From 431ab035365f9cfbe98b2ce0c4ef901ce9fec142 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 25 Mar 2020 12:30:37 +0300 Subject: [PATCH 001/233] initial commit Signed-off-by: raver119 --- libnd4j/include/memory/ColdZoneManager.h | 35 ++++++++++++ libnd4j/include/memory/MemoryZone.h | 32 +++++++++++ libnd4j/include/memory/ZoneManager.h | 57 +++++++++++++++++++ .../include/memory/impl/ColdZoneManager.cpp | 21 +++++++ libnd4j/include/memory/impl/ZoneManager.cpp | 21 +++++++ 5 files changed, 166 insertions(+) create mode 100644 libnd4j/include/memory/ColdZoneManager.h create mode 100644 libnd4j/include/memory/MemoryZone.h create mode 100644 libnd4j/include/memory/ZoneManager.h create mode 100644 libnd4j/include/memory/impl/ColdZoneManager.cpp create mode 100644 libnd4j/include/memory/impl/ZoneManager.cpp diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h new file mode 100644 index 000000000000..1d234e14f770 --- /dev/null +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#ifndef SD_COLDZONEMANAGER_H +#define SD_COLDZONEMANAGER_H + +#include + +namespace sd { + class ColdZoneManager : public ZoneManager { + public: + + }; +} + + +#endif //SD_COLDZONEMANAGER_H diff --git a/libnd4j/include/memory/MemoryZone.h b/libnd4j/include/memory/MemoryZone.h new file mode 100644 index 000000000000..e31e87cf9a63 --- /dev/null +++ b/libnd4j/include/memory/MemoryZone.h @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORYZONE_H +#define SD_MEMORYZONE_H + +namespace sd { + enum MemoryZone { + COLD = 0, + WARM = 10, + HOT = 20, + }; +} + +#endif //SD_MEMORYZONE_H diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h new file mode 100644 index 000000000000..02ed896f5fc2 --- /dev/null +++ b/libnd4j/include/memory/ZoneManager.h @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_ZONEMANAGER_H +#define SD_ZONEMANAGER_H + +#include +#include + +namespace sd { + /** + * Abstract class that defines common methods for zone managers + */ + class ZoneManager { + public: + ZoneManager() = default; + ~ZoneManager() = default; + + /** + * This method returns id of the current zone served by this manager instance + * @return MemoryZone enum + */ + virtual MemoryZone zone() const = 0; + + /** + * This method returns amount of memory available in this zone + * @return number of bytes + */ + virtual uint64_t available() const = 0; + + /** + * This method returns amount of memory currently used in this zone + * @return number of bytes + */ + virtual uint64_t used() const = 0; + }; +} + + +#endif //SD_ZONEMANAGER_H diff --git a/libnd4j/include/memory/impl/ColdZoneManager.cpp b/libnd4j/include/memory/impl/ColdZoneManager.cpp new file mode 100644 index 000000000000..102a82691286 --- /dev/null +++ b/libnd4j/include/memory/impl/ColdZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/memory/impl/ZoneManager.cpp b/libnd4j/include/memory/impl/ZoneManager.cpp new file mode 100644 index 000000000000..4b90acddc2ad --- /dev/null +++ b/libnd4j/include/memory/impl/ZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include From aa93c614ab7403f86cdaa83bad4ffafd12718463 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 25 Mar 2020 18:23:26 +0300 Subject: [PATCH 002/233] mmap for Graphs Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 102 ++++++++++-------- libnd4j/include/graph/impl/Graph.cpp | 29 +++++ libnd4j/include/helpers/FileUtils.h | 36 +++++++ libnd4j/include/helpers/files.h | 8 +- libnd4j/include/helpers/impl/FileUtils.cpp | 40 +++++++ libnd4j/include/memory/ColdZoneManager.h | 8 ++ .../include/memory/impl/ColdZoneManager.cpp | 6 ++ .../tests_cpu/layers_tests/OneOffTests.cpp | 12 +++ 8 files changed, 193 insertions(+), 48 deletions(-) create mode 100644 libnd4j/include/helpers/FileUtils.h create mode 100644 libnd4j/include/helpers/impl/FileUtils.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index a160872fd2db..ab1310b42f8e 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -85,6 +85,8 @@ namespace sd { ~Graph(); + static Graph* fromFlatBuffers(const char* fileName); + // this method applies toposort to nodes void toposortNodes(); @@ -222,60 +224,72 @@ namespace sd { void replaceState(VariableSpace *state, ExecutorConfiguration *configuration); - FORCEINLINE std::vector* nodes() { - return _nodes; - } + FORCEINLINE std::vector* nodes(); - FORCEINLINE std::vector* autos() { - return &_autos; - } + FORCEINLINE std::vector* autos(); - FORCEINLINE std::vector* output() { - return &_output; - } + FORCEINLINE std::vector* output(); - FORCEINLINE MAP_IMPL* scopes() { - return &_mappedScopes; - } + FORCEINLINE MAP_IMPL* scopes(); - FORCEINLINE bool built() { - return _built.load(); - } + FORCEINLINE bool built(); - FORCEINLINE void pullState(Graph *other) { - for (int e = 0; e < other->nodes()->size(); e++) - this->_nodes->emplace_back(other->nodes()->at(e)); + FORCEINLINE void pullState(Graph *other); + }; - for (int e = 0; e < other->output()->size(); e++) - this->_output.emplace_back(other->output()->at(e)); - - for (int e = 0; e < other->autos()->size(); e++) - this->_autos.emplace_back(other->autos()->at(e)); + FORCEINLINE std::vector* Graph::nodes() { + return _nodes; + } - for (auto &v: *other->scopes()) { - auto scp = v.second->clone(); - this->_mappedScopes[v.first] = scp; - this->_scopes.emplace_back(scp); - } - - for (auto &v: *other->getOnion()) { - auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector(); - - auto ovec = (*other->getOnion())[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - _handles.emplace_back(n); - (*this->_mapped)[n->id()] = n; - } - - if (this->_onion->count(v.first) < 1) - (*this->_onion)[v.first] = vec; + FORCEINLINE std::vector* Graph::autos() { + return &_autos; + } + + FORCEINLINE std::vector* Graph::output() { + return &_output; + } + + FORCEINLINE MAP_IMPL* Graph::scopes() { + return &_mappedScopes; + } + + FORCEINLINE bool Graph::built() { + return _built.load(); + } + + FORCEINLINE void Graph::pullState(Graph *other) { + for (int e = 0; e < other->nodes()->size(); e++) + this->_nodes->emplace_back(other->nodes()->at(e)); + + for (int e = 0; e < other->output()->size(); e++) + this->_output.emplace_back(other->output()->at(e)); + + for (int e = 0; e < other->autos()->size(); e++) + this->_autos.emplace_back(other->autos()->at(e)); + + for (auto &v: *other->scopes()) { + auto scp = v.second->clone(); + this->_mappedScopes[v.first] = scp; + this->_scopes.emplace_back(scp); + } + + for (auto &v: *other->getOnion()) { + auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector(); + + auto ovec = (*other->getOnion())[v.first]; + for (auto x: *(ovec)) { + auto n = x->clone(); + vec->emplace_back(n); + _handles.emplace_back(n); + (*this->_mapped)[n->id()] = n; } - this->_built.store(other->built()); + if (this->_onion->count(v.first) < 1) + (*this->_onion)[v.first] = vec; } - }; + + this->_built.store(other->built()); + } } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 15db128a850c..0d34ce89746e 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -1453,6 +1454,34 @@ namespace sd { return hash; } + + + Graph* Graph::fromFlatBuffers(const char* fileName) { + // check if file exists + if (!FileUtils::fileExists(fileName)) + throw std::runtime_error("Graph file doesn't exist"); + + // get file size + auto fsize = FileUtils::fileSize(fileName); + Nd4jLong *ref; + void *ptrGraph; + + // check if mmap is supported + if (true) { + // mmap this file + ref = ::mmapFile(nullptr, fileName, fsize); + ptrGraph = reinterpret_cast(ref[0]); + } else { + // if mmap is not supported - load it directly + + } + + // get FlatGraph out of it + auto fg = GetFlatGraph(reinterpret_cast(ptrGraph)); + + // return Graph from this FlatGraph + return new Graph(fg); + } } } diff --git a/libnd4j/include/helpers/FileUtils.h b/libnd4j/include/helpers/FileUtils.h new file mode 100644 index 000000000000..9698b4916487 --- /dev/null +++ b/libnd4j/include/helpers/FileUtils.h @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_FILEUTILS_H +#define SD_FILEUTILS_H + +#include + +namespace sd { + class FileUtils { + public: + static bool fileExists(const char *filename); + + static int64_t fileSize(const char *filename); + }; +} + + +#endif //SD_FILEUTILS_H diff --git a/libnd4j/include/helpers/files.h b/libnd4j/include/helpers/files.h index c49cedbb737b..9be8e4e5db62 100644 --- a/libnd4j/include/helpers/files.h +++ b/libnd4j/include/helpers/files.h @@ -32,7 +32,7 @@ char *strsave(const char *s, const char *lim); char ** shellpath(void); void freeshellpath (char *shellpath[]); unsigned maxpathlen(char *path[], const char *base); -bool file_exists(char *name); +bool file_exists(const char *name); void *malloc_check(const char *what, size_t n) { void *p = malloc(n); @@ -65,7 +65,7 @@ char ** shellpath(void) { #ifdef _WIN32 char *q = strchr(p, ';'); // windows uses ; as delimiter #else - char *q = strchr(p, ':'); // linux and derivatives use : as delimiter + char *q = strchr(const_cast(p), ':'); // linux and derivatives use : as delimiter #endif vector[next++] = strsave(p, q); p = q ? q + 1 : NULL; @@ -89,8 +89,8 @@ unsigned maxpathlen(char *path[], const char *base) { } return blen+n+1; } -bool file_exists(char *name){ - printf("Trying file: [%s]\n", name); +bool file_exists(const char *name){ + //printf("Trying file: [%s]\n", name); FILE *file; if (file = fopen(name, "r")) { fclose(file); diff --git a/libnd4j/include/helpers/impl/FileUtils.cpp b/libnd4j/include/helpers/impl/FileUtils.cpp new file mode 100644 index 000000000000..b3da6e6d4e92 --- /dev/null +++ b/libnd4j/include/helpers/impl/FileUtils.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include + +namespace sd { + bool FileUtils::fileExists(const char *filename) { + if (filename == nullptr) + return false; + + return file_exists(filename); + } + + int64_t FileUtils::fileSize(const char *filename) { + struct stat stat_buf; + int rc = stat(filename, &stat_buf); + return rc == 0 ? stat_buf.st_size : -1; + } +} \ No newline at end of file diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h index 1d234e14f770..ea840f85e8ee 100644 --- a/libnd4j/include/memory/ColdZoneManager.h +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -27,6 +27,14 @@ namespace sd { class ColdZoneManager : public ZoneManager { public: + /** + * This constructor is used to initialize ZoneManager with existing FlatBuffers file + * @param filename - full path to existing file (i.e. FlatBuffers file) + */ + explicit ColdZoneManager(const char* filename); + ColdZoneManager() = default; + ~ColdZoneManager() = default; + }; } diff --git a/libnd4j/include/memory/impl/ColdZoneManager.cpp b/libnd4j/include/memory/impl/ColdZoneManager.cpp index 102a82691286..0f0a64f077cc 100644 --- a/libnd4j/include/memory/impl/ColdZoneManager.cpp +++ b/libnd4j/include/memory/impl/ColdZoneManager.cpp @@ -19,3 +19,9 @@ // #include + +namespace sd { + ColdZoneManager::ColdZoneManager(const char* filename) { + + } +} diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index e1cf4ec52663..7cbbab24a1f3 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -48,6 +48,18 @@ TEST_F(OneOffTests, test_avg_pool_3d_1) { delete graph; } +TEST_F(OneOffTests, test_avg_pool_3d_2) { + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); + + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + delete graph; +} + TEST_F(OneOffTests, test_non2d_0A_1) { auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_0A.fb"); From e894d8171f3c3f0fa3e4e5b0196bd0823e3bb10d Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 25 Mar 2020 19:08:42 +0300 Subject: [PATCH 003/233] few methods removed from GraphExecutioner Signed-off-by: raver119 --- libnd4j/include/array/impl/NDArrayFactory.cpp | 4 +- libnd4j/include/graph/Graph.h | 3 +- libnd4j/include/graph/GraphExecutioner.h | 7 --- libnd4j/include/graph/impl/Graph.cpp | 20 ++++++- .../include/graph/impl/GraphExecutioner.cpp | 55 ------------------- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 34 ++++++------ .../layers_tests/PlaygroundTests.cpp | 13 +++-- 8 files changed, 47 insertions(+), 91 deletions(-) diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 870fdc19880e..0f54f26a9616 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -33,6 +33,7 @@ #include #include +#include namespace sd { @@ -694,8 +695,7 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char NDArray NDArrayFactory::fromNpyFile(const char *fileName) { - auto size = sd::graph::getFileSize(fileName); - if (size < 0) + if (!FileUtils::fileExists(fileName)) throw std::runtime_error("File doesn't exit"); auto pNPY = reinterpret_cast(::numpyFromFile(std::string(fileName))); diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index ab1310b42f8e..746f39d7a372 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -85,7 +85,8 @@ namespace sd { ~Graph(); - static Graph* fromFlatBuffers(const char* fileName); + static Graph* fromFlatBuffers(const char *fileName); + static Graph* fromFlatPointer(void *ptr); // this method applies toposort to nodes void toposortNodes(); diff --git a/libnd4j/include/graph/GraphExecutioner.h b/libnd4j/include/graph/GraphExecutioner.h index 148b27951a01..4e81638f6ec7 100644 --- a/libnd4j/include/graph/GraphExecutioner.h +++ b/libnd4j/include/graph/GraphExecutioner.h @@ -67,15 +67,8 @@ namespace sd { static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); static Graph *importFromTensorFlow(const char *fileName); - - - static Graph *importFromFlatBuffers(const char *filename); - - static Graph *importFromFlatPointer(Nd4jPointer ptr); }; - long getFileSize(const char * filename); - uint8_t* readFlatBuffers(const char * filename); } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 0d34ce89746e..b3732e0b6e2e 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1466,7 +1466,7 @@ namespace sd { Nd4jLong *ref; void *ptrGraph; - // check if mmap is supported + // TODO: check if mmap is supported if (true) { // mmap this file ref = ::mmapFile(nullptr, fileName, fsize); @@ -1474,10 +1474,26 @@ namespace sd { } else { // if mmap is not supported - load it directly + ptrGraph = new uint8_t[fsize]; + auto data = reinterpret_cast(ptrGraph); + + FILE *in = fopen(fileName, "rb"); + int cnt = 0; + int b = 0; + while (cnt < fsize) { + b = fread(data + cnt, 1, fsize < 16384 ? fsize : 16384, in); + + cnt += b; + } + fclose(in); } + return fromFlatPointer(ptrGraph); + } + + Graph* Graph::fromFlatPointer(void *ptr) { // get FlatGraph out of it - auto fg = GetFlatGraph(reinterpret_cast(ptrGraph)); + auto fg = GetFlatGraph(reinterpret_cast(ptr)); // return Graph from this FlatGraph return new Graph(fg); diff --git a/libnd4j/include/graph/impl/GraphExecutioner.cpp b/libnd4j/include/graph/impl/GraphExecutioner.cpp index c673d2b31a55..a8ab282c7a4c 100644 --- a/libnd4j/include/graph/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/impl/GraphExecutioner.cpp @@ -809,43 +809,6 @@ Graph* GraphExecutioner::importFromTensorFlow(const char *fileName) { return nullptr; } -/** -* This function returns file size for the given file name, or -1 if something went wrong -*/ -long getFileSize(const char * filename) { - struct stat stat_buf; - int rc = stat(filename, &stat_buf); - return rc == 0 ? stat_buf.st_size : -1; -} - -/** -* Helper function, that loads given filename into uint8_t array -* -*/ -uint8_t* readFlatBuffers(const char * filename) { - long fileLen = getFileSize(filename); - if (fileLen < 0) { - nd4j_printf("File [%s] wasn't found. Please check path and permissions\n", filename); - throw std::runtime_error("File not found"); - } - - nd4j_debug("File length: %i\n", fileLen); - - uint8_t * data = new uint8_t[fileLen]; - - FILE *in = fopen(filename, "rb"); - int cnt = 0; - int b = 0; - while (cnt < fileLen) { - b = fread(data + cnt, 1, 1, in); - - cnt += b; - } - fclose(in); - - return data; -} - flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { ExecutionResult result; auto varSpace = graph->getVariableSpace(); @@ -884,23 +847,5 @@ flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuff } - /** - * This method reads given FlatBuffers file, and returns Graph instance - * - * PLEASE NOTE: This method is mostly suited for tests and debugging/profiling - */ - Graph* GraphExecutioner::importFromFlatBuffers(const char *filename) { - auto data = readFlatBuffers(filename); - auto restoredGraph = importFromFlatPointer(reinterpret_cast(data)); - delete[] data; - return restoredGraph; - } - - Graph *GraphExecutioner::importFromFlatPointer(Nd4jPointer ptr) { - auto fg = GetFlatGraph(reinterpret_cast(ptr)); - auto restoredGraph = new Graph(fg); - - return restoredGraph; - } } } \ No newline at end of file diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index cf04acbe7cab..39d0e62133a0 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2172,7 +2172,7 @@ int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBu int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { try { - auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 7cbbab24a1f3..acd7125f751e 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -37,7 +37,7 @@ class OneOffTests : public testing::Test { }; TEST_F(OneOffTests, test_avg_pool_3d_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/avg_pooling3d.fb"); + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); ASSERT_TRUE(graph != nullptr); @@ -61,7 +61,7 @@ TEST_F(OneOffTests, test_avg_pool_3d_2) { } TEST_F(OneOffTests, test_non2d_0A_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_0A.fb"); + auto graph = Graph::fromFlatBuffers("./resources/non2d_0A.fb"); ASSERT_TRUE(graph != nullptr); @@ -77,7 +77,7 @@ TEST_F(OneOffTests, test_assert_scalar_float32_1) { sd::ops::Assert op; sd::ops::identity op1; sd::ops::noop op2; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scalar_float32.fb"); + auto graph = Graph::fromFlatBuffers("./resources/scalar_float32.fb"); ASSERT_TRUE(graph != nullptr); @@ -92,7 +92,7 @@ TEST_F(OneOffTests, test_assert_scalar_float32_2) { sd::ops::Assert op; sd::ops::identity op1; sd::ops::noop op2; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assertsomething.fb"); + auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); ASSERT_TRUE(graph != nullptr); @@ -106,7 +106,7 @@ TEST_F(OneOffTests, test_assert_scalar_float32_2) { TEST_F(OneOffTests, test_pad_1D_1) { auto e = NDArrayFactory::create('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/pad_1D.fb"); + auto graph = Graph::fromFlatBuffers("./resources/pad_1D.fb"); ASSERT_TRUE(graph != nullptr); @@ -134,7 +134,7 @@ TEST_F(OneOffTests, test_scatter_nd_update_1) { 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scatter_nd_update.fb"); + auto graph = Graph::fromFlatBuffers("./resources/scatter_nd_update.fb"); ASSERT_TRUE(graph != nullptr); graph->printOut(); @@ -158,7 +158,7 @@ TEST_F(OneOffTests, test_scatter_nd_update_1) { TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { auto e = NDArrayFactory::create('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); + auto graph = Graph::fromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -179,7 +179,7 @@ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { TEST_F(OneOffTests, test_tensor_array_1) { auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); + auto graph = Graph::fromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -199,7 +199,7 @@ TEST_F(OneOffTests, test_tensor_array_1) { TEST_F(OneOffTests, test_tensor_array_2) { auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); + auto graph = Graph::fromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -219,7 +219,7 @@ TEST_F(OneOffTests, test_tensor_array_2) { TEST_F(OneOffTests, test_tensor_array_3) { auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); + auto graph = Graph::fromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -240,7 +240,7 @@ TEST_F(OneOffTests, test_tensor_array_3) { TEST_F(OneOffTests, test_tensor_array_4) { auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); + auto graph = Graph::fromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -261,7 +261,7 @@ TEST_F(OneOffTests, test_tensor_array_4) { TEST_F(OneOffTests, test_assert_4) { auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assert_type_rank2_int64.fb"); + auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -282,7 +282,7 @@ TEST_F(OneOffTests, test_assert_4) { // TEST_F(OneOffTests, test_cond_true_1) { // auto e = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); -// auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_true.fb"); +// auto graph = Graph::fromFlatBuffers("./resources/cond_true.fb"); // ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -306,7 +306,7 @@ TEST_F(OneOffTests, test_assert_4) { TEST_F(OneOffTests, test_cond_false_1) { auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_false.fb"); + auto graph = Graph::fromFlatBuffers("./resources/cond_false.fb"); ASSERT_TRUE(graph != nullptr); graph->printOut(); @@ -332,7 +332,7 @@ TEST_F(OneOffTests, test_identity_n_2) { sd::ops::identity_n op; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/identity_n_2.fb"); + auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -354,7 +354,7 @@ TEST_F(OneOffTests, test_identity_n_2) { TEST_F(OneOffTests, test_non2d_1) { auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_1.fb"); + auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); @@ -376,7 +376,7 @@ TEST_F(OneOffTests, test_non2d_1) { TEST_F(OneOffTests, test_reduce_all_1) { auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); ASSERT_TRUE(graph != nullptr); // graph->printOut(); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 94156d4bc9f0..67f1c212dc44 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -46,6 +46,7 @@ #include #include +#include using namespace sd; using namespace sd::graph; @@ -92,10 +93,10 @@ TEST_F(PlaygroundTests, test_biasAdd_1) { TEST_F(PlaygroundTests, test_bert_full_1) { // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0) + if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) return; - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); + auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); @@ -152,10 +153,10 @@ TEST_F(PlaygroundTests, test_bert_full_1) { TEST_F(PlaygroundTests, test_bert_1) { // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0) + if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb")) return; - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); + auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext.numpy"); auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_1.numpy"); @@ -210,10 +211,10 @@ TEST_F(PlaygroundTests, test_bert_1) { TEST_F(PlaygroundTests, test_bert_2) { // this test will run ONLY if this model exists - if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) + if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb")) return; - auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); + auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); //graph->printOut(); From 49dc7098c13bf98b3ee9a427d5c6b689fc196fd4 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 25 Mar 2020 19:20:59 +0300 Subject: [PATCH 004/233] disable certain tests in debug mode Signed-off-by: raver119 --- .../layers_tests/PlaygroundTests.cpp | 13 ++++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 67f1c212dc44..d094e09caec4 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -92,6 +92,10 @@ TEST_F(PlaygroundTests, test_biasAdd_1) { TEST_F(PlaygroundTests, test_bert_full_1) { +#ifdef _RELEASE + + + // this test will run ONLY if this model exists if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) return; @@ -148,10 +152,13 @@ TEST_F(PlaygroundTests, test_bert_full_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; + +#endif } TEST_F(PlaygroundTests, test_bert_1) { +#ifdef _RELEASE // this test will run ONLY if this model exists if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb")) return; @@ -207,9 +214,13 @@ TEST_F(PlaygroundTests, test_bert_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; + +#endif } TEST_F(PlaygroundTests, test_bert_2) { +#ifdef _RELEASE + // this test will run ONLY if this model exists if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb")) return; @@ -257,6 +268,8 @@ TEST_F(PlaygroundTests, test_bert_2) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ delete graph; + +#endif } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 47791f865744..3643e1e04634 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -16672,6 +16672,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax_bp position(long position) { + return (mergemax_bp)super.position(position); + } + + public mergemax_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /* * Complete tensor with max indices merged from all input tensors list @@ -16714,6 +16729,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd_bp position(long position) { + return (mergeadd_bp)super.position(position); + } + + public mergeadd_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_mergeavg) @@ -16732,6 +16762,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg_bp position(long position) { + return (mergeavg_bp)super.position(position); + } + + public mergeavg_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_scatter_update) From 93808dc6280b2dea1babb158925f0535fa4627c0 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 14:48:32 +0300 Subject: [PATCH 005/233] OpSequence skeleton Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 4 + libnd4j/include/graph/GraphExecutioner.h | 2 - libnd4j/include/graph/execution/OpSequence.h | 66 ++++++ .../graph/execution/impl/OpSequence.cpp | 68 ++++++ libnd4j/include/graph/impl/Graph.cpp | 199 ++++++++++++++++++ .../include/graph/impl/GraphExecutioner.cpp | 199 ------------------ 6 files changed, 337 insertions(+), 201 deletions(-) create mode 100644 libnd4j/include/graph/execution/OpSequence.h create mode 100644 libnd4j/include/graph/execution/impl/OpSequence.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 746f39d7a372..939cf95f9a2c 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -85,6 +85,10 @@ namespace sd { ~Graph(); + /** + * Methods that allow Graph imports + */ + static Graph *importFromTensorFlow(const char *fileName); static Graph* fromFlatBuffers(const char *fileName); static Graph* fromFlatPointer(void *ptr); diff --git a/libnd4j/include/graph/GraphExecutioner.h b/libnd4j/include/graph/GraphExecutioner.h index 4e81638f6ec7..2921d30089f9 100644 --- a/libnd4j/include/graph/GraphExecutioner.h +++ b/libnd4j/include/graph/GraphExecutioner.h @@ -65,8 +65,6 @@ namespace sd { static sd::graph::ResultWrapper* executeFlatBuffer(Nd4jPointer pointer); static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); - - static Graph *importFromTensorFlow(const char *fileName); }; diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h new file mode 100644 index 000000000000..c1ecee4b9f9c --- /dev/null +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_OPSEQUENCE_H +#define SD_OPSEQUENCE_H + +#include +#include + + +namespace sd { + namespace graph { + class OpSequence { + protected: + // main thing here. sorted list of operations and their contexts + std::vector> _ops; + public: + explicit OpSequence(const std::vector> &ops); + OpSequence() = default; + ~OpSequence() = default; + + OpSequence(const OpSequence& other) noexcept; + + OpSequence& operator=(const OpSequence& other) noexcept; + + // move constructor + OpSequence(OpSequence&& other) noexcept; + + // move assignment operator + OpSequence& operator=(OpSequence&& other) noexcept; + + /** + * This method returns number of individual operations within this sequence + * @return + */ + uint64_t length(); + + /** + * This method allows to add DeclarableOp to the end of execution queue + * @param op - Op to be executed + * @param ctx - Context for this operation with inputs/outputs/args defined + */ + void append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx); + }; + } +} + + +#endif //SD_OPSEQUENCE_H diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp new file mode 100644 index 000000000000..0e8c7f2054c6 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + OpSequence::OpSequence(const std::vector> &ops) { + for (const auto v : ops) + _ops.emplace_back(v); + } + + OpSequence::OpSequence(const OpSequence& other) noexcept{ + for (const auto v : other._ops) + _ops.emplace_back(v); + } + + //////////////////////////////////////////////////////////////////////// + // move constructor + OpSequence::OpSequence(OpSequence&& other) noexcept { + _ops = std::move(other._ops); + } + + OpSequence& OpSequence::operator=(OpSequence&& other) noexcept { + if (this == &other) + return *this; + + _ops = std::move(other._ops); + + return *this; + } + + OpSequence& OpSequence::operator=(const OpSequence& other) noexcept { + if (this == &other) + return *this; + + for (const auto v : other._ops) + _ops.emplace_back(v); + + return *this; + } + + uint64_t OpSequence::length() { + return _ops.size(); + } + + void OpSequence::append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx) { + _ops.emplace_back(std::pair{op, ctx}); + } + } +} diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index b3732e0b6e2e..85bedf14dbcd 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1498,6 +1498,205 @@ namespace sd { // return Graph from this FlatGraph return new Graph(fg); } + + Graph* Graph::importFromTensorFlow(const char *fileName) { + throw std::runtime_error("Graph::importFromTensorFlow() not implemented yet"); + /* + if (fileName == nullptr) + return nullptr; + + int fd = open(fileName, O_RDONLY); + + if (fd < 0) { + nd4j_printf("File not found: [%s]\n", fileName); + return nullptr; + } + + nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); + + tensorflow::GraphDef graphDef; + bool res = graphDef.ParseFromFileDescriptor(fd); + + // trying to read graph as text + if(!res) { + close(fd); + fd = open(fileName, O_RDONLY); + + google::protobuf::io::FileInputStream fileInput(fd); + fileInput.SetCloseOnDelete(true); + + if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { + nd4j_printf("Failed to read file\n",""); + } else { + res = true; + } + } + + close(fd); + + if (!res) + return nullptr; + + auto graph = new Graph(); + auto variableSpace = graph->getVariableSpace(); + + std::map variablesMap; + + int variablesCounter = 0; + int nodesCounter = 0; + nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); + for (int n = 0; n < graphDef.node_size(); n++) { + auto node = graphDef.node(n); + + // if that's external variable - we put it to variable space + if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { + nd4j_printf("Variable found: %s\n", node.name().c_str()); + auto variable = new Variable(); + variable->setName(new std::string(node.name().c_str())); + variable->setId(--variablesCounter); + variableSpace->putVariable(variable->id(), variable); + + std::pair pair(node.name(), variable->id()); + variablesMap.insert(pair); + + // TODO: we might want to have something like that. + // it basically just gives input validation option, since settles expectations for input + if (strcmp(TF_INPUT, node.op().c_str()) == 0) + continue; + + // checking shape, not applicable to input, since it can vary + if (node.attr().count("shape")) { + auto attr = node.attr().at("shape"); + int dims = attr.shape().dim_size(); + + if (dims > 0) { + std::vector __shape; + + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); + + // roll through dimensions + for (auto s: attr.shape().dim()) { + __shape.push_back((int) s.size()) ; + } + + variable->setNDArray(new NDArray('c', __shape)); + + nd4j_printf("Shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } + + // checking tensor attached + if (node.attr().count("value")) { + auto attr = node.attr().at("value"); + + // int + if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { + nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); + + Nd4jLong __length = 0; + + nd4j_verbose("Tensor has shape: %i\n", attr.tensor().has_tensor_shape()); + if (attr.tensor().has_tensor_shape()) { + auto shape = attr.tensor().tensor_shape(); + int dims = shape.dim_size(); + + if (dims > 0) { + std::vector __shape; + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); + + // roll through dimensions + for (auto s: shape.dim()) { + __shape.push_back((int) s.size()); + } + + variable->setNDArray(new NDArray('c', __shape)); + __length = variable->getNDArray()->lengthOf(); + + nd4j_printf("Tensor shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } + + // it can be valueOf array + if (attr.tensor().int_val_size() == 1 && __length > 0) { + variable->getNDArray()->assign((T) attr.tensor().int_val(0)); + } + } + } + } else { + nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, node.name().c_str(), + node.op().c_str()); + + sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str()); + + if (op == nullptr) { + nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); + return nullptr; + } + + auto jNode = new Node(); + jNode->setName(node.name()); + jNode->setId(++nodesCounter); + jNode->setCustomOp(op); + jNode->setBlock(new Block(jNode->id(), variableSpace)); + + std::pair pair(node.name(), jNode->id()); + variablesMap.insert(pair); + + // multi-output nodes require special treatment + for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) { + std::string deepName(node.name()); + deepName += ":" + std::to_string(e); + auto deepVar = new Variable(); + deepVar->setName(&deepName); + + if (e > 0) + deepVar->setId(--variablesCounter); + else + deepVar->setId(jNode->id()); + + std::pair pair(deepName, deepVar->id()); + variablesMap.insert(pair); + + variableSpace->putVariable(deepVar->id(), deepVar); + + std::pair nodepair(jNode->id(), e); + variableSpace->putVariable(nodepair, deepVar); + } + + + printf(" Inputs: ["); + for (int i = 0; i < node.input_size(); i++) { + nd4j_printf("Trying input: %s\n", node.input(i).c_str()); + + // if this fails - we're probably on partial input :) + if (!variablesMap.count(node.input(i))) + return nullptr; + + printf("%s (%i)", node.input(i).c_str(), variablesMap.at(node.input(i))); + + + jNode->pickInput(variablesMap.at(node.input(i))); + jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); + + + if (i < node.input_size() + 1) + printf(", "); + } + printf("]\n"); + + graph->addNode(jNode); + } + } + + return graph; + */ + } } } diff --git a/libnd4j/include/graph/impl/GraphExecutioner.cpp b/libnd4j/include/graph/impl/GraphExecutioner.cpp index a8ab282c7a4c..b68c3ccd3203 100644 --- a/libnd4j/include/graph/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/impl/GraphExecutioner.cpp @@ -610,205 +610,6 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) return new ResultWrapper(builder.GetSize(), reinterpret_cast(res)); } -Graph* GraphExecutioner::importFromTensorFlow(const char *fileName) { - /* - if (fileName == nullptr) - return nullptr; - - int fd = open(fileName, O_RDONLY); - - if (fd < 0) { - nd4j_printf("File not found: [%s]\n", fileName); - return nullptr; - } - - nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); - - tensorflow::GraphDef graphDef; - bool res = graphDef.ParseFromFileDescriptor(fd); - - // trying to read graph as text - if(!res) { - close(fd); - fd = open(fileName, O_RDONLY); - - google::protobuf::io::FileInputStream fileInput(fd); - fileInput.SetCloseOnDelete(true); - - if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { - nd4j_printf("Failed to read file\n",""); - } else { - res = true; - } - } - - close(fd); - - if (!res) - return nullptr; - - auto graph = new Graph(); - auto variableSpace = graph->getVariableSpace(); - - std::map variablesMap; - - int variablesCounter = 0; - int nodesCounter = 0; - nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); - for (int n = 0; n < graphDef.node_size(); n++) { - auto node = graphDef.node(n); - - // if that's external variable - we put it to variable space - if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { - nd4j_printf("Variable found: %s\n", node.name().c_str()); - auto variable = new Variable(); - variable->setName(new std::string(node.name().c_str())); - variable->setId(--variablesCounter); - variableSpace->putVariable(variable->id(), variable); - - std::pair pair(node.name(), variable->id()); - variablesMap.insert(pair); - - // TODO: we might want to have something like that. - // it basically just gives input validation option, since settles expectations for input - if (strcmp(TF_INPUT, node.op().c_str()) == 0) - continue; - - // checking shape, not applicable to input, since it can vary - if (node.attr().count("shape")) { - auto attr = node.attr().at("shape"); - int dims = attr.shape().dim_size(); - - if (dims > 0) { - std::vector __shape; - - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); - - // roll through dimensions - for (auto s: attr.shape().dim()) { - __shape.push_back((int) s.size()) ; - } - - variable->setNDArray(new NDArray('c', __shape)); - - nd4j_printf("Shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } - - // checking tensor attached - if (node.attr().count("value")) { - auto attr = node.attr().at("value"); - - // int - if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { - nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); - - Nd4jLong __length = 0; - - nd4j_verbose("Tensor has shape: %i\n", attr.tensor().has_tensor_shape()); - if (attr.tensor().has_tensor_shape()) { - auto shape = attr.tensor().tensor_shape(); - int dims = shape.dim_size(); - - if (dims > 0) { - std::vector __shape; - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); - - // roll through dimensions - for (auto s: shape.dim()) { - __shape.push_back((int) s.size()); - } - - variable->setNDArray(new NDArray('c', __shape)); - __length = variable->getNDArray()->lengthOf(); - - nd4j_printf("Tensor shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } - - // it can be valueOf array - if (attr.tensor().int_val_size() == 1 && __length > 0) { - variable->getNDArray()->assign((T) attr.tensor().int_val(0)); - } - } - } - } else { - nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, node.name().c_str(), - node.op().c_str()); - - sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str()); - - if (op == nullptr) { - nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); - return nullptr; - } - - auto jNode = new Node(); - jNode->setName(node.name()); - jNode->setId(++nodesCounter); - jNode->setCustomOp(op); - jNode->setBlock(new Block(jNode->id(), variableSpace)); - - std::pair pair(node.name(), jNode->id()); - variablesMap.insert(pair); - - // multi-output nodes require special treatment - for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) { - std::string deepName(node.name()); - deepName += ":" + std::to_string(e); - auto deepVar = new Variable(); - deepVar->setName(&deepName); - - if (e > 0) - deepVar->setId(--variablesCounter); - else - deepVar->setId(jNode->id()); - - std::pair pair(deepName, deepVar->id()); - variablesMap.insert(pair); - - variableSpace->putVariable(deepVar->id(), deepVar); - - std::pair nodepair(jNode->id(), e); - variableSpace->putVariable(nodepair, deepVar); - } - - - printf(" Inputs: ["); - for (int i = 0; i < node.input_size(); i++) { - nd4j_printf("Trying input: %s\n", node.input(i).c_str()); - - // if this fails - we're probably on partial input :) - if (!variablesMap.count(node.input(i))) - return nullptr; - - printf("%s (%i)", node.input(i).c_str(), variablesMap.at(node.input(i))); - - - jNode->pickInput(variablesMap.at(node.input(i))); - jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); - - - if (i < node.input_size() + 1) - printf(", "); - } - printf("]\n"); - - graph->addNode(jNode); - } - } - - return graph; - */ - return nullptr; -} - flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { ExecutionResult result; auto varSpace = graph->getVariableSpace(); From de46d8745f7ea511405acdfa4da1587710233146 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 15:27:53 +0300 Subject: [PATCH 006/233] OpSequence iterator + test Signed-off-by: raver119 --- libnd4j/include/graph/execution/OpSequence.h | 28 ++++++++- .../graph/execution/impl/OpSequence.cpp | 34 +++++++++- .../layers_tests/OpSequenceTests.cpp | 62 +++++++++++++++++++ .../tests_cpu/layers_tests/OpTrackerTests.cpp | 2 - 4 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index c1ecee4b9f9c..5be27791c5a0 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -27,7 +27,9 @@ namespace sd { namespace graph { - class OpSequence { + class OpSequence : public std::iterator> { + // our internal iterator for OpSequence + class iterator; protected: // main thing here. sorted list of operations and their contexts std::vector> _ops; @@ -50,7 +52,7 @@ namespace sd { * This method returns number of individual operations within this sequence * @return */ - uint64_t length(); + uint64_t length() const; /** * This method allows to add DeclarableOp to the end of execution queue @@ -58,6 +60,28 @@ namespace sd { * @param ctx - Context for this operation with inputs/outputs/args defined */ void append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx); + + /** + * Iterator functionality for OpSequence + * @return + */ + + iterator begin(); + iterator end(); + + // additional private section + private: + class iterator : public std::iterator> { + private: + uint64_t _position = 0; + OpSequence & _container; + public: + explicit iterator(OpSequence & container, uint64_t index = 0); + std::pair operator*() const; + iterator & operator++(); + iterator & operator++(int); + bool operator!=(const iterator &) const; + }; }; } } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0e8c7f2054c6..8e0db284dabf 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -57,12 +57,44 @@ namespace sd { return *this; } - uint64_t OpSequence::length() { + uint64_t OpSequence::length() const { return _ops.size(); } void OpSequence::append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx) { _ops.emplace_back(std::pair{op, ctx}); } + + + OpSequence::iterator + OpSequence::begin() { + return OpSequence::iterator(*this, 0); + } + + OpSequence::iterator + OpSequence::end() { + return OpSequence::iterator(*this, length()); + } + + OpSequence::iterator::iterator(OpSequence &container, uint64_t index) :_container(container), _position(index) { + // + } + + std::pair OpSequence::iterator::operator*() const { + return _container._ops[_position]; + } + + OpSequence::iterator &OpSequence::iterator::operator++() { + _position++; + return *this; + } + + OpSequence::iterator &OpSequence::iterator::operator++(int inc) { + return ++(*this); + } + + bool OpSequence::iterator::operator!=(const OpSequence::iterator &other) const { + return _position != other._position; + } } } diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp new file mode 100644 index 000000000000..eae93dd23c21 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class OpSequenceTests : public testing::Test { +public: + + OpSequenceTests() { + } +}; + +TEST_F(OpSequenceTests, test_iterator_1) { + OpSequence sequence; + + ASSERT_EQ(0, sequence.length()); + + ops::add op1; + ops::multiply op2; + + Context ctx1(1); + Context ctx2(2); + + sequence.append(&op1, &ctx1); + sequence.append(&op2, &ctx2); + + ASSERT_EQ(2, sequence.length()); + + int cnt = 1; + for (const auto &v:sequence) { + ASSERT_EQ(cnt++, v.second->nodeId()); + } + + ASSERT_EQ(3, cnt); +} diff --git a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp index fe581e09e21e..35e47788027a 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp @@ -34,8 +34,6 @@ class OpTrackerTests : public testing::Test { int poolSize = 10; OpTrackerTests() { - printf("\n"); - fflush(stdout); } }; From caa0fbee7ff935de116c953b790db9506ab7ea69 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 17:02:37 +0300 Subject: [PATCH 007/233] OpimizedGraph abstraction Signed-off-by: raver119 --- libnd4j/include/graph/OptimizedGraph.h | 65 +++++++++++++++++++ libnd4j/include/graph/execution/OpSequence.h | 3 + libnd4j/include/graph/impl/OptimizedGraph.cpp | 37 +++++++++++ 3 files changed, 105 insertions(+) create mode 100644 libnd4j/include/graph/OptimizedGraph.h create mode 100644 libnd4j/include/graph/impl/OptimizedGraph.cpp diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h new file mode 100644 index 000000000000..5b9a472f1ac9 --- /dev/null +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#ifndef SD_OPTIMIZEDGRAPH_H +#define SD_OPTIMIZEDGRAPH_H + +#include +#include +#include + +namespace sd { + namespace graph { + /** + * This class acts as a topologically sorted & optimized for top + */ + class OptimizedGraph { + protected: + // here we store independent OpSequences + // Graph starts from layer 0, and goes deeper step by step + // on each layer we can have 1+ OpSequences that can be executed independent + std::map> _onion; + public: + OptimizedGraph() = default; + ~OptimizedGraph() = default; + + /** + * This method returns number of layers within OptimizedGraph + * @return + */ + uint64_t layers() const; + + /** + * This method returns OpSequences stored in a given layer + * @param index + * @return + */ + const std::vector& layer(uint64_t index) const; + + /** + * This method allows to append layer to this OptimizedGraph instance + */ + // FIXME: this method should be removed or made private + void append(const std::vector &layer); + }; + } +} + + +#endif //SD_OPTIMIZEDGRAPH_H diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 5be27791c5a0..73320103d1c4 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -27,6 +27,9 @@ namespace sd { namespace graph { + /** + * This class represents independent and immutable sequence of operations + */ class OpSequence : public std::iterator> { // our internal iterator for OpSequence class iterator; diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp new file mode 100644 index 000000000000..ef091b5d5972 --- /dev/null +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + uint64_t OptimizedGraph::layers() const { + return _onion.size(); + } + + const std::vector &OptimizedGraph::layer(uint64_t index) const { + return _onion.at(index); + } + + void OptimizedGraph::append(const std::vector &layer) { + _onion[_onion.size()] = layer; + } + } +} \ No newline at end of file From de21d7d9fa8c3c9f55f39ef2141afb7e59724e43 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 17:03:31 +0300 Subject: [PATCH 008/233] javadoc update Signed-off-by: raver119 --- libnd4j/include/graph/OptimizedGraph.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 5b9a472f1ac9..d601779ae001 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -27,7 +27,7 @@ namespace sd { namespace graph { /** - * This class acts as a topologically sorted & optimized for top + * This class acts as a topologically sorted & optimized Graph representation, ready for execution */ class OptimizedGraph { protected: From e4f41a5f53f50310bfa7c9e8e7c8cede1d58efd7 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 17:28:44 +0300 Subject: [PATCH 009/233] - OpSequence now works with ContextPrototype - OptimizedGraph copy/move operators/constructors Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 4 +++ libnd4j/include/graph/OptimizedGraph.h | 11 ++++++++ libnd4j/include/graph/execution/OpSequence.h | 14 +++++----- .../graph/execution/impl/OpSequence.cpp | 8 +++--- libnd4j/include/graph/impl/Graph.cpp | 6 +++++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 26 +++++++++++++++++++ 6 files changed, 58 insertions(+), 11 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 939cf95f9a2c..c64ab128f05f 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -36,6 +36,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -240,6 +241,9 @@ namespace sd { FORCEINLINE bool built(); FORCEINLINE void pullState(Graph *other); + + + OptimizedGraph optimizedGraph() const; }; FORCEINLINE std::vector* Graph::nodes() { diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index d601779ae001..4de4d6a3d836 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -39,6 +39,17 @@ namespace sd { OptimizedGraph() = default; ~OptimizedGraph() = default; + OptimizedGraph(const OptimizedGraph& other) noexcept; + + OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; + + // move constructor + OptimizedGraph(OptimizedGraph&& other) noexcept; + + // move assignment operator + OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + + /** * This method returns number of layers within OptimizedGraph * @return diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 73320103d1c4..d1a08a63ca1b 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -30,14 +30,14 @@ namespace sd { /** * This class represents independent and immutable sequence of operations */ - class OpSequence : public std::iterator> { + class OpSequence : public std::iterator> { // our internal iterator for OpSequence class iterator; protected: // main thing here. sorted list of operations and their contexts - std::vector> _ops; + std::vector> _ops; public: - explicit OpSequence(const std::vector> &ops); + explicit OpSequence(const std::vector> &ops); OpSequence() = default; ~OpSequence() = default; @@ -60,9 +60,9 @@ namespace sd { /** * This method allows to add DeclarableOp to the end of execution queue * @param op - Op to be executed - * @param ctx - Context for this operation with inputs/outputs/args defined + * @param ctx - ContextPrototype for this operation with inputs/outputs/args defined */ - void append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx); + void append(sd::ops::DeclarableOp *op, sd::graph::ContextPrototype *ctx); /** * Iterator functionality for OpSequence @@ -74,13 +74,13 @@ namespace sd { // additional private section private: - class iterator : public std::iterator> { + class iterator : public std::iterator> { private: uint64_t _position = 0; OpSequence & _container; public: explicit iterator(OpSequence & container, uint64_t index = 0); - std::pair operator*() const; + std::pair operator*() const; iterator & operator++(); iterator & operator++(int); bool operator!=(const iterator &) const; diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 8e0db284dabf..2354deaec7c5 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -22,7 +22,7 @@ namespace sd { namespace graph { - OpSequence::OpSequence(const std::vector> &ops) { + OpSequence::OpSequence(const std::vector> &ops) { for (const auto v : ops) _ops.emplace_back(v); } @@ -61,8 +61,8 @@ namespace sd { return _ops.size(); } - void OpSequence::append(sd::ops::DeclarableOp *op, sd::graph::Context *ctx) { - _ops.emplace_back(std::pair{op, ctx}); + void OpSequence::append(sd::ops::DeclarableOp *op, sd::graph::ContextPrototype *ctx) { + _ops.emplace_back(std::pair{op, ctx}); } @@ -80,7 +80,7 @@ namespace sd { // } - std::pair OpSequence::iterator::operator*() const { + std::pair OpSequence::iterator::operator*() const { return _container._ops[_position]; } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 85bedf14dbcd..cb0845423f2e 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1697,6 +1697,12 @@ namespace sd { return graph; */ } + + + OptimizedGraph Graph::optimizedGraph() const { + // TODO: implement this method + return OptimizedGraph(); + } } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index ef091b5d5972..bcc6720c7176 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -22,6 +22,32 @@ namespace sd { namespace graph { + OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { + _onion = other._onion; + } + + OptimizedGraph &OptimizedGraph::operator=(const OptimizedGraph &other) noexcept { + if (this == &other) + return *this; + + _onion = other._onion; + + return *this; + } + + OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept { + _onion = std::move(other._onion); + } + + OptimizedGraph &OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { + if (this == &other) + return *this; + + _onion = std::move(other._onion); + + return *this; + } + uint64_t OptimizedGraph::layers() const { return _onion.size(); } From 3242d250a0a60cbf3c2ed9602d13f484c5ab71c2 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 17:53:01 +0300 Subject: [PATCH 010/233] few more methods Signed-off-by: raver119 --- libnd4j/include/graph/OptimizedGraph.h | 1 + libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 ++++ libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp | 7 +++++++ 3 files changed, 12 insertions(+) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 4de4d6a3d836..f59e51081ebb 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -68,6 +68,7 @@ namespace sd { */ // FIXME: this method should be removed or made private void append(const std::vector &layer); + void append(OpSequence &sequence); }; } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index bcc6720c7176..3e706107f30c 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -59,5 +59,9 @@ namespace sd { void OptimizedGraph::append(const std::vector &layer) { _onion[_onion.size()] = layer; } + + void OptimizedGraph::append(OpSequence &sequence) { + append(std::vector{sequence}); + } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index eae93dd23c21..fae8110e766b 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -25,6 +25,7 @@ #include #include #include +#include using namespace sd; using namespace sd::ops; @@ -59,4 +60,10 @@ TEST_F(OpSequenceTests, test_iterator_1) { } ASSERT_EQ(3, cnt); + + OptimizedGraph optimizedGraph; + ASSERT_EQ(0, optimizedGraph.layers()); + + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); } From 27feeb67228f1c45815b42e69527b5b56f03582b Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 18:37:48 +0300 Subject: [PATCH 011/233] some additional testing Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index fae8110e766b..0027f7052c23 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -66,4 +66,13 @@ TEST_F(OpSequenceTests, test_iterator_1) { optimizedGraph.append(sequence); ASSERT_EQ(1, optimizedGraph.layers()); + + auto layer = optimizedGraph.layer(0); + + // we expect exactly 1 sequence in this layer + ASSERT_EQ(1, layer.size()); + + auto seq = layer[0]; + + ASSERT_EQ(2, seq.length()); } From 667559c2d078e78917ee21fa6e0826bce06b7ed7 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 18:54:33 +0300 Subject: [PATCH 012/233] ManagedDataBuffer Signed-off-by: raver119 --- libnd4j/include/array/DataBuffer.h | 2 +- libnd4j/include/array/ManagedDataBuffer.h | 38 +++++++++++++++++++ .../include/array/impl/ManagedDataBuffer.cpp | 21 ++++++++++ libnd4j/include/graph/OptimizedGraph.h | 2 +- libnd4j/include/graph/execution/OpSequence.h | 2 +- libnd4j/include/helpers/FileUtils.h | 3 +- 6 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 libnd4j/include/array/ManagedDataBuffer.h create mode 100644 libnd4j/include/array/impl/ManagedDataBuffer.cpp diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 59ffe3045e08..f4c9bfd0448a 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -34,7 +34,7 @@ namespace sd { class ND4J_EXPORT DataBuffer { - private: + protected: void* _primaryBuffer = nullptr; void* _specialBuffer = nullptr; diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h new file mode 100644 index 000000000000..6b85f2ecc2fb --- /dev/null +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MANAGEDDATABUFFER_H +#define SD_MANAGEDDATABUFFER_H + +#include + +namespace sd { + /** + * This class provides special DataBuffer implementation for use within Graphs + */ + class ManagedDataBuffer : public DataBuffer { + protected: + public: + ManagedDataBuffer() = default; + ~ManagedDataBuffer() = default; + }; +} + +#endif //SD_MANAGEDDATABUFFER_H diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp new file mode 100644 index 000000000000..4884725b6324 --- /dev/null +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index f59e51081ebb..f75504408f21 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -29,7 +29,7 @@ namespace sd { /** * This class acts as a topologically sorted & optimized Graph representation, ready for execution */ - class OptimizedGraph { + class ND4J_EXPORT OptimizedGraph { protected: // here we store independent OpSequences // Graph starts from layer 0, and goes deeper step by step diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index d1a08a63ca1b..ce8fced76bae 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -30,7 +30,7 @@ namespace sd { /** * This class represents independent and immutable sequence of operations */ - class OpSequence : public std::iterator> { + class ND4J_EXPORT OpSequence : public std::iterator> { // our internal iterator for OpSequence class iterator; protected: diff --git a/libnd4j/include/helpers/FileUtils.h b/libnd4j/include/helpers/FileUtils.h index 9698b4916487..85e4bf2e3fc6 100644 --- a/libnd4j/include/helpers/FileUtils.h +++ b/libnd4j/include/helpers/FileUtils.h @@ -22,9 +22,10 @@ #define SD_FILEUTILS_H #include +#include namespace sd { - class FileUtils { + class ND4J_EXPORT FileUtils { public: static bool fileExists(const char *filename); From d2d603e55a7cd7f616bbf015e56b848f6cf396eb Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 19:01:00 +0300 Subject: [PATCH 013/233] ManagedDataBuffer Signed-off-by: raver119 --- libnd4j/include/array/ManagedDataBuffer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index 6b85f2ecc2fb..6ffde3ac7a72 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -27,7 +27,7 @@ namespace sd { /** * This class provides special DataBuffer implementation for use within Graphs */ - class ManagedDataBuffer : public DataBuffer { + class ND4J_EXPORT ManagedDataBuffer : public DataBuffer { protected: public: ManagedDataBuffer() = default; From a52e8621119133d0604194500cd8b9fcbca0c8a3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 19:15:04 +0300 Subject: [PATCH 014/233] one more planned Graph method Signed-off-by: raver119 --- libnd4j/include/array/ManagedDataBuffer.h | 2 +- libnd4j/include/array/impl/ManagedDataBuffer.cpp | 6 ++++++ libnd4j/include/graph/Graph.h | 7 +++++++ libnd4j/include/graph/impl/Graph.cpp | 5 +++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index 6ffde3ac7a72..5bdcbedc6c69 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -31,7 +31,7 @@ namespace sd { protected: public: ManagedDataBuffer() = default; - ~ManagedDataBuffer() = default; + ~ManagedDataBuffer(); }; } diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp index 4884725b6324..0103d506743c 100644 --- a/libnd4j/include/array/impl/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -19,3 +19,9 @@ // #include + +namespace sd { + ManagedDataBuffer::~ManagedDataBuffer() { + // + } +} \ No newline at end of file diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index c64ab128f05f..e1cdd7887006 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -244,6 +244,13 @@ namespace sd { OptimizedGraph optimizedGraph() const; + + /** + * This method executes this Graph instance and returns execution results + * @param dictionary + * @return + */ + std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}) const; }; FORCEINLINE std::vector* Graph::nodes() { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index cb0845423f2e..2543828048bb 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1703,6 +1703,11 @@ namespace sd { // TODO: implement this method return OptimizedGraph(); } + + std::map Graph::execute(const std::map &dictionary, const std::vector &outputs) const { + // TODO: implement this method + return std::map(); + } } } From 0e09e3532cbb6538729add30d45b558b9541f5a3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 26 Mar 2020 21:56:27 +0300 Subject: [PATCH 015/233] virtual DataBuffer methods Signed-off-by: raver119 --- libnd4j/include/array/ConstantDataBuffer.h | 2 +- libnd4j/include/array/ConstantDescriptor.h | 6 +-- libnd4j/include/array/ConstantHolder.h | 2 +- libnd4j/include/array/DataBuffer.h | 37 ++++++++++--------- libnd4j/include/array/ExtraArguments.h | 6 +-- libnd4j/include/array/NDArrayFactory.h | 6 +-- libnd4j/include/array/ShapeDescriptor.h | 6 +-- libnd4j/include/array/TadDescriptor.h | 6 +-- libnd4j/include/array/TadPack.h | 6 +-- .../include/exceptions/allocation_exception.h | 6 +-- libnd4j/include/exceptions/cuda_exception.h | 6 +-- .../include/exceptions/datatype_exception.h | 6 +-- libnd4j/include/exceptions/graph_exception.h | 2 +- .../exceptions/graph_execution_exception.h | 6 +-- .../exceptions/graph_exists_exception.h | 6 +-- .../include/exceptions/no_results_exception.h | 6 +-- .../exceptions/unknown_graph_exception.h | 6 +-- libnd4j/include/execution/AffinityManager.h | 2 +- libnd4j/include/execution/BlockingQueue.h | 2 +- libnd4j/include/execution/CallableInterface.h | 2 +- .../include/execution/CallableWithArguments.h | 6 +-- libnd4j/include/execution/ContextBuffers.h | 2 +- libnd4j/include/execution/ErrorReference.h | 6 +-- libnd4j/include/execution/ThreadPool.h | 2 +- libnd4j/include/execution/Ticket.h | 2 +- libnd4j/include/graph/InferenceRequest.h | 6 +-- .../exceptions/unresolved_input_exception.h | 6 +-- .../exceptions/unresolved_output_exception.h | 6 +-- libnd4j/include/helpers/BenchmarkHelper.h | 2 +- libnd4j/include/helpers/ConstantHelper.h | 6 +-- libnd4j/include/helpers/ConstantShapeHelper.h | 6 +-- libnd4j/include/helpers/ConstantTadHelper.h | 6 +-- libnd4j/include/helpers/OpBenchmark.h | 6 +-- libnd4j/include/helpers/ShapeBuilders.h | 6 +-- libnd4j/include/helpers/SimpleReadWriteLock.h | 2 +- libnd4j/include/helpers/benchmark/BasicSuit.h | 6 +-- .../helpers/benchmark/BoolParameters.h | 6 +-- .../helpers/benchmark/BroadcastBenchmark.h | 6 +-- .../helpers/benchmark/DeclarableBenchmark.h | 6 +-- .../include/helpers/benchmark/IntParameters.h | 6 +-- .../helpers/benchmark/IntPowerParameters.h | 6 +-- .../helpers/benchmark/MatrixBenchmark.h | 6 +-- .../helpers/benchmark/PairwiseBenchmark.h | 6 +-- .../include/helpers/benchmark/Parameters.h | 6 +-- .../helpers/benchmark/ParametersBatch.h | 6 +-- .../helpers/benchmark/ParametersSpace.h | 6 +-- .../helpers/benchmark/PredefinedParameters.h | 6 +-- .../helpers/benchmark/ReductionBenchmark.h | 6 +-- .../helpers/benchmark/ScalarBenchmark.h | 6 +-- .../helpers/benchmark/TransformBenchmark.h | 6 +-- libnd4j/include/helpers/cublasHelper.h | 6 +-- .../loops/BroadcastPairwiseConverter.h | 6 +-- .../include/loops/BroadcastScalarConverter.h | 6 +-- libnd4j/include/loops/ReduceType.h | 6 +-- .../cuda/inplace_loops/reduce_same_inplace.h | 6 +-- .../loops/cuda/inplace_loops/scalar_inplace.h | 6 +-- .../inplace_loops/transform_strict_inplace.h | 6 +-- libnd4j/include/math/platformmath.h | 2 +- libnd4j/include/memory/AllocationEntry.h | 6 +-- libnd4j/include/memory/MemoryTracker.h | 6 +-- libnd4j/include/memory/MemoryType.h | 6 +-- libnd4j/include/ops/BroadcastBoolOpsTuple.h | 6 +-- libnd4j/include/ops/BroadcastIntOpsTuple.h | 6 +-- libnd4j/include/ops/BroadcastOpsTuple.h | 6 +-- libnd4j/include/ops/declarable/headers/nlp.h | 6 +-- .../include/ops/declarable/helpers/flatten.h | 6 +-- .../include/ops/declarable/helpers/hamming.h | 2 +- .../include/ops/declarable/helpers/hashcode.h | 6 +-- .../ops/declarable/helpers/histogram.h | 2 +- .../include/ops/declarable/helpers/one_hot.h | 6 +-- .../include/ops/declarable/helpers/scatter.h | 6 +-- .../include/ops/declarable/helpers/sg_cb.h | 6 +-- .../include/ops/declarable/helpers/shift.h | 6 +-- .../ops/declarable/helpers/toggle_bits.h | 6 +-- .../include/ops/declarable/helpers/where.h | 6 +-- .../declarable/platform/mkldnn/mkldnnUtils.h | 6 +-- .../performance/benchmarking/BenchmarkSuit.h | 2 +- .../benchmarking/FullBenchmarkSuit.h | 2 +- .../benchmarking/LightBenchmarkSuit.h | 2 +- libnd4j/include/system/BlasVersionHelper.h | 2 +- libnd4j/include/system/msvc.h | 2 +- libnd4j/include/system/openmp_pragmas.h | 6 +-- libnd4j/include/types/utf8string.h | 6 +-- 83 files changed, 227 insertions(+), 226 deletions(-) diff --git a/libnd4j/include/array/ConstantDataBuffer.h b/libnd4j/include/array/ConstantDataBuffer.h index e8bafe114a36..c3c6923f35be 100644 --- a/libnd4j/include/array/ConstantDataBuffer.h +++ b/libnd4j/include/array/ConstantDataBuffer.h @@ -56,4 +56,4 @@ namespace sd { }; } -#endif //DEV_TESTS_CONSTANTDATABUFFER_H +#endif //SD_CONSTANTDATABUFFER_H diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index 589ba23532a3..09efbcad6743 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTDESCRIPTOR_H -#define DEV_TESTS_CONSTANTDESCRIPTOR_H +#ifndef SD_CONSTANTDESCRIPTOR_H +#define SD_CONSTANTDESCRIPTOR_H #include #include @@ -72,4 +72,4 @@ namespace std { #endif -#endif //DEV_TESTS_CONSTANTDESCRIPTOR_H +#endif //SD_CONSTANTDESCRIPTOR_H diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index a404e580843f..833549a742c1 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -62,4 +62,4 @@ namespace sd { } -#endif //DEV_TESTS_CONSTANTHOLDER_H +#endif //SD_CONSTANTHOLDER_H diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index f4c9bfd0448a..e9f725a70123 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#ifndef DEV_TESTS_DATABUFFER_H -#define DEV_TESTS_DATABUFFER_H +#ifndef SD_DATABUFFER_H +#define SD_DATABUFFER_H #include #include @@ -55,9 +55,9 @@ class ND4J_EXPORT DataBuffer { void setCountersToZero(); void copyCounters(const DataBuffer& other); - void deleteSpecial(); - void deletePrimary(); - void deleteBuffers(); + virtual void deleteSpecial(); + virtual void deletePrimary(); + virtual void deleteBuffers(); void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); void allocateBuffers(const bool allocBoth = false); void setSpecial(void* special, const bool isOwnerSpecial); @@ -87,8 +87,8 @@ class ND4J_EXPORT DataBuffer { explicit DataBuffer(); ~DataBuffer(); - DataBuffer& operator=(const DataBuffer& other); - DataBuffer& operator=(DataBuffer&& other) noexcept; + virtual DataBuffer& operator=(const DataBuffer& other); + virtual DataBuffer& operator=(DataBuffer&& other) noexcept; DataType getDataType(); void setDataType(DataType dataType); @@ -97,8 +97,8 @@ class ND4J_EXPORT DataBuffer { void* primary(); void* special(); - void allocatePrimary(); - void allocateSpecial(); + virtual void allocatePrimary(); + virtual void allocateSpecial(); void writePrimary() const; void writeSpecial() const; @@ -107,11 +107,12 @@ class ND4J_EXPORT DataBuffer { bool isPrimaryActual() const; bool isSpecialActual() const; - void expand(const uint64_t size); + virtual void expand(const uint64_t size); int deviceId() const; void setDeviceId(int deviceId); - void migrate(); + + virtual void migrate(); template FORCEINLINE T* primaryAsT(); template FORCEINLINE T* specialAsT(); @@ -119,35 +120,35 @@ class ND4J_EXPORT DataBuffer { void syncToPrimary(const LaunchContext* context, const bool forceSync = false); void syncToSpecial(const bool forceSync = false); - void setToZeroBuffers(const bool both = false); + virtual void setToZeroBuffers(const bool both = false); void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); static void memcpy(const DataBuffer &dst, const DataBuffer &src); - void setPrimaryBuffer(void *buffer, size_t length); - void setSpecialBuffer(void *buffer, size_t length); + virtual void setPrimaryBuffer(void *buffer, size_t length); + virtual void setSpecialBuffer(void *buffer, size_t length); /** * This method deletes buffers, if we're owners */ - void close(); + virtual void close(); }; ///// IMLEMENTATION OF INLINE METHODS ///// //////////////////////////////////////////////////////////////////////// template T* DataBuffer::primaryAsT() { - return reinterpret_cast(_primaryBuffer); + return reinterpret_cast(primary()); } //////////////////////////////////////////////////////////////////////// template T* DataBuffer::specialAsT() { - return reinterpret_cast(_specialBuffer); + return reinterpret_cast(special()); } } -#endif //DEV_TESTS_DATABUFFER_H +#endif //SD_DATABUFFER_H diff --git a/libnd4j/include/array/ExtraArguments.h b/libnd4j/include/array/ExtraArguments.h index 131e8cd92924..337667eea630 100644 --- a/libnd4j/include/array/ExtraArguments.h +++ b/libnd4j/include/array/ExtraArguments.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_EXTRAARGUMENTS_H -#define DEV_TESTS_EXTRAARGUMENTS_H +#ifndef SD_EXTRAARGUMENTS_H +#define SD_EXTRAARGUMENTS_H #include #include @@ -62,4 +62,4 @@ namespace sd { -#endif //DEV_TESTS_EXTRAARGUMENTS_H +#endif //SD_EXTRAARGUMENTS_H diff --git a/libnd4j/include/array/NDArrayFactory.h b/libnd4j/include/array/NDArrayFactory.h index f25c68fb4f32..bfe3aa3e6600 100644 --- a/libnd4j/include/array/NDArrayFactory.h +++ b/libnd4j/include/array/NDArrayFactory.h @@ -20,8 +20,8 @@ // @author Oleg Semeniv // -#ifndef DEV_TESTS_NDARRAYFACTORY_H -#define DEV_TESTS_NDARRAYFACTORY_H +#ifndef SD_NDARRAYFACTORY_H +#define SD_NDARRAYFACTORY_H #include #include @@ -188,4 +188,4 @@ namespace sd { }; } -#endif //DEV_TESTS_NDARRAYFACTORY_H +#endif //SD_NDARRAYFACTORY_H diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index 6e2299ba08cb..3f0704889943 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHAPEDESCRIPTOR_H -#define DEV_TESTS_SHAPEDESCRIPTOR_H +#ifndef SD_SHAPEDESCRIPTOR_H +#define SD_SHAPEDESCRIPTOR_H #include #include @@ -100,4 +100,4 @@ namespace std { #endif -#endif //DEV_TESTS_SHAPEDESCRIPTOR_H +#endif //SD_SHAPEDESCRIPTOR_H diff --git a/libnd4j/include/array/TadDescriptor.h b/libnd4j/include/array/TadDescriptor.h index 01ea1caa1270..8e14c1ec1efd 100644 --- a/libnd4j/include/array/TadDescriptor.h +++ b/libnd4j/include/array/TadDescriptor.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TADDESCRIPTOR_H -#define DEV_TESTS_TADDESCRIPTOR_H +#ifndef SD_TADDESCRIPTOR_H +#define SD_TADDESCRIPTOR_H #include "ShapeDescriptor.h" #include @@ -71,4 +71,4 @@ namespace std { #endif -#endif //DEV_TESTS_TADDESCRIPTOR_H +#endif //SD_TADDESCRIPTOR_H diff --git a/libnd4j/include/array/TadPack.h b/libnd4j/include/array/TadPack.h index 09b084548def..a5f880b9f3b8 100644 --- a/libnd4j/include/array/TadPack.h +++ b/libnd4j/include/array/TadPack.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TADPACK_H -#define DEV_TESTS_TADPACK_H +#ifndef SD_TADPACK_H +#define SD_TADPACK_H #include "ConstantDataBuffer.h" @@ -54,4 +54,4 @@ namespace sd { } -#endif //DEV_TESTS_TADPACK_H +#endif //SD_TADPACK_H diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 1e9b6653b550..98a5a47556ec 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ALLOCATION_EXCEPTION_H -#define DEV_TESTS_ALLOCATION_EXCEPTION_H +#ifndef SD_ALLOCATION_EXCEPTION_H +#define SD_ALLOCATION_EXCEPTION_H #include #include @@ -45,4 +45,4 @@ namespace sd { } -#endif //DEV_TESTS_ALLOCATION_EXCEPTION_H +#endif //SD_ALLOCATION_EXCEPTION_H diff --git a/libnd4j/include/exceptions/cuda_exception.h b/libnd4j/include/exceptions/cuda_exception.h index 2dc98eec3dff..a800cfcbcb1f 100644 --- a/libnd4j/include/exceptions/cuda_exception.h +++ b/libnd4j/include/exceptions/cuda_exception.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CUDA_EXCEPTION_H -#define DEV_TESTS_CUDA_EXCEPTION_H +#ifndef SD_CUDA_EXCEPTION_H +#define SD_CUDA_EXCEPTION_H #include #include @@ -44,4 +44,4 @@ namespace sd { -#endif //DEV_TESTS_CUDA_EXCEPTION_H +#endif //SD_CUDA_EXCEPTION_H diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index 74829d54c647..fbbe62164586 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -18,8 +18,8 @@ // Created by raver on 11/26/2018. // -#ifndef DEV_TESTS_DATATYPE_EXCEPTION_H -#define DEV_TESTS_DATATYPE_EXCEPTION_H +#ifndef SD_DATATYPE_EXCEPTION_H +#define SD_DATATYPE_EXCEPTION_H #include #include @@ -47,4 +47,4 @@ namespace sd { } -#endif //DEV_TESTS_DATATYPE_EXCEPTION_H +#endif //SD_DATATYPE_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exception.h b/libnd4j/include/exceptions/graph_exception.h index 7c9345a4deed..d7180e6ce07d 100644 --- a/libnd4j/include/exceptions/graph_exception.h +++ b/libnd4j/include/exceptions/graph_exception.h @@ -54,4 +54,4 @@ namespace sd { -#endif //DEV_TESTS_GRAPH_EXCEPTION_H +#endif //SD_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index 37f8e636e878..d64b0ccfdecd 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -18,8 +18,8 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_GRAPH_EXECUTION_EXCEPTION_H -#define DEV_TESTS_GRAPH_EXECUTION_EXCEPTION_H +#ifndef SD_GRAPH_EXECUTION_EXCEPTION_H +#define SD_GRAPH_EXECUTION_EXCEPTION_H #include #include @@ -41,4 +41,4 @@ namespace sd { }; } -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exists_exception.h b/libnd4j/include/exceptions/graph_exists_exception.h index 63554c31b386..a50566d90636 100644 --- a/libnd4j/include/exceptions/graph_exists_exception.h +++ b/libnd4j/include/exceptions/graph_exists_exception.h @@ -18,8 +18,8 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_GRAPH_EXISTS_EXCEPTION_H -#define DEV_TESTS_GRAPH_EXISTS_EXCEPTION_H +#ifndef SD_GRAPH_EXISTS_EXCEPTION_H +#define SD_GRAPH_EXISTS_EXCEPTION_H #include #include @@ -41,4 +41,4 @@ namespace sd { }; } -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/no_results_exception.h b/libnd4j/include/exceptions/no_results_exception.h index b2687854b25d..aa144679ac39 100644 --- a/libnd4j/include/exceptions/no_results_exception.h +++ b/libnd4j/include/exceptions/no_results_exception.h @@ -18,8 +18,8 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_NO_RESULTS_EXCEPTION_H -#define DEV_TESTS_NO_RESULTS_EXCEPTION_H +#ifndef SD_NO_RESULTS_EXCEPTION_H +#define SD_NO_RESULTS_EXCEPTION_H #include #include @@ -41,4 +41,4 @@ namespace sd { }; } -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/unknown_graph_exception.h b/libnd4j/include/exceptions/unknown_graph_exception.h index 917aeb757954..bcb1e2619dff 100644 --- a/libnd4j/include/exceptions/unknown_graph_exception.h +++ b/libnd4j/include/exceptions/unknown_graph_exception.h @@ -18,8 +18,8 @@ // Created by raver on 8/31/2018. // -#ifndef DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H -#define DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#ifndef SD_UNKNOWN_GRAPH_EXCEPTION_H +#define SD_UNKNOWN_GRAPH_EXCEPTION_H #include #include @@ -41,4 +41,4 @@ namespace sd { }; } -#endif //DEV_TESTS_UNKNOWN_GRAPH_EXCEPTION_H +#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h index 757f637cee05..10020d311c02 100644 --- a/libnd4j/include/execution/AffinityManager.h +++ b/libnd4j/include/execution/AffinityManager.h @@ -43,4 +43,4 @@ namespace sd { }; } -#endif //DEV_TESTS_AFFINITYMANAGER_H +#endif //SD_AFFINITYMANAGER_H diff --git a/libnd4j/include/execution/BlockingQueue.h b/libnd4j/include/execution/BlockingQueue.h index a78196dfc745..b3c2c654c773 100644 --- a/libnd4j/include/execution/BlockingQueue.h +++ b/libnd4j/include/execution/BlockingQueue.h @@ -49,4 +49,4 @@ namespace samediff { }; } -#endif //DEV_TESTS_BLOCKINGQUEUE_H +#endif //SD_BLOCKINGQUEUE_H diff --git a/libnd4j/include/execution/CallableInterface.h b/libnd4j/include/execution/CallableInterface.h index aad83b379222..c5053ecbc933 100644 --- a/libnd4j/include/execution/CallableInterface.h +++ b/libnd4j/include/execution/CallableInterface.h @@ -91,4 +91,4 @@ namespace samediff { } -#endif //DEV_TESTS_CALLABLEINTERFACE_H +#endif //SD_CALLABLEINTERFACE_H diff --git a/libnd4j/include/execution/CallableWithArguments.h b/libnd4j/include/execution/CallableWithArguments.h index 28ef8433e3ad..84a3be3f8b8e 100644 --- a/libnd4j/include/execution/CallableWithArguments.h +++ b/libnd4j/include/execution/CallableWithArguments.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CALLABLEWITHARGUMENTS_H -#define DEV_TESTS_CALLABLEWITHARGUMENTS_H +#ifndef SD_CALLABLEWITHARGUMENTS_H +#define SD_CALLABLEWITHARGUMENTS_H #include #include @@ -89,4 +89,4 @@ namespace samediff { } -#endif //DEV_TESTS_CALLABLEWITHARGUMENTS_H +#endif //SD_CALLABLEWITHARGUMENTS_H diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index c14671e426f9..386d84039444 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -73,4 +73,4 @@ namespace sd { } -#endif //DEV_TESTS_CONTEXTBUFFERS_H +#endif //SD_CONTEXTBUFFERS_H diff --git a/libnd4j/include/execution/ErrorReference.h b/libnd4j/include/execution/ErrorReference.h index b71090248994..0a3465dbe56b 100644 --- a/libnd4j/include/execution/ErrorReference.h +++ b/libnd4j/include/execution/ErrorReference.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ERRORREFERENCE_H -#define DEV_TESTS_ERRORREFERENCE_H +#ifndef SD_ERRORREFERENCE_H +#define SD_ERRORREFERENCE_H #include #include @@ -43,4 +43,4 @@ namespace sd { } -#endif //DEV_TESTS_ERRORREFERENCE_H +#endif //SD_ERRORREFERENCE_H diff --git a/libnd4j/include/execution/ThreadPool.h b/libnd4j/include/execution/ThreadPool.h index 6811f1b1c843..a1ce73f408a2 100644 --- a/libnd4j/include/execution/ThreadPool.h +++ b/libnd4j/include/execution/ThreadPool.h @@ -68,4 +68,4 @@ namespace samediff { } -#endif //DEV_TESTS_THREADPOOL_H +#endif //SD_THREADPOOL_H diff --git a/libnd4j/include/execution/Ticket.h b/libnd4j/include/execution/Ticket.h index 80bf54145661..a44a1e95756c 100644 --- a/libnd4j/include/execution/Ticket.h +++ b/libnd4j/include/execution/Ticket.h @@ -64,4 +64,4 @@ namespace samediff { } -#endif //DEV_TESTS_TICKET_H +#endif //SD_TICKET_H diff --git a/libnd4j/include/graph/InferenceRequest.h b/libnd4j/include/graph/InferenceRequest.h index b445fa0e1daa..38eed9b67e66 100644 --- a/libnd4j/include/graph/InferenceRequest.h +++ b/libnd4j/include/graph/InferenceRequest.h @@ -17,8 +17,8 @@ // // @author raver119@gmail.com // -#ifndef DEV_TESTS_INFERENCEREQUEST_H -#define DEV_TESTS_INFERENCEREQUEST_H +#ifndef SD_INFERENCEREQUEST_H +#define SD_INFERENCEREQUEST_H #include #include @@ -57,4 +57,4 @@ namespace sd { -#endif //DEV_TESTS_INFERENCEREQUEST_H +#endif //SD_INFERENCEREQUEST_H diff --git a/libnd4j/include/graph/exceptions/unresolved_input_exception.h b/libnd4j/include/graph/exceptions/unresolved_input_exception.h index 5e38977a99e7..842e5ef8c728 100644 --- a/libnd4j/include/graph/exceptions/unresolved_input_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_input_exception.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UNRESOLVED_INPUT_H -#define DEV_TESTS_UNRESOLVED_INPUT_H +#ifndef SD_UNRESOLVED_INPUT_H +#define SD_UNRESOLVED_INPUT_H #include #include @@ -39,4 +39,4 @@ namespace sd { } } -#endif //DEV_TESTS_UNRESOLVED_INPUT_H +#endif //SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/graph/exceptions/unresolved_output_exception.h b/libnd4j/include/graph/exceptions/unresolved_output_exception.h index 05d39c514818..b9f09bf4cbdf 100644 --- a/libnd4j/include/graph/exceptions/unresolved_output_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_output_exception.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UNRESOLVED_OUTPUT_H -#define DEV_TESTS_UNRESOLVED_OUTPUT_H +#ifndef SD_UNRESOLVED_OUTPUT_H +#define SD_UNRESOLVED_OUTPUT_H #include #include @@ -40,4 +40,4 @@ namespace sd { } -#endif //DEV_TESTS_UNRESOLVED_INPUT_H +#endif //SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/helpers/BenchmarkHelper.h b/libnd4j/include/helpers/BenchmarkHelper.h index f76f787d8b69..17c427cd8078 100644 --- a/libnd4j/include/helpers/BenchmarkHelper.h +++ b/libnd4j/include/helpers/BenchmarkHelper.h @@ -86,4 +86,4 @@ namespace sd { } -#endif //DEV_TESTS_BENCHMARKHELPER_H +#endif //SD_BENCHMARKHELPER_H diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index 3e5681fb68bc..7c519d6351e5 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTHELPER_H -#define DEV_TESTS_CONSTANTHELPER_H +#ifndef SD_CONSTANTHELPER_H +#define SD_CONSTANTHELPER_H #include #include @@ -61,4 +61,4 @@ namespace sd { }; } -#endif //DEV_TESTS_CONSTANTHELPER_H +#endif //SD_CONSTANTHELPER_H diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 4454776a44e9..e821e55f5a89 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CONSTANTSHAPEHELPER_H -#define DEV_TESTS_CONSTANTSHAPEHELPER_H +#ifndef SD_CONSTANTSHAPEHELPER_H +#define SD_CONSTANTSHAPEHELPER_H #include #include @@ -95,4 +95,4 @@ namespace sd { }; } -#endif //DEV_TESTS_CONSTANTSHAPEHELPER_H +#endif //SD_CONSTANTSHAPEHELPER_H diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index 80efaa86f7b5..c6bb834fb30f 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -19,8 +19,8 @@ // -#ifndef DEV_TESTS_CONSTANTTADHELPER_H -#define DEV_TESTS_CONSTANTTADHELPER_H +#ifndef SD_CONSTANTTADHELPER_H +#define SD_CONSTANTTADHELPER_H #include #include @@ -86,4 +86,4 @@ namespace sd { }; } -#endif //DEV_TESTS_CONSTANTTADHELPER_H +#endif //SD_CONSTANTTADHELPER_H diff --git a/libnd4j/include/helpers/OpBenchmark.h b/libnd4j/include/helpers/OpBenchmark.h index 328b20dce3ab..dfd303a626c5 100644 --- a/libnd4j/include/helpers/OpBenchmark.h +++ b/libnd4j/include/helpers/OpBenchmark.h @@ -18,8 +18,8 @@ // Created by raver on 2/28/2019. // -#ifndef DEV_TESTS_OPEXECUTIONER_H -#define DEV_TESTS_OPEXECUTIONER_H +#ifndef SD_OPEXECUTIONER_H +#define SD_OPEXECUTIONER_H #include #include @@ -73,4 +73,4 @@ namespace sd { } -#endif //DEV_TESTS_OPEXECUTIONER_H +#endif //SD_OPEXECUTIONER_H diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h index e2c29a280405..b1e5a3eba876 100644 --- a/libnd4j/include/helpers/ShapeBuilders.h +++ b/libnd4j/include/helpers/ShapeBuilders.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHAPEBUILDERS_H -#define DEV_TESTS_SHAPEBUILDERS_H +#ifndef SD_SHAPEBUILDERS_H +#define SD_SHAPEBUILDERS_H #include #include @@ -66,4 +66,4 @@ namespace sd { } -#endif //DEV_TESTS_SHAPEBUILDERS_H +#endif //SD_SHAPEBUILDERS_H diff --git a/libnd4j/include/helpers/SimpleReadWriteLock.h b/libnd4j/include/helpers/SimpleReadWriteLock.h index 5d1fce7119b1..7b87ef5eafea 100644 --- a/libnd4j/include/helpers/SimpleReadWriteLock.h +++ b/libnd4j/include/helpers/SimpleReadWriteLock.h @@ -57,4 +57,4 @@ namespace sd { } -#endif //DEV_TESTS_READWRITELOCK_H +#endif //SD_READWRITELOCK_H diff --git a/libnd4j/include/helpers/benchmark/BasicSuit.h b/libnd4j/include/helpers/benchmark/BasicSuit.h index 1e4c156fbe72..8e06f4f0e452 100644 --- a/libnd4j/include/helpers/benchmark/BasicSuit.h +++ b/libnd4j/include/helpers/benchmark/BasicSuit.h @@ -18,8 +18,8 @@ * @author raver119@gmail.com */ -#ifndef DEV_TESTS_BASICSUIT_H -#define DEV_TESTS_BASICSUIT_H +#ifndef SD_BASICSUIT_H +#define SD_BASICSUIT_H namespace sd { class BasicSuit { @@ -30,4 +30,4 @@ namespace sd { }; } -#endif //DEV_TESTS_BASICSUIT_H \ No newline at end of file +#endif //SD_BASICSUIT_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BoolParameters.h b/libnd4j/include/helpers/benchmark/BoolParameters.h index bac8a0c5cb69..20d547cee591 100644 --- a/libnd4j/include/helpers/benchmark/BoolParameters.h +++ b/libnd4j/include/helpers/benchmark/BoolParameters.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BOOLPARAMETERS_H -#define DEV_TESTS_BOOLPARAMETERS_H +#ifndef SD_BOOLPARAMETERS_H +#define SD_BOOLPARAMETERS_H #include #include @@ -45,4 +45,4 @@ namespace sd { }; } -#endif //DEV_TESTS_PARAMETERSPACE_H \ No newline at end of file +#endif //SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 3a043be5903b..6fc1822d0433 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -20,8 +20,8 @@ #include "../OpBenchmark.h" -#ifndef DEV_TESTS_BROADCASTBENCHMARK_H -#define DEV_TESTS_BROADCASTBENCHMARK_H +#ifndef SD_BROADCASTBENCHMARK_H +#define SD_BROADCASTBENCHMARK_H namespace sd { class ND4J_EXPORT BroadcastBenchmark : public OpBenchmark { @@ -130,4 +130,4 @@ void executeOnce() override { }; } -#endif //DEV_TESTS_BROADCASTBENCHMARK_H \ No newline at end of file +#endif //SD_BROADCASTBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index f9347eb0568b..032d05d0a677 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -19,8 +19,8 @@ // Created by raver on 3/2/2019. // -#ifndef DEV_TESTS_DECLARABLEBENCHMARK_H -#define DEV_TESTS_DECLARABLEBENCHMARK_H +#ifndef SD_DECLARABLEBENCHMARK_H +#define SD_DECLARABLEBENCHMARK_H #include #include @@ -174,4 +174,4 @@ namespace sd { }; } -#endif //DEV_TESTS_DECLARABLEBENCHMARKS_H \ No newline at end of file +#endif //SD_DECLARABLEBENCHMARKS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntParameters.h b/libnd4j/include/helpers/benchmark/IntParameters.h index 10a1763e4979..3f4e4cc344f8 100644 --- a/libnd4j/include/helpers/benchmark/IntParameters.h +++ b/libnd4j/include/helpers/benchmark/IntParameters.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_INTPARAMETERS_H -#define DEV_TESTS_INTPARAMETERS_H +#ifndef SD_INTPARAMETERS_H +#define SD_INTPARAMETERS_H #include #include @@ -52,4 +52,4 @@ namespace sd { }; } -#endif //DEV_TESTS_INTPARAMETERS_H \ No newline at end of file +#endif //SD_INTPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntPowerParameters.h b/libnd4j/include/helpers/benchmark/IntPowerParameters.h index 82c58bb2317e..29667ae43d64 100644 --- a/libnd4j/include/helpers/benchmark/IntPowerParameters.h +++ b/libnd4j/include/helpers/benchmark/IntPowerParameters.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_INTPOWERPARAMETERS_H -#define DEV_TESTS_INTPOWERPARAMETERS_H +#ifndef SD_INTPOWERPARAMETERS_H +#define SD_INTPOWERPARAMETERS_H #include #include @@ -54,4 +54,4 @@ namespace sd { }; } -#endif //DEV_TESTS_INTPOWERPARAMETERS_H \ No newline at end of file +#endif //SD_INTPOWERPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h index eb8fd2619993..139806bb4a51 100644 --- a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h +++ b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h @@ -21,8 +21,8 @@ #include #include -#ifndef DEV_TESTS_MATRIXBENCHMARK_H -#define DEV_TESTS_MATRIXBENCHMARK_H +#ifndef SD_MATRIXBENCHMARK_H +#define SD_MATRIXBENCHMARK_H namespace sd { class ND4J_EXPORT MatrixBenchmark : public OpBenchmark { @@ -121,4 +121,4 @@ namespace sd { }; } -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +#endif //SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h index ca92e96b3219..58ac17b4f72f 100644 --- a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h +++ b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h @@ -20,8 +20,8 @@ #include "../OpBenchmark.h" -#ifndef DEV_TESTS_PAIRWISEBENCHMARK_H -#define DEV_TESTS_PAIRWISEBENCHMARK_H +#ifndef SD_PAIRWISEBENCHMARK_H +#define SD_PAIRWISEBENCHMARK_H using namespace sd::graph; @@ -105,4 +105,4 @@ namespace sd { }; } -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +#endif //SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/Parameters.h b/libnd4j/include/helpers/benchmark/Parameters.h index eee443574b81..6810edec2b67 100644 --- a/libnd4j/include/helpers/benchmark/Parameters.h +++ b/libnd4j/include/helpers/benchmark/Parameters.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERS_H -#define DEV_TESTS_PARAMETERS_H +#ifndef SD_PARAMETERS_H +#define SD_PARAMETERS_H #include #include @@ -49,4 +49,4 @@ namespace sd { }; } -#endif //DEV_TESTS_PARAMETERS_H \ No newline at end of file +#endif //SD_PARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersBatch.h b/libnd4j/include/helpers/benchmark/ParametersBatch.h index 68c4dfb9f726..7706492183a3 100644 --- a/libnd4j/include/helpers/benchmark/ParametersBatch.h +++ b/libnd4j/include/helpers/benchmark/ParametersBatch.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERSBATCH_H -#define DEV_TESTS_PARAMETERSBATCH_H +#ifndef SD_PARAMETERSBATCH_H +#define SD_PARAMETERSBATCH_H #include #include @@ -81,4 +81,4 @@ namespace sd { }; } -#endif //DEV_TESTS_PARAMETERSBATCH_H \ No newline at end of file +#endif //SD_PARAMETERSBATCH_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersSpace.h b/libnd4j/include/helpers/benchmark/ParametersSpace.h index a7c59f9a6d45..a71b03d962d1 100644 --- a/libnd4j/include/helpers/benchmark/ParametersSpace.h +++ b/libnd4j/include/helpers/benchmark/ParametersSpace.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PARAMETERSPACE_H -#define DEV_TESTS_PARAMETERSPACE_H +#ifndef SD_PARAMETERSPACE_H +#define SD_PARAMETERSPACE_H #include @@ -39,4 +39,4 @@ namespace sd { }; } -#endif //DEV_TESTS_PARAMETERSPACE_H \ No newline at end of file +#endif //SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PredefinedParameters.h b/libnd4j/include/helpers/benchmark/PredefinedParameters.h index f2a7fc347655..d26b7bf0e7b4 100644 --- a/libnd4j/include/helpers/benchmark/PredefinedParameters.h +++ b/libnd4j/include/helpers/benchmark/PredefinedParameters.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_PREDEFINEDPARAMETERS_H -#define DEV_TESTS_PREDEFINEDPARAMETERS_H +#ifndef SD_PREDEFINEDPARAMETERS_H +#define SD_PREDEFINEDPARAMETERS_H #include "ParametersSpace.h" @@ -43,4 +43,4 @@ namespace sd { }; } -#endif //DEV_TESTS_PREDEFINEDPARAMETERS_H \ No newline at end of file +#endif //SD_PREDEFINEDPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index a1dc0126f7da..e5396af41060 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -22,8 +22,8 @@ #include #include "../OpBenchmark.h" -#ifndef DEV_TESTS_REDUCEBENCHMARK_H -#define DEV_TESTS_REDUCEBENCHMARK_H +#ifndef SD_REDUCEBENCHMARK_H +#define SD_REDUCEBENCHMARK_H using namespace sd::graph; @@ -149,4 +149,4 @@ namespace sd { }; } -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +#endif //SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 3b0cdecf5bcc..9063f3c526db 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -19,8 +19,8 @@ // #include "../OpBenchmark.h" -#ifndef DEV_TESTS_SCALARBENCHMARK_H -#define DEV_TESTS_SCALARBENCHMARK_H +#ifndef SD_SCALARBENCHMARK_H +#define SD_SCALARBENCHMARK_H using namespace sd::graph; @@ -98,4 +98,4 @@ namespace sd { }; } -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +#endif //SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/TransformBenchmark.h b/libnd4j/include/helpers/benchmark/TransformBenchmark.h index 024857633490..81a388edb39d 100644 --- a/libnd4j/include/helpers/benchmark/TransformBenchmark.h +++ b/libnd4j/include/helpers/benchmark/TransformBenchmark.h @@ -19,8 +19,8 @@ // #include "../OpBenchmark.h" -#ifndef DEV_TESTS_TRANSFORMBENCHMARK_H -#define DEV_TESTS_TRANSFORMBENCHMARK_H +#ifndef SD_TRANSFORMBENCHMARK_H +#define SD_TRANSFORMBENCHMARK_H using namespace sd::graph; @@ -132,4 +132,4 @@ namespace sd { }; } -#endif //DEV_TESTS_SCALARBENCHMARK_H \ No newline at end of file +#endif //SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 0300f3698329..8e690778936d 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_CUBLASHELPER_H -#define DEV_TESTS_CUBLASHELPER_H +#ifndef SD_CUBLASHELPER_H +#define SD_CUBLASHELPER_H #include #include @@ -49,4 +49,4 @@ namespace sd { }; } -#endif //DEV_TESTS_CUBLASHELPER_H +#endif //SD_CUBLASHELPER_H diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index acb7e8d64035..cb0655224d97 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -18,8 +18,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 09.04.2019 // -#ifndef DEV_TESTS_BROADCASTPAIRWISECONVERTER_H -#define DEV_TESTS_BROADCASTPAIRWISECONVERTER_H +#ifndef SD_BROADCASTPAIRWISECONVERTER_H +#define SD_BROADCASTPAIRWISECONVERTER_H #include #include @@ -93,4 +93,4 @@ inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) { } } -#endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file +#endif //SD_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file diff --git a/libnd4j/include/loops/BroadcastScalarConverter.h b/libnd4j/include/loops/BroadcastScalarConverter.h index 12006c293a2b..a4a334adecc4 100644 --- a/libnd4j/include/loops/BroadcastScalarConverter.h +++ b/libnd4j/include/loops/BroadcastScalarConverter.h @@ -17,8 +17,8 @@ /** * @author raver119@gmail.com */ -#ifndef DEV_TESTS_BROADCASTSCALARCONVERTER_H -#define DEV_TESTS_BROADCASTSCALARCONVERTER_H +#ifndef SD_BROADCASTSCALARCONVERTER_H +#define SD_BROADCASTSCALARCONVERTER_H #include #include @@ -54,4 +54,4 @@ namespace sd { } } -#endif //DEV_TESTS_BROADCASTSCALARCONVERTER_H +#endif //SD_BROADCASTSCALARCONVERTER_H diff --git a/libnd4j/include/loops/ReduceType.h b/libnd4j/include/loops/ReduceType.h index 501b7229e3ff..ae0bef08d496 100644 --- a/libnd4j/include/loops/ReduceType.h +++ b/libnd4j/include/loops/ReduceType.h @@ -18,8 +18,8 @@ * @author raver119@gmail.com */ -#ifndef DEV_TESTS_REDUCETYPE_H -#define DEV_TESTS_REDUCETYPE_H +#ifndef SD_REDUCETYPE_H +#define SD_REDUCETYPE_H namespace functions { enum ReduceType { @@ -33,4 +33,4 @@ namespace functions { }; } -#endif //DEV_TESTS_REDUCETYPE_H +#endif //SD_REDUCETYPE_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h index 1989cadc5c53..9e7bdf63409f 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h @@ -19,8 +19,8 @@ // -#ifndef DEV_TESTS_REDUCE_SAME_LOOPS_H -#define DEV_TESTS_REDUCE_SAME_LOOPS_H +#ifndef SD_REDUCE_SAME_LOOPS_H +#define SD_REDUCE_SAME_LOOPS_H #include #include @@ -170,4 +170,4 @@ namespace functions { } } -#endif //DEV_TESTS_REDUCE_SAME_LOOPS_H +#endif //SD_REDUCE_SAME_LOOPS_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h index df1a87ba896e..7049483e7852 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h @@ -19,8 +19,8 @@ // -#ifndef DEV_TESTS_SCALAR_INPLACE_H -#define DEV_TESTS_SCALAR_INPLACE_H +#ifndef SD_SCALAR_INPLACE_H +#define SD_SCALAR_INPLACE_H #include #include @@ -79,4 +79,4 @@ namespace functions { } } -#endif //DEV_TESTS_SCALAR_INPLACE_H +#endif //SD_SCALAR_INPLACE_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h index c4b94fca5651..903ed8c53291 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H -#define DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H +#ifndef SD_TRANSFORM_FLOAT_INPLACE_H +#define SD_TRANSFORM_FLOAT_INPLACE_H #include #include @@ -96,4 +96,4 @@ namespace functions { } #undef LOCAL_TRANSFORM_STRICT_OPS -#endif //DEV_TESTS_TRANSFORM_FLOAT_INPLACE_H +#endif //SD_TRANSFORM_FLOAT_INPLACE_H diff --git a/libnd4j/include/math/platformmath.h b/libnd4j/include/math/platformmath.h index e4990cc87338..e68e75fc46b3 100644 --- a/libnd4j/include/math/platformmath.h +++ b/libnd4j/include/math/platformmath.h @@ -871,4 +871,4 @@ namespace sd { } } -#endif //DEV_TESTS_PLATFORM_MATH_H +#endif //SD_PLATFORM_MATH_H diff --git a/libnd4j/include/memory/AllocationEntry.h b/libnd4j/include/memory/AllocationEntry.h index 815a5c992f2b..b2d9839af4e8 100644 --- a/libnd4j/include/memory/AllocationEntry.h +++ b/libnd4j/include/memory/AllocationEntry.h @@ -18,8 +18,8 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_ALLOCATIONENTRY_H -#define DEV_TESTS_ALLOCATIONENTRY_H +#ifndef SD_ALLOCATIONENTRY_H +#define SD_ALLOCATIONENTRY_H #include #include @@ -47,4 +47,4 @@ namespace sd { } -#endif //DEV_TESTS_ALLOCATIONENTRY_H +#endif //SD_ALLOCATIONENTRY_H diff --git a/libnd4j/include/memory/MemoryTracker.h b/libnd4j/include/memory/MemoryTracker.h index 38bb926ca0eb..60e347f03580 100644 --- a/libnd4j/include/memory/MemoryTracker.h +++ b/libnd4j/include/memory/MemoryTracker.h @@ -18,8 +18,8 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_MEMORYTRACKER_H -#define DEV_TESTS_MEMORYTRACKER_H +#ifndef SD_MEMORYTRACKER_H +#define SD_MEMORYTRACKER_H #include #include @@ -55,4 +55,4 @@ namespace sd { } -#endif //DEV_TESTS_MEMORYTRACKER_H +#endif //SD_MEMORYTRACKER_H diff --git a/libnd4j/include/memory/MemoryType.h b/libnd4j/include/memory/MemoryType.h index 113d8d16d383..4f9ced600e26 100644 --- a/libnd4j/include/memory/MemoryType.h +++ b/libnd4j/include/memory/MemoryType.h @@ -2,8 +2,8 @@ // Created by raver119 on 07.05.19. // -#ifndef DEV_TESTS_MEMORYTYPE_H -#define DEV_TESTS_MEMORYTYPE_H +#ifndef SD_MEMORYTYPE_H +#define SD_MEMORYTYPE_H namespace sd { namespace memory { @@ -14,4 +14,4 @@ namespace sd { } } -#endif //DEV_TESTS_MEMORYTYPE_H +#endif //SD_MEMORYTYPE_H diff --git a/libnd4j/include/ops/BroadcastBoolOpsTuple.h b/libnd4j/include/ops/BroadcastBoolOpsTuple.h index 188186b4ceb4..69c633eeda87 100644 --- a/libnd4j/include/ops/BroadcastBoolOpsTuple.h +++ b/libnd4j/include/ops/BroadcastBoolOpsTuple.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTBOOLOPSTUPLE_H -#define DEV_TESTS_BROADCASTBOOLOPSTUPLE_H +#ifndef SD_BROADCASTBOOLOPSTUPLE_H +#define SD_BROADCASTBOOLOPSTUPLE_H #include #include @@ -47,4 +47,4 @@ namespace sd { } -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +#endif //SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h index 258719004aba..64e6c407079f 100644 --- a/libnd4j/include/ops/BroadcastIntOpsTuple.h +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H -#define DEV_TESTS_BROADCASTINTOPSTUPLE_H +#ifndef SD_BROADCASTINTOPSTUPLE_H +#define SD_BROADCASTINTOPSTUPLE_H #include #include @@ -47,4 +47,4 @@ namespace sd { } -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +#endif //SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 34e2c603995d..c47f5462efcd 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_BROADCASTOPSTUPLE_H -#define DEV_TESTS_BROADCASTOPSTUPLE_H +#ifndef SD_BROADCASTOPSTUPLE_H +#define SD_BROADCASTOPSTUPLE_H #include #include @@ -59,4 +59,4 @@ namespace sd { } -#endif //DEV_TESTS_BROADCASTOPSTUPLE_H +#endif //SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/declarable/headers/nlp.h b/libnd4j/include/ops/declarable/headers/nlp.h index e12db1402576..1eb846365744 100644 --- a/libnd4j/include/ops/declarable/headers/nlp.h +++ b/libnd4j/include/ops/declarable/headers/nlp.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_NLP_H -#define DEV_TESTS_NLP_H +#ifndef SD_NLP_H +#define SD_NLP_H #include namespace sd { @@ -35,4 +35,4 @@ namespace sd { } } -#endif //DEV_TESTS_NLP_H +#endif //SD_NLP_H diff --git a/libnd4j/include/ops/declarable/helpers/flatten.h b/libnd4j/include/ops/declarable/helpers/flatten.h index da6253dfa0d9..bddf5362011e 100644 --- a/libnd4j/include/ops/declarable/helpers/flatten.h +++ b/libnd4j/include/ops/declarable/helpers/flatten.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_FLATTEN_H -#define DEV_TESTS_FLATTEN_H +#ifndef SD_FLATTEN_H +#define SD_FLATTEN_H #include #include @@ -65,4 +65,4 @@ INLINEDEF _CUDA_HD Nd4jLong getIndexOffsetOrdered(Nd4jLong index, const Nd4jLong } } -#endif //DEV_TESTS_FLATTEN_H +#endif //SD_FLATTEN_H diff --git a/libnd4j/include/ops/declarable/helpers/hamming.h b/libnd4j/include/ops/declarable/helpers/hamming.h index 6450d788240e..2b6883a4f680 100644 --- a/libnd4j/include/ops/declarable/helpers/hamming.h +++ b/libnd4j/include/ops/declarable/helpers/hamming.h @@ -29,4 +29,4 @@ namespace sd { } } -#endif //DEV_TESTS_HAMMING_H +#endif //SD_HAMMING_H diff --git a/libnd4j/include/ops/declarable/helpers/hashcode.h b/libnd4j/include/ops/declarable/helpers/hashcode.h index 730249d1aa66..6e76fc97609e 100644 --- a/libnd4j/include/ops/declarable/helpers/hashcode.h +++ b/libnd4j/include/ops/declarable/helpers/hashcode.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_HASHCODE_H -#define DEV_TESTS_HASHCODE_H +#ifndef SD_HASHCODE_H +#define SD_HASHCODE_H #include "helpers.h" @@ -67,4 +67,4 @@ namespace sd { } } -#endif //DEV_TESTS_HASHCODE_H +#endif //SD_HASHCODE_H diff --git a/libnd4j/include/ops/declarable/helpers/histogram.h b/libnd4j/include/ops/declarable/helpers/histogram.h index b9738ef07a49..2963d5f0e37b 100644 --- a/libnd4j/include/ops/declarable/helpers/histogram.h +++ b/libnd4j/include/ops/declarable/helpers/histogram.h @@ -31,4 +31,4 @@ namespace sd { } } -#endif //DEV_TESTS_HISTOGRAM_H +#endif //SD_HISTOGRAM_H diff --git a/libnd4j/include/ops/declarable/helpers/one_hot.h b/libnd4j/include/ops/declarable/helpers/one_hot.h index 1fefd7c446ed..2c435a75e948 100644 --- a/libnd4j/include/ops/declarable/helpers/one_hot.h +++ b/libnd4j/include/ops/declarable/helpers/one_hot.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_ONE_HOT_H -#define DEV_TESTS_ONE_HOT_H +#ifndef SD_ONE_HOT_H +#define SD_ONE_HOT_H #include #include @@ -34,4 +34,4 @@ namespace helpers { } } -#endif //DEV_TESTS_ONE_HOT_H +#endif //SD_ONE_HOT_H diff --git a/libnd4j/include/ops/declarable/helpers/scatter.h b/libnd4j/include/ops/declarable/helpers/scatter.h index 6e456ff9728f..0460a702d3cc 100644 --- a/libnd4j/include/ops/declarable/helpers/scatter.h +++ b/libnd4j/include/ops/declarable/helpers/scatter.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SCATTER_H -#define DEV_TESTS_SCATTER_H +#ifndef SD_SCATTER_H +#define SD_SCATTER_H #include @@ -37,4 +37,4 @@ namespace sd { } } -#endif //DEV_TESTS_SCATTER_H +#endif //SD_SCATTER_H diff --git a/libnd4j/include/ops/declarable/helpers/sg_cb.h b/libnd4j/include/ops/declarable/helpers/sg_cb.h index 6b0824a81a4b..abf073786b6d 100644 --- a/libnd4j/include/ops/declarable/helpers/sg_cb.h +++ b/libnd4j/include/ops/declarable/helpers/sg_cb.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SG_CB_H -#define DEV_TESTS_SG_CB_H +#ifndef SD_SG_CB_H +#define SD_SG_CB_H #include #include @@ -37,4 +37,4 @@ namespace sd { } } -#endif //DEV_TESTS_SG_CB_H +#endif //SD_SG_CB_H diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h index f1b21741ca57..da816a902640 100644 --- a/libnd4j/include/ops/declarable/helpers/shift.h +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_SHIFT_H -#define DEV_TESTS_SHIFT_H +#ifndef SD_SHIFT_H +#define SD_SHIFT_H #include #include @@ -39,4 +39,4 @@ namespace sd { } } -#endif //DEV_TESTS_SHIFT_H +#endif //SD_SHIFT_H diff --git a/libnd4j/include/ops/declarable/helpers/toggle_bits.h b/libnd4j/include/ops/declarable/helpers/toggle_bits.h index 6d8ffe44af3d..5c30765dd22c 100644 --- a/libnd4j/include/ops/declarable/helpers/toggle_bits.h +++ b/libnd4j/include/ops/declarable/helpers/toggle_bits.h @@ -20,8 +20,8 @@ #include -#ifndef DEV_TESTS_TOGGLE_BITS_H -#define DEV_TESTS_TOGGLE_BITS_H +#ifndef SD_TOGGLE_BITS_H +#define SD_TOGGLE_BITS_H namespace sd { namespace ops { @@ -34,4 +34,4 @@ namespace sd { } } -#endif //DEV_TESTS_TOGGLE_BITS_H +#endif //SD_TOGGLE_BITS_H diff --git a/libnd4j/include/ops/declarable/helpers/where.h b/libnd4j/include/ops/declarable/helpers/where.h index 2c958846246d..284a365f8730 100644 --- a/libnd4j/include/ops/declarable/helpers/where.h +++ b/libnd4j/include/ops/declarable/helpers/where.h @@ -18,8 +18,8 @@ // Created by raver119 on 24/09/18. // -#ifndef DEV_TESTS_WHERE_H -#define DEV_TESTS_WHERE_H +#ifndef SD_WHERE_H +#define SD_WHERE_H #include @@ -31,4 +31,4 @@ namespace sd { } } -#endif //DEV_TESTS_WHERE_H +#endif //SD_WHERE_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index dd512a884cdf..5f18ad7431e3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#ifndef DEV_TESTS_MKLDNNUTILS_H -#define DEV_TESTS_MKLDNNUTILS_H +#ifndef SD_MKLDNNUTILS_H +#define SD_MKLDNNUTILS_H #include @@ -194,4 +194,4 @@ namespace sd { -#endif //DEV_TESTS_MKLDNNUTILS_H +#endif //SD_MKLDNNUTILS_H diff --git a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h index 7805e570e47a..91d966339eb6 100644 --- a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h @@ -38,4 +38,4 @@ namespace sd { } -#endif //DEV_TESTS_BENCHMARKSUIT_H +#endif //SD_BENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h index 6b2314b96138..d5a653649f99 100644 --- a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h @@ -31,4 +31,4 @@ namespace sd { } -#endif //DEV_TESTS_FULLBENCHMARKSUIT_H +#endif //SD_FULLBENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h index 65a74b1fe7e0..1822a6f98588 100644 --- a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h @@ -31,4 +31,4 @@ namespace sd { } -#endif //DEV_TESTS_LIGHTBENCHMARKSUIT_H +#endif //SD_LIGHTBENCHMARKSUIT_H diff --git a/libnd4j/include/system/BlasVersionHelper.h b/libnd4j/include/system/BlasVersionHelper.h index 7cc97a26cc1e..0d894ed8df70 100644 --- a/libnd4j/include/system/BlasVersionHelper.h +++ b/libnd4j/include/system/BlasVersionHelper.h @@ -37,4 +37,4 @@ namespace sd { }; } -#endif //DEV_TESTS_BLASVERSIONHELPER_H +#endif //SD_BLASVERSIONHELPER_H diff --git a/libnd4j/include/system/msvc.h b/libnd4j/include/system/msvc.h index c884736f3ec7..4708d97c6874 100644 --- a/libnd4j/include/system/msvc.h +++ b/libnd4j/include/system/msvc.h @@ -36,4 +36,4 @@ #endif -#endif //DEV_TESTS_MSVC_H +#endif //SD_MSVC_H diff --git a/libnd4j/include/system/openmp_pragmas.h b/libnd4j/include/system/openmp_pragmas.h index 667f54521d31..0259ed75474f 100644 --- a/libnd4j/include/system/openmp_pragmas.h +++ b/libnd4j/include/system/openmp_pragmas.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_OPENMP_PRAGMAS_H -#define DEV_TESTS_OPENMP_PRAGMAS_H +#ifndef SD_OPENMP_PRAGMAS_H +#define SD_OPENMP_PRAGMAS_H #if defined(_MSC_VER) @@ -135,4 +135,4 @@ #define PRAGMA_THREADS_FOR_2D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) -> void #define PRAGMA_THREADS_FOR_3D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) -> void -#endif //DEV_TESTS_OPENMP_PRAGMAS_H +#endif //SD_OPENMP_PRAGMAS_H diff --git a/libnd4j/include/types/utf8string.h b/libnd4j/include/types/utf8string.h index ed25c6e10735..a0df70558b93 100644 --- a/libnd4j/include/types/utf8string.h +++ b/libnd4j/include/types/utf8string.h @@ -18,8 +18,8 @@ // @author raver119@gmail.com // -#ifndef DEV_TESTS_UTF8STRING_H -#define DEV_TESTS_UTF8STRING_H +#ifndef SD_UTF8STRING_H +#define SD_UTF8STRING_H #include #include @@ -46,4 +46,4 @@ namespace sd { } -#endif //DEV_TESTS_UTF8STRING_H +#endif //SD_UTF8STRING_H From b2c2fa5f969fcc406e57b4ee6df719572935db10 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 08:19:47 +0300 Subject: [PATCH 016/233] few build fixes Signed-off-by: raver119 --- libnd4j/include/helpers/files.h | 3 +-- libnd4j/include/helpers/impl/FileUtils.cpp | 1 - libnd4j/include/legacy/cuda/NativeOps.cu | 2 +- libnd4j/minifier/minifier.cpp | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/helpers/files.h b/libnd4j/include/helpers/files.h index 9be8e4e5db62..c85b98be2870 100644 --- a/libnd4j/include/helpers/files.h +++ b/libnd4j/include/helpers/files.h @@ -23,7 +23,6 @@ #define LIBND4J_FILES_H #include #include -#include #include @@ -63,7 +62,7 @@ char ** shellpath(void) { int next = 0; while (p) { #ifdef _WIN32 - char *q = strchr(p, ';'); // windows uses ; as delimiter + char *q = strchr(const_cast(p), ';'); // windows uses ; as delimiter #else char *q = strchr(const_cast(p), ':'); // linux and derivatives use : as delimiter #endif diff --git a/libnd4j/include/helpers/impl/FileUtils.cpp b/libnd4j/include/helpers/impl/FileUtils.cpp index b3da6e6d4e92..15c426a4b9da 100644 --- a/libnd4j/include/helpers/impl/FileUtils.cpp +++ b/libnd4j/include/helpers/impl/FileUtils.cpp @@ -22,7 +22,6 @@ #include #include #include -#include namespace sd { bool FileUtils::fileExists(const char *filename) { diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index 1a4de3de5063..bd6c23e8c05d 100755 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -2903,7 +2903,7 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { try { - auto graph = sd::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index 7846c1846956..ef1cf39d0613 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -110,7 +110,7 @@ main(int argc, char *argv[]) { #endif if (st.st_size != 0) { //std::cout << "File " << file << " exists and can be read" << std::endl; - auto graph = GraphExecutioner::importFromFlatBuffers(file.c_str()); + auto graph = Graph::fromFlatBuffers(file.c_str()); auto ops = graph->getOperations(); for (auto &v:ops) { From 9334cbb76d0d3f60b86780f9e8465aab2f585993 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 09:06:19 +0300 Subject: [PATCH 017/233] few other abstractions Signed-off-by: raver119 --- libnd4j/include/memory/HotZoneManager.h | 36 +++++++++++++++++++ libnd4j/include/memory/WarmZoneManager.h | 36 +++++++++++++++++++ libnd4j/include/memory/ZoneManager.h | 3 +- .../include/memory/impl/HotZoneManager.cpp | 21 +++++++++++ .../include/memory/impl/WarmZoneManager.cpp | 21 +++++++++++ .../layers_tests/ManagedDataBufferTests.cpp | 35 ++++++++++++++++++ 6 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 libnd4j/include/memory/HotZoneManager.h create mode 100644 libnd4j/include/memory/WarmZoneManager.h create mode 100644 libnd4j/include/memory/impl/HotZoneManager.cpp create mode 100644 libnd4j/include/memory/impl/WarmZoneManager.cpp create mode 100644 libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h new file mode 100644 index 000000000000..becb44e8cb1b --- /dev/null +++ b/libnd4j/include/memory/HotZoneManager.h @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_HOTZONEMANAGER_H +#define SD_HOTZONEMANAGER_H + +#include + +namespace sd { + class ND4J_EXPORT HotZoneManager : public ZoneManager { + protected: + public: + HotZoneManager() = default; + ~HotZoneManager() = default; + }; +} + + +#endif //SD_HOTZONEMANAGER_H diff --git a/libnd4j/include/memory/WarmZoneManager.h b/libnd4j/include/memory/WarmZoneManager.h new file mode 100644 index 000000000000..d57b7ca5b0de --- /dev/null +++ b/libnd4j/include/memory/WarmZoneManager.h @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_WARMZONEMANAGER_H +#define SD_WARMZONEMANAGER_H + +#include + +namespace sd { + class ND4J_EXPORT WarmZoneManager : public ZoneManager { + protected: + public: + WarmZoneManager() = default; + ~WarmZoneManager() = default; + }; +} + + +#endif //SD_WARMZONEMANAGER_H diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index 02ed896f5fc2..82d609caeb3d 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -21,6 +21,7 @@ #ifndef SD_ZONEMANAGER_H #define SD_ZONEMANAGER_H +#include #include #include @@ -28,7 +29,7 @@ namespace sd { /** * Abstract class that defines common methods for zone managers */ - class ZoneManager { + class ND4J_EXPORT ZoneManager { public: ZoneManager() = default; ~ZoneManager() = default; diff --git a/libnd4j/include/memory/impl/HotZoneManager.cpp b/libnd4j/include/memory/impl/HotZoneManager.cpp new file mode 100644 index 000000000000..1b7b390b3363 --- /dev/null +++ b/libnd4j/include/memory/impl/HotZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/memory/impl/WarmZoneManager.cpp b/libnd4j/include/memory/impl/WarmZoneManager.cpp new file mode 100644 index 000000000000..d6969fae099f --- /dev/null +++ b/libnd4j/include/memory/impl/WarmZoneManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp new file mode 100644 index 000000000000..c0d18f11baff --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ManagedDataBufferTests : public testing::Test { +public: + ManagedDataBufferTests() { + // + } +}; \ No newline at end of file From 540bbaf99b24517296178a86a8eb38c37ade4789 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 10:22:56 +0300 Subject: [PATCH 018/233] first basic test Signed-off-by: raver119 --- .../tests_cpu/layers_tests/ManagedDataBufferTests.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp index c0d18f11baff..6f901030dde2 100644 --- a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -23,6 +23,7 @@ #include #include #include +#include using namespace sd; using namespace sd::graph; @@ -32,4 +33,10 @@ class ManagedDataBufferTests : public testing::Test { ManagedDataBufferTests() { // } -}; \ No newline at end of file +}; + +TEST_F(ManagedDataBufferTests, basic_constructor_test_1) { + auto mdb = std::make_shared(); + + NDArray array(mdb, 'c', {0}); +} \ No newline at end of file From 2a9b3c7805f9cf6e27eb1df0e15b18c14e45819a Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 12:38:03 +0300 Subject: [PATCH 019/233] two more abstractions Signed-off-by: raver119 --- libnd4j/include/memory/ColdZoneManager.h | 27 +++++---- libnd4j/include/memory/GraphMemoryManager.h | 59 +++++++++++++++++++ libnd4j/include/memory/HotZoneManager.h | 14 +++-- libnd4j/include/memory/MemoryDescriptor.h | 53 +++++++++++++++++ libnd4j/include/memory/MemoryZone.h | 12 ++-- libnd4j/include/memory/WarmZoneManager.h | 14 +++-- libnd4j/include/memory/ZoneManager.h | 51 ++++++++-------- .../include/memory/cpu/GraphMemoryManager.cpp | 21 +++++++ .../include/memory/impl/ColdZoneManager.cpp | 4 +- .../include/memory/impl/MemoryDescriptor.cpp | 59 +++++++++++++++++++ 10 files changed, 260 insertions(+), 54 deletions(-) create mode 100644 libnd4j/include/memory/GraphMemoryManager.h create mode 100644 libnd4j/include/memory/MemoryDescriptor.h create mode 100644 libnd4j/include/memory/cpu/GraphMemoryManager.cpp create mode 100644 libnd4j/include/memory/impl/MemoryDescriptor.cpp diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h index ea840f85e8ee..06d03d10411a 100644 --- a/libnd4j/include/memory/ColdZoneManager.h +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -25,18 +25,21 @@ #include namespace sd { - class ColdZoneManager : public ZoneManager { - public: - /** - * This constructor is used to initialize ZoneManager with existing FlatBuffers file - * @param filename - full path to existing file (i.e. FlatBuffers file) - */ - explicit ColdZoneManager(const char* filename); - ColdZoneManager() = default; - ~ColdZoneManager() = default; - - - }; + namespace memory { + class ColdZoneManager : public ZoneManager { + public: + /** + * This constructor is used to initialize ZoneManager with existing FlatBuffers file + * @param filename - full path to existing file (i.e. FlatBuffers file) + */ + explicit ColdZoneManager(const char *filename); + + ColdZoneManager() = default; + ~ColdZoneManager() = default; + + + }; + } } diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h new file mode 100644 index 000000000000..ffa48543c083 --- /dev/null +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_GRAPHMEMORYMANAGER_H +#define SD_GRAPHMEMORYMANAGER_H + +#include +#include +#include +#include + +using namespace sd::memory; + +namespace sd { + namespace graph { + class GraphMemoryManager { + protected: + std::map _zones; + + public: + GraphMemoryManager() = default; + ~GraphMemoryManager() = default; + + /** + * This method does allocation (probably) and returns structure that describes it + * @param numBytes - number of bytes to be allocated + * @param zone - memory zone for allocation + * @return + */ + MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); + + /** + * This method releases (probably) memory chunk described by given descriptor + * @param descriptor + */ + void release(MemoryDescriptor &descriptor); + }; + } +} + + +#endif //SD_GRAPHMEMORYMANAGER_H diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h index becb44e8cb1b..03ad41ebd709 100644 --- a/libnd4j/include/memory/HotZoneManager.h +++ b/libnd4j/include/memory/HotZoneManager.h @@ -24,12 +24,14 @@ #include namespace sd { - class ND4J_EXPORT HotZoneManager : public ZoneManager { - protected: - public: - HotZoneManager() = default; - ~HotZoneManager() = default; - }; + namespace memory { + class ND4J_EXPORT HotZoneManager : public ZoneManager { + protected: + public: + HotZoneManager() = default; + ~HotZoneManager() = default; + }; + } } diff --git a/libnd4j/include/memory/MemoryDescriptor.h b/libnd4j/include/memory/MemoryDescriptor.h new file mode 100644 index 000000000000..92858ea33239 --- /dev/null +++ b/libnd4j/include/memory/MemoryDescriptor.h @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORYDESCRIPTOR_H +#define SD_MEMORYDESCRIPTOR_H + +#include +#include +#include + +namespace sd { + namespace memory { + class ND4J_EXPORT MemoryDescriptor { + private: + void* _ptr; + MemoryZone _zone; + uint64_t _bytes; + public: + MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes); + ~MemoryDescriptor() = default; + + MemoryDescriptor(const MemoryDescriptor& other) noexcept; + + MemoryDescriptor& operator=(const MemoryDescriptor& other) noexcept; + + // move constructor + MemoryDescriptor(MemoryDescriptor&& other) noexcept; + + // move assignment operator + MemoryDescriptor& operator=(MemoryDescriptor&& other) noexcept; + }; + } +} + + +#endif //SD_MEMORYDESCRIPTOR_H diff --git a/libnd4j/include/memory/MemoryZone.h b/libnd4j/include/memory/MemoryZone.h index e31e87cf9a63..16921e9b5d79 100644 --- a/libnd4j/include/memory/MemoryZone.h +++ b/libnd4j/include/memory/MemoryZone.h @@ -22,11 +22,13 @@ #define SD_MEMORYZONE_H namespace sd { - enum MemoryZone { - COLD = 0, - WARM = 10, - HOT = 20, - }; + namespace memory { + enum MemoryZone { + COLD = 0, + WARM = 10, + HOT = 20, + }; + } } #endif //SD_MEMORYZONE_H diff --git a/libnd4j/include/memory/WarmZoneManager.h b/libnd4j/include/memory/WarmZoneManager.h index d57b7ca5b0de..00c604ec1b77 100644 --- a/libnd4j/include/memory/WarmZoneManager.h +++ b/libnd4j/include/memory/WarmZoneManager.h @@ -24,12 +24,14 @@ #include namespace sd { - class ND4J_EXPORT WarmZoneManager : public ZoneManager { - protected: - public: - WarmZoneManager() = default; - ~WarmZoneManager() = default; - }; + namespace memory { + class ND4J_EXPORT WarmZoneManager : public ZoneManager { + protected: + public: + WarmZoneManager() = default; + ~WarmZoneManager() = default; + }; + } } diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index 82d609caeb3d..6d94ffc6446b 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -26,32 +26,35 @@ #include namespace sd { - /** - * Abstract class that defines common methods for zone managers - */ - class ND4J_EXPORT ZoneManager { - public: - ZoneManager() = default; - ~ZoneManager() = default; - - /** - * This method returns id of the current zone served by this manager instance - * @return MemoryZone enum - */ - virtual MemoryZone zone() const = 0; - - /** - * This method returns amount of memory available in this zone - * @return number of bytes - */ - virtual uint64_t available() const = 0; - + namespace memory { /** - * This method returns amount of memory currently used in this zone - * @return number of bytes + * Abstract class that defines common methods for zone managers */ - virtual uint64_t used() const = 0; - }; + class ND4J_EXPORT ZoneManager { + public: + ZoneManager() = default; + + ~ZoneManager() = default; + + /** + * This method returns id of the current zone served by this manager instance + * @return MemoryZone enum + */ + virtual MemoryZone zone() const = 0; + + /** + * This method returns amount of memory available in this zone + * @return number of bytes + */ + virtual uint64_t available() const = 0; + + /** + * This method returns amount of memory currently used in this zone + * @return number of bytes + */ + virtual uint64_t used() const = 0; + }; + } } diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp new file mode 100644 index 000000000000..54cef3f7048a --- /dev/null +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -0,0 +1,21 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include diff --git a/libnd4j/include/memory/impl/ColdZoneManager.cpp b/libnd4j/include/memory/impl/ColdZoneManager.cpp index 0f0a64f077cc..f495c3849246 100644 --- a/libnd4j/include/memory/impl/ColdZoneManager.cpp +++ b/libnd4j/include/memory/impl/ColdZoneManager.cpp @@ -21,7 +21,9 @@ #include namespace sd { - ColdZoneManager::ColdZoneManager(const char* filename) { + namespace memory { + ColdZoneManager::ColdZoneManager(const char *filename) { + } } } diff --git a/libnd4j/include/memory/impl/MemoryDescriptor.cpp b/libnd4j/include/memory/impl/MemoryDescriptor.cpp new file mode 100644 index 000000000000..4512639e3558 --- /dev/null +++ b/libnd4j/include/memory/impl/MemoryDescriptor.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace memory { + MemoryDescriptor::MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes) : _ptr(ptr), _zone(zone), _bytes(bytes) { + // + } + + MemoryDescriptor::MemoryDescriptor(const MemoryDescriptor &other) noexcept : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // + } + + MemoryDescriptor &MemoryDescriptor::operator=(const MemoryDescriptor &other) noexcept { + if (this == &other) + return *this; + + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; + + return *this; + } + + MemoryDescriptor::MemoryDescriptor(MemoryDescriptor &&other) noexcept : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // + } + + MemoryDescriptor &MemoryDescriptor::operator=(MemoryDescriptor &&other) noexcept { + if (this == &other) + return *this; + + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; + + return *this; + } + } +} From 8131f188c4688ef3399aebe5d1860a31bad573aa Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 13:10:41 +0300 Subject: [PATCH 020/233] few missing methods Signed-off-by: raver119 --- libnd4j/include/memory/MemoryDescriptor.h | 4 ++++ libnd4j/include/memory/cpu/GraphMemoryManager.cpp | 12 ++++++++++++ libnd4j/include/memory/impl/MemoryDescriptor.cpp | 12 ++++++++++++ 3 files changed, 28 insertions(+) diff --git a/libnd4j/include/memory/MemoryDescriptor.h b/libnd4j/include/memory/MemoryDescriptor.h index 92858ea33239..799df092ec4c 100644 --- a/libnd4j/include/memory/MemoryDescriptor.h +++ b/libnd4j/include/memory/MemoryDescriptor.h @@ -45,6 +45,10 @@ namespace sd { // move assignment operator MemoryDescriptor& operator=(MemoryDescriptor&& other) noexcept; + + void* address() const; + MemoryZone zone() const; + uint64_t bytes() const; }; } } diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp index 54cef3f7048a..45fa4f20c7e8 100644 --- a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -19,3 +19,15 @@ // #include + +namespace sd { + namespace graph { + MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, MemoryZone zone) { + return MemoryDescriptor(nullptr, COLD, 0); + } + + void GraphMemoryManager::release(MemoryDescriptor &descriptor) { + + } + } +} diff --git a/libnd4j/include/memory/impl/MemoryDescriptor.cpp b/libnd4j/include/memory/impl/MemoryDescriptor.cpp index 4512639e3558..7d3dfeb753f2 100644 --- a/libnd4j/include/memory/impl/MemoryDescriptor.cpp +++ b/libnd4j/include/memory/impl/MemoryDescriptor.cpp @@ -55,5 +55,17 @@ namespace sd { return *this; } + + void *MemoryDescriptor::address() const { + return _ptr; + } + + MemoryZone MemoryDescriptor::zone() const { + return _zone; + } + + uint64_t MemoryDescriptor::bytes() const { + return _bytes; + } } } From 9d52e98f95b93de2fdbeb4dc2222132c1291edde Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 14:40:04 +0300 Subject: [PATCH 021/233] meh Signed-off-by: raver119 --- libnd4j/include/array/DataBuffer.h | 11 +++++-- libnd4j/include/array/ManagedDataBuffer.h | 14 ++++++++- libnd4j/include/array/cpu/DataBuffer.cpp | 4 +++ .../include/array/cpu/ManagedDataBuffer.cpp | 31 +++++++++++++++++++ libnd4j/include/array/cuda/DataBuffer.cu | 4 +++ .../include/array/impl/ManagedDataBuffer.cpp | 9 ++++++ 6 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 libnd4j/include/array/cpu/ManagedDataBuffer.cpp diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index e9f725a70123..4b0c928440e9 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -94,8 +94,9 @@ class ND4J_EXPORT DataBuffer { void setDataType(DataType dataType); size_t getLenInBytes() const; - void* primary(); - void* special(); + virtual void* primary(); + virtual void* special(); + virtual void* platform(); virtual void allocatePrimary(); virtual void allocateSpecial(); @@ -116,6 +117,7 @@ class ND4J_EXPORT DataBuffer { template FORCEINLINE T* primaryAsT(); template FORCEINLINE T* specialAsT(); + template FORCEINLINE T* platformAsT(); void syncToPrimary(const LaunchContext* context, const bool forceSync = false); void syncToSpecial(const bool forceSync = false); @@ -148,6 +150,11 @@ class ND4J_EXPORT DataBuffer { return reinterpret_cast(special()); } +//////////////////////////////////////////////////////////////////////// + template + T* DataBuffer::platformAsT() { + return reinterpret_cast(platform()); + } } diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index 5bdcbedc6c69..8406e4f42867 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -22,16 +22,28 @@ #define SD_MANAGEDDATABUFFER_H #include +#include namespace sd { /** * This class provides special DataBuffer implementation for use within Graphs */ class ND4J_EXPORT ManagedDataBuffer : public DataBuffer { + private: + graph::GraphMemoryManager &_manager; + protected: + uint64_t _bytes; + DataType _dtype; + memory::MemoryZone _zone; + MemoryDescriptor _descriptor; + public: - ManagedDataBuffer() = default; + ManagedDataBuffer(graph::GraphMemoryManager &manager, uint64_t numberOfBytes, DataType dtype, memory::MemoryZone zone); ~ManagedDataBuffer(); + + void* primary() override; + void* special() override; }; } diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index 2575e2ba41bc..677518746e0e 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -127,6 +127,10 @@ void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { dst.readPrimary(); } + void *DataBuffer::platform() { + return primary(); + } + //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const { } diff --git a/libnd4j/include/array/cpu/ManagedDataBuffer.cpp b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp new file mode 100644 index 000000000000..bacb6d7886e1 --- /dev/null +++ b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + void *ManagedDataBuffer::primary() { + return _descriptor.address(); + } + + void *ManagedDataBuffer::special() { + return nullptr; + } +} \ No newline at end of file diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index 922b6967bfe5..9660d38669e7 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -281,6 +281,10 @@ void DataBuffer::migrate() { _specialBuffer = newBuffer; } + void *DataBuffer::platform() { + return special(); + } + //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const {_writePrimary = ++_counter; } void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp index 0103d506743c..2744fcfe5657 100644 --- a/libnd4j/include/array/impl/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -21,6 +21,15 @@ #include namespace sd { + ManagedDataBuffer::ManagedDataBuffer(graph::GraphMemoryManager &manager, uint64_t numberOfBytes, DataType dtype, memory::MemoryZone zone) : + _manager(manager), + _bytes(numberOfBytes), + _dtype(dtype), + _zone(zone), + _descriptor(manager.allocate(numberOfBytes, zone)) { + // everything already initialized + } + ManagedDataBuffer::~ManagedDataBuffer() { // } From d6f4c9bf95d4c4b3b28244c971db4eda52bef697 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 14:40:35 +0300 Subject: [PATCH 022/233] meh Signed-off-by: raver119 --- .../layers_tests/ManagedDataBufferTests.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp index 6f901030dde2..5f832652dda7 100644 --- a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -36,7 +36,17 @@ class ManagedDataBufferTests : public testing::Test { }; TEST_F(ManagedDataBufferTests, basic_constructor_test_1) { - auto mdb = std::make_shared(); + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 0, DataType::FLOAT32, memory::MemoryZone::HOT); NDArray array(mdb, 'c', {0}); +} + +TEST_F(ManagedDataBufferTests, basic_constructor_test_2) { + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 20, DataType::FLOAT32, memory::MemoryZone::HOT); + + ASSERT_NE(nullptr, mdb->platform()); + + NDArray array(mdb, 'c', {5}); } \ No newline at end of file From 5c8c1ae98c2f0ef2597679843f88d1f5deb67d9b Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 18:40:25 +0300 Subject: [PATCH 023/233] ManagedDataBuffer stuff Signed-off-by: raver119 --- .../include/array/impl/ManagedDataBuffer.cpp | 5 +- libnd4j/include/memory/ColdZoneManager.h | 8 +++ libnd4j/include/memory/GraphMemoryManager.h | 2 +- ...HotZoneManager.cpp => HotRamZoneManager.h} | 20 ++++++++ libnd4j/include/memory/HotZoneManager.h | 14 ++++++ libnd4j/include/memory/ZoneManager.h | 20 ++++++++ .../include/memory/cpu/ColdZoneManager.cpp | 49 +++++++++++++++++++ .../include/memory/cpu/GraphMemoryManager.cpp | 15 +++++- .../HotZoneManager.cpp} | 15 ++++-- .../include/memory/impl/HotRamZoneManager.cpp | 38 ++++++++++++++ 10 files changed, 179 insertions(+), 7 deletions(-) rename libnd4j/include/memory/{impl/HotZoneManager.cpp => HotRamZoneManager.h} (64%) create mode 100644 libnd4j/include/memory/cpu/ColdZoneManager.cpp rename libnd4j/include/memory/{impl/ColdZoneManager.cpp => cpu/HotZoneManager.cpp} (74%) create mode 100644 libnd4j/include/memory/impl/HotRamZoneManager.cpp diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp index 2744fcfe5657..a2418281def3 100644 --- a/libnd4j/include/array/impl/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -27,10 +27,11 @@ namespace sd { _dtype(dtype), _zone(zone), _descriptor(manager.allocate(numberOfBytes, zone)) { - // everything already initialized + // everything already initialized by now } ManagedDataBuffer::~ManagedDataBuffer() { - // + // if we know that MDB can be released - it means that all NDArrays were released, so it's really safe to release + _manager.release(_descriptor); } } \ No newline at end of file diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h index 06d03d10411a..448b5e97e61f 100644 --- a/libnd4j/include/memory/ColdZoneManager.h +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -37,7 +37,15 @@ namespace sd { ColdZoneManager() = default; ~ColdZoneManager() = default; + MemoryZone zone() const override; + uint64_t available() const override; + + uint64_t used() const override; + + MemoryDescriptor allocate(uint64_t numBytes) override; + + void release(MemoryDescriptor &descriptor) override; }; } } diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h index ffa48543c083..87fe68468919 100644 --- a/libnd4j/include/memory/GraphMemoryManager.h +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -35,7 +35,7 @@ namespace sd { std::map _zones; public: - GraphMemoryManager() = default; + GraphMemoryManager(); ~GraphMemoryManager() = default; /** diff --git a/libnd4j/include/memory/impl/HotZoneManager.cpp b/libnd4j/include/memory/HotRamZoneManager.h similarity index 64% rename from libnd4j/include/memory/impl/HotZoneManager.cpp rename to libnd4j/include/memory/HotRamZoneManager.h index 1b7b390b3363..39e4362ab53e 100644 --- a/libnd4j/include/memory/impl/HotZoneManager.cpp +++ b/libnd4j/include/memory/HotRamZoneManager.h @@ -18,4 +18,24 @@ // @author raver119@gmail.com // +#ifndef SD_HOTRAMZONEMANAGER_H +#define SD_HOTRAMZONEMANAGER_H + #include + +namespace sd { + namespace memory { + class HotRamZoneManager : public HotZoneManager { + public: + HotRamZoneManager() = default; + ~HotRamZoneManager() = default; + + MemoryDescriptor allocate(uint64_t numBytes) override; + + void release(MemoryDescriptor &descriptor) override; + }; + } +} + + +#endif //SD_HOTRAMZONEMANAGER_H diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h index 03ad41ebd709..e33fb236905b 100644 --- a/libnd4j/include/memory/HotZoneManager.h +++ b/libnd4j/include/memory/HotZoneManager.h @@ -22,14 +22,28 @@ #define SD_HOTZONEMANAGER_H #include +#include namespace sd { namespace memory { class ND4J_EXPORT HotZoneManager : public ZoneManager { protected: + std::atomic _used = {0}; + std::atomic _available = {0}; + public: HotZoneManager() = default; ~HotZoneManager() = default; + + MemoryZone zone() const override; + + uint64_t available() const override; + + uint64_t used() const override; + + virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + + virtual void release(MemoryDescriptor &descriptor) = 0; }; } } diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index 6d94ffc6446b..a67158bb7dc9 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -23,7 +23,9 @@ #include #include +#include #include +#include namespace sd { namespace memory { @@ -31,6 +33,9 @@ namespace sd { * Abstract class that defines common methods for zone managers */ class ND4J_EXPORT ZoneManager { + protected: + std::mutex _lock; + public: ZoneManager() = default; @@ -53,6 +58,21 @@ namespace sd { * @return number of bytes */ virtual uint64_t used() const = 0; + + /** + * This method allocates (probably) some memory chunk, and returns you pointer to it. + * @param numBytes + * @return + */ + virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + + /** + * This method releases (probably) memory described by given MemoryDescriptor + * @param descriptor + */ + virtual void release(MemoryDescriptor &descriptor) = 0; + + }; } } diff --git a/libnd4j/include/memory/cpu/ColdZoneManager.cpp b/libnd4j/include/memory/cpu/ColdZoneManager.cpp new file mode 100644 index 000000000000..a9599c229d0c --- /dev/null +++ b/libnd4j/include/memory/cpu/ColdZoneManager.cpp @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace memory { + ColdZoneManager::ColdZoneManager(const char *filename) { + // + } + + MemoryZone ColdZoneManager::zone() const { + return COLD; + } + + uint64_t ColdZoneManager::available() const { + return 0; + } + + uint64_t ColdZoneManager::used() const { + return 0; + } + + MemoryDescriptor ColdZoneManager::allocate(uint64_t numBytes) { + return MemoryDescriptor(nullptr, COLD, numBytes); + } + + void ColdZoneManager::release(MemoryDescriptor &descriptor) { + // + } + } +} diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp index 45fa4f20c7e8..82f8d50d7f81 100644 --- a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -19,11 +19,24 @@ // #include +#include +#include namespace sd { namespace graph { + GraphMemoryManager::GraphMemoryManager() { + // first of all we initialize all memory managers + // CPU backend only has two: HOT and COLD + + _zones[MemoryZone::HOT] = new memory::HotRamZoneManager(); + _zones[MemoryZone::COLD] = new memory::ColdZoneManager(); + } + MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, MemoryZone zone) { - return MemoryDescriptor(nullptr, COLD, 0); + if (zone == MemoryZone::WARM) + zone = MemoryZone::HOT; + + return _zones[zone]->allocate(numBytes); } void GraphMemoryManager::release(MemoryDescriptor &descriptor) { diff --git a/libnd4j/include/memory/impl/ColdZoneManager.cpp b/libnd4j/include/memory/cpu/HotZoneManager.cpp similarity index 74% rename from libnd4j/include/memory/impl/ColdZoneManager.cpp rename to libnd4j/include/memory/cpu/HotZoneManager.cpp index f495c3849246..229029ad153e 100644 --- a/libnd4j/include/memory/impl/ColdZoneManager.cpp +++ b/libnd4j/include/memory/cpu/HotZoneManager.cpp @@ -18,12 +18,21 @@ // @author raver119@gmail.com // -#include +#include + namespace sd { namespace memory { - ColdZoneManager::ColdZoneManager(const char *filename) { + MemoryZone HotZoneManager::zone() const { + return HOT; + } + + uint64_t HotZoneManager::available() const { + return _available; + } + uint64_t HotZoneManager::used() const { + return _used; } } -} +} \ No newline at end of file diff --git a/libnd4j/include/memory/impl/HotRamZoneManager.cpp b/libnd4j/include/memory/impl/HotRamZoneManager.cpp new file mode 100644 index 000000000000..651eb501a30a --- /dev/null +++ b/libnd4j/include/memory/impl/HotRamZoneManager.cpp @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace memory { + MemoryDescriptor HotRamZoneManager::allocate(uint64_t numBytes) { + _used += numBytes; + + auto ptr = new int8_t[numBytes]; + return MemoryDescriptor(ptr, zone(), numBytes); + } + + void HotRamZoneManager::release(MemoryDescriptor &descriptor) { + _used -= descriptor.bytes(); + + delete[](reinterpret_cast(descriptor.address())); + } + } +} \ No newline at end of file From b84a51fa1f5224cb3208a0fcaed10fff65b7cd0b Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 18:50:21 +0300 Subject: [PATCH 024/233] additional assert Signed-off-by: raver119 --- libnd4j/include/array/ManagedDataBuffer.h | 2 -- libnd4j/include/array/impl/ManagedDataBuffer.cpp | 6 +++--- libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp | 5 +++++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index 8406e4f42867..f4ab9e46882c 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -33,8 +33,6 @@ namespace sd { graph::GraphMemoryManager &_manager; protected: - uint64_t _bytes; - DataType _dtype; memory::MemoryZone _zone; MemoryDescriptor _descriptor; diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp index a2418281def3..8f2191fa85b2 100644 --- a/libnd4j/include/array/impl/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -23,11 +23,11 @@ namespace sd { ManagedDataBuffer::ManagedDataBuffer(graph::GraphMemoryManager &manager, uint64_t numberOfBytes, DataType dtype, memory::MemoryZone zone) : _manager(manager), - _bytes(numberOfBytes), - _dtype(dtype), _zone(zone), _descriptor(manager.allocate(numberOfBytes, zone)) { - // everything already initialized by now + + _lenInBytes = numberOfBytes; + _dataType = dtype; } ManagedDataBuffer::~ManagedDataBuffer() { diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp index 5f832652dda7..c6eec989573e 100644 --- a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -43,10 +43,15 @@ TEST_F(ManagedDataBufferTests, basic_constructor_test_1) { } TEST_F(ManagedDataBufferTests, basic_constructor_test_2) { + auto exp = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + GraphMemoryManager mgr; auto mdb = std::make_shared(mgr, 20, DataType::FLOAT32, memory::MemoryZone::HOT); ASSERT_NE(nullptr, mdb->platform()); NDArray array(mdb, 'c', {5}); + array.assign(1.0f); + + ASSERT_EQ(exp, array); } \ No newline at end of file From 6312a91526b4254f0de4ba1f46656b3028b9872e Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 18:56:37 +0300 Subject: [PATCH 025/233] .gitignore Signed-off-by: raver119 --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index ad2e28e6fd79..eaa7643bd07e 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,7 @@ venv2/ # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java + +# ignore some specific folders related to CLion +libnd4j/tests_cpu/libnd4j_tests/cmake* +libnd4j/tests_cpu/libnd4j_tests/cmake-build-debug-kraken/ \ No newline at end of file From a28154c81ba10929afa06ddf5af99568d32821d9 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 18:59:28 +0300 Subject: [PATCH 026/233] meh Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp index c6eec989573e..5202fc3a1bc1 100644 --- a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -31,7 +31,7 @@ using namespace sd::graph; class ManagedDataBufferTests : public testing::Test { public: ManagedDataBufferTests() { - // + /// } }; From 7fe07d1b3e6997c32cf6197eb45de0b4d0c6bec7 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 27 Mar 2020 19:03:20 +0300 Subject: [PATCH 027/233] virtual destructor Signed-off-by: raver119 --- libnd4j/include/memory/GraphMemoryManager.h | 6 +++--- libnd4j/include/memory/ZoneManager.h | 2 +- libnd4j/include/memory/cpu/GraphMemoryManager.cpp | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h index 87fe68468919..edcdf9d3cfe0 100644 --- a/libnd4j/include/memory/GraphMemoryManager.h +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -36,7 +36,7 @@ namespace sd { public: GraphMemoryManager(); - ~GraphMemoryManager() = default; + ~GraphMemoryManager(); /** * This method does allocation (probably) and returns structure that describes it @@ -44,13 +44,13 @@ namespace sd { * @param zone - memory zone for allocation * @return */ - MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); + virtual MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); /** * This method releases (probably) memory chunk described by given descriptor * @param descriptor */ - void release(MemoryDescriptor &descriptor); + virtual void release(MemoryDescriptor &descriptor); }; } } diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index a67158bb7dc9..c30095597b0a 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -39,7 +39,7 @@ namespace sd { public: ZoneManager() = default; - ~ZoneManager() = default; + virtual ~ZoneManager() = default; /** * This method returns id of the current zone served by this manager instance diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp index 82f8d50d7f81..47c1653eaed4 100644 --- a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -32,6 +32,11 @@ namespace sd { _zones[MemoryZone::COLD] = new memory::ColdZoneManager(); } + GraphMemoryManager::~GraphMemoryManager() { + delete _zones[MemoryZone::HOT]; + delete _zones[MemoryZone::COLD]; + } + MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, MemoryZone zone) { if (zone == MemoryZone::WARM) zone = MemoryZone::HOT; @@ -40,7 +45,7 @@ namespace sd { } void GraphMemoryManager::release(MemoryDescriptor &descriptor) { - + _zones[descriptor.zone()]->release(descriptor); } } } From b0ad7188c8786de562d43c194cf30b2aeb2089fe Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 11:26:41 +0300 Subject: [PATCH 028/233] first analyzer test Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 1 + libnd4j/include/graph/Node.h | 7 ++ libnd4j/include/graph/execution/OpSequence.h | 7 ++ .../graph/execution/impl/OpSequence.cpp | 4 ++ libnd4j/include/graph/impl/Graph.cpp | 12 +++- libnd4j/include/graph/impl/Node.cpp | 43 ++++++++++++ .../include/ops/declarable/OpRegistrator.h | 2 +- .../ops/declarable/impl/OpRegistrator.cpp | 2 +- .../layers_tests/GraphAnalysisTests.cpp | 66 +++++++++++++++++++ 9 files changed, 139 insertions(+), 5 deletions(-) create mode 100644 libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index e1cdd7887006..578ec4998dad 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -127,6 +127,7 @@ namespace sd { * @param node */ void addNode(sd::graph::Node *node); + void addNode(const sd::graph::Node &node); /** * This method returns layered representation of the graph diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 5fde65f3c16d..6657687cb0e1 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -86,6 +86,9 @@ namespace sd { // each node can be active or inactive, if used with divergents, like IF statements bool _active = true; + // meh + mutable bool _removable = true; + // these fields contain information about Scope these ops are related to int _scope_id = 0; std::string _scope_name; @@ -96,6 +99,7 @@ namespace sd { Nd4jLong _frameId = -1; public: + explicit Node(const std::string &opName, const int id = 0, const std::vector> &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(const sd::graph::FlatNode *node); @@ -127,6 +131,9 @@ namespace sd { bool isMultiInput(); bool isMultiOutput(); + bool isRemovable() const; + void markRemovable(bool reallyRemovable) const; + int getLayer(); void setLayer(int layer); diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index ce8fced76bae..65d96d1ecff0 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -57,6 +57,13 @@ namespace sd { */ uint64_t length() const; + /** + * This method returns specific Op/ContextPrototype pair for specified index + * @param index + * @return + */ + std::pair at(uint64_t index); + /** * This method allows to add DeclarableOp to the end of execution queue * @param op - Op to be executed diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 2354deaec7c5..0bc38ef8568a 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -57,6 +57,10 @@ namespace sd { return *this; } + std::pair OpSequence::at(uint64_t index) { + return _ops[index]; + } + uint64_t OpSequence::length() const { return _ops.size(); } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 2543828048bb..5633c9204ddc 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -288,15 +288,16 @@ namespace sd { Graph::~Graph() { for (auto &v: *_mapped) - delete v.second; + if (v.second->isRemovable()) + delete v.second; for (auto &v: _unmapped) - delete v.second; + if (v.second->isRemovable()) + delete v.second; for (auto &v: *_onion) delete v.second; - for (auto v: _scopes) delete v; @@ -307,6 +308,11 @@ namespace sd { delete _configuration; } + void Graph::addNode(const sd::graph::Node &node) { + node.markRemovable(true); + addNode(const_cast(&node)); + } + void Graph::addNode(Node *node) { _built.store(false); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index e3ea75ef9073..b288096c28b3 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -73,6 +73,14 @@ namespace sd { } } + bool Node::isRemovable() const { + return _removable; + } + + void Node::markRemovable(bool reallyRemovable) const { + _removable = reallyRemovable; + } + OpClass sd::graph::Node::getOpClass() { return _opClass; } @@ -313,6 +321,40 @@ namespace sd { } BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Node* Node::asT, (), LIBND4J_TYPES); + Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { + + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_dim = nullptr; + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + // FIXME: get rid of this!!! + _scalar = NDArrayFactory::create(0); + + for (auto i: inputs) + pickInput(i); + + auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + + for (auto v: iArgs) + block->getIArguments()->emplace_back(v); + + for (auto v: tArgs) + block->getTArguments()->emplace_back(v); + + this->setContextPrototype(block); + } + sd::graph::Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { this->_opType = OpType_CUSTOM; this->_id = id; @@ -327,6 +369,7 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; + // FIXME: get rid of this!!! _scalar = NDArrayFactory::create(scalar); for (auto i: input) diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index 3a9fb3df6e44..d9528ab7a4e3 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -123,7 +123,7 @@ namespace sd { sd::ops::DeclarableOp* getOperation(const char *name); sd::ops::DeclarableOp* getOperation(Nd4jLong hash); - sd::ops::DeclarableOp* getOperation(std::string &name); + sd::ops::DeclarableOp* getOperation(const std::string &name); sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine); diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index 65d694dea6f6..177e3a87e8d8 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -221,7 +221,7 @@ namespace sd { return _declarablesLD.at(hash); } - sd::ops::DeclarableOp *OpRegistrator::getOperation(std::string& name) { + sd::ops::DeclarableOp *OpRegistrator::getOperation(const std::string& name) { if (!_declarablesD.count(name)) { nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); return nullptr; diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp new file mode 100644 index 000000000000..1cb2c0eba87b --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -0,0 +1,66 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphAnalysisTests : public testing::Test { +public: + GraphAnalysisTests() { + /// + } +}; + +TEST_F(GraphAnalysisTests, basic_toposort_test_1) { + Graph graph; + + Node a("multiply", 10); + Node b("add", 20, {{10, 0}}); + + graph.addNode(b); + graph.addNode(a); + + // we just check that nodes were really added + ASSERT_EQ(2, graph.totalNodes()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.layers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.size()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, sequence.length()); + + ASSERT_EQ(10, sequence.at(0).second->nodeId()); + ASSERT_EQ(20, sequence.at(1).second->nodeId()); +} \ No newline at end of file From 16e544ea4ff4116b18350cfcdd148d756afa14d9 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 11:36:54 +0300 Subject: [PATCH 029/233] second analyzer test Signed-off-by: raver119 --- libnd4j/include/graph/VariableProxy.h | 2 +- libnd4j/include/graph/VariableSpace.h | 2 +- libnd4j/include/graph/impl/VariableProxy.cpp | 2 +- libnd4j/include/graph/impl/VariableSpace.cpp | 4 +- .../layers_tests/GraphAnalysisTests.cpp | 59 ++++++++++++++++++- 5 files changed, 62 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 1569b477d23b..b917a9192bcf 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -58,7 +58,7 @@ namespace sd { virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, NDArray *array); virtual Variable* putVariable(int id, int idx, NDArray *array); - virtual void putVariable(int id, int idx, NDArray &array); + void putVariable(int id, int idx, const NDArray &array) override; virtual void putVariable(int id, int idx, Variable *array); virtual void replaceVariable(Variable *variable); diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index ea3c6370d16c..619725e7a1d1 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -100,7 +100,7 @@ namespace sd { virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, NDArray *array); virtual Variable* putVariable(int id, int idx, NDArray *array); - virtual void putVariable(int id, int idx, NDArray &array); + virtual void putVariable(int id, int idx, const NDArray &array); virtual void putVariable(int id, int idx, Variable *array); virtual void dropVariable(std::pair &pair); diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index 2736e2a9e32b..8ec199d8762c 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -191,7 +191,7 @@ namespace sd { _current->putVariable(id, array); } - void sd::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) { + void sd::graph::VariableProxy::putVariable(int id, int idx, const NDArray &array) { _current->putVariable(id, idx, array); } diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 0e8634d07303..a8fe297ee215 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -305,8 +305,8 @@ namespace sd { } } - void sd::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) { - auto *var = new sd::graph::Variable(&array, "", id, idx); + void sd::graph::VariableSpace::putVariable(int id, int idx, const NDArray &array) { + auto *var = new sd::graph::Variable(const_cast(&array), "", id, idx); var->markRemovable(false); var->markReadOnly(true); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 1cb2c0eba87b..78bced33a21b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -38,8 +38,17 @@ class GraphAnalysisTests : public testing::Test { TEST_F(GraphAnalysisTests, basic_toposort_test_1) { Graph graph; - Node a("multiply", 10); - Node b("add", 20, {{10, 0}}); + // A + graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a("multiply", 10, {{-1, 0}, {-2, 0}}); + Node b("add", 20, {{10, 0}, {-3, 0}}); graph.addNode(b); graph.addNode(a); @@ -63,4 +72,50 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { ASSERT_EQ(10, sequence.at(0).second->nodeId()); ASSERT_EQ(20, sequence.at(1).second->nodeId()); +} + +TEST_F(GraphAnalysisTests, basic_toposort_test_2) { + Graph graph; + + // A + graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.getVariableSpace()->putVariable(-4, 0, NDArrayFactory::create('c', {3}, {4, 4, 4})); + + Node a("multiply", 10, {{-1, 0}, {-2, 0}}); + Node b("add", 20, {{10, 0}, {-3, 0}}); + Node c("subtract", 30, {{10, 0}, {-4, 0}}); + + + graph.addNode(b); + graph.addNode(c); + graph.addNode(a); + + // we just check that nodes were really added + ASSERT_EQ(3, graph.totalNodes()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(2, optimized.layers()); + + // checking first layer first + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.size()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + + ASSERT_EQ(10, sequence.at(0).second->nodeId()); + } \ No newline at end of file From eade08e9d16d1e83b8c427cb34a31461a65a848d Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 11:48:05 +0300 Subject: [PATCH 030/233] minor fix Signed-off-by: raver119 --- libnd4j/include/graph/impl/Graph.cpp | 2 +- .../layers_tests/GraphAnalysisTests.cpp | 21 ++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 5633c9204ddc..5c0ead9feea0 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -309,7 +309,7 @@ namespace sd { } void Graph::addNode(const sd::graph::Node &node) { - node.markRemovable(true); + node.markRemovable(false); addNode(const_cast(&node)); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 78bced33a21b..e46e2b8d45fb 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -107,15 +107,30 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { ASSERT_EQ(2, optimized.layers()); // checking first layer first - auto layer = optimized.layer(0); + auto layer0 = optimized.layer(0); // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.size()); - auto sequence = layer[0]; + ASSERT_EQ(1, layer0.size()); + auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops ASSERT_EQ(1, sequence.length()); ASSERT_EQ(10, sequence.at(0).second->nodeId()); + // checking second layer now + auto layer1 = optimized.layer(0); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.size()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(20, sequence.at(0).second->nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(30, sequence.at(0).second->nodeId()); } \ No newline at end of file From c743a35d4df10306d8b3120e5894ee1e5399b317 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 12:13:20 +0300 Subject: [PATCH 031/233] additional convenience operator for OpSequence Signed-off-by: raver119 --- libnd4j/include/graph/execution/OpSequence.h | 3 ++- libnd4j/include/graph/execution/impl/OpSequence.cpp | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 65d96d1ecff0..56e4ffcd2779 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -62,7 +62,8 @@ namespace sd { * @param index * @return */ - std::pair at(uint64_t index); + std::pair at(uint64_t index) const; + std::pair operator[](uint64_t index) const; /** * This method allows to add DeclarableOp to the end of execution queue diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0bc38ef8568a..0cc41d0f3e9b 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -57,10 +57,14 @@ namespace sd { return *this; } - std::pair OpSequence::at(uint64_t index) { + std::pair OpSequence::at(uint64_t index) const { return _ops[index]; } + std::pair OpSequence::operator[](uint64_t index) const { + return at(index); + } + uint64_t OpSequence::length() const { return _ops.size(); } From d920ca939547b6559cfa28d33ac7d7916cb1afce Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 15:05:36 +0300 Subject: [PATCH 032/233] ExecutionLayer abstraction instead of std::vector for OptimizedGraph Signed-off-by: raver119 --- libnd4j/include/array/ArrayOptions.h | 2 +- libnd4j/include/array/ByteOrderUtils.h | 2 +- libnd4j/include/array/ConstantDataBuffer.h | 2 +- libnd4j/include/array/ConstantDescriptor.h | 4 +- libnd4j/include/array/DataBuffer.h | 2 +- libnd4j/include/array/DataTypeConversions.h | 2 +- libnd4j/include/array/DataTypeUtils.h | 2 +- libnd4j/include/array/ExtraArguments.h | 2 +- libnd4j/include/array/InteropDataBuffer.h | 2 +- libnd4j/include/array/ManagedDataBuffer.h | 2 +- libnd4j/include/array/NDArray.h | 44 +- libnd4j/include/array/NDArray.hXX | 554 +++++++++--------- libnd4j/include/array/NDArrayFactory.h | 2 +- libnd4j/include/array/NDArrayList.h | 2 +- libnd4j/include/array/ResultSet.h | 2 +- libnd4j/include/array/ShapeDescriptor.h | 4 +- libnd4j/include/array/ShapeList.h | 2 +- libnd4j/include/array/TadDescriptor.h | 4 +- libnd4j/include/array/TadPack.h | 2 +- libnd4j/include/array/cuda/NDArray.cu | 2 +- .../include/array/impl/ConstantDataBuffer.cpp | 16 +- libnd4j/include/array/impl/ConstantHolder.cpp | 6 +- libnd4j/include/array/impl/ExtraArguments.cpp | 4 +- libnd4j/include/array/impl/NDArrayFactory.cpp | 322 +++++----- libnd4j/include/cnpy/cnpy.h | 46 +- .../include/exceptions/allocation_exception.h | 2 +- libnd4j/include/exceptions/cuda_exception.h | 2 +- .../include/exceptions/datatype_exception.h | 2 +- libnd4j/include/exceptions/graph_exception.h | 2 +- .../exceptions/graph_execution_exception.h | 2 +- .../exceptions/graph_exists_exception.h | 2 +- .../include/exceptions/no_results_exception.h | 2 +- .../exceptions/unknown_graph_exception.h | 2 +- libnd4j/include/execution/AffinityManager.h | 2 +- libnd4j/include/execution/ContextBuffers.h | 2 +- libnd4j/include/execution/ErrorReference.h | 2 +- libnd4j/include/execution/LaunchContext.h | 2 +- libnd4j/include/execution/ThreadPool.h | 2 +- libnd4j/include/execution/Threads.h | 10 +- libnd4j/include/execution/Ticket.h | 2 +- libnd4j/include/graph/ArgumentsList.h | 2 +- libnd4j/include/graph/Context.h | 2 +- libnd4j/include/graph/ContextPrototype.h | 2 +- libnd4j/include/graph/ExecutorConfiguration.h | 2 +- libnd4j/include/graph/FlatUtils.h | 2 +- libnd4j/include/graph/FlowPath.h | 2 +- libnd4j/include/graph/FrameState.h | 2 +- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/GraphExecutioner.h | 2 +- libnd4j/include/graph/GraphHolder.h | 2 +- libnd4j/include/graph/GraphState.h | 2 +- libnd4j/include/graph/GraphUtils.h | 2 +- libnd4j/include/graph/InferenceRequest.h | 2 +- libnd4j/include/graph/Intervals.h | 2 +- libnd4j/include/graph/Node.h | 2 +- libnd4j/include/graph/NodeState.h | 2 +- libnd4j/include/graph/OptimizedGraph.h | 11 +- libnd4j/include/graph/RandomGenerator.h | 6 +- libnd4j/include/graph/ResultWrapper.h | 2 +- libnd4j/include/graph/Scope.h | 2 +- libnd4j/include/graph/SessionLocalStorage.h | 2 +- libnd4j/include/graph/Stash.h | 6 +- libnd4j/include/graph/Status.h | 2 +- libnd4j/include/graph/TimeHolder.h | 2 +- libnd4j/include/graph/Variable.h | 10 +- libnd4j/include/graph/VariableProxy.h | 2 +- libnd4j/include/graph/VariableSpace.h | 2 +- libnd4j/include/graph/VariablesSet.h | 2 +- .../include/graph/execution/ExecutionLayer.h | 70 +++ libnd4j/include/graph/execution/OpSequence.h | 2 +- .../graph/execution/impl/ExecutionLayer.cpp | 72 +++ libnd4j/include/graph/impl/Node.cpp | 2 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 10 +- libnd4j/include/graph/impl/Variable.cpp | 2 +- .../include/graph/profiling/GraphProfile.h | 2 +- libnd4j/include/graph/profiling/NodeProfile.h | 2 +- libnd4j/include/helpers/AttentionHelper.h | 2 +- libnd4j/include/helpers/BenchmarkHelper.h | 2 +- libnd4j/include/helpers/BitwiseUtils.h | 2 +- libnd4j/include/helpers/ConstantHelper.h | 2 +- libnd4j/include/helpers/ConstantShapeHelper.h | 2 +- libnd4j/include/helpers/ConstantTadHelper.h | 2 +- libnd4j/include/helpers/CudaLaunchHelper.h | 2 +- libnd4j/include/helpers/DebugHelper.h | 2 +- libnd4j/include/helpers/DebugInfo.h | 2 +- libnd4j/include/helpers/FileUtils.h | 2 +- libnd4j/include/helpers/GradCheck.h | 2 +- libnd4j/include/helpers/LoopKind.h | 2 +- libnd4j/include/helpers/Loops.h | 14 +- libnd4j/include/helpers/MmulHelper.h | 2 +- libnd4j/include/helpers/OmpLaunchHelper.h | 2 +- libnd4j/include/helpers/OpArgsHolder.h | 2 +- libnd4j/include/helpers/OpBenchmark.h | 2 +- libnd4j/include/helpers/OpTracker.h | 2 +- libnd4j/include/helpers/PointersManager.h | 2 +- libnd4j/include/helpers/RandomLauncher.h | 2 +- libnd4j/include/helpers/ShapeBuilders.h | 2 +- libnd4j/include/helpers/ShapeUtils.h | 2 +- libnd4j/include/helpers/SimpleReadWriteLock.h | 2 +- libnd4j/include/helpers/StringUtils.h | 2 +- .../helpers/benchmark/BroadcastBenchmark.h | 2 +- .../helpers/benchmark/DeclarableBenchmark.h | 2 +- .../helpers/benchmark/MatrixBenchmark.h | 2 +- .../helpers/benchmark/PairwiseBenchmark.h | 2 +- .../helpers/benchmark/ReductionBenchmark.h | 2 +- .../helpers/benchmark/ScalarBenchmark.h | 2 +- .../helpers/benchmark/TransformBenchmark.h | 2 +- libnd4j/include/helpers/cpu/householder.cpp | 8 +- libnd4j/include/helpers/cpu/jacobiSVD.cpp | 8 +- .../helpers/cpu/loops/Reduction3Loops_0.cpp | 2 +- .../helpers/cpu/loops/Reduction3Loops_1.cpp | 2 +- .../helpers/cpu/loops/Reduction3Loops_2.cpp | 2 +- .../helpers/cpu/loops/Reduction3Loops_3.cpp | 2 +- .../helpers/cpu/loops/ReductionLoops_bool.cpp | 2 +- .../cpu/loops/ReductionLoops_float_0.cpp | 2 +- .../cpu/loops/ReductionLoops_float_1.cpp | 2 +- .../cpu/loops/ReductionLoops_float_2.cpp | 2 +- .../cpu/loops/ReductionLoops_float_3.cpp | 2 +- .../helpers/cpu/loops/ReductionLoops_long.cpp | 2 +- libnd4j/include/helpers/cpu/svd.cpp | 2 +- libnd4j/include/helpers/cublasHelper.h | 2 +- libnd4j/include/helpers/helper_generator.h | 10 +- libnd4j/include/helpers/helper_hash.h | 2 +- libnd4j/include/helpers/logger.h | 2 +- libnd4j/include/helpers/shape.h | 370 ++++++------ libnd4j/include/indexing/IndicesList.h | 2 +- libnd4j/include/indexing/NDIndex.h | 8 +- libnd4j/include/legacy/NativeOpExecutioner.h | 2 +- libnd4j/include/legacy/NativeOps.h | 470 +++++++-------- libnd4j/include/legacy/impl/cnpy.cpp | 6 +- .../include/loops/cpu/broadcasting_bool.hpp | 2 +- .../include/loops/cpu/broadcasting_int.hpp | 2 +- .../compilation_units/broadcast_bool_p0.cpp | 2 +- .../compilation_units/broadcast_bool_p1.cpp | 2 +- .../compilation_units/broadcast_bool_p2.cpp | 2 +- .../compilation_units/broadcast_bool_p3.cpp | 2 +- .../compilation_units/broadcast_bool_p4.cpp | 2 +- .../compilation_units/broadcast_bool_p5.cpp | 2 +- .../compilation_units/broadcast_bool_p6.cpp | 2 +- .../compilation_units/broadcast_bool_p7.cpp | 2 +- .../compilation_units/broadcast_bool_p8.cpp | 2 +- .../compilation_units/broadcast_bool_p9.cpp | 2 +- .../compilation_units/broadcast_int_p0.cpp | 2 +- .../compilation_units/broadcast_int_p1.cpp | 2 +- .../compilation_units/broadcast_int_p2.cpp | 2 +- .../compilation_units/broadcast_int_p3.cpp | 2 +- .../compilation_units/broadcast_int_p4.cpp | 2 +- .../compilation_units/broadcast_int_p5.cpp | 2 +- .../compilation_units/broadcast_int_p6.cpp | 2 +- .../compilation_units/broadcast_int_p7.cpp | 2 +- .../cpu/compilation_units/broadcast_p0.cpp | 2 +- .../cpu/compilation_units/broadcast_p1.cpp | 2 +- .../cpu/compilation_units/broadcast_p10.cpp | 2 +- .../cpu/compilation_units/broadcast_p11.cpp | 2 +- .../cpu/compilation_units/broadcast_p12.cpp | 2 +- .../cpu/compilation_units/broadcast_p2.cpp | 2 +- .../cpu/compilation_units/broadcast_p3.cpp | 2 +- .../cpu/compilation_units/broadcast_p4.cpp | 2 +- .../cpu/compilation_units/broadcast_p5.cpp | 2 +- .../cpu/compilation_units/broadcast_p6.cpp | 2 +- .../cpu/compilation_units/broadcast_p7.cpp | 2 +- .../cpu/compilation_units/broadcast_p8.cpp | 2 +- .../cpu/compilation_units/broadcast_p9.cpp | 2 +- .../compilation_units/indexreduce_int32_0.cpp | 2 +- .../compilation_units/indexreduce_int32_1.cpp | 2 +- .../compilation_units/indexreduce_int32_2.cpp | 2 +- .../compilation_units/indexreduce_int32_3.cpp | 2 +- .../compilation_units/indexreduce_int32_4.cpp | 2 +- .../compilation_units/indexreduce_int32_5.cpp | 2 +- .../compilation_units/indexreduce_int32_6.cpp | 2 +- .../compilation_units/indexreduce_int32_7.cpp | 2 +- .../compilation_units/indexreduce_int32_8.cpp | 2 +- .../compilation_units/indexreduce_int32_9.cpp | 2 +- .../compilation_units/indexreduce_int64_0.cpp | 2 +- .../compilation_units/indexreduce_int64_1.cpp | 2 +- .../compilation_units/indexreduce_int64_2.cpp | 2 +- .../compilation_units/indexreduce_int64_3.cpp | 2 +- .../compilation_units/indexreduce_int64_4.cpp | 2 +- .../compilation_units/indexreduce_int64_5.cpp | 2 +- .../compilation_units/indexreduce_int64_6.cpp | 2 +- .../compilation_units/indexreduce_int64_7.cpp | 2 +- .../compilation_units/indexreduce_int64_8.cpp | 2 +- .../compilation_units/indexreduce_int64_9.cpp | 2 +- .../cpu/compilation_units/pairwise_p0.cpp | 2 +- .../cpu/compilation_units/pairwise_p1.cpp | 2 +- .../cpu/compilation_units/pairwise_p10.cpp | 2 +- .../cpu/compilation_units/pairwise_p11.cpp | 2 +- .../cpu/compilation_units/pairwise_p12.cpp | 2 +- .../cpu/compilation_units/pairwise_p2.cpp | 2 +- .../cpu/compilation_units/pairwise_p3.cpp | 2 +- .../cpu/compilation_units/pairwise_p4.cpp | 2 +- .../cpu/compilation_units/pairwise_p5.cpp | 2 +- .../cpu/compilation_units/pairwise_p6.cpp | 2 +- .../cpu/compilation_units/pairwise_p7.cpp | 2 +- .../cpu/compilation_units/pairwise_p8.cpp | 2 +- .../cpu/compilation_units/pairwise_p9.cpp | 2 +- .../loops/cpu/compilation_units/random_0.cpp | 2 +- .../loops/cpu/compilation_units/random_1.cpp | 2 +- .../loops/cpu/compilation_units/random_2.cpp | 2 +- .../loops/cpu/compilation_units/random_3.cpp | 2 +- .../compilation_units/reduce3_bfloat16_0.cpp | 2 +- .../compilation_units/reduce3_bfloat16_1.cpp | 2 +- .../compilation_units/reduce3_bfloat16_2.cpp | 2 +- .../compilation_units/reduce3_bfloat16_3.cpp | 2 +- .../compilation_units/reduce3_bfloat16_4.cpp | 2 +- .../compilation_units/reduce3_bfloat16_5.cpp | 2 +- .../compilation_units/reduce3_bfloat16_6.cpp | 2 +- .../compilation_units/reduce3_bfloat16_7.cpp | 2 +- .../compilation_units/reduce3_bfloat16_8.cpp | 2 +- .../compilation_units/reduce3_bfloat16_9.cpp | 2 +- .../compilation_units/reduce3_double_0.cpp | 2 +- .../compilation_units/reduce3_double_1.cpp | 2 +- .../compilation_units/reduce3_double_2.cpp | 2 +- .../compilation_units/reduce3_double_3.cpp | 2 +- .../compilation_units/reduce3_double_4.cpp | 2 +- .../compilation_units/reduce3_double_5.cpp | 2 +- .../compilation_units/reduce3_double_6.cpp | 2 +- .../compilation_units/reduce3_double_7.cpp | 2 +- .../compilation_units/reduce3_double_8.cpp | 2 +- .../compilation_units/reduce3_double_9.cpp | 2 +- .../compilation_units/reduce3_float16_0.cpp | 2 +- .../compilation_units/reduce3_float16_1.cpp | 2 +- .../compilation_units/reduce3_float16_2.cpp | 2 +- .../compilation_units/reduce3_float16_3.cpp | 2 +- .../compilation_units/reduce3_float16_4.cpp | 2 +- .../compilation_units/reduce3_float16_5.cpp | 2 +- .../compilation_units/reduce3_float16_6.cpp | 2 +- .../compilation_units/reduce3_float16_7.cpp | 2 +- .../compilation_units/reduce3_float16_8.cpp | 2 +- .../compilation_units/reduce3_float16_9.cpp | 2 +- .../cpu/compilation_units/reduce3_float_0.cpp | 2 +- .../cpu/compilation_units/reduce3_float_1.cpp | 2 +- .../cpu/compilation_units/reduce3_float_2.cpp | 2 +- .../cpu/compilation_units/reduce3_float_3.cpp | 2 +- .../cpu/compilation_units/reduce3_float_4.cpp | 2 +- .../cpu/compilation_units/reduce3_float_5.cpp | 2 +- .../cpu/compilation_units/reduce3_float_6.cpp | 2 +- .../cpu/compilation_units/reduce3_float_7.cpp | 2 +- .../cpu/compilation_units/reduce3_float_8.cpp | 2 +- .../cpu/compilation_units/reduce3_float_9.cpp | 2 +- .../cpu/compilation_units/reduce_float_0.cpp | 2 +- .../cpu/compilation_units/reduce_float_1.cpp | 2 +- .../cpu/compilation_units/reduce_float_2.cpp | 2 +- .../cpu/compilation_units/reduce_float_3.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p0.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p1.cpp | 2 +- .../cpu/compilation_units/scalar_p10.cpp | 2 +- .../cpu/compilation_units/scalar_p11.cpp | 2 +- .../cpu/compilation_units/scalar_p12.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p2.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p3.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p4.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p5.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p6.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p7.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p8.cpp | 2 +- .../loops/cpu/compilation_units/scalar_p9.cpp | 2 +- libnd4j/include/loops/cpu/pairwise_bool.cpp | 2 +- libnd4j/include/loops/cpu/pairwise_int.cpp | 2 +- libnd4j/include/loops/cpu/random.hpp | 2 +- .../include/loops/cpu/reduce/reduce_bool.cpp | 2 +- .../include/loops/cpu/reduce/reduce_long.cpp | 2 +- .../include/loops/cpu/reduce/reduce_same.cpp | 2 +- libnd4j/include/loops/cpu/scalar_bool.cpp | 2 +- libnd4j/include/loops/cpu/scalar_int.cpp | 2 +- .../include/loops/cpu/summarystatsreduce.cpp | 2 +- .../loops/cpu/transform/transform_any.cpp | 2 +- .../loops/cpu/transform/transform_bool.cpp | 2 +- .../loops/cpu/transform/transform_float.cpp | 2 +- .../loops/cpu/transform/transform_same.cpp | 2 +- .../loops/cpu/transform/transform_strict.cpp | 2 +- libnd4j/include/loops/cuda/broadcasting.chpp | 20 +- .../include/loops/cuda/broadcasting_bool.cu | 2 +- .../include/loops/cuda/broadcasting_int.cu | 2 +- .../broadcasting/broadcasting_0.cu | 2 +- .../broadcasting/broadcasting_1.cu | 2 +- .../broadcasting/broadcasting_10.cu | 2 +- .../broadcasting/broadcasting_11.cu | 2 +- .../broadcasting/broadcasting_12.cu | 2 +- .../broadcasting/broadcasting_2.cu | 2 +- .../broadcasting/broadcasting_3.cu | 2 +- .../broadcasting/broadcasting_4.cu | 2 +- .../broadcasting/broadcasting_5.cu | 2 +- .../broadcasting/broadcasting_6.cu | 2 +- .../broadcasting/broadcasting_7.cu | 2 +- .../broadcasting/broadcasting_8.cu | 2 +- .../broadcasting/broadcasting_9.cu | 2 +- .../compilation_units/pairwise/pairwise_0.cu | 2 +- .../compilation_units/pairwise/pairwise_1.cu | 2 +- .../compilation_units/pairwise/pairwise_10.cu | 2 +- .../compilation_units/pairwise/pairwise_11.cu | 2 +- .../compilation_units/pairwise/pairwise_12.cu | 2 +- .../compilation_units/pairwise/pairwise_2.cu | 2 +- .../compilation_units/pairwise/pairwise_3.cu | 2 +- .../compilation_units/pairwise/pairwise_4.cu | 2 +- .../compilation_units/pairwise/pairwise_5.cu | 2 +- .../compilation_units/pairwise/pairwise_6.cu | 2 +- .../compilation_units/pairwise/pairwise_7.cu | 2 +- .../compilation_units/pairwise/pairwise_8.cu | 2 +- .../compilation_units/pairwise/pairwise_9.cu | 2 +- .../compilation_units/reduce3/reduce3_0.cu | 2 +- .../compilation_units/reduce3/reduce3_1.cu | 2 +- .../compilation_units/reduce3/reduce3_2.cu | 2 +- .../compilation_units/reduce3/reduce3_3.cu | 2 +- .../reduce_float/reduce_float_0.cu | 2 +- .../reduce_float/reduce_float_1.cu | 2 +- .../reduce_float/reduce_float_2.cu | 2 +- .../reduce_float/reduce_float_3.cu | 2 +- .../cuda/compilation_units/scalar/scalar_0.cu | 2 +- .../cuda/compilation_units/scalar/scalar_1.cu | 2 +- .../compilation_units/scalar/scalar_10.cu | 2 +- .../compilation_units/scalar/scalar_11.cu | 2 +- .../compilation_units/scalar/scalar_12.cu | 2 +- .../cuda/compilation_units/scalar/scalar_2.cu | 2 +- .../cuda/compilation_units/scalar/scalar_3.cu | 2 +- .../cuda/compilation_units/scalar/scalar_4.cu | 2 +- .../cuda/compilation_units/scalar/scalar_5.cu | 2 +- .../cuda/compilation_units/scalar/scalar_6.cu | 2 +- .../cuda/compilation_units/scalar/scalar_7.cu | 2 +- .../cuda/compilation_units/scalar/scalar_8.cu | 2 +- .../cuda/compilation_units/scalar/scalar_9.cu | 2 +- libnd4j/include/loops/cuda/indexreduce.cu | 2 +- .../loops/cuda/legacy/transform.legacy | 6 +- libnd4j/include/loops/cuda/pairwise.chpp | 20 +- libnd4j/include/loops/cuda/pairwise_bool.cu | 2 +- libnd4j/include/loops/cuda/pairwise_int.cu | 2 +- libnd4j/include/loops/cuda/random.cu | 2 +- .../include/loops/cuda/reduce/reduce_bool.cu | 2 +- .../loops/cuda/reduce/reduce_float.chpp | 2 +- .../include/loops/cuda/reduce/reduce_long.cu | 2 +- .../include/loops/cuda/reduce/reduce_same.cu | 2 +- libnd4j/include/loops/cuda/reduce3.chpp | 2 +- libnd4j/include/loops/cuda/scalar_bool.cu | 2 +- libnd4j/include/loops/cuda/scalar_int.cu | 2 +- .../loops/cuda/specials/accumulateKernel.cu | 2 +- .../loops/cuda/specials/averagingKernel.cu | 2 +- .../cuda/specials/bitonicArbitraryStep.cu | 4 +- .../loops/cuda/specials/bitonicSortStep.cu | 4 +- .../loops/cuda/specials/concatKernel.cu | 2 +- .../loops/cuda/specials/concatKernelHStack.cu | 2 +- .../loops/cuda/specials/concatKernelScalar.cu | 2 +- .../loops/cuda/specials/concatKernelVStack.cu | 2 +- .../loops/cuda/specials/convertHalfs.cu | 2 +- .../loops/cuda/specials/convertToHalf.cu | 2 +- .../cuda/specials/fillDimensionalIsMax.cu | 2 +- .../include/loops/cuda/specials/fillIsMax.cu | 2 +- .../include/loops/cuda/specials/flatten.cu | 2 +- libnd4j/include/loops/cuda/specials/oesTad.cu | 4 +- .../loops/cuda/specials/pullRowsKernel.cu | 2 +- .../loops/cuda/specials/shuffleKernel.cu | 2 +- .../include/loops/cuda/specials/tearKernel.cu | 2 +- .../include/loops/cuda/summarystatsreduce.cu | 2 +- .../loops/cuda/transform/transform_any.cu | 2 +- .../loops/cuda/transform/transform_bool.cu | 2 +- .../loops/cuda/transform/transform_float.cu | 2 +- .../loops/cuda/transform/transform_same.cu | 2 +- .../loops/cuda/transform/transform_strict.cu | 2 +- .../include/loops/cuda/type_conversions.cu | 10 +- libnd4j/include/loops/transform_any.h | 2 +- libnd4j/include/loops/transform_bool.h | 2 +- libnd4j/include/loops/transform_float.h | 2 +- libnd4j/include/loops/transform_same.h | 2 +- libnd4j/include/loops/transform_strict.h | 2 +- libnd4j/include/memory/ExternalWorkspace.h | 2 +- libnd4j/include/memory/HotZoneManager.h | 2 +- libnd4j/include/memory/MemoryCounter.h | 2 +- libnd4j/include/memory/MemoryDescriptor.h | 2 +- libnd4j/include/memory/MemoryRegistrator.h | 2 +- libnd4j/include/memory/MemoryReport.h | 2 +- libnd4j/include/memory/MemoryTracker.h | 2 +- libnd4j/include/memory/MemoryUtils.h | 2 +- libnd4j/include/memory/WarmZoneManager.h | 2 +- libnd4j/include/memory/Workspace.h | 2 +- libnd4j/include/memory/ZoneManager.h | 2 +- libnd4j/include/ops/BroadcastBoolOpsTuple.h | 2 +- libnd4j/include/ops/BroadcastIntOpsTuple.h | 2 +- libnd4j/include/ops/BroadcastOpsTuple.h | 2 +- libnd4j/include/ops/declarable/BooleanOp.h | 2 +- .../include/ops/declarable/BroadcastableOp.h | 2 +- .../include/ops/declarable/CustomOperations.h | 2 +- .../ops/declarable/DeclarableCustomOp.h | 2 +- .../include/ops/declarable/DeclarableListOp.h | 2 +- libnd4j/include/ops/declarable/DeclarableOp.h | 4 +- .../ops/declarable/DeclarableReductionOp.h | 2 +- .../ops/declarable/LegacyBroadcastBoolOp.h | 2 +- .../ops/declarable/LegacyBroadcastOp.h | 2 +- .../ops/declarable/LegacyIndexReduceOp.h | 2 +- libnd4j/include/ops/declarable/LegacyOp.h | 2 +- .../LegacyPairwiseTransformBoolOp.h | 2 +- .../declarable/LegacyPairwiseTransformOp.h | 2 +- .../include/ops/declarable/LegacyRandomOp.h | 2 +- .../include/ops/declarable/LegacyReduce3Op.h | 2 +- .../ops/declarable/LegacyReduceBoolOp.h | 2 +- .../ops/declarable/LegacyReduceFloatOp.h | 2 +- .../ops/declarable/LegacyReduceLongOp.h | 2 +- .../include/ops/declarable/LegacyReduceOp.h | 2 +- .../ops/declarable/LegacyReduceSameOp.h | 2 +- .../ops/declarable/LegacyScalarBoolOp.h | 2 +- .../include/ops/declarable/LegacyScalarOp.h | 2 +- .../include/ops/declarable/LegacyStatsOp.h | 2 +- .../ops/declarable/LegacyTransformAnyOp.h | 2 +- .../ops/declarable/LegacyTransformBoolOp.h | 2 +- .../ops/declarable/LegacyTransformFloatOp.h | 2 +- .../ops/declarable/LegacyTransformOp.h | 2 +- .../ops/declarable/LegacyTransformSameOp.h | 2 +- .../ops/declarable/LegacyTransformStrictOp.h | 2 +- libnd4j/include/ops/declarable/LogicOp.h | 2 +- libnd4j/include/ops/declarable/OpDescriptor.h | 2 +- .../include/ops/declarable/OpRegistrator.h | 2 +- libnd4j/include/ops/declarable/OpTuple.h | 2 +- .../include/ops/declarable/PlatformHelper.h | 2 +- .../ops/declarable/helpers/activations.h | 18 +- .../include/ops/declarable/helpers/col2im.h | 2 +- .../ops/declarable/helpers/convolutions.h | 2 +- .../ops/declarable/helpers/cpu/svd.cpp | 2 +- .../include/ops/declarable/helpers/im2col.h | 2 +- .../ops/declarable/helpers/lstmLayer.h | 4 +- .../ops/declarable/helpers/multiUnique.h | 2 +- libnd4j/include/ops/impl/specials_sparse.cpp | 2 +- libnd4j/include/ops/specials.h | 6 +- .../performance/benchmarking/BenchmarkSuit.h | 2 +- libnd4j/include/system/BlasVersionHelper.h | 2 +- libnd4j/include/system/Environment.h | 2 +- libnd4j/include/system/dll.h | 4 +- libnd4j/include/system/op_boilerplate.h | 30 +- libnd4j/include/system/platform_boilerplate.h | 4 +- libnd4j/include/types/pair.h | 2 +- libnd4j/include/types/triple.h | 2 +- libnd4j/include/types/utf8string.h | 2 +- .../layers_tests/GraphAnalysisTests.cpp | 6 +- .../layers_tests/OpSequenceTests.cpp | 2 +- 431 files changed, 1585 insertions(+), 1432 deletions(-) create mode 100644 libnd4j/include/graph/execution/ExecutionLayer.h create mode 100644 libnd4j/include/graph/execution/impl/ExecutionLayer.cpp diff --git a/libnd4j/include/array/ArrayOptions.h b/libnd4j/include/array/ArrayOptions.h index 1f0c2570503a..0e874155d484 100644 --- a/libnd4j/include/array/ArrayOptions.h +++ b/libnd4j/include/array/ArrayOptions.h @@ -88,7 +88,7 @@ namespace sd { - class ND4J_EXPORT ArrayOptions { + class SD_EXPORT ArrayOptions { private: static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape); diff --git a/libnd4j/include/array/ByteOrderUtils.h b/libnd4j/include/array/ByteOrderUtils.h index 0f335ea65f93..b8d16d81119a 100644 --- a/libnd4j/include/array/ByteOrderUtils.h +++ b/libnd4j/include/array/ByteOrderUtils.h @@ -26,7 +26,7 @@ #include namespace sd { - class ND4J_EXPORT ByteOrderUtils { + class SD_EXPORT ByteOrderUtils { public: static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order); }; diff --git a/libnd4j/include/array/ConstantDataBuffer.h b/libnd4j/include/array/ConstantDataBuffer.h index c3c6923f35be..17e8a22c5de9 100644 --- a/libnd4j/include/array/ConstantDataBuffer.h +++ b/libnd4j/include/array/ConstantDataBuffer.h @@ -25,7 +25,7 @@ namespace sd { - class ND4J_EXPORT ConstantDataBuffer { + class SD_EXPORT ConstantDataBuffer { private: Nd4jPointer _primaryBuffer = nullptr; Nd4jPointer _specialBuffer = nullptr; diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index 09efbcad6743..664bb819abc9 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -29,7 +29,7 @@ #include namespace sd { - class ND4J_EXPORT ConstantDescriptor { + class SD_EXPORT ConstantDescriptor { private: std::vector _integerValues; std::vector _floatValues; @@ -63,7 +63,7 @@ namespace sd { namespace std { template<> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const sd::ConstantDescriptor &k) const; }; diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 4b0c928440e9..1d0f93400915 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -32,7 +32,7 @@ namespace sd { -class ND4J_EXPORT DataBuffer { +class SD_EXPORT DataBuffer { protected: diff --git a/libnd4j/include/array/DataTypeConversions.h b/libnd4j/include/array/DataTypeConversions.h index 44f55553373b..a3190728749f 100644 --- a/libnd4j/include/array/DataTypeConversions.h +++ b/libnd4j/include/array/DataTypeConversions.h @@ -33,7 +33,7 @@ namespace sd { template - class ND4J_EXPORT DataTypeConversions { + class SD_EXPORT DataTypeConversions { private: template static FORCEINLINE void rconv(bool isBe, bool canKeep, T *buffer, Nd4jLong length, void *src) { diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index bd89605d153b..a0c40a89f312 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -35,7 +35,7 @@ #include namespace sd { - class ND4J_EXPORT DataTypeUtils { + class SD_EXPORT DataTypeUtils { public: static int asInt(DataType type); static DataType fromInt(int dtype); diff --git a/libnd4j/include/array/ExtraArguments.h b/libnd4j/include/array/ExtraArguments.h index 337667eea630..9404447dce65 100644 --- a/libnd4j/include/array/ExtraArguments.h +++ b/libnd4j/include/array/ExtraArguments.h @@ -29,7 +29,7 @@ #include namespace sd { - class ND4J_EXPORT ExtraArguments { + class SD_EXPORT ExtraArguments { private: std::vector _fpArgs; std::vector _intArgs; diff --git a/libnd4j/include/array/InteropDataBuffer.h b/libnd4j/include/array/InteropDataBuffer.h index 27b17aabb568..10d14b6133b4 100644 --- a/libnd4j/include/array/InteropDataBuffer.h +++ b/libnd4j/include/array/InteropDataBuffer.h @@ -30,7 +30,7 @@ namespace sd { /** * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages */ - class ND4J_EXPORT InteropDataBuffer { + class SD_EXPORT InteropDataBuffer { private: std::shared_ptr _dataBuffer; uint64_t _offset = 0; diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index f4ab9e46882c..7cb32f065459 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides special DataBuffer implementation for use within Graphs */ - class ND4J_EXPORT ManagedDataBuffer : public DataBuffer { + class SD_EXPORT ManagedDataBuffer : public DataBuffer { private: graph::GraphMemoryManager &_manager; diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 6ab301200b42..867eaaf76350 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -50,56 +50,56 @@ namespace sd { template ::value>::type> - ND4J_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); + SD_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); + SD_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); + SD_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); template ::value>::type> - ND4J_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); + SD_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); template ::value>::type> - ND4J_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); + SD_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); + SD_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); + SD_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); template ::value>::type> - ND4J_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); + SD_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); template ::value>::type> - ND4J_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); + SD_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); + SD_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); + SD_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); template ::value>::type> - ND4J_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); + SD_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); template ::value>::type> - ND4J_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); + SD_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); + SD_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); template ::value>::type> - ND4J_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); + SD_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); template ::value>::type> - ND4J_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); + SD_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); + SD_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); + SD_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); + SD_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); template ::type>::value && std::is_same::type>::value>::type> - ND4J_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); + SD_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); - ND4J_EXPORT NDArray mmul(const NDArray&, const NDArray&); + SD_EXPORT NDArray mmul(const NDArray&, const NDArray&); - class ND4J_EXPORT NDArray { + class SD_EXPORT NDArray { private: /** * This method applies given value to the buffer, wrt templates diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 43c6fe2ad865..4c3ce8def7c2 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -30,13 +30,13 @@ namespace sd { template <> -ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; +SD_EXPORT utf8string NDArray::e(const Nd4jLong i) const; template <> -ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; +SD_EXPORT std::string NDArray::e(const Nd4jLong i) const; template <> -ND4J_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; +SD_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; template <> -ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; +SD_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; //////////////////////////////////////////////////////////////////////// // copy constructor @@ -985,7 +985,7 @@ std::vector NDArray::getBufferAsVector() { vector[e] = this->e(e); return vector; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeAsFlatVector() { @@ -1121,19 +1121,19 @@ NDArray& NDArray::operator=(const T scalar) { this->assign(scalar); return *this; } -template ND4J_EXPORT NDArray& NDArray::operator=(const double scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bool scalar); +template SD_EXPORT NDArray& NDArray::operator=(const double scalar); +template SD_EXPORT NDArray& NDArray::operator=(const float scalar); +template SD_EXPORT NDArray& NDArray::operator=(const float16 scalar); +template SD_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); +template SD_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int8_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const int16_t scalar); +template SD_EXPORT NDArray& NDArray::operator=(const bool scalar); ////////////////////////////////////////////////////////////////////////// void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes, Nd4jLong offsetThis, Nd4jLong offsetOther) { @@ -1220,19 +1220,19 @@ void NDArray::assign(const T& value, bool allowParallelism) { NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism); NDArray::registerSpecialUse({this}, {&temp}); } -template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const double& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); ////////////////////////////////////////////////////////////////////////// NDArray* NDArray::detach() { @@ -1319,7 +1319,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *va auto xOffset = shape::getOffset(getShapeInfo(), indices); t[xOffset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -1329,7 +1329,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *valu t[offset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void NDArray::setContext(sd::LaunchContext *context) { @@ -1795,7 +1795,7 @@ template void* NDArray::templatedPointerShift(const Nd4jLong offset) const { return reinterpret_cast(getBuffer()) + offset; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected @@ -2158,7 +2158,7 @@ bool NDArray::isUnitary() { ////////////////////////////////////////////////////////////////////////// template <> -std::string* ND4J_EXPORT NDArray::bufferAsT() const { +std::string* SD_EXPORT NDArray::bufferAsT() const { throw std::runtime_error("This method is NOT supposed to be used"); } @@ -2170,7 +2170,7 @@ T* NDArray::bufferAsT() const { return reinterpret_cast(getBuffer()); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , * NDArray::bufferAsT() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArray::subarray(IndicesList& idx) const { @@ -2288,7 +2288,7 @@ NDArray NDArray::asT() const{ return result; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -2400,7 +2400,7 @@ NDArray NDArray::asS() const { return res; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArray::asT(DataType dtype) const { @@ -2592,13 +2592,13 @@ void NDArray::operator+=(const T value) { NDArray::registerSpecialUse({this}, {}); } -template ND4J_EXPORT void NDArray::operator+=(const double value); -template ND4J_EXPORT void NDArray::operator+=(const float value); -template ND4J_EXPORT void NDArray::operator+=(const float16 value); -template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator+=(const int value); -template ND4J_EXPORT void NDArray::operator+=(const bool value); +template SD_EXPORT void NDArray::operator+=(const double value); +template SD_EXPORT void NDArray::operator+=(const float value); +template SD_EXPORT void NDArray::operator+=(const float16 value); +template SD_EXPORT void NDArray::operator+=(const bfloat16 value); +template SD_EXPORT void NDArray::operator+=(const Nd4jLong value); +template SD_EXPORT void NDArray::operator+=(const int value); +template SD_EXPORT void NDArray::operator+=(const bool value); //////////////////////////////////////////////////////////////////////// template @@ -2614,13 +2614,13 @@ void NDArray::operator-=(const T value) { NDArray::registerSpecialUse({this}, {}); } -template ND4J_EXPORT void NDArray::operator-=(const double value); -template ND4J_EXPORT void NDArray::operator-=(const float value); -template ND4J_EXPORT void NDArray::operator-=(const float16 value); -template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator-=(const int value); -template ND4J_EXPORT void NDArray::operator-=(const bool value); +template SD_EXPORT void NDArray::operator-=(const double value); +template SD_EXPORT void NDArray::operator-=(const float value); +template SD_EXPORT void NDArray::operator-=(const float16 value); +template SD_EXPORT void NDArray::operator-=(const bfloat16 value); +template SD_EXPORT void NDArray::operator-=(const Nd4jLong value); +template SD_EXPORT void NDArray::operator-=(const int value); +template SD_EXPORT void NDArray::operator-=(const bool value); //////////////////////////////////////////////////////////////////////// template @@ -2634,16 +2634,16 @@ void NDArray::operator*=(const T scalar) { NDArray::registerSpecialUse({this}, {}); } -template ND4J_EXPORT void NDArray::operator*=(const double scalar); -template ND4J_EXPORT void NDArray::operator*=(const float scalar); -template ND4J_EXPORT void NDArray::operator*=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator*=(const int scalar); -template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const bool scalar); +template SD_EXPORT void NDArray::operator*=(const double scalar); +template SD_EXPORT void NDArray::operator*=(const float scalar); +template SD_EXPORT void NDArray::operator*=(const float16 scalar); +template SD_EXPORT void NDArray::operator*=(const bfloat16 scalar); +template SD_EXPORT void NDArray::operator*=(const Nd4jLong scalar); +template SD_EXPORT void NDArray::operator*=(const int scalar); +template SD_EXPORT void NDArray::operator*=(const int16_t scalar); +template SD_EXPORT void NDArray::operator*=(const int8_t scalar); +template SD_EXPORT void NDArray::operator*=(const uint8_t scalar); +template SD_EXPORT void NDArray::operator*=(const bool scalar); //////////////////////////////////////////////////////////////////////// template @@ -2656,16 +2656,16 @@ void NDArray::operator/=(const T scalar) { NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {}); } -template ND4J_EXPORT void NDArray::operator/=(const double scalar); -template ND4J_EXPORT void NDArray::operator/=(const float scalar); -template ND4J_EXPORT void NDArray::operator/=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator/=(const int scalar); -template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const bool scalar); +template SD_EXPORT void NDArray::operator/=(const double scalar); +template SD_EXPORT void NDArray::operator/=(const float scalar); +template SD_EXPORT void NDArray::operator/=(const float16 scalar); +template SD_EXPORT void NDArray::operator/=(const bfloat16 scalar); +template SD_EXPORT void NDArray::operator/=(const Nd4jLong scalar); +template SD_EXPORT void NDArray::operator/=(const int scalar); +template SD_EXPORT void NDArray::operator/=(const int16_t scalar); +template SD_EXPORT void NDArray::operator/=(const int8_t scalar); +template SD_EXPORT void NDArray::operator/=(const uint8_t scalar); +template SD_EXPORT void NDArray::operator/=(const bool scalar); //////////////////////////////////////////////////////////////////////// // negative operator, it makes all array elements = -elements @@ -3243,7 +3243,7 @@ std::vector NDArray::asVectorT() { return result; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // set new order and shape in case of suitable array length @@ -3353,7 +3353,7 @@ template void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value) { BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ @@ -3416,7 +3416,7 @@ void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const const auto y = reinterpret_cast(yBuffer); x[xOffset] = static_cast(y[yOffset]); } -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { @@ -3737,7 +3737,7 @@ T NDArray::e(const Nd4jLong i) const { BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // Returns value from 2D matrix by coordinates/indexes @@ -3757,7 +3757,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates @@ -3777,7 +3777,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates @@ -3797,7 +3797,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// NDArray NDArray::e(const Nd4jLong i) const { @@ -4037,17 +4037,17 @@ void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target applyScalarArr(op, scalarArr, target, extraParams); } -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; +template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; //////////////////////////////////////////////////////////////////////// template @@ -4056,17 +4056,17 @@ void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, E auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); applyScalarArr(op, scalarArr, target, extraParams); } -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); +template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); //////////////////////////////////////////////////////////////////////// template @@ -4076,17 +4076,17 @@ void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray &targe applyScalarArr(op, scalarArr, target, extraParams); } -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; +template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; +template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; //////////////////////////////////////////////////////////////////////// void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { @@ -4377,19 +4377,19 @@ void NDArray::p(const Nd4jLong i, const T value) { NDArray::registerPrimaryUse({this}, {}); } -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 2D matrix to position i, j @@ -4407,19 +4407,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 3D matrix to position i,j,k @@ -4437,19 +4437,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); ////////////////////////////////////////////////////////////////////////// template @@ -4466,19 +4466,19 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {}); } -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const NDArray& scalar) { @@ -4667,7 +4667,7 @@ void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuff if (xBuffer != nullptr && yBuffer != nullptr) *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// @@ -5031,11 +5031,11 @@ NDArray operator+(NDArray&& arr, const T& scalar) { return std::move(arr); } -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5053,22 +5053,22 @@ NDArray operator+(const NDArray& arr, const T& scalar) { return result; } -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator+(const T& scalar, NDArray&& arr) { return std::move(arr) + scalar; } -template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// @@ -5076,9 +5076,9 @@ template NDArray operator+(const T& scalar, const NDArray& arr) { return arr + scalar; } -template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// // addition operator array - scalar @@ -5101,8 +5101,8 @@ NDArray operator-(NDArray&& arr, const T& scalar) { return std::move(arr); } -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5120,11 +5120,11 @@ NDArray operator-(const NDArray& arr, const T& scalar) { return result; } -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5145,11 +5145,11 @@ NDArray operator-(const T& scalar, NDArray&& arr) { return std::move(arr); } -template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// template @@ -5167,9 +5167,9 @@ NDArray operator-(const T& scalar, const NDArray& arr) { return result; } -template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// // addition operator array + scalar @@ -5192,12 +5192,12 @@ NDArray operator*(NDArray&& arr, const T& scalar) { return std::move(arr); } -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); +template SD_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5216,24 +5216,24 @@ NDArray operator*(const NDArray& arr, const T& scalar) { return result; } -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator*(const T& scalar, NDArray&& arr) { return std::move(arr) * scalar; } -template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// @@ -5241,12 +5241,12 @@ template NDArray operator*(const T& scalar, const NDArray& arr) { return arr * scalar; } -template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// template @@ -5268,11 +5268,11 @@ NDArray operator/(NDArray&& arr, const T& scalar) { return std::move(arr); } -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5290,12 +5290,12 @@ NDArray operator/(const NDArray& arr, const T& scalar) { return result; } -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template @@ -5316,11 +5316,11 @@ NDArray operator/(const T& scalar, NDArray&& arr) { return std::move(arr); } -template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); +template SD_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// @@ -5339,9 +5339,9 @@ NDArray operator/(const T& scalar, const NDArray& arr) { return result; } -template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); //////////////////////////////////////////////////////////////////////// // addition operator array + array @@ -5383,15 +5383,15 @@ NDArray operator+(T1&& arr1, T2&& arr2) { return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), std::forward(arr2)); } -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // addition operator array - array @@ -5433,15 +5433,15 @@ NDArray operator-(T1&& arr1, T2&& arr2) { return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); } -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array @@ -5483,15 +5483,15 @@ NDArray operator*(T1&& arr1, T2&& arr2) { return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); } -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array @@ -5533,15 +5533,15 @@ NDArray operator/(T1&& arr1, T2&& arr2) { return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); } -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); /* diff --git a/libnd4j/include/array/NDArrayFactory.h b/libnd4j/include/array/NDArrayFactory.h index bfe3aa3e6600..00f6aa732d9e 100644 --- a/libnd4j/include/array/NDArrayFactory.h +++ b/libnd4j/include/array/NDArrayFactory.h @@ -32,7 +32,7 @@ namespace sd { - class ND4J_EXPORT NDArrayFactory { + class SD_EXPORT NDArrayFactory { private: template static void memcpyFromVector(void *ptr, const std::vector &vector); diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index e446213f2ec6..661422de914d 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -31,7 +31,7 @@ #include namespace sd { - class ND4J_EXPORT NDArrayList { + class SD_EXPORT NDArrayList { private: // workspace where chunks belong to //sd::memory::Workspace* _workspace = nullptr; diff --git a/libnd4j/include/array/ResultSet.h b/libnd4j/include/array/ResultSet.h index 6c80e7b1816a..916ae0d41fe3 100644 --- a/libnd4j/include/array/ResultSet.h +++ b/libnd4j/include/array/ResultSet.h @@ -34,7 +34,7 @@ namespace sd { class NDArray; // forward declaration of template class NDArray - class ND4J_EXPORT ResultSet { + class SD_EXPORT ResultSet { private: std::vector _content; Nd4jStatus _status = ND4J_STATUS_OK; diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index 3f0704889943..196930c75f08 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -30,7 +30,7 @@ namespace sd { -class ND4J_EXPORT ShapeDescriptor { +class SD_EXPORT ShapeDescriptor { private: int _rank = 0; @@ -91,7 +91,7 @@ class ND4J_EXPORT ShapeDescriptor { namespace std { template<> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const sd::ShapeDescriptor &k) const; }; diff --git a/libnd4j/include/array/ShapeList.h b/libnd4j/include/array/ShapeList.h index 2d0fde4ad2aa..0fb170172f10 100644 --- a/libnd4j/include/array/ShapeList.h +++ b/libnd4j/include/array/ShapeList.h @@ -26,7 +26,7 @@ #include namespace sd { - class ND4J_EXPORT ShapeList { + class SD_EXPORT ShapeList { protected: std::vector _shapes; diff --git a/libnd4j/include/array/TadDescriptor.h b/libnd4j/include/array/TadDescriptor.h index 8e14c1ec1efd..e10525c2a681 100644 --- a/libnd4j/include/array/TadDescriptor.h +++ b/libnd4j/include/array/TadDescriptor.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT TadDescriptor { + class SD_EXPORT TadDescriptor { private: ShapeDescriptor _originalShape; @@ -62,7 +62,7 @@ namespace sd { namespace std { template<> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const sd::TadDescriptor &k) const; }; diff --git a/libnd4j/include/array/TadPack.h b/libnd4j/include/array/TadPack.h index a5f880b9f3b8..350e5de4759e 100644 --- a/libnd4j/include/array/TadPack.h +++ b/libnd4j/include/array/TadPack.h @@ -24,7 +24,7 @@ #include "ConstantDataBuffer.h" namespace sd { - class ND4J_EXPORT TadPack { + class SD_EXPORT TadPack { private: ConstantDataBuffer _tadShape; ConstantDataBuffer _tadOffsets; diff --git a/libnd4j/include/array/cuda/NDArray.cu b/libnd4j/include/array/cuda/NDArray.cu index 9ddb8be595cf..6e2c7980420c 100644 --- a/libnd4j/include/array/cuda/NDArray.cu +++ b/libnd4j/include/array/cuda/NDArray.cu @@ -147,7 +147,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& t manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template diff --git a/libnd4j/include/array/impl/ConstantDataBuffer.cpp b/libnd4j/include/array/impl/ConstantDataBuffer.cpp index 20c842266d31..5a56a50d02ee 100644 --- a/libnd4j/include/array/impl/ConstantDataBuffer.cpp +++ b/libnd4j/include/array/impl/ConstantDataBuffer.cpp @@ -55,18 +55,18 @@ namespace sd { T* ConstantDataBuffer::primaryAsT() { return reinterpret_cast(_primaryBuffer); } - template ND4J_EXPORT float* ConstantDataBuffer::primaryAsT(); - template ND4J_EXPORT double* ConstantDataBuffer::primaryAsT(); - template ND4J_EXPORT int* ConstantDataBuffer::primaryAsT(); - template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT(); + template SD_EXPORT float* ConstantDataBuffer::primaryAsT(); + template SD_EXPORT double* ConstantDataBuffer::primaryAsT(); + template SD_EXPORT int* ConstantDataBuffer::primaryAsT(); + template SD_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT(); template T* ConstantDataBuffer::specialAsT() { return reinterpret_cast(_specialBuffer); } - template ND4J_EXPORT float* ConstantDataBuffer::specialAsT(); - template ND4J_EXPORT double* ConstantDataBuffer::specialAsT(); - template ND4J_EXPORT int* ConstantDataBuffer::specialAsT(); - template ND4J_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT(); + template SD_EXPORT float* ConstantDataBuffer::specialAsT(); + template SD_EXPORT double* ConstantDataBuffer::specialAsT(); + template SD_EXPORT int* ConstantDataBuffer::specialAsT(); + template SD_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT(); } diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index 08637862c5e3..204ff879e136 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -40,7 +40,7 @@ namespace sd { bool ConstantHolder::hasBuffer() { return hasBuffer(DataTypeUtils::fromT()); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES); void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType) { _buffers[dataType] = pointer; @@ -50,7 +50,7 @@ namespace sd { void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) { addBuffer(pointer, DataTypeUtils::fromT()); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES); ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(sd::DataType dataType) { if (!hasBuffer(dataType)) @@ -63,5 +63,5 @@ namespace sd { ConstantDataBuffer* ConstantHolder::getConstantDataBuffer() { return getConstantDataBuffer(DataTypeUtils::fromT()); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT ConstantDataBuffer* ConstantHolder::getConstantDataBuffer, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT ConstantDataBuffer* ConstantHolder::getConstantDataBuffer, (), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/array/impl/ExtraArguments.cpp b/libnd4j/include/array/impl/ExtraArguments.cpp index 084f327cc290..2f512cf50b84 100644 --- a/libnd4j/include/array/impl/ExtraArguments.cpp +++ b/libnd4j/include/array/impl/ExtraArguments.cpp @@ -89,7 +89,7 @@ namespace sd { delete[] target; #endif } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::allocate(size_t length, size_t elementSize) { #ifdef __CUDABLAS__ @@ -119,7 +119,7 @@ namespace sd { void* ExtraArguments::argumentsAsT(Nd4jLong offset) { return argumentsAsT(DataTypeUtils::fromT(), offset); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 0f54f26a9616..128f032faef7 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -39,7 +39,7 @@ namespace sd { //////////////////////////////////////////////////////////////////////// template <> - ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { + SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { if ((int) shape.size() > MAX_RANK) throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); @@ -82,25 +82,25 @@ namespace sd { return result; } - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, sd::LaunchContext * context) { return create_(order, shape, DataTypeUtils::fromT(), context); } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template @@ -110,20 +110,20 @@ void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { } template <> -void ND4J_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { +void SD_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { auto p = reinterpret_cast(ptr); for (Nd4jLong e = 0; e < vector.size(); e++) p[e] = vector[e]; } -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); #ifndef __JAVACPP_HACK__ @@ -132,16 +132,16 @@ template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std: NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const T value, const char order, sd::LaunchContext * context) { return valueOf(std::vector(shape), value, order); } - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template @@ -149,18 +149,18 @@ template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std: std::vector vec(data); return create(order, shape, vec, context); } - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); #endif @@ -179,19 +179,19 @@ template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std: return res; } - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const double scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const float scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const int scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, sd::LaunchContext * context); template NDArray NDArrayFactory::create(sd::DataType type, const T scalar, sd::LaunchContext * context) { @@ -205,20 +205,20 @@ template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std: return res; } -// BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (DataType type, const T scalar, sd::LaunchContext * context), LIBND4J_TYPES); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint16_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint32_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const uint64_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const int16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(DataType type, const bool scalar, sd::LaunchContext * context); +// BUILD_DOUBLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (DataType type, const T scalar, sd::LaunchContext * context), LIBND4J_TYPES); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const double scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const float scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const float16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const bfloat16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const Nd4jLong scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint16_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint32_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint64_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int16_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const bool scalar, sd::LaunchContext * context); template NDArray NDArrayFactory::create(const T scalar, sd::LaunchContext * context) { @@ -234,19 +234,19 @@ template ND4J_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std: return res; } - template ND4J_EXPORT NDArray NDArrayFactory::create(const double scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const float scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const float16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, sd::LaunchContext * context); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, sd::LaunchContext* workspace); - template ND4J_EXPORT NDArray NDArrayFactory::create(const bool scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const double scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const float scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const float16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const int scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, sd::LaunchContext * context); + template SD_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, sd::LaunchContext* workspace); + template SD_EXPORT NDArray NDArrayFactory::create(const bool scalar, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// @@ -255,31 +255,31 @@ NDArray* NDArrayFactory::create_(const char order, const std::vector & return new NDArray(NDArrayFactory::create(order, shape, data, context)); } -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template <> - ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, sd::LaunchContext * context) { + SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, sd::LaunchContext * context) { auto result = create_(order, shape, value->dataType(), context); result->assign(*value); return result; } template <> - ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, sd::LaunchContext * context) { + SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, sd::LaunchContext * context) { auto result = create_(order, shape, value.dataType(), context); result->assign(value); return result; @@ -291,16 +291,16 @@ template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const st result->assign(value); return result; } - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// @@ -316,19 +316,19 @@ template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const st return result; } - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const double from, const double to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const float from, const float to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const float16 from, const float16 to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int from, const int to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int16_t from, const int16_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint8_t from, const uint8_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint16_t from, const uint16_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint32_t from, const uint32_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const uint64_t from, const uint64_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const int8_t from, const int8_t to, const Nd4jLong numElements); - template ND4J_EXPORT NDArray* NDArrayFactory::linspace(const bool from, const bool to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const double from, const double to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const float from, const float to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const float16 from, const float16 to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const int from, const int to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const int16_t from, const int16_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint8_t from, const uint8_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint16_t from, const uint16_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint32_t from, const uint32_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint64_t from, const uint64_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const int8_t from, const int8_t to, const Nd4jLong numElements); + template SD_EXPORT NDArray* NDArrayFactory::linspace(const bool from, const bool to, const Nd4jLong numElements); //////////////////////////////////////////////////////////////////////// template @@ -345,19 +345,19 @@ template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const st return res; } - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const double startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float16 startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bfloat16 startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const Nd4jLong startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint8_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint16_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint32_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint64_t startingValue, sd::LaunchContext *workspace); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int8_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int16_t startingValue, sd::LaunchContext * context); - template ND4J_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bool startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const double startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float16 startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bfloat16 startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const Nd4jLong startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint8_t startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint16_t startingValue, sd::LaunchContext *workspace); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint32_t startingValue, sd::LaunchContext *workspace); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint64_t startingValue, sd::LaunchContext *workspace); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int8_t startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int16_t startingValue, sd::LaunchContext * context); + template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bool startingValue, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template @@ -365,14 +365,14 @@ template ND4J_EXPORT NDArray* NDArrayFactory::create_(const char order, const st std::vector vec(shape); return create(order, vec, context); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char, const std::initializer_list&, sd::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (const char, const std::initializer_list&, sd::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// template NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::LaunchContext * context) { return create(order, shape, DataTypeUtils::fromT(), context); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::create, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext* context) { @@ -425,17 +425,17 @@ NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * return res; } -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// template @@ -448,7 +448,7 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &val return result; } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray* NDArrayFactory::empty_, (sd::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::empty_, (sd::LaunchContext * context), LIBND4J_TYPES); NDArray* NDArrayFactory::empty_(sd::DataType dataType, sd::LaunchContext * context) { if (context == nullptr) @@ -468,7 +468,7 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(const std::vector &val NDArray NDArrayFactory::empty(sd::LaunchContext * context) { return empty(DataTypeUtils::fromT(), context); } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArrayFactory::empty, (sd::LaunchContext * context), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::empty, (sd::LaunchContext * context), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArrayFactory::empty(sd::DataType dataType, sd::LaunchContext * context) { @@ -511,16 +511,16 @@ NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializ return result; } -template ND4J_EXPORT NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); +template SD_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); ///////////////////////////////////////////////////////////////////////////////////// NDArray NDArrayFactory::string(const char16_t* u16string, sd::DataType dtype, sd::LaunchContext* context) { diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index ea847c3e7423..62fab018bdd0 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -59,7 +59,7 @@ namespace cnpy { /** * The numpy array */ - struct ND4J_EXPORT NpyArray { + struct SD_EXPORT NpyArray { char* data; std::vector shape; unsigned int wordSize; @@ -69,7 +69,7 @@ namespace cnpy { } }; - struct ND4J_EXPORT npz_t : public std::unordered_map { + struct SD_EXPORT npz_t : public std::unordered_map { void destruct() { npz_t::iterator it = this->begin(); for(; it != this->end(); ++it) (*it).second.destruct(); @@ -81,7 +81,7 @@ namespace cnpy { * @param path * @return */ - ND4J_EXPORT char* loadFile(const char *path); + SD_EXPORT char* loadFile(const char *path); @@ -97,10 +97,10 @@ namespace cnpy { * @param t * @return */ - ND4J_EXPORT char mapType(const std::type_info &t); + SD_EXPORT char mapType(const std::type_info &t); template - ND4J_EXPORT char mapType(); + SD_EXPORT char mapType(); /** * @@ -111,7 +111,7 @@ namespace cnpy { * @return */ template - ND4J_EXPORT std::vector createNpyHeader(const void *data, + SD_EXPORT std::vector createNpyHeader(const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize = 4); @@ -126,7 +126,7 @@ namespace cnpy { * @param ndims * @param fortranOrder */ - ND4J_EXPORT void parseNpyHeader(FILE *fp, + SD_EXPORT void parseNpyHeader(FILE *fp, unsigned int &wordSize, unsigned int *&shape, unsigned int &ndims, @@ -143,7 +143,7 @@ namespace cnpy { * @param ndims * @param fortran_order */ - ND4J_EXPORT void parseNpyHeaderPointer( + SD_EXPORT void parseNpyHeaderPointer( const char *header, unsigned int& word_size, unsigned int*& shape, @@ -156,7 +156,7 @@ namespace cnpy { * @param global_header_size * @param global_header_offset */ - ND4J_EXPORT void parseZipFooter(FILE *fp, + SD_EXPORT void parseZipFooter(FILE *fp, unsigned short &nrecs, unsigned int &global_header_size, unsigned int &global_header_offset); @@ -167,14 +167,14 @@ namespace cnpy { * @param varname * @return */ - ND4J_EXPORT NpyArray npzLoad(std::string fname, std::string varname); + SD_EXPORT NpyArray npzLoad(std::string fname, std::string varname); /** * * @param fname * @return */ - ND4J_EXPORT NpyArray npyLoad(std::string fname); + SD_EXPORT NpyArray npyLoad(std::string fname); /** * Parse the numpy header from @@ -187,7 +187,7 @@ namespace cnpy { * @param ndims * @param fortranOrder */ - ND4J_EXPORT void parseNpyHeaderStr(std::string header, + SD_EXPORT void parseNpyHeaderStr(std::string header, unsigned int &wordSize, unsigned int *&shape, unsigned int &ndims, @@ -199,46 +199,46 @@ namespace cnpy { * @param fp * @return */ - ND4J_EXPORT int* shapeFromFile(FILE *fp); + SD_EXPORT int* shapeFromFile(FILE *fp); /** * * @param data * @return */ - ND4J_EXPORT int* shapeFromPointer(char *data); + SD_EXPORT int* shapeFromPointer(char *data); /** * Load the numpy array from the given file. * @param fp the file to load * @return the loaded array */ - ND4J_EXPORT NpyArray loadNpyFromFile(FILE *fp); + SD_EXPORT NpyArray loadNpyFromFile(FILE *fp); /** * Load the numpy array archive from the given file. * @param fp the file to load * @return the loaded archive */ - ND4J_EXPORT npz_t npzLoad(FILE* fp); + SD_EXPORT npz_t npzLoad(FILE* fp); /** * * @param data * @return */ - ND4J_EXPORT NpyArray loadNpyFromPointer(char *data); + SD_EXPORT NpyArray loadNpyFromPointer(char *data); /** * * @param data * @return */ - ND4J_EXPORT NpyArray loadNpyFromHeader(char *data); + SD_EXPORT NpyArray loadNpyFromHeader(char *data); - ND4J_EXPORT npz_t npzLoad(std::string fname); + SD_EXPORT npz_t npzLoad(std::string fname); - ND4J_EXPORT sd::DataType dataTypeFromHeader(char *data); + SD_EXPORT sd::DataType dataTypeFromHeader(char *data); /** * Parse the numpy header from * the given file @@ -250,7 +250,7 @@ namespace cnpy { * @param ndims * @param fortran_order */ - ND4J_EXPORT void parseNpyHeader(std::string header, + SD_EXPORT void parseNpyHeader(std::string header, unsigned int &word_size, unsigned int *&shape, unsigned int &ndims, @@ -273,7 +273,7 @@ namespace cnpy { template - ND4J_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w"); + SD_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w"); } @@ -285,7 +285,7 @@ namespace cnpy { * @return */ template - ND4J_EXPORT std::vector& operator+=(std::vector& lhs, const T rhs); + SD_EXPORT std::vector& operator+=(std::vector& lhs, const T rhs); #endif diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 98a5a47556ec..156f6937ca3b 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -34,7 +34,7 @@ #endif namespace sd { - class ND4J_EXPORT allocation_exception : public std::runtime_error { + class SD_EXPORT allocation_exception : public std::runtime_error { public: allocation_exception(std::string message); ~allocation_exception() = default; diff --git a/libnd4j/include/exceptions/cuda_exception.h b/libnd4j/include/exceptions/cuda_exception.h index a800cfcbcb1f..b4e6f591c002 100644 --- a/libnd4j/include/exceptions/cuda_exception.h +++ b/libnd4j/include/exceptions/cuda_exception.h @@ -33,7 +33,7 @@ #endif namespace sd { - class ND4J_EXPORT cuda_exception : public std::runtime_error { + class SD_EXPORT cuda_exception : public std::runtime_error { public: cuda_exception(std::string message); ~cuda_exception() = default; diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index fbbe62164586..32e4441d62e7 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -34,7 +34,7 @@ #endif namespace sd { - class ND4J_EXPORT datatype_exception : public std::runtime_error { + class SD_EXPORT datatype_exception : public std::runtime_error { public: datatype_exception(std::string message); ~datatype_exception() = default; diff --git a/libnd4j/include/exceptions/graph_exception.h b/libnd4j/include/exceptions/graph_exception.h index d7180e6ce07d..7866dc1bb909 100644 --- a/libnd4j/include/exceptions/graph_exception.h +++ b/libnd4j/include/exceptions/graph_exception.h @@ -34,7 +34,7 @@ #endif namespace sd { - class ND4J_EXPORT graph_exception : public std::runtime_error { + class SD_EXPORT graph_exception : public std::runtime_error { protected: Nd4jLong _graphId; std::string _message; diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index d64b0ccfdecd..612b549f81ff 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -35,7 +35,7 @@ #endif namespace sd { - class ND4J_EXPORT graph_execution_exception: public graph_exception { + class SD_EXPORT graph_execution_exception: public graph_exception { public: explicit graph_execution_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/graph_exists_exception.h b/libnd4j/include/exceptions/graph_exists_exception.h index a50566d90636..edec72303d2c 100644 --- a/libnd4j/include/exceptions/graph_exists_exception.h +++ b/libnd4j/include/exceptions/graph_exists_exception.h @@ -35,7 +35,7 @@ #endif namespace sd { - class ND4J_EXPORT graph_exists_exception: public graph_exception { + class SD_EXPORT graph_exists_exception: public graph_exception { public: explicit graph_exists_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/no_results_exception.h b/libnd4j/include/exceptions/no_results_exception.h index aa144679ac39..97a281826d26 100644 --- a/libnd4j/include/exceptions/no_results_exception.h +++ b/libnd4j/include/exceptions/no_results_exception.h @@ -35,7 +35,7 @@ #endif namespace sd { - class ND4J_EXPORT no_results_exception: public graph_exception { + class SD_EXPORT no_results_exception: public graph_exception { public: explicit no_results_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/exceptions/unknown_graph_exception.h b/libnd4j/include/exceptions/unknown_graph_exception.h index bcb1e2619dff..39b63cda25d2 100644 --- a/libnd4j/include/exceptions/unknown_graph_exception.h +++ b/libnd4j/include/exceptions/unknown_graph_exception.h @@ -35,7 +35,7 @@ #endif namespace sd { - class ND4J_EXPORT unknown_graph_exception: public graph_exception { + class SD_EXPORT unknown_graph_exception: public graph_exception { public: explicit unknown_graph_exception(Nd4jLong graphId); }; diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h index 10020d311c02..d054dc885bd9 100644 --- a/libnd4j/include/execution/AffinityManager.h +++ b/libnd4j/include/execution/AffinityManager.h @@ -27,7 +27,7 @@ #include namespace sd { - class ND4J_EXPORT AffinityManager { + class SD_EXPORT AffinityManager { private: static std::atomic _lastDevice; static int _numberOfDevices; diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index 386d84039444..8b1b88bfac89 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -26,7 +26,7 @@ #include namespace sd { - class ND4J_EXPORT ContextBuffers { + class SD_EXPORT ContextBuffers { private: void* _reductionPointer = nullptr; void* _scalarPointer = nullptr; diff --git a/libnd4j/include/execution/ErrorReference.h b/libnd4j/include/execution/ErrorReference.h index 0a3465dbe56b..108878cc5f65 100644 --- a/libnd4j/include/execution/ErrorReference.h +++ b/libnd4j/include/execution/ErrorReference.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT ErrorReference { + class SD_EXPORT ErrorReference { private: int _errorCode = 0; std::string _errorMessage; diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 4eaf2ca0f1bc..308f028498bc 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -48,7 +48,7 @@ namespace sd { -class ND4J_EXPORT LaunchContext { +class SD_EXPORT LaunchContext { private: static std::vector> _contexts; diff --git a/libnd4j/include/execution/ThreadPool.h b/libnd4j/include/execution/ThreadPool.h index a1ce73f408a2..050ce5dfbac2 100644 --- a/libnd4j/include/execution/ThreadPool.h +++ b/libnd4j/include/execution/ThreadPool.h @@ -33,7 +33,7 @@ #include namespace samediff { - class ND4J_EXPORT ThreadPool { + class SD_EXPORT ThreadPool { private: static ThreadPool* _INSTANCE; diff --git a/libnd4j/include/execution/Threads.h b/libnd4j/include/execution/Threads.h index 2ea8295a824d..26c0c0afd542 100644 --- a/libnd4j/include/execution/Threads.h +++ b/libnd4j/include/execution/Threads.h @@ -27,7 +27,7 @@ #include namespace samediff { - class ND4J_EXPORT ThreadsHelper { + class SD_EXPORT ThreadsHelper { public: static int numberOfThreads(int maxThreads, uint64_t numberOfElements); static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y); @@ -36,7 +36,7 @@ namespace samediff { static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); }; - class ND4J_EXPORT Span { + class SD_EXPORT Span { private: int64_t _startX, _stopX, _incX; public: @@ -50,7 +50,7 @@ namespace samediff { static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x); }; - class ND4J_EXPORT Span2 { + class SD_EXPORT Span2 { private: int64_t _startX, _stopX, _incX; int64_t _startY, _stopY, _incY; @@ -70,7 +70,7 @@ namespace samediff { static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); }; - class ND4J_EXPORT Span3 { + class SD_EXPORT Span3 { private: int64_t _startX, _stopX, _incX; int64_t _startY, _stopY, _incY; @@ -94,7 +94,7 @@ namespace samediff { static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); }; - class ND4J_EXPORT Threads { + class SD_EXPORT Threads { public: /** * This function executes 1 dimensional loop for a given number of threads diff --git a/libnd4j/include/execution/Ticket.h b/libnd4j/include/execution/Ticket.h index a44a1e95756c..e6f8fdf6c003 100644 --- a/libnd4j/include/execution/Ticket.h +++ b/libnd4j/include/execution/Ticket.h @@ -29,7 +29,7 @@ #include namespace samediff { - class ND4J_EXPORT Ticket { + class SD_EXPORT Ticket { private: bool _acquired = false; std::vector*> _queues; diff --git a/libnd4j/include/graph/ArgumentsList.h b/libnd4j/include/graph/ArgumentsList.h index 75bdf857a2b3..f9c7ecad3e03 100644 --- a/libnd4j/include/graph/ArgumentsList.h +++ b/libnd4j/include/graph/ArgumentsList.h @@ -29,7 +29,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT ArgumentsList { + class SD_EXPORT ArgumentsList { protected: std::vector _arguments; public: diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 96d7e8b1201f..f5d17f1e3b41 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -44,7 +44,7 @@ namespace sd { /** * This class defines input desired for any given node/operation within graph */ - class ND4J_EXPORT Context : public sd::graph::ContextPrototype { + class SD_EXPORT Context : public sd::graph::ContextPrototype { protected: sd::memory::Workspace* _workspace = nullptr; sd::graph::VariableSpace* _variableSpace = nullptr; diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 57d773dbb391..db27992a110b 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -38,7 +38,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT ContextPrototype { + class SD_EXPORT ContextPrototype { protected: // int ids of the input nodes std::vector> _inputs; diff --git a/libnd4j/include/graph/ExecutorConfiguration.h b/libnd4j/include/graph/ExecutorConfiguration.h index 40f299f02513..61c001381eb1 100644 --- a/libnd4j/include/graph/ExecutorConfiguration.h +++ b/libnd4j/include/graph/ExecutorConfiguration.h @@ -27,7 +27,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT ExecutorConfiguration { + class SD_EXPORT ExecutorConfiguration { public: sd::graph::ProfilingMode _profilingMode; sd::graph::ExecutionMode _executionMode; diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index 1b2a02dca841..dc5911708545 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -29,7 +29,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT FlatUtils { + class SD_EXPORT FlatUtils { public: static std::pair fromIntPair(IntPair* pair); diff --git a/libnd4j/include/graph/FlowPath.h b/libnd4j/include/graph/FlowPath.h index 59752024929e..70bbc727dda1 100644 --- a/libnd4j/include/graph/FlowPath.h +++ b/libnd4j/include/graph/FlowPath.h @@ -32,7 +32,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT FlowPath { + class SD_EXPORT FlowPath { private: MAP_IMPL _states; MAP_IMPL _frames; diff --git a/libnd4j/include/graph/FrameState.h b/libnd4j/include/graph/FrameState.h index 1c0edbc0bbcb..8f97404a0cae 100644 --- a/libnd4j/include/graph/FrameState.h +++ b/libnd4j/include/graph/FrameState.h @@ -27,7 +27,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT FrameState { + class SD_EXPORT FrameState { private: std::string _name; Nd4jLong _id = 0; diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 578ec4998dad..b2653a30b087 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -41,7 +41,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT Graph { + class SD_EXPORT Graph { protected: ExecutorConfiguration *_configuration; VariableSpace *_variableSpace; diff --git a/libnd4j/include/graph/GraphExecutioner.h b/libnd4j/include/graph/GraphExecutioner.h index 2921d30089f9..276c82a1c329 100644 --- a/libnd4j/include/graph/GraphExecutioner.h +++ b/libnd4j/include/graph/GraphExecutioner.h @@ -40,7 +40,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT GraphExecutioner { + class SD_EXPORT GraphExecutioner { protected: diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index 07e091f42503..b200bf37cf75 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -28,7 +28,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT GraphHolder { + class SD_EXPORT GraphHolder { private: static GraphHolder *_INSTANCE; MAP_IMPL _graphF; diff --git a/libnd4j/include/graph/GraphState.h b/libnd4j/include/graph/GraphState.h index 89343997fa40..122fe37a4f83 100644 --- a/libnd4j/include/graph/GraphState.h +++ b/libnd4j/include/graph/GraphState.h @@ -38,7 +38,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT GraphState { + class SD_EXPORT GraphState { protected: // id of this GraphState instance Nd4jLong _id = 0; diff --git a/libnd4j/include/graph/GraphUtils.h b/libnd4j/include/graph/GraphUtils.h index 3aaf820aeb8e..42328a98c6da 100644 --- a/libnd4j/include/graph/GraphUtils.h +++ b/libnd4j/include/graph/GraphUtils.h @@ -28,7 +28,7 @@ namespace sd { namespace graph { -class ND4J_EXPORT GraphUtils { +class SD_EXPORT GraphUtils { public: typedef std::vector OpList; diff --git a/libnd4j/include/graph/InferenceRequest.h b/libnd4j/include/graph/InferenceRequest.h index 38eed9b67e66..2ac698841717 100644 --- a/libnd4j/include/graph/InferenceRequest.h +++ b/libnd4j/include/graph/InferenceRequest.h @@ -28,7 +28,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT InferenceRequest { + class SD_EXPORT InferenceRequest { private: Nd4jLong _id; std::vector _variables; diff --git a/libnd4j/include/graph/Intervals.h b/libnd4j/include/graph/Intervals.h index 3a796407608d..939a65d0538c 100644 --- a/libnd4j/include/graph/Intervals.h +++ b/libnd4j/include/graph/Intervals.h @@ -28,7 +28,7 @@ namespace sd { - class ND4J_EXPORT Intervals { + class SD_EXPORT Intervals { private: std::vector> _content; diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 6657687cb0e1..84cce5f7cd09 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -36,7 +36,7 @@ namespace sd { class Graph; - class ND4J_EXPORT Node { + class SD_EXPORT Node { protected: // TODO: this field must be removed sd::DataType _dataType; diff --git a/libnd4j/include/graph/NodeState.h b/libnd4j/include/graph/NodeState.h index 5e0a7a6d2dcf..3ed3020fb4de 100644 --- a/libnd4j/include/graph/NodeState.h +++ b/libnd4j/include/graph/NodeState.h @@ -26,7 +26,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT NodeState { + class SD_EXPORT NodeState { private: // inner time spent on specific node Nd4jLong _inner = 0; diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index f75504408f21..8c7bd749a055 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -21,20 +21,24 @@ #define SD_OPTIMIZEDGRAPH_H #include +#include #include #include +#include namespace sd { namespace graph { /** * This class acts as a topologically sorted & optimized Graph representation, ready for execution */ - class ND4J_EXPORT OptimizedGraph { + class SD_EXPORT OptimizedGraph { protected: // here we store independent OpSequences // Graph starts from layer 0, and goes deeper step by step // on each layer we can have 1+ OpSequences that can be executed independent - std::map> _onion; + std::map _onion; + + std::mutex _mutex; public: OptimizedGraph() = default; ~OptimizedGraph() = default; @@ -61,13 +65,14 @@ namespace sd { * @param index * @return */ - const std::vector& layer(uint64_t index) const; + const ExecutionLayer& layer(uint64_t index) const; /** * This method allows to append layer to this OptimizedGraph instance */ // FIXME: this method should be removed or made private void append(const std::vector &layer); + void append(const ExecutionLayer &layer); void append(OpSequence &sequence); }; } diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index ef06c345d611..755fce062910 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -38,7 +38,7 @@ namespace sd { namespace graph { #ifdef __CUDACC__ - class ND4J_EXPORT CudaManagedRandomGenerator { + class SD_EXPORT CudaManagedRandomGenerator { private: protected: @@ -59,9 +59,9 @@ namespace sd { } }; - class ND4J_EXPORT RandomGenerator : public CudaManagedRandomGenerator { + class SD_EXPORT RandomGenerator : public CudaManagedRandomGenerator { #else - class ND4J_EXPORT RandomGenerator { + class SD_EXPORT RandomGenerator { #endif private: #ifndef __CUDACC__ diff --git a/libnd4j/include/graph/ResultWrapper.h b/libnd4j/include/graph/ResultWrapper.h index fe5193097818..4b21d946fa69 100644 --- a/libnd4j/include/graph/ResultWrapper.h +++ b/libnd4j/include/graph/ResultWrapper.h @@ -27,7 +27,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT ResultWrapper { + class SD_EXPORT ResultWrapper { private: Nd4jLong _size = 0L; Nd4jPointer _pointer = nullptr; diff --git a/libnd4j/include/graph/Scope.h b/libnd4j/include/graph/Scope.h index 42b99c18e9bc..a78f9fe97471 100644 --- a/libnd4j/include/graph/Scope.h +++ b/libnd4j/include/graph/Scope.h @@ -34,7 +34,7 @@ namespace sd { * * @tparam T */ - class ND4J_EXPORT Scope { + class SD_EXPORT Scope { protected: // Graph-unique IDs for Scope instances int _id; diff --git a/libnd4j/include/graph/SessionLocalStorage.h b/libnd4j/include/graph/SessionLocalStorage.h index 3cb77ec3a5f0..b3636f2615b1 100644 --- a/libnd4j/include/graph/SessionLocalStorage.h +++ b/libnd4j/include/graph/SessionLocalStorage.h @@ -31,7 +31,7 @@ namespace sd{ namespace graph { - class ND4J_EXPORT SessionLocalStorage { + class SD_EXPORT SessionLocalStorage { protected: std::atomic _sessionCounter; MAP_IMPL _threadSession; diff --git a/libnd4j/include/graph/Stash.h b/libnd4j/include/graph/Stash.h index ba431d05756e..2e9e832eaf4f 100644 --- a/libnd4j/include/graph/Stash.h +++ b/libnd4j/include/graph/Stash.h @@ -32,7 +32,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT KeyPair { + class SD_EXPORT KeyPair { int _node; std::string _name; public: @@ -54,7 +54,7 @@ namespace sd { namespace std { template <> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const sd::graph::KeyPair& k) const; }; @@ -64,7 +64,7 @@ namespace std { namespace sd { namespace graph { - class ND4J_EXPORT Stash { + class SD_EXPORT Stash { protected: std::map _stash; std::vector _handles; diff --git a/libnd4j/include/graph/Status.h b/libnd4j/include/graph/Status.h index 42794488dd6b..2f055dc4dbdf 100644 --- a/libnd4j/include/graph/Status.h +++ b/libnd4j/include/graph/Status.h @@ -27,7 +27,7 @@ #include namespace sd { - class ND4J_EXPORT Status { + class SD_EXPORT Status { public: static FORCEINLINE Nd4jStatus OK() { return ND4J_STATUS_OK; diff --git a/libnd4j/include/graph/TimeHolder.h b/libnd4j/include/graph/TimeHolder.h index 191a75bace13..110cdf2f5d28 100644 --- a/libnd4j/include/graph/TimeHolder.h +++ b/libnd4j/include/graph/TimeHolder.h @@ -27,7 +27,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT TimeHolder { + class SD_EXPORT TimeHolder { private: std::map _outer; std::map _inner; diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index b3ac74533f34..61f37dbfb5d7 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -34,19 +34,19 @@ namespace std { template <> - class ND4J_EXPORT hash> { + class SD_EXPORT hash> { public: size_t operator()(const std::pair& k) const; }; template <> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const bfloat16& k) const; }; template <> - class ND4J_EXPORT hash { + class SD_EXPORT hash { public: size_t operator()(const float16& k) const; }; @@ -56,7 +56,7 @@ namespace std { namespace sd { namespace graph { - class ND4J_EXPORT Variable { + class SD_EXPORT Variable { protected: int _id = 0; int _index = 0; @@ -93,7 +93,7 @@ namespace sd { Variable* clone(); template - ND4J_EXPORT Variable* asT(); + SD_EXPORT Variable* asT(); bool hasNDArray(); sd::NDArray* getNDArray(); diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index b917a9192bcf..8d09f84a93fd 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -22,7 +22,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT VariableProxy: public VariableSpace { + class SD_EXPORT VariableProxy: public VariableSpace { protected: VariableSpace* _backed = nullptr; VariableSpace* _current = nullptr; diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 619725e7a1d1..fded177660a8 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -38,7 +38,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT VariableSpace { + class SD_EXPORT VariableSpace { protected: sd::memory::Workspace *_workspace; diff --git a/libnd4j/include/graph/VariablesSet.h b/libnd4j/include/graph/VariablesSet.h index 682b7fce4dab..b3e4844df60f 100644 --- a/libnd4j/include/graph/VariablesSet.h +++ b/libnd4j/include/graph/VariablesSet.h @@ -30,7 +30,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT VariablesSet { + class SD_EXPORT VariablesSet { protected: std::vector _holder; Nd4jStatus _status; diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h new file mode 100644 index 000000000000..f119de8a778d --- /dev/null +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -0,0 +1,70 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#ifndef SD_EXECUTIONLAYER_H +#define SD_EXECUTIONLAYER_H + +#include +#include +#include + +namespace sd { + namespace graph { + class SD_EXPORT ExecutionLayer { + protected: + std::vector _sequences; + public: + ExecutionLayer(const std::vector &sequences = {}); + ~ExecutionLayer() = default; + + ExecutionLayer(const ExecutionLayer& other) noexcept; + + ExecutionLayer& operator=(const ExecutionLayer& other) noexcept; + + // move constructor + ExecutionLayer(ExecutionLayer&& other) noexcept; + + // move assignment operator + ExecutionLayer& operator=(ExecutionLayer&& other) noexcept; + + /** + * This method returns number of sequences in this layer + * @return + */ + uint64_t width() const; + + /** + * This method returns specified OpSequence from this layer + * @return + */ + OpSequence at(uint64_t index) const; + OpSequence operator[](uint64_t index) const; + + /** + * This method appends OpSequence to the end of this layer + * @param sequence + */ + void append(const OpSequence &sequence); + }; + } +} + +#endif //SD_EXECUTIONLAYER_H diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 56e4ffcd2779..3c6e163933ed 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -30,7 +30,7 @@ namespace sd { /** * This class represents independent and immutable sequence of operations */ - class ND4J_EXPORT OpSequence : public std::iterator> { + class SD_EXPORT OpSequence : public std::iterator> { // our internal iterator for OpSequence class iterator; protected: diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp new file mode 100644 index 000000000000..64db317964e0 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include + +namespace sd { + namespace graph { + ExecutionLayer::ExecutionLayer(const std::vector &sequences) { + _sequences = sequences; + } + + uint64_t ExecutionLayer::width() const { + return _sequences.size(); + } + + OpSequence ExecutionLayer::at(uint64_t index) const { + return _sequences[index]; + } + + OpSequence ExecutionLayer::operator[](uint64_t index) const { + return at(index); + } + + void ExecutionLayer::append(const OpSequence &sequence) { + _sequences.emplace_back(sequence); + } + + ExecutionLayer::ExecutionLayer(const ExecutionLayer &other) noexcept { + _sequences = other._sequences; + } + + ExecutionLayer &ExecutionLayer::operator=(const ExecutionLayer &other) noexcept { + if (this == &other) + return *this; + + _sequences = other._sequences; + + return *this; + } + + ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept { + _sequences = std::move(other._sequences); + } + + ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { + if (this == &other) + return *this; + + _sequences = std::move(other._sequences); + + return *this; + } + } +} \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index b288096c28b3..9ad1285fb10a 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -319,7 +319,7 @@ namespace sd { node->_dataType = DataTypeUtils::fromT(); return node; } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Node* Node::asT, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT Node* Node::asT, (), LIBND4J_TYPES); Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 3e706107f30c..5401abc32ca6 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -52,16 +52,22 @@ namespace sd { return _onion.size(); } - const std::vector &OptimizedGraph::layer(uint64_t index) const { + const ExecutionLayer &OptimizedGraph::layer(uint64_t index) const { return _onion.at(index); } void OptimizedGraph::append(const std::vector &layer) { + std::lock_guard lock(_mutex); _onion[_onion.size()] = layer; } void OptimizedGraph::append(OpSequence &sequence) { - append(std::vector{sequence}); + append(ExecutionLayer({sequence})); + } + + void OptimizedGraph::append(const ExecutionLayer &layer) { + std::lock_guard lock(_mutex); + _onion[_onion.size()] = layer; } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index e87c51897ebb..4e990e9b9ad2 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -50,7 +50,7 @@ namespace sd { return result; } - BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES); sd::graph::Variable* sd::graph::Variable::clone() { auto result = new Variable(this->isPlaceholder()); diff --git a/libnd4j/include/graph/profiling/GraphProfile.h b/libnd4j/include/graph/profiling/GraphProfile.h index f0ada4f90f19..8e76b0436683 100644 --- a/libnd4j/include/graph/profiling/GraphProfile.h +++ b/libnd4j/include/graph/profiling/GraphProfile.h @@ -31,7 +31,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT GraphProfile { + class SD_EXPORT GraphProfile { private: // this variable Nd4jLong _merges = 1L; diff --git a/libnd4j/include/graph/profiling/NodeProfile.h b/libnd4j/include/graph/profiling/NodeProfile.h index 871eb57483fb..c7dcaff50a9a 100644 --- a/libnd4j/include/graph/profiling/NodeProfile.h +++ b/libnd4j/include/graph/profiling/NodeProfile.h @@ -28,7 +28,7 @@ namespace sd { namespace graph { - class ND4J_EXPORT NodeProfile { + class SD_EXPORT NodeProfile { private: int _id; std::string _name; diff --git a/libnd4j/include/helpers/AttentionHelper.h b/libnd4j/include/helpers/AttentionHelper.h index 02d9da995541..a2c52c61b403 100644 --- a/libnd4j/include/helpers/AttentionHelper.h +++ b/libnd4j/include/helpers/AttentionHelper.h @@ -24,7 +24,7 @@ #include "array/NDArray.h" namespace sd { - class ND4J_EXPORT AttentionHelper { + class SD_EXPORT AttentionHelper { public: static sd::NDArray multiHeadProject(const sd::NDArray* input, const sd::NDArray* projectionMatrix, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); diff --git a/libnd4j/include/helpers/BenchmarkHelper.h b/libnd4j/include/helpers/BenchmarkHelper.h index 17c427cd8078..38e3826535c1 100644 --- a/libnd4j/include/helpers/BenchmarkHelper.h +++ b/libnd4j/include/helpers/BenchmarkHelper.h @@ -44,7 +44,7 @@ namespace sd { - class ND4J_EXPORT BenchmarkHelper { + class SD_EXPORT BenchmarkHelper { private: unsigned int _wIterations; unsigned int _rIterations; diff --git a/libnd4j/include/helpers/BitwiseUtils.h b/libnd4j/include/helpers/BitwiseUtils.h index 6b7e5c231fcf..38bcd1716ed4 100644 --- a/libnd4j/include/helpers/BitwiseUtils.h +++ b/libnd4j/include/helpers/BitwiseUtils.h @@ -28,7 +28,7 @@ #include namespace sd { - class ND4J_EXPORT BitwiseUtils { + class SD_EXPORT BitwiseUtils { public: diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index 7c519d6351e5..44e4a71c7d14 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -33,7 +33,7 @@ #include namespace sd { - class ND4J_EXPORT ConstantHelper { + class SD_EXPORT ConstantHelper { private: static ConstantHelper* _INSTANCE; ConstantHelper(); diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index e821e55f5a89..f015eec85472 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -33,7 +33,7 @@ namespace sd { - class ND4J_EXPORT ConstantShapeHelper { + class SD_EXPORT ConstantShapeHelper { private: static ConstantShapeHelper *_INSTANCE; diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index c6bb834fb30f..ef1aa4157c89 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -33,7 +33,7 @@ #include namespace sd { - class ND4J_EXPORT ConstantTadHelper { + class SD_EXPORT ConstantTadHelper { private: static ConstantTadHelper *_INSTANCE; diff --git a/libnd4j/include/helpers/CudaLaunchHelper.h b/libnd4j/include/helpers/CudaLaunchHelper.h index 6bf44317fd86..3e933c546476 100644 --- a/libnd4j/include/helpers/CudaLaunchHelper.h +++ b/libnd4j/include/helpers/CudaLaunchHelper.h @@ -28,7 +28,7 @@ #include namespace sd { - class ND4J_EXPORT CudaLaunchHelper { + class SD_EXPORT CudaLaunchHelper { public: static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY); static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index b0387dd8cc2b..0e1bd28374ec 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -38,7 +38,7 @@ #include namespace sd { class NDArray; - class ND4J_EXPORT DebugHelper { + class SD_EXPORT DebugHelper { public: // cuda-specific debug functions diff --git a/libnd4j/include/helpers/DebugInfo.h b/libnd4j/include/helpers/DebugInfo.h index c2efb00fe576..6a2487c537b5 100644 --- a/libnd4j/include/helpers/DebugInfo.h +++ b/libnd4j/include/helpers/DebugInfo.h @@ -38,7 +38,7 @@ #endif namespace sd { - struct ND4J_EXPORT DebugInfo { + struct SD_EXPORT DebugInfo { double _minValue; double _maxValue; double _meanValue; diff --git a/libnd4j/include/helpers/FileUtils.h b/libnd4j/include/helpers/FileUtils.h index 85e4bf2e3fc6..aea296119fb2 100644 --- a/libnd4j/include/helpers/FileUtils.h +++ b/libnd4j/include/helpers/FileUtils.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT FileUtils { + class SD_EXPORT FileUtils { public: static bool fileExists(const char *filename); diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index 0d184a5a1574..baf9adb3871c 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -27,7 +27,7 @@ namespace sd { -class ND4J_EXPORT GradCheck { +class SD_EXPORT GradCheck { public: enum LossFunc {MEAN = 0, SUM = 1}; diff --git a/libnd4j/include/helpers/LoopKind.h b/libnd4j/include/helpers/LoopKind.h index 95e9238ada2e..7a2995dacb63 100644 --- a/libnd4j/include/helpers/LoopKind.h +++ b/libnd4j/include/helpers/LoopKind.h @@ -34,7 +34,7 @@ namespace sd { -class ND4J_EXPORT LoopKind { +class SD_EXPORT LoopKind { public: enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index 508b84f20a38..063983296a8e 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -36,7 +36,7 @@ namespace sd { template - class ND4J_EXPORT ReductionLoops { + class SD_EXPORT ReductionLoops { protected: public: @@ -54,7 +54,7 @@ namespace sd { }; template - class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops { + class SD_EXPORT ReductionBoolLoops : public ReductionLoops { public: static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); @@ -63,7 +63,7 @@ namespace sd { }; template - class ND4J_EXPORT ReductionLongLoops : public ReductionLoops { + class SD_EXPORT ReductionLongLoops : public ReductionLoops { public: static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); @@ -72,7 +72,7 @@ namespace sd { }; template - class ND4J_EXPORT ReductionSameLoops : public ReductionLoops { + class SD_EXPORT ReductionSameLoops : public ReductionLoops { public: static void wrapper(const int opNum, X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); @@ -82,7 +82,7 @@ namespace sd { template - class ND4J_EXPORT IndexReductionLoops { + class SD_EXPORT IndexReductionLoops { private: public: static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); @@ -93,7 +93,7 @@ namespace sd { template - class ND4J_EXPORT TransformLoops { + class SD_EXPORT TransformLoops { public: @@ -102,7 +102,7 @@ namespace sd { }; template - class ND4J_EXPORT Reduction3Loops { + class SD_EXPORT Reduction3Loops { public: template diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 6e38be5c1cd2..5b1aaf734c18 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -25,7 +25,7 @@ #include "array/NDArray.h" namespace sd { - class ND4J_EXPORT MmulHelper { + class SD_EXPORT MmulHelper { private: diff --git a/libnd4j/include/helpers/OmpLaunchHelper.h b/libnd4j/include/helpers/OmpLaunchHelper.h index 3e0e50391472..574a3182ceec 100644 --- a/libnd4j/include/helpers/OmpLaunchHelper.h +++ b/libnd4j/include/helpers/OmpLaunchHelper.h @@ -28,7 +28,7 @@ namespace sd { -class ND4J_EXPORT OmpLaunchHelper { +class SD_EXPORT OmpLaunchHelper { public: diff --git a/libnd4j/include/helpers/OpArgsHolder.h b/libnd4j/include/helpers/OpArgsHolder.h index a9432f134aa9..0850181a0634 100644 --- a/libnd4j/include/helpers/OpArgsHolder.h +++ b/libnd4j/include/helpers/OpArgsHolder.h @@ -27,7 +27,7 @@ namespace sd { -class ND4J_EXPORT OpArgsHolder { +class SD_EXPORT OpArgsHolder { private: diff --git a/libnd4j/include/helpers/OpBenchmark.h b/libnd4j/include/helpers/OpBenchmark.h index dfd303a626c5..8e665cb70089 100644 --- a/libnd4j/include/helpers/OpBenchmark.h +++ b/libnd4j/include/helpers/OpBenchmark.h @@ -28,7 +28,7 @@ #include namespace sd { - class ND4J_EXPORT OpBenchmark { + class SD_EXPORT OpBenchmark { protected: int _opNum = 0; std::string _testName; diff --git a/libnd4j/include/helpers/OpTracker.h b/libnd4j/include/helpers/OpTracker.h index 122f4f32be79..1d38fabe38a3 100644 --- a/libnd4j/include/helpers/OpTracker.h +++ b/libnd4j/include/helpers/OpTracker.h @@ -30,7 +30,7 @@ #include namespace sd { - class ND4J_EXPORT OpTracker { + class SD_EXPORT OpTracker { private: static OpTracker* _INSTANCE; diff --git a/libnd4j/include/helpers/PointersManager.h b/libnd4j/include/helpers/PointersManager.h index 4f7af94098eb..ba01713df739 100644 --- a/libnd4j/include/helpers/PointersManager.h +++ b/libnd4j/include/helpers/PointersManager.h @@ -30,7 +30,7 @@ namespace sd { -class ND4J_EXPORT PointersManager { +class SD_EXPORT PointersManager { private: diff --git a/libnd4j/include/helpers/RandomLauncher.h b/libnd4j/include/helpers/RandomLauncher.h index 49e961062d96..5eec7e0bd11d 100644 --- a/libnd4j/include/helpers/RandomLauncher.h +++ b/libnd4j/include/helpers/RandomLauncher.h @@ -24,7 +24,7 @@ #include namespace sd { - class ND4J_EXPORT RandomLauncher { + class SD_EXPORT RandomLauncher { public: static void applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); static void applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h index b1e5a3eba876..fbbfb71500ef 100644 --- a/libnd4j/include/helpers/ShapeBuilders.h +++ b/libnd4j/include/helpers/ShapeBuilders.h @@ -29,7 +29,7 @@ #include namespace sd { - class ND4J_EXPORT ShapeBuilders { + class SD_EXPORT ShapeBuilders { public: static Nd4jLong* createScalarShapeInfo(sd::DataType dataType, sd::memory::Workspace* workspace = nullptr); diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 8d2a119c33bd..cd4cb0bf97fa 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -26,7 +26,7 @@ namespace sd { - class ND4J_EXPORT ShapeUtils { + class SD_EXPORT ShapeUtils { public: diff --git a/libnd4j/include/helpers/SimpleReadWriteLock.h b/libnd4j/include/helpers/SimpleReadWriteLock.h index 7b87ef5eafea..5079eba65452 100644 --- a/libnd4j/include/helpers/SimpleReadWriteLock.h +++ b/libnd4j/include/helpers/SimpleReadWriteLock.h @@ -32,7 +32,7 @@ * Basic idea: write lock won't be obtained before all read requests served */ namespace sd { - class ND4J_EXPORT SimpleReadWriteLock { + class SD_EXPORT SimpleReadWriteLock { private: std::atomic _read_locks; std::atomic _write_locks; diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index ef9586637eea..fe96e0287da9 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -32,7 +32,7 @@ #include namespace sd { - class ND4J_EXPORT StringUtils { + class SD_EXPORT StringUtils { public: template static FORCEINLINE std::string valueToString(T value) { diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 6fc1822d0433..2214283eae46 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -24,7 +24,7 @@ #define SD_BROADCASTBENCHMARK_H namespace sd { - class ND4J_EXPORT BroadcastBenchmark : public OpBenchmark { + class SD_EXPORT BroadcastBenchmark : public OpBenchmark { public: BroadcastBenchmark() : OpBenchmark() { // diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index 032d05d0a677..7abaa232bedf 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -30,7 +30,7 @@ #include namespace sd { - class ND4J_EXPORT DeclarableBenchmark : public OpBenchmark { + class SD_EXPORT DeclarableBenchmark : public OpBenchmark { protected: sd::ops::DeclarableOp *_op = nullptr; sd::graph::Context *_context = nullptr; diff --git a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h index 139806bb4a51..f307b41287da 100644 --- a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h +++ b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h @@ -25,7 +25,7 @@ #define SD_MATRIXBENCHMARK_H namespace sd { - class ND4J_EXPORT MatrixBenchmark : public OpBenchmark { + class SD_EXPORT MatrixBenchmark : public OpBenchmark { private: float _alpha = 1.0f; float _beta = 0.0f; diff --git a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h index 58ac17b4f72f..76d08db760e1 100644 --- a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h +++ b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h @@ -26,7 +26,7 @@ using namespace sd::graph; namespace sd { - class ND4J_EXPORT PairwiseBenchmark : public OpBenchmark { + class SD_EXPORT PairwiseBenchmark : public OpBenchmark { public: PairwiseBenchmark() : OpBenchmark() { // diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index e5396af41060..78c41171541d 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -28,7 +28,7 @@ using namespace sd::graph; namespace sd { - class ND4J_EXPORT ReductionBenchmark : public OpBenchmark { + class SD_EXPORT ReductionBenchmark : public OpBenchmark { protected: int _opType; //0=Float, 1=Same public: diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 9063f3c526db..f75f4a5a4fed 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -25,7 +25,7 @@ using namespace sd::graph; namespace sd { - class ND4J_EXPORT ScalarBenchmark : public OpBenchmark { + class SD_EXPORT ScalarBenchmark : public OpBenchmark { public: ScalarBenchmark() : OpBenchmark() { // diff --git a/libnd4j/include/helpers/benchmark/TransformBenchmark.h b/libnd4j/include/helpers/benchmark/TransformBenchmark.h index 81a388edb39d..eabb956dbed5 100644 --- a/libnd4j/include/helpers/benchmark/TransformBenchmark.h +++ b/libnd4j/include/helpers/benchmark/TransformBenchmark.h @@ -25,7 +25,7 @@ using namespace sd::graph; namespace sd { - class ND4J_EXPORT TransformBenchmark : public OpBenchmark { + class SD_EXPORT TransformBenchmark : public OpBenchmark { protected: int _opType; // 0=StrictOps, 1=Same, 2=Any, 3=Float diff --git a/libnd4j/include/helpers/cpu/householder.cpp b/libnd4j/include/helpers/cpu/householder.cpp index 69d4ca3dbcc5..39d97f1e1244 100644 --- a/libnd4j/include/helpers/cpu/householder.cpp +++ b/libnd4j/include/helpers/cpu/householder.cpp @@ -205,10 +205,10 @@ void Householder::mulRight(NDArray& matrix, const NDArray& tail, const T coef } -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; diff --git a/libnd4j/include/helpers/cpu/jacobiSVD.cpp b/libnd4j/include/helpers/cpu/jacobiSVD.cpp index 372a2a4092d5..1be91d50b631 100644 --- a/libnd4j/include/helpers/cpu/jacobiSVD.cpp +++ b/libnd4j/include/helpers/cpu/jacobiSVD.cpp @@ -412,10 +412,10 @@ void JacobiSVD::evalData(const NDArray& matrix) { -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; -template class ND4J_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; +template class SD_EXPORT JacobiSVD; diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp index f721c5994eb5..f041adc66d81 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp @@ -56,5 +56,5 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_0); } diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp index 19a248896560..21fde4ff74fd 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp @@ -56,5 +56,5 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_1); } diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp index e90050e4e4a5..38e576695c54 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp @@ -56,5 +56,5 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_2); } diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp index d109d1013aa1..066401dda23e 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp @@ -56,5 +56,5 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_3); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp index 31ec60d939bd..8826c3503a68 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp @@ -42,6 +42,6 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp index f4243d1c98af..bf3a5efb0a2c 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp @@ -43,7 +43,7 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp index 1c5b46d40c7b..445782c05769 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp @@ -43,7 +43,7 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp index 08ca08cdb7cb..5b6b23e7b075 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp @@ -43,7 +43,7 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp index 7735c21253da..f8bf4322fc86 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp @@ -43,7 +43,7 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3); } diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp index e4f4ab2e0cc0..ad82388dacda 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp @@ -48,5 +48,5 @@ namespace sd { #endif } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES); } diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 4e257b267926..1db9d60420ec 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -926,7 +926,7 @@ void SVD::evalData(const NDArray& matrix) { } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SVD,,FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD,,FLOAT_TYPES); diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 8e690778936d..26de981f8b4b 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -27,7 +27,7 @@ #include namespace sd { - class ND4J_EXPORT CublasHelper { + class SD_EXPORT CublasHelper { private: static CublasHelper *_INSTANCE; static std::mutex _mutex; diff --git a/libnd4j/include/helpers/helper_generator.h b/libnd4j/include/helpers/helper_generator.h index ecf87ae81f10..760cd0e24c45 100644 --- a/libnd4j/include/helpers/helper_generator.h +++ b/libnd4j/include/helpers/helper_generator.h @@ -52,7 +52,7 @@ namespace sd { namespace random { #ifdef __CUDACC__ - class ND4J_EXPORT CudaManaged { + class SD_EXPORT CudaManaged { private: protected: @@ -70,9 +70,9 @@ namespace sd { } }; - class ND4J_EXPORT RandomBuffer : public CudaManaged { + class SD_EXPORT RandomBuffer : public CudaManaged { #else - class ND4J_EXPORT RandomBuffer { + class SD_EXPORT RandomBuffer { #endif private: void *devHolder; @@ -511,7 +511,7 @@ namespace sd { }; - class ND4J_EXPORT IGenerator { + class SD_EXPORT IGenerator { protected: Nd4jLong limit; Nd4jLong seed; @@ -549,7 +549,7 @@ namespace sd { - class ND4J_EXPORT Xoroshiro128 : public IGenerator { + class SD_EXPORT Xoroshiro128 : public IGenerator { protected: uint64_t state[2]; diff --git a/libnd4j/include/helpers/helper_hash.h b/libnd4j/include/helpers/helper_hash.h index 1b032238fad5..0ff76eea1bcc 100644 --- a/libnd4j/include/helpers/helper_hash.h +++ b/libnd4j/include/helpers/helper_hash.h @@ -29,7 +29,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT HashHelper { + class SD_EXPORT HashHelper { private: static HashHelper* _INSTANCE; diff --git a/libnd4j/include/helpers/logger.h b/libnd4j/include/helpers/logger.h index c13785ff70bf..625a549318be 100644 --- a/libnd4j/include/helpers/logger.h +++ b/libnd4j/include/helpers/logger.h @@ -49,7 +49,7 @@ #endif namespace sd { - class ND4J_EXPORT Logger { + class SD_EXPORT Logger { public: diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index 2c18615fc391..3d5d0e03a61d 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -64,7 +64,7 @@ namespace shape { * Shape information approximating * the information on an ndarray */ - struct ND4J_EXPORT ShapeInformation { + struct SD_EXPORT ShapeInformation { _CUDA_HD ShapeInformation(Nd4jLong *shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0) : shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_) {} @@ -81,7 +81,7 @@ namespace shape { * Indexing information * for bounds checking */ - struct ND4J_EXPORT CurrentIndexing { + struct SD_EXPORT CurrentIndexing { int numElementsPerThread; int blockStartingIndex; int startingThreadIndex; @@ -91,72 +91,72 @@ namespace shape { - ND4J_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2); + SD_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2); - ND4J_EXPORT _CUDA_HD Nd4jLong* detachShape(Nd4jLong *originalShape); + SD_EXPORT _CUDA_HD Nd4jLong* detachShape(Nd4jLong *originalShape); - ND4J_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong *originalShape); + SD_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong *originalShape); - ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); + SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); - ND4J_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); + SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); - ND4J_EXPORT _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2); + SD_EXPORT _CUDA_HD bool strideEquals(int shape1Rank,Nd4jLong *shape1,int shape2Rank,Nd4jLong *shape2); - ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2); + SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *shapeInfo1,Nd4jLong *shapeInfo2); - ND4J_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1,Nd4jLong *stride2,int rank2); + SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong *stride1,int rank1,Nd4jLong *stride2,int rank2); - ND4J_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); + SD_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); - ND4J_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); + SD_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); - ND4J_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB); + SD_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB); // returns true if ranks, shapes and strides are the same - ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); - ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); + SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); + SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); - ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); - ND4J_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); + SD_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); + SD_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); template - ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length); + SD_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length); - ND4J_EXPORT _CUDA_HD void traceNew(int id); + SD_EXPORT _CUDA_HD void traceNew(int id); - ND4J_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); + SD_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); - ND4J_EXPORT _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong tadLength(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); - ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder); + SD_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder); - ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo); + SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo); /** * newShapeInfo contains rank, shape and order only, no strides/ews/type */ - ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo); + SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo); /** * Get the shape info buffer * for the given rank and shape. */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape); + SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *buffer); /** * Get the shape info buffer * for the given rank and shape. */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape); + SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output); + SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong *shape, Nd4jLong *output); #ifdef __CUDACC__ - __device__ ND4J_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); + __device__ SD_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); #endif @@ -168,9 +168,9 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank); + SD_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank); - ND4J_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret); + SD_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong *shape, int rank, Nd4jLong* ret); /** * Computes the standard packed array strides for a given shape. @@ -180,17 +180,17 @@ namespace shape { * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank); + SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank); - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret); + SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, Nd4jLong* ret); - ND4J_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); - ND4J_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order); + SD_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); + SD_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order); // check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 template - ND4J_EXPORT _CUDA_HD bool isDimPermuted(const T* dimensions, const int dimSize); + SD_EXPORT _CUDA_HD bool isDimPermuted(const T* dimensions, const int dimSize); /** * Computes the standard packed array strides for a given shape. @@ -199,9 +199,9 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum); + SD_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum); - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); + SD_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); /** * Computes the standard packed array strides for a given shape. @@ -210,27 +210,27 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum); + SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum); - ND4J_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); + SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong *shape, int rank, int startNum, Nd4jLong* ret); /** * @param toCopy the shape to copy * @return a copy of the original struct */ - ND4J_EXPORT _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy); + SD_EXPORT _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy); - ND4J_EXPORT _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer); + SD_EXPORT _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer); - ND4J_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo); /** * copy-past from java hasDefaultStridesForShape function * check whether array is not permuted and has contiguous elements in memory */ - ND4J_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo); /** @@ -244,7 +244,7 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder); + SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder); /** * Compute the element wise stride @@ -257,11 +257,11 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - ND4J_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder, Nd4jLong *dimension, int dimensionLength); + SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong *shape, Nd4jLong *stride, int isFOrder, Nd4jLong *dimension, int dimensionLength); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride); + SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer); /** * * @param length @@ -269,7 +269,7 @@ namespace shape { * @param rearrange * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int* rearrange); + SD_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int* rearrange); @@ -279,13 +279,13 @@ namespace shape { * @param shape * @param rearrange */ - ND4J_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange); + SD_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange); - ND4J_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange); + SD_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong *shapeBuffer, int* rearrange); - ND4J_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); + SD_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); - ND4J_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, const int *rearrange, Nd4jLong len = -1); + SD_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, const int *rearrange, Nd4jLong len = -1); /** * Rearrange the permute indexes @@ -302,16 +302,16 @@ namespace shape { * wise stride. */ - ND4J_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength); - ND4J_EXPORT _CUDA_HD Nd4jLong* computeResultShape(Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong* computeResultShape(Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength); /** * This method does inplace transpose of given shapeBuffer * * @param shapeBuffer */ - ND4J_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); + SD_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); /** @@ -322,7 +322,7 @@ namespace shape { * @param elementStride * @return */ - ND4J_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride); + SD_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride); /** * Ensure that every value in the re arrange @@ -334,7 +334,7 @@ namespace shape { * @return */ template - ND4J_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength); + SD_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength); /** * Permute the shape information @@ -342,7 +342,7 @@ namespace shape { * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - ND4J_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank); + SD_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank); /** * Returns whether the @@ -350,32 +350,32 @@ namespace shape { * @param shape the shape of the array * @param rank the rank of cthe shape */ - ND4J_EXPORT _CUDA_HD int isVector(Nd4jLong *shape, int rank); + SD_EXPORT _CUDA_HD int isVector(Nd4jLong *shape, int rank); /** * When 1 dimension is the whole length of the * array */ - ND4J_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); + SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); - ND4J_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim); + SD_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong *shapeInfo, int& posOfNonUnityDim); - ND4J_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim); + SD_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim); - ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo); /** * shape - input inShape is shape only, not shapeInfo * returns number of non-unity dimensions in inShape */ - ND4J_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape); + SD_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape); /** * Returns whether the @@ -384,15 +384,15 @@ namespace shape { * @param rank the rank of the shape */ - ND4J_EXPORT _CUDA_HD int isMatrix(Nd4jLong *shape, int rank); + SD_EXPORT _CUDA_HD int isMatrix(Nd4jLong *shape, int rank); INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo); /** * Returns the shape portion of an information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); /** * Return a copy of a buffer. @@ -401,10 +401,10 @@ namespace shape { */ template - ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy); + SD_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy); template - ND4J_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret); + SD_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T *toCopy, T *ret); /** * Return a copy of a buffer. @@ -413,13 +413,13 @@ namespace shape { */ template - ND4J_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to); + SD_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T *from, T *to); /** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - ND4J_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes); + SD_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong *from, Nd4jLong *to, Nd4jLong *indexes); /** * Permute the given strides @@ -430,18 +430,18 @@ namespace shape { * and all must be filled in) * @return the rearranged array */ - //ND4J_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); + //SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - ND4J_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); + SD_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); + SD_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); - ND4J_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer); + SD_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer); /** * Returns the length of the * shape information buffer: @@ -450,30 +450,30 @@ namespace shape { * info length for * @return rank * 2 + 4 */ - ND4J_EXPORT _CUDA_HD int shapeInfoLength(int rank); + SD_EXPORT _CUDA_HD int shapeInfoLength(int rank); - ND4J_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong* shapeInfo); - ND4J_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong* shapeInfo); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); + SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); - ND4J_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); /** * Returns the rank portion of * an information buffer */ - ND4J_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD int rank(const int *shapeInfo); - ND4J_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); + SD_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD int rank(const int *shapeInfo); + SD_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); /** * returns pointer on elementWiseStride */ - ND4J_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo); /** * Converts a raw int buffer of the layout: @@ -485,50 +485,50 @@ namespace shape { * * where shape and stride are both straight int pointers */ - ND4J_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); + SD_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); /** * Returns the stride portion of an information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); - ND4J_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); /** * Compute the length of the given shape */ - ND4J_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); + SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); - ND4J_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); + SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); /*** * Returns the offset portion of an information buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); - ND4J_EXPORT _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer); /** * Returns the ordering * for this shape information buffer */ - ND4J_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); + SD_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); /** * Returns the type */ - ND4J_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong* shapeInfo); /** * Returns the element wise stride for this information * buffer */ - ND4J_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer); /** @@ -536,14 +536,14 @@ namespace shape { * buffer * relative to a dimension and ordering for a reduction index */ - ND4J_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong *buffer, int *dimension, int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong *buffer, int *dimension, int dimensionLength); /** * Returns whether * the given shape info buffer * represents a scalar shape */ - ND4J_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); + SD_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); /** * Returns whether @@ -551,7 +551,7 @@ namespace shape { * represents a scalar * shape or not */ - ND4J_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); + SD_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); /** * Return a copy of this array with the @@ -566,7 +566,7 @@ namespace shape { * item */ template - ND4J_EXPORT _CUDA_HD void removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out); + SD_EXPORT _CUDA_HD void removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out); /** * Return a copy of this array with the @@ -582,7 +582,7 @@ namespace shape { */ template - ND4J_EXPORT _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength); + SD_EXPORT _CUDA_HD T1* removeIndex(T1 *data, T2 *indexes, Nd4jLong dataLength, Nd4jLong indexesLength); /** * Iterate over a given set of indexes @@ -595,7 +595,7 @@ namespace shape { * indexes should be the indexes to exclude * indexes length should be the length of indexes */ - ND4J_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end); + SD_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong *indexes,int indexesLength,int begin,int end); /** * Computes the offset for accessing @@ -605,7 +605,7 @@ namespace shape { //#ifdef __CUDACC__ // __device__ //#endif -// ND4J_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); +// SD_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); /** * Returns a shape @@ -615,11 +615,11 @@ namespace shape { * for the shape to be returned as * @return the new shape */ - ND4J_EXPORT _CUDA_HD Nd4jLong* ensureVectorShape(Nd4jLong *shape); + SD_EXPORT _CUDA_HD Nd4jLong* ensureVectorShape(Nd4jLong *shape); - ND4J_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(); + SD_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(); - ND4J_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret); + SD_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret); /** * Generate an int buffer @@ -628,20 +628,20 @@ namespace shape { * */ template - ND4J_EXPORT _CUDA_HD T* range(int from, int to, int increment); + SD_EXPORT _CUDA_HD T* range(int from, int to, int increment); /** * Range between from and two with an * increment of 1 */ template - ND4J_EXPORT _CUDA_HD T* range(int from, int to); + SD_EXPORT _CUDA_HD T* range(int from, int to); /** * Keep the given indexes * in the data */ - ND4J_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength); + SD_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int* index, int indexLength, int dataLength); /** * Generate reverse copy of the data @@ -651,16 +651,16 @@ namespace shape { */ template - ND4J_EXPORT _CUDA_HD T* reverseCopy(T *data, Nd4jLong length); + SD_EXPORT _CUDA_HD T* reverseCopy(T *data, Nd4jLong length); template - ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length); + SD_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong length); template - ND4J_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length); + SD_EXPORT _CUDA_HD void reverseCopyTo(T *from, T *to, Nd4jLong *indexes, Nd4jLong length); template - ND4J_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); + SD_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); /** * * @param arr1 @@ -670,7 +670,7 @@ namespace shape { * @return */ template - ND4J_EXPORT _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length); + SD_EXPORT _CUDA_HD T* concat(T* arr1, Nd4jLong arr1Length, T* arr2, Nd4jLong arr2Length); /** * @@ -681,7 +681,7 @@ namespace shape { * @return */ template - ND4J_EXPORT _CUDA_HD T* concat(int numArrays, int numTotalElements, Nd4jLong **arr, Nd4jLong *lengths); + SD_EXPORT _CUDA_HD T* concat(int numArrays, int numTotalElements, Nd4jLong **arr, Nd4jLong *lengths); /** * Get the length per slice of the @@ -695,7 +695,7 @@ namespace shape { * @return the length per slice of the given shape * along the given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int *dimension, int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong *shape, int *dimension, int dimensionLength); /** * calculates the offset for a tensor @@ -704,7 +704,7 @@ namespace shape { * @param tensorShape * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, + SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong *shape, Nd4jLong *tensorShape, @@ -719,7 +719,7 @@ namespace shape { * @param tensorShape * @return */ - ND4J_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); + SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); /** * Computes the tensor along dimension * offset @@ -728,7 +728,7 @@ namespace shape { * @param info the shape information to use for tad * @param dimension the dimensions to use for computing the tensor along dimensions */ -// ND4J_EXPORT _CUDA_HD int offset(int index, +// SD_EXPORT _CUDA_HD int offset(int index, // int rank, // shape::ShapeInformation *info, // Nd4jLong *dimension, @@ -740,7 +740,7 @@ namespace shape { * of tensors along * a given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, + SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, volatile int length, volatile Nd4jLong *shape, int *dimension, @@ -751,7 +751,7 @@ namespace shape { * of tensors along * a given dimension */ - ND4J_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); + SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); @@ -763,24 +763,24 @@ namespace shape { * @param i * @return */ - ND4J_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); + SD_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); /** * Computes the number of tads per block * */ - ND4J_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); + SD_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); -// ND4J_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, +// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, // int dimensionLength); /** * Returns a shape buffer * for the shape information metadata. */ - ND4J_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info); + SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info); - ND4J_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret); + SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret); /** * Returns the number of elements per thread @@ -831,7 +831,7 @@ namespace shape { * @param numElementsPerTad the number of elements * per tad */ - ND4J_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad); + SD_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad); /** * Map a tad to a @@ -841,7 +841,7 @@ namespace shape { * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - ND4J_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, + SD_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, int tadsForOriginal); /** @@ -849,7 +849,7 @@ namespace shape { * per reduce index for the * reduction tad. */ - ND4J_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); + SD_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); /** * Maps a linear index to a reduction index @@ -859,14 +859,14 @@ namespace shape { * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - ND4J_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, + SD_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, int tadNum, int originalTadNum); /** * Returns the prod of the data * up to the given length */ - ND4J_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); + SD_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); /** * Returns the rear most left over item not present in @@ -884,7 +884,7 @@ namespace shape { * the last item of the dimension array */ -// ND4J_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); +// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); /** * Get an offset for retrieval @@ -898,44 +898,44 @@ namespace shape { * @return the double at the specified index */ - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0); - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0); - ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0); + SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0); + SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0); + SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0); - ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); + SD_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); - ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer); + SD_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer); /** * Convert a linear index to the corresponding coordinates * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] */ - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords); - ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords); + SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords); + SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims); + SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims); /** * Convert coordinates to the corresponding linear index (sequence number in other words) * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned */ - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords); - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *coords); + SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); + SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords); + SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords); + SD_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims); + SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -947,75 +947,75 @@ namespace shape { /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} */ - ND4J_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned); + SD_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned); - ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo); + SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo); - ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides); + SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides); - ND4J_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); - ND4J_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); + SD_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); + SD_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); - ND4J_EXPORT _CUDA_HD void printArray(float *arr,int length); + SD_EXPORT _CUDA_HD void printArray(float *arr,int length); template - ND4J_EXPORT _CUDA_HD void printArray(T *arr,int length, const char *message); + SD_EXPORT _CUDA_HD void printArray(T *arr,int length, const char *message); - ND4J_EXPORT _CUDA_HD Nd4jLong* shapeBufferOfNpy(int rank, unsigned int *shape,bool fortranOrder); + SD_EXPORT _CUDA_HD Nd4jLong* shapeBufferOfNpy(int rank, unsigned int *shape,bool fortranOrder); - ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); + SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); -// ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); +// SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) // also sort input array of dimensions, this operation is also necessary for creating TAD object - ND4J_EXPORT _CUDA_H void checkDimensions(const int rank, std::vector& dimensions); + SD_EXPORT _CUDA_H void checkDimensions(const int rank, std::vector& dimensions); // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); + SD_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); + SD_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); // max array is outer for min array, min array is sub-array of max array // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // dimsToExclude - should be sorted in increasing order // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - ND4J_EXPORT _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); + SD_EXPORT _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array // dimsToExclude - should be sorted in increasing order - ND4J_EXPORT _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); + SD_EXPORT _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand // dimsToExclude - should be sorted in increasing order // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude = nullptr); + SD_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude = nullptr); // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // rank is equal to size of shape - ND4J_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c'); - ND4J_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); - ND4J_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, Nd4jLong* const buffer, const char order); + SD_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c'); + SD_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c'); + // SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); + // SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); + SD_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, Nd4jLong* const buffer, const char order); // deduce order and element-wise stride // if array is scalar or unit length vector then ews = 1 and order is preserved // if array is common vector then ews = stride of non-unity dimension and order is preserved // if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved - ND4J_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities); - ND4J_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo); + SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities); + SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo); /** * processes whole set of sub-arrays @@ -1029,7 +1029,7 @@ namespace shape { * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} */ - ND4J_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false); + SD_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false); /** * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array @@ -1045,7 +1045,7 @@ namespace shape { * isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, * numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1 */ - ND4J_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0); + SD_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0); /** * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} @@ -1054,7 +1054,7 @@ namespace shape { * returns number of non-unity dimensions in inShapeInfo * if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo */ - ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities); + SD_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities); /** * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2 diff --git a/libnd4j/include/indexing/IndicesList.h b/libnd4j/include/indexing/IndicesList.h index a652615d5355..f657a191ee60 100644 --- a/libnd4j/include/indexing/IndicesList.h +++ b/libnd4j/include/indexing/IndicesList.h @@ -25,7 +25,7 @@ #include "NDIndex.h" namespace sd { - class ND4J_EXPORT IndicesList { + class SD_EXPORT IndicesList { protected: std::vector _indices; public: diff --git a/libnd4j/include/indexing/NDIndex.h b/libnd4j/include/indexing/NDIndex.h index 799da4e6ca99..32c831ce2f45 100644 --- a/libnd4j/include/indexing/NDIndex.h +++ b/libnd4j/include/indexing/NDIndex.h @@ -26,7 +26,7 @@ #include namespace sd { - class ND4J_EXPORT NDIndex { + class SD_EXPORT NDIndex { protected: std::vector _indices; Nd4jLong _stride = 1; @@ -46,7 +46,7 @@ namespace sd { static NDIndex* interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); }; - class ND4J_EXPORT NDIndexAll : public NDIndex { + class SD_EXPORT NDIndexAll : public NDIndex { public: NDIndexAll(); virtual bool isInterval(); @@ -54,14 +54,14 @@ namespace sd { }; - class ND4J_EXPORT NDIndexPoint : public NDIndex { + class SD_EXPORT NDIndexPoint : public NDIndex { public: NDIndexPoint(Nd4jLong point); virtual bool isInterval(); ~NDIndexPoint() = default; }; - class ND4J_EXPORT NDIndexInterval : public NDIndex { + class SD_EXPORT NDIndexInterval : public NDIndex { public: NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); virtual bool isInterval(); diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index 4d55a3357724..c57d59e339d0 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -34,7 +34,7 @@ * */ -class ND4J_EXPORT NativeOpExecutioner { +class SD_EXPORT NativeOpExecutioner { public: /** * diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index ea8352362f6d..6060d01ca16d 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -50,9 +50,9 @@ //IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN //RE ADDS THE DEFINITION VIA dll.h #ifdef _WIN32 -#define ND4J_EXPORT __declspec(dllexport) +#define SD_EXPORT __declspec(dllexport) #else -#define ND4J_EXPORT +#define SD_EXPORT #endif #include @@ -86,32 +86,32 @@ extern "C" { * This function returns last error code stored, * @return non-zero if something bad happened */ -ND4J_EXPORT int lastErrorCode(); +SD_EXPORT int lastErrorCode(); /** * This function returns last error message, if last error code > 0 * @return */ -ND4J_EXPORT const char* lastErrorMessage(); +SD_EXPORT const char* lastErrorMessage(); /** * * @param p * @param len */ -ND4J_EXPORT void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len); +SD_EXPORT void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len); /** * * @param num */ -ND4J_EXPORT void setElementThreshold(int num); +SD_EXPORT void setElementThreshold(int num); /** * * @param num */ -ND4J_EXPORT void setTADThreshold(int num); +SD_EXPORT void setTADThreshold(int num); /** * @@ -120,7 +120,7 @@ ND4J_EXPORT void setTADThreshold(int num); * @param xShapeInfo * @param extraParams */ -ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, +SD_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -137,7 +137,7 @@ ND4J_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, +SD_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -156,7 +156,7 @@ ND4J_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execBroadcast( +SD_EXPORT void execBroadcast( Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, @@ -165,7 +165,7 @@ ND4J_EXPORT void execBroadcast( OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); -ND4J_EXPORT void execBroadcastBool( +SD_EXPORT void execBroadcastBool( Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, @@ -186,7 +186,7 @@ ND4J_EXPORT void execBroadcastBool( * @param extraParams * @param n */ -ND4J_EXPORT void execPairwiseTransform( +SD_EXPORT void execPairwiseTransform( Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, @@ -194,7 +194,7 @@ ND4J_EXPORT void execPairwiseTransform( OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); -ND4J_EXPORT void execPairwiseTransformBool( +SD_EXPORT void execPairwiseTransformBool( Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, @@ -211,26 +211,26 @@ ND4J_EXPORT void execPairwiseTransformBool( * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); -ND4J_EXPORT void execReduceSame(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceSame(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); -ND4J_EXPORT void execReduceBool(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceBool(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo); -ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceLong(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -245,7 +245,7 @@ ND4J_EXPORT void execReduceLong(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -253,7 +253,7 @@ ND4J_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); -ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -261,7 +261,7 @@ ND4J_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); -ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -269,7 +269,7 @@ ND4J_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbDimension, Nd4jLong *hDimensionShape, Nd4jLong *dDimensionShape); -ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, +SD_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -287,7 +287,7 @@ ND4J_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, +SD_EXPORT void execReduce3(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, @@ -303,7 +303,7 @@ ND4J_EXPORT void execReduce3(Nd4jPointer *extraPointers, * @param y * @param yShapeInfo */ -ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, +SD_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, @@ -322,7 +322,7 @@ ND4J_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, +SD_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, @@ -333,7 +333,7 @@ ND4J_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets); -ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, +SD_EXPORT void execReduce3All(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParamsVals, @@ -354,14 +354,14 @@ ND4J_EXPORT void execReduce3All(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -ND4J_EXPORT void execScalar(Nd4jPointer *extraPointers, +SD_EXPORT void execScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbScalar, Nd4jLong *hSscalarShapeInfo, Nd4jLong *dSscalarShapeInfo, void *extraParams); -ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, +SD_EXPORT void execScalarBool(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, @@ -375,7 +375,7 @@ ND4J_EXPORT void execScalarBool(Nd4jPointer *extraPointers, * @param xShapeInfo * @param extraParams */ -ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, +SD_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -390,7 +390,7 @@ ND4J_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, +SD_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -407,7 +407,7 @@ ND4J_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, +SD_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, void *extraParams, @@ -426,31 +426,31 @@ ND4J_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -ND4J_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, +SD_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); -ND4J_EXPORT void execTransformSame(Nd4jPointer *extraPointers, +SD_EXPORT void execTransformSame(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); -ND4J_EXPORT void execTransformBool(Nd4jPointer *extraPointers, +SD_EXPORT void execTransformBool(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); -ND4J_EXPORT void execTransformAny(Nd4jPointer *extraPointers, +SD_EXPORT void execTransformAny(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, void *extraParams); -ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, +SD_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, @@ -469,7 +469,7 @@ ND4J_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, +SD_EXPORT void execScalarTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, @@ -479,7 +479,7 @@ ND4J_EXPORT void execScalarTad(Nd4jPointer *extraPointers, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); -ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, +SD_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeInfo, Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeInfo, Nd4jLong *dZShapeInfo, @@ -489,7 +489,7 @@ ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); -ND4J_EXPORT void specialConcat ( +SD_EXPORT void specialConcat ( Nd4jPointer *extraPointers, int dimension, int numArrays, @@ -504,9 +504,9 @@ ND4J_EXPORT void specialConcat ( * This method implementation exists only for cuda. * The other backends should have dummy method for JNI compatibility reasons. */ -ND4J_EXPORT void initializeDevicesAndFunctions(); +SD_EXPORT void initializeDevicesAndFunctions(); -ND4J_EXPORT void initializeFunctions(Nd4jPointer *functions); +SD_EXPORT void initializeFunctions(Nd4jPointer *functions); /** * This method acquires memory chunk of requested size on host side @@ -515,7 +515,7 @@ ND4J_EXPORT void initializeFunctions(Nd4jPointer *functions); * @param memorySize memory size, in bytes * @param flags optional parameter */ -ND4J_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); +SD_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); /** * This method acquires memory chunk of requested size on specified device @@ -525,14 +525,14 @@ ND4J_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc * @param flags optional parameter */ -ND4J_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags); +SD_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags); /** * This method releases previously allocated host memory space * * @param pointer pointer that'll be freed */ -ND4J_EXPORT int freeHost(Nd4jPointer pointer); +SD_EXPORT int freeHost(Nd4jPointer pointer); /** * This method releases previously allocated memory space on device @@ -540,52 +540,52 @@ ND4J_EXPORT int freeHost(Nd4jPointer pointer); * @param pointer pointer that'll be freed * @param ptrToDeviceId pointer to deviceId. */ -ND4J_EXPORT int freeDevice(Nd4jPointer pointer, int deviceId); +SD_EXPORT int freeDevice(Nd4jPointer pointer, int deviceId); /** * * @return */ -ND4J_EXPORT int ompGetMaxThreads(); +SD_EXPORT int ompGetMaxThreads(); /** * * @return */ -ND4J_EXPORT int ompGetNumThreads(); +SD_EXPORT int ompGetNumThreads(); /** * * @param threads */ -ND4J_EXPORT void setOmpNumThreads(int threads); +SD_EXPORT void setOmpNumThreads(int threads); /** * * @param threads */ -ND4J_EXPORT void setOmpMinThreads(int threads); +SD_EXPORT void setOmpMinThreads(int threads); -ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build); +SD_EXPORT bool isBlasVersionMatches(int major, int minor, int build); /** * * @return */ -ND4J_EXPORT Nd4jPointer createContext(); +SD_EXPORT Nd4jPointer createContext(); /** * * @return */ -ND4J_EXPORT Nd4jPointer createStream(); +SD_EXPORT Nd4jPointer createStream(); /** * * @return */ -ND4J_EXPORT Nd4jPointer createEvent(); +SD_EXPORT Nd4jPointer createEvent(); /** * @@ -593,89 +593,89 @@ ND4J_EXPORT Nd4jPointer createEvent(); * @param stream * @return */ -ND4J_EXPORT int registerEvent(Nd4jPointer event, Nd4jPointer stream); +SD_EXPORT int registerEvent(Nd4jPointer event, Nd4jPointer stream); /** * * @param event * @return */ -ND4J_EXPORT int destroyEvent(Nd4jPointer event); +SD_EXPORT int destroyEvent(Nd4jPointer event); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int setDevice(int deviceId); +SD_EXPORT int setDevice(int deviceId); /** * * @return */ -ND4J_EXPORT int getDevice(); +SD_EXPORT int getDevice(); /** * * @param stream * @return */ -ND4J_EXPORT int streamSynchronize(Nd4jPointer stream); +SD_EXPORT int streamSynchronize(Nd4jPointer stream); /** * * @param event * @return */ -ND4J_EXPORT int eventSynchronize(Nd4jPointer event); +SD_EXPORT int eventSynchronize(Nd4jPointer event); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT Nd4jLong getDeviceFreeMemory(int deviceId); +SD_EXPORT Nd4jLong getDeviceFreeMemory(int deviceId); /** * Returns amount of free memory for current device * @return */ -ND4J_EXPORT Nd4jLong getDeviceFreeMemoryDefault(); +SD_EXPORT Nd4jLong getDeviceFreeMemoryDefault(); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT Nd4jLong getDeviceTotalMemory(int deviceId); +SD_EXPORT Nd4jLong getDeviceTotalMemory(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int getDeviceMajor(int deviceId); +SD_EXPORT int getDeviceMajor(int deviceId); /** * This method returns amount of cached memory * @param deviceId * @return */ -ND4J_EXPORT Nd4jLong getCachedMemory(int deviceId); +SD_EXPORT Nd4jLong getCachedMemory(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT int getDeviceMinor(int deviceId); +SD_EXPORT int getDeviceMinor(int deviceId); /** * * @param ptrToDeviceId * @return */ -ND4J_EXPORT const char * getDeviceName(int deviceId); +SD_EXPORT const char * getDeviceName(int deviceId); /** * @@ -686,7 +686,7 @@ ND4J_EXPORT const char * getDeviceName(int deviceId); * @param reserved * @return */ -ND4J_EXPORT int memcpySync(Nd4jPointer dst, +SD_EXPORT int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, @@ -701,7 +701,7 @@ ND4J_EXPORT int memcpySync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memcpyAsync(Nd4jPointer dst, +SD_EXPORT int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, @@ -716,7 +716,7 @@ ND4J_EXPORT int memcpyAsync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memsetSync(Nd4jPointer dst, +SD_EXPORT int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, @@ -731,7 +731,7 @@ ND4J_EXPORT int memsetSync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memsetAsync(Nd4jPointer dst, +SD_EXPORT int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, @@ -746,7 +746,7 @@ ND4J_EXPORT int memsetAsync(Nd4jPointer dst, * @param reserved * @return */ -ND4J_EXPORT int memcpyConstantAsync(Nd4jLong dst, +SD_EXPORT int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, @@ -756,31 +756,31 @@ ND4J_EXPORT int memcpyConstantAsync(Nd4jLong dst, * * @return */ -ND4J_EXPORT Nd4jPointer getConstantSpace(); +SD_EXPORT Nd4jPointer getConstantSpace(); /** * * @return */ -ND4J_EXPORT int getAvailableDevices(); +SD_EXPORT int getAvailableDevices(); /** * * @param reallyEnable */ -ND4J_EXPORT void enableDebugMode(bool reallyEnable); +SD_EXPORT void enableDebugMode(bool reallyEnable); /** * * @param reallyEnable */ -ND4J_EXPORT void enableVerboseMode(bool reallyEnable); +SD_EXPORT void enableVerboseMode(bool reallyEnable); /** * * @param gridSize */ -ND4J_EXPORT void setGridLimit(int gridSize); +SD_EXPORT void setGridLimit(int gridSize); typedef sd::TadPack OpaqueTadPack; @@ -792,18 +792,18 @@ typedef sd::TadPack OpaqueTadPack; * @param targetBuffer * @param offsetsBuffer */ -ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, +SD_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, int *dimension, int dimensionLength); -ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack); -ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); -ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack); +SD_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); +SD_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); -ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); +SD_EXPORT void deleteTadPack(OpaqueTadPack* ptr); /* * PullRow special op @@ -823,7 +823,7 @@ ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); * @param zTadShapeInfo * @param zTadOffsets */ -ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, +SD_EXPORT void pullRows(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, OpaqueDataBuffer *dbZ, Nd4jLong *zShapeInfo, Nd4jLong *dzShapeInfo, Nd4jLong n, @@ -842,7 +842,7 @@ ND4J_EXPORT void pullRows(Nd4jPointer *extraPointers, * @param length * @param propagate */ -ND4J_EXPORT void average(Nd4jPointer *extras, +SD_EXPORT void average(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *dx, Nd4jLong *dxShapeInfo, void *z, Nd4jLong *zShapeInfo, @@ -852,7 +852,7 @@ ND4J_EXPORT void average(Nd4jPointer *extras, bool propagate); -ND4J_EXPORT void accumulate(Nd4jPointer *extras, +SD_EXPORT void accumulate(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong *xShapeInfo, Nd4jPointer *dx, Nd4jLong *dxShapeInfo, void *z, Nd4jLong *zShapeInfo, @@ -868,18 +868,18 @@ ND4J_EXPORT void accumulate(Nd4jPointer *extras, * * @param enable */ -ND4J_EXPORT void enableP2P(bool enable); +SD_EXPORT void enableP2P(bool enable); /** * */ -ND4J_EXPORT void checkP2P(); +SD_EXPORT void checkP2P(); /** * * @return */ -ND4J_EXPORT bool isP2PAvailable(); +SD_EXPORT bool isP2PAvailable(); /** * Shuffle methods @@ -897,7 +897,7 @@ ND4J_EXPORT bool isP2PAvailable(); * @param tadShapeInfo * @param tadOffsets */ -ND4J_EXPORT void shuffle(Nd4jPointer *extras, +SD_EXPORT void shuffle(Nd4jPointer *extras, Nd4jPointer *x, Nd4jPointer *xShapeInfo, Nd4jPointer *dx, Nd4jPointer *dxShapeInfo, Nd4jPointer *z, Nd4jPointer *zShapeInfo, @@ -921,14 +921,14 @@ ND4J_EXPORT void shuffle(Nd4jPointer *extras, * @param dstType * @param z */ -ND4J_EXPORT void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer x, Nd4jLong N, int dstType, Nd4jPointer z); +SD_EXPORT void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer x, Nd4jLong N, int dstType, Nd4jPointer z); /** * * @return */ -ND4J_EXPORT bool isExperimentalEnabled(); +SD_EXPORT bool isExperimentalEnabled(); /** * Aggregate @@ -949,7 +949,7 @@ ND4J_EXPORT bool isExperimentalEnabled(); * @param realArguments * @param numRealArguments */ -ND4J_EXPORT void execAggregate(Nd4jPointer *extraPointers, +SD_EXPORT void execAggregate(Nd4jPointer *extraPointers, int opNum, void **arguments, int numArguments, @@ -964,7 +964,7 @@ ND4J_EXPORT void execAggregate(Nd4jPointer *extraPointers, sd::DataType dtype); -ND4J_EXPORT void batchExecutor(Nd4jPointer *extraPointers, +SD_EXPORT void batchExecutor(Nd4jPointer *extraPointers, int numAggregates, int opNum, int maxArgs, @@ -976,7 +976,7 @@ ND4J_EXPORT void batchExecutor(Nd4jPointer *extraPointers, void *ptrToArguments, sd::DataType dtype); -ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, +SD_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, int numAggregates, int opNum, int maxArgs, @@ -1001,7 +1001,7 @@ ND4J_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, +SD_EXPORT void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, OpaqueDataBuffer *dbZ, Nd4jLong *hZShapeBuffer, Nd4jLong *dZShapeBuffer, @@ -1020,7 +1020,7 @@ ND4J_EXPORT void execRandom(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, +SD_EXPORT void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, @@ -1039,7 +1039,7 @@ ND4J_EXPORT void execRandom3(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, +SD_EXPORT void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, OpaqueDataBuffer *dbX, Nd4jLong *hXShapeBuffer, Nd4jLong *dXShapeBuffer, @@ -1055,7 +1055,7 @@ ND4J_EXPORT void execRandom2(Nd4jPointer *extraPointers, * @param ptrToBuffer * @return */ -ND4J_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, +SD_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer); @@ -1066,7 +1066,7 @@ ND4J_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -ND4J_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, +SD_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom); @@ -1076,7 +1076,7 @@ ND4J_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, +SD_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom); @@ -1084,7 +1084,7 @@ ND4J_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, * * @param ptrRandom */ -ND4J_EXPORT void destroyRandom(Nd4jPointer ptrRandom); +SD_EXPORT void destroyRandom(Nd4jPointer ptrRandom); } @@ -1202,7 +1202,7 @@ static Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLo * @param npyArray * @return */ -ND4J_EXPORT Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); +SD_EXPORT Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); /** @@ -1321,7 +1321,7 @@ static void* getNpyArrayFromMap(void *map, int index){ throw std::runtime_error("No array at index."); } -ND4J_EXPORT int dataTypeFromNpyHeader(void *header); +SD_EXPORT int dataTypeFromNpyHeader(void *header); static void* getNpyArrayData(void *npArray){ cnpy::NpyArray* npyArray2 = reinterpret_cast(npArray); @@ -1405,7 +1405,7 @@ static void releaseNumpy(Nd4jPointer npyArray) { * @param buffer the buffer pointer to check * @return */ -ND4J_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); +SD_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); /** @@ -1415,7 +1415,7 @@ ND4J_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); * @return the pointer for the given address */ -ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); +SD_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); /** * This method takes single N-dimensional tensor, and copies its TADs to target arrays @@ -1426,44 +1426,44 @@ ND4J_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); * @param zShapeInfo * @return */ -ND4J_EXPORT void tear(Nd4jPointer *extraPointers, +SD_EXPORT void tear(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, Nd4jPointer *targets, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); -ND4J_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); -ND4J_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); +SD_EXPORT Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); +SD_EXPORT void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); -ND4J_EXPORT void encodeThresholdP1(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); -ND4J_EXPORT void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *dx, Nd4jLong N, int *dz); -ND4J_EXPORT void encodeThresholdP3(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, int *offsets, Nd4jLong N, int *dz); +SD_EXPORT void encodeThresholdP1(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold); +SD_EXPORT void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *dx, Nd4jLong N, int *dz); +SD_EXPORT void encodeThresholdP3(Nd4jPointer *extraPointers, void *dx, Nd4jLong *xShapeInfo, int *offsets, Nd4jLong N, int *dz); -ND4J_EXPORT void decodeThreshold(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); +SD_EXPORT void decodeThreshold(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); -ND4J_EXPORT void sort(Nd4jPointer *extraPointers, +SD_EXPORT void sort(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, bool descending); -ND4J_EXPORT void sortByKey(Nd4jPointer *extraPointers, +SD_EXPORT void sortByKey(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending); -ND4J_EXPORT void sortByValue(Nd4jPointer *extraPointers, +SD_EXPORT void sortByValue(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending); -ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers, +SD_EXPORT void sortTad(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, int *dimension, @@ -1472,7 +1472,7 @@ ND4J_EXPORT void sortTad(Nd4jPointer *extraPointers, Nd4jLong *tadOffsets, bool descending); -ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, +SD_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, void *y, Nd4jLong *yShapeInfo, @@ -1481,7 +1481,7 @@ ND4J_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, int dimensionLength, bool descending); -ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, +SD_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dx, Nd4jLong *dxShapeInfo, void *y, Nd4jLong *yShapeInfo, @@ -1492,178 +1492,178 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, // special sort impl for sorting out COO indices and values -ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); +SD_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); -ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); +SD_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); -ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); +SD_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); typedef sd::graph::ResultWrapper OpaqueResultWrapper; // flatbuffers execution -ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); +SD_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); -ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); -ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); +SD_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); +SD_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); -ND4J_EXPORT const char* getAllCustomOps(); +SD_EXPORT const char* getAllCustomOps(); -ND4J_EXPORT const char* getAllOperations(); +SD_EXPORT const char* getAllOperations(); // customOp executioner -ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); -ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); +SD_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); +SD_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); typedef sd::ShapeList OpaqueShapeList; -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); -ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); -ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); +SD_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); +SD_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); -ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); +SD_EXPORT void deleteShapeList(Nd4jPointer shapeList); -ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); +SD_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); typedef sd::graph::VariablesSet OpaqueVariablesSet; typedef sd::graph::Variable OpaqueVariable; -ND4J_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); +SD_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); -ND4J_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); -ND4J_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); -ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i); -ND4J_EXPORT int getVariableId(OpaqueVariable* variable); -ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); -ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); -ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable); -ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable); +SD_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); +SD_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); +SD_EXPORT OpaqueVariable* getVariable(OpaqueVariablesSet* set, Nd4jLong i); +SD_EXPORT int getVariableId(OpaqueVariable* variable); +SD_EXPORT int getVariableIndex(OpaqueVariable* variable); +SD_EXPORT const char* getVariableName(OpaqueVariable* variable); +SD_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable); +SD_EXPORT void* getVariableBuffer(OpaqueVariable* variable); -ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); +SD_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); -ND4J_EXPORT void deleteCharArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); -ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); +SD_EXPORT void deleteCharArray(Nd4jPointer pointer); +SD_EXPORT void deleteIntArray(Nd4jPointer pointer); +SD_EXPORT void deleteLongArray(Nd4jPointer pointer); +SD_EXPORT void deletePointerArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); +SD_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); // GraphState creation -ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); +SD_EXPORT Nd4jPointer getGraphState(Nd4jLong id); -ND4J_EXPORT void deleteGraphState(Nd4jPointer state); +SD_EXPORT void deleteGraphState(Nd4jPointer state); -ND4J_EXPORT void deleteResultWrapper(Nd4jPointer ptr); +SD_EXPORT void deleteResultWrapper(Nd4jPointer ptr); -ND4J_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong *xShapeInfo, int N, float threshold); +SD_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong *xShapeInfo, int N, float threshold); // this method executes op that requires scope to be present: if/while/cond/whatever -ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs); +SD_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs); //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); -ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); -ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr); -ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr); -ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); +SD_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); +SD_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr); +SD_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr); +SD_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); -ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, +SD_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, void* dX, Nd4jLong* dXShapeInfo, Nd4jLong* dXOffsets, void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* hIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo); -ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); +SD_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer; -ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty); +SD_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty); -ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length); -ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length); -ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor); +SD_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong *data, int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor); -ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf); -ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr); +SD_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr); typedef sd::graph::Context OpaqueContext; typedef sd::graph::RandomGenerator OpaqueRandomGenerator; -ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); -ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); -ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); -ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); -ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); -ND4J_EXPORT void ctxPurge(OpaqueContext* ptr); -ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); -ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); -ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -ND4J_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); -ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); -ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr); - -ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); -ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); -ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); -ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); -ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); -ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); - -ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); -ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); +SD_EXPORT OpaqueContext* createGraphContext(int nodeId); +SD_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); +SD_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); +SD_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); +SD_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); +SD_EXPORT void ctxPurge(OpaqueContext* ptr); +SD_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); +SD_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); +SD_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +SD_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +SD_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); +SD_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); +SD_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments); +SD_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); +SD_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); +SD_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); +SD_EXPORT void deleteGraphContext(OpaqueContext* ptr); + +SD_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +SD_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); +SD_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); +SD_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +SD_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); + +SD_EXPORT const char* runLightBenchmarkSuit(bool printOut); +SD_EXPORT const char* runFullBenchmarkSuit(bool printOut); typedef sd::LaunchContext OpaqueLaunchContext; -ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext(); -ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); -ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); - -ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); -ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); -ND4J_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); -ND4J_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); -ND4J_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); -ND4J_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); -ND4J_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); - - -ND4J_EXPORT int binaryLevel(); -ND4J_EXPORT int optimalLevel(); - -ND4J_EXPORT bool isMinimalRequirementsMet(); -ND4J_EXPORT bool isOptimalRequirementsMet(); +SD_EXPORT OpaqueLaunchContext* defaultLaunchContext(); +SD_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); +SD_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); + +SD_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); +SD_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); +SD_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); +SD_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); +SD_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); +SD_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); +SD_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); +SD_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); +SD_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); + + +SD_EXPORT int binaryLevel(); +SD_EXPORT int optimalLevel(); + +SD_EXPORT bool isMinimalRequirementsMet(); +SD_EXPORT bool isOptimalRequirementsMet(); } diff --git a/libnd4j/include/legacy/impl/cnpy.cpp b/libnd4j/include/legacy/impl/cnpy.cpp index ee4fa36b0f81..c24c016875e0 100644 --- a/libnd4j/include/legacy/impl/cnpy.cpp +++ b/libnd4j/include/legacy/impl/cnpy.cpp @@ -729,6 +729,6 @@ std::vector cnpy::createNpyHeader(const void *vdata, return header; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector cnpy::createNpyHeader, (const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize), LIBND4J_TYPES); -//template ND4J_EXPORT std::vector cnpy::createNpyHeader(const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize); -template ND4J_EXPORT void cnpy::npy_save(std::string fname, const float* data, const unsigned int* shape, const unsigned int ndims, std::string mode); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector cnpy::createNpyHeader, (const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize), LIBND4J_TYPES); +//template SD_EXPORT std::vector cnpy::createNpyHeader(const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize); +template SD_EXPORT void cnpy::npy_save(std::string fname, const float* data, const unsigned int* shape, const unsigned int ndims, std::string mode); diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.hpp b/libnd4j/include/loops/cpu/broadcasting_bool.hpp index 21b40cb55e67..94f7a335f959 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.hpp @@ -756,7 +756,7 @@ void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, } } - //BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); + //BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); } diff --git a/libnd4j/include/loops/cpu/broadcasting_int.hpp b/libnd4j/include/loops/cpu/broadcasting_int.hpp index 456994b1608c..2e6fe55cf0d1 100644 --- a/libnd4j/include/loops/cpu/broadcasting_int.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_int.hpp @@ -740,6 +740,6 @@ void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, } } -//BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); +//BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp index 08ebd92f7bcc..a21ea1109d12 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_0, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_0, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp index 16e4c817aaa3..8cb7bc865fa4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_1, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_1, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp index 10b32ca41dfb..b073e4603652 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_2, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_2, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp index 547ddd371535..6d5032a88f04 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_3, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_3, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp index 3c7dee0a0ce5..e312a564387a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_4, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_4, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp index b71925dab7cd..2f37d6505e1a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_5, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_5, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp index 23eedd289b87..e15adcd9fcba 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_6, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_6, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp index c18e7641ea4c..4dfc22073780 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_7, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_7, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp index efee34519197..ab59a846d87a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_8, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_8, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp index 2ab19328532b..e43382ec032b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_9, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_9, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp index d3f5ada43a58..aec26e5ea83d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_0); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp index 82969bdb008a..50cd1972268c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_1); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp index 53d928111bc2..807a613bf832 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_2); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp index eba7b78d1d1c..26dfa1985cd2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_3); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp index 47b7350f23c7..652974b39d58 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_4); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp index 3afad08f6334..4159e5d0b7be 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_5); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp index 286c2680ffaf..1fb44733ab2b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_6); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp index 242441561e7b..72d127f312f8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_7); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp index 943186a8a6e7..059ef45b97d7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp index b38a1c801f5e..b93b66aef95e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp index 98330500776d..09c6cbd50e4c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp index 206b1476346f..33da9553fb3b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp index 825c07adf62c..f85e7ff9c592 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp index 341f1afb4635..77242a3145e8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp index 9aa4c227bb5b..683629bf1b95 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp index 7f68bb1f83ba..fc720385475f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp index d2e586bf8ef2..f9ce462c993b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp index a9db2f7f8000..dbf09ea2273e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp index 9a2111ee5c5f..700a5c771c59 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp index 4bbd88ba656c..60720b2da364 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp index 406a8f8e28cc..c0255f2d3d26 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp index 89b85485ac3d..137258f77d62 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp index ada7844cb3c6..3aaf3fde7282 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp index 47dce2d5a38d..c4f87dfae7da 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp index c3d33e7f1c3d..1a86d3eb4cc5 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp index 37a81e441281..d263456400d9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp index 1d6555ddf7d4..4195c48a8dcf 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp index 0bb8aef4da7c..b6966425da82 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp index a7d3c733f81d..931d9a5ad91f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp index 8c5de9653bb5..6b282d8fb102 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp index f61d604e222a..17d14a835381 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp index d399f5e0ee06..63b5347a176c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp index c4df4d2e4ad5..b7bab85cba95 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp index 538e369eb93d..eb4217f66390 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp index b0d082bce3fb..fceeb38298d0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp index 98e13bb63225..0bb478598c3f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp index 4b7f599d9a31..851fe1edb26d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp index 8d7de9822288..b9268e519770 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp index 8f9befddb007..c17d61930394 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp index b38112631abe..ddea061ac185 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp index baacdc432c46..79a6ddac2730 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp index d498a4400586..3dbc22427690 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp index 2a665d9d2330..607467b47884 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp @@ -22,7 +22,7 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp index 4a8aaf94adea..365ff223a933 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp index 1f4eb1389735..6222e487aeef 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp index 3c0984db9cc3..9a9909bca945 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp index 0725ae862ad8..83bee2bb3b95 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp index f9dcf3519824..804a8887521b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp index a7b63427df4c..c244607b3c5e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp index 3f8557ea9f8b..51043f20f9d6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp index 2b5dc9ed470d..02ed81a9a27f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp index f5deef7195f3..9cd8ff32a9fc 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp index e2fa75bbb438..0f57bc913d9a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp index eb3da276e382..6ef3f7e07bf0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_0.cpp b/libnd4j/include/loops/cpu/compilation_units/random_0.cpp index 6424ccb6ecf6..ef5c075533e0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_0); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_1.cpp b/libnd4j/include/loops/cpu/compilation_units/random_1.cpp index 316d55bf6373..c4ec6bc1d855 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_1.cpp @@ -22,6 +22,6 @@ namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_1); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_2.cpp b/libnd4j/include/loops/cpu/compilation_units/random_2.cpp index 90d080b632f6..d766d5caf115 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_2); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_3.cpp b/libnd4j/include/loops/cpu/compilation_units/random_3.cpp index 97e5211e8a63..08032c8b4bf9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_3); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp index 19483c1dfda1..9b4c36769bbc 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp index 88225bd85c7e..c4e77433ab7a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp index 7bed85c5d7b7..327c7d47e73c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp index 87042d34205b..d26a609040fd 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp index 0802e11f4026..8dac72f07052 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp index 87ec2d3f8761..0b35b957e574 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp index 10dc7d69b8be..b7a4f0e24242 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp index 28ba56376bd2..cc66ed99b8d0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp index 8087f6a07af3..f8da905c3a7d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp index 4a5186cf0320..ebbaef251a96 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp index 34172b4b3cbe..f4d9c53e933f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp index c2f7c7e9c62c..ee6dd7ff5dd8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp index 41c1dd679bff..8bc9ab053949 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp index a440852328d1..139f955b5b9e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp index d346d175b0d8..5146e675dac7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp index 86cf48ff754c..bee768cd8233 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp index 92f7ac39ecf5..cc9fe5cb6687 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp index eb216f89f35e..fc46966d8ee8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp index d1e9f8c96c69..59e293f9a678 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp index fa00bde190de..58926d9e33d6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp index cb212b06b902..57069c7c8730 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp index 4a7fdee8a4ca..7075e4fe1d40 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp index aaafe1baea40..b4fe17ed8e0b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp index 9b8cf0c6a775..109d46c0cc5f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp index 4d02ffe53538..9390c388e9c4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp index 88ce3e5e2c81..44162fb56527 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp index 26d4df1dd166..078ae7968af6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp index 3b04f47aab29..1027b86916f9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp index c87090229a3e..0addb0a3f44d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp index d5acb3935dc6..1e2878851ff1 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp index e7e1fab61389..52e0648d8688 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp index 98ccf8b357be..3fcf252de0e7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp index 6782d74ed6fa..bd21f708b73d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp index 915b0ac0e69f..e15aaa14c13c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp index d34e611815f9..ac3a138106d7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp index 89a8f164f4fd..7a9b85d9a578 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp index 70e482b8bbaa..ff6490bebdc7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp index 88663cd7dbac..36270a134959 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp index d5399a4d81b7..ed9c9ff64a7c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp index e27e7ab12d72..e619ccdeff32 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp index de4619f29109..3ebf86606e73 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp index bfa88bc3b480..a0bc314e24cc 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp index 8cc2795a45a4..387516ed470a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp index 0b94831c3133..6194cb7e45a4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp @@ -23,6 +23,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp index 32f670f46cba..3e1008b57fac 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp index 5146d70bdcf1..90e5d0a47398 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp index 7175a8603825..4433cef9f755 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp index a6b7bafac95c..4366a23489ba 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp index 69cbeb7ff015..14c9b3774dfa 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp index 1e0f25909498..29616adc10e2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp index e4f2c6457377..686a3f7f7852 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp index daabf93259d0..865c7bf5ab10 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp index cadad858ec55..4284efac914d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp index 7e56f65c74d0..29a13300a8b6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp index 85cedcecd109..b542c60ab56d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp index d593889b8854..79983ab1d714 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp index 14eb788d712a..41b39bb3f2d9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise_bool.cpp b/libnd4j/include/loops/cpu/pairwise_bool.cpp index d77413e8cb97..633af24b1165 100644 --- a/libnd4j/include/loops/cpu/pairwise_bool.cpp +++ b/libnd4j/include/loops/cpu/pairwise_bool.cpp @@ -231,6 +231,6 @@ namespace functions { } } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cpu/pairwise_int.cpp b/libnd4j/include/loops/cpu/pairwise_int.cpp index 9af092a0f3d2..af2bcb92d44a 100644 --- a/libnd4j/include/loops/cpu/pairwise_int.cpp +++ b/libnd4j/include/loops/cpu/pairwise_int.cpp @@ -231,6 +231,6 @@ namespace functions { } } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , INTEGER_TYPES); } } diff --git a/libnd4j/include/loops/cpu/random.hpp b/libnd4j/include/loops/cpu/random.hpp index 034179f07efb..b533e5453beb 100644 --- a/libnd4j/include/loops/cpu/random.hpp +++ b/libnd4j/include/loops/cpu/random.hpp @@ -281,6 +281,6 @@ namespace functions { } - //BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); + //BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index afb441a45e18..c7cd93b9e134 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -236,6 +236,6 @@ namespace functions { } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index 98b462ebdb1a..61eeb98daf9e 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -259,6 +259,6 @@ namespace functions { } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index f357b7e64ea2..079777480bc7 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -268,6 +268,6 @@ namespace functions { } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , LIBND4J_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/scalar_bool.cpp b/libnd4j/include/loops/cpu/scalar_bool.cpp index c6f437ba8874..af55bdd40596 100644 --- a/libnd4j/include/loops/cpu/scalar_bool.cpp +++ b/libnd4j/include/loops/cpu/scalar_bool.cpp @@ -211,7 +211,7 @@ namespace functions { } } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp index ed85e28ef79f..bfd8679f006d 100644 --- a/libnd4j/include/loops/cpu/scalar_int.cpp +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -212,7 +212,7 @@ namespace functions { } } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , INTEGER_TYPES); } } diff --git a/libnd4j/include/loops/cpu/summarystatsreduce.cpp b/libnd4j/include/loops/cpu/summarystatsreduce.cpp index f6b44b75cb9e..df0f0b54ccf2 100644 --- a/libnd4j/include/loops/cpu/summarystatsreduce.cpp +++ b/libnd4j/include/loops/cpu/summarystatsreduce.cpp @@ -191,6 +191,6 @@ namespace functions { } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_any.cpp b/libnd4j/include/loops/cpu/transform/transform_any.cpp index 3fc9af1b325e..1a38f809fd38 100644 --- a/libnd4j/include/loops/cpu/transform/transform_any.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_any.cpp @@ -57,6 +57,6 @@ void _CUDA_H TransformAny::exec(void *vx, Nd4jLong *xShapeInfo, -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_bool.cpp b/libnd4j/include/loops/cpu/transform/transform_bool.cpp index 7302ef970b9d..50bfb741cc4e 100644 --- a/libnd4j/include/loops/cpu/transform/transform_bool.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_bool.cpp @@ -57,6 +57,6 @@ namespace functions { sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_float.cpp b/libnd4j/include/loops/cpu/transform/transform_float.cpp index 833b263f16c5..088eebb64185 100644 --- a/libnd4j/include/loops/cpu/transform/transform_float.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_float.cpp @@ -56,6 +56,6 @@ namespace functions { sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_same.cpp b/libnd4j/include/loops/cpu/transform/transform_same.cpp index bc9d2e525ba7..42cc4609610b 100644 --- a/libnd4j/include/loops/cpu/transform/transform_same.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_same.cpp @@ -55,6 +55,6 @@ namespace functions { sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_strict.cpp b/libnd4j/include/loops/cpu/transform/transform_strict.cpp index 2ef3b808e07a..22067cf37254 100644 --- a/libnd4j/include/loops/cpu/transform/transform_strict.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_strict.cpp @@ -56,6 +56,6 @@ namespace functions { sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index 848522a358a5..998ace05b517 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -299,16 +299,16 @@ __device__ void Broadcast::transformCuda( } /* - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); */ } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index 1c7bc358ef1d..2ed5b3aee0ae 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -315,6 +315,6 @@ __device__ void BroadcastBool::transformCuda(const void *vx, const Nd4jLong } -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu index 998ac9ae8eab..7fc116c2387c 100644 --- a/libnd4j/include/loops/cuda/broadcasting_int.cu +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -295,6 +295,6 @@ __device__ void BroadcastInt::transformCuda(const void *vx, const Nd4jLong *x } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu index d7902af87642..ffc274e1b7e6 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu index b24ebdb6cd51..37c4edefce73 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu index 4d19a893cec4..1400b7289092 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu index 8b643965bd8f..4cfd95934238 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu index 935297a530ec..9600cd9f8452 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu index 7d7fdc1b61a4..92112eb5b30b 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu index d5c09f1147cf..b72bd706a4bf 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu index f3c64a91a36c..b592b874e081 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu index 5ca557a3015e..c66438f5c70d 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu index 9c53e8b366dd..5381d45f439d 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu index a64b6f0d3d55..d917b7c0fe6d 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu index 4404fed7cece..b24f16bc6746 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu index dbb560f5c724..48bc66b120d0 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu @@ -22,6 +22,6 @@ namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu index e57433ae2b96..21c6550e630c 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu index 513a2c056c9b..729b612d349b 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu index fac835b185d4..01b197a007c9 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu index f01ef7eb3662..e552367e74d2 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu index 8cf8c367f27f..6c3176ee4993 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu index 8e0261d140bf..f0e43a85ff62 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu index 86c23344aff6..62de38e6c3d9 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu index 1ac28891f47b..e77c3934ffb1 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu index 713fe344c0f4..0e312afc2b92 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu index 0983be1e934f..ce15e77ca6c0 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu index b12d82eac48b..1d9572fe6c08 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu index fc1876f3dcdf..2df8be1d1819 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu index f13c28e859f0..235d4b6aae91 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu @@ -22,6 +22,6 @@ namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu index d3aeadb5f06d..5208ece7453c 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu index cfc7cb5f359a..f9f242ef584e 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu index 754ac9f52b14..4360574a31b1 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu index 340698b34a0f..ac2f43d60a32 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu index dd893939d6ff..f32d9f16a2fe 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu index 4d98cb61c4da..c94c45fdc489 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu index 346627563ea8..8c3b2325a90a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu index 2852063adcf7..af745591afef 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu @@ -22,6 +22,6 @@ namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu index 28f754b14d57..149c85487836 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu index fb54e476714e..a088f94a2bd0 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu index e06cad235ce8..fded63e8c46a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu index 3c5549339605..5506d1708df4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu index 7f7f74156df1..ce44f409086a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu index af2de5b0e2df..0711ce550d1e 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu index a50cee507bd8..64f803d48ede 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu index 7f99764d8bcb..8806668a0cce 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu index 10e93e14cf3f..51b1e6b5140c 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu index a1a98cf414d6..95d0b46489fd 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu index f29d26c44011..8a99df7d88c7 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu index 38d275b6f53f..c1c233e9b843 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu index be7c66956059..4afcc624e746 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu @@ -22,6 +22,6 @@ namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 6383458c9868..71dddf6490fb 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -370,7 +370,7 @@ namespace functions { } } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } } diff --git a/libnd4j/include/loops/cuda/legacy/transform.legacy b/libnd4j/include/loops/cuda/legacy/transform.legacy index e7f76751aa7f..b6edc9e1ea5f 100644 --- a/libnd4j/include/loops/cuda/legacy/transform.legacy +++ b/libnd4j/include/loops/cuda/legacy/transform.legacy @@ -297,9 +297,9 @@ namespace functions { } - //template class ND4J_EXPORT Transform; - //template class ND4J_EXPORT Transform; - //template class ND4J_EXPORT Transform; + //template class SD_EXPORT Transform; + //template class SD_EXPORT Transform; + //template class SD_EXPORT Transform; BUILD_CALL_1(template __device__ void Transform::transformCuda, float, (float*, Nd4jLong*, float*, float*,Nd4jLong*, int*,float*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS) BUILD_CALL_1(template __device__ void Transform::transformCuda, float16, (float16*, Nd4jLong*, float16*, float16*,Nd4jLong*, int*, float16*, UnifiedSharedMemory*, Nd4jLong*, Nd4jLong*), TRANSFORM_OPS) diff --git a/libnd4j/include/loops/cuda/pairwise.chpp b/libnd4j/include/loops/cuda/pairwise.chpp index d3252d862787..5d2cadb14188 100644 --- a/libnd4j/include/loops/cuda/pairwise.chpp +++ b/libnd4j/include/loops/cuda/pairwise.chpp @@ -106,16 +106,16 @@ void __host__ PairWiseTransform::executeCudaShaped(dim3& launchDims, cuda } /* - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); */ } } diff --git a/libnd4j/include/loops/cuda/pairwise_bool.cu b/libnd4j/include/loops/cuda/pairwise_bool.cu index f697de814a50..096d192fefdb 100644 --- a/libnd4j/include/loops/cuda/pairwise_bool.cu +++ b/libnd4j/include/loops/cuda/pairwise_bool.cu @@ -110,7 +110,7 @@ void PairWiseBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_ DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/pairwise_int.cu b/libnd4j/include/loops/cuda/pairwise_int.cu index 44447605e36d..2c408828160f 100644 --- a/libnd4j/include/loops/cuda/pairwise_int.cu +++ b/libnd4j/include/loops/cuda/pairwise_int.cu @@ -109,7 +109,7 @@ void PairWiseIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t * DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , INTEGER_TYPES); } } diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index c7550b926813..0d4dfdeadf6a 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -442,6 +442,6 @@ namespace functions { DEBUG_KERNEL(stream, opNum); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index 3aa2626a2617..aa64ea990a1f 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -314,7 +314,7 @@ __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { } -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp index e1b95ae55e57..4f3183d018f0 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp +++ b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp @@ -304,7 +304,7 @@ __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { } -//BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); +//BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce/reduce_long.cu b/libnd4j/include/loops/cuda/reduce/reduce_long.cu index e55ecd11c66f..db1bf5dcb131 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_long.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_long.cu @@ -325,7 +325,7 @@ __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { } -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce/reduce_same.cu b/libnd4j/include/loops/cuda/reduce/reduce_same.cu index c3c74c8065b6..2b5076314b2b 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_same.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_same.cu @@ -322,7 +322,7 @@ __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , LIBND4J_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/reduce3.chpp b/libnd4j/include/loops/cuda/reduce3.chpp index 2fa16e9aca7c..71770c0d63ec 100644 --- a/libnd4j/include/loops/cuda/reduce3.chpp +++ b/libnd4j/include/loops/cuda/reduce3.chpp @@ -556,7 +556,7 @@ __host__ void Reduce3::execScalar(dim3 launchDims, cudaStream_t *stream, - //BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES); + //BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_bool.cu b/libnd4j/include/loops/cuda/scalar_bool.cu index 1c8929ef3bfc..1858b56ddcdc 100644 --- a/libnd4j/include/loops/cuda/scalar_bool.cu +++ b/libnd4j/include/loops/cuda/scalar_bool.cu @@ -230,7 +230,7 @@ void ScalarBoolTransform::executeCudaAlongDimension(dim3& launchDims, cudaS DISPATCH_BY_OPNUM_TT(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_BOOL_OPS); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/scalar_int.cu b/libnd4j/include/loops/cuda/scalar_int.cu index bb761c76c311..00c64eb8e365 100644 --- a/libnd4j/include/loops/cuda/scalar_int.cu +++ b/libnd4j/include/loops/cuda/scalar_int.cu @@ -228,7 +228,7 @@ void ScalarIntTransform::executeCudaAlongDimension(dim3& launchDims, cudaStre DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , INTEGER_TYPES); } } diff --git a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu index 6d6dd42a4f10..b07827f2638f 100644 --- a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu +++ b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu @@ -86,5 +86,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "accumulate(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT accumulateKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vx, void * vz, int n, const Nd4jLong length), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT accumulateKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vx, void * vz, int n, const Nd4jLong length), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/averagingKernel.cu b/libnd4j/include/loops/cuda/specials/averagingKernel.cu index 798b273cfdc4..fe8798d74880 100644 --- a/libnd4j/include/loops/cuda/specials/averagingKernel.cu +++ b/libnd4j/include/loops/cuda/specials/averagingKernel.cu @@ -100,5 +100,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "averaging(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT averagingKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdx, void * vdz, int n, Nd4jLong length, bool propagate), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT averagingKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdx, void * vdz, int n, Nd4jLong length, bool propagate), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu index 13ad1d5b46cf..89ce8f12d6ab 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu @@ -186,5 +186,5 @@ __host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *str bitonicArbitraryStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu index 6bd1e8a33daa..a4163d6bfe8d 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu @@ -136,5 +136,5 @@ __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/concatKernel.cu b/libnd4j/include/loops/cuda/specials/concatKernel.cu index a4a849e49ec2..0d1c8485458d 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernel.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernel.cu @@ -266,5 +266,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "concatGenericLegacy(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo, Nd4jPointer * tadPointers, Nd4jPointer * offsetPointers, Nd4jLong * zTadShape, Nd4jLong * zOffsets), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo, Nd4jPointer * tadPointers, Nd4jPointer * offsetPointers, Nd4jLong * zTadShape, Nd4jLong * zOffsets), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu index 8ef9dfd24838..39a42e99ed93 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu @@ -91,5 +91,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "concatHStack(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelHStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong * zShapeInfo), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelHStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong * zShapeInfo), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu index 6614480f25ec..a949c12f4f2f 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu @@ -51,5 +51,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "concatScalar(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelScalarGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, void * vz), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelScalarGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, void * vz), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu index f95bad413dd1..cd9b7ca80b3b 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu @@ -81,5 +81,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "concatVStack(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT concatKernelVStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelVStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertHalfs.cu b/libnd4j/include/loops/cuda/specials/convertHalfs.cu index dec1705a42d3..c63b78552fc9 100644 --- a/libnd4j/include/loops/cuda/specials/convertHalfs.cu +++ b/libnd4j/include/loops/cuda/specials/convertHalfs.cu @@ -42,5 +42,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "convertHalfsToGeneric(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT convertHalfsToGeneric, (dim3 & launchDims, cudaStream_t * stream, half * dx, Nd4jLong n, void * dz), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertHalfsToGeneric, (dim3 & launchDims, cudaStream_t * stream, half * dx, Nd4jLong n, void * dz), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertToHalf.cu b/libnd4j/include/loops/cuda/specials/convertToHalf.cu index d86982d03585..ad82f5a8fb35 100644 --- a/libnd4j/include/loops/cuda/specials/convertToHalf.cu +++ b/libnd4j/include/loops/cuda/specials/convertToHalf.cu @@ -40,6 +40,6 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "convertToHalfs(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT convertToHalfGeneric, (dim3 & launchDims, cudaStream_t * stream, void * dx, Nd4jLong n, half * dz), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertToHalfGeneric, (dim3 & launchDims, cudaStream_t * stream, void * dx, Nd4jLong n, half * dz), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu index 813de162da3a..e8cc61699bea 100644 --- a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu @@ -90,5 +90,5 @@ namespace sd { execfillDimensionalIsMax<<>>(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, tadOffsets); sd::DebugHelper::checkErrorCode(stream, "fillDimensionalIsMax(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillDimensionalIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillDimensionalIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillIsMax.cu b/libnd4j/include/loops/cuda/specials/fillIsMax.cu index 1a994a13c28c..4b677e4be780 100644 --- a/libnd4j/include/loops/cuda/specials/fillIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillIsMax.cu @@ -41,5 +41,5 @@ namespace sd { } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/flatten.cu b/libnd4j/include/loops/cuda/specials/flatten.cu index b0bbf58e12cb..ef86ac596c11 100644 --- a/libnd4j/include/loops/cuda/specials/flatten.cu +++ b/libnd4j/include/loops/cuda/specials/flatten.cu @@ -65,7 +65,7 @@ __host__ void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, sd::DebugHelper::checkErrorCode(stream, "flattenGeneric(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT flattenKernelGeneric, (dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT flattenKernelGeneric, (dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/oesTad.cu b/libnd4j/include/loops/cuda/specials/oesTad.cu index 9f41ffbb90d4..eff7d0928e1c 100644 --- a/libnd4j/include/loops/cuda/specials/oesTad.cu +++ b/libnd4j/include/loops/cuda/specials/oesTad.cu @@ -201,5 +201,5 @@ __host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, execOesTadKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu index 7ef6a46dbc12..2f233f0d3940 100644 --- a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu +++ b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu @@ -87,6 +87,6 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "pullRows(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT pullRowsKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, void * vz, Nd4jLong len, Nd4jLong * indexes, Nd4jLong * tadShapeInfo, Nd4jLong * tadOffsets, Nd4jLong *zTadShapeInfo, Nd4jLong * zTadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT pullRowsKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, void * vz, Nd4jLong len, Nd4jLong * indexes, Nd4jLong * tadShapeInfo, Nd4jLong * tadOffsets, Nd4jLong *zTadShapeInfo, Nd4jLong * zTadOffsets), LIBND4J_TYPES); } diff --git a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu index db63c2af728c..643b0fd12a46 100644 --- a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu +++ b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu @@ -121,5 +121,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "shuffleGeneric(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT shuffleKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdX, Nd4jLong * *xShapeInfo, void **vdZ, int N, int * shuffleMap, Nd4jLong * *tadOnlyShapeInfo, Nd4jLong * *tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT shuffleKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdX, Nd4jLong * *xShapeInfo, void **vdZ, int N, int * shuffleMap, Nd4jLong * *tadOnlyShapeInfo, Nd4jLong * *tadOffsets), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tearKernel.cu b/libnd4j/include/loops/cuda/specials/tearKernel.cu index a6285b5a5a01..8f6007220804 100644 --- a/libnd4j/include/loops/cuda/specials/tearKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tearKernel.cu @@ -91,5 +91,5 @@ namespace sd { sd::DebugHelper::checkErrorCode(stream, "tear(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT tearKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, Nd4jLong * xShapeInfo, Nd4jPointer *targets, Nd4jLong * zShapeInfo, Nd4jLong * tadShapeInfo, Nd4jLong * tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void SD_EXPORT tearKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, Nd4jLong * xShapeInfo, Nd4jPointer *targets, Nd4jLong * zShapeInfo, Nd4jLong * tadShapeInfo, Nd4jLong * tadOffsets), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index c858d80986a4..e3c55779b419 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -414,6 +414,6 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_any.cu b/libnd4j/include/loops/cuda/transform/transform_any.cu index d13b9459986a..a34a29f2df53 100644 --- a/libnd4j/include/loops/cuda/transform/transform_any.cu +++ b/libnd4j/include/loops/cuda/transform/transform_any.cu @@ -114,6 +114,6 @@ namespace functions { sd::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index fec14a745a3a..39de73090179 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -120,6 +120,6 @@ namespace functions { sd::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index f631fd4d71c2..4f0fe7ddf49e 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -142,6 +142,6 @@ namespace functions { sd::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); } - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index 368a9b602b5c..66614c91631d 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -118,6 +118,6 @@ namespace functions { sd::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 155e5aa2308b..5fed51e2df86 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -119,6 +119,6 @@ namespace functions { sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/type_conversions.cu b/libnd4j/include/loops/cuda/type_conversions.cu index 8c38561f4d1b..3f0bbe007c5f 100644 --- a/libnd4j/include/loops/cuda/type_conversions.cu +++ b/libnd4j/include/loops/cuda/type_conversions.cu @@ -230,7 +230,7 @@ __host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, voi execEncoderKernelP1<<>>(dx, N, dz, threshold); sd::DebugHelper::checkErrorCode(stream, "encoderP1(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT encoderKernelP1Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *dz, float threshold), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP1Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *dz, float threshold), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -245,7 +245,7 @@ __host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, voi execEncoderKernelP3<<>>(dx, offsets, N, dz); sd::DebugHelper::checkErrorCode(stream, "encoderP3(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT encoderKernelP3Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP3Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -261,7 +261,7 @@ __host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void execDecoderKernel<<>>(dx, N, dz); sd::DebugHelper::checkErrorCode(stream, "execDecoder(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT decoderKernelGeneric, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT decoderKernelGeneric, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// @@ -278,7 +278,7 @@ __host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, vo execCudaEncodeBitmapKernel<<>>(vdx, N, dz, scalar, reductionBuffer, threshold); sd::DebugHelper::checkErrorCode(stream, "encodeBitmap(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaEncodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaEncodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// @@ -295,7 +295,7 @@ __host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, vo execCudaDecodeBitmapKernel<<>>(dx, N, vdz); sd::DebugHelper::checkErrorCode(stream, "cudeDecodeBitmap(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT cudaDecodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *vdz), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaDecodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong N, void *vdz), LIBND4J_TYPES); template diff --git a/libnd4j/include/loops/transform_any.h b/libnd4j/include/loops/transform_any.h index 22d56a4d397f..00af17433c70 100644 --- a/libnd4j/include/loops/transform_any.h +++ b/libnd4j/include/loops/transform_any.h @@ -68,7 +68,7 @@ class TransformAny { static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); template - static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); + static SD_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); #endif }; diff --git a/libnd4j/include/loops/transform_bool.h b/libnd4j/include/loops/transform_bool.h index 56a7f8f7e41d..70dee5c0b5ff 100644 --- a/libnd4j/include/loops/transform_bool.h +++ b/libnd4j/include/loops/transform_bool.h @@ -77,7 +77,7 @@ namespace functions { static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); template - static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); + static SD_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); #endif }; } diff --git a/libnd4j/include/loops/transform_float.h b/libnd4j/include/loops/transform_float.h index 1d9b6fb71aec..eeb8fd1a996a 100644 --- a/libnd4j/include/loops/transform_float.h +++ b/libnd4j/include/loops/transform_float.h @@ -101,7 +101,7 @@ namespace functions { static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); template - static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); + static SD_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); #endif }; } diff --git a/libnd4j/include/loops/transform_same.h b/libnd4j/include/loops/transform_same.h index cb36ba8726ad..d6a021a608c9 100644 --- a/libnd4j/include/loops/transform_same.h +++ b/libnd4j/include/loops/transform_same.h @@ -79,7 +79,7 @@ namespace functions { static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); template - static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); + static SD_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); #endif }; } diff --git a/libnd4j/include/loops/transform_strict.h b/libnd4j/include/loops/transform_strict.h index b7ba63e46e82..3da05c6fcd4a 100644 --- a/libnd4j/include/loops/transform_strict.h +++ b/libnd4j/include/loops/transform_strict.h @@ -81,7 +81,7 @@ namespace functions { static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); template - static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); + static SD_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, uint64_t threadId, uint64_t numThreads); #endif }; diff --git a/libnd4j/include/memory/ExternalWorkspace.h b/libnd4j/include/memory/ExternalWorkspace.h index 772afc6082ae..f557f1c484b9 100644 --- a/libnd4j/include/memory/ExternalWorkspace.h +++ b/libnd4j/include/memory/ExternalWorkspace.h @@ -26,7 +26,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT ExternalWorkspace { + class SD_EXPORT ExternalWorkspace { private: void *_ptrH = nullptr; void *_ptrD = nullptr; diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h index e33fb236905b..bdeb05bcad07 100644 --- a/libnd4j/include/memory/HotZoneManager.h +++ b/libnd4j/include/memory/HotZoneManager.h @@ -26,7 +26,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT HotZoneManager : public ZoneManager { + class SD_EXPORT HotZoneManager : public ZoneManager { protected: std::atomic _used = {0}; std::atomic _available = {0}; diff --git a/libnd4j/include/memory/MemoryCounter.h b/libnd4j/include/memory/MemoryCounter.h index 91aaeecff9bb..909eb7819c9d 100644 --- a/libnd4j/include/memory/MemoryCounter.h +++ b/libnd4j/include/memory/MemoryCounter.h @@ -32,7 +32,7 @@ namespace sd { /** * This class provides simple per-device counter */ - class ND4J_EXPORT MemoryCounter { + class SD_EXPORT MemoryCounter { private: static MemoryCounter* _INSTANCE; diff --git a/libnd4j/include/memory/MemoryDescriptor.h b/libnd4j/include/memory/MemoryDescriptor.h index 799df092ec4c..fb4ced9a3408 100644 --- a/libnd4j/include/memory/MemoryDescriptor.h +++ b/libnd4j/include/memory/MemoryDescriptor.h @@ -27,7 +27,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT MemoryDescriptor { + class SD_EXPORT MemoryDescriptor { private: void* _ptr; MemoryZone _zone; diff --git a/libnd4j/include/memory/MemoryRegistrator.h b/libnd4j/include/memory/MemoryRegistrator.h index ad1b0333aa32..a12d960b29ba 100644 --- a/libnd4j/include/memory/MemoryRegistrator.h +++ b/libnd4j/include/memory/MemoryRegistrator.h @@ -30,7 +30,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT MemoryRegistrator { + class SD_EXPORT MemoryRegistrator { protected: static MemoryRegistrator* _INSTANCE; Workspace* _workspace; diff --git a/libnd4j/include/memory/MemoryReport.h b/libnd4j/include/memory/MemoryReport.h index 647886ab54dd..40c87e188346 100644 --- a/libnd4j/include/memory/MemoryReport.h +++ b/libnd4j/include/memory/MemoryReport.h @@ -26,7 +26,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT MemoryReport { + class SD_EXPORT MemoryReport { private: Nd4jLong _vm = 0; Nd4jLong _rss = 0; diff --git a/libnd4j/include/memory/MemoryTracker.h b/libnd4j/include/memory/MemoryTracker.h index 60e347f03580..36a54f0c5208 100644 --- a/libnd4j/include/memory/MemoryTracker.h +++ b/libnd4j/include/memory/MemoryTracker.h @@ -33,7 +33,7 @@ namespace sd { /** * This class is used for tracking memory allocation wrt their allocation points in code */ - class ND4J_EXPORT MemoryTracker { + class SD_EXPORT MemoryTracker { private: static MemoryTracker* _INSTANCE; std::map _allocations; diff --git a/libnd4j/include/memory/MemoryUtils.h b/libnd4j/include/memory/MemoryUtils.h index 027008238535..c53ffa76752d 100644 --- a/libnd4j/include/memory/MemoryUtils.h +++ b/libnd4j/include/memory/MemoryUtils.h @@ -26,7 +26,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT MemoryUtils { + class SD_EXPORT MemoryUtils { public: static bool retrieveMemoryStatistics(MemoryReport& report); }; diff --git a/libnd4j/include/memory/WarmZoneManager.h b/libnd4j/include/memory/WarmZoneManager.h index 00c604ec1b77..7d830736abfa 100644 --- a/libnd4j/include/memory/WarmZoneManager.h +++ b/libnd4j/include/memory/WarmZoneManager.h @@ -25,7 +25,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT WarmZoneManager : public ZoneManager { + class SD_EXPORT WarmZoneManager : public ZoneManager { protected: public: WarmZoneManager() = default; diff --git a/libnd4j/include/memory/Workspace.h b/libnd4j/include/memory/Workspace.h index c97f6a178978..0f2bbfd94c8b 100644 --- a/libnd4j/include/memory/Workspace.h +++ b/libnd4j/include/memory/Workspace.h @@ -36,7 +36,7 @@ namespace sd { namespace memory { - class ND4J_EXPORT Workspace { + class SD_EXPORT Workspace { protected: char* _ptrHost = nullptr; char* _ptrDevice = nullptr; diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index c30095597b0a..e097e8e5952d 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -32,7 +32,7 @@ namespace sd { /** * Abstract class that defines common methods for zone managers */ - class ND4J_EXPORT ZoneManager { + class SD_EXPORT ZoneManager { protected: std::mutex _lock; diff --git a/libnd4j/include/ops/BroadcastBoolOpsTuple.h b/libnd4j/include/ops/BroadcastBoolOpsTuple.h index 69c633eeda87..d663a6c6cd4f 100644 --- a/libnd4j/include/ops/BroadcastBoolOpsTuple.h +++ b/libnd4j/include/ops/BroadcastBoolOpsTuple.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT BroadcastBoolOpsTuple { + class SD_EXPORT BroadcastBoolOpsTuple { private: public: diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h index 64e6c407079f..571e870adb9b 100644 --- a/libnd4j/include/ops/BroadcastIntOpsTuple.h +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT BroadcastIntOpsTuple { + class SD_EXPORT BroadcastIntOpsTuple { private: public: diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index c47f5462efcd..2b55535198f3 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT BroadcastOpsTuple { + class SD_EXPORT BroadcastOpsTuple { private: public: diff --git a/libnd4j/include/ops/declarable/BooleanOp.h b/libnd4j/include/ops/declarable/BooleanOp.h index b04ca8ecab54..4e22ce2d4403 100644 --- a/libnd4j/include/ops/declarable/BooleanOp.h +++ b/libnd4j/include/ops/declarable/BooleanOp.h @@ -27,7 +27,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT BooleanOp : public DeclarableOp { + class SD_EXPORT BooleanOp : public DeclarableOp { protected: OpDescriptor * _descriptor; diff --git a/libnd4j/include/ops/declarable/BroadcastableOp.h b/libnd4j/include/ops/declarable/BroadcastableOp.h index 9bc7561283e8..3bb1d55f4aa1 100644 --- a/libnd4j/include/ops/declarable/BroadcastableOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableOp.h @@ -28,7 +28,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT BroadcastableOp : public DeclarableCustomOp{ + class SD_EXPORT BroadcastableOp : public DeclarableCustomOp{ protected: Nd4jStatus validateAndExecute(Context& block) override = 0; public: diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index f98deb784197..5e63a3a8c3ca 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -59,7 +59,7 @@ namespace sd { - struct ND4J_EXPORT _loader { + struct SD_EXPORT _loader { _loader(); }; diff --git a/libnd4j/include/ops/declarable/DeclarableCustomOp.h b/libnd4j/include/ops/declarable/DeclarableCustomOp.h index 4aa133a4bba8..080ed6648637 100644 --- a/libnd4j/include/ops/declarable/DeclarableCustomOp.h +++ b/libnd4j/include/ops/declarable/DeclarableCustomOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { + class SD_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { protected: /** * This method executes this Op diff --git a/libnd4j/include/ops/declarable/DeclarableListOp.h b/libnd4j/include/ops/declarable/DeclarableListOp.h index cc77ee17b271..6cf3589758b6 100644 --- a/libnd4j/include/ops/declarable/DeclarableListOp.h +++ b/libnd4j/include/ops/declarable/DeclarableListOp.h @@ -30,7 +30,7 @@ using namespace sd::graph; namespace sd { namespace ops { - class ND4J_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { + class SD_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { protected: Nd4jStatus validateAndExecute(Context& block) override = 0; diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 3cce3b8e4fd8..cbdd500fcf66 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -44,7 +44,7 @@ using namespace sd::graph; namespace sd { namespace ops { - Nd4jStatus ND4J_EXPORT conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...); + Nd4jStatus SD_EXPORT conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...); template @@ -64,7 +64,7 @@ namespace sd { * This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class. * */ - class ND4J_EXPORT DeclarableOp { + class SD_EXPORT DeclarableOp { private: std::mutex _registrator; bool _registered = false; diff --git a/libnd4j/include/ops/declarable/DeclarableReductionOp.h b/libnd4j/include/ops/declarable/DeclarableReductionOp.h index 11f4ec410b94..8af574a2ef62 100644 --- a/libnd4j/include/ops/declarable/DeclarableReductionOp.h +++ b/libnd4j/include/ops/declarable/DeclarableReductionOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { + class SD_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { protected: /** * This method executes this Op diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h index 67787ca4b1ad..3e952bd4b6fc 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for broadcast operations. */ - class ND4J_EXPORT LegacyBroadcastBoolOp : public LegacyOp { + class SD_EXPORT LegacyBroadcastBoolOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override ; public: diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h index 755277397d0f..44518798ed30 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for broadcast operations. */ - class ND4J_EXPORT LegacyBroadcastOp : public LegacyOp { + class SD_EXPORT LegacyBroadcastOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h index fae0c5e8fd3a..bfe33874c431 100644 --- a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h @@ -30,7 +30,7 @@ namespace sd { * * TODO: eventually we want this op class to return long long instead of T */ - class ND4J_EXPORT LegacyIndexReduceOp : public LegacyOp { + class SD_EXPORT LegacyIndexReduceOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyOp.h b/libnd4j/include/ops/declarable/LegacyOp.h index 0dfd91a429fe..90ebf0fcc975 100644 --- a/libnd4j/include/ops/declarable/LegacyOp.h +++ b/libnd4j/include/ops/declarable/LegacyOp.h @@ -33,7 +33,7 @@ namespace sd { * * */ - class ND4J_EXPORT LegacyOp : public DeclarableOp { + class SD_EXPORT LegacyOp : public DeclarableOp { protected: // this field is mainly for debugging // it defines, which legacy op should be invoked on a given data diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h index 16a482811244..8e76226df161 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for Pairwise transform operations */ - class ND4J_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp { + class SD_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h index 81bbdc71556e..0418db506fe3 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for Pairwise transform operations */ - class ND4J_EXPORT LegacyPairwiseTransformOp: public LegacyOp { + class SD_EXPORT LegacyPairwiseTransformOp: public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyRandomOp.h b/libnd4j/include/ops/declarable/LegacyRandomOp.h index c0bab879d661..095a61e0d850 100644 --- a/libnd4j/include/ops/declarable/LegacyRandomOp.h +++ b/libnd4j/include/ops/declarable/LegacyRandomOp.h @@ -30,7 +30,7 @@ namespace sd { /** * This class provides wrapper for Random operations (i.e. linspace or Uniform) */ - class ND4J_EXPORT LegacyRandomOp : public LegacyOp { + class SD_EXPORT LegacyRandomOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyReduce3Op.h b/libnd4j/include/ops/declarable/LegacyReduce3Op.h index b0a06bd94e81..bfcd666a6a78 100644 --- a/libnd4j/include/ops/declarable/LegacyReduce3Op.h +++ b/libnd4j/include/ops/declarable/LegacyReduce3Op.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for Reduce3 operations (i.e. dot, cosineDistance etc) */ - class ND4J_EXPORT LegacyReduce3Op : public LegacyOp { + class SD_EXPORT LegacyReduce3Op : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h index 11cd52146874..647c88a6a322 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceBoolOp : public LegacyOp { + class SD_EXPORT LegacyReduceBoolOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h index ed36a04fe20c..59f3ec8d2717 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceFloatOp : public LegacyOp { + class SD_EXPORT LegacyReduceFloatOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h index 4f23a9717f53..1f2b8339b848 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceLongOp : public LegacyOp { + class SD_EXPORT LegacyReduceLongOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyReduceOp.h b/libnd4j/include/ops/declarable/LegacyReduceOp.h index 3e289fe258ed..1ce9b5609b35 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceOp.h @@ -25,7 +25,7 @@ /* namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceOp : public LegacyOp { + class SD_EXPORT LegacyReduceOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block); public: diff --git a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h index 86cc06a0ecbe..63472ec67e9c 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT LegacyReduceSameOp: public LegacyOp { + class SD_EXPORT LegacyReduceSameOp: public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h index 0d52eee9d35b..5da57ad18810 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray */ - class ND4J_EXPORT LegacyScalarBoolOp : public LegacyOp { + class SD_EXPORT LegacyScalarBoolOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; diff --git a/libnd4j/include/ops/declarable/LegacyScalarOp.h b/libnd4j/include/ops/declarable/LegacyScalarOp.h index 9f2a1a23a35b..4fbfa5cc28cf 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray */ - class ND4J_EXPORT LegacyScalarOp : public LegacyOp { + class SD_EXPORT LegacyScalarOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context& block) override; diff --git a/libnd4j/include/ops/declarable/LegacyStatsOp.h b/libnd4j/include/ops/declarable/LegacyStatsOp.h index 74520b9ddf03..eb74a803f507 100644 --- a/libnd4j/include/ops/declarable/LegacyStatsOp.h +++ b/libnd4j/include/ops/declarable/LegacyStatsOp.h @@ -28,7 +28,7 @@ namespace sd { /** * This class provides wrapper for SummaryStats operations: Variance and Standard Deviation */ - class ND4J_EXPORT LegacyStatsOp : public LegacyOp { + class SD_EXPORT LegacyStatsOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h index f98ccd4c85e7..5b4e82b41587 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h @@ -29,7 +29,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformAnyOp : public LegacyOp { + class SD_EXPORT LegacyTransformAnyOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h index d64dd4b019c0..b9e81d7c3574 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h @@ -30,7 +30,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformBoolOp : public LegacyOp { + class SD_EXPORT LegacyTransformBoolOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h index 37bd0edce8b4..6ce3f2649655 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h @@ -29,7 +29,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformFloatOp : public LegacyOp { + class SD_EXPORT LegacyTransformFloatOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformOp.h b/libnd4j/include/ops/declarable/LegacyTransformOp.h index 7eb265bcbaa8..82848e714441 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformOp.h @@ -29,7 +29,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformOp : public LegacyOp { + class SD_EXPORT LegacyTransformOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block); public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h index 4d9312dafa1b..f2fb71dff08f 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h @@ -30,7 +30,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformSameOp : public LegacyOp { + class SD_EXPORT LegacyTransformSameOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h index ee48c02b7b57..1a936f8c6ce3 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h @@ -30,7 +30,7 @@ namespace sd { /** * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) */ - class ND4J_EXPORT LegacyTransformStrictOp : public LegacyOp { + class SD_EXPORT LegacyTransformStrictOp : public LegacyOp { protected: Nd4jStatus validateAndExecute(Context &block) override; public: diff --git a/libnd4j/include/ops/declarable/LogicOp.h b/libnd4j/include/ops/declarable/LogicOp.h index d3ad59af29cf..68092f346ecc 100644 --- a/libnd4j/include/ops/declarable/LogicOp.h +++ b/libnd4j/include/ops/declarable/LogicOp.h @@ -32,7 +32,7 @@ namespace sd { * Their code is the part of GraphExecutioner logic. But we still want them to be expressed via Graph * @tparam T */ - class ND4J_EXPORT LogicOp : public DeclarableOp { + class SD_EXPORT LogicOp : public DeclarableOp { protected: Nd4jStatus validateAndExecute(sd::graph::Context& block) override; public: diff --git a/libnd4j/include/ops/declarable/OpDescriptor.h b/libnd4j/include/ops/declarable/OpDescriptor.h index 3feff5916735..ceb10db38f9d 100644 --- a/libnd4j/include/ops/declarable/OpDescriptor.h +++ b/libnd4j/include/ops/declarable/OpDescriptor.h @@ -36,7 +36,7 @@ namespace sd { * This class is very basic info holder for ops. bean/pojo pretty much. * */ - class ND4J_EXPORT OpDescriptor { + class SD_EXPORT OpDescriptor { protected: // opNum for legacy XYZ ops int _opNum = 0; diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index d9528ab7a4e3..a024c0f5ad03 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -62,7 +62,7 @@ namespace sd { * available at runtime via this singleton. * */ - class ND4J_EXPORT OpRegistrator { + class SD_EXPORT OpRegistrator { private: static OpRegistrator* _INSTANCE; OpRegistrator() { diff --git a/libnd4j/include/ops/declarable/OpTuple.h b/libnd4j/include/ops/declarable/OpTuple.h index 7458ef3d0bfb..39ce03d05432 100644 --- a/libnd4j/include/ops/declarable/OpTuple.h +++ b/libnd4j/include/ops/declarable/OpTuple.h @@ -27,7 +27,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT OpTuple { + class SD_EXPORT OpTuple { public: std::string _opName; std::vector _inputs; diff --git a/libnd4j/include/ops/declarable/PlatformHelper.h b/libnd4j/include/ops/declarable/PlatformHelper.h index e0231ad9addd..9f456b1ca96d 100644 --- a/libnd4j/include/ops/declarable/PlatformHelper.h +++ b/libnd4j/include/ops/declarable/PlatformHelper.h @@ -34,7 +34,7 @@ namespace sd { /** * This abstract class defines methods used by platform-specific helpers implementations */ - class ND4J_EXPORT PlatformHelper { + class SD_EXPORT PlatformHelper { protected: // target engine for this impl samediff::Engine _engine; diff --git a/libnd4j/include/ops/declarable/helpers/activations.h b/libnd4j/include/ops/declarable/helpers/activations.h index ab652ab24fe1..b20eb8450ab3 100644 --- a/libnd4j/include/ops/declarable/helpers/activations.h +++ b/libnd4j/include/ops/declarable/helpers/activations.h @@ -27,23 +27,23 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void softMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); + SD_EXPORT void softMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); - ND4J_EXPORT void logSoftMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); + SD_EXPORT void logSoftMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); - ND4J_EXPORT void softmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); + SD_EXPORT void softmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); - ND4J_EXPORT void logSoftmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); + SD_EXPORT void logSoftmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); - ND4J_EXPORT void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); + SD_EXPORT void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); - ND4J_EXPORT void prelu(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); + SD_EXPORT void prelu(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); - ND4J_EXPORT void preluBP(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); + SD_EXPORT void preluBP(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); - ND4J_EXPORT void thresholdRelu(sd::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); + SD_EXPORT void thresholdRelu(sd::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); - ND4J_EXPORT void thresholdReluDerivative(sd::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); + SD_EXPORT void thresholdReluDerivative(sd::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); } } } diff --git a/libnd4j/include/ops/declarable/helpers/col2im.h b/libnd4j/include/ops/declarable/helpers/col2im.h index 39a29da85581..7a1468d88223 100644 --- a/libnd4j/include/ops/declarable/helpers/col2im.h +++ b/libnd4j/include/ops/declarable/helpers/col2im.h @@ -27,7 +27,7 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); + SD_EXPORT void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); } diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index f38692a3570f..031e788bbfe1 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -36,7 +36,7 @@ namespace sd { PNORM_POOL = 2, }; - class ND4J_EXPORT ConvolutionUtils { + class SD_EXPORT ConvolutionUtils { public: static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index c4f99af3f7ba..7ae4221d4abc 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -929,7 +929,7 @@ void SVD::evalData(const NDArray& matrix) { } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SVD,,FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD,,FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/im2col.h b/libnd4j/include/ops/declarable/helpers/im2col.h index 6b61535f96ed..87eaa3bbc05a 100644 --- a/libnd4j/include/ops/declarable/helpers/im2col.h +++ b/libnd4j/include/ops/declarable/helpers/im2col.h @@ -27,7 +27,7 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); + SD_EXPORT void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); } } } diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index dfa9268b4df1..5402502a0edc 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -29,13 +29,13 @@ namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, NDArray* h, NDArray* c); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void SD_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, const bool forward, diff --git a/libnd4j/include/ops/declarable/helpers/multiUnique.h b/libnd4j/include/ops/declarable/helpers/multiUnique.h index 3119901c12f1..a7ce14818d27 100644 --- a/libnd4j/include/ops/declarable/helpers/multiUnique.h +++ b/libnd4j/include/ops/declarable/helpers/multiUnique.h @@ -26,7 +26,7 @@ namespace sd { namespace ops { namespace helpers { - ND4J_EXPORT bool multiUnique(std::vector const& inputList, sd::memory::Workspace* workspace = nullptr); + SD_EXPORT bool multiUnique(std::vector const& inputList, sd::memory::Workspace* workspace = nullptr); } } diff --git a/libnd4j/include/ops/impl/specials_sparse.cpp b/libnd4j/include/ops/impl/specials_sparse.cpp index c782ccf188a4..798e3f93a6cd 100644 --- a/libnd4j/include/ops/impl/specials_sparse.cpp +++ b/libnd4j/include/ops/impl/specials_sparse.cpp @@ -217,6 +217,6 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) #endif } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SparseUtils, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index c250d72f630e..e4be65c7a36e 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -41,14 +41,14 @@ namespace sd { } FloatBits2; - class ND4J_EXPORT SpecialTypeConverter { + class SD_EXPORT SpecialTypeConverter { public: template static void convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); }; template - class ND4J_EXPORT SpecialMethods { + class SD_EXPORT SpecialMethods { public: static void concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis); static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo); @@ -72,7 +72,7 @@ namespace sd { }; template - class ND4J_EXPORT DoubleMethods{ + class SD_EXPORT DoubleMethods{ public: static void sortByKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, bool descending); static void sortByValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, bool descending); diff --git a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h index 91d966339eb6..ab150ddbbb52 100644 --- a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h @@ -28,7 +28,7 @@ #include namespace sd { - class ND4J_EXPORT BenchmarkSuit { + class SD_EXPORT BenchmarkSuit { public: BenchmarkSuit() = default; ~BenchmarkSuit() = default; diff --git a/libnd4j/include/system/BlasVersionHelper.h b/libnd4j/include/system/BlasVersionHelper.h index 0d894ed8df70..cee95d26bf98 100644 --- a/libnd4j/include/system/BlasVersionHelper.h +++ b/libnd4j/include/system/BlasVersionHelper.h @@ -26,7 +26,7 @@ #include namespace sd { - class ND4J_EXPORT BlasVersionHelper { + class SD_EXPORT BlasVersionHelper { public: int _blasMajorVersion = 0; int _blasMinorVersion = 0; diff --git a/libnd4j/include/system/Environment.h b/libnd4j/include/system/Environment.h index 392e7087136e..1369f99506b6 100644 --- a/libnd4j/include/system/Environment.h +++ b/libnd4j/include/system/Environment.h @@ -30,7 +30,7 @@ #include namespace sd{ - class ND4J_EXPORT Environment { + class SD_EXPORT Environment { private: std::atomic _tadThreshold; std::atomic _elementThreshold; diff --git a/libnd4j/include/system/dll.h b/libnd4j/include/system/dll.h index 71098f8bf472..63d7c39c4bc1 100644 --- a/libnd4j/include/system/dll.h +++ b/libnd4j/include/system/dll.h @@ -25,8 +25,8 @@ #ifdef _WIN32 //#include -# define ND4J_EXPORT __declspec(dllexport) +# define SD_EXPORT __declspec(dllexport) #else -# define ND4J_EXPORT +# define SD_EXPORT #endif #endif //NATIVEOPERATIONS_DLL_H diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 4e7a288f0552..ac7f281d9207 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1231,12 +1231,12 @@ #define REQUIRE_OK(A) if (sd::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (sd::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; -#define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat>; \ - template struct ND4J_EXPORT __registratorHalf>; \ - template struct ND4J_EXPORT __registratorDouble>; \ - template struct ND4J_EXPORT __registratorSynonymHalf>; \ - template struct ND4J_EXPORT __registratorSynonymDouble>; \ - template struct ND4J_EXPORT __registratorSynonymFloat>; +#define DECLARE_ENTRY(NAME, ...) template struct SD_EXPORT __registratorFloat>; \ + template struct SD_EXPORT __registratorHalf>; \ + template struct SD_EXPORT __registratorDouble>; \ + template struct SD_EXPORT __registratorSynonymHalf>; \ + template struct SD_EXPORT __registratorSynonymDouble>; \ + template struct SD_EXPORT __registratorSynonymFloat>; #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) @@ -1277,7 +1277,7 @@ #define REGISTER_C(NAME) #endif -#define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1287,7 +1287,7 @@ };\ REGISTER_H(NAME) -#define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class ND4J_EXPORT NAME: public sd::ops::BooleanOp { \ +#define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class SD_EXPORT NAME: public sd::ops::BooleanOp { \ public:\ NAME(); \ protected: \ @@ -1300,7 +1300,7 @@ REGISTER_C(NAME) \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableListOp { \ +#define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableListOp { \ public:\ NAME(); \ protected: \ @@ -1312,7 +1312,7 @@ REGISTER_C(NAME) \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_LOGIC_OP(NAME) class ND4J_EXPORT NAME: public sd::ops::LogicOp { \ +#define DECLARE_LOGIC_OP(NAME) class SD_EXPORT NAME: public sd::ops::LogicOp { \ public:\ NAME(); \ protected: \ @@ -1355,7 +1355,7 @@ };\ static sd::ops::__registratorSynonym_##NAME zzz_register_opd_##NAME(#NAME, #ORIGINAL) -#define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1378,7 +1378,7 @@ } \ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { \ +#define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableOp { \ public:\ NAME(); \ sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); \ @@ -1401,7 +1401,7 @@ } \ Nd4jStatus sd::ops::NAME::validateAndExecute(Context& block) -#define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableReductionOp { \ +#define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableReductionOp { \ public:\ NAME(); \ protected: \ @@ -1415,7 +1415,7 @@ Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -#define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableCustomOp { \ +#define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableCustomOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ @@ -1437,7 +1437,7 @@ #define DECLARE_TYPES(NAME) void sd::ops::NAME::registerTypes() -#define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableOp { \ +#define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ diff --git a/libnd4j/include/system/platform_boilerplate.h b/libnd4j/include/system/platform_boilerplate.h index bdbb1a05161c..f8ee0eb55652 100644 --- a/libnd4j/include/system/platform_boilerplate.h +++ b/libnd4j/include/system/platform_boilerplate.h @@ -28,7 +28,7 @@ #define CONCATP(A,B) A ##_##B -#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) class ND4J_EXPORT PLATFORM_##CNAME : public PlatformHelper {\ +#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) class SD_EXPORT PLATFORM_##CNAME : public PlatformHelper {\ public: \ PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) { } \ bool isUsable(graph::Context &context) override; \ @@ -37,7 +37,7 @@ #define DECLARE_PLATFORM(NAME, ENGINE) DECLARE_PLATFORM_F(NAME, ENGINE, NAME ##_## ENGINE) -#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) struct ND4J_EXPORT __registratorPlatformHelper_##CNAME { \ +#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) struct SD_EXPORT __registratorPlatformHelper_##CNAME { \ __registratorPlatformHelper_##CNAME() { \ auto helper = new PLATFORM_##CNAME(); \ OpRegistrator::getInstance()->registerHelper(helper); \ diff --git a/libnd4j/include/types/pair.h b/libnd4j/include/types/pair.h index 0471c45ed33f..7067e888242c 100644 --- a/libnd4j/include/types/pair.h +++ b/libnd4j/include/types/pair.h @@ -24,7 +24,7 @@ #include namespace sd { - class ND4J_EXPORT Pair { + class SD_EXPORT Pair { protected: int _first = 0; int _second = 0; diff --git a/libnd4j/include/types/triple.h b/libnd4j/include/types/triple.h index 0a5310265888..520e24d569ab 100644 --- a/libnd4j/include/types/triple.h +++ b/libnd4j/include/types/triple.h @@ -25,7 +25,7 @@ #include namespace sd { - class ND4J_EXPORT Triple { + class SD_EXPORT Triple { protected: int _first = 0; int _second = 0; diff --git a/libnd4j/include/types/utf8string.h b/libnd4j/include/types/utf8string.h index a0df70558b93..efab2c027c8d 100644 --- a/libnd4j/include/types/utf8string.h +++ b/libnd4j/include/types/utf8string.h @@ -25,7 +25,7 @@ #include namespace sd { - struct ND4J_EXPORT utf8string { + struct SD_EXPORT utf8string { private: bool _allocated = false; public: diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index e46e2b8d45fb..74e7d33e573a 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -64,7 +64,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { auto layer = optimized.layer(0); // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.size()); + ASSERT_EQ(1, layer.width()); auto sequence = layer[0]; // we expect that OpSequence has exactly 2 ops @@ -110,7 +110,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { auto layer0 = optimized.layer(0); // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.size()); + ASSERT_EQ(1, layer0.width()); auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops @@ -122,7 +122,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { auto layer1 = optimized.layer(0); // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.size()); + ASSERT_EQ(2, layer1.width()); sequence = layer1[0]; diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 0027f7052c23..9f2168020317 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -70,7 +70,7 @@ TEST_F(OpSequenceTests, test_iterator_1) { auto layer = optimizedGraph.layer(0); // we expect exactly 1 sequence in this layer - ASSERT_EQ(1, layer.size()); + ASSERT_EQ(1, layer.width()); auto seq = layer[0]; From 0d42c67bd78519fe97e9245403cf8c33c6ac780f Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 16:08:29 +0300 Subject: [PATCH 033/233] NodeOptimizer abstraction Signed-off-by: raver119 --- .../graph/optimization/NodeOptimizer.h | 58 +++++++++++++++++++ .../graph/optimization/cudnn/README.md | 1 + .../graph/optimization/generic/README.md | 1 + .../graph/optimization/impl/NodeOptimizer.cpp | 29 ++++++++++ .../graph/optimization/mkldnn/README.md | 1 + 5 files changed, 90 insertions(+) create mode 100644 libnd4j/include/graph/optimization/NodeOptimizer.h create mode 100644 libnd4j/include/graph/optimization/cudnn/README.md create mode 100644 libnd4j/include/graph/optimization/generic/README.md create mode 100644 libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp create mode 100644 libnd4j/include/graph/optimization/mkldnn/README.md diff --git a/libnd4j/include/graph/optimization/NodeOptimizer.h b/libnd4j/include/graph/optimization/NodeOptimizer.h new file mode 100644 index 000000000000..0befb67e5dbb --- /dev/null +++ b/libnd4j/include/graph/optimization/NodeOptimizer.h @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_NODEOPTIMIZER_H +#define SD_NODEOPTIMIZER_H + +#include +#include +#include + +namespace sd { + namespace graph { + /** + * This abstract class defines basic methods needed for Inputs/Outputs optimizations. I.e. weight format changes or data types changes for a specific backend + */ + class SD_EXPORT NodeOptimizer { + protected: + std::string _target = {}; + + public: + NodeOptimizer() = default; + virtual ~NodeOptimizer() = default; + + /** + * This method applu + * @param node + */ + virtual void optimize(Node &node) = 0; + + /** + * This method returns target Op name for this optimizer + * @return + */ + const std::string& targetOp() const; + }; + } +} + + + +#endif //DEV_TESTS_NODEOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/cudnn/README.md b/libnd4j/include/graph/optimization/cudnn/README.md new file mode 100644 index 000000000000..7c872fe4132c --- /dev/null +++ b/libnd4j/include/graph/optimization/cudnn/README.md @@ -0,0 +1 @@ +This folder contains NodeOptimizer implementations for cuDNN \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/generic/README.md b/libnd4j/include/graph/optimization/generic/README.md new file mode 100644 index 000000000000..535e905f3fb4 --- /dev/null +++ b/libnd4j/include/graph/optimization/generic/README.md @@ -0,0 +1 @@ +This folder contains generic NodeOptimizer implementations suitable for all platforms \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp new file mode 100644 index 000000000000..7b9e681c8c34 --- /dev/null +++ b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + const std::string &NodeOptimizer::targetOp() const { + return _target; + } + } +} \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/mkldnn/README.md b/libnd4j/include/graph/optimization/mkldnn/README.md new file mode 100644 index 000000000000..1cedbb91e007 --- /dev/null +++ b/libnd4j/include/graph/optimization/mkldnn/README.md @@ -0,0 +1 @@ +This folder contains NodeOptimizer implementations for MKL-DNN \ No newline at end of file From 34deb6657c06cd7aa69376800e8794c499d01d12 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 17:09:36 +0300 Subject: [PATCH 034/233] GraphOptimizer abstraction Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 4 +- libnd4j/include/graph/impl/Graph.cpp | 8 ++-- .../graph/optimization/GraphOptimizer.h | 41 +++++++++++++++++++ .../optimization/impl/GraphOptimizer.cpp | 33 +++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 libnd4j/include/graph/optimization/GraphOptimizer.h create mode 100644 libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index b2653a30b087..9aa99032c697 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -195,12 +195,12 @@ namespace sd { /** * This method returns clone of the graph */ - Graph* clone(); + Graph* clone() const; /** * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace */ - Graph* cloneWithProxy(); + Graph* cloneWithProxy() const; /** * This method removes reference to VariableSpace from this Graph diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 5c0ead9feea0..cd20fabac986 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1299,7 +1299,7 @@ namespace sd { _configuration = configuration; } - Graph* Graph::cloneWithProxy() { + Graph* Graph::cloneWithProxy() const { auto clone = new Graph(); clone->replaceState(new VariableProxy(this->_variableSpace), this->_configuration->clone()); @@ -1333,7 +1333,7 @@ namespace sd { for (auto x: *(ovec)) { auto n = x->clone(); vec->emplace_back(n); - _handles.emplace_back(n); + clone->_handles.emplace_back(n); (*clone->_mapped)[n->id()] = n; } @@ -1350,7 +1350,7 @@ namespace sd { return clone; } - Graph* Graph::clone() { + Graph* Graph::clone() const { auto clone = new Graph(); clone->replaceState(this->_variableSpace->clone(), this->_configuration->clone()); @@ -1384,7 +1384,7 @@ namespace sd { for (auto x: *(ovec)) { auto n = x->clone(); vec->emplace_back(n); - _handles.emplace_back(n); + clone->_handles.emplace_back(n); (*clone->_mapped)[n->id()] = n; } diff --git a/libnd4j/include/graph/optimization/GraphOptimizer.h b/libnd4j/include/graph/optimization/GraphOptimizer.h new file mode 100644 index 000000000000..df3a6f559baf --- /dev/null +++ b/libnd4j/include/graph/optimization/GraphOptimizer.h @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_GRAPHOPTIMIZER_H +#define SD_GRAPHOPTIMIZER_H + +#include + +namespace sd { + namespace graph { + class SD_EXPORT GraphOptimizer { + public: + /** + * This method optimizes given Graph and returns independent cloned Graph + * @param graph + * @return + */ + static Graph* optimize(const Graph &graph); + }; + } +} + + +#endif //SD_GRAPHOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp new file mode 100644 index 000000000000..f838a22e790e --- /dev/null +++ b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + Graph *GraphOptimizer::optimize(const Graph &graph) { + auto clone = graph.clone(); + + //TODO: implement this method + + return clone; + } + } +} \ No newline at end of file From e7ed426a8a6a46b2feae18c13fb2462213a4e500 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 18:55:59 +0300 Subject: [PATCH 035/233] few files moved around Signed-off-by: raver119 --- .../include/graph/impl/GraphExecutioner.cpp | 2 +- .../{execution => logic}/LogicConditional.h | 0 .../graph/{execution => logic}/LogicEnter.h | 0 .../{execution => logic}/LogicExecutor.h | 0 .../graph/{execution => logic}/LogicExit.h | 0 .../graph/{execution => logic}/LogicExpose.h | 0 .../{execution => logic}/LogicLoopCond.h | 0 .../graph/{execution => logic}/LogicMerge.h | 0 .../{execution => logic}/LogicNextIteration.h | 0 .../graph/{execution => logic}/LogicReturn.h | 0 .../graph/{execution => logic}/LogicScope.h | 0 .../graph/{execution => logic}/LogicSwitch.h | 0 .../graph/{execution => logic}/LogicWhile.h | 0 .../impl/LogicConditional.cpp | 4 ++-- .../{execution => logic}/impl/LogicEnter.cpp | 2 +- .../impl/LogicExecutor.cpp | 24 +++++++++---------- .../{execution => logic}/impl/LogicExit.cpp | 2 +- .../{execution => logic}/impl/LogicExpose.cpp | 2 +- .../impl/LogicLoopCond.cpp | 2 +- .../{execution => logic}/impl/LogicMerge.cpp | 2 +- .../impl/LogicNextIteration.cpp | 2 +- .../{execution => logic}/impl/LogicReturn.cpp | 2 +- .../{execution => logic}/impl/LogicScope.cpp | 2 +- .../{execution => logic}/impl/LogicSwitch.cpp | 2 +- .../{execution => logic}/impl/LogicWhile.cpp | 6 ++--- libnd4j/include/legacy/NativeOps.h | 2 +- 26 files changed, 28 insertions(+), 28 deletions(-) rename libnd4j/include/graph/{execution => logic}/LogicConditional.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicEnter.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicExecutor.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicExit.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicExpose.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicLoopCond.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicMerge.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicNextIteration.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicReturn.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicScope.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicSwitch.h (100%) rename libnd4j/include/graph/{execution => logic}/LogicWhile.h (100%) rename libnd4j/include/graph/{execution => logic}/impl/LogicConditional.cpp (98%) rename libnd4j/include/graph/{execution => logic}/impl/LogicEnter.cpp (98%) rename libnd4j/include/graph/{execution => logic}/impl/LogicExecutor.cpp (83%) rename libnd4j/include/graph/{execution => logic}/impl/LogicExit.cpp (97%) rename libnd4j/include/graph/{execution => logic}/impl/LogicExpose.cpp (96%) rename libnd4j/include/graph/{execution => logic}/impl/LogicLoopCond.cpp (97%) rename libnd4j/include/graph/{execution => logic}/impl/LogicMerge.cpp (99%) rename libnd4j/include/graph/{execution => logic}/impl/LogicNextIteration.cpp (97%) rename libnd4j/include/graph/{execution => logic}/impl/LogicReturn.cpp (98%) rename libnd4j/include/graph/{execution => logic}/impl/LogicScope.cpp (96%) rename libnd4j/include/graph/{execution => logic}/impl/LogicSwitch.cpp (99%) rename libnd4j/include/graph/{execution => logic}/impl/LogicWhile.cpp (97%) diff --git a/libnd4j/include/graph/impl/GraphExecutioner.cpp b/libnd4j/include/graph/impl/GraphExecutioner.cpp index b68c3ccd3203..b03895a42efd 100644 --- a/libnd4j/include/graph/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/impl/GraphExecutioner.cpp @@ -43,7 +43,7 @@ #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/execution/LogicConditional.h b/libnd4j/include/graph/logic/LogicConditional.h similarity index 100% rename from libnd4j/include/graph/execution/LogicConditional.h rename to libnd4j/include/graph/logic/LogicConditional.h diff --git a/libnd4j/include/graph/execution/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h similarity index 100% rename from libnd4j/include/graph/execution/LogicEnter.h rename to libnd4j/include/graph/logic/LogicEnter.h diff --git a/libnd4j/include/graph/execution/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h similarity index 100% rename from libnd4j/include/graph/execution/LogicExecutor.h rename to libnd4j/include/graph/logic/LogicExecutor.h diff --git a/libnd4j/include/graph/execution/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h similarity index 100% rename from libnd4j/include/graph/execution/LogicExit.h rename to libnd4j/include/graph/logic/LogicExit.h diff --git a/libnd4j/include/graph/execution/LogicExpose.h b/libnd4j/include/graph/logic/LogicExpose.h similarity index 100% rename from libnd4j/include/graph/execution/LogicExpose.h rename to libnd4j/include/graph/logic/LogicExpose.h diff --git a/libnd4j/include/graph/execution/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h similarity index 100% rename from libnd4j/include/graph/execution/LogicLoopCond.h rename to libnd4j/include/graph/logic/LogicLoopCond.h diff --git a/libnd4j/include/graph/execution/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h similarity index 100% rename from libnd4j/include/graph/execution/LogicMerge.h rename to libnd4j/include/graph/logic/LogicMerge.h diff --git a/libnd4j/include/graph/execution/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h similarity index 100% rename from libnd4j/include/graph/execution/LogicNextIteration.h rename to libnd4j/include/graph/logic/LogicNextIteration.h diff --git a/libnd4j/include/graph/execution/LogicReturn.h b/libnd4j/include/graph/logic/LogicReturn.h similarity index 100% rename from libnd4j/include/graph/execution/LogicReturn.h rename to libnd4j/include/graph/logic/LogicReturn.h diff --git a/libnd4j/include/graph/execution/LogicScope.h b/libnd4j/include/graph/logic/LogicScope.h similarity index 100% rename from libnd4j/include/graph/execution/LogicScope.h rename to libnd4j/include/graph/logic/LogicScope.h diff --git a/libnd4j/include/graph/execution/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h similarity index 100% rename from libnd4j/include/graph/execution/LogicSwitch.h rename to libnd4j/include/graph/logic/LogicSwitch.h diff --git a/libnd4j/include/graph/execution/LogicWhile.h b/libnd4j/include/graph/logic/LogicWhile.h similarity index 100% rename from libnd4j/include/graph/execution/LogicWhile.h rename to libnd4j/include/graph/logic/LogicWhile.h diff --git a/libnd4j/include/graph/execution/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp similarity index 98% rename from libnd4j/include/graph/execution/impl/LogicConditional.cpp rename to libnd4j/include/graph/logic/impl/LogicConditional.cpp index 25627df4564d..cf0a34d4a43f 100644 --- a/libnd4j/include/graph/execution/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -18,9 +18,9 @@ // Created by raver119 on 20.10.2017. // -#include +#include #include -#include +#include #include diff --git a/libnd4j/include/graph/execution/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp similarity index 98% rename from libnd4j/include/graph/execution/impl/LogicEnter.cpp rename to libnd4j/include/graph/logic/impl/LogicEnter.cpp index f10ff792f765..1f9a973d88b2 100644 --- a/libnd4j/include/graph/execution/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -18,7 +18,7 @@ // @author raver119@gmail.com // -#include +#include #include diff --git a/libnd4j/include/graph/execution/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp similarity index 83% rename from libnd4j/include/graph/execution/impl/LogicExecutor.cpp rename to libnd4j/include/graph/logic/impl/LogicExecutor.cpp index fd7ce3e852e4..2e4898da31e4 100644 --- a/libnd4j/include/graph/execution/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -18,18 +18,18 @@ // Created by raver119 on 20.10.2017. // -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace sd { diff --git a/libnd4j/include/graph/execution/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp similarity index 97% rename from libnd4j/include/graph/execution/impl/LogicExit.cpp rename to libnd4j/include/graph/logic/impl/LogicExit.cpp index 9a0e217938a8..1b2a8d49d97e 100644 --- a/libnd4j/include/graph/execution/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -18,7 +18,7 @@ // @author raver119@gmail.com // -#include +#include namespace sd { diff --git a/libnd4j/include/graph/execution/impl/LogicExpose.cpp b/libnd4j/include/graph/logic/impl/LogicExpose.cpp similarity index 96% rename from libnd4j/include/graph/execution/impl/LogicExpose.cpp rename to libnd4j/include/graph/logic/impl/LogicExpose.cpp index b19e1df55311..06ddbc61d773 100644 --- a/libnd4j/include/graph/execution/impl/LogicExpose.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExpose.cpp @@ -18,7 +18,7 @@ // Created by raver119 on 12.11.2017. // -#include +#include namespace sd { namespace graph { diff --git a/libnd4j/include/graph/execution/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp similarity index 97% rename from libnd4j/include/graph/execution/impl/LogicLoopCond.cpp rename to libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 292452719770..8409b0bd82ea 100644 --- a/libnd4j/include/graph/execution/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -18,7 +18,7 @@ // @author raver119@gmail.com // -#include +#include namespace sd { diff --git a/libnd4j/include/graph/execution/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp similarity index 99% rename from libnd4j/include/graph/execution/impl/LogicMerge.cpp rename to libnd4j/include/graph/logic/impl/LogicMerge.cpp index 9d032a93f110..6d374005dee7 100644 --- a/libnd4j/include/graph/execution/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -18,7 +18,7 @@ // Created by raver119 on 30.01.18. // -#include +#include #include namespace sd { diff --git a/libnd4j/include/graph/execution/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp similarity index 97% rename from libnd4j/include/graph/execution/impl/LogicNextIteration.cpp rename to libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index fb7eaa513872..0765eb4ee783 100644 --- a/libnd4j/include/graph/execution/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -18,7 +18,7 @@ // @author raver119@gmail.com // -#include +#include namespace sd { diff --git a/libnd4j/include/graph/execution/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp similarity index 98% rename from libnd4j/include/graph/execution/impl/LogicReturn.cpp rename to libnd4j/include/graph/logic/impl/LogicReturn.cpp index c9dbafd6d026..2e8aacb812c7 100644 --- a/libnd4j/include/graph/execution/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -18,7 +18,7 @@ // Created by raver119 on 28.10.2017. // -#include "graph/execution/LogicReturn.h" +#include "graph/logic/LogicReturn.h" #include #include diff --git a/libnd4j/include/graph/execution/impl/LogicScope.cpp b/libnd4j/include/graph/logic/impl/LogicScope.cpp similarity index 96% rename from libnd4j/include/graph/execution/impl/LogicScope.cpp rename to libnd4j/include/graph/logic/impl/LogicScope.cpp index 1773aa6ea766..5319397d6969 100644 --- a/libnd4j/include/graph/execution/impl/LogicScope.cpp +++ b/libnd4j/include/graph/logic/impl/LogicScope.cpp @@ -18,7 +18,7 @@ // Created by raver119 on 20.10.2017. // -#include +#include #include diff --git a/libnd4j/include/graph/execution/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp similarity index 99% rename from libnd4j/include/graph/execution/impl/LogicSwitch.cpp rename to libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 1089046a3546..6d4bdf99fe70 100644 --- a/libnd4j/include/graph/execution/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -19,7 +19,7 @@ // #include -#include +#include #include #include diff --git a/libnd4j/include/graph/execution/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp similarity index 97% rename from libnd4j/include/graph/execution/impl/LogicWhile.cpp rename to libnd4j/include/graph/logic/impl/LogicWhile.cpp index 1dfd3aaf2006..0ce0039a8b57 100644 --- a/libnd4j/include/graph/execution/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -18,10 +18,10 @@ // Created by raver119 on 20.10.2017. // -#include -#include +#include +#include #include -#include +#include #include diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index 6060d01ca16d..7351fc0fc394 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -73,7 +73,7 @@ bool verbose = false; #include #include #include -#include +#include #include #include #include From 49469d5c1491309c612c03f4fb8224a68666f328 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 19:08:19 +0300 Subject: [PATCH 036/233] few more files moved around Signed-off-by: raver119 --- libnd4j/include/array/impl/NDArrayFactory.cpp | 2 +- libnd4j/include/graph/{ => execution}/GraphExecutioner.h | 0 libnd4j/include/graph/{ => execution}/impl/GraphExecutioner.cpp | 2 +- libnd4j/include/graph/impl/GraphHolder.cpp | 2 +- libnd4j/include/graph/logic/impl/LogicConditional.cpp | 2 +- libnd4j/include/graph/logic/impl/LogicSwitch.cpp | 2 +- libnd4j/include/graph/logic/impl/LogicWhile.cpp | 2 +- libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp | 2 +- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp | 1 - libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp | 2 +- libnd4j/tests_cpu/layers_tests/testlayers.h | 2 +- 16 files changed, 14 insertions(+), 15 deletions(-) rename libnd4j/include/graph/{ => execution}/GraphExecutioner.h (100%) rename libnd4j/include/graph/{ => execution}/impl/GraphExecutioner.cpp (99%) diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 128f032faef7..d67514236141 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/libnd4j/include/graph/GraphExecutioner.h b/libnd4j/include/graph/execution/GraphExecutioner.h similarity index 100% rename from libnd4j/include/graph/GraphExecutioner.h rename to libnd4j/include/graph/execution/GraphExecutioner.h diff --git a/libnd4j/include/graph/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp similarity index 99% rename from libnd4j/include/graph/impl/GraphExecutioner.cpp rename to libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index b03895a42efd..631f69c8e346 100644 --- a/libnd4j/include/graph/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index c480508f5709..3eaf9c3b50be 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -19,7 +19,7 @@ // #include -#include +#include #include #include diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp index cf0a34d4a43f..8b6af83a98fb 100644 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -19,7 +19,7 @@ // #include -#include +#include #include #include diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 6d4bdf99fe70..85fd92191e46 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -20,7 +20,7 @@ #include #include -#include +#include #include namespace sd { diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index 0ce0039a8b57..bd75c82eab8f 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 03c2411e28d4..268b396e03b2 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -19,7 +19,7 @@ // #include -#include +#include namespace sd { namespace graph { diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 39d0e62133a0..db10cd7d9f77 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -23,7 +23,7 @@ #include #include "legacy/NativeOpExecutioner.h" #include -#include +#include #include #include #include diff --git a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp b/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp index 00752ca0f9e7..15974723f580 100644 --- a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp @@ -20,7 +20,7 @@ #include "testlayers.h" #include -#include +#include #include #include diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index bdb8bde681f0..17e63f33f150 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include using namespace sd; diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 04e4a70e8cf3..0c6410aea22a 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -20,7 +20,6 @@ #include "testlayers.h" #include -#include #include using namespace sd; diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index acd7125f751e..34ebcfd9a026 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp index fe2f97bb6c0e..e2927aa0b205 100644 --- a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -20,7 +20,7 @@ #include "testlayers.h" -#include +#include /* diff --git a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp index e0d03731b5f3..7e7fb03737b4 100644 --- a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp @@ -19,7 +19,7 @@ // #include "testlayers.h" -#include +#include #include #include diff --git a/libnd4j/tests_cpu/layers_tests/testlayers.h b/libnd4j/tests_cpu/layers_tests/testlayers.h index 9106223d861b..d63694dac94f 100644 --- a/libnd4j/tests_cpu/layers_tests/testlayers.h +++ b/libnd4j/tests_cpu/layers_tests/testlayers.h @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include #include From c4f03871cf86ef03879eea3c95e8fcb5699cd61f Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 19:34:14 +0300 Subject: [PATCH 037/233] one more method Signed-off-by: raver119 --- .../graph/execution/GraphExecutioner.h | 64 +- .../graph/execution/impl/GraphExecutioner.cpp | 900 +++++++++--------- 2 files changed, 487 insertions(+), 477 deletions(-) diff --git a/libnd4j/include/graph/execution/GraphExecutioner.h b/libnd4j/include/graph/execution/GraphExecutioner.h index 276c82a1c329..59da386ab9df 100644 --- a/libnd4j/include/graph/execution/GraphExecutioner.h +++ b/libnd4j/include/graph/execution/GraphExecutioner.h @@ -39,35 +39,43 @@ namespace sd { namespace graph { - - class SD_EXPORT GraphExecutioner { - protected: - - - public: - //static Nd4jStatus executeFlatNode(sd::graph::Graph *graph, sd::graph::Node *node, sd::graph::VariableSpace *variableSpace); - - static Nd4jStatus executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace); - /** - * This method executes given Graph - * @return - */ - static Nd4jStatus execute(Graph *graph, VariableSpace *variableSpace = nullptr); - - - /** - * This method executes graph stored at given FlatBuffers pointer - * - * @param pointer Pointer to FlatBuffer - * @return pointer to FlatBuffer with result - */ - static sd::graph::ResultWrapper* executeFlatBuffer(Nd4jPointer pointer); - - static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); - }; - - + * This class provides Graph execution functionality + */ + class SD_EXPORT GraphExecutioner { + private: + protected: + + public: + GraphExecutioner() = default; + virtual ~GraphExecutioner() = default; + + virtual Nd4jStatus execute(const OptimizedGraph &graph); + + /** + * TODO: REMOVE ALL METHODS BELOW + */ + //static Nd4jStatus executeFlatNode(sd::graph::Graph *graph, sd::graph::Node *node, sd::graph::VariableSpace *variableSpace); + + static Nd4jStatus executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace); + + /** + * This method executes given Graph + * @return + */ + static Nd4jStatus execute(Graph *graph, VariableSpace *variableSpace = nullptr); + + + /** + * This method executes graph stored at given FlatBuffers pointer + * + * @param pointer Pointer to FlatBuffer + * @return pointer to FlatBuffer with result + */ + static sd::graph::ResultWrapper* executeFlatBuffer(Nd4jPointer pointer); + + static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); + }; } } diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index 631f69c8e346..83237961c0f0 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -57,596 +57,598 @@ #include namespace sd{ -namespace graph { + namespace graph { -/** - * This method executes given Node (as in Op within Node) - * - * Basically it just does DeclarableOp::execute(Block), and ops to their job. However, there are some additional functionality. - * - * @param graph - Graph instance pointer - * @param node - Node instance pointer, which will be executed - * @param variableSpace - VariableSpace instance pointer - varspace specific to current Thread/Session - * @return - */ - Nd4jStatus GraphExecutioner::executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace) { - OpType opType = node->opType(); - int opNum = node->opNum(); -// std::string opName = *(node->getCustomOp()->getOpName()); - - if (opType == OpType_BOOLEAN) { - nd4j_debug("Executing boolean graph node_%i", node->id()); - } else if (opType == OpType_LOGIC) { - nd4j_debug("Executing logic graph node_%i", node->id()); - } else if (opType == OpType_GRAPH) { - nd4j_debug("Executing embedded graph node_%i", node->id()); - } else if (opType != OpType_CUSTOM) { - nd4j_debug("Executing node_%i{%i}\n", node->id(), opNum); - } else { - nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName()->c_str()); - } - - Context context(node->getContextPrototype(), variableSpace); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) { - //nd4j_debug("Input variables: %i\n", node->input()->size()); - printf(" Inputs: {"); - for (int e = 0; e < node->input()->size(); e++) { - printf("[%i:%i]", node->input()->at(e).first, node->input()->at(e).second); - - if (e < node->input()->size() - 1) - printf(", "); + Nd4jStatus GraphExecutioner::execute(const OptimizedGraph &graph) { + return Status::CODE(500, "Not implemented yet :)"); } - printf("}\n"); - fflush(stdout); - } - - if (node->id() == 13) - nd4j_debug("",""); - - // if true - this is special case: Graph-in-Graph. - if (node->hasGraphEmbedded()) { - auto embedded = node->getGraph(); /** - * basically, we should do following things here: - * 1) fill embedded graph with input variables from this graph, if anything should be filled in - * 2) invoke embedded graph - * 3) announce its results as corresponding output variables in current VariableSpace + * This method executes given Node (as in Op within Node) + * + * Basically it just does DeclarableOp::execute(Block), and ops to their job. However, there are some additional functionality. + * + * @param graph - Graph instance pointer + * @param node - Node instance pointer, which will be executed + * @param variableSpace - VariableSpace instance pointer - varspace specific to current Thread/Session + * @return */ + Nd4jStatus GraphExecutioner::executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace) { + OpType opType = node->opType(); + int opNum = node->opNum(); + // std::string opName = *(node->getCustomOp()->getOpName()); + + if (opType == OpType_BOOLEAN) { + nd4j_debug("Executing boolean graph node_%i", node->id()); + } else if (opType == OpType_LOGIC) { + nd4j_debug("Executing logic graph node_%i", node->id()); + } else if (opType == OpType_GRAPH) { + nd4j_debug("Executing embedded graph node_%i", node->id()); + } else if (opType != OpType_CUSTOM) { + nd4j_debug("Executing node_%i{%i}\n", node->id(), opNum); + } else { + nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName()->c_str()); + } - // enforcing IMPLICIT mode. or not... should we try to be smarter then user? - //embedded->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; + Context context(node->getContextPrototype(), variableSpace); - if (node->input()->size() != embedded->numberOfPlaceholders()) { - nd4j_debug("Placeholders amount mismatch: %i expected, and %i available\n",node->input()->size(), embedded->numberOfPlaceholders()); - return ND4J_STATUS_BAD_INPUT; - } + if (sd::Environment::getInstance()->isDebugAndVerbose()) { + //nd4j_debug("Input variables: %i\n", node->input()->size()); + printf(" Inputs: {"); + for (int e = 0; e < node->input()->size(); e++) { + printf("[%i:%i]", node->input()->at(e).first, node->input()->at(e).second); - // we need to propagate required variables to the embedded graph - ResultSet deletables; - int cnt = 0; - for (Variable* v: *embedded->getPlaceholders()) { - if (v->getName() != nullptr && v->getName()->size() > 0) { - - // trying symbolic lookup first - if (variableSpace->hasVariable(v->getName())) { - // symbolic feeder - auto array = variableSpace->getVariable(v->getName())->getNDArray(); - auto vr = new NDArray(array->dup()); -// deletables.push_back(vr); - v->setNDArray(vr); - } else { - nd4j_debug("Can't find variable [%s] in parent graph...", v->getName()->c_str()); - return ND4J_STATUS_BAD_INPUT; - //throw "Can't find desired variable"; + if (e < node->input()->size() - 1) + printf(", "); } - } else { - // if we're not using symbolic lookup - we'll use sequential approach then - auto p = node->input()->at(cnt); - auto array = variableSpace->getVariable(p)->getNDArray(); - auto vr = new NDArray(array->dup()); - //deletables.push_back(vr); - v->setNDArray(vr); + printf("}\n"); + fflush(stdout); } - cnt++; - } + if (node->id() == 13) + nd4j_debug("",""); - // executing embedded graph as independent one - Nd4jStatus status = GraphExecutioner::execute(embedded); - if (status != ND4J_STATUS_OK) - return status; + // if true - this is special case: Graph-in-Graph. + if (node->hasGraphEmbedded()) { + auto embedded = node->getGraph(); - // now we should migrate its results to this node, as its own outputs - cnt = 0; - auto outputs = embedded->fetchOutputs(); + /** + * basically, we should do following things here: + * 1) fill embedded graph with input variables from this graph, if anything should be filled in + * 2) invoke embedded graph + * 3) announce its results as corresponding output variables in current VariableSpace + */ - for (auto v: *outputs){ - NDArray *array = v->getNDArray(); - v->setNDArray(nullptr); + // enforcing IMPLICIT mode. or not... should we try to be smarter then user? + //embedded->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; - std::pair pair(node->id(), cnt++); + if (node->input()->size() != embedded->numberOfPlaceholders()) { + nd4j_debug("Placeholders amount mismatch: %i expected, and %i available\n",node->input()->size(), embedded->numberOfPlaceholders()); + return ND4J_STATUS_BAD_INPUT; + } - auto var = variableSpace->getVariable(pair); + // we need to propagate required variables to the embedded graph + ResultSet deletables; + int cnt = 0; + for (Variable* v: *embedded->getPlaceholders()) { + if (v->getName() != nullptr && v->getName()->size() > 0) { + + // trying symbolic lookup first + if (variableSpace->hasVariable(v->getName())) { + // symbolic feeder + auto array = variableSpace->getVariable(v->getName())->getNDArray(); + auto vr = new NDArray(array->dup()); + // deletables.push_back(vr); + v->setNDArray(vr); + } else { + nd4j_debug("Can't find variable [%s] in parent graph...", v->getName()->c_str()); + return ND4J_STATUS_BAD_INPUT; + //throw "Can't find desired variable"; + } + } else { + // if we're not using symbolic lookup - we'll use sequential approach then + auto p = node->input()->at(cnt); + auto array = variableSpace->getVariable(p)->getNDArray(); + auto vr = new NDArray(array->dup()); + //deletables.push_back(vr); + v->setNDArray(vr); + } - //nd4j_printf("HasArray: [%i]; Removable: [%i]\n", var->hasNDArray(), var->isRemovable()); - var->setNDArray(array); - var->markRemovable(true); - } - deletables.size(); - delete outputs; - nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt); - - } else if (node->hasCustomOp()) { - // now, if we have something to execute - lets just execute it. - auto status = node->getCustomOp()->execute(&context); - if (status != ND4J_STATUS_OK) - return status; - - // propagate variables - if (node->hasExternalOutputs()) { - for (auto v: *node->output()) { - if (variableSpace->hasExternalVariable(v.first)) { - variableSpace->getVariable(v.first)->getNDArray()->assign(variableSpace->getVariable(node->id())->getNDArray()); + cnt++; } - } - } - - return status; - } - return ND4J_STATUS_OK; -} + // executing embedded graph as independent one + Nd4jStatus status = GraphExecutioner::execute(embedded); + if (status != ND4J_STATUS_OK) + return status; -/** - * This method executes given Graph instance, and returns error code. - * - * @param graph - * @return one of error codes defined in pointercast.h - */ -Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) { - auto __variableSpace = variableSpace == nullptr ? graph->getVariableSpace() : variableSpace; - - bool tempFlow = false; - if (__variableSpace->flowPath() == nullptr) { - tempFlow = true; - __variableSpace->setFlowPath(new FlowPath()); - } - auto flowPath = __variableSpace->flowPath(); + // now we should migrate its results to this node, as its own outputs + cnt = 0; + auto outputs = embedded->fetchOutputs(); - Nd4jLong tb0 = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; - graph->buildGraph(); + for (auto v: *outputs){ + NDArray *array = v->getNDArray(); + v->setNDArray(nullptr); - auto footprintForward = sd::memory::MemoryRegistrator::getInstance()->getGraphMemoryFootprint(graph->hashCode()); - if (footprintForward > 0) { - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - // this method will work only if current workspace size is smaller then proposed value - nd4j_debug("Setting workspace to %lld bytes\n", footprintForward); - __variableSpace->launchContext()->getWorkspace()->expandTo(footprintForward); - } - } + std::pair pair(node->id(), cnt++); - // optionally saving graph build time - if (Environment::getInstance()->isProfiling()) - flowPath->profile()->setBuildTime(GraphProfile::relativeTime(tb0)); + auto var = variableSpace->getVariable(pair); - Nd4jLong timeStart = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; + //nd4j_printf("HasArray: [%i]; Removable: [%i]\n", var->hasNDArray(), var->isRemovable()); + var->setNDArray(array); + var->markRemovable(true); + } + deletables.size(); + delete outputs; + nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt); - bool pe = graph->getExecutorConfiguration()->_executionMode == ExecutionMode_AUTO; + } else if (node->hasCustomOp()) { + // now, if we have something to execute - lets just execute it. + auto status = node->getCustomOp()->execute(&context); + if (status != ND4J_STATUS_OK) + return status; + // propagate variables + if (node->hasExternalOutputs()) { + for (auto v: *node->output()) { + if (variableSpace->hasExternalVariable(v.first)) { + variableSpace->getVariable(v.first)->getNDArray()->assign(variableSpace->getVariable(node->id())->getNDArray()); + } + } + } - // basically if at some point code diverges, code branch might be _DISABLED_, and all nodes within that branch will be disabled as well + return status; + } + return ND4J_STATUS_OK; + } - std::deque frames; - bool inFrame = false; - bool leftFrame = false; - auto nodeTime = GraphProfile::currentTime(); - int lastId = -10000000; - Nd4jLong exec_counter = 0; - // we loop through op layers here - for (int l = 0; l < (int) graph->getOnion()->size(); l++) { - int layerSize = graph->getOnion()->count(l) == 1 ? graph->getOnion()->at(l)->size() : 0; + /** + * This method executes given Graph instance, and returns error code. + * + * @param graph + * @return one of error codes defined in pointercast.h + */ + Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) { + auto __variableSpace = variableSpace == nullptr ? graph->getVariableSpace() : variableSpace; - int n = 0; -// this omp block will probably never be the case - for (; n < layerSize; n++) { - if (++exec_counter > 10000) { - l = graph->getOnion()->size(); - return Status::THROW("Early termination hit"); + bool tempFlow = false; + if (__variableSpace->flowPath() == nullptr) { + tempFlow = true; + __variableSpace->setFlowPath(new FlowPath()); } + auto flowPath = __variableSpace->flowPath(); - Node* node = graph->getOnion()->at(l)->at(n); + Nd4jLong tb0 = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; + graph->buildGraph(); + auto footprintForward = sd::memory::MemoryRegistrator::getInstance()->getGraphMemoryFootprint(graph->hashCode()); + if (footprintForward > 0) { + if (__variableSpace->launchContext()->getWorkspace() != nullptr) { + // this method will work only if current workspace size is smaller then proposed value + nd4j_debug("Setting workspace to %lld bytes\n", footprintForward); + __variableSpace->launchContext()->getWorkspace()->expandTo(footprintForward); + } + } + + // optionally saving graph build time if (Environment::getInstance()->isProfiling()) - flowPath->profile()->nodeById(node->id(), node->name()->c_str()); + flowPath->profile()->setBuildTime(GraphProfile::relativeTime(tb0)); - if (lastId != node->id() && Environment::getInstance()->isProfiling()) { - if (lastId != -10000000) - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); + Nd4jLong timeStart = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; - lastId = node->id(); - nodeTime = GraphProfile::currentTime(); - } + bool pe = graph->getExecutorConfiguration()->_executionMode == ExecutionMode_AUTO; - nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name()->c_str()); - // on first non-Exit node after loop we can rewind (if planned) - if (!(node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit)) { - // VALIDATED + // basically if at some point code diverges, code branch might be _DISABLED_, and all nodes within that branch will be disabled as well - // if we're out of frame - let's remove it from queue - if (leftFrame) { - auto frame_id = frames.back(); - frames.pop_back(); - flowPath->markFrameActive(frame_id, false); - flowPath->forgetFrame(frame_id); + std::deque frames; + bool inFrame = false; + bool leftFrame = false; - leftFrame = false; - } + auto nodeTime = GraphProfile::currentTime(); + int lastId = -10000000; + Nd4jLong exec_counter = 0; + // we loop through op layers here + for (int l = 0; l < (int) graph->getOnion()->size(); l++) { + int layerSize = graph->getOnion()->count(l) == 1 ? graph->getOnion()->at(l)->size() : 0; + int n = 0; + // this omp block will probably never be the case + for (; n < layerSize; n++) { + if (++exec_counter > 10000) { + l = graph->getOnion()->size(); + return Status::THROW("Early termination hit"); + } - // TODO: move inactivity check right here - bool shouldSkip = false; - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Merge) { - // Merge node has own checkout logic + Node* node = graph->getOnion()->at(l)->at(n); - auto inputId0 = node->input()->at(0); - auto inputId1 = node->input()->at(1); + if (Environment::getInstance()->isProfiling()) + flowPath->profile()->nodeById(node->id(), node->name()->c_str()); - // Merge node can be skipped only both inputs are inactive - if (!flowPath->isNodeActive(inputId0.first) && !flowPath->isNodeActive(inputId1.first)) - shouldSkip = true; + if (lastId != node->id() && Environment::getInstance()->isProfiling()) { + if (lastId != -10000000) + flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - } else { - // let's check for input nodes, if they are disabled or contain divergents - for (int e = 0; e < node->input()->size(); e++) { - auto inputId = node->input()->at(e); + lastId = node->id(); + nodeTime = GraphProfile::currentTime(); + } - // not a node. skipping checks - if (graph->getMapped()->count(inputId.first) == 0) - continue; + nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name()->c_str()); - /** - * We can skip current node, in two cases: - * 1) If previous node was disabled - * 2) If previous node was divergent node (i.e. IF op) and code went other way - */ - Node *prevNode = graph->getMapped()->at(inputId.first); - if (!flowPath->isNodeActive(inputId.first)) { - shouldSkip = true; - flowPath->markNodeActive(node->id(), false); + // on first non-Exit node after loop we can rewind (if planned) + if (!(node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit)) { + // VALIDATED - nd4j_debug("Skipping Node_%i due to inactive input [%i]\n", node->id(), inputId.first); - break; + // if we're out of frame - let's remove it from queue + if (leftFrame) { + auto frame_id = frames.back(); + frames.pop_back(); + flowPath->markFrameActive(frame_id, false); + flowPath->forgetFrame(frame_id); + + leftFrame = false; + } + + + // TODO: move inactivity check right here + bool shouldSkip = false; + if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Merge) { + // Merge node has own checkout logic + + auto inputId0 = node->input()->at(0); + auto inputId1 = node->input()->at(1); - } else if (prevNode->isDivergencePoint()) { // literally checking for switch here - if (flowPath->branch(inputId.first) != inputId.second) { + // Merge node can be skipped only both inputs are inactive + if (!flowPath->isNodeActive(inputId0.first) && !flowPath->isNodeActive(inputId1.first)) shouldSkip = true; - flowPath->markNodeActive(node->id(), false); - nd4j_debug("Skipping Node_%i due to divergent branch [%i]\n", node->id(), - inputId.first); - break; + + } else { + // let's check for input nodes, if they are disabled or contain divergents + for (int e = 0; e < node->input()->size(); e++) { + auto inputId = node->input()->at(e); + + // not a node. skipping checks + if (graph->getMapped()->count(inputId.first) == 0) + continue; + + /** + * We can skip current node, in two cases: + * 1) If previous node was disabled + * 2) If previous node was divergent node (i.e. IF op) and code went other way + */ + Node *prevNode = graph->getMapped()->at(inputId.first); + if (!flowPath->isNodeActive(inputId.first)) { + shouldSkip = true; + flowPath->markNodeActive(node->id(), false); + + nd4j_debug("Skipping Node_%i due to inactive input [%i]\n", node->id(), inputId.first); + break; + + } else if (prevNode->isDivergencePoint()) { // literally checking for switch here + if (flowPath->branch(inputId.first) != inputId.second) { + shouldSkip = true; + flowPath->markNodeActive(node->id(), false); + nd4j_debug("Skipping Node_%i due to divergent branch [%i]\n", node->id(), + inputId.first); + break; + } + } } } - } - } - if (shouldSkip) - continue; - } + if (shouldSkip) + continue; + } - // we're propagating frameId here (but only if wasn't set earlier) - if (frames.size() > 0 && node->getFrameId() < 0) - node->setFrameId(frames.back()); + // we're propagating frameId here (but only if wasn't set earlier) + if (frames.size() > 0 && node->getFrameId() < 0) + node->setFrameId(frames.back()); - flowPath->markNodeActive(node->id(), true); + flowPath->markNodeActive(node->id(), true); - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Enter) { - // Enter operation - // VALIDATED + if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Enter) { + // Enter operation + // VALIDATED - // we expect this node to have frameId set - auto frame_id = node->getFrameId(); + // we expect this node to have frameId set + auto frame_id = node->getFrameId(); - // new frame starts here - if (frames.size() == 0 || (frames.size() > 0 && frames.back() != frame_id)) { - flowPath->registerFrame(frame_id); - frames.emplace_back(frame_id); - inFrame = true; - } + // new frame starts here + if (frames.size() == 0 || (frames.size() > 0 && frames.back() != frame_id)) { + flowPath->registerFrame(frame_id); + frames.emplace_back(frame_id); + inFrame = true; + } - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; + auto status = LogicExecutor::processNode(graph, node); + if (status != Status::OK()) + return status; - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::NextIteration) { - /** - * NextIteration is special case: after successful execution of this op - we're changing execution position - */ - // VALIDATED - auto inputId = node->input()->at(0); + } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::NextIteration) { + /** + * NextIteration is special case: after successful execution of this op - we're changing execution position + */ + // VALIDATED + auto inputId = node->input()->at(0); - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; + auto status = LogicExecutor::processNode(graph, node); + if (status != Status::OK()) + return status; - auto frame_id = frames.back(); + auto frame_id = frames.back(); - flowPath->markNodeActive(node->id(), true); - flowPath->markExecuted(node->id(), true); + flowPath->markNodeActive(node->id(), true); + flowPath->markExecuted(node->id(), true); - if (!flowPath->isRewindPlanned(frame_id)) { - auto nextLayer = node->getRewindLayer(); + if (!flowPath->isRewindPlanned(frame_id)) { + auto nextLayer = node->getRewindLayer(); - nd4j_debug("Node_%i planned rewind to Node_%i at [%i:%i]\n", node->id(), node->getRewindNode(), nextLayer.first, nextLayer.second); + nd4j_debug("Node_%i planned rewind to Node_%i at [%i:%i]\n", node->id(), node->getRewindNode(), nextLayer.first, nextLayer.second); - flowPath->planRewind(frame_id, true); - flowPath->setRewindPositionOnce(frame_id, nextLayer.first - 1); + flowPath->planRewind(frame_id, true); + flowPath->setRewindPositionOnce(frame_id, nextLayer.first - 1); - continue; - } + continue; + } - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit) { - // Exit node is another special case: it can rewind executioner to specific point in graph - // VALIDATED + } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit) { + // Exit node is another special case: it can rewind executioner to specific point in graph + // VALIDATED - auto frame_id = frames.back(); + auto frame_id = frames.back(); - // if this loop frame wasn't activated - just skip it - if (!flowPath->isFrameActive(frame_id)) { - flowPath->markNodeActive(node->id(), false); + // if this loop frame wasn't activated - just skip it + if (!flowPath->isFrameActive(frame_id)) { + flowPath->markNodeActive(node->id(), false); - leftFrame = true; - continue; - } + leftFrame = true; + continue; + } - if (flowPath->isRewindPlanned(frame_id)) { - // just break loop here - l = flowPath->getRewindPosition(frame_id); - flowPath->setRewindPosition(frame_id, -1); - flowPath->planRewind(frame_id, false); + if (flowPath->isRewindPlanned(frame_id)) { + // just break loop here + l = flowPath->getRewindPosition(frame_id); + flowPath->setRewindPosition(frame_id, -1); + flowPath->planRewind(frame_id, false); - break; - } else { - // execute Exit node otherwise + break; + } else { + // execute Exit node otherwise - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; + auto status = LogicExecutor::processNode(graph, node); + if (status != Status::OK()) + return status; - leftFrame = true; - } + leftFrame = true; + } - } else if (node->opType() == OpType_LOGIC) { - /** - * If this LOGIC op, we'll use another execution model here - */ - auto status = LogicExecutor::processNode(graph, node); + } else if (node->opType() == OpType_LOGIC) { + /** + * If this LOGIC op, we'll use another execution model here + */ + auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - } else { + if (status != Status::OK()) + return status; + } else { - auto timeStart = std::chrono::system_clock::now(); + auto timeStart = std::chrono::system_clock::now(); - // actual node execution happens right here - Nd4jStatus status = executeFlatNode(graph, node, __variableSpace); + // actual node execution happens right here + Nd4jStatus status = executeFlatNode(graph, node, __variableSpace); - auto timeEnd = std::chrono::system_clock::now(); + auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - flowPath->setOuterTime(node->id(), outerTime); + flowPath->setOuterTime(node->id(), outerTime); - if (status != ND4J_STATUS_OK) - return status; + if (status != ND4J_STATUS_OK) + return status; - // here we should handle divergent ops, and disable nodes accordingly - if (node->isDivergencePoint()) { - auto activeBranch = flowPath->branch(node->id()); - nd4j_debug("Active branch at node [%i]: %i\n", node->id(), activeBranch); + // here we should handle divergent ops, and disable nodes accordingly + if (node->isDivergencePoint()) { + auto activeBranch = flowPath->branch(node->id()); + nd4j_debug("Active branch at node [%i]: %i\n", node->id(), activeBranch); - // now we skip all branches except of this active one - } + // now we skip all branches except of this active one + } - if (sd::Environment::getInstance()->isDebugAndVerbose()) { - - if (__variableSpace->getVariable(node->id())->hasNDArray()) { - auto array = __variableSpace->getVariable(node->id())->getNDArray(); - auto shape = ShapeUtils::shapeAsString(array); - auto values = array->asIndexedString(16); - auto type = DataTypeUtils::asString(array->dataType()); - nd4j_debug("node_%i finished. result shape: %s; data type: %s; first values: %s\n", node->id(), shape.c_str(), type.c_str(), values.c_str()); - } else if (__variableSpace->getVariable(node->id())->hasNDArrayList()) { - auto list = __variableSpace->getVariable(node->id())->hasNDArrayList() ? __variableSpace->getVariable(node->id())->getNDArrayList() : nullptr; - nd4j_debug("node_% is ListOp, skipping evaluation", node->id()); - } else { - nd4j_debug("node_% is Unknown: has no NDArray or NDArrayList", node->id()); + if (sd::Environment::getInstance()->isDebugAndVerbose()) { + + if (__variableSpace->getVariable(node->id())->hasNDArray()) { + auto array = __variableSpace->getVariable(node->id())->getNDArray(); + auto shape = ShapeUtils::shapeAsString(array); + auto values = array->asIndexedString(16); + auto type = DataTypeUtils::asString(array->dataType()); + nd4j_debug("node_%i finished. result shape: %s; data type: %s; first values: %s\n", node->id(), shape.c_str(), type.c_str(), values.c_str()); + } else if (__variableSpace->getVariable(node->id())->hasNDArrayList()) { + auto list = __variableSpace->getVariable(node->id())->hasNDArrayList() ? __variableSpace->getVariable(node->id())->getNDArrayList() : nullptr; + nd4j_debug("node_% is ListOp, skipping evaluation", node->id()); + } else { + nd4j_debug("node_% is Unknown: has no NDArray or NDArrayList", node->id()); + } + } } + + // if node was executed - tag it as active + flowPath->markExecuted(node->id(), true); } } - // if node was executed - tag it as active - flowPath->markExecuted(node->id(), true); - } - } - - // optionally saving execution time - if (Environment::getInstance()->isProfiling()) { - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - flowPath->profile()->setExecutionTime(GraphProfile::relativeTime(timeStart)); - //flowPath->profile().printOut(); - } - - // saving memory footprint for current run - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - auto m = __variableSpace->launchContext()->getWorkspace()->getAllocatedSize(); - auto h = graph->hashCode(); - sd::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m); - } - - if (tempFlow) { - delete flowPath; - __variableSpace->setFlowPath(nullptr); - } + // optionally saving execution time + if (Environment::getInstance()->isProfiling()) { + flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); + flowPath->profile()->setExecutionTime(GraphProfile::relativeTime(timeStart)); + //flowPath->profile().printOut(); + } - return Status::OK(); -} + // saving memory footprint for current run + if (__variableSpace->launchContext()->getWorkspace() != nullptr) { + auto m = __variableSpace->launchContext()->getWorkspace()->getAllocatedSize(); + auto h = graph->hashCode(); + sd::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m); + } -/** - * This method is provided for IPC: - * 1) it accepts pointer to FlatBuffers buffer - * 2) restores Graph from it - * 3) Executes this Graph - * 4) Packs execution results into FlatBuffers (FlatResults instance) - * 5) Returns pointer to FlatBuffer results buffer - * - */ - sd::graph::ResultWrapper* GraphExecutioner::executeFlatBuffer(Nd4jPointer pointer) { - uint8_t *buffer = reinterpret_cast(pointer); + if (tempFlow) { + delete flowPath; + __variableSpace->setFlowPath(nullptr); + } - // nd4j_debug("Trying to restore graph\n", 0); + return Status::OK(); + } - auto restoredGraph = GetFlatGraph(buffer); + /** + * This method is provided for IPC: + * 1) it accepts pointer to FlatBuffers buffer + * 2) restores Graph from it + * 3) Executes this Graph + * 4) Packs execution results into FlatBuffers (FlatResults instance) + * 5) Returns pointer to FlatBuffer results buffer + * + */ + sd::graph::ResultWrapper* GraphExecutioner::executeFlatBuffer(Nd4jPointer pointer) { + uint8_t *buffer = reinterpret_cast(pointer); - // nd4j_debug("Graph restored\n", 0); + // nd4j_debug("Trying to restore graph\n", 0); - // converting FlatGraph to internal representation - auto nativeGraph = new Graph(restoredGraph); + auto restoredGraph = GetFlatGraph(buffer); - if (Environment::getInstance()->isDebugAndVerbose()) { - nativeGraph->printOut(); - } + // nd4j_debug("Graph restored\n", 0); - FlowPath flowPath; - nativeGraph->getVariableSpace()->setFlowPath(&flowPath); + // converting FlatGraph to internal representation + auto nativeGraph = new Graph(restoredGraph); + if (Environment::getInstance()->isDebugAndVerbose()) { + nativeGraph->printOut(); + } - // nd4j_debug("Going to execute graph\n", 0); + FlowPath flowPath; + nativeGraph->getVariableSpace()->setFlowPath(&flowPath); - // executing internal representation - auto status = GraphExecutioner::execute(nativeGraph); - if (status != ND4J_STATUS_OK) { - nd4j_printf("Graph execution failed with status: [%i]\n", status) - return nullptr; - } - // nd4j_debug("Building output...\n", 0); + // nd4j_debug("Going to execute graph\n", 0); - flatbuffers::FlatBufferBuilder builder(1024); + // executing internal representation + auto status = GraphExecutioner::execute(nativeGraph); + if (status != ND4J_STATUS_OK) { + nd4j_printf("Graph execution failed with status: [%i]\n", status) + return nullptr; + } - // fetching time reports - std::vector> timings_vector; - for (int e = 0; e < (int) nativeGraph->getAllNodes()->size(); e++) { - Node *node = nativeGraph->getAllNodes()->at(e); + // nd4j_debug("Building output...\n", 0); - if (node->getContextPrototype() == nullptr) - continue; + flatbuffers::FlatBufferBuilder builder(1024); - auto pair = CreateLongPair(builder, flowPath.outerTime(node->id()), flowPath.innerTime(node->id())); - if (node->getName() != nullptr) { - auto name = builder.CreateString(node->getName()->c_str()); - auto fr = CreateFlatTiming(builder, node->id(), name, pair); - timings_vector.push_back(fr); - } else { - auto fr = CreateFlatTiming(builder, node->id(), 0, pair); - timings_vector.push_back(fr); - } - } + // fetching time reports + std::vector> timings_vector; + for (int e = 0; e < (int) nativeGraph->getAllNodes()->size(); e++) { + Node *node = nativeGraph->getAllNodes()->at(e); + if (node->getContextPrototype() == nullptr) + continue; - // now, we'll prepare output, depending on given outputmode - auto outputs = nativeGraph->fetchOutputs(); - auto size = static_cast(outputs->size()); - int arrays = 0; - std::vector> variables_vector; - for (int e = 0; e < size; e++) { - auto var = outputs->at(e); + auto pair = CreateLongPair(builder, flowPath.outerTime(node->id()), flowPath.innerTime(node->id())); + if (node->getName() != nullptr) { + auto name = builder.CreateString(node->getName()->c_str()); + auto fr = CreateFlatTiming(builder, node->id(), name, pair); + timings_vector.push_back(fr); + } else { + auto fr = CreateFlatTiming(builder, node->id(), 0, pair); + timings_vector.push_back(fr); + } + } - // FIXME: we want to export multi-output nodes as well - // FIXME: we want to export NDArrayList and skip nodes without outputs - if (!var->hasNDArray()) - continue; + // now, we'll prepare output, depending on given outputmode + auto outputs = nativeGraph->fetchOutputs(); + auto size = static_cast(outputs->size()); + int arrays = 0; + std::vector> variables_vector; + for (int e = 0; e < size; e++) { + auto var = outputs->at(e); - auto array = var->getNDArray(); + // FIXME: we want to export multi-output nodes as well + // FIXME: we want to export NDArrayList and skip nodes without outputs + if (!var->hasNDArray()) + continue; - auto fArray = FlatUtils::toFlatArray(builder, *array); - auto fName = builder.CreateString(*(var->getName())); - auto id = CreateIntPair(builder, var->id(), var->index()); + auto array = var->getNDArray(); - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); + auto fArray = FlatUtils::toFlatArray(builder, *array); - variables_vector.push_back(fv); - arrays++; - } + auto fName = builder.CreateString(*(var->getName())); + auto id = CreateIntPair(builder, var->id(), var->index()); - nd4j_debug("Returning %i variables back\n", arrays); + auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); - auto varTimings = builder.CreateVector(timings_vector); - auto varVectors = builder.CreateVector(variables_vector); - auto result = CreateFlatResult(builder, restoredGraph->id(), varVectors, varTimings); - builder.Finish(result); + variables_vector.push_back(fv); + arrays++; + } - // we might want to keep this graph for future - delete outputs; - delete nativeGraph; + nd4j_debug("Returning %i variables back\n", arrays); - char* res = new char[builder.GetSize()]; - memcpy(res, builder.GetBufferPointer(), builder.GetSize()); + auto varTimings = builder.CreateVector(timings_vector); + auto varVectors = builder.CreateVector(variables_vector); + auto result = CreateFlatResult(builder, restoredGraph->id(), varVectors, varTimings); + builder.Finish(result); - nd4j_debug("Buffer size: %lld\n", static_cast(builder.GetSize())); + // we might want to keep this graph for future + delete outputs; + delete nativeGraph; - return new ResultWrapper(builder.GetSize(), reinterpret_cast(res)); -} + char* res = new char[builder.GetSize()]; + memcpy(res, builder.GetBufferPointer(), builder.GetSize()); -flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { - ExecutionResult result; - auto varSpace = graph->getVariableSpace(); + nd4j_debug("Buffer size: %lld\n", static_cast(builder.GetSize())); - if (request != nullptr && request->variables() != nullptr) { - auto vars = request->variables(); - for (int e = 0; e < vars->size(); e++) { - auto fv = vars->Get(e); - auto v = new Variable(fv); - varSpace->replaceVariable(v); + return new ResultWrapper(builder.GetSize(), reinterpret_cast(res)); } - } - if (Environment::getInstance()->isDebugAndVerbose()) - graph->printOut(); + flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { + ExecutionResult result; + auto varSpace = graph->getVariableSpace(); - auto status = GraphExecutioner::execute(graph); - if (status != sd::Status::OK()) - throw graph_execution_exception(request->id()); + if (request != nullptr && request->variables() != nullptr) { + auto vars = request->variables(); + for (int e = 0; e < vars->size(); e++) { + auto fv = vars->Get(e); + auto v = new Variable(fv); + varSpace->replaceVariable(v); + } + } - auto outputs = graph->fetchOutputs(); + if (Environment::getInstance()->isDebugAndVerbose()) + graph->printOut(); - if (outputs->size() == 0) - throw no_results_exception(request->id()); + auto status = GraphExecutioner::execute(graph); + if (status != sd::Status::OK()) + throw graph_execution_exception(request->id()); + auto outputs = graph->fetchOutputs(); - for (auto v: *outputs) { - result.emplace_back(v); - } + if (outputs->size() == 0) + throw no_results_exception(request->id()); - auto t = result.asFlatResult(builder); - delete outputs; + for (auto v: *outputs) { + result.emplace_back(v); + } - return t; -} + auto t = result.asFlatResult(builder); + delete outputs; + return t; + } } } \ No newline at end of file From 8d7e47081078c91704363877b007c9c81020ea7e Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 20:05:47 +0300 Subject: [PATCH 038/233] GraphExecutor Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 3 +- .../graph/execution/GraphExecutioner.h | 13 +----- .../include/graph/execution/GraphExecutor.h | 40 +++++++++++++++++++ .../graph/execution/impl/GraphExecutioner.cpp | 5 --- .../graph/execution/impl/GraphExecutor.cpp | 29 ++++++++++++++ libnd4j/include/graph/impl/Graph.cpp | 2 +- 6 files changed, 73 insertions(+), 19 deletions(-) create mode 100644 libnd4j/include/graph/execution/GraphExecutor.h create mode 100644 libnd4j/include/graph/execution/impl/GraphExecutor.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 9aa99032c697..1a9149ce04d2 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace sd { @@ -251,7 +252,7 @@ namespace sd { * @param dictionary * @return */ - std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}) const; + std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}, const GraphExecutor &executor = GraphExecutor()) const; }; FORCEINLINE std::vector* Graph::nodes() { diff --git a/libnd4j/include/graph/execution/GraphExecutioner.h b/libnd4j/include/graph/execution/GraphExecutioner.h index 59da386ab9df..785ca3a2658d 100644 --- a/libnd4j/include/graph/execution/GraphExecutioner.h +++ b/libnd4j/include/graph/execution/GraphExecutioner.h @@ -40,21 +40,10 @@ namespace sd { namespace graph { /** - * This class provides Graph execution functionality + * TODO: REMOVE THIS CLASS */ class SD_EXPORT GraphExecutioner { - private: - protected: - public: - GraphExecutioner() = default; - virtual ~GraphExecutioner() = default; - - virtual Nd4jStatus execute(const OptimizedGraph &graph); - - /** - * TODO: REMOVE ALL METHODS BELOW - */ //static Nd4jStatus executeFlatNode(sd::graph::Graph *graph, sd::graph::Node *node, sd::graph::VariableSpace *variableSpace); static Nd4jStatus executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace); diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h new file mode 100644 index 000000000000..3b1b7ca522ae --- /dev/null +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_GRAPHEXECUTOR_H +#define SD_GRAPHEXECUTOR_H + +#include +#include + +namespace sd { + namespace graph { + class SD_EXPORT GraphExecutor { + public: + GraphExecutor() = default; + virtual ~GraphExecutor() = default; + + virtual Nd4jStatus execute(const OptimizedGraph &graph) const ; + }; + } +} + + +#endif //SD_GRAPHEXECUTOR_H diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index 83237961c0f0..d88f69b4aa91 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -58,11 +58,6 @@ namespace sd{ namespace graph { - - Nd4jStatus GraphExecutioner::execute(const OptimizedGraph &graph) { - return Status::CODE(500, "Not implemented yet :)"); - } - /** * This method executes given Node (as in Op within Node) * diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp new file mode 100644 index 000000000000..8c0124b5af28 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph) const { + throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); + } + } +} diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index cd20fabac986..7fb7f463f794 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1710,7 +1710,7 @@ namespace sd { return OptimizedGraph(); } - std::map Graph::execute(const std::map &dictionary, const std::vector &outputs) const { + std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { // TODO: implement this method return std::map(); } From d3a46e702a61c46dd501c804ddfa154c87971550 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 28 Mar 2020 20:37:51 +0300 Subject: [PATCH 039/233] one include fixed Signed-off-by: raver119 --- libnd4j/minifier/minifier.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index ef1cf39d0613..3d047211deac 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -23,7 +23,7 @@ #endif #include #include "graphopt.h" -#include +#include #include #include From 9939fa94cc18f8f6a18a966e3298d58fd1448b69 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 29 Mar 2020 15:01:22 +0300 Subject: [PATCH 040/233] few minor api changes Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 5 + libnd4j/include/graph/Node.h | 2 + libnd4j/include/graph/Variable.h | 37 ++++--- libnd4j/include/graph/VariableProxy.h | 70 ++++++------ libnd4j/include/graph/VariableSpace.h | 6 +- .../graph/execution/impl/GraphExecutioner.cpp | 6 +- .../include/graph/impl/ExecutionResult.cpp | 4 +- libnd4j/include/graph/impl/Graph.cpp | 14 ++- libnd4j/include/graph/impl/Node.cpp | 38 ++++++- libnd4j/include/graph/impl/Variable.cpp | 46 ++++---- libnd4j/include/graph/impl/VariableProxy.cpp | 10 +- libnd4j/include/graph/impl/VariableSpace.cpp | 32 +++--- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- .../layers_tests/GraphExecutionerTests.cpp | 103 ------------------ .../layers_tests/GraphExecutorTests.cpp | 61 +++++++++++ libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 40 +++---- .../layers_tests/VariableSpaceTests.cpp | 10 +- .../tests_cpu/layers_tests/VariableTests.cpp | 6 +- 18 files changed, 251 insertions(+), 241 deletions(-) delete mode 100644 libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp create mode 100644 libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 1a9149ce04d2..4383a01e5a45 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -130,6 +130,11 @@ namespace sd { void addNode(sd::graph::Node *node); void addNode(const sd::graph::Node &node); + /** + * This method allows to add placeholder with some pre-defined properties + */ + void addPlaceholder(const std::string &nodeName, const int id = 0, const DataType dataType = sd::DataType::ANY, const std::vector &shape = {}); + /** * This method returns layered representation of the graph * diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 84cce5f7cd09..1f2077b9e4aa 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -99,6 +99,7 @@ namespace sd { Nd4jLong _frameId = -1; public: + explicit Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(const std::string &opName, const int id = 0, const std::vector> &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); @@ -159,6 +160,7 @@ namespace sd { void pickInput(int inputId); void pickInput(int nodeId, int outputId); void pickInput(std::pair& id); + void pickInput(const std::string &id); bool isDeductable(); void setDeductable(bool reallyDeductable); diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 61f37dbfb5d7..5c5a876a151a 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -64,6 +64,7 @@ namespace sd { std::string _name; std::vector _shape; + DataType _dtype; bool _external = false; bool _readOnly = false; @@ -80,7 +81,7 @@ namespace sd { VariableType _variableType = VariableType::NDARRAY; public: - Variable(bool placeHolder); + Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); Variable(sd::NDArray *arrayw, const char *name, int id, int idx = 0); Variable(sd::NDArray *array = nullptr, const char *name = nullptr); @@ -90,27 +91,27 @@ namespace sd { ~Variable(); - Variable* clone(); + Variable* clone() const; template - SD_EXPORT Variable* asT(); + SD_EXPORT Variable* asT() const; - bool hasNDArray(); - sd::NDArray* getNDArray(); + bool hasNDArray() const; + sd::NDArray* getNDArray() const; void setNDArray(sd::NDArray *array); - bool hasNDArrayList(); - sd::NDArrayList* getNDArrayList(); + bool hasNDArrayList() const; + sd::NDArrayList* getNDArrayList() const; void setNDArrayList(sd::NDArrayList* list); - bool isExternal(); - bool isReadOnly(); - bool isEmpty(); - bool isRemovable(); + bool isExternal() const; + bool isReadOnly() const; + bool isEmpty() const; + bool isRemovable() const; - bool isPlaceholder(); + bool isPlaceholder() const; - VariableType variableType(); + VariableType variableType() const; void setVariableType(VariableType variableType); /** @@ -124,16 +125,16 @@ namespace sd { void markReadOnly(bool reallyReadOnly); void markRemovable(bool reallyRemovable); - int id(); - int index(); + int id() const; + int index() const; void setIndex(int index); void setId(int id); void setId(int id, int idx); - std::string *getName(); - void setName(std::string *name); + const std::string &getName() const; + void setName(const std::string &name); - std::vector& shape(); + const std::vector& shape() const; #ifndef __JAVACPP_HACK__ /** diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 8d09f84a93fd..7d7ea62da985 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -32,58 +32,58 @@ namespace sd { virtual VariableSpace& operator=(const VariableSpace& other); - virtual int numberOfPlaceholders(); - virtual std::vector* getPlaceholders(); + virtual int numberOfPlaceholders() override; + virtual std::vector* getPlaceholders() override; virtual sd::memory::Workspace *workspace(); - virtual bool hasExternalVariable(int it); - virtual bool hasExternalVariable(std::pair& pair); - virtual bool hasExternalVariable(std::string *symbol); + virtual bool hasExternalVariable(int it) override; + virtual bool hasExternalVariable(std::pair& pair) override; + virtual bool hasExternalVariable(const std::string &symbol) override; - virtual bool hasVariable(int id); - virtual bool hasVariable(int id, int idx); - virtual bool hasVariable(std::pair& pair); - virtual bool hasVariable(std::string *symbol); + virtual bool hasVariable(int id) override; + virtual bool hasVariable(int id, int idx) override; + virtual bool hasVariable(std::pair& pair) override; + virtual bool hasVariable(const std::string &symbol) override; - virtual sd::graph::Variable *getVariable(int id); - virtual sd::graph::Variable *getVariable(int id, int idx); - virtual sd::graph::Variable *getVariable(std::pair& pair); - virtual sd::graph::Variable *getVariable(std::string *symbol); + virtual sd::graph::Variable *getVariable(int id) override; + virtual sd::graph::Variable *getVariable(int id, int idx) override; + virtual sd::graph::Variable *getVariable(std::pair& pair) override; + virtual sd::graph::Variable *getVariable(const std::string &symbol) override; - virtual std::vector getVariables(); + virtual std::vector getVariables() override; - virtual Variable* putVariable(std::pair& pair, NDArray *array); - virtual void putVariable(std::pair& pair, Variable *variable); - virtual void putVariable(int id, Variable *variable); - virtual void putVariable(int id, NDArray *array); - virtual Variable* putVariable(int id, int idx, NDArray *array); + virtual Variable* putVariable(std::pair& pair, NDArray *array) override; + virtual void putVariable(std::pair& pair, Variable *variable) override; + virtual void putVariable(int id, Variable *variable) override; + virtual void putVariable(int id, NDArray *array) override; + virtual Variable* putVariable(int id, int idx, NDArray *array) override; void putVariable(int id, int idx, const NDArray &array) override; - virtual void putVariable(int id, int idx, Variable *array); + virtual void putVariable(int id, int idx, Variable *array) override; - virtual void replaceVariable(Variable *variable); + virtual void replaceVariable(Variable *variable) override; - virtual void dropVariable(std::pair &pair); - virtual void dropVariable(int id, int idx); + virtual void dropVariable(std::pair &pair) override; + virtual void dropVariable(int id, int idx) override; - virtual void putOutputVariable(Variable *variable); + virtual void putOutputVariable(Variable *variable) override; - virtual void trackList(sd::NDArrayList *list); + virtual void trackList(sd::NDArrayList *list) override; // memory-related statistics - virtual Nd4jLong externalMemory(); - virtual Nd4jLong internalMemory(); - virtual Nd4jLong totalMemory(); + virtual Nd4jLong externalMemory() override; + virtual Nd4jLong internalMemory() override; + virtual Nd4jLong totalMemory() override; - virtual int externalEntries(); - virtual int internalEntries(); - virtual int totalEntries(); + virtual int externalEntries() override; + virtual int internalEntries() override; + virtual int totalEntries() override; - virtual sd::graph::VariableSpace *clone(); + virtual sd::graph::VariableSpace *clone() override; - virtual sd::graph::Stash* getStash(); - virtual void setFlowPath(FlowPath* timers); - virtual FlowPath* flowPath(); + virtual sd::graph::Stash* getStash() override; + virtual void setFlowPath(FlowPath* timers) override; + virtual FlowPath* flowPath() override; }; } } \ No newline at end of file diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index fded177660a8..66fef674e09c 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -81,17 +81,17 @@ namespace sd { virtual bool hasExternalVariable(int it); virtual bool hasExternalVariable(std::pair& pair); - virtual bool hasExternalVariable(std::string *symbol); + virtual bool hasExternalVariable(const std::string &symbol); virtual bool hasVariable(int id); virtual bool hasVariable(int id, int idx); virtual bool hasVariable(std::pair& pair); - virtual bool hasVariable(std::string *symbol); + virtual bool hasVariable(const std::string &symbol); virtual sd::graph::Variable* getVariable(int id); virtual sd::graph::Variable* getVariable(int id, int idx); virtual sd::graph::Variable* getVariable(std::pair& pair); - virtual sd::graph::Variable* getVariable(std::string *symbol); + virtual sd::graph::Variable* getVariable(const std::string &symbol); virtual std::vector getVariables(); diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index d88f69b4aa91..e51b6029de2f 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -126,7 +126,7 @@ namespace sd{ ResultSet deletables; int cnt = 0; for (Variable* v: *embedded->getPlaceholders()) { - if (v->getName() != nullptr && v->getName()->size() > 0) { + if (!v->getName().empty()) { // trying symbolic lookup first if (variableSpace->hasVariable(v->getName())) { @@ -136,7 +136,7 @@ namespace sd{ // deletables.push_back(vr); v->setNDArray(vr); } else { - nd4j_debug("Can't find variable [%s] in parent graph...", v->getName()->c_str()); + nd4j_debug("Can't find variable [%s] in parent graph...", v->getName().c_str()); return ND4J_STATUS_BAD_INPUT; //throw "Can't find desired variable"; } @@ -581,7 +581,7 @@ namespace sd{ auto fArray = FlatUtils::toFlatArray(builder, *array); - auto fName = builder.CreateString(*(var->getName())); + auto fName = builder.CreateString(var->getName()); auto id = CreateIntPair(builder, var->id(), var->index()); auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); diff --git a/libnd4j/include/graph/impl/ExecutionResult.cpp b/libnd4j/include/graph/impl/ExecutionResult.cpp index fd2bed054201..3f0fbdf7fdd4 100644 --- a/libnd4j/include/graph/impl/ExecutionResult.cpp +++ b/libnd4j/include/graph/impl/ExecutionResult.cpp @@ -55,8 +55,8 @@ namespace sd { void ExecutionResult::emplace_back(Variable *variable) { _variables.emplace_back(variable); - if (variable->getName() != nullptr) - _stringIdMap[*variable->getName()] = variable; + if (!variable->getName().empty()) + _stringIdMap[variable->getName()] = variable; std::pair p(variable->id(), variable->index()); _pairIdMap[p] = variable; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 7fb7f463f794..a7b55def63ff 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -329,7 +329,7 @@ namespace sd { auto cname = node->getName() == nullptr ? nullptr : node->getName()->c_str(); auto nodeState = _variableSpace->hasVariable(node->id()) ? _variableSpace->getVariable(node->id()) :new Variable(nullptr, cname, node->id()); if (node->getName() != nullptr) - nodeState->setName(node->getName()); + nodeState->setName(*node->getName()); if (node->isInplace()) @@ -420,7 +420,7 @@ namespace sd { // nd4j_logger("Adding auto output variable; Output size: %i\n", node->output()->size()); var->setId(node->id()); - var->setName(node->getName()); + var->setName(*node->getName()); _variableSpace->putOutputVariable(var); //node->pickExternalOutput(var->id()); @@ -1130,8 +1130,8 @@ namespace sd { auto values = v->getNDArray()->asString(16); auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); - if (v->getName() != nullptr && !v->getName()->empty()) { - nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", v->getName()->c_str(), v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); + if (!v->getName().empty()) { + nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", v->getName().c_str(), v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); } else { nd4j_printf("<%i:%i> dtype: %s; shape: %s; values: %s;\n", v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); } @@ -1704,6 +1704,12 @@ namespace sd { */ } + void Graph::addPlaceholder(const std::string &nodeName, const int id, DataType dataType, const std::vector &shape) { + auto var = new Variable(true, dataType, shape); + var->setName(nodeName); + _variableSpace->putVariable(id, var); + } + OptimizedGraph Graph::optimizedGraph() const { // TODO: implement this method diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 9ad1285fb10a..bd84daf85c3a 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -180,6 +180,10 @@ namespace sd { _input.push_back(pair); } + void Node::pickInput(const std::string &id) { + throw std::runtime_error("Node::pickInput - Not implemented yet"); + } + void sd::graph::Node::pickInput(int inputId, int outputId) { std::pair p(inputId,outputId); pickInput(p); @@ -321,8 +325,40 @@ namespace sd { } BUILD_SINGLE_TEMPLATE(template SD_EXPORT Node* Node::asT, (), LIBND4J_TYPES); - Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { + Node::Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs) { + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_dim = nullptr; + this->_customOp = customOp; + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + // FIXME: get rid of this!!! + _scalar = NDArrayFactory::create(0); + + for (auto i: inputs) + pickInput(i); + + auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + + for (auto v: iArgs) + block->getIArguments()->emplace_back(v); + + for (auto v: tArgs) + block->getTArguments()->emplace_back(v); + + this->setContextPrototype(block); + } + + Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); this->_opType = OpType_CUSTOM; diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 4e990e9b9ad2..c92349f23d8a 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -30,13 +30,13 @@ namespace sd { namespace graph { template - Variable* Variable::asT() { + Variable* Variable::asT() const { auto result = new Variable(this->isPlaceholder()); result->markExternal(this->_external); result->setId(this->_id); result->markReadOnly(this->_readOnly); - result->setName(&this->_name); + result->setName(this->_name); result->setIndex(this->_index); if (this->_ndarray != nullptr) @@ -50,9 +50,9 @@ namespace sd { return result; } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT Variable* Variable::asT, (), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template SD_EXPORT Variable* Variable::asT, () const, LIBND4J_TYPES); - sd::graph::Variable* sd::graph::Variable::clone() { + sd::graph::Variable* sd::graph::Variable::clone() const { auto result = new Variable(this->isPlaceholder()); result->_external = this->_external; result->_id = this->_id; @@ -76,7 +76,7 @@ namespace sd { _index = index; } - bool sd::graph::Variable::hasNDArray() { + bool sd::graph::Variable::hasNDArray() const { return _ndarray != nullptr; } @@ -84,27 +84,27 @@ namespace sd { _variableType = variableType; } - bool sd::graph::Variable::hasNDArrayList() { + bool sd::graph::Variable::hasNDArrayList() const { return _list != nullptr; } - bool sd::graph::Variable::isPlaceholder() { + bool sd::graph::Variable::isPlaceholder() const { return _placeholder; } - std::string * sd::graph::Variable::getName() { - return &_name; + const std::string& sd::graph::Variable::getName() const { + return _name; } - void sd::graph::Variable::setName(std::string *name) { - _name = *name; + void sd::graph::Variable::setName(const std::string &name) { + _name = name; } - int sd::graph::Variable::id() { + int sd::graph::Variable::id() const { return _id; } - int sd::graph::Variable::index() { + int sd::graph::Variable::index() const { return _index; } @@ -112,7 +112,7 @@ namespace sd { _id = id; } - bool sd::graph::Variable::isEmpty() { + bool sd::graph::Variable::isEmpty() const { if (_variableType == VariableType::NDARRAY) return _ndarray == nullptr || !_ndarray->nonNull(); else if (_variableType == VariableType::ARRAY_LIST) @@ -121,11 +121,11 @@ namespace sd { return false; } - bool sd::graph::Variable::isExternal() { + bool sd::graph::Variable::isExternal() const { return _external; } - bool sd::graph::Variable::isReadOnly() { + bool sd::graph::Variable::isReadOnly() const { return _readOnly; } @@ -143,7 +143,7 @@ namespace sd { this->_readOnly = reallyReadOnly; } - sd::NDArray * sd::graph::Variable::getNDArray() { + sd::NDArray * sd::graph::Variable::getNDArray() const { if (_variableType != VariableType::NDARRAY) { nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); } @@ -162,7 +162,7 @@ namespace sd { return this->_ndarray; } - sd::NDArrayList * sd::graph::Variable::getNDArrayList() { + sd::NDArrayList * sd::graph::Variable::getNDArrayList() const { if (_variableType != VariableType::ARRAY_LIST) { nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); } @@ -170,7 +170,7 @@ namespace sd { } - bool Variable::isRemovable() { + bool Variable::isRemovable() const { return _removable; } @@ -187,7 +187,7 @@ namespace sd { } - VariableType sd::graph::Variable::variableType() { + VariableType sd::graph::Variable::variableType() const { return _variableType; } @@ -270,12 +270,14 @@ namespace sd { } } - std::vector& sd::graph::Variable::shape() { + const std::vector& sd::graph::Variable::shape() const { return _shape; } - sd::graph::Variable::Variable(bool placeholder) { + sd::graph::Variable::Variable(bool placeholder, DataType dataType, const std::vector &shape) { _placeholder = placeholder; + _dtype = dataType; + _shape = shape; } diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index 8ec199d8762c..d15d3785af2f 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -57,7 +57,7 @@ namespace sd { } - bool VariableProxy::hasExternalVariable(std::string *symbol) { + bool VariableProxy::hasExternalVariable(const std::string &symbol) { return _backed->hasExternalVariable(symbol); } @@ -105,7 +105,7 @@ namespace sd { } - bool VariableProxy::hasVariable(std::string *symbol) { + bool VariableProxy::hasVariable(const std::string &symbol) { return _current->hasVariable(symbol) || _backed->hasVariable(symbol); } @@ -146,20 +146,20 @@ namespace sd { } - sd::graph::Variable *VariableProxy::getVariable(std::string *symbol) { + sd::graph::Variable *VariableProxy::getVariable(const std::string &symbol) { if (_current->hasVariable(symbol)) return _current->getVariable(symbol); if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); - nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol->c_str()); + nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol.c_str()); throw std::runtime_error("Bad arguments"); } void VariableProxy::replaceVariable(Variable *variable) { - if (variable->getName() != nullptr && !variable->getName()->empty()) { + if (!variable->getName().empty()) { // if variable has name defined - we should resolve it via backing var space if (_backed->hasVariable(variable->getName())) { auto origVar = _backed->getVariable(variable->getName()); diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index a8fe297ee215..b4ea8afeff51 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -73,8 +73,8 @@ namespace sd { this->_temporary[pair.first] = variable; } - if (variable->getName() != nullptr && variable->getName()->length() > 0) - this->_symbolic[*(variable->getName())] = variable; + if (!variable->getName().empty()) + this->_symbolic[variable->getName()] = variable; this->_paired[pair] = variable; @@ -89,12 +89,12 @@ namespace sd { return _placeholders.size(); } - bool sd::graph::VariableSpace::hasVariable(std::string *symbol) { - return _symbolic.count(*symbol) == 1; + bool sd::graph::VariableSpace::hasVariable(const std::string &symbol) { + return _symbolic.count(symbol) == 1; } - sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::string *symbol) { - return _symbolic.at(*symbol); + sd::graph::Variable * sd::graph::VariableSpace::getVariable(const std::string &symbol) { + return _symbolic.at(symbol); } bool sd::graph::VariableSpace::hasVariable(int id, int index) { @@ -118,7 +118,7 @@ namespace sd { return var->isExternal(); } - bool VariableSpace::hasExternalVariable(std::string *symbol) { + bool VariableSpace::hasExternalVariable(const std::string &symbol) { if (!hasVariable(symbol)) return false; @@ -235,8 +235,8 @@ namespace sd { if (pair.second == 0 && !this->hasVariable(pair.first)) { this->putVariable(pair.first, variable); } else { - if (variable->getName() != nullptr && variable->getName()->length() != 0) { - _symbolic[*(variable->getName())] = variable; + if (!variable->getName().empty()) { + _symbolic[variable->getName()] = variable; } _varmap.lock(); @@ -276,9 +276,9 @@ namespace sd { variable->setId(id); - if (variable->getName() != nullptr && variable->getName()->length() != 0) { + if (!variable->getName().empty()) { //std::pair pair(*(variable->getName()), variable); - _symbolic[*(variable->getName())] = variable; + _symbolic[variable->getName()] = variable; } // we have special list for external variables to ensure graph completeness @@ -378,8 +378,8 @@ namespace sd { this->_temporary[pair.first] = clonedVar; } - if (clonedVar->getName() != nullptr && clonedVar->getName()->length() > 0) - this->_symbolic[*(clonedVar->getName())] = clonedVar; + if (!clonedVar->getName().empty()) + this->_symbolic[clonedVar->getName()] = clonedVar; this->_paired[pair] = clonedVar; @@ -392,10 +392,10 @@ namespace sd { void VariableSpace::replaceVariable(Variable *variable) { bool replaced = false; // trying name first - if (variable->getName() != nullptr && !variable->getName()->empty()) { - nd4j_printf("Trying to replace variable by name: [%s]\n", variable->getName()->c_str()); + if (!variable->getName().empty()) { + nd4j_printf("Trying to replace variable by name: [%s]\n", variable->getName().c_str()); if (hasVariable(variable->getName())) { - nd4j_printf("Replacing by name: [%s]\n", variable->getName()->c_str()); + nd4j_printf("Replacing by name: [%s]\n", variable->getName().c_str()); auto vs = getVariable(variable->getName()); dropVariable(vs->id(), vs->index()); putVariable(vs->id(), vs->index(), variable); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index db10cd7d9f77..b2cdd3d05d0d 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2255,7 +2255,7 @@ int getVariableIndex(sd::graph::Variable* variable) { } const char* getVariableName(sd::graph::Variable* variable) { - return variable->getName()->c_str(); + return variable->getName().c_str(); } Nd4jLong* getVariableShape(sd::graph::Variable* variable) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp deleted file mode 100644 index 7a2856dc0c0a..000000000000 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutionerTests.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 29.11.17. -// - - -#include "testlayers.h" -#include -#include -#include -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class GraphExecutionerTests : public testing::Test { -public: - -}; - -#ifdef GRAPH_TESTS_OK -TEST_F(GraphExecutionerTests, Test_Implicit_Output_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); - graph->buildGraph(); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(7, var0->id()); - ASSERT_EQ(0, var0->index()); - - delete outputs; - delete graph; -} - - -TEST_F(GraphExecutionerTests, Test_Implicit_Output_2) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - graph->buildGraph(); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(3, var0->id()); - ASSERT_EQ(0, var0->index()); - - delete outputs; - delete graph; -} - - -TEST_F(GraphExecutionerTests, Test_Implicit_Output_3) { - auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto status = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - auto var0 = outputs->at(0); - - ASSERT_EQ(3, var0->id()); - ASSERT_EQ(0, var0->index()); - - auto array = var0->getNDArray(); - - ASSERT_TRUE(array != nullptr); - - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); - - delete outputs; - delete graph; -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp new file mode 100644 index 000000000000..1dcc8c0abbe2 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 29.11.17. +// + + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphExecutorTests : public testing::Test { +public: + +}; + +TEST_F(GraphExecutorTests, test_execution_1) { + Graph graph; + + // A + graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a("multiply", 10, {{-1, 0}, {-2, 0}}); + Node b("add", 20, {{10, 0}, {-3, 0}}); + + graph.addNode(b); + graph.addNode(a); + + auto result = graph.execute({}, {"add_node"}); + ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.count("add_node")); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index 73aac9c3bbe8..cc965d084fe6 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -596,8 +596,8 @@ TEST_F(GraphTests, SymbolicLookupTest1) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -615,14 +615,14 @@ TEST_F(GraphTests, SymbolicLookupTest1) { graph->addNode(nodeB); - auto rX = graph->getVariableSpace()->getVariable(&a); - auto rZ = graph->getVariableSpace()->getVariable(&o); + auto rX = graph->getVariableSpace()->getVariable(a); + auto rZ = graph->getVariableSpace()->getVariable(o); std::string om("omicron"); ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph->getVariableSpace()->hasVariable(&om)); + ASSERT_FALSE(graph->getVariableSpace()->hasVariable(om)); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); @@ -630,8 +630,8 @@ TEST_F(GraphTests, SymbolicLookupTest1) { GraphExecutioner::execute(graph); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&p)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&t)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(p)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(t)); ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); @@ -654,8 +654,8 @@ TEST_F(GraphTests, OutputValidation1) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -691,8 +691,8 @@ TEST_F(GraphTests, OutputValidation2) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -733,8 +733,8 @@ TEST_F(GraphTests, OutputValidation3) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -773,8 +773,8 @@ TEST_F(GraphTests, OutputValidation4) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -819,8 +819,8 @@ TEST_F(GraphTests, OutputValidation5) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -859,8 +859,8 @@ TEST_F(GraphTests, OutputValidation6) { std::string a("alpha"); std::string o("omega"); - vX->setName(&a); - vZ->setName(&o); + vX->setName(a); + vZ->setName(o); graph->getVariableSpace()->putVariable(-1, vX); graph->getVariableSpace()->putVariable(-2, vZ); @@ -1173,7 +1173,7 @@ TEST_F(GraphTests, TestGraphInGraph_2) { // this is placeholder variable auto placeHolder = new Variable(true); - placeHolder->setName(&nameA1); + placeHolder->setName(nameA1); graphB.getVariableSpace()->putVariable(-1, placeHolder); // abs, result is 5 diff --git a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp index ec10f3db097b..26eef254cf9f 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp @@ -97,11 +97,11 @@ TEST_F(VariableSpaceTest, EqualityTest1) { ASSERT_TRUE(space.hasVariable(1)); ASSERT_TRUE(space.hasVariable(pair)); - ASSERT_TRUE(space.hasVariable(&name)); + ASSERT_TRUE(space.hasVariable(name)); auto rV1 = space.getVariable(1); auto rV2 = space.getVariable(pair); - auto rV3 = space.getVariable(&name); + auto rV3 = space.getVariable(name); ASSERT_TRUE(rV1 == rV2); ASSERT_TRUE(rV2 == rV3); @@ -164,7 +164,7 @@ TEST_F(VariableSpaceTest, CloneTests_2) { spaceA.putVariable(pair, variableA); - ASSERT_TRUE(spaceA.hasVariable(&str)); + ASSERT_TRUE(spaceA.hasVariable(str)); ASSERT_TRUE(spaceA.hasVariable(pair)); auto spaceB = spaceA.clone(); @@ -172,7 +172,7 @@ TEST_F(VariableSpaceTest, CloneTests_2) { ASSERT_FALSE(spaceB->hasVariable(1)); ASSERT_FALSE(spaceB->hasVariable(2)); ASSERT_TRUE(spaceB->hasVariable(pair)); - ASSERT_TRUE(spaceB->hasVariable(&str)); + ASSERT_TRUE(spaceB->hasVariable(str)); auto arrayB = spaceB->getVariable(pair)->getNDArray(); @@ -184,7 +184,7 @@ TEST_F(VariableSpaceTest, CloneTests_2) { delete spaceB; - ASSERT_TRUE(spaceA.hasVariable(&str)); + ASSERT_TRUE(spaceA.hasVariable(str)); ASSERT_TRUE(spaceA.hasVariable(pair)); } diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index 49b9b02d6dc3..acf92a0075d5 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -49,12 +49,12 @@ TEST_F(VariableTests, TestClone_1) { ASSERT_TRUE(array1->equalsTo(array2)); ASSERT_EQ(var1->id(), var2->id()); - ASSERT_EQ(*var1->getName(), *var2->getName()); + ASSERT_EQ(var1->getName(), var2->getName()); delete var1; std::string str("alpha"); - ASSERT_EQ(*var2->getName(), str); + ASSERT_EQ(var2->getName(), str); array2->assign(2.0); ASSERT_NEAR(2.0, array2->meanNumber().e(0), 1e-5); @@ -209,7 +209,7 @@ TEST_F(VariableTests, Test_Dtype_Conversion_1) { auto vd = v.template asT(); auto vf = vd->template asT(); - ASSERT_EQ(*v.getName(), *vf->getName()); + ASSERT_EQ(v.getName(), vf->getName()); ASSERT_EQ(v.id(), vf->id()); ASSERT_EQ(v.index(), vf->index()); From e7c22f62b67f7ae9f2b442f192b06973131a0e16 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 29 Mar 2020 20:21:18 +0300 Subject: [PATCH 041/233] placeholder test Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/Variable.h | 1 + .../graph/execution/impl/GraphExecutioner.cpp | 2 +- libnd4j/include/graph/impl/Graph.cpp | 4 +- libnd4j/include/graph/impl/Variable.cpp | 3 + .../tests_cpu/layers_tests/GraphTests2.cpp | 60 +++++++++++++++++++ 6 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 libnd4j/tests_cpu/layers_tests/GraphTests2.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 4383a01e5a45..292554e9c096 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -114,7 +114,7 @@ namespace sd { int numberOfPlaceholders(); - std::vector* getPlaceholders(); + const std::vector& getPlaceholders() const; /** * This method returns pointer to thread_local VariableSpace diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 5c5a876a151a..c302741c93cd 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -135,6 +135,7 @@ namespace sd { void setName(const std::string &name); const std::vector& shape() const; + DataType dataType() const; #ifndef __JAVACPP_HACK__ /** diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index e51b6029de2f..b38fd3ad71be 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -125,7 +125,7 @@ namespace sd{ // we need to propagate required variables to the embedded graph ResultSet deletables; int cnt = 0; - for (Variable* v: *embedded->getPlaceholders()) { + for (Variable* v: embedded->getPlaceholders()) { if (!v->getName().empty()) { // trying symbolic lookup first diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index a7b55def63ff..619b5e253b54 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -38,8 +38,8 @@ namespace sd { return &_handles; } - std::vector* Graph::getPlaceholders() { - return _variableSpace->getPlaceholders(); + const std::vector& Graph::getPlaceholders() const { + return *_variableSpace->getPlaceholders(); } int Graph::numberOfPlaceholders() { diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index c92349f23d8a..5f341b2da3c7 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -294,6 +294,9 @@ namespace sd { _variableType = VariableType::NDARRAY; } + DataType Variable::dataType() const { + return _dtype; + } sd::graph::Variable::Variable(NDArray *array, const char *name, int id, int idx) : Variable(array, name) { _id = id; diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp new file mode 100644 index 000000000000..2496187c2471 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphTests2 : public testing::Test { +public: + + GraphTests2() { + // + } +}; + +TEST_F(GraphTests2, test_placeholder_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::BFLOAT16, {4, 12, 48}); + + ASSERT_TRUE(graph.getVariableSpace()->hasVariable("input")); + + auto variable = graph.getVariableSpace()->getVariable("input"); + + ASSERT_NE(nullptr, variable); + ASSERT_TRUE(variable->isPlaceholder()); + ASSERT_EQ(DataType::BFLOAT16, variable->dataType()); + ASSERT_EQ(std::vector({4, 12, 48}), variable->shape()); + + auto placeholders = graph.getPlaceholders(); + ASSERT_EQ(1, placeholders.size()); + ASSERT_EQ(placeholders[0], variable); +} \ No newline at end of file From f4a43871b8faa60da0824b32bb085c4f4cdb4005 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 29 Mar 2020 22:15:00 +0300 Subject: [PATCH 042/233] GraphMemoryManager propagation Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 9 ++++++--- libnd4j/include/graph/execution/GraphExecutor.h | 3 +++ libnd4j/include/graph/impl/Graph.cpp | 10 +++++----- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 292554e9c096..0ec3220a335a 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -38,6 +38,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -69,6 +70,8 @@ namespace sd { MAP_IMPL _mappedScopes; std::vector _scopes; + const GraphMemoryManager &_memoryMaager; + //////////////////////////////////////// Nd4jStatus validateNode(sd::graph::Node *node); @@ -83,7 +86,7 @@ namespace sd { void prepareOutputs(); public: - Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr); + Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); ~Graph(); @@ -91,8 +94,8 @@ namespace sd { * Methods that allow Graph imports */ static Graph *importFromTensorFlow(const char *fileName); - static Graph* fromFlatBuffers(const char *fileName); - static Graph* fromFlatPointer(void *ptr); + static Graph* fromFlatBuffers(const char *fileName, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + static Graph* fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); // this method applies toposort to nodes void toposortNodes(); diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 3b1b7ca522ae..97510a3d1578 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -23,10 +23,13 @@ #include #include +#include namespace sd { namespace graph { class SD_EXPORT GraphExecutor { + protected: + public: GraphExecutor() = default; virtual ~GraphExecutor() = default; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 619b5e253b54..37deaabdde15 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -872,7 +872,7 @@ namespace sd { } } - Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace) { + Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { this->_onion = new MAP_IMPL *>(); this->_mapped = new MAP_IMPL (); this->_nodes = new std::vector(); @@ -1462,7 +1462,7 @@ namespace sd { } - Graph* Graph::fromFlatBuffers(const char* fileName) { + Graph* Graph::fromFlatBuffers(const char* fileName, const GraphMemoryManager &memoryManager) { // check if file exists if (!FileUtils::fileExists(fileName)) throw std::runtime_error("Graph file doesn't exist"); @@ -1494,15 +1494,15 @@ namespace sd { fclose(in); } - return fromFlatPointer(ptrGraph); + return fromFlatPointer(ptrGraph, memoryManager); } - Graph* Graph::fromFlatPointer(void *ptr) { + Graph* Graph::fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager) { // get FlatGraph out of it auto fg = GetFlatGraph(reinterpret_cast(ptr)); // return Graph from this FlatGraph - return new Graph(fg); + return new Graph(fg, nullptr, memoryManager); } Graph* Graph::importFromTensorFlow(const char *fileName) { From 49835c7b55df89be7a9123885231d7ad4d56c1e9 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 08:01:32 +0300 Subject: [PATCH 043/233] GraphMemoryManager propagation 2 Signed-off-by: raver119 --- libnd4j/include/graph/OptimizedGraph.h | 11 ++++++++++- libnd4j/include/graph/impl/Graph.cpp | 2 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 12 ++++++++++++ libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp | 3 ++- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 8c7bd749a055..07eca9b22a61 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -38,9 +39,11 @@ namespace sd { // on each layer we can have 1+ OpSequences that can be executed independent std::map _onion; + GraphMemoryManager *_memoryManager; + std::mutex _mutex; public: - OptimizedGraph() = default; + OptimizedGraph(GraphMemoryManager &memoryManager); ~OptimizedGraph() = default; OptimizedGraph(const OptimizedGraph& other) noexcept; @@ -74,6 +77,12 @@ namespace sd { void append(const std::vector &layer); void append(const ExecutionLayer &layer); void append(OpSequence &sequence); + + /** + * This method returns GraphMemoryManager instance that manages this Graph + * @return + */ + const GraphMemoryManager& memoryManager() const; }; } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 37deaabdde15..de4d1ab02114 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1713,7 +1713,7 @@ namespace sd { OptimizedGraph Graph::optimizedGraph() const { // TODO: implement this method - return OptimizedGraph(); + return OptimizedGraph(const_cast(_memoryMaager)); } std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 5401abc32ca6..ed7393c22ca4 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -22,8 +22,13 @@ namespace sd { namespace graph { + OptimizedGraph::OptimizedGraph(GraphMemoryManager &memoryManager) { + _memoryManager = &memoryManager; + } + OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { _onion = other._onion; + _memoryManager = other._memoryManager; } OptimizedGraph &OptimizedGraph::operator=(const OptimizedGraph &other) noexcept { @@ -31,12 +36,14 @@ namespace sd { return *this; _onion = other._onion; + _memoryManager = other._memoryManager; return *this; } OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept { _onion = std::move(other._onion); + _memoryManager = other._memoryManager; } OptimizedGraph &OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { @@ -44,6 +51,7 @@ namespace sd { return *this; _onion = std::move(other._onion); + _memoryManager = other._memoryManager; return *this; } @@ -69,5 +77,9 @@ namespace sd { std::lock_guard lock(_mutex); _onion[_onion.size()] = layer; } + + const GraphMemoryManager &OptimizedGraph::memoryManager() const { + return *_memoryManager; + } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 9f2168020317..76fce192d135 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -61,7 +61,8 @@ TEST_F(OpSequenceTests, test_iterator_1) { ASSERT_EQ(3, cnt); - OptimizedGraph optimizedGraph; + GraphMemoryManager mgr; + OptimizedGraph optimizedGraph(mgr); ASSERT_EQ(0, optimizedGraph.layers()); optimizedGraph.append(sequence); From 0fda980f7ea83610003ed99aa79968c280da61a0 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 09:26:28 +0300 Subject: [PATCH 044/233] few more tests Signed-off-by: raver119 --- .../layers_tests/GraphExecutorTests.cpp | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 1dcc8c0abbe2..247adfd1e349 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include using namespace sd; using namespace sd::graph; @@ -49,8 +51,8 @@ TEST_F(GraphExecutorTests, test_execution_1) { // C graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a("multiply", 10, {{-1, 0}, {-2, 0}}); - Node b("add", 20, {{10, 0}, {-3, 0}}); + Node a("multiply", "multiply_node", 10, {{-1, 0}, {-2, 0}}); + Node b("add", "add_node", 20, {{10, 0}, {-3, 0}}); graph.addNode(b); graph.addNode(a); @@ -58,4 +60,48 @@ TEST_F(GraphExecutorTests, test_execution_1) { auto result = graph.execute({}, {"add_node"}); ASSERT_EQ(1, result.size()); ASSERT_EQ(1, result.count("add_node")); +} + +TEST_F(GraphExecutorTests, test_placeholder_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // this test must throw an exception, because input isn't resolved yet + ASSERT_ANY_THROW(graph.execute()); +} + +TEST_F(GraphExecutorTests, test_placeholder_resolution_2) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); + +} + +TEST_F(GraphExecutorTests, test_output_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // since we're requesting output of non-existent node - we expect exception + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), graph::unresolved_output_exception); +} + +TEST_F(GraphExecutorTests, test_input_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // since we're trying to resolve non-existent placeholder - we expect exception + ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); } \ No newline at end of file From 78b4068759412ea53280d1eeaa10d531c8774368 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 09:37:49 +0300 Subject: [PATCH 045/233] nano rearrangement in tests Signed-off-by: raver119 --- .../layers_tests/GraphExecutorTests.cpp | 65 +---------------- .../tests_cpu/layers_tests/GraphTests2.cpp | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 64 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 247adfd1e349..71ca402cfb03 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -39,69 +39,6 @@ class GraphExecutorTests : public testing::Test { }; -TEST_F(GraphExecutorTests, test_execution_1) { - Graph graph; +TEST_F(GraphExecutorTests, test_basic_exec_1) { - // A - graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); - - // B - graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); - - // C - graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); - - Node a("multiply", "multiply_node", 10, {{-1, 0}, {-2, 0}}); - Node b("add", "add_node", 20, {{10, 0}, {-3, 0}}); - - graph.addNode(b); - graph.addNode(a); - - auto result = graph.execute({}, {"add_node"}); - ASSERT_EQ(1, result.size()); - ASSERT_EQ(1, result.count("add_node")); -} - -TEST_F(GraphExecutorTests, test_placeholder_resolution_1) { - Graph graph; - - graph.addPlaceholder("input", 0, DataType::FLOAT32); - - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); - - // this test must throw an exception, because input isn't resolved yet - ASSERT_ANY_THROW(graph.execute()); -} - -TEST_F(GraphExecutorTests, test_placeholder_resolution_2) { - Graph graph; - - graph.addPlaceholder("input", 0, DataType::FLOAT32); - - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); - - auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); - -} - -TEST_F(GraphExecutorTests, test_output_resolution_1) { - Graph graph; - - graph.addPlaceholder("input", 0, DataType::FLOAT32); - - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); - - // since we're requesting output of non-existent node - we expect exception - ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), graph::unresolved_output_exception); -} - -TEST_F(GraphExecutorTests, test_input_resolution_1) { - Graph graph; - - graph.addPlaceholder("input", 0, DataType::FLOAT32); - - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); - - // since we're trying to resolve non-existent placeholder - we expect exception - ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 2496187c2471..468cf360884b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include using namespace sd; using namespace sd::graph; @@ -57,4 +59,71 @@ TEST_F(GraphTests2, test_placeholder_1) { auto placeholders = graph.getPlaceholders(); ASSERT_EQ(1, placeholders.size()); ASSERT_EQ(placeholders[0], variable); +} + +TEST_F(GraphTests2, test_execution_1) { + Graph graph; + + // A + graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a("multiply", "multiply_node", 10, {{-1, 0}, {-2, 0}}); + Node b("add", "add_node", 20, {{10, 0}, {-3, 0}}); + + graph.addNode(b); + graph.addNode(a); + + auto result = graph.execute({}, {"add_node"}); + ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.count("add_node")); +} + +TEST_F(GraphTests2, test_placeholder_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // this test must throw an exception, because input isn't resolved yet + ASSERT_ANY_THROW(graph.execute()); +} + +TEST_F(GraphTests2, test_placeholder_resolution_2) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); + +} + +TEST_F(GraphTests2, test_output_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // since we're requesting output of non-existent node - we expect exception + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), graph::unresolved_output_exception); +} + +TEST_F(GraphTests2, test_input_resolution_1) { + Graph graph; + + graph.addPlaceholder("input", 0, DataType::FLOAT32); + + graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + + // since we're trying to resolve non-existent placeholder - we expect exception + ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); } \ No newline at end of file From f79d1c705d52223f7fa9afe8e3f03d0d18ed542e Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 13:13:43 +0300 Subject: [PATCH 046/233] GraphExecutor skeleton Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/OptimizedGraph.h | 11 ++++- .../include/graph/execution/GraphExecutor.h | 42 ++++++++++++++++- libnd4j/include/graph/execution/OpSequence.h | 13 ++++-- .../graph/execution/impl/GraphExecutor.cpp | 45 ++++++++++++++++++- .../graph/execution/impl/OpSequence.cpp | 12 ++++- libnd4j/include/graph/impl/Graph.cpp | 2 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 5 +++ .../layers_tests/GraphExecutorTests.cpp | 2 +- 9 files changed, 123 insertions(+), 11 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 0ec3220a335a..7b78026a5a31 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -123,7 +123,7 @@ namespace sd { * This method returns pointer to thread_local VariableSpace * @return */ - sd::graph::VariableSpace *getVariableSpace(); + sd::graph::VariableSpace *getVariableSpace() const; /** * This method adds given node to the graph diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 07eca9b22a61..c2893a8c199c 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -29,6 +29,8 @@ namespace sd { namespace graph { + class Graph; + /** * This class acts as a topologically sorted & optimized Graph representation, ready for execution */ @@ -39,7 +41,8 @@ namespace sd { // on each layer we can have 1+ OpSequences that can be executed independent std::map _onion; - GraphMemoryManager *_memoryManager; + GraphMemoryManager *_memoryManager = nullptr; + Graph *_originalGraph = nullptr; std::mutex _mutex; public: @@ -83,6 +86,12 @@ namespace sd { * @return */ const GraphMemoryManager& memoryManager() const; + + /** + * This method returns pointer to original Graph + * @return + */ + const Graph& originalGraph() const; }; } } diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 97510a3d1578..c11d40716a7e 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -27,14 +27,54 @@ namespace sd { namespace graph { + class Graph; + class SD_EXPORT GraphExecutor { protected: + virtual Context prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace) const; + + /* + * preprocessor call involves: + * - ensure all inputs reside in HOT memory zone + * - shape function call + * - open workspace + */ + virtual Nd4jStatus preprocess(sd::ops::DeclarableOp *op, Context &context) const; + + /** + * postporcessor call involves: + * - remove all inputs that are not going to be used later from HOT memory zone + * - close workspace + * @return + */ + virtual Nd4jStatus postprocess(sd::ops::DeclarableOp *op, Context *context) const; public: GraphExecutor() = default; virtual ~GraphExecutor() = default; - virtual Nd4jStatus execute(const OptimizedGraph &graph) const ; + /** + * This method executes OptimizedGraph instance + * @param graph + * @return + */ + virtual Nd4jStatus execute(const OptimizedGraph &graph) const; + + /** + * This method executes OpSequence + * @param sequence + * @param deviceId - this argument allows to override device affinity specified in OpSequence, keep it < 0 to follow OpSequence + * @return + */ + virtual Nd4jStatus execute(const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId = -1) const; + + /** + * This method executes given op + * @param op + * @param contextPrototype + * @return + */ + virtual Nd4jStatus execute(sd::ops::DeclarableOp *op, ContextPrototype *contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const; }; } } diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 3c6e163933ed..7fb1a333d55c 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -36,9 +36,11 @@ namespace sd { protected: // main thing here. sorted list of operations and their contexts std::vector> _ops; + + int _deviceId = 0; public: - explicit OpSequence(const std::vector> &ops); - OpSequence() = default; + explicit OpSequence(const std::vector> &ops, const int deviceId = 0); + OpSequence(const int deviceId = 0); ~OpSequence() = default; OpSequence(const OpSequence& other) noexcept; @@ -51,6 +53,9 @@ namespace sd { // move assignment operator OpSequence& operator=(OpSequence&& other) noexcept; + + int deviceId() const; + /** * This method returns number of individual operations within this sequence * @return @@ -77,8 +82,8 @@ namespace sd { * @return */ - iterator begin(); - iterator end(); + OpSequence::iterator begin(); + OpSequence::iterator end(); // additional private section private: diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 8c0124b5af28..2db06f22f073 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -19,11 +19,54 @@ // #include +#include namespace sd { namespace graph { + Context GraphExecutor::prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace) const { + // TODO: maybe we'll want to do something here? + return Context(contextPrototype, &variableSpace); + } + + Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { + return Status::OK(); + } + + Nd4jStatus GraphExecutor::postprocess(sd::ops::DeclarableOp *op, Context *context) const { + return Status::OK(); + } + + + Nd4jStatus GraphExecutor::execute(sd::ops::DeclarableOp *op, ContextPrototype *contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { + auto ctx = prepareContext(contextPrototype, *graph.originalGraph().getVariableSpace()); + return op->execute(&ctx); + } + + Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { + /* + * this is a basic implementation that works without dispatching etc + */ + for (int e = 0; e < sequence.length(); e++) { + auto v = sequence[e]; + auto result = execute(v.first, v.second, sequence, graph, deviceId >= 0 ? deviceId : sequence.deviceId()); + if (result != Status::OK()) + return result; + } + + return Status::OK(); + } + Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph) const { - throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); + /* + * this is a basic exection logic: roll through layers and sequences and execute them one by one sequentially + */ + for (uint64_t l = 0; l < graph.layers(); l++) { + auto layer = graph.layer(l); + + for (uint64_t o = 0; layer.width(); o++) { + execute(layer[o], graph); + } + } } } } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0cc41d0f3e9b..d6edbe1a08d1 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -22,7 +22,13 @@ namespace sd { namespace graph { - OpSequence::OpSequence(const std::vector> &ops) { + OpSequence::OpSequence(const int deviceId) : _deviceId(deviceId) { + // + } + + OpSequence::OpSequence(const std::vector> &ops, const int deviceId) { + _deviceId = deviceId; + for (const auto v : ops) _ops.emplace_back(v); } @@ -57,6 +63,10 @@ namespace sd { return *this; } + int OpSequence::deviceId() const { + return _deviceId; + } + std::pair OpSequence::at(uint64_t index) const { return _ops[index]; } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index de4d1ab02114..88372efa455d 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -282,7 +282,7 @@ namespace sd { _onion->insert(pair); } - VariableSpace * Graph::getVariableSpace() { + VariableSpace * Graph::getVariableSpace() const { return _variableSpace; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index ed7393c22ca4..aa49db24f641 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -19,6 +19,7 @@ // #include +#include namespace sd { namespace graph { @@ -81,5 +82,9 @@ namespace sd { const GraphMemoryManager &OptimizedGraph::memoryManager() const { return *_memoryManager; } + + const Graph &OptimizedGraph::originalGraph() const { + return *_originalGraph; + } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 71ca402cfb03..909073545d8e 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -40,5 +40,5 @@ class GraphExecutorTests : public testing::Test { }; TEST_F(GraphExecutorTests, test_basic_exec_1) { - + GraphExecutor executor; } \ No newline at end of file From b2194f11e91d12e5ee45b073cee50692c5755102 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 15:47:05 +0300 Subject: [PATCH 047/233] GraphExecutor skeleton 2 Signed-off-by: raver119 --- libnd4j/include/graph/Context.h | 7 ++++++- libnd4j/include/graph/execution/OpSequence.h | 6 ++++++ .../graph/execution/impl/GraphExecutor.cpp | 20 +++++++++++++++++++ .../graph/execution/impl/OpSequence.cpp | 5 +++++ libnd4j/include/graph/impl/Context.cpp | 7 ++++++- libnd4j/include/ops/declarable/DeclarableOp.h | 10 +++++----- 6 files changed, 48 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index f5d17f1e3b41..54842e10ee6c 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -46,7 +47,9 @@ namespace sd { */ class SD_EXPORT Context : public sd::graph::ContextPrototype { protected: + sd::graph::GraphMemoryManager *_memoryManager = nullptr; sd::memory::Workspace* _workspace = nullptr; + sd::graph::VariableSpace* _variableSpace = nullptr; std::pair _executionTime; sd::random::RandomBuffer* _rng = nullptr; @@ -73,7 +76,7 @@ namespace sd { // special flag used during conversion from Graph exec to FastPath exec bool _forbidFastPath = false; public: - Context(ContextPrototype* prototype, VariableSpace* variableSpace); + Context(ContextPrototype* prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager = nullptr); explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); Context(int nodeId, VariableSpace *variableSpace, bool isInplace); @@ -233,6 +236,8 @@ namespace sd { bool isTraining(); bool isInference(); + + const GraphMemoryManager& memoryManager() const; }; } } diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 7fb1a333d55c..23d8dda366c3 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -56,6 +56,12 @@ namespace sd { int deviceId() const; + /** + * This method blocks until all operations within sequence are processed + * @return + */ + Nd4jStatus wait() const; + /** * This method returns number of individual operations within this sequence * @return diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 2db06f22f073..ae6df0538d2f 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -29,6 +29,13 @@ namespace sd { } Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { + // time to allocate outputs, if that's not inplace op + // inplace case is covered there + op->prepareOutputs(context); + + // once prepareOutputs method was called - we don't need shape function anymore + context.setShapeFunctionOverride(true); + return Status::OK(); } @@ -57,16 +64,29 @@ namespace sd { } Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph) const { + const auto numDevices = AffinityManager::numberOfDevices(); + /* * this is a basic exection logic: roll through layers and sequences and execute them one by one sequentially */ + Nd4jStatus result = Status::OK(); for (uint64_t l = 0; l < graph.layers(); l++) { auto layer = graph.layer(l); for (uint64_t o = 0; layer.width(); o++) { execute(layer[o], graph); } + + // optionally block until all sequences in this layer processed + if (layer.width() > 0 && numDevices > 1) + for (uint64_t o = 0; layer.width(); o++) { + result = layer[o].wait(); + if (result != Status::OK()) + return result; + } } + + return result; } } } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index d6edbe1a08d1..7c2a0dda5d80 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -114,5 +114,10 @@ namespace sd { bool OpSequence::iterator::operator!=(const OpSequence::iterator &other) const { return _position != other._position; } + + Nd4jStatus OpSequence::wait() const { + // TODO: to be implemented + return Status::OK(); + } } } diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 954329f4290d..1c0f944d7c39 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -26,7 +26,8 @@ namespace sd { namespace graph { - Context::Context(ContextPrototype* prototype, VariableSpace* variableSpace) { + Context::Context(ContextPrototype* prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager) { + _memoryManager = memoryManager; _variableSpace = variableSpace; _dataType = prototype->dataType(); @@ -585,6 +586,10 @@ namespace sd { _handles.clear(); } + + const GraphMemoryManager &Context::memoryManager() const { + return *_memoryManager; + } } } diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index cbdd500fcf66..ac9a106c9994 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -108,11 +108,6 @@ namespace sd { sd::NDArray* getZ(Context& block, int inputId = 0); sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); - /** - * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time - */ - int prepareOutputs(Context& block); - virtual samediff::EmptyHandling emptyHandling(); public: // for special cases, like BooleanOps @@ -215,6 +210,11 @@ namespace sd { // this method checks if number of available arguments matches op expectations Nd4jStatus validateArguments(Context& block); + + /** + * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time + */ + int prepareOutputs(Context& block); }; } } From 8122ab895a0b762e6e4dff1f864f7b8ba15ed294 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 30 Mar 2020 16:23:01 +0300 Subject: [PATCH 048/233] GraphExecutor skeleton 3 Signed-off-by: raver119 --- libnd4j/include/graph/execution/GraphExecutor.h | 2 +- libnd4j/include/graph/execution/impl/GraphExecutor.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index c11d40716a7e..01b448d12a2b 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -31,7 +31,7 @@ namespace sd { class SD_EXPORT GraphExecutor { protected: - virtual Context prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace) const; + virtual Context prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const; /* * preprocessor call involves: diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index ae6df0538d2f..604720789b0a 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -23,9 +23,9 @@ namespace sd { namespace graph { - Context GraphExecutor::prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace) const { + Context GraphExecutor::prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const { // TODO: maybe we'll want to do something here? - return Context(contextPrototype, &variableSpace); + return Context(contextPrototype, &variableSpace, const_cast(&memoryManager)); } Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { @@ -45,7 +45,7 @@ namespace sd { Nd4jStatus GraphExecutor::execute(sd::ops::DeclarableOp *op, ContextPrototype *contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { - auto ctx = prepareContext(contextPrototype, *graph.originalGraph().getVariableSpace()); + auto ctx = prepareContext(contextPrototype, *graph.originalGraph().getVariableSpace(), graph.memoryManager()); return op->execute(&ctx); } From 2dd19a5d89e2598f848a18e2ce21cbf8fa827167 Mon Sep 17 00:00:00 2001 From: Oleg Date: Mon, 30 Mar 2020 18:11:04 +0300 Subject: [PATCH 049/233] libnd4j raw implementation of topological sort Signed-off-by: Oleg --- libnd4j/include/graph/Graph.h | 11 +++++ libnd4j/include/graph/impl/Graph.cpp | 61 +++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 7b78026a5a31..ba81d066421b 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -261,6 +261,17 @@ namespace sd { * @return */ std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}, const GraphExecutor &executor = GraphExecutor()) const; +protected: + /* + * Topological graph analysis + * @param const start node for search + * @param const reference to list of nodes without external inputs + * @param const node positions in _handler + * @param operation gather + * @return stop iterating + */ + bool topolSearch(const int startNode, const std::set& nodeBranches, + const std::unordered_map& positions, OpSequence& opSeq) const; }; FORCEINLINE std::vector* Graph::nodes() { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 88372efa455d..63f83673b1c5 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -1710,16 +1710,65 @@ namespace sd { _variableSpace->putVariable(id, var); } - - OptimizedGraph Graph::optimizedGraph() const { - // TODO: implement this method - return OptimizedGraph(const_cast(_memoryMaager)); - } - std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { // TODO: implement this method return std::map(); } + + bool Graph::topolSearch(const int startNode, const std::set& nodeBranches, const std::unordered_map& positions, OpSequence& opSeq) const { + + if (nodeBranches.empty() || _handles.empty()) + return false; + + for (const auto& itNodes : nodeBranches) { + + auto position = positions.find(itNodes); + if (position != positions.end() && startNode != itNodes) { + for (auto in = _handles[position->second]->input()->begin(); in != _handles[position->second]->input()->end(); ++in) { + if (startNode == in->first) { + opSeq.append(_handles[position->second]->getCustomOp(), _handles[position->second]->getContextPrototype()); + return topolSearch(itNodes, nodeBranches, positions, opSeq); + } + } + } + } + + return true; + } + + OptimizedGraph Graph::optimizedGraph() const { + + OptimizedGraph optGraf(const_cast(_memoryMaager)); + OpSequence opSeq; + std::set nodesMap, startNodes; + std::unordered_map iDpositions; + + for (int i = 0; i < _handles.size(); ++i) { + auto ID = _handles[i]->id(); + int iNcouts = 0; + for (auto in = _handles[i]->input()->begin(); in != _handles[i]->input()->end(); ++in) { + if (in->first < 0) { + iNcouts++; + } + } + + if (iNcouts == _handles[i]->input()->size()) + startNodes.insert(ID); + + nodesMap.insert(ID); + iDpositions[ID] = i; + } + + for (const auto& start : startNodes) { + + auto position = iDpositions.find(start); + opSeq.append(_handles[position->second]->getCustomOp(), _handles[position->second]->getContextPrototype()); + topolSearch(start, nodesMap, iDpositions, opSeq); + } + optGraf.append(opSeq); + return optGraf; + } + } } From 079e4f247a982eee994fc99e1fe14b61f2743492 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 31 Mar 2020 07:45:51 +0300 Subject: [PATCH 050/233] some more rearrangements Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 11 +- libnd4j/include/graph/Node.h | 14 +- .../graph/execution/impl/GraphExecutioner.cpp | 8 +- libnd4j/include/graph/impl/Graph.cpp | 50 ++++-- libnd4j/include/graph/impl/Node.cpp | 162 +++++++++++------- .../include/graph/logic/impl/LogicEnter.cpp | 4 +- .../graph/logic/impl/LogicExecutor.cpp | 4 +- .../include/graph/logic/impl/LogicMerge.cpp | 6 +- .../graph/logic/impl/LogicNextIteration.cpp | 2 +- .../include/graph/logic/impl/LogicWhile.cpp | 4 +- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 4 +- 11 files changed, 166 insertions(+), 103 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index ba81d066421b..a75efc4bbefb 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -126,13 +126,20 @@ namespace sd { sd::graph::VariableSpace *getVariableSpace() const; /** - * This method adds given node to the graph - * + * These methods add given node to the graph + * FIXME: deprecated * @param node */ void addNode(sd::graph::Node *node); void addNode(const sd::graph::Node &node); + /** + * These methods add given node to the graph + * @param node + */ + void addNode(Node &node, const std::vector &inputs); + void addNode(Node &node, const std::vector> &inputs); + /** * This method allows to add placeholder with some pre-defined properties */ diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 1f2077b9e4aa..dafeaaebb98a 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -99,12 +99,18 @@ namespace sd { Nd4jLong _frameId = -1; public: + + explicit Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); + explicit Node(const sd::graph::FlatNode *node); + ~Node(); + + /* + * FIXME: deprecated methods, to be removed + */ explicit Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(const std::string &opName, const int id = 0, const std::vector> &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - explicit Node(const sd::graph::FlatNode *node); - ~Node(); bool equals(Node *other); @@ -167,8 +173,8 @@ namespace sd { void setName(std::string *name); void setName(const std::string& name); - std::string * getName(); - std::string * name(); + const std::string& getName() const; + const std::string& name() const; int totalReferences(); void addReference(int nodeId); diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index b38fd3ad71be..d2284455b3d2 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -259,7 +259,7 @@ namespace sd{ Node* node = graph->getOnion()->at(l)->at(n); if (Environment::getInstance()->isProfiling()) - flowPath->profile()->nodeById(node->id(), node->name()->c_str()); + flowPath->profile()->nodeById(node->id(), node->name().c_str()); if (lastId != node->id() && Environment::getInstance()->isProfiling()) { if (lastId != -10000000) @@ -269,7 +269,7 @@ namespace sd{ nodeTime = GraphProfile::currentTime(); } - nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name()->c_str()); + nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name().c_str()); // on first non-Exit node after loop we can rewind (if planned) if (!(node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit)) { @@ -552,8 +552,8 @@ namespace sd{ continue; auto pair = CreateLongPair(builder, flowPath.outerTime(node->id()), flowPath.innerTime(node->id())); - if (node->getName() != nullptr) { - auto name = builder.CreateString(node->getName()->c_str()); + if (!node->name().empty()) { + auto name = builder.CreateString(node->getName().c_str()); auto fr = CreateFlatTiming(builder, node->id(), name, pair); timings_vector.push_back(fr); } else { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 63f83673b1c5..7650354599aa 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -308,6 +308,14 @@ namespace sd { delete _configuration; } + void Graph::addNode(Node &node, const std::vector &inputs) { + throw std::runtime_error("not implemented yet"); + } + + void Graph::addNode(Node &node, const std::vector> &inputs) { + throw std::runtime_error("not implemented yet"); + } + void Graph::addNode(const sd::graph::Node &node) { node.markRemovable(false); addNode(const_cast(&node)); @@ -320,16 +328,16 @@ namespace sd { // nd4j_debug("Adding LogicOp [%i]\n", node->opNum()); // SCOPE if (node->opNum() == logic::Scope) { - auto scope = new Scope(node->id(), node->getName() != nullptr ? node->getName()->c_str() : ""); + auto scope = new Scope(node->id(), !node->getName().empty() ? node->getName().c_str() : ""); _mappedScopes[node->id()] = scope; _scopes.push_back(scope); } } - auto cname = node->getName() == nullptr ? nullptr : node->getName()->c_str(); + auto cname = node->name().empty() ? nullptr : node->getName().c_str(); auto nodeState = _variableSpace->hasVariable(node->id()) ? _variableSpace->getVariable(node->id()) :new Variable(nullptr, cname, node->id()); - if (node->getName() != nullptr) - nodeState->setName(*node->getName()); + if (!node->name().empty()) + nodeState->setName(node->name()); if (node->isInplace()) @@ -420,7 +428,7 @@ namespace sd { // nd4j_logger("Adding auto output variable; Output size: %i\n", node->output()->size()); var->setId(node->id()); - var->setName(*node->getName()); + var->setName(node->getName()); _variableSpace->putOutputVariable(var); //node->pickExternalOutput(var->id()); @@ -558,10 +566,10 @@ namespace sd { // single-input node if (node->input()->size() == 1) { - if (node->getName() == nullptr) { + if (!node->name().empty()) { nd4j_debug("Trying SI Node_%i\n", node->id()); } else { - nd4j_debug("Trying SI Node_%i:[%s]\n", node->id(), node->getName()->c_str()); + nd4j_debug("Trying SI Node_%i:[%s]\n", node->id(), node->getName().c_str()); } int iNode = node->input()->at(0).first; @@ -628,11 +636,11 @@ namespace sd { queue.emplace_back(node->id()); } else { // multi-input node - if (node->getName() == nullptr) { + if (!node->name().empty()) { nd4j_debug("Trying MI Node_%i\n", node->id()); } else { - std::string np = *(node->getName()); - nd4j_debug("Trying MI Node_%i:[%s]\n", node->id(), node->getName()->c_str()); + auto np = node->name(); + nd4j_debug("Trying MI Node_%i:[%s]\n", node->id(), node->getName().c_str()); } int maxLayer = 0; @@ -823,8 +831,8 @@ namespace sd { continue; Node* node = _mapped->at(v); - if (node->name() != nullptr) { - nd4j_debug("Node %i; Name: [%s]\n", v, node->name()->c_str()); + if (!node->name().empty()) { + nd4j_debug("Node %i; Name: [%s]\n", v, node->name().c_str()); } else { nd4j_debug("Node %i\n", v); } @@ -849,8 +857,8 @@ namespace sd { Node* node = _mapped->at(v); if (!node->hasInternalOutputs()) { - if (node->name() != nullptr) { - nd4j_debug("Output node found: [%i:<%s>]\n", v, node->name()->c_str()); + if (!node->name().empty()) { + nd4j_debug("Output node found: [%i:<%s>]\n", v, node->name().c_str()); } else { nd4j_debug("Output node found: [%i]\n", v); } @@ -860,7 +868,7 @@ namespace sd { if (std::find(_output.begin(), _output.end(), node->id()) == _output.end()) _output.emplace_back(node->id()); } else if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_debug("Node [%i:<%s>] has %i outputs announced:\n", v, node->name()->c_str(), node->output()->size()); + nd4j_debug("Node [%i:<%s>] has %i outputs announced:\n", v, node->name().c_str(), node->output()->size()); printf("{"); for (auto s : *node->output()) { printf("[%i:%i], ", s.first, s.second); @@ -1193,7 +1201,9 @@ namespace sd { for (int n = 0; n < layerSize; n++) { Node* node = _onion->at(l)->at(n); - if (node->name() == nullptr) continue; + if (node->name().empty()) + continue; + sd::ops::OpDescriptor* pOpDescriptor = nullptr; std::string opNameStr; //node->name(); int numInputs = 0; @@ -1241,7 +1251,9 @@ namespace sd { for (int n = 0; n < scope->nodes()->size(); n++) { Node* node = scope->nodes()->at(n); //printOutNode(node); - if (node->name() == nullptr) continue; + if (node->name().empty()) + continue; + std::string opNameStr; //node->name(); sd::ops::OpDescriptor* pOpDescriptor = nullptr; int numInputs = 0; @@ -1448,8 +1460,8 @@ namespace sd { Node *node = v.second; // optional part: node names - if (!node->name()->empty()) { - localStamp += *(node->name()); + if (!node->name().empty()) { + localStamp += node->name(); } } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index bd84daf85c3a..0b7a675d3187 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -44,29 +44,67 @@ namespace sd { namespace graph { - void sd::graph::Node::setOuterTime(Nd4jLong time){ + Node::Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs, + const std::vector &iArgs, const std::vector &bArgs, + const std::vector &dArgs) { + + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_dim = nullptr; + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + // FIXME: get rid of this!!! + _scalar = NDArrayFactory::create(0); + + auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + + for (auto v: iArgs) + block->getIArguments()->emplace_back(v); + + for (auto v: tArgs) + block->getTArguments()->emplace_back(v); + + for (auto v: bArgs) + block->getBArguments()->emplace_back(v); + + for (auto v: dArgs) + block->getDArguments()->emplace_back(v); + + this->setContextPrototype(block); + } + + void Node::setOuterTime(Nd4jLong time){ // if (hasBlockAttached()) // _block->setOuterTime(time); } - void sd::graph::Node::setInnerTime(Nd4jLong time){ + void Node::setInnerTime(Nd4jLong time){ // if (hasBlockAttached()) // _block->setInnerTime(time); } - void sd::graph::Node::setGraph(sd::graph::Graph* graph) { + void Node::setGraph(Graph* graph) { _graph = graph; } - sd::graph::Graph* sd::graph::Node::getGraph() { + Graph* Node::getGraph() { return _graph; } - bool sd::graph::Node::hasGraphEmbedded() { + bool Node::hasGraphEmbedded() { return _graph != nullptr; } - void sd::graph::Node::markInplace(bool reallyInplace) { + void Node::markInplace(bool reallyInplace) { _isInplace = reallyInplace; if (_protoContext != nullptr) { _protoContext->markInplace(reallyInplace); @@ -81,19 +119,19 @@ namespace sd { _removable = reallyRemovable; } - OpClass sd::graph::Node::getOpClass() { + OpClass Node::getOpClass() { return _opClass; } - bool sd::graph::Node::hasBlockAttached() { + bool Node::hasBlockAttached() { return _protoContext != nullptr; } - bool sd::graph::Node::isInplace() { + bool Node::isInplace() { return _isInplace; } - bool sd::graph::Node::isDivergencePoint() { + bool Node::isDivergencePoint() { if (hasCustomOp()) { return _customOp->getOpDescriptor()->isDivergent(); } else if (opType() == OpType_LOGIC && opNum() == 30) @@ -102,11 +140,11 @@ namespace sd { return false; } - void sd::graph::Node::setActive(bool reallyActive) { + void Node::setActive(bool reallyActive) { _active = reallyActive; } - bool sd::graph::Node::isActive() { + bool Node::isActive() { return _active; } @@ -118,7 +156,7 @@ namespace sd { _frameId = frameId; } - ContextPrototype * sd::graph::Node::getContextPrototype() { + ContextPrototype * Node::getContextPrototype() { if (_protoContext == nullptr) _protoContext = new ContextPrototype(this->getCustomOp() != nullptr ? this->getCustomOp()->getOpDescriptor() : nullptr, this->id()); if (_protoContext->inputs()->empty()) { @@ -129,22 +167,22 @@ namespace sd { return _protoContext; } - void sd::graph::Node::setContextPrototype(ContextPrototype *block) { + void Node::setContextPrototype(ContextPrototype *block) { if (_protoContext != nullptr) throw std::runtime_error("Block already exists"); _protoContext = block; } - void sd::graph::Node::setId(int id) { + void Node::setId(int id) { _id = id; } - sd::ops::DeclarableOp* sd::graph::Node::getCustomOp() { + sd::ops::DeclarableOp* Node::getCustomOp() { return _customOp; } - void sd::graph::Node::setCustomOp(sd::ops::DeclarableOp *customOp) { + void Node::setCustomOp(sd::ops::DeclarableOp *customOp) { _customOp = customOp; // divergent ops (Switch etc) are always inplace, they don't allocate anything @@ -152,31 +190,31 @@ namespace sd { _isInplace = true; } - bool sd::graph::Node::hasCustomOp() { + bool Node::hasCustomOp() { return _customOp != nullptr; } - std::string * sd::graph::Node::name() { + const std::string & Node::name() const { return this->getName(); } - std::string * sd::graph::Node::getName() { - return &_name; + const std::string & Node::getName() const { + return _name; } - void sd::graph::Node::setName(const std::string& name) { - _name = name.c_str(); + void Node::setName(const std::string& name) { + _name = name; } - void sd::graph::Node::setName(std::string *name) { + void Node::setName(std::string *name) { _name = *name; } - double sd::graph::Node::scalar() { + double Node::scalar() { return _scalar.e(0); }; - void sd::graph::Node::pickInput(std::pair& pair) { + void Node::pickInput(std::pair& pair) { _input.push_back(pair); } @@ -184,12 +222,12 @@ namespace sd { throw std::runtime_error("Node::pickInput - Not implemented yet"); } - void sd::graph::Node::pickInput(int inputId, int outputId) { + void Node::pickInput(int inputId, int outputId) { std::pair p(inputId,outputId); pickInput(p); } - void sd::graph::Node::pickInput(int inputId) { + void Node::pickInput(int inputId) { pickInput(inputId, 0); if (inputId < 0) @@ -198,25 +236,25 @@ namespace sd { _hasInternalInputs = true; } - void sd::graph::Node::pickExternalOutput(int outputId) { + void Node::pickExternalOutput(int outputId) { std::pair pair(outputId, 0); _output.push_back(pair); _hasExternalOutputs = true; } - void sd::graph::Node::pickOutputOnce(int outputId) { + void Node::pickOutputOnce(int outputId) { std::pair pair(outputId, 0); if (std::find(_output.begin(), _output.end(), pair) == _output.end()) pickOutput(outputId); } - void sd::graph::Node::pickOutput(int nodeId, int outputId) { + void Node::pickOutput(int nodeId, int outputId) { std::pair pair(nodeId, outputId); _output.emplace_back(pair); } - void sd::graph::Node::pickOutput(int outputId) { + void Node::pickOutput(int outputId) { std::pair pair(outputId, 0); _output.emplace_back(pair); @@ -226,47 +264,47 @@ namespace sd { _hasInternalOutputs = true; } - int * sd::graph::Node::getDimensionsPtr() { + int * Node::getDimensionsPtr() { return _dim; } - std::vector * sd::graph::Node::getDimensions() { + std::vector * Node::getDimensions() { return &_dimensions; } - int sd::graph::Node::getLayer() { + int Node::getLayer() { return _layer; } - void sd::graph::Node::setLayer(int layer) { + void Node::setLayer(int layer) { _layer = layer; } - bool sd::graph::Node::hasExternalOutputs() { + bool Node::hasExternalOutputs() { return _hasExternalOutputs; } - bool sd::graph::Node::hasExternalInputs() { + bool Node::hasExternalInputs() { return _hasExternalInputs; } - bool sd::graph::Node::hasInternalOutputs() { + bool Node::hasInternalOutputs() { return _hasInternalOutputs; } - bool sd::graph::Node::hasInternalInputs() { + bool Node::hasInternalInputs() { return _hasInternalInputs; } - bool sd::graph::Node::isMultiInput() { + bool Node::isMultiInput() { return _input.size() > 1; } - bool sd::graph::Node::isMultiOutput() { + bool Node::isMultiOutput() { return _output.size() > 1; } - double * sd::graph::Node::extraParams() { + double * Node::extraParams() { return _extraParams; } @@ -278,23 +316,23 @@ namespace sd { _referencedBy.emplace_back(nodeId); } - sd::graph::OpType sd::graph::Node::opType() { + OpType Node::opType() { return _opType; } - int sd::graph::Node::id() { + int Node::id() { return _id; } - Nd4jLong sd::graph::Node::opNum() { + Nd4jLong Node::opNum() { return _opNum; } - std::vector> *sd::graph::Node::input() { + std::vector> *Node::input() { return &_input; } - std::vector> *sd::graph::Node::output() { + std::vector> *Node::output() { return &_output; } @@ -391,7 +429,7 @@ namespace sd { this->setContextPrototype(block); } - sd::graph::Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { + Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { this->_opType = OpType_CUSTOM; this->_id = id; this->_opNum = customOp->getOpHash(); @@ -437,11 +475,11 @@ namespace sd { this->setContextPrototype(block); } - void sd::graph::Node::setOpType(OpType opType) { + void Node::setOpType(OpType opType) { this->_opType = opType; } - sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { + Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { this->_opType = opType; this->_id = id; this->_opNum = opNum; @@ -534,7 +572,7 @@ namespace sd { } }; - sd::graph::Node::Node(const sd::graph::FlatNode *node) { + Node::Node(const FlatNode *node) { _hasExternalInputs = false; _hasExternalOutputs = false; _hasInternalInputs = false; @@ -549,7 +587,7 @@ namespace sd { this->_scope_name = node->scope_name()->str(); if (node->scalar() != nullptr) { - auto scalar = sd::graph::FlatUtils::fromFlatArray(node->scalar()); + auto scalar = FlatUtils::fromFlatArray(node->scalar()); _scalar = *scalar; delete scalar; } @@ -769,7 +807,7 @@ namespace sd { return _protoContext; } - sd::graph::Node::~Node() { + Node::~Node() { if (_extraParams != nullptr) delete[] _extraParams; @@ -784,31 +822,31 @@ namespace sd { } } - int sd::graph::Node::getRewindNode() { + int Node::getRewindNode() { return _rewindNode; } - void sd::graph::Node::setRewindNode(int nodeId) { + void Node::setRewindNode(int nodeId) { _rewindNode = nodeId; } - std::pair& sd::graph::Node::getRewindLayer() { + std::pair& Node::getRewindLayer() { return _rewindLayer; }; - void sd::graph::Node::setRewindLayer(int layerId, int stepId) { + void Node::setRewindLayer(int layerId, int stepId) { _rewindLayer.first = layerId; _rewindLayer.second = stepId; } - bool sd::graph::Node::equals(Node *other) { + bool Node::equals(Node *other) { if (_opType == other->_opType && _dataType == other->_dataType && _opNum == other->_opNum) return true; return false; } - void sd::graph::Node::deleteOpByType(OpType opType, void *op) { + void Node::deleteOpByType(OpType opType, void *op) { switch (opType) { case OpType_PAIRWISE: delete reinterpret_cast(op); @@ -872,7 +910,7 @@ namespace sd { } } - sd::ops::DeclarableOp* sd::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { + sd::ops::DeclarableOp* Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { switch (opType) { case OpType_PAIRWISE: return new sd::ops::LegacyPairwiseTransformOp(opNum); @@ -925,7 +963,7 @@ namespace sd { Node* Node::clone() { - if (this->_customOp && this->_opType == sd::graph::OpType_CUSTOM) { + if (this->_customOp && this->_opType == OpType_CUSTOM) { auto clone = new Node(this->_customOp, _id); clone->pullValues(this); return clone; diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 1f9a973d88b2..c9e7b0c17e6f 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -42,7 +42,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); auto array = var->getNDArray(); lvar->setNDArray(array); @@ -54,7 +54,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); auto list = var->getNDArrayList(); lvar->setNDArrayList(list); diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 2e4898da31e4..70195eb22be3 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -60,10 +60,10 @@ namespace sd { return LogicEnter::processNode(graph, node); } - if (node->getName() == nullptr) { + if (node->getName().empty()) { nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), node->opNum()); } else { - nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), node->getName()->c_str(), node->opNum()); + nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), node->getName().c_str(), node->opNum()); } return ND4J_STATUS_BAD_INPUT; } diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 6d374005dee7..7118cd606418 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -64,7 +64,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); @@ -87,7 +87,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); @@ -113,7 +113,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); if (lvar->hasNDArray()) delete lvar->getNDArray(); diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 0765eb4ee783..9d3ae15cecff 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -35,7 +35,7 @@ namespace sd { if (__variableSpace->hasVariable(node->id(), 0)) lvar = __variableSpace->getVariable(node->id(), 0); else - lvar = new Variable(nullptr, node->getName()->c_str(), node->id(), 0); + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index bd75c82eab8f..b2dd00589484 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -79,7 +79,7 @@ namespace sd { nd4j_debug("Falling back to logic\n",""); LogicExecutor::processNode(graph, v); } else { - nd4j_debug("Op [<%s>]\n", v->getName()->c_str()); + nd4j_debug("Op [<%s>]\n", v->getName().c_str()); Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); if (status != ND4J_STATUS_OK) return status; @@ -114,7 +114,7 @@ namespace sd { nd4j_debug("Falling back to logic\n",""); LogicExecutor::processNode(graph, v); } else { - nd4j_debug("Op [<%s>]\n", v->getName()->c_str()); + nd4j_debug("Op [<%s>]\n", v->getName().c_str()); //v->getBlock()->updateVariables(); Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); if (status != ND4J_STATUS_OK) diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 8f4a8ae70bde..5c6d5d652413 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -39,7 +39,7 @@ TEST_F(NodeTests, Test_Dtype_Conversion_1) { auto nf = nd->asT(); ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(*nodeA->name(), *nf->name()); + ASSERT_EQ(nodeA->name(), nf->name()); ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); ASSERT_EQ(nodeA->opType(), nf->opType()); ASSERT_EQ(nodeA->opNum(), nf->opNum()); @@ -61,7 +61,7 @@ TEST_F(NodeTests, Test_Dtype_Conversion_2) { auto nf = nd->asT(); ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(*nodeA->name(), *nf->name()); + ASSERT_EQ(nodeA->name(), nf->name()); // ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); ASSERT_EQ(nodeA->opType(), nf->opType()); ASSERT_EQ(nodeA->opNum(), nf->opNum()); From a5ad0a1dcf4ae686091cb2bed56ca3adb12fe535 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 31 Mar 2020 10:56:31 +0300 Subject: [PATCH 051/233] next step Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 36 +++++---- libnd4j/include/graph/Node.h | 3 +- libnd4j/include/graph/VariableProxy.h | 12 +-- libnd4j/include/graph/VariableSpace.h | 24 +++--- .../graph/execution/impl/GraphExecutioner.cpp | 2 +- libnd4j/include/graph/impl/Graph.cpp | 47 ++++++++++-- libnd4j/include/graph/impl/Node.cpp | 39 ++++++++++ libnd4j/include/graph/impl/VariableProxy.cpp | 16 ++-- libnd4j/include/graph/impl/VariableSpace.cpp | 74 +++++++++---------- libnd4j/include/ops/declarable/DeclarableOp.h | 4 +- .../include/ops/declarable/impl/BooleanOp.cpp | 2 +- .../ops/declarable/impl/DeclarableListOp.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 26 +++---- .../ops/declarable/impl/OpRegistrator.cpp | 2 +- .../layers_tests/DeclarableOpsTests1.cpp | 8 +- .../tests_cpu/layers_tests/GraphTests2.cpp | 29 ++++---- .../layers_tests/JavaInteropTests.cpp | 12 +-- 17 files changed, 213 insertions(+), 125 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index a75efc4bbefb..f42ebf43e3b9 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -54,10 +54,11 @@ namespace sd { // vector holds ID's of top nodes only std::vector *_nodes; - MAP_IMPL *_mapped; + MAP_IMPL *_mapped; - MAP_IMPL *> *_onion; - MAP_IMPL _unmapped; + MAP_IMPL *> *_onion; + MAP_IMPL _unmapped; + MAP_IMPL _symbolicLookupTable; std::vector _unmappedMap; // macOS? std::mutex _mutexPreprocessing; @@ -66,6 +67,9 @@ namespace sd { std::vector _output; std::vector _autos; + // we want to know last node id + int _maxId = 1; + MAP_IMPL _mappedScopes; std::vector _scopes; @@ -73,11 +77,11 @@ namespace sd { const GraphMemoryManager &_memoryMaager; //////////////////////////////////////// - Nd4jStatus validateNode(sd::graph::Node *node); + Nd4jStatus validateNode(Node *node); void expandOnion(int newLayer); - void injectNode(sd::graph::Node *node); + void injectNode(Node *node); void pushToOutputOnce(int id); @@ -85,6 +89,7 @@ namespace sd { void prepareOutputs(); + int idByName(const std::string &nodeName) const; public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); @@ -117,21 +122,21 @@ namespace sd { int numberOfPlaceholders(); - const std::vector& getPlaceholders() const; + const std::vector& getPlaceholders() const; /** * This method returns pointer to thread_local VariableSpace * @return */ - sd::graph::VariableSpace *getVariableSpace() const; + VariableSpace *getVariableSpace() const; /** * These methods add given node to the graph * FIXME: deprecated * @param node */ - void addNode(sd::graph::Node *node); - void addNode(const sd::graph::Node &node); + void addNode(Node *node); + void addNode(const Node &node); /** * These methods add given node to the graph @@ -140,29 +145,32 @@ namespace sd { void addNode(Node &node, const std::vector &inputs); void addNode(Node &node, const std::vector> &inputs); + void addVariable(const std::string &name, NDArray &array); + void addVariable(const std::string &name, NDArray &&array); + /** * This method allows to add placeholder with some pre-defined properties */ - void addPlaceholder(const std::string &nodeName, const int id = 0, const DataType dataType = sd::DataType::ANY, const std::vector &shape = {}); + void addPlaceholder(const std::string &nodeName, const DataType dataType = sd::DataType::ANY, const std::vector &shape = {}); /** * This method returns layered representation of the graph * * @return */ - MAP_IMPL *> *getOnion(); + MAP_IMPL *> *getOnion(); /** * This method returns map of all nodes of the graph * @return */ - MAP_IMPL* getMapped(); + MAP_IMPL* getMapped(); /** * This method returns outputs of this graph * @return */ - std::vector *fetchOutputs(); + std::vector *fetchOutputs(); /** * This method returns pointer to ExecutorConfiguration @@ -181,7 +189,7 @@ namespace sd { * This method returns all nodes at once (order is NOT guaranteed) * @return */ - std::vector *getAllNodes(); + std::vector *getAllNodes(); /** * This method prints out Graph op-by-op, and respective inputs diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index dafeaaebb98a..f06d89d44815 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -44,7 +44,7 @@ namespace sd { OpType _opType; ContextPrototype* _protoContext = nullptr; Nd4jLong _opNum; - int _id; + int _id = 0; std::vector> _input; std::vector> _output; std::vector _dimensions; @@ -100,6 +100,7 @@ namespace sd { public: + explicit Node(const std::string &nodeName, const sd::ops::DeclarableOp &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const sd::graph::FlatNode *node); ~Node(); diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 7d7ea62da985..2e855ce7a9f1 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -46,10 +46,10 @@ namespace sd { virtual bool hasVariable(std::pair& pair) override; virtual bool hasVariable(const std::string &symbol) override; - virtual sd::graph::Variable *getVariable(int id) override; - virtual sd::graph::Variable *getVariable(int id, int idx) override; - virtual sd::graph::Variable *getVariable(std::pair& pair) override; - virtual sd::graph::Variable *getVariable(const std::string &symbol) override; + virtual Variable *getVariable(int id) override; + virtual Variable *getVariable(int id, int idx) override; + virtual Variable *getVariable(std::pair& pair) override; + virtual Variable *getVariable(const std::string &symbol) override; virtual std::vector getVariables() override; @@ -79,9 +79,9 @@ namespace sd { virtual int internalEntries() override; virtual int totalEntries() override; - virtual sd::graph::VariableSpace *clone() override; + virtual VariableSpace *clone() override; - virtual sd::graph::Stash* getStash() override; + virtual Stash* getStash() override; virtual void setFlowPath(FlowPath* timers) override; virtual FlowPath* flowPath() override; }; diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 66fef674e09c..a7fd05228ad6 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -43,7 +43,7 @@ namespace sd { sd::memory::Workspace *_workspace; // stash is NOT cloned - sd::graph::Stash _stash; + Stash _stash; MAP_IMPL, Variable*> _paired; MAP_IMPL _symbolic; @@ -53,7 +53,7 @@ namespace sd { std::vector _lists; - std::vector _placeholders; + std::vector _placeholders; void silentPutVariable(std::pair& pair, Variable *variable); @@ -61,9 +61,9 @@ namespace sd { std::mutex _varmap; - MAP_IMPL _temporary; + MAP_IMPL _temporary; - std::vector *_handles; + std::vector *_handles; FlowPath* _flow = nullptr; @@ -88,10 +88,10 @@ namespace sd { virtual bool hasVariable(std::pair& pair); virtual bool hasVariable(const std::string &symbol); - virtual sd::graph::Variable* getVariable(int id); - virtual sd::graph::Variable* getVariable(int id, int idx); - virtual sd::graph::Variable* getVariable(std::pair& pair); - virtual sd::graph::Variable* getVariable(const std::string &symbol); + virtual Variable* getVariable(int id); + virtual Variable* getVariable(int id, int idx); + virtual Variable* getVariable(std::pair& pair); + virtual Variable* getVariable(const std::string &symbol); virtual std::vector getVariables(); @@ -121,17 +121,17 @@ namespace sd { virtual int internalEntries(); virtual int totalEntries(); - virtual sd::graph::VariableSpace* clone(); + virtual VariableSpace* clone(); std::vector *handles(); - sd::graph::VariableSpace* asT(); + VariableSpace* asT(); void injectVariable(std::pair &pair, Variable* variable); - virtual sd::graph::Stash* getStash(); + virtual Stash* getStash(); - virtual std::vector * getExternalVariables(); + virtual std::vector * getExternalVariables(); virtual void setFlowPath(FlowPath* timers); virtual FlowPath* flowPath(); diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp index d2284455b3d2..4cda626bb115 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp @@ -82,7 +82,7 @@ namespace sd{ } else if (opType != OpType_CUSTOM) { nd4j_debug("Executing node_%i{%i}\n", node->id(), opNum); } else { - nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName()->c_str()); + nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName().c_str()); } Context context(node->getContextPrototype(), variableSpace); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 7650354599aa..dd63c7dec62c 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -78,7 +78,7 @@ namespace sd { //} - nd4j_debug("Trying estimation [%i] on [%s]\n", node->id(), node->getCustomOp()->getOpName()->c_str()); + nd4j_debug("Trying estimation [%i] on [%s]\n", node->id(), node->getCustomOp()->getOpName().c_str()); auto op = node->getCustomOp(); auto in = node->input()->at(0); @@ -308,15 +308,48 @@ namespace sd { delete _configuration; } + int Graph::idByName(const std::string &nodeName) const { + if (_symbolicLookupTable.count(nodeName) == 0) + throw std::runtime_error("Can't find node [" + nodeName + "]"); + + return _symbolicLookupTable.at(nodeName); + } + + void Graph::addVariable(const std::string &name, NDArray &array) { + int id = _maxId++; + _symbolicLookupTable[name] = id; + _variableSpace->putVariable(id, 0, array); + } + + void Graph::addVariable(const std::string &name, NDArray &&array) { + auto lvalue = array; + addVariable(name, lvalue); + } + void Graph::addNode(Node &node, const std::vector &inputs) { - throw std::runtime_error("not implemented yet"); + if (node.id() != 0) + throw std::runtime_error("Graph::addNode - Node has id defined"); + + node.markRemovable(false); + + // node must have numeric id + node.setId(_maxId++); + _symbolicLookupTable[node.name()] = node.id(); + + // converting string ids to numeric ones + for (auto &v:inputs) + node.pickInput(idByName(v), 0); + + addNode(&node); } void Graph::addNode(Node &node, const std::vector> &inputs) { + node.markRemovable(false); + throw std::runtime_error("not implemented yet"); } - void Graph::addNode(const sd::graph::Node &node) { + void Graph::addNode(const Node &node) { node.markRemovable(false); addNode(const_cast(&node)); } @@ -1082,7 +1115,7 @@ namespace sd { nd4j_printf("%i. ", node->id()); switch(node->opType()) { case OpType_CUSTOM: { - printf("%s; ", node->getCustomOp()->getOpName()->c_str()); + printf("%s; ", node->getCustomOp()->getOpName().c_str()); } break; case OpType_LOGIC: { @@ -1716,7 +1749,11 @@ namespace sd { */ } - void Graph::addPlaceholder(const std::string &nodeName, const int id, DataType dataType, const std::vector &shape) { + void Graph::addPlaceholder(const std::string &nodeName, DataType dataType, const std::vector &shape) { + int id = _maxId++; + + _symbolicLookupTable[nodeName] = id; + auto var = new Variable(true, dataType, shape); var->setName(nodeName); _variableSpace->putVariable(id, var); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 0b7a675d3187..632dd76dc798 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -44,12 +44,51 @@ namespace sd { namespace graph { + Node::Node(const std::string &nodeName, const ops::DeclarableOp &opName, const std::vector &tArgs, + const std::vector &iArgs, const std::vector &bArgs, + const std::vector &dArgs) { + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName.getOpHash()); + + this->_name = nodeName; + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_dim = nullptr; + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + // FIXME: get rid of this!!! + _scalar = NDArrayFactory::create(0); + + auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + + for (auto v: iArgs) + block->getIArguments()->emplace_back(v); + + for (auto v: tArgs) + block->getTArguments()->emplace_back(v); + + for (auto v: bArgs) + block->getBArguments()->emplace_back(v); + + for (auto v: dArgs) + block->getDArguments()->emplace_back(v); + + this->setContextPrototype(block); + } + Node::Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs) { auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + this->_name = nodeName; this->_opType = OpType_CUSTOM; this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index d15d3785af2f..dee01997d904 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -110,7 +110,7 @@ namespace sd { } - sd::graph::Variable *VariableProxy::getVariable(int id) { + Variable *VariableProxy::getVariable(int id) { if (_current->hasVariable(id)) return _current->getVariable(id); @@ -122,7 +122,7 @@ namespace sd { } - sd::graph::Variable *VariableProxy::getVariable(int id, int idx) { + Variable *VariableProxy::getVariable(int id, int idx) { if (_current->hasVariable(id, idx)) return _current->getVariable(id, idx); @@ -134,7 +134,7 @@ namespace sd { } - sd::graph::Variable *VariableProxy::getVariable(std::pair& pair) { + Variable *VariableProxy::getVariable(std::pair& pair) { if (_current->hasVariable(pair)) return _current->getVariable(pair); @@ -146,7 +146,7 @@ namespace sd { } - sd::graph::Variable *VariableProxy::getVariable(const std::string &symbol) { + Variable *VariableProxy::getVariable(const std::string &symbol) { if (_current->hasVariable(symbol)) return _current->getVariable(symbol); @@ -191,7 +191,7 @@ namespace sd { _current->putVariable(id, array); } - void sd::graph::VariableProxy::putVariable(int id, int idx, const NDArray &array) { + void VariableProxy::putVariable(int id, int idx, const NDArray &array) { _current->putVariable(id, idx, array); } @@ -210,7 +210,7 @@ namespace sd { } - sd::graph::Stash* VariableProxy::getStash() { + Stash* VariableProxy::getStash() { return _current->getStash(); } @@ -260,7 +260,7 @@ namespace sd { } - sd::graph::VariableSpace* VariableProxy::clone() { + VariableSpace* VariableProxy::clone() { auto clone = new VariableProxy(_backed); delete clone->_current; @@ -279,7 +279,7 @@ namespace sd { } - sd::memory::Workspace * sd::graph::VariableProxy::workspace() { + sd::memory::Workspace * VariableProxy::workspace() { return _workspace; } } diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index b4ea8afeff51..024768bcf689 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -23,15 +23,15 @@ namespace sd { namespace graph { - std::vector * sd::graph::VariableSpace::getExternalVariables() { + std::vector * VariableSpace::getExternalVariables() { return &_external; } - sd::graph::Stash* sd::graph::VariableSpace::getStash() { + Stash* VariableSpace::getStash() { return &_stash; } - sd::graph::VariableSpace* sd::graph::VariableSpace::clone() { + VariableSpace* VariableSpace::clone() { auto result = new VariableSpace(); for (auto const& x : _paired) { @@ -50,7 +50,7 @@ namespace sd { } - sd::graph::VariableSpace* sd::graph::VariableSpace::asT() { + VariableSpace* VariableSpace::asT() { auto result = new VariableSpace(); for (auto const& x : _paired) { @@ -65,7 +65,7 @@ namespace sd { } - void sd::graph::VariableSpace::injectVariable(std::pair &pair, Variable* variable) { + void VariableSpace::injectVariable(std::pair &pair, Variable* variable) { if (pair.second == 0) { if (pair.first < 0) this->_variables[pair.first] = variable; @@ -81,23 +81,23 @@ namespace sd { this->_handles->push_back(variable); } - std::vector * sd::graph::VariableSpace::getPlaceholders() { + std::vector * VariableSpace::getPlaceholders() { return &_placeholders; } - int sd::graph::VariableSpace ::numberOfPlaceholders() { + int VariableSpace ::numberOfPlaceholders() { return _placeholders.size(); } - bool sd::graph::VariableSpace::hasVariable(const std::string &symbol) { + bool VariableSpace::hasVariable(const std::string &symbol) { return _symbolic.count(symbol) == 1; } - sd::graph::Variable * sd::graph::VariableSpace::getVariable(const std::string &symbol) { + Variable * VariableSpace::getVariable(const std::string &symbol) { return _symbolic.at(symbol); } - bool sd::graph::VariableSpace::hasVariable(int id, int index) { + bool VariableSpace::hasVariable(int id, int index) { std::pair pair(id, index); return hasVariable(pair); } @@ -126,12 +126,12 @@ namespace sd { return var->isExternal(); } - sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id, int index) { + Variable * VariableSpace::getVariable(int id, int index) { std::pair pair(id, index); return getVariable(pair); } - sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::pair& pair) { + Variable * VariableSpace::getVariable(std::pair& pair) { if (pair.first < 0) return getVariable(pair.first); else @@ -141,32 +141,32 @@ namespace sd { throw std::runtime_error("Unknown variable requested"); } - bool sd::graph::VariableSpace::hasVariable(int id) { + bool VariableSpace::hasVariable(int id) { return _variables.count(id) == 1 || _temporary.count(id) == 1; } - bool sd::graph::VariableSpace::hasVariable(std::pair& id) { + bool VariableSpace::hasVariable(std::pair& id) { return _paired.count(id) > 0; } - void sd::graph::VariableSpace::putOutputVariable(Variable *variable) { + void VariableSpace::putOutputVariable(Variable *variable) { //putVariable(_auto_counter--, variable); putVariable(variable->id(), variable); } - int sd::graph::VariableSpace::externalEntries() { + int VariableSpace::externalEntries() { return _external.size(); } - int sd::graph::VariableSpace::internalEntries() { + int VariableSpace::internalEntries() { return _internal.size(); } - int sd::graph::VariableSpace::totalEntries() { + int VariableSpace::totalEntries() { return externalEntries() + internalEntries(); } - Nd4jLong sd::graph::VariableSpace::externalMemory() { + Nd4jLong VariableSpace::externalMemory() { Nd4jLong size = 0; for (auto n: _external) { size += n->getNDArray()->memoryFootprint(); @@ -187,7 +187,7 @@ namespace sd { return result; } - Nd4jLong sd::graph::VariableSpace::internalMemory() { + Nd4jLong VariableSpace::internalMemory() { Nd4jLong size = 0; for (auto n: _internal) { size += n->getNDArray()->memoryFootprint(); @@ -196,36 +196,36 @@ namespace sd { return size; } - Nd4jLong sd::graph::VariableSpace::totalMemory() { + Nd4jLong VariableSpace::totalMemory() { return externalMemory() + internalMemory(); } - Variable* sd::graph::VariableSpace::putVariable(std::pair& pair, NDArray *array) { + Variable* VariableSpace::putVariable(std::pair& pair, NDArray *array) { auto variable = new Variable(array, nullptr, pair.first, pair.second); this->putVariable(pair, variable); return variable; } - Variable* sd::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) { + Variable* VariableSpace::putVariable(int node, int idx, NDArray *array) { std::pair pair(node, idx); return this->putVariable(pair, array); } - void sd::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) { + void VariableSpace::putVariable(int node, int idx, Variable *variable) { std::pair pair(node, idx); this->putVariable(pair, variable); } - void sd::graph::VariableSpace::silentPutVariable(std::pair& pair, Variable *variable) { + void VariableSpace::silentPutVariable(std::pair& pair, Variable *variable) { _varmap.lock(); - //std::pair, sd::graph::Variable *> p(pair, variable); + //std::pair, Variable *> p(pair, variable); _paired[pair] = variable; _varmap.unlock(); } - void sd::graph::VariableSpace::putVariable(std::pair& pair, Variable *variable) { + void VariableSpace::putVariable(std::pair& pair, Variable *variable) { silentPutVariable(pair, variable); if (variable->isPlaceholder()) @@ -251,7 +251,7 @@ namespace sd { _lists.emplace_back(list); } - void sd::graph::VariableSpace::putVariable(int id, Variable *variable) { + void VariableSpace::putVariable(int id, Variable *variable) { // we don't want to add variables more then once if (_variables.count(id) > 0 || _temporary.count(id) > 0) { auto local = id < 0 ? _variables.at(id) : _temporary.at(id); @@ -277,7 +277,7 @@ namespace sd { variable->setId(id); if (!variable->getName().empty()) { - //std::pair pair(*(variable->getName()), variable); + //std::pair pair(*(variable->getName()), variable); _symbolic[variable->getName()] = variable; } @@ -305,8 +305,8 @@ namespace sd { } } - void sd::graph::VariableSpace::putVariable(int id, int idx, const NDArray &array) { - auto *var = new sd::graph::Variable(const_cast(&array), "", id, idx); + void VariableSpace::putVariable(int id, int idx, const NDArray &array) { + auto *var = new Variable(const_cast(&array), "", id, idx); var->markRemovable(false); var->markReadOnly(true); @@ -320,12 +320,12 @@ namespace sd { delete var; } - void sd::graph::VariableSpace::putVariable(int id, NDArray *array) { - auto *var = new sd::graph::Variable(array); + void VariableSpace::putVariable(int id, NDArray *array) { + auto *var = new Variable(array); this->putVariable(id, var); } - sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id) { + Variable * VariableSpace::getVariable(int id) { if (id < 0) { return _variables.at(id); } else { @@ -333,18 +333,18 @@ namespace sd { } } - LaunchContext* sd::graph::VariableSpace::launchContext() { + LaunchContext* VariableSpace::launchContext() { return LaunchContext::defaultContext(); } - std::vector* sd::graph::VariableSpace::handles() { + std::vector* VariableSpace::handles() { return _handles; } /* * FIXME: this thing have nice chances to become backend-specific! */ - sd::graph::VariableSpace::~VariableSpace() { + VariableSpace::~VariableSpace() { // loop through variables and release them for (auto p: *_handles) { delete p; diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index ac9a106c9994..06c2e36ecc6a 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -140,12 +140,12 @@ namespace sd { * * @return */ - std::string *getOpName(); + const std::string& getOpName() const; /** * Returns opHash */ - Nd4jLong getOpHash(); + Nd4jLong getOpHash() const; /** * This method sets arguments for op diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index 00079f9ae358..1023ae127335 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -111,7 +111,7 @@ namespace sd { if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE) return ND4J_STATUS_OK; - nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", this->getOpName()->c_str(), block->nodeId(), status); + nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", this->getOpName().c_str(), block->nodeId(), status); return ND4J_STATUS_KERNEL_FAILURE; } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index 13aa763f8957..013c7c143d78 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -75,7 +75,7 @@ namespace sd { if (block == nullptr) throw std::invalid_argument("Block is NULL"); - nd4j_debug("Executing list op: [%s]\n", this->getOpName()->c_str()); + nd4j_debug("Executing list op: [%s]\n", this->getOpName().c_str()); // ensure number of IArgs, TArgs match our expectations REQUIRE_OK(this->validateArguments(*block)); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 44fbaae42fc1..212617ec70fc 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -88,11 +88,11 @@ namespace sd { return _descriptor; } - std::string *DeclarableOp::getOpName() { - return _descriptor->getOpName(); + const std::string& DeclarableOp::getOpName() const { + return *_descriptor->getOpName(); } - Nd4jLong DeclarableOp::getOpHash() { + Nd4jLong DeclarableOp::getOpHash() const { return _descriptor->getHash(); } @@ -593,7 +593,7 @@ namespace sd { } Nd4jStatus sd::ops::DeclarableOp::execute(Context* block) { - nd4j_debug("Executing op: [%s]\n", this->getOpName()->c_str()); + nd4j_debug("Executing op: [%s]\n", this->getOpName().c_str()); std::chrono::time_point timeEnter, timeStart, timeEnd; Nd4jLong prepTime, outerTime; @@ -738,25 +738,25 @@ namespace sd { */ if (_descriptor->getNumberOfTArgs() > 0) { if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) { - nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size()); + nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size()); return ND4J_STATUS_BAD_PARAMS; } } else if (_descriptor->getNumberOfTArgs() == -1) if (block.getTArguments()->size() == 0) { - nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str()); + nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); return ND4J_STATUS_BAD_PARAMS; } if (_descriptor->getNumberOfIArgs() > 0) { if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) { - nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName()->c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size()); + nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size()); return ND4J_STATUS_BAD_PARAMS; } } else if (_descriptor->getNumberOfIArgs() == -1) if (block.getIArguments()->size() == 0) { - nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName()->c_str()); + nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); return ND4J_STATUS_BAD_PARAMS; } @@ -799,7 +799,7 @@ namespace sd { return Status::OK(); if (block.width() < 1) { - nd4j_printf("%s: no operands provided for the op", this->getOpName()->c_str()); + nd4j_printf("%s: no operands provided for the op", this->getOpName().c_str()); return ND4J_STATUS_BAD_INPUT; } @@ -808,8 +808,8 @@ namespace sd { for (auto p: *block.inputs()) { auto v = block.variable(p); if (v == nullptr) { - if (this->getOpName() != nullptr) { - nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second); + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName().c_str(), cnt, p.first, p.second); } else { nd4j_printf("Node [%i:]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); } @@ -824,8 +824,8 @@ namespace sd { continue; if (aV == nullptr || !aV->nonNull()) { - if (this->getOpName() != nullptr) { - nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName()->c_str(), cnt, p.first, p.second); + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName().c_str(), cnt, p.first, p.second); } else { nd4j_printf("Node [%i:]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); } diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index 177e3a87e8d8..2ad7fd93670c 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -169,7 +169,7 @@ namespace sd { */ bool OpRegistrator::registerOperation(sd::ops::DeclarableOp *op) { _uniqueD.emplace_back(op); - return registerOperation(op->getOpName()->c_str(), op); + return registerOperation(op->getOpName().c_str(), op); } void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 28240cc10d83..ca4a356ae794 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -95,7 +95,7 @@ TYPED_TEST_CASE(TypedDeclarableOpsTests1, TestingTypes); TEST_F(DeclarableOpsTests1, BasicInitialization1) { auto concat = new sd::ops::concat(); std::string expName("concat"); - ASSERT_EQ(expName, *(concat->getOpName())); + ASSERT_EQ(expName, concat->getOpName()); auto x0 = NDArrayFactory::create_('c', { 1, 5 }); auto x1 = NDArrayFactory::create_('c', { 1, 5 }); @@ -148,7 +148,7 @@ TEST_F(DeclarableOpsTests1, BasicInitialization2) { ASSERT_TRUE(op != nullptr); std::string expName("concat"); - ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_EQ(expName, op->getOpName()); ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); @@ -233,7 +233,7 @@ TEST_F(DeclarableOpsTests1, SynonymInitialization2) { ASSERT_TRUE(op != nullptr); std::string expName("multiply"); - ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_EQ(expName, op->getOpName()); ASSERT_TRUE(op == op2); } @@ -601,7 +601,7 @@ TEST_F(DeclarableOpsTests1, DivergentCheck1) { ASSERT_TRUE(op != nullptr); std::string expName("Switch"); - ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_EQ(expName, op->getOpName()); ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 468cf360884b..32e5f18efe8b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace sd; using namespace sd::graph; @@ -45,7 +46,7 @@ class GraphTests2 : public testing::Test { TEST_F(GraphTests2, test_placeholder_1) { Graph graph; - graph.addPlaceholder("input", 0, DataType::BFLOAT16, {4, 12, 48}); + graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); ASSERT_TRUE(graph.getVariableSpace()->hasVariable("input")); @@ -65,19 +66,19 @@ TEST_F(GraphTests2, test_execution_1) { Graph graph; // A - graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); // B - graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); // C - graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a("multiply", "multiply_node", 10, {{-1, 0}, {-2, 0}}); - Node b("add", "add_node", 20, {{10, 0}, {-3, 0}}); + Node a("multiply_node", sd::ops::multiply()); + Node b("add_node", sd::ops::add()); - graph.addNode(b); - graph.addNode(a); + graph.addNode(a, std::vector{"A", "B"}); + graph.addNode(b, std::vector{"multiply_node", "C"}); auto result = graph.execute({}, {"add_node"}); ASSERT_EQ(1, result.size()); @@ -87,7 +88,7 @@ TEST_F(GraphTests2, test_execution_1) { TEST_F(GraphTests2, test_placeholder_resolution_1) { Graph graph; - graph.addPlaceholder("input", 0, DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); @@ -98,18 +99,20 @@ TEST_F(GraphTests2, test_placeholder_resolution_1) { TEST_F(GraphTests2, test_placeholder_resolution_2) { Graph graph; - graph.addPlaceholder("input", 0, DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + Node a("tanh_node", "tanh"); + graph.addNode(a, {"input"}); auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); + // TODO: add result validation here } TEST_F(GraphTests2, test_output_resolution_1) { Graph graph; - graph.addPlaceholder("input", 0, DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); @@ -120,7 +123,7 @@ TEST_F(GraphTests2, test_output_resolution_1) { TEST_F(GraphTests2, test_input_resolution_1) { Graph graph; - graph.addPlaceholder("input", 0, DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 6f559230b027..3baf082bf2d1 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -363,8 +363,8 @@ TEST_F(JavaInteropTests, Test_Synonyms_1) { ASSERT_TRUE(op != nullptr); ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); ASSERT_EQ(nameExp, nameRef); ASSERT_EQ(nameRef, name); @@ -378,8 +378,8 @@ TEST_F(JavaInteropTests, Test_Synonyms_2) { ASSERT_TRUE(op != nullptr); ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); ASSERT_EQ(nameExp, nameRef); ASSERT_EQ(nameRef, name); @@ -393,8 +393,8 @@ TEST_F(JavaInteropTests, Test_Synonyms_3) { ASSERT_TRUE(op != nullptr); ASSERT_TRUE(opRef != nullptr); - std::string name = *(op->getOpName()); - std::string nameRef = *(opRef->getOpName()); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); ASSERT_EQ(nameExp, nameRef); ASSERT_EQ(nameRef, name); From 15365a0ea2f3d430b9b2039450078ed744bd273d Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 31 Mar 2020 13:04:25 +0300 Subject: [PATCH 052/233] next step Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 1 + libnd4j/include/graph/Node.h | 22 ++++- .../impl/unresolved_input_exception.cpp | 20 ++-- .../impl/unresolved_output_exception.cpp | 16 ++-- .../exceptions/unresolved_input_exception.h | 8 +- .../exceptions/unresolved_output_exception.h | 8 +- libnd4j/include/graph/impl/Graph.cpp | 11 +++ libnd4j/include/graph/impl/Node.cpp | 93 ++++++++++++++++++- .../tests_cpu/layers_tests/GraphTests2.cpp | 3 +- 9 files changed, 149 insertions(+), 33 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index f42ebf43e3b9..3edda9cb5b9d 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -142,6 +142,7 @@ namespace sd { * These methods add given node to the graph * @param node */ + void addNode(Node &&node, const std::vector &inputs); void addNode(Node &node, const std::vector &inputs); void addNode(Node &node, const std::vector> &inputs); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index f06d89d44815..8d7e65f65cf8 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -80,7 +80,7 @@ namespace sd { OpClass _opClass; // these fields are used to store embedded CustomOps and Graph in case of Graph-in-Graph scenario - sd::graph::Graph * _graph= nullptr; + Graph * _graph= nullptr; sd::ops::DeclarableOp *_customOp = nullptr; // each node can be active or inactive, if used with divergents, like IF statements @@ -102,7 +102,7 @@ namespace sd { explicit Node(const std::string &nodeName, const sd::ops::DeclarableOp &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); - explicit Node(const sd::graph::FlatNode *node); + explicit Node(const FlatNode *node); ~Node(); /* @@ -113,7 +113,19 @@ namespace sd { explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - bool equals(Node *other); + + Node(const Node& other) noexcept; + + Node& operator=(const Node& other) noexcept; + + // move constructor + Node(Node&& other) noexcept; + + // move assignment operator + Node& operator=(Node&& other) noexcept; + + + bool equals(Node *other) const; sd::DataType dataType(); ContextPrototype *protoContext(); @@ -188,8 +200,8 @@ namespace sd { sd::ops::DeclarableOp* getCustomOp(); bool hasCustomOp(); - void setGraph(sd::graph::Graph* graph = nullptr); - sd::graph::Graph* getGraph(); + void setGraph(Graph* graph = nullptr); + Graph* getGraph(); bool hasGraphEmbedded(); bool isInplace(); diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp index fe6e45875dda..5b25bbc0ad27 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp @@ -23,28 +23,28 @@ namespace sd { namespace graph { - unresolved_input_exception::unresolved_input_exception(std::string message) : std::runtime_error(message) { + unresolved_input_exception::unresolved_input_exception(const std::string &message) : std::runtime_error(message) { // } - unresolved_input_exception unresolved_input_exception::build(std::string message, int nodeId, std::pair &varIndex) { + unresolved_input_exception unresolved_input_exception::build(const std::string & message, int nodeId, std::pair &varIndex) { auto node = StringUtils::valueToString(nodeId); auto varId = StringUtils::valueToString(varIndex.first); auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Node: [" + node +":0]; Variable: [" + varId + ":" + outputIdx + "]"; - return unresolved_input_exception(message); + auto rmessage = message + "; Node: [" + node +":0]; Variable: [" + varId + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); } - unresolved_input_exception unresolved_input_exception::build(std::string message, std::pair &varIndex) { + unresolved_input_exception unresolved_input_exception::build(const std::string &message, std::pair &varIndex) { auto nodeId = StringUtils::valueToString(varIndex.first); auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_input_exception(message); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); } - unresolved_input_exception unresolved_input_exception::build(std::string message, std::string &varName) { - message += "; Variable: [" + varName + "]"; - return unresolved_input_exception(message); + unresolved_input_exception unresolved_input_exception::build(const std::string &message, const std::string &varName) { + auto rmessage = message + "; Variable: [" + varName + "]"; + return unresolved_input_exception(rmessage); } } } \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp index df8b5eb00c53..d7dede17ca08 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp @@ -24,26 +24,26 @@ namespace sd { namespace graph { - unresolved_output_exception::unresolved_output_exception(std::string message) : std::runtime_error(message) { + unresolved_output_exception::unresolved_output_exception(const std::string &message) : std::runtime_error(message) { // } - unresolved_output_exception unresolved_output_exception::build(std::string message, std::pair &varIndex) { + unresolved_output_exception unresolved_output_exception::build(const std::string &message, std::pair &varIndex) { auto nodeId = StringUtils::valueToString(varIndex.first); auto outputIdx = StringUtils::valueToString(varIndex.second); - message += "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_output_exception(message); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); } - unresolved_output_exception unresolved_output_exception::build(std::string message, int nodeId, int outputIndex) { + unresolved_output_exception unresolved_output_exception::build(const std::string &message, int nodeId, int outputIndex) { std::pair p(nodeId, outputIndex); return build(message, p); } - unresolved_output_exception unresolved_output_exception::build(std::string message, std::string &varName, int outputIndex) { + unresolved_output_exception unresolved_output_exception::build(const std::string &message, const std::string &varName, int outputIndex) { auto outputIdx = StringUtils::valueToString(outputIndex); - message += "; Variable: [" + varName + ":" + outputIdx + "]"; - return unresolved_output_exception(message); + auto rmessage = message +"; Variable: [" + varName + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); } } } \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/unresolved_input_exception.h b/libnd4j/include/graph/exceptions/unresolved_input_exception.h index 842e5ef8c728..024798a95290 100644 --- a/libnd4j/include/graph/exceptions/unresolved_input_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_input_exception.h @@ -29,12 +29,12 @@ namespace sd { namespace graph { class unresolved_input_exception : public std::runtime_error { public: - unresolved_input_exception(std::string message); + unresolved_input_exception(const std::string &message); ~unresolved_input_exception() = default; - static unresolved_input_exception build(std::string message, int nodeId, std::pair &varIndex); - static unresolved_input_exception build(std::string message, std::pair &varIndex); - static unresolved_input_exception build(std::string message, std::string &varName); + static unresolved_input_exception build(const std::string &message, int nodeId, std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, const std::string &varName); }; } } diff --git a/libnd4j/include/graph/exceptions/unresolved_output_exception.h b/libnd4j/include/graph/exceptions/unresolved_output_exception.h index b9f09bf4cbdf..c923da785e2f 100644 --- a/libnd4j/include/graph/exceptions/unresolved_output_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_output_exception.h @@ -29,12 +29,12 @@ namespace sd { namespace graph { class unresolved_output_exception : public std::runtime_error { public: - unresolved_output_exception(std::string message); + unresolved_output_exception(const std::string &message); ~unresolved_output_exception() = default; - static unresolved_output_exception build(std::string message, int nodeId, int outputIndex); - static unresolved_output_exception build(std::string message, std::pair &varIndex); - static unresolved_output_exception build(std::string message, std::string &varName, int outputIndex); + static unresolved_output_exception build(const std::string &message, int nodeId, int outputIndex); + static unresolved_output_exception build(const std::string &message, std::pair &varIndex); + static unresolved_output_exception build(const std::string &message, const std::string &varName, int outputIndex = 0); }; } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index dd63c7dec62c..4b5dc246a104 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -326,6 +326,11 @@ namespace sd { addVariable(name, lvalue); } + void Graph::addNode(Node &&node, const std::vector &inputs) { + auto lvalue = node; + addNode(lvalue, inputs); + } + void Graph::addNode(Node &node, const std::vector &inputs) { if (node.id() != 0) throw std::runtime_error("Graph::addNode - Node has id defined"); @@ -1760,6 +1765,12 @@ namespace sd { } std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { + // first of all we check existence of placeholders in dictionary + for (const auto &v:dictionary) { + if (_symbolicLookupTable.count(v.first) == 0) + throw unresolved_input_exception::build("Dictionary entry doesn't exist", v.first); + } + // TODO: implement this method return std::map(); } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 632dd76dc798..bcaf42ba99ad 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -878,13 +878,104 @@ namespace sd { _rewindLayer.second = stepId; } - bool Node::equals(Node *other) { + bool Node::equals(Node *other) const { if (_opType == other->_opType && _dataType == other->_dataType && _opNum == other->_opNum) return true; return false; } + Node::Node(const Node &other) noexcept { + + } + + Node &Node::operator=(const Node &other) noexcept { + if (this == &other) + return *this; + + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _scope_id = other._scope_id; + _scope_name = other._scope_name; + _rewindNode = other._rewindNode; + _layer = other._layer; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _isDeductable = other._isDeductable; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _customOp = other._customOp; + _dim = other._dim; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; + _dimensions = other._dimensions; + _rewindLayer = other._rewindLayer; + _referencedBy = other._referencedBy; + _scalar = other._scalar; + + return *this; + } + + Node::Node(Node &&other) noexcept { + + } + + Node &Node::operator=(Node &&other) noexcept { + if (this == &other) + return *this; + + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _scope_id = other._scope_id; + _name = std::move(other._name); + _scope_name = std::move(other._scope_name); + _rewindNode = other._rewindNode; + _layer = other._layer; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _isDeductable = other._isDeductable; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _customOp = other._customOp; + _dim = other._dim; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = std::move(other._input); + _output = std::move(other._output); + _dimensions = std::move(other._dimensions); + _rewindLayer = std::move(other._rewindLayer); + _referencedBy = std::move(other._referencedBy); + _scalar = std::move(other._scalar); + + other._protoContext = nullptr; + other._customOp = nullptr; + + return *this; + } + void Node::deleteOpByType(OpType opType, void *op) { switch (opType) { case OpType_PAIRWISE: diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 32e5f18efe8b..3206ba513995 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -125,7 +125,8 @@ TEST_F(GraphTests2, test_input_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + Node a("tanh_node", sd::ops::tanh()); + graph.addNode(a, {"input"}); // since we're trying to resolve non-existent placeholder - we expect exception ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); From a350345ffd350c0bd219522acb1a98d67754068f Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 31 Mar 2020 18:46:57 +0300 Subject: [PATCH 053/233] getting rid of legacy stuff Signed-off-by: raver119 --- libnd4j/include/array/impl/NDArrayFactory.cpp | 1 - libnd4j/include/graph/ExecutorConfiguration.h | 2 +- libnd4j/include/graph/Graph.h | 196 +-- libnd4j/include/graph/GraphState.h | 142 -- libnd4j/include/graph/Node.h | 25 +- libnd4j/include/graph/Variable.h | 3 +- .../graph/execution/GraphExecutioner.h | 72 - .../graph/execution/impl/GraphExecutioner.cpp | 649 -------- .../graph/impl/ExecutorConfiguration.cpp | 13 +- libnd4j/include/graph/impl/Graph.cpp | 1319 +---------------- libnd4j/include/graph/impl/GraphHolder.cpp | 5 +- libnd4j/include/graph/impl/GraphState.cpp | 167 --- libnd4j/include/graph/impl/Node.cpp | 62 +- libnd4j/include/graph/impl/Variable.cpp | 4 + .../graph/logic/impl/LogicConditional.cpp | 4 +- .../include/graph/logic/impl/LogicEnter.cpp | 3 + .../include/graph/logic/impl/LogicMerge.cpp | 11 +- .../graph/logic/impl/LogicNextIteration.cpp | 3 + .../include/graph/logic/impl/LogicReturn.cpp | 3 + .../include/graph/logic/impl/LogicSwitch.cpp | 4 +- .../include/graph/logic/impl/LogicWhile.cpp | 4 +- .../profiling/impl/GraphProfilingHelper.cpp | 6 +- libnd4j/include/legacy/NativeOps.h | 1 - libnd4j/include/legacy/cpu/NativeOps.cpp | 114 +- .../layers_tests/ConditionalTests.cpp | 332 ----- .../layers_tests/FlatBuffersTests.cpp | 1 - .../layers_tests/GraphAnalysisTests.cpp | 38 +- .../layers_tests/GraphStateTests.cpp | 349 ----- libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 1007 +++---------- .../tests_cpu/layers_tests/GraphTests2.cpp | 10 +- .../layers_tests/ListOperationsTests.cpp | 260 ---- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 2 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 71 +- .../tests_cpu/layers_tests/ProtoBufTests.cpp | 1 - libnd4j/tests_cpu/layers_tests/ScopeTests.cpp | 165 --- .../layers_tests/ServerRelatedTests.cpp | 1 - .../tests_cpu/layers_tests/SwitchTests.cpp | 251 ---- libnd4j/tests_cpu/layers_tests/testlayers.h | 1 - 38 files changed, 384 insertions(+), 4918 deletions(-) delete mode 100644 libnd4j/include/graph/GraphState.h delete mode 100644 libnd4j/include/graph/execution/GraphExecutioner.h delete mode 100644 libnd4j/include/graph/execution/impl/GraphExecutioner.cpp delete mode 100644 libnd4j/include/graph/impl/GraphState.cpp delete mode 100644 libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp delete mode 100644 libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp delete mode 100644 libnd4j/tests_cpu/layers_tests/ScopeTests.cpp delete mode 100644 libnd4j/tests_cpu/layers_tests/SwitchTests.cpp diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index d67514236141..9e6a0c8964f6 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include diff --git a/libnd4j/include/graph/ExecutorConfiguration.h b/libnd4j/include/graph/ExecutorConfiguration.h index 61c001381eb1..8a9d3e831700 100644 --- a/libnd4j/include/graph/ExecutorConfiguration.h +++ b/libnd4j/include/graph/ExecutorConfiguration.h @@ -40,7 +40,7 @@ namespace sd { explicit ExecutorConfiguration(const sd::graph::FlatConfiguration *conf = nullptr); ~ExecutorConfiguration() = default; - ExecutorConfiguration* clone(); + ExecutorConfiguration clone() const; #ifndef __JAVACPP_HACK__ flatbuffers::Offset asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder); diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 3edda9cb5b9d..bc78fc2fb8cc 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -45,51 +45,29 @@ namespace sd { class SD_EXPORT Graph { protected: - ExecutorConfiguration *_configuration; + ExecutorConfiguration _configuration; VariableSpace *_variableSpace; Stash* _stash; - // this list holds references to Node ptrs, which should be free'd in Graph destructor - std::vector _handles; + MAP_IMPL _unmapped; - // vector holds ID's of top nodes only - std::vector *_nodes; - MAP_IMPL *_mapped; - - MAP_IMPL *> *_onion; - MAP_IMPL _unmapped; + // string -> id conversion table MAP_IMPL _symbolicLookupTable; - std::vector _unmappedMap; // macOS? std::mutex _mutexPreprocessing; std::atomic _built; - std::vector _output; - std::vector _autos; - // we want to know last node id int _maxId = 1; - - MAP_IMPL _mappedScopes; - std::vector _scopes; - const GraphMemoryManager &_memoryMaager; //////////////////////////////////////// Nd4jStatus validateNode(Node *node); - void expandOnion(int newLayer); - - void injectNode(Node *node); - - void pushToOutputOnce(int id); - - void printOutNode(Node* node); - - void prepareOutputs(); - int idByName(const std::string &nodeName) const; + + void printOutNode(const Node &node) const; public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); @@ -102,25 +80,13 @@ namespace sd { static Graph* fromFlatBuffers(const char *fileName, const GraphMemoryManager &memoryManager = GraphMemoryManager()); static Graph* fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); - // this method applies toposort to nodes - void toposortNodes(); - // method that'll print out graph Nd4jStatus validate(); - // this method will build structured representation of graph - Nd4jStatus buildGraph(); - - // this method will return estimated memory size (in bytes) required for 1 full graph execution round - Nd4jLong estimateRequiredMemory(); - - // this method returns number of root nodes in this graph - int rootNodes(); - // this method returns total number of nodes in this graph - int totalNodes(); + int size() const; - int numberOfPlaceholders(); + int numberOfPlaceholders() const; const std::vector& getPlaceholders() const; @@ -132,19 +98,14 @@ namespace sd { /** * These methods add given node to the graph - * FIXME: deprecated * @param node */ - void addNode(Node *node); - void addNode(const Node &node); + //void addNode(Node &&node, const std::vector &inputs); + + void addNode(Node &node, const std::initializer_list &inputs); + void addNode(Node &node, const std::initializer_list &inputs); + void addNode(Node &node, const std::initializer_list> &inputs); - /** - * These methods add given node to the graph - * @param node - */ - void addNode(Node &&node, const std::vector &inputs); - void addNode(Node &node, const std::vector &inputs); - void addNode(Node &node, const std::vector> &inputs); void addVariable(const std::string &name, NDArray &array); void addVariable(const std::string &name, NDArray &&array); @@ -154,69 +115,19 @@ namespace sd { */ void addPlaceholder(const std::string &nodeName, const DataType dataType = sd::DataType::ANY, const std::vector &shape = {}); - /** - * This method returns layered representation of the graph - * - * @return - */ - MAP_IMPL *> *getOnion(); - - /** - * This method returns map of all nodes of the graph - * @return - */ - MAP_IMPL* getMapped(); - - /** - * This method returns outputs of this graph - * @return - */ - std::vector *fetchOutputs(); /** * This method returns pointer to ExecutorConfiguration * * @return */ - ExecutorConfiguration *getExecutorConfiguration(); - - /** - * This method adds specified node (by ID) to de - * @param id - */ - void addOutput(int id); - - /** - * This method returns all nodes at once (order is NOT guaranteed) - * @return - */ - std::vector *getAllNodes(); + const ExecutorConfiguration& getExecutorConfiguration() const; /** * This method prints out Graph op-by-op, and respective inputs */ void printOut(); - /** - * This method collect all ops from the graph into ops vector - */ - std::vector getOperations(); - - /** - * This method returns Scope ptr specified with id - * - * @param id - * @return - */ - Scope* scopeById(int id); - - /** - * This method returns TRUE if specified ID refers to Scope, and false otherwise - * @param id - * @return - */ - bool hasScope(int id); - /** * This method returns clone of the graph */ @@ -232,43 +143,15 @@ namespace sd { */ void forgetVariableSpace(); - /** - * This method returns Node with given Id - */ - Node* nodeById(int nodeId); - - /** - * This method returns True if node with given ID exists, False otherwise - * @param nodeId - * @return - */ - bool hasNode(int nodeId); - /** * This method returns hash of given Graph instance */ - Nd4jLong hashCode(); - - /** - * PLEASE NOTE: This method will be moved to private section - */ - void tagInplaceNodes(); - - void replaceState(VariableSpace *state, ExecutorConfiguration *configuration); - - FORCEINLINE std::vector* nodes(); + Nd4jLong hashCode() const; - FORCEINLINE std::vector* autos(); - - FORCEINLINE std::vector* output(); - - FORCEINLINE MAP_IMPL* scopes(); + void replaceState(VariableSpace *state, const ExecutorConfiguration &configuration); FORCEINLINE bool built(); - FORCEINLINE void pullState(Graph *other); - - OptimizedGraph optimizedGraph() const; /** @@ -290,59 +173,10 @@ namespace sd { const std::unordered_map& positions, OpSequence& opSeq) const; }; - FORCEINLINE std::vector* Graph::nodes() { - return _nodes; - } - - FORCEINLINE std::vector* Graph::autos() { - return &_autos; - } - - FORCEINLINE std::vector* Graph::output() { - return &_output; - } - - FORCEINLINE MAP_IMPL* Graph::scopes() { - return &_mappedScopes; - } FORCEINLINE bool Graph::built() { return _built.load(); } - - FORCEINLINE void Graph::pullState(Graph *other) { - for (int e = 0; e < other->nodes()->size(); e++) - this->_nodes->emplace_back(other->nodes()->at(e)); - - for (int e = 0; e < other->output()->size(); e++) - this->_output.emplace_back(other->output()->at(e)); - - for (int e = 0; e < other->autos()->size(); e++) - this->_autos.emplace_back(other->autos()->at(e)); - - for (auto &v: *other->scopes()) { - auto scp = v.second->clone(); - this->_mappedScopes[v.first] = scp; - this->_scopes.emplace_back(scp); - } - - for (auto &v: *other->getOnion()) { - auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector(); - - auto ovec = (*other->getOnion())[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - _handles.emplace_back(n); - (*this->_mapped)[n->id()] = n; - } - - if (this->_onion->count(v.first) < 1) - (*this->_onion)[v.first] = vec; - } - - this->_built.store(other->built()); - } } } diff --git a/libnd4j/include/graph/GraphState.h b/libnd4j/include/graph/GraphState.h deleted file mode 100644 index 122fe37a4f83..000000000000 --- a/libnd4j/include/graph/GraphState.h +++ /dev/null @@ -1,142 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -#ifndef LIBND4J_GRAPHSTATE_H -#define LIBND4J_GRAPHSTATE_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sd { -namespace graph { - - class SD_EXPORT GraphState { - protected: - // id of this GraphState instance - Nd4jLong _id = 0; - - // map of scopes. Scope id is used as key, since it's referred in calls later anyway - MAP_IMPL _scopes; - - // this variable space holds temp references - VariableSpace _variableSpace; - - Graph *_graph; - - public: - explicit GraphState(Nd4jLong id); - ~GraphState(); - - /** - * - * @return - */ - Nd4jLong id(); - - /** - * This method adds scope to this state tracker - * - * @param scopeId - * @return - */ - Nd4jStatus registerScope(int scopeId); - - /** - * This method cheks if scope with given ID exists - * - * @param scopeId - ID of the scope - * @return - TRUE if scope exists, FALSE otherwise - */ - bool hasScope(int scopeId); - - /** - * This method removes specified scope from this state tracker - * - * @param scopeId - * @return - */ - Nd4jStatus forgetScope(int scopeId); - -#ifndef __JAVACPP_HACK__ - /** - * This method adds given op to the end of specified scope - * PLEASE NOTE: This method is used for tests mostly - * - * @param scopeId - * @param op - * @return - */ - Nd4jStatus attachOpToScope(int scopeId, int nodeId, sd::ops::DeclarableOp *op, ArgumentsList inputs); - - /** - * This method returns pointer to the scope with given id - * - * @param scopeId - id of the scope - */ - Scope* getScope(int scopeId); - - Graph* graph(); -#endif - /** - * This method adds given op to the end of specified scope - * - * @param scopeId - * @param opNum - * @param type - * @return - */ - Nd4jStatus attachOpToScope(int scopeId, Nd4jLong opNum, int type, ArgumentsList inputs); - - /** - * This method adds return statement to specified scope - * - * PLEASE NOTE: should be used only in body scopes - * - * @param scopeId - * @param nodeId - * @param args - * @return - */ - Nd4jStatus defineReturn(int scopeId, int nodeId, ArgumentsList args); - - /** - * This method returns current variable space of this state holder - * - * @return - */ - VariableSpace* variableSpace(); - }; -} -} - - - -#endif //LIBND4J_GRAPHSTATE_H diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 8d7e65f65cf8..6b4eff3e3316 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -124,16 +124,15 @@ namespace sd { // move assignment operator Node& operator=(Node&& other) noexcept; - bool equals(Node *other) const; sd::DataType dataType(); - ContextPrototype *protoContext(); - OpType opType(); - Nd4jLong opNum(); - int id(); - std::vector> *input(); - std::vector> *output(); + ContextPrototype *protoContext() const; + OpType opType() const; + Nd4jLong opNum() const; + int id() const; + const std::vector>& input() const; + const std::vector>& output() const; Nd4jLong getFrameId(); void setFrameId(Nd4jLong frameId); @@ -197,12 +196,12 @@ namespace sd { bool hasBlockAttached(); void setCustomOp(sd::ops::DeclarableOp *customOp = nullptr); - sd::ops::DeclarableOp* getCustomOp(); - bool hasCustomOp(); + sd::ops::DeclarableOp* customOp() const; + bool hasCustomOp() const; void setGraph(Graph* graph = nullptr); - Graph* getGraph(); - bool hasGraphEmbedded(); + Graph* graph() const; + bool hasGraphEmbedded() const; bool isInplace(); void markInplace(bool reallyInplace); @@ -251,10 +250,10 @@ namespace sd { if (this->_customOp != nullptr && _isDeductable) delete this->_customOp; - for (auto v: *other->input()) + for (auto &v: other->input()) this->_input.emplace_back(v); - for (auto v: *other->output()) + for (auto &v: other->output()) this->_output.emplace_back(v); for (auto v: *other->getDimensions()) diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index c302741c93cd..d1f2ddc78b91 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -131,7 +131,8 @@ namespace sd { void setId(int id); void setId(int id, int idx); - const std::string &getName() const; + const std::string& name() const; + const std::string& getName() const; void setName(const std::string &name); const std::vector& shape() const; diff --git a/libnd4j/include/graph/execution/GraphExecutioner.h b/libnd4j/include/graph/execution/GraphExecutioner.h deleted file mode 100644 index 785ca3a2658d..000000000000 --- a/libnd4j/include/graph/execution/GraphExecutioner.h +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_GRAPHEXECUTIONER_H -#define LIBND4J_GRAPHEXECUTIONER_H - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#define TF_INPUT "Placeholder" -#define TF_CONST "Const" -#define TF_VAR "VariableV2" - -namespace sd { - namespace graph { - /** - * TODO: REMOVE THIS CLASS - */ - class SD_EXPORT GraphExecutioner { - public: - //static Nd4jStatus executeFlatNode(sd::graph::Graph *graph, sd::graph::Node *node, sd::graph::VariableSpace *variableSpace); - - static Nd4jStatus executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace); - - /** - * This method executes given Graph - * @return - */ - static Nd4jStatus execute(Graph *graph, VariableSpace *variableSpace = nullptr); - - - /** - * This method executes graph stored at given FlatBuffers pointer - * - * @param pointer Pointer to FlatBuffer - * @return pointer to FlatBuffer with result - */ - static sd::graph::ResultWrapper* executeFlatBuffer(Nd4jPointer pointer); - - static flatbuffers::Offset execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); - }; - } -} - - -#endif //LIBND4J_GRAPHEXECUTIONER_H diff --git a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp b/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp deleted file mode 100644 index 4cda626bb115..000000000000 --- a/libnd4j/include/graph/execution/impl/GraphExecutioner.cpp +++ /dev/null @@ -1,649 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -//#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//#include -//#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace sd{ - namespace graph { - /** - * This method executes given Node (as in Op within Node) - * - * Basically it just does DeclarableOp::execute(Block), and ops to their job. However, there are some additional functionality. - * - * @param graph - Graph instance pointer - * @param node - Node instance pointer, which will be executed - * @param variableSpace - VariableSpace instance pointer - varspace specific to current Thread/Session - * @return - */ - Nd4jStatus GraphExecutioner::executeFlatNode(Graph *graph, Node *node, VariableSpace *variableSpace) { - OpType opType = node->opType(); - int opNum = node->opNum(); - // std::string opName = *(node->getCustomOp()->getOpName()); - - if (opType == OpType_BOOLEAN) { - nd4j_debug("Executing boolean graph node_%i", node->id()); - } else if (opType == OpType_LOGIC) { - nd4j_debug("Executing logic graph node_%i", node->id()); - } else if (opType == OpType_GRAPH) { - nd4j_debug("Executing embedded graph node_%i", node->id()); - } else if (opType != OpType_CUSTOM) { - nd4j_debug("Executing node_%i{%i}\n", node->id(), opNum); - } else { - nd4j_debug("Executing node_%i{%s}\n", node->id(), node->getCustomOp()->getOpName().c_str()); - } - - Context context(node->getContextPrototype(), variableSpace); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) { - //nd4j_debug("Input variables: %i\n", node->input()->size()); - printf(" Inputs: {"); - for (int e = 0; e < node->input()->size(); e++) { - printf("[%i:%i]", node->input()->at(e).first, node->input()->at(e).second); - - if (e < node->input()->size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } - - if (node->id() == 13) - nd4j_debug("",""); - - // if true - this is special case: Graph-in-Graph. - if (node->hasGraphEmbedded()) { - auto embedded = node->getGraph(); - - /** - * basically, we should do following things here: - * 1) fill embedded graph with input variables from this graph, if anything should be filled in - * 2) invoke embedded graph - * 3) announce its results as corresponding output variables in current VariableSpace - */ - - // enforcing IMPLICIT mode. or not... should we try to be smarter then user? - //embedded->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; - - if (node->input()->size() != embedded->numberOfPlaceholders()) { - nd4j_debug("Placeholders amount mismatch: %i expected, and %i available\n",node->input()->size(), embedded->numberOfPlaceholders()); - return ND4J_STATUS_BAD_INPUT; - } - - // we need to propagate required variables to the embedded graph - ResultSet deletables; - int cnt = 0; - for (Variable* v: embedded->getPlaceholders()) { - if (!v->getName().empty()) { - - // trying symbolic lookup first - if (variableSpace->hasVariable(v->getName())) { - // symbolic feeder - auto array = variableSpace->getVariable(v->getName())->getNDArray(); - auto vr = new NDArray(array->dup()); - // deletables.push_back(vr); - v->setNDArray(vr); - } else { - nd4j_debug("Can't find variable [%s] in parent graph...", v->getName().c_str()); - return ND4J_STATUS_BAD_INPUT; - //throw "Can't find desired variable"; - } - } else { - // if we're not using symbolic lookup - we'll use sequential approach then - auto p = node->input()->at(cnt); - auto array = variableSpace->getVariable(p)->getNDArray(); - auto vr = new NDArray(array->dup()); - //deletables.push_back(vr); - v->setNDArray(vr); - } - - cnt++; - } - - // executing embedded graph as independent one - Nd4jStatus status = GraphExecutioner::execute(embedded); - if (status != ND4J_STATUS_OK) - return status; - - // now we should migrate its results to this node, as its own outputs - cnt = 0; - auto outputs = embedded->fetchOutputs(); - - for (auto v: *outputs){ - NDArray *array = v->getNDArray(); - v->setNDArray(nullptr); - - std::pair pair(node->id(), cnt++); - - auto var = variableSpace->getVariable(pair); - - //nd4j_printf("HasArray: [%i]; Removable: [%i]\n", var->hasNDArray(), var->isRemovable()); - var->setNDArray(array); - var->markRemovable(true); - } - deletables.size(); - delete outputs; - nd4j_debug("Embedded graph execution finished. %i variable(s) migrated\n", cnt); - - } else if (node->hasCustomOp()) { - // now, if we have something to execute - lets just execute it. - auto status = node->getCustomOp()->execute(&context); - if (status != ND4J_STATUS_OK) - return status; - - // propagate variables - if (node->hasExternalOutputs()) { - for (auto v: *node->output()) { - if (variableSpace->hasExternalVariable(v.first)) { - variableSpace->getVariable(v.first)->getNDArray()->assign(variableSpace->getVariable(node->id())->getNDArray()); - } - } - } - - return status; - } - return ND4J_STATUS_OK; - } - - - /** - * This method executes given Graph instance, and returns error code. - * - * @param graph - * @return one of error codes defined in pointercast.h - */ - Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) { - auto __variableSpace = variableSpace == nullptr ? graph->getVariableSpace() : variableSpace; - - bool tempFlow = false; - if (__variableSpace->flowPath() == nullptr) { - tempFlow = true; - __variableSpace->setFlowPath(new FlowPath()); - } - auto flowPath = __variableSpace->flowPath(); - - Nd4jLong tb0 = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; - graph->buildGraph(); - - auto footprintForward = sd::memory::MemoryRegistrator::getInstance()->getGraphMemoryFootprint(graph->hashCode()); - if (footprintForward > 0) { - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - // this method will work only if current workspace size is smaller then proposed value - nd4j_debug("Setting workspace to %lld bytes\n", footprintForward); - __variableSpace->launchContext()->getWorkspace()->expandTo(footprintForward); - } - } - - // optionally saving graph build time - if (Environment::getInstance()->isProfiling()) - flowPath->profile()->setBuildTime(GraphProfile::relativeTime(tb0)); - - Nd4jLong timeStart = Environment::getInstance()->isProfiling() ? GraphProfile::currentTime() : 0L; - - bool pe = graph->getExecutorConfiguration()->_executionMode == ExecutionMode_AUTO; - - - // basically if at some point code diverges, code branch might be _DISABLED_, and all nodes within that branch will be disabled as well - - std::deque frames; - bool inFrame = false; - bool leftFrame = false; - - auto nodeTime = GraphProfile::currentTime(); - int lastId = -10000000; - Nd4jLong exec_counter = 0; - // we loop through op layers here - for (int l = 0; l < (int) graph->getOnion()->size(); l++) { - int layerSize = graph->getOnion()->count(l) == 1 ? graph->getOnion()->at(l)->size() : 0; - - int n = 0; - // this omp block will probably never be the case - for (; n < layerSize; n++) { - if (++exec_counter > 10000) { - l = graph->getOnion()->size(); - return Status::THROW("Early termination hit"); - } - - Node* node = graph->getOnion()->at(l)->at(n); - - if (Environment::getInstance()->isProfiling()) - flowPath->profile()->nodeById(node->id(), node->name().c_str()); - - if (lastId != node->id() && Environment::getInstance()->isProfiling()) { - if (lastId != -10000000) - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - - lastId = node->id(); - nodeTime = GraphProfile::currentTime(); - } - - nd4j_debug("Step: %lld; Node: %i <%s>\n", exec_counter, node->id(), node->name().c_str()); - - // on first non-Exit node after loop we can rewind (if planned) - if (!(node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit)) { - // VALIDATED - - // if we're out of frame - let's remove it from queue - if (leftFrame) { - auto frame_id = frames.back(); - frames.pop_back(); - flowPath->markFrameActive(frame_id, false); - flowPath->forgetFrame(frame_id); - - leftFrame = false; - } - - - // TODO: move inactivity check right here - bool shouldSkip = false; - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Merge) { - // Merge node has own checkout logic - - auto inputId0 = node->input()->at(0); - auto inputId1 = node->input()->at(1); - - // Merge node can be skipped only both inputs are inactive - if (!flowPath->isNodeActive(inputId0.first) && !flowPath->isNodeActive(inputId1.first)) - shouldSkip = true; - - } else { - // let's check for input nodes, if they are disabled or contain divergents - for (int e = 0; e < node->input()->size(); e++) { - auto inputId = node->input()->at(e); - - // not a node. skipping checks - if (graph->getMapped()->count(inputId.first) == 0) - continue; - - /** - * We can skip current node, in two cases: - * 1) If previous node was disabled - * 2) If previous node was divergent node (i.e. IF op) and code went other way - */ - Node *prevNode = graph->getMapped()->at(inputId.first); - if (!flowPath->isNodeActive(inputId.first)) { - shouldSkip = true; - flowPath->markNodeActive(node->id(), false); - - nd4j_debug("Skipping Node_%i due to inactive input [%i]\n", node->id(), inputId.first); - break; - - } else if (prevNode->isDivergencePoint()) { // literally checking for switch here - if (flowPath->branch(inputId.first) != inputId.second) { - shouldSkip = true; - flowPath->markNodeActive(node->id(), false); - nd4j_debug("Skipping Node_%i due to divergent branch [%i]\n", node->id(), - inputId.first); - break; - } - } - } - } - - if (shouldSkip) - continue; - } - - // we're propagating frameId here (but only if wasn't set earlier) - if (frames.size() > 0 && node->getFrameId() < 0) - node->setFrameId(frames.back()); - - - flowPath->markNodeActive(node->id(), true); - - if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Enter) { - // Enter operation - // VALIDATED - - // we expect this node to have frameId set - auto frame_id = node->getFrameId(); - - // new frame starts here - if (frames.size() == 0 || (frames.size() > 0 && frames.back() != frame_id)) { - flowPath->registerFrame(frame_id); - frames.emplace_back(frame_id); - inFrame = true; - } - - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::NextIteration) { - /** - * NextIteration is special case: after successful execution of this op - we're changing execution position - */ - // VALIDATED - auto inputId = node->input()->at(0); - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - auto frame_id = frames.back(); - - flowPath->markNodeActive(node->id(), true); - flowPath->markExecuted(node->id(), true); - - if (!flowPath->isRewindPlanned(frame_id)) { - auto nextLayer = node->getRewindLayer(); - - nd4j_debug("Node_%i planned rewind to Node_%i at [%i:%i]\n", node->id(), node->getRewindNode(), nextLayer.first, nextLayer.second); - - flowPath->planRewind(frame_id, true); - flowPath->setRewindPositionOnce(frame_id, nextLayer.first - 1); - - continue; - } - - - } else if (node->opType() == OpType_LOGIC && node->opNum() == sd::logic::Exit) { - // Exit node is another special case: it can rewind executioner to specific point in graph - // VALIDATED - - auto frame_id = frames.back(); - - // if this loop frame wasn't activated - just skip it - if (!flowPath->isFrameActive(frame_id)) { - flowPath->markNodeActive(node->id(), false); - - leftFrame = true; - continue; - } - - if (flowPath->isRewindPlanned(frame_id)) { - // just break loop here - l = flowPath->getRewindPosition(frame_id); - flowPath->setRewindPosition(frame_id, -1); - flowPath->planRewind(frame_id, false); - - break; - } else { - // execute Exit node otherwise - - auto status = LogicExecutor::processNode(graph, node); - if (status != Status::OK()) - return status; - - leftFrame = true; - } - - - } else if (node->opType() == OpType_LOGIC) { - /** - * If this LOGIC op, we'll use another execution model here - */ - auto status = LogicExecutor::processNode(graph, node); - - if (status != Status::OK()) - return status; - } else { - - - auto timeStart = std::chrono::system_clock::now(); - - // actual node execution happens right here - Nd4jStatus status = executeFlatNode(graph, node, __variableSpace); - - auto timeEnd = std::chrono::system_clock::now(); - - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - - - flowPath->setOuterTime(node->id(), outerTime); - - if (status != ND4J_STATUS_OK) - return status; - - - // here we should handle divergent ops, and disable nodes accordingly - if (node->isDivergencePoint()) { - auto activeBranch = flowPath->branch(node->id()); - nd4j_debug("Active branch at node [%i]: %i\n", node->id(), activeBranch); - - // now we skip all branches except of this active one - } - - if (sd::Environment::getInstance()->isDebugAndVerbose()) { - - if (__variableSpace->getVariable(node->id())->hasNDArray()) { - auto array = __variableSpace->getVariable(node->id())->getNDArray(); - auto shape = ShapeUtils::shapeAsString(array); - auto values = array->asIndexedString(16); - auto type = DataTypeUtils::asString(array->dataType()); - nd4j_debug("node_%i finished. result shape: %s; data type: %s; first values: %s\n", node->id(), shape.c_str(), type.c_str(), values.c_str()); - } else if (__variableSpace->getVariable(node->id())->hasNDArrayList()) { - auto list = __variableSpace->getVariable(node->id())->hasNDArrayList() ? __variableSpace->getVariable(node->id())->getNDArrayList() : nullptr; - nd4j_debug("node_% is ListOp, skipping evaluation", node->id()); - } else { - nd4j_debug("node_% is Unknown: has no NDArray or NDArrayList", node->id()); - } - } - } - - // if node was executed - tag it as active - flowPath->markExecuted(node->id(), true); - } - } - - // optionally saving execution time - if (Environment::getInstance()->isProfiling()) { - flowPath->profile()->nodeById(lastId)->setTotalTime(GraphProfile::relativeTime(nodeTime)); - flowPath->profile()->setExecutionTime(GraphProfile::relativeTime(timeStart)); - //flowPath->profile().printOut(); - } - - // saving memory footprint for current run - if (__variableSpace->launchContext()->getWorkspace() != nullptr) { - auto m = __variableSpace->launchContext()->getWorkspace()->getAllocatedSize(); - auto h = graph->hashCode(); - sd::memory::MemoryRegistrator::getInstance()->setGraphMemoryFootprintIfGreater(h, m); - } - - if (tempFlow) { - delete flowPath; - __variableSpace->setFlowPath(nullptr); - } - - return Status::OK(); - } - - /** - * This method is provided for IPC: - * 1) it accepts pointer to FlatBuffers buffer - * 2) restores Graph from it - * 3) Executes this Graph - * 4) Packs execution results into FlatBuffers (FlatResults instance) - * 5) Returns pointer to FlatBuffer results buffer - * - */ - sd::graph::ResultWrapper* GraphExecutioner::executeFlatBuffer(Nd4jPointer pointer) { - uint8_t *buffer = reinterpret_cast(pointer); - - // nd4j_debug("Trying to restore graph\n", 0); - - auto restoredGraph = GetFlatGraph(buffer); - - // nd4j_debug("Graph restored\n", 0); - - // converting FlatGraph to internal representation - auto nativeGraph = new Graph(restoredGraph); - - if (Environment::getInstance()->isDebugAndVerbose()) { - nativeGraph->printOut(); - } - - FlowPath flowPath; - nativeGraph->getVariableSpace()->setFlowPath(&flowPath); - - - // nd4j_debug("Going to execute graph\n", 0); - - // executing internal representation - auto status = GraphExecutioner::execute(nativeGraph); - if (status != ND4J_STATUS_OK) { - nd4j_printf("Graph execution failed with status: [%i]\n", status) - return nullptr; - } - - // nd4j_debug("Building output...\n", 0); - - flatbuffers::FlatBufferBuilder builder(1024); - - // fetching time reports - std::vector> timings_vector; - for (int e = 0; e < (int) nativeGraph->getAllNodes()->size(); e++) { - Node *node = nativeGraph->getAllNodes()->at(e); - - if (node->getContextPrototype() == nullptr) - continue; - - auto pair = CreateLongPair(builder, flowPath.outerTime(node->id()), flowPath.innerTime(node->id())); - if (!node->name().empty()) { - auto name = builder.CreateString(node->getName().c_str()); - auto fr = CreateFlatTiming(builder, node->id(), name, pair); - timings_vector.push_back(fr); - } else { - auto fr = CreateFlatTiming(builder, node->id(), 0, pair); - timings_vector.push_back(fr); - } - } - - - // now, we'll prepare output, depending on given outputmode - auto outputs = nativeGraph->fetchOutputs(); - auto size = static_cast(outputs->size()); - int arrays = 0; - std::vector> variables_vector; - for (int e = 0; e < size; e++) { - auto var = outputs->at(e); - - // FIXME: we want to export multi-output nodes as well - // FIXME: we want to export NDArrayList and skip nodes without outputs - if (!var->hasNDArray()) - continue; - - - auto array = var->getNDArray(); - - auto fArray = FlatUtils::toFlatArray(builder, *array); - - auto fName = builder.CreateString(var->getName()); - auto id = CreateIntPair(builder, var->id(), var->index()); - - auto fv = CreateFlatVariable(builder, id, fName, static_cast(array->dataType()), 0, fArray); - - variables_vector.push_back(fv); - arrays++; - } - - nd4j_debug("Returning %i variables back\n", arrays); - - auto varTimings = builder.CreateVector(timings_vector); - auto varVectors = builder.CreateVector(variables_vector); - auto result = CreateFlatResult(builder, restoredGraph->id(), varVectors, varTimings); - builder.Finish(result); - - // we might want to keep this graph for future - delete outputs; - delete nativeGraph; - - char* res = new char[builder.GetSize()]; - memcpy(res, builder.GetBufferPointer(), builder.GetSize()); - - nd4j_debug("Buffer size: %lld\n", static_cast(builder.GetSize())); - - return new ResultWrapper(builder.GetSize(), reinterpret_cast(res)); - } - - flatbuffers::Offset GraphExecutioner::execute(Graph *graph, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { - ExecutionResult result; - auto varSpace = graph->getVariableSpace(); - - if (request != nullptr && request->variables() != nullptr) { - auto vars = request->variables(); - for (int e = 0; e < vars->size(); e++) { - auto fv = vars->Get(e); - auto v = new Variable(fv); - varSpace->replaceVariable(v); - } - } - - if (Environment::getInstance()->isDebugAndVerbose()) - graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - if (status != sd::Status::OK()) - throw graph_execution_exception(request->id()); - - auto outputs = graph->fetchOutputs(); - - if (outputs->size() == 0) - throw no_results_exception(request->id()); - - - for (auto v: *outputs) { - result.emplace_back(v); - } - - auto t = result.asFlatResult(builder); - - delete outputs; - - return t; - } - } -} \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp index f296ef3cd4df..0a578848980f 100644 --- a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp +++ b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp @@ -39,17 +39,8 @@ namespace sd { } }; - ExecutorConfiguration* ExecutorConfiguration::clone() { - auto clone = new ExecutorConfiguration(); - clone->_profilingMode = _profilingMode; - clone->_executionMode = _executionMode; - clone->_outputMode = _outputMode; - clone->_timestats = _timestats; - clone->_direction = _direction; - clone->_footprintForward = _footprintForward; - clone->_footprintBackward = _footprintBackward; - - return clone; + ExecutorConfiguration ExecutorConfiguration::clone() const { + return ExecutorConfiguration(*this); }; flatbuffers::Offset ExecutorConfiguration::asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder) { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 4b5dc246a104..4773685dd9da 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -34,278 +34,24 @@ namespace sd { namespace graph { - std::vector* Graph::getAllNodes() { - return &_handles; - } - const std::vector& Graph::getPlaceholders() const { return *_variableSpace->getPlaceholders(); } - int Graph::numberOfPlaceholders() { + int Graph::numberOfPlaceholders() const { return _variableSpace->numberOfPlaceholders(); }; - Nd4jLong Graph::estimateRequiredMemory() { - - Nd4jLong result = 0L; - Nd4jLong lastStep = 0L; - - std::vector shapes; - MAP_IMPL, Nd4jLong*> shapesMap; - - int cntFD = 0; - - // we loop in similar way to execution - for (int l = 0; l < (int) _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; - - - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); - - /* - * Limited number of options here: - * - * 1) Op is inplace, so adds nothing to total - * 2) Op is not inplace, and 1:1 transform - * 3) Op is reduction (i.e. sum) - * 4) Op is multiplicator (i.e. im2col) - */ - if (node->hasCustomOp()) { - //if (node->isInplace()) { - // continue; - //} - - - nd4j_debug("Trying estimation [%i] on [%s]\n", node->id(), node->getCustomOp()->getOpName().c_str()); - - auto op = node->getCustomOp(); - auto in = node->input()->at(0); - - auto block = node->getContextPrototype(); - std::vector inputShapes; - int *oldShape; - for (auto v: *node->input()) { - nd4j_debug(" inputs for estimation are: %i:%i\n", v.first, v.second); - if (v.first < 0) { - inputShapes.push_back(_variableSpace->getVariable(v.first)->getNDArray()->getShapeInfo()); - } else { - inputShapes.push_back(shapesMap.at(v)); - } - } - - Context ctx(block, _variableSpace); - - ShapeList inSha(inputShapes); - auto outSha = op->calculateOutputShape(&inSha, ctx); - - int cnt = 0; - for (auto newShape: *outSha->asVector()) { - std::pair pairAddr(node->id(), cnt++); - std::pair, Nd4jLong*> pairShape(pairAddr, newShape); - - shapesMap.insert(pairShape); - - if (!block->isInplace() && !node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); - - shapes.push_back(newShape); - } - - delete outSha; - } else if (node->getOpClass() == OpClass_TRANSFORM) { - auto vec = node->input(); - - auto in = node->input()->at(0); - if (in.first < 0) { - - auto x = _variableSpace->getVariable(in); - auto z = _variableSpace->getVariable(node->id()); - - auto newShape = new Nd4jLong[shape::shapeInfoLength(x->getNDArray()->getShapeInfo())]; - memcpy(newShape, x->getNDArray()->getShapeInfo(), shape::shapeInfoByteLength(x->getNDArray()->getShapeInfo())); - - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong*> pairShape(pairAddr, newShape); - - shapesMap.insert(pairShape); - - if (!node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); - - shapes.push_back(newShape); - } else { - auto prevShape = shapesMap.at(in); - - auto newShape = new Nd4jLong[shape::shapeInfoLength(prevShape)]; - memcpy(newShape, prevShape, shape::shapeInfoByteLength(prevShape)); - - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong*> pairShape(pairAddr, newShape); - - shapesMap.insert(pairShape); - - if (!node->isInplace()) - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); - - shapes.push_back(newShape); - } - - } else if (node->getOpClass() == OpClass_REDUCTION) { - Nd4jLong *newShape = nullptr; - - // if that's scalar output - we don't care about previous node - if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == sd::DataTypeUtils::max())) { - newShape = new Nd4jLong[8]; - - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 8192; // set type as FLOAT32 by default - newShape[6] = 1; - newShape[7] = 99; - - } else { - auto in = node->input()->at(0); - - Nd4jLong *oldShape = nullptr; - // calculate tads here - if (in.first < 0) { - auto x = _variableSpace->getVariable(in)->getNDArray(); - - oldShape = x->getShapeInfo(); - } else { - - oldShape = shapesMap.at(in); - } - - //shape::TAD tad(oldShape, node->getDimensions()->data(), node->getDimensions()->size()); - auto numTads = shape::tadLength(oldShape, node->getDimensions()->data(), node->getDimensions()->size()); - Nd4jLong shape[2] = {1, (int) numTads}; - newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(oldShape), 'c', 2, shape); - } - - std::pair pairAddr(node->id(), 0); - std::pair, Nd4jLong*> pairShape(pairAddr, newShape); - - shapesMap.insert(pairShape); - - result += shape::length(newShape) * DataTypeUtils::sizeOfElement(node->dataType()); - - shapes.push_back(newShape); - } else if (node->getOpClass() == OpClass_MULTIPLICATOR) { - // can't be in non-special op - } - - - cntFD++; - } - } - - // this is the only place where we deallocate shapes. - //if (_variableSpace->launchContext()->getWorkspace() == nullptr) - // for (auto v: shapes) - // delete[] v; - - return result; - } - - void Graph::pushToOutputOnce(int id) { - if (std::find(_output.begin(), _output.end(), id) == _output.end()) - _output.emplace_back(id); - } - - void Graph::addOutput(int id) { - if (_configuration->_outputMode == OutputMode_EXPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) - pushToOutputOnce(id); - } - - ExecutorConfiguration * Graph::getExecutorConfiguration() { + const ExecutorConfiguration& Graph::getExecutorConfiguration() const { return _configuration; } - std::vector * Graph::fetchOutputs() { - auto res = new std::vector(); - - nd4j_debug("Graph output size: %i\n", _output.size()); - for (int e = 0; e < (int) _output.size(); e++) { - auto nodeId = _output.at(e); - nd4j_debug("Output node: %i\n", nodeId); - - for (int e = 0; e < DataTypeUtils::max(); e++) { - if (_variableSpace->hasVariable(nodeId, e)) { - res->push_back(_variableSpace->getVariable(nodeId, e)); - } else { - if (e == 0) { - throw unresolved_output_exception::build("Can't find output variable", nodeId, e); - } else - break; - } - } - } - - return res; - } - - MAP_IMPL * Graph::getMapped() { - return _mapped; - } - - MAP_IMPL *>* Graph::getOnion() { - return _onion; - } - - void Graph::injectNode(Node *node) { - if (node->getLayer() < 0) - throw std::runtime_error("Only nodes with non-negative layer defined can be inserted"); - - std::pair pair(node->id(), node); - if (_mapped->count(pair.first) > 0) - return; - - nd4j_debug("Node_%i mapped to layer_%i\n", node->id(), node->getLayer()); - - - _onion->at(node->getLayer())->push_back(node); - _mapped->insert(pair); - } - - void Graph::expandOnion(int newLayer) { - if (_onion->count(newLayer) > 0) - return; - - std::vector *rootList = new std::vector(); - std::pair*> pair(newLayer, rootList); - _onion->insert(pair); - } - VariableSpace * Graph::getVariableSpace() const { return _variableSpace; } Graph::~Graph() { - for (auto &v: *_mapped) - if (v.second->isRemovable()) - delete v.second; - - for (auto &v: _unmapped) - if (v.second->isRemovable()) - delete v.second; - - for (auto &v: *_onion) - delete v.second; - - for (auto v: _scopes) - delete v; - - delete _mapped; - delete _nodes; delete _variableSpace; - delete _onion; - delete _configuration; } int Graph::idByName(const std::string &nodeName) const { @@ -325,18 +71,17 @@ namespace sd { auto lvalue = array; addVariable(name, lvalue); } - +/* void Graph::addNode(Node &&node, const std::vector &inputs) { auto lvalue = node; addNode(lvalue, inputs); } + */ - void Graph::addNode(Node &node, const std::vector &inputs) { + void Graph::addNode(Node &node, const std::initializer_list &inputs) { if (node.id() != 0) throw std::runtime_error("Graph::addNode - Node has id defined"); - node.markRemovable(false); - // node must have numeric id node.setId(_maxId++); _symbolicLookupTable[node.name()] = node.id(); @@ -345,620 +90,56 @@ namespace sd { for (auto &v:inputs) node.pickInput(idByName(v), 0); - addNode(&node); + _unmapped[node.id()] = node; } - void Graph::addNode(Node &node, const std::vector> &inputs) { + void Graph::addNode(Node &node, const std::initializer_list &inputs) { node.markRemovable(false); throw std::runtime_error("not implemented yet"); } - void Graph::addNode(const Node &node) { + void Graph::addNode(Node &node, const std::initializer_list> &inputs) { node.markRemovable(false); - addNode(const_cast(&node)); - } - - void Graph::addNode(Node *node) { - _built.store(false); - - if (node->opType() == OpType_LOGIC) { - // nd4j_debug("Adding LogicOp [%i]\n", node->opNum()); - // SCOPE - if (node->opNum() == logic::Scope) { - auto scope = new Scope(node->id(), !node->getName().empty() ? node->getName().c_str() : ""); - _mappedScopes[node->id()] = scope; - _scopes.push_back(scope); - } - } - - auto cname = node->name().empty() ? nullptr : node->getName().c_str(); - auto nodeState = _variableSpace->hasVariable(node->id()) ? _variableSpace->getVariable(node->id()) :new Variable(nullptr, cname, node->id()); - if (!node->name().empty()) - nodeState->setName(node->name()); - - - if (node->isInplace()) - nodeState->markRemovable(false); - - _handles.push_back(node); - - - _nodes->emplace_back(node->id()); - - // storing node state now - _variableSpace->putVariable(node->id(), nodeState); - - // here we're filling our blocks with future variables - if (node->opType() == OpType_LOGIC && node->opNum() == logic::While) { - // filling while - int inputs = node->input()->size(); - for (int e = 0; e < inputs - 2; e++){ - auto deepVar = new Variable(nullptr, nullptr, node->id(), e); - - std::pair id(node->id(), e); - _variableSpace->putVariable(id, deepVar); - } - } else if (node->hasCustomOp()) { - // custom ops require Block inside. but we'll set it inside buildGraph - - // TODO: we want to change this, to make blocks thread-local/session-local - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - - // and might have > 1 output - if (node->getCustomOp()->getOpDescriptor()->getNumberOfOutputs() > 1) { - for (int e = 1; e < node->getCustomOp()->getOpDescriptor()->getNumberOfOutputs(); e++) { - auto deepVar = new Variable(nullptr, nullptr, node->id()); - //deepVar->setId(node->id()); - deepVar->setId(node->id(), e); - if (node->isInplace()) - deepVar->markRemovable(false); - - std::pair id(node->id(), e); - _variableSpace->putVariable(id, deepVar); - } - } else { - // we need to check, if we should propagate output of this variable somewhere - for (int e = 0; e < node->output()->size(); e++) { - auto out = node->output()->at(e); - if (out.first < 0) { - nd4j_debug("Node [%i] will be propagating its output to Variable [%i]\n", node->id(), out.first); - auto extVar = _variableSpace->getVariable(out); - if (extVar->hasNDArray()) { - nodeState->setNDArray(extVar->getNDArray()); - nodeState->markRemovable(false); - } - } - } - } - } - - // we're saving only ops that have internal outpus here - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) - if (node->hasInternalOutputs()) - pushToOutputOnce(node->id()); - - // if outputs are undefined, we have to auto-create variable - if (node->output()->size() == 0 || (node->output()->size() == 1 && node->output()->at(0).first == 0)){ - Variable* var; - if (!_variableSpace->hasVariable(node->id())) { - var = new Variable(); - } else { - var = _variableSpace->getVariable(node->id()); - } - // nd4j_logger("Adding auto output variable; Output size: %i\n", node->output()->size()); - - var->setId(node->id()); - var->setName(node->getName()); - _variableSpace->putOutputVariable(var); - //node->pickExternalOutput(var->id()); - - this->_autos.push_back(var->id()); - -// } - } else if (node->hasExternalOutputs()) { - // TODO: we might want this behavior configurable! - nd4j_logger("Adding specific output variable: Outputs: %i; HasInternal: %i;\n", node->output()->size(), node->hasInternalOutputs()); - - // we're pushing this node to output only - if ((!node->hasInternalOutputs() && (_configuration->_outputMode == OutputMode_IMPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT)) ) { - for (int e = 0; e < (int) node->output()->size(); e++) { - if (node->output()->at(e).first < 0) - pushToOutputOnce(node->output()->at(e).first); - } - - nd4j_logger("Loop finished: %i outputs now\n", this->_output.size()); - } - } - - // ops that are tied to specific scope are never placed into the structure. - if (node->isScoped()) { - if (_mappedScopes.count(node->scopeId()) < 1) { - nd4j_printf("Requested scope [%i/%s] wasn't created yet\n", node->scopeId(), node->scopeName()->c_str()); - throw std::invalid_argument("Unknown scope requested"); - } - - Scope* scope = _mappedScopes.at(node->scopeId()); - scope->push_back(node); - - return; - } - - std::pair pair(node->id(), node); - // nd4j_debug("Adding node_%i\n", node->id()); - // if model has only external variables as input - it goes to first layer, no matter what. - if (node->hasExternalInputs() && !node->hasInternalInputs()) { - node->setLayer(0); - - injectNode(node); - - // nd4j_logger("A Node_%i mapped to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - - return; - } else { - // in some cases we're able to put stuff immediately - if (node->hasInternalInputs() && !node->hasExternalInputs() && node->input()->size() == 1) { - - bool automapAllowed = true; - for (int e = 0; e < node->input()->size(); e++) { - auto cInput = node->input()->at(e); - int cFirst = cInput.first; - if (_mapped->count(cFirst) == 0) { - automapAllowed = false; - break; - } - } - - // we only can put single input nodes, whose outputs were not mapped yet - //if (_mapped->count(node->input()->at(0).first) == 1 && (node->output()->size() == 0 || _mapped->count(node->output()->at(0).first) == 0)) { - if (automapAllowed) { - auto parent = _mapped->at(node->input()->at(0).first); - int nLayer = parent->getLayer() + 1; - - expandOnion(nLayer); - node->setLayer(nLayer); - injectNode(node); - - if (node->output()->size() > 0) { - nd4j_logger("Node_%i mapped to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - } else { - nd4j_logger("Node_%i mapped to layer_%i; Output: none;\n", node->id(), node->getLayer()); - } - - return; - } - } /*else if (node->opType() == OpType_LOGIC && node->opNum() == 10) { - // Scopes are just being added. They won't be executed on their own anyway. - - int nLayer = _onion->size(); - - expandOnion(nLayer); - node->setLayer(nLayer); - injectNode(node); - - nd4j_logger("Node_%i mapped Scope to layer_%i; Output: %i;\n", node->id(), node->getLayer(), node->output()->at(0)); - - return; - } -*/ - // otherwise we're putting it to unmapped space for further sorting - _unmapped.insert(pair); - _unmappedMap.emplace_back(pair.first); - nd4j_debug("adding: %i\n", pair.first); - } - } - - Nd4jStatus Graph::buildGraph() { - if (_built.load()) { - prepareOutputs(); - return ND4J_STATUS_OK; - } - - typename MAP_IMPL::iterator fit; - int cnts = 0; - for ( fit = _unmapped.begin(); fit != _unmapped.end(); fit++ ) { - int tK = fit->first; - int tF = _unmappedMap.at(cnts++); - } - - int buildCnt = 0; - int buildLimit = _unmapped.size() * 2; - while (_unmapped.size() > 0) { - - int sz = _unmapped.size(); - int sf = _unmappedMap.size(); - - std::vector queue; - - // first pass for unmapped nodes, we try to build tale here - typename MAP_IMPL::iterator it; - int cntf = 0; - nd4j_debug("-----------\n",""); - for ( it = _unmapped.begin(); it != _unmapped.end(); it++ ) { - auto node = it->second; - int tK = it->first; - int tF = _unmappedMap.at(cntf++); - - //nd4j_printf("tK: %i; tF: %i\n", tK, tF); - //for (int f = 0; f < sz; f++) { - // auto node = _unmapped.at(_unmappedMap.at(f)); - - - // single-input node - if (node->input()->size() == 1) { - - if (!node->name().empty()) { - nd4j_debug("Trying SI Node_%i\n", node->id()); - } else { - nd4j_debug("Trying SI Node_%i:[%s]\n", node->id(), node->getName().c_str()); - } - - int iNode = node->input()->at(0).first; - if (iNode < 0 || _variableSpace->hasExternalVariable(iNode)) { - // this is external variable, should we check, who's the last user of this variable? - int lastLayer = _onion->size(); - expandOnion(lastLayer); - - node->setLayer(lastLayer); - this->injectNode(node); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (int e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } else if (_mapped->count(iNode) > 0) { - int maxLayer = _mapped->at(iNode)->getLayer() + 1; - - node->setLayer(maxLayer); - if (_onion->count(maxLayer) == 0) - expandOnion(maxLayer); - - this->injectNode(node); - queue.emplace_back(node->id()); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } else - continue; - - //_unmapped.erase(node->id()); - queue.emplace_back(node->id()); - } else { - // multi-input node - if (!node->name().empty()) { - nd4j_debug("Trying MI Node_%i\n", node->id()); - } else { - auto np = node->name(); - nd4j_debug("Trying MI Node_%i:[%s]\n", node->id(), node->getName().c_str()); - } - - int maxLayer = 0; - bool breaker = false; - for (unsigned int e = 0; e < node->input()->size(); e++) { - int nodeId = node->input()->at(e).first; - - // if input node wasn't mapped yet - we'll have skip it in this round - if (_mapped->count(nodeId) == 1) { - auto iNode = _mapped->at(nodeId); - - if (maxLayer < iNode->getLayer()) - maxLayer = iNode->getLayer(); - } else - if (node->opType() == OpType_LOGIC) { - // just allow it? - } else // checking if that's static variable - if (nodeId > 0 && !_variableSpace->hasExternalVariable(nodeId)) { - breaker = true; - break; - } - } - - if (breaker) - continue; - - maxLayer++; - if (_onion->count(maxLayer) == 0) - expandOnion(maxLayer); - - node->setLayer(maxLayer); - injectNode(node); - queue.emplace_back(node->id()); - - if (node->hasCustomOp()) { - ContextPrototype* block = nullptr; - - if (!node->hasBlockAttached()) { - block = new ContextPrototype(node->getCustomOp()->getOpDescriptor(), node->id()); - node->setContextPrototype(block); - } else - block = node->getContextPrototype(); - - if (!block->hasVariablesFilled()) { - - for (uint32_t e = 0; e < node->input()->size(); e++) { - auto p = node->input()->at(e); - - block->pickInput(p); - } - } - } - } - } - - for (auto &v: queue) - _unmapped.erase(v); - - // second pass is mover, we'll be moving onion layers around here - buildCnt++; - if (buildCnt > buildLimit) { - nd4j_printf("Unable to build graph, probably unmapped nodes, or something: %i nodes left\n", _unmapped.size()); - for (auto v: _unmapped) { - Node* node = v.second; - nd4j_printf("Unmapped node: [%i]\n", node->id()); - } - - throw std::runtime_error("Unable to build graph"); - } - } - - if (_unmapped.size() == 0) - _built.store(true); - - prepareOutputs(); - - return sd::Status::OK(); - } - - void Graph::tagInplaceNodes() { - // just calling, in case it wasn't built before - if (!_built.load()) - this->buildGraph(); - - bool buildRef = false; - - // checking for non-refenenced nodes - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - if (node->totalReferences() == 0) { - buildRef = true; - break; - } - } - - if (buildRef) { - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - auto inputs = node->input(); - for (auto &t: *inputs) { - if (_mapped->count(t.first) == 0) - continue; - - Node* inode = _mapped->at(t.first); - inode->addReference(node->id()); - } - } - } - - - for (auto v: *_nodes) { - // skipping unmapped nodes - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - /** - * Node can be inplace if 2 requirements met: - * 1) current node allows in-place modification - * 2) source node has only 1 output - */ - - // checking for first requirement first - if (node->getCustomOp() != nullptr) - if (node->getCustomOp()->getOpDescriptor()->allowsInplace()){ - bool singleInput = true; - auto inputs = node->input(); - for (auto &t: *inputs) { - if (_mapped->count(t.first) == 0) - continue; - - Node* inode = _mapped->at(t.first); - - int output_size = inode->output()->size(); - - // checking for second requirement: inputNode must not be used as input anywhere - if (inode->totalReferences() > 1) { - singleInput = false; - break; - } - } - - node->markInplace(singleInput); - } - } - } - - void Graph::prepareOutputs() { - // if we're dumping everything out there - we'll add external variables as well - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) { - auto ext = _variableSpace->getExternalVariables(); - nd4j_verbose("Number of external variables: %i\n", ext->size()) - for (unsigned int e = 0; e < ext->size(); e++) { - pushToOutputOnce(ext->at(e)->id()); - } - - for (auto v: *_nodes) { - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - if (std::find(_output.begin(), _output.end(), node->id()) == _output.end()) - _output.emplace_back(node->id()); - } - - } else if (_configuration->_outputMode == OutputMode_IMPLICIT) { - // we're adding final nodes of the graph. those, not used as input anywhere - nd4j_debug("Paring nodes... \n", ""); - - if (Environment::getInstance()->isDebugAndVerbose()) { - // nd4j_printv("current _output", _output); - } - //_output.clear(); - - for (auto v: *_nodes) { - // we should check for scopes, and other possible non-mapped stuff - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - if (!node->name().empty()) { - nd4j_debug("Node %i; Name: [%s]\n", v, node->name().c_str()); - } else { - nd4j_debug("Node %i\n", v); - } - - // updating outputs now - for (int e = 0; e < node->input()->size(); e++) { - auto inP = node->input()->at(e); - - // input can be variable, or node. we only care about nodes - if (_mapped->count(inP.first) > 0) { - _mapped->at(inP.first)->pickOutputOnce(v); - } - } - } - // at this point all nodes have filled inputs/outputs, so we know nodes that do not have any connected outputs - - for (auto v: *_nodes) { - // we should check for scopes, and other possible non-mapped stuff - if (_mapped->count(v) == 0) - continue; - - Node* node = _mapped->at(v); - - if (!node->hasInternalOutputs()) { - if (!node->name().empty()) { - nd4j_debug("Output node found: [%i:<%s>]\n", v, node->name().c_str()); - } else { - nd4j_debug("Output node found: [%i]\n", v); - } - - // FIXME: we don't really need search here. - - if (std::find(_output.begin(), _output.end(), node->id()) == _output.end()) - _output.emplace_back(node->id()); - } else if (Environment::getInstance()->isDebugAndVerbose()) { - nd4j_debug("Node [%i:<%s>] has %i outputs announced:\n", v, node->name().c_str(), node->output()->size()); - printf("{"); - for (auto s : *node->output()) { - printf("[%i:%i], ", s.first, s.second); - } - printf("}\n"); - fflush(stdout); - } - } - } + throw std::runtime_error("not implemented yet"); } Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { - this->_onion = new MAP_IMPL *>(); - this->_mapped = new MAP_IMPL (); - this->_nodes = new std::vector(); this->_variableSpace = variableSpace == nullptr ? new VariableSpace() : variableSpace; bool trusted = flatGraph != nullptr; - // add 0 layer - this->expandOnion(0); - // if there was no exec configuration in flatgraph - create default one if (flatGraph != nullptr && flatGraph->configuration() != nullptr) { - _configuration = new ExecutorConfiguration(flatGraph->configuration()); + _configuration = ExecutorConfiguration(flatGraph->configuration()); } else - _configuration = new ExecutorConfiguration(); + _configuration = ExecutorConfiguration(); // if memory reqs were set - initialize workspace - if (_configuration->_footprintForward > 0) { + if (_configuration._footprintForward > 0) { sd::memory::Workspace *workspace = this->_variableSpace->launchContext()->getWorkspace(); - workspace->expandBy(_configuration->_footprintForward); + workspace->expandBy(_configuration._footprintForward); } // parsing variables here if (flatGraph != nullptr && flatGraph->variables() != nullptr && flatGraph->variables()->size() > 0) { for (unsigned int e = 0; e < flatGraph->variables()->size(); e++) { auto flatVar = flatGraph->variables()->Get(e); - - auto var = new Variable(flatVar); std::pair pair(flatVar->id()->first(), flatVar->id()->second()); - _variableSpace->putVariable(pair, var); - // if that's VariableSpace mode - we're pushing it to _output - if (_configuration->_outputMode == OutputMode_VARIABLE_SPACE) - pushToOutputOnce(var->id()); + auto var = new Variable(flatVar); + if (flatVar->name() != nullptr) { + var->setName(flatVar->name()->str()); + _symbolicLookupTable[var->name()] = pair.first; + } + _variableSpace->putVariable(pair, var); } } // at this point we expect all variables are already registered // we're saving outputs only if explicit mode is set - if (_configuration->_outputMode == OutputMode_EXPLICIT || _configuration->_outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { + if (_configuration._outputMode == OutputMode_EXPLICIT || _configuration._outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { if (flatGraph != nullptr && flatGraph->outputs() != nullptr) { for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { auto out = flatGraph->outputs()->Get(e); @@ -967,9 +148,6 @@ namespace sd { nd4j_verbose("Non-existent variable requested: %i\n", out); throw std::runtime_error("Non-existent variable requested"); } - - // TODO: fix this .first - pushToOutputOnce(vp.first); } } } @@ -984,166 +162,56 @@ namespace sd { } nd4j_debug("Node name: [%s]\n", node->name()->c_str()); - auto nnode = new Node(node); - /* - expandOnion(e); - nnode->setLayer(e); - this->addNode(nnode); - injectNode(nnode); - _unmapped.erase(nnode->id()); - */ + Node nnode(node); // just filling list of nodes - _unmapped[nnode->id()] = nnode; - } - - - this->toposortNodes(); - - _built = true; - } - - /** - * we allow in-place execution optimizations ONLY if 2 requirements met: - * 1) this is FeedForward pass ONLY - * 2) OPTIMIZED mode is set, so no intermediate results are going to be used - */ - if (_configuration->_direction == Direction_FORWARD_ONLY && _configuration->_outputMode == OutputMode_OPTIMIZED) - this->tagInplaceNodes(); - } - - - void Graph::toposortNodes() { - int attempts = 0; + _unmapped[nnode.id()] = nnode; - // in worst possible case number of rolls equals to the number of nodes - int limit = _unmapped.size() + 1; - - std::vector tbd; - while (!_unmapped.empty() && attempts < limit) { - for (auto np:_unmapped) { - auto id = np.first; - tbd.emplace_back(id); + if (!nnode.name().empty()) + _symbolicLookupTable[nnode.name()] = nnode.id(); } - - // rolling through unmapped nodes - for (auto np:tbd) { - auto id = np; - auto node = _unmapped[id]; - - // this variables contains the layer of maximal dependency - int maxDependencyLayer = -1; - - // simple flag - bool canMap = true; - - // looping through inputs, to check if they were already mapped - auto inputs = node->input(); - for (auto in:*inputs) { - // only 2 options here, in either refers to the node, or to the variable - // however, node can be already mapped, or not yet. this makes it 3 options :) - - if (hasNode(in.first)) { // node was mapped - auto dependency = nodeById(in.first); - auto layer = dependency->getLayer(); - if (layer > maxDependencyLayer) - maxDependencyLayer = layer; - - } else if (_unmapped.count(in.first) > 0) { // node is unmapped yet - // can't map this node yet, due to non-resolved dependencies - canMap = false; - } else if (_variableSpace->hasVariable(in.first)){ // that's probably variable. if not - we'll throw exception later - // do nothing, maxDepLayer is -1 here, because it's a variable input - } else { - throw graph::unresolved_input_exception::build("Unknown input specified", id, in); - } - } - - if (canMap) { - auto layer = maxDependencyLayer + 1; - this->expandOnion(layer); - node->setLayer(layer); - this->addNode(node); - this->injectNode(node); - _unmapped.erase(id); - } - } - - // if something was successfully mapped - remove it from unmapped entries - if (!tbd.empty()) - tbd.clear(); - - attempts++; } - - if (!_unmapped.empty()) - throw graph_exception("Graph wasn't toposorted", 0); - - _built = true; - } - - -/** - * This method returns number of root nodes in this graph - * @return - */ - int Graph::rootNodes() { - return this->_onion->at(0)->size(); } /** * This method returns total number of nodes in this graph * @return */ - int Graph::totalNodes() { - if (_built.load() != true) - buildGraph(); - - return _mapped->size(); + int Graph::size() const { + return _unmapped.size(); } Nd4jStatus Graph::validate() { - if (!_built) { - _mutexPreprocessing.lock(); - if (!_built) { - this->buildGraph(); - } - _mutexPreprocessing.unlock(); - } - - if (_built.load() != true) - return ND4J_STATUS_BAD_GRAPH; - - return ND4J_STATUS_OK; + throw std::runtime_error("Graph::validate - method not implemented"); }; - void Graph::printOutNode(Node* node) { - nd4j_printf("%i. ", node->id()); - switch(node->opType()) { + void Graph::printOutNode(const Node &node) const { + nd4j_printf("%i. ", node.id()); + switch(node.opType()) { case OpType_CUSTOM: { - printf("%s; ", node->getCustomOp()->getOpName().c_str()); + printf("%s; ", node.customOp()->getOpName().c_str()); } break; case OpType_LOGIC: { - printf("%s; ", EnumUtils::_LogicOpToString(node->opNum())); + printf("%s; ", EnumUtils::_LogicOpToString(node.opNum())); } break; default: { - printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node->opType()), (int) node->opNum()); + printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node.opType()), (int) node.opNum()); } } nd4j_printf("Inputs: [", ""); //auto block = node->getBlock(); - for (int e = 0; e < node->input()->size(); e++) { + for (int e = 0; e < node.input().size(); e++) { - auto in = node->input()->at(e); + auto in = node.input()[e]; printf("{%i:%i}", in.first, in.second); - if (e < node->input()->size() - 1) + if (e < node.input().size() - 1) nd4j_printf(", ", ""); } - if (node->opType() == OpType_CUSTOM) { - auto ctx = node->protoContext(); + if (node.opType() == OpType_CUSTOM) { + auto ctx = node.protoContext(); if (ctx->getIArguments()->size() > 0) { printf("]; iArgs: ["); @@ -1163,8 +231,6 @@ namespace sd { } void Graph::printOut() { - buildGraph(); - // print variables first if (_variableSpace->totalEntries() > 0) { nd4j_printf("\nPrinting out Variables...\n", ""); @@ -1188,39 +254,9 @@ namespace sd { } } - if (_onion->size() > 0) - nd4j_printf("\nPrinting out Graph...\n", ""); - - int opCnt = 0; - for (int l = 0; l < _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; - - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); - - // we're skipping Scopes here - if (node->opType() == OpType_LOGIC && node->opNum() == logic::Scope) - continue; - - printOutNode(node); - } - } - - - if (_scopes.size() > 0) - nd4j_printf("\nPrinting out Scopes...\n",""); - - for (int s = 0; s < _scopes.size(); s++) { - Scope* scope = _scopes.at(s); - nd4j_printf("Scope %i:<%s>:\n", scope->id(), scope->name()->c_str()); - - for (int n = 0; n < scope->nodes()->size(); n++) { - Node* node = scope->nodes()->at(n); - printOutNode(node); - } - } - fflush(stdout); + + throw std::runtime_error("Graph::printOut - not implemented yet"); } Nd4jStatus Graph::validateNode(Node *node) { @@ -1228,122 +264,12 @@ namespace sd { return ND4J_STATUS_OK; } - std::vector Graph::getOperations() { - buildGraph(); - // nd4j_printf("\nRetrieving ops from the Graph and collect them...\n", ""); - std::vector res; - - int opCnt = 0; - for (int l = 0; l < _onion->size(); l++) { - int layerSize = _onion->count(l) == 1 ? _onion->at(l)->size() : 0; - - for (int n = 0; n < layerSize; n++) { - Node* node = _onion->at(l)->at(n); - if (node->name().empty()) - continue; - - sd::ops::OpDescriptor* pOpDescriptor = nullptr; - std::string opNameStr; //node->name(); - int numInputs = 0; - int numOutputs = 0; - - switch(node->opType()) { - case OpType_CUSTOM: { - pOpDescriptor = node->getCustomOp()->getOpDescriptor(); - } - break; - case OpType_LOGIC: { - opNameStr = std::string(EnumUtils::_LogicOpToString(node->opNum())); - } - break; - default: { - opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance()->local_to_string((int) node->opNum()) + "}"; - } - } - - if (node->input()) - numInputs = node->input()->size(); - - if (node->output()) - numOutputs = node->output()->size(); - bool inplace = node->isInplace(); - - //OpDescriptor opDescriptor(numInputs, numOutputs, opNameStr, inplace); - - // we're skipping Scopes here - if (node->opType() == OpType_LOGIC && node->opNum() == logic::Scope) - continue; - if (pOpDescriptor) - res.emplace_back(*pOpDescriptor); - else - res.emplace_back(sd::ops::OpDescriptor(numInputs, numOutputs, opNameStr, inplace)); - } - } - - - // nd4j_printf("\nCollecting out Scopes...\n",""); - for (int s = 0; s < _scopes.size(); s++) { - Scope* scope = _scopes.at(s); - // nd4j_printf("Scope %i:<%s>:\n", scope->id(), scope->name()->c_str()); - - for (int n = 0; n < scope->nodes()->size(); n++) { - Node* node = scope->nodes()->at(n); - //printOutNode(node); - if (node->name().empty()) - continue; - - std::string opNameStr; //node->name(); - sd::ops::OpDescriptor* pOpDescriptor = nullptr; - int numInputs = 0; - int numOutputs = 0; - - switch(node->opType()) { - case OpType_CUSTOM: { - pOpDescriptor = node->getCustomOp()->getOpDescriptor(); - } - break; - case OpType_LOGIC: { - opNameStr = std::string(EnumUtils::_LogicOpToString(node->opNum())); - } - break; - default: { - opNameStr = std::string(EnumUtils::_OpTypeToString(node->opType()))+"{" + ops::OpRegistrator::getInstance()->local_to_string((int) node->opNum()) + "}"; - } - } - - if (node->input()) - numInputs = node->input()->size(); - - if (node->output()) - numOutputs = node->output()->size(); - bool inplace = node->isInplace(); - - if (pOpDescriptor != nullptr) - res.emplace_back(*pOpDescriptor); - else - res.emplace_back(sd::ops::OpDescriptor(numInputs, numOutputs, opNameStr, inplace)); - } - } - - return res; - } - - Scope *Graph::scopeById(int id) { - if (_mappedScopes.count(id) == 0) { - nd4j_printf("Requested Scope [%i] doesn't exist\n", id); - throw std::runtime_error("Non-existent Scope was requested"); - } - - return _mappedScopes.at(id); - } - void Graph::forgetVariableSpace() { _variableSpace = nullptr; } - void Graph::replaceState(VariableSpace *state, ExecutorConfiguration *configuration) { + void Graph::replaceState(VariableSpace *state, const ExecutorConfiguration &configuration) { delete _variableSpace; - delete _configuration; _variableSpace = state; _configuration = configuration; @@ -1352,163 +278,21 @@ namespace sd { Graph* Graph::cloneWithProxy() const { auto clone = new Graph(); - clone->replaceState(new VariableProxy(this->_variableSpace), this->_configuration->clone()); - - // transfer nodes - for (int e = 0; e < _nodes->size(); e++) - clone->_nodes->emplace_back(_nodes->at(e)); + clone->replaceState(new VariableProxy(this->_variableSpace), this->_configuration.clone()); - // transfer outputs - for (auto v: _output) - clone->_output.emplace_back(v); - - // transfer autos - for (auto v: _autos) - clone->_autos.emplace_back(v); - - // transfer scopes - for (auto &v: _mappedScopes) { - auto scp = v.second->clone(); - clone->_mappedScopes[v.first] = scp; - clone->_scopes.emplace_back(scp); - } - - // transfer mapped nodes - for (auto &v: *_onion) { - auto vec = clone->_onion->count(v.first) > 0 ? clone->_onion->at(v.first) : new std::vector(); - - - // cloning actual nodes - auto ovec = (*_onion)[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - clone->_handles.emplace_back(n); - (*clone->_mapped)[n->id()] = n; - } - - if (clone->_onion->count(v.first) < 1) - (*clone->_onion)[v.first] = vec; - } - - // transfer mapped nodes - for (auto &v: _unmapped) - clone->_unmapped[v.first] = v.second->clone(); - - clone->_built.store(_built.load()); - - return clone; + throw std::runtime_error("Graph::cloneWithProxy - not implemented yet"); } Graph* Graph::clone() const { auto clone = new Graph(); - clone->replaceState(this->_variableSpace->clone(), this->_configuration->clone()); - - // transfer nodes - for (int e = 0; e < _nodes->size(); e++) - clone->_nodes->emplace_back(_nodes->at(e)); - - // transfer outputs - for (auto v: _output) - clone->_output.emplace_back(v); - - // transfer autos - for (auto v: _autos) - clone->_autos.emplace_back(v); - - // transfer scopes - for (auto &v: _mappedScopes) { - auto scp = v.second->clone(); - clone->_mappedScopes[v.first] = scp; - clone->_scopes.emplace_back(scp); - } - - // transfer mapped nodes - for (auto &v: *_onion) { - auto vec = clone->_onion->count(v.first) > 0 ? clone->_onion->at(v.first) : new std::vector(); - - - // cloning actual nodes - auto ovec = (*_onion)[v.first]; - for (auto x: *(ovec)) { - auto n = x->clone(); - vec->emplace_back(n); - clone->_handles.emplace_back(n); - (*clone->_mapped)[n->id()] = n; - } - - if (clone->_onion->count(v.first) < 1) - (*clone->_onion)[v.first] = vec; - } - - // transfer mapped nodes - for (auto &v: _unmapped) - clone->_unmapped[v.first] = v.second->clone(); - - clone->_built.store(_built.load()); - - return clone; - } - - bool Graph::hasNode(int id) { - return _mapped->count(id) > 0; - } + clone->replaceState(this->_variableSpace->clone(), this->_configuration.clone()); - Node* Graph::nodeById(int id) { - return _mapped->at(id); + throw std::runtime_error("Graph::clone - not implemented yet"); } - bool Graph::hasScope(int id) { - return _mappedScopes.count(id) > 0; - } - - Nd4jLong Graph::hashCode() { - if (!_built.load()) - this->buildGraph(); - - Nd4jLong hash = 0L; - std::string localStamp; - /** - * Plan is: - * 1) get shapes of existing variables - * 2) get hash codes of individual ops - * 3) optionally: get node names, if they are defined - * 4) use long hash on that - */ - int cnt = 0; - /* - if (_variableSpace != nullptr) { - // loop over existing variables - for (auto v: *(_variableSpace->handles())) { - if (v->hasNDArray()) { - NDArray *arr = v->getNDArray(); - if (arr != nullptr && arr->nonNull()) { - auto shape = arr->getShapeAsVector(); - auto string = ShapeUtils::shapeAsString(shape); - localStamp += string; - } - } - } - } - */ - - // loop over nodes in graph - for (auto &v: *_mapped) { - Node *node = v.second; - - // optional part: node names - if (!node->name().empty()) { - localStamp += node->name(); - } - } - - - hash = ops::HashHelper::getInstance()->getLongHash(localStamp); - - nd4j_debug("Graph hash: %lld\n", hash); - - return hash; + Nd4jLong Graph::hashCode() const { + throw std::runtime_error("Graph::hashCode - not implemented yet"); } @@ -1776,7 +560,7 @@ namespace sd { } bool Graph::topolSearch(const int startNode, const std::set& nodeBranches, const std::unordered_map& positions, OpSequence& opSeq) const { - +/* if (nodeBranches.empty() || _handles.empty()) return false; @@ -1793,12 +577,16 @@ namespace sd { } } - return true; + return false; + */ + + throw std::runtime_error("Graph::topolSearch - not implemented yet"); } OptimizedGraph Graph::optimizedGraph() const { OptimizedGraph optGraf(const_cast(_memoryMaager)); + /* OpSequence opSeq; std::set nodesMap, startNodes; std::unordered_map iDpositions; @@ -1826,6 +614,7 @@ namespace sd { topolSearch(start, nodesMap, iDpositions, opSeq); } optGraf.append(opSeq); + */ return optGraf; } diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index 3eaf9c3b50be..b79ba2b5ff64 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -19,7 +19,6 @@ // #include -#include #include #include @@ -115,7 +114,7 @@ namespace sd { flatbuffers::Offset GraphHolder::execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { if (!hasGraph(graphId)) throw unknown_graph_exception(graphId); - +/* lockRead(graphId); auto graph = cloneGraph(graphId); @@ -125,6 +124,8 @@ namespace sd { unlockRead(graphId); return res; + */ + throw std::runtime_error("GraphHolder::execute - not implemented yet"); } GraphHolder* GraphHolder::_INSTANCE = 0; diff --git a/libnd4j/include/graph/impl/GraphState.cpp b/libnd4j/include/graph/impl/GraphState.cpp deleted file mode 100644 index a8b25603a512..000000000000 --- a/libnd4j/include/graph/impl/GraphState.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -#include -#include - - -namespace sd { -namespace graph { - GraphState::GraphState(Nd4jLong id) { - _id = id; - _graph = new Graph(nullptr, &_variableSpace); - }; - - GraphState::~GraphState() { - // stupid. we should get rid of pointers here i think. no need to bother - for (const auto &v: _scopes) { - v.second->forgetNodes(); - delete v.second; - } - - // we must remove reference to VariableSpace - _graph->forgetVariableSpace(); - - delete _graph; - }; - - Nd4jStatus GraphState::registerScope(int scopeId) { - auto scope = new Scope(scopeId); - _scopes[scopeId] = scope; - - auto scopeWrapper = new Node(OpType_LOGIC, 10, scopeId); - _graph->addNode(scopeWrapper); - - return Status::OK(); - }; - - Nd4jStatus GraphState::forgetScope(int scopeId) { - if (_scopes.count(scopeId) > 0) - _scopes.erase(scopeId); - else - return Status::THROW("Non-existent scope requested"); - - return Status::OK(); - }; - -#ifndef __JAVACPP_HACK__ - Nd4jStatus GraphState::attachOpToScope(int scopeId, int nodeId, ops::DeclarableOp *op, ArgumentsList inputs) { - if (_scopes.count(scopeId) == 0) - return Status::THROW("GraphState: can't attach op to unknown scope"); - - auto scope = _scopes[scopeId]; - - // creating new Node -// auto node = new Node(OpType_CUSTOM, 0, nodeId); - auto node = new Node(op, nodeId); -// node->setCustomOp(op); - node->setScopeInfo(scopeId); - - // mapping inputs here - for (int e = 0; e < inputs.size(); e++) { - auto p = inputs.at(e); - - // each expected input is Variable in current VariableSpace - // it should have it's numerical and symbolic ID - - if (!_variableSpace.hasVariable(p.first(), p.second())) { - auto var = new Variable(); - var->setId(p.first(), p.second()); - _variableSpace.putVariable(p.first(), p.second(), var); - } - - // nd4j_printf("Node_%i: adding input [%i:%i]\n", node->id(), p.first(), p.second()); - node->pickInput(p.first(), p.second()); - } - - scope->push_back(node); - - _graph->addNode(node); - - return Status::OK(); - }; - - Graph* GraphState::graph() { - return _graph; - } - - Scope* GraphState::getScope(int scopeId) { - if (_scopes.count(scopeId) == 0) { - nd4j_printf("GraphState: Unknown scope requested %i\n", scopeId); - return nullptr; - } - - return _scopes[scopeId]; - } -#endif - Nd4jStatus GraphState::defineReturn(int scopeId, int nodeId, ArgumentsList args) { - if (_scopes.count(scopeId) == 0) - return Status::THROW("GraphState: can't attach op to unknown scope"); - - auto scope = _scopes[scopeId]; - - // creating new Node for RETURN - auto node = new Node(OpType_LOGIC, 40, nodeId); - node->setScopeInfo(scopeId); - - // mapping inputs here - for (int e = 0; e < args.size(); e++) { - auto p = args.at(e); - - // each expected input is Variable in current VariableSpace - // it should have it's numerical and symbolic ID - - if (!_variableSpace.hasVariable(p.first(), p.second())) { - auto var = new Variable(); - var->setId(p.first(), p.second()); - _variableSpace.putVariable(p.first(), p.second(), var); - } - - // nd4j_printf("Node_%i: adding input [%i:%i]\n", node->id(), p.first(), p.second()); - node->pickInput(p.first(), p.second()); - node->pickOutput(0, e); - } - - scope->push_back(node); - - _graph->addNode(node); - - - return Status::OK(); - } - - bool GraphState::hasScope(int scopeId) { - return _scopes.count(scopeId) > 0; - } - - VariableSpace* GraphState::variableSpace() { - return &_variableSpace; - }; - - Nd4jLong GraphState::id() { - return _id; - } - - Nd4jStatus GraphState::attachOpToScope(int scopeId, Nd4jLong opNum, int type, ArgumentsList inputs) { - // we should use OpRegistrator here, to create Node and push it to specific scope - return Status::OK(); - } -} -} \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index bcaf42ba99ad..04ac581535b5 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -65,7 +65,7 @@ namespace sd { // FIXME: get rid of this!!! _scalar = NDArrayFactory::create(0); - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: iArgs) block->getIArguments()->emplace_back(v); @@ -104,7 +104,7 @@ namespace sd { // FIXME: get rid of this!!! _scalar = NDArrayFactory::create(0); - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: iArgs) block->getIArguments()->emplace_back(v); @@ -135,14 +135,10 @@ namespace sd { _graph = graph; } - Graph* Node::getGraph() { + Graph* Node::graph() const { return _graph; } - bool Node::hasGraphEmbedded() { - return _graph != nullptr; - } - void Node::markInplace(bool reallyInplace) { _isInplace = reallyInplace; if (_protoContext != nullptr) { @@ -197,10 +193,10 @@ namespace sd { ContextPrototype * Node::getContextPrototype() { if (_protoContext == nullptr) - _protoContext = new ContextPrototype(this->getCustomOp() != nullptr ? this->getCustomOp()->getOpDescriptor() : nullptr, this->id()); + _protoContext = new ContextPrototype(this->customOp() != nullptr ? this->customOp()->getOpDescriptor() : nullptr, this->id()); if (_protoContext->inputs()->empty()) { - for (int e = 0; e < this->input()->size(); e++) { - _protoContext->inputs()->emplace_back(this->input()->at(e)); + for (int e = 0; e < this->input().size(); e++) { + _protoContext->inputs()->emplace_back(this->input().at(e)); } } return _protoContext; @@ -217,7 +213,7 @@ namespace sd { _id = id; } - sd::ops::DeclarableOp* Node::getCustomOp() { + sd::ops::DeclarableOp* Node::customOp() const { return _customOp; } @@ -229,7 +225,7 @@ namespace sd { _isInplace = true; } - bool Node::hasCustomOp() { + bool Node::hasCustomOp() const { return _customOp != nullptr; } @@ -355,24 +351,24 @@ namespace sd { _referencedBy.emplace_back(nodeId); } - OpType Node::opType() { + OpType Node::opType() const { return _opType; } - int Node::id() { + int Node::id() const { return _id; } - Nd4jLong Node::opNum() { + Nd4jLong Node::opNum() const { return _opNum; } - std::vector> *Node::input() { - return &_input; + const std::vector>& Node::input() const { + return _input; } - std::vector> *Node::output() { - return &_output; + const std::vector>& Node::output() const { + return _output; } bool Node::isScoped() { @@ -424,7 +420,7 @@ namespace sd { for (auto i: inputs) pickInput(i); - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: iArgs) block->getIArguments()->emplace_back(v); @@ -457,7 +453,7 @@ namespace sd { for (auto i: inputs) pickInput(i); - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: iArgs) block->getIArguments()->emplace_back(v); @@ -500,7 +496,7 @@ namespace sd { } } - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: dimensions) block->getAxis()->emplace_back(v); @@ -592,10 +588,10 @@ namespace sd { this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); + block->setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (opType == OpType_CUSTOM) { - if (this->getCustomOp()) { - auto block = new ContextPrototype(this->getCustomOp()->getOpDescriptor(), this->id(), false); + if (this->customOp()) { + auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: dimensions) block->getAxis()->emplace_back(v); @@ -751,14 +747,14 @@ namespace sd { this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); + block->setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { this->_isDeductable = true; auto block = new ContextPrototype(nullptr, this->id(), false); - for (int e = 0; e < this->input()->size(); e++) { - block->inputs()->emplace_back(this->input()->at(e)); + for (int e = 0; e < this->input().size(); e++) { + block->inputs()->emplace_back(this->input().at(e)); } // there's no other IArgs in legacy options, actually @@ -789,7 +785,7 @@ namespace sd { this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); + block->setOpDescriptor(this->customOp()->getOpDescriptor()); } } else if (this->_opType == OpType_CUSTOM) { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum()); @@ -800,8 +796,8 @@ namespace sd { auto block = new ContextPrototype(nullptr, this->id()); - for (int e = 0; e < this->input()->size(); e++) { - block->inputs()->emplace_back(this->input()->at(e)); + for (int e = 0; e < this->input().size(); e++) { + block->inputs()->emplace_back(this->input().at(e)); } if (node->extraInteger() != nullptr) @@ -831,7 +827,7 @@ namespace sd { this->setContextPrototype(block); this->setCustomOp(op); - block->setOpDescriptor(this->getCustomOp()->getOpDescriptor()); + block->setOpDescriptor(this->customOp()->getOpDescriptor()); } } else { // empty dynamic node, tests probably @@ -842,7 +838,7 @@ namespace sd { return _dataType; } - ContextPrototype* Node::protoContext() { + ContextPrototype* Node::protoContext() const { return _protoContext; } diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 5f341b2da3c7..fd89b2d602e8 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -92,6 +92,10 @@ namespace sd { return _placeholder; } + const std::string& sd::graph::Variable::name() const { + return _name; + } + const std::string& sd::graph::Variable::getName() const { return _name; } diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp index 8b6af83a98fb..c4a3747c0e19 100644 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -19,7 +19,6 @@ // #include -#include #include #include @@ -27,6 +26,8 @@ namespace sd { namespace graph { Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicConditional::processNode - not implemented yet"); + /* auto __variableSpace = graph->getVariableSpace(); auto size = node->input()->size(); @@ -131,6 +132,7 @@ namespace sd { } return sd::Status::OK(); + */ } } } \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index c9e7b0c17e6f..b8ca2999e71f 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -25,6 +25,8 @@ namespace sd { namespace graph { Nd4jStatus LogicEnter::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicEnter::processNode - not implemented yet"); + /* // this op replicates input variable into the frame. basically happens once for single loop. // sure, if there's inner loop within outer loop, it'll be called once for outer loop and multiple times for inner loop @@ -69,6 +71,7 @@ namespace sd { } return sd::Status::OK(); + */ } } } \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 7118cd606418..e00fca1f0d21 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -24,13 +24,15 @@ namespace sd { namespace graph { Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicMerge::processNode - not implemented yet"); + /* // at merge node only one of inputs exist if that's just switch and other node isn't LogicNextItration auto __variableSpace = graph->getVariableSpace(); auto __flowPath = __variableSpace->flowPath(); // merge MUST have 2 inputs - auto inputAddr0 = node->input()->at(0); - auto inputAddr1 = node->input()->at(1); + auto inputAddr0 = node->input().at(0); + auto inputAddr1 = node->input().at(1); bool isWhile = false; @@ -101,8 +103,8 @@ namespace sd { } else { // basically, first non-null variable is our target - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); + for (int e = 0; e < node->input().size(); e++) { + auto inputAddr = node->input().at(e); if (__variableSpace->hasVariable(inputAddr)) { auto var = __variableSpace->getVariable(inputAddr); @@ -129,6 +131,7 @@ namespace sd { } return Status::OK(); + */ } } } diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 9d3ae15cecff..9d26f0c950e4 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -24,6 +24,8 @@ namespace sd { namespace graph { Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicNextIeration::processNode - not implemented yet"); + /* auto __variableSpace = graph->getVariableSpace(); auto __flowPath = __variableSpace->flowPath(); @@ -45,6 +47,7 @@ namespace sd { lvar->markReadOnly(true); return ND4J_STATUS_OK; + */ } } } \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp index 2e8aacb812c7..7a1a4f997091 100644 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -25,6 +25,8 @@ namespace sd { namespace graph { Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicReturn::processNode - not implemented yet"); + /* auto __variableSpace = graph->getVariableSpace(); for (int e = 0; e < node->input()->size(); e++) { @@ -50,6 +52,7 @@ namespace sd { } return sd::Status::OK(); + */ } } } diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 85fd92191e46..cb522b4d0d23 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -20,12 +20,13 @@ #include #include -#include #include namespace sd { namespace graph { Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { + throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); + /* auto __variableSpace = graph->getVariableSpace(); auto __flowPath = __variableSpace->flowPath(); @@ -103,6 +104,7 @@ namespace sd { } return sd::Status::OK(); + */ }; } } diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index b2dd00589484..cf15cc88877b 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -20,7 +20,6 @@ #include #include -#include #include #include @@ -28,6 +27,8 @@ namespace sd { namespace graph { Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicWhile::processNode - not implemented yet"); + /* auto __variableSpace = graph->getVariableSpace(); nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); @@ -139,6 +140,7 @@ namespace sd { } return sd::Status::OK(); + */ } } } diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 268b396e03b2..8f93306a72bf 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -19,7 +19,7 @@ // #include -#include + namespace sd { namespace graph { @@ -38,7 +38,7 @@ namespace sd { auto _vs = varSpace->clone(); //_vs->workspace()->expandTo(100000); _vs->setFlowPath(&fp); - GraphExecutioner::execute(graph, _vs); + //GraphExecutioner::execute(graph, _vs); delete _vs; } @@ -52,7 +52,7 @@ namespace sd { auto _vs = varSpace->clone(); //_vs->workspace()->expandTo(100000); _vs->setFlowPath(&fp); - GraphExecutioner::execute(graph, _vs); + //GraphExecutioner::execute(graph, _vs); auto p = fp.profile(); if (e == 0) diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index 7351fc0fc394..986e20c179d3 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -72,7 +72,6 @@ bool verbose = false; #include #include #include -#include #include #include #include diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index b2cdd3d05d0d..072829d937e5 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -23,7 +23,6 @@ #include #include "legacy/NativeOpExecutioner.h" #include -#include #include #include #include @@ -1907,13 +1906,7 @@ void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { } sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - try { - return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return nullptr; } Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) { @@ -2207,27 +2200,7 @@ static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong gr varSpace->putVariable(idx, array); } - auto hZ = sd::graph::GraphExecutioner::execute(graph, varSpace); - auto varSet = new sd::graph::VariablesSet(hZ); - - if (hZ == ND4J_STATUS_OK) { - // pull back results, and provide them - auto outputs = graph->fetchOutputs(); - for (int e = 0; e < outputs->size(); e++) { - // we're only getting variable ID/Index from original grap. values will be taken from cloned workspace - std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); - - auto var = varSpace->getVariable(varId); - - varSet->push_back(var->clone()); - } - - delete outputs; - } - - delete graph; - - return varSet; + throw std::runtime_error("executeStoredGraphT - not implemented yet"); } sd::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { @@ -2301,89 +2274,8 @@ const char* getAllOperations() { return sd::OpTracker::getInstance()->exportOperations(); } - -Nd4jPointer getGraphState(Nd4jLong id) { - return (Nd4jPointer) new sd::graph::GraphState(id); -} - -void deleteGraphState(Nd4jPointer state) { - auto stateP = reinterpret_cast(state); - delete stateP; -} - -Nd4jStatus execCustomOpWithScope_(Nd4jPointer *extraPointers, sd::graph::GraphState *state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - /** - * That's basically exec, with VariableSpace provided in GraphState: - * depending on operation (i.e. while of if), different logic executors could be used - */ - - auto graph = state->graph(); - auto varSpace = state->variableSpace(); - - // Node is dynamically created, and has nothing beyond it: only inputs and outputs - // this node has id of 0, and inputs are - Node node(OpType_LOGIC, opHash, 0); - - // mapping inputs - for (int e = 0; e < numInputs; e++) { - auto buffer = inputBuffers[e]; - auto shapeInfo = reinterpret_cast(inputShapes[e]); - - auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace - varSpace->putVariable(0, e, array); - node.pickInput(0, e); - } - - // mapping scopes - for (int e = 0; e < numScopes; e++) { - // we should check scope existence in GraphState/Graph - int scopeId = (int) scopes[e]; - if (!state->hasScope(scopeId)) { - // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't exist\n", scopeId); - return Status::THROW(); - } - node.pickInput(scopeId, 0); - } - - auto hZ = LogicExecutor::processNode(graph, &node); - if (hZ != Status::OK()) - return hZ; - - // mapping outputs - - for (int e = 0; e < numOutputs; e++) { - auto buffer = outputBuffers[e]; - auto shapeInfo = reinterpret_cast(outputShapes[e]); - - NDArray array(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace to the same ID - //varSpace->putVariable(0, e, array); - - auto t = varSpace->getVariable(0, e)->getNDArray(); - array.assign(t); - } - - // removing input variables - for (int e = 0; e < numInputs; e++) { - varSpace->dropVariable(0, e); - } - - - // after some bla-bla-bla we should have Graph and Node for current op - return Status::OK(); -} - Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - try { - return execCustomOpWithScope_(extraPointers, reinterpret_cast(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + } void deleteResultWrapper(Nd4jPointer ptr) { diff --git a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp b/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp deleted file mode 100644 index 15974723f580..000000000000 --- a/libnd4j/tests_cpu/layers_tests/ConditionalTests.cpp +++ /dev/null @@ -1,332 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 16.10.2017. -// - -#include "testlayers.h" -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class ConditionalTests : public testing::Test { -public: - ConditionalTests(){ - //Environment::getInstance()->setVerbose(true); - //Environment::getInstance()->setDebug(true); - } - - ~ConditionalTests(){ - //Environment::getInstance()->setVerbose(false); - //Environment::getInstance()->setDebug(false); - } -}; - - -TEST_F(ConditionalTests, BasicTests_1) { - Graph graph; - - auto x = NDArrayFactory::valueOf({2, 2}, 1.0f); - auto y0 = NDArrayFactory::valueOf({2, 2}, 5.0f); - auto y1 = NDArrayFactory::valueOf({2, 2}, -5.0f); - auto scalar = NDArrayFactory::create_(1.0f); - - auto variableSpace = graph.getVariableSpace(); - - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y0); - variableSpace->putVariable(-3, y1); - variableSpace->putVariable(-4, scalar); - - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 1); - scopeCondition->setName("scopeCondition"); - - auto scopeFalse = new Node(OpType_LOGIC, logic::Scope, 2); - scopeFalse->setName("scopeFalse"); - - auto scopeTrue = new Node(OpType_LOGIC, logic::Scope, 3); - scopeTrue->setName("scopeTrue"); - - auto nodeF = new Node(OpType_PAIRWISE, pairwise::Add, 5, {-1, -2}); - nodeF->setScopeInfo(2, "scopeFalse"); - - auto nodeT = new Node(OpType_PAIRWISE, pairwise::Subtract, 6, {-1, -2}); - nodeT->setScopeInfo(3, "scopeTrue"); - - auto nodeC0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 7, {-1}); - nodeC0->setScopeInfo(1, "scopeCondition"); - - sd::ops::eq_scalar op; - auto nodeC1 = new Node(&op, 8, {7, -4}); - nodeC1->setScopeInfo(1, "scopeCondition"); - - graph.addNode(scopeCondition); - graph.addNode(scopeFalse); - graph.addNode(scopeTrue); - graph.addNode(nodeF); - graph.addNode(nodeT); - graph.addNode(nodeC0); - graph.addNode(nodeC1); - - // at this point graph should ounly have Nodes referring to the Scopes: condition scope, true scope and false scope - ASSERT_EQ(3, graph.totalNodes()); - - // now we're adding Condition op, that'll take all of those in - auto nodeCondition = new Node(OpType_LOGIC, logic::Conditional, 10, {1, 2, 3}); - graph.addNode(nodeCondition); - - ASSERT_EQ(4, graph.totalNodes()); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(variableSpace->hasVariable(10, 0)); - auto conditionalResult = variableSpace->getVariable(10, 0)->getNDArray(); - ASSERT_NE(nullptr, conditionalResult); - - ASSERT_NEAR(6.0, conditionalResult->meanNumber().e(0), 1e-5); -} -#ifdef GRAPH_FILES_OK -/** - * Condition is False - */ -TEST_F(ConditionalTests, Flat_Test_1) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0_1.fb"); - auto varSpace = graph->getVariableSpace(); - //varSpace->getVariable(1)->getNDArray()->assign(2.0); - //varSpace->getVariable(2)->getNDArray()->assign(0.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(15)); - - auto z = varSpace->getVariable(15)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-2, -2, -2, -2}); - - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * Condition is True - */ -TEST_F(ConditionalTests, Flat_Test_2) { - Environment::getInstance()->setDebug(true); - Environment::getInstance()->setVerbose(true); - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-1.0); - - graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(15)); - - auto z = varSpace->getVariable(15)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - - ASSERT_TRUE(exp.equalsTo(z)); - delete graph; -} - - -/** - * Condition is false here, so there loop will be skipped - */ -TEST_F(ConditionalTests, Flat_Test_3) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_3.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(1.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * just one cycle in body - */ -TEST_F(ConditionalTests, Flat_Test_4) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(2)->getNDArray()->assign(4.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - // 0.0 + 2.0 = 2.0 in each element - auto exp = NDArrayFactory::create('c', {2, 2}, {2, 2, 2, 2}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - - -/** - * just two cycles in body - */ -TEST_F(ConditionalTests, Flat_Test_5) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(2)->getNDArray()->assign(9.0); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(17)); - - auto z = varSpace->getVariable(17)->getNDArray(); - - ASSERT_NE(nullptr, z); - - // 0.0 + 2.0 + 2.0 = 4.0 in each element - auto exp = NDArrayFactory::create('c', {2, 2}, {4, 4, 4, 4}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * While loop with multiple variables - */ -TEST_F(ConditionalTests, Flat_Test_6) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-4.0f); - varSpace->getVariable(2)->getNDArray()->assign(1.0f); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(25)); - - auto z = varSpace->getVariable(25)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //z->printIndexedBuffer(); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-1, -1, -1, -1}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -TEST_F(ConditionalTests, Flat_Test_7) { - sd::ops::identity op0; - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); - auto varSpace = graph->getVariableSpace(); - varSpace->getVariable(1)->getNDArray()->assign(-9.0f); - varSpace->getVariable(2)->getNDArray()->assign(1.0f); - - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(25)); - - auto z = varSpace->getVariable(25)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //z->printIndexedBuffer(); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-3, -3, -3, -3}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} - -/** - * This test checks nested while execution - */ -TEST_F(ConditionalTests, Flat_Test_8) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_nested.fb"); - auto varSpace = graph->getVariableSpace(); - //graph->printOut(); - - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(varSpace->hasVariable(52)); - - auto z = varSpace->getVariable(52)->getNDArray(); - - ASSERT_NE(nullptr, z); - - //val exp = Nd4j.create(2, 2).assign(15.0); - auto exp = NDArrayFactory::create('c', {2, 2}, {15, 15, 15, 15}); - ASSERT_TRUE(exp.equalsTo(z)); - - delete graph; -} -#endif diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index 17e63f33f150..65a77d1539d3 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -25,7 +25,6 @@ #include #include #include -#include #include using namespace sd; diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 74e7d33e573a..b8820a3f274d 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -39,22 +39,22 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { Graph graph; // A - graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); // B - graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); // C - graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a("multiply", 10, {{-1, 0}, {-2, 0}}); - Node b("add", 20, {{10, 0}, {-3, 0}}); + Node a("multiply", sd::ops::multiply()); + Node b("add", sd::ops::add()); - graph.addNode(b); - graph.addNode(a); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); // we just check that nodes were really added - ASSERT_EQ(2, graph.totalNodes()); + ASSERT_EQ(2, graph.size()); auto optimized = graph.optimizedGraph(); @@ -78,28 +78,28 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { Graph graph; // A - graph.getVariableSpace()->putVariable(-1, 0, NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); // B - graph.getVariableSpace()->putVariable(-2, 0, NDArrayFactory::create('c', {3}, {2, 2, 2})); + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); // C - graph.getVariableSpace()->putVariable(-3, 0, NDArrayFactory::create('c', {3}, {3, 3, 3})); + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); // D - graph.getVariableSpace()->putVariable(-4, 0, NDArrayFactory::create('c', {3}, {4, 4, 4})); + graph.addVariable("C", NDArrayFactory::create('c', {3}, {4, 4, 4})); - Node a("multiply", 10, {{-1, 0}, {-2, 0}}); - Node b("add", 20, {{10, 0}, {-3, 0}}); - Node c("subtract", 30, {{10, 0}, {-4, 0}}); + Node a("multiply", sd::ops::multiply()); + Node b("add", sd::ops::add()); + Node c("subtract", sd::ops::subtract()); - graph.addNode(b); - graph.addNode(c); - graph.addNode(a); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + graph.addNode(c, {"multiply", "D"}); // we just check that nodes were really added - ASSERT_EQ(3, graph.totalNodes()); + ASSERT_EQ(3, graph.size()); auto optimized = graph.optimizedGraph(); diff --git a/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp deleted file mode 100644 index 878b05712351..000000000000 --- a/libnd4j/tests_cpu/layers_tests/GraphStateTests.cpp +++ /dev/null @@ -1,349 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "testlayers.h" -#include -#include -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class GraphStateTests : public testing::Test { -public: - GraphStateTests() { - Environment::getInstance()->setDebug(false); - Environment::getInstance()->setVerbose(false); - }; - - ~GraphStateTests() { - Environment::getInstance()->setDebug(false); - Environment::getInstance()->setVerbose(false); - } -}; - -/* - * PLAN: - * Create GraphState - * Register Scope - * Add few Ops to it - * Call conditional, that refers to scopes - * Check results - */ - -TEST_F(GraphStateTests, Basic_Tests_1) { - auto state = (GraphState *) getGraphState(117L); - ASSERT_EQ(117L, state->id()); - - // this call will create scope internally - state->registerScope(119); - - sd::ops::add opA; - sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg - - ArgumentsList argsA; - ArgumentsList argsB; - - state->attachOpToScope(119, 1, &opA, argsA); - state->attachOpToScope(119, 2, &opB, argsB); - - auto scope = state->getScope(119); - ASSERT_TRUE(scope != nullptr); - ASSERT_EQ(2, scope->size()); - - deleteGraphState(state); -} - -// just separate case for doubles wrapper in NativeOps, nothing else -TEST_F(GraphStateTests, Basic_Tests_2) { - auto state = (GraphState *) getGraphState(117L); - ASSERT_EQ(117L, state->id()); - - // this call will create scope internally - state->registerScope(119); - - sd::ops::add opA; - sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg - - ArgumentsList argsA; - ArgumentsList argsB; - - state->attachOpToScope(119, 1, &opA, argsA); - state->attachOpToScope(119, 2, &opB, argsB); - - auto scope = state->getScope(119); - ASSERT_TRUE(scope != nullptr); - ASSERT_EQ(2, scope->size()); - - deleteGraphState(state); -} - -/* -TEST_F(GraphStateTests, Stateful_Execution_1) { - auto state = getGraphState(117L); - - Nd4jLong scopes[] = {22, 33}; - //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - - ASSERT_EQ(Status::THROW(), status); - - deleteGraphState(state); -} - -TEST_F(GraphStateTests, Stateful_Execution_2) { - auto state = (GraphState *) getGraphState(117L); - - state->registerScope(22); - state->registerScope(33); - - Nd4jLong scopes[] = {22, 33}; - auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); - // it's no-op: just LogicScope - ASSERT_EQ(Status::OK(), status); - - deleteGraphState(state); -} - -// This test checks WHILE loop -TEST_F(GraphStateTests, Stateful_Execution_3) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(11.0f); - auto var2 = NDArrayFactory::create(2.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - auto res2 = NDArrayFactory::create(0.0f); - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo(), (Nd4jPointer)var2.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer(), (Nd4jPointer) res2.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo(), (Nd4jPointer) res2.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::lt_scalar op2; - - // while sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // body scope - state->registerScope(33); - - // var0 + var1 + var1 - // this op is var0 + var1 - ArgumentsList args3({{0, 0}, {0, 2}}); - - // this op is result of previous op + 1 - ArgumentsList args4({{3, 0}, {0, 2}}); - - sd::ops::add op3; - sd::ops::add op4; - - state->attachOpToScope(33, 3, &op3, args3); - state->attachOpToScope(33, 4, &op4, args4); - - // Now we define RETURN, which returns 1 modified variable, and 2 unmodified variables - ArgumentsList args5({{4, 0}, {0, 1}, {0, 2}}); - - // so, at the end of body, initial variables will be updated - state->defineReturn(33, 5, args5); - - Nd4jLong scopes[] = {22, 33}; - - // we're executing while loop - auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); - ASSERT_EQ(Status::OK(), status); - - // now we check provided result array - float sum = res0.reduceNumber(reduce::Sum).e(0); - - // Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26 - ASSERT_NEAR(26.0f, sum, 1e-5); - - // nd4j_printf("0 ------------------\n",""); - - deleteGraphState(state); - - // nd4j_printf("1 ------------------\n",""); -} - -// This test checks CONDITIONAL execution for FALSE -TEST_F(GraphStateTests, Stateful_Execution_4) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(5.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - - auto exp = NDArrayFactory::create('c', {2, 2}, {-4, -3, -2, -1}); - - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::lt_scalar op2; - - // if sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // false scope - state->registerScope(33); - - ArgumentsList args3({{0, 0}, {0, 1}}); - sd::ops::subtract op3; - state->attachOpToScope(33, 3, &op3, args3); - - // return for false scope - ArgumentsList args10({{3, 0}, {0, 1}}); - state->defineReturn(33, 10, args10); - - // true scope - state->registerScope(44); - - ArgumentsList args4({{0, 0}, {0, 1}}); - sd::ops::add op4; - state->attachOpToScope(44, 4, &op4, args4); - - // return for false scope - ArgumentsList args20({{4, 0}, {0, 1}}); - state->defineReturn(44, 20, args20); - - - Nd4jLong scopes[] = {22, 33, 44}; - - // we're executing conditional op - auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(&res0)); - ASSERT_TRUE(exp.equalsTo(&res0)); - - - deleteGraphState(state); -} - - -// This test checks CONDITIONAL execution for TRUE -TEST_F(GraphStateTests, Stateful_Execution_5) { - auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto var1 = NDArrayFactory::create(5.0f); - - auto res0 = NDArrayFactory::create('c', {2, 2}); - auto res1 = NDArrayFactory::create(0.0f); - - auto exp = NDArrayFactory::create('c', {2, 2}, {6, 7, 8, 9}); - - - // registering our GraphState holder - auto state = (GraphState *) getGraphState(117L); - - // we're prepping pointers to input/output buffers - Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; - Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; - - Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; - Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; - - // conditional scope - state->registerScope(22); - - sd::ops::LegacyReduceSameOp op1(reduce::Sum); - sd::ops::gt_scalar op2; - - // if sum(var0) < var1 - // this op takes sum - ArgumentsList args1({{0, 0}}); - - // this op compares result of sum to input variable 0:1 - ArgumentsList args2({{1, 0}, {0, 1}}); - - state->attachOpToScope(22, 1, &op1, args1); - state->attachOpToScope(22, 2, &op2, args2); - - // false scope - state->registerScope(33); - - ArgumentsList args3({{0, 0}, {0, 1}}); - sd::ops::subtract op3; - state->attachOpToScope(33, 3, &op3, args3); - - // return for false scope - ArgumentsList args10({{3, 0}, {0, 1}}); - state->defineReturn(33, 10, args10); - - // true scope - state->registerScope(44); - - ArgumentsList args4({{0, 0}, {0, 1}}); - sd::ops::add op4; - state->attachOpToScope(44, 4, &op4, args4); - - // return for false scope - ArgumentsList args20({{4, 0}, {0, 1}}); - state->defineReturn(44, 20, args20); - - - Nd4jLong scopes[] = {22, 33, 44}; - - // we're executing conditional op - auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(&res0)); - ASSERT_TRUE(exp.equalsTo(&res0)); - - deleteGraphState(state); -} -*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index cc965d084fe6..792e1b07aa6d 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -45,37 +45,34 @@ class GraphTests : public testing::Test { }; TEST_F(GraphTests, SingleInput1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0f); - graph->getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-1, x); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_STRICT, transform::Cosine, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {2}, {}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); + Node nodeB(OpType_TRANSFORM_STRICT, transform::Cosine); + Node nodeC(OpType_TRANSFORM_SAME, transform::Abs); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeC, {2}); - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); + ASSERT_EQ(3, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(3)); - auto node3 = graph->getVariableSpace()->getVariable(3)->getNDArray(); + auto node3 = graph.getVariableSpace()->getVariable(3)->getNDArray(); ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, DoubleInput1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -85,30 +82,28 @@ TEST_F(GraphTests, DoubleInput1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, y); + graph.getVariableSpace()->putVariable(-3, z); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {3}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {3}); - auto nodeC = new Node(OpType_PAIRWISE, pairwise::Add, 3, {1, 2}, {-3}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); + Node nodeB(OpType_TRANSFORM_SAME, transform::Abs); + Node nodeC(OpType_PAIRWISE, pairwise::Add); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {-2}); + graph.addNode(nodeC, {1, 2}); - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - GraphExecutioner::execute(graph); + ASSERT_EQ(3, graph.size()); - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + graph.execute(); - delete graph; + ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, SingleInput3) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -116,31 +111,28 @@ TEST_F(GraphTests, SingleInput3) { auto v0 = NDArrayFactory::create_('c', {5, 5}); auto v1 = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, v0); - graph->getVariableSpace()->putVariable(-3, v1); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, v0); + graph.getVariableSpace()->putVariable(-3, v1); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2, 3}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Ones, 3, {1}, {-3}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); + Node nodeC(OpType_TRANSFORM_SAME, transform::Ones); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeC, {1}); - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); + ASSERT_EQ(3, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(1.4142135, v0->reduceNumber(reduce::Mean).e(0), 1e-5); ASSERT_NEAR(1.0, v1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, SingleInput4) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -148,37 +140,34 @@ TEST_F(GraphTests, SingleInput4) { auto v0 = NDArrayFactory::create_('c', {5, 5}); auto v1 = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, v0); - graph->getVariableSpace()->putVariable(-3, v1); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, v0); + graph.getVariableSpace()->putVariable(-3, v1); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {4, 5}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); + Node nodeC(OpType_TRANSFORM_SAME, transform::Neg); - auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Ones, 4, {3}, {-2}); - auto nodeE = new Node(OpType_TRANSFORM_SAME, transform::Identity, 5, {3}, {-3}); + Node nodeS(OpType_TRANSFORM_SAME, transform::Ones); + Node nodeE(OpType_TRANSFORM_SAME, transform::Identity); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeS); - graph->addNode(nodeE); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeC, {2}); + graph.addNode(nodeS, {3}); + graph.addNode(nodeE, {3}); - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(5, graph->totalNodes()); + ASSERT_EQ(5, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(1.0, v0->reduceNumber(reduce::Mean).e(0), 1e-5); ASSERT_NEAR(-1.4142135, v1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, DoubleInput2) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -189,41 +178,38 @@ TEST_F(GraphTests, DoubleInput2) { auto z0 = NDArrayFactory::create_('c', {5, 5}); auto z1 = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z0); - graph->getVariableSpace()->putVariable(-4, z1); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, y); + graph.getVariableSpace()->putVariable(-3, z0); + graph.getVariableSpace()->putVariable(-4, z1); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + Node nodeC(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3}); - auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4}); + Node nodeT(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); + Node nodeU(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); + Node nodeV(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4}); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeT); - graph->addNode(nodeU); - graph->addNode(nodeV); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeC, {2}); + graph.addNode(nodeT, {-2}); + graph.addNode(nodeU, {4}); + graph.addNode(nodeV, {5}); - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(6, graph->totalNodes()); + ASSERT_EQ(6, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, DoubleInput3) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -237,49 +223,46 @@ TEST_F(GraphTests, DoubleInput3) { auto w = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z0); - graph->getVariableSpace()->putVariable(-4, z1); - graph->getVariableSpace()->putVariable(-5, w); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, y); + graph.getVariableSpace()->putVariable(-3, z0); + graph.getVariableSpace()->putVariable(-4, z1); + graph.getVariableSpace()->putVariable(-5, w); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3, 21}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + Node nodeC(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3, 21}); - auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4, 21}); + Node nodeT(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); + Node nodeU(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); + Node nodeV(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4, 21}); - auto nodeW = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 13}, {22}); - auto nodeZ = new Node(OpType_TRANSFORM_SAME, transform::Abs, 22, {21}, {-5}); + Node nodeW(OpType_PAIRWISE, pairwise::Add, 21, {3, 13}, {22}); + Node nodeZ(OpType_TRANSFORM_SAME, transform::Abs, 22, {21}, {-5}); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeT); - graph->addNode(nodeU); - graph->addNode(nodeV); - graph->addNode(nodeW); - graph->addNode(nodeZ); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeC, {2}); + graph.addNode(nodeT, {-2}); + graph.addNode(nodeU, {4}); + graph.addNode(nodeV, {5}); + graph.addNode(nodeW, {3, 6}); + graph.addNode(nodeZ, {7}); - ASSERT_EQ(2, graph->rootNodes()); - ASSERT_EQ(8, graph->totalNodes()); + ASSERT_EQ(8, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); ASSERT_NEAR(2.4142135, w->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, QuadInput1) { - auto graph = new Graph(); + Graph graph; auto x0 = NDArrayFactory::create_('c', {5, 5}); x0->assign(0.0); @@ -296,91 +279,83 @@ TEST_F(GraphTests, QuadInput1) { auto z = NDArrayFactory::create_('c', {5, 5}); z->assign(119.0); - graph->getVariableSpace()->putVariable(-1, x0); - graph->getVariableSpace()->putVariable(-2, x1); - graph->getVariableSpace()->putVariable(-3, x2); - graph->getVariableSpace()->putVariable(-4, x3); - graph->getVariableSpace()->putVariable(-5, z); + graph.getVariableSpace()->putVariable(-1, x0); + graph.getVariableSpace()->putVariable(-2, x1); + graph.getVariableSpace()->putVariable(-3, x2); + graph.getVariableSpace()->putVariable(-4, x3); + graph.getVariableSpace()->putVariable(-5, z); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {-3}, {21}); - auto nodeD = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {-4}, {21}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); + Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); + Node nodeC(OpType_TRANSFORM_SAME, transform::Abs, 3, {-3}, {21}); + Node nodeD(OpType_TRANSFORM_SAME, transform::Abs, 4, {-4}, {21}); - auto nodeP1 = new Node(OpType_PAIRWISE, pairwise::Add, 11, {1, 2}, {31}); - auto nodeP2 = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 4}, {31}); + Node nodeP1(OpType_PAIRWISE, pairwise::Add, 11, {1, 2}, {31}); + Node nodeP2(OpType_PAIRWISE, pairwise::Add, 21, {3, 4}, {31}); - auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {11, 21}, {-5}); + Node nodeZ(OpType_PAIRWISE, pairwise::Add, 31, {11, 21}, {-5}); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - graph->addNode(nodeD); - graph->addNode(nodeP1); - graph->addNode(nodeP2); - graph->addNode(nodeZ); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {-2}); + graph.addNode(nodeC, {-3}); + graph.addNode(nodeD, {-4}); + graph.addNode(nodeP1, {1, 2}); + graph.addNode(nodeP2, {3, 4}); + graph.addNode(nodeZ, {11, 21}); - ASSERT_EQ(4, graph->rootNodes()); - ASSERT_EQ(7, graph->totalNodes()); + ASSERT_EQ(7, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(6.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, InternalBranching1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(0.0); auto z = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, z); // 1.0 - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); // -1 - auto nodeK = new Node(OpType_TRANSFORM_SAME, transform::Neg, 11, {1}, {12}); + Node nodeK(OpType_TRANSFORM_SAME, transform::Neg, 11, {1}, {12}); // 2.0 - auto nodeL = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 12, {11}, {31}); + Node nodeL(OpType_TRANSFORM_SAME, transform::OneMinus, 12, {11}, {31}); // -1 - auto nodeR = new Node(OpType_TRANSFORM_SAME, transform::Neg, 21, {1}, {22}); + Node nodeR(OpType_TRANSFORM_SAME, transform::Neg, 21, {1}, {22}); // 1 - auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Neg, 22, {21}, {31}); + Node nodeS(OpType_TRANSFORM_SAME, transform::Neg, 22, {21}, {31}); // 1.0 - auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {12, 22}, {-2}); + Node nodeZ(OpType_PAIRWISE, pairwise::Add, 31, {12, 22}, {-2}); - graph->addNode(nodeA); - graph->addNode(nodeK); - graph->addNode(nodeL); - graph->addNode(nodeR); - graph->addNode(nodeS); - graph->addNode(nodeZ); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(6, graph->totalNodes()); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeK, {1}); + graph.addNode(nodeL, {2}); + graph.addNode(nodeR, {1}); + graph.addNode(nodeS, {1}); + graph.addNode(nodeZ, {1, 1}); - GraphExecutioner::execute(graph); + ASSERT_EQ(6, graph.size()); - ASSERT_EQ(3, nodeZ->getLayer()); + graph.execute(); ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, ReductionsTest1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); for (int r = 0; r < x->rows(); r++) { @@ -391,30 +366,25 @@ TEST_F(GraphTests, ReductionsTest1) { auto z = NDArrayFactory::create_('c', {5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, z); -// sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { + Node nodeA(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); + Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - auto nodeA = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); - graph->addNode(nodeA); - graph->addNode(nodeB); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); + ASSERT_EQ(2, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(2.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, IndexReductionsTest1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); for (int r = 0; r < x->rows(); r++) { @@ -425,26 +395,23 @@ TEST_F(GraphTests, IndexReductionsTest1) { auto z = NDArrayFactory::create_('c', {5, 1}); auto axis = NDArrayFactory::create_('c', {1}, {1}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); + + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, z); //graph->getVariableSpace()->putVariable(-3, axis); - auto nodeA = new Node(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); + Node nodeA(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); + Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - graph->addNode(nodeA); - graph->addNode(nodeB); + graph.addNode(nodeA, {-1, -2}); + graph.addNode(nodeB, {1}); - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); + ASSERT_EQ(2, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(4.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete axis; } #if 0 @@ -524,7 +491,7 @@ TEST_F(GraphTests, AutoOutput2) { #endif TEST_F(GraphTests, BroadcastTest1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(0.f); @@ -535,55 +502,50 @@ TEST_F(GraphTests, BroadcastTest1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - graph->getVariableSpace()->putVariable(-3, z); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, y); + graph.getVariableSpace()->putVariable(-3, z); - auto nodeA = new Node(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); + Node nodeA(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); + Node nodeB(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); - graph->addNode(nodeA); - graph->addNode(nodeB); + graph.addNode(nodeA, {-1, -2}); + graph.addNode(nodeB, {1}); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, ScalarTest1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); auto z = NDArrayFactory::create_('c', {5, 5}); - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, z); + graph.getVariableSpace()->putVariable(-1, x); + graph.getVariableSpace()->putVariable(-2, z); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - auto nodeE = new Node(OpType_SCALAR, scalar::Add, 3, {2}, {-2}, {}, 1.3f); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + Node nodeE(OpType_SCALAR, scalar::Add, 3, {2}, {-2}, {}, 1.3f); - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeE); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); + graph.addNode(nodeE, {2}); - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); + ASSERT_EQ(3, graph.size()); - GraphExecutioner::execute(graph); + graph.execute(); ASSERT_NEAR(2.714213, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; } TEST_F(GraphTests, SymbolicLookupTest1) { - auto graph = new Graph(); + Graph graph; auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); @@ -599,628 +561,41 @@ TEST_F(GraphTests, SymbolicLookupTest1) { vX->setName(a); vZ->setName(o); - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); + graph.getVariableSpace()->putVariable(-1, vX); + graph.getVariableSpace()->putVariable(-2, vZ); - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); std::string p("phi"); std::string t("theta"); - nodeA->setName(&p); - nodeB->setName(&t); + nodeA.setName(&p); + nodeB.setName(&t); - graph->addNode(nodeA); - graph->addNode(nodeB); + graph.addNode(nodeA, {-1}); + graph.addNode(nodeB, {1}); - auto rX = graph->getVariableSpace()->getVariable(a); - auto rZ = graph->getVariableSpace()->getVariable(o); + auto rX = graph.getVariableSpace()->getVariable(a); + auto rZ = graph.getVariableSpace()->getVariable(o); std::string om("omicron"); ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph->getVariableSpace()->hasVariable(om)); + ASSERT_FALSE(graph.getVariableSpace()->hasVariable(om)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(2)); - GraphExecutioner::execute(graph); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(p)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(t)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(p)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(t)); ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, OutputValidation1) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(0, outputs->size()); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation2) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - graph->addOutput(-2); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation3) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation4) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT_AND_IMPLICIT; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - graph->addOutput(-1); - - // not a typo. we want this value only once - graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(2, outputs->size()); - - ASSERT_NEAR(1.4142135, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} - - -TEST_F(GraphTests, OutputValidation5) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Sqrt, 2, {1}, {-2}); - - graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(4, outputs->size()); - - delete graph; - delete outputs; -} - -TEST_F(GraphTests, OutputValidation6) { - auto graph = new Graph(); - - graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); - - auto vX = new Variable(x); - auto vZ = new Variable(z); - - std::string a("alpha"); - std::string o("omega"); - - vX->setName(a); - vZ->setName(o); - - graph->getVariableSpace()->putVariable(-1, vX); - graph->getVariableSpace()->putVariable(-2, vZ); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); - - //graph->addOutput(-1); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - -// nd4j_printf("Returned variables: \n", ""); -// for (int e = 0; e < outputs->size(); e++) { -// printf("%i, ", outputs->at(e)->id()); -// } -// printf("\n"); - - ASSERT_EQ(4, outputs->size()); - - //ASSERT_NEAR(1.4142135, graph->fetchOutputs()->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - delete graph; - delete outputs; -} - -TEST_F(GraphTests, TestMultiOutput1) { - sd::ops::testop2i2o op1; - auto graph = new Graph(); - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-3.0); - - graph->getVariableSpace()->putVariable(-1, x); - graph->getVariableSpace()->putVariable(-2, y); - - - // Abs - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); - nodeA0->markInplace(false); - auto nodeB0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); - nodeB0->markInplace(false); - - auto op = sd::ops::OpRegistrator::getInstance()->getOperation("testop2i2o"); - - // this op will add 1.0 to first input, and 2.0 for second input - auto nodeT = new Node(op, 11, {1, 2}, {21, 31}, {}, 0.0f); - nodeT->setName("TestOp2i2o"); - nodeT->markInplace(false); - - - // this op will subtract this value from 1.0 - auto nodeX = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 21); - nodeX->markInplace(false); - nodeX->pickInput(11, 0); - - // this op will subtract this value from 1.0 - auto nodeY = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 31); - nodeY->markInplace(false); - nodeY->pickInput(11, 1); - - graph->addNode(nodeA0); - graph->addNode(nodeB0); - graph->addNode(nodeT); - graph->addNode(nodeX); - graph->addNode(nodeY); - - std::pair pair0(11,0); - std::pair pair1(11,1); - - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair0)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair1)); - - Nd4jStatus status = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_NEAR(-2.0f, graph->getVariableSpace()->getVariable(21)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(-4.0f, graph->getVariableSpace()->getVariable(31)->getNDArray()->meanNumber().e(0), 1e-5); - - delete graph; -} - -TEST_F(GraphTests, TestDivergentNode1) { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation("Switch"); - auto nodeY = new Node(op, 1); - - ASSERT_TRUE(nodeY->isDivergencePoint()); - ASSERT_TRUE(nodeY->isActive()); - - delete nodeY; -} - - -TEST_F(GraphTests, MemoryEstimationTest1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - - ASSERT_EQ(2, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(25 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest2) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - //nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - - ASSERT_EQ(2, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(0, memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest3) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(26 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest4) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {1}); - nodeA1->markInplace(false); - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ(30 * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, MemoryEstimationTest5) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph.getVariableSpace()->putVariable(-1, x); - - sd::ops::testcustom op; - - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - auto nodeA2 = new Node(&op, 3, {2}, {}, {}); - nodeA1->markInplace(false); - - - graph.addNode(nodeA0); - graph.addNode(nodeA1); - graph.addNode(nodeA2); - - graph.buildGraph(); - - ASSERT_EQ(3, graph.totalNodes()); - ASSERT_EQ(1, graph.rootNodes()); - - auto memReq = graph.estimateRequiredMemory(); - - ASSERT_EQ((25 + 100) * x->sizeOfT(), memReq); -} - -TEST_F(GraphTests, TestGraphInGraph_1) { - // this one is external graph - Graph graphA; - - // and this ons is embedded - Graph graphB; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - auto modifier = NDArrayFactory::create_('c', {5, 5}); - modifier->assign(3.0); - - graphA.getVariableSpace()->putVariable(-1, x); - graphB.getVariableSpace()->putVariable(-2, modifier); - - // this is placeholder variable - graphB.getVariableSpace()->putVariable(-1, new Variable(true)); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); - - // graph should return 12: abs(3.0 x -4) - auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); - - // 1 - 12 = -11 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); - - nodeA2->setGraph(&graphB); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - - // this is going to be PWT - auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); - auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - - graphB.addNode(nodeB0); - graphB.addNode(nodeB1); - - graphB.buildGraph(); - graphA.buildGraph(); - - ASSERT_EQ(0, nodeA0->getLayer()); - ASSERT_EQ(1, nodeA1->getLayer()); - ASSERT_EQ(2, nodeA2->getLayer()); - ASSERT_EQ(3, nodeA3->getLayer()); - - ASSERT_EQ(0, nodeB0->getLayer()); - ASSERT_EQ(1, nodeB1->getLayer()); - - Nd4jStatus status = GraphExecutioner::execute(&graphA); - ASSERT_EQ(ND4J_STATUS_OK, status); - - float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); - - //nd4j_printf("OpResult: %f\n", m); - - ASSERT_NEAR(-11.0, m, 1e-5); -} - -// test for symbolic lookup -TEST_F(GraphTests, TestGraphInGraph_2) { - // this one is external graph - Graph graphA; - - // and this ons is embedded - Graph graphB; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - auto modifier = NDArrayFactory::create_('c', {5, 5}); - modifier->assign(3.0); - - std::string nameA1("_nodeA1"); - - graphA.getVariableSpace()->putVariable(-1, x); - graphB.getVariableSpace()->putVariable(-2, modifier); - - // this is placeholder variable - auto placeHolder = new Variable(true); - placeHolder->setName(nameA1); - graphB.getVariableSpace()->putVariable(-1, placeHolder); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); - nodeA1->setName(nameA1); - - // graph should return 12: abs(3.0 x -4) - auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); - - // 1 - 12 = -11 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); - - nodeA2->setGraph(&graphB); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - - // this is going to be PWT - auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); - auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); - - graphB.addNode(nodeB0); - graphB.addNode(nodeB1); - - graphB.buildGraph(); - graphA.buildGraph(); - - ASSERT_EQ(0, nodeA0->getLayer()); - ASSERT_EQ(1, nodeA1->getLayer()); - ASSERT_EQ(2, nodeA2->getLayer()); - ASSERT_EQ(3, nodeA3->getLayer()); - - ASSERT_EQ(0, nodeB0->getLayer()); - ASSERT_EQ(1, nodeB1->getLayer()); - - Nd4jStatus status = GraphExecutioner::execute(&graphA); - ASSERT_EQ(ND4J_STATUS_OK, status); - - float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); - - //nd4j_printf("OpResult: %f\n", m); - - ASSERT_NEAR(-11.0, m, 1e-5); } #if 0 diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 3206ba513995..baac6bbc154a 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -77,8 +77,8 @@ TEST_F(GraphTests2, test_execution_1) { Node a("multiply_node", sd::ops::multiply()); Node b("add_node", sd::ops::add()); - graph.addNode(a, std::vector{"A", "B"}); - graph.addNode(b, std::vector{"multiply_node", "C"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply_node", "C"}); auto result = graph.execute({}, {"add_node"}); ASSERT_EQ(1, result.size()); @@ -90,7 +90,8 @@ TEST_F(GraphTests2, test_placeholder_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + Node node("tanh_node", sd::ops::tanh()); + graph.addNode(node, {"input"}); // this test must throw an exception, because input isn't resolved yet ASSERT_ANY_THROW(graph.execute()); @@ -114,7 +115,8 @@ TEST_F(GraphTests2, test_output_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node ("tanh", "tanh_node", 10, {{"input"}})); + Node node("tanh_node", sd::ops::tanh()); + graph.addNode(node, {"input"}); // since we're requesting output of non-existent node - we expect exception ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), graph::unresolved_output_exception); diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 0c6410aea22a..284313d081bf 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -397,264 +397,4 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(ListOperationsTests, GraphTests_Sequential_1) { - Graph graph; - - auto matrix = NDArrayFactory::create_('c', {3, 3}); - auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - - auto exp = NDArrayFactory::create('c', {3, 3}); - auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp.at(0)->assign(0.f); - tadsExp.at(1)->assign(-1.f); - tadsExp.at(2)->assign(-2.f); - - auto indices = NDArrayFactory::valueOf({3}, 1, 'c'); - //indices->linspace(0); - - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, matrix); - variableSpace->putVariable(-2, indices); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); - - // creating list - sd::ops::create_list opB; - auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); - //nodeB->setCustomOp(&opB); - - // filling list with matrix - sd::ops::split_list opC; - auto nodeC = new Node(&opC, 3, {2, 1, -2}); - //nodeC->setCustomOp(&opC); - - // reading chunks from List. We're adding op number 3 in inputs, to ensure graph will execute this node after split - sd::ops::read_list opD; - auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); - auto nodeD1 = new Node(&opD, 6, {2, 3}, {},{}, 0.0f, {}, {1}); - auto nodeD2 = new Node(&opD, 7, {2, 3}, {},{}, 0.0f, {}, {2}); - //nodeD0->setCustomOp(&opD); - //nodeD1->setCustomOp(&opD); - //nodeD2->setCustomOp(&opD); - - // using OneMinus on each chunk separately - auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); - auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); - auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); - - // writing chunks back to the List - sd::ops::write_list opF; - auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); - auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); - auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); - -// nodeF0->setCustomOp(&opF); -// nodeF1->setCustomOp(&opF); -// nodeF2->setCustomOp(&opF); - - // now we're stacking chunks back to matrix state - sd::ops::stack_list opG; - auto nodeG = new Node(&opG, 20, {2, 15, 16, 17}); - //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); - -// nodeG->setCustomOp(&opG); - - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeC); - graph.addNode(nodeD0); - graph.addNode(nodeD1); - graph.addNode(nodeD2); - graph.addNode(nodeE0); - graph.addNode(nodeE1); - graph.addNode(nodeE2); - - graph.addNode(nodeF0); - graph.addNode(nodeF1); - graph.addNode(nodeF2); - - graph.addNode(nodeG); - - // let's also validate structural integrity - graph.buildGraph(); - - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeC->getLayer()); - - ASSERT_EQ(3, nodeD0->getLayer()); - ASSERT_EQ(3, nodeD1->getLayer()); - ASSERT_EQ(3, nodeD2->getLayer()); - - ASSERT_EQ(4, nodeE0->getLayer()); - ASSERT_EQ(4, nodeE1->getLayer()); - ASSERT_EQ(4, nodeE2->getLayer()); - - ASSERT_EQ(5, nodeF0->getLayer()); - ASSERT_EQ(5, nodeF1->getLayer()); - ASSERT_EQ(5, nodeF2->getLayer()); - - ASSERT_EQ(6, nodeG->getLayer()); - - auto result = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(variableSpace->hasVariable(2)); - auto list = variableSpace->getVariable(2)->getNDArrayList(); - - ASSERT_TRUE(list != nullptr); - - ASSERT_EQ(3, list->height()); - ASSERT_EQ(3, list->elements()); - - - ASSERT_TRUE(variableSpace->hasVariable(20)); - - auto stack = variableSpace->getVariable(20)->getNDArray(); - - ASSERT_TRUE(stack != nullptr); - - ASSERT_TRUE(exp.isSameShape(stack)); - ASSERT_TRUE(exp.equalsTo(stack)); -} - - -TEST_F(ListOperationsTests, GraphTests_Sequential_2) { - Graph graph; - - auto scalar = NDArrayFactory::create_(0.0f); - auto matrix = NDArrayFactory::create_('c', {3, 3}); - auto tads = matrix->allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); - } - - auto exp = NDArrayFactory::create('c', {3, 3}); - auto tadsExp = exp.allTensorsAlongDimension({1}); - tadsExp.at(0)->assign(0.f); - tadsExp.at(1)->assign(-1.f); - tadsExp.at(2)->assign(-2.f); - - //auto indices = NDArray::valueOf({1, 3}, 1.0f, 'c'); - auto indices = NDArrayFactory::create_('c', {1, 3}); - indices->linspace(0); - - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, matrix); - variableSpace->putVariable(-2, indices); - variableSpace->putVariable(-3, scalar); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); - - // creating list - sd::ops::create_list opB; - auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); -// nodeB->setCustomOp(&opB); - - // filling list with matrix - sd::ops::scatter_list opC; - auto nodeC = new Node(&opC, 3, {2, -2, 1, -3}); - - //nodeC->setCustomOp(&opC); - - sd::ops::read_list opD; - auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); - auto nodeD1 = new Node(&opD, 6, {2, 3, 15}, {},{}, 0.0f, {}, {1}); - auto nodeD2 = new Node(&opD, 7, {2, 3, 16}, {},{}, 0.0f, {}, {2}); - -// nodeD0->setCustomOp(&opD); -// nodeD1->setCustomOp(&opD); -// nodeD2->setCustomOp(&opD); - - - // using OneMinus on each chunk separately - auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); - auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); - auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); - - // writing chunks back to the List - sd::ops::write_list opF; - auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); - auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); - auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); - -// nodeF0->setCustomOp(&opF); -// nodeF1->setCustomOp(&opF); -// nodeF2->setCustomOp(&opF); - - // now we're gathering chunks back to matrix state - sd::ops::pick_list opG; - auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17}); - //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); - - //nodeG->setCustomOp(&opG); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeC); - graph.addNode(nodeD0); - graph.addNode(nodeD1); - graph.addNode(nodeD2); - graph.addNode(nodeE0); - graph.addNode(nodeE1); - graph.addNode(nodeE2); - - graph.addNode(nodeF0); - graph.addNode(nodeF1); - graph.addNode(nodeF2); - - graph.addNode(nodeG); - - // let's also validate structural integrity - graph.buildGraph(); - - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeC->getLayer()); - - ASSERT_EQ(3, nodeD0->getLayer()); - ASSERT_EQ(4, nodeE0->getLayer()); - ASSERT_EQ(5, nodeF0->getLayer()); - - ASSERT_EQ(6, nodeD1->getLayer()); - ASSERT_EQ(7, nodeE1->getLayer()); - ASSERT_EQ(8, nodeF1->getLayer()); - - ASSERT_EQ(9, nodeD2->getLayer()); - ASSERT_EQ(10, nodeE2->getLayer()); - ASSERT_EQ(11, nodeF2->getLayer()); - - ASSERT_EQ(12, nodeG->getLayer()); - - - auto result = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(variableSpace->hasVariable(2)); - auto list = variableSpace->getVariable(2)->getNDArrayList(); - - ASSERT_TRUE(list != nullptr); - - ASSERT_EQ(3, list->height()); - ASSERT_EQ(3, list->elements()); - - ASSERT_TRUE(variableSpace->hasVariable(20)); - - auto stack = variableSpace->getVariable(20)->getNDArray(); - - ASSERT_TRUE(stack != nullptr); - - ASSERT_TRUE(exp.isSameShape(stack)); - ASSERT_TRUE(exp.equalsTo(stack)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 5c6d5d652413..71cf5eb622da 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -65,7 +65,7 @@ TEST_F(NodeTests, Test_Dtype_Conversion_2) { // ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); ASSERT_EQ(nodeA->opType(), nf->opType()); ASSERT_EQ(nodeA->opNum(), nf->opNum()); - ASSERT_EQ(nodeA->getCustomOp()->getOpHash(), nf->getCustomOp()->getOpHash()); + ASSERT_EQ(nodeA->customOp()->getOpHash(), nf->customOp()->getOpHash()); delete nodeA; delete nd; diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 34ebcfd9a026..20b7780067ef 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -41,10 +40,7 @@ TEST_F(OneOffTests, test_avg_pool_3d_1) { ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); delete graph; } @@ -53,10 +49,7 @@ TEST_F(OneOffTests, test_avg_pool_3d_2) { ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); delete graph; } @@ -65,10 +58,7 @@ TEST_F(OneOffTests, test_non2d_0A_1) { ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); delete graph; } @@ -96,10 +86,7 @@ TEST_F(OneOffTests, test_assert_scalar_float32_2) { ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); delete graph; } @@ -110,18 +97,13 @@ TEST_F(OneOffTests, test_pad_1D_1) { ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(4)); auto z = graph->getVariableSpace()->getVariable(4)->getNDArray(); ASSERT_TRUE(z != nullptr); - // z->printIndexedBuffer("z"); - ASSERT_EQ(e, *z); delete graph; } @@ -161,10 +143,7 @@ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { auto graph = Graph::fromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(9)); @@ -182,10 +161,8 @@ TEST_F(OneOffTests, test_tensor_array_1) { auto graph = Graph::fromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); + graph->execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); @@ -202,10 +179,8 @@ TEST_F(OneOffTests, test_tensor_array_2) { auto graph = Graph::fromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); + graph->execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); @@ -222,11 +197,8 @@ TEST_F(OneOffTests, test_tensor_array_3) { auto graph = Graph::fromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - + graph->execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(15)); auto z = graph->getVariableSpace()->getVariable(15)->getNDArray(); @@ -243,11 +215,8 @@ TEST_F(OneOffTests, test_tensor_array_4) { auto graph = Graph::fromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); + graph->execute(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); @@ -264,11 +233,8 @@ TEST_F(OneOffTests, test_assert_4) { auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); + graph->execute(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); @@ -335,11 +301,8 @@ TEST_F(OneOffTests, test_identity_n_2) { auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - + graph->execute(); - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1)); @@ -357,10 +320,7 @@ TEST_F(OneOffTests, test_non2d_1) { auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); @@ -379,10 +339,7 @@ TEST_F(OneOffTests, test_reduce_all_1) { auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); ASSERT_TRUE(graph != nullptr); - // graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + graph->execute(); ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); diff --git a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp index e2927aa0b205..520ed0134b7a 100644 --- a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -20,7 +20,6 @@ #include "testlayers.h" -#include /* diff --git a/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp b/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp deleted file mode 100644 index 6c83e869e91d..000000000000 --- a/libnd4j/tests_cpu/layers_tests/ScopeTests.cpp +++ /dev/null @@ -1,165 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 15.10.2017. -// - -#include "testlayers.h" -#include -#include -#include - -using namespace sd; -using namespace sd::graph; - -class ScopeTests : public testing::Test { -public: - -}; - -TEST_F(ScopeTests, BasicTests_1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {2, 2}); - x->assign(0.0f); - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, x); - - sd::ops::Scope opScope; - - auto scopeBody = new Node(OpType_LOGIC, 10, 1); - scopeBody->setName("scopeBody"); - scopeBody->setCustomOp(&opScope); - - - graph.addNode(scopeBody); - - ASSERT_EQ(1, graph.totalNodes()); - - auto scopedB0 = new Node(OpType_SCALAR, 0, 6, {-1}, {}, {}, 1.0f); - scopedB0->markInplace(true); - scopedB0->setScopeInfo(1, "scopeBody"); - - graph.addNode(scopedB0); - - ASSERT_EQ(1, graph.totalNodes()); - -} -/* -TEST_F(ScopeTests, RealTests_1) { - Graph graph; - - auto x = NDArrayFactory::create_('c', {2, 2}); - x->assign(0.0f); - - auto y = NDArrayFactory::create_('c', {2, 2}); - y->assign(0.0); - -// auto scalar = NDArrayFactory::create_('c', {1, 1}); - auto scalar = NDArrayFactory::create_(10.f); - //scalar->p(0, 10); - - auto variableSpace = graph.getVariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, scalar); - - // just few ops coming before while - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 1, {-1}); - auto nodeB = new Node(OpType_SCALAR, scalar::Add, 2, {1}, {}, {}, 1.0); - - // - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - sd::ops::Scope opScope; - scopeCondition->setCustomOp(&opScope); - - // this is scope of the body, it'll be executed multiple times - auto scopeBody = new Node(OpType_LOGIC, logic::Scope, 10); - scopeBody->setName("scopeBody"); - scopeBody->setCustomOp(&opScope); - -//////////////////////////////////////////////////////////////////////////////////////////////////// -//// filling out condition scope -//////////////////////////////////////////////////////////////////////////////////////////////////// - // this is Sum accumulation, which feed - auto scopedA0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 4, {12}); - scopedA0->setScopeInfo(3, "scopeCondition"); - - // this op compares LT A0 result with variable `scalar` which is 10; - sd::ops::lt_scalar op; - auto scopedA1 = new Node(&op, 5, {4, -3}); - scopedA1->setScopeInfo(3, "scopeCondition"); - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -//// filling out body scope -//////////////////////////////////////////////////////////////////////////////////////////////////// - auto scopedB0 = new Node(OpType_SCALAR, scalar::Add, 6, {12}, {}, {}, 1.0f); - scopedB0->markInplace(false); - scopedB0->setScopeInfo(10, "scopeBody"); - - auto nodeReturn = new Node(OpType_LOGIC, logic::Return, 7, {6}, {12}); - sd::ops::Return opReturn; - nodeReturn->setCustomOp(&opReturn); - nodeReturn->setScopeInfo(10, "scopeBody"); - - // WHILE operations takes 2 scopes - :0 is condition scope, and :1 is loop body scope - auto nodeWhile = new Node(OpType_LOGIC, logic::While, 12, {-2, 3, 10}); - sd::ops::While opWhile; - nodeWhile->setCustomOp(&opWhile); - - // adding root nodes first, nothing unusual expected here - graph.addNode(nodeA); - graph.addNode(nodeB); - - // now we're registering our scopes - graph.addNode(scopeCondition); - graph.addNode(scopeBody); - - // at this moment graph should have 4 (four) nodes registered - ASSERT_EQ(4, graph.totalNodes()); - - // adding node that's attached to some scope. so it should be pushed to specific scope - graph.addNode(scopedA0); - - // we should still have 4 ops in graph, because node added above - goes directly into the scope - // thus falls out of the graph direct execution - it can be executed only via Scope - ASSERT_EQ(4, graph.totalNodes()); - - graph.addNode(scopedA1); - graph.addNode(scopedB0); - graph.addNode(nodeReturn); - - // should be still 4. no options here. - ASSERT_EQ(4, graph.totalNodes()); - - // WHILE is valid node, so we expect nodes counter to go up - graph.addNode(nodeWhile); - ASSERT_EQ(5, graph.totalNodes()); - - // now, let's try to execute graph - Nd4jStatus status = GraphExecutioner::execute(&graph); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto w = variableSpace->getVariable(12, 0)->getNDArray(); - - w->printShapeInfo("w shape"); - ASSERT_NEAR(12.f, w->sumNumber().e(0), 1e-5f); -} -*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp index 7e7fb03737b4..570f2c0c6aaf 100644 --- a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp @@ -19,7 +19,6 @@ // #include "testlayers.h" -#include #include #include diff --git a/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp b/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp deleted file mode 100644 index 8d6a8d180e63..000000000000 --- a/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp +++ /dev/null @@ -1,251 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 13.10.2017. -// - -#include "testlayers.h" -#include - -using namespace sd; -using namespace sd::ops; -using namespace sd::graph; - -class SwitchTests : public testing::Test { -public: - -}; - -TEST_F(SwitchTests, SwitchTest1) { - Graph graph; - - FlowPath flowPath; - - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 0.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 0.0f); - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - // this is just 2 ops, that are executed sequentially. We don't really care bout them - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - // this is our condition op, we'll be using Equals condition, on variables conditionX and conditionY (ids -2 and -3 respectively) - // we're creating this op manually in tests, as always. - sd::ops::eq_scalar eqOp; - auto nodeCondition = new Node(&eqOp, 119, {-2, -3}); - //nodeCondition->setOpType(OpType_BOOLEAN); - - // now, this is Switch operation. It takes BooleanOperation operation in, - // and based on evaluation result (true/false) - it'll pass data via :0 or :1 output - // other idx will be considered disabled, and that graph branch won't be executed - sd::ops::Switch switchOp; - auto nodeSwitch = new Node(&switchOp, 3, {2, 119}, {4, 5}); - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {}, {}); - nodeZ0->pickInput(3, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 5, {}, {}); - nodeZ1->pickInput(3, 1); - - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - graph.buildGraph(); - - // we're making sure nodes connected to the Switch have no other inputs in this Graph - ASSERT_EQ(1, nodeZ0->input()->size()); - ASSERT_EQ(1, nodeZ1->input()->size()); - - // just validating topo sort - ASSERT_EQ(0, nodeA->getLayer()); - ASSERT_EQ(0, nodeCondition->getLayer()); - ASSERT_EQ(1, nodeB->getLayer()); - ASSERT_EQ(2, nodeSwitch->getLayer()); - ASSERT_EQ(3, nodeZ0->getLayer()); - ASSERT_EQ(3, nodeZ1->getLayer()); - - // executing graph - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - // nd4j_printf("Z0: [%i]; Z1: [%i]\n", flowPath.isNodeActive(nodeZ0->id()), flowPath.isNodeActive(nodeZ1->id())); - - // we know that Switch got TRUE evaluation, so :0 should be inactive - ASSERT_FALSE(flowPath.isNodeActive(nodeZ0->id())); - - // and :1 should be active - ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); - - std::pair unexpected(4,0); - std::pair expectedResultIndex(5,0); - ASSERT_TRUE(variableSpace->hasVariable(expectedResultIndex)); - - // getting output of nodeZ1 - auto output = variableSpace->getVariable(expectedResultIndex)->getNDArray(); - - // and veryfing it against known expected value - ASSERT_NEAR(-118.0f, output->e(0), 1e-5f); -} - -TEST_F(SwitchTests, SwitchTest2) { - Graph graph; - - FlowPath flowPath; - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 1.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 1.0f); - - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - - auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); - nodeCondition->setScopeInfo(3, "scopeCondition"); - - sd::ops::eq_scalar eqOp; - nodeCondition->setCustomOp(&eqOp); - - auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); - - sd::ops::Switch switchOp; - nodeSwitch->setCustomOp(&switchOp); - - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 6, {}, {}); - nodeZ0->pickInput(5, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); - nodeZ1->pickInput(5, 1); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(scopeCondition); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(!flowPath.isNodeActive(nodeZ0->id())); - ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); - - auto z = graph.getVariableSpace()->getVariable(7)->getNDArray(); - - // abs(-119) = 119; 1 - 119 = -118 - ASSERT_NEAR(-118.f, z->e(0), 1e-5); -} - -TEST_F(SwitchTests, SwitchTest3) { - Graph graph; - - FlowPath flowPath; - auto variableSpace = graph.getVariableSpace(); - variableSpace->setFlowPath(&flowPath); - - auto input = NDArrayFactory::create_('c',{32, 100}); - input->assign(-119.0f); - - auto condtionX = NDArrayFactory::create_('c', {1, 1}); - condtionX->p(0, 2.0f); - auto condtionY = NDArrayFactory::create_('c', {1, 1}); - condtionY->p(0, 1.0f); - - - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, condtionX); - variableSpace->putVariable(-3, condtionY); - - - auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); - - auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); - scopeCondition->setName("scopeCondition"); - - auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); - nodeCondition->setScopeInfo(3, "scopeCondition"); - - sd::ops::eq_scalar eqOp; - nodeCondition->setCustomOp(&eqOp); - - auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); - - sd::ops::Switch switchOp; - nodeSwitch->setCustomOp(&switchOp); - - - // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE - auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Neg, 6, {}, {}); - nodeZ0->pickInput(5, 0); - auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); - nodeZ1->pickInput(5, 1); - - graph.addNode(nodeA); - graph.addNode(nodeB); - graph.addNode(scopeCondition); - graph.addNode(nodeCondition); - graph.addNode(nodeSwitch); - graph.addNode(nodeZ0); - graph.addNode(nodeZ1); - - Nd4jStatus status = GraphExecutioner::execute(&graph); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(flowPath.isNodeActive(nodeZ0->id())); - ASSERT_TRUE(!flowPath.isNodeActive(nodeZ1->id())); - - auto z = graph.getVariableSpace()->getVariable(6)->getNDArray(); - - // abs(-119) = 119; Neg(119) = 119 - ASSERT_NEAR(-119.f, z->e(0), 1e-5); -} diff --git a/libnd4j/tests_cpu/layers_tests/testlayers.h b/libnd4j/tests_cpu/layers_tests/testlayers.h index d63694dac94f..697e61693e2e 100644 --- a/libnd4j/tests_cpu/layers_tests/testlayers.h +++ b/libnd4j/tests_cpu/layers_tests/testlayers.h @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include From 916a3599f8cc5503fdea8f8220a7d9b6a5d60548 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 31 Mar 2020 21:54:17 +0300 Subject: [PATCH 054/233] getting rid of legacy stuff Signed-off-by: raver119 --- libnd4j/include/graph/Context.h | 8 +- libnd4j/include/graph/ContextPrototype.h | 53 +++--- libnd4j/include/graph/Node.h | 14 +- .../graph/execution/impl/GraphExecutor.cpp | 2 +- libnd4j/include/graph/impl/Context.cpp | 50 +++--- .../include/graph/impl/ContextPrototype.cpp | 94 ++++++++--- libnd4j/include/graph/impl/Graph.cpp | 8 +- libnd4j/include/graph/impl/Node.cpp | 157 +++++++----------- .../include/graph/logic/impl/LogicExit.cpp | 2 +- .../graph/logic/impl/LogicLoopCond.cpp | 2 +- libnd4j/include/helpers/ShapeUtils.h | 10 +- .../helpers/benchmark/DeclarableBenchmark.h | 24 +-- libnd4j/include/helpers/impl/ShapeUtils.cpp | 12 +- libnd4j/include/legacy/cpu/NativeOps.cpp | 12 +- .../ops/declarable/generic/blas/axpy.cpp | 4 +- .../ops/declarable/generic/blas/matmul.cpp | 6 +- .../generic/broadcastable/meshgrid.cpp | 4 +- .../generic/broadcastable/percentile.cpp | 14 +- .../declarable/generic/images/adjust_hue.cpp | 2 +- .../generic/images/adjust_saturation.cpp | 2 +- .../declarable/generic/images/hsvToRgb.cpp | 2 +- .../generic/images/image_resize.cpp | 4 +- .../generic/images/resize_bicubic.cpp | 4 +- .../declarable/generic/images/rgbToGrs.cpp | 4 +- .../declarable/generic/images/rgbToHsv.cpp | 2 +- .../declarable/generic/images/rgbToYiq.cpp | 2 +- .../declarable/generic/images/rgbToYuv.cpp | 2 +- .../declarable/generic/images/yiqToRgb.cpp | 2 +- .../declarable/generic/images/yuvToRgb.cpp | 2 +- .../ops/declarable/generic/linalg/eye.cpp | 4 +- .../ops/declarable/generic/linalg/lup.cpp | 4 +- .../ops/declarable/generic/linalg/moments.cpp | 8 +- .../ops/declarable/generic/linalg/qr.cpp | 5 +- .../ops/declarable/generic/linalg/triu.cpp | 4 +- .../ops/declarable/generic/list/pick_list.cpp | 4 +- .../declarable/generic/list/write_list.cpp | 2 +- .../loss/softmaxCrossEntropyWithLogits.cpp | 6 +- .../generic/nn/activations/prelu.cpp | 4 +- .../generic/nn/activations/relu.cpp | 2 +- .../nn/activations/thresholdedrelu.cpp | 4 +- .../ops/declarable/generic/nn/apply_sgd.cpp | 2 +- .../ops/declarable/generic/nn/batchnorm.cpp | 4 +- .../ops/declarable/generic/nn/bias_add.cpp | 4 +- .../declarable/generic/nn/convo/conv1d.cpp | 16 +- .../declarable/generic/nn/convo/conv2d.cpp | 24 +-- .../declarable/generic/nn/convo/conv3d.cpp | 16 +- .../declarable/generic/nn/convo/deconv2d.cpp | 16 +- .../generic/nn/convo/deconv2d_tf.cpp | 8 +- .../declarable/generic/nn/convo/deconv3d.cpp | 16 +- .../generic/nn/convo/depthwiseConv2d.cpp | 16 +- .../declarable/generic/nn/convo/im2col.cpp | 4 +- .../ops/declarable/generic/nn/convo/ismax.cpp | 2 +- .../generic/nn/convo/pointwiseConv2d.cpp | 8 +- .../declarable/generic/nn/convo/sconv2d.cpp | 16 +- .../generic/nn/convo/upsampling2d.cpp | 6 +- .../generic/nn/convo/upsampling3d.cpp | 6 +- .../declarable/generic/nn/fusedBatchNorm.cpp | 2 +- .../ops/declarable/generic/nn/layer_norm.cpp | 8 +- .../ops/declarable/generic/nn/logSoftmax.cpp | 4 +- .../generic/nn/pooling/avgpool2d.cpp | 8 +- .../generic/nn/pooling/avgpool3d.cpp | 6 +- .../generic/nn/pooling/maxpool2d.cpp | 6 +- .../generic/nn/pooling/maxpool3d.cpp | 6 +- .../nn/pooling/maxpool_with_argmax.cpp | 2 +- .../generic/nn/pooling/pnormpool2d.cpp | 8 +- .../nn/recurrent/dynamicBidirectionalRNN.cpp | 4 +- .../generic/nn/recurrent/dynamicRNN.cpp | 4 +- .../ops/declarable/generic/nn/relu_layer.cpp | 2 +- .../ops/declarable/generic/nn/softmax.cpp | 4 +- .../generic/parity_ops/confusion_matrix.cpp | 2 +- .../fake_quant_with_min_max_vars.cpp | 8 +- ...ke_quant_with_min_max_vars_per_channel.cpp | 4 +- .../parity_ops/non_max_suppression.cpp | 24 +-- .../non_max_suppression_overlaps.cpp | 8 +- .../generic/parity_ops/normalize_moments.cpp | 2 +- .../generic/parity_ops/nth_element.cpp | 2 +- .../declarable/generic/parity_ops/roll.cpp | 2 +- .../generic/parity_ops/sequence_mask.cpp | 2 +- .../ops/declarable/generic/random/gamma.cpp | 2 +- .../declarable/generic/random/multinomial.cpp | 4 +- .../ops/declarable/generic/random/poisson.cpp | 2 +- .../declarable/generic/random/random_crop.cpp | 2 +- .../declarable/generic/random/set_seed.cpp | 2 +- .../ops/declarable/generic/random/uniform.cpp | 4 +- .../ops/declarable/generic/reduce/argmax.cpp | 4 +- .../ops/declarable/generic/reduce/argmin.cpp | 4 +- .../ops/declarable/generic/reduce/norm.cpp | 6 +- .../declarable/generic/reduce/reduceMean.cpp | 20 +-- .../declarable/generic/reduce/reduceStDev.cpp | 28 ++-- .../generic/reduce/reduceVariance.cpp | 28 ++-- .../declarable/generic/reduce/reduce_dot.cpp | 12 +- .../generic/reduce/reduce_logsumexp.cpp | 12 +- .../declarable/generic/reduce/reduce_max.cpp | 16 +- .../declarable/generic/reduce/reduce_min.cpp | 16 +- .../generic/reduce/reduce_norm1.cpp | 24 +-- .../generic/reduce/reduce_norm2.cpp | 24 +-- .../generic/reduce/reduce_norm_max.cpp | 20 +-- .../declarable/generic/reduce/reduce_prod.cpp | 24 +-- .../generic/reduce/reduce_sqnorm.cpp | 20 +-- .../declarable/generic/reduce/reduce_sum.cpp | 24 +-- .../ops/declarable/generic/shape/permute.cpp | 8 +- .../ops/declarable/generic/shape/reshape.cpp | 34 ++-- .../generic/shape/tile_to_shape.cpp | 4 +- .../declarable/generic/shape/transpose.cpp | 8 +- .../ops/declarable/generic/tensor/range.cpp | 8 +- .../generic/tensor/strided_slice.cpp | 16 +- .../generic/thrid_party/firas_sparse.cpp | 4 +- .../transforms/clip_by_averaged_norm.cpp | 2 +- .../generic/transforms/clip_by_norm.cpp | 4 +- .../declarable/generic/transforms/concat.cpp | 8 +- .../declarable/generic/transforms/cumprod.cpp | 2 +- .../declarable/generic/transforms/cumsum.cpp | 2 +- .../declarable/generic/transforms/gather.cpp | 6 +- .../generic/transforms/gatherNd.cpp | 2 +- .../transforms/histogram_fixed_width.cpp | 4 +- .../generic/transforms/merge_max_idx.cpp | 2 +- .../ops/declarable/generic/transforms/pad.cpp | 2 +- .../declarable/generic/transforms/repeat.cpp | 4 +- .../declarable/generic/transforms/reverse.cpp | 4 +- .../generic/transforms/scatter_add.cpp | 4 +- .../generic/transforms/scatter_div.cpp | 4 +- .../generic/transforms/scatter_max.cpp | 4 +- .../generic/transforms/scatter_min.cpp | 4 +- .../generic/transforms/scatter_mul.cpp | 4 +- .../generic/transforms/scatter_nd.cpp | 4 +- .../generic/transforms/scatter_nd_add.cpp | 4 +- .../generic/transforms/scatter_nd_sub.cpp | 4 +- .../generic/transforms/scatter_nd_update.cpp | 4 +- .../generic/transforms/scatter_sub.cpp | 4 +- .../generic/transforms/scatter_upd.cpp | 4 +- .../generic/transforms/scatter_update.cpp | 2 +- .../declarable/generic/transforms/slice.cpp | 17 +- .../declarable/generic/transforms/split_v.cpp | 4 +- .../declarable/generic/transforms/stack.cpp | 4 +- .../generic/transforms/standardize.cpp | 4 +- .../declarable/generic/transforms/tear.cpp | 6 +- .../declarable/generic/transforms/tile.cpp | 16 +- .../declarable/generic/tsne/symmetrized.cpp | 5 +- .../generic/updaters/adaDeltaUpdater.cpp | 2 +- .../generic/updaters/adaGradUpdater.cpp | 2 +- .../generic/updaters/adaMaxUpdater.cpp | 4 +- .../generic/updaters/adamUpdater.cpp | 4 +- .../generic/updaters/amsGradUpdater.cpp | 4 +- .../generic/updaters/nadamUpdater.cpp | 4 +- .../generic/updaters/nesterovsUpdater.cpp | 2 +- .../generic/updaters/rmsPropUpdater.cpp | 2 +- .../generic/updaters/sgdUpdater.cpp | 2 +- .../ops/declarable/impl/DeclarableListOp.cpp | 4 +- .../ops/declarable/impl/DeclarableOp.cpp | 42 ++--- .../declarable/impl/DeclarableReductionOp.cpp | 8 +- .../declarable/impl/LegacyBroadcastBoolOp.cpp | 2 +- .../ops/declarable/impl/LegacyBroadcastOp.cpp | 2 +- .../declarable/impl/LegacyIndexReduceOp.cpp | 14 +- .../impl/LegacyPairwiseTransformBoolOp.cpp | 2 +- .../impl/LegacyPairwiseTransformOp.cpp | 2 +- .../ops/declarable/impl/LegacyRandomOp.cpp | 20 +-- .../ops/declarable/impl/LegacyReduce3Op.cpp | 10 +- .../declarable/impl/LegacyReduceBoolOp.cpp | 8 +- .../declarable/impl/LegacyReduceFloatOp.cpp | 12 +- .../declarable/impl/LegacyReduceLongOp.cpp | 8 +- .../declarable/impl/LegacyReduceSameOp.cpp | 8 +- .../declarable/impl/LegacyScalarBoolOp.cpp | 4 +- .../ops/declarable/impl/LegacyScalarOp.cpp | 4 +- .../ops/declarable/impl/LegacyStatsOp.cpp | 12 +- .../declarable/impl/LegacyTransformAnyOp.cpp | 2 +- .../declarable/impl/LegacyTransformBoolOp.cpp | 2 +- .../impl/LegacyTransformFloatOp.cpp | 2 +- .../declarable/impl/LegacyTransformSameOp.cpp | 2 +- .../impl/LegacyTransformStrictOp.cpp | 2 +- .../platform/mkldnn/avgpooling2d.cpp | 4 +- .../platform/mkldnn/avgpooling3d.cpp | 4 +- .../declarable/platform/mkldnn/batchnorm.cpp | 8 +- .../ops/declarable/platform/mkldnn/conv2d.cpp | 8 +- .../ops/declarable/platform/mkldnn/conv3d.cpp | 8 +- .../declarable/platform/mkldnn/deconv2d.cpp | 8 +- .../platform/mkldnn/deconv2d_tf.cpp | 4 +- .../declarable/platform/mkldnn/deconv3d.cpp | 8 +- .../platform/mkldnn/depthwiseConv2d.cpp | 8 +- .../ops/declarable/platform/mkldnn/matmul.cpp | 2 +- .../platform/mkldnn/maxpooling2d.cpp | 4 +- .../platform/mkldnn/maxpooling3d.cpp | 4 +- .../declarable/platform/mkldnn/softmax.cpp | 4 +- libnd4j/include/system/op_boilerplate.h | 8 +- .../tests_cpu/layers_tests/ContextTests.cpp | 36 ++-- .../layers_tests/ConvolutionTests1.cpp | 66 ++++---- .../layers_tests/ConvolutionTests2.cpp | 34 ++-- .../layers_tests/DeclarableOpsTests1.cpp | 35 ++-- .../layers_tests/DeclarableOpsTests4.cpp | 12 +- .../layers_tests/DeclarableOpsTests6.cpp | 12 +- 189 files changed, 926 insertions(+), 906 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 54842e10ee6c..bbf1a1c90fcb 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -76,7 +76,7 @@ namespace sd { // special flag used during conversion from Graph exec to FastPath exec bool _forbidFastPath = false; public: - Context(ContextPrototype* prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager = nullptr); + Context(const ContextPrototype &prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager = nullptr); explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); Context(int nodeId, VariableSpace *variableSpace, bool isInplace); @@ -90,9 +90,9 @@ namespace sd { Nd4jLong getOuterTime(); Nd4jLong getInnerTime(); - sd::DataType dataType() override; + sd::DataType dataType() const override; - sd::DataType dataType(int index) override; + sd::DataType dataType(int index) const override; void setDataType(int index, sd::DataType type) override; // these methods are related to Workspace abstraction bool hasWorkspaceProvided(); @@ -178,7 +178,7 @@ namespace sd { Variable* ensureVariable(int idx = 0); - unsigned long width() override; + unsigned long width() const override; // methods used in java interop /** diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index db27992a110b..028588507287 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -71,19 +71,19 @@ namespace sd { explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); ~ContextPrototype() = default; - int getNodeId(); - int nodeId(); + int getNodeId() const; + int nodeId() const; // this method returns true, if inputs are defined - bool hasVariablesFilled(); + bool hasVariablesFilled() const; void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); - virtual sd::DataType dataType(); - virtual sd::DataType dataType(int index); + virtual sd::DataType dataType() const; + virtual sd::DataType dataType(int index) const ; virtual void setDataType(int index, sd::DataType type); - bool isInplace(); + bool isInplace() const; void markInplace(bool reallyInplace); void pickInput(int input); @@ -91,34 +91,45 @@ namespace sd { void pickInput(std::pair& p); void fillInputs(std::initializer_list inputs); void fillInputs(std::vector& inputs); - std::vector>* inputs(); + std::vector>& inputs() const; - std::vector* getTArguments(); - std::vector* getIArguments(); - std::vector* getBArguments(); - std::vector* getDArguments(); - std::vector* getAxis(); + const std::vector& getTArguments() const; + const std::vector& getIArguments() const; + const std::vector& getBArguments() const; + const std::vector& getDArguments() const; + const std::vector& getAxis() const; - samediff::Engine engine(); + void appendI(const std::vector &value); + void appendT(const std::vector &value); + void appendB(const std::vector &value); + void appendD(const std::vector &value); - size_t numT(); - size_t numI(); - size_t numB(); - size_t numD(); + void appendA(Nd4jLong value); + void appendI(Nd4jLong value); + void appendT(double value); + void appendB(bool value); + void appendD(DataType value); - std::pair* input(int idx); + samediff::Engine engine() const; - int opNum(); + size_t numT() const; + size_t numI() const; + size_t numB() const; + size_t numD() const; + + const std::pair& input(int idx) const; + + int opNum() const; void setOpNum(int opNum); - bool isUseMKLDNN() { return _useMKLDNN; } + bool isUseMKLDNN() const { return _useMKLDNN; } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } /** * This method returns number of inputs available in this block * @return */ - virtual unsigned long width(); + virtual unsigned long width() const; // just a clone ContextPrototype* clone(); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 6b4eff3e3316..30e8a950dc5f 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -42,7 +42,7 @@ namespace sd { sd::DataType _dataType; OpType _opType; - ContextPrototype* _protoContext = nullptr; + ContextPrototype _protoContext; Nd4jLong _opNum; int _id = 0; std::vector> _input; @@ -127,7 +127,7 @@ namespace sd { bool equals(Node *other) const; sd::DataType dataType(); - ContextPrototype *protoContext() const; + const ContextPrototype& protoContext() const; OpType opType() const; Nd4jLong opNum() const; int id() const; @@ -191,8 +191,8 @@ namespace sd { int totalReferences(); void addReference(int nodeId); - void setContextPrototype(ContextPrototype *block); - ContextPrototype* getContextPrototype(); + void setContextPrototype(const ContextPrototype &block); + const ContextPrototype& contextPrototype() const; bool hasBlockAttached(); void setCustomOp(sd::ops::DeclarableOp *customOp = nullptr); @@ -228,12 +228,8 @@ namespace sd { Node* asT(); FORCEINLINE void pullValues(Node *other) { - - if (this->_protoContext != nullptr) - delete _protoContext; - this->_dataType = other->dataType(); - this->_protoContext = other->protoContext()->clone(); + this->_protoContext = other->protoContext(); this->_scalar = other->scalar(); this->_hasExternalInputs = other->hasExternalInputs(); this->_hasExternalOutputs = other->hasExternalOutputs(); diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 604720789b0a..d3b2167da245 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -25,7 +25,7 @@ namespace sd { namespace graph { Context GraphExecutor::prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const { // TODO: maybe we'll want to do something here? - return Context(contextPrototype, &variableSpace, const_cast(&memoryManager)); + return Context(*contextPrototype, &variableSpace, const_cast(&memoryManager)); } Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 1c0f944d7c39..cb582e2ea7fa 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -26,48 +26,44 @@ namespace sd { namespace graph { - Context::Context(ContextPrototype* prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager) { + Context::Context(const ContextPrototype& prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager) { _memoryManager = memoryManager; _variableSpace = variableSpace; - _dataType = prototype->dataType(); + _dataType = prototype.dataType(); - if (prototype != nullptr) { - for (const auto &v: *(prototype->inputs())) { - this->_inputs.push_back(v); - } - - for (const auto &v: *(prototype->getTArguments())) { - this->_tArgs.push_back(v); - } + for (const auto &v: prototype.inputs()) { + this->_inputs.push_back(v); + } - for (const auto &v: *(prototype->getIArguments())) { - this->_iArgs.push_back(v); - } + for (const auto &v: prototype.getTArguments()) { + this->_tArgs.push_back(v); + } - for (const auto &v: *(prototype->getBArguments())) { - this->_bArgs.push_back(v); - } + for (const auto &v: prototype.getIArguments()) { + this->_iArgs.push_back(v); + } - for (const auto &v: *(prototype->getAxis())) { - this->_axis.push_back(v); - } + for (const auto &v: prototype.getBArguments()) { + this->_bArgs.push_back(v); + } - this->_opNum = prototype->opNum(); - this->_isInplace = prototype->isInplace(); - this->_nodeId = prototype->nodeId(); - this->_useMKLDNN = prototype->isUseMKLDNN(); + for (const auto &v: prototype.getAxis()) { + this->_axis.push_back(v); } + this->_opNum = prototype.opNum(); + this->_isInplace = prototype.isInplace(); + this->_nodeId = prototype.nodeId(); + this->_useMKLDNN = prototype.isUseMKLDNN(); if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) this->_workspace = variableSpace->launchContext()->getWorkspace(); } - sd::DataType Context::dataType(int index) { - + sd::DataType Context::dataType(int index) const { return _dataType; } - sd::DataType Context::dataType() { + sd::DataType Context::dataType() const { return dataType(0); } @@ -392,7 +388,7 @@ namespace sd { } } - unsigned long Context::width() { + unsigned long Context::width() const { if (!_fastpath_in.empty()) return _fastpath_in.size(); else diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 417c46b3a11b..0288e2765820 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -40,7 +40,7 @@ namespace sd { pickInput(pair); } - int ContextPrototype::opNum() { + int ContextPrototype::opNum() const { return this->_opNum; } @@ -48,8 +48,8 @@ namespace sd { this->_opNum = opNum; } - std::vector>* ContextPrototype::inputs() { - return &_inputs; + std::vector> & ContextPrototype::inputs() const { + return const_cast> &>(_inputs); } void ContextPrototype::fillInputs(std::vector& inputs) { @@ -59,32 +59,32 @@ namespace sd { } } - samediff::Engine ContextPrototype::engine() { + samediff::Engine ContextPrototype::engine() const { return _engine; } - bool ContextPrototype::hasVariablesFilled() { + bool ContextPrototype::hasVariablesFilled() const { return this->_inputs.size() > 0; } - bool ContextPrototype::isInplace() { + bool ContextPrototype::isInplace() const { return this->_isInplace; } - std::vector* ContextPrototype::getTArguments() { - return &(this->_tArgs); + const std::vector & ContextPrototype::getTArguments() const { + return const_cast&>(_tArgs); } - std::vector* ContextPrototype::getIArguments() { - return &(this->_iArgs); + const std::vector & ContextPrototype::getIArguments() const { + return const_cast&>(_iArgs); } - std::vector* ContextPrototype::getBArguments() { - return &(this->_bArgs); + const std::vector & ContextPrototype::getBArguments() const { + return const_cast&>(_bArgs); } - std::vector* ContextPrototype::getAxis() { - return &(this->_axis); + const std::vector & ContextPrototype::getAxis() const { + return const_cast&>(_axis); } void ContextPrototype::pickInput(int input) { @@ -92,8 +92,8 @@ namespace sd { this->_inputs.emplace_back(pair); } - std::pair* ContextPrototype::input(int idx) { - return &(this->_inputs.at(idx)); + const std::pair& ContextPrototype::input(int idx) const { + return this->_inputs.at(idx); } void ContextPrototype::fillInputs(std::initializer_list inputs) { @@ -102,15 +102,15 @@ namespace sd { } } - int ContextPrototype::nodeId() { + int ContextPrototype::nodeId() const { return getNodeId(); } - sd::DataType ContextPrototype::dataType() { + sd::DataType ContextPrototype::dataType() const { return dataType(0); } - sd::DataType ContextPrototype::dataType(int index) { + sd::DataType ContextPrototype::dataType(int index) const { return _dataType; } @@ -119,19 +119,19 @@ namespace sd { _dataType = type; } - size_t ContextPrototype::numT() { + size_t ContextPrototype::numT() const { return (int) _tArgs.size(); } - size_t ContextPrototype::numI() { + size_t ContextPrototype::numI() const { return (int) _iArgs.size(); } - size_t ContextPrototype::numB() { + size_t ContextPrototype::numB() const { return (int) _bArgs.size(); } - int ContextPrototype::getNodeId() { + int ContextPrototype::getNodeId() const { return this->_nodeId; } @@ -139,7 +139,7 @@ namespace sd { * This method returns number of inputs available in this block * @return */ - unsigned long ContextPrototype::width() { + unsigned long ContextPrototype::width() const { return this->_inputs.size(); }; @@ -174,12 +174,52 @@ namespace sd { return clone; } - std::vector *ContextPrototype::getDArguments() { - return &_dArgs; + const std::vector & ContextPrototype::getDArguments() const { + return const_cast&>(_dArgs); } - size_t ContextPrototype::numD() { + size_t ContextPrototype::numD() const { return _dArgs.size(); } + + void ContextPrototype::appendI(const std::vector &value) { + for (auto v:value) + _iArgs.emplace_back(v); + } + + void ContextPrototype::appendT(const std::vector &value) { + for (auto v:value) + _tArgs.emplace_back(v); + } + + void ContextPrototype::appendB(const std::vector &value) { + for (auto v:value) + _bArgs.emplace_back(v); + } + + void ContextPrototype::appendD(const std::vector &value) { + for (auto v:value) + _dArgs.emplace_back(v); + } + + void ContextPrototype::appendA(Nd4jLong value) { + _axis.emplace_back(value); + } + + void ContextPrototype::appendI(Nd4jLong value) { + _iArgs.emplace_back(value); + } + + void ContextPrototype::appendT(double value) { + _tArgs.emplace_back(value); + } + + void ContextPrototype::appendB(bool value) { + _bArgs.emplace_back(value); + } + + void ContextPrototype::appendD(DataType value) { + _dArgs.emplace_back(value); + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 4773685dd9da..c090ff1574ee 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -212,12 +212,12 @@ namespace sd { if (node.opType() == OpType_CUSTOM) { auto ctx = node.protoContext(); - if (ctx->getIArguments()->size() > 0) { + if (ctx.numI() > 0) { printf("]; iArgs: ["); - for (int e = 0; e < ctx->getIArguments()->size(); e++) { - printf("%i", ctx->getIArguments()->at(e)); - if (e < ctx->getIArguments()->size() - 1) + for (int e = 0; e < ctx.numI(); e++) { + printf("%i", ctx.getIArguments().at(e)); + if (e < ctx.getIArguments().size() - 1) nd4j_printf(", ", ""); } } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 04ac581535b5..201b412cc911 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -65,19 +65,12 @@ namespace sd { // FIXME: get rid of this!!! _scalar = NDArrayFactory::create(0); - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); - - for (auto v: bArgs) - block->getBArguments()->emplace_back(v); - - for (auto v: dArgs) - block->getDArguments()->emplace_back(v); + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); this->setContextPrototype(block); } @@ -104,19 +97,12 @@ namespace sd { // FIXME: get rid of this!!! _scalar = NDArrayFactory::create(0); - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); - - for (auto v: bArgs) - block->getBArguments()->emplace_back(v); - - for (auto v: dArgs) - block->getDArguments()->emplace_back(v); + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); this->setContextPrototype(block); } @@ -141,9 +127,7 @@ namespace sd { void Node::markInplace(bool reallyInplace) { _isInplace = reallyInplace; - if (_protoContext != nullptr) { - _protoContext->markInplace(reallyInplace); - } + _protoContext.markInplace(reallyInplace); } bool Node::isRemovable() const { @@ -159,7 +143,7 @@ namespace sd { } bool Node::hasBlockAttached() { - return _protoContext != nullptr; + return true; } bool Node::isInplace() { @@ -191,21 +175,11 @@ namespace sd { _frameId = frameId; } - ContextPrototype * Node::getContextPrototype() { - if (_protoContext == nullptr) - _protoContext = new ContextPrototype(this->customOp() != nullptr ? this->customOp()->getOpDescriptor() : nullptr, this->id()); - if (_protoContext->inputs()->empty()) { - for (int e = 0; e < this->input().size(); e++) { - _protoContext->inputs()->emplace_back(this->input().at(e)); - } - } + const ContextPrototype& Node::contextPrototype() const { return _protoContext; } - void Node::setContextPrototype(ContextPrototype *block) { - if (_protoContext != nullptr) - throw std::runtime_error("Block already exists"); - + void Node::setContextPrototype(const ContextPrototype &block) { _protoContext = block; } @@ -420,13 +394,10 @@ namespace sd { for (auto i: inputs) pickInput(i); - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + block.appendI(iArgs); + block.appendT(tArgs); this->setContextPrototype(block); } @@ -453,13 +424,10 @@ namespace sd { for (auto i: inputs) pickInput(i); - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - for (auto v: iArgs) - block->getIArguments()->emplace_back(v); - - for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + block.appendI(iArgs); + block.appendT(tArgs); this->setContextPrototype(block); } @@ -496,16 +464,16 @@ namespace sd { } } - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); for (auto v: iArgs) - block->getIArguments()->emplace_back(v); + block.appendI(v); for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + block.appendT(v); this->setContextPrototype(block); } @@ -575,32 +543,33 @@ namespace sd { this->_isDeductable = true; - auto block = new ContextPrototype(nullptr, this->id(), false); + ContextPrototype block(nullptr, this->id(), false); for (auto v: dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); for (auto v: iArgs) - block->getIArguments()->emplace_back(v); + block.appendI(v); for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + block.appendT(v); this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), opNum, &_scalar)); - block->setOpDescriptor(this->customOp()->getOpDescriptor()); + + this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), opNum, &_scalar)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (opType == OpType_CUSTOM) { if (this->customOp()) { - auto block = new ContextPrototype(this->customOp()->getOpDescriptor(), this->id(), false); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); for (auto v: iArgs) - block->getIArguments()->emplace_back(v); + block.appendI(v); for (auto v: tArgs) - block->getTArguments()->emplace_back(v); + block.appendT(v); this->setContextPrototype(block); } else throw std::runtime_error("wrong custom operation given"); @@ -718,74 +687,74 @@ namespace sd { if (node->input() != nullptr && node->input()->size() > 0) { this->_isDeductable = true; - auto block = new ContextPrototype(nullptr, this->id(), false); + ContextPrototype block(nullptr, this->id(), false); for (auto v: _dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); + block.appendT(static_cast(node->extraParams()->Get(e))); } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); + block.appendB(node->extraBools()->Get(e)); } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); + block.appendI(node->extraInteger()->Get(e)); } if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); + block.appendD((sd::DataType) node->extraTypes()->Get(e)); } } this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->customOp()->getOpDescriptor()); + this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum, &_scalar)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { this->_isDeductable = true; - auto block = new ContextPrototype(nullptr, this->id(), false); + ContextPrototype block(nullptr, this->id(), false); for (int e = 0; e < this->input().size(); e++) { - block->inputs()->emplace_back(this->input().at(e)); + block.inputs().emplace_back(this->input().at(e)); } // there's no other IArgs in legacy options, actually for (auto v: _dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); + block.appendT(static_cast(node->extraParams()->Get(e))); } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); + block.appendB(node->extraBools()->Get(e)); } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block->getIArguments()->emplace_back(node->extraInteger()->Get(e)); + block.appendI(node->extraInteger()->Get(e)); } if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); + block.appendD((sd::DataType) node->extraTypes()->Get(e)); } } this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block->getIArguments()->size(), (int) block->getTArguments()->size(), (int) _opNum, &_scalar)); - block->setOpDescriptor(this->customOp()->getOpDescriptor()); + this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum, &_scalar)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); } } else if (this->_opType == OpType_CUSTOM) { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum()); @@ -794,40 +763,40 @@ namespace sd { throw std::runtime_error("Can't find requested operation"); } - auto block = new ContextPrototype(nullptr, this->id()); + ContextPrototype block(nullptr, this->id()); for (int e = 0; e < this->input().size(); e++) { - block->inputs()->emplace_back(this->input().at(e)); + block.inputs().emplace_back(this->input().at(e)); } if (node->extraInteger() != nullptr) for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { auto v = node->extraInteger()->Get(e); // FIXME: remove this static_cast, iArgs should be Nd4jLong - block->getIArguments()->emplace_back(static_cast(v)); + block.appendI(static_cast(v)); } if (node->extraParams() != nullptr) for (uint32_t e = 0; e < node->extraParams()->size(); e++) - block->getTArguments()->emplace_back(static_cast(node->extraParams()->Get(e))); + block.appendT(static_cast(node->extraParams()->Get(e))); if (node->extraBools() != nullptr && node->extraBools()->size() > 0) for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block->getBArguments()->push_back(node->extraBools()->Get(e)); + block.appendB(node->extraBools()->Get(e)); } if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block->getDArguments()->emplace_back((sd::DataType) node->extraTypes()->Get(e)); + block.appendD((sd::DataType) node->extraTypes()->Get(e)); } } for (auto v: _dimensions) - block->getAxis()->emplace_back(v); + block.appendA(v); this->setContextPrototype(block); this->setCustomOp(op); - block->setOpDescriptor(this->customOp()->getOpDescriptor()); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); } } else { // empty dynamic node, tests probably @@ -838,7 +807,7 @@ namespace sd { return _dataType; } - ContextPrototype* Node::protoContext() const { + const ContextPrototype& Node::protoContext() const { return _protoContext; } @@ -849,9 +818,6 @@ namespace sd { if (_dim != nullptr) delete[] _dim; - if (_protoContext != nullptr) - delete _protoContext; - if (_isDeductable && _customOp != nullptr) { Node::deleteOpByType(_opType, _customOp); } @@ -966,7 +932,6 @@ namespace sd { _referencedBy = std::move(other._referencedBy); _scalar = std::move(other._scalar); - other._protoContext = nullptr; other._customOp = nullptr; return *this; diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 1b2a8d49d97e..f351795e05f7 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -30,7 +30,7 @@ namespace sd { auto __variableSpace = graph->getVariableSpace(); auto __flowPath = __variableSpace->flowPath(); - Context ctx(node->getContextPrototype(), __variableSpace); + Context ctx(node->protoContext(), __variableSpace); auto input = ctx.variable(0)->getNDArray(); std::pair pair0(node->id(), 0); diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 8409b0bd82ea..8f04e0c3966e 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -27,7 +27,7 @@ namespace sd { auto __variableSpace = graph->getVariableSpace(); auto __flowPath = __variableSpace->flowPath(); - Context ctx(node->getContextPrototype(), __variableSpace); + Context ctx(node->protoContext(), __variableSpace); auto input = ctx.variable(0)->getNDArray(); std::pair pair0(node->id(), 0); diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index cd4cb0bf97fa..2ff4d25f1293 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -35,16 +35,16 @@ namespace sd { static std::vector evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector& axesA, const std::vector& axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt); // evaluate resulting shape after reduce operation - static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); + static Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); + static Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); + static Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); + static Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); /** * evaluate output shape for reduce operation when input shape is empty * behavior is analogous to tf */ - static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace); + static Nd4jLong* evalReduceShapeInfoEmpty(const char order, const std::vector& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace); // evaluate shape for array which is result of repeat operation applied to arr static std::vector evalRepeatShape(int axis, const std::vector& repeats, const NDArray& arr); diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index 7abaa232bedf..f570e676d70e 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -124,41 +124,41 @@ namespace sd { std::string extra() override { if(_context != nullptr){ - std::vector* iargs = _context->getIArguments(); - std::vector* targs = _context->getTArguments(); - std::vector* bargs = _context->getBArguments(); + auto iargs = _context->getIArguments(); + auto targs = _context->getTArguments(); + auto bargs = _context->getBArguments(); std::string e; bool any = false; - if(iargs != nullptr){ + if(!iargs.empty()){ e += "iargs=["; - for( int i=0; isize(); i++ ){ + for( int i=0; i 0) e += ","; - e += std::to_string(iargs->at(i)); + e += std::to_string(iargs.at(i)); } e += "]"; any = true; } - if(targs != nullptr){ + if(!targs.empty()){ if(any) e += ","; e += "targs=["; - for( int i=0; isize(); i++ ){ + for( int i=0; i 0) e += ","; - e += std::to_string(targs->at(i)); + e += std::to_string(targs.at(i)); } e += "]"; any = true; } - if(bargs != nullptr){ + if(!bargs.empty()){ if(any) e += ","; e += "bargs=["; - for( int i=0; isize(); i++ ){ + for( int i=0; i 0) e += ","; - e += std::to_string(bargs->at(i)); + e += std::to_string(bargs.at(i)); } e += "]"; } diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index aa8e917cc197..2b450ea9ec3c 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -124,7 +124,8 @@ std::vector ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons ////////////////////////////////////////////////////////////////////////// // evaluate output shape for reduce operation when input shape is empty -Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, const std::vector& vdimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; if (dimsToExclude.size() == 0) { // return copy of input shape Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace); @@ -171,22 +172,23 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vectorbufferForShapeInfo(descriptor).primaryAsT(); } -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace); } -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace); } ////////////////////////////////////////////////////////////////////////// // evaluate shape resulting from reduce operation -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& vdimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 072829d937e5..2b0409b88205 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -1973,16 +1973,16 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla sd::ShapeList inShapes; for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + block.appendI(iArgs[e]); for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + block.appendT(tArgs[e]); for (int e = 0; e < numBArgs; e++) - block.getBArguments()->push_back(bArgs[e]); + block.appendB(bArgs[e]); for (int e = 0; e < numDArgs; e++) - block.getDArguments()->push_back((sd::DataType) dArgs[e]); + block.appendD((sd::DataType) dArgs[e]); for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2028,10 +2028,10 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla sd::ShapeList inShapes; for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + block.appendI(iArgs[e]); for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + block.appendT(tArgs[e]); for (int e = 0; e < numInputShapes; e++) inShapes.push_back(reinterpret_cast(inputShapes[e])); diff --git a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp index 7c115059991a..3e27e8921af8 100644 --- a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp @@ -38,7 +38,7 @@ namespace sd { if (block.width() > 2) { auto alpha = INPUT_VARIABLE(2); REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); - } else if (block.getTArguments()->size() > 0) { + } else if (block.numT() > 0) { a = T_ARG(0); } @@ -46,7 +46,7 @@ namespace sd { y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); - return ND4J_STATUS_OK; + return Status::OK(); } DECLARE_TYPES(axpy) { diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index 6209e7bbf8fd..2e8fbe736232 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { auto y = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); - const int iSize = (int) block.getIArguments()->size(); + const int iSize = (int) block.numI(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; @@ -98,7 +98,7 @@ DECLARE_SHAPE_FN(matmul) { auto xShapeInfo = inputShape->at(0); auto yShapeInfo = inputShape->at(1); - const int iSize = (int) block.getIArguments()->size(); + const int iSize = (int) block.numI(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; @@ -147,7 +147,7 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { auto dldx = OUTPUT_VARIABLE(0); auto dldy = OUTPUT_VARIABLE(1); - const int iSize = (int) block.getIArguments()->size(); + const int iSize = (int) block.numI(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp index b07f50202c08..ea3757485e90 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp @@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { return Status::OK(); } - bool swapFirst2Dims = block.getIArguments()->size() > 0 ? (bool)INT_ARG(0) : true; + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; std::vector inArrs(rank); std::vector outArrs(rank); @@ -61,7 +61,7 @@ CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { DECLARE_SHAPE_FN(meshgrid) { - bool swapFirst2Dims = block.getIArguments()->size() > 0 ? (bool)INT_ARG(0) : true; + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; int rank = block.width(); Nd4jLong* outShapeInfo = nullptr; diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp index e2bf723b30c2..c3a2ac25876c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp @@ -33,10 +33,10 @@ CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) const auto q = T_ARG(0); // percentile - const int interpolation = block.getTArguments()->size() > 1 ? T_ARG(1) : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) - const int keepDims = block.getTArguments()->size() > 2 ? T_ARG(2) : 0.; // false is default + const int interpolation = block.numT() > 1 ? T_ARG(1) : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default - const int axisArrRank = block.getIArguments()->size(); + const int axisArrRank = block.numI(); const int inputArrRank = input->rankOf(); REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); @@ -49,7 +49,7 @@ CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); } - std::vector axises = *block.getIArguments(); + auto axises = block.getIArguments(); helpers::percentile(block.launchContext(), *input, *output, axises, q, interpolation); return Status::OK(); @@ -66,9 +66,9 @@ CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { DECLARE_SHAPE_FN(percentile) { Nd4jLong* inputShapeInfo = inputShape->at(0); - const int keepDims = block.getTArguments()->size() > 2 ? T_ARG(2) : 0.; // false is default + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default - const int axisArrRank = block.getIArguments()->size(); + const int axisArrRank = block.numI(); const int inputArrRank = inputShapeInfo[0]; REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); @@ -79,7 +79,7 @@ DECLARE_SHAPE_FN(percentile) { REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); } - std::vector axises = *block.getIArguments(); + auto axises = block.getIArguments(); Nd4jLong* outputShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, block.getWorkspace()); return SHAPELIST(outputShapeInfo); diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp index 436fae28d75c..ff564dd80b8b 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp @@ -39,7 +39,7 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); + const int arg_size = block.numI(); const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp index 5be1699f4a6b..40243f2d68bc 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp @@ -38,7 +38,7 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); + const int arg_size = block.numI(); const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp index d5211e498241..43cae28c50a1 100644 --- a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp @@ -35,7 +35,7 @@ CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp index 3ceba93d88c1..543bf7d87f23 100644 --- a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp @@ -39,9 +39,9 @@ namespace sd { REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %lld.", size->lengthOf()); width = size->e(0); height = size->e(1); - if (block.getBArguments()->size()) { + if (block.numB()) { preserveAspectRatio = B_ARG(0); - if (block.getBArguments()->size() > 1) + if (block.numB() > 1) antialias = B_ARG(1); } diff --git a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp index a867a2147421..65b5ef0b5af8 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp @@ -55,9 +55,9 @@ namespace sd { auto alignCorners = false; auto halfPixelAlign = false; if (block.numB() > 0) { - alignCorners = block.getBArguments()->at(0); + alignCorners = block.getBArguments().at(0); if (block.numB()> 1) - halfPixelAlign = block.getBArguments()->at(1); + halfPixelAlign = block.getBArguments().at(1); } REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp index f7378d3336fb..348cce162353 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp @@ -32,7 +32,7 @@ CUSTOM_OP_IMPL(rgb_to_grs, 1, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); const int inRank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); @@ -55,7 +55,7 @@ DECLARE_SHAPE_FN(rgb_to_grs) { const auto input = INPUT_VARIABLE(0); const int inRank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp index ac5a27c667d8..026c93749a95 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp @@ -37,7 +37,7 @@ CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp index 40c936e4f995..bf7d4f32236f 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp @@ -37,7 +37,7 @@ namespace sd { return Status::OK(); const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); + const int arg_size = block.numI(); const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp index b52b5a8a6b6c..9bf40a2c1256 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp @@ -38,7 +38,7 @@ CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp index e339fb74b678..08f4be2b77a6 100644 --- a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp @@ -37,7 +37,7 @@ namespace sd { return Status::OK(); const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); + const int arg_size = block.numI(); const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp index 48d4e379a9eb..1665ea131939 100644 --- a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp @@ -36,7 +36,7 @@ CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); diff --git a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp index 41469468cf4a..785f72185fa8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp @@ -44,10 +44,10 @@ namespace ops { std::vector params; - sd::DataType dtype = block.getTArguments()->empty() ? sd::DataType::FLOAT32 : sd::DataTypeUtils::fromInt(T_ARG(0)); + sd::DataType dtype = block.getTArguments().empty() ? sd::DataType::FLOAT32 : sd::DataTypeUtils::fromInt(T_ARG(0)); if(block.width() == 0) { - params = *block.getIArguments(); + params = block.getIArguments(); } else { for (int i = 0; i < block.width(); i++) { diff --git a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp index e0b1eb8d7e32..1607a08ea458 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp @@ -30,7 +30,7 @@ namespace sd { auto z = OUTPUT_VARIABLE(0); auto p = OUTPUT_VARIABLE(1); - if (block.getIArguments()->size()) { + if (block.numI()) { DataType dtype = (DataType)INT_ARG(0); REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } @@ -46,7 +46,7 @@ namespace sd { auto shapeVector = ShapeUtils::shapeAsVector(in); auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); auto dtype = sd::DataType::INT32; - if (block.getIArguments()->size()) { + if (block.numI()) { dtype = (DataType)INT_ARG(0); REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } diff --git a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp index c8fdf2e48cf9..d9e9b00c39a0 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp @@ -31,8 +31,8 @@ namespace sd { auto means = OUTPUT_VARIABLE(0); auto variances = OUTPUT_VARIABLE(1); - std::vector axis = *block.getIArguments(); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + auto axis = block.getIArguments(); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -56,7 +56,7 @@ namespace sd { } DECLARE_SHAPE_FN(moments) { - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); auto input = INPUT_VARIABLE(0); // axis might be dynamic (i.e. tf mode) @@ -73,7 +73,7 @@ namespace sd { } //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); diff --git a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp index 2cf9156ce7b4..4da3914ce4b8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp @@ -30,8 +30,9 @@ namespace sd { auto outputQ = OUTPUT_VARIABLE(0); auto outputR = OUTPUT_VARIABLE(1); auto fullMatricies = false; - if (block.getBArguments()->size()) + if (block.numB()) fullMatricies = B_ARG(0); + REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf()); REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2)); REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1)); @@ -49,7 +50,7 @@ namespace sd { int targetRank = shape::rank(inShape); // last two dimensions will be reduced to scalar auto fullMatricies = false; - if (block.getBArguments()->size()) + if (block.numB()) fullMatricies = B_ARG(0); auto shape = ShapeUtils::shapeAsVector(inShape); diff --git a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp index 839828f62a77..24525f09a10d 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(triu, 1, 1, false, 0, 0) { REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); - const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, *output, 'l' ), LIBND4J_TYPES); @@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(triu_bp, 2, 1, false, 0, 0) { REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU_BP OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); - const int diag = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; helpers::triuBP(block.launchContext(), *input, *gradO, *gradI, diag); diff --git a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp index 1254456bda4e..8fde9fdc2495 100644 --- a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp @@ -33,8 +33,8 @@ namespace sd { auto ia = INPUT_VARIABLE(1); for (int e = 0; e < ia->lengthOf(); e++) indices.emplace_back(ia->e(e)); - } else if (block.getIArguments()->size() > 0) { - indices = *(block.getIArguments()); + } else if (block.numI() > 0) { + indices = block.getIArguments(); } else return ND4J_STATUS_BAD_ARGUMENTS; for (auto& v: indices) { diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index c61bcb68b3e0..859aecdee7fe 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -48,7 +48,7 @@ namespace sd { // OVERWRITE_RESULT(res); return result; - } else if (block.getIArguments()->size() == 1) { + } else if (block.numI() == 1) { auto input = INPUT_VARIABLE(1); auto idx = INT_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index 6dab14365c7c..1b3135f2762f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { auto labels = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf()-1; // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); @@ -62,7 +62,7 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { auto logitsShapeInfo = inputShape->at(0); auto labelsShapeInfo = inputShape->at(1); - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1; + const int classesDim = block.numI() > 0 ? INT_ARG(0) : -1; std::vector dimensions = {classesDim}; // labels and logits must have the same shapes @@ -89,7 +89,7 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) { auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels - const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf()-1; // input validation REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp index b7d260a4c740..c382939ccba8 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp @@ -36,7 +36,7 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) { auto alpha = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - std::vector sharedAxes = *block.getIArguments(); + std::vector sharedAxes = block.getIArguments(); const int inputRank = input->rankOf(); const int numSharedAxes = sharedAxes.size(); // can be zero as well @@ -87,7 +87,7 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) { auto dLdI = OUTPUT_VARIABLE(0); auto dLdA = OUTPUT_VARIABLE(1); - std::vector sharedAxes = *block.getIArguments(); + std::vector sharedAxes = block.getIArguments(); const int inputRank = input->rankOf(); const int numSharedAxes = sharedAxes.size(); // can be zero as well diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp index 3b42c2e5af00..91c599126aae 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp @@ -30,7 +30,7 @@ namespace sd { auto first = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; first->applyScalar(sd::scalar::RELU, scalar, *z); diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp index a0cba155ac05..519f09d6d95f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp @@ -34,7 +34,7 @@ CONFIGURABLE_OP_IMPL(thresholdedrelu, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; helpers::thresholdRelu(block.launchContext(), *input, scalar, *output); @@ -53,7 +53,7 @@ CONFIGURABLE_OP_IMPL(thresholdedrelu_bp, 2, 1, true, 0, 0) { auto dLdO = INPUT_VARIABLE(1); auto dLdI = OUTPUT_VARIABLE(0); - auto threshold = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto threshold = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; helpers::thresholdReluDerivative(block.launchContext(), input, threshold, dLdO, dLdI); diff --git a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp index 389d07c7b651..e287c5035ccc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp @@ -37,7 +37,7 @@ namespace sd { if (block.width() == 3) { auto tarr = INPUT_VARIABLE(2); lr = tarr->e(0); - } else if (block.getTArguments()->size() == 1) { + } else if (block.numT() == 1) { lr = T_ARG(0); } else { REQUIRE_TRUE(false, 0, "ApplyGradients op should have LR announced either es T argument or additional NDArray!"); diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index e69b370ca717..c9fc0b91c323 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { if(applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over @@ -156,7 +156,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over diff --git a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp index eec864c5ed80..9d45797b470e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(biasadd, 2, 1, true, 0, 0) { auto output = OUTPUT_VARIABLE(0); - const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false; + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last REQUIRE_TRUE(bias->rankOf() == 1, 0, "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i instead !", bias->rankOf()); @@ -77,7 +77,7 @@ CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { auto gradI = OUTPUT_VARIABLE(0); auto gradB = OUTPUT_VARIABLE(1); - const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false; + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last gradI->assign(gradO); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 27081b545cc9..5815e0b55f14 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -44,8 +44,8 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { int pW = INT_ARG(2); // paddings width int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC + int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); @@ -105,8 +105,8 @@ DECLARE_SHAPE_FN(conv1d) { int pW = INT_ARG(2); // paddings width int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { @@ -178,8 +178,8 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { int pW = INT_ARG(2); // paddings width int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] const int rank = 3; REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); @@ -253,8 +253,8 @@ DECLARE_SHAPE_FN(conv1d_bp) { int pW = INT_ARG(2); // paddings width int dW = INT_ARG(3); // dilations width int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); if(!isNCW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index 4377c1487217..eba662d08ce2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -49,8 +49,8 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width @@ -86,8 +86,8 @@ DECLARE_SHAPE_FN(conv2d) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width @@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); @@ -226,8 +226,8 @@ DECLARE_SHAPE_FN(conv2d_bp) { const int dH = INT_ARG(6); // dilations height const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); if(!isNCHW) { @@ -282,8 +282,8 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); @@ -344,8 +344,8 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { const int dH = INT_ARG(6); // dilations height const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; if(!isNCHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 0657f6dc2ede..05bb837deccd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -52,8 +52,8 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -128,8 +128,8 @@ DECLARE_SHAPE_FN(conv3dnew) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] const int rank = 5; REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); @@ -211,8 +211,8 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -318,8 +318,8 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] const int rank = 5; REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 8d6c0e3a76e4..4ba57c93d859 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -52,8 +52,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -121,8 +121,8 @@ DECLARE_SHAPE_FN(deconv2d) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); if(!isNCHW) { @@ -194,8 +194,8 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -281,8 +281,8 @@ DECLARE_SHAPE_FN(deconv2d_bp) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); if(!isNCHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp index ae97c3d65173..8eb361118c47 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp @@ -46,8 +46,8 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); @@ -102,8 +102,8 @@ DECLARE_SHAPE_FN(deconv2d_tf) { const int dH = INT_ARG(6); // dilations height const int dW = INT_ARG(7); // dilations width const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; if(!isNCHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index ab6e49836ac5..f7f6416478c9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -53,8 +53,8 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -129,8 +129,8 @@ DECLARE_SHAPE_FN(deconv3d) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); if(!isNCDHW) { @@ -209,8 +209,8 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -305,8 +305,8 @@ DECLARE_SHAPE_FN(deconv3d_bp) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); if(!isNCDHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index 30580e7a6e41..8c03ea6eb73f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -49,8 +49,8 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI()> 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes @@ -92,8 +92,8 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { @@ -171,8 +171,8 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes @@ -216,8 +216,8 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp index 2e5818c56754..db65329835f0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp @@ -45,7 +45,7 @@ namespace sd { int dX = INT_ARG(7); //Dilation, width/x dimension bool isSameMode = INT_ARG(8) > 0; double zeroPadVal = 0.0; - if (block.getTArguments()->size() > 0) + if (block.numT() > 0) zeroPadVal = T_ARG(0); // FIXME: zeropad value is void @@ -120,7 +120,7 @@ namespace sd { int dX = INT_ARG(7); //Dilation, width/x dimension bool isSameMode = INT_ARG(8) > 0; double zeroPadVal = 0.0; - if (block.getTArguments()->size() > 0) + if (block.numT() > 0) zeroPadVal = T_ARG(0); //Assuming NCHW format here diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp index d786504adb72..24488dd5de31 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp @@ -32,7 +32,7 @@ CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - auto dimensions = *(block.getIArguments()); // argI + auto dimensions = block.getIArguments(); // argI if (x->isScalar()) z->assign(1); else diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp index 52960c3fc866..e519585e62b7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp @@ -47,8 +47,8 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { int pW = 0; // paddings width int dH = 1; // dilations height int dW = 1; // dilations width - int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + int isNCHW = block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -81,8 +81,8 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + int isNCHW = block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] int indIOioC, indWoC(0 == wFormat ? 3 : 0); if(!isNCHW) diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index a804abafad79..3ea1eec675ed 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -66,8 +66,8 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width @@ -136,8 +136,8 @@ DECLARE_SHAPE_FN(sconv2d) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { @@ -246,8 +246,8 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes @@ -338,8 +338,8 @@ DECLARE_SHAPE_FN(sconv2d_bp) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); if(!isNCHW) { diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp index 4800b3db9ddc..fd41313e4470 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { const int factorH = INT_ARG(0); const int factorW = INT_ARG(1); - const int isNCHW = block.getIArguments()->size() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf()); @@ -61,7 +61,7 @@ DECLARE_SHAPE_FN(upsampling2d) { const int factorH = INT_ARG(0); const int factorW = INT_ARG(1); - const int isNCHW = block.getIArguments()->size() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC Nd4jLong *outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); @@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC // REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp index 557468d147b0..19bc5a9ecaf0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { const int factorD = INT_ARG(0); const int factorH = INT_ARG(1); const int factorW = INT_ARG(2); - const int isNCDHW = block.getIArguments()->size() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC + const int isNCDHW = block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf()); @@ -61,7 +61,7 @@ DECLARE_SHAPE_FN(upsampling3d) { const int factorD = INT_ARG(0); const int factorH = INT_ARG(1); const int factorW = INT_ARG(2); - const int isNCDHW = block.getIArguments()->size() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC + const int isNCDHW = block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC Nd4jLong *outputShapeInfo = nullptr; ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); @@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + const int isNCDHW = block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC // REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index 6e911e405043..9fc71572dd43 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { // FIXME: double? double epsilon; - if(block.getTArguments()->size() > 0) + if(block.numT() > 0) epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; else epsilon = 0.001; diff --git a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp index 5643932cbf7a..1371d6b68b43 100644 --- a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp @@ -33,9 +33,9 @@ namespace ops { auto gain = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - std::vector axis = *block.getIArguments(); + auto axis = block.getIArguments(); - const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const bool isNCHW = block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const int dimC = isNCHW ? 1 : input->rankOf() - 1; REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); @@ -82,12 +82,12 @@ namespace ops { auto dLdg = OUTPUT_VARIABLE(1); auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; - const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const bool isNCHW = block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const int dimC = isNCHW ? 1 : input->rankOf() - 1; REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); - std::vector axis = *block.getIArguments(); + auto axis = block.getIArguments(); std::vector longAxis = ArrayUtils::toLongVector(axis); diff --git a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp index 64aadce370a7..be9653a9d3ec 100644 --- a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp @@ -39,7 +39,7 @@ CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { auto output = OUTPUT_VARIABLE(0); const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); @@ -62,7 +62,7 @@ CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) { auto gradI = OUTPUT_VARIABLE(0); const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int dim = block.numI()> 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index b93cbe47f257..359c0ac3ca43 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -45,7 +45,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { const auto dW = INT_ARG(7); const auto isSameMode = static_cast(INT_ARG(8)); const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -93,7 +93,7 @@ DECLARE_SHAPE_FN(avgpool2d) { auto shapeOf = shape::shapeOf(inShape); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); + auto argI = block.getIArguments(); const int kH = INT_ARG(0); const int kW = INT_ARG(1); const int sH = INT_ARG(2); @@ -104,7 +104,7 @@ DECLARE_SHAPE_FN(avgpool2d) { const int dW = INT_ARG(7); const int isSameMode = INT_ARG(8); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -159,7 +159,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 85b8d88339a2..32efcfb826e6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -48,7 +48,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -100,7 +100,7 @@ DECLARE_SHAPE_FN(avgpool3dnew) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -165,7 +165,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { const int dW = INT_ARG(11); // dilations width const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index d92c27442073..b3329e8b7c65 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { int oH = 0; int oW = 0; - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1); const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); @@ -107,7 +107,7 @@ DECLARE_SHAPE_FN(maxpool2d) { int dH = INT_ARG(6); int dW = INT_ARG(7); int isSameMode = INT_ARG(8); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -161,7 +161,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index 3fd5f9c51ba6..4f82976bf9c3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -48,7 +48,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -102,7 +102,7 @@ DECLARE_SHAPE_FN(maxpool3dnew) { int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -167,7 +167,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { const int dW = INT_ARG(11); // dilations width const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index 111846584115..05aab696290b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -35,7 +35,7 @@ namespace sd { REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf()); - auto argI = *(block.getIArguments()); + auto argI = block.getIArguments(); helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index 4c9319ca1350..a3c27f02e5f2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -52,7 +52,7 @@ namespace sd { int oY = 0; int oX = 0; - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW if(!isNCHW) { input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] @@ -93,7 +93,7 @@ namespace sd { auto shapeOf = shape::shapeOf(inShape); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - std::vector argI = *(block.getIArguments()); + auto argI = block.getIArguments(); int kH = INT_ARG(0); int kW = INT_ARG(1); int sH = INT_ARG(2); @@ -103,7 +103,7 @@ namespace sd { int dH = INT_ARG(6); int dW = INT_ARG(7); int isSameMode = INT_ARG(8); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW REQUIRE_TRUE(dH != 0 && dW != 0, 0, "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int pnorm = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW // FIXME: double? double eps = T_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp index 33fd5e8ea7d6..ae70fd451d8f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp @@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if non zero then [time, bS, ...], else [bS, time, ...] + const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if non zero then [time, bS, ...], else [bS, time, ...] switch(block.width()) { case 8: @@ -147,7 +147,7 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] + const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] switch(block.width()) { case 8: diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp index 41696638d20e..2e6f9597672f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(dynamic_rnn, 4, 2, false, 0, 0) { NDArray* h0 = nullptr; // initial cell output (at time step = 0) [bS x numUnits] NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] + const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] if(block.width() == 5) { if ((*INPUT_VARIABLE(4)).rankOf() == 2) @@ -112,7 +112,7 @@ DECLARE_SHAPE_FN(dynamic_rnn) { Nd4jLong* h0ShapeInfo = nullptr; // initial cell output (at time step = 0) [bS x numUnits] Nd4jLong* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] + const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] if(block.width() == 5) { if (inputShape->at(4)[0] == 2) diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 94a4a0ca4ccf..0bab69059ec0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -39,7 +39,7 @@ namespace sd { auto status = op.execute({x, w, b}, {output}); REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data."); - auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; output->applyScalar(sd::scalar::RELU, scalar, *output); diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index d5c58bb7a24a..bc292e09022a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -40,7 +40,7 @@ CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { auto output = OUTPUT_VARIABLE(0); const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); @@ -56,7 +56,7 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { auto gradI = OUTPUT_VARIABLE(0); const int rank = input->rankOf(); - const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp index cc8a64fa661e..e79b8cc2d991 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp @@ -71,7 +71,7 @@ namespace sd { int numClasses = 0; - if (block.getIArguments()->size() > 0) { + if (block.numI() > 0) { numClasses = INT_ARG(0); } else { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index 291e8b7c1fd7..2ce8643159e3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -32,14 +32,14 @@ namespace sd { NDArray* min; NDArray* max; - REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars: No minimum/maximum values provided by either input arrays or TArgs"); + REQUIRE_TRUE(block.width() == 3 || block.numT() == 2, 0, "fake_quant_with_min_max_vars: No minimum/maximum values provided by either input arrays or TArgs"); NDArray m; NDArray m2; if(block.width() == 3){ min = INPUT_VARIABLE(1); max = INPUT_VARIABLE(2); - } else if(block.getTArguments()->size() == 2){ + } else if(block.numT() == 2){ m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); min = &m; @@ -49,10 +49,10 @@ namespace sd { REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same"); int numBits = 8; - if (block.getIArguments() && block.getIArguments()->size()) + if (block.numI()) numBits = INT_ARG(0); bool narrowed = false; - if (block.getBArguments() && block.getBArguments()->size()) { + if (block.numB()) { narrowed = B_ARG(0); } REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \ diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index 4af483e22e48..c4f4b471e33c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -44,11 +44,11 @@ namespace sd { REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same"); int numBits = 8; - if (block.getIArguments() && block.getIArguments()->size()) + if (block.numI()) numBits = INT_ARG(0); bool narrowed = false; //INT_ARG(1); - if (block.getBArguments() && block.getBArguments()->size()) { + if (block.numB()) { narrowed = B_ARG(0); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index c32ee1ba995c..b6906bffcc07 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -32,7 +32,7 @@ namespace sd { int maxOutputSize; // = INT_ARG(0); if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); @@ -43,14 +43,14 @@ namespace sd { if (block.width() > 3) { overlayThreshold = INPUT_VARIABLE(3)->e(0); } - else if (block.getTArguments()->size() > 0) { + else if (block.numT() > 0) { overlayThreshold = T_ARG(0); } if (block.width() > 4) { scoreThreshold = INPUT_VARIABLE(4)->e(0); } - else if (block.getTArguments()->size() > 1) { + else if (block.numT() > 1) { scoreThreshold = T_ARG(1); } if (boxes->isEmpty() || scales->isEmpty()) @@ -85,16 +85,16 @@ namespace sd { int maxOutputSize; if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); if (maxOutputSize > 0) { auto actualIndicesCount = shape::sizeAt(in, 0); - if (block.getTArguments()->size() > 1 || block.width() > 4) { + if (block.numT() > 1 || block.width() > 4) { auto scoreThreshold = - block.getTArguments()->size() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); + block.numT() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); auto scales = INPUT_VARIABLE(1); scales->syncToHost(); for (auto e = 0; e < scales->lengthOf(); e++) { @@ -130,7 +130,7 @@ namespace sd { int maxOutputSize; // = INT_ARG(0); if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); @@ -141,14 +141,14 @@ namespace sd { if (block.width() > 3) { overlayThreshold = INPUT_VARIABLE(3)->e(0); } - else if (block.getTArguments()->size() > 0) { + else if (block.numT() > 0) { overlayThreshold = T_ARG(0); } if (block.width() > 4) { scoreThreshold = INPUT_VARIABLE(4)->e(0); } - else if (block.getTArguments()->size() > 1) { + else if (block.numT() > 1) { scoreThreshold = T_ARG(1); } if (boxes->isEmpty() || scales->isEmpty()) @@ -183,7 +183,7 @@ namespace sd { int maxOutputSize; if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); @@ -196,14 +196,14 @@ namespace sd { if (block.width() > 3) { overlayThreshold = INPUT_VARIABLE(3)->e(0); } - else if (block.getTArguments()->size() > 0) { + else if (block.numT() > 0) { overlayThreshold = T_ARG(0); } if (block.width() > 4) { scoreThreshold = INPUT_VARIABLE(4)->e(0); } - else if (block.getTArguments()->size() > 1) { + else if (block.numT() > 1) { scoreThreshold = T_ARG(1); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp index a8477c63a9da..36cb02256ce8 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp @@ -32,7 +32,7 @@ namespace sd { int maxOutputSize; // = INT_ARG(0); if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression_overlaps: Max output size argument cannot be retrieved."); @@ -44,9 +44,9 @@ namespace sd { // maxOutputSize = scales->lengthOf(); double overlapThreshold = 0.5; double scoreThreshold = -DataTypeUtils::infOrMax(); - if (block.getTArguments()->size() > 0) + if (block.numT() > 0) overlapThreshold = T_ARG(0); - if (block.getTArguments()->size() > 1) + if (block.numT() > 1) scoreThreshold = T_ARG(1); // TODO: refactor helpers to multithreaded facility @@ -63,7 +63,7 @@ namespace sd { int maxOutputSize; if (block.width() > 2) maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.getIArguments()->size() == 1) + else if (block.numI() == 1) maxOutputSize = INT_ARG(0); else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index d3ccff82aac8..1ec70a26e510 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -36,7 +36,7 @@ namespace sd { // FIXME: double? NDArray shift = NDArrayFactory::create(0.); - if (block.getTArguments()->size() > 0) { + if (block.numT() > 0) { shift.assign(T_ARG(0)); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp index b1b68c23d4ed..10970e965938 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp @@ -27,7 +27,7 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto n = INPUT_VARIABLE(1); bool reverse = false; - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) reverse = (bool)INT_ARG(0); auto output = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 75f102fa0cc3..18e271409359 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -65,7 +65,7 @@ namespace ops { else // cut shift to value between 1 and inputLen - 1 shift %= inputLen; - axes.resize(block.getIArguments()->size() - 1); + axes.resize(block.numI() - 1); if (axes.size()) shifts.resize(axes.size());//emplace_back(shift); else diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index 6b0402ebb6ce..fa4625288235 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -31,7 +31,7 @@ namespace sd { //REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank); Nd4jLong maxInd = input->argMax(); float max = input->e(maxInd); - if (block.getIArguments()->size() > 0) { + if (block.numI() > 0) { maxInd = INT_ARG(0); if (maxInd < max) maxInd = static_cast(max); diff --git a/libnd4j/include/ops/declarable/generic/random/gamma.cpp b/libnd4j/include/ops/declarable/generic/random/gamma.cpp index d508e1929bf8..d775aa943235 100644 --- a/libnd4j/include/ops/declarable/generic/random/gamma.cpp +++ b/libnd4j/include/ops/declarable/generic/random/gamma.cpp @@ -41,7 +41,7 @@ namespace sd { auto output = OUTPUT_VARIABLE(0); auto seed = 0; - if (block.getIArguments()->size()) { + if (block.numI()) { seed = INT_ARG(0); } diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index 5361d1bbb66a..52f0b3070637 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -61,7 +61,7 @@ namespace sd { const int rank = input->rankOf(); REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; auto dimA = (0 == dimC) ? 1 : 0; @@ -91,7 +91,7 @@ namespace sd { const int rank = input->rankOf(); REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - const int argSize = block.getIArguments()->size(); + const int argSize = block.numI(); const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; auto nShape = input->getShapeAsVector(); diff --git a/libnd4j/include/ops/declarable/generic/random/poisson.cpp b/libnd4j/include/ops/declarable/generic/random/poisson.cpp index 74f3a8570659..30fc27ecf508 100644 --- a/libnd4j/include/ops/declarable/generic/random/poisson.cpp +++ b/libnd4j/include/ops/declarable/generic/random/poisson.cpp @@ -33,7 +33,7 @@ namespace sd { auto lambda = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); auto seed = 0; - if (block.getIArguments()->size()) { + if (block.numI()) { seed = INT_ARG(0); } rng.setSeed(seed); diff --git a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp index 2ac2495d3651..a52f7b6548cf 100644 --- a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp +++ b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(random_crop, 2, 1, false, 0, 0) { int seed = 0; - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) seed = INT_ARG(0); REQUIRE_TRUE(shape->isVector(), 0, "random_crop: Shape tensor should be a vector."); diff --git a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp index f4c240d50722..5102321ad8f4 100644 --- a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp @@ -31,7 +31,7 @@ namespace sd { auto rng = block.getRng(); //.getRNG(); Nd4jLong seed = 0; - if (block.getIArguments()->size() > 0) { + if (block.numI() > 0) { seed = INT_ARG(0); } else if (block.width() > 0) { auto input = INPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 6dec967392c9..918b8da7bf50 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -41,7 +41,7 @@ namespace sd { // uniform distribution auto rng = block.randomGenerator(); auto dtype = DataType::FLOAT32; - if (block.getIArguments()->size()) + if (block.numI()) dtype = (DataType)INT_ARG(0); auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; @@ -75,7 +75,7 @@ namespace sd { auto shape = in->template asVectorT(); auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min - if (block.getIArguments()->size()) + if (block.numI()) dtype = (DataType)INT_ARG(0); if (block.width() > 1) REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same"); diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp index 928a0f7d0b51..151b5a022492 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp @@ -37,7 +37,7 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -60,7 +60,7 @@ namespace sd { std::vector dims; if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp index f4fb25daa680..ef7f58b5e049 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp @@ -35,7 +35,7 @@ namespace sd { CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) { auto input = INPUT_VARIABLE(0); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); auto output = OUTPUT_VARIABLE(0); @@ -60,7 +60,7 @@ namespace sd { std::vector dims; auto in = inputShape->at(0); if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); diff --git a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp index 64c2e5ccb672..6a0b7880b591 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp @@ -31,7 +31,7 @@ namespace sd { NDArray *output = OUTPUT_VARIABLE(0); auto mode = (int) T_ARG(0); - std::vector dims = *block.getIArguments(); + std::vector dims = block.getIArguments(); bool overwrite = false; if (block.width() == 1) { @@ -40,14 +40,12 @@ namespace sd { auto axisVector = INPUT_VARIABLE(1); dims.resize(axisVector->lengthOf()); helpers::adjustAxis(input->rankOf(), axisVector, dims); - axisVector->printIndexedBuffer("AXIS"); auto shape = ShapeUtils::evalReduceShapeInfo(input->ordering(), dims, *input, false, false); if (!shape::equalsStrict(shape, output->shapeInfo())) { output = new NDArray(shape, false, block.launchContext()); overwrite = true; } } - output->printShapeInfo("Output Shape Info"); switch(mode) { case 0: { REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); @@ -81,7 +79,7 @@ namespace sd { break; default: { // p-norm - REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); + REQUIRE_TRUE(block.numI() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); // FIXME: p is required here //T p = T_ARG(1); input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp index 90560bbb6feb..01de97b5f0ba 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp @@ -30,16 +30,16 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -54,7 +54,7 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_mean) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); auto in = inputShape->at(0); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); @@ -62,9 +62,9 @@ DECLARE_SHAPE_FN(reduce_mean) { } bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= in[0], 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -91,16 +91,16 @@ CUSTOM_OP_IMPL(reduce_mean_bp, 2, 1, false, 0, 0) { auto gradI = OUTPUT_VARIABLE(0); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -129,7 +129,7 @@ CUSTOM_OP_IMPL(reduce_mean_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_mean_bp) { auto in = inputShape->at(0); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); auto rank = shape::rank(in); if (block.width() > 2) { diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp index 1682b9d72146..3e0491ffefaf 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp @@ -33,20 +33,20 @@ CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) { bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) + if (block.numB() > 1) biasCorrected = B_ARG(1); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); } @@ -64,17 +64,17 @@ DECLARE_SHAPE_FN(reduce_stdev) { auto in = inputShape->at(0); auto rank = shape::rank(in); bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(rank, axesVector, dimensions); } - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); } @@ -105,20 +105,20 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) + if (block.numB() > 1) biasCorrected = B_ARG(1); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); } @@ -150,7 +150,7 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_stdev_bp) { auto in = inputShape->at(0); auto rank = shape::rank(in); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(rank, axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp index cd7441304f2c..00a4f6d7e8de 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp @@ -33,20 +33,20 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) + if (block.numB() > 1) biasCorrected = B_ARG(1); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); } @@ -63,16 +63,16 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_variance) { bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); } @@ -102,20 +102,20 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } // else if (block.getIArguments()->size()) - if (block.getBArguments()->size()) { + if (block.numB()) { keepDims = B_ARG(0); - if (block.getBArguments()->size() > 1) + if (block.numB() > 1) biasCorrected = B_ARG(1); } - else if (block.getTArguments()->size()) { + else if (block.numT()) { keepDims = (bool)T_ARG(0); - if (block.getTArguments()->size() > 1) + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); } @@ -146,7 +146,7 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_variance_bp) { auto in = inputShape->at(0); auto rank = shape::rank(in); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(rank, axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp index 75cb40ca27bf..7f0c822a8729 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp @@ -48,16 +48,16 @@ CUSTOM_OP_IMPL(reduce_dot_bp, 3, 2, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 3) { auto axesVector = INPUT_VARIABLE(3); helpers::adjustAxis(x->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= x->rankOf(), 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -87,16 +87,16 @@ DECLARE_SHAPE_FN(reduce_dot_bp) { if(shape::length(inputShape->at(2)) > 1) { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 3) { auto axesVector = INPUT_VARIABLE(3); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp index 805db18834f2..2acade3202db 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp @@ -32,14 +32,14 @@ namespace ops { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axes ); } - else if (block.getIArguments()->size() > 0) { - axes = *block.getIArguments(); + else if (block.numI() > 0) { + axes = block.getIArguments(); } for(const auto& item : axes) REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item shapeInfo()[0], 0, "REDUCE_LOGSUMEXP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; Nd4jLong maxI = input->argMax(); auto maxVals = input->e(maxI); //void* whereMax = (void*)(); @@ -58,7 +58,7 @@ namespace ops { } DECLARE_SHAPE_FN(reduce_logsumexp) { - const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; auto input = INPUT_VARIABLE(0); std::vector axes; // = *block.getIArguments(); @@ -66,8 +66,8 @@ namespace ops { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axes ); } - else if (block.getIArguments()->size() > 0) { - axes = *block.getIArguments(); + else if (block.numI() > 0) { + axes = block.getIArguments(); } Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, block.getWorkspace()); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp index 3d2dbe57ee74..f70e613e8da1 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); @@ -47,9 +47,9 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) + if (block.numB() > 0) keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) + else if (block.numT() > 0) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); @@ -61,12 +61,12 @@ DECLARE_SHAPE_FN(reduce_max) { bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) + if (block.numB() > 0) keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) + else if (block.numT() > 0) keepDims = (bool)T_ARG(0); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); @@ -98,7 +98,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); @@ -131,7 +131,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_max_bp) { - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp index 254cfe021ec3..85addc6121c4 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); @@ -47,9 +47,9 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) + if (block.numB() > 0) keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) + else if (block.numT() > 0) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); @@ -61,12 +61,12 @@ DECLARE_SHAPE_FN(reduce_min) { bool keepDims = false;//: false; - if (block.getBArguments()->size() > 0) + if (block.numB() > 0) keepDims = B_ARG(0); - else if (block.getTArguments()->size() > 0) + else if (block.numT() > 0) keepDims = (bool)T_ARG(0); - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); @@ -101,7 +101,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); @@ -133,7 +133,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_min_bp) { - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp index 31261fe5ccd6..5f15c0e8e89f 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp @@ -37,8 +37,8 @@ CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -46,9 +46,9 @@ CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); @@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_norm1) { bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); std::vector dimensions; @@ -69,8 +69,8 @@ DECLARE_SHAPE_FN(reduce_norm1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -108,16 +108,16 @@ CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -139,7 +139,7 @@ CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_norm1_bp) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp index c9ea8e374d79..60d7385a1c34 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp @@ -36,8 +36,8 @@ CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -45,9 +45,9 @@ CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); @@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_norm2) { bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); std::vector dimensions; @@ -69,8 +69,8 @@ DECLARE_SHAPE_FN(reduce_norm2) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -105,16 +105,16 @@ CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -137,7 +137,7 @@ CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_norm2_bp) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp index b1a0189009e1..4a7d18f17dc9 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp @@ -38,8 +38,8 @@ CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -47,9 +47,9 @@ CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); @@ -61,9 +61,9 @@ DECLARE_SHAPE_FN(reduce_norm_max) { auto in = inputShape->at(0); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); std::vector dimensions; @@ -71,8 +71,8 @@ DECLARE_SHAPE_FN(reduce_norm_max) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -98,7 +98,7 @@ CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - std::vector dimensions = *block.getIArguments(); + std::vector dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); @@ -133,7 +133,7 @@ CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_norm_max_bp) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp index e873220efa77..b35b5d77fd9d 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp @@ -37,8 +37,8 @@ CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -46,9 +46,9 @@ CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); @@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_prod) { bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); std::vector dimensions; @@ -69,8 +69,8 @@ DECLARE_SHAPE_FN(reduce_prod) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -104,16 +104,16 @@ CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -139,7 +139,7 @@ CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_prod_bp) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp index 0c53a261b095..25c8688b2113 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp @@ -33,16 +33,16 @@ CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 1) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -57,7 +57,7 @@ CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_sqnorm) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); bool keepDims = false; if (block.width() > 1) { @@ -65,9 +65,9 @@ DECLARE_SHAPE_FN(reduce_sqnorm) { helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -103,16 +103,16 @@ CUSTOM_OP_IMPL(reduce_sqnorm_bp, 2, 1, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -135,7 +135,7 @@ DECLARE_SHAPE_FN(reduce_sqnorm_bp) { if(shape::length(inputShape->at(1)) > 1) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp index 0f4a5f467556..3aca5854a8a5 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp @@ -37,8 +37,8 @@ CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -46,9 +46,9 @@ CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); @@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_sum) { bool keepDims = false; - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); std::vector dimensions; @@ -69,8 +69,8 @@ DECLARE_SHAPE_FN(reduce_sum) { auto axesVector = INPUT_VARIABLE(1); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - else if (block.getIArguments()->size()) - dimensions = *block.getIArguments(); + else if (block.numI()) + dimensions = block.getIArguments(); REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -101,16 +101,16 @@ CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { else { bool keepDims = false; - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - if (block.getBArguments()->size()) + if (block.numB()) keepDims = B_ARG(0); - else if (block.getTArguments()->size()) + else if (block.numT()) keepDims = (bool)T_ARG(0); REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); @@ -133,7 +133,7 @@ CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { DECLARE_SHAPE_FN(reduce_sum_bp) { - auto dimensions = *block.getIArguments(); + auto dimensions = block.getIArguments(); if (block.width() > 2) { auto axesVector = INPUT_VARIABLE(2); helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); diff --git a/libnd4j/include/ops/declarable/generic/shape/permute.cpp b/libnd4j/include/ops/declarable/generic/shape/permute.cpp index f612aec92755..c3065455364e 100644 --- a/libnd4j/include/ops/declarable/generic/shape/permute.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/permute.cpp @@ -40,12 +40,12 @@ CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) { return Status::OK(); //No op } - if (block.width() == 1 && block.getIArguments()->size() == 0) { + if (block.width() == 1 && block.numI() == 0) { z->assign(x->transpose()); return Status::OK(); } - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); z->assign(x->permute(permutationVector)); @@ -65,10 +65,10 @@ DECLARE_SHAPE_FN(permute) { auto x = INPUT_VARIABLE(0); - if (block.width() == 1 && block.getIArguments()->size() == 0) + if (block.width() == 1 && block.numI() == 0) return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index ace58a0b88ea..4fee3d9fd2e7 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -43,12 +43,12 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { if (block.width() == 1) { auto arguments = block.getIArguments(); - int argsSize = arguments->size(); + int argsSize = arguments.size(); int e = 1; - char order = (char) -(*arguments)[0]; + char order = (char) -arguments[0]; if (order != 'c' && order != 'f') { order = 'c'; //x->ordering(); e = 0; @@ -58,20 +58,20 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { std::vector shapeNew; int e2 = e; - for (; e < (int) arguments->size(); e++) { - if (arguments->at(e) == -1){ + for (; e < (int) arguments.size(); e++) { + if (arguments.at(e) == -1){ Nd4jLong shapeLength = 1; for(; e2 < e; e2++){ - shapeLength *= arguments->at(e2); + shapeLength *= arguments.at(e2); } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - shapeLength *= arguments->at(e2); + for(e2 = e + 1; e2 < arguments.size(); e2++){ + shapeLength *= arguments.at(e2); } Nd4jLong realShape = x->lengthOf() / shapeLength; shapeNew.push_back(realShape); } else{ - shapeNew.push_back(arguments->at(e)); + shapeNew.push_back(arguments.at(e)); } } @@ -156,10 +156,10 @@ DECLARE_SHAPE_FN(reshape) { // we can launch op using Int arguments if (inputShape->size() == 1) { REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined"); - std::vector *arguments = block.getIArguments(); + std::vector arguments = block.getIArguments(); int e = 1; - char order = (char) -(*arguments)[0]; + char order = (char) -arguments[0]; if (order != 'c' && order != 'f') { order = shape::order(inp); e = 0; @@ -168,16 +168,16 @@ DECLARE_SHAPE_FN(reshape) { std::vector shapeNew; int e2 = e; - for (; e < (int) arguments->size(); e++) { - if ((int) arguments->at(e) == -1){ + for (; e < (int) arguments.size(); e++) { + if ((int) arguments.at(e) == -1){ Nd4jLong shapeLength = 1; for(; e2 < e; e2 ++){ - shapeLength *= arguments->at(e2); + shapeLength *= arguments.at(e2); } - for(e2 = e + 1; e2 < arguments->size(); e2++){ - REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= arguments->at(e2); + for(e2 = e + 1; e2 < arguments.size(); e2++){ + REQUIRE_TRUE(arguments.at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + shapeLength *= arguments.at(e2); } if(shapeLength == 0){ @@ -190,7 +190,7 @@ DECLARE_SHAPE_FN(reshape) { } } else{ - shapeNew.push_back(arguments->at(e)); + shapeNew.push_back(arguments.at(e)); } } diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index 687d79f25baf..03ef8b9ab44a 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -30,7 +30,7 @@ namespace ops { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - std::vector outShape(block.getIArguments()->begin(), block.getIArguments()->end()); + std::vector outShape(block.getIArguments().begin(), block.getIArguments().end()); if (block.isInplace()) { input->tileToShape(outShape, *input); @@ -46,7 +46,7 @@ namespace ops { // output shape always equals to arguments - auto conv = ArrayUtils::toLongVector(*block.getIArguments()); + auto conv = ArrayUtils::toLongVector(block.getIArguments()); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv); diff --git a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp index 0b12f415ffc6..bf48544bbbd9 100644 --- a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp @@ -40,12 +40,12 @@ CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) { return Status::OK(); //No op } - if (block.width() == 1 && block.getIArguments()->size() == 0) { + if (block.width() == 1 && block.numI() == 0) { z->assign(x->transpose()); return Status::OK(); } - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); z->assign(x->permute(permutationVector)); @@ -62,10 +62,10 @@ DECLARE_SHAPE_FN(transpose) { auto x = INPUT_VARIABLE(0); - if (block.width() == 1 && block.getIArguments()->size() == 0) + if (block.width() == 1 && block.numI() == 0) return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); diff --git a/libnd4j/include/ops/declarable/generic/tensor/range.cpp b/libnd4j/include/ops/declarable/generic/tensor/range.cpp index a39e0791295e..987b7424d235 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/range.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/range.cpp @@ -32,8 +32,8 @@ CUSTOM_OP_IMPL(range, -2, 1, false, -2, -2) { auto output = OUTPUT_VARIABLE(0); const int numInArrs = block.width(); - const int numTArgs = block.getTArguments()->size(); - const int numIArgs = block.getIArguments()->size(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); NDArray *s = nullptr; NDArray *d = nullptr; @@ -126,8 +126,8 @@ CUSTOM_OP_IMPL(range, -2, 1, false, -2, -2) { DECLARE_SHAPE_FN(range) { const int numInArrs = block.width(); - const int numTArgs = block.getTArguments()->size(); - const int numIArgs = block.getIArguments()->size(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); Nd4jLong steps = 0; sd::DataType dataType = block.numD() ? D_ARG(0) : sd::DataType::INHERIT; diff --git a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp index 747331ef0dc1..6466b9c1cc7e 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp @@ -325,12 +325,12 @@ namespace sd { std::vector args; // statically evaluated - if (block.getIArguments()->size() > 5) { - dim_values = block.getIArguments()->size() - 5; + if (block.numI() > 5) { + dim_values = block.numI() - 5; delta = dim_values % 3; elements = dim_values / 3; - for (int e = 5; e < block.getIArguments()->size(); e++) + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); @@ -448,7 +448,7 @@ namespace sd { int x_rank = shape::rank(inShape); - int dim_values = block.getIArguments()->size() - 5; + int dim_values = block.numI() - 5; int delta = dim_values % 3; int elements = dim_values / 3; @@ -466,7 +466,7 @@ namespace sd { int delta2 = dim_values / x_rank; std::vector args; - for (int e = 5; e < block.getIArguments()->size(); e++) + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); // FIXME: propably template required here @@ -558,12 +558,12 @@ namespace sd { std::vector args; // statically evaluated - if (block.getIArguments()->size() > 5) { - dim_values = block.getIArguments()->size() - 5; + if (block.numI() > 5) { + dim_values = block.numI() - 5; delta = dim_values % 3; elements = dim_values / 3; - for (int e = 5; e < block.getIArguments()->size(); e++) + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index 7860036ed528..dfb9a9fc2916 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -53,7 +53,7 @@ namespace sd { int batchSize = x->sizeAt(0); int numColumns = x->sizeAt(1); - std::vector indices(*block.getIArguments()); + std::vector indices(block.getIArguments()); std::map sparse2dense; @@ -92,7 +92,7 @@ namespace sd { DECLARE_SHAPE_FN(firas_sparse) { auto inP = inputShape->at(0); - std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong) block.getIArguments()->size()}); + std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong) block.numI()}); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inP), 'c', shape); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp index 958a90410b6c..2bfe06dd9f40 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp @@ -35,7 +35,7 @@ CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) { const bool isInplace = block.isInplace(); auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext()); - helpers::clipByAveraged(block.launchContext(), *input, *output, *block.getIArguments(), ts, isInplace); + helpers::clipByAveraged(block.launchContext(), *input, *output, block.getIArguments(), ts, isInplace); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp index 43b23ba18be3..d062d47105c1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp @@ -34,7 +34,7 @@ namespace ops { const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext()); const bool isInplace = block.isInplace(); - helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace); + helpers::clipByNorm(block.launchContext(), *input, *output, block.getIArguments(), clipNorm, isInplace); return Status::OK(); } @@ -47,7 +47,7 @@ namespace ops { auto gradI = OUTPUT_VARIABLE(0); const auto clipNorm = NDArrayFactory::create(T_ARG(0)); - helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm); + helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, block.getIArguments(), clipNorm); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 0b171b36f623..5e587e40d453 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -32,7 +32,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); @@ -126,7 +126,7 @@ DECLARE_SHAPE_FN(concat) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); @@ -376,7 +376,7 @@ DECLARE_SHAPE_FN(concat) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) { - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); @@ -419,7 +419,7 @@ DECLARE_TYPES(concat_bp) { DECLARE_SHAPE_FN(concat_bp) { - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index c0b011f997f1..79644f51b8d4 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -40,7 +40,7 @@ namespace sd { const bool exclusive = INT_ARG(0) == 1; const bool reverse = INT_ARG(1) == 1; - if (block.getIArguments()->size() == 2 && block.width() == 1) { + if (block.numI() == 2 && block.width() == 1) { // all at once case sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, exclusive, reverse); } else { diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp index 97389fddbfb5..712eaca70392 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp @@ -41,7 +41,7 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) { return Status::OK(); } - if (block.getIArguments()->size() == 2 && block.width() == 1) { + if (block.numI() == 2 && block.width() == 1) { // all at once case sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index 79ce8ad29fff..848f72c998fe 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; auto output = OUTPUT_VARIABLE(0); - const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); //Edge case: empty indices -> empty output if(indices != nullptr && indices->isEmpty()){ @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { intArgs.emplace_back(0); else for (int i = 0; i < numOfIntArgs; ++i) - intArgs.emplace_back(block.getIArguments()->at(i)); + intArgs.emplace_back(block.getIArguments().at(i)); } const int inputRank = input->rankOf(); @@ -101,7 +101,7 @@ DECLARE_SHAPE_FN(gather) { if (block.width() > 2) { axis = INPUT_VARIABLE(2)->e(0); } else - axis = block.numI() > 0 ? block.getIArguments()->at(0) : 0; + axis = block.numI() > 0 ? block.getIArguments().at(0) : 0; int inputRank = shape::rank(inputShapeInfo); if(axis < 0) diff --git a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp index 30b5b19ef252..cf0081c76046 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp @@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) { auto indices = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); const int rankIn = input->rankOf(); const int rankInd = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp index 36175fc0181d..7c8bb4c6c4c5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp @@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) { auto range = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments()->empty() ? 100 : INT_ARG(0); + const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments().empty() ? 100 : INT_ARG(0); const double leftEdge = range->e(0); const double rightEdge = range->e(1); @@ -56,7 +56,7 @@ DECLARE_TYPES(histogram_fixed_width) { ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(histogram_fixed_width) { - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments()->empty() ? 100 : INT_ARG(0); + const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments().empty() ? 100 : INT_ARG(0); auto outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(nbins, DataType::INT64); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index 1ffe42f4ba27..0a4e15ff44f9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -52,7 +52,7 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex); DECLARE_SHAPE_FN(mergemaxindex) { auto in = inputShape->at(0); auto dtype = DataType::INT32; - if (block.getIArguments()->size()> 0) + if (block.numI() > 0) dtype = (DataType)INT_ARG(0); auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index d5d38aaeb715..9c08662857b7 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, "PAD op: data types of input and padValue arrays should be the same but got %i and %i correspondingly !", input->dataType(), INPUT_VARIABLE(2)->dataType()); padValue.assign(INPUT_VARIABLE(2)->e(0)); } - else if (!block.getTArguments()->empty()) + else if (!block.getTArguments().empty()) padValue = T_ARG(0); } else if(INT_ARG(0) == 1) { // REFLECT mode diff --git a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp index 99ab3d635cc9..4c121d115df8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp @@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(repeat, 1, 1, true, 0, -1) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - std::vector repeats = *block.getIArguments(); + std::vector repeats = block.getIArguments(); const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); @@ -58,7 +58,7 @@ DECLARE_SHAPE_FN(repeat) { auto input = INPUT_VARIABLE(0); - std::vector repeats = *block.getIArguments(); + auto repeats = block.getIArguments(); const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index ceb953979fe5..da2307562130 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -43,7 +43,7 @@ namespace ops { if (block.width() > 1) axis = INPUT_VARIABLE(1)->template asVectorT(); else if (block.numI() > 0) - axis = *block.getIArguments(); + axis = block.getIArguments(); if(axis.empty()) { // do not perform reversion if (!block.isInplace()) @@ -76,7 +76,7 @@ namespace ops { if (block.width() == 3) axis = INPUT_VARIABLE(1)->template asVectorT(); else if (block.numI() > 0) - axis = *block.getIArguments(); + axis = block.getIArguments(); if(axis.empty()) { // reversion is not performed in this case output->assign(eps); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp index e624afeb1e57..09a0efc5585e 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp @@ -38,8 +38,8 @@ OP_IMPL(scatter_add, 3, 1, true) { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp index fd0b2a7305a3..a4febc684483 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp @@ -38,8 +38,8 @@ namespace sd { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp index b3342c5a58ac..220c68227b1e 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp @@ -38,8 +38,8 @@ OP_IMPL(scatter_max, 3, 1, true) { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp index d37adb692a4e..957541e5e74f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp @@ -38,8 +38,8 @@ OP_IMPL(scatter_min, 3, 1, true) { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp index 9bf5be7487b8..24e329f6da31 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp @@ -34,8 +34,8 @@ namespace sd { auto output = OUTPUT_VARIABLE(0); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp index 7c2194c6c9da..564a8f478582 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp @@ -34,8 +34,8 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp index 8fb4288ee86d..691894a9ec72 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp @@ -34,8 +34,8 @@ OP_IMPL(scatter_nd_add, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp index 6cfa5d0463c9..d9db29bb132f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp @@ -34,8 +34,8 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp index b6122c724937..9f3f6be0849c 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp @@ -34,8 +34,8 @@ OP_IMPL(scatter_nd_update, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp index c955ac04221a..8ceb5de5fe88 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp @@ -37,8 +37,8 @@ namespace sd { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp index ef54b98138e0..d48a86fdc142 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp @@ -36,8 +36,8 @@ namespace sd { if (!block.isInplace()) output->assign(input); - const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); - const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp index d15b4c85949d..72b087e7baa9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp @@ -43,7 +43,7 @@ namespace sd { auto operand = INPUT_VARIABLE(0); auto updates = INPUT_VARIABLE(1); - helpers::scatterUpdate(block.launchContext(), *operand, *updates, block.getIArguments()); + helpers::scatterUpdate(block.launchContext(), *operand, *updates, &block.getIArguments()); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp index dc4671ef72fb..bd8423504bc6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp @@ -44,8 +44,10 @@ namespace sd { } else { REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(sz, *(block.getIArguments()), x_rank, x_rank); + auto vec = block.getIArguments(); + + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); } REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); @@ -131,8 +133,9 @@ namespace sd { } else { REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(sz, *(block.getIArguments()), x_rank, x_rank); + auto vec = block.getIArguments(); + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); } REQUIRE_TRUE(begin.size() == x_rank, 0, "Begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); @@ -191,8 +194,10 @@ namespace sd { } else { REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - ShapeUtils::copyVectorPart(begin, *(block.getIArguments()), x_rank, 0); - ShapeUtils::copyVectorPart(end, *(block.getIArguments()), x_rank, x_rank); + auto vec = block.getIArguments(); + + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(end, vec, x_rank, x_rank); } REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); diff --git a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp index 0bda3a6bef6a..a3499e93c7aa 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp @@ -31,7 +31,7 @@ namespace ops { int axis = 0; - if (block.getIArguments()->size() > 0) { + if (block.numI() > 0) { axis = INT_ARG(0); } else if (block.width() > 2){ auto _a = INPUT_VARIABLE(2); @@ -88,7 +88,7 @@ namespace ops { // 0 is just default axis int axis = 0; - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) axis = INT_ARG(0); else if (block.width() > 2) { auto _a = INPUT_VARIABLE(2); diff --git a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp index a78442b03449..2197a9d72bed 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp @@ -30,7 +30,7 @@ namespace ops { CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + int dim = block.numI() > 0 ? INT_ARG(0) : 0; if(dim < 0) dim += input->rankOf() + 1; @@ -70,7 +70,7 @@ DECLARE_SHAPE_FN(stack) { // check whether input dimension is within rank range auto inShapeInfo = inputShape->at(0); int rank = shape::rank(inShapeInfo); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + int dim = block.numI() > 0 ? INT_ARG(0) : 0; if(dim < 0 ) dim += rank + 1; diff --git a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp index f4e8a6f7acc8..4df64d20f780 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp @@ -38,7 +38,7 @@ namespace ops { if (block.width() > 1) axis = INPUT_VARIABLE(1)->template asVectorT(); else if (block.numI() > 0) - axis = *block.getIArguments(); + axis = block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") @@ -72,7 +72,7 @@ namespace ops { if (block.width() == 3) axis = INPUT_VARIABLE(1)->template asVectorT(); else if (block.numI() > 0) - axis = *block.getIArguments(); + axis = block.getIArguments(); REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") diff --git a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp index 61850ab0eea6..2c850bc93a86 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp @@ -30,9 +30,9 @@ namespace sd { CUSTOM_OP_IMPL(tear, 1, -1, false, 0, -1) { auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(!block.getIArguments()->empty(), 0, "At least 1 dimension should be specified for Tear"); + REQUIRE_TRUE(!block.getIArguments().empty(), 0, "At least 1 dimension should be specified for Tear"); - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); for (auto &v: dims) REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, "Tear dimensions should be non-negative values, and lower then input rank. Got %i instead", v); @@ -52,7 +52,7 @@ namespace sd { DECLARE_SHAPE_FN(tear) { auto inShape = inputShape->at(0); - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); if (dims.size() > 1) std::sort(dims.begin(), dims.end()); diff --git a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp index 6041d1c41a26..ffc64da6c520 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp @@ -36,9 +36,9 @@ CUSTOM_OP_IMPL(tile, 1, 1, false, 0, -2) { const int inRank = input->rankOf(); std::vector reps; - if (block.getIArguments()->size() == inRank) { + if (block.numI() == inRank) { - reps = ArrayUtils::toLongVector(*(block.getIArguments())); + reps = ArrayUtils::toLongVector(block.getIArguments()); } else if (block.width() > 1) { @@ -72,9 +72,9 @@ DECLARE_SHAPE_FN(tile) { const int inRank = inShape[0]; std::vector reps; - if (block.getIArguments()->size() == inRank) { + if (block.numI() == inRank) { - reps = ArrayUtils::toLongVector(*(block.getIArguments())); + reps = ArrayUtils::toLongVector(block.getIArguments()); } else if (block.width() > 1) { @@ -109,9 +109,9 @@ CUSTOM_OP_IMPL(tile_bp, 2, 1, false, 0, -2) { std::vector reps; - if (block.getIArguments()->size() == inRank) { + if (block.numI() == inRank) { - reps = ArrayUtils::toLongVector(*(block.getIArguments())); + reps = ArrayUtils::toLongVector(block.getIArguments()); } else if (block.width() > 2) { @@ -151,9 +151,9 @@ DECLARE_SHAPE_FN(tile_bp) { std::vector reps; - if (block.getIArguments()->size() == inRank) { + if (block.numI() == inRank) { - reps = ArrayUtils::toLongVector(*(block.getIArguments())); + reps = ArrayUtils::toLongVector(block.getIArguments()); } else if (block.width() > 2) { auto reps_vector = INPUT_VARIABLE(1); diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 82dd8c36e779..f229975dfeef 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -37,7 +37,7 @@ namespace ops { auto outputCols = OUTPUT_VARIABLE(1); auto outputVals = OUTPUT_VARIABLE(2); - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) N = INT_ARG(0); if (rowCountsPtr) { @@ -65,8 +65,9 @@ namespace ops { auto rowP = INPUT_VARIABLE(0); auto colP = INPUT_VARIABLE(1); auto N = rowP->lengthOf() - 1; - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) N = INT_ARG(0); + auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); NDArray* rowCounts = NDArrayFactory::create_('c', {N}); //rowP->dup(); //srowCounts->assign(0); diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp index bab2055432c5..a7aef8c8b4e4 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp @@ -47,7 +47,7 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initStateMsdx->getShapeInfo()).c_str()); - bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size(); + bool bParamsSupply = 5 == block.width() || 2 == block.numT(); REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp index a7a92b4104bf..084716e22392 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp @@ -43,7 +43,7 @@ namespace sd { ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); - bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp index 4e34c24f643a..38ca6e8b1679 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp @@ -49,9 +49,9 @@ namespace sd { ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + int iteration = block.numI() > 0 ? INT_ARG(0) : 0; REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp index a696d2388729..31fbfde5b3e3 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp @@ -48,9 +48,9 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp index bc0f4beac51c..1b6f401ac3ae 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp @@ -53,9 +53,9 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initStateH->getShapeInfo()).c_str()); - bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size(); + bool bParamsSupply = 8 == block.width() || 4 == block.numT(); - auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp index c6af0686be9e..3e7267c508b3 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp @@ -48,9 +48,9 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str()); - bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + auto nIteration = block.numI() > 0 ? INT_ARG(0) : 0; REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp index c77abd448b5d..98cedbb299b4 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp @@ -42,7 +42,7 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); - bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp index 1ca318e26568..9f4e214e07df 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp @@ -42,7 +42,7 @@ namespace sd { " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); - bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size(); + bool bParamsSupply = 5 == block.width() || 3 == block.numT(); REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!"); diff --git a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp index 491d7b53e203..829db3dcee9c 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp @@ -35,7 +35,7 @@ namespace sd { if (input->isEmpty()) return Status::OK(); - bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size(); + bool bLearningRate = 2 == block.width() || 1 == block.numT(); REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!"); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index 013c7c143d78..e06c39c2ed88 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -124,11 +124,11 @@ namespace sd { block.fillInputs(in); for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); + block.appendT(tArgs.at(e)); for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); + block.appendI(iArgs.at(e)); Nd4jStatus result = this->validateAndExecute(block); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 212617ec70fc..aca97b798652 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -170,7 +170,7 @@ namespace sd { if (fp) { // } else { - for (auto p: *ctx.inputs()) { + for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { NDArray *array = var->getNDArray(); @@ -187,7 +187,7 @@ namespace sd { int cnt = 0; auto id = ctx.nodeId(); auto vs = ctx.getVariableSpace(); - for (auto p: *ctx.inputs()) { + for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { NDArray *array = var->getNDArray(); @@ -238,7 +238,7 @@ namespace sd { } } else { int arrCnt = 0; - for (auto p: *ctx.inputs()) { + for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { NDArray *array = var->getNDArray(); @@ -456,7 +456,7 @@ namespace sd { cnt++; } } else { - for (auto &p: *(block.inputs())) { + for (auto &p: block.inputs()) { auto var = block.variable(p); // we're not checking validity, if ANY types were explicitly allowed @@ -737,25 +737,25 @@ namespace sd { * If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments */ if (_descriptor->getNumberOfTArgs() > 0) { - if ((int) block.getTArguments()->size() < _descriptor->getNumberOfTArgs()) { - nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), block.getTArguments()->size()); + if ((int) block.numT() < _descriptor->getNumberOfTArgs()) { + nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), block.numT()); return ND4J_STATUS_BAD_PARAMS; } } else if (_descriptor->getNumberOfTArgs() == -1) - if (block.getTArguments()->size() == 0) { + if (block.numT() == 0) { nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); return ND4J_STATUS_BAD_PARAMS; } if (_descriptor->getNumberOfIArgs() > 0) { - if ((int) block.getIArguments()->size() < _descriptor->getNumberOfIArgs()) { - nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), block.getIArguments()->size()); + if ((int) block.numI() < _descriptor->getNumberOfIArgs()) { + nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), block.numI()); return ND4J_STATUS_BAD_PARAMS; } } else if (_descriptor->getNumberOfIArgs() == -1) - if (block.getIArguments()->size() == 0) { + if (block.numI() == 0) { nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); return ND4J_STATUS_BAD_PARAMS; } @@ -768,7 +768,7 @@ namespace sd { if (block.width() == 0) return ND4J_STATUS_OK; - for (auto p: *block.inputs()) { + for (auto p: block.inputs()) { auto v = block.variable(p); NDArray *aV = v->getNDArray(); @@ -805,7 +805,7 @@ namespace sd { int cnt = 0; - for (auto p: *block.inputs()) { + for (auto p: block.inputs()) { auto v = block.variable(p); if (v == nullptr) { if (!this->getOpName().empty()) { @@ -844,7 +844,7 @@ namespace sd { return ND4J_STATUS_OK; NDArray *a0 = block.variable(0)->getNDArray(); - for (auto p: *block.inputs()) { + for (auto p: block.inputs()) { auto v = block.variable(p); NDArray *aV = v->getNDArray(); if (a0->ordering() != aV->ordering()) @@ -889,17 +889,17 @@ namespace sd { block.setRng(rng); for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); + block.appendT(tArgs.at(e)); // FIXME: iargs should be Nd4jLong for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(static_cast(iArgs.at(e))); + block.appendI(static_cast(iArgs.at(e))); for (int e = 0; e < bArgs.size(); e++) - block.getBArguments()->push_back(static_cast(bArgs.at(e))); + block.appendB(static_cast(bArgs.at(e))); for (int e = 0; e < dArgs.size(); e++) - block.getDArguments()->push_back(dArgs.at(e)); + block.appendD(dArgs.at(e)); Nd4jStatus result = this->execute(&block); @@ -1038,16 +1038,16 @@ namespace sd { // block.setRNG(ProviderRNG::getInstance().getRNG()); for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); + block.appendT(tArgs.at(e)); for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); + block.appendI(iArgs.at(e)); for (int e = 0; e < bArgs.size(); e++) - block.getBArguments()->push_back(bArgs.at(e)); + block.appendB(bArgs.at(e)); for (int e = 0; e < dArgs.size(); e++) - block.getDArguments()->push_back(dArgs.at(e)); + block.appendD(dArgs.at(e)); Nd4jStatus status = this->execute(&block); ResultSet arrayList; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp index 4f6646694c2d..c58766f65ac8 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp @@ -40,11 +40,11 @@ namespace sd { for (int e = 0; e < axis->lengthOf(); e++) dims.push_back(axis->e(e)); } - else if (block.getIArguments()->size()) - for (int e = 0; e < block.getIArguments()->size(); e++) + else if (block.numI()) + for (int e = 0; e < block.numI(); e++) dims.push_back(INT_ARG(e)); - else if (block.getAxis()->size()) { - dims = *block.getAxis(); //.push_back(axis->e(e)); + else if (block.getAxis().size()) { + dims = block.getAxis(); //.push_back(axis->e(e)); } if (dims.size() > 1) diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp index 03f34d269ee6..886318787171 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp @@ -33,7 +33,7 @@ namespace sd { auto z = OUTPUT_VARIABLE(0); - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); if (dims.size() > 0) std::sort(dims.begin(), dims.end()); diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp index 0297df28a806..f4954324f64f 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp @@ -35,7 +35,7 @@ namespace sd { NDArray::prepareSpecialUse({z}, {x, y}); - std::vector dims(*block.getAxis()); + std::vector dims(block.getAxis()); if (dims.size() == 0 && block.width() > 2) { auto axis = INPUT_VARIABLE(2); helpers::adjustAxis(x->rankOf(), axis, dims); diff --git a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp index c92577f3b413..8db309e3c354 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp @@ -43,7 +43,7 @@ namespace sd { auto inShape = inputShape->at(0); Nd4jLong *newShape; - if (block.getAxis()->size() == 0 && block.width() == 1) { + if (block.getAxis().size() == 0 && block.width() == 1) { // in this case we just return scalar ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; @@ -57,11 +57,11 @@ namespace sd { auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); RELEASE(newShape, block.getWorkspace()); return SHAPELIST(result); - } else if (block.getAxis()->size()){ + } else if (block.getAxis().size()){ // in this case we're building proper shape for reduction auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.getWorkspace()); - newShape = ShapeUtils::evalReduceShapeInfo('c', *block.getAxis(), *array, DataType::INT64, false, true, block.workspace()); + newShape = ShapeUtils::evalReduceShapeInfo('c', block.getAxis(), *array, DataType::INT64, false, true, block.workspace()); return SHAPELIST(newShape); } else { @@ -118,11 +118,11 @@ namespace sd { bool allAxes = false; - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyIndexReduceOp"); if (block.width() == 1) { - if (block.getAxis()->size() == 0) { + if (block.getAxis().size() == 0) { // scalar NativeOpExecutioner::execIndexReduceScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), @@ -131,9 +131,9 @@ namespace sd { z->getSpecialBuffer(), z->getSpecialShapeInfo()); } else { // TAD - std::vector dims(block.getAxis()->size()); + std::vector dims(block.getAxis().size()); for (size_t e = 0; e < dims.size(); e++) { - auto axe = block.getAxis()->at(e); + auto axe = block.getAxis().at(e); dims[e] = axe < 0 ? axe + x->rankOf(): axe; } if (dims.size() > 1) diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp index eb75141a9c39..28e2d6374fed 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp @@ -48,7 +48,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyPairwiseTransformBoolOp"); NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp index 7f6eecb19080..e540f77338b3 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp @@ -48,7 +48,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyPairwiseTransformOp"); NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index eeb80f4036ad..8be8d5d2a263 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -72,7 +72,7 @@ namespace sd { from = arg1->e(0); to = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { + } else if (block.numT() == 2) { from = T_ARG(0); to = T_ARG(1); } else { @@ -96,7 +96,7 @@ namespace sd { REQUIRE_TRUE(arg->isScalar(), 0, "DropOut: Second argument must be scalar"); prob = arg->e(0); - } else if (block.getTArguments()->size() > 0) { + } else if (block.numT() > 0) { prob = T_ARG(0); } else { REQUIRE_TRUE(false, 0, "DropOut requires either TArgs or second argument to be present"); @@ -125,7 +125,7 @@ namespace sd { mean = arg1->e(0); stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { + } else if (block.numT() == 2) { mean = T_ARG(0); stdev = T_ARG(1); } else { @@ -154,7 +154,7 @@ namespace sd { REQUIRE_TRUE(arg1->isScalar(), 0, "Bernoulli: Second argument must be scalar"); prob = arg1->e(0); - } else if (block.getTArguments()->size() > 0) { + } else if (block.numT() > 0) { prob = T_ARG(0); } else { REQUIRE_TRUE(false, 0, "Bernoulli requires either 1 TArg or 2 arguments to be present"); @@ -186,7 +186,7 @@ namespace sd { trials = arg1->e(0); prob = arg2->e(0); - } else if (block.getTArguments()->size() == 1 && block.getIArguments()->size() == 1) { + } else if (block.numT() == 1 && block.numI() == 1) { trials = INT_ARG(0); prob = T_ARG(0); } else { @@ -218,7 +218,7 @@ namespace sd { mean = arg1->e(0); stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { + } else if (block.numT() == 2) { mean = T_ARG(0); stdev = T_ARG(1); } else { @@ -250,7 +250,7 @@ namespace sd { mean = arg1->e(0); stdev = arg2->e(0); - } else if (block.getTArguments()->size() == 2) { + } else if (block.numT() == 2) { mean = T_ARG(0); stdev = T_ARG(1); } else { @@ -286,7 +286,7 @@ namespace sd { a = arg2->e(0); b = arg3->e(0); pa = arg4->e(0); - } else if (block.getTArguments()->size() == 4) { + } else if (block.numT() == 4) { prob = T_ARG(0); a = T_ARG(1); b = T_ARG(2); @@ -388,11 +388,11 @@ namespace sd { block.markInplace(isInplace); for (int e = 0; e < tArgs.size(); e++) - block.getTArguments()->emplace_back(tArgs.at(e)); + block.appendT(tArgs.at(e)); for (int e = 0; e < iArgs.size(); e++) - block.getIArguments()->emplace_back(iArgs.at(e)); + block.appendI(iArgs.at(e)); Nd4jStatus status = this->execute(&block); arrayList.setStatus(status); diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp index 7143c3bbdc87..40bf279f6144 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp @@ -37,17 +37,17 @@ namespace sd { nd4j_debug("Executing LegacyReduce3Op: [%i]\n", opNum); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyReduce3Op"); - if (x->isSameShape(y) && (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + if (x->isSameShape(y) && (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { // reduce3 to scalar NativeOpExecutioner::execReduce3Scalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { - std::vector dims(*block.getAxis()); + std::vector dims(block.getAxis()); for (int e = 0; e < dims.size(); e++) if (dims[e] < 0) dims[e] += x->rankOf(); @@ -98,7 +98,7 @@ namespace sd { Nd4jLong *zShape = nullptr; - if (shape::equalsSoft(xShape, yShape) && (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + if (shape::equalsSoft(xShape, yShape) && (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { // reduce3 to scalar case ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); zShape[0] = 2; @@ -112,7 +112,7 @@ namespace sd { } else { auto array = new NDArray(nullptr, xShape, block.launchContext()); - xShape = ShapeUtils::evalReduceShapeInfo('c', *block.getIArguments(), *array, false, true); + xShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, false, true); delete array; } diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp index 433e173fc6c4..66b142b6978c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp @@ -49,11 +49,11 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - auto axis = *block.getAxis(); + auto axis = block.getAxis(); bool allAxes = false; - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(),"LegacyReduceBoolOp"); if (block.width() == 1) { @@ -101,7 +101,7 @@ namespace sd { dims[e] = f >= 0 ? f : f += x->rankOf(); } - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { // scalar NativeOpExecutioner::execReduceBoolScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { @@ -139,7 +139,7 @@ namespace sd { auto keepDims = block.numB() > 0 ? B_ARG(0) : false; auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); if (axis.size() == shape::rank(inShape)) allAxes = true; diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp index 23f863ba2040..0e1e5cb93975 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp @@ -50,9 +50,9 @@ namespace sd { nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); bool allAxes = false; - auto axis = *block.getAxis(); + auto axis = block.getAxis(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyReduceFloatOp"); if (block.width() == 1) { @@ -62,13 +62,13 @@ namespace sd { // _axis.(block.getIArguments()->size() == 0) || // (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) - if (block.getAxis()->empty() || allAxes) { + if (block.getAxis().empty() || allAxes) { // scalar NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { // TAD - std::vector dims(*block.getAxis()); + std::vector dims(block.getAxis()); for (int e = 0; e < dims.size(); e++) if (dims[e] < 0) @@ -102,7 +102,7 @@ namespace sd { dims[e] = f >= 0 ? f : f += x->rankOf(); } - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { // scalar NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { @@ -140,7 +140,7 @@ namespace sd { auto keepDims = block.numB() > 0 ? B_ARG(0) : false; auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); if (axis.size() == shape::rank(inShape)) allAxes = true; diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp index 17cba42274ab..4d674de5ec44 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp @@ -49,10 +49,10 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - auto axis = *block.getAxis(); + auto axis = block.getAxis(); bool allAxes = false; - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(),"LegacyReduceLongOp"); if (block.width() == 1) { @@ -104,7 +104,7 @@ namespace sd { dims[e] = f >= 0 ? f : f += x->rankOf(); } - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { // scalar NativeOpExecutioner::execReduceLongScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { @@ -140,7 +140,7 @@ namespace sd { auto keepDims = block.numB() > 0 ? B_ARG(0) : false; auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); if (axis.size() == shape::rank(inShape)) allAxes = true; diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp index 3c96bca70c88..f9d42bd5d074 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp @@ -49,10 +49,10 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); nd4j_debug("Executing LegacyReduceSameOp: [%i]\n", opNum); - auto axis = *block.getAxis(); + auto axis = block.getAxis(); bool allAxes = false; - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyReduceSameOp"); if (block.width() == 1) { @@ -99,7 +99,7 @@ namespace sd { dims[e] = f >= 0 ? f : f += x->rankOf(); } - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { // scalar NativeOpExecutioner::execReduceSameScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); } else { @@ -136,7 +136,7 @@ namespace sd { auto keepDims = block.numB() > 0 ? B_ARG(0) : false; auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getAxis(); + auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); if (axis.size() == shape::rank(inShape)) allAxes = true; diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index 46728ede133f..bbaaca15656f 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -56,7 +56,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyScalarBoolOp"); if (block.width() > 1) { @@ -65,7 +65,7 @@ namespace sd { NDArray::prepareSpecialUse({z}, {x, y}); NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(x->dataType())); - } else if (block.getTArguments()->size() > 0) { + } else if (block.numT() > 0) { auto y = NDArrayFactory::create(T_ARG(0), block.launchContext()); NDArray::prepareSpecialUse({z}, {x, &y}); diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp index de104a11d8cf..be11a98c01ec 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp @@ -56,7 +56,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyScalarOp"); if (block.width() > 1) { @@ -67,7 +67,7 @@ namespace sd { NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(z->dataType())); NDArray::registerSpecialUse({z}, {x, y}); - } else if (block.getTArguments()->size() > 0) { + } else if (block.numT() > 0) { auto y = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); x->applyScalarArr(static_cast(opNum), y, *z); diff --git a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp index 74f82d16291b..11a8fe1daa2f 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp @@ -38,20 +38,20 @@ namespace sd { // bias goes as first argument, unlike all other reductions bool biasCorrected = false; - if (block.getIArguments()->size() > 0) + if (block.numI() > 0) biasCorrected = INT_ARG(0) > 0; - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(),"LegacyStatsOp"); - if (block.getIArguments()->size() == 1 || (block.getIArguments()->size() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { + if (block.numI() == 1 || (block.numI() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { // scalar NativeOpExecutioner::execSummaryStatsScalar(block.launchContext(), opNum, x->getBuffer(), x->getShapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), biasCorrected); } else { // dimensions for TAD // we should skip first argument here, because it's addressing bias correction - std::vector dims(*block.getIArguments()); + std::vector dims(block.getIArguments()); for (int e = 0; e < dims.size(); e++) if (dims[e] < 0) dims[e] += x->rankOf(); @@ -93,7 +93,7 @@ namespace sd { auto inShape = inputShape->at(0); Nd4jLong *newShape; - if (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { + if (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { // in this case we just return scalar ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; @@ -108,7 +108,7 @@ namespace sd { // in this case we're building proper shape for reduction auto array = new NDArray(nullptr, inShape, block.launchContext()); - newShape = ShapeUtils::evalReduceShapeInfo('c', *block.getIArguments(), *array, false, true); + newShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, false, true); delete array; } diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp index def577eb379f..86c1b672a4b7 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp @@ -45,7 +45,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(),"LegacyTransformAnyOp"); NativeOpExecutioner::execTransformAny(block.launchContext(), opNum, input->getBuffer(), input->getShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp index 99b856b8afb6..dfce53cd4467 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp @@ -45,7 +45,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(),"LegacyTransformBoolOp"); NativeOpExecutioner::execTransformBool(block.launchContext(), opNum, input->getBuffer(), input->getShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp index f0795b7bbfe6..ea363947c047 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp @@ -45,7 +45,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyTransformFloatOp"); NativeOpExecutioner::execTransformFloat(block.launchContext(), opNum, input->getBuffer(), input->getShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp index 0d827787e9cc..30160d118a9d 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp @@ -45,7 +45,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyTransformSameOp"); NativeOpExecutioner::execTransformSame(block.launchContext(), opNum, input->getBuffer(), input->getShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp index f36853579997..7bb1acb67f65 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp @@ -45,7 +45,7 @@ namespace sd { int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(*block.getTArguments()); + ExtraArguments extras(block.getTArguments()); PointersManager manager(block.launchContext(), "LegacyTransformStrictOp"); NativeOpExecutioner::execTransformStrict(block.launchContext(), opNum, input->getBuffer(), input->getShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 4adab2dfef49..2fc2d12c8342 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -53,7 +53,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { const auto dW = INT_ARG(7); const auto paddingMode = INT_ARG(8); const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -98,7 +98,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index 96110bd295d5..ff3199e3e206 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -54,7 +54,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -102,7 +102,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { const int dW = INT_ARG(11); // dilations width const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 173880e63dd0..05deae601336 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -363,7 +363,7 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) { if(applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over @@ -434,7 +434,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { beta = INPUT_VARIABLE(3 + (int)applyScale); - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); std::vector axes; if(numOfIntArgs > 2) for(int i = 2; i < numOfIntArgs; ++i) @@ -617,7 +617,7 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over @@ -707,7 +707,7 @@ PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) { dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); std::vector axes; if(numOfIntArgs > 2) for(int i = 2; i < numOfIntArgs; ++i) diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 0aa05f7f2f2e..396ff554cc1e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -536,8 +536,8 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + bool isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI()> 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width @@ -589,8 +589,8 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 68f0eea89989..51c083bcbc91 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -542,8 +542,8 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -600,8 +600,8 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index a1ca2a717a09..fce710983480 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -350,8 +350,8 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes @@ -422,8 +422,8 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 3236990b1e21..5f951cce2123 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -142,8 +142,8 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] const int rank = gradO->rankOf(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index bcc3d700a40d..c05b877c2628 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -362,8 +362,8 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -439,8 +439,8 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index 2ca16bb8ea9a..d36b6e850809 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -368,8 +368,8 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes @@ -438,8 +438,8 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 91e56d8016e8..66602c0d3e04 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -224,7 +224,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { if(x->isEmpty() || y->isEmpty()) return Status::OK(); - const int iSize = (int) block.getIArguments()->size(); + const int iSize = (int) block.numI(); int transX = iSize > 0 ? INT_ARG(0) : 0; int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 50b3fafa5625..cab94ffa8a76 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -54,7 +54,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { const int dW = INT_ARG(7); const int paddingMode = INT_ARG(8); // const int extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); @@ -95,7 +95,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME // int extraParam0 = INT_ARG(9); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 078b45ba0bf9..59f6d19499d8 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -53,7 +53,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); @@ -100,7 +100,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { const int dW = INT_ARG(11); // dilations width const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index a178e84c2d6a..8d741c9b899d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -105,7 +105,7 @@ namespace sd { auto output = OUTPUT_VARIABLE(0); const int rank = input->rankOf(); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; if (dim < 0) { dim += rank; @@ -224,7 +224,7 @@ namespace sd { const int rank = input->rankOf(); const int dLdzRank = dLdz->rankOf(); - int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; if (dim < 0) { dim += rank; diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index ac7f281d9207..816888931bcd 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1517,11 +1517,11 @@ #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) -#define D_ARG(INDEX) block.getDArguments()->at(INDEX) -#define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +#define D_ARG(INDEX) block.getDArguments().at(INDEX) +#define INT_ARG(INDEX) block.getIArguments().at(INDEX) #define I_ARG(INDEX) INT_ARG(INDEX) -#define T_ARG(INDEX) block.getTArguments()->at(INDEX) -#define B_ARG(INDEX) block.getBArguments()->at(INDEX) +#define T_ARG(INDEX) block.getTArguments().at(INDEX) +#define B_ARG(INDEX) block.getBArguments().at(INDEX) #define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.getWorkspace()) diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 57d9ce88d94d..a51a65bcc8c8 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -48,7 +48,7 @@ TEST_F(ContextTests, Basic_Test_1) { block.pickInput(2, 0); block.pickInput(2, 1); - ASSERT_EQ(2, block.inputs()->size()); + ASSERT_EQ(2, block.inputs().size()); ASSERT_EQ(2, block.width()); ASSERT_TRUE(variableSpace.hasVariable(2, 0)); @@ -76,7 +76,7 @@ TEST_F(ContextTests, Basic_Test_2) { block.pickInput(-1); block.pickInput(-2); - ASSERT_EQ(2, block.inputs()->size()); + ASSERT_EQ(2, block.inputs().size()); ASSERT_EQ(2, block.width()); ASSERT_TRUE(variableSpace.hasVariable(-1)); @@ -247,26 +247,26 @@ TEST_F(ContextTests, Prototype_Test_1) { prototype.pickInput(12, 3); prototype.pickInput(12, 4); - prototype.getTArguments()->push_back(2.0); - prototype.getTArguments()->push_back(-2.0); + prototype.appendT(2.0); + prototype.appendT(-2.0); - prototype.getIArguments()->push_back(17); - prototype.getIArguments()->push_back(119); + prototype.appendI(17); + prototype.appendI(119); - Context ctx(&prototype, nullptr); + Context ctx(prototype, nullptr); ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); - ASSERT_EQ(2, ctx.inputs()->size()); - ASSERT_EQ(2, ctx.getTArguments()->size()); - ASSERT_EQ(2, ctx.getIArguments()->size()); + ASSERT_EQ(2, ctx.inputs().size()); + ASSERT_EQ(2, ctx.getTArguments().size()); + ASSERT_EQ(2, ctx.getIArguments().size()); - ASSERT_EQ(2.0, ctx.getTArguments()->at(0)); - ASSERT_EQ(-2.0, ctx.getTArguments()->at(1)); + ASSERT_EQ(2.0, ctx.getTArguments().at(0)); + ASSERT_EQ(-2.0, ctx.getTArguments().at(1)); - ASSERT_EQ(17, ctx.getIArguments()->at(0)); - ASSERT_EQ(119, ctx.getIArguments()->at(1)); + ASSERT_EQ(17, ctx.getIArguments().at(0)); + ASSERT_EQ(119, ctx.getIArguments().at(1)); } @@ -274,14 +274,14 @@ TEST_F(ContextTests, Prototype_Test_2) { ContextPrototype prototype(nullptr, 119, false); prototype.setOpNum(179); - Context ctx(&prototype, nullptr); + Context ctx(prototype, nullptr); ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); ASSERT_EQ(ctx.opNum(), prototype.opNum()); - ASSERT_EQ(0, ctx.inputs()->size()); - ASSERT_EQ(0, ctx.getTArguments()->size()); - ASSERT_EQ(0, ctx.getIArguments()->size()); + ASSERT_EQ(0, ctx.inputs().size()); + ASSERT_EQ(0, ctx.getTArguments().size()); + ASSERT_EQ(0, ctx.getIArguments().size()); } TEST_F(ContextTests, test_short_context_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 149ab3c5fb3c..d33ed6fff780 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -81,26 +81,26 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_1) { auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({-1, -2}); // 5,5 kernel - block->getIArguments()->push_back(kH); - block->getIArguments()->push_back(kW); + block->appendI(kH); + block->appendI(kW); // 1,1 stride - block->getIArguments()->push_back(sH); - block->getIArguments()->push_back(sW); + block->appendI(sH); + block->appendI(sW); // 0,0 padding - block->getIArguments()->push_back(pH); - block->getIArguments()->push_back(pW); + block->appendI(pH); + block->appendI(pW); // 1,1 dilation - block->getIArguments()->push_back(dH); - block->getIArguments()->push_back(dW); + block->appendI(dH); + block->appendI(dW); // same mode - block->getIArguments()->push_back(1); + block->appendI(1); // is NHWC - block->getIArguments()->push_back(0); + block->appendI(0); sd::ops::conv2d op; @@ -409,21 +409,21 @@ TEST_F(ConvolutionTests1, sconv2d_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1, -2}); - block->getIArguments()->push_back(kY); - block->getIArguments()->push_back(kX); + block->appendI(kY); + block->appendI(kX); - block->getIArguments()->push_back(sY); - block->getIArguments()->push_back(sX); + block->appendI(sY); + block->appendI(sX); - block->getIArguments()->push_back(pY); - block->getIArguments()->push_back(pX); + block->appendI(pY); + block->appendI(pX); // dilation - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(1); + block->appendI(1); + block->appendI(1); // NOT same mode - block->getIArguments()->push_back(0); + block->appendI(0); sd::ops::sconv2d op; @@ -1984,24 +1984,24 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); ContextPrototype proto; Context ctx(1); - ctx.getIArguments()->push_back(2); - ctx.getIArguments()->push_back(5); - ctx.getIArguments()->push_back(5); + ctx.appendI(2); + ctx.appendI(5); + ctx.appendI(5); - ctx.getIArguments()->push_back(5); - ctx.getIArguments()->push_back(4); - ctx.getIArguments()->push_back(3); + ctx.appendI(5); + ctx.appendI(4); + ctx.appendI(3); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(0); + ctx.appendI(0); + ctx.appendI(0); + ctx.appendI(0); - ctx.getIArguments()->push_back(1); - ctx.getIArguments()->push_back(1); - ctx.getIArguments()->push_back(1); + ctx.appendI(1); + ctx.appendI(1); + ctx.appendI(1); - ctx.getIArguments()->push_back(0); - ctx.getIArguments()->push_back(1); // previous variant was "ctx.getIArguments()->push_back(0)" and this caused fail + ctx.appendI(0); + ctx.appendI(1); // previous variant was "ctx.appendI(0)" and this caused fail auto shapes = op.calculateOutputShape(&shapeList, ctx); ASSERT_EQ(1, shapes->size()); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 169c51124ff2..5d86c294405e 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -909,8 +909,8 @@ TEST_F(ConvolutionTests2, maxpool2d_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -953,8 +953,8 @@ TEST_F(ConvolutionTests2, maxpool2d_2) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -997,8 +997,8 @@ TEST_F(ConvolutionTests2, maxpool2d_3) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1041,8 +1041,8 @@ TEST_F(ConvolutionTests2, maxpool2d_4) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1085,8 +1085,8 @@ TEST_F(ConvolutionTests2, maxpool2d_5) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1703,8 +1703,8 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); block->fillInputs({-2}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d_bp bp; Nd4jStatus status = bp.execute(block); @@ -1887,8 +1887,8 @@ TEST_F(ConvolutionTests2, avgpool2d_bp_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); block->fillInputs({-2}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format sd::ops::avgpool2d_bp bp; Nd4jStatus status = bp.execute(block); @@ -2059,9 +2059,9 @@ TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { block->fillInputs({-1}); block->fillInputs({-2}); auto argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor - std::vector* argT = block->getTArguments(); - *argT = {0.000001}; + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor + std::vector argT = block->getTArguments(); + argT = {0.000001}; sd::ops::pnormpool2d_bp bp; Nd4jStatus status = bp.execute(block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index ca4a356ae794..bef547d58536 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -122,7 +122,7 @@ TEST_F(DeclarableOpsTests1, BasicInitialization1) { variableSpace->putVariable(1, nodeVar); Context block(1, variableSpace); - block.getIArguments()->push_back(1); + block.appendI(1); block.fillInputs({ -1, -2, -3, -4, -5 }); ASSERT_FALSE(nodeVar->hasNDArray()); @@ -872,8 +872,10 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(1, new Variable()); auto block = new Context(1, variableSpace, true); - block->getTArguments()->push_back(0.0f); - block->getTArguments()->push_back(3.0f); + + block->appendT(0.0f); + block->appendT(3.0f); + block->fillInputs({ -1 }); sd::ops::clipbyvalue clip; @@ -1862,11 +1864,12 @@ TEST_F(DeclarableOpsTests1, Reshape2) { auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - std::vector* arguments = block->getIArguments(); - arguments->push_back(-y->ordering()); - arguments->push_back(3); - arguments->push_back(5); - arguments->push_back(4); + std::vector arguments = block->getIArguments(); + + arguments.push_back(-y->ordering()); + arguments.push_back(3); + arguments.push_back(5); + arguments.push_back(4); sd::ops::reshape reshape; @@ -2005,7 +2008,7 @@ TEST_F(DeclarableOpsTests1, Permute1) { auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); auto arguments = block->getIArguments(); - *arguments = perm; // set dimensions to be permuted + arguments = perm; // set dimensions to be permuted sd::ops::permute permute; Nd4jStatus status = permute.execute(block); @@ -2059,7 +2062,7 @@ TEST_F(DeclarableOpsTests1, TestReductionShape1) { block->fillInputs({ -1 }); // kernel params - block->getIArguments()->push_back(MAX_INT); + block->appendI(MAX_INT); sd::ops::testreduction testop; @@ -2095,10 +2098,10 @@ TEST_F(DeclarableOpsTests1, TestReductionShape2) { // kernel params //block->getIArguments()->push_back(4); - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(2); - block->getIArguments()->push_back(3); - block->getIArguments()->push_back(4); + block->appendI(1); + block->appendI(2); + block->appendI(3); + block->appendI(4); sd::ops::testreduction testop; @@ -2193,8 +2196,8 @@ TEST_F(DeclarableOpsTests1, Pnormpool2d1) { auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - std::vector* argI = block->getIArguments(); - *argI = { kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; + std::vector argI = block->getIArguments(); + argI = { kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; sd::ops::pnormpool2d pooling; Nd4jStatus status = pooling.execute(block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 69dec8359fdd..4772b7972553 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -404,8 +404,8 @@ TEST_F(DeclarableOpsTests4, avgpool2d_13) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -447,8 +447,8 @@ TEST_F(DeclarableOpsTests4, avgpool2d_14) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -490,8 +490,8 @@ TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + std::vector argI = block->getIArguments(); + argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 002e3376f12e..2726947c9704 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -151,11 +151,13 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { block->fillInputs({-2}); block->fillInputs({-3}); block->fillInputs({-4}); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(1); - block->getIArguments()->push_back(0); - block->getIArguments()->push_back(0); + + block->appendI(0); + block->appendI(0); + block->appendI(1); + block->appendI(0); + block->appendI(0); + auto inputShapes = new ShapeList({ones->getShapeInfo(), b->getShapeInfo(), e->getShapeInfo(), s->getShapeInfo()}); sd::ops::strided_slice op; auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); From 6db387e45ecd86a63abb298bd9d5fb1033072589 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 08:45:13 +0300 Subject: [PATCH 055/233] bunch of tests fixed Signed-off-by: raver119 --- .../layers_tests/ConvolutionTests2.cpp | 27 +++++++------------ .../layers_tests/DeclarableOpsTests1.cpp | 19 +++++++------ .../layers_tests/DeclarableOpsTests4.cpp | 9 +++---- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 5d86c294405e..80ec596a7805 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -909,8 +909,7 @@ TEST_F(ConvolutionTests2, maxpool2d_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -953,8 +952,7 @@ TEST_F(ConvolutionTests2, maxpool2d_2) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -997,8 +995,7 @@ TEST_F(ConvolutionTests2, maxpool2d_3) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1041,8 +1038,7 @@ TEST_F(ConvolutionTests2, maxpool2d_4) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1085,8 +1081,7 @@ TEST_F(ConvolutionTests2, maxpool2d_5) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -1703,8 +1698,7 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); block->fillInputs({-2}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d_bp bp; Nd4jStatus status = bp.execute(block); @@ -1887,8 +1881,7 @@ TEST_F(ConvolutionTests2, avgpool2d_bp_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); block->fillInputs({-2}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format sd::ops::avgpool2d_bp bp; Nd4jStatus status = bp.execute(block); @@ -2058,10 +2051,8 @@ TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); block->fillInputs({-2}); - auto argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor - std::vector argT = block->getTArguments(); - argT = {0.000001}; + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor + block->appendT(0.000001); sd::ops::pnormpool2d_bp bp; Nd4jStatus status = bp.execute(block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index bef547d58536..dd165152b295 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1864,12 +1864,12 @@ TEST_F(DeclarableOpsTests1, Reshape2) { auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - std::vector arguments = block->getIArguments(); + block->getIArguments(); - arguments.push_back(-y->ordering()); - arguments.push_back(3); - arguments.push_back(5); - arguments.push_back(4); + block->appendI(-y->ordering()); + block->appendI(3); + block->appendI(5); + block->appendI(4); sd::ops::reshape reshape; @@ -1993,7 +1993,7 @@ TEST_F(DeclarableOpsTests1, Permute1) { Nd4jLong shapeX[] = { 3, 5,10,15, 150,15,1, 0,1,99 }; Nd4jLong shapeExp[] = { 3, 15,5,10, 50,10,1, 0,1,99 }; - const std::vector perm = { 2, 0, 1 }; + const std::vector perm = { 2, 0, 1 }; ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); @@ -2007,8 +2007,8 @@ TEST_F(DeclarableOpsTests1, Permute1) { auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); - auto arguments = block->getIArguments(); - arguments = perm; // set dimensions to be permuted + + block->appendI(perm); // set dimensions to be permuted sd::ops::permute permute; Nd4jStatus status = permute.execute(block); @@ -2196,8 +2196,7 @@ TEST_F(DeclarableOpsTests1, Pnormpool2d1) { auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - std::vector argI = block->getIArguments(); - argI = { kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; + block->appendI({ kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; sd::ops::pnormpool2d pooling; Nd4jStatus status = pooling.execute(block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 4772b7972553..0717d713039c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -404,8 +404,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_13) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -447,8 +446,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_14) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); @@ -490,8 +488,7 @@ TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { auto block = new Context(1, variableSpace, false); block->fillInputs({-1}); - std::vector argI = block->getIArguments(); - argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::avgpool2d pooling; Nd4jStatus status = pooling.execute(block); From c212277daacb40288ced3da696e60c149393ec2b Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 08:47:40 +0300 Subject: [PATCH 056/233] minor fix Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index b8820a3f274d..24a269c00bc7 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -87,7 +87,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); // D - graph.addVariable("C", NDArrayFactory::create('c', {3}, {4, 4, 4})); + graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); Node a("multiply", sd::ops::multiply()); Node b("add", sd::ops::add()); From e23dbec2fb24d556458e8d0e6cc30b16d011c01c Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 09:49:39 +0300 Subject: [PATCH 057/233] inputs/outputs validation Signed-off-by: raver119 --- .../include/exceptions/datatype_exception.h | 8 +-- .../exceptions/impl/datatype_exception.cpp | 20 ++++---- .../impl/shape_mismatch_exception.cpp | 36 +++++++++++++ .../exceptions/shape_mismatch_exception.h | 50 +++++++++++++++++++ libnd4j/include/graph/Graph.h | 2 + libnd4j/include/graph/impl/Graph.cpp | 28 +++++++++++ .../tests_cpu/layers_tests/GraphTests2.cpp | 24 +++++++++ 7 files changed, 154 insertions(+), 14 deletions(-) create mode 100644 libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp create mode 100644 libnd4j/include/exceptions/shape_mismatch_exception.h diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index 32e4441d62e7..a2e9ed96fbb7 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -36,13 +36,13 @@ namespace sd { class SD_EXPORT datatype_exception : public std::runtime_error { public: - datatype_exception(std::string message); + datatype_exception(const std::string &message); ~datatype_exception() = default; - static datatype_exception build(std::string message, sd::DataType actual); - static datatype_exception build(std::string message, sd::DataType expected, sd::DataType actual); - static datatype_exception build(std::string message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY); + static datatype_exception build(const std::string &message, sd::DataType actual); + static datatype_exception build(const std::string &message, sd::DataType expected, sd::DataType actual); + static datatype_exception build(const std::string &message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY); }; } diff --git a/libnd4j/include/exceptions/impl/datatype_exception.cpp b/libnd4j/include/exceptions/impl/datatype_exception.cpp index 9aab37951032..ee4bd1a5a77a 100644 --- a/libnd4j/include/exceptions/impl/datatype_exception.cpp +++ b/libnd4j/include/exceptions/impl/datatype_exception.cpp @@ -22,28 +22,28 @@ #include namespace sd { - datatype_exception::datatype_exception(std::string message) : std::runtime_error(message){ + datatype_exception::datatype_exception(const std::string &message) : std::runtime_error(message){ // } - datatype_exception datatype_exception::build(std::string message, sd::DataType expected, sd::DataType actual) { + datatype_exception datatype_exception::build(const std::string &message, sd::DataType expected, sd::DataType actual) { auto exp = DataTypeUtils::asString(expected); auto act = DataTypeUtils::asString(actual); - message += "; Expected: [" + exp + "]; Actual: [" + act + "]"; - return datatype_exception(message); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + act + "]"; + return datatype_exception(fmessage); } - datatype_exception datatype_exception::build(std::string message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY) { + datatype_exception datatype_exception::build(const std::string &message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY) { auto exp = DataTypeUtils::asString(expected); auto actX = DataTypeUtils::asString(actualX); auto actY = DataTypeUtils::asString(actualY); - message += "; Expected: [" + exp + "]; Actual: [" + actX + ", " + actY + "]"; - return datatype_exception(message); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + actX + ", " + actY + "]"; + return datatype_exception(fmessage); } - datatype_exception datatype_exception::build(std::string message, sd::DataType actual) { + datatype_exception datatype_exception::build(const std::string &message, sd::DataType actual) { auto act = DataTypeUtils::asString(actual); - message += "; Actual: [" + act + "]"; - return datatype_exception(message); + auto fmessage = message + "; Actual: [" + act + "]"; + return datatype_exception(fmessage); } } \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp new file mode 100644 index 000000000000..5b53ccc561ba --- /dev/null +++ b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace sd { + shape_mismatch_exception::shape_mismatch_exception(const std::string &message) : std::runtime_error(message) { + // + } + + shape_mismatch_exception + shape_mismatch_exception::build(const std::string &message, const std::vector &expected, const std::vector &actual) { + auto exp = ShapeUtils::shapeAsString(expected); + auto act = ShapeUtils::shapeAsString(actual); + auto fmessage = message + "; Expected shape: " + exp + "; Actual shape: " + act + ";"; + return shape_mismatch_exception(fmessage); + } +} diff --git a/libnd4j/include/exceptions/shape_mismatch_exception.h b/libnd4j/include/exceptions/shape_mismatch_exception.h new file mode 100644 index 000000000000..8a9c2092e860 --- /dev/null +++ b/libnd4j/include/exceptions/shape_mismatch_exception.h @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#ifndef SD_SHAPE_MISMATCH_EXCEPTION_H +#define SD_SHAPE_MISMATCH_EXCEPTION_H + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) + +// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library +#pragma warning( disable : 4275 ) + +#endif + +namespace sd { + class SD_EXPORT shape_mismatch_exception : public std::runtime_error { + public: + shape_mismatch_exception(const std::string &message); + ~shape_mismatch_exception() = default; + + static shape_mismatch_exception build(const std::string &message, const std::vector &expected, const std::vector &actual); + }; +} + + +#endif //SD_SHAPE_MISMATCH_EXCEPTION_H diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index bc78fc2fb8cc..448503e63808 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -68,6 +68,8 @@ namespace sd { int idByName(const std::string &nodeName) const; void printOutNode(const Node &node) const; + + std::vector _placeholders; public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index c090ff1574ee..decf1a026f4a 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include namespace sd { namespace graph { @@ -546,13 +548,39 @@ namespace sd { auto var = new Variable(true, dataType, shape); var->setName(nodeName); _variableSpace->putVariable(id, var); + + _placeholders.emplace_back(var); } std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { // first of all we check existence of placeholders in dictionary + int placeholdersCount = 0; for (const auto &v:dictionary) { if (_symbolicLookupTable.count(v.first) == 0) throw unresolved_input_exception::build("Dictionary entry doesn't exist", v.first); + + // we also check if arrays provided here do match placeholder restrictions of shape and dtype + auto var = _variableSpace->getVariable(v.first); + if (var->dataType() != DataType::ANY && var->dataType() != v.second.dataType()) + throw datatype_exception::build("Placeholder requires another data type", var->dataType(), v.second.dataType()); + + auto shape = v.second.getShapeAsVector(); + if (shape != var->shape()) + throw shape_mismatch_exception::build("Placeholder requires specific shape", var->shape(), shape); + + // we must also check if all placeholders were resolved + placeholdersCount++; + } + + // TODO: it would be nice if we'll print out unresolved placeholders + if (placeholdersCount != _placeholders.size()) + throw std::runtime_error("Some placeholders were not resolved"); + + + // we also must check existence of requested outputs + for (const auto &v:outputs) { + if (_symbolicLookupTable.count(v) == 0) + throw unresolved_output_exception::build("Requested output doesn't exist", v); } // TODO: implement this method diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index baac6bbc154a..c910b4310b0a 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include using namespace sd; using namespace sd::graph; @@ -110,6 +112,28 @@ TEST_F(GraphTests2, test_placeholder_resolution_2) { // TODO: add result validation here } +TEST_F(GraphTests2, test_placeholder_resolution_3) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + Node a("tanh_node", "tanh"); + graph.addNode(a, {"input"}); + + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), sd::datatype_exception); +} + +TEST_F(GraphTests2, test_placeholder_resolution_4) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); + + Node a("tanh_node", "tanh"); + graph.addNode(a, {"input"}); + + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), sd::shape_mismatch_exception); +} + TEST_F(GraphTests2, test_output_resolution_1) { Graph graph; From d4c17d31d0ba1fe4ddf1a9c92969261c01d2c9bd Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 11:54:17 +0300 Subject: [PATCH 058/233] shared_ptr for ops in cache and Nodes Signed-off-by: raver119 --- libnd4j/include/graph/Node.h | 13 +- libnd4j/include/graph/impl/Node.cpp | 129 +++++++++++++----- libnd4j/include/helpers/helper_hash.h | 2 +- libnd4j/include/helpers/impl/helper_hash.cpp | 2 +- libnd4j/include/legacy/cpu/NativeOps.cpp | 6 +- libnd4j/include/ops/declarable/LegacyOp.h | 10 ++ .../include/ops/declarable/OpRegistrator.h | 17 +-- .../include/ops/declarable/impl/LegacyOp.cpp | 30 ++++ .../ops/declarable/impl/OpRegistrator.cpp | 43 +++--- libnd4j/include/system/op_boilerplate.h | 4 +- 10 files changed, 174 insertions(+), 82 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 30e8a950dc5f..d27c8a740f22 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -81,7 +81,7 @@ namespace sd { // these fields are used to store embedded CustomOps and Graph in case of Graph-in-Graph scenario Graph * _graph= nullptr; - sd::ops::DeclarableOp *_customOp = nullptr; + std::shared_ptr _customOp; // each node can be active or inactive, if used with divergents, like IF statements bool _active = true; @@ -111,6 +111,7 @@ namespace sd { explicit Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(const std::string &opName, const int id = 0, const std::vector> &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); + explicit Node(std::shared_ptr customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); @@ -195,8 +196,8 @@ namespace sd { const ContextPrototype& contextPrototype() const; bool hasBlockAttached(); - void setCustomOp(sd::ops::DeclarableOp *customOp = nullptr); - sd::ops::DeclarableOp* customOp() const; + void setCustomOp(std::shared_ptr customOp); + std::shared_ptr customOp() const; bool hasCustomOp() const; void setGraph(Graph* graph = nullptr); @@ -242,10 +243,6 @@ namespace sd { this->setLayer(other->getLayer()); this->setDeductable(other->isDeductable()); - - if (this->_customOp != nullptr && _isDeductable) - delete this->_customOp; - for (auto &v: other->input()) this->_input.emplace_back(v); @@ -257,7 +254,7 @@ namespace sd { } - static sd::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar); + static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar); static void deleteOpByType(OpType opType, void *op); }; } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 201b412cc911..882391732a10 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -187,15 +187,15 @@ namespace sd { _id = id; } - sd::ops::DeclarableOp* Node::customOp() const { + std::shared_ptr Node::customOp() const { return _customOp; } - void Node::setCustomOp(sd::ops::DeclarableOp *customOp) { + void Node::setCustomOp(std::shared_ptr customOp) { _customOp = customOp; // divergent ops (Switch etc) are always inplace, they don't allocate anything - if (_customOp != nullptr && customOp->getOpDescriptor()->isDivergent()) + if (_customOp.get() != nullptr && _customOp->getOpDescriptor()->isDivergent()) _isInplace = true; } @@ -439,7 +439,12 @@ namespace sd { this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default this->_dim = nullptr; - this->_customOp = customOp; + + // if custom op is a registered one - pull it from cache, otherwise - clone locally + if (sd::ops::OpRegistrator::getInstance()->hasOperation(_opNum)) + this->_customOp = sd::ops::OpRegistrator::getInstance()->getOperation(_opNum); + else + throw std::runtime_error("Can't create a node with custom operation within"); _hasExternalInputs = false; _hasExternalOutputs = false; @@ -817,10 +822,6 @@ namespace sd { if (_dim != nullptr) delete[] _dim; - - if (_isDeductable && _customOp != nullptr) { - Node::deleteOpByType(_opType, _customOp); - } } int Node::getRewindNode() { @@ -848,7 +849,38 @@ namespace sd { } Node::Node(const Node &other) noexcept { + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _scope_id = other._scope_id; + _scope_name = other._scope_name; + _rewindNode = other._rewindNode; + _layer = other._layer; + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _isDeductable = other._isDeductable; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _customOp = other._customOp; + _dim = other._dim; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; + _dimensions = other._dimensions; + _rewindLayer = other._rewindLayer; + _referencedBy = other._referencedBy; + _scalar = other._scalar; } Node &Node::operator=(const Node &other) noexcept { @@ -892,7 +924,40 @@ namespace sd { } Node::Node(Node &&other) noexcept { + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _scope_id = other._scope_id; + _name = std::move(other._name); + _scope_name = std::move(other._scope_name); + _rewindNode = other._rewindNode; + _layer = other._layer; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _isDeductable = other._isDeductable; + _active = other._active; + _removable = other._removable; + _graph = other._graph; + _customOp = other._customOp; + _dim = other._dim; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = std::move(other._input); + _output = std::move(other._output); + _dimensions = std::move(other._dimensions); + _rewindLayer = std::move(other._rewindLayer); + _referencedBy = std::move(other._referencedBy); + _scalar = std::move(other._scalar); + + other._customOp = nullptr; } Node &Node::operator=(Node &&other) noexcept { @@ -1001,44 +1066,44 @@ namespace sd { } } - sd::ops::DeclarableOp* Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { + std::shared_ptr Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { switch (opType) { case OpType_PAIRWISE: - return new sd::ops::LegacyPairwiseTransformOp(opNum); + return std::make_shared(opNum); case OpType_PAIRWISE_BOOL: - return new sd::ops::LegacyPairwiseTransformBoolOp(opNum); + return std::make_shared(opNum); case OpType_TRANSFORM_STRICT: - return new sd::ops::LegacyTransformStrictOp(opNum); + return std::make_shared(opNum); case OpType_TRANSFORM_SAME: - return new sd::ops::LegacyTransformSameOp(opNum); + return std::make_shared(opNum); case OpType_TRANSFORM_FLOAT: - return new sd::ops::LegacyTransformFloatOp(opNum); + return std::make_shared(opNum); case OpType_TRANSFORM_BOOL: - return new sd::ops::LegacyTransformBoolOp(opNum); + return std::make_shared(opNum); case OpType_SCALAR: - return scalar == nullptr ? new sd::ops::LegacyScalarOp(opNum) : new sd::ops::LegacyScalarOp(opNum, *scalar); + return scalar == nullptr ? std::make_shared(opNum) : std::make_shared(opNum, *scalar); case OpType_SCALAR_BOOL: - return scalar == nullptr ? new sd::ops::LegacyScalarBoolOp(opNum) : new sd::ops::LegacyScalarBoolOp(opNum, *scalar); + return scalar == nullptr ? std::make_shared(opNum) : std::make_shared(opNum, *scalar); case OpType_REDUCE_3: - return new sd::ops::LegacyReduce3Op(opNum); + return std::make_shared(opNum); case OpType_REDUCE_SAME: - return new sd::ops::LegacyReduceSameOp(opNum); + return std::make_shared(opNum); case OpType_REDUCE_FLOAT: - return new sd::ops::LegacyReduceFloatOp(opNum); + return std::make_shared(opNum); case OpType_REDUCE_LONG: - return new sd::ops::LegacyReduceLongOp(opNum); + return std::make_shared(opNum); case OpType_REDUCE_BOOL: - return new sd::ops::LegacyReduceBoolOp(opNum); + return std::make_shared(opNum); case OpType_INDEX_REDUCE: - return new sd::ops::LegacyIndexReduceOp(opNum); + return std::make_shared(opNum); case OpType_SUMMARYSTATS: - return new sd::ops::LegacyStatsOp(opNum); + return std::make_shared(opNum); case OpType_RANDOM: - return new sd::ops::LegacyRandomOp(opNum); + return std::make_shared(opNum); case OpType_BROADCAST: - return new sd::ops::LegacyBroadcastOp(opNum); + return std::make_shared(opNum); case OpType_BROADCAST_BOOL: - return new sd::ops::LegacyBroadcastBoolOp(opNum); + return std::make_shared(opNum); default: throw std::runtime_error("Bad opType passed in"); } @@ -1055,7 +1120,8 @@ namespace sd { Node* Node::clone() { if (this->_customOp && this->_opType == OpType_CUSTOM) { - auto clone = new Node(this->_customOp, _id); + auto clone = new Node(nullptr, _id); + clone->_customOp = _customOp; clone->pullValues(this); return clone; } @@ -1065,12 +1131,7 @@ namespace sd { clone->pullValues(this); // op time - if (!_isDeductable) - clone->_customOp = _customOp; - else { - auto c = dynamic_cast(_customOp); - clone->_customOp = c->clone(); - } + clone->_customOp = _customOp; return clone; } diff --git a/libnd4j/include/helpers/helper_hash.h b/libnd4j/include/helpers/helper_hash.h index 0ff76eea1bcc..10a22f2237d4 100644 --- a/libnd4j/include/helpers/helper_hash.h +++ b/libnd4j/include/helpers/helper_hash.h @@ -42,7 +42,7 @@ namespace sd { public: static HashHelper* getInstance(); - Nd4jLong getLongHash(std::string& str); + Nd4jLong getLongHash(const std::string& str); }; } } diff --git a/libnd4j/include/helpers/impl/helper_hash.cpp b/libnd4j/include/helpers/impl/helper_hash.cpp index b12acb273fc8..0f45c8cb3616 100644 --- a/libnd4j/include/helpers/impl/helper_hash.cpp +++ b/libnd4j/include/helpers/impl/helper_hash.cpp @@ -31,7 +31,7 @@ namespace sd { return _INSTANCE; } - Nd4jLong HashHelper::getLongHash(std::string& str) { + Nd4jLong HashHelper::getLongHash(const std::string& str) { _locker.lock(); if (!_isInit) { nd4j_verbose("Building HashUtil table\n",""); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 2b0409b88205..67419d31e02c 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2015,7 +2015,7 @@ sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, try { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); + return _calculateOutputShapes(extraPointers, op.get(), inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2046,7 +2046,7 @@ sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, try { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); + return _calculateOutputShapes(extraPointers, op.get(), inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); @@ -2155,7 +2155,7 @@ Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jL int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { try { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + return realExec(op.get(), extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/include/ops/declarable/LegacyOp.h b/libnd4j/include/ops/declarable/LegacyOp.h index 90ebf0fcc975..f199f362120b 100644 --- a/libnd4j/include/ops/declarable/LegacyOp.h +++ b/libnd4j/include/ops/declarable/LegacyOp.h @@ -47,6 +47,16 @@ namespace sd { LegacyOp(int numInputs, int opNum); ~LegacyOp() = default; + LegacyOp(const LegacyOp& other) noexcept; + + LegacyOp& operator=(const LegacyOp& other) noexcept; + + // move constructor + LegacyOp(LegacyOp&& other) noexcept; + + // move assignment operator + LegacyOp& operator=(LegacyOp&& other) noexcept; + // All Op classes provide own specific implementation for this method ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override = 0; virtual LegacyOp* clone() = 0; diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index a024c0f5ad03..e055b7622a3f 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -82,9 +82,8 @@ namespace sd { MAP_IMPL _msvc; // pointers to our operations - MAP_IMPL _declarablesLD; - MAP_IMPL _declarablesD; - std::vector _uniqueD; + MAP_IMPL> _declarablesLD; + MAP_IMPL> _declarablesD; // pointers to platform-specific helpers MAP_IMPL, sd::ops::platforms::PlatformHelper*> _helpersLH; @@ -114,16 +113,18 @@ namespace sd { * * @param op */ - bool registerOperation(const char* name, sd::ops::DeclarableOp* op); - bool registerOperation(sd::ops::DeclarableOp *op); + bool registerOperation(const std::string &opName, std::shared_ptr op); + bool registerOperation(std::shared_ptr op); void registerHelper(sd::ops::platforms::PlatformHelper* op); bool hasHelper(Nd4jLong hash, samediff::Engine engine); - sd::ops::DeclarableOp* getOperation(const char *name); - sd::ops::DeclarableOp* getOperation(Nd4jLong hash); - sd::ops::DeclarableOp* getOperation(const std::string &name); + std::shared_ptr getOperation(Nd4jLong hash); + std::shared_ptr getOperation(const std::string &name); + + bool hasOperation(const std::string &opName) const; + bool hasOperation(const Nd4jLong opName) const; sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine); diff --git a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp index c179488dffb4..dfa4d42a0002 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp @@ -31,5 +31,35 @@ namespace sd { _opNum = opNum; _numInputs = numInputs; } + + LegacyOp::LegacyOp(const LegacyOp &other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; + } + + LegacyOp &LegacyOp::operator=(const LegacyOp &other) noexcept { + if (this == &other) + return *this; + + _numInputs = other._numInputs; + _opNum = other._opNum; + + return *this; + } + + LegacyOp::LegacyOp(LegacyOp &&other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; + } + + LegacyOp &LegacyOp::operator=(LegacyOp &&other) noexcept { + if (this == &other) + return *this; + + _numInputs = other._numInputs; + _opNum = other._opNum; + + return *this; + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index 2ad7fd93670c..a37b6e422395 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -110,14 +110,9 @@ namespace sd { #ifndef _RELEASE _msvc.clear(); - for (auto x : _uniqueD) - delete x; - for (auto x: _uniqueH) delete x; - _uniqueD.clear(); - _uniqueH.clear(); _declarablesD.clear(); @@ -130,7 +125,7 @@ namespace sd { _locker.lock(); if (!isInit) { - for (MAP_IMPL::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { + for (MAP_IMPL>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { std::string op = it->first + ":" + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" @@ -150,14 +145,20 @@ namespace sd { return _opsList.c_str(); } + bool OpRegistrator::hasOperation(const std::string &opName) const { + return _declarablesD.count(opName) > 1; + } - bool OpRegistrator::registerOperation(const char* name, sd::ops::DeclarableOp* op) { - std::string str(name); - std::pair pair(str, op); + bool OpRegistrator::hasOperation(Nd4jLong opName) const { + return _declarablesLD.count(opName) > 1; + } + + bool OpRegistrator::registerOperation(const std::string &opName, std::shared_ptr op) { + std::pair> pair(opName, op); _declarablesD.insert(pair); - auto hash = sd::ops::HashHelper::getInstance()->getLongHash(str); - std::pair pair2(hash, op); + auto hash = sd::ops::HashHelper::getInstance()->getLongHash(opName); + std::pair> pair2(hash, op); _declarablesLD.insert(pair2); return true; } @@ -167,9 +168,8 @@ namespace sd { * * @param op */ - bool OpRegistrator::registerOperation(sd::ops::DeclarableOp *op) { - _uniqueD.emplace_back(op); - return registerOperation(op->getOpName().c_str(), op); + bool OpRegistrator::registerOperation(std::shared_ptr op) { + return registerOperation(op->getOpName(), op); } void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { @@ -188,40 +188,33 @@ namespace sd { _helpersLH.insert(pair2); } - sd::ops::DeclarableOp* OpRegistrator::getOperation(const char *name) { - std::string str(name); - return getOperation(str); - } - /** * This method returns registered Op by name * * @param name * @return */ - sd::ops::DeclarableOp *OpRegistrator::getOperation(Nd4jLong hash) { + std::shared_ptr OpRegistrator::getOperation(Nd4jLong hash) { if (!_declarablesLD.count(hash)) { if (!_msvc.count(hash)) { nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); return nullptr; } else { - _locker.lock(); + std::lock_guard lock(_locker); auto str = _msvc.at(hash); auto op = _declarablesD.at(str); auto oHash = op->getOpDescriptor()->getHash(); - std::pair pair(oHash, op); + std::pair> pair(oHash, op); _declarablesLD.insert(pair); - - _locker.unlock(); } } return _declarablesLD.at(hash); } - sd::ops::DeclarableOp *OpRegistrator::getOperation(const std::string& name) { + std::shared_ptr OpRegistrator::getOperation(const std::string& name) { if (!_declarablesD.count(name)) { nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); return nullptr; diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 816888931bcd..92aeb9218445 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1268,7 +1268,7 @@ #define REGISTER_C(NAME) template \ struct __registrator_##NAME {\ __registrator_##NAME() {\ - OpName *ptr = new OpName(); \ + auto ptr = std::make_shared(); \ OpRegistrator::getInstance()->registerOperation(ptr); \ }\ };\ @@ -1343,7 +1343,7 @@ #define DECLARE_SYN(NAME, ORIGINAL) template \ struct __registratorSynonym_##NAME {\ __registratorSynonym_##NAME(const char *name, const char *oname) {\ - auto ptr = reinterpret_cast(OpRegistrator::getInstance()->getOperation(oname)); \ + auto ptr = OpRegistrator::getInstance()->getOperation(oname); \ if (ptr == nullptr) { \ std::string newName(name); \ std::string oldName(oname); \ From ba31c638492ec7e740afde1e73a3f1c60ca90701 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 12:03:19 +0300 Subject: [PATCH 059/233] few small fixes Signed-off-by: raver119 --- libnd4j/include/graph/impl/Node.cpp | 3 +-- libnd4j/include/ops/declarable/impl/OpRegistrator.cpp | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 882391732a10..613fe57d2dcc 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -1120,8 +1120,7 @@ namespace sd { Node* Node::clone() { if (this->_customOp && this->_opType == OpType_CUSTOM) { - auto clone = new Node(nullptr, _id); - clone->_customOp = _customOp; + auto clone = new Node(_customOp.get(), _id); clone->pullValues(this); return clone; } diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index a37b6e422395..d94748eb9ddb 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -146,11 +146,11 @@ namespace sd { } bool OpRegistrator::hasOperation(const std::string &opName) const { - return _declarablesD.count(opName) > 1; + return _declarablesD.count(opName) > 0; } bool OpRegistrator::hasOperation(Nd4jLong opName) const { - return _declarablesLD.count(opName) > 1; + return _declarablesLD.count(opName) > 0; } bool OpRegistrator::registerOperation(const std::string &opName, std::shared_ptr op) { From d4434d64a903352f7db1dbb087103bbcd9a96415 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 12:44:59 +0300 Subject: [PATCH 060/233] few other Node fields removed Signed-off-by: raver119 --- libnd4j/include/graph/Node.h | 19 +----- libnd4j/include/graph/impl/Node.cpp | 99 ++++------------------------- 2 files changed, 14 insertions(+), 104 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index d27c8a740f22..d00722dde675 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -51,7 +51,6 @@ namespace sd { std::vector _referencedBy; - int * _dim = nullptr; std::string _name; @@ -61,11 +60,6 @@ namespace sd { // many ops require extra parameters to run double *_extraParams = nullptr; - - // optional scalar. used in scalar ops and in summary stats - // TODO: this field must be removed - NDArray _scalar; - bool _hasExternalOutputs; bool _hasExternalInputs; bool _hasInternalOutputs; @@ -166,12 +160,6 @@ namespace sd { bool hasInternalOutputs(); bool hasInternalInputs(); - double scalar(); - - std::vector * getDimensions(); - int * getDimensionsPtr(); - - void pickOutputOnce(int outputId); void pickOutput(int outputId); void pickOutput(int nodeId, int outputId); @@ -231,7 +219,6 @@ namespace sd { FORCEINLINE void pullValues(Node *other) { this->_dataType = other->dataType(); this->_protoContext = other->protoContext(); - this->_scalar = other->scalar(); this->_hasExternalInputs = other->hasExternalInputs(); this->_hasExternalOutputs = other->hasExternalOutputs(); this->_hasInternalInputs = other->hasInternalInputs(); @@ -248,13 +235,9 @@ namespace sd { for (auto &v: other->output()) this->_output.emplace_back(v); - - for (auto v: *other->getDimensions()) - this->_dimensions.emplace_back(v); - } - static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar); + static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); static void deleteOpByType(OpType opType, void *op); }; } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 613fe57d2dcc..455459a34bb2 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -54,7 +54,6 @@ namespace sd { this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; this->_customOp = customOp; _hasExternalInputs = false; @@ -62,9 +61,6 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; - // FIXME: get rid of this!!! - _scalar = NDArrayFactory::create(0); - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); block.appendI(iArgs); @@ -86,7 +82,6 @@ namespace sd { this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; this->_customOp = customOp; _hasExternalInputs = false; @@ -94,9 +89,6 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; - // FIXME: get rid of this!!! - _scalar = NDArrayFactory::create(0); - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); block.appendI(iArgs); @@ -219,9 +211,6 @@ namespace sd { _name = *name; } - double Node::scalar() { - return _scalar.e(0); - }; void Node::pickInput(std::pair& pair) { _input.push_back(pair); @@ -273,14 +262,6 @@ namespace sd { _hasInternalOutputs = true; } - int * Node::getDimensionsPtr() { - return _dim; - } - - std::vector * Node::getDimensions() { - return &_dimensions; - } - int Node::getLayer() { return _layer; } @@ -380,7 +361,6 @@ namespace sd { this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; this->_customOp = customOp; _hasExternalInputs = false; @@ -388,9 +368,6 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; - // FIXME: get rid of this!!! - _scalar = NDArrayFactory::create(0); - for (auto i: inputs) pickInput(i); @@ -410,7 +387,6 @@ namespace sd { this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; this->_customOp = customOp; _hasExternalInputs = false; @@ -418,9 +394,6 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; - // FIXME: get rid of this!!! - _scalar = NDArrayFactory::create(0); - for (auto i: inputs) pickInput(i); @@ -438,7 +411,6 @@ namespace sd { this->_opNum = customOp->getOpHash(); this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; // if custom op is a registered one - pull it from cache, otherwise - clone locally if (sd::ops::OpRegistrator::getInstance()->hasOperation(_opNum)) @@ -451,24 +423,12 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; - // FIXME: get rid of this!!! - _scalar = NDArrayFactory::create(scalar); - for (auto i: input) pickInput(i); for (auto o: output) pickOutput(o); - if (dimensions.size() > 0) { - _dim = new int[dimensions.size()]; - int cnt = 0; - for (auto d: dimensions) { - _dimensions.push_back(d); - _dim[cnt++] = d; - } - } - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); for (auto v: dimensions) @@ -493,30 +453,18 @@ namespace sd { this->_opNum = opNum; this->_extraParams = nullptr; this->_dataType = sd::DataType::FLOAT32; // float as default - this->_dim = nullptr; _hasExternalInputs = false; _hasExternalOutputs = false; _hasInternalInputs = false; _hasInternalOutputs = false; - _scalar = NDArrayFactory::create(scalar); - for (auto i: input) pickInput(i); for (auto o: output) pickOutput(o); - if (dimensions.size() > 0) { - _dim = new int[dimensions.size()]; - int cnt = 0; - for (auto d: dimensions) { - _dimensions.push_back(d); - _dim[cnt++] = d; - } - } - // these ops allow in-place execution by design if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || opType == OpType_SCALAR || opType == OpType_BROADCAST) { if (_output.size() <= 1) { @@ -561,7 +509,7 @@ namespace sd { this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), opNum, &_scalar)); + this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (opType == OpType_CUSTOM) { if (this->customOp()) { @@ -587,7 +535,7 @@ namespace sd { _hasInternalInputs = false; _hasInternalOutputs = false; _extraParams = nullptr; - _dim = nullptr; + _dataType = sd::DataType::FLOAT32; // float as default if (node->scope_id() != 0) this->_scope_id = node->scope_id(); @@ -595,11 +543,8 @@ namespace sd { if (node->scope_name() != nullptr && node->scope_name()->size() > 0) this->_scope_name = node->scope_name()->str(); - if (node->scalar() != nullptr) { - auto scalar = FlatUtils::fromFlatArray(node->scalar()); - _scalar = *scalar; - delete scalar; - } + if (node->scalar() != nullptr) + throw std::runtime_error("FlatNode has scalar defined, it's deprecated"); if (node != nullptr) { this->_id = node->id(); @@ -648,13 +593,8 @@ namespace sd { } } - if (node->dimensions() != nullptr && node->dimensions()->size() > 0) { - _dim = new int[node->dimensions()->size()]; - for (int e = 0; e < (int) node->dimensions()->size(); e++) { - _dimensions.emplace_back(node->dimensions()->Get(e)); - _dim[e] = node->dimensions()->Get(e); - } - } + if (node->dimensions() != nullptr && node->dimensions()->size() > 0) + throw std::runtime_error("FlatNode has dimensions defined. Graph is outdated"); if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { if (node->extraInteger()->size() < 1) { @@ -720,7 +660,7 @@ namespace sd { } this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum, &_scalar)); + this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { this->_isDeductable = true; @@ -758,7 +698,7 @@ namespace sd { this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum, &_scalar)); + this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } } else if (this->_opType == OpType_CUSTOM) { @@ -819,9 +759,6 @@ namespace sd { Node::~Node() { if (_extraParams != nullptr) delete[] _extraParams; - - if (_dim != nullptr) - delete[] _dim; } int Node::getRewindNode() { @@ -871,7 +808,6 @@ namespace sd { _graph = other._graph; _customOp = other._customOp; - _dim = other._dim; _extraParams = other._extraParams; _protoContext = other._protoContext; @@ -880,7 +816,6 @@ namespace sd { _dimensions = other._dimensions; _rewindLayer = other._rewindLayer; _referencedBy = other._referencedBy; - _scalar = other._scalar; } Node &Node::operator=(const Node &other) noexcept { @@ -909,7 +844,6 @@ namespace sd { _graph = other._graph; _customOp = other._customOp; - _dim = other._dim; _extraParams = other._extraParams; _protoContext = other._protoContext; @@ -918,7 +852,6 @@ namespace sd { _dimensions = other._dimensions; _rewindLayer = other._rewindLayer; _referencedBy = other._referencedBy; - _scalar = other._scalar; return *this; } @@ -945,17 +878,15 @@ namespace sd { _removable = other._removable; _graph = other._graph; - _customOp = other._customOp; - _dim = other._dim; _extraParams = other._extraParams; _protoContext = other._protoContext; + _customOp = std::move(other._customOp); _input = std::move(other._input); _output = std::move(other._output); _dimensions = std::move(other._dimensions); _rewindLayer = std::move(other._rewindLayer); _referencedBy = std::move(other._referencedBy); - _scalar = std::move(other._scalar); other._customOp = nullptr; } @@ -985,19 +916,15 @@ namespace sd { _removable = other._removable; _graph = other._graph; - _customOp = other._customOp; - _dim = other._dim; _extraParams = other._extraParams; _protoContext = other._protoContext; + _customOp = std::move(other._customOp); _input = std::move(other._input); _output = std::move(other._output); _dimensions = std::move(other._dimensions); _rewindLayer = std::move(other._rewindLayer); _referencedBy = std::move(other._referencedBy); - _scalar = std::move(other._scalar); - - other._customOp = nullptr; return *this; } @@ -1066,7 +993,7 @@ namespace sd { } } - std::shared_ptr Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { + std::shared_ptr Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum) { switch (opType) { case OpType_PAIRWISE: return std::make_shared(opNum); @@ -1081,9 +1008,9 @@ namespace sd { case OpType_TRANSFORM_BOOL: return std::make_shared(opNum); case OpType_SCALAR: - return scalar == nullptr ? std::make_shared(opNum) : std::make_shared(opNum, *scalar); + return std::make_shared(opNum); case OpType_SCALAR_BOOL: - return scalar == nullptr ? std::make_shared(opNum) : std::make_shared(opNum, *scalar); + return std::make_shared(opNum); case OpType_REDUCE_3: return std::make_shared(opNum); case OpType_REDUCE_SAME: From 7ea87ced33fc407eec587fda6c5ff55d653f0b64 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 13:17:12 +0300 Subject: [PATCH 061/233] next step Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 20 ++- libnd4j/include/graph/impl/Graph.cpp | 63 +++++++++- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- .../layers_tests/GraphHolderTests.cpp | 14 +-- .../tests_cpu/layers_tests/OneOffTests.cpp | 119 ++++++------------ 5 files changed, 115 insertions(+), 103 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 448503e63808..87a60718a172 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -47,7 +47,7 @@ namespace sd { protected: ExecutorConfiguration _configuration; VariableSpace *_variableSpace; - Stash* _stash; + Stash _stash; MAP_IMPL _unmapped; @@ -69,18 +69,28 @@ namespace sd { void printOutNode(const Node &node) const; - std::vector _placeholders; + std::vector _placeholders; public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); ~Graph(); + Graph(const Graph& other); + + Graph& operator=(const Graph& other) noexcept; + + // move constructor + Graph(Graph&& other); + + // move assignment operator + Graph& operator=(Graph&& other) noexcept; + /** * Methods that allow Graph imports */ - static Graph *importFromTensorFlow(const char *fileName); - static Graph* fromFlatBuffers(const char *fileName, const GraphMemoryManager &memoryManager = GraphMemoryManager()); - static Graph* fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + static Graph importFromTensorFlow(const char *fileName); + static Graph fromFlatBuffers(const char *fileName, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + static Graph fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); // method that'll print out graph Nd4jStatus validate(); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index decf1a026f4a..69a7bd047626 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -298,7 +298,7 @@ namespace sd { } - Graph* Graph::fromFlatBuffers(const char* fileName, const GraphMemoryManager &memoryManager) { + Graph Graph::fromFlatBuffers(const char* fileName, const GraphMemoryManager &memoryManager) { // check if file exists if (!FileUtils::fileExists(fileName)) throw std::runtime_error("Graph file doesn't exist"); @@ -333,15 +333,15 @@ namespace sd { return fromFlatPointer(ptrGraph, memoryManager); } - Graph* Graph::fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager) { + Graph Graph::fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager) { // get FlatGraph out of it auto fg = GetFlatGraph(reinterpret_cast(ptr)); // return Graph from this FlatGraph - return new Graph(fg, nullptr, memoryManager); + return Graph(fg, nullptr, memoryManager); } - Graph* Graph::importFromTensorFlow(const char *fileName) { + Graph Graph::importFromTensorFlow(const char *fileName) { throw std::runtime_error("Graph::importFromTensorFlow() not implemented yet"); /* if (fileName == nullptr) @@ -549,7 +549,7 @@ namespace sd { var->setName(nodeName); _variableSpace->putVariable(id, var); - _placeholders.emplace_back(var); + _placeholders.emplace_back(nodeName); } std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { @@ -646,6 +646,59 @@ namespace sd { return optGraf; } + Graph::Graph(const Graph &other) : _memoryMaager(other._memoryMaager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = _maxId; + } + + Graph &Graph::operator=(const Graph &other) noexcept { + if (this == &other) + return *this; + + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = _maxId; + + return *this; + } + + Graph::Graph(Graph &&other) : _memoryMaager(other._memoryMaager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); + + _built = false; + _maxId = _maxId; + } + + Graph &Graph::operator=(Graph &&other) noexcept { + if (this == &other) + return *this; + + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); + + _built = false; + _maxId = _maxId; + + return *this; + } } } diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 67419d31e02c..aea3cfa5c613 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2167,7 +2167,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat try { auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); - sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); + //sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); return ND4J_STATUS_OK; } catch (std::exception &e) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp index f1f7195e746e..943d525db873 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -45,42 +45,36 @@ TEST_F(GraphHolderTests, SimpleTests_1) { TEST_F(GraphHolderTests, SimpleTests_2) { - auto graph = new Graph; + Graph graph; Nd4jLong graphId = 117; - GraphHolder::getInstance()->registerGraph(graphId, graph); + GraphHolder::getInstance()->registerGraph(graphId, &graph); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); auto graph2 = GraphHolder::getInstance()->cloneGraph(graphId); - ASSERT_TRUE(graph != graph2); ASSERT_TRUE(graph2 != nullptr); GraphHolder::getInstance()->forgetGraph(graphId); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); - delete graph; delete graph2; } TEST_F(GraphHolderTests, SimpleTests_3) { - auto graph = new Graph; + Graph graph; Nd4jLong graphId = 117; - GraphHolder::getInstance()->registerGraph(graphId, graph); + GraphHolder::getInstance()->registerGraph(graphId, &graph); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); auto graph2 = GraphHolder::getInstance()->cloneGraph(graphId); - ASSERT_TRUE(graph != graph2); ASSERT_TRUE(graph2 != nullptr); GraphHolder::getInstance()->dropGraph(graphId); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); - - - delete graph2; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 20b7780067ef..cea4ab3e55a8 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -38,28 +38,19 @@ class OneOffTests : public testing::Test { TEST_F(OneOffTests, test_avg_pool_3d_1) { auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); - delete graph; + graph.execute(); } TEST_F(OneOffTests, test_avg_pool_3d_2) { auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); - delete graph; + graph.execute(); } TEST_F(OneOffTests, test_non2d_0A_1) { auto graph = Graph::fromFlatBuffers("./resources/non2d_0A.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); - delete graph; + graph.execute(); } /* @@ -84,10 +75,7 @@ TEST_F(OneOffTests, test_assert_scalar_float32_2) { sd::ops::noop op2; auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); - delete graph; + graph.execute(); } @@ -95,17 +83,14 @@ TEST_F(OneOffTests, test_pad_1D_1) { auto e = NDArrayFactory::create('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f}); auto graph = Graph::fromFlatBuffers("./resources/pad_1D.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(4)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(4)); - auto z = graph->getVariableSpace()->getVariable(4)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(4)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - delete graph; } /* TEST_F(OneOffTests, test_scatter_nd_update_1) { @@ -141,108 +126,90 @@ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { auto e = NDArrayFactory::create('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); auto graph = Graph::fromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(9)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(9)); - auto z = graph->getVariableSpace()->getVariable(9)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(9)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_tensor_array_1) { auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); auto graph = Graph::fromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(5)); - auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(5)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_tensor_array_2) { auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); auto graph = Graph::fromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(6)); - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_tensor_array_3) { auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); auto graph = Graph::fromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(15)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(15)); - auto z = graph->getVariableSpace()->getVariable(15)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(15)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_tensor_array_4) { auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); auto graph = Graph::fromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(11)); - auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(11)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_assert_4) { auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } // TEST_F(OneOffTests, test_cond_true_1) { @@ -299,59 +266,47 @@ TEST_F(OneOffTests, test_identity_n_2) { sd::ops::identity_n op; auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1, 1)); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - delete graph; } TEST_F(OneOffTests, test_non2d_1) { auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); - ASSERT_TRUE(graph != nullptr); - graph->execute(); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(3)); - auto z = graph->getVariableSpace()->getVariable(3)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(3)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - - delete graph; } TEST_F(OneOffTests, test_reduce_all_1) { auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + graph.execute(); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); - auto in = graph->getVariableSpace()->getVariable(2)->getNDArray(); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.getVariableSpace()->hasVariable(2)); + auto in = graph.getVariableSpace()->getVariable(2)->getNDArray(); - auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); - - - delete graph; } \ No newline at end of file From bc773619742d7edcb0bc15c7d8ce1924bcd09da3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 13:47:47 +0300 Subject: [PATCH 062/233] one more signature changed Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/impl/Graph.cpp | 8 ++++---- libnd4j/include/graph/impl/GraphHolder.cpp | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 87a60718a172..d1e502746647 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -148,7 +148,7 @@ namespace sd { /** * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace */ - Graph* cloneWithProxy() const; + Graph cloneWithProxy() const; /** * This method removes reference to VariableSpace from this Graph diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 69a7bd047626..d6cf00eaea91 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -277,12 +277,12 @@ namespace sd { _configuration = configuration; } - Graph* Graph::cloneWithProxy() const { - auto clone = new Graph(); + Graph Graph::cloneWithProxy() const { + Graph clone; - clone->replaceState(new VariableProxy(this->_variableSpace), this->_configuration.clone()); + clone.replaceState(new VariableProxy(this->_variableSpace), this->_configuration); - throw std::runtime_error("Graph::cloneWithProxy - not implemented yet"); + return clone; } Graph* Graph::clone() const { diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index b79ba2b5ff64..c0b3fe59f421 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -49,7 +49,7 @@ namespace sd { auto graph = _graphF[graphId]->cloneWithProxy(); - return graph; + throw std::runtime_error("GraphHolder::cloneGraph - not implemented yet"); } Graph* GraphHolder::pullGraph(Nd4jLong graphId) { From b4639312639be9b15b50f005247d75c455615a05 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 13:57:51 +0300 Subject: [PATCH 063/233] one more signature for Graph Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/Node.h | 2 +- libnd4j/include/graph/impl/Graph.cpp | 7 +++---- libnd4j/tests_cpu/layers_tests/GraphTests2.cpp | 11 ++++------- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index d1e502746647..78eb02966062 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -112,7 +112,7 @@ namespace sd { * These methods add given node to the graph * @param node */ - //void addNode(Node &&node, const std::vector &inputs); + void addNode(Node &&node, const std::initializer_list &inputs); void addNode(Node &node, const std::initializer_list &inputs); void addNode(Node &node, const std::initializer_list &inputs); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index d00722dde675..223acd59b51d 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -94,7 +94,7 @@ namespace sd { public: - explicit Node(const std::string &nodeName, const sd::ops::DeclarableOp &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); + explicit Node(const std::string &nodeName, const sd::ops::DeclarableOp &op, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const FlatNode *node); ~Node(); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index d6cf00eaea91..2d6de75fd452 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -73,12 +73,11 @@ namespace sd { auto lvalue = array; addVariable(name, lvalue); } -/* - void Graph::addNode(Node &&node, const std::vector &inputs) { - auto lvalue = node; + + void Graph::addNode(Node &&node, const std::initializer_list &inputs) { + auto lvalue = std::move(node); addNode(lvalue, inputs); } - */ void Graph::addNode(Node &node, const std::initializer_list &inputs) { if (node.id() != 0) diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index c910b4310b0a..489ac4c379e0 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -76,10 +76,9 @@ TEST_F(GraphTests2, test_execution_1) { // C graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a("multiply_node", sd::ops::multiply()); Node b("add_node", sd::ops::add()); - graph.addNode(a, {"A", "B"}); + graph.addNode(Node("multiply_node", sd::ops::multiply()), {"A", "B"}); graph.addNode(b, {"multiply_node", "C"}); auto result = graph.execute({}, {"add_node"}); @@ -104,8 +103,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_2) { graph.addPlaceholder("input", DataType::FLOAT32); - Node a("tanh_node", "tanh"); - graph.addNode(a, {"input"}); + graph.addNode(Node("tanh_node", sd::ops::tanh()), {"input"}); auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); @@ -117,8 +115,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_3) { graph.addPlaceholder("input", DataType::FLOAT32); - Node a("tanh_node", "tanh"); - graph.addNode(a, {"input"}); + graph.addNode(Node("tanh_node", sd::ops::tanh()), {"input"}); ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), sd::datatype_exception); } @@ -128,7 +125,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_4) { graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); - Node a("tanh_node", "tanh"); + Node a("tanh_node", sd::ops::tanh()); graph.addNode(a, {"input"}); ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), sd::shape_mismatch_exception); From bdbce0b0b694400be0e81586dd8636d423be0901 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 14:32:44 +0300 Subject: [PATCH 064/233] ExecutionTask instead of pair Signed-off-by: raver119 --- .../include/graph/execution/ExecutionTask.h | 59 +++++++++++++++++ .../include/graph/execution/GraphExecutor.h | 4 +- libnd4j/include/graph/execution/OpSequence.h | 18 ++--- .../graph/execution/impl/ExecutionTask.cpp | 65 +++++++++++++++++++ .../graph/execution/impl/GraphExecutor.cpp | 8 +-- .../graph/execution/impl/OpSequence.cpp | 18 +++-- .../layers_tests/GraphAnalysisTests.cpp | 10 +-- .../layers_tests/OpSequenceTests.cpp | 6 +- 8 files changed, 160 insertions(+), 28 deletions(-) create mode 100644 libnd4j/include/graph/execution/ExecutionTask.h create mode 100644 libnd4j/include/graph/execution/impl/ExecutionTask.cpp diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h new file mode 100644 index 000000000000..a43c59904a3f --- /dev/null +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_EXECUTIONTASK_H +#define SD_EXECUTIONTASK_H + +#include +#include +#include + +namespace sd { + namespace graph { + class SD_EXPORT ExecutionTask { + protected: + std::shared_ptr _op; + const ContextPrototype &_context; + + public: + ExecutionTask(std::shared_ptr op, const ContextPrototype &ctx); + + ~ExecutionTask() = default; + + ExecutionTask(const ExecutionTask& other); + + ExecutionTask& operator=(const ExecutionTask& other) noexcept; + + // move constructor + ExecutionTask(ExecutionTask&& other); + + // move assignment operator + ExecutionTask& operator=(ExecutionTask&& other) noexcept; + + + std::shared_ptr op() const; + + const ContextPrototype &protoContext() const; + }; + } +} + + +#endif //SD_EXECUTIONTASK_H diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 01b448d12a2b..3d7364dce9d3 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -31,7 +31,7 @@ namespace sd { class SD_EXPORT GraphExecutor { protected: - virtual Context prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const; + virtual Context prepareContext(const ContextPrototype &contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const; /* * preprocessor call involves: @@ -74,7 +74,7 @@ namespace sd { * @param contextPrototype * @return */ - virtual Nd4jStatus execute(sd::ops::DeclarableOp *op, ContextPrototype *contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const; + virtual Nd4jStatus execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const; }; } } diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 23d8dda366c3..5ae2c97b759b 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -23,6 +23,7 @@ #include #include +#include namespace sd { @@ -30,16 +31,16 @@ namespace sd { /** * This class represents independent and immutable sequence of operations */ - class SD_EXPORT OpSequence : public std::iterator> { + class SD_EXPORT OpSequence : public std::iterator { // our internal iterator for OpSequence class iterator; protected: // main thing here. sorted list of operations and their contexts - std::vector> _ops; + std::vector _ops; int _deviceId = 0; public: - explicit OpSequence(const std::vector> &ops, const int deviceId = 0); + explicit OpSequence(const std::vector &ops, const int deviceId = 0); OpSequence(const int deviceId = 0); ~OpSequence() = default; @@ -73,15 +74,16 @@ namespace sd { * @param index * @return */ - std::pair at(uint64_t index) const; - std::pair operator[](uint64_t index) const; + ExecutionTask at(uint64_t index) const; + ExecutionTask operator[](uint64_t index) const; /** * This method allows to add DeclarableOp to the end of execution queue * @param op - Op to be executed * @param ctx - ContextPrototype for this operation with inputs/outputs/args defined */ - void append(sd::ops::DeclarableOp *op, sd::graph::ContextPrototype *ctx); + void append(std::shared_ptr op, const sd::graph::ContextPrototype &ctx); + void append(sd::ops::DeclarableOp *op, const sd::graph::ContextPrototype &ctx); /** * Iterator functionality for OpSequence @@ -93,13 +95,13 @@ namespace sd { // additional private section private: - class iterator : public std::iterator> { + class iterator : public std::iterator { private: uint64_t _position = 0; OpSequence & _container; public: explicit iterator(OpSequence & container, uint64_t index = 0); - std::pair operator*() const; + ExecutionTask operator*() const; iterator & operator++(); iterator & operator++(int); bool operator!=(const iterator &) const; diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp new file mode 100644 index 000000000000..a51452696019 --- /dev/null +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + ExecutionTask::ExecutionTask(std::shared_ptr op, const ContextPrototype &ctx) : _op(op), _context(ctx) { + // + } + + std::shared_ptr ExecutionTask::op() const { + return _op; + } + + const ContextPrototype &ExecutionTask::protoContext() const { + return _context; + } + + ExecutionTask::ExecutionTask(const ExecutionTask &other) : _op(other._op), _context(other._context) { + // + } + + ExecutionTask &ExecutionTask::operator=(const ExecutionTask &other) noexcept { + if (this == &other) + return *this; + + _op = other._op; + const_cast(_context) = other._context; + + return *this; + } + + ExecutionTask::ExecutionTask(ExecutionTask &&other) : _op(other._op), _context(other._context) { + // + } + + ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { + if (this == &other) + return *this; + + _op = std::move(other._op); + const_cast(_context) = std::move(other._context); + + return *this; + } + } +} diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index d3b2167da245..4356b049c727 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -23,9 +23,9 @@ namespace sd { namespace graph { - Context GraphExecutor::prepareContext(ContextPrototype *contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const { + Context GraphExecutor::prepareContext(const ContextPrototype &contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const { // TODO: maybe we'll want to do something here? - return Context(*contextPrototype, &variableSpace, const_cast(&memoryManager)); + return Context(contextPrototype, &variableSpace, const_cast(&memoryManager)); } Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { @@ -44,7 +44,7 @@ namespace sd { } - Nd4jStatus GraphExecutor::execute(sd::ops::DeclarableOp *op, ContextPrototype *contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { + Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { auto ctx = prepareContext(contextPrototype, *graph.originalGraph().getVariableSpace(), graph.memoryManager()); return op->execute(&ctx); } @@ -55,7 +55,7 @@ namespace sd { */ for (int e = 0; e < sequence.length(); e++) { auto v = sequence[e]; - auto result = execute(v.first, v.second, sequence, graph, deviceId >= 0 ? deviceId : sequence.deviceId()); + auto result = execute(v.op(), v.protoContext(), sequence, graph, deviceId >= 0 ? deviceId : sequence.deviceId()); if (result != Status::OK()) return result; } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 7c2a0dda5d80..2e609863806e 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -19,6 +19,7 @@ // #include +#include namespace sd { namespace graph { @@ -26,7 +27,7 @@ namespace sd { // } - OpSequence::OpSequence(const std::vector> &ops, const int deviceId) { + OpSequence::OpSequence(const std::vector &ops, const int deviceId) { _deviceId = deviceId; for (const auto v : ops) @@ -67,11 +68,11 @@ namespace sd { return _deviceId; } - std::pair OpSequence::at(uint64_t index) const { + ExecutionTask OpSequence::at(uint64_t index) const { return _ops[index]; } - std::pair OpSequence::operator[](uint64_t index) const { + ExecutionTask OpSequence::operator[](uint64_t index) const { return at(index); } @@ -79,10 +80,15 @@ namespace sd { return _ops.size(); } - void OpSequence::append(sd::ops::DeclarableOp *op, sd::graph::ContextPrototype *ctx) { - _ops.emplace_back(std::pair{op, ctx}); + void OpSequence::append(std::shared_ptr op, const sd::graph::ContextPrototype &ctx) { + ExecutionTask task(op, ctx); + _ops.emplace_back(task); } + void OpSequence::append(sd::ops::DeclarableOp *op, const ContextPrototype &ctx) { + auto rop = sd::ops::OpRegistrator::getInstance()->getOperation(op->getOpHash()); + append(rop, ctx); + } OpSequence::iterator OpSequence::begin() { @@ -98,7 +104,7 @@ namespace sd { // } - std::pair OpSequence::iterator::operator*() const { + ExecutionTask OpSequence::iterator::operator*() const { return _container._ops[_position]; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 24a269c00bc7..0acfb5d76eea 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -70,8 +70,8 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { // we expect that OpSequence has exactly 2 ops ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(10, sequence.at(0).second->nodeId()); - ASSERT_EQ(20, sequence.at(1).second->nodeId()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(20, sequence.at(1).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_2) { @@ -116,7 +116,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { // we expect that OpSequence has exactly 2 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).second->nodeId()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(0); @@ -127,10 +127,10 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { sequence = layer1[0]; ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(20, sequence.at(0).second->nodeId()); + ASSERT_EQ(20, sequence.at(0).protoContext().nodeId()); sequence = layer1[1]; ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(30, sequence.at(0).second->nodeId()); + ASSERT_EQ(30, sequence.at(0).protoContext().nodeId()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 76fce192d135..658c36e183b5 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -49,14 +49,14 @@ TEST_F(OpSequenceTests, test_iterator_1) { Context ctx1(1); Context ctx2(2); - sequence.append(&op1, &ctx1); - sequence.append(&op2, &ctx2); + sequence.append(&op1, ctx1); + sequence.append(&op2, ctx2); ASSERT_EQ(2, sequence.length()); int cnt = 1; for (const auto &v:sequence) { - ASSERT_EQ(cnt++, v.second->nodeId()); + ASSERT_EQ(cnt++, v.protoContext().nodeId()); } ASSERT_EQ(3, cnt); From 1ecc0c4ae3faa5f3474ea167d83a7f91b691cb49 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 15:04:47 +0300 Subject: [PATCH 065/233] nano fix Signed-off-by: raver119 --- libnd4j/include/graph/execution/impl/GraphExecutor.cpp | 4 ++-- libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 4356b049c727..1b339dfa2e14 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -73,13 +73,13 @@ namespace sd { for (uint64_t l = 0; l < graph.layers(); l++) { auto layer = graph.layer(l); - for (uint64_t o = 0; layer.width(); o++) { + for (uint64_t o = 0; o < layer.width(); o++) { execute(layer[o], graph); } // optionally block until all sequences in this layer processed if (layer.width() > 0 && numDevices > 1) - for (uint64_t o = 0; layer.width(); o++) { + for (uint64_t o = 0; o < layer.width(); o++) { result = layer[o].wait(); if (result != Status::OK()) return result; diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 909073545d8e..8a0b06c99b79 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -40,5 +40,13 @@ class GraphExecutorTests : public testing::Test { }; TEST_F(GraphExecutorTests, test_basic_exec_1) { + GraphMemoryManager memoryManager; + + OptimizedGraph optimizedGraph(memoryManager); + OpSequence sequence; + + optimizedGraph.append(sequence); + GraphExecutor executor; + executor.execute(optimizedGraph); } \ No newline at end of file From 6f68769ae87b2bddb8f922a85ef806600d1902e7 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 15:42:34 +0300 Subject: [PATCH 066/233] few more methods/fields removed Signed-off-by: raver119 --- libnd4j/include/graph/Context.h | 4 - libnd4j/include/graph/ContextPrototype.h | 18 ++-- libnd4j/include/graph/Node.h | 15 --- libnd4j/include/graph/impl/Context.cpp | 14 --- .../include/graph/impl/ContextPrototype.cpp | 95 ++++++++++++++++--- libnd4j/include/graph/impl/Node.cpp | 31 ------ .../ops/declarable/generic/linalg/tri.cpp | 2 +- .../generic/nn/convo/dilation2d.cpp | 2 +- .../generic/parity_ops/confusion_matrix.cpp | 3 +- .../declarable/generic/random/bernoulli.cpp | 2 +- .../declarable/generic/random/exponential.cpp | 2 +- .../ops/declarable/generic/random/normal.cpp | 2 +- .../declarable/generic/random/set_seed.cpp | 2 +- .../declarable/generic/tests/testcustom.cpp | 2 +- .../declarable/generic/transforms/concat.cpp | 2 +- .../declarable/generic/transforms/tear.cpp | 2 +- .../ops/declarable/helpers/cpu/dropout.cpp | 6 +- .../ops/declarable/impl/BroadcastableOp.cpp | 4 +- .../ops/declarable/impl/DeclarableListOp.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 6 +- .../declarable/impl/DeclarableReductionOp.cpp | 2 +- .../ops/declarable/impl/LegacyRandomOp.cpp | 2 +- 22 files changed, 112 insertions(+), 108 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index bbf1a1c90fcb..0d03caafb231 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -90,10 +90,6 @@ namespace sd { Nd4jLong getOuterTime(); Nd4jLong getInnerTime(); - sd::DataType dataType() const override; - - sd::DataType dataType(int index) const override; - void setDataType(int index, sd::DataType type) override; // these methods are related to Workspace abstraction bool hasWorkspaceProvided(); void attachWorkspace(sd::memory::Workspace* workspace); diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 028588507287..2e1295f3796b 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -49,8 +49,6 @@ namespace sd { std::vector _axis; std::vector _dArgs; - // TODO: remove this field - sd::DataType _dataType = sd::DataType::FLOAT32; bool _isInplace; // opNum for legacy XYZ ops @@ -58,8 +56,6 @@ namespace sd { uint64_t _rootSeed; RandomGenerator _randomGenerator; - std::vector _dataTypes; - sd::ops::OpDescriptor* _opDescriptor; bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN(); @@ -71,6 +67,16 @@ namespace sd { explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); ~ContextPrototype() = default; + ContextPrototype(const ContextPrototype& other) noexcept; + + ContextPrototype& operator=(const ContextPrototype& other) noexcept; + + // move constructor + ContextPrototype(ContextPrototype&& other) noexcept; + + // move assignment operator + ContextPrototype& operator=(ContextPrototype&& other) noexcept; + int getNodeId() const; int nodeId() const; @@ -79,10 +85,6 @@ namespace sd { void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); - virtual sd::DataType dataType() const; - virtual sd::DataType dataType(int index) const ; - virtual void setDataType(int index, sd::DataType type); - bool isInplace() const; void markInplace(bool reallyInplace); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 223acd59b51d..cda7bad4d79a 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -53,10 +53,6 @@ namespace sd { std::string _name; - - // this variable points to onion layer within graph - int _layer = -1; - // many ops require extra parameters to run double *_extraParams = nullptr; @@ -68,9 +64,6 @@ namespace sd { // this field is used to check, if op should be used in-place (so it can/will modify its inputs) bool _isInplace = false; - // this field is used to delete attached customOp - bool _isDeductable = false; - OpClass _opClass; // these fields are used to store embedded CustomOps and Graph in case of Graph-in-Graph scenario @@ -148,9 +141,6 @@ namespace sd { bool isRemovable() const; void markRemovable(bool reallyRemovable) const; - int getLayer(); - void setLayer(int layer); - bool isDivergencePoint(); void setActive(bool reallyActive); bool isActive(); @@ -169,9 +159,6 @@ namespace sd { void pickInput(std::pair& id); void pickInput(const std::string &id); - bool isDeductable(); - void setDeductable(bool reallyDeductable); - void setName(std::string *name); void setName(const std::string& name); const std::string& getName() const; @@ -227,8 +214,6 @@ namespace sd { this->markInplace(other->isInplace()); this->setActive(other->isActive()); this->setScopeInfo(other->scopeId(), other->scopeName()->c_str()); - this->setLayer(other->getLayer()); - this->setDeductable(other->isDeductable()); for (auto &v: other->input()) this->_input.emplace_back(v); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index cb582e2ea7fa..0bdac22ce6d4 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -29,7 +29,6 @@ namespace sd { Context::Context(const ContextPrototype& prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager) { _memoryManager = memoryManager; _variableSpace = variableSpace; - _dataType = prototype.dataType(); for (const auto &v: prototype.inputs()) { this->_inputs.push_back(v); @@ -59,19 +58,6 @@ namespace sd { if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) this->_workspace = variableSpace->launchContext()->getWorkspace(); } - sd::DataType Context::dataType(int index) const { - return _dataType; - } - - sd::DataType Context::dataType() const { - return dataType(0); - } - - void Context::setDataType(int index, sd::DataType type) { - if (this->_dataTypes.size() > (size_t)index) - _dataTypes[index] = type; - _dataType = type; - } Context::Context(int nodeId, VariableSpace *variableSpace) { this->_nodeId = nodeId; diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 0288e2765820..d839348b5b18 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -106,19 +106,6 @@ namespace sd { return getNodeId(); } - sd::DataType ContextPrototype::dataType() const { - return dataType(0); - } - - sd::DataType ContextPrototype::dataType(int index) const { - return _dataType; - } - - void ContextPrototype::setDataType(int index, sd::DataType type) { - // if (_outputs->size() == 0) - _dataType = type; - } - size_t ContextPrototype::numT() const { return (int) _tArgs.size(); } @@ -221,5 +208,87 @@ namespace sd { void ContextPrototype::appendD(DataType value) { _dArgs.emplace_back(value); } + + ContextPrototype::ContextPrototype(const ContextPrototype &other) noexcept { + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + } + + ContextPrototype &ContextPrototype::operator=(const ContextPrototype &other) noexcept { + if (this == &other) + return *this; + + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; + } + + ContextPrototype::ContextPrototype(ContextPrototype &&other) noexcept { + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + } + + ContextPrototype &ContextPrototype::operator=(ContextPrototype &&other) noexcept { + if (this == &other) + return *this; + + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 455459a34bb2..ce2839136499 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -262,14 +262,6 @@ namespace sd { _hasInternalOutputs = true; } - int Node::getLayer() { - return _layer; - } - - void Node::setLayer(int layer) { - _layer = layer; - } - bool Node::hasExternalOutputs() { return _hasExternalOutputs; } @@ -494,8 +486,6 @@ namespace sd { opType == OpType_SCALAR_BOOL || opType == OpType_SCALAR) { - this->_isDeductable = true; - ContextPrototype block(nullptr, this->id(), false); for (auto v: dimensions) @@ -630,8 +620,6 @@ namespace sd { } if (node->input() != nullptr && node->input()->size() > 0) { - this->_isDeductable = true; - ContextPrototype block(nullptr, this->id(), false); @@ -663,8 +651,6 @@ namespace sd { this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { - this->_isDeductable = true; - ContextPrototype block(nullptr, this->id(), false); for (int e = 0; e < this->input().size(); e++) { @@ -795,14 +781,12 @@ namespace sd { _scope_id = other._scope_id; _scope_name = other._scope_name; _rewindNode = other._rewindNode; - _layer = other._layer; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; _isInplace = other._isInplace; - _isDeductable = other._isDeductable; _active = other._active; _removable = other._removable; @@ -831,14 +815,12 @@ namespace sd { _scope_id = other._scope_id; _scope_name = other._scope_name; _rewindNode = other._rewindNode; - _layer = other._layer; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; _isInplace = other._isInplace; - _isDeductable = other._isDeductable; _active = other._active; _removable = other._removable; @@ -866,14 +848,12 @@ namespace sd { _name = std::move(other._name); _scope_name = std::move(other._scope_name); _rewindNode = other._rewindNode; - _layer = other._layer; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; _isInplace = other._isInplace; - _isDeductable = other._isDeductable; _active = other._active; _removable = other._removable; @@ -904,14 +884,12 @@ namespace sd { _name = std::move(other._name); _scope_name = std::move(other._scope_name); _rewindNode = other._rewindNode; - _layer = other._layer; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; _isInplace = other._isInplace; - _isDeductable = other._isDeductable; _active = other._active; _removable = other._removable; @@ -1036,15 +1014,6 @@ namespace sd { } } - bool Node::isDeductable() { - return _isDeductable; - } - - void Node::setDeductable(bool reallyDeductable) { - _isDeductable = reallyDeductable; - } - - Node* Node::clone() { if (this->_customOp && this->_opType == OpType_CUSTOM) { auto clone = new Node(_customOp.get(), _id); diff --git a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp index 19144e2fbe44..82b5ba5c3109 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp @@ -51,7 +51,7 @@ DECLARE_SHAPE_FN(tri) { const int rows = INT_ARG(0); const int cols = block.numI() > 1 ? INT_ARG(1) : rows; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', {rows, cols})); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {rows, cols})); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp index ea11934009e0..6a03e6c42844 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp @@ -109,7 +109,7 @@ namespace ops { rates = r->template asVectorT(); } else { if (block.numI() < 9) { - newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(block.dataType()); + newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input)); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp index e79b8cc2d991..eac7349280f0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp @@ -64,8 +64,7 @@ namespace sd { DECLARE_SHAPE_FN(confusion_matrix) { auto labels = INPUT_VARIABLE(0); auto predictions = INPUT_VARIABLE(1); - auto dtype = block.dataType(); - dtype = sd::DataType::INT64; // dtype - should be a param with int argument + auto dtype = sd::DataType::INT64; // dtype - should be a param with int argument if (block.numI() > 1) dtype = (sd::DataType)INT_ARG(1); diff --git a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp index 1441448c9471..fafe8c566321 100644 --- a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp +++ b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp @@ -53,7 +53,7 @@ namespace sd { auto in = INPUT_VARIABLE(0); auto shape = in->template asVectorT(); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 8605ffafea13..fae51a9675d2 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -63,7 +63,7 @@ namespace sd { auto in = INPUT_VARIABLE(0); auto shape = in->template asVectorT(); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/random/normal.cpp b/libnd4j/include/ops/declarable/generic/random/normal.cpp index 8bfbd8db6cae..8eb617fbd74d 100644 --- a/libnd4j/include/ops/declarable/generic/random/normal.cpp +++ b/libnd4j/include/ops/declarable/generic/random/normal.cpp @@ -48,7 +48,7 @@ namespace sd { auto in = INPUT_VARIABLE(0); auto shape = in->template asVectorT(); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp index 5102321ad8f4..dff6a4e6484d 100644 --- a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp @@ -48,7 +48,7 @@ namespace sd { } DECLARE_SHAPE_FN(set_seed) { - auto newshape = ConstantShapeHelper::getInstance()->scalarShapeInfo(block.dataType()); + auto newshape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); return SHAPELIST(newshape); } diff --git a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp index 89480e5bca8f..283efde0cbe4 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp @@ -39,7 +39,7 @@ namespace sd { for (int e = 0; e < shape::rank(inputShape->at(0)); e++) shapeOf[e] = inputShape->at(0)[e+1] * 2; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape::rank(inputShape->at(0)), shapeOf); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape::rank(inputShape->at(0)), shapeOf); RELEASE(shapeOf, block.getWorkspace()); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 5e587e40d453..f78558a93db8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { int index = 0; bool allOfSameType = true; auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; - auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); + auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; for(int i = 0; i < numOfInArrs; ++i) { auto input = INPUT_VARIABLE(i); diff --git a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp index 2c850bc93a86..c1def43bf3bc 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp @@ -62,7 +62,7 @@ namespace sd { auto result = SHAPELIST(); for (Nd4jLong e = 0; e < numTads; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), shape::order(inShape), shape::rank(tadPack.primaryShapeInfo()), shape::shapeOf(tadPack.primaryShapeInfo())); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape::rank(tadPack.primaryShapeInfo()), shape::shapeOf(tadPack.primaryShapeInfo())); result->push_back(newShape); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp index 54981dea5c5e..1d5de5cc7d46 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp @@ -155,17 +155,17 @@ namespace helpers { } int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(gradOut->dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(gradOut->dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp index eb691b84dff2..fa8640db156f 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp @@ -34,8 +34,8 @@ namespace sd { auto x = inputShape->at(0); auto y = inputShape->at(1); auto outputs = _descriptor->getOutputTypesForOutput(0); - sd::DataType dtype = block.dataType(0); - if (block.dataType(0) != sd::DataType::BOOL && !(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { + sd::DataType dtype; + if (!(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { if (Environment::getInstance()->isExperimentalBuild()) { if (shape::length(y) > shape::length(x)) { dtype = DataTypeUtils::pickPairwiseResultType(y, x); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index e06c39c2ed88..a83b970fbede 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -47,7 +47,7 @@ namespace sd { ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) { // TODO: ensure this method isn't ever called - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', {1, 1}); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {1, 1}); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index aca97b798652..3a061661e5fe 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -420,11 +420,11 @@ namespace sd { Nd4jLong len = shape::length(shape); // if that's first run - we probably have nothing here if (var->getNDArray() == nullptr) { - var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext())); + var->setNDArray(new NDArray(order, shape, DataType::FLOAT32, block.launchContext())); } else if(var->getNDArray()->lengthOf() != len) { // if length not match - lets reallocate array delete var->getNDArray(); - var->setNDArray(new NDArray(order, shape, block.dataType(), block.launchContext())); + var->setNDArray(new NDArray(order, shape, DataType::FLOAT32, block.launchContext())); } return true; @@ -882,7 +882,6 @@ namespace sd { Context block(1, &variableSpace, false); block.fillInputs(in); block.markInplace(isInplace); - block.setDataType(0, type); // we need this line for tests basically //if (rng != nullptr) @@ -1032,7 +1031,6 @@ namespace sd { } Context block(1, &variableSpace, false); - block.setDataType(0, sd::DataType::FLOAT32); block.fillInputs(in); block.markInplace(isInplace); // block.setRNG(ProviderRNG::getInstance().getRNG()); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp index c58766f65ac8..e483b0fb1f92 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp @@ -52,7 +52,7 @@ namespace sd { // special case - output is scalar if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - auto newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(block.dataType()); + auto newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index 8be8d5d2a263..f42edbcc727e 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -342,7 +342,7 @@ namespace sd { } else if (DataTypeUtils::isZ(xType)) { auto zShapeArr = INPUT_VARIABLE(0); auto zShapeVector = zShapeArr->asVectorT(); - auto dtype = block.dataType(); + auto dtype = DataType::BFLOAT16; newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', zShapeVector); return SHAPELIST(newShape); From 49e66c4bfdb86aa9e643408c4c649dfd58c780e3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 16:08:43 +0300 Subject: [PATCH 067/233] ContextPrototype nodeId Signed-off-by: raver119 --- libnd4j/include/graph/ContextPrototype.h | 1 + libnd4j/include/graph/impl/ContextPrototype.cpp | 4 ++++ libnd4j/include/graph/impl/Node.cpp | 1 + libnd4j/tests_cpu/layers_tests/ContextTests.cpp | 8 ++++++++ 4 files changed, 14 insertions(+) diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 2e1295f3796b..cea6a3422be3 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -79,6 +79,7 @@ namespace sd { int getNodeId() const; int nodeId() const; + void setNodeId(int id); // this method returns true, if inputs are defined bool hasVariablesFilled() const; diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index d839348b5b18..64827a869c1e 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -290,5 +290,9 @@ namespace sd { return *this; } + + void ContextPrototype::setNodeId(int id) { + _nodeId = id; + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index ce2839136499..1eb4d46bee22 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -177,6 +177,7 @@ namespace sd { void Node::setId(int id) { _id = id; + _protoContext.setNodeId(id); } std::shared_ptr Node::customOp() const { diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index a51a65bcc8c8..571efc1afb04 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -353,4 +353,12 @@ TEST_F(ContextTests, test_short_context_3) { auto z = ctx.fastpath_out()[0]; ASSERT_EQ(exp, *z); +} + +TEST_F(ContextTests, test_copy_1) { + ContextPrototype prototype(nullptr, 12); + + auto copy = prototype; + + ASSERT_EQ(prototype.nodeId(), copy.nodeId()); } \ No newline at end of file From b020d0abd30a75e13db864d192ddff703cbca8f9 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 18:27:40 +0300 Subject: [PATCH 068/233] few more changes around GraphExecutor Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 4 +- libnd4j/include/graph/OptimizedGraph.h | 2 +- .../graph/execution/impl/GraphExecutor.cpp | 2 +- libnd4j/include/graph/impl/Graph.cpp | 10 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 5 +- .../graph/logic/impl/LogicConditional.cpp | 2 +- .../include/graph/logic/impl/LogicEnter.cpp | 2 +- .../include/graph/logic/impl/LogicExit.cpp | 2 +- .../graph/logic/impl/LogicLoopCond.cpp | 2 +- .../include/graph/logic/impl/LogicMerge.cpp | 2 +- .../graph/logic/impl/LogicNextIteration.cpp | 2 +- .../include/graph/logic/impl/LogicReturn.cpp | 2 +- .../include/graph/logic/impl/LogicSwitch.cpp | 2 +- .../include/graph/logic/impl/LogicWhile.cpp | 2 +- .../profiling/impl/GraphProfilingHelper.cpp | 2 +- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- libnd4j/include/legacy/cuda/NativeOps.cu | 2 +- .../layers_tests/FlatBuffersTests.cpp | 62 +++++----- .../layers_tests/GraphExecutorTests.cpp | 32 ++++- libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 116 +++++++++--------- .../tests_cpu/layers_tests/GraphTests2.cpp | 4 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 58 ++++----- .../layers_tests/OpSequenceTests.cpp | 4 +- .../layers_tests/PlaygroundTests.cpp | 24 ++-- .../tests_cpu/layers_tests/ProtoBufTests.cpp | 12 +- .../tests_cpu/layers_tests/SanityTests.cpp | 6 +- .../tests_cpu/layers_tests/WorkspaceTests.cpp | 2 +- 27 files changed, 202 insertions(+), 165 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 78eb02966062..725dfe261935 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -106,7 +106,9 @@ namespace sd { * This method returns pointer to thread_local VariableSpace * @return */ - VariableSpace *getVariableSpace() const; + VariableSpace *variableSpace() const; + + const GraphMemoryManager& memoryManager() const; /** * These methods add given node to the graph diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index c2893a8c199c..6048bff464ff 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -46,7 +46,7 @@ namespace sd { std::mutex _mutex; public: - OptimizedGraph(GraphMemoryManager &memoryManager); + OptimizedGraph(Graph *original); ~OptimizedGraph() = default; OptimizedGraph(const OptimizedGraph& other) noexcept; diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 1b339dfa2e14..d488702a0ea3 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -45,7 +45,7 @@ namespace sd { Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { - auto ctx = prepareContext(contextPrototype, *graph.originalGraph().getVariableSpace(), graph.memoryManager()); + auto ctx = prepareContext(contextPrototype, *graph.originalGraph().variableSpace(), graph.memoryManager()); return op->execute(&ctx); } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 2d6de75fd452..3c3e959b50da 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -48,7 +48,7 @@ namespace sd { return _configuration; } - VariableSpace * Graph::getVariableSpace() const { + VariableSpace * Graph::variableSpace() const { return _variableSpace; } @@ -379,7 +379,7 @@ namespace sd { return nullptr; auto graph = new Graph(); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); std::map variablesMap; @@ -612,7 +612,7 @@ namespace sd { OptimizedGraph Graph::optimizedGraph() const { - OptimizedGraph optGraf(const_cast(_memoryMaager)); + OptimizedGraph optGraf(const_cast(this)); /* OpSequence opSeq; std::set nodesMap, startNodes; @@ -698,6 +698,10 @@ namespace sd { return *this; } + + const GraphMemoryManager &Graph::memoryManager() const { + return _memoryMaager; + } } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index aa49db24f641..28db8145c286 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -23,8 +23,9 @@ namespace sd { namespace graph { - OptimizedGraph::OptimizedGraph(GraphMemoryManager &memoryManager) { - _memoryManager = &memoryManager; + OptimizedGraph::OptimizedGraph(Graph *original) { + _originalGraph = original; + _memoryManager = const_cast(&original->memoryManager()); } OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp index c4a3747c0e19..a1ba5a9d0007 100644 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -28,7 +28,7 @@ namespace sd { Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { throw std::runtime_error("LogicConditional::processNode - not implemented yet"); /* - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto size = node->input()->size(); diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index b8ca2999e71f..95788197c47b 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -30,7 +30,7 @@ namespace sd { // this op replicates input variable into the frame. basically happens once for single loop. // sure, if there's inner loop within outer loop, it'll be called once for outer loop and multiple times for inner loop - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); // basically, first non-null variable is our target diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index f351795e05f7..92aca938306c 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -27,7 +27,7 @@ namespace sd { // this op is basically no-op // we just know it exists - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); Context ctx(node->protoContext(), __variableSpace); diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 8f04e0c3966e..35273ead76ed 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); Context ctx(node->protoContext(), __variableSpace); diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index e00fca1f0d21..a0068c968ce4 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -27,7 +27,7 @@ namespace sd { throw std::runtime_error("LogicMerge::processNode - not implemented yet"); /* // at merge node only one of inputs exist if that's just switch and other node isn't LogicNextItration - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); // merge MUST have 2 inputs diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 9d26f0c950e4..4a24d2fe0288 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -26,7 +26,7 @@ namespace sd { Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { throw std::runtime_error("LogicNextIeration::processNode - not implemented yet"); /* - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); auto inputAddr = node->input()->at(0); diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp index 7a1a4f997091..e7112a40d109 100644 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -27,7 +27,7 @@ namespace sd { Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { throw std::runtime_error("LogicReturn::processNode - not implemented yet"); /* - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); for (int e = 0; e < node->input()->size(); e++) { auto inputAddr = node->input()->at(e); diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index cb522b4d0d23..e58b42534fa5 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -27,7 +27,7 @@ namespace sd { Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); /* - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); Context ctx(node->getContextPrototype(), __variableSpace); diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index cf15cc88877b..073c8c09dac7 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -29,7 +29,7 @@ namespace sd { Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { throw std::runtime_error("LogicWhile::processNode - not implemented yet"); /* - auto __variableSpace = graph->getVariableSpace(); + auto __variableSpace = graph->variableSpace(); nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 8f93306a72bf..106ac96fbdf7 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -26,7 +26,7 @@ namespace sd { GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { // saving original workspace - auto varSpace = graph->getVariableSpace()->clone(); + auto varSpace = graph->variableSpace()->clone(); // printing out graph structure // graph->printOut(); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index aea3cfa5c613..dc098949322f 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2179,7 +2179,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { auto graph = sd::graph::GraphHolder::getInstance()->cloneGraph(graphId); - auto varSpace = graph->getVariableSpace(); + auto varSpace = graph->variableSpace(); std::vector handles; diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index bd6c23e8c05d..27dd7de76d68 100755 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -2918,7 +2918,7 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { auto graph = sd::graph::GraphHolder::getInstance()->pullGraph(graphId); - auto varSpace = graph->getVariableSpace()->clone(); + auto varSpace = graph->variableSpace()->clone(); std::vector handles; diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index 65a77d1539d3..c83b11f0ec69 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -168,7 +168,7 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { ASSERT_EQ(1, graph.rootNodes()); - auto vs = graph.getVariableSpace(); + auto vs = graph.variableSpace(); ASSERT_EQ(OutputMode_IMPLICIT, graph.getExecutorConfiguration()->_outputMode); @@ -359,7 +359,7 @@ TEST_F(FlatBuffersTest, ReadFile3) { ASSERT_EQ(ND4J_STATUS_OK, status); - auto z = graph->getVariableSpace()->getVariable(2)->getNDArray(); + auto z = graph->variableSpace()->getVariable(2)->getNDArray(); ASSERT_EQ(1, z->lengthOf()); ASSERT_EQ(8, z->e(0)); @@ -374,9 +374,9 @@ TEST_F(FlatBuffersTest, ReadInception1) { Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(227)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(227)); - auto lastNode = graph->getVariableSpace()->getVariable(227)->getNDArray(); + auto lastNode = graph->variableSpace()->getVariable(227)->getNDArray(); lastNode->printShapeInfo("Result shape"); @@ -403,12 +403,12 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { auto expPhi('c', {2, 2}); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-1)); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-2)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(-1)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(-2)); - auto phi = graph->getVariableSpace()->getVariable(-2)->getNDArray(); - auto constA = graph->getVariableSpace()->getVariable(-5)->getNDArray(); - auto lessY = graph->getVariableSpace()->getVariable(-6)->getNDArray(); + auto phi = graph->variableSpace()->getVariable(-2)->getNDArray(); + auto constA = graph->variableSpace()->getVariable(-5)->getNDArray(); + auto lessY = graph->variableSpace()->getVariable(-6)->getNDArray(); //constA->printBuffer("constA"); //lessY->printBuffer("lessY"); @@ -421,8 +421,8 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { // now, we expect some values - auto x = graph->getVariableSpace()->getVariable(20); - auto y = graph->getVariableSpace()->getVariable(21); + auto x = graph->variableSpace()->getVariable(20); + auto y = graph->variableSpace()->getVariable(21); ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); @@ -444,7 +444,7 @@ TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(23,0)); @@ -475,9 +475,9 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - auto x = graph->getVariableSpace()->getVariable(28); - auto y = graph->getVariableSpace()->getVariable(29); - auto z = graph->getVariableSpace()->getVariable(11, 2); + auto x = graph->variableSpace()->getVariable(28); + auto y = graph->variableSpace()->getVariable(29); + auto z = graph->variableSpace()->getVariable(11, 2); ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); @@ -503,9 +503,9 @@ TEST_F(FlatBuffersTest, ReadTensorArray_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(14)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(14)); - auto z = graph->getVariableSpace()->getVariable(14)->getNDArray(); + auto z = graph->variableSpace()->getVariable(14)->getNDArray(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -525,9 +525,9 @@ TEST_F(FlatBuffersTest, ReadStridedSlice_1) { ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(7)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(7)); - auto z = graph->getVariableSpace()->getVariable(7)->getNDArray(); + auto z = graph->variableSpace()->getVariable(7)->getNDArray(); ASSERT_NEAR(73.5f, z->e(0), 1e-5); @@ -545,7 +545,7 @@ TEST_F(FlatBuffersTest, ReduceDim_1) { graph->printOut(); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(1)); @@ -578,7 +578,7 @@ TEST_F(FlatBuffersTest, ReduceDim_2) { graph->printOut(); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); ASSERT_TRUE(variableSpace->hasVariable(1)); @@ -617,9 +617,9 @@ TEST_F(FlatBuffersTest, Ae_00) { auto result = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(18)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(18)); - auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); + auto z = graph->variableSpace()->getVariable(18)->getNDArray(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -638,9 +638,9 @@ TEST_F(FlatBuffersTest, expand_dims) { auto result = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(5)); - auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); + auto z = graph->variableSpace()->getVariable(5)->getNDArray(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -721,9 +721,9 @@ TEST_F(FlatBuffersTest, Test_TensorDotMisc) { auto result = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(77)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(77)); - auto z = graph->getVariableSpace()->getVariable(77,0)->getNDArray(); + auto z = graph->variableSpace()->getVariable(77,0)->getNDArray(); ASSERT_EQ(e, *z); @@ -739,9 +739,9 @@ TEST_F(FlatBuffersTest, Test_MNIST_00_1) { auto result = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - auto z = graph->getVariableSpace()->getVariable(6,0)->getNDArray(); + auto z = graph->variableSpace()->getVariable(6,0)->getNDArray(); ASSERT_EQ(e, *z); @@ -774,9 +774,9 @@ TEST_F(FlatBuffersTest, nhwc_conv_0) { auto result = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(11)); - auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); + auto z = graph->variableSpace()->getVariable(11)->getNDArray(); //z->printShapeInfo("z buffr"); //z->printIndexedBuffer("z shape"); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 8a0b06c99b79..a41a46a746c7 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace sd; using namespace sd::graph; @@ -41,12 +42,41 @@ class GraphExecutorTests : public testing::Test { TEST_F(GraphExecutorTests, test_basic_exec_1) { GraphMemoryManager memoryManager; + Graph graph; - OptimizedGraph optimizedGraph(memoryManager); + OptimizedGraph optimizedGraph(&graph); OpSequence sequence; optimizedGraph.append(sequence); + GraphExecutor executor; + executor.execute(optimizedGraph); +} + +TEST_F(GraphExecutorTests, test_basic_exec_2) { + GraphMemoryManager mgr; + Graph graph(nullptr, nullptr, mgr); + + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node m("mul", sd::ops::multiply()); + Node a("add", sd::ops::add()); + graph.addNode(m, {"A", "B"}); + graph.addNode(a, {"mul", "C"}); + + OptimizedGraph optimizedGraph(&graph); + OpSequence sequence; + + sequence.append(m.customOp(), m.protoContext()); + sequence.append(a.customOp(), a.protoContext()); + + optimizedGraph.append(sequence); + + ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(1, optimizedGraph.layers()); + GraphExecutor executor; executor.execute(optimizedGraph); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index 792e1b07aa6d..0266023a6814 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -50,7 +50,7 @@ TEST_F(GraphTests, SingleInput1) { auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0f); - graph.getVariableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-1, x); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_STRICT, transform::Cosine); @@ -64,9 +64,9 @@ TEST_F(GraphTests, SingleInput1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(3)); - auto node3 = graph.getVariableSpace()->getVariable(3)->getNDArray(); + auto node3 = graph.variableSpace()->getVariable(3)->getNDArray(); ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); } @@ -82,9 +82,9 @@ TEST_F(GraphTests, DoubleInput1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, y); - graph.getVariableSpace()->putVariable(-3, z); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, y); + graph.variableSpace()->putVariable(-3, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs); @@ -111,9 +111,9 @@ TEST_F(GraphTests, SingleInput3) { auto v0 = NDArrayFactory::create_('c', {5, 5}); auto v1 = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, v0); - graph.getVariableSpace()->putVariable(-3, v1); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, v0); + graph.variableSpace()->putVariable(-3, v1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); @@ -140,9 +140,9 @@ TEST_F(GraphTests, SingleInput4) { auto v0 = NDArrayFactory::create_('c', {5, 5}); auto v1 = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, v0); - graph.getVariableSpace()->putVariable(-3, v1); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, v0); + graph.variableSpace()->putVariable(-3, v1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); @@ -178,10 +178,10 @@ TEST_F(GraphTests, DoubleInput2) { auto z0 = NDArrayFactory::create_('c', {5, 5}); auto z1 = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, y); - graph.getVariableSpace()->putVariable(-3, z0); - graph.getVariableSpace()->putVariable(-4, z1); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, y); + graph.variableSpace()->putVariable(-3, z0); + graph.variableSpace()->putVariable(-4, z1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); @@ -223,11 +223,11 @@ TEST_F(GraphTests, DoubleInput3) { auto w = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, y); - graph.getVariableSpace()->putVariable(-3, z0); - graph.getVariableSpace()->putVariable(-4, z1); - graph.getVariableSpace()->putVariable(-5, w); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, y); + graph.variableSpace()->putVariable(-3, z0); + graph.variableSpace()->putVariable(-4, z1); + graph.variableSpace()->putVariable(-5, w); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); @@ -279,11 +279,11 @@ TEST_F(GraphTests, QuadInput1) { auto z = NDArrayFactory::create_('c', {5, 5}); z->assign(119.0); - graph.getVariableSpace()->putVariable(-1, x0); - graph.getVariableSpace()->putVariable(-2, x1); - graph.getVariableSpace()->putVariable(-3, x2); - graph.getVariableSpace()->putVariable(-4, x3); - graph.getVariableSpace()->putVariable(-5, z); + graph.variableSpace()->putVariable(-1, x0); + graph.variableSpace()->putVariable(-2, x1); + graph.variableSpace()->putVariable(-3, x2); + graph.variableSpace()->putVariable(-4, x3); + graph.variableSpace()->putVariable(-5, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); @@ -318,8 +318,8 @@ TEST_F(GraphTests, InternalBranching1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, z); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, z); // 1.0 Node nodeA(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); @@ -366,8 +366,8 @@ TEST_F(GraphTests, ReductionsTest1) { auto z = NDArrayFactory::create_('c', {5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, z); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, z); Node nodeA(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); @@ -396,9 +396,9 @@ TEST_F(GraphTests, IndexReductionsTest1) { auto z = NDArrayFactory::create_('c', {5, 1}); auto axis = NDArrayFactory::create_('c', {1}, {1}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, z); - //graph->getVariableSpace()->putVariable(-3, axis); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, z); + //graph->variableSpace()->putVariable(-3, axis); Node nodeA(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); @@ -420,7 +420,7 @@ TEST_F(GraphTests, AutoOutput1) { auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); - graph->getVariableSpace()->putVariable(-1, x); + graph->variableSpace()->putVariable(-1, x); auto nodeA = new Node(OpType_TRANSFORM_FLOAT, 0, 1, {-1}, {2}); auto nodeB = new Node(OpType_TRANSFORM_FLOAT, 35, 2, {1}, {}); @@ -433,7 +433,7 @@ TEST_F(GraphTests, AutoOutput1) { graph->buildGraph(); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); + ASSERT_TRUE(graph->variableSpace()->getVariable(2) != nullptr); GraphExecutioner::execute(graph); @@ -455,7 +455,7 @@ TEST_F(GraphTests, AutoOutput2) { auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); - graph->getVariableSpace()->putVariable(-1, x); + graph->variableSpace()->putVariable(-1, x); auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2, 3, -1}); auto nodeB = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); @@ -470,9 +470,9 @@ TEST_F(GraphTests, AutoOutput2) { graph->buildGraph(); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(-1) != nullptr); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); - ASSERT_TRUE(graph->getVariableSpace()->getVariable(3) != nullptr); + ASSERT_TRUE(graph->variableSpace()->getVariable(-1) != nullptr); + ASSERT_TRUE(graph->variableSpace()->getVariable(2) != nullptr); + ASSERT_TRUE(graph->variableSpace()->getVariable(3) != nullptr); GraphExecutioner::execute(graph); @@ -502,9 +502,9 @@ TEST_F(GraphTests, BroadcastTest1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, y); - graph.getVariableSpace()->putVariable(-3, z); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, y); + graph.variableSpace()->putVariable(-3, z); Node nodeA(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); Node nodeB(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); @@ -526,8 +526,8 @@ TEST_F(GraphTests, ScalarTest1) { auto z = NDArrayFactory::create_('c', {5, 5}); - graph.getVariableSpace()->putVariable(-1, x); - graph.getVariableSpace()->putVariable(-2, z); + graph.variableSpace()->putVariable(-1, x); + graph.variableSpace()->putVariable(-2, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); @@ -561,8 +561,8 @@ TEST_F(GraphTests, SymbolicLookupTest1) { vX->setName(a); vZ->setName(o); - graph.getVariableSpace()->putVariable(-1, vX); - graph.getVariableSpace()->putVariable(-2, vZ); + graph.variableSpace()->putVariable(-1, vX); + graph.variableSpace()->putVariable(-2, vZ); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); @@ -577,23 +577,23 @@ TEST_F(GraphTests, SymbolicLookupTest1) { graph.addNode(nodeB, {1}); - auto rX = graph.getVariableSpace()->getVariable(a); - auto rZ = graph.getVariableSpace()->getVariable(o); + auto rX = graph.variableSpace()->getVariable(a); + auto rZ = graph.variableSpace()->getVariable(o); std::string om("omicron"); ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph.getVariableSpace()->hasVariable(om)); + ASSERT_FALSE(graph.variableSpace()->hasVariable(om)); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(2)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(p)); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(t)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(p)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(t)); ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); } @@ -605,7 +605,7 @@ TEST_F(GraphTests, Test_Clone_1) { auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); //graph->buildGraph(); auto clone = graph->clone(); @@ -643,7 +643,7 @@ TEST_F(GraphTests, Test_Clone_2) { auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->getVariableSpace(); + auto variableSpace = graph->variableSpace(); graph->buildGraph(); auto clone = graph->clone(); @@ -877,15 +877,15 @@ TEST_F(GraphTests, Test_Inplace_Execution_1) { ASSERT_TRUE(graph->nodeById(17)->isInplace()); ASSERT_TRUE(graph->nodeById(18)->isInplace()); - auto status = GraphExecutioner::execute(graph, graph->getVariableSpace()); + auto status = GraphExecutioner::execute(graph, graph->variableSpace()); ASSERT_EQ(Status::OK(), status); - auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); + auto z = graph->variableSpace()->getVariable(18)->getNDArray(); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - auto z_17 = graph->getVariableSpace()->getVariable(17)->getNDArray(); + auto z_17 = graph->variableSpace()->getVariable(17)->getNDArray(); ASSERT_TRUE(z_17 == z); delete graph; diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 489ac4c379e0..3984eaa2a5ae 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -50,9 +50,9 @@ TEST_F(GraphTests2, test_placeholder_1) { graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable("input")); + ASSERT_TRUE(graph.variableSpace()->hasVariable("input")); - auto variable = graph.getVariableSpace()->getVariable("input"); + auto variable = graph.variableSpace()->getVariable("input"); ASSERT_NE(nullptr, variable); ASSERT_TRUE(variable->isPlaceholder()); diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index cea4ab3e55a8..f4fc4a9b7695 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -85,9 +85,9 @@ TEST_F(OneOffTests, test_pad_1D_1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(4)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(4)); - auto z = graph.getVariableSpace()->getVariable(4)->getNDArray(); + auto z = graph.variableSpace()->getVariable(4)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -109,9 +109,9 @@ TEST_F(OneOffTests, test_scatter_nd_update_1) { Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + auto z = graph->variableSpace()->getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); z->printIndexedBuffer("z"); @@ -129,9 +129,9 @@ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(9)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(9)); - auto z = graph.getVariableSpace()->getVariable(9)->getNDArray(); + auto z = graph.variableSpace()->getVariable(9)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -144,9 +144,9 @@ TEST_F(OneOffTests, test_tensor_array_1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(5)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(5)); - auto z = graph.getVariableSpace()->getVariable(5)->getNDArray(); + auto z = graph.variableSpace()->getVariable(5)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -159,9 +159,9 @@ TEST_F(OneOffTests, test_tensor_array_2) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(6)); - auto z = graph.getVariableSpace()->getVariable(6)->getNDArray(); + auto z = graph.variableSpace()->getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -174,9 +174,9 @@ TEST_F(OneOffTests, test_tensor_array_3) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(15)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(15)); - auto z = graph.getVariableSpace()->getVariable(15)->getNDArray(); + auto z = graph.variableSpace()->getVariable(15)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -189,9 +189,9 @@ TEST_F(OneOffTests, test_tensor_array_4) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(11)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(11)); - auto z = graph.getVariableSpace()->getVariable(11)->getNDArray(); + auto z = graph.variableSpace()->getVariable(11)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -204,9 +204,9 @@ TEST_F(OneOffTests, test_assert_4) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); - auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -223,9 +223,9 @@ TEST_F(OneOffTests, test_assert_4) { // Nd4jStatus status = GraphExecutioner::execute(graph); // ASSERT_EQ(Status::OK(), status); -// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); +// ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); -// auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); +// auto z = graph->variableSpace()->getVariable(6)->getNDArray(); // ASSERT_TRUE(z != nullptr); // z->printIndexedBuffer("z buffer"); @@ -247,9 +247,9 @@ TEST_F(OneOffTests, test_cond_false_1) { Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + auto z = graph->variableSpace()->getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); z->printIndexedBuffer("z buffer"); @@ -269,10 +269,10 @@ TEST_F(OneOffTests, test_identity_n_2) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1, 1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(1, 1)); - auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -285,9 +285,9 @@ TEST_F(OneOffTests, test_non2d_1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(3)); - auto z = graph.getVariableSpace()->getVariable(3)->getNDArray(); + auto z = graph.variableSpace()->getVariable(3)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -300,12 +300,12 @@ TEST_F(OneOffTests, test_reduce_all_1) { graph.execute(); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.getVariableSpace()->hasVariable(2)); - auto in = graph.getVariableSpace()->getVariable(2)->getNDArray(); + ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); + auto in = graph.variableSpace()->getVariable(2)->getNDArray(); - auto z = graph.getVariableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace()->getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 658c36e183b5..0e13966ea87a 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -39,6 +39,7 @@ class OpSequenceTests : public testing::Test { }; TEST_F(OpSequenceTests, test_iterator_1) { + Graph graph; OpSequence sequence; ASSERT_EQ(0, sequence.length()); @@ -61,8 +62,7 @@ TEST_F(OpSequenceTests, test_iterator_1) { ASSERT_EQ(3, cnt); - GraphMemoryManager mgr; - OptimizedGraph optimizedGraph(mgr); + OptimizedGraph optimizedGraph(&graph); ASSERT_EQ(0, optimizedGraph.layers()); optimizedGraph.append(sequence); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index d094e09caec4..5bd519dd8f3e 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -111,17 +111,17 @@ TEST_F(PlaygroundTests, test_bert_full_1) { graph->tagInplaceNodes(); - graph->getVariableSpace()->putVariable(658,0, t); - graph->getVariableSpace()->putVariable(659,0, u); - graph->getVariableSpace()->putVariable(660,0, v); + graph->variableSpace()->putVariable(658,0, t); + graph->variableSpace()->putVariable(659,0, u); + graph->variableSpace()->putVariable(660,0, v); /* // validating graph now auto status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1620)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(1620)); - auto array = graph->getVariableSpace()->getVariable(1620)->getNDArray(); + auto array = graph->variableSpace()->getVariable(1620)->getNDArray(); ASSERT_EQ(z, *array); */ @@ -174,17 +174,17 @@ TEST_F(PlaygroundTests, test_bert_1) { graph->tagInplaceNodes(); - graph->getVariableSpace()->putVariable(85,0, t); - graph->getVariableSpace()->putVariable(86,0, u); - graph->getVariableSpace()->putVariable(87,0, v); + graph->variableSpace()->putVariable(85,0, t); + graph->variableSpace()->putVariable(86,0, u); + graph->variableSpace()->putVariable(87,0, v); /* // validating graph now auto status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); - auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); ASSERT_EQ(z, *array); */ @@ -236,9 +236,9 @@ TEST_F(PlaygroundTests, test_bert_2) { // validating graph now auto status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); - auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); ASSERT_EQ(z, *array); */ diff --git a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp index 520ed0134b7a..2ab89663af0e 100644 --- a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -53,10 +53,10 @@ TEST_F(ProtoBufTests, TestTextLoad2) { ASSERT_FALSE(graph == nullptr); - ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + ASSERT_EQ(2, graph->variableSpace()->externalEntries()); - auto var0 = graph->getVariableSpace()->getVariable(new std::string("zeros")); - auto var1 = graph->getVariableSpace()->getVariable(new std::string("ones")); + auto var0 = graph->variableSpace()->getVariable(new std::string("zeros")); + auto var1 = graph->variableSpace()->getVariable(new std::string("ones")); // first we're veryfying variable states @@ -91,10 +91,10 @@ TEST_F(ProtoBufTests, TestTextLoad3) { ASSERT_FALSE(graph == nullptr); - ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + ASSERT_EQ(2, graph->variableSpace()->externalEntries()); - auto var0 = graph->getVariableSpace()->getVariable(new std::string("Placeholder")); - auto var1 = graph->getVariableSpace()->getVariable(new std::string("Placeholder_1")); + auto var0 = graph->variableSpace()->getVariable(new std::string("Placeholder")); + auto var1 = graph->variableSpace()->getVariable(new std::string("Placeholder_1")); ASSERT_TRUE(var0 != nullptr); ASSERT_TRUE(var1 != nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp index 7ca6732fe62b..c84f45bb3417 100644 --- a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp @@ -54,9 +54,9 @@ TEST_F(SanityTests, VariableSpace_2) { TEST_F(SanityTests, Graph_1) { Graph graph; - graph.getVariableSpace()->putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); - graph.getVariableSpace()->putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace()->putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace()->putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); std::pair pair(1, 2); - graph.getVariableSpace()->putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace()->putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp index 571db71f3048..74409f16a2cd 100644 --- a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp @@ -240,7 +240,7 @@ TEST_F(WorkspaceTests, Test_Arrays_1) { #ifdef GRAPH_FILES_OK TEST_F(WorkspaceTests, Test_Graph_1) { auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto workspace = graph->getVariableSpace()->workspace(); + auto workspace = graph->variableSpace()->workspace(); auto status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); From f6cc648a77c69a37e7136e3de81399cecfad78fd Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 19:05:47 +0300 Subject: [PATCH 069/233] Node inputs propagation Signed-off-by: raver119 --- libnd4j/include/graph/impl/Node.cpp | 1 + .../layers_tests/GraphExecutorTests.cpp | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 1eb4d46bee22..a0094b7684b9 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -215,6 +215,7 @@ namespace sd { void Node::pickInput(std::pair& pair) { _input.push_back(pair); + _protoContext.pickInput(pair); } void Node::pickInput(const std::string &id) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index a41a46a746c7..03236ff2ea8d 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -57,9 +57,13 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { GraphMemoryManager mgr; Graph graph(nullptr, nullptr, mgr); - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); + + graph.addVariable("A", A); + graph.addVariable("B", B); + graph.addVariable("C", C); Node m("mul", sd::ops::multiply()); Node a("add", sd::ops::add()); @@ -69,6 +73,9 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { OptimizedGraph optimizedGraph(&graph); OpSequence sequence; + ASSERT_EQ(2, m.protoContext().inputs().size()); + ASSERT_EQ(2, a.protoContext().inputs().size()); + sequence.append(m.customOp(), m.protoContext()); sequence.append(a.customOp(), a.protoContext()); @@ -79,4 +86,10 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { GraphExecutor executor; executor.execute(optimizedGraph); + + ASSERT_TRUE(graph.variableSpace()->hasVariable("mul")); + ASSERT_TRUE(graph.variableSpace()->hasVariable("add")); + + ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); } \ No newline at end of file From 4c75473a1c09e3b83ee4a64dab440dd99378634c Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 1 Apr 2020 20:04:13 +0300 Subject: [PATCH 070/233] first execution test passes Signed-off-by: raver119 --- libnd4j/include/graph/ContextPrototype.h | 4 ++++ libnd4j/include/graph/impl/Context.cpp | 6 ++++++ libnd4j/include/graph/impl/ContextPrototype.cpp | 12 ++++++++++++ libnd4j/include/graph/impl/Node.cpp | 2 ++ .../tests_cpu/layers_tests/GraphExecutorTests.cpp | 13 +++++++++++-- 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index cea6a3422be3..9b945b4c281f 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -43,6 +43,7 @@ namespace sd { // int ids of the input nodes std::vector> _inputs; int _nodeId; + std::string _name; std::vector _tArgs; std::vector _iArgs; std::vector _bArgs; @@ -128,6 +129,9 @@ namespace sd { bool isUseMKLDNN() const { return _useMKLDNN; } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } + std::string name() const; + void setName(const std::string &name); + /** * This method returns number of inputs available in this block * @return diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 0bdac22ce6d4..26a7d261f55f 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -53,6 +53,7 @@ namespace sd { this->_opNum = prototype.opNum(); this->_isInplace = prototype.isInplace(); this->_nodeId = prototype.nodeId(); + this->_name = prototype.name(); this->_useMKLDNN = prototype.isUseMKLDNN(); if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) @@ -320,6 +321,11 @@ namespace sd { if (!_variableSpace->hasVariable(pair)) { auto var = new Variable(nullptr, nullptr, this->nodeId(), idx); + auto name = this->name(); + + if (!name.empty()) + var->setName(name); + _variableSpace->putVariable(pair, var); return var; } else { diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 64827a869c1e..3fa993bd1027 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -215,6 +215,7 @@ namespace sd { _iArgs = other._iArgs; _bArgs = other._bArgs; _dArgs = other._dArgs; + _name = other._name; _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -236,6 +237,7 @@ namespace sd { _iArgs = other._iArgs; _bArgs = other._bArgs; _dArgs = other._dArgs; + _name = other._name; _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -256,6 +258,7 @@ namespace sd { _iArgs = std::move(other._iArgs); _bArgs = std::move(other._bArgs); _dArgs = std::move(other._dArgs); + _name = std::move(other._name); _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -277,6 +280,7 @@ namespace sd { _iArgs = std::move(other._iArgs); _bArgs = std::move(other._bArgs); _dArgs = std::move(other._dArgs); + _name = std::move(other._name); _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -294,5 +298,13 @@ namespace sd { void ContextPrototype::setNodeId(int id) { _nodeId = id; } + + std::string ContextPrototype::name() const { + return _name; + } + + void ContextPrototype::setName(const std::string &name) { + _name = name; + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index a0094b7684b9..dde28824baff 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -62,6 +62,7 @@ namespace sd { _hasInternalOutputs = false; ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + block.setName(nodeName); block.appendI(iArgs); block.appendT(tArgs); @@ -90,6 +91,7 @@ namespace sd { _hasInternalOutputs = false; ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + block.setName(nodeName); block.appendI(iArgs); block.appendT(tArgs); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 03236ff2ea8d..4fd4cd7ca66e 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -61,12 +61,15 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); + graph.addVariable("A", A); graph.addVariable("B", B); graph.addVariable("C", C); Node m("mul", sd::ops::multiply()); Node a("add", sd::ops::add()); + graph.addNode(m, {"A", "B"}); graph.addNode(a, {"mul", "C"}); @@ -87,9 +90,15 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { GraphExecutor executor; executor.execute(optimizedGraph); + // checking results by ID + ASSERT_TRUE(graph.variableSpace()->hasVariable(m.id())); + ASSERT_TRUE(graph.variableSpace()->hasVariable(a.id())); + + // checking results by name ASSERT_TRUE(graph.variableSpace()->hasVariable("mul")); ASSERT_TRUE(graph.variableSpace()->hasVariable("add")); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); + // checking if result is valid + auto result = graph.variableSpace()->getVariable(a.id())->getNDArray(); + ASSERT_EQ(exp, *result); } \ No newline at end of file From e13ac10cf38c2b63f1a4f2f1f461c33ec3bb866d Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 2 Apr 2020 08:03:49 +0300 Subject: [PATCH 071/233] graph execution and results Signed-off-by: raver119 --- .../exceptions/graph_execution_exception.h | 1 + .../impl/graph_execution_exception.cpp | 4 ++++ libnd4j/include/graph/impl/Graph.cpp | 21 +++++++++++++++++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index 612b549f81ff..3bf3b3311faa 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -38,6 +38,7 @@ namespace sd { class SD_EXPORT graph_execution_exception: public graph_exception { public: explicit graph_execution_exception(Nd4jLong graphId); + explicit graph_execution_exception(const std::string &message, Nd4jStatus status); }; } diff --git a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp index 086796517f7c..aff8f76638e5 100644 --- a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp @@ -25,4 +25,8 @@ namespace sd { graph_execution_exception::graph_execution_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Caught exception during graph execution", graphId), graphId) { _graphId = graphId; } + + graph_execution_exception::graph_execution_exception(const std::string &message, Nd4jStatus status) : graph_exception(message, status) { + // + } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 3c3e959b50da..ab057c41a6fa 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -33,6 +33,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -582,8 +583,24 @@ namespace sd { throw unresolved_output_exception::build("Requested output doesn't exist", v); } - // TODO: implement this method - return std::map(); + // execute optimized version of this graph + auto status = executor.execute(optimizedGraph()); + if (status != Status::OK()) + throw graph_execution_exception("Graph execution failed, error code: ", status); + + // fetch outputs from VariableSpace + std::map result; + for (const auto &v:outputs) { + if (!_variableSpace->hasVariable(v)) + throw unresolved_output_exception::build("Requested output doesn't exist after execution", v); + + auto var = _variableSpace->getVariable(v); + + // TODO: we want to make sure ManagedDataBuffer doesn't leak here + result[v] = *var->getNDArray(); + } + + return result; } bool Graph::topolSearch(const int startNode, const std::set& nodeBranches, const std::unordered_map& positions, OpSequence& opSeq) const { From 8ca3351825c7866d7077b2586044f12c8e72b595 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 2 Apr 2020 08:56:11 +0300 Subject: [PATCH 072/233] additional validation during graph construction Signed-off-by: raver119 --- libnd4j/include/graph/Node.h | 4 +-- libnd4j/include/graph/impl/Graph.cpp | 17 +++++++-- libnd4j/include/graph/impl/Node.cpp | 4 +-- .../layers_tests/GraphAnalysisTests.cpp | 11 +++--- .../layers_tests/GraphExecutorTests.cpp | 4 +-- .../tests_cpu/layers_tests/GraphTests2.cpp | 35 ++++++++++++++----- 6 files changed, 52 insertions(+), 23 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index cda7bad4d79a..d98064eb1142 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -87,8 +87,8 @@ namespace sd { public: - explicit Node(const std::string &nodeName, const sd::ops::DeclarableOp &op, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); - explicit Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); + explicit Node(const sd::ops::DeclarableOp &op, const std::string &nodeName = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); + explicit Node(const std::string &opName, const std::string &nodeName = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); explicit Node(const FlatNode *node); ~Node(); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index ab057c41a6fa..01532ff0666c 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -81,23 +81,34 @@ namespace sd { } void Graph::addNode(Node &node, const std::initializer_list &inputs) { + // temporary check. basically we're okay if Node has id defined if (node.id() != 0) throw std::runtime_error("Graph::addNode - Node has id defined"); + if (node.name().empty()) { + // if name is empty we'll make up a name based on Op name + } else { + if (_symbolicLookupTable.count(node.name()) > 0) + throw std::runtime_error("Graph::addNode - Graph alread has Node [" + node.name() + "] defined"); + } + // node must have numeric id node.setId(_maxId++); _symbolicLookupTable[node.name()] = node.id(); // converting string ids to numeric ones - for (auto &v:inputs) + for (auto &v:inputs) { + // we don't allow self-references + if (v == node.name()) + throw unresolved_input_exception::build("Graph::addNode - Node references itself", v); + node.pickInput(idByName(v), 0); + } _unmapped[node.id()] = node; } void Graph::addNode(Node &node, const std::initializer_list &inputs) { - node.markRemovable(false); - throw std::runtime_error("not implemented yet"); } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index dde28824baff..b72c1277e5c3 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -44,7 +44,7 @@ namespace sd { namespace graph { - Node::Node(const std::string &nodeName, const ops::DeclarableOp &opName, const std::vector &tArgs, + Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs) { auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName.getOpHash()); @@ -72,7 +72,7 @@ namespace sd { this->setContextPrototype(block); } - Node::Node(const std::string &nodeName, const std::string &opName, const std::vector &tArgs, + Node::Node(const std::string &opName, const std::string &nodeName, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 0acfb5d76eea..3919596bd3bd 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -47,8 +47,8 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { // C graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a("multiply", sd::ops::multiply()); - Node b("add", sd::ops::add()); + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); graph.addNode(a, {"A", "B"}); graph.addNode(b, {"multiply", "C"}); @@ -89,10 +89,9 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { // D graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); - Node a("multiply", sd::ops::multiply()); - Node b("add", sd::ops::add()); - Node c("subtract", sd::ops::subtract()); - + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + Node c(sd::ops::subtract(), "subtract"); graph.addNode(a, {"A", "B"}); graph.addNode(b, {"multiply", "C"}); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 4fd4cd7ca66e..1d65c2072a46 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -67,8 +67,8 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { graph.addVariable("B", B); graph.addVariable("C", C); - Node m("mul", sd::ops::multiply()); - Node a("add", sd::ops::add()); + Node m(sd::ops::multiply(), "mul"); + Node a(sd::ops::add(), "add"); graph.addNode(m, {"A", "B"}); graph.addNode(a, {"mul", "C"}); diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 3984eaa2a5ae..b5360abaefe8 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -76,9 +76,9 @@ TEST_F(GraphTests2, test_execution_1) { // C graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node b("add_node", sd::ops::add()); + Node b(sd::ops::add(), "add_node"); - graph.addNode(Node("multiply_node", sd::ops::multiply()), {"A", "B"}); + graph.addNode(Node( sd::ops::multiply(), "multiply_node"), {"A", "B"}); graph.addNode(b, {"multiply_node", "C"}); auto result = graph.execute({}, {"add_node"}); @@ -91,7 +91,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - Node node("tanh_node", sd::ops::tanh()); + Node node(sd::ops::tanh(), "tanh_node"); graph.addNode(node, {"input"}); // this test must throw an exception, because input isn't resolved yet @@ -103,7 +103,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_2) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node("tanh_node", sd::ops::tanh()), {"input"}); + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); @@ -115,7 +115,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_3) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node("tanh_node", sd::ops::tanh()), {"input"}); + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), sd::datatype_exception); } @@ -125,7 +125,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_4) { graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); - Node a("tanh_node", sd::ops::tanh()); + Node a(sd::ops::tanh(), "tanh_node"); graph.addNode(a, {"input"}); ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), sd::shape_mismatch_exception); @@ -136,7 +136,7 @@ TEST_F(GraphTests2, test_output_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - Node node("tanh_node", sd::ops::tanh()); + Node node(sd::ops::tanh(), "tanh_node"); graph.addNode(node, {"input"}); // since we're requesting output of non-existent node - we expect exception @@ -148,9 +148,28 @@ TEST_F(GraphTests2, test_input_resolution_1) { graph.addPlaceholder("input", DataType::FLOAT32); - Node a("tanh_node", sd::ops::tanh()); + Node a(sd::ops::tanh(), "tanh_node"); graph.addNode(a, {"input"}); // since we're trying to resolve non-existent placeholder - we expect exception ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); +} + +TEST_F(GraphTests2, test_double_name_1) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"}); + ASSERT_ANY_THROW(graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"})); +} + +TEST_F(GraphTests2, test_self_reference) { + Graph graph; + + graph.addPlaceholder("input", DataType::FLOAT32); + + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + ASSERT_ANY_THROW(graph.addNode(Node(sd::ops::add(), "add_node"), {"add_node"})); } \ No newline at end of file From 54c1bf02b14f4757d78dba3e776d4506cf130e69 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 2 Apr 2020 09:16:37 +0300 Subject: [PATCH 073/233] few files updated after merge Signed-off-by: raver119 --- libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp | 6 +++--- .../include/ops/declarable/platform/mkldnn/xw_plus_b.cpp | 4 ++-- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp | 9 ++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index dbabad395d4f..fc90c8125032 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -39,7 +39,7 @@ namespace sd { if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) return Status::OK(); - const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); @@ -66,7 +66,7 @@ namespace sd { auto weights = INPUT_VARIABLE(1); - const int nWeightsFormat = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + const int nWeightsFormat = block.numI() > 0 ? INT_ARG(0) : 0; auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1); @@ -95,7 +95,7 @@ namespace sd { if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty()) return Status::OK(); - const bool bTranspose = (block.getIArguments()->size() > 0 ? INT_ARG(0) == 1 : false); + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp index 01a003c2c286..a8b38fafa6c9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -315,7 +315,7 @@ namespace sd { const int wRank = w->rankOf(); const int zRank = z->rankOf(); - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + const bool bShouldTransp = block.numI() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); @@ -378,7 +378,7 @@ namespace sd { const int wRank = w->rankOf(); const int dLdzRank = dLdz->rankOf(); - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + const bool bShouldTransp = block.numI() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index b4c9839abb10..dc6b0b550dcb 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -2098,11 +2098,10 @@ TEST_F(DeclarableOpsTests14, Reshape2) { auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - std::vector* arguments = block->getIArguments(); - arguments->push_back(-y->ordering()); - arguments->push_back(3); - arguments->push_back(5); - arguments->push_back(4); + block->appendI(-y->ordering()); + block->appendI(3); + block->appendI(5); + block->appendI(4); sd::ops::reshape reshape; From 5f6986286ee17cb19e3eabca0391aed4ab8f9556 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 2 Apr 2020 10:33:34 +0300 Subject: [PATCH 074/233] few nano tweaks in RNGTests Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 41 ++++++--------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 889e194a656e..a1d7ab46f684 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -574,8 +574,8 @@ TEST_F(RNGTests, Test_Uniform_2) { RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(0); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(0); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); ASSERT_EQ(Status::OK(), result.status()); @@ -583,9 +583,6 @@ TEST_F(RNGTests, Test_Uniform_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - - delete op; - } TEST_F(RNGTests, Test_Uniform_SGA_3) { @@ -603,8 +600,8 @@ TEST_F(RNGTests, Test_Gaussian_2) { RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::GaussianDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); ASSERT_EQ(Status::OK(), result.status()); @@ -612,9 +609,6 @@ TEST_F(RNGTests, Test_Gaussian_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - - delete op; - } TEST_F(RNGTests, Test_LogNorm_2) { @@ -623,8 +617,8 @@ TEST_F(RNGTests, Test_LogNorm_2) { RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::LogNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); ASSERT_EQ(Status::OK(), result.status()); @@ -632,9 +626,6 @@ TEST_F(RNGTests, Test_LogNorm_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - - delete op; - } TEST_F(RNGTests, Test_TruncatedNorm_2) { @@ -643,8 +634,8 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) { RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution); - auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::TruncatedNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); ASSERT_EQ(Status::OK(), result.status()); @@ -652,8 +643,6 @@ TEST_F(RNGTests, Test_TruncatedNorm_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - delete op; - } @@ -663,8 +652,8 @@ TEST_F(RNGTests, Test_Binomial_2) { RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 0.5f); - auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx); - auto result = op->execute(_rngA, {&input}, {0.5f}, {3}); + sd::ops::LegacyRandomOp op(random::BinomialDistributionEx); + auto result = op.execute(_rngA, {&input}, {0.5f}, {3}); ASSERT_EQ(Status::OK(), result.status()); @@ -672,9 +661,6 @@ TEST_F(RNGTests, Test_Binomial_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - - delete op; - } @@ -684,8 +670,8 @@ TEST_F(RNGTests, Test_Bernoulli_2) { RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); - auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution); - auto result = op->execute(_rngA, {&input}, {0.5f}, {}); + sd::ops::LegacyRandomOp op(random::BernoulliDistribution); + auto result = op.execute(_rngA, {&input}, {0.5f}, {}); ASSERT_EQ(Status::OK(), result.status()); @@ -693,9 +679,6 @@ TEST_F(RNGTests, Test_Bernoulli_2) { ASSERT_TRUE(x1.isSameShape(z)); ASSERT_TRUE(x1.equalsTo(z)); - - delete op; - } TEST_F(RNGTests, Test_GaussianDistribution_1) { From d4ba199fae3ee14fe7c91bba6de439170ecf7e3b Mon Sep 17 00:00:00 2001 From: Oleg Date: Thu, 2 Apr 2020 14:08:21 +0300 Subject: [PATCH 075/233] libnd4j next step of graph topological sorting implementation Signed-off-by: Oleg --- libnd4j/include/graph/Graph.h | 63 ++++- libnd4j/include/graph/impl/Graph.cpp | 230 +++++++++++++----- .../layers_tests/GraphAnalysisTests.cpp | 97 +++++++- 3 files changed, 319 insertions(+), 71 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 725dfe261935..48a29323c9ab 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -43,6 +43,7 @@ namespace sd { namespace graph { + class NodeInfo; class SD_EXPORT Graph { protected: ExecutorConfiguration _configuration; @@ -178,13 +179,67 @@ namespace sd { /* * Topological graph analysis * @param const start node for search - * @param const reference to list of nodes without external inputs - * @param const node positions in _handler + * @param const reference for nodes infor container * @param operation gather * @return stop iterating */ - bool topolSearch(const int startNode, const std::set& nodeBranches, - const std::unordered_map& positions, OpSequence& opSeq) const; + bool topolSearch(const int startNode, const std::unordered_map& nodesConnections, std::vector>& opSeq) const; + /* + * Optimized graph analysis prototyping, gather nodes infor + * @param reference to node information collector + * @param reference to start nodes + * @param reference to input branching nodes (input branching node - atleast 2 internal inputs) + * @return stop iterating + */ + bool opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const; + /* + * Define layers and sequence positions based on nodes infor + * @param reference to node information collector + * @param node ID + * @param layer ID + * @param sequence ID + * @return stop iterating + */ + bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq) const; + /* + * Initialize container with operations and context + * @param code reference to node information collector + * @param reference to opSequence collector + * @return stop iterating + */ + bool initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const; + + }; + + class NodeInfo{ + private: + std::set sConnections; + bool bStart; + bool bInBranching; + bool bOutBranching; + int nLayer; + int nSequence; + public: + + void setStart(bool bValue){ bStart = bValue; } + void setInBranching(bool bValue){ bInBranching = bValue; } + void setOutBranching(bool bValue){ bOutBranching = bValue; } + + void reset(){ sConnections.clear(); bStart = bInBranching = bOutBranching = false; nLayer = 0; } + + int getLayer() const { return nLayer; } + void setLayer(int layer){ nLayer = layer; } + + int getSequence() const { return nSequence; } + void setSequence(int sequence){ nSequence = sequence; } + + void addConnection(int id){ sConnections.emplace(id); } + const std::set& connections() const { return sConnections; } + + bool isStart() const { return bStart; } + bool isInBranching() const { return bInBranching; } + bool isOutBranching() const { return bOutBranching; } + }; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 01532ff0666c..76aeaf6abc9e 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -614,65 +614,6 @@ namespace sd { return result; } - bool Graph::topolSearch(const int startNode, const std::set& nodeBranches, const std::unordered_map& positions, OpSequence& opSeq) const { -/* - if (nodeBranches.empty() || _handles.empty()) - return false; - - for (const auto& itNodes : nodeBranches) { - - auto position = positions.find(itNodes); - if (position != positions.end() && startNode != itNodes) { - for (auto in = _handles[position->second]->input()->begin(); in != _handles[position->second]->input()->end(); ++in) { - if (startNode == in->first) { - opSeq.append(_handles[position->second]->getCustomOp(), _handles[position->second]->getContextPrototype()); - return topolSearch(itNodes, nodeBranches, positions, opSeq); - } - } - } - } - - return false; - */ - - throw std::runtime_error("Graph::topolSearch - not implemented yet"); - } - - OptimizedGraph Graph::optimizedGraph() const { - - OptimizedGraph optGraf(const_cast(this)); - /* - OpSequence opSeq; - std::set nodesMap, startNodes; - std::unordered_map iDpositions; - - for (int i = 0; i < _handles.size(); ++i) { - auto ID = _handles[i]->id(); - int iNcouts = 0; - for (auto in = _handles[i]->input()->begin(); in != _handles[i]->input()->end(); ++in) { - if (in->first < 0) { - iNcouts++; - } - } - - if (iNcouts == _handles[i]->input()->size()) - startNodes.insert(ID); - - nodesMap.insert(ID); - iDpositions[ID] = i; - } - - for (const auto& start : startNodes) { - - auto position = iDpositions.find(start); - opSeq.append(_handles[position->second]->getCustomOp(), _handles[position->second]->getContextPrototype()); - topolSearch(start, nodesMap, iDpositions, opSeq); - } - optGraf.append(opSeq); - */ - return optGraf; - } - Graph::Graph(const Graph &other) : _memoryMaager(other._memoryMaager) { _configuration = other._configuration; _variableSpace = other._variableSpace; @@ -730,6 +671,177 @@ namespace sd { const GraphMemoryManager &Graph::memoryManager() const { return _memoryMaager; } + + + bool Graph::opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const{ + + if(_unmapped.empty()) + return false; + + for (const auto& it : _unmapped) { + + const auto& ID = it.first; + const auto& inputs = it.second.input(); + + if(collector.find(ID) == collector.end()) + collector[ID] = NodeInfo(); + + NodeInfo& parentNode = collector[ID]; + + int inExCounts = 0, inInternalCounts = 0; + for (auto in = inputs.begin(); in != inputs.end(); ++in) { + if (variableSpace()->hasVariable(in->first, 0)) { + inExCounts++; + } + else{ + inInternalCounts++; + if(collector.find(in->first) == collector.end()) + collector[in->first] = NodeInfo(); + collector[in->first].addConnection(ID); + } + } + + parentNode.setInBranching( (inInternalCounts == inputs.size() && inInternalCounts > 1) ); + + parentNode.setStart(inExCounts == inputs.size() ); + parentNode.setSequence(-1); + + if(parentNode.isStart()){ + parentNode.setLayer(0); + startNodes.emplace(ID); + } + else{ + if(parentNode.isInBranching()) + inBranchingNodes.emplace(ID); + } + } + return true; + } + + bool Graph::topolSearch(const int startNode, const std::unordered_map& collector, + std::vector >& opSeq) const { + + if (_unmapped.empty()) + return false; + + auto itParent = collector.find(startNode); + if(itParent != collector.end()){ + + for (const auto& itNodes : itParent->second.connections() ) { + + auto itChild = collector.find(itNodes); + + if(itChild != collector.end()){ + + if(itChild->second.isInBranching()){ + return true; + } + + const auto it = _unmapped.find(itNodes); + const auto& child = itChild->second; + opSeq[child.getLayer()][child.getSequence()].append( it->second.customOp(), it->second.contextPrototype() ); + + topolSearch(itNodes, collector, opSeq); + } + } + } + return true; + } + + OptimizedGraph Graph::optimizedGraph() const { + + std::unordered_map collector; + std::set startNodes, inBranching; + + OptimizedGraph optGraph(const_cast(this)); + + // todo check this will be empty Optimized graph + if(!opGraphProto(collector, startNodes, inBranching)) + throw std::runtime_error("Graph::optimizedGraph() - not prototyped"); + + int startSeq = 0; + for(const auto& id : startNodes){ + layersSeqDefine(collector, id, 0, startSeq); + startSeq++; + } + + std::vector> vOpSeq; + initOpSeqContainer(collector, vOpSeq); + + startNodes.insert(inBranching.begin(), inBranching.end()); + + for (const auto& id : startNodes) { + + const auto it = _unmapped.find(id); + const auto& nodeInfo = collector[id]; + vOpSeq[nodeInfo.getLayer()][nodeInfo.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); + + topolSearch(id, collector, vOpSeq); + } + + for(auto& vSeq : vOpSeq){ + optGraph.append(vSeq); + } + return optGraph; + } + + bool Graph::initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const { + + if(collection.empty()) + return false; + + int layer = 0; + for(const auto& node : collection){ + if(layer < node.second.getLayer()) + layer = node.second.getLayer(); + } + + vOpSeq.resize(layer + 1); + // each layer will have it own max sequence + for(int i = 0; i <= layer; ++i){ + int nSeq = 0; + for(const auto& node : collection){ + if(nSeq < node.second.getSequence() && i == node.second.getLayer()) + nSeq = node.second.getSequence(); + } + vOpSeq[i].resize(nSeq + 1); + } + + return true; + } + + bool Graph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq) const { + + auto parent = collection.find(ID); + if(parent == collection.end()) + return false; + + if( parent->second.isInBranching() ){ + layer++; + startSeq--; + } + + parent->second.setLayer(layer); + parent->second.setSequence(startSeq); + + parent->second.setOutBranching( parent->second.connections().size() > 1 ); + if(parent->second.isOutBranching()) + layer++; + + for(const auto& id : parent->second.connections() ){ + + auto child = collection.find(id); + + layersSeqDefine(collection , id, layer, startSeq); + + if(parent->second.isOutBranching()) + startSeq++; + } + + return true; + } + + } } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 3919596bd3bd..2f99b3f2c1bb 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -70,8 +70,8 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_1) { // we expect that OpSequence has exactly 2 ops ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); - ASSERT_EQ(20, sequence.at(1).protoContext().nodeId()); + ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_2) { @@ -93,6 +93,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { Node b(sd::ops::add(), "add"); Node c(sd::ops::subtract(), "subtract"); + graph.addNode(a, {"A", "B"}); graph.addNode(b, {"multiply", "C"}); graph.addNode(c, {"multiply", "D"}); @@ -102,7 +103,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer + // we expect that OptimizedGraph has exactly 2 layers ASSERT_EQ(2, optimized.layers()); // checking first layer first @@ -115,10 +116,10 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { // we expect that OpSequence has exactly 2 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); // checking second layer now - auto layer1 = optimized.layer(0); + auto layer1 = optimized.layer(1); // we expect layer has exactly 2 OpSequences ASSERT_EQ(2, layer1.width()); @@ -126,10 +127,90 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { sequence = layer1[0]; ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(20, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); sequence = layer1[1]; ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(30, sequence.at(0).protoContext().nodeId()); -} \ No newline at end of file + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, basic_toposort_test_3) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + Node c(sd::ops::subtract(), "subtract"); + Node d(sd::ops::add(), "add2"); + Node e(sd::ops::multiply(), "multiply2"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + + graph.addNode(c, {"add", "D"}); + graph.addNode(d, {"add", "D"}); + + graph.addNode(e, {"subtract", "add2"}); + + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, sequence.length()); + + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(6, sequence.at(1).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + + + // checking last layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); +} From 80d58a4ce738bfdbc17ee8d793a7380791456bef Mon Sep 17 00:00:00 2001 From: Oleg Date: Fri, 3 Apr 2020 10:19:25 +0300 Subject: [PATCH 076/233] libnd4j moved graph optimization semantic from graph to optimized graph, fixed several places, add next step of analysis test Signed-off-by: Oleg --- libnd4j/include/graph/Graph.h | 71 +------ libnd4j/include/graph/OptimizedGraph.h | 73 +++++++- libnd4j/include/graph/impl/Graph.cpp | 165 ----------------- libnd4j/include/graph/impl/OptimizedGraph.cpp | 174 ++++++++++++++++++ .../layers_tests/GraphAnalysisTests.cpp | 121 ++++++++++++ 5 files changed, 372 insertions(+), 232 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 48a29323c9ab..d8926ca0c1ae 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -108,6 +108,11 @@ namespace sd { * @return */ VariableSpace *variableSpace() const; + /** + * This method returns unmapped nodes + * @return + */ + const MAP_IMPL& unmappedNodes() const { return _unmapped; }; const GraphMemoryManager& memoryManager() const; @@ -175,73 +180,7 @@ namespace sd { * @return */ std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}, const GraphExecutor &executor = GraphExecutor()) const; -protected: - /* - * Topological graph analysis - * @param const start node for search - * @param const reference for nodes infor container - * @param operation gather - * @return stop iterating - */ - bool topolSearch(const int startNode, const std::unordered_map& nodesConnections, std::vector>& opSeq) const; - /* - * Optimized graph analysis prototyping, gather nodes infor - * @param reference to node information collector - * @param reference to start nodes - * @param reference to input branching nodes (input branching node - atleast 2 internal inputs) - * @return stop iterating - */ - bool opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const; - /* - * Define layers and sequence positions based on nodes infor - * @param reference to node information collector - * @param node ID - * @param layer ID - * @param sequence ID - * @return stop iterating - */ - bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq) const; - /* - * Initialize container with operations and context - * @param code reference to node information collector - * @param reference to opSequence collector - * @return stop iterating - */ - bool initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const; - }; - - class NodeInfo{ - private: - std::set sConnections; - bool bStart; - bool bInBranching; - bool bOutBranching; - int nLayer; - int nSequence; - public: - - void setStart(bool bValue){ bStart = bValue; } - void setInBranching(bool bValue){ bInBranching = bValue; } - void setOutBranching(bool bValue){ bOutBranching = bValue; } - - void reset(){ sConnections.clear(); bStart = bInBranching = bOutBranching = false; nLayer = 0; } - - int getLayer() const { return nLayer; } - void setLayer(int layer){ nLayer = layer; } - - int getSequence() const { return nSequence; } - void setSequence(int sequence){ nSequence = sequence; } - - void addConnection(int id){ sConnections.emplace(id); } - const std::set& connections() const { return sConnections; } - - bool isStart() const { return bStart; } - bool isInBranching() const { return bInBranching; } - bool isOutBranching() const { return bOutBranching; } - - }; - FORCEINLINE bool Graph::built() { return _built.load(); diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 6048bff464ff..fc1c4f105f7e 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -30,7 +30,7 @@ namespace sd { namespace graph { class Graph; - + class NodeInfo; /** * This class acts as a topologically sorted & optimized Graph representation, ready for execution */ @@ -92,7 +92,78 @@ namespace sd { * @return */ const Graph& originalGraph() const; + + protected: + /* + * optimize original graph + */ + void optimizedGraph(); + /* + * Topological graph analysis + * @param const start node for search + * @param const reference for nodes infor container + * @param operation gather + * @return stop iterating + */ + bool topolSearch(const int startNode, const std::unordered_map& nodesConnections, std::vector>& opSeq) const; + /* + * Optimized graph analysis prototyping, gather nodes infor + * @param reference to node information collector + * @param reference to start nodes + * @param reference to input branching nodes (input branching node - atleast 2 internal inputs) + * @return stop iterating + */ + bool opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const; + /* + * Define layers and sequence positions based on nodes infor + * @param reference to node information collector + * @param node ID + * @param layer ID + * @param sequence ID + * @return stop iterating + */ + bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq) const; + /* + * Initialize container with operations and context + * @param code reference to node information collector + * @param reference to opSequence collector + * @return stop iterating + */ + bool initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const; + }; + + class NodeInfo { + private: + std::set sConnections; + bool bStart; + bool bInBranching; + bool bOutBranching; + int nLayer; + int nSequence; + public: + + void setStart(bool bValue) { bStart = bValue; } + void setInBranching(bool bValue) { bInBranching = bValue; } + void setOutBranching(bool bValue) { bOutBranching = bValue; } + + void reset() { sConnections.clear(); bStart = bInBranching = bOutBranching = false; nLayer = 0; } + + int getLayer() const { return nLayer; } + void setLayer(int layer) { nLayer = layer; } + + int getSequence() const { return nSequence; } + void setSequence(int sequence) { nSequence = sequence; } + + void addConnection(int id) { sConnections.emplace(id); } + const std::set& connections() const { return sConnections; } + + bool isStart() const { return bStart; } + bool isInBranching() const { return bInBranching; } + bool isOutBranching() const { return bOutBranching; } + + }; + } } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 76aeaf6abc9e..8176a683be05 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -672,176 +672,11 @@ namespace sd { return _memoryMaager; } - - bool Graph::opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const{ - - if(_unmapped.empty()) - return false; - - for (const auto& it : _unmapped) { - - const auto& ID = it.first; - const auto& inputs = it.second.input(); - - if(collector.find(ID) == collector.end()) - collector[ID] = NodeInfo(); - - NodeInfo& parentNode = collector[ID]; - - int inExCounts = 0, inInternalCounts = 0; - for (auto in = inputs.begin(); in != inputs.end(); ++in) { - if (variableSpace()->hasVariable(in->first, 0)) { - inExCounts++; - } - else{ - inInternalCounts++; - if(collector.find(in->first) == collector.end()) - collector[in->first] = NodeInfo(); - collector[in->first].addConnection(ID); - } - } - - parentNode.setInBranching( (inInternalCounts == inputs.size() && inInternalCounts > 1) ); - - parentNode.setStart(inExCounts == inputs.size() ); - parentNode.setSequence(-1); - - if(parentNode.isStart()){ - parentNode.setLayer(0); - startNodes.emplace(ID); - } - else{ - if(parentNode.isInBranching()) - inBranchingNodes.emplace(ID); - } - } - return true; - } - - bool Graph::topolSearch(const int startNode, const std::unordered_map& collector, - std::vector >& opSeq) const { - - if (_unmapped.empty()) - return false; - - auto itParent = collector.find(startNode); - if(itParent != collector.end()){ - - for (const auto& itNodes : itParent->second.connections() ) { - - auto itChild = collector.find(itNodes); - - if(itChild != collector.end()){ - - if(itChild->second.isInBranching()){ - return true; - } - - const auto it = _unmapped.find(itNodes); - const auto& child = itChild->second; - opSeq[child.getLayer()][child.getSequence()].append( it->second.customOp(), it->second.contextPrototype() ); - - topolSearch(itNodes, collector, opSeq); - } - } - } - return true; - } - OptimizedGraph Graph::optimizedGraph() const { - std::unordered_map collector; - std::set startNodes, inBranching; - OptimizedGraph optGraph(const_cast(this)); - - // todo check this will be empty Optimized graph - if(!opGraphProto(collector, startNodes, inBranching)) - throw std::runtime_error("Graph::optimizedGraph() - not prototyped"); - - int startSeq = 0; - for(const auto& id : startNodes){ - layersSeqDefine(collector, id, 0, startSeq); - startSeq++; - } - - std::vector> vOpSeq; - initOpSeqContainer(collector, vOpSeq); - - startNodes.insert(inBranching.begin(), inBranching.end()); - - for (const auto& id : startNodes) { - - const auto it = _unmapped.find(id); - const auto& nodeInfo = collector[id]; - vOpSeq[nodeInfo.getLayer()][nodeInfo.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); - - topolSearch(id, collector, vOpSeq); - } - - for(auto& vSeq : vOpSeq){ - optGraph.append(vSeq); - } return optGraph; } - - bool Graph::initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const { - - if(collection.empty()) - return false; - - int layer = 0; - for(const auto& node : collection){ - if(layer < node.second.getLayer()) - layer = node.second.getLayer(); - } - - vOpSeq.resize(layer + 1); - // each layer will have it own max sequence - for(int i = 0; i <= layer; ++i){ - int nSeq = 0; - for(const auto& node : collection){ - if(nSeq < node.second.getSequence() && i == node.second.getLayer()) - nSeq = node.second.getSequence(); - } - vOpSeq[i].resize(nSeq + 1); - } - - return true; - } - - bool Graph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq) const { - - auto parent = collection.find(ID); - if(parent == collection.end()) - return false; - - if( parent->second.isInBranching() ){ - layer++; - startSeq--; - } - - parent->second.setLayer(layer); - parent->second.setSequence(startSeq); - - parent->second.setOutBranching( parent->second.connections().size() > 1 ); - if(parent->second.isOutBranching()) - layer++; - - for(const auto& id : parent->second.connections() ){ - - auto child = collection.find(id); - - layersSeqDefine(collection , id, layer, startSeq); - - if(parent->second.isOutBranching()) - startSeq++; - } - - return true; - } - - } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 28db8145c286..769a1eaf7b09 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -26,6 +26,7 @@ namespace sd { OptimizedGraph::OptimizedGraph(Graph *original) { _originalGraph = original; _memoryManager = const_cast(&original->memoryManager()); + optimizedGraph(); } OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { @@ -87,5 +88,178 @@ namespace sd { const Graph &OptimizedGraph::originalGraph() const { return *_originalGraph; } + + + bool OptimizedGraph::opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const { + + if (originalGraph().unmappedNodes().empty()) + return false; + + for (const auto& it : originalGraph().unmappedNodes()) { + + const auto& ID = it.first; + const auto& inputs = it.second.input(); + + if (collector.find(ID) == collector.end()) + collector[ID] = NodeInfo(); + + NodeInfo& parentNode = collector[ID]; + + int inExCounts = 0, inInternalCounts = 0; + for (auto in = inputs.begin(); in != inputs.end(); ++in) { + if (originalGraph().variableSpace()->hasVariable(in->first, 0)) { + inExCounts++; + } + else { + inInternalCounts++; + if (collector.find(in->first) == collector.end()) + collector[in->first] = NodeInfo(); + collector[in->first].addConnection(ID); + } + } + + parentNode.setInBranching((inInternalCounts == inputs.size() && inInternalCounts > 1)); + + parentNode.setStart(inExCounts == inputs.size()); + parentNode.setSequence(-1); + + if (parentNode.isStart()) { + parentNode.setLayer(0); + startNodes.emplace(ID); + } + else { + if (parentNode.isInBranching()) + inBranchingNodes.emplace(ID); + } + } + return true; + } + + bool OptimizedGraph::topolSearch(const int startNode, const std::unordered_map& collector, + std::vector >& opSeq) const { + + if (originalGraph().unmappedNodes().empty()) + return false; + + auto itParent = collector.find(startNode); + if (itParent != collector.end()) { + + for (const auto& itNodes : itParent->second.connections()) { + + auto itChild = collector.find(itNodes); + + if (itChild != collector.end()) { + + if (itChild->second.isInBranching()) { + return true; + } + + const auto it = originalGraph().unmappedNodes().find(itNodes); + const auto& child = itChild->second; + opSeq[child.getLayer()][child.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); + + topolSearch(itNodes, collector, opSeq); + } + } + } + return true; + } + + void OptimizedGraph::optimizedGraph() { + + std::unordered_map collector; + std::set startNodes, inBranching; + + // todo check this will be empty Optimized graph + if (!opGraphProto(collector, startNodes, inBranching)) + throw std::runtime_error("OptimizedGraph::optimizedGraph() - not prototyped"); + + int startSeq = 0; + for (const auto& id : startNodes) { + layersSeqDefine(collector, id, 0, startSeq); + startSeq++; + } + + std::vector> vOpSeq; + initOpSeqContainer(collector, vOpSeq); + + startNodes.insert(inBranching.begin(), inBranching.end()); + + for (const auto& id : startNodes) { + + const auto it = originalGraph().unmappedNodes().find(id); + const auto& nodeInfo = collector[id]; + vOpSeq[nodeInfo.getLayer()][nodeInfo.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); + + topolSearch(id, collector, vOpSeq); + } + + for (auto& vSeq : vOpSeq) { + this->append(vSeq); + } + } + + bool OptimizedGraph::initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const { + + if (collection.empty()) + return false; + + int layer = 0; + std::vector vSeq; + for (const auto& node : collection) { + + int nodeLayer = node.second.getLayer(); + int nodeSeq = node.second.getSequence(); + if (layer < nodeLayer) + layer = nodeLayer; + + if (vSeq.size() < nodeLayer + 1) { + vSeq.resize(nodeLayer + 1, 0); + } + // each layer will have it own max sequence + if (vSeq[nodeLayer] < nodeSeq) + vSeq[nodeLayer] = nodeSeq; + } + + vOpSeq.resize(layer + 1); + for (int i = 0; i <= layer; ++i) { + vOpSeq[i].resize(vSeq[i] + 1); + } + return true; + } + + bool OptimizedGraph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq) const { + + auto parent = collection.find(ID); + if (parent == collection.end()) + return false; + + if (parent->second.isInBranching()) { + layer++; + if (startSeq > 0) + startSeq--; + } + + parent->second.setLayer(layer); + // sequence have to be init once + if (parent->second.getSequence() < 0) + parent->second.setSequence(startSeq); + + parent->second.setOutBranching(parent->second.connections().size() > 1); + if (parent->second.isOutBranching()) + layer++; + + for (const auto& id : parent->second.connections()) { + + auto child = collection.find(id); + + layersSeqDefine(collection, id, layer, startSeq); + + if (parent->second.isOutBranching()) + startSeq++; + } + + return true; + } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 2f99b3f2c1bb..e7450dda1432 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -214,3 +214,124 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); } + +// currently does not work correctly +TEST_F(GraphAnalysisTests, DISABLED_basic_toposort_test_4) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // B + graph.addVariable("B", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // C + graph.addVariable("C", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // D + graph.addVariable("D", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // E + graph.addVariable("E", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // F + graph.addVariable("F", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); + + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::add(), "d2"); + + Node e(sd::ops::subtract(), "e"); + + graph.addNode(a1, { "A", "B" }); + graph.addNode(a2, { "C", "D" }); + + graph.addNode(b1, { "a1", "E" }); + graph.addNode(b2, { "a1", "a2" }); + graph.addNode(b3, { "a2", "F" }); + + graph.addNode(d1, { "b1", "b2" }); + graph.addNode(d2, { "b3", "b2" }); + + graph.addNode(e, { "d1", "d2" }); + + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(4, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(2, layer0.width()); + auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + + sequence = layer0[1]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + + // ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + // ASSERT_EQ(6, sequence.at(1).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(3, layer1.width()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + // ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + // ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + sequence = layer1[2]; + + ASSERT_EQ(1, sequence.length()); + + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(2, layer2.width()); + sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + + // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + + // checking last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); +} From 7aeb96e3147ac1831e4e906ecadf6977d1a56990 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 10:34:48 +0300 Subject: [PATCH 077/233] internal changes Signed-off-by: raver119 --- libnd4j/include/array/NDArray.hXX | 6 +- libnd4j/include/array/NDArrayList.h | 3 + libnd4j/include/array/ResultSet.h | 8 +- libnd4j/include/array/impl/NDArrayList.cpp | 14 +- libnd4j/include/array/impl/ResultSet.cpp | 18 +- libnd4j/include/graph/Context.h | 90 +-- libnd4j/include/graph/FlatUtils.h | 2 +- libnd4j/include/graph/Graph.h | 14 +- libnd4j/include/graph/InferenceRequest.h | 19 +- libnd4j/include/graph/SessionLocalStorage.h | 65 -- libnd4j/include/graph/Variable.h | 38 +- libnd4j/include/graph/VariableProxy.h | 69 +- libnd4j/include/graph/VariableSpace.h | 113 ++- .../graph/execution/impl/GraphExecutor.cpp | 6 +- libnd4j/include/graph/impl/Context.cpp | 206 ++--- libnd4j/include/graph/impl/FlatUtils.cpp | 8 +- libnd4j/include/graph/impl/Graph.cpp | 61 +- .../include/graph/impl/InferenceRequest.cpp | 28 +- .../graph/impl/SessionLocalStorage.cpp | 123 --- libnd4j/include/graph/impl/Variable.cpp | 81 +- libnd4j/include/graph/impl/VariableProxy.cpp | 133 ++-- libnd4j/include/graph/impl/VariableSpace.cpp | 332 ++++----- .../include/graph/logic/impl/LogicExit.cpp | 4 +- .../graph/logic/impl/LogicLoopCond.cpp | 3 + .../profiling/impl/GraphProfilingHelper.cpp | 15 +- libnd4j/include/helpers/OpBenchmark.h | 28 +- libnd4j/include/helpers/ShapeUtils.h | 2 + .../helpers/benchmark/BroadcastBenchmark.h | 43 +- .../helpers/benchmark/DeclarableBenchmark.h | 14 +- .../helpers/benchmark/MatrixBenchmark.h | 34 +- .../helpers/benchmark/PairwiseBenchmark.h | 32 +- .../helpers/benchmark/ReductionBenchmark.h | 31 +- .../helpers/benchmark/ScalarBenchmark.h | 29 +- .../helpers/benchmark/TransformBenchmark.h | 34 +- .../include/helpers/impl/BenchmarkHelper.cpp | 10 +- libnd4j/include/helpers/impl/GradCheck.cpp | 10 +- libnd4j/include/helpers/impl/OpBenchmark.cpp | 42 +- libnd4j/include/helpers/impl/ShapeUtils.cpp | 8 + libnd4j/include/legacy/cpu/NativeOps.cpp | 20 +- .../ops/declarable/generic/blas/svd.cpp | 4 +- .../ops/declarable/generic/boolean/select.cpp | 4 +- .../ops/declarable/generic/boolean/where.cpp | 6 +- .../declarable/generic/boolean/where_np.cpp | 10 +- .../generic/broadcastable/meshgrid.cpp | 2 +- .../generic/broadcastable/multiply.cpp | 4 +- .../generic/broadcastable/percentile.cpp | 2 +- .../declarable/generic/broadcastable/pow.cpp | 2 +- .../generic/flow/flow_control_ops.cpp | 5 +- .../generic/images/extract_image_patches.cpp | 2 +- .../generic/images/image_resize.cpp | 2 +- .../declarable/generic/images/resize_area.cpp | 2 +- .../generic/images/resize_bicubic.cpp | 2 +- .../generic/images/resize_linear.cpp | 2 +- .../generic/images/resize_neighbor.cpp | 2 +- .../declarable/generic/linalg/diagPart.cpp | 2 +- .../ops/declarable/generic/linalg/eye.cpp | 8 +- .../generic/linalg/matrixDiagPart.cpp | 2 +- .../declarable/generic/linalg/matrix_diag.cpp | 2 +- .../ops/declarable/generic/linalg/trace.cpp | 4 +- .../ops/declarable/generic/linalg/triu.cpp | 4 +- .../declarable/generic/list/create_list.cpp | 4 +- .../declarable/generic/list/scatter_list.cpp | 5 +- .../declarable/generic/list/split_list.cpp | 4 +- .../generic/loss/absoluteDifference.cpp | 6 +- .../generic/loss/cosineDistance.cpp | 12 +- .../ops/declarable/generic/loss/hingeLoss.cpp | 8 +- .../ops/declarable/generic/loss/huberLoss.cpp | 8 +- .../ops/declarable/generic/loss/logLoss.cpp | 6 +- .../generic/loss/log_poisson_loss.cpp | 12 +- .../generic/loss/meanPairWsSqErr.cpp | 8 +- .../ops/declarable/generic/loss/meanSqErr.cpp | 6 +- .../generic/loss/sigmCrossEntropy.cpp | 6 +- .../generic/loss/softmaxCrossEntropy.cpp | 6 +- .../loss/softmaxCrossEntropyWithLogits.cpp | 2 +- .../sparseSoftmaxCrossEntropyWithLogits.cpp | 4 +- .../ops/declarable/generic/nlp/skipgram.cpp | 2 +- .../generic/nn/activations/crelu.cpp | 6 +- .../generic/nn/activations/identity_n.cpp | 2 +- .../ops/declarable/generic/nn/batchnorm.cpp | 2 +- .../declarable/generic/nn/convo/col2im.cpp | 2 +- .../declarable/generic/nn/convo/conv1d.cpp | 8 +- .../declarable/generic/nn/convo/conv2d.cpp | 10 +- .../declarable/generic/nn/convo/conv3d.cpp | 8 +- .../declarable/generic/nn/convo/deconv2d.cpp | 6 +- .../declarable/generic/nn/convo/deconv3d.cpp | 8 +- .../generic/nn/convo/depthwiseConv2d.cpp | 8 +- .../declarable/generic/nn/convo/im2col.cpp | 2 +- .../generic/nn/convo/pointwiseConv2d.cpp | 2 +- .../declarable/generic/nn/convo/sconv2d.cpp | 14 +- .../generic/nn/convo/upsampling2d.cpp | 4 +- .../generic/nn/convo/upsampling3d.cpp | 4 +- .../generic/nn/embedding_lookup.cpp | 4 +- .../generic/nn/pooling/pnormpool2d.cpp | 2 +- .../nn/recurrent/dynamicBidirectionalRNN.cpp | 12 +- .../generic/nn/recurrent/dynamicRNN.cpp | 4 +- .../declarable/generic/nn/recurrent/gru.cpp | 2 +- .../generic/nn/recurrent/gruCell.cpp | 2 +- .../declarable/generic/nn/recurrent/lstm.cpp | 4 +- .../generic/nn/recurrent/lstmBlock.cpp | 2 +- .../generic/nn/recurrent/lstmBlockCell.cpp | 2 +- .../generic/nn/recurrent/lstmCell.cpp | 4 +- .../declarable/generic/nn/recurrent/sru.cpp | 10 +- .../generic/nn/recurrent/sruCell.cpp | 4 +- .../nn/recurrent/staticBidirectionalRNN.cpp | 6 +- .../generic/nn/recurrent/staticRNN.cpp | 4 +- .../ops/declarable/generic/nn/relu_layer.cpp | 2 +- .../ops/declarable/generic/nn/xw_plus_b.cpp | 4 +- .../declarable/generic/parity_ops/expose.cpp | 3 +- .../generic/parity_ops/normalize_moments.cpp | 4 +- .../generic/parity_ops/nth_element.cpp | 2 +- .../generic/parity_ops/segment_max.cpp | 2 +- .../generic/parity_ops/segment_mean.cpp | 2 +- .../generic/parity_ops/segment_min.cpp | 2 +- .../generic/parity_ops/segment_prod.cpp | 2 +- .../generic/parity_ops/segment_sum.cpp | 2 +- .../generic/parity_ops/sequence_mask.cpp | 2 +- .../declarable/generic/parity_ops/top_k.cpp | 4 +- .../declarable/generic/parity_ops/unique.cpp | 2 +- .../parity_ops/unsorted_segment_max.cpp | 2 +- .../parity_ops/unsorted_segment_mean.cpp | 2 +- .../parity_ops/unsorted_segment_min.cpp | 2 +- .../parity_ops/unsorted_segment_prod.cpp | 2 +- .../parity_ops/unsorted_segment_sqrt_n.cpp | 2 +- .../parity_ops/unsorted_segment_sum.cpp | 2 +- .../ops/declarable/generic/reduce/argmax.cpp | 2 +- .../ops/declarable/generic/reduce/argmin.cpp | 2 +- .../declarable/generic/reduce/reduceMean.cpp | 4 +- .../declarable/generic/reduce/reduceStDev.cpp | 4 +- .../generic/reduce/reduceVariance.cpp | 4 +- .../declarable/generic/reduce/reduce_dot.cpp | 2 +- .../generic/reduce/reduce_logsumexp.cpp | 2 +- .../declarable/generic/reduce/reduce_max.cpp | 2 +- .../declarable/generic/reduce/reduce_min.cpp | 2 +- .../generic/reduce/reduce_norm1.cpp | 4 +- .../generic/reduce/reduce_norm2.cpp | 4 +- .../generic/reduce/reduce_norm_max.cpp | 2 +- .../declarable/generic/reduce/reduce_prod.cpp | 4 +- .../generic/reduce/reduce_sqnorm.cpp | 4 +- .../declarable/generic/reduce/reduce_sum.cpp | 4 +- .../ops/declarable/generic/tensor/fill.cpp | 2 +- .../generic/tensor/strided_slice.cpp | 4 +- .../declarable/generic/tests/test_scalar.cpp | 2 +- .../declarable/generic/tests/testcustom.cpp | 4 +- .../generic/thrid_party/firas_sparse.cpp | 4 +- .../declarable/generic/transforms/concat.cpp | 10 +- .../generic/transforms/dynamic_parititon.cpp | 6 +- .../declarable/generic/transforms/gather.cpp | 6 +- .../generic/transforms/gatherNd.cpp | 2 +- .../generic/transforms/mirrorPad.cpp | 2 +- .../ops/declarable/generic/transforms/pad.cpp | 4 +- .../generic/transforms/parallelStack.cpp | 2 +- .../generic/transforms/scatter_nd.cpp | 2 +- .../declarable/generic/transforms/slice.cpp | 4 +- .../declarable/generic/tsne/edge_force.cpp | 2 +- .../declarable/generic/tsne/symmetrized.cpp | 6 +- .../ops/declarable/helpers/cpu/clip.cpp | 16 +- .../ops/declarable/helpers/cpu/confusion.cpp | 2 +- .../ops/declarable/helpers/cpu/cross.cpp | 2 +- .../ops/declarable/helpers/cpu/dynamic.cpp | 6 +- .../helpers/cpu/extract_patches.cpp | 2 +- .../ops/declarable/helpers/cpu/eye.cpp | 2 +- .../ops/declarable/helpers/cpu/lstsq.cpp | 2 +- .../ops/declarable/helpers/cpu/lup.cpp | 36 +- .../declarable/helpers/cpu/matrix_band.cpp | 16 +- .../helpers/cpu/matrix_diag_part.cpp | 2 +- .../ops/declarable/helpers/cpu/meshgrid.cpp | 2 +- .../declarable/helpers/cpu/nth_element.cpp | 2 +- .../ops/declarable/helpers/cpu/percentile.cpp | 4 +- .../ops/declarable/helpers/cpu/prefix.cpp | 2 +- .../include/ops/declarable/helpers/cpu/qr.cpp | 2 +- .../declarable/helpers/cpu/randomShuffle.cpp | 8 +- .../ops/declarable/helpers/cpu/reverse.cpp | 16 +- .../ops/declarable/helpers/cpu/roll.cpp | 6 +- .../ops/declarable/helpers/cpu/segment.cpp | 128 ++-- .../ops/declarable/helpers/cpu/solve.cpp | 10 +- .../ops/declarable/helpers/cpu/svd.cpp | 8 +- .../ops/declarable/helpers/cpu/trace.cpp | 2 +- .../helpers/cpu/triangular_solve.cpp | 8 +- .../include/ops/declarable/helpers/cross.h | 2 +- .../ops/declarable/helpers/cuda/lup.cu | 4 +- .../ops/declarable/helpers/impl/lstmLayer.cpp | 70 +- .../declarable/helpers/impl/multiUnique.cpp | 6 +- .../include/ops/declarable/impl/BooleanOp.cpp | 7 +- .../ops/declarable/impl/DeclarableListOp.cpp | 29 +- .../ops/declarable/impl/DeclarableOp.cpp | 94 +-- .../declarable/impl/DeclarableReductionOp.cpp | 2 +- .../ops/declarable/impl/LegacyBroadcastOp.cpp | 2 +- .../declarable/impl/LegacyIndexReduceOp.cpp | 12 +- .../ops/declarable/impl/LegacyRandomOp.cpp | 18 +- .../ops/declarable/impl/LegacyReduce3Op.cpp | 2 +- .../ops/declarable/impl/LegacyReduceOp.cpp | 6 +- .../ops/declarable/impl/LegacyStatsOp.cpp | 2 +- .../ops/declarable/impl/PlatformHelper.cpp | 13 +- .../benchmarking/impl/FullBenchmarkSuit.cpp | 534 +++++++------ .../benchmarking/impl/LightBenchmarkSuit.cpp | 196 ++--- libnd4j/include/system/op_boilerplate.h | 6 +- .../layers_tests/BooleanOpsTests.cpp | 2 +- .../layers_tests/BroadcastableOpsTests.cpp | 12 +- .../tests_cpu/layers_tests/ContextTests.cpp | 54 +- .../layers_tests/ConvolutionTests1.cpp | 94 +-- .../layers_tests/ConvolutionTests2.cpp | 112 +-- .../layers_tests/DataTypesValidationTests.cpp | 12 +- .../layers_tests/DeclarableOpsTests1.cpp | 564 +++++++------- .../layers_tests/DeclarableOpsTests10.cpp | 192 ++--- .../layers_tests/DeclarableOpsTests11.cpp | 404 +++++----- .../layers_tests/DeclarableOpsTests12.cpp | 204 +++-- .../layers_tests/DeclarableOpsTests13.cpp | 135 ++-- .../layers_tests/DeclarableOpsTests14.cpp | 135 ++-- .../layers_tests/DeclarableOpsTests15.cpp | 264 +++---- .../layers_tests/DeclarableOpsTests16.cpp | 110 +-- .../layers_tests/DeclarableOpsTests17.cpp | 4 +- .../layers_tests/DeclarableOpsTests18.cpp | 2 +- .../layers_tests/DeclarableOpsTests2.cpp | 702 +++++++++--------- .../layers_tests/DeclarableOpsTests3.cpp | 192 ++--- .../layers_tests/DeclarableOpsTests4.cpp | 46 +- .../layers_tests/DeclarableOpsTests5.cpp | 90 ++- .../layers_tests/DeclarableOpsTests6.cpp | 62 +- .../layers_tests/DeclarableOpsTests7.cpp | 42 +- .../layers_tests/DeclarableOpsTests8.cpp | 22 +- .../layers_tests/DeclarableOpsTests9.cpp | 78 +- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 51 +- .../tests_cpu/layers_tests/FlatUtilsTests.cpp | 16 +- .../layers_tests/GraphExecutorTests.cpp | 12 +- libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 267 ++++--- .../tests_cpu/layers_tests/GraphTests2.cpp | 6 +- .../tests_cpu/layers_tests/IndexingTests.cpp | 24 +- .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 12 +- .../layers_tests/ListOperationsTests.cpp | 39 +- .../layers_tests/MultiDataTypeTests.cpp | 2 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 46 +- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 26 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 36 +- .../tests_cpu/layers_tests/ResultSetTests.cpp | 4 +- .../tests_cpu/layers_tests/SanityTests.cpp | 22 +- .../layers_tests/SessionLocalTests.cpp | 93 --- .../tests_cpu/layers_tests/SingleDimTests.cpp | 2 +- .../layers_tests/VariableProxyTests.cpp | 53 +- .../layers_tests/VariableSpaceTests.cpp | 121 +-- .../tests_cpu/layers_tests/VariableTests.cpp | 58 +- 239 files changed, 3548 insertions(+), 4434 deletions(-) delete mode 100644 libnd4j/include/graph/SessionLocalStorage.h delete mode 100644 libnd4j/include/graph/impl/SessionLocalStorage.cpp delete mode 100644 libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 86e097b76efa..d280cec37f79 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -4706,7 +4706,7 @@ ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices throw std::runtime_error("Bad index"); } - auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); + NDArray array(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); result.push_back(array); } @@ -4820,8 +4820,8 @@ ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) auto numTads = pack.numberOfTads(); for (Nd4jLong idx = 0; idx < numTads; idx++ ) { - auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); - array->_isView = true; + NDArray array(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); + array._isView = true; result.push_back(array); } diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index 661422de914d..63aaf2beafc0 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -65,6 +65,9 @@ namespace sd { NDArrayList(int height, bool expandable = false); ~NDArrayList(); + NDArrayList(const sd::NDArrayList &other); + NDArrayList(sd::NDArrayList &&other); + sd::DataType dataType(); NDArray* read(int idx); diff --git a/libnd4j/include/array/ResultSet.h b/libnd4j/include/array/ResultSet.h index 916ae0d41fe3..6883a14c105a 100644 --- a/libnd4j/include/array/ResultSet.h +++ b/libnd4j/include/array/ResultSet.h @@ -36,7 +36,7 @@ namespace sd { class SD_EXPORT ResultSet { private: - std::vector _content; + std::vector _content; Nd4jStatus _status = ND4J_STATUS_OK; bool _removable = true; @@ -62,9 +62,9 @@ namespace sd { ~ResultSet(); int size(); - sd::NDArray* at(const unsigned long idx) const; - sd::NDArray* operator[](const unsigned long idx) const; - void push_back(sd::NDArray* array); + sd::NDArray& at(const unsigned long idx) const; + sd::NDArray& operator[](const unsigned long idx) const; + void push_back(const sd::NDArray &array); Nd4jStatus status(); void setStatus(Nd4jStatus status); diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index ecd4bcacac96..eba1288f8d39 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -26,6 +26,14 @@ #include namespace sd { + NDArrayList::NDArrayList(const NDArrayList &other) { + + } + + NDArrayList::NDArrayList(NDArrayList &&other) { + + } + NDArrayList::NDArrayList(int height, bool expandable) { _expandable = expandable; _elements.store(0); @@ -138,8 +146,8 @@ namespace sd { auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto result = array->allTensorsAlongDimension(newAxis); for (int e = 0; e < result.size(); e++) { - auto chunk = result.at(e);//->dup(array->ordering()); - write(e, new NDArray(chunk->dup(array->ordering()))); + auto chunk = result.at(e); + write(e, new NDArray(chunk.dup(array->ordering()))); } } @@ -229,7 +237,7 @@ namespace sd { throw std::runtime_error("Number of TADs should match number of indices"); for (int e = 0; e < indicesSize; e++) - tads.at(e)->assign(_chunks[indices[e]]); + tads.at(e).assign(_chunks[indices[e]]); return array; } diff --git a/libnd4j/include/array/impl/ResultSet.cpp b/libnd4j/include/array/impl/ResultSet.cpp index d9d824d4689e..53cde399e6b8 100644 --- a/libnd4j/include/array/impl/ResultSet.cpp +++ b/libnd4j/include/array/impl/ResultSet.cpp @@ -30,7 +30,7 @@ namespace sd { for (int e = 0; e < result->variables()->size(); e++) { auto var = result->variables()->Get(e); - NDArray* array; + NDArray array; if (var->ndarray() != nullptr) { array = sd::graph::FlatUtils::fromFlatArray(var->ndarray()); @@ -48,7 +48,7 @@ namespace sd { shape.emplace_back(shapeInfo.at(i + 1)); } - array = new NDArray((char) shapeInfo.at(shapeInfo.size() - 1), shape, DataTypeUtils::fromFlatDataType(var->dtype())); + array = NDArray((char) shapeInfo.at(shapeInfo.size() - 1), shape, DataTypeUtils::fromFlatDataType(var->dtype())); } else { nd4j_printf("Either shape or NDArray should be defined in FlatResult variable\n",""); throw std::runtime_error("Empty variable"); @@ -111,9 +111,7 @@ namespace sd { } void ResultSet::delContent() { - if (_removable) - for (auto v : _content) - delete v; + // } ResultSet::~ResultSet() { @@ -129,15 +127,15 @@ namespace sd { return (int) _content.size(); } - sd::NDArray* ResultSet::at(const unsigned long idx) const { - return _content.at(idx); + sd::NDArray& ResultSet::at(const unsigned long idx) const { + return const_cast(_content[idx]); } - sd::NDArray* ResultSet::operator[](const unsigned long idx) const { - return _content[idx]; + sd::NDArray& ResultSet::operator[](const unsigned long idx) const { + return const_cast(_content[idx]); } - void ResultSet::push_back(sd::NDArray *array) { + void ResultSet::push_back(const sd::NDArray &array) { _content.emplace_back(array); } diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 0d03caafb231..646df825f5f3 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -52,7 +52,6 @@ namespace sd { sd::graph::VariableSpace* _variableSpace = nullptr; std::pair _executionTime; - sd::random::RandomBuffer* _rng = nullptr; sd::DataType _dataType = sd::DataType::FLOAT32; // branch for divergent_op @@ -64,9 +63,8 @@ namespace sd { std::vector _dataTypes; // fields for fast execution (out-of-graph ops use) - std::vector _fastpath_in; - std::vector _fastpath_out; - std::vector _handles; + std::vector> _fastpath_in; + std::vector> _fastpath_out; bool _helpersAllowed = true; @@ -87,62 +85,37 @@ namespace sd { // these methods are for execution timing void setOuterTime(Nd4jLong time); void setInnerTime(Nd4jLong time); - Nd4jLong getOuterTime(); - Nd4jLong getInnerTime(); + Nd4jLong outerTime() const; + Nd4jLong innerTime() const; // these methods are related to Workspace abstraction - bool hasWorkspaceProvided(); + bool hasWorkspaceProvided() const; void attachWorkspace(sd::memory::Workspace* workspace); - void forgetWorkspace(); - // these methods return full-time workspace - sd::memory::Workspace* getWorkspace(); - sd::memory::Workspace* workspace(); - sd::memory::Workspace* fWorkspace(); + sd::memory::Workspace* workspace() const; - // this method returns workspace for temporary allocations - sd::memory::Workspace* tWorkspace(); - - // this method returns workspace for object allocations - sd::memory::Workspace* oWorkspace(); void setVariableSpace(VariableSpace* variableSpace); - sd::random::RandomBuffer* getRNG(); - void setRNG(sd::random::RandomBuffer* rng); - void setTargetEngine(samediff::Engine engine); VariableSpace *getVariableSpace(); LaunchContext* launchContext(); - // these fields define, if we can execute specific node in-place, without generating new array - - - // these variables are only for Divergent Nodes - int getBranch(); - void setBranch(int branch); - /** * * @return */ - Stash* getStash(); - - /** - * - */ - void trackList(NDArrayList* list); - + Stash* stash() const; /** * This method returns variable for a given input index for this block * @param idx * @return */ - Variable* getVariable(int idx); - Variable* variable(int idx); + std::shared_ptr getVariable(int idx) const; + std::shared_ptr variable(int idx) const; /** * This method is shortcut to getVariable(int idx); @@ -150,8 +123,8 @@ namespace sd { * + it check fastpath for array availability (preferred) * @return */ - NDArray* getNDArray(int idx); - NDArray* array(int idx); + std::shared_ptr getNDArray(int idx) const; + std::shared_ptr array(int idx) const; /** @@ -159,20 +132,21 @@ namespace sd { * @param p * @return */ - Variable* variable(int node, int index); - Variable* variable(std::pair& p); - Variable* variable(std::initializer_list p); + std::shared_ptr variable(int node, int index) const; + std::shared_ptr variable(const std::pair& p) const; + std::shared_ptr variable(std::initializer_list p) const; - void pushNDArrayToVariableSpace(int nodeId, int index, NDArray* array, bool removable = true); - void pushNDArrayToVariableSpace(std::pair& pair, NDArray* array, bool removable = true); + void pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array); + void pushNDArrayToVariableSpace(const std::pair& pair, const NDArray &array); - void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList* list, bool track = true); - void pushNDArrayListToVariableSpace(std::pair& pair, NDArrayList* list, bool track = true); + void pushNDArrayListToVariableSpace(int nodeId, int index, std::shared_ptr list); + void pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track = true); + void pushNDArrayListToVariableSpace(const std::pair& pair, const NDArrayList &list, bool track = true); - bool isValueAvailable(int idx = 0); + bool isValueAvailable(int idx = 0) const; - Variable* ensureVariable(int idx = 0); + std::shared_ptr ensureVariable(int idx = 0); unsigned long width() const override; @@ -181,7 +155,7 @@ namespace sd { * This method checks if Context uses fastpath variable access * @return */ - bool isFastPath(); + bool isFastPath() const; /** * Method allows to forbid FastPath execution @@ -190,15 +164,17 @@ namespace sd { void forbidFastPath(bool reallyForbid); #ifndef __JAVACPP_HACK__ - std::vector& fastpath_in(); - std::vector& fastpath_out(); + const std::vector>& fastpath_in() const; + const std::vector>& fastpath_out() const; #endif - void setInputArray(int index, NDArray *array, bool removable = false); + void setInputArray(int index, const std::shared_ptr &array); + void setInputArray(int index, const NDArray &array); void setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setInputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); - void setOutputArray(int index, NDArray *array, bool removable = false); + void setOutputArray(int index, const std::shared_ptr &array); + void setOutputArray(int index, const NDArray &array); void setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); void setOutputArray(int index, void *databuffer, void *shapeInfo, void *specialShapeInfo); @@ -222,16 +198,16 @@ namespace sd { void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void allowHelpers(bool reallyAllow); - bool helpersAllowed(); + bool helpersAllowed() const; void setShapeFunctionOverride(bool reallyOverride); - bool shapeFunctionOverride(); + bool shapeFunctionOverride() const; - samediff::ExecutionMode executionMode(); + samediff::ExecutionMode executionMode() const; void setExecutionMode(samediff::ExecutionMode executionMode); - bool isTraining(); - bool isInference(); + bool isTraining() const; + bool isInference() const; const GraphMemoryManager& memoryManager() const; }; diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index dc5911708545..0d1e54025272 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -35,7 +35,7 @@ namespace sd { static std::pair fromLongPair(LongPair* pair); - static NDArray* fromFlatArray(const sd::graph::FlatArray* flatArray); + static NDArray fromFlatArray(const sd::graph::FlatArray* flatArray); static flatbuffers::Offset toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array); }; diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 725dfe261935..2deb04e7477a 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -46,7 +46,8 @@ namespace sd { class SD_EXPORT Graph { protected: ExecutorConfiguration _configuration; - VariableSpace *_variableSpace; + VariableSpace _variableSpace; + memory::Workspace _workspace; Stash _stash; MAP_IMPL _unmapped; @@ -71,7 +72,7 @@ namespace sd { std::vector _placeholders; public: - Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + Graph(const FlatGraph *flatGraph = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); ~Graph(); @@ -100,13 +101,13 @@ namespace sd { int numberOfPlaceholders() const; - const std::vector& getPlaceholders() const; + const std::vector>& placeholders() const; /** * This method returns pointer to thread_local VariableSpace * @return */ - VariableSpace *variableSpace() const; + VariableSpace& variableSpace() const; const GraphMemoryManager& memoryManager() const; @@ -152,11 +153,6 @@ namespace sd { */ Graph cloneWithProxy() const; - /** - * This method removes reference to VariableSpace from this Graph - */ - void forgetVariableSpace(); - /** * This method returns hash of given Graph instance */ diff --git a/libnd4j/include/graph/InferenceRequest.h b/libnd4j/include/graph/InferenceRequest.h index 2ac698841717..75418d7cbe3b 100644 --- a/libnd4j/include/graph/InferenceRequest.h +++ b/libnd4j/include/graph/InferenceRequest.h @@ -31,22 +31,21 @@ namespace sd { class SD_EXPORT InferenceRequest { private: Nd4jLong _id; - std::vector _variables; - std::vector _deletables; + std::vector> _variables; - ExecutorConfiguration *_configuration = nullptr; + ExecutorConfiguration _configuration; - void insertVariable(Variable* variable); + void insertVariable(std::shared_ptr variable); public: - InferenceRequest(Nd4jLong graphId, ExecutorConfiguration *configuration = nullptr); + InferenceRequest(Nd4jLong graphId, const ExecutorConfiguration &configuration); ~InferenceRequest(); - void appendVariable(int id, NDArray *array); - void appendVariable(int id, int index, NDArray *array); - void appendVariable(std::string &name, NDArray *array); - void appendVariable(std::string &name, int id, int index, NDArray *array); - void appendVariable(Variable *variable); + void appendVariable(int id, const NDArray &array); + void appendVariable(int id, int index, const NDArray &array); + void appendVariable(const std::string &name, const NDArray &array); + void appendVariable(const std::string &name, int id, int index, const NDArray &array); + void appendVariable(std::shared_ptr variable); #ifndef __JAVACPP_HACK__ flatbuffers::Offset asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder); diff --git a/libnd4j/include/graph/SessionLocalStorage.h b/libnd4j/include/graph/SessionLocalStorage.h deleted file mode 100644 index b3636f2615b1..000000000000 --- a/libnd4j/include/graph/SessionLocalStorage.h +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALSTORAGE_H -#define LIBND4J_SESSIONLOCALSTORAGE_H - -#include -#include -#include -#include "VariableSpace.h" -#include "Context.h" -#include "Stash.h" -#include - -namespace sd{ - namespace graph { - class SD_EXPORT SessionLocalStorage { - protected: - std::atomic _sessionCounter; - MAP_IMPL _threadSession; - MAP_IMPL _threadVariableSpace; - - VariableSpace* _variableSpace; - Stash* _stash; - - std::mutex _mutex; - - Nd4jLong getSessionId(); - Nd4jLong getThreadId(); - public: - SessionLocalStorage(VariableSpace* variableSpace = nullptr, Stash* stash = nullptr); - - ~SessionLocalStorage(); - - VariableSpace* localVariableSpace(); - VariableSpace* localVariableSpace(Nd4jLong sessionId); - - - Nd4jLong startSession(); - void endSession(Nd4jLong sessionId); - void endSession(); - - int numberOfSessions(); - }; - } -} - -#endif //LIBND4J_SESSIONLOCALSTORAGE_H diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index d1f2ddc78b91..ae5878bc22f2 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -60,7 +60,7 @@ namespace sd { protected: int _id = 0; int _index = 0; - sd::NDArray *_ndarray = nullptr; + std::shared_ptr _ndarray; std::string _name; std::vector _shape; @@ -71,38 +71,31 @@ namespace sd { bool _placeholder = false; bool _removable = true; - // for now we're setting default to numeric - // in future we'll be fetching it right from the array, - //InputType _variableType = InputType_UNDEFINED; - //DataType _dataType = INHERIT; - - sd::NDArrayList *_list = nullptr; + std::shared_ptr _list; VariableType _variableType = VariableType::NDARRAY; public: - Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); - Variable(sd::NDArray *arrayw, const char *name, int id, int idx = 0); - Variable(sd::NDArray *array = nullptr, const char *name = nullptr); + explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); + explicit Variable(const sd::NDArray &array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const char *name = nullptr); + explicit Variable(); #ifndef __JAVACPP_HACK__ - Variable(const sd::graph::FlatVariable *flatVariable); + explicit Variable(const sd::graph::FlatVariable *flatVariable); #endif ~Variable(); - Variable* clone() const; - - template - SD_EXPORT Variable* asT() const; bool hasNDArray() const; - sd::NDArray* getNDArray() const; - void setNDArray(sd::NDArray *array); + std::shared_ptr getNDArray() const; + void setNDArray(std::shared_ptr array); bool hasNDArrayList() const; - sd::NDArrayList* getNDArrayList() const; - void setNDArrayList(sd::NDArrayList* list); + std::shared_ptr getNDArrayList() const; + void setNDArrayList(std::shared_ptr list); bool isExternal() const; bool isReadOnly() const; @@ -114,13 +107,6 @@ namespace sd { VariableType variableType() const; void setVariableType(VariableType variableType); - /** - * This method returns InputType of this variable - */ - //InputType variableType() { - // return _variableType; - //} - void markExternal(bool reallyExternal); void markReadOnly(bool reallyReadOnly); void markRemovable(bool reallyRemovable); diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 2e855ce7a9f1..b49b9f55df34 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -32,58 +32,51 @@ namespace sd { virtual VariableSpace& operator=(const VariableSpace& other); - virtual int numberOfPlaceholders() override; - virtual std::vector* getPlaceholders() override; + virtual int numberOfPlaceholders() const override; + virtual const std::vector>& placeholders() const override; - virtual sd::memory::Workspace *workspace(); + virtual bool hasExternalVariable(int it) const override; + virtual bool hasExternalVariable(const std::pair& pair) const override; + virtual bool hasExternalVariable(const std::string &symbol) const override; - virtual bool hasExternalVariable(int it) override; - virtual bool hasExternalVariable(std::pair& pair) override; - virtual bool hasExternalVariable(const std::string &symbol) override; + virtual bool hasVariable(int id) const override; + virtual bool hasVariable(int id, int idx) const override; + virtual bool hasVariable(const std::pair& pair) const override; + virtual bool hasVariable(const std::string &symbol) const override; - virtual bool hasVariable(int id) override; - virtual bool hasVariable(int id, int idx) override; - virtual bool hasVariable(std::pair& pair) override; - virtual bool hasVariable(const std::string &symbol) override; + virtual std::shared_ptr getVariable(int id) const override; + virtual std::shared_ptr getVariable(int id, int idx) const override; + virtual std::shared_ptr getVariable(const std::pair& pair) const override; + virtual std::shared_ptr getVariable(const std::string &symbol) const override; - virtual Variable *getVariable(int id) override; - virtual Variable *getVariable(int id, int idx) override; - virtual Variable *getVariable(std::pair& pair) override; - virtual Variable *getVariable(const std::string &symbol) override; + virtual std::vector> variables() const override; - virtual std::vector getVariables() override; + virtual std::shared_ptr putVariable(const std::pair& pair, const NDArray &array) override; + virtual std::shared_ptr putVariable(int id, const NDArray &array) override; + virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array) override; + virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array) override; + virtual void putVariable(const std::string& name, int id, int idx, std::shared_ptr variable) override; + virtual void putVariable(const std::pair& pair, std::shared_ptr variable) override; + virtual void putVariable(int id, std::shared_ptr variable) override; - virtual Variable* putVariable(std::pair& pair, NDArray *array) override; - virtual void putVariable(std::pair& pair, Variable *variable) override; - virtual void putVariable(int id, Variable *variable) override; - virtual void putVariable(int id, NDArray *array) override; - virtual Variable* putVariable(int id, int idx, NDArray *array) override; - void putVariable(int id, int idx, const NDArray &array) override; - virtual void putVariable(int id, int idx, Variable *array) override; + virtual void replaceVariable(std::shared_ptr variable) override; - virtual void replaceVariable(Variable *variable) override; - - virtual void dropVariable(std::pair &pair) override; + virtual void dropVariable(const std::pair &pair) override; virtual void dropVariable(int id, int idx) override; - virtual void putOutputVariable(Variable *variable) override; - - virtual void trackList(sd::NDArrayList *list) override; + virtual void putOutputVariable(std::shared_ptr variable) override; // memory-related statistics - virtual Nd4jLong externalMemory() override; - virtual Nd4jLong internalMemory() override; - virtual Nd4jLong totalMemory() override; + virtual Nd4jLong externalMemory() const override; + virtual Nd4jLong internalMemory() const override; + virtual Nd4jLong totalMemory() const override; - virtual int externalEntries() override; - virtual int internalEntries() override; - virtual int totalEntries() override; + virtual int externalEntries() const override; + virtual int internalEntries() const override; + virtual int totalEntries() const override; - virtual VariableSpace *clone() override; - virtual Stash* getStash() override; - virtual void setFlowPath(FlowPath* timers) override; - virtual FlowPath* flowPath() override; + virtual Stash* stash() const override; }; } } \ No newline at end of file diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index a7fd05228ad6..c7d6d0238c86 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -40,101 +40,90 @@ namespace sd { namespace graph { class SD_EXPORT VariableSpace { protected: - sd::memory::Workspace *_workspace; - // stash is NOT cloned Stash _stash; - MAP_IMPL, Variable*> _paired; - MAP_IMPL _symbolic; - MAP_IMPL _variables; - std::vector _external; - std::vector _internal; + // lookup tables: by name, by id, by id:idx + MAP_IMPL, std::shared_ptr> _paired; + MAP_IMPL> _symbolic; + MAP_IMPL> _variables; + + // direct references to external variables and internally-generated variables + std::vector> _external; + std::vector> _internal; - std::vector _lists; + // meh + std::vector> _lists; - std::vector _placeholders; + // placeholders. must be resolved before Graph execution + std::vector> _placeholders; - void silentPutVariable(std::pair& pair, Variable *variable); + void silentPutVariable(const std::pair& pair, std::shared_ptr variable); int _auto_counter = -1; std::mutex _varmap; - MAP_IMPL _temporary; - - std::vector *_handles; - - FlowPath* _flow = nullptr; - public: VariableSpace(); virtual ~VariableSpace(); + VariableSpace(const sd::graph::VariableSpace &variableSpace); + VariableSpace(sd::graph::VariableSpace &&variableSpace); + virtual VariableSpace& operator=(const VariableSpace& other); + virtual VariableSpace& operator=(VariableSpace&& other); - virtual int numberOfPlaceholders(); - virtual std::vector* getPlaceholders(); - virtual void setWorkspace(sd::memory::Workspace *workspace); + virtual int numberOfPlaceholders() const; - virtual LaunchContext* launchContext(); + virtual const std::vector>& placeholders() const; - virtual bool hasExternalVariable(int it); - virtual bool hasExternalVariable(std::pair& pair); - virtual bool hasExternalVariable(const std::string &symbol); + virtual bool hasExternalVariable(int it) const; + virtual bool hasExternalVariable(const std::pair& pair) const; + virtual bool hasExternalVariable(const std::string &symbol) const; - virtual bool hasVariable(int id); - virtual bool hasVariable(int id, int idx); - virtual bool hasVariable(std::pair& pair); - virtual bool hasVariable(const std::string &symbol); + virtual bool hasVariable(int id) const; + virtual bool hasVariable(int id, int idx) const; + virtual bool hasVariable(const std::pair& pair) const; + virtual bool hasVariable(const std::string &symbol) const; - virtual Variable* getVariable(int id); - virtual Variable* getVariable(int id, int idx); - virtual Variable* getVariable(std::pair& pair); - virtual Variable* getVariable(const std::string &symbol); + virtual std::shared_ptr getVariable(int id) const; + virtual std::shared_ptr getVariable(int id, int idx) const; + virtual std::shared_ptr getVariable(const std::pair& pair) const; + virtual std::shared_ptr getVariable(const std::string &symbol) const; - virtual std::vector getVariables(); + virtual std::vector> variables() const; - virtual Variable* putVariable(std::pair& pair, NDArray *array); - virtual void putVariable(std::pair& pair, Variable *variable); - virtual void putVariable(int id, Variable *variable); - virtual void putVariable(int id, NDArray *array); - virtual Variable* putVariable(int id, int idx, NDArray *array); - virtual void putVariable(int id, int idx, const NDArray &array); - virtual void putVariable(int id, int idx, Variable *array); + virtual std::shared_ptr putVariable(const std::pair& pair, const NDArray &array); + virtual std::shared_ptr putVariable(int id, const NDArray &array); + virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); + virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array); + virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array); + virtual void putVariable(const std::string& name, int id, int idx, std::shared_ptr variable); + virtual void putVariable(const std::pair& pair, std::shared_ptr variable); + virtual void putVariable(int id, std::shared_ptr variable); - virtual void dropVariable(std::pair &pair); + virtual void dropVariable(const std::string &pair); + virtual void dropVariable(const std::pair &pair); virtual void dropVariable(int id, int idx); - virtual void trackList(sd::NDArrayList *list); + virtual void putOutputVariable(std::shared_ptr variable); - virtual void putOutputVariable(Variable *variable); - - virtual void replaceVariable(Variable *variable); + virtual void replaceVariable(std::shared_ptr variable); // memory-related statistics - virtual Nd4jLong externalMemory(); - virtual Nd4jLong internalMemory(); - virtual Nd4jLong totalMemory(); - - virtual int externalEntries(); - virtual int internalEntries(); - virtual int totalEntries(); - - virtual VariableSpace* clone(); - - std::vector *handles(); - + virtual Nd4jLong externalMemory() const; + virtual Nd4jLong internalMemory() const; + virtual Nd4jLong totalMemory() const; - VariableSpace* asT(); - void injectVariable(std::pair &pair, Variable* variable); + virtual int externalEntries() const; + virtual int internalEntries() const; + virtual int totalEntries() const; - virtual Stash* getStash(); + void injectVariable(const std::pair &pair, std::shared_ptr variable); - virtual std::vector * getExternalVariables(); + virtual Stash* stash() const; - virtual void setFlowPath(FlowPath* timers); - virtual FlowPath* flowPath(); }; } } diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index d488702a0ea3..07bb0da02f43 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -45,8 +45,10 @@ namespace sd { Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { - auto ctx = prepareContext(contextPrototype, *graph.originalGraph().variableSpace(), graph.memoryManager()); - return op->execute(&ctx); + //auto varSpace = graph.originalGraph().variableSpace(); + //auto ctx = prepareContext(contextPrototype, varSpace, graph.memoryManager()); + //return op->execute(&ctx); + throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); } Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 26a7d261f55f..da8a5037baee 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -55,9 +55,6 @@ namespace sd { this->_nodeId = prototype.nodeId(); this->_name = prototype.name(); this->_useMKLDNN = prototype.isUseMKLDNN(); - - if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) - this->_workspace = variableSpace->launchContext()->getWorkspace(); } Context::Context(int nodeId, VariableSpace *variableSpace) { @@ -68,9 +65,6 @@ namespace sd { this->_executionTime.first = 0; this->_executionTime.second = 0; - - if (variableSpace != nullptr && variableSpace->launchContext()->getWorkspace() != nullptr) - this->_workspace = variableSpace->launchContext()->getWorkspace(); } Context::Context(int nodeId, VariableSpace *variableSpace, bool isInplace) : Context(nodeId, variableSpace) { @@ -84,9 +78,6 @@ namespace sd { this->_fastpath_in.clear(); this->_fastpath_out.clear(); - for (auto v:_handles) - delete v; - if (_context != nullptr) delete _context; } @@ -95,10 +86,6 @@ namespace sd { _engine = engine; } - bool Context::hasWorkspaceProvided() { - return this->_workspace != nullptr; - } - void Context::attachWorkspace(sd::memory::Workspace* workspace) { this->_workspace = workspace; } @@ -107,19 +94,15 @@ namespace sd { this->_variableSpace = variableSpace; } - void Context::forgetWorkspace() { - _workspace = nullptr; - } - - std::vector& Context::fastpath_in() { + const std::vector>& Context::fastpath_in() const { return _fastpath_in; } - std::vector& Context::fastpath_out() { + const std::vector>& Context::fastpath_out() const { return _fastpath_out; } - bool Context::isFastPath() { + bool Context::isFastPath() const { auto ie = _fastpath_in.empty(); auto io = _fastpath_out.empty(); // two options here. @@ -139,65 +122,19 @@ namespace sd { return _variableSpace; } - sd::memory::Workspace* Context::getWorkspace() { + sd::memory::Workspace* Context::workspace() const { return _workspace; } - sd::memory::Workspace* Context::workspace() { - return _workspace; - } - - sd::random::RandomBuffer* Context::getRNG() { - return _rng; - } - - void Context::setRNG(sd::random::RandomBuffer* rng) { - _rng = rng; - } - - /** - * This method returns variableSpace used in this block - * @return - */ - /* - VariableSpace* Context::getVariableSpace() { - return _variableSpace; - } -*/ - - Stash* Context::getStash() { - return _variableSpace->getStash(); + Stash* Context::stash() const { + return _variableSpace->stash(); } - void Context::trackList(NDArrayList* list) { - _variableSpace->trackList(list); - } - -/* - void Block::updateVariables() { - _variables.clear(); - auto x = _inputs.size(); - for (auto &v:_inputs) { - auto var = _variableSpace->getVariable(v); - _variables.emplace_back(var); - } - } -*/ - int Context::getBranch() { - return _variableSpace->flowPath()->branch(this->nodeId()); - } - - void Context::setBranch(int branch) { - //_branch = branch; - if (_variableSpace->flowPath() != nullptr) - _variableSpace->flowPath()->markBranch(this->nodeId(), branch); - } - - Nd4jLong sd::graph::Context::getOuterTime(){ + Nd4jLong sd::graph::Context::outerTime() const { return this->_executionTime.first; } - Nd4jLong sd::graph::Context::getInnerTime(){ + Nd4jLong sd::graph::Context::innerTime() const { return this->_executionTime.second; } @@ -210,7 +147,7 @@ namespace sd { } - Variable* Context::getVariable(int idx) { + std::shared_ptr Context::getVariable(int idx) const { if (idx >= this->_inputs.size()) { nd4j_printf("Node %i; Variable [%i] requested, but only %i inputs available\n", this->_nodeId, idx, this->_inputs.size()); throw std::runtime_error("Context: bad Variable index"); @@ -222,7 +159,7 @@ namespace sd { if (Environment::getInstance()->isDebugAndVerbose() && v != nullptr && v->getNDArray() != nullptr) { auto array = v->getNDArray(); - std::string shape_ = ShapeUtils::shapeAsString(array); + std::string shape_ = ShapeUtils::shapeAsString(array.get()); auto type = DataTypeUtils::asString(array->dataType()); float m = std::numeric_limits::quiet_NaN(); if (!array->isEmpty()) { @@ -237,11 +174,11 @@ namespace sd { return v; } - Variable* Context::variable(int idx) { + std::shared_ptr Context::variable(int idx) const { return getVariable(idx); } - Variable* Context::variable(std::initializer_list p) { + std::shared_ptr Context::variable(std::initializer_list p) const { if (p.size() != 2) throw std::runtime_error("Variable address should have size of 2"); @@ -251,12 +188,12 @@ namespace sd { return variable(pair); } - Variable* Context::variable(int node, int idx) { + std::shared_ptr Context::variable(int node, int idx) const { std::pair pair(node, idx); return variable(pair); } - Variable* Context::variable(std::pair& p) { + std::shared_ptr Context::variable(const std::pair& p) const { try { return _variableSpace->getVariable(p); } catch (std::exception &e) { @@ -265,62 +202,51 @@ namespace sd { } } - void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) { + void Context::pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array) { std::pair pair(nodeId, index); - pushNDArrayToVariableSpace(pair, array, removable); + pushNDArrayToVariableSpace(pair, array); } - void Context::pushNDArrayToVariableSpace(std::pair &pair, NDArray *array, bool removable) { + void Context::pushNDArrayToVariableSpace(const std::pair &pair, const NDArray &array) { if (_variableSpace != nullptr) { if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(array, nullptr, pair.first, pair.second); + auto var = std::make_shared(array, "", pair.first, pair.second); _variableSpace->putVariable(pair, var); - var->markRemovable(removable); } else { auto var = _variableSpace->getVariable(pair); if (var->hasNDArray()) { - if (var->getNDArray() != array) { - if (var->isRemovable() && var->hasNDArray()) - delete var->getNDArray(); - - var->setNDArray(array); - var->markRemovable(removable); - } - } else { - var->setNDArray(array); - var->markRemovable(removable); + var->setNDArray(std::make_shared(array)); } } } } - void Context::pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList* list, bool track) { + void Context::pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track) { std::pair pair(nodeId, index); pushNDArrayListToVariableSpace(pair, list, track); } - void Context::pushNDArrayListToVariableSpace(std::pair& pair, NDArrayList* list, bool track) { + void Context::pushNDArrayListToVariableSpace(const std::pair& pair, const NDArrayList &list, bool track) { if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(nullptr, nullptr, pair.first, pair.second); - var->setNDArrayList(list); + auto var = std::make_shared(); + var->setId(pair.first, pair.second); + var->setNDArrayList(std::make_shared(list)); _variableSpace->putVariable(pair, var); } else { auto var = _variableSpace->getVariable(pair); - var->setNDArrayList(list); + var->setNDArrayList(std::make_shared(list)); } - - if (track) - _variableSpace->trackList(list); } - Variable* Context::ensureVariable(int idx) { + std::shared_ptr Context::ensureVariable(int idx) { std::pair pair(this->nodeId(), idx); if (_variableSpace == nullptr) throw std::runtime_error("Context::ensureVariable VariableSpace is NULL!"); if (!_variableSpace->hasVariable(pair)) { - auto var = new Variable(nullptr, nullptr, this->nodeId(), idx); + auto var = std::make_shared(); + var->setId(this->nodeId(), idx); auto name = this->name(); if (!name.empty()) @@ -333,8 +259,8 @@ namespace sd { } } - bool Context::isValueAvailable(int idx) { - auto var = ensureVariable(idx); + bool Context::isValueAvailable(int idx) const { + auto var = const_cast(this)->ensureVariable(idx); if (var->variableType() == VariableType::NDARRAY) { return var->hasNDArray(); @@ -345,11 +271,11 @@ namespace sd { return false; } - NDArray* Context::getNDArray(int idx) { + std::shared_ptr Context::getNDArray(int idx) const { return array(idx); } - NDArray* Context::array(int idx) { + std::shared_ptr Context::array(int idx) const { // we check for fastpath first if (!_fastpath_in.empty() && _fastpath_in.size() > idx) { return _fastpath_in[idx]; @@ -359,18 +285,6 @@ namespace sd { return getVariable(idx)->getNDArray(); } - sd::memory::Workspace *Context::fWorkspace() { - return workspace(); - } - - sd::memory::Workspace *Context::tWorkspace() { - return nullptr; - } - - sd::memory::Workspace *Context::oWorkspace() { - return nullptr; - } - LaunchContext* Context::launchContext() { //FIXME: we need proper context to be shared here if (_context == nullptr) { @@ -387,46 +301,39 @@ namespace sd { return _inputs.size(); } - void Context::setInputArray(int index, NDArray *array, bool removable) { + void Context::setInputArray(int index, const NDArray &array) { if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index+1); - _fastpath_in[index] = array; - if (removable) - _handles.emplace_back(array); + _fastpath_in[index] = std::make_shared(array); } void Context::setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - auto array = new NDArray(buffer, specialBuffer, reinterpret_cast(shapeInfo)); + auto array = std::make_shared(buffer, specialBuffer, reinterpret_cast(shapeInfo)); if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index+1); _fastpath_in[index] = array; - _handles.emplace_back(array); if (_context != nullptr) array->setContext(_context); } - void Context::setOutputArray(int index, NDArray *array, bool removable) { + void Context::setOutputArray(int index, const NDArray &array) { if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index+1); - _fastpath_out[index] = array; - - if (removable) - _handles.emplace_back(array); + _fastpath_out[index] = std::make_shared(array); } void Context::setOutputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index+1); - auto array = new NDArray(buffer, specialBuffer, reinterpret_cast(shapeInfo)); + auto array = std::make_shared(buffer, specialBuffer, reinterpret_cast(shapeInfo)); _fastpath_out[index] = array; - _handles.emplace_back(array); if (_context != nullptr) array->setContext(_context); @@ -438,14 +345,13 @@ namespace sd { if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index+1); - NDArray *array; + std::shared_ptr array; if (dataBuffer != nullptr) - array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + array = std::make_shared(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); else - array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + array = std::make_shared(nullptr, nullptr, reinterpret_cast(shapeInfo)); _fastpath_in[index] = array; - _handles.emplace_back(array); if (_context != nullptr) array->setContext(_context); @@ -457,14 +363,13 @@ namespace sd { if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index+1); - NDArray *array; + std::shared_ptr array; if (dataBuffer != nullptr) - array = new NDArray(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); + array = std::make_shared(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); else - array = new NDArray(nullptr, nullptr, reinterpret_cast(shapeInfo)); + array = std::make_shared(nullptr, nullptr, reinterpret_cast(shapeInfo)); _fastpath_out[index] = array; - _handles.emplace_back(array); if (_context != nullptr) array->setContext(_context); @@ -510,7 +415,7 @@ namespace sd { _helpersAllowed = reallyAllow; } - bool Context::helpersAllowed() { + bool Context::helpersAllowed() const { return _helpersAllowed; } @@ -533,11 +438,11 @@ namespace sd { _shapeFunctionOverride = reallyOverride; } - bool Context::shapeFunctionOverride() { + bool Context::shapeFunctionOverride() const { return _shapeFunctionOverride; } - samediff::ExecutionMode Context::executionMode() { + samediff::ExecutionMode Context::executionMode() const { return _execMode; } @@ -545,11 +450,11 @@ namespace sd { _execMode = executionMode; } - bool Context::isTraining() { + bool Context::isTraining() const { return _execMode == samediff::ExecutionMode::MODE_TRAINING; } - bool Context::isInference() { + bool Context::isInference() const { return _execMode == samediff::ExecutionMode::MODE_INFERENCE; } @@ -568,16 +473,19 @@ namespace sd { void Context::clearFastPath() { _fastpath_in.clear(); _fastpath_out.clear(); - - for (auto v:_handles) - delete v; - - _handles.clear(); } const GraphMemoryManager &Context::memoryManager() const { return *_memoryManager; } + + void Context::setInputArray(int index, const std::shared_ptr &array) { + _fastpath_in[index] = array; + } + + void Context::setOutputArray(int index, const std::shared_ptr &array) { + _fastpath_out[index] = array; + } } } diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index e6984bb9705d..0843b45870cc 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -36,7 +36,7 @@ namespace sd { return std::pair(pair->first(), pair->second()); } - NDArray* FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { + NDArray FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { auto rank = static_cast(flatArray->shape()->Get(0)); auto newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; memcpy(newShape, flatArray->shape()->data(), shape::shapeInfoByteLength(rank)); @@ -47,7 +47,7 @@ namespace sd { // empty arrays is special case, nothing to restore here if (shape::isEmpty(newShape)) { delete[] newShape; - return NDArrayFactory::empty_(dtype, nullptr); + return NDArrayFactory::empty(dtype, nullptr); } // TODO fix UTF16 and UTF32 if (dtype == UTF8) { @@ -88,7 +88,7 @@ namespace sd { delete[] offsets; delete[] newShape; // string order always 'c' - return NDArrayFactory::string_(shapeVector, substrings); + return NDArrayFactory::string(shapeVector, substrings); } @@ -96,7 +96,7 @@ namespace sd { BUILD_SINGLE_SELECTOR(dtype, DataTypeConversions, ::convertType(newBuffer, (void *)flatArray->buffer()->data(), dtype, ByteOrderUtils::fromFlatByteOrder(flatArray->byteOrder()), length), LIBND4J_TYPES); - auto array = new NDArray(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); + NDArray array(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); delete[] newShape; return array; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 01532ff0666c..c6265ebe32f3 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -37,24 +37,24 @@ namespace sd { namespace graph { - const std::vector& Graph::getPlaceholders() const { - return *_variableSpace->getPlaceholders(); + const std::vector>& Graph::placeholders() const { + return _variableSpace.placeholders(); } int Graph::numberOfPlaceholders() const { - return _variableSpace->numberOfPlaceholders(); + return _variableSpace.numberOfPlaceholders(); }; const ExecutorConfiguration& Graph::getExecutorConfiguration() const { return _configuration; } - VariableSpace * Graph::variableSpace() const { - return _variableSpace; + VariableSpace& Graph::variableSpace() const { + return const_cast(_variableSpace); } Graph::~Graph() { - delete _variableSpace; + } int Graph::idByName(const std::string &nodeName) const { @@ -67,7 +67,7 @@ namespace sd { void Graph::addVariable(const std::string &name, NDArray &array) { int id = _maxId++; _symbolicLookupTable[name] = id; - _variableSpace->putVariable(id, 0, array); + _variableSpace.putVariable(id, 0, array); } void Graph::addVariable(const std::string &name, NDArray &&array) { @@ -118,8 +118,7 @@ namespace sd { throw std::runtime_error("not implemented yet"); } - Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { - this->_variableSpace = variableSpace == nullptr ? new VariableSpace() : variableSpace; + Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { bool trusted = flatGraph != nullptr; // if there was no exec configuration in flatgraph - create default one @@ -130,8 +129,7 @@ namespace sd { // if memory reqs were set - initialize workspace if (_configuration._footprintForward > 0) { - sd::memory::Workspace *workspace = this->_variableSpace->launchContext()->getWorkspace(); - workspace->expandBy(_configuration._footprintForward); + _workspace.expandBy(_configuration._footprintForward); } // parsing variables here @@ -140,13 +138,13 @@ namespace sd { auto flatVar = flatGraph->variables()->Get(e); std::pair pair(flatVar->id()->first(), flatVar->id()->second()); - auto var = new Variable(flatVar); + auto var = std::make_shared(flatVar); if (flatVar->name() != nullptr) { var->setName(flatVar->name()->str()); _symbolicLookupTable[var->name()] = pair.first; } - _variableSpace->putVariable(pair, var); + _variableSpace.putVariable(pair, var); } } @@ -157,7 +155,7 @@ namespace sd { for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { auto out = flatGraph->outputs()->Get(e); std::pair vp(out->first(), out->second()); - if (!_variableSpace->hasVariable(vp)) { + if (!_variableSpace.hasVariable(vp)) { nd4j_verbose("Non-existent variable requested: %i\n", out); throw std::runtime_error("Non-existent variable requested"); } @@ -245,13 +243,13 @@ namespace sd { void Graph::printOut() { // print variables first - if (_variableSpace->totalEntries() > 0) { + if (_variableSpace.totalEntries() > 0) { nd4j_printf("\nPrinting out Variables...\n", ""); - auto vars = _variableSpace->getVariables(); + auto vars = _variableSpace.variables(); - for (Variable* v: vars) { + for (auto &v: vars) { if (v->hasNDArray()) { - auto shape = ShapeUtils::shapeAsString(v->getNDArray()); + auto shape = ShapeUtils::shapeAsString(v->getNDArray().get()); auto values = v->getNDArray()->asString(16); auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); @@ -277,29 +275,24 @@ namespace sd { return ND4J_STATUS_OK; } - void Graph::forgetVariableSpace() { - _variableSpace = nullptr; - } - void Graph::replaceState(VariableSpace *state, const ExecutorConfiguration &configuration) { - delete _variableSpace; - - _variableSpace = state; + _variableSpace = *state; _configuration = configuration; } Graph Graph::cloneWithProxy() const { Graph clone; - clone.replaceState(new VariableProxy(this->_variableSpace), this->_configuration); + //clone.replaceState(new VariableProxy(&this->_variableSpace), this->_configuration); - return clone; + //return clone; + throw std::runtime_error("Graph::cloneWithProxy - Not implemented yet"); } Graph* Graph::clone() const { auto clone = new Graph(); - clone->replaceState(this->_variableSpace->clone(), this->_configuration.clone()); + //clone->replaceState(&this->_variableSpace, this->_configuration.clone()); throw std::runtime_error("Graph::clone - not implemented yet"); } @@ -349,7 +342,7 @@ namespace sd { auto fg = GetFlatGraph(reinterpret_cast(ptr)); // return Graph from this FlatGraph - return Graph(fg, nullptr, memoryManager); + return Graph(fg, memoryManager); } Graph Graph::importFromTensorFlow(const char *fileName) { @@ -556,9 +549,9 @@ namespace sd { _symbolicLookupTable[nodeName] = id; - auto var = new Variable(true, dataType, shape); + auto var = std::make_shared(true, dataType, shape); var->setName(nodeName); - _variableSpace->putVariable(id, var); + _variableSpace.putVariable(id, var); _placeholders.emplace_back(nodeName); } @@ -571,7 +564,7 @@ namespace sd { throw unresolved_input_exception::build("Dictionary entry doesn't exist", v.first); // we also check if arrays provided here do match placeholder restrictions of shape and dtype - auto var = _variableSpace->getVariable(v.first); + auto var = _variableSpace.getVariable(v.first); if (var->dataType() != DataType::ANY && var->dataType() != v.second.dataType()) throw datatype_exception::build("Placeholder requires another data type", var->dataType(), v.second.dataType()); @@ -602,10 +595,10 @@ namespace sd { // fetch outputs from VariableSpace std::map result; for (const auto &v:outputs) { - if (!_variableSpace->hasVariable(v)) + if (!_variableSpace.hasVariable(v)) throw unresolved_output_exception::build("Requested output doesn't exist after execution", v); - auto var = _variableSpace->getVariable(v); + auto var = _variableSpace.getVariable(v); // TODO: we want to make sure ManagedDataBuffer doesn't leak here result[v] = *var->getNDArray(); diff --git a/libnd4j/include/graph/impl/InferenceRequest.cpp b/libnd4j/include/graph/impl/InferenceRequest.cpp index 29fde1eb1d5b..6aff5a96efb2 100644 --- a/libnd4j/include/graph/impl/InferenceRequest.cpp +++ b/libnd4j/include/graph/impl/InferenceRequest.cpp @@ -23,53 +23,51 @@ namespace sd { namespace graph { - InferenceRequest::InferenceRequest(Nd4jLong graphId, ExecutorConfiguration *configuration) { + InferenceRequest::InferenceRequest(Nd4jLong graphId, const ExecutorConfiguration &configuration) { this->_id = graphId; this->_configuration = configuration; } InferenceRequest::~InferenceRequest() { - for (auto v : _deletables) - delete v; + // } - void InferenceRequest::appendVariable(int id, NDArray *array) { + void InferenceRequest::appendVariable(int id, const NDArray &array) { appendVariable(id, 0, array); } - void InferenceRequest::appendVariable(int id, int index, NDArray *array) { - auto v = new Variable(array, nullptr, id, index); + void InferenceRequest::appendVariable(int id, int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), nullptr, id, index); insertVariable(v); } - void InferenceRequest::appendVariable(std::string &id, NDArray *array) { - auto v = new Variable(array, id.c_str()); + void InferenceRequest::appendVariable(const std::string &id, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), id.c_str()); insertVariable(v); } - void InferenceRequest::appendVariable(std::string &name, int id, int index, NDArray *array) { - auto v = new Variable(array, name.c_str(), id, index); + void InferenceRequest::appendVariable(const std::string &name, int id, int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), name, id, index); insertVariable(v); } - void InferenceRequest::insertVariable(Variable *variable) { + void InferenceRequest::insertVariable(std::shared_ptr variable) { variable->markRemovable(false); variable->markReadOnly(true); _variables.emplace_back(variable); - _deletables.emplace_back(variable); } - void InferenceRequest::appendVariable(Variable *variable) { + void InferenceRequest::appendVariable(std::shared_ptr variable) { _variables.emplace_back(variable); } flatbuffers::Offset InferenceRequest::asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder) { std::vector> vec; - for (Variable* v : _variables) { + for (const auto &v : _variables) { vec.emplace_back(v->asFlatVariable(builder)); } - auto confOffset = _configuration != nullptr ? _configuration->asFlatConfiguration(builder) : 0; + auto confOffset = _configuration.asFlatConfiguration(builder); auto vecOffset = builder.CreateVector(vec); diff --git a/libnd4j/include/graph/impl/SessionLocalStorage.cpp b/libnd4j/include/graph/impl/SessionLocalStorage.cpp deleted file mode 100644 index 9c512b0b6f48..000000000000 --- a/libnd4j/include/graph/impl/SessionLocalStorage.cpp +++ /dev/null @@ -1,123 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -namespace sd { - namespace graph { - SessionLocalStorage::SessionLocalStorage(VariableSpace* variableSpace, Stash* stash) { - // we start from 1, since key 0 holds original VariableSpace - _sessionCounter.store(1); - _variableSpace = variableSpace; - _stash = stash; - } - - VariableSpace* SessionLocalStorage::localVariableSpace(Nd4jLong sessionId) { - _mutex.lock(); - auto varSpace = _threadVariableSpace.at(sessionId); - _mutex.unlock(); - - return varSpace; - } - - VariableSpace* SessionLocalStorage::localVariableSpace() { - return localVariableSpace(getSessionId()); - } - - SessionLocalStorage::~SessionLocalStorage() { - for (const auto & v: _threadVariableSpace) { - delete v.second; - } - } - - - Nd4jLong SessionLocalStorage::getThreadId() { -#ifdef __APPLE__ - // syscall? -#elif _WIN32 - // some win32api -#else - // syscall! -#endif - auto id=std::this_thread::get_id(); - uint64_t* ptr=(uint64_t*) &id; - return (*ptr); - } - - int SessionLocalStorage::numberOfSessions() { - _mutex.lock(); - int size = (int) _threadSession.size(); - _mutex.unlock(); - return size; - } - - void SessionLocalStorage::endSession(Nd4jLong sessionId) { - // we should delete specific holders here - _mutex.lock(); - auto vs = _threadVariableSpace[sessionId]; - _threadVariableSpace.erase(sessionId); - - delete vs; - _mutex.unlock(); - } - - void SessionLocalStorage::endSession() { - auto tid = getThreadId(); - - _mutex.lock(); - - auto ntid = _threadSession[tid]; - _threadSession.erase(tid); - - _mutex.unlock(); - - endSession(ntid); - } - - Nd4jLong SessionLocalStorage::getSessionId() { - auto tid = getThreadId(); - - _mutex.lock(); - auto ntid = _threadSession[tid]; - - _mutex.unlock(); - - return ntid; - } - - Nd4jLong sd::graph::SessionLocalStorage::startSession() { - auto tid = getThreadId(); - - nd4j_debug("Adding ThreadId: %i;\n", (int) tid); - Nd4jLong ntid = _sessionCounter++; - _mutex.lock(); - - _threadSession[tid] = ntid; - _threadVariableSpace[ntid] = _variableSpace->clone(); - - _mutex.unlock(); - - return ntid; - } - } -} - diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index fd89b2d602e8..01a8a94008cb 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -28,48 +28,18 @@ namespace sd { namespace graph { + Variable::Variable(const NDArray &array, const std::string &name, int id, int idx) { + _ndarray = std::make_shared(array); - template - Variable* Variable::asT() const { - auto result = new Variable(this->isPlaceholder()); + if (!name.empty()) + _name = name; - result->markExternal(this->_external); - result->setId(this->_id); - result->markReadOnly(this->_readOnly); - result->setName(this->_name); - result->setIndex(this->_index); - - if (this->_ndarray != nullptr) - result->setNDArray(new NDArray(this->_ndarray->template asT())); - - // FIXME: add support for ArrayList - if (this->_list != nullptr) { - nd4j_printf("ArrayList not supported yet\n", ""); - throw std::runtime_error("ArrayList not supported yet for asT"); - } - - return result; + _id = id; + _index = idx; } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT Variable* Variable::asT, () const, LIBND4J_TYPES); - sd::graph::Variable* sd::graph::Variable::clone() const { - auto result = new Variable(this->isPlaceholder()); - result->_external = this->_external; - result->_id = this->_id; - result->_readOnly = this->_readOnly; - result->_name = this->_name; - result->_index = this->_index; - - if (this->_ndarray != nullptr) { - result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering())); - result->_readOnly = false; - result->_removable = true; - } - - if (this->_list != nullptr) - result->_list = this->_list->clone(); - - return result; + Variable::Variable() { + // } void sd::graph::Variable::setIndex(int index) { @@ -77,7 +47,7 @@ namespace sd { } bool sd::graph::Variable::hasNDArray() const { - return _ndarray != nullptr; + return _ndarray.get() != nullptr; } void sd::graph::Variable::setVariableType(VariableType variableType) { @@ -85,7 +55,7 @@ namespace sd { } bool sd::graph::Variable::hasNDArrayList() const { - return _list != nullptr; + return _list.get() != nullptr; } bool sd::graph::Variable::isPlaceholder() const { @@ -147,12 +117,12 @@ namespace sd { this->_readOnly = reallyReadOnly; } - sd::NDArray * sd::graph::Variable::getNDArray() const { + std::shared_ptr sd::graph::Variable::getNDArray() const { if (_variableType != VariableType::NDARRAY) { nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); } - if (this->_ndarray == nullptr) { + if (this->_ndarray.get() == nullptr) { if (_name.empty()) { auto nodeId = StringUtils::valueToString(this->id()); auto outputIndex = StringUtils::valueToString(this->index()); @@ -166,7 +136,7 @@ namespace sd { return this->_ndarray; } - sd::NDArrayList * sd::graph::Variable::getNDArrayList() const { + std::shared_ptr sd::graph::Variable::getNDArrayList() const { if (_variableType != VariableType::ARRAY_LIST) { nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); } @@ -179,13 +149,13 @@ namespace sd { } - void sd::graph::Variable::setNDArrayList(sd::NDArrayList * list) { + void sd::graph::Variable::setNDArrayList(std::shared_ptr list) { this->_variableType = VariableType::ARRAY_LIST; this->_list = list; } - void sd::graph::Variable::setNDArray(sd::NDArray * array) { + void sd::graph::Variable::setNDArray(std::shared_ptr array) { this->_variableType = VariableType::NDARRAY; this->_ndarray = array; } @@ -215,7 +185,7 @@ namespace sd { // ????? if (flatVariable->ndarray() != nullptr) { auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); + _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); } _variableType = VariableType::NDARRAY; @@ -227,9 +197,9 @@ namespace sd { auto ar = flatVariable->ndarray(); if (ar->dtype() == DType_UTF8) { - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); + _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); } else { - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); + _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); } _variableType = VariableType::NDARRAY; @@ -240,7 +210,7 @@ namespace sd { // ????? if (flatVariable->ndarray() != nullptr) { auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); + _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); // _ndarray->triggerAllocationFlag(true); } @@ -253,7 +223,7 @@ namespace sd { if (flatVariable->ndarray() != nullptr) { auto ar = flatVariable->ndarray(); - _ndarray = sd::graph::FlatUtils::fromFlatArray(ar); + _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); // _ndarray->triggerAllocationFlag(true); _variableType = VariableType::NDARRAY; @@ -285,7 +255,7 @@ namespace sd { } - sd::graph::Variable::Variable(NDArray *array, const char *name ) { + sd::graph::Variable::Variable(std::shared_ptr array, const char *name ) { _ndarray = array; _external = false; @@ -302,19 +272,14 @@ namespace sd { return _dtype; } - sd::graph::Variable::Variable(NDArray *array, const char *name, int id, int idx) : Variable(array, name) { + sd::graph::Variable::Variable(std::shared_ptr array, const std::string &name, int id, int idx) : Variable(array, name.c_str()) { _id = id; _index = idx; } sd::graph::Variable::~Variable() { - //nd4j_printf("Removing variable [%i:%i]\n", _id, _index); - if (_variableType == VariableType::NDARRAY) { - nd4j_debug("Removing variable <%i:%i>\n", _id, _index); - if (_ndarray != nullptr && _removable && !_readOnly) - delete _ndarray; - } + // } diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index dee01997d904..b7b087fde3a7 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -38,46 +38,46 @@ namespace sd { } - int VariableProxy::numberOfPlaceholders() { + int VariableProxy::numberOfPlaceholders() const { return _backed->numberOfPlaceholders(); } - std::vector* VariableProxy::getPlaceholders() { - return _backed->getPlaceholders(); + const std::vector>& VariableProxy::placeholders() const { + return _backed->placeholders(); } - bool VariableProxy::hasExternalVariable(int it) { + bool VariableProxy::hasExternalVariable(int it) const { return _backed->hasExternalVariable(it); } - bool VariableProxy::hasExternalVariable(std::pair& pair) { + bool VariableProxy::hasExternalVariable(const std::pair& pair) const { return _backed->hasExternalVariable(pair); } - bool VariableProxy::hasExternalVariable(const std::string &symbol) { + bool VariableProxy::hasExternalVariable(const std::string &symbol) const { return _backed->hasExternalVariable(symbol); } - bool VariableProxy::hasVariable(int id) { + bool VariableProxy::hasVariable(int id) const { return _current->hasVariable(id) || _backed->hasVariable(id); } - bool VariableProxy::hasVariable(int id, int idx) { + bool VariableProxy::hasVariable(int id, int idx) const { return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); } - bool VariableProxy::hasVariable(std::pair& pair) { + bool VariableProxy::hasVariable(const std::pair& pair) const { return _current->hasVariable(pair) || _backed->hasVariable(pair); } - void VariableProxy::dropVariable(std::pair &pair) { + void VariableProxy::dropVariable(const std::pair &pair) { dropVariable(pair.first, pair.second); } @@ -89,11 +89,11 @@ namespace sd { } - std::vector VariableProxy::getVariables() { - std::vector result; + std::vector> VariableProxy::variables() const { + std::vector> result; - auto b = _backed->getVariables(); - auto c = _current->getVariables(); + auto b = _backed->variables(); + auto c = _current->variables(); for (auto v: b) result.emplace_back(v); @@ -105,60 +105,60 @@ namespace sd { } - bool VariableProxy::hasVariable(const std::string &symbol) { + bool VariableProxy::hasVariable(const std::string &symbol) const { return _current->hasVariable(symbol) || _backed->hasVariable(symbol); } - Variable *VariableProxy::getVariable(int id) { + std::shared_ptr VariableProxy::getVariable(int id) const { if (_current->hasVariable(id)) return _current->getVariable(id); if (_backed->hasVariable(id)) return _backed->getVariable(id); - nd4j_printf("Unable to get Variable to proxy: [%i]\n", id); + nd4j_printf("Unable to get Variable from proxy: [%i]\n", id); throw std::runtime_error("Bad arguments"); } - Variable *VariableProxy::getVariable(int id, int idx) { + std::shared_ptr VariableProxy::getVariable(int id, int idx) const { if (_current->hasVariable(id, idx)) return _current->getVariable(id, idx); if (_backed->hasVariable(id, idx)) return _backed->getVariable(id, idx); - nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", id, idx); + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", id, idx); throw std::runtime_error("Bad arguments"); } - Variable *VariableProxy::getVariable(std::pair& pair) { + std::shared_ptr VariableProxy::getVariable(const std::pair& pair) const { if (_current->hasVariable(pair)) return _current->getVariable(pair); if (_backed->hasVariable(pair)) return _backed->getVariable(pair); - nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", pair.first, pair.second); + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", pair.first, pair.second); throw std::runtime_error("Bad arguments"); } - Variable *VariableProxy::getVariable(const std::string &symbol) { + std::shared_ptr VariableProxy::getVariable(const std::string &symbol) const { if (_current->hasVariable(symbol)) return _current->getVariable(symbol); if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); - nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol.c_str()); + nd4j_printf("Unable to get Variable from proxy: [%s]\n", symbol.c_str()); throw std::runtime_error("Bad arguments"); } - void VariableProxy::replaceVariable(Variable *variable) { + void VariableProxy::replaceVariable(std::shared_ptr variable) { if (!variable->getName().empty()) { // if variable has name defined - we should resolve it via backing var space if (_backed->hasVariable(variable->getName())) { @@ -171,104 +171,74 @@ namespace sd { _current->replaceVariable(variable); } - - Variable* VariableProxy::putVariable(std::pair& pair, NDArray *array) { - return _current->putVariable(pair, array); - } - - - void VariableProxy::putVariable(std::pair& pair, Variable *variable) { - _current->putVariable(pair, variable); - } - - - void VariableProxy::putVariable(int id, Variable *variable) { - _current->putVariable(id, variable); + std::shared_ptr + VariableProxy::putVariable(const std::string &name, int id, int idx, const NDArray &array) { + return _current->putVariable(name, id, idx, array); } - - void VariableProxy::putVariable(int id, NDArray *array) { - _current->putVariable(id, array); + void VariableProxy::putOutputVariable(std::shared_ptr variable) { + _current->putOutputVariable(variable); } - void VariableProxy::putVariable(int id, int idx, const NDArray &array) { - _current->putVariable(id, idx, array); - } - - Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) { - return _current->putVariable(id, idx, array); + std::shared_ptr VariableProxy::putVariable(const std::pair& pair, const NDArray &array) { + return _current->putVariable(pair, array); } - void VariableProxy::putVariable(int id, int idx, Variable *array) { - _current->putVariable(id, idx, array); + void VariableProxy::putVariable(const std::pair& pair, std::shared_ptr variable) { + _current->putVariable(pair, variable); } - void VariableProxy::trackList(sd::NDArrayList* list) { - _current->trackList(list); + void VariableProxy::putVariable(int id, std::shared_ptr variable) { + _current->putVariable(id, variable); } - Stash* VariableProxy::getStash() { - return _current->getStash(); + std::shared_ptr VariableProxy::putVariable(int id, const NDArray &array) { + return _current->putVariable(id, array); } - - void VariableProxy::setFlowPath(FlowPath* timers) { - _current->setFlowPath(timers); + std::shared_ptr VariableProxy::putVariable(int id, int idx, const NDArray &array) { + return _current->putVariable(id, idx, array); } - - FlowPath* VariableProxy::flowPath() { - return _current->flowPath(); + void VariableProxy::putVariable(const std::string& name, int id, int idx, std::shared_ptr array) { + _current->putVariable(name, id, idx, array); } - - void VariableProxy::putOutputVariable(Variable *variable) { - _current->putOutputVariable(variable); + Stash* VariableProxy::stash() const { + return _current->stash(); } - - Nd4jLong VariableProxy::externalMemory() { + Nd4jLong VariableProxy::externalMemory() const { return _backed->externalMemory() + _current->externalMemory(); } - Nd4jLong VariableProxy::internalMemory() { + Nd4jLong VariableProxy::internalMemory() const { return _backed->internalMemory() + _current->internalMemory(); } - Nd4jLong VariableProxy::totalMemory() { + Nd4jLong VariableProxy::totalMemory() const { return _backed->totalMemory() + _current->totalMemory(); } - int VariableProxy::externalEntries() { + int VariableProxy::externalEntries() const { return _backed->externalEntries() + _current->externalEntries(); } - int VariableProxy::internalEntries() { + int VariableProxy::internalEntries() const { return _backed->internalEntries() + _current->internalEntries(); } - int VariableProxy::totalEntries() { + int VariableProxy::totalEntries() const { return _backed->totalEntries() + _current->totalEntries(); } - - - VariableSpace* VariableProxy::clone() { - auto clone = new VariableProxy(_backed); - - delete clone->_current; - clone->_current = _current->clone(); - - return clone; - } - VariableSpace& VariableProxy::operator=(const VariableSpace& other) { if (this == &other) return *this; @@ -276,11 +246,6 @@ namespace sd { nd4j_printf("VariableProxy = not implemented\n",""); return *this; - } - - - sd::memory::Workspace * VariableProxy::workspace() { - return _workspace; } } } diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 024768bcf689..b5eacfe32507 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -23,86 +23,43 @@ namespace sd { namespace graph { - std::vector * VariableSpace::getExternalVariables() { - return &_external; + Stash* VariableSpace::stash() const { + return const_cast(&_stash); } - Stash* VariableSpace::getStash() { - return &_stash; - } - - VariableSpace* VariableSpace::clone() { - auto result = new VariableSpace(); - - for (auto const& x : _paired) { - std::pair pair(x.first.first, x.first.second); - - Variable* clonedVar = x.second->clone(); - - result->injectVariable(pair, clonedVar); - } - - return result; - } - - void VariableSpace::setWorkspace(sd::memory::Workspace *workspace) { - //_workspace = *workspace; - } - - - VariableSpace* VariableSpace::asT() { - auto result = new VariableSpace(); - - for (auto const& x : _paired) { - std::pair pair(x.first.first, x.first.second); - - //Variable* clonedVar = x.second->template asT(); - - //result->injectVariable(pair, clonedVar); - } - - return result; - } - - - void VariableSpace::injectVariable(std::pair &pair, Variable* variable) { + void VariableSpace::injectVariable(const std::pair &pair, std::shared_ptr variable) { if (pair.second == 0) { - if (pair.first < 0) - this->_variables[pair.first] = variable; - else - this->_temporary[pair.first] = variable; + this->_variables[pair.first] = variable; } if (!variable->getName().empty()) this->_symbolic[variable->getName()] = variable; this->_paired[pair] = variable; - - this->_handles->push_back(variable); } - std::vector * VariableSpace::getPlaceholders() { - return &_placeholders; + const std::vector>& VariableSpace::placeholders() const { + return _placeholders; } - int VariableSpace ::numberOfPlaceholders() { + int VariableSpace::numberOfPlaceholders() const { return _placeholders.size(); } - bool VariableSpace::hasVariable(const std::string &symbol) { - return _symbolic.count(symbol) == 1; + bool VariableSpace::hasVariable(const std::string &symbol) const { + return _symbolic.count(symbol) > 0; } - Variable * VariableSpace::getVariable(const std::string &symbol) { + std::shared_ptr VariableSpace::getVariable(const std::string &symbol) const { return _symbolic.at(symbol); } - bool VariableSpace::hasVariable(int id, int index) { + bool VariableSpace::hasVariable(int id, int index) const { std::pair pair(id, index); return hasVariable(pair); } - bool VariableSpace::hasExternalVariable(int id) { + bool VariableSpace::hasExternalVariable(int id) const { if (!hasVariable(id)) return false; @@ -110,7 +67,7 @@ namespace sd { return var->isExternal(); } - bool VariableSpace::hasExternalVariable(std::pair& pair) { + bool VariableSpace::hasExternalVariable(const std::pair& pair) const { if (!hasVariable(pair)) return false; @@ -118,7 +75,7 @@ namespace sd { return var->isExternal(); } - bool VariableSpace::hasExternalVariable(const std::string &symbol) { + bool VariableSpace::hasExternalVariable(const std::string &symbol) const { if (!hasVariable(symbol)) return false; @@ -126,12 +83,12 @@ namespace sd { return var->isExternal(); } - Variable * VariableSpace::getVariable(int id, int index) { + std::shared_ptr VariableSpace::getVariable(int id, int index) const { std::pair pair(id, index); return getVariable(pair); } - Variable * VariableSpace::getVariable(std::pair& pair) { + std::shared_ptr VariableSpace::getVariable(const std::pair& pair) const { if (pair.first < 0) return getVariable(pair.first); else @@ -141,32 +98,32 @@ namespace sd { throw std::runtime_error("Unknown variable requested"); } - bool VariableSpace::hasVariable(int id) { - return _variables.count(id) == 1 || _temporary.count(id) == 1; + bool VariableSpace::hasVariable(int id) const { + return _variables.count(id) == 1; } - bool VariableSpace::hasVariable(std::pair& id) { + bool VariableSpace::hasVariable(const std::pair& id) const { return _paired.count(id) > 0; } - void VariableSpace::putOutputVariable(Variable *variable) { + void VariableSpace::putOutputVariable(std::shared_ptr variable) { //putVariable(_auto_counter--, variable); putVariable(variable->id(), variable); } - int VariableSpace::externalEntries() { + int VariableSpace::externalEntries() const { return _external.size(); } - int VariableSpace::internalEntries() { + int VariableSpace::internalEntries() const { return _internal.size(); } - int VariableSpace::totalEntries() { + int VariableSpace::totalEntries() const { return externalEntries() + internalEntries(); } - Nd4jLong VariableSpace::externalMemory() { + Nd4jLong VariableSpace::externalMemory() const { Nd4jLong size = 0; for (auto n: _external) { size += n->getNDArray()->memoryFootprint(); @@ -175,8 +132,8 @@ namespace sd { return size; } - std::vector VariableSpace::getVariables() { - std::vector result; + std::vector> VariableSpace::variables() const { + std::vector> result; for (auto v: _internal) result.emplace_back(v); @@ -187,7 +144,7 @@ namespace sd { return result; } - Nd4jLong VariableSpace::internalMemory() { + Nd4jLong VariableSpace::internalMemory() const { Nd4jLong size = 0; for (auto n: _internal) { size += n->getNDArray()->memoryFootprint(); @@ -196,36 +153,50 @@ namespace sd { return size; } - Nd4jLong VariableSpace::totalMemory() { + Nd4jLong VariableSpace::totalMemory() const { return externalMemory() + internalMemory(); } - Variable* VariableSpace::putVariable(std::pair& pair, NDArray *array) { - auto variable = new Variable(array, nullptr, pair.first, pair.second); + std::shared_ptr VariableSpace::putVariable(int id, int idx, const std::shared_ptr &array) { + auto variable = std::make_shared(array, "", id, idx); + this->putVariable({id, idx}, variable); + return variable; + } + + std::shared_ptr + VariableSpace::putVariable(const std::string &name, int id, int idx, const NDArray &array) { + return std::shared_ptr(); + } + + void VariableSpace::dropVariable(const std::string &pair) { + throw std::runtime_error("VariableSpace::dropVariable - not implemented yet"); + } + + + std::shared_ptr VariableSpace::putVariable(const std::pair& pair, const NDArray &array) { + auto variable = std::make_shared(array, nullptr, pair.first, pair.second); this->putVariable(pair, variable); return variable; } - Variable* VariableSpace::putVariable(int node, int idx, NDArray *array) { + std::shared_ptr VariableSpace::putVariable(int node, int idx, const NDArray &array) { std::pair pair(node, idx); return this->putVariable(pair, array); } - void VariableSpace::putVariable(int node, int idx, Variable *variable) { + void VariableSpace::putVariable(const std::string& name, int node, int idx, std::shared_ptr variable) { std::pair pair(node, idx); + variable->setName(name); this->putVariable(pair, variable); } - void VariableSpace::silentPutVariable(std::pair& pair, Variable *variable) { - _varmap.lock(); + void VariableSpace::silentPutVariable(const std::pair& pair, std::shared_ptr variable) { + std::lock_guard lock(_varmap); - //std::pair, Variable *> p(pair, variable); _paired[pair] = variable; - - _varmap.unlock(); } - void VariableSpace::putVariable(std::pair& pair, Variable *variable) { + void VariableSpace::putVariable(const std::pair& pair, std::shared_ptr variable) { silentPutVariable(pair, variable); if (variable->isPlaceholder()) @@ -238,63 +209,41 @@ namespace sd { if (!variable->getName().empty()) { _symbolic[variable->getName()] = variable; } - - _varmap.lock(); - - _handles->push_back(variable); - - _varmap.unlock(); } } - void VariableSpace::trackList(sd::NDArrayList* list) { - _lists.emplace_back(list); - } - void VariableSpace::putVariable(int id, Variable *variable) { + void VariableSpace::putVariable(int id, std::shared_ptr variable) { // we don't want to add variables more then once - if (_variables.count(id) > 0 || _temporary.count(id) > 0) { - auto local = id < 0 ? _variables.at(id) : _temporary.at(id); - - if (!local->hasNDArray() && variable->hasNDArray()) { - local->setNDArray(variable->getNDArray()); - - // we're inheriting this from Variable - local->markReadOnly(variable->isReadOnly()); - local->markRemovable(variable->isRemovable()); - } - - return; + if (_variables.count(id) > 0) { + throw std::runtime_error("VariableSpace::putVariable - duplicate found"); } - _varmap.lock(); - - _handles->emplace_back(variable); - - if (_auto_counter >= id) - _auto_counter = id - 1; + { + std::lock_guard lock(_varmap); - variable->setId(id); + if (_auto_counter >= id) + _auto_counter = id - 1; - if (!variable->getName().empty()) { - //std::pair pair(*(variable->getName()), variable); - _symbolic[variable->getName()] = variable; - } + variable->setId(id); - // we have special list for external variables to ensure graph completeness + if (!variable->getName().empty()) { + //std::pair pair(*(variable->getName()), variable); + _symbolic[variable->getName()] = variable; + } - if (id < 0) { - //if (variable->isExternal()) - _external.push_back(variable); + // we have special list for external variables to ensure graph completeness - _variables[id] = variable; - } else { - _internal.push_back(variable); + if (id < 0) { + //if (variable->isExternal()) + _external.push_back(variable); - _temporary[id] = variable; + _variables[id] = variable; + } else { + _internal.push_back(variable); + } } - _varmap.unlock(); std::pair pair(id, 0); if (!hasVariable(pair)) { @@ -305,91 +254,96 @@ namespace sd { } } - void VariableSpace::putVariable(int id, int idx, const NDArray &array) { - auto *var = new Variable(const_cast(&array), "", id, idx); - var->markRemovable(false); - var->markReadOnly(true); - - // let's see if this op needs - bool d = this->hasVariable(id, idx); - + std::shared_ptr VariableSpace::putVariable(int id, const NDArray &array) { + auto var = std::make_shared(array, "", id, 0); this->putVariable(id, var); - - // if var for this nodeid already exists - we'll just delete variable - if (d) - delete var; } - void VariableSpace::putVariable(int id, NDArray *array) { - auto *var = new Variable(array); - this->putVariable(id, var); + std::shared_ptr VariableSpace::getVariable(int id) const { + return _variables.at(id); } - Variable * VariableSpace::getVariable(int id) { - if (id < 0) { - return _variables.at(id); - } else { - return _temporary.at(id); - } + VariableSpace::~VariableSpace() { + // } - LaunchContext* VariableSpace::launchContext() { - return LaunchContext::defaultContext(); + VariableSpace::VariableSpace(const VariableSpace &other) { + _stash = other._stash; + + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; + + _external = other._external; + _internal = other._internal; + + _lists = other._lists; + _placeholders = other._placeholders; + + + _auto_counter = other._auto_counter; } - std::vector* VariableSpace::handles() { - return _handles; + VariableSpace::VariableSpace(VariableSpace &&other) { + _stash = std::move(other._stash); + + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); + + _external = std::move(other._external); + _internal = std::move(other._internal); + + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); + + + _auto_counter = other._auto_counter; } -/* - * FIXME: this thing have nice chances to become backend-specific! - */ - VariableSpace::~VariableSpace() { - // loop through variables and release them - for (auto p: *_handles) { - delete p; - } + VariableSpace& VariableSpace::operator=(VariableSpace &&other) { + if (this == &other) return *this; + + _stash = std::move(other._stash); + + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); + + _external = std::move(other._external); + _internal = std::move(other._internal); - delete _handles; + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); - //_internal.clear(); - //_external.clear(); - //_temporary.clear(); - //nd4j_printf("Number of NDArrayLists in this space: [%i]\n", _lists.size()) - for (auto p: _lists) - delete p; + _auto_counter = other._auto_counter; - _lists.clear(); + return *this; } VariableSpace& VariableSpace::operator=(const VariableSpace& other) { if (this == &other) return *this; - for (auto const& x : other._paired) { - std::pair pair(x.first.first, x.first.second); + _stash = other._stash; - Variable* clonedVar = x.second->clone(); + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; - if (pair.second == 0) { - if (pair.first < 0) - this->_variables[pair.first] = clonedVar; - else - this->_temporary[pair.first] = clonedVar; - } + _external = other._external; + _internal = other._internal; - if (!clonedVar->getName().empty()) - this->_symbolic[clonedVar->getName()] = clonedVar; + _lists = other._lists; + _placeholders = other._placeholders; - this->_paired[pair] = clonedVar; - this->_handles->push_back(clonedVar); - } + _auto_counter = other._auto_counter; return *this; } - void VariableSpace::replaceVariable(Variable *variable) { + void VariableSpace::replaceVariable(std::shared_ptr variable) { bool replaced = false; // trying name first if (!variable->getName().empty()) { @@ -398,7 +352,8 @@ namespace sd { nd4j_printf("Replacing by name: [%s]\n", variable->getName().c_str()); auto vs = getVariable(variable->getName()); dropVariable(vs->id(), vs->index()); - putVariable(vs->id(), vs->index(), variable); + + putVariable({vs->id(), vs->index()}, variable); //delete vs; replaced = true; } @@ -408,7 +363,7 @@ namespace sd { nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), variable->index()); auto vs = getVariable(variable->id(), variable->index()); dropVariable(variable->id(), variable->index()); - putVariable(vs->id(), vs->index(), variable); + putVariable({vs->id(), vs->index()}, variable); //delete vs; replaced = true; } @@ -416,11 +371,11 @@ namespace sd { if (!replaced) { nd4j_printf("wasn't able to replace variable, putting\n", ""); - putVariable(variable->id(), variable->index(), variable); + putVariable({variable->id(), variable->index()}, variable); } } - void VariableSpace::dropVariable(std::pair &pair) { + void VariableSpace::dropVariable(const std::pair &pair) { dropVariable(pair.first, pair.second); } @@ -428,17 +383,8 @@ namespace sd { } - - void VariableSpace::setFlowPath(FlowPath* flow) { - _flow = flow; - } - - FlowPath* VariableSpace::flowPath() { - return _flow; - } - VariableSpace::VariableSpace() { - _handles = new std::vector; + } } } \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 92aca938306c..dfabcb8ac861 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -26,7 +26,8 @@ namespace sd { Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { // this op is basically no-op // we just know it exists - + throw std::runtime_error("LogicExit::processNode - Not implemented yet"); +/* auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); @@ -42,6 +43,7 @@ namespace sd { __variableSpace->getVariable(pair0)->markRemovable(false); return ND4J_STATUS_OK; + */ } } } \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 35273ead76ed..8c464f231765 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -24,6 +24,8 @@ namespace sd { namespace graph { Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicLoopCond::processNode - Not implemented yet"); + /* auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); @@ -49,6 +51,7 @@ namespace sd { } return ND4J_STATUS_OK; + */ } } } \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 106ac96fbdf7..001a8517a344 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -24,9 +24,11 @@ namespace sd { namespace graph { GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { + if (1 > 0) + throw std::runtime_error("GraphProfilingHelper::profile - Not implemented yet"); // saving original workspace - auto varSpace = graph->variableSpace()->clone(); + //auto varSpace = graph->variableSpace(); // printing out graph structure // graph->printOut(); @@ -35,19 +37,19 @@ namespace sd { for (int e = 0; e < iterations; e++) { FlowPath fp; - auto _vs = varSpace->clone(); + //auto _vs = varSpace->clone(); //_vs->workspace()->expandTo(100000); - _vs->setFlowPath(&fp); + //_vs->setFlowPath(&fp); //GraphExecutioner::execute(graph, _vs); - delete _vs; + //delete _vs; } auto profile = new GraphProfile(); for (int e = 0; e < iterations; e++) { FlowPath fp; - +/* // we're always starting from "fresh" varspace here auto _vs = varSpace->clone(); //_vs->workspace()->expandTo(100000); @@ -61,10 +63,9 @@ namespace sd { profile->merge(p); delete _vs; + */ } - delete varSpace; - return profile; } } diff --git a/libnd4j/include/helpers/OpBenchmark.h b/libnd4j/include/helpers/OpBenchmark.h index 8e665cb70089..4a4996929926 100644 --- a/libnd4j/include/helpers/OpBenchmark.h +++ b/libnd4j/include/helpers/OpBenchmark.h @@ -32,30 +32,28 @@ namespace sd { protected: int _opNum = 0; std::string _testName; - NDArray *_x = nullptr; - NDArray *_y = nullptr; - NDArray *_z = nullptr; + NDArray _x; + NDArray _y; + NDArray _z; std::vector _axis; public: OpBenchmark() = default; - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z); - OpBenchmark(std::string name, NDArray *x, NDArray *z); - OpBenchmark(std::string name, NDArray *x, NDArray *z, std::initializer_list axis); - OpBenchmark(std::string name, NDArray *x, NDArray *z, std::vector axis); - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis); - OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::vector axis); + OpBenchmark(const std::string& name, const NDArray &x, const NDArray &y, const NDArray &z); + OpBenchmark(const std::string& name, const NDArray &x, const NDArray &z); + OpBenchmark(const std::string& name, const NDArray &x, const NDArray &z, const std::vector &axis); + OpBenchmark(const std::string& name, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis); void setOpNum(int opNum); - void setTestName(std::string testName); - void setX(NDArray *array); - void setY(NDArray *array); - void setZ(NDArray *array); + void setTestName(const std::string &testName); + void setX(const NDArray &array); + void setY(const NDArray &array); + void setZ(const NDArray &array); void setAxis(std::vector axis); void setAxis(std::initializer_list axis); NDArray& x(); - int opNum(); - std::string testName(); + int opNum() const; + const std::string& testName() const; std::vector getAxis(); virtual std::string extra(); diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 2ff4d25f1293..380c20c89b85 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -93,10 +93,12 @@ namespace sd { // returns shape part of shapeInfo as std::vector static std::vector pullShapeFromShapeInfo(Nd4jLong *shapeInfo); + static std::string shapeAsString(const NDArray &array); static std::string shapeAsString(const NDArray* array); static std::string shapeAsString(const std::vector& shape); static std::string shapeAsString(const Nd4jLong* shapeInfo); static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); + static std::string strideAsString(const NDArray& array); static std::string strideAsString(const NDArray* array); static std::string shapeInfoAsString(const Nd4jLong* shapeInfo); diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 2214283eae46..2429df5441ec 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -30,47 +30,26 @@ namespace sd { // } - BroadcastBenchmark(broadcast::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z, std::vector axis) : OpBenchmark(testName, x, y, z, axis) { + BroadcastBenchmark(broadcast::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis) : OpBenchmark(testName, x, y, z, axis) { _opNum = (int) op; } + - BroadcastBenchmark(broadcast::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, y, z, axis) { - _opNum = (int) op; - } - - BroadcastBenchmark(broadcast::Ops op, std::string name, std::vector axis) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - _axis = axis; - } - - BroadcastBenchmark(broadcast::Ops op, std::string name, std::initializer_list axis) : OpBenchmark() { + BroadcastBenchmark(broadcast::Ops op, const std::string &name, const std::vector &axis) : OpBenchmark() { _opNum = (int) op; _testName = name; _axis = axis; } ~BroadcastBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } + // } void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "BroadcastBM"); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(_x->shapeInfo(), _axis); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(_z->shapeInfo(), _axis); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(_x.shapeInfo(), _axis); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(_z.shapeInfo(), _axis); auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); auto tadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); @@ -78,7 +57,7 @@ void executeOnce() override { auto tadOnlyShapeInfoZ = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); auto tadOffsetsZ = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); - NativeOpExecutioner::execBroadcast(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), + NativeOpExecutioner::execBroadcast(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), /*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ); manager.synchronize(); @@ -106,11 +85,11 @@ void executeOnce() override { std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _y->ordering(); + result += _y.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -120,7 +99,7 @@ void executeOnce() override { result += "/"; result += ShapeUtils::strideAsString(_y); result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); return result; } diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index f570e676d70e..fb414a837520 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -50,13 +50,13 @@ namespace sd { std::string orders() override { if(_context != nullptr && _context->isFastPath()){ - std::vector& ins = _context->fastpath_in(); + auto& ins = _context->fastpath_in(); std::string s; for( int i=0; i 0){ s += "/"; } - s += ShapeUtils::strideAsString(_context->getNDArray(i)); + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); } return s; } @@ -65,13 +65,13 @@ namespace sd { std::string strides() override { if (_context != nullptr && _context->isFastPath()) { - std::vector& ins = _context->fastpath_in(); + auto& ins = _context->fastpath_in(); std::string s(""); for( int i=0; i 0){ s += "/"; } - s += ShapeUtils::strideAsString(_context->getNDArray(i)); + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); } return s; } else @@ -94,13 +94,13 @@ namespace sd { std::string shape() override { if (_context != nullptr && _context->isFastPath()) { - std::vector& ins = _context->fastpath_in(); + auto& ins = _context->fastpath_in(); std::string s; for( int i=0; i 0){ s += "/"; } - s += ShapeUtils::shapeAsString(_context->getNDArray(i)); + s += ShapeUtils::shapeAsString(_context->getNDArray(i).get()); } return s; } else @@ -109,7 +109,7 @@ namespace sd { std::string dataType() override { if (_context != nullptr && _context->isFastPath()){ - std::vector& ins = _context->fastpath_in(); + auto& ins = _context->fastpath_in(); std::string s; for( int i=0; i 0){ diff --git a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h index f307b41287da..6d040b1fd6d6 100644 --- a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h +++ b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h @@ -36,14 +36,14 @@ namespace sd { // } - MatrixBenchmark(float alpha, float beta, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { + MatrixBenchmark(float alpha, float beta, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { _alpha = alpha; _beta = beta; _tA = false; _tB = false; } - MatrixBenchmark(float alpha, float beta, bool tA, bool tB, std::string name) : OpBenchmark() { + MatrixBenchmark(float alpha, float beta, bool tA, bool tB, const std::string &name) : OpBenchmark() { _testName = name; _alpha = alpha; _beta = beta; @@ -52,26 +52,14 @@ namespace sd { } ~MatrixBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } + // } void executeOnce() override { - auto xT = (_tA ? _x->transpose() : *_x); - auto yT = (_tB ? _y->transpose() : *_y); + auto xT = (_tA ? _x.transpose() : _x); + auto yT = (_tB ? _y.transpose() : _y); - MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta); + MmulHelper::mmul(&xT, &yT, &_z, _alpha, _beta); } std::string axis() override { @@ -84,11 +72,11 @@ namespace sd { std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _y->ordering(); + result += _y.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -98,7 +86,7 @@ namespace sd { result += "/"; result += ShapeUtils::strideAsString(_y); result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); return result; } @@ -108,7 +96,7 @@ namespace sd { result += "x"; result += ShapeUtils::shapeAsString(_y); result += "="; - result += _z == nullptr ? "" : ShapeUtils::shapeAsString(_z); + result += _z.shapeInfo() == nullptr ? "" : ShapeUtils::shapeAsString(_z); return result; } diff --git a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h index 76d08db760e1..5cc686238c94 100644 --- a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h +++ b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h @@ -32,7 +32,7 @@ namespace sd { // } - PairwiseBenchmark(pairwise::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { + PairwiseBenchmark(pairwise::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { _opNum = (int) op; } @@ -42,25 +42,13 @@ namespace sd { } ~PairwiseBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } else if (_y == _z) { - delete _x; - delete _y; - } + // } void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "PairwiseBM"); - NativeOpExecutioner::execPairwiseTransform(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr); + NativeOpExecutioner::execPairwiseTransform(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr); manager.synchronize(); } @@ -71,21 +59,21 @@ namespace sd { std::string inplace() override { std::string result; - result += (_x == _y ? "x==y" : "x!=y"); + result += (_x.platformBuffer() == _y.platformBuffer() ? "x==y" : "x!=y"); result += "/"; - result += (_x == _z ? "x==z" : "x!=z"); + result += (_x.platformBuffer() == _z.platformBuffer() ? "x==z" : "x!=z"); result += "/"; - result += (_y == _z ? "y==z" : "y!=z"); + result += (_y.platformBuffer() == _z.platformBuffer() ? "y==z" : "y!=z"); return result; } std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _y->ordering(); + result += _y.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -95,7 +83,7 @@ namespace sd { result += "/"; result += ShapeUtils::strideAsString(_y); result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); return result; } diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index 78c41171541d..491d0254859f 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -36,12 +36,12 @@ namespace sd { // } - ReductionBenchmark(reduce::FloatOps op, std::string testName, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, const NDArray &x, const NDArray &z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { _opNum = (int) op; _opType = 0; } - ReductionBenchmark(reduce::SameOps op, std::string testName, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { + ReductionBenchmark(reduce::SameOps op, const std::string &testName, const NDArray &x, const NDArray &z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { _opNum = (int) op; _opType = 1; } @@ -52,7 +52,7 @@ namespace sd { _opType = 0; } - ReductionBenchmark(reduce::FloatOps op, std::string testName) : OpBenchmark() { + ReductionBenchmark(reduce::FloatOps op, const std::string &testName) : OpBenchmark() { _opNum = (int) op; _opType = 0; _testName = testName; @@ -63,18 +63,18 @@ namespace sd { _opType = 1; } - ReductionBenchmark(reduce::SameOps op, std::string testName) : OpBenchmark() { + ReductionBenchmark(reduce::SameOps op, const std::string &testName) : OpBenchmark() { _opNum = (int) op; _opType = 1; _testName = testName; } - ReductionBenchmark(reduce::FloatOps op, std::string testName, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(testName ,x, z, axis) { + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, const NDArray &x, const NDArray &z, const std::vector &axis) : OpBenchmark(testName ,x, z, axis) { _opNum = (int) op; _opType = 0; } - ReductionBenchmark(reduce::SameOps op, std::string testName, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(testName ,x, z, axis) { + ReductionBenchmark(reduce::SameOps op, const std::string &testName, const NDArray &x, const NDArray &z, const std::vector &axis) : OpBenchmark(testName ,x, z, axis) { _opNum = (int) op; _opType = 1; } @@ -82,21 +82,21 @@ namespace sd { void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "reductionBM"); - if (_z->isScalar() || _y == nullptr) + if (_z.isScalar() || _y.shapeInfo() == nullptr) if (_opType == 0) - NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo()); + NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo()); else - NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo()); + NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo()); else { - auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_x->shapeInfo(), _axis); + auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_x.shapeInfo(), _axis); auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? pack.primaryShapeInfo() : pack.specialShapeInfo(); auto tadOffsets = Environment::getInstance()->isCPU() ? pack.primaryOffsets() : pack.specialOffsets(); if (_opType == 0) - NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); + NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); else - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); } manager.synchronize(); @@ -104,9 +104,9 @@ namespace sd { std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -121,8 +121,7 @@ namespace sd { } ~ReductionBenchmark(){ - delete _x; - delete _z; + // } std::string axis() override { diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index f75f4a5a4fed..9b1da8b33d72 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -32,47 +32,38 @@ namespace sd { } ~ScalarBenchmark(){ - if (_x != _y && _x != _z && _y != _z) { - delete _x; - delete _y; - delete _z; - } else if (_x == _y && _x == _z) { - delete _x; - } else if (_x == _z) { - delete _x; - delete _y; - } + } ScalarBenchmark(scalar::Ops op) : OpBenchmark() { _opNum = (int) op; } - ScalarBenchmark(scalar::Ops op, std::string testName) : OpBenchmark() { + ScalarBenchmark(scalar::Ops op, const std::string &testName) : OpBenchmark() { _opNum = (int) op; _testName = testName; } - ScalarBenchmark(scalar::Ops op, std::string testName, NDArray *x, NDArray *y, NDArray *z) : OpBenchmark(testName, x, y, z) { + ScalarBenchmark(scalar::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { _opNum = (int) op; } void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "ScalarBM"); - if (_z == nullptr) - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), nullptr); + if (_z.shapeInfo() == nullptr) + NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), nullptr); else - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), _y->buffer(), _y->shapeInfo(), _y->specialBuffer(), _y->specialShapeInfo(), nullptr); + NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), nullptr); manager.synchronize(); } std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -80,7 +71,7 @@ namespace sd { std::string result; result += ShapeUtils::strideAsString(_x); result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); return result; } @@ -93,7 +84,7 @@ namespace sd { } OpBenchmark* clone() override { - return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup())); + return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x.shapeInfo() == nullptr ? _x : NDArray(_x.dup()) , _y.shapeInfo() == nullptr ? _y : NDArray(_y.dup()), _z.shapeInfo() == nullptr ? _z : NDArray(_z.dup())); } }; } diff --git a/libnd4j/include/helpers/benchmark/TransformBenchmark.h b/libnd4j/include/helpers/benchmark/TransformBenchmark.h index eabb956dbed5..1476fd95284e 100644 --- a/libnd4j/include/helpers/benchmark/TransformBenchmark.h +++ b/libnd4j/include/helpers/benchmark/TransformBenchmark.h @@ -35,35 +35,35 @@ namespace sd { // } - TransformBenchmark(int opNum, int opType, std::string testName, NDArray *x, NDArray *z) : OpBenchmark(testName, x, z) { + TransformBenchmark(int opNum, int opType, const std::string &testName, const NDArray &x, const NDArray &z) : OpBenchmark(testName, x, z) { _opNum = opNum; _opType = opType; } - TransformBenchmark(transform::StrictOps op, std::string testName, NDArray *x, NDArray *z) : OpBenchmark(testName, x, z) { + TransformBenchmark(transform::StrictOps op, const std::string &testName, const NDArray &x, const NDArray &z) : OpBenchmark(testName, x, z) { _opNum = (int) op; _opType = 0; } - TransformBenchmark(transform::StrictOps op, std::string name) : OpBenchmark() { + TransformBenchmark(transform::StrictOps op, const std::string &name) : OpBenchmark() { _opNum = (int) op; _opType = 0; _testName = name; } - TransformBenchmark(transform::SameOps op, std::string name) : OpBenchmark() { + TransformBenchmark(transform::SameOps op, const std::string &name) : OpBenchmark() { _opNum = (int) op; _opType = 1; _testName = name; } - TransformBenchmark(transform::AnyOps op, std::string name) : OpBenchmark() { + TransformBenchmark(transform::AnyOps op, const std::string &name) : OpBenchmark() { _opNum = (int) op; _opType = 2; _testName = name; } - TransformBenchmark(transform::FloatOps op, std::string name) : OpBenchmark() { + TransformBenchmark(transform::FloatOps op, const std::string &name) : OpBenchmark() { _opNum = (int) op; _opType = 3; _testName = name; @@ -71,31 +71,25 @@ namespace sd { ~TransformBenchmark(){ - if (_x == _z) { - delete _x; - } else { - delete _x; - delete _z; - } } void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "TransformBM"); - auto z = _z == nullptr ? _x : _z; + auto z = _z.shapeInfo() == nullptr ? _x : _z; switch (_opType) { case 0: - NativeOpExecutioner::execTransformStrict(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformStrict(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); break; case 1: - NativeOpExecutioner::execTransformSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformSame(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); break; case 2: - NativeOpExecutioner::execTransformAny(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformAny(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); break; case 3: - NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); break; } @@ -108,9 +102,9 @@ namespace sd { std::string orders() override { std::string result; - result += _x->ordering(); + result += _x.ordering(); result += "/"; - result += _z == nullptr ? _x->ordering() : _z->ordering(); + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); return result; } @@ -118,7 +112,7 @@ namespace sd { std::string result; result += ShapeUtils::strideAsString(_x); result += "/"; - result += _z == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); return result; } diff --git a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp index 9e85cc5b73fd..df97b08c6e8d 100644 --- a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp +++ b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp @@ -457,9 +457,8 @@ namespace sd { clone->setX(x_); clone->setZ(z_); - if (y_ != nullptr) { - clone->setAxis(y_->asVectorT()); - delete y_; + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); } result.emplace_back(clone); @@ -500,9 +499,8 @@ namespace sd { clone->setX(x_); clone->setZ(z_); - if (y_ != nullptr) { - clone->setAxis(y_->asVectorT()); - delete y_; + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); } result.emplace_back(clone); } diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index 2643a7b6d596..34b3ccfb09e6 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -84,9 +84,9 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); scorePlus += tmpScalar.e(0); } @@ -97,9 +97,9 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); scoreMinus += tmpScalar.e(0); } @@ -114,7 +114,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons } // get analytical gradient - const double analyticGrad = outArrsBP.at(i)->e(j); + const double analyticGrad = outArrsBP.at(i).e(j); if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) { printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j); throw std::runtime_error(""); diff --git a/libnd4j/include/helpers/impl/OpBenchmark.cpp b/libnd4j/include/helpers/impl/OpBenchmark.cpp index 6cb0dc08a788..1eba37d03561 100644 --- a/libnd4j/include/helpers/impl/OpBenchmark.cpp +++ b/libnd4j/include/helpers/impl/OpBenchmark.cpp @@ -21,36 +21,30 @@ #include "../OpBenchmark.h" namespace sd { - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z) { + OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, const NDArray &z) { _testName = name; _x = x; _y = y; _z = z; } - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z) { + OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z) { _testName = name; _x = x; _z = z; } - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z, std::initializer_list axis) : OpBenchmark(name, x, nullptr, z, axis){ } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::initializer_list axis){ + OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z, const std::vector &axis) { _testName = name; _x = x; - _y = y; _z = z; - _axis = std::vector(axis); + _axis = axis; if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); - } - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *z, std::vector axis) : OpBenchmark(name, x, nullptr, z, axis) { } - - OpBenchmark::OpBenchmark(std::string name, NDArray *x, NDArray *y, NDArray *z, std::vector axis) { + OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis) { _testName = name; _x = x; _y = y; @@ -63,13 +57,13 @@ namespace sd { NDArray& OpBenchmark::x() { - return *_x; + return _x; } - int OpBenchmark::opNum() { + int OpBenchmark::opNum() const { return _opNum; } - std::string OpBenchmark::testName(){ + const std::string& OpBenchmark::testName() const{ return _testName; } @@ -77,19 +71,19 @@ namespace sd { _opNum = opNum; } - void OpBenchmark::setTestName(std::string name){ + void OpBenchmark::setTestName(const std::string &name){ _testName = name; } - void OpBenchmark::setX(NDArray *array) { + void OpBenchmark::setX(const NDArray &array) { _x = array; } - void OpBenchmark::setY(NDArray *array) { + void OpBenchmark::setY(const NDArray &array) { _y = array; } - void OpBenchmark::setZ(NDArray *array) { + void OpBenchmark::setZ(const NDArray &array) { _z = array; } @@ -110,19 +104,19 @@ namespace sd { } std::string OpBenchmark::shape() { - if (_x != nullptr) + if (_x.shapeInfo() != nullptr) return ShapeUtils::shapeAsString(_x); - else if (_z != nullptr) + else if (_z.shapeInfo() != nullptr) return ShapeUtils::shapeAsString(_z); else return "N/A"; } std::string OpBenchmark::dataType() { - if (_x != nullptr) - return DataTypeUtils::asString(_x->dataType()); - else if (_z != nullptr) - return DataTypeUtils::asString(_z->dataType()); + if (_x.shapeInfo() != nullptr) + return DataTypeUtils::asString(_x.dataType()); + else if (_z.shapeInfo() != nullptr) + return DataTypeUtils::asString(_z.dataType()); else return "N/A"; } diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index 2b450ea9ec3c..e5c846bb6ffe 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -623,6 +623,14 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector(buffer_, shape_, LaunchContext::defaultContext(), false); // block should contain references to proper variable varSpace.putVariable(1, e, array); @@ -2005,9 +2005,6 @@ sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::Decla auto shapeList = op->calculateOutputShape(&inShapes, block); - if (varSpace.launchContext() != nullptr) - shapeList->detach(); - return shapeList; } @@ -2181,23 +2178,16 @@ static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong gr auto graph = sd::graph::GraphHolder::getInstance()->cloneGraph(graphId); auto varSpace = graph->variableSpace(); - std::vector handles; - for (int e = 0; e < numInputs; e++) { auto idx = inputIndices[e]; // we'll delete this array later, together with cloned VariableSpace - auto array = new sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e])); - handles.emplace_back(array); - - if (varSpace->hasVariable(idx)) { - auto var = varSpace->getVariable(idx); - if (var->hasNDArray()) - delete var->getNDArray(); - var->setNDArray(array); + if (varSpace.hasVariable(idx)) { + auto var = varSpace.getVariable(idx); + var->setNDArray(std::make_shared(inputBuffers[e], reinterpret_cast(inputShapes[e]))); } else - varSpace->putVariable(idx, array); + varSpace.putVariable(idx, sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e]))); } throw std::runtime_error("executeStoredGraphT - not implemented yet"); diff --git a/libnd4j/include/ops/declarable/generic/blas/svd.cpp b/libnd4j/include/ops/declarable/generic/blas/svd.cpp index ca5fd52c2092..1c316cf4a91a 100644 --- a/libnd4j/include/ops/declarable/generic/blas/svd.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/svd.cpp @@ -68,12 +68,12 @@ DECLARE_SHAPE_FN(svd) { Nd4jLong* sShapeInfo(nullptr); if(rank == 2) { - ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); sShapeInfo[0] = 1; sShapeInfo[1] = diagSize; } else { - ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong); + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(rank-1), Nd4jLong); sShapeInfo[0] = rank - 1; for(int i=1; i <= rank-2; ++i) sShapeInfo[i] = inShapeInfo[i]; diff --git a/libnd4j/include/ops/declarable/generic/boolean/select.cpp b/libnd4j/include/ops/declarable/generic/boolean/select.cpp index e8e257258b3c..1e3c9e1cc9ad 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/select.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/select.cpp @@ -72,9 +72,9 @@ namespace sd { for (int e = 0; e < tadsX.size(); e++) { if (!cond->e(e)) { - tadsZ.at(e)->assign(tadsY.at(e)); + tadsZ.at(e).assign(tadsY.at(e)); } else { - tadsZ.at(e)->assign(tadsX.at(e)); + tadsZ.at(e).assign(tadsX.at(e)); } } } diff --git a/libnd4j/include/ops/declarable/generic/boolean/where.cpp b/libnd4j/include/ops/declarable/generic/boolean/where.cpp index c72c10d6b72d..b09414ecec10 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where.cpp @@ -61,9 +61,9 @@ namespace sd { for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) { - tadsZ.at(e)->assign(tadsY.at(e)); + tadsZ.at(e).assign(tadsY.at(e)); } else { - tadsZ.at(e)->assign(tadsX.at(e)); + tadsZ.at(e).assign(tadsX.at(e)); } } } @@ -102,7 +102,7 @@ namespace sd { Nd4jLong *newShape; if (numOfTrue > 0) { - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; newShape[1] = numOfTrue; diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index 65cb52cddb05..0d90c11a3c10 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -91,9 +91,9 @@ namespace sd { for (int e = 0; e < tadsX.size(); e++) { if (!condition->e(e)) - tadsZ.at(e)->assign(tadsY.at(e)); + tadsZ.at(e).assign(tadsY.at(e)); else - tadsZ.at(e)->assign(tadsX.at(e)); + tadsZ.at(e).assign(tadsX.at(e)); } } } else { @@ -106,14 +106,14 @@ namespace sd { sd::ops::Where op; auto res(op.evaluate({condition})); REQUIRE_OK(res.status()); - NDArray* whereTrue = res.at(0); + auto& whereTrue = res.at(0); - if (whereTrue->isEmpty()) + if (whereTrue.isEmpty()) return ND4J_STATUS_OK; for (Nd4jLong outNext = 0; outNext < width; ++outNext) { auto output = OUTPUT_VARIABLE(outNext); for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { - output->p(e, whereTrue->e(e, outNext)); + output->p(e, whereTrue.e(e, outNext)); } } } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp index ea3757485e90..51c0d876a283 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp @@ -65,7 +65,7 @@ DECLARE_SHAPE_FN(meshgrid) { int rank = block.width(); Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outShapeInfo[0] = rank; for(int i = 1; i <= rank; ++i) outShapeInfo[i] = (Nd4jLong)shape::length(inputShape->at(i - 1)); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp index 3ddbe57ca697..a8421d41470d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp @@ -35,7 +35,7 @@ namespace ops { BROADCAST_CHECK_EMPTY(x,y,z); Nd4jLong* zShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, zShapeInfo, block.getWorkspace()); + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, zShapeInfo, block.workspace()); REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Multiply(), x, y, z); @@ -71,7 +71,7 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { auto dLdy = OUTPUT_VARIABLE(1); Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.workspace()); REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, "MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp index c3a2ac25876c..31517a483c7e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp @@ -80,7 +80,7 @@ DECLARE_SHAPE_FN(percentile) { } auto axises = block.getIArguments(); - Nd4jLong* outputShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, block.getWorkspace()); + Nd4jLong* outputShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, block.workspace()); return SHAPELIST(outputShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index 5a1ac02c5809..9eea9fd0c3fd 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -63,7 +63,7 @@ namespace ops { auto dLdy = OUTPUT_VARIABLE(1); Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.getWorkspace()); + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.workspace()); REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s" " and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); diff --git a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp index 108660c7b7db..ef42be652fd1 100644 --- a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp @@ -37,7 +37,7 @@ namespace sd { // we'll store signal to both ends //STORE_2_RESULTS(*input, *input); - +/* // but we'll ensure only one node is active, and other is disabled if (condition->e(0) == 0) { block.setBranch(0); @@ -48,6 +48,9 @@ namespace sd { } return Status::OK(); + */ + + throw std::runtime_error("Switch - Not implemented yet"); } DECLARE_SYN(switch, Switch); DECLARE_SYN(if, Switch); diff --git a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp index 1bcb8ef36b81..9bb24cf536a0 100644 --- a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp +++ b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp @@ -80,7 +80,7 @@ namespace sd { } - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = batchSizeDim; diff --git a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp index 543bf7d87f23..4787fc6897a5 100644 --- a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp @@ -67,7 +67,7 @@ namespace sd { width = newImageSize->e(0); height = newImageSize->e(1); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); outputShape[0] = 4; outputShape[1] = in[1]; outputShape[2] = width; diff --git a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp index 4ae03cc256af..8ed5c4954abf 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp @@ -91,7 +91,7 @@ namespace sd { REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; diff --git a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp index 65b5ef0b5af8..28bfaad97705 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp @@ -83,7 +83,7 @@ namespace sd { REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; diff --git a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp index 6d72bf889728..f5f89fe23f0b 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp @@ -95,7 +95,7 @@ namespace sd { height = INT_ARG(1); } - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; diff --git a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp index 3454fb897e6e..46a8949b6e99 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp @@ -93,7 +93,7 @@ namespace sd { height = INT_ARG(1); } - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; diff --git a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp index 925c4b6c1149..12d08b7b43fe 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp @@ -64,7 +64,7 @@ namespace ops { int outRank = inRank/2; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outShapeInfo[0] = outRank; for(int i = 1; i <= outRank; ++i) diff --git a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp index 785f72185fa8..5561d9cb6a8d 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp @@ -75,14 +75,14 @@ namespace ops { switch(size) { case 2: - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); outShapeInfo[0] = 2; outShapeInfo[1] = params[1]; outShapeInfo[2] = params[1]; break; case 3: - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); outShapeInfo[0] = 2; outShapeInfo[1] = params[1]; outShapeInfo[2] = params[2]; @@ -90,7 +90,7 @@ namespace ops { default: int rank = size-1; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outShapeInfo[0] = rank; outShapeInfo[rank-1] = params[1]; outShapeInfo[rank] = params[2]; @@ -101,7 +101,7 @@ namespace ops { shape::updateStrides(outShapeInfo, static_cast(-params[0])); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, dtype)); - RELEASE(outShapeInfo, block.getWorkspace()); + RELEASE(outShapeInfo, block.workspace()); return SHAPELIST(result); } diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp index 9d4a00be3f3c..15bc82dfabaa 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp @@ -49,7 +49,7 @@ namespace sd { outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(lastDimension, ArrayOptions::dataType(in)); } else { - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outShapeInfo[0] = outRank; for(int i = 0; i < outRank - 1; ++i) outShapeInfo[i + 1] = shape::sizeAt(in, i); diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp index 6e95d127de6d..79061ca5e31c 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp @@ -47,7 +47,7 @@ DECLARE_SHAPE_FN(matrix_diag) { int outRank = inRank + 1; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outShapeInfo[0] = outRank; for(int i = 0; i < inRank; ++i) outShapeInfo[i + 1] = shape::sizeAt(in, i); diff --git a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp index fa9fd5f56a5e..b707605d2639 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp @@ -51,7 +51,7 @@ DECLARE_SHAPE_FN(trace) { const int rank = inShapeInfo[0] - 2; Nd4jLong* outShapeInfo(nullptr); - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outShapeInfo[0] = rank; for(int i=1; i <= rank; ++i) @@ -59,7 +59,7 @@ DECLARE_SHAPE_FN(trace) { shape::updateStrides(outShapeInfo, shape::order(inShapeInfo)); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, ArrayOptions::dataType(inShapeInfo))); - RELEASE(outShapeInfo, block.getWorkspace()); + RELEASE(outShapeInfo, block.workspace()); return SHAPELIST(result); } diff --git a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp index 24525f09a10d..18c5ac8ebde8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp @@ -55,7 +55,7 @@ DECLARE_SHAPE_FN(triu) { int rank = (inShapeInfo[0] == 1) ? 2 : inShapeInfo[0]; Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); memcpy(outShapeInfo, inShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only if(inShapeInfo[0] == 1) { @@ -100,7 +100,7 @@ DECLARE_SHAPE_FN(triu_bp) { int rank = gradOShapeInfo[0]; Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); memcpy(outShapeInfo, gradOShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only auto in = inputShape->at(0); diff --git a/libnd4j/include/ops/declarable/generic/list/create_list.cpp b/libnd4j/include/ops/declarable/generic/list/create_list.cpp index 606558e7edab..08d9ff273f4f 100644 --- a/libnd4j/include/ops/declarable/generic/list/create_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/create_list.cpp @@ -48,8 +48,8 @@ namespace sd { setupResultList(list, block); // OVERWRITE_RESULT(list); - auto scalar = NDArrayFactory::create_(list->counter()); - block.pushNDArrayToVariableSpace(block.getNodeId(), 1, scalar); + auto scalar = NDArrayFactory::create(list->counter()); + block.pushNDArrayToVariableSpace(block.nodeId(), 1, scalar); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index 38a4da7bd9ce..c83f07c65069 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -43,7 +43,8 @@ namespace sd { array = INPUT_VARIABLE(1); indices = INPUT_VARIABLE(2); list = new NDArrayList(indices->lengthOf(), false); - block.trackList(list); + + throw std::runtime_error("scatter_list - Not implemented yet"); } REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, "ScatterList: Indices for Scatter should be a vector") @@ -56,7 +57,7 @@ namespace sd { if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - auto arr = new NDArray(tads.at(e)->dup(array->ordering())); + auto arr = new NDArray(tads.at(e).dup(array->ordering())); auto res = list->write(idx, arr); if (res != ND4J_STATUS_OK) return res; diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index c490479617c9..74a75b29fe9f 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -42,7 +42,9 @@ namespace sd { array = INPUT_VARIABLE(0); sizes = INPUT_VARIABLE(1); list = new NDArrayList(sizes->lengthOf(), false); - block.trackList(list); + //block.trackList(list); + + throw std::runtime_error("split_list - Not implemented yet"); } REQUIRE_TRUE(sizes->isZ(), 0, "split_list: sizes array must have one of integer types"); diff --git a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp index 812588710686..3b715b1fc020 100644 --- a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp @@ -279,9 +279,9 @@ DECLARE_SHAPE_FN(absolute_difference_loss_grad) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp index 10995c90b8a9..38458226314f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp @@ -147,7 +147,7 @@ DECLARE_SHAPE_FN(cosine_distance_loss) { else { // in this case output has the same shape as labels reduced by dim axis std::vector dimensions = {dim}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, outType, true, false, block.getWorkspace()); + outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, outType, true, false, block.workspace()); // weights array can be single scalar or has the same rank as output, and must be broadcastable to output REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); @@ -186,7 +186,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { REQUIRE_TRUE(labels->isSameShape(predictions), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "COSINE_DISTANCE_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(predictions->ordering(), dimensions, predictions->getShapeInfo(), true, false, block.getWorkspace()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(predictions->ordering(), dimensions, predictions->getShapeInfo(), true, false, block.workspace()); // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); // check whether broadcast operation is possible for weights array @@ -323,7 +323,7 @@ DECLARE_SHAPE_FN(cosine_distance_loss_grad) { // labels and predictions must have the same shapes REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, true, false, block.getWorkspace()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, true, false, block.workspace()); // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); // check whether broadcast operation is possible for weights array @@ -333,9 +333,9 @@ DECLARE_SHAPE_FN(cosine_distance_loss_grad) { auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index 7d8eeec3a8e6..75add512e3e4 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -176,7 +176,7 @@ namespace sd { E.applyScalar(scalar::RELU, 0.0f, E); // turn E into gradient mask - NDArray gradientMask(E.getShapeInfo(), block.getWorkspace()); + NDArray gradientMask(E.getShapeInfo(), block.workspace()); E.applyTransform(sd::transform::Sign, gradientMask); dLdp->assign(-z * gradientMask); @@ -291,9 +291,9 @@ namespace sd { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index a29bd1cf2572..bab25837fd09 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -54,7 +54,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { auto error = *predictions - *labels; error.applyTransform(transform::Abs, error); - NDArray quadratic(error.getShapeInfo(), block.getWorkspace()); + NDArray quadratic(error.getShapeInfo(), block.workspace()); error.applyScalar(scalar::MinPairwise, delta, quadratic); NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; @@ -306,9 +306,9 @@ DECLARE_SHAPE_FN(huber_loss) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index 99140a394c1a..b1d3d76e7bc9 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -298,9 +298,9 @@ DECLARE_SHAPE_FN(log_loss_grad) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 20e03e92bdde..2d651c000663 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -53,7 +53,7 @@ namespace ops { weightsBroad = new NDArray(weights->tileToShape(log_predictions->getShapeInfo())); - NDArray E(labels->getShapeInfo(), block.getWorkspace()); + NDArray E(labels->getShapeInfo(), block.workspace()); if (computeFullLoss) labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); else @@ -175,11 +175,11 @@ namespace ops { weightsBroad = new NDArray(weights->tileToShape(log_predictions->getShapeInfo())); - NDArray E(labels->getShapeInfo(), block.getWorkspace()); + NDArray E(labels->getShapeInfo(), block.workspace()); if (computeFullLoss) { labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); - NDArray rDiv(labels->getShapeInfo(), block.getWorkspace()); + NDArray rDiv(labels->getShapeInfo(), block.workspace()); labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); } else { @@ -299,9 +299,9 @@ namespace ops { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index f8006a3ed5c9..6c6b6826d6a5 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -203,7 +203,7 @@ namespace sd { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else { // in this case output has the shape as labels and logits minus last dimension std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, false, true, block.getWorkspace()); + outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, false, true, block.workspace()); // weights array can be single scalar or has the same rank as output, and must be broadcastable to output REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); @@ -369,9 +369,9 @@ namespace sd { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index b0ccf968b2ea..171c736aa57e 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -288,9 +288,9 @@ DECLARE_SHAPE_FN(mean_sqerr_loss_grad) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index 28d66bc9395d..5fdaef19f653 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -313,9 +313,9 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp index 3ea9ce2bd9e0..b34fbb39a001 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp @@ -164,7 +164,7 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss) { outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); else { // in this case output has the shape as labels and logits minus last dimension std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.getWorkspace()); + outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.workspace()); // weights array can be single scalar or has the same rank as output, and must be broadcastable to output REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); @@ -207,7 +207,7 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); // only 4 possible reduction modes exist REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->getShapeInfo(), false, false, block.getWorkspace()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->getShapeInfo(), false, false, block.workspace()); // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); // check whether broadcast operation is possible for weights array @@ -376,7 +376,7 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss_grad) { // labels and logits must have the same shapes REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.workspace()); // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); // check whether broadcast operation is possible for weights array diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index 1b3135f2762f..7e4785ce14e5 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -69,7 +69,7 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, false, false, block.getWorkspace()); + auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, false, false, block.workspace()); return SHAPELIST(reducedShapeInfo); diff --git a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp index c641bf12f5b0..6f671e148002 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp @@ -83,7 +83,7 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) { REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace()); + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(outShapeInfo)); } @@ -155,7 +155,7 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits_grad) { DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.workspace()); return SHAPELIST(CONSTANT(dLdpShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp index 921662fa6b5e..c96ddb503df6 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp @@ -84,7 +84,7 @@ namespace sd { /* DECLARE_SHAPE_FN(skipgram) { - return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, block.getWorkspace())); + return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, block.workspace())); } */ } diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp index 539b211452f5..8ef400ed4482 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp @@ -81,10 +81,10 @@ namespace sd { // now we do RELU backward pass //actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); - helpers::reluDerivative(block.launchContext(), actv, epsilonNext); + helpers::reluDerivative(block.launchContext(), &actv, epsilonNext); // now we split updated array into 2 chunks along last dimension sd::ops::concat_bp opc; - auto dec = opc.evaluate({input, input, actv}, {-1}); + auto dec = opc.evaluate({input, input, &actv}, {-1}); if (dec.status() != ND4J_STATUS_OK) return dec.status(); @@ -92,7 +92,7 @@ namespace sd { auto pos = dec.at(0); auto neg = dec.at(1); - pos->applyPairwiseTransform(sd::pairwise::Subtract, *neg, *epsilon); + pos.applyPairwiseTransform(sd::pairwise::Subtract, neg, *epsilon); return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp index 4b7088660de4..38b1ade26ab3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp @@ -44,7 +44,7 @@ namespace sd { auto shapes = SHAPELIST(); for (size_t i = 0; i < inputShape->size(); ++i) { Nd4jLong* shape; - COPY_SHAPE_EX(inputShape->at(i), shape, block.getWorkspace()); + COPY_SHAPE_EX(inputShape->at(i), shape, block.workspace()); shapes->push_back(CONSTANT(shape)); } return shapes; diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index c9fc0b91c323..02b0da90ac01 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -122,7 +122,7 @@ DECLARE_SHAPE_FN(batchnorm) { auto inShapeInfo = inputShape->at(0); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.workspace()); // output shape is identical to input shape return SHAPELIST(CONSTANT(outShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp index d6e95a582678..07c0d29c90ea 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp @@ -64,7 +64,7 @@ namespace sd { bool isSameMode = INT_ARG(8) > 0; Nd4jLong* zShape; - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); zShape[0] = 4; zShape[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 5815e0b55f14..ea8a4424f9ac 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -134,7 +134,7 @@ DECLARE_SHAPE_FN(conv1d) { ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outputShapeInfo[0] = 3; outputShapeInfo[1] = bS; @@ -279,11 +279,11 @@ DECLARE_SHAPE_FN(conv1d_bp) { if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index eba662d08ce2..6024ef9d68c5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -117,7 +117,7 @@ DECLARE_SHAPE_FN(conv2d) { REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); int oH, oW; // output height, width ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); @@ -253,11 +253,11 @@ DECLARE_SHAPE_FN(conv2d_bp) { if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); } @@ -372,7 +372,7 @@ DECLARE_SHAPE_FN(conv2d_input_bp) { REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); Nd4jLong* gradIshapeInfo(nullptr); - ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(gradIshapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); gradIshapeInfo[0] = rank; gradIshapeInfo[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index 05bb837deccd..4475ecf402c3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -160,7 +160,7 @@ DECLARE_SHAPE_FN(conv3dnew) { ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); outputShapeInfo[0] = rank; outputShapeInfo[1] = bS; @@ -352,11 +352,11 @@ DECLARE_SHAPE_FN(conv3dnew_bp) { if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 4ba57c93d859..8aa37bdd7aab 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -308,13 +308,13 @@ DECLARE_SHAPE_FN(deconv2d_bp) { if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); shapes->push_back(CONSTANT(gradBShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index f7f6416478c9..25b5976f59ec 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -156,7 +156,7 @@ DECLARE_SHAPE_FN(deconv3d) { ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); outputShapeInfo[0] = rank; outputShapeInfo[1] = bS; @@ -333,13 +333,13 @@ DECLARE_SHAPE_FN(deconv3d_bp) { if(biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); shapes->push_back(CONSTANT(gradBShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index 8c03ea6eb73f..1be9b379b59e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -120,7 +120,7 @@ DECLARE_SHAPE_FN(depthwise_conv2d) { ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); outputShapeInfo[0] = rank; outputShapeInfo[1] = bS; @@ -244,11 +244,11 @@ DECLARE_SHAPE_FN(depthwise_conv2d_bp) { if(biasShapeInfo) REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); if(biasShapeInfo) { - Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp index db65329835f0..b7df3748251b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp @@ -75,7 +75,7 @@ namespace sd { // output is always 6d for im2col Nd4jLong* zShape; - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(6), Nd4jLong); + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(6), Nd4jLong); int oY = 0; int oX = 0; diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp index e519585e62b7..848d12e1bbd6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp @@ -99,7 +99,7 @@ DECLARE_SHAPE_FN(pointwise_conv2d) { if (biasShapeInfo) REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.getWorkspace()); + auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.workspace()); // do not forget to put oC instead of iC in outputShapeInfo outputShapeInfo[indIOioC + 1] = oC; diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index 3ea1eec675ed..53bdaaaed79c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -167,7 +167,7 @@ DECLARE_SHAPE_FN(sconv2d) { ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); outputShapeInfo[0] = 4; outputShapeInfo[1] = bS; @@ -370,24 +370,24 @@ DECLARE_SHAPE_FN(sconv2d_bp) { if (biasShapeInfo) REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && oC == shape::length(biasShapeInfo), 0, "SCONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsDShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsDShapeInfo, gradOShapeInfo, false, block.workspace()); Nd4jLong* gradWPshapeInfo(nullptr), *gradBshapeInfo(nullptr); if(weightsPShapeInfo && biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.getWorkspace()); - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo), CONSTANT(gradBshapeInfo)); } if(weightsPShapeInfo && !biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo)); } if(!weightsPShapeInfo && biasShapeInfo) { - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace()); + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradBshapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp index fd41313e4470..dc8dcf770e08 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp @@ -64,7 +64,7 @@ DECLARE_SHAPE_FN(upsampling2d) { const int isNCHW = block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); outputShapeInfo[0] = inputShapeInfo[0]; outputShapeInfo[1] = inputShapeInfo[1]; @@ -117,7 +117,7 @@ DECLARE_SHAPE_FN(upsampling2d_bp) { REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.getWorkspace()); + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.workspace()); return SHAPELIST(CONSTANT(gradIShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp index 19bc5a9ecaf0..dce61910eba1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp @@ -64,7 +64,7 @@ DECLARE_SHAPE_FN(upsampling3d) { const int isNCDHW = block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); outputShapeInfo[0] = inputShapeInfo[0]; outputShapeInfo[1] = inputShapeInfo[1]; @@ -116,7 +116,7 @@ DECLARE_SHAPE_FN(upsampling3d_bp) { REQUIRE_TRUE(inputShape->at(0)[0] == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); REQUIRE_TRUE(inputShape->at(1)[0] == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.getWorkspace()); + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.workspace()); return SHAPELIST(CONSTANT(gradIShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp index 0888854ee969..e015af2dab91 100644 --- a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp @@ -53,7 +53,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { Nd4jLong thisIndex = (*indeces).e(e); input = INPUT_VARIABLE(thisIndex); // lookup param - outputView.at(e)->assign(input); + outputView.at(e).assign(input); } } else { @@ -68,7 +68,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { auto result(op.evaluate({input, indeces}, {0})); REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); - REQUIRE_TRUE(result.at(0)->isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); + REQUIRE_TRUE(result.at(0).isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); output->assign(result.at(0)); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index a3c27f02e5f2..532f4cb10daf 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -187,7 +187,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - // NDArray pNorm(columns2d->getShapeInfo(), block.getWorkspace()); + // NDArray pNorm(columns2d->getShapeInfo(), block.workspace()); // input->template applyTransform>(columns, std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data()); diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp index ae70fd451d8f..3420f1758281 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp @@ -112,12 +112,12 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { auto revInput = resultsIn.at(0); // backward steps - auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); + auto resultsBW = dynamicRnn.evaluate({&revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] hBWFinal->assign(resultsBW.at(1)); // reverse hBWtemp - auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0}); + auto resultsOut = timeMajor ? reverse.evaluate({&hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({&hBWtemp, seqLen}, {1, 0}); hBW->assign(resultsOut.at(0)); if(seqLen != maxTimeStep) @@ -199,10 +199,10 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { // evaluate output shapeInfos Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hFWShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hBWShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hFWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(hBWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); hFWShapeInfo[0] = hBWShapeInfo[0] = inRank; hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp index 2e6f9597672f..663d0b91d6d2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp @@ -149,8 +149,8 @@ DECLARE_SHAPE_FN(dynamic_rnn) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); hShapeInfo[0] = inRank; hPrevShapeInfo[0] = inRank-1; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index c6cd2e8f1388..a7ad50b3ae4e 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -94,7 +94,7 @@ DECLARE_SHAPE_FN(gru) { // evaluate output shapeInfo Nd4jLong *hShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); hShapeInfo[0] = rank; hShapeInfo[1] = time; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 204a1ca63944..5072b0fb03b3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -104,7 +104,7 @@ DECLARE_SHAPE_FN(gruCell) { REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==nU, 0, "gruCell: cell biases must be rank 1, size nU"); Nd4jLong *s0(nullptr); - ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU] + ALLOCATE(s0, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU] s0[0] = rank; s0[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp index 915be3129d25..fc63449585f1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp @@ -127,8 +127,8 @@ DECLARE_SHAPE_FN(lstm) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits] + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits] hShapeInfo[0] = cShapeInfo[0] = rank; hShapeInfo[1] = cShapeInfo[1] = time; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp index 3225f3f74f60..30e86a229df7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp @@ -92,7 +92,7 @@ DECLARE_SHAPE_FN(lstmBlock) { int nOut = cLast[2]; //rank, bs, nOut, ...] Nd4jLong *s(nullptr); - ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(3), Nd4jLong); // [time, bS, nOut] + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(3), Nd4jLong); // [time, bS, nOut] s[0] = 3; if(dataFormat == 0){ //[rank, seqLen, bs, nIn, ...] diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp index 333854ba31f0..532307cc19a2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp @@ -107,7 +107,7 @@ DECLARE_SHAPE_FN(lstmBlockCell) { // evaluate output shapeInfos const int bS = xt[1]; Nd4jLong *s(nullptr); - ALLOCATE(s, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); // [bS, numUnits] + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); // [bS, numUnits] s[0] = 2; s[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp index 20a9e6710c37..765dd36cc863 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp @@ -125,8 +125,8 @@ DECLARE_SHAPE_FN(lstmCell) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] hShapeInfo[0] = cShapeInfo[0] = rank; hShapeInfo[1] = cShapeInfo[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp index 9b78a5c56885..42a88c210e65 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp @@ -117,7 +117,7 @@ DECLARE_SHAPE_FN(sru) { REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); Nd4jLong* newShapeInfo1 = nullptr; - ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time] + ALLOCATE(newShapeInfo1, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time] newShapeInfo1[0] = rank; newShapeInfo1[1] = bS; @@ -126,7 +126,7 @@ DECLARE_SHAPE_FN(sru) { ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, shape::order(xShapeInfo)); ShapeDescriptor descriptor(newShapeInfo1); - RELEASE(newShapeInfo1, block.getWorkspace()); + RELEASE(newShapeInfo1, block.workspace()); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); return SHAPELIST(result, result); } @@ -660,7 +660,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // char order = (char)(inShape[size-1]); // Nd4jLong* newShapeInfo1 = nullptr; -// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong); +// ALLOCATE(newShapeInfo1, block.workspace(), size, Nd4jLong); // newShapeInfo1[0] = rank; // newShapeInfo1[1] = bS; @@ -762,7 +762,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // char order = (char)(inShape[size-1]); // Nd4jLong *newShapeInfo1 = nullptr; -// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong); +// ALLOCATE(newShapeInfo1, block.workspace(), size, Nd4jLong); // newShapeInfo1[0] = rank; // newShapeInfo1[1] = bS; @@ -772,7 +772,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order); // auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShapeInfo1)); -// RELEASE(newShapeInfo1, block.getWorkspace()); +// RELEASE(newShapeInfo1, block.workspace()); // return SHAPELIST(result, result); // } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp index ee446037ce50..52253f74f817 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp @@ -87,8 +87,8 @@ DECLARE_SHAPE_FN(sruCell) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] hShapeInfo[0] = cShapeInfo[0] = rank; hShapeInfo[1] = cShapeInfo[1] = bS; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp index bc27c08f69b3..57a5ddeadfcf 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp @@ -198,9 +198,9 @@ DECLARE_SHAPE_FN(static_bidirectional_rnn) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); hShapeInfo[0] = inRank; hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1; diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp index 4100f67454a2..110db1c2afb9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp @@ -126,8 +126,8 @@ DECLARE_SHAPE_FN(static_rnn) { // evaluate output shapeInfos Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); hShapeInfo[0] = inRank; hPrevShapeInfo[0] = inRank-1; diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 0bab69059ec0..670b09b86ca5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -49,7 +49,7 @@ namespace sd { DECLARE_SHAPE_FN(relu_layer) { auto inShape = inputShape->at(0); auto weightsShape = inputShape->at(1); - auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace()); + auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.workspace()); return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index fc90c8125032..405567c509a0 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -68,10 +68,10 @@ namespace sd { const int nWeightsFormat = block.numI() > 0 ? INT_ARG(0) : 0; - auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.getWorkspace()) : inputShape->at(1); + auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.workspace()) : inputShape->at(1); auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false, - ArrayOptions::dataType(inputShape->at(0)), block.getWorkspace()); + ArrayOptions::dataType(inputShape->at(0)), block.workspace()); return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp index fd3315157097..b625c56a2fdf 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp @@ -36,7 +36,8 @@ namespace sd { if (!var->hasNDArrayList()) { auto list = inVar->getNDArrayList(); - block.pushNDArrayListToVariableSpace(block.nodeId(), e, list, false); + //block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); + throw std::runtime_error("Expose - not implemented yet"); } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index 7dbed736d8ad..7262d7fa71a1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -63,8 +63,8 @@ namespace sd { Nd4jLong* meanShape = nullptr; Nd4jLong* varianceShape = nullptr; - COPY_SHAPE_EX(in, meanShape, block.getWorkspace()); - COPY_SHAPE_EX(in, varianceShape, block.getWorkspace()); + COPY_SHAPE_EX(in, meanShape, block.workspace()); + COPY_SHAPE_EX(in, varianceShape, block.workspace()); auto shapeList = SHAPELIST(); shapeList->push_back(CONSTANT(meanShape)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp index 10970e965938..98b24cfcdcb9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp @@ -51,7 +51,7 @@ namespace sd { int outRank = shape::rank(in) - 1; Nd4jLong *outputShape = nullptr; if (outRank > 1) { - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; for (Nd4jLong e = 0; e < outRank; e++) outputShape[e + 1] = in[e + 1]; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp index 7ab19668a8ca..11d41ee5ee50 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp @@ -53,7 +53,7 @@ namespace sd { int numOfClasses = val + 1; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp index abb865d8e933..a0b0e91a3e2b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp @@ -51,7 +51,7 @@ namespace sd { int numOfClasses = val + 1; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp index a245b000b4fd..0b75a5416fe6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp @@ -51,7 +51,7 @@ namespace sd { int numOfClasses = val + 1; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp index 478eb9e233a8..d600d3935115 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp @@ -51,7 +51,7 @@ namespace sd { int numOfClasses = val + 1; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp index bb959fd3d3e5..3fd84c1aebbb 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp @@ -51,7 +51,7 @@ namespace sd { int numOfClasses = static_cast(val) + 1; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index fa4625288235..5ac400478905 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -85,7 +85,7 @@ namespace sd { } int lastDimension = maxInd; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outShapeInfo[0] = outRank; for(int i = 0; i < outRank - 1; ++i) outShapeInfo[i + 1] = shape::sizeAt(in, i); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp index 7995727940a9..360e7fdd437b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp @@ -69,7 +69,7 @@ namespace sd { for (int e = 0; e < 2; e++) { // 2 element tuple at output Nd4jLong* aShape; - ALLOCATE(aShape, block.getWorkspace(), shape::shapeInfoLength(shapeRank), Nd4jLong); + ALLOCATE(aShape, block.workspace(), shape::shapeInfoLength(shapeRank), Nd4jLong); aShape[0] = shapeRank; for (int i = 1 ; i < shapeRank; ++i) aShape[i] = shape::sizeAt(in, i - 1); @@ -78,7 +78,7 @@ namespace sd { shape::updateStrides(aShape, shape::order(in)); shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(aShape, (e == 0?ArrayOptions::dataType(in):sd::DataType::INT64)))); - RELEASE(aShape, block.getWorkspace()); + RELEASE(aShape, block.workspace()); } return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp index 64b915c53040..4ed1c3850114 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp @@ -55,7 +55,7 @@ namespace sd { // second output is always LONG indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::length(in), sd::DataType::INT64); - //COPY_SHAPE_EX(in, indicesShape, block.getWorkspace()); + //COPY_SHAPE_EX(in, indicesShape, block.workspace()); return SHAPELIST(valuesShape, indicesShape); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index 77e851104962..1f10e95f127c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -54,7 +54,7 @@ namespace sd { Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); Nd4jLong* outputShape; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index cad59b7e9657..d5ce627c3c97 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -56,7 +56,7 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index 87b96e844e89..1a67a428ec75 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -48,7 +48,7 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index e430c8f7713a..4f9d2269b298 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -48,7 +48,7 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index eeaa6e2c2f6d..376b12ae5b18 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -48,7 +48,7 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 9414964246bc..b786ba581d34 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -55,7 +55,7 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outputShape[0] = outRank; outputShape[1] = numOfClasses; diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp index 151b5a022492..f0c27da53d31 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp @@ -82,7 +82,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.workspace())); } } } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp index ef7f58b5e049..a9c7951a8774 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp @@ -81,7 +81,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); } - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.workspace()); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp index 01de97b5f0ba..aefea142485d 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp @@ -72,7 +72,7 @@ DECLARE_SHAPE_FN(reduce_mean) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace()); + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } @@ -116,7 +116,7 @@ CUSTOM_OP_IMPL(reduce_mean_bp, 2, 1, false, 0, 0) { gradI->assign((gradO->lengthOf() + 0.) / input->lengthOf()); if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp index 3e0491ffefaf..2cba7b79c61c 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp @@ -83,7 +83,7 @@ DECLARE_SHAPE_FN(reduce_stdev) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace()); + Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } @@ -138,7 +138,7 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp index 00a4f6d7e8de..1c60f53ab76e 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp @@ -81,7 +81,7 @@ DECLARE_SHAPE_FN(reduce_variance) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } @@ -134,7 +134,7 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else *gradI *= *gradO; // automatic broadcasting happens here diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp index 7f0c822a8729..f64bb86444cc 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp @@ -66,7 +66,7 @@ CUSTOM_OP_IMPL(reduce_dot_bp, 3, 2, false, 0, 0) { REQUIRE_TRUE(item >= -x->rankOf() && item < x->rankOf(), 0, "REDUCE_DOT_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , x->rankOf(), x->rankOf(), item); if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *x, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *x, true, false, block.workspace()); auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] gradX->assign((*y) * r); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp index 2acade3202db..242bbd2b9983 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp @@ -70,7 +70,7 @@ namespace ops { axes = block.getIArguments(); } - Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, block.getWorkspace()); + Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp index f70e613e8da1..72c4ab4768e2 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_max) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); + Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp index 85addc6121c4..735b515c586f 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_min) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); + Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp index 5f15c0e8e89f..273497548f00 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_norm1) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_norm1) { @@ -128,7 +128,7 @@ CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { // *** calculations *** // if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else *gradI *= *gradO; diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp index 60d7385a1c34..329b567e544f 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_norm2) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_norm2) { @@ -127,7 +127,7 @@ CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else *gradI *= *gradO; diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp index 4a7d18f17dc9..89cc2a5530ee 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp @@ -79,7 +79,7 @@ DECLARE_SHAPE_FN(reduce_norm_max) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_norm_max) { diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp index b35b5d77fd9d..e111597ea49b 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_prod) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_prod) { @@ -128,7 +128,7 @@ CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { *gradI /= *input; if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] } else *gradI *= *gradO; diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp index 25c8688b2113..1f70cd57e11e 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp @@ -75,7 +75,7 @@ DECLARE_SHAPE_FN(reduce_sqnorm) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace()); + Nd4jLong* outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); return SHAPELIST(outShapeInfo); } @@ -123,7 +123,7 @@ CUSTOM_OP_IMPL(reduce_sqnorm_bp, 2, 1, false, 0, 0) { // *** calculations *** // if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); gradI->assign(2. * (*input) *gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims))); // for example could be something like [a,b] -> [1,a,1,b] } else gradI->assign(2. * (*input) * *gradO); diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp index 3aca5854a8a5..e23ef5502b94 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp @@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_sum) { for(const auto& item : dimensions) REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_sum) { @@ -121,7 +121,7 @@ CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { // *** calculations *** // if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.getWorkspace()); + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), r, *gradI); } else diff --git a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp index 18b9ce2b8ac5..5a35cee2cf84 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp @@ -68,7 +68,7 @@ namespace sd { auto shapeArray = INPUT_VARIABLE(0); const int len = (int) shapeArray->lengthOf(); Nd4jLong *newShape = nullptr; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(len), Nd4jLong); newShape[0] = len; for (int e = 0; e < shapeArray->lengthOf(); e++){ diff --git a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp index 6466b9c1cc7e..39389ba9757c 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp @@ -412,7 +412,7 @@ namespace sd { // else { if (indices.size()) { Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong); + ALLOCATE(subArrShapeInfo, block.workspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong); Nd4jLong offset; shape::calcSubArrShapeInfoAndOffset(indices.data(), x->getShapeInfo(), subArrShapeInfo, offset, true, true); @@ -428,7 +428,7 @@ namespace sd { NDArray::registerSpecialUse({z}, {x}); - RELEASE(subArrShapeInfo, block.getWorkspace()); + RELEASE(subArrShapeInfo, block.workspace()); } else if (!z->isEmpty()){ z->assign(x->e(0)); diff --git a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp index 437222052caa..76510720a30d 100644 --- a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp @@ -51,7 +51,7 @@ namespace sd { ArrayOptions::setDataType(newShape, ArrayOptions::dataType(inputShape->at(0))); auto shape = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape)); - RELEASE(newShape, block.getWorkspace()); + RELEASE(newShape, block.workspace()); return SHAPELIST(shape); } diff --git a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp index 283efde0cbe4..86368d98404f 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp @@ -35,12 +35,12 @@ namespace sd { DECLARE_SHAPE_FN(testcustom) { // this test op will just return back original shape doubled Nd4jLong *shapeOf; - ALLOCATE(shapeOf, block.getWorkspace(), shape::rank(inputShape->at(0)), Nd4jLong); + ALLOCATE(shapeOf, block.workspace(), shape::rank(inputShape->at(0)), Nd4jLong); for (int e = 0; e < shape::rank(inputShape->at(0)); e++) shapeOf[e] = inputShape->at(0)[e+1] * 2; auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape::rank(inputShape->at(0)), shapeOf); - RELEASE(shapeOf, block.getWorkspace()); + RELEASE(shapeOf, block.workspace()); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index dfb9a9fc2916..2fe7ddef62b3 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -70,14 +70,14 @@ namespace sd { auto row = rows.at(r); for (int e = 0; e < numColumns; e += 2) { - int idx = row->e(e); + int idx = row.e(e); if (idx < 0) break; int denseIdx = sparse2dense.at(idx); - float value = row->e(e); + float value = row.e(e); float current = z->e(r, denseIdx); z->p(r, denseIdx, value + current); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index f78558a93db8..9ce344c8d8cd 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -188,10 +188,10 @@ DECLARE_SHAPE_FN(concat) { // delete dynamically allocated vectors shapes with length=1 for(int index : shapesToDelete) - RELEASE(arrShapes[index], block.getWorkspace()); + RELEASE(arrShapes[index], block.workspace()); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo)); - RELEASE(outShapeInfo, block.getWorkspace()); + RELEASE(outShapeInfo, block.workspace()); return SHAPELIST(result); } @@ -328,7 +328,7 @@ DECLARE_SHAPE_FN(concat) { // // all scalars // if (allScalars) { - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); // shape::shapeBuffer(1, &elements, newShape); // return SHAPELIST(newShape); @@ -336,7 +336,7 @@ DECLARE_SHAPE_FN(concat) { // // any scalar // if (hasScalars) { - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); // Nd4jLong length = shape::length(inp); // for (int i = 1; i < block.width(); i++) { // auto c = INPUT_VARIABLE(i); @@ -352,7 +352,7 @@ DECLARE_SHAPE_FN(concat) { // } - // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); + // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); // if (_dimension < 0) // _dimension += first->rankOf(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp index 6a055c02c380..8ab7ba1e053d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp @@ -69,7 +69,7 @@ namespace ops { int outRank = shape::rank(in) - shape::rank(idx) + 1; for (int e = 0; e < numPartition; e++) { Nd4jLong *newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); //shape::shapeVector(partitionSizes[e], newShape); newShape[0] = outRank; newShape[1] = partitionSizes[e]; @@ -117,13 +117,13 @@ namespace ops { ops::dynamic_stitch stichOp; std::vector partitions(numPartition * 2); for (size_t i = 0; i < res.size(); i++) { - partitions[i] = res.at(i); + partitions[i] = &res.at(i); partitions[i + numPartition] = gradOutList[i]; } auto result = stichOp.evaluate(partitions, {numPartition}); REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); - result.at(0)->reshapei(outputList[0]->getShapeAsVector()); + result.at(0).reshapei(outputList[0]->getShapeAsVector()); outputList[1]->assign(indices); outputList[0]->assign(result.at(0)); diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index 848f72c998fe..e8db96837d4b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -118,7 +118,7 @@ DECLARE_SHAPE_FN(gather) { int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(outputRank), Nd4jLong); // fill output shapeInfo outputShapeInfo[0] = outputRank; @@ -138,7 +138,7 @@ DECLARE_SHAPE_FN(gather) { int indicesRank = block.numI() == 2 ? 0 : 1; int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(outputRank), Nd4jLong); // building shape manually outputShapeInfo[0] = outputRank; @@ -162,7 +162,7 @@ DECLARE_SHAPE_FN(gather) { } auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outputShapeInfo)); - RELEASE(outputShapeInfo, block.getWorkspace()); + RELEASE(outputShapeInfo, block.workspace()); return SHAPELIST(result); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp index cf0081c76046..e5b895ec1962 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp @@ -76,7 +76,7 @@ DECLARE_SHAPE_FN(gather_nd) { int outRank = (rankInd - 1) + (rankIn - lastIndDim); Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); outShapeInfo[0] = outRank; diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index a4b934853e3a..d5f50f40d175 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -86,7 +86,7 @@ DECLARE_SHAPE_FN(mirror_pad) { outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(len, input->dataType()); } else { - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outShapeInfo[0] = rank; for(int i = 0; i < rank; ++i) outShapeInfo[i+1] = input->sizeAt(i) + paddings->e(i,0) + paddings->e(i,1); diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 9c08662857b7..0c5dfc7f0bb2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -96,14 +96,14 @@ DECLARE_SHAPE_FN(pad) { REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); outShapeInfo[0] = rank; for(int i=1; i <= rank; ++i) outShapeInfo[i] = inputShapeInfo[i] + paddings->e(i-1,0) + paddings->e(i-1,1); ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); ShapeDescriptor descriptor(outShapeInfo); - RELEASE(outShapeInfo, block.getWorkspace()); + RELEASE(outShapeInfo, block.workspace()); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor)); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp index b6a2ba1e1cf4..f91ad39d4f10 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp @@ -58,7 +58,7 @@ DECLARE_SHAPE_FN(parallel_stack) { int rank = inShapeInfo[0]; Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank+1), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank+1), Nd4jLong); outShapeInfo[0] = rank + 1; outShapeInfo[1] = block.width(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp index 564a8f478582..1a907b1331a8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp @@ -82,7 +82,7 @@ namespace ops { auto updShapeInfo = inputShape->at(1); Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); outShapeInfo[0] = shape->lengthOf(); for (int i = 0; i < outShapeInfo[0]; ++i) diff --git a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp index bd8423504bc6..814780bfa285 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp @@ -83,7 +83,7 @@ namespace sd { } Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong); + ALLOCATE(subArrShapeInfo, block.workspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong); Nd4jLong offset; @@ -101,7 +101,7 @@ namespace sd { NDArray::registerSpecialUse({output}, {input}); - RELEASE(subArrShapeInfo, block.getWorkspace()); + RELEASE(subArrShapeInfo, block.workspace()); // auto sub = (*input)(indices, true); // output->assign(sub); diff --git a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp index 64be499fb719..11700be2c374 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp @@ -58,7 +58,7 @@ namespace ops { DECLARE_SHAPE_FN(barnes_edge_forces) { Nd4jLong* bufShape; Nd4jLong* outShapeInfo; - outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(3), inputShape->at(3), false, block.getWorkspace()); + outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(3), inputShape->at(3), false, block.workspace()); return SHAPELIST(CONSTANT(outShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 6e3c562d47fb..3f02bdd79c2a 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -81,9 +81,9 @@ namespace sd { // outShapeInfo[2] = len; // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); - outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.getWorkspace()); - auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.getWorkspace()); - auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.getWorkspace()); + outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.workspace()); + auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.workspace()); + auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.workspace()); return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp index d4240d7803af..2b02e1a46a57 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp @@ -51,7 +51,7 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& for (auto i = start; i < stop; i++) { const T iNormActual = norm2.e(i); if (iNormActual > normClip) - *listOfInSubArrs.at(i) *= normClip / iNormActual; + listOfInSubArrs.at(i) *= normClip / iNormActual; } }; samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); @@ -75,12 +75,12 @@ static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& for (auto i = start; i < stop; i++) { auto inputSubArr = listOfInSubArrs.at(i); auto outputSubArr = listOfOutSubArrs.at(i); - outputSubArr->assign(inputSubArr); + outputSubArr.assign(inputSubArr); const T iNormActual = norm2.e(i); if (iNormActual > clipNorm.e(0)) - *outputSubArr *= clipNorm / iNormActual; + outputSubArr *= clipNorm / iNormActual; } }; samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); @@ -178,7 +178,7 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g if (N > cn) { auto inputSubArr = inputSubArrs.at(i); - const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar + const T sumOfProd = (inputSubArr * gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar const T factor1 = static_cast(1.f) / N; const T factor3 = factor1 / (N * N); // 1 / (N*N*N) @@ -186,9 +186,9 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); }; - inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); + inputSubArr.applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); } else - gradISubArr->assign(gradOSubArr); + gradISubArr.assign(gradOSubArr); } }; samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); @@ -228,11 +228,11 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector(e) / tads.at(e)->lengthOf(); + T n2 = norm2.e(e) / tads.at(e).lengthOf(); const T factor = cn / n2; if (n2 > cn) { auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads.at(e)->applyLambda(lambda, output); + tads.at(e).applyLambda(lambda, output); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp index 685d80d2dcf0..b524dc1eacb2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp @@ -36,7 +36,7 @@ namespace helpers { auto label = labels->e(j); auto pred = predictions->e(j); T value = (weights == nullptr ? (T) 1.0f : weights->e(j)); - arrs.at(label)->p(pred, value); + arrs.at(label).p(pred, value); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp index 51af1840be42..7f662e1c2845 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp @@ -44,7 +44,7 @@ void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray * auto b_ = tadsB.at(e); auto o_ = tadsO.at(e); - helpers::cross(context, a_, b_, o_); + helpers::cross(context, &a_, &b_, &o_); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp index 2b6b4cd02cde..5547c92832f6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp @@ -55,7 +55,7 @@ namespace sd { //PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > Environment::getInstance()->elementwiseThreshold()) for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) if ((*indices).e(e) == i) - listOutForCurrent.at(outputs[i].second++)->assign(listOfTensors.at(e)); + listOutForCurrent.at(outputs[i].second++).assign(listOfTensors.at(e)); } } else { @@ -126,7 +126,7 @@ namespace sd { return ND4J_STATUS_VALIDATION; } - listOfOutTensors.at(pos)->assign(listOfTensors.at(i)); + listOfOutTensors.at(pos).assign(listOfTensors.at(i)); } } } @@ -160,7 +160,7 @@ namespace sd { for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) if (indices->e(e) == i) - listOfTensors.at(e)->assign(listOutForCurrent.at(outputs[i].second++)); + listOfTensors.at(e).assign(listOutForCurrent.at(outputs[i].second++)); } } else { // one-dimensional case diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index 15ea569e8687..377ea559fffe 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -73,7 +73,7 @@ namespace helpers { bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame); if (setUp) { - outMatrix->t(i, j, pos) = patch->e(row, col, pixel); + outMatrix.t(i, j, pos) = patch.e(row, col, pixel); } pos++; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp index 30a83b8713c9..9c343eafdf0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp @@ -34,7 +34,7 @@ void eye(sd::LaunchContext * context, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) - arrs.at(i)->setIdentity(); + arrs.at(i).setIdentity(); }; samediff::Threads::parallel_tad(func, 0, arrs.size()); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp index 554486bbf9ce..32bb1833ec63 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp @@ -40,7 +40,7 @@ namespace helpers { for (auto x = 0; x < lastDims.size(); x++) { for (auto r = 0; r < rows; r++) { - lastDims[x]->t(r,r) = (T)value; + lastDims[x].t(r,r) = (T)value; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8938a98f9e0b..8fc2ece79e2f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -324,7 +324,7 @@ namespace helpers { auto loop = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { - luNN_(context, outputs.at(i), permutationVectors?permutations.at(i):nullptr, n); + luNN_(context, &outputs.at(i), permutationVectors ? &permutations.at(i) : nullptr, n); } }; samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); @@ -342,7 +342,7 @@ namespace helpers { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.workspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) @@ -363,7 +363,7 @@ template Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); + NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.workspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); @@ -388,8 +388,8 @@ template auto totalCount = output->lengthOf() / n2; output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); @@ -441,8 +441,8 @@ template auto totalCount = output->lengthOf() / n2; output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); @@ -484,14 +484,14 @@ template auto n2 = n * n; output->nullify(); // fill up output tensor with zeros -// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); +// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); // auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); // auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); auto inputPart = input->allTensorsAlongDimension({-2, -1}); auto outputPart = output->allTensorsAlongDimension({-2, -1}); auto totalCount = outputPart.size(); //lengthOf() / n2; for (int e = 0; e < totalCount; e++) { - invertUpperMatrix(inputPart.at(e), outputPart.at(e)); + invertUpperMatrix(&inputPart.at(e), &outputPart.at(e)); } return Status::OK(); } @@ -510,20 +510,20 @@ template template static bool checkCholeskyInput_(sd::LaunchContext * context, NDArray const* input) { - //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.getWorkspace()); + //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.workspace()); ResultSet lastMatrixList = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1}); for (size_t i = 0; i < lastMatrixList.size(); i++) { auto thisMatrix = lastMatrixList.at(i); // check for symmetric - for (Nd4jLong r = 0; r < thisMatrix->rows(); r++) - for (Nd4jLong c = 0; c < thisMatrix->columns(); c++) - if (sd::math::nd4j_abs(thisMatrix->e(r, c) - lastMatrixList.at(i)->e(c,r)) > DataTypeUtils::min()) return false; + for (Nd4jLong r = 0; r < thisMatrix.rows(); r++) + for (Nd4jLong c = 0; c < thisMatrix.columns(); c++) + if (sd::math::nd4j_abs(thisMatrix.e(r, c) - lastMatrixList.at(i).e(c,r)) > DataTypeUtils::min()) return false; NDArray output = NDArrayFactory::create(0., context); - if (ND4J_STATUS_OK != determinant(context, thisMatrix, &output)) return false; + if (ND4J_STATUS_OK != determinant(context, &thisMatrix, &output)) return false; if (output.e(0) <= T(0)) return 0; - NDArray reversedMatrix(*thisMatrix); - if (ND4J_STATUS_OK != inverse(context, thisMatrix, &reversedMatrix)) return false; + NDArray reversedMatrix(thisMatrix); + if (ND4J_STATUS_OK != inverse(context, &thisMatrix, &reversedMatrix)) return false; if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) return false; if (output.e(0) <= T(0)) return 0; @@ -546,7 +546,7 @@ template if (!inplace) output->assign(0.f); // fill up output tensor with zeros only inplace=false - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace()); + std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.workspace()); std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); for (int e = 0; e < totalCount; e++) { @@ -597,7 +597,7 @@ template for (Nd4jLong e = 0; e < totalCount; e++) { for (size_t i = 0; i < n; ++i) - output->t(e) += sd::math::nd4j_log(sd::math::nd4j_pow(matricies.at(e)->t(i, i), T(2))); + output->t(e) += sd::math::nd4j_log(sd::math::nd4j_pow(matricies.at(e).t(i, i), T(2))); } return ND4J_STATUS_OK; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp index d83f0dab94de..660184afd93d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp @@ -33,15 +33,15 @@ namespace helpers { ResultSet listOut = output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); ResultSet listDiag = input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); for (Nd4jLong e = 0; e < static_cast(listOut.size()); ++e) { - NDArray* inputMatrix = listDiag.at(e); - NDArray* outputMatrix = listOut.at(e); - if (outputMatrix != inputMatrix) // if not inplace - outputMatrix->assign(inputMatrix); + auto inputMatrix = listDiag.at(e); + auto outputMatrix = listOut.at(e); + if (outputMatrix.platformBuffer() != inputMatrix.platformBuffer()) // if not inplace + outputMatrix.assign(inputMatrix); if (lowerBand >= 0) { - for (Nd4jLong row = 0; row < inputMatrix->rows(); ++row) { + for (Nd4jLong row = 0; row < inputMatrix.rows(); ++row) { for (Nd4jLong col = 0; col < row; ++col) { if ((row - col) > lowerBand) - outputMatrix->p(row, col, 0.); + outputMatrix.p(row, col, 0.); // else // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); } @@ -49,10 +49,10 @@ namespace helpers { } } if (upperBand >= 0) { - for (Nd4jLong col = 0; col < inputMatrix->columns(); ++col) { + for (Nd4jLong col = 0; col < inputMatrix.columns(); ++col) { for (Nd4jLong row = 0; row < col; ++row) { if ((col - row) > upperBand) - outputMatrix->p(row, col, 0.); + outputMatrix.p(row, col, 0.); // else // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp index 3271dc110cab..afd4aa64d597 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp @@ -48,7 +48,7 @@ int _matrixDiagPart(const NDArray* input, NDArray* output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) for (int j = 0; j < lastDimension; ++j) - listOut.at(i)->p(j, listDiag.at(i)->e(j, j)); + listOut.at(i).p(j, listDiag.at(i).e(j, j)); }; samediff::Threads::parallel_tad(func, 0, lO); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp index 336eacf20dfd..a6c167f27f87 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp @@ -43,7 +43,7 @@ void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, for(int i = 0; i < rank; ++i) { auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); for(int j = 0; j < list.size(); ++j) - list.at(j)->assign(inArrs[i]); + list.at(j).assign(inArrs[i]); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp index 2730d9e886d5..9ee5d7b45f66 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp @@ -57,7 +57,7 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { auto row = rows.at(e); - output->p(e, row->e(n)); + output->p(e, row.e(n)); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp index 3ffa4dd8234a..7c85f3e0e7c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp @@ -42,9 +42,9 @@ static void _percentile(const NDArray& input, NDArray& output, std::vector& auto listOfSubArrs = input.allTensorsAlongDimension(axises); - std::vector shapeOfSubArr(listOfSubArrs.at(0)->rankOf()); + std::vector shapeOfSubArr(listOfSubArrs.at(0).rankOf()); for(int i=0; ishapeOf()[i]; + shapeOfSubArr[i] = listOfSubArrs.at(0).shapeOf()[i]; auto flattenedArr = NDArrayFactory::create('c', shapeOfSubArr, input.dataType(), input.getContext()); const int len = flattenedArr.lengthOf(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp index 5307f841ea3c..27ababdd108f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp @@ -107,7 +107,7 @@ namespace sd { auto tx = xTads.at(e); auto tz = zTads.at(e); - prefix_(op, tx->buffer(), tx->shapeInfo(), tz->buffer(), tz->shapeInfo(), exclusive, reverse); + prefix_(op, tx.buffer(), tx.shapeInfo(), tz.buffer(), tz.shapeInfo(), exclusive, reverse); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index 1f980e553697..c2bbdb7c870b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -115,7 +115,7 @@ namespace helpers { auto batching = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch++) { //qr here - qrSingle(listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); + qrSingle(&listInput.at(batch), &listOutQ.at(batch), &listOutR.at(batch), fullMatricies); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index 7323c393730f..90de864aa98c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -90,7 +90,7 @@ void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& if(i == r) continue; - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); + subArrsListIn.at(i).swapUnsafe(subArrsListIn.at(r)); } } else { @@ -102,16 +102,16 @@ void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) for(int i = firstDim - 1; i > 0; --i) { int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); + subArrsListOut.at(i).assign(subArrsListIn.at(indices[r])); if(r == 0) isZeroShuffled = true; if(i == r) continue; - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); + subArrsListOut.at(r).assign(subArrsListIn.at(indices[i])); math::nd4j_swap(indices[i], indices[r]); } if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); + subArrsListOut.at(0).assign(subArrsListIn.at(0)); } rng.rewindH(firstDim-1); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index 3d17fb62af02..6fd1214f6679 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -176,13 +176,13 @@ static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, Nd4jLong numOfElemsToReverse = seqLengths->e(i); if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); + outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); } else { - auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto inInnerSet = inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + auto outInnerSet = outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); for(int j = 0; j < inInnerSet.size(); ++j) - helpers::reverseArray(context, inInnerSet.at(j)->getBuffer(), inInnerSet.at(j)->getShapeInfo(), outInnerSet.at(j)->getBuffer(), outInnerSet.at(j)->getShapeInfo(), numOfElemsToReverse); + helpers::reverseArray(context, inInnerSet.at(j).getBuffer(), inInnerSet.at(j).getShapeInfo(), outInnerSet.at(j).getBuffer(), outInnerSet.at(j).getShapeInfo(), numOfElemsToReverse); } } } @@ -201,12 +201,10 @@ void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, auto listOut = output->allTensorsAlongDimension(dimensions); auto listIn = input->allTensorsAlongDimension(dimensions); - NDArray *subArrIn, *subArrOut; - for(int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() - subArrIn = listIn.at(i); - subArrOut = listOut.at(i); - BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->getBuffer(), subArrIn->getShapeInfo(), subArrOut->getBuffer(), subArrOut->getShapeInfo()), LIBND4J_TYPES); + auto subArrIn = listIn.at(i); + auto subArrOut = listOut.at(i); + BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn.getBuffer(), subArrIn.getShapeInfo(), subArrOut.getBuffer(), subArrOut.getShapeInfo()), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index 278f3bcf5ae6..228435d2b64d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -105,7 +105,7 @@ namespace helpers { theShift -= fullLen * (theShift / fullLen - 1); } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); + rollFunctorLinear(context, &listOfTensors.at(k), &listOfOutTensors.at(k), theShift, true); } } else { @@ -133,7 +133,7 @@ namespace helpers { for (int e = theShift; e < sizeAt - theShift; ++e) { auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); auto targetM = listOfOutTensors.at(dim * sizeAt + e); - sourceM->swapUnsafe(*targetM); + sourceM.swapUnsafe(targetM); } for (int e = 0; e < theShift; ++e) { @@ -141,7 +141,7 @@ namespace helpers { auto sourceM = listOfTensors.at(sourceIndex); auto targetM = listOfOutTensors.at(dim * sizeAt + e); - sourceM->swapUnsafe(*targetM); + sourceM.swapUnsafe(targetM); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index e57264e666b3..241e3e131a2a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -59,19 +59,19 @@ namespace helpers { auto maxT = listOfOutTensors.at(idx); //int pos = 0; - maxT->assign(listOfTensors.at(0)); + maxT.assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { - for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->t(e) = sd::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); + for (Nd4jLong e = 0; e < maxT.lengthOf(); e++) { + maxT.t(e) = sd::math::nd4j_max(maxT.t(e), listOfTensors.at(i).t(e)); } } else { idx = indices->e(i); maxT = listOfOutTensors.at(idx); - maxT->assign(listOfTensors.at(i)); + maxT.assign(listOfTensors.at(i)); } } @@ -110,19 +110,19 @@ namespace helpers { auto minT = listOfOutTensors.at(idx); int pos = 0; - minT->assign(listOfTensors.at(0)); + minT.assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { - for (Nd4jLong e = 0; e < minT->lengthOf(); e++) { - minT->p(e, sd::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); + for (Nd4jLong e = 0; e < minT.lengthOf(); e++) { + minT.p(e, sd::math::nd4j_min(minT.e(e), listOfTensors.at(i).e(e))); } } else { idx = indices->e(i); minT = listOfOutTensors.at(idx); - minT->assign(listOfTensors.at(i)); + minT.assign(listOfTensors.at(i)); } } } @@ -163,29 +163,29 @@ namespace helpers { std::vector> outputs(numOfClasses); auto meanT = listOfOutTensors.at(idx); int count = 1; - auto meanV = meanT->dup(); + auto meanV = meanT.dup(); meanV.assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { - meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); + meanV.p(e, meanV.e(e) + listOfTensors.at(i).e(e)); } }; - samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); + samediff::Threads::parallel_for(func, 0, meanT.lengthOf()); count++; } else { //meanT->assign(meanV); - meanV.applyScalar(scalar::Divide, count, *meanT); + meanV.applyScalar(scalar::Divide, count, meanT); idx = indices->e(i); meanT = listOfOutTensors.at(idx); meanV.assign(listOfTensors.at(i)); count = 1; } - meanV.applyScalar(scalar::Divide, count, *meanT); + meanV.applyScalar(scalar::Divide, count, meanT); } } } @@ -224,15 +224,15 @@ namespace helpers { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); + sumT.p(e, sumT.e(e) + listOfTensors.at(i).e(e)); } }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); } else { idx = indices->e(i); sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); + sumT.assign(listOfTensors.at(i)); } } } @@ -268,20 +268,20 @@ namespace helpers { int numOfClasses = output->sizeAt(0); // number of classes auto sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(0)); + sumT.assign(listOfTensors.at(0)); for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); + sumT.p(e, sumT.e(e) * listOfTensors.at(i).e(e)); } }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); } else { idx = indices->e(i); sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); + sumT.assign(listOfTensors.at(i)); } } } @@ -379,13 +379,13 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { auto maxT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - T val = sd::math::nd4j_max(maxT->e(e), outputT->e(e)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + T val = sd::math::nd4j_max(maxT.e(e), outputT.e(e)); - outputT->p(e, val); + outputT.p(e, val); } } } @@ -431,12 +431,12 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); for (size_t idx = 1; idx < fi->second.size(); ++idx) { auto minT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - outputT->t(e) = sd::math::nd4j_min(minT->t(e), outputT->t(e)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + outputT.t(e) = sd::math::nd4j_min(minT.t(e), outputT.t(e)); } } //outputT->assign(maxT); @@ -481,14 +481,14 @@ namespace helpers { // FIXME: parallelism here? for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loopSize = fi->second.size(); for (Nd4jLong idx = 1; idx < loopSize; ++idx) { auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; + outputT += current; } - (*outputT) /= double(fi->second.size()); + (outputT) /= double(fi->second.size()); } } } @@ -519,13 +519,13 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); Nd4jLong loop_size = fi->second.size(); // FIXME: parallelism here? for (Nd4jLong idx = 1; idx < loop_size; ++idx) { auto current = listOfTensors.at(fi->second.at(idx)); - *(outputT) += *current; + outputT += current; } //outputT->assign(maxT); } @@ -559,11 +559,11 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); for (size_t idx = 1; idx < fi->second.size(); ++idx) { auto current = listOfTensors.at(fi->second.at(idx)); - *outputT *= *current; + outputT *= current; } } } @@ -598,13 +598,13 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); + outputT.assign(listOfTensors.at(fi->second.at(0))); for (size_t idx = 1; idx < fi->second.size(); ++idx) { auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; + outputT += current; } //outputT->assign(maxT); - (*outputT) /= sd::math::nd4j_sqrt(fi->second.size()); + outputT /= sd::math::nd4j_sqrt(fi->second.size()); } } } @@ -651,9 +651,9 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) - currentOut->p(e, currentGradOut->e(e)); + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) <= T(1.e-6)) + currentOut.p(e, currentGradOut.e(e)); } } }; @@ -703,10 +703,10 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) < 1.e-5) - currentOut->p(e, currentGradOut->e(e)); + currentOut.p(e, currentGradOut.e(e)); } } }; @@ -752,8 +752,8 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + currentOut.p(e, currentGradOut.e(e) / classCount.at(classNum)); } } //}; @@ -787,7 +787,7 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - currentOut->assign(currentGradOut); + currentOut.assign(currentGradOut); } //}; @@ -824,7 +824,7 @@ namespace helpers { auto currentGradOut = listOfGradOuts.at(classNum); auto currentFFOut = listOfBPTensors.at(classNum); - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); + currentOut.assign(currentFFOut * currentGradOut / current); } //}; @@ -862,12 +862,12 @@ namespace helpers { for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - for (int e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) - currentOut->p(e, currentGradOut->e(e)); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + for (int e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) < 1.e-5) + currentOut.p(e, currentGradOut.e(e)); } } } @@ -911,9 +911,9 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) - currentOut->t(e) = currentGradOut->t(e); + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).t(e) - current.t(e)) < 1.e-6) + currentOut.t(e) = currentGradOut.t(e); } } //}; @@ -957,10 +957,10 @@ namespace helpers { for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - currentOut->assign(*currentGradOut / double(classCount[classNum])); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + currentOut.assign(currentGradOut / double(classCount[classNum])); } } return ND4J_STATUS_OK; @@ -989,7 +989,7 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - currentOut->assign(currentGradOut); + currentOut.assign(currentGradOut); } //}; @@ -1028,7 +1028,7 @@ namespace helpers { auto currentGradOut = listOfGradOuts.at(classNum); auto currentFFOut = listOfBPTensors.at(classNum); - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); + currentOut.assign(currentFFOut * currentGradOut / current); } //}; @@ -1075,8 +1075,8 @@ namespace helpers { auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (int e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / sd::math::nd4j_sqrt(classCount[classNum])); + for (int e = 0; e < current.lengthOf(); e++) { + currentOut.p(e, currentGradOut.e(e) / sd::math::nd4j_sqrt(classCount[classNum])); } } //}; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index 9a06975aa666..24e8ef317e86 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -43,7 +43,7 @@ namespace helpers { for (auto batch = start; batch < stop; batch++) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c < r; c++) { - math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r)); + math::nd4j_swap(outputPart[batch].t(r, c) , outputPart[batch].t(c, r)); } } } @@ -66,8 +66,8 @@ namespace helpers { auto permutationsPart = permutations.allTensorsAlongDimension({-1}); for (auto batch = 0; batch < permutationsPart.size(); ++batch) { - for (Nd4jLong row = 0; row < PPart[batch]->rows(); ++row) { - PPart[batch]->t(row, permutationsPart[batch]->t(row)) = T(1.f); + for (Nd4jLong row = 0; row < PPart[batch].rows(); ++row) { + PPart[batch].t(row, permutationsPart[batch].t(row)) = T(1.f); } } @@ -77,8 +77,8 @@ namespace helpers { MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); for (auto i = 0; i < leftLowerPart.size(); i++) { - for (Nd4jLong r = 0; r < leftLowerPart[i]->rows(); r++) - leftLowerPart[i]->t(r,r) = (T)1.f; + for (Nd4jLong r = 0; r < leftLowerPart[i].rows(); r++) + leftLowerPart[i].t(r,r) = (T)1.f; } // stage 2: triangularSolveFunctor for Lower with given b helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 7ae4221d4abc..23ce7888071c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -957,12 +957,12 @@ static void svd_(const NDArray* x, const std::vector& outArrs, const b // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), listX.at(i)->sizeAt(1)}, block.getContext()); // matrix.assign(listX.at(i)); - helpers::SVD svdObj(*(listX.at(i)), switchNum, calcUV, calcUV, fullUV); - listS.at(i)->assign(svdObj._s); + helpers::SVD svdObj(listX.at(i), switchNum, calcUV, calcUV, fullUV); + listS.at(i).assign(svdObj._s); if(calcUV) { - listU->at(i)->assign(svdObj._u); - listV->at(i)->assign(svdObj._v); + listU->at(i).assign(svdObj._u); + listV->at(i).assign(svdObj._v); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp index d544fa24eea6..c80829aabf34 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp @@ -34,7 +34,7 @@ static void trace_(const NDArray& input, NDArray& output) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) - output.p(i, setOfSubArrs.at(i)->getTrace()); + output.p(i, setOfSubArrs.at(i).getTrace()); }; samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index bcf4063920a7..4bb09378930d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -92,9 +92,9 @@ namespace helpers { auto batchLoop = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { if (lower) { - lowerTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + lowerTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, &outputPart[i]); } else { - upperTriangularSolve(context, leftPart[i], rightPart[i], adjoint, outputPart[i]); + upperTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, &outputPart[i]); } } }; @@ -116,13 +116,13 @@ namespace helpers { if (!lower) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c <= r; c++) { - outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].t(r, c) = inputPart[batch].t(c, r); } } } else { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = r; c < cols; c++) { - outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].t(r, c) = inputPart[batch].t(c, r); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cross.h b/libnd4j/include/ops/declarable/helpers/cross.h index bd1e2a61dbe9..2a3ee22c1626 100644 --- a/libnd4j/include/ops/declarable/helpers/cross.h +++ b/libnd4j/include/ops/declarable/helpers/cross.h @@ -73,7 +73,7 @@ void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDAr auto b_ = tadsB.at(e); auto o_ = tadsO.at(e); - helpers::cross(context, a_, b_, o_); + helpers::cross(context, &a_, &b_, &o_); } }; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 2ca731912108..55aadc55640a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -630,7 +630,7 @@ namespace helpers { // DataType dtype = input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -676,7 +676,7 @@ namespace helpers { if (dtype != DataType::DOUBLE) dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.workspace()); auto det = NDArrayFactory::create(1, context); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 435a3e32daae..2aac4aa4c68f 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -231,15 +231,15 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!h) { // seqLen and h are absent - lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps } else { // seqLen is absent and h is present - lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(0), ct); // first time step for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t - 1), ct, Wp, params, &hSet->at(t), ct); // rest time steps if(hL) hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr @@ -255,18 +255,18 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(limit == 0) { if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step for (int t = 1; t < limit; ++t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps } } } @@ -281,24 +281,24 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step for (int t = 1; t < limit; ++t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr + htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } @@ -311,15 +311,15 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!h) { // seqLen and h are absent - lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps } else { // seqLen is absent and h is present - lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(sL - 1), ct); // first time step for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), ct, Wp, params, &hSet->at(t), ct); // rest time steps if(hL) hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr @@ -335,18 +335,18 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(limit == 0) { if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps } } } @@ -361,24 +361,24 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } @@ -394,18 +394,18 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(limit == 0) { if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step for (int t = limit - 2; t >= 0; --t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps } } } @@ -420,24 +420,24 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) - ctSet->at(e)->nullify(); + ctSet->at(e).nullify(); if(hL) - htSet->at(e)->nullify(); + htSet->at(e).nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step for (int t = limit - 2; t >= 0; --t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) - htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index 5989f5246f67..ef04b9a4e240 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -36,13 +36,13 @@ namespace helpers { throw std::runtime_error("multiUnique: this op support INT32 data type only."); reshaped[pos] = array->reshape(array->ordering(), {-1}); - cContext.setInputArray(pos, &reshaped[pos]); + cContext.setInputArray(pos, reshaped[pos]); length += array->lengthOf(); pos++; } NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext()); - cContext.setOutputArray(0, &arrayFull); + cContext.setOutputArray(0, arrayFull); cContext.setIArguments(&axis, 1); sd::ops::concat opConcat; @@ -57,7 +57,7 @@ namespace helpers { auto uniqueVals = uResult.at(0); - bool res = uniqueVals->lengthOf() == arrayFull.lengthOf(); + bool res = uniqueVals.lengthOf() == arrayFull.lengthOf(); return res; } diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index 1023ae127335..4c2a77308fd1 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -68,13 +68,12 @@ namespace sd { std::pair pair(ctx.nodeId(), e); if (!variableSpace->hasVariable(pair)) - variableSpace->putVariable(pair, new Variable()); + variableSpace->putVariable(pair, std::make_shared()); auto var = ctx.variable(pair); if (!var->hasNDArray()) { - var->setNDArray(NDArrayFactory::create_(false, ctx.launchContext())); - var->markRemovable(true); + var->setNDArray(std::make_shared(NDArrayFactory::create(false, ctx.launchContext()))); } } @@ -121,7 +120,7 @@ namespace sd { int cnt = -1; std::vector in; for (auto v: args) { - auto var = new Variable(v); + auto var = std::make_shared(v); var->markRemovable(false); in.push_back(cnt); variableSpace.putVariable(cnt--, var); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index a83b970fbede..15d76c75e461 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -57,11 +57,11 @@ namespace sd { } void DeclarableListOp::setupResult(NDArray* array, Context& block) { - block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); + block.pushNDArrayToVariableSpace(block.getNodeId(), 0, *array); } void DeclarableListOp::setupResultList(NDArrayList* arrayList, Context& block) { - block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); + block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, *arrayList); } ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs) { @@ -105,19 +105,22 @@ namespace sd { if (list->id().first == 0) list->id().first = -1; - auto listVar = new Variable(nullptr, nullptr, -119, 0); - listVar->setNDArrayList(list); - varSpace.putVariable(-1, listVar); - in.push_back(-1); - cnt--; + auto listVar = std::make_shared(); + listVar->setId(-119, 0); + //listVar->setNDArrayList(list); + //varSpace.putVariable(-1, listVar); + //in.push_back(-1); + //cnt--; + throw std::runtime_error("DeclarableListOp::execute - Not implemented yet"); } for (auto v: inputs) { - auto var = new Variable(v); - var->markRemovable(false); - in.push_back(cnt); - varSpace.putVariable(cnt--, var); + //auto var = new Variable(v); + //var->markRemovable(false); + //in.push_back(cnt); + //varSpace.putVariable(cnt--, var); + throw std::runtime_error("DeclarableListOp::execute - Not implemented yet"); } Context block(1, &varSpace, false); @@ -143,10 +146,10 @@ namespace sd { auto arr = var->getNDArray(); if (arr->isAttached()) { auto d = arr->detach(); - res.push_back(d); + res.push_back(*d); } else { var->markRemovable(false); - res.push_back(arr); + res.push_back(*arr); } } } else diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 3a061661e5fe..d049293c172d 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -111,32 +111,33 @@ namespace sd { if (ctx.isFastPath()) { if (ctx.fastpath_out().size() <= inputId) { if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId]; + z = ctx.fastpath_in()[inputId].get(); } else throw std::runtime_error("fastpath_out: unresolved output array"); } else { - z = ctx.fastpath_out()[inputId]; + z = ctx.fastpath_out()[inputId].get(); } } else { std::pair pair(ctx.nodeId(), inputId); if (ctx.isInplace()) { - z = ctx.variable(inputId)->getNDArray(); + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = new Variable(); + auto var = std::make_shared(); ctx.getVariableSpace()->putVariable(pair, var); } // now we're saving input array as output array auto var = ctx.getVariableSpace()->getVariable(pair); var->markRemovable(false); - var->setNDArray(z); + var->setNDArray(vz); } else if (!ctx.isInplace()) { auto var = ctx.variable(pair); if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray(); + z = var->getNDArray().get(); } else { nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); } @@ -149,8 +150,8 @@ namespace sd { return z; } - int sd::ops::DeclarableOp::prepareOutputs(Context &ctx) { - auto workspace = ctx.getWorkspace(); + int DeclarableOp::prepareOutputs(Context &ctx) { + auto workspace = ctx.workspace(); GraphProfile *prof = nullptr; NodeProfile *node = nullptr; std::chrono::time_point inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; @@ -159,10 +160,13 @@ namespace sd { auto fp = ctx.isFastPath(); if (Environment::getInstance()->isProfiling()) { + /* if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) { prof = ctx.getVariableSpace()->flowPath()->profile(); node = prof->nodeById(ctx.nodeId()); } + */ + throw std::runtime_error("DeclarableOp::prepareOutputs - Not implemented yet"); } if (ctx.isInplace()) { @@ -173,7 +177,7 @@ namespace sd { for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); + auto array = var->getNDArray().get(); node->addInputShape(array->shapeInfo()); node->addOutputShape(array->shapeInfo()); @@ -190,7 +194,7 @@ namespace sd { for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); + auto array = var->getNDArray(); ctx.setInputArray(cnt, array); ctx.setOutputArray(cnt, array); @@ -241,8 +245,8 @@ namespace sd { for (auto p: ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { - NDArray *array = var->getNDArray(); - if (array == nullptr) + auto array = var->getNDArray(); + if (array.get() == nullptr) throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p); inSha.push_back(array->getShapeInfo()); @@ -303,7 +307,7 @@ namespace sd { shape::printShapeInfoLinear("Going to create variable with shape", out); // we're creating non-initialized array here - auto outArr = new NDArray(out, true, ctx.launchContext(), false); + NDArray outArr(out, true, ctx.launchContext(), false); ctx.pushNDArrayToVariableSpace(pair, outArr); @@ -342,8 +346,8 @@ namespace sd { auto idx = cnt++; if (fout.size() <= idx) { // array doesnt exist - auto outArr = new NDArray(out, true, ctx.launchContext()); - ctx.setOutputArray(idx, outArr, true); + auto outArr = std::make_shared(out, true, ctx.launchContext()); + ctx.setOutputArray(idx, outArr); } else { auto array = fout[idx]; // checking out shape equality @@ -382,13 +386,13 @@ namespace sd { } void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, int outputNumber, NDArray& array) { - ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, &array, !ctx.isInplace()); + ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, array); } bool sd::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) { auto var = block.variable(block.getNodeId(), 0); - auto workspace = block.getWorkspace(); + auto workspace = block.workspace(); Nd4jLong len = shape::length(shape); Nd4jLong* __shape; @@ -400,13 +404,12 @@ namespace sd { if (var->getNDArray() == nullptr) { std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext())); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), block.launchContext())); } else if(var->getNDArray()->lengthOf() != len) { // if length not match - lets reallocate array - delete var->getNDArray(); std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(new NDArray(buffer, ShapeDescriptor(__shape), block.launchContext())); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), block.launchContext())); } return true; @@ -415,16 +418,15 @@ namespace sd { bool sd::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list& shape, char order) { auto var = block.variable(block.getNodeId(), 0); - auto workspace = block.getWorkspace(); + auto workspace = block.workspace(); Nd4jLong len = shape::length(shape); // if that's first run - we probably have nothing here if (var->getNDArray() == nullptr) { - var->setNDArray(new NDArray(order, shape, DataType::FLOAT32, block.launchContext())); + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, block.launchContext())); } else if(var->getNDArray()->lengthOf() != len) { // if length not match - lets reallocate array - delete var->getNDArray(); - var->setNDArray(new NDArray(order, shape, DataType::FLOAT32, block.launchContext())); + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, block.launchContext())); } return true; @@ -592,7 +594,7 @@ namespace sd { return ND4J_STATUS_OK; } - Nd4jStatus sd::ops::DeclarableOp::execute(Context* block) { + Nd4jStatus DeclarableOp::execute(Context* block) { nd4j_debug("Executing op: [%s]\n", this->getOpName().c_str()); std::chrono::time_point timeEnter, timeStart, timeEnd; @@ -648,6 +650,7 @@ namespace sd { } if (Environment::getInstance()->isProfiling() && block->getVariableSpace() != nullptr) { + /* auto fp = block->getVariableSpace()->flowPath(); if (fp != nullptr) { auto p = fp->profile(); @@ -659,6 +662,8 @@ namespace sd { p->nodeById(block->nodeId())->setTotalSize(memoryUsed); } } + */ + throw std::runtime_error("DeclarableOp::execute - Not implemented yet"); } @@ -685,7 +690,7 @@ namespace sd { auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray(); - auto shape = ShapeUtils::shapeAsString(array); + auto shape = ShapeUtils::shapeAsString(array.get()); auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32); auto type = DataTypeUtils::asString(array->dataType()); @@ -770,7 +775,7 @@ namespace sd { for (auto p: block.inputs()) { auto v = block.variable(p); - NDArray *aV = v->getNDArray(); + NDArray *aV = v->getNDArray().get(); if (aV == nullptr) return ND4J_STATUS_BAD_INPUT; @@ -817,11 +822,12 @@ namespace sd { } if (v->variableType() == VariableType::NDARRAY) { - NDArray *aV = v->getNDArray(); - // if array is empty intentionally - we're ok with that - if (v->hasNDArray() && v->isEmpty()) + if (v->hasNDArray() && v->isEmpty()) { continue; + } + + NDArray *aV = v->getNDArray().get(); if (aV == nullptr || !aV->nonNull()) { if (!this->getOpName().empty()) { @@ -843,10 +849,10 @@ namespace sd { if (block.width() == 0) return ND4J_STATUS_OK; - NDArray *a0 = block.variable(0)->getNDArray(); + NDArray *a0 = block.variable(0)->getNDArray().get(); for (auto p: block.inputs()) { auto v = block.variable(p); - NDArray *aV = v->getNDArray(); + NDArray *aV = v->getNDArray().get(); if (a0->ordering() != aV->ordering()) return ND4J_STATUS_BAD_ORDER; } @@ -857,7 +863,7 @@ namespace sd { Nd4jStatus sd::ops::DeclarableOp::execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector& dArgs, bool isInplace, sd::DataType type) { VariableSpace variableSpace; FlowPath fp; - variableSpace.setFlowPath(&fp); + //variableSpace.setFlowPath(&fp); int cnt = -1; std::vector in; @@ -865,7 +871,7 @@ namespace sd { if (v == nullptr) continue; - auto var = new Variable(v); + auto var = std::make_shared(v); var->markRemovable(false); in.push_back(cnt); variableSpace.putVariable(cnt--, var); @@ -873,7 +879,7 @@ namespace sd { int et = 0; for (auto v: outputs) { - auto var = new Variable(v); + auto var = std::make_shared(v); var->markRemovable(false); std::pair pair(1, et++); variableSpace.putVariable(pair, var); @@ -951,11 +957,11 @@ namespace sd { Context ctx(1); for (int e = 0; e < inputs.size(); e++) { - ctx.setInputArray(e, inputs[e]); + ctx.setInputArray(e, *inputs[e]); } for (int e = 0; e < outputs.size(); e++) { - ctx.setOutputArray(e, outputs[e]); + ctx.setOutputArray(e, *outputs[e]); } @@ -1016,7 +1022,7 @@ namespace sd { VariableSpace variableSpace; //ResultSet arrayList; FlowPath fp; - variableSpace.setFlowPath(&fp); + //variableSpace.setFlowPath(&fp); int cnt = -1; std::vector in; @@ -1024,7 +1030,7 @@ namespace sd { if (v == nullptr) continue; - auto var = new Variable(v); + auto var = std::make_shared(v); var->markRemovable(false); in.push_back(cnt); variableSpace.putVariable(cnt--, var); @@ -1065,16 +1071,16 @@ namespace sd { if (!arr->isAttached()) { var->markRemovable(false); arr->setContext(sd::LaunchContext::defaultContext()); - arrayList.push_back(arr); + arrayList.push_back(*arr.get()); } else { - arrayList.push_back(arr->detach()); + arrayList.push_back(*arr->detach()); } } else break; } } else { for (auto v:inputs) { - arrayList.push_back(v); + arrayList.push_back(*v); } } @@ -1090,7 +1096,7 @@ namespace sd { if (block.width() == 0) return ND4J_STATUS_OK; - NDArray *a0 = block.array(0); + NDArray *a0 = block.array(0).get(); for (int e = 0; e < block.width(); e++) { auto aV = block.array(e); if (!shape::equalsSoft(a0->getShapeInfo(), aV->getShapeInfo())) @@ -1128,7 +1134,7 @@ namespace sd { // default implementation suits transform, so just returns the same shape int* newshape; - ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape), int); + ALLOCATE(newshape, block.workspace(), shape::shapeInfoLength(inputShape), int); memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape)); return newshape; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp index e483b0fb1f92..4e1b6db0b2c2 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp @@ -56,7 +56,7 @@ namespace sd { return SHAPELIST(newShape); } - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), false, false, block.getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), false, false, block.workspace()); return SHAPELIST(newShape); } } diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp index f4954324f64f..dc9630ff1bff 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp @@ -99,7 +99,7 @@ namespace sd { // FIXME: remove memcpy Nd4jLong *newShape; - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(inShape), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(inShape), Nd4jLong); memcpy(newShape, inShape, shape::shapeInfoByteLength(inShape)); return SHAPELIST(CONSTANT(newShape)); diff --git a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp index 8db309e3c354..b7e73961df50 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp @@ -45,7 +45,7 @@ namespace sd { Nd4jLong *newShape; if (block.getAxis().size() == 0 && block.width() == 1) { // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; newShape[1] = 1; newShape[2] = 1; @@ -55,11 +55,11 @@ namespace sd { newShape[7] = 99; auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.getWorkspace()); + RELEASE(newShape, block.workspace()); return SHAPELIST(result); } else if (block.getAxis().size()){ // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.getWorkspace()); + auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.workspace()); newShape = ShapeUtils::evalReduceShapeInfo('c', block.getAxis(), *array, DataType::INT64, false, true, block.workspace()); return SHAPELIST(newShape); @@ -79,7 +79,7 @@ namespace sd { } if (allAxes){ // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; newShape[1] = 1; newShape[2] = 1; @@ -89,11 +89,11 @@ namespace sd { newShape[7] = 99; auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.getWorkspace()); + RELEASE(newShape, block.workspace()); return SHAPELIST(result); } else { // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.getWorkspace()); + auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.workspace()); newShape = ShapeUtils::evalReduceShapeInfo('c', axis, *array, DataType::INT64, false, true, block.workspace()); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index f42edbcc727e..e83e3755dcf0 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -79,7 +79,7 @@ namespace sd { REQUIRE_TRUE(false, 0, "Uniform requires either TArgs or 3 arguments to be present"); } - auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to); @@ -138,7 +138,7 @@ namespace sd { for (int e = 0; e < input->lengthOf(); e++) shape[e] = input->e(e); - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev); @@ -166,7 +166,7 @@ namespace sd { for (int e = 0; e < input->lengthOf(); e++) shape[e] = input->e(e); - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob); @@ -199,7 +199,7 @@ namespace sd { for (int e = 0; e < input->lengthOf(); e++) shape[e] = input->e(e); - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob); @@ -231,7 +231,7 @@ namespace sd { for (int e = 0; e < input->lengthOf(); e++) shape[e] = input->e(e); - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); @@ -263,7 +263,7 @@ namespace sd { for (int e = 0; e < input->lengthOf(); e++) shape[e] = input->e(e); - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); + auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.workspace()); RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); } @@ -375,7 +375,7 @@ namespace sd { if (v == nullptr) continue; - auto var = new Variable(v); + auto var = std::make_shared(v); var->markRemovable(false); in.push_back(cnt); variableSpace.putVariable(cnt--, var); @@ -407,9 +407,9 @@ namespace sd { auto arr = var->getNDArray(); if (!arr->isAttached()) { var->markRemovable(false); - arrayList.push_back(arr); + arrayList.push_back(*arr.get()); } else { - arrayList.push_back(arr->detach()); + arrayList.push_back(*arr->detach()); } } else break; diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp index 40bf279f6144..b91aced15064 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp @@ -100,7 +100,7 @@ namespace sd { if (shape::equalsSoft(xShape, yShape) && (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { // reduce3 to scalar case - ALLOCATE(zShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); zShape[0] = 2; zShape[1] = 1; zShape[2] = 1; diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp index 46be149c6222..30577b606d3b 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp @@ -153,7 +153,7 @@ namespace sd { if (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { if (block.getIArguments()->size() > 0 && block.getIArguments()->at(0) == 1) { // in this case we just return legacy scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; newShape[1] = 1; newShape[2] = 1; @@ -164,7 +164,7 @@ namespace sd { newShape[7] = 99; //ArrayOptions::setDataType(newShape, block.dataType() == DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); } else { - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(0), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(0), Nd4jLong); newShape[0] = 0; newShape[1] = 0; newShape[2] = 1; @@ -173,7 +173,7 @@ namespace sd { } } else { // in this case we're building proper shape for reduction - auto array = new NDArray(nullptr, inShape, block.getWorkspace()); + auto array = new NDArray(nullptr, inShape, block.workspace()); newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), *block.getIArguments(), *array, false, false, block.workspace()); diff --git a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp index 11a8fe1daa2f..587a70832418 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp @@ -95,7 +95,7 @@ namespace sd { Nd4jLong *newShape; if (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { // in this case we just return scalar - ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); newShape[0] = 2; newShape[1] = 1; newShape[2] = 1; diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index dfc18d33b550..6cd7873a1539 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -45,32 +45,33 @@ namespace sd { if (ctx.isFastPath()) { if (ctx.fastpath_out().size() <= inputId) { if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId]; + z = ctx.fastpath_in()[inputId].get(); } else throw std::runtime_error("fastpath_out: unresolved output array"); } else { - z = ctx.fastpath_out()[inputId]; + z = ctx.fastpath_out()[inputId].get(); } } else { std::pair pair(ctx.nodeId(), inputId); if (ctx.isInplace()) { - z = ctx.variable(inputId)->getNDArray(); + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = new graph::Variable(); + auto var = std::make_shared(); ctx.getVariableSpace()->putVariable(pair, var); } // now we're saving input array as output array auto var = ctx.getVariableSpace()->getVariable(pair); var->markRemovable(false); - var->setNDArray(z); + var->setNDArray(vz); } else if (!ctx.isInplace()) { auto var = ctx.variable(pair); if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray(); + z = var->getNDArray().get(); } else { nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); } diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 0eabe959ad11..f1bccad73017 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -103,21 +103,21 @@ namespace sd { int axis; if (n == 0) { //nchw - auto input = NDArrayFactory::create_('c', {16, c, hw, hw}); - auto output = NDArrayFactory::create_('c', {16, c, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {16, c, hw, hw}); + auto output = NDArrayFactory::create('c', {16, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); axis = 1; } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, c}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, c}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, hw, hw, c}); + auto output = NDArrayFactory::create('c', {32, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); axis = 3; } - auto bias = NDArrayFactory::create_('c', {c}); - ctx->setInputArray(1, bias, true); + auto bias = NDArrayFactory::create('c', {c}); + ctx->setInputArray(1, bias); auto iargs = new Nd4jLong[1]; iargs[0] = axis; ctx->setIArguments(iargs, 1); @@ -164,12 +164,12 @@ namespace sd { //Same mode + stride 1: output is same shape as input if(format == 1) { //NDHWC - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); } else { //NCDHW - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } auto iargs = new Nd4jLong[15]; @@ -230,17 +230,17 @@ namespace sd { //Same mode + stride 1: output is same shape as input if(format == 1) { //NDHWC - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, dhw, dhw, dhw, chIn}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); } else { //NCDHW - ctx->setInputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {mb, chIn, dhw, dhw, dhw}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } //Weights and bias: - ctx->setInputArray(1, NDArrayFactory::create_('c', {3, 3, 3, chIn, chOut}), true); - ctx->setInputArray(2, NDArrayFactory::create_('c', {chOut}), true); + ctx->setInputArray(1, NDArrayFactory::create('c', {3, 3, 3, chIn, chOut})); + ctx->setInputArray(2, NDArrayFactory::create('c', {chOut})); auto iargs = new Nd4jLong[14]; @@ -296,46 +296,46 @@ namespace sd { int n = p.getIntParam("nInOut"); Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create_(l), true); //Max TS length (unused) + ctx->setInputArray(0, NDArrayFactory::create(l)); //Max TS length (unused) if (f == 0) { //TNS format - ctx->setInputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('c', {seqLength, m, n}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('c', {seqLength, m, n}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('c', {seqLength, m, n}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('c', {seqLength, m, n}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('c', {seqLength, m, n}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('c', {seqLength, m, n}), true); //y + ctx->setInputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //x + ctx->setOutputArray(0, NDArrayFactory::create('c', {seqLength, m, n})); //i + ctx->setOutputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //c + ctx->setOutputArray(2, NDArrayFactory::create('c', {seqLength, m, n})); //f + ctx->setOutputArray(3, NDArrayFactory::create('c', {seqLength, m, n})); //o + ctx->setOutputArray(4, NDArrayFactory::create('c', {seqLength, m, n})); //z + ctx->setOutputArray(5, NDArrayFactory::create('c', {seqLength, m, n})); //h + ctx->setOutputArray(6, NDArrayFactory::create('c', {seqLength, m, n})); //y } else { //NST format - ctx->setInputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('f', {m, n, seqLength}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('f', {m, n, seqLength}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('f', {m, n, seqLength}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('f', {m, n, seqLength}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('f', {m, n, seqLength}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('f', {m, n, seqLength}), true); //y + ctx->setInputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //x + ctx->setOutputArray(0, NDArrayFactory::create('f', {m, n, seqLength})); //i + ctx->setOutputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //c + ctx->setOutputArray(2, NDArrayFactory::create('f', {m, n, seqLength})); //f + ctx->setOutputArray(3, NDArrayFactory::create('f', {m, n, seqLength})); //o + ctx->setOutputArray(4, NDArrayFactory::create('f', {m, n, seqLength})); //z + ctx->setOutputArray(5, NDArrayFactory::create('f', {m, n, seqLength})); //h + ctx->setOutputArray(6, NDArrayFactory::create('f', {m, n, seqLength})); //y } - auto cLast = NDArrayFactory::create_('c', {m, n}); - auto yLast = NDArrayFactory::create_('c', {m, n}); - auto W = NDArrayFactory::create_('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create_('c', {n}); - auto Wcf = NDArrayFactory::create_('c', {n}); - auto Wco = NDArrayFactory::create_('c', {n}); - auto b = NDArrayFactory::create_('c', {4 * n}); - - ctx->setInputArray(2, cLast, true); - ctx->setInputArray(3, yLast, true); - ctx->setInputArray(4, W, true); - ctx->setInputArray(5, Wci, true); - ctx->setInputArray(6, Wcf, true); - ctx->setInputArray(7, Wco, true); - ctx->setInputArray(8, b, true); + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); auto iargs = new Nd4jLong[2]; iargs[0] = 0; //No peephole @@ -380,31 +380,31 @@ namespace sd { auto args = new Nd4jLong[3]; args[0] = args[1] = 1; //apply scale and offset if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, ch, hw, hw}); - auto output = NDArrayFactory::create_('c', {32, ch, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, ch, hw, hw}); + auto output = NDArrayFactory::create('c', {32, ch, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); args[2] = 1; //axis } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, ch}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, ch}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, hw, hw, ch}); + auto output = NDArrayFactory::create('c', {32, hw, hw, ch}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); args[2] = 3; //axis } ctx->setIArguments(args, 3); delete[] args; - ctx->setInputArray(1, NDArrayFactory::create_('c', {ch}), true); //mean - auto v = NDArrayFactory::create_('c', {ch}); - v->assign(1.0f); - ctx->setInputArray(2, v, true); //variance - auto g = NDArrayFactory::create_('c', {ch}); - g->assign(1.0); - ctx->setInputArray(3, g, true); //gamma - auto b = NDArrayFactory::create_('c', {ch}); - b->assign(1.0); - ctx->setInputArray(4, b, true); //beta + ctx->setInputArray(1, NDArrayFactory::create('c', {ch})); //mean + auto v = NDArrayFactory::create('c', {ch}); + v.assign(1.0f); + ctx->setInputArray(2, v); //variance + auto g = NDArrayFactory::create('c', {ch}); + g.assign(1.0); + ctx->setInputArray(3, g); //gamma + auto b = NDArrayFactory::create('c', {ch}); + b.assign(1.0); + ctx->setInputArray(4, b); //beta auto targs = new double[1]; targs[0] = 1e-5; @@ -446,15 +446,15 @@ namespace sd { int khw = p.getIntParam("k"); if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } auto args = new Nd4jLong[11]; @@ -507,22 +507,22 @@ namespace sd { int khw = p.getIntParam("k"); if (n == 0) { - auto input = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create_('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } else { - auto input = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create_('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - auto b = NDArrayFactory::create_('c', {p.getIntParam("c")}); - auto w = NDArrayFactory::create_('c', {khw, khw, p.getIntParam("c"), p.getIntParam("c")}); // [kH, kW, iC, oC] always + auto b = NDArrayFactory::create('c', {p.getIntParam("c")}); + auto w = NDArrayFactory::create('c', {khw, khw, p.getIntParam("c"), p.getIntParam("c")}); // [kH, kW, iC, oC] always - ctx->setInputArray(1, w, true); - ctx->setInputArray(2, b, true); + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); auto args = new Nd4jLong[10]; args[0] = args[1] = khw; //Kernel @@ -552,8 +552,8 @@ namespace sd { auto gen01 = PARAMETRIC_D() { auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {2},{1, p.getIntParam("length")}), true); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1, p.getIntParam("length")}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {2},{1, p.getIntParam("length")})); //Shape as NDArray + ctx->setOutputArray(0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); auto d = new double[2]; d[0] = 0.0; d[1] = 1.0; @@ -564,8 +564,8 @@ namespace sd { auto gen05 = PARAMETRIC_D() { auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {2},{1, p.getIntParam("length")}), true); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1, p.getIntParam("length")}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {2},{1, p.getIntParam("length")})); //Shape as NDArray + ctx->setOutputArray(0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); auto d = new double[1]; d[0] = 0.5; ctx->setTArguments(d, 1); @@ -638,9 +638,9 @@ namespace sd { } else { shapeB = {b, c}; } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); x.push_back(A); y.push_back(B); @@ -674,9 +674,9 @@ namespace sd { } else { shapeB = {b, c}; } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); x.push_back(A); y.push_back(B); @@ -710,9 +710,9 @@ namespace sd { } else { shapeB = {b, c}; } - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_('f', {a, c}); + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); x.push_back(A); y.push_back(B); @@ -752,13 +752,13 @@ namespace sd { auto ctx = new Context(1); if(rank == 3){ - ctx->setInputArray(0, NDArrayFactory::create_('c', {32, 1024, 1024}), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', {32, 1024, 1024}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {32, 1024, 1024}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setInputArray(1, NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {32, 1024, 1024})); } else { - ctx->setInputArray(0, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {4, 8, 1024, 1024}), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setInputArray(1, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setOutputArray(0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); } return ctx; @@ -787,9 +787,9 @@ namespace sd { auto generator = PARAMETRIC_XYZ() { auto s = p.getIntParam("sz"); - auto A = NDArrayFactory::create_('c', {s, s}); - auto B = NDArrayFactory::create_('c', {s, s}); - auto C = NDArrayFactory::create_(resultOrder, {s, s}); + auto A = NDArrayFactory::create('c', {s, s}); + auto B = NDArrayFactory::create('c', {s, s}); + auto C = NDArrayFactory::create(resultOrder, {s, s}); x.push_back(A); y.push_back(B); @@ -827,9 +827,9 @@ namespace sd { auto generator = PARAMETRIC_D() { auto ctx = new Context(1); int length = p.getIntParam("length"); - auto in = NDArrayFactory::create_('c', {length}); - auto indices = NDArrayFactory::create_('c', {length}); - auto updates = NDArrayFactory::create_('c', {length}); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); + auto updates = NDArrayFactory::create('c', {length}); int* a = new int[length]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! ctx->markInplace(true); return ctx; @@ -862,9 +862,9 @@ namespace sd { auto ctx = new Context(1); int rows = p.getIntParam("rows"); int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create_('c', {rows, cols}); - auto indices = NDArrayFactory::create_('c', {rows}); - auto updates = NDArrayFactory::create_('c', {rows, cols}); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); + auto updates = NDArrayFactory::create('c', {rows, cols}); int* a = new int[rows]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! ctx->markInplace(true); return ctx; @@ -897,9 +897,9 @@ namespace sd { auto ctx = new Context(1); int sz0 = p.getIntParam("sz0"); int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create_('c', {sz0}); - auto updates = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); + auto updates = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); int* a = new int[sz0]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setInputArray(2, updates, true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! ctx->markInplace(true); return ctx; @@ -937,8 +937,8 @@ namespace sd { auto generator = PARAMETRIC_D() { auto ctx = new Context(1); int length = p.getIntParam("length"); - auto in = NDArrayFactory::create_('c', {length}); - auto indices = NDArrayFactory::create_('c', {length}); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); int* a = new int[length]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {length}), true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {length})); return ctx; }; @@ -968,8 +968,8 @@ namespace sd { auto ctx = new Context(1); int rows = p.getIntParam("rows"); int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create_('c', {rows, cols}); - auto indices = NDArrayFactory::create_('c', {rows}); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); int* a = new int[rows]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {rows, cols}), true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, cols})); return ctx; }; @@ -1000,8 +1000,8 @@ namespace sd { auto ctx = new Context(1); int sz0 = p.getIntParam("sz0"); int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create_('c', {sz0}); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); int* a = new int[sz0]; for( int i=0; ip(i, a[i]); + indices.p(i, a[i]); } delete[] a; - ctx->setInputArray(0, in, true); - ctx->setInputArray(1, indices, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {sz0, sz1, 512/sz1}), true); + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {sz0, sz1, 512/sz1})); return ctx; }; @@ -1040,8 +1040,8 @@ namespace sd { int cols = numElements / rows; bool c = p.getIntParam("cf"); - auto arr = NDArrayFactory::create_(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create_(c ? 'f' : 'c', {rows, cols}); + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); x.push_back(arr); z.push_back(arr2); }; @@ -1056,15 +1056,15 @@ namespace sd { bool nchw = p.getIntParam("nchw"); if(nchw) { - auto orig = NDArrayFactory::create_('c', {16, 32, 64, 64}); - orig->permutei({0,2,3,1}); + auto orig = NDArrayFactory::create('c', {16, 32, 64, 64}); + orig.permutei({0,2,3,1}); x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {16, 64, 64, 32})); + z.push_back(NDArrayFactory::create('c', {16, 64, 64, 32})); } else { - auto orig = NDArrayFactory::create_('c', {16, 64, 64, 32}); - orig->permutei({0,3,1,2}); + auto orig = NDArrayFactory::create('c', {16, 64, 64, 32}); + orig.permutei({0,3,1,2}); x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {16, 32, 64, 64})); + z.push_back(NDArrayFactory::create('c', {16, 32, 64, 64})); } }; @@ -1156,9 +1156,9 @@ namespace sd { } auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', shape), true); - ctx->setInputArray(1, NDArrayFactory::create_('c', toBcShape), true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', shape), true); + ctx->setInputArray(0, NDArrayFactory::create('c', shape)); + ctx->setInputArray(1, NDArrayFactory::create('c', toBcShape)); + ctx->setOutputArray(0, NDArrayFactory::create('c', shape)); return ctx; }; @@ -1188,20 +1188,20 @@ namespace sd { auto generator = PARAMETRIC_D() { auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); auto ctx = new Context(1); - ctx->setInputArray(0, arr, true); + ctx->setInputArray(0, arr); if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create_('c', {p.getIntParam("rows"), 1}), true); + ctx->setInputArray(1, NDArrayFactory::create('c', {p.getIntParam("rows"), 1})); } else { - ctx->setInputArray(1, NDArrayFactory::create_('c', {1, p.getIntParam("cols")}), true); + ctx->setInputArray(1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } if (p.getIntParam("inplace") == 1) { ctx->setOutputArray(0, arr); ctx->markInplace(true); } else { - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}), true); + ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); } return ctx; }; @@ -1226,17 +1226,17 @@ namespace sd { ParametersBatch batch({&rows, &cols, &inplace}); auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); x.push_back(arr); if(axis == 0){ - y.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows")})); + y.push_back(NDArrayFactory::create('c', {p.getIntParam("rows")})); } else { - y.push_back(NDArrayFactory::create_('c', {p.getIntParam("cols")})); + y.push_back(NDArrayFactory::create('c', {p.getIntParam("cols")})); } if (p.getIntParam("inplace") == 1) { z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); } }; @@ -1264,9 +1264,9 @@ namespace sd { //Note: always inplace here auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', shape); + auto arr = NDArrayFactory::create('c', shape); x.push_back(arr); - y.push_back(NDArrayFactory::create_('c', {vectorLength})); + y.push_back(NDArrayFactory::create('c', {vectorLength})); z.push_back(arr); }; @@ -1291,21 +1291,20 @@ namespace sd { //This is an edge case: technically an EWS *should* be available here auto generator1 = PARAMETRIC_XYZ() { auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {131072 + (stride == 1 ? 0 : 1), stride}); + auto arr = NDArrayFactory::create('c', {131072 + (stride == 1 ? 0 : 1), stride}); - NDArray* strided; + NDArray strided; if(stride == 1){ strided = arr; } else { IndicesList indices({NDIndex::interval(0,131072), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; + strided = arr.subarray(indices); //All rows, first column } - strided->assign(1.0); + strided.assign(1.0); x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); }; ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); @@ -1315,21 +1314,20 @@ namespace sd { //No EWS defined for this case auto generator2 = PARAMETRIC_XYZ() { auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); + auto arr = NDArrayFactory::create('c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); - NDArray* strided; + NDArray strided; if(stride == 1){ strided = arr; } else { IndicesList indices({NDIndex::interval(0,2*1024,2), NDIndex::all(), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); - delete arr; + strided = arr.subarray(indices); } - strided->assign(1.0); + strided.assign(1.0); x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); }; ReductionBenchmark rbSum2(reduce::SameOps::Sum, "stridedSumNoEWS"); @@ -1351,21 +1349,20 @@ namespace sd { auto generator = PARAMETRIC_XYZ() { auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - NDArray* strided; + NDArray strided; if(stride == 1){ strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; + strided = arr.subarray(indices); //All rows, first column } - strided->assign(1.0); + strided.assign(1.0); x.push_back(strided); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); }; ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); @@ -1386,22 +1383,21 @@ namespace sd { auto generator = PARAMETRIC_XYZ() { auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - NDArray* strided; + NDArray strided; if(stride == 1){ strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; + strided = arr.subarray(indices); //All rows, first column } - strided->assign(1.0); + strided.assign(1.0); x.push_back(strided); - y.push_back(nullptr); -// z.push_back(NDArrayFactory::create_(0.0f)); - z.push_back(NDArrayFactory::create_('c', {1})); + y.push_back(NDArray()); +// z.push_back(NDArrayFactory::create(0.0f)); + z.push_back(NDArrayFactory::create('c', {1})); }; ReductionBenchmark rbSum(reduce::SameOps::Sum, "Strided Sum"); @@ -1411,20 +1407,19 @@ namespace sd { auto generator3 = PARAMETRIC_D(){ auto ctx = new Context(1); auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length"), stride}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - NDArray* strided; + NDArray strided; if(stride == 1){ strided = arr; } else { IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = new NDArray(arr->subarray(indices)); //All rows, first column - delete arr; + strided = arr.subarray(indices); //All rows, first column } - strided->assign(1.0); - ctx->setInputArray(0, strided, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1}), true); + strided.assign(1.0); + ctx->setInputArray(0, strided); + ctx->setOutputArray(0, NDArrayFactory::create('c', {1})); auto iargs = new Nd4jLong[1]; iargs[0] = 0; ctx->setIArguments(iargs, 1); @@ -1457,17 +1452,17 @@ namespace sd { int rows = p.getIntParam("rows"); int cols = length[i] / rows; int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); + auto arr = NDArrayFactory::create('c', {rows, cols}); x.push_back(arr); - y.push_back(NDArrayFactory::create_(dim)); + y.push_back(NDArrayFactory::create(dim)); - NDArray* result; + NDArray result; if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); + result = NDArrayFactory::create('c', {cols}); } else { - result = NDArrayFactory::create_('c', {rows}); + result = NDArrayFactory::create('c', {rows}); } z.push_back(result); }; @@ -1486,22 +1481,22 @@ namespace sd { int rows = p.getIntParam("rows"); int cols = length[i] / rows; int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); + auto arr = NDArrayFactory::create('c', {rows, cols}); Nd4jLong* dimArg = new Nd4jLong[1]; dimArg[0] = dim; ctx->setIArguments(dimArg, 1); delete[] dimArg; - ctx->setInputArray(0, arr, true); + ctx->setInputArray(0, arr); - NDArray* result; + NDArray result; if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); + result = NDArrayFactory::create('c', {cols}); } else { - result = NDArrayFactory::create_('c', {rows}); + result = NDArrayFactory::create('c', {rows}); } - ctx->setOutputArray(0, result, true); + ctx->setOutputArray(0, result); return ctx; }; @@ -1525,11 +1520,11 @@ namespace sd { ParametersBatch batch({&length}); auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); x.push_back(arr); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); }; ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); @@ -1542,9 +1537,9 @@ namespace sd { auto generator3 = PARAMETRIC_D(){ auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - ctx->setInputArray(1, NDArrayFactory::create_((Nd4jLong)0), true); - ctx->setOutputArray(0, NDArrayFactory::create_(0), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); return ctx; }; @@ -1563,21 +1558,20 @@ namespace sd { auto generator = PARAMETRIC_XZ() { int r = p.getIntParam("rowcol"); - auto arr = NDArrayFactory::create_('c', {r, r+1}); + auto arr = NDArrayFactory::create('c', {r, r+1}); IndicesList indices({NDIndex::all(), NDIndex::interval(0,r-1)}); - auto view = new NDArray(arr->subarray(indices)); + auto view = arr.subarray(indices); //nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), view->sizeAt(1)); x.push_back(view); if(p.getIntParam("inplace") == 1){ z.push_back(view); } else { - z.push_back(NDArrayFactory::create_('c', {view->sizeAt(0),view->sizeAt(1)})); + z.push_back(NDArrayFactory::create('c', {view.sizeAt(0),view.sizeAt(1)})); } - delete arr; }; ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU_View"); - sbLRelu.setY(NDArrayFactory::create_(0.0)); + sbLRelu.setY(NDArrayFactory::create(0.0)); TransformBenchmark tbExp(transform::StrictOps::Exp, "exp view"); @@ -1596,14 +1590,14 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); x.push_back(arr1); y.push_back(arr2); if(p.getIntParam("inplace") == 1){ z.push_back(arr1); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -1625,13 +1619,13 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if (p.getIntParam("inplace") == 1) { z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -1644,13 +1638,13 @@ namespace sd { DeclarableBenchmark pg(op1, "polygamma"); auto generator2 = PARAMETRIC_D() { auto ctx = new Context(1); - auto in0 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in0->assign(0.25); - auto in1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in1->assign(0.5); - ctx->setInputArray(0, in0, true); - ctx->setInputArray(1, in1, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); return ctx; }; @@ -1661,16 +1655,16 @@ namespace sd { DeclarableBenchmark binc(op2, "betainc"); auto generator3 = PARAMETRIC_D() { auto ctx = new Context(1); - auto in0 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in0->assign(0.25); - auto in1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in1->assign(0.5); - auto in2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - in2->assign(0.75); - ctx->setInputArray(0, in0, true); - ctx->setInputArray(1, in1, true); - ctx->setInputArray(2, in2, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + auto in2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in2.assign(0.75); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setInputArray(2, in2); + ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); return ctx; }; @@ -1691,13 +1685,13 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -1715,13 +1709,13 @@ namespace sd { ParametersBatch batch2({&rows, &cols, &inplace}); auto generator2 = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("rows"), p.getIntParam("cols")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); } }; @@ -1741,18 +1735,18 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU"); - sbLRelu.setY(NDArrayFactory::create_(0.0)); + sbLRelu.setY(NDArrayFactory::create(0.0)); TransformBenchmark tbAbs(transform::SameOps::Abs, "abs"); TransformBenchmark tbExp(transform::StrictOps::Exp, "exp"); @@ -1774,13 +1768,13 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -1789,9 +1783,9 @@ namespace sd { ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - sbAdd.setY(NDArrayFactory::create_(3.14159265359)); - sbDiv.setY(NDArrayFactory::create_(3.14159265359)); - sbPow.setY(NDArrayFactory::create_(3.14159265359)); + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359) - F32"); diff --git a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp index 99a1b05bf92a..158131bbec54 100644 --- a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp @@ -46,18 +46,18 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; ScalarBenchmark sbRelu(scalar::Ops::RELU, "RELU"); - sbRelu.setY(NDArrayFactory::create_(0.0)); + sbRelu.setY(NDArrayFactory::create(0.0)); TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); @@ -82,13 +82,13 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); - arr->assign(1.0); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); x.push_back(arr); if(p.getIntParam("inplace") == 1){ z.push_back(arr); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -97,9 +97,9 @@ namespace sd { ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - sbAdd.setY(NDArrayFactory::create_(3.14159265359)); - sbDiv.setY(NDArrayFactory::create_(3.14159265359)); - sbPow.setY(NDArrayFactory::create_(3.14159265359)); + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359)"); @@ -122,14 +122,14 @@ namespace sd { ParametersBatch batch({&length, &inplace}); auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create_('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create_('c', {p.getIntParam("length")}); + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); x.push_back(arr1); y.push_back(arr2); if(p.getIntParam("inplace") == 1){ z.push_back(arr1); } else { - z.push_back(NDArrayFactory::create_('c', {p.getIntParam("length")})); + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } }; @@ -157,8 +157,8 @@ namespace sd { int cols = numElements / rows; bool c = p.getIntParam("cf"); - auto arr = NDArrayFactory::create_(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create_(c ? 'f' : 'c', {rows, cols}); + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); x.push_back(arr); z.push_back(arr2); }; @@ -176,15 +176,15 @@ namespace sd { bool nchw = p.getIntParam("nchw"); if(nchw) { - auto orig = NDArrayFactory::create_('c', {mb, c, hw, hw}); - orig->permutei({0,2,3,1}); + auto orig = NDArrayFactory::create('c', {mb, c, hw, hw}); + orig.permutei({0,2,3,1}); x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {mb, hw, hw, c})); + z.push_back(NDArrayFactory::create('c', {mb, hw, hw, c})); } else { - auto orig = NDArrayFactory::create_('c', {mb, hw, hw, c}); - orig->permutei({0,3,1,2}); + auto orig = NDArrayFactory::create('c', {mb, hw, hw, c}); + orig.permutei({0,3,1,2}); x.push_back(orig); - z.push_back(NDArrayFactory::create_('c', {mb, c, hw, hw})); + z.push_back(NDArrayFactory::create('c', {mb, c, hw, hw})); } }; @@ -213,9 +213,9 @@ namespace sd { std::vector shapeB; shapeA = {a, b}; shapeB = {b, c}; - auto A = NDArrayFactory::create_('c', shapeA); - auto B = NDArrayFactory::create_('c', shapeB); - auto C = NDArrayFactory::create_(resultOrder, {a, c}); + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create(resultOrder, {a, c}); x.push_back(A); y.push_back(B); @@ -246,11 +246,11 @@ namespace sd { ParametersBatch batch({&length}); auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create_('c', {p.getIntParam("length")}); + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); x.push_back(arr); - y.push_back(nullptr); - z.push_back(NDArrayFactory::create_(0.0f)); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); }; ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); @@ -267,9 +267,9 @@ namespace sd { auto generator3 = PARAMETRIC_D(){ auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create_('c', {p.getIntParam("length")}), true); - ctx->setInputArray(1, NDArrayFactory::create_((Nd4jLong)0), true); - ctx->setOutputArray(0, NDArrayFactory::create_(0), true); + ctx->setInputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); return ctx; }; @@ -298,17 +298,17 @@ namespace sd { int rows = p.getIntParam("rows"); int cols = length[i] / rows; int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); + auto arr = NDArrayFactory::create('c', {rows, cols}); x.push_back(arr); - y.push_back(NDArrayFactory::create_(dim)); + y.push_back(NDArrayFactory::create(dim)); - NDArray* result; + NDArray result; if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); + result = NDArrayFactory::create('c', {cols}); } else { - result = NDArrayFactory::create_('c', {rows}); + result = NDArrayFactory::create('c', {rows}); } z.push_back(result); }; @@ -331,22 +331,22 @@ namespace sd { int rows = p.getIntParam("rows"); int cols = length[i] / rows; int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create_('c', {rows, cols}); + auto arr = NDArrayFactory::create('c', {rows, cols}); auto dimArg = new Nd4jLong[1]; dimArg[0] = dim; ctx->setIArguments(dimArg, 1); delete[] dimArg; - ctx->setInputArray(0, arr, true); + ctx->setInputArray(0, arr); - NDArray* result; + NDArray result; if(dim == 0){ - result = NDArrayFactory::create_('c', {cols}); + result = NDArrayFactory::create('c', {cols}); } else { - result = NDArrayFactory::create_('c', {rows}); + result = NDArrayFactory::create('c', {rows}); } - ctx->setOutputArray(0, result, true); + ctx->setOutputArray(0, result); return ctx; }; @@ -382,22 +382,22 @@ namespace sd { int khw = p.getIntParam("k"); if (n == 0) { - auto input = NDArrayFactory::create_('c', {8, 3, hw, hw}); - auto output = NDArrayFactory::create_('c', {8, 3, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {8, 3, hw, hw}); + auto output = NDArrayFactory::create('c', {8, 3, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } else { - auto input = NDArrayFactory::create_('c', {8, hw, hw, 3}); - auto output = NDArrayFactory::create_('c', {8, hw, hw, 3}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {8, hw, hw, 3}); + auto output = NDArrayFactory::create('c', {8, hw, hw, 3}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - auto b = NDArrayFactory::create_('c', {3}); - auto w = NDArrayFactory::create_('c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always + auto b = NDArrayFactory::create('c', {3}); + auto w = NDArrayFactory::create('c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always - ctx->setInputArray(1, w, true); - ctx->setInputArray(2, b, true); + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); auto args = new Nd4jLong[10]; args[0] = args[1] = khw; //Kernel @@ -437,15 +437,15 @@ namespace sd { int khw = p.getIntParam("k"); if (n == 0) { - auto input = NDArrayFactory::create_('c', {8, c, hw, hw}); - auto output = NDArrayFactory::create_('c', {8, c, hw, hw}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {8, c, hw, hw}); + auto output = NDArrayFactory::create('c', {8, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } else { - auto input = NDArrayFactory::create_('c', {8, hw, hw, c}); - auto output = NDArrayFactory::create_('c', {8, hw, hw, c}); - ctx->setInputArray(0, input, true); - ctx->setOutputArray(0, output, true); + auto input = NDArrayFactory::create('c', {8, hw, hw, c}); + auto output = NDArrayFactory::create('c', {8, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } auto args = new Nd4jLong[11]; @@ -494,46 +494,46 @@ namespace sd { int m = p.getIntParam("mb"); Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create_(l), true); //Max TS length (unused) + ctx->setInputArray(0, NDArrayFactory::create(l)); //Max TS length (unused) if (f == 0) { //TNS format - ctx->setInputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('c', {seqLength, m, n}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('c', {seqLength, m, n}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('c', {seqLength, m, n}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('c', {seqLength, m, n}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('c', {seqLength, m, n}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('c', {seqLength, m, n}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('c', {seqLength, m, n}), true); //y + ctx->setInputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //x + ctx->setOutputArray(0, NDArrayFactory::create('c', {seqLength, m, n})); //i + ctx->setOutputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //c + ctx->setOutputArray(2, NDArrayFactory::create('c', {seqLength, m, n})); //f + ctx->setOutputArray(3, NDArrayFactory::create('c', {seqLength, m, n})); //o + ctx->setOutputArray(4, NDArrayFactory::create('c', {seqLength, m, n})); //z + ctx->setOutputArray(5, NDArrayFactory::create('c', {seqLength, m, n})); //h + ctx->setOutputArray(6, NDArrayFactory::create('c', {seqLength, m, n})); //y } else { //NST format - ctx->setInputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //x - ctx->setOutputArray(0, NDArrayFactory::create_('f', {m, n, seqLength}), true); //i - ctx->setOutputArray(1, NDArrayFactory::create_('f', {m, n, seqLength}), true); //c - ctx->setOutputArray(2, NDArrayFactory::create_('f', {m, n, seqLength}), true); //f - ctx->setOutputArray(3, NDArrayFactory::create_('f', {m, n, seqLength}), true); //o - ctx->setOutputArray(4, NDArrayFactory::create_('f', {m, n, seqLength}), true); //z - ctx->setOutputArray(5, NDArrayFactory::create_('f', {m, n, seqLength}), true); //h - ctx->setOutputArray(6, NDArrayFactory::create_('f', {m, n, seqLength}), true); //y + ctx->setInputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //x + ctx->setOutputArray(0, NDArrayFactory::create('f', {m, n, seqLength})); //i + ctx->setOutputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //c + ctx->setOutputArray(2, NDArrayFactory::create('f', {m, n, seqLength})); //f + ctx->setOutputArray(3, NDArrayFactory::create('f', {m, n, seqLength})); //o + ctx->setOutputArray(4, NDArrayFactory::create('f', {m, n, seqLength})); //z + ctx->setOutputArray(5, NDArrayFactory::create('f', {m, n, seqLength})); //h + ctx->setOutputArray(6, NDArrayFactory::create('f', {m, n, seqLength})); //y } - auto cLast = NDArrayFactory::create_('c', {m, n}); - auto yLast = NDArrayFactory::create_('c', {m, n}); - auto W = NDArrayFactory::create_('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create_('c', {n}); - auto Wcf = NDArrayFactory::create_('c', {n}); - auto Wco = NDArrayFactory::create_('c', {n}); - auto b = NDArrayFactory::create_('c', {4 * n}); - - ctx->setInputArray(2, cLast, true); - ctx->setInputArray(3, yLast, true); - ctx->setInputArray(4, W, true); - ctx->setInputArray(5, Wci, true); - ctx->setInputArray(6, Wcf, true); - ctx->setInputArray(7, Wco, true); - ctx->setInputArray(8, b, true); + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); auto iargs = new Nd4jLong[2]; iargs[0] = 0; //No peephole @@ -566,20 +566,20 @@ namespace sd { auto generator = PARAMETRIC_D() { auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create_('c', {rows, p.getIntParam("cols")}); + auto arr = NDArrayFactory::create('c', {rows, p.getIntParam("cols")}); auto ctx = new Context(1); - ctx->setInputArray(0, arr, true); + ctx->setInputArray(0, arr); if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create_('c', {rows, 1}), true); + ctx->setInputArray(1, NDArrayFactory::create('c', {rows, 1})); } else { - ctx->setInputArray(1, NDArrayFactory::create_('c', {1, p.getIntParam("cols")}), true); + ctx->setInputArray(1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } if (p.getIntParam("inplace") == 1) { ctx->setOutputArray(0, arr); ctx->markInplace(true); } else { - ctx->setOutputArray(0, NDArrayFactory::create_('c', {rows, p.getIntParam("cols")}), true); + ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, p.getIntParam("cols")})); } return ctx; }; diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 92aeb9218445..251babee9ddb 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1511,11 +1511,11 @@ #define CHECK_STASH(NAME) block.getStash()->checkStash(block.getNodeId(), NAME); #define UNSTASH(NAME) block.getStash()->extractArray(block.getNodeId(), NAME); -#define INPUT_VARIABLE(INDEX) block.array(INDEX) +#define INPUT_VARIABLE(INDEX) block.array(INDEX).get() #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) -#define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +#define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList().get()) #define D_ARG(INDEX) block.getDArguments().at(INDEX) #define INT_ARG(INDEX) block.getIArguments().at(INDEX) @@ -1524,7 +1524,7 @@ #define B_ARG(INDEX) block.getBArguments().at(INDEX) -#define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.getWorkspace()) +#define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.workspace()) #define COPY_SHAPE_EX(SRC, TGT, WORKSPACE) TGT = ShapeBuilders::copyShapeInfo(SRC, true, WORKSPACE) diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index 199cb88eb0d5..74d4cdb44b38 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -143,6 +143,6 @@ TEST_F(BooleanOpsTests, test_where_1) { //z->printIndexedBuffer("z"); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 51c6e2375550..d9f1c760f91e 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -594,7 +594,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_3) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_4) { @@ -611,7 +611,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_4) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -629,7 +629,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_5) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -647,7 +647,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_6) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -665,7 +665,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_7) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -700,7 +700,7 @@ TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_bool_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 571efc1afb04..4baf1278b696 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -34,11 +34,11 @@ class ContextTests : public testing::Test { TEST_F(ContextTests, Basic_Test_1) { VariableSpace variableSpace; - auto _20 = NDArrayFactory::create_('c', {2, 2}); - auto _21 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20->assign(1.0f); - _21->assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); variableSpace.putVariable(2, 0, _20); variableSpace.putVariable(2, 1, _21); @@ -62,11 +62,11 @@ TEST_F(ContextTests, Basic_Test_1) { TEST_F(ContextTests, Basic_Test_2) { VariableSpace variableSpace; - auto _20 = NDArrayFactory::create_('c', {2, 2}); - auto _21 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20->assign(1.0f); - _21->assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); variableSpace.putVariable(-1, _20); variableSpace.putVariable(-2, _21); @@ -92,7 +92,7 @@ TEST_F(ContextTests, Basic_Test_3) { Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create_('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); ctx.pushNDArrayToVariableSpace(1, 1, _20); @@ -105,11 +105,11 @@ TEST_F(ContextTests, Basic_Test_4) { Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create_('c', {2, 2}); - _20->linspace(1); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - auto _21 = NDArrayFactory::create_('c', {2, 2}); - _21->linspace(10); + auto _21 = NDArrayFactory::create('c', {2, 2}); + _21.linspace(10); ctx.pushNDArrayToVariableSpace(1, 1, _20); @@ -127,10 +127,10 @@ TEST_F(ContextTests, Basic_Test_5) { Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create_('c', {2, 2}); - _20->linspace(1); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - auto exp = new NDArray(_20->dup()); + auto exp = _20.dup(); ctx.pushNDArrayToVariableSpace(1, 1, _20); @@ -140,11 +140,7 @@ TEST_F(ContextTests, Basic_Test_5) { auto vA = ctx.variable(1, 1); - ASSERT_TRUE(vA->getNDArray() == _20); - ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); - - delete exp; } @@ -185,11 +181,11 @@ TEST_F(ContextTests, Basic_Test_7) { ASSERT_TRUE(v1 == var1); - auto _10 = NDArrayFactory::create_('c', {2, 2}); - _10->linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create_('c', {2, 2}); - _11->linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); ctx.pushNDArrayToVariableSpace(1, 0, _10); ctx.pushNDArrayToVariableSpace(1, 1, _11); @@ -206,11 +202,11 @@ TEST_F(ContextTests, Basic_Test_8) { Context ctx(1, &variableSpace); - auto _10 = NDArrayFactory::create_('c', {2, 2}); - _10->linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create_('c', {2, 2}); - _11->linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); ctx.pushNDArrayToVariableSpace(1, 0, _10); ctx.pushNDArrayToVariableSpace(1, 1, _11); @@ -232,7 +228,7 @@ TEST_F(ContextTests, Basic_Test_9) { auto in = NDArrayFactory::create('c', {5, 5}); Context ctx(1, &variableSpace, true); - ctx.pushNDArrayToVariableSpace(1, 1, &in, false); + ctx.pushNDArrayToVariableSpace(1, 1, in); } TEST_F(ContextTests, Basic_Test_10) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index d33ed6fff780..87c5add77999 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -60,18 +60,18 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_1) { int bS=1, iH=5,iW=4, iC=2,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - auto input = NDArrayFactory::create_('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create_('c', {oC, iC, kH, kW}); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e + 1); + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + for (int e = 0; e < input.lengthOf(); e++) + input.p(e, e + 1); - for (int e = 0; e < weights->lengthOf(); e++) - weights->p(e, e + 1); - weights->permutei({2,3,1,0}); + for (int e = 0; e < weights.lengthOf(); e++) + weights.p(e, e + 1); + weights.permutei({2,3,1,0}); // weights->printShapeInfo("weights"); - ArrayOptions::setDataType(_expS, input->dataType()); + ArrayOptions::setDataType(_expS, input.dataType()); auto exp = new NDArray(_expB, _expS); auto variableSpace = new VariableSpace(); @@ -264,7 +264,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_7) { sd::ops::conv2d op; auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -393,14 +393,14 @@ TEST_F(ConvolutionTests1, sconv2d_1) { int iX = 10; int B = 2; - auto input = NDArrayFactory::create_('c', {B, iC, iY, iX}); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e+1); + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + for (int e = 0; e < input.lengthOf(); e++) + input.p(e, e+1); - auto weights = NDArrayFactory::create_('c', {oC, iC, kY, kX}); - for (int e = 0; e < weights->lengthOf(); e++) - weights->p(e, e+1); - weights->permutei({2,3,1,0}); + auto weights = NDArrayFactory::create('c', {oC, iC, kY, kX}); + for (int e = 0; e < weights.lengthOf(); e++) + weights.p(e, e+1); + weights.permutei({2,3,1,0}); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, input); @@ -434,11 +434,11 @@ TEST_F(ConvolutionTests1, sconv2d_1) { //exp.printShapeInfo("Expected shape"); //output->printShapeInfo("Result shape"); - ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.isSameShape(*output)); //exp.printBuffer("Expctd buffer"); //output->printBuffer("Result buffer"); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(*output)); delete block; delete variableSpace; @@ -482,11 +482,11 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { //z->printShapeInfo("FF shape"); - ASSERT_TRUE(z->isSameShape(&expFF)); + ASSERT_TRUE(z.isSameShape(&expFF)); //expFF.printBuffer("e"); //z->printBuffer("z"); - ASSERT_TRUE(z->equalsTo(&expFF, 1e-3)); + ASSERT_TRUE(z.equalsTo(&expFF, 1e-3)); } TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { @@ -547,7 +547,7 @@ TEST_F(ConvolutionTests1, sconv2d_4) { sd::ops::sconv2d op; auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -720,19 +720,19 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { auto z = resultFF.at(0); - ASSERT_TRUE(z->isSameShape(&expFF)); - ASSERT_TRUE(z->equalsTo(&expFF, 1)); + ASSERT_TRUE(z.isSameShape(&expFF)); + ASSERT_TRUE(z.equalsTo(&expFF, 1)); sd::ops::conv2d op2d; // weightsP.printShapeInfo(); - auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + auto result2D = op2d.evaluate({&z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); auto z2d = result2D.at(0); // z2d->printBuffer(); - ASSERT_TRUE(z2d->isSameShape(&exp2FF)); - ASSERT_TRUE(z2d->equalsTo(&exp2FF)); + ASSERT_TRUE(z2d.isSameShape(&exp2FF)); + ASSERT_TRUE(z2d.equalsTo(&exp2FF)); } TEST_F(ConvolutionTests1, deconv2d_bp_1) { @@ -898,7 +898,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { sd::ops::conv1d_bp op_bp; - auto epsilonNxt = new NDArray(z->dup()); + auto epsilonNxt = new NDArray(z.dup()); epsilonNxt->linspace(1); auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); @@ -1635,9 +1635,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { sd::ops::conv3dnew_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); - auto* gradB = results.at(2); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); ASSERT_EQ(Status::OK(), results.status()); @@ -1783,7 +1783,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -1809,7 +1809,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -1834,7 +1834,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -1861,7 +1861,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); // output->printIndexedBuffer(); @@ -1889,7 +1889,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); // output->printIndexedBuffer(); @@ -1919,7 +1919,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); // output->printIndexedBuffer(); @@ -1948,7 +1948,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expected.isSameShape(output)); @@ -2032,7 +2032,7 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2055,10 +2055,10 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { sd::ops::conv3dnew op; auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isSameShape(&expected)); + ASSERT_TRUE(output.isSameShape(&expected)); } @@ -2154,7 +2154,7 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { sd::ops::pointwise_conv2d op; auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(output)); @@ -2274,7 +2274,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { sd::ops::upsampling2d op; auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(output)); @@ -2301,7 +2301,7 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { sd::ops::upsampling2d op; auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2338,7 +2338,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { sd::ops::upsampling3d op; auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(output)); @@ -2371,7 +2371,7 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { sd::ops::upsampling3d op; auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2397,7 +2397,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { sd::ops::upsampling3d_bp op; auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto* gradI = results.at(0); + auto gradI = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expGradI.isSameShape(gradI)); @@ -2486,7 +2486,7 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { sd::ops::upsampling3d_bp op; auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto* gradI = results.at(0); + auto gradI = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expGradI.isSameShape(gradI)); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 80ec596a7805..2bd6c5445cc2 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -141,7 +141,7 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, *result.at(0)); + ASSERT_EQ(exp, result.at(0)); } @@ -249,22 +249,22 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { //_gradWP->printBuffer("gradWP"); - ASSERT_TRUE(_gradWP->isSameShape(&expGWP)); - ASSERT_TRUE(_gradWP->isSameShape(&weightsP)); + ASSERT_TRUE(_gradWP.isSameShape(&expGWP)); + ASSERT_TRUE(_gradWP.isSameShape(&weightsP)); - ASSERT_TRUE(_gradWP->equalsTo(&expGWP)); + ASSERT_TRUE(_gradWP.equalsTo(&expGWP)); //_gradWD->printShapeInfo("gradWD shape"); - ASSERT_TRUE(_gradWD->isSameShape(&expGWD)); - ASSERT_TRUE(_gradWD->isSameShape(&weightsD)); + ASSERT_TRUE(_gradWD.isSameShape(&expGWD)); + ASSERT_TRUE(_gradWD.isSameShape(&weightsD)); // _gradWD->printIndexedBuffer(); - ASSERT_TRUE(_gradWD->equalsTo(&expGWD)); + ASSERT_TRUE(_gradWD.equalsTo(&expGWD)); - ASSERT_TRUE(_epsilon->isSameShape(&input)); - ASSERT_TRUE(_epsilon->isSameShape(&expE)); + ASSERT_TRUE(_epsilon.isSameShape(&input)); + ASSERT_TRUE(_epsilon.isSameShape(&expE)); - ASSERT_TRUE(_epsilon->equalsTo(&expE)); + ASSERT_TRUE(_epsilon.equalsTo(&expE)); } @@ -369,8 +369,8 @@ TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) { sd::ops::sconv2d_bp op; auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradWD = results.at(1); + auto gradI = results.at(0); + auto gradWD = results.at(1); ASSERT_EQ(Status::OK(), results.status()); @@ -899,7 +899,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test6) { ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_1) { - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -917,7 +917,7 @@ TEST_F(ConvolutionTests2, maxpool2d_1) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -942,7 +942,7 @@ TEST_F(ConvolutionTests2, maxpool2d_2) { const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -960,7 +960,7 @@ TEST_F(ConvolutionTests2, maxpool2d_2) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -985,7 +985,7 @@ TEST_F(ConvolutionTests2, maxpool2d_3) { const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -1003,7 +1003,7 @@ TEST_F(ConvolutionTests2, maxpool2d_3) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -1028,7 +1028,7 @@ TEST_F(ConvolutionTests2, maxpool2d_4) { const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -1046,7 +1046,7 @@ TEST_F(ConvolutionTests2, maxpool2d_4) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -1071,7 +1071,7 @@ TEST_F(ConvolutionTests2, maxpool2d_5) { const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -1089,7 +1089,7 @@ TEST_F(ConvolutionTests2, maxpool2d_5) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -1173,7 +1173,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) { auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isSameShape({bS, iC, oH, oW})); + ASSERT_TRUE(output.isSameShape({bS, iC, oH, oW})); } @@ -1196,7 +1196,7 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { sd::ops::maxpool2d op; auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -1686,8 +1686,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_1) { - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto variableSpace = new VariableSpace(); @@ -1705,7 +1705,7 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_1) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -1869,8 +1869,8 @@ TEST_F(ConvolutionTests2, maxpool2d_bp_7) { ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, avgpool2d_bp_1) { - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto variableSpace = new VariableSpace(); @@ -1888,7 +1888,7 @@ TEST_F(ConvolutionTests2, avgpool2d_bp_1) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -2039,8 +2039,8 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { - auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto variableSpace = new VariableSpace(); @@ -2059,7 +2059,7 @@ TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -2146,7 +2146,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_1) { sd::ops::upsampling2d_bp op; auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); + auto gradI = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expGradI.isSameShape(gradI)); @@ -2170,7 +2170,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_2) { sd::ops::upsampling2d_bp op; auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); + auto gradI = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expGradI.isSameShape(gradI)); @@ -2200,7 +2200,7 @@ TEST_F(ConvolutionTests2, upsampling2d_bp_3) { sd::ops::upsampling2d_bp op; auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto* gradI = results.at(0); + auto gradI = results.at(0); ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expGradI.isSameShape(gradI)); @@ -2230,7 +2230,7 @@ TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2258,7 +2258,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_2) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2289,7 +2289,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_3) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2372,7 +2372,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_6) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* output = results.at(0); + auto output = results.at(0); // output.printIndexedBuffer(); ASSERT_EQ(Status::OK(), results.status()); @@ -2405,7 +2405,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_7) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2511,7 +2511,7 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_10) { sd::ops::depthwise_conv2d op; auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto* output = results.at(0); + auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2583,8 +2583,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); + auto gradI = results.at(0); + auto gradW = results.at(1); ASSERT_EQ(Status::OK(), results.status()); @@ -2620,8 +2620,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); + auto gradI = results.at(0); + auto gradW = results.at(1); ASSERT_EQ(Status::OK(), results.status()); @@ -2686,9 +2686,9 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); ASSERT_EQ(Status::OK(), results.status()); @@ -2740,9 +2740,9 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - NDArray* gradI = results.at(0); - NDArray* gradW = results.at(1); - NDArray* gradB = results.at(2); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); ASSERT_EQ(Status::OK(), results.status()); @@ -2782,8 +2782,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); + auto gradI = results.at(0); + auto gradW = results.at(1); ASSERT_EQ(Status::OK(), results.status()); @@ -2821,8 +2821,8 @@ TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) { sd::ops::depthwise_conv2d_bp op; auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto* gradI = results.at(0); - auto* gradW = results.at(1); + auto gradI = results.at(0); + auto gradW = results.at(1); ASSERT_EQ(Status::OK(), results.status()); diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 83f3a15f58ef..be6440b707e4 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -131,9 +131,9 @@ TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { auto z = NDArrayFactory::create(0); Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); sd::ops::bits_hamming_distance op; auto status = op.execute(&ctx); @@ -146,9 +146,9 @@ TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) { auto z = NDArrayFactory::create(0); Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); sd::ops::bits_hamming_distance op; auto status = op.execute(&ctx); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 1f7b9885b91e..649d4db6ab3d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -97,17 +97,17 @@ TEST_F(DeclarableOpsTests1, BasicInitialization1) { std::string expName("concat"); ASSERT_EQ(expName, concat->getOpName()); - auto x0 = NDArrayFactory::create_('c', { 1, 5 }); - auto x1 = NDArrayFactory::create_('c', { 1, 5 }); - auto x2 = NDArrayFactory::create_('c', { 1, 5 }); - auto x3 = NDArrayFactory::create_('c', { 1, 5 }); - auto x4 = NDArrayFactory::create_('c', { 1, 5 }); - - x0->assign(1.0f); - x1->assign(1.0f); - x2->assign(1.0f); - x3->assign(1.0f); - x4->assign(1.0f); + auto x0 = NDArrayFactory::create('c', { 1, 5 }); + auto x1 = NDArrayFactory::create('c', { 1, 5 }); + auto x2 = NDArrayFactory::create('c', { 1, 5 }); + auto x3 = NDArrayFactory::create('c', { 1, 5 }); + auto x4 = NDArrayFactory::create('c', { 1, 5 }); + + x0.assign(1.0f); + x1.assign(1.0f); + x2.assign(1.0f); + x3.assign(1.0f); + x4.assign(1.0f); auto variableSpace = new VariableSpace(); @@ -117,7 +117,7 @@ TEST_F(DeclarableOpsTests1, BasicInitialization1) { variableSpace->putVariable(-4, x3); variableSpace->putVariable(-5, x4); - auto nodeVar = new Variable(); + auto nodeVar = std::make_shared(); variableSpace->putVariable(1, nodeVar); @@ -165,7 +165,7 @@ TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { ASSERT_EQ(result.status(), ND4J_STATUS_OK); auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } @@ -179,7 +179,7 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { ASSERT_EQ(result.status(), ND4J_STATUS_OK); auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } @@ -196,8 +196,8 @@ TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { auto z1 = result.at(0); auto z2 = result.at(1); - ASSERT_TRUE(z1->equalsTo(exp1)); - ASSERT_TRUE(z2->equalsTo(exp2)); + ASSERT_TRUE(z1.equalsTo(exp1)); + ASSERT_TRUE(z2.equalsTo(exp2)); } @@ -212,7 +212,7 @@ TEST_F(DeclarableOpsTests1, AXpY_Test_1) { ASSERT_EQ(result.status(), ND4J_STATUS_OK); auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } @@ -253,7 +253,7 @@ TEST_F(DeclarableOpsTests1, TestTensorMmul1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* out = results.at(0); + auto out = results.at(0); ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); @@ -273,7 +273,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* out = results.at(0); + auto out = results.at(0); ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); @@ -293,7 +293,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* out = results.at(0); + auto out = results.at(0); ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); @@ -313,7 +313,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* out = results.at(0); + auto out = results.at(0); ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); @@ -333,7 +333,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -355,7 +355,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -376,7 +376,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -397,7 +397,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -447,7 +447,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -469,7 +469,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -490,7 +490,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -511,7 +511,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -532,7 +532,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -553,7 +553,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); @@ -609,12 +609,12 @@ TEST_F(DeclarableOpsTests1, DivergentCheck1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(1); - exp->assign(3); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x.assign(2); + y.assign(1); + exp.assign(3); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -626,23 +626,21 @@ TEST_F(DeclarableOpsTests1, AddMatrices1) { addOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); - delete exp; delete block; delete variableSpace; - } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(2); - y->assign(1); - exp->assign(3); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x.assign(2); + y.assign(1); + exp.assign(3); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -654,9 +652,8 @@ TEST_F(DeclarableOpsTests1, AddVectorVector1) { addOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); - delete exp; delete block; delete variableSpace; } @@ -664,11 +661,11 @@ TEST_F(DeclarableOpsTests1, AddVectorVector1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(1); + x.assign(2); + y.assign(1); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -681,7 +678,7 @@ TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { addOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -690,11 +687,11 @@ TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 1, 1 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(2); - y->assign(1); + x.assign(2); + y.assign(1); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -707,7 +704,7 @@ TEST_F(DeclarableOpsTests1, AddScalarScalar1) { addOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -716,11 +713,11 @@ TEST_F(DeclarableOpsTests1, AddScalarScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3); - y->assign(1); + x.assign(3); + y.assign(1); exp.assign(2); auto variableSpace = new VariableSpace(); @@ -733,8 +730,7 @@ TEST_F(DeclarableOpsTests1, SubtractMatrices1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); - + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -743,11 +739,11 @@ TEST_F(DeclarableOpsTests1, SubtractMatrices1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractTest_1) { - auto x = NDArrayFactory::create_('c', { 1, 6 }); - auto y = NDArrayFactory::create_('c', { 1, 6 }); + auto x = NDArrayFactory::create('c', { 1, 6 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); auto exp = NDArrayFactory::create('c', { 1, 6 }); - x->assign(3); - y->assign(1); + x.assign(3); + y.assign(1); exp.assign(2); auto variableSpace = new VariableSpace(); @@ -760,7 +756,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; @@ -784,7 +780,7 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -826,20 +822,20 @@ TEST_F(DeclarableOpsTests1, TestRng1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeSumTest1) { - auto x = NDArrayFactory::create_('c', { 5, 5 }); - auto y = NDArrayFactory::create_('c', { 5, 5 }); - auto z = NDArrayFactory::create_('c', { 5, 5 }); + auto x = NDArrayFactory::create('c', { 5, 5 }); + auto y = NDArrayFactory::create('c', { 5, 5 }); + auto z = NDArrayFactory::create('c', { 5, 5 }); auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(3); - y->assign(1); - z->assign(2); + x.assign(3); + y.assign(1); + z.assign(2); exp.assign(6); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); + variableSpace->putVariable(1, NDArrayFactory::create('c', { 5, 5 })); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2, -3 }); @@ -859,18 +855,18 @@ TEST_F(DeclarableOpsTests1, MergeSumTest1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ClipByValue1) { - auto x = NDArrayFactory::create_('c', { 5, 5 }); + auto x = NDArrayFactory::create('c', { 5, 5 }); auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(4); - x->p(0, -1); - x->p(1, 2); + x.assign(4); + x.p(0, -1); + x.p(1, 2); exp.assign(3); exp.p(0, 0); exp.p(1, 2); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, true); block->appendT(0.0f); @@ -882,8 +878,7 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { clip.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); - + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -892,13 +887,13 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeAvgTest1) { - auto x = NDArrayFactory::create_('c', { 5, 5 }); - auto y = NDArrayFactory::create_('c', { 5, 5 }); - auto z = NDArrayFactory::create_('c', { 5, 5 }); + auto x = NDArrayFactory::create('c', { 5, 5 }); + auto y = NDArrayFactory::create('c', { 5, 5 }); + auto z = NDArrayFactory::create('c', { 5, 5 }); auto exp = NDArrayFactory::create('c', { 5, 5 }); - x->assign(3); - y->assign(1); - z->assign(2); + x.assign(3); + y.assign(1); + z.assign(2); exp.assign(2); auto zu = NDArrayFactory::create('c', { 5, 5 }); @@ -907,7 +902,7 @@ TEST_F(DeclarableOpsTests1, MergeAvgTest1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); + variableSpace->putVariable(1, NDArrayFactory::create('c', { 5, 5 })); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2, -3 }); @@ -917,7 +912,7 @@ TEST_F(DeclarableOpsTests1, MergeAvgTest1) { auto res = variableSpace->getVariable(1)->getNDArray(); - ASSERT_TRUE(res->equalsTo(&exp)); + ASSERT_TRUE(res->equalsTo(exp)); delete block; delete variableSpace; @@ -927,11 +922,11 @@ TEST_F(DeclarableOpsTests1, MergeAvgTest1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(3); - y->assign(1); + x.assign(3); + y.assign(1); exp.assign(2); auto variableSpace = new VariableSpace(); @@ -944,7 +939,7 @@ TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete block; delete variableSpace; @@ -955,11 +950,11 @@ TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3); - y->assign(1); + x.assign(3); + y.assign(1); exp.assign(2); auto variableSpace = new VariableSpace(); @@ -972,7 +967,7 @@ TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete block; delete variableSpace; @@ -982,11 +977,11 @@ TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 1, 1 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(3); - y->assign(1); + x.assign(3); + y.assign(1); exp.assign(2); auto variableSpace = new VariableSpace(); @@ -999,7 +994,7 @@ TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete block; delete variableSpace; @@ -1008,11 +1003,11 @@ TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(3.f); - y->assign(1.f); + x.assign(3.f); + y.assign(1.f); exp.assign(-2.f); auto variableSpace = new VariableSpace(); @@ -1025,7 +1020,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1046,7 +1041,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -1064,14 +1059,14 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { exp.assign(-2.f); x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - ASSERT_TRUE(exp.equalsTo(&z)); + ASSERT_TRUE(exp.equalsTo(z)); sd::ops::reversesubtract subOp; auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -1093,7 +1088,7 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -1120,7 +1115,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); ASSERT_TRUE(exp.equalsTo(&z)); @@ -1147,7 +1142,7 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -1155,12 +1150,12 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x.assign(3); + y.assign(1); + exp.assign(-2); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1172,23 +1167,22 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x.assign(3); + y.assign(1); + exp.assign(-2); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1200,23 +1194,22 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 1, 1 }); - x->assign(3); - y->assign(1); - exp->assign(-2); + auto x = NDArrayFactory::create('c', { 1, 1 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 1, 1 }); + x.assign(3); + y.assign(1); + exp.assign(-2); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1228,22 +1221,21 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { subOp.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x.assign(2); + y.assign(3); + exp.assign(6); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1255,22 +1247,21 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); - auto exp = NDArrayFactory::create_('c', { 1, 15 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x.assign(2); + y.assign(3); + exp.assign(6); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1282,22 +1273,21 @@ TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x.assign(2); + y.assign(3); + exp.assign(6); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1309,22 +1299,21 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); - auto exp = NDArrayFactory::create_('c', { 1, 1 }); - x->assign(2); - y->assign(3); - exp->assign(6); + auto x = NDArrayFactory::create('c', { 1, 1 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 1, 1 }); + x.assign(2); + y.assign(3); + exp.assign(6); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1336,34 +1325,33 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { mul.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete block; delete variableSpace; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { - auto input = NDArrayFactory::create_('c', { 2, 2 }); - for (int e = 0; e < input->lengthOf(); e++) - input->p(e, e + 1); + auto input = NDArrayFactory::create('c', { 2, 2 }); + for (int e = 0; e < input.lengthOf(); e++) + input.p(e, e + 1); - auto epsilon = NDArrayFactory::create_('c', { 2, 2 }); - epsilon->p(0, 0.1f); - epsilon->p(1, 0.2f); - epsilon->p(2, 0.3f); - epsilon->p(3, 0.4f); + auto epsilon = NDArrayFactory::create('c', { 2, 2 }); + epsilon.p(0, 0.1f); + epsilon.p(1, 0.2f); + epsilon.p(2, 0.3f); + epsilon.p(3, 0.4f); - auto output = NDArrayFactory::create_('c', { 2, 2 }); - output->assign(1.0f); + auto output = NDArrayFactory::create('c', { 2, 2 }); + output.assign(1.0f); - auto exp = NDArrayFactory::create_('c', { 2, 2 }); - exp->p(0, -0.019661194f); - exp->p(1, 0.019661194f); - exp->p(2, -0.019661194f); - exp->p(3, 0.019661194f); + auto exp = NDArrayFactory::create('c', { 2, 2 }); + exp.p(0, -0.019661194f); + exp.p(1, 0.019661194f); + exp.p(2, -0.019661194f); + exp.p(3, 0.019661194f); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, input); @@ -1379,12 +1367,10 @@ TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { Nd4jStatus status = op.execute(block); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output->equalsTo(exp)); + ASSERT_TRUE(output.equalsTo(exp)); delete variableSpace; delete block; - delete exp; - } ////////////////////////////////////////////////////////////////////// @@ -1402,7 +1388,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { auto res = div.evaluate({ &x, &y }); ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -1421,7 +1407,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { auto res = div.evaluate({ &x, &y }); ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -1437,7 +1423,7 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { auto res = div.evaluate({ &x, &y }); ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -1458,7 +1444,7 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); auto z(exp); x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); @@ -1471,12 +1457,12 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); - auto exp = NDArrayFactory::create_('c', { 5, 3 }); - x->assign(6); - y->assign(2); - exp->assign(3); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x.assign(6); + y.assign(2); + exp.assign(3); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); @@ -1488,21 +1474,20 @@ TEST_F(DeclarableOpsTests1, DivideMatrices1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); delete variableSpace; delete block; - delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(6); - y->assign(2); + x.assign(6); + y.assign(2); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1515,7 +1500,7 @@ TEST_F(DeclarableOpsTests1, DivideVectorVector1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1524,11 +1509,11 @@ TEST_F(DeclarableOpsTests1, DivideVectorVector1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(6); - y->assign(2); + x.assign(6); + y.assign(2); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1541,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete block; delete variableSpace; @@ -1551,11 +1536,11 @@ TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 1 }); - auto y = NDArrayFactory::create_('c', { 5, 1 }); + auto x = NDArrayFactory::create('c', { 5, 1 }); + auto y = NDArrayFactory::create('c', { 5, 1 }); auto exp = NDArrayFactory::create('c', { 5, 1 }); - x->assign(6); - y->assign(2); + x.assign(6); + y.assign(2); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1568,7 +1553,7 @@ TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1577,11 +1562,11 @@ TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 5, 3 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(6); + x.assign(2); + y.assign(6); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1594,7 +1579,7 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1603,11 +1588,11 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { - auto x = NDArrayFactory::create_('c', { 1, 15 }); - auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto x = NDArrayFactory::create('c', { 1, 15 }); + auto y = NDArrayFactory::create('c', { 1, 15 }); auto exp = NDArrayFactory::create('c', { 1, 15 }); - x->assign(2); - y->assign(6); + x.assign(2); + y.assign(6); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1620,7 +1605,7 @@ TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1629,11 +1614,11 @@ TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { - auto x = NDArrayFactory::create_('c', { 5, 3 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 5, 3 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 5, 3 }); - x->assign(2); - y->assign(6); + x.assign(2); + y.assign(6); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1646,7 +1631,7 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1655,11 +1640,11 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { - auto x = NDArrayFactory::create_('c', { 1, 1 }); - auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto x = NDArrayFactory::create('c', { 1, 1 }); + auto y = NDArrayFactory::create('c', { 1, 1 }); auto exp = NDArrayFactory::create('c', { 1, 1 }); - x->assign(2); - y->assign(6); + x.assign(2); + y.assign(6); exp.assign(3); auto variableSpace = new VariableSpace(); @@ -1672,7 +1657,7 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { div.execute(block); - ASSERT_TRUE(x->equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); delete variableSpace; delete block; @@ -1828,12 +1813,12 @@ TEST_F(DeclarableOpsTests1, TestGemv1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Transpose1) { - auto x = NDArrayFactory::create_('c', { 3,5,2 }); - auto exp = NDArrayFactory::create_('c', { 2,5,3 }); + auto x = NDArrayFactory::create('c', { 3,5,2 }); + auto exp = NDArrayFactory::create('c', { 2,5,3 }); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); @@ -1844,11 +1829,10 @@ TEST_F(DeclarableOpsTests1, Transpose1) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp->isSameShape(result)); - ASSERT_TRUE(exp->dataType() == result->dataType()); - ASSERT_TRUE(exp->ordering() == result->ordering()); + ASSERT_TRUE(exp.isSameShape(*result)); + ASSERT_TRUE(exp.dataType() == result->dataType()); + ASSERT_TRUE(exp.ordering() == result->ordering()); - delete exp; delete block; delete variableSpace; } @@ -1864,12 +1848,12 @@ TEST_F(DeclarableOpsTests1, Permute1) { ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); - auto x = new NDArray(shapeX, true); - auto exp = new NDArray(shapeExp, true); + NDArray x(shapeX, true); + NDArray exp(shapeExp, true); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); @@ -1881,11 +1865,10 @@ TEST_F(DeclarableOpsTests1, Permute1) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(result->isSameShapeStrict(*exp)); + ASSERT_TRUE(result->isSameShapeStrict(exp)); delete block; delete variableSpace; - delete exp; } ////////////////////////////////////////////////////////////////////// @@ -1897,12 +1880,12 @@ TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); const std::vector perm = { 2, 0, 1 }; - auto x = new NDArray(shapeX); - auto exp = new NDArray(shapeExp); + NDArray x(shapeX); + NDArray exp(shapeExp); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); @@ -1912,14 +1895,13 @@ TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { ASSERT_TRUE(status != 0); - delete exp; delete block; delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape1) { - auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + auto input = NDArrayFactory::create('c', { 4, 5, 5, 10, 10 }); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, input); @@ -1932,8 +1914,8 @@ TEST_F(DeclarableOpsTests1, TestReductionShape1) { sd::ops::testreduction testop; - auto inP = new Nd4jLong[shape::shapeInfoLength(input->getShapeInfo())]; - memcpy(inP, input->getShapeInfo(), shape::shapeInfoByteLength(input->rankOf())); + auto inP = new Nd4jLong[shape::shapeInfoLength(input.getShapeInfo())]; + memcpy(inP, input.getShapeInfo(), shape::shapeInfoByteLength(input.rankOf())); auto inshape = new ShapeList(inP); @@ -1954,7 +1936,7 @@ TEST_F(DeclarableOpsTests1, TestReductionShape1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape2) { - auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + auto input = NDArrayFactory::create('c', { 4, 5, 5, 10, 10 }); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, input); @@ -1971,7 +1953,7 @@ TEST_F(DeclarableOpsTests1, TestReductionShape2) { sd::ops::testreduction testop; - auto inshapes = new ShapeList(input->getShapeInfo()); + auto inshapes = new ShapeList(input.getShapeInfo()); auto shapes = testop.calculateOutputShape(inshapes, *block); ASSERT_EQ(1, shapes->size()); ASSERT_EQ(1, shapes->at(0)[0]); @@ -1986,7 +1968,7 @@ TEST_F(DeclarableOpsTests1, TestReductionShape2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestCustomShape1) { - auto input = NDArrayFactory::create_('c', { 2, 3, 4 }); + auto input = NDArrayFactory::create('c', { 2, 3, 4 }); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, input); @@ -1996,14 +1978,14 @@ TEST_F(DeclarableOpsTests1, TestCustomShape1) { sd::ops::testcustom test; - auto inshapes = new ShapeList(input->getShapeInfo()); + auto inshapes = new ShapeList(input.getShapeInfo()); auto shapes = test.calculateOutputShape(inshapes, *block); - ASSERT_EQ(input->getShapeInfo()[0], shapes->at(0)[0]); - ASSERT_EQ(input->getShapeInfo()[1] * 2, shapes->at(0)[1]); - ASSERT_EQ(input->getShapeInfo()[2] * 2, shapes->at(0)[2]); - ASSERT_EQ(input->getShapeInfo()[3] * 2, shapes->at(0)[3]); + ASSERT_EQ(input.getShapeInfo()[0], shapes->at(0)[0]); + ASSERT_EQ(input.getShapeInfo()[1] * 2, shapes->at(0)[1]); + ASSERT_EQ(input.getShapeInfo()[2] * 2, shapes->at(0)[2]); + ASSERT_EQ(input.getShapeInfo()[3] * 2, shapes->at(0)[3]); delete variableSpace; delete block; @@ -2052,7 +2034,7 @@ TEST_F(DeclarableOpsTests1, Sum1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Pnormpool2d1) { - auto x = NDArrayFactory::create_('c', { bS,iD,iH,iW }); + auto x = NDArrayFactory::create('c', { bS,iD,iH,iW }); auto exp = NDArrayFactory::create('c', { bS,iD,oH,oW }); // auto z('c',{bS,iD,oH,oW}); @@ -2069,7 +2051,7 @@ TEST_F(DeclarableOpsTests1, Pnormpool2d1) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -2516,13 +2498,12 @@ TEST_F(DeclarableOpsTests1, ArgMax6) { ASSERT_EQ(Status::OK(), expected.status()); auto exp = expected.at(0); - auto result = op.evaluate({ &x, &dim }, {}, {}); ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(*exp, *z); + ASSERT_EQ(exp, z); } @@ -2543,8 +2524,6 @@ TEST_F(DeclarableOpsTests1, ArgMin1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -2675,7 +2654,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) { auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -2688,7 +2667,7 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) { auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }, {}, { sd::DataType::HALF }, false); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -2706,7 +2685,7 @@ TEST_F(DeclarableOpsTests1, FillAs_1) { ASSERT_TRUE(x.isSameShape(result.at(0))); - ASSERT_NEAR(scalar, result.at(0)->meanNumber().e(0), 1e-5f); + ASSERT_NEAR(scalar, result.at(0).meanNumber().e(0), 1e-5f); } @@ -2997,7 +2976,7 @@ TEST_F(DeclarableOpsTests1, Reverse_1) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3024,8 +3003,6 @@ TEST_F(DeclarableOpsTests1, Reverse_2) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - - } ////////////////////////////////////////////////////////////////////// @@ -3048,7 +3025,7 @@ TEST_F(DeclarableOpsTests1, Reverse_3) { auto result = results.at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3074,7 +3051,7 @@ TEST_F(DeclarableOpsTests1, Reverse_4) { auto result = results.at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3099,7 +3076,7 @@ TEST_F(DeclarableOpsTests1, Reverse_5) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3153,7 +3130,7 @@ TEST_F(DeclarableOpsTests1, Reverse_7) { //expected.printIndexedBuffer("E"); //result->printIndexedBuffer("R"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3181,7 +3158,7 @@ TEST_F(DeclarableOpsTests1, Reverse_8) { auto result = results.at(0); // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3206,7 +3183,7 @@ TEST_F(DeclarableOpsTests1, Reverse_9) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3245,7 +3222,7 @@ TEST_F(DeclarableOpsTests1, Reverse_11) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3267,7 +3244,7 @@ TEST_F(DeclarableOpsTests1, Reverse_12) { auto result = results.at(0); //result->printIndexedBuffer("Result reverse"); //expected.printIndexedBuffer("Expected reverse"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3288,7 +3265,7 @@ TEST_F(DeclarableOpsTests1, Reverse_13) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3309,7 +3286,7 @@ TEST_F(DeclarableOpsTests1, Reverse_14) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3337,12 +3314,11 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { TEST_F(DeclarableOpsTests1, Test_Expose_2) { auto list = new NDArrayList(0, true); - auto var = new Variable(nullptr, "arraylist", -1, 0); - var->setNDArrayList(list); + auto var = std::make_shared(NDArray(), "arraylist", -1, 0); + //var->setNDArrayList(list); VariableSpace variableSpace; variableSpace.putVariable(-1, var); - variableSpace.trackList(list); Context block(1, &variableSpace); block.pickInput(-1); @@ -3359,7 +3335,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_2) { auto list1 = var1->getNDArrayList(); - ASSERT_TRUE(list == list1); + ASSERT_TRUE(list == list1.get()); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 03e5ae53f634..c028bd54b6f6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -64,7 +64,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { ASSERT_EQ(Status::OK(), result.status()); - auto z = *result.at(0); + auto z = result.at(0); ASSERT_EQ(e, z); } @@ -80,7 +80,7 @@ TEST_F(DeclarableOpsTests10, Test_ArgMax_2) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - auto z = *result.at(0); + auto z = result.at(0); //z.printIndexedBuffer("z"); //z.printShapeInfo("z shape"); @@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests10, Test_And_1) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Or_1) { @@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Not_1) { @@ -134,7 +134,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) { auto result = op.evaluate({&x}, {1}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); auto resA = res.at(0); - ASSERT_TRUE(resA->isEmpty()); + ASSERT_TRUE(resA.isEmpty()); //resA->printIndexedBuffer("Result A"); //resA->printShapeInfo("ShapeA"); //ASSERT_TRUE(exp.equalsTo(resA)); @@ -331,7 +331,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { auto res = op.evaluate({&input}, {}, {}); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); auto resA = res.at(0); - ASSERT_TRUE(resA->isEmpty()); + ASSERT_TRUE(resA.isEmpty()); //resA->printIndexedBuffer("Result A"); //resA->printShapeInfo("ShapeA"); //ASSERT_TRUE(exp.equalsTo(resA)); @@ -859,7 +859,7 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *out = results.at(0); + auto out = results.at(0); ASSERT_TRUE(exp.isSameShape(out)); // out->printBuffer("5HIST"); @@ -902,7 +902,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -922,7 +922,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -942,7 +942,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -962,7 +962,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -981,7 +981,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1000,7 +1000,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1020,7 +1020,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); } @@ -1038,7 +1038,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); } @@ -1062,7 +1062,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1087,7 +1087,7 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1107,7 +1107,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1127,7 +1127,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1165,7 +1165,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1183,7 +1183,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1201,7 +1201,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1219,7 +1219,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1239,7 +1239,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1259,7 +1259,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1279,7 +1279,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1333,7 +1333,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); //result.printIndexedBuffer("Resized to 10x10"); //expected.printIndexedBuffer("Expect for 10x10"); @@ -1355,8 +1355,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); - ASSERT_NE(*result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } //////////////////////////////////////////////////////////////////// @@ -1372,8 +1372,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); - ASSERT_NE(*result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { @@ -1414,7 +1414,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); //expected.printIndexedBuffer("Expect for 10x10"); @@ -1461,7 +1461,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 4x5"); // expected.printBuffer("Expect for 4x5"); @@ -1517,7 +1517,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); //result.printIndexedBuffer("Resized to 10x10"); //expected.printIndexedBuffer("Expect for 10x10"); @@ -1672,7 +1672,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 9x9"); // expected.printBuffer("Expect for 9x9"); @@ -1731,7 +1731,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1859,7 +1859,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1986,7 +1986,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printIndexedBuffer("Resized to 10x10"); // expected.printIndexedBuffer("Expected of 10x10"); // result.printShapeInfo("Resized to 10x10 shape"); @@ -2050,7 +2050,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printIndexedBuffer("Resized to 4x5"); // expected.printIndexedBuffer("Expect for 4x5"); @@ -2096,7 +2096,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printIndexedBuffer("Resized to 4x5"); // expected.printIndexedBuffer("Expect for 4x5"); @@ -2142,7 +2142,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printIndexedBuffer("Resized to 4x5"); // expected.printBuffer("Expect for 4x5"); @@ -2187,7 +2187,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); //result.printIndexedBuffer("Resized to 4x5"); //expected.printIndexedBuffer("Expect for 4x5"); @@ -2209,7 +2209,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2229,7 +2229,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { auto result = results.at(0); // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2249,7 +2249,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { auto result = results.at(0); // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); // expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// @@ -2265,10 +2265,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); //result.printIndexedBuffer("OOOOUUUUTTT"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2287,9 +2287,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2307,9 +2307,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput3"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2328,9 +2328,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { @@ -2348,9 +2348,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2369,10 +2369,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput6"); // result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2391,10 +2391,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput06"); // result.printShapeInfo("Ouput06 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2412,10 +2412,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { ASSERT_EQ(Status::OK(), results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppression OUtput7"); // result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(result->isEmpty()); + ASSERT_TRUE(result.isEmpty()); } //////////////////////////////////////////////////////////////////// @@ -2435,9 +2435,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2458,9 +2458,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2481,9 +2481,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2506,7 +2506,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { auto result = results.at(0); // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2528,7 +2528,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2550,7 +2550,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2572,7 +2572,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { auto result = results.at(0); // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2594,7 +2594,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); //ASSERT_TRUE(expected.equalsTo(result)); } @@ -2627,10 +2627,10 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto result = results.at(0); - result->syncToHost(); + result.syncToHost(); // result.printBuffer("Bounded boxes"); // expected.printBuffer("Bounded expec"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2661,7 +2661,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { // result.syncToHost(); // result.printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2714,7 +2714,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { // result.syncToHost(); // result.printBuffer("Bounded boxes 2"); // expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); } @@ -2734,7 +2734,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { auto result = results.at(0); // result.printBuffer("Quantized"); // exp.printBuffer("Expected"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// @@ -2752,7 +2752,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { auto result = results.at(0); // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2771,7 +2771,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { auto result = results.at(0); // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2793,7 +2793,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { auto result = results.at(0); // result.printIndexedBuffer("Quantized03"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2815,7 +2815,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { auto result = results.at(0); // result.printIndexedBuffer("Quantized03_1"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2836,8 +2836,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto result = results.at(0); - result->printIndexedBuffer("Quantized03_2"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + //result.printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2857,8 +2857,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto result = results.at(0); - result->printIndexedBuffer("Quantized03_3"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + //result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2896,9 +2896,9 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { auto result = results.at(0); // result.printBuffer("Quantized per channels 4"); // exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; +// auto diff = result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2947,10 +2947,10 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { auto result = results.at(0); // result.printBuffer("Quantized per channels 5"); // exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; +// auto diff = result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -2979,10 +2979,10 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { auto result = results.at(0); // result.printBuffer("Quantized per channels 5"); // exp.printBuffer("Quantized per channest E"); -// auto diff = *result - exp; +// auto diff = result - exp; // diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -3023,7 +3023,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { auto result = results.at(0); // result.printBuffer("Quantized7"); // exp.printBuffer("Expected 7"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -3047,7 +3047,7 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { // x.printBuffer("SourInput8"); // result.printBuffer("Quantized8"); // exp.printBuffer("Expected 8"); - ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.equalsTo(result)); } @@ -3108,7 +3108,7 @@ TEST_F(DeclarableOpsTests10, printIndexedTest_1) { // printf("["); // else // printf(" "); - lastDims.at(k++)->printBuffer(); + lastDims.at(k++).printBuffer(); //if (k == arr.sizeAt(i)) // printf("]\n"); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index b8c89322cd03..ffda9aeccb6c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -73,9 +73,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -131,9 +131,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -161,7 +161,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); // dLdw->printIndexedBuffer(); // dLdw->printShapeInfo(); @@ -192,9 +192,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -222,7 +222,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -246,7 +246,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -279,9 +279,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -314,9 +314,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -344,7 +344,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -368,7 +368,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -403,9 +403,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -440,9 +440,9 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -620,7 +620,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 30x30"); // expected.printBuffer("Expect for 30x30"); @@ -694,7 +694,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 10x8"); // expected.printBuffer("Expect for 10x8"); @@ -730,7 +730,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 6x6"); // expected.printBuffer("Expect for 6x6"); @@ -766,7 +766,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 6x8"); // expected.printBuffer("Expect for 6x8"); @@ -808,7 +808,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 8x8"); // expected.printBuffer("Expect for 8x8"); @@ -936,7 +936,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { auto results = op.evaluate({&input, &size}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 30x30"); // expected.printBuffer("Expect for 30x30"); @@ -994,7 +994,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 9x9"); // expected.printBuffer("Expect for 9x9"); @@ -1046,7 +1046,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Resized to 9x9"); // testData.printBuffer("Expect for 9x9"); @@ -1106,7 +1106,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1132,7 +1132,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1159,7 +1159,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1196,7 +1196,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1233,7 +1233,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1270,7 +1270,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1307,7 +1307,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1336,7 +1336,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); @@ -1362,7 +1362,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 10x10"); // expected.printBuffer("Area Expect for 6x6"); @@ -1388,7 +1388,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 10x10"); // expected.printBuffer("Area Expect for 6x6"); @@ -1414,7 +1414,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x9"); // expected.printBuffer("Area Expect for 6x6"); @@ -1440,7 +1440,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 6x9"); // expected.printBuffer("Area Expect for 6x6"); @@ -1466,7 +1466,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 8x8"); // expected.printBuffer("Area Expect for 6x6"); @@ -1496,7 +1496,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 8x7"); // expected.printBuffer("Area Expect for 8x7"); ASSERT_TRUE(expected.isSameShape(result)); @@ -1524,7 +1524,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result.printBuffer("Area Resized to 8x7"); // expected.printBuffer("Area Expect for 8x7"); ASSERT_TRUE(expected.isSameShape(result)); @@ -2023,8 +2023,8 @@ TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto z = res.at(0); - z->printIndexedBuffer("L matrix is"); - exp.printIndexedBuffer("L expected is"); + //z->printIndexedBuffer("L matrix is"); + //exp.printIndexedBuffer("L expected is"); ASSERT_TRUE(exp.equalsTo(z)); @@ -2054,7 +2054,7 @@ TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) { // z->printIndexedBuffer("L matrix is"); // exp.printIndexedBuffer("L expected is"); - MmulHelper::matmul(z, z, &exp, false, true); + MmulHelper::matmul(&z, &z, &exp, false, true); ASSERT_TRUE(exp.equalsTo(a)); } @@ -2087,8 +2087,8 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2110,7 +2110,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2137,16 +2137,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2168,7 +2168,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2196,16 +2196,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2227,7 +2227,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2252,7 +2252,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2284,16 +2284,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2318,16 +2318,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2349,7 +2349,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2374,7 +2374,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2406,16 +2406,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2442,16 +2442,16 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2519,8 +2519,8 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// @@ -2541,7 +2541,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2568,16 +2568,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2599,7 +2599,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2627,16 +2627,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2658,7 +2658,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2683,7 +2683,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2715,16 +2715,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2749,16 +2749,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2780,7 +2780,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2805,7 +2805,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -2837,16 +2837,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2873,16 +2873,16 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } @@ -2902,7 +2902,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } @@ -2922,7 +2922,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } @@ -2942,7 +2942,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// @@ -2968,9 +2968,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3003,9 +3003,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3038,9 +3038,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3069,7 +3069,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -3099,9 +3099,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3130,7 +3130,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -3155,7 +3155,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -3188,9 +3188,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3223,9 +3223,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3254,7 +3254,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -3279,7 +3279,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdw = results.at(1); + auto dLdw = results.at(1); ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); @@ -3313,9 +3313,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3350,9 +3350,9 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3379,7 +3379,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// @@ -3398,7 +3398,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } @@ -3418,7 +3418,7 @@ TEST_F(DeclarableOpsTests11, BFloat16_Test_6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); auto res = results.at(0); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res.equalsTo(exp)); } ///////////////////////////////////////////////////////////////// @@ -3440,9 +3440,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3469,9 +3469,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3498,9 +3498,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3527,9 +3527,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3557,9 +3557,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3587,9 +3587,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3617,9 +3617,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3658,9 +3658,9 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); // dLdp->printIndexedBuffer(); @@ -3698,7 +3698,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3721,7 +3721,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3743,7 +3743,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3764,7 +3764,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3785,7 +3785,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3806,7 +3806,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3827,7 +3827,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3848,7 +3848,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3874,7 +3874,7 @@ TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdo = results.at(0); + auto dLdo = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); @@ -3896,7 +3896,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3919,7 +3919,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3940,7 +3940,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3963,7 +3963,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -3984,7 +3984,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); + auto dLdp = results.at(0); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 615f95bbd946..9f788f3439d7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -49,7 +49,7 @@ TEST_F(DeclarableOpsTests12, test_any_validation_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(x.dataType(), z->dataType()); + ASSERT_EQ(x.dataType(), z.dataType()); } @@ -74,9 +74,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -106,9 +106,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -116,8 +116,6 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - } ///////////////////////////////////////////////////////////////// @@ -140,9 +138,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -174,9 +172,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -209,9 +207,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -243,9 +241,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -279,9 +277,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -315,9 +313,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -351,9 +349,9 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *dLdp = results.at(0); - auto *dLdw = results.at(1); - auto *dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); @@ -917,7 +915,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_1) { auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -948,7 +946,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_2) { auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -979,7 +977,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_3) { auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1010,7 +1008,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_4) { auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1033,7 +1031,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_5) { auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1052,7 +1050,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_6) { auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1109,7 +1107,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) { // for (int i = 0; i < exp.lengthOf(); ++i) // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1126,7 +1124,7 @@ TEST_F(DeclarableOpsTests12, lrn_bp_10) { auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); auto gradI = results.at(0); - ASSERT_EQ(*gradI, exp); + ASSERT_EQ(gradI, exp); } @@ -1147,7 +1145,7 @@ TEST_F(DeclarableOpsTests12, lrn_1) { auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); auto output = results.at(0); - ASSERT_EQ(*output, exp); + ASSERT_EQ(output, exp); } @@ -1162,7 +1160,7 @@ TEST_F(DeclarableOpsTests12, lrn_2) { auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); auto output = results.at(0); - ASSERT_EQ(*output, exp); + ASSERT_EQ(output, exp); } @@ -1177,7 +1175,7 @@ TEST_F(DeclarableOpsTests12, lrn_3) { auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); auto output = results.at(0); - ASSERT_EQ(*output, exp); + ASSERT_EQ(output, exp); } @@ -1192,7 +1190,7 @@ TEST_F(DeclarableOpsTests12, lrn_4) { auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); auto output = results.at(0); - ASSERT_EQ(*output, exp); + ASSERT_EQ(output, exp); } @@ -1207,7 +1205,7 @@ TEST_F(DeclarableOpsTests12, lrn_5) { auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); auto output = results.at(0); - ASSERT_EQ(*output, exp); + ASSERT_EQ(output, exp); } @@ -1249,7 +1247,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); //res.at(0)->printIndexedBuffer("IN_TOP_K output"); - ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } @@ -1375,7 +1373,7 @@ TEST_F(DeclarableOpsTests12, pad_tests1) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1402,7 +1400,7 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1429,7 +1427,7 @@ TEST_F(DeclarableOpsTests12, pad_tests3) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1460,7 +1458,7 @@ TEST_F(DeclarableOpsTests12, pad_tests4) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); // for(int i = 0; i < expected.lengthOf(); ++i) { @@ -1494,7 +1492,7 @@ TEST_F(DeclarableOpsTests12, pad_tests5) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1521,7 +1519,7 @@ TEST_F(DeclarableOpsTests12, pad_tests6) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1544,10 +1542,10 @@ TEST_F(DeclarableOpsTests12, pad_tests7) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1570,10 +1568,10 @@ TEST_F(DeclarableOpsTests12, pad_tests8) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1596,10 +1594,10 @@ TEST_F(DeclarableOpsTests12, pad_tests9) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1621,7 +1619,7 @@ TEST_F(DeclarableOpsTests12, pad_tests10) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1643,7 +1641,7 @@ TEST_F(DeclarableOpsTests12, pad_tests11) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1672,7 +1670,7 @@ TEST_F(DeclarableOpsTests12, pad_tests12) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1694,7 +1692,7 @@ TEST_F(DeclarableOpsTests12, pad_tests13) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1715,7 +1713,7 @@ TEST_F(DeclarableOpsTests12, pad_tests14) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1736,7 +1734,7 @@ TEST_F(DeclarableOpsTests12, pad_tests15) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1757,7 +1755,7 @@ TEST_F(DeclarableOpsTests12, pad_tests16) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1778,7 +1776,7 @@ TEST_F(DeclarableOpsTests12, pad_tests17) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1799,7 +1797,7 @@ TEST_F(DeclarableOpsTests12, pad_tests18) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1820,7 +1818,7 @@ TEST_F(DeclarableOpsTests12, pad_tests19) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1841,7 +1839,7 @@ TEST_F(DeclarableOpsTests12, pad_tests20) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1864,7 +1862,7 @@ TEST_F(DeclarableOpsTests12, pad_tests21) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1887,7 +1885,7 @@ TEST_F(DeclarableOpsTests12, pad_tests22) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1911,7 +1909,7 @@ TEST_F(DeclarableOpsTests12, pad_tests23) { // result->printShapeInfo("r"); // expected.printShapeInfo("e"); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1933,7 +1931,7 @@ TEST_F(DeclarableOpsTests12, pad_tests24) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1955,7 +1953,7 @@ TEST_F(DeclarableOpsTests12, pad_tests25) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1977,7 +1975,7 @@ TEST_F(DeclarableOpsTests12, pad_tests26) { auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2147,7 +2145,7 @@ TEST_F(DeclarableOpsTests12, Pad_1) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2174,7 +2172,7 @@ TEST_F(DeclarableOpsTests12, Pad_2) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2201,7 +2199,7 @@ TEST_F(DeclarableOpsTests12, Pad_3) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2228,7 +2226,7 @@ TEST_F(DeclarableOpsTests12, Pad_4) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2255,7 +2253,7 @@ TEST_F(DeclarableOpsTests12, Pad_5) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2282,7 +2280,7 @@ TEST_F(DeclarableOpsTests12, Pad_6) { auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2305,10 +2303,10 @@ TEST_F(DeclarableOpsTests12, Pad_7) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2331,10 +2329,10 @@ TEST_F(DeclarableOpsTests12, Pad_8) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2357,10 +2355,10 @@ TEST_F(DeclarableOpsTests12, Pad_9) ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2740,13 +2738,13 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) { // q->printShapeInfo("Q shape"); // r->printShapeInfo("R shape"); sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); + ASSERT_TRUE(exp.isSameShape(in)); // ASSERT_TRUE(q->isSameShape(expQ)); //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp->equalsTo(in)); + ASSERT_TRUE(exp.equalsTo(in)); } @@ -2786,13 +2784,13 @@ TEST_F(DeclarableOpsTests12, QR_Test_1_1) { // q->printShapeInfo("Q shape"); // r->printShapeInfo("R shape"); sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); + ASSERT_TRUE(exp.isSameShape(in)); // ASSERT_TRUE(q->isSameShape(expQ)); //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp->equalsTo(in)); + ASSERT_TRUE(exp.equalsTo(in)); } @@ -2810,14 +2808,14 @@ TEST_F(DeclarableOpsTests12, QR_Test_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto q = res.at(0); auto r = res.at(1); - ASSERT_TRUE(q->isSameShape(expQ)); - ASSERT_TRUE(r->isSameShape(expR)); + ASSERT_TRUE(q.isSameShape(expQ)); + ASSERT_TRUE(r.isSameShape(expR)); sd::ops::matmul opMul; - auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp->isSameShape(in)); - ASSERT_TRUE(exp->equalsTo(in)); + ASSERT_TRUE(exp.isSameShape(in)); + ASSERT_TRUE(exp.equalsTo(in)); } @@ -2999,7 +2997,7 @@ TEST_F(DeclarableOpsTests12, SolveLs_Test_1) { auto z = res.at(0); // z->printIndexedBuffer("MatrixSolveLS"); - MmulHelper::matmul(&a, z, &exp, false, false); + MmulHelper::matmul(&a, &z, &exp, false, false); ASSERT_TRUE(exp.equalsTo(b)); @@ -3022,7 +3020,7 @@ TEST_F(DeclarableOpsTests12, SolveLs_Test_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto z = res.at(0); - MmulHelper::matmul(&a, z, &exp, false, false); + MmulHelper::matmul(&a, &z, &exp, false, false); // z->printIndexedBuffer("MatrixSolveLS2"); @@ -3048,7 +3046,7 @@ TEST_F(DeclarableOpsTests12, SolveLs_Test_3) { auto z = res.at(0); // z->printIndexedBuffer("MatrixSolveLS3"); - MmulHelper::matmul(&a, z, &exp, false, false); + MmulHelper::matmul(&a, &z, &exp, false, false); ASSERT_TRUE(exp.equalsTo(b)); } @@ -3090,7 +3088,7 @@ TEST_F(DeclarableOpsTests12, SolveLs_Test_5) { auto res = op.evaluate({&a, &b}, {false}); ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto z = res.at(0); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); } @@ -3106,9 +3104,7 @@ TEST_F(DeclarableOpsTests12, Solve_Test_6) { auto res = op.evaluate({&a, &b}, {true}); ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto z = res.at(0); - ASSERT_TRUE(z->isEmpty()); - - + ASSERT_TRUE(z.isEmpty()); } ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -3135,7 +3131,7 @@ TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto z = res.at(0); - z->printIndexedBuffer("TriangularSolve with adjoint"); + //z.printIndexedBuffer("TriangularSolve with adjoint"); ASSERT_TRUE(exp.equalsTo(z)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 4b5a24bb912d..228c728fce9a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -63,7 +63,7 @@ TEST_F(DeclarableOpsTests13, test_pow_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_empty_range_1) { @@ -75,7 +75,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); } @@ -87,7 +87,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_empty_range_3) { @@ -97,24 +97,21 @@ TEST_F(DeclarableOpsTests13, test_empty_range_3) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { - auto ctx = new Context(1); - auto arr = NDArrayFactory::create_('c', {1024,1}); + auto ctx = Context(1); + auto arr = NDArrayFactory::create('c', {1024,1}); - ctx->setInputArray(0, arr, true); - ctx->setOutputArray(0, NDArrayFactory::create_('c', {1}), true); - ctx->setInputArray(1, NDArrayFactory::create_(0), true); //Axis 0 + ctx.setInputArray(0, arr); + ctx.setOutputArray(0, NDArrayFactory::create('c', {1})); + ctx.setInputArray(1, NDArrayFactory::create(0)); //Axis 0 sd::ops::argmax op; - auto result = op.execute(ctx); + auto result = op.execute(&ctx); ASSERT_EQ(Status::OK(), result); - - //nd4j_printf("Done\n",""); - delete ctx; } TEST_F(DeclarableOpsTests13, test_add_1) { @@ -162,7 +159,7 @@ TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) { auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests13, test_or_1) { @@ -415,7 +412,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { sd::ops::cell_contains op; auto result = op.evaluate({&corners, &width, &point}, {}, {5}); ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(result.at(0)->e(0)); + ASSERT_TRUE(result.at(0).e(0)); //result.at(2)->printBuffer("Symmetrized3"); //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); @@ -623,7 +620,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -640,7 +637,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -657,7 +654,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -674,7 +671,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -692,7 +689,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -710,7 +707,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -729,7 +726,7 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -747,7 +744,7 @@ TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, shift_bits_3) { @@ -764,7 +761,7 @@ TEST_F(DeclarableOpsTests13, shift_bits_3) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -1167,8 +1164,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *h = results.at(0); - auto *cL = results.at(2); + auto h = results.at(0); + auto cL = results.at(2); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); @@ -1234,8 +1231,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *h = results.at(0); - auto *cL = results.at(2); + auto h = results.at(0); + auto cL = results.at(2); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); @@ -2088,7 +2085,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) { auto output = results.at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2121,7 +2118,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { auto output = results.at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -2148,7 +2145,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -2175,7 +2172,7 @@ TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -2203,7 +2200,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test5) { auto output = results.at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -2230,7 +2227,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) { auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2299,7 +2296,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test8) { auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -2342,7 +2339,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { auto output = results.at(0); // output->printBuffer(); - ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2378,13 +2375,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); @@ -2421,13 +2418,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2462,13 +2459,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2500,13 +2497,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2542,13 +2539,13 @@ return; auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2585,13 +2582,13 @@ return; auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2633,13 +2630,13 @@ return; // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); @@ -2682,13 +2679,13 @@ return; // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2733,13 +2730,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2784,13 +2781,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } @@ -2847,13 +2844,13 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { auto dLdG = results.at(3); auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index dc6b0b550dcb..bfe9f60d4ab2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -49,7 +49,7 @@ TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } @@ -87,12 +87,8 @@ TEST_F(DeclarableOpsTests14, Multiply_test) { sd::ops::multiply op; auto result = op.evaluate({&x, &y}); auto f = result.at(0); - NDArray r = *f; - - ASSERT_EQ(e, r); - ASSERT_EQ(e, *f); - + ASSERT_EQ(e, f); } } @@ -106,7 +102,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -121,7 +117,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -175,7 +171,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -192,7 +188,7 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -206,7 +202,7 @@ TEST_F(DeclarableOpsTests14, test_empty_fill_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(y, *z); + ASSERT_EQ(y, z); } @@ -242,7 +238,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); - ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } @@ -254,7 +250,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); - ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); + ASSERT_EQ(out.e(0), -DataTypeUtils::infOrMax()); } @@ -269,7 +265,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); - ASSERT_EQ(out->e(0), 0.f); + ASSERT_EQ(out.e(0), 0.f); } @@ -286,7 +282,7 @@ TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { auto out = res2.at(0); // out->printShapeInfo("ReduceMean empty shape with keep dims"); // out->printIndexedBuffer("ReduceMean scalar"); - ASSERT_TRUE(std::isnan(out->e(0))); + ASSERT_TRUE(std::isnan(out.e(0))); } @@ -345,7 +341,7 @@ TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -373,7 +369,7 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { auto z = result.at(0); ASSERT_TRUE(x.isSameShape(z)); - ASSERT_EQ(x, *z); + ASSERT_EQ(x, z); } @@ -482,7 +478,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { auto result = op.evaluate({ &x, &y }); ASSERT_EQ(Status::OK(), result.status()); - auto res = *result.at(0); + auto res = result.at(0); ASSERT_EQ(e, res); @@ -503,7 +499,7 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { auto result = op.evaluate({ &x, &y }); ASSERT_EQ(Status::OK(), result.status()); - auto res = *result.at(0); + auto res = result.at(0); ASSERT_EQ(e, res); @@ -809,11 +805,11 @@ TEST_F(DeclarableOpsTests14, matmul_test9) { TEST_F(DeclarableOpsTests14, matmul_test10) { - auto x = NDArrayFactory::create_('c', {3, 5}); - x->linspace(1); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); - auto y = NDArrayFactory::create_('c', {5, 3}); - y->linspace(1); + auto y = NDArrayFactory::create('c', {5, 3}); + y.linspace(1); float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape @@ -823,7 +819,7 @@ TEST_F(DeclarableOpsTests14, matmul_test10) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); block->fillInputs({-1, -2}); @@ -836,7 +832,7 @@ TEST_F(DeclarableOpsTests14, matmul_test10) { auto result = variableSpace->getVariable(1)->getNDArray(); - ASSERT_TRUE(result->equalsTo(&exp)); + ASSERT_TRUE(result->equalsTo(exp)); delete block; delete variableSpace; @@ -967,7 +963,7 @@ TEST_F(DeclarableOpsTests14, matmul_test17) { auto result = op.evaluate({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, *result.at(0)); + ASSERT_EQ(exp, result.at(0)); } @@ -1598,7 +1594,7 @@ TEST_F(DeclarableOpsTests14, Stack_1) { auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1626,7 +1622,7 @@ TEST_F(DeclarableOpsTests14, Stack_2) { auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1654,7 +1650,7 @@ TEST_F(DeclarableOpsTests14, Stack_3) { auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1681,7 +1677,7 @@ TEST_F(DeclarableOpsTests14, Stack_4) { auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1708,7 +1704,7 @@ TEST_F(DeclarableOpsTests14, Stack_5) { auto results = op.evaluate({&input1, &input2}, {}, {0}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1735,7 +1731,7 @@ TEST_F(DeclarableOpsTests14, Stack_6) { auto results = op.evaluate({&input1, &input2}, {}, {1}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1759,7 +1755,7 @@ TEST_F(DeclarableOpsTests14, Stack_7) { auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1782,7 +1778,7 @@ TEST_F(DeclarableOpsTests14, Stack_8) { auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1805,7 +1801,7 @@ TEST_F(DeclarableOpsTests14, Stack_9) { auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1831,7 +1827,7 @@ TEST_F(DeclarableOpsTests14, Stack_10) { //expected.printShapeInfo("exp"); //output->printShapeInfo("out"); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1853,7 +1849,7 @@ TEST_F(DeclarableOpsTests14, Stack_11) { auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); auto output = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1995,13 +1991,13 @@ TEST_F(DeclarableOpsTests14, Stack_18) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); sd::ops::reduce_min sumOp; auto res2 = sumOp.evaluate({&e}, {1.}, {1}); ASSERT_EQ(res2.status(), Status::OK()); auto out = res2.at(0); - ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } @@ -2015,7 +2011,7 @@ TEST_F(DeclarableOpsTests14, Stack_19) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -2029,7 +2025,7 @@ TEST_F(DeclarableOpsTests14, Stack_20) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -2053,10 +2049,10 @@ TEST_F(DeclarableOpsTests14, Stack_21) { auto outStack = resultStack.at(0); auto outConcat = resultConcat.at(0); - outConcat->reshapei({2,3,2}); + outConcat.reshapei({2,3,2}); - ASSERT_TRUE(outStack->isSameShape(outConcat)); - ASSERT_TRUE(outStack->equalsTo(outConcat)); + ASSERT_TRUE(outStack.isSameShape(outConcat)); + ASSERT_TRUE(outStack.equalsTo(outConcat)); } ////////////////////////////////////////////////////////////////////// @@ -2064,8 +2060,8 @@ TEST_F(DeclarableOpsTests14, Reshape1) { const std::vector xShape = { 5,4,3 }; const std::vector yShape = { 3,5,4 }; - auto x = NDArrayFactory::create_('f', xShape); - auto y = NDArrayFactory::create_('f', yShape); + auto x = NDArrayFactory::create('f', xShape); + auto y = NDArrayFactory::create('f', yShape); auto variableSpace = new VariableSpace(); @@ -2078,7 +2074,7 @@ TEST_F(DeclarableOpsTests14, Reshape1) { reshape.execute(block); - ASSERT_TRUE(x->isSameShape(y)); + ASSERT_TRUE(x.isSameShape(y)); delete variableSpace; delete block; @@ -2089,16 +2085,16 @@ TEST_F(DeclarableOpsTests14, Reshape2) { const std::vector xShape = { 5,4,3 }; const std::vector yShape = { 3,5,4 }; - auto x = NDArrayFactory::create_('c', xShape); - auto y = NDArrayFactory::create_('c', yShape); + auto x = NDArrayFactory::create('c', xShape); + auto y = NDArrayFactory::create('c', yShape); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, new Variable()); + variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); - block->appendI(-y->ordering()); + block->appendI(-y.ordering()); block->appendI(3); block->appendI(5); block->appendI(4); @@ -2111,7 +2107,6 @@ TEST_F(DeclarableOpsTests14, Reshape2) { ASSERT_TRUE(result->isSameShape(y)); - delete y; delete block; delete variableSpace; } @@ -2162,7 +2157,7 @@ TEST_F(DeclarableOpsTests14, Reshape6) { auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape7) { @@ -2176,7 +2171,7 @@ TEST_F(DeclarableOpsTests14, Reshape7) { auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape8) { @@ -2203,7 +2198,7 @@ TEST_F(DeclarableOpsTests14, Reshape9) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape10) { @@ -2258,7 +2253,7 @@ TEST_F(DeclarableOpsTests14, Reshape13) { ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, *result.at(0)); + ASSERT_EQ(exp, result.at(0)); delete empty; } @@ -2275,7 +2270,7 @@ TEST_F(DeclarableOpsTests14, Reshape14) { auto z = result.at(0); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -2294,12 +2289,12 @@ TEST_F(DeclarableOpsTests14, Reshape15) { auto result0 = op.evaluate({&x0, &shape0}, {}, {}); ASSERT_EQ(Status::OK(), result0.status()); auto z0 = result0.at(0); - ASSERT_EQ(e0, *z0); + ASSERT_EQ(e0, z0); auto result1 = op.evaluate({&x1, &shape1}, {}, {}); ASSERT_EQ(Status::OK(), result1.status()); auto z1 = result1.at(0); - ASSERT_EQ(e1, *z1); + ASSERT_EQ(e1, z1); } TEST_F(DeclarableOpsTests14, Reshape16) { @@ -2379,45 +2374,45 @@ TEST_F(DeclarableOpsTests14, Reshape20) { auto result = op.evaluate({&x1}, {}, {2, -1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0})); + ASSERT_TRUE(result.at(0).isSameShape({2,0})); result = op.evaluate({&x2}, {}, {2, 0, -1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,5})); + ASSERT_TRUE(result.at(0).isSameShape({2,0,5})); result = op.evaluate({&x2}, {}, {5, 2, -1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + ASSERT_TRUE(result.at(0).isSameShape({5,2,0})); result = op.evaluate({&x2}, {}, {-1, 2, 0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + ASSERT_TRUE(result.at(0).isSameShape({5,2,0})); result = op.evaluate({&x3}, {}, {2, 0, -1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,10})); + ASSERT_TRUE(result.at(0).isSameShape({2,0,10})); result = op.evaluate({&x4}, {}, {2, -1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,5,0})); + ASSERT_TRUE(result.at(0).isSameShape({2,5,0})); result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10})); + ASSERT_TRUE(result.at(0).isSameShape({2,0,0,0,10})); result = op.evaluate({&x6}, {}, {-1, 2, 0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0})); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); result = op.evaluate({&x7}, {}, {-1, 0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2, 0})); + ASSERT_TRUE(result.at(0).isSameShape({2, 0})); result = op.evaluate({&x7}, {}, {10,0,50,100}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); + ASSERT_TRUE(result.at(0).isSameShape({10,0,50,100})); result = op.evaluate({&x7}, {}, {2,0,-1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0)->isSameShape({2,0,1})); + ASSERT_TRUE(result.at(0).isSameShape({2,0,1})); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 5fffa73c549c..df3061f16e1f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -412,7 +412,7 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { auto out = result.at(0); // out->printBuffer("Adjusted Constrast7"); // e.printBuffer("Adjusted expected 7"); - auto diff = e - *out; + auto diff = e - out; // diff.printBuffer("Adjusted subtract 7"); ASSERT_TRUE(e.equalsTo(out)); @@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { // e.printIndexedBuffer("Double to int64"); auto res = result.at(0); - ASSERT_EQ(*res, e); + ASSERT_EQ(res, e); } @@ -570,8 +570,8 @@ TEST_F(DeclarableOpsTests15, test_non_decreasing_1) { sd::ops::is_non_decreasing op; Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); auto status = op.execute(&ctx); ASSERT_EQ(Status::OK(), status); @@ -588,7 +588,7 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_1) { auto z = result.at(0); - ASSERT_EQ(x, *z); + ASSERT_EQ(x, z); } TEST_F(DeclarableOpsTests15, test_check_numeric_2) { @@ -686,8 +686,8 @@ TEST_F(DeclarableOpsTests15, test_hashCode_1) { // resultA0->at(0)->printIndexedBuffer("A0"); // resultA1->at(0)->printIndexedBuffer("A1"); // resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); - ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_hashCode_2) { @@ -706,8 +706,8 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) { // resultA1->at(0)->printIndexedBuffer("A1"); // resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); - ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_rank_1) { @@ -731,7 +731,7 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -810,8 +810,8 @@ TEST_F(DeclarableOpsTests15, test_empty_increasing_1) { auto z = NDArrayFactory::create(false); Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); sd::ops::is_strictly_increasing op; auto status = op.execute(&ctx); @@ -825,8 +825,8 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { auto z = NDArrayFactory::create(false); Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); sd::ops::is_non_decreasing op; auto status = op.execute(&ctx); @@ -1197,8 +1197,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); @@ -1223,8 +1223,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); @@ -1252,8 +1252,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); - auto* dLdxY = resultsY.at(0); - auto* dLdyY = resultsY.at(1); + auto dLdxY = resultsY.at(0); + auto dLdyY = resultsY.at(1); ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); @@ -1282,8 +1282,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); - auto* dLdxX = resultsX.at(0); - auto* dLdyX = resultsX.at(1); + auto dLdxX = resultsX.at(0); + auto dLdyX = resultsX.at(1); ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); @@ -1311,8 +1311,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); @@ -1339,8 +1339,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); - auto* dLdxXC = resultsXC.at(0); - auto* dLdyXC = resultsXC.at(1); + auto dLdxXC = resultsXC.at(0); + auto dLdyXC = resultsXC.at(1); ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); @@ -1367,8 +1367,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); - auto* dLdxY = resultsYs.at(0); - auto* dLdyY = resultsYs.at(1); + auto dLdxY = resultsYs.at(0); + auto dLdyY = resultsYs.at(1); ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); @@ -1392,8 +1392,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); @@ -1420,8 +1420,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdx = results.at(0); - auto* dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); @@ -1448,8 +1448,8 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - auto* dLdxB = resultsB.at(0); - auto* dLdyB = resultsB.at(1); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); @@ -1476,19 +1476,19 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - auto* dLdxB = resultsB.at(0); - auto* dLdyB = resultsB.at(1); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); - for (int i = 0; i < dLdxB->lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdxB->e(i)) && !sd::math::nd4j_isnan(dLdxExpB.e(i))) - ASSERT_NEAR(dLdxB->e(i), dLdxExpB.e(i), 0.00001); + for (int i = 0; i < dLdxB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdxB.e(i)) && !sd::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB.e(i), dLdxExpB.e(i), 0.00001); } ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); - for (int i = 0; i < dLdyB->lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdyB->e(i)) && !sd::math::nd4j_isnan(dLdyExpB.e(i))) - ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); + for (int i = 0; i < dLdyB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdyB.e(i)) && !sd::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB.e(i), dLdyExpB.e(i), 0.00001); } @@ -1509,14 +1509,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1531,14 +1531,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(B.isSameShape(*dLdAbp)); - ASSERT_TRUE(B.equalsTo(*dLdAbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); - ASSERT_TRUE(A.isSameShape(*dLdBbp)); - ASSERT_TRUE(A.equalsTo(*dLdBbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1557,14 +1557,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1583,14 +1583,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1609,14 +1609,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1632,14 +1632,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(B.isSameShape(*dLdAbp)); - ASSERT_TRUE(B.equalsTo(*dLdAbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); - ASSERT_TRUE(A.isSameShape(*dLdBbp)); - ASSERT_TRUE(A.equalsTo(*dLdBbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1657,14 +1657,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) { auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1683,14 +1683,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1709,14 +1709,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1736,14 +1736,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1763,14 +1763,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1791,14 +1791,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1819,14 +1819,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1851,14 +1851,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) { ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto* dLdAbp = resultsBP.at(0); - auto* dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdAbp)); - ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); - ASSERT_TRUE(dB.isSameShape(*dLdBbp)); - ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// @@ -1877,14 +1877,14 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* dLdA = results.at(0); - auto* dLdB = results.at(1); + auto dLdA = results.at(0); + auto dLdB = results.at(1); - ASSERT_TRUE(dA.isSameShape(*dLdA)); - ASSERT_TRUE(dA.equalsTo(*dLdA)); + ASSERT_TRUE(dA.isSameShape(dLdA)); + ASSERT_TRUE(dA.equalsTo(dLdA)); - ASSERT_TRUE(dB.isSameShape(*dLdB)); - ASSERT_TRUE(dB.equalsTo(*dLdB)); + ASSERT_TRUE(dB.isSameShape(dLdB)); + ASSERT_TRUE(dB.equalsTo(dLdB)); } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index c909b7686a10..77262e0f747d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -51,7 +51,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_2) { @@ -69,7 +69,7 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_3) { @@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { @@ -162,7 +162,7 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) { sd::ops::cast op; auto result = op.evaluate({&x}, {10}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests16, test_range_1) { @@ -171,7 +171,7 @@ TEST_F(DeclarableOpsTests16, test_range_1) { Context ctx(1); ctx.setTArguments({ -1.0, 1.0, 0.01 }); - ctx.setOutputArray(0, &z); + ctx.setOutputArray(0, z); auto status = op.execute(&ctx); ASSERT_EQ(Status::OK(), status); @@ -214,8 +214,8 @@ TEST_F(DeclarableOpsTests16, test_reverse_1) { auto listE = exp.allTensorsAlongDimension({ 1 }); for (int e = 0; e < r; e++) { - listI.at(e)->assign(rowOriginal); - listE.at(e)->assign(rowReversed); + listI.at(e).assign(rowOriginal); + listE.at(e).assign(rowReversed); } sd::ops::reverse op; @@ -272,8 +272,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { auto actual = NDArrayFactory::create('c', { 5,4,3 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -326,8 +326,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { auto actual = NDArrayFactory::create('c', { 5,3,4 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 1 }); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -353,8 +353,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { auto actual = NDArrayFactory::create('c', { 4, 3 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -380,8 +380,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { auto actual = NDArrayFactory::create('c', { 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 0 }); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -402,8 +402,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -439,8 +439,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &subArrRgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_hsv op; auto status = op.execute(&ctx); @@ -484,8 +484,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { auto actual = NDArrayFactory::create('c', { 5,4,3 }); Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -527,8 +527,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { auto actual = NDArrayFactory::create('c', { 5,3,4 }); Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 1 }); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -552,8 +552,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { auto actual = NDArrayFactory::create('c', { 4,3 }); Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -578,8 +578,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { auto actual = NDArrayFactory::create('c', { 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 0 }); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -601,8 +601,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &hsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -638,8 +638,8 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { #endif Context ctx(1); - ctx.setInputArray(0, &subArrHsvs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, subArrHsvs); + ctx.setOutputArray(0, actual); sd::ops::hsv_to_rgb op; auto status = op.execute(&ctx); @@ -698,8 +698,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) { auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -746,8 +746,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 1 }); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -776,8 +776,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { auto actual = NDArrayFactory::create('c', { 4, 3 }); Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -806,8 +806,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { auto actual = NDArrayFactory::create('c', { 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &rgb); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 0 }); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -829,8 +829,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &rgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -867,8 +867,8 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &subArrRgbs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); sd::ops::rgb_to_yiq op; auto status = op.execute(&ctx); @@ -910,8 +910,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); @@ -954,8 +954,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 1 }); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); @@ -980,8 +980,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { auto actual = NDArrayFactory::create('c', { 4, 3 }); Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); @@ -1006,8 +1006,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { auto actual = NDArrayFactory::create('c', { 3, 4 }); Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); ctx.setIArguments({ 0 }); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); @@ -1028,8 +1028,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &yiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); @@ -1066,8 +1066,8 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { auto actual = NDArrayFactory::create('c', { 3 }); Context ctx(1); - ctx.setInputArray(0, &subArrYiqs); - ctx.setOutputArray(0, &actual); + ctx.setInputArray(0, subArrYiqs); + ctx.setOutputArray(0, actual); sd::ops::yiq_to_rgb op; auto status = op.execute(&ctx); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp index 1341312f8deb..08434b25b11e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -84,7 +84,7 @@ TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_EQ(exp0, *z0); - ASSERT_EQ(exp1, *z1); + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index 1f36a8f2c77a..7d7a39cdac63 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -1186,7 +1186,7 @@ TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); - results = op.evaluate({ &grad, results.at(1), results.at(2) }, { 0.95, 1.0e-6 }, { }); + results = op.evaluate({ &grad, &results.at(1), &results.at(2) }, { 0.95, 1.0e-6 }, { }); NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index f4847889b8fb..1661eed56fe2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -45,9 +45,9 @@ TEST_F(DeclarableOpsTests2, gather_1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -63,9 +63,9 @@ TEST_F(DeclarableOpsTests2, gather_2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -83,9 +83,9 @@ TEST_F(DeclarableOpsTests2, gather_3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -101,9 +101,9 @@ TEST_F(DeclarableOpsTests2, gather_4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -120,9 +120,9 @@ TEST_F(DeclarableOpsTests2, gather_5) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -141,9 +141,9 @@ TEST_F(DeclarableOpsTests2, gather_6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -161,9 +161,9 @@ TEST_F(DeclarableOpsTests2, gather_7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -179,11 +179,11 @@ TEST_F(DeclarableOpsTests2, gather_8) { auto result = op.evaluate({&input, &indices}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); // output->printShapeInfo(); // output->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -278,9 +278,9 @@ TEST_F(DeclarableOpsTests2, gather_13) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto* output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -539,7 +539,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -564,8 +564,8 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printIndexedBuffer("ADL test2"); + auto result = results.at(0); + // result.printIndexedBuffer("ADL test2"); // expected.printIndexedBuffer("ADL expec"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -590,7 +590,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -615,7 +615,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -640,7 +640,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -665,7 +665,7 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -688,10 +688,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); } @@ -711,10 +711,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); @@ -736,10 +736,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.); @@ -761,10 +761,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 60.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); @@ -786,10 +786,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); @@ -811,10 +811,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); @@ -836,10 +836,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); @@ -863,10 +863,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); @@ -888,10 +888,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 2.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 2.f); @@ -917,10 +917,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 2.01667, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 2.01667, 1e-5); @@ -950,10 +950,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.93333, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333, 1e-5); @@ -983,10 +983,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.93333f, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333f, 1e-5); @@ -1009,10 +1009,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); @@ -1034,10 +1034,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); @@ -1059,10 +1059,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 1.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); @@ -1084,10 +1084,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); @@ -1121,10 +1121,10 @@ TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.965517, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.965517, 1e-5); @@ -1147,7 +1147,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1173,7 +1173,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1200,7 +1200,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1225,7 +1225,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1250,10 +1250,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.); @@ -1275,10 +1275,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.f); @@ -1300,10 +1300,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -69.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -69.f); @@ -1325,10 +1325,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.f); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.f); @@ -1350,10 +1350,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.); @@ -1377,10 +1377,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -32.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -32.); @@ -1402,8 +1402,8 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result.printBuffer(); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1427,8 +1427,8 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result.printBuffer(); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1452,8 +1452,8 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result.printBuffer(); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1477,10 +1477,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } @@ -1501,10 +1501,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } @@ -1525,10 +1525,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 83.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } @@ -1549,10 +1549,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } @@ -1573,10 +1573,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } @@ -1597,10 +1597,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } @@ -1621,10 +1621,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.45833, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } @@ -1645,10 +1645,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.45833, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } @@ -1673,10 +1673,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 3.975, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.975, 1e-5); } @@ -1697,10 +1697,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == 0.); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); } @@ -1722,7 +1722,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1747,7 +1747,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1772,7 +1772,7 @@ TEST_F(DeclarableOpsTests2, huber_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1796,10 +1796,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.44, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } @@ -1820,10 +1820,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.44, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } @@ -1844,10 +1844,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.12, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } @@ -1868,10 +1868,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.12, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } @@ -1896,10 +1896,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.3, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.3, 1e-5); } @@ -1920,10 +1920,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.56, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } @@ -1944,10 +1944,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.56, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } @@ -1972,10 +1972,10 @@ TEST_F(DeclarableOpsTests2, huber_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.65, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.65, 1e-5); } @@ -1997,7 +1997,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2022,7 +2022,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2047,7 +2047,7 @@ TEST_F(DeclarableOpsTests2, log_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2071,10 +2071,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } @@ -2095,10 +2095,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } @@ -2119,10 +2119,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } @@ -2143,10 +2143,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } @@ -2167,10 +2167,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } @@ -2191,10 +2191,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } @@ -2219,10 +2219,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -12.443609, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -12.443609, 1e-5); } @@ -2243,10 +2243,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -4.745268, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } @@ -2267,10 +2267,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -4.745268, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } @@ -2295,10 +2295,10 @@ TEST_F(DeclarableOpsTests2, log_loss_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.221805, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.221805, 1e-5); } @@ -2315,7 +2315,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2335,7 +2335,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2355,7 +2355,7 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2374,10 +2374,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 60.74394998193965, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 60.74394998193965, 1e-5); } @@ -2393,10 +2393,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 15.189082270182983, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 15.189082270182983, 1e-5); } @@ -2412,10 +2412,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 13.568564090650312, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.568564090650312, 1e-5); } @@ -2431,10 +2431,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 198.318201904499, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 198.318201904499, 1e-5); } @@ -2450,10 +2450,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 10.709003499121707, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.709003499121707, 1e-5); } @@ -2469,10 +2469,10 @@ TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 17.686067864414472, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 17.686067864414472, 1e-5); } @@ -2493,7 +2493,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2518,7 +2518,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2543,7 +2543,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2572,7 +2572,7 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2596,10 +2596,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } @@ -2620,10 +2620,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } @@ -2644,10 +2644,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 612.5, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } @@ -2672,10 +2672,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 608.75, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 608.75, 1e-5); } @@ -2696,10 +2696,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } @@ -2720,10 +2720,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } @@ -2744,10 +2744,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } @@ -2771,10 +2771,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 88.541664, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 88.541664, 1e-5); } @@ -2795,10 +2795,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } @@ -2819,10 +2819,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } @@ -2843,10 +2843,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } @@ -2870,10 +2870,10 @@ TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 44.270832, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 44.270832, 1e-5); } @@ -2894,7 +2894,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2918,7 +2918,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2942,7 +2942,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -2989,10 +2989,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } @@ -3012,10 +3012,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } @@ -3035,10 +3035,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } @@ -3058,10 +3058,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 10.2187976837, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.2187976837, 1e-5); } @@ -3085,10 +3085,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 6.06840181351, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.06840181351, 1e-5); } @@ -3108,10 +3108,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } @@ -3131,10 +3131,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } @@ -3154,10 +3154,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.851566493511, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.851566493511, 1e-5); } @@ -3180,10 +3180,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 1.01140034199, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.01140034199, 1e-5); } @@ -3203,10 +3203,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } @@ -3226,10 +3226,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } @@ -3249,10 +3249,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.425783246756, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.425783246756, 1e-5); } @@ -3275,10 +3275,10 @@ TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 0.505700170994, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.505700170994, 1e-5); } @@ -3299,7 +3299,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3322,7 +3322,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3346,7 +3346,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3370,7 +3370,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3394,7 +3394,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3417,10 +3417,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), 8.55521392822, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 8.55521392822, 1e-5); } @@ -3440,10 +3440,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } @@ -3463,10 +3463,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } @@ -3486,10 +3486,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } @@ -3509,10 +3509,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -2.12338066101, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.12338066101, 1e-5); } @@ -3532,10 +3532,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -1.06169033051, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -1.06169033051, 1e-5); } @@ -3555,10 +3555,10 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result->isScalar()); - ASSERT_NEAR(result->e(0), -2.18880319595, 1e-5); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.18880319595, 1e-5); } @@ -3579,7 +3579,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3605,7 +3605,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3629,7 +3629,7 @@ TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *result = results.at(0); + auto result = results.at(0); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3671,8 +3671,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3716,8 +3716,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3761,8 +3761,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3806,8 +3806,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test4) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3851,8 +3851,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test5) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3896,8 +3896,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3941,8 +3941,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -3987,8 +3987,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); @@ -4032,8 +4032,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test9) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); @@ -4077,8 +4077,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test10) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -4122,8 +4122,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test11) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -4167,8 +4167,8 @@ TEST_F(DeclarableOpsTests2, lstmCell_test12) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index eddef73b32da..e4883e7cc832 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -324,7 +324,7 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { auto result = op.evaluate({&x}, {1.0}, {1}); auto z = result.at(0); - auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); + auto zNorm1 = z.reduceAlongDimension(reduce::Norm2, {1}, true); auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); ASSERT_TRUE(exp.isSameShape(&zNorm1)); @@ -347,8 +347,8 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { auto z0 = result.at(0); auto z1 = result.at(1); - z0->getDataBuffer()->syncToSpecial(true); // force sync - z1->getDataBuffer()->syncToSpecial(true); // force sync + z0.getDataBuffer()->syncToSpecial(true); // force sync + z1.getDataBuffer()->syncToSpecial(true); // force sync ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); @@ -798,8 +798,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -833,8 +833,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -867,8 +867,8 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(0); - auto *ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -906,7 +906,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(3); + auto ht = results.at(3); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -941,7 +941,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *ht = results.at(3); + auto ht = results.at(3); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -978,7 +978,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *ht = result.at(3); + auto ht = result.at(3); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -996,7 +996,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1016,7 +1016,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1035,7 +1035,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1055,7 +1055,7 @@ TEST_F(DeclarableOpsTests3, diag_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1075,7 +1075,7 @@ TEST_F(DeclarableOpsTests3, diag_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1115,7 +1115,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1136,7 +1136,7 @@ TEST_F(DeclarableOpsTests3, diag_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1156,7 +1156,7 @@ TEST_F(DeclarableOpsTests3, diag_test4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests3, diag_test5) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1196,7 +1196,7 @@ TEST_F(DeclarableOpsTests3, diag_test6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1218,7 +1218,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1240,7 +1240,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1263,7 +1263,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1285,7 +1285,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1305,7 +1305,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); // output->printBuffer(); ASSERT_TRUE(expected.isSameShape(output)); @@ -1326,7 +1326,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1346,7 +1346,7 @@ TEST_F(DeclarableOpsTests3, diagPart_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1396,7 +1396,7 @@ TEST_F(DeclarableOpsTests3, betainc_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1421,7 +1421,7 @@ TEST_F(DeclarableOpsTests3, betainc_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1446,7 +1446,7 @@ TEST_F(DeclarableOpsTests3, betainc_test4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); @@ -1471,7 +1471,7 @@ TEST_F(DeclarableOpsTests3, betainc_test5) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); @@ -1496,7 +1496,7 @@ TEST_F(DeclarableOpsTests3, betainc_test6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); @@ -1521,7 +1521,7 @@ TEST_F(DeclarableOpsTests3, betainc_test7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); @@ -1546,7 +1546,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); @@ -1572,7 +1572,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1597,7 +1597,7 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1617,7 +1617,7 @@ TEST_F(DeclarableOpsTests3, betainc_test11) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1638,7 +1638,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1661,7 +1661,7 @@ TEST_F(DeclarableOpsTests3, zeta_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1684,7 +1684,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1709,7 +1709,7 @@ TEST_F(DeclarableOpsTests3, zeta_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1733,7 +1733,7 @@ TEST_F(DeclarableOpsTests3, zeta_test4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1756,7 +1756,7 @@ TEST_F(DeclarableOpsTests3, zeta_test5) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1779,7 +1779,7 @@ TEST_F(DeclarableOpsTests3, zeta_test6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1803,7 +1803,7 @@ TEST_F(DeclarableOpsTests3, zeta_test7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1826,7 +1826,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *output = result.at(0); + auto output = result.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1851,7 +1851,7 @@ TEST_F(DeclarableOpsTests3, zeta_test9) { ASSERT_EQ(ND4J_STATUS_OK, results); - //auto *output = result.at(0); + //auto output = result.at(0); // z.printIndexedBuffer("Zeta output"); ASSERT_TRUE(expected.isSameShape(z)); ASSERT_TRUE(expected.equalsTo(z)); @@ -1876,7 +1876,7 @@ TEST_F(DeclarableOpsTests3, zeta_test10) { ASSERT_EQ(ND4J_STATUS_OK, results); - //auto *output = result.at(0); + //auto output = result.at(0); // z.printIndexedBuffer("Zeta output"); ASSERT_TRUE(expected.isSameShape(z)); ASSERT_TRUE(expected.equalsTo(z)); @@ -2023,9 +2023,9 @@ TEST_F(DeclarableOpsTests3, svd_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2039,9 +2039,9 @@ TEST_F(DeclarableOpsTests3, svd_test1) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); } } @@ -2059,9 +2059,9 @@ TEST_F(DeclarableOpsTests3, svd_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2075,9 +2075,9 @@ TEST_F(DeclarableOpsTests3, svd_test2) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); } } @@ -2095,9 +2095,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2111,9 +2111,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); } } @@ -2131,9 +2131,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2147,9 +2147,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); } } @@ -2167,9 +2167,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2183,9 +2183,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); } } @@ -2221,9 +2221,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2237,9 +2237,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); } } @@ -2259,7 +2259,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); + auto s = result.at(0); ASSERT_TRUE(expS.equalsTo(s)); ASSERT_TRUE(expS.isSameShape(s)); @@ -2395,9 +2395,9 @@ TEST_F(DeclarableOpsTests3, svd_test7) { // ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// auto *s = result.at(0); -// auto *u = result.at(1); -// auto *v = result.at(2); +// auto s = result.at(0); +// auto u = result.at(1); +// auto v = result.at(2); // ASSERT_TRUE(expS.isSameShape(s)); // ASSERT_TRUE(expU.isSameShape(u)); @@ -2458,9 +2458,9 @@ TEST_F(DeclarableOpsTests3, svd_test9) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2474,9 +2474,9 @@ TEST_F(DeclarableOpsTests3, svd_test9) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); } } @@ -2516,9 +2516,9 @@ TEST_F(DeclarableOpsTests3, svd_test10) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto *s = result.at(0); - auto *u = result.at(1); - auto *v = result.at(2); + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expU.isSameShape(u)); @@ -2532,9 +2532,9 @@ TEST_F(DeclarableOpsTests3, svd_test10) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); } } @@ -2576,9 +2576,9 @@ TEST_F(DeclarableOpsTests3, svd_test11) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 3c3accf40457..2af6b480e2b6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -337,7 +337,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_11) { } m /= c; - ASSERT_EQ(m, *z); + ASSERT_EQ(m, z); } @@ -394,7 +394,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_13) { const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -411,7 +411,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_13) { ASSERT_EQ(ND4J_STATUS_OK, status); auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; @@ -436,7 +436,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_14) { const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -454,7 +454,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_14) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -478,7 +478,7 @@ TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); @@ -496,7 +496,7 @@ TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.isSameShape(*result)); delete variableSpace; delete block; @@ -600,11 +600,11 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_1) { auto gradI = result.at(0); auto gradB = result.at(1); - ASSERT_TRUE(gradI->isSameShape(gradO)); - ASSERT_TRUE(gradI->equalsTo(gradO)); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); - ASSERT_TRUE(gradB->isSameShape(expGradB)); - ASSERT_TRUE(gradB->equalsTo(expGradB)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } @@ -628,11 +628,11 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_2) { auto gradI = result.at(0); auto gradB = result.at(1); - ASSERT_TRUE(gradI->isSameShape(gradO)); - ASSERT_TRUE(gradI->equalsTo(gradO)); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); - ASSERT_TRUE(gradB->isSameShape(expGradB)); - ASSERT_TRUE(gradB->equalsTo(expGradB)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } @@ -963,7 +963,7 @@ TEST_F(DeclarableOpsTests4, split_test6) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i)->isSameShape(expShape)); + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } /////////////////////////////////////////////////////////////////// @@ -981,7 +981,7 @@ TEST_F(DeclarableOpsTests4, split_test7) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i)->isSameShape(expShape)); + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } @@ -1182,7 +1182,7 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) { auto z = result.at(0); - ASSERT_EQ(2, z->rankOf()); + ASSERT_EQ(2, z.rankOf()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1268,7 +1268,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); } @@ -1287,7 +1287,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->lengthOf() == 1); + ASSERT_TRUE(z.lengthOf() == 1); ASSERT_TRUE(exp.equalsTo(z)); } @@ -1749,9 +1749,9 @@ TEST_F(DeclarableOpsTests4, lstm_test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto *h = results.at(0); - auto *c = results.at(1); - auto cLast = (*c)({4,5,0,0,0,0},true); + auto h = results.at(0); + auto c = results.at(1); + auto cLast = c({4,5,0,0,0,0},true); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 6ac9d34cdbf0..df0d9ad7998a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -216,7 +216,7 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { sd::ops::less op; auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(result.at(0)->t(0), true); + ASSERT_EQ(result.at(0).t(0), true); } @@ -409,7 +409,7 @@ TEST_F(DeclarableOpsTests5, Identity_test2) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(eps)); + ASSERT_TRUE(z.equalsTo(eps)); } @@ -425,7 +425,7 @@ TEST_F(DeclarableOpsTests5, Log1p_test1) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(y)); + ASSERT_TRUE(z.equalsTo(y)); } @@ -462,8 +462,6 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -817,7 +815,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -1135,7 +1133,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { auto z = results.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -1174,7 +1172,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_0) { ASSERT_TRUE(expI.equalsTo(i)); // repeat res again for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting } } @@ -1213,7 +1211,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_1) { ASSERT_TRUE(expI.equalsTo(i)); // repeat res again for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting } } @@ -1436,9 +1434,9 @@ TEST_F(DeclarableOpsTests5, Test_Moments_1) { // v->printIndexedBuffer("Result is "); // d->printIndexedBuffer("Result is "); - ASSERT_TRUE(v->isScalar()); - ASSERT_NEAR(expMean, v->e(0), inf); - ASSERT_NEAR(expDeviation, d->e(0), inf); + ASSERT_TRUE(v.isScalar()); + ASSERT_NEAR(expMean, v.e(0), inf); + ASSERT_NEAR(expDeviation, d.e(0), inf); } @@ -1464,11 +1462,11 @@ TEST_F(DeclarableOpsTests5, Test_Moments_2) { auto v = result.at(0); auto d = result.at(1); - ASSERT_TRUE(v->isVector()); - ASSERT_TRUE(d->isVector()); + ASSERT_TRUE(v.isVector()); + ASSERT_TRUE(d.isVector()); - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } @@ -1498,11 +1496,11 @@ TEST_F(DeclarableOpsTests5, Test_Moments_3) { auto v = result.at(0); auto d = result.at(1); - ASSERT_TRUE(v->isMatrix()); - ASSERT_TRUE(d->isMatrix()); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } @@ -1525,8 +1523,8 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) { auto v = result.at(0); auto d = result.at(1); - ASSERT_TRUE(v->isMatrix()); - ASSERT_TRUE(d->isMatrix()); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); // v->printIndexedBuffer("v"); // expV.printIndexedBuffer("expV"); @@ -1534,8 +1532,8 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) { // d->printIndexedBuffer("d"); // expD.printIndexedBuffer("expD"); - ASSERT_TRUE(v->equalsTo(&expV)); - ASSERT_TRUE(d->equalsTo(&expD)); + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } @@ -1644,8 +1642,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { auto output = results.at(0); bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) + for(int i = 0; i < output.lengthOf(); ++i) + if(output.e(i) == (float)0.) haveZeros = true; ASSERT_EQ(Status::OK(), results.status()); @@ -1684,8 +1682,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { auto output = results.at(0); bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) + for(int i = 0; i < output.lengthOf(); ++i) + if(output.e(i) == (float)0.) haveZeros = true; ASSERT_EQ(Status::OK(), results.status()); @@ -1728,8 +1726,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { ASSERT_EQ(Status::OK(), results.status()); auto output = results.at(0); bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) + for(int i = 0; i < output.lengthOf(); ++i) + if(output.e(i) == (float)0.) haveZeros = true; ASSERT_TRUE(input.isSameShape(output)); @@ -1750,8 +1748,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { auto output = results.at(0); bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) + for(int i = 0; i < output.lengthOf(); ++i) + if(output.e(i) == (float)0.) haveZeros = true; ASSERT_EQ(Status::OK(), results.status()); @@ -1773,8 +1771,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { auto output = results.at(0); bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) + for(int i = 0; i < output.lengthOf(); ++i) + if(output.e(i) == (float)0.) haveZeros = true; ASSERT_EQ(Status::OK(), results.status()); @@ -1826,8 +1824,8 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { auto result = op.evaluate({&x, &y}, {}, {0}); auto output = result.at(0); // x.printShapeInfo("Input"); - output->printShapeInfo("Output"); - exp.printShapeInfo("Expected"); + //output->printShapeInfo("Output"); + //exp.printShapeInfo("Expected"); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_TRUE(exp.isSameShape(output)); //output->printIndexedBuffer("Output"); @@ -2047,7 +2045,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { for (int e = 0; e < result.size(); e++) { auto output = result.at(e); - if (output) + if (output.shapeInfo()) { // output->printShapeInfo("Output shape> "); // exp[e].printShapeInfo("Expected shape> "); @@ -2397,8 +2395,8 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_1) { auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.25); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.25); } @@ -2412,8 +2410,8 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_2) { auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.375); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } @@ -2427,8 +2425,8 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) { auto res = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0)->isScalar()); - ASSERT_EQ(res.at(0)->e(0), 0.375); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } @@ -2867,9 +2865,9 @@ TEST_F(DeclarableOpsTests5, L2_Loss_1) { auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->isScalar()); + ASSERT_TRUE(output.isScalar()); - ASSERT_EQ(output->e(0), exp); + ASSERT_EQ(output.e(0), exp); } @@ -2884,7 +2882,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_2) { auto z = results.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 2726947c9704..afcff6d40f93 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -76,7 +76,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } @@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { auto z = result.at(0); //z->printShapeInfo("SS OS shape"); - ASSERT_TRUE(z->isEmpty()); + ASSERT_TRUE(z.isEmpty()); //ASSERT_EQ(exp, *z); @@ -119,7 +119,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { auto z = result.at(0); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.equalsTo(exp)); //ASSERT_EQ(exp, *z); @@ -128,9 +128,9 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { int z = 0; auto matrix = NDArrayFactory::create('c', {1}, {10}); - auto b = NDArrayFactory::create_('c', {1}, {1}); - auto e = NDArrayFactory::create_('c', {1}, {z}); - auto s = NDArrayFactory::create_('c', {1}, {1}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {1}, {z}); + auto s = NDArrayFactory::create('c', {1}, {1}); sd::ops::ones_as opOnes; //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); auto onesRes = opOnes.evaluate({&matrix}); @@ -138,8 +138,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { ASSERT_EQ(onesRes.status(), Status::OK()); auto ones = onesRes.at(0); - *ones *= 10; - auto onesD = new NDArray(ones->dup()); + ones *= 10; + auto onesD = ones.dup(); auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, onesD); @@ -158,8 +158,10 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { block->appendI(0); block->appendI(0); - auto inputShapes = new ShapeList({ones->getShapeInfo(), b->getShapeInfo(), e->getShapeInfo(), s->getShapeInfo()}); + auto inputShapes = new ShapeList({ones.getShapeInfo(), b.getShapeInfo(), e.getShapeInfo(), s.getShapeInfo()}); + sd::ops::strided_slice op; + auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); ASSERT_EQ(result->size(), 1); ASSERT_TRUE(shape::isEmpty(result->at(0))); @@ -326,7 +328,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_NE(x.ordering(), z->ordering()); + ASSERT_NE(x.ordering(), z.ordering()); } @@ -630,7 +632,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) { // z->printShapeInfo(); // x.printShapeInfo(); - ASSERT_TRUE(z->ews() == 1); + ASSERT_TRUE(z.ews() == 1); ASSERT_TRUE(x.ews() == 1); @@ -781,10 +783,10 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MergeMaxIndex Result is "); -// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex"); +// res.at(0).printIndexedBuffer("MergeMaxIndex Result is "); +// res.at(0).printShapeInfo("Shape info for MergeMaxIdex"); // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -800,10 +802,10 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); ASSERT_EQ(ND4J_STATUS_OK, ress.status()); -// res.at(0)->printIndexedBuffer("MergeMaxIndex2 Result is "); -// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex2"); +// res.at(0).printIndexedBuffer("MergeMaxIndex2 Result is "); +// res.at(0).printShapeInfo("Shape info for MergeMaxIdex2"); // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(ress.at(0)->equalsTo(exp)); + ASSERT_TRUE(ress.at(0).equalsTo(exp)); } @@ -817,7 +819,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) { auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); - //res.at(0)->printIndexedBuffer("Result is "); + //res.at(0).printIndexedBuffer("Result is "); //x.printIndexedBuffer("Input is"); @@ -833,9 +835,9 @@ TEST_F(DeclarableOpsTests6, TestMod_1) { auto res = op.evaluate({&x, &y}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MOD Result is "); +// res.at(0).printIndexedBuffer("MOD Result is "); // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -851,10 +853,10 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) { auto res = op.evaluate({&x, &y, &eps}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MOD_BP Result is "); +// res.at(0).printIndexedBuffer("MOD_BP Result is "); // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } @@ -871,7 +873,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) { ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(res.at(0)->equalsTo(exp)); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } TEST_F(DeclarableOpsTests6, TestDropout_2) { @@ -945,7 +947,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { auto res = op.evaluate({&x, &axis}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0)->e(0), count); + ASSERT_EQ(res.at(0).e(0), count); ASSERT_TRUE(sumExp.equalsTo(res.at(1))); ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); @@ -977,7 +979,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { auto res = op.evaluate({&x, &axis}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0)->e(0), count); + ASSERT_EQ(res.at(0).e(0), count); ASSERT_TRUE(sumExp.equalsTo(res.at(1))); ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); @@ -1342,7 +1344,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(result.at(2)->isScalar()); + ASSERT_TRUE(result.at(2).isScalar()); ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(y)); @@ -1475,7 +1477,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { //z->printIndexedBuffer("Output "); //z->printShapeInfo("Shape"); //exp.printIndexedBuffer("Expected "); - ASSERT_TRUE(z->isScalar()); + ASSERT_TRUE(z.isScalar()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -2763,7 +2765,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -2776,7 +2778,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -2789,7 +2791,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { auto result = op.evaluate({&x}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 7f39c3d761ab..d11b34b5ce16 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -66,7 +66,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { auto z = result.at(1); - ASSERT_EQ(148,z->e(0)); + ASSERT_EQ(148,z.e(0)); //ASSERT_TRUE(exp.isSameShape(z)); @@ -88,8 +88,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(1); - auto array = *z; - ASSERT_EQ(3,array.e(0)); + ASSERT_EQ(3, z.e(0)); //ASSERT_TRUE(exp.isSameShape(z)); @@ -113,7 +112,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(3, z->lengthOf()); + ASSERT_EQ(3, z.lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); @@ -137,7 +136,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(3,z->lengthOf()); + ASSERT_EQ(3,z.lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); @@ -160,7 +159,7 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(2,z->lengthOf()); + ASSERT_EQ(2,z.lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); @@ -183,11 +182,8 @@ TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(3,z->lengthOf()); + ASSERT_EQ(3,z.lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); - - - } @@ -476,7 +472,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { auto timeStart = std::chrono::system_clock::now(); for (int i = 0; i < numOfCases; i++) { - op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {res}, {}, {}, {}); + op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {&res}, {}, {}, {}); } auto timeEnd = std::chrono::system_clock::now(); @@ -572,8 +568,8 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } @@ -663,8 +659,8 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(exp)); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } @@ -3975,7 +3971,7 @@ TEST_F(DeclarableOpsTests7, RealDiv_1) { auto z = result.at(0); // z->printIndexedBuffer("OUtput RealDiv"); ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -4022,7 +4018,7 @@ TEST_F(DeclarableOpsTests7, ShapesOf_1) { auto z = result.at(0); // z->printIndexedBuffer("OUtput RealDiv"); // ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -4065,7 +4061,7 @@ TEST_F(DeclarableOpsTests7, Size_1) { auto z = result.at(0); // z->printIndexedBuffer("OUtput SIZE"); /// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -4084,7 +4080,7 @@ TEST_F(DeclarableOpsTests7, Size_2) { auto z = result.at(0); // z->printIndexedBuffer("OUtput SIZE"); /// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -4102,7 +4098,7 @@ TEST_F(DeclarableOpsTests7, Softplus_1) { auto z = result.at(0); // z->printIndexedBuffer("OUtput Softplus"); /// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -4142,7 +4138,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) { auto z = result.at(0); // z->printIndexedBuffer("OUtput Softsign"); /// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(*z)); + ASSERT_TRUE(e.equalsTo(z)); } @@ -5018,7 +5014,7 @@ TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { auto result = op.evaluate({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, *result.at(0)); + ASSERT_EQ(exp, result.at(0)); } @@ -5972,7 +5968,7 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { sd::ops::reduce_prod op_exp; auto res = op_exp.evaluate({&input}); auto result = op.evaluate({&input, &eps}, {}, {}); - exp.assign(res.at(0)->e(0)); + exp.assign(res.at(0).e(0)); exp /= input; exp *= eps.e(0); ASSERT_EQ(Status::OK(), result.status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 589adebcbde0..9b5cf41848f1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -2473,7 +2473,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { ASSERT_EQ(Status::OK(), results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2495,7 +2495,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { ASSERT_EQ(Status::OK(), results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2517,7 +2517,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { ASSERT_EQ(Status::OK(), results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2627,7 +2627,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { ASSERT_EQ(Status::OK(), results.status()); - auto *output = results.at(0); + auto output = results.at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -2826,7 +2826,7 @@ TEST_F(DeclarableOpsTests8, clipbynorm_test_tf_119_1) { sd::ops::clipbynorm op; auto result = op.evaluate({&x}, {0.54}, {}); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } @@ -2938,8 +2938,8 @@ TEST_F(DeclarableOpsTests8, zeros_as_test2) { ASSERT_EQ(Status::OK(), result.status()); auto y = result.at(0); - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } @@ -2972,8 +2972,8 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) { auto results = op.evaluate({&x}); ASSERT_EQ(Status::OK(), results.status()); auto y = results.at(0); - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } @@ -2991,8 +2991,8 @@ TEST_F(DeclarableOpsTests8, ones_as_test3) { ASSERT_EQ(Status::OK(), results.status()); auto y = results.at(0); - ASSERT_TRUE(y->isSameShape(exp)); - ASSERT_TRUE(y->equalsTo(exp)); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index c7e704a21437..713cb2df4148 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -300,8 +300,6 @@ TEST_F(DeclarableOpsTests9, concat_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); - output->printBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -539,11 +537,11 @@ TEST_F(DeclarableOpsTests9, concat_test14) { auto z = result.at(0); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.getShapeInfo(), {0}); ASSERT_TRUE(2 == numOfTads); for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); + NDArray tad = z(e, {0}); auto mean = tad.meanNumber().e(0); ASSERT_NEAR((e+1)*1., mean, 1e-5); } @@ -601,11 +599,11 @@ TEST_F(DeclarableOpsTests9, concat_test17) { // z->printShapeInfo(); // z->printIndexedBuffer(); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.getShapeInfo(), {0}); ASSERT_TRUE(2 == numOfTads); for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); + NDArray tad = z(e, {0}); auto mean = tad.meanNumber().e(0); ASSERT_NEAR((e+1)*1., mean, 1e-5); } @@ -618,13 +616,13 @@ TEST_F(DeclarableOpsTests9, concat_test18) { // we crate bunch of arrays, filled with specific values for (int e = 0; e < 2000; e++) { - auto array = NDArrayFactory::create_('c', {1, 300}); - array->assign(e); - context.setInputArray(e, array, true); + auto array = NDArrayFactory::create('c', {1, 300}); + array.assign(e); + context.setInputArray(e, array); } auto z = NDArrayFactory::create('c', {2000, 300}); - context.setOutputArray(0, &z, false); + context.setOutputArray(0, z); context.setIArguments(&axis, 1); sd::ops::concat op; @@ -646,13 +644,13 @@ TEST_F(DeclarableOpsTests9, concat_test19) { // we crate bunch of arrays, filled with specific values for (int e = 0; e < 10; e++) { - auto array = NDArrayFactory::create_('c', {1, 5, 20}); - array->assign(e); - context.setInputArray(e, array, true); + auto array = NDArrayFactory::create('c', {1, 5, 20}); + array.assign(e); + context.setInputArray(e, array); } auto z = NDArrayFactory::create('c', {10, 5, 20}); - context.setOutputArray(0, &z, false); + context.setOutputArray(0, z); context.setIArguments(&axis, 1); sd::ops::concat op; @@ -680,11 +678,11 @@ TEST_F(DeclarableOpsTests9, concat_test20) { auto z = result.at(0); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.getShapeInfo(), {0}); ASSERT_TRUE(4 == numOfTads); for (int e = 0; e < numOfTads; e++) { - NDArray tad = (*z)(e, {0}); + NDArray tad = z(e, {0}); auto mean = tad.meanNumber().e(0); ASSERT_NEAR((double) e+1, mean, 1e-5); } @@ -793,7 +791,6 @@ TEST_F(DeclarableOpsTests9, concat_test26) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); - output->printLinearBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -815,7 +812,7 @@ TEST_F(DeclarableOpsTests9, concat_test27) { auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(expShape)); + ASSERT_TRUE(z.isSameShape(expShape)); } ////////////////////////////////////////////////////////////////////// @@ -989,7 +986,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { ASSERT_EQ(ND4J_STATUS_OK, ress.status()); //ress.at(0)->printIndexedBuffer("Result is "); //x.printIndexedBuffer("Input is"); - ASSERT_FALSE(ress.at(0)->equalsTo(errs)); + ASSERT_FALSE(ress.at(0).equalsTo(errs)); } @@ -1004,20 +1001,20 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) { auto ress = op.evaluate({&x}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - NDArray* res = ress.at(0); //->printIndexedBuffer("Result is "); + auto res = ress.at(0); //->printIndexedBuffer("Result is "); //x.printIndexedBuffer("Input is"); //res->printIndexedBuffer("Result for Dropout_1"); - auto countZero = res->reduceNumber(reduce::CountZero); + auto countZero = res.reduceNumber(reduce::CountZero); ASSERT_NEAR(countZero.e(0), 80, 5); auto ress2 = op.evaluate({&x}, {0.2f}, {113}); ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - NDArray* res2 = ress2.at(0); + auto res2 = ress2.at(0); - countZero = res->reduceNumber(reduce::CountZero); + countZero = res.reduceNumber(reduce::CountZero); ASSERT_NEAR(countZero.e(0), 80, 5); //res2->printIndexedBuffer("Result for Dropout_2"); - ASSERT_TRUE(res->equalsTo(res2)); + ASSERT_TRUE(res.equalsTo(res2)); //res->printIndexedBuffer("FF dropout"); //res2->printIndexedBuffer("BP dropout"); @@ -1060,7 +1057,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { ASSERT_EQ(ND4J_STATUS_OK, ress.status()); //ress.at(0)->printIndexedBuffer("01Dropout result is "); - auto count = ress.at(0)->reduceNumber(reduce::CountNonZero); + auto count = ress.at(0).reduceNumber(reduce::CountNonZero); // nd4j_printf("\n01Dropout count %i\n\n", count); sd::ops::dropout_bp op2; @@ -1076,9 +1073,9 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { //ressY->at(0)->printIndexedBuffer("BP"); //ress.at(0)->printIndexedBuffer("FF"); bool ret = true; - for (int e = 0; e < ress.at(0)->lengthOf(); e++) { - if (ress.at(0)->e(e) == 0.f) - if (ressX.at(0)->e(e) != ress.at(0)->e(e)) { + for (int e = 0; e < ress.at(0).lengthOf(); e++) { + if (ress.at(0).e(e) == 0.f) + if (ressX.at(0).e(e) != ress.at(0).e(e)) { ret = false; break; } @@ -1122,16 +1119,16 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { //ressY->at(0)->printIndexedBuffer("BP Dropout result is "); - auto countZero = ress.at(0)->reduceNumber(reduce::CountZero); + auto countZero = ress.at(0).reduceNumber(reduce::CountZero); ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressX.at(0)->reduceNumber(reduce::CountZero); + countZero = ressX.at(0).reduceNumber(reduce::CountZero); //nd4j_printf("X zero count is %f\n", countZero); ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressY.at(0)->reduceNumber(reduce::CountZero); + countZero = ressY.at(0).reduceNumber(reduce::CountZero); //nd4j_printf("Y zero count is %f\n", countZero); ASSERT_NEAR(countZero.e(0), 50.f, 10.f); // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0))); + ASSERT_TRUE(ressX.at(0).equalsTo(ressY.at(0))); } @@ -1149,15 +1146,15 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - NDArray* res = ress.at(0); + auto res = ress.at(0); auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - NDArray* res2 = ress2.at(0); + auto res2 = ress2.at(0); //res->printIndexedBuffer("Result1AlphaBP1"); //res2->printIndexedBuffer("Result1AlphaBP2"); - ASSERT_TRUE(res2->equalsTo(res)); + ASSERT_TRUE(res2.equalsTo(res)); } @@ -1171,7 +1168,6 @@ TEST_F(DeclarableOpsTests9, test_range_int_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - } TEST_F(DeclarableOpsTests9, test_range_empty_1) { @@ -1185,8 +1181,7 @@ TEST_F(DeclarableOpsTests9, test_range_empty_1) { auto z = result.at(0); - ASSERT_TRUE(z->isEmpty()); - + ASSERT_TRUE(z.isEmpty()); } @@ -1221,7 +1216,6 @@ TEST_F(DeclarableOpsTests9, test_unstack_1) { auto result = op.evaluate({&x}, {}, {0}); ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(5, result.size()); - } //////////////////////////////////////////////////////////////////////////////// @@ -1239,8 +1233,8 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(5, result.size()); for (size_t i = 0; i < result.size(); i++) { - ASSERT_TRUE(result.at(i)->isSameShape(z[i])); - ASSERT_TRUE(result.at(i)->equalsTo(z[i])); + ASSERT_TRUE(result.at(i).isSameShape(z[i])); + ASSERT_TRUE(result.at(i).equalsTo(z[i])); } } @@ -2319,7 +2313,7 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { // printf("How many: %ul\n", res2->size()); // res2->at(0)->printBuffer("Ouputput0"); // res2->at(1)->printBuffer("Ouputput1"); - ASSERT_TRUE(res2.at(0)->equalsTo(exp)); + ASSERT_TRUE(res2.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index e6aeb43d435f..2cbb691b2ca0 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -60,13 +60,13 @@ TEST_F(EmptyTests, Test_Create_Empty_2) { TEST_F(EmptyTests, Test_Concat_1) { // auto empty = NDArrayFactory::empty_(); - auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; - auto vector = NDArrayFactory::create_('c', {1}, {1.0f}); + auto empty = NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; + auto vector = NDArrayFactory::create('c', {1}, {1.0f}); - ASSERT_TRUE(empty->isEmpty()); + ASSERT_TRUE(empty.isEmpty()); sd::ops::concat op; - auto result = op.evaluate({empty, vector}, {}, {0}); + auto result = op.evaluate({&empty, &vector}, {}, {0}); ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); @@ -74,35 +74,25 @@ TEST_F(EmptyTests, Test_Concat_1) { // z->printShapeInfo("z shape"); // z->printIndexedBuffer("z buffr"); - ASSERT_EQ(*vector, *z); - - delete empty; - delete vector; + ASSERT_EQ(vector, z); } TEST_F(EmptyTests, Test_Concat_2) { - auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create_('c', {1}, {1.0f}); - auto scalar2 = NDArrayFactory::create_('c', {1}, {2.0f}); + auto empty = NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create('c', {1}, {1.0f}); + auto scalar2 = NDArrayFactory::create('c', {1}, {2.0f}); auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty->isEmpty()); + ASSERT_TRUE(empty.isEmpty()); sd::ops::concat op; - auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0}); + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); -// z->printShapeInfo("z shape"); -// z->printIndexedBuffer("z buffr"); - - ASSERT_EQ(exp, *z); - - delete empty; - delete scalar1; - delete scalar2; + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_3) { @@ -119,8 +109,7 @@ TEST_F(EmptyTests, Test_Concat_3) { auto z = result.at(0); - ASSERT_EQ(exp, *z); - + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_4) { @@ -137,17 +126,15 @@ TEST_F(EmptyTests, Test_Concat_4) { auto z = result.at(0); - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); - auto dup = new NDArray(empty.dup()); - - ASSERT_TRUE(dup->isEmpty()); - ASSERT_EQ(empty, *dup); + auto dup = empty.dup(); - delete dup; + ASSERT_TRUE(dup.isEmpty()); + ASSERT_EQ(empty, dup); } TEST_F(EmptyTests, test_empty_scatter_1) { @@ -162,7 +149,7 @@ TEST_F(EmptyTests, test_empty_scatter_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(x, *z); + ASSERT_EQ(x, z); } @@ -236,7 +223,7 @@ TEST_F(EmptyTests, test_empty_matmul_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } @@ -250,5 +237,5 @@ TEST_F(EmptyTests, test_empty_matmul_2) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp index f31a1c7ec47b..ae4d654e9007 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -43,9 +43,7 @@ TEST_F(FlatUtilsTests, flat_float_serde_1) { auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_int_serde_1) { @@ -60,9 +58,7 @@ TEST_F(FlatUtilsTests, flat_int_serde_1) { auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_bool_serde_1) { @@ -77,9 +73,7 @@ TEST_F(FlatUtilsTests, flat_bool_serde_1) { auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_string_serde_1) { @@ -94,7 +88,5 @@ TEST_F(FlatUtilsTests, flat_string_serde_1) { auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, *restored); - - delete restored; + ASSERT_EQ(array, restored); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 1d65c2072a46..f02c9feb54f9 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -55,7 +55,7 @@ TEST_F(GraphExecutorTests, test_basic_exec_1) { TEST_F(GraphExecutorTests, test_basic_exec_2) { GraphMemoryManager mgr; - Graph graph(nullptr, nullptr, mgr); + Graph graph(nullptr, mgr); auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); @@ -91,14 +91,14 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { executor.execute(optimizedGraph); // checking results by ID - ASSERT_TRUE(graph.variableSpace()->hasVariable(m.id())); - ASSERT_TRUE(graph.variableSpace()->hasVariable(a.id())); + ASSERT_TRUE(graph.variableSpace().hasVariable(m.id())); + ASSERT_TRUE(graph.variableSpace().hasVariable(a.id())); // checking results by name - ASSERT_TRUE(graph.variableSpace()->hasVariable("mul")); - ASSERT_TRUE(graph.variableSpace()->hasVariable("add")); + ASSERT_TRUE(graph.variableSpace().hasVariable("mul")); + ASSERT_TRUE(graph.variableSpace().hasVariable("add")); // checking if result is valid - auto result = graph.variableSpace()->getVariable(a.id())->getNDArray(); + auto result = graph.variableSpace().getVariable(a.id())->getNDArray(); ASSERT_EQ(exp, *result); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index 0266023a6814..a1394c17cc86 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -47,10 +47,10 @@ class GraphTests : public testing::Test { TEST_F(GraphTests, SingleInput1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0f); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0f); - graph.variableSpace()->putVariable(-1, x); + graph.variableSpace().putVariable(-1, x); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_STRICT, transform::Cosine); @@ -64,9 +64,9 @@ TEST_F(GraphTests, SingleInput1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.variableSpace().hasVariable(3)); - auto node3 = graph.variableSpace()->getVariable(3)->getNDArray(); + auto node3 = graph.variableSpace().getVariable(3)->getNDArray(); ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); } @@ -74,17 +74,17 @@ TEST_F(GraphTests, SingleInput1) { TEST_F(GraphTests, DoubleInput1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(-1.0); - auto z = NDArrayFactory::create_('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, y); - graph.variableSpace()->putVariable(-3, z); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, y); + graph.variableSpace().putVariable(-3, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs); @@ -99,21 +99,21 @@ TEST_F(GraphTests, DoubleInput1) { graph.execute(); - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, SingleInput3) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto v0 = NDArrayFactory::create_('c', {5, 5}); - auto v1 = NDArrayFactory::create_('c', {5, 5}); + auto v0 = NDArrayFactory::create('c', {5, 5}); + auto v1 = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, v0); - graph.variableSpace()->putVariable(-3, v1); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, v0); + graph.variableSpace().putVariable(-3, v1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); @@ -127,22 +127,22 @@ TEST_F(GraphTests, SingleInput3) { graph.execute(); - ASSERT_NEAR(1.4142135, v0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(1.0, v1->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(1.4142135, v0.reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(1.0, v1.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, SingleInput4) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto v0 = NDArrayFactory::create_('c', {5, 5}); - auto v1 = NDArrayFactory::create_('c', {5, 5}); + auto v0 = NDArrayFactory::create('c', {5, 5}); + auto v1 = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, v0); - graph.variableSpace()->putVariable(-3, v1); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, v0); + graph.variableSpace().putVariable(-3, v1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); @@ -161,27 +161,27 @@ TEST_F(GraphTests, SingleInput4) { graph.execute(); - ASSERT_NEAR(1.0, v0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.4142135, v1->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(1.0, v0.reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.4142135, v1.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, DoubleInput2) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(-1.0); - auto z0 = NDArrayFactory::create_('c', {5, 5}); - auto z1 = NDArrayFactory::create_('c', {5, 5}); + auto z0 = NDArrayFactory::create('c', {5, 5}); + auto z1 = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, y); - graph.variableSpace()->putVariable(-3, z0); - graph.variableSpace()->putVariable(-4, z1); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, y); + graph.variableSpace().putVariable(-3, z0); + graph.variableSpace().putVariable(-4, z1); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); @@ -203,31 +203,31 @@ TEST_F(GraphTests, DoubleInput2) { graph.execute(); - ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.4142135, z0.reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.0, z1.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, DoubleInput3) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto y = NDArrayFactory::create_('c', {5, 5}); - y->assign(-1.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(-1.0); - auto z0 = NDArrayFactory::create_('c', {5, 5}); - auto z1 = NDArrayFactory::create_('c', {5, 5}); + auto z0 = NDArrayFactory::create('c', {5, 5}); + auto z1 = NDArrayFactory::create('c', {5, 5}); - auto w = NDArrayFactory::create_('c', {5, 5}); + auto w = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, y); - graph.variableSpace()->putVariable(-3, z0); - graph.variableSpace()->putVariable(-4, z1); - graph.variableSpace()->putVariable(-5, w); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, y); + graph.variableSpace().putVariable(-3, z0); + graph.variableSpace().putVariable(-4, z1); + graph.variableSpace().putVariable(-5, w); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); @@ -254,36 +254,36 @@ TEST_F(GraphTests, DoubleInput3) { graph.execute(); - ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.4142135, z0.reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.0, z1.reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(2.4142135, w->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(2.4142135, w.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, QuadInput1) { Graph graph; - auto x0 = NDArrayFactory::create_('c', {5, 5}); - x0->assign(0.0); + auto x0 = NDArrayFactory::create('c', {5, 5}); + x0.assign(0.0); - auto x1 = NDArrayFactory::create_('c', {5, 5}); - x1->assign(-1.0); + auto x1 = NDArrayFactory::create('c', {5, 5}); + x1.assign(-1.0); - auto x2 = NDArrayFactory::create_('c', {5, 5}); - x2->assign(-2.0); + auto x2 = NDArrayFactory::create('c', {5, 5}); + x2.assign(-2.0); - auto x3 = NDArrayFactory::create_('c', {5, 5}); - x3->assign(-3.0); + auto x3 = NDArrayFactory::create('c', {5, 5}); + x3.assign(-3.0); - auto z = NDArrayFactory::create_('c', {5, 5}); - z->assign(119.0); + auto z = NDArrayFactory::create('c', {5, 5}); + z.assign(119.0); - graph.variableSpace()->putVariable(-1, x0); - graph.variableSpace()->putVariable(-2, x1); - graph.variableSpace()->putVariable(-3, x2); - graph.variableSpace()->putVariable(-4, x3); - graph.variableSpace()->putVariable(-5, z); + graph.variableSpace().putVariable(-1, x0); + graph.variableSpace().putVariable(-2, x1); + graph.variableSpace().putVariable(-3, x2); + graph.variableSpace().putVariable(-4, x3); + graph.variableSpace().putVariable(-5, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); @@ -307,19 +307,19 @@ TEST_F(GraphTests, QuadInput1) { graph.execute(); - ASSERT_NEAR(6.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(6.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, InternalBranching1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(0.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(0.0); - auto z = NDArrayFactory::create_('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, z); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, z); // 1.0 Node nodeA(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); @@ -350,24 +350,24 @@ TEST_F(GraphTests, InternalBranching1) { graph.execute(); - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, ReductionsTest1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - for (int r = 0; r < x->rows(); r++) { - for (int c = 0; c < x->columns(); c++) { - x->p(r, c, -c); + auto x = NDArrayFactory::create('c', {5, 5}); + for (int r = 0; r < x.rows(); r++) { + for (int c = 0; c < x.columns(); c++) { + x.p(r, c, -c); } } - auto z = NDArrayFactory::create_('c', {5}); + auto z = NDArrayFactory::create('c', {5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, z); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, z); Node nodeA(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); @@ -379,26 +379,26 @@ TEST_F(GraphTests, ReductionsTest1) { graph.execute(); - ASSERT_NEAR(2.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(2.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, IndexReductionsTest1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - for (int r = 0; r < x->rows(); r++) { - for (int c = 0; c < x->columns(); c++) { - x->p(r, c, -c); + auto x = NDArrayFactory::create('c', {5, 5}); + for (int r = 0; r < x.rows(); r++) { + for (int c = 0; c < x.columns(); c++) { + x.p(r, c, -c); } } - auto z = NDArrayFactory::create_('c', {5, 1}); - auto axis = NDArrayFactory::create_('c', {1}, {1}); + auto z = NDArrayFactory::create('c', {5, 1}); + auto axis = NDArrayFactory::create('c', {1}, {1}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, z); - //graph->variableSpace()->putVariable(-3, axis); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, z); + //graph->variableSpace().putVariable(-3, axis); Node nodeA(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); @@ -411,7 +411,7 @@ TEST_F(GraphTests, IndexReductionsTest1) { graph.execute(); - ASSERT_NEAR(4.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(4.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } #if 0 @@ -420,7 +420,7 @@ TEST_F(GraphTests, AutoOutput1) { auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); - graph->variableSpace()->putVariable(-1, x); + graph->variableSpace().putVariable(-1, x); auto nodeA = new Node(OpType_TRANSFORM_FLOAT, 0, 1, {-1}, {2}); auto nodeB = new Node(OpType_TRANSFORM_FLOAT, 35, 2, {1}, {}); @@ -455,7 +455,7 @@ TEST_F(GraphTests, AutoOutput2) { auto x = NDArrayFactory::create_('c', {5, 5}); x->assign(-2.0); - graph->variableSpace()->putVariable(-1, x); + graph->variableSpace().putVariable(-1, x); auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2, 3, -1}); auto nodeB = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); @@ -492,19 +492,19 @@ TEST_F(GraphTests, AutoOutput2) { TEST_F(GraphTests, BroadcastTest1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(0.f); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(0.f); - auto y = NDArrayFactory::create_('c', {1, 5}); - for (int e = 0; e < y->columns(); e++) { - y->p(e, (float)e+1); + auto y = NDArrayFactory::create('c', {1, 5}); + for (int e = 0; e < y.columns(); e++) { + y.p(e, (float)e+1); } - auto z = NDArrayFactory::create_('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, y); - graph.variableSpace()->putVariable(-3, z); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, y); + graph.variableSpace().putVariable(-3, z); Node nodeA(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); Node nodeB(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); @@ -514,20 +514,20 @@ TEST_F(GraphTests, BroadcastTest1) { graph.execute(); - ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, ScalarTest1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto z = NDArrayFactory::create_('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - graph.variableSpace()->putVariable(-1, x); - graph.variableSpace()->putVariable(-2, z); + graph.variableSpace().putVariable(-1, x); + graph.variableSpace().putVariable(-2, z); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); @@ -541,28 +541,25 @@ TEST_F(GraphTests, ScalarTest1) { graph.execute(); - ASSERT_NEAR(2.714213, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(2.714213, z.reduceNumber(reduce::Mean).e(0), 1e-5); } TEST_F(GraphTests, SymbolicLookupTest1) { Graph graph; - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - auto z = NDArrayFactory::create_('c', {5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(-2.0); - auto vX = new Variable(x); - auto vZ = new Variable(z); + auto z = NDArrayFactory::create('c', {5, 5}); std::string a("alpha"); std::string o("omega"); - vX->setName(a); - vZ->setName(o); + auto vX = std::make_shared(x, a, -1); + auto vZ = std::make_shared(z, o, -1); - graph.variableSpace()->putVariable(-1, vX); - graph.variableSpace()->putVariable(-2, vZ); + graph.variableSpace().putVariable(-1, vX); + graph.variableSpace().putVariable(-2, vZ); Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); @@ -577,25 +574,25 @@ TEST_F(GraphTests, SymbolicLookupTest1) { graph.addNode(nodeB, {1}); - auto rX = graph.variableSpace()->getVariable(a); - auto rZ = graph.variableSpace()->getVariable(o); + auto rX = graph.variableSpace().getVariable(a); + auto rZ = graph.variableSpace().getVariable(o); std::string om("omicron"); ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph.variableSpace()->hasVariable(om)); + ASSERT_FALSE(graph.variableSpace().hasVariable(om)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(2)); graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(p)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(t)); + ASSERT_TRUE(graph.variableSpace().hasVariable(p)); + ASSERT_TRUE(graph.variableSpace().hasVariable(t)); - ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(1.4142135, z.reduceNumber(reduce::Mean).e(0), 1e-5); } #if 0 diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index b5360abaefe8..4af34615fcf3 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -50,16 +50,16 @@ TEST_F(GraphTests2, test_placeholder_1) { graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); - ASSERT_TRUE(graph.variableSpace()->hasVariable("input")); + ASSERT_TRUE(graph.variableSpace().hasVariable("input")); - auto variable = graph.variableSpace()->getVariable("input"); + auto variable = graph.variableSpace().getVariable("input"); ASSERT_NE(nullptr, variable); ASSERT_TRUE(variable->isPlaceholder()); ASSERT_EQ(DataType::BFLOAT16, variable->dataType()); ASSERT_EQ(std::vector({4, 12, 48}), variable->shape()); - auto placeholders = graph.getPlaceholders(); + auto placeholders = graph.placeholders(); ASSERT_EQ(1, placeholders.size()); ASSERT_EQ(placeholders[0], variable); } diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index dbe7ccd0a79d..dc0bcb7443f9 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -195,7 +195,7 @@ TEST_F(IndexingTests, MaskedSlice_0) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); + tads.at(e).assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 5}); @@ -221,7 +221,7 @@ TEST_F(IndexingTests, MaskedSlice_00) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); + tads.at(e).assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); @@ -245,7 +245,7 @@ TEST_F(IndexingTests, MaskedSlice_1) { auto matrix = NDArrayFactory::create('c', {3, 5}); auto tads = matrix.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { - tads.at(e)->assign((float) (e+1)); + tads.at(e).assign((float) (e+1)); } auto exp = NDArrayFactory::create('c', {5}); @@ -281,8 +281,6 @@ TEST_F(IndexingTests, MaskedSlice_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -301,8 +299,6 @@ TEST_F(IndexingTests, MaskedSlice_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -322,8 +318,6 @@ TEST_F(IndexingTests, MaskedSlice_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(IndexingTests, Live_Slice_1) { @@ -346,8 +340,6 @@ TEST_F(IndexingTests, Live_Slice_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -367,8 +359,6 @@ TEST_F(IndexingTests, Test_StridedSlice_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(IndexingTests, Test_StridedSlice_2) { @@ -389,8 +379,6 @@ TEST_F(IndexingTests, Test_StridedSlice_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -410,8 +398,6 @@ TEST_F(IndexingTests, Test_StridedSlice_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -434,8 +420,6 @@ TEST_F(IndexingTests, Test_StridedSlice_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(IndexingTests, Test_Subarray_Strided_1) { @@ -464,7 +448,5 @@ TEST_F(IndexingTests, MaskedSlice_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index dae5ba5b9179..e1e7831531e4 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -173,8 +173,8 @@ TEST_F(LegacyOpsTests, ReduceTests_1) { auto z = result.at(0); // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z->isScalar()); - ASSERT_NEAR(x.sumNumber().e(0), z->e(0), 1e-5f); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.sumNumber().e(0), z.e(0), 1e-5f); } @@ -253,8 +253,8 @@ TEST_F(LegacyOpsTests, ReduceTests_5) { auto z = result.at(0); // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z->isScalar()); - ASSERT_NEAR(x.meanNumber().e(0), z->e(0), 1e-5f); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.meanNumber().e(0), z.e(0), 1e-5f); } @@ -335,8 +335,8 @@ TEST_F(LegacyOpsTests, IndexReduceTests_1) { auto z = result.at(0); - ASSERT_TRUE(z->isScalar()); - ASSERT_EQ(24, z->e(0)); + ASSERT_TRUE(z.isScalar()); + ASSERT_EQ(24, z.e(0)); } diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 284313d081bf..2f5842ec9521 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -58,7 +58,7 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); list.write(e, row); - tads.at(e)->assign(row); + tads.at(e).assign(row); } sd::ops::stack_list op; @@ -81,11 +81,10 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { auto x = NDArrayFactory::create('c', {10, 100}); auto tads = x.allTensorsAlongDimension({1}); for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); + auto row = NDArrayFactory::create('c', {100}); + row.assign((double) e); //list.write(e, row); - tads.at(e)->assign(row); - delete row; + tads.at(e).assign(row); } sd::ops::unstack_list op; @@ -182,10 +181,10 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { } auto tads = exp.allTensorsAlongDimension({1}); - tads.at(0)->assign(1.0f); - tads.at(1)->assign(1.0f); - tads.at(2)->assign(3.0f); - tads.at(3)->assign(3.0f); + tads.at(0).assign(1.0f); + tads.at(1).assign(1.0f); + tads.at(2).assign(3.0f); + tads.at(3).assign(3.0f); sd::ops::pick_list op; @@ -268,14 +267,14 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {5}); row->assign((double) e); - tads.at(e)->assign(row); + tads.at(e).assign(row); if (e < 2) - tads0.at(cnt0++)->assign(row); + tads0.at(cnt0++).assign(row); else if (e < 5) - tads1.at(cnt1++)->assign(row); + tads1.at(cnt1++).assign(row); else - tads2.at(cnt2++)->assign(row); + tads2.at(cnt2++).assign(row); delete row; } @@ -307,7 +306,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 5}); row->assign((double) e); - tads.at(e)->assign(row); + tads.at(e).assign(row); delete row; } @@ -335,18 +334,18 @@ TEST_F(ListOperationsTests, BasicTest_Clone_1) { auto list = new NDArrayList(0, true); VariableSpace variableSpace; - auto var = new Variable(nullptr, nullptr, -1, 0); - var->setNDArrayList(list); + auto var = new Variable(); + //var->setNDArrayList(list); - variableSpace.putVariable(-1, var); - variableSpace.trackList(list); + //variableSpace.putVariable(-1, var); + //variableSpace.trackList(list); Context block(1, &variableSpace); block.pickInput(-1); sd::ops::clone_list op; - ASSERT_TRUE(list == block.variable(0)->getNDArrayList()); + ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); auto result = op.execute(&block); @@ -375,7 +374,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { auto tads = exp.allTensorsAlongDimension({1}); for (int e = 0; e < 10; e++) { auto tad = tads.at(9 - e); - tad->assign(e); + tad.assign(e); } auto indices = NDArrayFactory::create('c', {1, 10}); diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index db342771e5dc..b57a741e0fa6 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -138,7 +138,7 @@ TEST_F(MultiDataTypeTests, Basic_Test_7) { auto z = result.at(0); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index f4fc4a9b7695..eb0c8e897cca 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -85,9 +85,9 @@ TEST_F(OneOffTests, test_pad_1D_1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(4)); + ASSERT_TRUE(graph.variableSpace().hasVariable(4)); - auto z = graph.variableSpace()->getVariable(4)->getNDArray(); + auto z = graph.variableSpace().getVariable(4)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -129,9 +129,9 @@ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(9)); + ASSERT_TRUE(graph.variableSpace().hasVariable(9)); - auto z = graph.variableSpace()->getVariable(9)->getNDArray(); + auto z = graph.variableSpace().getVariable(9)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -144,9 +144,9 @@ TEST_F(OneOffTests, test_tensor_array_1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(5)); + ASSERT_TRUE(graph.variableSpace().hasVariable(5)); - auto z = graph.variableSpace()->getVariable(5)->getNDArray(); + auto z = graph.variableSpace().getVariable(5)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -159,9 +159,9 @@ TEST_F(OneOffTests, test_tensor_array_2) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(6)); + ASSERT_TRUE(graph.variableSpace().hasVariable(6)); - auto z = graph.variableSpace()->getVariable(6)->getNDArray(); + auto z = graph.variableSpace().getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -174,9 +174,9 @@ TEST_F(OneOffTests, test_tensor_array_3) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(15)); + ASSERT_TRUE(graph.variableSpace().hasVariable(15)); - auto z = graph.variableSpace()->getVariable(15)->getNDArray(); + auto z = graph.variableSpace().getVariable(15)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -189,9 +189,9 @@ TEST_F(OneOffTests, test_tensor_array_4) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(11)); + ASSERT_TRUE(graph.variableSpace().hasVariable(11)); - auto z = graph.variableSpace()->getVariable(11)->getNDArray(); + auto z = graph.variableSpace().getVariable(11)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -204,9 +204,9 @@ TEST_F(OneOffTests, test_assert_4) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - auto z = graph.variableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -269,10 +269,10 @@ TEST_F(OneOffTests, test_identity_n_2) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1, 1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1, 1)); - auto z = graph.variableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -285,9 +285,9 @@ TEST_F(OneOffTests, test_non2d_1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(3)); + ASSERT_TRUE(graph.variableSpace().hasVariable(3)); - auto z = graph.variableSpace()->getVariable(3)->getNDArray(); + auto z = graph.variableSpace().getVariable(3)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); @@ -300,12 +300,12 @@ TEST_F(OneOffTests, test_reduce_all_1) { graph.execute(); - ASSERT_TRUE(graph.variableSpace()->hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - ASSERT_TRUE(graph.variableSpace()->hasVariable(2)); - auto in = graph.variableSpace()->getVariable(2)->getNDArray(); + ASSERT_TRUE(graph.variableSpace().hasVariable(2)); + auto in = graph.variableSpace().getVariable(2)->getNDArray(); - auto z = graph.variableSpace()->getVariable(1)->getNDArray(); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 089b4a92f5db..f9987a9d246c 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -45,8 +45,8 @@ TEST_F(ParityOpsTests, TestZeroAs1) { auto z = result.at(0); - ASSERT_TRUE(z->isSameShape(&x)); - ASSERT_TRUE(z->equalsTo(&exp)); + ASSERT_TRUE(z.isSameShape(&x)); + ASSERT_TRUE(z.equalsTo(&exp)); } @@ -93,8 +93,8 @@ TEST_F(ParityOpsTests, TestTear1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float) e + 1); } sd::ops::tear op; @@ -104,7 +104,7 @@ TEST_F(ParityOpsTests, TestTear1) { ASSERT_EQ(10, result.size()); for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } @@ -113,8 +113,8 @@ TEST_F(ParityOpsTests, TestUnstack1) { auto input = NDArrayFactory::create('c', {10, 5}); auto tads = input.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float) e + 1); } sd::ops::unstack op; @@ -124,7 +124,7 @@ TEST_F(ParityOpsTests, TestUnstack1) { ASSERT_EQ(10, result.size()); for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } @@ -135,8 +135,8 @@ TEST_F(ParityOpsTests, TestUnstack2) { auto input = NDArrayFactory::create('c', {5,2,6}); auto tads = input.allTensorsAlongDimension({0,1}); for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(10, tads.at(e)->lengthOf()); - tads.at(e)->assign((float) e + 1); + ASSERT_EQ(10, tads.at(e).lengthOf()); + tads.at(e).assign((float) e + 1); } sd::ops::unstack op; @@ -146,7 +146,7 @@ TEST_F(ParityOpsTests, TestUnstack2) { ASSERT_EQ(6, result.size()); for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } @@ -340,7 +340,7 @@ TEST_F(ParityOpsTests, TestUnstack13) { ASSERT_EQ(3, result.size()); for (int e = 0; e < 3; e++) - ASSERT_EQ(1, result.at(e)->rankOf()); + ASSERT_EQ(1, result.at(e).rankOf()); } @@ -680,7 +680,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) { auto z = result.at(0); - auto tads = z->allTensorsAlongDimension({1}); + auto tads = z.allTensorsAlongDimension({1}); for (int e = 0; e < tads.size(); e++) { ASSERT_TRUE(bias.equalsTo(tads.at(e))); } diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index a1d7ab46f684..98aa4ce67289 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -282,8 +282,8 @@ TEST_F(RNGTests, Test_Gaussian_21) { // mean->printIndexedBuffer("Mean"); // variance->printIndexedBuffer("Variance"); - ASSERT_NEAR(sd::math::nd4j_abs(mean->e(0)), 0.f, 0.2f); - ASSERT_NEAR(variance->e(0), 1.0f, 0.2f); + ASSERT_NEAR(sd::math::nd4j_abs(mean.e(0)), 0.f, 0.2f); + ASSERT_NEAR(variance.e(0), 1.0f, 0.2f); } @@ -313,8 +313,8 @@ TEST_F(RNGTests, Test_Gaussian_22) { //mean0->printIndexedBuffer("Mean"); //variance0->printIndexedBuffer("Variance"); - ASSERT_NEAR(sd::math::nd4j_abs(mean0->e(0)), 0.f, 1.0e-3f); - ASSERT_NEAR(variance0->e(0), 1.0f, 1.e-3f); + ASSERT_NEAR(sd::math::nd4j_abs(mean0.e(0)), 0.f, 1.0e-3f); + ASSERT_NEAR(variance0.e(0), 1.0f, 1.e-3f); } @@ -736,9 +736,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1) { ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); // - z->printBuffer("\nExponential1"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + //z->printBuffer("\nExponential1"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); ASSERT_FALSE(nexp0->equalsTo(z)); @@ -761,9 +761,9 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { ASSERT_TRUE(exp0.isSameShape(z)); ASSERT_FALSE(exp0.equalsTo(z)); // - z->printBuffer("\nExponential2"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + //z->printBuffer("\nExponential2"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); ASSERT_FALSE(nexp0->equalsTo(z)); @@ -1150,14 +1150,14 @@ TEST_F(RNGTests, test_multinomial_5) { auto outputR = resultR.at(0); ASSERT_EQ(Status::OK(), resultR.status()); - deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); - mean = outputR->meanNumber(); + deviation = outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR.meanNumber(); // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); - for (int i = 0; i < outputR->lengthOf(); i++) { - auto value = outputR->e(i); + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); ASSERT_TRUE(value >= 0 && value < ClassValue); } @@ -1184,8 +1184,8 @@ TEST_F(RNGTests, test_multinomial_6) { NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); - for (int i = 0; i < outputR->lengthOf(); i++) { - auto value = outputR->e(i); + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); ASSERT_TRUE(value >= 0 && value < ClassValue); double* z = countsR.bufferAsT(); z[value] += 1; @@ -1198,8 +1198,8 @@ TEST_F(RNGTests, test_multinomial_6) { ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); } - auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto mean = outputR->meanNumber(); + auto deviation = outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR.meanNumber(); // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); diff --git a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp index 4ca8a3806572..bfee646cb7a8 100644 --- a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp @@ -42,8 +42,8 @@ TEST_F(ResultSetTests, basic_test_1) { ASSERT_EQ(3, set.size()); for (int e = 0; e < set.size(); e++) - ASSERT_EQ(5, set.at(e)->lengthOf()); + ASSERT_EQ(5, set.at(e).lengthOf()); for (int e = 0; e < tensors.size(); e++) - ASSERT_EQ(5, tensors.at(e)->lengthOf()); + ASSERT_EQ(5, tensors.at(e).lengthOf()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp index c84f45bb3417..47cbd5f0518e 100644 --- a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp @@ -31,32 +31,22 @@ class SanityTests : public testing::Test { }; - -TEST_F(SanityTests, VariableSpace_1) { - VariableSpace variableSpace; - variableSpace.putVariable(1, new Variable()); - variableSpace.putVariable(1, 1, new Variable()); - - std::pair pair(1, 2); - variableSpace.putVariable(pair, new Variable()); -} - TEST_F(SanityTests, VariableSpace_2) { VariableSpace variableSpace; - variableSpace.putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); - variableSpace.putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + variableSpace.putVariable(1, NDArrayFactory::create('c', {3, 3})); + variableSpace.putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); std::pair pair(1, 2); - variableSpace.putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); + variableSpace.putVariable(pair, NDArrayFactory::create('c', {3, 3})); } TEST_F(SanityTests, Graph_1) { Graph graph; - graph.variableSpace()->putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); - graph.variableSpace()->putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace().putVariable(1, NDArrayFactory::create('c', {3, 3})); + graph.variableSpace().putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); std::pair pair(1, 2); - graph.variableSpace()->putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.variableSpace().putVariable(pair, NDArrayFactory::create('c', {3, 3})); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp deleted file mode 100644 index 8481dfde5cef..000000000000 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALTESTS_H -#define LIBND4J_SESSIONLOCALTESTS_H - -#include "testlayers.h" -#include -#include - -using namespace sd::graph; - -class SessionLocalTests : public testing::Test { -public: - -}; - -TEST_F(SessionLocalTests, BasicTests_1) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - } - - ASSERT_EQ(4, storage.numberOfSessions()); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.endSession(); - } - - ASSERT_EQ(0, storage.numberOfSessions()); -} - - -TEST_F(SessionLocalTests, BasicTests_2) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - auto alpha = sd::NDArrayFactory::create_('c',{5,5}); - alpha->assign(0.0); - - variableSpace.putVariable(-1, alpha); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - - auto varSpace = storage.localVariableSpace(); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(sd::scalar::Add, (float) e+1, *arr); - } - - float lastValue = 0.0f; - for (int e = 1; e <= 4; e++) { - auto varSpace = storage.localVariableSpace((Nd4jLong) e); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - - //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0)); - - ASSERT_NE(lastValue, arr->e(0)); - lastValue = arr->e(0); - } -} - -#endif //LIBND4J_SESSIONLOCALTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index cc13f3529df5..b8ac4c8876c5 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -146,7 +146,7 @@ TEST_F(SingleDimTests, Test_Squeeze_1) { auto z = result.at(0); - ASSERT_EQ(exp.rankOf(), z->rankOf()); + ASSERT_EQ(exp.rankOf(), z.rankOf()); ASSERT_TRUE(exp.equalsTo(z)); diff --git a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp index 16e7cf7ac310..299972c7d0fb 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp @@ -31,7 +31,7 @@ class VariableProxyTests : public testing::Test { TEST_F(VariableProxyTests, Test_Simple_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); VariableSpace ref; ref.putVariable(119, x); @@ -45,7 +45,7 @@ TEST_F(VariableProxyTests, Test_Simple_1) { TEST_F(VariableProxyTests, Test_Simple_2) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); VariableSpace ref; ASSERT_FALSE(ref.hasVariable(119)); @@ -63,8 +63,8 @@ TEST_F(VariableProxyTests, Test_Simple_2) { TEST_F(VariableProxyTests, Test_Simple_3) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); VariableSpace ref; ref.putVariable(119, x); @@ -85,14 +85,14 @@ TEST_F(VariableProxyTests, Test_Simple_3) { auto z1 = proxy.getVariable(119)->getNDArray(); ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } TEST_F(VariableProxyTests, Test_Simple_4) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - auto z = NDArrayFactory::create_('c', {2, 2}, {4, 1, 3, 2}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + auto z = NDArrayFactory::create('c', {2, 2}, {4, 1, 3, 2}); VariableSpace ref; ref.putVariable(119, x); @@ -116,14 +116,14 @@ TEST_F(VariableProxyTests, Test_Simple_4) { auto z1 = proxy.getVariable(119)->getNDArray(); ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } TEST_F(VariableProxyTests, Test_Cast_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); VariableSpace ref; ref.putVariable(-119, x); @@ -145,29 +145,6 @@ TEST_F(VariableProxyTests, Test_Cast_1) { auto z1 = cast->getVariable(-119)->getNDArray(); ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == z1); - ASSERT_TRUE(x == z0); -} - - -TEST_F(VariableProxyTests, Test_Clone_1) { - auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; - - ref.putVariable(118, x); - - VariableProxy proxy(&ref); - - proxy.putVariable(119, y); - - ASSERT_TRUE(proxy.hasVariable(118)); - ASSERT_TRUE(proxy.hasVariable(119)); - - auto clone = proxy.clone(); - - ASSERT_TRUE(clone->hasVariable(118)); - ASSERT_TRUE(clone->hasVariable(119)); - - delete clone; + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp index 26eef254cf9f..21407178210b 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp @@ -44,8 +44,8 @@ class VariableSpaceTest : public testing::Test { TEST_F(VariableSpaceTest, SettersGettersTest1) { auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create_('c', {5, 5}); - auto arrayB = NDArrayFactory::create_('c', {3, 3}); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); space1->putVariable(1, arrayA); space1->putVariable(2, arrayB); @@ -53,8 +53,8 @@ TEST_F(VariableSpaceTest, SettersGettersTest1) { auto arrayRA = space1->getVariable(1); auto arrayRB = space1->getVariable(2); - ASSERT_TRUE(arrayA == arrayRA->getNDArray()); - ASSERT_TRUE(arrayB == arrayRB->getNDArray()); + ASSERT_TRUE(arrayA == *arrayRA->getNDArray()); + ASSERT_TRUE(arrayB == *arrayRB->getNDArray()); // we should survive this call delete space1; @@ -63,16 +63,11 @@ TEST_F(VariableSpaceTest, SettersGettersTest1) { TEST_F(VariableSpaceTest, SettersGettersTest2) { auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create_('c', {5, 5}); - auto arrayB = NDArrayFactory::create_('c', {3, 3}); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); - auto varA = new Variable(arrayA); - auto varB = new Variable(arrayB); - - varA->markExternal(true); - - space1->putVariable(-1, varA); - space1->putVariable(2, varB); + space1->putVariable(-1, arrayA); + space1->putVariable(2, arrayB); Nd4jLong expExternal = (25 * 4) + (8 * 8); Nd4jLong expInternal = (9 * 4) + (8 * 8); @@ -88,8 +83,8 @@ TEST_F(VariableSpaceTest, EqualityTest1) { std::string name("myvar"); - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - auto variableA = new Variable(arrayA, name.c_str()); + auto arrayA = NDArrayFactory::create('c', {3, 3}); + auto variableA = std::make_shared(arrayA, name, 1); space.putVariable(1, variableA); @@ -110,7 +105,7 @@ TEST_F(VariableSpaceTest, EqualityTest1) { TEST_F(VariableSpaceTest, EqualityTest2) { VariableSpace space; - auto arrayA = NDArrayFactory::create_('c', {3, 3}); + auto arrayA = NDArrayFactory::create('c', {3, 3}); space.putVariable(1, arrayA); @@ -123,98 +118,4 @@ TEST_F(VariableSpaceTest, EqualityTest2) { auto rV2 = space.getVariable(pair); ASSERT_TRUE(rV1 == rV2); -} - -TEST_F(VariableSpaceTest, CloneTests_1) { - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - spaceA.putVariable(1, arrayA); - - auto spaceB = spaceA.clone(); - - std::pair pair(1,0); - - ASSERT_TRUE(spaceB->hasVariable(1)); - ASSERT_TRUE(spaceB->hasVariable(pair)); - - auto arrayB = spaceB->getVariable(1)->getNDArray(); - - ASSERT_TRUE(arrayA->equalsTo(arrayB)); - - arrayB->assign(2.0); - - ASSERT_FALSE(arrayA->equalsTo(arrayB)); - - delete spaceB; -} - -TEST_F(VariableSpaceTest, CloneTests_2) { - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - auto variableA = new Variable(arrayA, "alpha"); - - std::string str("alpha"); - std::pair pair(2, 3); - - spaceA.putVariable(pair, variableA); - - ASSERT_TRUE(spaceA.hasVariable(str)); - ASSERT_TRUE(spaceA.hasVariable(pair)); - - auto spaceB = spaceA.clone(); - - ASSERT_FALSE(spaceB->hasVariable(1)); - ASSERT_FALSE(spaceB->hasVariable(2)); - ASSERT_TRUE(spaceB->hasVariable(pair)); - ASSERT_TRUE(spaceB->hasVariable(str)); - - auto arrayB = spaceB->getVariable(pair)->getNDArray(); - - ASSERT_TRUE(arrayA->equalsTo(arrayB)); - - arrayB->assign(2.0); - - ASSERT_FALSE(arrayA->equalsTo(arrayB)); - - delete spaceB; - - ASSERT_TRUE(spaceA.hasVariable(str)); - ASSERT_TRUE(spaceA.hasVariable(pair)); -} - - -TEST_F(VariableSpaceTest, Test_DType_Conversion_1) { - /* - VariableSpace spaceA; - - auto arrayA = NDArrayFactory::create_('c', {3, 3}); - arrayA->assign(1.0); - - auto variableA = new Variable(arrayA, "alpha"); - - std::string str("alpha"); - std::pair pair(2, 3); - - spaceA.putVariable(pair, variableA); - - - auto sd = spaceA.template asT(); - auto sf = sd->template asT(); - - ASSERT_TRUE(sf->hasVariable(pair)); - - auto xf = sf->getVariable(pair)->getNDArray(); - - ASSERT_TRUE(arrayA->isSameShape(xf)); - ASSERT_TRUE(arrayA->equalsTo(xf)); - - delete sd; - delete sf; - */ } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index acf92a0075d5..90deeab8d894 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -34,34 +34,6 @@ class VariableTests : public testing::Test { }; -TEST_F(VariableTests, TestClone_1) { - auto array1 = NDArrayFactory::create_('c', {5, 5}); - array1->assign(1.0); - - auto var1 = new Variable(array1, "alpha"); - var1->setId(119); - - - auto var2 = var1->clone(); - - ASSERT_FALSE(var1->getNDArray() == var2->getNDArray()); - auto array2 = var2->getNDArray(); - - ASSERT_TRUE(array1->equalsTo(array2)); - ASSERT_EQ(var1->id(), var2->id()); - ASSERT_EQ(var1->getName(), var2->getName()); - - delete var1; - - std::string str("alpha"); - ASSERT_EQ(var2->getName(), str); - array2->assign(2.0); - - ASSERT_NEAR(2.0, array2->meanNumber().e(0), 1e-5); - - delete var2; -} - TEST_F(VariableTests, Test_FlatVariableDataType_1) { flatbuffers::FlatBufferBuilder builder(1024); auto original = NDArrayFactory::create('c', {5, 10}); @@ -90,8 +62,8 @@ TEST_F(VariableTests, Test_FlatVariableDataType_1) { auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(restoredArray)); - ASSERT_TRUE(original.equalsTo(restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); delete rv; } @@ -124,8 +96,8 @@ TEST_F(VariableTests, Test_FlatVariableDataType_2) { auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(restoredArray)); - ASSERT_TRUE(original.equalsTo(restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); delete rv; } @@ -162,7 +134,7 @@ TEST_F(VariableTests, Test_FlatVariableDataType_3) { auto restoredArray = rv->getNDArray(); auto conv = restoredArray->asT(); - ASSERT_TRUE(floating.isSameShape(restoredArray)); + ASSERT_TRUE(floating.isSameShape(*restoredArray)); ASSERT_TRUE(floating.equalsTo(conv)); delete rv; @@ -202,24 +174,4 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { delete rv; } */ -TEST_F(VariableTests, Test_Dtype_Conversion_1) { - auto x = NDArrayFactory::create_('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - Variable v(x, "alpha", 12, 3); - - auto vd = v.template asT(); - auto vf = vd->template asT(); - - ASSERT_EQ(v.getName(), vf->getName()); - ASSERT_EQ(v.id(), vf->id()); - ASSERT_EQ(v.index(), vf->index()); - - auto xf = vf->getNDArray(); - - ASSERT_TRUE(x->isSameShape(xf)); - ASSERT_TRUE(x->equalsTo(xf)); - - delete vd; - delete vf; -} - #endif //LIBND4J_VARIABLETESTS_H From b21f5f8924445afca60522593dee9497fb8ab776 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 11:34:07 +0300 Subject: [PATCH 078/233] first set of small fixes Signed-off-by: raver119 --- libnd4j/include/graph/VariableProxy.h | 6 +++--- libnd4j/include/graph/VariableSpace.h | 8 ++++---- libnd4j/include/graph/impl/Context.cpp | 6 ++++++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 2 +- libnd4j/include/graph/impl/VariableProxy.cpp | 6 +++--- libnd4j/include/graph/impl/VariableSpace.cpp | 20 +++++++++---------- 6 files changed, 27 insertions(+), 21 deletions(-) diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index b49b9f55df34..eea5ababb3cc 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -55,9 +55,9 @@ namespace sd { virtual std::shared_ptr putVariable(int id, const NDArray &array) override; virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array) override; virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array) override; - virtual void putVariable(const std::string& name, int id, int idx, std::shared_ptr variable) override; - virtual void putVariable(const std::pair& pair, std::shared_ptr variable) override; - virtual void putVariable(int id, std::shared_ptr variable) override; + virtual void putVariable(const std::string& name, int id, int idx, const std::shared_ptr &variable) override; + virtual void putVariable(const std::pair& pair, const std::shared_ptr &variable) override; + virtual void putVariable(int id, const std::shared_ptr &variable) override; virtual void replaceVariable(std::shared_ptr variable) override; diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index c7d6d0238c86..798e350aef7d 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -58,7 +58,7 @@ namespace sd { // placeholders. must be resolved before Graph execution std::vector> _placeholders; - void silentPutVariable(const std::pair& pair, std::shared_ptr variable); + void silentPutVariable(const std::pair& pair, const std::shared_ptr &variable); int _auto_counter = -1; @@ -99,9 +99,9 @@ namespace sd { virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array); virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array); - virtual void putVariable(const std::string& name, int id, int idx, std::shared_ptr variable); - virtual void putVariable(const std::pair& pair, std::shared_ptr variable); - virtual void putVariable(int id, std::shared_ptr variable); + virtual void putVariable(const std::string& name, int id, int idx, const std::shared_ptr &variable); + virtual void putVariable(const std::pair& pair, const std::shared_ptr &variable); + virtual void putVariable(int id, const std::shared_ptr &variable); virtual void dropVariable(const std::string &pair); virtual void dropVariable(const std::pair &pair); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index da8a5037baee..cc32059fc184 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -480,10 +480,16 @@ namespace sd { } void Context::setInputArray(int index, const std::shared_ptr &array) { + if (_fastpath_in.size() < index + 1) + _fastpath_in.resize(index+1); + _fastpath_in[index] = array; } void Context::setOutputArray(int index, const std::shared_ptr &array) { + if (_fastpath_out.size() < index + 1) + _fastpath_out.resize(index+1); + _fastpath_out[index] = array; } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 769a1eaf7b09..f6de752dca96 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -107,7 +107,7 @@ namespace sd { int inExCounts = 0, inInternalCounts = 0; for (auto in = inputs.begin(); in != inputs.end(); ++in) { - if (originalGraph().variableSpace()->hasVariable(in->first, 0)) { + if (originalGraph().variableSpace().hasVariable(in->first, 0)) { inExCounts++; } else { diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index b7b087fde3a7..d90ed10d8a56 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -185,12 +185,12 @@ namespace sd { } - void VariableProxy::putVariable(const std::pair& pair, std::shared_ptr variable) { + void VariableProxy::putVariable(const std::pair& pair, const std::shared_ptr &variable) { _current->putVariable(pair, variable); } - void VariableProxy::putVariable(int id, std::shared_ptr variable) { + void VariableProxy::putVariable(int id, const std::shared_ptr &variable) { _current->putVariable(id, variable); } @@ -203,7 +203,7 @@ namespace sd { return _current->putVariable(id, idx, array); } - void VariableProxy::putVariable(const std::string& name, int id, int idx, std::shared_ptr array) { + void VariableProxy::putVariable(const std::string& name, int id, int idx, const std::shared_ptr &array) { _current->putVariable(name, id, idx, array); } diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index b5eacfe32507..8d8d2856eea5 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -174,7 +174,7 @@ namespace sd { std::shared_ptr VariableSpace::putVariable(const std::pair& pair, const NDArray &array) { - auto variable = std::make_shared(array, nullptr, pair.first, pair.second); + auto variable = std::make_shared(array, "", pair.first, pair.second); this->putVariable(pair, variable); return variable; } @@ -184,23 +184,23 @@ namespace sd { return this->putVariable(pair, array); } - void VariableSpace::putVariable(const std::string& name, int node, int idx, std::shared_ptr variable) { + void VariableSpace::putVariable(const std::string& name, int node, int idx, const std::shared_ptr &variable) { std::pair pair(node, idx); variable->setName(name); this->putVariable(pair, variable); } - void VariableSpace::silentPutVariable(const std::pair& pair, std::shared_ptr variable) { + void VariableSpace::silentPutVariable(const std::pair& pair, const std::shared_ptr &variable) { std::lock_guard lock(_varmap); _paired[pair] = variable; } - void VariableSpace::putVariable(const std::pair& pair, std::shared_ptr variable) { + void VariableSpace::putVariable(const std::pair& pair, const std::shared_ptr &variable) { silentPutVariable(pair, variable); if (variable->isPlaceholder()) - _placeholders.push_back(variable); + _placeholders.emplace_back(variable); // copying duplicate for compatibility if (pair.second == 0 && !this->hasVariable(pair.first)) { @@ -213,7 +213,7 @@ namespace sd { } - void VariableSpace::putVariable(int id, std::shared_ptr variable) { + void VariableSpace::putVariable(int id, const std::shared_ptr &variable) { // we don't want to add variables more then once if (_variables.count(id) > 0) { throw std::runtime_error("VariableSpace::putVariable - duplicate found"); @@ -235,12 +235,11 @@ namespace sd { // we have special list for external variables to ensure graph completeness if (id < 0) { - //if (variable->isExternal()) - _external.push_back(variable); + _external.emplace_back(variable); _variables[id] = variable; } else { - _internal.push_back(variable); + _internal.emplace_back(variable); } } @@ -250,13 +249,14 @@ namespace sd { this->silentPutVariable(pair, variable); if (variable->isPlaceholder()) - _placeholders.push_back(variable); + _placeholders.emplace_back(variable); } } std::shared_ptr VariableSpace::putVariable(int id, const NDArray &array) { auto var = std::make_shared(array, "", id, 0); this->putVariable(id, var); + return var; } std::shared_ptr VariableSpace::getVariable(int id) const { From aee40664383340659a4caa9b3bc32d9520bc5e17 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 14:04:41 +0300 Subject: [PATCH 079/233] no crashing tests atm Signed-off-by: raver119 --- libnd4j/include/array/NDArray.h | 9 ++- libnd4j/include/array/cpu/NDArray.cpp | 2 - libnd4j/include/array/cuda/NDArray.cu | 1 - .../array/{NDArray.hXX => impl/NDArray.cpp} | 58 ++++++++++++++++--- libnd4j/include/graph/Context.h | 7 +++ libnd4j/include/graph/impl/Context.cpp | 13 ++++- .../loops/BroadcastPairwiseConverter.h | 1 + .../generic/nn/dot_product_attention.cpp | 2 +- .../include/ops/declarable/impl/BooleanOp.cpp | 5 +- .../ops/declarable/impl/DeclarableListOp.cpp | 3 +- .../ops/declarable/impl/DeclarableOp.cpp | 11 ++-- .../ops/declarable/impl/LegacyRandomOp.cpp | 2 +- .../ops/declarable/impl/PlatformHelper.cpp | 3 + libnd4j/include/system/op_boilerplate.h | 2 +- .../layers_tests/DeclarableOpsTests1.cpp | 7 ++- .../tests_cpu/layers_tests/OneOffTests.cpp | 3 + 16 files changed, 99 insertions(+), 30 deletions(-) rename libnd4j/include/array/{NDArray.hXX => impl/NDArray.cpp} (99%) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 867eaaf76350..044eb2ff5f3c 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -1491,7 +1491,14 @@ namespace sd { FORCEINLINE bool isAttached(); - NDArray* detach(); + NDArray detach(); + + /** + * This method returns true if array is valid array with some shape etc + * @return + */ + bool defined() const; + bool undefined() const; FORCEINLINE bool operator==(const NDArray &other) const; diff --git a/libnd4j/include/array/cpu/NDArray.cpp b/libnd4j/include/array/cpu/NDArray.cpp index 1d97ba61cd7f..cb8b7a161340 100644 --- a/libnd4j/include/array/cpu/NDArray.cpp +++ b/libnd4j/include/array/cpu/NDArray.cpp @@ -44,8 +44,6 @@ #include #include -#include - namespace sd { diff --git a/libnd4j/include/array/cuda/NDArray.cu b/libnd4j/include/array/cuda/NDArray.cu index 6e2c7980420c..2b64463ae564 100644 --- a/libnd4j/include/array/cuda/NDArray.cu +++ b/libnd4j/include/array/cuda/NDArray.cu @@ -44,7 +44,6 @@ #include #include #include -#include #include namespace sd { diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/impl/NDArray.cpp similarity index 99% rename from libnd4j/include/array/NDArray.hXX rename to libnd4j/include/array/impl/NDArray.cpp index d280cec37f79..1a1914d55929 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -20,12 +20,22 @@ #ifndef __NDARRAY__HPP__ #define __NDARRAY__HPP__ +#include #include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace sd { @@ -42,17 +52,25 @@ SD_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; // copy constructor NDArray::NDArray(const NDArray& other) { - _context = other._context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - + //setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); +/* if(!isEmpty()) { _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); this->assign(&other); } else _buffer = std::make_shared(); + */ + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _length = other._length; + _isAttached = other._isAttached; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; } //////////////////////////////////////////////////////////////////////// @@ -74,6 +92,14 @@ NDArray::NDArray(const char order, const std::vector &shape, sd::DataT _buffer->setToZeroBuffers(); } + bool NDArray::defined() const { + return _shapeInfo != nullptr; + } + + bool NDArray::undefined() const { + return _shapeInfo == nullptr; + } + //////////////////////////////////////////////////////////////////////// NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, sd::DataType dtype, sd::LaunchContext * context) { @@ -860,6 +886,18 @@ NDArray::NDArray(const std::vector& shape, const std::vectorassign(&other); @@ -876,6 +914,8 @@ NDArray::NDArray(const std::vector& shape, const std::vector(); } + */ + return *this; } @@ -1235,16 +1275,16 @@ template SD_EXPORT void NDArray::assign(const uint64_t& value, bool allowParalle template SD_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); ////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::detach() { +NDArray NDArray::detach() { if (!isAttached()) - return this; + return *this; std::shared_ptr newBuffer = std::make_shared(lengthOf() * sizeOfT(), dataType()); - auto result = new NDArray(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); + NDArray result(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); - result->assign(*this); + result.assign(*this); return result; } diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 646df825f5f3..ca0be0e6b399 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -126,6 +126,13 @@ namespace sd { std::shared_ptr getNDArray(int idx) const; std::shared_ptr array(int idx) const; + /** + * This is special method, used only within Graph + * @param idx + * @return + */ + NDArray* arrayForOp(int idx) const; + /** * This method fetches variable from VariableSpace DIRECTLY diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index cc32059fc184..851648cde609 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -214,9 +214,7 @@ namespace sd { _variableSpace->putVariable(pair, var); } else { auto var = _variableSpace->getVariable(pair); - if (var->hasNDArray()) { - var->setNDArray(std::make_shared(array)); - } + var->setNDArray(std::make_shared(array)); } } } @@ -308,6 +306,15 @@ namespace sd { _fastpath_in[index] = std::make_shared(array); } + NDArray *Context::arrayForOp(int idx) const { + auto ptr = array(idx); + + if (ptr.get() != nullptr && ptr->undefined()) + return nullptr; + + return ptr.get(); + } + void Context::setInputArray(int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { auto array = std::make_shared(buffer, specialBuffer, reinterpret_cast(shapeInfo)); diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index cb0655224d97..c6160d953129 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -22,6 +22,7 @@ #define SD_BROADCASTPAIRWISECONVERTER_H #include +#include #include namespace sd { diff --git a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp index bd0cf329a8b9..10032a5f2f5b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp @@ -73,7 +73,7 @@ namespace ops { *weights /= sqrt((double)keys->sizeAt(-2)); } - if(mask != nullptr){ + if(mask != nullptr && mask->defined()){ NDArray reshapedMask; if(weights->rankOf() == 4){ reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index 4c2a77308fd1..f584f803cdcb 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -120,9 +120,8 @@ namespace sd { int cnt = -1; std::vector in; for (auto v: args) { - auto var = std::make_shared(v); - var->markRemovable(false); - in.push_back(cnt); + auto var = std::make_shared(*v, "", cnt); + in.emplace_back(cnt); variableSpace.putVariable(cnt--, var); } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index 15d76c75e461..da5671069319 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -145,8 +145,7 @@ namespace sd { if (var->hasNDArray()) { auto arr = var->getNDArray(); if (arr->isAttached()) { - auto d = arr->detach(); - res.push_back(*d); + res.push_back(arr->detach()); } else { var->markRemovable(false); res.push_back(*arr); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index d049293c172d..35be39f46f8c 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -147,6 +147,9 @@ namespace sd { } } + if (z != nullptr && z->undefined()) + return nullptr; + return z; } @@ -957,11 +960,11 @@ namespace sd { Context ctx(1); for (int e = 0; e < inputs.size(); e++) { - ctx.setInputArray(e, *inputs[e]); + ctx.setInputArray(e, inputs[e] == nullptr ? NDArray() : *inputs[e]); } for (int e = 0; e < outputs.size(); e++) { - ctx.setOutputArray(e, *outputs[e]); + ctx.setOutputArray(e, outputs[e] == nullptr ? NDArray() : *outputs[e]); } @@ -1030,7 +1033,7 @@ namespace sd { if (v == nullptr) continue; - auto var = std::make_shared(v); + auto var = std::make_shared(*v, "", cnt, 0); var->markRemovable(false); in.push_back(cnt); variableSpace.putVariable(cnt--, var); @@ -1073,7 +1076,7 @@ namespace sd { arr->setContext(sd::LaunchContext::defaultContext()); arrayList.push_back(*arr.get()); } else { - arrayList.push_back(*arr->detach()); + arrayList.push_back(arr->detach()); } } else break; diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index e83e3755dcf0..b89b476c0998 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -409,7 +409,7 @@ namespace sd { var->markRemovable(false); arrayList.push_back(*arr.get()); } else { - arrayList.push_back(*arr->detach()); + arrayList.push_back(arr->detach()); } } else break; diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index 6cd7873a1539..ccf69b586991 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -33,6 +33,9 @@ namespace sd { sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { auto result = getZ(block, inputId); + if (result != nullptr && result->undefined()) + return nullptr; + if (result != nullptr && !block.isInplace()) result->nullify(); diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 251babee9ddb..ece447991055 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1511,7 +1511,7 @@ #define CHECK_STASH(NAME) block.getStash()->checkStash(block.getNodeId(), NAME); #define UNSTASH(NAME) block.getStash()->extractArray(block.getNodeId(), NAME); -#define INPUT_VARIABLE(INDEX) block.array(INDEX).get() +#define INPUT_VARIABLE(INDEX) block.arrayForOp(INDEX) #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 649d4db6ab3d..2d266b2ef4b4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2101,12 +2101,12 @@ TEST_F(DeclarableOpsTests1, IsMax1) { exp.p(2, 2, true); sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 1 }); + auto result = ismaxOp.evaluate({ &x }, { 1 }); ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); + res.printIndexedBuffer("IS_MAX"); ASSERT_TRUE(exp.equalsTo(res)); @@ -3312,6 +3312,9 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { } TEST_F(DeclarableOpsTests1, Test_Expose_2) { + if (1 > 0) + throw std::runtime_error("Not implemented yet"); + auto list = new NDArrayList(0, true); auto var = std::make_shared(NDArray(), "arraylist", -1, 0); diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index eb0c8e897cca..ca99df66d837 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -279,6 +279,9 @@ TEST_F(OneOffTests, test_identity_n_2) { } TEST_F(OneOffTests, test_non2d_1) { + if (1 > 0) + throw std::runtime_error("Not implemented yet"); + auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); From be7cad9bba46f6e19c2674b65b28ca05e9a14884 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 15:48:50 +0300 Subject: [PATCH 080/233] 30 tests fixed Signed-off-by: raver119 --- libnd4j/include/graph/Context.h | 4 ++-- libnd4j/include/graph/impl/Context.cpp | 6 +++--- libnd4j/include/graph/impl/VariableSpace.cpp | 17 +++++++-------- .../declarable/generic/parity_ops/expose.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 2 +- .../tests_cpu/layers_tests/ContextTests.cpp | 12 +++++------ .../layers_tests/ConvolutionTests2.cpp | 21 ++++++++++--------- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index ca0be0e6b399..8d57891fa7d7 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -151,9 +151,9 @@ namespace sd { void pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track = true); void pushNDArrayListToVariableSpace(const std::pair& pair, const NDArrayList &list, bool track = true); - bool isValueAvailable(int idx = 0) const; + bool isValueAvailable(const std::string &name, int id, int idx = 0) const; - std::shared_ptr ensureVariable(int idx = 0); + std::shared_ptr ensureVariable(const std::string &name, int id, int idx = 0); unsigned long width() const override; diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 851648cde609..49a7600f98d4 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -236,7 +236,7 @@ namespace sd { } } - std::shared_ptr Context::ensureVariable(int idx) { + std::shared_ptr Context::ensureVariable(const std::string &name, int id, int idx) { std::pair pair(this->nodeId(), idx); if (_variableSpace == nullptr) @@ -257,8 +257,8 @@ namespace sd { } } - bool Context::isValueAvailable(int idx) const { - auto var = const_cast(this)->ensureVariable(idx); + bool Context::isValueAvailable(const std::string &name, int id, int idx ) const { + auto var = const_cast(this)->ensureVariable(name, id, idx); if (var->variableType() == VariableType::NDARRAY) { return var->hasNDArray(); diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 8d8d2856eea5..69991c8a61f8 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -99,7 +99,7 @@ namespace sd { } bool VariableSpace::hasVariable(int id) const { - return _variables.count(id) == 1; + return _variables.count(id) > 0; } bool VariableSpace::hasVariable(const std::pair& id) const { @@ -205,10 +205,10 @@ namespace sd { // copying duplicate for compatibility if (pair.second == 0 && !this->hasVariable(pair.first)) { this->putVariable(pair.first, variable); - } else { - if (!variable->getName().empty()) { - _symbolic[variable->getName()] = variable; - } + } + + if (!variable->getName().empty()) { + _symbolic[variable->getName()] = variable; } } @@ -229,18 +229,17 @@ namespace sd { if (!variable->getName().empty()) { //std::pair pair(*(variable->getName()), variable); - _symbolic[variable->getName()] = variable; + _symbolic[variable->name()] = variable; } // we have special list for external variables to ensure graph completeness - if (id < 0) { _external.emplace_back(variable); - - _variables[id] = variable; } else { _internal.emplace_back(variable); } + + _variables[id] = variable; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp index b625c56a2fdf..bb9cbe5fd222 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp @@ -32,7 +32,7 @@ namespace sd { out->assign(in); } else if (inVar->variableType() == VariableType::ARRAY_LIST) { - auto var = block.ensureVariable(e); + auto var = block.ensureVariable(block.name(), block.nodeId(), e); if (!var->hasNDArrayList()) { auto list = inVar->getNDArrayList(); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 35be39f46f8c..f1564d11dfb6 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -305,7 +305,7 @@ namespace sd { // we need to check, if Z is really needed std::pair pair(ctx.nodeId(), cnt++); - if (!ctx.isValueAvailable(pair.second)) { + if (!ctx.isValueAvailable(ctx.name(), ctx.nodeId(), pair.second)) { if (Environment::getInstance()->isDebugAndVerbose()) shape::printShapeInfoLinear("Going to create variable with shape", out); diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 4baf1278b696..9176b4f77452 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -149,8 +149,8 @@ TEST_F(ContextTests, Basic_Test_6) { Context ctx(1, &variableSpace); - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); ASSERT_TRUE(variableSpace.hasVariable(1, 0)); ASSERT_TRUE(variableSpace.hasVariable(1, 1)); @@ -168,8 +168,8 @@ TEST_F(ContextTests, Basic_Test_7) { Context ctx(1, &variableSpace); - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); ASSERT_TRUE(variableSpace.hasVariable(1, 0)); ASSERT_TRUE(variableSpace.hasVariable(1, 1)); @@ -214,8 +214,8 @@ TEST_F(ContextTests, Basic_Test_8) { auto z0 = variableSpace.getVariable(1, 0); auto z1 = variableSpace.getVariable(1, 1); - auto v0 = ctx.ensureVariable(); - auto v1 = ctx.ensureVariable(1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); ASSERT_TRUE(v0 == z0); ASSERT_TRUE(v1 == z1); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 2bd6c5445cc2..ceb9e68e8958 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -946,24 +946,25 @@ TEST_F(ConvolutionTests2, maxpool2d_2) { auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); // auto z('c',{bS,iD,oH,oW}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); // variableSpace->putVariable(1, &z); - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + Context block(1, &variableSpace, false); + block.setName("alpha"); + block.fillInputs({-1}); + block.appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); + Nd4jStatus status = pooling.execute(&block); ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId(), 0)); + ASSERT_TRUE(variableSpace.hasVariable("alpha")); + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId())); + auto result = variableSpace.getVariable(block.nodeId())->getNDArray(); // result.printShapeInfo(); ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; } ////////////////////////////////////////////////////////////////////// From 9ee47c10032e3455f5e14143f0c0594e0f88f3cd Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 16:40:26 +0300 Subject: [PATCH 081/233] check for inplace flag validity Signed-off-by: raver119 --- .../generic/broadcastable/floormod.cpp | 2 +- .../declarable/generic/shape/reshape_as.cpp | 13 +++++-------- .../ops/declarable/impl/DeclarableOp.cpp | 6 ++++++ .../layers_tests/DeclarableOpsTests14.cpp | 19 ++++++++++--------- .../layers_tests/DeclarableOpsTests7.cpp | 14 ++++++-------- .../layers_tests/DeclarableOpsTests9.cpp | 2 +- 6 files changed, 29 insertions(+), 27 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index fac2099055c8..1319ccfd0a45 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -66,7 +66,7 @@ namespace sd { auto gradY = OUTPUT_VARIABLE(1); gradX->assign(epsNext); - NDArray temp(*epsNext); + auto temp = epsNext->dup(); BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); if (gradY->rankOf() == gradX->rankOf()) diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp index 90e2ff398aa8..c0008cb08b24 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp @@ -35,19 +35,16 @@ namespace sd { auto z = OUTPUT_VARIABLE(0); - if (x->reshapei(y->ordering(), y->getShapeAsVector())) { + // FIXME: add validation here? + auto tmp = x->reshape(y->ordering(), y->getShapeAsVector()); + z->assign(tmp); - z->assign(x); - return Status::OK(); - } - - return ND4J_STATUS_BAD_INPUT; + return Status::OK(); } DECLARE_SYN(reshape_as, reshapeas); DECLARE_SHAPE_FN(reshapeas) { - - return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace())); + return SHAPELIST(CONSTANT(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->shapeInfo(), false, block.workspace()))); } DECLARE_TYPES(reshapeas) { diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index f1564d11dfb6..db542b6e0bba 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -607,6 +607,10 @@ namespace sd { if (Environment::getInstance()->isProfiling()) timeEnter = std::chrono::system_clock::now(); + // make sure we're not trying to call non-inpace op inplace + if (block->isInplace() && !this->getOpDescriptor()->allowsInplace()) + throw std::runtime_error("DeclarableOp::execute - trying to execute non-inplace op as inplace"); + // basic validation: ensure inputs are set REQUIRE_OK(this->validateNonEmptyInput(*block)); @@ -617,6 +621,8 @@ namespace sd { REQUIRE_OK(this->validateDataTypes(*block)); + + // this method will allocate output NDArrays for this op auto numOutputs = this->prepareOutputs(*block); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index bfe9f60d4ab2..e21b81dbe497 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -2064,20 +2064,21 @@ TEST_F(DeclarableOpsTests14, Reshape1) { auto y = NDArrayFactory::create('f', yShape); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); + variableSpace.putVariable(-2, y); + + Context block(1, &variableSpace, false); + block.fillInputs({ -1, -2 }); sd::ops::reshapeas reshape; - reshape.execute(block); + reshape.execute(&block); - ASSERT_TRUE(x.isSameShape(y)); + ASSERT_TRUE(variableSpace.hasVariable(1)); + auto z = variableSpace.getVariable(1)->getNDArray().get(); - delete variableSpace; - delete block; + ASSERT_TRUE(y.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index d11b34b5ce16..ce258d613852 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -3316,9 +3316,8 @@ auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { }); // ---------------------------------------------------------------- sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {38}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); + auto result = op.evaluate({&x}, {}, {38}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); //x.printIndexedBuffer("Output 4 inplace"); //exp.printIndexedBuffer("Expect 4 inplace"); @@ -3409,8 +3408,8 @@ auto exp = NDArrayFactory::create('c', {2, 3, 2}, { // ---------------------------------------------------------------- sd::ops::roll op; NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); //x.printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); @@ -3431,9 +3430,8 @@ auto exp = NDArrayFactory::create('c', {2, 3, 3}, { }); // ---------------------------------------------------------------- sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, {}, true); - ASSERT_EQ(result, Status::OK()); + auto result = op.evaluate({&x}, {}, {1, 1}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(&x)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 713cb2df4148..f342506f463f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2350,7 +2350,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { eps.assign(1.f); sd::ops::floormod_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); + auto result = op.evaluate({&x, &y, &eps}); ASSERT_TRUE(result.size() == 2); auto gradX = result.at(0); From 1e9d587117ed7e80f6071f0892558377be5a982d Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 3 Apr 2020 19:12:06 +0300 Subject: [PATCH 082/233] more fixed tests JacobiSVD, QR, cholesky etc Signed-off-by: raver119 --- libnd4j/include/helpers/cpu/hhColPivQR.cpp | 43 ++++++++----------- libnd4j/include/helpers/cpu/hhSequence.cpp | 11 +++-- libnd4j/include/helpers/cpu/householder.cpp | 29 ++++++------- libnd4j/include/helpers/cpu/jacobiSVD.cpp | 8 ++-- .../ops/declarable/helpers/cpu/lup.cpp | 4 +- .../include/ops/declarable/helpers/cpu/qr.cpp | 4 +- .../ops/declarable/helpers/cpu/svd.cpp | 20 ++++----- 7 files changed, 55 insertions(+), 64 deletions(-) diff --git a/libnd4j/include/helpers/cpu/hhColPivQR.cpp b/libnd4j/include/helpers/cpu/hhColPivQR.cpp index e118b0bf1c58..1a57749d0ebe 100644 --- a/libnd4j/include/helpers/cpu/hhColPivQR.cpp +++ b/libnd4j/include/helpers/cpu/hhColPivQR.cpp @@ -84,13 +84,11 @@ void HHcolPivQR::_evalData() { if(k != biggestColIndex) { - auto temp1 = new NDArray(_qr({0,0, k,k+1}, true)); - auto temp2 = new NDArray(_qr({0,0, biggestColIndex,biggestColIndex+1}, true)); - auto temp3 = *temp1; - temp1->assign(temp2); - temp2->assign(temp3); - delete temp1; - delete temp2; + auto temp1 = _qr({0,0, k,k+1}, true); + auto temp2 = _qr({0,0, biggestColIndex,biggestColIndex+1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); T e0 = normsUpd.e(k); T e1 = normsUpd.e(biggestColIndex); @@ -108,11 +106,12 @@ void HHcolPivQR::_evalData() { } T normX; - NDArray* qrBlock = new NDArray(_qr({k,rows, k,k+1}, true)); - T c; - Householder::evalHHmatrixDataI(*qrBlock, c, normX); - _coeffs.p(k, c); - delete qrBlock; + { + auto qrBlock = _qr({k, rows, k, k + 1}, true); + T c; + Householder::evalHHmatrixDataI(qrBlock, c, normX); + _coeffs.p(k, c); + } _qr.p(k,k, normX); @@ -121,11 +120,9 @@ void HHcolPivQR::_evalData() { maxPivot = max; if(k < rows && (k+1) < cols) { - qrBlock = new NDArray(_qr({k, rows, k+1,cols}, true)); - auto tail = new NDArray(_qr({k+1,rows, k, k+1}, true)); - Householder::mulLeft(*qrBlock, *tail, _coeffs.e(k)); - delete qrBlock; - delete tail; + auto qrBlock = _qr({k, rows, k+1,cols}, true); + auto tail = _qr({k+1,rows, k, k+1}, true); + Householder::mulLeft(qrBlock, tail, _coeffs.e(k)); } for (int j = k + 1; j < cols; ++j) { @@ -153,13 +150,11 @@ void HHcolPivQR::_evalData() { for(int k = 0; k < _diagSize; ++k) { int idx = transp.e(k); - auto temp1 = new NDArray(_permut({0,0, k, k+1}, true)); - auto temp2 = new NDArray(_permut({0,0, idx,idx+1}, true)); - auto temp3 = *temp1; - temp1->assign(temp2); - temp2->assign(temp3); - delete temp1; - delete temp2; + auto temp1 = _permut({0,0, k, k+1}, true); + auto temp2 = _permut({0,0, idx,idx+1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); } } diff --git a/libnd4j/include/helpers/cpu/hhSequence.cpp b/libnd4j/include/helpers/cpu/hhSequence.cpp index 8a2a35329ef2..84c9e017a922 100644 --- a/libnd4j/include/helpers/cpu/hhSequence.cpp +++ b/libnd4j/include/helpers/cpu/hhSequence.cpp @@ -43,26 +43,25 @@ void HHsequence::_mulLeft(NDArray& matrix) { const int cols = _vectors.sizeAt(1); const int inRows = matrix.sizeAt(0); - NDArray* block(nullptr); + NDArray block; for(int i = _diagSize - 1; i >= 0; --i) { if(_type == 'u') { - block = new NDArray(matrix({inRows-rows+_shift+ i,inRows, 0,0}, true)); + block = matrix({inRows-rows+_shift+ i,inRows, 0,0}, true); T _x = _coeffs.e(i); - Householder::mulLeft(*block, _vectors({i + 1 + _shift, rows, i, i+1}, true), _x); + Householder::mulLeft(block, _vectors({i + 1 + _shift, rows, i, i+1}, true), _x); _coeffs.p(i, _x); } else { - block = new NDArray(matrix({inRows-cols+_shift+i,inRows, 0,0}, true)); + block = matrix({inRows-cols+_shift+i,inRows, 0,0}, true); T _x = _coeffs.e(i); - Householder::mulLeft(*block, _vectors({i, i+1, i + 1 + _shift, cols}, true), _x); + Householder::mulLeft(block, _vectors({i, i+1, i + 1 + _shift, cols}, true), _x); _coeffs.p(i, _x); } - delete block; } } diff --git a/libnd4j/include/helpers/cpu/householder.cpp b/libnd4j/include/helpers/cpu/householder.cpp index 39d97f1e1244..090cea17038c 100644 --- a/libnd4j/include/helpers/cpu/householder.cpp +++ b/libnd4j/include/helpers/cpu/householder.cpp @@ -137,8 +137,8 @@ void Householder::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff } else if(coeff != (T)0.f) { - auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true)); - auto bottomPartCopy = *bottomPart; + auto bottomPart = matrix({1,matrix.sizeAt(0), 0,0}, true); + auto bottomPartCopy = bottomPart.dup(); if(tail.isColumnVector()) { @@ -148,7 +148,7 @@ void Householder::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff auto fistRow = matrix({0,1, 0,0}, true); resultingRow += fistRow; fistRow -= resultingRow * coeff; - *bottomPart -= mmul(column, resultingRow) * coeff; + bottomPart -= mmul(column, resultingRow) * coeff; } else { @@ -158,9 +158,8 @@ void Householder::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff auto fistRow = matrix({0,1, 0,0}, true); resultingRow += fistRow; fistRow -= resultingRow * coeff; - *bottomPart -= mmul(column, resultingRow) * coeff; + bottomPart -= mmul(column, resultingRow) * coeff; } - delete bottomPart; } } @@ -177,30 +176,28 @@ void Householder::mulRight(NDArray& matrix, const NDArray& tail, const T coef else if(coeff != (T)0.f) { - auto rightPart = new NDArray(matrix({0,0, 1,matrix.sizeAt(1)}, true)); - auto rightPartCopy = *rightPart; - auto fistCol = new NDArray(matrix({0,0, 0,1}, true)); + auto rightPart = matrix({0,0, 1,matrix.sizeAt(1)}, true); + auto rightPartCopy = rightPart.dup(); + auto fistCol = matrix({0,0, 0,1}, true); if(tail.isColumnVector()) { auto column = tail; auto row = tail.transpose(); auto resultingCol = mmul(rightPartCopy, column); - resultingCol += *fistCol; - *fistCol -= resultingCol * coeff; - *rightPart -= mmul(resultingCol, row) * coeff; + resultingCol += fistCol; + fistCol -= resultingCol * coeff; + rightPart -= mmul(resultingCol, row) * coeff; } else { auto row = tail; auto column = tail.transpose(); auto resultingCol = mmul(rightPartCopy, column); - resultingCol += *fistCol; - *fistCol -= resultingCol * coeff; - *rightPart -= mmul(resultingCol, row) * coeff; + resultingCol += fistCol; + fistCol -= resultingCol * coeff; + rightPart -= mmul(resultingCol, row) * coeff; } - delete rightPart; - delete fistCol; } } diff --git a/libnd4j/include/helpers/cpu/jacobiSVD.cpp b/libnd4j/include/helpers/cpu/jacobiSVD.cpp index 1be91d50b631..8c6d1ccc7b70 100644 --- a/libnd4j/include/helpers/cpu/jacobiSVD.cpp +++ b/libnd4j/include/helpers/cpu/jacobiSVD.cpp @@ -78,7 +78,7 @@ void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, c throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out of array row range !"); auto pTemp = block({i,j+1,j-i, 0,0,0}, true, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(rotation, temp)); } else { @@ -109,7 +109,7 @@ void JacobiSVD::mulRotationOnRight(const int i, const int j, NDArray& block, throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: second argument is out of array column range !"); auto pTemp = block({0,0,0, i,j+1,j-i}, true, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, rotation)); } else { @@ -393,7 +393,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { if(_calcU) { auto temp1 = _u({0,0, pos,pos+1}, true); auto temp2 = _u({0,0, i,i+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -401,7 +401,7 @@ void JacobiSVD::evalData(const NDArray& matrix) { if(_calcV) { auto temp1 = _v({0,0, pos, pos+1}, true); auto temp2 = _v({0,0, i, i+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8fc2ece79e2f..0585b5399f19 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -143,7 +143,7 @@ namespace helpers { const int columnNum = input->columns(); NDArray determinant = NDArrayFactory::create(1.f, context); - NDArray compoundMatrix = *input; // copy + NDArray compoundMatrix = input->dup(); // copy NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); @@ -522,7 +522,7 @@ template NDArray output = NDArrayFactory::create(0., context); if (ND4J_STATUS_OK != determinant(context, &thisMatrix, &output)) return false; if (output.e(0) <= T(0)) return 0; - NDArray reversedMatrix(thisMatrix); + NDArray reversedMatrix = thisMatrix.dup(); if (ND4J_STATUS_OK != inverse(context, &thisMatrix, &reversedMatrix)) return false; if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) return false; if (output.e(0) <= T(0)) return 0; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index c2bbdb7c870b..cab9132eb43e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -60,7 +60,7 @@ namespace helpers { auto resR = fullMatricies?R->ulike():matrix->ulike(); std::vector q(M); - NDArray z = *matrix; + NDArray z = matrix->dup(); NDArray e('c', {M}, DataTypeUtils::fromT(), Q->getContext()); // two internal buffers and scalar for squared norm for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number @@ -88,7 +88,7 @@ namespace helpers { resQ.assign(q[0]); // // MmulHelper::matmul(&q[0], matrix, &resR, false, false); for (Nd4jLong i = 1; i < N && i < M - 1; i++) { - auto tempResQ = resQ; + auto tempResQ = resQ.ulike(); MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? resQ = std::move(tempResQ); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index 23ce7888071c..a5a07776d325 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -316,14 +316,14 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if (_calcU) { auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } else { auto temp1 = _u({0,2, col1+i, col1+i+1}, true); auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -331,7 +331,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(_calcV) { auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true); auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -643,14 +643,14 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA auto temp1 = U({0,0, i,i+1}, true); auto temp2 = U({0,0, i+1,i+2}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); if(_calcV) { auto temp1 = V({0,0, i,i+1}, true); auto temp2 = V({0,0, i+1,i+2}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -668,7 +668,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA for(int i = 0; i < curSize/2; ++i) { auto temp3 = temp2({0,0, i,i+1}, true); auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3; + auto temp5 = temp3.dup(); temp3.assign(temp4); temp4.assign(temp5); } @@ -678,7 +678,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA for(int i = 0; i < curSize/2; ++i) { auto temp3 = temp2({0,0, i,i+1}, true); auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3; + auto temp5 = temp3.dup(); temp3.assign(temp4); temp4.assign(temp5); } @@ -815,18 +815,18 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif if(_calcU) { auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, UofSVD)); } else { auto pTemp = _u({0,0, col1,col1+n+1}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, UofSVD)); } if (_calcV) { auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, VofSVD)); } From f64b42650f816022d761e65ecaf4c068148949b3 Mon Sep 17 00:00:00 2001 From: Oleg Date: Mon, 6 Apr 2020 10:34:14 +0300 Subject: [PATCH 083/233] libnd4j corrections and optimization of graph topological sort, added tests for several cases, need more tests cases Signed-off-by: Oleg --- libnd4j/include/graph/OptimizedGraph.h | 29 ++-- libnd4j/include/graph/impl/OptimizedGraph.cpp | 149 +++++++++-------- .../layers_tests/GraphAnalysisTests.cpp | 158 +++++++++++++++--- 3 files changed, 233 insertions(+), 103 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index fc1c4f105f7e..27e4b9d3fc4b 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -16,7 +16,9 @@ // // @author raver119@gmail.com +// @author oleg.semeniv@gmail.com // + #ifndef SD_OPTIMIZEDGRAPH_H #define SD_OPTIMIZEDGRAPH_H @@ -29,6 +31,7 @@ namespace sd { namespace graph { + class Graph; class NodeInfo; /** @@ -97,7 +100,7 @@ namespace sd { /* * optimize original graph */ - void optimizedGraph(); + void createOptimizedGraph(); /* * Topological graph analysis * @param const start node for search @@ -120,47 +123,53 @@ namespace sd { * @param node ID * @param layer ID * @param sequence ID + * @param map of layers and max sequence * @return stop iterating */ - bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq) const; + bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq, std::unordered_map& layersMaxSeq) const; /* * Initialize container with operations and context - * @param code reference to node information collector + * @param const reference to layers and sequence collection * @param reference to opSequence collector * @return stop iterating */ - bool initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const; + bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, std::vector>& vOpSeq) const; }; class NodeInfo { private: std::set sConnections; - bool bStart; + bool bInBranching; bool bOutBranching; + bool bProcessed; + int nLayer; int nSequence; public: + + NodeInfo(){ reset(); } + ~NodeInfo(){ reset(); } - void setStart(bool bValue) { bStart = bValue; } void setInBranching(bool bValue) { bInBranching = bValue; } void setOutBranching(bool bValue) { bOutBranching = bValue; } + void setProcessed() { bProcessed = true; } - void reset() { sConnections.clear(); bStart = bInBranching = bOutBranching = false; nLayer = 0; } + void reset() { sConnections.clear(); bProcessed = bInBranching = bOutBranching = false; nLayer = 0; nSequence = -1; } - int getLayer() const { return nLayer; } + int layer() const { return nLayer; } void setLayer(int layer) { nLayer = layer; } - int getSequence() const { return nSequence; } + int sequence() const { return nSequence; } void setSequence(int sequence) { nSequence = sequence; } void addConnection(int id) { sConnections.emplace(id); } const std::set& connections() const { return sConnections; } - bool isStart() const { return bStart; } bool isInBranching() const { return bInBranching; } bool isOutBranching() const { return bOutBranching; } + bool isProcessed() const { return bProcessed; } }; diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index f6de752dca96..5f3d5cfba8f3 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -16,6 +16,7 @@ // // @author raver119@gmail.com +// @author oleg.semeniv@gmail.com // #include @@ -26,7 +27,8 @@ namespace sd { OptimizedGraph::OptimizedGraph(Graph *original) { _originalGraph = original; _memoryManager = const_cast(&original->memoryManager()); - optimizedGraph(); + // create optimized graph + createOptimizedGraph(); } OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { @@ -89,12 +91,12 @@ namespace sd { return *_originalGraph; } - - bool OptimizedGraph::opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const { + bool OptimizedGraph::opGraphProto(std::unordered_map& collector, std::set& startNodes, + std::set& inBranchingNodes) const { if (originalGraph().unmappedNodes().empty()) return false; - + // iterate via original graph nodes to gather node information for (const auto& it : originalGraph().unmappedNodes()) { const auto& ID = it.first; @@ -117,14 +119,10 @@ namespace sd { collector[in->first].addConnection(ID); } } - - parentNode.setInBranching((inInternalCounts == inputs.size() && inInternalCounts > 1)); - - parentNode.setStart(inExCounts == inputs.size()); - parentNode.setSequence(-1); - - if (parentNode.isStart()) { - parentNode.setLayer(0); + // if move then 1 internal input this is in-branching node + parentNode.setInBranching( inInternalCounts > 1); + // gather start and in-branching node for the loop when operations put to OpSequence for opimized graph + if (inExCounts == inputs.size()) { startNodes.emplace(ID); } else { @@ -136,28 +134,28 @@ namespace sd { } bool OptimizedGraph::topolSearch(const int startNode, const std::unordered_map& collector, - std::vector >& opSeq) const { + std::vector >& opSeq) const { if (originalGraph().unmappedNodes().empty()) return false; auto itParent = collector.find(startNode); if (itParent != collector.end()) { - + // iterate via start nodes connections in depth for (const auto& itNodes : itParent->second.connections()) { auto itChild = collector.find(itNodes); if (itChild != collector.end()) { - + // if the child is in-branching node it will be treated as start node if (itChild->second.isInBranching()) { - return true; + continue; } - + // put operation to OpSequence container const auto it = originalGraph().unmappedNodes().find(itNodes); const auto& child = itChild->second; - opSeq[child.getLayer()][child.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); - + opSeq[child.layer()][child.sequence()].append(it->second.customOp(), it->second.contextPrototype()); + // go to the child node connections topolSearch(itNodes, collector, opSeq); } } @@ -165,101 +163,110 @@ namespace sd { return true; } - void OptimizedGraph::optimizedGraph() { + void OptimizedGraph::createOptimizedGraph() { std::unordered_map collector; std::set startNodes, inBranching; + std::unordered_map layersMaxSeq; - // todo check this will be empty Optimized graph + // optimizing graph prototyping + // select start nodes + // create connections between nodes + // select in-branching nodes ( more then one iternal input -> outputs from other nodes) if (!opGraphProto(collector, startNodes, inBranching)) throw std::runtime_error("OptimizedGraph::optimizedGraph() - not prototyped"); - + + // next step set the node layer and it sequence in layer + // define max layers and max sequence per layer int startSeq = 0; for (const auto& id : startNodes) { - layersSeqDefine(collector, id, 0, startSeq); + layersMaxSeq[0] = startSeq; + layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); startSeq++; } - + + // init container to collect operations per node position (layer:sequence) std::vector> vOpSeq; - initOpSeqContainer(collector, vOpSeq); - + initOpSeqContainer(layersMaxSeq, vOpSeq); + + // combine start nodes and in-branching nodes startNodes.insert(inBranching.begin(), inBranching.end()); - + // iterate via start and in-branching nodes for (const auto& id : startNodes) { - + const auto it = originalGraph().unmappedNodes().find(id); const auto& nodeInfo = collector[id]; - vOpSeq[nodeInfo.getLayer()][nodeInfo.getSequence()].append(it->second.customOp(), it->second.contextPrototype()); - + vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); + // search in depth via connections of "start" node topolSearch(id, collector, vOpSeq); } - + // put results to optimized graph for (auto& vSeq : vOpSeq) { this->append(vSeq); } } - bool OptimizedGraph::initOpSeqContainer(const std::unordered_map& collection, std::vector>& vOpSeq) const { + bool OptimizedGraph::initOpSeqContainer(const std::unordered_map& layersMaxSeq, std::vector>& vOpSeq) const { - if (collection.empty()) + if (layersMaxSeq.empty()) return false; - int layer = 0; - std::vector vSeq; - for (const auto& node : collection) { - - int nodeLayer = node.second.getLayer(); - int nodeSeq = node.second.getSequence(); - if (layer < nodeLayer) - layer = nodeLayer; - - if (vSeq.size() < nodeLayer + 1) { - vSeq.resize(nodeLayer + 1, 0); - } - // each layer will have it own max sequence - if (vSeq[nodeLayer] < nodeSeq) - vSeq[nodeLayer] = nodeSeq; - } - - vOpSeq.resize(layer + 1); - for (int i = 0; i <= layer; ++i) { - vOpSeq[i].resize(vSeq[i] + 1); + vOpSeq.resize(layersMaxSeq.size()); + for (const auto& it : layersMaxSeq) { + vOpSeq[it.first].resize(it.second + 1); } return true; } - bool OptimizedGraph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq) const { + bool OptimizedGraph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq, + std::unordered_map& layersMaxSeq) const { auto parent = collection.find(ID); if (parent == collection.end()) return false; - - if (parent->second.isInBranching()) { - layer++; - if (startSeq > 0) - startSeq--; + + // if node was proceed and the current layer is less of it own return + if(parent->second.isProcessed() && parent->second.layer() >= layer) + return true; + // put layer and sequence to container that collects layers and max sequence per layer + auto layerFound = layersMaxSeq.find(layer); + if(layerFound == layersMaxSeq.end()){ + layersMaxSeq[layer] = startSeq; } - - parent->second.setLayer(layer); - // sequence have to be init once - if (parent->second.getSequence() < 0) + else{ + layerFound->second = (layerFound->second < startSeq && parent->second.sequence() < 0) ? startSeq : layerFound->second; + } + // double check if the layer is higher and set node layer + if(parent->second.layer() < layer) + parent->second.setLayer(layer); + // double check if sequence was init, if not set current sequence + if(parent->second.sequence() < 0) parent->second.setSequence(startSeq); - + // set is node out-branching parent->second.setOutBranching(parent->second.connections().size() > 1); + // set that node was processed, to avoid it double processing (only for some cases it can be processed several times) + parent->second.setProcessed(); + + // if current node is out-branching it childs will be put to next layer if (parent->second.isOutBranching()) layer++; - + // for childs sequence position have to start from max defined sequence position + int seq = layersMaxSeq[layer]; for (const auto& id : parent->second.connections()) { auto child = collection.find(id); - - layersSeqDefine(collection, id, layer, startSeq); - - if (parent->second.isOutBranching()) - startSeq++; + if(child == collection.end()) + return false; + // in case parent was not out-branching node but child is in branching it will be put to next layer + if (!parent->second.isOutBranching() && child->second.isInBranching()) + layer++; + // move in depth of connections + layersSeqDefine(collection, id, layer, seq, layersMaxSeq); + + seq++; } return true; } } -} \ No newline at end of file +} diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index e7450dda1432..506f4f1121cb 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -113,7 +113,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { ASSERT_EQ(1, layer0.width()); auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops + // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); @@ -169,7 +169,7 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer + // we expect that OptimizedGraph has exactly 3 layer ASSERT_EQ(3, optimized.layers()); // checking first layer first @@ -215,8 +215,8 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); } -// currently does not work correctly -TEST_F(GraphAnalysisTests, DISABLED_basic_toposort_test_4) { + +TEST_F(GraphAnalysisTests, basic_toposort_test_4) { Graph graph; // A @@ -267,71 +267,185 @@ TEST_F(GraphAnalysisTests, DISABLED_basic_toposort_test_4) { auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer + // we expect that OptimizedGraph has exactly 4 layer ASSERT_EQ(4, optimized.layers()); // checking first layer first auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence + // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer0.width()); auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops + // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); sequence = layer0[1]; - // we expect that OpSequence has exactly 2 ops + // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - - // ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - // ASSERT_EQ(6, sequence.at(1).protoContext().nodeId()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences + // we expect layer has exactly 3 OpSequences ASSERT_EQ(3, layer1.width()); sequence = layer1[0]; ASSERT_EQ(1, sequence.length()); - // ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); sequence = layer1[1]; ASSERT_EQ(1, sequence.length()); - // ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + sequence = layer1[2]; - ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence + // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer2.width()); sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - - // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); // checking last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); + ASSERT_EQ(1, layer3.width()); + sequence = layer3[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); +} + + +TEST_F(GraphAnalysisTests, basic_toposort_test_5) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + Node f(sd::ops::multiply(), "f"); + + + Node g(sd::ops::multiply(), "g"); + Node h(sd::ops::multiply(), "h"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"C", "D"}); + + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"a", "b"}); + + graph.addNode(e, {"c", "d"}); + graph.addNode(f, {"c", "d"}); + + graph.addNode(g, {"c", "e"}); + graph.addNode(h, {"d", "f"}); + + + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(4, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); + auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + + sequence = layer0[1]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + + // checking before last layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - // ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + sequence = layer2[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + + // checking last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + sequence = layer3[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + + sequence = layer3[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); } From 6a2b3c5cd6ff0a616d09e68c9aea7fe4626d4e57 Mon Sep 17 00:00:00 2001 From: Oleg Date: Mon, 6 Apr 2020 17:18:16 +0300 Subject: [PATCH 084/233] libnd4j minor corrections, added one more test of graph, added description of steps Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 72 +++++++--- .../layers_tests/GraphAnalysisTests.cpp | 133 ++++++++++++++++++ 2 files changed, 182 insertions(+), 23 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 5f3d5cfba8f3..60de340c7bc2 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -93,35 +93,43 @@ namespace sd { bool OptimizedGraph::opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const { - + + // double check to avoid unstable behavior if (originalGraph().unmappedNodes().empty()) return false; + // iterate via original graph nodes to gather node information for (const auto& it : originalGraph().unmappedNodes()) { const auto& ID = it.first; const auto& inputs = it.second.input(); - - if (collector.find(ID) == collector.end()) + // if node info is not in collecter add it + if (collector.find(ID) == collector.end()) collector[ID] = NodeInfo(); NodeInfo& parentNode = collector[ID]; int inExCounts = 0, inInternalCounts = 0; - for (auto in = inputs.begin(); in != inputs.end(); ++in) { - if (originalGraph().variableSpace().hasVariable(in->first, 0)) { + for (const auto& in : inputs) { + if (originalGraph().variableSpace().hasVariable(in.first, 0)) { + // count external inputs, all inputs that are in + // varable space will be treated as external inputs inExCounts++; } else { + // count iternal inputs, all inputs that are not in external variable space + // will be treated as outputs from other nodes inInternalCounts++; - if (collector.find(in->first) == collector.end()) - collector[in->first] = NodeInfo(); - collector[in->first].addConnection(ID); + // if node info is not in collector add it + if (collector.find(in.first) == collector.end()) + collector[in.first] = NodeInfo(); + + collector[in.first].addConnection(ID); } } // if move then 1 internal input this is in-branching node parentNode.setInBranching( inInternalCounts > 1); - // gather start and in-branching node for the loop when operations put to OpSequence for opimized graph + // gather start and in-branching nodes for the loop when operations are put to OpSequence (topolSearch) if (inExCounts == inputs.size()) { startNodes.emplace(ID); } @@ -135,13 +143,15 @@ namespace sd { bool OptimizedGraph::topolSearch(const int startNode, const std::unordered_map& collector, std::vector >& opSeq) const { - - if (originalGraph().unmappedNodes().empty()) + + // double check to avoid unstable behavior + if (originalGraph().unmappedNodes().empty() || collector.empty() ) return false; - + + // skip nodes which are not pre-collected and pre-processed auto itParent = collector.find(startNode); if (itParent != collector.end()) { - // iterate via start nodes connections in depth + // iterate via start (in-branching) nodes connections in depth for (const auto& itNodes : itParent->second.connections()) { auto itChild = collector.find(itNodes); @@ -154,6 +164,7 @@ namespace sd { // put operation to OpSequence container const auto it = originalGraph().unmappedNodes().find(itNodes); const auto& child = itChild->second; + // the layer and sequence are pre-defined in layersSeqDefine method opSeq[child.layer()][child.sequence()].append(it->second.customOp(), it->second.contextPrototype()); // go to the child node connections topolSearch(itNodes, collector, opSeq); @@ -164,9 +175,12 @@ namespace sd { } void OptimizedGraph::createOptimizedGraph() { - + + // container to store node infor std::unordered_map collector; + // containers to store start and in-branching nodes std::set startNodes, inBranching; + // container to store max sequences per layer std::unordered_map layersMaxSeq; // optimizing graph prototyping @@ -174,20 +188,22 @@ namespace sd { // create connections between nodes // select in-branching nodes ( more then one iternal input -> outputs from other nodes) if (!opGraphProto(collector, startNodes, inBranching)) - throw std::runtime_error("OptimizedGraph::optimizedGraph() - not prototyped"); + throw std::runtime_error("OptimizedGraph::optimizedGraph() - not prototyped!"); // next step set the node layer and it sequence in layer // define max layers and max sequence per layer int startSeq = 0; for (const auto& id : startNodes) { layersMaxSeq[0] = startSeq; - layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); + if(!layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq)) + throw std::runtime_error("OptimizedGraph::layersSeqDefine() - not all nodes properly prototyped!"); startSeq++; } // init container to collect operations per node position (layer:sequence) std::vector> vOpSeq; - initOpSeqContainer(layersMaxSeq, vOpSeq); + if(!initOpSeqContainer(layersMaxSeq, vOpSeq)) + throw std::runtime_error("OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, not all nodes properly prototyped!"); // combine start nodes and in-branching nodes startNodes.insert(inBranching.begin(), inBranching.end()); @@ -198,7 +214,8 @@ namespace sd { const auto& nodeInfo = collector[id]; vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); // search in depth via connections of "start" node - topolSearch(id, collector, vOpSeq); + if(!topolSearch(id, collector, vOpSeq)) + throw std::runtime_error("OptimizedGraph::topolSearch() - cannot run topological search, inputs incorrect!"); } // put results to optimized graph for (auto& vSeq : vOpSeq) { @@ -207,7 +224,8 @@ namespace sd { } bool OptimizedGraph::initOpSeqContainer(const std::unordered_map& layersMaxSeq, std::vector>& vOpSeq) const { - + + // double check to avoid unstable behavior if (layersMaxSeq.empty()) return false; @@ -220,7 +238,8 @@ namespace sd { bool OptimizedGraph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq, std::unordered_map& layersMaxSeq) const { - + + // double check to avoid unstable behavior auto parent = collection.find(ID); if (parent == collection.end()) return false; @@ -228,14 +247,18 @@ namespace sd { // if node was proceed and the current layer is less of it own return if(parent->second.isProcessed() && parent->second.layer() >= layer) return true; + // put layer and sequence to container that collects layers and max sequence per layer auto layerFound = layersMaxSeq.find(layer); if(layerFound == layersMaxSeq.end()){ + // if layer was not treated before, create pair for it layersMaxSeq[layer] = startSeq; } else{ + // if node sequence position was not checked use it for max sequence selection layerFound->second = (layerFound->second < startSeq && parent->second.sequence() < 0) ? startSeq : layerFound->second; } + // double check if the layer is higher and set node layer if(parent->second.layer() < layer) parent->second.setLayer(layer); @@ -250,19 +273,22 @@ namespace sd { // if current node is out-branching it childs will be put to next layer if (parent->second.isOutBranching()) layer++; - // for childs sequence position have to start from max defined sequence position + + // childs sequence position have to start from max defined sequence position in layer int seq = layersMaxSeq[layer]; + // loop via childs (connected nodes) for (const auto& id : parent->second.connections()) { - + // double check to avoid unstable behavior auto child = collection.find(id); if(child == collection.end()) return false; + // in case parent was not out-branching node but child is in branching it will be put to next layer if (!parent->second.isOutBranching() && child->second.isInBranching()) layer++; // move in depth of connections layersSeqDefine(collection, id, layer, seq, layersMaxSeq); - + // increment sequence as childs are on the one layer seq++; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 506f4f1121cb..8ebd49e8ab94 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -449,3 +449,136 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_5) { ASSERT_EQ(1, sequence.length()); ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); } + +TEST_F(GraphAnalysisTests, basic_toposort_test_6) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // E + graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // F + graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + Node a(sd::ops::multiply(), "a"); + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::subtract(), "b2"); + + Node c1(sd::ops::add(), "c1"); + Node c2(sd::ops::multiply(), "c2"); + Node c3(sd::ops::subtract(), "c3"); + + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::multiply(), "d2"); + + Node e(sd::ops::add(), "e"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "D"}); + + graph.addNode(c1, {"b1", "E"}); + graph.addNode(c2, {"b1", "b2"}); + graph.addNode(c3, {"b2", "F"}); + + graph.addNode(d1, {"c1", "c2"}); + graph.addNode(d2, {"c2", "c3"}); + + graph.addNode(e, {"d1", "d2"}); + + // we just check that nodes were really added + ASSERT_EQ(9, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + auto sequence = layer0[0]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + + // checking midle layer + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(3, layer2.width()); + sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + + // checking before last layer + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + sequence = layer3[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); + + sequence = layer3[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); + + // checking last layer + auto layer4 = optimized.layer(4); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(1, layer4.width()); + sequence = layer4[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(15, sequence.at(0).protoContext().nodeId()); + +} From 1cb2d47e69bf2f368049e367a6d1d5e3ac5568ba Mon Sep 17 00:00:00 2001 From: Oleg Date: Tue, 7 Apr 2020 16:04:29 +0300 Subject: [PATCH 085/233] libnd4j: fixed behavior for several cases of directed graph, added tests Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 34 ++- .../layers_tests/GraphAnalysisTests.cpp | 200 ++++++++++++++++-- 2 files changed, 214 insertions(+), 20 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 60de340c7bc2..b7aaefacd77b 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -141,7 +141,7 @@ namespace sd { return true; } - bool OptimizedGraph::topolSearch(const int startNode, const std::unordered_map& collector, + bool OptimizedGraph::topolSearch(const int startNode, std::unordered_map& collector, std::vector >& opSeq) const { // double check to avoid unstable behavior @@ -158,14 +158,16 @@ namespace sd { if (itChild != collector.end()) { // if the child is in-branching node it will be treated as start node - if (itChild->second.isInBranching()) { + // skip processed nodes + if (itChild->second.isInBranching() || itChild->second.isProcessed()) { continue; } // put operation to OpSequence container const auto it = originalGraph().unmappedNodes().find(itNodes); - const auto& child = itChild->second; + auto& child = itChild->second; // the layer and sequence are pre-defined in layersSeqDefine method opSeq[child.layer()][child.sequence()].append(it->second.customOp(), it->second.contextPrototype()); + child.setProcessed(); // go to the child node connections topolSearch(itNodes, collector, opSeq); } @@ -205,14 +207,25 @@ namespace sd { if(!initOpSeqContainer(layersMaxSeq, vOpSeq)) throw std::runtime_error("OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, not all nodes properly prototyped!"); + // re-init proceed NodeInfo member to avoid append sequence several times + for(auto& it : collector){ + it.second.setProcessed(false); + } + // combine start nodes and in-branching nodes startNodes.insert(inBranching.begin(), inBranching.end()); + // iterate via start and in-branching nodes for (const auto& id : startNodes) { const auto it = originalGraph().unmappedNodes().find(id); - const auto& nodeInfo = collector[id]; - vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); + auto& nodeInfo = collector[id]; + // check to avoid node processing twice + if(!nodeInfo.isProcessed()){ + vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); + nodeInfo.setProcessed(); + } + // search in depth via connections of "start" node if(!topolSearch(id, collector, vOpSeq)) throw std::runtime_error("OptimizedGraph::topolSearch() - cannot run topological search, inputs incorrect!"); @@ -252,10 +265,14 @@ namespace sd { auto layerFound = layersMaxSeq.find(layer); if(layerFound == layersMaxSeq.end()){ // if layer was not treated before, create pair for it - layersMaxSeq[layer] = startSeq; + layersMaxSeq[layer] = 0; } else{ // if node sequence position was not checked use it for max sequence selection + // double check if input sequence do not jump max value twice + if(startSeq > (layerFound->second + 1)) + startSeq = layerFound->second + 1; + layerFound->second = (layerFound->second < startSeq && parent->second.sequence() < 0) ? startSeq : layerFound->second; } @@ -282,13 +299,14 @@ namespace sd { auto child = collection.find(id); if(child == collection.end()) return false; - + // in case parent was not out-branching node but child is in branching it will be put to next layer if (!parent->second.isOutBranching() && child->second.isInBranching()) layer++; + // move in depth of connections layersSeqDefine(collection, id, layer, seq, layersMaxSeq); - // increment sequence as childs are on the one layer + // increment sequence as childs are on the one layer in case if child was not processed earlier seq++; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 8ebd49e8ab94..3f0370ed140e 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -150,19 +150,19 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { // D graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "multiply"); - Node b(sd::ops::add(), "add"); - Node c(sd::ops::subtract(), "subtract"); - Node d(sd::ops::add(), "add2"); - Node e(sd::ops::multiply(), "multiply2"); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"multiply", "C"}); + graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"add", "D"}); - graph.addNode(d, {"add", "D"}); + graph.addNode(c, {"b", "D"}); + graph.addNode(d, {"b", "D"}); - graph.addNode(e, {"subtract", "add2"}); + graph.addNode(e, {"c", "d"}); // we just check that nodes were really added ASSERT_EQ(5, graph.size()); @@ -215,7 +215,6 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); } - TEST_F(GraphAnalysisTests, basic_toposort_test_4) { Graph graph; @@ -334,7 +333,6 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_4) { ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); } - TEST_F(GraphAnalysisTests, basic_toposort_test_5) { Graph graph; @@ -373,7 +371,6 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_5) { graph.addNode(g, {"c", "e"}); graph.addNode(h, {"d", "f"}); - // we just check that nodes were really added ASSERT_EQ(8, graph.size()); @@ -582,3 +579,182 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_6) { ASSERT_EQ(15, sequence.at(0).protoContext().nodeId()); } + +TEST_F(GraphAnalysisTests, basic_toposort_test_7) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"b", "c"}); + + graph.addNode(e, {"b", "c", "d"}); + + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(1, layer1.width()); + + sequence = layer1[0]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + + // checking layer 2 + auto layer2 = optimized.layer(2); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + + // checking layer 3 + auto layer3 = optimized.layer(3); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); + sequence = layer3[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + // checking layer 3 + auto layer4 = optimized.layer(4); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer4.width()); + sequence = layer4[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + + +TEST_F(GraphAnalysisTests, basic_toposort_test_8) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // B + graph.addVariable("B", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // C + graph.addVariable("C", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // D + graph.addVariable("D", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // E + graph.addVariable("E", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + // F + graph.addVariable("F", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + + + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + Node a3(sd::ops::add(), "a3"); + + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); + + graph.addNode(a1, { "A", "B" }); + graph.addNode(a2, { "C", "D" }); + graph.addNode(a3, { "E", "F" }); + + graph.addNode(b1, { "a1", "a2" }); + graph.addNode(b2, { "a1", "a2", "a3" }); + graph.addNode(b3, { "a2", "a3" }); + + // we just check that nodes were really added + ASSERT_EQ(6, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 2 layer + ASSERT_EQ(2, optimized.layers()); + + // checking first layer first + auto layer0 = optimized.layer(0); + + // we expect layer has exactly 3 OpSequence + ASSERT_EQ(3, layer0.width()); + auto sequence = layer0[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + sequence = layer0[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + + sequence = layer0[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + + // checking second layer now + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); + + sequence = layer1[0]; + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[2]; + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + +} From 9608a3fb9fd60a9953faa6c917cd13c5a66cbab5 Mon Sep 17 00:00:00 2001 From: Oleg Date: Tue, 7 Apr 2020 16:09:02 +0300 Subject: [PATCH 086/233] libnd4j minor build fixes Signed-off-by: Oleg --- libnd4j/include/graph/OptimizedGraph.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 27e4b9d3fc4b..72f06b46b39e 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -108,7 +108,7 @@ namespace sd { * @param operation gather * @return stop iterating */ - bool topolSearch(const int startNode, const std::unordered_map& nodesConnections, std::vector>& opSeq) const; + bool topolSearch(const int startNode, std::unordered_map& nodesConnections, std::vector>& opSeq) const; /* * Optimized graph analysis prototyping, gather nodes infor * @param reference to node information collector @@ -154,7 +154,7 @@ namespace sd { void setInBranching(bool bValue) { bInBranching = bValue; } void setOutBranching(bool bValue) { bOutBranching = bValue; } - void setProcessed() { bProcessed = true; } + void setProcessed(bool bValue = true) { bProcessed = bValue; } void reset() { sConnections.clear(); bProcessed = bInBranching = bOutBranching = false; nLayer = 0; nSequence = -1; } From 5431b0ca4cb116a71903e43cc00c6a23013cdf96 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 8 Apr 2020 09:20:47 +0300 Subject: [PATCH 087/233] var fix Signed-off-by: raver119 --- libnd4j/include/ops/declarable/BroadcastableBoolOp.h | 2 +- libnd4j/include/system/op_boilerplate.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h index c48650294315..64e61069f741 100644 --- a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h @@ -28,7 +28,7 @@ namespace sd { namespace ops { - class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{ + class SD_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{ protected: Nd4jStatus validateAndExecute(Context& block) override = 0; public: diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 549b70878116..8d51420e48c3 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1446,7 +1446,7 @@ };\ REGISTER_H(NAME) -#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \ +#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \ protected: \ void registerTypes(); \ Nd4jStatus validateAndExecute(Context& block); \ From 1b801254ff2a79ef013d9bce08e7b6f036201f27 Mon Sep 17 00:00:00 2001 From: Oleg Date: Wed, 8 Apr 2020 11:19:16 +0300 Subject: [PATCH 088/233] libnd4j fixed tests crashes in DeclarableOpsTests1 test case Signed-off-by: Oleg --- .../layers_tests/DeclarableOpsTests1.cpp | 172 +++++++++--------- 1 file changed, 81 insertions(+), 91 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 2d266b2ef4b4..0e0228702ca3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -277,8 +277,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot2) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - - } TEST_F(DeclarableOpsTests1, TestTensorDot3) { @@ -297,8 +295,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot3) { ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); - - } TEST_F(DeclarableOpsTests1, TestTensorDot4) { @@ -337,9 +333,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot5) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - - } @@ -359,9 +352,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot6) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - - } //////////////////////////////////////////////////////////////////// @@ -380,9 +370,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot7) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - - } //////////////////////////////////////////////////////////////////// @@ -401,9 +388,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot8) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - - } //////////////////////////////////////////////////////////////////// @@ -430,8 +414,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot9) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } @@ -452,8 +434,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot10) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } @@ -474,8 +454,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot11) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -495,8 +473,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot12) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -516,8 +492,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot13) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -537,8 +511,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot14) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -558,8 +530,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot15) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -579,7 +549,6 @@ TEST_F(DeclarableOpsTests1, TestTensorDot16) { ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.equalsTo(result)); - } //////////////////////////////////////////////////////////////////// @@ -619,13 +588,15 @@ TEST_F(DeclarableOpsTests1, AddMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + variableSpace->putVariable(1, x); + + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::add addOp; addOp.execute(block); - + ASSERT_TRUE(x.equalsTo(exp)); delete block; @@ -643,15 +614,17 @@ TEST_F(DeclarableOpsTests1, AddVectorVector1) { exp.assign(3); auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1, -2 }); sd::ops::add addOp; - + addOp.execute(block); - + ASSERT_TRUE(x.equalsTo(exp)); delete block; @@ -671,7 +644,9 @@ TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + variableSpace->putVariable(1, x); + + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::add addOp; @@ -697,7 +672,9 @@ TEST_F(DeclarableOpsTests1, AddScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::add addOp; @@ -723,7 +700,9 @@ TEST_F(DeclarableOpsTests1, SubtractMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; @@ -749,7 +728,9 @@ TEST_F(DeclarableOpsTests1, SubtractTest_1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; @@ -781,8 +762,6 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); - - } TEST_F(DeclarableOpsTests1, TestRng1) { @@ -866,8 +845,8 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, std::make_shared()); - auto block = new Context(1, variableSpace, true); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->appendT(0.0f); block->appendT(3.0f); @@ -932,7 +911,9 @@ TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; @@ -943,10 +924,8 @@ TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { delete block; delete variableSpace; - } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { @@ -960,7 +939,9 @@ TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; @@ -987,7 +968,9 @@ TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; @@ -1013,7 +996,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversesubtract subOp; @@ -1042,8 +1027,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1067,8 +1050,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1089,8 +1070,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { auto res = subOp.evaluate({ &x, &y }); ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1117,8 +1096,6 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_1) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); ASSERT_TRUE(exp.equalsTo(&z)); - - } ////////////////////////////////////////////////////////////////////// @@ -1143,8 +1120,6 @@ TEST_F(DeclarableOpsTests1, ReverseModTest_2) { ASSERT_TRUE(res.status() == ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(&exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1160,7 +1135,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversesubtract subOp; @@ -1187,7 +1164,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversesubtract subOp; @@ -1214,7 +1193,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversesubtract subOp; @@ -1240,7 +1221,9 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::multiply mul; @@ -1266,7 +1249,9 @@ TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::multiply mul; @@ -1292,7 +1277,9 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::multiply mul; @@ -1318,7 +1305,9 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::multiply mul; @@ -1389,8 +1378,6 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1408,8 +1395,6 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1424,8 +1409,6 @@ TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); ASSERT_TRUE(res.at(0).equalsTo(exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1450,8 +1433,6 @@ TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); ASSERT_TRUE(z.equalsTo(&exp)); - - } ////////////////////////////////////////////////////////////////////// @@ -1467,7 +1448,9 @@ TEST_F(DeclarableOpsTests1, DivideMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::divide div; @@ -1493,7 +1476,9 @@ TEST_F(DeclarableOpsTests1, DivideVectorVector1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::divide div; @@ -1519,7 +1504,9 @@ TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::divide div; @@ -1546,7 +1533,9 @@ TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::divide div; @@ -1572,7 +1561,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversedivide div; @@ -1598,7 +1589,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversedivide div; @@ -1624,7 +1617,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversedivide div; @@ -1650,7 +1645,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, true); + + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::reversedivide div; @@ -1676,8 +1673,6 @@ TEST_F(DeclarableOpsTests1, Test_Cast_1) { auto z = result.at(0); ASSERT_TRUE(yExp.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2979,7 +2974,6 @@ TEST_F(DeclarableOpsTests1, Reverse_1) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - } ////////////////////////////////////////////////////////////////////// @@ -3132,12 +3126,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } - - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_8) { From a1743411983eaa2326fb6411e4bb859bd75da94e Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 8 Apr 2020 12:27:31 +0300 Subject: [PATCH 089/233] - RNG tests fixed - dynamic_partition_bp fixed - RandomGenerator got explicit copy/move stuff Signed-off-by: raver119 --- libnd4j/include/graph/RandomGenerator.h | 12 ++++ .../include/graph/impl/RandomGenerator.cpp | 57 +++++++++++++++++++ .../include/ops/declarable/LegacyRandomOp.h | 3 +- .../declarable/generic/random/multinomial.cpp | 2 +- .../generic/transforms/dynamic_parititon.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 8 +-- .../ops/declarable/impl/LegacyRandomOp.cpp | 17 +++--- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 23 ++++---- 8 files changed, 94 insertions(+), 30 deletions(-) create mode 100644 libnd4j/include/graph/impl/RandomGenerator.cpp diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index 755fce062910..fdd8f88010fc 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -91,6 +92,17 @@ namespace sd { public: FORCEINLINE RandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); + + RandomGenerator(const RandomGenerator& other) noexcept; + + RandomGenerator& operator=(const RandomGenerator& other) noexcept; + + // move constructor + RandomGenerator(RandomGenerator&& other) noexcept; + + // move assignment operator + RandomGenerator& operator=(RandomGenerator&& other) noexcept; + /** * This method allows to change graph-level state in runtime. * PLEASE NOTE: this method will change state of node as well. diff --git a/libnd4j/include/graph/impl/RandomGenerator.cpp b/libnd4j/include/graph/impl/RandomGenerator.cpp new file mode 100644 index 000000000000..07ceb77c5747 --- /dev/null +++ b/libnd4j/include/graph/impl/RandomGenerator.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { + namespace graph { + RandomGenerator::RandomGenerator(const RandomGenerator& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; + } + + RandomGenerator& RandomGenerator::operator=(const RandomGenerator& other) noexcept { + if (this == &other) + return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; + } + + // move constructor + RandomGenerator::RandomGenerator(RandomGenerator&& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; + } + + // move assignment operator + RandomGenerator& RandomGenerator::operator=(RandomGenerator&& other) noexcept { + if (this == &other) + return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/LegacyRandomOp.h b/libnd4j/include/ops/declarable/LegacyRandomOp.h index 095a61e0d850..6dc7d8d72ee1 100644 --- a/libnd4j/include/ops/declarable/LegacyRandomOp.h +++ b/libnd4j/include/ops/declarable/LegacyRandomOp.h @@ -41,8 +41,7 @@ namespace sd { template Nd4jStatus validateAndExecute_(Context &block); - sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - sd::ResultSet execute(sd::graph::RandomGenerator& rng, std::vector& inputs, std::vector& tArgs, std::vector& iArgs, bool isInplace = false); + sd::ResultSet execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& tArgs = {}, const std::vector& iArgs = {}, const std::vector& dArgs = {}, bool isInplace = false); Nd4jStatus execute(Context* block) override; diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index 52f0b3070637..f3a5a20959de 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -98,7 +98,7 @@ namespace sd { auto dimA = (0 == dimC) ? 1 : 0; nShape[dimA] = numOfSamples; - DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast(INT_ARG(1)) : sd::DataType::INT64) : sd::DataType::INT64; + DataType nType = block.numD() ? D_ARG(0) : sd::DataType::INT64; return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape)); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp index 8ab7ba1e053d..3100b9ad4168 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp @@ -109,7 +109,7 @@ namespace ops { } outputList[0] = OUTPUT_VARIABLE(0); outputList[1] = OUTPUT_VARIABLE(1); - NDArray originalIndices(*indices); //->ordering(), indices->shapeInfo(), indices->dataType()); + auto originalIndices = indices->dup(); //->ordering(), indices->shapeInfo(), indices->dataType()); originalIndices.linspace(0); ops::dynamic_partition op; auto res = op.evaluate({&originalIndices, indices}, {numPartition}); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index db542b6e0bba..a2de81fc43cd 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -880,17 +880,17 @@ namespace sd { if (v == nullptr) continue; - auto var = std::make_shared(v); - var->markRemovable(false); + auto var = std::make_shared(*v, "", cnt); + in.push_back(cnt); variableSpace.putVariable(cnt--, var); } int et = 0; for (auto v: outputs) { - auto var = std::make_shared(v); - var->markRemovable(false); std::pair pair(1, et++); + auto var = std::make_shared(*v, "", pair.first, pair.second); + variableSpace.putVariable(pair, var); } diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index b89b476c0998..380a7ae67579 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -342,7 +342,7 @@ namespace sd { } else if (DataTypeUtils::isZ(xType)) { auto zShapeArr = INPUT_VARIABLE(0); auto zShapeVector = zShapeArr->asVectorT(); - auto dtype = DataType::BFLOAT16; + auto dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', zShapeVector); return SHAPELIST(newShape); @@ -354,14 +354,8 @@ namespace sd { return DeclarableOp::execute(block); } - sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace) { - std::vector ins(inputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - return this->execute(rng, ins, tas, ias, isInplace); - } - sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, std::vector& inputs, std::vector& tArgs, std::vector& iArgs, bool isInplace) { + sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& dArgs, bool isInplace) { VariableSpace variableSpace; ResultSet arrayList; //ResultSet arrayList; @@ -375,8 +369,8 @@ namespace sd { if (v == nullptr) continue; - auto var = std::make_shared(v); - var->markRemovable(false); + auto var = std::make_shared(*v, "", cnt); + in.push_back(cnt); variableSpace.putVariable(cnt--, var); } @@ -394,6 +388,9 @@ namespace sd { for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + for (int e = 0; e < dArgs.size(); e++) + block.appendD(dArgs.at(e)); + Nd4jStatus status = this->execute(&block); arrayList.setStatus(status); if (status != ND4J_STATUS_OK) diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 98aa4ce67289..efc50041d781 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -35,8 +35,6 @@ class RNGTests : public testing::Test { public: long _seed = 119L; - //sd::random::RandomBuffer *_rngA; - //sd::random::RandomBuffer *_rngB; sd::graph::RandomGenerator _rngA; sd::graph::RandomGenerator _rngB; @@ -575,7 +573,7 @@ TEST_F(RNGTests, Test_Uniform_2) { RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); sd::ops::LegacyRandomOp op(0); - auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}, {sd::DataType::FLOAT32}); ASSERT_EQ(Status::OK(), result.status()); @@ -1043,14 +1041,15 @@ TEST_F(RNGTests, test_multinomial_1) { sd::ops::random_multinomial op; RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) ); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0}, {}, {INT64}, false) ); + ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64); - auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); + auto result = op.evaluate({ &probsZ, &samples }, { }, { 1 }, {}, {INT64}); auto outputZ = result.at(0); ASSERT_EQ(Status::OK(), result.status()); @@ -1068,7 +1067,7 @@ TEST_F(RNGTests, test_multinomial_2) { sd::ops::random_multinomial op; RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -1077,7 +1076,7 @@ TEST_F(RNGTests, test_multinomial_2) { NDArray output2('c', { 20, 3 }, sd::DataType::INT64); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1 }, {}, {INT64}, false)); ASSERT_TRUE(expected2.isSameShape(output2)); ASSERT_TRUE(expected2.equalsTo(output2)); } @@ -1092,10 +1091,10 @@ TEST_F(RNGTests, test_multinomial_3) { sd::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0 }, {}, {INT64}, false)); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -1109,10 +1108,10 @@ TEST_F(RNGTests, test_multinomial_4) { RandomGenerator rng(1234, 1234); sd::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1 }, {}, {INT64}, false)); rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {INT64}, false)); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); } @@ -1210,7 +1209,7 @@ TEST_F(RNGTests, test_multinomial_6) { NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); From 6e739e49f90593e4ab84034de2fcf9a15b5c5a70 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 8 Apr 2020 12:32:39 +0300 Subject: [PATCH 090/233] one fixed test for Oleg Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 0e0228702ca3..ab5d9f4be127 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -681,7 +681,9 @@ TEST_F(DeclarableOpsTests1, AddScalarScalar1) { addOp.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; From 32418805e479c6dcfdd9b8faddf4393206f8cf26 Mon Sep 17 00:00:00 2001 From: Oleg Date: Wed, 8 Apr 2020 14:37:38 +0300 Subject: [PATCH 091/233] libnd4j replace in-place test parts by new semantics Signed-off-by: Oleg --- .../layers_tests/DeclarableOpsTests1.cpp | 205 ++++++------------ 1 file changed, 71 insertions(+), 134 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index ab5d9f4be127..12d6ddef9945 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -588,16 +588,16 @@ TEST_F(DeclarableOpsTests1, AddMatrices1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); - + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::add addOp; addOp.execute(block); - - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -617,15 +617,15 @@ TEST_F(DeclarableOpsTests1, AddVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::add addOp; addOp.execute(block); - - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -644,7 +644,6 @@ TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -652,8 +651,9 @@ TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { sd::ops::add addOp; addOp.execute(block); - - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -673,7 +673,6 @@ TEST_F(DeclarableOpsTests1, AddScalarScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -703,15 +702,15 @@ TEST_F(DeclarableOpsTests1, SubtractMatrices1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; subOp.execute(block); - - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -731,7 +730,6 @@ TEST_F(DeclarableOpsTests1, SubtractTest_1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -739,8 +737,9 @@ TEST_F(DeclarableOpsTests1, SubtractTest_1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); - + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -756,7 +755,6 @@ TEST_F(DeclarableOpsTests1, SubtractTest_2) { y.assign(1); exp.assign(2); - sd::ops::subtract subOp; auto res = subOp.evaluate({ &x, &y }); @@ -847,7 +845,6 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->appendT(0.0f); @@ -859,7 +856,9 @@ TEST_F(DeclarableOpsTests1, ClipByValue1) { clip.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -883,7 +882,6 @@ TEST_F(DeclarableOpsTests1, MergeAvgTest1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, NDArrayFactory::create('c', { 5, 5 })); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2, -3 }); @@ -914,15 +912,16 @@ TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::subtract subOp; subOp.execute(block); - - ASSERT_TRUE(x.equalsTo(&exp)); + + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -942,7 +941,6 @@ TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -950,7 +948,9 @@ TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -971,7 +971,6 @@ TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -979,7 +978,9 @@ TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -999,7 +1000,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1007,7 +1007,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1138,7 +1140,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1146,7 +1147,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1167,7 +1170,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1175,7 +1177,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1196,7 +1200,6 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1204,7 +1207,9 @@ TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { subOp.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1224,7 +1229,6 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1232,7 +1236,9 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { mul.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1252,15 +1258,16 @@ TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); sd::ops::multiply mul; mul.execute(block); - - ASSERT_TRUE(x.equalsTo(exp)); + + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1280,7 +1287,6 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1288,7 +1294,9 @@ TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { mul.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1308,7 +1316,6 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1316,7 +1323,9 @@ TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { mul.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -1479,7 +1488,6 @@ TEST_F(DeclarableOpsTests1, DivideVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1487,7 +1495,9 @@ TEST_F(DeclarableOpsTests1, DivideVectorVector1) { div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1507,7 +1517,6 @@ TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1515,7 +1524,9 @@ TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete block; delete variableSpace; @@ -1564,7 +1575,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1572,7 +1582,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1592,7 +1604,6 @@ TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1600,7 +1611,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1619,8 +1632,7 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); variableSpace->putVariable(-2, y); - - variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); block->fillInputs({ -1, -2 }); @@ -1628,7 +1640,9 @@ TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); delete variableSpace; delete block; @@ -1815,8 +1829,7 @@ TEST_F(DeclarableOpsTests1, Transpose1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, std::make_shared()); - + auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); sd::ops::transpose transpose; @@ -1850,7 +1863,6 @@ TEST_F(DeclarableOpsTests1, Permute1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); @@ -1882,7 +1894,6 @@ TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, std::make_shared()); auto block = new Context(1, variableSpace, false); // not-in-place block->fillInputs({ -1 }); @@ -2037,7 +2048,6 @@ TEST_F(DeclarableOpsTests1, Pnormpool2d1) { auto variableSpace = new VariableSpace(); variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); auto block = new Context(1, variableSpace, false); block->fillInputs({ -1 }); @@ -2284,8 +2294,6 @@ TEST_F(DeclarableOpsTests1, sru_bp) { ASSERT_TRUE(expGradW.equalsTo(gradW)); ASSERT_TRUE(expGradB.equalsTo(gradB)); ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - - } ////////////////////////////////////////////////////////////////// @@ -2393,8 +2401,6 @@ TEST_F(DeclarableOpsTests1, ArgMax1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } @@ -2415,7 +2421,6 @@ TEST_F(DeclarableOpsTests1, ArgMax2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - } @@ -2436,8 +2441,6 @@ TEST_F(DeclarableOpsTests1, ArgMax3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, ArgMax4) { @@ -2479,8 +2482,6 @@ TEST_F(DeclarableOpsTests1, ArgMax5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, ArgMax6) { @@ -2540,8 +2541,6 @@ TEST_F(DeclarableOpsTests1, SquareTests1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, OneHotTests_1) { @@ -2560,8 +2559,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, OneHotTests_2) { @@ -2579,8 +2576,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, OneHotTests_3) { @@ -2618,8 +2613,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, OneHotTests_5) { @@ -2639,8 +2632,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - - } TEST_F(DeclarableOpsTests1, OneHotTests_6) { @@ -2652,8 +2643,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_6) { auto z = result.at(0); ASSERT_EQ(e, z); - - } TEST_F(DeclarableOpsTests1, OneHotTests_7) { @@ -2665,8 +2654,6 @@ TEST_F(DeclarableOpsTests1, OneHotTests_7) { auto z = result.at(0); ASSERT_EQ(e, z); - - } TEST_F(DeclarableOpsTests1, FillAs_1) { @@ -2683,8 +2670,6 @@ TEST_F(DeclarableOpsTests1, FillAs_1) { ASSERT_TRUE(x.isSameShape(result.at(0))); ASSERT_NEAR(scalar, result.at(0).meanNumber().e(0), 1e-5f); - - } ////////////////////////////////////////////////////////////////////// @@ -2709,8 +2694,6 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { // array->printIndexedBuffer("Range integer 1"); ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - - } @@ -2736,8 +2719,6 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - - } @@ -2756,8 +2737,6 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { ASSERT_TRUE(exp.isSameShape(array)); ASSERT_TRUE(exp.equalsTo(array)); - - } ////////////////////////////////////////////////////////////////////// @@ -2774,8 +2753,6 @@ TEST_F(DeclarableOpsTests1, softmax_test1) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2790,8 +2767,6 @@ TEST_F(DeclarableOpsTests1, softmax_test2) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2806,8 +2781,6 @@ TEST_F(DeclarableOpsTests1, softmax_test3) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2822,8 +2795,6 @@ TEST_F(DeclarableOpsTests1, softmax_test4) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2838,8 +2809,6 @@ TEST_F(DeclarableOpsTests1, softmax_test5) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2854,8 +2823,6 @@ TEST_F(DeclarableOpsTests1, softmax_test6) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2870,8 +2837,6 @@ TEST_F(DeclarableOpsTests1, softmax_test7) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2886,8 +2851,6 @@ TEST_F(DeclarableOpsTests1, softmax_test8) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2902,8 +2865,6 @@ TEST_F(DeclarableOpsTests1, softmax_test9) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test10) { @@ -2917,8 +2878,6 @@ TEST_F(DeclarableOpsTests1, softmax_test10) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test11) { @@ -2932,8 +2891,6 @@ TEST_F(DeclarableOpsTests1, softmax_test11) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -2951,8 +2908,6 @@ TEST_F(DeclarableOpsTests1, softmax_test12) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_1) { @@ -2975,7 +2930,6 @@ TEST_F(DeclarableOpsTests1, Reverse_1) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - } ////////////////////////////////////////////////////////////////////// @@ -3100,8 +3054,6 @@ TEST_F(DeclarableOpsTests1, Reverse_6) { ASSERT_TRUE(expected.isSameShapeStrict(input)); ASSERT_TRUE(expected.equalsTo(&input)); - - } @@ -3152,8 +3104,6 @@ TEST_F(DeclarableOpsTests1, Reverse_8) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } //////////////////////////////////////////////////////////////////// @@ -3177,8 +3127,6 @@ TEST_F(DeclarableOpsTests1, Reverse_9) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } TEST_F(DeclarableOpsTests1, Reverse_10) { @@ -3193,8 +3141,6 @@ TEST_F(DeclarableOpsTests1, Reverse_10) { ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -3216,8 +3162,6 @@ TEST_F(DeclarableOpsTests1, Reverse_11) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } ////////////////////////////////////////////////////////////////////// @@ -3238,8 +3182,6 @@ TEST_F(DeclarableOpsTests1, Reverse_12) { //expected.printIndexedBuffer("Expected reverse"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } ////////////////////////////////////////////////////////////////////// @@ -3260,7 +3202,6 @@ TEST_F(DeclarableOpsTests1, Reverse_13) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - } ////////////////////////////////////////////////////////////////////// @@ -3280,8 +3221,6 @@ TEST_F(DeclarableOpsTests1, Reverse_14) { ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); - - } TEST_F(DeclarableOpsTests1, Test_Expose_1) { @@ -3299,8 +3238,6 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { ASSERT_TRUE(input0.equalsTo(z0)); ASSERT_TRUE(input1.equalsTo(z1)); - - } TEST_F(DeclarableOpsTests1, Test_Expose_2) { From cf6f214fc082c4af8362f7885cee342102c3ae67 Mon Sep 17 00:00:00 2001 From: Oleg Date: Wed, 8 Apr 2020 17:22:15 +0300 Subject: [PATCH 092/233] libnd4j fixed support of some graphs, added test case Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 24 ++-- .../layers_tests/GraphAnalysisTests.cpp | 136 ++++++++++++++++++ 2 files changed, 149 insertions(+), 11 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index b7aaefacd77b..1be0ed8122fd 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -158,8 +158,7 @@ namespace sd { if (itChild != collector.end()) { // if the child is in-branching node it will be treated as start node - // skip processed nodes - if (itChild->second.isInBranching() || itChild->second.isProcessed()) { + if (itChild->second.isInBranching() || itChild->second.isProcessed()) { // proceed continue; } // put operation to OpSequence container @@ -206,26 +205,23 @@ namespace sd { std::vector> vOpSeq; if(!initOpSeqContainer(layersMaxSeq, vOpSeq)) throw std::runtime_error("OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, not all nodes properly prototyped!"); - + + // combine start nodes and in-branching nodes + startNodes.insert(inBranching.begin(), inBranching.end()); // re-init proceed NodeInfo member to avoid append sequence several times for(auto& it : collector){ it.second.setProcessed(false); } - // combine start nodes and in-branching nodes - startNodes.insert(inBranching.begin(), inBranching.end()); - // iterate via start and in-branching nodes for (const auto& id : startNodes) { const auto it = originalGraph().unmappedNodes().find(id); auto& nodeInfo = collector[id]; - // check to avoid node processing twice if(!nodeInfo.isProcessed()){ vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); nodeInfo.setProcessed(); - } - + } // search in depth via connections of "start" node if(!topolSearch(id, collector, vOpSeq)) throw std::runtime_error("OptimizedGraph::topolSearch() - cannot run topological search, inputs incorrect!"); @@ -266,10 +262,11 @@ namespace sd { if(layerFound == layersMaxSeq.end()){ // if layer was not treated before, create pair for it layersMaxSeq[layer] = 0; + // set sequence value to 0, as this is first sequence in layer + startSeq = 0; } else{ // if node sequence position was not checked use it for max sequence selection - // double check if input sequence do not jump max value twice if(startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; @@ -292,7 +289,12 @@ namespace sd { layer++; // childs sequence position have to start from max defined sequence position in layer - int seq = layersMaxSeq[layer]; + // or if it is first node in layer from 0 + int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) ? 0 : layersMaxSeq[layer]; + // if parent is out-branching node sequence have to be increment + // on the next stage the sequence value will be double checked with max per layer + seq = parent->second.isOutBranching() ? seq + 1 : seq; + // loop via childs (connected nodes) for (const auto& id : parent->second.connections()) { // double check to avoid unstable behavior diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 3f0370ed140e..da4e83eb9e57 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -758,3 +758,139 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_8) { ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); } + +TEST_F(GraphAnalysisTests, basic_toposort_test_9) { + + // start graph + + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a(sd::ops::multiply(), "a"); + + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::multiply(), "b2"); + Node b3(sd::ops::subtract(), "b3"); + Node b4(sd::ops::Pow(), "b4"); + + Node c1(sd::ops::Pow(), "c1"); + Node c2(sd::ops::subtract(), "c2"); + Node c3(sd::ops::multiply(), "c3"); + Node c4(sd::ops::add(), "c4"); + + Node c5(sd::ops::Pow(), "c5"); + Node c6(sd::ops::subtract(), "c6"); + Node c7(sd::ops::multiply(), "c7"); + Node c8(sd::ops::add(), "c8"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "C"}); + graph.addNode(b3, {"a", "C"}); + graph.addNode(b4, {"a", "C"}); + + graph.addNode(c1, {"b1", "D"}); + graph.addNode(c2, {"b2", "D"}); + graph.addNode(c3, {"b3", "D"}); + graph.addNode(c4, {"b4", "D"}); + + graph.addNode(c5, {"b1", "D"}); + graph.addNode(c6, {"b2", "D"}); + graph.addNode(c7, {"b3", "D"}); + graph.addNode(c8, {"b4", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(13, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.layers()); + + auto layer = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + + auto layer1 = optimized.layer(1); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(4, layer1.width()); + sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + + sequence = layer1[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(8, layer2.width()); + sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(15, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[4]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[5]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(16, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[6]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); + + sequence = layer2[7]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(17, sequence.at(0).protoContext().nodeId()); +} From b76394715e1b851cb8719c9d55437d7c67bb8092 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 9 Apr 2020 10:48:04 +0300 Subject: [PATCH 093/233] two special graphs for Oleg Signed-off-by: raver119 --- .../layers_tests/GraphAnalysisTests.cpp | 8 ++++++++ libnd4j/tests_cpu/resources/cond_false.fb | Bin 0 -> 4384 bytes libnd4j/tests_cpu/resources/cond_true.fb | Bin 4088 -> 4384 bytes 3 files changed, 8 insertions(+) create mode 100644 libnd4j/tests_cpu/resources/cond_false.fb diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index da4e83eb9e57..db3c22a962ac 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -894,3 +894,11 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_9) { ASSERT_EQ(1, sequence.length()); ASSERT_EQ(17, sequence.at(0).protoContext().nodeId()); } + +TEST_F(GraphAnalysisTests, test_cond_1) { + auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); +} + +TEST_F(GraphAnalysisTests, test_cond_2) { + auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/resources/cond_false.fb b/libnd4j/tests_cpu/resources/cond_false.fb new file mode 100644 index 0000000000000000000000000000000000000000..65629204d743620ebafff49ac0756e915584aac3 GIT binary patch literal 4384 zcmb`LPiS0K6vpqgNhfW`Hl`EEVvtz`2^dI}EQ&-Ou}BxO7MewoKs#h6c3?6Yk|#+^ z7o%j6MHWFr$)bxaLMaj?Xptgah-4v>MUX696c-|{MBHe#R>$9W-@Vf})0vVcIq=PW z_uTW&J^$``?@h{h>M}2I*=9OS!gQLX=`mf*$G|XH04dN9dVw(~km|5kM@*cV=y0s+ z$WHk=0@jZiQ$}=AlQv^|4uv`sOw*mLoit_>1$`iR&X_a6fNca;!DY}7{yGjHdhfq# zm7!bHo8Ug!0pEgqU<0fHtt|nWRAd|hr$IkpTQdh<+c~~<>XQ#o#<<2ba}fQF{8GY;>10yrxe!qvP-|Ue282On~aqu=jJSTSn8%+Y)(@+kpvy!5xn2Q z&u(p-QF&cwTm;Hx5p)8@&gbH6i;y{nwt|S=Gog_yKH#ufYab11n$&%md{($0$3}Kt7d$dX*=|B8GCc zm+%-GBcAG5cGTj1!Jg4qUxuf1VDlaFn1$4e@7JTF_E2AtFS;&cLrzXY+P95s`84L{ zgMS}_v2o3b{D-y-NyBrS>^9KxCiclHK6)aPpN=8puVCy9AT6nX`nj12ox(% zvZdW&GC%#kY!y$npz~|%?H#wkT-1smuGarbR~)F;dYB}{fYjrnb+$rWGVq2#KZx3# zU@m(+ITe0d8ylY!zdD`!$)oo5bvFFyVC(D~dW>T89iwvaFzhqk;#I&VP>yedn_w9% zfC>nJ@_U|9{-r=CxD3Ad_(3o`@zSc|B)zJ|Vxe?tM$Zne#g|{T9^;=Y@H{v z`c&J>>1igtK)Du)w+_aD_N?GzhVdZtyo1hRbnZ3JvdiZ<&3l^8Orw20@S=4n_rz;P ziuv0k@4h+8sI!{O2HDZ8Wrf~+Nw3m=f08@JMv*&4WE4N05xs+)ckdw2*cZB`<2twt zHozKK0gIpl=0FjsZscRMZZ0zS?;r9_knPRhA5~A3)8N2#KAYp(Q!Bn-kB-u>IBDE+ z3-Vuet@zx=#xsnKc}T)bI2i^@($R_+e6>_x;!**{XA? zn19Zi?5Ult$ouyR)`>=){7=ZM*EY&4Mz1@0*`5N@udg8aro1T^ic9U@$a}=KCVQQu zx9Oh7cZpk8w!O}-YshMk#^=EIO=NE&tA19qqth`TIS)hb4D`##$kgI~B`5{e-P*O~;vcmuvs$wcO;FnYJ`ty{xL3We)%Q!1iuRSEudq0sR(ZnhXO3Us{EOPD zeAmSJq^~%FSNyS{9XaF9W_<`M-loJw>?ZU-Y%cJI=he- V`L(EBdTw;!g|AcHr=BeN{{e?OWBmXC literal 0 HcmV?d00001 diff --git a/libnd4j/tests_cpu/resources/cond_true.fb b/libnd4j/tests_cpu/resources/cond_true.fb index 003f7868a2d6be3d5dc663dd64b66862b69470c3..75a354c141920bc978fe1aa58f5ec4e40e4e02eb 100644 GIT binary patch literal 4384 zcmb`LL17{_PZq?@*D8`BNsu*e>Q1Pmlf4n?A_Sfqzo3(cWOpk1;XyRg|T$&;j| zhgEXOA%~!$;yIayE1Ha6h zneTn`&G&up+xL<(zIB-wxNI{WCSf{F()5@v#$zxH3y^|-=mld=Ak|^7j+i(jwZpNh zYj((=Bd~tVm@=Y^8nhYHb12lAV3_V??W8fADCmRWIb+U%fo%j<;WG5YU&ooJz4u?W zO4qI7O}Gy`@Gabf4Oj!MErCob(vQGt=m)knbMV^E@vT#ze0Va(HKv)PFg{hv4NR5t zGhAz4=%j^Qz7!OKt6bVS$;(tLKQmhh#;;tO(4uQkTk|r}Ui;`w>hHP*TxDaRy~=-7 zs~P$s&>03H4M|9VWPFp3W=;AFrE)Plo_{M_ESgQHQ@Z*tU{kAq}$gd9d*sDU>dh@xnmqgGuI;>-F<9!Bp9pQ%&Zx z(}9f&YgZxGUugaqs*FYcpLev_`f zpWEp;dZTK2CU=3Da3FbjD&e6+`H0VraP_6P!zKEe*eZfAXufEKjo&y{2aE?*fTFv|QXsUoN zeeiLbcRJ;ty~K&ztgZg}-9g`e;x1oIOlK$a=YZ?phvAv`W1n=Er}MeOcs7vE&!}07 z^>6=FtD(Ce*P64ts5PoKuGY*de%Ik1+<`Sv{;8%^U=CDQl!qhqI!79mhZjM<@=dn1 zJ0|nf@5@%rDHrto+I)M*EihJVH6O0l|CO#dDA#%zBs4)&jf>{l3Vz8jHw6669 z_W0&h_+hPY{GIsK>D*5owXe^!;YSCXXWvj`2#z1>kIAezXAmhA)&S7-!HO{ij$2iSzfU^<`8aqU~HdA}a*rC)whzvUKm ze&x0Na~m7a&^N{*$y~z8(Ag#(ty*JENb)up{a(=fopc^l{@if&{YU%0A3G#l^<2v5 zpR*==YR^{W{rd##c%x4KC*)OY8|CGr*PXmrFL)RJ>puEz0OhF zbWh{E#4Rh^UT4=eWVJ`*@4)v>WN#s>dRDWeR%^QOyDV1L%13egKHKPnU$|cqjk(op z{o5W}Vr?1vfbH$GrE@EXYI{cfJa21kdo1l_UF8tzBEr|53X#t2O)3L`&PhPx$F8?p5t;_5G40qkU%RD=bc>Rh&5XGy5-a=Zo5@ zc-Q$aE!~gf8};7=_p2_o`liw;dK#?NxzHBN@%IYIcGOh$KZ@{Naa+SAn7?n*>l6^U9TcKZA7dw2RWzD7;v!1vC(=iYbj zx#!<|Z^Qyh*(GlJEonoRvNSNufootLjDT@q_8hZGcX!qX8Hvs$+9A0imjUq0IkVlf zW}E$H>ApbMdE|4?n;r6C6a4h7*$!xdc`yp@4w;RCEiHyN|7f>i7YDhtz#iBE+u#n^ z0vkYfHi1l;hfji0kOHs$cz*xF=bx~%St5i^Q7RNG%~JDYFUM5F1Y4+93RfCyrRKu& ztwl|SbljX@Bd$GUuXsuq=nU8aH-YxK_j|jYg-?MAAbT?)4LH7A19KFr@WT$5i%OMR zIloYRGhZ%S8ee6RNtfa!z4AevPQ#>y+-|s- zzWOpU`Od{TkTDAAX>vdBx%Sk4ijD4T+!VqEe?I+Ej z?pSa4z&_fDE!OiZ#p{6k`v>i|YU$^YkK3Hr1jrvVI^RxV(BdoM0?^nmm{d zu3qnd$tU}(_xY3yb)Gc-I)!}T`^Pr=dYMlLJ+QAkpDt6EisMzdV)XFP^`*3AE$}_~ z4yZ4E4Yt4rSO>}%<;gr;J{tuoARl=0F-;;Umnw}K&02FlUpl>Mu2i|vgM6CP}6-FruV_l>d7yP_Ie?o$Dr$%{k7Vh3hQJ`w5^krGffw^jr*M zNj@x;X!ea#wQ?Jug*vB@OP_RSK%7p@i6!;qTE5Z9E!V57OUriH13MMp#&W)&pm&z>N&NZ57wN1**Q}Zy^hDo-8O3q2UR{CN zWok+_aTT6~zYf>@xXu5czIsviw!j|P0oy?D@GYeOBvRbVui}bjg*AF^I-JjNul$iMGp&C@_2K>R>&K5R!zAR# z3&GG`m80&xTCEiIaeCkFch1?9W9;pDggsijADNdu0yGfvU=()Gxcy#Wh1<{ftx4IS|K-32S<$hfPxayV*9z6LwE24L__FiQGHfp% Vqf7rpI*Z=$NwoPo>*$H}#K(F09321v From 31557eaeeec6dc5318f6de23a014a99e781b0f91 Mon Sep 17 00:00:00 2001 From: Oleg Date: Thu, 9 Apr 2020 14:05:14 +0300 Subject: [PATCH 094/233] libnd4j: added one more test for directed graph testing, minor corrections Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 3 + .../layers_tests/GraphAnalysisTests.cpp | 67 ++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 1be0ed8122fd..35b6e77a99c6 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -111,6 +111,8 @@ namespace sd { int inExCounts = 0, inInternalCounts = 0; for (const auto& in : inputs) { + + // todo have to be correctly defined what input is external and what internal if (originalGraph().variableSpace().hasVariable(in.first, 0)) { // count external inputs, all inputs that are in // varable space will be treated as external inputs @@ -218,6 +220,7 @@ namespace sd { const auto it = originalGraph().unmappedNodes().find(id); auto& nodeInfo = collector[id]; + // append only not processed nodes if(!nodeInfo.isProcessed()){ vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); nodeInfo.setProcessed(); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index db3c22a962ac..d1aa4ca335e1 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -895,10 +895,75 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_9) { ASSERT_EQ(17, sequence.at(0).protoContext().nodeId()); } +TEST_F(GraphAnalysisTests, basic_toposort_test_10) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + graph.addNode(c, {"a", "D"}); + graph.addNode(d, {"a", "b", "c"}); + + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.layers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + + auto layer1 = optimized.layer(1); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer1.width()); + sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); } TEST_F(GraphAnalysisTests, test_cond_2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); -} \ No newline at end of file +} From 8ca8440d0b9d86ede69ddef74370d360f1ae91c7 Mon Sep 17 00:00:00 2001 From: Oleg Date: Thu, 9 Apr 2020 16:14:33 +0300 Subject: [PATCH 095/233] libnd4j fixed special case of graph, test added Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 14 ++++- .../layers_tests/GraphAnalysisTests.cpp | 57 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 35b6e77a99c6..5a5995f00170 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -198,8 +198,18 @@ namespace sd { int startSeq = 0; for (const auto& id : startNodes) { layersMaxSeq[0] = startSeq; - if(!layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq)) - throw std::runtime_error("OptimizedGraph::layersSeqDefine() - not all nodes properly prototyped!"); + // if only start nodes exists they have to be add to connections + if (bOnlyStartNodes) { + auto node = NodeInfo(); + node.setLayer(0); + node.setProcessed(true); + node.setSequence(startSeq); + collector[id] = node; + } + else { + if (!layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq)) + throw std::runtime_error("OptimizedGraph::layersSeqDefine() - not all nodes properly prototyped!"); + } startSeq++; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index d1aa4ca335e1..ebf67522f43c 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -960,8 +960,65 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_10) { ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } +TEST_F(GraphAnalysisTests, basic_toposort_test_11) { + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"A", "C"}); + graph.addNode(c, {"B", "D"}); + graph.addNode(d, {"C", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.layers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(4, layer.width()); + auto sequence = layer[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + sequence = layer[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + sequence = layer[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + + auto optimized = graph.optimizedGraph(); } TEST_F(GraphAnalysisTests, test_cond_2) { From 09dc51f431f272a84bc45b3529d61c284df33a03 Mon Sep 17 00:00:00 2001 From: Oleg Date: Thu, 9 Apr 2020 16:36:30 +0300 Subject: [PATCH 096/233] libnd4j fixed selection of external inputs Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 5a5995f00170..c13b436e8371 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -112,10 +112,8 @@ namespace sd { int inExCounts = 0, inInternalCounts = 0; for (const auto& in : inputs) { - // todo have to be correctly defined what input is external and what internal - if (originalGraph().variableSpace().hasVariable(in.first, 0)) { - // count external inputs, all inputs that are in - // varable space will be treated as external inputs + if (originalGraph().unmappedNodes().find(in.first) == originalGraph().unmappedNodes().end()) { + // count external inputs, all inputs which id is node in unmapped container will be treaded as external inExCounts++; } else { From 83522cb867a4ffc6f0f9813d50ff22dbc3e8e2c8 Mon Sep 17 00:00:00 2001 From: Oleg Date: Fri, 10 Apr 2020 11:36:05 +0300 Subject: [PATCH 097/233] libnd4j added support of logic operation handling in optimized graph, need some clarification and additional testing Signed-off-by: Oleg --- libnd4j/include/graph/OptimizedGraph.h | 8 +++- libnd4j/include/graph/impl/OptimizedGraph.cpp | 43 ++++++++++++------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 72f06b46b39e..c675ee9ff2e6 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -147,6 +147,8 @@ namespace sd { int nLayer; int nSequence; + + sd::graph::OpType opType; public: NodeInfo(){ reset(); } @@ -156,7 +158,7 @@ namespace sd { void setOutBranching(bool bValue) { bOutBranching = bValue; } void setProcessed(bool bValue = true) { bProcessed = bValue; } - void reset() { sConnections.clear(); bProcessed = bInBranching = bOutBranching = false; nLayer = 0; nSequence = -1; } + void reset() { sConnections.clear(); bProcessed = bInBranching = bOutBranching = false; nLayer = 0; nSequence = -1; opType = OpType_CUSTOM; } int layer() const { return nLayer; } void setLayer(int layer) { nLayer = layer; } @@ -167,6 +169,10 @@ namespace sd { void addConnection(int id) { sConnections.emplace(id); } const std::set& connections() const { return sConnections; } + void setType(sd::graph::OpType value){ opType = value; } + sd::graph::OpType type() const { return opType; } + bool isLogic(){ return opType == OpType_LOGIC; } + bool isInBranching() const { return bInBranching; } bool isOutBranching() const { return bOutBranching; } bool isProcessed() const { return bProcessed; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index c13b436e8371..4e5bbbd2a8f6 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -97,9 +97,10 @@ namespace sd { // double check to avoid unstable behavior if (originalGraph().unmappedNodes().empty()) return false; - + + const auto& unmappedNodes = originalGraph().unmappedNodes(); // iterate via original graph nodes to gather node information - for (const auto& it : originalGraph().unmappedNodes()) { + for (const auto& it : unmappedNodes) { const auto& ID = it.first; const auto& inputs = it.second.input(); @@ -112,7 +113,7 @@ namespace sd { int inExCounts = 0, inInternalCounts = 0; for (const auto& in : inputs) { - if (originalGraph().unmappedNodes().find(in.first) == originalGraph().unmappedNodes().end()) { + if (unmappedNodes.find(in.first) == unmappedNodes.end() ){ // count external inputs, all inputs which id is node in unmapped container will be treaded as external inExCounts++; } @@ -127,6 +128,8 @@ namespace sd { collector[in.first].addConnection(ID); } } + // set operation type + parentNode.setType( it.second.opType() ); // if move then 1 internal input this is in-branching node parentNode.setInBranching( inInternalCounts > 1); // gather start and in-branching nodes for the loop when operations are put to OpSequence (topolSearch) @@ -194,19 +197,19 @@ namespace sd { // next step set the node layer and it sequence in layer // define max layers and max sequence per layer int startSeq = 0; + bool bOnlyStartNodes = collector.empty(); for (const auto& id : startNodes) { layersMaxSeq[0] = startSeq; // if only start nodes exists they have to be add to connections - if (bOnlyStartNodes) { + if(bOnlyStartNodes){ auto node = NodeInfo(); node.setLayer(0); node.setProcessed(true); node.setSequence(startSeq); collector[id] = node; } - else { - if (!layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq)) - throw std::runtime_error("OptimizedGraph::layersSeqDefine() - not all nodes properly prototyped!"); + else{ + layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); } startSeq++; } @@ -215,7 +218,7 @@ namespace sd { std::vector> vOpSeq; if(!initOpSeqContainer(layersMaxSeq, vOpSeq)) throw std::runtime_error("OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, not all nodes properly prototyped!"); - + // combine start nodes and in-branching nodes startNodes.insert(inBranching.begin(), inBranching.end()); // re-init proceed NodeInfo member to avoid append sequence several times @@ -228,11 +231,11 @@ namespace sd { const auto it = originalGraph().unmappedNodes().find(id); auto& nodeInfo = collector[id]; - // append only not processed nodes if(!nodeInfo.isProcessed()){ vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); nodeInfo.setProcessed(); - } + } + // search in depth via connections of "start" node if(!topolSearch(id, collector, vOpSeq)) throw std::runtime_error("OptimizedGraph::topolSearch() - cannot run topological search, inputs incorrect!"); @@ -294,9 +297,9 @@ namespace sd { parent->second.setOutBranching(parent->second.connections().size() > 1); // set that node was processed, to avoid it double processing (only for some cases it can be processed several times) parent->second.setProcessed(); - + // if current node is out-branching it childs will be put to next layer - if (parent->second.isOutBranching()) + if (parent->second.isOutBranching() && !parent->second.isLogic()) layer++; // childs sequence position have to start from max defined sequence position in layer @@ -304,7 +307,8 @@ namespace sd { int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) ? 0 : layersMaxSeq[layer]; // if parent is out-branching node sequence have to be increment // on the next stage the sequence value will be double checked with max per layer - seq = parent->second.isOutBranching() ? seq + 1 : seq; + // todo check logic part + seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 : seq; // loop via childs (connected nodes) for (const auto& id : parent->second.connections()) { @@ -313,17 +317,24 @@ namespace sd { if(child == collection.end()) return false; + // todo check this do we need to set op type logic for childs + if(parent->second.isLogic()) + child->second.setType(parent->second.type()); + // in case parent was not out-branching node but child is in branching it will be put to next layer - if (!parent->second.isOutBranching() && child->second.isInBranching()) + if (!parent->second.isOutBranching() && child->second.isInBranching() && !child->second.isLogic()) layer++; // move in depth of connections layersSeqDefine(collection, id, layer, seq, layersMaxSeq); - // increment sequence as childs are on the one layer in case if child was not processed earlier - seq++; + // increment sequence as childs are on the one layer in case if child was not processed earlier + if(!parent->second.isLogic()) + seq++; } return true; } + + } } From 450dce0ebbb5da91e656312b1100fde2a7e32290 Mon Sep 17 00:00:00 2001 From: Oleg Date: Fri, 10 Apr 2020 14:18:02 +0300 Subject: [PATCH 098/233] libnd4j temporary disabled one step as seems it is wrong Signed-off-by: Oleg --- libnd4j/include/graph/Node.h | 2 +- libnd4j/include/graph/impl/Node.cpp | 2 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index d98064eb1142..c937221b3df0 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -183,7 +183,7 @@ namespace sd { void markInplace(bool reallyInplace); - OpClass getOpClass(); + OpClass getOpClass() const; // these methods are used for internal profiling void setOuterTime(Nd4jLong time); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index b72c1277e5c3..ffa2a9fde569 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -132,7 +132,7 @@ namespace sd { _removable = reallyRemovable; } - OpClass Node::getOpClass() { + OpClass Node::getOpClass() const { return _opClass; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 4e5bbbd2a8f6..49ebabf16f01 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -317,8 +317,8 @@ namespace sd { if(child == collection.end()) return false; - // todo check this do we need to set op type logic for childs - if(parent->second.isLogic()) + // todo check this do we need to set child logic if parent is + if(false && parent->second.isLogic()) child->second.setType(parent->second.type()); // in case parent was not out-branching node but child is in branching it will be put to next layer From c87b184b9c9e9835f06e490c1f36926a5bd0820f Mon Sep 17 00:00:00 2001 From: Oleg Date: Tue, 14 Apr 2020 10:12:59 +0300 Subject: [PATCH 099/233] libnd4j added more comments and behavior description for easier support/update in future Signed-off-by: Oleg --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 28 ++++++++++--------- .../layers_tests/GraphAnalysisTests.cpp | 21 ++++++++++++++ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 49ebabf16f01..1c78e99409cd 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -109,12 +109,12 @@ namespace sd { collector[ID] = NodeInfo(); NodeInfo& parentNode = collector[ID]; - + // count external and internal inputs to find out the type of the node (start, in-branching, out-branching) int inExCounts = 0, inInternalCounts = 0; for (const auto& in : inputs) { - + // find input id in original graph if (unmappedNodes.find(in.first) == unmappedNodes.end() ){ - // count external inputs, all inputs which id is node in unmapped container will be treaded as external + // count external inputs, all inputs which id is not in unmapped container will be treaded as external inExCounts++; } else { @@ -124,12 +124,13 @@ namespace sd { // if node info is not in collector add it if (collector.find(in.first) == collector.end()) collector[in.first] = NodeInfo(); - + // input node connection with discovered collector[in.first].addConnection(ID); } } // set operation type parentNode.setType( it.second.opType() ); + // if move then 1 internal input this is in-branching node parentNode.setInBranching( inInternalCounts > 1); // gather start and in-branching nodes for the loop when operations are put to OpSequence (topolSearch) @@ -158,10 +159,10 @@ namespace sd { for (const auto& itNodes : itParent->second.connections()) { auto itChild = collector.find(itNodes); - + // double check if (itChild != collector.end()) { - // if the child is in-branching node it will be treated as start node - if (itChild->second.isInBranching() || itChild->second.isProcessed()) { // proceed + // if the child is in-branching node it will be treated as start node or it was proceed + if (itChild->second.isInBranching() || itChild->second.isProcessed()) { continue; } // put operation to OpSequence container @@ -231,6 +232,7 @@ namespace sd { const auto it = originalGraph().unmappedNodes().find(id); auto& nodeInfo = collector[id]; + // append start/in-branching node operation to sequence if(!nodeInfo.isProcessed()){ vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); nodeInfo.setProcessed(); @@ -251,7 +253,7 @@ namespace sd { // double check to avoid unstable behavior if (layersMaxSeq.empty()) return false; - + // pre-init op-sequence size layers/per-layer sequence vOpSeq.resize(layersMaxSeq.size()); for (const auto& it : layersMaxSeq) { vOpSeq[it.first].resize(it.second + 1); @@ -281,6 +283,7 @@ namespace sd { } else{ // if node sequence position was not checked use it for max sequence selection + // sequence have to be incremented as max + 1, without any jumps if(startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; @@ -307,7 +310,8 @@ namespace sd { int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) ? 0 : layersMaxSeq[layer]; // if parent is out-branching node sequence have to be increment // on the next stage the sequence value will be double checked with max per layer - // todo check logic part + // todo check logic part maybe here have to be check operation class (something likke Switch, If, While etc) + // probably for each of them could be other behavior seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 : seq; // loop via childs (connected nodes) @@ -316,18 +320,16 @@ namespace sd { auto child = collection.find(id); if(child == collection.end()) return false; - - // todo check this do we need to set child logic if parent is - if(false && parent->second.isLogic()) - child->second.setType(parent->second.type()); // in case parent was not out-branching node but child is in branching it will be put to next layer + // todo check logic part if (!parent->second.isOutBranching() && child->second.isInBranching() && !child->second.isLogic()) layer++; // move in depth of connections layersSeqDefine(collection, id, layer, seq, layersMaxSeq); // increment sequence as childs are on the one layer in case if child was not processed earlier + // todo check logic part if(!parent->second.isLogic()) seq++; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index ebf67522f43c..933e3c52b894 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1019,6 +1019,27 @@ TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); auto optimized = graph.optimizedGraph(); + /* + some infor that would be useful for implementation + currently on optimization graph is passing next data + + Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation class: -1719689536 + Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation type: 21; Operation class: -1719689536 + Node name: cond/Switch; ID: 9; Input: 1, 0; Operation type: 119; Operation class: -1719689536 + Node name: cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: -1719689536 + Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: 119; Operation class: -1719689536 + Node name: cond/Merge; ID: 8; Input: 7, 0; Operation type: 119; Operation class: -1719689536 + Node name: in_0/read; ID: 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 + Node name: cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: -1719689536 + Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; Operation class: -1719689536 + Node name: cond/LinSpace; ID: 7; Input: 4, 0; Operation type: 21; Operation class: -1719689536 + + as it can be seen cond/LinSpace is not connected with any switch node(s) that causes wrong results of optimization. + also maybe to cover all conditional operations will be need "Operation class", but this have to discovered deeper. + + All above is true for test_cond_2 + */ + } TEST_F(GraphAnalysisTests, test_cond_2) { From f6fc5c488f773bc8738821651614075719eb35ce Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 17 Apr 2020 17:30:14 +0300 Subject: [PATCH 100/233] changes adopted Signed-off-by: raver119 --- libnd4j/include/array/NDArrayFactory.h | 2 + libnd4j/include/array/impl/NDArrayFactory.cpp | 4 + .../ops/declarable/generic/blas/matmul.cpp | 4 +- .../ops/declarable/helpers/impl/gru.cpp | 10 +- .../ops/declarable/helpers/impl/lstmLayer.cpp | 104 +++++++++--------- .../ops/declarable/helpers/lstmLayer.h | 6 +- .../ops/declarable/platform/mkldnn/matmul.cpp | 2 +- .../layers_tests/DeclarableOpsTests10.cpp | 4 +- .../layers_tests/DeclarableOpsTests15.cpp | 2 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 3 + 10 files changed, 75 insertions(+), 66 deletions(-) diff --git a/libnd4j/include/array/NDArrayFactory.h b/libnd4j/include/array/NDArrayFactory.h index 00f6aa732d9e..d73fba2bf9f2 100644 --- a/libnd4j/include/array/NDArrayFactory.h +++ b/libnd4j/include/array/NDArrayFactory.h @@ -37,6 +37,8 @@ namespace sd { template static void memcpyFromVector(void *ptr, const std::vector &vector); public: + static NDArray undefined(); + template static NDArray* empty_(sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 9e6a0c8964f6..a02e7e9c2f7e 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -36,6 +36,10 @@ namespace sd { + SD_EXPORT NDArray NDArrayFactory::undefined() { + return NDArray(); + } + //////////////////////////////////////////////////////////////////////// template <> SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index 6ec31a77b5bb..887f46ff690f 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { int transY = iSize > 1 ? INT_ARG(1) : 0; const int transZ = iSize > 2 ? INT_ARG(2) : 0; // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); + iSize = (int)block.numT(); double alpha = iSize > 0 ? T_ARG(0) : 1.0; double beta = iSize > 1 ? T_ARG(1) : 0.0; @@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { const int transZ = iSize > 2 ? INT_ARG(2) : 0; // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); + iSize = (int) block.numT(); double alpha = iSize > 0 ? T_ARG(0) : 1.0; double beta = iSize > 1 ? T_ARG(1) : 0.0; diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp index 277188428cb1..7357ad862d47 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -169,7 +169,7 @@ void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h // time loop for (int t = 0; t < sL; ++t) - gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t)); + gruCell(context, &xSet.at(t), t == 0 ? hI : &hSet.at(t-1), Wx, Wh, b, &gates, &hSet.at(t)); } ////////////////////////////////////////////////////////////////////////// @@ -528,16 +528,16 @@ void gruTimeLoopBp(sd::LaunchContext * context, auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - hSet.at(0)->assign(hI); + hSet.at(0).assign(hI); // forward time loop for (int t = 0; t < sL; ++t) - gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1)); + gruCell(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &gatesSet.at(t), &hSet.at(t+1)); // backward time loop for (int t = sL-1; t >= 0; --t) - gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t), - dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); + gruCellBp(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &dLdhSet.at(t), &gatesSet.at(t), + &dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); } diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 8802ca4fafee..8c98c88d3029 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -1013,25 +1013,25 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // seqLen is absent if(hI) - hSet->at(0)->assign(hI); + hSet->at(0).assign(hI); else - hSet->at(0)->nullify(); + hSet->at(0).nullify(); if(cI) - cSet->at(0)->assign(cI); + cSet->at(0).assign(cI); else - cSet->at(0)->nullify(); + cSet->at(0).nullify(); // ff for (int t = 0; t < sL; ++t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1)); + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, params, &zSet->at(t), &aSet->at(t), &hSet->at(t+1), &cSet->at(t+1)); // bp for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + const NDArray dLdhh = dLdh ? dLdhSet->at(t) : NDArrayFactory::undefined(); + const NDArray dLdhhL = (t == sL-1 && dLdhL) ? *dLdhL : NDArrayFactory::undefined(); + const NDArray dLdccL = (t == sL-1 && dLdcL) ? *dLdcL : NDArrayFactory::undefined(); + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, &dLdhh, &dLdhhL, &dLdccL, + &zSet->at(t), &aSet->at(t), &cSet->at(t+1), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } else { // seqLen is present @@ -1046,28 +1046,28 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - hSet->at(e)->assign(hISet->at(e)); + hSet->at(e).assign(hISet->at(e)); else - hSet->at(e)->nullify(); + hSet->at(e).nullify(); if(cI) - cSet->at(e)->assign(cISet->at(e)); + cSet->at(e).assign(cISet->at(e)); else - cSet->at(e)->nullify(); + cSet->at(e).nullify(); // ff for (int t = 0; t < limit; ++t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e)); + lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, params, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e)); // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); + const NDArray dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); + const NDArray dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at((t+1)*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) @@ -1080,25 +1080,25 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // backward or bidirectional, seqLen is absent if(hI) - hSet->at(sL)->assign(hI); + hSet->at(sL).assign(hI); else - hSet->at(sL)->nullify(); + hSet->at(sL).nullify(); if(cI) - cSet->at(sL)->assign(cI); + cSet->at(sL).assign(cI); else - cSet->at(sL)->nullify(); + cSet->at(sL).nullify(); // ff for (int t = sL-1; t >= 0; --t) - lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t)); + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, params, &zSet->at(t), &aSet->at(t), &hSet->at(t), &cSet->at(t)); // bp for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + const NDArray dLdhh = dLdh ? dLdhSet->at(t) : NDArrayFactory::undefined(); + const NDArray dLdhhL = (t == 0 && dLdhL) ? *dLdhL : NDArrayFactory::undefined(); + const NDArray dLdccL = (t == 0 && dLdcL) ? *dLdcL : NDArrayFactory::undefined(); + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, &dLdhh, &dLdhhL, &dLdccL, + &zSet->at(t), &aSet->at(t), &cSet->at(t), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } else if(directionMode == 1) { // backward, seqLen is present @@ -1113,28 +1113,28 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - hSet->at(sL*bS + e)->assign(hISet->at(e)); + hSet->at(sL*bS + e).assign(hISet->at(e)); else - hSet->at(sL*bS + e)->nullify(); + hSet->at(sL*bS + e).nullify(); if(cI) - cSet->at(sL*bS + e)->assign(cISet->at(e)); + cSet->at(sL*bS + e).assign(cISet->at(e)); else - cSet->at(sL*bS + e)->nullify(); + cSet->at(sL*bS + e).nullify(); // ff for (int t = sL - 1; t >= sL-limit; --t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, params, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at(t*bS + e), &cSet->at(t*bS + e)); // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); + const NDArray dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); + const NDArray dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) @@ -1163,18 +1163,18 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // ff for (int t = limit - 1; t >= 0; --t) - lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, - zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, params, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at(t*bS + e), &cSet->at(t*bS + e)); // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, - dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); + const NDArray dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); + const NDArray dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index a89b74909f6a..7eee8917dcf3 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -35,13 +35,13 @@ void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop -void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, NDArray* z, NDArray* a, NDArray* h, NDArray* c); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, +void SD_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); @@ -55,7 +55,7 @@ void SD_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDAr NDArray* h, NDArray* hL, NDArray* cL); ////////////////////////////////////////////////////////////////////////// -void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +void SD_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const std::vector& params, const bool forward, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 1b77099651ac..e91a1af02eec 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -236,7 +236,7 @@ PLATFORM_IMPL(matmul, ENGINE_CPU) { const int transZ = iSize > 2 ? INT_ARG(2) : 0; // optional use alpha nad beta - iSize = (int)block.getTArguments()->size(); + iSize = (int) block.numT(); float alpha = iSize > 0 ? T_ARG(0) : 1.0; float beta = iSize > 1 ? T_ARG(1) : 0.0; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index b8ae5da3ee12..e2e3d17cc7a6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2021,7 +2021,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test2) { auto result = op.evaluate({}, {1, 12}, {23}); ASSERT_EQ(result.status(), ND4J_STATUS_OK); auto res = result.at(0); - ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); + ASSERT_EQ( res.dataType(), sd::DataType::FLOAT32 ); ASSERT_TRUE(expect.equalsTo(res)); } @@ -2035,7 +2035,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test3) { ASSERT_EQ(result.status(), ND4J_STATUS_OK); auto res = result.at(0); - ASSERT_EQ( res->dataType(), expect.dataType()); + ASSERT_EQ( res.dataType(), expect.dataType()); ASSERT_TRUE(expect.equalsTo(res)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 59045add1ead..2eff1150d317 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -1953,7 +1953,7 @@ TEST_F(DeclarableOpsTests15, gru_1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto* h = results.at(0); + auto h = results.at(0); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index ca99df66d837..ec4d04b307b6 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -168,6 +168,9 @@ TEST_F(OneOffTests, test_tensor_array_2) { } TEST_F(OneOffTests, test_tensor_array_3) { + if (1 > 0) + throw std::runtime_error("Temporary disabled"); + auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); auto graph = Graph::fromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); From 09399a1c5326e0e23f690c611d2fbfd6706dc3de Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 20 Apr 2020 11:34:57 +0300 Subject: [PATCH 101/233] all svd tests pass Signed-off-by: raver119 --- libnd4j/include/helpers/cpu/svd.cpp | 45 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 1db9d60420ec..2bae2231f2ad 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -216,7 +216,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh const int len = col2 + 1 - col1; - auto colVec0 = new NDArray(_m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true)); + auto colVec0 = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true); auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c'); @@ -226,7 +226,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh maxElem = math::nd4j_abs(diagInterval.template e(0)); else maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); - T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); + T maxElem0 = colVec0.reduceNumber(reduce::AMax).template e(0); T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); @@ -235,8 +235,8 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh diagInterval.p(Nd4jLong(0), epsBig); for(int i=1; i < len; ++i) - if(math::nd4j_abs(colVec0->template e(i)) < eps) - colVec0->p(i, 0.f); + if(math::nd4j_abs(colVec0.template e(i)) < eps) + colVec0.p(i, 0.f); for(int i=1; i < len; i++) if(diagInterval.template e(i) < epsBig) { @@ -249,7 +249,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh bool totDefl = true; for(int i=1; i < len; i++) - if(colVec0->template e(i) >= almostZero) { + if(colVec0.template e(i) >= almostZero) { totDefl = false; break; } @@ -309,23 +309,23 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh diagInterval.p(i, _e0); if(i!=0 && jac!=0) { - _e0 = colVec0->template e(jac); + _e0 = colVec0.template e(jac); //math::nd4j_swap((*colVec0)(i), (*colVec0)(jac)); - colVec0->p(jac, colVec0->template e(i)); - colVec0->p(i, _e0); + colVec0.p(jac, colVec0.template e(i)); + colVec0.p(i, _e0); } if (_calcU) { auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } else { auto temp1 = _u({0,2, col1+i, col1+i+1}, true); auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -333,7 +333,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh if(_calcV) { auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true); auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -351,7 +351,7 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh { int i = len-1; - while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) + while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0.template e(i)) < almostZero)) --i; for(; i > 1; --i) { @@ -362,8 +362,6 @@ void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh } } } - - delete colVec0; } @@ -606,7 +604,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA const T almostZero = DataTypeUtils::min(); auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); + auto diag = _m({col1, col1+size, col1, col1+size}, true).diagonal('c').dup(); diag.p(Nd4jLong(0), T(0)); singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); @@ -644,14 +642,14 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA auto temp1 = U({0,0, i,i+1}, true); auto temp2 = U({0,0, i+1,i+2}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); if(_calcV) { auto temp1 = V({0,0, i,i+1}, true); auto temp2 = V({0,0, i+1,i+2}, true); - auto temp3 = temp1; + auto temp3 = temp1.dup(); temp1.assign(temp2); temp2.assign(temp3); } @@ -669,7 +667,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA for(int i = 0; i < curSize/2; ++i) { auto temp3 = temp2({0,0, i,i+1}, true); auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3; + auto temp5 = temp3.dup(); temp3.assign(temp4); temp4.assign(temp5); } @@ -679,7 +677,7 @@ void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA for(int i = 0; i < curSize/2; ++i) { auto temp3 = temp2({0,0, i,i+1}, true); auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3; + auto temp5 = temp3.dup(); temp3.assign(temp4); temp4.assign(temp5); } @@ -770,8 +768,7 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif if (_calcU) { - auto temp = _u({col1,col1+k+1, col1+k,col1+k+1}, true); - NDArray q1(temp); + auto q1 = _u({col1,col1+k+1, col1+k,col1+k+1}, true).dup(); for (int i = col1 + k - 1; i >= col1; --i) { auto temp = _u({col1,col1+k+1, i+1,i+2}, true); @@ -811,18 +808,18 @@ void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif if(_calcU) { auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, UofSVD)); } else { auto pTemp = _u({0,0, col1,col1+n+1}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, UofSVD)); } if (_calcV) { auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true); - auto temp = pTemp; + auto temp = pTemp.dup(); pTemp.assign(mmul(temp, VofSVD)); } From cd401bfe1ccbc55625867b531290dabf7f1d9eb8 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 20 Apr 2020 14:37:41 +0300 Subject: [PATCH 102/233] some NDArrayList-related changes Signed-off-by: raver119 --- libnd4j/include/array/NDArrayList.h | 36 +++--- libnd4j/include/array/impl/NDArrayList.cpp | 104 ++++++++---------- libnd4j/include/graph/Variable.h | 1 + libnd4j/include/graph/impl/Variable.cpp | 12 +- .../include/ops/declarable/DeclarableListOp.h | 9 +- .../declarable/generic/list/clone_list.cpp | 2 +- .../declarable/generic/list/create_list.cpp | 2 +- .../declarable/generic/list/gather_list.cpp | 8 +- .../ops/declarable/generic/list/read_list.cpp | 2 +- .../declarable/generic/list/scatter_list.cpp | 8 +- .../ops/declarable/generic/list/size_list.cpp | 6 +- .../declarable/generic/list/split_list.cpp | 6 +- .../declarable/generic/list/unstack_list.cpp | 7 +- .../declarable/generic/list/write_list.cpp | 8 +- .../ops/declarable/helpers/impl/where.cpp | 5 +- .../ops/declarable/impl/DeclarableListOp.cpp | 40 ++----- .../layers_tests/ListOperationsTests.cpp | 47 ++++---- .../layers_tests/NDArrayListTests.cpp | 10 +- 18 files changed, 143 insertions(+), 170 deletions(-) diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index 63aaf2beafc0..3075637a6c27 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -44,11 +44,11 @@ namespace sd { sd::DataType _dtype; // stored chunks - MAP_IMPL _chunks; + MAP_IMPL _chunks; // just a counter, for stored elements std::atomic _elements; - std::atomic _counter; + mutable std::atomic _counter; // reference shape std::vector _shape; @@ -62,39 +62,39 @@ namespace sd { // maximum number of elements int _height = 0; public: - NDArrayList(int height, bool expandable = false); + NDArrayList(int height = 0, bool expandable = false); ~NDArrayList(); NDArrayList(const sd::NDArrayList &other); NDArrayList(sd::NDArrayList &&other); - sd::DataType dataType(); + sd::DataType dataType() const; - NDArray* read(int idx); - NDArray* readRaw(int idx); - Nd4jStatus write(int idx, NDArray* array); + NDArray read(int idx); + NDArray readRaw(int idx); + Nd4jStatus write(int idx, const NDArray &array); - NDArray* pick(std::initializer_list indices); - NDArray* pick(std::vector& indices); - bool isWritten(int index); + NDArray pick(const std::vector& indices); + bool isWritten(int index) const; - std::vector& shape(); + const std::vector& shape() const; + void setShape(const std::vector &shape); - NDArray* stack(); - void unstack(NDArray* array, int axis); + NDArray stack() const; + void unstack(const NDArray &array, int axis); - std::pair& id(); - std::string& name(); + const std::pair& id() const; + const std::string& name() const; //sd::memory::Workspace* workspace(); sd::LaunchContext * context(); NDArrayList* clone(); bool equals(NDArrayList& other); - int elements(); - int height(); + int elements() const; + int height() const; - int counter(); + int counter() const; }; } diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index eba1288f8d39..1f92ed86c8fa 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -45,22 +45,18 @@ namespace sd { } NDArrayList::~NDArrayList() { - //nd4j_printf("\nDeleting NDArrayList: [%i]\n", _chunks.size()); - for (auto const& v : _chunks) - delete v.second; - _chunks.clear(); } - NDArray* NDArrayList::read(int idx) { - return new NDArray(readRaw(idx)->dup()); + NDArray NDArrayList::read(int idx) { + return readRaw(idx); } - sd::DataType NDArrayList::dataType() { + sd::DataType NDArrayList::dataType() const { return _dtype; } - NDArray* NDArrayList::readRaw(int idx) { + NDArray NDArrayList::readRaw(int idx) { if (_chunks.count(idx) < 1) { nd4j_printf("Non-existent chunk requested: [%i]\n", idx); throw std::invalid_argument("Bad index"); @@ -69,109 +65,111 @@ namespace sd { return _chunks[idx]; } - Nd4jStatus NDArrayList::write(int idx, NDArray* array) { + Nd4jStatus NDArrayList::write(int idx, const NDArray &array) { if (_chunks.count(idx) == 0) _elements++; else { - delete _chunks[idx]; + _chunks.erase(idx); } // we store reference shape on first write if (_chunks.empty()) { - _dtype = array->dataType(); + _dtype = array.dataType(); if (_shape.empty()) { //adding leading 1 to shape _shape.emplace_back(1); - for (int e = 0; e < array->rankOf(); e++) - _shape.emplace_back(array->sizeAt(e)); + for (int e = 0; e < array.rankOf(); e++) + _shape.emplace_back(array.sizeAt(e)); } else { // if shape is inferred (say, from split_list) - if (array->rankOf() == _shape.size()) { + if (array.rankOf() == _shape.size()) { // skipping first dim for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array->sizeAt(e)) + if (_shape[e] != array.sizeAt(e)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - } else if (array->rankOf() == _shape.size() - 1) { + } else if (array.rankOf() == _shape.size() - 1) { // case like 2d _shape, and 1D rows for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array->sizeAt(e - 1)) + if (_shape[e] != array.sizeAt(e - 1)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } } else { - if (array->dataType() != _dtype) + if (array.dataType() != _dtype) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same data type"); // if shape is inferred (say, from split_list) - if (array->rankOf() == _shape.size()) { + if (array.rankOf() == _shape.size()) { // skipping first dim for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array->sizeAt(e)) + if (_shape[e] != array.sizeAt(e)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - } else if (array->rankOf() == _shape.size() - 1) { + } else if (array.rankOf() == _shape.size() - 1) { // case like 2d _shape, and 1D rows for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array->sizeAt(e - 1)) + if (_shape[e] != array.sizeAt(e - 1)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - //_elements++; - // storing reference _chunks[idx] = array; return Status::OK(); } - std::vector& NDArrayList::shape() { + const std::vector& NDArrayList::shape() const { return _shape; } - int NDArrayList::counter() { + int NDArrayList::counter() const { return _counter++; } - void NDArrayList::unstack(NDArray* array, int axis) { + void NDArrayList::unstack(const NDArray &array, int axis) { _axis = axis; std::vector args({axis}); - auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); - auto result = array->allTensorsAlongDimension(newAxis); + auto newAxis = ShapeUtils::evalDimsToExclude(array.rankOf(), args); + auto result = array.allTensorsAlongDimension(newAxis); for (int e = 0; e < result.size(); e++) { auto chunk = result.at(e); - write(e, new NDArray(chunk.dup(array->ordering()))); + write(e, chunk.dup(array.ordering())); } } - NDArray* NDArrayList::stack() { + void NDArrayList::setShape(const std::vector &shape) { + _shape = shape; + } + + NDArray NDArrayList::stack() const { // FIXME: this is bad for perf, but ok as poc int numElements = _elements.load(); std::vector inputs(numElements); for (int e = 0; e < numElements; e++) { - _chunks[e]->syncToDevice(); - inputs[e] = _chunks[e]; + _chunks.at(e).syncToDevice(); + inputs[e] = &_chunks.at(e); } auto inShapeInfo = inputs[0]->getShapeInfo(); int rank = shape::rank(inShapeInfo); - NDArray* array = nullptr; + NDArray array; if (shape::isEmpty(inShapeInfo)) { switch (rank) { case 0: { if (numElements == 1) { - array = new NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + array = NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); } else { - array = new NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; + array = NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; } } } @@ -179,19 +177,19 @@ namespace sd { else{ std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); outShape.insert(outShape.begin(), (Nd4jLong) numElements); - array = new NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + array = NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); } - ops::helpers::stack(inputs[0]->getContext(), inputs, *array, 0); + ops::helpers::stack(inputs[0]->getContext(), inputs, array, 0); return array; } - std::pair& NDArrayList::id() { + const std::pair& NDArrayList::id() const { return _id; } - std::string& NDArrayList::name() { + const std::string& NDArrayList::name() const { return _name; } @@ -199,38 +197,30 @@ namespace sd { return _context; } - int NDArrayList::elements() { + int NDArrayList::elements() const { return _elements.load(); } - int NDArrayList::height() { - //if (_height != 0) - // return _height; - //else - return (int) _chunks.size(); + int NDArrayList::height() const { + return (int) _chunks.size(); } - bool NDArrayList::isWritten(int index) { + bool NDArrayList::isWritten(int index) const { if (_chunks.count(index) > 0) return true; else return false; } - NDArray* NDArrayList::pick(std::initializer_list indices) { - std::vector idcs(indices); - return pick(idcs); - } - - NDArray* NDArrayList::pick(std::vector &indices) { + NDArray NDArrayList::pick(const std::vector &indices) { std::vector shape(_shape); //shape.insert(shape.begin() + _axis, indices.size()); shape[_axis] = indices.size(); // do we have to enforce C order here? - auto array = new NDArray('c', shape, _chunks[0]->dataType(), _context); + NDArray array('c', shape, _chunks[0].dataType(), _context); std::vector axis = ShapeUtils::evalDimsToExclude(shape.size(), {_axis}); - auto tads = array->allTensorsAlongDimension(axis); + auto tads = array.allTensorsAlongDimension(axis); int indicesSize = indices.size(); if (tads.size() != indicesSize) @@ -251,7 +241,7 @@ namespace sd { list->_elements.store(_elements.load()); for (auto const& v : _chunks) { - list->_chunks[v.first] = new NDArray(v.second->dup()); + list->_chunks[v.first] = v.second.dup(); } return list; @@ -271,7 +261,7 @@ namespace sd { auto arrThis = _chunks[v.first]; auto arrThat = other._chunks[v.first]; - if (!arrThis->equalsTo(arrThat)) + if (!arrThis.equalsTo(arrThat)) return false; } diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index ae5878bc22f2..fff8fa79a53d 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -80,6 +80,7 @@ namespace sd { explicit Variable(const sd::NDArray &array, const std::string &name, int id, int idx = 0); explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); explicit Variable(std::shared_ptr array, const char *name = nullptr); + explicit Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx = 0); explicit Variable(); #ifndef __JAVACPP_HACK__ diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 01a8a94008cb..6929258fe41e 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -28,6 +28,16 @@ namespace sd { namespace graph { + Variable::Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx) { + _list = std::make_shared(arrayList); + + if (!name.empty()) + _name = name; + + _id = id; + _index = idx; + } + Variable::Variable(const NDArray &array, const std::string &name, int id, int idx) { _ndarray = std::make_shared(array); @@ -55,7 +65,7 @@ namespace sd { } bool sd::graph::Variable::hasNDArrayList() const { - return _list.get() != nullptr; + return _list != nullptr; } bool sd::graph::Variable::isPlaceholder() const { diff --git a/libnd4j/include/ops/declarable/DeclarableListOp.h b/libnd4j/include/ops/declarable/DeclarableListOp.h index 6cf3589758b6..89819864f17d 100644 --- a/libnd4j/include/ops/declarable/DeclarableListOp.h +++ b/libnd4j/include/ops/declarable/DeclarableListOp.h @@ -35,18 +35,15 @@ namespace sd { Nd4jStatus validateAndExecute(Context& block) override = 0; sd::NDArray* getZ(Context& block, int inputId) ; - void setupResult(NDArray* array, Context& block); - void setupResultList(NDArrayList* arrayList, Context& block); + void setupResult(const NDArray &array, Context& block); + void setupResultList(const NDArrayList &arrayList, Context& block); public: DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs); - Nd4jStatus execute(Context* block) override; - - ResultSet execute(NDArrayList* list, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs); - ResultSet execute(NDArrayList* list, std::vector& inputs, std::vector& tArgs, std::vector& iArgs); + ResultSet execute(const NDArrayList &list, const std::vector& inputs, const std::vector& tArgs = {}, const std::vector& iArgs = {}); ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; }; diff --git a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp index d100153ec421..20e5a6b39404 100644 --- a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp @@ -31,7 +31,7 @@ namespace sd { auto newList = list->clone(); //OVERWRITE_RESULT(newList); - setupResultList(newList, block); + setupResultList(*newList, block); return ND4J_STATUS_OK; } DECLARE_SYN(TensorArrayIdentityV3, clone_list); diff --git a/libnd4j/include/ops/declarable/generic/list/create_list.cpp b/libnd4j/include/ops/declarable/generic/list/create_list.cpp index 08d9ff273f4f..3dde5e510df3 100644 --- a/libnd4j/include/ops/declarable/generic/list/create_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/create_list.cpp @@ -45,7 +45,7 @@ namespace sd { // we recieve input array for graph integrity purposes only auto input = INPUT_VARIABLE(0); - setupResultList(list, block); + setupResultList(*list, block); // OVERWRITE_RESULT(list); auto scalar = NDArrayFactory::create(list->counter()); diff --git a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp index 943313ad0242..5a3ed2e9ebd2 100644 --- a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp @@ -44,13 +44,13 @@ namespace sd { // now we should fill other dimensions if (e == 0) { - for (int d = 0; d < array->rankOf(); d++) - shape.emplace_back(array->sizeAt(d)); + for (int d = 0; d < array.rankOf(); d++) + shape.emplace_back(array.sizeAt(d)); } } auto result = NDArrayFactory::create_('c', shape, list->dataType()); - std::vector indicesList((list->readRaw(0)->rankOf() + 1) * 2, 0); + std::vector indicesList((list->readRaw(0).rankOf() + 1) * 2, 0); int skipPosition = 0; for (int e = 0; e < indices->lengthOf(); e++) { auto idx = indices->e(e); @@ -65,7 +65,7 @@ namespace sd { } //OVERWRITE_RESULT(result); - setupResult(result, block); + setupResult(*result, block); return Status::OK(); } DECLARE_SYN(TensorArrayGatherV3, gather_list); diff --git a/libnd4j/include/ops/declarable/generic/list/read_list.cpp b/libnd4j/include/ops/declarable/generic/list/read_list.cpp index a1320b9b3767..0300bf42b913 100644 --- a/libnd4j/include/ops/declarable/generic/list/read_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/read_list.cpp @@ -27,7 +27,7 @@ namespace sd { namespace ops { LIST_OP_IMPL(read_list, 1, 1, 0, 0) { auto list = INPUT_LIST(0); - NDArray *result = nullptr; + NDArray result; REQUIRE_TRUE(list->height() > 0, 0, "ReadList: number of elements in list should be positive prior to Read call"); diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index c83f07c65069..3b2126b4785f 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -57,15 +57,15 @@ namespace sd { if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - auto arr = new NDArray(tads.at(e).dup(array->ordering())); + auto arr = tads.at(e).dup(array->ordering()); auto res = list->write(idx, arr); - if (res != ND4J_STATUS_OK) + + if (res != Status::OK()) return res; } if (!hasList) - //OVERWRITE_RESULT(list); - setupResultList(list, block); + setupResultList(*list, block); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/list/size_list.cpp b/libnd4j/include/ops/declarable/generic/list/size_list.cpp index 9c4d7ff70ce5..7a4ab7f870d7 100644 --- a/libnd4j/include/ops/declarable/generic/list/size_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/size_list.cpp @@ -28,13 +28,13 @@ namespace sd { LIST_OP_IMPL(size_list, 1, 1, 0, 0) { auto list = INPUT_LIST(0); - auto result = NDArrayFactory::create_(list->height(), block.launchContext()); + auto result = NDArrayFactory::create(list->height(), block.launchContext()); //nd4j_printf("List size: [%i]\n", list->height()); - result->printIndexedBuffer("actual height"); + result.printIndexedBuffer("actual height"); //nd4j_printf("List size: [%i]\n", list->height()); - result->printIndexedBuffer("actual height"); + result.printIndexedBuffer("actual height"); //OVERWRITE_RESULT(result); setupResult(result, block); diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index 74a75b29fe9f..40699443f010 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -50,7 +50,7 @@ namespace sd { REQUIRE_TRUE(sizes->isZ(), 0, "split_list: sizes array must have one of integer types"); REQUIRE_TRUE(sizes->rankOf() == 1, 0, "split_list: sizes array must be 1D") - list->shape() = array->getShapeAsVector(); + list->setShape(array->getShapeAsVector()); // now let's build subarrays int cnt = 0; @@ -68,7 +68,7 @@ namespace sd { auto subarray = (*array)(indices); - auto status = list->write(e, new NDArray(subarray.dup(array->ordering()))); + auto status = list->write(e, subarray.dup(array->ordering())); if (status != ND4J_STATUS_OK) return status; @@ -76,7 +76,7 @@ namespace sd { if (!hasList) { //OVERWRITE_RESULT(list); - setupResultList(list, block); + setupResultList(*list, block); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp index 5f452294904a..9a12bb416e5f 100644 --- a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp @@ -32,13 +32,10 @@ namespace ops { if (outputList == nullptr) { outputList = new NDArrayList(0, true); //block.trackList(outputList); - setupResultList(outputList, block); + setupResultList(*outputList, block); } - outputList->unstack(input, INT_ARG(0)); + outputList->unstack(*input, INT_ARG(0)); - //OVERWRITE_RESULT(list); - - // return Status::OK(); } } diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index 859aecdee7fe..8952e5453201 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -39,9 +39,9 @@ namespace sd { //nd4j_printf("Writing [%i]:\n", idx->e(0)); //input->printShapeInfo("input shape"); //input->printIndexedBuffer("input buffer"); - Nd4jStatus result = list->write(idx->e(0), new NDArray(input->dup())); + Nd4jStatus result = list->write(idx->e(0), input->dup()); - auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 2 output shape"); setupResult(res, block); @@ -52,9 +52,9 @@ namespace sd { auto input = INPUT_VARIABLE(1); auto idx = INT_ARG(0); - Nd4jStatus result = list->write(idx, new NDArray(input->dup())); + Nd4jStatus result = list->write(idx, input->dup()); - auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); //res->printShapeInfo("Write_list 1 output shape"); //OVERWRITE_RESULT(res); setupResult(res, block); diff --git a/libnd4j/include/ops/declarable/helpers/impl/where.cpp b/libnd4j/include/ops/declarable/helpers/impl/where.cpp index df8fd1074838..7357e4c96f4c 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/where.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/where.cpp @@ -38,9 +38,9 @@ namespace sd { auto offset = shape::getOffset(condition.getShapeInfo(), idx); if (condition.e(offset)) { - auto array = NDArrayFactory::create_('c', {1, condition.rankOf()}, output.dataType(), output.getContext()); + auto array = NDArrayFactory::create('c', {1, condition.rankOf()}, output.dataType(), output.getContext()); for (int f = 0; f < condition.rankOf(); f++) - array->p(f, (T) idx[f]); + array.p(f, (T) idx[f]); list.write(cnt++, array); } @@ -48,7 +48,6 @@ namespace sd { auto s = list.stack(); output.assign(s); - delete s; } BUILD_SINGLE_TEMPLATE(template void __where,(NDArray &condition, NDArray& output, memory::Workspace *workspace), LIBND4J_TYPES); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index da5671069319..1b43d422dcb1 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -56,20 +56,14 @@ namespace sd { return nullptr; } - void DeclarableListOp::setupResult(NDArray* array, Context& block) { - block.pushNDArrayToVariableSpace(block.getNodeId(), 0, *array); + void DeclarableListOp::setupResult(const NDArray &array, Context& block) { + block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); } - void DeclarableListOp::setupResultList(NDArrayList* arrayList, Context& block) { - block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, *arrayList); + void DeclarableListOp::setupResultList(const NDArrayList &arrayList, Context& block) { + block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); } - ResultSet DeclarableListOp::execute(NDArrayList* list, std::initializer_list inputs, std::initializer_list tArgs, std::initializer_list iArgs) { - std::vector ins(inputs); - std::vector tas(tArgs); - std::vector ias(iArgs); - return this->execute(list, ins, tas, ias); - } Nd4jStatus DeclarableListOp::execute(Context* block) { if (block == nullptr) @@ -94,33 +88,23 @@ namespace sd { return status; } - ResultSet DeclarableListOp::execute(NDArrayList* list, std::vector& inputs, std::vector& tArgs, std::vector& iArgs) { + ResultSet DeclarableListOp::execute(const NDArrayList &list, const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs) { VariableSpace varSpace; int nodeId = 119; // should be never used in practice, since in-graph NDArrayList should have id set int cnt = -1; std::vector in; - if (list != nullptr) { - if (list->id().first == 0) - list->id().first = -1; - - auto listVar = std::make_shared(); - listVar->setId(-119, 0); - //listVar->setNDArrayList(list); - //varSpace.putVariable(-1, listVar); - //in.push_back(-1); - //cnt--; - throw std::runtime_error("DeclarableListOp::execute - Not implemented yet"); - } + auto listVar = std::make_shared(list, "", -1); + varSpace.putVariable(-1, listVar); + in.push_back(-1); + cnt--; for (auto v: inputs) { - //auto var = new Variable(v); - //var->markRemovable(false); - //in.push_back(cnt); - //varSpace.putVariable(cnt--, var); - throw std::runtime_error("DeclarableListOp::execute - Not implemented yet"); + auto var = std::make_shared(*v, "", cnt); + in.push_back(cnt); + varSpace.putVariable(cnt--, var); } Context block(1, &varSpace, false); diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 2f5842ec9521..b3444f1fc60c 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -36,13 +36,13 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) { sd::ops::write_list op; - auto result = op.execute(&list, {&x}, {}, {1}); + auto result = op.execute(list, {&x}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(1, list.elements()); - auto result2 = op.execute(&list, {&x}, {}, {2}); + auto result2 = op.execute(list, {&x}, {}, {2}); ASSERT_EQ(2, list.elements()); @@ -55,15 +55,15 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { auto exp = NDArrayFactory::create('c', {10, 100}); auto tads = exp.allTensorsAlongDimension({1}); for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); + auto row = NDArrayFactory::create('c', {100}); + row.assign((double) e); list.write(e, row); tads.at(e).assign(row); } sd::ops::stack_list op; - auto result = op.execute(&list, {}, {}, {1}); + auto result = op.execute(list, {}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -89,7 +89,7 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { sd::ops::unstack_list op; - auto result = op.execute(&list, {&x}, {}, {0}); + auto result = op.execute(list, {&x}, {}, {0}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(list.elements(), 10); @@ -100,9 +100,8 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { // ASSERT_TRUE(exp.equalsTo(z)); for (int e = 0; e < 10; e++) { auto row = list.read(e); - ASSERT_TRUE(row->equalsTo(tads.at(e))); + ASSERT_TRUE(row.equalsTo(tads.at(e))); //list.write(e, row); - delete row; } @@ -149,14 +148,14 @@ TEST_F(ListOperationsTests, BasicTest_Read_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {1, 100}); row->assign((double) e); - list.write(e, new NDArray(row->dup())); + list.write(e, row->dup()); delete row; } sd::ops::read_list op; - auto result = op.execute(&list, {}, {}, {4}); + auto result = op.execute(list, {}, {}, {4}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -175,7 +174,7 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, new NDArray(row->dup())); + list.write(e, row->dup()); delete row; } @@ -188,7 +187,7 @@ TEST_F(ListOperationsTests, BasicTest_Pick_1) { sd::ops::pick_list op; - auto result = op.execute(&list, {}, {}, {1, 1, 3, 3}); + auto result = op.execute(list, {}, {}, {1, 1, 3, 3}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -206,14 +205,14 @@ TEST_F(ListOperationsTests, BasicTest_Size_1) { for (int e = 0; e < 10; e++) { auto row = NDArrayFactory::create_('c', {100}); row->assign((double) e); - list.write(e, new NDArray(row->dup())); + list.write(e, row->dup()); delete row; } sd::ops::size_list op; - auto result = op.execute(&list, {}, {}, {1}); + auto result = op.execute(list, {}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -231,7 +230,7 @@ TEST_F(ListOperationsTests, BasicTest_Create_1) { sd::ops::create_list op; - auto result = op.execute(nullptr, {&matrix}, {}, {1, 1}); + auto result = op.execute(NDArrayList(), {&matrix}, {}, {1, 1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -280,7 +279,7 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { } sd::ops::split_list op; - auto result = op.execute(&list, {&matrix, &lengths}, {}, {}); + auto result = op.execute(list, {&matrix, &lengths}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(3, list.height()); @@ -315,7 +314,7 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { indices.p(e, 9 - e); sd::ops::scatter_list op; - auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {}); + auto result = op.execute(list, {&indices, &matrix, &s}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); @@ -323,9 +322,9 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { auto row = tads.at(9 - e); auto chunk = list.readRaw(e); - ASSERT_TRUE(chunk->isSameShape(row)); + ASSERT_TRUE(chunk.isSameShape(row)); - ASSERT_TRUE(chunk->equalsTo(row)); + ASSERT_TRUE(chunk.equalsTo(row)); } } @@ -363,11 +362,9 @@ TEST_F(ListOperationsTests, BasicTest_Clone_1) { TEST_F(ListOperationsTests, BasicTest_Gather_1) { NDArrayList list(0, true); for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {3}); - row->assign((double) e); - list.write(e, new NDArray(row->dup())); - - delete row; + auto row = NDArrayFactory::create('c', {3}); + row.assign((double) e); + list.write(e, row.dup()); } auto exp = NDArrayFactory::create('c', {10, 3}); @@ -381,7 +378,7 @@ TEST_F(ListOperationsTests, BasicTest_Gather_1) { indices.linspace(9, -1); sd::ops::gather_list op; - auto result = op.execute(&list, {&indices}, {}, {}); + auto result = op.execute(list, {&indices}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(1, result.size()); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp index 2de3e4651377..0400f4b90bf8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -36,7 +36,7 @@ TEST_F(NDArrayListTests, BasicTests_1) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 10}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); } @@ -47,9 +47,9 @@ TEST_F(NDArrayListTests, BasicTests_2) { auto x = NDArrayFactory::create('c', {1, 10}); auto y = NDArrayFactory::create('c', {1, 7}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); - ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y)); + ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, y)); } @@ -59,7 +59,7 @@ TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { NDArrayList list(false); - list.unstack(&input, 0); + list.unstack(input, 0); ASSERT_EQ(10, list.elements()); @@ -68,6 +68,4 @@ TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { ASSERT_TRUE(input.isSameShape(array)); ASSERT_TRUE(input.equalsTo(array)); - - delete array; } \ No newline at end of file From 4c6e3212906cd05d9e6869e573dd7c41ad9a26c6 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 20 Apr 2020 17:10:19 +0300 Subject: [PATCH 103/233] all list ops pass Signed-off-by: raver119 --- libnd4j/include/array/NDArrayList.h | 59 ++++--- libnd4j/include/array/impl/NDArrayList.cpp | 155 ++++++++++-------- .../declarable/generic/list/clone_list.cpp | 2 +- .../declarable/generic/list/create_list.cpp | 10 +- .../declarable/generic/list/gather_list.cpp | 6 +- .../ops/declarable/impl/DeclarableListOp.cpp | 10 +- .../layers_tests/ListOperationsTests.cpp | 28 ++-- 7 files changed, 146 insertions(+), 124 deletions(-) diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index 3075637a6c27..bb9e068dadfd 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -32,35 +32,42 @@ namespace sd { class SD_EXPORT NDArrayList { - private: - // workspace where chunks belong to - //sd::memory::Workspace* _workspace = nullptr; - sd::LaunchContext * _context = sd::LaunchContext ::defaultContext(); + protected: + class InternalArrayList { + public: + // numeric and symbolic ids of this list + std::pair _id; + std::string _name; - // numeric and symbolic ids of this list - std::pair _id; - std::string _name; + sd::DataType _dtype; - sd::DataType _dtype; + // stored chunks + MAP_IMPL _chunks; - // stored chunks - MAP_IMPL _chunks; + // just a counter, for stored elements + std::atomic _elements; + mutable std::atomic _counter; - // just a counter, for stored elements - std::atomic _elements; - mutable std::atomic _counter; + // reference shape + std::vector _shape; - // reference shape - std::vector _shape; + // unstack axis + int _axis = 0; - // unstack axis - int _axis = 0; + // + bool _expandable = false; - // - bool _expandable = false; + // maximum number of elements + int _height = 0; + + + ////////// + InternalArrayList(int height = 0, bool expandable = false); + ~InternalArrayList() = default; + }; + + std::shared_ptr _state; - // maximum number of elements - int _height = 0; public: NDArrayList(int height = 0, bool expandable = false); ~NDArrayList(); @@ -68,6 +75,11 @@ namespace sd { NDArrayList(const sd::NDArrayList &other); NDArrayList(sd::NDArrayList &&other); + NDArrayList& operator=(const NDArrayList& other) noexcept; + + // move assignment operator + NDArrayList& operator=(NDArrayList&& other) noexcept; + sd::DataType dataType() const; NDArray read(int idx); @@ -85,9 +97,8 @@ namespace sd { const std::pair& id() const; const std::string& name() const; - //sd::memory::Workspace* workspace(); - sd::LaunchContext * context(); - NDArrayList* clone(); + + NDArrayList clone(); bool equals(NDArrayList& other); diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 1f92ed86c8fa..5dd090dc09d2 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -26,22 +26,43 @@ #include namespace sd { - NDArrayList::NDArrayList(const NDArrayList &other) { + NDArrayList::InternalArrayList::InternalArrayList(int height, bool expandable) { + _expandable = expandable; + _elements.store(0); + _counter.store(0); + _id.first = 0; + _id.second = 0; + _height = height; + } + NDArrayList::NDArrayList(const NDArrayList &other) { + _state = other._state; } NDArrayList::NDArrayList(NDArrayList &&other) { + _state = std::move(other._state); + } + + NDArrayList &NDArrayList::operator=(const NDArrayList &other) noexcept { + if (this == &other) + return *this; + + _state = other._state; + return *this; + } + + NDArrayList &NDArrayList::operator=(NDArrayList &&other) noexcept { + if (this == &other) + return *this; + + _state = std::move(other._state); + + return *this; } NDArrayList::NDArrayList(int height, bool expandable) { - _expandable = expandable; - _elements.store(0); - _counter.store(0); - _id.first = 0; - _id.second = 0; - _height = height; - //nd4j_printf("\nCreating NDArrayList\n",""); + _state = std::make_shared(height, expandable); } NDArrayList::~NDArrayList() { @@ -53,89 +74,89 @@ namespace sd { } sd::DataType NDArrayList::dataType() const { - return _dtype; + return _state->_dtype; } NDArray NDArrayList::readRaw(int idx) { - if (_chunks.count(idx) < 1) { + if (_state->_chunks.count(idx) < 1) { nd4j_printf("Non-existent chunk requested: [%i]\n", idx); throw std::invalid_argument("Bad index"); } - return _chunks[idx]; + return _state->_chunks.at(idx); } Nd4jStatus NDArrayList::write(int idx, const NDArray &array) { - if (_chunks.count(idx) == 0) - _elements++; + if (_state->_chunks.count(idx) == 0) + _state->_elements++; else { - _chunks.erase(idx); + _state->_chunks.erase(idx); } // we store reference shape on first write - if (_chunks.empty()) { - _dtype = array.dataType(); + if (_state->_chunks.empty()) { + _state->_dtype = array.dataType(); - if (_shape.empty()) { + if (_state->_shape.empty()) { //adding leading 1 to shape - _shape.emplace_back(1); + _state->_shape.emplace_back(1); for (int e = 0; e < array.rankOf(); e++) - _shape.emplace_back(array.sizeAt(e)); + _state->_shape.emplace_back(array.sizeAt(e)); } else { // if shape is inferred (say, from split_list) - if (array.rankOf() == _shape.size()) { + if (array.rankOf() == _state->_shape.size()) { // skipping first dim - for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array.sizeAt(e)) + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - } else if (array.rankOf() == _shape.size() - 1) { + } else if (array.rankOf() == _state->_shape.size() - 1) { // case like 2d _shape, and 1D rows - for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array.sizeAt(e - 1)) + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } } else { - if (array.dataType() != _dtype) + if (array.dataType() != _state->_dtype) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same data type"); // if shape is inferred (say, from split_list) - if (array.rankOf() == _shape.size()) { + if (array.rankOf() == _state->_shape.size()) { // skipping first dim - for (int e = 1; e < _shape.size(); e++) { - if (_shape[e] != array.sizeAt(e)) + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } - } else if (array.rankOf() == _shape.size() - 1) { + } else if (array.rankOf() == _state->_shape.size() - 1) { // case like 2d _shape, and 1D rows - for (int e = 1; e < _shape.size(); e++) - if (_shape[e] != array.sizeAt(e - 1)) + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } else return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); } // storing reference - _chunks[idx] = array; + _state->_chunks.insert({idx, array}); return Status::OK(); } const std::vector& NDArrayList::shape() const { - return _shape; + return _state->_shape; } int NDArrayList::counter() const { - return _counter++; + return _state->_counter++; } void NDArrayList::unstack(const NDArray &array, int axis) { - _axis = axis; + _state->_axis = axis; std::vector args({axis}); auto newAxis = ShapeUtils::evalDimsToExclude(array.rankOf(), args); auto result = array.allTensorsAlongDimension(newAxis); @@ -146,17 +167,17 @@ namespace sd { } void NDArrayList::setShape(const std::vector &shape) { - _shape = shape; + _state->_shape = shape; } NDArray NDArrayList::stack() const { // FIXME: this is bad for perf, but ok as poc - int numElements = _elements.load(); + int numElements = _state->_elements.load(); std::vector inputs(numElements); for (int e = 0; e < numElements; e++) { - _chunks.at(e).syncToDevice(); - inputs[e] = &_chunks.at(e); + _state->_chunks.at(e).syncToDevice(); + inputs[e] = &_state->_chunks.at(e); } auto inShapeInfo = inputs[0]->getShapeInfo(); @@ -186,40 +207,36 @@ namespace sd { } const std::pair& NDArrayList::id() const { - return _id; + return _state->_id; } const std::string& NDArrayList::name() const { - return _name; - } - - sd::LaunchContext * NDArrayList::context() { - return _context; + return _state->_name; } int NDArrayList::elements() const { - return _elements.load(); + return (int) _state->_chunks.size(); } int NDArrayList::height() const { - return (int) _chunks.size(); + return (int) _state->_chunks.size(); } bool NDArrayList::isWritten(int index) const { - if (_chunks.count(index) > 0) + if (_state->_chunks.count(index) > 0) return true; else return false; } NDArray NDArrayList::pick(const std::vector &indices) { - std::vector shape(_shape); + std::vector shape(_state->_shape); //shape.insert(shape.begin() + _axis, indices.size()); - shape[_axis] = indices.size(); + shape[_state->_axis] = indices.size(); // do we have to enforce C order here? - NDArray array('c', shape, _chunks[0].dataType(), _context); - std::vector axis = ShapeUtils::evalDimsToExclude(shape.size(), {_axis}); + NDArray array('c', shape, _state->_chunks.at(0).dataType()); + std::vector axis = ShapeUtils::evalDimsToExclude(shape.size(), {_state->_axis}); auto tads = array.allTensorsAlongDimension(axis); int indicesSize = indices.size(); @@ -227,39 +244,39 @@ namespace sd { throw std::runtime_error("Number of TADs should match number of indices"); for (int e = 0; e < indicesSize; e++) - tads.at(e).assign(_chunks[indices[e]]); + tads.at(e).assign(_state->_chunks.at(indices[e])); return array; } - NDArrayList* NDArrayList::clone() { - auto list = new NDArrayList(_height, _expandable); - list->_axis = _axis; - list->_id.first = _id.first; - list->_id.second = _id.second; - list->_name = _name; - list->_elements.store(_elements.load()); + NDArrayList NDArrayList::clone() { + NDArrayList list(_state->_height, _state->_expandable); + list._state->_axis = _state->_axis; + list._state->_id.first = _state->_id.first; + list._state->_id.second = _state->_id.second; + list._state->_name = _state->_name; + list._state->_elements.store(_state->_elements.load()); - for (auto const& v : _chunks) { - list->_chunks[v.first] = v.second.dup(); + for (auto const& v : _state->_chunks) { + list._state->_chunks.insert({v.first, v.second.dup()}); } return list; } bool NDArrayList::equals(NDArrayList& other) { - if (_axis != other._axis) + if (_state->_axis != other._state->_axis) return false; - if (_chunks.size() != other._chunks.size()) + if (_state->_chunks.size() != other._state->_chunks.size()) return false; - for (auto const& v : _chunks) { - if (other._chunks.count(v.first) == 0) + for (auto const& v : _state->_chunks) { + if (other._state->_chunks.count(v.first) == 0) return false; - auto arrThis = _chunks[v.first]; - auto arrThat = other._chunks[v.first]; + auto arrThis = _state->_chunks.at(v.first); + auto arrThat = other._state->_chunks.at(v.first); if (!arrThis.equalsTo(arrThat)) return false; diff --git a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp index 20e5a6b39404..d100153ec421 100644 --- a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp @@ -31,7 +31,7 @@ namespace sd { auto newList = list->clone(); //OVERWRITE_RESULT(newList); - setupResultList(*newList, block); + setupResultList(newList, block); return ND4J_STATUS_OK; } DECLARE_SYN(TensorArrayIdentityV3, clone_list); diff --git a/libnd4j/include/ops/declarable/generic/list/create_list.cpp b/libnd4j/include/ops/declarable/generic/list/create_list.cpp index 3dde5e510df3..fc2d22df8821 100644 --- a/libnd4j/include/ops/declarable/generic/list/create_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/create_list.cpp @@ -41,18 +41,18 @@ namespace sd { expandable = true; } - auto list = new NDArrayList(height, expandable); + NDArrayList list(height, expandable); // we recieve input array for graph integrity purposes only auto input = INPUT_VARIABLE(0); - setupResultList(*list, block); -// OVERWRITE_RESULT(list); + setupResultList(list, block); - auto scalar = NDArrayFactory::create(list->counter()); + auto scalar = NDArrayFactory::create(list.counter()); block.pushNDArrayToVariableSpace(block.nodeId(), 1, scalar); - return ND4J_STATUS_OK; + return Status::OK(); } + DECLARE_SYN(TensorArrayV3, create_list); DECLARE_SYN(tensorarrayv3, create_list); DECLARE_SYN(TensorArrayCreateV3, create_list); diff --git a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp index 5a3ed2e9ebd2..341f6347e01a 100644 --- a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp @@ -49,7 +49,7 @@ namespace sd { } } - auto result = NDArrayFactory::create_('c', shape, list->dataType()); + auto result = NDArrayFactory::create('c', shape, list->dataType()); std::vector indicesList((list->readRaw(0).rankOf() + 1) * 2, 0); int skipPosition = 0; for (int e = 0; e < indices->lengthOf(); e++) { @@ -60,12 +60,12 @@ namespace sd { indicesList[0] = skipPosition; indicesList[1] = skipPosition++ + 1; - auto subarray = (*result)(indicesList, true); + auto subarray = (result)(indicesList, true); subarray.assign(array); } //OVERWRITE_RESULT(result); - setupResult(*result, block); + setupResult(result, block); return Status::OK(); } DECLARE_SYN(TensorArrayGatherV3, gather_list); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index 1b43d422dcb1..64f7b09cff74 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -96,10 +96,12 @@ namespace sd { int cnt = -1; std::vector in; - auto listVar = std::make_shared(list, "", -1); - varSpace.putVariable(-1, listVar); - in.push_back(-1); - cnt--; + // first input must be our NDArrayList, except create_list op. it creates list itself. + if (getOpName() != "create_list") { + auto listVar = std::make_shared(list, "", cnt); + varSpace.putVariable(cnt, listVar); + in.push_back(cnt--); + } for (auto v: inputs) { auto var = std::make_shared(*v, "", cnt); diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index b3444f1fc60c..5f22d48a554e 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -45,9 +45,6 @@ TEST_F(ListOperationsTests, BasicTest_Write_1) { auto result2 = op.execute(list, {&x}, {}, {2}); ASSERT_EQ(2, list.elements()); - - - } TEST_F(ListOperationsTests, BasicTest_Stack_1) { @@ -91,8 +88,8 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { auto result = op.execute(list, {&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(list.elements(), 10); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(10, list.elements()); // auto z = result.at(0); // z->printShapeInfo("The first of"); @@ -264,8 +261,8 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { int cnt1 = 0; int cnt2 = 0; for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {5}); - row->assign((double) e); + auto row = NDArrayFactory::create('c', {5}); + row.assign((double) e); tads.at(e).assign(row); if (e < 2) @@ -274,8 +271,6 @@ TEST_F(ListOperationsTests, BasicTest_Split_1) { tads1.at(cnt1++).assign(row); else tads2.at(cnt2++).assign(row); - - delete row; } sd::ops::split_list op; @@ -330,21 +325,18 @@ TEST_F(ListOperationsTests, BasicTest_Scatter_1) { } TEST_F(ListOperationsTests, BasicTest_Clone_1) { - auto list = new NDArrayList(0, true); + NDArrayList list(0, true); VariableSpace variableSpace; - auto var = new Variable(); - //var->setNDArrayList(list); - - //variableSpace.putVariable(-1, var); - //variableSpace.trackList(list); + auto var = std::make_shared(list, "", -1); + variableSpace.putVariable(-1, var); Context block(1, &variableSpace); block.pickInput(-1); sd::ops::clone_list op; - ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); + //ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); auto result = op.execute(&block); @@ -352,11 +344,11 @@ TEST_F(ListOperationsTests, BasicTest_Clone_1) { auto resVar = variableSpace.getVariable(1); - auto resList = resVar->getNDArrayList(); + auto resList = resVar->getNDArrayList().get(); ASSERT_TRUE( resList != nullptr); - ASSERT_TRUE(list->equals(*resList)); + ASSERT_TRUE(list.equals(*resList)); } TEST_F(ListOperationsTests, BasicTest_Gather_1) { From 362bec16750e143f0dc7e6988c20720e8c1e6518 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 21 Apr 2020 13:23:15 +0300 Subject: [PATCH 104/233] some references here and there Signed-off-by: raver119 --- libnd4j/include/graph/Graph.h | 6 +- libnd4j/include/graph/OptimizedGraph.h | 11 +- .../include/graph/execution/ExecutionLayer.h | 4 +- libnd4j/include/graph/execution/OpSequence.h | 6 +- .../graph/execution/impl/ExecutionLayer.cpp | 4 +- .../graph/execution/impl/OpSequence.cpp | 10 +- libnd4j/include/graph/impl/Graph.cpp | 10 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 16 + .../layers_tests/ExecutionLayerTests.cpp | 63 ++++ .../layers_tests/GraphAnalysisTests.cpp | 342 +++++++++--------- 10 files changed, 277 insertions(+), 195 deletions(-) create mode 100644 libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 94829cb44c54..f6fb62b2bd8a 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -72,6 +72,10 @@ namespace sd { void printOutNode(const Node &node) const; std::vector _placeholders; + + mutable OptimizedGraph _optimized; + + mutable std::mutex _optimizedLock; public: Graph(const FlatGraph *flatGraph = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); @@ -168,7 +172,7 @@ namespace sd { FORCEINLINE bool built(); - OptimizedGraph optimizedGraph() const; + const OptimizedGraph& optimizedGraph() const; /** * This method executes this Graph instance and returns execution results diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index c675ee9ff2e6..317c5a6f1568 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -47,9 +47,12 @@ namespace sd { GraphMemoryManager *_memoryManager = nullptr; Graph *_originalGraph = nullptr; - std::mutex _mutex; + mutable std::mutex _mutex; + + mutable size_t _size = 0; public: OptimizedGraph(Graph *original); + OptimizedGraph() = default; ~OptimizedGraph() = default; OptimizedGraph(const OptimizedGraph& other) noexcept; @@ -96,6 +99,11 @@ namespace sd { */ const Graph& originalGraph() const; + /** + * This method returns number of nodes in this graph instance + * @return + */ + size_t size() const; protected: /* * optimize original graph @@ -176,7 +184,6 @@ namespace sd { bool isInBranching() const { return bInBranching; } bool isOutBranching() const { return bOutBranching; } bool isProcessed() const { return bProcessed; } - }; } diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index f119de8a778d..58dcfb8985a8 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -55,8 +55,8 @@ namespace sd { * This method returns specified OpSequence from this layer * @return */ - OpSequence at(uint64_t index) const; - OpSequence operator[](uint64_t index) const; + const OpSequence& at(uint64_t index) const; + const OpSequence& operator[](uint64_t index) const; /** * This method appends OpSequence to the end of this layer diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 5ae2c97b759b..1268baba4fc7 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -74,8 +74,8 @@ namespace sd { * @param index * @return */ - ExecutionTask at(uint64_t index) const; - ExecutionTask operator[](uint64_t index) const; + const ExecutionTask& at(uint64_t index) const; + const ExecutionTask& operator[](uint64_t index) const; /** * This method allows to add DeclarableOp to the end of execution queue @@ -101,7 +101,7 @@ namespace sd { OpSequence & _container; public: explicit iterator(OpSequence & container, uint64_t index = 0); - ExecutionTask operator*() const; + const ExecutionTask& operator*() const; iterator & operator++(); iterator & operator++(int); bool operator!=(const iterator &) const; diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index 64db317964e0..65b91505d8f2 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -31,11 +31,11 @@ namespace sd { return _sequences.size(); } - OpSequence ExecutionLayer::at(uint64_t index) const { + const OpSequence& ExecutionLayer::at(uint64_t index) const { return _sequences[index]; } - OpSequence ExecutionLayer::operator[](uint64_t index) const { + const OpSequence& ExecutionLayer::operator[](uint64_t index) const { return at(index); } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 2e609863806e..e3726e64e761 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -29,12 +29,13 @@ namespace sd { OpSequence::OpSequence(const std::vector &ops, const int deviceId) { _deviceId = deviceId; - for (const auto v : ops) _ops.emplace_back(v); } OpSequence::OpSequence(const OpSequence& other) noexcept{ + _ops.clear(); + for (const auto v : other._ops) _ops.emplace_back(v); } @@ -58,6 +59,7 @@ namespace sd { if (this == &other) return *this; + _ops.clear(); for (const auto v : other._ops) _ops.emplace_back(v); @@ -68,11 +70,11 @@ namespace sd { return _deviceId; } - ExecutionTask OpSequence::at(uint64_t index) const { + const ExecutionTask& OpSequence::at(uint64_t index) const { return _ops[index]; } - ExecutionTask OpSequence::operator[](uint64_t index) const { + const ExecutionTask& OpSequence::operator[](uint64_t index) const { return at(index); } @@ -104,7 +106,7 @@ namespace sd { // } - ExecutionTask OpSequence::iterator::operator*() const { + const ExecutionTask& OpSequence::iterator::operator*() const { return _container._ops[_position]; } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 3dfaadc5457a..694ffaf126e4 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -665,10 +665,14 @@ namespace sd { return _memoryMaager; } - OptimizedGraph Graph::optimizedGraph() const { + const OptimizedGraph& Graph::optimizedGraph() const { + std::lock_guard lock(_optimizedLock); - OptimizedGraph optGraph(const_cast(this)); - return optGraph; + // optionally rebuild optimized graph, if it's out of date + if (_optimized.size() != size()) + _optimized = OptimizedGraph(const_cast(this)); + + return _optimized; } } } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 1c78e99409cd..563d918b3e93 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -61,6 +61,20 @@ namespace sd { return *this; } + size_t OptimizedGraph::size() const { + std::lock_guard lock(_mutex); + + std::vector seq; + if (_size == 0) + for (const auto &v:_onion) { + for (int e = 0; e < v.second.width(); e++) { + _size += v.second.at(0).length(); + } + } + + return _size; + } + uint64_t OptimizedGraph::layers() const { return _onion.size(); } @@ -72,6 +86,7 @@ namespace sd { void OptimizedGraph::append(const std::vector &layer) { std::lock_guard lock(_mutex); _onion[_onion.size()] = layer; + _size = 0; } void OptimizedGraph::append(OpSequence &sequence) { @@ -81,6 +96,7 @@ namespace sd { void OptimizedGraph::append(const ExecutionLayer &layer) { std::lock_guard lock(_mutex); _onion[_onion.size()] = layer; + _size = 0; } const GraphMemoryManager &OptimizedGraph::memoryManager() const { diff --git a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp new file mode 100644 index 000000000000..f80178d19d6a --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ExecutionLayerTests : public testing::Test { +public: + ExecutionLayerTests() { + /// + } +}; + +TEST_F(ExecutionLayerTests, test_reassign_1) { + ExecutionLayer layer; + OpSequence sequence1, sequence2; + + ops::add op1; + ops::multiply op2; + ops::divide op3; + + Context ctx1(1); + Context ctx2(2); + Context ctx3(3); + + sequence1.append(&op1, ctx1); + sequence2.append(&op2, ctx2); + sequence2.append(&op3, ctx3); + + layer.append(sequence1); + layer.append(sequence2); + + auto seq = layer[0]; + ASSERT_EQ(1, seq.length()); + + seq = layer[1]; + ASSERT_EQ(2, seq.length()); +} + diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 933e3c52b894..71c6a73a56ba 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -103,6 +103,9 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { auto optimized = graph.optimizedGraph(); + // graph size must stay the same + ASSERT_EQ(3, graph.size()); + // we expect that OptimizedGraph has exactly 2 layers ASSERT_EQ(2, optimized.layers()); @@ -111,12 +114,12 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer0.width()); - auto sequence = layer0[0]; + ; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -124,15 +127,12 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_2) { // we expect layer has exactly 2 OpSequences ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; - - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_3) { @@ -177,13 +177,13 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer0.width()); - auto sequence = layer0[0]; + //auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(2, layer0[0].length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - ASSERT_EQ(6, sequence.at(1).protoContext().nodeId()); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -191,15 +191,15 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { // we expect layer has exactly 2 OpSequences ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; + //sequence = layer1[0]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; + //sequence = layer1[1]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); // checking last layer @@ -207,12 +207,11 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_3) { // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer2.width()); - sequence = layer2[0]; + //sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_4) { @@ -274,16 +273,14 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_4) { // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer0.width()); - auto sequence = layer0[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - sequence = layer0[1]; + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -291,46 +288,37 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_4) { // we expect layer has exactly 3 OpSequences ASSERT_EQ(3, layer1.width()); - sequence = layer1[0]; - - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); - - sequence = layer1[2]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer2.width()); - sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); - - sequence = layer2[1]; + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); // checking last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer3.width()); - sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_5) { @@ -384,18 +372,18 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_5) { // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer0.width()); - auto sequence = layer0[0]; + //auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - sequence = layer0[1]; + //sequence = layer0[1]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -403,48 +391,48 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_5) { // we expect layer has exactly 2 OpSequences ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; + //sequence = layer1[0]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; + //sequence = layer1[1]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); // checking before last layer auto layer2 = optimized.layer(2); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer2.width()); - sequence = layer2[0]; + //sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); - sequence = layer2[1]; + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); + //sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); // checking last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer3.width()); - sequence = layer3[0]; + //sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); - sequence = layer3[1]; + //sequence = layer3[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_6) { @@ -508,10 +496,11 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_6) { // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer0.width()); - auto sequence = layer0[0]; + + //auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -519,64 +508,61 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_6) { // we expect layer has exactly 2 OpSequences ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; - - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + //sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; - - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + //sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); // checking midle layer auto layer2 = optimized.layer(2); // we expect layer has exactly 2 OpSequence ASSERT_EQ(3, layer2.width()); - sequence = layer2[0]; + //sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - sequence = layer2[1]; + //sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); - sequence = layer2[2]; + //sequence = layer2[2]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); // checking before last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer3.width()); - sequence = layer3[0]; + //sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); - sequence = layer3[1]; - + //sequence = layer3[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); // checking last layer auto layer4 = optimized.layer(4); // we expect layer has exactly 2 OpSequence ASSERT_EQ(1, layer4.width()); - sequence = layer4[0]; + //sequence = layer4[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(15, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); } @@ -619,55 +605,54 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_7) { // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer0.width()); - auto sequence = layer0[0]; + //auto sequence = layer0[0]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); // we expect layer has exactly 2 OpSequences ASSERT_EQ(1, layer1.width()); + //sequence = layer1[0]; - sequence = layer1[0]; - - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); // checking layer 2 auto layer2 = optimized.layer(2); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer2.width()); - sequence = layer2[0]; + //sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); // checking layer 3 auto layer3 = optimized.layer(3); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer3.width()); - sequence = layer3[0]; + //sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); // checking layer 3 auto layer4 = optimized.layer(4); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer4.width()); - sequence = layer4[0]; + //sequence = layer4[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); } @@ -722,21 +707,21 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_8) { // we expect layer has exactly 3 OpSequence ASSERT_EQ(3, layer0.width()); - auto sequence = layer0[0]; + //auto sequence = layer0[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - sequence = layer0[1]; + //sequence = layer0[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - sequence = layer0[2]; + //sequence = layer0[2]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer0[2].length()); + ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -744,18 +729,18 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_8) { // we expect layer has exactly 3 OpSequences ASSERT_EQ(3, layer1.width()); - sequence = layer1[0]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + //sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; + //sequence = layer1[1]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); - sequence = layer1[2]; - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + //sequence = layer1[2]; + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); } @@ -822,77 +807,78 @@ TEST_F(GraphAnalysisTests, basic_toposort_test_9) { auto layer = optimized.layer(0); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer.width()); - auto sequence = layer[0]; + //auto sequence = layer[0]; // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer[0].length()); + ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); auto layer1 = optimized.layer(1); // we expect layer has exactly 4 OpSequence ASSERT_EQ(4, layer1.width()); - sequence = layer1[0]; + //sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - sequence = layer1[1]; + //sequence = layer1[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); - sequence = layer1[2]; + //sequence = layer1[2]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); - sequence = layer1[3]; + //sequence = layer1[3]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(9, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[3].length()); + ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); // we expect layer has exactly 4 OpSequence ASSERT_EQ(8, layer2.width()); - sequence = layer2[0]; + //sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(10, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - sequence = layer2[1]; + //sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(14, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); - sequence = layer2[2]; + //sequence = layer2[2]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(11, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); - sequence = layer2[3]; + //sequence = layer2[3]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(15, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[3].length()); + ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); - sequence = layer2[4]; + //sequence = layer2[4]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(12, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[4].length()); + ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); - sequence = layer2[5]; + //sequence = layer2[5]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(16, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[5].length()); + ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); - sequence = layer2[6]; + //sequence = layer2[6]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(13, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[6].length()); + ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); - sequence = layer2[7]; + //sequence = layer2[7]; // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(17, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer2[7].length()); + ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_10) { From 9667524f1d40cec66443228dc8e2f5e6a893485b Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 22 Apr 2020 07:52:10 +0300 Subject: [PATCH 105/233] minor tests tweaks Signed-off-by: raver119 --- libnd4j/include/graph/execution/impl/GraphExecutor.cpp | 2 +- libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 07bb0da02f43..5c4ddc523129 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -45,7 +45,7 @@ namespace sd { Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { - //auto varSpace = graph.originalGraph().variableSpace(); + //auto &varSpace = graph.originalGraph().variableSpace(); //auto ctx = prepareContext(contextPrototype, varSpace, graph.memoryManager()); //return op->execute(&ctx); throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index f02c9feb54f9..970e8c9d577c 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -44,7 +44,7 @@ TEST_F(GraphExecutorTests, test_basic_exec_1) { GraphMemoryManager memoryManager; Graph graph; - OptimizedGraph optimizedGraph(&graph); + OptimizedGraph optimizedGraph; OpSequence sequence; optimizedGraph.append(sequence); @@ -73,7 +73,7 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { graph.addNode(m, {"A", "B"}); graph.addNode(a, {"mul", "C"}); - OptimizedGraph optimizedGraph(&graph); + OptimizedGraph optimizedGraph; OpSequence sequence; ASSERT_EQ(2, m.protoContext().inputs().size()); From 9af3df14f74ad163e5452102412bdb283d4ce45e Mon Sep 17 00:00:00 2001 From: Yurii Date: Wed, 22 Apr 2020 17:31:00 +0300 Subject: [PATCH 106/233] - get rid of undefined ndarrays in lstmLayer, use pointers instead Signed-off-by: Yurii --- libnd4j/include/array/impl/NDArray.cpp | 32 +++++++-------- .../ops/declarable/helpers/impl/lstmLayer.cpp | 40 +++++++++---------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 3880454ac536..9fca12aaa744 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -61,16 +61,16 @@ NDArray::NDArray(const NDArray& other) { else _buffer = std::make_shared(); */ - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; _shapeInfoD = other._shapeInfoD; - _length = other._length; + _length = other._length; _isAttached = other._isAttached; - _isView = other._isView; - _context = other._context; - _dataType = other._dataType; - _deviceId = other._deviceId; - _offset = other._offset; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; } //////////////////////////////////////////////////////////////////////// @@ -885,16 +885,16 @@ NDArray::NDArray(const std::vector& shape, const std::vector= 0; --t) { - const NDArray dLdhh = dLdh ? dLdhSet->at(t) : NDArrayFactory::undefined(); - const NDArray dLdhhL = (t == sL-1 && dLdhL) ? *dLdhL : NDArrayFactory::undefined(); - const NDArray dLdccL = (t == sL-1 && dLdcL) ? *dLdcL : NDArrayFactory::undefined(); - lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, &dLdhh, &dLdhhL, &dLdccL, + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), &cSet->at(t+1), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1062,10 +1062,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); - const NDArray dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); - const NDArray dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == limit-1 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at((t+1)*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } @@ -1094,10 +1094,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < sL; ++t) { - const NDArray dLdhh = dLdh ? dLdhSet->at(t) : NDArrayFactory::undefined(); - const NDArray dLdhhL = (t == 0 && dLdhL) ? *dLdhL : NDArrayFactory::undefined(); - const NDArray dLdccL = (t == 0 && dLdcL) ? *dLdcL : NDArrayFactory::undefined(); - lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, &dLdhh, &dLdhhL, &dLdccL, + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), &cSet->at(t), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1129,10 +1129,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); - const NDArray dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); - const NDArray dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == sL-limit && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } @@ -1169,10 +1169,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray dLdhh = dLdh ? dLdhSet->at(ind) : NDArrayFactory::undefined(); - const NDArray dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : NDArrayFactory::undefined(); - const NDArray dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : NDArrayFactory::undefined(); - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, &dLdhh, &dLdhhL, &dLdccL, + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } From 30e0542fa7b5fcdb0994b50a568ee6e062edbf93 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 24 Apr 2020 08:28:15 +0300 Subject: [PATCH 107/233] test nano fix Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 1176996a532a..81e92396a03d 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -786,8 +786,8 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { ASSERT_FALSE(exp0.equalsTo(z)); // // z->printBuffer("\nExponential2+"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); ASSERT_FALSE(nexp0->equalsTo(z)); @@ -816,8 +816,8 @@ TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) { //ASSERT_FALSE(exp0.equalsTo(z)); // // z->printBuffer("\nExponential2+"); - auto mean = z->reduceNumber(reduce::Mean); - auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); mean.printBuffer("Mean"); variance.printBuffer("Variance"); ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); From f482c9802f7d74e073e2487726a162dcdfb186ac Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 24 Apr 2020 08:31:05 +0300 Subject: [PATCH 108/233] one more test fixed Signed-off-by: raver119 --- libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 0e13966ea87a..c4dbc2fcc2be 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -62,7 +62,7 @@ TEST_F(OpSequenceTests, test_iterator_1) { ASSERT_EQ(3, cnt); - OptimizedGraph optimizedGraph(&graph); + OptimizedGraph optimizedGraph; ASSERT_EQ(0, optimizedGraph.layers()); optimizedGraph.append(sequence); From 2de6bc7529da6d7587af23622e48f86b4254af57 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 27 Apr 2020 10:54:46 +0300 Subject: [PATCH 109/233] Graph::execute now handles VariableProxy creation --- libnd4j/include/graph/VariableProxy.h | 11 ++++++--- .../include/graph/execution/GraphExecutor.h | 9 ++++---- .../graph/execution/impl/GraphExecutor.cpp | 23 +++++++++---------- libnd4j/include/graph/impl/Graph.cpp | 14 +++++++---- libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 ++++ libnd4j/include/graph/impl/VariableProxy.cpp | 2 +- libnd4j/include/graph/impl/VariableSpace.cpp | 4 +++- libnd4j/include/legacy/cpu/NativeOps.cpp | 2 +- .../include/ops/declarable/OpRegistrator.h | 2 ++ .../layers_tests/GraphExecutorTests.cpp | 6 +++-- .../tests_cpu/layers_tests/GraphTests2.cpp | 2 +- 11 files changed, 50 insertions(+), 29 deletions(-) diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index eea5ababb3cc..c3e427592d1e 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -18,16 +18,19 @@ // @author raver119@gmail.com // +#ifndef SD_VARIABLEPROXY_H +#define SD_VARIABLEPROXY_H + #include namespace sd { namespace graph { class SD_EXPORT VariableProxy: public VariableSpace { protected: - VariableSpace* _backed = nullptr; + const VariableSpace* _backed; VariableSpace* _current = nullptr; public: - explicit VariableProxy(VariableSpace* reference); + explicit VariableProxy(const VariableSpace* reference); ~VariableProxy(); virtual VariableSpace& operator=(const VariableSpace& other); @@ -79,4 +82,6 @@ namespace sd { virtual Stash* stash() const override; }; } -} \ No newline at end of file +} + +#endif // SD_VARIABLEPROXY_H \ No newline at end of file diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 3d7364dce9d3..cf8475690207 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -31,7 +32,7 @@ namespace sd { class SD_EXPORT GraphExecutor { protected: - virtual Context prepareContext(const ContextPrototype &contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const; + virtual Context prepareContext(const ContextPrototype &contextPrototype, VariableProxy &variableSpace, const GraphMemoryManager &memoryManager) const; /* * preprocessor call involves: @@ -58,7 +59,7 @@ namespace sd { * @param graph * @return */ - virtual Nd4jStatus execute(const OptimizedGraph &graph) const; + virtual Nd4jStatus execute(const OptimizedGraph &graph, VariableProxy &proxy) const; /** * This method executes OpSequence @@ -66,7 +67,7 @@ namespace sd { * @param deviceId - this argument allows to override device affinity specified in OpSequence, keep it < 0 to follow OpSequence * @return */ - virtual Nd4jStatus execute(const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId = -1) const; + virtual Nd4jStatus execute(const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, int deviceId) const; /** * This method executes given op @@ -74,7 +75,7 @@ namespace sd { * @param contextPrototype * @return */ - virtual Nd4jStatus execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const; + virtual Nd4jStatus execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const; }; } } diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 5c4ddc523129..94dc7b1cf10e 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -23,9 +23,9 @@ namespace sd { namespace graph { - Context GraphExecutor::prepareContext(const ContextPrototype &contextPrototype, VariableSpace &variableSpace, const GraphMemoryManager &memoryManager) const { + Context GraphExecutor::prepareContext(const ContextPrototype &contextPrototype, VariableProxy &variableProxy, const GraphMemoryManager &memoryManager) const { // TODO: maybe we'll want to do something here? - return Context(contextPrototype, &variableSpace, const_cast(&memoryManager)); + return Context(contextPrototype, &variableProxy, const_cast(&memoryManager)); } Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { @@ -44,20 +44,19 @@ namespace sd { } - Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { - //auto &varSpace = graph.originalGraph().variableSpace(); - //auto ctx = prepareContext(contextPrototype, varSpace, graph.memoryManager()); - //return op->execute(&ctx); - throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); + Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { + auto ctx = prepareContext(contextPrototype, proxy, graph.memoryManager()); + return op->execute(&ctx); + //throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); } - Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, const OptimizedGraph &graph, const int deviceId) const { + Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { /* * this is a basic implementation that works without dispatching etc */ for (int e = 0; e < sequence.length(); e++) { auto v = sequence[e]; - auto result = execute(v.op(), v.protoContext(), sequence, graph, deviceId >= 0 ? deviceId : sequence.deviceId()); + auto result = execute(v.op(), v.protoContext(), sequence, graph, proxy, deviceId >= 0 ? deviceId : sequence.deviceId()); if (result != Status::OK()) return result; } @@ -65,7 +64,7 @@ namespace sd { return Status::OK(); } - Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph) const { + Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, VariableProxy &proxy) const { const auto numDevices = AffinityManager::numberOfDevices(); /* @@ -73,10 +72,10 @@ namespace sd { */ Nd4jStatus result = Status::OK(); for (uint64_t l = 0; l < graph.layers(); l++) { - auto layer = graph.layer(l); + const auto &layer = graph.layer(l); for (uint64_t o = 0; o < layer.width(); o++) { - execute(layer[o], graph); + execute(layer[o], graph, proxy, -1); } // optionally block until all sequences in this layer processed diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 694ffaf126e4..492d0fa66de8 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -557,6 +557,9 @@ namespace sd { } std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { + // creating our proxy, we'll use it for actual execution + VariableProxy proxy(&_variableSpace); + // first of all we check existence of placeholders in dictionary int placeholdersCount = 0; for (const auto &v:dictionary) { @@ -572,6 +575,9 @@ namespace sd { if (shape != var->shape()) throw shape_mismatch_exception::build("Placeholder requires specific shape", var->shape(), shape); + // update the placeholder + proxy.putVariable(v.first, var->id(), var->index(), v.second); + // we must also check if all placeholders were resolved placeholdersCount++; } @@ -588,17 +594,17 @@ namespace sd { } // execute optimized version of this graph - auto status = executor.execute(optimizedGraph()); + auto status = executor.execute(optimizedGraph(), proxy); if (status != Status::OK()) throw graph_execution_exception("Graph execution failed, error code: ", status); - // fetch outputs from VariableSpace + // fetch outputs from our VariableProxy std::map result; for (const auto &v:outputs) { - if (!_variableSpace.hasVariable(v)) + if (!proxy.hasVariable(v)) throw unresolved_output_exception::build("Requested output doesn't exist after execution", v); - auto var = _variableSpace.getVariable(v); + auto var = proxy.getVariable(v); // TODO: we want to make sure ManagedDataBuffer doesn't leak here result[v] = *var->getNDArray(); diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 563d918b3e93..8aa8742166d7 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -34,6 +34,7 @@ namespace sd { OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { _onion = other._onion; _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; } OptimizedGraph &OptimizedGraph::operator=(const OptimizedGraph &other) noexcept { @@ -42,6 +43,7 @@ namespace sd { _onion = other._onion; _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; return *this; } @@ -49,6 +51,7 @@ namespace sd { OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept { _onion = std::move(other._onion); _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; } OptimizedGraph &OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { @@ -57,6 +60,7 @@ namespace sd { _onion = std::move(other._onion); _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; return *this; } diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index d90ed10d8a56..a89e8b5071e2 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { - VariableProxy::VariableProxy(VariableSpace* ref) { + VariableProxy::VariableProxy(const VariableSpace* ref) { if (ref == nullptr) _backed = new VariableSpace(); diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 69991c8a61f8..818ef28b8b90 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -165,7 +165,9 @@ namespace sd { std::shared_ptr VariableSpace::putVariable(const std::string &name, int id, int idx, const NDArray &array) { - return std::shared_ptr(); + auto variable = std::make_shared(array, name, id, idx); + this->putVariable({id, idx}, variable); + return variable; } void VariableSpace::dropVariable(const std::string &pair) { diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index ab9c47688c71..1d4d9f94c42b 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2265,7 +2265,7 @@ const char* getAllOperations() { } Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - + return 0; } void deleteResultWrapper(Nd4jPointer ptr) { diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index e055b7622a3f..a0954efe3c69 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -68,6 +68,7 @@ namespace sd { OpRegistrator() { nd4j_debug("OpRegistrator started\n",""); + /* #ifndef _RELEASE std::signal(SIGSEGV, &OpRegistrator::sigSegVHandler); std::signal(SIGINT, &OpRegistrator::sigIntHandler); @@ -77,6 +78,7 @@ namespace sd { std::signal(SIGTERM, &OpRegistrator::sigIntHandler); atexit(&OpRegistrator::exitHandler); #endif + */ }; MAP_IMPL _msvc; diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 970e8c9d577c..56c3b368bed1 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -49,8 +49,9 @@ TEST_F(GraphExecutorTests, test_basic_exec_1) { optimizedGraph.append(sequence); + VariableProxy proxy(&graph.variableSpace()); GraphExecutor executor; - executor.execute(optimizedGraph); + executor.execute(optimizedGraph, proxy); } TEST_F(GraphExecutorTests, test_basic_exec_2) { @@ -87,8 +88,9 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { ASSERT_EQ(2, sequence.length()); ASSERT_EQ(1, optimizedGraph.layers()); + VariableProxy proxy(&graph.variableSpace()); GraphExecutor executor; - executor.execute(optimizedGraph); + executor.execute(optimizedGraph, proxy); // checking results by ID ASSERT_TRUE(graph.variableSpace().hasVariable(m.id())); diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 4af34615fcf3..117cf7265482 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -103,7 +103,7 @@ TEST_F(GraphTests2, test_placeholder_resolution_2) { graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::rationaltanh(), "tanh_node"), {"input"}); auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); From 94ba5f459b1df34ae2528f86fb12a1247585adda Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 27 Apr 2020 12:16:03 +0300 Subject: [PATCH 110/233] - GraphHolder changes - no more pointers there - few irrelevant tests removed --- libnd4j/include/array/impl/NDArray.cpp | 2 +- libnd4j/include/graph/GraphHolder.h | 46 ++----------- libnd4j/include/graph/impl/Graph.cpp | 5 +- libnd4j/include/graph/impl/GraphHolder.cpp | 69 +++++-------------- libnd4j/include/legacy/cpu/NativeOps.cpp | 25 +------ .../layers_tests/DeclarableOpsTests1.cpp | 2 +- .../layers_tests/GraphExecutorTests.cpp | 10 +-- .../layers_tests/GraphHolderTests.cpp | 39 +---------- .../tests_cpu/layers_tests/OneOffTests.cpp | 2 +- 9 files changed, 38 insertions(+), 162 deletions(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 9fca12aaa744..b225c0a3a733 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -1953,7 +1953,7 @@ Nd4jLong NDArray::argMax(std::initializer_list dimensions) { return max; } else - throw std::runtime_error("Not implemented yet"); + throw std::runtime_error("NDArray::argMax() - Not implemented yet"); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index b200bf37cf75..a41b4457e228 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -31,64 +31,28 @@ namespace sd { class SD_EXPORT GraphHolder { private: static GraphHolder *_INSTANCE; - MAP_IMPL _graphF; + MAP_IMPL _graphs; - MAP_IMPL _locks; + std::mutex _mutex; GraphHolder() = default; ~GraphHolder() = default; public: static GraphHolder* getInstance(); - void registerGraph(Nd4jLong graphId, Graph *graph); - - Graph* cloneGraph(Nd4jLong graphId); + void registerGraph(Nd4jLong graphId, const Graph &graph); - Graph* pullGraph(Nd4jLong graphId); + Graph& graph(Nd4jLong graphId); void forgetGraph(Nd4jLong graphId); void dropGraph(Nd4jLong graphId); - void dropGraphAny(Nd4jLong graphId); - bool hasGraph(Nd4jLong graphId); - bool hasGraphAny(Nd4jLong graphId); - flatbuffers::Offset execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); - void replaceGraph(Nd4jLong graphId, Graph *graph); - - ///////////////////////////// - - FORCEINLINE void lockWrite(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; - - _locks[graphId].lockWrite(); - } - - FORCEINLINE void unlockWrite(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; - - _locks[graphId].unlockWrite(); - } - - FORCEINLINE void lockRead(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; - - _locks[graphId].lockRead(); - } - - FORCEINLINE void unlockRead(Nd4jLong graphId) { - if (_locks.count(graphId) == 0) - return; - - _locks[graphId].unlockRead(); - } + void replaceGraph(Nd4jLong graphId, const Graph &graph); }; } } \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 492d0fa66de8..552c1f8dcd37 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -105,17 +105,18 @@ namespace sd { node.pickInput(idByName(v), 0); } + // actually storing the node. Later, topological sort will be applied on this map _unmapped[node.id()] = node; } void Graph::addNode(Node &node, const std::initializer_list &inputs) { - throw std::runtime_error("not implemented yet"); + throw std::runtime_error("Graph::addNode() - Not implemented yet"); } void Graph::addNode(Node &node, const std::initializer_list> &inputs) { node.markRemovable(false); - throw std::runtime_error("not implemented yet"); + throw std::runtime_error("Graph::addNode() - Not implemented yet"); } Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index c0b3fe59f421..2da1077117ed 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -25,87 +25,56 @@ namespace sd { namespace graph { GraphHolder* GraphHolder::getInstance() { - if (_INSTANCE == 0) + if (_INSTANCE == nullptr) _INSTANCE = new GraphHolder(); return _INSTANCE; }; - void GraphHolder::registerGraph(Nd4jLong graphId, Graph* graph) { - if (hasGraphAny(graphId)) + void GraphHolder::registerGraph(Nd4jLong graphId, const Graph &graph) { + if (hasGraph(graphId)) throw graph_exists_exception(graphId); - _graphF[graphId] = graph; - - sd::SimpleReadWriteLock lock; - _locks[graphId] = lock; + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; } - Graph* GraphHolder::cloneGraph(Nd4jLong graphId) { + Graph& GraphHolder::graph(Nd4jLong graphId) { if (!this->hasGraph(graphId)) { nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); throw std::runtime_error("Bad argument"); } - auto graph = _graphF[graphId]->cloneWithProxy(); - - throw std::runtime_error("GraphHolder::cloneGraph - not implemented yet"); - } - - Graph* GraphHolder::pullGraph(Nd4jLong graphId) { - if (!this->hasGraph(graphId)) { - nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); - throw std::runtime_error("Bad argument"); - } - - auto graph = _graphF[graphId]; - - return graph; + std::lock_guard lock(_mutex); + return _graphs[graphId]; } void GraphHolder::forgetGraph(Nd4jLong graphId) { - if (this->hasGraph(graphId)) - _graphF.erase(graphId); - } - - void GraphHolder::dropGraph(Nd4jLong graphId) { if (this->hasGraph(graphId)) { - auto g = _graphF[graphId]; - forgetGraph(graphId); - delete g; + std::lock_guard lock(_mutex); + _graphs.erase(graphId); } } - void GraphHolder::dropGraphAny(Nd4jLong graphId) { - if (!hasGraphAny(graphId)) - return; - - this->lockWrite(graphId); - - this->dropGraph(graphId); - - this->unlockWrite(graphId); - } - - bool GraphHolder::hasGraphAny(Nd4jLong graphId) { - return this->hasGraph(graphId); + void GraphHolder::dropGraph(Nd4jLong graphId) { + forgetGraph(graphId); } bool GraphHolder::hasGraph(Nd4jLong graphId) { - return _graphF.count(graphId) > 0; + std::lock_guard lock(_mutex); + return _graphs.count(graphId) > 0; } - void GraphHolder::replaceGraph(Nd4jLong graphId, Graph* graph) { + void GraphHolder::replaceGraph(Nd4jLong graphId, const Graph& graph) { if (!hasGraph(graphId)) { registerGraph(graphId, graph); return; } - this->lockWrite(graphId); - - _graphF[graphId] = graph; - - this->unlockWrite(graphId); + forgetGraph(graphId); + + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; } diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 1d4d9f94c42b..c5a6a6681473 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -2174,25 +2174,6 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flat } } -static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - auto graph = sd::graph::GraphHolder::getInstance()->cloneGraph(graphId); - auto varSpace = graph->variableSpace(); - - for (int e = 0; e < numInputs; e++) { - auto idx = inputIndices[e]; - - // we'll delete this array later, together with cloned VariableSpace - - if (varSpace.hasVariable(idx)) { - auto var = varSpace.getVariable(idx); - var->setNDArray(std::make_shared(inputBuffers[e], reinterpret_cast(inputShapes[e]))); - } else - varSpace.putVariable(idx, sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e]))); - } - - throw std::runtime_error("executeStoredGraphT - not implemented yet"); -} - sd::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { return nullptr; } @@ -2230,10 +2211,8 @@ void* getVariableBuffer(sd::graph::Variable* variable) { } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { - - sd::graph::GraphHolder::getInstance()->dropGraphAny(graphId); - - return sd::Status::OK(); + sd::graph::GraphHolder::getInstance()->forgetGraph(graphId); + return Status::OK(); } void deletePointerArray(Nd4jPointer pointer) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 12d6ddef9945..e8cb5621d477 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -3242,7 +3242,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { TEST_F(DeclarableOpsTests1, Test_Expose_2) { if (1 > 0) - throw std::runtime_error("Not implemented yet"); + throw std::runtime_error("Test not implemented yet"); auto list = new NDArrayList(0, true); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 56c3b368bed1..f5d967469354 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -93,14 +93,14 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { executor.execute(optimizedGraph, proxy); // checking results by ID - ASSERT_TRUE(graph.variableSpace().hasVariable(m.id())); - ASSERT_TRUE(graph.variableSpace().hasVariable(a.id())); + ASSERT_TRUE(proxy.hasVariable(m.id())); + ASSERT_TRUE(proxy.hasVariable(a.id())); // checking results by name - ASSERT_TRUE(graph.variableSpace().hasVariable("mul")); - ASSERT_TRUE(graph.variableSpace().hasVariable("add")); + ASSERT_TRUE(proxy.hasVariable("mul")); + ASSERT_TRUE(proxy.hasVariable("add")); // checking if result is valid - auto result = graph.variableSpace().getVariable(a.id())->getNDArray(); + auto result = proxy.getVariable(a.id())->getNDArray(); ASSERT_EQ(exp, *result); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp index 943d525db873..9faf685abf01 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -33,48 +33,11 @@ class GraphHolderTests : public testing::Test { TEST_F(GraphHolderTests, SimpleTests_1) { Graph graph; Nd4jLong graphId = 119; - GraphHolder::getInstance()->registerGraph(graphId, &graph); + GraphHolder::getInstance()->registerGraph(graphId, graph); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); GraphHolder::getInstance()->forgetGraph(graphId); - ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); -} - - - -TEST_F(GraphHolderTests, SimpleTests_2) { - Graph graph; - Nd4jLong graphId = 117; - GraphHolder::getInstance()->registerGraph(graphId, &graph); - - ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); - - auto graph2 = GraphHolder::getInstance()->cloneGraph(graphId); - - ASSERT_TRUE(graph2 != nullptr); - - GraphHolder::getInstance()->forgetGraph(graphId); - - ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); - - delete graph2; -} - - -TEST_F(GraphHolderTests, SimpleTests_3) { - Graph graph; - Nd4jLong graphId = 117; - GraphHolder::getInstance()->registerGraph(graphId, &graph); - - ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); - - auto graph2 = GraphHolder::getInstance()->cloneGraph(graphId); - - ASSERT_TRUE(graph2 != nullptr); - - GraphHolder::getInstance()->dropGraph(graphId); - ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index ec4d04b307b6..244369dca3c7 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -283,7 +283,7 @@ TEST_F(OneOffTests, test_identity_n_2) { TEST_F(OneOffTests, test_non2d_1) { if (1 > 0) - throw std::runtime_error("Not implemented yet"); + throw std::runtime_error("Test not implemented yet"); auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); From 0c23a210e8310b0dc220f12d7302d80f251d2f22 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 27 Apr 2020 13:06:40 +0300 Subject: [PATCH 111/233] few minor tweaks here and themplace_backere --- libnd4j/include/graph/execution/ExecutionTask.h | 2 +- libnd4j/include/graph/execution/OpSequence.h | 6 +++--- libnd4j/include/graph/execution/impl/ExecutionTask.cpp | 2 +- libnd4j/include/graph/execution/impl/OpSequence.cpp | 2 +- libnd4j/include/graph/impl/FlatUtils.cpp | 2 +- libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 6 +++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h index a43c59904a3f..44b6ede5f101 100644 --- a/libnd4j/include/graph/execution/ExecutionTask.h +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -33,7 +33,7 @@ namespace sd { const ContextPrototype &_context; public: - ExecutionTask(std::shared_ptr op, const ContextPrototype &ctx); + ExecutionTask(const std::shared_ptr &op, const ContextPrototype &ctx); ~ExecutionTask() = default; diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 1268baba4fc7..e81f35177d1b 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -40,8 +40,8 @@ namespace sd { int _deviceId = 0; public: - explicit OpSequence(const std::vector &ops, const int deviceId = 0); - OpSequence(const int deviceId = 0); + explicit OpSequence(const std::vector &ops, int deviceId = 0); + OpSequence(int deviceId = 0); ~OpSequence() = default; OpSequence(const OpSequence& other) noexcept; @@ -82,7 +82,7 @@ namespace sd { * @param op - Op to be executed * @param ctx - ContextPrototype for this operation with inputs/outputs/args defined */ - void append(std::shared_ptr op, const sd::graph::ContextPrototype &ctx); + void append(const std::shared_ptr &op, const sd::graph::ContextPrototype &ctx); void append(sd::ops::DeclarableOp *op, const sd::graph::ContextPrototype &ctx); /** diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index a51452696019..34d7597bed8a 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -22,7 +22,7 @@ namespace sd { namespace graph { - ExecutionTask::ExecutionTask(std::shared_ptr op, const ContextPrototype &ctx) : _op(op), _context(ctx) { + ExecutionTask::ExecutionTask(const std::shared_ptr &op, const ContextPrototype &ctx) : _op(op), _context(ctx) { // } diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index e3726e64e761..4ef81fb7f202 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -82,7 +82,7 @@ namespace sd { return _ops.size(); } - void OpSequence::append(std::shared_ptr op, const sd::graph::ContextPrototype &ctx) { + void OpSequence::append(const std::shared_ptr &op, const sd::graph::ContextPrototype &ctx) { ExecutionTask task(op, ctx); _ops.emplace_back(task); } diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index 0843b45870cc..2ca7b9bbf697 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -47,7 +47,7 @@ namespace sd { // empty arrays is special case, nothing to restore here if (shape::isEmpty(newShape)) { delete[] newShape; - return NDArrayFactory::empty(dtype, nullptr); + return NDArrayFactory::empty(dtype); } // TODO fix UTF16 and UTF32 if (dtype == UTF8) { diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 244369dca3c7..e5d72afae220 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -169,7 +169,7 @@ TEST_F(OneOffTests, test_tensor_array_2) { TEST_F(OneOffTests, test_tensor_array_3) { if (1 > 0) - throw std::runtime_error("Temporary disabled"); + throw std::runtime_error("This test crashes"); auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); @@ -282,8 +282,8 @@ TEST_F(OneOffTests, test_identity_n_2) { } TEST_F(OneOffTests, test_non2d_1) { - if (1 > 0) - throw std::runtime_error("Test not implemented yet"); + //if (1 > 0) + // throw std::runtime_error("Test not implemented yet"); auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); From e9fb6450ad1979b8f569a18d053466515c370f0c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 27 Apr 2020 20:37:06 +0300 Subject: [PATCH 112/233] remove outdated Graph tests --- libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 927 ------------------ 1 file changed, 927 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index a1394c17cc86..4f77d0501574 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -44,933 +44,6 @@ class GraphTests : public testing::Test { } }; -TEST_F(GraphTests, SingleInput1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0f); - - graph.variableSpace().putVariable(-1, x); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); - Node nodeB(OpType_TRANSFORM_STRICT, transform::Cosine); - Node nodeC(OpType_TRANSFORM_SAME, transform::Abs); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeC, {2}); - - ASSERT_EQ(3, graph.size()); - - graph.execute(); - - ASSERT_TRUE(graph.variableSpace().hasVariable(3)); - - auto node3 = graph.variableSpace().getVariable(3)->getNDArray(); - - ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); -} - -TEST_F(GraphTests, DoubleInput1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(-1.0); - - auto z = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, y); - graph.variableSpace().putVariable(-3, z); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); - Node nodeB(OpType_TRANSFORM_SAME, transform::Abs); - Node nodeC(OpType_PAIRWISE, pairwise::Add); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {-2}); - graph.addNode(nodeC, {1, 2}); - - - ASSERT_EQ(3, graph.size()); - - graph.execute(); - - ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -TEST_F(GraphTests, SingleInput3) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto v0 = NDArrayFactory::create('c', {5, 5}); - auto v1 = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, v0); - graph.variableSpace().putVariable(-3, v1); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); - Node nodeC(OpType_TRANSFORM_SAME, transform::Ones); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeC, {1}); - - ASSERT_EQ(3, graph.size()); - - graph.execute(); - - ASSERT_NEAR(1.4142135, v0.reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(1.0, v1.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -TEST_F(GraphTests, SingleInput4) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto v0 = NDArrayFactory::create('c', {5, 5}); - auto v1 = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, v0); - graph.variableSpace().putVariable(-3, v1); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt); - Node nodeC(OpType_TRANSFORM_SAME, transform::Neg); - - Node nodeS(OpType_TRANSFORM_SAME, transform::Ones); - Node nodeE(OpType_TRANSFORM_SAME, transform::Identity); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeC, {2}); - graph.addNode(nodeS, {3}); - graph.addNode(nodeE, {3}); - - ASSERT_EQ(5, graph.size()); - - graph.execute(); - - ASSERT_NEAR(1.0, v0.reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.4142135, v1.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, DoubleInput2) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(-1.0); - - auto z0 = NDArrayFactory::create('c', {5, 5}); - auto z1 = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, y); - graph.variableSpace().putVariable(-3, z0); - graph.variableSpace().putVariable(-4, z1); - - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - Node nodeC(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3}); - - Node nodeT(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - Node nodeU(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - Node nodeV(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4}); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeC, {2}); - graph.addNode(nodeT, {-2}); - graph.addNode(nodeU, {4}); - graph.addNode(nodeV, {5}); - - ASSERT_EQ(6, graph.size()); - - graph.execute(); - - ASSERT_NEAR(-1.4142135, z0.reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, DoubleInput3) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(-1.0); - - auto z0 = NDArrayFactory::create('c', {5, 5}); - auto z1 = NDArrayFactory::create('c', {5, 5}); - - - auto w = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, y); - graph.variableSpace().putVariable(-3, z0); - graph.variableSpace().putVariable(-4, z1); - graph.variableSpace().putVariable(-5, w); - - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - Node nodeC(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3, 21}); - - Node nodeT(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); - Node nodeU(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); - Node nodeV(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4, 21}); - - Node nodeW(OpType_PAIRWISE, pairwise::Add, 21, {3, 13}, {22}); - Node nodeZ(OpType_TRANSFORM_SAME, transform::Abs, 22, {21}, {-5}); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeC, {2}); - graph.addNode(nodeT, {-2}); - graph.addNode(nodeU, {4}); - graph.addNode(nodeV, {5}); - graph.addNode(nodeW, {3, 6}); - graph.addNode(nodeZ, {7}); - - ASSERT_EQ(8, graph.size()); - - graph.execute(); - - ASSERT_NEAR(-1.4142135, z0.reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-1.0, z1.reduceNumber(reduce::Mean).e(0), 1e-5); - - ASSERT_NEAR(2.4142135, w.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, QuadInput1) { - Graph graph; - - auto x0 = NDArrayFactory::create('c', {5, 5}); - x0.assign(0.0); - - auto x1 = NDArrayFactory::create('c', {5, 5}); - x1.assign(-1.0); - - auto x2 = NDArrayFactory::create('c', {5, 5}); - x2.assign(-2.0); - - auto x3 = NDArrayFactory::create('c', {5, 5}); - x3.assign(-3.0); - - auto z = NDArrayFactory::create('c', {5, 5}); - z.assign(119.0); - - graph.variableSpace().putVariable(-1, x0); - graph.variableSpace().putVariable(-2, x1); - graph.variableSpace().putVariable(-3, x2); - graph.variableSpace().putVariable(-4, x3); - graph.variableSpace().putVariable(-5, z); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); - Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); - Node nodeC(OpType_TRANSFORM_SAME, transform::Abs, 3, {-3}, {21}); - Node nodeD(OpType_TRANSFORM_SAME, transform::Abs, 4, {-4}, {21}); - - Node nodeP1(OpType_PAIRWISE, pairwise::Add, 11, {1, 2}, {31}); - Node nodeP2(OpType_PAIRWISE, pairwise::Add, 21, {3, 4}, {31}); - - Node nodeZ(OpType_PAIRWISE, pairwise::Add, 31, {11, 21}, {-5}); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {-2}); - graph.addNode(nodeC, {-3}); - graph.addNode(nodeD, {-4}); - graph.addNode(nodeP1, {1, 2}); - graph.addNode(nodeP2, {3, 4}); - graph.addNode(nodeZ, {11, 21}); - - ASSERT_EQ(7, graph.size()); - - graph.execute(); - - ASSERT_NEAR(6.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -TEST_F(GraphTests, InternalBranching1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(0.0); - - auto z = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, z); - - // 1.0 - Node nodeA(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); - - // -1 - Node nodeK(OpType_TRANSFORM_SAME, transform::Neg, 11, {1}, {12}); - - // 2.0 - Node nodeL(OpType_TRANSFORM_SAME, transform::OneMinus, 12, {11}, {31}); - - // -1 - Node nodeR(OpType_TRANSFORM_SAME, transform::Neg, 21, {1}, {22}); - - // 1 - Node nodeS(OpType_TRANSFORM_SAME, transform::Neg, 22, {21}, {31}); - - // 1.0 - Node nodeZ(OpType_PAIRWISE, pairwise::Add, 31, {12, 22}, {-2}); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeK, {1}); - graph.addNode(nodeL, {2}); - graph.addNode(nodeR, {1}); - graph.addNode(nodeS, {1}); - graph.addNode(nodeZ, {1, 1}); - - ASSERT_EQ(6, graph.size()); - - graph.execute(); - - ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, ReductionsTest1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - for (int r = 0; r < x.rows(); r++) { - for (int c = 0; c < x.columns(); c++) { - x.p(r, c, -c); - } - } - - auto z = NDArrayFactory::create('c', {5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, z); - - Node nodeA(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); - Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - - ASSERT_EQ(2, graph.size()); - - graph.execute(); - - ASSERT_NEAR(2.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, IndexReductionsTest1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - for (int r = 0; r < x.rows(); r++) { - for (int c = 0; c < x.columns(); c++) { - x.p(r, c, -c); - } - } - - auto z = NDArrayFactory::create('c', {5, 1}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, z); - //graph->variableSpace().putVariable(-3, axis); - - - Node nodeA(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); - Node nodeB(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); - - graph.addNode(nodeA, {-1, -2}); - graph.addNode(nodeB, {1}); - - ASSERT_EQ(2, graph.size()); - - graph.execute(); - - ASSERT_NEAR(4.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -#if 0 -TEST_F(GraphTests, AutoOutput1) { - auto graph = new Graph(); - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph->variableSpace().putVariable(-1, x); - - auto nodeA = new Node(OpType_TRANSFORM_FLOAT, 0, 1, {-1}, {2}); - auto nodeB = new Node(OpType_TRANSFORM_FLOAT, 35, 2, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(2, graph->totalNodes()); - - graph->buildGraph(); - - ASSERT_TRUE(graph->variableSpace()->getVariable(2) != nullptr); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(1, outputs->size()); - - ASSERT_TRUE(outputs->at(0) != nullptr); - - ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete outputs; - delete graph; -} - - -TEST_F(GraphTests, AutoOutput2) { - auto graph = new Graph(); - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-2.0); - - graph->variableSpace().putVariable(-1, x); - - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2, 3, -1}); - auto nodeB = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); - auto nodeC = new Node(OpType_TRANSFORM_SAME, 6, 3, {1}, {}); - - graph->addNode(nodeA); - graph->addNode(nodeB); - graph->addNode(nodeC); - - ASSERT_EQ(1, graph->rootNodes()); - ASSERT_EQ(3, graph->totalNodes()); - - graph->buildGraph(); - - ASSERT_TRUE(graph->variableSpace()->getVariable(-1) != nullptr); - ASSERT_TRUE(graph->variableSpace()->getVariable(2) != nullptr); - ASSERT_TRUE(graph->variableSpace()->getVariable(3) != nullptr); - - GraphExecutioner::execute(graph); - - auto outputs = graph->fetchOutputs(); - - ASSERT_EQ(2, outputs->size()); - - ASSERT_TRUE(outputs->at(0) != nullptr); - - ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - ASSERT_NEAR(-2.0, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); - - delete graph; - delete outputs; -} -#endif - -TEST_F(GraphTests, BroadcastTest1) { - Graph graph; - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(0.f); - - auto y = NDArrayFactory::create('c', {1, 5}); - for (int e = 0; e < y.columns(); e++) { - y.p(e, (float)e+1); - } - - auto z = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, y); - graph.variableSpace().putVariable(-3, z); - - Node nodeA(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); - Node nodeB(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); - - graph.addNode(nodeA, {-1, -2}); - graph.addNode(nodeB, {1}); - - graph.execute(); - - ASSERT_NEAR(3.0, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - - -TEST_F(GraphTests, ScalarTest1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto z = NDArrayFactory::create('c', {5, 5}); - - graph.variableSpace().putVariable(-1, x); - graph.variableSpace().putVariable(-2, z); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); - Node nodeE(OpType_SCALAR, scalar::Add, 3, {2}, {-2}, {}, 1.3f); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - graph.addNode(nodeE, {2}); - - ASSERT_EQ(3, graph.size()); - - graph.execute(); - - ASSERT_NEAR(2.714213, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -TEST_F(GraphTests, SymbolicLookupTest1) { - Graph graph; - - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(-2.0); - - auto z = NDArrayFactory::create('c', {5, 5}); - - std::string a("alpha"); - std::string o("omega"); - - auto vX = std::make_shared(x, a, -1); - auto vZ = std::make_shared(z, o, -1); - - graph.variableSpace().putVariable(-1, vX); - graph.variableSpace().putVariable(-2, vZ); - - Node nodeA(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); - Node nodeB(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); - - std::string p("phi"); - std::string t("theta"); - - nodeA.setName(&p); - nodeB.setName(&t); - - graph.addNode(nodeA, {-1}); - graph.addNode(nodeB, {1}); - - - auto rX = graph.variableSpace().getVariable(a); - auto rZ = graph.variableSpace().getVariable(o); - - std::string om("omicron"); - - ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); - ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); - ASSERT_FALSE(graph.variableSpace().hasVariable(om)); - - - ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - ASSERT_TRUE(graph.variableSpace().hasVariable(2)); - - graph.execute(); - - ASSERT_TRUE(graph.variableSpace().hasVariable(p)); - ASSERT_TRUE(graph.variableSpace().hasVariable(t)); - - ASSERT_NEAR(1.4142135, z.reduceNumber(reduce::Mean).e(0), 1e-5); -} - -#if 0 -TEST_F(GraphTests, Test_Clone_1) { - auto exp = NDArrayFactory::create('c', {3}); - exp.assign(3.0); - - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->variableSpace(); - //graph->buildGraph(); - - auto clone = graph->clone(); - - Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); - ASSERT_TRUE(variableSpace->hasVariable(3)); - - Nd4jStatus statusClone = GraphExecutioner::execute(clone); - - ASSERT_EQ(ND4J_STATUS_OK, statusClone); - - ASSERT_TRUE(variableSpace->hasVariable(3)); - - auto z0 = variableSpace->getVariable(3)->getNDArray(); - auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z0)); - ASSERT_TRUE(exp.equalsTo(z0)); - - ASSERT_TRUE(exp.isSameShape(z1)); - ASSERT_TRUE(exp.equalsTo(z1)); - - delete graph; - delete clone; -} - - - - -TEST_F(GraphTests, Test_Clone_2) { - auto exp = NDArrayFactory::create('c', {3}); - exp.assign(3.0); - - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - auto variableSpace = graph->variableSpace(); - graph->buildGraph(); - - auto clone = graph->clone(); - - Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); - - ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); - ASSERT_TRUE(variableSpace->hasVariable(3)); - - Nd4jStatus statusClone = GraphExecutioner::execute(clone); - - ASSERT_EQ(ND4J_STATUS_OK, statusClone); - - ASSERT_TRUE(variableSpace->hasVariable(3)); - - auto z0 = variableSpace->getVariable(3)->getNDArray(); - auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z0)); - ASSERT_TRUE(exp.equalsTo(z0)); - - ASSERT_TRUE(exp.isSameShape(z1)); - ASSERT_TRUE(exp.equalsTo(z1)); - - delete graph; - delete clone; -} - -TEST_F(GraphTests, Test_Dtype_Conversion_1) { - /*auto expD = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); - auto expF = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - graph->buildGraph(); - - - auto gd = graph->template asT(); - auto gf = gd->template asT(); - - // checking float graph - Nd4jStatus statusF = GraphExecutioner::execute(gf); - ASSERT_EQ(ND4J_STATUS_OK, statusF); - - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); - - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); - auto z1 = gf->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(expF.isSameShape(z1)); - ASSERT_TRUE(expF.equalsTo(z1)); - - - // checking double graph - Nd4jStatus statusD = GraphExecutioner::execute(gd); - ASSERT_EQ(ND4J_STATUS_OK, statusD); - - ASSERT_TRUE(gd->getVariableSpace()->hasVariable(3)); - auto z2 = gd->getVariableSpace()->getVariable(3)->getNDArray(); - - ASSERT_TRUE(expD.isSameShape(z2)); - ASSERT_TRUE(expD.equalsTo(z2)); - - - delete graph; - delete gd; - delete gf; - */ -} - -TEST_F(GraphTests, Test_Dtype_Conversion_2) { - /* - NDArray expF('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - NDArray expD('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - graph->buildGraph(); - - - auto gd = graph->template asT(); - auto gf = gd->template asT(); - - // checking float - auto resultF = GraphExecutioner::execute(gf); - ASSERT_EQ(ND4J_STATUS_OK, resultF); - ASSERT_TRUE(gf->getVariableSpace()->hasVariable(18)); - auto zF = gf->getVariableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(expF.isSameShape(zF)); - ASSERT_TRUE(expF.equalsTo(zF)); - - - // checking double - auto resultD = GraphExecutioner::execute(gd); - ASSERT_EQ(ND4J_STATUS_OK, resultD); - ASSERT_TRUE(gd->getVariableSpace()->hasVariable(18)); - auto zD = gd->getVariableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(expD.isSameShape(zD)); - ASSERT_TRUE(expD.equalsTo(zD)); - - delete graph; - delete gd; - delete gf; - */ -} - -TEST_F(GraphTests, Test_Hash_Function_1) { - /* - auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph2 = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); - - ASSERT_EQ(graph0->hashCode(), graph1->hashCode()); - ASSERT_NE(0L, graph1->hashCode()); - ASSERT_NE(graph0->hashCode(), graph2->hashCode()); - - auto graph0D = graph0->template asT(); - auto graph1D = graph1->template asT(); - - ASSERT_NE(graph0->hashCode(), graph0D->hashCode()); - ASSERT_EQ(graph0D->hashCode(), graph1D->hashCode()); - - delete graph0; - delete graph1; - delete graph2; - delete graph0D; - delete graph1D; - */ -} - -TEST_F(GraphTests, OpListTest_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - - ASSERT_TRUE(ops.size() == 11); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == 7); - - std::string exp(" -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true '\""); - std::string out = GraphUtils::makeCommandLine(ops); -// nd4j_printf("EXP: >%s<\n", exp.c_str()); -// nd4j_printf("OUT: >%s<\n", out.c_str()); - ASSERT_EQ(exp, out); - - delete graph; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_2) { - auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); - - ASSERT_TRUE(graph0 != nullptr); - ASSERT_TRUE(graph1 != nullptr); - - std::vector ops = graph0->getOperations(); - std::vector ops1 = graph1->getOperations(); - std::copy ( ops1.begin(), ops1.end(), std::back_inserter(ops)); - - ASSERT_TRUE(ops.size() == 13); - - GraphUtils::filterOperations(ops); - - std::string exp = " -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true -DOP_strided_slice=true -DOP_ACCUMULATION{1}=true '\""; - - ASSERT_TRUE(ops.size() == 9); - ASSERT_EQ(exp, GraphUtils::makeCommandLine(ops)); - - delete graph0; - delete graph1; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_3) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - std::vector ops2(ops); - std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); - - ASSERT_TRUE(ops.size() == 11); - ASSERT_TRUE(ops2.size() == 2 * ops.size()); - - GraphUtils::filterOperations(ops2); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == ops2.size()); - ASSERT_TRUE(ops.size() == 7); - ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); - - delete graph; -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(GraphTests, OpListTest_4) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); ; - - ASSERT_TRUE(graph != nullptr); - std::vector ops = graph->getOperations(); - std::vector ops2(ops); - std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); - - // nd4j_printf("Total ops before %i\n", ops.size()); - ASSERT_TRUE(ops.size() == 6); - ASSERT_TRUE(ops2.size() == 2 * ops.size()); - - GraphUtils::filterOperations(ops2); - GraphUtils::filterOperations(ops); - ASSERT_TRUE(ops.size() == ops2.size()); - ASSERT_TRUE(ops.size() == 5); - ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); - - delete graph; -} - - -TEST_F(GraphTests, Test_Inplace_Execution_1) { - auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - // graph->printOut(); - graph->tagInplaceNodes(); - - ASSERT_FALSE(graph->nodeById(8)->isInplace()); - ASSERT_TRUE(graph->nodeById(9)->isInplace()); - ASSERT_TRUE(graph->nodeById(10)->isInplace()); - ASSERT_FALSE(graph->nodeById(11)->isInplace()); - ASSERT_FALSE(graph->nodeById(12)->isInplace()); - ASSERT_TRUE(graph->nodeById(17)->isInplace()); - ASSERT_TRUE(graph->nodeById(18)->isInplace()); - - auto status = GraphExecutioner::execute(graph, graph->variableSpace()); - ASSERT_EQ(Status::OK(), status); - - auto z = graph->variableSpace()->getVariable(18)->getNDArray(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - auto z_17 = graph->variableSpace()->getVariable(17)->getNDArray(); - ASSERT_TRUE(z_17 == z); - - delete graph; -} - -TEST_F(GraphTests, Test_Inplace_Execution_2) { - Graph graphA; - - auto x = NDArrayFactory::create_('c', {5, 5}); - x->assign(-5.0); - - graphA.getVariableSpace()->putVariable(-1, x); - - // abs, result is 5 - auto nodeA0 = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {}); - // 1-, result -4 - auto nodeA1 = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); - - // graph should return 4: abs(-4) - auto nodeA2 = new Node(OpType_TRANSFORM_SAME, 0, 3, {2}, {}); - - // graph should return 1 - 4 = -3 - auto nodeA21 = new Node(OpType_TRANSFORM_SAME, 35, 5, {3}, {}); - - // 1 - -4 = 3 - auto nodeA3 = new Node(OpType_TRANSFORM_SAME, 35, 4, {2}, {}); - - // same abs = 3 - auto nodeA31 = new Node(OpType_TRANSFORM_SAME, 35, 6, {4}, {}); - - graphA.addNode(nodeA0); - graphA.addNode(nodeA1); - graphA.addNode(nodeA2); - graphA.addNode(nodeA3); - graphA.addNode(nodeA21); - graphA.addNode(nodeA31); - - graphA.buildGraph(); - graphA.tagInplaceNodes(); - - // nodes have 1 output - ASSERT_TRUE(graphA.nodeById(1)->isInplace()); - ASSERT_TRUE(graphA.nodeById(2)->isInplace()); - - // this 2 nodes share same input: node 2, so they can't be inplace - ASSERT_FALSE(graphA.nodeById(3)->isInplace()); - ASSERT_FALSE(graphA.nodeById(4)->isInplace()); - - // these 2 ops are standalone, so they can be run inplace - ASSERT_TRUE(graphA.nodeById(5)->isInplace()); - ASSERT_TRUE(graphA.nodeById(6)->isInplace()); -} -#endif - -TEST_F(GraphTests, Test_Inplace_Outputs_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {2, 3}); - - sd::ops::test_output_reshape op; - auto result = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(GraphTests, Test_Inplace_Outputs_2) { -#ifndef __APPLE_OS__ - // we dont want testing this on apple. due to try/catch - - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 3}); - - bool failed = false; - sd::ops::test_output_reshape op; - try { - op.execute({&x}, {&z}, {}, {}, {}); - - } catch (const std::runtime_error& e) { - failed = true; - } - - - ASSERT_TRUE(failed); -#endif -} /* TEST_F(GraphTests, Test_Minifier_1) { From cae1158a2ccfe3a5b5b7eb92c230bcc213692e8d Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 28 Apr 2020 08:40:04 +0300 Subject: [PATCH 113/233] printOut draft --- libnd4j/include/graph/OptimizedGraph.h | 5 +++++ libnd4j/include/graph/execution/ExecutionTask.h | 1 + libnd4j/include/graph/execution/GraphExecutor.h | 2 +- libnd4j/include/graph/execution/OpSequence.h | 5 +++++ libnd4j/include/graph/execution/impl/ExecutionTask.cpp | 8 ++++++++ libnd4j/include/graph/execution/impl/GraphExecutor.cpp | 2 +- libnd4j/include/graph/execution/impl/OpSequence.cpp | 7 ++++++- libnd4j/include/graph/impl/Graph.cpp | 7 ++++++- libnd4j/include/graph/impl/Node.cpp | 4 ++-- libnd4j/include/graph/impl/OptimizedGraph.cpp | 7 ++++++- libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 4 +++- 11 files changed, 44 insertions(+), 8 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 317c5a6f1568..b653c27943c1 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -104,6 +104,11 @@ namespace sd { * @return */ size_t size() const; + + /** + * This method prints out graph content + */ + void printOut() const; protected: /* * optimize original graph diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h index 44b6ede5f101..12a0ac250b07 100644 --- a/libnd4j/include/graph/execution/ExecutionTask.h +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -47,6 +47,7 @@ namespace sd { // move assignment operator ExecutionTask& operator=(ExecutionTask&& other) noexcept; + void printOut() const; std::shared_ptr op() const; diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index cf8475690207..f77836de56f1 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -75,7 +75,7 @@ namespace sd { * @param contextPrototype * @return */ - virtual Nd4jStatus execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const; + virtual Nd4jStatus execute(const std::shared_ptr &op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const; }; } } diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index e81f35177d1b..ba461b011bf8 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -63,6 +63,11 @@ namespace sd { */ Nd4jStatus wait() const; + /** + * This method prints out content of the sequence + */ + void printOut() const; + /** * This method returns number of individual operations within this sequence * @return diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index 34d7597bed8a..dfd0724b00c5 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -52,6 +52,14 @@ namespace sd { // } + void ExecutionTask::printOut() const { + if (_context.name().empty()) { + nd4j_printf("Node <%s/%i>:\n", "_", _context.nodeId()); + } else { + nd4j_printf("Node <%s/%i>:\n", _context.name().c_str(), _context.nodeId()); + } + } + ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { if (this == &other) return *this; diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 94dc7b1cf10e..be4e4b3ff2e8 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -44,7 +44,7 @@ namespace sd { } - Nd4jStatus GraphExecutor::execute(std::shared_ptr op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { + Nd4jStatus GraphExecutor::execute(const std::shared_ptr &op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { auto ctx = prepareContext(contextPrototype, proxy, graph.memoryManager()); return op->execute(&ctx); //throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 4ef81fb7f202..72325564cdf2 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -60,12 +60,17 @@ namespace sd { return *this; _ops.clear(); - for (const auto v : other._ops) + for (const auto &v : other._ops) _ops.emplace_back(v); return *this; } + void OpSequence::printOut() const { + for (const auto &v: _ops) + v.printOut(); + } + int OpSequence::deviceId() const { return _deviceId; } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 552c1f8dcd37..d70fe8c1a7f9 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -268,7 +268,12 @@ namespace sd { fflush(stdout); - throw std::runtime_error("Graph::printOut - not implemented yet"); + if (size() > 0) { + nd4j_printf("\nPrinting out Nodes...\n", ""); + + // since we need structure - we'll print out nodes of OptimizedGraph + optimizedGraph().printOut(); + } } Nd4jStatus Graph::validateNode(Node *node) { diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index ffa2a9fde569..5b16604a44c5 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -587,8 +587,8 @@ namespace sd { } } - if (node->dimensions() != nullptr && node->dimensions()->size() > 0) - throw std::runtime_error("FlatNode has dimensions defined. Graph is outdated"); + //if (node->dimensions() != nullptr && node->dimensions()->size() > 0) + // throw std::runtime_error("FlatNode has dimensions defined. Graph is outdated"); if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { if (node->extraInteger()->size() < 1) { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 8aa8742166d7..3f0e28fc5513 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -357,6 +357,11 @@ namespace sd { return true; } - + + void OptimizedGraph::printOut() const { + for (const auto &layer : _onion) + for (uint64_t l = 0; l < layer.second.width(); l++) + layer.second.at(l).printOut(); + } } } diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index e5d72afae220..8f465d25f249 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -75,7 +75,9 @@ TEST_F(OneOffTests, test_assert_scalar_float32_2) { sd::ops::noop op2; auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); - graph.execute(); + graph.printOut(); + + //graph.execute(); } From 88ddc7f0954c146ff13054b93e1a2e35fa7b3b9d Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 28 Apr 2020 10:54:38 +0300 Subject: [PATCH 114/233] printOut draft --- libnd4j/include/graph/ContextPrototype.h | 4 +-- .../graph/execution/impl/ExecutionTask.cpp | 27 +++++++++++++++++-- .../graph/execution/impl/OpSequence.cpp | 4 +-- .../include/graph/impl/ContextPrototype.cpp | 4 +-- libnd4j/include/graph/impl/Node.cpp | 4 +-- libnd4j/include/graph/impl/OptimizedGraph.cpp | 9 ++++--- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 9b945b4c281f..99c62f3dfb8e 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -92,10 +92,10 @@ namespace sd { void pickInput(int input); void pickInput(int input, int index); - void pickInput(std::pair& p); + void pickInput(const std::pair& p); void fillInputs(std::initializer_list inputs); void fillInputs(std::vector& inputs); - std::vector>& inputs() const; + const std::vector>& inputs() const; const std::vector& getTArguments() const; const std::vector& getIArguments() const; diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index dfd0724b00c5..81b696a6039f 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -54,10 +54,33 @@ namespace sd { void ExecutionTask::printOut() const { if (_context.name().empty()) { - nd4j_printf("Node <%s/%i>:\n", "_", _context.nodeId()); + if (_op != nullptr) + printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _op->getOpName().c_str()); + else + printf(" <%i:0>: ", _context.nodeId()); } else { - nd4j_printf("Node <%s/%i>:\n", _context.name().c_str(), _context.nodeId()); + printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); } + + auto sz = _context.inputs().size(); + if (sz) { + printf(" Inputs: ["); + int cnt = 0; + for (const auto &v:_context.inputs()) { + printf("<%i:%i>", v.first, v.second); + + if (cnt < sz - 1) + printf(", "); + cnt++; + } + + printf("]; "); + } else { + printf(" No inputs; "); + } + + printf("\n"); + fflush(stdout); } ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 72325564cdf2..0d23ca97cd66 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -29,14 +29,14 @@ namespace sd { OpSequence::OpSequence(const std::vector &ops, const int deviceId) { _deviceId = deviceId; - for (const auto v : ops) + for (const auto &v : ops) _ops.emplace_back(v); } OpSequence::OpSequence(const OpSequence& other) noexcept{ _ops.clear(); - for (const auto v : other._ops) + for (const auto &v : other._ops) _ops.emplace_back(v); } diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 3fa993bd1027..370a18fbd2a9 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -31,7 +31,7 @@ namespace sd { _opDescriptor = opDescriptor; } - void ContextPrototype::pickInput(std::pair& p) { + void ContextPrototype::pickInput(const std::pair& p) { this->_inputs.emplace_back(p); } @@ -48,7 +48,7 @@ namespace sd { this->_opNum = opNum; } - std::vector> & ContextPrototype::inputs() const { + const std::vector> & ContextPrototype::inputs() const { return const_cast> &>(_inputs); } diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 5b16604a44c5..662ecabd931d 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -658,7 +658,7 @@ namespace sd { ContextPrototype block(nullptr, this->id(), false); for (int e = 0; e < this->input().size(); e++) { - block.inputs().emplace_back(this->input().at(e)); + block.pickInput(this->input().at(e)); } // there's no other IArgs in legacy options, actually @@ -701,7 +701,7 @@ namespace sd { ContextPrototype block(nullptr, this->id()); for (int e = 0; e < this->input().size(); e++) { - block.inputs().emplace_back(this->input().at(e)); + block.pickInput(this->input().at(e)); } if (node->extraInteger() != nullptr) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 3f0e28fc5513..f24f8ee36ce8 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -359,9 +359,12 @@ namespace sd { void OptimizedGraph::printOut() const { - for (const auto &layer : _onion) - for (uint64_t l = 0; l < layer.second.width(); l++) - layer.second.at(l).printOut(); + for (uint64_t o = 0; o < _onion.size(); o++) { + const auto &layer = _onion.at(o); + printf("Layer [%lu]\n", o); + for (uint64_t l = 0; l < layer.width(); l++) + layer.at(l).printOut(); + } } } } From bb57cfe5f63c4df35f9f9a1bc65652ca10c63ca8 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 1 May 2020 07:36:15 +0300 Subject: [PATCH 115/233] minor fix --- libnd4j/include/legacy/NativeOps.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index ec7642bad5ab..00d5cb2f4eea 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -1637,8 +1637,8 @@ SD_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); SD_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); SD_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -ND4J_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special); -ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); +SD_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special); +SD_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); SD_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); SD_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); SD_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); From a99a45d38a73f9e04c516cdc0893174991a6e80a Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 11 May 2020 13:25:42 +0300 Subject: [PATCH 116/233] master merge Signed-off-by: raver119@gmail.com --- .../layers_tests/DeclarableOpsTests19.cpp | 24 +++++++++---------- .../layers_tests/DeclarableOpsTests6.cpp | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 641728ad3032..f8a084f43984 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -53,7 +53,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { //encoded->printIndexedBuffer("ENC"); - ASSERT_EQ(exp_encoded, *encoded); + ASSERT_EQ(exp_encoded, encoded); ASSERT_EQ(exp_gradients, x); // FIXME: we need to add a way to declare individual inplace outputs @@ -75,7 +75,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_2) { auto encoded = result.at(1); - ASSERT_EQ(length + 4, encoded->lengthOf()); + ASSERT_EQ(length + 4, encoded.lengthOf()); ASSERT_EQ(exp_gradients, x); } } @@ -90,7 +90,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_1) { auto gradients = result.at(0); auto encoded = result.at(1); - ASSERT_EQ(7, encoded->lengthOf()); + ASSERT_EQ(7, encoded.lengthOf()); ASSERT_EQ(3, x.sumNumber().e(0)); } @@ -104,7 +104,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_2) { auto gradients = result.at(0); auto encoded = result.at(1); - ASSERT_EQ(104, encoded->lengthOf()); + ASSERT_EQ(104, encoded.lengthOf()); ASSERT_EQ(900, x.sumNumber().e(0)); } @@ -138,10 +138,10 @@ TEST_F(DeclarableOpsTests19, test_bitmap_encode_1) { //encoded->printIndexedBuffer("encoded"); - ASSERT_EQ(exp_c, *counter); + ASSERT_EQ(exp_c, counter); sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, encoded}, {&initial}); + auto status = dec.execute({&initial, &encoded}, {&initial}); ASSERT_EQ(Status::OK(), status); @@ -162,16 +162,16 @@ TEST_F(DeclarableOpsTests19, test_bitmap_encode_decode) { auto encoded = enc_result.at(1); // checking equality of all encoded bits - for (int e = 5; e < encoded->lengthOf() - 1; e++) { - if (encoded->e(e) != encoded->e(e - 1)) - nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); + for (int e = 5; e < encoded.lengthOf() - 1; e++) { + if (encoded.e(e) != encoded.e(e - 1)) + nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded.e(e)); } ASSERT_NE(exp, initial); ASSERT_EQ(neg, initial); sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, encoded}, {&initial}); + auto status = dec.execute({&initial, &encoded}, {&initial}); ASSERT_EQ(Status::OK(), status); // checking equality of all dedoded bits @@ -196,7 +196,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(256000 + 4, encoded->lengthOf()); + ASSERT_EQ(256000 + 4, encoded.lengthOf()); ASSERT_NE(exp, initial); for (int e = 0; e < initial.lengthOf(); e++) { @@ -215,7 +215,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { //} sd::ops::decode_threshold dec; - auto status = dec.execute({&initial, encoded}, {&initial}); + auto status = dec.execute({&initial, &encoded}, {&initial}); ASSERT_EQ(Status::OK(), status); // checking equality of all dedoded bits diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index afcff6d40f93..a8451e1c4026 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -158,7 +158,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { block->appendI(0); block->appendI(0); - auto inputShapes = new ShapeList({ones.getShapeInfo(), b.getShapeInfo(), e.getShapeInfo(), s.getShapeInfo()}); + auto inputShapes = new ShapeList({ones.shapeInfo(), b.shapeInfo(), e.shapeInfo(), s.shapeInfo()}); sd::ops::strided_slice op; From 670e57c82f19b042349a7e4f015a7620937184d8 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 11 May 2020 14:32:03 +0300 Subject: [PATCH 117/233] Google formatting enforced Signed-off-by: raver119@gmail.com --- libnd4j/include/array/ArrayOptions.h | 502 +- libnd4j/include/array/ArrayType.h | 14 +- libnd4j/include/array/ByteOrder.h | 10 +- libnd4j/include/array/ByteOrderUtils.h | 16 +- libnd4j/include/array/ConstantDataBuffer.h | 65 +- libnd4j/include/array/ConstantDescriptor.h | 67 +- libnd4j/include/array/ConstantHolder.h | 53 +- libnd4j/include/array/DataBuffer.h | 249 +- libnd4j/include/array/DataType.h | 50 +- libnd4j/include/array/DataTypeConversions.h | 282 +- libnd4j/include/array/DataTypeUtils.h | 658 +- libnd4j/include/array/ExtraArguments.h | 54 +- libnd4j/include/array/InteropDataBuffer.h | 77 +- libnd4j/include/array/ManagedDataBuffer.h | 43 +- libnd4j/include/array/NDArray.h | 3686 +++--- libnd4j/include/array/NDArrayFactory.h | 484 +- libnd4j/include/array/NDArrayLambda.hXX | 563 +- libnd4j/include/array/NDArrayList.h | 114 +- libnd4j/include/array/ResultSet.h | 61 +- libnd4j/include/array/ShapeDescriptor.h | 145 +- libnd4j/include/array/ShapeList.h | 64 +- libnd4j/include/array/SpaceType.h | 10 +- libnd4j/include/array/SparseType.h | 14 +- libnd4j/include/array/TadDescriptor.h | 70 +- libnd4j/include/array/TadPack.h | 64 +- libnd4j/include/array/cpu/DataBuffer.cpp | 151 +- .../include/array/cpu/ManagedDataBuffer.cpp | 10 +- libnd4j/include/array/cpu/NDArray.cpp | 660 +- libnd4j/include/array/cpu/NDArrayLambda.hpp | 725 +- libnd4j/include/array/cuda/DataBuffer.cu | 455 +- libnd4j/include/array/cuda/NDArray.cu | 957 +- libnd4j/include/array/impl/ByteOrderUtils.cpp | 9 +- .../include/array/impl/ConstantDataBuffer.cpp | 75 +- .../include/array/impl/ConstantDescriptor.cpp | 116 +- libnd4j/include/array/impl/ConstantHolder.cpp | 76 +- libnd4j/include/array/impl/DataBuffer.cpp | 457 +- libnd4j/include/array/impl/DataTypeUtils.cpp | 16 +- libnd4j/include/array/impl/ExtraArguments.cpp | 167 +- .../include/array/impl/InteropDataBuffer.cpp | 243 +- .../include/array/impl/ManagedDataBuffer.cpp | 28 +- libnd4j/include/array/impl/NDArray.cpp | 9465 ++++++++------ libnd4j/include/array/impl/NDArrayFactory.cpp | 1622 ++- libnd4j/include/array/impl/NDArrayList.cpp | 464 +- libnd4j/include/array/impl/ResultSet.cpp | 178 +- .../include/array/impl/ShapeDescriptor.cpp | 561 +- libnd4j/include/array/impl/ShapeList.cpp | 86 +- libnd4j/include/array/impl/TadDescriptor.cpp | 109 +- libnd4j/include/array/impl/TadPack.cpp | 80 +- libnd4j/include/cblas.h | 872 +- libnd4j/include/cblas_enum_conversion.h | 3 +- libnd4j/include/cnpy/cnpy.h | 439 +- .../include/exceptions/allocation_exception.h | 36 +- libnd4j/include/exceptions/cuda_exception.h | 28 +- .../include/exceptions/datatype_exception.h | 40 +- libnd4j/include/exceptions/graph_exception.h | 55 +- .../exceptions/graph_execution_exception.h | 27 +- .../exceptions/graph_exists_exception.h | 22 +- .../exceptions/impl/allocation_exception.cpp | 34 +- .../exceptions/impl/cuda_exception.cpp | 19 +- .../exceptions/impl/datatype_exception.cpp | 52 +- .../exceptions/impl/graph_exception.cpp | 56 +- .../impl/graph_execution_exception.cpp | 19 +- .../impl/graph_exists_exception.cpp | 11 +- .../exceptions/impl/no_results_exception.cpp | 11 +- .../impl/shape_mismatch_exception.cpp | 23 +- .../impl/unknown_graph_exception.cpp | 11 +- .../include/exceptions/no_results_exception.h | 22 +- .../exceptions/shape_mismatch_exception.h | 34 +- .../exceptions/unknown_graph_exception.h | 22 +- libnd4j/include/execution/AffinityManager.h | 31 +- libnd4j/include/execution/BlockingQueue.h | 51 +- libnd4j/include/execution/CallableInterface.h | 139 +- .../include/execution/CallableWithArguments.h | 135 +- libnd4j/include/execution/ContextBuffers.h | 73 +- libnd4j/include/execution/Engine.h | 10 +- libnd4j/include/execution/ErrorReference.h | 33 +- libnd4j/include/execution/ExecutionMode.h | 12 +- libnd4j/include/execution/Executor.h | 16 +- libnd4j/include/execution/LaunchContext.h | 143 +- libnd4j/include/execution/ThreadPool.h | 74 +- libnd4j/include/execution/Threads.h | 356 +- libnd4j/include/execution/Ticket.h | 66 +- .../include/execution/cpu/AffinityManager.cpp | 26 +- .../include/execution/cpu/ContextBuffers.cpp | 153 +- .../include/execution/cpu/LaunchContext.cpp | 89 +- .../include/execution/cuda/AffinityManager.cu | 160 +- .../include/execution/cuda/ContextBuffers.cu | 335 +- .../include/execution/cuda/LaunchContext.cu | 247 +- .../include/execution/impl/BlockingQueue.cpp | 81 +- .../execution/impl/CallableInterface.cpp | 387 +- .../execution/impl/CallableWithArguments.cpp | 152 +- .../include/execution/impl/ErrorReference.cpp | 32 +- libnd4j/include/execution/impl/ThreadPool.cpp | 293 +- libnd4j/include/execution/impl/Threads.cpp | 1370 +- libnd4j/include/execution/impl/Ticket.cpp | 152 +- libnd4j/include/graph/ArgumentsList.h | 64 +- libnd4j/include/graph/Context.h | 384 +- libnd4j/include/graph/ContextPrototype.h | 243 +- libnd4j/include/graph/ExecutionResult.h | 125 +- libnd4j/include/graph/ExecutorConfiguration.h | 44 +- libnd4j/include/graph/FlatUtils.h | 30 +- libnd4j/include/graph/FlowPath.h | 87 +- libnd4j/include/graph/FrameState.h | 164 +- libnd4j/include/graph/Graph.h | 258 +- libnd4j/include/graph/GraphHolder.h | 52 +- libnd4j/include/graph/GraphUtils.h | 21 +- libnd4j/include/graph/InferenceRequest.h | 50 +- libnd4j/include/graph/Intervals.h | 41 +- libnd4j/include/graph/Node.h | 417 +- libnd4j/include/graph/NodeState.h | 64 +- libnd4j/include/graph/OptimizedGraph.h | 351 +- libnd4j/include/graph/RandomGenerator.h | 471 +- libnd4j/include/graph/RandomGenerator.hpp | 33 +- libnd4j/include/graph/ResultWrapper.h | 31 +- libnd4j/include/graph/Scope.h | 160 +- libnd4j/include/graph/Stash.h | 90 +- libnd4j/include/graph/Status.h | 40 +- libnd4j/include/graph/TimeHolder.h | 49 +- libnd4j/include/graph/Variable.h | 177 +- libnd4j/include/graph/VariableProxy.h | 134 +- libnd4j/include/graph/VariableSpace.h | 162 +- libnd4j/include/graph/VariableType.h | 18 +- libnd4j/include/graph/VariablesSet.h | 42 +- .../impl/unresolved_input_exception.cpp | 52 +- .../impl/unresolved_output_exception.cpp | 48 +- .../exceptions/unresolved_input_exception.h | 32 +- .../exceptions/unresolved_output_exception.h | 33 +- .../include/graph/execution/ExecutionLayer.h | 71 +- .../include/graph/execution/ExecutionTask.h | 49 +- .../include/graph/execution/GraphExecutor.h | 110 +- libnd4j/include/graph/execution/OpSequence.h | 190 +- .../graph/execution/impl/ExecutionLayer.cpp | 70 +- .../graph/execution/impl/ExecutionTask.cpp | 140 +- .../graph/execution/impl/GraphExecutor.cpp | 129 +- .../graph/execution/impl/OpSequence.cpp | 215 +- .../include/graph/generated/array_generated.h | 146 +- .../graph/generated/config_generated.h | 130 +- .../include/graph/generated/graph.grpc.fb.h | 626 +- .../include/graph/generated/graph_generated.h | 240 +- .../include/graph/generated/node_generated.h | 230 +- .../graph/generated/properties_generated.h | 70 +- .../graph/generated/request_generated.h | 64 +- .../graph/generated/result_generated.h | 98 +- .../graph/generated/uigraphevents_generated.h | 367 +- .../graph/generated/uigraphstatic_generated.h | 374 +- .../include/graph/generated/utils_generated.h | 432 +- .../graph/generated/variable_generated.h | 149 +- libnd4j/include/graph/impl/ArgumentsList.cpp | 33 +- libnd4j/include/graph/impl/Context.cpp | 932 +- .../include/graph/impl/ContextPrototype.cpp | 528 +- .../include/graph/impl/ExecutionResult.cpp | 163 +- .../graph/impl/ExecutorConfiguration.cpp | 55 +- libnd4j/include/graph/impl/FlatUtils.cpp | 185 +- libnd4j/include/graph/impl/FlowPath.cpp | 181 +- libnd4j/include/graph/impl/FrameState.cpp | 58 +- libnd4j/include/graph/impl/Graph.cpp | 1155 +- libnd4j/include/graph/impl/GraphHolder.cpp | 150 +- libnd4j/include/graph/impl/GraphUtils.cpp | 243 +- .../include/graph/impl/InferenceRequest.cpp | 93 +- libnd4j/include/graph/impl/Intervals.cpp | 51 +- libnd4j/include/graph/impl/Node.cpp | 1708 ++- libnd4j/include/graph/impl/NodeState.cpp | 50 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 637 +- .../include/graph/impl/RandomGenerator.cpp | 67 +- libnd4j/include/graph/impl/ResultWrapper.cpp | 48 +- libnd4j/include/graph/impl/Scope.cpp | 63 +- libnd4j/include/graph/impl/Stash.cpp | 104 +- libnd4j/include/graph/impl/TimeHolder.cpp | 34 +- libnd4j/include/graph/impl/Variable.cpp | 557 +- libnd4j/include/graph/impl/VariableProxy.cpp | 425 +- libnd4j/include/graph/impl/VariableSpace.cpp | 628 +- libnd4j/include/graph/impl/VariablesSet.cpp | 43 +- .../include/graph/logic/LogicConditional.h | 44 +- libnd4j/include/graph/logic/LogicEnter.h | 20 +- libnd4j/include/graph/logic/LogicExecutor.h | 31 +- libnd4j/include/graph/logic/LogicExit.h | 20 +- libnd4j/include/graph/logic/LogicExpose.h | 22 +- libnd4j/include/graph/logic/LogicLoopCond.h | 20 +- libnd4j/include/graph/logic/LogicMerge.h | 20 +- .../include/graph/logic/LogicNextIteration.h | 20 +- libnd4j/include/graph/logic/LogicReturn.h | 38 +- libnd4j/include/graph/logic/LogicScope.h | 35 +- libnd4j/include/graph/logic/LogicSwitch.h | 35 +- libnd4j/include/graph/logic/LogicWhile.h | 34 +- .../graph/logic/impl/LogicConditional.cpp | 229 +- .../include/graph/logic/impl/LogicEnter.cpp | 95 +- .../graph/logic/impl/LogicExecutor.cpp | 87 +- .../include/graph/logic/impl/LogicExit.cpp | 52 +- .../include/graph/logic/impl/LogicExpose.cpp | 14 +- .../graph/logic/impl/LogicLoopCond.cpp | 68 +- .../include/graph/logic/impl/LogicMerge.cpp | 163 +- .../graph/logic/impl/LogicNextIteration.cpp | 44 +- .../include/graph/logic/impl/LogicReturn.cpp | 55 +- .../include/graph/logic/impl/LogicScope.cpp | 19 +- .../include/graph/logic/impl/LogicSwitch.cpp | 178 +- .../include/graph/logic/impl/LogicWhile.cpp | 244 +- .../graph/optimization/GraphOptimizer.h | 27 +- .../graph/optimization/NodeOptimizer.h | 63 +- .../optimization/impl/GraphOptimizer.cpp | 16 +- .../graph/optimization/impl/NodeOptimizer.cpp | 10 +- .../include/graph/profiling/GraphProfile.h | 187 +- .../graph/profiling/GraphProfilingHelper.h | 18 +- libnd4j/include/graph/profiling/NodeProfile.h | 122 +- .../graph/profiling/impl/GraphProfile.cpp | 365 +- .../profiling/impl/GraphProfilingHelper.cpp | 77 +- .../graph/profiling/impl/NodeProfile.cpp | 255 +- libnd4j/include/helpers/ArrayUtils.h | 26 +- libnd4j/include/helpers/AttentionHelper.h | 20 +- libnd4j/include/helpers/BenchmarkHelper.h | 162 +- libnd4j/include/helpers/BitwiseUtils.h | 175 +- libnd4j/include/helpers/BlasHelper.h | 685 +- libnd4j/include/helpers/ConstantHelper.h | 60 +- libnd4j/include/helpers/ConstantShapeHelper.h | 142 +- libnd4j/include/helpers/ConstantTadHelper.h | 126 +- libnd4j/include/helpers/CudaLaunchHelper.h | 19 +- libnd4j/include/helpers/DebugHelper.h | 84 +- libnd4j/include/helpers/DebugInfo.h | 62 +- libnd4j/include/helpers/EnumUtils.h | 15 +- libnd4j/include/helpers/FileUtils.h | 18 +- libnd4j/include/helpers/GradCheck.h | 72 +- libnd4j/include/helpers/LoopKind.h | 451 +- libnd4j/include/helpers/Loops.h | 2455 ++-- libnd4j/include/helpers/Loops.hpp | 25 +- libnd4j/include/helpers/LoopsCoordsHelper.h | 804 +- libnd4j/include/helpers/MKLDNNStream.h | 81 +- libnd4j/include/helpers/MmulHelper.h | 96 +- libnd4j/include/helpers/OmpLaunchHelper.h | 59 +- libnd4j/include/helpers/OpArgsHolder.h | 94 +- libnd4j/include/helpers/OpBenchmark.h | 83 +- libnd4j/include/helpers/OpTracker.h | 50 +- libnd4j/include/helpers/PointersManager.h | 64 +- libnd4j/include/helpers/RandomLauncher.h | 54 +- libnd4j/include/helpers/ShapeBuilders.h | 97 +- libnd4j/include/helpers/ShapeUtils.h | 468 +- libnd4j/include/helpers/SimpleReadWriteLock.h | 58 +- libnd4j/include/helpers/StringUtils.h | 232 +- libnd4j/include/helpers/TAD.h | 1799 +-- libnd4j/include/helpers/benchmark/BasicSuit.h | 14 +- .../helpers/benchmark/BoolParameters.h | 32 +- .../helpers/benchmark/BroadcastBenchmark.h | 186 +- .../helpers/benchmark/DeclarableBenchmark.h | 270 +- .../include/helpers/benchmark/IntParameters.h | 54 +- .../helpers/benchmark/IntPowerParameters.h | 59 +- .../helpers/benchmark/MatrixBenchmark.h | 175 +- .../helpers/benchmark/PairwiseBenchmark.h | 141 +- .../include/helpers/benchmark/Parameters.h | 54 +- .../helpers/benchmark/ParametersBatch.h | 111 +- .../helpers/benchmark/ParametersSpace.h | 31 +- .../helpers/benchmark/PredefinedParameters.h | 41 +- .../helpers/benchmark/ReductionBenchmark.h | 279 +- .../helpers/benchmark/ScalarBenchmark.h | 139 +- .../helpers/benchmark/TransformBenchmark.h | 222 +- libnd4j/include/helpers/biDiagonalUp.h | 85 +- .../include/helpers/cpu/ConstantHelper.cpp | 176 +- .../helpers/cpu/ConstantShapeHelper.cpp | 293 +- .../include/helpers/cpu/ConstantTadHelper.cpp | 180 +- libnd4j/include/helpers/cpu/MmulHelper.cpp | 1005 +- .../include/helpers/cpu/PointersManager.cpp | 22 +- libnd4j/include/helpers/cpu/biDiagonalUp.cpp | 225 +- libnd4j/include/helpers/cpu/cublasHelper.cpp | 42 +- libnd4j/include/helpers/cpu/hhColPivQR.cpp | 245 +- libnd4j/include/helpers/cpu/hhSequence.cpp | 135 +- libnd4j/include/helpers/cpu/householder.cpp | 324 +- libnd4j/include/helpers/cpu/jacobiSVD.cpp | 641 +- .../helpers/cpu/loops/IndexReductionLoops.hpp | 536 +- .../cpu/loops/IndexReductionLoops_int32_0.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_1.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_2.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_3.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_4.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_5.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_6.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_7.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_8.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int32_9.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_0.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_1.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_2.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_3.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_4.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_5.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_6.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_7.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_8.cpp | 9 +- .../cpu/loops/IndexReductionLoops_int64_9.cpp | 9 +- .../helpers/cpu/loops/Reduction3Loops_0.cpp | 70 +- .../helpers/cpu/loops/Reduction3Loops_1.cpp | 70 +- .../helpers/cpu/loops/Reduction3Loops_2.cpp | 70 +- .../helpers/cpu/loops/Reduction3Loops_3.cpp | 70 +- .../helpers/cpu/loops/ReductionLoops_bool.cpp | 37 +- .../cpu/loops/ReductionLoops_float_0.cpp | 43 +- .../cpu/loops/ReductionLoops_float_1.cpp | 44 +- .../cpu/loops/ReductionLoops_float_2.cpp | 39 +- .../cpu/loops/ReductionLoops_float_3.cpp | 39 +- .../helpers/cpu/loops/ReductionLoops_long.cpp | 40 +- .../helpers/cpu/loops/ReductionLoops_same.cpp | 46 +- libnd4j/include/helpers/cpu/svd.cpp | 1657 +-- libnd4j/include/helpers/cublasHelper.h | 40 +- .../include/helpers/cuda/ConstantHelper.cu | 314 +- .../helpers/cuda/ConstantShapeHelper.cu | 299 +- .../include/helpers/cuda/ConstantTadHelper.cu | 197 +- .../include/helpers/cuda/PointersManager.cu | 166 +- .../include/helpers/cuda_off/MmulHelper.cu | 1383 +- .../include/helpers/cuda_off/cublasHelper.cu | 206 +- libnd4j/include/helpers/data_gen.h | 17 +- libnd4j/include/helpers/files.h | 143 +- libnd4j/include/helpers/helper_generator.h | 1048 +- libnd4j/include/helpers/helper_hash.h | 46 +- libnd4j/include/helpers/helper_ptrmap.h | 372 +- libnd4j/include/helpers/helper_random.h | 386 +- libnd4j/include/helpers/hhColPivQR.h | 35 +- libnd4j/include/helpers/hhSequence.h | 119 +- libnd4j/include/helpers/householder.h | 172 +- libnd4j/include/helpers/impl/ArrayUtils.cpp | 63 +- .../include/helpers/impl/AttentionHelper.cpp | 119 +- .../include/helpers/impl/BenchmarkHelper.cpp | 1361 +- libnd4j/include/helpers/impl/BitwiseUtils.cpp | 73 +- libnd4j/include/helpers/impl/BlasHelper.cpp | 582 +- .../include/helpers/impl/CudaLaunchHelper.cpp | 26 +- libnd4j/include/helpers/impl/DebugHelper.cpp | 157 +- libnd4j/include/helpers/impl/EnumUtils.cpp | 136 +- libnd4j/include/helpers/impl/FileUtils.cpp | 23 +- libnd4j/include/helpers/impl/GradCheck.cpp | 264 +- libnd4j/include/helpers/impl/MmulHelper.cpp | 593 +- .../include/helpers/impl/OmpLaunchHelper.cpp | 141 +- libnd4j/include/helpers/impl/OpArgsHolder.cpp | 171 +- libnd4j/include/helpers/impl/OpBenchmark.cpp | 178 +- libnd4j/include/helpers/impl/OpTracker.cpp | 151 +- libnd4j/include/helpers/impl/Parameters.cpp | 115 +- .../include/helpers/impl/RandomLauncher.cpp | 251 +- .../include/helpers/impl/ShapeBuilders.cpp | 221 +- libnd4j/include/helpers/impl/ShapeUtils.cpp | 1903 +-- .../helpers/impl/SimpleReadWriteLock.cpp | 74 +- libnd4j/include/helpers/impl/StringUtils.cpp | 264 +- libnd4j/include/helpers/impl/TAD.cpp | 5 +- libnd4j/include/helpers/impl/helper_hash.cpp | 71 +- libnd4j/include/helpers/impl/logger.cpp | 70 +- libnd4j/include/helpers/impl/shape.cpp | 5 +- libnd4j/include/helpers/impl/unicode.cpp | 790 +- libnd4j/include/helpers/jacobiSVD.h | 58 +- libnd4j/include/helpers/logger.h | 45 +- libnd4j/include/helpers/mman.h | 355 +- libnd4j/include/helpers/shape.h | 6894 +++++----- libnd4j/include/helpers/svd.h | 96 +- libnd4j/include/helpers/threshold.h | 3 +- libnd4j/include/helpers/unicode.h | 321 +- libnd4j/include/indexing/IndicesList.h | 34 +- libnd4j/include/indexing/NDIndex.h | 77 +- libnd4j/include/indexing/impl/IndicesList.cpp | 28 +- libnd4j/include/indexing/impl/NDIndex.cpp | 77 +- libnd4j/include/legacy/NativeOpExecutioner.h | 1182 +- libnd4j/include/legacy/NativeOps.h | 1495 +-- .../legacy/cpu/NativeOpExecutioner.cpp | 2422 ++-- libnd4j/include/legacy/cpu/NativeOps.cpp | 4508 ++++--- .../include/legacy/cuda/BlasVersionHelper.cu | 12 +- .../legacy/cuda/NativeOpExecutioner.cu | 3106 ++--- libnd4j/include/legacy/cuda/NativeOps.cu | 6427 +++++----- libnd4j/include/legacy/impl/Environment.cpp | 551 +- libnd4j/include/legacy/impl/cnpy.cpp | 1095 +- .../loops/BroadcastPairwiseConverter.h | 163 +- .../include/loops/BroadcastScalarConverter.h | 71 +- libnd4j/include/loops/ReduceType.h | 12 +- libnd4j/include/loops/broadcasting.h | 270 +- libnd4j/include/loops/broadcasting_bool.h | 271 +- libnd4j/include/loops/broadcasting_int.h | 261 +- libnd4j/include/loops/cpu/broadcasting.hpp | 1612 +-- .../include/loops/cpu/broadcasting_bool.hpp | 1395 +- .../include/loops/cpu/broadcasting_int.hpp | 1355 +- .../compilation_units/broadcast_bool_p0.cpp | 9 +- .../compilation_units/broadcast_bool_p1.cpp | 9 +- .../compilation_units/broadcast_bool_p2.cpp | 9 +- .../compilation_units/broadcast_bool_p3.cpp | 9 +- .../compilation_units/broadcast_bool_p4.cpp | 9 +- .../compilation_units/broadcast_bool_p5.cpp | 9 +- .../compilation_units/broadcast_bool_p6.cpp | 9 +- .../compilation_units/broadcast_bool_p7.cpp | 9 +- .../compilation_units/broadcast_bool_p8.cpp | 9 +- .../compilation_units/broadcast_bool_p9.cpp | 9 +- .../compilation_units/broadcast_int_p0.cpp | 8 +- .../compilation_units/broadcast_int_p1.cpp | 8 +- .../compilation_units/broadcast_int_p2.cpp | 8 +- .../compilation_units/broadcast_int_p3.cpp | 8 +- .../compilation_units/broadcast_int_p4.cpp | 8 +- .../compilation_units/broadcast_int_p5.cpp | 8 +- .../compilation_units/broadcast_int_p6.cpp | 8 +- .../compilation_units/broadcast_int_p7.cpp | 8 +- .../cpu/compilation_units/broadcast_p0.cpp | 8 +- .../cpu/compilation_units/broadcast_p1.cpp | 8 +- .../cpu/compilation_units/broadcast_p10.cpp | 9 +- .../cpu/compilation_units/broadcast_p11.cpp | 9 +- .../cpu/compilation_units/broadcast_p12.cpp | 9 +- .../cpu/compilation_units/broadcast_p2.cpp | 8 +- .../cpu/compilation_units/broadcast_p3.cpp | 8 +- .../cpu/compilation_units/broadcast_p4.cpp | 8 +- .../cpu/compilation_units/broadcast_p5.cpp | 8 +- .../cpu/compilation_units/broadcast_p6.cpp | 8 +- .../cpu/compilation_units/broadcast_p7.cpp | 8 +- .../cpu/compilation_units/broadcast_p8.cpp | 8 +- .../cpu/compilation_units/broadcast_p9.cpp | 8 +- .../compilation_units/indexreduce_int32_0.cpp | 9 +- .../compilation_units/indexreduce_int32_1.cpp | 9 +- .../compilation_units/indexreduce_int32_2.cpp | 9 +- .../compilation_units/indexreduce_int32_3.cpp | 9 +- .../compilation_units/indexreduce_int32_4.cpp | 9 +- .../compilation_units/indexreduce_int32_5.cpp | 9 +- .../compilation_units/indexreduce_int32_6.cpp | 9 +- .../compilation_units/indexreduce_int32_7.cpp | 9 +- .../compilation_units/indexreduce_int32_8.cpp | 9 +- .../compilation_units/indexreduce_int32_9.cpp | 9 +- .../compilation_units/indexreduce_int64_0.cpp | 9 +- .../compilation_units/indexreduce_int64_1.cpp | 9 +- .../compilation_units/indexreduce_int64_2.cpp | 9 +- .../compilation_units/indexreduce_int64_3.cpp | 9 +- .../compilation_units/indexreduce_int64_4.cpp | 9 +- .../compilation_units/indexreduce_int64_5.cpp | 9 +- .../compilation_units/indexreduce_int64_6.cpp | 9 +- .../compilation_units/indexreduce_int64_7.cpp | 9 +- .../compilation_units/indexreduce_int64_8.cpp | 9 +- .../compilation_units/indexreduce_int64_9.cpp | 9 +- .../cpu/compilation_units/pairwise_p0.cpp | 9 +- .../cpu/compilation_units/pairwise_p1.cpp | 9 +- .../cpu/compilation_units/pairwise_p10.cpp | 9 +- .../cpu/compilation_units/pairwise_p11.cpp | 9 +- .../cpu/compilation_units/pairwise_p12.cpp | 9 +- .../cpu/compilation_units/pairwise_p2.cpp | 9 +- .../cpu/compilation_units/pairwise_p3.cpp | 9 +- .../cpu/compilation_units/pairwise_p4.cpp | 9 +- .../cpu/compilation_units/pairwise_p5.cpp | 9 +- .../cpu/compilation_units/pairwise_p6.cpp | 9 +- .../cpu/compilation_units/pairwise_p7.cpp | 9 +- .../cpu/compilation_units/pairwise_p8.cpp | 9 +- .../cpu/compilation_units/pairwise_p9.cpp | 9 +- .../loops/cpu/compilation_units/random_0.cpp | 8 +- .../loops/cpu/compilation_units/random_1.cpp | 8 +- .../loops/cpu/compilation_units/random_2.cpp | 8 +- .../loops/cpu/compilation_units/random_3.cpp | 8 +- .../compilation_units/reduce3_bfloat16_0.cpp | 9 +- .../compilation_units/reduce3_bfloat16_1.cpp | 9 +- .../compilation_units/reduce3_bfloat16_2.cpp | 9 +- .../compilation_units/reduce3_bfloat16_3.cpp | 9 +- .../compilation_units/reduce3_bfloat16_4.cpp | 9 +- .../compilation_units/reduce3_bfloat16_5.cpp | 9 +- .../compilation_units/reduce3_bfloat16_6.cpp | 9 +- .../compilation_units/reduce3_bfloat16_7.cpp | 9 +- .../compilation_units/reduce3_bfloat16_8.cpp | 9 +- .../compilation_units/reduce3_bfloat16_9.cpp | 9 +- .../compilation_units/reduce3_double_0.cpp | 9 +- .../compilation_units/reduce3_double_1.cpp | 9 +- .../compilation_units/reduce3_double_2.cpp | 9 +- .../compilation_units/reduce3_double_3.cpp | 9 +- .../compilation_units/reduce3_double_4.cpp | 9 +- .../compilation_units/reduce3_double_5.cpp | 9 +- .../compilation_units/reduce3_double_6.cpp | 9 +- .../compilation_units/reduce3_double_7.cpp | 9 +- .../compilation_units/reduce3_double_8.cpp | 9 +- .../compilation_units/reduce3_double_9.cpp | 9 +- .../compilation_units/reduce3_float16_0.cpp | 9 +- .../compilation_units/reduce3_float16_1.cpp | 9 +- .../compilation_units/reduce3_float16_2.cpp | 9 +- .../compilation_units/reduce3_float16_3.cpp | 9 +- .../compilation_units/reduce3_float16_4.cpp | 9 +- .../compilation_units/reduce3_float16_5.cpp | 9 +- .../compilation_units/reduce3_float16_6.cpp | 9 +- .../compilation_units/reduce3_float16_7.cpp | 9 +- .../compilation_units/reduce3_float16_8.cpp | 9 +- .../compilation_units/reduce3_float16_9.cpp | 9 +- .../cpu/compilation_units/reduce3_float_0.cpp | 9 +- .../cpu/compilation_units/reduce3_float_1.cpp | 9 +- .../cpu/compilation_units/reduce3_float_2.cpp | 9 +- .../cpu/compilation_units/reduce3_float_3.cpp | 9 +- .../cpu/compilation_units/reduce3_float_4.cpp | 9 +- .../cpu/compilation_units/reduce3_float_5.cpp | 9 +- .../cpu/compilation_units/reduce3_float_6.cpp | 9 +- .../cpu/compilation_units/reduce3_float_7.cpp | 9 +- .../cpu/compilation_units/reduce3_float_8.cpp | 9 +- .../cpu/compilation_units/reduce3_float_9.cpp | 9 +- .../cpu/compilation_units/reduce_float_0.cpp | 7 +- .../cpu/compilation_units/reduce_float_1.cpp | 7 +- .../cpu/compilation_units/reduce_float_2.cpp | 7 +- .../cpu/compilation_units/reduce_float_3.cpp | 7 +- .../loops/cpu/compilation_units/scalar_p0.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p1.cpp | 9 +- .../cpu/compilation_units/scalar_p10.cpp | 9 +- .../cpu/compilation_units/scalar_p11.cpp | 9 +- .../cpu/compilation_units/scalar_p12.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p2.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p3.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p4.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p5.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p6.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p7.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p8.cpp | 9 +- .../loops/cpu/compilation_units/scalar_p9.cpp | 9 +- libnd4j/include/loops/cpu/indexreduce.hpp | 200 +- libnd4j/include/loops/cpu/pairwise.hpp | 397 +- libnd4j/include/loops/cpu/pairwise_bool.cpp | 390 +- libnd4j/include/loops/cpu/pairwise_int.cpp | 393 +- libnd4j/include/loops/cpu/random.hpp | 531 +- .../include/loops/cpu/reduce/reduce_bool.cpp | 414 +- .../include/loops/cpu/reduce/reduce_float.hpp | 476 +- .../include/loops/cpu/reduce/reduce_long.cpp | 454 +- .../include/loops/cpu/reduce/reduce_same.cpp | 466 +- libnd4j/include/loops/cpu/reduce3.hpp | 416 +- libnd4j/include/loops/cpu/scalar.hpp | 325 +- libnd4j/include/loops/cpu/scalar_bool.cpp | 352 +- libnd4j/include/loops/cpu/scalar_int.cpp | 352 +- .../include/loops/cpu/summarystatsreduce.cpp | 318 +- .../loops/cpu/transform/transform_any.cpp | 60 +- .../loops/cpu/transform/transform_bool.cpp | 66 +- .../loops/cpu/transform/transform_float.cpp | 64 +- .../loops/cpu/transform/transform_same.cpp | 65 +- .../loops/cpu/transform/transform_strict.cpp | 64 +- libnd4j/include/loops/cuda/broadcasting.chpp | 553 +- libnd4j/include/loops/cuda/broadcasting.cu | 21 +- .../include/loops/cuda/broadcasting_bool.cu | 518 +- .../include/loops/cuda/broadcasting_int.cu | 480 +- .../broadcasting/broadcasting_0.cu | 8 +- .../broadcasting/broadcasting_1.cu | 8 +- .../broadcasting/broadcasting_10.cu | 9 +- .../broadcasting/broadcasting_11.cu | 9 +- .../broadcasting/broadcasting_12.cu | 9 +- .../broadcasting/broadcasting_2.cu | 8 +- .../broadcasting/broadcasting_3.cu | 8 +- .../broadcasting/broadcasting_4.cu | 8 +- .../broadcasting/broadcasting_5.cu | 8 +- .../broadcasting/broadcasting_6.cu | 8 +- .../broadcasting/broadcasting_7.cu | 8 +- .../broadcasting/broadcasting_8.cu | 8 +- .../broadcasting/broadcasting_9.cu | 8 +- .../compilation_units/pairwise/pairwise_0.cu | 9 +- .../compilation_units/pairwise/pairwise_1.cu | 9 +- .../compilation_units/pairwise/pairwise_10.cu | 9 +- .../compilation_units/pairwise/pairwise_11.cu | 9 +- .../compilation_units/pairwise/pairwise_12.cu | 9 +- .../compilation_units/pairwise/pairwise_2.cu | 9 +- .../compilation_units/pairwise/pairwise_3.cu | 9 +- .../compilation_units/pairwise/pairwise_4.cu | 9 +- .../compilation_units/pairwise/pairwise_5.cu | 9 +- .../compilation_units/pairwise/pairwise_6.cu | 9 +- .../compilation_units/pairwise/pairwise_7.cu | 9 +- .../compilation_units/pairwise/pairwise_8.cu | 9 +- .../compilation_units/pairwise/pairwise_9.cu | 9 +- .../compilation_units/reduce3/reduce3_0.cu | 9 +- .../compilation_units/reduce3/reduce3_1.cu | 9 +- .../compilation_units/reduce3/reduce3_2.cu | 9 +- .../compilation_units/reduce3/reduce3_3.cu | 9 +- .../reduce_float/reduce_float_0.cu | 9 +- .../reduce_float/reduce_float_1.cu | 9 +- .../reduce_float/reduce_float_2.cu | 9 +- .../reduce_float/reduce_float_3.cu | 9 +- .../cuda/compilation_units/scalar/scalar_0.cu | 9 +- .../cuda/compilation_units/scalar/scalar_1.cu | 9 +- .../compilation_units/scalar/scalar_10.cu | 9 +- .../compilation_units/scalar/scalar_11.cu | 9 +- .../compilation_units/scalar/scalar_12.cu | 9 +- .../cuda/compilation_units/scalar/scalar_2.cu | 9 +- .../cuda/compilation_units/scalar/scalar_3.cu | 9 +- .../cuda/compilation_units/scalar/scalar_4.cu | 9 +- .../cuda/compilation_units/scalar/scalar_5.cu | 9 +- .../cuda/compilation_units/scalar/scalar_6.cu | 9 +- .../cuda/compilation_units/scalar/scalar_7.cu | 9 +- .../cuda/compilation_units/scalar/scalar_8.cu | 9 +- .../cuda/compilation_units/scalar/scalar_9.cu | 9 +- libnd4j/include/loops/cuda/indexreduce.cu | 650 +- .../cuda/inplace_loops/reduce_same_inplace.h | 285 +- .../loops/cuda/inplace_loops/scalar_inplace.h | 101 +- .../inplace_loops/transform_strict_inplace.h | 125 +- libnd4j/include/loops/cuda/pairwise.chpp | 169 +- libnd4j/include/loops/cuda/pairwise.cu | 6 +- libnd4j/include/loops/cuda/pairwise_bool.cu | 156 +- libnd4j/include/loops/cuda/pairwise_int.cu | 154 +- libnd4j/include/loops/cuda/random.cu | 853 +- .../include/loops/cuda/reduce/reduce_bool.cu | 521 +- .../loops/cuda/reduce/reduce_float.chpp | 505 +- .../include/loops/cuda/reduce/reduce_long.cu | 531 +- .../include/loops/cuda/reduce/reduce_same.cu | 512 +- libnd4j/include/loops/cuda/reduce3.chpp | 912 +- libnd4j/include/loops/cuda/reduce3.cu | 14 +- libnd4j/include/loops/cuda/scalar.chpp | 258 +- libnd4j/include/loops/cuda/scalar.cu | 11 +- libnd4j/include/loops/cuda/scalar_bool.cu | 344 +- libnd4j/include/loops/cuda/scalar_int.cu | 341 +- .../loops/cuda/specials/accumulateKernel.cu | 90 +- .../loops/cuda/specials/averagingKernel.cu | 136 +- .../cuda/specials/bitonicArbitraryStep.cu | 326 +- .../loops/cuda/specials/bitonicSortStep.cu | 217 +- .../loops/cuda/specials/concatKernel.cu | 432 +- .../loops/cuda/specials/concatKernelHStack.cu | 124 +- .../loops/cuda/specials/concatKernelScalar.cu | 52 +- .../loops/cuda/specials/concatKernelVStack.cu | 110 +- .../loops/cuda/specials/convertHalfs.cu | 37 +- .../loops/cuda/specials/convertToHalf.cu | 37 +- .../cuda/specials/fillDimensionalIsMax.cu | 129 +- .../include/loops/cuda/specials/fillIsMax.cu | 38 +- .../include/loops/cuda/specials/flatten.cu | 62 +- libnd4j/include/loops/cuda/specials/oesTad.cu | 329 +- .../loops/cuda/specials/pullRowsKernel.cu | 113 +- .../loops/cuda/specials/setDiagonalKernel.cu | 200 +- .../loops/cuda/specials/shuffleKernel.cu | 183 +- .../loops/cuda/specials/swapUnsafeKernel.cu | 93 +- .../include/loops/cuda/specials/tearKernel.cu | 131 +- .../include/loops/cuda/specials/tileKernel.cu | 181 +- .../include/loops/cuda/summarystatsreduce.cu | 756 +- .../loops/cuda/transform/transform_any.cu | 200 +- .../loops/cuda/transform/transform_bool.cu | 209 +- .../loops/cuda/transform/transform_float.cu | 228 +- .../loops/cuda/transform/transform_same.cu | 197 +- .../loops/cuda/transform/transform_strict.cu | 204 +- .../include/loops/cuda/type_conversions.cu | 907 +- .../include/loops/impl/type_conversions.cpp | 399 +- libnd4j/include/loops/indexreduce.h | 147 +- libnd4j/include/loops/legacy_ops.h | 513 +- libnd4j/include/loops/pairwise_bool.h | 99 +- libnd4j/include/loops/pairwise_int.h | 100 +- libnd4j/include/loops/pairwise_transform.h | 98 +- libnd4j/include/loops/random.h | 137 +- libnd4j/include/loops/reduce3.h | 361 +- libnd4j/include/loops/reduce_bool.h | 266 +- libnd4j/include/loops/reduce_float.h | 272 +- libnd4j/include/loops/reduce_long.h | 268 +- libnd4j/include/loops/reduce_same.h | 278 +- libnd4j/include/loops/scalar.h | 234 +- libnd4j/include/loops/scalar_bool.h | 282 +- libnd4j/include/loops/scalar_int.h | 286 +- libnd4j/include/loops/special_kernels.h | 194 +- libnd4j/include/loops/summarystatsreduce.h | 529 +- libnd4j/include/loops/transform_any.h | 83 +- libnd4j/include/loops/transform_bool.h | 85 +- libnd4j/include/loops/transform_float.h | 113 +- libnd4j/include/loops/transform_same.h | 85 +- libnd4j/include/loops/transform_strict.h | 88 +- libnd4j/include/loops/type_conversions.h | 182 +- libnd4j/include/math/platformmath.h | 1235 +- libnd4j/include/math/templatemath.h | 2643 ++-- libnd4j/include/memory/AllocationEntry.h | 47 +- libnd4j/include/memory/ColdZoneManager.h | 41 +- libnd4j/include/memory/ExternalWorkspace.h | 48 +- libnd4j/include/memory/GraphMemoryManager.h | 59 +- libnd4j/include/memory/HotRamZoneManager.h | 23 +- libnd4j/include/memory/HotZoneManager.h | 36 +- libnd4j/include/memory/MemoryCounter.h | 240 +- libnd4j/include/memory/MemoryDescriptor.h | 45 +- libnd4j/include/memory/MemoryRegistrator.h | 73 +- libnd4j/include/memory/MemoryReport.h | 58 +- libnd4j/include/memory/MemoryTracker.h | 63 +- libnd4j/include/memory/MemoryType.h | 14 +- libnd4j/include/memory/MemoryUtils.h | 21 +- libnd4j/include/memory/MemoryZone.h | 16 +- libnd4j/include/memory/WarmZoneManager.h | 21 +- libnd4j/include/memory/Workspace.h | 119 +- libnd4j/include/memory/ZoneManager.h | 89 +- .../include/memory/cpu/ColdZoneManager.cpp | 44 +- .../include/memory/cpu/GraphMemoryManager.cpp | 52 +- libnd4j/include/memory/cpu/HotZoneManager.cpp | 19 +- libnd4j/include/memory/cpu/Workspace.cpp | 282 +- libnd4j/include/memory/cuda/Workspace.cu | 455 +- .../include/memory/impl/AllocationEntry.cpp | 31 +- .../include/memory/impl/ExternalWorkspace.cpp | 45 +- .../include/memory/impl/HotRamZoneManager.cpp | 24 +- libnd4j/include/memory/impl/MemoryCounter.cpp | 218 +- .../include/memory/impl/MemoryDescriptor.cpp | 71 +- .../include/memory/impl/MemoryRegistrator.cpp | 92 +- libnd4j/include/memory/impl/MemoryReport.cpp | 44 +- libnd4j/include/memory/impl/MemoryTracker.cpp | 260 +- libnd4j/include/memory/impl/MemoryUtils.cpp | 85 +- libnd4j/include/ops/BroadcastBoolOpsTuple.h | 50 +- libnd4j/include/ops/BroadcastIntOpsTuple.h | 49 +- libnd4j/include/ops/BroadcastOpsTuple.h | 73 +- libnd4j/include/ops/InputType.h | 18 +- libnd4j/include/ops/declarable/BooleanOp.h | 37 +- .../ops/declarable/BroadcastableBoolOp.h | 28 +- .../include/ops/declarable/BroadcastableOp.h | 28 +- .../include/ops/declarable/CustomOperations.h | 95 +- .../ops/declarable/DeclarableCustomOp.h | 31 +- .../include/ops/declarable/DeclarableListOp.h | 47 +- libnd4j/include/ops/declarable/DeclarableOp.h | 396 +- .../ops/declarable/DeclarableReductionOp.h | 31 +- .../include/ops/declarable/EmptyHandling.h | 8 +- .../ops/declarable/LegacyBroadcastBoolOp.h | 33 +- .../ops/declarable/LegacyBroadcastOp.h | 33 +- .../ops/declarable/LegacyIndexReduceOp.h | 39 +- libnd4j/include/ops/declarable/LegacyOp.h | 68 +- .../LegacyPairwiseTransformBoolOp.h | 33 +- .../declarable/LegacyPairwiseTransformOp.h | 33 +- .../include/ops/declarable/LegacyRandomOp.h | 64 +- .../include/ops/declarable/LegacyReduce3Op.h | 34 +- .../ops/declarable/LegacyReduceBoolOp.h | 27 +- .../ops/declarable/LegacyReduceFloatOp.h | 27 +- .../ops/declarable/LegacyReduceLongOp.h | 27 +- .../include/ops/declarable/LegacyReduceOp.h | 11 +- .../ops/declarable/LegacyReduceSameOp.h | 27 +- .../ops/declarable/LegacyScalarBoolOp.h | 43 +- .../include/ops/declarable/LegacyScalarOp.h | 43 +- .../include/ops/declarable/LegacyStatsOp.h | 34 +- .../ops/declarable/LegacyTransformAnyOp.h | 40 +- .../ops/declarable/LegacyTransformBoolOp.h | 34 +- .../ops/declarable/LegacyTransformFloatOp.h | 40 +- .../ops/declarable/LegacyTransformOp.h | 40 +- .../ops/declarable/LegacyTransformSameOp.h | 34 +- .../ops/declarable/LegacyTransformStrictOp.h | 34 +- libnd4j/include/ops/declarable/LogicOp.h | 45 +- libnd4j/include/ops/declarable/OpDescriptor.h | 246 +- .../include/ops/declarable/OpRegistrator.h | 195 +- libnd4j/include/ops/declarable/OpTuple.h | 52 +- .../include/ops/declarable/PlatformHelper.h | 138 +- .../declarable/generic/CustomOperations.cpp | 47 +- .../generic/bitwise/bits_hamming_distance.cpp | 57 +- .../generic/bitwise/bitwise_and.cpp | 31 +- .../declarable/generic/bitwise/bitwise_or.cpp | 31 +- .../generic/bitwise/bitwise_xor.cpp | 31 +- .../generic/bitwise/cyclic_rshift.cpp | 32 +- .../generic/bitwise/cyclic_shift.cpp | 31 +- .../ops/declarable/generic/bitwise/rshift.cpp | 31 +- .../ops/declarable/generic/bitwise/shift.cpp | 31 +- .../generic/bitwise/toggle_bits.cpp | 46 +- .../ops/declarable/generic/blas/axpy.cpp | 67 +- .../declarable/generic/blas/batched_gemm.cpp | 214 +- .../ops/declarable/generic/blas/matmul.cpp | 317 +- .../ops/declarable/generic/blas/svd.cpp | 159 +- .../declarable/generic/blas/tensormmul.cpp | 238 +- .../generic/boolean/boolean_not.cpp | 32 +- .../ops/declarable/generic/boolean/choose.cpp | 131 +- .../declarable/generic/boolean/eq_scalar.cpp | 36 +- .../declarable/generic/boolean/gt_scalar.cpp | 36 +- .../declarable/generic/boolean/gte_scalar.cpp | 36 +- .../generic/boolean/is_non_decreasing.cpp | 48 +- .../generic/boolean/is_numeric_tensor.cpp | 23 +- .../boolean/is_strictly_increasing.cpp | 48 +- .../declarable/generic/boolean/lt_scalar.cpp | 36 +- .../declarable/generic/boolean/lte_scalar.cpp | 36 +- .../declarable/generic/boolean/neq_scalar.cpp | 36 +- .../ops/declarable/generic/boolean/select.cpp | 147 +- .../ops/declarable/generic/boolean/where.cpp | 214 +- .../declarable/generic/boolean/where_np.cpp | 254 +- .../declarable/generic/broadcastable/add.cpp | 177 +- .../generic/broadcastable/assign.cpp | 166 +- .../generic/broadcastable/atan2.cpp | 44 +- .../generic/broadcastable/boolean_and.cpp | 51 +- .../generic/broadcastable/boolean_or.cpp | 51 +- .../generic/broadcastable/boolean_xor.cpp | 51 +- .../generic/broadcastable/divide.cpp | 235 +- .../generic/broadcastable/divide_no_nan.cpp | 58 +- .../generic/broadcastable/equals.cpp | 61 +- .../generic/broadcastable/floordiv.cpp | 134 +- .../generic/broadcastable/floormod.cpp | 160 +- .../generic/broadcastable/greater.cpp | 45 +- .../generic/broadcastable/greater_equal.cpp | 45 +- .../generic/broadcastable/igamma.cpp | 62 +- .../generic/broadcastable/igammac.cpp | 60 +- .../declarable/generic/broadcastable/less.cpp | 44 +- .../generic/broadcastable/less_equal.cpp | 45 +- .../generic/broadcastable/maximum.cpp | 122 +- .../generic/broadcastable/meshgrid.cpp | 98 +- .../generic/broadcastable/minimum.cpp | 119 +- .../declarable/generic/broadcastable/mod.cpp | 132 +- .../generic/broadcastable/multiply.cpp | 245 +- .../generic/broadcastable/not_equals.cpp | 45 +- .../generic/broadcastable/percentile.cpp | 142 +- .../declarable/generic/broadcastable/pow.cpp | 192 +- .../generic/broadcastable/realdiv.cpp | 179 +- .../generic/broadcastable/reverse_divide.cpp | 201 +- .../generic/broadcastable/reverse_mod.cpp | 109 +- .../broadcastable/reverse_subtract.cpp | 185 +- .../broadcastable/squared_subtract.cpp | 241 +- .../generic/broadcastable/subtract.cpp | 186 +- .../generic/broadcastable/truncatediv.cpp | 44 +- .../generic/compat/compat_sparse_to_dense.cpp | 69 +- .../generic/compat/compat_string_split.cpp | 228 +- .../declarable/generic/compression/bitmap.cpp | 101 +- .../generic/compression/threshold.cpp | 172 +- .../declarable/generic/datatypes/bitcast.cpp | 141 +- .../ops/declarable/generic/datatypes/cast.cpp | 67 +- .../generic/datatypes/to_double.cpp | 40 +- .../generic/datatypes/to_float16.cpp | 40 +- .../generic/datatypes/to_float32.cpp | 40 +- .../declarable/generic/datatypes/to_int32.cpp | 48 +- .../declarable/generic/datatypes/to_int64.cpp | 48 +- .../generic/datatypes/to_uint32.cpp | 48 +- .../generic/datatypes/to_uint64.cpp | 46 +- .../generic/flow/flow_control_ops.cpp | 111 +- .../generic/grad/broadcast_gradient_args.cpp | 28 +- .../generic/helpers/BroadcastHelper.h | 234 +- .../generic/helpers/ScatterHelper.h | 13 +- .../generic/images/adjust_contrast.cpp | 161 +- .../declarable/generic/images/adjust_hue.cpp | 95 +- .../generic/images/adjust_saturation.cpp | 88 +- .../generic/images/crop_and_resize.cpp | 117 +- .../generic/images/draw_bounding_boxes.cpp | 89 +- .../generic/images/extract_image_patches.cpp | 148 +- .../declarable/generic/images/hsvToRgb.cpp | 53 +- .../generic/images/image_resize.cpp | 111 +- .../declarable/generic/images/resize_area.cpp | 220 +- .../generic/images/resize_bicubic.cpp | 198 +- .../generic/images/resize_linear.cpp | 228 +- .../generic/images/resize_neighbor.cpp | 206 +- .../declarable/generic/images/rgbToGrs.cpp | 86 +- .../declarable/generic/images/rgbToHsv.cpp | 64 +- .../declarable/generic/images/rgbToYiq.cpp | 79 +- .../declarable/generic/images/rgbToYuv.cpp | 54 +- .../declarable/generic/images/yiqToRgb.cpp | 65 +- .../declarable/generic/images/yuvToRgb.cpp | 55 +- .../generic/kernels/knn_mindistance.cpp | 50 +- .../ops/declarable/generic/linalg/betaInc.cpp | 73 +- .../declarable/generic/linalg/cholesky.cpp | 41 +- .../ops/declarable/generic/linalg/cross.cpp | 59 +- .../ops/declarable/generic/linalg/diag.cpp | 48 +- .../declarable/generic/linalg/diagPart.cpp | 87 +- .../ops/declarable/generic/linalg/digamma.cpp | 21 +- .../ops/declarable/generic/linalg/eye.cpp | 149 +- .../ops/declarable/generic/linalg/lgamma.cpp | 21 +- .../ops/declarable/generic/linalg/log1p.cpp | 30 +- .../ops/declarable/generic/linalg/lstsq.cpp | 244 +- .../ops/declarable/generic/linalg/lup.cpp | 87 +- .../generic/linalg/matrixDiagPart.cpp | 84 +- .../generic/linalg/matrixSetDiag.cpp | 49 +- .../generic/linalg/matrix_band_part.cpp | 65 +- .../generic/linalg/matrix_determinant.cpp | 243 +- .../declarable/generic/linalg/matrix_diag.cpp | 51 +- .../generic/linalg/matrix_inverse.cpp | 32 +- .../ops/declarable/generic/linalg/moments.cpp | 113 +- .../declarable/generic/linalg/polygamma.cpp | 46 +- .../ops/declarable/generic/linalg/qr.cpp | 115 +- .../ops/declarable/generic/linalg/solve.cpp | 96 +- .../generic/linalg/sufficient_statistics.cpp | 120 +- .../ops/declarable/generic/linalg/trace.cpp | 61 +- .../ops/declarable/generic/linalg/tri.cpp | 48 +- .../generic/linalg/triangular_solve.cpp | 104 +- .../ops/declarable/generic/linalg/triu.cpp | 122 +- .../ops/declarable/generic/linalg/zeta.cpp | 50 +- .../declarable/generic/list/clone_list.cpp | 22 +- .../declarable/generic/list/create_list.cpp | 68 +- .../declarable/generic/list/gather_list.cpp | 86 +- .../ops/declarable/generic/list/pick_list.cpp | 55 +- .../ops/declarable/generic/list/read_list.cpp | 59 +- .../declarable/generic/list/scatter_list.cpp | 79 +- .../ops/declarable/generic/list/size_list.cpp | 31 +- .../declarable/generic/list/split_list.cpp | 126 +- .../declarable/generic/list/stack_list.cpp | 26 +- .../declarable/generic/list/unstack_list.cpp | 24 +- .../declarable/generic/list/write_list.cpp | 68 +- .../generic/loss/absoluteDifference.cpp | 589 +- .../generic/loss/cosineDistance.cpp | 690 +- .../ops/declarable/generic/loss/hingeLoss.cpp | 623 +- .../ops/declarable/generic/loss/huberLoss.cpp | 648 +- .../ops/declarable/generic/loss/l2_loss.cpp | 47 +- .../ops/declarable/generic/loss/logLoss.cpp | 620 +- .../generic/loss/log_poisson_loss.cpp | 641 +- .../generic/loss/meanPairWsSqErr.cpp | 786 +- .../ops/declarable/generic/loss/meanSqErr.cpp | 602 +- .../generic/loss/sigmCrossEntropy.cpp | 651 +- .../generic/loss/softmaxCrossEntropy.cpp | 806 +- .../loss/softmaxCrossEntropyWithLogits.cpp | 212 +- .../sparseSoftmaxCrossEntropyWithLogits.cpp | 277 +- .../ops/declarable/generic/nlp/cbow.cpp | 105 +- .../ops/declarable/generic/nlp/skipgram.cpp | 105 +- .../generic/nn/activations/crelu.cpp | 170 +- .../generic/nn/activations/cube.cpp | 58 +- .../declarable/generic/nn/activations/elu.cpp | 62 +- .../generic/nn/activations/hardsigmoid.cpp | 59 +- .../generic/nn/activations/hardtanh.cpp | 59 +- .../generic/nn/activations/identity.cpp | 62 +- .../generic/nn/activations/identity_n.cpp | 63 +- .../generic/nn/activations/lrelu.cpp | 63 +- .../generic/nn/activations/prelu.cpp | 250 +- .../generic/nn/activations/rationaltanh.cpp | 59 +- .../generic/nn/activations/rectifiedtanh.cpp | 59 +- .../generic/nn/activations/relu.cpp | 61 +- .../generic/nn/activations/relu6.cpp | 54 +- .../generic/nn/activations/selu.cpp | 59 +- .../generic/nn/activations/sigmoid.cpp | 59 +- .../generic/nn/activations/softplus.cpp | 61 +- .../generic/nn/activations/softsign.cpp | 61 +- .../generic/nn/activations/tanh.cpp | 61 +- .../nn/activations/thresholdedrelu.cpp | 58 +- .../ops/declarable/generic/nn/apply_sgd.cpp | 71 +- .../ops/declarable/generic/nn/batchnorm.cpp | 650 +- .../ops/declarable/generic/nn/bias_add.cpp | 100 +- .../declarable/generic/nn/convo/col2im.cpp | 127 +- .../declarable/generic/nn/convo/conv1d.cpp | 663 +- .../declarable/generic/nn/convo/conv2d.cpp | 875 +- .../declarable/generic/nn/convo/conv3d.cpp | 826 +- .../declarable/generic/nn/convo/deconv2d.cpp | 731 +- .../generic/nn/convo/deconv2d_tf.cpp | 307 +- .../declarable/generic/nn/convo/deconv3d.cpp | 798 +- .../generic/nn/convo/depthwiseConv2d.cpp | 587 +- .../generic/nn/convo/dilation2d.cpp | 163 +- .../declarable/generic/nn/convo/im2col.cpp | 272 +- .../ops/declarable/generic/nn/convo/ismax.cpp | 36 +- .../generic/nn/convo/pointwiseConv2d.cpp | 209 +- .../declarable/generic/nn/convo/sconv2d.cpp | 909 +- .../generic/nn/convo/upsampling2d.cpp | 183 +- .../generic/nn/convo/upsampling3d.cpp | 190 +- .../generic/nn/dot_product_attention.cpp | 393 +- .../generic/nn/embedding_lookup.cpp | 141 +- .../declarable/generic/nn/fusedBatchNorm.cpp | 243 +- .../ops/declarable/generic/nn/layer_norm.cpp | 245 +- .../ops/declarable/generic/nn/logSoftmax.cpp | 70 +- .../include/ops/declarable/generic/nn/lrn.cpp | 108 +- .../nn/multi_head_dot_product_attention.cpp | 603 +- .../generic/nn/pooling/avgpool2d.cpp | 401 +- .../generic/nn/pooling/avgpool3d.cpp | 398 +- .../generic/nn/pooling/maxpool2d.cpp | 403 +- .../generic/nn/pooling/maxpool3d.cpp | 438 +- .../nn/pooling/maxpool_with_argmax.cpp | 55 +- .../generic/nn/pooling/pnormpool2d.cpp | 453 +- .../nn/recurrent/dynamicBidirectionalRNN.cpp | 524 +- .../generic/nn/recurrent/dynamicRNN.cpp | 349 +- .../declarable/generic/nn/recurrent/gru.cpp | 350 +- .../generic/nn/recurrent/gruCell.cpp | 516 +- .../declarable/generic/nn/recurrent/lstm.cpp | 309 +- .../generic/nn/recurrent/lstmBlock.cpp | 203 +- .../generic/nn/recurrent/lstmBlockCell.cpp | 233 +- .../generic/nn/recurrent/lstmCell.cpp | 307 +- .../generic/nn/recurrent/lstmLayer.cpp | 1750 +-- .../generic/nn/recurrent/lstmLayerCell.cpp | 652 +- .../declarable/generic/nn/recurrent/sru.cpp | 1628 ++- .../generic/nn/recurrent/sruCell.cpp | 175 +- .../nn/recurrent/staticBidirectionalRNN.cpp | 514 +- .../generic/nn/recurrent/staticRNN.cpp | 282 +- .../ops/declarable/generic/nn/relu_layer.cpp | 81 +- .../ops/declarable/generic/nn/softmax.cpp | 70 +- .../ops/declarable/generic/nn/xw_plus_b.cpp | 217 +- .../ops/declarable/generic/parity_ops.cpp | 30 +- .../declarable/generic/parity_ops/assert.cpp | 27 +- .../generic/parity_ops/bincount.cpp | 183 +- .../parity_ops/broadcast_dynamic_shape.cpp | 113 +- .../generic/parity_ops/check_numerics.cpp | 43 +- .../parity_ops/compare_and_bitpack.cpp | 81 +- .../generic/parity_ops/confusion_matrix.cpp | 107 +- .../declarable/generic/parity_ops/expose.cpp | 81 +- .../fake_quant_with_min_max_vars.cpp | 84 +- ...ke_quant_with_min_max_vars_per_channel.cpp | 86 +- .../generic/parity_ops/in_top_k.cpp | 67 +- .../generic/parity_ops/listdiff.cpp | 76 +- .../parity_ops/non_max_suppression.cpp | 401 +- .../non_max_suppression_overlaps.cpp | 127 +- .../generic/parity_ops/normalize_moments.cpp | 91 +- .../generic/parity_ops/nth_element.cpp | 107 +- .../declarable/generic/parity_ops/onehot.cpp | 117 +- .../declarable/generic/parity_ops/rint.cpp | 32 +- .../declarable/generic/parity_ops/roll.cpp | 147 +- .../generic/parity_ops/segment_max.cpp | 170 +- .../generic/parity_ops/segment_mean.cpp | 169 +- .../generic/parity_ops/segment_min.cpp | 166 +- .../generic/parity_ops/segment_prod.cpp | 175 +- .../generic/parity_ops/segment_sum.cpp | 152 +- .../generic/parity_ops/sequence_mask.cpp | 135 +- .../declarable/generic/parity_ops/square.cpp | 32 +- .../generic/parity_ops/stop_gradient.cpp | 38 +- .../declarable/generic/parity_ops/top_k.cpp | 131 +- .../declarable/generic/parity_ops/unique.cpp | 165 +- .../parity_ops/unsorted_segment_max.cpp | 136 +- .../parity_ops/unsorted_segment_mean.cpp | 158 +- .../parity_ops/unsorted_segment_min.cpp | 139 +- .../parity_ops/unsorted_segment_prod.cpp | 196 +- .../parity_ops/unsorted_segment_sqrt_n.cpp | 156 +- .../parity_ops/unsorted_segment_sum.cpp | 149 +- .../weighted_cross_entropy_with_logits.cpp | 49 +- .../generic/parity_ops/zero_fraction.cpp | 69 +- .../declarable/generic/random/bernoulli.cpp | 64 +- .../ops/declarable/generic/random/dropout.cpp | 141 +- .../declarable/generic/random/exponential.cpp | 46 +- .../ops/declarable/generic/random/gamma.cpp | 101 +- .../declarable/generic/random/get_seed.cpp | 46 +- .../declarable/generic/random/multinomial.cpp | 186 +- .../ops/declarable/generic/random/normal.cpp | 64 +- .../ops/declarable/generic/random/poisson.cpp | 68 +- .../declarable/generic/random/random_crop.cpp | 69 +- .../generic/random/random_shuffle.cpp | 38 +- .../declarable/generic/random/set_seed.cpp | 73 +- .../ops/declarable/generic/random/uniform.cpp | 116 +- .../ops/declarable/generic/reduce/argmax.cpp | 97 +- .../ops/declarable/generic/reduce/argmin.cpp | 107 +- .../ops/declarable/generic/reduce/norm.cpp | 145 +- .../declarable/generic/reduce/reduceMean.cpp | 254 +- .../declarable/generic/reduce/reduceStDev.cpp | 301 +- .../generic/reduce/reduceVariance.cpp | 289 +- .../declarable/generic/reduce/reduce_dot.cpp | 173 +- .../generic/reduce/reduce_logsumexp.cpp | 98 +- .../declarable/generic/reduce/reduce_max.cpp | 214 +- .../declarable/generic/reduce/reduce_min.cpp | 216 +- .../generic/reduce/reduce_norm1.cpp | 251 +- .../generic/reduce/reduce_norm2.cpp | 244 +- .../generic/reduce/reduce_norm_max.cpp | 246 +- .../declarable/generic/reduce/reduce_prod.cpp | 254 +- .../generic/reduce/reduce_sqnorm.cpp | 224 +- .../declarable/generic/reduce/reduce_sum.cpp | 239 +- .../declarable/generic/shape/broadcast_to.cpp | 116 +- .../shape/evaluate_reduction_shape.cpp | 88 +- .../declarable/generic/shape/expand_dims.cpp | 132 +- .../ops/declarable/generic/shape/flatten.cpp | 69 +- .../ops/declarable/generic/shape/order.cpp | 40 +- .../ops/declarable/generic/shape/permute.cpp | 63 +- .../ops/declarable/generic/shape/rank.cpp | 48 +- .../ops/declarable/generic/shape/reshape.cpp | 237 +- .../declarable/generic/shape/reshape_as.cpp | 43 +- .../ops/declarable/generic/shape/shape.cpp | 47 +- .../ops/declarable/generic/shape/shapes.cpp | 62 +- .../ops/declarable/generic/shape/size.cpp | 47 +- .../ops/declarable/generic/shape/size_at.cpp | 57 +- .../ops/declarable/generic/shape/squeeze.cpp | 257 +- .../generic/shape/tile_to_shape.cpp | 102 +- .../declarable/generic/shape/transpose.cpp | 66 +- .../generic/strings/split_string.cpp | 42 +- .../ops/declarable/generic/tensor/create.cpp | 43 +- .../ops/declarable/generic/tensor/fill.cpp | 137 +- .../ops/declarable/generic/tensor/fill_as.cpp | 51 +- .../declarable/generic/tensor/lin_space.cpp | 98 +- .../ops/declarable/generic/tensor/ones_as.cpp | 42 +- .../ops/declarable/generic/tensor/range.cpp | 470 +- .../generic/tensor/strided_slice.cpp | 1327 +- .../declarable/generic/tensor/zeros_as.cpp | 44 +- .../ops/declarable/generic/tests/noop.cpp | 20 +- .../generic/tests/test_output_reshape.cpp | 27 +- .../declarable/generic/tests/test_scalar.cpp | 60 +- .../declarable/generic/tests/testcustom.cpp | 52 +- .../declarable/generic/tests/testop2i2o.cpp | 50 +- .../generic/tests/testreduction.cpp | 22 +- .../generic/thrid_party/firas_sparse.cpp | 113 +- .../generic/transforms/batch_to_space.cpp | 184 +- .../generic/transforms/batch_to_space_nd.cpp | 179 +- .../transforms/clip_by_averaged_norm.cpp | 32 +- .../transforms/clip_by_global_norm.cpp | 63 +- .../generic/transforms/clip_by_norm.cpp | 76 +- .../generic/transforms/clip_by_value.cpp | 52 +- .../declarable/generic/transforms/concat.cpp | 686 +- .../declarable/generic/transforms/cumprod.cpp | 262 +- .../declarable/generic/transforms/cumsum.cpp | 214 +- .../generic/transforms/depth_to_space.cpp | 120 +- .../generic/transforms/dynamic_parititon.cpp | 247 +- .../generic/transforms/dynamic_stitch.cpp | 108 +- .../declarable/generic/transforms/floor.cpp | 28 +- .../declarable/generic/transforms/gather.cpp | 232 +- .../generic/transforms/gatherNd.cpp | 129 +- .../generic/transforms/hashcode.cpp | 47 +- .../generic/transforms/histogram.cpp | 47 +- .../transforms/histogram_fixed_width.cpp | 64 +- .../generic/transforms/invertPermutation.cpp | 36 +- .../generic/transforms/merge_add.cpp | 91 +- .../generic/transforms/merge_avg.cpp | 89 +- .../generic/transforms/merge_max.cpp | 96 +- .../generic/transforms/merge_max_idx.cpp | 39 +- .../generic/transforms/mirrorPad.cpp | 182 +- .../ops/declarable/generic/transforms/pad.cpp | 175 +- .../generic/transforms/parallelStack.cpp | 67 +- .../declarable/generic/transforms/repeat.cpp | 56 +- .../declarable/generic/transforms/reverse.cpp | 154 +- .../generic/transforms/reverseSequence.cpp | 134 +- .../generic/transforms/scatter_add.cpp | 133 +- .../generic/transforms/scatter_div.cpp | 146 +- .../generic/transforms/scatter_max.cpp | 136 +- .../generic/transforms/scatter_min.cpp | 136 +- .../generic/transforms/scatter_mul.cpp | 147 +- .../generic/transforms/scatter_nd.cpp | 148 +- .../generic/transforms/scatter_nd_add.cpp | 112 +- .../generic/transforms/scatter_nd_sub.cpp | 112 +- .../generic/transforms/scatter_nd_update.cpp | 113 +- .../generic/transforms/scatter_sub.cpp | 151 +- .../generic/transforms/scatter_upd.cpp | 149 +- .../generic/transforms/scatter_update.cpp | 59 +- .../declarable/generic/transforms/slice.cpp | 389 +- .../generic/transforms/space_to_batch.cpp | 151 +- .../generic/transforms/space_to_batch_nd.cpp | 162 +- .../generic/transforms/space_to_depth.cpp | 117 +- .../declarable/generic/transforms/split.cpp | 233 +- .../declarable/generic/transforms/split_v.cpp | 170 +- .../declarable/generic/transforms/stack.cpp | 136 +- .../generic/transforms/standardize.cpp | 198 +- .../declarable/generic/transforms/tear.cpp | 81 +- .../declarable/generic/transforms/tile.cpp | 290 +- .../declarable/generic/transforms/unstack.cpp | 146 +- .../declarable/generic/tsne/cell_contains.cpp | 49 +- .../declarable/generic/tsne/edge_force.cpp | 74 +- .../ops/declarable/generic/tsne/gains.cpp | 34 +- .../declarable/generic/tsne/symmetrized.cpp | 131 +- .../generic/updaters/adaDeltaUpdater.cpp | 131 +- .../generic/updaters/adaGradUpdater.cpp | 117 +- .../generic/updaters/adaMaxUpdater.cpp | 160 +- .../generic/updaters/adamUpdater.cpp | 161 +- .../generic/updaters/amsGradUpdater.cpp | 176 +- .../generic/updaters/nadamUpdater.cpp | 161 +- .../generic/updaters/nesterovsUpdater.cpp | 114 +- .../generic/updaters/rmsPropUpdater.cpp | 127 +- .../generic/updaters/sgdUpdater.cpp | 64 +- .../generic/util/print_affinity.cpp | 51 +- .../generic/util/print_variable.cpp | 82 +- .../ops/declarable/headers/BarnesHutTsne.h | 141 +- .../ops/declarable/headers/activations.h | 341 +- .../include/ops/declarable/headers/bitwise.h | 190 +- libnd4j/include/ops/declarable/headers/blas.h | 168 +- .../include/ops/declarable/headers/boolean.h | 260 +- .../ops/declarable/headers/broadcastable.h | 746 +- .../include/ops/declarable/headers/common.h | 23 +- .../include/ops/declarable/headers/compat.h | 55 +- .../ops/declarable/headers/compression.h | 65 +- .../include/ops/declarable/headers/convo.h | 557 +- .../ops/declarable/headers/datatypes.h | 161 +- .../include/ops/declarable/headers/images.h | 97 +- .../include/ops/declarable/headers/kernels.h | 14 +- libnd4j/include/ops/declarable/headers/list.h | 203 +- libnd4j/include/ops/declarable/headers/loss.h | 686 +- libnd4j/include/ops/declarable/headers/nlp.h | 20 +- libnd4j/include/ops/declarable/headers/nn.h | 446 +- .../ops/declarable/headers/parity_ops.h | 4065 +++--- .../include/ops/declarable/headers/random.h | 129 +- .../ops/declarable/headers/recurrent.h | 818 +- .../include/ops/declarable/headers/shape.h | 185 +- .../include/ops/declarable/headers/strings.h | 27 +- .../include/ops/declarable/headers/tests.h | 42 +- .../ops/declarable/headers/third_party.h | 12 +- .../ops/declarable/headers/transforms.h | 407 +- .../include/ops/declarable/headers/updaters.h | 329 +- libnd4j/include/ops/declarable/headers/util.h | 32 +- .../ops/declarable/helpers/BarnesHutTsne.h | 31 +- .../ops/declarable/helpers/activations.h | 39 +- .../include/ops/declarable/helpers/addBias.h | 18 +- .../ops/declarable/helpers/adjust_hue.h | 179 +- .../declarable/helpers/adjust_saturation.h | 41 +- libnd4j/include/ops/declarable/helpers/axis.h | 18 +- .../ops/declarable/helpers/batched_gemm.h | 12 +- .../ops/declarable/helpers/batchnorm.h | 18 +- .../include/ops/declarable/helpers/betaInc.h | 19 +- .../include/ops/declarable/helpers/choose.h | 15 +- .../include/ops/declarable/helpers/col2im.h | 13 +- .../ops/declarable/helpers/compare_elem.h | 16 +- .../ops/declarable/helpers/compression.h | 14 +- .../ops/declarable/helpers/confusion.h | 9 +- .../ops/declarable/helpers/convolutions.h | 725 +- .../declarable/helpers/cpu/BarnesHutTsne.cpp | 423 +- .../declarable/helpers/cpu/activations.cpp | 425 +- .../ops/declarable/helpers/cpu/addBias.cpp | 1303 +- .../ops/declarable/helpers/cpu/adjust_hue.cpp | 154 +- .../helpers/cpu/adjust_saturation.cpp | 164 +- .../ops/declarable/helpers/cpu/axis.cpp | 35 +- .../declarable/helpers/cpu/batched_gemm.cpp | 217 +- .../ops/declarable/helpers/cpu/batchnorm.cpp | 354 +- .../ops/declarable/helpers/cpu/betaInc.cpp | 172 +- .../ops/declarable/helpers/cpu/clip.cpp | 440 +- .../ops/declarable/helpers/cpu/col2im.cpp | 213 +- .../declarable/helpers/cpu/compare_elem.cpp | 99 +- .../compilation_units/crop_and_resize_0.cpp | 17 +- .../compilation_units/crop_and_resize_1.cpp | 17 +- .../compilation_units/crop_and_resize_2.cpp | 17 +- .../compilation_units/crop_and_resize_3.cpp | 17 +- .../compilation_units/crop_and_resize_4.cpp | 17 +- .../compilation_units/crop_and_resize_5.cpp | 17 +- .../compilation_units/crop_and_resize_6.cpp | 17 +- .../compilation_units/crop_and_resize_7.cpp | 17 +- .../compilation_units/crop_and_resize_8.cpp | 17 +- .../compilation_units/crop_and_resize_9.cpp | 17 +- .../helpers/cpu/compression/compression.cpp | 24 +- .../helpers/cpu/compression/threshold.cpp | 80 +- .../ops/declarable/helpers/cpu/concat.cpp | 36 +- .../ops/declarable/helpers/cpu/confusion.cpp | 55 +- .../helpers/cpu/convolutions_col2vol.cpp | 240 +- .../helpers/cpu/convolutions_conv2d.cpp | 179 +- .../helpers/cpu/convolutions_conv2dBP.cpp | 228 +- .../cpu/convolutions_depthwiseConv2d.cpp | 179 +- .../cpu/convolutions_depthwiseConv2dBP.cpp | 234 +- .../helpers/cpu/convolutions_pooling2d.cpp | 426 +- .../helpers/cpu/convolutions_pooling2dBP.cpp | 577 +- .../helpers/cpu/convolutions_pooling3d.cpp | 463 +- .../helpers/cpu/convolutions_pooling3dBP.cpp | 612 +- .../helpers/cpu/convolutions_sconv2d.cpp | 108 +- .../helpers/cpu/convolutions_upsampling2d.cpp | 92 +- .../cpu/convolutions_upsampling2dBP.cpp | 115 +- .../helpers/cpu/convolutions_upsampling3d.cpp | 126 +- .../cpu/convolutions_upsampling3dBP.cpp | 134 +- .../helpers/cpu/convolutions_vol2col.cpp | 247 +- .../helpers/cpu/crop_and_resize.cpp | 41 +- .../helpers/cpu/crop_and_resize.hpp | 188 +- .../ops/declarable/helpers/cpu/cross.cpp | 48 +- .../ops/declarable/helpers/cpu/d_t_s.cpp | 171 +- .../ops/declarable/helpers/cpu/diGamma.cpp | 27 +- .../ops/declarable/helpers/cpu/diag.cpp | 50 +- .../ops/declarable/helpers/cpu/dilation2d.cpp | 136 +- .../ops/declarable/helpers/cpu/dropout.cpp | 315 +- .../ops/declarable/helpers/cpu/dynamic.cpp | 429 +- .../helpers/cpu/extract_patches.cpp | 141 +- .../ops/declarable/helpers/cpu/eye.cpp | 29 +- .../helpers/cpu/fake_quantization.cpp | 181 +- .../ops/declarable/helpers/cpu/flatten.cpp | 72 +- .../ops/declarable/helpers/cpu/gather.cpp | 290 +- .../helpers/cpu/gatherTransforms.cpp | 266 +- .../ops/declarable/helpers/cpu/gradient.cpp | 26 +- .../ops/declarable/helpers/cpu/hamming.cpp | 159 +- .../ops/declarable/helpers/cpu/hashcode.cpp | 158 +- .../ops/declarable/helpers/cpu/histogram.cpp | 81 +- .../helpers/cpu/histogramFixedWidth.cpp | 63 +- .../ops/declarable/helpers/cpu/im2col.cpp | 222 +- .../helpers/cpu/image_draw_bounding_boxes.cpp | 231 +- .../declarable/helpers/cpu/image_resize.cpp | 1889 +-- .../helpers/cpu/image_suppression.cpp | 484 +- .../declarable/helpers/cpu/imagesHelpers.cpp | 416 +- .../helpers/cpu/invertPermutation.cpp | 51 +- .../ops/declarable/helpers/cpu/ismax.cpp | 338 +- .../declarable/helpers/cpu/legacy_helper.cpp | 720 +- .../ops/declarable/helpers/cpu/lgamma.cpp | 31 +- .../ops/declarable/helpers/cpu/lrn.cpp | 616 +- .../ops/declarable/helpers/cpu/lstm.cpp | 443 +- .../ops/declarable/helpers/cpu/lstsq.cpp | 162 +- .../ops/declarable/helpers/cpu/lup.cpp | 1123 +- .../declarable/helpers/cpu/matrixSetDiag.cpp | 95 +- .../declarable/helpers/cpu/matrix_band.cpp | 98 +- .../helpers/cpu/matrix_diag_part.cpp | 59 +- .../declarable/helpers/cpu/max_pooling.cpp | 91 +- .../ops/declarable/helpers/cpu/merge.cpp | 418 +- .../ops/declarable/helpers/cpu/meshgrid.cpp | 45 +- .../ops/declarable/helpers/cpu/minimax.cpp | 263 +- .../declarable/helpers/cpu/nth_element.cpp | 97 +- .../ops/declarable/helpers/cpu/one_hot.cpp | 172 +- .../ops/declarable/helpers/cpu/pad.cpp | 472 +- .../ops/declarable/helpers/cpu/percentile.cpp | 124 +- .../ops/declarable/helpers/cpu/polyGamma.cpp | 94 +- .../ops/declarable/helpers/cpu/prefix.cpp | 238 +- .../declarable/helpers/cpu/print_variable.cpp | 15 +- .../include/ops/declarable/helpers/cpu/qr.cpp | 213 +- .../ops/declarable/helpers/cpu/random.cpp | 358 +- .../declarable/helpers/cpu/randomShuffle.cpp | 180 +- .../declarable/helpers/cpu/random_crop.cpp | 87 +- .../ops/declarable/helpers/cpu/range.cpp | 49 +- .../ops/declarable/helpers/cpu/reverse.cpp | 356 +- .../ops/declarable/helpers/cpu/roll.cpp | 261 +- .../ops/declarable/helpers/cpu/s_t_b.cpp | 674 +- .../ops/declarable/helpers/cpu/s_t_d.cpp | 172 +- .../ops/declarable/helpers/cpu/scatter.cpp | 264 +- .../helpers/cpu/scatterUpdateAndSimple.cpp | 169 +- .../ops/declarable/helpers/cpu/segment.cpp | 1998 +-- .../declarable/helpers/cpu/sequence_mask.cpp | 42 +- .../ops/declarable/helpers/cpu/sg_cb.cpp | 1311 +- .../ops/declarable/helpers/cpu/shift.cpp | 98 +- .../ops/declarable/helpers/cpu/softmax.cpp | 434 +- .../ops/declarable/helpers/cpu/solve.cpp | 143 +- .../ops/declarable/helpers/cpu/split.cpp | 175 +- .../ops/declarable/helpers/cpu/sru.cpp | 627 +- .../ops/declarable/helpers/cpu/stack.cpp | 168 +- .../ops/declarable/helpers/cpu/svd.cpp | 1748 +-- .../ops/declarable/helpers/cpu/tile.cpp | 108 +- .../declarable/helpers/cpu/toggle_bits.cpp | 30 +- .../ops/declarable/helpers/cpu/top_k.cpp | 315 +- .../ops/declarable/helpers/cpu/trace.cpp | 32 +- .../helpers/cpu/triangular_solve.cpp | 224 +- .../ops/declarable/helpers/cpu/triu.cpp | 54 +- .../helpers/cpu/updaterAdaDelta.cpp | 166 +- .../declarable/helpers/cpu/updaterAdaGrad.cpp | 118 +- .../declarable/helpers/cpu/updaterAdaMax.cpp | 177 +- .../declarable/helpers/cpu/updaterAdam.cpp | 175 +- .../declarable/helpers/cpu/updaterAmsGrad.cpp | 212 +- .../declarable/helpers/cpu/updaterNadam.cpp | 179 +- .../helpers/cpu/updaterNesterovs.cpp | 109 +- .../declarable/helpers/cpu/updaterRmsProp.cpp | 111 +- .../ops/declarable/helpers/cpu/weights.cpp | 43 +- .../ops/declarable/helpers/cpu/zeta.cpp | 80 +- .../ops/declarable/helpers/crop_and_resize.h | 31 +- .../include/ops/declarable/helpers/cross.h | 108 +- .../declarable/helpers/cuda/BarnesHutTsne.cu | 409 +- .../declarable/helpers/cuda/activations.cu | 1127 +- .../ops/declarable/helpers/cuda/addBias.cu | 224 +- .../ops/declarable/helpers/cuda/adjust_hue.cu | 214 +- .../helpers/cuda/adjust_saturation.cu | 207 +- .../ops/declarable/helpers/cuda/axis.cu | 45 +- .../declarable/helpers/cuda/batched_gemm.cu | 295 +- .../ops/declarable/helpers/cuda/batchnorm.cu | 322 +- .../ops/declarable/helpers/cuda/betaInc.cu | 300 +- .../ops/declarable/helpers/cuda/col2im.cu | 320 +- .../declarable/helpers/cuda/compare_elem.cu | 214 +- .../helpers/cuda/compression/compression.cu | 72 +- .../helpers/cuda/compression/threshold.cu | 393 +- .../ops/declarable/helpers/cuda/concat.cu | 264 +- .../ops/declarable/helpers/cuda/confusion.cu | 216 +- .../helpers/cuda/convolutions_col2vol.cu | 171 +- .../helpers/cuda/convolutions_conv2d.cu | 170 +- .../helpers/cuda/convolutions_conv2dBP.cu | 221 +- .../cuda/convolutions_depthwiseConv2d.cu | 172 +- .../cuda/convolutions_depthwiseConv2dBP.cu | 227 +- .../helpers/cuda/convolutions_pooling2d.cu | 605 +- .../helpers/cuda/convolutions_pooling2dBP.cu | 317 +- .../helpers/cuda/convolutions_pooling3d.cu | 300 +- .../helpers/cuda/convolutions_pooling3dBP.cu | 345 +- .../helpers/cuda/convolutions_sconv2d.cu | 101 +- .../helpers/cuda/convolutions_upsampling2d.cu | 112 +- .../cuda/convolutions_upsampling2dBP.cu | 121 +- .../helpers/cuda/convolutions_upsampling3d.cu | 114 +- .../cuda/convolutions_upsampling3dBP.cu | 129 +- .../helpers/cuda/convolutions_vol2col.cu | 177 +- .../ops/declarable/helpers/cuda/cross.cu | 159 +- .../ops/declarable/helpers/cuda/d_t_s.cu | 164 +- .../ops/declarable/helpers/cuda/diGamma.cu | 92 +- .../ops/declarable/helpers/cuda/diag.cu | 186 +- .../ops/declarable/helpers/cuda/dilation2d.cu | 168 +- .../ops/declarable/helpers/cuda/dropout.cu | 528 +- .../ops/declarable/helpers/cuda/dynamic.cu | 647 +- .../helpers/cuda/extract_patches.cu | 206 +- .../helpers/cuda/fake_quantization.cu | 196 +- .../ops/declarable/helpers/cuda/flatten.cu | 131 +- .../ops/declarable/helpers/cuda/gather.cu | 295 +- .../ops/declarable/helpers/cuda/gather_nd.cu | 204 +- .../ops/declarable/helpers/cuda/gradient.cu | 28 +- .../ops/declarable/helpers/cuda/hamming.cu | 157 +- .../ops/declarable/helpers/cuda/hashcode.cu | 223 +- .../ops/declarable/helpers/cuda/histogram.cu | 239 +- .../helpers/cuda/histogramFixedWidth.cu | 178 +- .../ops/declarable/helpers/cuda/im2col.cu | 123 +- .../helpers/cuda/image_draw_bounding_boxes.cu | 321 +- .../declarable/helpers/cuda/image_resize.cu | 2674 ++-- .../helpers/cuda/image_suppression.cu | 797 +- .../declarable/helpers/cuda/imagesHelpers.cu | 741 +- .../ops/declarable/helpers/cuda/ismax.cu | 133 +- .../declarable/helpers/cuda/legacy/relu.cu | 190 +- .../declarable/helpers/cuda/legacy/tanh.cu | 112 +- .../declarable/helpers/cuda/legacy_helper.cu | 423 +- .../ops/declarable/helpers/cuda/lgamma.cu | 32 +- .../ops/declarable/helpers/cuda/lrn.cu | 300 +- .../ops/declarable/helpers/cuda/lstm.cu | 360 +- .../ops/declarable/helpers/cuda/lstsq.cu | 187 +- .../ops/declarable/helpers/cuda/lup.cu | 1944 +-- .../declarable/helpers/cuda/matrixSetDiag.cu | 156 +- .../declarable/helpers/cuda/matrix_band.cu | 138 +- .../helpers/cuda/matrix_diag_part.cu | 138 +- .../declarable/helpers/cuda/max_pooling.cu | 119 +- .../ops/declarable/helpers/cuda/maximum.cu | 125 +- .../ops/declarable/helpers/cuda/merge.cu | 894 +- .../ops/declarable/helpers/cuda/meshgrid.cu | 227 +- .../ops/declarable/helpers/cuda/minimum.cu | 155 +- .../declarable/helpers/cuda/nth_element.cu | 120 +- .../ops/declarable/helpers/cuda/one_hot.cu | 159 +- .../ops/declarable/helpers/cuda/pad.cu | 501 +- .../ops/declarable/helpers/cuda/percentile.cu | 218 +- .../ops/declarable/helpers/cuda/polyGamma.cu | 138 +- .../ops/declarable/helpers/cuda/prefix.cu | 276 +- .../declarable/helpers/cuda/print_variable.cu | 78 +- .../include/ops/declarable/helpers/cuda/qr.cu | 307 +- .../ops/declarable/helpers/cuda/random.cu | 658 +- .../declarable/helpers/cuda/random_crop.cu | 33 +- .../ops/declarable/helpers/cuda/range.cu | 52 +- .../ops/declarable/helpers/cuda/reverse.cu | 443 +- .../ops/declarable/helpers/cuda/roll.cu | 546 +- .../ops/declarable/helpers/cuda/s_t_b.cu | 818 +- .../ops/declarable/helpers/cuda/s_t_d.cu | 174 +- .../ops/declarable/helpers/cuda/scatter.cu | 1468 ++- .../declarable/helpers/cuda/scatter_simple.cu | 121 +- .../declarable/helpers/cuda/scatter_update.cu | 219 +- .../ops/declarable/helpers/cuda/segment.cu | 236 +- .../declarable/helpers/cuda/segment_max.cu | 904 +- .../declarable/helpers/cuda/segment_mean.cu | 866 +- .../declarable/helpers/cuda/segment_min.cu | 856 +- .../declarable/helpers/cuda/segment_prod.cu | 780 +- .../declarable/helpers/cuda/segment_sqrtn.cu | 479 +- .../declarable/helpers/cuda/segment_sum.cu | 801 +- .../declarable/helpers/cuda/sequence_mask.cu | 80 +- .../ops/declarable/helpers/cuda/sg_cb.cu | 1657 +-- .../ops/declarable/helpers/cuda/shift.cu | 98 +- .../ops/declarable/helpers/cuda/solve.cu | 251 +- .../ops/declarable/helpers/cuda/split.cu | 259 +- .../ops/declarable/helpers/cuda/sru.cu | 942 +- .../ops/declarable/helpers/cuda/stack.cu | 341 +- .../ops/declarable/helpers/cuda/svd.cu | 1230 +- .../declarable/helpers/cuda/toggle_bits.cu | 35 +- .../ops/declarable/helpers/cuda/top_k.cu | 505 +- .../ops/declarable/helpers/cuda/transforms.cu | 1709 +-- .../helpers/cuda/triangular_solve.cu | 436 +- .../helpers/cuda/updaterAdaDelta.cu | 211 +- .../declarable/helpers/cuda/updaterAdaGrad.cu | 167 +- .../declarable/helpers/cuda/updaterAdaMax.cu | 233 +- .../declarable/helpers/cuda/updaterAdam.cu | 230 +- .../declarable/helpers/cuda/updaterAmsGrad.cu | 269 +- .../declarable/helpers/cuda/updaterNadam.cu | 226 +- .../helpers/cuda/updaterNesterovs.cu | 171 +- .../declarable/helpers/cuda/updaterRmsProp.cu | 179 +- .../ops/declarable/helpers/cuda/weights.cu | 193 +- .../ops/declarable/helpers/cuda/zeta.cu | 93 +- .../include/ops/declarable/helpers/d_t_s.h | 9 +- libnd4j/include/ops/declarable/helpers/diag.h | 14 +- .../ops/declarable/helpers/dilation2d.h | 119 +- .../include/ops/declarable/helpers/dropout.h | 25 +- .../include/ops/declarable/helpers/dynamic.h | 31 +- .../ops/declarable/helpers/extract_patches.h | 10 +- .../declarable/helpers/fake_quantization.h | 15 +- .../include/ops/declarable/helpers/flatten.h | 59 +- .../ops/declarable/helpers/gammaMathFunc.h | 152 +- .../include/ops/declarable/helpers/gather.h | 12 +- .../include/ops/declarable/helpers/gradient.h | 17 +- libnd4j/include/ops/declarable/helpers/gru.h | 64 +- .../include/ops/declarable/helpers/hamming.h | 15 +- .../include/ops/declarable/helpers/hashcode.h | 71 +- .../include/ops/declarable/helpers/helpers.h | 25 +- .../ops/declarable/helpers/histogram.h | 13 +- .../declarable/helpers/histogramFixedWidth.h | 10 +- .../include/ops/declarable/helpers/im2col.h | 11 +- .../helpers/image_draw_bounding_boxes.h | 9 +- .../ops/declarable/helpers/image_resize.h | 60 +- .../declarable/helpers/image_suppression.h | 25 +- .../ops/declarable/helpers/imagesHelpers.h | 33 +- .../ops/declarable/helpers/impl/choose.cpp | 214 +- .../ops/declarable/helpers/impl/gru.cpp | 968 +- .../helpers/impl/knn_mindistance.cpp | 81 +- .../ops/declarable/helpers/impl/listdiff.cpp | 181 +- .../ops/declarable/helpers/impl/lstm.cpp | 221 +- .../ops/declarable/helpers/impl/lstmLayer.cpp | 3010 +++-- .../declarable/helpers/impl/multiUnique.cpp | 69 +- .../ops/declarable/helpers/impl/rnn.cpp | 134 +- .../helpers/impl/sparse_to_dense.cpp | 203 +- .../ops/declarable/helpers/impl/unique.cpp | 143 +- .../ops/declarable/helpers/impl/where.cpp | 65 +- .../include/ops/declarable/helpers/ismax.h | 14 +- libnd4j/include/ops/declarable/helpers/knn.h | 13 +- .../ops/declarable/helpers/legacy_helpers.h | 113 +- .../include/ops/declarable/helpers/lgamma.h | 14 +- .../include/ops/declarable/helpers/listdiff.h | 12 +- libnd4j/include/ops/declarable/helpers/lrn.h | 15 +- libnd4j/include/ops/declarable/helpers/lstm.h | 95 +- .../ops/declarable/helpers/lstmBlock.h | 41 +- .../ops/declarable/helpers/lstmLayer.h | 73 +- .../include/ops/declarable/helpers/lstsq.h | 12 +- libnd4j/include/ops/declarable/helpers/lup.h | 34 +- .../include/ops/declarable/helpers/matmul.h | 13 +- .../ops/declarable/helpers/matrixSetDiag.h | 12 +- .../ops/declarable/helpers/matrix_band.h | 10 +- .../ops/declarable/helpers/matrix_diag_part.h | 9 +- .../ops/declarable/helpers/max_pooling.h | 10 +- .../include/ops/declarable/helpers/meshgrid.h | 14 +- .../include/ops/declarable/helpers/minimax.h | 14 +- .../ops/declarable/helpers/multiUnique.h | 9 +- .../ops/declarable/helpers/nth_element.h | 9 +- .../include/ops/declarable/helpers/one_hot.h | 18 +- .../ops/declarable/helpers/percentile.h | 13 +- .../include/ops/declarable/helpers/prefix.h | 26 +- .../ops/declarable/helpers/print_variable.h | 13 +- libnd4j/include/ops/declarable/helpers/qr.h | 10 +- .../include/ops/declarable/helpers/random.h | 23 +- .../ops/declarable/helpers/random_crop.h | 11 +- .../include/ops/declarable/helpers/range.h | 18 +- .../include/ops/declarable/helpers/reverse.h | 22 +- libnd4j/include/ops/declarable/helpers/rnn.h | 23 +- libnd4j/include/ops/declarable/helpers/roll.h | 13 +- .../include/ops/declarable/helpers/s_t_b.h | 61 +- .../include/ops/declarable/helpers/s_t_d.h | 9 +- .../include/ops/declarable/helpers/scatter.h | 34 +- .../include/ops/declarable/helpers/segment.h | 96 +- .../ops/declarable/helpers/segment_common.h | 16 +- .../ops/declarable/helpers/sequence_mask.h | 9 +- .../include/ops/declarable/helpers/sg_cb.h | 30 +- .../include/ops/declarable/helpers/shift.h | 26 +- .../include/ops/declarable/helpers/solve.h | 14 +- .../ops/declarable/helpers/sparse_to_dense.h | 13 +- libnd4j/include/ops/declarable/helpers/sru.h | 35 +- .../include/ops/declarable/helpers/stack.h | 23 +- libnd4j/include/ops/declarable/helpers/svd.h | 21 +- .../ops/declarable/helpers/threshold.h | 18 +- .../ops/declarable/helpers/toggle_bits.h | 19 +- .../include/ops/declarable/helpers/top_k.h | 14 +- .../ops/declarable/helpers/transforms.h | 113 +- .../ops/declarable/helpers/triangular_solve.h | 15 +- .../include/ops/declarable/helpers/unique.h | 13 +- .../ops/declarable/helpers/updatersHelpers.h | 52 +- .../include/ops/declarable/helpers/weights.h | 9 +- .../include/ops/declarable/helpers/where.h | 13 +- libnd4j/include/ops/declarable/helpers/zeta.h | 139 +- .../include/ops/declarable/impl/BooleanOp.cpp | 185 +- .../declarable/impl/BroadcastableBoolOp.cpp | 102 +- .../ops/declarable/impl/BroadcastableOp.cpp | 128 +- .../declarable/impl/DeclarableCustomOp.cpp | 16 +- .../ops/declarable/impl/DeclarableListOp.cpp | 237 +- .../ops/declarable/impl/DeclarableOp.cpp | 2232 ++-- .../declarable/impl/DeclarableReductionOp.cpp | 75 +- .../declarable/impl/LegacyBroadcastBoolOp.cpp | 171 +- .../ops/declarable/impl/LegacyBroadcastOp.cpp | 192 +- .../declarable/impl/LegacyIndexReduceOp.cpp | 358 +- .../include/ops/declarable/impl/LegacyOp.cpp | 65 +- .../impl/LegacyPairwiseTransformBoolOp.cpp | 87 +- .../impl/LegacyPairwiseTransformOp.cpp | 114 +- .../ops/declarable/impl/LegacyRandomOp.cpp | 875 +- .../ops/declarable/impl/LegacyReduce3Op.cpp | 231 +- .../declarable/impl/LegacyReduceBoolOp.cpp | 271 +- .../declarable/impl/LegacyReduceFloatOp.cpp | 270 +- .../declarable/impl/LegacyReduceLongOp.cpp | 272 +- .../ops/declarable/impl/LegacyReduceOp.cpp | 294 +- .../declarable/impl/LegacyReduceSameOp.cpp | 263 +- .../declarable/impl/LegacyScalarBoolOp.cpp | 143 +- .../ops/declarable/impl/LegacyScalarOp.cpp | 146 +- .../ops/declarable/impl/LegacyStatsOp.cpp | 214 +- .../declarable/impl/LegacyTransformAnyOp.cpp | 86 +- .../declarable/impl/LegacyTransformBoolOp.cpp | 80 +- .../impl/LegacyTransformFloatOp.cpp | 86 +- .../ops/declarable/impl/LegacyTransformOp.cpp | 67 +- .../declarable/impl/LegacyTransformSameOp.cpp | 86 +- .../impl/LegacyTransformStrictOp.cpp | 85 +- .../include/ops/declarable/impl/LogicOp.cpp | 31 +- .../ops/declarable/impl/OpDescriptor.cpp | 550 +- .../ops/declarable/impl/OpRegistrator.cpp | 439 +- .../include/ops/declarable/impl/OpTuple.cpp | 44 +- .../ops/declarable/impl/PlatformHelper.cpp | 150 +- .../declarable/platform/cudnn/avgpool2d.cu | 241 +- .../declarable/platform/cudnn/avgpool3d.cu | 265 +- .../declarable/platform/cudnn/batchnorm.cu | 1209 +- .../ops/declarable/platform/cudnn/conv2d.cu | 1195 +- .../ops/declarable/platform/cudnn/conv3d.cu | 1141 +- .../declarable/platform/cudnn/cudnnUtils.cu | 856 +- .../declarable/platform/cudnn/cudnnUtils.h | 165 +- .../platform/cudnn/depthwiseConv2d.cu | 1144 +- .../declarable/platform/cudnn/maxpool2d.cu | 225 +- .../declarable/platform/cudnn/maxpool3d.cu | 253 +- .../platform/mkldnn/avgpooling2d.cpp | 204 +- .../platform/mkldnn/avgpooling3d.cpp | 215 +- .../declarable/platform/mkldnn/batchnorm.cpp | 1268 +- .../ops/declarable/platform/mkldnn/conv2d.cpp | 1157 +- .../ops/declarable/platform/mkldnn/conv3d.cpp | 1256 +- .../declarable/platform/mkldnn/deconv2d.cpp | 1055 +- .../platform/mkldnn/deconv2d_tf.cpp | 444 +- .../declarable/platform/mkldnn/deconv3d.cpp | 1103 +- .../platform/mkldnn/depthwiseConv2d.cpp | 1080 +- .../ops/declarable/platform/mkldnn/lrn.cpp | 136 +- .../declarable/platform/mkldnn/lstmLayer.cpp | 1168 +- .../ops/declarable/platform/mkldnn/matmul.cpp | 648 +- .../platform/mkldnn/maxpooling2d.cpp | 193 +- .../platform/mkldnn/maxpooling3d.cpp | 205 +- .../platform/mkldnn/mkldnnUtils.cpp | 1247 +- .../declarable/platform/mkldnn/mkldnnUtils.h | 327 +- .../declarable/platform/mkldnn/softmax.cpp | 449 +- .../ops/declarable/platform/mkldnn/tanh.cpp | 397 +- .../declarable/platform/mkldnn/xw_plus_b.cpp | 888 +- libnd4j/include/ops/gemm.h | 56 +- .../ops/impl/BroadcastBoolOpsTuple.cpp | 10 +- .../include/ops/impl/BroadcastIntOpsTuple.cpp | 10 +- .../include/ops/impl/BroadcastOpsTuple.cpp | 81 +- .../compilation_units/specials_double_0.cpp | 9 +- .../compilation_units/specials_double_1.cpp | 3 +- .../compilation_units/specials_double_2.cpp | 3 +- .../compilation_units/specials_double_3.cpp | 3 +- .../compilation_units/specials_double_4.cpp | 3 +- .../compilation_units/specials_double_5.cpp | 3 +- .../compilation_units/specials_double_6.cpp | 3 +- .../compilation_units/specials_double_7.cpp | 3 +- .../compilation_units/specials_double_8.cpp | 3 +- .../compilation_units/specials_double_9.cpp | 3 +- .../compilation_units/specials_single_0.cpp | 2 +- .../compilation_units/specials_single_1.cpp | 2 +- .../compilation_units/specials_single_2.cpp | 2 +- .../compilation_units/specials_single_3.cpp | 2 +- .../compilation_units/specials_single_4.cpp | 2 +- .../compilation_units/specials_single_5.cpp | 2 +- .../compilation_units/specials_single_6.cpp | 2 +- .../compilation_units/specials_single_7.cpp | 2 +- .../compilation_units/specials_single_8.cpp | 2 +- .../compilation_units/specials_single_9.cpp | 2 +- libnd4j/include/ops/impl/gemm.cpp | 226 +- libnd4j/include/ops/impl/specials_double.hpp | 474 +- libnd4j/include/ops/impl/specials_single.hpp | 880 +- libnd4j/include/ops/impl/specials_sparse.cpp | 356 +- libnd4j/include/ops/meta_ops.h | 296 +- libnd4j/include/ops/ops.h | 8464 ++++++------ libnd4j/include/ops/random_ops.h | 538 +- libnd4j/include/ops/special_random_ops.h | 1572 +-- libnd4j/include/ops/specials.h | 127 +- libnd4j/include/ops/specials_cuda.h | 120 +- libnd4j/include/ops/specials_sparse.h | 72 +- .../performance/benchmarking/BenchmarkSuit.h | 26 +- .../benchmarking/FullBenchmarkSuit.h | 13 +- .../benchmarking/LightBenchmarkSuit.h | 13 +- .../benchmarking/impl/FullBenchmarkSuit.cpp | 3686 +++--- .../benchmarking/impl/LightBenchmarkSuit.cpp | 1202 +- libnd4j/include/samediff.h | 4 +- libnd4j/include/system/BlasVersionHelper.h | 22 +- libnd4j/include/system/Environment.h | 197 +- libnd4j/include/system/buffer.h | 279 +- libnd4j/include/system/dll.h | 6 +- libnd4j/include/system/enum_boilerplate.h | 224 +- libnd4j/include/system/msvc.h | 22 +- libnd4j/include/system/nd4jmalloc.h | 2 +- libnd4j/include/system/nd4jmemset.h | 2 +- libnd4j/include/system/op_enums.h | 201 +- libnd4j/include/system/openmp_pragmas.h | 67 +- libnd4j/include/system/optype.h | 12 +- libnd4j/include/system/pairwise_util.h | 609 +- libnd4j/include/system/platform_boilerplate.h | 64 +- libnd4j/include/system/play.h | 122 +- libnd4j/include/system/pointercast.h | 51 +- libnd4j/include/system/type_boilerplate.h | 2260 +++- libnd4j/include/system/util.h | 16 +- libnd4j/include/types/bfloat16.h | 544 +- libnd4j/include/types/float16.h | 881 +- libnd4j/include/types/float8.h | 228 +- libnd4j/include/types/impl/int16.cpp | 12 +- libnd4j/include/types/impl/pair.cpp | 18 +- libnd4j/include/types/impl/triple.cpp | 22 +- libnd4j/include/types/impl/uint16.cpp | 13 +- libnd4j/include/types/impl/uint8.cpp | 12 +- libnd4j/include/types/impl/utf8string.cpp | 88 +- libnd4j/include/types/int16.h | 88 +- libnd4j/include/types/int8.h | 85 +- libnd4j/include/types/pair.h | 25 +- libnd4j/include/types/triple.h | 35 +- libnd4j/include/types/types.h | 607 +- libnd4j/include/types/u64.h | 68 +- libnd4j/include/types/uint16.h | 91 +- libnd4j/include/types/uint8.h | 87 +- libnd4j/include/types/utf8string.h | 37 +- libnd4j/minifier/graphopt.cpp | 191 +- libnd4j/minifier/graphopt.h | 74 +- libnd4j/minifier/minifier.cpp | 218 +- libnd4j/server/GraphServer.cpp | 280 +- libnd4j/server/GraphServer.h | 49 +- libnd4j/tests_cpu/layers_tests/AllTests.cpp | 20 +- .../layers_tests/ArrayOptionsTests.cpp | 77 +- libnd4j/tests_cpu/layers_tests/AtomicTests.cu | 305 +- .../tests_cpu/layers_tests/AttentionTests.cpp | 176 +- .../tests_cpu/layers_tests/BackpropTests.cpp | 27 +- .../layers_tests/BitwiseUtilsTests.cpp | 46 +- .../layers_tests/BooleanOpsTests.cpp | 110 +- .../layers_tests/BroadcastableOpsTests.cpp | 947 +- .../tests_cpu/layers_tests/BrodcastTests.cpp | 75 +- libnd4j/tests_cpu/layers_tests/CnpyTests.cpp | 41 +- .../layers_tests/ConstantShapeHelperTests.cpp | 261 +- .../tests_cpu/layers_tests/ContextTests.cpp | 391 +- .../layers_tests/ConvolutionTests1.cpp | 7066 ++++++---- .../layers_tests/ConvolutionTests2.cpp | 8610 +++++++++---- libnd4j/tests_cpu/layers_tests/CuDnnTests.cu | 210 +- .../layers_tests/CudaBasicsTests1.cu | 6129 +++++---- .../layers_tests/CudaBasicsTests2.cu | 1652 ++- .../layers_tests/CudaExtraArgumentsTests.cu | 63 +- .../layers_tests/CudaLaunchHelperTests.cpp | 14 +- .../layers_tests/DataBufferTests.cpp | 69 +- .../layers_tests/DataBufferTestsCuda.cu | 35 +- .../layers_tests/DataTypesValidationTests.cpp | 183 +- .../layers_tests/DeclarableOpsTests1.cpp | 4265 +++--- .../layers_tests/DeclarableOpsTests10.cpp | 4737 ++++--- .../layers_tests/DeclarableOpsTests11.cpp | 6750 +++++----- .../layers_tests/DeclarableOpsTests12.cpp | 4740 +++---- .../layers_tests/DeclarableOpsTests13.cpp | 6042 +++++---- .../layers_tests/DeclarableOpsTests14.cpp | 3576 +++--- .../layers_tests/DeclarableOpsTests15.cpp | 3140 +++-- .../layers_tests/DeclarableOpsTests16.cpp | 1703 ++- .../layers_tests/DeclarableOpsTests17.cpp | 90 +- .../layers_tests/DeclarableOpsTests18.cpp | 3905 +++--- .../layers_tests/DeclarableOpsTests19.cpp | 484 +- .../layers_tests/DeclarableOpsTests2.cpp | 5786 +++++---- .../layers_tests/DeclarableOpsTests3.cpp | 4276 +++--- .../layers_tests/DeclarableOpsTests4.cpp | 3520 ++--- .../layers_tests/DeclarableOpsTests5.cpp | 4348 +++---- .../layers_tests/DeclarableOpsTests6.cpp | 4170 +++--- .../layers_tests/DeclarableOpsTests7.cpp | 10708 ++++++++-------- .../layers_tests/DeclarableOpsTests8.cpp | 5151 ++++---- .../layers_tests/DeclarableOpsTests9.cpp | 3188 ++--- .../layers_tests/DeclarableOpsTestsCuda1.cu | 68 +- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 270 +- .../layers_tests/ExecutionLayerTests.cpp | 51 +- .../layers_tests/ExtraArgumentsTests.cpp | 52 +- .../layers_tests/FlatBuffersTests.cpp | 674 +- .../tests_cpu/layers_tests/FlatUtilsTests.cpp | 69 +- .../layers_tests/GraphAnalysisTests.cpp | 1487 ++- .../layers_tests/GraphExecutorTests.cpp | 103 +- .../layers_tests/GraphHolderTests.cpp | 18 +- .../GraphRandomGeneratorTests.cpp | 326 +- libnd4j/tests_cpu/layers_tests/GraphTests.cpp | 62 +- .../tests_cpu/layers_tests/GraphTests2.cpp | 173 +- .../tests_cpu/layers_tests/HashUtilsTests.cpp | 20 +- .../tests_cpu/layers_tests/HelpersTests1.cpp | 4788 ++++--- .../tests_cpu/layers_tests/IndexingTests.cpp | 532 +- .../layers_tests/JavaInteropCudaTests.cu | 104 +- .../layers_tests/JavaInteropTests.cpp | 2396 ++-- libnd4j/tests_cpu/layers_tests/LambdaTests.cu | 212 +- .../layers_tests/LaunchContextCudaTests.cu | 125 +- .../layers_tests/LegacyOpsCudaTests.cu | 52 +- .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 1091 +- .../layers_tests/ListOperationsTests.cpp | 432 +- .../layers_tests/LoopCoordsHelperTests.cpp | 298 +- .../layers_tests/ManagedDataBufferTests.cpp | 38 +- .../layers_tests/MemoryUtilsTests.cpp | 17 +- .../tests_cpu/layers_tests/MklDnnTests.cpp | 97 +- libnd4j/tests_cpu/layers_tests/MmapTests.cpp | 34 +- .../layers_tests/MultiDataTypeTests.cpp | 2559 ++-- .../layers_tests/MultiDeviceTests.cpp | 57 +- .../layers_tests/NDArrayConstructorsTests.cu | 205 +- .../layers_tests/NDArrayCudaBasicsTests.cu | 3719 +++--- .../layers_tests/NDArrayListTests.cpp | 42 +- .../tests_cpu/layers_tests/NDArrayTests.cpp | 3249 ++--- .../tests_cpu/layers_tests/NDArrayTests2.cpp | 1573 +-- .../tests_cpu/layers_tests/NativeOpsTests.cpp | 2294 ++-- libnd4j/tests_cpu/layers_tests/NlpTests.cpp | 843 +- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 61 +- .../layers_tests/OmpLaunchHelperTests.cpp | 97 +- .../tests_cpu/layers_tests/OneOffTests.cpp | 277 +- .../layers_tests/OpSequenceTests.cpp | 64 +- .../tests_cpu/layers_tests/OpTrackerTests.cpp | 48 +- .../tests_cpu/layers_tests/OpTupleTests.cpp | 47 +- .../tests_cpu/layers_tests/PairwiseTests.cpp | 39 +- .../tests_cpu/layers_tests/ParityOpsTests.cpp | 2366 ++-- .../layers_tests/PerformanceTests.cpp | 170 +- .../layers_tests/PlaygroundTests.cpp | 533 +- .../tests_cpu/layers_tests/ProtoBufTests.cpp | 31 +- .../layers_tests/QuantizationTests.cpp | 42 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 2062 +-- .../tests_cpu/layers_tests/ResultSetTests.cpp | 25 +- .../tests_cpu/layers_tests/SanityTests.cpp | 30 +- .../tests_cpu/layers_tests/ScalarTests.cpp | 247 +- .../layers_tests/ServerRelatedTests.cpp | 156 +- libnd4j/tests_cpu/layers_tests/ShapeTests.cpp | 331 +- .../tests_cpu/layers_tests/ShapeTests2.cpp | 1213 +- .../layers_tests/ShapeUtilsTests.cpp | 275 +- .../tests_cpu/layers_tests/SingleDimTests.cpp | 172 +- .../tests_cpu/layers_tests/SortCpuTests.cpp | 135 +- .../tests_cpu/layers_tests/SortCudaTests.cu | 203 +- .../layers_tests/SparseUtilsTest.cpp | 168 +- libnd4j/tests_cpu/layers_tests/StashTests.cpp | 74 +- .../tests_cpu/layers_tests/StringTests.cpp | 941 +- libnd4j/tests_cpu/layers_tests/TadTests.cpp | 622 +- .../tests_cpu/layers_tests/ThreadsTests.cpp | 268 +- .../tests_cpu/layers_tests/TypeCastTests.cpp | 65 +- .../layers_tests/VariableProxyTests.cpp | 138 +- .../layers_tests/VariableSpaceTests.cpp | 118 +- .../tests_cpu/layers_tests/VariableTests.cpp | 140 +- .../tests_cpu/layers_tests/WorkspaceTests.cpp | 293 +- .../tests_cpu/layers_tests/WorkspaceTests.cu | 51 +- libnd4j/tests_cpu/layers_tests/testinclude.h | 51 +- libnd4j/tests_cpu/layers_tests/testlayers.h | 21 +- 1715 files changed, 273364 insertions(+), 233698 deletions(-) mode change 100755 => 100644 libnd4j/include/cblas.h mode change 100755 => 100644 libnd4j/include/cblas_enum_conversion.h mode change 100755 => 100644 libnd4j/include/legacy/NativeOps.h mode change 100755 => 100644 libnd4j/include/legacy/cuda/NativeOps.cu mode change 100755 => 100644 libnd4j/include/loops/broadcasting.h mode change 100755 => 100644 libnd4j/include/loops/indexreduce.h mode change 100755 => 100644 libnd4j/include/loops/pairwise_transform.h mode change 100755 => 100644 libnd4j/include/loops/reduce3.h mode change 100755 => 100644 libnd4j/include/loops/scalar.h mode change 100755 => 100644 libnd4j/include/loops/summarystatsreduce.h mode change 100755 => 100644 libnd4j/include/system/buffer.h mode change 100755 => 100644 libnd4j/include/system/pairwise_util.h diff --git a/libnd4j/include/array/ArrayOptions.h b/libnd4j/include/array/ArrayOptions.h index 0e874155d484..2b9b98d4cec3 100644 --- a/libnd4j/include/array/ArrayOptions.h +++ b/libnd4j/include/array/ArrayOptions.h @@ -21,22 +21,21 @@ #ifndef ND4J_ARRAY_OPTIONS_H #define ND4J_ARRAY_OPTIONS_H -#include -#include -#include -#include #include +#include #include #include -#include +#include +#include +#include +#include #define ARRAY_SPARSE 2 #define ARRAY_COMPRESSED 4 #define ARRAY_EMPTY 8 #define ARRAY_RAGGED 16 - #define ARRAY_CSR 32 #define ARRAY_CSC 64 #define ARRAY_COO 128 @@ -79,295 +78,304 @@ #define ARRAY_UTF16 4194304 #define ARRAY_UTF32 16777216 -// flag for extras +// flag for extras #define ARRAY_EXTRAS 2097152 - // flag for signed/unsigned integers #define ARRAY_UNSIGNED 8388608 - namespace sd { - class SD_EXPORT ArrayOptions { - - private: - static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape); +class SD_EXPORT ArrayOptions { + private: + static FORCEINLINE _CUDA_HD Nd4jLong &extra(Nd4jLong *shape); - public: - static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, int property); - static FORCEINLINE _CUDA_HD void setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties); + public: + static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, + int property); + static FORCEINLINE _CUDA_HD void setPropertyBits( + Nd4jLong *shapeInfo, std::initializer_list properties); - static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo); + static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, + const sd::DataType dataType); - static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo); - static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType); + static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong *to, + const Nd4jLong *from); +}; - static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from); - }; - - FORCEINLINE _CUDA_HD Nd4jLong& ArrayOptions::extra(Nd4jLong* shape) { - return shape[shape[0] + shape[0] + 1]; - } - - FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) { - return (extra(const_cast(shapeInfo)) != 0); - } +FORCEINLINE _CUDA_HD Nd4jLong &ArrayOptions::extra(Nd4jLong *shape) { + return shape[shape[0] + shape[0] + 1]; +} +FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) { + return (extra(const_cast(shapeInfo)) != 0); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) { - return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE); - } +FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) { + return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties(Nd4jLong *shapeInfo) { - return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS); - } +FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties( + Nd4jLong *shapeInfo) { + return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet(const Nd4jLong *shapeInfo, int property) { - if (!isNewFormat(shapeInfo)) - return false; +FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet( + const Nd4jLong *shapeInfo, int property) { + if (!isNewFormat(shapeInfo)) return false; - return ((extra(const_cast(shapeInfo)) & property) == property); - } + return ((extra(const_cast(shapeInfo)) & property) == property); +} - FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) { - if (!isNewFormat(shapeInfo)) - return false; +FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) { + if (!isNewFormat(shapeInfo)) return false; - return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED); - } + return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED); +} - FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) { - /*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) - return sd::DataType::QINT8; - else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT)) - return sd::DataType::FLOAT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE)) - return sd::DataType::DOUBLE; - else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF)) - return sd::DataType::HALF; - else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF)) - return sd::DataType::BFLOAT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL)) - return sd::DataType ::BOOL; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) { - if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) - return sd::DataType ::UINT8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) - return sd::DataType ::UINT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) - return sd::DataType ::UINT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) - return sd::DataType ::UINT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) - return sd::DataType ::UTF8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) - return sd::DataType ::UTF16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) - return sd::DataType ::UTF32; - else { - //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast(shapeInfo)); +FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType( + const Nd4jLong *shapeInfo) { + /*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) + return sd::DataType::QINT8; + else */ + if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT)) + return sd::DataType::FLOAT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE)) + return sd::DataType::DOUBLE; + else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF)) + return sd::DataType::HALF; + else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF)) + return sd::DataType::BFLOAT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL)) + return sd::DataType ::BOOL; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) { + if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) + return sd::DataType ::UINT8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) + return sd::DataType ::UINT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) + return sd::DataType ::UINT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) + return sd::DataType ::UINT64; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) + return sd::DataType ::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return sd::DataType ::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return sd::DataType ::UTF32; + else { + // shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in + // shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Bad datatype A"); + throw std::runtime_error("Bad datatype A"); #endif - } - } - else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) - return sd::DataType::INT8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) - return sd::DataType::INT16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) - return sd::DataType::INT32; - else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) - return sd::DataType::INT64; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) - return sd::DataType::UTF8; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) - return sd::DataType::UTF16; - else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) - return sd::DataType::UTF32; - else { - //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast(shapeInfo)); + } + } else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) + return sd::DataType::INT8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) + return sd::DataType::INT16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) + return sd::DataType::INT32; + else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) + return sd::DataType::INT64; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) + return sd::DataType::UTF8; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) + return sd::DataType::UTF16; + else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) + return sd::DataType::UTF32; + else { + // shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", + // const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Bad datatype B"); + throw std::runtime_error("Bad datatype B"); #endif - } - } + } +} - FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(const Nd4jLong *shapeInfo) { - return spaceType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD SpaceType +ArrayOptions::spaceType(const Nd4jLong *shapeInfo) { + return spaceType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) { - if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) - return SpaceType::QUANTIZED; - if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX)) - return SpaceType::COMPLEX; - else // by default we return continuous type here - return SpaceType::CONTINUOUS; - } +FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) { + if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) + return SpaceType::QUANTIZED; + if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX)) + return SpaceType::COMPLEX; + else // by default we return continuous type here + return SpaceType::CONTINUOUS; +} - FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(const Nd4jLong *shapeInfo) { - return arrayType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD ArrayType +ArrayOptions::arrayType(const Nd4jLong *shapeInfo) { + return arrayType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) { - if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE)) - return ArrayType::SPARSE; - else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED)) - return ArrayType::COMPRESSED; - else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) - return ArrayType::EMPTY; - else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) - return ArrayType::RAGGED; - else // by default we return DENSE type here - return ArrayType::DENSE; - } +FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) { + if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE)) + return ArrayType::SPARSE; + else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED)) + return ArrayType::COMPRESSED; + else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) + return ArrayType::EMPTY; + else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) + return ArrayType::RAGGED; + else // by default we return DENSE type here + return ArrayType::DENSE; +} - FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) ^= property; +FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) ^= property; - return hasPropertyBitSet(shapeInfo, property); - } + return hasPropertyBitSet(shapeInfo, property); +} - FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) |= property; - } +FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) |= property; +} - FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, int property) { - extra(shapeInfo) &= property; - } +FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, + int property) { + extra(shapeInfo) &= property; +} - FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(const Nd4jLong *shapeInfo) { - return sparseType(const_cast(shapeInfo)); - } +FORCEINLINE _CUDA_HD SparseType +ArrayOptions::sparseType(const Nd4jLong *shapeInfo) { + return sparseType(const_cast(shapeInfo)); +} - FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) { +FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) { #ifndef __CUDA_ARCH__ - if (!isSparseArray(shapeInfo)) - throw std::runtime_error("Not a sparse array"); + if (!isSparseArray(shapeInfo)) throw std::runtime_error("Not a sparse array"); #endif - if (hasPropertyBitSet(shapeInfo, ARRAY_CSC)) - return SparseType::CSC; - else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR)) - return SparseType::CSR; - else if (hasPropertyBitSet(shapeInfo, ARRAY_COO)) - return SparseType::COO; - else - return SparseType::LIL; - } + if (hasPropertyBitSet(shapeInfo, ARRAY_CSC)) + return SparseType::CSC; + else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR)) + return SparseType::CSR; + else if (hasPropertyBitSet(shapeInfo, ARRAY_COO)) + return SparseType::COO; + else + return SparseType::LIL; +} - FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties) { - for (auto v: properties) { - if (!hasPropertyBitSet(shapeInfo, v)) - setPropertyBit(shapeInfo, v); - } - } +FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits( + Nd4jLong *shapeInfo, std::initializer_list properties) { + for (auto v : properties) { + if (!hasPropertyBitSet(shapeInfo, v)) setPropertyBit(shapeInfo, v); + } +} - FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) { - unsetPropertyBit(shapeInfo, ARRAY_BOOL); - unsetPropertyBit(shapeInfo, ARRAY_HALF); - unsetPropertyBit(shapeInfo, ARRAY_BHALF); - unsetPropertyBit(shapeInfo, ARRAY_FLOAT); - unsetPropertyBit(shapeInfo, ARRAY_DOUBLE); - unsetPropertyBit(shapeInfo, ARRAY_INT); - unsetPropertyBit(shapeInfo, ARRAY_LONG); - unsetPropertyBit(shapeInfo, ARRAY_CHAR); - unsetPropertyBit(shapeInfo, ARRAY_SHORT); - unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED); - } +FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) { + unsetPropertyBit(shapeInfo, ARRAY_BOOL); + unsetPropertyBit(shapeInfo, ARRAY_HALF); + unsetPropertyBit(shapeInfo, ARRAY_BHALF); + unsetPropertyBit(shapeInfo, ARRAY_FLOAT); + unsetPropertyBit(shapeInfo, ARRAY_DOUBLE); + unsetPropertyBit(shapeInfo, ARRAY_INT); + unsetPropertyBit(shapeInfo, ARRAY_LONG); + unsetPropertyBit(shapeInfo, ARRAY_CHAR); + unsetPropertyBit(shapeInfo, ARRAY_SHORT); + unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED); +} - FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType) { - resetDataType(shapeInfo); - if (dataType == sd::DataType::UINT8 || - dataType == sd::DataType::UINT16 || - dataType == sd::DataType::UINT32 || - dataType == sd::DataType::UINT64) { - - setPropertyBit(shapeInfo, ARRAY_UNSIGNED); - } - - switch (dataType) { - case sd::DataType::BOOL: - setPropertyBit(shapeInfo, ARRAY_BOOL); - break; - case sd::DataType::HALF: - setPropertyBit(shapeInfo, ARRAY_HALF); - break; - case sd::DataType::BFLOAT16: - setPropertyBit(shapeInfo, ARRAY_BHALF); - break; - case sd::DataType::FLOAT32: - setPropertyBit(shapeInfo, ARRAY_FLOAT); - break; - case sd::DataType::DOUBLE: - setPropertyBit(shapeInfo, ARRAY_DOUBLE); - break; - case sd::DataType::INT8: - setPropertyBit(shapeInfo, ARRAY_CHAR); - break; - case sd::DataType::INT16: - setPropertyBit(shapeInfo, ARRAY_SHORT); - break; - case sd::DataType::INT32: - setPropertyBit(shapeInfo, ARRAY_INT); - break; - case sd::DataType::INT64: - setPropertyBit(shapeInfo, ARRAY_LONG); - break; - case sd::DataType::UINT8: - setPropertyBit(shapeInfo, ARRAY_CHAR); - break; - case sd::DataType::UINT16: - setPropertyBit(shapeInfo, ARRAY_SHORT); - break; - case sd::DataType::UINT32: - setPropertyBit(shapeInfo, ARRAY_INT); - break; - case sd::DataType::UINT64: - setPropertyBit(shapeInfo, ARRAY_LONG); - break; - case sd::DataType::UTF8: - setPropertyBit(shapeInfo, ARRAY_UTF8); - break; - case sd::DataType::UTF16: - setPropertyBit(shapeInfo, ARRAY_UTF16); - break; - case sd::DataType::UTF32: - setPropertyBit(shapeInfo, ARRAY_UTF32); - break; - default: +FORCEINLINE _CUDA_HD void ArrayOptions::setDataType( + Nd4jLong *shapeInfo, const sd::DataType dataType) { + resetDataType(shapeInfo); + if (dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || + dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64) { + setPropertyBit(shapeInfo, ARRAY_UNSIGNED); + } + + switch (dataType) { + case sd::DataType::BOOL: + setPropertyBit(shapeInfo, ARRAY_BOOL); + break; + case sd::DataType::HALF: + setPropertyBit(shapeInfo, ARRAY_HALF); + break; + case sd::DataType::BFLOAT16: + setPropertyBit(shapeInfo, ARRAY_BHALF); + break; + case sd::DataType::FLOAT32: + setPropertyBit(shapeInfo, ARRAY_FLOAT); + break; + case sd::DataType::DOUBLE: + setPropertyBit(shapeInfo, ARRAY_DOUBLE); + break; + case sd::DataType::INT8: + setPropertyBit(shapeInfo, ARRAY_CHAR); + break; + case sd::DataType::INT16: + setPropertyBit(shapeInfo, ARRAY_SHORT); + break; + case sd::DataType::INT32: + setPropertyBit(shapeInfo, ARRAY_INT); + break; + case sd::DataType::INT64: + setPropertyBit(shapeInfo, ARRAY_LONG); + break; + case sd::DataType::UINT8: + setPropertyBit(shapeInfo, ARRAY_CHAR); + break; + case sd::DataType::UINT16: + setPropertyBit(shapeInfo, ARRAY_SHORT); + break; + case sd::DataType::UINT32: + setPropertyBit(shapeInfo, ARRAY_INT); + break; + case sd::DataType::UINT64: + setPropertyBit(shapeInfo, ARRAY_LONG); + break; + case sd::DataType::UTF8: + setPropertyBit(shapeInfo, ARRAY_UTF8); + break; + case sd::DataType::UTF16: + setPropertyBit(shapeInfo, ARRAY_UTF16); + break; + case sd::DataType::UTF32: + setPropertyBit(shapeInfo, ARRAY_UTF32); + break; + default: #ifndef __CUDA_ARCH__ - throw std::runtime_error("Can't set unknown data type"); + throw std::runtime_error("Can't set unknown data type"); #else - printf("Can't set unknown data type"); + printf("Can't set unknown data type"); #endif - } - } + } +} //////////////////////////////////////////////////////////////////////////////// - FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong* to, const Nd4jLong* from) { - setDataType(to, dataType(from)); - } +FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong *to, + const Nd4jLong *from) { + setDataType(to, dataType(from)); } +} // namespace sd -#endif // ND4J_ARRAY_OPTIONS_H :) \ No newline at end of file +#endif // ND4J_ARRAY_OPTIONS_H :) \ No newline at end of file diff --git a/libnd4j/include/array/ArrayType.h b/libnd4j/include/array/ArrayType.h index 83e80bc0f068..c749c278f2ac 100644 --- a/libnd4j/include/array/ArrayType.h +++ b/libnd4j/include/array/ArrayType.h @@ -22,13 +22,13 @@ #define ND4J_ARRAY_TYPE_H namespace sd { - enum ArrayType { - DENSE = 1, - SPARSE = 2, - COMPRESSED = 3, - EMPTY = 4, - RAGGED = 5, - }; +enum ArrayType { + DENSE = 1, + SPARSE = 2, + COMPRESSED = 3, + EMPTY = 4, + RAGGED = 5, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/ByteOrder.h b/libnd4j/include/array/ByteOrder.h index 121be9e9dcb9..88c20b693ad4 100644 --- a/libnd4j/include/array/ByteOrder.h +++ b/libnd4j/include/array/ByteOrder.h @@ -22,10 +22,10 @@ #define LIBND4J_BYTEORDER_H namespace sd { - enum ByteOrder { - LE = 0, - BE = 1, - }; +enum ByteOrder { + LE = 0, + BE = 1, +}; } -#endif //LIBND4J_BYTEORDER_H +#endif // LIBND4J_BYTEORDER_H diff --git a/libnd4j/include/array/ByteOrderUtils.h b/libnd4j/include/array/ByteOrderUtils.h index b8d16d81119a..f34b9c1698b4 100644 --- a/libnd4j/include/array/ByteOrderUtils.h +++ b/libnd4j/include/array/ByteOrderUtils.h @@ -22,15 +22,15 @@ #define LIBND4J_BYTEORDERUTILS_H #include -#include "ByteOrder.h" #include -namespace sd { - class SD_EXPORT ByteOrderUtils { - public: - static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order); - }; -} +#include "ByteOrder.h" +namespace sd { +class SD_EXPORT ByteOrderUtils { + public: + static ByteOrder fromFlatByteOrder(sd::graph::ByteOrder order); +}; +} // namespace sd -#endif //LIBND4J_BYTEORDERUTILS_H +#endif // LIBND4J_BYTEORDERUTILS_H diff --git a/libnd4j/include/array/ConstantDataBuffer.h b/libnd4j/include/array/ConstantDataBuffer.h index 17e8a22c5de9..0f493f1d0142 100644 --- a/libnd4j/include/array/ConstantDataBuffer.h +++ b/libnd4j/include/array/ConstantDataBuffer.h @@ -23,37 +23,36 @@ #include #include - namespace sd { - class SD_EXPORT ConstantDataBuffer { - private: - Nd4jPointer _primaryBuffer = nullptr; - Nd4jPointer _specialBuffer = nullptr; - Nd4jLong _length = 0; - Nd4jLong _sizeOf = 0; - - public: - ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf); - ConstantDataBuffer(const ConstantDataBuffer &other); - ConstantDataBuffer() = default; - ~ConstantDataBuffer() = default; - - Nd4jLong sizeOf() const; - Nd4jLong length() const; - - Nd4jPointer primary() const; - Nd4jPointer special() const; - - ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default; - ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default; - - - template - T* primaryAsT(); - - template - T* specialAsT(); - }; -} - -#endif //SD_CONSTANTDATABUFFER_H +class SD_EXPORT ConstantDataBuffer { + private: + Nd4jPointer _primaryBuffer = nullptr; + Nd4jPointer _specialBuffer = nullptr; + Nd4jLong _length = 0; + Nd4jLong _sizeOf = 0; + + public: + ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, + Nd4jLong numEelements, Nd4jLong sizeOf); + ConstantDataBuffer(const ConstantDataBuffer& other); + ConstantDataBuffer() = default; + ~ConstantDataBuffer() = default; + + Nd4jLong sizeOf() const; + Nd4jLong length() const; + + Nd4jPointer primary() const; + Nd4jPointer special() const; + + ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default; + ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default; + + template + T* primaryAsT(); + + template + T* specialAsT(); +}; +} // namespace sd + +#endif // SD_CONSTANTDATABUFFER_H diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index 8c3f5f5ea21d..a377a001fbbd 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -21,55 +21,56 @@ #ifndef SD_CONSTANTDESCRIPTOR_H #define SD_CONSTANTDESCRIPTOR_H +#include #include +#include +#include + #include #include -#include -#include -#include namespace sd { - class SD_EXPORT ConstantDescriptor { - private: - std::vector _integerValues; - std::vector _floatValues; - public: - ConstantDescriptor(double* values, int length); - ConstantDescriptor(Nd4jLong const* values, int length); - ConstantDescriptor(std::initializer_list values); +class SD_EXPORT ConstantDescriptor { + private: + std::vector _integerValues; + std::vector _floatValues; - explicit ConstantDescriptor(std::vector &values); - explicit ConstantDescriptor(std::vector &values); + public: + ConstantDescriptor(double *values, int length); + ConstantDescriptor(Nd4jLong const *values, int length); + ConstantDescriptor(std::initializer_list values); - ~ConstantDescriptor() = default; + explicit ConstantDescriptor(std::vector &values); + explicit ConstantDescriptor(std::vector &values); - // equal to operator - bool operator==(const ConstantDescriptor &other) const; + ~ConstantDescriptor() = default; - // less than operator - bool operator<(const ConstantDescriptor &other) const; + // equal to operator + bool operator==(const ConstantDescriptor &other) const; - bool isInteger() const; - bool isFloat() const; + // less than operator + bool operator<(const ConstantDescriptor &other) const; - Nd4jLong length() const; + bool isInteger() const; + bool isFloat() const; - const std::vector& integerValues() const; - const std::vector& floatValues() const; - }; -} + Nd4jLong length() const; + + const std::vector &integerValues() const; + const std::vector &floatValues() const; +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class SD_EXPORT hash { - public: - size_t operator()(const sd::ConstantDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::ConstantDescriptor &k) const; +}; +} // namespace std #endif - -#endif //SD_CONSTANTDESCRIPTOR_H +#endif // SD_CONSTANTDESCRIPTOR_H diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index 833549a742c1..0e606b7a421b 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -22,44 +22,45 @@ #ifndef LIBND4J_CONSTANTHOLDER_H #define LIBND4J_CONSTANTHOLDER_H -#include -#include #include +#include + +#include #include namespace sd { - class ConstantHolder { - private: - int _deviceId = 0; - std::mutex _mutex; +class ConstantHolder { + private: + int _deviceId = 0; + std::mutex _mutex; - std::map _buffers; - public: - ConstantHolder(const ConstantHolder& other); - ConstantHolder() = default; - ~ConstantHolder() = default; + std::map _buffers; - ConstantHolder& operator=(const ConstantHolder& other) = default; - ConstantHolder& operator=(ConstantHolder&& other) = default; + public: + ConstantHolder(const ConstantHolder& other); + ConstantHolder() = default; + ~ConstantHolder() = default; - bool hasBuffer(sd::DataType dataType); + ConstantHolder& operator=(const ConstantHolder& other) = default; + ConstantHolder& operator=(ConstantHolder&& other) = default; - template - bool hasBuffer(); + bool hasBuffer(sd::DataType dataType); - void addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType); + template + bool hasBuffer(); - template - void addBuffer(ConstantDataBuffer &pointer); + void addBuffer(ConstantDataBuffer& pointer, sd::DataType dataType); - ConstantDataBuffer* getConstantDataBuffer(sd::DataType dataType); + template + void addBuffer(ConstantDataBuffer& pointer); - template - ConstantDataBuffer* getConstantDataBuffer(); + ConstantDataBuffer* getConstantDataBuffer(sd::DataType dataType); - std::mutex* mutex(); - }; -} + template + ConstantDataBuffer* getConstantDataBuffer(); + std::mutex* mutex(); +}; +} // namespace sd -#endif //SD_CONSTANTHOLDER_H +#endif // SD_CONSTANTHOLDER_H diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 1d0f93400915..312cf9ccb60a 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -22,140 +22,147 @@ #ifndef SD_DATABUFFER_H #define SD_DATABUFFER_H -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { class SD_EXPORT DataBuffer { - - protected: - - void* _primaryBuffer = nullptr; - void* _specialBuffer = nullptr; - size_t _lenInBytes = 0; - DataType _dataType; - memory::Workspace* _workspace = nullptr; - bool _isOwnerPrimary; - bool _isOwnerSpecial; - std::atomic _deviceId; - - #ifdef __CUDABLAS__ - mutable std::atomic _counter; - mutable std::atomic _writePrimary; - mutable std::atomic _writeSpecial; - mutable std::atomic _readPrimary; - mutable std::atomic _readSpecial; - #endif - - void setCountersToZero(); - void copyCounters(const DataBuffer& other); - virtual void deleteSpecial(); - virtual void deletePrimary(); - virtual void deleteBuffers(); - void setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial = false); - void allocateBuffers(const bool allocBoth = false); - void setSpecial(void* special, const bool isOwnerSpecial); - void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetHostBuffer = 0); - - - public: - - DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary = false, const bool isOwnerSpecial = false, - memory::Workspace* workspace = nullptr); - - DataBuffer(void* primary, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary = false, - memory::Workspace* workspace = nullptr); - - DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own memory buffer - const DataType dataType, const size_t lenInBytes, - memory::Workspace* workspace = nullptr); - - DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace = nullptr, const bool allocBoth = false); - - DataBuffer(const DataBuffer& other); - DataBuffer(DataBuffer&& other); - explicit DataBuffer(); - ~DataBuffer(); - - virtual DataBuffer& operator=(const DataBuffer& other); - virtual DataBuffer& operator=(DataBuffer&& other) noexcept; - - DataType getDataType(); - void setDataType(DataType dataType); - size_t getLenInBytes() const; - - virtual void* primary(); - virtual void* special(); - virtual void* platform(); - - virtual void allocatePrimary(); - virtual void allocateSpecial(); - - void writePrimary() const; - void writeSpecial() const; - void readPrimary() const; - void readSpecial() const; - bool isPrimaryActual() const; - bool isSpecialActual() const; - - virtual void expand(const uint64_t size); - - int deviceId() const; - void setDeviceId(int deviceId); - - virtual void migrate(); - - template FORCEINLINE T* primaryAsT(); - template FORCEINLINE T* specialAsT(); - template FORCEINLINE T* platformAsT(); - - void syncToPrimary(const LaunchContext* context, const bool forceSync = false); - void syncToSpecial(const bool forceSync = false); - - virtual void setToZeroBuffers(const bool both = false); - - void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); - - static void memcpy(const DataBuffer &dst, const DataBuffer &src); - - virtual void setPrimaryBuffer(void *buffer, size_t length); - virtual void setSpecialBuffer(void *buffer, size_t length); - - /** - * This method deletes buffers, if we're owners - */ - virtual void close(); + protected: + void* _primaryBuffer = nullptr; + void* _specialBuffer = nullptr; + size_t _lenInBytes = 0; + DataType _dataType; + memory::Workspace* _workspace = nullptr; + bool _isOwnerPrimary; + bool _isOwnerSpecial; + std::atomic _deviceId; + +#ifdef __CUDABLAS__ + mutable std::atomic _counter; + mutable std::atomic _writePrimary; + mutable std::atomic _writeSpecial; + mutable std::atomic _readPrimary; + mutable std::atomic _readSpecial; +#endif + + void setCountersToZero(); + void copyCounters(const DataBuffer& other); + virtual void deleteSpecial(); + virtual void deletePrimary(); + virtual void deleteBuffers(); + void setAllocFlags(const bool isOwnerPrimary, + const bool isOwnerSpecial = false); + void allocateBuffers(const bool allocBoth = false); + void setSpecial(void* special, const bool isOwnerSpecial); + void copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes = 0, + const Nd4jLong offsetThis = 0, + const Nd4jLong offsetHostBuffer = 0); + + public: + DataBuffer(void* primary, void* special, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary = false, + const bool isOwnerSpecial = false, + memory::Workspace* workspace = nullptr); + + DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, + const bool isOwnerPrimary = false, + memory::Workspace* workspace = nullptr); + + DataBuffer(const void* hostBuffer, // copies data from hostBuffer to own + // memory buffer + const DataType dataType, const size_t lenInBytes, + memory::Workspace* workspace = nullptr); + + DataBuffer(const size_t lenInBytes, const DataType dataType, + memory::Workspace* workspace = nullptr, + const bool allocBoth = false); + + DataBuffer(const DataBuffer& other); + DataBuffer(DataBuffer&& other); + explicit DataBuffer(); + ~DataBuffer(); + + virtual DataBuffer& operator=(const DataBuffer& other); + virtual DataBuffer& operator=(DataBuffer&& other) noexcept; + + DataType getDataType(); + void setDataType(DataType dataType); + size_t getLenInBytes() const; + + virtual void* primary(); + virtual void* special(); + virtual void* platform(); + + virtual void allocatePrimary(); + virtual void allocateSpecial(); + + void writePrimary() const; + void writeSpecial() const; + void readPrimary() const; + void readSpecial() const; + bool isPrimaryActual() const; + bool isSpecialActual() const; + + virtual void expand(const uint64_t size); + + int deviceId() const; + void setDeviceId(int deviceId); + + virtual void migrate(); + + template + FORCEINLINE T* primaryAsT(); + template + FORCEINLINE T* specialAsT(); + template + FORCEINLINE T* platformAsT(); + + void syncToPrimary(const LaunchContext* context, + const bool forceSync = false); + void syncToSpecial(const bool forceSync = false); + + virtual void setToZeroBuffers(const bool both = false); + + void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, + const Nd4jLong offsetThis = 0, + const Nd4jLong offsetOther = 0); + + static void memcpy(const DataBuffer& dst, const DataBuffer& src); + + virtual void setPrimaryBuffer(void* buffer, size_t length); + virtual void setSpecialBuffer(void* buffer, size_t length); + + /** + * This method deletes buffers, if we're owners + */ + virtual void close(); }; ///// IMLEMENTATION OF INLINE METHODS ///// //////////////////////////////////////////////////////////////////////// - template - T* DataBuffer::primaryAsT() { - return reinterpret_cast(primary()); - } +template +T* DataBuffer::primaryAsT() { + return reinterpret_cast(primary()); +} //////////////////////////////////////////////////////////////////////// - template - T* DataBuffer::specialAsT() { - return reinterpret_cast(special()); - } +template +T* DataBuffer::specialAsT() { + return reinterpret_cast(special()); +} //////////////////////////////////////////////////////////////////////// - template - T* DataBuffer::platformAsT() { - return reinterpret_cast(platform()); - } +template +T* DataBuffer::platformAsT() { + return reinterpret_cast(platform()); } +} // namespace sd - -#endif //SD_DATABUFFER_H +#endif // SD_DATABUFFER_H diff --git a/libnd4j/include/array/DataType.h b/libnd4j/include/array/DataType.h index cf8baf7d0324..8aa16ffff4ac 100644 --- a/libnd4j/include/array/DataType.h +++ b/libnd4j/include/array/DataType.h @@ -22,31 +22,31 @@ #define ND4J_DATATYPE_H namespace sd { - enum DataType { - INHERIT = 0, - BOOL = 1, - FLOAT8 = 2, - HALF = 3, - HALF2 = 4, - FLOAT32 = 5, - DOUBLE = 6, - INT8 = 7, - INT16 = 8, - INT32 = 9, - INT64 = 10, - UINT8 = 11, - UINT16 = 12, - UINT32 = 13, - UINT64 = 14, - QINT8 = 15, - QINT16 = 16, - BFLOAT16 = 17, - UTF8 = 50, - UTF16 = 51, - UTF32 = 52, - ANY = 100, - AUTO = 200, - }; +enum DataType { + INHERIT = 0, + BOOL = 1, + FLOAT8 = 2, + HALF = 3, + HALF2 = 4, + FLOAT32 = 5, + DOUBLE = 6, + INT8 = 7, + INT16 = 8, + INT32 = 9, + INT64 = 10, + UINT8 = 11, + UINT16 = 12, + UINT32 = 13, + UINT64 = 14, + QINT8 = 15, + QINT16 = 16, + BFLOAT16 = 17, + UTF8 = 50, + UTF16 = 51, + UTF32 = 52, + ANY = 100, + AUTO = 200, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/DataTypeConversions.h b/libnd4j/include/array/DataTypeConversions.h index a3190728749f..92043ae15abc 100644 --- a/libnd4j/include/array/DataTypeConversions.h +++ b/libnd4j/include/array/DataTypeConversions.h @@ -21,168 +21,172 @@ #ifndef LIBND4J_DATATYPECONVERSIONS_H #define LIBND4J_DATATYPECONVERSIONS_H -#include -#include -#include #include -#include +#include #include +#include #include #include -#include +#include +#include +#include namespace sd { - template - class SD_EXPORT DataTypeConversions { - private: - template - static FORCEINLINE void rconv(bool isBe, bool canKeep, T *buffer, Nd4jLong length, void *src) { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new T2[length]; - memcpy(tmp, src, length * sizeof(T2)); - +template +class SD_EXPORT DataTypeConversions { + private: + template + static FORCEINLINE void rconv(bool isBe, bool canKeep, T *buffer, + Nd4jLong length, void *src) { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new T2[length]; + memcpy(tmp, src, length * sizeof(T2)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = canKeep + ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - - public: - static FORCEINLINE void convertType(void* vbuffer, void* src, DataType dataType, ByteOrder order, Nd4jLong length) { - auto buffer = reinterpret_cast(vbuffer); - bool isBe = BitwiseUtils::isBE(); - bool canKeep = (isBe && order == ByteOrder::BE) || (!isBe && order == ByteOrder::LE); - - switch (dataType) { - case BOOL: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case UINT8: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT8: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT16: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT32: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case INT64: { - DataTypeConversions::template rconv(isBe, canKeep, buffer, length, src); - } - break; - case FLOAT32: { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new float[length]; - memcpy(tmp, src, length * sizeof(float)); - + delete[] tmp; + } + } + + public: + static FORCEINLINE void convertType(void *vbuffer, void *src, + DataType dataType, ByteOrder order, + Nd4jLong length) { + auto buffer = reinterpret_cast(vbuffer); + bool isBe = BitwiseUtils::isBE(); + bool canKeep = + (isBe && order == ByteOrder::BE) || (!isBe && order == ByteOrder::LE); + + switch (dataType) { + case BOOL: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case UINT8: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT8: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT16: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT32: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case INT64: { + DataTypeConversions::template rconv(isBe, canKeep, buffer, + length, src); + } break; + case FLOAT32: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new float[length]; + memcpy(tmp, src, length * sizeof(float)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - case DOUBLE: { - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new double[length]; - memcpy(tmp, src, length * sizeof(double)); + delete[] tmp; + } + } break; + case DOUBLE: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new double[length]; + memcpy(tmp, src, length * sizeof(double)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); - + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - case HALF: { - - if (std::is_same::value && canKeep) { - memcpy(buffer, src, length * sizeof(T)); - } else { - auto tmp = new float16[length]; - memcpy(tmp, src, length * sizeof(float16)); + delete[] tmp; + } + } break; + case HALF: { + if (std::is_same::value && canKeep) { + memcpy(buffer, src, length * sizeof(T)); + } else { + auto tmp = new float16[length]; + memcpy(tmp, src, length * sizeof(float16)); #if __GNUC__ <= 4 - if (!canKeep) - for (Nd4jLong e = 0; e < length; e++) - buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); - else - TypeCast::convertGeneric(nullptr, tmp, length, buffer); + if (!canKeep) + for (Nd4jLong e = 0; e < length; e++) + buffer[e] = BitwiseUtils::swap_bytes(static_cast(tmp[e])); + else + TypeCast::convertGeneric(nullptr, tmp, length, buffer); #else - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - buffer[e] = canKeep ? static_cast(tmp[e]) : BitwiseUtils::swap_bytes(static_cast(tmp[e])); - }; - - samediff::Threads::parallel_for(func, 0, length); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + buffer[e] = + canKeep ? static_cast(tmp[e]) + : BitwiseUtils::swap_bytes(static_cast(tmp[e])); + }; + + samediff::Threads::parallel_for(func, 0, length); #endif - delete[] tmp; - } - } - break; - default: { - nd4j_printf("Unsupported DataType requested: [%i]\n", static_cast(dataType)); - throw std::runtime_error("Unsupported DataType"); - } - } + delete[] tmp; } - }; -} - - - -#endif //LIBND4J_DATATYPECONVERSIONS_H + } break; + default: { + nd4j_printf("Unsupported DataType requested: [%i]\n", + static_cast(dataType)); + throw std::runtime_error("Unsupported DataType"); + } + } + } +}; +} // namespace sd + +#endif // LIBND4J_DATATYPECONVERSIONS_H diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index a0c40a89f312..c22af0f1341a 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -22,477 +22,505 @@ #ifndef DATATYPEUTILS_H #define DATATYPEUTILS_H -#include -#include +#include #include #include -#include -#include #include -#include +#include +#include +#include +#include //#include //#include #include namespace sd { - class SD_EXPORT DataTypeUtils { - public: - static int asInt(DataType type); - static DataType fromInt(int dtype); - static DataType fromFlatDataType(sd::graph::DType dtype); - FORCEINLINE static std::string asString(DataType dataType); - - template - static FORCEINLINE _CUDA_HD DataType fromT(); - static FORCEINLINE _CUDA_HD size_t sizeOfElement(DataType type); - - // returns the smallest finite value of the given type - template - FORCEINLINE static _CUDA_HD T min(); - - // returns the largest finite value of the given type - template - FORCEINLINE static _CUDA_HD T max(); - - /** - * returns inf for float/double and max for everything else - */ - template - FORCEINLINE static _CUDA_HD T infOrMax(); - - template - FORCEINLINE static _CUDA_HD T nanOrZero(); - - // returns the difference between 1.0 and the next representable value of the given floating-point type - template - FORCEINLINE static T eps(); +class SD_EXPORT DataTypeUtils { + public: + static int asInt(DataType type); + static DataType fromInt(int dtype); + static DataType fromFlatDataType(sd::graph::DType dtype); + FORCEINLINE static std::string asString(DataType dataType); + + template + static FORCEINLINE _CUDA_HD DataType fromT(); + static FORCEINLINE _CUDA_HD size_t sizeOfElement(DataType type); + + // returns the smallest finite value of the given type + template + FORCEINLINE static _CUDA_HD T min(); + + // returns the largest finite value of the given type + template + FORCEINLINE static _CUDA_HD T max(); + + /** + * returns inf for float/double and max for everything else + */ + template + FORCEINLINE static _CUDA_HD T infOrMax(); + + template + FORCEINLINE static _CUDA_HD T nanOrZero(); + + // returns the difference between 1.0 and the next representable value of the + // given floating-point type + template + FORCEINLINE static T eps(); + + FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type); + FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo); + + FORCEINLINE static _CUDA_HD bool isR(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isZ(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isB(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isU(sd::DataType dataType); + + FORCEINLINE static _CUDA_HD bool isS(sd::DataType dataType); + + FORCEINLINE static sd::DataType pickPairwiseResultType(sd::DataType typeX, + sd::DataType typeY); + + FORCEINLINE static sd::DataType pickPairwiseResultType( + const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2); + + FORCEINLINE static sd::DataType pickFloatingType(sd::DataType typeX); + + template + FORCEINLINE static std::vector convertVector( + const std::vector& vector); + + template + FORCEINLINE static bool castShapeInfo(const Nd4jLong* originalShapeInfo, + T* newShapeInfo); + + template + // struct scalarTypesForNDarray { static bool const value = + // std::is_same::value || std::is_same::value || + // std::is_same::value || std::is_same::value || + // std::is_same::value || std::is_same::value; }; + struct scalarTypesForNDarray { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; + + template + struct scalarTypesForExecution { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; +}; - FORCEINLINE static _CUDA_HD size_t sizeOf(DataType type); - FORCEINLINE static _CUDA_HD size_t sizeOf(const Nd4jLong* shapeInfo); - - FORCEINLINE static _CUDA_HD bool isR(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isZ(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isB(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isU(sd::DataType dataType); - - FORCEINLINE static _CUDA_HD bool isS(sd::DataType dataType); - - FORCEINLINE static sd::DataType pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY); - - FORCEINLINE static sd::DataType pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2); - - FORCEINLINE static sd::DataType pickFloatingType(sd::DataType typeX); - - template - FORCEINLINE static std::vector convertVector(const std::vector &vector); - - template - FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); +////////////////////////////////////////////////////////////////////////// +///// IMLEMENTATION OF INLINE METHODS ///// +////////////////////////////////////////////////////////////////////////// - template - // struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; - struct scalarTypesForNDarray { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; +FORCEINLINE sd::DataType DataTypeUtils::pickFloatingType(sd::DataType typeX) { + // if proposed dataType is already floating point - return it + if (isR(typeX)) return typeX; + return Environment::getInstance()->defaultFloatDataType(); +} - template - struct scalarTypesForExecution { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; +FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) { + return dataType == sd::DataType::FLOAT32 || + dataType == sd::DataType::BFLOAT16 || dataType == sd::DataType::HALF || + dataType == sd::DataType::DOUBLE; +} - }; +FORCEINLINE bool DataTypeUtils::isB(sd::DataType dataType) { + return dataType == sd::DataType::BOOL; +} +FORCEINLINE bool DataTypeUtils::isS(sd::DataType dataType) { + return dataType == sd::DataType::UTF8 || dataType == sd::DataType::UTF16 || + dataType == sd::DataType::UTF32; +} -////////////////////////////////////////////////////////////////////////// -///// IMLEMENTATION OF INLINE METHODS ///// -////////////////////////////////////////////////////////////////////////// +FORCEINLINE bool DataTypeUtils::isZ(sd::DataType dataType) { + return !isR(dataType) && !isB(dataType) && !isS(dataType); +} - FORCEINLINE sd::DataType DataTypeUtils::pickFloatingType(sd::DataType typeX) { - // if proposed dataType is already floating point - return it - if (isR(typeX)) - return typeX; - return Environment::getInstance()->defaultFloatDataType(); - } +FORCEINLINE bool DataTypeUtils::isU(sd::DataType dataType) { + return dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || + dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64; +} - FORCEINLINE bool DataTypeUtils::isR(sd::DataType dataType) { - return dataType == sd::DataType::FLOAT32 || dataType == sd::DataType::BFLOAT16 || dataType == sd::DataType::HALF || dataType == sd::DataType::DOUBLE; - } +FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType( + sd::DataType typeX, sd::DataType typeY) { + // if both dtypes are the same - just return it + if (typeX == typeY) return typeX; + auto nd4j_max = [](sd::DataType typeX, sd::DataType typeY) { + return typeX > typeY ? typeX : typeY; + }; + auto rX = isR(typeX); + auto rY = isR(typeY); - FORCEINLINE bool DataTypeUtils::isB(sd::DataType dataType) { - return dataType == sd::DataType::BOOL; - } + // if X is float - use it + if (rX && !rY) return typeX; - FORCEINLINE bool DataTypeUtils::isS(sd::DataType dataType) { - return dataType == sd::DataType::UTF8 || dataType == sd::DataType::UTF16 || dataType == sd::DataType::UTF32; - } + // if Y is float - use it + if (!rX && rY) return typeY; - FORCEINLINE bool DataTypeUtils::isZ(sd::DataType dataType) { - return !isR(dataType) && !isB(dataType) && !isS(dataType); + // if both data types are float - return biggest one + if (rX && rY) { + // if we allow precision boost, then we pick bigger data type + if (sd::Environment::getInstance()->precisionBoostAllowed()) { + return nd4j_max(typeX, typeY); + } else { + // and we return first operand otherwise + return typeX; } - - FORCEINLINE bool DataTypeUtils::isU(sd::DataType dataType) { - return dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64; + } + + // if that's not real type, we apply same rules + if (!rX && !rY) { + if (sd::Environment::getInstance()->precisionBoostAllowed()) { + return nd4j_max(typeX, typeY); + } else { + // and we return first operand otherwise + return typeX; } + } - FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(sd::DataType typeX, sd::DataType typeY) { - // if both dtypes are the same - just return it - if (typeX == typeY) - return typeX; - auto nd4j_max = [](sd::DataType typeX, sd::DataType typeY) { - return typeX > typeY?typeX:typeY; - }; - auto rX = isR(typeX); - auto rY = isR(typeY); - - // if X is float - use it - if (rX && !rY) - return typeX; - - // if Y is float - use it - if (!rX && rY) - return typeY; - - // if both data types are float - return biggest one - if (rX && rY) { - // if we allow precision boost, then we pick bigger data type - if (sd::Environment::getInstance()->precisionBoostAllowed()) { - return nd4j_max(typeX, typeY); - } else { - // and we return first operand otherwise - return typeX; - } - - } - - // if that's not real type, we apply same rules - if (!rX && !rY) { - if (sd::Environment::getInstance()->precisionBoostAllowed()) { - return nd4j_max(typeX, typeY); - } else { - // and we return first operand otherwise - return typeX; - } - } - - return typeX; - } + return typeX; +} /////////////////////////////////////////////////////////////////// -FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType(const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) { - - return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), ArrayOptions::dataType(shapeInfo2)); +FORCEINLINE sd::DataType DataTypeUtils::pickPairwiseResultType( + const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2) { + return pickPairwiseResultType(ArrayOptions::dataType(shapeInfo1), + ArrayOptions::dataType(shapeInfo2)); } /////////////////////////////////////////////////////////////////// FORCEINLINE size_t DataTypeUtils::sizeOf(DataType type) { - return sizeOfElement(type); + return sizeOfElement(type); } /////////////////////////////////////////////////////////////////// FORCEINLINE size_t DataTypeUtils::sizeOf(const Nd4jLong* shapeInfo) { - return sizeOfElement(ArrayOptions::dataType(shapeInfo)); + return sizeOfElement(ArrayOptions::dataType(shapeInfo)); } // returns the smallest finite value of the given type -template<> +template <> FORCEINLINE _CUDA_HD int DataTypeUtils::min() { - return 1; + return 1; } -template<> +template <> FORCEINLINE _CUDA_HD char DataTypeUtils::min() { - return 1; + return 1; } template <> FORCEINLINE _CUDA_HD bool DataTypeUtils::min() { - return false; + return false; } -template<> +template <> FORCEINLINE _CUDA_HD Nd4jLong DataTypeUtils::min() { - return 1L; + return 1L; } -template<> +template <> FORCEINLINE _CUDA_HD uint64_t DataTypeUtils::min() { - return 1L; + return 1L; } -template<> +template <> FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::min() { - return 1; + return 1; } -template<> +template <> FORCEINLINE _CUDA_HD float DataTypeUtils::min() { - return 1.175494e-38; + return 1.175494e-38; } -template<> +template <> FORCEINLINE _CUDA_HD float16 DataTypeUtils::min() { - return (float16) 6.1035e-05; + return (float16)6.1035e-05; } -template<> +template <> FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::min() { - return bfloat16::min(); + return bfloat16::min(); } -template<> +template <> FORCEINLINE _CUDA_HD double DataTypeUtils::min() { - return 2.2250738585072014e-308; + return 2.2250738585072014e-308; } /////////////////////////////////////////////////////////////////// // returns the largest finite value of the given type template <> FORCEINLINE _CUDA_HD int DataTypeUtils::max() { - return 2147483647; + return 2147483647; } template <> FORCEINLINE _CUDA_HD bool DataTypeUtils::max() { - return true; + return true; } template <> FORCEINLINE _CUDA_HD int8_t DataTypeUtils::max() { - return 127; + return 127; } template <> FORCEINLINE _CUDA_HD uint8_t DataTypeUtils::max() { - return 255; + return 255; } template <> FORCEINLINE _CUDA_HD int16_t DataTypeUtils::max() { - return 32767; + return 32767; } template <> FORCEINLINE _CUDA_HD uint16_t DataTypeUtils::max() { - return 65535; + return 65535; } template <> FORCEINLINE _CUDA_HD Nd4jLong DataTypeUtils::max() { - return 9223372036854775807LL; + return 9223372036854775807LL; } template <> FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::max() { - return 4294967295; + return 4294967295; } template <> FORCEINLINE _CUDA_HD Nd4jULong DataTypeUtils::max() { - return 18446744073709551615LLU; + return 18446744073709551615LLU; } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::max() { - return 3.402823e+38; + return 3.402823e+38; } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::max() { - return 1.7976931348623157E308; + return 1.7976931348623157E308; } template <> FORCEINLINE _CUDA_HD float16 DataTypeUtils::max() { - return static_cast(65504.f); + return static_cast(65504.f); } template <> FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max() { - return bfloat16::max(); + return bfloat16::max(); } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax() { - return std::numeric_limits::infinity(); + return std::numeric_limits::infinity(); } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax() { - return std::numeric_limits::infinity(); + return std::numeric_limits::infinity(); } template FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() { - return DataTypeUtils::max(); + return DataTypeUtils::max(); } template <> FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero() { - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } template <> FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero() { - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); } template FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() { - return static_cast(0); + return static_cast(0); } FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { - switch(dataType) { - case INT8: - return std::string("INT8"); - case INT16: - return std::string("INT16"); - case INT32: - return std::string("INT32"); - case INT64: - return std::string("INT64"); - case BFLOAT16: - return std::string("BFLOAT16"); - case FLOAT32: - return std::string("FLOAT"); - case DOUBLE: - return std::string("DOUBLE"); - case HALF: - return std::string("HALF"); - case BOOL: - return std::string("BOOL"); - case UINT8: - return std::string("UINT8"); - case UINT16: - return std::string("UINT16"); - case UINT32: - return std::string("UINT32"); - case UINT64: - return std::string("UINT64"); - case UTF8: - return std::string("UTF8"); - case UTF16: - return std::string("UTF16"); - case UTF32: - return std::string("UTF32"); - default: - throw std::runtime_error("Unknown data type used"); - } + switch (dataType) { + case INT8: + return std::string("INT8"); + case INT16: + return std::string("INT16"); + case INT32: + return std::string("INT32"); + case INT64: + return std::string("INT64"); + case BFLOAT16: + return std::string("BFLOAT16"); + case FLOAT32: + return std::string("FLOAT"); + case DOUBLE: + return std::string("DOUBLE"); + case HALF: + return std::string("HALF"); + case BOOL: + return std::string("BOOL"); + case UINT8: + return std::string("UINT8"); + case UINT16: + return std::string("UINT16"); + case UINT32: + return std::string("UINT32"); + case UINT64: + return std::string("UINT64"); + case UTF8: + return std::string("UTF8"); + case UTF16: + return std::string("UTF16"); + case UTF32: + return std::string("UTF32"); + default: + throw std::runtime_error("Unknown data type used"); + } } - template -FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) { - auto shapeInfoLength = *originalShapeInfo * 2 + 4; - for (auto e = 0; e < shapeInfoLength; e++) { - if (originalShapeInfo[e] < static_cast(DataTypeUtils::max())) { - newShapeInfo[e] = static_cast(originalShapeInfo[e]); - } else - return false; - } +FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong* originalShapeInfo, + T* newShapeInfo) { + auto shapeInfoLength = *originalShapeInfo * 2 + 4; + for (auto e = 0; e < shapeInfoLength; e++) { + if (originalShapeInfo[e] < static_cast(DataTypeUtils::max())) { + newShapeInfo[e] = static_cast(originalShapeInfo[e]); + } else + return false; + } - return true; + return true; } /////////////////////////////////////////////////////////////////// -// returns the difference between 1.0 and the next representable value of the given floating-point type +// returns the difference between 1.0 and the next representable value of the +// given floating-point type template FORCEINLINE _CUDA_HD T DataTypeUtils::eps() { - if (std::is_same::value) - return std::numeric_limits::epsilon(); - else if (std::is_same::value) - return std::numeric_limits::epsilon(); - else if (std::is_same::value) - return 0.00097656; - else if (std::is_same::value) - return bfloat16::eps(); - else - return 0; -} - - - template - FORCEINLINE std::vector DataTypeUtils::convertVector(const std::vector &vector) { - std::vector result(vector.size()); - Nd4jLong vecSize = vector.size(); - for (Nd4jLong e = 0; e < vecSize; e++) - result[e] = static_cast(vector[e]); - - return result; - } - - FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(sd::DataType type) { - switch (type) { - case sd::DataType::UINT8: - case sd::DataType::INT8: - case sd::DataType::FLOAT8: - case sd::DataType::QINT8: - case sd::DataType::BOOL: return (size_t) 1; - - case sd::DataType::BFLOAT16: - case sd::DataType::HALF: - case sd::DataType::INT16: - case sd::DataType::QINT16: - case sd::DataType::UINT16: return (size_t) 2; - - case sd::DataType::UTF8: - case sd::DataType::UTF16: - case sd::DataType::UTF32: - case sd::DataType::INT32: - case sd::DataType::UINT32: - case sd::DataType::HALF2: - case sd::DataType::FLOAT32: return (size_t) 4; - - case sd::DataType::UINT64: - case sd::DataType::INT64: - case sd::DataType::DOUBLE: return (size_t) 8; - - default: { - nd4j_printf("Unknown DataType used: [%i]\n", asInt(type)); + if (std::is_same::value) + return std::numeric_limits::epsilon(); + else if (std::is_same::value) + return std::numeric_limits::epsilon(); + else if (std::is_same::value) + return 0.00097656; + else if (std::is_same::value) + return bfloat16::eps(); + else + return 0; +} + +template +FORCEINLINE std::vector DataTypeUtils::convertVector( + const std::vector& vector) { + std::vector result(vector.size()); + Nd4jLong vecSize = vector.size(); + for (Nd4jLong e = 0; e < vecSize; e++) result[e] = static_cast(vector[e]); + + return result; +} + +FORCEINLINE _CUDA_HD size_t DataTypeUtils::sizeOfElement(sd::DataType type) { + switch (type) { + case sd::DataType::UINT8: + case sd::DataType::INT8: + case sd::DataType::FLOAT8: + case sd::DataType::QINT8: + case sd::DataType::BOOL: + return (size_t)1; + + case sd::DataType::BFLOAT16: + case sd::DataType::HALF: + case sd::DataType::INT16: + case sd::DataType::QINT16: + case sd::DataType::UINT16: + return (size_t)2; + + case sd::DataType::UTF8: + case sd::DataType::UTF16: + case sd::DataType::UTF32: + case sd::DataType::INT32: + case sd::DataType::UINT32: + case sd::DataType::HALF2: + case sd::DataType::FLOAT32: + return (size_t)4; + + case sd::DataType::UINT64: + case sd::DataType::INT64: + case sd::DataType::DOUBLE: + return (size_t)8; + + default: { + nd4j_printf("Unknown DataType used: [%i]\n", asInt(type)); #ifndef __CUDA_ARCH__ - throw std::runtime_error("Unknown DataType requested"); + throw std::runtime_error("Unknown DataType requested"); #endif - } - } - } - - template - FORCEINLINE _CUDA_HD sd::DataType sd::DataTypeUtils::fromT() { - if (std::is_same::value) { - return sd::DataType::BOOL; - } else if (std::is_same::value) { - return sd::DataType::UTF8; - } else if (std::is_same::value) { - return sd::DataType::UTF16; - } else if (std::is_same::value) { - return sd::DataType::UTF32; - } else if (std::is_same::value) { - return sd::DataType::FLOAT32; - } else if (std::is_same::value) { - return sd::DataType::HALF; - } else if (std::is_same::value) { - return sd::DataType::BFLOAT16; - } else if (std::is_same::value) { - return sd::DataType::DOUBLE; - } else if (std::is_same::value) { - return sd::DataType::INT8; - } else if (std::is_same::value) { - return sd::DataType::INT16; - } else if (std::is_same::value) { - return sd::DataType::INT32; - } else if (std::is_same::value) { - return sd::DataType::INT64; - } else if (std::is_same::value) { - return sd::DataType::UINT8; - } else if (std::is_same::value) { - return sd::DataType::UINT16; - } else if (std::is_same::value) { - return sd::DataType::UINT32; - } else if (std::is_same::value) { - return sd::DataType::UINT64; - } else { - return sd::DataType::INHERIT; - } } + } } -#endif //DATATYPEUTILS_H \ No newline at end of file +template +FORCEINLINE _CUDA_HD sd::DataType sd::DataTypeUtils::fromT() { + if (std::is_same::value) { + return sd::DataType::BOOL; + } else if (std::is_same::value) { + return sd::DataType::UTF8; + } else if (std::is_same::value) { + return sd::DataType::UTF16; + } else if (std::is_same::value) { + return sd::DataType::UTF32; + } else if (std::is_same::value) { + return sd::DataType::FLOAT32; + } else if (std::is_same::value) { + return sd::DataType::HALF; + } else if (std::is_same::value) { + return sd::DataType::BFLOAT16; + } else if (std::is_same::value) { + return sd::DataType::DOUBLE; + } else if (std::is_same::value) { + return sd::DataType::INT8; + } else if (std::is_same::value) { + return sd::DataType::INT16; + } else if (std::is_same::value) { + return sd::DataType::INT32; + } else if (std::is_same::value) { + return sd::DataType::INT64; + } else if (std::is_same::value) { + return sd::DataType::UINT8; + } else if (std::is_same::value) { + return sd::DataType::UINT16; + } else if (std::is_same::value) { + return sd::DataType::UINT32; + } else if (std::is_same::value) { + return sd::DataType::UINT64; + } else { + return sd::DataType::INHERIT; + } +} +} // namespace sd + +#endif // DATATYPEUTILS_H \ No newline at end of file diff --git a/libnd4j/include/array/ExtraArguments.h b/libnd4j/include/array/ExtraArguments.h index 9404447dce65..4fce4bb73159 100644 --- a/libnd4j/include/array/ExtraArguments.h +++ b/libnd4j/include/array/ExtraArguments.h @@ -21,45 +21,45 @@ #ifndef SD_EXTRAARGUMENTS_H #define SD_EXTRAARGUMENTS_H +#include +#include #include +#include + #include #include -#include -#include -#include namespace sd { - class SD_EXPORT ExtraArguments { - private: - std::vector _fpArgs; - std::vector _intArgs; - - std::vector _pointers; +class SD_EXPORT ExtraArguments { + private: + std::vector _fpArgs; + std::vector _intArgs; - template - void convertAndCopy(Nd4jPointer pointer, Nd4jLong offset); + std::vector _pointers; - void* allocate(size_t length, size_t elementSize); - public: - explicit ExtraArguments(std::initializer_list arguments); - explicit ExtraArguments(std::initializer_list arguments); + template + void convertAndCopy(Nd4jPointer pointer, Nd4jLong offset); - explicit ExtraArguments(const std::vector &arguments); - explicit ExtraArguments(const std::vector &arguments); - explicit ExtraArguments(const std::vector &arguments); + void* allocate(size_t length, size_t elementSize); - explicit ExtraArguments(); - ~ExtraArguments(); + public: + explicit ExtraArguments(std::initializer_list arguments); + explicit ExtraArguments(std::initializer_list arguments); - template - void* argumentsAsT(Nd4jLong offset = 0); + explicit ExtraArguments(const std::vector& arguments); + explicit ExtraArguments(const std::vector& arguments); + explicit ExtraArguments(const std::vector& arguments); - void* argumentsAsT(sd::DataType dataType, Nd4jLong offset = 0); + explicit ExtraArguments(); + ~ExtraArguments(); - size_t length(); - }; -} + template + void* argumentsAsT(Nd4jLong offset = 0); + void* argumentsAsT(sd::DataType dataType, Nd4jLong offset = 0); + size_t length(); +}; +} // namespace sd -#endif //SD_EXTRAARGUMENTS_H +#endif // SD_EXTRAARGUMENTS_H diff --git a/libnd4j/include/array/InteropDataBuffer.h b/libnd4j/include/array/InteropDataBuffer.h index 10d14b6133b4..4cafb31d0cf0 100644 --- a/libnd4j/include/array/InteropDataBuffer.h +++ b/libnd4j/include/array/InteropDataBuffer.h @@ -18,54 +18,67 @@ // @author raver119@gmail.com // -#include #include #include +#include + #include #ifndef LIBND4J_INTEROPDATABUFFER_H #define LIBND4J_INTEROPDATABUFFER_H namespace sd { - /** - * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer between front-end and back-end languages - */ - class SD_EXPORT InteropDataBuffer { - private: - std::shared_ptr _dataBuffer; - uint64_t _offset = 0; - public: - InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset); - InteropDataBuffer(std::shared_ptr databuffer); - InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth); - ~InteropDataBuffer() = default; +/** + * This class is a wrapper for DataBuffer, suitable for sharing DataBuffer + * between front-end and back-end languages + */ +class SD_EXPORT InteropDataBuffer { + private: + std::shared_ptr _dataBuffer; + uint64_t _offset = 0; + + public: + InteropDataBuffer(InteropDataBuffer& dataBuffer, uint64_t length, + uint64_t offset); + InteropDataBuffer(std::shared_ptr databuffer); + InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth); + ~InteropDataBuffer() = default; #ifndef __JAVACPP_HACK__ - std::shared_ptr getDataBuffer() const; - std::shared_ptr dataBuffer(); + std::shared_ptr getDataBuffer() const; + std::shared_ptr dataBuffer(); #endif - void* primary() const; - void* special() const; - - uint64_t offset() const ; - void setOffset(uint64_t offset); + void* primary() const; + void* special() const; - void setPrimary(void* ptr, size_t length); - void setSpecial(void* ptr, size_t length); + uint64_t offset() const; + void setOffset(uint64_t offset); - void expand(size_t newlength); + void setPrimary(void* ptr, size_t length); + void setSpecial(void* ptr, size_t length); - int deviceId() const; - void setDeviceId(int deviceId); + void expand(size_t newlength); - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + int deviceId() const; + void setDeviceId(int deviceId); - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); - }; -} + static void registerSpecialUse( + const std::vector& writeList, + const std::vector& readList); + static void prepareSpecialUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); + static void registerPrimaryUse( + const std::vector& writeList, + const std::vector& readList); + static void preparePrimaryUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); +}; +} // namespace sd -#endif //LIBND4J_INTEROPDATABUFFER_H +#endif // LIBND4J_INTEROPDATABUFFER_H diff --git a/libnd4j/include/array/ManagedDataBuffer.h b/libnd4j/include/array/ManagedDataBuffer.h index 7cb32f065459..05d34d62f202 100644 --- a/libnd4j/include/array/ManagedDataBuffer.h +++ b/libnd4j/include/array/ManagedDataBuffer.h @@ -25,24 +25,25 @@ #include namespace sd { - /** - * This class provides special DataBuffer implementation for use within Graphs - */ - class SD_EXPORT ManagedDataBuffer : public DataBuffer { - private: - graph::GraphMemoryManager &_manager; - - protected: - memory::MemoryZone _zone; - MemoryDescriptor _descriptor; - - public: - ManagedDataBuffer(graph::GraphMemoryManager &manager, uint64_t numberOfBytes, DataType dtype, memory::MemoryZone zone); - ~ManagedDataBuffer(); - - void* primary() override; - void* special() override; - }; -} - -#endif //SD_MANAGEDDATABUFFER_H +/** + * This class provides special DataBuffer implementation for use within Graphs + */ +class SD_EXPORT ManagedDataBuffer : public DataBuffer { + private: + graph::GraphMemoryManager& _manager; + + protected: + memory::MemoryZone _zone; + MemoryDescriptor _descriptor; + + public: + ManagedDataBuffer(graph::GraphMemoryManager& manager, uint64_t numberOfBytes, + DataType dtype, memory::MemoryZone zone); + ~ManagedDataBuffer(); + + void* primary() override; + void* special() override; +}; +} // namespace sd + +#endif // SD_MANAGEDDATABUFFER_H diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 7c4cb887ca9d..a3220e86960f 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -17,1726 +17,1963 @@ #ifndef NDARRAY_H #define NDARRAY_H -#include -#include -#include -#include -#include "legacy/NativeOpExecutioner.h" -#include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include #include -#include -#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include + +#include +#include #include -#include -#include +#include "legacy/NativeOpExecutioner.h" namespace sd { - template ::value>::type> - SD_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); - template ::value>::type> - SD_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); - - template ::value>::type> - SD_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); - template ::value>::type> - SD_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); - - template ::value>::type> - SD_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); - template ::value>::type> - SD_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); - - template ::value>::type> - SD_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); - template ::value>::type> - SD_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); - template ::value>::type> - SD_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); - - template ::type>::value && std::is_same::type>::value>::type> - SD_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - SD_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - SD_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); - template ::type>::value && std::is_same::type>::value>::type> - SD_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); - - - - - SD_EXPORT NDArray mmul(const NDArray&, const NDArray&); - - class SD_EXPORT NDArray { - private: - /** - * This method applies given value to the buffer, wrt templates - * @tparam T - * @tparam Y - * @param buffer - * @param indices - * @param value - */ - template - void templatedSet(void *buffer, const Nd4jLong *indices, const void *value); - - template - void templatedSet(void *buffer, const Nd4jLong xOffset, const void *value); - - template - void templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value); - - template - void templatedAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const; - - template - void templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const; - - template - FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const; -/* - template - R templatedGetIndex(void *buffer, Nd4jLong *indices) const; -*/ - template - void* templatedPointerShift(const Nd4jLong offset) const; - - FORCEINLINE void copyBufferStatus(const NDArray& other) const; - - protected: - - /** - * if true then array doesn't own buffer and simply points to another's buffer - */ - bool _isView = false; - - /** - * pointer on DataBuffer buffers in cpu/device memory - */ - std::shared_ptr _buffer = std::make_shared(); - - /** - * buffers offset, it is the same both for cpu and device buffers - */ - Nd4jLong _offset = 0L; - - /** - * contains shape info: matrix rank, numbers of elements per each dimension, dimensions strides, element-wise-stride, c-like or fortan-like order - */ - Nd4jLong *_shapeInfo = nullptr; - Nd4jLong *_shapeInfoD = nullptr; - - /** - * pointer on device launch context (with all data needed there). - */ - sd::LaunchContext * _context = sd::LaunchContext::defaultContext(); - - // indicates if array's buffer is within workspace - bool _isAttached = false; - - /** - * Field to store cached length - */ - Nd4jLong _length = -1L; - - /** - * type of array elements - */ - sd::DataType _dataType = FLOAT32; - - /** - * deviceID where this NDArray belongs to - */ - int _deviceId = AffinityManager::currentDeviceId(); - - template - std::string toStringValue(T value); - - public: - NDArray() = default; - - /** - * do not allocate memory, memory for array is passed from outside - */ +template ::value>::type> +SD_EXPORT NDArray operator+(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator+(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator+(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator+(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator-(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator-(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator-(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator-(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator*(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator*(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator*(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator*(const T& scalar, NDArray&& arr); + +template ::value>::type> +SD_EXPORT NDArray operator/(const NDArray& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator/(NDArray&& arr, const T& scalar); +template ::value>::type> +SD_EXPORT NDArray operator/(const T& scalar, const NDArray& arr); +template ::value>::type> +SD_EXPORT NDArray operator/(const T& scalar, NDArray&& arr); + +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator+(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator-(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator*(T1&& arr1, T2&& arr2); +template < + typename T1, typename T2, + typename = typename std::enable_if< + std::is_same::type>::value && + std::is_same::type>::value>::type> +SD_EXPORT NDArray operator/(T1&& arr1, T2&& arr2); + +SD_EXPORT NDArray mmul(const NDArray&, const NDArray&); + +class SD_EXPORT NDArray { + private: + /** + * This method applies given value to the buffer, wrt templates + * @tparam T + * @tparam Y + * @param buffer + * @param indices + * @param value + */ + template + void templatedSet(void* buffer, const Nd4jLong* indices, const void* value); + + template + void templatedSet(void* buffer, const Nd4jLong xOffset, const void* value); + + template + void templatedSet(void* buffer, const Nd4jLong xOfsset, sd::DataType dtype, + const void* value); + + template + void templatedAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const; + + template + void templatedDoubleAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const; + + template + FORCEINLINE R templatedGet(void const* buffer, const Nd4jLong index) const; + /* + template + R templatedGetIndex(void *buffer, Nd4jLong *indices) const; + */ + template + void* templatedPointerShift(const Nd4jLong offset) const; + + FORCEINLINE void copyBufferStatus(const NDArray& other) const; + + protected: + /** + * if true then array doesn't own buffer and simply points to another's + * buffer + */ + bool _isView = false; + + /** + * pointer on DataBuffer buffers in cpu/device memory + */ + std::shared_ptr _buffer = std::make_shared(); + + /** + * buffers offset, it is the same both for cpu and device buffers + */ + Nd4jLong _offset = 0L; + + /** + * contains shape info: matrix rank, numbers of elements per each dimension, + * dimensions strides, element-wise-stride, c-like or fortan-like order + */ + Nd4jLong* _shapeInfo = nullptr; + Nd4jLong* _shapeInfoD = nullptr; + + /** + * pointer on device launch context (with all data needed there). + */ + sd::LaunchContext* _context = sd::LaunchContext::defaultContext(); + + // indicates if array's buffer is within workspace + bool _isAttached = false; + + /** + * Field to store cached length + */ + Nd4jLong _length = -1L; + + /** + * type of array elements + */ + sd::DataType _dataType = FLOAT32; + + /** + * deviceID where this NDArray belongs to + */ + int _deviceId = AffinityManager::currentDeviceId(); + + template + std::string toStringValue(T value); + + public: + NDArray() = default; + + /** + * do not allocate memory, memory for array is passed from outside + */ #ifndef __JAVACPP_HACK__ - NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const Nd4jLong offset = 0); - - NDArray(std::shared_ptr buffer, char order, const std::vector &shape, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf8 - * - */ - NDArray(const char* str, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::string(str), dtype, context) { - } - NDArray(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf16 - * - */ - NDArray(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::u16string(u16string), dtype, context) { - } - - NDArray(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create scalar array containing string utf32 - * - */ - NDArray(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()) - : NDArray(std::u32string(u32string), dtype, context) { - } - - NDArray(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf8 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf16 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This contructors create array from vector of utf32 strings - * - */ - NDArray(const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + const Nd4jLong offset = 0); + + NDArray(std::shared_ptr buffer, char order, + const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf8 + * + */ + NDArray(const char* str, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::string(str), dtype, context) {} + NDArray(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf16 + * + */ + NDArray(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::u16string(u16string), dtype, context) {} + + NDArray(const std::u16string& u16string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create scalar array containing string utf32 + * + */ + NDArray(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()) + : NDArray(std::u32string(u32string), dtype, context) {} + + NDArray(const std::u32string& u32string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf8 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf16 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This contructors create array from vector of utf32 strings + * + */ + NDArray(const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + NDArray(const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); #endif - /** - * do not allocate memory, memory for array is passed from outside - */ - NDArray(void *buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false); - NDArray(void *buffer, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false); - - /** - * do not allocate memory, memory for array is passed from outside - * we suppose the content of both (device and host) buffers is identical - */ - NDArray(void *buffer, void *bufferD, const Nd4jLong* shapeInfo, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isBuffAlloc = false, bool isBuffDAlloc = false); - - /** - * copy constructor - */ - NDArray(const NDArray& other); - - /** - * move constructor - */ - NDArray(NDArray&& other) noexcept; - - /** - * constructor, create array stored at given workspace - */ - NDArray(sd::LaunchContext * context); - - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - */ - NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true); - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - * set dtype as array type - */ - NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool nullify = true); - - /** - * this constructor creates new array using shape information contained in vector argument - */ - NDArray(char order, const std::vector &shape, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype - */ - NDArray(char order, const std::vector &shape, const std::vector& data, sd::DataType dtype = DOUBLE, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape - */ - NDArray(void *buffer, char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool isBuffAlloc = false); - - /** - * This method returns new array with the same shape & data type - * @return - */ - NDArray like(); - - /** - * This method returns new uninitialized array with the same shape & data type - * @return - */ - NDArray ulike() const; - - - /** - * this constructor creates new NDArray with shape matching "other" array, - * doesn't copy "other" elements into new array !!! - */ - explicit NDArray(const NDArray* other, bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); - - /** - * this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar - */ - NDArray(sd::DataType dtype, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), bool isScalar = true); - - /** - * This method blocks until asynchronous operation finishes - */ - void synchronize(const char* msg) const; - - /** - * This method allows to set _isAttached flag - * @param reallyAttached - */ - void setAttached(bool reallyAttached); - - void tickWriteHost() const; - void tickWriteDevice() const; - void tickReadHost() const; - void tickReadDevice() const; - void tickBothActual() const; - bool isActualOnHostSide() const; - bool isActualOnDeviceSide() const; - void makeBothBuffersActual() const; - - void syncToHost() const; - void syncToDevice() const; - void syncShape() const; - - /** - * This method can be used on architectures that use special buffers - * @param writeList - * @param readList - */ - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); - - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); - - /** - * This method returns buffer pointer offset by given number of elements, wrt own data type - * @param offset - * @return - */ - void const* bufferWithOffset(Nd4jLong offset) const; - void* bufferWithOffset(Nd4jLong offset); - - void const* specialBufferWithOffset(Nd4jLong offset) const; - void* specialBufferWithOffset(Nd4jLong offset); - /** - * copy assignment operator - * in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType - */ - NDArray& operator=(const NDArray& other); - - /** - * move assignment operator - */ - NDArray& operator=(NDArray&& other) noexcept; - - /** - * assignment operator, assigns the same scalar to all array elements - */ - template - NDArray& operator=(const T scalar); - - - /** - * operators for memory allocation and deletion - */ - void* operator new(size_t i); - void operator delete(void* p); - - - void setContext(sd::LaunchContext * context); - - /** - * create a new array by replicating current array by repeats times along given dimension - * axis - axis along which to repeat elements - * repeats - number of repetitions - */ - NDArray repeat(const int axis, const std::vector& repeats) const; - - /** - * This method fills this array with zeros - */ - void nullify(); - - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - static NDArray quantize(const NDArray &array); - - /** - * fill target array by repeating current array - * axis - axis along which to repeat elements - * repeats - vector containing numbers of repetition for elements at given axis - */ - void repeat(const int axis, const std::vector& repeats, NDArray& target) const; - - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - NDArray subarray(IndicesList& indices) const; - NDArray subarray(const std::initializer_list& idx) const; - NDArray subarray(const Intervals& idx) const; - - /** - * cast array elements to given dtype - */ - NDArray cast(DataType dtype) const; - - void cast(NDArray& target, DataType dtype); - - /** - * returns _context - */ - sd::LaunchContext * getContext() const { - return _context; - } + /** + * do not allocate memory, memory for array is passed from outside + */ + NDArray(void* buffer, Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false); + NDArray(void* buffer, const Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false); + + /** + * do not allocate memory, memory for array is passed from outside + * we suppose the content of both (device and host) buffers is identical + */ + NDArray(void* buffer, void* bufferD, const Nd4jLong* shapeInfo, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isBuffAlloc = false, bool isBuffDAlloc = false); + + /** + * copy constructor + */ + NDArray(const NDArray& other); + + /** + * move constructor + */ + NDArray(NDArray&& other) noexcept; + + /** + * constructor, create array stored at given workspace + */ + NDArray(sd::LaunchContext* context); + + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently + */ + NDArray(const Nd4jLong* shapeInfo, bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool nullify = true); + + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to be zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently set + * dtype as array type + */ + NDArray(const Nd4jLong* shapeInfo, sd::DataType dtype, + bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool nullify = true); + + /** + * this constructor creates new array using shape information contained in + * vector argument + */ + NDArray(char order, const std::vector& shape, + sd::DataType dtype = DOUBLE, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This constructor creates new array with elements copied from data and using + * shape information stored in shape, elements from data will be casted to + * dtype + */ + NDArray(char order, const std::vector& shape, + const std::vector& data, sd::DataType dtype = DOUBLE, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * this constructor creates new array using given buffer (without memory + * allocation) and shape information stored in shape + */ + NDArray(void* buffer, char order, const std::vector& shape, + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + const bool isBuffAlloc = false); + + /** + * This method returns new array with the same shape & data type + * @return + */ + NDArray like(); + + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + NDArray ulike() const; + + /** + * this constructor creates new NDArray with shape matching "other" array, + * doesn't copy "other" elements into new array !!! + */ + explicit NDArray( + const NDArray* other, bool copyStrides = false, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * this constructor creates scalar(and set its value = 0) or empty array + * depending on bool argument isScalar + */ + NDArray(sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext::defaultContext(), + bool isScalar = true); + + /** + * This method blocks until asynchronous operation finishes + */ + void synchronize(const char* msg) const; + + /** + * This method allows to set _isAttached flag + * @param reallyAttached + */ + void setAttached(bool reallyAttached); + + void tickWriteHost() const; + void tickWriteDevice() const; + void tickReadHost() const; + void tickReadDevice() const; + void tickBothActual() const; + bool isActualOnHostSide() const; + bool isActualOnDeviceSide() const; + void makeBothBuffersActual() const; + + void syncToHost() const; + void syncToDevice() const; + void syncShape() const; + + /** + * This method can be used on architectures that use special buffers + * @param writeList + * @param readList + */ + static void registerSpecialUse(const std::vector& writeList, + const std::vector& readList); + static void prepareSpecialUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); + + static void registerPrimaryUse(const std::vector& writeList, + const std::vector& readList); + static void preparePrimaryUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables = false); + + /** + * This method returns buffer pointer offset by given number of elements, wrt + * own data type + * @param offset + * @return + */ + void const* bufferWithOffset(Nd4jLong offset) const; + void* bufferWithOffset(Nd4jLong offset); + + void const* specialBufferWithOffset(Nd4jLong offset) const; + void* specialBufferWithOffset(Nd4jLong offset); + /** + * copy assignment operator + * in particular, when _dataType != other._dataType and both shapes are the + * same, there will be allocation of new _buffer and _dataType acquires + * other._dataType + */ + NDArray& operator=(const NDArray& other); + + /** + * move assignment operator + */ + NDArray& operator=(NDArray&& other) noexcept; + + /** + * assignment operator, assigns the same scalar to all array elements + */ + template + NDArray& operator=(const T scalar); + + /** + * operators for memory allocation and deletion + */ + void* operator new(size_t i); + void operator delete(void* p); + + void setContext(sd::LaunchContext* context); + + /** + * create a new array by replicating current array by repeats times along + * given dimension axis - axis along which to repeat elements repeats - number + * of repetitions + */ + NDArray repeat(const int axis, const std::vector& repeats) const; + + /** + * This method fills this array with zeros + */ + void nullify(); + + /** + * This method returns quantized copy of given array + * + * @param array + * @return + */ + static NDArray quantize(const NDArray& array); + + /** + * fill target array by repeating current array + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given + * axis + */ + void repeat(const int axis, const std::vector& repeats, + NDArray& target) const; + + /** + * creates array which points on certain sub-range of this array, sub-range + * is defined by given indices + */ + NDArray subarray(IndicesList& indices) const; + NDArray subarray(const std::initializer_list& idx) const; + NDArray subarray(const Intervals& idx) const; + + /** + * cast array elements to given dtype + */ + NDArray cast(DataType dtype) const; + + void cast(NDArray& target, DataType dtype); + + /** + * returns _context + */ + sd::LaunchContext* getContext() const { return _context; } #ifndef __JAVACPP_HACK__ - FORCEINLINE std::shared_ptr getDataBuffer() const; - FORCEINLINE std::shared_ptr dataBuffer(); + FORCEINLINE std::shared_ptr getDataBuffer() const; + FORCEINLINE std::shared_ptr dataBuffer(); #endif - /** - * returns host buffer - */ - FORCEINLINE void* buffer(); - FORCEINLINE const void* buffer() const; - - - /** - * returns buffer offset (offset is the same for host and device buffers) - */ - FORCEINLINE Nd4jLong bufferOffset() const; - - /** - * if _bufferD==nullptr return _buffer, else return _bufferD - */ - void* specialBuffer(); - const void* specialBuffer() const; - - /** - * returns device buffer if compilation is for cuda case, otherwise returns host buffer - */ - void* platformBuffer(); - const void* platformBuffer() const; - - - - template - T* bufferAsT(); - - template - const T* bufferAsT() const; - - /** - * returns _shapeInfo - */ - FORCEINLINE const Nd4jLong* shapeInfo() const; - - - /** - * Returns True if it's legally empty NDArray, or false otherwise - * @return - */ - FORCEINLINE bool isEmpty() const; - - /** - * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD - */ - FORCEINLINE const Nd4jLong* specialShapeInfo() const; - - const Nd4jLong* platformShapeInfo() const; - - /** - * permutes (in-place) the dimensions in array according to "dimensions" array - */ - bool permutei(const std::initializer_list& dimensions); - bool permutei(const std::vector& dimensions); - bool permutei(const int* dimensions, const int rank); - - bool permutei(const std::initializer_list& dimensions); - bool permutei(const std::vector& dimensions); - bool permutei(const Nd4jLong* dimensions, const int rank); - - bool isFinite(); - bool hasNaNs(); - bool hasInfs(); - - void copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes = 0, Nd4jLong offsetThis = 0, Nd4jLong offsetOther = 0); - - /** - * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array - */ - NDArray permute(const std::initializer_list& dimensions) const &; - NDArray permute(const std::vector& dimensions) const &; - NDArray permute(const int* dimensions, const int rank) const &; - NDArray permute(const std::initializer_list& dimensions) &&; - NDArray permute(const std::vector& dimensions) &&; - NDArray permute(const int* dimensions, const int rank) &&; - - void permute(const int* dimensions, const int rank, NDArray& target) const; - void permute(const std::vector& dimensions, NDArray& target) const; - - NDArray permute(const std::initializer_list& dimensions) const &; - NDArray permute(const std::vector& dimensions) const &; - NDArray permute(const Nd4jLong* dimensions, const int rank) const &; - NDArray permute(const std::initializer_list& dimensions) &&; - NDArray permute(const std::vector& dimensions) &&; - NDArray permute(const Nd4jLong* dimensions, const int rank) &&; - - void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const; - void permute(const std::vector& dimensions, NDArray& target) const; - - /** - * This method streamlines given view or permuted array, and reallocates buffer - */ - void streamline(char order = 'a'); - - /** - * prints information about array shape - * msg - message to print out - */ - void printShapeInfo(const char * msg = nullptr) const; - - /** - * prints buffer elements - * msg - message to print out - * limit - number of array elements to print out - * sync - if true check whether host buffer is actual, if it is not then make it so - */ - void printBuffer(const char* msg = nullptr, Nd4jLong limit = -1, const bool sync = true) const; - - /** - * print element by element consequently in a way they (elements) are stored in physical memory - */ - void printLinearBuffer() const; - - /** - * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status - */ - template - void printCurrentBuffer(const bool host = true, const char* msg = nullptr, const int precision = 1) const; - - /** - * prints buffer elements, takes into account offset between elements (element-wise-stride) - * msg - message to print out - * limit - number of array elements to print out - */ - void printIndexedBuffer(const char* msg = nullptr, Nd4jLong limit = -1) const; - - std::string asIndexedString(Nd4jLong limit = -1); - std::string asString(Nd4jLong limit = -1); - - /** - * this method assigns values of given array to this one - */ - void assign(const NDArray* other, bool allowParallelism = true); - - /** - * this method assigns values of given array to this one - */ - void assign(const NDArray& other, bool allowParallelism = true); - - /** - * this method assigns given value to all elements in array - */ - template ::value>::type> - void assign(const T& value, bool allowParallelism = true); - - /** - * returns new copy of this array, optionally in different order - */ - NDArray dup(const char newOrder = 'a') const; - - /** - * returns sum of all elements of array - */ - NDArray sumNumber() const; - - /** - * returns mean number of array - */ - NDArray meanNumber() const; + /** + * returns host buffer + */ + FORCEINLINE void* buffer(); + FORCEINLINE const void* buffer() const; + + /** + * returns buffer offset (offset is the same for host and device buffers) + */ + FORCEINLINE Nd4jLong bufferOffset() const; + + /** + * if _bufferD==nullptr return _buffer, else return _bufferD + */ + void* specialBuffer(); + const void* specialBuffer() const; + + /** + * returns device buffer if compilation is for cuda case, otherwise returns + * host buffer + */ + void* platformBuffer(); + const void* platformBuffer() const; + + template + T* bufferAsT(); + + template + const T* bufferAsT() const; + + /** + * returns _shapeInfo + */ + FORCEINLINE const Nd4jLong* shapeInfo() const; + + /** + * Returns True if it's legally empty NDArray, or false otherwise + * @return + */ + FORCEINLINE bool isEmpty() const; + + /** + * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD + */ + FORCEINLINE const Nd4jLong* specialShapeInfo() const; + + const Nd4jLong* platformShapeInfo() const; + + /** + * permutes (in-place) the dimensions in array according to "dimensions" + * array + */ + bool permutei(const std::initializer_list& dimensions); + bool permutei(const std::vector& dimensions); + bool permutei(const int* dimensions, const int rank); + + bool permutei(const std::initializer_list& dimensions); + bool permutei(const std::vector& dimensions); + bool permutei(const Nd4jLong* dimensions, const int rank); + + bool isFinite(); + bool hasNaNs(); + bool hasInfs(); + + void copyBuffersContinuouslyFrom(const NDArray& other, + size_t sizeToCopyInBytes = 0, + Nd4jLong offsetThis = 0, + Nd4jLong offsetOther = 0); + + /** + * permutes the dimensions in array according to "dimensions" array, new + * array points on _buffer of this array + */ + NDArray permute(const std::initializer_list& dimensions) const&; + NDArray permute(const std::vector& dimensions) const&; + NDArray permute(const int* dimensions, const int rank) const&; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const int* dimensions, const int rank) &&; + + void permute(const int* dimensions, const int rank, NDArray& target) const; + void permute(const std::vector& dimensions, NDArray& target) const; + + NDArray permute(const std::initializer_list& dimensions) const&; + NDArray permute(const std::vector& dimensions) const&; + NDArray permute(const Nd4jLong* dimensions, const int rank) const&; + NDArray permute(const std::initializer_list& dimensions) &&; + NDArray permute(const std::vector& dimensions) &&; + NDArray permute(const Nd4jLong* dimensions, const int rank) &&; + + void permute(const Nd4jLong* dimensions, const int rank, + NDArray& target) const; + void permute(const std::vector& dimensions, NDArray& target) const; + + /** + * This method streamlines given view or permuted array, and reallocates + * buffer + */ + void streamline(char order = 'a'); + + /** + * prints information about array shape + * msg - message to print out + */ + void printShapeInfo(const char* msg = nullptr) const; + + /** + * prints buffer elements + * msg - message to print out + * limit - number of array elements to print out + * sync - if true check whether host buffer is actual, if it is not then make + * it so + */ + void printBuffer(const char* msg = nullptr, Nd4jLong limit = -1, + const bool sync = true) const; + + /** + * print element by element consequently in a way they (elements) are stored + * in physical memory + */ + void printLinearBuffer() const; + + /** + * prints _buffer (if host = true) or _bufferD (if host = false) as it is, + * that is in current state without checking buffer status + */ + template + void printCurrentBuffer(const bool host = true, const char* msg = nullptr, + const int precision = 1) const; + + /** + * prints buffer elements, takes into account offset between elements + * (element-wise-stride) msg - message to print out limit - number of array + * elements to print out + */ + void printIndexedBuffer(const char* msg = nullptr, Nd4jLong limit = -1) const; + + std::string asIndexedString(Nd4jLong limit = -1); + std::string asString(Nd4jLong limit = -1); + + /** + * this method assigns values of given array to this one + */ + void assign(const NDArray* other, bool allowParallelism = true); + + /** + * this method assigns values of given array to this one + */ + void assign(const NDArray& other, bool allowParallelism = true); + + /** + * this method assigns given value to all elements in array + */ + template ::value>::type> + void assign(const T& value, bool allowParallelism = true); + + /** + * returns new copy of this array, optionally in different order + */ + NDArray dup(const char newOrder = 'a') const; + + /** + * returns sum of all elements of array + */ + NDArray sumNumber() const; + + /** + * returns mean number of array + */ + NDArray meanNumber() const; #ifndef __JAVACPP_HACK__ - /** - * This method explicitly enforces new shape for this NDArray, old shape/stride information is lost - */ - void enforce(const std::initializer_list &dimensions, char order = 'a'); - void enforce(std::vector &dimensions, char order = 'a'); - - - /** - * method reduces array by excluding its shapes along dimensions present in given dimensions vector, result is stored in new array to be returned - * dimensions - array of dimensions to reduce along - * keepDims - if true then put unities in place of reduced dimensions - */ - - NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const; - - /** - * method reduces array by excluding its shapes along dimensions present in given dimensions vector - * target - where to save result of reducing - * dimensions - array of dimensions to reduce along - * keepDims - if true then put unities in place of reduced dimensions - * extras - extra parameters - */ - void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const; - - /** - * return variance of array elements set - * biasCorrected - if true bias correction will be applied - */ - NDArray varianceNumber(sd::variance::Ops op, bool biasCorrected = true); - - /** - * apply scalar operation to array - * extraParams - extra parameters for operation - * returns scalar array - */ - NDArray reduceNumber(sd::reduce::FloatOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::SameOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::BoolOps ops, void *extraParams = nullptr) const; - NDArray reduceNumber(sd::reduce::LongOps ops, void *extraParams = nullptr) const; - - void reduceNumber(sd::reduce::FloatOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::SameOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::BoolOps ops, NDArray& target, void *extraParams = nullptr) const; - void reduceNumber(sd::reduce::LongOps ops, NDArray& target, void *extraParams = nullptr) const; - - /** - * returns element index which corresponds to some condition imposed by operation - * extraParams - extra parameters for operation - */ - NDArray indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams = nullptr); - - /** - * returns index of max element in a given array (optionally: along given dimension(s)) - * dimensions - optional vector with dimensions - */ - Nd4jLong argMax(std::initializer_list dimensions = {}); - - // FIXME: remove this method eventually - void makeBothActual() const { syncToDevice(); syncToHost(); } - - - void applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - void applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams = nullptr); - - /** - * apply OpName transformation to this array and store result in new array to be returned - * extraParams - extra parameters for operation - */ - NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) const &; - NDArray transform(sd::transform::FloatOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::SameOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::BoolOps op, void *extraParams = nullptr) &&; - NDArray transform(sd::transform::StrictOps op, void *extraParams = nullptr) &&; - - /** - * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in this array - * other - second array necessary for pairwise operation - * extraParams - extra parameters for operation - */ - void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams = nullptr); - - /** - * apply pairwise OpName transformation based on "this" and "other" arras elements, store result in target array - * other - second array necessary for pairwise operation - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray&target, ExtraArguments *extraParams = nullptr) const; - - /** - * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) - * tad - array to broadcast - * dimensions - dimensions array to broadcast along - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tad, NDArray& target, ExtraArguments* extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::Ops op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::BoolOps op, const std::vector &dimensions, const NDArray &tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - void applyBroadcast(sd::broadcast::IntOps op, const std::vector &dimensions, const NDArray& tad, NDArray &target, ExtraArguments *extraArgs = nullptr); - - /** - * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting - * other - input array - * extraParams - extra parameters for operation - */ - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) const &; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) const &; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs = nullptr) &&; - NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs = nullptr) &&; - - /** - * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting - * other - input array - * target - where to store result - * checkTargetShape - if true check whether target shape is suitable for broadcasting - * extraParams - extra parameters for operation - */ - void applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - void applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - void applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; - - - /** - * apply a scalar operation to an array - * scalar - input scalar - * target - where to store result - * extraParams - extra parameters for operation - */ - template - void applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - - template - void applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - template - void applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - /** - * apply a scalar operation to an array - * scalar - input array which is simple scalar - * target - where to store result - * extraParams - extra parameters for operation - */ - void applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr); - - void applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - - void applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams = nullptr) const; - -#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) - template - FORCEINLINE void applyLambda(Lambda func, NDArray& target); - - template - FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target); - - template - FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); - - template - FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target); - - template - FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target); + /** + * This method explicitly enforces new shape for this NDArray, old + * shape/stride information is lost + */ + void enforce(const std::initializer_list& dimensions, + char order = 'a'); + void enforce(std::vector& dimensions, char order = 'a'); + + /** + * method reduces array by excluding its shapes along dimensions present in + * given dimensions vector, result is stored in new array to be returned + * dimensions - array of dimensions to reduce along + * keepDims - if true then put unities in place of reduced dimensions + */ + + NDArray reduceAlongDimension(sd::reduce::FloatOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::FloatOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::SameOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::SameOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::BoolOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::BoolOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + NDArray reduceAlongDimension(sd::reduce::LongOps op, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + NDArray reduceAlongDimension(sd::reduce::LongOps op, + const std::initializer_list& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false) const; + + /** + * method reduces array by excluding its shapes along dimensions present in + * given dimensions vector target - where to save result of reducing + * dimensions - array of dimensions to reduce along + * keepDims - if true then put unities in place of reduced dimensions + * extras - extra parameters + */ + void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims = false, + const bool supportOldShapes = false, + const bool checkTargetShape = true) const; + + /** + * return variance of array elements set + * biasCorrected - if true bias correction will be applied + */ + NDArray varianceNumber(sd::variance::Ops op, bool biasCorrected = true); + + /** + * apply scalar operation to array + * extraParams - extra parameters for operation + * returns scalar array + */ + NDArray reduceNumber(sd::reduce::FloatOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::SameOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::BoolOps ops, + void* extraParams = nullptr) const; + NDArray reduceNumber(sd::reduce::LongOps ops, + void* extraParams = nullptr) const; + + void reduceNumber(sd::reduce::FloatOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::SameOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::BoolOps ops, NDArray& target, + void* extraParams = nullptr) const; + void reduceNumber(sd::reduce::LongOps ops, NDArray& target, + void* extraParams = nullptr) const; + + /** + * returns element index which corresponds to some condition imposed by + * operation extraParams - extra parameters for operation + */ + NDArray indexReduceNumber(sd::indexreduce::Ops op, + ExtraArguments* extraParams = nullptr); + + /** + * returns index of max element in a given array (optionally: along given + * dimension(s)) dimensions - optional vector with dimensions + */ + Nd4jLong argMax(std::initializer_list dimensions = {}); + + // FIXME: remove this method eventually + void makeBothActual() const { + syncToDevice(); + syncToHost(); + } + + void applyTransform(sd::transform::FloatOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::SameOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::AnyOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::BoolOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + void applyTransform(sd::transform::StrictOps op, NDArray& target, + ExtraArguments* extraParams = nullptr); + + /** + * apply OpName transformation to this array and store result in new array to + * be returned extraParams - extra parameters for operation + */ + NDArray transform(sd::transform::FloatOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::SameOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::BoolOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::StrictOps op, + void* extraParams = nullptr) const&; + NDArray transform(sd::transform::FloatOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::SameOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::BoolOps op, void* extraParams = nullptr) &&; + NDArray transform(sd::transform::StrictOps op, + void* extraParams = nullptr) &&; + + /** + * apply pairwise OpName transformation based on "this" and "other" arras + * elements, store result in this array other - second array necessary for + * pairwise operation extraParams - extra parameters for operation + */ + void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + ExtraArguments* extraParams = nullptr); + + /** + * apply pairwise OpName transformation based on "this" and "other" arras + * elements, store result in target array other - second array necessary for + * pairwise operation target - where to store result extraParams - extra + * parameters for operation + */ + void applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + /** + * apply operation which requires broadcasting, broadcast a smaller array + * (tad) along bigger one (this) tad - array to broadcast dimensions - + * dimensions array to broadcast along target - where to store result + * extraParams - extra parameters for operation + */ + void applyBroadcast(sd::broadcast::Ops op, + const std::initializer_list dimensions, + const NDArray& tad, NDArray& target, + ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, + const NDArray& tad, NDArray& target, + ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::BoolOps op, + const std::vector& dimensions, const NDArray& tad, + NDArray& target, ExtraArguments* extraArgs = nullptr); + + void applyBroadcast(sd::broadcast::IntOps op, + const std::vector& dimensions, const NDArray& tad, + NDArray& target, ExtraArguments* extraArgs = nullptr); + + /** + * apply operation which requires broadcasting, broadcast one tensor along + * another, also this method checks the possibility of broadcasting other - + * input array extraParams - extra parameters for operation + */ + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + ExtraArguments* extraArgs = nullptr) const&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs = nullptr) const&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs = nullptr) &&; + NDArray applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + ExtraArguments* extraArgs = nullptr) &&; + + /** + * apply operation which requires broadcasting, broadcast one tensor along + * another, also this method checks the possibility of broadcasting other - + * input array target - where to store result checkTargetShape - if true check + * whether target shape is suitable for broadcasting extraParams - extra + * parameters for operation + */ + void applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + void applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + void applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape = true, + ExtraArguments* extraArgs = nullptr) const; + + /** + * apply a scalar operation to an array + * scalar - input scalar + * target - where to store result + * extraParams - extra parameters for operation + */ + template + void applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr); + + template + void applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + template + void applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + /** + * apply a scalar operation to an array + * scalar - input array which is simple scalar + * target - where to store result + * extraParams - extra parameters for operation + */ + void applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, ExtraArguments* extraParams = nullptr); + + void applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + + void applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams = nullptr) const; + +#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) + template + FORCEINLINE void applyLambda(Lambda func, NDArray& target); + + template + FORCEINLINE void applyPairwiseLambda(const NDArray& other, Lambda func, + NDArray& target); + + template + FORCEINLINE void applyIndexedLambda(Lambda func, NDArray& target); + + template + FORCEINLINE void applyIndexedPairwiseLambda(NDArray& other, Lambda func, + NDArray& target); + + template + FORCEINLINE void applyTriplewiseLambda(NDArray& second, NDArray& third, + Lambda func, NDArray& target); #else - /** - * apply operation "func" to an array - * func - what operation to apply - * target - where to store result - */ - template - void applyLambda(const std::function& func, NDArray& target); - - /** - * apply pairwise operation "func" to an array - * other - input array - * func - what pairwise operation to apply - * target - where to store result - */ - template - void applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); - - template - void applyIndexedLambda(const std::function& func, NDArray& target); - - template - void applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); - - template - void applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target); + /** + * apply operation "func" to an array + * func - what operation to apply + * target - where to store result + */ + template + void applyLambda(const std::function& func, NDArray& target); + + /** + * apply pairwise operation "func" to an array + * other - input array + * func - what pairwise operation to apply + * target - where to store result + */ + template + void applyPairwiseLambda(const NDArray& other, + const std::function& func, NDArray& target); + + template + void applyIndexedLambda(const std::function& func, + NDArray& target); + + template + void applyIndexedPairwiseLambda(NDArray& other, + const std::function& func, + NDArray& target); + + template + void applyTriplewiseLambda(NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); #endif - /** - * reduces dimensions in this array relying on index operation OpName - * dimensions - vector of dimensions to reduce along - * extraArgs - extra parameters for operation - */ - NDArray applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; - - /** - * reduces dimensions in array relying on index operation OpName - * target - where to store result - * dimensions - vector of dimensions to reduce along - * extraArgs - extra parameters for operation - */ - void applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams = nullptr) const; - - /** - * apply reduce3 operation OpName to this and other array, return result in new output array - * other - input array - * extraArgs - extra parameters for operation - */ - NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams = nullptr) const; - - /** - * apply reduce3 operation OpName to this and other array, return result in new output array - * other - input array - * dimensions - vector of dimensions to reduce along (tads not axis) - * extraArgs - extra parameters for operation - */ - NDArray applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - - /** - * apply reduce3 (exec) operation OpName to this and other array, return result in new output array - * other - input array - * dimensions - vector of dimensions to reduce along (same as reduceAlongDimension) - * extraArgs - extra parameters for operation - */ - NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams = nullptr) const; - - /** - * returns variance along given dimensions - * biasCorrected - if true bias correction will be applied - * dimensions - vector of dimensions to calculate variance along - */ - NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const; - NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const; - - void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const; - void varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::initializer_list& dimensions) const; + /** + * reduces dimensions in this array relying on index operation OpName + * dimensions - vector of dimensions to reduce along + * extraArgs - extra parameters for operation + */ + NDArray applyIndexReduce(sd::indexreduce::Ops op, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * reduces dimensions in array relying on index operation OpName + * target - where to store result + * dimensions - vector of dimensions to reduce along + * extraArgs - extra parameters for operation + */ + void applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 operation OpName to this and other array, return result in + * new output array other - input array extraArgs - extra parameters for + * operation + */ + NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 operation OpName to this and other array, return result in + * new output array other - input array dimensions - vector of dimensions to + * reduce along (tads not axis) extraArgs - extra parameters for operation + */ + NDArray applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * apply reduce3 (exec) operation OpName to this and other array, return + * result in new output array other - input array dimensions - vector of + * dimensions to reduce along (same as reduceAlongDimension) extraArgs - extra + * parameters for operation + */ + NDArray applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams = nullptr) const; + + /** + * returns variance along given dimensions + * biasCorrected - if true bias correction will be applied + * dimensions - vector of dimensions to calculate variance along + */ + NDArray varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, + const std::vector& dimensions) const; + NDArray varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::initializer_list& dimensions) const; + + void varianceAlongDimension(sd::variance::Ops op, NDArray& target, + const bool biasCorrected, + const std::vector& dimensions) const; + void varianceAlongDimension( + sd::variance::Ops op, NDArray& target, const bool biasCorrected, + const std::initializer_list& dimensions) const; #endif - /** - * apply transpose operation to the copy of this array, that is this array remains unaffected - */ - NDArray transpose() const &; - NDArray transpose() &&; - - /** - * perform transpose operation and store result in target, this array remains unaffected - * target - where to store result - */ - void transpose(NDArray& target) const; - - /** - * apply in-place transpose operation to this array, so this array becomes transposed - */ - void transposei(); - - /** - * returns the number of arrays pointing on specified dimension(s) - * dimensions - array of dimensions to point on - */ - Nd4jLong tensorsAlongDimension(const std::initializer_list dimensions) const ; - Nd4jLong tensorsAlongDimension(const std::vector& dimensions) const ; - - /** - * returns true if elements of two arrays are equal to within given epsilon value - * other - input array to compare - * eps - epsilon, this value defines the precision of elements comparison - */ - bool equalsTo(const NDArray *other, double eps = 1e-5) const; - bool equalsTo(const NDArray &other, double eps = 1e-5) const; - - /** - * add given row vector to all rows of this array - * row - row vector to add - */ - void addiRowVector(const NDArray& row); - - /** - * add given row vector to all rows of this array, store result in target - * row - row vector to add - * target - where to store result - */ - void addRowVector(const NDArray& row, NDArray& target) const; - - /** - * subtract given row vector from all rows of this array, store result in target - * row - row vector to subtract - * target - where to store result - */ - void subRowVector(const NDArray& row, NDArray& target) const; - - /** - * multiply all rows of this array on given row vector, store result in target - * row - row vector to multiply on - * target - where to store result - */ - void mulRowVector(const NDArray &row, NDArray& target) const; - - /** - * divide all rows of this array on given row vector, store result in target - * row - row vector to divide on - * target - where to store result - */ - void divRowVector(const NDArray &row, NDArray& target) const; - - /** - * add given column vector to all columns of this array, store result in target - * column - column vector to add - * target - where to store result - */ - void addColumnVector(const NDArray &column, NDArray& target) const; - - /** - * add given column vector to all columns of this array, this array becomes affected (in-place operation) - * column - column vector to add - */ - void addiColumnVector(const NDArray &column); - - /** - * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) - * column - column vector to multiply on - */ - void muliColumnVector(const NDArray &column); - - /** - * returns number of bytes used by _buffer & _shapeInfo - */ - FORCEINLINE Nd4jLong memoryFootprint(); - - /** - * these methods suited for FlatBuffers use - */ - template - std::vector getBufferAsVector(); - std::vector getShapeAsVector() const; - std::vector getShapeAsVectorInt() const; - std::vector getShapeInfoAsVector(); - std::vector getShapeInfoAsFlatVector(); - std::vector getShapeAsFlatVector(); - - /** - * set new order and shape in case of suitable array length (in-place operation) - * order - order to set - * shape - shape to set - * copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping - * if there was permute applied before or there are weird strides, then new buffer is allocated for array - */ - bool reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff = true); - bool reshapei(const char order, const std::vector& shape, const bool copyToNewBuff = true); - - bool reshapei(const std::initializer_list& shape, const bool copyToNewBuff = true); - bool reshapei(const std::vector& shape, const bool copyToNewBuff = true); - - /** - * creates new array with corresponding order and shape, new array will point on _buffer of this array - * order - order to set - * shape - shape to set - * - * if permute have been applied before or there are weird strides, then new buffer is allocated for new array - */ - NDArray reshape(const char order, const std::vector& shape, const bool copyToNewBuff = true) const &; - NDArray reshape(const char order, const std::vector& shape, const bool copyToNewBuff = true) &&; - - /** - * calculate strides and set given order - * order - order to set - */ - void updateStrides(const char order); - - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - */ - void tilei(const std::vector& repeats); - - /** - * returns new array which is created by repeating of this array the number of times given by reps - * repeats - contains numbers of repetitions - */ - NDArray tile(const std::vector& repeats) const; - - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - * target - where to store result - */ - void tile(const std::vector& repeats, NDArray& target) const; - - /** - * change an array by repeating it the number of times to acquire the new shape which is the same as target shape - * target - where to store result - */ - void tile(NDArray& target) const; - - /** - * check whether array is identity matrix - */ - bool isIdentityMatrix(); - - /** - * check whether array is unitary matrix - */ - bool isUnitary(); - - /** - * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals - * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - * isStrided - if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * so structure of idx is like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} - */ - NDArray operator()(const std::vector& idx, const bool keepUnitiesInShape = false, const bool isStrided = false) const; - - /** - * evaluates subarray with buffer pointing at this->_buffer and offset defined by given sequential index subArrIdx and dimensions in dimsToExclude - * subArrIdx - index of current sub-array - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] - * if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned. - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - */ - NDArray operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape = false) const; - - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * if dimsToExclude.size() = array rank it means sub-array is whole array and copy of original_shapeInfo will be returned and one zero offset - * subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape = false) const; - - /** - * addition unary operator array += other - * other - input array to add - */ - void operator+=(const NDArray& other); - - /** - * subtraction unary operator array -= other - * other - input array to add - */ - void operator-=(const NDArray& other); - - template - void operator+=(const T other); - - template - void operator-=(const T other); - - /** - * negative operator, it changes sign of all array elements on opposite - */ - NDArray operator-() const &; - NDArray operator-() &&; - - /** - * pairwise multiplication unary operator array *= other - * other - input array to multiply on - */ - void operator*=(const NDArray& other); - - /** - * multiplication unary operator array *= scalar - * scalar - input scalar to multiply on - */ - template - void operator*=(const T scalar); - - /** - * pairwise division unary operator: array /= other - * other - input array to divide on - */ - void operator/=(const NDArray& other); - - /** - * division unary operator: array /= scalar - * scalar - input scalar to divide on - */ - template - void operator/=(const T scalar); - - /** - * friend function which implements mathematical multiplication of two arrays - * left - input array - * right - input array - */ - friend NDArray mmul(const NDArray& left, const NDArray& right); - - /** - * return vector containing _buffer as flat binary array - */ - std::vector asByteVector(); - - /** - * makes array to be identity matrix (not necessarily square), that is set all diagonal elements = 1, rest = 0 - */ - void setIdentity(); - - /** - * swaps the contents of tow arrays, - * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes may be different except one condition: arrays lengths must be the same - */ - void swapUnsafe(NDArray& other); - - /** - * return vector with buffer which points on corresponding diagonal elements of array - * type - means of vector to be returned: column ('c') or row ('r') - */ - NDArray diagonal(const char type ) const; - - /** - * fill target matrix with given value in one or two directions from main diagonal: - * - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both) - * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) - * direction - in what direction to fill matrix. There are 3 possible directions: - * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected - * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected - * 'b' - fill in both directions, both "lower" and "upper" are taken into account - * rest of target elements are equal to this array elements - * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) - */ - template - void fillAsTriangular(const float value, int lower, int upper, NDArray& target, const char direction = 'b'); - - /** - * change an array by repeating it the number of times in order to acquire new shape equal to the input shape - * - * shape - contains new shape to broadcast array to - * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place - */ - NDArray tileToShape(const Nd4jLong* shapeInfo); - void tileToShape(const std::vector& shape, NDArray& target); + /** + * apply transpose operation to the copy of this array, that is this array + * remains unaffected + */ + NDArray transpose() const&; + NDArray transpose() &&; + + /** + * perform transpose operation and store result in target, this array remains + * unaffected target - where to store result + */ + void transpose(NDArray& target) const; + + /** + * apply in-place transpose operation to this array, so this array becomes + * transposed + */ + void transposei(); + + /** + * returns the number of arrays pointing on specified dimension(s) + * dimensions - array of dimensions to point on + */ + Nd4jLong tensorsAlongDimension( + const std::initializer_list dimensions) const; + Nd4jLong tensorsAlongDimension(const std::vector& dimensions) const; + + /** + * returns true if elements of two arrays are equal to within given epsilon + * value other - input array to compare eps - epsilon, this value defines the + * precision of elements comparison + */ + bool equalsTo(const NDArray* other, double eps = 1e-5) const; + bool equalsTo(const NDArray& other, double eps = 1e-5) const; + + /** + * add given row vector to all rows of this array + * row - row vector to add + */ + void addiRowVector(const NDArray& row); + + /** + * add given row vector to all rows of this array, store result in target + * row - row vector to add + * target - where to store result + */ + void addRowVector(const NDArray& row, NDArray& target) const; + + /** + * subtract given row vector from all rows of this array, store result in + * target row - row vector to subtract target - where to store result + */ + void subRowVector(const NDArray& row, NDArray& target) const; + + /** + * multiply all rows of this array on given row vector, store result in + * target row - row vector to multiply on target - where to store result + */ + void mulRowVector(const NDArray& row, NDArray& target) const; + + /** + * divide all rows of this array on given row vector, store result in target + * row - row vector to divide on + * target - where to store result + */ + void divRowVector(const NDArray& row, NDArray& target) const; + + /** + * add given column vector to all columns of this array, store result in + * target column - column vector to add target - where to store result + */ + void addColumnVector(const NDArray& column, NDArray& target) const; + + /** + * add given column vector to all columns of this array, this array becomes + * affected (in-place operation) column - column vector to add + */ + void addiColumnVector(const NDArray& column); + + /** + * multiply all columns of this array on given column vector, this array + * becomes affected (in-place operation) column - column vector to multiply on + */ + void muliColumnVector(const NDArray& column); + + /** + * returns number of bytes used by _buffer & _shapeInfo + */ + FORCEINLINE Nd4jLong memoryFootprint(); + + /** + * these methods suited for FlatBuffers use + */ + template + std::vector getBufferAsVector(); + std::vector getShapeAsVector() const; + std::vector getShapeAsVectorInt() const; + std::vector getShapeInfoAsVector(); + std::vector getShapeInfoAsFlatVector(); + std::vector getShapeAsFlatVector(); + + /** + * set new order and shape in case of suitable array length (in-place + * operation) order - order to set shape - shape to set copyToNewBuff - if + * true then old buffer will be copied to new buffer if last one will be + * allocated after reshaping if there was permute applied before or there are + * weird strides, then new buffer is allocated for array + */ + bool reshapei(const char order, const std::initializer_list& shape, + const bool copyToNewBuff = true); + bool reshapei(const char order, const std::vector& shape, + const bool copyToNewBuff = true); + + bool reshapei(const std::initializer_list& shape, + const bool copyToNewBuff = true); + bool reshapei(const std::vector& shape, + const bool copyToNewBuff = true); + + /** + * creates new array with corresponding order and shape, new array will point + * on _buffer of this array order - order to set shape - shape to set + * + * if permute have been applied before or there are weird strides, then new + * buffer is allocated for new array + */ + NDArray reshape(const char order, const std::vector& shape, + const bool copyToNewBuff = true) const&; + NDArray reshape(const char order, const std::vector& shape, + const bool copyToNewBuff = true) &&; + + /** + * calculate strides and set given order + * order - order to set + */ + void updateStrides(const char order); + + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions + */ + void tilei(const std::vector& repeats); + + /** + * returns new array which is created by repeating of this array the number + * of times given by reps repeats - contains numbers of repetitions + */ + NDArray tile(const std::vector& repeats) const; + + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions target - + * where to store result + */ + void tile(const std::vector& repeats, NDArray& target) const; + + /** + * change an array by repeating it the number of times to acquire the new + * shape which is the same as target shape target - where to store result + */ + void tile(NDArray& target) const; + + /** + * check whether array is identity matrix + */ + bool isIdentityMatrix(); + + /** + * check whether array is unitary matrix + */ + bool isUnitary(); + + /** + * operator returns subarray with buffer pointing at this->_buffer with + * offset defined by given intervals idx - intervals of indexes which define + * the subarrays to point on, idx has form {dim0Start,dim0End, + * dim1Start,dim1End, ....} and length (2 * this->rankOf()) when (dimStart == + * dimEnd) then whole range will be used for current dimension + * keepUnitiesInShape - if false then eliminate unities from resulting array + * shape, for example {1,a,1,b} -> {a,b} isStrided - if true then idx has + * length (3 * this->rankOf()) and contains additional stride numbers which + * correspond to stride between dimStart and dimEnd, so structure of idx is + * like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} + */ + NDArray operator()(const std::vector& idx, + const bool keepUnitiesInShape = false, + const bool isStrided = false) const; + + /** + * evaluates subarray with buffer pointing at this->_buffer and offset + * defined by given sequential index subArrIdx and dimensions in dimsToExclude + * subArrIdx - index of current sub-array + * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, + * i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 + * sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] if + * dimsToExclude is empty then idxRanges containing all zeros (means whole + * array) will be returned. keepUnitiesInShape - if false then eliminate + * unities from resulting array shape, for example {1,a,1,b} -> {a,b} + */ + NDArray operator()(const Nd4jLong subArrIdx, + const std::vector& dimsToExclude, + bool keepUnitiesInShape = false) const; + + /** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) dimsToExclude - MUST BE SORTED, dimensions to + * evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] if + * dimsToExclude.size() = array rank it means sub-array is whole array and + * copy of original_shapeInfo will be returned and one zero offset + * subArrShapeInfo - output argument, contains shapeInfo common for all + * sub-arrays subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ + void getSubArrShapeAndOffsets(const std::vector& dimsToExclude, + Nd4jLong*& subArrShapeInfo, + Nd4jLong*& subArrOffsets, + bool keepUnitiesInShape = false) const; + + /** + * addition unary operator array += other + * other - input array to add + */ + void operator+=(const NDArray& other); + + /** + * subtraction unary operator array -= other + * other - input array to add + */ + void operator-=(const NDArray& other); + + template + void operator+=(const T other); + + template + void operator-=(const T other); + + /** + * negative operator, it changes sign of all array elements on opposite + */ + NDArray operator-() const&; + NDArray operator-() &&; + + /** + * pairwise multiplication unary operator array *= other + * other - input array to multiply on + */ + void operator*=(const NDArray& other); + + /** + * multiplication unary operator array *= scalar + * scalar - input scalar to multiply on + */ + template + void operator*=(const T scalar); + + /** + * pairwise division unary operator: array /= other + * other - input array to divide on + */ + void operator/=(const NDArray& other); + + /** + * division unary operator: array /= scalar + * scalar - input scalar to divide on + */ + template + void operator/=(const T scalar); + + /** + * friend function which implements mathematical multiplication of two arrays + * left - input array + * right - input array + */ + friend NDArray mmul(const NDArray& left, const NDArray& right); + + /** + * return vector containing _buffer as flat binary array + */ + std::vector asByteVector(); + + /** + * makes array to be identity matrix (not necessarily square), that is set + * all diagonal elements = 1, rest = 0 + */ + void setIdentity(); + + /** + * swaps the contents of tow arrays, + * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes + * may be different except one condition: arrays lengths must be the same + */ + void swapUnsafe(NDArray& other); + + /** + * return vector with buffer which points on corresponding diagonal elements + * of array type - means of vector to be returned: column ('c') or row ('r') + */ + NDArray diagonal(const char type) const; + + /** + * fill target matrix with given value in one or two directions from main + * diagonal: + * - down from main diagonal starting at subdiagonal number "lower" if + * direction = 'd' (down) or 'b' (both) + * - up from main diagonal starting at superdiagonal number "upper"if + * direction = 'u' (up) or 'b' (both) direction - in what direction to fill + * matrix. There are 3 possible directions: 'u' - fill up, mathematically this + * corresponds to lower triangular matrix, subdiagonal "lower" unaffected 'l' + * - fill down, mathematically this corresponds to upper triangular matrix, + * superdiagonal "upper" remains unaffected 'b' - fill in both directions, + * both "lower" and "upper" are taken into account rest of target elements are + * equal to this array elements target and this array should have same shapes, + * except when this_rank = 1 (in that case should be target_rank = 2) + */ + template + void fillAsTriangular(const float value, int lower, int upper, + NDArray& target, const char direction = 'b'); + + /** + * change an array by repeating it the number of times in order to acquire + * new shape equal to the input shape + * + * shape - contains new shape to broadcast array to + * target - optional argument, if target != nullptr the resulting array will + * be placed in target, in opposite case tile operation is done in place + */ + NDArray tileToShape(const Nd4jLong* shapeInfo); + void tileToShape(const std::vector& shape, NDArray& target); #ifndef __JAVACPP_HACK__ - void tileToShape(const std::initializer_list& shape, NDArray& target); + void tileToShape(const std::initializer_list& shape, + NDArray& target); #endif - template - NDArray asT() const; - - template - NDArray asS() const; - - NDArray asT(DataType dtype) const; - - - void linspace(const double start); - - void linspace(const double start, const double step); - - /** - * calculates the trace of an array, that is sum of elements on main diagonal = sum array[i, i, i, ...] - */ - double getTrace() const; - - ResultSet multipleTensorsAlongDimension(const std::vector& indices, const std::vector& dimensions) const; - - ResultSet allTensorsAlongDimension(const std::initializer_list& dimensions) const; - - ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; - - ResultSet allExamples()const ; - - /** - * set _shapeInfo - */ - void setShapeInfo(const Nd4jLong *shapeInfo); - void setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype); - void setShapeInfo(const ShapeDescriptor& descriptor); - void setShapeInfo(const ConstantDataBuffer& shapeBuffer); - - /** - * returns absolute offset which corresponds to given sequential index - */ - Nd4jLong getOffset(const Nd4jLong i) const; - - /** - * returns reference on array element with given index - */ - template - FORCEINLINE T& t(const Nd4jLong index); - - template - FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j); - template - FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); - template - FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w); - - - /** - * returns array element with given index - * i - element index in array - */ - template - FORCEINLINE T t(const Nd4jLong i) const; - - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - template - FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const; - - - /** - * default destructor - */ - ~NDArray() noexcept = default; - - /** - * set _shapeInfo - */ - FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo); - FORCEINLINE void setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype); - - /** - * returns the value of "dim" dimension - */ - Nd4jLong sizeAt(const int dim) const; - - /** - * returns stride of "dim" dimension - */ - Nd4jLong strideAt(const int dim) const; - - /** - * returns order of array - */ - FORCEINLINE char ordering() const; - - /** - * return _isView - */ - FORCEINLINE bool isView() const; - - /** - * returns shape portion of shapeInfo - */ - FORCEINLINE Nd4jLong* shapeOf() const; - - /** - * returns strides portion of shapeInfo - */ - FORCEINLINE Nd4jLong* stridesOf() const; - - /** - * returns rank of array - */ - FORCEINLINE int rankOf() const; - - /** - * returns length of array - */ - FORCEINLINE Nd4jLong lengthOf() const; - - /** - * returns number of rows in array - */ - FORCEINLINE Nd4jLong rows() const; - - /** - * returns number of columns in array - */ - FORCEINLINE Nd4jLong columns() const; - - /** - * returns size of array elements type - */ - FORCEINLINE size_t sizeOfT() const; - - /** - * returns element-wise-stride - */ - FORCEINLINE Nd4jLong ews() const; - - // returns true if arrays have same shape - FORCEINLINE bool isSameShape(const NDArray *other) const; - FORCEINLINE bool isSameShape(const NDArray &other) const; - FORCEINLINE bool isSameShape(const std::initializer_list& shape) const; - FORCEINLINE bool isSameShape(const std::vector& shape) const; - FORCEINLINE bool areSameShapeAndType(const NDArray& other) const; - - /** - * returns true if these two NDArrays have same rank, dimensions, strides, ews and order - */ - FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; - - /** - * returns true if buffer && shapeInfo were defined (non nullptr) - */ - FORCEINLINE bool nonNull() const; - - template - T r(const Nd4jLong i) const; - - /** - * returns array element with given index from linear buffer - * i - element index in array - */ - template - T e(const Nd4jLong i) const; - - /** - * returns element with given indexes from 2D array - * i - number of row - * j - number of column - */ - template - T e(const Nd4jLong i, const Nd4jLong j) const; - - /** - * returns element with given indexes from 3D array - * i - height - * j - width - * k - depth - */ - template - T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; - - /** - * returns element with given indexes from DD array - */ - template - T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const; - - /** - * returns array-scalar containing element of this array with given index - * i - element index in array - */ - NDArray e(const Nd4jLong i) const; - - /** - * assigns given scalar to array element by given index, regards array buffer as linear - * i - element index in array - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const T value); - - void p(const Nd4jLong i, const NDArray& value); - - /** - * assigns given scalar to 2D array element by given indexes - * i - number of row - * j - number of row - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const Nd4jLong j, const T value); - - /** - * assigns given scalar to 3D array element by given indexes - * i - height - * j - width - * k - depth - * value - scalar value to assign - */ - template - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value); - - template - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value); - void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, NDArray const& value); - - - template - void pIdx(const Nd4jLong* indices, const T value); - - /** - * returns true if array is 2D - */ - FORCEINLINE bool isMatrix() const; - - /** - * returns true if array is vector - */ - FORCEINLINE bool isVector() const; - - /** - * returns true if array is column vector - */ - FORCEINLINE bool isColumnVector() const; - - /** - * returns true if array is row vector - */ - FORCEINLINE bool isRowVector() const; - - /** - * returns true if all dimensions of array except one are unities, for example: [1,1,n,1], [n,1,1], [n], ... - * posOfNonUnityDim - one dimension with value > 1 - */ - FORCEINLINE bool isCommonVector(int& posOfNonUnityDim) const; - - - /** - * returns true if array is scalar - */ - FORCEINLINE bool isScalar() const; - - /** - * Returns data type of this array - * @return - */ - FORCEINLINE DataType dataType() const; - - /** - * This method returns true if value is from Integer space - * @return - */ - bool isZ() const; - - /** - * This method returns true if array is from Real space - * @return - */ - bool isR() const; - - /** - * This method returns true if array is from Boolean space - * @return - */ - bool isB() const; - - /** - * This method returns true if array contains Complex numbers - * @return - */ - bool isC() const; - - /** - * This method returns true if array contains String - * @return - */ - bool isS() const; - - template - std::vector asVectorT(); - - FORCEINLINE bool isAttached(); - - NDArray detach(); - - /** - * This method returns true if array is valid array with some shape etc - * @return - */ - bool defined() const; - bool undefined() const; - - FORCEINLINE bool operator==(const NDArray &other) const; - - FORCEINLINE bool operator!=(const NDArray &other) const; - }; - - - + template + NDArray asT() const; + + template + NDArray asS() const; + + NDArray asT(DataType dtype) const; + + void linspace(const double start); + + void linspace(const double start, const double step); + + /** + * calculates the trace of an array, that is sum of elements on main diagonal + * = sum array[i, i, i, ...] + */ + double getTrace() const; + + ResultSet multipleTensorsAlongDimension( + const std::vector& indices, + const std::vector& dimensions) const; + + ResultSet allTensorsAlongDimension( + const std::initializer_list& dimensions) const; + + ResultSet allTensorsAlongDimension(const std::vector& dimensions) const; + + ResultSet allExamples() const; + + /** + * set _shapeInfo + */ + void setShapeInfo(const Nd4jLong* shapeInfo); + void setShapeInfo(const Nd4jLong* shapeInfo, const sd::DataType dtype); + void setShapeInfo(const ShapeDescriptor& descriptor); + void setShapeInfo(const ConstantDataBuffer& shapeBuffer); + + /** + * returns absolute offset which corresponds to given sequential index + */ + Nd4jLong getOffset(const Nd4jLong i) const; + + /** + * returns reference on array element with given index + */ + template + FORCEINLINE T& t(const Nd4jLong index); + + template + FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j); + template + FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); + template + FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w); + + /** + * returns array element with given index + * i - element index in array + */ + template + FORCEINLINE T t(const Nd4jLong i) const; + + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) const; + + /** + * default destructor + */ + ~NDArray() noexcept = default; + + /** + * set _shapeInfo + */ + FORCEINLINE void setShapeInfo(Nd4jLong* shapeInfo); + FORCEINLINE void setShapeInfo(Nd4jLong* shapeInfo, const sd::DataType dtype); + + /** + * returns the value of "dim" dimension + */ + Nd4jLong sizeAt(const int dim) const; + + /** + * returns stride of "dim" dimension + */ + Nd4jLong strideAt(const int dim) const; + + /** + * returns order of array + */ + FORCEINLINE char ordering() const; + + /** + * return _isView + */ + FORCEINLINE bool isView() const; + + /** + * returns shape portion of shapeInfo + */ + FORCEINLINE Nd4jLong* shapeOf() const; + + /** + * returns strides portion of shapeInfo + */ + FORCEINLINE Nd4jLong* stridesOf() const; + + /** + * returns rank of array + */ + FORCEINLINE int rankOf() const; + + /** + * returns length of array + */ + FORCEINLINE Nd4jLong lengthOf() const; + + /** + * returns number of rows in array + */ + FORCEINLINE Nd4jLong rows() const; + + /** + * returns number of columns in array + */ + FORCEINLINE Nd4jLong columns() const; + + /** + * returns size of array elements type + */ + FORCEINLINE size_t sizeOfT() const; + + /** + * returns element-wise-stride + */ + FORCEINLINE Nd4jLong ews() const; + + // returns true if arrays have same shape + FORCEINLINE bool isSameShape(const NDArray* other) const; + FORCEINLINE bool isSameShape(const NDArray& other) const; + FORCEINLINE bool isSameShape( + const std::initializer_list& shape) const; + FORCEINLINE bool isSameShape(const std::vector& shape) const; + FORCEINLINE bool areSameShapeAndType(const NDArray& other) const; + + /** + * returns true if these two NDArrays have same rank, dimensions, strides, + * ews and order + */ + FORCEINLINE bool isSameShapeStrict(const NDArray& other) const; + + /** + * returns true if buffer && shapeInfo were defined (non nullptr) + */ + FORCEINLINE bool nonNull() const; + + template + T r(const Nd4jLong i) const; + + /** + * returns array element with given index from linear buffer + * i - element index in array + */ + template + T e(const Nd4jLong i) const; + + /** + * returns element with given indexes from 2D array + * i - number of row + * j - number of column + */ + template + T e(const Nd4jLong i, const Nd4jLong j) const; + + /** + * returns element with given indexes from 3D array + * i - height + * j - width + * k - depth + */ + template + T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + + /** + * returns element with given indexes from DD array + */ + template + T e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l) const; + + /** + * returns array-scalar containing element of this array with given index + * i - element index in array + */ + NDArray e(const Nd4jLong i) const; + + /** + * assigns given scalar to array element by given index, regards array buffer + * as linear i - element index in array value - scalar value to assign + */ + template + void p(const Nd4jLong i, const T value); + + void p(const Nd4jLong i, const NDArray& value); + + /** + * assigns given scalar to 2D array element by given indexes + * i - number of row + * j - number of row + * value - scalar value to assign + */ + template + void p(const Nd4jLong i, const Nd4jLong j, const T value); + + /** + * assigns given scalar to 3D array element by given indexes + * i - height + * j - width + * k - depth + * value - scalar value to assign + */ + template + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value); + + template + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, + const T value); + void p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, + NDArray const& value); + + template + void pIdx(const Nd4jLong* indices, const T value); + + /** + * returns true if array is 2D + */ + FORCEINLINE bool isMatrix() const; + + /** + * returns true if array is vector + */ + FORCEINLINE bool isVector() const; + + /** + * returns true if array is column vector + */ + FORCEINLINE bool isColumnVector() const; + + /** + * returns true if array is row vector + */ + FORCEINLINE bool isRowVector() const; + + /** + * returns true if all dimensions of array except one are unities, for + * example: [1,1,n,1], [n,1,1], [n], ... posOfNonUnityDim - one dimension with + * value > 1 + */ + FORCEINLINE bool isCommonVector(int& posOfNonUnityDim) const; + + /** + * returns true if array is scalar + */ + FORCEINLINE bool isScalar() const; + + /** + * Returns data type of this array + * @return + */ + FORCEINLINE DataType dataType() const; + + /** + * This method returns true if value is from Integer space + * @return + */ + bool isZ() const; + + /** + * This method returns true if array is from Real space + * @return + */ + bool isR() const; + + /** + * This method returns true if array is from Boolean space + * @return + */ + bool isB() const; + + /** + * This method returns true if array contains Complex numbers + * @return + */ + bool isC() const; + + /** + * This method returns true if array contains String + * @return + */ + bool isS() const; + + template + std::vector asVectorT(); + + FORCEINLINE bool isAttached(); + + NDArray detach(); + + /** + * This method returns true if array is valid array with some shape etc + * @return + */ + bool defined() const; + bool undefined() const; + + FORCEINLINE bool operator==(const NDArray& other) const; + + FORCEINLINE bool operator!=(const NDArray& other) const; +}; ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// -bool NDArray::isAttached() { - return this->_context->getWorkspace() != nullptr; -} +bool NDArray::isAttached() { return this->_context->getWorkspace() != nullptr; } template FORCEINLINE R NDArray::templatedGet(void const* buffer, Nd4jLong index) const { - auto b = reinterpret_cast(buffer); - auto v = static_cast(b[index]); - return v; + auto b = reinterpret_cast(buffer); + auto v = static_cast(b[index]); + return v; } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(Nd4jLong *shapeInfo) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); - - if (shapeInfo != nullptr) { - _dataType = ArrayOptions::dataType(_shapeInfo); - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _length = 0; - } +void NDArray::setShapeInfo(Nd4jLong* shapeInfo) { + auto buffer = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); + + if (shapeInfo != nullptr) { + _dataType = ArrayOptions::dataType(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _length = 0; + } } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(Nd4jLong *shapeInfo, const sd::DataType dtype) { - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); - _shapeInfo = buffer.primaryAsT(); - _shapeInfoD = buffer.specialAsT(); - - if (shapeInfo != nullptr) { - _dataType = dtype; - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _length = 0; - } +void NDArray::setShapeInfo(Nd4jLong* shapeInfo, const sd::DataType dtype) { + auto buffer = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(shapeInfo); + _shapeInfo = buffer.primaryAsT(); + _shapeInfoD = buffer.specialAsT(); + + if (shapeInfo != nullptr) { + _dataType = dtype; + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _length = 0; + } } ////////////////////////////////////////////////////////////////////////// -char NDArray::ordering() const { - return shape::order(_shapeInfo); -} +char NDArray::ordering() const { return shape::order(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isView() const { - return _isView; -} +bool NDArray::isView() const { return _isView; } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* NDArray::shapeOf() const { - return shape::shapeOf(_shapeInfo); -} +Nd4jLong* NDArray::shapeOf() const { return shape::shapeOf(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* NDArray::stridesOf() const { - return shape::stride(_shapeInfo); -} +Nd4jLong* NDArray::stridesOf() const { return shape::stride(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -int NDArray::rankOf() const { - return shape::rank(_shapeInfo); -} +int NDArray::rankOf() const { return shape::rank(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::lengthOf() const { - return _length; -} +Nd4jLong NDArray::lengthOf() const { return _length; } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::rows() const { - if (this->rankOf() == 1) - return 1; + if (this->rankOf() == 1) return 1; - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have rows"); + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have rows"); - return shapeOf()[0]; + return shapeOf()[0]; } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::columns() const { - if (this->rankOf() == 1) - return this->lengthOf(); + if (this->rankOf() == 1) return this->lengthOf(); - if (this->rankOf() > 2) - throw std::runtime_error("Array with rank > 2 can't have columns"); + if (this->rankOf() > 2) + throw std::runtime_error("Array with rank > 2 can't have columns"); - return shapeOf()[1]; + return shapeOf()[1]; } ////////////////////////////////////////////////////////////////////////// size_t NDArray::sizeOfT() const { - return DataTypeUtils::sizeOfElement(_dataType); + return DataTypeUtils::sizeOfElement(_dataType); } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::ews() const { - if (this->isEmpty() || this->rankOf() == 0) - return 1; + if (this->isEmpty() || this->rankOf() == 0) return 1; - return shape::elementWiseStride(_shapeInfo); + return shape::elementWiseStride(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::nonNull() const { - if (isEmpty()) - return true; + if (isEmpty()) return true; - if(!Environment::getInstance()->isCPU()) - return getDataBuffer()->special() != nullptr && specialShapeInfo() != nullptr; + if (!Environment::getInstance()->isCPU()) + return getDataBuffer()->special() != nullptr && + specialShapeInfo() != nullptr; - return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr; + return getDataBuffer()->primary() != nullptr && shapeInfo() != nullptr; } ////////////////////////////////////////////////////////////////////////// bool NDArray::isMatrix() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - return 0 != shape::isMatrix(this->_shapeInfo); + return 0 != shape::isMatrix(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isVector() const { - if (isEmpty()) - return false; - if (rankOf() == 1) - return true; - return !isScalar() && shape::isVector(this->_shapeInfo); + if (isEmpty()) return false; + if (rankOf() == 1) return true; + return !isScalar() && shape::isVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isColumnVector() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - return !isScalar() && shape::isColumnVector(this->_shapeInfo); + return !isScalar() && shape::isColumnVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isRowVector() const { - if (isEmpty()) - return false; + if (isEmpty()) return false; - // 1D edge case - if (shape::rank(this->_shapeInfo) == 1) - return true; + // 1D edge case + if (shape::rank(this->_shapeInfo) == 1) return true; - return !isScalar() && shape::isRowVector(this->_shapeInfo); + return !isScalar() && shape::isRowVector(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isCommonVector(int& posOfNonUnityDim) const { - - return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); + return shape::isCommonVector(_shapeInfo, posOfNonUnityDim); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isScalar() const { - return 0 != shape::isScalar(this->_shapeInfo); + return 0 != shape::isScalar(this->_shapeInfo); } - ////////////////////////////////////////////////////////////////////////// Nd4jLong FORCEINLINE NDArray::memoryFootprint() { - Nd4jLong size = this->lengthOf() * this->sizeOfT(); - size += shape::shapeInfoByteLength(this->rankOf()); - return size; + Nd4jLong size = this->lengthOf() * this->sizeOfT(); + size += shape::shapeInfoByteLength(this->rankOf()); + return size; } ////////////////////////////////////////////////////////////////////////// // still the definition of inline function must be in header file -bool NDArray::isSameShape(const std::vector& shape) const{ - if (this->isScalar() && shape.size() == 1 && shape[0] == 0) - return true; - if (this->rankOf() != (int) shape.size()) - return false; - for (int e = 0; e < this->rankOf(); e++) { - if (this->shapeOf()[e] != shape[e] && shape[e] != -1) - return false; - } - return true; +bool NDArray::isSameShape(const std::vector& shape) const { + if (this->isScalar() && shape.size() == 1 && shape[0] == 0) return true; + if (this->rankOf() != (int)shape.size()) return false; + for (int e = 0; e < this->rankOf(); e++) { + if (this->shapeOf()[e] != shape[e] && shape[e] != -1) return false; + } + return true; } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isSameShape(const NDArray *other) const { - if (this->isEmpty() != other->isEmpty()) - return false; +bool NDArray::isSameShape(const NDArray* other) const { + if (this->isEmpty() != other->isEmpty()) return false; - return isSameShape(std::vector(other->_shapeInfo+1, other->_shapeInfo+1+other->_shapeInfo[0])); + return isSameShape(std::vector( + other->_shapeInfo + 1, other->_shapeInfo + 1 + other->_shapeInfo[0])); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isSameShape(const NDArray &other) const { - return isSameShape(&other); +bool NDArray::isSameShape(const NDArray& other) const { + return isSameShape(&other); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isSameShape(const std::initializer_list& other) const { - return isSameShape(std::vector(other)); + return isSameShape(std::vector(other)); } ////////////////////////////////////////////////////////////////////////// bool NDArray::areSameShapeAndType(const NDArray& other) const { + if (rankOf() != other.rankOf() || _dataType != other._dataType) return false; - if(rankOf() != other.rankOf() || _dataType != other._dataType) - return false; + for (int i = 0; i < rankOf(); ++i) + if (sizeAt(i) != other.sizeAt(i)) return false; - for(int i = 0; i < rankOf(); ++i) - if(sizeAt(i) != other.sizeAt(i)) - return false; - - return true; + return true; } ////////////////////////////////////////////////////////////////////////// @@ -1744,227 +1981,228 @@ bool NDArray::areSameShapeAndType(const NDArray& other) const { // still the definition of inline function must be in header file bool NDArray::isSameShapeStrict(const NDArray& other) const { - return shape::equalsStrict(_shapeInfo, other._shapeInfo); + return shape::equalsStrict(_shapeInfo, other._shapeInfo); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isEmpty() const { - if (this->_shapeInfo == nullptr) - return false; + if (this->_shapeInfo == nullptr) return false; - return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY; + return ArrayOptions::arrayType(this->shapeInfo()) == ArrayType::EMPTY; } ////////////////////////////////////////////////////////////////////////// -bool NDArray::operator==(const NDArray &other) const { - // if (this->dataType() != other.dataType()) // this comparison is already present in equalsTo - // return false; +bool NDArray::operator==(const NDArray& other) const { + // if (this->dataType() != other.dataType()) // this comparison is already + // present in equalsTo + // return false; - if (!this->isSameShape(&other)) - return false; + if (!this->isSameShape(&other)) return false; - return this->equalsTo(&other); + return this->equalsTo(&other); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::operator!=(const NDArray &other) const { - if (this->dataType() != other.dataType()) - return true; +bool NDArray::operator!=(const NDArray& other) const { + if (this->dataType() != other.dataType()) return true; - if (!this->isSameShape(&other)) - return true; + if (!this->isSameShape(&other)) return true; - return !this->equalsTo(&other); + return !this->equalsTo(&other); } ////////////////////////////////////////////////////////////////////////// DataType NDArray::dataType() const { - return _dataType; - // return ArrayOptions::dataType(_shapeInfo); + return _dataType; + // return ArrayOptions::dataType(_shapeInfo); } //////////////////////////////////////////////////////////////////////// template T& NDArray::t(const Nd4jLong i) { + // if (i >= _length) + // throw std::invalid_argument("NDArray::t(i): input index is out of array + // length !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i): type of array is not equal to template type T!"); - // if (i >= _length) - // throw std::invalid_argument("NDArray::t(i): input index is out of array length !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); + if (!isActualOnHostSide()) syncToHost(); - tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); } //////////////////////////////////////////////////////////////////////// template T& NDArray::t(const Nd4jLong i, const Nd4jLong j) { - - if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) - throw std::invalid_argument("NDArray::t(i,j): one of input indexes is out of array length or rank!=2 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); - - Nd4jLong coords[2] = {i, j}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); + if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) + throw std::invalid_argument( + "NDArray::t(i,j): one of input indexes is out of array length or " + "rank!=2 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[2] = {i, j}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); } template T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { - - if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); - - Nd4jLong coords[3] = {i, j, k}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); + if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) + throw std::invalid_argument( + "NDArray::t(i,j,k): one of input indexes is out of array length or " + "rank!=3!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[3] = {i, j, k}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); } template -T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) { - - if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) - throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); - - Nd4jLong coords[4] = {i, j, k, w}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickWriteHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); +T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) { + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || + w >= sizeAt(3)) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): one of input indexes is out of array length or " + "rank!=4 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); } //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i) const { + // if (i >= _length) + // throw std::invalid_argument("NDArray::t(i): input index is out of array + // length !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i): type of array is not equal to template type T!"); - // if (i >= _length) - // throw std::invalid_argument("NDArray::t(i): input index is out of array length !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); + if (!isActualOnHostSide()) syncToHost(); - tickReadHost(); - return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(getOffset(i)))); } //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { - - if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) - throw std::invalid_argument("NDArray::t(i,j): one of input indexes is out of array length or rank!=2 !"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); - - Nd4jLong coords[2] = {i, j}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickReadHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); + if (rankOf() != 2 || i >= sizeAt(0) || j >= sizeAt(1)) + throw std::invalid_argument( + "NDArray::t(i,j): one of input indexes is out of array length or " + "rank!=2 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[2] = {i, j}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); } - template - T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - - if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); - - Nd4jLong coords[3] = {i, j, k}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickReadHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); - } - - template - T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const { - - if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) - throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!"); - if (DataTypeUtils::fromT() != _dataType) - throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); - - if(!isActualOnHostSide()) - syncToHost(); +template +T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { + if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) + throw std::invalid_argument( + "NDArray::t(i,j,k): one of input indexes is out of array length or " + "rank!=3!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[3] = {i, j, k}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); +} - Nd4jLong coords[4] = {i, j, k, w}; - auto offset = shape::getOffset(shapeInfo(), coords); - tickReadHost(); - return *(reinterpret_cast(bufferWithOffset(offset))); - } +template +T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong w) const { + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || + w >= sizeAt(3)) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): one of input indexes is out of array length or " + "rank!=4!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument( + "NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if (!isActualOnHostSide()) syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(shapeInfo(), coords); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); +} #ifndef __JAVACPP_HACK__ //////////////////////////////////////////////////////////////////////// -std::shared_ptr NDArray::getDataBuffer() const { - return _buffer; -} +std::shared_ptr NDArray::getDataBuffer() const { return _buffer; } //////////////////////////////////////////////////////////////////////// -std::shared_ptr NDArray::dataBuffer() { - return _buffer; -} +std::shared_ptr NDArray::dataBuffer() { return _buffer; } #endif //////////////////////////////////////////////////////////////////////// const void* NDArray::buffer() const { - return _buffer->primary() != nullptr ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; + return _buffer->primary() != nullptr + ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// void* NDArray::buffer() { - return _buffer->primary() != nullptr ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) : nullptr; + return _buffer->primary() != nullptr + ? static_cast(_buffer->primary()) + (_offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// -const Nd4jLong* NDArray::shapeInfo() const { - return _shapeInfo; -} +const Nd4jLong* NDArray::shapeInfo() const { return _shapeInfo; } //////////////////////////////////////////////////////////////////////// const Nd4jLong* NDArray::specialShapeInfo() const { - if (_shapeInfoD == nullptr) - return _shapeInfo; - // FIXME: this should be fixed once CUDA backend added - return _shapeInfoD; + if (_shapeInfoD == nullptr) return _shapeInfo; + // FIXME: this should be fixed once CUDA backend added + return _shapeInfoD; } //////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::bufferOffset() const { - return _offset; -} - +Nd4jLong NDArray::bufferOffset() const { return _offset; } -#if defined(__CUDACC__) //&& defined(BUILD_TESTS) +#if defined(__CUDACC__) //&& defined(BUILD_TESTS) // for CUDA we need stil stuff inline #include #endif -} +} // namespace sd #endif diff --git a/libnd4j/include/array/NDArrayFactory.h b/libnd4j/include/array/NDArrayFactory.h index d73fba2bf9f2..46688ebff078 100644 --- a/libnd4j/include/array/NDArrayFactory.h +++ b/libnd4j/include/array/NDArrayFactory.h @@ -23,171 +23,341 @@ #ifndef SD_NDARRAYFACTORY_H #define SD_NDARRAYFACTORY_H -#include -#include #include + +#include +#include //#include #include -#include +#include namespace sd { - class SD_EXPORT NDArrayFactory { - private: - template - static void memcpyFromVector(void *ptr, const std::vector &vector); - public: - static NDArray undefined(); - - template - static NDArray* empty_(sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* empty_(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray empty(sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray empty(sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* valueOf(const std::initializer_list& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* valueOf(const std::vector& shape, T value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* valueOf(const std::vector& shape, const NDArray& value, char order = 'c', sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* linspace(T from, T to, Nd4jLong numElements); - - - template - static NDArray* create_(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* create_(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(const T value, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray create(sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - template - static NDArray create(DataType type, const T scalar, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - - template - static NDArray* vector(Nd4jLong length, T startingValue = (T) 0, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* create_(char order, const std::vector &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - static NDArray* create_( char order, const std::vector &shape, sd::DataType dataType, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray* create_(char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray create(char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(const std::vector &values, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); +class SD_EXPORT NDArrayFactory { + private: + template + static void memcpyFromVector(void* ptr, const std::vector& vector); + + public: + static NDArray undefined(); + + template + static NDArray* empty_( + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* empty_( + sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray empty( + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray empty( + sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* valueOf( + const std::initializer_list& shape, T value, char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* valueOf( + const std::vector& shape, T value, char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* valueOf( + const std::vector& shape, const NDArray& value, + char order = 'c', + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* linspace(T from, T to, Nd4jLong numElements); + + template + static NDArray* create_( + const T value, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* create_( + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + const T value, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray create( + sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + template + static NDArray create( + DataType type, const T scalar, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* vector( + Nd4jLong length, T startingValue = (T)0, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* create_( + char order, const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + static NDArray* create_( + char order, const std::vector& shape, sd::DataType dataType, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray* create_( + char order, const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray create( + char order, const std::vector& shape, sd::DataType dtype, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + const std::vector& values, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); #ifndef __JAVACPP_HACK__ - // this method only available out of javacpp - /** - * This constructor creates vector of T - * - * @param values - */ - - template - static NDArray create(char order, const std::initializer_list& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(T* buffer, char order, const std::initializer_list& shape, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - template - static NDArray create(char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - /** - * This method creates NDArray from .npy file - * @param fileName - * @return - */ - static NDArray fromNpyFile(const char *fileName); - - /** - * This factory create array from utf8 string - * @return NDArray default dataType UTF8 - */ - static NDArray string(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_(const char *string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_(const std::string &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string(const std::string& string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from utf16 string - * @return NDArray default dataType UTF16 - */ - static NDArray string(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string(const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from utf32 string - * @return NDArray default dataType UTF32 - */ - static NDArray string(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string(const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from vector of utf8 strings - * @return NDArray default dataType UTF8 - */ - static NDArray string( const std::vector &shape, const std::initializer_list &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::initializer_list &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::vector &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray string( const std::vector &shape, const std::vector &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::initializer_list &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::initializer_list &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::vector &strings, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static NDArray* string_( const std::vector &shape, const std::vector &string, sd::DataType dtype = sd::DataType::UTF8, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - - /** - * This factory create array from vector of utf16 strings - * @return NDArray default dataType UTF16 - */ - static NDArray string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF16, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - /** - * This factory create array from vector of utf32 strings - * @return NDArray default dataType UTF32 - */ - static NDArray string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray string( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& strings, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - static NDArray* string_( const std::vector& shape, const std::vector& string, sd::DataType dtype = sd::DataType::UTF32, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); - - - static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); + // this method only available out of javacpp + /** + * This constructor creates vector of T + * + * @param values + */ + + template + static NDArray create( + char order, const std::initializer_list& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + T* buffer, char order, const std::initializer_list& shape, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + template + static NDArray create( + char order, const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * This method creates NDArray from .npy file + * @param fileName + * @return + */ + static NDArray fromNpyFile(const char* fileName); + + /** + * This factory create array from utf8 string + * @return NDArray default dataType UTF8 + */ + static NDArray string( + const char* string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const char* string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::string& string, sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from utf16 string + * @return NDArray default dataType UTF16 + */ + static NDArray string( + const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const char16_t* u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::u16string& u16string, sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from utf32 string + * @return NDArray default dataType UTF32 + */ + static NDArray string( + const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const char32_t* u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::u32string& u32string, sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from vector of utf8 strings + * @return NDArray default dataType UTF8 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF8, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + + /** + * This factory create array from vector of utf16 strings + * @return NDArray default dataType UTF16 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF16, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + /** + * This factory create array from vector of utf32 strings + * @return NDArray default dataType UTF32 + */ + static NDArray string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray string( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::initializer_list& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& strings, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + static NDArray* string_( + const std::vector& shape, + const std::vector& string, + sd::DataType dtype = sd::DataType::UTF32, + sd::LaunchContext* context = sd::LaunchContext::defaultContext()); + + static ResultSet createSetOfArrs( + const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, + const Nd4jLong* offsets, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); #endif - }; -} +}; +} // namespace sd -#endif //SD_NDARRAYFACTORY_H +#endif // SD_NDARRAYFACTORY_H diff --git a/libnd4j/include/array/NDArrayLambda.hXX b/libnd4j/include/array/NDArrayLambda.hXX index f213b6aa6a96..6e1289a45683 100644 --- a/libnd4j/include/array/NDArrayLambda.hXX +++ b/libnd4j/include/array/NDArrayLambda.hXX @@ -17,306 +17,401 @@ #ifndef CUDA_LAMBDA_HELPER #define CUDA_LAMBDA_HELPER -#include -#include -#include #include #include +#include +#include +#include -static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); +static Nd4jLong __device__ __noinline__ +getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); } static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) { - return shape::length(shapeInfo); + return shape::length(shapeInfo); } -template static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); -template static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); +template +static _CUDA_G void lambdaKernel(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaIndexedKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaIndexedPairwiseKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaPairwiseKernel(const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda); +template +static _CUDA_G void lambdaTriplewiseKernel( + const void *vw, const Nd4jLong *wShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template class LambdaHelper { -public: - - template - FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyLambda execution failed"); - } - - template - FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaIndexedKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyIndexedLambda execution failed"); - } - - template - FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyPairwiseLambda execution failed"); - } - - template - FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaIndexedPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda execution failed"); - } - - template - FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream,const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - lambdaTriplewiseKernel<<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); - auto err = cudaStreamSynchronize(*stream); - if (err != 0) - throw std::runtime_error("NDArray::applyTriplewiseLambda execution failed"); - } + public: + template + FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + lambdaKernel + <<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyLambda execution failed"); + } + + template + FORCEINLINE static void lambdaIndexedLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaIndexedKernel + <<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyIndexedLambda execution failed"); + } + + template + FORCEINLINE static void lambdaPairwiseLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaPairwiseKernel<<<256, 512, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error("NDArray::applyPairwiseLambda execution failed"); + } + + template + FORCEINLINE static void lambdaIndexedPairwiseLauncher( + cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Lambda lambda) { + lambdaIndexedPairwiseKernel<<<256, 512, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda execution failed"); + } + + template + FORCEINLINE static void lambdaTriplewiseLauncher( + cudaStream_t *stream, const void *vw, const Nd4jLong *wShapeInfo, + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + lambdaTriplewiseKernel<<<256, 512, 1024, *stream>>>( + vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); + auto err = cudaStreamSynchronize(*stream); + if (err != 0) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda execution failed"); + } }; //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static _CUDA_G void lambdaKernel(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto zLength = length(zShapeInfo); + auto zLength = length(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(x[e * xEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(x[e * xEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); - z[zOffset] = lambda(x[xOffset]); - } + z[zOffset] = lambda(x[xOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static _CUDA_G void lambdaIndexedKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto zLength = length(zShapeInfo); + auto zLength = length(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(e, x[e * xEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(e, x[e * xEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); - z[zOffset] = lambda(e, x[xOffset]); - } + z[zOffset] = lambda(e, x[xOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(e, x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaIndexedPairwiseKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaPairwiseKernel(const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Lambda lambda) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(x[xOffset], y[yOffset]); } + } } //////////////////////////////////////////////////////////////////////// template -static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { - auto w = reinterpret_cast(vw); - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto wEws = shape::elementWiseStride(wShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto wOrder = shape::order(wShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto zLength = length(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (wEws > 1 && xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder && wOrder == xOrder) { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) - z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); - } else { - for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto wOffset = getIndexOffset(e, wShapeInfo); - auto xOffset = getIndexOffset(e, xShapeInfo); - auto yOffset = getIndexOffset(e, yShapeInfo); - auto zOffset = getIndexOffset(e, zShapeInfo); - - z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); - } +static _CUDA_G void lambdaTriplewiseKernel( + const void *vw, const Nd4jLong *wShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { + auto w = reinterpret_cast(vw); + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto wEws = shape::elementWiseStride(wShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto wOrder = shape::order(wShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto zLength = length(zShapeInfo); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (wEws > 1 && xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && + yOrder == xOrder && wOrder == xOrder) { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) + z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); + } else { + for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { + auto wOffset = getIndexOffset(e, wShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); + + z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } + } } #endif ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyLambda(Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType()) - throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); - +template +void NDArray::applyLambda(Lambda func, NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType()) + throw std::runtime_error( + "NDArray::applyLambda X/Z data types must be the same"); + // throw datatype_exception::build("NDArray::applyLambda X/Z data types must + // be the same", dtype, target.dataType()); + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType() || dtype != other.dataType()) - throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); - //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); - - prepareSpecialUse({&target}, {this, &other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &other}); - +template +void NDArray::applyPairwiseLambda(const NDArray &other, Lambda func, + NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType() || dtype != other.dataType()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); + // throw datatype_exception::build("NDArray::applyLambda X/Z data types must + // be the same", dtype, target.dataType()); + + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaPairwiseLauncher(this->_context->getCudaStream(), + this->specialBuffer(), this->specialShapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), + target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedLambda(Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - if (dtype != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); +void NDArray::applyIndexedLambda(Lambda func, NDArray &target) { + auto dtype = this->dataType(); + if (dtype != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedLambda X/Z data types must be the same"); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaIndexedLauncher(this->_context->getCudaStream(), + this->specialBuffer(), this->specialShapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - if (dtype != target.dataType() || dtype != other.dataType()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); - - prepareSpecialUse({&target}, {this, &other}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &other}); +void NDArray::applyIndexedPairwiseLambda(NDArray &other, Lambda func, + NDArray &target) { + auto dtype = this->dataType(); + if (dtype != target.dataType() || dtype != other.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the " + "same"); + + prepareSpecialUse({&target}, {this, &other}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaIndexedPairwiseLauncher( + this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), other.specialBuffer(), + other.specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) { - - auto dtype = this->dataType(); - - if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType()) - throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); - - prepareSpecialUse({&target}, {this, &second, &third}); - BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); - registerSpecialUse({&target}, {this, &second, &third}); +void NDArray::applyTriplewiseLambda(NDArray &second, NDArray &third, + Lambda func, NDArray &target) { + auto dtype = this->dataType(); + + if (dtype != target.dataType() || dtype != second.dataType() || + dtype != third.dataType()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); + + prepareSpecialUse({&target}, {this, &second, &third}); + BUILD_SINGLE_SELECTOR( + dtype, LambdaHelper, + ::lambdaTriplewiseLauncher( + this->_context->getCudaStream(), this->specialBuffer(), + this->specialShapeInfo(), second.specialBuffer(), + second.specialShapeInfo(), third.specialBuffer(), + third.specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), func), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this, &second, &third}); } - - - - - diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index bb9e068dadfd..da63086474a7 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -23,90 +23,90 @@ #ifndef NDARRAY_LIST_H #define NDARRAY_LIST_H -#include -#include -#include #include #include #include -namespace sd { - class SD_EXPORT NDArrayList { - protected: - class InternalArrayList { - public: - // numeric and symbolic ids of this list - std::pair _id; - std::string _name; +#include +#include +#include - sd::DataType _dtype; +namespace sd { +class SD_EXPORT NDArrayList { + protected: + class InternalArrayList { + public: + // numeric and symbolic ids of this list + std::pair _id; + std::string _name; - // stored chunks - MAP_IMPL _chunks; + sd::DataType _dtype; - // just a counter, for stored elements - std::atomic _elements; - mutable std::atomic _counter; + // stored chunks + MAP_IMPL _chunks; - // reference shape - std::vector _shape; + // just a counter, for stored elements + std::atomic _elements; + mutable std::atomic _counter; - // unstack axis - int _axis = 0; + // reference shape + std::vector _shape; - // - bool _expandable = false; + // unstack axis + int _axis = 0; - // maximum number of elements - int _height = 0; + // + bool _expandable = false; + // maximum number of elements + int _height = 0; - ////////// - InternalArrayList(int height = 0, bool expandable = false); - ~InternalArrayList() = default; - }; + ////////// + InternalArrayList(int height = 0, bool expandable = false); + ~InternalArrayList() = default; + }; - std::shared_ptr _state; + std::shared_ptr _state; - public: - NDArrayList(int height = 0, bool expandable = false); - ~NDArrayList(); + public: + NDArrayList(int height = 0, bool expandable = false); + ~NDArrayList(); - NDArrayList(const sd::NDArrayList &other); - NDArrayList(sd::NDArrayList &&other); + NDArrayList(const sd::NDArrayList& other); + NDArrayList(sd::NDArrayList&& other); - NDArrayList& operator=(const NDArrayList& other) noexcept; + NDArrayList& operator=(const NDArrayList& other) noexcept; - // move assignment operator - NDArrayList& operator=(NDArrayList&& other) noexcept; + // move assignment operator + NDArrayList& operator=(NDArrayList&& other) noexcept; - sd::DataType dataType() const; + sd::DataType dataType() const; - NDArray read(int idx); - NDArray readRaw(int idx); - Nd4jStatus write(int idx, const NDArray &array); + NDArray read(int idx); + NDArray readRaw(int idx); + Nd4jStatus write(int idx, const NDArray& array); - NDArray pick(const std::vector& indices); - bool isWritten(int index) const; + NDArray pick(const std::vector& indices); + bool isWritten(int index) const; - const std::vector& shape() const; - void setShape(const std::vector &shape); + const std::vector& shape() const; + void setShape(const std::vector& shape); - NDArray stack() const; - void unstack(const NDArray &array, int axis); + NDArray stack() const; + void unstack(const NDArray& array, int axis); - const std::pair& id() const; - const std::string& name() const; + const std::pair& id() const; + const std::string& name() const; - NDArrayList clone(); + NDArrayList clone(); - bool equals(NDArrayList& other); + bool equals(NDArrayList& other); - int elements() const; - int height() const; + int elements() const; + int height() const; - int counter() const; - }; -} + int counter() const; +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/array/ResultSet.h b/libnd4j/include/array/ResultSet.h index 6883a14c105a..6b6dc9c0e496 100644 --- a/libnd4j/include/array/ResultSet.h +++ b/libnd4j/include/array/ResultSet.h @@ -25,52 +25,53 @@ #ifndef LIBND4J_RESULTSET_H #define LIBND4J_RESULTSET_H -#include #include -#include #include +#include + +#include namespace sd { - class NDArray; // forward declaration of template class NDArray +class NDArray; // forward declaration of template class NDArray - class SD_EXPORT ResultSet { - private: - std::vector _content; - Nd4jStatus _status = ND4J_STATUS_OK; - bool _removable = true; +class SD_EXPORT ResultSet { + private: + std::vector _content; + Nd4jStatus _status = ND4J_STATUS_OK; + bool _removable = true; - void delContent(); + void delContent(); - public: - explicit ResultSet(); + public: + explicit ResultSet(); #ifndef __JAVACPP_HACK__ - ResultSet(const sd::graph::FlatResult* result); + ResultSet(const sd::graph::FlatResult* result); #endif - ResultSet(const ResultSet& other) noexcept; + ResultSet(const ResultSet& other) noexcept; - ResultSet& operator=(const ResultSet& other) noexcept; + ResultSet& operator=(const ResultSet& other) noexcept; - // move constructor - ResultSet(ResultSet&& other) noexcept; + // move constructor + ResultSet(ResultSet&& other) noexcept; - // move assignment operator - ResultSet& operator=(ResultSet&& other) noexcept; + // move assignment operator + ResultSet& operator=(ResultSet&& other) noexcept; - ~ResultSet(); + ~ResultSet(); - int size(); - sd::NDArray& at(const unsigned long idx) const; - sd::NDArray& operator[](const unsigned long idx) const; - void push_back(const sd::NDArray &array); + int size(); + sd::NDArray& at(const unsigned long idx) const; + sd::NDArray& operator[](const unsigned long idx) const; + void push_back(const sd::NDArray& array); - Nd4jStatus status(); - void setStatus(Nd4jStatus status); - void purge(); - void setNonRemovable(); - }; -} + Nd4jStatus status(); + void setStatus(Nd4jStatus status); + void purge(); + void setNonRemovable(); +}; +} // namespace sd -#endif //LIBND4J_RESULTSET_H +#endif // LIBND4J_RESULTSET_H diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index 196930c75f08..83b9c875bac2 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -21,83 +21,96 @@ #ifndef SD_SHAPEDESCRIPTOR_H #define SD_SHAPEDESCRIPTOR_H -#include -#include +#include #include #include -#include + #include +#include +#include namespace sd { class SD_EXPORT ShapeDescriptor { - - private: - int _rank = 0; - std::vector _shape; - std::vector _strides; - Nd4jLong _ews = 1; - char _order = 'c'; - DataType _dataType; - bool _empty = false; - - public: - ShapeDescriptor(const ShapeDescriptor &other); - ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride); - explicit ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride); - explicit ShapeDescriptor(const DataType type, const Nd4jLong length); - explicit ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank); - explicit ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty); - explicit ShapeDescriptor(const DataType type, const char order, const std::initializer_list &shape); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides); - explicit ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides, const Nd4jLong ews); - ShapeDescriptor() = default; - ~ShapeDescriptor() = default; - - int rank() const; - Nd4jLong ews() const; - Nd4jLong arrLength() const; - char order() const; - DataType dataType() const; - bool isEmpty() const; - std::vector& shape(); - std::vector& strides(); - - // we use default copy assignment operator - ShapeDescriptor& operator=(const ShapeDescriptor& other) = default; - - // we use default move assignment operator - ShapeDescriptor& operator=(ShapeDescriptor&& other) noexcept = default; - - // equal to operator - bool operator==(const ShapeDescriptor &other) const; - - // less than operator - bool operator<(const ShapeDescriptor &other) const; - - Nd4jLong* toShapeInfo() const; - - - static ShapeDescriptor emptyDescriptor(const DataType type); - static ShapeDescriptor scalarDescriptor(const DataType type); - static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type); - }; -} + private: + int _rank = 0; + std::vector _shape; + std::vector _strides; + Nd4jLong _ews = 1; + char _order = 'c'; + DataType _dataType; + bool _empty = false; + + public: + ShapeDescriptor(const ShapeDescriptor &other); + ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype = true); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const sd::DataType dtypeOverride); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride); + explicit ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride, + const Nd4jLong *orderOverride); + explicit ShapeDescriptor(const DataType type, const Nd4jLong length); + explicit ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const int rank); + explicit ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const Nd4jLong *strides, + const int rank, Nd4jLong ews, const bool empty); + explicit ShapeDescriptor(const DataType type, const char order, + const std::initializer_list &shape); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides); + explicit ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides, + const Nd4jLong ews); + ShapeDescriptor() = default; + ~ShapeDescriptor() = default; + + int rank() const; + Nd4jLong ews() const; + Nd4jLong arrLength() const; + char order() const; + DataType dataType() const; + bool isEmpty() const; + std::vector &shape(); + std::vector &strides(); + + // we use default copy assignment operator + ShapeDescriptor &operator=(const ShapeDescriptor &other) = default; + + // we use default move assignment operator + ShapeDescriptor &operator=(ShapeDescriptor &&other) noexcept = default; + + // equal to operator + bool operator==(const ShapeDescriptor &other) const; + + // less than operator + bool operator<(const ShapeDescriptor &other) const; + + Nd4jLong *toShapeInfo() const; + + static ShapeDescriptor emptyDescriptor(const DataType type); + static ShapeDescriptor scalarDescriptor(const DataType type); + static ShapeDescriptor vectorDescriptor(const Nd4jLong length, + const DataType type); +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class SD_EXPORT hash { - public: - size_t operator()(const sd::ShapeDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::ShapeDescriptor &k) const; +}; +} // namespace std #endif - -#endif //SD_SHAPEDESCRIPTOR_H +#endif // SD_SHAPEDESCRIPTOR_H diff --git a/libnd4j/include/array/ShapeList.h b/libnd4j/include/array/ShapeList.h index 601570079cef..1e5e697b9ca0 100644 --- a/libnd4j/include/array/ShapeList.h +++ b/libnd4j/include/array/ShapeList.h @@ -21,38 +21,40 @@ #ifndef LIBND4J_SHAPELIST_H #define LIBND4J_SHAPELIST_H -#include #include #include +#include + namespace sd { - class SD_EXPORT ShapeList { - protected: - std::vector _shapes; - - bool _destroyed = false; - bool _autoremovable = false; - bool _workspace = false; - public: - ShapeList(const Nd4jLong* shape = nullptr); - ShapeList(const std::vector &shapes, bool isWorkspace); - ShapeList(const std::vector& shapes); - //ShapeList(bool autoRemovable); - - ~ShapeList(); - - std::vector* asVector(); - void destroy(); - int size() const; - const Nd4jLong* at(int idx); - void push_back(const Nd4jLong *shape); - - /** - * PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak - */ - void detach(); - }; -} - - -#endif //LIBND4J_SHAPELIST_H +class SD_EXPORT ShapeList { + protected: + std::vector _shapes; + + bool _destroyed = false; + bool _autoremovable = false; + bool _workspace = false; + + public: + ShapeList(const Nd4jLong* shape = nullptr); + ShapeList(const std::vector& shapes, bool isWorkspace); + ShapeList(const std::vector& shapes); + // ShapeList(bool autoRemovable); + + ~ShapeList(); + + std::vector* asVector(); + void destroy(); + int size() const; + const Nd4jLong* at(int idx); + void push_back(const Nd4jLong* shape); + + /** + * PLEASE NOTE: This method should be called ONLY if shapes were generated at + * workspaces. Otherwise you'll get memory leak + */ + void detach(); +}; +} // namespace sd + +#endif // LIBND4J_SHAPELIST_H diff --git a/libnd4j/include/array/SpaceType.h b/libnd4j/include/array/SpaceType.h index b6c6dfbbcfbc..8c65f69850b3 100644 --- a/libnd4j/include/array/SpaceType.h +++ b/libnd4j/include/array/SpaceType.h @@ -22,11 +22,11 @@ #define ND4J_SPACE_TYPE_H namespace sd { - enum SpaceType { - CONTINUOUS = 1, - COMPLEX = 2, - QUANTIZED = 3, - }; +enum SpaceType { + CONTINUOUS = 1, + COMPLEX = 2, + QUANTIZED = 3, +}; } #endif \ No newline at end of file diff --git a/libnd4j/include/array/SparseType.h b/libnd4j/include/array/SparseType.h index 3b77a1626424..c9084d8292f3 100644 --- a/libnd4j/include/array/SparseType.h +++ b/libnd4j/include/array/SparseType.h @@ -22,12 +22,12 @@ #define LIBND4J_SPARSETYPE_H namespace sd { - enum SparseType { - CSR = 1, - CSC = 2, - COO = 3, - LIL = 4, - }; +enum SparseType { + CSR = 1, + CSC = 2, + COO = 3, + LIL = 4, +}; } -#endif //LIBND4J_SPARSETYPE_H +#endif // LIBND4J_SPARSETYPE_H diff --git a/libnd4j/include/array/TadDescriptor.h b/libnd4j/include/array/TadDescriptor.h index e10525c2a681..4417c2bd17be 100644 --- a/libnd4j/include/array/TadDescriptor.h +++ b/libnd4j/include/array/TadDescriptor.h @@ -21,54 +21,58 @@ #ifndef SD_TADDESCRIPTOR_H #define SD_TADDESCRIPTOR_H -#include "ShapeDescriptor.h" #include +#include "ShapeDescriptor.h" + namespace sd { - class SD_EXPORT TadDescriptor { - private: - ShapeDescriptor _originalShape; +class SD_EXPORT TadDescriptor { + private: + ShapeDescriptor _originalShape; - std::vector _axis; + std::vector _axis; - bool _unitiesInShape; + bool _unitiesInShape; - public: - explicit TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, const int length, const bool keepUnitiesInShape = false); - explicit TadDescriptor(const ShapeDescriptor &descriptor, const std::vector &dimensions, const bool keepUnitiesInShape = false); - explicit TadDescriptor(const TadDescriptor &other); - ~TadDescriptor() = default; + public: + explicit TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, + const int length, + const bool keepUnitiesInShape = false); + explicit TadDescriptor(const ShapeDescriptor &descriptor, + const std::vector &dimensions, + const bool keepUnitiesInShape = false); + explicit TadDescriptor(const TadDescriptor &other); + ~TadDescriptor() = default; - // we use default copy assignment operator - TadDescriptor& operator=(const TadDescriptor& other) = default; + // we use default copy assignment operator + TadDescriptor &operator=(const TadDescriptor &other) = default; - // we use default move assignment operator - TadDescriptor& operator=(TadDescriptor&& other) noexcept = default; + // we use default move assignment operator + TadDescriptor &operator=(TadDescriptor &&other) noexcept = default; - // equal to operator - bool operator==(const TadDescriptor &other) const; + // equal to operator + bool operator==(const TadDescriptor &other) const; - // less than operator - bool operator<(const TadDescriptor &other) const; + // less than operator + bool operator<(const TadDescriptor &other) const; - std::vector& axis(); - ShapeDescriptor& originalShape(); - ShapeDescriptor const& originalShapeConst() const; - bool areUnitiesinShape() const; - }; -} + std::vector &axis(); + ShapeDescriptor &originalShape(); + ShapeDescriptor const &originalShapeConst() const; + bool areUnitiesinShape() const; +}; +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template<> - class SD_EXPORT hash { - public: - size_t operator()(const sd::TadDescriptor &k) const; - }; -} +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::TadDescriptor &k) const; +}; +} // namespace std #endif - -#endif //SD_TADDESCRIPTOR_H +#endif // SD_TADDESCRIPTOR_H diff --git a/libnd4j/include/array/TadPack.h b/libnd4j/include/array/TadPack.h index 42f585abb843..889d91774a4e 100644 --- a/libnd4j/include/array/TadPack.h +++ b/libnd4j/include/array/TadPack.h @@ -24,34 +24,36 @@ #include "ConstantDataBuffer.h" namespace sd { - class SD_EXPORT TadPack { - private: - ConstantDataBuffer _tadShape; - ConstantDataBuffer _tadOffsets; - Nd4jLong _numTads = 0 ; - int _shapeInfoLength = 0; - public: - explicit TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads); - TadPack() = default; - ~TadPack() = default; - - const Nd4jLong* primaryShapeInfo() const; - const Nd4jLong* primaryOffsets() const; - - const Nd4jLong* specialShapeInfo() const; - const Nd4jLong* specialOffsets() const; - - Nd4jLong numberOfTads() const; - int shapeInfoLength() const; - - /** - * These methods return either primary or special pointers depending on platform binaries were compiled for - * @return - */ - const Nd4jLong *platformShapeInfo() const; - const Nd4jLong *platformOffsets() const; - }; -} - - -#endif //SD_TADPACK_H +class SD_EXPORT TadPack { + private: + ConstantDataBuffer _tadShape; + ConstantDataBuffer _tadOffsets; + Nd4jLong _numTads = 0; + int _shapeInfoLength = 0; + + public: + explicit TadPack(ConstantDataBuffer& shapes, ConstantDataBuffer& offets, + Nd4jLong numTads); + TadPack() = default; + ~TadPack() = default; + + const Nd4jLong* primaryShapeInfo() const; + const Nd4jLong* primaryOffsets() const; + + const Nd4jLong* specialShapeInfo() const; + const Nd4jLong* specialOffsets() const; + + Nd4jLong numberOfTads() const; + int shapeInfoLength() const; + + /** + * These methods return either primary or special pointers depending on + * platform binaries were compiled for + * @return + */ + const Nd4jLong* platformShapeInfo() const; + const Nd4jLong* platformOffsets() const; +}; +} // namespace sd + +#endif // SD_TADPACK_H diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index 677518746e0e..7de64cf29ee3 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -23,122 +23,111 @@ #include namespace sd { - void DataBuffer::expand(const uint64_t size) { - if (size > _lenInBytes) { - // allocate new buffer - int8_t *newBuffer = nullptr; - ALLOCATE(newBuffer, _workspace, size, int8_t); - - // copy data from existing buffer - std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); - - if (_isOwnerPrimary) { - RELEASE(reinterpret_cast(_primaryBuffer), _workspace); - } - - _primaryBuffer = newBuffer; - _lenInBytes = size; - _isOwnerPrimary = true; - } - } +void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t* newBuffer = nullptr; + ALLOCATE(newBuffer, _workspace, size, int8_t); -//////////////////////////////////////////////////////////////////////// -void DataBuffer::setCountersToZero() { + // copy data from existing buffer + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + RELEASE(reinterpret_cast(_primaryBuffer), _workspace); + } + _primaryBuffer = newBuffer; + _lenInBytes = size; + _isOwnerPrimary = true; + } } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyCounters(const DataBuffer& other) { +void DataBuffer::setCountersToZero() {} -} //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case) - - allocatePrimary(); -} - +void DataBuffer::copyCounters(const DataBuffer& other) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { +void DataBuffer::allocateBuffers( + const bool allocBoth) { // always allocate primary buffer only (cpu case) - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = other.getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - if(other._primaryBuffer != nullptr) - std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); + allocatePrimary(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetHostBuffer) { - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - if(hostBuffer != nullptr) - std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); +void DataBuffer::copyBufferFrom(const DataBuffer& other, + size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetOther) { + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = other.getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (other._primaryBuffer != nullptr) + std::memcpy(static_cast(_primaryBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._primaryBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes); } - //////////////////////////////////////////////////////////////////////// -void DataBuffer::deleteSpecial() { - +void DataBuffer::copyBufferFromHost(const void* hostBuffer, + size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetHostBuffer) { + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (hostBuffer != nullptr) + std::memcpy(static_cast(_primaryBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(hostBuffer) + + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), + sizeToCopyinBytes); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::setSpecial(void* special, const bool isOwnerSpecail) { +void DataBuffer::deleteSpecial() {} -} +//////////////////////////////////////////////////////////////////////// +void DataBuffer::setSpecial(void* special, const bool isOwnerSpecail) {} //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { - - memset(primary(), 0, getLenInBytes()); + memset(primary(), 0, getLenInBytes()); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - -} +void DataBuffer::syncToPrimary(const LaunchContext* context, + const bool forceSync) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToSpecial(const bool forceSync) { - -} +void DataBuffer::syncToSpecial(const bool forceSync) {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateSpecial() { - -} +void DataBuffer::allocateSpecial() {} //////////////////////////////////////////////////////////////////////// -void DataBuffer::migrate() { - -} +void DataBuffer::migrate() {} ///////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes > dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); +void DataBuffer::memcpy(const DataBuffer& dst, const DataBuffer& src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error( + "DataBuffer::memcpy: Source data buffer is larger than destination"); - std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); - dst.readPrimary(); + std::memcpy(dst._primaryBuffer, src._primaryBuffer, src._lenInBytes); + dst.readPrimary(); } - void *DataBuffer::platform() { - return primary(); - } - +void* DataBuffer::platform() { return primary(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const { } -void DataBuffer::writeSpecial() const { } -void DataBuffer::readPrimary() const { } -void DataBuffer::readSpecial() const { } -bool DataBuffer::isPrimaryActual() const { return true;} -bool DataBuffer::isSpecialActual() const { return false;} - - -} +void DataBuffer::writePrimary() const {} +void DataBuffer::writeSpecial() const {} +void DataBuffer::readPrimary() const {} +void DataBuffer::readSpecial() const {} +bool DataBuffer::isPrimaryActual() const { return true; } +bool DataBuffer::isSpecialActual() const { return false; } + +} // namespace sd diff --git a/libnd4j/include/array/cpu/ManagedDataBuffer.cpp b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp index bacb6d7886e1..610b2faf9a3b 100644 --- a/libnd4j/include/array/cpu/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/cpu/ManagedDataBuffer.cpp @@ -21,11 +21,7 @@ #include namespace sd { - void *ManagedDataBuffer::primary() { - return _descriptor.address(); - } +void *ManagedDataBuffer::primary() { return _descriptor.address(); } - void *ManagedDataBuffer::special() { - return nullptr; - } -} \ No newline at end of file +void *ManagedDataBuffer::special() { return nullptr; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/cpu/NDArray.cpp b/libnd4j/include/array/cpu/NDArray.cpp index 84de96050ce6..03d74f2e9902 100644 --- a/libnd4j/include/array/cpu/NDArray.cpp +++ b/libnd4j/include/array/cpu/NDArray.cpp @@ -19,421 +19,468 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include +#include +#include +#include +#include #include -#include +#include #include +#include #include -#include + #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include -#include -#include - +#include namespace sd { //////////////////////////////////////////////////////////////////////// -void* NDArray::platformBuffer() { return buffer(); } -void const* NDArray::platformBuffer() const { return buffer(); } +void* NDArray::platformBuffer() { return buffer(); } +void const* NDArray::platformBuffer() const { return buffer(); } Nd4jLong const* NDArray::platformShapeInfo() const { return shapeInfo(); } -void NDArray::syncToDevice() const { } -void NDArray::syncToHost() const { } -void NDArray::tickWriteHost() const { } -void NDArray::tickWriteDevice() const { } -void NDArray::tickReadHost() const { } -void NDArray::tickReadDevice() const { } -void NDArray::tickBothActual() const { } -bool NDArray::isActualOnHostSide() const { return true; } -bool NDArray::isActualOnDeviceSide() const { return true; } -void NDArray::makeBothBuffersActual() const { } - +void NDArray::syncToDevice() const {} +void NDArray::syncToHost() const {} +void NDArray::tickWriteHost() const {} +void NDArray::tickWriteDevice() const {} +void NDArray::tickReadHost() const {} +void NDArray::tickReadDevice() const {} +void NDArray::tickBothActual() const {} +bool NDArray::isActualOnHostSide() const { return true; } +bool NDArray::isActualOnDeviceSide() const { return true; } +void NDArray::makeBothBuffersActual() const {} //////////////////////////////////////////////////////////////////////// template -void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { - - if (isS()) - throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); - - if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) - throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); - - if (direction == 'u') - lower = -target.sizeAt(-2); - else if (direction == 'l') - upper = target.sizeAt(-1); - - const T value = static_cast(val); - const auto x = reinterpret_cast(buffer()); - auto z = reinterpret_cast(target.buffer()); - - const int xRank = rankOf(); - const int zRank = target.rankOf(); - - const auto zLen = target.lengthOf(); - - const bool areSameOffsets = shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK], temp; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, target.shapeInfo(), coords); - const auto zOffset = shape::getOffset(target.shapeInfo(), coords); - - // if( (row + upper < col) || (row + lower > col) ) - if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) - z[zOffset] = value; - else if (this != &target) { // when this and target are different arrays - if (xRank != zRank) { - temp = coords[0]; - coords[0] = coords[1]; - } +void NDArray::fillAsTriangular(const float val, int lower, int upper, + NDArray& target, const char direction) { + if (isS()) + throw std::runtime_error( + "NDArray::fillArrayAsTriangular: you can't use this method on String " + "array!"); + + if (!isSameShape(target) && + !(rankOf() == 1 && target.rankOf() == 2 && + sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) + throw std::string( + "NDArray::fillArrayAsTriangular method: wrong shape of target array !"); + + if (direction == 'u') + lower = -target.sizeAt(-2); + else if (direction == 'l') + upper = target.sizeAt(-1); + + const T value = static_cast(val); + const auto x = reinterpret_cast(buffer()); + auto z = reinterpret_cast(target.buffer()); + + const int xRank = rankOf(); + const int zRank = target.rankOf(); + + const auto zLen = target.lengthOf(); + + const bool areSameOffsets = + shape::haveSameShapeAndStrides(shapeInfo(), target.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, target.shapeInfo(), coords); + const auto zOffset = shape::getOffset(target.shapeInfo(), coords); + + // if( (row + upper < col) || (row + lower > col) ) + if ((coords[zRank - 2] + upper < coords[zRank - 1]) || + (coords[zRank - 2] + lower > coords[zRank - 1])) + z[zOffset] = value; + else if (this != &target) { // when this and target are different arrays + if (xRank != zRank) { + temp = coords[0]; + coords[0] = coords[1]; + } - const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords); - z[zOffset] = x[xOffset]; + const auto xOffset = + areSameOffsets ? zOffset : shape::getOffset(shapeInfo(), coords); + z[zOffset] = x[xOffset]; - if (xRank != zRank) // restore first coordinate - coords[0] = temp; - } - } - }; + if (xRank != zRank) // restore first coordinate + coords[0] = temp; + } + } + }; - samediff::Threads::parallel_for(func, 0, zLen); + samediff::Threads::parallel_for(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, + (const float val, int lower, int upper, NDArray& target, + const char direction), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { - if (isS()) - throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::setIdentity: you can't use this method on String array!"); - this->nullify(); + this->nullify(); - int rank = rankOf(); - auto shape = shapeOf(); - int minDim = MAX_INT; - Nd4jLong indices[MAX_RANK]; - for(int j = 0; j < rank; ++j) - indices[j] = 1; + int rank = rankOf(); + auto shape = shapeOf(); + int minDim = MAX_INT; + Nd4jLong indices[MAX_RANK]; + for (int j = 0; j < rank; ++j) indices[j] = 1; - Nd4jLong offset = shape::getOffset(shapeInfo(), indices); + Nd4jLong offset = shape::getOffset(shapeInfo(), indices); - for(int i = 0; i < rank; ++i) - if(minDim > shape[i]) - minDim = shape[i]; + for (int i = 0; i < rank; ++i) + if (minDim > shape[i]) minDim = shape[i]; - float v = 1.0f; + float v = 1.0f; - for(int i = 0; i < minDim; ++i) - templatedSet(buffer(), i*offset, this->dataType(), &v); + for (int i = 0; i < minDim; ++i) + templatedSet(buffer(), i * offset, this->dataType(), &v); } //////////////////////////////////////////////////////////////////////// template -static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) { - auto x = reinterpret_cast(xBuffer); - auto y = reinterpret_cast(yBuffer); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto temp = x[i]; - x[i] = y[i]; - y[i] = temp; - } - }; +static void templatedSwap(void* xBuffer, void* yBuffer, Nd4jLong length) { + auto x = reinterpret_cast(xBuffer); + auto y = reinterpret_cast(yBuffer); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto temp = x[i]; + x[i] = y[i]; + y[i] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, length); + samediff::Threads::parallel_for(func, 0, length); } -BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void templatedSwap, + (void* xBuffer, void* yBuffer, Nd4jLong length), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::swapUnsafe(NDArray& other) { - auto xType = this->dataType(); + auto xType = this->dataType(); - if (xType != other.dataType()) - throw std::runtime_error("NDArray::swapUnsage method: both arrays must have the same data type"); + if (xType != other.dataType()) + throw std::runtime_error( + "NDArray::swapUnsage method: both arrays must have the same data type"); - if(buffer() == nullptr || other.buffer() == nullptr) - throw std::runtime_error("NDArray::swapUnsafe method: input array should not be empty!"); + if (buffer() == nullptr || other.buffer() == nullptr) + throw std::runtime_error( + "NDArray::swapUnsafe method: input array should not be empty!"); - if(lengthOf() != other.lengthOf()) - throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!"); + if (lengthOf() != other.lengthOf()) + throw std::runtime_error( + "NDArray::swapUnsafe method: input arrays should have the same " + "length!"); - BUILD_SINGLE_SELECTOR(xType, templatedSwap, (buffer(), other.buffer(), this->lengthOf()), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, templatedSwap, + (buffer(), other.buffer(), this->lengthOf()), + LIBND4J_TYPES); } //////////////////////////////////////////////////////////////////////// void NDArray::synchronize(const char* msg) const { - // no-op + // no-op } -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - // no-op +void NDArray::prepareSpecialUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + // no-op } -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { - // no-op +void NDArray::registerSpecialUse(const std::vector& writeList, + const std::vector& readList) { + // no-op } -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - // no-op +void NDArray::preparePrimaryUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + // no-op } -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { - // no-op +void NDArray::registerPrimaryUse(const std::vector& writeList, + const std::vector& readList) { + // no-op } void NDArray::syncShape() const { - // no-op + // no-op } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const { - -} +template +void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const {} - //////////////////////////////////////////////////////////////////////// - void* NDArray::specialBufferWithOffset(Nd4jLong offset) { - return nullptr; - } +//////////////////////////////////////////////////////////////////////// +void* NDArray::specialBufferWithOffset(Nd4jLong offset) { return nullptr; } //////////////////////////////////////////////////////////////////////// const void* NDArray::specialBufferWithOffset(Nd4jLong offset) const { - return nullptr; + return nullptr; } //////////////////////////////////////////////////////////////////////// void* NDArray::specialBuffer() { - if (_buffer->special() == nullptr) - return buffer(); - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) return buffer(); + // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } //////////////////////////////////////////////////////////////////////// void const* NDArray::specialBuffer() const { - if (_buffer->special() == nullptr) - return buffer(); - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) return buffer(); + // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } - ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. NDArray NDArray::tile(const std::vector& reps) const { - const int repsSize = reps.size(); - - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - if(product == 0) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = rankOf(); - int diff = rankOld - repsSize; - if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do - NDArray result(*this); - if(diff < 0) { // reshape to higher dimension - std::vector shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape - memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions - result.reshapei(ordering(), shapeNew); - } - return result; // nothing to do, if diff >= 0 -> identity tile + const int repsSize = reps.size(); + + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + if (product == 0) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = rankOf(); + int diff = rankOld - repsSize; + if (product == 1) { // in this case 2 possibilities are present: just reshape + // or nothing to do + NDArray result(*this); + if (diff < 0) { // reshape to higher dimension + std::vector shapeNew = + reps; // there is requirement to have unities at first "diff" + // positions of new shape + memcpy( + &shapeNew[-diff], result.shapeInfo() + 1, + rankOld * + sizeof(Nd4jLong)); // put old shape numbers at rest of positions + result.reshapei(ordering(), shapeNew); } + return result; // nothing to do, if diff >= 0 -> identity tile + } + + // evaluate shapeInfo for resulting array + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + // create new buffer, in any case the memory amount new buffer points to is + // bigger then those for old _buffer + std::shared_ptr newBuff = + std::make_shared(shape::length(newShapeInfo) * sizeOfT(), + dataType(), getContext()->getWorkspace()); + // assign new shape and new buffer to resulting array + NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const auto resultLen = result.lengthOf(); + auto xType = this->dataType(); + if (result.ordering() == 'c') { // ews == 1 always here - // evaluate shapeInfo for resulting array - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - // create new buffer, in any case the memory amount new buffer points to is bigger then those for old _buffer - std::shared_ptr newBuff = std::make_shared(shape::length(newShapeInfo) * sizeOfT(), dataType(), getContext()->getWorkspace()); - // assign new shape and new buffer to resulting array - NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const auto resultLen = result.lengthOf(); - auto xType = this->dataType(); - if(result.ordering() == 'c') { // ews == 1 always here - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); - BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), i, this->buffer(), yOffset), LIBND4J_TYPES); - } - }; - - samediff::Threads::parallel_for(func, 0, resultLen); - } - else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); + BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, + (result.buffer(), i, this->buffer(), yOffset), + LIBND4J_TYPES); + } + }; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xOffset = result.getOffset(i); - auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); - BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES); - } - }; + samediff::Threads::parallel_for(func, 0, resultLen); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xOffset = result.getOffset(i); + auto yOffset = shape::subArrayOffset(i, newShapeInfo, shapeInfo()); + BUILD_SINGLE_SELECTOR( + xType, this->template templatedAssign, + (result.buffer(), xOffset, this->buffer(), yOffset), LIBND4J_TYPES); + } + }; - samediff::Threads::parallel_for(func, 0, resultLen); - } - result.tickWriteHost(); - return result; + samediff::Threads::parallel_for(func, 0, resultLen); + } + result.tickWriteHost(); + return result; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tile(const std::vector& reps, NDArray& target) const { - - auto repProd = shape::prodLong(reps.data(), reps.size()); - if (repProd < 1) - throw std::runtime_error("NDArray::tile: reps can't contain 0s"); - - // evaluate true tile shapeInfo for comparison with target shapeInfo - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { - delete []newShapeInfo; - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); + auto repProd = shape::prodLong(reps.data(), reps.size()); + if (repProd < 1) + throw std::runtime_error("NDArray::tile: reps can't contain 0s"); + + // evaluate true tile shapeInfo for comparison with target shapeInfo + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + if (!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { + delete[] newShapeInfo; + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + } + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const int ews = target.ews(); + const auto targetLen = target.lengthOf(); + if (target.ordering() == 'c' && ews == 1) { // ews == 1 always here + //#pragma omp parallel for simd if(targetLen > + //Environment::getInstance()->elementwiseThreshold()) schedule(guided) + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const int ews = target.ews(); - const auto targetLen = target.lengthOf(); - if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here -//#pragma omp parallel for simd if(targetLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided) - for(Nd4jLong i=0; i 1) { - for(Nd4jLong i=0; i 1) { + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i * ews, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - else { - - for(Nd4jLong i=0; i target.rankOf()) - throw std::runtime_error("NDArray::tile method - rank of target array must be bigger or equal to the rank of this array !"); - - if(!ShapeUtils::areShapesBroadcastable(*this, target)) - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - - // fill newBuff, loop through all elements of newBuff - // looping through _buffer goes automatically by means of getSubArrayIndex applying - const auto ews = target.ews(); - const auto targetLen = target.lengthOf(); - if(target.ordering() == 'c' && ews >= 1) { - - for(Nd4jLong i=0; i target.rankOf()) + throw std::runtime_error( + "NDArray::tile method - rank of target array must be bigger or equal " + "to the rank of this array !"); + + if (!ShapeUtils::areShapesBroadcastable(*this, target)) + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + + // fill newBuff, loop through all elements of newBuff + // looping through _buffer goes automatically by means of getSubArrayIndex + // applying + const auto ews = target.ews(); + const auto targetLen = target.lengthOf(); + if (target.ordering() == 'c' && ews >= 1) { + for (Nd4jLong i = 0; i < targetLen; ++i) { + auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); + BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), + templatedDoubleAssign, + (target.buffer(), i * ews, buffer(), yOffset), + LIBND4J_TYPES, LIBND4J_TYPES); } - else { - - for(Nd4jLong i=0; i -static void repeat_(const NDArray& input, NDArray& output, const std::vector& repeats, const int axis) { - - const X* x = input.bufferAsT(); - Z* z = output.bufferAsT(); - - const int rank = input.rankOf(); // xRank = zRank - const int zLen = output.lengthOf(); // xLen <= zLen - const uint repSize = repeats.size(); - - // loop through input array - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK], temp; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - - temp = coords[axis]; - - if (repSize > 1) { - for (uint j = 0; j < repSize; ++j) { - coords[axis] -= repeats[j]; - if (coords[axis] < 0) { - coords[axis] = j; - break; - } - } - } else - coords[axis] /= repeats[0]; +template +static void repeat_(const NDArray& input, NDArray& output, + const std::vector& repeats, const int axis) { + const X* x = input.bufferAsT(); + Z* z = output.bufferAsT(); + + const int rank = input.rankOf(); // xRank = zRank + const int zLen = output.lengthOf(); // xLen <= zLen + const uint repSize = repeats.size(); + + // loop through input array + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + + temp = coords[axis]; + + if (repSize > 1) { + for (uint j = 0; j < repSize; ++j) { + coords[axis] -= repeats[j]; + if (coords[axis] < 0) { + coords[axis] = j; + break; + } + } + } else + coords[axis] /= repeats[0]; - z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)]; + z[zOffset] = x[shape::getOffset(input.shapeInfo(), coords)]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, zLen); + samediff::Threads::parallel_for(func, 0, zLen); } ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), + dataType(), getContext()); - NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); - - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, + (*this, output, repeats, axis), LIBND4J_TYPES); - return output; + return output; } ////////////////////////////////////////////////////////////////////////// // fill array by repeating it the number of times given by reps -void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) const { - - if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) - throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) method: wrong shape of target array!"); - - BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeat_, (*this, target, repeats, axis), LIBND4J_TYPES, LIBND4J_TYPES); +void NDArray::repeat(const int axis, const std::vector& repeats, + NDArray& target) const { + if (!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) + throw std::invalid_argument( + "NDArray::repeat(const int axis, const std::vector& repeats, " + "NDArray& target) method: wrong shape of target array!"); + + BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeat_, + (*this, target, repeats, axis), LIBND4J_TYPES, + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -448,9 +495,6 @@ void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& t #include "NDArray.macro" #endif */ -} - - +} // namespace sd #endif - diff --git a/libnd4j/include/array/cpu/NDArrayLambda.hpp b/libnd4j/include/array/cpu/NDArrayLambda.hpp index bd8742288c80..709cd55924c7 100644 --- a/libnd4j/include/array/cpu/NDArrayLambda.hpp +++ b/libnd4j/include/array/cpu/NDArrayLambda.hpp @@ -1,332 +1,473 @@ +template +void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, + const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != second.dataType() || dataType() != third.dataType() || + dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyTriplewiseLambda method: bother four arrays (this, " + "second, third, target) should have the same type !"); + + if (this->lengthOf() != second.lengthOf() || + this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || + !this->isSameShape(third)) { + nd4j_printf( + "applyTriplewiseLambda requires all operands to have the same shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = second.bufferAsT(); + auto t = third.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == second.ordering() && + this->ordering() == third.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && + this->ews() == third.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e], s[e], t[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto tOffset = this->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + + f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); + } + }; -template -void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); - - if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { - nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = second.bufferAsT(); - auto t = third.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e], s[e], t[e]); - }; - - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto tOffset = this->getOffset(e); - auto uOffset = second.getOffset(e); - auto vOffset = third.getOffset(e); - - f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto tOffset = this->getOffset(e); - auto uOffset = second.getOffset(e); - auto vOffset = third.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto tOffset = this->getOffset(e); + auto uOffset = second.getOffset(e); + auto vOffset = third.getOffset(e); + auto zOffset = target.getOffset(e); + + z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); -template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, + NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); +template void NDArray::applyTriplewiseLambda( + NDArray& second, NDArray& third, + const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != other.dataType() || dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); - - if (this->lengthOf() != other.lengthOf()) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e], s[e]); - }; +template +void NDArray::applyPairwiseLambda(const NDArray& other, + const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != other.dataType() || dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyPairwiseLambda method: all three arrays (this, " + "other, target) must have the same type !"); + + if (this->lengthOf() != other.lengthOf()) { + nd4j_printf( + "applyPairwiseLambda requires both operands to have the same shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == other.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e], s[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + + f[xOffset] = func(f[xOffset], s[yOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - - f[xOffset] = func(f[xOffset], s[yOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(f[xOffset], s[yOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, + const std::function& func, NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyPairwiseLambda( + const NDArray& other, const std::function& func, + NDArray& target); ////////////////////////////////////////////////////////////////////////// -template +template void NDArray::applyLambda(const std::function& func, NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyLambda method: wrong template parameter T, its type " + "should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyLambda method: types of this and target array should " + "match !"); + + auto f = this->bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1)) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(f[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + + f[xOffset] = func(f[xOffset]); + } + }; - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(f[e]); - }; - - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - - f[xOffset] = func(f[xOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(f[xOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(f[xOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); -template void NDArray::applyLambda(const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda( + const std::function& func, NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); +template void NDArray::applyLambda(const std::function& func, + NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedLambda(const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func(e, f[e]); - }; +template +void NDArray::applyIndexedLambda(const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyIndexedLambda method: wrong template parameter T, " + "its type should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedLambda method: types of this and target array " + "should match !"); + + auto f = this->bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1)) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func(e, f[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + + f[xOffset] = func(e, f[xOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - - f[xOffset] = func(e, f[xOffset]); - } - }; + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto zOffset = target.getOffset(e); - samediff::Threads::parallel_for(loop, 0, _length); - } else { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func(e, f[xOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func(e, f[xOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); +template void NDArray::applyIndexedLambda( + const std::function& func, NDArray& target); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target) { - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target.dataType()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); - if (this->lengthOf() != other.lengthOf()) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other.bufferAsT(); - auto z = target.bufferAsT(); - - if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - z[e] = func((Nd4jLong) e, f[e], s[e]); - }; +template +void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target) { + if (dataType() != DataTypeUtils::fromT()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda method: wrong template " + "parameter T, its type should be the same as type of this array!"); + if (dataType() != target.dataType()) + throw std::runtime_error( + "NDArray::applyIndexedPairwiseLambda method: types of this and " + "target array should match !"); + if (this->lengthOf() != other.lengthOf()) { + nd4j_printf( + "applyIndexedPairwiseLambda requires both operands to have the same " + "shape\n", + ""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other.bufferAsT(); + auto z = target.bufferAsT(); + + if (this->ordering() == other.ordering() && + this->ordering() == target.ordering() && + (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) z[e] = func((Nd4jLong)e, f[e], s[e]); + }; + + samediff::Threads::parallel_for(loop, 0, _length); + } else { + if (f == z) { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + + f[xOffset] = func((Nd4jLong)e, f[xOffset], s[yOffset]); + } + }; - samediff::Threads::parallel_for(loop, 0, _length); + samediff::Threads::parallel_for(loop, 0, _length); } else { - if (f == z) { - - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - - f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); - } else { + auto loop = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto xOffset = this->getOffset(e); + auto yOffset = other.getOffset(e); + auto zOffset = target.getOffset(e); - auto loop = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto xOffset = this->getOffset(e); - auto yOffset = other.getOffset(e); - auto zOffset = target.getOffset(e); - - z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - }; - - samediff::Threads::parallel_for(loop, 0, _length); + z[zOffset] = func((Nd4jLong)e, f[xOffset], s[yOffset]); } + }; + + samediff::Threads::parallel_for(loop, 0, _length); } + } } -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); -template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function& func, NDArray& target); \ No newline at end of file +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, + const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); +template void NDArray::applyIndexedPairwiseLambda( + NDArray& other, const std::function& func, + NDArray& target); \ No newline at end of file diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index 9660d38669e7..81ba2f98b1bf 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -19,278 +19,321 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include "../DataBuffer.h" #include -#include +#include #include #include #include -#include +#include + +#include "../DataBuffer.h" namespace sd { - void DataBuffer::expand(const uint64_t size) { - if (size > _lenInBytes) { - // allocate new buffer - int8_t *newBuffer = nullptr; - int8_t *newSpecialBuffer = nullptr; - ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); - - // copy data from existing buffer - if (_primaryBuffer != nullptr) { - // there's non-zero chance that primary buffer doesn't exist yet - ALLOCATE(newBuffer, _workspace, size, int8_t); - std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); - - if (_isOwnerPrimary) { - auto ipb = reinterpret_cast(_primaryBuffer); - RELEASE(ipb, _workspace); - } - - _primaryBuffer = newBuffer; - _isOwnerPrimary = true; - } - - cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, cudaMemcpyDeviceToDevice); - - if (_isOwnerSpecial) { - auto isb = reinterpret_cast(_specialBuffer); - RELEASE_SPECIAL(isb, _workspace); - } - - _specialBuffer = newSpecialBuffer; - _lenInBytes = size; - _isOwnerSpecial = true; - } +void DataBuffer::expand(const uint64_t size) { + if (size > _lenInBytes) { + // allocate new buffer + int8_t* newBuffer = nullptr; + int8_t* newSpecialBuffer = nullptr; + ALLOCATE_SPECIAL(newSpecialBuffer, _workspace, size, int8_t); + + // copy data from existing buffer + if (_primaryBuffer != nullptr) { + // there's non-zero chance that primary buffer doesn't exist yet + ALLOCATE(newBuffer, _workspace, size, int8_t); + std::memcpy(newBuffer, _primaryBuffer, _lenInBytes); + + if (_isOwnerPrimary) { + auto ipb = reinterpret_cast(_primaryBuffer); + RELEASE(ipb, _workspace); + } + + _primaryBuffer = newBuffer; + _isOwnerPrimary = true; } -//////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateSpecial() { - - if (_specialBuffer == nullptr && getLenInBytes() > 0) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - - if (_workspace == nullptr) - if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds device limits", sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes()); - + cudaMemcpy(newSpecialBuffer, _specialBuffer, _lenInBytes, + cudaMemcpyDeviceToDevice); - ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerSpecial = true; - - if (_workspace == nullptr) { - sd::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes()); - sd::memory::MemoryCounter::getInstance()->countIn(sd::memory::MemoryType::DEVICE, getLenInBytes()); - } + if (_isOwnerSpecial) { + auto isb = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(isb, _workspace); } + + _specialBuffer = newSpecialBuffer; + _lenInBytes = size; + _isOwnerSpecial = true; + } } //////////////////////////////////////////////////////////////////////// -void DataBuffer::syncToPrimary(const LaunchContext* context, const bool forceSync) { - if(isPrimaryActual() && !forceSync) { - return; - } - - allocatePrimary(); +void DataBuffer::allocateSpecial() { + if (_specialBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = sd::AffinityManager::currentDeviceId(); - auto res = cudaStreamSynchronize(*context->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToPrimary failed to to some previous kernel failre", res); + if (_workspace == nullptr) + if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds device limits", + sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), + getLenInBytes()); - res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToHost); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", res); + ALLOCATE_SPECIAL(_specialBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerSpecial = true; - readPrimary(); + if (_workspace == nullptr) { + sd::memory::MemoryCounter::getInstance()->countIn(deviceId, + getLenInBytes()); + sd::memory::MemoryCounter::getInstance()->countIn( + sd::memory::MemoryType::DEVICE, getLenInBytes()); + } + } } +//////////////////////////////////////////////////////////////////////// +void DataBuffer::syncToPrimary(const LaunchContext* context, + const bool forceSync) { + if (isPrimaryActual() && !forceSync) { + return; + } + + allocatePrimary(); + + auto res = cudaStreamSynchronize(*context->getCudaStream()); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::syncToPrimary failed to to some previous kernel failre", + res); + + res = cudaMemcpy(_primaryBuffer, _specialBuffer, getLenInBytes(), + cudaMemcpyDeviceToHost); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToPrimary cudaMemcpy failed", + res); + + readPrimary(); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::syncToSpecial(const bool forceSync) { - // in this case there's nothing to do here - if (_primaryBuffer == nullptr) - return; + // in this case there's nothing to do here + if (_primaryBuffer == nullptr) return; - if(isSpecialActual() && !forceSync) { - return; - } + if (isSpecialActual() && !forceSync) { + return; + } - allocateSpecial(); + allocateSpecial(); - auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", res); + auto res = cudaMemcpy(_specialBuffer, _primaryBuffer, getLenInBytes(), + cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::syncToSpecial cudaMemcpy failed", + res); - readSpecial(); + readSpecial(); } - //////////////////////////////////////////////////////////////////////// void DataBuffer::deleteSpecial() { - - if(_isOwnerSpecial && _specialBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_specialBuffer); - RELEASE_SPECIAL(p, _workspace); - _specialBuffer = nullptr; - _isOwnerSpecial = false; - - // count out towards DataBuffer device, only if we're not in workspace - if (_workspace == nullptr) { - sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes()); - sd::memory::MemoryCounter::getInstance()->countOut(sd::memory::MemoryType::DEVICE, getLenInBytes()); - } + if (_isOwnerSpecial && _specialBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_specialBuffer); + RELEASE_SPECIAL(p, _workspace); + _specialBuffer = nullptr; + _isOwnerSpecial = false; + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, + getLenInBytes()); + sd::memory::MemoryCounter::getInstance()->countOut( + sd::memory::MemoryType::DEVICE, getLenInBytes()); } + } } - //////////////////////////////////////////////////////////////////////// void DataBuffer::setCountersToZero() { - - _counter.store(0L); - _writePrimary.store(0L); - _writeSpecial.store(0L); - _readPrimary.store(0L); - _readSpecial.store(0L); + _counter.store(0L); + _writePrimary.store(0L); + _writeSpecial.store(0L); + _readPrimary.store(0L); + _readSpecial.store(0L); } //////////////////////////////////////////////////////////////////////// void DataBuffer::copyCounters(const DataBuffer& other) { - - _counter.store(other._counter); - _writePrimary.store(other._readSpecial); - _writeSpecial.store(other._readPrimary); - _readPrimary.store(other._writeSpecial); - _readSpecial.store(other._writePrimary); + _counter.store(other._counter); + _writePrimary.store(other._readSpecial); + _writeSpecial.store(other._readPrimary); + _readPrimary.store(other._writeSpecial); + _readSpecial.store(other._writePrimary); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer - - if(other._primaryBuffer == nullptr && other._specialBuffer == nullptr) - return; - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = other.getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - if(other.isPrimaryActual()) { - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyHostToDevice failed!", res); - other.readPrimary(); - } - else { - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._specialBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes, cudaMemcpyDeviceToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyDeviceToDevice failed!", res); - other.readSpecial(); - } - - writeSpecial(); +void DataBuffer::copyBufferFrom( + const DataBuffer& other, size_t sizeToCopyinBytes, + const Nd4jLong offsetThis, + const Nd4jLong offsetOther) { // copies only to special buffer + + if (other._primaryBuffer == nullptr && other._specialBuffer == nullptr) + return; + + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = other.getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + if (other.isPrimaryActual()) { + auto res = cudaMemcpy( + static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._primaryBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes, cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyHostToDevice " + "failed!", + res); + other.readPrimary(); + } else { + auto res = cudaMemcpy( + static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(other._specialBuffer) + + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), + sizeToCopyinBytes, cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFrom: cudaMemcpy_cudaMemcpyDeviceToDevice " + "failed!", + res); + other.readSpecial(); + } + + writeSpecial(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetHostBuffer) { // copies only to special buffer - - if(hostBuffer == nullptr) - return; - - if(sizeToCopyinBytes == 0) - sizeToCopyinBytes = getLenInBytes(); - if(sizeToCopyinBytes == 0) - return; - - auto res = cudaMemcpy(static_cast(_specialBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::copyBufferFromHost: cudaMemcpy_cudaMemcpyHostToDevice failed!", res); - - writeSpecial(); +void DataBuffer::copyBufferFromHost( + const void* hostBuffer, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, + const Nd4jLong offsetHostBuffer) { // copies only to special buffer + + if (hostBuffer == nullptr) return; + + if (sizeToCopyinBytes == 0) sizeToCopyinBytes = getLenInBytes(); + if (sizeToCopyinBytes == 0) return; + + auto res = + cudaMemcpy(static_cast(_specialBuffer) + + offsetThis * DataTypeUtils::sizeOfElement(_dataType), + static_cast(hostBuffer) + + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), + sizeToCopyinBytes, cudaMemcpyHostToDevice); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::copyBufferFromHost: cudaMemcpy_cudaMemcpyHostToDevice " + "failed!", + res); + + writeSpecial(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::setSpecial(void* special, const bool isOwnerSpecial) { - - deleteSpecial(); - _specialBuffer = special; - _isOwnerSpecial = isOwnerSpecial; + deleteSpecial(); + _specialBuffer = special; + _isOwnerSpecial = isOwnerSpecial; } - //////////////////////////////////////////////////////////////////////// -void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate special buffer only (cuda case) +void DataBuffer::allocateBuffers( + const bool allocBoth) { // always allocate special buffer only (cuda case) - allocateSpecial(); + allocateSpecial(); - if(allocBoth) - allocatePrimary(); + if (allocBoth) allocatePrimary(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::setToZeroBuffers(const bool both) { - cudaMemsetAsync(special(), 0, getLenInBytes(), *LaunchContext::defaultContext()->getCudaStream()); - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::setToZeroBuffers: streamSync failed!", res); - - writeSpecial(); - - if(both) { - memset(primary(), 0, getLenInBytes()); - readPrimary(); - } + cudaMemsetAsync(special(), 0, getLenInBytes(), + *LaunchContext::defaultContext()->getCudaStream()); + auto res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build( + "DataBuffer::setToZeroBuffers: streamSync failed!", res); + + writeSpecial(); + + if (both) { + memset(primary(), 0, getLenInBytes()); + readPrimary(); + } } ///////////////////////// -void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { - if (src._lenInBytes > dst._lenInBytes) - throw std::runtime_error("DataBuffer::memcpy: Source data buffer is larger than destination"); - - - int res = 0; - if (src.isSpecialActual()) { - res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, src.getLenInBytes(), cudaMemcpyDeviceToDevice, *LaunchContext::defaultContext()->getCudaStream()); - } else if (src.isPrimaryActual()) { - res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, src.getLenInBytes(), cudaMemcpyHostToDevice, *LaunchContext::defaultContext()->getCudaStream()); - } - - if (res != 0) - throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", res); - - res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); - - dst.writeSpecial(); +void DataBuffer::memcpy(const DataBuffer& dst, const DataBuffer& src) { + if (src._lenInBytes > dst._lenInBytes) + throw std::runtime_error( + "DataBuffer::memcpy: Source data buffer is larger than destination"); + + int res = 0; + if (src.isSpecialActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._specialBuffer, + src.getLenInBytes(), cudaMemcpyDeviceToDevice, + *LaunchContext::defaultContext()->getCudaStream()); + } else if (src.isPrimaryActual()) { + res = cudaMemcpyAsync(dst._specialBuffer, src._primaryBuffer, + src.getLenInBytes(), cudaMemcpyHostToDevice, + *LaunchContext::defaultContext()->getCudaStream()); + } + + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: cudaMemcpyAsync failed!", + res); + + res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("DataBuffer::memcpy: streamSync failed!", res); + + dst.writeSpecial(); } //////////////////////////////////////////////////////////////////////// void DataBuffer::migrate() { - memory::Workspace* newWorkspace = nullptr; - void* newBuffer; - ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); - auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice); - if (res != 0) - throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", res); - - if (_isOwnerSpecial) { - // now we're releasing original buffer - RELEASE_SPECIAL(_specialBuffer, _workspace); - } - - _isOwnerSpecial = true; - _specialBuffer = newBuffer; + memory::Workspace* newWorkspace = nullptr; + void* newBuffer; + ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t); + auto res = cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), + cudaMemcpyDeviceToDevice); + if (res != 0) + throw cuda_exception::build("DataBuffer::migrate: cudaMemcpyAsync failed!", + res); + + if (_isOwnerSpecial) { + // now we're releasing original buffer + RELEASE_SPECIAL(_specialBuffer, _workspace); + } + + _isOwnerSpecial = true; + _specialBuffer = newBuffer; } - void *DataBuffer::platform() { - return special(); - } +void* DataBuffer::platform() { return special(); } //////////////////////////////////////////////////////////////////////// -void DataBuffer::writePrimary() const {_writePrimary = ++_counter; } -void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } -void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } -void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } -bool DataBuffer::isPrimaryActual() const { return (_writePrimary.load() > _writeSpecial.load() || _readPrimary.load() > _writeSpecial.load()); } -bool DataBuffer::isSpecialActual() const { return (_writeSpecial.load() > _writePrimary.load() || _readSpecial.load() > _writePrimary.load()); } - +void DataBuffer::writePrimary() const { _writePrimary = ++_counter; } +void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; } +void DataBuffer::readPrimary() const { _readPrimary = ++_counter; } +void DataBuffer::readSpecial() const { _readSpecial = ++_counter; } +bool DataBuffer::isPrimaryActual() const { + return (_writePrimary.load() > _writeSpecial.load() || + _readPrimary.load() > _writeSpecial.load()); } +bool DataBuffer::isSpecialActual() const { + return (_writeSpecial.load() > _writePrimary.load() || + _readSpecial.load() > _writePrimary.load()); +} + +} // namespace sd diff --git a/libnd4j/include/array/cuda/NDArray.cu b/libnd4j/include/array/cuda/NDArray.cu index 88c016356174..e9a39a3314f5 100644 --- a/libnd4j/include/array/cuda/NDArray.cu +++ b/libnd4j/include/array/cuda/NDArray.cu @@ -19,556 +19,672 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include -#include +#include #include +#include +#include #include -#include + #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace sd { -void* NDArray::platformBuffer() { return specialBuffer(); } -void const* NDArray::platformBuffer() const { return specialBuffer(); } +void* NDArray::platformBuffer() { return specialBuffer(); } +void const* NDArray::platformBuffer() const { return specialBuffer(); } -Nd4jLong const* NDArray::platformShapeInfo() const { return specialShapeInfo(); } -//Nd4jLong const* NDArray::platformShapeInfo() { return specialShapeInfo(); } +Nd4jLong const* NDArray::platformShapeInfo() const { + return specialShapeInfo(); +} +// Nd4jLong const* NDArray::platformShapeInfo() { return +// specialShapeInfo(); } -void NDArray::syncToDevice() const { - auto currentDeviceId = AffinityManager::currentDeviceId(); - if (currentDeviceId != _deviceId) { - // first of all we update shapeInfo - const_cast(this)->setShapeInfo(this->shapeInfo()); +void NDArray::syncToDevice() const { + auto currentDeviceId = AffinityManager::currentDeviceId(); + if (currentDeviceId != _deviceId) { + // first of all we update shapeInfo + const_cast(this)->setShapeInfo(this->shapeInfo()); - // now we actually migrate data buffer - _buffer->migrate(); - } + // now we actually migrate data buffer + _buffer->migrate(); + } - _buffer->syncToSpecial(); + _buffer->syncToSpecial(); } -void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); } -void NDArray::tickWriteHost() const { _buffer->writePrimary(); } -void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); } -void NDArray::tickReadHost() const { _buffer->readPrimary(); } -void NDArray::tickReadDevice() const { _buffer->readSpecial(); } -void NDArray::tickBothActual() const { _buffer->writePrimary(); _buffer->readSpecial(); } -bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual(); } -bool NDArray::isActualOnDeviceSide() const { return _buffer->isSpecialActual(); } -void NDArray::makeBothBuffersActual() const { if(!isActualOnHostSide()) syncToHost(); if(!isActualOnDeviceSide()) syncToDevice(); } +void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); } +void NDArray::tickWriteHost() const { _buffer->writePrimary(); } +void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); } +void NDArray::tickReadHost() const { _buffer->readPrimary(); } +void NDArray::tickReadDevice() const { _buffer->readSpecial(); } +void NDArray::tickBothActual() const { + _buffer->writePrimary(); + _buffer->readSpecial(); +} +bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual(); } +bool NDArray::isActualOnDeviceSide() const { + return _buffer->isSpecialActual(); +} +void NDArray::makeBothBuffersActual() const { + if (!isActualOnHostSide()) syncToHost(); + if (!isActualOnDeviceSide()) syncToDevice(); +} /////////////////////////////////////////////////////////////////// -template -__global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int zRank, xRank, areSameOffsets, *sharedMem; // xRank == zRank always, except when xRank = 1, in this case zRank = 2 - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * zRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - // if( (row + upper < col) || (row + lower > col) ) - if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) - z[zOffset] = val; - else if(vx != vz) { // when x and z are different arrays - if(xRank != zRank) - coords[0] = coords[1]; - const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } +template +__global__ static void fillAsTriangularCuda( + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int zRank, xRank, areSameOffsets, + *sharedMem; // xRank == zRank always, except when xRank = 1, in this case + // zRank = 2 + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = + // 1, in this case zLen = 2*xLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * zRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + // if( (row + upper < col) || (row + lower > col) ) + if ((coords[zRank - 2] + upper < coords[zRank - 1]) || + (coords[zRank - 2] + lower > coords[zRank - 1])) + z[zOffset] = val; + else if (vx != vz) { // when x and z are different arrays + if (xRank != zRank) coords[0] = coords[1]; + const auto xOffset = + areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords); + z[zOffset] = x[xOffset]; } + } } /////////////////////////////////////////////////////////////////// -template -void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) { - - if (isS()) - throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); - - if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) - throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); - - if (direction == 'u') - lower = -target.sizeAt(-2); - else if (direction == 'l') - upper = target.sizeAt(-1); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * target.rankOf() + 128; - - PointersManager manager(getContext(), "NDArray::fillAsTriangular"); - - NDArray::prepareSpecialUse({&target}, {this}); - fillAsTriangularCuda<<getCudaStream()>>>(platformBuffer(), platformShapeInfo(), target.platformBuffer(), target.platformShapeInfo(), static_cast(val), lower, upper); - NDArray::registerSpecialUse({&target}, {this}); - - manager.synchronize(); +template +void NDArray::fillAsTriangular(const float val, int lower, int upper, + NDArray& target, const char direction) { + if (isS()) + throw std::runtime_error( + "NDArray::fillAsTriangular: you can't use this method on String " + "array!"); + + if (!isSameShape(target) && + !(rankOf() == 1 && target.rankOf() == 2 && + sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1))) + throw std::string( + "NDArray::fillAsTriangular method: wrong shape of target array !"); + + if (direction == 'u') + lower = -target.sizeAt(-2); + else if (direction == 'l') + upper = target.sizeAt(-1); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * target.rankOf() + 128; + + PointersManager manager(getContext(), "NDArray::fillAsTriangular"); + + NDArray::prepareSpecialUse({&target}, {this}); + fillAsTriangularCuda<<getCudaStream()>>>( + platformBuffer(), platformShapeInfo(), target.platformBuffer(), + target.platformShapeInfo(), static_cast(val), lower, upper); + NDArray::registerSpecialUse({&target}, {this}); + + manager.synchronize(); } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::fillAsTriangular, + (const float val, int lower, int upper, NDArray& target, + const char direction), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -template -__global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo, const T val) { - - auto x = reinterpret_cast(vx); - - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong len, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - rank = shape::rank(xShapeInfo); - len = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - shape::index2coords(i, xShapeInfo, coords); - const auto offset = shape::getOffset(xShapeInfo, coords); - - if(coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal - x[offset] = val; - else - x[offset] = static_cast(0); - } +template +__global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo, + const T val) { + auto x = reinterpret_cast(vx); + + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong len, totalThreads; // xLen == zLen, except when xRank = + // 1, in this case zLen = 2*xLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + rank = shape::rank(xShapeInfo); + len = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); + const auto offset = shape::getOffset(xShapeInfo, coords); + + if (coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal + x[offset] = val; + else + x[offset] = static_cast(0); + } } /////////////////////////////////////////////////////////////////// -template -static void identityMatrixCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* vx, const Nd4jLong *xShapeInfo, const float val) { - - identityMatrixCuda<<>>(vx, xShapeInfo, static_cast(val)); +template +static void identityMatrixCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, void* vx, + const Nd4jLong* xShapeInfo, + const float val) { + identityMatrixCuda<<>>( + vx, xShapeInfo, static_cast(val)); } -BUILD_SINGLE_TEMPLATE(template void identityMatrixCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* vx, const Nd4jLong *xShapeInfo, const float val), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void identityMatrixCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + void* vx, const Nd4jLong* xShapeInfo, const float val), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void NDArray::setIdentity() { - if (isS()) - throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!"); - - // if (rankOf() != 2) - // throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given."); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * rankOf() + 128; - - PointersManager manager(getContext(), "NDArray::setIdentity"); - - syncToDevice(); - BUILD_SINGLE_SELECTOR(dataType(), identityMatrixCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), platformBuffer(), platformShapeInfo(), 1.f), LIBND4J_TYPES); - tickWriteDevice(); - - manager.synchronize(); + if (isS()) + throw std::runtime_error( + "NDArray::setIdentity: you can't use this method on String array!"); + + // if (rankOf() != 2) + // throw std::runtime_error("NDArray::setIdentity: method should work only + // for 2D tensors. But " + toStringValue(rankOf()) + " was given."); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * rankOf() + 128; + + PointersManager manager(getContext(), "NDArray::setIdentity"); + + syncToDevice(); + BUILD_SINGLE_SELECTOR( + dataType(), identityMatrixCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + platformBuffer(), platformShapeInfo(), 1.f), + LIBND4J_TYPES); + tickWriteDevice(); + + manager.synchronize(); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// void NDArray::swapUnsafe(NDArray& other) { - auto xType = this->dataType(); - - if (xType != other.dataType()) - throw std::runtime_error("NDArray::swapUnsage method: both arrays must have the same data type"); - - if(specialBuffer() == nullptr || other.specialBuffer() == nullptr) - throw std::runtime_error("NDArray::swapUnsafe method: input array should not be empty!"); - - if(lengthOf() != other.lengthOf()) - throw std::runtime_error("NDArray::swapUnsafe method: input arrays should have the same length!"); - - BUILD_SINGLE_SELECTOR(xType, templatedSwapUnsafe, (specialBuffer(), specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), getContext()->getCudaStream()), LIBND4J_TYPES); + auto xType = this->dataType(); + + if (xType != other.dataType()) + throw std::runtime_error( + "NDArray::swapUnsage method: both arrays must have the same data type"); + + if (specialBuffer() == nullptr || other.specialBuffer() == nullptr) + throw std::runtime_error( + "NDArray::swapUnsafe method: input array should not be empty!"); + + if (lengthOf() != other.lengthOf()) + throw std::runtime_error( + "NDArray::swapUnsafe method: input arrays should have the same " + "length!"); + + BUILD_SINGLE_SELECTOR( + xType, templatedSwapUnsafe, + (specialBuffer(), specialShapeInfo(), other.specialBuffer(), + other.specialShapeInfo(), getContext()->getCudaStream()), + LIBND4J_TYPES); } //////////////////////////////////////////////////////////////////////// void NDArray::synchronize(const char* msg) const { - auto res = cudaStreamSynchronize(*(getContext()->getCudaStream())); - if (res != 0) - throw std::runtime_error(msg + std::string(": synchronization failed !")); + auto res = cudaStreamSynchronize(*(getContext()->getCudaStream())); + if (res != 0) + throw std::runtime_error(msg + std::string(": synchronization failed !")); } //////////////////////////////////////////////////////////////////////// -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - - for (const auto& a : readList) - if(a != nullptr) - a->syncToDevice(); - - for (const auto& a : writeList) { - if (a != nullptr) { - a->getDataBuffer()->allocateSpecial(); - if (synchronizeWritables) - a->syncToDevice(); - } +void NDArray::prepareSpecialUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& a : readList) + if (a != nullptr) a->syncToDevice(); + + for (const auto& a : writeList) { + if (a != nullptr) { + a->getDataBuffer()->allocateSpecial(); + if (synchronizeWritables) a->syncToDevice(); } + } } //////////////////////////////////////////////////////////////////////// -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { - - for (const auto& p : readList) - if(p != nullptr) - p->tickReadDevice(); +void NDArray::registerSpecialUse(const std::vector& writeList, + const std::vector& readList) { + for (const auto& p : readList) + if (p != nullptr) p->tickReadDevice(); - for (const auto& p : writeList) - if (p != nullptr) - p->tickWriteDevice(); + for (const auto& p : writeList) + if (p != nullptr) p->tickWriteDevice(); } //////////////////////////////////////////////////////////////////////// -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - - for (const auto& a : readList) - if(a != nullptr) - a->syncToHost(); - - for (const auto& a : writeList) { - if (a != nullptr) { - a->getDataBuffer()->allocatePrimary(); - if (synchronizeWritables) - a->syncToHost(); - } +void NDArray::preparePrimaryUse(const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& a : readList) + if (a != nullptr) a->syncToHost(); + + for (const auto& a : writeList) { + if (a != nullptr) { + a->getDataBuffer()->allocatePrimary(); + if (synchronizeWritables) a->syncToHost(); } + } } //////////////////////////////////////////////////////////////////////// -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { - - for (const auto& p : readList) - if(p != nullptr) - p->tickReadHost(); +void NDArray::registerPrimaryUse(const std::vector& writeList, + const std::vector& readList) { + for (const auto& p : readList) + if (p != nullptr) p->tickReadHost(); - for (const auto& p : writeList) - if (p != nullptr) - p->tickWriteHost(); + for (const auto& p : writeList) + if (p != nullptr) p->tickWriteHost(); } ////////////////////////////////////////////////////////////////////////// void NDArray::syncShape() const { - cudaMemcpy(const_cast(specialShapeInfo()), shapeInfo(), shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice); + cudaMemcpy(const_cast(specialShapeInfo()), shapeInfo(), + shape::shapeInfoByteLength(shapeInfo()), cudaMemcpyHostToDevice); } ////////////////////////////////////////////////////////////////////////// void const* NDArray::specialBufferWithOffset(Nd4jLong offset) const { - return specialBuffer() != nullptr ? static_cast(specialBuffer()) + (offset * sizeOfT()) : nullptr; + return specialBuffer() != nullptr + ? static_cast(specialBuffer()) + + (offset * sizeOfT()) + : nullptr; } -void* NDArray::specialBufferWithOffset(Nd4jLong offset){ - return specialBuffer() != nullptr ? static_cast(specialBuffer()) + (offset * sizeOfT()) : nullptr; +void* NDArray::specialBufferWithOffset(Nd4jLong offset) { + return specialBuffer() != nullptr + ? static_cast(specialBuffer()) + (offset * sizeOfT()) + : nullptr; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. NDArray NDArray::tile(const std::vector& reps) const { - int dim = reps.size(); - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - - if(product < 1) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = rankOf(); - int diff = rankOld - dim; - if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do - NDArray result(*this); - if(diff < 0) { // reshape to higher dimension - std::vector shapeNew = reps; // need to have unities at first "diff" positions of new shape - memcpy(&shapeNew[-diff], result.shapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions - result.reshapei(ordering(), shapeNew); - } - return result; // nothing to do, if diff >= 0 -> identity tile + int dim = reps.size(); + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + + if (product < 1) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = rankOf(); + int diff = rankOld - dim; + if (product == 1) { // in this case 2 possibilities are present: just reshape + // or nothing to do + NDArray result(*this); + if (diff < 0) { // reshape to higher dimension + std::vector shapeNew = + reps; // need to have unities at first "diff" positions of new shape + memcpy( + &shapeNew[-diff], result.shapeInfo() + 1, + rankOld * + sizeof(Nd4jLong)); // put old shape numbers at rest of positions + result.reshapei(ordering(), shapeNew); } - - // evaluate shapeInfo for resulting array - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - // create new buffer, in any case the memory amount new buffer points to is bigger then those for old _buffer - std::shared_ptr newBuff = std::make_shared(shape::length(newShapeInfo) * sizeOfT(), dataType(), getContext()->getWorkspace(), true); - // assign new shape and new buffer to resulting array - NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); - - // fill newBuff, loop through all elements of newBuff - // looping through buffer() goes automatically by means of getSubArrayIndex applying - const auto resultLen = result.lengthOf(); - auto xType = this->dataType(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&result}, {this}); - BUILD_SINGLE_SELECTOR(xType, tileKernelH, (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), result.specialShapeInfo(), resultLen, stream), LIBND4J_TYPES); - registerSpecialUse({&result}, {this}); - - return result; + return result; // nothing to do, if diff >= 0 -> identity tile + } + + // evaluate shapeInfo for resulting array + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + // create new buffer, in any case the memory amount new buffer points to is + // bigger then those for old _buffer + std::shared_ptr newBuff = std::make_shared( + shape::length(newShapeInfo) * sizeOfT(), dataType(), + getContext()->getWorkspace(), true); + // assign new shape and new buffer to resulting array + NDArray result(newBuff, ShapeDescriptor(newShapeInfo), getContext()); + + // fill newBuff, loop through all elements of newBuff + // looping through buffer() goes automatically by means of getSubArrayIndex + // applying + const auto resultLen = result.lengthOf(); + auto xType = this->dataType(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&result}, {this}); + BUILD_SINGLE_SELECTOR( + xType, tileKernelH, + (this->specialBuffer(), this->specialShapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), resultLen, stream), + LIBND4J_TYPES); + registerSpecialUse({&result}, {this}); + + return result; } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tile(const std::vector& reps, NDArray& target) const { - - auto repProd = shape::prodLong(reps.data(), reps.size()); - if (repProd < 1) - throw std::runtime_error("NDArray::tile: reps can't contain 0s"); - - // evaluate true tile shapeInfo for comparison with target shapeInfo - auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); - if(!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - } - - // fill newBuff, loop through all elements of newBuff - // looping through buffer() goes automatically by means of getSubArrayIndex applying - const int ews = target.ews(); - const int targetLen = target.lengthOf(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); + auto repProd = shape::prodLong(reps.data(), reps.size()); + if (repProd < 1) + throw std::runtime_error("NDArray::tile: reps can't contain 0s"); + + // evaluate true tile shapeInfo for comparison with target shapeInfo + auto newShapeInfo = + ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); + if (!shape::equalsSoft(newShapeInfo, target.shapeInfo())) { + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + } + + // fill newBuff, loop through all elements of newBuff + // looping through buffer() goes automatically by means of getSubArrayIndex + // applying + const int ews = target.ews(); + const int targetLen = target.lengthOf(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + target.dataType(), tileKernelHH, + (specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), targetLen, ews, stream), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// void NDArray::tile(NDArray& target) const { - if(rankOf() > target.rankOf()) - throw std::runtime_error("NDArray::tile method - rank of target array must be bigger or equal to the rank of this array !"); - - if(!ShapeUtils::areShapesBroadcastable(*this, target)) - throw std::runtime_error("NDArray::tile method - shapeInfo of target array is not suitable for tile operation !"); - - // fill newBuff, loop through all elements of newBuff - // looping through getBuffer() goes automatically by means of getSubArrayIndex applying - const auto ews = target.ews(); - const auto targetLen = target.lengthOf(); - auto stream = getContext()->getCudaStream(); - - prepareSpecialUse({&target}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); - registerSpecialUse({&target}, {this}); + if (rankOf() > target.rankOf()) + throw std::runtime_error( + "NDArray::tile method - rank of target array must be bigger or equal " + "to the rank of this array !"); + + if (!ShapeUtils::areShapesBroadcastable(*this, target)) + throw std::runtime_error( + "NDArray::tile method - shapeInfo of target array is not suitable for " + "tile operation !"); + + // fill newBuff, loop through all elements of newBuff + // looping through getBuffer() goes automatically by means of getSubArrayIndex + // applying + const auto ews = target.ews(); + const auto targetLen = target.lengthOf(); + auto stream = getContext()->getCudaStream(); + + prepareSpecialUse({&target}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + target.dataType(), tileKernelHH, + (specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), targetLen, ews, stream), + LIBND4J_TYPES); + registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -template +template __global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int* repeats, const int repSize, const int axis) { + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); - const X* x = reinterpret_cast(vx); - Z* z = reinterpret_cast(vz); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen, totalThreads; // xLen = zLen - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen, totalThreads; // xLen = zLen + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { + rank = shape::rank(zShapeInfo); // xRank = zRank + zLen = shape::length(zShapeInfo); // xLen <= zLen - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + totalThreads = gridDim.x * blockDim.x; + } - rank = shape::rank(zShapeInfo); // xRank = zRank - zLen = shape::length(zShapeInfo); // xLen <= zLen - - totalThreads = gridDim.x * blockDim.x; - } + __syncthreads(); - __syncthreads(); + auto coords = sharedMem + threadIdx.x * rank; - auto coords = sharedMem + threadIdx.x * rank; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const auto zOffset = shape::getOffset(zShapeInfo, coords); - shape::index2coords(i, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - if(repSize > 1) { - for (uint j = 0; j < repSize; ++j) { - coords[axis] -= repeats[j]; - if (coords[axis] < 0) { - coords[axis] = j; - break; - } - } + if (repSize > 1) { + for (uint j = 0; j < repSize; ++j) { + coords[axis] -= repeats[j]; + if (coords[axis] < 0) { + coords[axis] = j; + break; } - else - coords[axis] /= repeats[0]; + } + } else + coords[axis] /= repeats[0]; - z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; - } + z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; + } } ////////////////////////////////////////////////////////////////////////// -template -static void repeatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int* repeats, const int repSize, - const int axis) { - - repeatCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, repeats, repSize, axis); +template +static void repeatCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int* repeats, + const int repSize, const int axis) { + repeatCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, repeats, repSize, axis); } -BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int* repeats, const int repSize, const int axis), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int* repeats, + const int repSize, const int axis), + LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // create new array by repeating it the number of times given by repeats NDArray NDArray::repeat(const int axis, const std::vector& repeats) const { + NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), + dataType(), getContext()); - NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(int) * threadsPerBlock + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = output.rankOf() * sizeof(int) * threadsPerBlock + 128; - PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); + PointersManager manager( + getContext(), + "NDArray::repeat(const int axis, const std::vector& repeats)"); - const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); + const int* reps = reinterpret_cast( + manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - prepareSpecialUse({&output}, {this}); - BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); - prepareSpecialUse({&output}, {this}); + prepareSpecialUse({&output}, {this}); + BUILD_SINGLE_SELECTOR_TWICE( + dataType(), repeatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + specialBuffer(), specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), reps, repeats.size(), axis), + LIBND4J_TYPES); + prepareSpecialUse({&output}, {this}); - manager.synchronize(); + manager.synchronize(); - return output; + return output; } ////////////////////////////////////////////////////////////////////////// // fill array by repeating it the number of times given by repeats -void NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) const { - - if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) - throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector& repeats, NDArray& target) method: wrong shape of target array!"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = target.rankOf() * sizeof(int) * threadsPerBlock + 128; - - PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector& repeats)"); - - const int* reps = reinterpret_cast(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); - - prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), specialBuffer(), specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES); - prepareSpecialUse({&target}, {this}); - - manager.synchronize(); +void NDArray::repeat(const int axis, const std::vector& repeats, + NDArray& target) const { + if (!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this))) + throw std::invalid_argument( + "NDArray::repeat(const int axis, const std::vector& repeats, " + "NDArray& target) method: wrong shape of target array!"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = target.rankOf() * sizeof(int) * threadsPerBlock + 128; + + PointersManager manager( + getContext(), + "NDArray::repeat(const int axis, const std::vector& repeats)"); + + const int* reps = reinterpret_cast( + manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); + + prepareSpecialUse({&target}, {this}); + BUILD_DOUBLE_SELECTOR( + dataType(), target.dataType(), repeatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), + specialBuffer(), specialShapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), reps, repeats.size(), axis), + LIBND4J_TYPES, LIBND4J_TYPES); + prepareSpecialUse({&target}, {this}); + + manager.synchronize(); } - //////////////////////////////////////////////////////////////////////// void* NDArray::specialBuffer() { - - if (_buffer->special() == nullptr) { - syncToDevice(); - tickReadHost(); - } - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) { + syncToDevice(); + tickReadHost(); + } + // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } //////////////////////////////////////////////////////////////////////// void const* NDArray::specialBuffer() const { - if (_buffer->special() == nullptr) { - syncToDevice(); - tickReadHost(); - } - // FIXME: this should be fixed once CUDA backend added - return static_cast(_buffer->special()) + (_offset * sizeOfT()); + if (_buffer->special() == nullptr) { + syncToDevice(); + tickReadHost(); + } + // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const { - - if(_length == 0) - { printf("NDArray::printActualBuffer: array length is zero !\n"); return; } - - if(msg) - printf("%s", msg); - - if(host) { - if(buffer() == nullptr || _length == 0) - { printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); return; } - - const T* buff = bufferAsT(); - for (uint i = 0; i < _length; i++) - printf("%.*f, ", precision, (double)buff[getOffset(i)]); - printf("\n"); +template +void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const { + if (_length == 0) { + printf("NDArray::printActualBuffer: array length is zero !\n"); + return; + } + + if (msg) printf("%s", msg); + + if (host) { + if (buffer() == nullptr || _length == 0) { + printf("NDArray::printActualBuffer: host buffer is nullptr !\n"); + return; } - else { - if(specialBuffer() == nullptr || _length == 0) - { printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); return; } - void* pHost = operator new(sizeof(T) * _length); - - if (ews() != 1) { - for (uint i = 0; i < _length; i++) - cudaMemcpyAsync(reinterpret_cast(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream())); - } - else - cudaMemcpyAsync(pHost, specialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); - - cudaError_t cudaResult = cudaStreamSynchronize(*getContext()->getCudaStream()); - if(cudaResult != 0) - throw std::runtime_error("NDArray::printSpecialBuffer: cudaStreamSynchronize failed!"); - - for (uint i = 0; i < _length; i++) - printf("%.*f, ", precision, (double)reinterpret_cast(pHost)[i]); - printf("\n"); - - operator delete(pHost); + const T* buff = bufferAsT(); + for (uint i = 0; i < _length; i++) + printf("%.*f, ", precision, (double)buff[getOffset(i)]); + printf("\n"); + } else { + if (specialBuffer() == nullptr || _length == 0) { + printf("NDArray::printSpecialBuffer: special buffer is nullptr !\n"); + return; } -} -template void NDArray::printCurrentBuffer(const bool host,const char* msg, const int precision) const; -template void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const; -template void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const; + void* pHost = operator new(sizeof(T) * _length); + + if (ews() != 1) { + for (uint i = 0; i < _length; i++) + cudaMemcpyAsync(reinterpret_cast(pHost) + i, + specialBufferWithOffset(i), sizeof(T), + cudaMemcpyDeviceToHost, + *(getContext()->getCudaStream())); + } else + cudaMemcpyAsync(pHost, specialBuffer(), sizeOfT() * _length, + cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); + + cudaError_t cudaResult = + cudaStreamSynchronize(*getContext()->getCudaStream()); + if (cudaResult != 0) + throw std::runtime_error( + "NDArray::printSpecialBuffer: cudaStreamSynchronize failed!"); + + for (uint i = 0; i < _length; i++) + printf("%.*f, ", precision, (double)reinterpret_cast(pHost)[i]); + printf("\n"); + + operator delete(pHost); + } +} +template void NDArray::printCurrentBuffer(const bool host, const char* msg, + const int precision) const; +template void NDArray::printCurrentBuffer(const bool host, + const char* msg, + const int precision) const; +template void NDArray::printCurrentBuffer(const bool host, + const char* msg, + const int precision) const; #if defined(__CUDACC__) && !defined(BUILD_TESTS) @@ -576,6 +692,5 @@ template void NDArray::printCurrentBuffer(const bool host, const char* m #endif -} // end namespace sd +} // end namespace sd #endif - diff --git a/libnd4j/include/array/impl/ByteOrderUtils.cpp b/libnd4j/include/array/impl/ByteOrderUtils.cpp index 0220ccac807b..5ce7f0f44988 100644 --- a/libnd4j/include/array/impl/ByteOrderUtils.cpp +++ b/libnd4j/include/array/impl/ByteOrderUtils.cpp @@ -20,9 +20,8 @@ #include - namespace sd { - ByteOrder ByteOrderUtils::fromFlatByteOrder(sd::graph::ByteOrder order) { - return (ByteOrder) order; - } -} \ No newline at end of file +ByteOrder ByteOrderUtils::fromFlatByteOrder(sd::graph::ByteOrder order) { + return (ByteOrder)order; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ConstantDataBuffer.cpp b/libnd4j/include/array/impl/ConstantDataBuffer.cpp index 5a56a50d02ee..d1e50912d851 100644 --- a/libnd4j/include/array/impl/ConstantDataBuffer.cpp +++ b/libnd4j/include/array/impl/ConstantDataBuffer.cpp @@ -21,52 +21,45 @@ #include "../ConstantDataBuffer.h" namespace sd { - ConstantDataBuffer::ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, Nd4jLong numEelements, Nd4jLong sizeOf) { - _primaryBuffer = primary; - _specialBuffer = special; - _length = numEelements; - _sizeOf = sizeOf; - } - - Nd4jPointer ConstantDataBuffer::primary() const { - return _primaryBuffer; - } +ConstantDataBuffer::ConstantDataBuffer(Nd4jPointer primary, Nd4jPointer special, + Nd4jLong numEelements, Nd4jLong sizeOf) { + _primaryBuffer = primary; + _specialBuffer = special; + _length = numEelements; + _sizeOf = sizeOf; +} - Nd4jPointer ConstantDataBuffer::special() const { - return _specialBuffer; - } +Nd4jPointer ConstantDataBuffer::primary() const { return _primaryBuffer; } - Nd4jLong ConstantDataBuffer::sizeOf() const { - return _sizeOf; - } +Nd4jPointer ConstantDataBuffer::special() const { return _specialBuffer; } - Nd4jLong ConstantDataBuffer::length() const { - return _length; - } +Nd4jLong ConstantDataBuffer::sizeOf() const { return _sizeOf; } - ConstantDataBuffer::ConstantDataBuffer(const ConstantDataBuffer &other) { - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _length = other._length; - _sizeOf = other._sizeOf; - } +Nd4jLong ConstantDataBuffer::length() const { return _length; } - template - T* ConstantDataBuffer::primaryAsT() { - return reinterpret_cast(_primaryBuffer); - } - template SD_EXPORT float* ConstantDataBuffer::primaryAsT(); - template SD_EXPORT double* ConstantDataBuffer::primaryAsT(); - template SD_EXPORT int* ConstantDataBuffer::primaryAsT(); - template SD_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT(); +ConstantDataBuffer::ConstantDataBuffer(const ConstantDataBuffer& other) { + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _length = other._length; + _sizeOf = other._sizeOf; +} - template - T* ConstantDataBuffer::specialAsT() { - return reinterpret_cast(_specialBuffer); - } - template SD_EXPORT float* ConstantDataBuffer::specialAsT(); - template SD_EXPORT double* ConstantDataBuffer::specialAsT(); - template SD_EXPORT int* ConstantDataBuffer::specialAsT(); - template SD_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT(); +template +T* ConstantDataBuffer::primaryAsT() { + return reinterpret_cast(_primaryBuffer); +} +template SD_EXPORT float* ConstantDataBuffer::primaryAsT(); +template SD_EXPORT double* ConstantDataBuffer::primaryAsT(); +template SD_EXPORT int* ConstantDataBuffer::primaryAsT(); +template SD_EXPORT Nd4jLong* ConstantDataBuffer::primaryAsT(); +template +T* ConstantDataBuffer::specialAsT() { + return reinterpret_cast(_specialBuffer); } +template SD_EXPORT float* ConstantDataBuffer::specialAsT(); +template SD_EXPORT double* ConstantDataBuffer::specialAsT(); +template SD_EXPORT int* ConstantDataBuffer::specialAsT(); +template SD_EXPORT Nd4jLong* ConstantDataBuffer::specialAsT(); + +} // namespace sd diff --git a/libnd4j/include/array/impl/ConstantDescriptor.cpp b/libnd4j/include/array/impl/ConstantDescriptor.cpp index 829ac5b343f4..1e2b22f5c22d 100644 --- a/libnd4j/include/array/impl/ConstantDescriptor.cpp +++ b/libnd4j/include/array/impl/ConstantDescriptor.cpp @@ -20,80 +20,80 @@ #include #include + #include namespace sd { - ConstantDescriptor::ConstantDescriptor(double* values, int length) { - for (int e = 0; e < length; e++) - _floatValues.emplace_back(values[e]); - } +ConstantDescriptor::ConstantDescriptor(double *values, int length) { + for (int e = 0; e < length; e++) _floatValues.emplace_back(values[e]); +} - ConstantDescriptor::ConstantDescriptor(Nd4jLong const* values, int length) { - for (int e = 0; e < length; e++) - _integerValues.emplace_back(values[e]); - } +ConstantDescriptor::ConstantDescriptor(Nd4jLong const *values, int length) { + for (int e = 0; e < length; e++) _integerValues.emplace_back(values[e]); +} - ConstantDescriptor::ConstantDescriptor(std::initializer_list values) { - _floatValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::initializer_list values) { + _floatValues = values; +} - ConstantDescriptor::ConstantDescriptor(std::vector &values) { - _integerValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::vector &values) { + _integerValues = values; +} - ConstantDescriptor::ConstantDescriptor(std::vector &values) { - _floatValues = values; - } +ConstantDescriptor::ConstantDescriptor(std::vector &values) { + _floatValues = values; +} - // equal to operator - bool ConstantDescriptor::operator==(const ConstantDescriptor &other) const { - return std::tie(_floatValues, _integerValues) == std::tie(other._floatValues, other._integerValues); - } +// equal to operator +bool ConstantDescriptor::operator==(const ConstantDescriptor &other) const { + return std::tie(_floatValues, _integerValues) == + std::tie(other._floatValues, other._integerValues); +} - // less than operator - bool ConstantDescriptor::operator<(const ConstantDescriptor &other) const { - return std::tie(_floatValues, _integerValues) < std::tie(other._floatValues, other._integerValues); - } +// less than operator +bool ConstantDescriptor::operator<(const ConstantDescriptor &other) const { + return std::tie(_floatValues, _integerValues) < + std::tie(other._floatValues, other._integerValues); +} - bool ConstantDescriptor::isInteger() const { - return !_integerValues.empty(); - } +bool ConstantDescriptor::isInteger() const { return !_integerValues.empty(); } - bool ConstantDescriptor::isFloat() const { - return !_floatValues.empty(); - } +bool ConstantDescriptor::isFloat() const { return !_floatValues.empty(); } - const std::vector& ConstantDescriptor::integerValues() const { - return _integerValues; - } +const std::vector &ConstantDescriptor::integerValues() const { + return _integerValues; +} - const std::vector& ConstantDescriptor::floatValues() const { - return _floatValues; - } +const std::vector &ConstantDescriptor::floatValues() const { + return _floatValues; +} - Nd4jLong ConstantDescriptor::length() const { - return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L; - } +Nd4jLong ConstantDescriptor::length() const { + return isInteger() ? _integerValues.size() + : isFloat() ? _floatValues.size() : 0L; } +} // namespace sd namespace std { - size_t hash::operator()(const sd::ConstantDescriptor &k) const { - using std::hash; - // Compute individual hash values for first, - // second and third and combine them using XOR - // and bit shifting: - size_t hashVal = 0; - size_t i = 0; - if (k.isInteger()) { - for (auto v: k.integerValues()) { - hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); - } - } - else { - for (auto v: k.floatValues()) { - hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); - } - } - return hashVal; +size_t hash::operator()( + const sd::ConstantDescriptor &k) const { + using std::hash; + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + size_t hashVal = 0; + size_t i = 0; + if (k.isInteger()) { + for (auto v : k.integerValues()) { + hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + + (hashVal >> 2); + } + } else { + for (auto v : k.floatValues()) { + hashVal ^= + std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); } + } + return hashVal; } +} // namespace std diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index 204ff879e136..b8146a1aeee2 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -1,5 +1,5 @@ /** -* Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,50 +18,54 @@ // Created by raver on 5/17/2019. // -#include #include +#include #include namespace sd { - ConstantHolder::ConstantHolder(const ConstantHolder& other) { - _buffers = other._buffers; - _deviceId = other._deviceId; - } +ConstantHolder::ConstantHolder(const ConstantHolder& other) { + _buffers = other._buffers; + _deviceId = other._deviceId; +} - bool ConstantHolder::hasBuffer(sd::DataType dataType) { - return _buffers.count(dataType) > 0; - } +bool ConstantHolder::hasBuffer(sd::DataType dataType) { + return _buffers.count(dataType) > 0; +} - std::mutex* ConstantHolder::mutex() { - return &_mutex; - } +std::mutex* ConstantHolder::mutex() { return &_mutex; } - template - bool ConstantHolder::hasBuffer() { - return hasBuffer(DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES); +template +bool ConstantHolder::hasBuffer() { + return hasBuffer(DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT bool ConstantHolder::hasBuffer, (void), + LIBND4J_TYPES); - void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, sd::DataType dataType) { - _buffers[dataType] = pointer; - } +void ConstantHolder::addBuffer(ConstantDataBuffer& pointer, + sd::DataType dataType) { + _buffers[dataType] = pointer; +} - template - void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) { - addBuffer(pointer, DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES); +template +void ConstantHolder::addBuffer(ConstantDataBuffer& pointer) { + addBuffer(pointer, DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ConstantHolder::addBuffer, + (ConstantDataBuffer & cb), LIBND4J_TYPES); - ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(sd::DataType dataType) { - if (!hasBuffer(dataType)) - throw std::runtime_error("Requested dataType is absent in storage"); +ConstantDataBuffer* ConstantHolder::getConstantDataBuffer( + sd::DataType dataType) { + if (!hasBuffer(dataType)) + throw std::runtime_error("Requested dataType is absent in storage"); - return &_buffers[dataType]; - } + return &_buffers[dataType]; +} - template - ConstantDataBuffer* ConstantHolder::getConstantDataBuffer() { - return getConstantDataBuffer(DataTypeUtils::fromT()); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT ConstantDataBuffer* ConstantHolder::getConstantDataBuffer, (), LIBND4J_TYPES); -} \ No newline at end of file +template +ConstantDataBuffer* ConstantHolder::getConstantDataBuffer() { + return getConstantDataBuffer(DataTypeUtils::fromT()); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT ConstantDataBuffer* + ConstantHolder::getConstantDataBuffer, + (), LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp index 262460e8c5d5..033669d967c8 100644 --- a/libnd4j/include/array/impl/DataBuffer.cpp +++ b/libnd4j/include/array/impl/DataBuffer.cpp @@ -20,319 +20,308 @@ // #include -#include #include +#include #include +#include #include -#include namespace sd { - ///// IMLEMENTATION OF COMMON METHODS ///// - +///// IMLEMENTATION OF COMMON METHODS ///// //////////////////////////////////////////////////////////////////////// // default constructor - DataBuffer::DataBuffer() { - - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = 0; - _dataType = INT8; - _workspace = nullptr; - _isOwnerPrimary = false; - _isOwnerSpecial = false; - _deviceId = sd::AffinityManager::currentDeviceId(); - - setCountersToZero(); - } +DataBuffer::DataBuffer() { + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = 0; + _dataType = INT8; + _workspace = nullptr; + _isOwnerPrimary = false; + _isOwnerSpecial = false; + _deviceId = sd::AffinityManager::currentDeviceId(); + + setCountersToZero(); +} //////////////////////////////////////////////////////////////////////// // copy constructor - DataBuffer::DataBuffer(const DataBuffer &other) { +DataBuffer::DataBuffer(const DataBuffer& other) { + throw std::runtime_error( + "DataBuffer copy constructor: we don't expect using of this " + "constructor!"); - throw std::runtime_error("DataBuffer copy constructor: we don't expect using of this constructor!"); + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; - _primaryBuffer = nullptr; - _specialBuffer = nullptr; + _deviceId.store(other._deviceId.load()); - _deviceId.store(other._deviceId.load()); + setCountersToZero(); - setCountersToZero(); - - allocateBuffers(); - copyBufferFrom(other); - } + allocateBuffers(); + copyBufferFrom(other); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(void* primary, void* special, - const size_t lenInBytes, const DataType dataType, - const bool isOwnerPrimary, const bool isOwnerSpecial, - memory::Workspace* workspace) { - - if (primary == nullptr && special == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with both nullptr buffers !"); - - _primaryBuffer = primary; - _specialBuffer = special; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; - _deviceId = sd::AffinityManager::currentDeviceId(); - - setCountersToZero(); - - if(primary != nullptr) - readPrimary(); - if(special != nullptr) - readSpecial(); - } +DataBuffer::DataBuffer(void* primary, void* special, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary, + const bool isOwnerSpecial, + memory::Workspace* workspace) { + if (primary == nullptr && special == nullptr) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with both nullptr " + "buffers !"); + + _primaryBuffer = primary; + _specialBuffer = special; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; + _deviceId = sd::AffinityManager::currentDeviceId(); + + setCountersToZero(); + + if (primary != nullptr) readPrimary(); + if (special != nullptr) readSpecial(); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, const DataType dataType, const bool isOwnerPrimary, memory::Workspace* workspace): - DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, workspace) { - - syncToSpecial(true); - } +DataBuffer::DataBuffer(void* primary, const size_t lenInBytes, + const DataType dataType, const bool isOwnerPrimary, + memory::Workspace* workspace) + : DataBuffer(primary, nullptr, lenInBytes, dataType, isOwnerPrimary, false, + workspace) { + syncToSpecial(true); +} //////////////////////////////////////////////////////////////////////// // copies data from hostBuffer to own memory buffer - DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, const size_t lenInBytes, memory::Workspace* workspace) { - - if (hostBuffer == nullptr) - throw std::runtime_error("DataBuffer constructor: can't be initialized with nullptr host buffer !"); - if (lenInBytes == 0) - throw std::runtime_error("DataBuffer constructor: can't be initialized with zero length !"); +DataBuffer::DataBuffer(const void* hostBuffer, const DataType dataType, + const size_t lenInBytes, memory::Workspace* workspace) { + if (hostBuffer == nullptr) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with nullptr host buffer " + "!"); + if (lenInBytes == 0) + throw std::runtime_error( + "DataBuffer constructor: can't be initialized with zero length !"); - _primaryBuffer = nullptr; - _specialBuffer = nullptr; - _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; + _lenInBytes = lenInBytes; + _dataType = dataType; + _workspace = workspace; - _deviceId = sd::AffinityManager::currentDeviceId(); + _deviceId = sd::AffinityManager::currentDeviceId(); - setCountersToZero(); + setCountersToZero(); - allocateBuffers(); + allocateBuffers(); - copyBufferFromHost(hostBuffer, lenInBytes); - } + copyBufferFromHost(hostBuffer, lenInBytes); +} //////////////////////////////////////////////////////////////////////// - DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, memory::Workspace* workspace, const bool allocBoth) { +DataBuffer::DataBuffer(const size_t lenInBytes, const DataType dataType, + memory::Workspace* workspace, const bool allocBoth) { + _dataType = dataType; + _workspace = workspace; + _lenInBytes = lenInBytes; - _dataType = dataType; - _workspace = workspace; - _lenInBytes = lenInBytes; + _primaryBuffer = nullptr; + _specialBuffer = nullptr; - _primaryBuffer = nullptr; - _specialBuffer = nullptr; + _deviceId = sd::AffinityManager::currentDeviceId(); - _deviceId = sd::AffinityManager::currentDeviceId(); + setCountersToZero(); - setCountersToZero(); - - if(lenInBytes != 0) { - allocateBuffers(allocBoth); - writeSpecial(); - } - } + if (lenInBytes != 0) { + allocateBuffers(allocBoth); + writeSpecial(); + } +} //////////////////////////////////////////////////////////////////////// // move constructor - DataBuffer::DataBuffer(DataBuffer&& other) { - - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; - _deviceId.store(other._deviceId); - - copyCounters(other); - - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; - } +DataBuffer::DataBuffer(DataBuffer&& other) { + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; + _deviceId.store(other._deviceId); + + copyCounters(other); + + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; +} //////////////////////////////////////////////////////////////////////// // assignment operator - DataBuffer& DataBuffer::operator=(const DataBuffer& other) { +DataBuffer& DataBuffer::operator=(const DataBuffer& other) { + if (this == &other) return *this; - if (this == &other) - return *this; + deleteBuffers(); - deleteBuffers(); + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; + allocateBuffers(); + copyBufferFrom(other); - allocateBuffers(); - copyBufferFrom(other); - - return *this; - } + return *this; +} //////////////////////////////////////////////////////////////////////// // move assignment operator - DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { - - if (this == &other) - return *this; - - deleteBuffers(); +DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept { + if (this == &other) return *this; - _primaryBuffer = other._primaryBuffer; - _specialBuffer = other._specialBuffer; - _lenInBytes = other._lenInBytes; - _dataType = other._dataType; - _workspace = other._workspace; - _isOwnerPrimary = other._isOwnerPrimary; - _isOwnerSpecial = other._isOwnerSpecial; + deleteBuffers(); - copyCounters(other); + _primaryBuffer = other._primaryBuffer; + _specialBuffer = other._specialBuffer; + _lenInBytes = other._lenInBytes; + _dataType = other._dataType; + _workspace = other._workspace; + _isOwnerPrimary = other._isOwnerPrimary; + _isOwnerSpecial = other._isOwnerSpecial; - other._primaryBuffer = other._specialBuffer = nullptr; - other.setAllocFlags(false, false); - other._lenInBytes = 0; + copyCounters(other); - return *this; - } + other._primaryBuffer = other._specialBuffer = nullptr; + other.setAllocFlags(false, false); + other._lenInBytes = 0; -//////////////////////////////////////////////////////////////////////// - void* DataBuffer::primary() { - return _primaryBuffer; - } + return *this; +} //////////////////////////////////////////////////////////////////////// - void* DataBuffer::special() { - return _specialBuffer; - } +void* DataBuffer::primary() { return _primaryBuffer; } //////////////////////////////////////////////////////////////////////// - DataType DataBuffer::getDataType() { - return _dataType; - } +void* DataBuffer::special() { return _specialBuffer; } //////////////////////////////////////////////////////////////////////// - size_t DataBuffer::getLenInBytes() const { - return _lenInBytes; - } - +DataType DataBuffer::getDataType() { return _dataType; } //////////////////////////////////////////////////////////////////////// - void DataBuffer::allocatePrimary() { - - if (_primaryBuffer == nullptr && getLenInBytes() > 0) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - // check if this allocation won't bring us above limit - if (_workspace == nullptr) { - if (Environment::getInstance()->isCPU()) { - // on cpu backend we validate against device 0 for now - if (!sd::memory::MemoryCounter::getInstance()->validate(getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds HOST device limits", sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), getLenInBytes()); - } else { - // in heterogenous mode we valdate against device group - if (!sd::memory::MemoryCounter::getInstance()->validateGroup(sd::memory::MemoryType::HOST, getLenInBytes())) - throw sd::allocation_exception::build("Requested amount exceeds HOST group limits", sd::memory::MemoryCounter::getInstance()->groupLimit(sd::memory::MemoryType::HOST), getLenInBytes()); - } - } - - ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); - _isOwnerPrimary = true; - - // count in towards current deviceId if we're not in workspace mode - if (_workspace == nullptr) { - if (Environment::getInstance()->isCPU()) // we don't want this counter to be added to CUDA device - sd::memory::MemoryCounter::getInstance()->countIn(deviceId, getLenInBytes()); - - sd::memory::MemoryCounter::getInstance()->countIn(sd::memory::MemoryType::HOST, getLenInBytes()); - } - } - } +size_t DataBuffer::getLenInBytes() const { return _lenInBytes; } //////////////////////////////////////////////////////////////////////// - void DataBuffer::setAllocFlags(const bool isOwnerPrimary, const bool isOwnerSpecial) { - _isOwnerPrimary = isOwnerPrimary; - _isOwnerSpecial = isOwnerSpecial; +void DataBuffer::allocatePrimary() { + if (_primaryBuffer == nullptr && getLenInBytes() > 0) { + auto deviceId = sd::AffinityManager::currentDeviceId(); + // check if this allocation won't bring us above limit + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) { + // on cpu backend we validate against device 0 for now + if (!sd::memory::MemoryCounter::getInstance()->validate( + getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds HOST device limits", + sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId), + getLenInBytes()); + } else { + // in heterogenous mode we valdate against device group + if (!sd::memory::MemoryCounter::getInstance()->validateGroup( + sd::memory::MemoryType::HOST, getLenInBytes())) + throw sd::allocation_exception::build( + "Requested amount exceeds HOST group limits", + sd::memory::MemoryCounter::getInstance()->groupLimit( + sd::memory::MemoryType::HOST), + getLenInBytes()); + } } -//////////////////////////////////////////////////////////////////////// - void DataBuffer::deletePrimary() { - - if(_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { - auto p = reinterpret_cast(_primaryBuffer); - RELEASE(p, _workspace); - _primaryBuffer = nullptr; - _isOwnerPrimary = false; + ALLOCATE(_primaryBuffer, _workspace, getLenInBytes(), int8_t); + _isOwnerPrimary = true; + // count in towards current deviceId if we're not in workspace mode + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) // we don't want this counter to + // be added to CUDA device + sd::memory::MemoryCounter::getInstance()->countIn(deviceId, + getLenInBytes()); - // count out towards DataBuffer device, only if we're not in workspace - if (_workspace == nullptr) { - if (Environment::getInstance()->isCPU()) - sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, getLenInBytes()); - - sd::memory::MemoryCounter::getInstance()->countOut(sd::memory::MemoryType::HOST, getLenInBytes()); - } - } + sd::memory::MemoryCounter::getInstance()->countIn( + sd::memory::MemoryType::HOST, getLenInBytes()); } + } +} //////////////////////////////////////////////////////////////////////// - void DataBuffer::deleteBuffers() { +void DataBuffer::setAllocFlags(const bool isOwnerPrimary, + const bool isOwnerSpecial) { + _isOwnerPrimary = isOwnerPrimary; + _isOwnerSpecial = isOwnerSpecial; +} - deletePrimary(); - deleteSpecial(); - _lenInBytes = 0; +//////////////////////////////////////////////////////////////////////// +void DataBuffer::deletePrimary() { + if (_isOwnerPrimary && _primaryBuffer != nullptr && getLenInBytes() != 0) { + auto p = reinterpret_cast(_primaryBuffer); + RELEASE(p, _workspace); + _primaryBuffer = nullptr; + _isOwnerPrimary = false; + + // count out towards DataBuffer device, only if we're not in workspace + if (_workspace == nullptr) { + if (Environment::getInstance()->isCPU()) + sd::memory::MemoryCounter::getInstance()->countOut(_deviceId, + getLenInBytes()); + + sd::memory::MemoryCounter::getInstance()->countOut( + sd::memory::MemoryType::HOST, getLenInBytes()); } + } +} //////////////////////////////////////////////////////////////////////// - DataBuffer::~DataBuffer() { +void DataBuffer::deleteBuffers() { + deletePrimary(); + deleteSpecial(); + _lenInBytes = 0; +} - deleteBuffers(); - } +//////////////////////////////////////////////////////////////////////// +DataBuffer::~DataBuffer() { deleteBuffers(); } - void DataBuffer::setPrimaryBuffer(void *buffer, size_t length) { - if (_primaryBuffer != nullptr && _isOwnerPrimary) { - deletePrimary(); - } +void DataBuffer::setPrimaryBuffer(void* buffer, size_t length) { + if (_primaryBuffer != nullptr && _isOwnerPrimary) { + deletePrimary(); + } - _primaryBuffer = buffer; - _isOwnerPrimary = false; - _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); - } + _primaryBuffer = buffer; + _isOwnerPrimary = false; + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); +} - void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { - if (_specialBuffer != nullptr && _isOwnerSpecial) { - deleteSpecial(); - } +void DataBuffer::setSpecialBuffer(void* buffer, size_t length) { + if (_specialBuffer != nullptr && _isOwnerSpecial) { + deleteSpecial(); + } - this->setSpecial(buffer, false); - _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); - } + this->setSpecial(buffer, false); + _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); +} - void DataBuffer::setDataType(DataType dataType) { - _dataType = dataType; - } +void DataBuffer::setDataType(DataType dataType) { _dataType = dataType; } - int DataBuffer::deviceId() const { - return _deviceId.load(); - } +int DataBuffer::deviceId() const { return _deviceId.load(); } - void DataBuffer::close() { - this->deleteBuffers(); - } +void DataBuffer::close() { this->deleteBuffers(); } - void DataBuffer::setDeviceId(int deviceId) { - _deviceId = deviceId; - } -} +void DataBuffer::setDeviceId(int deviceId) { _deviceId = deviceId; } +} // namespace sd diff --git a/libnd4j/include/array/impl/DataTypeUtils.cpp b/libnd4j/include/array/impl/DataTypeUtils.cpp index 481fa4149716..9bfe90eff901 100644 --- a/libnd4j/include/array/impl/DataTypeUtils.cpp +++ b/libnd4j/include/array/impl/DataTypeUtils.cpp @@ -23,15 +23,11 @@ #include namespace sd { - DataType DataTypeUtils::fromInt(int val) { - return (DataType) val; - } +DataType DataTypeUtils::fromInt(int val) { return (DataType)val; } - DataType DataTypeUtils::fromFlatDataType(sd::graph::DType dtype) { - return (DataType) dtype; - } +DataType DataTypeUtils::fromFlatDataType(sd::graph::DType dtype) { + return (DataType)dtype; +} - int DataTypeUtils::asInt(DataType type) { - return (int) type; - } -} \ No newline at end of file +int DataTypeUtils::asInt(DataType type) { return (int)type; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ExtraArguments.cpp b/libnd4j/include/array/impl/ExtraArguments.cpp index 2f512cf50b84..41ac67d8beb0 100644 --- a/libnd4j/include/array/impl/ExtraArguments.cpp +++ b/libnd4j/include/array/impl/ExtraArguments.cpp @@ -18,123 +18,124 @@ // @author raver119@gmail.com // -#include #include #include -#include +#include #include +#include + #ifdef __CUDABLAS__ #include #include #endif namespace sd { - ExtraArguments::ExtraArguments(std::initializer_list arguments) { - _fpArgs = arguments; - } +ExtraArguments::ExtraArguments(std::initializer_list arguments) { + _fpArgs = arguments; +} - ExtraArguments::ExtraArguments(std::initializer_list arguments) { - _intArgs = arguments; - } +ExtraArguments::ExtraArguments(std::initializer_list arguments) { + _intArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - _fpArgs = arguments; - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + _fpArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - _intArgs = arguments; - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + _intArgs = arguments; +} - ExtraArguments::ExtraArguments(const std::vector &arguments) { - for (const auto &v:arguments) - _intArgs.emplace_back(static_cast(v)); - } +ExtraArguments::ExtraArguments(const std::vector &arguments) { + for (const auto &v : arguments) + _intArgs.emplace_back(static_cast(v)); +} - ExtraArguments::ExtraArguments() { - // no-op - } +ExtraArguments::ExtraArguments() { + // no-op +} - ExtraArguments::~ExtraArguments() { - for (auto p:_pointers) { +ExtraArguments::~ExtraArguments() { + for (auto p : _pointers) { #ifdef __CUDABLAS__ - cudaFree(p); -#else // CPU branch - delete[] reinterpret_cast(p); + cudaFree(p); +#else // CPU branch + delete[] reinterpret_cast(p); #endif - } - } + } +} - template - void ExtraArguments::convertAndCopy(Nd4jPointer pointer, Nd4jLong offset) { - auto length = this->length(); - auto target = reinterpret_cast(pointer); +template +void ExtraArguments::convertAndCopy(Nd4jPointer pointer, Nd4jLong offset) { + auto length = this->length(); + auto target = reinterpret_cast(pointer); #ifdef __CUDABLAS__ - target = new T[length]; + target = new T[length]; #endif - if (!_fpArgs.empty()) { - for (int e = offset; e < _fpArgs.size(); e++) { - target[e] = static_cast(_fpArgs[e]); - } - } else if (_intArgs.empty()) { - for (int e = offset; e < _intArgs.size(); e++) { - target[e] = static_cast(_intArgs[e]); - } - } + if (!_fpArgs.empty()) { + for (int e = offset; e < _fpArgs.size(); e++) { + target[e] = static_cast(_fpArgs[e]); + } + } else if (_intArgs.empty()) { + for (int e = offset; e < _intArgs.size(); e++) { + target[e] = static_cast(_intArgs[e]); + } + } #ifdef __CUDABLAS__ - // TODO: maybe make it asynchronous eventually? - cudaMemcpy(pointer, target, length * DataTypeUtils::sizeOf(DataTypeUtils::fromT()), cudaMemcpyHostToDevice); - delete[] target; + // TODO: maybe make it asynchronous eventually? + cudaMemcpy(pointer, target, + length * DataTypeUtils::sizeOf(DataTypeUtils::fromT()), + cudaMemcpyHostToDevice); + delete[] target; #endif - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void ExtraArguments::convertAndCopy, + (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); - void* ExtraArguments::allocate(size_t length, size_t elementSize) { +void *ExtraArguments::allocate(size_t length, size_t elementSize) { #ifdef __CUDABLAS__ - Nd4jPointer ptr; - auto res = cudaMalloc(reinterpret_cast(&ptr), length * elementSize); - if (res != 0) - throw std::runtime_error("Can't allocate CUDA memory"); -#else // CPU branch - auto ptr = new int8_t[length * elementSize]; - if (!ptr) - throw std::runtime_error("Can't allocate memory"); + Nd4jPointer ptr; + auto res = cudaMalloc(reinterpret_cast(&ptr), length * elementSize); + if (res != 0) throw std::runtime_error("Can't allocate CUDA memory"); +#else // CPU branch + auto ptr = new int8_t[length * elementSize]; + if (!ptr) throw std::runtime_error("Can't allocate memory"); #endif - return ptr; - } - - size_t ExtraArguments::length() { - if (!_fpArgs.empty()) - return _fpArgs.size(); - else if (!_intArgs.empty()) - return _intArgs.size(); - else - return 0; - } + return ptr; +} - template - void* ExtraArguments::argumentsAsT(Nd4jLong offset) { - return argumentsAsT(DataTypeUtils::fromT(), offset); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); +size_t ExtraArguments::length() { + if (!_fpArgs.empty()) + return _fpArgs.size(); + else if (!_intArgs.empty()) + return _intArgs.size(); + else + return 0; +} +template +void *ExtraArguments::argumentsAsT(Nd4jLong offset) { + return argumentsAsT(DataTypeUtils::fromT(), offset); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void *ExtraArguments::argumentsAsT, + (Nd4jLong offset), LIBND4J_TYPES); - void* ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { - if (_fpArgs.empty() && _intArgs.empty()) - return nullptr; +void *ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { + if (_fpArgs.empty() && _intArgs.empty()) return nullptr; - // we allocate pointer - auto ptr = allocate(length() - offset, DataTypeUtils::sizeOf(dataType)); + // we allocate pointer + auto ptr = allocate(length() - offset, DataTypeUtils::sizeOf(dataType)); - // fill it with data - BUILD_SINGLE_SELECTOR(dataType, convertAndCopy, (ptr, offset), LIBND4J_TYPES); + // fill it with data + BUILD_SINGLE_SELECTOR(dataType, convertAndCopy, (ptr, offset), LIBND4J_TYPES); - // store it internally for future release - _pointers.emplace_back(ptr); + // store it internally for future release + _pointers.emplace_back(ptr); - return ptr; - } + return ptr; } +} // namespace sd diff --git a/libnd4j/include/array/impl/InteropDataBuffer.cpp b/libnd4j/include/array/impl/InteropDataBuffer.cpp index d0a38161285f..02d1cfc21dad 100644 --- a/libnd4j/include/array/impl/InteropDataBuffer.cpp +++ b/libnd4j/include/array/impl/InteropDataBuffer.cpp @@ -18,129 +18,132 @@ // @author raver119@gmail.com // -#include #include +#include #include #include namespace sd { - InteropDataBuffer::InteropDataBuffer(InteropDataBuffer &dataBuffer, uint64_t length, uint64_t offset) { - _dataBuffer = dataBuffer.getDataBuffer(); - - // offset is always absolute to the original buffer - _offset = offset; - - if (_offset + length > _dataBuffer->getLenInBytes()) { - throw std::runtime_error("offset + length is higher than original length"); - } - } - - InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { - _dataBuffer = databuffer; - } - - InteropDataBuffer::InteropDataBuffer(size_t elements, sd::DataType dtype, bool allocateBoth) { - if (elements == 0) { - _dataBuffer = std::make_shared(); - _dataBuffer->setDataType(dtype); - } else { - _dataBuffer = std::make_shared(elements, dtype, nullptr, allocateBoth); - } - } - - std::shared_ptr InteropDataBuffer::getDataBuffer() const { - return _dataBuffer; - } - - std::shared_ptr InteropDataBuffer::dataBuffer() { - return _dataBuffer; - } - - void* InteropDataBuffer::primary() const { - return reinterpret_cast(_dataBuffer->primary()) + _offset; - } - - void* InteropDataBuffer::special() const { - return reinterpret_cast(_dataBuffer->special()) + _offset; - } - - void InteropDataBuffer::setPrimary(void* ptr, size_t length) { - _dataBuffer->setPrimaryBuffer(ptr, length); - } - - void InteropDataBuffer::setSpecial(void* ptr, size_t length) { - _dataBuffer->setSpecialBuffer(ptr, length); - } - - uint64_t InteropDataBuffer::offset() const { - return _offset; - } - - void InteropDataBuffer::setOffset(uint64_t offset) { - _offset = offset; - } - - int InteropDataBuffer::deviceId() const { - return _dataBuffer->deviceId(); - } - - - void InteropDataBuffer::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { - for (const auto &v:writeList) { - if (v == nullptr) - continue; - - v->getDataBuffer()->writeSpecial(); - } - } - - void InteropDataBuffer::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - auto currentDeviceId = sd::AffinityManager::currentDeviceId(); - for (const auto &v:readList) { - if (v == nullptr) - continue; - - if (v->getDataBuffer()->deviceId() != currentDeviceId) - v->getDataBuffer()->migrate(); - - v->getDataBuffer()->syncToSpecial(); - } - - // we don't tick write list, only ensure the same device affinity - for (const auto &v:writeList) { - if (v == nullptr) - continue; - - // special case for legacy ops - views can be updated on host side, thus original array can be not updated - if (!v->getDataBuffer()->isSpecialActual()) - v->getDataBuffer()->syncToSpecial(); - - if (v->getDataBuffer()->deviceId() != currentDeviceId) - v->getDataBuffer()->migrate(); - } - } - - void InteropDataBuffer::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { - for (const auto &v:writeList) { - if (v == nullptr) - continue; - } - } - - void InteropDataBuffer::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { - for (const auto &v:readList) { - if (v == nullptr) - continue; - - v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); - } - } - - void InteropDataBuffer::expand(size_t newlength) { - _dataBuffer->expand(newlength * DataTypeUtils::sizeOf(_dataBuffer->getDataType())); - } - - void InteropDataBuffer::setDeviceId(int deviceId) { - _dataBuffer->setDeviceId(deviceId); - } +InteropDataBuffer::InteropDataBuffer(InteropDataBuffer& dataBuffer, + uint64_t length, uint64_t offset) { + _dataBuffer = dataBuffer.getDataBuffer(); + + // offset is always absolute to the original buffer + _offset = offset; + + if (_offset + length > _dataBuffer->getLenInBytes()) { + throw std::runtime_error("offset + length is higher than original length"); + } +} + +InteropDataBuffer::InteropDataBuffer(std::shared_ptr databuffer) { + _dataBuffer = databuffer; +} + +InteropDataBuffer::InteropDataBuffer(size_t elements, sd::DataType dtype, + bool allocateBoth) { + if (elements == 0) { + _dataBuffer = std::make_shared(); + _dataBuffer->setDataType(dtype); + } else { + _dataBuffer = + std::make_shared(elements, dtype, nullptr, allocateBoth); + } +} + +std::shared_ptr InteropDataBuffer::getDataBuffer() const { + return _dataBuffer; +} + +std::shared_ptr InteropDataBuffer::dataBuffer() { + return _dataBuffer; +} + +void* InteropDataBuffer::primary() const { + return reinterpret_cast(_dataBuffer->primary()) + _offset; +} + +void* InteropDataBuffer::special() const { + return reinterpret_cast(_dataBuffer->special()) + _offset; +} + +void InteropDataBuffer::setPrimary(void* ptr, size_t length) { + _dataBuffer->setPrimaryBuffer(ptr, length); +} + +void InteropDataBuffer::setSpecial(void* ptr, size_t length) { + _dataBuffer->setSpecialBuffer(ptr, length); +} + +uint64_t InteropDataBuffer::offset() const { return _offset; } + +void InteropDataBuffer::setOffset(uint64_t offset) { _offset = offset; } + +int InteropDataBuffer::deviceId() const { return _dataBuffer->deviceId(); } + +void InteropDataBuffer::registerSpecialUse( + const std::vector& writeList, + const std::vector& readList) { + for (const auto& v : writeList) { + if (v == nullptr) continue; + + v->getDataBuffer()->writeSpecial(); + } +} + +void InteropDataBuffer::prepareSpecialUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + auto currentDeviceId = sd::AffinityManager::currentDeviceId(); + for (const auto& v : readList) { + if (v == nullptr) continue; + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + + v->getDataBuffer()->syncToSpecial(); + } + + // we don't tick write list, only ensure the same device affinity + for (const auto& v : writeList) { + if (v == nullptr) continue; + + // special case for legacy ops - views can be updated on host side, thus + // original array can be not updated + if (!v->getDataBuffer()->isSpecialActual()) + v->getDataBuffer()->syncToSpecial(); + + if (v->getDataBuffer()->deviceId() != currentDeviceId) + v->getDataBuffer()->migrate(); + } +} + +void InteropDataBuffer::registerPrimaryUse( + const std::vector& writeList, + const std::vector& readList) { + for (const auto& v : writeList) { + if (v == nullptr) continue; + } +} + +void InteropDataBuffer::preparePrimaryUse( + const std::vector& writeList, + const std::vector& readList, + bool synchronizeWritables) { + for (const auto& v : readList) { + if (v == nullptr) continue; + + v->getDataBuffer()->syncToPrimary(LaunchContext::defaultContext()); + } +} + +void InteropDataBuffer::expand(size_t newlength) { + _dataBuffer->expand(newlength * + DataTypeUtils::sizeOf(_dataBuffer->getDataType())); +} + +void InteropDataBuffer::setDeviceId(int deviceId) { + _dataBuffer->setDeviceId(deviceId); } +} // namespace sd diff --git a/libnd4j/include/array/impl/ManagedDataBuffer.cpp b/libnd4j/include/array/impl/ManagedDataBuffer.cpp index 8f2191fa85b2..8f486f27faca 100644 --- a/libnd4j/include/array/impl/ManagedDataBuffer.cpp +++ b/libnd4j/include/array/impl/ManagedDataBuffer.cpp @@ -21,17 +21,19 @@ #include namespace sd { - ManagedDataBuffer::ManagedDataBuffer(graph::GraphMemoryManager &manager, uint64_t numberOfBytes, DataType dtype, memory::MemoryZone zone) : - _manager(manager), - _zone(zone), - _descriptor(manager.allocate(numberOfBytes, zone)) { +ManagedDataBuffer::ManagedDataBuffer(graph::GraphMemoryManager &manager, + uint64_t numberOfBytes, DataType dtype, + memory::MemoryZone zone) + : _manager(manager), + _zone(zone), + _descriptor(manager.allocate(numberOfBytes, zone)) { + _lenInBytes = numberOfBytes; + _dataType = dtype; +} - _lenInBytes = numberOfBytes; - _dataType = dtype; - } - - ManagedDataBuffer::~ManagedDataBuffer() { - // if we know that MDB can be released - it means that all NDArrays were released, so it's really safe to release - _manager.release(_descriptor); - } -} \ No newline at end of file +ManagedDataBuffer::~ManagedDataBuffer() { + // if we know that MDB can be released - it means that all NDArrays were + // released, so it's really safe to release + _manager.release(_descriptor); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 542b54c16e28..fc3422f416d3 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -21,20 +21,19 @@ #define __NDARRAY__HPP__ #include +#include #include -#include +#include +#include #include #include -#include +#include #include -#include #include -#include -#include #include -#include -#include #include +#include +#include #include namespace sd { @@ -51,1118 +50,1222 @@ SD_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; //////////////////////////////////////////////////////////////////////// // copy constructor NDArray::NDArray(const NDArray& other) { - - //setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); -/* - if(!isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); - } - else - _buffer = std::make_shared(); - */ - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _length = other._length; - _isAttached = other._isAttached; - _isView = other._isView; - _context = other._context; - _dataType = other._dataType; - _deviceId = other._deviceId; - _offset = other._offset; + // setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), + // other.shapeOf(), other.rankOf())); + /* + if(!isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * + other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); + */ + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _length = other._length; + _isAttached = other._isAttached; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; } //////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _isAttached = _context->getWorkspace() != nullptr; - _offset = 0; - - if (shape.empty()) - setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, order, shape)); +NDArray::NDArray(const char order, const std::vector& shape, + sd::DataType dtype, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _isAttached = _context->getWorkspace() != nullptr; + _offset = 0; + + if (shape.empty()) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, order, shape)); - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); + _buffer = + std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), + dtype, getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); } - bool NDArray::defined() const { - return _shapeInfo != nullptr; - } +bool NDArray::defined() const { return _shapeInfo != nullptr; } - bool NDArray::undefined() const { - return _shapeInfo == nullptr; - } +bool NDArray::undefined() const { return _shapeInfo == nullptr; } //////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, sd::DataType dtype, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - if (shape.size() == 0) { - if (data.size() == 0) - setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); - else - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - } else { - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - } - - if (lengthOf() != data.size()) { - nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); - throw std::runtime_error("Data size doesn't match shape"); - } - - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace(), true); - - for(Nd4jLong i=0; i < lengthOf(); ++i) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedDoubleAssign<, double>(buffer(), i, reinterpret_cast(data.data()), i), LIBND4J_TYPES); - } - tickWriteHost(); - syncToDevice(); +NDArray::NDArray(const char order, const std::vector& shape, + const std::vector& data, sd::DataType dtype, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + if (shape.size() == 0) { + if (data.size() == 0) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + } else { + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + } + + if (lengthOf() != data.size()) { + nd4j_printf( + "NDArray constructor: data size [%i] doesn't match shape length [%i]\n", + data.size(), lengthOf()); + throw std::runtime_error("Data size doesn't match shape"); + } + + _buffer = + std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), + dtype, getContext()->getWorkspace(), true); + + for (Nd4jLong i = 0; i < lengthOf(); ++i) { + BUILD_SINGLE_PARTIAL_SELECTOR( + dtype, + templatedDoubleAssign<, double>( + buffer(), i, reinterpret_cast(data.data()), i), + LIBND4J_TYPES); + } + tickWriteHost(); + syncToDevice(); } - //////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const NDArray *other, const bool copyStrides, sd::LaunchContext* context) { - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; +NDArray::NDArray(const NDArray* other, const bool copyStrides, + sd::LaunchContext* context) { + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; - if (copyStrides) - setShapeInfo(ShapeDescriptor(other->_shapeInfo)); - else - setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), other->shapeOf(), other->rankOf())); + if (copyStrides) + setShapeInfo(ShapeDescriptor(other->_shapeInfo)); + else + setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), + other->shapeOf(), other->rankOf())); - if (!isEmpty()) - _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); + if (!isEmpty()) + _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), + getContext()->getWorkspace()); } //////////////////////////////////////////////////////////////////////// -NDArray::NDArray(void* buffer, const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context, const bool isBuffAlloc) { - - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); +NDArray::NDArray(void* buffer, const char order, + const std::vector& shape, sd::DataType dtype, + sd::LaunchContext* context, const bool isBuffAlloc) { + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; - setShapeInfo(ShapeDescriptor(dtype, order, shape)); + setShapeInfo(ShapeDescriptor(dtype, order, shape)); - _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); + _buffer = + std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), + isBuffAlloc, getContext()->getWorkspace()); } //////////////////////////////////////////////////////////////////////// -// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros -NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) { - - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); +// creates new NDArray using shape information from "shapeInfo" array, set all +// elements in new array to be zeros +NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, + const bool copyStrides, sd::LaunchContext* context, + const bool nullify) { + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor: can't be initalized without shapeinfo"); - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - if (copyStrides) - setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), shape::shapeOf(shapeInfo), shape::rank(shapeInfo))); + if (copyStrides) + setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), + shape::shapeOf(shapeInfo), + shape::rank(shapeInfo))); - if (!isEmpty()) { - _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); + if (!isEmpty()) { + _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, + getContext()->getWorkspace()); - if (nullify) - _buffer->setToZeroBuffers(); - } + if (nullify) _buffer->setToZeroBuffers(); + } } //////////////////////////////////////////////////////////////////////// // scalar constructor -NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, const bool isScalar) { - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; +NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, + const bool isScalar) { + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; - if (isScalar) { - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - _buffer = std::make_shared(sizeOfT(), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); - } - else - setShapeInfo(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); + if (isScalar) { + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + _buffer = std::make_shared(sizeOfT(), dtype, + getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); + } else + setShapeInfo(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); } ////////////////////////////////////////////////////////////////////////// // move constructor NDArray::NDArray(NDArray&& other) noexcept { + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; - - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; } //////////////////////////////////////////////////////////////////////// -//constructor, create empty array at given workspace -NDArray::NDArray(sd::LaunchContext * context) { - _buffer = std::make_shared(); - _shapeInfo = nullptr; - _shapeInfoD = nullptr; - _offset = 0; - _context = context; - _length = 0; +// constructor, create empty array at given workspace +NDArray::NDArray(sd::LaunchContext* context) { + _buffer = std::make_shared(); + _shapeInfo = nullptr; + _shapeInfoD = nullptr; + _offset = 0; + _context = context; + _length = 0; } - //////////////////////////////////////////////////////////////////////// - // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type - NDArray::NDArray(const Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context, const bool nullify): - NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) { - } - - //////////////////////////////////////////////////////////////////////// - NDArray::NDArray(std::shared_ptr buffer, const ShapeDescriptor& descriptor, sd::LaunchContext* context, const Nd4jLong offset) { - - _context = context; - _offset = offset; - - setShapeInfo(descriptor); - - _buffer = buffer; - - _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); - } - - NDArray::NDArray(void *buffer, Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc) : NDArray::NDArray(buffer, const_cast(shapeInfo), context, isBuffAlloc) { - // - } - - //////////////////////////////////////////////////////////////////////// - // do not allocate memory, memory for array is passed from outside - NDArray::NDArray(void *buffer, const Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc) { - - if (buffer == nullptr && ArrayOptions::arrayType(shapeInfo) != ArrayType::EMPTY) - throw std::runtime_error("NDArray constructor: can't be initalized with nullptr buffer !"); +//////////////////////////////////////////////////////////////////////// +// creates new NDArray using shape information from "shapeInfo" array, set all +// elements in new array to be zeros, set dtype as array type +NDArray::NDArray(const Nd4jLong* shapeInfo, const bool copyStrides, + sd::LaunchContext* context, const bool nullify) + : NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, + context) {} - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo !"); +//////////////////////////////////////////////////////////////////////// +NDArray::NDArray(std::shared_ptr buffer, + const ShapeDescriptor& descriptor, sd::LaunchContext* context, + const Nd4jLong offset) { + _context = context; + _offset = offset; - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32 !"); + setShapeInfo(descriptor); - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + _buffer = buffer; - setShapeInfo(ShapeDescriptor(shapeInfo)); + _isView = offset > 0 || _length * DataTypeUtils::sizeOf(_dataType) < + buffer->getLenInBytes(); +} - if (this->isEmpty()) { - tickReadDevice(); - tickReadHost(); - } - else { - _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); - } - } +NDArray::NDArray(void* buffer, Nd4jLong* shapeInfo, sd::LaunchContext* context, + const bool isBuffAlloc) + : NDArray::NDArray(buffer, const_cast(shapeInfo), context, + isBuffAlloc) { + // +} - //////////////////////////////////////////////////////////////////////// - // do not allocate memory, memory for array is passed from outside - // we suppose the content of both (device and host) buffers is identical - NDArray::NDArray(void *buffer, void* bufferD, const Nd4jLong *shapeInfo, sd::LaunchContext * context, const bool isBuffAlloc, const bool isBuffDAlloc) { +//////////////////////////////////////////////////////////////////////// +// do not allocate memory, memory for array is passed from outside +NDArray::NDArray(void* buffer, const Nd4jLong* shapeInfo, + sd::LaunchContext* context, const bool isBuffAlloc) { + if (buffer == nullptr && + ArrayOptions::arrayType(shapeInfo) != ArrayType::EMPTY) + throw std::runtime_error( + "NDArray constructor: can't be initalized with nullptr buffer !"); + + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor: can't be initalized without shapeinfo !"); + + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor: rank of NDArray can't exceed 32 !"); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor(shapeInfo)); + + if (this->isEmpty()) { + tickReadDevice(); + tickReadHost(); + } else { + _buffer = + std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), + isBuffAlloc, getContext()->getWorkspace()); + } +} - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor cuda: can't be initalized without shapeinfo"); +//////////////////////////////////////////////////////////////////////// +// do not allocate memory, memory for array is passed from outside +// we suppose the content of both (device and host) buffers is identical +NDArray::NDArray(void* buffer, void* bufferD, const Nd4jLong* shapeInfo, + sd::LaunchContext* context, const bool isBuffAlloc, + const bool isBuffDAlloc) { + if (shapeInfo == nullptr) + throw std::runtime_error( + "NDArray constructor cuda: can't be initalized without shapeinfo"); - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("NDArray constructor cuda: rank of NDArray can't exceed 32"); + if ((int)shapeInfo[0] > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor cuda: rank of NDArray can't exceed 32"); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(shapeInfo)); + setShapeInfo(ShapeDescriptor(shapeInfo)); - if (!isEmpty()) - _buffer = std::make_shared(buffer, bufferD, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, isBuffDAlloc, getContext()->getWorkspace()); - } + if (!isEmpty()) + _buffer = std::make_shared( + buffer, bufferD, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, + isBuffDAlloc, getContext()->getWorkspace()); +} ////////////////////////////////////////////////////////////////////////// -NDArray::NDArray(std::shared_ptr buffer, const char order, const std::vector &shape, sd::LaunchContext* context) { +NDArray::NDArray(std::shared_ptr buffer, const char order, + const std::vector& shape, + sd::LaunchContext* context) { + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArray constructor: rank of NDArray can't exceed 32"); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32"); + _context = context; + _offset = 0; - _context = context; - _offset = 0; + setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); - setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); + _buffer = buffer; - _buffer = buffer; - - _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); + _isView = + _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); } ///////////////////////////////////////////////////////////////////////// // u16 string constructors -NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } +NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } - if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } + if (!unicode::isStringValidU16(u16string.data(), + u16string.data() + u16string.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return static_cast(u16string.size() * sizeof(uint16_t)); - } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf16StringInUtf32(u16string.data(), u16string.size()); - } - return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); - }(); + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return static_cast(u16string.size() * sizeof(uint16_t)); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf16StringInUtf32(u16string.data(), + u16string.size()); + } + return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); + }(); - Nd4jLong offsets[2] = { 0 , dataLength }; + Nd4jLong offsets[2] = {0, dataLength}; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf16to8(u16string.data(), data, u16string.size()); - } - else if (dtype == DataType::UTF16) { - memcpy(data, u16string.data(), dataLength); - } - else { - unicode::utf16to32(u16string.data(), data, u16string.size()); - } + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf16to8(u16string.data(), data, u16string.size()); + } else if (dtype == DataType::UTF16) { + memcpy(data, u16string.data(), dataLength); + } else { + unicode::utf16to32(u16string.data(), data, u16string.size()); + } - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// // u32 string constructors -NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { +NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } + if (!unicode::isStringValidU32(u32string.data(), + u32string.data() + u32string.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - if (!unicode::isStringValidU32(u32string.data(), u32string.data() + u32string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf32StringInUtf16(u32string.data(), + u32string.size()); } - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf32StringInUtf16(u32string.data(), u32string.size()); - } - if (dtype == DataType::UTF32) { - return static_cast(sizeof(uint32_t) * u32string.size()); - } - return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); - }(); + if (dtype == DataType::UTF32) { + return static_cast(sizeof(uint32_t) * u32string.size()); + } + return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); + }(); - Nd4jLong offsets[2] = { 0 , dataLength }; + Nd4jLong offsets[2] = {0, dataLength}; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf32to8(u32string.data(), data, u32string.size()); - } - else if (dtype == DataType::UTF16) { - unicode::utf32to16(u32string.data(), data, u32string.size()); - } - else { - memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); - } + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf32to8(u32string.data(), data, u32string.size()); + } else if (dtype == DataType::UTF16) { + unicode::utf32to16(u32string.data(), data, u32string.size()); + } else { + memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); + } - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// // u8 string constructors -NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } +NDArray::NDArray(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + } - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } - // one word that is why used 1 - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); + // one word that is why used 1 + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); - } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); - } - return static_cast(str.size()); - }(); + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); + } + return static_cast(str.size()); + }(); - Nd4jLong offsets[2] = { 0 , dataLength }; + Nd4jLong offsets[2] = {0, dataLength}; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); + auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - memcpy(data, str.data(), str.size()); - } - else if (dtype == DataType::UTF16) { - unicode::utf8to16(str.data(), data, str.size()); - } - else { - unicode::utf8to32(str.data(), data, str.size()); - } + if (dtype == DataType::UTF8) { + memcpy(data, str.data(), str.size()); + } else if (dtype == DataType::UTF16) { + unicode::utf8to16(str.data(), data, str.size()); + } else { + unicode::utf8to32(str.data(), data, str.size()); + } - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// // constructors for vector of strings -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU8(str, str + std::char_traits::length(str)) ) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e], std::char_traits::length(string[e])); - return static_cast(std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; +NDArray::NDArray(const std::vector& shape, + const std::vector& string, + const sd::DataType dataType, sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str, + str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16( + string[e], std::char_traits::length(string[e])); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32( + string[e], std::char_traits::length(string[e])); + return static_cast(std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dataType, + context->getWorkspace(), true); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); - _isView = false; + _isView = false; - setAttached(context->getWorkspace() != nullptr); + setAttached(context->getWorkspace() != nullptr); - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); + auto data = reinterpret_cast(bufferAsT() + headerLength); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dataType == DataType::UTF16) { - unicode::utf8to16(string[e], cdata, std::char_traits::length(string[e])); - } - else if (dataType == DataType::UTF32) { - unicode::utf8to32(string[e], cdata, std::char_traits::length(string[e])); - } - else { - memcpy(cdata, string[e], std::char_traits::length(string[e])); - } - } - }; + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e], cdata, + std::char_traits::length(string[e])); + } else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e], cdata, + std::char_traits::length(string[e])); + } else { + memcpy(cdata, string[e], std::char_traits::length(string[e])); + } + } + }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e].data(), string[e].size()); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e].data(), string[e].size()); - return static_cast(string[e].size()); - }(); - } - - offsets[string.size()] = dataLength; +NDArray::NDArray(const std::vector& shape, + const std::vector& string, + const sd::DataType dataType, sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e].data(), + string[e].size()); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e].data(), + string[e].size()); + return static_cast(string[e].size()); + }(); + } - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + offsets[string.size()] = dataLength; - _context = context; - _offset = 0; + _buffer = std::make_shared(headerLength + dataLength, dataType, + context->getWorkspace(), true); - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + _context = context; + _offset = 0; - _isView = false; + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); - setAttached(context->getWorkspace() != nullptr); + _isView = false; - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + setAttached(context->getWorkspace() != nullptr); - auto data = reinterpret_cast(bufferAsT() + headerLength); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dataType == DataType::UTF16) { - unicode::utf8to16(string[e].data(), cdata, string[e].size()); - } - else if (dataType == DataType::UTF32) { - unicode::utf8to32(string[e].data(), cdata, string[e].size()); - } - else { - memcpy(cdata, string[e].data(), string[e].size()); - } - } - }; + auto data = reinterpret_cast(bufferAsT() + headerLength); - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e].data(), cdata, string[e].size()); + } else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e].data(), cdata, string[e].size()); + } else { + memcpy(cdata, string[e].data(), string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * string[e].size()); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e].data(), string[e].size()); - return unicode::offsetUtf16StringInUtf8(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * string[e].size()); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e].data(), + string[e].size()); + return unicode::offsetUtf16StringInUtf8(string[e].data(), + string[e].size()); + }(); + } + offsets[string.size()] = dataLength; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - _isView = false; + _isView = false; - setAttached(context->getWorkspace() != nullptr); + setAttached(context->getWorkspace() != nullptr); - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); + auto data = reinterpret_cast(bufferAsT() + headerLength); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e].data(), cdata, string[e].size()); - } - else { - unicode::utf16to8(string[e].data(), cdata, string[e].size()); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); + } else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e].data(), cdata, string[e].size()); + } else { + unicode::utf16to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU16(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e], std::char_traits::length(string[e])); - return unicode::offsetUtf16StringInUtf8(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU16( + str, str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast( + sizeof(uint16_t) * std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32( + string[e], std::char_traits::length(string[e])); + return unicode::offsetUtf16StringInUtf8( + string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; - _context = context; - _offset = 0; + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + _context = context; + _offset = 0; - _isView = false; + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - setAttached(context->getWorkspace() != nullptr); + _isView = false; - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + setAttached(context->getWorkspace() != nullptr); - auto data = reinterpret_cast(bufferAsT() + headerLength); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); + auto data = reinterpret_cast(bufferAsT() + headerLength); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e], cdata, std::char_traits::length(string[e])); - } - else { - unicode::utf16to8(string[e], cdata, std::char_traits::length(string[e])); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy( + cdata, string[e], + std::char_traits::length(string[e]) * sizeof(uint16_t)); + } else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e], cdata, + std::char_traits::length(string[e])); + } else { + unicode::utf16to8(string[e], cdata, + std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (auto str : string) { - if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * string[e].size()); - return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument( + "NDArray::NDArray: invalid DataType, only string dataTypes have to be " + "used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (auto str : string) { + if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e].data(), + string[e].size()); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * string[e].size()); + return unicode::offsetUtf32StringInUtf16(string[e].data(), + string[e].size()); + }(); + } + offsets[string.size()] = dataLength; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - _isView = false; + _isView = false; - setAttached(context->getWorkspace() != nullptr); + setAttached(context->getWorkspace() != nullptr); - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); + auto data = reinterpret_cast(bufferAsT() + headerLength); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e].data(), cdata, string[e].size()); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e].data(), cdata, string[e].size()); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e].data(), cdata, string[e].size()); + } else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); + } else { + unicode::utf32to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU32(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - } - - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - - std::vector offsets(string.size() + 1); - - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * std::char_traits::length(string[e])); - return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; +NDArray::NDArray(const std::vector& shape, + const std::vector& string, sd::DataType dtype, + sd::LaunchContext* context) { + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); + + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument( + "NDArray::NDArray: Number of strings should match length of array"); + + for (const auto& str : string) { + if (!unicode::isStringValidU32( + str, str + std::char_traits::length(str))) { + throw std::invalid_argument( + "NDArray::NDArray: invalid character in input string"); + } + } + + Nd4jLong headerLength = + ShapeUtils::stringBufferHeaderRequirements(string.size()); + + std::vector offsets(string.size() + 1); + + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16( + string[e], std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return static_cast( + sizeof(uint32_t) * std::char_traits::length(string[e])); + return unicode::offsetUtf32StringInUtf16( + string[e], std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + _buffer = std::make_shared(headerLength + dataLength, dtype, + context->getWorkspace(), true); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - _isView = _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); + _isView = + _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); - setAttached(context->getWorkspace() != nullptr); + setAttached(context->getWorkspace() != nullptr); - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + memcpy(bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto data = reinterpret_cast(bufferAsT() + headerLength); + auto data = reinterpret_cast(bufferAsT() + headerLength); - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e], cdata, std::char_traits::length(string[e])); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e], cdata, std::char_traits::length(string[e])); - } - } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e], cdata, + std::char_traits::length(string[e])); + } else if (dtype == DataType::UTF32) { + memcpy( + cdata, string[e], + std::char_traits::length(string[e]) * sizeof(uint32_t)); + } else { + unicode::utf32to8(string[e], cdata, + std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); + tickWriteHost(); + syncToDevice(); } //////////////////////////////////////////////////////////////////////// // assignment operator - NDArray& NDArray::operator=(const NDArray& other) { - - if (this == &other || (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) - return *this; - - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _length = other._length; - _isAttached = other._isAttached; - _isView = other._isView; - _context = other._context; - _dataType = other._dataType; - _deviceId = other._deviceId; - _offset = other._offset; - -/* - if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, other._shapeInfo)) { - if(!other.isEmpty()) - this->assign(&other); - } - else { - _context = other._context; - _offset = 0; - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - - if(!other.isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); - } - else - _buffer = std::make_shared(); - } - */ - +NDArray& NDArray::operator=(const NDArray& other) { + if (this == &other || + (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) return *this; -} + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _length = other._length; + _isAttached = other._isAttached; + _isView = other._isView; + _context = other._context; + _dataType = other._dataType; + _deviceId = other._deviceId; + _offset = other._offset; + + /* + if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, + other._shapeInfo)) { if(!other.isEmpty()) this->assign(&other); + } + else { + _context = other._context; + _offset = 0; + setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), + other.shapeOf(), other.rankOf())); + + if(!other.isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * + other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); + } + */ + + return *this; +} ////////////////////////////////////////////////////////////////////////// bool NDArray::isC() const { - // TODO: this method must be implemented once we add support for complex numbers - return false; + // TODO: this method must be implemented once we add support for complex + // numbers + return false; } ////////////////////////////////////////////////////////////////////////// bool NDArray::isS() const { - return (dataType() == DataType::UTF8 || - dataType() == DataType::UTF16 || - dataType() == DataType::UTF32); + return (dataType() == DataType::UTF8 || dataType() == DataType::UTF16 || + dataType() == DataType::UTF32); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isR() const { - auto xType = ArrayOptions::dataType(this->_shapeInfo); - return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; + auto xType = ArrayOptions::dataType(this->_shapeInfo); + return xType == FLOAT32 || xType == HALF || xType == DOUBLE || + xType == FLOAT8 || xType == BFLOAT16; } ////////////////////////////////////////////////////////////////////////// bool NDArray::isZ() const { - // TODO: decide if we really want to exclude Bool here - return !isC() && !isR() && !isB() && !isS(); + // TODO: decide if we really want to exclude Bool here + return !isC() && !isR() && !isB() && !isS(); } ////////////////////////////////////////////////////////////////////////// bool NDArray::isB() const { - return ArrayOptions::dataType(this->_shapeInfo) == BOOL; + return ArrayOptions::dataType(this->_shapeInfo) == BOOL; } ////////////////////////////////////////////////////////////////////////// -template +template std::string NDArray::toStringValue(T value) { - std::ostringstream os ; - //throw the value into the string stream - os << value ; - //convert the string stream into a string and return - return os.str() ; + std::ostringstream os; + // throw the value into the string stream + os << value; + // convert the string stream into a string and return + return os.str(); } ////////////////////////////////////////////////////////////////////////// -template<> +template <> std::string NDArray::toStringValue(float16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; + std::ostringstream os; + // throw the value into the string stream + os << (float)value; + // convert the string stream into a string and return + return os.str(); } ////////////////////////////////////////////////////////////////////////// -template<> +template <> std::string NDArray::toStringValue(bfloat16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; + std::ostringstream os; + // throw the value into the string stream + os << (float)value; + // convert the string stream into a string and return + return os.str(); } ////////////////////////////////////////////////////////////////////////// std::string NDArray::asIndexedString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - os << toStringValue(this->e(e)); - if (e < limit - 1) - os << ", "; - } - os << "]"; - return os.str(); + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { + os << toStringValue(this->e(e)); + if (e < limit - 1) os << ", "; + } + os << "]"; + return os.str(); } ////////////////////////////////////////////////////////////////////////// std::string NDArray::asString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - if (this->isR()) - os << toStringValue(this->e(e)); - else if (this->isZ()) - os << toStringValue(this->e(e)); - else if (this->isB()) - os << toStringValue(this->e(e)); - else if (this->isS()) // todo add utf16 and utf32 - os << this->e(e); - if (e < limit - 1) - os << ", "; - } - os << "]"; - return os.str(); + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { + if (this->isR()) + os << toStringValue(this->e(e)); + else if (this->isZ()) + os << toStringValue(this->e(e)); + else if (this->isB()) + os << toStringValue(this->e(e)); + else if (this->isS()) // todo add utf16 and utf32 + os << this->e(e); + if (e < limit - 1) os << ", "; + } + os << "]"; + return os.str(); } //////////////////////////////////////////////////////////////////////// -template +template std::vector NDArray::getBufferAsVector() { - std::vector vector(lengthOf()); - for (Nd4jLong e = 0; e < lengthOf(); e++) - vector[e] = this->e(e); - return vector; + std::vector vector(lengthOf()); + for (Nd4jLong e = 0; e < lengthOf(); e++) vector[e] = this->e(e); + return vector; } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, + NDArray::getBufferAsVector(), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeAsFlatVector() { - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - return vector; + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + return vector; } //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeAsVector() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) vector[e] = this->sizeAt(e); - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = this->sizeAt(e); - - return vector; + return vector; } //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeAsVectorInt() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - - return vector; + return vector; } //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeInfoAsFlatVector() { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); - for (int e = 0; e < magicNumber; e++) - vector[e] = static_cast(_shapeInfo[e]); + for (int e = 0; e < magicNumber; e++) + vector[e] = static_cast(_shapeInfo[e]); - return vector; + return vector; } //////////////////////////////////////////////////////////////////////// std::vector NDArray::getShapeInfoAsVector() { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); - for (int e = 0; e < magicNumber; e++) - vector[e] = this->_shapeInfo[e]; - return vector; + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); + for (int e = 0; e < magicNumber; e++) vector[e] = this->_shapeInfo[e]; + return vector; } //////////////////////////////////////////////////////////////////////// std::vector NDArray::asByteVector() { + if (isS()) { + // string data type requires special treatment + syncToHost(); + auto numWords = this->lengthOf(); + auto offsetsBuffer = this->bufferAsT(); + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); + auto dataLength = offsetsBuffer[numWords]; + std::vector result(headerLength + dataLength); - if (isS()) { - // string data type requires special treatment - syncToHost(); - auto numWords = this->lengthOf(); - auto offsetsBuffer = this->bufferAsT(); - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); - auto dataLength = offsetsBuffer[numWords]; - std::vector result(headerLength + dataLength); - - memcpy(result.data(), buffer(), headerLength + dataLength); + memcpy(result.data(), buffer(), headerLength + dataLength); - return result; + return result; + } else { + // all other types are linear + std::vector result((unsigned long long)this->lengthOf() * + sizeOfT()); + + if (this->isView()) { + auto tmp = this->dup(this->ordering()); + syncToHost(); + memcpy(result.data(), tmp.buffer(), + (unsigned long long)lengthOf() * sizeOfT()); } else { - // all other types are linear - std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); - - if (this->isView()) { - auto tmp = this->dup(this->ordering()); - syncToHost(); - memcpy(result.data(), tmp.buffer(), (unsigned long long) lengthOf() * sizeOfT()); - } else { - syncToHost(); - memcpy(result.data(), buffer(), (unsigned long long) lengthOf() * sizeOfT()); - } - return result; + syncToHost(); + memcpy(result.data(), buffer(), + (unsigned long long)lengthOf() * sizeOfT()); } + return result; + } } ////////////////////////////////////////////////////////////////////////// -void NDArray::linspace(const double start) { - linspace(start, 1); -} +void NDArray::linspace(const double start) { linspace(start, 1); } ////////////////////////////////////////////////////////////////////////// void NDArray::linspace(const double start, const double step) { - if (isS()) - throw std::runtime_error("NDArray::linspace: you can't use this method on String array!"); - Nd4jLong numElements = this->lengthOf(); - for (Nd4jLong e = 0; e < numElements; e++) - this->p(e, start + (step * e)); + if (isS()) + throw std::runtime_error( + "NDArray::linspace: you can't use this method on String array!"); + Nd4jLong numElements = this->lengthOf(); + for (Nd4jLong e = 0; e < numElements; e++) this->p(e, start + (step * e)); } //////////////////////////////////////////////////////////////////////// void NDArray::streamline(char o) { - char order = o == 'a' ? this->ordering() : o; - syncToDevice(); - std::shared_ptr newBuffer = std::make_shared(this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(dataType(), order, rankOf(), shapeOf()); - NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), static_cast(shapeBuffer.primary()), newBuffer->special(), static_cast(shapeBuffer.special()), nullptr, nullptr, nullptr); - setShapeInfo(static_cast(shapeBuffer.primary())); - _buffer = newBuffer; - _offset = 0; - tickWriteDevice(); + char order = o == 'a' ? this->ordering() : o; + syncToDevice(); + std::shared_ptr newBuffer = std::make_shared( + this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); + auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo( + dataType(), order, rankOf(), shapeOf()); + NativeOpExecutioner::execTransformSame( + getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), newBuffer->primary(), + static_cast(shapeBuffer.primary()), newBuffer->special(), + static_cast(shapeBuffer.special()), nullptr, nullptr, nullptr); + setShapeInfo(static_cast(shapeBuffer.primary())); + _buffer = newBuffer; + _offset = 0; + tickWriteDevice(); } //////////////////////////////////////////////////////////////////////// // move assignment operator NDArray& NDArray::operator=(NDArray&& other) noexcept { - if (this == &other) - return *this; - - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; - - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; + if (this == &other) return *this; - return *this; + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; + + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; + + return *this; } //////////////////////////////////////////////////////////////////////// -template +template NDArray& NDArray::operator=(const T scalar) { - this->assign(scalar); - return *this; + this->assign(scalar); + return *this; } template SD_EXPORT NDArray& NDArray::operator=(const double scalar); template SD_EXPORT NDArray& NDArray::operator=(const float scalar); @@ -1179,1469 +1282,1693 @@ template SD_EXPORT NDArray& NDArray::operator=(const int16_t scalar); template SD_EXPORT NDArray& NDArray::operator=(const bool scalar); ////////////////////////////////////////////////////////////////////////// -void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes, Nd4jLong offsetThis, Nd4jLong offsetOther) { - - if(offsetThis == 0) - offsetThis = bufferOffset(); - if(offsetOther == 0) - offsetOther = other.bufferOffset(); +void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, + size_t sizeToCopyInBytes, + Nd4jLong offsetThis, + Nd4jLong offsetOther) { + if (offsetThis == 0) offsetThis = bufferOffset(); + if (offsetOther == 0) offsetOther = other.bufferOffset(); - dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); + dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, + offsetThis, offsetOther); } //////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one void NDArray::assign(const NDArray& other, bool allowParallelism) { + if (this == &other) return; - if (this == &other) - return; - - if (other.isEmpty()) { - if (!isEmpty()) { - throw std::runtime_error("Cannot assign empty array to non-empty array"); - } - return; - } - - if(isEmpty()) { - *this = other; - return; - } - - if (other.lengthOf() == 1) { - - if(lengthOf() == 1) { - NDArray::preparePrimaryUse({this}, {&other}); - BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&other}); - this->syncToDevice(); - } - else { - if (dataType() != other.dataType()) { - auto tmp = other.cast(dataType()); - NDArray::prepareSpecialUse({this}, {&tmp}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {}); - } - else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); - } - } - } - else { - if (other.lengthOf() != lengthOf()) { - auto shapeThis = ShapeUtils::shapeAsString(this); - auto shapeThat = ShapeUtils::shapeAsString(&other); - nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); - throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); - } - + if (other.isEmpty()) { + if (!isEmpty()) { + throw std::runtime_error("Cannot assign empty array to non-empty array"); + } + return; + } + + if (isEmpty()) { + *this = other; + return; + } + + if (other.lengthOf() == 1) { + if (lengthOf() == 1) { + NDArray::preparePrimaryUse({this}, {&other}); + BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, + (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&other}); + this->syncToDevice(); + } else { + if (dataType() != other.dataType()) { + auto tmp = other.cast(dataType()); + NDArray::prepareSpecialUse({this}, {&tmp}); + NativeOpExecutioner::execScalar( + getContext(), scalar::CopyPws, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, + allowParallelism); + NDArray::registerSpecialUse({this}, {}); + } else { NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); + NativeOpExecutioner::execScalar( + getContext(), scalar::CopyPws, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), + other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), + nullptr, allowParallelism); NDArray::registerSpecialUse({this}, {&other}); + } } + } else { + if (other.lengthOf() != lengthOf()) { + auto shapeThis = ShapeUtils::shapeAsString(this); + auto shapeThat = ShapeUtils::shapeAsString(&other); + nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", + shapeThis.c_str(), shapeThat.c_str()); + throw std::runtime_error( + "NDArray::assign: lengths of arrays are mismatched"); + } + + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execTransformAny( + getContext(), transform::Assign, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, + allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } } ////////////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one, wrt order -void NDArray::assign(const NDArray *other, bool allowParallelism) { - assign(*other, allowParallelism); +void NDArray::assign(const NDArray* other, bool allowParallelism) { + assign(*other, allowParallelism); } ////////////////////////////////////////////////////////////////////////// template void NDArray::assign(const T& value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); - - NDArray::prepareSpecialUse({this}, {&temp}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&temp}); -} -template SD_EXPORT void NDArray::assign(const double& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const float& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const int& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); -template SD_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); + // just fire scalar + auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); + + NDArray::prepareSpecialUse({this}, {&temp}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), + temp.specialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&temp}); +} +template SD_EXPORT void NDArray::assign(const double& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const float16& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bfloat16& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const Nd4jLong& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int8_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const int16_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint8_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint16_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint32_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const uint64_t& value, + bool allowParallelism); +template SD_EXPORT void NDArray::assign(const bool& value, + bool allowParallelism); ////////////////////////////////////////////////////////////////////////// NDArray NDArray::detach() { + if (!isAttached()) return *this; - if (!isAttached()) - return *this; - - std::shared_ptr newBuffer = std::make_shared(lengthOf() * sizeOfT(), dataType()); + std::shared_ptr newBuffer = + std::make_shared(lengthOf() * sizeOfT(), dataType()); - NDArray result(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); + NDArray result(newBuffer, + ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); - result.assign(*this); + result.assign(*this); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::varianceNumber(sd::variance::Ops op, bool biasCorrected) { + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected); - NDArray::registerSpecialUse({&res}, {this}); + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execSummaryStatsScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), + res.specialBuffer(), res.specialShapeInfo(), biasCorrected); + NDArray::registerSpecialUse({&res}, {this}); - return res; + return res; } ////////////////////////////////////////////////////////////////////////// // This method returns sum of all elements of this NDArray NDArray NDArray::sumNumber() const { - if (isS()) - throw std::runtime_error("NDArray::sumNumber: you can't use this method on String array!"); - NDArray res(dataType(), getContext()); + if (isS()) + throw std::runtime_error( + "NDArray::sumNumber: you can't use this method on String array!"); + NDArray res(dataType(), getContext()); - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), + res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); - return res; + return res; } ////////////////////////////////////////////////////////////////////////// // This method returns mean number of this NDArray NDArray NDArray::meanNumber() const { + if (isS()) + throw std::runtime_error( + "NDArray::meanNumber: you can't use this method on String array!"); + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - if (isS()) - throw std::runtime_error("NDArray::meanNumber: you can't use this method on String array!"); - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); - return res; + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), + res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); + return res; } ////////////////////////////////////////////////////////////////////////// bool NDArray::hasNaNs() { - if (isS()) - throw std::runtime_error("NDArray::hasNaNs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; + if (isS()) + throw std::runtime_error( + "NDArray::hasNaNs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; } ////////////////////////////////////////////////////////////////////////// bool NDArray::hasInfs() { - if (isS()) - throw std::runtime_error("NDArray::hasInfs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; + if (isS()) + throw std::runtime_error( + "NDArray::hasInfs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; } ////////////////////////////////////////////////////////////////////////// bool NDArray::isFinite() { - if (isS()) - throw std::runtime_error("NDArray::isFinite: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; + if (isS()) + throw std::runtime_error( + "NDArray::isFinite: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; } ////////////////////////////////////////////////////////////////////////// template -void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); +void NDArray::templatedSet(void* buffer, const Nd4jLong* indices, + const void* value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); - auto xOffset = shape::getOffset(shapeInfo(), indices); - t[xOffset] = static_cast(y); + auto xOffset = shape::getOffset(shapeInfo(), indices); + t[xOffset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong* indices, + const void* value), + LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template -void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); +void NDArray::templatedSet(void* buffer, const Nd4jLong offset, + const void* value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); - t[offset] = static_cast(y); + t[offset] = static_cast(y); } -BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong offset, const void* value), + LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void NDArray::setContext(sd::LaunchContext *context) { - - _context = context; - if (getContext() == nullptr) - _context = sd::LaunchContext ::defaultContext(); // empty context for default cases +void NDArray::setContext(sd::LaunchContext* context) { + _context = context; + if (getContext() == nullptr) + _context = sd::LaunchContext ::defaultContext(); // empty context for + // default cases } ////////////////////////////////////////////////////////////////////////// void const* NDArray::bufferWithOffset(Nd4jLong offset) const { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); + return const_cast(buffer() != nullptr + ? static_cast(buffer()) + + (offset * sizeOfT()) + : nullptr); } ////////////////////////////////////////////////////////////////////////// void* NDArray::bufferWithOffset(Nd4jLong offset) { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); + return const_cast(buffer() != nullptr + ? static_cast(buffer()) + + (offset * sizeOfT()) + : nullptr); } ////////////////////////////////////////////////////////////////////////// -// eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +// eventually method reduces array by excluding its shapes along axes present in +// dimensions vector +NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); - std::vector copy(dimensions); + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, + isR() ? dataType() : Environment::getInstance()->defaultFloatDataType(), + keepDims, supportOldShapes, getContext()->getWorkspace()); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance()->defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); - NDArray result(newShape, true, getContext()); + this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, + false); - this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); +NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, + supportOldShapes, + getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { - - std::vector copy(dimensions); +NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, + getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const { +NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes) const { + std::vector copy(dimensions); - std::vector copy(dimensions); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, supportOldShapes, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataType::INT64, keepDims, supportOldShapes, + getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); + reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); +// method reduces array by excluding its shapes along axes present in dimensions +// vector +NDArray NDArray::reduceAlongDimension( + sd::reduce::FloatOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); +NDArray NDArray::reduceAlongDimension( + sd::reduce::SameOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); +NDArray NDArray::reduceAlongDimension( + sd::reduce::BoolOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes); +NDArray NDArray::reduceAlongDimension( + sd::reduce::LongOps op, const std::initializer_list& dimensions, + const bool keepDims, const bool supportOldShapes) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims, + supportOldShapes); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); +NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber FloatOps: you can't use this method on String " + "array!"); - auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType())); - NDArray result(shape, true, this->getContext()); + auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + DataTypeUtils::pickFloatingType(dataType())); + NDArray result(shape, true, this->getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); +NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber SameOps: you can't use this method on String " + "array!"); - NDArray result(dataType(), getContext()); + NDArray result(dataType(), getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); +NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber BoolOps: you can't use this method on String " + "array!"); - auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL); - NDArray result(shape, true, this->getContext()); + auto shape = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL); + NDArray result(shape, true, this->getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); +NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber LongOps: you can't use this method on String " + "array!"); - auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); - NDArray result(shape, true, this->getContext()); + auto shape = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); + NDArray result(shape, true, this->getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) - throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); +void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber FloatOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || + target.dataType() != DataTypeUtils::pickFloatingType(dataType())) + throw std::invalid_argument( + "NDArray::reduceNumber FloatOps: target array should be scalar and " + "have corresponding float type!"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != dataType()) - throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); +void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber SameOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != dataType()) + throw std::invalid_argument( + "NDArray::reduceNumber SameOps: target array should be scalar and have " + "same type as this array!"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, void *extraParams) const { +void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber BoolOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != DataType::BOOL) + throw std::invalid_argument( + "NDArray::reduceNumber BoolOps: target array should be scalar and have " + "bool type!"); - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) - throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); +void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, + void* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceNumber LongOps: you can't use this method on String " + "array!"); + if (target.lengthOf() != 1 || target.dataType() != DataType::INT64) + throw std::invalid_argument( + "NDArray::reduceNumber LongOps: target array should be scalar and have " + "long type!"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::indexReduceNumber: you can't use this method on String array!"); +NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::indexReduceNumber: you can't use this method on String " + "array!"); - auto res = NDArrayFactory::create(0); + auto res = NDArrayFactory::create(0); - NDArray::NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::NDArray::registerSpecialUse({&res}, {this}); + NDArray::NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execIndexReduceScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), + extraParams == nullptr ? nullptr + : extraParams->argumentsAsT(this->dataType()), + res.buffer(), res.shapeInfo(), res.specialBuffer(), + res.specialShapeInfo()); + NDArray::NDArray::registerSpecialUse({&res}, {this}); - return res; + return res; } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(std::initializer_list dimensions) const { - return tensorsAlongDimension(std::vector(dimensions)); +Nd4jLong NDArray::tensorsAlongDimension( + std::initializer_list dimensions) const { + return tensorsAlongDimension(std::vector(dimensions)); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) const { - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); +Nd4jLong NDArray::tensorsAlongDimension( + const std::vector& dimensions) const { + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); - Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); - Nd4jLong numTads = this->lengthOf() / tadLength; + Nd4jLong tadLength = + shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); + Nd4jLong numTads = this->lengthOf() / tadLength; - return numTads; + return numTads; } ////////////////////////////////////////////////////////////////////////// -void NDArray::printShapeInfo(const char * msg) const { - - int rank = shape::rank(_shapeInfo); - int lim = shape::shapeInfoLength(rank); +void NDArray::printShapeInfo(const char* msg) const { + int rank = shape::rank(_shapeInfo); + int lim = shape::shapeInfoLength(rank); - if(msg != nullptr) - printf("shapeInfo %s: [", msg); - else - printf("shapeInfo: ["); + if (msg != nullptr) + printf("shapeInfo %s: [", msg); + else + printf("shapeInfo: ["); - printf("%i, ", rank); - for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ - if(i == rank + 1) - printf(" "); - printf("%lld,", _shapeInfo[i]); - } - printf(" %lld,", shape::type(_shapeInfo)); - printf("%lld,", shape::elementWiseStride(_shapeInfo)); - printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); + printf("%i, ", rank); + for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++) { + if (i == rank + 1) printf(" "); + printf("%lld,", _shapeInfo[i]); + } + printf(" %lld,", shape::type(_shapeInfo)); + printf("%lld,", shape::elementWiseStride(_shapeInfo)); + printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); - fflush(stdout); + fflush(stdout); } ////////////////////////////////////////////////////////////////////////// -void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) const{ - if (sync) - syncToHost(); +void NDArray::printBuffer(const char* msg, Nd4jLong limit, + const bool sync) const { + if (sync) syncToHost(); - if (limit == -1) - limit = (int) this->lengthOf(); + if (limit == -1) limit = (int)this->lengthOf(); - if (msg != nullptr) - printf("%s: [", msg); - else - printf("["); - if (this->isR()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (e) - printf(", "); - printf("%f", this->e(e)); - } - } - else if (this->isZ()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->dataType() != sd::DataType::INT64 && this->dataType() != sd::DataType::UINT64) - printf("%d", this->e(e)); - else - printf("%llu", this->e(e)); - if (e < limit - 1) - printf(", "); - } + if (msg != nullptr) + printf("%s: [", msg); + else + printf("["); + if (this->isR()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (e) printf(", "); + printf("%f", this->e(e)); } - else if (this->isB()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->e(e)) - printf("true"); - else - printf("false"); - if (e < limit - 1) - printf(", "); - } + } else if (this->isZ()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->dataType() != sd::DataType::INT64 && + this->dataType() != sd::DataType::UINT64) + printf("%d", this->e(e)); + else + printf("%llu", this->e(e)); + if (e < limit - 1) printf(", "); + } + } else if (this->isB()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->e(e)) + printf("true"); + else + printf("false"); + if (e < limit - 1) printf(", "); + } + } else if (this->isS()) { + // todo do we need this print offsets + /* + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%lld\"", this->getOffset(e)); + if (e < limit - 1) + printf(", "); } - else if (this->isS()) { - // todo do we need this print offsets - /* - for (Nd4jLong e = 0; e < limit; e++) { - printf("\"%lld\"", this->getOffset(e)); - if (e < limit - 1) - printf(", "); - } - printf("]\n["); - */ - for (Nd4jLong e = 0; e < limit; e++) { - printf("\"%s\"", this->e(e).c_str()); - if (e < limit - 1) - printf(", "); - } + printf("]\n["); + */ + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%s\"", this->e(e).c_str()); + if (e < limit - 1) printf(", "); } - printf("]\n"); - fflush(stdout); + } + printf("]\n"); + fflush(stdout); } ////////////////////////////////////////////////////////////////////////// -// print element by element consequently in a way they (elements) are stored in physical memory +// print element by element consequently in a way they (elements) are stored in +// physical memory void NDArray::printLinearBuffer() const { + syncToHost(); - syncToHost(); - - const auto ews = this->ews() > 0 ? this->ews() : 1; - const auto len = this->lengthOf(); + const auto ews = this->ews() > 0 ? this->ews() : 1; + const auto len = this->lengthOf(); - printf("["); + printf("["); - if (this->dataType() == sd::DataType::INT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%d, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::INT64) { - for(Nd4jLong e = 0; e < len; e++) - printf("%lld, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::FLOAT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.3f, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::DOUBLE) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.3f, ", this->bufferAsT()[e * ews]); - } - else - throw std::invalid_argument("NDArray::printLinearBuffer: not implemented yet for this data type !"); + if (this->dataType() == sd::DataType::INT32) { + for (Nd4jLong e = 0; e < len; e++) + printf("%d, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::INT64) { + for (Nd4jLong e = 0; e < len; e++) + printf("%lld, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::FLOAT32) { + for (Nd4jLong e = 0; e < len; e++) + printf("%.3f, ", this->bufferAsT()[e * ews]); + } else if (this->dataType() == sd::DataType::DOUBLE) { + for (Nd4jLong e = 0; e < len; e++) + printf("%.3f, ", this->bufferAsT()[e * ews]); + } else + throw std::invalid_argument( + "NDArray::printLinearBuffer: not implemented yet for this data type !"); - printf("]\n"); - fflush(stdout); + printf("]\n"); + fflush(stdout); } ////////////////////////////////////////////////////////////////////////// static void printFormatted(NDArray const* arr, int depth, int limit) { - - if (arr->rankOf() == 1) { - printf("[ "); - for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { - if (arr->isR()) - printf("%f, ", arr->e(i)); - else if (arr->isZ()) - printf("%lld, ", arr->e(i)); - else if (arr->isB()) - printf("%s, ", arr->e(i)?"true":"false"); - else if (arr->isS()) { - printf("\"%s\", ", arr->e(i).c_str()); - } - } - printf("]\n"); + if (arr->rankOf() == 1) { + printf("[ "); + for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { + if (arr->isR()) + printf("%f, ", arr->e(i)); + else if (arr->isZ()) + printf("%lld, ", arr->e(i)); + else if (arr->isB()) + printf("%s, ", arr->e(i) ? "true" : "false"); + else if (arr->isS()) { + printf("\"%s\", ", arr->e(i).c_str()); + } } - else if (arr->rankOf() == 2) { - Nd4jLong rows = arr->rows(); - Nd4jLong cols = arr->columns(); - char* padding = new char[depth + 1]; - memset(padding, ' ', depth); - padding[depth] = 0; - printf("["); - for (Nd4jLong row = 0; row < rows; ++row) { - if (row && depth > 0) - printf("%s", padding); - printf("["); - Nd4jLong colLimit = cols > limit?cols:limit; - for (Nd4jLong col = 0; col < colLimit; ++col) { - if (col) - printf(", "); - if (arr->isR()) - printf("%f", arr->e(row, col)); - else if (arr->isZ()) - printf("%lld", arr->e(row, col)); - else if (arr->isB()) - printf("%s", arr->e(row, col)?"true":"false"); - else if (arr->isS()) { - printf("\"%s\"", arr->e(row * cols + col).c_str()); - } - } - if (row < rows - 1) - printf("]\n"); - else - printf("]"); + printf("]\n"); + } else if (arr->rankOf() == 2) { + Nd4jLong rows = arr->rows(); + Nd4jLong cols = arr->columns(); + char* padding = new char[depth + 1]; + memset(padding, ' ', depth); + padding[depth] = 0; + printf("["); + for (Nd4jLong row = 0; row < rows; ++row) { + if (row && depth > 0) printf("%s", padding); + printf("["); + Nd4jLong colLimit = cols > limit ? cols : limit; + for (Nd4jLong col = 0; col < colLimit; ++col) { + if (col) printf(", "); + if (arr->isR()) + printf("%f", arr->e(row, col)); + else if (arr->isZ()) + printf("%lld", arr->e(row, col)); + else if (arr->isB()) + printf("%s", arr->e(row, col) ? "true" : "false"); + else if (arr->isS()) { + printf("\"%s\"", arr->e(row * cols + col).c_str()); } + } + if (row < rows - 1) + printf("]\n"); + else printf("]"); - if (padding) - delete [] padding; } - else { - //std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); - size_t restCount = 2; - printf("["); - restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); - for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { - NDArray subArr = (*arr)(arrIndex, {0}); - printFormatted(&subArr, depth + 1, limit); - if (arrIndex < restCount - 1) { - for (Nd4jLong i = 1; i < arr->rankOf(); ++i) - printf("\n"); - for (Nd4jLong i = 0; i < depth - 2; ++i) - printf(" "); - } - } - printf("]"); + printf("]"); + if (padding) delete[] padding; + } else { + // std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); + size_t restCount = 2; + printf("["); + restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); + for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { + NDArray subArr = (*arr)(arrIndex, {0}); + printFormatted(&subArr, depth + 1, limit); + if (arrIndex < restCount - 1) { + for (Nd4jLong i = 1; i < arr->rankOf(); ++i) printf("\n"); + for (Nd4jLong i = 0; i < depth - 2; ++i) printf(" "); + } } + printf("]"); + } } ////////////////////////////////////////////////////////////////////////// void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { + syncToHost(); - syncToHost(); - - Nd4jLong rank = this->rankOf(); + Nd4jLong rank = this->rankOf(); - bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); + bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); - if (msg) - printf("%s: ", msg); + if (msg) printf("%s: ", msg); - if (this->isEmpty()) { - printf("Empty\n"); - } - else if (this->rankOf() == 0) { - if (this->isZ()) - printf("%lld\n", this->e(0)); - else if (this->isR()) - printf("%f\n", this->e(0)); - else if (this->isB()) { - printf("%s\n", this->e(0)?"true":"false"); - } - else if (this->isS()) { - // todo do we need this - // printf("\"%lld\"\n", this->getOffset(e)); - printf("\"%s\"\n", this->e(0).c_str()); - } - } - else if (rowFlag && ews()==1) - printBuffer(nullptr, limit); - else { - if (msg) - printf("\n"); - printFormatted(this, 1, limit); - printf("\n"); + if (this->isEmpty()) { + printf("Empty\n"); + } else if (this->rankOf() == 0) { + if (this->isZ()) + printf("%lld\n", this->e(0)); + else if (this->isR()) + printf("%f\n", this->e(0)); + else if (this->isB()) { + printf("%s\n", this->e(0) ? "true" : "false"); + } else if (this->isS()) { + // todo do we need this + // printf("\"%lld\"\n", this->getOffset(e)); + printf("\"%s\"\n", this->e(0).c_str()); } - fflush(stdout); + } else if (rowFlag && ews() == 1) + printBuffer(nullptr, limit); + else { + if (msg) printf("\n"); + printFormatted(this, 1, limit); + printf("\n"); + } + fflush(stdout); } ////////////////////////////////////////////////////////////////////////// template void* NDArray::templatedPointerShift(const Nd4jLong offset) const { - return const_cast(reinterpret_cast(buffer()) + offset); + return const_cast(reinterpret_cast(buffer()) + offset); } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void* NDArray::templatedPointerShift, + (const Nd4jLong offset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() const &{ - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.transposei(); +// method makes copy of this array and applies to the copy transpose operation, +// this array remains unaffected +NDArray NDArray::transpose() const& { + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), + bufferOffset()); + newArr.transposei(); - return newArr; + return newArr; } ////////////////////////////////////////////////////////////////////////// -// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected +// method makes copy of this array and applies to the copy transpose operation, +// this array remains unaffected NDArray NDArray::transpose() && { - - this->transposei(); - return std::move(*this); + this->transposei(); + return std::move(*this); } //////////////////////////////////////////////////////////////////////// -// method performs transpose operation based on this array and store result in target, this array remains unaffected +// method performs transpose operation based on this array and store result in +// target, this array remains unaffected void NDArray::transpose(NDArray& target) const { + auto correctShape = + ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); + if (!shape::equalsStrict(correctShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::transpose method: the shapeInfo of target array is wrong !"); - auto correctShape = ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); - if(!shape::equalsStrict(correctShape, target.shapeInfo())) - throw std::runtime_error("NDArray::transpose method: the shapeInfo of target array is wrong !"); - - target._buffer = _buffer; - target._offset = _offset; - target._isView = true; + target._buffer = _buffer; + target._offset = _offset; + target._isView = true; } //////////////////////////////////////////////////////////////////////// -// This method applies in-place transpose to this array, so this array becomes transposed +// This method applies in-place transpose to this array, so this array becomes +// transposed void NDArray::transposei() { - std::vector perm; - for (int e = this->rankOf() - 1; e >= 0; e--) - perm.emplace_back(e); + std::vector perm; + for (int e = this->rankOf() - 1; e >= 0; e--) perm.emplace_back(e); - this->permutei(perm); + this->permutei(perm); } //////////////////////////////////////////////////////////////////////// -bool NDArray::equalsTo(const NDArray &other, double eps) const { - return equalsTo(&other, eps); +bool NDArray::equalsTo(const NDArray& other, double eps) const { + return equalsTo(&other, eps); } ////////////////////////////////////////////////////////////////////////// void NDArray::setAttached(bool reallyAttached) { - _isAttached = reallyAttached; + _isAttached = reallyAttached; }; ////////////////////////////////////////////////////////////////////////// // calculate strides void NDArray::updateStrides(const char order) { - shape::updateStrides(_shapeInfo, order); - syncShape(); + shape::updateStrides(_shapeInfo, order); + syncShape(); } ////////////////////////////////////////////////////////////////////////// // set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff) { - std::vector vShape(shape); - return reshapei(order, vShape, copyToNewBuff); +bool NDArray::reshapei(const char order, + const std::initializer_list& shape, + const bool copyToNewBuff) { + std::vector vShape(shape); + return reshapei(order, vShape, copyToNewBuff); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::initializer_list& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); +bool NDArray::reshapei(const std::initializer_list& shape, + const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); } ////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::vector& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); +bool NDArray::reshapei(const std::vector& shape, + const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); } ////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(const std::initializer_list &dimensions, char order) { - std::vector dims(dimensions); - enforce(dims, order); +void NDArray::enforce(const std::initializer_list& dimensions, + char order) { + std::vector dims(dimensions); + enforce(dims, order); } ////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(std::vector &dimensions, char o) { - - Nd4jLong prod = 1; - for (int e = 0; e < dimensions.size(); e++) - prod *= dimensions[e]; +void NDArray::enforce(std::vector& dimensions, char o) { + Nd4jLong prod = 1; + for (int e = 0; e < dimensions.size(); e++) prod *= dimensions[e]; - if (prod != this->lengthOf()) { - std::string current = ShapeUtils::shapeAsString(this); - std::string enforced = ShapeUtils::shapeAsString(dimensions); - nd4j_printf("Can't enforce new shape, lengths mismatch. Original shape: %s; Requested shape: %s\n", current.c_str(), enforced.c_str()); - throw std::runtime_error("Incompatible shape"); - } + if (prod != this->lengthOf()) { + std::string current = ShapeUtils::shapeAsString(this); + std::string enforced = ShapeUtils::shapeAsString(dimensions); + nd4j_printf( + "Can't enforce new shape, lengths mismatch. Original shape: %s; " + "Requested shape: %s\n", + current.c_str(), enforced.c_str()); + throw std::runtime_error("Incompatible shape"); + } - char order = o == 'a' ? this->ordering() : o; - setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); + char order = o == 'a' ? this->ordering() : o; + setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::argMax(std::initializer_list dimensions) { + if (isS()) + throw std::runtime_error( + "NDArray::argMax: you can't use this method on String array!"); - if (isS()) - throw std::runtime_error("NDArray::argMax: you can't use this method on String array!"); - - if (dimensions.size() == 0) { - Nd4jLong max = 0; - auto mv = -DataTypeUtils::max(); - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto val = this->e(e); - if (mv < val) { - mv = val; - max = e; - } - } - return max; + if (dimensions.size() == 0) { + Nd4jLong max = 0; + auto mv = -DataTypeUtils::max(); + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto val = this->e(e); + if (mv < val) { + mv = val; + max = e; + } } - else - throw std::runtime_error("NDArray::argMax() - Not implemented yet"); + return max; + } else + throw std::runtime_error("NDArray::argMax() - Not implemented yet"); } ////////////////////////////////////////////////////////////////////////// -// create new array with corresponding order and shape, new array will point to the same _buffer as this array -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) const & { - - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.reshapei(order, shape, copyToNewBuff); +// create new array with corresponding order and shape, new array will point to +// the same _buffer as this array +NDArray NDArray::reshape(const char order, const std::vector& shape, + const bool copyToNewBuff) const& { + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), + bufferOffset()); + newArr.reshapei(order, shape, copyToNewBuff); - return newArr; + return newArr; } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) && { - - this->reshapei(order, shape, copyToNewBuff); - return std::move(*this); +NDArray NDArray::reshape(const char order, const std::vector& shape, + const bool copyToNewBuff) && { + this->reshapei(order, shape, copyToNewBuff); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. void NDArray::tilei(const std::vector& reps) { - *this = this->tile(reps); + *this = this->tile(reps); } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::sizeAt(const int dim) const { + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::sizeAt: bad size index requested"); - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::sizeAt: bad size index requested"); - - if (dim >= 0) - return shape::shapeOf(_shapeInfo)[dim]; - else - return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; + if (dim >= 0) + return shape::shapeOf(_shapeInfo)[dim]; + else + return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; } ////////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::strideAt(const int dim) const { + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::strideAt: Bad size index requested"); - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::strideAt: Bad size index requested"); - - if (dim >= 0) - return shape::stride(_shapeInfo)[dim]; - else - return shape::stride(_shapeInfo)[this->rankOf() + dim]; + if (dim >= 0) + return shape::stride(_shapeInfo)[dim]; + else + return shape::stride(_shapeInfo)[this->rankOf() + dim]; } ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - return permutei(vec); + std::vector vec(dimensions); + return permutei(vec); } ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const std::vector& dimensions) { - return permutei(dimensions.data(), rankOf()); + return permutei(dimensions.data(), rankOf()); } ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - std::vector ivec(dimensions.size()); + std::vector vec(dimensions); + std::vector ivec(dimensions.size()); - for (int e = 0; e < vec.size(); e++) - ivec[e] = static_cast(vec[e]); + for (int e = 0; e < vec.size(); e++) ivec[e] = static_cast(vec[e]); - return permutei(ivec); + return permutei(ivec); } ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const std::vector& dimensions) { + std::vector ivec(dimensions.size()); - std::vector ivec(dimensions.size()); + for (int e = 0; e < dimensions.size(); e++) ivec[e] = dimensions[e]; - for (int e = 0; e < dimensions.size(); e++) - ivec[e] = dimensions[e]; - - return permutei(ivec.data(), rankOf()); + return permutei(ivec.data(), rankOf()); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) const & { - - // evaluate shapeInfo for output (permuted) array ret - auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), bufferOffset()); - ret._isView = true; - return ret; +NDArray NDArray::permute(const int* dimensions, const int rank) const& { + // evaluate shapeInfo for output (permuted) array ret + auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, getContext()->getWorkspace()); + NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), + bufferOffset()); + ret._isView = true; + return ret; } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const int* dimensions, const int rank) && { - - this->permutei(dimensions, rank); - return std::move(*this); + this->permutei(dimensions, rank); + return std::move(*this); } ///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ - int tempDims[MAX_RANK]; - shape::convertT(const_cast(dimensions), tempDims, rank); - return permute(tempDims, rank); +NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const& { + int tempDims[MAX_RANK]; + shape::convertT(const_cast(dimensions), tempDims, + rank); + return permute(tempDims, rank); } ///////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { - - this->permutei(dimensions, rank); - return std::move(*this); + this->permutei(dimensions, rank); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const &{ - - return permute(dimensions.data(), rankOf()); +NDArray NDArray::permute(const std::vector& dimensions) const& { + return permute(dimensions.data(), rankOf()); } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const std::vector& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); + this->permutei(dimensions); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const & { - - return permute(dimensions.data(), rankOf()); +NDArray NDArray::permute(const std::vector& dimensions) const& { + return permute(dimensions.data(), rankOf()); } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const std::vector& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); + this->permutei(dimensions); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ - - std::vector vec(dimensions); - return permute(vec); +NDArray NDArray::permute(const std::initializer_list& dimensions) const& { + std::vector vec(dimensions); + return permute(vec); } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const std::initializer_list& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); + this->permutei(dimensions); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const & { - std::vector vec(dimensions); - return permute(vec); +NDArray NDArray::permute( + const std::initializer_list& dimensions) const& { + std::vector vec(dimensions); + return permute(vec); } ////////////////////////////////////////////////////////////////////////// NDArray NDArray::permute(const std::initializer_list& dimensions) && { - - this->permutei(dimensions); - return std::move(*this); + this->permutei(dimensions); + return std::move(*this); } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); +void NDArray::permute(const int* dimensions, const int rank, + NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || + rank != target.rankOf()) + throw std::runtime_error( + "NDArray::permute method: either arrays are nullptr or ranks are " + "not suitable!"); - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, target.getContext()->getWorkspace()); - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); +void NDArray::permute(const Nd4jLong* dimensions, const int rank, + NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || + rank != target.rankOf()) + throw std::runtime_error( + "NDArray::permute method: either arrays are nullptr or ranks are " + "not suitable!"); - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo( + dimensions, rank, *this, target.getContext()->getWorkspace()); - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); +void NDArray::permute(const std::vector& dimensions, + NDArray& target) const { + permute(dimensions.data(), rankOf(), target); } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); +void NDArray::permute(const std::vector& dimensions, + NDArray& target) const { + permute(dimensions.data(), rankOf(), target); } ////////////////////////////////////////////////////////////////////////// // check whether array is identity matrix bool NDArray::isIdentityMatrix() { - if (isS()) - throw std::runtime_error("NDArray::isIdentityMatrix: you can't use this method on String array!"); - if(rankOf() !=2 || rows() != columns()) - throw std::runtime_error("isIdentityMatrix method: matrix must be square and have rank = 2 !"); - - const double eps = 1e-5f; - for(Nd4jLong i=0; i(i,i) - 1.f) > eps) - return false; - - for(Nd4jLong i=0; i(i,j)) > eps) - return false; - } + if (isS()) + throw std::runtime_error( + "NDArray::isIdentityMatrix: you can't use this method on String " + "array!"); + if (rankOf() != 2 || rows() != columns()) + throw std::runtime_error( + "isIdentityMatrix method: matrix must be square and have rank = 2 !"); + + const double eps = 1e-5f; + for (Nd4jLong i = 0; i < rows(); ++i) + if (sd::math::nd4j_abs(e(i, i) - 1.f) > eps) return false; + + for (Nd4jLong i = 0; i < rows(); ++i) { + for (Nd4jLong j = 0; j < columns(); ++j) { + if (i == j) continue; + if (sd::math::nd4j_abs(e(i, j)) > eps) return false; } - return true; + } + return true; } ////////////////////////////////////////////////////////////////////////// // check whether array is unitary matrix bool NDArray::isUnitary() { - if (isS()) - throw std::runtime_error("NDArray::isUnitary: you can't use this method on String array!"); - if(rankOf() != 2 || rows() != columns()) - throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !"); + if (isS()) + throw std::runtime_error( + "NDArray::isUnitary: you can't use this method on String array!"); + if (rankOf() != 2 || rows() != columns()) + throw std::runtime_error( + "isUnitary method: matrix must be square and have rank = 2 !"); - auto tr = this->transpose(); - auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); + auto tr = this->transpose(); + auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); - bool result = trMul->isIdentityMatrix(); - delete trMul; + bool result = trMul->isIdentityMatrix(); + delete trMul; - return result; + return result; } ////////////////////////////////////////////////////////////////////////// template <> const std::string* SD_EXPORT NDArray::bufferAsT() const { - throw std::runtime_error("This method is NOT supposed to be used"); + throw std::runtime_error("This method is NOT supposed to be used"); } ////////////////////////////////////////////////////////////////////////// template const T* NDArray::bufferAsT() const { - // FIXME: do we REALLY want sync here? - syncToHost(); + // FIXME: do we REALLY want sync here? + syncToHost(); - return reinterpret_cast(buffer()); + return reinterpret_cast(buffer()); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT const, * NDArray::bufferAsT() const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT const, + *NDArray::bufferAsT() const, LIBND4J_TYPES); template T* NDArray::bufferAsT() { - syncToHost(); - return reinterpret_cast(buffer()); + syncToHost(); + return reinterpret_cast(buffer()); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, * NDArray::bufferAsT(), LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, *NDArray::bufferAsT(), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArray::subarray(IndicesList& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match"); - - std::vector indexes(3 * idxSize); - - // convert IndicesList to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx.at(d)->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isPoint()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isInterval()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().size();// last - indexes[3 * d + 2] = idx.at(d)->stride(); // stride - } - else { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last - indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride - } + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match"); + + std::vector indexes(3 * idxSize); + + // convert IndicesList to vector + for (int d = 0; d < idxSize; ++d) { + if (idx.at(d)->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } else if (idx.at(d)->isPoint()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } else if (idx.at(d)->isInterval()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().size(); // last + indexes[3 * d + 2] = idx.at(d)->stride(); // stride + } else { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last + indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride } - return NDArray((*this)(indexes, true, true)); + } + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::subarray(const std::initializer_list& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the array rank"); - - std::vector indexes(3 * idxSize); - - // convert NDIndex to vector - int d = 0; - for (const auto& item : idx) { - - if (item->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isPoint()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isInterval()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().size(); // last - indexes[3 * d + 2] = item->stride(); // stride - } - else { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().at(1); // last - indexes[3 * d + 2] = item->getIndices().at(2); // stride - } - ++d; + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match the array rank"); + + std::vector indexes(3 * idxSize); + + // convert NDIndex to vector + int d = 0; + for (const auto& item : idx) { + if (item->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } else if (item->isPoint()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } else if (item->isInterval()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().size(); // last + indexes[3 * d + 2] = item->stride(); // stride + } else { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().at(1); // last + indexes[3 * d + 2] = item->getIndices().at(2); // stride } + ++d; + } - // release NDIndices - for (auto i: idx) - delete i; + // release NDIndices + for (auto i : idx) delete i; - return NDArray((*this)(indexes, true, true)); + return NDArray((*this)(indexes, true, true)); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::subarray(const Intervals& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the rank of array!"); - - std::vector indexes(2 * idxSize); - - // convert Intervals to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx[d].empty()) { - indexes[2 * d] = 0; // first - indexes[2 * d + 1] = 0; // last - } - else { - indexes[2 * d] = idx[d][0]; // first - indexes[2 * d + 1] = idx[d][1]; // last - } + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error( + "NDArray::subarray: number of indices should match the rank of array!"); + + std::vector indexes(2 * idxSize); + + // convert Intervals to vector + for (int d = 0; d < idxSize; ++d) { + if (idx[d].empty()) { + indexes[2 * d] = 0; // first + indexes[2 * d + 1] = 0; // last + } else { + indexes[2 * d] = idx[d][0]; // first + indexes[2 * d + 1] = idx[d][1]; // last } + } - return NDArray((*this)(indexes, true)); + return NDArray((*this)(indexes, true)); } ////////////////////////////////////////////////////////////////////////// template -NDArray NDArray::asT() const{ +NDArray NDArray::asT() const { + auto result = isScalar() + ? NDArray('c', {}, std::vector{0.}, + DataTypeUtils::fromT(), this->getContext()) + : NDArray(ordering(), getShapeAsVector(), + DataTypeUtils::fromT(), this->getContext()); - auto result = isScalar() ? NDArray('c', {}, std::vector{0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformAny( + getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; + return result; } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asT, () const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template NDArray NDArray::asS() const { + if (!isS()) + throw std::runtime_error( + "NDArray::asS: you can use this method only for String array!"); - if (!isS()) - throw std::runtime_error("NDArray::asS: you can use this method only for String array!"); - - auto dtype = DataTypeUtils::fromT(); - - if (!(DataTypeUtils::isS(dtype))) - throw std::invalid_argument("NDArray::asS: invalid DataType used"); + auto dtype = DataTypeUtils::fromT(); - if (dtype == dataType()) { + if (!(DataTypeUtils::isS(dtype))) + throw std::invalid_argument("NDArray::asS: invalid DataType used"); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - const auto nInputoffsets = bufferAsT(); - std::shared_ptr pBuffer = std::make_shared(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); + if (dtype == dataType()) { + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + const auto nInputoffsets = bufferAsT(); + std::shared_ptr pBuffer = + std::make_shared(offsetsLength + nInputoffsets[lengthOf()], + dtype, getContext()->getWorkspace(), true); - NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); - res.setAttached(getContext()->getWorkspace() != nullptr); + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), + getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); - preparePrimaryUse({ &res }, { this }); - memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); - auto data = res.bufferAsT() + offsetsLength; - const auto inData = bufferAsT() + offsetsLength; - memcpy(data, inData, nInputoffsets[lengthOf()]); + preparePrimaryUse({&res}, {this}); + memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); + auto data = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + memcpy(data, inData, nInputoffsets[lengthOf()]); - registerPrimaryUse({ &res }, { this }); - return res; - } + registerPrimaryUse({&res}, {this}); + return res; + } - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - std::vector offsets(lengthOf() + 1); + std::vector offsets(lengthOf() + 1); - const auto nInputoffsets = bufferAsT(); + const auto nInputoffsets = bufferAsT(); - Nd4jLong start = 0, stop = 0; - Nd4jLong dataLength = 0; + Nd4jLong start = 0, stop = 0; + Nd4jLong dataLength = 0; - auto data = bufferAsT() + offsetsLength; - for (Nd4jLong e = 0; e < lengthOf(); e++) { - offsets[e] = dataLength; - start = nInputoffsets[e]; - stop = nInputoffsets[e + 1]; - if (dataType() == DataType::UTF8) { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf8StringInUtf16(data + start, stop) + auto data = bufferAsT() + offsetsLength; + for (Nd4jLong e = 0; e < lengthOf(); e++) { + offsets[e] = dataLength; + start = nInputoffsets[e]; + stop = nInputoffsets[e + 1]; + if (dataType() == DataType::UTF8) { + dataLength += (dtype == DataType::UTF16) + ? unicode::offsetUtf8StringInUtf16(data + start, stop) : unicode::offsetUtf8StringInUtf32(data + start, stop); - } - else if (dataType() == DataType::UTF16) { - dataLength += (dtype == DataType::UTF32) ? unicode::offsetUtf16StringInUtf32(data + start, (stop / sizeof(char16_t)) ) - : unicode::offsetUtf16StringInUtf8(data + start, (stop / sizeof(char16_t))); - } - else { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf32StringInUtf16(data + start, (stop / sizeof(char32_t))) - : unicode::offsetUtf32StringInUtf8(data + start, (stop / sizeof(char32_t))); - } + } else if (dataType() == DataType::UTF16) { + dataLength += (dtype == DataType::UTF32) + ? unicode::offsetUtf16StringInUtf32( + data + start, (stop / sizeof(char16_t))) + : unicode::offsetUtf16StringInUtf8( + data + start, (stop / sizeof(char16_t))); + } else { + dataLength += (dtype == DataType::UTF16) + ? unicode::offsetUtf32StringInUtf16( + data + start, (stop / sizeof(char32_t))) + : unicode::offsetUtf32StringInUtf8( + data + start, (stop / sizeof(char32_t))); } - offsets[lengthOf()] = dataLength; + } + offsets[lengthOf()] = dataLength; - std::shared_ptr pBuffer = std::make_shared(offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); + std::shared_ptr pBuffer = std::make_shared( + offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); - NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); - res.setAttached(getContext()->getWorkspace() != nullptr); + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), + getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); - preparePrimaryUse({ &res }, { this }); + preparePrimaryUse({&res}, {this}); - memcpy(res.bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + memcpy(res.bufferAsT(), offsets.data(), + offsets.size() * sizeof(Nd4jLong)); - auto outData = res.bufferAsT() + offsetsLength; - const auto inData = bufferAsT() + offsetsLength; + auto outData = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; - auto func = PRAGMA_THREADS_FOR{ - for (int e = start; e < stop; e++) { - auto cdata = outData + offsets[e]; - auto end = nInputoffsets[e + 1]; - auto idata = inData + nInputoffsets[e]; - if (dtype == DataType::UTF16) { - if (dataType() == DataType::UTF8) { - unicode::utf8to16(idata, outData, end); - } - else { - unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); - } - } - else if (dtype == DataType::UTF32) { - if (dataType() == DataType::UTF8) { - unicode::utf8to32(idata, cdata, end); - } - else { - unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); - } - } - else { - if (dataType() == DataType::UTF16) { - unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); - } - else { - unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); - } - } + auto func = PRAGMA_THREADS_FOR { + for (int e = start; e < stop; e++) { + auto cdata = outData + offsets[e]; + auto end = nInputoffsets[e + 1]; + auto idata = inData + nInputoffsets[e]; + if (dtype == DataType::UTF16) { + if (dataType() == DataType::UTF8) { + unicode::utf8to16(idata, outData, end); + } else { + unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); } - }; + } else if (dtype == DataType::UTF32) { + if (dataType() == DataType::UTF8) { + unicode::utf8to32(idata, cdata, end); + } else { + unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); + } + } else { + if (dataType() == DataType::UTF16) { + unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); + } else { + unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); + } + } + } + }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - registerPrimaryUse({ &res }, { this }); + registerPrimaryUse({&res}, {this}); - return res; + return res; } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArray::asS, () const, + LIBND4J_STRINGTYPES); //////////////////////////////////////////////////////////////////////// NDArray NDArray::asT(DataType dtype) const { + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::asT: you can't use this method on String array with not " + "string DataType!"); - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); - - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on not String array with string DataType!"); + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::asT: you can't use this method on not String array with " + "string DataType!"); - if (isS()){ - BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); - } else { - BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); - } + if (isS()) { + BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); + } else { + BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); + } - return NDArray(); + return NDArray(); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::cast(DataType dtype) const { + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::cast: you can't use this method on String array with not " + "string DataType!"); - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on String array with not string DataType!"); + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error( + "NDArray::cast: you can't use this method on not String array with " + "string DataType!"); - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on not String array with string DataType!"); - - return this->asT(dtype); + return this->asT(dtype); } //////////////////////////////////////////////////////////////////////// void NDArray::cast(NDArray& target, DataType dtype) { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); - // TODO: to be implemented properly - target.assign(this); + if (isS()) + throw std::runtime_error( + "NDArray::cast: you can't use this method on String array!"); + // TODO: to be implemented properly + target.assign(this); } //////////////////////////////////////////////////////////////////////// void NDArray::operator+=(const NDArray& other) { - - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); - - if (this->lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } + if (isS()) + throw std::runtime_error( + "NDArray::operator+=: you can't use this method on String array!"); + if (!Environment::getInstance()->isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator+=: Cannot add different types", this->dataType(), + other.dataType()); + + if (this->lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator+=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy } + } } //////////////////////////////////////////////////////////////////////// void NDArray::operator-=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } + if (isS()) + throw std::runtime_error( + "NDArray::operator-=: you can't use this method on String array!"); + + if (!Environment::getInstance()->isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator-=: Cannot subtract different types", this->dataType(), + other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator-=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy } + } } //////////////////////////////////////////////////////////////////////// void NDArray::operator*=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } + if (isS()) + throw std::runtime_error( + "NDArray::operator*=: you can't use this method on String array!"); + if (!Environment::getInstance()->isExperimentalBuild() && + this->dataType() != other.dataType() && + (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build( + "NDArray operator*=: Cannot multiply different types", this->dataType(), + other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator*=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy } + } } //////////////////////////////////////////////////////////////////////// void NDArray::operator/=(const NDArray& other) { - if (isS() || other.isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - if (other.isB()) - throw std::runtime_error("NDArray::operator/=: you can't divide by bool array!"); - - if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType()) { - throw sd::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); - } - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } + if (isS() || other.isS()) + throw std::runtime_error( + "NDArray::operator/=: you can't use this method on String array!"); + if (other.isB()) + throw std::runtime_error( + "NDArray::operator/=: you can't divide by bool array!"); + + if (!Environment::getInstance()->isExperimentalBuild() && + this->dataType() != other.dataType()) { + throw sd::datatype_exception::build( + "NDArray operator/=: Cannot divide different types", this->dataType(), + other.dataType()); + } + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else if (other.lengthOf() == lengthOf() && + this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } else { + const Nd4jLong* bShape = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, + getContext()->getWorkspace())) + throw std::invalid_argument( + "NDArray::operator/=: the shapes of this and other arrays are not " + "suitable for broadcast operation !"); + + if (shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, + false); + } else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, + false); + *this = std::move(result); // move assignment operator, zero cost copy } + } } //////////////////////////////////////////////////////////////////////// template void NDArray::operator+=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::operator+=: you can't use this method on String array!"); - auto other = NDArrayFactory::create(this->dataType(), value, getContext()); + auto other = NDArrayFactory::create(this->dataType(), value, getContext()); - NDArray::prepareSpecialUse({this}, {&other}); + NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {}); + NDArray::registerSpecialUse({this}, {}); } template SD_EXPORT void NDArray::operator+=(const double value); template SD_EXPORT void NDArray::operator+=(const float value); @@ -2652,18 +2979,23 @@ template SD_EXPORT void NDArray::operator+=(const int value); template SD_EXPORT void NDArray::operator+=(const bool value); //////////////////////////////////////////////////////////////////////// -template +template void NDArray::operator-=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::operator-=: you can't use this method on String array!"); - auto other = NDArrayFactory::create(dataType(), value, getContext()); + auto other = NDArrayFactory::create(dataType(), value, getContext()); - NDArray::prepareSpecialUse({this}, {&other}); + NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {}); + NDArray::registerSpecialUse({this}, {}); } template SD_EXPORT void NDArray::operator-=(const double value); template SD_EXPORT void NDArray::operator-=(const float value); @@ -2674,16 +3006,21 @@ template SD_EXPORT void NDArray::operator-=(const int value); template SD_EXPORT void NDArray::operator-=(const bool value); //////////////////////////////////////////////////////////////////////// -template +template void NDArray::operator*=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::operator*=: you can't use this method on String array!"); - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {}); + NDArray::registerSpecialUse({this}, {}); } template SD_EXPORT void NDArray::operator*=(const double scalar); template SD_EXPORT void NDArray::operator*=(const float scalar); @@ -2697,15 +3034,20 @@ template SD_EXPORT void NDArray::operator*=(const uint8_t scalar); template SD_EXPORT void NDArray::operator*=(const bool scalar); //////////////////////////////////////////////////////////////////////// -template +template void NDArray::operator/=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {}); + if (isS()) + throw std::runtime_error( + "NDArray::operator/=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar( + getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {}); } template SD_EXPORT void NDArray::operator/=(const double scalar); template SD_EXPORT void NDArray::operator/=(const float scalar); @@ -2720,1712 +3062,2385 @@ template SD_EXPORT void NDArray::operator/=(const bool scalar); //////////////////////////////////////////////////////////////////////// // negative operator, it makes all array elements = -elements -NDArray NDArray::operator-() const & { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); +NDArray NDArray::operator-() const& { + if (isS()) + throw std::runtime_error( + "NDArray::negative-: you can't use this method on String array!"); - NDArray result(shapeInfo(), false, getContext()); + NDArray result(shapeInfo(), false, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } //////////////////////////////////////////////////////////////////////// NDArray NDArray::operator-() && { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::negative-: you can't use this method on String array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); + return std::move(*this); } //////////////////////////////////////////////////////////////////////// // mathematical multiplication of two arrays NDArray mmul(const NDArray& left, const NDArray& right) { - if (left.isS() || right.isS()) - throw std::runtime_error("mmul friend function: you can't use this function on String array!"); - auto ptr = MmulHelper::mmul(const_cast(&left), const_cast(&right), nullptr, 1., 0.); - NDArray result(std::move(*ptr)); - delete ptr; - return result; + if (left.isS() || right.isS()) + throw std::runtime_error( + "mmul friend function: you can't use this function on String array!"); + auto ptr = MmulHelper::mmul(const_cast(&left), + const_cast(&right), nullptr, 1., 0.); + NDArray result(std::move(*ptr)); + delete ptr; + return result; } //////////////////////////////////////////////////////////////////////// void NDArray::tileToShape(const std::vector& shape, NDArray& target) { - if(&target != this) { - this->tile(target); - return; - } + if (&target != this) { + this->tile(target); + return; + } - std::vector thisShape(rankOf()); - for(int i = 0; i < rankOf(); ++i) - thisShape[i] = sizeAt(i); + std::vector thisShape(rankOf()); + for (int i = 0; i < rankOf(); ++i) thisShape[i] = sizeAt(i); - if(!ShapeUtils::areShapesBroadcastable(shape, thisShape)) - throw std::runtime_error("NDArray::tileToShape method: the shape of this array and input shape are not suitable for broadcast operation !"); + if (!ShapeUtils::areShapesBroadcastable(shape, thisShape)) + throw std::runtime_error( + "NDArray::tileToShape method: the shape of this array and input shape " + "are not suitable for broadcast operation !"); - const int newRank = shape.size(); - std::vector repeats(newRank); + const int newRank = shape.size(); + std::vector repeats(newRank); - for(int i = 1; i <= newRank; ++i) { - if(i > rankOf()) - repeats[newRank-i] = shape[newRank - i]; - else - repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i]; - } + for (int i = 1; i <= newRank; ++i) { + if (i > rankOf()) + repeats[newRank - i] = shape[newRank - i]; + else + repeats[newRank - i] = shape[newRank - i] / thisShape[rankOf() - i]; + } - tilei(repeats); + tilei(repeats); } //////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { - tileToShape(std::vector(shape), target); +void NDArray::tileToShape(const std::initializer_list& shape, + NDArray& target) { + tileToShape(std::vector(shape), target); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::tileToShape(const Nd4jLong* shapeInfo) { - - NDArray result(const_cast(shapeInfo), false, getContext()); - tile(result); - return result; + NDArray result(const_cast(shapeInfo), false, getContext()); + tile(result); + return result; } //////////////////////////////////////////////////////////////////////// double NDArray::getTrace() const { - if (isS()) - throw std::runtime_error("NDArray::getTrace: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error( + "NDArray::getTrace: you can't use this method on String array!"); - int rank = rankOf(); - auto shape = shapeOf(); - int minDim = 100000000; + int rank = rankOf(); + auto shape = shapeOf(); + int minDim = 100000000; - Nd4jLong indices[MAX_RANK]; - for(int j = 0; j < rank; ++j) - indices[j] = 1; + Nd4jLong indices[MAX_RANK]; + for (int j = 0; j < rank; ++j) indices[j] = 1; - auto offset = shape::getOffset(shapeInfo(), indices); + auto offset = shape::getOffset(shapeInfo(), indices); - for(int i = 0; i < rank; ++i) - if(minDim > shape[i]) - minDim = shape[i]; + for (int i = 0; i < rank; ++i) + if (minDim > shape[i]) minDim = shape[i]; - double sum = 0.; + double sum = 0.; - for(int i = 0; i < minDim; ++i) - sum += e(i * offset); + for (int i = 0; i < minDim; ++i) sum += e(i * offset); - return sum; + return sum; } //////////////////////////////////////////////////////////////////////// NDArray NDArray::quantize(const NDArray& array) { + if (!array.isR()) + throw std::invalid_argument( + "NDArray::quantize: type of array should be from real space!"); - if(!array.isR()) - throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); + auto ws = array.getContext()->getWorkspace(); - auto ws = array.getContext()->getWorkspace(); + Nd4jLong* shapeInfo = + ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); - Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); + std::shared_ptr buffer = std::make_shared( + TypeCast::estimateQuantizedSize(array.lengthOf()), + ArrayOptions::dataType(shapeInfo), ws); - std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); + NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); - NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); - - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, + NDArray& target, const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast: you can't use this method on String " + "array!"); - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); + if (((op.s == scalar::Divide || op.s == scalar::FloorDiv || + op.s == scalar::FloorMod) && + other.isB()) || + (op.s == scalar::ReverseDivide && this->isB())) + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: you can't divide by bool array !"); - if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); + if (isEmpty() || other.isEmpty()) return; - if (isEmpty() || other.isEmpty()) - return; + // if (lengthOf() == 1) { + // target.assign(this); + // target.applyPairwiseTransform(op.p, other, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // const_cast(this)->applyScalarArr(op.s, other, target, + // extraArgs); return; + // } - // if (lengthOf() == 1) { - // target.assign(this); - // target.applyPairwiseTransform(op.p, other, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); - } + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shape or type of target " + "array is wrong !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = + ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, + const NDArray& other, NDArray& target, + const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool: you can't use this method on String " + "array!"); + + if (isEmpty() || other.isEmpty()) return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsSoft(target._shapeInfo, newShapeInfo) || + target.dataType() != DataType::BOOL) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool method: the shape or type of " + "target array is wrong !"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyTrueBroadcast bool method: this and other arrays must " + "have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = + ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, + const NDArray& other, NDArray& target, + const bool checkTargetShape, + ExtraArguments* extraArgs) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast bool: you can't use this method on String " + "array!"); + + if (isEmpty() || other.isEmpty()) return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if (checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, false, newShapeInfo, + getContext()->getWorkspace())) // the rank of target array must be + // equal to max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + if (!shape::equalsSoft(target._shapeInfo, newShapeInfo) || + target.dataType() != this->dataType()) + throw std::runtime_error( + "NDArray::applyTrueBroadcast int method: the shape or type of target " + "array is wrong !"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyTrueBroadcast int method: this and other arrays must " + "have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = + ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt( + getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, + const NDArray& other, + ExtraArguments* extraArgs) const& { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + NDArray result(newShapeInfo, true, getContext()); - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } + this->applyTrueBroadcast(op, other, result, false, extraArgs); - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); + return result; } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs) const& { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + if (!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } - if (isEmpty() || other.isEmpty()) - return; + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, + const NDArray& other, + ExtraArguments* extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + if (!shape::shapeEquals(newShapeInfo, shapeInfo())) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) - throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); - } + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); +} - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); +////////////////////////////////////////////////////////////////////////// +NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, + ExtraArguments* extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if (!ShapeUtils::evalBroadcastShapeInfo( + *this, other, true, newShapeInfo, + getContext() + ->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error( + "NDArray::applyTrueBroadcast method: the shapes of this and other " + "arrays are not suitable for broadcast operation !"); + + const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); + const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); + + if (!thisMove && !otherMove) { + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } + if (thisMove) { + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + + // otherMove + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::Ops op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.size() == 0) return; + + if (isS()) + throw std::runtime_error( + "NDArray::applyBroadcast: you can't use this method on String array!"); + if (((op == broadcast::Divide || op == broadcast::FloorDiv || + op == broadcast::FloorMod) && + other.isB()) || + (op == broadcast::ReverseDivide && this->isB())) + throw std::runtime_error( + "NDArray::applyBroadcast: you can't divide by array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast method: when some of input arrays (or both) " + "is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseTransform(getContext(), + // fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), + // specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) + throw std::invalid_argument( + "NDArray::applyBroadcast method: wrong type of target array !"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast method: one of of two input arrays (this or " + "other) should has the same shape as target array!"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::BoolOps op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.size() == 0) return; + + if (isS()) + throw std::runtime_error( + "NDArray::applyBroadcast BoolOps: you can't use this method on String " + "array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast BoolOps: when some of input arrays (or " + "both) is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), + // fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), + // specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != DataType::BOOL) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: type of target array must be " + "BOOL!"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: one of of two input arrays (this " + "or other) should has the same shape as target array!"); + if (_dataType != other._dataType) + throw std::invalid_argument( + "NDArray::applyBroadcast bool method: this and other arrays must have " + "the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::IntOps op, + const std::vector& dimensions, + const NDArray& other, NDArray& target, + ExtraArguments* extraArgs) { + if (dimensions.empty()) return; + + if (!isZ()) + throw std::runtime_error( + "NDArray::applyBroadcast IntOps: you can't use this method on " + "non-Integer array!"); + if (isEmpty() || other.isEmpty()) { + if (!target.isEmpty()) + throw std::runtime_error( + "NDArray::applyBroadcast IntOps: when some of input arrays (or both) " + "is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseIntTransform(getContext(), + // fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), + // specialShapeInfo(), other.buffer(), other.shapeInfo(), + // other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + // target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + // nullptr); NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if (target.dataType() != dataType()) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: type of target array must be the " + "same as input!"); + if (!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: one of of two input arrays (this " + "or other) should has the same shape as target array!"); + if (_dataType != other._dataType) + throw std::invalid_argument( + "NDArray::applyBroadcast int method: this and other arrays must have " + "the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if (!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), shapeInfo(), + getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if (!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance() + ->createShapeInfoWithUnitiesForBroadcast( + target.shapeInfo(), other.shapeInfo(), + other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt( + getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, + other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, + target.buffer(), target.shapeInfo(), target.specialBuffer(), + target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyBroadcast(sd::broadcast::Ops op, + const std::initializer_list dimensions, + const NDArray& tadArray, NDArray& target, + ExtraArguments* extraArgs) { + std::vector vec(dimensions); + applyBroadcast(op, vec, tadArray, target, extraArgs); +} - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); +//////////////////////////////////////////////////////////////////////// +void* NDArray::operator new(size_t i) { + if (sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) { + sd::memory::Workspace* ws = + sd::memory::MemoryRegistrator::getInstance()->getWorkspace(); + return ws->allocateBytes((Nd4jLong)i); + } else { + auto p = malloc(i); + CHECK_ALLOC(p, "Failed to allocate new NDArray", i); + return p; + } } -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { +//////////////////////////////////////////////////////////////////////// +void NDArray::operator delete(void* p) { + if (!sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) + free(p); +} - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); +//////////////////////////////////////////////////////////////////////// +template +std::vector NDArray::asVectorT() { + std::vector result(this->lengthOf()); - if (isEmpty() || other.isEmpty()) - return; + PRAGMA_OMP_SIMD + for (int e = 0; e < this->lengthOf(); e++) result[e] = this->e(e); - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) - throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} + return result; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::asVectorT(), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } +// set new order and shape in case of suitable array length +bool NDArray::reshapei(const char order, const std::vector& cshape, + const bool copyToNewBuff) { + // check firstly whether cshape is identical to shape of array, if yes then + // reshape is unnecessary + if (order == ordering() && + shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) + return true; - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - NDArray result(newShapeInfo, true, getContext()); + const bool isOutShapeEmpty = + std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); + + if (isEmpty() && !isOutShapeEmpty) + throw std::invalid_argument( + "NDArray::reshapei: can't reshape empty array to non-empty !"); + if (!isEmpty() && isOutShapeEmpty) + throw std::invalid_argument( + "NDArray::reshapei: can't reshape non-empty array to empty !"); + if (isEmpty() && isOutShapeEmpty) { + Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo( + dataType(), order, cshape, getContext()->getWorkspace()); + setShapeInfo(shapeInfoNew); + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + return true; + } - this->applyTrueBroadcast(op, other, result, false, extraArgs); + std::vector shape(cshape); + int rank = shape.size(); - return result; -} + // looking for negative in shape -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } + int numberNegativesOnes = 0; - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + Nd4jLong* shape_ = shape.data(); + for (int i = 0; i < (int)shape.size(); i++) { + if (shape[i] < 0) { + if (numberNegativesOnes >= 1) + throw std::runtime_error( + "NDArray::reshapei: only one dimension can be negative at once"); - if(!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { + numberNegativesOnes++; - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } + int shapeLength = 1; + for (int j = 0; j < (int)shape.size(); j++) + if (i != j) shapeLength *= shape_[j]; - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} + Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); + auto thisNewShape = new Nd4jLong[shape.size()]; -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); + for (int j = 0; j < (int)shape.size(); j++) + if (i != j) + thisNewShape[j] = shape_[j]; else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(!shape::shapeEquals(newShapeInfo, shapeInfo())) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); -} + thisNewShape[j] = realShape; -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); + shape_ = thisNewShape; } + } - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + for (int e = 0; e < (int)shape.size(); e++) shape[e] = shape_[e]; - const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); - const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); + if (numberNegativesOnes > 0) delete[] shape_; - if(!thisMove && !otherMove) { + Nd4jLong arrLength = 1; + for (const auto& item : shape) arrLength *= item; - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } + if (platformBuffer() == nullptr || arrLength != this->lengthOf()) { + this->printShapeInfo("Mismatched shape"); + sd::Logger::printv("Shape requested: ", shape); + nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", + arrLength, this->lengthOf()); + throw std::runtime_error("NDArray::reshapei: bad input shape!"); + } - if(thisMove) { - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); - } + Nd4jLong* shapeInfoNew; + ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), + shape::shapeInfoLength(rank), Nd4jLong); - // otherMove - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), + shape.data(), shapeInfoNew); - if (dimensions.size() == 0) - return; + if (canReshape) { + setShapeInfo(shapeInfoNew); + } else { + NDArray temp(order, shape, dataType(), getContext()); + if (copyToNewBuff) this->applyTransform(transform::Assign, temp, nullptr); + *this = std::move(temp); + } - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); - if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) - throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.size() == 0) - return; - - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); -} - - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.empty()) - return; - - if (!isZ()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != dataType()) - throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } + RELEASE(shapeInfoNew, getContext()->getWorkspace()); - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { - std::vector vec(dimensions); - applyBroadcast(op, vec, tadArray, target, extraArgs); -} - -//////////////////////////////////////////////////////////////////////// -void* NDArray::operator new(size_t i) { - if (sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) { - sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance()->getWorkspace(); - return ws->allocateBytes((Nd4jLong) i); - } - else { - auto p = malloc(i); - CHECK_ALLOC(p, "Failed to allocate new NDArray", i); - return p; - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator delete(void* p) { - if (!sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) - free(p); -} - -//////////////////////////////////////////////////////////////////////// -template -std::vector NDArray::asVectorT() { - - std::vector result(this->lengthOf()); - - PRAGMA_OMP_SIMD - for (int e = 0; e < this->lengthOf(); e++) - result[e] = this->e(e); - - return result; -} -BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::vector& cshape, const bool copyToNewBuff) { - - // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary - if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) - return true; - - const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); - - if(isEmpty() && !isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !"); - if(!isEmpty() && isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !"); - if(isEmpty() && isOutShapeEmpty) { - Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace()); - setShapeInfo(shapeInfoNew); - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - return true; - } - - std::vector shape(cshape); - int rank = shape.size(); - - // looking for negative in shape - - int numberNegativesOnes = 0; - - Nd4jLong* shape_ = shape.data(); - for (int i = 0; i < (int) shape.size(); i++) { - if (shape[i] < 0) { - if (numberNegativesOnes >= 1) - throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once"); - - numberNegativesOnes++; - - int shapeLength = 1; - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - shapeLength *= shape_[j]; - - Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); - auto thisNewShape = new Nd4jLong[shape.size()]; - - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - thisNewShape[j] = shape_[j]; - else - thisNewShape[j] = realShape; - - shape_ = thisNewShape; - } - } - - for (int e = 0; e < (int) shape.size(); e++) - shape[e] = shape_[e]; - - if (numberNegativesOnes > 0) - delete[] shape_; - - Nd4jLong arrLength = 1; - for(const auto& item : shape) - arrLength *= item; - - if(platformBuffer() == nullptr || arrLength != this->lengthOf()) { - this->printShapeInfo("Mismatched shape"); - sd::Logger::printv("Shape requested: ", shape); - nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", arrLength, this->lengthOf()); - throw std::runtime_error("NDArray::reshapei: bad input shape!"); - } - - Nd4jLong *shapeInfoNew; - ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew); - - if (canReshape) { - setShapeInfo(shapeInfoNew); - } - else { - NDArray temp(order, shape, dataType(), getContext()); - if(copyToNewBuff) - this->applyTransform(transform::Assign, temp, nullptr); - *this = std::move(temp); - } - - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - - return canReshape; + return canReshape; } ////////////////////////////////////////////////////////////////////////// void NDArray::nullify() { - if (isEmpty()) - return; - - if (isView() || ews() != 1) - assign(0); - else - _buffer->setToZeroBuffers(); + if (isEmpty()) return; + if (isView() || ews() != 1) + assign(0); + else + _buffer->setToZeroBuffers(); } //////////////////////////////////////////////////////////////////////// template -void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); +void NDArray::templatedSet(void* buffer, const Nd4jLong xOfsset, + sd::DataType dtype, const void* value) { + BUILD_SINGLE_PARTIAL_SELECTOR( + dtype, templatedSet<, T>(buffer, xOfsset, value), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedSet, + (void* buffer, const Nd4jLong xOfsset, sd::DataType dtype, + const void* value), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); - if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - - if (extraParams != nullptr) - synchronize("NDArray::applyPairwiseTransform"); +void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform: you can't use this method on String " + "array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform method - lengths of arrays are " + "mismatched"); + if (target.dataType() != this->dataType() && + target.dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform method - type of target array must be " + "the same as type of this or other array !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); + + if (extraParams != nullptr) synchronize("NDArray::applyPairwiseTransform"); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); +void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, + const NDArray& other, NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform BoolOps: you can't use this method on " + "String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - lengths of arrays " + "are mismatched"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - result must have " + "bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform BoolOps method - this and other " + "arrays must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); - if (!target.isZ()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { - applyPairwiseTransform(op, other, *this, extraParams); +void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, + const NDArray& other, NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyPairwiseTransform IntOps: you can't use this method on " + "String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - lengths of arrays are " + "mismatched"); + if (!target.isZ()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - result must have bool " + "type"); + if (dataType() != other.dataType()) + throw std::invalid_argument( + "NDArray::applyPairwiseTransform IntOps method - this and other arrays " + "must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, + ExtraArguments* extraParams) { + applyPairwiseTransform(op, other, *this, extraParams); } //////////////////////////////////////////////////////////////////////// template -void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - auto x = reinterpret_cast(xBuffer); - const auto y = reinterpret_cast(yBuffer); - x[xOffset] = static_cast(y[yOffset]); -} -BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); +void NDArray::templatedDoubleAssign(void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, + const Nd4jLong yOffset) const { + auto x = reinterpret_cast(xBuffer); + const auto y = reinterpret_cast(yBuffer); + x[xOffset] = static_cast(y[yOffset]); +} +BUILD_DOUBLE_TEMPLATE(template SD_EXPORT void NDArray::templatedDoubleAssign, + (void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const, + LIBND4J_TYPES, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { - - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - - if (!target.isR()) - throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == dimensions.size() || dimensions.empty()) - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), biasCorrected); - else { - std::vector copy(dimensions); - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); - synchronize("NDArray::varianceAlongDimension"); - } - - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, + const bool biasCorrected, + const std::vector& dimensions) const { + if (isS()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: you can't use this method on String " + "array!"); + + if (!target.isR()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: target array must have FLOAT type"); + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == dimensions.size() || dimensions.empty()) + NativeOpExecutioner::execSummaryStatsScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), biasCorrected); + else { + std::vector copy(dimensions); + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, + dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), + biasCorrected); + synchronize("NDArray::varianceAlongDimension"); + } + + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); +NDArray NDArray::varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::vector& dimensions) const { + if (isS()) + throw std::runtime_error( + "NDArray::varianceAlongDimension: you can't use this method on String " + "array!"); - std::vector copy(dimensions); - if (copy.size() > 1) - std::sort(copy.begin(), copy.end()); + std::vector copy(dimensions); + if (copy.size() > 1) std::sort(copy.begin(), copy.end()); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); - this->varianceAlongDimension(op, result, biasCorrected, dimensions); + this->varianceAlongDimension(op, result, biasCorrected, dimensions); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { - return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); +NDArray NDArray::varianceAlongDimension( + sd::variance::Ops op, const bool biasCorrected, + const std::initializer_list& dimensions) const { + return varianceAlongDimension(op, biasCorrected, + std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { - varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); +void NDArray::varianceAlongDimension( + sd::variance::Ops op, NDArray& target, const bool biasCorrected, + const std::initializer_list& dimensions) const { + varianceAlongDimension(op, target, biasCorrected, + std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// // This method returns new copy of this NDArray, optionally in different order NDArray NDArray::dup(const char newOrder) const { + if (isEmpty()) return NDArrayFactory::empty(dataType(), getContext()); - if (isEmpty()) - return NDArrayFactory::empty(dataType(), getContext()); + char order = newOrder == 'a' ? ordering() : newOrder; - char order = newOrder == 'a' ? ordering() : newOrder; + // for now string arrays require special treatment + if (isS()) { + if (dataType() == DataType::UTF8) { + std::vector strings(lengthOf()); - // for now string arrays require special treatment - if (isS()) { - if (dataType() == DataType::UTF8) { - std::vector strings(lengthOf()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + if (dataType() == DataType::UTF16) { + std::vector strings(lengthOf()); - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); } - if (dataType() == DataType::UTF16) { - std::vector strings(lengthOf()); + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } - std::vector strings(lengthOf()); - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; + std::vector strings(lengthOf()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } - NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); - result.assign(*this); + NDArray result(order, + isScalar() ? std::vector({0}) : getShapeAsVector(), + dataType(), getContext()); + result.assign(*this); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -// This method returns true if two arrays are equal, with custom or default Eps value of 1e-5, false otherwise -bool NDArray::equalsTo(const NDArray *other, double eps) const { +// This method returns true if two arrays are equal, with custom or default Eps +// value of 1e-5, false otherwise +bool NDArray::equalsTo(const NDArray* other, double eps) const { + if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) + return false; - if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) - return false; + // we need to be able to compare [1, len] to [len] + if ((rankOf() == 1 && other->rankOf() == 2) || + (rankOf() == 2 && other->rankOf() == 1)) { + // FIXME: do something here? + } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) + return false; - // we need to be able to compare [1, len] to [len] - if ((rankOf() == 1 && other->rankOf() == 2) || (rankOf() == 2 && other->rankOf() == 1)) { - // FIXME: do something here? - } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) - return false; + if (isS()) { + // string is special case, we'll compare them one by one, considering both + // arrays are guaranteed to have the same length - if (isS()) { - // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length + if (dataType() == DataType::UTF8) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - if (dataType() == DataType::UTF8) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); + if (s1 != s2) return false; + } + } else if (dataType() == DataType::UTF16) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - if (s1 != s2) - return false; - } - } - else if (dataType() == DataType::UTF16) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - else { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - - return true; + if (s1 != s2) return false; + } } else { - // regular numeric types - NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - ExtraArguments extras({0.0, 0.0, eps}); + if (s1 != s2) return false; + } + } - NDArray::prepareSpecialUse({&tmp}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), - specialBuffer(), specialShapeInfo(), - extras.argumentsAsT(DataType::FLOAT32), other->buffer(), - other->shapeInfo(), other->specialBuffer(), - other->specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), - tmp.specialBuffer(), tmp.specialShapeInfo()); - NDArray::registerSpecialUse({&tmp}, {this, other}); + return true; + } else { + // regular numeric types + NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 - synchronize("NDArray::equalsTo"); + ExtraArguments extras({0.0, 0.0, eps}); - if (tmp.e(0) != 0) - return false; + NDArray::prepareSpecialUse({&tmp}, {this, other}); + NativeOpExecutioner::execReduce3Scalar( + getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), + extras.argumentsAsT(DataType::FLOAT32), other->buffer(), + other->shapeInfo(), other->specialBuffer(), other->specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo()); + NDArray::registerSpecialUse({&tmp}, {this, other}); - return true; - } + synchronize("NDArray::equalsTo"); + + if (tmp.e(0) != 0) return false; + + return true; + } } ////////////////////////////////////////////////////////////////////////// template <> std::string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error("Can't get std::string out of non-string array"); - if (!isS()) - throw std::runtime_error("Can't get std::string out of non-string array"); - - if (i == lengthOf()) - throw std::runtime_error("Can't get std::string for index out of range"); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::string for index out of range"); + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::string s; + StringUtils::u16StringToU8String(u16, s); + return s; + } - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::string s; - StringUtils::u16StringToU8String(u16, s); - return s; - } - - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::string s; - StringUtils::u32StringToU8String(u32, s); - return s; - } + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::string s; + StringUtils::u32StringToU8String(u32, s); + return s; + } - NDArray::preparePrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); - auto offsets = bufferAsT(); - auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - auto start = offsets[i]; - auto end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; + auto offsets = bufferAsT(); + auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + auto start = offsets[i]; + auto end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; - std::string r(reinterpret_cast(data), (end - start)); + std::string r(reinterpret_cast(data), (end - start)); - registerPrimaryUse({}, {this}); + registerPrimaryUse({}, {this}); - return r; + return r; } template <> std::u16string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error( + "Can't get std::u16string out of non-string array"); - if (!isS()) - throw std::runtime_error("Can't get std::u16string out of non-string array"); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u16string for index out of range"); - if(i == lengthOf()) - throw std::runtime_error("Can't get std::u16string for index out of range"); + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u16string s; + StringUtils::u8StringToU16String(u, s); + return s; + } - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u16string s; - StringUtils::u8StringToU16String(u, s); - return s; - } - - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::u16string s; - StringUtils::u32StringToU16String(u32, s); - return s; - } + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::u16string s; + StringUtils::u32StringToU16String(u32, s); + return s; + } - NDArray::preparePrimaryUse({}, { this }); + NDArray::preparePrimaryUse({}, {this}); - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; - std::u16string r(reinterpret_cast(data), (end - start) / sizeof(char16_t)); + std::u16string r(reinterpret_cast(data), + (end - start) / sizeof(char16_t)); - registerPrimaryUse({}, { this }); + registerPrimaryUse({}, {this}); - return r; + return r; } template <> std::u32string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error( + "Can't get std::u32string out of non-string array"); - if (!isS()) - throw std::runtime_error("Can't get std::u32string out of non-string array"); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u32string for index out of range"); - if (i == lengthOf()) - throw std::runtime_error("Can't get std::u32string for index out of range"); + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u32string s; + StringUtils::u8StringToU32String(u, s); + return s; + } - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u32string s; - StringUtils::u8StringToU32String(u, s); - return s; - } + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::u32string s; + StringUtils::u16StringToU32String(u16, s); + return s; + } - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::u32string s; - StringUtils::u16StringToU32String(u16, s); - return s; - } + NDArray::preparePrimaryUse({}, {this}); - NDArray::preparePrimaryUse({}, { this }); + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = + ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; - auto data = bufferAsT() + offsetsLength + start; + std::u32string r(reinterpret_cast(data), + (end - start) / sizeof(char32_t)); - std::u32string r(reinterpret_cast(data), (end - start) / sizeof(char32_t)); + registerPrimaryUse({}, {this}); - registerPrimaryUse({}, { this }); - - return r; + return r; } ////////////////////////////////////////////////////////////////////////// template <> utf8string NDArray::e(const Nd4jLong i) const { + if (!isS()) + throw std::runtime_error("This method is available for String arrays only"); - if (!isS()) - throw std::runtime_error("This method is available for String arrays only"); - - auto rp = getOffset(i); + auto rp = getOffset(i); - syncToHost(); - tickReadHost(); + syncToHost(); + tickReadHost(); - return *(reinterpret_cast(buffer())[rp]); + return *(reinterpret_cast(buffer())[rp]); } ///////////////////////////////////////////////////////////////////////// template T NDArray::e(const Nd4jLong i) const { + const auto rp = getOffset(i); - const auto rp = getOffset(i); - - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); - + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong) const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // Returns value from 2D matrix by coordinates/indexes template T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument( + "NDArray::e(i,j): one of input indexes is out of array length or " + "rank!=2 !"); - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !"); + const Nd4jLong coords[2] = {i, j}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); - const Nd4jLong coords[2] = {i, j}; - const auto xOffset = shape::getOffset(shapeInfo(), coords); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - - return static_cast(119); + return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong) + const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates template T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2]) + throw std::invalid_argument( + "NDArray::e(i,j,k): one of input indexes is out of array length or " + "rank!=3 !"); - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !"); - - const Nd4jLong coords[3] = {i, j, k}; - const auto xOffset = shape::getOffset(shapeInfo(), coords); + const Nd4jLong coords[3] = {i, j, k}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - return static_cast(119); + return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong, + const Nd4jLong) const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates template -T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const { - - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !"); +T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l) const { + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument( + "NDArray::e(i,j,k,l): one of input indexes is out of array length or " + "rank!=4 !"); - const Nd4jLong coords[4] = {i, j, k, l}; - const auto xOffset = shape::getOffset(shapeInfo(), coords); + const Nd4jLong coords[4] = {i, j, k, l}; + const auto xOffset = shape::getOffset(shapeInfo(), coords); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - return static_cast(119); + return static_cast(119); } -BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); +BUILD_SINGLE_UNCHAINED_TEMPLATE(template SD_EXPORT, + NDArray::e(const Nd4jLong, const Nd4jLong, + const Nd4jLong, const Nd4jLong) + const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// NDArray NDArray::e(const Nd4jLong i) const { + const auto offset = getOffset(i); - const auto offset = getOffset(i); + NDArray scalar(dataType(), getContext()); - NDArray scalar(dataType(), getContext()); + scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, + bufferOffset() + offset); - scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, bufferOffset() + offset); - - return scalar; + return scalar; } ////////////////////////////////////////////////////////////////////////// // perform array transformation -void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); - - if (!target.isR()) - throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform FloatOps: you can't use this method on String " + "array!"); + + if (!target.isR()) + throw std::runtime_error( + "NDArray::applyTransform FloatOps: target array must have one of FLOAT " + "types"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform AnyOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformAny( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { - - if (isS()) - throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); - - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform SameOps: you can't use this method on String " + "array!"); + + if (target.dataType() != dataType()) + throw std::runtime_error( + "NDArray::applyTransform SameOps: target array must have the same data " + "type as original array"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); - - if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) - throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform StrictOps: you can't use this method on " + "String array!"); + + if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) + throw std::runtime_error( + "NDArray::applyTransform StrictOps: both Source and Target array must " + "have same FLOAT type !"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); - - if (!target.isB()) - throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); +void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, + ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyTransform BoolOps: you can't use this method on String " + "array!"); + + if (!target.isB()) + throw std::runtime_error( + "NDArray::applyTransform BoolOps: target array must have one of BOOL " + "types"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr, + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); +NDArray NDArray::transform(sd::transform::FloatOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform FloatOps: you can't use this method on String " + "array!"); - NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext()); + NDArray result(ordering(), getShapeAsVector(), + DataTypeUtils::pickFloatingType(dataType()), getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); +NDArray NDArray::transform(sd::transform::FloatOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); + return std::move(*this); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); +NDArray NDArray::transform(sd::transform::SameOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); - NDArray result(shapeInfo(), false, getContext()); + NDArray result(shapeInfo(), false, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); +NDArray NDArray::transform(sd::transform::SameOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform SameOps: you can't use this method on String " + "array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); + return std::move(*this); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) const & { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); +NDArray NDArray::transform(sd::transform::StrictOps op, + void* extraParams) const& { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); - NDArray result(shapeInfo(), false, getContext()); + NDArray result(shapeInfo(), false, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) && { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); +NDArray NDArray::transform(sd::transform::StrictOps op, void* extraParams) && { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformStrict( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); + return std::move(*this); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); +NDArray NDArray::transform(sd::transform::BoolOps op, + void* extraParams) const& { + if (isS()) + throw std::runtime_error( + "NDArray::transform BoolOps: you can't use this method on String " + "array!"); - NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, getContext()); + NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, + getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), result.buffer(), result.shapeInfo(), + result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, + nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); - - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.shapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) - throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); - - if (target.dataType() != this->dataType()) - throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); - } - - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); +NDArray NDArray::transform(sd::transform::BoolOps op, void* extraParams) && { + if (isS()) + throw std::runtime_error( + "NDArray::transform BoolOps: you can't use this method on String " + "array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, ExtraArguments* extraParams) { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr: you can't use this method on String array!"); + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::applyScalarArr method: operand is not a scalar!"); + + if (target.dataType() != DataTypeUtils::pickPairwiseResultType( + shapeInfo(), scalar.shapeInfo()) && + !(target.dataType() == dataType() || + target.dataType() == scalar.dataType())) + throw std::invalid_argument( + "NDArray::applyScalarArr method: wrong type of target array!"); + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr BoolOps: you can't use this method on String " + "array!"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::applyScalarArr bool method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf( + "NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: " + "[%i]\n", + this->dataType(), scalar.dataType()); + throw std::invalid_argument( + "NDArray::applyScalarArr bool method: this and scalar arrays must have " + "the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyScalarArr IntOps: you can't use this method on String " + "array!"); + + if (target.dataType() != this->dataType()) + throw std::invalid_argument( + "NDArray::applyScalarArr int method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf( + "NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: " + "[%i]\n", + this->dataType(), scalar.dataType()); + throw std::invalid_argument( + "NDArray::applyScalarArr int method: this and scalar arrays must have " + "the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarInt( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), + scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) + : nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } //////////////////////////////////////////////////////////////////////// template -void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); +void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, + NDArray& target, ExtraArguments* extraParams) const { + NDArray scalarArr = + NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); } -template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, + const NDArray& scalar, NDArray& target, + ExtraArguments* extraParams) const { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const double scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const float scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const float16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const bfloat16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const Nd4jLong scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int16_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const int8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const uint8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::IntOps op, const bool scalar, NDArray& target, + ExtraArguments* extraParams) const; //////////////////////////////////////////////////////////////////////// template -void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { - - auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} -template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); -template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); +void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, + ExtraArguments* extraParams) { + auto scalarArr = + NDArrayFactory::create(dataType(), scalar, this->getContext()); + applyScalarArr(op, scalarArr, target, extraParams); +} +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, + NDArray& target, + ExtraArguments* extraParams) { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const double scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const float scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const float16 scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const bfloat16 scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const Nd4jLong scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int scalar, NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int16_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const int8_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const uint8_t scalar, + NDArray& target, + ExtraArguments* extraParams); +template SD_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, + const bool scalar, NDArray& target, + ExtraArguments* extraParams); //////////////////////////////////////////////////////////////////////// template -void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); +void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, + NDArray& target, ExtraArguments* extraParams) const { + NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); } -template <> SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - - if (target.dataType() != sd::DataType::INT64 && target.dataType() != sd::DataType::INT32) - throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); - - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; - - NDArray::prepareSpecialUse({&target}, {this}); - - if (target.lengthOf() == 1) { - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - std::vector copy = dimensions; - shape::checkDimensions(rankOf(), copy); - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); - NativeOpExecutioner::execIndexReduce(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - synchronize("NDArray::applyIndexReduce"); - } - - registerSpecialUse({&target}, {this}); -} +template <> +SD_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, + const NDArray& scalar, NDArray& target, + ExtraArguments* extraParams) const { + throw std::runtime_error( + "NDArray::applyScalar method: do not use me!"); +} +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const double scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const float scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const float16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const bfloat16 scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int16_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const int8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const uint8_t scalar, NDArray& target, + ExtraArguments* extraParams) const; +template SD_EXPORT void NDArray::applyScalar( + sd::scalar::BoolOps op, const bool scalar, NDArray& target, + ExtraArguments* extraParams) const; //////////////////////////////////////////////////////////////////////// -// reduce dimensions in this array relying on index operations -NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { - +void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyIndexReduce: you can't use this method on String " + "array!"); + + if (target.dataType() != sd::DataType::INT64 && + target.dataType() != sd::DataType::INT32) + throw std::runtime_error( + "NDArray::applyIndexReduce operations return INT32/INT64"); + + void* params = extraParams != nullptr + ? const_cast(extraParams) + ->argumentsAsT(this->dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&target}, {this}); + + if (target.lengthOf() == 1) { + NativeOpExecutioner::execIndexReduceScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { std::vector copy = dimensions; - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - - applyIndexReduce(op, result, copy, extraParams); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// apply reduce3 operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - // check shapes consistency - if(!isSameShape(other)) - throw std::runtime_error("NDArray::applyReduce3 method: the shapes of this and other arrays must be the same !"); - // create shapeInfo for scalar - auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); - // create output array (scalar) - NDArray result(newShape, true, getContext()); - RELEASE(newShape, getContext()->getWorkspace()); - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -// apply reduce3 (exec) operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - - std::vector copy(dimensions); shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - - // perform calculations - if(rankOf() == copy.size() && other.rankOf() == copy.size()) { - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - } - else { - - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(other.shapeInfo(), copy); + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + shapeInfo(), copy); + NativeOpExecutioner::execIndexReduce( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + synchronize("NDArray::applyIndexReduce"); + } - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) - throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); - - NativeOpExecutioner::execReduce3(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - } - - registerSpecialUse({&result}, {this, &other}); - - return result; + registerSpecialUse({&target}, {this}); } //////////////////////////////////////////////////////////////////////// -// apply reduce3 (execAll) operations to this and other array, return result in new output array -NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); - - // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(other.shapeInfo(), copy); - - // check tads shapes - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) - throw std::runtime_error("NDArray::applyAllReduce3 method: the shapes of array tads are different !"); - - // set newShape for output array - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); - - // create output array - NDArray result(newShape, true, getContext()); - - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3All(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({&result}, {this, &other}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); - if (!target.isR()) - throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); - NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension FloatOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); +// reduce dimensions in this array relying on index operations +NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + std::vector copy = dimensions; + auto newShape = + ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { //if (!isEmpty()) { - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension SameOps"); + applyIndexReduce(op, result, copy, extraParams); - NDArray::registerSpecialUse({&target}, {this}); + return result; } -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); - if (target.dataType() != DataType::INT64) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); - - NDArray::registerSpecialUse({&target}, {this}); +//////////////////////////////////////////////////////////////////////// +// apply reduce3 operations to this and other array, return result in new output +// array +NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyReduce3 method: you can't use this method on String " + "array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyReduce3 method: the types of this and other arrays must " + "be the same !"); + // check shapes consistency + if (!isSameShape(other)) + throw std::runtime_error( + "NDArray::applyReduce3 method: the shapes of this and other arrays " + "must be the same !"); + // create shapeInfo for scalar + auto newShape = ShapeBuilders::createScalarShapeInfo( + DataTypeUtils::pickFloatingType(dataType()), + getContext()->getWorkspace()); + // create output array (scalar) + NDArray result(newShape, true, getContext()); + RELEASE(newShape, getContext()->getWorkspace()); + // create dynamic array of extra parameters if array extraParams is empty + // (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3Scalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this, &other}); + + return result; } -////////////////////////////////////////////////////////////////////////// -// method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); - - std::vector copy(dimensions); - - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), copy); - NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); - - NDArray::registerSpecialUse({&target}, {this}); +//////////////////////////////////////////////////////////////////////// +// apply reduce3 (exec) operations to this and other array, return result in new +// output array +NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyReduce3: you can't use this method on String array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyReduce3 method: the types of this and other arrays must " + "be the same !"); + + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, + false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + // create temporary dynamic array of extra parameters if array extraParams is + // empty (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + + // perform calculations + if (rankOf() == copy.size() && other.rankOf() == copy.size()) { + NativeOpExecutioner::execReduce3Scalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + shapeInfo(), copy); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + other.shapeInfo(), copy); + + if (!shape::equalsSoft(packX.primaryShapeInfo(), + packY.primaryShapeInfo()) || + (packX.numberOfTads() != packY.numberOfTads() && + packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) + throw std::runtime_error( + "NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); + + NativeOpExecutioner::execReduce3( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); + } + + registerSpecialUse({&result}, {this, &other}); + + return result; +} + +//////////////////////////////////////////////////////////////////////// +// apply reduce3 (execAll) operations to this and other array, return result in +// new output array +NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, + const std::vector& dimensions, + const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error( + "NDArray::applyAllReduce3: you can't use this method on String array!"); + if (dataType() != other.dataType()) + throw std::runtime_error( + "NDArray::applyAllReduce3 method: the types of this and other arrays " + "must be the same !"); + + // be careful, copy array may undergo changes (sort, transformation of + // negative dimensions to positive, duplicates removing ) + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); + + auto packX = + ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + other.shapeInfo(), copy); + + // check tads shapes + if (!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) + throw std::runtime_error( + "NDArray::applyAllReduce3 method: the shapes of array tads are " + "different !"); + + // set newShape for output array + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataTypeUtils::pickFloatingType(dataType()), 'c', + {packX.numberOfTads(), packY.numberOfTads()}); + + // create output array + NDArray result(newShape, true, getContext()); + + // create dynamic array of extra parameters if array extraParams is empty + // (==nullptr) + void* params = + extraParams != nullptr + ? const_cast(extraParams)->argumentsAsT(dataType()) + : nullptr; + + auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3All( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), params, other.buffer(), other.shapeInfo(), + other.specialBuffer(), other.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); + NDArray::registerSpecialUse({&result}, {this, &other}); + + return result; +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension FloatOps: you can't use this method on " + "String array!"); + if (!target.isR()) + throw std::invalid_argument( + "NDArray::reduceAlongDimension FloatOps: requires target array to be " + "present and have type form real space!"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension FloatOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceFloatScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + shapeInfo(), copy); + NativeOpExecutioner::execReduceFloat( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), copy.data(), + copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension FloatOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: you can't use this method on " + "String array!"); + if (target.dataType() != dataType()) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: requires target array to be " + "present and have same dtype as input"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension SameOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceSameScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { // if (!isEmpty()) { + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceSame( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension SameOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: you can't use this method on " + "String array!"); + if (target.dataType() != DataType::INT64) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: requires target array to be " + "present and have type of INT64"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension LongOps: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceLongScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceLong( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); +} + +////////////////////////////////////////////////////////////////////////// +// method reduces array by excluding its shapes along axes present in dimensions +// vector +void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, + const std::vector& dimensions, + const bool keepDims, + const bool supportOldShapes, + const bool checkTargetShape) const { + if (isS()) + throw std::runtime_error( + "NDArray::reduceAlongDimension BoolOps cuda: you can't use this method " + "on String array!"); + if (!target.isB()) + throw std::invalid_argument( + "NDArray::reduceAlongDimension BoolOps cuda: requires target array to " + "be present and have BOOL type!"); + + std::vector copy(dimensions); + + if (checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo( + target.ordering(), copy, *this, keepDims, supportOldShapes, + getContext()->getWorkspace()); + if (!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error( + "NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); + } + + NDArray::prepareSpecialUse({&target}, {this}); + + if (rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceBoolScalar( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo()); + } else { + auto pDims = + sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), copy); + NativeOpExecutioner::execReduceBool( + getContext(), op, buffer(), shapeInfo(), specialBuffer(), + specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), + target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), + packX.platformShapeInfo(), packX.platformOffsets()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// // This method sets value in linear buffer to position i template void NDArray::p(const Nd4jLong i, const T value) { + if (i >= lengthOf()) + throw std::invalid_argument( + "NDArray::p(i, value): input index is out of array length !"); - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::p(i, value): input index is out of array length !"); + auto rp = getOffset(i); + const void* pV = reinterpret_cast(const_cast(&value)); - auto rp = getOffset(i); - const void *pV = reinterpret_cast(const_cast(&value)); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->buffer(), rp, pV), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), + templatedSet<, T>(this->buffer(), rp, pV), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); } template SD_EXPORT void NDArray::p(const Nd4jLong i, const double value); @@ -4446,641 +5461,832 @@ template SD_EXPORT void NDArray::p(const Nd4jLong i, const bool value); // This method sets value in 2D matrix to position i, j template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { - - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray:pe(i,j, value): one of input indexes is out of array length or rank!=2 !"); - - void *p = reinterpret_cast(const_cast(&value)); - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(shapeInfo(), coords); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument( + "NDArray:pe(i,j, value): one of input indexes is out of array length " + "or rank!=2 !"); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[2] = {i, j}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 3D matrix to position i,j,k template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !"); - - NDArray::preparePrimaryUse({this}, {}, true); - - void *p = reinterpret_cast(const_cast(&value)); - Nd4jLong coords[3] = {i, j, k}; - auto xOffset = shape::getOffset(shapeInfo(), coords); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2]) + throw std::invalid_argument( + "NDArray:pe(i,j,k, value): one of input indexes is out of array length " + "or rank!=3 !"); + + NDArray::preparePrimaryUse({this}, {}, true); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[3] = {i, j, k}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const bool value); ////////////////////////////////////////////////////////////////////////// template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::p(i,j,k,l, value): one of input indexes is out of array length or rank!=4 !"); - - void *p = reinterpret_cast(const_cast(&value)); - Nd4jLong coords[4] = {i, j, k, l}; - auto xOffset = shape::getOffset(shapeInfo(), coords); - - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); -template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l, const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || + k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument( + "NDArray::p(i,j,k,l, value): one of input indexes is out of array " + "length or rank!=4 !"); + + void* p = reinterpret_cast(const_cast(&value)); + Nd4jLong coords[4] = {i, j, k, l}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR( + dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); +} +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const double value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const float value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const float16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const bfloat16 value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const Nd4jLong value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint8_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint32_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const uint64_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const int16_t value); +template SD_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, + const Nd4jLong k, const Nd4jLong l, + const bool value); //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const NDArray& scalar) { + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument( + "NDArray::p(i, NDArray_scalar): input index is out of array length !"); - if(scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::p method: input array must be scalar!"); - if (i >= _length) - throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); - - NDArray::preparePrimaryUse({this}, {&scalar}, true); - auto rp = getOffset(i); - BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (buffer(), rp, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&scalar}); + NDArray::preparePrimaryUse({this}, {&scalar}, true); + auto rp = getOffset(i); + BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, + (buffer(), rp, scalar.dataType(), scalar.buffer()), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); } //////////////////////////////////////////////////////////////////////// - void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { - - if(scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::p method: input array must be scalar!"); - if (i >= _length) - throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); - -// void *p = reinterpret_cast(scalar.buffer()); - Nd4jLong coords[4] = {i, j, k, l}; - auto xOffset = shape::getOffset(shapeInfo(), coords); - - NDArray::preparePrimaryUse({this}, {&scalar}, true); -// BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (this->buffer(), xOffset, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&scalar}); - } +void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, + const Nd4jLong l, const NDArray& scalar) { + if (scalar.lengthOf() != 1) + throw std::invalid_argument( + "NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument( + "NDArray::p(i, NDArray_scalar): input index is out of array length !"); + + // void *p = reinterpret_cast(scalar.buffer()); + Nd4jLong coords[4] = {i, j, k, l}; + auto xOffset = shape::getOffset(shapeInfo(), coords); + + NDArray::preparePrimaryUse({this}, {&scalar}, true); + // BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, + // T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + scalar.dataType(), templatedSet, + (this->buffer(), xOffset, scalar.dataType(), scalar.buffer()), + LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); +} ////////////////////////////////////////////////////////////////////////// void NDArray::addRowVector(const NDArray& row, NDArray& target) const { - - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); + if (isS()) + throw std::runtime_error( + "NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && + !(isR() && row.isR() && target.isR())) + throw std::invalid_argument( + "NDArray::addRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); } ////////////////////////////////////////////////////////////////////////// void NDArray::subRowVector(const NDArray& row, NDArray& target) const { - - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - - -////////////////////////////////////////////////////////////////////////// -void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { - - if (isS()) - throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::divRowVector(const NDArray &row, NDArray &target) const { - - if (isS()) - throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); - if (row.isB()) - throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} - -////////////////////////////////////////////////////////////////////////// -// This method adds given row to all rows in this NDArray, this array becomes affected + if (isS()) + throw std::runtime_error( + "NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && + !(isR() && row.isR() && target.isR())) + throw std::invalid_argument( + "NDArray::addRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), + nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::mulRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::mulRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument( + "NDArray::mulRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::divRowVector(const NDArray& row, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::divRowVector: you can't use this method on String array!"); + if (row.isB()) + throw std::runtime_error( + "NDArray::divRowVector: you can't divide by bool row!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !row.isRowVector() || + columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument( + "NDArray::divRowVector: wrong type of target array !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); +} + +////////////////////////////////////////////////////////////////////////// +// This method adds given row to all rows in this NDArray, this array becomes +// affected void NDArray::addiRowVector(const NDArray& row) { - - if (isS()) - throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); - - int dimension = 1; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({this}, {&row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&row}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { - if (isS()) - throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) - throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); - - int dimension = 0; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({&target}, {this, &column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &column}); -} - -////////////////////////////////////////////////////////////////////////// -// This method adds given column to all columns in this NDArray, this array becomes affected -void NDArray::addiColumnVector(const NDArray &column) { - if (isS()) - throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); - - int dimension = 0; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); - - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); -} - -////////////////////////////////////////////////////////////////////////// -// This method multiplies each column of this array by given argument-column, this array becomes affected + if (isS()) + throw std::runtime_error( + "NDArray::addiRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); + + int dimension = 1; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({this}, {&row}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), + row.specialBuffer(), row.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&row}); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::addColumnVector(const NDArray& column, NDArray& target) const { + if (isS()) + throw std::runtime_error( + "NDArray::addColumnVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || + columns() != target.columns() || !column.isColumnVector() || + rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); + if (target.dataType() != + DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) + throw std::invalid_argument( + "NDArray::addColumnVector: wrong type of target array !"); + + int dimension = 0; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({&target}, {this, &column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), target.buffer(), + target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({&target}, {this, &column}); +} + +////////////////////////////////////////////////////////////////////////// +// This method adds given column to all columns in this NDArray, this array +// becomes affected +void NDArray::addiColumnVector(const NDArray& column) { + if (isS()) + throw std::runtime_error( + "NDArray::addiColumnVector: you can't use this method on String " + "array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); + + int dimension = 0; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); + + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&column}); +} + +////////////////////////////////////////////////////////////////////////// +// This method multiplies each column of this array by given argument-column, +// this array becomes affected void NDArray::muliColumnVector(const NDArray& column) { - if (isS()) - throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); + if (isS()) + throw std::runtime_error( + "NDArray::muliColumnVector: you can't use this method on String " + "array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); - int dimension = 0; + int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast( + getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), + column.specialBuffer(), column.specialShapeInfo(), this->buffer(), + this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), + nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, + nullptr); + NDArray::registerSpecialUse({this}, {&column}); } ////////////////////////////////////////////////////////////////////////// template -void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - if (xBuffer != nullptr && yBuffer != nullptr) - *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); +void NDArray::templatedAssign(void* xBuffer, Nd4jLong xOffset, + const void* yBuffer, + const Nd4jLong yOffset) const { + if (xBuffer != nullptr && yBuffer != nullptr) + *(reinterpret_cast(xBuffer) + xOffset) = + *(reinterpret_cast(yBuffer) + yOffset); } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template SD_EXPORT void NDArray::templatedAssign, + (void* xBuffer, const Nd4jLong xOffset, + const void* yBuffer, const Nd4jLong yOffset) const, + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const int* dimensions, const int rank) { + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, + getContext()->getWorkspace()); + setShapeInfo(shapeInfo); - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); - - return true; + return true; } ////////////////////////////////////////////////////////////////////////// bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, + getContext()->getWorkspace()); + setShapeInfo(shapeInfo); - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); - - return true; + return true; } //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { - ResultSet result; - - if (indices.size() == 0) - return result; +ResultSet NDArray::multipleTensorsAlongDimension( + const std::vector& indices, const std::vector& dimensions) const { + ResultSet result; - auto pack = ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), const_cast(dimensions.data()), dimensions.size()); + if (indices.size() == 0) return result; - auto tadLength = shape::length(pack.primaryShapeInfo()); - auto numTads = lengthOf() / tadLength; + auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + shapeInfo(), const_cast(dimensions.data()), dimensions.size()); - for (auto idx: indices) { - if (idx >= numTads) { - nd4j_printf("NDArray::multipleTensorsAlongDimension: index %i is higher then number of TADs: %i\n", idx, numTads); - throw std::runtime_error("Bad index"); - } + auto tadLength = shape::length(pack.primaryShapeInfo()); + auto numTads = lengthOf() / tadLength; - NDArray array(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - result.push_back(array); + for (auto idx : indices) { + if (idx >= numTads) { + nd4j_printf( + "NDArray::multipleTensorsAlongDimension: index %i is higher then " + "number of TADs: %i\n", + idx, numTads); + throw std::runtime_error("Bad index"); } - return result; + NDArray array(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), + getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + result.push_back(array); + } + + return result; } //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { - return allTensorsAlongDimension(std::vector(dimensions)); +ResultSet NDArray::allTensorsAlongDimension( + const std::initializer_list& dimensions) const { + return allTensorsAlongDimension(std::vector(dimensions)); } //////////////////////////////////////////////////////////////////////// ResultSet NDArray::allExamples() const { - std::vector dimensions(rankOf() - 1); - for (int e = 1; e < rankOf(); e++) - dimensions[e-1] = e; + std::vector dimensions(rankOf() - 1); + for (int e = 1; e < rankOf(); e++) dimensions[e - 1] = e; - return allTensorsAlongDimension(dimensions); + return allTensorsAlongDimension(dimensions); } //////////////////////////////////////////////////////////////////////// Nd4jLong NDArray::getOffset(const Nd4jLong i) const { + if (i >= lengthOf()) + throw std::invalid_argument( + "NDArray::getOffset: input index is out of array length !"); - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::getOffset: input index is out of array length !"); - - return shape::getIndexOffset(i, _shapeInfo); + return shape::getIndexOffset(i, _shapeInfo); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::like() { - - return NDArray(shapeInfo(), this->dataType(), false, getContext()); + return NDArray(shapeInfo(), this->dataType(), false, getContext()); } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::ulike() const{ - - return NDArray(this, false, getContext()); -} +NDArray NDArray::ulike() const { return NDArray(this, false, getContext()); } //////////////////////////////////////////////////////////////////////// NDArray NDArray::diagonal(const char type) const { + if (isS()) + throw std::runtime_error( + "NDArray::diagonal: you can't use this method on String array!"); + + const char order = ordering(); + const int rank = rankOf(); + Nd4jLong* outShapeInfo; + ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[5] = 0; + + if (isVector() || isScalar()) { + outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; + outShapeInfo[6] = 1; + outShapeInfo[7] = (int)order; + } else { + int diagSize = 100000000; + Nd4jLong indices[MAX_RANK]; - if (isS()) - throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); - - const char order = ordering(); - const int rank = rankOf(); - Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[5] = 0; - - if(isVector() || isScalar()) { - - outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; - outShapeInfo[6] = 1; - outShapeInfo[7] = (int)order; + for (int i = 0; i < rank; ++i) { + if (diagSize > shapeOf()[i]) diagSize = shapeOf()[i]; + indices[i] = 1; } - else { - - int diagSize = 100000000; - Nd4jLong indices[MAX_RANK]; - - for(int i = 0; i < rank; ++i) { - if(diagSize > shapeOf()[i]) - diagSize = shapeOf()[i]; - indices[i] = 1; - } - auto step = shape::getOffset(shapeInfo(), indices); + auto step = shape::getOffset(shapeInfo(), indices); - if(type == 'c') { - outShapeInfo[1] = diagSize; - outShapeInfo[2] = 1; - } - else { - outShapeInfo[1] = 1; - outShapeInfo[2] = diagSize; - } - shape::updateStrides(outShapeInfo, order); - - outShapeInfo[3] *= step; - outShapeInfo[4] *= step; - outShapeInfo[6] = 0; + if (type == 'c') { + outShapeInfo[1] = diagSize; + outShapeInfo[2] = 1; + } else { + outShapeInfo[1] = 1; + outShapeInfo[2] = diagSize; } + shape::updateStrides(outShapeInfo, order); - ArrayOptions::setDataType(outShapeInfo, this->dataType()); + outShapeInfo[3] *= step; + outShapeInfo[4] *= step; + outShapeInfo[6] = 0; + } - NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), bufferOffset()); + ArrayOptions::setDataType(outShapeInfo, this->dataType()); - RELEASE(outShapeInfo, getContext()->getWorkspace()); + NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), + bufferOffset()); - return result; + RELEASE(outShapeInfo, getContext()->getWorkspace()); + + return result; } //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { - - ResultSet result; - - if(dimensions.size() == 0) - return result; +ResultSet NDArray::allTensorsAlongDimension( + const std::vector& dimensions) const { + ResultSet result; - if(dimensions.back() >= rankOf()) - throw std::runtime_error("NDArray::allTensorsAlongDimension static function: all input dimensions must be smaller than rank of input array !"); + if (dimensions.size() == 0) return result; + if (dimensions.back() >= rankOf()) + throw std::runtime_error( + "NDArray::allTensorsAlongDimension static function: all input " + "dimensions must be smaller than rank of input array !"); - auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast(dimensions.data()), dimensions.size()); - auto numTads = pack.numberOfTads(); + auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + _shapeInfo, const_cast(dimensions.data()), dimensions.size()); + auto numTads = pack.numberOfTads(); - for (Nd4jLong idx = 0; idx < numTads; idx++ ) { - NDArray array(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - array._isView = true; - result.push_back(array); - } + for (Nd4jLong idx = 0; idx < numTads; idx++) { + NDArray array(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), + getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + array._isView = true; + result.push_back(array); + } - return result; + return result; } //////////////////////////////////////////////////////////////////////// -// operator returns sub-array with buffer pointing at this->_buffer + certain offset -NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { - - if(isEmpty()) - throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !"); - - // Nd4jLong *outShapeInfo = nullptr; - // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); +// operator returns sub-array with buffer pointing at this->_buffer + certain +// offset +NDArray NDArray::operator()(const std::vector& idx, + const bool keepUnitiesInShape, + const bool isStrided) const { + if (isEmpty()) + throw std::invalid_argument( + "NDArray::operator(sub-arrays): array is empty !"); - int numOfUntiesInSubArrShape = 0; + // Nd4jLong *outShapeInfo = nullptr; + // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), + // Nd4jLong); - Nd4jLong* subArrShapeInfo = nullptr; + int numOfUntiesInSubArrShape = 0; - if(!keepUnitiesInShape) { + Nd4jLong* subArrShapeInfo = nullptr; - int n(isStrided ? 3 : 2), first, last; + if (!keepUnitiesInShape) { + int n(isStrided ? 3 : 2), first, last; - // calculate the number of unities in shape - for (uint d = 0; d < rankOf(); ++d) { - - if (idx[n * d] != idx[n * d + 1]) { - - first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; - last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1; - if(last - first == 1) - ++numOfUntiesInSubArrShape; - } - } + // calculate the number of unities in shape + for (uint d = 0; d < rankOf(); ++d) { + if (idx[n * d] != idx[n * d + 1]) { + first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; + last = idx[n * d + 1] >= 0 ? idx[n * d + 1] + : idx[n * d + 1] + sizeAt(d) + 1; + if (last - first == 1) ++numOfUntiesInSubArrShape; + } } + } - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong); + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), + shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), + Nd4jLong); - Nd4jLong offset; + Nd4jLong offset; - shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape); + shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, + offset, keepUnitiesInShape, isStrided, + numOfUntiesInSubArrShape); - NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + bufferOffset()); - result._isView = true; + NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), + offset + bufferOffset()); + result._isView = true; - RELEASE(subArrShapeInfo, getContext()->getWorkspace()); + RELEASE(subArrShapeInfo, getContext()->getWorkspace()); - return result; + return result; } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape) const { +NDArray NDArray::operator()(const Nd4jLong subArrIdx, + const std::vector& dimsToExclude, + bool keepUnitiesInShape) const { + std::vector idxRanges(2 * rankOf()); - std::vector idxRanges(2 * rankOf()); + const auto rank = rankOf(); + const auto subArrRank = static_cast(dimsToExclude.size()); - const auto rank = rankOf(); - const auto subArrRank = static_cast(dimsToExclude.size()); + if (subArrRank > rank) + throw std::invalid_argument( + "NDArray::operator(const Nd4jLong subArrIdx, const std::vector& " + "dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude " + "is empty or has size > rank of array !"); - if(subArrRank > rank) - throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !"); + memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); - memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); + // subArrRank == 0 means whole array, idxRanges should contain zeros only - // subArrRank == 0 means whole array, idxRanges should contain zeros only + if (subArrRank != 0) { + std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); + for (int i = 0; i < subArrRank; ++i) + shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); - if(subArrRank != 0) { + shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), + indexes.data()); - std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); - for(int i = 0; i < subArrRank; ++i) - shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); - - shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data()); - - for(int i = 0; i < subArrRank; ++i) { - int currIdx = 2 * dimsToExclude[i]; - idxRanges[currIdx] = indexes[i]; - idxRanges[currIdx + 1] = indexes[i] + 1; - } + for (int i = 0; i < subArrRank; ++i) { + int currIdx = 2 * dimsToExclude[i]; + idxRanges[currIdx] = indexes[i]; + idxRanges[currIdx + 1] = indexes[i] + 1; } + } - return (*this)(idxRanges, keepUnitiesInShape); + return (*this)(idxRanges, keepUnitiesInShape); } //////////////////////////////////////////////////////////////////////// -void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { - - if(isEmpty()) - throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !"); - - const int rank = rankOf(); - const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); - - // allocate memory - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong); - ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); - - shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) { - - if (shapeInfo != nullptr) { - - ShapeDescriptor descriptor(shapeInfo); - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); - - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); - #ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); - #endif +void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, + Nd4jLong*& subArrShapeInfo, + Nd4jLong*& subArrOffsets, + bool keepUnitiesInShape) const { + if (isEmpty()) + throw std::invalid_argument( + "NDArray::getSubArrShapeAndOffsets: array is empty !"); + + const int rank = rankOf(); + const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) + ? rank + : rank - dimsToExclude.size(); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); + + // allocate memory + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), + shape::shapeInfoLength(subArrRank), Nd4jLong); + ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); + + shape::calcSubArrsShapeInfoAndOffsets( + _shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), + subArrShapeInfo, subArrOffsets, keepUnitiesInShape); +} + +////////////////////////////////////////////////////////////////////////// +void NDArray::setShapeInfo(const Nd4jLong* shapeInfo) { + if (shapeInfo != nullptr) { + ShapeDescriptor descriptor(shapeInfo); + auto shapeBuffer = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); + + _shapeInfo = reinterpret_cast(shapeBuffer.primary()); +#ifdef __CUDABLAS__ + _shapeInfoD = reinterpret_cast(shapeBuffer.special()); +#endif - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); - _dataType = ArrayOptions::dataType(_shapeInfo); - } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } + _dataType = ArrayOptions::dataType(_shapeInfo); + } else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; + } } //////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype) { - - if (shapeInfo != nullptr) { - - Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dtype, true, getContext()->getWorkspace()); - ShapeDescriptor descriptor(shapeInfoTemp); - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); - - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); - #ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); - #endif +void NDArray::setShapeInfo(const Nd4jLong* shapeInfo, + const sd::DataType dtype) { + if (shapeInfo != nullptr) { + Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType( + shapeInfo, dtype, true, getContext()->getWorkspace()); + ShapeDescriptor descriptor(shapeInfoTemp); + auto shapeBuffer = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); + + _shapeInfo = reinterpret_cast(shapeBuffer.primary()); +#ifdef __CUDABLAS__ + _shapeInfoD = reinterpret_cast(shapeBuffer.special()); +#endif - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); - _dataType = dtype; - } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } + _dataType = dtype; + } else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; + } } ////////////////////////////////////////////////////////////////////////// void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { + auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo( + const_cast(descriptor)); - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(const_cast(descriptor)); - - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); - #ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); - #endif + _shapeInfo = reinterpret_cast(shapeBuffer.primary()); +#ifdef __CUDABLAS__ + _shapeInfoD = reinterpret_cast(shapeBuffer.special()); +#endif - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); - _dataType = ArrayOptions::dataType(_shapeInfo); + _dataType = ArrayOptions::dataType(_shapeInfo); } ////////////////////////////////////////////////////////////////////////// void NDArray::setShapeInfo(const ConstantDataBuffer& shapeBuffer) { + _shapeInfo = reinterpret_cast( + const_cast(shapeBuffer).primary()); +#ifdef __CUDABLAS__ + _shapeInfoD = reinterpret_cast( + const_cast(shapeBuffer).special()); +#endif - _shapeInfo = reinterpret_cast(const_cast(shapeBuffer).primary()); - #ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(const_cast(shapeBuffer).special()); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); - _dataType = ArrayOptions::dataType(_shapeInfo); + _dataType = ArrayOptions::dataType(_shapeInfo); } /////////////////////////////////////////////////////////////////////// // addition operator array + scalar template NDArray operator+(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr + scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr + scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator+(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator+(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); } template SD_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); template SD_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); @@ -5091,29 +6297,39 @@ template SD_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator+(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + if (arr.isS()) + throw std::runtime_error( + "operator+(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; } template SD_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); template SD_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); template SD_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); -template SD_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator+(const NDArray& arr, + const bfloat16& scalar); template SD_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator+(const T& scalar, NDArray&& arr) { - return std::move(arr) + scalar; + return std::move(arr) + scalar; } template SD_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); template SD_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); @@ -5121,11 +6337,10 @@ template SD_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); template SD_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); template SD_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); - //////////////////////////////////////////////////////////////////////// template NDArray operator+(const T& scalar, const NDArray& arr) { - return arr + scalar; + return arr + scalar; } template SD_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); template SD_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); @@ -5135,22 +6350,32 @@ template SD_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); // addition operator array - scalar template NDArray operator-(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr - scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr - scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator-(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator-(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); } template SD_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); template SD_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); @@ -5158,43 +6383,59 @@ template SD_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator-(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + if (arr.isS()) + throw std::runtime_error( + "operator-(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; } template SD_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); template SD_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); template SD_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); -template SD_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator-(const NDArray& arr, + const bfloat16& scalar); template SD_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator-(const T& scalar, NDArray&& arr) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(scalar - arr); // arr is lvalue inside function body - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar - arr); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + if (arr.isS()) + throw std::runtime_error( + "operator-(const T& scalar, NDArray&& arr): you can't use this method " + "on String array!"); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - return std::move(arr); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), + arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + return std::move(arr); } template SD_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); template SD_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); @@ -5205,18 +6446,27 @@ template SD_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// template NDArray operator-(const T& scalar, const NDArray& arr) { - - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + if (arr.isS()) + throw std::runtime_error( + "operator-(const T& scalar, const NDArray& arr): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + result.buffer(), result.shapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; } template SD_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); template SD_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); @@ -5226,22 +6476,32 @@ template SD_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); // addition operator array + scalar template NDArray operator*(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr * scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr * scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator*(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator*(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); } template SD_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); template SD_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); @@ -5253,31 +6513,42 @@ template SD_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator*(const NDArray& arr, const T& scalar) { + if (arr.isS()) + throw std::runtime_error( + "operator*(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); - if (arr.isS()) - throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + return result; } template SD_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); template SD_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); template SD_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); -template SD_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, + const bfloat16& scalar); template SD_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); -template SD_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); +template SD_EXPORT NDArray operator*(const NDArray& arr, + const long long& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator*(const T& scalar, NDArray&& arr) { - return std::move(arr) * scalar; + return std::move(arr) * scalar; } template SD_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); template SD_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); @@ -5286,38 +6557,49 @@ template SD_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); template SD_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); template SD_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); - //////////////////////////////////////////////////////////////////////// template NDArray operator*(const T& scalar, const NDArray& arr) { - return arr * scalar; + return arr * scalar; } template SD_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); template SD_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); template SD_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); -template SD_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const bfloat16& scalar, + const NDArray& arr); template SD_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); -template SD_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); +template SD_EXPORT NDArray operator*(const long long& scalar, + const NDArray& arr); /////////////////////////////////////////////////////////////////////// template NDArray operator/(NDArray&& arr, const T& scalar) { - - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr / scalar); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(arr / scalar); // arr is lvalue inside function body + + if (arr.isS()) + throw std::runtime_error( + "operator/(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error( + "operator/(NDArray&& arr, const T& scalar): you can't use this method " + "on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + + return std::move(arr); } template SD_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); template SD_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); @@ -5328,44 +6610,61 @@ template SD_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator/(const NDArray& arr, const T& scalar) { - - if (arr.isS()) - throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + if (arr.isS()) + throw std::runtime_error( + "operator/(const NDArray& arr, const T& scalar): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), + arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), + result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), + tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; } template SD_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); template SD_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); template SD_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); -template SD_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, + const bfloat16& scalar); template SD_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); -template SD_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); +template SD_EXPORT NDArray operator/(const NDArray& arr, + const long long& scalar); //////////////////////////////////////////////////////////////////////// template NDArray operator/(const T& scalar, NDArray&& arr) { + if (arr.isView()) // do not use resources of arrays which use buffers of + // other original arrays + return std::move(scalar / arr); // arr is lvalue inside function body - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar / arr); // arr is lvalue inside function body - - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error( + "operator/(const T& scalar, NDArray&& arr): you can't use this method " + "on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - - return std::move(arr); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), + arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + return std::move(arr); } template SD_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); template SD_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); @@ -5373,22 +6672,30 @@ template SD_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); template SD_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); template SD_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); - //////////////////////////////////////////////////////////////////////// template NDArray operator/(const T& scalar, const NDArray& arr) { - - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - - return result; + if (arr.isS()) + throw std::runtime_error( + "operator/(const T& scalar, const NDArray& arr): you can't use this " + "method on String array!"); + + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr.dataType(), DataTypeUtils::fromT()), + false, arr.getContext()); + + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar( + arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), + arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), + result.buffer(), result.shapeInfo(), result.specialBuffer(), + result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + + return result; } template SD_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); template SD_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); @@ -5398,238 +6705,326 @@ template SD_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); // addition operator array + array template NDArray operator+(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), std::forward(arr2)); -} -template SD_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator+(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator+(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator+(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), + arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), + arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), + result->buffer(), result->shapeInfo(), result->specialBuffer(), + result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), + std::forward(arr2)); +} +template SD_EXPORT NDArray operator+ + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator+ + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator+ + (NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // addition operator array - array template NDArray operator-(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); -} -template SD_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator-(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator-(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator-(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator- + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator- + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator- + (NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array template NDArray operator*(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); -} -template SD_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator*(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator*(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator*(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator* + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator* + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator* + (NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array template NDArray operator/(T1&& arr1, T2&& arr2) { - - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - - PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); - - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); - } - - return std::move(*result); - } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); -} -template SD_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); -template SD_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); -template SD_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); -template SD_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); - + if (arr1.isS() || arr2.isS()) + throw std::runtime_error( + "operator/(T&& arr1, T&& arr2): you can't use this method on String " + "arrays!"); + if (!Environment::getInstance()->isExperimentalBuild() && + arr1.dataType() != arr2.dataType() && + (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build( + "operator/(T&& arr1, T&& arr2): Cannot multiply different types", + arr1.dataType(), arr2.dataType()); + + PointersManager pointersManager(arr1.getContext(), + "operator/(T&& arr1, T&& arr2)"); + + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + + NDArray* result = nullptr; + if (isArr1Rvalue) + result = const_cast(&arr1); + else if (isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), + DataTypeUtils::pickPairwiseResultType( + arr1.shapeInfo(), arr2.shapeInfo()), + false, arr1.getContext()); + + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform( + arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), + arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), + arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), + arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), + result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + + if (!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); + } + + return std::forward(arr1).applyTrueBroadcast( + sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); +} +template SD_EXPORT NDArray operator/ + (NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, NDArray& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, NDArray&& arr2); +template SD_EXPORT NDArray operator/ + (const NDArray& arr1, + const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, const NDArray& arr2); +template SD_EXPORT NDArray operator/ + (NDArray&& arr1, NDArray&& arr2); /* #ifndef __CLION_IDE__ #include "NDArray.macro" #endif */ -} +} // namespace sd #endif - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// // check whether array's rows (arg=0) or columns (arg=1) create orthogonal basis // bool NDArray::hasOrthonormalBasis(const int arg) { // if (isS()) -// throw std::runtime_error("NDArray::hasOrthonormalBasis: you can't use this method on String array!"); +// throw std::runtime_error("NDArray::hasOrthonormalBasis: you can't use +// this method on String array!"); // if(rankOf() !=2 ) -// throw std::runtime_error("NDArray::hasOrthBasis method: rank of ndarray is not equal 2 !"); +// throw std::runtime_error("NDArray::hasOrthBasis method: rank of +// ndarray is not equal 2 !"); // if(arg!=0 && arg!=1) -// throw std::runtime_error("NDArray::hasOrthBasis method: input argument is not equal to 0 or 1 !"); +// throw std::runtime_error("NDArray::hasOrthBasis method: input +// argument is not equal to 0 or 1 !"); // const double eps = 1e-5; // double dot = 0.f; -// if(arg) { // check whether columns create orthogonal basis +// if(arg) { // check whether columns create orthogonal +// basis // for(int j=0; j(NDArray&& arr1, NDA // dot = 0.f; // } -// for(int j=0; j(i,j)*e(i,j); -// if(dot != 0.f && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.f) > eps) +// if(dot != 0.f && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.f) > eps) // return false; // dot = 0.f; @@ -5662,11 +7059,13 @@ template SD_EXPORT NDArray operator/(NDArray&& arr1, NDA // dot = 0.; // } -// for(int i=0; i(i,j)*e(i,j); -// if(dot!= 0. && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.) > eps) +// if(dot!= 0. && sd::math::nd4j_abs(sd::math::nd4j_sqrt(dot) - 1.) > eps) // return false; // dot = 0.; // } diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 8d13adf15175..23ae2e070571 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -24,698 +24,1104 @@ #include #include #include +#include #include +#include +#include + #include +namespace sd { +SD_EXPORT NDArray NDArrayFactory::undefined() { return NDArray(); } +//////////////////////////////////////////////////////////////////////// +template <> +SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32 !"); + + ShapeDescriptor descriptor(sd::DataType::BOOL, order, shape); + + if (descriptor.arrLength() != data.size()) { + nd4j_printf( + "NDArrayFactory::create: data size [%i] doesn't match shape length " + "[%lld]\n", + data.size(), descriptor.arrLength()); + throw std::runtime_error( + "NDArrayFactory::create: data size doesn't match shape"); + } + + bool* hostBuffer = nullptr; + ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool); + std::copy(data.begin(), data.end(), hostBuffer); + + std::shared_ptr buffer = std::make_shared( + hostBuffer, data.size() * sizeof(bool), sd::DataType::BOOL, true, + context->getWorkspace()); + + NDArray result(buffer, descriptor, context); + + return result; +} -#include -#include -#include +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32 !"); + + ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shape); + + if (descriptor.arrLength() != data.size()) { + nd4j_printf( + "NDArrayFactory::create: data size [%i] doesn't match shape length " + "[%lld]\n", + data.size(), descriptor.arrLength()); + throw std::runtime_error( + "NDArrayFactory::create: data size doesn't match shape"); + } + + std::shared_ptr buffer = std::make_shared( + data.data(), DataTypeUtils::fromT(), + descriptor.arrLength() * sizeof(T), context->getWorkspace()); + + NDArray result(buffer, descriptor, context); + + return result; +} +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); -namespace sd { +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + sd::LaunchContext* context) { + return create_(order, shape, DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::create_, + (const char order, const std::vector& shape, + sd::LaunchContext* context), + LIBND4J_TYPES); - SD_EXPORT NDArray NDArrayFactory::undefined() { - return NDArray(); - } +//////////////////////////////////////////////////////////////////////// +template +void NDArrayFactory::memcpyFromVector(void* ptr, const std::vector& vector) { + memcpy(ptr, vector.data(), vector.size() * sizeof(T)); +} - //////////////////////////////////////////////////////////////////////// - template <> - SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { +template <> +void SD_EXPORT +NDArrayFactory::memcpyFromVector(void* ptr, const std::vector& vector) { + auto p = reinterpret_cast(ptr); + for (Nd4jLong e = 0; e < vector.size(); e++) p[e] = vector[e]; +} - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); +template SD_EXPORT void NDArrayFactory::memcpyFromVector( + void* ptr, const std::vector& vector); - ShapeDescriptor descriptor(sd::DataType::BOOL, order, shape); +#ifndef __JAVACPP_HACK__ +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, + const T value, const char order, + sd::LaunchContext* context) { + return valueOf(std::vector(shape), value, order); +} +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const double value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const float value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const float16 value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const bfloat16 value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const Nd4jLong value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const uint8_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int8_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const int16_t value, + const char order, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::initializer_list& shape, const bool value, + const char order, sd::LaunchContext* context); - if (descriptor.arrLength() != data.size()) { - nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength()); - throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape"); - } +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context) { + std::vector vec(data); + return create(order, shape, vec, context); +} +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray +NDArrayFactory::create(const char order, const std::vector& shape, + const std::initializer_list& data, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const char order, const std::vector& shape, + const std::initializer_list& data, sd::LaunchContext* context); - bool* hostBuffer = nullptr; - ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool); - std::copy(data.begin(), data.end(), hostBuffer); +#endif - std::shared_ptr buffer = std::make_shared(hostBuffer, data.size() * sizeof(bool), sd::DataType::BOOL, true, context->getWorkspace()); +//////////////////////////////////////////////////////////////////////// +template +NDArray* NDArrayFactory::create_(const T scalar, sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + 1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); + + NDArray* res = new NDArray( + buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), + context); - NDArray result(buffer, descriptor, context); + res->bufferAsT()[0] = scalar; - return result; - } + res->tickWriteHost(); + res->syncToDevice(); + + return res; +} +template SD_EXPORT NDArray* NDArrayFactory::create_(const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, + sd::LaunchContext* context); - //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { +template +NDArray NDArrayFactory::create(sd::DataType type, const T scalar, + sd::LaunchContext* context) { + if (type == DataTypeUtils::fromT()) + return NDArrayFactory::create(scalar, context); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !"); + NDArray res(type, context); + res.p(0, scalar); + res.syncToDevice(); - ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shape); + return res; +} +// BUILD_DOUBLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, +// (DataType type, const T scalar, sd::LaunchContext * context), +// LIBND4J_TYPES); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint16_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint32_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const uint64_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const int16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(DataType type, + const bool scalar, + sd::LaunchContext* context); - if (descriptor.arrLength() != data.size()) { - nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength()); - throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape"); - } +template +NDArray NDArrayFactory::create(const T scalar, sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + 1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - std::shared_ptr buffer = std::make_shared(data.data(), DataTypeUtils::fromT(), descriptor.arrLength() * sizeof(T), context->getWorkspace()); + NDArray res(buffer, + ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), + context); - NDArray result(buffer, descriptor, context); + res.bufferAsT()[0] = scalar; - return result; - } - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::vector& data, sd::LaunchContext * context); + res.tickWriteHost(); + res.syncToDevice(); -//////////////////////////////////////////////////////////////////////// -template -NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, sd::LaunchContext * context) { - return create_(order, shape, DataTypeUtils::fromT(), context); + return res; } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::create_, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); +template SD_EXPORT NDArray NDArrayFactory::create(const double scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const float scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const float16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray NDArrayFactory::create(const bool scalar, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// template -void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { - - memcpy(ptr, vector.data(), vector.size() * sizeof(T)); +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + const std::vector& data, + sd::LaunchContext* context) { + return new NDArray(NDArrayFactory::create(order, shape, data, context)); } +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::create_( + const char order, const std::vector& shape, + const std::vector& data, sd::LaunchContext* context); +//////////////////////////////////////////////////////////////////////// template <> -void SD_EXPORT NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector) { - auto p = reinterpret_cast(ptr); - for (Nd4jLong e = 0; e < vector.size(); e++) - p[e] = vector[e]; +SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, + NDArray* value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, value->dataType(), context); + result->assign(*value); + return result; } -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); -template SD_EXPORT void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &vector); - +template <> +SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, + NDArray& value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, value.dataType(), context); + result->assign(value); + return result; +} -#ifndef __JAVACPP_HACK__ - //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const T value, const char order, sd::LaunchContext * context) { - return valueOf(std::vector(shape), value, order); - } - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const double value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const float16 value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int8_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const int16_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::initializer_list& shape, const bool value, const char order, sd::LaunchContext * context); +template +NDArray* NDArrayFactory::valueOf(const std::vector& shape, + const T value, const char order, + sd::LaunchContext* context) { + auto result = create_(order, shape, DataTypeUtils::fromT()); + result->assign(value); + return result; +} +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const double value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const float value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const float16 value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const bfloat16 value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const Nd4jLong value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int16_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const int8_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const uint8_t value, const char order, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::valueOf( + const std::vector& shape, const bool value, const char order, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context) { - std::vector vec(data); - return create(order, shape, vec, context); - } - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, sd::LaunchContext * context); - -#endif +template +NDArray* NDArrayFactory::linspace(const T from, const T to, + const Nd4jLong numElements) { + NDArray* result = NDArrayFactory::vector(numElements); + // TO DO: linspace should be executed on DEVICE, but only CPU version + // implemnted! + for (Nd4jLong e = 0; e < numElements; e++) { + T step = (T)e / ((T)numElements - (T)1); + result->p(e, (from * ((T)1 - step) + step * to)); + } + result->syncToDevice(); + + return result; +} +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const double from, const double to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const float from, const float to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const float16 from, const float16 to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int from, const int to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int16_t from, const int16_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint8_t from, const uint8_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint16_t from, const uint16_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint32_t from, const uint32_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const uint64_t from, const uint64_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const int8_t from, const int8_t to, const Nd4jLong numElements); +template SD_EXPORT NDArray* NDArrayFactory::linspace( + const bool from, const bool to, const Nd4jLong numElements); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::create_(const T scalar, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - NDArray* res = new NDArray(buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), context); - - res->bufferAsT()[0] = scalar; - - res->tickWriteHost(); - res->syncToDevice(); - - return res; - } - template SD_EXPORT NDArray* NDArrayFactory::create_(const double scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const float scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const float16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const bfloat16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const Nd4jLong scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const int scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const bool scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const int8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const uint8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const uint16_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const uint32_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const uint64_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::create_(const int16_t scalar, sd::LaunchContext * context); - - template - NDArray NDArrayFactory::create(sd::DataType type, const T scalar, sd::LaunchContext * context) { - - if (type == DataTypeUtils::fromT()) - return NDArrayFactory::create(scalar, context); - - NDArray res(type, context); - res.p(0, scalar); - res.syncToDevice(); - - return res; - } -// BUILD_DOUBLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (DataType type, const T scalar, sd::LaunchContext * context), LIBND4J_TYPES); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const double scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const float scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const float16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const bfloat16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const Nd4jLong scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint16_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint32_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const uint64_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const int16_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(DataType type, const bool scalar, sd::LaunchContext * context); - - template - NDArray NDArrayFactory::create(const T scalar, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(1 * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - NDArray res(buffer, ShapeDescriptor::scalarDescriptor(DataTypeUtils::fromT()), context); - - res.bufferAsT()[0] = scalar; - - res.tickWriteHost(); - res.syncToDevice(); - - return res; - } - template SD_EXPORT NDArray NDArrayFactory::create(const double scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const float scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const float16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const bfloat16 scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const Nd4jLong scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const int scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const int8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const uint8_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const int16_t scalar, sd::LaunchContext * context); - template SD_EXPORT NDArray NDArrayFactory::create(const uint16_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(const uint32_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(const uint64_t scalar, sd::LaunchContext* workspace); - template SD_EXPORT NDArray NDArrayFactory::create(const bool scalar, sd::LaunchContext * context); - +template +NDArray* NDArrayFactory::vector(Nd4jLong length, const T value, + sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + length * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), + true); + + auto res = new NDArray( + buffer, + ShapeDescriptor::vectorDescriptor(length, DataTypeUtils::fromT()), + context); + + if (value == (T)0.0f) + res->nullify(); + else + res->assign(value); + + return res; +} +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const double startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const float startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const float16 startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const bfloat16 startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const Nd4jLong startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const uint8_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint16_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint32_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector( + Nd4jLong length, const uint64_t startingValue, + sd::LaunchContext* workspace); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int8_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const int16_t startingValue, + sd::LaunchContext* context); +template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, + const bool startingValue, + sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// -template -NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context) { - - return new NDArray(NDArrayFactory::create(order, shape, data, context)); -} -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); -template SD_EXPORT NDArray* NDArrayFactory::create_(const char order, const std::vector &shape, const std::vector &data, sd::LaunchContext * context); - - - //////////////////////////////////////////////////////////////////////// - template <> - SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray* value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, value->dataType(), context); - result->assign(*value); - return result; - } - - template <> - SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, NDArray& value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, value.dataType(), context); - result->assign(value); - return result; - } - - template - NDArray* NDArrayFactory::valueOf(const std::vector& shape, const T value, const char order, sd::LaunchContext * context) { - auto result = create_(order, shape, DataTypeUtils::fromT()); - result->assign(value); - return result; - } - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const double value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const float16 value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bfloat16 value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const Nd4jLong value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int16_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const int8_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const uint8_t value, const char order, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::valueOf(const std::vector& shape, const bool value, const char order, sd::LaunchContext * context); - - - //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::linspace(const T from, const T to, const Nd4jLong numElements) { - NDArray* result = NDArrayFactory::vector(numElements); - //TO DO: linspace should be executed on DEVICE, but only CPU version implemnted! - for (Nd4jLong e = 0; e < numElements; e++) { - T step = (T) e / ((T) numElements - (T) 1); - result->p(e, (from * ((T) 1 - step) + step * to)); - } - result->syncToDevice(); - - return result; - } - template SD_EXPORT NDArray* NDArrayFactory::linspace(const double from, const double to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const float from, const float to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const float16 from, const float16 to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const bfloat16 from, const bfloat16 to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const Nd4jLong from, const Nd4jLong to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const int from, const int to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const int16_t from, const int16_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint8_t from, const uint8_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint16_t from, const uint16_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint32_t from, const uint32_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const uint64_t from, const uint64_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const int8_t from, const int8_t to, const Nd4jLong numElements); - template SD_EXPORT NDArray* NDArrayFactory::linspace(const bool from, const bool to, const Nd4jLong numElements); +template +NDArray NDArrayFactory::create(const char order, + const std::initializer_list& shape, + sd::LaunchContext* context) { + std::vector vec(shape); + return create(order, vec, context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, + (const char, const std::initializer_list&, + sd::LaunchContext* context), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::vector(Nd4jLong length, const T value, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(length * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); - - auto res = new NDArray(buffer, ShapeDescriptor::vectorDescriptor(length, DataTypeUtils::fromT()), context); - - if (value == (T)0.0f) - res->nullify(); - else - res->assign(value); - - return res; - } - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const double startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const float16 startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bfloat16 startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const Nd4jLong startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint8_t startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint16_t startingValue, sd::LaunchContext *workspace); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint32_t startingValue, sd::LaunchContext *workspace); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const uint64_t startingValue, sd::LaunchContext *workspace); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int8_t startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const int16_t startingValue, sd::LaunchContext * context); - template SD_EXPORT NDArray* NDArrayFactory::vector(Nd4jLong length, const bool startingValue, sd::LaunchContext * context); +template +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + sd::LaunchContext* context) { + return create(order, shape, DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, + (const char order, const std::vector& shape, + sd::LaunchContext* context), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::initializer_list& shape, sd::LaunchContext * context) { - std::vector vec(shape); - return create(order, vec, context); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (const char, const std::initializer_list&, sd::LaunchContext * context), LIBND4J_TYPES); +NDArray NDArrayFactory::create(const char order, + const std::vector& shape, + sd::DataType dtype, sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: rank of NDArray can't exceed 32"); -//////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::LaunchContext * context) { - return create(order, shape, DataTypeUtils::fromT(), context); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::create, (const char order, const std::vector &shape, sd::LaunchContext * context), LIBND4J_TYPES); + ShapeDescriptor descriptor(dtype, order, shape); -//////////////////////////////////////////////////////////////////////// -NDArray NDArrayFactory::create(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + descriptor.arrLength() * DataTypeUtils::sizeOfElement(dtype), dtype, + context->getWorkspace()); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32"); + NDArray result(buffer, descriptor, context); - ShapeDescriptor descriptor(dtype, order, shape); + result.nullify(); - std::shared_ptr buffer = std::make_shared(descriptor.arrLength() * DataTypeUtils::sizeOfElement(dtype), dtype, context->getWorkspace()); + return result; +} - NDArray result(buffer, descriptor, context); +//////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::create(sd::DataType dtype, sd::LaunchContext* context) { + std::shared_ptr buffer = + std::make_shared(DataTypeUtils::sizeOfElement(dtype), dtype, + context->getWorkspace(), true); - result.nullify(); + NDArray res(buffer, ShapeDescriptor::scalarDescriptor(dtype), context); - return result; + res.nullify(); + + return res; } +NDArray* NDArrayFactory::create_(sd::DataType dtype, + sd::LaunchContext* context) { + auto result = new NDArray(); + *result = NDArrayFactory::create(dtype, context); + return result; +} //////////////////////////////////////////////////////////////////////// -NDArray NDArrayFactory::create(sd::DataType dtype, sd::LaunchContext * context) { - - std::shared_ptr buffer = std::make_shared(DataTypeUtils::sizeOfElement(dtype), dtype, context->getWorkspace(), true); +template +NDArray NDArrayFactory::create(const std::vector& values, + sd::LaunchContext* context) { + std::shared_ptr buffer = std::make_shared( + values.size() * sizeof(T), DataTypeUtils::fromT(), + context->getWorkspace(), true); - NDArray res(buffer, ShapeDescriptor::scalarDescriptor(dtype), context); + NDArray res(buffer, + ShapeDescriptor::vectorDescriptor(values.size(), + DataTypeUtils::fromT()), + context); - res.nullify(); + memcpyFromVector(res.buffer(), values); - return res; -} + res.tickWriteHost(); + res.syncToDevice(); -NDArray* NDArrayFactory::create_(sd::DataType dtype, sd::LaunchContext * context) { - auto result = new NDArray(); - *result = NDArrayFactory::create(dtype, context); - return result; + return res; } +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + const std::vector& values, sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// template -NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context) { +NDArray* NDArrayFactory::empty_(sd::LaunchContext* context) { + auto shapeInfo = ShapeBuilders::createScalarShapeInfo( + DataTypeUtils::fromT(), context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + auto result = new NDArray(nullptr, shapeInfo, context, false); + + RELEASE(shapeInfo, context->getWorkspace()); - std::shared_ptr buffer = std::make_shared(values.size() * sizeof(T), DataTypeUtils::fromT(), context->getWorkspace(), true); + return result; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::empty_, + (sd::LaunchContext * context), LIBND4J_TYPES); - NDArray res(buffer, ShapeDescriptor::vectorDescriptor(values.size(), DataTypeUtils::fromT()), context); +NDArray* NDArrayFactory::empty_(sd::DataType dataType, + sd::LaunchContext* context) { + if (context == nullptr) context = sd::LaunchContext ::defaultContext(); - memcpyFromVector(res.buffer(), values); + auto shapeInfo = + ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + auto result = new NDArray(nullptr, shapeInfo, context, false); - res.tickWriteHost(); - res.syncToDevice(); + RELEASE(shapeInfo, context->getWorkspace()); - return res; + return result; } -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(const std::vector &values, sd::LaunchContext * context); //////////////////////////////////////////////////////////////////////// - template - NDArray* NDArrayFactory::empty_(sd::LaunchContext * context) { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::fromT(), context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - auto result = new NDArray(nullptr, shapeInfo, context, false); +template +NDArray NDArrayFactory::empty(sd::LaunchContext* context) { + return empty(DataTypeUtils::fromT(), context); +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::empty, + (sd::LaunchContext * context), LIBND4J_TYPES); - RELEASE(shapeInfo, context->getWorkspace()); +//////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::empty(sd::DataType dataType, + sd::LaunchContext* context) { + auto shapeInfo = + ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + NDArray result(nullptr, shapeInfo, context, false); - return result; - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray* NDArrayFactory::empty_, (sd::LaunchContext * context), LIBND4J_TYPES); + RELEASE(shapeInfo, context->getWorkspace()); - NDArray* NDArrayFactory::empty_(sd::DataType dataType, sd::LaunchContext * context) { - if (context == nullptr) - context = sd::LaunchContext ::defaultContext(); + return result; +} - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - auto result = new NDArray(nullptr, shapeInfo, context, false); +//////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::valueOf(const std::vector& shape, + const NDArray& value, const char order, + sd::LaunchContext* context) { + auto res = NDArrayFactory::create_(order, shape, value.dataType(), context); + res->assign(const_cast(value)); + return res; +} - RELEASE(shapeInfo, context->getWorkspace()); +//////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::create_(const char order, + const std::vector& shape, + sd::DataType dataType, + sd::LaunchContext* context) { + return new NDArray(order, shape, dataType, context); +} - return result; - } +//////////////////////////////////////////////////////////////////////// +template +NDArray NDArrayFactory::create(T* buffer, const char order, + const std::initializer_list& shape, + sd::LaunchContext* context) { + if ((int)shape.size() > MAX_RANK) + throw std::invalid_argument( + "NDArrayFactory::create: Rank of NDArray can't exceed 32"); - //////////////////////////////////////////////////////////////////////// - template - NDArray NDArrayFactory::empty(sd::LaunchContext * context) { - return empty(DataTypeUtils::fromT(), context); - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT NDArray NDArrayFactory::empty, (sd::LaunchContext * context), LIBND4J_TYPES); + std::vector shp(shape); + ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shp); - //////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::empty(sd::DataType dataType, sd::LaunchContext * context) { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, context->getWorkspace()); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - NDArray result(nullptr, shapeInfo, context, false); + std::shared_ptr pBuffer = std::make_shared( + buffer, descriptor.arrLength() * sizeof(T), descriptor.dataType(), false, + context->getWorkspace()); - RELEASE(shapeInfo, context->getWorkspace()); + NDArray result(pBuffer, descriptor, context); - return result; - } + return result; +} -//////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::valueOf(const std::vector& shape, const NDArray& value, const char order, sd::LaunchContext * context) { - auto res = NDArrayFactory::create_(order, shape, value.dataType(), context); - res->assign(const_cast(value)); - return res; - } +template SD_EXPORT NDArray NDArrayFactory::create( + double* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + float* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + float16* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + bfloat16* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + Nd4jLong* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int* buffer, const char order, const std::initializer_list& shape, + sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + bool* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + uint8_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int8_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); +template SD_EXPORT NDArray NDArrayFactory::create( + int16_t* buffer, const char order, + const std::initializer_list& shape, sd::LaunchContext* context); + +///////////////////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char16_t* u16string, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(u16string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char16_t* u16string, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::u16string(u16string), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::u16string& u16string, + sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u16string, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::u16string& u16string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(u16string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char32_t* u32string, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(u32string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char32_t* u32string, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::u32string(u32string), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::u32string& u32string, + sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(u32string, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::u32string& u32string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(u32string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const char* str, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(str, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const char* str, sd::DataType dtype, + sd::LaunchContext* context) { + return string_(std::string(str), dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(str, dtype, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::string& str, sd::DataType dtype, + sd::LaunchContext* context) { + return NDArray(str, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::initializer_list& string, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::string(s); + + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, string, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, + context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::u16string(s); + + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(shape, string, dtype, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArray(shape, std::vector(strings), dataType, + context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, strings, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArray(shape, std::vector(string), dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& strings, + sd::DataType dataType, sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(strings), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& strings, + sd::DataType dataType, + sd::LaunchContext* context) { + std::vector vec(strings.size()); + int cnt = 0; + for (auto s : strings) vec[cnt++] = std::u32string(s); + return NDArrayFactory::string_(shape, vec, dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_( + const std::vector& shape, + const std::initializer_list& string, sd::DataType dataType, + sd::LaunchContext* context) { + return NDArrayFactory::string_(shape, std::vector(string), + dataType, context); +} +///////////////////////////////////////////////////////////////////////// +NDArray* NDArrayFactory::string_(const std::vector& shape, + const std::vector& string, + sd::DataType dataType, + sd::LaunchContext* context) { + auto res = new NDArray(); + *res = NDArray(shape, string, dataType, context); + return res; +} +///////////////////////////////////////////////////////////////////////// +NDArray NDArrayFactory::string(const std::vector& shape, + const std::vector& string, + sd::DataType dtype, sd::LaunchContext* context) { + return NDArray(shape, string, dtype, context); +} -//////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::create_( const char order, const std::vector &shape, sd::DataType dataType, sd::LaunchContext * context) { +NDArray NDArrayFactory::fromNpyFile(const char* fileName) { + if (!FileUtils::fileExists(fileName)) + throw std::runtime_error("File doesn't exit"); - return new NDArray(order, shape, dataType, context); - } + auto pNPY = reinterpret_cast(::numpyFromFile(std::string(fileName))); -//////////////////////////////////////////////////////////////////////// -template -NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArrayFactory::create: Rank of NDArray can't exceed 32"); - - std::vector shp(shape); - ShapeDescriptor descriptor(DataTypeUtils::fromT(), order, shp); - - std::shared_ptr pBuffer = std::make_shared(buffer, descriptor.arrLength() * sizeof(T), descriptor.dataType(), false, context->getWorkspace()); - - NDArray result(pBuffer, descriptor, context); - - return result; -} - -template SD_EXPORT NDArray NDArrayFactory::create(double* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(float* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(float16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(bfloat16* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(Nd4jLong * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(int* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(bool* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); -template SD_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list& shape, sd::LaunchContext * context); - - ///////////////////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char16_t* u16string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u16string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char16_t* u16string, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::u16string(u16string), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(u16string, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u16string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char32_t* u32string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u32string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char32_t* u32string, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::u32string(u32string), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(u32string, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(u32string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const char* str, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(str, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const char* str, sd::DataType dtype, sd::LaunchContext* context) { - return string_(std::string(str), dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray(str, dtype, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray(str, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector &shape, const std::initializer_list &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray(shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::initializer_list &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &strings, sd::DataType dataType, sd::LaunchContext * context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::vector &strings, sd::DataType dataType, sd::LaunchContext * context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s:strings) - vec[cnt++] = std::string(s); - - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector &shape, const std::initializer_list &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector &shape, const std::vector &string, sd::DataType dataType, sd::LaunchContext * context) { - return NDArray(shape, string, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_(const std::vector &shape, const std::vector &string, sd::DataType dataType, sd::LaunchContext * context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s : strings) - vec[cnt++] = std::u16string(s); - - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, sd::DataType dataType, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray( shape, string, dtype, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray( shape, strings, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArray(shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& strings, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(strings), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& strings, sd::DataType dataType, sd::LaunchContext* context) { - std::vector vec(strings.size()); - int cnt = 0; - for (auto s : strings) - vec[cnt++] = std::u32string(s); - return NDArrayFactory::string_( shape, vec, dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::initializer_list& string, sd::DataType dataType, sd::LaunchContext* context) { - return NDArrayFactory::string_( shape, std::vector(string), dataType, context); - } - ///////////////////////////////////////////////////////////////////////// - NDArray* NDArrayFactory::string_( const std::vector& shape, const std::vector& string, sd::DataType dataType, sd::LaunchContext* context) { - auto res = new NDArray(); - *res = NDArray( shape, string, dataType, context); - return res; - } - ///////////////////////////////////////////////////////////////////////// - NDArray NDArrayFactory::string(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - return NDArray( shape, string, dtype, context); - } - - - NDArray NDArrayFactory::fromNpyFile(const char *fileName) { - if (!FileUtils::fileExists(fileName)) - throw std::runtime_error("File doesn't exit"); - - auto pNPY = reinterpret_cast(::numpyFromFile(std::string(fileName))); - - auto nBuffer = reinterpret_cast(::dataPointForNumpy(pNPY)); - auto shape = reinterpret_cast(::shapeBufferForNumpy(pNPY)); - - auto length = shape::length(shape); - int8_t *buffer = nullptr; - sd::memory::Workspace *workspace = nullptr; - auto byteLen = length * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)); - - ALLOCATE(buffer, workspace, byteLen, int8_t); - memcpy(buffer, nBuffer, byteLen); - - free(pNPY); - - return NDArray(buffer, shape, LaunchContext::defaultContext(), true); - } + auto nBuffer = reinterpret_cast(::dataPointForNumpy(pNPY)); + auto shape = reinterpret_cast(::shapeBufferForNumpy(pNPY)); + + auto length = shape::length(shape); + int8_t* buffer = nullptr; + sd::memory::Workspace* workspace = nullptr; + auto byteLen = + length * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)); + + ALLOCATE(buffer, workspace, byteLen, int8_t); + memcpy(buffer, nBuffer, byteLen); + + free(pNPY); + + return NDArray(buffer, shape, LaunchContext::defaultContext(), true); } +} // namespace sd diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 356390e0d369..4f17f4479cb4 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -18,270 +18,262 @@ // @author raver119@gmail.com // - -#include #include #include #include -#include - -namespace sd { - NDArrayList::InternalArrayList::InternalArrayList(int height, bool expandable) { - _expandable = expandable; - _elements.store(0); - _counter.store(0); - _id.first = 0; - _id.second = 0; - _height = height; - } - - NDArrayList::NDArrayList(const NDArrayList &other) { - _state = other._state; - } - - NDArrayList::NDArrayList(NDArrayList &&other) { - _state = std::move(other._state); - } - - NDArrayList &NDArrayList::operator=(const NDArrayList &other) noexcept { - if (this == &other) - return *this; - - _state = other._state; - - return *this; - } - - NDArrayList &NDArrayList::operator=(NDArrayList &&other) noexcept { - if (this == &other) - return *this; +#include - _state = std::move(other._state); - - return *this; - } - - NDArrayList::NDArrayList(int height, bool expandable) { - _state = std::make_shared(height, expandable); - } - - NDArrayList::~NDArrayList() { - - } - - NDArray NDArrayList::read(int idx) { - return readRaw(idx); - } - - sd::DataType NDArrayList::dataType() const { - return _state->_dtype; - } +#include - NDArray NDArrayList::readRaw(int idx) { - if (_state->_chunks.count(idx) < 1) { - nd4j_printf("Non-existent chunk requested: [%i]\n", idx); - throw std::invalid_argument("Bad index"); +namespace sd { +NDArrayList::InternalArrayList::InternalArrayList(int height, bool expandable) { + _expandable = expandable; + _elements.store(0); + _counter.store(0); + _id.first = 0; + _id.second = 0; + _height = height; +} + +NDArrayList::NDArrayList(const NDArrayList &other) { _state = other._state; } + +NDArrayList::NDArrayList(NDArrayList &&other) { + _state = std::move(other._state); +} + +NDArrayList &NDArrayList::operator=(const NDArrayList &other) noexcept { + if (this == &other) return *this; + + _state = other._state; + + return *this; +} + +NDArrayList &NDArrayList::operator=(NDArrayList &&other) noexcept { + if (this == &other) return *this; + + _state = std::move(other._state); + + return *this; +} + +NDArrayList::NDArrayList(int height, bool expandable) { + _state = std::make_shared(height, expandable); +} + +NDArrayList::~NDArrayList() {} + +NDArray NDArrayList::read(int idx) { return readRaw(idx); } + +sd::DataType NDArrayList::dataType() const { return _state->_dtype; } + +NDArray NDArrayList::readRaw(int idx) { + if (_state->_chunks.count(idx) < 1) { + nd4j_printf("Non-existent chunk requested: [%i]\n", idx); + throw std::invalid_argument("Bad index"); + } + + return _state->_chunks.at(idx); +} + +Nd4jStatus NDArrayList::write(int idx, const NDArray &array) { + if (_state->_chunks.count(idx) == 0) + _state->_elements++; + else { + _state->_chunks.erase(idx); + } + + // we store reference shape on first write + if (_state->_chunks.empty()) { + _state->_dtype = array.dataType(); + + if (_state->_shape.empty()) { + // adding leading 1 to shape + _state->_shape.emplace_back(1); + for (int e = 0; e < array.rankOf(); e++) + _state->_shape.emplace_back(array.sizeAt(e)); + } else { + // if shape is inferred (say, from split_list) + if (array.rankOf() == _state->_shape.size()) { + // skipping first dim + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); } - - return _state->_chunks.at(idx); + } else if (array.rankOf() == _state->_shape.size() - 1) { + // case like 2d _shape, and 1D rows + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } else + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size along " + "inner dimensions"); } - - Nd4jStatus NDArrayList::write(int idx, const NDArray &array) { - if (_state->_chunks.count(idx) == 0) - _state->_elements++; - else { - _state->_chunks.erase(idx); - } - - - // we store reference shape on first write - if (_state->_chunks.empty()) { - _state->_dtype = array.dataType(); - - if (_state->_shape.empty()) { - //adding leading 1 to shape - _state->_shape.emplace_back(1); - for (int e = 0; e < array.rankOf(); e++) - _state->_shape.emplace_back(array.sizeAt(e)); - } else { - // if shape is inferred (say, from split_list) - if (array.rankOf() == _state->_shape.size()) { - // skipping first dim - for (int e = 1; e < _state->_shape.size(); e++) { - if (_state->_shape[e] != array.sizeAt(e)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } - } else if (array.rankOf() == _state->_shape.size() - 1) { - // case like 2d _shape, and 1D rows - for (int e = 1; e < _state->_shape.size(); e++) - if (_state->_shape[e] != array.sizeAt(e - 1)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } else - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - - } + } else { + if (array.dataType() != _state->_dtype) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same data type"); + + // if shape is inferred (say, from split_list) + if (array.rankOf() == _state->_shape.size()) { + // skipping first dim + for (int e = 1; e < _state->_shape.size(); e++) { + if (_state->_shape[e] != array.sizeAt(e)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } + } else if (array.rankOf() == _state->_shape.size() - 1) { + // case like 2d _shape, and 1D rows + for (int e = 1; e < _state->_shape.size(); e++) + if (_state->_shape[e] != array.sizeAt(e - 1)) + return Status::CODE(ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size " + "along inner dimensions"); + } else + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "NDArrayList: all arrays must have same size along inner dimensions"); + } + + // storing reference + _state->_chunks.insert({idx, array}); + + return Status::OK(); +} + +const std::vector &NDArrayList::shape() const { + return _state->_shape; +} + +int NDArrayList::counter() const { return _state->_counter++; } + +void NDArrayList::unstack(const NDArray &array, int axis) { + _state->_axis = axis; + std::vector args({axis}); + auto newAxis = ShapeUtils::evalDimsToExclude(array.rankOf(), args); + auto result = array.allTensorsAlongDimension(newAxis); + for (int e = 0; e < result.size(); e++) { + auto chunk = result.at(e); + write(e, chunk.dup(array.ordering())); + } +} + +void NDArrayList::setShape(const std::vector &shape) { + _state->_shape = shape; +} + +NDArray NDArrayList::stack() const { + // FIXME: this is bad for perf, but ok as poc + + int numElements = _state->_elements.load(); + std::vector inputs(numElements); + for (int e = 0; e < numElements; e++) { + _state->_chunks.at(e).syncToDevice(); + inputs[e] = &_state->_chunks.at(e); + } + + auto inShapeInfo = inputs[0]->shapeInfo(); + int rank = shape::rank(inShapeInfo); + NDArray array; + + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + if (numElements == 1) { + array = NDArray(inputs[0]->ordering(), {0}, + ArrayOptions::dataType(inShapeInfo), + inputs[0]->getContext()); } else { - if (array.dataType() != _state->_dtype) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same data type"); - - - // if shape is inferred (say, from split_list) - if (array.rankOf() == _state->_shape.size()) { - // skipping first dim - for (int e = 1; e < _state->_shape.size(); e++) { - if (_state->_shape[e] != array.sizeAt(e)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } - } else if (array.rankOf() == _state->_shape.size() - 1) { - // case like 2d _shape, and 1D rows - for (int e = 1; e < _state->_shape.size(); e++) - if (_state->_shape[e] != array.sizeAt(e - 1)) - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); - } else - return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); + array = NDArray('c', {(Nd4jLong)numElements, 0}, + ArrayOptions::dataType(inShapeInfo), + inputs[0]->getContext()); } - - // storing reference - _state->_chunks.insert({idx, array}); - - return Status::OK(); + } } + } else { + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + outShape.insert(outShape.begin(), (Nd4jLong)numElements); + array = + NDArray(shape::order(inShapeInfo), outShape, + ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); + } - const std::vector& NDArrayList::shape() const { - return _state->_shape; - } + ops::helpers::stack(inputs[0]->getContext(), inputs, array, 0); - int NDArrayList::counter() const { - return _state->_counter++; - } + return array; +} - void NDArrayList::unstack(const NDArray &array, int axis) { - _state->_axis = axis; - std::vector args({axis}); - auto newAxis = ShapeUtils::evalDimsToExclude(array.rankOf(), args); - auto result = array.allTensorsAlongDimension(newAxis); - for (int e = 0; e < result.size(); e++) { - auto chunk = result.at(e); - write(e, chunk.dup(array.ordering())); - } - } +const std::pair &NDArrayList::id() const { return _state->_id; } - void NDArrayList::setShape(const std::vector &shape) { - _state->_shape = shape; - } +const std::string &NDArrayList::name() const { return _state->_name; } - NDArray NDArrayList::stack() const { - // FIXME: this is bad for perf, but ok as poc - - int numElements = _state->_elements.load(); - std::vector inputs(numElements); - for (int e = 0; e < numElements; e++) { - _state->_chunks.at(e).syncToDevice(); - inputs[e] = &_state->_chunks.at(e); - } +int NDArrayList::elements() const { return (int)_state->_chunks.size(); } - auto inShapeInfo = inputs[0]->shapeInfo(); - int rank = shape::rank(inShapeInfo); - NDArray array; - - if (shape::isEmpty(inShapeInfo)) { - switch (rank) { - case 0: { - if (numElements == 1) { - array = NDArray(inputs[0]->ordering(), {0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); - } else { - array = NDArray('c', {(Nd4jLong) numElements, 0}, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext() ) ; - } - } - } - } - else{ - std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); - outShape.insert(outShape.begin(), (Nd4jLong) numElements); - array = NDArray( shape::order(inShapeInfo), outShape, ArrayOptions::dataType(inShapeInfo), inputs[0]->getContext()); - } - - ops::helpers::stack(inputs[0]->getContext(), inputs, array, 0); - - return array; - } +int NDArrayList::height() const { return (int)_state->_chunks.size(); } - const std::pair& NDArrayList::id() const { - return _state->_id; - } +bool NDArrayList::isWritten(int index) const { + if (_state->_chunks.count(index) > 0) + return true; + else + return false; +} - const std::string& NDArrayList::name() const { - return _state->_name; - } +NDArray NDArrayList::pick(const std::vector &indices) { + std::vector shape(_state->_shape); - int NDArrayList::elements() const { - return (int) _state->_chunks.size(); - } + // shape.insert(shape.begin() + _axis, indices.size()); + shape[_state->_axis] = indices.size(); + // do we have to enforce C order here? + NDArray array('c', shape, _state->_chunks.at(0).dataType()); + std::vector axis = + ShapeUtils::evalDimsToExclude(shape.size(), {_state->_axis}); + auto tads = array.allTensorsAlongDimension(axis); + int indicesSize = indices.size(); - int NDArrayList::height() const { - return (int) _state->_chunks.size(); - } + if (tads.size() != indicesSize) + throw std::runtime_error("Number of TADs should match number of indices"); - bool NDArrayList::isWritten(int index) const { - if (_state->_chunks.count(index) > 0) - return true; - else - return false; - } + for (int e = 0; e < indicesSize; e++) + tads.at(e).assign(_state->_chunks.at(indices[e])); - NDArray NDArrayList::pick(const std::vector &indices) { - std::vector shape(_state->_shape); + return array; +} - //shape.insert(shape.begin() + _axis, indices.size()); - shape[_state->_axis] = indices.size(); - // do we have to enforce C order here? - NDArray array('c', shape, _state->_chunks.at(0).dataType()); - std::vector axis = ShapeUtils::evalDimsToExclude(shape.size(), {_state->_axis}); - auto tads = array.allTensorsAlongDimension(axis); - int indicesSize = indices.size(); +NDArrayList NDArrayList::clone() { + NDArrayList list(_state->_height, _state->_expandable); + list._state->_axis = _state->_axis; + list._state->_id.first = _state->_id.first; + list._state->_id.second = _state->_id.second; + list._state->_name = _state->_name; + list._state->_elements.store(_state->_elements.load()); - if (tads.size() != indicesSize) - throw std::runtime_error("Number of TADs should match number of indices"); + for (auto const &v : _state->_chunks) { + list._state->_chunks.insert({v.first, v.second.dup()}); + } - for (int e = 0; e < indicesSize; e++) - tads.at(e).assign(_state->_chunks.at(indices[e])); + return list; +} - return array; - } +bool NDArrayList::equals(NDArrayList &other) { + if (_state->_axis != other._state->_axis) return false; - NDArrayList NDArrayList::clone() { - NDArrayList list(_state->_height, _state->_expandable); - list._state->_axis = _state->_axis; - list._state->_id.first = _state->_id.first; - list._state->_id.second = _state->_id.second; - list._state->_name = _state->_name; - list._state->_elements.store(_state->_elements.load()); + if (_state->_chunks.size() != other._state->_chunks.size()) return false; - for (auto const& v : _state->_chunks) { - list._state->_chunks.insert({v.first, v.second.dup()}); - } + for (auto const &v : _state->_chunks) { + if (other._state->_chunks.count(v.first) == 0) return false; - return list; - } + auto arrThis = _state->_chunks.at(v.first); + auto arrThat = other._state->_chunks.at(v.first); - bool NDArrayList::equals(NDArrayList& other) { - if (_state->_axis != other._state->_axis) - return false; + if (!arrThis.equalsTo(arrThat)) return false; + } - if (_state->_chunks.size() != other._state->_chunks.size()) - return false; - - for (auto const& v : _state->_chunks) { - if (other._state->_chunks.count(v.first) == 0) - return false; - - auto arrThis = _state->_chunks.at(v.first); - auto arrThat = other._state->_chunks.at(v.first); - - if (!arrThis.equalsTo(arrThat)) - return false; - } - - return true; - } -} \ No newline at end of file + return true; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/ResultSet.cpp b/libnd4j/include/array/impl/ResultSet.cpp index 53cde399e6b8..41c218954d04 100644 --- a/libnd4j/include/array/impl/ResultSet.cpp +++ b/libnd4j/include/array/impl/ResultSet.cpp @@ -22,133 +22,115 @@ #include namespace sd { - ResultSet::ResultSet() { - // - } - - ResultSet::ResultSet(const sd::graph::FlatResult* result) { - for (int e = 0; e < result->variables()->size(); e++) { - auto var = result->variables()->Get(e); +ResultSet::ResultSet() { + // +} - NDArray array; +ResultSet::ResultSet(const sd::graph::FlatResult* result) { + for (int e = 0; e < result->variables()->size(); e++) { + auto var = result->variables()->Get(e); - if (var->ndarray() != nullptr) { - array = sd::graph::FlatUtils::fromFlatArray(var->ndarray()); - } else if (var->shape() != nullptr) { - std::vector shapeInfo; - for (int i = 0; i < var->shape()->size(); i++) { - shapeInfo.emplace_back(var->shape()->Get(i)); - } + NDArray array; - // we just create empty array here - int s0 = shapeInfo.at(0); + if (var->ndarray() != nullptr) { + array = sd::graph::FlatUtils::fromFlatArray(var->ndarray()); + } else if (var->shape() != nullptr) { + std::vector shapeInfo; + for (int i = 0; i < var->shape()->size(); i++) { + shapeInfo.emplace_back(var->shape()->Get(i)); + } - std::vector shape; - for (int i = 0; i < s0; i++) { - shape.emplace_back(shapeInfo.at(i + 1)); - } + // we just create empty array here + int s0 = shapeInfo.at(0); - array = NDArray((char) shapeInfo.at(shapeInfo.size() - 1), shape, DataTypeUtils::fromFlatDataType(var->dtype())); - } else { - nd4j_printf("Either shape or NDArray should be defined in FlatResult variable\n",""); - throw std::runtime_error("Empty variable"); - } + std::vector shape; + for (int i = 0; i < s0; i++) { + shape.emplace_back(shapeInfo.at(i + 1)); + } - _content.push_back(array); - } + array = NDArray((char)shapeInfo.at(shapeInfo.size() - 1), shape, + DataTypeUtils::fromFlatDataType(var->dtype())); + } else { + nd4j_printf( + "Either shape or NDArray should be defined in FlatResult variable\n", + ""); + throw std::runtime_error("Empty variable"); } - ResultSet::ResultSet(const ResultSet& other) noexcept{ - for (const auto v:other._content) - _content.emplace_back(v); + _content.push_back(array); + } +} - _status = other._status; - _removable = false; - } +ResultSet::ResultSet(const ResultSet& other) noexcept { + for (const auto v : other._content) _content.emplace_back(v); - //////////////////////////////////////////////////////////////////////// - // move constructor - ResultSet::ResultSet(ResultSet&& other) noexcept { + _status = other._status; + _removable = false; +} - _content = std::move(other._content); - _status = other._status; - _removable = other._removable; - other._removable = false; - } +//////////////////////////////////////////////////////////////////////// +// move constructor +ResultSet::ResultSet(ResultSet&& other) noexcept { + _content = std::move(other._content); + _status = other._status; + _removable = other._removable; + other._removable = false; +} - //////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////// // move assignment operator - ResultSet& ResultSet::operator=(ResultSet&& other) noexcept { +ResultSet& ResultSet::operator=(ResultSet&& other) noexcept { + if (this == &other) return *this; - if (this == &other) - return *this; + delContent(); - delContent(); + _content = std::move(other._content); - _content = std::move(other._content); + _status = other._status; + _removable = other._removable; + other._removable = false; - _status = other._status; - _removable = other._removable; - other._removable = false; - - return *this; - } - - ResultSet& ResultSet::operator=(const ResultSet& other) noexcept { - - if (this == &other) - return *this; - - delContent(); + return *this; +} - for (const auto v : other._content) - _content.push_back(v); +ResultSet& ResultSet::operator=(const ResultSet& other) noexcept { + if (this == &other) return *this; - _status = other._status; - _removable = false; + delContent(); - return *this; - } + for (const auto v : other._content) _content.push_back(v); - void ResultSet::delContent() { - // - } + _status = other._status; + _removable = false; - ResultSet::~ResultSet() { + return *this; +} - delContent(); - } +void ResultSet::delContent() { + // +} - void ResultSet::setNonRemovable() { - _removable = false; - } +ResultSet::~ResultSet() { delContent(); } - int ResultSet::size() { - return (int) _content.size(); - } +void ResultSet::setNonRemovable() { _removable = false; } - sd::NDArray& ResultSet::at(const unsigned long idx) const { - return const_cast(_content[idx]); - } +int ResultSet::size() { return (int)_content.size(); } - sd::NDArray& ResultSet::operator[](const unsigned long idx) const { - return const_cast(_content[idx]); - } +sd::NDArray& ResultSet::at(const unsigned long idx) const { + return const_cast(_content[idx]); +} - void ResultSet::push_back(const sd::NDArray &array) { - _content.emplace_back(array); - } +sd::NDArray& ResultSet::operator[](const unsigned long idx) const { + return const_cast(_content[idx]); +} - Nd4jStatus ResultSet::status() { - return _status; - } +void ResultSet::push_back(const sd::NDArray& array) { + _content.emplace_back(array); +} - void ResultSet::setStatus(Nd4jStatus status) { - _status = status; - } +Nd4jStatus ResultSet::status() { return _status; } - void ResultSet::purge() { - _content.clear(); - } -} +void ResultSet::setStatus(Nd4jStatus status) { _status = status; } +void ResultSet::purge() { _content.clear(); } +} // namespace sd diff --git a/libnd4j/include/array/impl/ShapeDescriptor.cpp b/libnd4j/include/array/impl/ShapeDescriptor.cpp index 3ef096312d35..a097ce276a08 100644 --- a/libnd4j/include/array/impl/ShapeDescriptor.cpp +++ b/libnd4j/include/array/impl/ShapeDescriptor.cpp @@ -19,362 +19,337 @@ // #include -#include #include +#include namespace sd { ////////////////////////////////////////////////////////////////////////// // equal to operator - bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const { - - if (_empty != other._empty) - return false; - if (_rank != other._rank) - return false; - if (_order != other._order) - return false; - if (_dataType != other._dataType) - return false; - if (_ews != other._ews) - return false; - - if (_shape != other._shape) - return false; - - if (_strides != other._strides) - return false; - - return true; - } +bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const { + if (_empty != other._empty) return false; + if (_rank != other._rank) return false; + if (_order != other._order) return false; + if (_dataType != other._dataType) return false; + if (_ews != other._ews) return false; + + if (_shape != other._shape) return false; + + if (_strides != other._strides) return false; + + return true; +} ////////////////////////////////////////////////////////////////////////// // less than operator - bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const { - return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < - std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape, - other._strides); +bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const { + return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < + std::tie(other._empty, other._rank, other._dataType, other._ews, + other._order, other._shape, other._strides); +} + +Nd4jLong *ShapeDescriptor::toShapeInfo() const { + if (_empty) { + if (_rank == 0) + return ShapeBuilders::emptyShapeInfo(_dataType); + else { + return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); } + } - Nd4jLong *ShapeDescriptor::toShapeInfo() const { - if (_empty) { - if (_rank == 0) - return ShapeBuilders::emptyShapeInfo(_dataType); - else { - return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); - } - } - - - switch (_rank) { - case 0: { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); - shapeInfo[2] = _ews; - return shapeInfo; - } - case 1: { - auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); - shapeInfo[2 + _rank * 2] = _ews; - shapeInfo[2] = _strides[0]; - shapeInfo[2 + _rank * 2 + 1] = _order; - return shapeInfo; - } - default: { - auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape); - - for (int e = 0; e < _rank; e++) - shapeInfo[e + 1 + _rank] = _strides[e]; - - shapeInfo[2 + _rank * 2] = _ews; - - return shapeInfo; - } - } + switch (_rank) { + case 0: { + auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); + shapeInfo[2] = _ews; + return shapeInfo; } + case 1: { + auto shapeInfo = + ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); + shapeInfo[2 + _rank * 2] = _ews; + shapeInfo[2] = _strides[0]; + shapeInfo[2 + _rank * 2 + 1] = _order; + return shapeInfo; + } + default: { + auto shapeInfo = + ShapeBuilders::createShapeInfo(_dataType, _order, _shape); - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank) - : _dataType(type), _order(order), _rank(rank), _ews(1) { - _shape.resize(rank); - _strides.resize(rank); + for (int e = 0; e < _rank; e++) shapeInfo[e + 1 + _rank] = _strides[e]; - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; + shapeInfo[2 + _rank * 2] = _ews; - if (order == 'c') - shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); + return shapeInfo; + } + } +} +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const int rank) + : _dataType(type), _order(order), _rank(rank), _ews(1) { + _shape.resize(rank); + _strides.resize(rank); - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } - } + for (int e = 0; e < rank; e++) _shape[e] = shape[e]; - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, - const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) { - _shape.resize(rank); - _strides.resize(rank); + if (order == 'c') + shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); - _dataType = type; - _order = order; - _rank = rank; - _empty = empty; - _ews = ews; + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; + } + } +} + +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const Nd4jLong *shape, const Nd4jLong *strides, + const int rank, Nd4jLong ews, + const bool empty) { + _shape.resize(rank); + _strides.resize(rank); - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; + _dataType = type; + _order = order; + _rank = rank; + _empty = empty; + _ews = ews; - for (int e = 0; e < rank; e++) - _strides[e] = strides[e]; + for (int e = 0; e < rank; e++) _shape[e] = shape[e]; + for (int e = 0; e < rank; e++) _strides[e] = strides[e]; - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape) - : _dataType(type), _order(order), _shape(shape) { - _rank = shape.size(); - _ews = 1; - - if (_rank > 0) { - _strides.resize(_rank); - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } - - // no point calculating strides for empty arrays - if (!_empty) { - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } else { - // all strides set to 0 - memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); - } - } +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape) + : _dataType(type), _order(order), _shape(shape) { + _rank = shape.size(); + _ews = 1; + + if (_rank > 0) { + _strides.resize(_rank); + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; + } } -////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, - const std::initializer_list &shape) : _dataType(type), _order(order), - _shape(shape) { - _rank = shape.size(); - _ews = 1; - - _strides.resize(shape.size()); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } + // no point calculating strides for empty arrays + if (!_empty) { + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + // all strides set to 0 + memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); } + } +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, - const std::vector &strides, const Nd4jLong ews) : ShapeDescriptor(type, - order, - shape, - strides) { - _ews = ews; +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::initializer_list &shape) + : _dataType(type), _order(order), _shape(shape) { + _rank = shape.size(); + _ews = 1; + + _strides.resize(shape.size()); + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} - ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1), - _order('c'), _rank(1), - _empty(false) { - _shape = {length}; - _strides = {1}; - } +////////////////////////////////////////////////////////////////////////// +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides, + const Nd4jLong ews) + : ShapeDescriptor(type, order, shape, strides) { + _ews = ews; +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { - _order = shape::order(shapeInfo); - _ews = shape::elementWiseStride(shapeInfo); - _rank = shape::rank(shapeInfo); +ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) + : _dataType(type), _ews(1), _order('c'), _rank(1), _empty(false) { + _shape = {length}; + _strides = {1}; +} - if (inheritDtype) - _dataType = ArrayOptions::dataType(shapeInfo); +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { + _order = shape::order(shapeInfo); + _ews = shape::elementWiseStride(shapeInfo); + _rank = shape::rank(shapeInfo); - _empty = shape::isEmpty(shapeInfo); + if (inheritDtype) _dataType = ArrayOptions::dataType(shapeInfo); - for (int e = 0; e < _rank; e++) { - _shape.emplace_back(shapeInfo[e + 1]); - if (shapeInfo[e + 1] == 0) - _empty = true; - } + _empty = shape::isEmpty(shapeInfo); - for (int e = 0; e < _rank; e++) - _strides.emplace_back(shapeInfo[e + 1 + _rank]); - } + for (int e = 0; e < _rank; e++) { + _shape.emplace_back(shapeInfo[e + 1]); + if (shapeInfo[e + 1] == 0) _empty = true; + } - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride) - : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { - _dataType = dtypeOverride; - } + for (int e = 0; e < _rank; e++) + _strides.emplace_back(shapeInfo[e + 1 + _rank]); +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride) - : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) { - // - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const sd::DataType dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { + _dataType = dtypeOverride; +} - ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, - const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, - ArrayOptions::dataType( - dtypeOverride)) { - _order = shape::order(orderOverride); - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, + ArrayOptions::dataType(dtypeOverride)) { + // +} - int ShapeDescriptor::rank() const { - return _rank; - } +ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, + const Nd4jLong *dtypeOverride, + const Nd4jLong *orderOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, + ArrayOptions::dataType(dtypeOverride)) { + _order = shape::order(orderOverride); +} - Nd4jLong ShapeDescriptor::ews() const { - return _ews; - } +int ShapeDescriptor::rank() const { return _rank; } - Nd4jLong ShapeDescriptor::arrLength() const { +Nd4jLong ShapeDescriptor::ews() const { return _ews; } - Nd4jLong len = 1; - for (const auto &dim : const_cast(this)->shape()) - len *= dim; - return len; - } +Nd4jLong ShapeDescriptor::arrLength() const { + Nd4jLong len = 1; + for (const auto &dim : const_cast(this)->shape()) + len *= dim; + return len; +} - char ShapeDescriptor::order() const { - return _order; - } +char ShapeDescriptor::order() const { return _order; } - DataType ShapeDescriptor::dataType() const { - return _dataType; - } +DataType ShapeDescriptor::dataType() const { return _dataType; } - bool ShapeDescriptor::isEmpty() const { - return _empty; - } +bool ShapeDescriptor::isEmpty() const { return _empty; } - std::vector &ShapeDescriptor::shape() { - return _shape; - } +std::vector &ShapeDescriptor::shape() { return _shape; } - std::vector &ShapeDescriptor::strides() { - return _strides; - } +std::vector &ShapeDescriptor::strides() { return _strides; } - ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { - _rank = other._rank; - _ews = other._ews; - _empty = other._empty; - _dataType = other._dataType; - _order = other._order; - _shape = other._shape; - _strides = other._strides; - } +ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { + _rank = other._rank; + _ews = other._ews; + _empty = other._empty; + _dataType = other._dataType; + _order = other._order; + _shape = other._shape; + _strides = other._strides; +} ////////////////////////////////////////////////////////////////////////// - ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, - const std::vector &strides) : _dataType(type), _order(order), - _shape(shape) { - - if (strides.empty() && !shape.empty()) { - _strides.resize(shape.size()); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } else { - _strides = strides; - } - - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } +ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::vector &shape, + const std::vector &strides) + : _dataType(type), _order(order), _shape(shape) { + if (strides.empty() && !shape.empty()) { + _strides.resize(shape.size()); + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + _strides = strides; + } + + for (auto v : _shape) { + if (v == 0) { + _empty = true; + break; } + } +} - ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = true; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; +ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = true; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; - return descriptor; - } + return descriptor; +} - ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = false; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; +ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = false; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; - return descriptor; - } + return descriptor; +} - ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._shape.emplace_back(length); +ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, + const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._shape.emplace_back(length); - if (length > 0) - descriptor._strides.emplace_back(1); - else { - descriptor._strides.emplace_back(0); - descriptor._empty = true; - } + if (length > 0) + descriptor._strides.emplace_back(1); + else { + descriptor._strides.emplace_back(0); + descriptor._empty = true; + } - descriptor._order = 'c'; - descriptor._ews = 1; - descriptor._rank = 1; + descriptor._order = 'c'; + descriptor._ews = 1; + descriptor._rank = 1; - return descriptor; - } + return descriptor; } +} // namespace sd namespace std { - size_t hash::operator()(const sd::ShapeDescriptor &k) const { - auto res = std::hash()(k.arrLength()); - res ^= std::hash()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= std::hash()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2); - res ^= std::hash()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2); - auto shapes = const_cast(k).shape(); - auto strides = const_cast(k).strides(); - for (auto s: shapes) { - res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - - for (auto s: strides) { - res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - - return res; - } +size_t hash::operator()( + const sd::ShapeDescriptor &k) const { + auto res = std::hash()(k.arrLength()); + res ^= std::hash()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2); + auto shapes = const_cast(k).shape(); + auto strides = const_cast(k).strides(); + for (auto s : shapes) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + for (auto s : strides) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + return res; } - - - +} // namespace std diff --git a/libnd4j/include/array/impl/ShapeList.cpp b/libnd4j/include/array/impl/ShapeList.cpp index d26132516227..ce610e23805e 100644 --- a/libnd4j/include/array/impl/ShapeList.cpp +++ b/libnd4j/include/array/impl/ShapeList.cpp @@ -18,69 +18,61 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - //ShapeList::ShapeList(bool autoRemovable) { +// ShapeList::ShapeList(bool autoRemovable) { // _autoremovable = autoRemovable; // } - ShapeList::ShapeList(const Nd4jLong* shape) { - if (shape != nullptr) - _shapes.push_back(shape); - } +ShapeList::ShapeList(const Nd4jLong* shape) { + if (shape != nullptr) _shapes.push_back(shape); +} - ShapeList::~ShapeList() { - if (_autoremovable) - destroy(); - } +ShapeList::~ShapeList() { + if (_autoremovable) destroy(); +} - ShapeList::ShapeList(const std::vector &shapes, bool isWorkspace) : ShapeList(shapes){ - _workspace = isWorkspace; - } +ShapeList::ShapeList(const std::vector& shapes, + bool isWorkspace) + : ShapeList(shapes) { + _workspace = isWorkspace; +} - ShapeList::ShapeList(const std::vector& shapes) { - _shapes = shapes; - } +ShapeList::ShapeList(const std::vector& shapes) { + _shapes = shapes; +} - std::vector* ShapeList::asVector() { - return &_shapes; - } +std::vector* ShapeList::asVector() { return &_shapes; } - void ShapeList::destroy() { - if (_destroyed) - return; +void ShapeList::destroy() { + if (_destroyed) return; - if (!_workspace) - for (auto v:_shapes) - if(v != nullptr) - delete[] v; + if (!_workspace) + for (auto v : _shapes) + if (v != nullptr) delete[] v; - _destroyed = true; - } + _destroyed = true; +} - int ShapeList::size() const { - return (int) _shapes.size(); - } +int ShapeList::size() const { return (int)_shapes.size(); } - const Nd4jLong* ShapeList::at(int idx) { - if (_shapes.size() <= idx) - throw std::runtime_error("Can't find requested variable by index"); +const Nd4jLong* ShapeList::at(int idx) { + if (_shapes.size() <= idx) + throw std::runtime_error("Can't find requested variable by index"); - return _shapes.at(idx); - } + return _shapes.at(idx); +} - void ShapeList::push_back(const Nd4jLong *shape) { - _shapes.push_back(shape); - } +void ShapeList::push_back(const Nd4jLong* shape) { _shapes.push_back(shape); } - void ShapeList::detach() { - for (int e = 0; e < _shapes.size(); e++) { - _shapes[e] = shape::detachShape(_shapes[e]); - } +void ShapeList::detach() { + for (int e = 0; e < _shapes.size(); e++) { + _shapes[e] = shape::detachShape(_shapes[e]); + } - _autoremovable = true; - _workspace = false; - } -} \ No newline at end of file + _autoremovable = true; + _workspace = false; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/array/impl/TadDescriptor.cpp b/libnd4j/include/array/impl/TadDescriptor.cpp index e2ec7480ee8f..8ceb6e2d4bb3 100644 --- a/libnd4j/include/array/impl/TadDescriptor.cpp +++ b/libnd4j/include/array/impl/TadDescriptor.cpp @@ -18,77 +18,74 @@ // @author raver119@gmail.com // +#include "../TadDescriptor.h" #include -#include "../TadDescriptor.h" namespace sd { - TadDescriptor::TadDescriptor(const TadDescriptor &other) { - _originalShape = other._originalShape; - _axis = other._axis; - _unitiesInShape = other._unitiesInShape; - } - - TadDescriptor::TadDescriptor(const Nd4jLong *originalShape, const int *dimensions, const int length, const bool keepUnitiesInShape) { - ShapeDescriptor descriptor(originalShape); - - _axis.resize(length); - for (int e = 0; e < length; e++) - _axis[e] = dimensions[e]; +TadDescriptor::TadDescriptor(const TadDescriptor &other) { + _originalShape = other._originalShape; + _axis = other._axis; + _unitiesInShape = other._unitiesInShape; +} - if (length > 1) - std::sort(_axis.begin(), _axis.end()); +TadDescriptor::TadDescriptor(const Nd4jLong *originalShape, + const int *dimensions, const int length, + const bool keepUnitiesInShape) { + ShapeDescriptor descriptor(originalShape); - _originalShape = descriptor; - _unitiesInShape = keepUnitiesInShape; - } + _axis.resize(length); + for (int e = 0; e < length; e++) _axis[e] = dimensions[e]; - TadDescriptor::TadDescriptor(const ShapeDescriptor &descriptor, const std::vector &dimensions, const bool keepUnitiesInShape) { - _originalShape = descriptor; - _axis = dimensions; - _unitiesInShape = keepUnitiesInShape; + if (length > 1) std::sort(_axis.begin(), _axis.end()); - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - } + _originalShape = descriptor; + _unitiesInShape = keepUnitiesInShape; +} - bool TadDescriptor::operator==(const TadDescriptor &other) const { - return std::tie(_originalShape, _axis, _unitiesInShape) == std::tie(other._originalShape, other._axis, other._unitiesInShape); - } +TadDescriptor::TadDescriptor(const ShapeDescriptor &descriptor, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + _originalShape = descriptor; + _axis = dimensions; + _unitiesInShape = keepUnitiesInShape; + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} - bool TadDescriptor::operator<(const TadDescriptor &other) const { - return std::tie(_originalShape, _axis, _unitiesInShape) < std::tie(other._originalShape, other._axis, other._unitiesInShape); - } +bool TadDescriptor::operator==(const TadDescriptor &other) const { + return std::tie(_originalShape, _axis, _unitiesInShape) == + std::tie(other._originalShape, other._axis, other._unitiesInShape); +} - std::vector& TadDescriptor::axis() { - return _axis; - } +bool TadDescriptor::operator<(const TadDescriptor &other) const { + return std::tie(_originalShape, _axis, _unitiesInShape) < + std::tie(other._originalShape, other._axis, other._unitiesInShape); +} - ShapeDescriptor& TadDescriptor::originalShape(){ - return _originalShape; - } +std::vector &TadDescriptor::axis() { return _axis; } - ShapeDescriptor const& TadDescriptor::originalShapeConst() const{ - return _originalShape; - } +ShapeDescriptor &TadDescriptor::originalShape() { return _originalShape; } - bool TadDescriptor::areUnitiesinShape() const { - return _unitiesInShape; - } +ShapeDescriptor const &TadDescriptor::originalShapeConst() const { + return _originalShape; } +bool TadDescriptor::areUnitiesinShape() const { return _unitiesInShape; } +} // namespace sd + namespace std { - size_t hash::operator()(const sd::TadDescriptor &k) const { - // Compute individual hash values for first, - // second and third and combine them using XOR - // and bit shifting: - auto res = std::hash()((int)k.areUnitiesinShape()); - res ^= std::hash()(k.originalShapeConst()) + 0x9e3779b9 + (res << 6) + (res >> 2); - auto axes = const_cast(k).axis(); - for (auto a: axes) { - res ^= std::hash()(a) + 0x9e3779b9 + (res << 6) + (res >> 2); - } - return res; - } -} \ No newline at end of file +size_t hash::operator()(const sd::TadDescriptor &k) const { + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + auto res = std::hash()((int)k.areUnitiesinShape()); + res ^= std::hash()(k.originalShapeConst()) + 0x9e3779b9 + + (res << 6) + (res >> 2); + auto axes = const_cast(k).axis(); + for (auto a : axes) { + res ^= std::hash()(a) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + return res; +} +} // namespace std \ No newline at end of file diff --git a/libnd4j/include/array/impl/TadPack.cpp b/libnd4j/include/array/impl/TadPack.cpp index 7a3bdbe364a2..5603a7947683 100644 --- a/libnd4j/include/array/impl/TadPack.cpp +++ b/libnd4j/include/array/impl/TadPack.cpp @@ -19,45 +19,47 @@ // #include "../TadPack.h" -#include + #include +#include namespace sd { - TadPack::TadPack(ConstantDataBuffer &shapes, ConstantDataBuffer &offets, Nd4jLong numTads) { - _tadShape = shapes; - _tadOffsets = offets; - _numTads = numTads; - } - - const Nd4jLong* TadPack::primaryShapeInfo() const { - return reinterpret_cast(_tadShape.primary()); - } - - const Nd4jLong* TadPack::primaryOffsets() const { - return reinterpret_cast(_tadOffsets.primary()); - } - - const Nd4jLong* TadPack::specialShapeInfo() const { - return reinterpret_cast(_tadShape.special()); - } - - const Nd4jLong* TadPack::specialOffsets() const { - return reinterpret_cast(_tadOffsets.special()); - } - - Nd4jLong TadPack::numberOfTads() const { - return _numTads; - } - - const Nd4jLong* TadPack::platformShapeInfo() const { - return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo(); - } - - const Nd4jLong* TadPack::platformOffsets() const { - return sd::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets(); - } - - int TadPack::shapeInfoLength() const { - return (int) shape::shapeInfoLength(primaryShapeInfo()); - } -} \ No newline at end of file +TadPack::TadPack(ConstantDataBuffer& shapes, ConstantDataBuffer& offets, + Nd4jLong numTads) { + _tadShape = shapes; + _tadOffsets = offets; + _numTads = numTads; +} + +const Nd4jLong* TadPack::primaryShapeInfo() const { + return reinterpret_cast(_tadShape.primary()); +} + +const Nd4jLong* TadPack::primaryOffsets() const { + return reinterpret_cast(_tadOffsets.primary()); +} + +const Nd4jLong* TadPack::specialShapeInfo() const { + return reinterpret_cast(_tadShape.special()); +} + +const Nd4jLong* TadPack::specialOffsets() const { + return reinterpret_cast(_tadOffsets.special()); +} + +Nd4jLong TadPack::numberOfTads() const { return _numTads; } + +const Nd4jLong* TadPack::platformShapeInfo() const { + return sd::Environment::getInstance()->isCPU() ? primaryShapeInfo() + : specialShapeInfo(); +} + +const Nd4jLong* TadPack::platformOffsets() const { + return sd::Environment::getInstance()->isCPU() ? primaryOffsets() + : specialOffsets(); +} + +int TadPack::shapeInfoLength() const { + return (int)shape::shapeInfoLength(primaryShapeInfo()); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/cblas.h b/libnd4j/include/cblas.h old mode 100755 new mode 100644 index 18970a9b074f..27c7c605bbc4 --- a/libnd4j/include/cblas.h +++ b/libnd4j/include/cblas.h @@ -55,12 +55,16 @@ extern "C" { #endif #ifndef CBLAS_ENUM_DEFINED_H #define CBLAS_ENUM_DEFINED_H -enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; -enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, - AtlasConj=114}; -enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; -enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; -enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; +enum CBLAS_TRANSPOSE { + CblasNoTrans = 111, + CblasTrans = 112, + CblasConjTrans = 113, + AtlasConj = 114 +}; +enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 }; +enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 }; +enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 }; #endif #ifndef CBLAS_ENUM_ONLY @@ -68,7 +72,7 @@ enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; #define CBLAS_INDEX int int cblas_errprn(int ierr, int info, char *form, ...); -void cblas_xerbla(int p, char *rout, char *form, ...); +void cblas_xerbla(int p, char *rout, char *form, ...); #ifdef __MKL void MKL_Set_Num_Threads(int num); @@ -80,57 +84,46 @@ void openblas_set_num_threads(int num); // do nothing #endif - /* * =========================================================================== * Prototypes for level 1 BLAS functions (complex are recast as routines) * =========================================================================== */ -float cblas_sdsdot(int N, float alpha, float *X, - int incX, float *Y, int incY); -double cblas_dsdot(int N, float *X, int incX, float *Y, - int incY); -float cblas_sdot(int N, float *X, int incX, - float *Y, int incY); -double cblas_ddot(int N, double *X, int incX, - double *Y, int incY); +float cblas_sdsdot(int N, float alpha, float *X, int incX, float *Y, int incY); +double cblas_dsdot(int N, float *X, int incX, float *Y, int incY); +float cblas_sdot(int N, float *X, int incX, float *Y, int incY); +double cblas_ddot(int N, double *X, int incX, double *Y, int incY); /* * Functions having prefixes Z and C only */ -void cblas_cdotu_sub(int N, void *X, int incX, - void *Y, int incY, void *dotu); -void cblas_cdotc_sub(int N, void *X, int incX, - void *Y, int incY, void *dotc); - -void cblas_zdotu_sub(int N, void *X, int incX, - void *Y, int incY, void *dotu); -void cblas_zdotc_sub(int N, void *X, int incX, - void *Y, int incY, void *dotc); +void cblas_cdotu_sub(int N, void *X, int incX, void *Y, int incY, void *dotu); +void cblas_cdotc_sub(int N, void *X, int incX, void *Y, int incY, void *dotc); +void cblas_zdotu_sub(int N, void *X, int incX, void *Y, int incY, void *dotu); +void cblas_zdotc_sub(int N, void *X, int incX, void *Y, int incY, void *dotc); /* * Functions having prefixes S D SC DZ */ -float cblas_snrm2(int N, float *X, int incX); -float cblas_sasum(int N, float *X, int incX); - -double cblas_dnrm2(int N, double *X, int incX); -double cblas_dasum(int N, double *X, int incX); +float cblas_snrm2(int N, float *X, int incX); +float cblas_sasum(int N, float *X, int incX); -float cblas_scnrm2(int N, void *X, int incX); -float cblas_scasum(int N, void *X, int incX); +double cblas_dnrm2(int N, double *X, int incX); +double cblas_dasum(int N, double *X, int incX); -double cblas_dznrm2(int N, void *X, int incX); -double cblas_dzasum(int N, void *X, int incX); +float cblas_scnrm2(int N, void *X, int incX); +float cblas_scasum(int N, void *X, int incX); +double cblas_dznrm2(int N, void *X, int incX); +double cblas_dzasum(int N, void *X, int incX); /* * Functions having standard 4 prefixes (S D C Z) */ -CBLAS_INDEX cblas_isamax(int N, float *X, int incX); -CBLAS_INDEX cblas_idamax(int N, double *X, int incX); -CBLAS_INDEX cblas_icamax(int N, void *X, int incX); -CBLAS_INDEX cblas_izamax(int N, void *X, int incX); +CBLAS_INDEX cblas_isamax(int N, float *X, int incX); +CBLAS_INDEX cblas_idamax(int N, double *X, int incX); +CBLAS_INDEX cblas_icamax(int N, void *X, int incX); +CBLAS_INDEX cblas_izamax(int N, void *X, int incX); /* * =========================================================================== @@ -141,88 +134,67 @@ CBLAS_INDEX cblas_izamax(int N, void *X, int incX); /* * Routines with standard 4 prefixes (s, d, c, z) */ -void cblas_sswap(int N, float *X, int incX, - float *Y, int incY); -void cblas_scopy(int N, float *X, int incX, - float *Y, int incY); -void cblas_saxpy(int N, float alpha, float *X, - int incX, float *Y, int incY); -void catlas_saxpby(int N, float alpha, float *X, - int incX, float beta, float *Y, int incY); -void catlas_sset - (int N, float alpha, float *X, int incX); - -void cblas_dswap(int N, double *X, int incX, - double *Y, int incY); -void cblas_dcopy(int N, double *X, int incX, - double *Y, int incY); -void cblas_daxpy(int N, double alpha, double *X, - int incX, double *Y, int incY); -void catlas_daxpby(int N, double alpha, double *X, - int incX, double beta, double *Y, int incY); -void catlas_dset - (int N, double alpha, double *X, int incX); - -void cblas_cswap(int N, void *X, int incX, - void *Y, int incY); -void cblas_ccopy(int N, void *X, int incX, - void *Y, int incY); -void cblas_caxpy(int N, void *alpha, void *X, - int incX, void *Y, int incY); -void catlas_caxpby(int N, void *alpha, void *X, - int incX, void *beta, void *Y, int incY); -void catlas_cset - (int N, void *alpha, void *X, int incX); - -void cblas_zswap(int N, void *X, int incX, - void *Y, int incY); -void cblas_zcopy(int N, void *X, int incX, - void *Y, int incY); -void cblas_zaxpy(int N, void *alpha, void *X, - int incX, void *Y, int incY); -void catlas_zaxpby(int N, void *alpha, void *X, - int incX, void *beta, void *Y, int incY); -void catlas_zset - (int N, void *alpha, void *X, int incX); +void cblas_sswap(int N, float *X, int incX, float *Y, int incY); +void cblas_scopy(int N, float *X, int incX, float *Y, int incY); +void cblas_saxpy(int N, float alpha, float *X, int incX, float *Y, int incY); +void catlas_saxpby(int N, float alpha, float *X, int incX, float beta, float *Y, + int incY); +void catlas_sset(int N, float alpha, float *X, int incX); + +void cblas_dswap(int N, double *X, int incX, double *Y, int incY); +void cblas_dcopy(int N, double *X, int incX, double *Y, int incY); +void cblas_daxpy(int N, double alpha, double *X, int incX, double *Y, int incY); +void catlas_daxpby(int N, double alpha, double *X, int incX, double beta, + double *Y, int incY); +void catlas_dset(int N, double alpha, double *X, int incX); +void cblas_cswap(int N, void *X, int incX, void *Y, int incY); +void cblas_ccopy(int N, void *X, int incX, void *Y, int incY); +void cblas_caxpy(int N, void *alpha, void *X, int incX, void *Y, int incY); +void catlas_caxpby(int N, void *alpha, void *X, int incX, void *beta, void *Y, + int incY); +void catlas_cset(int N, void *alpha, void *X, int incX); + +void cblas_zswap(int N, void *X, int incX, void *Y, int incY); +void cblas_zcopy(int N, void *X, int incX, void *Y, int incY); +void cblas_zaxpy(int N, void *alpha, void *X, int incX, void *Y, int incY); +void catlas_zaxpby(int N, void *alpha, void *X, int incX, void *beta, void *Y, + int incY); +void catlas_zset(int N, void *alpha, void *X, int incX); /* * Routines with S and D prefix only */ void cblas_srotg(float *a, float *b, float *c, float *s); -void cblas_srotmg(float *d1, float *d2, float *b1, float b2, float *P); -void cblas_srot(int N, float *X, int incX, - float *Y, int incY, float c, float s); -void cblas_srotm(int N, float *X, int incX, - float *Y, int incY, float *P); +void cblas_srotmg(float *d1, float *d2, float *b1, float b2, float *P); +void cblas_srot(int N, float *X, int incX, float *Y, int incY, float c, + float s); +void cblas_srotm(int N, float *X, int incX, float *Y, int incY, float *P); void cblas_drotg(double *a, double *b, double *c, double *s); -void cblas_drotmg(double *d1, double *d2, double *b1, double b2, double *P); -void cblas_drot(int N, double *X, int incX, - double *Y, int incY, double c, double s); -void cblas_drotm(int N, double *X, int incX, - double *Y, int incY, double *P); - +void cblas_drotmg(double *d1, double *d2, double *b1, double b2, double *P); +void cblas_drot(int N, double *X, int incX, double *Y, int incY, double c, + double s); +void cblas_drotm(int N, double *X, int incX, double *Y, int incY, double *P); /* * Routines with S D C Z CS and ZD prefixes */ -void cblas_sscal(int N, float alpha, float *X, int incX); -void cblas_dscal(int N, double alpha, double *X, int incX); -void cblas_cscal(int N, void *alpha, void *X, int incX); -void cblas_zscal(int N, void *alpha, void *X, int incX); -void cblas_csscal(int N, float alpha, void *X, int incX); -void cblas_zdscal(int N, double alpha, void *X, int incX); +void cblas_sscal(int N, float alpha, float *X, int incX); +void cblas_dscal(int N, double alpha, double *X, int incX); +void cblas_cscal(int N, void *alpha, void *X, int incX); +void cblas_zscal(int N, void *alpha, void *X, int incX); +void cblas_csscal(int N, float alpha, void *X, int incX); +void cblas_zdscal(int N, double alpha, void *X, int incX); /* * Extra reference routines provided by ATLAS, but not mandated by the standard */ void cblas_crotg(void *a, void *b, void *c, void *s); void cblas_zrotg(void *a, void *b, void *c, void *s); -void cblas_csrot(int N, void *X, int incX, void *Y, int incY, - float c, float s); -void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, - double c, double s); +void cblas_csrot(int N, void *X, int incX, void *Y, int incY, float c, float s); +void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, double c, + double s); /* * =========================================================================== @@ -233,265 +205,200 @@ void cblas_zdrot(int N, void *X, int incX, void *Y, int incY, /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - float alpha, float *A, int lda, - float *X, int incX, float beta, - float *Y, int incY); -void cblas_sgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, float alpha, - float *A, int lda, float *X, - int incX, float beta, float *Y, int incY); -void cblas_strmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *A, int lda, - float *X, int incX); -void cblas_stbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, float *A, int lda, - float *X, int incX); -void cblas_stpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *Ap, float *X, int incX); -void cblas_strsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *A, int lda, float *X, - int incX); -void cblas_stbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, float *A, int lda, - float *X, int incX); -void cblas_stpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, float *Ap, float *X, int incX); - -void cblas_dgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - double alpha, double *A, int lda, - double *X, int incX, double beta, - double *Y, int incY); -void cblas_dgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, double alpha, - double *A, int lda, double *X, - int incX, double beta, double *Y, int incY); -void cblas_dtrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *A, int lda, - double *X, int incX); -void cblas_dtbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, double *A, int lda, - double *X, int incX); -void cblas_dtpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *Ap, double *X, int incX); -void cblas_dtrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *A, int lda, double *X, - int incX); -void cblas_dtbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, double *A, int lda, - double *X, int incX); -void cblas_dtpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, double *Ap, double *X, int incX); - -void cblas_cgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - void *alpha, void *A, int lda, - void *X, int incX, void *beta, - void *Y, int incY); -void cblas_cgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, void *alpha, - void *A, int lda, void *X, - int incX, void *beta, void *Y, int incY); -void cblas_ctrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, - void *X, int incX); -void cblas_ctbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ctpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); -void cblas_ctrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, void *X, - int incX); -void cblas_ctbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ctpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); - -void cblas_zgemv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - void *alpha, void *A, int lda, - void *X, int incX, void *beta, - void *Y, int incY); -void cblas_zgbmv( enum CBLAS_ORDER Order, - enum CBLAS_TRANSPOSE TransA, int M, int N, - int KL, int KU, void *alpha, - void *A, int lda, void *X, - int incX, void *beta, void *Y, int incY); -void cblas_ztrmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, - void *X, int incX); -void cblas_ztbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ztpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); -void cblas_ztrsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *A, int lda, void *X, - int incX); -void cblas_ztbsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, int K, void *A, int lda, - void *X, int incX); -void cblas_ztpsv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, - int N, void *Ap, void *X, int incX); - +void cblas_sgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, float alpha, float *A, int lda, float *X, int incX, + float beta, float *Y, int incY); +void cblas_sgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, float alpha, float *A, int lda, + float *X, int incX, float beta, float *Y, int incY); +void cblas_strmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *A, int lda, float *X, int incX); +void cblas_stbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, float *A, int lda, float *X, int incX); +void cblas_stpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *Ap, float *X, int incX); +void cblas_strsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *A, int lda, float *X, int incX); +void cblas_stbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, float *A, int lda, float *X, int incX); +void cblas_stpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + float *Ap, float *X, int incX); + +void cblas_dgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, double alpha, double *A, int lda, + double *X, int incX, double beta, double *Y, int incY); +void cblas_dtrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *A, int lda, double *X, int incX); +void cblas_dtbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, double *A, int lda, double *X, int incX); +void cblas_dtpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *Ap, double *X, int incX); +void cblas_dtrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *A, int lda, double *X, int incX); +void cblas_dtbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, double *A, int lda, double *X, int incX); +void cblas_dtpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + double *Ap, double *X, int incX); + +void cblas_cgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, void *alpha, void *A, int lda, void *X, int incX, + void *beta, void *Y, int incY); +void cblas_cgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, void *alpha, void *A, int lda, void *X, + int incX, void *beta, void *Y, int incY); +void cblas_ctrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ctbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ctpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); +void cblas_ctrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ctbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ctpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); + +void cblas_zgemv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, void *alpha, void *A, int lda, void *X, int incX, + void *beta, void *Y, int incY); +void cblas_zgbmv(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, int M, + int N, int KL, int KU, void *alpha, void *A, int lda, void *X, + int incX, void *beta, void *Y, int incY); +void cblas_ztrmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ztbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ztpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); +void cblas_ztrsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *A, int lda, void *X, int incX); +void cblas_ztbsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + int K, void *A, int lda, void *X, int incX); +void cblas_ztpsv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_DIAG Diag, int N, + void *Ap, void *X, int incX); /* * Routines with S and D prefixes only */ -void cblas_ssymv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *A, - int lda, float *X, int incX, - float beta, float *Y, int incY); -void cblas_ssbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, float alpha, float *A, - int lda, float *X, int incX, - float beta, float *Y, int incY); -void cblas_sspmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *Ap, - float *X, int incX, - float beta, float *Y, int incY); -void cblas_sger( enum CBLAS_ORDER Order, int M, int N, - float alpha, float *X, int incX, - float *Y, int incY, float *A, int lda); -void cblas_ssyr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *A, int lda); -void cblas_sspr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Ap); -void cblas_ssyr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Y, int incY, float *A, - int lda); -void cblas_sspr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, float *X, - int incX, float *Y, int incY, float *A); - -void cblas_dsymv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *A, - int lda, double *X, int incX, - double beta, double *Y, int incY); -void cblas_dsbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, double alpha, double *A, - int lda, double *X, int incX, - double beta, double *Y, int incY); -void cblas_dspmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *Ap, - double *X, int incX, - double beta, double *Y, int incY); -void cblas_dger( enum CBLAS_ORDER Order, int M, int N, - double alpha, double *X, int incX, - double *Y, int incY, double *A, int lda); -void cblas_dsyr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *A, int lda); -void cblas_dspr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Ap); -void cblas_dsyr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Y, int incY, double *A, +void cblas_ssymv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *A, int lda, float *X, int incX, float beta, + float *Y, int incY); +void cblas_ssbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + float alpha, float *A, int lda, float *X, int incX, float beta, + float *Y, int incY); +void cblas_sspmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *Ap, float *X, int incX, float beta, + float *Y, int incY); +void cblas_sger(enum CBLAS_ORDER Order, int M, int N, float alpha, float *X, + int incX, float *Y, int incY, float *A, int lda); +void cblas_ssyr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *A, int lda); +void cblas_sspr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Ap); +void cblas_ssyr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Y, int incY, float *A, int lda); -void cblas_dspr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, double *X, - int incX, double *Y, int incY, double *A); - +void cblas_sspr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, float *X, int incX, float *Y, int incY, float *A); + +void cblas_dsymv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dsbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + double alpha, double *A, int lda, double *X, int incX, + double beta, double *Y, int incY); +void cblas_dspmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *Ap, double *X, int incX, double beta, + double *Y, int incY); +void cblas_dger(enum CBLAS_ORDER Order, int M, int N, double alpha, double *X, + int incX, double *Y, int incY, double *A, int lda); +void cblas_dsyr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *A, int lda); +void cblas_dspr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Ap); +void cblas_dsyr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Y, int incY, + double *A, int lda); +void cblas_dspr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, double *X, int incX, double *Y, int incY, + double *A); /* * Routines with C and Z prefixes only */ -void cblas_chemv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_chbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_chpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *Ap, - void *X, int incX, - void *beta, void *Y, int incY); -void cblas_cgeru( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_cgerc( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_cher( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, void *X, int incX, - void *A, int lda); -void cblas_chpr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, float alpha, void *X, - int incX, void *A); -void cblas_cher2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_chpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *Ap); - -void cblas_zhemv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zhbmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, int K, void *alpha, void *A, - int lda, void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zhpmv( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, void *alpha, void *Ap, - void *X, int incX, - void *beta, void *Y, int incY); -void cblas_zgeru( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zgerc( enum CBLAS_ORDER Order, int M, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zher( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, void *X, int incX, - void *A, int lda); -void cblas_zhpr( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - int N, double alpha, void *X, - int incX, void *A); -void cblas_zher2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *A, int lda); -void cblas_zhpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, - void *alpha, void *X, int incX, - void *Y, int incY, void *Ap); +void cblas_chemv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_chbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_chpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *Ap, void *X, int incX, void *beta, void *Y, + int incY); +void cblas_cgeru(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_cgerc(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_cher(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, void *X, int incX, void *A, int lda); +void cblas_chpr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + float alpha, void *X, int incX, void *A); +void cblas_cher2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *A, + int lda); +void cblas_chpr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *Ap); + +void cblas_zhemv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_zhbmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, int K, + void *alpha, void *A, int lda, void *X, int incX, void *beta, + void *Y, int incY); +void cblas_zhpmv(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *Ap, void *X, int incX, void *beta, void *Y, + int incY); +void cblas_zgeru(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_zgerc(enum CBLAS_ORDER Order, int M, int N, void *alpha, void *X, + int incX, void *Y, int incY, void *A, int lda); +void cblas_zher(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, void *X, int incX, void *A, int lda); +void cblas_zhpr(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + double alpha, void *X, int incX, void *A); +void cblas_zher2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *A, + int lda); +void cblas_zhpr2(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, + void *alpha, void *X, int incX, void *Y, int incY, void *Ap); /* * =========================================================================== @@ -502,163 +409,126 @@ void cblas_zhpr2( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, int N, /* * Routines with standard 4 prefixes (S, D, C, Z) */ -void cblas_sgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, float alpha, float *A, - int lda, float *B, int ldb, - float beta, float *C, int ldc); -void cblas_ssymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb, float beta, - float *C, int ldc); -void cblas_ssyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, float *A, int lda, - float beta, float *C, int ldc); -void cblas_ssyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, float *A, int lda, - float *B, int ldb, float beta, - float *C, int ldc); -void cblas_strmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb); -void cblas_strsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - float alpha, float *A, int lda, - float *B, int ldb); - -void cblas_dgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, double alpha, double *A, - int lda, double *B, int ldb, - double beta, double *C, int ldc); -void cblas_dsymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb, double beta, - double *C, int ldc); -void cblas_dsyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, double *A, int lda, - double beta, double *C, int ldc); -void cblas_dsyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, double *A, int lda, - double *B, int ldb, double beta, +void cblas_sgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, float alpha, + float *A, int lda, float *B, int ldb, float beta, float *C, + int ldc); +void cblas_ssymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, float alpha, float *A, + int lda, float *B, int ldb, float beta, float *C, int ldc); +void cblas_ssyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, + float *A, int lda, float beta, float *C, int ldc); +void cblas_ssyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, + float *A, int lda, float *B, int ldb, float beta, float *C, + int ldc); +void cblas_strmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, float alpha, float *A, + int lda, float *B, int ldb); +void cblas_strsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, float alpha, float *A, + int lda, float *B, int ldb); + +void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, double alpha, + double *A, int lda, double *B, int ldb, double beta, double *C, + int ldc); +void cblas_dsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, double alpha, double *A, + int lda, double *B, int ldb, double beta, double *C, int ldc); +void cblas_dsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + double *A, int lda, double beta, double *C, int ldc); +void cblas_dsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + double *A, int lda, double *B, int ldb, double beta, double *C, int ldc); -void cblas_dtrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb); -void cblas_dtrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - double alpha, double *A, int lda, - double *B, int ldb); - -void cblas_cgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, void *alpha, void *A, - int lda, void *B, int ldb, - void *beta, void *C, int ldc); -void cblas_csymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_csyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *beta, void *C, int ldc); -void cblas_csyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_ctrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); -void cblas_ctrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); - -void cblas_zgemm( enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_TRANSPOSE TransB, int M, int N, - int K, void *alpha, void *A, - int lda, void *B, int ldb, - void *beta, void *C, int ldc); -void cblas_zsymm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_zsyrk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *beta, void *C, int ldc); -void cblas_zsyr2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_ztrmm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); -void cblas_ztrsm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, - enum CBLAS_DIAG Diag, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb); - +void cblas_dtrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, double alpha, double *A, + int lda, double *B, int ldb); +void cblas_dtrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, double alpha, double *A, + int lda, double *B, int ldb); + +void cblas_cgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_csymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_csyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, void *A, + int lda, void *beta, void *C, int ldc); +void cblas_csyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_ctrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); +void cblas_ctrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); + +void cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_TRANSPOSE TransB, int M, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_zsymm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_zsyrk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, void *A, + int lda, void *beta, void *C, int ldc); +void cblas_zsyr2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, void *beta, void *C, + int ldc); +void cblas_ztrmm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); +void cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, enum CBLAS_TRANSPOSE TransA, + enum CBLAS_DIAG Diag, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb); /* * Routines with prefixes C and Z only */ -void cblas_chemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_cherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - float alpha, void *A, int lda, - float beta, void *C, int ldc); -void cblas_cher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, float beta, - void *C, int ldc); -void cblas_zhemm( enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, - enum CBLAS_UPLO Uplo, int M, int N, - void *alpha, void *A, int lda, - void *B, int ldb, void *beta, - void *C, int ldc); -void cblas_zherk( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - double alpha, void *A, int lda, - double beta, void *C, int ldc); -void cblas_zher2k( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, - enum CBLAS_TRANSPOSE Trans, int N, int K, - void *alpha, void *A, int lda, - void *B, int ldb, double beta, - void *C, int ldc); +void cblas_chemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_cherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, float alpha, void *A, + int lda, float beta, void *C, int ldc); +void cblas_cher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, float beta, void *C, + int ldc); +void cblas_zhemm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, + enum CBLAS_UPLO Uplo, int M, int N, void *alpha, void *A, + int lda, void *B, int ldb, void *beta, void *C, int ldc); +void cblas_zherk(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, double alpha, + void *A, int lda, double beta, void *C, int ldc); +void cblas_zher2k(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE Trans, int N, int K, void *alpha, + void *A, int lda, void *B, int ldb, double beta, void *C, + int ldc); int cblas_errprn(int ierr, int info, char *form, ...); #ifdef __cplusplus } #endif -#endif /* end #ifdef CBLAS_ENUM_ONLY */ +#endif /* end #ifdef CBLAS_ENUM_ONLY */ #endif -#endif //NATIVEOPERATIONS_CBLAS_H +#endif // NATIVEOPERATIONS_CBLAS_H diff --git a/libnd4j/include/cblas_enum_conversion.h b/libnd4j/include/cblas_enum_conversion.h old mode 100755 new mode 100644 index 6ff6fe557f96..987b5e80557a --- a/libnd4j/include/cblas_enum_conversion.h +++ b/libnd4j/include/cblas_enum_conversion.h @@ -78,5 +78,4 @@ CBLAS_SIDE convertSide(int from); } #endif - -#endif //NATIVEOPERATIONS_CBLAS_ENUM_CONVERSION_H_H +#endif // NATIVEOPERATIONS_CBLAS_ENUM_CONVERSION_H_H diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index 62fab018bdd0..a225e5791c02 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -22,270 +22,251 @@ * THE SOFTWARE. ******************************************************************************/ -//Copyright (C) 2011 Carl Rogers -//Released under MIT License -//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php +// Copyright (C) 2011 Carl Rogers +// Released under MIT License +// license available in LICENSE file, or at +// http://www.opensource.org/licenses/mit-license.php #ifndef LIBCNPY_H_ #define LIBCNPY_H_ - /** * */ -#include -#include -#include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include +#include +#include -#include +#include +#include +#include +#include #include +#include +#include #include -#include -#include -#include - +#include +#include +#include +#include namespace cnpy { - /** - * The numpy array - */ - struct SD_EXPORT NpyArray { - char* data; - std::vector shape; - unsigned int wordSize; - bool fortranOrder; - void destruct() { - delete[] data; - } - }; - - struct SD_EXPORT npz_t : public std::unordered_map { - void destruct() { - npz_t::iterator it = this->begin(); - for(; it != this->end(); ++it) (*it).second.destruct(); - } - }; - - /** - * - * @param path - * @return - */ - SD_EXPORT char* loadFile(const char *path); - - - - /** - * - * @return - */ - char BigEndianTest(); - +/** + * The numpy array + */ +struct SD_EXPORT NpyArray { + char *data; + std::vector shape; + unsigned int wordSize; + bool fortranOrder; + void destruct() { delete[] data; } +}; + +struct SD_EXPORT npz_t : public std::unordered_map { + void destruct() { + npz_t::iterator it = this->begin(); + for (; it != this->end(); ++it) (*it).second.destruct(); + } +}; - /** - * - * @param t - * @return - */ - SD_EXPORT char mapType(const std::type_info &t); +/** + * + * @param path + * @return + */ +SD_EXPORT char *loadFile(const char *path); - template - SD_EXPORT char mapType(); +/** + * + * @return + */ +char BigEndianTest(); - /** - * - * @param T the type of the ndarray - * @param data the data for the ndarray - * @param shape the shape of the ndarray - * @param ndims the rank of the ndarray - * @return - */ - template - SD_EXPORT std::vector createNpyHeader(const void *data, - const unsigned int *shape, - const unsigned int ndims, - unsigned int wordSize = 4); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param fp the file to parse from - * @param wordSize the size of - * the individual elements - * @param shape - * @param ndims - * @param fortranOrder - */ - SD_EXPORT void parseNpyHeader(FILE *fp, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, - bool &fortranOrder); +/** + * + * @param t + * @return + */ +SD_EXPORT char mapType(const std::type_info &t); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param header the file to parse from - * @param word_size the size of - * the individual elements - * @param shape - * @param ndims - * @param fortran_order - */ - SD_EXPORT void parseNpyHeaderPointer( - const char *header, - unsigned int& word_size, - unsigned int*& shape, - unsigned int& ndims, - bool& fortran_order); - /** - * - * @param fp - * @param nrecs - * @param global_header_size - * @param global_header_offset - */ - SD_EXPORT void parseZipFooter(FILE *fp, - unsigned short &nrecs, - unsigned int &global_header_size, - unsigned int &global_header_offset); +template +SD_EXPORT char mapType(); - /** - * - * @param fname - * @param varname - * @return - */ - SD_EXPORT NpyArray npzLoad(std::string fname, std::string varname); +/** + * + * @param T the type of the ndarray + * @param data the data for the ndarray + * @param shape the shape of the ndarray + * @param ndims the rank of the ndarray + * @return + */ +template +SD_EXPORT std::vector createNpyHeader(const void *data, + const unsigned int *shape, + const unsigned int ndims, + unsigned int wordSize = 4); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +SD_EXPORT void parseNpyHeader(FILE *fp, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, + bool &fortranOrder); - /** - * - * @param fname - * @return - */ - SD_EXPORT NpyArray npyLoad(std::string fname); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param header the file to parse from + * @param word_size the size of + * the individual elements + * @param shape + * @param ndims + * @param fortran_order + */ +SD_EXPORT void parseNpyHeaderPointer(const char *header, + unsigned int &word_size, + unsigned int *&shape, unsigned int &ndims, + bool &fortran_order); +/** + * + * @param fp + * @param nrecs + * @param global_header_size + * @param global_header_offset + */ +SD_EXPORT void parseZipFooter(FILE *fp, unsigned short &nrecs, + unsigned int &global_header_size, + unsigned int &global_header_offset); - /** - * Parse the numpy header from - * the given file - * based on the pointers passed in - * @param fp the file to parse from - * @param wordSize the size of - * the individual elements - * @param shape - * @param ndims - * @param fortranOrder - */ - SD_EXPORT void parseNpyHeaderStr(std::string header, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, - bool &fortranOrder); +/** + * + * @param fname + * @param varname + * @return + */ +SD_EXPORT NpyArray npzLoad(std::string fname, std::string varname); +/** + * + * @param fname + * @return + */ +SD_EXPORT NpyArray npyLoad(std::string fname); - /** - * - * @param fp - * @return - */ - SD_EXPORT int* shapeFromFile(FILE *fp); +/** + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +SD_EXPORT void parseNpyHeaderStr(std::string header, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, + bool &fortranOrder); - /** - * - * @param data - * @return - */ - SD_EXPORT int* shapeFromPointer(char *data); +/** + * + * @param fp + * @return + */ +SD_EXPORT int *shapeFromFile(FILE *fp); - /** - * Load the numpy array from the given file. - * @param fp the file to load - * @return the loaded array - */ - SD_EXPORT NpyArray loadNpyFromFile(FILE *fp); +/** + * + * @param data + * @return + */ +SD_EXPORT int *shapeFromPointer(char *data); - /** - * Load the numpy array archive from the given file. - * @param fp the file to load - * @return the loaded archive - */ - SD_EXPORT npz_t npzLoad(FILE* fp); - /** - * - * @param data - * @return - */ - SD_EXPORT NpyArray loadNpyFromPointer(char *data); +/** + * Load the numpy array from the given file. + * @param fp the file to load + * @return the loaded array + */ +SD_EXPORT NpyArray loadNpyFromFile(FILE *fp); - /** - * - * @param data - * @return - */ - SD_EXPORT NpyArray loadNpyFromHeader(char *data); +/** + * Load the numpy array archive from the given file. + * @param fp the file to load + * @return the loaded archive + */ +SD_EXPORT npz_t npzLoad(FILE *fp); +/** + * + * @param data + * @return + */ +SD_EXPORT NpyArray loadNpyFromPointer(char *data); +/** + * + * @param data + * @return + */ +SD_EXPORT NpyArray loadNpyFromHeader(char *data); - SD_EXPORT npz_t npzLoad(std::string fname); +SD_EXPORT npz_t npzLoad(std::string fname); - SD_EXPORT sd::DataType dataTypeFromHeader(char *data); +SD_EXPORT sd::DataType dataTypeFromHeader(char *data); /** -* Parse the numpy header from -* the given file -* based on the pointers passed in -* @param fp the file to parse from -* @param word_size the size of -* the individual elements -* @param shape -* @param ndims -* @param fortran_order -*/ - SD_EXPORT void parseNpyHeader(std::string header, - unsigned int &word_size, - unsigned int *&shape, - unsigned int &ndims, - bool &fortran_order); - - /** - * - * @tparam T - * @param i - * @param pad - * @param padval - * @return - */ - template - FORCEINLINE std::string tostring(T i, int pad = 0, char padval = ' ') { - std::stringstream s; - s << i; - return s.str(); - } + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param word_size the size of + * the individual elements + * @param shape + * @param ndims + * @param fortran_order + */ +SD_EXPORT void parseNpyHeader(std::string header, unsigned int &word_size, + unsigned int *&shape, unsigned int &ndims, + bool &fortran_order); +/** + * + * @tparam T + * @param i + * @param pad + * @param padval + * @return + */ +template +FORCEINLINE std::string tostring(T i, int pad = 0, char padval = ' ') { + std::stringstream s; + s << i; + return s.str(); +} - template - SD_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w"); +template +SD_EXPORT void npy_save(std::string fname, const T *data, + const unsigned int *shape, const unsigned int ndims, + std::string mode = "w"); -} +} // namespace cnpy /** - * - * @tparam T - * @param lhs - * @param rhs - * @return - */ - template - SD_EXPORT std::vector& operator+=(std::vector& lhs, const T rhs); - + * + * @tparam T + * @param lhs + * @param rhs + * @return + */ +template +SD_EXPORT std::vector &operator+=(std::vector &lhs, const T rhs); #endif diff --git a/libnd4j/include/exceptions/allocation_exception.h b/libnd4j/include/exceptions/allocation_exception.h index 156f6937ca3b..beec84e40d3d 100644 --- a/libnd4j/include/exceptions/allocation_exception.h +++ b/libnd4j/include/exceptions/allocation_exception.h @@ -21,28 +21,30 @@ #ifndef SD_ALLOCATION_EXCEPTION_H #define SD_ALLOCATION_EXCEPTION_H -#include -#include -#include #include +#include + +#include +#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT allocation_exception : public std::runtime_error { - public: - allocation_exception(std::string message); - ~allocation_exception() = default; - - static allocation_exception build(std::string message, Nd4jLong bytes); - static allocation_exception build(std::string message, Nd4jLong limit, Nd4jLong bytes); - }; -} - - -#endif //SD_ALLOCATION_EXCEPTION_H +class SD_EXPORT allocation_exception : public std::runtime_error { + public: + allocation_exception(std::string message); + ~allocation_exception() = default; + + static allocation_exception build(std::string message, Nd4jLong bytes); + static allocation_exception build(std::string message, Nd4jLong limit, + Nd4jLong bytes); +}; +} // namespace sd + +#endif // SD_ALLOCATION_EXCEPTION_H diff --git a/libnd4j/include/exceptions/cuda_exception.h b/libnd4j/include/exceptions/cuda_exception.h index b4e6f591c002..6c8f288d3ee6 100644 --- a/libnd4j/include/exceptions/cuda_exception.h +++ b/libnd4j/include/exceptions/cuda_exception.h @@ -21,27 +21,27 @@ #ifndef SD_CUDA_EXCEPTION_H #define SD_CUDA_EXCEPTION_H -#include -#include #include +#include +#include + #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT cuda_exception : public std::runtime_error { - public: - cuda_exception(std::string message); - ~cuda_exception() = default; - - static cuda_exception build(std::string message, int errorCode); - }; -} - +class SD_EXPORT cuda_exception : public std::runtime_error { + public: + cuda_exception(std::string message); + ~cuda_exception() = default; + static cuda_exception build(std::string message, int errorCode); +}; +} // namespace sd -#endif //SD_CUDA_EXCEPTION_H +#endif // SD_CUDA_EXCEPTION_H diff --git a/libnd4j/include/exceptions/datatype_exception.h b/libnd4j/include/exceptions/datatype_exception.h index a2e9ed96fbb7..b3effeaf4e31 100644 --- a/libnd4j/include/exceptions/datatype_exception.h +++ b/libnd4j/include/exceptions/datatype_exception.h @@ -21,30 +21,34 @@ #ifndef SD_DATATYPE_EXCEPTION_H #define SD_DATATYPE_EXCEPTION_H -#include -#include #include #include +#include +#include + #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT datatype_exception : public std::runtime_error { - public: - datatype_exception(const std::string &message); - ~datatype_exception() = default; - - - static datatype_exception build(const std::string &message, sd::DataType actual); - static datatype_exception build(const std::string &message, sd::DataType expected, sd::DataType actual); - static datatype_exception build(const std::string &message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY); - }; -} - - -#endif //SD_DATATYPE_EXCEPTION_H +class SD_EXPORT datatype_exception : public std::runtime_error { + public: + datatype_exception(const std::string &message); + ~datatype_exception() = default; + + static datatype_exception build(const std::string &message, + sd::DataType actual); + static datatype_exception build(const std::string &message, + sd::DataType expected, sd::DataType actual); + static datatype_exception build(const std::string &message, + sd::DataType expected, sd::DataType actualX, + sd::DataType actualY); +}; +} // namespace sd + +#endif // SD_DATATYPE_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exception.h b/libnd4j/include/exceptions/graph_exception.h index 7866dc1bb909..66bb5aad4792 100644 --- a/libnd4j/include/exceptions/graph_exception.h +++ b/libnd4j/include/exceptions/graph_exception.h @@ -21,37 +21,40 @@ #ifndef LIBND4J_GRAPH_EXCEPTION_H #define LIBND4J_GRAPH_EXCEPTION_H -#include -#include -#include #include +#include + +#include +#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT graph_exception : public std::runtime_error { - protected: - Nd4jLong _graphId; - std::string _message; - std::string _description; - public: - graph_exception(std::string message, Nd4jLong graphId); - graph_exception(std::string message, std::string description, Nd4jLong graphId); - graph_exception(std::string message, const char *description, Nd4jLong graphId); - ~graph_exception() = default; - - Nd4jLong graphId(); - - const char * message(); - const char * description(); - }; -} - - - -#endif //SD_GRAPH_EXCEPTION_H +class SD_EXPORT graph_exception : public std::runtime_error { + protected: + Nd4jLong _graphId; + std::string _message; + std::string _description; + + public: + graph_exception(std::string message, Nd4jLong graphId); + graph_exception(std::string message, std::string description, + Nd4jLong graphId); + graph_exception(std::string message, const char *description, + Nd4jLong graphId); + ~graph_exception() = default; + + Nd4jLong graphId(); + + const char *message(); + const char *description(); +}; +} // namespace sd + +#endif // SD_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_execution_exception.h b/libnd4j/include/exceptions/graph_execution_exception.h index 3bf3b3311faa..2b9ce263af12 100644 --- a/libnd4j/include/exceptions/graph_execution_exception.h +++ b/libnd4j/include/exceptions/graph_execution_exception.h @@ -21,25 +21,28 @@ #ifndef SD_GRAPH_EXECUTION_EXCEPTION_H #define SD_GRAPH_EXECUTION_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT graph_execution_exception: public graph_exception { - public: - explicit graph_execution_exception(Nd4jLong graphId); - explicit graph_execution_exception(const std::string &message, Nd4jStatus status); - }; -} - -#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H +class SD_EXPORT graph_execution_exception : public graph_exception { + public: + explicit graph_execution_exception(Nd4jLong graphId); + explicit graph_execution_exception(const std::string &message, + Nd4jStatus status); +}; +} // namespace sd + +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/graph_exists_exception.h b/libnd4j/include/exceptions/graph_exists_exception.h index edec72303d2c..0023a49fa0a3 100644 --- a/libnd4j/include/exceptions/graph_exists_exception.h +++ b/libnd4j/include/exceptions/graph_exists_exception.h @@ -21,24 +21,26 @@ #ifndef SD_GRAPH_EXISTS_EXCEPTION_H #define SD_GRAPH_EXISTS_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT graph_exists_exception: public graph_exception { - public: - explicit graph_exists_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT graph_exists_exception : public graph_exception { + public: + explicit graph_exists_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/impl/allocation_exception.cpp b/libnd4j/include/exceptions/impl/allocation_exception.cpp index 46f2ef5c86ae..f29f4d0d07f7 100644 --- a/libnd4j/include/exceptions/impl/allocation_exception.cpp +++ b/libnd4j/include/exceptions/impl/allocation_exception.cpp @@ -22,20 +22,24 @@ #include namespace sd { - allocation_exception::allocation_exception(std::string message) : std::runtime_error(message){ - // - } +allocation_exception::allocation_exception(std::string message) + : std::runtime_error(message) { + // +} - allocation_exception allocation_exception::build(std::string message, Nd4jLong numBytes) { - auto bytes = StringUtils::valueToString(numBytes); - message += "; Requested bytes: [" + bytes + "]"; - return allocation_exception(message); - } +allocation_exception allocation_exception::build(std::string message, + Nd4jLong numBytes) { + auto bytes = StringUtils::valueToString(numBytes); + message += "; Requested bytes: [" + bytes + "]"; + return allocation_exception(message); +} - allocation_exception allocation_exception::build(std::string message, Nd4jLong limit, Nd4jLong numBytes) { - auto bytes = StringUtils::valueToString(numBytes); - auto lim = StringUtils::valueToString(limit); - message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]"; - return allocation_exception(message); - } -} \ No newline at end of file +allocation_exception allocation_exception::build(std::string message, + Nd4jLong limit, + Nd4jLong numBytes) { + auto bytes = StringUtils::valueToString(numBytes); + auto lim = StringUtils::valueToString(limit); + message += "; Limit bytes: [" + lim + "]; Requested bytes: [" + bytes + "]"; + return allocation_exception(message); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/cuda_exception.cpp b/libnd4j/include/exceptions/impl/cuda_exception.cpp index 91de6c2516a5..93854168a338 100644 --- a/libnd4j/include/exceptions/impl/cuda_exception.cpp +++ b/libnd4j/include/exceptions/impl/cuda_exception.cpp @@ -22,13 +22,14 @@ #include namespace sd { - cuda_exception::cuda_exception(std::string message) : std::runtime_error(message){ - // - } +cuda_exception::cuda_exception(std::string message) + : std::runtime_error(message) { + // +} - cuda_exception cuda_exception::build(std::string message, int errorCode) { - auto ec = StringUtils::valueToString(errorCode); - message += "; Error code: [" + ec + "]"; - return cuda_exception(message); - } -} \ No newline at end of file +cuda_exception cuda_exception::build(std::string message, int errorCode) { + auto ec = StringUtils::valueToString(errorCode); + message += "; Error code: [" + ec + "]"; + return cuda_exception(message); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/datatype_exception.cpp b/libnd4j/include/exceptions/impl/datatype_exception.cpp index ee4bd1a5a77a..4b7e36010291 100644 --- a/libnd4j/include/exceptions/impl/datatype_exception.cpp +++ b/libnd4j/include/exceptions/impl/datatype_exception.cpp @@ -22,28 +22,36 @@ #include namespace sd { - datatype_exception::datatype_exception(const std::string &message) : std::runtime_error(message){ - // - } +datatype_exception::datatype_exception(const std::string &message) + : std::runtime_error(message) { + // +} - datatype_exception datatype_exception::build(const std::string &message, sd::DataType expected, sd::DataType actual) { - auto exp = DataTypeUtils::asString(expected); - auto act = DataTypeUtils::asString(actual); - auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + act + "]"; - return datatype_exception(fmessage); - } +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType expected, + sd::DataType actual) { + auto exp = DataTypeUtils::asString(expected); + auto act = DataTypeUtils::asString(actual); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + act + "]"; + return datatype_exception(fmessage); +} - datatype_exception datatype_exception::build(const std::string &message, sd::DataType expected, sd::DataType actualX, sd::DataType actualY) { - auto exp = DataTypeUtils::asString(expected); - auto actX = DataTypeUtils::asString(actualX); - auto actY = DataTypeUtils::asString(actualY); - auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + actX + ", " + actY + "]"; - return datatype_exception(fmessage); - } +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType expected, + sd::DataType actualX, + sd::DataType actualY) { + auto exp = DataTypeUtils::asString(expected); + auto actX = DataTypeUtils::asString(actualX); + auto actY = DataTypeUtils::asString(actualY); + auto fmessage = message + "; Expected: [" + exp + "]; Actual: [" + actX + + ", " + actY + "]"; + return datatype_exception(fmessage); +} - datatype_exception datatype_exception::build(const std::string &message, sd::DataType actual) { - auto act = DataTypeUtils::asString(actual); - auto fmessage = message + "; Actual: [" + act + "]"; - return datatype_exception(fmessage); - } -} \ No newline at end of file +datatype_exception datatype_exception::build(const std::string &message, + sd::DataType actual) { + auto act = DataTypeUtils::asString(actual); + auto fmessage = message + "; Actual: [" + act + "]"; + return datatype_exception(fmessage); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/exceptions/impl/graph_exception.cpp b/libnd4j/include/exceptions/impl/graph_exception.cpp index fa2210a1d729..bb289ad54de4 100644 --- a/libnd4j/include/exceptions/impl/graph_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_exception.cpp @@ -22,33 +22,31 @@ #include namespace sd { - graph_exception::graph_exception(std::string message, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_graphId = graphId; - } - - graph_exception::graph_exception(std::string message, std::string description, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_description = description; - this->_graphId = graphId; - } - - graph_exception::graph_exception(std::string message, const char *description, Nd4jLong graphId) : std::runtime_error(message) { - this->_message = message; - this->_description = description; - this->_graphId = graphId; - } - - - Nd4jLong graph_exception::graphId() { - return _graphId; - } - - const char* graph_exception::message() { - return _message.c_str(); - } - - const char* graph_exception::description() { - return _description.c_str(); - } +graph_exception::graph_exception(std::string message, Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_graphId = graphId; } + +graph_exception::graph_exception(std::string message, std::string description, + Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_description = description; + this->_graphId = graphId; +} + +graph_exception::graph_exception(std::string message, const char* description, + Nd4jLong graphId) + : std::runtime_error(message) { + this->_message = message; + this->_description = description; + this->_graphId = graphId; +} + +Nd4jLong graph_exception::graphId() { return _graphId; } + +const char* graph_exception::message() { return _message.c_str(); } + +const char* graph_exception::description() { return _description.c_str(); } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp index aff8f76638e5..340340f34398 100644 --- a/libnd4j/include/exceptions/impl/graph_execution_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_execution_exception.cpp @@ -18,15 +18,20 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - graph_execution_exception::graph_execution_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Caught exception during graph execution", graphId), graphId) { - _graphId = graphId; - } +graph_execution_exception::graph_execution_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Caught exception during graph execution", graphId), + graphId) { + _graphId = graphId; +} - graph_execution_exception::graph_execution_exception(const std::string &message, Nd4jStatus status) : graph_exception(message, status) { - // - } +graph_execution_exception::graph_execution_exception(const std::string &message, + Nd4jStatus status) + : graph_exception(message, status) { + // } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/graph_exists_exception.cpp b/libnd4j/include/exceptions/impl/graph_exists_exception.cpp index 535a74a6a0fe..8889aed2a628 100644 --- a/libnd4j/include/exceptions/impl/graph_exists_exception.cpp +++ b/libnd4j/include/exceptions/impl/graph_exists_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - graph_exists_exception::graph_exists_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Graph with given ID already exists", graphId), graphId) { - _graphId = graphId; - } +graph_exists_exception::graph_exists_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Graph with given ID already exists", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/no_results_exception.cpp b/libnd4j/include/exceptions/impl/no_results_exception.cpp index ce3122ffbd8a..a20cd77c97fa 100644 --- a/libnd4j/include/exceptions/impl/no_results_exception.cpp +++ b/libnd4j/include/exceptions/impl/no_results_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - no_results_exception::no_results_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Got no results after graph execution", graphId), graphId) { - _graphId = graphId; - } +no_results_exception::no_results_exception(Nd4jLong graphId) + : graph_exception(StringUtils::buildGraphErrorMessage( + "Got no results after graph execution", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp index 5b53ccc561ba..da8ee1241eda 100644 --- a/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp +++ b/libnd4j/include/exceptions/impl/shape_mismatch_exception.cpp @@ -22,15 +22,18 @@ #include namespace sd { - shape_mismatch_exception::shape_mismatch_exception(const std::string &message) : std::runtime_error(message) { - // - } +shape_mismatch_exception::shape_mismatch_exception(const std::string &message) + : std::runtime_error(message) { + // +} - shape_mismatch_exception - shape_mismatch_exception::build(const std::string &message, const std::vector &expected, const std::vector &actual) { - auto exp = ShapeUtils::shapeAsString(expected); - auto act = ShapeUtils::shapeAsString(actual); - auto fmessage = message + "; Expected shape: " + exp + "; Actual shape: " + act + ";"; - return shape_mismatch_exception(fmessage); - } +shape_mismatch_exception shape_mismatch_exception::build( + const std::string &message, const std::vector &expected, + const std::vector &actual) { + auto exp = ShapeUtils::shapeAsString(expected); + auto act = ShapeUtils::shapeAsString(actual); + auto fmessage = + message + "; Expected shape: " + exp + "; Actual shape: " + act + ";"; + return shape_mismatch_exception(fmessage); } +} // namespace sd diff --git a/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp b/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp index ad73f3d33353..5cb64defe169 100644 --- a/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp +++ b/libnd4j/include/exceptions/impl/unknown_graph_exception.cpp @@ -18,11 +18,14 @@ // Created by raver on 8/31/2018. // -#include #include +#include namespace sd { - unknown_graph_exception::unknown_graph_exception(Nd4jLong graphId) : graph_exception(StringUtils::buildGraphErrorMessage("Unknown graph", graphId), graphId) { - _graphId = graphId; - } +unknown_graph_exception::unknown_graph_exception(Nd4jLong graphId) + : graph_exception( + StringUtils::buildGraphErrorMessage("Unknown graph", graphId), + graphId) { + _graphId = graphId; } +} // namespace sd diff --git a/libnd4j/include/exceptions/no_results_exception.h b/libnd4j/include/exceptions/no_results_exception.h index 97a281826d26..415a76018d91 100644 --- a/libnd4j/include/exceptions/no_results_exception.h +++ b/libnd4j/include/exceptions/no_results_exception.h @@ -21,24 +21,26 @@ #ifndef SD_NO_RESULTS_EXCEPTION_H #define SD_NO_RESULTS_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT no_results_exception: public graph_exception { - public: - explicit no_results_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT no_results_exception : public graph_exception { + public: + explicit no_results_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/shape_mismatch_exception.h b/libnd4j/include/exceptions/shape_mismatch_exception.h index 8a9c2092e860..48868f892697 100644 --- a/libnd4j/include/exceptions/shape_mismatch_exception.h +++ b/libnd4j/include/exceptions/shape_mismatch_exception.h @@ -18,33 +18,35 @@ // @author raver119@gmail.com // - #ifndef SD_SHAPE_MISMATCH_EXCEPTION_H #define SD_SHAPE_MISMATCH_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT shape_mismatch_exception : public std::runtime_error { - public: - shape_mismatch_exception(const std::string &message); - ~shape_mismatch_exception() = default; - - static shape_mismatch_exception build(const std::string &message, const std::vector &expected, const std::vector &actual); - }; -} - - -#endif //SD_SHAPE_MISMATCH_EXCEPTION_H +class SD_EXPORT shape_mismatch_exception : public std::runtime_error { + public: + shape_mismatch_exception(const std::string &message); + ~shape_mismatch_exception() = default; + + static shape_mismatch_exception build(const std::string &message, + const std::vector &expected, + const std::vector &actual); +}; +} // namespace sd + +#endif // SD_SHAPE_MISMATCH_EXCEPTION_H diff --git a/libnd4j/include/exceptions/unknown_graph_exception.h b/libnd4j/include/exceptions/unknown_graph_exception.h index 39b63cda25d2..2ca79c6d272d 100644 --- a/libnd4j/include/exceptions/unknown_graph_exception.h +++ b/libnd4j/include/exceptions/unknown_graph_exception.h @@ -21,24 +21,26 @@ #ifndef SD_UNKNOWN_GRAPH_EXCEPTION_H #define SD_UNKNOWN_GRAPH_EXCEPTION_H +#include +#include #include #include + #include -#include -#include #if defined(_MSC_VER) -// we're ignoring warning about non-exportable parent class, since std::runtime_error is a part of Standard C++ Library -#pragma warning( disable : 4275 ) +// we're ignoring warning about non-exportable parent class, since +// std::runtime_error is a part of Standard C++ Library +#pragma warning(disable : 4275) #endif namespace sd { - class SD_EXPORT unknown_graph_exception: public graph_exception { - public: - explicit unknown_graph_exception(Nd4jLong graphId); - }; -} +class SD_EXPORT unknown_graph_exception : public graph_exception { + public: + explicit unknown_graph_exception(Nd4jLong graphId); +}; +} // namespace sd -#endif //SD_UNKNOWN_GRAPH_EXCEPTION_H +#endif // SD_UNKNOWN_GRAPH_EXCEPTION_H diff --git a/libnd4j/include/execution/AffinityManager.h b/libnd4j/include/execution/AffinityManager.h index d054dc885bd9..ae525b9c955e 100644 --- a/libnd4j/include/execution/AffinityManager.h +++ b/libnd4j/include/execution/AffinityManager.h @@ -23,24 +23,25 @@ #include #include + #include #include namespace sd { - class SD_EXPORT AffinityManager { - private: - static std::atomic _lastDevice; - static int _numberOfDevices; - static std::mutex _currentMutex; - static std::mutex _numberMutex; +class SD_EXPORT AffinityManager { + private: + static std::atomic _lastDevice; + static int _numberOfDevices; + static std::mutex _currentMutex; + static std::mutex _numberMutex; - public: - static int currentNativeDeviceId(); - static int currentDeviceId(); - static int numberOfDevices(); - static void setCurrentDevice(int deviceId); - static void setCurrentNativeDevice(int deviceId); - }; -} + public: + static int currentNativeDeviceId(); + static int currentDeviceId(); + static int numberOfDevices(); + static void setCurrentDevice(int deviceId); + static void setCurrentNativeDevice(int deviceId); +}; +} // namespace sd -#endif //SD_AFFINITYMANAGER_H +#endif // SD_AFFINITYMANAGER_H diff --git a/libnd4j/include/execution/BlockingQueue.h b/libnd4j/include/execution/BlockingQueue.h index b3c2c654c773..de6957f13063 100644 --- a/libnd4j/include/execution/BlockingQueue.h +++ b/libnd4j/include/execution/BlockingQueue.h @@ -21,32 +21,33 @@ #ifndef SAMEDIFF_BLOCKINGQUEUE_H #define SAMEDIFF_BLOCKINGQUEUE_H -#include -#include -#include #include #include +#include +#include +#include namespace samediff { - template - class BlockingQueue { - private: - std::queue _queue; - std::mutex _lock; - std::atomic _size; - std::atomic _available; - - std::condition_variable _condition; - public: - BlockingQueue(int queueSize); - ~BlockingQueue() = default; - T poll(); - void put(const T &t); - - bool available(); - void markAvailable(); - void markUnavailable(); - }; -} - -#endif //SD_BLOCKINGQUEUE_H +template +class BlockingQueue { + private: + std::queue _queue; + std::mutex _lock; + std::atomic _size; + std::atomic _available; + + std::condition_variable _condition; + + public: + BlockingQueue(int queueSize); + ~BlockingQueue() = default; + T poll(); + void put(const T &t); + + bool available(); + void markAvailable(); + void markUnavailable(); +}; +} // namespace samediff + +#endif // SD_BLOCKINGQUEUE_H diff --git a/libnd4j/include/execution/CallableInterface.h b/libnd4j/include/execution/CallableInterface.h index c5053ecbc933..06a4a7ca7243 100644 --- a/libnd4j/include/execution/CallableInterface.h +++ b/libnd4j/include/execution/CallableInterface.h @@ -22,73 +22,82 @@ #define SAMEDIFF_CALLABLEINTERFACE_H #include + +#include +#include +#include #include #include -#include -#include #include -#include namespace samediff { - /** - * This class is suited for passing functions to execution threads without queues - */ - class CallableInterface { - private: - // parallel_for functions - FUNC_1D _function_1d; - FUNC_2D _function_2d; - FUNC_3D _function_3d; - - // parallel function - FUNC_DO _function_do; - - // reduction functions - FUNC_RL _function_rl; - FUNC_RD _function_rd; - - std::array _arguments; - - volatile int _branch = 0; - volatile uint32_t _thread_id = 0; - volatile uint32_t _num_threads = 0; - - std::atomic _finished; - std::atomic _filled; - std::atomic _available; - - std::condition_variable _starter; - std::condition_variable _finisher; - - int64_t* _lptr = nullptr; - double* _dptr = nullptr; - - std::mutex _ms; - std::mutex _mf; - public: - CallableInterface(); - ~CallableInterface() = default; - - void waitForTask(); - void waitForCompletion(); - - void fill(int thread_id, int num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void fill(int thread_id, int num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); - - void fill(int thread_id, int num_threads, FUNC_DO func); - void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - - bool available(); - void markAvailable(); - void markUnavailable(); - - void finish(); - - void execute(); - }; -} - - -#endif //SD_CALLABLEINTERFACE_H +/** + * This class is suited for passing functions to execution threads without + * queues + */ +class CallableInterface { + private: + // parallel_for functions + FUNC_1D _function_1d; + FUNC_2D _function_2d; + FUNC_3D _function_3d; + + // parallel function + FUNC_DO _function_do; + + // reduction functions + FUNC_RL _function_rl; + FUNC_RD _function_rd; + + std::array _arguments; + + volatile int _branch = 0; + volatile uint32_t _thread_id = 0; + volatile uint32_t _num_threads = 0; + + std::atomic _finished; + std::atomic _filled; + std::atomic _available; + + std::condition_variable _starter; + std::condition_variable _finisher; + + int64_t* _lptr = nullptr; + double* _dptr = nullptr; + + std::mutex _ms; + std::mutex _mf; + + public: + CallableInterface(); + ~CallableInterface() = default; + + void waitForTask(); + void waitForCompletion(); + + void fill(int thread_id, int num_threads, int64_t* lpt, FUNC_RL func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + void fill(int thread_id, int num_threads, double* dpt, FUNC_RD func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + + void fill(int thread_id, int num_threads, FUNC_DO func); + void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, + int64_t stop_x, int64_t inc_x); + void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, + int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, + int64_t inc_y); + void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, + int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, + int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); + + bool available(); + void markAvailable(); + void markUnavailable(); + + void finish(); + + void execute(); +}; +} // namespace samediff + +#endif // SD_CALLABLEINTERFACE_H diff --git a/libnd4j/include/execution/CallableWithArguments.h b/libnd4j/include/execution/CallableWithArguments.h index 84a3be3f8b8e..f926492df491 100644 --- a/libnd4j/include/execution/CallableWithArguments.h +++ b/libnd4j/include/execution/CallableWithArguments.h @@ -21,72 +21,77 @@ #ifndef SD_CALLABLEWITHARGUMENTS_H #define SD_CALLABLEWITHARGUMENTS_H -#include -#include +#include + #include #include -#include +#include +#include namespace samediff { - class CallableWithArguments { - FUNC_DO _function_do; - FUNC_1D _function_1d; - FUNC_2D _function_2d; - FUNC_3D _function_3d; - - std::vector _arguments; - - std::atomic _finished; - - std::condition_variable _condition; - - std::mutex _lock; - - int _dimensions = 0; - - uint64_t _threadId; - uint64_t _numThreads; - public: - CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads); - CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x); - CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y); - CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z); - - - /** - * This method returns number of dimensions - * @return - */ - int dimensions(); - - /** - * This method checks if this callable is finished - * @return - */ - bool finished(); - - /** - * this method marks this Callable as finished - */ - void finish(); - - /** - * This method blocks until callable is finished - */ - void waitUntilFinished(); - - std::vector& arguments(); - FUNC_DO function_do(); - FUNC_1D function_1d(); - FUNC_2D function_2d(); - FUNC_3D function_3d(); - - - uint64_t threadId(); - - uint64_t numThreads(); - }; -} - - -#endif //SD_CALLABLEWITHARGUMENTS_H +class CallableWithArguments { + FUNC_DO _function_do; + FUNC_1D _function_1d; + FUNC_2D _function_2d; + FUNC_3D _function_3d; + + std::vector _arguments; + + std::atomic _finished; + + std::condition_variable _condition; + + std::mutex _lock; + + int _dimensions = 0; + + uint64_t _threadId; + uint64_t _numThreads; + + public: + CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads); + CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x); + CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x, int64_t start_y, + int64_t stop_y, int64_t increment_y); + CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, + int64_t stop_x, int64_t increment_x, int64_t start_y, + int64_t stop_y, int64_t increment_y, int64_t start_z, + int64_t stop_z, int64_t increment_z); + + /** + * This method returns number of dimensions + * @return + */ + int dimensions(); + + /** + * This method checks if this callable is finished + * @return + */ + bool finished(); + + /** + * this method marks this Callable as finished + */ + void finish(); + + /** + * This method blocks until callable is finished + */ + void waitUntilFinished(); + + std::vector& arguments(); + FUNC_DO function_do(); + FUNC_1D function_1d(); + FUNC_2D function_2d(); + FUNC_3D function_3d(); + + uint64_t threadId(); + + uint64_t numThreads(); +}; +} // namespace samediff + +#endif // SD_CALLABLEWITHARGUMENTS_H diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index 8b1b88bfac89..8cb2891febdc 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -21,56 +21,57 @@ #ifndef LIBND4J_CONTEXTBUFFERS_H #define LIBND4J_CONTEXTBUFFERS_H +#include #include #include -#include namespace sd { - class SD_EXPORT ContextBuffers { - private: - void* _reductionPointer = nullptr; - void* _scalarPointer = nullptr; - void* _allocationPointer = nullptr; - void* _execStream = nullptr; - void* _specialStream = nullptr; - sd::ErrorReference _errorReference; - bool _allocated = false; - bool _initialized = false; +class SD_EXPORT ContextBuffers { + private: + void* _reductionPointer = nullptr; + void* _scalarPointer = nullptr; + void* _allocationPointer = nullptr; + void* _execStream = nullptr; + void* _specialStream = nullptr; + sd::ErrorReference _errorReference; + bool _allocated = false; + bool _initialized = false; - int _deviceId = -1; + int _deviceId = -1; - void initialize(); - public: - ContextBuffers(); - ContextBuffers(const ContextBuffers &other); - ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false); - ~ContextBuffers(); + void initialize(); - ContextBuffers& operator=(const ContextBuffers& other); - ContextBuffers& operator=(ContextBuffers&& other); + public: + ContextBuffers(); + ContextBuffers(const ContextBuffers& other); + ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner = false); + ~ContextBuffers(); - void release(); + ContextBuffers& operator=(const ContextBuffers& other); + ContextBuffers& operator=(ContextBuffers&& other); - void* reductionBuffer(); - void* scalarBuffer(); - void* allocationBuffer(); + void release(); - void* execStream(); - void* specialStream(); + void* reductionBuffer(); + void* scalarBuffer(); + void* allocationBuffer(); - void setReductionBuffer(void* pointer); - void setScalarBuffer(void* pointer); - void setAllocationBuffer(void* pointer); + void* execStream(); + void* specialStream(); - sd::ErrorReference* errorReference(); + void setReductionBuffer(void* pointer); + void setScalarBuffer(void* pointer); + void setAllocationBuffer(void* pointer); - void triggerOwnership(bool isOwner); + sd::ErrorReference* errorReference(); - int deviceId(); + void triggerOwnership(bool isOwner); - bool isInitialized(); - }; -} + int deviceId(); + bool isInitialized(); +}; +} // namespace sd -#endif //SD_CONTEXTBUFFERS_H +#endif // SD_CONTEXTBUFFERS_H diff --git a/libnd4j/include/execution/Engine.h b/libnd4j/include/execution/Engine.h index cd30867a9bb2..69b16570491a 100644 --- a/libnd4j/include/execution/Engine.h +++ b/libnd4j/include/execution/Engine.h @@ -22,10 +22,10 @@ #define SD_ENGINE_H namespace samediff { - enum Engine { - ENGINE_CPU = 0, - ENGINE_CUDA = 1, - }; +enum Engine { + ENGINE_CPU = 0, + ENGINE_CUDA = 1, +}; } -#endif //SD_ENGINE_H +#endif // SD_ENGINE_H diff --git a/libnd4j/include/execution/ErrorReference.h b/libnd4j/include/execution/ErrorReference.h index 108878cc5f65..2cbbeef4c66b 100644 --- a/libnd4j/include/execution/ErrorReference.h +++ b/libnd4j/include/execution/ErrorReference.h @@ -21,26 +21,27 @@ #ifndef SD_ERRORREFERENCE_H #define SD_ERRORREFERENCE_H -#include #include +#include + namespace sd { - class SD_EXPORT ErrorReference { - private: - int _errorCode = 0; - std::string _errorMessage; - public: - ErrorReference() = default; - ~ErrorReference() = default; +class SD_EXPORT ErrorReference { + private: + int _errorCode = 0; + std::string _errorMessage; - int errorCode(); - const char* errorMessage(); + public: + ErrorReference() = default; + ~ErrorReference() = default; - void setErrorCode(int errorCode); - void setErrorMessage(std::string message); - void setErrorMessage(const char* message); - }; -} + int errorCode(); + const char* errorMessage(); + void setErrorCode(int errorCode); + void setErrorMessage(std::string message); + void setErrorMessage(const char* message); +}; +} // namespace sd -#endif //SD_ERRORREFERENCE_H +#endif // SD_ERRORREFERENCE_H diff --git a/libnd4j/include/execution/ExecutionMode.h b/libnd4j/include/execution/ExecutionMode.h index ea97e3fc9bdf..0c37cdfdbb83 100644 --- a/libnd4j/include/execution/ExecutionMode.h +++ b/libnd4j/include/execution/ExecutionMode.h @@ -22,11 +22,11 @@ #define SD_EXECUTIONMODE_H namespace samediff { - enum ExecutionMode { - MODE_UNDEFINED = 0, - MODE_TRAINING = 1, - MODE_INFERENCE = 2, - }; +enum ExecutionMode { + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2, +}; } -#endif //SD_EXECUTIONMODE_H +#endif // SD_EXECUTIONMODE_H diff --git a/libnd4j/include/execution/Executor.h b/libnd4j/include/execution/Executor.h index a9eaa6ad36c0..02646cf73f46 100644 --- a/libnd4j/include/execution/Executor.h +++ b/libnd4j/include/execution/Executor.h @@ -22,12 +22,12 @@ #define SD_EXECUTOR_H namespace sd { - class Executor { - public: - static void execute() { - // - } - }; -} +class Executor { + public: + static void execute() { + // + } +}; +} // namespace sd -#endif //SD_EXECUTOR_H +#endif // SD_EXECUTOR_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 308f028498bc..49910af3c093 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -21,12 +21,12 @@ #ifndef LIBND4J_CUDACONTEXT_H #define LIBND4J_CUDACONTEXT_H - #ifdef __CUDABLAS__ #include -#include -#include #include +#include +#include + #include "config.h" #endif @@ -35,101 +35,100 @@ #include "config.h" #endif -#include -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include +#include +#include - -namespace sd { +namespace sd { class SD_EXPORT LaunchContext { + private: + static std::vector> _contexts; + static std::mutex _mutex; - private: - static std::vector> _contexts; - static std::mutex _mutex; - - static MAP_IMPL _deviceMutexes; + static MAP_IMPL _deviceMutexes; - // used for MKLDNN - void *_engine = nullptr; + // used for MKLDNN + void* _engine = nullptr; #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - void* _cublasHandle = nullptr; - void* _cusolverHandle = nullptr; + void* _cublasHandle = nullptr; + void* _cusolverHandle = nullptr; -#endif // JCPP +#endif // JCPP - bool _isAllocated = false; -#endif // CUDA - sd::memory::Workspace* _workspace = nullptr; - int _deviceID = 0; + bool _isAllocated = false; +#endif // CUDA + sd::memory::Workspace* _workspace = nullptr; + int _deviceID = 0; - public: + public: #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr); - - void* getReductionPointer () const; - void* getScalarPointer() const; - int* getAllocationPointer() const; - void* getCublasHandle() const; - void* getCusolverHandle() const; - void* getCuDnnHandle() const; - cudaStream_t* getCudaStream() const; - cudaStream_t* getCudaSpecialStream() const; - - void setReductionPointer (void* reductionPointer); - void setScalarPointer(void* scalarPointer); - void setAllocationPointer(int* allocationPointer); - void setCudaStream(cudaStream_t* cudaStream); - void setCudaSpecialStream(cudaStream_t* cudaStream); - void setCublasHandle(void *handle); - -#endif // JCPP - -#endif // CUDA - LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer = nullptr, Nd4jPointer scalarPointer = nullptr, Nd4jPointer allocationPointer = nullptr); - LaunchContext(); - ~LaunchContext(); - sd::memory::Workspace* getWorkspace() const { return _workspace; } - void setWorkspace(sd::memory::Workspace* theWorkspace) { - _workspace = theWorkspace; - } - - void* engine(); - - int getDeviceID() const {return _deviceID;} - void setDeviceID(int deviceID) { _deviceID = deviceID; } - sd::ErrorReference* errorReference(); + LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, + void* reductionPointer = nullptr, void* scalarPointer = nullptr, + int* allocationPointer = nullptr); + + void* getReductionPointer() const; + void* getScalarPointer() const; + int* getAllocationPointer() const; + void* getCublasHandle() const; + void* getCusolverHandle() const; + void* getCuDnnHandle() const; + cudaStream_t* getCudaStream() const; + cudaStream_t* getCudaSpecialStream() const; + + void setReductionPointer(void* reductionPointer); + void setScalarPointer(void* scalarPointer); + void setAllocationPointer(int* allocationPointer); + void setCudaStream(cudaStream_t* cudaStream); + void setCudaSpecialStream(cudaStream_t* cudaStream); + void setCublasHandle(void* handle); + +#endif // JCPP + +#endif // CUDA + LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer = nullptr, + Nd4jPointer scalarPointer = nullptr, + Nd4jPointer allocationPointer = nullptr); + LaunchContext(); + ~LaunchContext(); + sd::memory::Workspace* getWorkspace() const { return _workspace; } + void setWorkspace(sd::memory::Workspace* theWorkspace) { + _workspace = theWorkspace; + } + + void* engine(); + + int getDeviceID() const { return _deviceID; } + void setDeviceID(int deviceID) { _deviceID = deviceID; } + sd::ErrorReference* errorReference(); #ifndef __JAVACPP_HACK__ - // this method returns mutex shared between all threads that use the same device - static std::mutex* deviceMutex(); + // this method returns mutex shared between all threads that use the same + // device + static std::mutex* deviceMutex(); #endif - static bool isInitialized(); - static void releaseBuffers(); - + static bool isInitialized(); + static void releaseBuffers(); - static LaunchContext* defaultContext(); - - - static void swapContextBuffers(ContextBuffers &buffers); + static LaunchContext* defaultContext(); + static void swapContextBuffers(ContextBuffers& buffers); }; -} - +} // namespace sd -#endif //LIBND4J_CUDACONTEXT_H +#endif // LIBND4J_CUDACONTEXT_H diff --git a/libnd4j/include/execution/ThreadPool.h b/libnd4j/include/execution/ThreadPool.h index 050ce5dfbac2..0c83fc23dd86 100644 --- a/libnd4j/include/execution/ThreadPool.h +++ b/libnd4j/include/execution/ThreadPool.h @@ -21,51 +21,55 @@ #ifndef SAMEDIFF_THREADPOOL_H #define SAMEDIFF_THREADPOOL_H -#include -#include -#include -#include -#include #include -#include #include +#include #include + +#include +#include +#include #include +#include +#include namespace samediff { - class SD_EXPORT ThreadPool { - private: - static ThreadPool* _INSTANCE; +class SD_EXPORT ThreadPool { + private: + static ThreadPool* _INSTANCE; + + std::vector _threads; + std::vector*> _queues; + std::vector _interfaces; - std::vector _threads; - std::vector*> _queues; - std::vector _interfaces; + std::mutex _lock; + std::atomic _available; + std::queue _tickets; - std::mutex _lock; - std::atomic _available; - std::queue _tickets; - protected: - ThreadPool(); - ~ThreadPool(); - public: - static ThreadPool* getInstance(); + protected: + ThreadPool(); + ~ThreadPool(); - /** - * This method returns list of pointers to threads ONLY if num_threads of threads were available upon request, returning empty list otherwise - * @param num_threads - * @return - */ - Ticket* tryAcquire(int num_threads); + public: + static ThreadPool* getInstance(); - /** - * This method marks specified number of threads as released, and available for use - * @param num_threads - */ - void release(int num_threads = 1); + /** + * This method returns list of pointers to threads ONLY if num_threads of + * threads were available upon request, returning empty list otherwise + * @param num_threads + * @return + */ + Ticket* tryAcquire(int num_threads); - void release(Ticket *ticket); - }; -} + /** + * This method marks specified number of threads as released, and available + * for use + * @param num_threads + */ + void release(int num_threads = 1); + void release(Ticket* ticket); +}; +} // namespace samediff -#endif //SD_THREADPOOL_H +#endif // SD_THREADPOOL_H diff --git a/libnd4j/include/execution/Threads.h b/libnd4j/include/execution/Threads.h index 26c0c0afd542..147b84bf85dc 100644 --- a/libnd4j/include/execution/Threads.h +++ b/libnd4j/include/execution/Threads.h @@ -14,167 +14,209 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// #ifndef SAMEDIFF_THREADS_H #define SAMEDIFF_THREADS_H -#include -#include -#include #include +#include #include +#include + +#include namespace samediff { - class SD_EXPORT ThreadsHelper { - public: - static int numberOfThreads(int maxThreads, uint64_t numberOfElements); - static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y); - static int numberOfThreads3d(int maxThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); - static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y); - static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z); - }; - - class SD_EXPORT Span { - private: - int64_t _startX, _stopX, _incX; - public: - Span(int64_t start_x, int64_t stop_x, int64_t inc_x); - ~Span() = default; - - int64_t startX() const; - int64_t stopX() const; - int64_t incX() const; - - static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x); - }; - - class SD_EXPORT Span2 { - private: - int64_t _startX, _stopX, _incX; - int64_t _startY, _stopY, _incY; - public: - Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - ~Span2() = default; - - int64_t startX() const; - int64_t startY() const; - - int64_t stopX() const; - int64_t stopY() const; - - int64_t incX() const; - int64_t incY() const; - - static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - }; - - class SD_EXPORT Span3 { - private: - int64_t _startX, _stopX, _incX; - int64_t _startY, _stopY, _incY; - int64_t _startZ, _stopZ, _incZ; - public: - Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - ~Span3() = default; - - int64_t startX() const; - int64_t startY() const; - int64_t startZ() const; - - int64_t stopX() const; - int64_t stopY() const; - int64_t stopZ() const; - - int64_t incX() const; - int64_t incY() const; - int64_t incZ() const; - - static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z); - }; - - class SD_EXPORT Threads { - public: - /** - * This function executes 1 dimensional loop for a given number of threads - * PLEASE NOTE: this function can use smaller number of threads than requested. - * - * @param function - * @param numThreads - * @param start - * @param stop - * @param increment - * @return - */ - static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - /** - * This function executes 1 dimensional loop for a given number of threads - * - * @param function - * @param start - * @param stop - * @param increment - * @param numThreads - * @return - */ - static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - /** - * This method will execute function splitting 2 nested loops space with multiple threads - * - * @param function - * @param numThreads - * @param start_x - * @param stop_x - * @param inc_x - * @param start_y - * @param stop_y - * @param inc_y - * @return - */ - static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads(), bool debug = false); - - /** - * This method will execute function splitting 3 nested loops space with multiple threads - * - * @param function - * @param numThreads - * @param start_x - * @param stop_x - * @param inc_x - * @param start_y - * @param stop_y - * @param inc_y - * @param start_z - * @param stop_z - * @param inc_z - * @return - */ - static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - /** - * - * @param function - * @param numThreads - * @return - */ - static int parallel_do(FUNC_DO function, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - /** - * This method will execute function in parallel preserving the parts to be aligned increment size - * PLEASE NOTE: this function can use smaller number of threads than requested. - * - */ - static int parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size = sizeof(float), uint32_t req_numThreads = sd::Environment::getInstance()->maxMasterThreads()); - - }; -} - - -#endif //SAMEDIFF_THREADS_H +class SD_EXPORT ThreadsHelper { + public: + static int numberOfThreads(int maxThreads, uint64_t numberOfElements); + static int numberOfThreads2d(int maxThreads, uint64_t iters_x, + uint64_t iters_y); + static int numberOfThreads3d(int maxThreads, uint64_t iters_x, + uint64_t iters_y, uint64_t iters_z); + static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y); + static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, + uint64_t iters_z); +}; + +class SD_EXPORT Span { + private: + int64_t _startX, _stopX, _incX; + + public: + Span(int64_t start_x, int64_t stop_x, int64_t inc_x); + ~Span() = default; + + int64_t startX() const; + int64_t stopX() const; + int64_t incX() const; + + static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, + int64_t stop_x, int64_t inc_x); +}; + +class SD_EXPORT Span2 { + private: + int64_t _startX, _stopX, _incX; + int64_t _startY, _stopY, _incY; + + public: + Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y); + ~Span2() = default; + + int64_t startX() const; + int64_t startY() const; + + int64_t stopX() const; + int64_t stopY() const; + + int64_t incX() const; + int64_t incY() const; + + static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y); +}; + +class SD_EXPORT Span3 { + private: + int64_t _startX, _stopX, _incX; + int64_t _startY, _stopY, _incY; + int64_t _startZ, _stopZ, _incZ; + + public: + Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, + int64_t inc_z); + ~Span3() = default; + + int64_t startX() const; + int64_t startY() const; + int64_t startZ() const; + + int64_t stopX() const; + int64_t stopY() const; + int64_t stopZ() const; + + int64_t incX() const; + int64_t incY() const; + int64_t incZ() const; + + static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z); +}; + +class SD_EXPORT Threads { + public: + /** + * This function executes 1 dimensional loop for a given number of threads + * PLEASE NOTE: this function can use smaller number of threads than + * requested. + * + * @param function + * @param numThreads + * @param start + * @param stop + * @param increment + * @return + */ + static int parallel_for( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, + uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + /** + * This function executes 1 dimensional loop for a given number of threads + * + * @param function + * @param start + * @param stop + * @param increment + * @param numThreads + * @return + */ + static int parallel_tad( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, + uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + /** + * This method will execute function splitting 2 nested loops space with + * multiple threads + * + * @param function + * @param numThreads + * @param start_x + * @param stop_x + * @param inc_x + * @param start_y + * @param stop_y + * @param inc_y + * @return + */ + static int parallel_for( + FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads(), + bool debug = false); + + /** + * This method will execute function splitting 3 nested loops space with + * multiple threads + * + * @param function + * @param numThreads + * @param start_x + * @param stop_x + * @param inc_x + * @param start_y + * @param stop_y + * @param inc_y + * @param start_z + * @param stop_z + * @param inc_z + * @return + */ + static int parallel_for( + FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, + int64_t stop_z, int64_t inc_z, + uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + /** + * + * @param function + * @param numThreads + * @return + */ + static int parallel_do( + FUNC_DO function, + uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + static int64_t parallel_long( + FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, + int64_t increment = 1, + uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + static double parallel_double( + FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, + int64_t increment = 1, + uint64_t numThreads = sd::Environment::getInstance()->maxMasterThreads()); + + /** + * This method will execute function in parallel preserving the parts to be + * aligned increment size PLEASE NOTE: this function can use smaller number of + * threads than requested. + * + */ + static int parallel_aligned_increment( + FUNC_1D function, int64_t start, int64_t stop, int64_t increment, + size_t type_size = sizeof(float), + uint32_t req_numThreads = + sd::Environment::getInstance()->maxMasterThreads()); +}; +} // namespace samediff + +#endif // SAMEDIFF_THREADS_H diff --git a/libnd4j/include/execution/Ticket.h b/libnd4j/include/execution/Ticket.h index e6f8fdf6c003..d66039c6620e 100644 --- a/libnd4j/include/execution/Ticket.h +++ b/libnd4j/include/execution/Ticket.h @@ -21,47 +21,57 @@ #ifndef SAMEDIFF_TICKET_H #define SAMEDIFF_TICKET_H -#include #include -#include #include +#include + #include #include +#include namespace samediff { - class SD_EXPORT Ticket { - private: - bool _acquired = false; - std::vector*> _queues; - std::vector _callables; - std::vector _interfaces; +class SD_EXPORT Ticket { + private: + bool _acquired = false; + std::vector *> _queues; + std::vector _callables; + std::vector _interfaces; - uint32_t _acquiredThreads = 0; - public: - explicit Ticket(const std::vector*> &queues); - Ticket(); - ~Ticket() = default; + uint32_t _acquiredThreads = 0; - bool acquired(); + public: + explicit Ticket( + const std::vector *> &queues); + Ticket(); + ~Ticket() = default; - void acquiredThreads(uint32_t threads); + bool acquired(); - void attach(uint32_t thread_id, CallableInterface *interface); + void acquiredThreads(uint32_t threads); - // deprecated one - void enqueue(int thread_id, CallableWithArguments* callable); + void attach(uint32_t thread_id, CallableInterface *interface); - void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); + // deprecated one + void enqueue(int thread_id, CallableWithArguments *callable); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y); - void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, int64_t inc_z); + void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, + FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x); + void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, + FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x); - void waitAndRelease(); - }; -} + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, + int64_t start_x, int64_t stop_x, int64_t inc_x); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y); + void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, + int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, + int64_t inc_z); + void waitAndRelease(); +}; +} // namespace samediff -#endif //SD_TICKET_H +#endif // SD_TICKET_H diff --git a/libnd4j/include/execution/cpu/AffinityManager.cpp b/libnd4j/include/execution/cpu/AffinityManager.cpp index 32df63d0d1e8..f00c1e35eb85 100644 --- a/libnd4j/include/execution/cpu/AffinityManager.cpp +++ b/libnd4j/include/execution/cpu/AffinityManager.cpp @@ -21,23 +21,17 @@ #include namespace sd { - int AffinityManager::currentDeviceId() { - return 0; - } +int AffinityManager::currentDeviceId() { return 0; } - int AffinityManager::currentNativeDeviceId() { - return 0; - } +int AffinityManager::currentNativeDeviceId() { return 0; } - int AffinityManager::numberOfDevices() { - return 1; - } +int AffinityManager::numberOfDevices() { return 1; } - void AffinityManager::setCurrentDevice(int deviceId) { - // no-op - } +void AffinityManager::setCurrentDevice(int deviceId) { + // no-op +} - void AffinityManager::setCurrentNativeDevice(int deviceId) { - // no-op - } -} \ No newline at end of file +void AffinityManager::setCurrentNativeDevice(int deviceId) { + // no-op +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp index 3b1c566a837d..0e1d6a8d2242 100644 --- a/libnd4j/include/execution/cpu/ContextBuffers.cpp +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -17,90 +17,75 @@ // // @author raver119@gmail.com // -#include #include +#include namespace sd { - ContextBuffers::ContextBuffers() { - _deviceId = AffinityManager::currentDeviceId(); - } - - ContextBuffers::~ContextBuffers() { - // no-op - } - - ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { - _reductionPointer = rPointer; - _scalarPointer = sPointer; - _allocationPointer = aPointer; - _allocated = isOwner; - } - - ContextBuffers::ContextBuffers(const ContextBuffers &other) { - // - } - - void ContextBuffers::initialize() { - // no-op - } - - void* ContextBuffers::reductionBuffer() { - return _reductionPointer; - } - - void* ContextBuffers::scalarBuffer() { - return _scalarPointer; - } - - void* ContextBuffers::allocationBuffer() { - return _allocationPointer; - } - - void ContextBuffers::setReductionBuffer(void* pointer) { - _reductionPointer = pointer; - } - - void ContextBuffers::setScalarBuffer(void* pointer) { - _scalarPointer = pointer; - } - - void ContextBuffers::setAllocationBuffer(void* pointer) { - _allocationPointer = pointer; - } - - void ContextBuffers::triggerOwnership(bool isOwner) { - _allocated = isOwner; - } - - int ContextBuffers::deviceId() { - return _deviceId; - } - - void* ContextBuffers::execStream() { - return _execStream; - } - - void* ContextBuffers::specialStream() { - return _specialStream; - } - - bool ContextBuffers::isInitialized() { - return true; - } - - void ContextBuffers::release() { - // - } - - ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { - return *this; - } - - ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { - return *this; - } - - sd::ErrorReference* ContextBuffers::errorReference() { - return &_errorReference; - } -} \ No newline at end of file +ContextBuffers::ContextBuffers() { + _deviceId = AffinityManager::currentDeviceId(); +} + +ContextBuffers::~ContextBuffers() { + // no-op +} + +ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; +} + +ContextBuffers::ContextBuffers(const ContextBuffers& other) { + // +} + +void ContextBuffers::initialize() { + // no-op +} + +void* ContextBuffers::reductionBuffer() { return _reductionPointer; } + +void* ContextBuffers::scalarBuffer() { return _scalarPointer; } + +void* ContextBuffers::allocationBuffer() { return _allocationPointer; } + +void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; +} + +void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; +} + +void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; +} + +void ContextBuffers::triggerOwnership(bool isOwner) { _allocated = isOwner; } + +int ContextBuffers::deviceId() { return _deviceId; } + +void* ContextBuffers::execStream() { return _execStream; } + +void* ContextBuffers::specialStream() { return _specialStream; } + +bool ContextBuffers::isInitialized() { return true; } + +void ContextBuffers::release() { + // +} + +ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + return *this; +} + +ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + return *this; +} + +sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 23e78c350a05..95ad7716820e 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -18,13 +18,15 @@ // Created by raver119 on 30.11.17. // -#include +#include #include +#include #include -#include + #include -#if defined(SD_IOS_BUILD) || defined(SD_APPLE_BUILD) || defined(SD_ANDROID_BUILD) +#if defined(SD_IOS_BUILD) || defined(SD_APPLE_BUILD) || \ + defined(SD_ANDROID_BUILD) sd::ContextBuffers contextBuffers = sd::ContextBuffers(); #else thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); @@ -36,62 +38,59 @@ thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); namespace sd { - LaunchContext::~LaunchContext() { +LaunchContext::~LaunchContext() { #ifdef HAVE_MKLDNN - delete reinterpret_cast(_engine); + delete reinterpret_cast(_engine); #endif - } +} - std::vector> LaunchContext::_contexts = std::vector>(); - MAP_IMPL LaunchContext::_deviceMutexes; - std::mutex LaunchContext::_mutex; +std::vector> LaunchContext::_contexts = + std::vector>(); +MAP_IMPL LaunchContext::_deviceMutexes; +std::mutex LaunchContext::_mutex; //////////////////////////////////////////////////////////////////////// - LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; +LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; #ifdef HAVE_MKLDNN - _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); + _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); #endif - } - - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { +} - } +LaunchContext::LaunchContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer scalarPointer, + Nd4jPointer allocationPointer) {} - LaunchContext* LaunchContext::defaultContext() { - // TODO: we need it to be device-aware, but only once we add NUMA support for cpu - if (LaunchContext::_contexts.empty()) { - LaunchContext::_contexts.emplace_back(std::make_shared()); - } +LaunchContext* LaunchContext::defaultContext() { + // TODO: we need it to be device-aware, but only once we add NUMA support for + // cpu + if (LaunchContext::_contexts.empty()) { + LaunchContext::_contexts.emplace_back(std::make_shared()); + } - // return context for current device - return LaunchContext::_contexts[0].get(); - } + // return context for current device + return LaunchContext::_contexts[0].get(); +} - std::mutex* LaunchContext::deviceMutex() { - return &_mutex; - } +std::mutex* LaunchContext::deviceMutex() { return &_mutex; } - void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { - // - } +void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { + // +} - bool LaunchContext::isInitialized() { - return true; - } +bool LaunchContext::isInitialized() { return true; } - void LaunchContext::releaseBuffers() { - // - } +void LaunchContext::releaseBuffers() { + // +} - sd::ErrorReference* LaunchContext::errorReference() { - return contextBuffers.errorReference(); - } +sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); +} - void* LaunchContext::engine() { - return _engine; - } -} \ No newline at end of file +void* LaunchContext::engine() { return _engine; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu index cdfe7c1079d8..d0c9f95b39b9 100644 --- a/libnd4j/include/execution/cuda/AffinityManager.cu +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -18,111 +18,109 @@ // @author raver119@gmail.com // -#include -#include #include +#include #include +#include thread_local int globalThreadToDevice = -1; namespace sd { - std::mutex AffinityManager::_currentMutex; - std::mutex AffinityManager::_numberMutex; - int AffinityManager::_numberOfDevices = -1; +std::mutex AffinityManager::_currentMutex; +std::mutex AffinityManager::_numberMutex; +int AffinityManager::_numberOfDevices = -1; + +int AffinityManager::currentDeviceId() { + // if there's no affinity set - set it now + if (globalThreadToDevice < 0) { + // this block must be thread-local + _currentMutex.lock(); + + globalThreadToDevice = _lastDevice++; + + // we need to check if we've got deviceId >= number of actual devices, and + // reset to zero otherwise + if (globalThreadToDevice >= numberOfDevices()) { + globalThreadToDevice = 0; + _lastDevice = numberOfDevices() > 1 ? 1 : 0; + } - int AffinityManager::currentDeviceId() { - // if there's no affinity set - set it now - if (globalThreadToDevice < 0) { + _currentMutex.unlock(); - // this block must be thread-local - _currentMutex.lock(); + setCurrentNativeDevice(globalThreadToDevice); + } - globalThreadToDevice = _lastDevice++; + // if we already know affinity - just return it + if (globalThreadToDevice >= 0) return globalThreadToDevice; - // we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise - if (globalThreadToDevice >= numberOfDevices()) { - globalThreadToDevice = 0; - _lastDevice = numberOfDevices() > 1 ? 1 : 0; - } + int dev = 0; + auto res = cudaGetDevice(&dev); - _currentMutex.unlock(); + if (res != 0) throw cuda_exception::build("cudaGetDevice failed", res); - setCurrentNativeDevice(globalThreadToDevice); - } + return dev; +} - // if we already know affinity - just return it - if (globalThreadToDevice >= 0) - return globalThreadToDevice; +int AffinityManager::currentNativeDeviceId() { + int dev = 0; + auto res = cudaGetDevice(&dev); - int dev = 0; - auto res = cudaGetDevice(&dev); + if (res != 0) throw cuda_exception::build("cudaGetDevice failed", res); - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); + return dev; +} - return dev; - } +int AffinityManager::numberOfDevices() { + _numberMutex.lock(); + // we want to cache number of devices + if (_numberOfDevices <= 0) { + int dev = 0; + auto res = cudaGetDeviceCount(&dev); - int AffinityManager::currentNativeDeviceId() { - int dev = 0; - auto res = cudaGetDevice(&dev); + if (res != 0) throw cuda_exception::build("cudaGetDeviceCount failed", res); - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); + _numberOfDevices = dev; + } + _numberMutex.unlock(); - return dev; - } + return _numberOfDevices; +} - int AffinityManager::numberOfDevices() { - _numberMutex.lock(); - // we want to cache number of devices - if (_numberOfDevices <= 0) { - int dev = 0; - auto res = cudaGetDeviceCount(&dev); +void AffinityManager::setCurrentNativeDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) throw cuda_exception::build("setCurrentDevice failed", res); +} - if (res != 0) - throw cuda_exception::build("cudaGetDeviceCount failed", res); +void AffinityManager::setCurrentDevice(int deviceId) { + auto previousDeviceId = globalThreadToDevice; + if (previousDeviceId >= 0 && LaunchContext::isInitialized()) { + auto res = cudaStreamSynchronize( + *LaunchContext::defaultContext()->getCudaStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> sync failed", res); - _numberOfDevices = dev; - } - _numberMutex.unlock(); + res = cudaStreamSynchronize( + *LaunchContext::defaultContext()->getCudaSpecialStream()); + if (res != 0) + throw cuda_exception::build("setCurrentDevice -> specialSync failed", + res); - return _numberOfDevices; + if (deviceId != previousDeviceId) { + // discard existing stuff + // nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing + // buffers\n", ""); + LaunchContext::releaseBuffers(); } + } - void AffinityManager::setCurrentNativeDevice(int deviceId) { - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("setCurrentDevice failed", res); - } + if (deviceId != previousDeviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); + } - void AffinityManager::setCurrentDevice(int deviceId) { - auto previousDeviceId = globalThreadToDevice; - if (previousDeviceId >= 0 && LaunchContext::isInitialized()) { - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - if (res != 0) - throw cuda_exception::build("setCurrentDevice -> sync failed", res); - - res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaSpecialStream()); - if (res != 0) - throw cuda_exception::build("setCurrentDevice -> specialSync failed", res); - - if (deviceId != previousDeviceId) { - // discard existing stuff - //nd4j_printf("AffinityManager::setCurrentDevice() was invoked, releasing buffers\n", ""); - LaunchContext::releaseBuffers(); - } - } - - if (deviceId != previousDeviceId) { - auto res = cudaSetDevice(deviceId); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); - } - - // update thread-device affinity - globalThreadToDevice = deviceId; - } + // update thread-device affinity + globalThreadToDevice = deviceId; +} - std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); -} \ No newline at end of file +std::atomic AffinityManager::_lastDevice; // = std::atomic(initialV); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 9411a27d5fea..9b1518b43b86 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -18,220 +18,211 @@ // @author raver119@gmail.com // -#include -#include -#include -#include - #include -#include -#include #include +#include +#include +#include +#include +#include +#include namespace sd { - ContextBuffers::ContextBuffers() { - //nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId()); - _deviceId = AffinityManager::currentDeviceId(); - } +ContextBuffers::ContextBuffers() { + // nd4j_printf("Creating ContextBuffers for device [%i]\n", + // AffinityManager::currentDeviceId()); + _deviceId = AffinityManager::currentDeviceId(); +} + +ContextBuffers::ContextBuffers(const ContextBuffers& other) { + release(); - ContextBuffers::ContextBuffers(const ContextBuffers &other) { - release(); - - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; - - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { - release(); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; +} - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; +ContextBuffers& ContextBuffers::operator=(const ContextBuffers& other) { + release(); - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; - - return *this; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { - release(); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; - this->_initialized = other._initialized; - this->_allocated = other._allocated; - this->_deviceId = other._deviceId; + return *this; +} - this->_specialStream = other._specialStream; - this->_execStream = other._execStream; - this->_allocationPointer = other._allocationPointer; - this->_reductionPointer = other._reductionPointer; - this->_scalarPointer = other._scalarPointer; +ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { + release(); - return *this; - } + this->_initialized = other._initialized; + this->_allocated = other._allocated; + this->_deviceId = other._deviceId; - void ContextBuffers::release() { - if (_allocated) { - //nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId); + this->_specialStream = other._specialStream; + this->_execStream = other._execStream; + this->_allocationPointer = other._allocationPointer; + this->_reductionPointer = other._reductionPointer; + this->_scalarPointer = other._scalarPointer; - if (_allocationPointer != nullptr) - cudaFree(_allocationPointer); + return *this; +} - if (_scalarPointer != nullptr) - cudaFree(_scalarPointer); +void ContextBuffers::release() { + if (_allocated) { + // nd4j_printf("Releasing ContextBuffers on device [%i]\n", _deviceId); - if (_allocationPointer != nullptr) - cudaFree(_reductionPointer); + if (_allocationPointer != nullptr) cudaFree(_allocationPointer); - auto _cudaStream = reinterpret_cast(_execStream); - auto _cudaSpecialStream = reinterpret_cast(_specialStream); + if (_scalarPointer != nullptr) cudaFree(_scalarPointer); - cudaStreamSynchronize(*_cudaStream); - cudaStreamSynchronize(*_cudaSpecialStream); + if (_allocationPointer != nullptr) cudaFree(_reductionPointer); - cudaStreamDestroy(*_cudaStream); - cudaStreamDestroy(*_cudaSpecialStream); + auto _cudaStream = reinterpret_cast(_execStream); + auto _cudaSpecialStream = reinterpret_cast(_specialStream); - delete _cudaStream; - delete _cudaSpecialStream; + cudaStreamSynchronize(*_cudaStream); + cudaStreamSynchronize(*_cudaSpecialStream); - ////// - _allocated = false; - _deviceId = -1; - - this->_specialStream = nullptr; - this->_execStream = nullptr; - this->_allocationPointer = nullptr; - this->_reductionPointer = nullptr; - this->_scalarPointer = nullptr; - } + cudaStreamDestroy(*_cudaStream); + cudaStreamDestroy(*_cudaSpecialStream); - _initialized = false; - } - - ContextBuffers::~ContextBuffers() { - release(); - } - - ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { - _reductionPointer = rPointer; - _scalarPointer = sPointer; - _allocationPointer = aPointer; - _allocated = isOwner; - } - - void ContextBuffers::initialize() { - _deviceId = AffinityManager::currentNativeDeviceId(); - //nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId); + delete _cudaStream; + delete _cudaSpecialStream; - auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); - if (res != 0) - throw cuda_exception::build("_reductionPointer allocation failed", res); + ////// + _allocated = false; + _deviceId = -1; - res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("_scalarPointer allocation failed", res); + this->_specialStream = nullptr; + this->_execStream = nullptr; + this->_allocationPointer = nullptr; + this->_reductionPointer = nullptr; + this->_scalarPointer = nullptr; + } - res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); - if (res != 0) - throw cuda_exception::build("_allocationPointer allocation failed", res); + _initialized = false; +} - _execStream = new cudaStream_t(); - _specialStream = new cudaStream_t(); - if (nullptr == _execStream || nullptr == _specialStream) - throw std::runtime_error("Failed to allocate memory for new CUDA stream"); +ContextBuffers::~ContextBuffers() { release(); } - res = cudaStreamCreate(reinterpret_cast(_execStream)); - if (res != 0) - throw cuda_exception::build("Failed to create default CUDA stream with launch context", res); +ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, + bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; +} - res = cudaStreamCreate(reinterpret_cast(_specialStream)); - if (res != 0) - throw cuda_exception::build("Failed to create special CUDA stream with launch context", res); +void ContextBuffers::initialize() { + _deviceId = AffinityManager::currentNativeDeviceId(); + // nd4j_printf("Initializing buffers on deviceId [%i]\n", _deviceId); + + auto res = + cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); + if (res != 0) + throw cuda_exception::build("_reductionPointer allocation failed", res); + + res = cudaHostAlloc(reinterpret_cast(&_scalarPointer), 16, + cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("_scalarPointer allocation failed", res); + + res = cudaMalloc(reinterpret_cast(&_allocationPointer), + 1024 * 1024 * 8); + if (res != 0) + throw cuda_exception::build("_allocationPointer allocation failed", res); + + _execStream = new cudaStream_t(); + _specialStream = new cudaStream_t(); + if (nullptr == _execStream || nullptr == _specialStream) + throw std::runtime_error("Failed to allocate memory for new CUDA stream"); + + res = cudaStreamCreate(reinterpret_cast(_execStream)); + if (res != 0) + throw cuda_exception::build( + "Failed to create default CUDA stream with launch context", res); + + res = cudaStreamCreate(reinterpret_cast(_specialStream)); + if (res != 0) + throw cuda_exception::build( + "Failed to create special CUDA stream with launch context", res); + + _allocated = true; + _initialized = true; +} - _allocated = true; - _initialized = true; - } +void* ContextBuffers::reductionBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::reductionBuffer() { - if (!_initialized) - initialize(); + return _reductionPointer; +} - return _reductionPointer; - } +void* ContextBuffers::scalarBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::scalarBuffer() { - if (!_initialized) - initialize(); + return _scalarPointer; +} - return _scalarPointer; - } +void* ContextBuffers::allocationBuffer() { + if (!_initialized) initialize(); - void* ContextBuffers::allocationBuffer() { - if (!_initialized) - initialize(); + return _allocationPointer; +} - return _allocationPointer; - } +void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; +} - void ContextBuffers::setReductionBuffer(void* pointer) { - _reductionPointer = pointer; - } +void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; +} - void ContextBuffers::setScalarBuffer(void* pointer) { - _scalarPointer = pointer; - } +void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; +} - void ContextBuffers::setAllocationBuffer(void* pointer) { - _allocationPointer = pointer; - } +void ContextBuffers::triggerOwnership(bool isOwner) { _allocated = isOwner; } - void ContextBuffers::triggerOwnership(bool isOwner) { - _allocated = isOwner; - } - - int ContextBuffers::deviceId() { - return _deviceId; - } - - void* ContextBuffers::execStream() { - if (!_initialized) { - //nd4j_printf("execStream not initialized\n", ""); - initialize(); - } else { - //nd4j_printf("execStream is initialized\n", ""); - } - - return _execStream; - } +int ContextBuffers::deviceId() { return _deviceId; } - void* ContextBuffers::specialStream() { - if (!_initialized) { - //nd4j_printf("specialStream not initialized\n", ""); - initialize(); - } else { - //nd4j_printf("specialStream is initialized\n", ""); - } +void* ContextBuffers::execStream() { + if (!_initialized) { + // nd4j_printf("execStream not initialized\n", ""); + initialize(); + } else { + // nd4j_printf("execStream is initialized\n", ""); + } - return _specialStream; - } + return _execStream; +} - bool ContextBuffers::isInitialized() { - return _initialized; - } +void* ContextBuffers::specialStream() { + if (!_initialized) { + // nd4j_printf("specialStream not initialized\n", ""); + initialize(); + } else { + // nd4j_printf("specialStream is initialized\n", ""); + } - sd::ErrorReference* ContextBuffers::errorReference() { - return &_errorReference; - } + return _specialStream; } +bool ContextBuffers::isInitialized() { return _initialized; } + +sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; +} +} // namespace sd diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 8380e50bf495..8106d654de3c 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -18,171 +18,170 @@ // Created by raver119 on 30.11.17. // -#include -#include #include +#include +#include #include +#include + #include -#include thread_local sd::ContextBuffers contextBuffers = sd::ContextBuffers(); namespace sd { - std::vector> LaunchContext::_contexts = std::vector>(); - std::mutex LaunchContext::_mutex; - MAP_IMPL LaunchContext::_deviceMutexes; +std::vector> LaunchContext::_contexts = + std::vector>(); +std::mutex LaunchContext::_mutex; +MAP_IMPL LaunchContext::_deviceMutexes; //////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { - - //_cudaStream = cudaStream; - //_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; - //_reductionPointer = reductionPointer; - //_scalarPointer = scalarPointer; - //_allocationPointer = allocationPointer; - _workspace = nullptr; - _isAllocated = false; +LaunchContext::LaunchContext(cudaStream_t* cudaStream, + cudaStream_t& specialCudaStream, + void* reductionPointer, void* scalarPointer, + int* allocationPointer) { + //_cudaStream = cudaStream; + //_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; + //*_cudaSpecialStream = specialCudaStream; _reductionPointer = + //reductionPointer; _scalarPointer = scalarPointer; _allocationPointer = + //allocationPointer; + _workspace = nullptr; + _isAllocated = false; } - std::mutex* LaunchContext::deviceMutex() { - auto deviceId = AffinityManager::currentDeviceId(); - return _deviceMutexes[deviceId]; - } +std::mutex* LaunchContext::deviceMutex() { + auto deviceId = AffinityManager::currentDeviceId(); + return _deviceMutexes[deviceId]; +} LaunchContext::~LaunchContext() { - if (_isAllocated) { - - } + if (_isAllocated) { + } } //////////////////////////////////////////////////////////////////////// LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; - _isAllocated = true; + _isAllocated = true; } - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { - _isAllocated = false; - //_cudaStream = reinterpret_cast(cudaStream); - // _cudaSpecialStream = reinterpret_cast(cudaStream); - //_reductionPointer = reductionPointer; - //_scalarPointer = scalarPointer; - //_allocationPointer = reinterpret_cast(allocationPointer); - } +LaunchContext::LaunchContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer scalarPointer, + Nd4jPointer allocationPointer) { + _isAllocated = false; + //_cudaStream = reinterpret_cast(cudaStream); + // _cudaSpecialStream = reinterpret_cast(cudaStream); + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = reinterpret_cast(allocationPointer); +} - LaunchContext* LaunchContext::defaultContext() { - /** - * This method returns LaunchContext, that has multiple entities within: - * 1) temporary buffers. they must be per-thread - * 2) CUDA stream. it must be either per-thread or per-device - * 3) cuBLAS handle. it must be per-device - */ - auto deviceId = AffinityManager::currentDeviceId(); - - // we need this block synchronous, to avoid double initialization etc - _mutex.lock(); - if (LaunchContext::_contexts.empty()) { - // create one context per device - auto numDevices = AffinityManager::numberOfDevices(); - - _contexts.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - _deviceMutexes[e] = new std::mutex(); - - AffinityManager::setCurrentNativeDevice(e); - - LaunchContext::_contexts[e] = std::make_shared(); - } - - // don't forget to restore device back again - AffinityManager::setCurrentNativeDevice(deviceId); - } - _mutex.unlock(); - - // return context for current device - return LaunchContext::_contexts[deviceId].get(); +LaunchContext* LaunchContext::defaultContext() { + /** + * This method returns LaunchContext, that has multiple entities within: + * 1) temporary buffers. they must be per-thread + * 2) CUDA stream. it must be either per-thread or per-device + * 3) cuBLAS handle. it must be per-device + */ + auto deviceId = AffinityManager::currentDeviceId(); + + // we need this block synchronous, to avoid double initialization etc + _mutex.lock(); + if (LaunchContext::_contexts.empty()) { + // create one context per device + auto numDevices = AffinityManager::numberOfDevices(); + + _contexts.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + _deviceMutexes[e] = new std::mutex(); + + AffinityManager::setCurrentNativeDevice(e); + + LaunchContext::_contexts[e] = std::make_shared(); } + // don't forget to restore device back again + AffinityManager::setCurrentNativeDevice(deviceId); + } + _mutex.unlock(); - void* LaunchContext::getReductionPointer () const { - return contextBuffers.reductionBuffer(); - }; + // return context for current device + return LaunchContext::_contexts[deviceId].get(); +} - void* LaunchContext::getScalarPointer() const { - return contextBuffers.scalarBuffer(); - }; +void* LaunchContext::getReductionPointer() const { + return contextBuffers.reductionBuffer(); +}; - int* LaunchContext::getAllocationPointer() const { - return reinterpret_cast(contextBuffers.allocationBuffer()); - }; +void* LaunchContext::getScalarPointer() const { + return contextBuffers.scalarBuffer(); +}; - void* LaunchContext::getCublasHandle() const { - return CublasHelper::getInstance()->handle(); - }; +int* LaunchContext::getAllocationPointer() const { + return reinterpret_cast(contextBuffers.allocationBuffer()); +}; - void* LaunchContext::getCusolverHandle() const { - return CublasHelper::getInstance()->solver(); - }; +void* LaunchContext::getCublasHandle() const { + return CublasHelper::getInstance()->handle(); +}; - cudaStream_t* LaunchContext::getCudaStream() const { - return reinterpret_cast(contextBuffers.execStream()); - }; +void* LaunchContext::getCusolverHandle() const { + return CublasHelper::getInstance()->solver(); +}; - cudaStream_t* LaunchContext::getCudaSpecialStream() const { - return reinterpret_cast(contextBuffers.specialStream());; - }; +cudaStream_t* LaunchContext::getCudaStream() const { + return reinterpret_cast(contextBuffers.execStream()); +}; +cudaStream_t* LaunchContext::getCudaSpecialStream() const { + return reinterpret_cast(contextBuffers.specialStream()); + ; +}; - void LaunchContext::setReductionPointer (void* reductionPointer) { - contextBuffers.setReductionBuffer(reductionPointer); - }; +void LaunchContext::setReductionPointer(void* reductionPointer) { + contextBuffers.setReductionBuffer(reductionPointer); +}; - void LaunchContext::setScalarPointer(void* scalarPointer) { - contextBuffers.setScalarBuffer(scalarPointer); - }; +void LaunchContext::setScalarPointer(void* scalarPointer) { + contextBuffers.setScalarBuffer(scalarPointer); +}; - void LaunchContext::setAllocationPointer(int* allocationPointer) { - contextBuffers.setAllocationBuffer(allocationPointer); - }; +void LaunchContext::setAllocationPointer(int* allocationPointer) { + contextBuffers.setAllocationBuffer(allocationPointer); +}; - void LaunchContext::setCudaStream(cudaStream_t* cudaStream) { - //_cudaStream = cudaStream; - }; +void LaunchContext::setCudaStream(cudaStream_t* cudaStream){ + //_cudaStream = cudaStream; +}; - void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) { - //_cudaSpecialStream = cudaStream; - }; +void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream){ + //_cudaSpecialStream = cudaStream; +}; - void LaunchContext::setCublasHandle(void *handle) { - _cublasHandle = handle; - }; +void LaunchContext::setCublasHandle(void* handle) { _cublasHandle = handle; }; - void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { - contextBuffers = buffers; - }; +void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { + contextBuffers = buffers; +}; - void LaunchContext::releaseBuffers() { - //nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); - contextBuffers.release(); - } +void LaunchContext::releaseBuffers() { + // nd4j_printf("LaunchContext::releaseBuffers() was invoked\n", ""); + contextBuffers.release(); +} - bool LaunchContext::isInitialized() { - return contextBuffers.isInitialized(); - } +bool LaunchContext::isInitialized() { return contextBuffers.isInitialized(); } - void* LaunchContext::getCuDnnHandle() const { - return CublasHelper::getInstance()->cudnn(); - } +void* LaunchContext::getCuDnnHandle() const { + return CublasHelper::getInstance()->cudnn(); +} - sd::ErrorReference* LaunchContext::errorReference() { - return contextBuffers.errorReference(); - } +sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); +} - void* LaunchContext::engine() { - return _engine; - } -} \ No newline at end of file +void* LaunchContext::engine() { return _engine; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/execution/impl/BlockingQueue.cpp b/libnd4j/include/execution/impl/BlockingQueue.cpp index 21c3b4c6a5f4..80f5f9a58280 100644 --- a/libnd4j/include/execution/impl/BlockingQueue.cpp +++ b/libnd4j/include/execution/impl/BlockingQueue.cpp @@ -20,54 +20,55 @@ #include #include + #include namespace samediff { - template - BlockingQueue::BlockingQueue(int queueSize) { - _size = 0; - _available = true; - } - - template - T BlockingQueue::poll() { - // locking untill there's something within queue - std::unique_lock lock(_lock); - _condition.wait(lock, [&]{ return this->_size.load() != 0; }); +template +BlockingQueue::BlockingQueue(int queueSize) { + _size = 0; + _available = true; +} - T t(std::move(_queue.front())); - _queue.pop(); - _size--; - return t; - } +template +T BlockingQueue::poll() { + // locking untill there's something within queue + std::unique_lock lock(_lock); + _condition.wait(lock, [&] { return this->_size.load() != 0; }); - template - void BlockingQueue::put(const T &t) { - { - // locking before push, unlocking after - std::unique_lock lock(_lock); - _queue.push(t); - _size++; - } + T t(std::move(_queue.front())); + _queue.pop(); + _size--; + return t; +} - // notifying condition - _condition.notify_one(); - } +template +void BlockingQueue::put(const T &t) { + { + // locking before push, unlocking after + std::unique_lock lock(_lock); + _queue.push(t); + _size++; + } - template - bool BlockingQueue::available() { - return _available.load(); - } + // notifying condition + _condition.notify_one(); +} - template - void BlockingQueue::markAvailable() { - _available = true; - } +template +bool BlockingQueue::available() { + return _available.load(); +} - template - void BlockingQueue::markUnavailable() { - _available = false; - } +template +void BlockingQueue::markAvailable() { + _available = true; +} - template class BlockingQueue; +template +void BlockingQueue::markUnavailable() { + _available = false; } + +template class BlockingQueue; +} // namespace samediff diff --git a/libnd4j/include/execution/impl/CallableInterface.cpp b/libnd4j/include/execution/impl/CallableInterface.cpp index a719af848576..8e8c121ac7ab 100644 --- a/libnd4j/include/execution/impl/CallableInterface.cpp +++ b/libnd4j/include/execution/impl/CallableInterface.cpp @@ -22,192 +22,201 @@ #include namespace samediff { - CallableInterface::CallableInterface() { - // initial state is available - _available = true; - _filled = false; - _finished = false; - } - - bool CallableInterface::available() { - return _available.load(); - } - - void CallableInterface::markUnavailable() { - _available = false; - } - - void CallableInterface::markAvailable() { - _available = true; - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) { - _function_do = std::move(func); - - _branch = 0; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, int64_t startX, int64_t stopX, int64_t incX) { - _function_1d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _branch = 1; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y) { - _function_2d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - _arguments[3] = start_y; - _arguments[4] = stop_y; - _arguments[5] = inc_y; - - _branch = 2; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) { - _function_3d = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - _arguments[3] = start_y; - _arguments[4] = stop_y; - _arguments[5] = inc_y; - _arguments[6] = start_z; - _arguments[7] = stop_z; - _arguments[8] = inc_z; - - _branch = 3; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, FUNC_RL func, int64_t startX, int64_t stopX, int64_t incX) { - _function_rl = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _lptr = lptr; - - _branch = 4; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::fill(int threadID, int numThreads, double *dptr, FUNC_RD func, int64_t startX, int64_t stopX, int64_t incX) { - _function_rd = std::move(func); - _arguments[0] = startX; - _arguments[1] = stopX; - _arguments[2] = incX; - - _dptr = dptr; - - _branch = 5; - _num_threads = numThreads; - _thread_id = threadID; - _finished = false; - - { - std::unique_lock l(_ms); - _filled = true; - } - _starter.notify_one(); - } - - void CallableInterface::waitForTask() { - // block until task is available - std::unique_lock lock(_ms); - _starter.wait(lock, [&]{ return _filled.load(); }); - } - - void CallableInterface::waitForCompletion() { - //while (!_finished.load()); - - // block until finished - std::unique_lock lock(_mf); - _finisher.wait(lock, [&] { return _finished.load(); }); - } - - void CallableInterface::finish() { - // mark as finished - { - std::unique_lock l(_mf); - _finished.store(true); - } - _finisher.notify_one(); - } - - void CallableInterface::execute() { - // mark it as consumed - _filled = false; - - // actually executing op - switch (_branch) { - case 0: - _function_do(_thread_id, _num_threads); - break; - case 1: - _function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - case 2: - _function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5]); - break; - case 3: - _function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5], _arguments[6], _arguments[7], _arguments[8]); - break; - case 4: - _lptr[0] = _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - case 5: - _dptr[0] = _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]); - break; - } - - // notify that thread finished the job - this->finish(); - } -} \ No newline at end of file +CallableInterface::CallableInterface() { + // initial state is available + _available = true; + _filled = false; + _finished = false; +} + +bool CallableInterface::available() { return _available.load(); } + +void CallableInterface::markUnavailable() { _available = false; } + +void CallableInterface::markAvailable() { _available = true; } + +void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) { + _function_do = std::move(func); + + _branch = 0; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, + int64_t startX, int64_t stopX, int64_t incX) { + _function_1d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _branch = 1; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, + int64_t startX, int64_t stopX, int64_t incX, + int64_t start_y, int64_t stop_y, int64_t inc_y) { + _function_2d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + _arguments[3] = start_y; + _arguments[4] = stop_y; + _arguments[5] = inc_y; + + _branch = 2; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, + int64_t startX, int64_t stopX, int64_t incX, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z) { + _function_3d = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + _arguments[3] = start_y; + _arguments[4] = stop_y; + _arguments[5] = inc_y; + _arguments[6] = start_z; + _arguments[7] = stop_z; + _arguments[8] = inc_z; + + _branch = 3; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, + FUNC_RL func, int64_t startX, int64_t stopX, + int64_t incX) { + _function_rl = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _lptr = lptr; + + _branch = 4; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::fill(int threadID, int numThreads, double *dptr, + FUNC_RD func, int64_t startX, int64_t stopX, + int64_t incX) { + _function_rd = std::move(func); + _arguments[0] = startX; + _arguments[1] = stopX; + _arguments[2] = incX; + + _dptr = dptr; + + _branch = 5; + _num_threads = numThreads; + _thread_id = threadID; + _finished = false; + + { + std::unique_lock l(_ms); + _filled = true; + } + _starter.notify_one(); +} + +void CallableInterface::waitForTask() { + // block until task is available + std::unique_lock lock(_ms); + _starter.wait(lock, [&] { return _filled.load(); }); +} + +void CallableInterface::waitForCompletion() { + // while (!_finished.load()); + + // block until finished + std::unique_lock lock(_mf); + _finisher.wait(lock, [&] { return _finished.load(); }); +} + +void CallableInterface::finish() { + // mark as finished + { + std::unique_lock l(_mf); + _finished.store(true); + } + _finisher.notify_one(); +} + +void CallableInterface::execute() { + // mark it as consumed + _filled = false; + + // actually executing op + switch (_branch) { + case 0: + _function_do(_thread_id, _num_threads); + break; + case 1: + _function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + case 2: + _function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], + _arguments[3], _arguments[4], _arguments[5]); + break; + case 3: + _function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], + _arguments[3], _arguments[4], _arguments[5], _arguments[6], + _arguments[7], _arguments[8]); + break; + case 4: + _lptr[0] = + _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + case 5: + _dptr[0] = + _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]); + break; + } + + // notify that thread finished the job + this->finish(); +} +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/CallableWithArguments.cpp b/libnd4j/include/execution/impl/CallableWithArguments.cpp index 8f17622b733b..a05fcf1f8eba 100644 --- a/libnd4j/include/execution/impl/CallableWithArguments.cpp +++ b/libnd4j/include/execution/impl/CallableWithArguments.cpp @@ -21,83 +21,75 @@ #include namespace samediff { - CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) { - _function_do = func; - _finished = false; - _threadId = thread_id; - _numThreads = numThreads; - _dimensions = 0; - } - - CallableWithArguments::CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z) { - _function_3d = func; - _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y, start_z, stop_z, increment_z}; - _finished = false; - _threadId = thread_id; - _dimensions = 3; - } - - CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x) { - _function_1d = func; - _arguments = {start_x, stop_x, increment_x}; - _finished = false; - _threadId = thread_id; - _dimensions = 1; - } - - CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y) { - _function_2d = func; - _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y}; - _finished = false; - _threadId = thread_id; - _dimensions = 2; - } - - int CallableWithArguments::dimensions() { - return _dimensions; - } - - std::vector& CallableWithArguments::arguments() { - return _arguments; - } - - bool CallableWithArguments::finished() { - return _finished.load(); - } - - void CallableWithArguments::finish() { - std::lock_guard lock(_lock); - _finished = true; - _condition.notify_one(); - } - - void CallableWithArguments::waitUntilFinished() { - std::unique_lock lock(_lock); - _condition.wait(lock, [&]{ return _finished.load(); }); - } - - - FUNC_1D CallableWithArguments::function_1d() { - return _function_1d; - } - - FUNC_2D CallableWithArguments::function_2d() { - return _function_2d; - } - - FUNC_DO CallableWithArguments::function_do() { - return _function_do; - } - - FUNC_3D CallableWithArguments::function_3d() { - return _function_3d; - } - - uint64_t CallableWithArguments::threadId() { - return _threadId; - } - - uint64_t CallableWithArguments::numThreads() { - return _numThreads; - } -} \ No newline at end of file +CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, + uint64_t numThreads) { + _function_do = func; + _finished = false; + _threadId = thread_id; + _numThreads = numThreads; + _dimensions = 0; +} + +CallableWithArguments::CallableWithArguments( + FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, + int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, + int64_t start_z, int64_t stop_z, int64_t increment_z) { + _function_3d = func; + _arguments = {start_x, stop_x, increment_x, start_y, stop_y, + increment_y, start_z, stop_z, increment_z}; + _finished = false; + _threadId = thread_id; + _dimensions = 3; +} + +CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, + int64_t start_x, int64_t stop_x, + int64_t increment_x) { + _function_1d = func; + _arguments = {start_x, stop_x, increment_x}; + _finished = false; + _threadId = thread_id; + _dimensions = 1; +} + +CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, + int64_t start_x, int64_t stop_x, + int64_t increment_x, + int64_t start_y, int64_t stop_y, + int64_t increment_y) { + _function_2d = func; + _arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y}; + _finished = false; + _threadId = thread_id; + _dimensions = 2; +} + +int CallableWithArguments::dimensions() { return _dimensions; } + +std::vector& CallableWithArguments::arguments() { return _arguments; } + +bool CallableWithArguments::finished() { return _finished.load(); } + +void CallableWithArguments::finish() { + std::lock_guard lock(_lock); + _finished = true; + _condition.notify_one(); +} + +void CallableWithArguments::waitUntilFinished() { + std::unique_lock lock(_lock); + _condition.wait(lock, [&] { return _finished.load(); }); +} + +FUNC_1D CallableWithArguments::function_1d() { return _function_1d; } + +FUNC_2D CallableWithArguments::function_2d() { return _function_2d; } + +FUNC_DO CallableWithArguments::function_do() { return _function_do; } + +FUNC_3D CallableWithArguments::function_3d() { return _function_3d; } + +uint64_t CallableWithArguments::threadId() { return _threadId; } + +uint64_t CallableWithArguments::numThreads() { return _numThreads; } +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/ErrorReference.cpp b/libnd4j/include/execution/impl/ErrorReference.cpp index 7b3409aa1339..307ed874c602 100644 --- a/libnd4j/include/execution/impl/ErrorReference.cpp +++ b/libnd4j/include/execution/impl/ErrorReference.cpp @@ -18,29 +18,25 @@ // @author raver119@gmail.com // - #include namespace sd { - int ErrorReference::errorCode() { - return _errorCode; - } +int ErrorReference::errorCode() { return _errorCode; } - const char* ErrorReference::errorMessage() { - // since we're fetching error message - error code will be assumed consumed & nullified - _errorCode = 0; - return _errorMessage.c_str(); - } +const char* ErrorReference::errorMessage() { + // since we're fetching error message - error code will be assumed consumed & + // nullified + _errorCode = 0; + return _errorMessage.c_str(); +} - void ErrorReference::setErrorCode(int errorCode) { - _errorCode = errorCode; - } +void ErrorReference::setErrorCode(int errorCode) { _errorCode = errorCode; } - void ErrorReference::setErrorMessage(std::string message) { - _errorMessage = message; - } +void ErrorReference::setErrorMessage(std::string message) { + _errorMessage = message; +} - void ErrorReference::setErrorMessage(const char* message) { - _errorMessage = std::string(message); - } +void ErrorReference::setErrorMessage(const char* message) { + _errorMessage = std::string(message); } +} // namespace sd diff --git a/libnd4j/include/execution/impl/ThreadPool.cpp b/libnd4j/include/execution/impl/ThreadPool.cpp index b02c4c4d5f4c..e9cbe328b5c5 100644 --- a/libnd4j/include/execution/impl/ThreadPool.cpp +++ b/libnd4j/include/execution/impl/ThreadPool.cpp @@ -19,176 +19,177 @@ // #include -#include #include +#include + #if defined(_WIN32) || defined(_WIN64) //#include #endif namespace samediff { - // this function executed once per thread, it polls functions from queue, and executes them via wrapper - static void executionLoop_(int thread_id, BlockingQueue *queue) { - while (true) { - // this method blocks until there's something within queue - auto c = queue->poll(); - //nd4j_printf("ThreadPool: starting thread %i\n", c->threadId()); - switch (c->dimensions()) { - case 0: { - c->function_do()(c->threadId(), c->numThreads()); - c->finish(); - } - break; - case 1: { - auto args = c->arguments(); - c->function_1d()(c->threadId(), args[0], args[1], args[2]); - c->finish(); - } - break; - case 2: { - auto args = c->arguments(); - c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5]); - c->finish(); - //nd4j_printf("ThreadPool: finished thread %i\n", c->threadId()); - } - break; - case 3: { - auto args = c->arguments(); - c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]); - c->finish(); - } - break; - default: - throw std::runtime_error("Don't know what to do with provided Callable"); - } - } +// this function executed once per thread, it polls functions from queue, and +// executes them via wrapper +static void executionLoop_(int thread_id, + BlockingQueue *queue) { + while (true) { + // this method blocks until there's something within queue + auto c = queue->poll(); + // nd4j_printf("ThreadPool: starting thread %i\n", c->threadId()); + switch (c->dimensions()) { + case 0: { + c->function_do()(c->threadId(), c->numThreads()); + c->finish(); + } break; + case 1: { + auto args = c->arguments(); + c->function_1d()(c->threadId(), args[0], args[1], args[2]); + c->finish(); + } break; + case 2: { + auto args = c->arguments(); + c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], + args[4], args[5]); + c->finish(); + // nd4j_printf("ThreadPool: finished thread %i\n", c->threadId()); + } break; + case 3: { + auto args = c->arguments(); + c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], + args[4], args[5], args[6], args[7], args[8]); + c->finish(); + } break; + default: + throw std::runtime_error( + "Don't know what to do with provided Callable"); } + } +} - static void executionLoopWithInterface_(int thread_id, CallableInterface *c) { - while (true) { - // blocking here until there's something to do - c->waitForTask(); +static void executionLoopWithInterface_(int thread_id, CallableInterface *c) { + while (true) { + // blocking here until there's something to do + c->waitForTask(); - // execute whatever we have - c->execute(); - } - } + // execute whatever we have + c->execute(); + } +} - ThreadPool::ThreadPool() { - // TODO: number of threads must reflect number of cores for UMA system. In case of NUMA it should be per-device pool - // FIXME: on mobile phones this feature must NOT be used - _available = sd::Environment::getInstance()->maxThreads(); - - _queues.resize(_available.load()); - _threads.resize(_available.load()); - _interfaces.resize(_available.load()); - - // creating threads here - for (int e = 0; e < _available.load(); e++) { - _queues[e] = new BlockingQueue(2); - _interfaces[e] = new CallableInterface(); - _threads[e] = new std::thread(executionLoopWithInterface_, e, _interfaces[e]); - _tickets.push(new Ticket()); - // _threads[e] = new std::thread(executionLoop_, e, _queues[e]); - - // TODO: add other platforms here as well - // now we must set affinity, and it's going to be platform-specific thing +ThreadPool::ThreadPool() { + // TODO: number of threads must reflect number of cores for UMA system. In + // case of NUMA it should be per-device pool + // FIXME: on mobile phones this feature must NOT be used + _available = sd::Environment::getInstance()->maxThreads(); + + _queues.resize(_available.load()); + _threads.resize(_available.load()); + _interfaces.resize(_available.load()); + + // creating threads here + for (int e = 0; e < _available.load(); e++) { + _queues[e] = new BlockingQueue(2); + _interfaces[e] = new CallableInterface(); + _threads[e] = + new std::thread(executionLoopWithInterface_, e, _interfaces[e]); + _tickets.push(new Ticket()); + // _threads[e] = new std::thread(executionLoop_, e, _queues[e]); + + // TODO: add other platforms here as well + // now we must set affinity, and it's going to be platform-specific thing #ifdef LINUX_BUILD - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(e, &cpuset); - int rc = pthread_setaffinity_np(_threads[e]->native_handle(), sizeof(cpu_set_t), &cpuset); - if (rc != 0) - throw std::runtime_error("Failed to set pthread affinity"); + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(e, &cpuset); + int rc = pthread_setaffinity_np(_threads[e]->native_handle(), + sizeof(cpu_set_t), &cpuset); + if (rc != 0) throw std::runtime_error("Failed to set pthread affinity"); #endif - /* + /* #if defined(_WIN32) || defined(_WIN64) - // we can't set affinity to more than 64 cores - if (e <= 64) { - auto mask = (static_cast(1) << e); - auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask); - if (!result) - throw std::runtime_error("Failed to set pthread affinity"); - } - - // that's fine. no need for time_critical here - SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST); -#endif - */ - } + // we can't set affinity to more than 64 cores + if (e <= 64) { + auto mask = (static_cast(1) << e); + auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask); + if (!result) + throw std::runtime_error("Failed to set pthread affinity"); } - ThreadPool::~ThreadPool() { - // TODO: implement this one properly - for (int e = 0; e < _queues.size(); e++) { - // stop each and every thread - - // release queue and thread - //delete _queues[e]; - //delete _threads[e]; - } - } + // that's fine. no need for time_critical here + SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST); +#endif + */ + } +} - static std::mutex _lmutex; +ThreadPool::~ThreadPool() { + // TODO: implement this one properly + for (int e = 0; e < _queues.size(); e++) { + // stop each and every thread - ThreadPool* ThreadPool::getInstance() { - std::unique_lock lock(_lmutex); - if (!_INSTANCE) - _INSTANCE = new ThreadPool(); + // release queue and thread + // delete _queues[e]; + // delete _threads[e]; + } +} - return _INSTANCE; - } +static std::mutex _lmutex; - void ThreadPool::release(int numThreads) { - _available += numThreads; - } +ThreadPool *ThreadPool::getInstance() { + std::unique_lock lock(_lmutex); + if (!_INSTANCE) _INSTANCE = new ThreadPool(); - Ticket* ThreadPool::tryAcquire(int numThreads) { - //std::vector*> queues; - - Ticket *t = nullptr; - // we check for threads availability first - bool threaded = false; - { - // we lock before checking availability - std::unique_lock lock(_lock); - if (_available >= numThreads) { - threaded = true; - _available -= numThreads; - - // getting a ticket from the queue - t = _tickets.front(); - _tickets.pop(); - - // ticket must contain information about number of threads for the current session - t->acquiredThreads(numThreads); - - // filling ticket with executable interfaces - for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) { - if (_interfaces[e]->available()) { - t->attach(i++, _interfaces[e]); - _interfaces[e]->markUnavailable(); - } - } - } - } + return _INSTANCE; +} - // we either dispatch tasks to threads, or run single-threaded - if (threaded) { - return t; - } else { - // if there's no threads available - return nullptr - return nullptr; +void ThreadPool::release(int numThreads) { _available += numThreads; } + +Ticket *ThreadPool::tryAcquire(int numThreads) { + // std::vector*> queues; + + Ticket *t = nullptr; + // we check for threads availability first + bool threaded = false; + { + // we lock before checking availability + std::unique_lock lock(_lock); + if (_available >= numThreads) { + threaded = true; + _available -= numThreads; + + // getting a ticket from the queue + t = _tickets.front(); + _tickets.pop(); + + // ticket must contain information about number of threads for the current + // session + t->acquiredThreads(numThreads); + + // filling ticket with executable interfaces + for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) { + if (_interfaces[e]->available()) { + t->attach(i++, _interfaces[e]); + _interfaces[e]->markUnavailable(); } + } } + } + + // we either dispatch tasks to threads, or run single-threaded + if (threaded) { + return t; + } else { + // if there's no threads available - return nullptr + return nullptr; + } +} - void ThreadPool::release(samediff::Ticket *ticket) { - // returning ticket back to the queue - std::unique_lock lock(_lock); - _tickets.push(ticket); - } - - - ThreadPool* ThreadPool::_INSTANCE = 0; +void ThreadPool::release(samediff::Ticket *ticket) { + // returning ticket back to the queue + std::unique_lock lock(_lock); + _tickets.push(ticket); } + +ThreadPool *ThreadPool::_INSTANCE = 0; +} // namespace samediff diff --git a/libnd4j/include/execution/impl/Threads.cpp b/libnd4j/include/execution/impl/Threads.cpp index 51339abf1261..f62f5a0a35df 100644 --- a/libnd4j/include/execution/impl/Threads.cpp +++ b/libnd4j/include/execution/impl/Threads.cpp @@ -17,707 +17,685 @@ // // @author raver119@gmail.com // -#include #include -#include -#include +#include #include -#include #include +#include +#include +#include namespace samediff { - int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { - // let's see how many threads we actually need first - auto optimalThreads = sd::math::nd4j_max(1, numberOfElements / 1024); - - // now return the smallest value - return sd::math::nd4j_min(optimalThreads, maxThreads); - } - - Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { - _startX = startX; - _startY = startY; - _startZ = startZ; - _stopX = stopX; - _stopY = stopY; - _stopZ = stopZ; - _incX = incX; - _incY = incY; - _incZ = incZ; - } - - Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) { - switch (loop) { - case 1: { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); - } - break; - case 2: { - auto span = (stopY - startY) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopY; - - return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); - } - break; - case 3: { - auto span = (stopZ - startZ) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopZ; - - return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); - } - break; - default: - throw std::runtime_error(""); - } - return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - } - - Span::Span(int64_t startX, int64_t stopX, int64_t incX) { - _startX = startX; - _stopX = stopX; - _incX = incX; - } - - Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX) { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span(s, e, incX); - } - - Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { - _startX = startX; - _startY = startY; - _stopX = stopX; - _stopY = stopY; - _incX = incX; - _incY = incY; - } - - - Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) { - - switch (loop) { - case 1: { - auto span = (stopX - startX) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopX; - - return Span2(s, e, incX, startY, stopY, incY); - } - break; - case 2: { - auto span = (stopY - startY) / numThreads; - auto s = span * threadID; - auto e = s + span; - if (threadID == numThreads - 1) - e = stopY; - - return Span2(startX, stopX, incX, s, e, incY); - } - break; - default: - throw std::runtime_error(""); - } - } - - int64_t Span::startX() const { - return _startX; - } - - int64_t Span::stopX() const { - return _stopX; - } - - int64_t Span::incX() const { - return _incX; - } - - int64_t Span2::startX() const { - return _startX; - } - - int64_t Span2::startY() const { - return _startY; - } - - int64_t Span2::stopX() const { - return _stopX; - } - - int64_t Span2::stopY() const { - return _stopY; - } - - int64_t Span2::incX() const { - return _incX; - } - - int64_t Span2::incY() const { - return _incY; - } - - int64_t Span3::startX() const { - return _startX; - } - - int64_t Span3::startY() const { - return _startY; - } - - int64_t Span3::startZ() const { - return _startZ; - } - - int64_t Span3::stopX() const { - return _stopX; - } - - int64_t Span3::stopY() const { - return _stopY; - } - - int64_t Span3::stopZ() const { - return _stopZ; - } - - int64_t Span3::incX() const { - return _incX; - } - - int64_t Span3::incY() const { - return _incY; - } - - int64_t Span3::incZ() const { - return _incZ; - } - - int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, uint64_t itersY) { - // if one of dimensions is definitely too small - we just pick the other one - if (itersX < numThreads && itersY >= numThreads) - return 2; - if (itersY < numThreads && itersX >= numThreads) - return 1; - - // next step - we pick the most balanced dimension - auto remX = itersX % numThreads; - auto remY = itersY % numThreads; - auto splitY = itersY / numThreads; - - // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution - if (remX == 0) - return 1; - if (remY == 0) - return 2; - - // if there's no loop without a remainder - we're picking one with smaller remainder - if (remX < remY) - return 1; - if (remY < remX && splitY >= 64) // we don't want too small splits over last dimension, or vectorization will fail - return 2; - // if loops are equally sized - give the preference to the first thread - return 1; - } - - - static int threads_(int maxThreads, uint64_t elements) { - - if (elements == maxThreads) { - return maxThreads; - } - else if (elements > maxThreads) { - // if we have full load across thread, or at least half of threads can be utilized - auto rem = elements % maxThreads; - if (rem == 0 || rem >= maxThreads / 3) - return maxThreads; - else - return threads_(maxThreads - 1, elements); - - } - else if (elements < maxThreads) { - return elements; - } - - return 1; - } - - int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y) { - // in some cases there's nothing to think about, part 1 - if (iters_x < maxThreads && iters_y < maxThreads) - return sd::math::nd4j_max(iters_x, iters_y); - - auto remX = iters_x % maxThreads; - auto remY = iters_y % maxThreads; - - // in some cases there's nothing to think about, part 2 - if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0)) - return maxThreads; - - // at this point we suppose that there's no loop perfectly matches number of our threads - // so let's pick something as equal as possible - if (iters_x > maxThreads || iters_y > maxThreads) - return maxThreads; - else - return numberOfThreads2d(maxThreads - 1, iters_x, iters_y); - } - - int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { - // we don't want to run underloaded threads - if (itersX * itersY * itersZ <= 32) - return 1; - - auto remX = itersX % maxThreads; - auto remY = itersY % maxThreads; - auto remZ = itersZ % maxThreads; - - // if we have perfect balance across one of dimensions - just go for it - if ((itersX >= maxThreads && remX == 0) || (itersY >= maxThreads && remY == 0) || (itersZ >= maxThreads && remZ == 0)) - return maxThreads; - - int threadsX = 0, threadsY = 0, threadsZ = 0; - - // now we look into possible number of - threadsX = threads_(maxThreads, itersX); - threadsY = threads_(maxThreads, itersY); - threadsZ = threads_(maxThreads, itersZ); - - // we want to split as close to outer loop as possible, so checking it out first - if (threadsX >= threadsY && threadsX >= threadsZ) - return threadsX; - else if (threadsY >= threadsX && threadsY >= threadsZ) - return threadsY; - else if (threadsZ >= threadsX && threadsZ >= threadsY) - return threadsZ; - - return 1; - } - - int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) { - auto remX = itersX % numThreads; - auto remY = itersY % numThreads; - auto remZ = itersZ % numThreads; - - auto splitX = itersX / numThreads; - auto splitY = itersY / numThreads; - auto splitZ = itersZ / numThreads; - - // if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution - if (remX == 0) - return 1; - else if (remY == 0) - return 2; - else if (remZ == 0) // TODO: we don't want too smal splits over last dimension? or we do? - return 3; - - if (itersX > numThreads) - return 1; - else if (itersY > numThreads) - return 2; - else if (itersZ > numThreads) - return 3; - - return 1; - } - - int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - - auto delta = (stop - start); - - if (numThreads > delta) - numThreads = delta; - - if (numThreads == 0) - return 0; - - // shortcut - if (numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); - if (ticket != nullptr) { - // if we got our threads - we'll run our jobs here - auto span = delta / numThreads; - - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = start_ + span; - - // last thread will process tail - if (e == numThreads - 1) - stop_ = stop; - - // putting the task into the queue for a given thread - ticket->enqueue(e, numThreads, function, start_, stop_, increment); - } - - // block and wait till all threads finished the job - ticket->waitAndRelease(); - - // we tell that parallelism request succeeded - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, start, stop, increment); - - // we tell that parallelism request declined - return 1; - } - } - - int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - - auto delta = (stop - start); - - // in some cases we just fire func as is - if (delta == 0 || numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it in parallel_tad. - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - return parallel_tad(function, start, stop, increment, numThreads); - } - - int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, uint64_t numThreads, bool debug) { - if (startX > stopX) - throw std::runtime_error("Threads::parallel_for got startX > stopX"); - - if (startY > stopY) - throw std::runtime_error("Threads::parallel_for got startY > stopY"); - - // number of elements per loop - auto delta_x = (stopX - startX); - auto delta_y = (stopY - startY); - - // number of iterations per loop - auto itersX = delta_x / incX; - auto itersY = delta_y / incY; - - // total number of iterations - auto iters_t = itersX * itersY; - - // we are checking the case of number of requested threads was smaller - numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY); - - // basic shortcut for no-threading cases - if (numThreads == 1) { - function(0, startX, stopX, incX, startY, stopY, incY); - return 1; - } - - // We have couple of scenarios: - // either we split workload along 1st loop, or 2nd - auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); - - // for debug mode we execute things inplace, without any threads - if (debug) { - for (int e = 0; e < numThreads; e++) { - auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, startY, stopY, incY); - - function(e, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); - } - - // but we still mimic multithreaded execution - return numThreads; - } else { - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); - if (ticket != nullptr) { - - for (int e = 0; e < numThreads; e++) { - auto threadId = numThreads - e - 1; - auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY); - - ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY()); - } - - // block until all threads finish their job - ticket->waitAndRelease(); - - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, startX, stopX, incX, startY, stopY, incY); - - // we tell that parallelism request declined - return 1; - } - }; - } - - - int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ, uint64_t numThreads) { - if (startX > stopX) - throw std::runtime_error("Threads::parallel_for got startX > stopX"); - - if (startY > stopY) - throw std::runtime_error("Threads::parallel_for got startY > stopY"); - - if (startZ > stopZ) - throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); - - auto delta_x = stopX - startX; - auto delta_y = stopY - startY; - auto delta_z = stopZ - startZ; - - auto itersX = delta_x / incX; - auto itersY = delta_y / incY; - auto itersZ = delta_z / incZ; - - numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); - if (numThreads == 1) { - // loop is too small - executing function as is - function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - return 1; - } - - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); - if (ticket != nullptr) { - auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); - - for (int e = 0; e < numThreads; e++) { - auto thread_id = numThreads - e - 1; - auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - - ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ()); - } - - // block until we're done - ticket->waitAndRelease(); - - // we tell that parallelism request succeeded - return numThreads; - } else { - // if there were no threads available - we'll execute function right within current thread - function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); - - // we tell that parallelism request declined - return 1; - } - - } - - int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) { - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); - if (ticket != nullptr) { - - // submit tasks one by one - for (uint64_t e = 0; e < numThreads - 1; e++) - ticket->enqueue(e, numThreads, function); - - function(numThreads - 1, numThreads); - - ticket->waitAndRelease(); - - return numThreads; - } else { - // if there's no threads available - we'll execute function sequentially one by one - for (uint64_t e = 0; e < numThreads; e++) - function(e, numThreads); - - return numThreads; - } - - - return numThreads; - } - - int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_long got start > stop"); - - auto delta = (stop - start); - if (delta == 0 || numThreads == 1) - return function(0, start, stop, increment); - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - if (numThreads == 1) - return function(0, start, stop, increment); - - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); - if (ticket == nullptr) - return function(0, start, stop, increment); - - // create temporary array - int64_t intermediatery[256]; - auto span = (numElements / numThreads) - (numElements % numThreads); - - // execute threads in parallel - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = span * (e + 1) + start; - - if (e == numThreads - 1) - intermediatery[e] = function(e, start_, stop, increment); - else - ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); - } - - ticket->waitAndRelease(); - - // aggregate results in single thread - for (uint64_t e = 1; e < numThreads; e++) - intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); - - // return accumulated result - return intermediatery[0]; - } - - double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_long got start > stop"); - - auto delta = (stop - start); - if (delta == 0 || numThreads == 1) - return function(0, start, stop, increment); - - auto numElements = delta / increment; - - // we decide what's optimal number of threads we need here, and execute it - numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); - if (numThreads == 1) - return function(0, start, stop, increment); - - auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); - if (ticket == nullptr) - return function(0, start, stop, increment); - - // create temporary array - double intermediatery[256]; - auto span = (numElements / numThreads) - (numElements % numThreads); - - // execute threads in parallel - for (uint32_t e = 0; e < numThreads; e++) { - auto start_ = span * e + start; - auto stop_ = span * (e + 1) + start; - - if (e == numThreads - 1) - intermediatery[e] = function(e, start_, stop, increment); - else - ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment); - } - - ticket->waitAndRelease(); - - // aggregate results in single thread - for (uint64_t e = 1; e < numThreads; e++) - intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); - - // return accumulated result - return intermediatery[0]; - } - - - int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, size_t type_size , uint32_t req_numThreads) { - if (start > stop) - throw std::runtime_error("Threads::parallel_for got start > stop"); - auto num_elements = (stop - start); - //this way we preserve increment starts offset - //so we will parition considering delta but not total elements - auto delta = (stop - start) / increment; - - // in some cases we just fire func as is - if (delta == 0 || req_numThreads == 1) { - function(0, start, stop, increment); - return 1; - } - int numThreads = 0; - - int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads(req_numThreads, (num_elements * sizeof(double)) / (200 * type_size)); - - if (adjusted_numThreads > delta) - adjusted_numThreads = delta; - // shortcut - if (adjusted_numThreads <= 1) { - function(0, start, stop, increment); - return 1; - } - //take span as ceil - auto spand = std::ceil((double)delta / (double)adjusted_numThreads); - numThreads = static_cast(std::ceil((double)delta / spand)); - auto span = static_cast(spand); - - auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads); - if (ticket != nullptr) { - //tail_add is additional value of the last part - //it could be negative or positive - //we will spread that value across - auto tail_add = delta - numThreads * span; - Nd4jLong begin = 0; - Nd4jLong end = 0; - - //we will try enqueu bigger parts first - decltype(span) span1, span2; - int last = 0; - if (tail_add >= 0) { - //for span == 1 , tail_add is 0 - last = tail_add; - span1 = span + 1; - span2 = span; - } - else { - last = numThreads + tail_add;// -std::abs(tail_add); - span1 = span; - span2 = span - 1; - } - for (int i = 0; i < last; i++) { - end = begin + span1 * increment; - // putting the task into the queue for a given thread - ticket->enqueue(i, numThreads, function, begin, end, increment); - begin = end; - } - for (int i = last; i < numThreads - 1; i++) { - end = begin + span2 * increment; - // putting the task into the queue for a given thread - ticket->enqueue(i, numThreads, function, begin, end, increment); - begin = end; - } - //for last one enqueue last offset as stop - //we need it in case our ((stop-start) % increment ) > 0 - ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, increment); - // block and wait till all threads finished the job - ticket->waitAndRelease(); - // we tell that parallelism request succeeded - return numThreads; - } - else { - // if there were no threads available - we'll execute function right within current thread - function(0, start, stop, increment); - // we tell that parallelism request declined - return 1; - } - } - +int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) { + // let's see how many threads we actually need first + auto optimalThreads = + sd::math::nd4j_max(1, numberOfElements / 1024); + + // now return the smallest value + return sd::math::nd4j_min(optimalThreads, maxThreads); +} + +Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ) { + _startX = startX; + _startY = startY; + _startZ = startZ; + _stopX = stopX; + _stopY = stopY; + _stopZ = stopZ; + _incX = incX; + _incY = incY; + _incZ = incZ; +} + +Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, + int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ) { + switch (loop) { + case 1: { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ); + } break; + case 2: { + auto span = (stopY - startY) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopY; + + return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ); + } break; + case 3: { + auto span = (stopZ - startZ) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopZ; + + return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ); + } break; + default: + throw std::runtime_error(""); + } + return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); +} + +Span::Span(int64_t startX, int64_t stopX, int64_t incX) { + _startX = startX; + _stopX = stopX; + _incX = incX; +} + +Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, + int64_t stopX, int64_t incX) { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span(s, e, incX); +} + +Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY) { + _startX = startX; + _startY = startY; + _stopX = stopX; + _stopY = stopY; + _incX = incX; + _incY = incY; +} + +Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, + int64_t startX, int64_t stopX, int64_t incX, int64_t startY, + int64_t stopY, int64_t incY) { + switch (loop) { + case 1: { + auto span = (stopX - startX) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopX; + + return Span2(s, e, incX, startY, stopY, incY); + } break; + case 2: { + auto span = (stopY - startY) / numThreads; + auto s = span * threadID; + auto e = s + span; + if (threadID == numThreads - 1) e = stopY; + + return Span2(startX, stopX, incX, s, e, incY); + } break; + default: + throw std::runtime_error(""); + } +} + +int64_t Span::startX() const { return _startX; } + +int64_t Span::stopX() const { return _stopX; } + +int64_t Span::incX() const { return _incX; } + +int64_t Span2::startX() const { return _startX; } + +int64_t Span2::startY() const { return _startY; } + +int64_t Span2::stopX() const { return _stopX; } + +int64_t Span2::stopY() const { return _stopY; } + +int64_t Span2::incX() const { return _incX; } + +int64_t Span2::incY() const { return _incY; } + +int64_t Span3::startX() const { return _startX; } + +int64_t Span3::startY() const { return _startY; } + +int64_t Span3::startZ() const { return _startZ; } + +int64_t Span3::stopX() const { return _stopX; } + +int64_t Span3::stopY() const { return _stopY; } + +int64_t Span3::stopZ() const { return _stopZ; } + +int64_t Span3::incX() const { return _incX; } + +int64_t Span3::incY() const { return _incY; } + +int64_t Span3::incZ() const { return _incZ; } + +int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, + uint64_t itersY) { + // if one of dimensions is definitely too small - we just pick the other one + if (itersX < numThreads && itersY >= numThreads) return 2; + if (itersY < numThreads && itersX >= numThreads) return 1; + + // next step - we pick the most balanced dimension + auto remX = itersX % numThreads; + auto remY = itersY % numThreads; + auto splitY = itersY / numThreads; + + // if there's no remainder left in some dimension - we're picking that + // dimension, because it'll be the most balanced work distribution + if (remX == 0) return 1; + if (remY == 0) return 2; + + // if there's no loop without a remainder - we're picking one with smaller + // remainder + if (remX < remY) return 1; + if (remY < remX && splitY >= 64) // we don't want too small splits over last + // dimension, or vectorization will fail + return 2; + // if loops are equally sized - give the preference to the first thread + return 1; +} + +static int threads_(int maxThreads, uint64_t elements) { + if (elements == maxThreads) { + return maxThreads; + } else if (elements > maxThreads) { + // if we have full load across thread, or at least half of threads can be + // utilized + auto rem = elements % maxThreads; + if (rem == 0 || rem >= maxThreads / 3) + return maxThreads; + else + return threads_(maxThreads - 1, elements); + + } else if (elements < maxThreads) { + return elements; + } + + return 1; +} + +int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, + uint64_t iters_y) { + // in some cases there's nothing to think about, part 1 + if (iters_x < maxThreads && iters_y < maxThreads) + return sd::math::nd4j_max(iters_x, iters_y); + + auto remX = iters_x % maxThreads; + auto remY = iters_y % maxThreads; + + // in some cases there's nothing to think about, part 2 + if ((iters_x >= maxThreads && remX == 0) || + (iters_y >= maxThreads && remY == 0)) + return maxThreads; + + // at this point we suppose that there's no loop perfectly matches number of + // our threads so let's pick something as equal as possible + if (iters_x > maxThreads || iters_y > maxThreads) + return maxThreads; + else + return numberOfThreads2d(maxThreads - 1, iters_x, iters_y); +} -} \ No newline at end of file +int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, + uint64_t itersY, uint64_t itersZ) { + // we don't want to run underloaded threads + if (itersX * itersY * itersZ <= 32) return 1; + + auto remX = itersX % maxThreads; + auto remY = itersY % maxThreads; + auto remZ = itersZ % maxThreads; + + // if we have perfect balance across one of dimensions - just go for it + if ((itersX >= maxThreads && remX == 0) || + (itersY >= maxThreads && remY == 0) || + (itersZ >= maxThreads && remZ == 0)) + return maxThreads; + + int threadsX = 0, threadsY = 0, threadsZ = 0; + + // now we look into possible number of + threadsX = threads_(maxThreads, itersX); + threadsY = threads_(maxThreads, itersY); + threadsZ = threads_(maxThreads, itersZ); + + // we want to split as close to outer loop as possible, so checking it out + // first + if (threadsX >= threadsY && threadsX >= threadsZ) + return threadsX; + else if (threadsY >= threadsX && threadsY >= threadsZ) + return threadsY; + else if (threadsZ >= threadsX && threadsZ >= threadsY) + return threadsZ; + + return 1; +} + +int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, + uint64_t itersZ) { + auto remX = itersX % numThreads; + auto remY = itersY % numThreads; + auto remZ = itersZ % numThreads; + + auto splitX = itersX / numThreads; + auto splitY = itersY / numThreads; + auto splitZ = itersZ / numThreads; + + // if there's no remainder left in some dimension - we're picking that + // dimension, because it'll be the most balanced work distribution + if (remX == 0) + return 1; + else if (remY == 0) + return 2; + else if (remZ == 0) // TODO: we don't want too smal splits over last + // dimension? or we do? + return 3; + + if (itersX > numThreads) + return 1; + else if (itersY > numThreads) + return 2; + else if (itersZ > numThreads) + return 3; + + return 1; +} + +int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, + int64_t increment, uint32_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + + auto delta = (stop - start); + + if (numThreads > delta) numThreads = delta; + + if (numThreads == 0) return 0; + + // shortcut + if (numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); + if (ticket != nullptr) { + // if we got our threads - we'll run our jobs here + auto span = delta / numThreads; + + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = start_ + span; + + // last thread will process tail + if (e == numThreads - 1) stop_ = stop; + + // putting the task into the queue for a given thread + ticket->enqueue(e, numThreads, function, start_, stop_, increment); + } + + // block and wait till all threads finished the job + ticket->waitAndRelease(); + + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, start, stop, increment); + + // we tell that parallelism request declined + return 1; + } +} + +int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, + int64_t increment, uint32_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + + auto delta = (stop - start); + + // in some cases we just fire func as is + if (delta == 0 || numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it in + // parallel_tad. + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + return parallel_tad(function, start, stop, increment, numThreads); +} + +int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, + int64_t incX, int64_t startY, int64_t stopY, + int64_t incY, uint64_t numThreads, bool debug) { + if (startX > stopX) + throw std::runtime_error("Threads::parallel_for got startX > stopX"); + + if (startY > stopY) + throw std::runtime_error("Threads::parallel_for got startY > stopY"); + + // number of elements per loop + auto delta_x = (stopX - startX); + auto delta_y = (stopY - startY); + + // number of iterations per loop + auto itersX = delta_x / incX; + auto itersY = delta_y / incY; + + // total number of iterations + auto iters_t = itersX * itersY; + + // we are checking the case of number of requested threads was smaller + numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY); + + // basic shortcut for no-threading cases + if (numThreads == 1) { + function(0, startX, stopX, incX, startY, stopY, incY); + return 1; + } + + // We have couple of scenarios: + // either we split workload along 1st loop, or 2nd + auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); + + // for debug mode we execute things inplace, without any threads + if (debug) { + for (int e = 0; e < numThreads; e++) { + auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, + startY, stopY, incY); + + function(e, span.startX(), span.stopX(), span.incX(), span.startY(), + span.stopY(), span.incY()); + } + + // but we still mimic multithreaded execution + return numThreads; + } else { + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); + if (ticket != nullptr) { + for (int e = 0; e < numThreads; e++) { + auto threadId = numThreads - e - 1; + auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, + incX, startY, stopY, incY); + + ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), + span.incX(), span.startY(), span.stopY(), span.incY()); + } + + // block until all threads finish their job + ticket->waitAndRelease(); + + return numThreads; + } else { + // if there were no threads available - we'll execute function right + // within current thread + function(0, startX, stopX, incX, startY, stopY, incY); + + // we tell that parallelism request declined + return 1; + } + }; +} + +int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, + int64_t incX, int64_t startY, int64_t stopY, + int64_t incY, int64_t startZ, int64_t stopZ, + int64_t incZ, uint64_t numThreads) { + if (startX > stopX) + throw std::runtime_error("Threads::parallel_for got startX > stopX"); + + if (startY > stopY) + throw std::runtime_error("Threads::parallel_for got startY > stopY"); + + if (startZ > stopZ) + throw std::runtime_error("Threads::parallel_for got startZ > stopZ"); + + auto delta_x = stopX - startX; + auto delta_y = stopY - startY; + auto delta_z = stopZ - startZ; + + auto itersX = delta_x / incX; + auto itersY = delta_y / incY; + auto itersZ = delta_z / incZ; + + numThreads = + ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); + if (numThreads == 1) { + // loop is too small - executing function as is + function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); + return 1; + } + + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads); + if (ticket != nullptr) { + auto splitLoop = + ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ); + + for (int e = 0; e < numThreads; e++) { + auto thread_id = numThreads - e - 1; + auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, + incX, startY, stopY, incY, startZ, stopZ, incZ); + + ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), + span.incX(), span.startY(), span.stopY(), span.incY(), + span.startZ(), span.stopZ(), span.incZ()); + } + + // block until we're done + ticket->waitAndRelease(); + + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); + + // we tell that parallelism request declined + return 1; + } +} + +int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) { + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); + if (ticket != nullptr) { + // submit tasks one by one + for (uint64_t e = 0; e < numThreads - 1; e++) + ticket->enqueue(e, numThreads, function); + + function(numThreads - 1, numThreads); + + ticket->waitAndRelease(); + + return numThreads; + } else { + // if there's no threads available - we'll execute function sequentially one + // by one + for (uint64_t e = 0; e < numThreads; e++) function(e, numThreads); + + return numThreads; + } + + return numThreads; +} + +int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, + int64_t start, int64_t stop, int64_t increment, + uint64_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_long got start > stop"); + + auto delta = (stop - start); + if (delta == 0 || numThreads == 1) return function(0, start, stop, increment); + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + if (numThreads == 1) return function(0, start, stop, increment); + + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); + if (ticket == nullptr) return function(0, start, stop, increment); + + // create temporary array + int64_t intermediatery[256]; + auto span = (numElements / numThreads) - (numElements % numThreads); + + // execute threads in parallel + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = span * (e + 1) + start; + + if (e == numThreads - 1) + intermediatery[e] = function(e, start_, stop, increment); + else + ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, + stop_, increment); + } + + ticket->waitAndRelease(); + + // aggregate results in single thread + for (uint64_t e = 1; e < numThreads; e++) + intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); + + // return accumulated result + return intermediatery[0]; +} + +double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, + int64_t start, int64_t stop, int64_t increment, + uint64_t numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_long got start > stop"); + + auto delta = (stop - start); + if (delta == 0 || numThreads == 1) return function(0, start, stop, increment); + + auto numElements = delta / increment; + + // we decide what's optimal number of threads we need here, and execute it + numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements); + if (numThreads == 1) return function(0, start, stop, increment); + + auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1); + if (ticket == nullptr) return function(0, start, stop, increment); + + // create temporary array + double intermediatery[256]; + auto span = (numElements / numThreads) - (numElements % numThreads); + + // execute threads in parallel + for (uint32_t e = 0; e < numThreads; e++) { + auto start_ = span * e + start; + auto stop_ = span * (e + 1) + start; + + if (e == numThreads - 1) + intermediatery[e] = function(e, start_, stop, increment); + else + ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, + stop_, increment); + } + + ticket->waitAndRelease(); + + // aggregate results in single thread + for (uint64_t e = 1; e < numThreads; e++) + intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]); + + // return accumulated result + return intermediatery[0]; +} + +int Threads::parallel_aligned_increment(FUNC_1D function, int64_t start, + int64_t stop, int64_t increment, + size_t type_size, + uint32_t req_numThreads) { + if (start > stop) + throw std::runtime_error("Threads::parallel_for got start > stop"); + auto num_elements = (stop - start); + // this way we preserve increment starts offset + // so we will parition considering delta but not total elements + auto delta = (stop - start) / increment; + + // in some cases we just fire func as is + if (delta == 0 || req_numThreads == 1) { + function(0, start, stop, increment); + return 1; + } + int numThreads = 0; + + int adjusted_numThreads = samediff::ThreadsHelper::numberOfThreads( + req_numThreads, (num_elements * sizeof(double)) / (200 * type_size)); + + if (adjusted_numThreads > delta) adjusted_numThreads = delta; + // shortcut + if (adjusted_numThreads <= 1) { + function(0, start, stop, increment); + return 1; + } + // take span as ceil + auto spand = std::ceil((double)delta / (double)adjusted_numThreads); + numThreads = static_cast(std::ceil((double)delta / spand)); + auto span = static_cast(spand); + + auto ticket = samediff::ThreadPool::getInstance()->tryAcquire(numThreads); + if (ticket != nullptr) { + // tail_add is additional value of the last part + // it could be negative or positive + // we will spread that value across + auto tail_add = delta - numThreads * span; + Nd4jLong begin = 0; + Nd4jLong end = 0; + + // we will try enqueu bigger parts first + decltype(span) span1, span2; + int last = 0; + if (tail_add >= 0) { + // for span == 1 , tail_add is 0 + last = tail_add; + span1 = span + 1; + span2 = span; + } else { + last = numThreads + tail_add; // -std::abs(tail_add); + span1 = span; + span2 = span - 1; + } + for (int i = 0; i < last; i++) { + end = begin + span1 * increment; + // putting the task into the queue for a given thread + ticket->enqueue(i, numThreads, function, begin, end, increment); + begin = end; + } + for (int i = last; i < numThreads - 1; i++) { + end = begin + span2 * increment; + // putting the task into the queue for a given thread + ticket->enqueue(i, numThreads, function, begin, end, increment); + begin = end; + } + // for last one enqueue last offset as stop + // we need it in case our ((stop-start) % increment ) > 0 + ticket->enqueue(numThreads - 1, numThreads, function, begin, stop, + increment); + // block and wait till all threads finished the job + ticket->waitAndRelease(); + // we tell that parallelism request succeeded + return numThreads; + } else { + // if there were no threads available - we'll execute function right within + // current thread + function(0, start, stop, increment); + // we tell that parallelism request declined + return 1; + } +} + +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/execution/impl/Ticket.cpp b/libnd4j/include/execution/impl/Ticket.cpp index 98cb053762b6..b56e10794d96 100644 --- a/libnd4j/include/execution/impl/Ticket.cpp +++ b/libnd4j/include/execution/impl/Ticket.cpp @@ -18,77 +18,91 @@ // @author raver119@gmail.com // -#include #include +#include #include + #include namespace samediff { - Ticket::Ticket(const std::vector*> &queues) { - _acquired = true; - _queues = queues; - } - - Ticket::Ticket() { - _acquired = true; - _interfaces.resize(sd::Environment::getInstance()->maxThreads()); - } - - bool Ticket::acquired() { - return _acquired; - } - - void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) { - _queues[thread_id]->put(callable); - _callables.emplace_back(callable); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) { - _interfaces[thread_id]->fill(thread_id, num_threads, func); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x) { - _interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, stop_x, inc_x); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) { - _interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, stop_x, inc_x, start_y, stop_y, inc_y); - } - - void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) { - _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x, start_y, stop_y, inc_y, start_z, stop_z, inc_z); - } - - void Ticket::acquiredThreads(uint32_t threads) { - _acquiredThreads = threads; - } - - void Ticket::waitAndRelease() { - for (uint32_t e = 0; e < this->_acquiredThreads; e++) { - // block until finished - _interfaces[e]->waitForCompletion(); - - // mark available - _interfaces[e]->markAvailable(); - - // increment availability counter - ThreadPool::getInstance()->release(); - } - - // return this ticket back to the pool - ThreadPool::getInstance()->release(this); - } - - - void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) { - _interfaces[thread_id] = interface; - } -} \ No newline at end of file +Ticket::Ticket( + const std::vector *> &queues) { + _acquired = true; + _queues = queues; +} + +Ticket::Ticket() { + _acquired = true; + _interfaces.resize(sd::Environment::getInstance()->maxThreads()); +} + +bool Ticket::acquired() { return _acquired; } + +void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) { + _queues[thread_id]->put(callable); + _callables.emplace_back(callable); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) { + _interfaces[thread_id]->fill(thread_id, num_threads, func); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, + int64_t start_x, int64_t stop_x, int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, + inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, + FUNC_RL func, int64_t start_x, int64_t stop_x, + int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, + stop_x, inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, + FUNC_RD func, int64_t start_x, int64_t stop_x, + int64_t inc_x) { + _interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, + stop_x, inc_x); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y) { + _interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, + stop_x, inc_x, start_y, stop_y, inc_y); +} + +void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, + int64_t start_x, int64_t stop_x, int64_t inc_x, + int64_t start_y, int64_t stop_y, int64_t inc_y, + int64_t start_z, int64_t stop_z, int64_t inc_z) { + _interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, + inc_x, start_y, stop_y, inc_y, start_z, stop_z, + inc_z); +} + +void Ticket::acquiredThreads(uint32_t threads) { _acquiredThreads = threads; } + +void Ticket::waitAndRelease() { + for (uint32_t e = 0; e < this->_acquiredThreads; e++) { + // block until finished + _interfaces[e]->waitForCompletion(); + + // mark available + _interfaces[e]->markAvailable(); + + // increment availability counter + ThreadPool::getInstance()->release(); + } + + // return this ticket back to the pool + ThreadPool::getInstance()->release(this); +} + +void Ticket::attach(uint32_t thread_id, + samediff::CallableInterface *interface) { + _interfaces[thread_id] = interface; +} +} // namespace samediff \ No newline at end of file diff --git a/libnd4j/include/graph/ArgumentsList.h b/libnd4j/include/graph/ArgumentsList.h index f9c7ecad3e03..79fd49d4c0ec 100644 --- a/libnd4j/include/graph/ArgumentsList.h +++ b/libnd4j/include/graph/ArgumentsList.h @@ -21,40 +21,42 @@ #ifndef LIBND4J_INPUTLIST_H #define LIBND4J_INPUTLIST_H +#include #include #include -#include -#include #include +#include + namespace sd { namespace graph { - class SD_EXPORT ArgumentsList { - protected: - std::vector _arguments; - public: - explicit ArgumentsList() = default; - ArgumentsList(std::initializer_list arguments); - ArgumentsList(std::initializer_list arguments); - - ~ArgumentsList() = default; - - /** - * This method returns number of argument pairs available - * - * @return - */ - int size(); - - /** - * This method returns Pair at specified index - * - * @param index - * @return - */ - Pair &at(int index); - }; -} -} - -#endif //LIBND4J_INPUTLIST_H +class SD_EXPORT ArgumentsList { + protected: + std::vector _arguments; + + public: + explicit ArgumentsList() = default; + ArgumentsList(std::initializer_list arguments); + ArgumentsList(std::initializer_list arguments); + + ~ArgumentsList() = default; + + /** + * This method returns number of argument pairs available + * + * @return + */ + int size(); + + /** + * This method returns Pair at specified index + * + * @param index + * @return + */ + Pair &at(int index); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_INPUTLIST_H diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index d7f30cce6277..8a36aefac87e 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -22,206 +22,220 @@ #ifndef LIBND4J_CONTEXT_H #define LIBND4J_CONTEXT_H -#include #include +#include +#include #include #include -#include #include #include -#include + +#include // CUDA-specific includes #ifdef __CUDACC__ #include -#include -#include #include +#include +#include #endif namespace sd { - namespace graph { - /** - * This class defines input desired for any given node/operation within graph - */ - class SD_EXPORT Context : public sd::graph::ContextPrototype { - protected: - sd::graph::GraphMemoryManager *_memoryManager = nullptr; - sd::memory::Workspace* _workspace = nullptr; - - sd::graph::VariableSpace* _variableSpace = nullptr; - std::pair _executionTime; - - sd::DataType _dataType = sd::DataType::FLOAT32; - // branch for divergent_op - int _branch = 0; - - // temporary context for standalone ops execution - LaunchContext* _context = nullptr; - - std::vector _dataTypes; - - // fields for fast execution (out-of-graph ops use) - std::vector> _fastpath_in; - std::vector> _fastpath_out; - - bool _helpersAllowed = true; - - // in some cases we might be able to skip shape function for validation purposes - bool _shapeFunctionOverride = false; - - // special flag used during conversion from Graph exec to FastPath exec - bool _forbidFastPath = false; - public: - Context(const ContextPrototype &prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager = nullptr); - - explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); - Context(int nodeId, VariableSpace *variableSpace, bool isInplace); - - // default destructor - ~Context(); - - // these methods are for execution timing - void setOuterTime(Nd4jLong time); - void setInnerTime(Nd4jLong time); - Nd4jLong outerTime() const; - Nd4jLong innerTime() const; - - // these methods are related to Workspace abstraction - bool hasWorkspaceProvided() const; - void attachWorkspace(sd::memory::Workspace* workspace); - - sd::memory::Workspace* workspace() const; - - - void setVariableSpace(VariableSpace* variableSpace); - - void setTargetEngine(samediff::Engine engine); - - VariableSpace *getVariableSpace(); - - LaunchContext* launchContext(); - - /** - * - * @return - */ - Stash* stash() const; - - /** - * This method returns variable for a given input index for this block - * @param idx - * @return - */ - std::shared_ptr getVariable(int idx) const; - std::shared_ptr variable(int idx) const; - - /** - * This method is shortcut to getVariable(int idx); - * - * + it check fastpath for array availability (preferred) - * @return - */ - std::shared_ptr getNDArray(int idx) const; - std::shared_ptr array(int idx) const; - - /** - * This is special method, used only within Graph - * @param idx - * @return - */ - NDArray* arrayForOp(int idx) const; - - - /** - * This method fetches variable from VariableSpace DIRECTLY - * @param p - * @return - */ - std::shared_ptr variable(int node, int index) const; - std::shared_ptr variable(const std::pair& p) const; - std::shared_ptr variable(std::initializer_list p) const; - - - void pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array); - void pushNDArrayToVariableSpace(const std::pair& pair, const NDArray &array); - - void pushNDArrayListToVariableSpace(int nodeId, int index, std::shared_ptr list); - void pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track = true); - void pushNDArrayListToVariableSpace(const std::pair& pair, const NDArrayList &list, bool track = true); - - bool isValueAvailable(const std::string &name, int id, int idx = 0) const; - - std::shared_ptr ensureVariable(const std::string &name, int id, int idx = 0); - - unsigned long width() const override; - - // methods used in java interop - /** - * This method checks if Context uses fastpath variable access - * @return - */ - bool isFastPath() const; - - /** - * Method allows to forbid FastPath execution - * @param reallyForbid - */ - void forbidFastPath(bool reallyForbid); +namespace graph { +/** + * This class defines input desired for any given node/operation within graph + */ +class SD_EXPORT Context : public sd::graph::ContextPrototype { + protected: + sd::graph::GraphMemoryManager *_memoryManager = nullptr; + sd::memory::Workspace *_workspace = nullptr; + + sd::graph::VariableSpace *_variableSpace = nullptr; + std::pair _executionTime; + + sd::DataType _dataType = sd::DataType::FLOAT32; + // branch for divergent_op + int _branch = 0; + + // temporary context for standalone ops execution + LaunchContext *_context = nullptr; + + std::vector _dataTypes; + + // fields for fast execution (out-of-graph ops use) + std::vector> _fastpath_in; + std::vector> _fastpath_out; + + bool _helpersAllowed = true; + + // in some cases we might be able to skip shape function for validation + // purposes + bool _shapeFunctionOverride = false; + + // special flag used during conversion from Graph exec to FastPath exec + bool _forbidFastPath = false; + + public: + Context(const ContextPrototype &prototype, VariableSpace *variableSpace, + GraphMemoryManager *memoryManager = nullptr); + + explicit Context(int nodeId, VariableSpace *variableSpace = nullptr); + Context(int nodeId, VariableSpace *variableSpace, bool isInplace); + + // default destructor + ~Context(); + + // these methods are for execution timing + void setOuterTime(Nd4jLong time); + void setInnerTime(Nd4jLong time); + Nd4jLong outerTime() const; + Nd4jLong innerTime() const; + + // these methods are related to Workspace abstraction + bool hasWorkspaceProvided() const; + void attachWorkspace(sd::memory::Workspace *workspace); + + sd::memory::Workspace *workspace() const; + + void setVariableSpace(VariableSpace *variableSpace); + + void setTargetEngine(samediff::Engine engine); + + VariableSpace *getVariableSpace(); + + LaunchContext *launchContext(); + + /** + * + * @return + */ + Stash *stash() const; + + /** + * This method returns variable for a given input index for this block + * @param idx + * @return + */ + std::shared_ptr getVariable(int idx) const; + std::shared_ptr variable(int idx) const; + + /** + * This method is shortcut to getVariable(int idx); + * + * + it check fastpath for array availability (preferred) + * @return + */ + std::shared_ptr getNDArray(int idx) const; + std::shared_ptr array(int idx) const; + + /** + * This is special method, used only within Graph + * @param idx + * @return + */ + NDArray *arrayForOp(int idx) const; + + /** + * This method fetches variable from VariableSpace DIRECTLY + * @param p + * @return + */ + std::shared_ptr variable(int node, int index) const; + std::shared_ptr variable(const std::pair &p) const; + std::shared_ptr variable(std::initializer_list p) const; + + void pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array); + void pushNDArrayToVariableSpace(const std::pair &pair, + const NDArray &array); + + void pushNDArrayListToVariableSpace(int nodeId, int index, + std::shared_ptr list); + void pushNDArrayListToVariableSpace(int nodeId, int index, + const NDArrayList &list, + bool track = true); + void pushNDArrayListToVariableSpace(const std::pair &pair, + const NDArrayList &list, + bool track = true); + + bool isValueAvailable(const std::string &name, int id, int idx = 0) const; + + std::shared_ptr ensureVariable(const std::string &name, int id, + int idx = 0); + + unsigned long width() const override; + + // methods used in java interop + /** + * This method checks if Context uses fastpath variable access + * @return + */ + bool isFastPath() const; + + /** + * Method allows to forbid FastPath execution + * @param reallyForbid + */ + void forbidFastPath(bool reallyForbid); #ifndef __JAVACPP_HACK__ - const std::vector>& fastpath_in() const; - const std::vector>& fastpath_out() const; + const std::vector> &fastpath_in() const; + const std::vector> &fastpath_out() const; #endif - void setInputArray(int index, const std::shared_ptr &array); - void setInputArray(int index, const NDArray &array); - void setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo); - void setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo); - void setInputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo); - - void setOutputArray(int index, const std::shared_ptr &array); - void setOutputArray(int index, const NDArray &array); - void setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo); - void setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo); - void setOutputArray(int index, void *databuffer, void const* shapeInfo, void const* specialShapeInfo); - - void setTArguments(double *arguments, int numberOfArguments); - void setIArguments(Nd4jLong *arguments, int numberOfArguments); - void setBArguments(bool *arguments, int numberOfArguments); - void setDArguments(sd::DataType *arguments, int numberOfArguments); - - void setTArguments(const std::vector &tArgs); - void setIArguments(const std::vector &tArgs); - void setBArguments(const std::vector &tArgs); - void setDArguments(const std::vector &dArgs); - - /** - * This method purges fastpath in/out contents and releases all the handles. - * - * PLEASE NOTE: I/T/B/D args will stay intact - */ - void clearFastPath(); - - void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); - - void allowHelpers(bool reallyAllow); - bool helpersAllowed() const; - - void setShapeFunctionOverride(bool reallyOverride); - bool shapeFunctionOverride() const; - - samediff::ExecutionMode executionMode() const; - void setExecutionMode(samediff::ExecutionMode executionMode); - - bool isTraining() const; - bool isInference() const; - - const GraphMemoryManager& memoryManager() const; - }; - } -} - - -#endif //LIBND4J_BLOCK_H + void setInputArray(int index, const std::shared_ptr &array); + void setInputArray(int index, const NDArray &array); + void setInputArray(int index, void *buffer, void const *shapeInfo, + void *specialBuffer, void const *specialShapeInfo); + void setInputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo); + void setInputArray(int index, void *databuffer, void const *shapeInfo, + void const *specialShapeInfo); + + void setOutputArray(int index, const std::shared_ptr &array); + void setOutputArray(int index, const NDArray &array); + void setOutputArray(int index, void *buffer, const void *shapeInfo, + void *specialBuffer, const void *specialShapeInfo); + void setOutputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo); + void setOutputArray(int index, void *databuffer, void const *shapeInfo, + void const *specialShapeInfo); + + void setTArguments(double *arguments, int numberOfArguments); + void setIArguments(Nd4jLong *arguments, int numberOfArguments); + void setBArguments(bool *arguments, int numberOfArguments); + void setDArguments(sd::DataType *arguments, int numberOfArguments); + + void setTArguments(const std::vector &tArgs); + void setIArguments(const std::vector &tArgs); + void setBArguments(const std::vector &tArgs); + void setDArguments(const std::vector &dArgs); + + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + void clearFastPath(); + + void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, + Nd4jPointer allocationPointer); + + void allowHelpers(bool reallyAllow); + bool helpersAllowed() const; + + void setShapeFunctionOverride(bool reallyOverride); + bool shapeFunctionOverride() const; + + samediff::ExecutionMode executionMode() const; + void setExecutionMode(samediff::ExecutionMode executionMode); + + bool isTraining() const; + bool isInference() const; + + const GraphMemoryManager &memoryManager() const; +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_BLOCK_H diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 99c62f3dfb8e..eea5bcf6188b 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -22,136 +22,143 @@ #ifndef ND4J_CONTEXT_PROTOTYPE_H #define ND4J_CONTEXT_PROTOTYPE_H -#include -#include #include -#include -#include -#include #include #include +#include +#include +#include +#include + +#include #ifndef __STANDALONE_BUILD__ #include #endif namespace sd { - namespace graph { - - class SD_EXPORT ContextPrototype { - protected: - // int ids of the input nodes - std::vector> _inputs; - int _nodeId; - std::string _name; - std::vector _tArgs; - std::vector _iArgs; - std::vector _bArgs; - std::vector _axis; - std::vector _dArgs; - - bool _isInplace; - - // opNum for legacy XYZ ops - int _opNum = -1; - uint64_t _rootSeed; - RandomGenerator _randomGenerator; - - sd::ops::OpDescriptor* _opDescriptor; - bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN(); - - // target engine for execution - samediff::Engine _engine = DEFAULT_ENGINE; - - samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED; - public: - explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, int nodeId = 1, bool inPlace = false); - ~ContextPrototype() = default; - - ContextPrototype(const ContextPrototype& other) noexcept; - - ContextPrototype& operator=(const ContextPrototype& other) noexcept; - - // move constructor - ContextPrototype(ContextPrototype&& other) noexcept; - - // move assignment operator - ContextPrototype& operator=(ContextPrototype&& other) noexcept; - - int getNodeId() const; - int nodeId() const; - void setNodeId(int id); - - // this method returns true, if inputs are defined - bool hasVariablesFilled() const; - - void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); - - bool isInplace() const; - void markInplace(bool reallyInplace); - - void pickInput(int input); - void pickInput(int input, int index); - void pickInput(const std::pair& p); - void fillInputs(std::initializer_list inputs); - void fillInputs(std::vector& inputs); - const std::vector>& inputs() const; - - const std::vector& getTArguments() const; - const std::vector& getIArguments() const; - const std::vector& getBArguments() const; - const std::vector& getDArguments() const; - const std::vector& getAxis() const; - - void appendI(const std::vector &value); - void appendT(const std::vector &value); - void appendB(const std::vector &value); - void appendD(const std::vector &value); - - void appendA(Nd4jLong value); - void appendI(Nd4jLong value); - void appendT(double value); - void appendB(bool value); - void appendD(DataType value); - - samediff::Engine engine() const; - - size_t numT() const; - size_t numI() const; - size_t numB() const; - size_t numD() const; - - const std::pair& input(int idx) const; - - int opNum() const; - void setOpNum(int opNum); - - bool isUseMKLDNN() const { return _useMKLDNN; } - void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } +namespace graph { - std::string name() const; - void setName(const std::string &name); +class SD_EXPORT ContextPrototype { + protected: + // int ids of the input nodes + std::vector> _inputs; + int _nodeId; + std::string _name; + std::vector _tArgs; + std::vector _iArgs; + std::vector _bArgs; + std::vector _axis; + std::vector _dArgs; - /** - * This method returns number of inputs available in this block - * @return - */ - virtual unsigned long width() const; + bool _isInplace; - // just a clone - ContextPrototype* clone(); + // opNum for legacy XYZ ops + int _opNum = -1; + uint64_t _rootSeed; + RandomGenerator _randomGenerator; + + sd::ops::OpDescriptor* _opDescriptor; + bool _useMKLDNN = sd::Environment::getInstance()->isUseMKLDNN(); + + // target engine for execution + samediff::Engine _engine = DEFAULT_ENGINE; + + samediff::ExecutionMode _execMode = samediff::ExecutionMode::MODE_UNDEFINED; + + public: + explicit ContextPrototype(sd::ops::OpDescriptor* opDescriptor = nullptr, + int nodeId = 1, bool inPlace = false); + ~ContextPrototype() = default; + + ContextPrototype(const ContextPrototype& other) noexcept; + + ContextPrototype& operator=(const ContextPrototype& other) noexcept; + + // move constructor + ContextPrototype(ContextPrototype&& other) noexcept; + + // move assignment operator + ContextPrototype& operator=(ContextPrototype&& other) noexcept; + + int getNodeId() const; + int nodeId() const; + void setNodeId(int id); + + // this method returns true, if inputs are defined + bool hasVariablesFilled() const; + + void setOpDescriptor(sd::ops::OpDescriptor* opDescriptor); + + bool isInplace() const; + void markInplace(bool reallyInplace); + + void pickInput(int input); + void pickInput(int input, int index); + void pickInput(const std::pair& p); + void fillInputs(std::initializer_list inputs); + void fillInputs(std::vector& inputs); + const std::vector>& inputs() const; + + const std::vector& getTArguments() const; + const std::vector& getIArguments() const; + const std::vector& getBArguments() const; + const std::vector& getDArguments() const; + const std::vector& getAxis() const; + + void appendI(const std::vector& value); + void appendT(const std::vector& value); + void appendB(const std::vector& value); + void appendD(const std::vector& value); + + void appendA(Nd4jLong value); + void appendI(Nd4jLong value); + void appendT(double value); + void appendB(bool value); + void appendD(DataType value); + + samediff::Engine engine() const; + + size_t numT() const; + size_t numI() const; + size_t numB() const; + size_t numD() const; + + const std::pair& input(int idx) const; + + int opNum() const; + void setOpNum(int opNum); + + bool isUseMKLDNN() const { return _useMKLDNN; } + void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } + + std::string name() const; + void setName(const std::string& name); + + /** + * This method returns number of inputs available in this block + * @return + */ + virtual unsigned long width() const; + + // just a clone + ContextPrototype* clone(); - template - ContextPrototype* asT(); + template + ContextPrototype* asT(); - RandomGenerator& randomGenerator() {return _randomGenerator;} - RandomGenerator const& getRng()const { return _randomGenerator; } - void setRng(RandomGenerator const& anotherRng) { _randomGenerator = anotherRng; } - void setRandomGenerator(RandomGenerator const& anotherRng) { _randomGenerator = anotherRng; } - uint64_t randomSeed() const { return _rootSeed; } - void setRandomSeed(uint64_t seed) { _rootSeed = seed; } - }; - } -} + RandomGenerator& randomGenerator() { return _randomGenerator; } + RandomGenerator const& getRng() const { return _randomGenerator; } + void setRng(RandomGenerator const& anotherRng) { + _randomGenerator = anotherRng; + } + void setRandomGenerator(RandomGenerator const& anotherRng) { + _randomGenerator = anotherRng; + } + uint64_t randomSeed() const { return _rootSeed; } + void setRandomSeed(uint64_t seed) { _rootSeed = seed; } +}; +} // namespace graph +} // namespace sd -#endif //ND4J_CONTEXT_PROTOTYPE_H +#endif // ND4J_CONTEXT_PROTOTYPE_H diff --git a/libnd4j/include/graph/ExecutionResult.h b/libnd4j/include/graph/ExecutionResult.h index b1f16032c314..b67320ef546d 100644 --- a/libnd4j/include/graph/ExecutionResult.h +++ b/libnd4j/include/graph/ExecutionResult.h @@ -21,74 +21,77 @@ #ifndef LIBND4J_EXECUTION_RESULT #define LIBND4J_EXECUTION_RESULT -#include +#include +#include + #include -#include #include #include -#include -#include +#include +#include namespace sd { - namespace graph { - class ExecutionResult { - private: - std::vector _variables; - MAP_IMPL _stringIdMap; - MAP_IMPL, Variable *> _pairIdMap; - - // this flag is used to optionally release variables - bool _releasable = false; - public: - ExecutionResult(const FlatResult* flatResult); - ExecutionResult(std::initializer_list variables); - ExecutionResult() = default; - ~ExecutionResult(); - - /** - * This method adds variable pointer to result - */ - void emplace_back(Variable *variable); - - /** - * This method returns Variable by its position in output - */ - Variable* at(int position); - - /** - * This method returns Variable by its string id - */ - Variable* byId(std::string &id); - - /** - * This method returns Variable by its string id - */ - Variable* byId(const char *str); - - /** - * This method returns Variable by its numeric id:index pair - */ - Variable* byId(std::pair &id); - - /** - * This method returns Variable by its numeric id with index 0 - */ - Variable* byId(int id); - - /** - * This method returns number of elements stored in this entity - * @return - */ - Nd4jLong size(); +namespace graph { +class ExecutionResult { + private: + std::vector _variables; + MAP_IMPL _stringIdMap; + MAP_IMPL, Variable *> _pairIdMap; + + // this flag is used to optionally release variables + bool _releasable = false; + + public: + ExecutionResult(const FlatResult *flatResult); + ExecutionResult(std::initializer_list variables); + ExecutionResult() = default; + ~ExecutionResult(); + + /** + * This method adds variable pointer to result + */ + void emplace_back(Variable *variable); + + /** + * This method returns Variable by its position in output + */ + Variable *at(int position); + + /** + * This method returns Variable by its string id + */ + Variable *byId(std::string &id); + + /** + * This method returns Variable by its string id + */ + Variable *byId(const char *str); + + /** + * This method returns Variable by its numeric id:index pair + */ + Variable *byId(std::pair &id); + + /** + * This method returns Variable by its numeric id with index 0 + */ + Variable *byId(int id); + + /** + * This method returns number of elements stored in this entity + * @return + */ + Nd4jLong size(); #ifndef __JAVACPP_HACK__ - /** - * This method converts ExecutionResult entity to FlatResult - */ - flatbuffers::Offset asFlatResult(flatbuffers::FlatBufferBuilder &builder); + /** + * This method converts ExecutionResult entity to FlatResult + */ + flatbuffers::Offset asFlatResult( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/ExecutorConfiguration.h b/libnd4j/include/graph/ExecutorConfiguration.h index 8a9d3e831700..758300595e1e 100644 --- a/libnd4j/include/graph/ExecutorConfiguration.h +++ b/libnd4j/include/graph/ExecutorConfiguration.h @@ -22,31 +22,33 @@ #define LIBND4J_EXECUTORCONFIGURATION_H #include -#include #include +#include namespace sd { - namespace graph { - class SD_EXPORT ExecutorConfiguration { - public: - sd::graph::ProfilingMode _profilingMode; - sd::graph::ExecutionMode _executionMode; - sd::graph::OutputMode _outputMode; - bool _timestats; - Nd4jLong _footprintForward = 0L; - Nd4jLong _footprintBackward = 0L; - Direction _direction = Direction_FORWARD_ONLY; - - explicit ExecutorConfiguration(const sd::graph::FlatConfiguration *conf = nullptr); - ~ExecutorConfiguration() = default; - - ExecutorConfiguration clone() const; +namespace graph { +class SD_EXPORT ExecutorConfiguration { + public: + sd::graph::ProfilingMode _profilingMode; + sd::graph::ExecutionMode _executionMode; + sd::graph::OutputMode _outputMode; + bool _timestats; + Nd4jLong _footprintForward = 0L; + Nd4jLong _footprintBackward = 0L; + Direction _direction = Direction_FORWARD_ONLY; + + explicit ExecutorConfiguration( + const sd::graph::FlatConfiguration *conf = nullptr); + ~ExecutorConfiguration() = default; + + ExecutorConfiguration clone() const; #ifndef __JAVACPP_HACK__ - flatbuffers::Offset asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder); + flatbuffers::Offset asFlatConfiguration( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_EXECUTORCONFIGURATION_H +#endif // LIBND4J_EXECUTORCONFIGURATION_H diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index 0d1e54025272..180c141a3d2f 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -21,25 +21,27 @@ #ifndef LIBND4J_FLATUTILS_H #define LIBND4J_FLATUTILS_H -#include -#include +#include #include #include -#include +#include + +#include namespace sd { - namespace graph { - class SD_EXPORT FlatUtils { - public: - static std::pair fromIntPair(IntPair* pair); +namespace graph { +class SD_EXPORT FlatUtils { + public: + static std::pair fromIntPair(IntPair* pair); - static std::pair fromLongPair(LongPair* pair); + static std::pair fromLongPair(LongPair* pair); - static NDArray fromFlatArray(const sd::graph::FlatArray* flatArray); + static NDArray fromFlatArray(const sd::graph::FlatArray* flatArray); - static flatbuffers::Offset toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array); - }; - } -} + static flatbuffers::Offset toFlatArray( + flatbuffers::FlatBufferBuilder& builder, NDArray& array); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_FLATUTILS_H +#endif // LIBND4J_FLATUTILS_H diff --git a/libnd4j/include/graph/FlowPath.h b/libnd4j/include/graph/FlowPath.h index 70bbc727dda1..e9e8ddc9122f 100644 --- a/libnd4j/include/graph/FlowPath.h +++ b/libnd4j/include/graph/FlowPath.h @@ -21,67 +21,68 @@ #ifndef LIBND4J_FLOWPATH_H #define LIBND4J_FLOWPATH_H -#include -#include -#include -#include -#include #include +#include #include #include +#include +#include + +#include +#include namespace sd { - namespace graph { - class SD_EXPORT FlowPath { - private: - MAP_IMPL _states; - MAP_IMPL _frames; +namespace graph { +class SD_EXPORT FlowPath { + private: + MAP_IMPL _states; + MAP_IMPL _frames; - void ensureNode(int nodeId); - void ensureFrame(int nodeId); + void ensureNode(int nodeId); + void ensureFrame(int nodeId); - GraphProfile _profile; - public: - FlowPath() = default; - ~FlowPath() = default; + GraphProfile _profile; - void setInnerTime(int nodeId, Nd4jLong time); - void setOuterTime(int nodeId, Nd4jLong time); + public: + FlowPath() = default; + ~FlowPath() = default; - Nd4jLong innerTime(int nodeId); - Nd4jLong outerTime(int nodeId); + void setInnerTime(int nodeId, Nd4jLong time); + void setOuterTime(int nodeId, Nd4jLong time); - bool isNodeActive(int nodeId); - void markNodeActive(int nodeId, bool isActive); + Nd4jLong innerTime(int nodeId); + Nd4jLong outerTime(int nodeId); - bool wasExecuted(int nodeId); - void markExecuted(int nodeId, bool wasExecuted); + bool isNodeActive(int nodeId); + void markNodeActive(int nodeId, bool isActive); - int branch(int nodeId); - void markBranch(int nodeId, int index); + bool wasExecuted(int nodeId); + void markExecuted(int nodeId, bool wasExecuted); - // Frame-related methods + int branch(int nodeId); + void markBranch(int nodeId, int index); - void registerFrame(Nd4jLong frameId); - void forgetFrame(Nd4jLong frameId); + // Frame-related methods - bool isFrameActive(Nd4jLong frameId); - void markFrameActive(Nd4jLong frameId, bool isActive); + void registerFrame(Nd4jLong frameId); + void forgetFrame(Nd4jLong frameId); - bool isRewindPlanned(Nd4jLong frameId); - void planRewind(Nd4jLong frameId, bool reallyRewind); + bool isFrameActive(Nd4jLong frameId); + void markFrameActive(Nd4jLong frameId, bool isActive); - int getRewindPosition(Nd4jLong frameId); - void setRewindPosition(Nd4jLong frameId, int position); - void setRewindPositionOnce(Nd4jLong frameId, int position); + bool isRewindPlanned(Nd4jLong frameId); + void planRewind(Nd4jLong frameId, bool reallyRewind); - void incrementNumberOfCycles(Nd4jLong frameId); - Nd4jLong getNumberOfCycles(Nd4jLong frameId); + int getRewindPosition(Nd4jLong frameId); + void setRewindPosition(Nd4jLong frameId, int position); + void setRewindPositionOnce(Nd4jLong frameId, int position); - GraphProfile* profile(); - }; - } -} + void incrementNumberOfCycles(Nd4jLong frameId); + Nd4jLong getNumberOfCycles(Nd4jLong frameId); + GraphProfile* profile(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_FLOWPATH_H +#endif // LIBND4J_FLOWPATH_H diff --git a/libnd4j/include/graph/FrameState.h b/libnd4j/include/graph/FrameState.h index 8f97404a0cae..564cf5750ebe 100644 --- a/libnd4j/include/graph/FrameState.h +++ b/libnd4j/include/graph/FrameState.h @@ -21,87 +21,89 @@ #ifndef LIBND4J_FRAMESTATE_H #define LIBND4J_FRAMESTATE_H -#include -#include #include +#include + +#include namespace sd { - namespace graph { - class SD_EXPORT FrameState { - private: - std::string _name; - Nd4jLong _id = 0; - int _numberOfCycles = 0; - bool _activated = false; - - bool _rewindPlanned = false; - int _rewindPosition = -1; - public: - FrameState(Nd4jLong id = 0); - ~FrameState() = default; - - /** - * This method returns number of cycles passed for this Frame - * - * @return - */ - int getNumberOfCycles(); - - /** - * This method increments number of cycles by 1 for this Frame - */ - void incrementNumberOfCycles(); - - /** - * This method returns TRUE is frame was activated at LoopCond - * @return - */ - bool wasActivated(); - - /** - * This method allows to toggle activated state of this Frame - * @param reallyActivated - */ - void markActivated(bool reallyActivated); - - /** - * This method returns of this Frame (if it's set) - * @return - */ - std::string& getFrameName(); - - /** - * This method returns TRUE if reset is planned for this Frame - * @return - */ - bool isRewindPlanned(); - - /** - * This method allows you to toggle flag for planned rewind - * @param reallyPlanning - */ - void planRewind(bool reallyPlanning); - - /** - * This method returns planned reset position for given Frame - * @return - */ - int getRewindPosition(); - - /** - * This method allows to set rewind position for this Frame - * @param pos - */ - void setRewindPosition(int pos); - - /** - * This method allows to set rewind position for this Frame, but only if it wasn't set earlier - * @param pos - */ - void setRewindPositionOnce(int pos); - }; - } -} - - -#endif //LIBND4J_FRAMESTATE_H +namespace graph { +class SD_EXPORT FrameState { + private: + std::string _name; + Nd4jLong _id = 0; + int _numberOfCycles = 0; + bool _activated = false; + + bool _rewindPlanned = false; + int _rewindPosition = -1; + + public: + FrameState(Nd4jLong id = 0); + ~FrameState() = default; + + /** + * This method returns number of cycles passed for this Frame + * + * @return + */ + int getNumberOfCycles(); + + /** + * This method increments number of cycles by 1 for this Frame + */ + void incrementNumberOfCycles(); + + /** + * This method returns TRUE is frame was activated at LoopCond + * @return + */ + bool wasActivated(); + + /** + * This method allows to toggle activated state of this Frame + * @param reallyActivated + */ + void markActivated(bool reallyActivated); + + /** + * This method returns of this Frame (if it's set) + * @return + */ + std::string& getFrameName(); + + /** + * This method returns TRUE if reset is planned for this Frame + * @return + */ + bool isRewindPlanned(); + + /** + * This method allows you to toggle flag for planned rewind + * @param reallyPlanning + */ + void planRewind(bool reallyPlanning); + + /** + * This method returns planned reset position for given Frame + * @return + */ + int getRewindPosition(); + + /** + * This method allows to set rewind position for this Frame + * @param pos + */ + void setRewindPosition(int pos); + + /** + * This method allows to set rewind position for this Frame, but only if it + * wasn't set earlier + * @param pos + */ + void setRewindPositionOnce(int pos); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_FRAMESTATE_H diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index f6fb62b2bd8a..bef4ff065e75 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -21,171 +21,181 @@ #ifndef LIBND4J_GRAPH_H #define LIBND4J_GRAPH_H -#include #include -#include +#include #include +#include //#include +#include #include -#include +#include #include +#include #include #include -#include -#include -#include -#include -#include #include -#include +#include +#include +#include #include +#include namespace sd { - namespace graph { - - class NodeInfo; - class SD_EXPORT Graph { - protected: - ExecutorConfiguration _configuration; - VariableSpace _variableSpace; - memory::Workspace _workspace; - Stash _stash; - - MAP_IMPL _unmapped; - - // string -> id conversion table - MAP_IMPL _symbolicLookupTable; - - std::mutex _mutexPreprocessing; - std::atomic _built; - - // we want to know last node id - int _maxId = 1; - - const GraphMemoryManager &_memoryMaager; - -//////////////////////////////////////// - Nd4jStatus validateNode(Node *node); - - int idByName(const std::string &nodeName) const; +namespace graph { - void printOutNode(const Node &node) const; +class NodeInfo; +class SD_EXPORT Graph { + protected: + ExecutorConfiguration _configuration; + VariableSpace _variableSpace; + memory::Workspace _workspace; + Stash _stash; - std::vector _placeholders; + MAP_IMPL _unmapped; - mutable OptimizedGraph _optimized; + // string -> id conversion table + MAP_IMPL _symbolicLookupTable; - mutable std::mutex _optimizedLock; - public: - Graph(const FlatGraph *flatGraph = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + std::mutex _mutexPreprocessing; + std::atomic _built; - ~Graph(); + // we want to know last node id + int _maxId = 1; - Graph(const Graph& other); + const GraphMemoryManager &_memoryMaager; - Graph& operator=(const Graph& other) noexcept; + //////////////////////////////////////// + Nd4jStatus validateNode(Node *node); - // move constructor - Graph(Graph&& other); + int idByName(const std::string &nodeName) const; - // move assignment operator - Graph& operator=(Graph&& other) noexcept; + void printOutNode(const Node &node) const; - /** - * Methods that allow Graph imports - */ - static Graph importFromTensorFlow(const char *fileName); - static Graph fromFlatBuffers(const char *fileName, const GraphMemoryManager &memoryManager = GraphMemoryManager()); - static Graph fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); + std::vector _placeholders; - // method that'll print out graph - Nd4jStatus validate(); + mutable OptimizedGraph _optimized; - // this method returns total number of nodes in this graph - int size() const; + mutable std::mutex _optimizedLock; - int numberOfPlaceholders() const; + public: + Graph(const FlatGraph *flatGraph = nullptr, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); - const std::vector>& placeholders() const; + ~Graph(); - /** - * This method returns pointer to thread_local VariableSpace - * @return - */ - VariableSpace& variableSpace() const; - /** - * This method returns unmapped nodes - * @return - */ - const MAP_IMPL& unmappedNodes() const { return _unmapped; }; + Graph(const Graph &other); - const GraphMemoryManager& memoryManager() const; + Graph &operator=(const Graph &other) noexcept; - /** - * These methods add given node to the graph - * @param node - */ - void addNode(Node &&node, const std::initializer_list &inputs); + // move constructor + Graph(Graph &&other); - void addNode(Node &node, const std::initializer_list &inputs); - void addNode(Node &node, const std::initializer_list &inputs); - void addNode(Node &node, const std::initializer_list> &inputs); + // move assignment operator + Graph &operator=(Graph &&other) noexcept; + /** + * Methods that allow Graph imports + */ + static Graph importFromTensorFlow(const char *fileName); + static Graph fromFlatBuffers( + const char *fileName, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); + static Graph fromFlatPointer( + void *ptr, + const GraphMemoryManager &memoryManager = GraphMemoryManager()); - void addVariable(const std::string &name, NDArray &array); - void addVariable(const std::string &name, NDArray &&array); + // method that'll print out graph + Nd4jStatus validate(); - /** - * This method allows to add placeholder with some pre-defined properties - */ - void addPlaceholder(const std::string &nodeName, const DataType dataType = sd::DataType::ANY, const std::vector &shape = {}); + // this method returns total number of nodes in this graph + int size() const; + int numberOfPlaceholders() const; - /** - * This method returns pointer to ExecutorConfiguration - * - * @return - */ - const ExecutorConfiguration& getExecutorConfiguration() const; + const std::vector> &placeholders() const; - /** - * This method prints out Graph op-by-op, and respective inputs - */ - void printOut(); + /** + * This method returns pointer to thread_local VariableSpace + * @return + */ + VariableSpace &variableSpace() const; + /** + * This method returns unmapped nodes + * @return + */ + const MAP_IMPL &unmappedNodes() const { return _unmapped; }; - /** - * This method returns clone of the graph - */ - Graph* clone() const; + const GraphMemoryManager &memoryManager() const; - /** - * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace - */ - Graph cloneWithProxy() const; + /** + * These methods add given node to the graph + * @param node + */ + void addNode(Node &&node, const std::initializer_list &inputs); - /** - * This method returns hash of given Graph instance - */ - Nd4jLong hashCode() const; + void addNode(Node &node, const std::initializer_list &inputs); + void addNode(Node &node, const std::initializer_list &inputs); + void addNode(Node &node, + const std::initializer_list> &inputs); - void replaceState(VariableSpace *state, const ExecutorConfiguration &configuration); + void addVariable(const std::string &name, NDArray &array); + void addVariable(const std::string &name, NDArray &&array); - FORCEINLINE bool built(); + /** + * This method allows to add placeholder with some pre-defined properties + */ + void addPlaceholder(const std::string &nodeName, + const DataType dataType = sd::DataType::ANY, + const std::vector &shape = {}); - const OptimizedGraph& optimizedGraph() const; + /** + * This method returns pointer to ExecutorConfiguration + * + * @return + */ + const ExecutorConfiguration &getExecutorConfiguration() const; - /** - * This method executes this Graph instance and returns execution results - * @param dictionary - * @return - */ - std::map execute(const std::map &dictionary = {}, const std::vector &outputs = {}, const GraphExecutor &executor = GraphExecutor()) const; - }; + /** + * This method prints out Graph op-by-op, and respective inputs + */ + void printOut(); - FORCEINLINE bool Graph::built() { - return _built.load(); - } - } -} + /** + * This method returns clone of the graph + */ + Graph *clone() const; -#endif //LIBND4J_GRAPH_H + /** + * This method returns clone of the graph, backed by VariableProxy instead of + * VariableSpace + */ + Graph cloneWithProxy() const; + + /** + * This method returns hash of given Graph instance + */ + Nd4jLong hashCode() const; + + void replaceState(VariableSpace *state, + const ExecutorConfiguration &configuration); + + FORCEINLINE bool built(); + + const OptimizedGraph &optimizedGraph() const; + + /** + * This method executes this Graph instance and returns execution results + * @param dictionary + * @return + */ + std::map execute( + const std::map &dictionary = {}, + const std::vector &outputs = {}, + const GraphExecutor &executor = GraphExecutor()) const; +}; + +FORCEINLINE bool Graph::built() { return _built.load(); } +} // namespace graph +} // namespace sd + +#endif // LIBND4J_GRAPH_H diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index a41b4457e228..4ba3d809adc6 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -18,41 +18,45 @@ // @author raver119@gmail.com // +#include +#include +#include #include #include -#include + #include -#include -#include -#include +#include namespace sd { - namespace graph { - class SD_EXPORT GraphHolder { - private: - static GraphHolder *_INSTANCE; - MAP_IMPL _graphs; +namespace graph { +class SD_EXPORT GraphHolder { + private: + static GraphHolder *_INSTANCE; + MAP_IMPL _graphs; + + std::mutex _mutex; - std::mutex _mutex; + GraphHolder() = default; + ~GraphHolder() = default; - GraphHolder() = default; - ~GraphHolder() = default; - public: - static GraphHolder* getInstance(); + public: + static GraphHolder *getInstance(); - void registerGraph(Nd4jLong graphId, const Graph &graph); + void registerGraph(Nd4jLong graphId, const Graph &graph); - Graph& graph(Nd4jLong graphId); + Graph &graph(Nd4jLong graphId); - void forgetGraph(Nd4jLong graphId); + void forgetGraph(Nd4jLong graphId); - void dropGraph(Nd4jLong graphId); + void dropGraph(Nd4jLong graphId); - bool hasGraph(Nd4jLong graphId); + bool hasGraph(Nd4jLong graphId); - flatbuffers::Offset execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request); + flatbuffers::Offset execute( + Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, + const FlatInferenceRequest *request); - void replaceGraph(Nd4jLong graphId, const Graph &graph); - }; - } -} \ No newline at end of file + void replaceGraph(Nd4jLong graphId, const Graph &graph); +}; +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/GraphUtils.h b/libnd4j/include/graph/GraphUtils.h index 42328a98c6da..21e79e051a3d 100644 --- a/libnd4j/include/graph/GraphUtils.h +++ b/libnd4j/include/graph/GraphUtils.h @@ -21,23 +21,24 @@ #ifndef __H__GRAPH_UTILS__ #define __H__GRAPH_UTILS__ -#include -#include #include +#include + +#include namespace sd { namespace graph { class SD_EXPORT GraphUtils { -public: - typedef std::vector OpList; + public: + typedef std::vector OpList; -public: - static bool filterOperations(OpList& ops); - static std::string makeCommandLine(OpList& ops); - static int runPreprocessor(char const* input, char const* output); + public: + static bool filterOperations(OpList& ops); + static std::string makeCommandLine(OpList& ops); + static int runPreprocessor(char const* input, char const* output); }; -} -} +} // namespace graph +} // namespace sd #endif diff --git a/libnd4j/include/graph/InferenceRequest.h b/libnd4j/include/graph/InferenceRequest.h index 75418d7cbe3b..ea195eaf4539 100644 --- a/libnd4j/include/graph/InferenceRequest.h +++ b/libnd4j/include/graph/InferenceRequest.h @@ -20,40 +20,42 @@ #ifndef SD_INFERENCEREQUEST_H #define SD_INFERENCEREQUEST_H +#include +#include #include #include -#include -#include + #include "ExecutorConfiguration.h" namespace sd { - namespace graph { - class SD_EXPORT InferenceRequest { - private: - Nd4jLong _id; - std::vector> _variables; +namespace graph { +class SD_EXPORT InferenceRequest { + private: + Nd4jLong _id; + std::vector> _variables; - ExecutorConfiguration _configuration; + ExecutorConfiguration _configuration; - void insertVariable(std::shared_ptr variable); - public: + void insertVariable(std::shared_ptr variable); - InferenceRequest(Nd4jLong graphId, const ExecutorConfiguration &configuration); - ~InferenceRequest(); + public: + InferenceRequest(Nd4jLong graphId, + const ExecutorConfiguration &configuration); + ~InferenceRequest(); - void appendVariable(int id, const NDArray &array); - void appendVariable(int id, int index, const NDArray &array); - void appendVariable(const std::string &name, const NDArray &array); - void appendVariable(const std::string &name, int id, int index, const NDArray &array); - void appendVariable(std::shared_ptr variable); + void appendVariable(int id, const NDArray &array); + void appendVariable(int id, int index, const NDArray &array); + void appendVariable(const std::string &name, const NDArray &array); + void appendVariable(const std::string &name, int id, int index, + const NDArray &array); + void appendVariable(std::shared_ptr variable); #ifndef __JAVACPP_HACK__ - flatbuffers::Offset asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder); + flatbuffers::Offset asFlatInferenceRequest( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} - - +}; +} // namespace graph +} // namespace sd -#endif //SD_INFERENCEREQUEST_H +#endif // SD_INFERENCEREQUEST_H diff --git a/libnd4j/include/graph/Intervals.h b/libnd4j/include/graph/Intervals.h index 939a65d0538c..e0ec274eb690 100644 --- a/libnd4j/include/graph/Intervals.h +++ b/libnd4j/include/graph/Intervals.h @@ -21,36 +21,33 @@ #ifndef LIBND4J_INTERVALS_H #define LIBND4J_INTERVALS_H +#include #include -#include + #include -#include +#include namespace sd { - class SD_EXPORT Intervals { - - private: - std::vector> _content; - - public: +class SD_EXPORT Intervals { + private: + std::vector> _content; - // default constructor - Intervals(); - - // constructor - Intervals(const std::initializer_list>& content ); - Intervals(const std::vector>& content ); - - // accessing operator - std::vector operator[](const Nd4jLong i) const; + public: + // default constructor + Intervals(); - // returns size of _content - int size() const; + // constructor + Intervals(const std::initializer_list>& content); + Intervals(const std::vector>& content); - }; + // accessing operator + std::vector operator[](const Nd4jLong i) const; + // returns size of _content + int size() const; +}; -} +} // namespace sd -#endif //LIBND4J_INTERVALS_H +#endif // LIBND4J_INTERVALS_H diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index c937221b3df0..1dc5ad7a9526 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -21,211 +21,240 @@ #ifndef LIBND4J_GNODE_H #define LIBND4J_GNODE_H -#include -#include -#include #include -#include "Context.h" -#include #include +#include +#include +#include +#include -namespace sd { - namespace graph { - - - class Graph; - - class SD_EXPORT Node { - protected: - // TODO: this field must be removed - sd::DataType _dataType; - - OpType _opType; - ContextPrototype _protoContext; - Nd4jLong _opNum; - int _id = 0; - std::vector> _input; - std::vector> _output; - std::vector _dimensions; - - std::vector _referencedBy; - - std::string _name; - - // many ops require extra parameters to run - double *_extraParams = nullptr; - - bool _hasExternalOutputs; - bool _hasExternalInputs; - bool _hasInternalOutputs; - bool _hasInternalInputs; - - // this field is used to check, if op should be used in-place (so it can/will modify its inputs) - bool _isInplace = false; - - OpClass _opClass; - - // these fields are used to store embedded CustomOps and Graph in case of Graph-in-Graph scenario - Graph * _graph= nullptr; - std::shared_ptr _customOp; - - // each node can be active or inactive, if used with divergents, like IF statements - bool _active = true; - - // meh - mutable bool _removable = true; - - // these fields contain information about Scope these ops are related to - int _scope_id = 0; - std::string _scope_name; - - // TODO: these 3 fields should be removed - int _rewindNode = -1; - std::pair _rewindLayer = {-1, -1}; - Nd4jLong _frameId = -1; - - public: - - explicit Node(const sd::ops::DeclarableOp &op, const std::string &nodeName = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); - explicit Node(const std::string &opName, const std::string &nodeName = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}, const std::vector &bArgs = {}, const std::vector &dArgs = {}); - explicit Node(const FlatNode *node); - ~Node(); - - /* - * FIXME: deprecated methods, to be removed - */ - explicit Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); - explicit Node(const std::string &opName, const int id = 0, const std::vector> &inputs = {}, const std::vector &tArgs = {}, const std::vector &iArgs = {}); - explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - explicit Node(std::shared_ptr customOp, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, int id = 0, std::initializer_list input = {}, std::initializer_list output = {}, std::initializer_list dimensions = {}, float scalar = 0.0f, std::initializer_list tArgs = {}, std::initializer_list iArgs = {}); - - - Node(const Node& other) noexcept; - - Node& operator=(const Node& other) noexcept; - - // move constructor - Node(Node&& other) noexcept; - - // move assignment operator - Node& operator=(Node&& other) noexcept; - - bool equals(Node *other) const; - - sd::DataType dataType(); - const ContextPrototype& protoContext() const; - OpType opType() const; - Nd4jLong opNum() const; - int id() const; - const std::vector>& input() const; - const std::vector>& output() const; - - Nd4jLong getFrameId(); - void setFrameId(Nd4jLong frameId); - - int getRewindNode(); - void setRewindNode(int nodeId); - - std::pair& getRewindLayer(); - void setRewindLayer(int layerId, int stepId = 0); - - void setId(int id); - - double *extraParams(); - - bool isMultiInput(); - bool isMultiOutput(); - - bool isRemovable() const; - void markRemovable(bool reallyRemovable) const; - - bool isDivergencePoint(); - void setActive(bool reallyActive); - bool isActive(); - - bool hasExternalOutputs(); - bool hasExternalInputs(); - bool hasInternalOutputs(); - bool hasInternalInputs(); - - void pickOutputOnce(int outputId); - void pickOutput(int outputId); - void pickOutput(int nodeId, int outputId); - void pickExternalOutput(int outputId); - void pickInput(int inputId); - void pickInput(int nodeId, int outputId); - void pickInput(std::pair& id); - void pickInput(const std::string &id); - - void setName(std::string *name); - void setName(const std::string& name); - const std::string& getName() const; - const std::string& name() const; - - int totalReferences(); - void addReference(int nodeId); - - void setContextPrototype(const ContextPrototype &block); - const ContextPrototype& contextPrototype() const; - bool hasBlockAttached(); - - void setCustomOp(std::shared_ptr customOp); - std::shared_ptr customOp() const; - bool hasCustomOp() const; - - void setGraph(Graph* graph = nullptr); - Graph* graph() const; - bool hasGraphEmbedded() const; - - bool isInplace(); - void markInplace(bool reallyInplace); - - - OpClass getOpClass() const; - - // these methods are used for internal profiling - void setOuterTime(Nd4jLong time); - void setInnerTime(Nd4jLong time); +#include "Context.h" - // methods related to scopes - bool isScoped(); - void setScopeInfo(int id, const char* name = nullptr); - int scopeId(); - std::string* scopeName(); +namespace sd { +namespace graph { + +class Graph; + +class SD_EXPORT Node { + protected: + // TODO: this field must be removed + sd::DataType _dataType; + + OpType _opType; + ContextPrototype _protoContext; + Nd4jLong _opNum; + int _id = 0; + std::vector> _input; + std::vector> _output; + std::vector _dimensions; + + std::vector _referencedBy; + + std::string _name; + + // many ops require extra parameters to run + double *_extraParams = nullptr; + + bool _hasExternalOutputs; + bool _hasExternalInputs; + bool _hasInternalOutputs; + bool _hasInternalInputs; + + // this field is used to check, if op should be used in-place (so it can/will + // modify its inputs) + bool _isInplace = false; + + OpClass _opClass; + + // these fields are used to store embedded CustomOps and Graph in case of + // Graph-in-Graph scenario + Graph *_graph = nullptr; + std::shared_ptr _customOp; + + // each node can be active or inactive, if used with divergents, like IF + // statements + bool _active = true; + + // meh + mutable bool _removable = true; + + // these fields contain information about Scope these ops are related to + int _scope_id = 0; + std::string _scope_name; + + // TODO: these 3 fields should be removed + int _rewindNode = -1; + std::pair _rewindLayer = {-1, -1}; + Nd4jLong _frameId = -1; + + public: + explicit Node(const sd::ops::DeclarableOp &op, + const std::string &nodeName = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}, + const std::vector &bArgs = {}, + const std::vector &dArgs = {}); + explicit Node(const std::string &opName, const std::string &nodeName = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}, + const std::vector &bArgs = {}, + const std::vector &dArgs = {}); + explicit Node(const FlatNode *node); + ~Node(); + + /* + * FIXME: deprecated methods, to be removed + */ + explicit Node(const std::string &opName, const std::string &nodeName, + const int id, const std::vector &inputs = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}); + explicit Node(const std::string &opName, const int id = 0, + const std::vector> &inputs = {}, + const std::vector &tArgs = {}, + const std::vector &iArgs = {}); + explicit Node(sd::ops::DeclarableOp *customOp, int id = 0, + std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + explicit Node(std::shared_ptr customOp, int id = 0, + std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + explicit Node(OpType opType = OpType_TRANSFORM_SAME, int opNum = 0, + int id = 0, std::initializer_list input = {}, + std::initializer_list output = {}, + std::initializer_list dimensions = {}, float scalar = 0.0f, + std::initializer_list tArgs = {}, + std::initializer_list iArgs = {}); + + Node(const Node &other) noexcept; + + Node &operator=(const Node &other) noexcept; + + // move constructor + Node(Node &&other) noexcept; + + // move assignment operator + Node &operator=(Node &&other) noexcept; + + bool equals(Node *other) const; + + sd::DataType dataType(); + const ContextPrototype &protoContext() const; + OpType opType() const; + Nd4jLong opNum() const; + int id() const; + const std::vector> &input() const; + const std::vector> &output() const; + + Nd4jLong getFrameId(); + void setFrameId(Nd4jLong frameId); + + int getRewindNode(); + void setRewindNode(int nodeId); + + std::pair &getRewindLayer(); + void setRewindLayer(int layerId, int stepId = 0); + + void setId(int id); + + double *extraParams(); + + bool isMultiInput(); + bool isMultiOutput(); + + bool isRemovable() const; + void markRemovable(bool reallyRemovable) const; + + bool isDivergencePoint(); + void setActive(bool reallyActive); + bool isActive(); + + bool hasExternalOutputs(); + bool hasExternalInputs(); + bool hasInternalOutputs(); + bool hasInternalInputs(); + + void pickOutputOnce(int outputId); + void pickOutput(int outputId); + void pickOutput(int nodeId, int outputId); + void pickExternalOutput(int outputId); + void pickInput(int inputId); + void pickInput(int nodeId, int outputId); + void pickInput(std::pair &id); + void pickInput(const std::string &id); + + void setName(std::string *name); + void setName(const std::string &name); + const std::string &getName() const; + const std::string &name() const; + + int totalReferences(); + void addReference(int nodeId); + + void setContextPrototype(const ContextPrototype &block); + const ContextPrototype &contextPrototype() const; + bool hasBlockAttached(); + + void setCustomOp(std::shared_ptr customOp); + std::shared_ptr customOp() const; + bool hasCustomOp() const; + + void setGraph(Graph *graph = nullptr); + Graph *graph() const; + bool hasGraphEmbedded() const; + + bool isInplace(); + void markInplace(bool reallyInplace); + + OpClass getOpClass() const; + + // these methods are used for internal profiling + void setOuterTime(Nd4jLong time); + void setInnerTime(Nd4jLong time); + + // methods related to scopes + bool isScoped(); + void setScopeInfo(int id, const char *name = nullptr); + int scopeId(); + std::string *scopeName(); - void setOpType(OpType opType); + void setOpType(OpType opType); - // clone Node - Node* clone(); + // clone Node + Node *clone(); - template - Node* asT(); + template + Node *asT(); - FORCEINLINE void pullValues(Node *other) { - this->_dataType = other->dataType(); - this->_protoContext = other->protoContext(); - this->_hasExternalInputs = other->hasExternalInputs(); - this->_hasExternalOutputs = other->hasExternalOutputs(); - this->_hasInternalInputs = other->hasInternalInputs(); - this->_hasInternalOutputs = other->hasInternalOutputs(); + FORCEINLINE void pullValues(Node *other) { + this->_dataType = other->dataType(); + this->_protoContext = other->protoContext(); + this->_hasExternalInputs = other->hasExternalInputs(); + this->_hasExternalOutputs = other->hasExternalOutputs(); + this->_hasInternalInputs = other->hasInternalInputs(); + this->_hasInternalOutputs = other->hasInternalOutputs(); - this->markInplace(other->isInplace()); - this->setActive(other->isActive()); - this->setScopeInfo(other->scopeId(), other->scopeName()->c_str()); + this->markInplace(other->isInplace()); + this->setActive(other->isActive()); + this->setScopeInfo(other->scopeId(), other->scopeName()->c_str()); - for (auto &v: other->input()) - this->_input.emplace_back(v); + for (auto &v : other->input()) this->_input.emplace_back(v); - for (auto &v: other->output()) - this->_output.emplace_back(v); - } + for (auto &v : other->output()) this->_output.emplace_back(v); + } - static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); - static void deleteOpByType(OpType opType, void *op); - }; - } -} + static std::shared_ptr buildOpByType( + OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); + static void deleteOpByType(OpType opType, void *op); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_GNODE_H +#endif // LIBND4J_GNODE_H diff --git a/libnd4j/include/graph/NodeState.h b/libnd4j/include/graph/NodeState.h index 3ed3020fb4de..199bdcb6aa9f 100644 --- a/libnd4j/include/graph/NodeState.h +++ b/libnd4j/include/graph/NodeState.h @@ -21,49 +21,49 @@ #ifndef LIBND4J_NODESTATE_H #define LIBND4J_NODESTATE_H -#include #include +#include namespace sd { - namespace graph { - class SD_EXPORT NodeState { - private: - // inner time spent on specific node - Nd4jLong _inner = 0; +namespace graph { +class SD_EXPORT NodeState { + private: + // inner time spent on specific node + Nd4jLong _inner = 0; + + // outer time spent on specific node + Nd4jLong _outer = 0; - // outer time spent on specific node - Nd4jLong _outer = 0; - - // flag that shows if node is active or disabled (i.e. after Switch op) - bool _active = true; + // flag that shows if node is active or disabled (i.e. after Switch op) + bool _active = true; - bool _executed = false; + bool _executed = false; - // active divergence branch - int _branch = 0; + // active divergence branch + int _branch = 0; - int _id = 0; - public: - NodeState(int id = 0); - ~NodeState() = default; + int _id = 0; - void setInnerTime(Nd4jLong time); - void setOuterTime(Nd4jLong time); + public: + NodeState(int id = 0); + ~NodeState() = default; - Nd4jLong innerTime(); - Nd4jLong outerTime(); + void setInnerTime(Nd4jLong time); + void setOuterTime(Nd4jLong time); - void markActive(bool isActive); - bool isActive(); + Nd4jLong innerTime(); + Nd4jLong outerTime(); - int branch(); - void markBranch(int index); + void markActive(bool isActive); + bool isActive(); - bool wasExecuted(); - void markExecuted(bool wasExecuted); - }; - } -} + int branch(); + void markBranch(int index); + bool wasExecuted(); + void markExecuted(bool wasExecuted); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_NODESTATE_H +#endif // LIBND4J_NODESTATE_H diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index b653c27943c1..fc20c414e630 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -22,177 +22,192 @@ #ifndef SD_OPTIMIZEDGRAPH_H #define SD_OPTIMIZEDGRAPH_H -#include #include +#include #include -#include + #include #include +#include namespace sd { - namespace graph { - - class Graph; - class NodeInfo; - /** - * This class acts as a topologically sorted & optimized Graph representation, ready for execution - */ - class SD_EXPORT OptimizedGraph { - protected: - // here we store independent OpSequences - // Graph starts from layer 0, and goes deeper step by step - // on each layer we can have 1+ OpSequences that can be executed independent - std::map _onion; - - GraphMemoryManager *_memoryManager = nullptr; - Graph *_originalGraph = nullptr; - - mutable std::mutex _mutex; - - mutable size_t _size = 0; - public: - OptimizedGraph(Graph *original); - OptimizedGraph() = default; - ~OptimizedGraph() = default; - - OptimizedGraph(const OptimizedGraph& other) noexcept; - - OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; - - // move constructor - OptimizedGraph(OptimizedGraph&& other) noexcept; - - // move assignment operator - OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; - - - /** - * This method returns number of layers within OptimizedGraph - * @return - */ - uint64_t layers() const; - - /** - * This method returns OpSequences stored in a given layer - * @param index - * @return - */ - const ExecutionLayer& layer(uint64_t index) const; - - /** - * This method allows to append layer to this OptimizedGraph instance - */ - // FIXME: this method should be removed or made private - void append(const std::vector &layer); - void append(const ExecutionLayer &layer); - void append(OpSequence &sequence); - - /** - * This method returns GraphMemoryManager instance that manages this Graph - * @return - */ - const GraphMemoryManager& memoryManager() const; - - /** - * This method returns pointer to original Graph - * @return - */ - const Graph& originalGraph() const; - - /** - * This method returns number of nodes in this graph instance - * @return - */ - size_t size() const; - - /** - * This method prints out graph content - */ - void printOut() const; - protected: - /* - * optimize original graph - */ - void createOptimizedGraph(); - /* - * Topological graph analysis - * @param const start node for search - * @param const reference for nodes infor container - * @param operation gather - * @return stop iterating - */ - bool topolSearch(const int startNode, std::unordered_map& nodesConnections, std::vector>& opSeq) const; - /* - * Optimized graph analysis prototyping, gather nodes infor - * @param reference to node information collector - * @param reference to start nodes - * @param reference to input branching nodes (input branching node - atleast 2 internal inputs) - * @return stop iterating - */ - bool opGraphProto(std::unordered_map& collector, std::set& startNodes, std::set& inBranchingNodes) const; - /* - * Define layers and sequence positions based on nodes infor - * @param reference to node information collector - * @param node ID - * @param layer ID - * @param sequence ID - * @param map of layers and max sequence - * @return stop iterating - */ - bool layersSeqDefine(std::unordered_map& collection, int ID, int layer, int nStartSeq, std::unordered_map& layersMaxSeq) const; - /* - * Initialize container with operations and context - * @param const reference to layers and sequence collection - * @param reference to opSequence collector - * @return stop iterating - */ - bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, std::vector>& vOpSeq) const; - - }; - - class NodeInfo { - private: - std::set sConnections; - - bool bInBranching; - bool bOutBranching; - bool bProcessed; - - int nLayer; - int nSequence; - - sd::graph::OpType opType; - public: - - NodeInfo(){ reset(); } - ~NodeInfo(){ reset(); } - - void setInBranching(bool bValue) { bInBranching = bValue; } - void setOutBranching(bool bValue) { bOutBranching = bValue; } - void setProcessed(bool bValue = true) { bProcessed = bValue; } - - void reset() { sConnections.clear(); bProcessed = bInBranching = bOutBranching = false; nLayer = 0; nSequence = -1; opType = OpType_CUSTOM; } - - int layer() const { return nLayer; } - void setLayer(int layer) { nLayer = layer; } - - int sequence() const { return nSequence; } - void setSequence(int sequence) { nSequence = sequence; } - - void addConnection(int id) { sConnections.emplace(id); } - const std::set& connections() const { return sConnections; } - - void setType(sd::graph::OpType value){ opType = value; } - sd::graph::OpType type() const { return opType; } - bool isLogic(){ return opType == OpType_LOGIC; } - - bool isInBranching() const { return bInBranching; } - bool isOutBranching() const { return bOutBranching; } - bool isProcessed() const { return bProcessed; } - }; - - } -} - - -#endif //SD_OPTIMIZEDGRAPH_H +namespace graph { + +class Graph; +class NodeInfo; +/** + * This class acts as a topologically sorted & optimized Graph representation, + * ready for execution + */ +class SD_EXPORT OptimizedGraph { + protected: + // here we store independent OpSequences + // Graph starts from layer 0, and goes deeper step by step + // on each layer we can have 1+ OpSequences that can be executed independent + std::map _onion; + + GraphMemoryManager* _memoryManager = nullptr; + Graph* _originalGraph = nullptr; + + mutable std::mutex _mutex; + + mutable size_t _size = 0; + + public: + OptimizedGraph(Graph* original); + OptimizedGraph() = default; + ~OptimizedGraph() = default; + + OptimizedGraph(const OptimizedGraph& other) noexcept; + + OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; + + // move constructor + OptimizedGraph(OptimizedGraph&& other) noexcept; + + // move assignment operator + OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + + /** + * This method returns number of layers within OptimizedGraph + * @return + */ + uint64_t layers() const; + + /** + * This method returns OpSequences stored in a given layer + * @param index + * @return + */ + const ExecutionLayer& layer(uint64_t index) const; + + /** + * This method allows to append layer to this OptimizedGraph instance + */ + // FIXME: this method should be removed or made private + void append(const std::vector& layer); + void append(const ExecutionLayer& layer); + void append(OpSequence& sequence); + + /** + * This method returns GraphMemoryManager instance that manages this Graph + * @return + */ + const GraphMemoryManager& memoryManager() const; + + /** + * This method returns pointer to original Graph + * @return + */ + const Graph& originalGraph() const; + + /** + * This method returns number of nodes in this graph instance + * @return + */ + size_t size() const; + + /** + * This method prints out graph content + */ + void printOut() const; + + protected: + /* + * optimize original graph + */ + void createOptimizedGraph(); + /* + * Topological graph analysis + * @param const start node for search + * @param const reference for nodes infor container + * @param operation gather + * @return stop iterating + */ + bool topolSearch(const int startNode, + std::unordered_map& nodesConnections, + std::vector>& opSeq) const; + /* + * Optimized graph analysis prototyping, gather nodes infor + * @param reference to node information collector + * @param reference to start nodes + * @param reference to input branching nodes (input branching node - atleast 2 + * internal inputs) + * @return stop iterating + */ + bool opGraphProto(std::unordered_map& collector, + std::set& startNodes, + std::set& inBranchingNodes) const; + /* + * Define layers and sequence positions based on nodes infor + * @param reference to node information collector + * @param node ID + * @param layer ID + * @param sequence ID + * @param map of layers and max sequence + * @return stop iterating + */ + bool layersSeqDefine(std::unordered_map& collection, int ID, + int layer, int nStartSeq, + std::unordered_map& layersMaxSeq) const; + /* + * Initialize container with operations and context + * @param const reference to layers and sequence collection + * @param reference to opSequence collector + * @return stop iterating + */ + bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, + std::vector>& vOpSeq) const; +}; + +class NodeInfo { + private: + std::set sConnections; + + bool bInBranching; + bool bOutBranching; + bool bProcessed; + + int nLayer; + int nSequence; + + sd::graph::OpType opType; + + public: + NodeInfo() { reset(); } + ~NodeInfo() { reset(); } + + void setInBranching(bool bValue) { bInBranching = bValue; } + void setOutBranching(bool bValue) { bOutBranching = bValue; } + void setProcessed(bool bValue = true) { bProcessed = bValue; } + + void reset() { + sConnections.clear(); + bProcessed = bInBranching = bOutBranching = false; + nLayer = 0; + nSequence = -1; + opType = OpType_CUSTOM; + } + + int layer() const { return nLayer; } + void setLayer(int layer) { nLayer = layer; } + + int sequence() const { return nSequence; } + void setSequence(int sequence) { nSequence = sequence; } + + void addConnection(int id) { sConnections.emplace(id); } + const std::set& connections() const { return sConnections; } + + void setType(sd::graph::OpType value) { opType = value; } + sd::graph::OpType type() const { return opType; } + bool isLogic() { return opType == OpType_LOGIC; } + + bool isInBranching() const { return bInBranching; } + bool isOutBranching() const { return bOutBranching; } + bool isProcessed() const { return bProcessed; } +}; + +} // namespace graph +} // namespace sd + +#endif // SD_OPTIMIZEDGRAPH_H diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index fdd8f88010fc..57d3cfd0c82d 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -22,13 +22,14 @@ #ifndef LIBND4J_GRAPH_RNG_H #define LIBND4J_GRAPH_RNG_H -#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include + +#include #include #ifdef __CUDACC__ @@ -37,273 +38,265 @@ #endif namespace sd { - namespace graph { +namespace graph { #ifdef __CUDACC__ - class SD_EXPORT CudaManagedRandomGenerator { - private: +class SD_EXPORT CudaManagedRandomGenerator { + private: + protected: + void* devHolder; + + public: + void* operator new(size_t len) { + void* ptr; + auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + if (res != 0) + throw std::runtime_error( + "CudaManagedRandomGenerator: failed to allocate memory"); + + return ptr; + } + + void operator delete(void* ptr) { cudaFreeHost(ptr); } +}; + +class SD_EXPORT RandomGenerator : public CudaManagedRandomGenerator { +#else +class SD_EXPORT RandomGenerator { +#endif + private: +#ifndef __CUDACC__ + void* placeHolder; +#endif + // GRAPH-LEVEL STATE + u64 _rootState; - protected: - void *devHolder; + // NODE-LEVEL STATE + u64 _nodeState; - public: - void *operator new(size_t len) { - void *ptr; - auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); - if (res != 0) - throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory"); + /** + * Utility method, returns number of milliseconds since 1970 + * Leave this static if possible to avoid problems in constructor + */ + static FORCEINLINE Nd4jLong currentMilliseconds(); - return ptr; - } + FORCEINLINE _CUDA_HD uint32_t xoroshiro32(Nd4jLong index); + FORCEINLINE _CUDA_HD uint64_t xoroshiro64(Nd4jLong index); - void operator delete(void *ptr) { - cudaFreeHost(ptr); - } - }; + /** + * This method returns integer value between 0 and MAX_UINT + */ + // uint32_t relativeUInt32(Nd4jLong index); - class SD_EXPORT RandomGenerator : public CudaManagedRandomGenerator { -#else - class SD_EXPORT RandomGenerator { -#endif - private: -#ifndef __CUDACC__ - void *placeHolder; -#endif - // GRAPH-LEVEL STATE - u64 _rootState; + public: + FORCEINLINE RandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); - // NODE-LEVEL STATE - u64 _nodeState; + RandomGenerator(const RandomGenerator& other) noexcept; - /** - * Utility method, returns number of milliseconds since 1970 - * Leave this static if possible to avoid problems in constructor - */ - static FORCEINLINE Nd4jLong currentMilliseconds(); + RandomGenerator& operator=(const RandomGenerator& other) noexcept; + // move constructor + RandomGenerator(RandomGenerator&& other) noexcept; - FORCEINLINE _CUDA_HD uint32_t xoroshiro32(Nd4jLong index); - FORCEINLINE _CUDA_HD uint64_t xoroshiro64(Nd4jLong index); + // move assignment operator + RandomGenerator& operator=(RandomGenerator&& other) noexcept; - /** - * This method returns integer value between 0 and MAX_UINT - */ - //uint32_t relativeUInt32(Nd4jLong index); + /** + * This method allows to change graph-level state in runtime. + * PLEASE NOTE: this method will change state of node as well. + */ + FORCEINLINE _CUDA_H void setStates(Nd4jLong rootSeed, Nd4jLong nodeState = 0); - public: - FORCEINLINE RandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); + /** + * This method returns T value between from and to + */ + template + FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index, T from, T to); + /** + * This method returns T value between 0 and MAX_T + */ + template + FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index); - RandomGenerator(const RandomGenerator& other) noexcept; + /** + * These two methods are made for JVM + * @param index + * @return + */ + FORCEINLINE _CUDA_HD int relativeInt(Nd4jLong index); + FORCEINLINE _CUDA_HD Nd4jLong relativeLong(Nd4jLong index); - RandomGenerator& operator=(const RandomGenerator& other) noexcept; + FORCEINLINE _CUDA_HD void rewindH(Nd4jLong steps); - // move constructor - RandomGenerator(RandomGenerator&& other) noexcept; - - // move assignment operator - RandomGenerator& operator=(RandomGenerator&& other) noexcept; + /** + * These methods set up only node states, with non-changed root ones + */ + FORCEINLINE _CUDA_H void setSeed(int seed) { + _nodeState._ulong = static_cast(seed); + } - /** - * This method allows to change graph-level state in runtime. - * PLEASE NOTE: this method will change state of node as well. - */ - FORCEINLINE _CUDA_H void setStates(Nd4jLong rootSeed, Nd4jLong nodeState = 0); - - - - /** - * This method returns T value between from and to - */ - template - FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index, T from, T to); - - /** - * This method returns T value between 0 and MAX_T - */ - template - FORCEINLINE _CUDA_HD T relativeT(Nd4jLong index); - - /** - * These two methods are made for JVM - * @param index - * @return - */ - FORCEINLINE _CUDA_HD int relativeInt(Nd4jLong index); - FORCEINLINE _CUDA_HD Nd4jLong relativeLong(Nd4jLong index); - - FORCEINLINE _CUDA_HD void rewindH(Nd4jLong steps); - - /** - * These methods set up only node states, with non-changed root ones - */ - FORCEINLINE _CUDA_H void setSeed(int seed) { - _nodeState._ulong = static_cast(seed); - } - - FORCEINLINE _CUDA_H void setSeed(uint64_t seed) { - _nodeState._ulong = seed; - } - - FORCEINLINE _CUDA_HD Nd4jLong rootState() { - return _rootState._long; - } - - FORCEINLINE _CUDA_HD Nd4jLong nodeState() { - return _nodeState._long; - } - }; - - - FORCEINLINE RandomGenerator::RandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - // this seed is used graph-level state - if (rootSeed == 0) - rootSeed = currentMilliseconds(); - - // graph-level state is just first seed - _rootState._long = rootSeed; - - // used to build second, node state - _nodeState._long = (nodeSeed != 0 ? nodeSeed: 1298567341LL); - } - - FORCEINLINE void RandomGenerator::setStates(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - // this seed is used graph-level state - if (rootSeed == 0) - rootSeed = currentMilliseconds(); - - // graph-level state is just first seed - _rootState._long = rootSeed; - - // used to build second, node state - _nodeState._long = (nodeSeed != 0 ? nodeSeed: 1298567341LL); - } - - FORCEINLINE Nd4jLong RandomGenerator::currentMilliseconds() { - auto s = std::chrono::system_clock::now().time_since_epoch(); - auto v = std::chrono::duration_cast(s).count(); - return v; - } - - template <> - _CUDA_HD FORCEINLINE uint64_t RandomGenerator::relativeT(Nd4jLong index) { - return this->xoroshiro64(index); - } - - template <> - _CUDA_HD FORCEINLINE uint32_t RandomGenerator::relativeT(Nd4jLong index) { - return this->xoroshiro32(index); - } - - template <> - _CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index) { - auto x = this->relativeT(index); - auto r = static_cast(x % DataTypeUtils::max()); - return r; - } - - template <> - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index) { - auto x = this->relativeT(index); - auto r = static_cast(x % DataTypeUtils::max()); - return r; - } - - template - _CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) { - auto t = this->relativeT(index); - auto z = from + T(t * (to - from)); - return z; - } - - template <> - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, Nd4jLong from, Nd4jLong to) { - auto t = this->relativeT(index); - auto z = from + Nd4jLong(t * (to - from)); - return z; - } - - template <> - _CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, int to) { - auto t = this->relativeT(index); - auto z = from + float(t * (to - from)); - return z; - } - - template - _CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index) { - // This is default implementation for floating point types -#ifdef __DOUBLE_RNG__ - auto i = static_cast(this->relativeT(index)); - auto r = i / static_cast(DataTypeUtils::max()); - return static_cast(r); -#else - auto i = static_cast(this->relativeT(index)); - auto r = i / static_cast(DataTypeUtils::max()); - return static_cast(r); -#endif - } + FORCEINLINE _CUDA_H void setSeed(uint64_t seed) { _nodeState._ulong = seed; } + + FORCEINLINE _CUDA_HD Nd4jLong rootState() { return _rootState._long; } + FORCEINLINE _CUDA_HD Nd4jLong nodeState() { return _nodeState._long; } +}; - _CUDA_HD FORCEINLINE int RandomGenerator::relativeInt(Nd4jLong index) { - return relativeT(index); - } +FORCEINLINE RandomGenerator::RandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + // this seed is used graph-level state + if (rootSeed == 0) rootSeed = currentMilliseconds(); - _CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeLong(Nd4jLong index) { - return relativeT(index); - } + // graph-level state is just first seed + _rootState._long = rootSeed; - ////// - static FORCEINLINE _CUDA_HD uint32_t rotl(const uint32_t x, int k) { - return (x << k) | (x >> (32 - k)); - } + // used to build second, node state + _nodeState._long = (nodeSeed != 0 ? nodeSeed : 1298567341LL); +} - static FORCEINLINE _CUDA_HD uint64_t rotl(const uint64_t x, int k) { - return (x << k) | (x >> (64 - k)); - } +FORCEINLINE void RandomGenerator::setStates(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + // this seed is used graph-level state + if (rootSeed == 0) rootSeed = currentMilliseconds(); - _CUDA_HD FORCEINLINE uint32_t RandomGenerator::xoroshiro32(Nd4jLong index) { + // graph-level state is just first seed + _rootState._long = rootSeed; - auto s0 = _rootState._ulong; - auto s1 = _nodeState._ulong; + // used to build second, node state + _nodeState._long = (nodeSeed != 0 ? nodeSeed : 1298567341LL); +} - // xor by idx - s0 |= ((index + 2) * (s1 + 24243287)); - s1 ^= ((index + 2) * (s0 + 723829)); - - unsigned long val = 0; - val = s1 ^ s0; - int* pHalf = reinterpret_cast(&val); +FORCEINLINE Nd4jLong RandomGenerator::currentMilliseconds() { + auto s = std::chrono::system_clock::now().time_since_epoch(); + auto v = std::chrono::duration_cast(s).count(); + return v; +} - return rotl(*pHalf * 0x9E3779BB, 5) * 5; - } +template <> +_CUDA_HD FORCEINLINE uint64_t +RandomGenerator::relativeT(Nd4jLong index) { + return this->xoroshiro64(index); +} - _CUDA_HD FORCEINLINE uint64_t RandomGenerator::xoroshiro64(Nd4jLong index) { - auto s0 = _rootState._ulong; - auto s1 = _nodeState._ulong; +template <> +_CUDA_HD FORCEINLINE uint32_t +RandomGenerator::relativeT(Nd4jLong index) { + return this->xoroshiro32(index); +} - // xor by idx - s0 |= ((index + 2) * (s1 + 24243287)); - s1 ^= ((index + 2) * (s0 + 723829)); +template <> +_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index) { + auto x = this->relativeT(index); + auto r = static_cast(x % DataTypeUtils::max()); + return r; +} - // since we're not modifying state - do rotl step right here - s1 ^= s0; - s0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); - s1 = rotl(s1, 36); +template <> +_CUDA_HD FORCEINLINE Nd4jLong +RandomGenerator::relativeT(Nd4jLong index) { + auto x = this->relativeT(index); + auto r = static_cast(x % DataTypeUtils::max()); + return r; +} - return s0 + s1; - } +template +_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, + T to) { + auto t = this->relativeT(index); + auto z = from + T(t * (to - from)); + return z; +} + +template <> +_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, + Nd4jLong from, + Nd4jLong to) { + auto t = this->relativeT(index); + auto z = from + Nd4jLong(t * (to - from)); + return z; +} + +template <> +_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, + int to) { + auto t = this->relativeT(index); + auto z = from + float(t * (to - from)); + return z; +} + +template +_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index) { + // This is default implementation for floating point types +#ifdef __DOUBLE_RNG__ + auto i = static_cast(this->relativeT(index)); + auto r = i / static_cast(DataTypeUtils::max()); + return static_cast(r); +#else + auto i = static_cast(this->relativeT(index)); + auto r = i / static_cast(DataTypeUtils::max()); + return static_cast(r); +#endif +} + +_CUDA_HD FORCEINLINE int RandomGenerator::relativeInt(Nd4jLong index) { + return relativeT(index); +} + +_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeLong(Nd4jLong index) { + return relativeT(index); +} + +////// +static FORCEINLINE _CUDA_HD uint32_t rotl(const uint32_t x, int k) { + return (x << k) | (x >> (32 - k)); +} + +static FORCEINLINE _CUDA_HD uint64_t rotl(const uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); +} + +_CUDA_HD FORCEINLINE uint32_t RandomGenerator::xoroshiro32(Nd4jLong index) { + auto s0 = _rootState._ulong; + auto s1 = _nodeState._ulong; + + // xor by idx + s0 |= ((index + 2) * (s1 + 24243287)); + s1 ^= ((index + 2) * (s0 + 723829)); + + unsigned long val = 0; + val = s1 ^ s0; + int* pHalf = reinterpret_cast(&val); + + return rotl(*pHalf * 0x9E3779BB, 5) * 5; +} + +_CUDA_HD FORCEINLINE uint64_t RandomGenerator::xoroshiro64(Nd4jLong index) { + auto s0 = _rootState._ulong; + auto s1 = _nodeState._ulong; + + // xor by idx + s0 |= ((index + 2) * (s1 + 24243287)); + s1 ^= ((index + 2) * (s0 + 723829)); + + // since we're not modifying state - do rotl step right here + s1 ^= s0; + s0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); + s1 = rotl(s1, 36); + + return s0 + s1; +} - _CUDA_HD FORCEINLINE void RandomGenerator::rewindH(Nd4jLong steps) { - auto s0 = _nodeState._du32._v0; - auto s1 = _nodeState._du32._v1; +_CUDA_HD FORCEINLINE void RandomGenerator::rewindH(Nd4jLong steps) { + auto s0 = _nodeState._du32._v0; + auto s1 = _nodeState._du32._v1; - s1 ^= s0; - _nodeState._du32._v0 = rotl(s0, 26) ^ s1 ^ (s1 << 9); // a, b - _nodeState._du32._v1 = rotl(s1, 13); // c + s1 ^= s0; + _nodeState._du32._v0 = rotl(s0, 26) ^ s1 ^ (s1 << 9); // a, b + _nodeState._du32._v1 = rotl(s1, 13); // c - _nodeState._long ^= (steps ^ 0xdeadbeef); - } - } + _nodeState._long ^= (steps ^ 0xdeadbeef); } +} // namespace graph +} // namespace sd #endif diff --git a/libnd4j/include/graph/RandomGenerator.hpp b/libnd4j/include/graph/RandomGenerator.hpp index fbbc8bad1749..5ddd3e0cd19b 100644 --- a/libnd4j/include/graph/RandomGenerator.hpp +++ b/libnd4j/include/graph/RandomGenerator.hpp @@ -19,26 +19,27 @@ // // relies on xoroshiro64** and xoroshiro128 implementations +#include +#include +#include #include #include -#include + #include -#include -#include namespace sd { - namespace graph { - - +namespace graph { - template _CUDA_HD int RandomGenerator::relativeT(Nd4jLong, int, int); - template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong, float16, float16); - template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong, float, float); - template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong, double, double); - template _CUDA_HD Nd4jLong RandomGenerator::relativeT(Nd4jLong, Nd4jLong, Nd4jLong); +template _CUDA_HD int RandomGenerator::relativeT(Nd4jLong, int, int); +template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong, float16, + float16); +template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong, float, float); +template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong, double, double); +template _CUDA_HD Nd4jLong RandomGenerator::relativeT(Nd4jLong, Nd4jLong, + Nd4jLong); - template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong); - template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong); - template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong); - } -} \ No newline at end of file +template _CUDA_HD float16 RandomGenerator::relativeT(Nd4jLong); +template _CUDA_HD float RandomGenerator::relativeT(Nd4jLong); +template _CUDA_HD double RandomGenerator::relativeT(Nd4jLong); +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/ResultWrapper.h b/libnd4j/include/graph/ResultWrapper.h index 4b21d946fa69..9ee2bcb6405f 100644 --- a/libnd4j/include/graph/ResultWrapper.h +++ b/libnd4j/include/graph/ResultWrapper.h @@ -21,27 +21,26 @@ #ifndef LIBND4J_RESULTWRAPPER_H #define LIBND4J_RESULTWRAPPER_H +#include #include #include -#include namespace sd { - namespace graph { - class SD_EXPORT ResultWrapper { - private: - Nd4jLong _size = 0L; - Nd4jPointer _pointer = nullptr; - - public: - ResultWrapper(Nd4jLong size, Nd4jPointer ptr); - ~ResultWrapper(); +namespace graph { +class SD_EXPORT ResultWrapper { + private: + Nd4jLong _size = 0L; + Nd4jPointer _pointer = nullptr; - Nd4jLong size(); + public: + ResultWrapper(Nd4jLong size, Nd4jPointer ptr); + ~ResultWrapper(); - Nd4jPointer pointer(); - }; - } -} + Nd4jLong size(); + Nd4jPointer pointer(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_RESULTWRAPPER_H +#endif // LIBND4J_RESULTWRAPPER_H diff --git a/libnd4j/include/graph/Scope.h b/libnd4j/include/graph/Scope.h index a78f9fe97471..660578503e45 100644 --- a/libnd4j/include/graph/Scope.h +++ b/libnd4j/include/graph/Scope.h @@ -21,86 +21,88 @@ #ifndef LIBND4J_SCOPE_H #define LIBND4J_SCOPE_H +#include + #include #include -#include namespace sd { - namespace graph { - - /** - * Scope holds sequential list of operations, and made suitable for continuous - * re-execution of multiple operations. - * - * @tparam T - */ - class SD_EXPORT Scope { - protected: - // Graph-unique IDs for Scope instances - int _id; - std::string _name; - - // list of nodes to run, always sequential - // Graph takes care of topo sort - std::vector _nodes; - public: - // attach GiG here, with shared namespace? - // or just rebuilt graph leaf? - // ¯\_(ツ)_/¯ - - // default consructor - explicit Scope(int id, const char* name = nullptr); - - // default destructor - ~Scope(); - - /** - * this method adds op node to the scope - * - * PLEASE NOTE: We assume that ops are being added ORDERED - */ - void push_back(Node* node); - - /** - * This method returns list of ops stored earlier, ready for execution - * - * PLEASE NOTE: If the scope is conditional - last op in list should be BooleanOp - * @return - */ - std::vector * nodes(); - - /** - * This function returns number of nodes in this scope - * - * @return - */ - int size(); - - /** - * Returns ID of this scope - * @return - */ - int id(); - - /** - * Returns name of this scope - * - * @return - */ - std::string* name(); - - /** - * This method returns clone of this Scope - */ - Scope* clone(); - - /** - * This method removes all Nodes from this scope - */ - void forgetNodes(); - }; - } -} - - -#endif //LIBND4J_SCOPE_H +namespace graph { + +/** + * Scope holds sequential list of operations, and made suitable for continuous + * re-execution of multiple operations. + * + * @tparam T + */ +class SD_EXPORT Scope { + protected: + // Graph-unique IDs for Scope instances + int _id; + std::string _name; + + // list of nodes to run, always sequential + // Graph takes care of topo sort + std::vector _nodes; + + public: + // attach GiG here, with shared namespace? + // or just rebuilt graph leaf? + // ¯\_(ツ)_/¯ + + // default consructor + explicit Scope(int id, const char* name = nullptr); + + // default destructor + ~Scope(); + + /** + * this method adds op node to the scope + * + * PLEASE NOTE: We assume that ops are being added ORDERED + */ + void push_back(Node* node); + + /** + * This method returns list of ops stored earlier, ready for execution + * + * PLEASE NOTE: If the scope is conditional - last op in list should be + * BooleanOp + * @return + */ + std::vector* nodes(); + + /** + * This function returns number of nodes in this scope + * + * @return + */ + int size(); + + /** + * Returns ID of this scope + * @return + */ + int id(); + + /** + * Returns name of this scope + * + * @return + */ + std::string* name(); + + /** + * This method returns clone of this Scope + */ + Scope* clone(); + + /** + * This method removes all Nodes from this scope + */ + void forgetNodes(); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_SCOPE_H diff --git a/libnd4j/include/graph/Stash.h b/libnd4j/include/graph/Stash.h index 2e9e832eaf4f..d5d375d8d351 100644 --- a/libnd4j/include/graph/Stash.h +++ b/libnd4j/include/graph/Stash.h @@ -23,72 +23,70 @@ //#include #include -#include -#include -#include +#include + #include #include -#include +#include +#include +#include namespace sd { - namespace graph { - class SD_EXPORT KeyPair { - int _node; - std::string _name; - public: - KeyPair(int node = 0, const char *name = nullptr); +namespace graph { +class SD_EXPORT KeyPair { + int _node; + std::string _name; + + public: + KeyPair(int node = 0, const char *name = nullptr); - bool operator<(const KeyPair &other) const; + bool operator<(const KeyPair &other) const; - bool operator==(const KeyPair &other) const { - return _node == other._node; - } + bool operator==(const KeyPair &other) const { return _node == other._node; } - int key() const { return _node; } - std::string name() const { return _name; } - }; - } -} + int key() const { return _node; } + std::string name() const { return _name; } +}; +} // namespace graph +} // namespace sd #ifndef __JAVACPP_HACK__ namespace std { - template <> - class SD_EXPORT hash { - public: - size_t operator()(const sd::graph::KeyPair& k) const; - }; +template <> +class SD_EXPORT hash { + public: + size_t operator()(const sd::graph::KeyPair &k) const; }; +}; // namespace std #endif namespace sd { - namespace graph { - class SD_EXPORT Stash { - protected: - std::map _stash; - std::vector _handles; - - public: - Stash(); - ~Stash(); +namespace graph { +class SD_EXPORT Stash { + protected: + std::map _stash; + std::vector _handles; - //void storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array); - void storeArray(int nodeId, const char *name, sd::NDArray *array); + public: + Stash(); + ~Stash(); - //bool checkStash(sd::graph::Block& block, const char *name); - bool checkStash(int nodeId, const char *name); + // void storeArray(sd::graph::Block& block, const char *name, + // sd::NDArray *array); + void storeArray(int nodeId, const char *name, sd::NDArray *array); - //sd::NDArray* extractArray(sd::graph::Block& block, const char *name); - sd::NDArray* extractArray(int nodeId, const char *name); - - void clear(); - }; - } - -} + // bool checkStash(sd::graph::Block& block, const char *name); + bool checkStash(int nodeId, const char *name); + // sd::NDArray* extractArray(sd::graph::Block& block, const char *name); + sd::NDArray *extractArray(int nodeId, const char *name); + void clear(); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_STASH_H +#endif // LIBND4J_STASH_H diff --git a/libnd4j/include/graph/Status.h b/libnd4j/include/graph/Status.h index 2f055dc4dbdf..93a957846c8c 100644 --- a/libnd4j/include/graph/Status.h +++ b/libnd4j/include/graph/Status.h @@ -21,30 +21,28 @@ #ifndef ND4J_STATUS_H #define ND4J_STATUS_H -#include -#include -#include #include +#include +#include +#include namespace sd { - class SD_EXPORT Status { - public: - static FORCEINLINE Nd4jStatus OK() { - return ND4J_STATUS_OK; - }; +class SD_EXPORT Status { + public: + static FORCEINLINE Nd4jStatus OK() { return ND4J_STATUS_OK; }; - static FORCEINLINE Nd4jStatus CODE(Nd4jStatus code, const char *message) { - nd4j_printf("%s\n", message); - return code; - } + static FORCEINLINE Nd4jStatus CODE(Nd4jStatus code, const char *message) { + nd4j_printf("%s\n", message); + return code; + } - static FORCEINLINE Nd4jStatus THROW(const char *message = nullptr) { - if (message != nullptr) { - nd4j_printf("%s\n", message); - } - return ND4J_STATUS_KERNEL_FAILURE; - } - }; -} + static FORCEINLINE Nd4jStatus THROW(const char *message = nullptr) { + if (message != nullptr) { + nd4j_printf("%s\n", message); + } + return ND4J_STATUS_KERNEL_FAILURE; + } +}; +} // namespace sd -#endif // STATUS_H \ No newline at end of file +#endif // STATUS_H \ No newline at end of file diff --git a/libnd4j/include/graph/TimeHolder.h b/libnd4j/include/graph/TimeHolder.h index 110cdf2f5d28..9d332226cbad 100644 --- a/libnd4j/include/graph/TimeHolder.h +++ b/libnd4j/include/graph/TimeHolder.h @@ -21,32 +21,29 @@ #ifndef LIBND4J_TIMEHOLDER_H #define LIBND4J_TIMEHOLDER_H -#include -#include #include +#include -namespace sd { - namespace graph { - class SD_EXPORT TimeHolder { - private: - std::map _outer; - std::map _inner; - - - public: - - TimeHolder() = default; - ~TimeHolder() = default; - - - void setOuterTime(int nodeId, Nd4jLong time); - void setInnerTime(int nodeId, Nd4jLong time); - - - Nd4jLong outerTime(int nodeId); - Nd4jLong innerTime(int nodeId); - }; - } -} +#include -#endif //LIBND4J_TIMEHOLDER_H +namespace sd { +namespace graph { +class SD_EXPORT TimeHolder { + private: + std::map _outer; + std::map _inner; + + public: + TimeHolder() = default; + ~TimeHolder() = default; + + void setOuterTime(int nodeId, Nd4jLong time); + void setInnerTime(int nodeId, Nd4jLong time); + + Nd4jLong outerTime(int nodeId); + Nd4jLong innerTime(int nodeId); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_TIMEHOLDER_H diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index fff8fa79a53d..4a3e241addb3 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -21,121 +21,126 @@ #ifndef LIBND4J_VARIABLE_H #define LIBND4J_VARIABLE_H -#include #include #include #include #include -#include #include +#include + +#include #ifndef __JAVACPP_HACK__ namespace std { - template <> - class SD_EXPORT hash> { - public: - size_t operator()(const std::pair& k) const; - }; - - template <> - class SD_EXPORT hash { - public: - size_t operator()(const bfloat16& k) const; - }; - - template <> - class SD_EXPORT hash { - public: - size_t operator()(const float16& k) const; - }; +template <> +class SD_EXPORT hash> { + public: + size_t operator()(const std::pair &k) const; +}; + +template <> +class SD_EXPORT hash { + public: + size_t operator()(const bfloat16 &k) const; +}; + +template <> +class SD_EXPORT hash { + public: + size_t operator()(const float16 &k) const; }; +}; // namespace std #endif namespace sd { - namespace graph { - class SD_EXPORT Variable { - protected: - int _id = 0; - int _index = 0; - std::shared_ptr _ndarray; - std::string _name; - - std::vector _shape; - DataType _dtype; - - bool _external = false; - bool _readOnly = false; - bool _placeholder = false; - bool _removable = true; - - std::shared_ptr _list; - - VariableType _variableType = VariableType::NDARRAY; - - public: - explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); - explicit Variable(const sd::NDArray &array, const std::string &name, int id, int idx = 0); - explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); - explicit Variable(std::shared_ptr array, const char *name = nullptr); - explicit Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx = 0); - explicit Variable(); +namespace graph { +class SD_EXPORT Variable { + protected: + int _id = 0; + int _index = 0; + std::shared_ptr _ndarray; + std::string _name; + + std::vector _shape; + DataType _dtype; + + bool _external = false; + bool _readOnly = false; + bool _placeholder = false; + bool _removable = true; + + std::shared_ptr _list; + + VariableType _variableType = VariableType::NDARRAY; + + public: + explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, + const std::vector &shape = {}); + explicit Variable(const sd::NDArray &array, const std::string &name, int id, + int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, + int id, int idx = 0); + explicit Variable(std::shared_ptr array, + const char *name = nullptr); + explicit Variable(const NDArrayList &arrayList, const std::string &name, + int id, int idx = 0); + explicit Variable(); #ifndef __JAVACPP_HACK__ - explicit Variable(const sd::graph::FlatVariable *flatVariable); + explicit Variable(const sd::graph::FlatVariable *flatVariable); #endif - ~Variable(); + ~Variable(); + bool hasNDArray() const; + std::shared_ptr getNDArray() const; + void setNDArray(std::shared_ptr array); - bool hasNDArray() const; - std::shared_ptr getNDArray() const; - void setNDArray(std::shared_ptr array); + bool hasNDArrayList() const; + std::shared_ptr getNDArrayList() const; + void setNDArrayList(std::shared_ptr list); - bool hasNDArrayList() const; - std::shared_ptr getNDArrayList() const; - void setNDArrayList(std::shared_ptr list); + bool isExternal() const; + bool isReadOnly() const; + bool isEmpty() const; + bool isRemovable() const; - bool isExternal() const; - bool isReadOnly() const; - bool isEmpty() const; - bool isRemovable() const; + bool isPlaceholder() const; - bool isPlaceholder() const; + VariableType variableType() const; + void setVariableType(VariableType variableType); - VariableType variableType() const; - void setVariableType(VariableType variableType); + void markExternal(bool reallyExternal); + void markReadOnly(bool reallyReadOnly); + void markRemovable(bool reallyRemovable); - void markExternal(bool reallyExternal); - void markReadOnly(bool reallyReadOnly); - void markRemovable(bool reallyRemovable); + int id() const; + int index() const; + void setIndex(int index); + void setId(int id); + void setId(int id, int idx); - int id() const; - int index() const; - void setIndex(int index); - void setId(int id); - void setId(int id, int idx); + const std::string &name() const; + const std::string &getName() const; + void setName(const std::string &name); - const std::string& name() const; - const std::string& getName() const; - void setName(const std::string &name); - - const std::vector& shape() const; - DataType dataType() const; + const std::vector &shape() const; + DataType dataType() const; #ifndef __JAVACPP_HACK__ - /** - * This method returns offset to this Variable in FlatBuffer - * @param builder - * @return - */ - flatbuffers::Offset asFlatVariable(flatbuffers::FlatBufferBuilder &builder); + /** + * This method returns offset to this Variable in FlatBuffer + * @param builder + * @return + */ + flatbuffers::Offset asFlatVariable( + flatbuffers::FlatBufferBuilder &builder); #endif - }; - } -} - +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLE_H +#endif // LIBND4J_VARIABLE_H diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index c3e427592d1e..729d171629a0 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -24,64 +24,76 @@ #include namespace sd { - namespace graph { - class SD_EXPORT VariableProxy: public VariableSpace { - protected: - const VariableSpace* _backed; - VariableSpace* _current = nullptr; - public: - explicit VariableProxy(const VariableSpace* reference); - ~VariableProxy(); - - virtual VariableSpace& operator=(const VariableSpace& other); - - virtual int numberOfPlaceholders() const override; - virtual const std::vector>& placeholders() const override; - - virtual bool hasExternalVariable(int it) const override; - virtual bool hasExternalVariable(const std::pair& pair) const override; - virtual bool hasExternalVariable(const std::string &symbol) const override; - - virtual bool hasVariable(int id) const override; - virtual bool hasVariable(int id, int idx) const override; - virtual bool hasVariable(const std::pair& pair) const override; - virtual bool hasVariable(const std::string &symbol) const override; - - virtual std::shared_ptr getVariable(int id) const override; - virtual std::shared_ptr getVariable(int id, int idx) const override; - virtual std::shared_ptr getVariable(const std::pair& pair) const override; - virtual std::shared_ptr getVariable(const std::string &symbol) const override; - - virtual std::vector> variables() const override; - - virtual std::shared_ptr putVariable(const std::pair& pair, const NDArray &array) override; - virtual std::shared_ptr putVariable(int id, const NDArray &array) override; - virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array) override; - virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array) override; - virtual void putVariable(const std::string& name, int id, int idx, const std::shared_ptr &variable) override; - virtual void putVariable(const std::pair& pair, const std::shared_ptr &variable) override; - virtual void putVariable(int id, const std::shared_ptr &variable) override; - - virtual void replaceVariable(std::shared_ptr variable) override; - - virtual void dropVariable(const std::pair &pair) override; - virtual void dropVariable(int id, int idx) override; - - virtual void putOutputVariable(std::shared_ptr variable) override; - - // memory-related statistics - virtual Nd4jLong externalMemory() const override; - virtual Nd4jLong internalMemory() const override; - virtual Nd4jLong totalMemory() const override; - - virtual int externalEntries() const override; - virtual int internalEntries() const override; - virtual int totalEntries() const override; - - - virtual Stash* stash() const override; - }; - } -} - -#endif // SD_VARIABLEPROXY_H \ No newline at end of file +namespace graph { +class SD_EXPORT VariableProxy : public VariableSpace { + protected: + const VariableSpace* _backed; + VariableSpace* _current = nullptr; + + public: + explicit VariableProxy(const VariableSpace* reference); + ~VariableProxy(); + + virtual VariableSpace& operator=(const VariableSpace& other); + + virtual int numberOfPlaceholders() const override; + virtual const std::vector>& placeholders() + const override; + + virtual bool hasExternalVariable(int it) const override; + virtual bool hasExternalVariable( + const std::pair& pair) const override; + virtual bool hasExternalVariable(const std::string& symbol) const override; + + virtual bool hasVariable(int id) const override; + virtual bool hasVariable(int id, int idx) const override; + virtual bool hasVariable(const std::pair& pair) const override; + virtual bool hasVariable(const std::string& symbol) const override; + + virtual std::shared_ptr getVariable(int id) const override; + virtual std::shared_ptr getVariable(int id, int idx) const override; + virtual std::shared_ptr getVariable( + const std::pair& pair) const override; + virtual std::shared_ptr getVariable( + const std::string& symbol) const override; + + virtual std::vector> variables() const override; + + virtual std::shared_ptr putVariable(const std::pair& pair, + const NDArray& array) override; + virtual std::shared_ptr putVariable(int id, + const NDArray& array) override; + virtual std::shared_ptr putVariable(int id, int idx, + const NDArray& array) override; + virtual std::shared_ptr putVariable(const std::string& name, int id, + int idx, + const NDArray& array) override; + virtual void putVariable(const std::string& name, int id, int idx, + const std::shared_ptr& variable) override; + virtual void putVariable(const std::pair& pair, + const std::shared_ptr& variable) override; + virtual void putVariable(int id, + const std::shared_ptr& variable) override; + + virtual void replaceVariable(std::shared_ptr variable) override; + + virtual void dropVariable(const std::pair& pair) override; + virtual void dropVariable(int id, int idx) override; + + virtual void putOutputVariable(std::shared_ptr variable) override; + + // memory-related statistics + virtual Nd4jLong externalMemory() const override; + virtual Nd4jLong internalMemory() const override; + virtual Nd4jLong totalMemory() const override; + + virtual int externalEntries() const override; + virtual int internalEntries() const override; + virtual int totalEntries() const override; + + virtual Stash* stash() const override; +}; +} // namespace graph +} // namespace sd + +#endif // SD_VARIABLEPROXY_H \ No newline at end of file diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 798e350aef7d..6a9a9a0b5e90 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -21,112 +21,120 @@ #ifndef LIBND4J_VARIABLESPACE_H #define LIBND4J_VARIABLESPACE_H -#include -#include -#include -#include -#include -#include -#include #include #include +#include +#include #include +#include +#include #include -#include -#include +#include +#include +#include +#include +#include namespace sd { - namespace graph { - class SD_EXPORT VariableSpace { - protected: - // stash is NOT cloned - Stash _stash; - - // lookup tables: by name, by id, by id:idx - MAP_IMPL, std::shared_ptr> _paired; - MAP_IMPL> _symbolic; - MAP_IMPL> _variables; - - // direct references to external variables and internally-generated variables - std::vector> _external; - std::vector> _internal; +namespace graph { +class SD_EXPORT VariableSpace { + protected: + // stash is NOT cloned + Stash _stash; - // meh - std::vector> _lists; + // lookup tables: by name, by id, by id:idx + MAP_IMPL, std::shared_ptr> _paired; + MAP_IMPL> _symbolic; + MAP_IMPL> _variables; - // placeholders. must be resolved before Graph execution - std::vector> _placeholders; + // direct references to external variables and internally-generated variables + std::vector> _external; + std::vector> _internal; - void silentPutVariable(const std::pair& pair, const std::shared_ptr &variable); + // meh + std::vector> _lists; - int _auto_counter = -1; + // placeholders. must be resolved before Graph execution + std::vector> _placeholders; - std::mutex _varmap; + void silentPutVariable(const std::pair &pair, + const std::shared_ptr &variable); - public: - VariableSpace(); - virtual ~VariableSpace(); + int _auto_counter = -1; - VariableSpace(const sd::graph::VariableSpace &variableSpace); - VariableSpace(sd::graph::VariableSpace &&variableSpace); + std::mutex _varmap; - virtual VariableSpace& operator=(const VariableSpace& other); - virtual VariableSpace& operator=(VariableSpace&& other); + public: + VariableSpace(); + virtual ~VariableSpace(); - virtual int numberOfPlaceholders() const; + VariableSpace(const sd::graph::VariableSpace &variableSpace); + VariableSpace(sd::graph::VariableSpace &&variableSpace); - virtual const std::vector>& placeholders() const; + virtual VariableSpace &operator=(const VariableSpace &other); + virtual VariableSpace &operator=(VariableSpace &&other); - virtual bool hasExternalVariable(int it) const; - virtual bool hasExternalVariable(const std::pair& pair) const; - virtual bool hasExternalVariable(const std::string &symbol) const; + virtual int numberOfPlaceholders() const; - virtual bool hasVariable(int id) const; - virtual bool hasVariable(int id, int idx) const; - virtual bool hasVariable(const std::pair& pair) const; - virtual bool hasVariable(const std::string &symbol) const; + virtual const std::vector> &placeholders() const; - virtual std::shared_ptr getVariable(int id) const; - virtual std::shared_ptr getVariable(int id, int idx) const; - virtual std::shared_ptr getVariable(const std::pair& pair) const; - virtual std::shared_ptr getVariable(const std::string &symbol) const; + virtual bool hasExternalVariable(int it) const; + virtual bool hasExternalVariable(const std::pair &pair) const; + virtual bool hasExternalVariable(const std::string &symbol) const; - virtual std::vector> variables() const; + virtual bool hasVariable(int id) const; + virtual bool hasVariable(int id, int idx) const; + virtual bool hasVariable(const std::pair &pair) const; + virtual bool hasVariable(const std::string &symbol) const; - virtual std::shared_ptr putVariable(const std::pair& pair, const NDArray &array); - virtual std::shared_ptr putVariable(int id, const NDArray &array); - virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); - virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array); - virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array); - virtual void putVariable(const std::string& name, int id, int idx, const std::shared_ptr &variable); - virtual void putVariable(const std::pair& pair, const std::shared_ptr &variable); - virtual void putVariable(int id, const std::shared_ptr &variable); + virtual std::shared_ptr getVariable(int id) const; + virtual std::shared_ptr getVariable(int id, int idx) const; + virtual std::shared_ptr getVariable( + const std::pair &pair) const; + virtual std::shared_ptr getVariable( + const std::string &symbol) const; - virtual void dropVariable(const std::string &pair); - virtual void dropVariable(const std::pair &pair); - virtual void dropVariable(int id, int idx); + virtual std::vector> variables() const; - virtual void putOutputVariable(std::shared_ptr variable); + virtual std::shared_ptr putVariable(const std::pair &pair, + const NDArray &array); + virtual std::shared_ptr putVariable(int id, const NDArray &array); + virtual std::shared_ptr putVariable( + int id, int idx, const std::shared_ptr &array); + virtual std::shared_ptr putVariable(int id, int idx, + const NDArray &array); + virtual std::shared_ptr putVariable(const std::string &name, int id, + int idx, const NDArray &array); + virtual void putVariable(const std::string &name, int id, int idx, + const std::shared_ptr &variable); + virtual void putVariable(const std::pair &pair, + const std::shared_ptr &variable); + virtual void putVariable(int id, const std::shared_ptr &variable); - virtual void replaceVariable(std::shared_ptr variable); + virtual void dropVariable(const std::string &pair); + virtual void dropVariable(const std::pair &pair); + virtual void dropVariable(int id, int idx); - // memory-related statistics - virtual Nd4jLong externalMemory() const; - virtual Nd4jLong internalMemory() const; - virtual Nd4jLong totalMemory() const; + virtual void putOutputVariable(std::shared_ptr variable); - virtual int externalEntries() const; - virtual int internalEntries() const; - virtual int totalEntries() const; + virtual void replaceVariable(std::shared_ptr variable); - void injectVariable(const std::pair &pair, std::shared_ptr variable); + // memory-related statistics + virtual Nd4jLong externalMemory() const; + virtual Nd4jLong internalMemory() const; + virtual Nd4jLong totalMemory() const; - virtual Stash* stash() const; + virtual int externalEntries() const; + virtual int internalEntries() const; + virtual int totalEntries() const; - }; - } -} + void injectVariable(const std::pair &pair, + std::shared_ptr variable); + virtual Stash *stash() const; +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLESPACE_H +#endif // LIBND4J_VARIABLESPACE_H diff --git a/libnd4j/include/graph/VariableType.h b/libnd4j/include/graph/VariableType.h index 28883f9b1e52..1e9c580e5c00 100644 --- a/libnd4j/include/graph/VariableType.h +++ b/libnd4j/include/graph/VariableType.h @@ -22,15 +22,15 @@ #define ND4J_VARIABLE_TYPE_H namespace sd { - namespace graph { - enum VariableType { - NDARRAY = 0, - ARRAY_LIST = 1, - FLOW = 2, - CONSTANT = 3, - PLACEHOLDER = 4, - }; - } +namespace graph { +enum VariableType { + NDARRAY = 0, + ARRAY_LIST = 1, + FLOW = 2, + CONSTANT = 3, + PLACEHOLDER = 4, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/VariablesSet.h b/libnd4j/include/graph/VariablesSet.h index b3e4844df60f..b5e5a4d5e87f 100644 --- a/libnd4j/include/graph/VariablesSet.h +++ b/libnd4j/include/graph/VariablesSet.h @@ -21,35 +21,33 @@ #ifndef LIBND4J_VARIABLESSET_H #define LIBND4J_VARIABLESSET_H -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace graph { - class SD_EXPORT VariablesSet { - protected: - std::vector _holder; - Nd4jStatus _status; - public: - VariablesSet(Nd4jStatus status = ND4J_STATUS_OK); - ~VariablesSet(); - - Nd4jStatus status(); - - int size(); +namespace graph { +class SD_EXPORT VariablesSet { + protected: + std::vector _holder; + Nd4jStatus _status; - void push_back(Variable* variable); + public: + VariablesSet(Nd4jStatus status = ND4J_STATUS_OK); + ~VariablesSet(); - Variable* at(int index); + Nd4jStatus status(); - }; - } -} + int size(); + void push_back(Variable* variable); + Variable* at(int index); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_VARIABLESSET_H +#endif // LIBND4J_VARIABLESSET_H diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp index 5b25bbc0ad27..78b8c058533f 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_input_exception.cpp @@ -22,29 +22,35 @@ #include namespace sd { - namespace graph { - unresolved_input_exception::unresolved_input_exception(const std::string &message) : std::runtime_error(message) { - // - } +namespace graph { +unresolved_input_exception::unresolved_input_exception( + const std::string &message) + : std::runtime_error(message) { + // +} - unresolved_input_exception unresolved_input_exception::build(const std::string & message, int nodeId, std::pair &varIndex) { - auto node = StringUtils::valueToString(nodeId); - auto varId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - auto rmessage = message + "; Node: [" + node +":0]; Variable: [" + varId + ":" + outputIdx + "]"; - return unresolved_input_exception(rmessage); - } +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, int nodeId, std::pair &varIndex) { + auto node = StringUtils::valueToString(nodeId); + auto varId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Node: [" + node + ":0]; Variable: [" + varId + + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); +} - unresolved_input_exception unresolved_input_exception::build(const std::string &message, std::pair &varIndex) { - auto nodeId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_input_exception(rmessage); - } +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, std::pair &varIndex) { + auto nodeId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_input_exception(rmessage); +} - unresolved_input_exception unresolved_input_exception::build(const std::string &message, const std::string &varName) { - auto rmessage = message + "; Variable: [" + varName + "]"; - return unresolved_input_exception(rmessage); - } - } -} \ No newline at end of file +unresolved_input_exception unresolved_input_exception::build( + const std::string &message, const std::string &varName) { + auto rmessage = message + "; Variable: [" + varName + "]"; + return unresolved_input_exception(rmessage); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp index d7dede17ca08..49e08b56e472 100644 --- a/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp +++ b/libnd4j/include/graph/exceptions/impl/unresolved_output_exception.cpp @@ -20,30 +20,36 @@ #include #include + #include namespace sd { - namespace graph { - unresolved_output_exception::unresolved_output_exception(const std::string &message) : std::runtime_error(message) { - // - } +namespace graph { +unresolved_output_exception::unresolved_output_exception( + const std::string &message) + : std::runtime_error(message) { + // +} - unresolved_output_exception unresolved_output_exception::build(const std::string &message, std::pair &varIndex) { - auto nodeId = StringUtils::valueToString(varIndex.first); - auto outputIdx = StringUtils::valueToString(varIndex.second); - auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; - return unresolved_output_exception(rmessage); - } +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, std::pair &varIndex) { + auto nodeId = StringUtils::valueToString(varIndex.first); + auto outputIdx = StringUtils::valueToString(varIndex.second); + auto rmessage = message + "; Variable: [" + nodeId + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); +} - unresolved_output_exception unresolved_output_exception::build(const std::string &message, int nodeId, int outputIndex) { - std::pair p(nodeId, outputIndex); - return build(message, p); - } +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, int nodeId, int outputIndex) { + std::pair p(nodeId, outputIndex); + return build(message, p); +} - unresolved_output_exception unresolved_output_exception::build(const std::string &message, const std::string &varName, int outputIndex) { - auto outputIdx = StringUtils::valueToString(outputIndex); - auto rmessage = message +"; Variable: [" + varName + ":" + outputIdx + "]"; - return unresolved_output_exception(rmessage); - } - } -} \ No newline at end of file +unresolved_output_exception unresolved_output_exception::build( + const std::string &message, const std::string &varName, int outputIndex) { + auto outputIdx = StringUtils::valueToString(outputIndex); + auto rmessage = message + "; Variable: [" + varName + ":" + outputIdx + "]"; + return unresolved_output_exception(rmessage); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/exceptions/unresolved_input_exception.h b/libnd4j/include/graph/exceptions/unresolved_input_exception.h index 024798a95290..6cea61bacc23 100644 --- a/libnd4j/include/graph/exceptions/unresolved_input_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_input_exception.h @@ -21,22 +21,26 @@ #ifndef SD_UNRESOLVED_INPUT_H #define SD_UNRESOLVED_INPUT_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - class unresolved_input_exception : public std::runtime_error { - public: - unresolved_input_exception(const std::string &message); - ~unresolved_input_exception() = default; +namespace graph { +class unresolved_input_exception : public std::runtime_error { + public: + unresolved_input_exception(const std::string &message); + ~unresolved_input_exception() = default; - static unresolved_input_exception build(const std::string &message, int nodeId, std::pair &varIndex); - static unresolved_input_exception build(const std::string &message, std::pair &varIndex); - static unresolved_input_exception build(const std::string &message, const std::string &varName); - }; - } -} + static unresolved_input_exception build(const std::string &message, + int nodeId, + std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, + std::pair &varIndex); + static unresolved_input_exception build(const std::string &message, + const std::string &varName); +}; +} // namespace graph +} // namespace sd -#endif //SD_UNRESOLVED_INPUT_H +#endif // SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/graph/exceptions/unresolved_output_exception.h b/libnd4j/include/graph/exceptions/unresolved_output_exception.h index c923da785e2f..688028472d12 100644 --- a/libnd4j/include/graph/exceptions/unresolved_output_exception.h +++ b/libnd4j/include/graph/exceptions/unresolved_output_exception.h @@ -21,23 +21,26 @@ #ifndef SD_UNRESOLVED_OUTPUT_H #define SD_UNRESOLVED_OUTPUT_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - class unresolved_output_exception : public std::runtime_error { - public: - unresolved_output_exception(const std::string &message); - ~unresolved_output_exception() = default; - - static unresolved_output_exception build(const std::string &message, int nodeId, int outputIndex); - static unresolved_output_exception build(const std::string &message, std::pair &varIndex); - static unresolved_output_exception build(const std::string &message, const std::string &varName, int outputIndex = 0); - }; - } -} +namespace graph { +class unresolved_output_exception : public std::runtime_error { + public: + unresolved_output_exception(const std::string &message); + ~unresolved_output_exception() = default; + static unresolved_output_exception build(const std::string &message, + int nodeId, int outputIndex); + static unresolved_output_exception build(const std::string &message, + std::pair &varIndex); + static unresolved_output_exception build(const std::string &message, + const std::string &varName, + int outputIndex = 0); +}; +} // namespace graph +} // namespace sd -#endif //SD_UNRESOLVED_INPUT_H +#endif // SD_UNRESOLVED_INPUT_H diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index 58dcfb8985a8..04dcb2bbd538 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -18,53 +18,54 @@ // @author raver119@gmail.com // - #ifndef SD_EXECUTIONLAYER_H #define SD_EXECUTIONLAYER_H -#include #include +#include + #include namespace sd { - namespace graph { - class SD_EXPORT ExecutionLayer { - protected: - std::vector _sequences; - public: - ExecutionLayer(const std::vector &sequences = {}); - ~ExecutionLayer() = default; +namespace graph { +class SD_EXPORT ExecutionLayer { + protected: + std::vector _sequences; + + public: + ExecutionLayer(const std::vector& sequences = {}); + ~ExecutionLayer() = default; - ExecutionLayer(const ExecutionLayer& other) noexcept; + ExecutionLayer(const ExecutionLayer& other) noexcept; - ExecutionLayer& operator=(const ExecutionLayer& other) noexcept; + ExecutionLayer& operator=(const ExecutionLayer& other) noexcept; - // move constructor - ExecutionLayer(ExecutionLayer&& other) noexcept; + // move constructor + ExecutionLayer(ExecutionLayer&& other) noexcept; - // move assignment operator - ExecutionLayer& operator=(ExecutionLayer&& other) noexcept; + // move assignment operator + ExecutionLayer& operator=(ExecutionLayer&& other) noexcept; - /** - * This method returns number of sequences in this layer - * @return - */ - uint64_t width() const; + /** + * This method returns number of sequences in this layer + * @return + */ + uint64_t width() const; - /** - * This method returns specified OpSequence from this layer - * @return - */ - const OpSequence& at(uint64_t index) const; - const OpSequence& operator[](uint64_t index) const; + /** + * This method returns specified OpSequence from this layer + * @return + */ + const OpSequence& at(uint64_t index) const; + const OpSequence& operator[](uint64_t index) const; - /** - * This method appends OpSequence to the end of this layer - * @param sequence - */ - void append(const OpSequence &sequence); - }; - } -} + /** + * This method appends OpSequence to the end of this layer + * @param sequence + */ + void append(const OpSequence& sequence); +}; +} // namespace graph +} // namespace sd -#endif //SD_EXECUTIONLAYER_H +#endif // SD_EXECUTIONLAYER_H diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h index 12a0ac250b07..9e9d22f2f304 100644 --- a/libnd4j/include/graph/execution/ExecutionTask.h +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -21,40 +21,41 @@ #ifndef SD_EXECUTIONTASK_H #define SD_EXECUTIONTASK_H -#include -#include #include +#include -namespace sd { - namespace graph { - class SD_EXPORT ExecutionTask { - protected: - std::shared_ptr _op; - const ContextPrototype &_context; +#include - public: - ExecutionTask(const std::shared_ptr &op, const ContextPrototype &ctx); +namespace sd { +namespace graph { +class SD_EXPORT ExecutionTask { + protected: + std::shared_ptr _op; + const ContextPrototype& _context; - ~ExecutionTask() = default; + public: + ExecutionTask(const std::shared_ptr& op, + const ContextPrototype& ctx); - ExecutionTask(const ExecutionTask& other); + ~ExecutionTask() = default; - ExecutionTask& operator=(const ExecutionTask& other) noexcept; + ExecutionTask(const ExecutionTask& other); - // move constructor - ExecutionTask(ExecutionTask&& other); + ExecutionTask& operator=(const ExecutionTask& other) noexcept; - // move assignment operator - ExecutionTask& operator=(ExecutionTask&& other) noexcept; + // move constructor + ExecutionTask(ExecutionTask&& other); - void printOut() const; + // move assignment operator + ExecutionTask& operator=(ExecutionTask&& other) noexcept; - std::shared_ptr op() const; + void printOut() const; - const ContextPrototype &protoContext() const; - }; - } -} + std::shared_ptr op() const; + const ContextPrototype& protoContext() const; +}; +} // namespace graph +} // namespace sd -#endif //SD_EXECUTIONTASK_H +#endif // SD_EXECUTIONTASK_H diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index f77836de56f1..da83c0b19043 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -21,64 +21,76 @@ #ifndef SD_GRAPHEXECUTOR_H #define SD_GRAPHEXECUTOR_H -#include #include -#include #include +#include +#include namespace sd { - namespace graph { - class Graph; - - class SD_EXPORT GraphExecutor { - protected: - virtual Context prepareContext(const ContextPrototype &contextPrototype, VariableProxy &variableSpace, const GraphMemoryManager &memoryManager) const; +namespace graph { +class Graph; - /* - * preprocessor call involves: - * - ensure all inputs reside in HOT memory zone - * - shape function call - * - open workspace - */ - virtual Nd4jStatus preprocess(sd::ops::DeclarableOp *op, Context &context) const; +class SD_EXPORT GraphExecutor { + protected: + virtual Context prepareContext(const ContextPrototype &contextPrototype, + VariableProxy &variableSpace, + const GraphMemoryManager &memoryManager) const; - /** - * postporcessor call involves: - * - remove all inputs that are not going to be used later from HOT memory zone - * - close workspace - * @return - */ - virtual Nd4jStatus postprocess(sd::ops::DeclarableOp *op, Context *context) const; + /* + * preprocessor call involves: + * - ensure all inputs reside in HOT memory zone + * - shape function call + * - open workspace + */ + virtual Nd4jStatus preprocess(sd::ops::DeclarableOp *op, + Context &context) const; - public: - GraphExecutor() = default; - virtual ~GraphExecutor() = default; + /** + * postporcessor call involves: + * - remove all inputs that are not going to be used later from HOT memory + * zone + * - close workspace + * @return + */ + virtual Nd4jStatus postprocess(sd::ops::DeclarableOp *op, + Context *context) const; - /** - * This method executes OptimizedGraph instance - * @param graph - * @return - */ - virtual Nd4jStatus execute(const OptimizedGraph &graph, VariableProxy &proxy) const; + public: + GraphExecutor() = default; + virtual ~GraphExecutor() = default; - /** - * This method executes OpSequence - * @param sequence - * @param deviceId - this argument allows to override device affinity specified in OpSequence, keep it < 0 to follow OpSequence - * @return - */ - virtual Nd4jStatus execute(const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, int deviceId) const; + /** + * This method executes OptimizedGraph instance + * @param graph + * @return + */ + virtual Nd4jStatus execute(const OptimizedGraph &graph, + VariableProxy &proxy) const; - /** - * This method executes given op - * @param op - * @param contextPrototype - * @return - */ - virtual Nd4jStatus execute(const std::shared_ptr &op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const; - }; - } -} + /** + * This method executes OpSequence + * @param sequence + * @param deviceId - this argument allows to override device affinity + * specified in OpSequence, keep it < 0 to follow OpSequence + * @return + */ + virtual Nd4jStatus execute(const OpSequence &sequence, + const OptimizedGraph &graph, VariableProxy &proxy, + int deviceId) const; + /** + * This method executes given op + * @param op + * @param contextPrototype + * @return + */ + virtual Nd4jStatus execute(const std::shared_ptr &op, + const ContextPrototype &contextPrototype, + const OpSequence &sequence, + const OptimizedGraph &graph, VariableProxy &proxy, + const int deviceId) const; +}; +} // namespace graph +} // namespace sd -#endif //SD_GRAPHEXECUTOR_H +#endif // SD_GRAPHEXECUTOR_H diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index ba461b011bf8..eafd49f699e5 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -21,99 +21,105 @@ #ifndef SD_OPSEQUENCE_H #define SD_OPSEQUENCE_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class represents independent and immutable sequence of operations - */ - class SD_EXPORT OpSequence : public std::iterator { - // our internal iterator for OpSequence - class iterator; - protected: - // main thing here. sorted list of operations and their contexts - std::vector _ops; - - int _deviceId = 0; - public: - explicit OpSequence(const std::vector &ops, int deviceId = 0); - OpSequence(int deviceId = 0); - ~OpSequence() = default; - - OpSequence(const OpSequence& other) noexcept; - - OpSequence& operator=(const OpSequence& other) noexcept; - - // move constructor - OpSequence(OpSequence&& other) noexcept; - - // move assignment operator - OpSequence& operator=(OpSequence&& other) noexcept; - - - int deviceId() const; - - /** - * This method blocks until all operations within sequence are processed - * @return - */ - Nd4jStatus wait() const; - - /** - * This method prints out content of the sequence - */ - void printOut() const; - - /** - * This method returns number of individual operations within this sequence - * @return - */ - uint64_t length() const; - - /** - * This method returns specific Op/ContextPrototype pair for specified index - * @param index - * @return - */ - const ExecutionTask& at(uint64_t index) const; - const ExecutionTask& operator[](uint64_t index) const; - - /** - * This method allows to add DeclarableOp to the end of execution queue - * @param op - Op to be executed - * @param ctx - ContextPrototype for this operation with inputs/outputs/args defined - */ - void append(const std::shared_ptr &op, const sd::graph::ContextPrototype &ctx); - void append(sd::ops::DeclarableOp *op, const sd::graph::ContextPrototype &ctx); - - /** - * Iterator functionality for OpSequence - * @return - */ - - OpSequence::iterator begin(); - OpSequence::iterator end(); - - // additional private section - private: - class iterator : public std::iterator { - private: - uint64_t _position = 0; - OpSequence & _container; - public: - explicit iterator(OpSequence & container, uint64_t index = 0); - const ExecutionTask& operator*() const; - iterator & operator++(); - iterator & operator++(int); - bool operator!=(const iterator &) const; - }; - }; - } -} - - -#endif //SD_OPSEQUENCE_H +namespace graph { +/** + * This class represents independent and immutable sequence of operations + */ +class SD_EXPORT OpSequence + : public std::iterator { + // our internal iterator for OpSequence + class iterator; + + protected: + // main thing here. sorted list of operations and their contexts + std::vector _ops; + + int _deviceId = 0; + + public: + explicit OpSequence(const std::vector& ops, int deviceId = 0); + OpSequence(int deviceId = 0); + ~OpSequence() = default; + + OpSequence(const OpSequence& other) noexcept; + + OpSequence& operator=(const OpSequence& other) noexcept; + + // move constructor + OpSequence(OpSequence&& other) noexcept; + + // move assignment operator + OpSequence& operator=(OpSequence&& other) noexcept; + + int deviceId() const; + + /** + * This method blocks until all operations within sequence are processed + * @return + */ + Nd4jStatus wait() const; + + /** + * This method prints out content of the sequence + */ + void printOut() const; + + /** + * This method returns number of individual operations within this sequence + * @return + */ + uint64_t length() const; + + /** + * This method returns specific Op/ContextPrototype pair for specified index + * @param index + * @return + */ + const ExecutionTask& at(uint64_t index) const; + const ExecutionTask& operator[](uint64_t index) const; + + /** + * This method allows to add DeclarableOp to the end of execution queue + * @param op - Op to be executed + * @param ctx - ContextPrototype for this operation with inputs/outputs/args + * defined + */ + void append(const std::shared_ptr& op, + const sd::graph::ContextPrototype& ctx); + void append(sd::ops::DeclarableOp* op, + const sd::graph::ContextPrototype& ctx); + + /** + * Iterator functionality for OpSequence + * @return + */ + + OpSequence::iterator begin(); + OpSequence::iterator end(); + + // additional private section + private: + class iterator + : public std::iterator { + private: + uint64_t _position = 0; + OpSequence& _container; + + public: + explicit iterator(OpSequence& container, uint64_t index = 0); + const ExecutionTask& operator*() const; + iterator& operator++(); + iterator& operator++(int); + bool operator!=(const iterator&) const; + }; +}; +} // namespace graph +} // namespace sd + +#endif // SD_OPSEQUENCE_H diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index 65b91505d8f2..127b1436f99b 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -18,55 +18,51 @@ // @author raver119@gmail.com // - #include namespace sd { - namespace graph { - ExecutionLayer::ExecutionLayer(const std::vector &sequences) { - _sequences = sequences; - } +namespace graph { +ExecutionLayer::ExecutionLayer(const std::vector &sequences) { + _sequences = sequences; +} - uint64_t ExecutionLayer::width() const { - return _sequences.size(); - } +uint64_t ExecutionLayer::width() const { return _sequences.size(); } - const OpSequence& ExecutionLayer::at(uint64_t index) const { - return _sequences[index]; - } +const OpSequence &ExecutionLayer::at(uint64_t index) const { + return _sequences[index]; +} - const OpSequence& ExecutionLayer::operator[](uint64_t index) const { - return at(index); - } +const OpSequence &ExecutionLayer::operator[](uint64_t index) const { + return at(index); +} - void ExecutionLayer::append(const OpSequence &sequence) { - _sequences.emplace_back(sequence); - } +void ExecutionLayer::append(const OpSequence &sequence) { + _sequences.emplace_back(sequence); +} - ExecutionLayer::ExecutionLayer(const ExecutionLayer &other) noexcept { - _sequences = other._sequences; - } +ExecutionLayer::ExecutionLayer(const ExecutionLayer &other) noexcept { + _sequences = other._sequences; +} - ExecutionLayer &ExecutionLayer::operator=(const ExecutionLayer &other) noexcept { - if (this == &other) - return *this; +ExecutionLayer &ExecutionLayer::operator=( + const ExecutionLayer &other) noexcept { + if (this == &other) return *this; - _sequences = other._sequences; + _sequences = other._sequences; - return *this; - } + return *this; +} - ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept { - _sequences = std::move(other._sequences); - } +ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept { + _sequences = std::move(other._sequences); +} - ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { - if (this == &other) - return *this; +ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { + if (this == &other) return *this; - _sequences = std::move(other._sequences); + _sequences = std::move(other._sequences); - return *this; - } - } -} \ No newline at end of file + return *this; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index 81b696a6039f..c2067041e0f8 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -21,76 +21,74 @@ #include namespace sd { - namespace graph { - ExecutionTask::ExecutionTask(const std::shared_ptr &op, const ContextPrototype &ctx) : _op(op), _context(ctx) { - // - } - - std::shared_ptr ExecutionTask::op() const { - return _op; - } - - const ContextPrototype &ExecutionTask::protoContext() const { - return _context; - } - - ExecutionTask::ExecutionTask(const ExecutionTask &other) : _op(other._op), _context(other._context) { - // - } - - ExecutionTask &ExecutionTask::operator=(const ExecutionTask &other) noexcept { - if (this == &other) - return *this; - - _op = other._op; - const_cast(_context) = other._context; - - return *this; - } - - ExecutionTask::ExecutionTask(ExecutionTask &&other) : _op(other._op), _context(other._context) { - // - } - - void ExecutionTask::printOut() const { - if (_context.name().empty()) { - if (_op != nullptr) - printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _op->getOpName().c_str()); - else - printf(" <%i:0>: ", _context.nodeId()); - } else { - printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); - } - - auto sz = _context.inputs().size(); - if (sz) { - printf(" Inputs: ["); - int cnt = 0; - for (const auto &v:_context.inputs()) { - printf("<%i:%i>", v.first, v.second); - - if (cnt < sz - 1) - printf(", "); - cnt++; - } - - printf("]; "); - } else { - printf(" No inputs; "); - } - - printf("\n"); - fflush(stdout); - } - - ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { - if (this == &other) - return *this; - - _op = std::move(other._op); - const_cast(_context) = std::move(other._context); - - return *this; - } +namespace graph { +ExecutionTask::ExecutionTask(const std::shared_ptr &op, + const ContextPrototype &ctx) + : _op(op), _context(ctx) { + // +} + +std::shared_ptr ExecutionTask::op() const { return _op; } + +const ContextPrototype &ExecutionTask::protoContext() const { return _context; } + +ExecutionTask::ExecutionTask(const ExecutionTask &other) + : _op(other._op), _context(other._context) { + // +} + +ExecutionTask &ExecutionTask::operator=(const ExecutionTask &other) noexcept { + if (this == &other) return *this; + + _op = other._op; + const_cast(_context) = other._context; + + return *this; +} + +ExecutionTask::ExecutionTask(ExecutionTask &&other) + : _op(other._op), _context(other._context) { + // +} + +void ExecutionTask::printOut() const { + if (_context.name().empty()) { + if (_op != nullptr) + printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), + _op->getOpName().c_str()); + else + printf(" <%i:0>: ", _context.nodeId()); + } else { + printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); + } + + auto sz = _context.inputs().size(); + if (sz) { + printf(" Inputs: ["); + int cnt = 0; + for (const auto &v : _context.inputs()) { + printf("<%i:%i>", v.first, v.second); + + if (cnt < sz - 1) printf(", "); + cnt++; } + + printf("]; "); + } else { + printf(" No inputs; "); + } + + printf("\n"); + fflush(stdout); +} + +ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { + if (this == &other) return *this; + + _op = std::move(other._op); + const_cast(_context) = std::move(other._context); + + return *this; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index be4e4b3ff2e8..20059bbb75f1 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -18,76 +18,89 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace graph { - Context GraphExecutor::prepareContext(const ContextPrototype &contextPrototype, VariableProxy &variableProxy, const GraphMemoryManager &memoryManager) const { - // TODO: maybe we'll want to do something here? - return Context(contextPrototype, &variableProxy, const_cast(&memoryManager)); - } - - Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, Context &context) const { - // time to allocate outputs, if that's not inplace op - // inplace case is covered there - op->prepareOutputs(context); - - // once prepareOutputs method was called - we don't need shape function anymore - context.setShapeFunctionOverride(true); +namespace graph { +Context GraphExecutor::prepareContext( + const ContextPrototype &contextPrototype, VariableProxy &variableProxy, + const GraphMemoryManager &memoryManager) const { + // TODO: maybe we'll want to do something here? + return Context(contextPrototype, &variableProxy, + const_cast(&memoryManager)); +} - return Status::OK(); - } +Nd4jStatus GraphExecutor::preprocess(sd::ops::DeclarableOp *op, + Context &context) const { + // time to allocate outputs, if that's not inplace op + // inplace case is covered there + op->prepareOutputs(context); - Nd4jStatus GraphExecutor::postprocess(sd::ops::DeclarableOp *op, Context *context) const { - return Status::OK(); - } + // once prepareOutputs method was called - we don't need shape function + // anymore + context.setShapeFunctionOverride(true); + return Status::OK(); +} - Nd4jStatus GraphExecutor::execute(const std::shared_ptr &op, const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { - auto ctx = prepareContext(contextPrototype, proxy, graph.memoryManager()); - return op->execute(&ctx); - //throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); - } +Nd4jStatus GraphExecutor::postprocess(sd::ops::DeclarableOp *op, + Context *context) const { + return Status::OK(); +} - Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { - /* - * this is a basic implementation that works without dispatching etc - */ - for (int e = 0; e < sequence.length(); e++) { - auto v = sequence[e]; - auto result = execute(v.op(), v.protoContext(), sequence, graph, proxy, deviceId >= 0 ? deviceId : sequence.deviceId()); - if (result != Status::OK()) - return result; - } +Nd4jStatus GraphExecutor::execute( + const std::shared_ptr &op, + const ContextPrototype &contextPrototype, const OpSequence &sequence, + const OptimizedGraph &graph, VariableProxy &proxy, + const int deviceId) const { + auto ctx = prepareContext(contextPrototype, proxy, graph.memoryManager()); + return op->execute(&ctx); + // throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); +} - return Status::OK(); - } +Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, + const OptimizedGraph &graph, + VariableProxy &proxy, + const int deviceId) const { + /* + * this is a basic implementation that works without dispatching etc + */ + for (int e = 0; e < sequence.length(); e++) { + auto v = sequence[e]; + auto result = execute(v.op(), v.protoContext(), sequence, graph, proxy, + deviceId >= 0 ? deviceId : sequence.deviceId()); + if (result != Status::OK()) return result; + } + + return Status::OK(); +} - Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, VariableProxy &proxy) const { - const auto numDevices = AffinityManager::numberOfDevices(); +Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, + VariableProxy &proxy) const { + const auto numDevices = AffinityManager::numberOfDevices(); - /* - * this is a basic exection logic: roll through layers and sequences and execute them one by one sequentially - */ - Nd4jStatus result = Status::OK(); - for (uint64_t l = 0; l < graph.layers(); l++) { - const auto &layer = graph.layer(l); + /* + * this is a basic exection logic: roll through layers and sequences and + * execute them one by one sequentially + */ + Nd4jStatus result = Status::OK(); + for (uint64_t l = 0; l < graph.layers(); l++) { + const auto &layer = graph.layer(l); - for (uint64_t o = 0; o < layer.width(); o++) { - execute(layer[o], graph, proxy, -1); - } + for (uint64_t o = 0; o < layer.width(); o++) { + execute(layer[o], graph, proxy, -1); + } - // optionally block until all sequences in this layer processed - if (layer.width() > 0 && numDevices > 1) - for (uint64_t o = 0; o < layer.width(); o++) { - result = layer[o].wait(); - if (result != Status::OK()) - return result; - } - } + // optionally block until all sequences in this layer processed + if (layer.width() > 0 && numDevices > 1) + for (uint64_t o = 0; o < layer.width(); o++) { + result = layer[o].wait(); + if (result != Status::OK()) return result; + } + } - return result; - } - } + return result; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0d23ca97cd66..0dd4b908e4a9 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -22,115 +22,108 @@ #include namespace sd { - namespace graph { - OpSequence::OpSequence(const int deviceId) : _deviceId(deviceId) { - // - } - - OpSequence::OpSequence(const std::vector &ops, const int deviceId) { - _deviceId = deviceId; - for (const auto &v : ops) - _ops.emplace_back(v); - } - - OpSequence::OpSequence(const OpSequence& other) noexcept{ - _ops.clear(); - - for (const auto &v : other._ops) - _ops.emplace_back(v); - } - - //////////////////////////////////////////////////////////////////////// - // move constructor - OpSequence::OpSequence(OpSequence&& other) noexcept { - _ops = std::move(other._ops); - } - - OpSequence& OpSequence::operator=(OpSequence&& other) noexcept { - if (this == &other) - return *this; - - _ops = std::move(other._ops); - - return *this; - } - - OpSequence& OpSequence::operator=(const OpSequence& other) noexcept { - if (this == &other) - return *this; - - _ops.clear(); - for (const auto &v : other._ops) - _ops.emplace_back(v); - - return *this; - } - - void OpSequence::printOut() const { - for (const auto &v: _ops) - v.printOut(); - } - - int OpSequence::deviceId() const { - return _deviceId; - } - - const ExecutionTask& OpSequence::at(uint64_t index) const { - return _ops[index]; - } - - const ExecutionTask& OpSequence::operator[](uint64_t index) const { - return at(index); - } - - uint64_t OpSequence::length() const { - return _ops.size(); - } - - void OpSequence::append(const std::shared_ptr &op, const sd::graph::ContextPrototype &ctx) { - ExecutionTask task(op, ctx); - _ops.emplace_back(task); - } - - void OpSequence::append(sd::ops::DeclarableOp *op, const ContextPrototype &ctx) { - auto rop = sd::ops::OpRegistrator::getInstance()->getOperation(op->getOpHash()); - append(rop, ctx); - } - - OpSequence::iterator - OpSequence::begin() { - return OpSequence::iterator(*this, 0); - } - - OpSequence::iterator - OpSequence::end() { - return OpSequence::iterator(*this, length()); - } - - OpSequence::iterator::iterator(OpSequence &container, uint64_t index) :_container(container), _position(index) { - // - } - - const ExecutionTask& OpSequence::iterator::operator*() const { - return _container._ops[_position]; - } - - OpSequence::iterator &OpSequence::iterator::operator++() { - _position++; - return *this; - } - - OpSequence::iterator &OpSequence::iterator::operator++(int inc) { - return ++(*this); - } - - bool OpSequence::iterator::operator!=(const OpSequence::iterator &other) const { - return _position != other._position; - } - - Nd4jStatus OpSequence::wait() const { - // TODO: to be implemented - return Status::OK(); - } - } +namespace graph { +OpSequence::OpSequence(const int deviceId) : _deviceId(deviceId) { + // } + +OpSequence::OpSequence(const std::vector &ops, + const int deviceId) { + _deviceId = deviceId; + for (const auto &v : ops) _ops.emplace_back(v); +} + +OpSequence::OpSequence(const OpSequence &other) noexcept { + _ops.clear(); + + for (const auto &v : other._ops) _ops.emplace_back(v); +} + +//////////////////////////////////////////////////////////////////////// +// move constructor +OpSequence::OpSequence(OpSequence &&other) noexcept { + _ops = std::move(other._ops); +} + +OpSequence &OpSequence::operator=(OpSequence &&other) noexcept { + if (this == &other) return *this; + + _ops = std::move(other._ops); + + return *this; +} + +OpSequence &OpSequence::operator=(const OpSequence &other) noexcept { + if (this == &other) return *this; + + _ops.clear(); + for (const auto &v : other._ops) _ops.emplace_back(v); + + return *this; +} + +void OpSequence::printOut() const { + for (const auto &v : _ops) v.printOut(); +} + +int OpSequence::deviceId() const { return _deviceId; } + +const ExecutionTask &OpSequence::at(uint64_t index) const { + return _ops[index]; +} + +const ExecutionTask &OpSequence::operator[](uint64_t index) const { + return at(index); +} + +uint64_t OpSequence::length() const { return _ops.size(); } + +void OpSequence::append(const std::shared_ptr &op, + const sd::graph::ContextPrototype &ctx) { + ExecutionTask task(op, ctx); + _ops.emplace_back(task); +} + +void OpSequence::append(sd::ops::DeclarableOp *op, + const ContextPrototype &ctx) { + auto rop = + sd::ops::OpRegistrator::getInstance()->getOperation(op->getOpHash()); + append(rop, ctx); +} + +OpSequence::iterator OpSequence::begin() { + return OpSequence::iterator(*this, 0); +} + +OpSequence::iterator OpSequence::end() { + return OpSequence::iterator(*this, length()); +} + +OpSequence::iterator::iterator(OpSequence &container, uint64_t index) + : _container(container), _position(index) { + // +} + +const ExecutionTask &OpSequence::iterator::operator*() const { + return _container._ops[_position]; +} + +OpSequence::iterator &OpSequence::iterator::operator++() { + _position++; + return *this; +} + +OpSequence::iterator &OpSequence::iterator::operator++(int inc) { + return ++(*this); +} + +bool OpSequence::iterator::operator!=(const OpSequence::iterator &other) const { + return _position != other._position; +} + +Nd4jStatus OpSequence::wait() const { + // TODO: to be implemented + return Status::OK(); +} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/generated/array_generated.h b/libnd4j/include/graph/generated/array_generated.h index 5c4c0d7af947..1d1feaefab9f 100644 --- a/libnd4j/include/graph/generated/array_generated.h +++ b/libnd4j/include/graph/generated/array_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_ARRAY_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_ARRAY_ND4J_GRAPH_H_ @@ -19,19 +18,12 @@ enum ByteOrder { }; inline const ByteOrder (&EnumValuesByteOrder())[2] { - static const ByteOrder values[] = { - ByteOrder_LE, - ByteOrder_BE - }; + static const ByteOrder values[] = {ByteOrder_LE, ByteOrder_BE}; return values; } -inline const char * const *EnumNamesByteOrder() { - static const char * const names[] = { - "LE", - "BE", - nullptr - }; +inline const char *const *EnumNamesByteOrder() { + static const char *const names[] = {"LE", "BE", nullptr}; return names; } @@ -68,88 +60,24 @@ enum DType { inline const DType (&EnumValuesDType())[21] { static const DType values[] = { - DType_INHERIT, - DType_BOOL, - DType_FLOAT8, - DType_HALF, - DType_HALF2, - DType_FLOAT, - DType_DOUBLE, - DType_INT8, - DType_INT16, - DType_INT32, - DType_INT64, - DType_UINT8, - DType_UINT16, - DType_UINT32, - DType_UINT64, - DType_QINT8, - DType_QINT16, - DType_BFLOAT16, - DType_UTF8, - DType_UTF16, - DType_UTF32 - }; + DType_INHERIT, DType_BOOL, DType_FLOAT8, DType_HALF, DType_HALF2, + DType_FLOAT, DType_DOUBLE, DType_INT8, DType_INT16, DType_INT32, + DType_INT64, DType_UINT8, DType_UINT16, DType_UINT32, DType_UINT64, + DType_QINT8, DType_QINT16, DType_BFLOAT16, DType_UTF8, DType_UTF16, + DType_UTF32}; return values; } -inline const char * const *EnumNamesDType() { - static const char * const names[] = { - "INHERIT", - "BOOL", - "FLOAT8", - "HALF", - "HALF2", - "FLOAT", - "DOUBLE", - "INT8", - "INT16", - "INT32", - "INT64", - "UINT8", - "UINT16", - "UINT32", - "UINT64", - "QINT8", - "QINT16", - "BFLOAT16", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "UTF8", - "UTF16", - "UTF32", - nullptr - }; +inline const char *const *EnumNamesDType() { + static const char *const names[] = { + "INHERIT", "BOOL", "FLOAT8", "HALF", "HALF2", "FLOAT", "DOUBLE", + "INT8", "INT16", "INT32", "INT64", "UINT8", "UINT16", "UINT32", + "UINT64", "QINT8", "QINT16", "BFLOAT16", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "", "", "", "", "", "", + "", "UTF8", "UTF16", "UTF32", nullptr}; return names; } @@ -159,12 +87,7 @@ inline const char *EnumNameDType(DType e) { } struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_SHAPE = 4, - VT_BUFFER = 6, - VT_DTYPE = 8, - VT_BYTEORDER = 10 - }; + enum { VT_SHAPE = 4, VT_BUFFER = 6, VT_DTYPE = 8, VT_BYTEORDER = 10 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } @@ -178,14 +101,12 @@ struct FlatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return static_cast(GetField(VT_BYTEORDER, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SHAPE) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_BUFFER) && verifier.VerifyVector(buffer()) && VerifyField(verifier, VT_DTYPE) && - VerifyField(verifier, VT_BYTEORDER) && - verifier.EndTable(); + VerifyField(verifier, VT_BYTEORDER) && verifier.EndTable(); } }; @@ -202,10 +123,10 @@ struct FlatArrayBuilder { fbb_.AddElement(FlatArray::VT_DTYPE, static_cast(dtype), 0); } void add_byteOrder(ByteOrder byteOrder) { - fbb_.AddElement(FlatArray::VT_BYTEORDER, static_cast(byteOrder), 0); + fbb_.AddElement(FlatArray::VT_BYTEORDER, + static_cast(byteOrder), 0); } - explicit FlatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatArrayBuilder &operator=(const FlatArrayBuilder &); @@ -220,8 +141,7 @@ inline flatbuffers::Offset CreateFlatArray( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, flatbuffers::Offset> buffer = 0, - DType dtype = DType_INHERIT, - ByteOrder byteOrder = ByteOrder_LE) { + DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { FlatArrayBuilder builder_(_fbb); builder_.add_buffer(buffer); builder_.add_shape(shape); @@ -233,15 +153,11 @@ inline flatbuffers::Offset CreateFlatArray( inline flatbuffers::Offset CreateFlatArrayDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, - const std::vector *buffer = nullptr, - DType dtype = DType_INHERIT, + const std::vector *buffer = nullptr, DType dtype = DType_INHERIT, ByteOrder byteOrder = ByteOrder_LE) { return sd::graph::CreateFlatArray( - _fbb, - shape ? _fbb.CreateVector(*shape) : 0, - buffer ? _fbb.CreateVector(*buffer) : 0, - dtype, - byteOrder); + _fbb, shape ? _fbb.CreateVector(*shape) : 0, + buffer ? _fbb.CreateVector(*buffer) : 0, dtype, byteOrder); } inline const sd::graph::FlatArray *GetFlatArray(const void *buf) { @@ -252,13 +168,11 @@ inline const sd::graph::FlatArray *GetSizePrefixedFlatArray(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatArrayBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatArrayBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatArrayBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatArrayBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/config_generated.h b/libnd4j/include/graph/generated/config_generated.h index 2c12027a2b0c..505fd5767056 100644 --- a/libnd4j/include/graph/generated/config_generated.h +++ b/libnd4j/include/graph/generated/config_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_CONFIG_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_CONFIG_ND4J_GRAPH_H_ @@ -22,22 +21,14 @@ enum ProfilingMode { inline const ProfilingMode (&EnumValuesProfilingMode())[4] { static const ProfilingMode values[] = { - ProfilingMode_NONE, - ProfilingMode_NAN_PANIC, - ProfilingMode_INF_PANIC, - ProfilingMode_ANY_PANIC - }; + ProfilingMode_NONE, ProfilingMode_NAN_PANIC, ProfilingMode_INF_PANIC, + ProfilingMode_ANY_PANIC}; return values; } -inline const char * const *EnumNamesProfilingMode() { - static const char * const names[] = { - "NONE", - "NAN_PANIC", - "INF_PANIC", - "ANY_PANIC", - nullptr - }; +inline const char *const *EnumNamesProfilingMode() { + static const char *const names[] = {"NONE", "NAN_PANIC", "INF_PANIC", + "ANY_PANIC", nullptr}; return names; } @@ -56,20 +47,12 @@ enum ExecutionMode { inline const ExecutionMode (&EnumValuesExecutionMode())[3] { static const ExecutionMode values[] = { - ExecutionMode_SEQUENTIAL, - ExecutionMode_STRICT, - ExecutionMode_AUTO - }; + ExecutionMode_SEQUENTIAL, ExecutionMode_STRICT, ExecutionMode_AUTO}; return values; } -inline const char * const *EnumNamesExecutionMode() { - static const char * const names[] = { - "SEQUENTIAL", - "STRICT", - "AUTO", - nullptr - }; +inline const char *const *EnumNamesExecutionMode() { + static const char *const names[] = {"SEQUENTIAL", "STRICT", "AUTO", nullptr}; return names; } @@ -89,25 +72,17 @@ enum OutputMode { }; inline const OutputMode (&EnumValuesOutputMode())[5] { - static const OutputMode values[] = { - OutputMode_IMPLICIT, - OutputMode_EXPLICIT, - OutputMode_EXPLICIT_AND_IMPLICIT, - OutputMode_VARIABLE_SPACE, - OutputMode_OPTIMIZED - }; + static const OutputMode values[] = {OutputMode_IMPLICIT, OutputMode_EXPLICIT, + OutputMode_EXPLICIT_AND_IMPLICIT, + OutputMode_VARIABLE_SPACE, + OutputMode_OPTIMIZED}; return values; } -inline const char * const *EnumNamesOutputMode() { - static const char * const names[] = { - "IMPLICIT", - "EXPLICIT", - "EXPLICIT_AND_IMPLICIT", - "VARIABLE_SPACE", - "OPTIMIZED", - nullptr - }; +inline const char *const *EnumNamesOutputMode() { + static const char *const names[] = { + "IMPLICIT", "EXPLICIT", "EXPLICIT_AND_IMPLICIT", + "VARIABLE_SPACE", "OPTIMIZED", nullptr}; return names; } @@ -125,21 +100,15 @@ enum Direction { }; inline const Direction (&EnumValuesDirection())[3] { - static const Direction values[] = { - Direction_FORWARD_ONLY, - Direction_FORWARD_AND_BACKWARD, - Direction_BACKWARD_ONLY - }; + static const Direction values[] = {Direction_FORWARD_ONLY, + Direction_FORWARD_AND_BACKWARD, + Direction_BACKWARD_ONLY}; return values; } -inline const char * const *EnumNamesDirection() { - static const char * const names[] = { - "FORWARD_ONLY", - "FORWARD_AND_BACKWARD", - "BACKWARD_ONLY", - nullptr - }; +inline const char *const *EnumNamesDirection() { + static const char *const names[] = {"FORWARD_ONLY", "FORWARD_AND_BACKWARD", + "BACKWARD_ONLY", nullptr}; return names; } @@ -159,9 +128,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FOOTPRINTBACKWARD = 16, VT_DIRECTION = 18 }; - int64_t id() const { - return GetField(VT_ID, 0); - } + int64_t id() const { return GetField(VT_ID, 0); } ExecutionMode executionMode() const { return static_cast(GetField(VT_EXECUTIONMODE, 0)); } @@ -171,9 +138,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { OutputMode outputMode() const { return static_cast(GetField(VT_OUTPUTMODE, 0)); } - bool timestats() const { - return GetField(VT_TIMESTATS, 0) != 0; - } + bool timestats() const { return GetField(VT_TIMESTATS, 0) != 0; } int64_t footprintForward() const { return GetField(VT_FOOTPRINTFORWARD, 0); } @@ -192,8 +157,7 @@ struct FlatConfiguration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_TIMESTATS) && VerifyField(verifier, VT_FOOTPRINTFORWARD) && VerifyField(verifier, VT_FOOTPRINTBACKWARD) && - VerifyField(verifier, VT_DIRECTION) && - verifier.EndTable(); + VerifyField(verifier, VT_DIRECTION) && verifier.EndTable(); } }; @@ -204,28 +168,35 @@ struct FlatConfigurationBuilder { fbb_.AddElement(FlatConfiguration::VT_ID, id, 0); } void add_executionMode(ExecutionMode executionMode) { - fbb_.AddElement(FlatConfiguration::VT_EXECUTIONMODE, static_cast(executionMode), 0); + fbb_.AddElement(FlatConfiguration::VT_EXECUTIONMODE, + static_cast(executionMode), 0); } void add_profilingMode(ProfilingMode profilingMode) { - fbb_.AddElement(FlatConfiguration::VT_PROFILINGMODE, static_cast(profilingMode), 0); + fbb_.AddElement(FlatConfiguration::VT_PROFILINGMODE, + static_cast(profilingMode), 0); } void add_outputMode(OutputMode outputMode) { - fbb_.AddElement(FlatConfiguration::VT_OUTPUTMODE, static_cast(outputMode), 0); + fbb_.AddElement(FlatConfiguration::VT_OUTPUTMODE, + static_cast(outputMode), 0); } void add_timestats(bool timestats) { - fbb_.AddElement(FlatConfiguration::VT_TIMESTATS, static_cast(timestats), 0); + fbb_.AddElement(FlatConfiguration::VT_TIMESTATS, + static_cast(timestats), 0); } void add_footprintForward(int64_t footprintForward) { - fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTFORWARD, footprintForward, 0); + fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTFORWARD, + footprintForward, 0); } void add_footprintBackward(int64_t footprintBackward) { - fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTBACKWARD, footprintBackward, 0); + fbb_.AddElement(FlatConfiguration::VT_FOOTPRINTBACKWARD, + footprintBackward, 0); } void add_direction(Direction direction) { - fbb_.AddElement(FlatConfiguration::VT_DIRECTION, static_cast(direction), 0); + fbb_.AddElement(FlatConfiguration::VT_DIRECTION, + static_cast(direction), 0); } explicit FlatConfigurationBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatConfigurationBuilder &operator=(const FlatConfigurationBuilder &); @@ -237,14 +208,11 @@ struct FlatConfigurationBuilder { }; inline flatbuffers::Offset CreateFlatConfiguration( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, ExecutionMode executionMode = ExecutionMode_SEQUENTIAL, ProfilingMode profilingMode = ProfilingMode_NONE, - OutputMode outputMode = OutputMode_IMPLICIT, - bool timestats = false, - int64_t footprintForward = 0, - int64_t footprintBackward = 0, + OutputMode outputMode = OutputMode_IMPLICIT, bool timestats = false, + int64_t footprintForward = 0, int64_t footprintBackward = 0, Direction direction = Direction_FORWARD_ONLY) { FlatConfigurationBuilder builder_(_fbb); builder_.add_footprintBackward(footprintBackward); @@ -258,22 +226,24 @@ inline flatbuffers::Offset CreateFlatConfiguration( return builder_.Finish(); } -inline const sd::graph::FlatConfiguration *GetFlatConfiguration(const void *buf) { +inline const sd::graph::FlatConfiguration *GetFlatConfiguration( + const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatConfiguration *GetSizePrefixedFlatConfiguration(const void *buf) { +inline const sd::graph::FlatConfiguration *GetSizePrefixedFlatConfiguration( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatConfigurationBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatConfigurationBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } inline bool VerifySizePrefixedFlatConfigurationBuffer( flatbuffers::Verifier &verifier) { - return verifier.VerifySizePrefixedBuffer(nullptr); + return verifier.VerifySizePrefixedBuffer( + nullptr); } inline void FinishFlatConfigurationBuffer( diff --git a/libnd4j/include/graph/generated/graph.grpc.fb.h b/libnd4j/include/graph/generated/graph.grpc.fb.h index 0167f48d568c..7f9e27d42163 100644 --- a/libnd4j/include/graph/generated/graph.grpc.fb.h +++ b/libnd4j/include/graph/generated/graph.grpc.fb.h @@ -4,9 +4,6 @@ #ifndef GRPC_graph__INCLUDED #define GRPC_graph__INCLUDED -#include "graph_generated.h" -#include "flatbuffers/grpc.h" - #include #include #include @@ -17,6 +14,9 @@ #include #include +#include "flatbuffers/grpc.h" +#include "graph_generated.h" + namespace grpc { class CompletionQueue; class Channel; @@ -35,196 +35,449 @@ class GraphInferenceServer final { class StubInterface { public: virtual ~StubInterface() {} - virtual ::grpc::Status RegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncRegisterGraphRaw(context, request, cq)); - } - virtual ::grpc::Status ForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncForgetGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncForgetGraphRaw(context, request, cq)); - } - virtual ::grpc::Status ReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncReplaceGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncReplaceGraphRaw(context, request, cq)); - } - virtual ::grpc::Status InferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> AsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(AsyncInferenceRequestRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>> PrepareAsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>>(PrepareAsyncInferenceRequestRaw(context, request, cq)); - } - private: - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* AsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< flatbuffers::grpc::Message>* PrepareAsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::Status RegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncRegisterGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncRegisterGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncRegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncRegisterGraphRaw(context, request, cq)); + } + virtual ::grpc::Status ForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncForgetGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncForgetGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncForgetGraphRaw(context, request, cq)); + } + virtual ::grpc::Status ReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncReplaceGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncReplaceGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncReplaceGraphRaw(context, request, cq)); + } + virtual ::grpc::Status InferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + AsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + AsyncInferenceRequestRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>> + PrepareAsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>>( + PrepareAsyncInferenceRequestRaw(context, request, cq)); + } + + private: + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncRegisterGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncRegisterGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncReplaceGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncReplaceGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + AsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + flatbuffers::grpc::Message>* + PrepareAsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) = 0; }; class Stub final : public StubInterface { public: - Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status RegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncRegisterGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncRegisterGraphRaw(context, request, cq)); + Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); + ::grpc::Status RegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncRegisterGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncRegisterGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncRegisterGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncRegisterGraphRaw(context, request, cq)); + } + ::grpc::Status ForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncForgetGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncForgetGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncForgetGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncForgetGraphRaw(context, request, cq)); + } + ::grpc::Status ReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncReplaceGraph(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncReplaceGraphRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncReplaceGraph( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncReplaceGraphRaw(context, request, cq)); + } + ::grpc::Status InferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + flatbuffers::grpc::Message* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + AsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + AsyncInferenceRequestRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>> + PrepareAsyncInferenceRequest( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader< + flatbuffers::grpc::Message>>( + PrepareAsyncInferenceRequestRaw(context, request, cq)); } - ::grpc::Status ForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncForgetGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncForgetGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncForgetGraphRaw(context, request, cq)); - } - ::grpc::Status ReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncReplaceGraphRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncReplaceGraph(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncReplaceGraphRaw(context, request, cq)); - } - ::grpc::Status InferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, flatbuffers::grpc::Message* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> AsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(AsyncInferenceRequestRaw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>> PrepareAsyncInferenceRequest(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>>(PrepareAsyncInferenceRequestRaw(context, request, cq)); - } - + private: - std::shared_ptr< ::grpc::ChannelInterface> channel_; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncRegisterGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncForgetGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncReplaceGraphRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* AsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< flatbuffers::grpc::Message>* PrepareAsyncInferenceRequestRaw(::grpc::ClientContext* context, const flatbuffers::grpc::Message& request, ::grpc::CompletionQueue* cq) override; + std::shared_ptr<::grpc::ChannelInterface> channel_; + ::grpc::ClientAsyncResponseReader>* + AsyncRegisterGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncRegisterGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncForgetGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncReplaceGraphRaw(::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncReplaceGraphRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + AsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader>* + PrepareAsyncInferenceRequestRaw( + ::grpc::ClientContext* context, + const flatbuffers::grpc::Message& request, + ::grpc::CompletionQueue* cq) override; const ::grpc::internal::RpcMethod rpcmethod_RegisterGraph_; const ::grpc::internal::RpcMethod rpcmethod_ForgetGraph_; const ::grpc::internal::RpcMethod rpcmethod_ReplaceGraph_; const ::grpc::internal::RpcMethod rpcmethod_InferenceRequest_; }; - static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); - + static std::unique_ptr NewStub( + const std::shared_ptr<::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); + class Service : public ::grpc::Service { public: Service(); virtual ~Service(); - virtual ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); - virtual ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response); + virtual ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); + virtual ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response); }; template class WithAsyncMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_RegisterGraph() { - ::grpc::Service::MarkMethodAsync(0); - } + WithAsyncMethod_RegisterGraph() { ::grpc::Service::MarkMethodAsync(0); } ~WithAsyncMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestRegisterGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestRegisterGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_ForgetGraph() { - ::grpc::Service::MarkMethodAsync(1); - } + WithAsyncMethod_ForgetGraph() { ::grpc::Service::MarkMethodAsync(1); } ~WithAsyncMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestForgetGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(1, context, request, response, new_call_cq, notification_cq, tag); + void RequestForgetGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(1, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodAsync(2); - } + WithAsyncMethod_ReplaceGraph() { ::grpc::Service::MarkMethodAsync(2); } ~WithAsyncMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestReplaceGraph(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(2, context, request, response, new_call_cq, notification_cq, tag); + void RequestReplaceGraph( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(2, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_InferenceRequest() { - ::grpc::Service::MarkMethodAsync(3); - } + WithAsyncMethod_InferenceRequest() { ::grpc::Service::MarkMethodAsync(3); } ~WithAsyncMethod_InferenceRequest() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestInferenceRequest(::grpc::ServerContext* context, flatbuffers::grpc::Message* request, ::grpc::ServerAsyncResponseWriter< flatbuffers::grpc::Message>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(3, context, request, response, new_call_cq, notification_cq, tag); + void RequestInferenceRequest( + ::grpc::ServerContext* context, + flatbuffers::grpc::Message* request, + ::grpc::ServerAsyncResponseWriter< + flatbuffers::grpc::Message>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(3, context, request, response, + new_call_cq, notification_cq, tag); } }; - typedef WithAsyncMethod_RegisterGraph< WithAsyncMethod_ForgetGraph< WithAsyncMethod_ReplaceGraph< WithAsyncMethod_InferenceRequest< Service > > > > AsyncService; + typedef WithAsyncMethod_RegisterGraph>>> + AsyncService; template class WithGenericMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_RegisterGraph() { - ::grpc::Service::MarkMethodGeneric(0); - } + WithGenericMethod_RegisterGraph() { ::grpc::Service::MarkMethodGeneric(0); } ~WithGenericMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -232,16 +485,18 @@ class GraphInferenceServer final { template class WithGenericMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_ForgetGraph() { - ::grpc::Service::MarkMethodGeneric(1); - } + WithGenericMethod_ForgetGraph() { ::grpc::Service::MarkMethodGeneric(1); } ~WithGenericMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -249,16 +504,18 @@ class GraphInferenceServer final { template class WithGenericMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodGeneric(2); - } + WithGenericMethod_ReplaceGraph() { ::grpc::Service::MarkMethodGeneric(2); } ~WithGenericMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -266,7 +523,8 @@ class GraphInferenceServer final { template class WithGenericMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithGenericMethod_InferenceRequest() { ::grpc::Service::MarkMethodGeneric(3); @@ -275,7 +533,10 @@ class GraphInferenceServer final { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -283,90 +544,147 @@ class GraphInferenceServer final { template class WithStreamedUnaryMethod_RegisterGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_RegisterGraph() { - ::grpc::Service::MarkMethodStreamed(0, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_RegisterGraph::StreamedRegisterGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 0, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_RegisterGraph< + BaseClass>::StreamedRegisterGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_RegisterGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status RegisterGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status RegisterGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedRegisterGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedRegisterGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_ForgetGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_ForgetGraph() { - ::grpc::Service::MarkMethodStreamed(1, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_ForgetGraph::StreamedForgetGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 1, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_ForgetGraph< + BaseClass>::StreamedForgetGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_ForgetGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status ForgetGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ForgetGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedForgetGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedForgetGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_ReplaceGraph : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_ReplaceGraph() { - ::grpc::Service::MarkMethodStreamed(2, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_ReplaceGraph::StreamedReplaceGraph, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 2, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_ReplaceGraph< + BaseClass>::StreamedReplaceGraph, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_ReplaceGraph() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status ReplaceGraph(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status ReplaceGraph( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedReplaceGraph(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedReplaceGraph( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer, + flatbuffers::grpc::Message>* + server_unary_streamer) = 0; }; template class WithStreamedUnaryMethod_InferenceRequest : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_InferenceRequest() { - ::grpc::Service::MarkMethodStreamed(3, - new ::grpc::internal::StreamedUnaryHandler< flatbuffers::grpc::Message, flatbuffers::grpc::Message>(std::bind(&WithStreamedUnaryMethod_InferenceRequest::StreamedInferenceRequest, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 3, new ::grpc::internal::StreamedUnaryHandler< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>(std::bind( + &WithStreamedUnaryMethod_InferenceRequest< + BaseClass>::StreamedInferenceRequest, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_InferenceRequest() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status InferenceRequest(::grpc::ServerContext* context, const flatbuffers::grpc::Message* request, flatbuffers::grpc::Message* response) final override { + ::grpc::Status InferenceRequest( + ::grpc::ServerContext* context, + const flatbuffers::grpc::Message* request, + flatbuffers::grpc::Message* response) final override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedInferenceRequest(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< flatbuffers::grpc::Message,flatbuffers::grpc::Message>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedInferenceRequest( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer< + flatbuffers::grpc::Message, + flatbuffers::grpc::Message>* server_unary_streamer) = 0; }; - typedef WithStreamedUnaryMethod_RegisterGraph< WithStreamedUnaryMethod_ForgetGraph< WithStreamedUnaryMethod_ReplaceGraph< WithStreamedUnaryMethod_InferenceRequest< Service > > > > StreamedUnaryService; - typedef Service SplitStreamedService; - typedef WithStreamedUnaryMethod_RegisterGraph< WithStreamedUnaryMethod_ForgetGraph< WithStreamedUnaryMethod_ReplaceGraph< WithStreamedUnaryMethod_InferenceRequest< Service > > > > StreamedService; + typedef WithStreamedUnaryMethod_RegisterGraph< + WithStreamedUnaryMethod_ForgetGraph>>> + StreamedUnaryService; + typedef Service SplitStreamedService; + typedef WithStreamedUnaryMethod_RegisterGraph< + WithStreamedUnaryMethod_ForgetGraph>>> + StreamedService; }; } // namespace graph } // namespace sd - #endif // GRPC_graph__INCLUDED diff --git a/libnd4j/include/graph/generated/graph_generated.h b/libnd4j/include/graph/generated/graph_generated.h index 1285e4607e91..6421ccdbee29 100644 --- a/libnd4j/include/graph/generated/graph_generated.h +++ b/libnd4j/include/graph/generated/graph_generated.h @@ -1,13 +1,11 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_GRAPH_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_GRAPH_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" #include "config_generated.h" +#include "flatbuffers/flatbuffers.h" #include "node_generated.h" #include "properties_generated.h" #include "request_generated.h" @@ -27,23 +25,24 @@ struct FlatDropRequest; struct FlatResponse; struct UpdaterState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_PARAMNAME = 4, - VT_UPDATERSTATEKEYS = 6, - VT_UPDATERSTATEVALUES = 8 - }; + enum { VT_PARAMNAME = 4, VT_UPDATERSTATEKEYS = 6, VT_UPDATERSTATEVALUES = 8 }; const flatbuffers::String *paramName() const { return GetPointer(VT_PARAMNAME); } - const flatbuffers::Vector> *updaterStateKeys() const { - return GetPointer> *>(VT_UPDATERSTATEKEYS); + const flatbuffers::Vector> + *updaterStateKeys() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATEKEYS); } - const flatbuffers::Vector> *updaterStateValues() const { - return GetPointer> *>(VT_UPDATERSTATEVALUES); + const flatbuffers::Vector> + *updaterStateValues() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATEVALUES); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_PARAMNAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PARAMNAME) && verifier.VerifyString(paramName()) && VerifyOffset(verifier, VT_UPDATERSTATEKEYS) && verifier.VerifyVector(updaterStateKeys()) && @@ -61,14 +60,19 @@ struct UpdaterStateBuilder { void add_paramName(flatbuffers::Offset paramName) { fbb_.AddOffset(UpdaterState::VT_PARAMNAME, paramName); } - void add_updaterStateKeys(flatbuffers::Offset>> updaterStateKeys) { + void add_updaterStateKeys( + flatbuffers::Offset< + flatbuffers::Vector>> + updaterStateKeys) { fbb_.AddOffset(UpdaterState::VT_UPDATERSTATEKEYS, updaterStateKeys); } - void add_updaterStateValues(flatbuffers::Offset>> updaterStateValues) { + void add_updaterStateValues( + flatbuffers::Offset>> + updaterStateValues) { fbb_.AddOffset(UpdaterState::VT_UPDATERSTATEVALUES, updaterStateValues); } explicit UpdaterStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UpdaterStateBuilder &operator=(const UpdaterStateBuilder &); @@ -82,8 +86,11 @@ struct UpdaterStateBuilder { inline flatbuffers::Offset CreateUpdaterState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset paramName = 0, - flatbuffers::Offset>> updaterStateKeys = 0, - flatbuffers::Offset>> updaterStateValues = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + updaterStateKeys = 0, + flatbuffers::Offset>> + updaterStateValues = 0) { UpdaterStateBuilder builder_(_fbb); builder_.add_updaterStateValues(updaterStateValues); builder_.add_updaterStateKeys(updaterStateKeys); @@ -92,15 +99,20 @@ inline flatbuffers::Offset CreateUpdaterState( } inline flatbuffers::Offset CreateUpdaterStateDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *paramName = nullptr, - const std::vector> *updaterStateKeys = nullptr, - const std::vector> *updaterStateValues = nullptr) { + flatbuffers::FlatBufferBuilder &_fbb, const char *paramName = nullptr, + const std::vector> + *updaterStateKeys = nullptr, + const std::vector> *updaterStateValues = + nullptr) { return sd::graph::CreateUpdaterState( - _fbb, - paramName ? _fbb.CreateString(paramName) : 0, - updaterStateKeys ? _fbb.CreateVector>(*updaterStateKeys) : 0, - updaterStateValues ? _fbb.CreateVector>(*updaterStateValues) : 0); + _fbb, paramName ? _fbb.CreateString(paramName) : 0, + updaterStateKeys + ? _fbb.CreateVector>( + *updaterStateKeys) + : 0, + updaterStateValues ? _fbb.CreateVector>( + *updaterStateValues) + : 0); } struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -115,32 +127,44 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TRAININGCONFIG = 18, VT_UPDATERSTATE = 20 }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *nodes() const { - return GetPointer> *>(VT_NODES); + return GetPointer< + const flatbuffers::Vector> *>(VT_NODES); } const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); + return GetPointer< + const flatbuffers::Vector> *>(VT_OUTPUTS); } const FlatConfiguration *configuration() const { return GetPointer(VT_CONFIGURATION); } - const flatbuffers::Vector> *placeholders() const { - return GetPointer> *>(VT_PLACEHOLDERS); + const flatbuffers::Vector> + *placeholders() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_PLACEHOLDERS); } - const flatbuffers::Vector> *lossVariables() const { - return GetPointer> *>(VT_LOSSVARIABLES); + const flatbuffers::Vector> + *lossVariables() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_LOSSVARIABLES); } const flatbuffers::String *trainingConfig() const { return GetPointer(VT_TRAININGCONFIG); } - const flatbuffers::Vector> *updaterState() const { - return GetPointer> *>(VT_UPDATERSTATE); + const flatbuffers::Vector> *updaterState() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_UPDATERSTATE); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -148,8 +172,7 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_VARIABLES) && verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && - VerifyOffset(verifier, VT_NODES) && - verifier.VerifyVector(nodes()) && + VerifyOffset(verifier, VT_NODES) && verifier.VerifyVector(nodes()) && verifier.VerifyVectorOfTables(nodes()) && VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) && @@ -166,43 +189,54 @@ struct FlatGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyString(trainingConfig()) && VerifyOffset(verifier, VT_UPDATERSTATE) && verifier.VerifyVector(updaterState()) && - verifier.VerifyVectorOfTables(updaterState()) && - verifier.EndTable(); + verifier.VerifyVectorOfTables(updaterState()) && verifier.EndTable(); } }; struct FlatGraphBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_id(int64_t id) { - fbb_.AddElement(FlatGraph::VT_ID, id, 0); - } - void add_variables(flatbuffers::Offset>> variables) { + void add_id(int64_t id) { fbb_.AddElement(FlatGraph::VT_ID, id, 0); } + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatGraph::VT_VARIABLES, variables); } - void add_nodes(flatbuffers::Offset>> nodes) { + void add_nodes( + flatbuffers::Offset>> + nodes) { fbb_.AddOffset(FlatGraph::VT_NODES, nodes); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset>> + outputs) { fbb_.AddOffset(FlatGraph::VT_OUTPUTS, outputs); } void add_configuration(flatbuffers::Offset configuration) { fbb_.AddOffset(FlatGraph::VT_CONFIGURATION, configuration); } - void add_placeholders(flatbuffers::Offset>> placeholders) { + void add_placeholders( + flatbuffers::Offset< + flatbuffers::Vector>> + placeholders) { fbb_.AddOffset(FlatGraph::VT_PLACEHOLDERS, placeholders); } - void add_lossVariables(flatbuffers::Offset>> lossVariables) { + void add_lossVariables( + flatbuffers::Offset< + flatbuffers::Vector>> + lossVariables) { fbb_.AddOffset(FlatGraph::VT_LOSSVARIABLES, lossVariables); } - void add_trainingConfig(flatbuffers::Offset trainingConfig) { + void add_trainingConfig( + flatbuffers::Offset trainingConfig) { fbb_.AddOffset(FlatGraph::VT_TRAININGCONFIG, trainingConfig); } - void add_updaterState(flatbuffers::Offset>> updaterState) { + void add_updaterState(flatbuffers::Offset< + flatbuffers::Vector>> + updaterState) { fbb_.AddOffset(FlatGraph::VT_UPDATERSTATE, updaterState); } - explicit FlatGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatGraphBuilder &operator=(const FlatGraphBuilder &); @@ -214,16 +248,23 @@ struct FlatGraphBuilder { }; inline flatbuffers::Offset CreateFlatGraph( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> nodes = 0, - flatbuffers::Offset>> outputs = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> + nodes = 0, + flatbuffers::Offset>> + outputs = 0, flatbuffers::Offset configuration = 0, - flatbuffers::Offset>> placeholders = 0, - flatbuffers::Offset>> lossVariables = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + placeholders = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + lossVariables = 0, flatbuffers::Offset trainingConfig = 0, - flatbuffers::Offset>> updaterState = 0) { + flatbuffers::Offset>> + updaterState = 0) { FlatGraphBuilder builder_(_fbb); builder_.add_id(id); builder_.add_updaterState(updaterState); @@ -238,40 +279,46 @@ inline flatbuffers::Offset CreateFlatGraph( } inline flatbuffers::Offset CreateFlatGraphDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, const std::vector> *nodes = nullptr, const std::vector> *outputs = nullptr, flatbuffers::Offset configuration = 0, - const std::vector> *placeholders = nullptr, - const std::vector> *lossVariables = nullptr, + const std::vector> *placeholders = + nullptr, + const std::vector> *lossVariables = + nullptr, const char *trainingConfig = nullptr, - const std::vector> *updaterState = nullptr) { + const std::vector> *updaterState = + nullptr) { return sd::graph::CreateFlatGraph( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, nodes ? _fbb.CreateVector>(*nodes) : 0, outputs ? _fbb.CreateVector>(*outputs) : 0, configuration, - placeholders ? _fbb.CreateVector>(*placeholders) : 0, - lossVariables ? _fbb.CreateVector>(*lossVariables) : 0, + placeholders + ? _fbb.CreateVector>( + *placeholders) + : 0, + lossVariables + ? _fbb.CreateVector>( + *lossVariables) + : 0, trainingConfig ? _fbb.CreateString(trainingConfig) : 0, - updaterState ? _fbb.CreateVector>(*updaterState) : 0); + updaterState + ? _fbb.CreateVector>(*updaterState) + : 0); } struct FlatDropRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4 - }; - int64_t id() const { - return GetField(VT_ID, 0); - } + enum { VT_ID = 4 }; + int64_t id() const { return GetField(VT_ID, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_ID) && - verifier.EndTable(); + VerifyField(verifier, VT_ID) && verifier.EndTable(); } }; @@ -282,7 +329,7 @@ struct FlatDropRequestBuilder { fbb_.AddElement(FlatDropRequest::VT_ID, id, 0); } explicit FlatDropRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatDropRequestBuilder &operator=(const FlatDropRequestBuilder &); @@ -294,24 +341,18 @@ struct FlatDropRequestBuilder { }; inline flatbuffers::Offset CreateFlatDropRequest( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0) { FlatDropRequestBuilder builder_(_fbb); builder_.add_id(id); return builder_.Finish(); } struct FlatResponse FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_STATUS = 4 - }; - int32_t status() const { - return GetField(VT_STATUS, 0); - } + enum { VT_STATUS = 4 }; + int32_t status() const { return GetField(VT_STATUS, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_STATUS) && - verifier.EndTable(); + VerifyField(verifier, VT_STATUS) && verifier.EndTable(); } }; @@ -322,7 +363,7 @@ struct FlatResponseBuilder { fbb_.AddElement(FlatResponse::VT_STATUS, status, 0); } explicit FlatResponseBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatResponseBuilder &operator=(const FlatResponseBuilder &); @@ -334,8 +375,7 @@ struct FlatResponseBuilder { }; inline flatbuffers::Offset CreateFlatResponse( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t status = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t status = 0) { FlatResponseBuilder builder_(_fbb); builder_.add_status(status); return builder_.Finish(); @@ -349,13 +389,11 @@ inline const sd::graph::FlatGraph *GetSizePrefixedFlatGraph(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatGraphBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatGraphBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatGraphBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatGraphBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/node_generated.h b/libnd4j/include/graph/generated/node_generated.h index a39f2490c9d2..503cded18db6 100644 --- a/libnd4j/include/graph/generated/node_generated.h +++ b/libnd4j/include/graph/generated/node_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_NODE_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_NODE_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "properties_generated.h" #include "utils_generated.h" @@ -41,26 +39,27 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_CONTROLDEPFOR = 46, VT_EXTRATYPES = 48 }; - int32_t id() const { - return GetField(VT_ID, 0); - } + int32_t id() const { return GetField(VT_ID, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } OpType opType() const { return static_cast(GetField(VT_OPTYPE, 0)); } - int64_t opNum() const { - return GetField(VT_OPNUM, 0); - } - const flatbuffers::Vector> *properties() const { - return GetPointer> *>(VT_PROPERTIES); + int64_t opNum() const { return GetField(VT_OPNUM, 0); } + const flatbuffers::Vector> *properties() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_PROPERTIES); } const flatbuffers::Vector *input() const { return GetPointer *>(VT_INPUT); } const flatbuffers::Vector> *inputPaired() const { - return GetPointer> *>(VT_INPUTPAIRED); + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTPAIRED); } const flatbuffers::Vector *output() const { return GetPointer *>(VT_OUTPUT); @@ -77,17 +76,16 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *dimensions() const { return GetPointer *>(VT_DIMENSIONS); } - int32_t device() const { - return GetField(VT_DEVICE, 0); - } - int32_t scope_id() const { - return GetField(VT_SCOPE_ID, 0); - } + int32_t device() const { return GetField(VT_DEVICE, 0); } + int32_t scope_id() const { return GetField(VT_SCOPE_ID, 0); } const flatbuffers::String *scope_name() const { return GetPointer(VT_SCOPE_NAME); } - const flatbuffers::Vector> *outputNames() const { - return GetPointer> *>(VT_OUTPUTNAMES); + const flatbuffers::Vector> + *outputNames() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTNAMES); } const flatbuffers::String *opName() const { return GetPointer(VT_OPNAME); @@ -98,14 +96,23 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *scalar() const { return GetPointer(VT_SCALAR); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } - const flatbuffers::Vector> *varControlDeps() const { - return GetPointer> *>(VT_VARCONTROLDEPS); + const flatbuffers::Vector> + *varControlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARCONTROLDEPS); } - const flatbuffers::Vector> *controlDepFor() const { - return GetPointer> *>(VT_CONTROLDEPFOR); + const flatbuffers::Vector> + *controlDepFor() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPFOR); } const flatbuffers::Vector *extraTypes() const { return GetPointer *>(VT_EXTRATYPES); @@ -113,15 +120,13 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_OPTYPE) && VerifyField(verifier, VT_OPNUM) && VerifyOffset(verifier, VT_PROPERTIES) && verifier.VerifyVector(properties()) && verifier.VerifyVectorOfTables(properties()) && - VerifyOffset(verifier, VT_INPUT) && - verifier.VerifyVector(input()) && + VerifyOffset(verifier, VT_INPUT) && verifier.VerifyVector(input()) && VerifyOffset(verifier, VT_INPUTPAIRED) && verifier.VerifyVector(inputPaired()) && verifier.VerifyVectorOfTables(inputPaired()) && @@ -158,48 +163,54 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(controlDepFor()) && verifier.VerifyVectorOfStrings(controlDepFor()) && VerifyOffset(verifier, VT_EXTRATYPES) && - verifier.VerifyVector(extraTypes()) && - verifier.EndTable(); + verifier.VerifyVector(extraTypes()) && verifier.EndTable(); } }; struct FlatNodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_id(int32_t id) { - fbb_.AddElement(FlatNode::VT_ID, id, 0); - } + void add_id(int32_t id) { fbb_.AddElement(FlatNode::VT_ID, id, 0); } void add_name(flatbuffers::Offset name) { fbb_.AddOffset(FlatNode::VT_NAME, name); } void add_opType(OpType opType) { - fbb_.AddElement(FlatNode::VT_OPTYPE, static_cast(opType), 0); + fbb_.AddElement(FlatNode::VT_OPTYPE, static_cast(opType), + 0); } void add_opNum(int64_t opNum) { fbb_.AddElement(FlatNode::VT_OPNUM, opNum, 0); } - void add_properties(flatbuffers::Offset>> properties) { + void add_properties(flatbuffers::Offset< + flatbuffers::Vector>> + properties) { fbb_.AddOffset(FlatNode::VT_PROPERTIES, properties); } void add_input(flatbuffers::Offset> input) { fbb_.AddOffset(FlatNode::VT_INPUT, input); } - void add_inputPaired(flatbuffers::Offset>> inputPaired) { + void add_inputPaired( + flatbuffers::Offset>> + inputPaired) { fbb_.AddOffset(FlatNode::VT_INPUTPAIRED, inputPaired); } void add_output(flatbuffers::Offset> output) { fbb_.AddOffset(FlatNode::VT_OUTPUT, output); } - void add_extraParams(flatbuffers::Offset> extraParams) { + void add_extraParams( + flatbuffers::Offset> extraParams) { fbb_.AddOffset(FlatNode::VT_EXTRAPARAMS, extraParams); } - void add_extraInteger(flatbuffers::Offset> extraInteger) { + void add_extraInteger( + flatbuffers::Offset> extraInteger) { fbb_.AddOffset(FlatNode::VT_EXTRAINTEGER, extraInteger); } - void add_extraBools(flatbuffers::Offset> extraBools) { + void add_extraBools( + flatbuffers::Offset> extraBools) { fbb_.AddOffset(FlatNode::VT_EXTRABOOLS, extraBools); } - void add_dimensions(flatbuffers::Offset> dimensions) { + void add_dimensions( + flatbuffers::Offset> dimensions) { fbb_.AddOffset(FlatNode::VT_DIMENSIONS, dimensions); } void add_device(int32_t device) { @@ -211,32 +222,45 @@ struct FlatNodeBuilder { void add_scope_name(flatbuffers::Offset scope_name) { fbb_.AddOffset(FlatNode::VT_SCOPE_NAME, scope_name); } - void add_outputNames(flatbuffers::Offset>> outputNames) { + void add_outputNames( + flatbuffers::Offset< + flatbuffers::Vector>> + outputNames) { fbb_.AddOffset(FlatNode::VT_OUTPUTNAMES, outputNames); } void add_opName(flatbuffers::Offset opName) { fbb_.AddOffset(FlatNode::VT_OPNAME, opName); } - void add_outputTypes(flatbuffers::Offset> outputTypes) { + void add_outputTypes( + flatbuffers::Offset> outputTypes) { fbb_.AddOffset(FlatNode::VT_OUTPUTTYPES, outputTypes); } void add_scalar(flatbuffers::Offset scalar) { fbb_.AddOffset(FlatNode::VT_SCALAR, scalar); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps); } - void add_varControlDeps(flatbuffers::Offset>> varControlDeps) { + void add_varControlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + varControlDeps) { fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps); } - void add_controlDepFor(flatbuffers::Offset>> controlDepFor) { + void add_controlDepFor( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepFor) { fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor); } - void add_extraTypes(flatbuffers::Offset> extraTypes) { + void add_extraTypes( + flatbuffers::Offset> extraTypes) { fbb_.AddOffset(FlatNode::VT_EXTRATYPES, extraTypes); } - explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatNodeBuilder &operator=(const FlatNodeBuilder &); @@ -248,29 +272,37 @@ struct FlatNodeBuilder { }; inline flatbuffers::Offset CreateFlatNode( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, flatbuffers::Offset name = 0, - OpType opType = OpType_TRANSFORM_FLOAT, - int64_t opNum = 0, - flatbuffers::Offset>> properties = 0, + OpType opType = OpType_TRANSFORM_FLOAT, int64_t opNum = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + properties = 0, flatbuffers::Offset> input = 0, - flatbuffers::Offset>> inputPaired = 0, + flatbuffers::Offset>> + inputPaired = 0, flatbuffers::Offset> output = 0, flatbuffers::Offset> extraParams = 0, flatbuffers::Offset> extraInteger = 0, flatbuffers::Offset> extraBools = 0, flatbuffers::Offset> dimensions = 0, - int32_t device = 0, - int32_t scope_id = 0, + int32_t device = 0, int32_t scope_id = 0, flatbuffers::Offset scope_name = 0, - flatbuffers::Offset>> outputNames = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputNames = 0, flatbuffers::Offset opName = 0, flatbuffers::Offset> outputTypes = 0, flatbuffers::Offset scalar = 0, - flatbuffers::Offset>> controlDeps = 0, - flatbuffers::Offset>> varControlDeps = 0, - flatbuffers::Offset>> controlDepFor = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + varControlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepFor = 0, flatbuffers::Offset> extraTypes = 0) { FlatNodeBuilder builder_(_fbb); builder_.add_opNum(opNum); @@ -300,54 +332,62 @@ inline flatbuffers::Offset CreateFlatNode( } inline flatbuffers::Offset CreateFlatNodeDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, - const char *name = nullptr, - OpType opType = OpType_TRANSFORM_FLOAT, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, + const char *name = nullptr, OpType opType = OpType_TRANSFORM_FLOAT, int64_t opNum = 0, - const std::vector> *properties = nullptr, + const std::vector> *properties = + nullptr, const std::vector *input = nullptr, const std::vector> *inputPaired = nullptr, const std::vector *output = nullptr, const std::vector *extraParams = nullptr, const std::vector *extraInteger = nullptr, const std::vector *extraBools = nullptr, - const std::vector *dimensions = nullptr, - int32_t device = 0, - int32_t scope_id = 0, - const char *scope_name = nullptr, - const std::vector> *outputNames = nullptr, + const std::vector *dimensions = nullptr, int32_t device = 0, + int32_t scope_id = 0, const char *scope_name = nullptr, + const std::vector> *outputNames = + nullptr, const char *opName = nullptr, const std::vector *outputTypes = nullptr, flatbuffers::Offset scalar = 0, - const std::vector> *controlDeps = nullptr, - const std::vector> *varControlDeps = nullptr, - const std::vector> *controlDepFor = nullptr, + const std::vector> *controlDeps = + nullptr, + const std::vector> + *varControlDeps = nullptr, + const std::vector> *controlDepFor = + nullptr, const std::vector *extraTypes = nullptr) { return sd::graph::CreateFlatNode( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - opType, - opNum, - properties ? _fbb.CreateVector>(*properties) : 0, + _fbb, id, name ? _fbb.CreateString(name) : 0, opType, opNum, + properties + ? _fbb.CreateVector>(*properties) + : 0, input ? _fbb.CreateVector(*input) : 0, - inputPaired ? _fbb.CreateVector>(*inputPaired) : 0, + inputPaired + ? _fbb.CreateVector>(*inputPaired) + : 0, output ? _fbb.CreateVector(*output) : 0, extraParams ? _fbb.CreateVector(*extraParams) : 0, extraInteger ? _fbb.CreateVector(*extraInteger) : 0, extraBools ? _fbb.CreateVector(*extraBools) : 0, - dimensions ? _fbb.CreateVector(*dimensions) : 0, - device, - scope_id, - scope_name ? _fbb.CreateString(scope_name) : 0, - outputNames ? _fbb.CreateVector>(*outputNames) : 0, + dimensions ? _fbb.CreateVector(*dimensions) : 0, device, + scope_id, scope_name ? _fbb.CreateString(scope_name) : 0, + outputNames ? _fbb.CreateVector>( + *outputNames) + : 0, opName ? _fbb.CreateString(opName) : 0, - outputTypes ? _fbb.CreateVector(*outputTypes) : 0, - scalar, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, - varControlDeps ? _fbb.CreateVector>(*varControlDeps) : 0, - controlDepFor ? _fbb.CreateVector>(*controlDepFor) : 0, + outputTypes ? _fbb.CreateVector(*outputTypes) : 0, scalar, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, + varControlDeps + ? _fbb.CreateVector>( + *varControlDeps) + : 0, + controlDepFor + ? _fbb.CreateVector>( + *controlDepFor) + : 0, extraTypes ? _fbb.CreateVector(*extraTypes) : 0); } @@ -359,13 +399,11 @@ inline const sd::graph::FlatNode *GetSizePrefixedFlatNode(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatNodeBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatNodeBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } -inline bool VerifySizePrefixedFlatNodeBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifySizePrefixedFlatNodeBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifySizePrefixedBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/properties_generated.h b/libnd4j/include/graph/generated/properties_generated.h index 34138fe86264..21986255d6e1 100644 --- a/libnd4j/include/graph/generated/properties_generated.h +++ b/libnd4j/include/graph/generated/properties_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_PROPERTIES_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_PROPERTIES_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" namespace sd { namespace graph { @@ -37,37 +35,32 @@ struct FlatProperties FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_D); } const flatbuffers::Vector> *a() const { - return GetPointer> *>(VT_A); + return GetPointer< + const flatbuffers::Vector> *>(VT_A); } const flatbuffers::Vector *b() const { return GetPointer *>(VT_B); } - const flatbuffers::Vector> *s() const { - return GetPointer> *>(VT_S); + const flatbuffers::Vector> *s() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_S); } const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && - VerifyOffset(verifier, VT_I) && - verifier.VerifyVector(i()) && - VerifyOffset(verifier, VT_L) && - verifier.VerifyVector(l()) && - VerifyOffset(verifier, VT_D) && - verifier.VerifyVector(d()) && - VerifyOffset(verifier, VT_A) && - verifier.VerifyVector(a()) && - verifier.VerifyVectorOfTables(a()) && - VerifyOffset(verifier, VT_B) && - verifier.VerifyVector(b()) && - VerifyOffset(verifier, VT_S) && - verifier.VerifyVector(s()) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyOffset(verifier, VT_I) && + verifier.VerifyVector(i()) && VerifyOffset(verifier, VT_L) && + verifier.VerifyVector(l()) && VerifyOffset(verifier, VT_D) && + verifier.VerifyVector(d()) && VerifyOffset(verifier, VT_A) && + verifier.VerifyVector(a()) && verifier.VerifyVectorOfTables(a()) && + VerifyOffset(verifier, VT_B) && verifier.VerifyVector(b()) && + VerifyOffset(verifier, VT_S) && verifier.VerifyVector(s()) && verifier.VerifyVectorOfStrings(s()) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && verifier.EndTable(); } }; @@ -87,20 +80,24 @@ struct FlatPropertiesBuilder { void add_d(flatbuffers::Offset> d) { fbb_.AddOffset(FlatProperties::VT_D, d); } - void add_a(flatbuffers::Offset>> a) { + void add_a( + flatbuffers::Offset>> + a) { fbb_.AddOffset(FlatProperties::VT_A, a); } void add_b(flatbuffers::Offset> b) { fbb_.AddOffset(FlatProperties::VT_B, b); } - void add_s(flatbuffers::Offset>> s) { + void add_s(flatbuffers::Offset< + flatbuffers::Vector>> + s) { fbb_.AddOffset(FlatProperties::VT_S, s); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(FlatProperties::VT_SHAPE, shape); } explicit FlatPropertiesBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatPropertiesBuilder &operator=(const FlatPropertiesBuilder &); @@ -117,9 +114,12 @@ inline flatbuffers::Offset CreateFlatProperties( flatbuffers::Offset> i = 0, flatbuffers::Offset> l = 0, flatbuffers::Offset> d = 0, - flatbuffers::Offset>> a = 0, + flatbuffers::Offset>> a = + 0, flatbuffers::Offset> b = 0, - flatbuffers::Offset>> s = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + s = 0, flatbuffers::Offset> shape = 0) { FlatPropertiesBuilder builder_(_fbb); builder_.add_shape(shape); @@ -134,8 +134,7 @@ inline flatbuffers::Offset CreateFlatProperties( } inline flatbuffers::Offset CreateFlatPropertiesDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *name = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr, const std::vector *i = nullptr, const std::vector *l = nullptr, const std::vector *d = nullptr, @@ -144,8 +143,7 @@ inline flatbuffers::Offset CreateFlatPropertiesDirect( const std::vector> *s = nullptr, const std::vector *shape = nullptr) { return sd::graph::CreateFlatProperties( - _fbb, - name ? _fbb.CreateString(name) : 0, + _fbb, name ? _fbb.CreateString(name) : 0, i ? _fbb.CreateVector(*i) : 0, l ? _fbb.CreateVector(*l) : 0, d ? _fbb.CreateVector(*d) : 0, @@ -159,12 +157,12 @@ inline const sd::graph::FlatProperties *GetFlatProperties(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatProperties *GetSizePrefixedFlatProperties(const void *buf) { +inline const sd::graph::FlatProperties *GetSizePrefixedFlatProperties( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatPropertiesBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatPropertiesBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/request_generated.h b/libnd4j/include/graph/generated/request_generated.h index 00c782311b05..970336dd3393 100644 --- a/libnd4j/include/graph/generated/request_generated.h +++ b/libnd4j/include/graph/generated/request_generated.h @@ -1,13 +1,11 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_REQUEST_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_REQUEST_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" #include "config_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" #include "variable_generated.h" @@ -16,17 +14,15 @@ namespace graph { struct FlatInferenceRequest; -struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4, - VT_VARIABLES = 6, - VT_CONFIGURATION = 8 - }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); +struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + enum { VT_ID = 4, VT_VARIABLES = 6, VT_CONFIGURATION = 8 }; + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const FlatConfiguration *configuration() const { return GetPointer(VT_CONFIGURATION); @@ -38,8 +34,7 @@ struct FlatInferenceRequest FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && VerifyOffset(verifier, VT_CONFIGURATION) && - verifier.VerifyTable(configuration()) && - verifier.EndTable(); + verifier.VerifyTable(configuration()) && verifier.EndTable(); } }; @@ -49,14 +44,16 @@ struct FlatInferenceRequestBuilder { void add_id(int64_t id) { fbb_.AddElement(FlatInferenceRequest::VT_ID, id, 0); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatInferenceRequest::VT_VARIABLES, variables); } void add_configuration(flatbuffers::Offset configuration) { fbb_.AddOffset(FlatInferenceRequest::VT_CONFIGURATION, configuration); } explicit FlatInferenceRequestBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatInferenceRequestBuilder &operator=(const FlatInferenceRequestBuilder &); @@ -68,9 +65,9 @@ struct FlatInferenceRequestBuilder { }; inline flatbuffers::Offset CreateFlatInferenceRequest( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, flatbuffers::Offset configuration = 0) { FlatInferenceRequestBuilder builder_(_fbb); builder_.add_id(id); @@ -79,34 +76,37 @@ inline flatbuffers::Offset CreateFlatInferenceRequest( return builder_.Finish(); } -inline flatbuffers::Offset CreateFlatInferenceRequestDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, +inline flatbuffers::Offset +CreateFlatInferenceRequestDirect( + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, flatbuffers::Offset configuration = 0) { return sd::graph::CreateFlatInferenceRequest( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, configuration); } -inline const sd::graph::FlatInferenceRequest *GetFlatInferenceRequest(const void *buf) { +inline const sd::graph::FlatInferenceRequest *GetFlatInferenceRequest( + const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatInferenceRequest *GetSizePrefixedFlatInferenceRequest(const void *buf) { +inline const sd::graph::FlatInferenceRequest * +GetSizePrefixedFlatInferenceRequest(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatInferenceRequestBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatInferenceRequestBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } inline bool VerifySizePrefixedFlatInferenceRequestBuffer( flatbuffers::Verifier &verifier) { - return verifier.VerifySizePrefixedBuffer(nullptr); + return verifier.VerifySizePrefixedBuffer( + nullptr); } inline void FinishFlatInferenceRequestBuffer( diff --git a/libnd4j/include/graph/generated/result_generated.h b/libnd4j/include/graph/generated/result_generated.h index 04c458a9fde1..b05a0ed44300 100644 --- a/libnd4j/include/graph/generated/result_generated.h +++ b/libnd4j/include/graph/generated/result_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_RESULT_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_RESULT_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "node_generated.h" #include "properties_generated.h" #include "utils_generated.h" @@ -20,14 +18,8 @@ struct FlatTiming; struct FlatResult; struct FlatTiming FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_ID = 4, - VT_NAME = 6, - VT_TIMING = 8 - }; - int32_t id() const { - return GetField(VT_ID, 0); - } + enum { VT_ID = 4, VT_NAME = 6, VT_TIMING = 8 }; + int32_t id() const { return GetField(VT_ID, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -37,11 +29,9 @@ struct FlatTiming FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_ID) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyOffset(verifier, VT_TIMING) && - verifier.VerifyTable(timing()) && - verifier.EndTable(); + verifier.VerifyTable(timing()) && verifier.EndTable(); } }; @@ -58,7 +48,7 @@ struct FlatTimingBuilder { fbb_.AddOffset(FlatTiming::VT_TIMING, timing); } explicit FlatTimingBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatTimingBuilder &operator=(const FlatTimingBuilder &); @@ -70,8 +60,7 @@ struct FlatTimingBuilder { }; inline flatbuffers::Offset CreateFlatTiming( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, flatbuffers::Offset name = 0, flatbuffers::Offset timing = 0) { FlatTimingBuilder builder_(_fbb); @@ -82,15 +71,10 @@ inline flatbuffers::Offset CreateFlatTiming( } inline flatbuffers::Offset CreateFlatTimingDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t id = 0, - const char *name = nullptr, - flatbuffers::Offset timing = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t id = 0, + const char *name = nullptr, flatbuffers::Offset timing = 0) { return sd::graph::CreateFlatTiming( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - timing); + _fbb, id, name ? _fbb.CreateString(name) : 0, timing); } struct FlatResult FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -101,14 +85,17 @@ struct FlatResult FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FOOTPRINTFORWARD = 10, VT_FOOTPRINTBACKWARD = 12 }; - int64_t id() const { - return GetField(VT_ID, 0); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + int64_t id() const { return GetField(VT_ID, 0); } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *timing() const { - return GetPointer> *>(VT_TIMING); + return GetPointer< + const flatbuffers::Vector> *>( + VT_TIMING); } int64_t footprintForward() const { return GetField(VT_FOOTPRINTFORWARD, 0); @@ -137,20 +124,26 @@ struct FlatResultBuilder { void add_id(int64_t id) { fbb_.AddElement(FlatResult::VT_ID, id, 0); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables(flatbuffers::Offset< + flatbuffers::Vector>> + variables) { fbb_.AddOffset(FlatResult::VT_VARIABLES, variables); } - void add_timing(flatbuffers::Offset>> timing) { + void add_timing( + flatbuffers::Offset>> + timing) { fbb_.AddOffset(FlatResult::VT_TIMING, timing); } void add_footprintForward(int64_t footprintForward) { - fbb_.AddElement(FlatResult::VT_FOOTPRINTFORWARD, footprintForward, 0); + fbb_.AddElement(FlatResult::VT_FOOTPRINTFORWARD, footprintForward, + 0); } void add_footprintBackward(int64_t footprintBackward) { - fbb_.AddElement(FlatResult::VT_FOOTPRINTBACKWARD, footprintBackward, 0); + fbb_.AddElement(FlatResult::VT_FOOTPRINTBACKWARD, + footprintBackward, 0); } explicit FlatResultBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatResultBuilder &operator=(const FlatResultBuilder &); @@ -162,12 +155,12 @@ struct FlatResultBuilder { }; inline flatbuffers::Offset CreateFlatResult( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> timing = 0, - int64_t footprintForward = 0, - int64_t footprintBackward = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> + timing = 0, + int64_t footprintForward = 0, int64_t footprintBackward = 0) { FlatResultBuilder builder_(_fbb); builder_.add_footprintBackward(footprintBackward); builder_.add_footprintForward(footprintForward); @@ -178,19 +171,17 @@ inline flatbuffers::Offset CreateFlatResult( } inline flatbuffers::Offset CreateFlatResultDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t id = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t id = 0, const std::vector> *variables = nullptr, const std::vector> *timing = nullptr, - int64_t footprintForward = 0, - int64_t footprintBackward = 0) { + int64_t footprintForward = 0, int64_t footprintBackward = 0) { return sd::graph::CreateFlatResult( - _fbb, - id, - variables ? _fbb.CreateVector>(*variables) : 0, + _fbb, id, + variables + ? _fbb.CreateVector>(*variables) + : 0, timing ? _fbb.CreateVector>(*timing) : 0, - footprintForward, - footprintBackward); + footprintForward, footprintBackward); } inline const sd::graph::FlatResult *GetFlatResult(const void *buf) { @@ -201,8 +192,7 @@ inline const sd::graph::FlatResult *GetSizePrefixedFlatResult(const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatResultBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatResultBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/generated/uigraphevents_generated.h b/libnd4j/include/graph/generated/uigraphevents_generated.h index b3430a5c7ca8..5d10f9d75906 100644 --- a/libnd4j/include/graph/generated/uigraphevents_generated.h +++ b/libnd4j/include/graph/generated/uigraphevents_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UIGRAPHEVENTS_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UIGRAPHEVENTS_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" namespace sd { namespace graph { @@ -40,33 +38,29 @@ enum UIEventType { }; inline const UIEventType (&EnumValuesUIEventType())[9] { - static const UIEventType values[] = { - UIEventType_ADD_NAME, - UIEventType_SCALAR, - UIEventType_ARRAY, - UIEventType_ARRAY_LIST, - UIEventType_HISTOGRAM, - UIEventType_IMAGE, - UIEventType_SUMMARY_STATISTICS, - UIEventType_OP_TIMING, - UIEventType_HARDWARE_STATE - }; + static const UIEventType values[] = {UIEventType_ADD_NAME, + UIEventType_SCALAR, + UIEventType_ARRAY, + UIEventType_ARRAY_LIST, + UIEventType_HISTOGRAM, + UIEventType_IMAGE, + UIEventType_SUMMARY_STATISTICS, + UIEventType_OP_TIMING, + UIEventType_HARDWARE_STATE}; return values; } -inline const char * const *EnumNamesUIEventType() { - static const char * const names[] = { - "ADD_NAME", - "SCALAR", - "ARRAY", - "ARRAY_LIST", - "HISTOGRAM", - "IMAGE", - "SUMMARY_STATISTICS", - "OP_TIMING", - "HARDWARE_STATE", - nullptr - }; +inline const char *const *EnumNamesUIEventType() { + static const char *const names[] = {"ADD_NAME", + "SCALAR", + "ARRAY", + "ARRAY_LIST", + "HISTOGRAM", + "IMAGE", + "SUMMARY_STATISTICS", + "OP_TIMING", + "HARDWARE_STATE", + nullptr}; return names; } @@ -92,34 +86,19 @@ enum UIEventSubtype { inline const UIEventSubtype (&EnumValuesUIEventSubtype())[10] { static const UIEventSubtype values[] = { - UIEventSubtype_NONE, - UIEventSubtype_EVALUATION, - UIEventSubtype_LOSS, - UIEventSubtype_LEARNING_RATE, - UIEventSubtype_TUNING_METRIC, - UIEventSubtype_PERFORMANCE, - UIEventSubtype_PROFILING, - UIEventSubtype_FEATURE_LABEL, - UIEventSubtype_PREDICTION, - UIEventSubtype_USER_CUSTOM - }; + UIEventSubtype_NONE, UIEventSubtype_EVALUATION, + UIEventSubtype_LOSS, UIEventSubtype_LEARNING_RATE, + UIEventSubtype_TUNING_METRIC, UIEventSubtype_PERFORMANCE, + UIEventSubtype_PROFILING, UIEventSubtype_FEATURE_LABEL, + UIEventSubtype_PREDICTION, UIEventSubtype_USER_CUSTOM}; return values; } -inline const char * const *EnumNamesUIEventSubtype() { - static const char * const names[] = { - "NONE", - "EVALUATION", - "LOSS", - "LEARNING_RATE", - "TUNING_METRIC", - "PERFORMANCE", - "PROFILING", - "FEATURE_LABEL", - "PREDICTION", - "USER_CUSTOM", - nullptr - }; +inline const char *const *EnumNamesUIEventSubtype() { + static const char *const names[] = { + "NONE", "EVALUATION", "LOSS", "LEARNING_RATE", + "TUNING_METRIC", "PERFORMANCE", "PROFILING", "FEATURE_LABEL", + "PREDICTION", "USER_CUSTOM", nullptr}; return names; } @@ -137,21 +116,15 @@ enum UIHistogramType { }; inline const UIHistogramType (&EnumValuesUIHistogramType())[3] { - static const UIHistogramType values[] = { - UIHistogramType_DISCRETE, - UIHistogramType_EQUAL_SPACING, - UIHistogramType_CUSTOM - }; + static const UIHistogramType values[] = {UIHistogramType_DISCRETE, + UIHistogramType_EQUAL_SPACING, + UIHistogramType_CUSTOM}; return values; } -inline const char * const *EnumNamesUIHistogramType() { - static const char * const names[] = { - "DISCRETE", - "EQUAL_SPACING", - "CUSTOM", - nullptr - }; +inline const char *const *EnumNamesUIHistogramType() { + static const char *const names[] = {"DISCRETE", "EQUAL_SPACING", "CUSTOM", + nullptr}; return names; } @@ -178,27 +151,15 @@ struct UIEvent FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { UIEventSubtype eventSubType() const { return static_cast(GetField(VT_EVENTSUBTYPE, 0)); } - int32_t nameIdx() const { - return GetField(VT_NAMEIDX, 0); - } - int64_t timestamp() const { - return GetField(VT_TIMESTAMP, 0); - } - int32_t iteration() const { - return GetField(VT_ITERATION, 0); - } - int32_t epoch() const { - return GetField(VT_EPOCH, 0); - } - int16_t variableId() const { - return GetField(VT_VARIABLEID, 0); - } + int32_t nameIdx() const { return GetField(VT_NAMEIDX, 0); } + int64_t timestamp() const { return GetField(VT_TIMESTAMP, 0); } + int32_t iteration() const { return GetField(VT_ITERATION, 0); } + int32_t epoch() const { return GetField(VT_EPOCH, 0); } + int16_t variableId() const { return GetField(VT_VARIABLEID, 0); } const FrameIteration *frameIter() const { return GetPointer(VT_FRAMEITER); } - uint16_t plugin() const { - return GetField(VT_PLUGIN, 0); - } + uint16_t plugin() const { return GetField(VT_PLUGIN, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_EVENTTYPE) && @@ -210,8 +171,7 @@ struct UIEvent FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_VARIABLEID) && VerifyOffset(verifier, VT_FRAMEITER) && verifier.VerifyTable(frameIter()) && - VerifyField(verifier, VT_PLUGIN) && - verifier.EndTable(); + VerifyField(verifier, VT_PLUGIN) && verifier.EndTable(); } }; @@ -219,10 +179,12 @@ struct UIEventBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_eventType(UIEventType eventType) { - fbb_.AddElement(UIEvent::VT_EVENTTYPE, static_cast(eventType), 0); + fbb_.AddElement(UIEvent::VT_EVENTTYPE, + static_cast(eventType), 0); } void add_eventSubType(UIEventSubtype eventSubType) { - fbb_.AddElement(UIEvent::VT_EVENTSUBTYPE, static_cast(eventSubType), 0); + fbb_.AddElement(UIEvent::VT_EVENTSUBTYPE, + static_cast(eventSubType), 0); } void add_nameIdx(int32_t nameIdx) { fbb_.AddElement(UIEvent::VT_NAMEIDX, nameIdx, 0); @@ -245,8 +207,7 @@ struct UIEventBuilder { void add_plugin(uint16_t plugin) { fbb_.AddElement(UIEvent::VT_PLUGIN, plugin, 0); } - explicit UIEventBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIEventBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIEventBuilder &operator=(const UIEventBuilder &); @@ -260,13 +221,9 @@ struct UIEventBuilder { inline flatbuffers::Offset CreateUIEvent( flatbuffers::FlatBufferBuilder &_fbb, UIEventType eventType = UIEventType_ADD_NAME, - UIEventSubtype eventSubType = UIEventSubtype_NONE, - int32_t nameIdx = 0, - int64_t timestamp = 0, - int32_t iteration = 0, - int32_t epoch = 0, - int16_t variableId = 0, - flatbuffers::Offset frameIter = 0, + UIEventSubtype eventSubType = UIEventSubtype_NONE, int32_t nameIdx = 0, + int64_t timestamp = 0, int32_t iteration = 0, int32_t epoch = 0, + int16_t variableId = 0, flatbuffers::Offset frameIter = 0, uint16_t plugin = 0) { UIEventBuilder builder_(_fbb); builder_.add_timestamp(timestamp); @@ -282,22 +239,15 @@ inline flatbuffers::Offset CreateUIEvent( } struct FrameIteration FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FRAME = 4, - VT_ITERATION = 6 - }; + enum { VT_FRAME = 4, VT_ITERATION = 6 }; const flatbuffers::String *frame() const { return GetPointer(VT_FRAME); } - uint16_t iteration() const { - return GetField(VT_ITERATION, 0); - } + uint16_t iteration() const { return GetField(VT_ITERATION, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_FRAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_FRAME) && verifier.VerifyString(frame()) && - VerifyField(verifier, VT_ITERATION) && - verifier.EndTable(); + VerifyField(verifier, VT_ITERATION) && verifier.EndTable(); } }; @@ -311,7 +261,7 @@ struct FrameIterationBuilder { fbb_.AddElement(FrameIteration::VT_ITERATION, iteration, 0); } explicit FrameIterationBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FrameIterationBuilder &operator=(const FrameIterationBuilder &); @@ -333,31 +283,22 @@ inline flatbuffers::Offset CreateFrameIteration( } inline flatbuffers::Offset CreateFrameIterationDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *frame = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *frame = nullptr, uint16_t iteration = 0) { return sd::graph::CreateFrameIteration( - _fbb, - frame ? _fbb.CreateString(frame) : 0, - iteration); + _fbb, frame ? _fbb.CreateString(frame) : 0, iteration); } struct UIAddName FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_NAMEIDX = 4, - VT_NAME = 6 - }; - int32_t nameIdx() const { - return GetField(VT_NAMEIDX, 0); - } + enum { VT_NAMEIDX = 4, VT_NAME = 6 }; + int32_t nameIdx() const { return GetField(VT_NAMEIDX, 0); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_NAMEIDX) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && verifier.EndTable(); } }; @@ -371,8 +312,7 @@ struct UIAddNameBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(UIAddName::VT_NAME, name); } - explicit UIAddNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIAddNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIAddNameBuilder &operator=(const UIAddNameBuilder &); @@ -384,8 +324,7 @@ struct UIAddNameBuilder { }; inline flatbuffers::Offset CreateUIAddName( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t nameIdx = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t nameIdx = 0, flatbuffers::Offset name = 0) { UIAddNameBuilder builder_(_fbb); builder_.add_name(name); @@ -394,39 +333,35 @@ inline flatbuffers::Offset CreateUIAddName( } inline flatbuffers::Offset CreateUIAddNameDirect( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t nameIdx = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t nameIdx = 0, const char *name = nullptr) { - return sd::graph::CreateUIAddName( - _fbb, - nameIdx, - name ? _fbb.CreateString(name) : 0); + return sd::graph::CreateUIAddName(_fbb, nameIdx, + name ? _fbb.CreateString(name) : 0); } struct FlatArrayList FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_LIST = 4 - }; + enum { VT_LIST = 4 }; const flatbuffers::Vector> *list() const { - return GetPointer> *>(VT_LIST); + return GetPointer< + const flatbuffers::Vector> *>(VT_LIST); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_LIST) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_LIST) && verifier.VerifyVector(list()) && - verifier.VerifyVectorOfTables(list()) && - verifier.EndTable(); + verifier.VerifyVectorOfTables(list()) && verifier.EndTable(); } }; struct FlatArrayListBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_list(flatbuffers::Offset>> list) { + void add_list( + flatbuffers::Offset>> + list) { fbb_.AddOffset(FlatArrayList::VT_LIST, list); } explicit FlatArrayListBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatArrayListBuilder &operator=(const FlatArrayListBuilder &); @@ -439,7 +374,8 @@ struct FlatArrayListBuilder { inline flatbuffers::Offset CreateFlatArrayList( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> list = 0) { + flatbuffers::Offset>> + list = 0) { FlatArrayListBuilder builder_(_fbb); builder_.add_list(list); return builder_.Finish(); @@ -464,30 +400,26 @@ struct UIHistogram FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { UIHistogramType type() const { return static_cast(GetField(VT_TYPE, 0)); } - uint32_t numbins() const { - return GetField(VT_NUMBINS, 0); - } + uint32_t numbins() const { return GetField(VT_NUMBINS, 0); } const FlatArray *binranges() const { return GetPointer(VT_BINRANGES); } - const FlatArray *y() const { - return GetPointer(VT_Y); - } - const flatbuffers::Vector> *binlabels() const { - return GetPointer> *>(VT_BINLABELS); + const FlatArray *y() const { return GetPointer(VT_Y); } + const flatbuffers::Vector> + *binlabels() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_BINLABELS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_NUMBINS) && VerifyOffset(verifier, VT_BINRANGES) && - verifier.VerifyTable(binranges()) && - VerifyOffset(verifier, VT_Y) && - verifier.VerifyTable(y()) && - VerifyOffset(verifier, VT_BINLABELS) && + verifier.VerifyTable(binranges()) && VerifyOffset(verifier, VT_Y) && + verifier.VerifyTable(y()) && VerifyOffset(verifier, VT_BINLABELS) && verifier.VerifyVector(binlabels()) && - verifier.VerifyVectorOfStrings(binlabels()) && - verifier.EndTable(); + verifier.VerifyVectorOfStrings(binlabels()) && verifier.EndTable(); } }; @@ -506,11 +438,14 @@ struct UIHistogramBuilder { void add_y(flatbuffers::Offset y) { fbb_.AddOffset(UIHistogram::VT_Y, y); } - void add_binlabels(flatbuffers::Offset>> binlabels) { + void add_binlabels( + flatbuffers::Offset< + flatbuffers::Vector>> + binlabels) { fbb_.AddOffset(UIHistogram::VT_BINLABELS, binlabels); } explicit UIHistogramBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIHistogramBuilder &operator=(const UIHistogramBuilder &); @@ -523,11 +458,12 @@ struct UIHistogramBuilder { inline flatbuffers::Offset CreateUIHistogram( flatbuffers::FlatBufferBuilder &_fbb, - UIHistogramType type = UIHistogramType_DISCRETE, - uint32_t numbins = 0, + UIHistogramType type = UIHistogramType_DISCRETE, uint32_t numbins = 0, flatbuffers::Offset binranges = 0, flatbuffers::Offset y = 0, - flatbuffers::Offset>> binlabels = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + binlabels = 0) { UIHistogramBuilder builder_(_fbb); builder_.add_binlabels(binlabels); builder_.add_y(y); @@ -539,21 +475,20 @@ inline flatbuffers::Offset CreateUIHistogram( inline flatbuffers::Offset CreateUIHistogramDirect( flatbuffers::FlatBufferBuilder &_fbb, - UIHistogramType type = UIHistogramType_DISCRETE, - uint32_t numbins = 0, + UIHistogramType type = UIHistogramType_DISCRETE, uint32_t numbins = 0, flatbuffers::Offset binranges = 0, flatbuffers::Offset y = 0, - const std::vector> *binlabels = nullptr) { + const std::vector> *binlabels = + nullptr) { return sd::graph::CreateUIHistogram( - _fbb, - type, - numbins, - binranges, - y, - binlabels ? _fbb.CreateVector>(*binlabels) : 0); + _fbb, type, numbins, binranges, y, + binlabels ? _fbb.CreateVector>( + *binlabels) + : 0); } -struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { +struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { enum { VT_BITMASK = 4, VT_MIN = 6, @@ -566,51 +501,32 @@ struct UISummaryStatistics FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table VT_COUNTNAN = 20, VT_COUNTINF = 22 }; - uint32_t bitmask() const { - return GetField(VT_BITMASK, 0); - } - const FlatArray *min() const { - return GetPointer(VT_MIN); - } - const FlatArray *max() const { - return GetPointer(VT_MAX); - } - double mean() const { - return GetField(VT_MEAN, 0.0); - } - double stdev() const { - return GetField(VT_STDEV, 0.0); - } - int64_t countzero() const { - return GetField(VT_COUNTZERO, 0); - } + uint32_t bitmask() const { return GetField(VT_BITMASK, 0); } + const FlatArray *min() const { return GetPointer(VT_MIN); } + const FlatArray *max() const { return GetPointer(VT_MAX); } + double mean() const { return GetField(VT_MEAN, 0.0); } + double stdev() const { return GetField(VT_STDEV, 0.0); } + int64_t countzero() const { return GetField(VT_COUNTZERO, 0); } int64_t countpositive() const { return GetField(VT_COUNTPOSITIVE, 0); } int64_t countnegative() const { return GetField(VT_COUNTNEGATIVE, 0); } - int64_t countnan() const { - return GetField(VT_COUNTNAN, 0); - } - int64_t countinf() const { - return GetField(VT_COUNTINF, 0); - } + int64_t countnan() const { return GetField(VT_COUNTNAN, 0); } + int64_t countinf() const { return GetField(VT_COUNTINF, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_BITMASK) && - VerifyOffset(verifier, VT_MIN) && - verifier.VerifyTable(min()) && - VerifyOffset(verifier, VT_MAX) && - verifier.VerifyTable(max()) && + VerifyOffset(verifier, VT_MIN) && verifier.VerifyTable(min()) && + VerifyOffset(verifier, VT_MAX) && verifier.VerifyTable(max()) && VerifyField(verifier, VT_MEAN) && VerifyField(verifier, VT_STDEV) && VerifyField(verifier, VT_COUNTZERO) && VerifyField(verifier, VT_COUNTPOSITIVE) && VerifyField(verifier, VT_COUNTNEGATIVE) && VerifyField(verifier, VT_COUNTNAN) && - VerifyField(verifier, VT_COUNTINF) && - verifier.EndTable(); + VerifyField(verifier, VT_COUNTINF) && verifier.EndTable(); } }; @@ -636,10 +552,12 @@ struct UISummaryStatisticsBuilder { fbb_.AddElement(UISummaryStatistics::VT_COUNTZERO, countzero, 0); } void add_countpositive(int64_t countpositive) { - fbb_.AddElement(UISummaryStatistics::VT_COUNTPOSITIVE, countpositive, 0); + fbb_.AddElement(UISummaryStatistics::VT_COUNTPOSITIVE, + countpositive, 0); } void add_countnegative(int64_t countnegative) { - fbb_.AddElement(UISummaryStatistics::VT_COUNTNEGATIVE, countnegative, 0); + fbb_.AddElement(UISummaryStatistics::VT_COUNTNEGATIVE, + countnegative, 0); } void add_countnan(int64_t countnan) { fbb_.AddElement(UISummaryStatistics::VT_COUNTNAN, countnan, 0); @@ -648,7 +566,7 @@ struct UISummaryStatisticsBuilder { fbb_.AddElement(UISummaryStatistics::VT_COUNTINF, countinf, 0); } explicit UISummaryStatisticsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UISummaryStatisticsBuilder &operator=(const UISummaryStatisticsBuilder &); @@ -660,17 +578,11 @@ struct UISummaryStatisticsBuilder { }; inline flatbuffers::Offset CreateUISummaryStatistics( - flatbuffers::FlatBufferBuilder &_fbb, - uint32_t bitmask = 0, + flatbuffers::FlatBufferBuilder &_fbb, uint32_t bitmask = 0, flatbuffers::Offset min = 0, - flatbuffers::Offset max = 0, - double mean = 0.0, - double stdev = 0.0, - int64_t countzero = 0, - int64_t countpositive = 0, - int64_t countnegative = 0, - int64_t countnan = 0, - int64_t countinf = 0) { + flatbuffers::Offset max = 0, double mean = 0.0, + double stdev = 0.0, int64_t countzero = 0, int64_t countpositive = 0, + int64_t countnegative = 0, int64_t countnan = 0, int64_t countinf = 0) { UISummaryStatisticsBuilder builder_(_fbb); builder_.add_countinf(countinf); builder_.add_countnan(countnan); @@ -686,36 +598,30 @@ inline flatbuffers::Offset CreateUISummaryStatistics( } struct UIHardwareState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_GPUMEMORY = 4, - VT_HOSTMEMORY = 6 - }; + enum { VT_GPUMEMORY = 4, VT_HOSTMEMORY = 6 }; const flatbuffers::Vector *gpuMemory() const { return GetPointer *>(VT_GPUMEMORY); } - int64_t hostMemory() const { - return GetField(VT_HOSTMEMORY, 0); - } + int64_t hostMemory() const { return GetField(VT_HOSTMEMORY, 0); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_GPUMEMORY) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_GPUMEMORY) && verifier.VerifyVector(gpuMemory()) && - VerifyField(verifier, VT_HOSTMEMORY) && - verifier.EndTable(); + VerifyField(verifier, VT_HOSTMEMORY) && verifier.EndTable(); } }; struct UIHardwareStateBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_gpuMemory(flatbuffers::Offset> gpuMemory) { + void add_gpuMemory( + flatbuffers::Offset> gpuMemory) { fbb_.AddOffset(UIHardwareState::VT_GPUMEMORY, gpuMemory); } void add_hostMemory(int64_t hostMemory) { fbb_.AddElement(UIHardwareState::VT_HOSTMEMORY, hostMemory, 0); } explicit UIHardwareStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIHardwareStateBuilder &operator=(const UIHardwareStateBuilder &); @@ -738,12 +644,9 @@ inline flatbuffers::Offset CreateUIHardwareState( inline flatbuffers::Offset CreateUIHardwareStateDirect( flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *gpuMemory = nullptr, - int64_t hostMemory = 0) { + const std::vector *gpuMemory = nullptr, int64_t hostMemory = 0) { return sd::graph::CreateUIHardwareState( - _fbb, - gpuMemory ? _fbb.CreateVector(*gpuMemory) : 0, - hostMemory); + _fbb, gpuMemory ? _fbb.CreateVector(*gpuMemory) : 0, hostMemory); } } // namespace graph diff --git a/libnd4j/include/graph/generated/uigraphstatic_generated.h b/libnd4j/include/graph/generated/uigraphstatic_generated.h index b6545f53a5e6..e0e0e8a6547d 100644 --- a/libnd4j/include/graph/generated/uigraphstatic_generated.h +++ b/libnd4j/include/graph/generated/uigraphstatic_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UIGRAPHSTATIC_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UIGRAPHSTATIC_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" #include "variable_generated.h" @@ -32,21 +30,15 @@ enum UIInfoType { }; inline const UIInfoType (&EnumValuesUIInfoType())[3] { - static const UIInfoType values[] = { - UIInfoType_GRAPH_STRUCTURE, - UIInfoType_SYTEM_INFO, - UIInfoType_START_EVENTS - }; + static const UIInfoType values[] = {UIInfoType_GRAPH_STRUCTURE, + UIInfoType_SYTEM_INFO, + UIInfoType_START_EVENTS}; return values; } -inline const char * const *EnumNamesUIInfoType() { - static const char * const names[] = { - "GRAPH_STRUCTURE", - "SYTEM_INFO", - "START_EVENTS", - nullptr - }; +inline const char *const *EnumNamesUIInfoType() { + static const char *const names[] = {"GRAPH_STRUCTURE", "SYTEM_INFO", + "START_EVENTS", nullptr}; return names; } @@ -56,16 +48,13 @@ inline const char *EnumNameUIInfoType(UIInfoType e) { } struct UIStaticInfoRecord FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_INFOTYPE = 4 - }; + enum { VT_INFOTYPE = 4 }; UIInfoType infoType() const { return static_cast(GetField(VT_INFOTYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_INFOTYPE) && - verifier.EndTable(); + VerifyField(verifier, VT_INFOTYPE) && verifier.EndTable(); } }; @@ -73,10 +62,11 @@ struct UIStaticInfoRecordBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_infoType(UIInfoType infoType) { - fbb_.AddElement(UIStaticInfoRecord::VT_INFOTYPE, static_cast(infoType), 0); + fbb_.AddElement(UIStaticInfoRecord::VT_INFOTYPE, + static_cast(infoType), 0); } explicit UIStaticInfoRecordBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIStaticInfoRecordBuilder &operator=(const UIStaticInfoRecordBuilder &); @@ -96,9 +86,7 @@ inline flatbuffers::Offset CreateUIStaticInfoRecord( } struct UISystemInfo FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_PHYSICALCORES = 4 - }; + enum { VT_PHYSICALCORES = 4 }; int32_t physicalCores() const { return GetField(VT_PHYSICALCORES, 0); } @@ -116,7 +104,7 @@ struct UISystemInfoBuilder { fbb_.AddElement(UISystemInfo::VT_PHYSICALCORES, physicalCores, 0); } explicit UISystemInfoBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UISystemInfoBuilder &operator=(const UISystemInfoBuilder &); @@ -128,8 +116,7 @@ struct UISystemInfoBuilder { }; inline flatbuffers::Offset CreateUISystemInfo( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t physicalCores = 0) { + flatbuffers::FlatBufferBuilder &_fbb, int32_t physicalCores = 0) { UISystemInfoBuilder builder_(_fbb); builder_.add_physicalCores(physicalCores); return builder_.Finish(); @@ -143,24 +130,35 @@ struct UIGraphStructure FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_VARIABLES = 10, VT_OPS = 12 }; - const flatbuffers::Vector> *inputs() const { - return GetPointer> *>(VT_INPUTS); + const flatbuffers::Vector> *inputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTS); } const flatbuffers::Vector> *inputsPair() const { - return GetPointer> *>(VT_INPUTSPAIR); - } - const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); - } - const flatbuffers::Vector> *variables() const { - return GetPointer> *>(VT_VARIABLES); + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTSPAIR); + } + const flatbuffers::Vector> *outputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTS); + } + const flatbuffers::Vector> *variables() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_VARIABLES); } const flatbuffers::Vector> *ops() const { - return GetPointer> *>(VT_OPS); + return GetPointer> *>( + VT_OPS); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_INPUTS) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && verifier.VerifyVectorOfStrings(inputs()) && VerifyOffset(verifier, VT_INPUTSPAIR) && @@ -172,33 +170,41 @@ struct UIGraphStructure FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_VARIABLES) && verifier.VerifyVector(variables()) && verifier.VerifyVectorOfTables(variables()) && - VerifyOffset(verifier, VT_OPS) && - verifier.VerifyVector(ops()) && - verifier.VerifyVectorOfTables(ops()) && - verifier.EndTable(); + VerifyOffset(verifier, VT_OPS) && verifier.VerifyVector(ops()) && + verifier.VerifyVectorOfTables(ops()) && verifier.EndTable(); } }; struct UIGraphStructureBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_inputs(flatbuffers::Offset>> inputs) { + void add_inputs(flatbuffers::Offset< + flatbuffers::Vector>> + inputs) { fbb_.AddOffset(UIGraphStructure::VT_INPUTS, inputs); } - void add_inputsPair(flatbuffers::Offset>> inputsPair) { + void add_inputsPair( + flatbuffers::Offset>> + inputsPair) { fbb_.AddOffset(UIGraphStructure::VT_INPUTSPAIR, inputsPair); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset< + flatbuffers::Vector>> + outputs) { fbb_.AddOffset(UIGraphStructure::VT_OUTPUTS, outputs); } - void add_variables(flatbuffers::Offset>> variables) { + void add_variables( + flatbuffers::Offset>> + variables) { fbb_.AddOffset(UIGraphStructure::VT_VARIABLES, variables); } - void add_ops(flatbuffers::Offset>> ops) { + void add_ops( + flatbuffers::Offset>> ops) { fbb_.AddOffset(UIGraphStructure::VT_OPS, ops); } explicit UIGraphStructureBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIGraphStructureBuilder &operator=(const UIGraphStructureBuilder &); @@ -211,11 +217,18 @@ struct UIGraphStructureBuilder { inline flatbuffers::Offset CreateUIGraphStructure( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> inputsPair = 0, - flatbuffers::Offset>> outputs = 0, - flatbuffers::Offset>> variables = 0, - flatbuffers::Offset>> ops = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + inputs = 0, + flatbuffers::Offset>> + inputsPair = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputs = 0, + flatbuffers::Offset>> + variables = 0, + flatbuffers::Offset>> ops = + 0) { UIGraphStructureBuilder builder_(_fbb); builder_.add_ops(ops); builder_.add_variables(variables); @@ -227,17 +240,25 @@ inline flatbuffers::Offset CreateUIGraphStructure( inline flatbuffers::Offset CreateUIGraphStructureDirect( flatbuffers::FlatBufferBuilder &_fbb, - const std::vector> *inputs = nullptr, + const std::vector> *inputs = + nullptr, const std::vector> *inputsPair = nullptr, - const std::vector> *outputs = nullptr, + const std::vector> *outputs = + nullptr, const std::vector> *variables = nullptr, const std::vector> *ops = nullptr) { return sd::graph::CreateUIGraphStructure( _fbb, - inputs ? _fbb.CreateVector>(*inputs) : 0, - inputsPair ? _fbb.CreateVector>(*inputsPair) : 0, - outputs ? _fbb.CreateVector>(*outputs) : 0, - variables ? _fbb.CreateVector>(*variables) : 0, + inputs + ? _fbb.CreateVector>(*inputs) + : 0, + inputsPair ? _fbb.CreateVector>(*inputsPair) + : 0, + outputs ? _fbb.CreateVector>( + *outputs) + : 0, + variables ? _fbb.CreateVector>(*variables) + : 0, ops ? _fbb.CreateVector>(*ops) : 0); } @@ -257,9 +278,7 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_UILABELEXTRA = 26, VT_CONSTANTVALUE = 28 }; - const IntPair *id() const { - return GetPointer(VT_ID); - } + const IntPair *id() const { return GetPointer(VT_ID); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -272,20 +291,32 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } const flatbuffers::String *outputOfOp() const { return GetPointer(VT_OUTPUTOFOP); } - const flatbuffers::Vector> *inputsForOp() const { - return GetPointer> *>(VT_INPUTSFOROP); + const flatbuffers::Vector> + *inputsForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTSFOROP); } - const flatbuffers::Vector> *controlDepsForOp() const { - return GetPointer> *>(VT_CONTROLDEPSFOROP); + const flatbuffers::Vector> + *controlDepsForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFOROP); } - const flatbuffers::Vector> *controlDepsForVar() const { - return GetPointer> *>(VT_CONTROLDEPSFORVAR); + const flatbuffers::Vector> + *controlDepsForVar() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFORVAR); } const flatbuffers::String *gradientVariable() const { return GetPointer(VT_GRADIENTVARIABLE); @@ -297,15 +328,12 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer(VT_CONSTANTVALUE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ID) && - verifier.VerifyTable(id()) && - VerifyOffset(verifier, VT_NAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && + verifier.VerifyTable(id()) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_DATATYPE) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_CONTROLDEPS) && verifier.VerifyVector(controlDeps()) && verifier.VerifyVectorOfStrings(controlDeps()) && @@ -325,8 +353,7 @@ struct UIVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_UILABELEXTRA) && verifier.VerifyString(uiLabelExtra()) && VerifyOffset(verifier, VT_CONSTANTVALUE) && - verifier.VerifyTable(constantValue()) && - verifier.EndTable(); + verifier.VerifyTable(constantValue()) && verifier.EndTable(); } }; @@ -343,27 +370,41 @@ struct UIVariableBuilder { fbb_.AddElement(UIVariable::VT_TYPE, static_cast(type), 0); } void add_datatype(DType datatype) { - fbb_.AddElement(UIVariable::VT_DATATYPE, static_cast(datatype), 0); + fbb_.AddElement(UIVariable::VT_DATATYPE, + static_cast(datatype), 0); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(UIVariable::VT_SHAPE, shape); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPS, controlDeps); } void add_outputOfOp(flatbuffers::Offset outputOfOp) { fbb_.AddOffset(UIVariable::VT_OUTPUTOFOP, outputOfOp); } - void add_inputsForOp(flatbuffers::Offset>> inputsForOp) { + void add_inputsForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + inputsForOp) { fbb_.AddOffset(UIVariable::VT_INPUTSFOROP, inputsForOp); } - void add_controlDepsForOp(flatbuffers::Offset>> controlDepsForOp) { + void add_controlDepsForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForOp) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPSFOROP, controlDepsForOp); } - void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + void add_controlDepsForVar( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar) { fbb_.AddOffset(UIVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); } - void add_gradientVariable(flatbuffers::Offset gradientVariable) { + void add_gradientVariable( + flatbuffers::Offset gradientVariable) { fbb_.AddOffset(UIVariable::VT_GRADIENTVARIABLE, gradientVariable); } void add_uiLabelExtra(flatbuffers::Offset uiLabelExtra) { @@ -373,7 +414,7 @@ struct UIVariableBuilder { fbb_.AddOffset(UIVariable::VT_CONSTANTVALUE, constantValue); } explicit UIVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIVariableBuilder &operator=(const UIVariableBuilder &); @@ -385,17 +426,23 @@ struct UIVariableBuilder { }; inline flatbuffers::Offset CreateUIVariable( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, - VarType type = VarType_VARIABLE, - DType datatype = DType_INHERIT, + VarType type = VarType_VARIABLE, DType datatype = DType_INHERIT, flatbuffers::Offset> shape = 0, - flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, flatbuffers::Offset outputOfOp = 0, - flatbuffers::Offset>> inputsForOp = 0, - flatbuffers::Offset>> controlDepsForOp = 0, - flatbuffers::Offset>> controlDepsForVar = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + inputsForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar = 0, flatbuffers::Offset gradientVariable = 0, flatbuffers::Offset uiLabelExtra = 0, flatbuffers::Offset constantValue = 0) { @@ -417,35 +464,40 @@ inline flatbuffers::Offset CreateUIVariable( } inline flatbuffers::Offset CreateUIVariableDirect( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, - const char *name = nullptr, - VarType type = VarType_VARIABLE, - DType datatype = DType_INHERIT, - const std::vector *shape = nullptr, - const std::vector> *controlDeps = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, + const char *name = nullptr, VarType type = VarType_VARIABLE, + DType datatype = DType_INHERIT, const std::vector *shape = nullptr, + const std::vector> *controlDeps = + nullptr, const char *outputOfOp = nullptr, - const std::vector> *inputsForOp = nullptr, - const std::vector> *controlDepsForOp = nullptr, - const std::vector> *controlDepsForVar = nullptr, - const char *gradientVariable = nullptr, - const char *uiLabelExtra = nullptr, + const std::vector> *inputsForOp = + nullptr, + const std::vector> + *controlDepsForOp = nullptr, + const std::vector> + *controlDepsForVar = nullptr, + const char *gradientVariable = nullptr, const char *uiLabelExtra = nullptr, flatbuffers::Offset constantValue = 0) { return sd::graph::CreateUIVariable( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - type, - datatype, + _fbb, id, name ? _fbb.CreateString(name) : 0, type, datatype, shape ? _fbb.CreateVector(*shape) : 0, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, outputOfOp ? _fbb.CreateString(outputOfOp) : 0, - inputsForOp ? _fbb.CreateVector>(*inputsForOp) : 0, - controlDepsForOp ? _fbb.CreateVector>(*controlDepsForOp) : 0, - controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0, + inputsForOp ? _fbb.CreateVector>( + *inputsForOp) + : 0, + controlDepsForOp + ? _fbb.CreateVector>( + *controlDepsForOp) + : 0, + controlDepsForVar + ? _fbb.CreateVector>( + *controlDepsForVar) + : 0, gradientVariable ? _fbb.CreateString(gradientVariable) : 0, - uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0, - constantValue); + uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0, constantValue); } struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -463,23 +515,30 @@ struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::String *opName() const { return GetPointer(VT_OPNAME); } - const flatbuffers::Vector> *inputs() const { - return GetPointer> *>(VT_INPUTS); + const flatbuffers::Vector> *inputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_INPUTS); } - const flatbuffers::Vector> *outputs() const { - return GetPointer> *>(VT_OUTPUTS); + const flatbuffers::Vector> *outputs() + const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_OUTPUTS); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } const flatbuffers::String *uiLabelExtra() const { return GetPointer(VT_UILABELEXTRA); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_NAME) && - verifier.VerifyString(name()) && - VerifyOffset(verifier, VT_OPNAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && VerifyOffset(verifier, VT_OPNAME) && verifier.VerifyString(opName()) && VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) && @@ -491,8 +550,7 @@ struct UIOp FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVector(controlDeps()) && verifier.VerifyVectorOfStrings(controlDeps()) && VerifyOffset(verifier, VT_UILABELEXTRA) && - verifier.VerifyString(uiLabelExtra()) && - verifier.EndTable(); + verifier.VerifyString(uiLabelExtra()) && verifier.EndTable(); } }; @@ -505,20 +563,27 @@ struct UIOpBuilder { void add_opName(flatbuffers::Offset opName) { fbb_.AddOffset(UIOp::VT_OPNAME, opName); } - void add_inputs(flatbuffers::Offset>> inputs) { + void add_inputs(flatbuffers::Offset< + flatbuffers::Vector>> + inputs) { fbb_.AddOffset(UIOp::VT_INPUTS, inputs); } - void add_outputs(flatbuffers::Offset>> outputs) { + void add_outputs( + flatbuffers::Offset< + flatbuffers::Vector>> + outputs) { fbb_.AddOffset(UIOp::VT_OUTPUTS, outputs); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(UIOp::VT_CONTROLDEPS, controlDeps); } void add_uiLabelExtra(flatbuffers::Offset uiLabelExtra) { fbb_.AddOffset(UIOp::VT_UILABELEXTRA, uiLabelExtra); } - explicit UIOpBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit UIOpBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } UIOpBuilder &operator=(const UIOpBuilder &); @@ -533,9 +598,15 @@ inline flatbuffers::Offset CreateUIOp( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset name = 0, flatbuffers::Offset opName = 0, - flatbuffers::Offset>> inputs = 0, - flatbuffers::Offset>> outputs = 0, - flatbuffers::Offset>> controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + inputs = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + outputs = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, flatbuffers::Offset uiLabelExtra = 0) { UIOpBuilder builder_(_fbb); builder_.add_uiLabelExtra(uiLabelExtra); @@ -548,20 +619,27 @@ inline flatbuffers::Offset CreateUIOp( } inline flatbuffers::Offset CreateUIOpDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const char *name = nullptr, + flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr, const char *opName = nullptr, - const std::vector> *inputs = nullptr, - const std::vector> *outputs = nullptr, - const std::vector> *controlDeps = nullptr, + const std::vector> *inputs = + nullptr, + const std::vector> *outputs = + nullptr, + const std::vector> *controlDeps = + nullptr, const char *uiLabelExtra = nullptr) { return sd::graph::CreateUIOp( - _fbb, - name ? _fbb.CreateString(name) : 0, + _fbb, name ? _fbb.CreateString(name) : 0, opName ? _fbb.CreateString(opName) : 0, - inputs ? _fbb.CreateVector>(*inputs) : 0, - outputs ? _fbb.CreateVector>(*outputs) : 0, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, + inputs + ? _fbb.CreateVector>(*inputs) + : 0, + outputs ? _fbb.CreateVector>( + *outputs) + : 0, + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, uiLabelExtra ? _fbb.CreateString(uiLabelExtra) : 0); } diff --git a/libnd4j/include/graph/generated/utils_generated.h b/libnd4j/include/graph/generated/utils_generated.h index 8e7896bb4956..24853534d86a 100644 --- a/libnd4j/include/graph/generated/utils_generated.h +++ b/libnd4j/include/graph/generated/utils_generated.h @@ -1,6 +1,5 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_UTILS_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_UTILS_ND4J_GRAPH_H_ @@ -50,160 +49,144 @@ enum OpType { inline const OpType (&EnumValuesOpType())[26] { static const OpType values[] = { - OpType_TRANSFORM_FLOAT, - OpType_TRANSFORM_SAME, - OpType_TRANSFORM_BOOL, - OpType_TRANSFORM_STRICT, - OpType_TRANSFORM_ANY, - OpType_REDUCE_FLOAT, - OpType_REDUCE_SAME, - OpType_REDUCE_LONG, - OpType_REDUCE_BOOL, - OpType_INDEX_REDUCE, - OpType_SCALAR, - OpType_SCALAR_BOOL, - OpType_BROADCAST, - OpType_BROADCAST_BOOL, - OpType_PAIRWISE, - OpType_PAIRWISE_BOOL, - OpType_REDUCE_3, - OpType_SUMMARYSTATS, - OpType_SHAPE, - OpType_AGGREGATION, - OpType_RANDOM, - OpType_CUSTOM, - OpType_GRAPH, - OpType_VARIABLE, - OpType_BOOLEAN, - OpType_LOGIC - }; + OpType_TRANSFORM_FLOAT, OpType_TRANSFORM_SAME, + OpType_TRANSFORM_BOOL, OpType_TRANSFORM_STRICT, + OpType_TRANSFORM_ANY, OpType_REDUCE_FLOAT, + OpType_REDUCE_SAME, OpType_REDUCE_LONG, + OpType_REDUCE_BOOL, OpType_INDEX_REDUCE, + OpType_SCALAR, OpType_SCALAR_BOOL, + OpType_BROADCAST, OpType_BROADCAST_BOOL, + OpType_PAIRWISE, OpType_PAIRWISE_BOOL, + OpType_REDUCE_3, OpType_SUMMARYSTATS, + OpType_SHAPE, OpType_AGGREGATION, + OpType_RANDOM, OpType_CUSTOM, + OpType_GRAPH, OpType_VARIABLE, + OpType_BOOLEAN, OpType_LOGIC}; return values; } -inline const char * const *EnumNamesOpType() { - static const char * const names[] = { - "TRANSFORM_FLOAT", - "TRANSFORM_SAME", - "TRANSFORM_BOOL", - "TRANSFORM_STRICT", - "TRANSFORM_ANY", - "REDUCE_FLOAT", - "REDUCE_SAME", - "REDUCE_LONG", - "REDUCE_BOOL", - "INDEX_REDUCE", - "SCALAR", - "SCALAR_BOOL", - "BROADCAST", - "BROADCAST_BOOL", - "PAIRWISE", - "PAIRWISE_BOOL", - "REDUCE_3", - "SUMMARYSTATS", - "SHAPE", - "AGGREGATION", - "RANDOM", - "CUSTOM", - "GRAPH", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "VARIABLE", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "BOOLEAN", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "LOGIC", - nullptr - }; +inline const char *const *EnumNamesOpType() { + static const char *const names[] = {"TRANSFORM_FLOAT", + "TRANSFORM_SAME", + "TRANSFORM_BOOL", + "TRANSFORM_STRICT", + "TRANSFORM_ANY", + "REDUCE_FLOAT", + "REDUCE_SAME", + "REDUCE_LONG", + "REDUCE_BOOL", + "INDEX_REDUCE", + "SCALAR", + "SCALAR_BOOL", + "BROADCAST", + "BROADCAST_BOOL", + "PAIRWISE", + "PAIRWISE_BOOL", + "REDUCE_3", + "SUMMARYSTATS", + "SHAPE", + "AGGREGATION", + "RANDOM", + "CUSTOM", + "GRAPH", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "VARIABLE", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "BOOLEAN", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "LOGIC", + nullptr}; return names; } @@ -224,24 +207,15 @@ enum InputType { inline const InputType (&EnumValuesInputType())[5] { static const InputType values[] = { - InputType_UNDEFINED, - InputType_NUMERIC, - InputType_STRINGULAR, - InputType_NUMERIC_SET, - InputType_STRINGULAR_SET - }; + InputType_UNDEFINED, InputType_NUMERIC, InputType_STRINGULAR, + InputType_NUMERIC_SET, InputType_STRINGULAR_SET}; return values; } -inline const char * const *EnumNamesInputType() { - static const char * const names[] = { - "UNDEFINED", - "NUMERIC", - "STRINGULAR", - "NUMERIC_SET", - "STRINGULAR_SET", - nullptr - }; +inline const char *const *EnumNamesInputType() { + static const char *const names[] = {"UNDEFINED", "NUMERIC", + "STRINGULAR", "NUMERIC_SET", + "STRINGULAR_SET", nullptr}; return names; } @@ -262,27 +236,16 @@ enum OpClass { }; inline const OpClass (&EnumValuesOpClass())[6] { - static const OpClass values[] = { - OpClass_TRANSFORM, - OpClass_REDUCTION, - OpClass_MULTIPLICATOR, - OpClass_GRAPH, - OpClass_CONDITIONAL, - OpClass_LOOP - }; + static const OpClass values[] = {OpClass_TRANSFORM, OpClass_REDUCTION, + OpClass_MULTIPLICATOR, OpClass_GRAPH, + OpClass_CONDITIONAL, OpClass_LOOP}; return values; } -inline const char * const *EnumNamesOpClass() { - static const char * const names[] = { - "TRANSFORM", - "REDUCTION", - "MULTIPLICATOR", - "GRAPH", - "CONDITIONAL", - "LOOP", - nullptr - }; +inline const char *const *EnumNamesOpClass() { + static const char *const names[] = { + "TRANSFORM", "REDUCTION", "MULTIPLICATOR", "GRAPH", + "CONDITIONAL", "LOOP", nullptr}; return names; } @@ -292,21 +255,13 @@ inline const char *EnumNameOpClass(OpClass e) { } struct LongPair FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6 - }; - int64_t first() const { - return GetField(VT_FIRST, 0); - } - int64_t second() const { - return GetField(VT_SECOND, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6 }; + int64_t first() const { return GetField(VT_FIRST, 0); } + int64_t second() const { return GetField(VT_SECOND, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && - VerifyField(verifier, VT_SECOND) && - verifier.EndTable(); + VerifyField(verifier, VT_SECOND) && verifier.EndTable(); } }; @@ -319,8 +274,7 @@ struct LongPairBuilder { void add_second(int64_t second) { fbb_.AddElement(LongPair::VT_SECOND, second, 0); } - explicit LongPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit LongPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } LongPairBuilder &operator=(const LongPairBuilder &); @@ -332,8 +286,7 @@ struct LongPairBuilder { }; inline flatbuffers::Offset CreateLongPair( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t first = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t first = 0, int64_t second = 0) { LongPairBuilder builder_(_fbb); builder_.add_second(second); @@ -342,26 +295,15 @@ inline flatbuffers::Offset CreateLongPair( } struct LongTriple FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6, - VT_THIRD = 8 - }; - int64_t first() const { - return GetField(VT_FIRST, 0); - } - int64_t second() const { - return GetField(VT_SECOND, 0); - } - int64_t third() const { - return GetField(VT_THIRD, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6, VT_THIRD = 8 }; + int64_t first() const { return GetField(VT_FIRST, 0); } + int64_t second() const { return GetField(VT_SECOND, 0); } + int64_t third() const { return GetField(VT_THIRD, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && VerifyField(verifier, VT_SECOND) && - VerifyField(verifier, VT_THIRD) && - verifier.EndTable(); + VerifyField(verifier, VT_THIRD) && verifier.EndTable(); } }; @@ -378,7 +320,7 @@ struct LongTripleBuilder { fbb_.AddElement(LongTriple::VT_THIRD, third, 0); } explicit LongTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LongTripleBuilder &operator=(const LongTripleBuilder &); @@ -390,9 +332,7 @@ struct LongTripleBuilder { }; inline flatbuffers::Offset CreateLongTriple( - flatbuffers::FlatBufferBuilder &_fbb, - int64_t first = 0, - int64_t second = 0, + flatbuffers::FlatBufferBuilder &_fbb, int64_t first = 0, int64_t second = 0, int64_t third = 0) { LongTripleBuilder builder_(_fbb); builder_.add_third(third); @@ -402,21 +342,13 @@ inline flatbuffers::Offset CreateLongTriple( } struct IntPair FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6 - }; - int32_t first() const { - return GetField(VT_FIRST, 0); - } - int32_t second() const { - return GetField(VT_SECOND, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6 }; + int32_t first() const { return GetField(VT_FIRST, 0); } + int32_t second() const { return GetField(VT_SECOND, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && - VerifyField(verifier, VT_SECOND) && - verifier.EndTable(); + VerifyField(verifier, VT_SECOND) && verifier.EndTable(); } }; @@ -429,8 +361,7 @@ struct IntPairBuilder { void add_second(int32_t second) { fbb_.AddElement(IntPair::VT_SECOND, second, 0); } - explicit IntPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit IntPairBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } IntPairBuilder &operator=(const IntPairBuilder &); @@ -442,8 +373,7 @@ struct IntPairBuilder { }; inline flatbuffers::Offset CreateIntPair( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t first = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t first = 0, int32_t second = 0) { IntPairBuilder builder_(_fbb); builder_.add_second(second); @@ -452,26 +382,15 @@ inline flatbuffers::Offset CreateIntPair( } struct IntTriple FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - enum { - VT_FIRST = 4, - VT_SECOND = 6, - VT_THIRD = 8 - }; - int32_t first() const { - return GetField(VT_FIRST, 0); - } - int32_t second() const { - return GetField(VT_SECOND, 0); - } - int32_t third() const { - return GetField(VT_THIRD, 0); - } + enum { VT_FIRST = 4, VT_SECOND = 6, VT_THIRD = 8 }; + int32_t first() const { return GetField(VT_FIRST, 0); } + int32_t second() const { return GetField(VT_SECOND, 0); } + int32_t third() const { return GetField(VT_THIRD, 0); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FIRST) && VerifyField(verifier, VT_SECOND) && - VerifyField(verifier, VT_THIRD) && - verifier.EndTable(); + VerifyField(verifier, VT_THIRD) && verifier.EndTable(); } }; @@ -487,8 +406,7 @@ struct IntTripleBuilder { void add_third(int32_t third) { fbb_.AddElement(IntTriple::VT_THIRD, third, 0); } - explicit IntTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit IntTripleBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } IntTripleBuilder &operator=(const IntTripleBuilder &); @@ -500,9 +418,7 @@ struct IntTripleBuilder { }; inline flatbuffers::Offset CreateIntTriple( - flatbuffers::FlatBufferBuilder &_fbb, - int32_t first = 0, - int32_t second = 0, + flatbuffers::FlatBufferBuilder &_fbb, int32_t first = 0, int32_t second = 0, int32_t third = 0) { IntTripleBuilder builder_(_fbb); builder_.add_third(third); diff --git a/libnd4j/include/graph/generated/variable_generated.h b/libnd4j/include/graph/generated/variable_generated.h index a0e43a5af7b3..61fe314ee156 100644 --- a/libnd4j/include/graph/generated/variable_generated.h +++ b/libnd4j/include/graph/generated/variable_generated.h @@ -1,12 +1,10 @@ // automatically generated by the FlatBuffers compiler, do not modify - #ifndef FLATBUFFERS_GENERATED_VARIABLE_ND4J_GRAPH_H_ #define FLATBUFFERS_GENERATED_VARIABLE_ND4J_GRAPH_H_ -#include "flatbuffers/flatbuffers.h" - #include "array_generated.h" +#include "flatbuffers/flatbuffers.h" #include "utils_generated.h" namespace sd { @@ -24,23 +22,14 @@ enum VarType { }; inline const VarType (&EnumValuesVarType())[4] { - static const VarType values[] = { - VarType_VARIABLE, - VarType_CONSTANT, - VarType_ARRAY, - VarType_PLACEHOLDER - }; + static const VarType values[] = {VarType_VARIABLE, VarType_CONSTANT, + VarType_ARRAY, VarType_PLACEHOLDER}; return values; } -inline const char * const *EnumNamesVarType() { - static const char * const names[] = { - "VARIABLE", - "CONSTANT", - "ARRAY", - "PLACEHOLDER", - nullptr - }; +inline const char *const *EnumNamesVarType() { + static const char *const names[] = {"VARIABLE", "CONSTANT", "ARRAY", + "PLACEHOLDER", nullptr}; return names; } @@ -62,9 +51,7 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_CONTROLDEPFOROP = 20, VT_CONTROLDEPSFORVAR = 22 }; - const IntPair *id() const { - return GetPointer(VT_ID); - } + const IntPair *id() const { return GetPointer(VT_ID); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -77,30 +64,34 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FlatArray *ndarray() const { return GetPointer(VT_NDARRAY); } - int32_t device() const { - return GetField(VT_DEVICE, 0); - } + int32_t device() const { return GetField(VT_DEVICE, 0); } VarType variabletype() const { return static_cast(GetField(VT_VARIABLETYPE, 0)); } - const flatbuffers::Vector> *controlDeps() const { - return GetPointer> *>(VT_CONTROLDEPS); + const flatbuffers::Vector> + *controlDeps() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPS); } - const flatbuffers::Vector> *controlDepForOp() const { - return GetPointer> *>(VT_CONTROLDEPFOROP); + const flatbuffers::Vector> + *controlDepForOp() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPFOROP); } - const flatbuffers::Vector> *controlDepsForVar() const { - return GetPointer> *>(VT_CONTROLDEPSFORVAR); + const flatbuffers::Vector> + *controlDepsForVar() const { + return GetPointer< + const flatbuffers::Vector> *>( + VT_CONTROLDEPSFORVAR); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_ID) && - verifier.VerifyTable(id()) && - VerifyOffset(verifier, VT_NAME) && + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_ID) && + verifier.VerifyTable(id()) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && VerifyField(verifier, VT_DTYPE) && - VerifyOffset(verifier, VT_SHAPE) && - verifier.VerifyVector(shape()) && + VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyOffset(verifier, VT_NDARRAY) && verifier.VerifyTable(ndarray()) && VerifyField(verifier, VT_DEVICE) && @@ -128,7 +119,8 @@ struct FlatVariableBuilder { fbb_.AddOffset(FlatVariable::VT_NAME, name); } void add_dtype(DType dtype) { - fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), 0); + fbb_.AddElement(FlatVariable::VT_DTYPE, static_cast(dtype), + 0); } void add_shape(flatbuffers::Offset> shape) { fbb_.AddOffset(FlatVariable::VT_SHAPE, shape); @@ -140,19 +132,29 @@ struct FlatVariableBuilder { fbb_.AddElement(FlatVariable::VT_DEVICE, device, 0); } void add_variabletype(VarType variabletype) { - fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, static_cast(variabletype), 0); + fbb_.AddElement(FlatVariable::VT_VARIABLETYPE, + static_cast(variabletype), 0); } - void add_controlDeps(flatbuffers::Offset>> controlDeps) { + void add_controlDeps( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps); } - void add_controlDepForOp(flatbuffers::Offset>> controlDepForOp) { + void add_controlDepForOp( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepForOp) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp); } - void add_controlDepsForVar(flatbuffers::Offset>> controlDepsForVar) { + void add_controlDepsForVar( + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar) { fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar); } explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FlatVariableBuilder &operator=(const FlatVariableBuilder &); @@ -164,17 +166,21 @@ struct FlatVariableBuilder { }; inline flatbuffers::Offset CreateFlatVariable( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, flatbuffers::Offset name = 0, DType dtype = DType_INHERIT, flatbuffers::Offset> shape = 0, - flatbuffers::Offset ndarray = 0, - int32_t device = 0, + flatbuffers::Offset ndarray = 0, int32_t device = 0, VarType variabletype = VarType_VARIABLE, - flatbuffers::Offset>> controlDeps = 0, - flatbuffers::Offset>> controlDepForOp = 0, - flatbuffers::Offset>> controlDepsForVar = 0) { + flatbuffers::Offset< + flatbuffers::Vector>> + controlDeps = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepForOp = 0, + flatbuffers::Offset< + flatbuffers::Vector>> + controlDepsForVar = 0) { FlatVariableBuilder builder_(_fbb); builder_.add_controlDepsForVar(controlDepsForVar); builder_.add_controlDepForOp(controlDepForOp); @@ -190,41 +196,44 @@ inline flatbuffers::Offset CreateFlatVariable( } inline flatbuffers::Offset CreateFlatVariableDirect( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset id = 0, - const char *name = nullptr, - DType dtype = DType_INHERIT, + flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset id = 0, + const char *name = nullptr, DType dtype = DType_INHERIT, const std::vector *shape = nullptr, - flatbuffers::Offset ndarray = 0, - int32_t device = 0, + flatbuffers::Offset ndarray = 0, int32_t device = 0, VarType variabletype = VarType_VARIABLE, - const std::vector> *controlDeps = nullptr, - const std::vector> *controlDepForOp = nullptr, - const std::vector> *controlDepsForVar = nullptr) { + const std::vector> *controlDeps = + nullptr, + const std::vector> + *controlDepForOp = nullptr, + const std::vector> + *controlDepsForVar = nullptr) { return sd::graph::CreateFlatVariable( - _fbb, - id, - name ? _fbb.CreateString(name) : 0, - dtype, - shape ? _fbb.CreateVector(*shape) : 0, - ndarray, - device, + _fbb, id, name ? _fbb.CreateString(name) : 0, dtype, + shape ? _fbb.CreateVector(*shape) : 0, ndarray, device, variabletype, - controlDeps ? _fbb.CreateVector>(*controlDeps) : 0, - controlDepForOp ? _fbb.CreateVector>(*controlDepForOp) : 0, - controlDepsForVar ? _fbb.CreateVector>(*controlDepsForVar) : 0); + controlDeps ? _fbb.CreateVector>( + *controlDeps) + : 0, + controlDepForOp + ? _fbb.CreateVector>( + *controlDepForOp) + : 0, + controlDepsForVar + ? _fbb.CreateVector>( + *controlDepsForVar) + : 0); } inline const sd::graph::FlatVariable *GetFlatVariable(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const sd::graph::FlatVariable *GetSizePrefixedFlatVariable(const void *buf) { +inline const sd::graph::FlatVariable *GetSizePrefixedFlatVariable( + const void *buf) { return flatbuffers::GetSizePrefixedRoot(buf); } -inline bool VerifyFlatVariableBuffer( - flatbuffers::Verifier &verifier) { +inline bool VerifyFlatVariableBuffer(flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(nullptr); } diff --git a/libnd4j/include/graph/impl/ArgumentsList.cpp b/libnd4j/include/graph/impl/ArgumentsList.cpp index 71ba8c479b91..a73f9aeb52ab 100644 --- a/libnd4j/include/graph/impl/ArgumentsList.cpp +++ b/libnd4j/include/graph/impl/ArgumentsList.cpp @@ -22,25 +22,20 @@ namespace sd { namespace graph { - ArgumentsList::ArgumentsList(std::initializer_list arguments) { - _arguments = arguments; - } - - ArgumentsList::ArgumentsList(std::initializer_list arguments) { - std::vector args(arguments); - for (int e = 0; e < args.size(); e++) { - Pair pair(args[e]); - _arguments.emplace_back(pair); - } +ArgumentsList::ArgumentsList(std::initializer_list arguments) { + _arguments = arguments; +} - } +ArgumentsList::ArgumentsList(std::initializer_list arguments) { + std::vector args(arguments); + for (int e = 0; e < args.size(); e++) { + Pair pair(args[e]); + _arguments.emplace_back(pair); + } +} - int ArgumentsList::size() { - return (int) _arguments.size(); - } +int ArgumentsList::size() { return (int)_arguments.size(); } - Pair& ArgumentsList::at(int index) { - return _arguments.at(index); - } -} -} +Pair& ArgumentsList::at(int index) { return _arguments.at(index); } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 6856dd3d046d..1fa660e59062 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -18,495 +18,503 @@ // @author raver119@gmail.com // +#include #include #include -#include -#include - namespace sd { - namespace graph { - Context::Context(const ContextPrototype& prototype, VariableSpace* variableSpace, GraphMemoryManager *memoryManager) { - _memoryManager = memoryManager; - _variableSpace = variableSpace; - - for (const auto &v: prototype.inputs()) { - this->_inputs.push_back(v); - } - - for (const auto &v: prototype.getTArguments()) { - this->_tArgs.push_back(v); - } - - for (const auto &v: prototype.getIArguments()) { - this->_iArgs.push_back(v); - } - - for (const auto &v: prototype.getBArguments()) { - this->_bArgs.push_back(v); - } - - for (const auto &v: prototype.getAxis()) { - this->_axis.push_back(v); - } - - this->_opNum = prototype.opNum(); - this->_isInplace = prototype.isInplace(); - this->_nodeId = prototype.nodeId(); - this->_name = prototype.name(); - this->_useMKLDNN = prototype.isUseMKLDNN(); - } - - Context::Context(int nodeId, VariableSpace *variableSpace) { - this->_nodeId = nodeId; - this->_variableSpace = variableSpace; - this->_isInplace = false; - this->_workspace = nullptr; - - this->_executionTime.first = 0; - this->_executionTime.second = 0; - } - - Context::Context(int nodeId, VariableSpace *variableSpace, bool isInplace) : Context(nodeId, variableSpace) { - this->_isInplace = isInplace; - } - - Context::~Context() { - this->_iArgs.clear(); - this->_tArgs.clear(); - this->_inputs.clear(); - this->_fastpath_in.clear(); - this->_fastpath_out.clear(); - - if (_context != nullptr) - delete _context; - } - - void Context::setTargetEngine(samediff::Engine engine) { - _engine = engine; - } - - void Context::attachWorkspace(sd::memory::Workspace* workspace) { - this->_workspace = workspace; - } - - void Context::setVariableSpace(VariableSpace *variableSpace) { - this->_variableSpace = variableSpace; - } - - const std::vector>& Context::fastpath_in() const { - return _fastpath_in; - } - - const std::vector>& Context::fastpath_out() const { - return _fastpath_out; - } - - bool Context::isFastPath() const { - auto ie = _fastpath_in.empty(); - auto io = _fastpath_out.empty(); - // two options here. - // either both IN/OUT are filled - auto b1 = (!ie && !io) || (!ie && _isInplace); - - // or at least something is filled, and FastPath is NOT forbidden - auto b2 = (!ie || !io) && !_forbidFastPath; - return b1 || b2; - } - - void Context::forbidFastPath(bool reallyForbid) { - _forbidFastPath = reallyForbid; - } - - VariableSpace *Context::getVariableSpace() { - return _variableSpace; - } - - sd::memory::Workspace* Context::workspace() const { - return _workspace; - } - - Stash* Context::stash() const { - return _variableSpace->stash(); - } - - Nd4jLong sd::graph::Context::outerTime() const { - return this->_executionTime.first; - } - - Nd4jLong sd::graph::Context::innerTime() const { - return this->_executionTime.second; - } - - void sd::graph::Context::setOuterTime(Nd4jLong time){ - this->_executionTime.first = time; - } - - void sd::graph::Context::setInnerTime(Nd4jLong time){ - this->_executionTime.second = time; - } - - - std::shared_ptr Context::getVariable(int idx) const { - if (idx >= this->_inputs.size()) { - nd4j_printf("Node %i; Variable [%i] requested, but only %i inputs available\n", this->_nodeId, idx, this->_inputs.size()); - throw std::runtime_error("Context: bad Variable index"); - } - - auto p = this->_inputs[idx]; - - auto v = variable(p); - - if (Environment::getInstance()->isDebugAndVerbose() && v != nullptr && v->getNDArray() != nullptr) { - auto array = v->getNDArray(); - std::string shape_ = ShapeUtils::shapeAsString(array.get()); - auto type = DataTypeUtils::asString(array->dataType()); - float m = std::numeric_limits::quiet_NaN(); - if (!array->isEmpty()) { - auto values = array->asIndexedString(16); - - nd4j_printf("Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: [%i]; dtype: [%s]; first values: %s\n", this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), type.c_str(), values.c_str()); - } else { - nd4j_printf("Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: [%i]; dtype: [%s]; mean value: [%f]\n", this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), type.c_str(), m); - } - } - - return v; - } - - std::shared_ptr Context::variable(int idx) const { - return getVariable(idx); - } - - std::shared_ptr Context::variable(std::initializer_list p) const { - if (p.size() != 2) - throw std::runtime_error("Variable address should have size of 2"); - - // FIXME: lol - std::vector vec(p); - std::pair pair(vec[0], vec[1]); - return variable(pair); - } - - std::shared_ptr Context::variable(int node, int idx) const { - std::pair pair(node, idx); - return variable(pair); - } - - std::shared_ptr Context::variable(const std::pair& p) const { - try { - return _variableSpace->getVariable(p); - } catch (std::exception &e) { - nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second); - throw std::runtime_error("Bad variable"); - } - } - - void Context::pushNDArrayToVariableSpace(int nodeId, int index, const NDArray &array) { - std::pair pair(nodeId, index); - pushNDArrayToVariableSpace(pair, array); - } - - void Context::pushNDArrayToVariableSpace(const std::pair &pair, const NDArray &array) { - if (_variableSpace != nullptr) { - if (!_variableSpace->hasVariable(pair)) { - auto var = std::make_shared(array, "", pair.first, pair.second); - _variableSpace->putVariable(pair, var); - } else { - auto var = _variableSpace->getVariable(pair); - var->setNDArray(std::make_shared(array)); - } - } - } - - void Context::pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track) { - std::pair pair(nodeId, index); - pushNDArrayListToVariableSpace(pair, list, track); - } - - void Context::pushNDArrayListToVariableSpace(const std::pair& pair, const NDArrayList &list, bool track) { - if (!_variableSpace->hasVariable(pair)) { - auto var = std::make_shared(); - var->setId(pair.first, pair.second); - var->setNDArrayList(std::make_shared(list)); - _variableSpace->putVariable(pair, var); - } else { - auto var = _variableSpace->getVariable(pair); - var->setNDArrayList(std::make_shared(list)); - } - } - - std::shared_ptr Context::ensureVariable(const std::string &name, int id, int idx) { - std::pair pair(this->nodeId(), idx); - - if (_variableSpace == nullptr) - throw std::runtime_error("Context::ensureVariable VariableSpace is NULL!"); - - if (!_variableSpace->hasVariable(pair)) { - auto var = std::make_shared(); - var->setId(this->nodeId(), idx); - auto name = this->name(); - - if (!name.empty()) - var->setName(name); - - _variableSpace->putVariable(pair, var); - return var; - } else { - return _variableSpace->getVariable(pair); - } - } - - bool Context::isValueAvailable(const std::string &name, int id, int idx ) const { - auto var = const_cast(this)->ensureVariable(name, id, idx); - - if (var->variableType() == VariableType::NDARRAY) { - return var->hasNDArray(); - } else if (var->variableType() == VariableType::ARRAY_LIST) { - return var->hasNDArrayList(); - } - - return false; - } - - std::shared_ptr Context::getNDArray(int idx) const { - return array(idx); - } - - std::shared_ptr Context::array(int idx) const { - // we check for fastpath first - if (!_fastpath_in.empty() && _fastpath_in.size() > idx) { - return _fastpath_in[idx]; - } - - // if no luck for fastpath - return whatever is available - return getVariable(idx)->getNDArray(); - } - - LaunchContext* Context::launchContext() { - //FIXME: we need proper context to be shared here - if (_context == nullptr) { - return LaunchContext::defaultContext(); - } else { - return _context; - } - } - - unsigned long Context::width() const { - if (!_fastpath_in.empty()) - return _fastpath_in.size(); - else - return _inputs.size(); - } - - void Context::setInputArray(int index, const NDArray &array) { - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - _fastpath_in[index] = std::make_shared(array); - } - - NDArray *Context::arrayForOp(int idx) const { - auto ptr = array(idx); - - if (ptr.get() != nullptr && ptr->undefined()) - return nullptr; - - return ptr.get(); - } - - void Context::setInputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) { - this->setInputArray(index, buffer, const_cast(shapeInfo), specialBuffer, const_cast(specialShapeInfo)); - } - - void Context::setInputArray(int index, void *buffer, void const* shapeInfo, void *specialBuffer, void const* specialShapeInfo) { - auto array = std::make_shared(buffer, specialBuffer, reinterpret_cast(shapeInfo)); - - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - _fastpath_in[index] = array; - - if (_context != nullptr) - array->setContext(_context); - } - - void Context::setOutputArray(int index, const NDArray &array) { - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); +namespace graph { +Context::Context(const ContextPrototype &prototype, + VariableSpace *variableSpace, + GraphMemoryManager *memoryManager) { + _memoryManager = memoryManager; + _variableSpace = variableSpace; + + for (const auto &v : prototype.inputs()) { + this->_inputs.push_back(v); + } + + for (const auto &v : prototype.getTArguments()) { + this->_tArgs.push_back(v); + } + + for (const auto &v : prototype.getIArguments()) { + this->_iArgs.push_back(v); + } + + for (const auto &v : prototype.getBArguments()) { + this->_bArgs.push_back(v); + } + + for (const auto &v : prototype.getAxis()) { + this->_axis.push_back(v); + } + + this->_opNum = prototype.opNum(); + this->_isInplace = prototype.isInplace(); + this->_nodeId = prototype.nodeId(); + this->_name = prototype.name(); + this->_useMKLDNN = prototype.isUseMKLDNN(); +} + +Context::Context(int nodeId, VariableSpace *variableSpace) { + this->_nodeId = nodeId; + this->_variableSpace = variableSpace; + this->_isInplace = false; + this->_workspace = nullptr; + + this->_executionTime.first = 0; + this->_executionTime.second = 0; +} + +Context::Context(int nodeId, VariableSpace *variableSpace, bool isInplace) + : Context(nodeId, variableSpace) { + this->_isInplace = isInplace; +} + +Context::~Context() { + this->_iArgs.clear(); + this->_tArgs.clear(); + this->_inputs.clear(); + this->_fastpath_in.clear(); + this->_fastpath_out.clear(); + + if (_context != nullptr) delete _context; +} + +void Context::setTargetEngine(samediff::Engine engine) { _engine = engine; } + +void Context::attachWorkspace(sd::memory::Workspace *workspace) { + this->_workspace = workspace; +} + +void Context::setVariableSpace(VariableSpace *variableSpace) { + this->_variableSpace = variableSpace; +} + +const std::vector> &Context::fastpath_in() const { + return _fastpath_in; +} + +const std::vector> &Context::fastpath_out() const { + return _fastpath_out; +} + +bool Context::isFastPath() const { + auto ie = _fastpath_in.empty(); + auto io = _fastpath_out.empty(); + // two options here. + // either both IN/OUT are filled + auto b1 = (!ie && !io) || (!ie && _isInplace); + + // or at least something is filled, and FastPath is NOT forbidden + auto b2 = (!ie || !io) && !_forbidFastPath; + return b1 || b2; +} + +void Context::forbidFastPath(bool reallyForbid) { + _forbidFastPath = reallyForbid; +} + +VariableSpace *Context::getVariableSpace() { return _variableSpace; } + +sd::memory::Workspace *Context::workspace() const { return _workspace; } + +Stash *Context::stash() const { return _variableSpace->stash(); } + +Nd4jLong sd::graph::Context::outerTime() const { + return this->_executionTime.first; +} + +Nd4jLong sd::graph::Context::innerTime() const { + return this->_executionTime.second; +} - _fastpath_out[index] = std::make_shared(array); - } +void sd::graph::Context::setOuterTime(Nd4jLong time) { + this->_executionTime.first = time; +} + +void sd::graph::Context::setInnerTime(Nd4jLong time) { + this->_executionTime.second = time; +} + +std::shared_ptr Context::getVariable(int idx) const { + if (idx >= this->_inputs.size()) { + nd4j_printf( + "Node %i; Variable [%i] requested, but only %i inputs available\n", + this->_nodeId, idx, this->_inputs.size()); + throw std::runtime_error("Context: bad Variable index"); + } + + auto p = this->_inputs[idx]; + + auto v = variable(p); + + if (Environment::getInstance()->isDebugAndVerbose() && v != nullptr && + v->getNDArray() != nullptr) { + auto array = v->getNDArray(); + std::string shape_ = ShapeUtils::shapeAsString(array.get()); + auto type = DataTypeUtils::asString(array->dataType()); + float m = std::numeric_limits::quiet_NaN(); + if (!array->isEmpty()) { + auto values = array->asIndexedString(16); + + nd4j_printf( + "Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: " + "[%i]; dtype: [%s]; first values: %s\n", + this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), + type.c_str(), values.c_str()); + } else { + nd4j_printf( + "Debug info for node_%i input[%i]; shape: %s; ews: [%i]; order: " + "[%i]; dtype: [%s]; mean value: [%f]\n", + this->_nodeId, idx, shape_.c_str(), array->ews(), array->ordering(), + type.c_str(), m); + } + } + + return v; +} + +std::shared_ptr Context::variable(int idx) const { + return getVariable(idx); +} + +std::shared_ptr Context::variable( + std::initializer_list p) const { + if (p.size() != 2) + throw std::runtime_error("Variable address should have size of 2"); + + // FIXME: lol + std::vector vec(p); + std::pair pair(vec[0], vec[1]); + return variable(pair); +} + +std::shared_ptr Context::variable(int node, int idx) const { + std::pair pair(node, idx); + return variable(pair); +} + +std::shared_ptr Context::variable( + const std::pair &p) const { + try { + return _variableSpace->getVariable(p); + } catch (std::exception &e) { + nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", + this->_nodeId, p.first, p.second); + throw std::runtime_error("Bad variable"); + } +} + +void Context::pushNDArrayToVariableSpace(int nodeId, int index, + const NDArray &array) { + std::pair pair(nodeId, index); + pushNDArrayToVariableSpace(pair, array); +} - void Context::setOutputArray(int index, void *buffer, void * shapeInfo, void *specialBuffer, void * specialShapeInfo) { - this->setOutputArray(index, buffer, const_cast(shapeInfo), specialBuffer, const_cast(specialShapeInfo)); - } +void Context::pushNDArrayToVariableSpace(const std::pair &pair, + const NDArray &array) { + if (_variableSpace != nullptr) { + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(array, "", pair.first, pair.second); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArray(std::make_shared(array)); + } + } +} - void Context::setOutputArray(int index, void *buffer, const void * shapeInfo, void *specialBuffer, const void * specialShapeInfo) { - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); +void Context::pushNDArrayListToVariableSpace(int nodeId, int index, + const NDArrayList &list, + bool track) { + std::pair pair(nodeId, index); + pushNDArrayListToVariableSpace(pair, list, track); +} - auto array = std::make_shared(buffer, specialBuffer, reinterpret_cast(shapeInfo)); - - _fastpath_out[index] = array; +void Context::pushNDArrayListToVariableSpace(const std::pair &pair, + const NDArrayList &list, + bool track) { + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(pair.first, pair.second); + var->setNDArrayList(std::make_shared(list)); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArrayList(std::make_shared(list)); + } +} - if (_context != nullptr) - array->setContext(_context); - } +std::shared_ptr Context::ensureVariable(const std::string &name, + int id, int idx) { + std::pair pair(this->nodeId(), idx); - void Context::setInputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) { - auto dataBuffer = reinterpret_cast(vdatabuffer); + if (_variableSpace == nullptr) + throw std::runtime_error("Context::ensureVariable VariableSpace is NULL!"); - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(this->nodeId(), idx); + auto name = this->name(); - std::shared_ptr array; - if (dataBuffer != nullptr) - array = std::make_shared(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); - else - array = std::make_shared(nullptr, nullptr, reinterpret_cast(shapeInfo)); + if (!name.empty()) var->setName(name); - _fastpath_in[index] = array; + _variableSpace->putVariable(pair, var); + return var; + } else { + return _variableSpace->getVariable(pair); + } +} - if (_context != nullptr) - array->setContext(_context); - } +bool Context::isValueAvailable(const std::string &name, int id, int idx) const { + auto var = const_cast(this)->ensureVariable(name, id, idx); - void Context::setOutputArray(int index, void *vdatabuffer, void const* shapeInfo, void const* specialShapeInfo) { - auto dataBuffer = reinterpret_cast(vdatabuffer); + if (var->variableType() == VariableType::NDARRAY) { + return var->hasNDArray(); + } else if (var->variableType() == VariableType::ARRAY_LIST) { + return var->hasNDArrayList(); + } - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); + return false; +} - std::shared_ptr array; - if (dataBuffer != nullptr) - array = std::make_shared(dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), sd::LaunchContext::defaultContext(), dataBuffer->offset() / DataTypeUtils::sizeOf(ArrayOptions::dataType(reinterpret_cast(shapeInfo)))); - else - array = std::make_shared(nullptr, nullptr, reinterpret_cast(shapeInfo)); +std::shared_ptr Context::getNDArray(int idx) const { + return array(idx); +} - _fastpath_out[index] = array; +std::shared_ptr Context::array(int idx) const { + // we check for fastpath first + if (!_fastpath_in.empty() && _fastpath_in.size() > idx) { + return _fastpath_in[idx]; + } - if (_context != nullptr) - array->setContext(_context); - } + // if no luck for fastpath - return whatever is available + return getVariable(idx)->getNDArray(); +} - void Context::setTArguments(double *arguments, int numberOfArguments) { - _tArgs.clear(); - _tArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _tArgs.push_back(arguments[e]); - } +LaunchContext *Context::launchContext() { + // FIXME: we need proper context to be shared here + if (_context == nullptr) { + return LaunchContext::defaultContext(); + } else { + return _context; + } +} + +unsigned long Context::width() const { + if (!_fastpath_in.empty()) + return _fastpath_in.size(); + else + return _inputs.size(); +} + +void Context::setInputArray(int index, const NDArray &array) { + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); - void Context::setIArguments(Nd4jLong *arguments, int numberOfArguments) { - _iArgs.clear(); - _iArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _iArgs.push_back(arguments[e]); - } + _fastpath_in[index] = std::make_shared(array); +} + +NDArray *Context::arrayForOp(int idx) const { + auto ptr = array(idx); - void Context::setBArguments(bool *arguments, int numberOfArguments) { - _bArgs.clear(); - _bArgs.reserve(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - _bArgs.push_back(arguments[e]); - } + if (ptr.get() != nullptr && ptr->undefined()) return nullptr; + + return ptr.get(); +} - void Context::setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer) { +void Context::setInputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + this->setInputArray(index, buffer, const_cast(shapeInfo), + specialBuffer, + const_cast(specialShapeInfo)); +} + +void Context::setInputArray(int index, void *buffer, void const *shapeInfo, + void *specialBuffer, void const *specialShapeInfo) { + auto array = std::make_shared( + buffer, specialBuffer, reinterpret_cast(shapeInfo)); + + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + _fastpath_in[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setOutputArray(int index, const NDArray &array) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + _fastpath_out[index] = std::make_shared(array); +} + +void Context::setOutputArray(int index, void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + this->setOutputArray(index, buffer, const_cast(shapeInfo), + specialBuffer, + const_cast(specialShapeInfo)); +} + +void Context::setOutputArray(int index, void *buffer, const void *shapeInfo, + void *specialBuffer, + const void *specialShapeInfo) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + auto array = std::make_shared( + buffer, specialBuffer, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setInputArray(int index, void *vdatabuffer, void const *shapeInfo, + void const *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + std::shared_ptr array; + if (dataBuffer != nullptr) + array = std::make_shared( + dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), + sd::LaunchContext::defaultContext(), + dataBuffer->offset() / + DataTypeUtils::sizeOf(ArrayOptions::dataType( + reinterpret_cast(shapeInfo)))); + else + array = std::make_shared( + nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_in[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setOutputArray(int index, void *vdatabuffer, + void const *shapeInfo, + void const *specialShapeInfo) { + auto dataBuffer = reinterpret_cast(vdatabuffer); + + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + std::shared_ptr array; + if (dataBuffer != nullptr) + array = std::make_shared( + dataBuffer->dataBuffer(), reinterpret_cast(shapeInfo), + sd::LaunchContext::defaultContext(), + dataBuffer->offset() / + DataTypeUtils::sizeOf(ArrayOptions::dataType( + reinterpret_cast(shapeInfo)))); + else + array = std::make_shared( + nullptr, nullptr, reinterpret_cast(shapeInfo)); + + _fastpath_out[index] = array; + + if (_context != nullptr) array->setContext(_context); +} + +void Context::setTArguments(double *arguments, int numberOfArguments) { + _tArgs.clear(); + _tArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _tArgs.push_back(arguments[e]); +} + +void Context::setIArguments(Nd4jLong *arguments, int numberOfArguments) { + _iArgs.clear(); + _iArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _iArgs.push_back(arguments[e]); +} + +void Context::setBArguments(bool *arguments, int numberOfArguments) { + _bArgs.clear(); + _bArgs.reserve(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) _bArgs.push_back(arguments[e]); +} + +void Context::setCudaContext(Nd4jPointer cudaStream, + Nd4jPointer reductionPointer, + Nd4jPointer allocationPointer) { #ifdef __CUDABLAS__ - _context = new LaunchContext(cudaStream, reductionPointer, allocationPointer); + _context = new LaunchContext(cudaStream, reductionPointer, allocationPointer); - // FIXME: either pass handle from outside, or make sure outside we use the same handle - _context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle()); + // FIXME: either pass handle from outside, or make sure outside we use the + // same handle + _context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle()); - for (auto v: _fastpath_out) - v->setContext(_context); + for (auto v : _fastpath_out) v->setContext(_context); - for (auto v: _fastpath_in) - v->setContext(_context); + for (auto v : _fastpath_in) v->setContext(_context); #endif - } - - void Context::allowHelpers(bool reallyAllow) { - _helpersAllowed = reallyAllow; - } - - bool Context::helpersAllowed() const { - return _helpersAllowed; - } - - void Context::setTArguments(const std::vector &tArgs) { - for (auto t:tArgs) - _tArgs.emplace_back(t); - } - - void Context::setIArguments(const std::vector &iArgs) { - for (auto i:iArgs) - _iArgs.emplace_back(i); - } - - void Context::setBArguments(const std::vector &bArgs) { - for (auto b:bArgs) - _bArgs.push_back(b); - } - - void Context::setShapeFunctionOverride(bool reallyOverride) { - _shapeFunctionOverride = reallyOverride; - } - - bool Context::shapeFunctionOverride() const { - return _shapeFunctionOverride; - } - - samediff::ExecutionMode Context::executionMode() const { - return _execMode; - } - - void Context::setExecutionMode(samediff::ExecutionMode executionMode) { - _execMode = executionMode; - } - - bool Context::isTraining() const { - return _execMode == samediff::ExecutionMode::MODE_TRAINING; - } - - bool Context::isInference() const { - return _execMode == samediff::ExecutionMode::MODE_INFERENCE; - } - - void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) { - _dArgs.clear(); - for (int e = 0; e < numberOfArguments; e++) - _dArgs.emplace_back(arguments[e]); - } - - void Context::setDArguments(const std::vector &dArgs) { - _dArgs.clear(); - for (auto d:dArgs) - _dArgs.emplace_back(d); - } - - void Context::clearFastPath() { - _fastpath_in.clear(); - _fastpath_out.clear(); - } - - const GraphMemoryManager &Context::memoryManager() const { - return *_memoryManager; - } - - void Context::setInputArray(int index, const std::shared_ptr &array) { - if (_fastpath_in.size() < index + 1) - _fastpath_in.resize(index+1); - - _fastpath_in[index] = array; - } - - void Context::setOutputArray(int index, const std::shared_ptr &array) { - if (_fastpath_out.size() < index + 1) - _fastpath_out.resize(index+1); - - _fastpath_out[index] = array; - } - } } +void Context::allowHelpers(bool reallyAllow) { _helpersAllowed = reallyAllow; } + +bool Context::helpersAllowed() const { return _helpersAllowed; } + +void Context::setTArguments(const std::vector &tArgs) { + for (auto t : tArgs) _tArgs.emplace_back(t); +} + +void Context::setIArguments(const std::vector &iArgs) { + for (auto i : iArgs) _iArgs.emplace_back(i); +} + +void Context::setBArguments(const std::vector &bArgs) { + for (auto b : bArgs) _bArgs.push_back(b); +} + +void Context::setShapeFunctionOverride(bool reallyOverride) { + _shapeFunctionOverride = reallyOverride; +} + +bool Context::shapeFunctionOverride() const { return _shapeFunctionOverride; } + +samediff::ExecutionMode Context::executionMode() const { return _execMode; } + +void Context::setExecutionMode(samediff::ExecutionMode executionMode) { + _execMode = executionMode; +} + +bool Context::isTraining() const { + return _execMode == samediff::ExecutionMode::MODE_TRAINING; +} + +bool Context::isInference() const { + return _execMode == samediff::ExecutionMode::MODE_INFERENCE; +} + +void Context::setDArguments(sd::DataType *arguments, int numberOfArguments) { + _dArgs.clear(); + for (int e = 0; e < numberOfArguments; e++) _dArgs.emplace_back(arguments[e]); +} + +void Context::setDArguments(const std::vector &dArgs) { + _dArgs.clear(); + for (auto d : dArgs) _dArgs.emplace_back(d); +} + +void Context::clearFastPath() { + _fastpath_in.clear(); + _fastpath_out.clear(); +} + +const GraphMemoryManager &Context::memoryManager() const { + return *_memoryManager; +} + +void Context::setInputArray(int index, const std::shared_ptr &array) { + if (_fastpath_in.size() < index + 1) _fastpath_in.resize(index + 1); + + _fastpath_in[index] = array; +} + +void Context::setOutputArray(int index, const std::shared_ptr &array) { + if (_fastpath_out.size() < index + 1) _fastpath_out.resize(index + 1); + + _fastpath_out[index] = array; +} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 370a18fbd2a9..dc3409ec974e 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -18,293 +18,249 @@ // @author raver119@gmail.com // -#include +#include #include +#include #include -#include namespace sd { - namespace graph { - ContextPrototype::ContextPrototype(sd::ops::OpDescriptor* opDescriptor, int nodeId, bool inPlace) { - _nodeId = nodeId; - _isInplace = inPlace; - _opDescriptor = opDescriptor; - } - - void ContextPrototype::pickInput(const std::pair& p) { - this->_inputs.emplace_back(p); - } - - void ContextPrototype::pickInput(int input, int index) { - std::pair pair(input, index); - pickInput(pair); - } - - int ContextPrototype::opNum() const { - return this->_opNum; - } - - void ContextPrototype::setOpNum(int opNum) { - this->_opNum = opNum; - } - - const std::vector> & ContextPrototype::inputs() const { - return const_cast> &>(_inputs); - } - - void ContextPrototype::fillInputs(std::vector& inputs) { - for (int e = 0; e < inputs.size(); e++) { - auto v = inputs.at(e); - pickInput(v); - } - } - - samediff::Engine ContextPrototype::engine() const { - return _engine; - } - - bool ContextPrototype::hasVariablesFilled() const { - return this->_inputs.size() > 0; - } - - bool ContextPrototype::isInplace() const { - return this->_isInplace; - } - - const std::vector & ContextPrototype::getTArguments() const { - return const_cast&>(_tArgs); - } - - const std::vector & ContextPrototype::getIArguments() const { - return const_cast&>(_iArgs); - } - - const std::vector & ContextPrototype::getBArguments() const { - return const_cast&>(_bArgs); - } - - const std::vector & ContextPrototype::getAxis() const { - return const_cast&>(_axis); - } - - void ContextPrototype::pickInput(int input) { - std::pair pair(input, 0); - this->_inputs.emplace_back(pair); - } - - const std::pair& ContextPrototype::input(int idx) const { - return this->_inputs.at(idx); - } - - void ContextPrototype::fillInputs(std::initializer_list inputs) { - for (auto v: inputs) { - pickInput(v); - } - } - - int ContextPrototype::nodeId() const { - return getNodeId(); - } - - size_t ContextPrototype::numT() const { - return (int) _tArgs.size(); - } - - size_t ContextPrototype::numI() const { - return (int) _iArgs.size(); - } - - size_t ContextPrototype::numB() const { - return (int) _bArgs.size(); - } - - int ContextPrototype::getNodeId() const { - return this->_nodeId; - } - - /** - * This method returns number of inputs available in this block - * @return - */ - unsigned long ContextPrototype::width() const { - return this->_inputs.size(); - }; - - void ContextPrototype::markInplace(bool reallyInplace) { - this->_isInplace = reallyInplace; - } - - template - ContextPrototype* ContextPrototype::asT() { - auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); - - return clone; - } - - void ContextPrototype::setOpDescriptor(sd::ops::OpDescriptor* opDescriptor) { - _opDescriptor = opDescriptor; - } - - ContextPrototype* ContextPrototype::clone() { - auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); - clone->_opNum = _opNum; - - for (auto v: _inputs) - clone->_inputs.emplace_back(v); - - for (auto v: _tArgs) - clone->_tArgs.emplace_back(v); - - for (auto v: _iArgs) - clone->_iArgs.emplace_back(v); - - return clone; - } - - const std::vector & ContextPrototype::getDArguments() const { - return const_cast&>(_dArgs); - } - - size_t ContextPrototype::numD() const { - return _dArgs.size(); - } - - void ContextPrototype::appendI(const std::vector &value) { - for (auto v:value) - _iArgs.emplace_back(v); - } - - void ContextPrototype::appendT(const std::vector &value) { - for (auto v:value) - _tArgs.emplace_back(v); - } - - void ContextPrototype::appendB(const std::vector &value) { - for (auto v:value) - _bArgs.emplace_back(v); - } - - void ContextPrototype::appendD(const std::vector &value) { - for (auto v:value) - _dArgs.emplace_back(v); - } - - void ContextPrototype::appendA(Nd4jLong value) { - _axis.emplace_back(value); - } - - void ContextPrototype::appendI(Nd4jLong value) { - _iArgs.emplace_back(value); - } - - void ContextPrototype::appendT(double value) { - _tArgs.emplace_back(value); - } - - void ContextPrototype::appendB(bool value) { - _bArgs.emplace_back(value); - } - - void ContextPrototype::appendD(DataType value) { - _dArgs.emplace_back(value); - } - - ContextPrototype::ContextPrototype(const ContextPrototype &other) noexcept { - _inputs = other._inputs; - _tArgs = other._tArgs; - _iArgs = other._iArgs; - _bArgs = other._bArgs; - _dArgs = other._dArgs; - _name = other._name; - - _nodeId = other._nodeId; - _isInplace = other._isInplace; - _opNum = other._opNum; - _rootSeed = other._rootSeed; - _randomGenerator = other._randomGenerator; - _opDescriptor = other._opDescriptor; - _useMKLDNN = other._useMKLDNN; - _engine = other._engine; - _execMode = other._execMode; - } - - ContextPrototype &ContextPrototype::operator=(const ContextPrototype &other) noexcept { - if (this == &other) - return *this; - - _inputs = other._inputs; - _tArgs = other._tArgs; - _iArgs = other._iArgs; - _bArgs = other._bArgs; - _dArgs = other._dArgs; - _name = other._name; - - _nodeId = other._nodeId; - _isInplace = other._isInplace; - _opNum = other._opNum; - _rootSeed = other._rootSeed; - _randomGenerator = other._randomGenerator; - _opDescriptor = other._opDescriptor; - _useMKLDNN = other._useMKLDNN; - _engine = other._engine; - _execMode = other._execMode; - - return *this; - } - - ContextPrototype::ContextPrototype(ContextPrototype &&other) noexcept { - _inputs = std::move(other._inputs); - _tArgs = std::move(other._tArgs); - _iArgs = std::move(other._iArgs); - _bArgs = std::move(other._bArgs); - _dArgs = std::move(other._dArgs); - _name = std::move(other._name); - - _nodeId = other._nodeId; - _isInplace = other._isInplace; - _opNum = other._opNum; - _rootSeed = other._rootSeed; - _randomGenerator = other._randomGenerator; - _opDescriptor = other._opDescriptor; - _useMKLDNN = other._useMKLDNN; - _engine = other._engine; - _execMode = other._execMode; - } - - ContextPrototype &ContextPrototype::operator=(ContextPrototype &&other) noexcept { - if (this == &other) - return *this; - - _inputs = std::move(other._inputs); - _tArgs = std::move(other._tArgs); - _iArgs = std::move(other._iArgs); - _bArgs = std::move(other._bArgs); - _dArgs = std::move(other._dArgs); - _name = std::move(other._name); - - _nodeId = other._nodeId; - _isInplace = other._isInplace; - _opNum = other._opNum; - _rootSeed = other._rootSeed; - _randomGenerator = other._randomGenerator; - _opDescriptor = other._opDescriptor; - _useMKLDNN = other._useMKLDNN; - _engine = other._engine; - _execMode = other._execMode; - - return *this; - } - - void ContextPrototype::setNodeId(int id) { - _nodeId = id; - } - - std::string ContextPrototype::name() const { - return _name; - } - - void ContextPrototype::setName(const std::string &name) { - _name = name; - } - } -} \ No newline at end of file +namespace graph { +ContextPrototype::ContextPrototype(sd::ops::OpDescriptor *opDescriptor, + int nodeId, bool inPlace) { + _nodeId = nodeId; + _isInplace = inPlace; + _opDescriptor = opDescriptor; +} + +void ContextPrototype::pickInput(const std::pair &p) { + this->_inputs.emplace_back(p); +} + +void ContextPrototype::pickInput(int input, int index) { + std::pair pair(input, index); + pickInput(pair); +} + +int ContextPrototype::opNum() const { return this->_opNum; } + +void ContextPrototype::setOpNum(int opNum) { this->_opNum = opNum; } + +const std::vector> &ContextPrototype::inputs() const { + return const_cast> &>(_inputs); +} + +void ContextPrototype::fillInputs(std::vector &inputs) { + for (int e = 0; e < inputs.size(); e++) { + auto v = inputs.at(e); + pickInput(v); + } +} + +samediff::Engine ContextPrototype::engine() const { return _engine; } + +bool ContextPrototype::hasVariablesFilled() const { + return this->_inputs.size() > 0; +} + +bool ContextPrototype::isInplace() const { return this->_isInplace; } + +const std::vector &ContextPrototype::getTArguments() const { + return const_cast &>(_tArgs); +} + +const std::vector &ContextPrototype::getIArguments() const { + return const_cast &>(_iArgs); +} + +const std::vector &ContextPrototype::getBArguments() const { + return const_cast &>(_bArgs); +} + +const std::vector &ContextPrototype::getAxis() const { + return const_cast &>(_axis); +} + +void ContextPrototype::pickInput(int input) { + std::pair pair(input, 0); + this->_inputs.emplace_back(pair); +} + +const std::pair &ContextPrototype::input(int idx) const { + return this->_inputs.at(idx); +} + +void ContextPrototype::fillInputs(std::initializer_list inputs) { + for (auto v : inputs) { + pickInput(v); + } +} + +int ContextPrototype::nodeId() const { return getNodeId(); } + +size_t ContextPrototype::numT() const { return (int)_tArgs.size(); } + +size_t ContextPrototype::numI() const { return (int)_iArgs.size(); } + +size_t ContextPrototype::numB() const { return (int)_bArgs.size(); } + +int ContextPrototype::getNodeId() const { return this->_nodeId; } + +/** + * This method returns number of inputs available in this block + * @return + */ +unsigned long ContextPrototype::width() const { return this->_inputs.size(); }; + +void ContextPrototype::markInplace(bool reallyInplace) { + this->_isInplace = reallyInplace; +} + +template +ContextPrototype *ContextPrototype::asT() { + auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); + + return clone; +} + +void ContextPrototype::setOpDescriptor(sd::ops::OpDescriptor *opDescriptor) { + _opDescriptor = opDescriptor; +} + +ContextPrototype *ContextPrototype::clone() { + auto clone = new ContextPrototype(_opDescriptor, _nodeId, _isInplace); + clone->_opNum = _opNum; + + for (auto v : _inputs) clone->_inputs.emplace_back(v); + + for (auto v : _tArgs) clone->_tArgs.emplace_back(v); + + for (auto v : _iArgs) clone->_iArgs.emplace_back(v); + + return clone; +} + +const std::vector &ContextPrototype::getDArguments() const { + return const_cast &>(_dArgs); +} + +size_t ContextPrototype::numD() const { return _dArgs.size(); } + +void ContextPrototype::appendI(const std::vector &value) { + for (auto v : value) _iArgs.emplace_back(v); +} + +void ContextPrototype::appendT(const std::vector &value) { + for (auto v : value) _tArgs.emplace_back(v); +} + +void ContextPrototype::appendB(const std::vector &value) { + for (auto v : value) _bArgs.emplace_back(v); +} + +void ContextPrototype::appendD(const std::vector &value) { + for (auto v : value) _dArgs.emplace_back(v); +} + +void ContextPrototype::appendA(Nd4jLong value) { _axis.emplace_back(value); } + +void ContextPrototype::appendI(Nd4jLong value) { _iArgs.emplace_back(value); } + +void ContextPrototype::appendT(double value) { _tArgs.emplace_back(value); } + +void ContextPrototype::appendB(bool value) { _bArgs.emplace_back(value); } + +void ContextPrototype::appendD(DataType value) { _dArgs.emplace_back(value); } + +ContextPrototype::ContextPrototype(const ContextPrototype &other) noexcept { + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + _name = other._name; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; +} + +ContextPrototype &ContextPrototype::operator=( + const ContextPrototype &other) noexcept { + if (this == &other) return *this; + + _inputs = other._inputs; + _tArgs = other._tArgs; + _iArgs = other._iArgs; + _bArgs = other._bArgs; + _dArgs = other._dArgs; + _name = other._name; + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; +} + +ContextPrototype::ContextPrototype(ContextPrototype &&other) noexcept { + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + _name = std::move(other._name); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; +} + +ContextPrototype &ContextPrototype::operator=( + ContextPrototype &&other) noexcept { + if (this == &other) return *this; + + _inputs = std::move(other._inputs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _dArgs = std::move(other._dArgs); + _name = std::move(other._name); + + _nodeId = other._nodeId; + _isInplace = other._isInplace; + _opNum = other._opNum; + _rootSeed = other._rootSeed; + _randomGenerator = other._randomGenerator; + _opDescriptor = other._opDescriptor; + _useMKLDNN = other._useMKLDNN; + _engine = other._engine; + _execMode = other._execMode; + + return *this; +} + +void ContextPrototype::setNodeId(int id) { _nodeId = id; } + +std::string ContextPrototype::name() const { return _name; } + +void ContextPrototype::setName(const std::string &name) { _name = name; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ExecutionResult.cpp b/libnd4j/include/graph/impl/ExecutionResult.cpp index 3f0fbdf7fdd4..586bc151ab6a 100644 --- a/libnd4j/include/graph/impl/ExecutionResult.cpp +++ b/libnd4j/include/graph/impl/ExecutionResult.cpp @@ -18,91 +18,88 @@ // @author raver119@gmail.com // +#include +#include #include #include -#include -#include namespace sd { - namespace graph { - ExecutionResult::ExecutionResult(const FlatResult* flatResult) { - if (flatResult->variables() != nullptr) { - for (int e = 0; e < flatResult->variables()->size(); e++) { - auto fv = flatResult->variables()->Get(e); - auto v = new Variable(fv); - this->emplace_back(v); - } - - _releasable = true; - } - } - - ExecutionResult::~ExecutionResult(){ - if (_releasable) - for (auto v : _variables) - delete v; - } - - Nd4jLong ExecutionResult::size() { - return _variables.size(); - } - - ExecutionResult::ExecutionResult(std::initializer_list variables) { - for (auto v: variables) - this->emplace_back(v); - } - - void ExecutionResult::emplace_back(Variable *variable) { - _variables.emplace_back(variable); - - if (!variable->getName().empty()) - _stringIdMap[variable->getName()] = variable; - - std::pair p(variable->id(), variable->index()); - _pairIdMap[p] = variable; - } - - Variable* ExecutionResult::at(int position) { - if (position >= _variables.size()) - throw std::runtime_error("Position index is higher then number of variables stored"); - - return _variables.at(position); - } - - Variable* ExecutionResult::byId(std::string &id) { - if (_stringIdMap.count(id) == 0) - throw std::runtime_error("Can't find specified ID"); - - return _stringIdMap.at(id); - } - - Variable* ExecutionResult::byId(std::pair &id) { - if (_pairIdMap.count(id) == 0) - throw std::runtime_error("Can't find specified ID"); - - return _pairIdMap.at(id); - } - - Variable* ExecutionResult::byId(int id) { - std::pair p(id, 0); - return byId(p); - } - - Variable* ExecutionResult::byId(const char *str) { - std::string p(str); - return byId(p); - } - - flatbuffers::Offset ExecutionResult::asFlatResult(flatbuffers::FlatBufferBuilder &builder) { - - std::vector> vec; - for (Variable* v : _variables) { - vec.emplace_back(v->asFlatVariable(builder)); - } - - auto vecOffset = builder.CreateVector(vec); - - return CreateFlatResult(builder, 0, vecOffset); - } +namespace graph { +ExecutionResult::ExecutionResult(const FlatResult* flatResult) { + if (flatResult->variables() != nullptr) { + for (int e = 0; e < flatResult->variables()->size(); e++) { + auto fv = flatResult->variables()->Get(e); + auto v = new Variable(fv); + this->emplace_back(v); } -} \ No newline at end of file + + _releasable = true; + } +} + +ExecutionResult::~ExecutionResult() { + if (_releasable) + for (auto v : _variables) delete v; +} + +Nd4jLong ExecutionResult::size() { return _variables.size(); } + +ExecutionResult::ExecutionResult(std::initializer_list variables) { + for (auto v : variables) this->emplace_back(v); +} + +void ExecutionResult::emplace_back(Variable* variable) { + _variables.emplace_back(variable); + + if (!variable->getName().empty()) + _stringIdMap[variable->getName()] = variable; + + std::pair p(variable->id(), variable->index()); + _pairIdMap[p] = variable; +} + +Variable* ExecutionResult::at(int position) { + if (position >= _variables.size()) + throw std::runtime_error( + "Position index is higher then number of variables stored"); + + return _variables.at(position); +} + +Variable* ExecutionResult::byId(std::string& id) { + if (_stringIdMap.count(id) == 0) + throw std::runtime_error("Can't find specified ID"); + + return _stringIdMap.at(id); +} + +Variable* ExecutionResult::byId(std::pair& id) { + if (_pairIdMap.count(id) == 0) + throw std::runtime_error("Can't find specified ID"); + + return _pairIdMap.at(id); +} + +Variable* ExecutionResult::byId(int id) { + std::pair p(id, 0); + return byId(p); +} + +Variable* ExecutionResult::byId(const char* str) { + std::string p(str); + return byId(p); +} + +flatbuffers::Offset ExecutionResult::asFlatResult( + flatbuffers::FlatBufferBuilder& builder) { + std::vector> vec; + for (Variable* v : _variables) { + vec.emplace_back(v->asFlatVariable(builder)); + } + + auto vecOffset = builder.CreateVector(vec); + + return CreateFlatResult(builder, 0, vecOffset); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp index 0a578848980f..f189fa850d3c 100644 --- a/libnd4j/include/graph/impl/ExecutorConfiguration.cpp +++ b/libnd4j/include/graph/impl/ExecutorConfiguration.cpp @@ -21,30 +21,35 @@ #include namespace sd { - namespace graph { - ExecutorConfiguration::ExecutorConfiguration(const sd::graph::FlatConfiguration *conf) { - if (conf != nullptr) { - _profilingMode = conf->profilingMode(); - _executionMode = conf->executionMode(); - _outputMode = conf->outputMode(); - _timestats = conf->timestats(); - _footprintForward = conf->footprintForward(); - _footprintBackward = conf->footprintBackward(); - _direction = conf->direction(); - } else { - _profilingMode = ProfilingMode_NONE; - _executionMode = ExecutionMode_SEQUENTIAL; - _outputMode = OutputMode_IMPLICIT; - _timestats = false; - } - }; +namespace graph { +ExecutorConfiguration::ExecutorConfiguration( + const sd::graph::FlatConfiguration *conf) { + if (conf != nullptr) { + _profilingMode = conf->profilingMode(); + _executionMode = conf->executionMode(); + _outputMode = conf->outputMode(); + _timestats = conf->timestats(); + _footprintForward = conf->footprintForward(); + _footprintBackward = conf->footprintBackward(); + _direction = conf->direction(); + } else { + _profilingMode = ProfilingMode_NONE; + _executionMode = ExecutionMode_SEQUENTIAL; + _outputMode = OutputMode_IMPLICIT; + _timestats = false; + } +}; - ExecutorConfiguration ExecutorConfiguration::clone() const { - return ExecutorConfiguration(*this); - }; +ExecutorConfiguration ExecutorConfiguration::clone() const { + return ExecutorConfiguration(*this); +}; - flatbuffers::Offset ExecutorConfiguration::asFlatConfiguration(flatbuffers::FlatBufferBuilder &builder) { - return CreateFlatConfiguration(builder, 0, _executionMode, _profilingMode, _outputMode, _timestats, _footprintBackward, _footprintBackward); - } - } -} \ No newline at end of file +flatbuffers::Offset +ExecutorConfiguration::asFlatConfiguration( + flatbuffers::FlatBufferBuilder &builder) { + return CreateFlatConfiguration(builder, 0, _executionMode, _profilingMode, + _outputMode, _timestats, _footprintBackward, + _footprintBackward); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index 2ca7b9bbf697..40378847cff3 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -18,99 +18,106 @@ // Created by raver119 on 22.11.2017. // -#include #include +#include #include #include -#include #include - +#include namespace sd { - namespace graph { - std::pair FlatUtils::fromIntPair(IntPair *pair) { - return std::pair(pair->first(), pair->second()); - } - - std::pair FlatUtils::fromLongPair(LongPair *pair) { - return std::pair(pair->first(), pair->second()); - } - - NDArray FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { - auto rank = static_cast(flatArray->shape()->Get(0)); - auto newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; - memcpy(newShape, flatArray->shape()->data(), shape::shapeInfoByteLength(rank)); - - auto length = shape::length(newShape); - auto dtype = DataTypeUtils::fromFlatDataType(flatArray->dtype()); - - // empty arrays is special case, nothing to restore here - if (shape::isEmpty(newShape)) { - delete[] newShape; - return NDArrayFactory::empty(dtype); - } - // TODO fix UTF16 and UTF32 - if (dtype == UTF8) { - bool isBe = BitwiseUtils::isBE(); - bool canKeep = (isBe && flatArray->byteOrder() == sd::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == sd::graph::ByteOrder_LE); - - std::vector substrings(length); - std::vector shapeVector(rank); - for (int e = 0; e < rank; e++) - shapeVector[e] = newShape[e+1]; - - auto rawPtr = (void *)flatArray->buffer()->data(); - auto longPtr = reinterpret_cast(rawPtr); - auto charPtr = reinterpret_cast(longPtr + length + 1); - auto offsets = new Nd4jLong[length+1]; - for (Nd4jLong e = 0; e <= length; e++) { - auto o = longPtr[e]; - // FIXME: BE vs LE on partials - //auto v = canKeep ? o : BitwiseUtils::swap_bytes(o); - offsets[e] = o; - } - - for (Nd4jLong e = 0; e < length; e++) { - auto start = offsets[e]; - auto end = offsets[e+1]; - auto len = end - start; - - auto c = (char *) malloc(len+1); - CHECK_ALLOC(c, "Failed temp allocation", len + 1); - memset(c, '\0', len + 1); - memcpy(c, charPtr + start, len); - - std::string val(c); - substrings[e] = val; - free(c); - } - - delete[] offsets; - delete[] newShape; - // string order always 'c' - return NDArrayFactory::string(shapeVector, substrings); - } - - - auto newBuffer = new int8_t[length * DataTypeUtils::sizeOf(dtype)]; - - BUILD_SINGLE_SELECTOR(dtype, DataTypeConversions, ::convertType(newBuffer, (void *)flatArray->buffer()->data(), dtype, ByteOrderUtils::fromFlatByteOrder(flatArray->byteOrder()), length), LIBND4J_TYPES); - - NDArray array(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); - - delete[] newShape; - return array; - } - - flatbuffers::Offset FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) { - auto byteVector = array.asByteVector(); - - auto fBuffer = builder.CreateVector(byteVector); - auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); - - auto bo = static_cast(BitwiseUtils::asByteOrder()); - - return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); - } +namespace graph { +std::pair FlatUtils::fromIntPair(IntPair *pair) { + return std::pair(pair->first(), pair->second()); +} + +std::pair FlatUtils::fromLongPair(LongPair *pair) { + return std::pair(pair->first(), pair->second()); +} + +NDArray FlatUtils::fromFlatArray(const sd::graph::FlatArray *flatArray) { + auto rank = static_cast(flatArray->shape()->Get(0)); + auto newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; + memcpy(newShape, flatArray->shape()->data(), + shape::shapeInfoByteLength(rank)); + + auto length = shape::length(newShape); + auto dtype = DataTypeUtils::fromFlatDataType(flatArray->dtype()); + + // empty arrays is special case, nothing to restore here + if (shape::isEmpty(newShape)) { + delete[] newShape; + return NDArrayFactory::empty(dtype); + } + // TODO fix UTF16 and UTF32 + if (dtype == UTF8) { + bool isBe = BitwiseUtils::isBE(); + bool canKeep = + (isBe && flatArray->byteOrder() == sd::graph::ByteOrder_BE) || + (!isBe && flatArray->byteOrder() == sd::graph::ByteOrder_LE); + + std::vector substrings(length); + std::vector shapeVector(rank); + for (int e = 0; e < rank; e++) shapeVector[e] = newShape[e + 1]; + + auto rawPtr = (void *)flatArray->buffer()->data(); + auto longPtr = reinterpret_cast(rawPtr); + auto charPtr = reinterpret_cast(longPtr + length + 1); + auto offsets = new Nd4jLong[length + 1]; + for (Nd4jLong e = 0; e <= length; e++) { + auto o = longPtr[e]; + // FIXME: BE vs LE on partials + // auto v = canKeep ? o : BitwiseUtils::swap_bytes(o); + offsets[e] = o; + } + + for (Nd4jLong e = 0; e < length; e++) { + auto start = offsets[e]; + auto end = offsets[e + 1]; + auto len = end - start; + + auto c = (char *)malloc(len + 1); + CHECK_ALLOC(c, "Failed temp allocation", len + 1); + memset(c, '\0', len + 1); + memcpy(c, charPtr + start, len); + + std::string val(c); + substrings[e] = val; + free(c); } -} \ No newline at end of file + + delete[] offsets; + delete[] newShape; + // string order always 'c' + return NDArrayFactory::string(shapeVector, substrings); + } + + auto newBuffer = new int8_t[length * DataTypeUtils::sizeOf(dtype)]; + + BUILD_SINGLE_SELECTOR( + dtype, DataTypeConversions, + ::convertType(newBuffer, (void *)flatArray->buffer()->data(), dtype, + ByteOrderUtils::fromFlatByteOrder(flatArray->byteOrder()), + length), + LIBND4J_TYPES); + + NDArray array(newBuffer, newShape, sd::LaunchContext::defaultContext(), true); + + delete[] newShape; + return array; +} + +flatbuffers::Offset FlatUtils::toFlatArray( + flatbuffers::FlatBufferBuilder &builder, NDArray &array) { + auto byteVector = array.asByteVector(); + + auto fBuffer = builder.CreateVector(byteVector); + auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); + + auto bo = static_cast(BitwiseUtils::asByteOrder()); + + return CreateFlatArray(builder, fShape, fBuffer, + static_cast(array.dataType()), bo); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FlowPath.cpp b/libnd4j/include/graph/impl/FlowPath.cpp index 79fe67b30110..f4119896289b 100644 --- a/libnd4j/include/graph/impl/FlowPath.cpp +++ b/libnd4j/include/graph/impl/FlowPath.cpp @@ -21,131 +21,124 @@ #include namespace sd { - namespace graph { +namespace graph { - void FlowPath::ensureNode(int nodeId) { - if (_states.count(nodeId) == 0) { - NodeState state(nodeId); - _states[nodeId] = state; - } - } +void FlowPath::ensureNode(int nodeId) { + if (_states.count(nodeId) == 0) { + NodeState state(nodeId); + _states[nodeId] = state; + } +} - void FlowPath::ensureFrame(int frameId) { - if (_frames.count(frameId) == 0) { - FrameState state(frameId); - _frames[frameId] = state; - } - } +void FlowPath::ensureFrame(int frameId) { + if (_frames.count(frameId) == 0) { + FrameState state(frameId); + _frames[frameId] = state; + } +} - void FlowPath::setInnerTime(int nodeId, Nd4jLong time) { - ensureNode(nodeId); +void FlowPath::setInnerTime(int nodeId, Nd4jLong time) { + ensureNode(nodeId); - _states[nodeId].setInnerTime(time); - } + _states[nodeId].setInnerTime(time); +} - void FlowPath::setOuterTime(int nodeId, Nd4jLong time) { - ensureNode(nodeId); +void FlowPath::setOuterTime(int nodeId, Nd4jLong time) { + ensureNode(nodeId); - _states[nodeId].setOuterTime(time); - } + _states[nodeId].setOuterTime(time); +} - Nd4jLong FlowPath::innerTime(int nodeId) { - ensureNode(nodeId); +Nd4jLong FlowPath::innerTime(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].innerTime(); - } + return _states[nodeId].innerTime(); +} - Nd4jLong FlowPath::outerTime(int nodeId) { - ensureNode(nodeId); +Nd4jLong FlowPath::outerTime(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].outerTime(); - } + return _states[nodeId].outerTime(); +} - bool FlowPath::isNodeActive(int nodeId) { - ensureNode(nodeId); +bool FlowPath::isNodeActive(int nodeId) { + ensureNode(nodeId); - return _states[nodeId].isActive(); - } - - void FlowPath::markNodeActive(int nodeId, bool isActive) { - ensureNode(nodeId); + return _states[nodeId].isActive(); +} - _states[nodeId].markActive(isActive); - } +void FlowPath::markNodeActive(int nodeId, bool isActive) { + ensureNode(nodeId); - int FlowPath::branch(int nodeId){ - ensureNode(nodeId); + _states[nodeId].markActive(isActive); +} - return _states[nodeId].branch(); - } +int FlowPath::branch(int nodeId) { + ensureNode(nodeId); - void FlowPath::markBranch(int nodeId, int index) { - ensureNode(nodeId); + return _states[nodeId].branch(); +} - _states[nodeId].markBranch(index); - } +void FlowPath::markBranch(int nodeId, int index) { + ensureNode(nodeId); - bool FlowPath::isFrameActive(Nd4jLong frameId) { - ensureFrame(frameId); + _states[nodeId].markBranch(index); +} - return _frames[frameId].wasActivated(); - } +bool FlowPath::isFrameActive(Nd4jLong frameId) { + ensureFrame(frameId); - void FlowPath::markFrameActive(Nd4jLong frameId, bool isActive) { - ensureFrame(frameId); + return _frames[frameId].wasActivated(); +} - _frames[frameId].markActivated(isActive); - } +void FlowPath::markFrameActive(Nd4jLong frameId, bool isActive) { + ensureFrame(frameId); - bool FlowPath::isRewindPlanned(Nd4jLong frameId) { - return _frames[frameId].isRewindPlanned(); - } + _frames[frameId].markActivated(isActive); +} - void FlowPath::planRewind(Nd4jLong frameId, bool reallyRewind) { - _frames[frameId].planRewind(reallyRewind); - } +bool FlowPath::isRewindPlanned(Nd4jLong frameId) { + return _frames[frameId].isRewindPlanned(); +} - int FlowPath::getRewindPosition(Nd4jLong frameId) { - return _frames[frameId].getRewindPosition(); - } +void FlowPath::planRewind(Nd4jLong frameId, bool reallyRewind) { + _frames[frameId].planRewind(reallyRewind); +} - void FlowPath::setRewindPosition(Nd4jLong frameId, int position) { - _frames[frameId].setRewindPosition(position); - } +int FlowPath::getRewindPosition(Nd4jLong frameId) { + return _frames[frameId].getRewindPosition(); +} - void FlowPath::setRewindPositionOnce(Nd4jLong frameId, int position) { - _frames[frameId].setRewindPositionOnce(position); - } +void FlowPath::setRewindPosition(Nd4jLong frameId, int position) { + _frames[frameId].setRewindPosition(position); +} - void FlowPath::registerFrame(Nd4jLong frameId) { - if (_frames.count(frameId) == 0) - ensureFrame(frameId); - } +void FlowPath::setRewindPositionOnce(Nd4jLong frameId, int position) { + _frames[frameId].setRewindPositionOnce(position); +} - void FlowPath::forgetFrame(Nd4jLong frameId) { - if (_frames.count(frameId) > 0) - _frames.erase(frameId); - } +void FlowPath::registerFrame(Nd4jLong frameId) { + if (_frames.count(frameId) == 0) ensureFrame(frameId); +} - void FlowPath::incrementNumberOfCycles(Nd4jLong frameId) { - _frames[frameId].incrementNumberOfCycles(); - } +void FlowPath::forgetFrame(Nd4jLong frameId) { + if (_frames.count(frameId) > 0) _frames.erase(frameId); +} - Nd4jLong FlowPath::getNumberOfCycles(Nd4jLong frameId) { - return _frames[frameId].getNumberOfCycles(); - } +void FlowPath::incrementNumberOfCycles(Nd4jLong frameId) { + _frames[frameId].incrementNumberOfCycles(); +} +Nd4jLong FlowPath::getNumberOfCycles(Nd4jLong frameId) { + return _frames[frameId].getNumberOfCycles(); +} - bool FlowPath::wasExecuted(int nodeId) { - return _states[nodeId].wasExecuted(); - } +bool FlowPath::wasExecuted(int nodeId) { return _states[nodeId].wasExecuted(); } - void FlowPath::markExecuted(int nodeId, bool wasExecuted) { - _states[nodeId].markExecuted(wasExecuted); - } +void FlowPath::markExecuted(int nodeId, bool wasExecuted) { + _states[nodeId].markExecuted(wasExecuted); +} - GraphProfile* FlowPath::profile() { - return &_profile; - } - } -} \ No newline at end of file +GraphProfile* FlowPath::profile() { return &_profile; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/FrameState.cpp b/libnd4j/include/graph/impl/FrameState.cpp index d312a4f39069..fc3fbb9d4b73 100644 --- a/libnd4j/include/graph/impl/FrameState.cpp +++ b/libnd4j/include/graph/impl/FrameState.cpp @@ -20,52 +20,34 @@ #include - namespace sd { - namespace graph { - FrameState::FrameState(Nd4jLong id) { - this->_id = id; - } +namespace graph { +FrameState::FrameState(Nd4jLong id) { this->_id = id; } - int FrameState::getNumberOfCycles() { - return _numberOfCycles; - } +int FrameState::getNumberOfCycles() { return _numberOfCycles; } - void FrameState::incrementNumberOfCycles() { - ++_numberOfCycles; - } +void FrameState::incrementNumberOfCycles() { ++_numberOfCycles; } - bool FrameState::wasActivated() { - return _activated; - } +bool FrameState::wasActivated() { return _activated; } - void FrameState::markActivated(bool reallyActivated) { - _activated = reallyActivated; - } +void FrameState::markActivated(bool reallyActivated) { + _activated = reallyActivated; +} - std::string &FrameState::getFrameName() { - return _name; - } +std::string &FrameState::getFrameName() { return _name; } - bool FrameState::isRewindPlanned() { - return _rewindPlanned; - } +bool FrameState::isRewindPlanned() { return _rewindPlanned; } - int FrameState::getRewindPosition() { - return _rewindPosition; - } +int FrameState::getRewindPosition() { return _rewindPosition; } - void FrameState::setRewindPosition(int pos) { - _rewindPosition = pos; - } +void FrameState::setRewindPosition(int pos) { _rewindPosition = pos; } - void FrameState::setRewindPositionOnce(int pos) { - if (_rewindPosition < 0) - _rewindPosition = pos; - } +void FrameState::setRewindPositionOnce(int pos) { + if (_rewindPosition < 0) _rewindPosition = pos; +} - void FrameState::planRewind(bool reallyPlanning) { - _rewindPlanned = reallyPlanning; - } - } -} \ No newline at end of file +void FrameState::planRewind(bool reallyPlanning) { + _rewindPlanned = reallyPlanning; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index d70fe8c1a7f9..65912d6d9e61 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -18,674 +18,691 @@ // @author raver119@gmail.com // -#include #include -#include +#include +#include +#include +#include #include -#include -#include -#include -#include +#include #include -#include #include #include +#include #include -#include -#include -#include - -namespace sd { - namespace graph { - const std::vector>& Graph::placeholders() const { - return _variableSpace.placeholders(); - } - - int Graph::numberOfPlaceholders() const { - return _variableSpace.numberOfPlaceholders(); - }; - - const ExecutorConfiguration& Graph::getExecutorConfiguration() const { - return _configuration; - } - - VariableSpace& Graph::variableSpace() const { - return const_cast(_variableSpace); - } - - Graph::~Graph() { - - } - - int Graph::idByName(const std::string &nodeName) const { - if (_symbolicLookupTable.count(nodeName) == 0) - throw std::runtime_error("Can't find node [" + nodeName + "]"); - - return _symbolicLookupTable.at(nodeName); - } - - void Graph::addVariable(const std::string &name, NDArray &array) { - int id = _maxId++; - _symbolicLookupTable[name] = id; - _variableSpace.putVariable(id, 0, array); - } - - void Graph::addVariable(const std::string &name, NDArray &&array) { - auto lvalue = array; - addVariable(name, lvalue); - } - - void Graph::addNode(Node &&node, const std::initializer_list &inputs) { - auto lvalue = std::move(node); - addNode(lvalue, inputs); - } - - void Graph::addNode(Node &node, const std::initializer_list &inputs) { - // temporary check. basically we're okay if Node has id defined - if (node.id() != 0) - throw std::runtime_error("Graph::addNode - Node has id defined"); - - if (node.name().empty()) { - // if name is empty we'll make up a name based on Op name - } else { - if (_symbolicLookupTable.count(node.name()) > 0) - throw std::runtime_error("Graph::addNode - Graph alread has Node [" + node.name() + "] defined"); - } - - // node must have numeric id - node.setId(_maxId++); - _symbolicLookupTable[node.name()] = node.id(); - - // converting string ids to numeric ones - for (auto &v:inputs) { - // we don't allow self-references - if (v == node.name()) - throw unresolved_input_exception::build("Graph::addNode - Node references itself", v); - - node.pickInput(idByName(v), 0); - } - - // actually storing the node. Later, topological sort will be applied on this map - _unmapped[node.id()] = node; - } +#include +#include +#include - void Graph::addNode(Node &node, const std::initializer_list &inputs) { - throw std::runtime_error("Graph::addNode() - Not implemented yet"); - } +#include - void Graph::addNode(Node &node, const std::initializer_list> &inputs) { - node.markRemovable(false); +namespace sd { +namespace graph { +const std::vector> &Graph::placeholders() const { + return _variableSpace.placeholders(); +} - throw std::runtime_error("Graph::addNode() - Not implemented yet"); - } +int Graph::numberOfPlaceholders() const { + return _variableSpace.numberOfPlaceholders(); +}; - Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager) : _memoryMaager(memoryManager) { - bool trusted = flatGraph != nullptr; - - // if there was no exec configuration in flatgraph - create default one - if (flatGraph != nullptr && flatGraph->configuration() != nullptr) { - _configuration = ExecutorConfiguration(flatGraph->configuration()); - } else - _configuration = ExecutorConfiguration(); - - // if memory reqs were set - initialize workspace - if (_configuration._footprintForward > 0) { - _workspace.expandBy(_configuration._footprintForward); - } - - // parsing variables here - if (flatGraph != nullptr && flatGraph->variables() != nullptr && flatGraph->variables()->size() > 0) { - for (unsigned int e = 0; e < flatGraph->variables()->size(); e++) { - auto flatVar = flatGraph->variables()->Get(e); - std::pair pair(flatVar->id()->first(), flatVar->id()->second()); - - auto var = std::make_shared(flatVar); - if (flatVar->name() != nullptr) { - var->setName(flatVar->name()->str()); - _symbolicLookupTable[var->name()] = pair.first; - } - - _variableSpace.putVariable(pair, var); - } - } - - // at this point we expect all variables are already registered - // we're saving outputs only if explicit mode is set - if (_configuration._outputMode == OutputMode_EXPLICIT || _configuration._outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { - if (flatGraph != nullptr && flatGraph->outputs() != nullptr) { - for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { - auto out = flatGraph->outputs()->Get(e); - std::pair vp(out->first(), out->second()); - if (!_variableSpace.hasVariable(vp)) { - nd4j_verbose("Non-existent variable requested: %i\n", out); - throw std::runtime_error("Non-existent variable requested"); - } - } - } - } - - // rolling through nodes - if (flatGraph != nullptr && flatGraph->nodes() != nullptr && flatGraph->nodes()->size() > 0) { - for (unsigned int e = 0; e < flatGraph->nodes()->size(); e++) { - auto node = flatGraph->nodes()->Get(e); - - if (node->output() == nullptr || node->output()->size() == 0) { - nd4j_verbose("Orphan node detected: %i; AutoOutput to be considered\n", node->id()); - } - - nd4j_debug("Node name: [%s]\n", node->name()->c_str()); - Node nnode(node); - // just filling list of nodes - _unmapped[nnode.id()] = nnode; - - if (!nnode.name().empty()) - _symbolicLookupTable[nnode.name()] = nnode.id(); - } - } - } +const ExecutorConfiguration &Graph::getExecutorConfiguration() const { + return _configuration; +} - /** - * This method returns total number of nodes in this graph - * @return - */ - int Graph::size() const { - return _unmapped.size(); - } +VariableSpace &Graph::variableSpace() const { + return const_cast(_variableSpace); +} - Nd4jStatus Graph::validate() { - throw std::runtime_error("Graph::validate - method not implemented"); - }; - - void Graph::printOutNode(const Node &node) const { - nd4j_printf("%i. ", node.id()); - switch(node.opType()) { - case OpType_CUSTOM: { - printf("%s; ", node.customOp()->getOpName().c_str()); - } - break; - case OpType_LOGIC: { - printf("%s; ", EnumUtils::_LogicOpToString(node.opNum())); - } - break; - default: { - printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node.opType()), (int) node.opNum()); - } - } - - nd4j_printf("Inputs: [", ""); - //auto block = node->getBlock(); - for (int e = 0; e < node.input().size(); e++) { - - auto in = node.input()[e]; - printf("{%i:%i}", in.first, in.second); - if (e < node.input().size() - 1) - nd4j_printf(", ", ""); - } - - if (node.opType() == OpType_CUSTOM) { - auto ctx = node.protoContext(); - if (ctx.numI() > 0) { - printf("]; iArgs: ["); - - for (int e = 0; e < ctx.numI(); e++) { - printf("%i", ctx.getIArguments().at(e)); - if (e < ctx.getIArguments().size() - 1) - nd4j_printf(", ", ""); - } - } - } - - nd4j_printf("]; \n", ""); - - -// printf("\n"); - fflush(stdout); - } +Graph::~Graph() {} - void Graph::printOut() { - // print variables first - if (_variableSpace.totalEntries() > 0) { - nd4j_printf("\nPrinting out Variables...\n", ""); - auto vars = _variableSpace.variables(); - - for (auto &v: vars) { - if (v->hasNDArray()) { - auto shape = ShapeUtils::shapeAsString(v->getNDArray().get()); - auto values = v->getNDArray()->asString(16); - auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); - - if (!v->getName().empty()) { - nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", v->getName().c_str(), v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); - } else { - nd4j_printf("<%i:%i> dtype: %s; shape: %s; values: %s;\n", v->id(), v->index(), dtype.c_str(), shape.c_str(), values.c_str()); - } - } else if (v->hasNDArrayList()) { - // TODO: add better NDArrayList printout - nd4j_printf("<%i:%i> holds ArrayList", v->id(), v->index()); - } - } - } - - fflush(stdout); - - if (size() > 0) { - nd4j_printf("\nPrinting out Nodes...\n", ""); - - // since we need structure - we'll print out nodes of OptimizedGraph - optimizedGraph().printOut(); - } - } +int Graph::idByName(const std::string &nodeName) const { + if (_symbolicLookupTable.count(nodeName) == 0) + throw std::runtime_error("Can't find node [" + nodeName + "]"); - Nd4jStatus Graph::validateNode(Node *node) { - // TODO: to be implemented - return ND4J_STATUS_OK; - } + return _symbolicLookupTable.at(nodeName); +} - void Graph::replaceState(VariableSpace *state, const ExecutorConfiguration &configuration) { - _variableSpace = *state; - _configuration = configuration; - } +void Graph::addVariable(const std::string &name, NDArray &array) { + int id = _maxId++; + _symbolicLookupTable[name] = id; + _variableSpace.putVariable(id, 0, array); +} - Graph Graph::cloneWithProxy() const { - Graph clone; +void Graph::addVariable(const std::string &name, NDArray &&array) { + auto lvalue = array; + addVariable(name, lvalue); +} - //clone.replaceState(new VariableProxy(&this->_variableSpace), this->_configuration); +void Graph::addNode(Node &&node, + const std::initializer_list &inputs) { + auto lvalue = std::move(node); + addNode(lvalue, inputs); +} - //return clone; - throw std::runtime_error("Graph::cloneWithProxy - Not implemented yet"); - } +void Graph::addNode(Node &node, + const std::initializer_list &inputs) { + // temporary check. basically we're okay if Node has id defined + if (node.id() != 0) + throw std::runtime_error("Graph::addNode - Node has id defined"); + + if (node.name().empty()) { + // if name is empty we'll make up a name based on Op name + } else { + if (_symbolicLookupTable.count(node.name()) > 0) + throw std::runtime_error("Graph::addNode - Graph alread has Node [" + + node.name() + "] defined"); + } + + // node must have numeric id + node.setId(_maxId++); + _symbolicLookupTable[node.name()] = node.id(); + + // converting string ids to numeric ones + for (auto &v : inputs) { + // we don't allow self-references + if (v == node.name()) + throw unresolved_input_exception::build( + "Graph::addNode - Node references itself", v); + + node.pickInput(idByName(v), 0); + } + + // actually storing the node. Later, topological sort will be applied on this + // map + _unmapped[node.id()] = node; +} - Graph* Graph::clone() const { - auto clone = new Graph(); +void Graph::addNode(Node &node, const std::initializer_list &inputs) { + throw std::runtime_error("Graph::addNode() - Not implemented yet"); +} - //clone->replaceState(&this->_variableSpace, this->_configuration.clone()); +void Graph::addNode(Node &node, + const std::initializer_list> &inputs) { + node.markRemovable(false); - throw std::runtime_error("Graph::clone - not implemented yet"); - } + throw std::runtime_error("Graph::addNode() - Not implemented yet"); +} - Nd4jLong Graph::hashCode() const { - throw std::runtime_error("Graph::hashCode - not implemented yet"); +Graph::Graph(const FlatGraph *flatGraph, + const GraphMemoryManager &memoryManager) + : _memoryMaager(memoryManager) { + bool trusted = flatGraph != nullptr; + + // if there was no exec configuration in flatgraph - create default one + if (flatGraph != nullptr && flatGraph->configuration() != nullptr) { + _configuration = ExecutorConfiguration(flatGraph->configuration()); + } else + _configuration = ExecutorConfiguration(); + + // if memory reqs were set - initialize workspace + if (_configuration._footprintForward > 0) { + _workspace.expandBy(_configuration._footprintForward); + } + + // parsing variables here + if (flatGraph != nullptr && flatGraph->variables() != nullptr && + flatGraph->variables()->size() > 0) { + for (unsigned int e = 0; e < flatGraph->variables()->size(); e++) { + auto flatVar = flatGraph->variables()->Get(e); + std::pair pair(flatVar->id()->first(), flatVar->id()->second()); + + auto var = std::make_shared(flatVar); + if (flatVar->name() != nullptr) { + var->setName(flatVar->name()->str()); + _symbolicLookupTable[var->name()] = pair.first; + } + + _variableSpace.putVariable(pair, var); + } + } + + // at this point we expect all variables are already registered + // we're saving outputs only if explicit mode is set + if (_configuration._outputMode == OutputMode_EXPLICIT || + _configuration._outputMode == OutputMode_EXPLICIT_AND_IMPLICIT) { + if (flatGraph != nullptr && flatGraph->outputs() != nullptr) { + for (unsigned int e = 0; e < flatGraph->outputs()->size(); e++) { + auto out = flatGraph->outputs()->Get(e); + std::pair vp(out->first(), out->second()); + if (!_variableSpace.hasVariable(vp)) { + nd4j_verbose("Non-existent variable requested: %i\n", out); + throw std::runtime_error("Non-existent variable requested"); } + } + } + } + + // rolling through nodes + if (flatGraph != nullptr && flatGraph->nodes() != nullptr && + flatGraph->nodes()->size() > 0) { + for (unsigned int e = 0; e < flatGraph->nodes()->size(); e++) { + auto node = flatGraph->nodes()->Get(e); + + if (node->output() == nullptr || node->output()->size() == 0) { + nd4j_verbose("Orphan node detected: %i; AutoOutput to be considered\n", + node->id()); + } + + nd4j_debug("Node name: [%s]\n", node->name()->c_str()); + Node nnode(node); + // just filling list of nodes + _unmapped[nnode.id()] = nnode; + + if (!nnode.name().empty()) + _symbolicLookupTable[nnode.name()] = nnode.id(); + } + } +} +/** + * This method returns total number of nodes in this graph + * @return + */ +int Graph::size() const { return _unmapped.size(); } + +Nd4jStatus Graph::validate() { + throw std::runtime_error("Graph::validate - method not implemented"); +}; + +void Graph::printOutNode(const Node &node) const { + nd4j_printf("%i. ", node.id()); + switch (node.opType()) { + case OpType_CUSTOM: { + printf("%s; ", node.customOp()->getOpName().c_str()); + } break; + case OpType_LOGIC: { + printf("%s; ", EnumUtils::_LogicOpToString(node.opNum())); + } break; + default: { + printf("%s:{%i}; ", EnumUtils::_OpTypeToString(node.opType()), + (int)node.opNum()); + } + } + + nd4j_printf("Inputs: [", ""); + // auto block = node->getBlock(); + for (int e = 0; e < node.input().size(); e++) { + auto in = node.input()[e]; + printf("{%i:%i}", in.first, in.second); + if (e < node.input().size() - 1) nd4j_printf(", ", ""); + } + + if (node.opType() == OpType_CUSTOM) { + auto ctx = node.protoContext(); + if (ctx.numI() > 0) { + printf("]; iArgs: ["); + + for (int e = 0; e < ctx.numI(); e++) { + printf("%i", ctx.getIArguments().at(e)); + if (e < ctx.getIArguments().size() - 1) nd4j_printf(", ", ""); + } + } + } - Graph Graph::fromFlatBuffers(const char* fileName, const GraphMemoryManager &memoryManager) { - // check if file exists - if (!FileUtils::fileExists(fileName)) - throw std::runtime_error("Graph file doesn't exist"); - - // get file size - auto fsize = FileUtils::fileSize(fileName); - Nd4jLong *ref; - void *ptrGraph; - - // TODO: check if mmap is supported - if (true) { - // mmap this file - ref = ::mmapFile(nullptr, fileName, fsize); - ptrGraph = reinterpret_cast(ref[0]); - } else { - // if mmap is not supported - load it directly - - ptrGraph = new uint8_t[fsize]; - auto data = reinterpret_cast(ptrGraph); - - FILE *in = fopen(fileName, "rb"); - int cnt = 0; - int b = 0; - while (cnt < fsize) { - b = fread(data + cnt, 1, fsize < 16384 ? fsize : 16384, in); + nd4j_printf("]; \n", ""); - cnt += b; - } - fclose(in); - } + // printf("\n"); + fflush(stdout); +} - return fromFlatPointer(ptrGraph, memoryManager); +void Graph::printOut() { + // print variables first + if (_variableSpace.totalEntries() > 0) { + nd4j_printf("\nPrinting out Variables...\n", ""); + auto vars = _variableSpace.variables(); + + for (auto &v : vars) { + if (v->hasNDArray()) { + auto shape = ShapeUtils::shapeAsString(v->getNDArray().get()); + auto values = v->getNDArray()->asString(16); + auto dtype = DataTypeUtils::asString(v->getNDArray()->dataType()); + + if (!v->getName().empty()) { + nd4j_printf("<%s> <%i:%i> dtype: %s; shape: %s; values: %s;\n", + v->getName().c_str(), v->id(), v->index(), dtype.c_str(), + shape.c_str(), values.c_str()); + } else { + nd4j_printf("<%i:%i> dtype: %s; shape: %s; values: %s;\n", v->id(), + v->index(), dtype.c_str(), shape.c_str(), values.c_str()); } + } else if (v->hasNDArrayList()) { + // TODO: add better NDArrayList printout + nd4j_printf("<%i:%i> holds ArrayList", v->id(), v->index()); + } + } + } - Graph Graph::fromFlatPointer(void *ptr, const GraphMemoryManager &memoryManager) { - // get FlatGraph out of it - auto fg = GetFlatGraph(reinterpret_cast(ptr)); + fflush(stdout); - // return Graph from this FlatGraph - return Graph(fg, memoryManager); - } + if (size() > 0) { + nd4j_printf("\nPrinting out Nodes...\n", ""); - Graph Graph::importFromTensorFlow(const char *fileName) { - throw std::runtime_error("Graph::importFromTensorFlow() not implemented yet"); - /* - if (fileName == nullptr) - return nullptr; + // since we need structure - we'll print out nodes of OptimizedGraph + optimizedGraph().printOut(); + } +} - int fd = open(fileName, O_RDONLY); +Nd4jStatus Graph::validateNode(Node *node) { + // TODO: to be implemented + return ND4J_STATUS_OK; +} - if (fd < 0) { - nd4j_printf("File not found: [%s]\n", fileName); - return nullptr; - } +void Graph::replaceState(VariableSpace *state, + const ExecutorConfiguration &configuration) { + _variableSpace = *state; + _configuration = configuration; +} - nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); +Graph Graph::cloneWithProxy() const { + Graph clone; - tensorflow::GraphDef graphDef; - bool res = graphDef.ParseFromFileDescriptor(fd); + // clone.replaceState(new VariableProxy(&this->_variableSpace), + // this->_configuration); - // trying to read graph as text - if(!res) { - close(fd); - fd = open(fileName, O_RDONLY); + // return clone; + throw std::runtime_error("Graph::cloneWithProxy - Not implemented yet"); +} - google::protobuf::io::FileInputStream fileInput(fd); - fileInput.SetCloseOnDelete(true); +Graph *Graph::clone() const { + auto clone = new Graph(); - if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { - nd4j_printf("Failed to read file\n",""); - } else { - res = true; - } - } - - close(fd); - - if (!res) - return nullptr; - - auto graph = new Graph(); - auto variableSpace = graph->variableSpace(); - - std::map variablesMap; - - int variablesCounter = 0; - int nodesCounter = 0; - nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); - for (int n = 0; n < graphDef.node_size(); n++) { - auto node = graphDef.node(n); - - // if that's external variable - we put it to variable space - if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { - nd4j_printf("Variable found: %s\n", node.name().c_str()); - auto variable = new Variable(); - variable->setName(new std::string(node.name().c_str())); - variable->setId(--variablesCounter); - variableSpace->putVariable(variable->id(), variable); - - std::pair pair(node.name(), variable->id()); - variablesMap.insert(pair); - - // TODO: we might want to have something like that. - // it basically just gives input validation option, since settles expectations for input - if (strcmp(TF_INPUT, node.op().c_str()) == 0) - continue; - - // checking shape, not applicable to input, since it can vary - if (node.attr().count("shape")) { - auto attr = node.attr().at("shape"); - int dims = attr.shape().dim_size(); - - if (dims > 0) { - std::vector __shape; - - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); - - // roll through dimensions - for (auto s: attr.shape().dim()) { - __shape.push_back((int) s.size()) ; - } + // clone->replaceState(&this->_variableSpace, this->_configuration.clone()); - variable->setNDArray(new NDArray('c', __shape)); + throw std::runtime_error("Graph::clone - not implemented yet"); +} - nd4j_printf("Shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } +Nd4jLong Graph::hashCode() const { + throw std::runtime_error("Graph::hashCode - not implemented yet"); +} - // checking tensor attached - if (node.attr().count("value")) { - auto attr = node.attr().at("value"); +Graph Graph::fromFlatBuffers(const char *fileName, + const GraphMemoryManager &memoryManager) { + // check if file exists + if (!FileUtils::fileExists(fileName)) + throw std::runtime_error("Graph file doesn't exist"); + + // get file size + auto fsize = FileUtils::fileSize(fileName); + Nd4jLong *ref; + void *ptrGraph; + + // TODO: check if mmap is supported + if (true) { + // mmap this file + ref = ::mmapFile(nullptr, fileName, fsize); + ptrGraph = reinterpret_cast(ref[0]); + } else { + // if mmap is not supported - load it directly + + ptrGraph = new uint8_t[fsize]; + auto data = reinterpret_cast(ptrGraph); + + FILE *in = fopen(fileName, "rb"); + int cnt = 0; + int b = 0; + while (cnt < fsize) { + b = fread(data + cnt, 1, fsize < 16384 ? fsize : 16384, in); + + cnt += b; + } + fclose(in); + } - // int - if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { - nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); + return fromFlatPointer(ptrGraph, memoryManager); +} - Nd4jLong __length = 0; +Graph Graph::fromFlatPointer(void *ptr, + const GraphMemoryManager &memoryManager) { + // get FlatGraph out of it + auto fg = GetFlatGraph(reinterpret_cast(ptr)); - nd4j_verbose("Tensor has shape: %i\n", attr.tensor().has_tensor_shape()); - if (attr.tensor().has_tensor_shape()) { - auto shape = attr.tensor().tensor_shape(); - int dims = shape.dim_size(); + // return Graph from this FlatGraph + return Graph(fg, memoryManager); +} - if (dims > 0) { - std::vector __shape; - // we don't have rank1 arrays. vector is 2d. - if (dims == 1) - __shape.push_back(1); +Graph Graph::importFromTensorFlow(const char *fileName) { + throw std::runtime_error("Graph::importFromTensorFlow() not implemented yet"); + /* + if (fileName == nullptr) + return nullptr; - // roll through dimensions - for (auto s: shape.dim()) { - __shape.push_back((int) s.size()); - } + int fd = open(fileName, O_RDONLY); - variable->setNDArray(new NDArray('c', __shape)); - __length = variable->getNDArray()->lengthOf(); + if (fd < 0) { + nd4j_printf("File not found: [%s]\n", fileName); + return nullptr; + } - nd4j_printf("Tensor shape found: %i dims;\n", dims); - variable->getNDArray()->printShapeInfo(); - } - } + nd4j_verbose("Trying to load TF GraphDef from file [%s]\n", fileName); - // it can be valueOf array - if (attr.tensor().int_val_size() == 1 && __length > 0) { - variable->getNDArray()->assign((T) attr.tensor().int_val(0)); - } - } - } - } else { - nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, node.name().c_str(), - node.op().c_str()); + tensorflow::GraphDef graphDef; + bool res = graphDef.ParseFromFileDescriptor(fd); - sd::ops::DeclarableOp *op = sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str()); + // trying to read graph as text + if(!res) { + close(fd); + fd = open(fileName, O_RDONLY); - if (op == nullptr) { - nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); - return nullptr; - } + google::protobuf::io::FileInputStream fileInput(fd); + fileInput.SetCloseOnDelete(true); - auto jNode = new Node(); - jNode->setName(node.name()); - jNode->setId(++nodesCounter); - jNode->setCustomOp(op); - jNode->setBlock(new Block(jNode->id(), variableSpace)); - - std::pair pair(node.name(), jNode->id()); - variablesMap.insert(pair); + if (!google::protobuf::TextFormat::Parse(&fileInput, &graphDef)) { + nd4j_printf("Failed to read file\n",""); + } else { + res = true; + } + } + + close(fd); + + if (!res) + return nullptr; + + auto graph = new Graph(); + auto variableSpace = graph->variableSpace(); + + std::map variablesMap; + + int variablesCounter = 0; + int nodesCounter = 0; + nd4j_verbose("Number of nodes in graphDef: %i\n", graphDef.node_size()); + for (int n = 0; n < graphDef.node_size(); n++) { + auto node = graphDef.node(n); + + // if that's external variable - we put it to variable space + if (strcmp(TF_VAR, node.op().c_str()) == 0 || strcmp(TF_CONST, + node.op().c_str()) == 0 || strcmp(TF_INPUT, node.op().c_str()) == 0) { + nd4j_printf("Variable found: %s\n", node.name().c_str()); + auto variable = new Variable(); + variable->setName(new std::string(node.name().c_str())); + variable->setId(--variablesCounter); + variableSpace->putVariable(variable->id(), variable); + + std::pair pair(node.name(), variable->id()); + variablesMap.insert(pair); + + // TODO: we might want to have something like that. + // it basically just gives input validation option, since settles + expectations for input if (strcmp(TF_INPUT, node.op().c_str()) == 0) continue; + + // checking shape, not applicable to input, since it can vary + if (node.attr().count("shape")) { + auto attr = node.attr().at("shape"); + int dims = attr.shape().dim_size(); + + if (dims > 0) { + std::vector __shape; + + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); + + // roll through dimensions + for (auto s: attr.shape().dim()) { + __shape.push_back((int) s.size()) ; + } - // multi-output nodes require special treatment - for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) { - std::string deepName(node.name()); - deepName += ":" + std::to_string(e); - auto deepVar = new Variable(); - deepVar->setName(&deepName); + variable->setNDArray(new NDArray('c', __shape)); - if (e > 0) - deepVar->setId(--variablesCounter); - else - deepVar->setId(jNode->id()); + nd4j_printf("Shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } - std::pair pair(deepName, deepVar->id()); - variablesMap.insert(pair); + // checking tensor attached + if (node.attr().count("value")) { + auto attr = node.attr().at("value"); - variableSpace->putVariable(deepVar->id(), deepVar); + // int + if (attr.tensor().dtype() == ::tensorflow::DataType::DT_INT32) { + nd4j_verbose("Int size: %i\n", attr.tensor().int_val_size()); - std::pair nodepair(jNode->id(), e); - variableSpace->putVariable(nodepair, deepVar); - } + Nd4jLong __length = 0; + nd4j_verbose("Tensor has shape: %i\n", + attr.tensor().has_tensor_shape()); if (attr.tensor().has_tensor_shape()) { + auto shape = attr.tensor().tensor_shape(); + int dims = shape.dim_size(); - printf(" Inputs: ["); - for (int i = 0; i < node.input_size(); i++) { - nd4j_printf("Trying input: %s\n", node.input(i).c_str()); + if (dims > 0) { + std::vector __shape; + // we don't have rank1 arrays. vector is 2d. + if (dims == 1) + __shape.push_back(1); - // if this fails - we're probably on partial input :) - if (!variablesMap.count(node.input(i))) - return nullptr; + // roll through dimensions + for (auto s: shape.dim()) { + __shape.push_back((int) s.size()); + } - printf("%s (%i)", node.input(i).c_str(), variablesMap.at(node.input(i))); + variable->setNDArray(new NDArray('c', __shape)); + __length = variable->getNDArray()->lengthOf(); + nd4j_printf("Tensor shape found: %i dims;\n", dims); + variable->getNDArray()->printShapeInfo(); + } + } - jNode->pickInput(variablesMap.at(node.input(i))); - jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); + // it can be valueOf array + if (attr.tensor().int_val_size() == 1 && __length > 0) { + variable->getNDArray()->assign((T) + attr.tensor().int_val(0)); + } + } + } + } else { + nd4j_verbose("Node id: [%i]; name: [%s]; opName: [%s]\n", n + 1, + node.name().c_str(), node.op().c_str()); + sd::ops::DeclarableOp *op = + sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str()); - if (i < node.input_size() + 1) - printf(", "); - } - printf("]\n"); + if (op == nullptr) { + nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); + return nullptr; + } + + auto jNode = new Node(); + jNode->setName(node.name()); + jNode->setId(++nodesCounter); + jNode->setCustomOp(op); + jNode->setBlock(new Block(jNode->id(), variableSpace)); - graph->addNode(jNode); - } - } + std::pair pair(node.name(), jNode->id()); + variablesMap.insert(pair); + + // multi-output nodes require special treatment + for (int e = 0; e < op->getOpDescriptor()->getNumberOfOutputs(); e++) + { std::string deepName(node.name()); deepName += ":" + std::to_string(e); auto + deepVar = new Variable(); deepVar->setName(&deepName); - return graph; - */ - } + if (e > 0) + deepVar->setId(--variablesCounter); + else + deepVar->setId(jNode->id()); - void Graph::addPlaceholder(const std::string &nodeName, DataType dataType, const std::vector &shape) { - int id = _maxId++; + std::pair pair(deepName, deepVar->id()); + variablesMap.insert(pair); - _symbolicLookupTable[nodeName] = id; + variableSpace->putVariable(deepVar->id(), deepVar); - auto var = std::make_shared(true, dataType, shape); - var->setName(nodeName); - _variableSpace.putVariable(id, var); + std::pair nodepair(jNode->id(), e); + variableSpace->putVariable(nodepair, deepVar); + } - _placeholders.emplace_back(nodeName); - } - std::map Graph::execute(const std::map &dictionary, const std::vector &outputs, const GraphExecutor &executor) const { - // creating our proxy, we'll use it for actual execution - VariableProxy proxy(&_variableSpace); + printf(" Inputs: ["); + for (int i = 0; i < node.input_size(); i++) { + nd4j_printf("Trying input: %s\n", node.input(i).c_str()); - // first of all we check existence of placeholders in dictionary - int placeholdersCount = 0; - for (const auto &v:dictionary) { - if (_symbolicLookupTable.count(v.first) == 0) - throw unresolved_input_exception::build("Dictionary entry doesn't exist", v.first); + // if this fails - we're probably on partial input :) + if (!variablesMap.count(node.input(i))) + return nullptr; - // we also check if arrays provided here do match placeholder restrictions of shape and dtype - auto var = _variableSpace.getVariable(v.first); - if (var->dataType() != DataType::ANY && var->dataType() != v.second.dataType()) - throw datatype_exception::build("Placeholder requires another data type", var->dataType(), v.second.dataType()); + printf("%s (%i)", node.input(i).c_str(), + variablesMap.at(node.input(i))); - auto shape = v.second.getShapeAsVector(); - if (shape != var->shape()) - throw shape_mismatch_exception::build("Placeholder requires specific shape", var->shape(), shape); - // update the placeholder - proxy.putVariable(v.first, var->id(), var->index(), v.second); + jNode->pickInput(variablesMap.at(node.input(i))); + jNode->getBlock()->pickInput(variablesMap.at(node.input(i))); - // we must also check if all placeholders were resolved - placeholdersCount++; - } - // TODO: it would be nice if we'll print out unresolved placeholders - if (placeholdersCount != _placeholders.size()) - throw std::runtime_error("Some placeholders were not resolved"); + if (i < node.input_size() + 1) + printf(", "); + } + printf("]\n"); + graph->addNode(jNode); + } + } - // we also must check existence of requested outputs - for (const auto &v:outputs) { - if (_symbolicLookupTable.count(v) == 0) - throw unresolved_output_exception::build("Requested output doesn't exist", v); - } + return graph; + */ +} - // execute optimized version of this graph - auto status = executor.execute(optimizedGraph(), proxy); - if (status != Status::OK()) - throw graph_execution_exception("Graph execution failed, error code: ", status); +void Graph::addPlaceholder(const std::string &nodeName, DataType dataType, + const std::vector &shape) { + int id = _maxId++; - // fetch outputs from our VariableProxy - std::map result; - for (const auto &v:outputs) { - if (!proxy.hasVariable(v)) - throw unresolved_output_exception::build("Requested output doesn't exist after execution", v); + _symbolicLookupTable[nodeName] = id; - auto var = proxy.getVariable(v); + auto var = std::make_shared(true, dataType, shape); + var->setName(nodeName); + _variableSpace.putVariable(id, var); - // TODO: we want to make sure ManagedDataBuffer doesn't leak here - result[v] = *var->getNDArray(); - } + _placeholders.emplace_back(nodeName); +} - return result; - } +std::map Graph::execute( + const std::map &dictionary, + const std::vector &outputs, + const GraphExecutor &executor) const { + // creating our proxy, we'll use it for actual execution + VariableProxy proxy(&_variableSpace); + + // first of all we check existence of placeholders in dictionary + int placeholdersCount = 0; + for (const auto &v : dictionary) { + if (_symbolicLookupTable.count(v.first) == 0) + throw unresolved_input_exception::build("Dictionary entry doesn't exist", + v.first); + + // we also check if arrays provided here do match placeholder restrictions + // of shape and dtype + auto var = _variableSpace.getVariable(v.first); + if (var->dataType() != DataType::ANY && + var->dataType() != v.second.dataType()) + throw datatype_exception::build("Placeholder requires another data type", + var->dataType(), v.second.dataType()); + + auto shape = v.second.getShapeAsVector(); + if (shape != var->shape()) + throw shape_mismatch_exception::build( + "Placeholder requires specific shape", var->shape(), shape); + + // update the placeholder + proxy.putVariable(v.first, var->id(), var->index(), v.second); + + // we must also check if all placeholders were resolved + placeholdersCount++; + } + + // TODO: it would be nice if we'll print out unresolved placeholders + if (placeholdersCount != _placeholders.size()) + throw std::runtime_error("Some placeholders were not resolved"); + + // we also must check existence of requested outputs + for (const auto &v : outputs) { + if (_symbolicLookupTable.count(v) == 0) + throw unresolved_output_exception::build("Requested output doesn't exist", + v); + } + + // execute optimized version of this graph + auto status = executor.execute(optimizedGraph(), proxy); + if (status != Status::OK()) + throw graph_execution_exception("Graph execution failed, error code: ", + status); + + // fetch outputs from our VariableProxy + std::map result; + for (const auto &v : outputs) { + if (!proxy.hasVariable(v)) + throw unresolved_output_exception::build( + "Requested output doesn't exist after execution", v); + + auto var = proxy.getVariable(v); + + // TODO: we want to make sure ManagedDataBuffer doesn't leak here + result[v] = *var->getNDArray(); + } + + return result; +} - Graph::Graph(const Graph &other) : _memoryMaager(other._memoryMaager) { - _configuration = other._configuration; - _variableSpace = other._variableSpace; - _stash = other._stash; - _unmapped = other._unmapped; - _symbolicLookupTable = other._symbolicLookupTable; - _built = false; - _maxId = _maxId; - } +Graph::Graph(const Graph &other) : _memoryMaager(other._memoryMaager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = _maxId; +} - Graph &Graph::operator=(const Graph &other) noexcept { - if (this == &other) - return *this; +Graph &Graph::operator=(const Graph &other) noexcept { + if (this == &other) return *this; - _configuration = other._configuration; - _variableSpace = other._variableSpace; - _stash = other._stash; - _unmapped = other._unmapped; - _symbolicLookupTable = other._symbolicLookupTable; - _built = false; - _maxId = _maxId; + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; + _unmapped = other._unmapped; + _symbolicLookupTable = other._symbolicLookupTable; + _built = false; + _maxId = _maxId; - return *this; - } + return *this; +} - Graph::Graph(Graph &&other) : _memoryMaager(other._memoryMaager) { - _configuration = other._configuration; - _variableSpace = other._variableSpace; - _stash = other._stash; +Graph::Graph(Graph &&other) : _memoryMaager(other._memoryMaager) { + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; - _unmapped = std::move(other._unmapped); - _symbolicLookupTable = std::move(other._symbolicLookupTable); + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); - _built = false; - _maxId = _maxId; - } + _built = false; + _maxId = _maxId; +} - Graph &Graph::operator=(Graph &&other) noexcept { - if (this == &other) - return *this; +Graph &Graph::operator=(Graph &&other) noexcept { + if (this == &other) return *this; - _configuration = other._configuration; - _variableSpace = other._variableSpace; - _stash = other._stash; + _configuration = other._configuration; + _variableSpace = other._variableSpace; + _stash = other._stash; - _unmapped = std::move(other._unmapped); - _symbolicLookupTable = std::move(other._symbolicLookupTable); + _unmapped = std::move(other._unmapped); + _symbolicLookupTable = std::move(other._symbolicLookupTable); - _built = false; - _maxId = _maxId; + _built = false; + _maxId = _maxId; - return *this; - } + return *this; +} - const GraphMemoryManager &Graph::memoryManager() const { - return _memoryMaager; - } +const GraphMemoryManager &Graph::memoryManager() const { return _memoryMaager; } - const OptimizedGraph& Graph::optimizedGraph() const { - std::lock_guard lock(_optimizedLock); +const OptimizedGraph &Graph::optimizedGraph() const { + std::lock_guard lock(_optimizedLock); - // optionally rebuild optimized graph, if it's out of date - if (_optimized.size() != size()) - _optimized = OptimizedGraph(const_cast(this)); + // optionally rebuild optimized graph, if it's out of date + if (_optimized.size() != size()) + _optimized = OptimizedGraph(const_cast(this)); - return _optimized; - } - } + return _optimized; } - +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/GraphHolder.cpp b/libnd4j/include/graph/impl/GraphHolder.cpp index 2da1077117ed..58c34cdad523 100644 --- a/libnd4j/include/graph/impl/GraphHolder.cpp +++ b/libnd4j/include/graph/impl/GraphHolder.cpp @@ -18,85 +18,79 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include namespace sd { - namespace graph { - GraphHolder* GraphHolder::getInstance() { - if (_INSTANCE == nullptr) - _INSTANCE = new GraphHolder(); - - return _INSTANCE; - }; - - void GraphHolder::registerGraph(Nd4jLong graphId, const Graph &graph) { - if (hasGraph(graphId)) - throw graph_exists_exception(graphId); - - std::lock_guard lock(_mutex); - _graphs[graphId] = graph; - } - - Graph& GraphHolder::graph(Nd4jLong graphId) { - if (!this->hasGraph(graphId)) { - nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); - throw std::runtime_error("Bad argument"); - } - - std::lock_guard lock(_mutex); - return _graphs[graphId]; - } - - void GraphHolder::forgetGraph(Nd4jLong graphId) { - if (this->hasGraph(graphId)) { - std::lock_guard lock(_mutex); - _graphs.erase(graphId); - } - } - - void GraphHolder::dropGraph(Nd4jLong graphId) { - forgetGraph(graphId); - } - - bool GraphHolder::hasGraph(Nd4jLong graphId) { - std::lock_guard lock(_mutex); - return _graphs.count(graphId) > 0; - } - - void GraphHolder::replaceGraph(Nd4jLong graphId, const Graph& graph) { - if (!hasGraph(graphId)) { - registerGraph(graphId, graph); - return; - } - - forgetGraph(graphId); - - std::lock_guard lock(_mutex); - _graphs[graphId] = graph; - } - - - - - flatbuffers::Offset GraphHolder::execute(Nd4jLong graphId, flatbuffers::FlatBufferBuilder &builder, const FlatInferenceRequest* request) { - if (!hasGraph(graphId)) - throw unknown_graph_exception(graphId); -/* - lockRead(graphId); - - auto graph = cloneGraph(graphId); - auto res = GraphExecutioner::execute(graph, builder, request); - delete graph; - - unlockRead(graphId); - - return res; - */ - throw std::runtime_error("GraphHolder::execute - not implemented yet"); - } - - GraphHolder* GraphHolder::_INSTANCE = 0; - } +namespace graph { +GraphHolder* GraphHolder::getInstance() { + if (_INSTANCE == nullptr) _INSTANCE = new GraphHolder(); + + return _INSTANCE; +}; + +void GraphHolder::registerGraph(Nd4jLong graphId, const Graph& graph) { + if (hasGraph(graphId)) throw graph_exists_exception(graphId); + + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; } + +Graph& GraphHolder::graph(Nd4jLong graphId) { + if (!this->hasGraph(graphId)) { + nd4j_printf("GraphHolder doesn't have graph stored for [%lld]\n", graphId); + throw std::runtime_error("Bad argument"); + } + + std::lock_guard lock(_mutex); + return _graphs[graphId]; +} + +void GraphHolder::forgetGraph(Nd4jLong graphId) { + if (this->hasGraph(graphId)) { + std::lock_guard lock(_mutex); + _graphs.erase(graphId); + } +} + +void GraphHolder::dropGraph(Nd4jLong graphId) { forgetGraph(graphId); } + +bool GraphHolder::hasGraph(Nd4jLong graphId) { + std::lock_guard lock(_mutex); + return _graphs.count(graphId) > 0; +} + +void GraphHolder::replaceGraph(Nd4jLong graphId, const Graph& graph) { + if (!hasGraph(graphId)) { + registerGraph(graphId, graph); + return; + } + + forgetGraph(graphId); + + std::lock_guard lock(_mutex); + _graphs[graphId] = graph; +} + +flatbuffers::Offset GraphHolder::execute( + Nd4jLong graphId, flatbuffers::FlatBufferBuilder& builder, + const FlatInferenceRequest* request) { + if (!hasGraph(graphId)) throw unknown_graph_exception(graphId); + /* + lockRead(graphId); + + auto graph = cloneGraph(graphId); + auto res = GraphExecutioner::execute(graph, builder, request); + delete graph; + + unlockRead(graphId); + + return res; + */ + throw std::runtime_error("GraphHolder::execute - not implemented yet"); +} + +GraphHolder* GraphHolder::_INSTANCE = 0; +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/GraphUtils.cpp b/libnd4j/include/graph/impl/GraphUtils.cpp index 15f674ce1c70..de8c5a06d6d6 100644 --- a/libnd4j/include/graph/impl/GraphUtils.cpp +++ b/libnd4j/include/graph/impl/GraphUtils.cpp @@ -23,14 +23,15 @@ #endif #include -#include + #include +#include -#ifdef __linux__ //_WIN32 -#include +#ifdef __linux__ //_WIN32 +#include #include #include -#include +#include //#eldef __APPLE__ //#include //#include @@ -39,81 +40,81 @@ namespace sd { namespace graph { bool GraphUtils::filterOperations(GraphUtils::OpList& ops) { - bool modified = false; - - std::vector filtered(ops); - - std::sort(filtered.begin(), filtered.end(), [](ops::OpDescriptor a, ops::OpDescriptor b) { - return a.getOpName()->compare(*(b.getOpName())) < 0; - }); - std::string name = *(filtered[0].getOpName()); - - for (int e = 1; e < filtered.size(); e++) { -// nd4j_printf(">%s<, %lu %lu\n", name.c_str(), ops.size(), filtered.size()); - if (0 == filtered[e].getOpName()->compare(name)) { - // there is a match - auto fi = std::find_if(ops.begin(), ops.end(), - [name](ops::OpDescriptor a) { - return a.getOpName()->compare(name) == 0; + bool modified = false; + + std::vector filtered(ops); + + std::sort(filtered.begin(), filtered.end(), + [](ops::OpDescriptor a, ops::OpDescriptor b) { + return a.getOpName()->compare(*(b.getOpName())) < 0; }); - if (fi != ops.end()) - ops.erase(fi); - modified = true; - } - name = *(filtered[e].getOpName()); + std::string name = *(filtered[0].getOpName()); + + for (int e = 1; e < filtered.size(); e++) { + // nd4j_printf(">%s<, %lu %lu\n", name.c_str(), ops.size(), + // filtered.size()); + if (0 == filtered[e].getOpName()->compare(name)) { + // there is a match + auto fi = + std::find_if(ops.begin(), ops.end(), [name](ops::OpDescriptor a) { + return a.getOpName()->compare(name) == 0; + }); + if (fi != ops.end()) ops.erase(fi); + modified = true; } - return modified; + name = *(filtered[e].getOpName()); + } + return modified; } std::string GraphUtils::makeCommandLine(GraphUtils::OpList& ops) { - std::string res; - - if (!ops.empty()) { - res += std::string(" -g \"-DSD_OPS_LIST='"); - //res += *(ops[0].getOpName()); - for (int i = 0; i < ops.size(); i++) { - res += std::string("-DOP_"); - res += *(ops[i].getOpName()); - res += "=true "; - } - res += "'\""; + std::string res; + + if (!ops.empty()) { + res += std::string(" -g \"-DSD_OPS_LIST='"); + // res += *(ops[0].getOpName()); + for (int i = 0; i < ops.size(); i++) { + res += std::string("-DOP_"); + res += *(ops[i].getOpName()); + res += "=true "; } + res += "'\""; + } - return res; + return res; } -int -GraphUtils::runPreprocessor(char const* input, char const* output) { - int status = 0; +int GraphUtils::runPreprocessor(char const* input, char const* output) { + int status = 0; -#ifdef __linux__ //_WIN32 - int pipefd[2]; - status = pipe(pipefd); - pid_t pid = fork(); - if (pid == 0) - { - close(pipefd[0]); // close reading end in the child - - dup2(pipefd[1], 1); // send stdout to the pipe - dup2(pipefd[1], 2); // send stderr to the pipe +#ifdef __linux__ //_WIN32 + int pipefd[2]; + status = pipe(pipefd); + pid_t pid = fork(); + if (pid == 0) { + close(pipefd[0]); // close reading end in the child - close(pipefd[1]); // this descriptor is no longer needed + dup2(pipefd[1], 1); // send stdout to the pipe + dup2(pipefd[1], 2); // send stderr to the pipe - #if __CNUC__ < 4 && __GNUC_MINOR__ < 9 - #pragma error "Compiler version should be greater then 4.9" - #endif + close(pipefd[1]); // this descriptor is no longer needed + +#if __CNUC__ < 4 && __GNUC_MINOR__ < 9 +#pragma error "Compiler version should be greater then 4.9" +#endif // just stacking everything together -// std::string cmdline = "./buildnativeoperations.sh " + -/// std::string(name_arg) + -// std::string(build_arg) + -/// std::string(arch_arg) + -// std::string(opts_arg); - - FILE *f = popen("which c++", "r"); - if(f == NULL) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(1); + // std::string cmdline = "./buildnativeoperations.sh " + + /// std::string(name_arg) + + // std::string(build_arg) + + /// std::string(arch_arg) + + // std::string(opts_arg); + + FILE* f = popen("which c++", "r"); + if (f == NULL) { + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(1); } #if _POSIX_C_SOURCE >= 200809L char* line = nullptr; @@ -121,11 +122,11 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { ssize_t len; if ((len = getdelim(&line, &size, '\n', f)) < 2) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(2); + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(2); } - if (line[len - 1] == '\n') - line[len - 1] = '\0'; + if (line[len - 1] == '\n') line[len - 1] = '\0'; std::string cmd(line); @@ -135,38 +136,39 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { #else std::string cmd; { - - char szLine[PATH_MAX]; - if (NULL == fgets(szLine, sizeof(szLine), f)) { - std::cerr << "Cannot find c++ compiler with 'which' command." << std::endl; - exit(3); - } - char* p = strchr(szLine, '\n'); - if (p) { - *p = '\0'; - } - cmd = szLine; + char szLine[PATH_MAX]; + if (NULL == fgets(szLine, sizeof(szLine), f)) { + std::cerr << "Cannot find c++ compiler with 'which' command." + << std::endl; + exit(3); + } + char* p = strchr(szLine, '\n'); + if (p) { + *p = '\0'; + } + cmd = szLine; } #endif - char const* cxx = cmd.c_str(); //;getenv("CXX"); -// if (cxx == nullptr) { -// nd4j_printf("Cannot retrieve mandatory environment variable 'CXX'. Please set up the variable and try again.", ""); -// exit(3); -// } - //char* pathEnv = getenv("PATH"); - //std::string pathStr("PATH=./;"); - //pathStr += pathEnv; - - //nd4j_printf("%s\n", pathStr.c_str()); -// char const* env[] = {// "HOME=/tmp", -// pathStr.c_str(), -// (char *)0 }; - -// to retrieve c++ version (hardcoded 6): c++ -v 2>&1 | tail -1 | awk '{v = int($3); print v;}' - - std::vector params;//(9); - std::vector args;//(9); + char const* cxx = cmd.c_str(); //;getenv("CXX"); + // if (cxx == nullptr) { + // nd4j_printf("Cannot retrieve mandatory environment variable 'CXX'. + // Please set up the variable and try again.", ""); exit(3); + // } + // char* pathEnv = getenv("PATH"); + // std::string pathStr("PATH=./;"); + // pathStr += pathEnv; + + // nd4j_printf("%s\n", pathStr.c_str()); + // char const* env[] = {// "HOME=/tmp", + // pathStr.c_str(), + // (char *)0 }; + + // to retrieve c++ version (hardcoded 6): c++ -v 2>&1 | tail -1 | awk '{v = + // int($3); print v;}' + + std::vector params; //(9); + std::vector args; //(9); args.emplace_back(cmd); args.emplace_back(std::string("-E")); args.emplace_back(std::string("-P")); @@ -182,7 +184,7 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { args.emplace_back(std::string("-I../include/types")); args.emplace_back(std::string("-I../include/array")); args.emplace_back(std::string("-I../include/cnpy")); - args.emplace_back(std::string("-I../include/graph")); + args.emplace_back(std::string("-I../include/graph")); args.emplace_back(std::string("-I../include/ops/declarable")); #ifdef MKLDNN_PATH args.emplace_back(std::string("-I" MKLDNN_PATH "/include")); @@ -197,14 +199,13 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { std::string preprocessorCmd(cxx); bool skip = true; - for (auto& arg: args) { - if (!skip) { - preprocessorCmd += ' '; - preprocessorCmd += arg; - } - else - skip = false; - params.emplace_back(const_cast(arg.data())); + for (auto& arg : args) { + if (!skip) { + preprocessorCmd += ' '; + preprocessorCmd += arg; + } else + skip = false; + params.emplace_back(const_cast(arg.data())); } params.emplace_back(nullptr); nd4j_printf("Run: \n\t %s\n", preprocessorCmd.c_str()); @@ -212,26 +213,24 @@ GraphUtils::runPreprocessor(char const* input, char const* output) { int err = execvp(cmd.c_str(), ¶ms[0]); if (err < 0) { - perror("\nCannot run Preprocessor properly due \n"); + perror("\nCannot run Preprocessor properly due \n"); } status = err; nd4j_printf("Header file %s was generated.\n", output); -// nd4j_printf("Running build script\n%s\n", cmdline.c_str()); - } - else - { + // nd4j_printf("Running build script\n%s\n", cmdline.c_str()); + } else { // parent - char buffer[1024]; - close(pipefd[1]); // close the write end of the pipe in the parent - memset(buffer, 0, sizeof(buffer)); - while (read(pipefd[0], buffer, sizeof(buffer)) != 0) { - printf("%s\n", buffer); - } - waitpid(pid, &status, 0); + char buffer[1024]; + close(pipefd[1]); // close the write end of the pipe in the parent + memset(buffer, 0, sizeof(buffer)); + while (read(pipefd[0], buffer, sizeof(buffer)) != 0) { + printf("%s\n", buffer); } + waitpid(pid, &status, 0); + } #endif - return status; + return status; } -} -} +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/InferenceRequest.cpp b/libnd4j/include/graph/impl/InferenceRequest.cpp index 6aff5a96efb2..079491eb39d5 100644 --- a/libnd4j/include/graph/impl/InferenceRequest.cpp +++ b/libnd4j/include/graph/impl/InferenceRequest.cpp @@ -20,58 +20,65 @@ #include - namespace sd { - namespace graph { - InferenceRequest::InferenceRequest(Nd4jLong graphId, const ExecutorConfiguration &configuration) { - this->_id = graphId; - this->_configuration = configuration; - } +namespace graph { +InferenceRequest::InferenceRequest(Nd4jLong graphId, + const ExecutorConfiguration &configuration) { + this->_id = graphId; + this->_configuration = configuration; +} - InferenceRequest::~InferenceRequest() { - // - } +InferenceRequest::~InferenceRequest() { + // +} - void InferenceRequest::appendVariable(int id, const NDArray &array) { - appendVariable(id, 0, array); - } +void InferenceRequest::appendVariable(int id, const NDArray &array) { + appendVariable(id, 0, array); +} - void InferenceRequest::appendVariable(int id, int index, const NDArray &array) { - auto v = std::make_shared(std::make_shared(array), nullptr, id, index); - insertVariable(v); - } +void InferenceRequest::appendVariable(int id, int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + nullptr, id, index); + insertVariable(v); +} - void InferenceRequest::appendVariable(const std::string &id, const NDArray &array) { - auto v = std::make_shared(std::make_shared(array), id.c_str()); - insertVariable(v); - } +void InferenceRequest::appendVariable(const std::string &id, + const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + id.c_str()); + insertVariable(v); +} - void InferenceRequest::appendVariable(const std::string &name, int id, int index, const NDArray &array) { - auto v = std::make_shared(std::make_shared(array), name, id, index); - insertVariable(v); - } +void InferenceRequest::appendVariable(const std::string &name, int id, + int index, const NDArray &array) { + auto v = std::make_shared(std::make_shared(array), + name, id, index); + insertVariable(v); +} - void InferenceRequest::insertVariable(std::shared_ptr variable) { - variable->markRemovable(false); - variable->markReadOnly(true); - _variables.emplace_back(variable); - } +void InferenceRequest::insertVariable(std::shared_ptr variable) { + variable->markRemovable(false); + variable->markReadOnly(true); + _variables.emplace_back(variable); +} - void InferenceRequest::appendVariable(std::shared_ptr variable) { - _variables.emplace_back(variable); - } +void InferenceRequest::appendVariable(std::shared_ptr variable) { + _variables.emplace_back(variable); +} - flatbuffers::Offset InferenceRequest::asFlatInferenceRequest(flatbuffers::FlatBufferBuilder &builder) { - std::vector> vec; - for (const auto &v : _variables) { - vec.emplace_back(v->asFlatVariable(builder)); - } +flatbuffers::Offset +InferenceRequest::asFlatInferenceRequest( + flatbuffers::FlatBufferBuilder &builder) { + std::vector> vec; + for (const auto &v : _variables) { + vec.emplace_back(v->asFlatVariable(builder)); + } - auto confOffset = _configuration.asFlatConfiguration(builder); + auto confOffset = _configuration.asFlatConfiguration(builder); - auto vecOffset = builder.CreateVector(vec); + auto vecOffset = builder.CreateVector(vec); - return CreateFlatInferenceRequest(builder, _id, vecOffset, confOffset); - } - } -} \ No newline at end of file + return CreateFlatInferenceRequest(builder, _id, vecOffset, confOffset); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Intervals.cpp b/libnd4j/include/graph/impl/Intervals.cpp index 1a89c797fc92..b8f0ff1a63e8 100644 --- a/libnd4j/include/graph/impl/Intervals.cpp +++ b/libnd4j/include/graph/impl/Intervals.cpp @@ -21,33 +21,30 @@ namespace sd { - // default constructor - Intervals::Intervals(): _content({{}}) {} - - // constructor - Intervals::Intervals(const std::initializer_list>& content ): _content(content) {} - Intervals::Intervals(const std::vector>& content ): _content(content) {} - - ////////////////////////////////////////////////////////////////////////// - // accessing operator - std::vector Intervals::operator[](const Nd4jLong i) const { - - return *(_content.begin() + i); - } - - ////////////////////////////////////////////////////////////////////////// - // returns size of _content - int Intervals::size() const { - - return _content.size(); - } - - ////////////////////////////////////////////////////////////////////////// - // modifying operator - // std::vector& Intervals::operator()(const int i) { - // return _content[i]; - // } +// default constructor +Intervals::Intervals() : _content({{}}) {} + +// constructor +Intervals::Intervals( + const std::initializer_list>& content) + : _content(content) {} +Intervals::Intervals(const std::vector>& content) + : _content(content) {} + +////////////////////////////////////////////////////////////////////////// +// accessing operator +std::vector Intervals::operator[](const Nd4jLong i) const { + return *(_content.begin() + i); +} +////////////////////////////////////////////////////////////////////////// +// returns size of _content +int Intervals::size() const { return _content.size(); } -} +////////////////////////////////////////////////////////////////////////// +// modifying operator +// std::vector& Intervals::operator()(const int i) { +// return _content[i]; +// } +} // namespace sd diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 662ecabd931d..aaae8d129576 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -18,1022 +18,940 @@ // @author raver119@gmail.com // +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include +#include +#include +#include #include #include -#include -#include +#include #include -#include +#include +#include +#include #include -#include -#include +#include +#include #include -#include -#include +#include +#include +#include +#include namespace sd { - namespace graph { - Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, const std::vector &tArgs, - const std::vector &iArgs, const std::vector &bArgs, - const std::vector &dArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName.getOpHash()); - - this->_name = nodeName; - this->_opType = OpType_CUSTOM; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_customOp = customOp; - - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - block.setName(nodeName); - - block.appendI(iArgs); - block.appendT(tArgs); - block.appendB(bArgs); - block.appendD(dArgs); - - this->setContextPrototype(block); - } - - Node::Node(const std::string &opName, const std::string &nodeName, const std::vector &tArgs, - const std::vector &iArgs, const std::vector &bArgs, - const std::vector &dArgs) { - - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); - - this->_name = nodeName; - this->_opType = OpType_CUSTOM; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_customOp = customOp; - - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); - block.setName(nodeName); - - block.appendI(iArgs); - block.appendT(tArgs); - block.appendB(bArgs); - block.appendD(dArgs); +namespace graph { +Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, + const std::vector &tArgs, const std::vector &iArgs, + const std::vector &bArgs, const std::vector &dArgs) { + auto customOp = + ops::OpRegistrator::getInstance()->getOperation(opName.getOpHash()); + + this->_name = nodeName; + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + block.setName(nodeName); + + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); + + this->setContextPrototype(block); +} - this->setContextPrototype(block); - } +Node::Node(const std::string &opName, const std::string &nodeName, + const std::vector &tArgs, const std::vector &iArgs, + const std::vector &bArgs, const std::vector &dArgs) { + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + + this->_name = nodeName; + this->_opType = OpType_CUSTOM; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_customOp = customOp; + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + block.setName(nodeName); + + block.appendI(iArgs); + block.appendT(tArgs); + block.appendB(bArgs); + block.appendD(dArgs); + + this->setContextPrototype(block); +} - void Node::setOuterTime(Nd4jLong time){ -// if (hasBlockAttached()) -// _block->setOuterTime(time); - } +void Node::setOuterTime(Nd4jLong time) { + // if (hasBlockAttached()) + // _block->setOuterTime(time); +} - void Node::setInnerTime(Nd4jLong time){ -// if (hasBlockAttached()) -// _block->setInnerTime(time); - } +void Node::setInnerTime(Nd4jLong time) { + // if (hasBlockAttached()) + // _block->setInnerTime(time); +} - void Node::setGraph(Graph* graph) { - _graph = graph; - } +void Node::setGraph(Graph *graph) { _graph = graph; } - Graph* Node::graph() const { - return _graph; - } +Graph *Node::graph() const { return _graph; } - void Node::markInplace(bool reallyInplace) { - _isInplace = reallyInplace; - _protoContext.markInplace(reallyInplace); - } +void Node::markInplace(bool reallyInplace) { + _isInplace = reallyInplace; + _protoContext.markInplace(reallyInplace); +} - bool Node::isRemovable() const { - return _removable; - } +bool Node::isRemovable() const { return _removable; } - void Node::markRemovable(bool reallyRemovable) const { - _removable = reallyRemovable; - } +void Node::markRemovable(bool reallyRemovable) const { + _removable = reallyRemovable; +} - OpClass Node::getOpClass() const { - return _opClass; - } +OpClass Node::getOpClass() const { return _opClass; } - bool Node::hasBlockAttached() { - return true; - } +bool Node::hasBlockAttached() { return true; } - bool Node::isInplace() { - return _isInplace; - } +bool Node::isInplace() { return _isInplace; } - bool Node::isDivergencePoint() { - if (hasCustomOp()) { - return _customOp->getOpDescriptor()->isDivergent(); - } else if (opType() == OpType_LOGIC && opNum() == 30) - return true; - else - return false; - } +bool Node::isDivergencePoint() { + if (hasCustomOp()) { + return _customOp->getOpDescriptor()->isDivergent(); + } else if (opType() == OpType_LOGIC && opNum() == 30) + return true; + else + return false; +} - void Node::setActive(bool reallyActive) { - _active = reallyActive; - } +void Node::setActive(bool reallyActive) { _active = reallyActive; } - bool Node::isActive() { - return _active; - } +bool Node::isActive() { return _active; } - Nd4jLong Node::getFrameId() { - return _frameId; - } +Nd4jLong Node::getFrameId() { return _frameId; } - void Node::setFrameId(Nd4jLong frameId) { - _frameId = frameId; - } +void Node::setFrameId(Nd4jLong frameId) { _frameId = frameId; } - const ContextPrototype& Node::contextPrototype() const { - return _protoContext; - } +const ContextPrototype &Node::contextPrototype() const { return _protoContext; } - void Node::setContextPrototype(const ContextPrototype &block) { - _protoContext = block; - } +void Node::setContextPrototype(const ContextPrototype &block) { + _protoContext = block; +} - void Node::setId(int id) { - _id = id; - _protoContext.setNodeId(id); - } +void Node::setId(int id) { + _id = id; + _protoContext.setNodeId(id); +} - std::shared_ptr Node::customOp() const { - return _customOp; - } +std::shared_ptr Node::customOp() const { + return _customOp; +} - void Node::setCustomOp(std::shared_ptr customOp) { - _customOp = customOp; +void Node::setCustomOp(std::shared_ptr customOp) { + _customOp = customOp; - // divergent ops (Switch etc) are always inplace, they don't allocate anything - if (_customOp.get() != nullptr && _customOp->getOpDescriptor()->isDivergent()) - _isInplace = true; - } + // divergent ops (Switch etc) are always inplace, they don't allocate anything + if (_customOp.get() != nullptr && _customOp->getOpDescriptor()->isDivergent()) + _isInplace = true; +} - bool Node::hasCustomOp() const { - return _customOp != nullptr; - } +bool Node::hasCustomOp() const { return _customOp != nullptr; } - const std::string & Node::name() const { - return this->getName(); - } +const std::string &Node::name() const { return this->getName(); } - const std::string & Node::getName() const { - return _name; - } +const std::string &Node::getName() const { return _name; } - void Node::setName(const std::string& name) { - _name = name; - } +void Node::setName(const std::string &name) { _name = name; } - void Node::setName(std::string *name) { - _name = *name; - } +void Node::setName(std::string *name) { _name = *name; } +void Node::pickInput(std::pair &pair) { + _input.push_back(pair); + _protoContext.pickInput(pair); +} - void Node::pickInput(std::pair& pair) { - _input.push_back(pair); - _protoContext.pickInput(pair); - } +void Node::pickInput(const std::string &id) { + throw std::runtime_error("Node::pickInput - Not implemented yet"); +} - void Node::pickInput(const std::string &id) { - throw std::runtime_error("Node::pickInput - Not implemented yet"); - } +void Node::pickInput(int inputId, int outputId) { + std::pair p(inputId, outputId); + pickInput(p); +} - void Node::pickInput(int inputId, int outputId) { - std::pair p(inputId,outputId); - pickInput(p); - } +void Node::pickInput(int inputId) { + pickInput(inputId, 0); - void Node::pickInput(int inputId) { - pickInput(inputId, 0); + if (inputId < 0) + _hasExternalInputs = true; + else + _hasInternalInputs = true; +} - if (inputId < 0) - _hasExternalInputs = true; - else - _hasInternalInputs = true; - } +void Node::pickExternalOutput(int outputId) { + std::pair pair(outputId, 0); + _output.push_back(pair); - void Node::pickExternalOutput(int outputId) { - std::pair pair(outputId, 0); - _output.push_back(pair); + _hasExternalOutputs = true; +} - _hasExternalOutputs = true; - } +void Node::pickOutputOnce(int outputId) { + std::pair pair(outputId, 0); + if (std::find(_output.begin(), _output.end(), pair) == _output.end()) + pickOutput(outputId); +} - void Node::pickOutputOnce(int outputId) { - std::pair pair(outputId, 0); - if (std::find(_output.begin(), _output.end(), pair) == _output.end()) - pickOutput(outputId); - } +void Node::pickOutput(int nodeId, int outputId) { + std::pair pair(nodeId, outputId); + _output.emplace_back(pair); +} - void Node::pickOutput(int nodeId, int outputId) { - std::pair pair(nodeId, outputId); - _output.emplace_back(pair); - } +void Node::pickOutput(int outputId) { + std::pair pair(outputId, 0); + _output.emplace_back(pair); - void Node::pickOutput(int outputId) { - std::pair pair(outputId, 0); - _output.emplace_back(pair); + if (outputId < 0) + _hasExternalOutputs = true; + else + _hasInternalOutputs = true; +} - if (outputId < 0) - _hasExternalOutputs = true; - else - _hasInternalOutputs = true; - } +bool Node::hasExternalOutputs() { return _hasExternalOutputs; } - bool Node::hasExternalOutputs() { - return _hasExternalOutputs; - } +bool Node::hasExternalInputs() { return _hasExternalInputs; } - bool Node::hasExternalInputs() { - return _hasExternalInputs; - } +bool Node::hasInternalOutputs() { return _hasInternalOutputs; } - bool Node::hasInternalOutputs() { - return _hasInternalOutputs; - } +bool Node::hasInternalInputs() { return _hasInternalInputs; } - bool Node::hasInternalInputs() { - return _hasInternalInputs; - } +bool Node::isMultiInput() { return _input.size() > 1; } - bool Node::isMultiInput() { - return _input.size() > 1; - } +bool Node::isMultiOutput() { return _output.size() > 1; } - bool Node::isMultiOutput() { - return _output.size() > 1; - } +double *Node::extraParams() { return _extraParams; } - double * Node::extraParams() { - return _extraParams; - } +int Node::totalReferences() { return _referencedBy.size(); } - int Node::totalReferences() { - return _referencedBy.size(); - } +void Node::addReference(int nodeId) { _referencedBy.emplace_back(nodeId); } - void Node::addReference(int nodeId) { - _referencedBy.emplace_back(nodeId); - } +OpType Node::opType() const { return _opType; } - OpType Node::opType() const { - return _opType; - } +int Node::id() const { return _id; } - int Node::id() const { - return _id; - } +Nd4jLong Node::opNum() const { return _opNum; } - Nd4jLong Node::opNum() const { - return _opNum; - } +const std::vector> &Node::input() const { return _input; } - const std::vector>& Node::input() const { - return _input; - } +const std::vector> &Node::output() const { return _output; } - const std::vector>& Node::output() const { - return _output; - } +bool Node::isScoped() { return _scope_id != 0; } - bool Node::isScoped() { - return _scope_id != 0; - } +void Node::setScopeInfo(int id, const char *name) { + _scope_id = id; - void Node::setScopeInfo(int id, const char* name) { - _scope_id = id; + if (name != nullptr) _scope_name = name; +} - if (name != nullptr) - _scope_name = name; - } +int Node::scopeId() { return _scope_id; } - int Node::scopeId() { - return _scope_id; - } +std::string *Node::scopeName() { return &_scope_name; } - std::string* Node::scopeName() { - return &_scope_name; - } +template +Node *Node::asT() { + auto node = this->clone(); + node->_dataType = DataTypeUtils::fromT(); + return node; +} +BUILD_SINGLE_TEMPLATE(template SD_EXPORT Node *Node::asT, (), LIBND4J_TYPES); - template - Node* Node::asT() { - auto node = this->clone(); - node->_dataType = DataTypeUtils::fromT(); - return node; - } - BUILD_SINGLE_TEMPLATE(template SD_EXPORT Node* Node::asT, (), LIBND4J_TYPES); +Node::Node(const std::string &opName, const std::string &nodeName, const int id, + const std::vector &inputs, + const std::vector &tArgs, + const std::vector &iArgs) { + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); - Node::Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_customOp = customOp; - this->_opType = OpType_CUSTOM; - this->_id = id; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_customOp = customOp; + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; + for (auto i : inputs) pickInput(i); - for (auto i: inputs) - pickInput(i); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + block.appendI(iArgs); + block.appendT(tArgs); - block.appendI(iArgs); - block.appendT(tArgs); + this->setContextPrototype(block); +} - this->setContextPrototype(block); - } +Node::Node(const std::string &opName, const int id, + const std::vector> &inputs, + const std::vector &tArgs, + const std::vector &iArgs) { + auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); - Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + this->_customOp = customOp; - this->_opType = OpType_CUSTOM; - this->_id = id; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - this->_customOp = customOp; + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; + for (auto i : inputs) pickInput(i); - for (auto i: inputs) - pickInput(i); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + block.appendI(iArgs); + block.appendT(tArgs); - block.appendI(iArgs); - block.appendT(tArgs); + this->setContextPrototype(block); +} - this->setContextPrototype(block); - } +Node::Node(sd::ops::DeclarableOp *customOp, int id, + std::initializer_list input, std::initializer_list output, + std::initializer_list dimensions, float scalar, + std::initializer_list tArgs, + std::initializer_list iArgs) { + this->_opType = OpType_CUSTOM; + this->_id = id; + this->_opNum = customOp->getOpHash(); + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default - Node::Node(sd::ops::DeclarableOp *customOp, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { - this->_opType = OpType_CUSTOM; - this->_id = id; - this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default + // if custom op is a registered one - pull it from cache, otherwise - clone + // locally + if (sd::ops::OpRegistrator::getInstance()->hasOperation(_opNum)) + this->_customOp = + sd::ops::OpRegistrator::getInstance()->getOperation(_opNum); + else + throw std::runtime_error( + "Can't create a node with custom operation within"); - // if custom op is a registered one - pull it from cache, otherwise - clone locally - if (sd::ops::OpRegistrator::getInstance()->hasOperation(_opNum)) - this->_customOp = sd::ops::OpRegistrator::getInstance()->getOperation(_opNum); - else - throw std::runtime_error("Can't create a node with custom operation within"); + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; + for (auto i : input) pickInput(i); - for (auto i: input) - pickInput(i); + for (auto o : output) pickOutput(o); - for (auto o: output) - pickOutput(o); + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + for (auto v : dimensions) block.appendA(v); - for (auto v: dimensions) - block.appendA(v); + for (auto v : iArgs) block.appendI(v); - for (auto v: iArgs) - block.appendI(v); + for (auto v : tArgs) block.appendT(v); - for (auto v: tArgs) - block.appendT(v); + this->setContextPrototype(block); +} - this->setContextPrototype(block); - } +void Node::setOpType(OpType opType) { this->_opType = opType; } + +Node::Node(OpType opType, int opNum, int id, std::initializer_list input, + std::initializer_list output, + std::initializer_list dimensions, float scalar, + std::initializer_list tArgs, + std::initializer_list iArgs) { + this->_opType = opType; + this->_id = id; + this->_opNum = opNum; + this->_extraParams = nullptr; + this->_dataType = sd::DataType::FLOAT32; // float as default + + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + + for (auto i : input) pickInput(i); + + for (auto o : output) pickOutput(o); + + // these ops allow in-place execution by design + if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || + opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || + opType == OpType_SCALAR || opType == OpType_BROADCAST) { + if (_output.size() <= 1) { + _isInplace = true; + } + _opClass = OpClass_TRANSFORM; + } else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || + opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || + opType == OpType_SUMMARYSTATS) { + _opClass = OpClass_REDUCTION; + } + + if (opType == OpType_BROADCAST || opType == OpType_BROADCAST_BOOL || + opType == OpType_INDEX_REDUCE || opType == OpType_SUMMARYSTATS || + opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_SAME || + opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_3 || + opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_SAME || + opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_BOOL || + opType == OpType_RANDOM || opType == OpType_PAIRWISE || + opType == OpType_PAIRWISE_BOOL || opType == OpType_SCALAR_BOOL || + opType == OpType_SCALAR) { + ContextPrototype block(nullptr, this->id(), false); + + for (auto v : dimensions) block.appendA(v); + + for (auto v : iArgs) block.appendI(v); + + for (auto v : tArgs) block.appendT(v); + + this->setContextPrototype(block); + + this->setCustomOp(Node::buildOpByType( + opType, (int)input.size(), (int)block.getIArguments().size(), + (int)block.getTArguments().size(), opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } else if (opType == OpType_CUSTOM) { + if (this->customOp()) { + ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), + false); + + for (auto v : dimensions) block.appendA(v); + + for (auto v : iArgs) block.appendI(v); + + for (auto v : tArgs) block.appendT(v); + + this->setContextPrototype(block); + } else + throw std::runtime_error("wrong custom operation given"); + } +}; + +Node::Node(const FlatNode *node) { + _hasExternalInputs = false; + _hasExternalOutputs = false; + _hasInternalInputs = false; + _hasInternalOutputs = false; + _extraParams = nullptr; + + _dataType = sd::DataType::FLOAT32; // float as default + if (node->scope_id() != 0) this->_scope_id = node->scope_id(); + + if (node->scope_name() != nullptr && node->scope_name()->size() > 0) + this->_scope_name = node->scope_name()->str(); + + if (node->scalar() != nullptr) + throw std::runtime_error("FlatNode has scalar defined, it's deprecated"); + + if (node != nullptr) { + this->_id = node->id(); + // this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType()); + this->_opNum = node->opNum(); + this->_opType = node->opType(); + + if (node->name() != nullptr && node->name()->c_str() != nullptr) { + this->_name = node->name()->str(); + } - void Node::setOpType(OpType opType) { - this->_opType = opType; - } + if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { + for (int e = 0; e < (int)node->inputPaired()->size(); e++) { + auto pair = node->inputPaired()->Get(e); + pickInput(pair->first(), pair->second()); + } + } else if (node->input() != nullptr && node->input()->size() > 0) { + for (int e = 0; e < (int)node->input()->size(); e++) + pickInput(node->input()->Get(e)); + } else { + if (this->opType() != OpType_LOGIC) { + if (this->_name.size() > 0) { + nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, + this->_name.c_str()); + } else { + nd4j_debug("Node [%i:] has no inputs defined\n", this->_id); + } + } + } - Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { - this->_opType = opType; - this->_id = id; - this->_opNum = opNum; - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default - - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - - for (auto i: input) - pickInput(i); - - for (auto o: output) - pickOutput(o); - - // these ops allow in-place execution by design - if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || opType == OpType_SCALAR || opType == OpType_BROADCAST) { - if (_output.size() <= 1) { - _isInplace = true; - } - _opClass = OpClass_TRANSFORM; - } else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || opType == OpType_SUMMARYSTATS) { - _opClass = OpClass_REDUCTION; + /* + if (node->output() != nullptr) + for (int e = 0; e < (int) node->output()->size(); e++) { + auto oid = node->output()->Get(e); + if (oid != this->_id && oid != 0) { + nd4j_verbose("Picking output: %i\n", node->output()->Get(e)); + pickOutput(oid); } + } + */ + if (node->extraParams() != nullptr && node->extraParams()->size() > 0) { + _extraParams = new double[node->extraParams()->size()]; + for (int e = 0; e < (int)node->extraParams()->size(); e++) { + _extraParams[e] = static_cast(node->extraParams()->Get(e)); + } + } - if (opType == OpType_BROADCAST || - opType == OpType_BROADCAST_BOOL || - opType == OpType_INDEX_REDUCE || - opType == OpType_SUMMARYSTATS || - opType == OpType_REDUCE_BOOL || - opType == OpType_REDUCE_SAME || - opType == OpType_REDUCE_FLOAT || - opType == OpType_REDUCE_3 || - opType == OpType_TRANSFORM_STRICT || - opType == OpType_TRANSFORM_SAME || - opType == OpType_TRANSFORM_FLOAT || - opType == OpType_TRANSFORM_BOOL || - opType == OpType_RANDOM || - opType == OpType_PAIRWISE || - opType == OpType_PAIRWISE_BOOL || - opType == OpType_SCALAR_BOOL || - opType == OpType_SCALAR) { - - ContextPrototype block(nullptr, this->id(), false); - - for (auto v: dimensions) - block.appendA(v); - - for (auto v: iArgs) - block.appendI(v); + // if (node->dimensions() != nullptr && node->dimensions()->size() > 0) + // throw std::runtime_error("FlatNode has dimensions defined. Graph is + // outdated"); - for (auto v: tArgs) - block.appendT(v); + if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { + if (node->extraInteger()->size() < 1) { + nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", + this->id()); + throw std::runtime_error("Enter node must have FrameID specified"); + } - this->setContextPrototype(block); + this->setFrameId(node->extraInteger()->Get(0)); + } - this->setCustomOp(Node::buildOpByType(opType, (int) input.size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), opNum)); - block.setOpDescriptor(this->customOp()->getOpDescriptor()); - } else if (opType == OpType_CUSTOM) { - if (this->customOp()) { - ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), false); + // these ops allow in-place execution by design + if (_opType == OpType_BROADCAST || _opType == OpType_BROADCAST_BOOL || + _opType == OpType_INDEX_REDUCE || _opType == OpType_SUMMARYSTATS || + _opType == OpType_REDUCE_BOOL || _opType == OpType_REDUCE_SAME || + _opType == OpType_REDUCE_FLOAT || _opType == OpType_REDUCE_3 || + _opType == OpType_TRANSFORM_STRICT || + _opType == OpType_TRANSFORM_SAME || _opType == OpType_TRANSFORM_FLOAT || + _opType == OpType_TRANSFORM_BOOL || _opType == OpType_RANDOM || + _opType == OpType_PAIRWISE || _opType == OpType_PAIRWISE_BOOL || + _opType == OpType_SCALAR_BOOL || _opType == OpType_SCALAR) { + if (_output.size() <= 1) { + _isInplace = true; + } + + if (node->input() != nullptr && node->input()->size() > 0) { + ContextPrototype block(nullptr, this->id(), false); + + for (auto v : _dimensions) block.appendA(v); + + if (node->extraParams() != nullptr && node->extraParams()->size() > 0) + for (int e = 0; e < (int)node->extraParams()->size(); e++) { + block.appendT(static_cast(node->extraParams()->Get(e))); + } + + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int)node->extraBools()->size(); e++) { + block.appendB(node->extraBools()->Get(e)); + } + + if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) + for (int e = 0; e < (int)node->extraInteger()->size(); e++) { + block.appendI(node->extraInteger()->Get(e)); + } + + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int)node->extraTypes()->size(); e++) { + block.appendD((sd::DataType)node->extraTypes()->Get(e)); + } + } + + this->setContextPrototype(block); + this->setCustomOp(Node::buildOpByType( + _opType, (int)node->input()->size(), + (int)block.getIArguments().size(), + (int)block.getTArguments().size(), (int)_opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } else if (node->inputPaired() != nullptr && + node->inputPaired()->size() > 0) { + ContextPrototype block(nullptr, this->id(), false); + + for (int e = 0; e < this->input().size(); e++) { + block.pickInput(this->input().at(e)); + } + + // there's no other IArgs in legacy options, actually + for (auto v : _dimensions) block.appendA(v); + + if (node->extraParams() != nullptr && node->extraParams()->size() > 0) + for (int e = 0; e < (int)node->extraParams()->size(); e++) { + block.appendT(static_cast(node->extraParams()->Get(e))); + } + + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int)node->extraBools()->size(); e++) { + block.appendB(node->extraBools()->Get(e)); + } + + if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) + for (int e = 0; e < (int)node->extraInteger()->size(); e++) { + block.appendI(node->extraInteger()->Get(e)); + } + + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int)node->extraTypes()->size(); e++) { + block.appendD((sd::DataType)node->extraTypes()->Get(e)); + } + } + + this->setContextPrototype(block); + + this->setCustomOp(Node::buildOpByType( + _opType, (int)node->inputPaired()->size(), + (int)block.getIArguments().size(), + (int)block.getTArguments().size(), (int)_opNum)); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } + } else if (this->_opType == OpType_CUSTOM) { + auto op = + sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum()); + if (op == nullptr) { + nd4j_verbose("Can't find operation: %lld\n", this->opNum()); + throw std::runtime_error("Can't find requested operation"); + } + + ContextPrototype block(nullptr, this->id()); + + for (int e = 0; e < this->input().size(); e++) { + block.pickInput(this->input().at(e)); + } + + if (node->extraInteger() != nullptr) + for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { + auto v = node->extraInteger()->Get(e); + // FIXME: remove this static_cast, iArgs should be Nd4jLong + block.appendI(static_cast(v)); + } + + if (node->extraParams() != nullptr) + for (uint32_t e = 0; e < node->extraParams()->size(); e++) + block.appendT(static_cast(node->extraParams()->Get(e))); - for (auto v: dimensions) - block.appendA(v); + if (node->extraBools() != nullptr && node->extraBools()->size() > 0) + for (int e = 0; e < (int)node->extraBools()->size(); e++) { + block.appendB(node->extraBools()->Get(e)); + } - for (auto v: iArgs) - block.appendI(v); + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { + for (int e = 0; e < (int)node->extraTypes()->size(); e++) { + block.appendD((sd::DataType)node->extraTypes()->Get(e)); + } + } - for (auto v: tArgs) - block.appendT(v); + for (auto v : _dimensions) block.appendA(v); - this->setContextPrototype(block); - } else throw std::runtime_error("wrong custom operation given"); - } - }; - - Node::Node(const FlatNode *node) { - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - _extraParams = nullptr; - - _dataType = sd::DataType::FLOAT32; // float as default - if (node->scope_id() != 0) - this->_scope_id = node->scope_id(); - - if (node->scope_name() != nullptr && node->scope_name()->size() > 0) - this->_scope_name = node->scope_name()->str(); - - if (node->scalar() != nullptr) - throw std::runtime_error("FlatNode has scalar defined, it's deprecated"); - - if (node != nullptr) { - this->_id = node->id(); - //this->_dataType = DataTypeUtils::fromFlatDataType(node->dataType()); - this->_opNum = node->opNum(); - this->_opType = node->opType(); - - if (node->name() != nullptr && node->name()->c_str() != nullptr) { - this->_name = node->name()->str(); - } - - if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { - for (int e = 0; e < (int) node->inputPaired()->size(); e++) { - auto pair = node->inputPaired()->Get(e); - pickInput(pair->first(), pair->second()); - } - } else if (node->input() != nullptr && node->input()->size() > 0) { - for (int e = 0; e < (int) node->input()->size(); e++) - pickInput(node->input()->Get(e)); - } else { - if (this->opType() != OpType_LOGIC) { - if (this->_name.size() > 0) { - nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str()); - } else { - nd4j_debug("Node [%i:] has no inputs defined\n", this->_id); - } - } - } - - /* - if (node->output() != nullptr) - for (int e = 0; e < (int) node->output()->size(); e++) { - auto oid = node->output()->Get(e); - if (oid != this->_id && oid != 0) { - nd4j_verbose("Picking output: %i\n", node->output()->Get(e)); - pickOutput(oid); - } - } - */ - - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) { - _extraParams = new double[node->extraParams()->size()]; - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - _extraParams[e] = static_cast(node->extraParams()->Get(e)); - } - } - - //if (node->dimensions() != nullptr && node->dimensions()->size() > 0) - // throw std::runtime_error("FlatNode has dimensions defined. Graph is outdated"); - - if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { - if (node->extraInteger()->size() < 1) { - nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", this->id()); - throw std::runtime_error("Enter node must have FrameID specified"); - } - - this->setFrameId(node->extraInteger()->Get(0)); - } - - - // these ops allow in-place execution by design - if (_opType == OpType_BROADCAST || - _opType == OpType_BROADCAST_BOOL || - _opType == OpType_INDEX_REDUCE || - _opType == OpType_SUMMARYSTATS || - _opType == OpType_REDUCE_BOOL || - _opType == OpType_REDUCE_SAME || - _opType == OpType_REDUCE_FLOAT || - _opType == OpType_REDUCE_3 || - _opType == OpType_TRANSFORM_STRICT || - _opType == OpType_TRANSFORM_SAME || - _opType == OpType_TRANSFORM_FLOAT || - _opType == OpType_TRANSFORM_BOOL || - _opType == OpType_RANDOM || - _opType == OpType_PAIRWISE || - _opType == OpType_PAIRWISE_BOOL || - _opType == OpType_SCALAR_BOOL || - _opType == OpType_SCALAR) { - - if (_output.size() <= 1) { - _isInplace = true; - } - - if (node->input() != nullptr && node->input()->size() > 0) { - ContextPrototype block(nullptr, this->id(), false); - - - for (auto v: _dimensions) - block.appendA(v); - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block.appendT(static_cast(node->extraParams()->Get(e))); - } - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block.appendB(node->extraBools()->Get(e)); - } - - if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block.appendI(node->extraInteger()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block.appendD((sd::DataType) node->extraTypes()->Get(e)); - } - } - - this->setContextPrototype(block); - this->setCustomOp(Node::buildOpByType(_opType, (int) node->input()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum)); - block.setOpDescriptor(this->customOp()->getOpDescriptor()); - } else if (node->inputPaired() != nullptr && node->inputPaired()->size() > 0) { - ContextPrototype block(nullptr, this->id(), false); - - for (int e = 0; e < this->input().size(); e++) { - block.pickInput(this->input().at(e)); - } - - // there's no other IArgs in legacy options, actually - for (auto v: _dimensions) - block.appendA(v); - - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { - block.appendT(static_cast(node->extraParams()->Get(e))); - } - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block.appendB(node->extraBools()->Get(e)); - } - - if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { - block.appendI(node->extraInteger()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block.appendD((sd::DataType) node->extraTypes()->Get(e)); - } - } - - this->setContextPrototype(block); - - this->setCustomOp(Node::buildOpByType(_opType, (int) node->inputPaired()->size(), (int) block.getIArguments().size(), (int) block.getTArguments().size(), (int) _opNum)); - block.setOpDescriptor(this->customOp()->getOpDescriptor()); - } - } else if (this->_opType == OpType_CUSTOM) { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum()); - if (op == nullptr) { - nd4j_verbose("Can't find operation: %lld\n", this->opNum()); - throw std::runtime_error("Can't find requested operation"); - } - - ContextPrototype block(nullptr, this->id()); - - for (int e = 0; e < this->input().size(); e++) { - block.pickInput(this->input().at(e)); - } - - if (node->extraInteger() != nullptr) - for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { - auto v = node->extraInteger()->Get(e); - // FIXME: remove this static_cast, iArgs should be Nd4jLong - block.appendI(static_cast(v)); - } - - if (node->extraParams() != nullptr) - for (uint32_t e = 0; e < node->extraParams()->size(); e++) - block.appendT(static_cast(node->extraParams()->Get(e))); - - if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { - block.appendB(node->extraBools()->Get(e)); - } - - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { - block.appendD((sd::DataType) node->extraTypes()->Get(e)); - } - } - - for (auto v: _dimensions) - block.appendA(v); - - this->setContextPrototype(block); - this->setCustomOp(op); - block.setOpDescriptor(this->customOp()->getOpDescriptor()); - } - } else { - // empty dynamic node, tests probably - } - } + this->setContextPrototype(block); + this->setCustomOp(op); + block.setOpDescriptor(this->customOp()->getOpDescriptor()); + } + } else { + // empty dynamic node, tests probably + } +} - sd::DataType Node::dataType() { - return _dataType; - } +sd::DataType Node::dataType() { return _dataType; } - const ContextPrototype& Node::protoContext() const { - return _protoContext; - } +const ContextPrototype &Node::protoContext() const { return _protoContext; } - Node::~Node() { - if (_extraParams != nullptr) - delete[] _extraParams; - } +Node::~Node() { + if (_extraParams != nullptr) delete[] _extraParams; +} - int Node::getRewindNode() { - return _rewindNode; - } +int Node::getRewindNode() { return _rewindNode; } - void Node::setRewindNode(int nodeId) { - _rewindNode = nodeId; - } +void Node::setRewindNode(int nodeId) { _rewindNode = nodeId; } - std::pair& Node::getRewindLayer() { - return _rewindLayer; - }; +std::pair &Node::getRewindLayer() { return _rewindLayer; }; - void Node::setRewindLayer(int layerId, int stepId) { - _rewindLayer.first = layerId; - _rewindLayer.second = stepId; - } +void Node::setRewindLayer(int layerId, int stepId) { + _rewindLayer.first = layerId; + _rewindLayer.second = stepId; +} - bool Node::equals(Node *other) const { - if (_opType == other->_opType && _dataType == other->_dataType && _opNum == other->_opNum) - return true; +bool Node::equals(Node *other) const { + if (_opType == other->_opType && _dataType == other->_dataType && + _opNum == other->_opNum) + return true; - return false; - } + return false; +} - Node::Node(const Node &other) noexcept { - _dataType = other._dataType; - _opType = other._opType; - _opClass = other._opClass; - _opNum = other._opNum; - _customOp = other._customOp; - _name = other._name; - _scope_id = other._scope_id; - _scope_name = other._scope_name; - _rewindNode = other._rewindNode; - - _hasExternalOutputs = other._hasExternalOutputs; - _hasExternalInputs = other._hasExternalInputs; - _hasInternalOutputs = other._hasInternalOutputs; - _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; - _active = other._active; - _removable = other._removable; - - _graph = other._graph; - _customOp = other._customOp; - _extraParams = other._extraParams; - _protoContext = other._protoContext; - - _input = other._input; - _output = other._output; - _dimensions = other._dimensions; - _rewindLayer = other._rewindLayer; - _referencedBy = other._referencedBy; - } +Node::Node(const Node &other) noexcept { + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _scope_id = other._scope_id; + _scope_name = other._scope_name; + _rewindNode = other._rewindNode; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _customOp = other._customOp; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; + _dimensions = other._dimensions; + _rewindLayer = other._rewindLayer; + _referencedBy = other._referencedBy; +} - Node &Node::operator=(const Node &other) noexcept { - if (this == &other) - return *this; - - _dataType = other._dataType; - _opType = other._opType; - _opClass = other._opClass; - _opNum = other._opNum; - _customOp = other._customOp; - _name = other._name; - _scope_id = other._scope_id; - _scope_name = other._scope_name; - _rewindNode = other._rewindNode; - - _hasExternalOutputs = other._hasExternalOutputs; - _hasExternalInputs = other._hasExternalInputs; - _hasInternalOutputs = other._hasInternalOutputs; - _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; - _active = other._active; - _removable = other._removable; - - _graph = other._graph; - _customOp = other._customOp; - _extraParams = other._extraParams; - _protoContext = other._protoContext; - - _input = other._input; - _output = other._output; - _dimensions = other._dimensions; - _rewindLayer = other._rewindLayer; - _referencedBy = other._referencedBy; - - return *this; - } +Node &Node::operator=(const Node &other) noexcept { + if (this == &other) return *this; + + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _name = other._name; + _scope_id = other._scope_id; + _scope_name = other._scope_name; + _rewindNode = other._rewindNode; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _customOp = other._customOp; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _input = other._input; + _output = other._output; + _dimensions = other._dimensions; + _rewindLayer = other._rewindLayer; + _referencedBy = other._referencedBy; + + return *this; +} - Node::Node(Node &&other) noexcept { - _dataType = other._dataType; - _opType = other._opType; - _opClass = other._opClass; - _opNum = other._opNum; - _customOp = other._customOp; - _scope_id = other._scope_id; - _name = std::move(other._name); - _scope_name = std::move(other._scope_name); - _rewindNode = other._rewindNode; - - _hasExternalOutputs = other._hasExternalOutputs; - _hasExternalInputs = other._hasExternalInputs; - _hasInternalOutputs = other._hasInternalOutputs; - _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; - _active = other._active; - _removable = other._removable; - - _graph = other._graph; - _extraParams = other._extraParams; - _protoContext = other._protoContext; - - _customOp = std::move(other._customOp); - _input = std::move(other._input); - _output = std::move(other._output); - _dimensions = std::move(other._dimensions); - _rewindLayer = std::move(other._rewindLayer); - _referencedBy = std::move(other._referencedBy); - - other._customOp = nullptr; - } +Node::Node(Node &&other) noexcept { + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _scope_id = other._scope_id; + _name = std::move(other._name); + _scope_name = std::move(other._scope_name); + _rewindNode = other._rewindNode; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _customOp = std::move(other._customOp); + _input = std::move(other._input); + _output = std::move(other._output); + _dimensions = std::move(other._dimensions); + _rewindLayer = std::move(other._rewindLayer); + _referencedBy = std::move(other._referencedBy); + + other._customOp = nullptr; +} - Node &Node::operator=(Node &&other) noexcept { - if (this == &other) - return *this; - - _dataType = other._dataType; - _opType = other._opType; - _opClass = other._opClass; - _opNum = other._opNum; - _customOp = other._customOp; - _scope_id = other._scope_id; - _name = std::move(other._name); - _scope_name = std::move(other._scope_name); - _rewindNode = other._rewindNode; - - _hasExternalOutputs = other._hasExternalOutputs; - _hasExternalInputs = other._hasExternalInputs; - _hasInternalOutputs = other._hasInternalOutputs; - _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; - _active = other._active; - _removable = other._removable; - - _graph = other._graph; - _extraParams = other._extraParams; - _protoContext = other._protoContext; - - _customOp = std::move(other._customOp); - _input = std::move(other._input); - _output = std::move(other._output); - _dimensions = std::move(other._dimensions); - _rewindLayer = std::move(other._rewindLayer); - _referencedBy = std::move(other._referencedBy); - - return *this; - } +Node &Node::operator=(Node &&other) noexcept { + if (this == &other) return *this; + + _dataType = other._dataType; + _opType = other._opType; + _opClass = other._opClass; + _opNum = other._opNum; + _customOp = other._customOp; + _scope_id = other._scope_id; + _name = std::move(other._name); + _scope_name = std::move(other._scope_name); + _rewindNode = other._rewindNode; + + _hasExternalOutputs = other._hasExternalOutputs; + _hasExternalInputs = other._hasExternalInputs; + _hasInternalOutputs = other._hasInternalOutputs; + _hasInternalInputs = other._hasInternalInputs; + _isInplace = other._isInplace; + _active = other._active; + _removable = other._removable; + + _graph = other._graph; + _extraParams = other._extraParams; + _protoContext = other._protoContext; + + _customOp = std::move(other._customOp); + _input = std::move(other._input); + _output = std::move(other._output); + _dimensions = std::move(other._dimensions); + _rewindLayer = std::move(other._rewindLayer); + _referencedBy = std::move(other._referencedBy); + + return *this; +} - void Node::deleteOpByType(OpType opType, void *op) { - switch (opType) { - case OpType_PAIRWISE: - delete reinterpret_cast(op); - break; - case OpType_PAIRWISE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_STRICT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_SAME: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_BOOL: - delete reinterpret_cast(op); - break; - case OpType_SCALAR: - delete reinterpret_cast(op); - break; - case OpType_SCALAR_BOOL: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_3: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_SAME: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_LONG: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_INDEX_REDUCE: - delete reinterpret_cast(op); - break; - case OpType_SUMMARYSTATS: - delete reinterpret_cast(op); - break; - case OpType_RANDOM: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST_BOOL: - delete reinterpret_cast(op); - break; - case OpType_CUSTOM: - delete reinterpret_cast(op); - break; - default: - throw std::runtime_error("Bad opType passed in"); - } - } +void Node::deleteOpByType(OpType opType, void *op) { + switch (opType) { + case OpType_PAIRWISE: + delete reinterpret_cast(op); + break; + case OpType_PAIRWISE_BOOL: + delete reinterpret_cast(op); + break; + case OpType_TRANSFORM_STRICT: + delete reinterpret_cast(op); + break; + case OpType_TRANSFORM_SAME: + delete reinterpret_cast(op); + break; + case OpType_TRANSFORM_FLOAT: + delete reinterpret_cast(op); + break; + case OpType_TRANSFORM_BOOL: + delete reinterpret_cast(op); + break; + case OpType_SCALAR: + delete reinterpret_cast(op); + break; + case OpType_SCALAR_BOOL: + delete reinterpret_cast(op); + break; + case OpType_REDUCE_3: + delete reinterpret_cast(op); + break; + case OpType_REDUCE_SAME: + delete reinterpret_cast(op); + break; + case OpType_REDUCE_FLOAT: + delete reinterpret_cast(op); + break; + case OpType_REDUCE_LONG: + delete reinterpret_cast(op); + break; + case OpType_REDUCE_BOOL: + delete reinterpret_cast(op); + break; + case OpType_INDEX_REDUCE: + delete reinterpret_cast(op); + break; + case OpType_SUMMARYSTATS: + delete reinterpret_cast(op); + break; + case OpType_RANDOM: + delete reinterpret_cast(op); + break; + case OpType_BROADCAST: + delete reinterpret_cast(op); + break; + case OpType_BROADCAST_BOOL: + delete reinterpret_cast(op); + break; + case OpType_CUSTOM: + delete reinterpret_cast(op); + break; + default: + throw std::runtime_error("Bad opType passed in"); + } +} - std::shared_ptr Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum) { - switch (opType) { - case OpType_PAIRWISE: - return std::make_shared(opNum); - case OpType_PAIRWISE_BOOL: - return std::make_shared(opNum); - case OpType_TRANSFORM_STRICT: - return std::make_shared(opNum); - case OpType_TRANSFORM_SAME: - return std::make_shared(opNum); - case OpType_TRANSFORM_FLOAT: - return std::make_shared(opNum); - case OpType_TRANSFORM_BOOL: - return std::make_shared(opNum); - case OpType_SCALAR: - return std::make_shared(opNum); - case OpType_SCALAR_BOOL: - return std::make_shared(opNum); - case OpType_REDUCE_3: - return std::make_shared(opNum); - case OpType_REDUCE_SAME: - return std::make_shared(opNum); - case OpType_REDUCE_FLOAT: - return std::make_shared(opNum); - case OpType_REDUCE_LONG: - return std::make_shared(opNum); - case OpType_REDUCE_BOOL: - return std::make_shared(opNum); - case OpType_INDEX_REDUCE: - return std::make_shared(opNum); - case OpType_SUMMARYSTATS: - return std::make_shared(opNum); - case OpType_RANDOM: - return std::make_shared(opNum); - case OpType_BROADCAST: - return std::make_shared(opNum); - case OpType_BROADCAST_BOOL: - return std::make_shared(opNum); - default: - throw std::runtime_error("Bad opType passed in"); - } - } +std::shared_ptr Node::buildOpByType( + OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum) { + switch (opType) { + case OpType_PAIRWISE: + return std::make_shared(opNum); + case OpType_PAIRWISE_BOOL: + return std::make_shared(opNum); + case OpType_TRANSFORM_STRICT: + return std::make_shared(opNum); + case OpType_TRANSFORM_SAME: + return std::make_shared(opNum); + case OpType_TRANSFORM_FLOAT: + return std::make_shared(opNum); + case OpType_TRANSFORM_BOOL: + return std::make_shared(opNum); + case OpType_SCALAR: + return std::make_shared(opNum); + case OpType_SCALAR_BOOL: + return std::make_shared(opNum); + case OpType_REDUCE_3: + return std::make_shared(opNum); + case OpType_REDUCE_SAME: + return std::make_shared(opNum); + case OpType_REDUCE_FLOAT: + return std::make_shared(opNum); + case OpType_REDUCE_LONG: + return std::make_shared(opNum); + case OpType_REDUCE_BOOL: + return std::make_shared(opNum); + case OpType_INDEX_REDUCE: + return std::make_shared(opNum); + case OpType_SUMMARYSTATS: + return std::make_shared(opNum); + case OpType_RANDOM: + return std::make_shared(opNum); + case OpType_BROADCAST: + return std::make_shared(opNum); + case OpType_BROADCAST_BOOL: + return std::make_shared(opNum); + default: + throw std::runtime_error("Bad opType passed in"); + } +} - Node* Node::clone() { - if (this->_customOp && this->_opType == OpType_CUSTOM) { - auto clone = new Node(_customOp.get(), _id); - clone->pullValues(this); - return clone; - } - else { - auto clone = new Node(_opType, _opNum, _id); +Node *Node::clone() { + if (this->_customOp && this->_opType == OpType_CUSTOM) { + auto clone = new Node(_customOp.get(), _id); + clone->pullValues(this); + return clone; + } else { + auto clone = new Node(_opType, _opNum, _id); - clone->pullValues(this); + clone->pullValues(this); - // op time - clone->_customOp = _customOp; + // op time + clone->_customOp = _customOp; - return clone; - } - } - } + return clone; + } } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/NodeState.cpp b/libnd4j/include/graph/impl/NodeState.cpp index d09de9c57bbc..09c4ba461f5c 100644 --- a/libnd4j/include/graph/impl/NodeState.cpp +++ b/libnd4j/include/graph/impl/NodeState.cpp @@ -21,49 +21,27 @@ #include namespace sd { - namespace graph { - NodeState::NodeState(int id) { - _id = id; - } +namespace graph { +NodeState::NodeState(int id) { _id = id; } - void NodeState::setInnerTime(Nd4jLong time) { - _inner = time; - } +void NodeState::setInnerTime(Nd4jLong time) { _inner = time; } - void NodeState::setOuterTime(Nd4jLong time) { - _outer = time; - } +void NodeState::setOuterTime(Nd4jLong time) { _outer = time; } - Nd4jLong NodeState::innerTime() { - return _inner; - } +Nd4jLong NodeState::innerTime() { return _inner; } - Nd4jLong NodeState::outerTime() { - return _outer; - } +Nd4jLong NodeState::outerTime() { return _outer; } - void NodeState::markActive(bool isActive) { - _active = isActive; - } +void NodeState::markActive(bool isActive) { _active = isActive; } - bool NodeState::isActive() { - return _active; - } +bool NodeState::isActive() { return _active; } - int NodeState::branch() { - return _branch; - } +int NodeState::branch() { return _branch; } - void NodeState::markBranch(int index) { - _branch = index; - } +void NodeState::markBranch(int index) { _branch = index; } - bool NodeState::wasExecuted() { - return _executed; - } +bool NodeState::wasExecuted() { return _executed; } - void NodeState::markExecuted(bool wasExecuted) { - _executed = wasExecuted; - } - } -} \ No newline at end of file +void NodeState::markExecuted(bool wasExecuted) { _executed = wasExecuted; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index f24f8ee36ce8..7a8059b3d573 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -19,352 +19,349 @@ // @author oleg.semeniv@gmail.com // -#include #include +#include namespace sd { - namespace graph { - OptimizedGraph::OptimizedGraph(Graph *original) { - _originalGraph = original; - _memoryManager = const_cast(&original->memoryManager()); - // create optimized graph - createOptimizedGraph(); - } - - OptimizedGraph::OptimizedGraph(const OptimizedGraph &other) noexcept { - _onion = other._onion; - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; - } +namespace graph { +OptimizedGraph::OptimizedGraph(Graph* original) { + _originalGraph = original; + _memoryManager = const_cast(&original->memoryManager()); + // create optimized graph + createOptimizedGraph(); +} - OptimizedGraph &OptimizedGraph::operator=(const OptimizedGraph &other) noexcept { - if (this == &other) - return *this; +OptimizedGraph::OptimizedGraph(const OptimizedGraph& other) noexcept { + _onion = other._onion; + _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; +} - _onion = other._onion; - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; +OptimizedGraph& OptimizedGraph::operator=( + const OptimizedGraph& other) noexcept { + if (this == &other) return *this; - return *this; - } + _onion = other._onion; + _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; - OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept { - _onion = std::move(other._onion); - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; - } + return *this; +} - OptimizedGraph &OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { - if (this == &other) - return *this; +OptimizedGraph::OptimizedGraph(OptimizedGraph&& other) noexcept { + _onion = std::move(other._onion); + _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; +} - _onion = std::move(other._onion); - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; +OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph&& other) noexcept { + if (this == &other) return *this; - return *this; - } + _onion = std::move(other._onion); + _memoryManager = other._memoryManager; + _originalGraph = other._originalGraph; - size_t OptimizedGraph::size() const { - std::lock_guard lock(_mutex); + return *this; +} - std::vector seq; - if (_size == 0) - for (const auto &v:_onion) { - for (int e = 0; e < v.second.width(); e++) { - _size += v.second.at(0).length(); - } - } +size_t OptimizedGraph::size() const { + std::lock_guard lock(_mutex); - return _size; - } + std::vector seq; + if (_size == 0) + for (const auto& v : _onion) { + for (int e = 0; e < v.second.width(); e++) { + _size += v.second.at(0).length(); + } + } - uint64_t OptimizedGraph::layers() const { - return _onion.size(); - } + return _size; +} - const ExecutionLayer &OptimizedGraph::layer(uint64_t index) const { - return _onion.at(index); - } +uint64_t OptimizedGraph::layers() const { return _onion.size(); } - void OptimizedGraph::append(const std::vector &layer) { - std::lock_guard lock(_mutex); - _onion[_onion.size()] = layer; - _size = 0; - } +const ExecutionLayer& OptimizedGraph::layer(uint64_t index) const { + return _onion.at(index); +} - void OptimizedGraph::append(OpSequence &sequence) { - append(ExecutionLayer({sequence})); - } +void OptimizedGraph::append(const std::vector& layer) { + std::lock_guard lock(_mutex); + _onion[_onion.size()] = layer; + _size = 0; +} - void OptimizedGraph::append(const ExecutionLayer &layer) { - std::lock_guard lock(_mutex); - _onion[_onion.size()] = layer; - _size = 0; - } +void OptimizedGraph::append(OpSequence& sequence) { + append(ExecutionLayer({sequence})); +} - const GraphMemoryManager &OptimizedGraph::memoryManager() const { - return *_memoryManager; - } +void OptimizedGraph::append(const ExecutionLayer& layer) { + std::lock_guard lock(_mutex); + _onion[_onion.size()] = layer; + _size = 0; +} - const Graph &OptimizedGraph::originalGraph() const { - return *_originalGraph; - } +const GraphMemoryManager& OptimizedGraph::memoryManager() const { + return *_memoryManager; +} - bool OptimizedGraph::opGraphProto(std::unordered_map& collector, std::set& startNodes, - std::set& inBranchingNodes) const { - - // double check to avoid unstable behavior - if (originalGraph().unmappedNodes().empty()) - return false; - - const auto& unmappedNodes = originalGraph().unmappedNodes(); - // iterate via original graph nodes to gather node information - for (const auto& it : unmappedNodes) { - - const auto& ID = it.first; - const auto& inputs = it.second.input(); - // if node info is not in collecter add it - if (collector.find(ID) == collector.end()) - collector[ID] = NodeInfo(); - - NodeInfo& parentNode = collector[ID]; - // count external and internal inputs to find out the type of the node (start, in-branching, out-branching) - int inExCounts = 0, inInternalCounts = 0; - for (const auto& in : inputs) { - // find input id in original graph - if (unmappedNodes.find(in.first) == unmappedNodes.end() ){ - // count external inputs, all inputs which id is not in unmapped container will be treaded as external - inExCounts++; - } - else { - // count iternal inputs, all inputs that are not in external variable space - // will be treated as outputs from other nodes - inInternalCounts++; - // if node info is not in collector add it - if (collector.find(in.first) == collector.end()) - collector[in.first] = NodeInfo(); - // input node connection with discovered - collector[in.first].addConnection(ID); - } - } - // set operation type - parentNode.setType( it.second.opType() ); - - // if move then 1 internal input this is in-branching node - parentNode.setInBranching( inInternalCounts > 1); - // gather start and in-branching nodes for the loop when operations are put to OpSequence (topolSearch) - if (inExCounts == inputs.size()) { - startNodes.emplace(ID); - } - else { - if (parentNode.isInBranching()) - inBranchingNodes.emplace(ID); - } - } - return true; - } +const Graph& OptimizedGraph::originalGraph() const { return *_originalGraph; } + +bool OptimizedGraph::opGraphProto(std::unordered_map& collector, + std::set& startNodes, + std::set& inBranchingNodes) const { + // double check to avoid unstable behavior + if (originalGraph().unmappedNodes().empty()) return false; + + const auto& unmappedNodes = originalGraph().unmappedNodes(); + // iterate via original graph nodes to gather node information + for (const auto& it : unmappedNodes) { + const auto& ID = it.first; + const auto& inputs = it.second.input(); + // if node info is not in collecter add it + if (collector.find(ID) == collector.end()) collector[ID] = NodeInfo(); + + NodeInfo& parentNode = collector[ID]; + // count external and internal inputs to find out the type of the node + // (start, in-branching, out-branching) + int inExCounts = 0, inInternalCounts = 0; + for (const auto& in : inputs) { + // find input id in original graph + if (unmappedNodes.find(in.first) == unmappedNodes.end()) { + // count external inputs, all inputs which id is not in unmapped + // container will be treaded as external + inExCounts++; + } else { + // count iternal inputs, all inputs that are not in external variable + // space will be treated as outputs from other nodes + inInternalCounts++; + // if node info is not in collector add it + if (collector.find(in.first) == collector.end()) + collector[in.first] = NodeInfo(); + // input node connection with discovered + collector[in.first].addConnection(ID); + } + } + // set operation type + parentNode.setType(it.second.opType()); + + // if move then 1 internal input this is in-branching node + parentNode.setInBranching(inInternalCounts > 1); + // gather start and in-branching nodes for the loop when operations are put + // to OpSequence (topolSearch) + if (inExCounts == inputs.size()) { + startNodes.emplace(ID); + } else { + if (parentNode.isInBranching()) inBranchingNodes.emplace(ID); + } + } + return true; +} - bool OptimizedGraph::topolSearch(const int startNode, std::unordered_map& collector, - std::vector >& opSeq) const { - - // double check to avoid unstable behavior - if (originalGraph().unmappedNodes().empty() || collector.empty() ) - return false; - - // skip nodes which are not pre-collected and pre-processed - auto itParent = collector.find(startNode); - if (itParent != collector.end()) { - // iterate via start (in-branching) nodes connections in depth - for (const auto& itNodes : itParent->second.connections()) { - - auto itChild = collector.find(itNodes); - // double check - if (itChild != collector.end()) { - // if the child is in-branching node it will be treated as start node or it was proceed - if (itChild->second.isInBranching() || itChild->second.isProcessed()) { - continue; - } - // put operation to OpSequence container - const auto it = originalGraph().unmappedNodes().find(itNodes); - auto& child = itChild->second; - // the layer and sequence are pre-defined in layersSeqDefine method - opSeq[child.layer()][child.sequence()].append(it->second.customOp(), it->second.contextPrototype()); - child.setProcessed(); - // go to the child node connections - topolSearch(itNodes, collector, opSeq); - } - } - } - return true; +bool OptimizedGraph::topolSearch( + const int startNode, std::unordered_map& collector, + std::vector>& opSeq) const { + // double check to avoid unstable behavior + if (originalGraph().unmappedNodes().empty() || collector.empty()) + return false; + + // skip nodes which are not pre-collected and pre-processed + auto itParent = collector.find(startNode); + if (itParent != collector.end()) { + // iterate via start (in-branching) nodes connections in depth + for (const auto& itNodes : itParent->second.connections()) { + auto itChild = collector.find(itNodes); + // double check + if (itChild != collector.end()) { + // if the child is in-branching node it will be treated as start node or + // it was proceed + if (itChild->second.isInBranching() || itChild->second.isProcessed()) { + continue; } + // put operation to OpSequence container + const auto it = originalGraph().unmappedNodes().find(itNodes); + auto& child = itChild->second; + // the layer and sequence are pre-defined in layersSeqDefine method + opSeq[child.layer()][child.sequence()].append( + it->second.customOp(), it->second.contextPrototype()); + child.setProcessed(); + // go to the child node connections + topolSearch(itNodes, collector, opSeq); + } + } + } + return true; +} - void OptimizedGraph::createOptimizedGraph() { - - // container to store node infor - std::unordered_map collector; - // containers to store start and in-branching nodes - std::set startNodes, inBranching; - // container to store max sequences per layer - std::unordered_map layersMaxSeq; - - // optimizing graph prototyping - // select start nodes - // create connections between nodes - // select in-branching nodes ( more then one iternal input -> outputs from other nodes) - if (!opGraphProto(collector, startNodes, inBranching)) - throw std::runtime_error("OptimizedGraph::optimizedGraph() - not prototyped!"); - - // next step set the node layer and it sequence in layer - // define max layers and max sequence per layer - int startSeq = 0; - bool bOnlyStartNodes = collector.empty(); - for (const auto& id : startNodes) { - layersMaxSeq[0] = startSeq; - // if only start nodes exists they have to be add to connections - if(bOnlyStartNodes){ - auto node = NodeInfo(); - node.setLayer(0); - node.setProcessed(true); - node.setSequence(startSeq); - collector[id] = node; - } - else{ - layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); - } - startSeq++; - } - - // init container to collect operations per node position (layer:sequence) - std::vector> vOpSeq; - if(!initOpSeqContainer(layersMaxSeq, vOpSeq)) - throw std::runtime_error("OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, not all nodes properly prototyped!"); - - // combine start nodes and in-branching nodes - startNodes.insert(inBranching.begin(), inBranching.end()); - // re-init proceed NodeInfo member to avoid append sequence several times - for(auto& it : collector){ - it.second.setProcessed(false); - } - - // iterate via start and in-branching nodes - for (const auto& id : startNodes) { - - const auto it = originalGraph().unmappedNodes().find(id); - auto& nodeInfo = collector[id]; - // append start/in-branching node operation to sequence - if(!nodeInfo.isProcessed()){ - vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append(it->second.customOp(), it->second.contextPrototype()); - nodeInfo.setProcessed(); - } - - // search in depth via connections of "start" node - if(!topolSearch(id, collector, vOpSeq)) - throw std::runtime_error("OptimizedGraph::topolSearch() - cannot run topological search, inputs incorrect!"); - } - // put results to optimized graph - for (auto& vSeq : vOpSeq) { - this->append(vSeq); - } - } +void OptimizedGraph::createOptimizedGraph() { + // container to store node infor + std::unordered_map collector; + // containers to store start and in-branching nodes + std::set startNodes, inBranching; + // container to store max sequences per layer + std::unordered_map layersMaxSeq; + + // optimizing graph prototyping + // select start nodes + // create connections between nodes + // select in-branching nodes ( more then one iternal input -> outputs from + // other nodes) + if (!opGraphProto(collector, startNodes, inBranching)) + throw std::runtime_error( + "OptimizedGraph::optimizedGraph() - not prototyped!"); + + // next step set the node layer and it sequence in layer + // define max layers and max sequence per layer + int startSeq = 0; + bool bOnlyStartNodes = collector.empty(); + for (const auto& id : startNodes) { + layersMaxSeq[0] = startSeq; + // if only start nodes exists they have to be add to connections + if (bOnlyStartNodes) { + auto node = NodeInfo(); + node.setLayer(0); + node.setProcessed(true); + node.setSequence(startSeq); + collector[id] = node; + } else { + layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); + } + startSeq++; + } + + // init container to collect operations per node position (layer:sequence) + std::vector> vOpSeq; + if (!initOpSeqContainer(layersMaxSeq, vOpSeq)) + throw std::runtime_error( + "OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, " + "not all nodes properly prototyped!"); + + // combine start nodes and in-branching nodes + startNodes.insert(inBranching.begin(), inBranching.end()); + // re-init proceed NodeInfo member to avoid append sequence several times + for (auto& it : collector) { + it.second.setProcessed(false); + } + + // iterate via start and in-branching nodes + for (const auto& id : startNodes) { + const auto it = originalGraph().unmappedNodes().find(id); + auto& nodeInfo = collector[id]; + // append start/in-branching node operation to sequence + if (!nodeInfo.isProcessed()) { + vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append( + it->second.customOp(), it->second.contextPrototype()); + nodeInfo.setProcessed(); + } - bool OptimizedGraph::initOpSeqContainer(const std::unordered_map& layersMaxSeq, std::vector>& vOpSeq) const { - - // double check to avoid unstable behavior - if (layersMaxSeq.empty()) - return false; - // pre-init op-sequence size layers/per-layer sequence - vOpSeq.resize(layersMaxSeq.size()); - for (const auto& it : layersMaxSeq) { - vOpSeq[it.first].resize(it.second + 1); - } - return true; - } + // search in depth via connections of "start" node + if (!topolSearch(id, collector, vOpSeq)) + throw std::runtime_error( + "OptimizedGraph::topolSearch() - cannot run topological search, " + "inputs incorrect!"); + } + // put results to optimized graph + for (auto& vSeq : vOpSeq) { + this->append(vSeq); + } +} - bool OptimizedGraph::layersSeqDefine(std::unordered_map& collection, int ID, int layer, int startSeq, - std::unordered_map& layersMaxSeq) const { - - // double check to avoid unstable behavior - auto parent = collection.find(ID); - if (parent == collection.end()) - return false; - - // if node was proceed and the current layer is less of it own return - if(parent->second.isProcessed() && parent->second.layer() >= layer) - return true; - - // put layer and sequence to container that collects layers and max sequence per layer - auto layerFound = layersMaxSeq.find(layer); - if(layerFound == layersMaxSeq.end()){ - // if layer was not treated before, create pair for it - layersMaxSeq[layer] = 0; - // set sequence value to 0, as this is first sequence in layer - startSeq = 0; - } - else{ - // if node sequence position was not checked use it for max sequence selection - // sequence have to be incremented as max + 1, without any jumps - if(startSeq > (layerFound->second + 1)) - startSeq = layerFound->second + 1; - - layerFound->second = (layerFound->second < startSeq && parent->second.sequence() < 0) ? startSeq : layerFound->second; - } - - // double check if the layer is higher and set node layer - if(parent->second.layer() < layer) - parent->second.setLayer(layer); - // double check if sequence was init, if not set current sequence - if(parent->second.sequence() < 0) - parent->second.setSequence(startSeq); - // set is node out-branching - parent->second.setOutBranching(parent->second.connections().size() > 1); - // set that node was processed, to avoid it double processing (only for some cases it can be processed several times) - parent->second.setProcessed(); - - // if current node is out-branching it childs will be put to next layer - if (parent->second.isOutBranching() && !parent->second.isLogic()) - layer++; - - // childs sequence position have to start from max defined sequence position in layer - // or if it is first node in layer from 0 - int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) ? 0 : layersMaxSeq[layer]; - // if parent is out-branching node sequence have to be increment - // on the next stage the sequence value will be double checked with max per layer - // todo check logic part maybe here have to be check operation class (something likke Switch, If, While etc) - // probably for each of them could be other behavior - seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 : seq; - - // loop via childs (connected nodes) - for (const auto& id : parent->second.connections()) { - // double check to avoid unstable behavior - auto child = collection.find(id); - if(child == collection.end()) - return false; - - // in case parent was not out-branching node but child is in branching it will be put to next layer - // todo check logic part - if (!parent->second.isOutBranching() && child->second.isInBranching() && !child->second.isLogic()) - layer++; - - // move in depth of connections - layersSeqDefine(collection, id, layer, seq, layersMaxSeq); - // increment sequence as childs are on the one layer in case if child was not processed earlier - // todo check logic part - if(!parent->second.isLogic()) - seq++; - } - - return true; - } +bool OptimizedGraph::initOpSeqContainer( + const std::unordered_map& layersMaxSeq, + std::vector>& vOpSeq) const { + // double check to avoid unstable behavior + if (layersMaxSeq.empty()) return false; + // pre-init op-sequence size layers/per-layer sequence + vOpSeq.resize(layersMaxSeq.size()); + for (const auto& it : layersMaxSeq) { + vOpSeq[it.first].resize(it.second + 1); + } + return true; +} +bool OptimizedGraph::layersSeqDefine( + std::unordered_map& collection, int ID, int layer, + int startSeq, std::unordered_map& layersMaxSeq) const { + // double check to avoid unstable behavior + auto parent = collection.find(ID); + if (parent == collection.end()) return false; + + // if node was proceed and the current layer is less of it own return + if (parent->second.isProcessed() && parent->second.layer() >= layer) + return true; + + // put layer and sequence to container that collects layers and max sequence + // per layer + auto layerFound = layersMaxSeq.find(layer); + if (layerFound == layersMaxSeq.end()) { + // if layer was not treated before, create pair for it + layersMaxSeq[layer] = 0; + // set sequence value to 0, as this is first sequence in layer + startSeq = 0; + } else { + // if node sequence position was not checked use it for max sequence + // selection sequence have to be incremented as max + 1, without any jumps + if (startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; + + layerFound->second = + (layerFound->second < startSeq && parent->second.sequence() < 0) + ? startSeq + : layerFound->second; + } + + // double check if the layer is higher and set node layer + if (parent->second.layer() < layer) parent->second.setLayer(layer); + // double check if sequence was init, if not set current sequence + if (parent->second.sequence() < 0) parent->second.setSequence(startSeq); + // set is node out-branching + parent->second.setOutBranching(parent->second.connections().size() > 1); + // set that node was processed, to avoid it double processing (only for some + // cases it can be processed several times) + parent->second.setProcessed(); + + // if current node is out-branching it childs will be put to next layer + if (parent->second.isOutBranching() && !parent->second.isLogic()) layer++; + + // childs sequence position have to start from max defined sequence position + // in layer or if it is first node in layer from 0 + int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) + ? 0 + : layersMaxSeq[layer]; + // if parent is out-branching node sequence have to be increment + // on the next stage the sequence value will be double checked with max per + // layer todo check logic part maybe here have to be check operation class + // (something likke Switch, If, While etc) probably for each of them could be + // other behavior + seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 + : seq; + + // loop via childs (connected nodes) + for (const auto& id : parent->second.connections()) { + // double check to avoid unstable behavior + auto child = collection.find(id); + if (child == collection.end()) return false; + + // in case parent was not out-branching node but child is in branching it + // will be put to next layer todo check logic part + if (!parent->second.isOutBranching() && child->second.isInBranching() && + !child->second.isLogic()) + layer++; + + // move in depth of connections + layersSeqDefine(collection, id, layer, seq, layersMaxSeq); + // increment sequence as childs are on the one layer in case if child was + // not processed earlier todo check logic part + if (!parent->second.isLogic()) seq++; + } + + return true; +} - void OptimizedGraph::printOut() const { - for (uint64_t o = 0; o < _onion.size(); o++) { - const auto &layer = _onion.at(o); - printf("Layer [%lu]\n", o); - for (uint64_t l = 0; l < layer.width(); l++) - layer.at(l).printOut(); - } - } - } +void OptimizedGraph::printOut() const { + for (uint64_t o = 0; o < _onion.size(); o++) { + const auto& layer = _onion.at(o); + printf("Layer [%lu]\n", o); + for (uint64_t l = 0; l < layer.width(); l++) layer.at(l).printOut(); + } } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/RandomGenerator.cpp b/libnd4j/include/graph/impl/RandomGenerator.cpp index 07ceb77c5747..c6e1b7b81f59 100644 --- a/libnd4j/include/graph/impl/RandomGenerator.cpp +++ b/libnd4j/include/graph/impl/RandomGenerator.cpp @@ -21,37 +21,36 @@ #include namespace sd { - namespace graph { - RandomGenerator::RandomGenerator(const RandomGenerator& other) noexcept { - _rootState = other._rootState; - _nodeState = other._nodeState; - } - - RandomGenerator& RandomGenerator::operator=(const RandomGenerator& other) noexcept { - if (this == &other) - return *this; - - _rootState = other._rootState; - _nodeState = other._nodeState; - - return *this; - } - - // move constructor - RandomGenerator::RandomGenerator(RandomGenerator&& other) noexcept { - _rootState = other._rootState; - _nodeState = other._nodeState; - } - - // move assignment operator - RandomGenerator& RandomGenerator::operator=(RandomGenerator&& other) noexcept { - if (this == &other) - return *this; - - _rootState = other._rootState; - _nodeState = other._nodeState; - - return *this; - } - } -} \ No newline at end of file +namespace graph { +RandomGenerator::RandomGenerator(const RandomGenerator& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; +} + +RandomGenerator& RandomGenerator::operator=( + const RandomGenerator& other) noexcept { + if (this == &other) return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; +} + +// move constructor +RandomGenerator::RandomGenerator(RandomGenerator&& other) noexcept { + _rootState = other._rootState; + _nodeState = other._nodeState; +} + +// move assignment operator +RandomGenerator& RandomGenerator::operator=(RandomGenerator&& other) noexcept { + if (this == &other) return *this; + + _rootState = other._rootState; + _nodeState = other._nodeState; + + return *this; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/ResultWrapper.cpp b/libnd4j/include/graph/impl/ResultWrapper.cpp index 277644acf913..204445148c47 100644 --- a/libnd4j/include/graph/impl/ResultWrapper.cpp +++ b/libnd4j/include/graph/impl/ResultWrapper.cpp @@ -19,33 +19,27 @@ // #include -#include +#include namespace sd { - namespace graph { - ResultWrapper::ResultWrapper(Nd4jLong size, Nd4jPointer ptr) { - if (size <= 0) - throw std::runtime_error("FlatResult size should be > 0"); - - _size = size; - _pointer = ptr; - } - - ResultWrapper::~ResultWrapper() { - if (_pointer != nullptr && _size > 0) { - auto ptr = reinterpret_cast(_pointer); - delete[] ptr; - } - } - - - Nd4jLong ResultWrapper::size() { - return _size; - } - - Nd4jPointer ResultWrapper::pointer() { - return _pointer; - } - } -} \ No newline at end of file +namespace graph { +ResultWrapper::ResultWrapper(Nd4jLong size, Nd4jPointer ptr) { + if (size <= 0) throw std::runtime_error("FlatResult size should be > 0"); + + _size = size; + _pointer = ptr; +} + +ResultWrapper::~ResultWrapper() { + if (_pointer != nullptr && _size > 0) { + auto ptr = reinterpret_cast(_pointer); + delete[] ptr; + } +} + +Nd4jLong ResultWrapper::size() { return _size; } + +Nd4jPointer ResultWrapper::pointer() { return _pointer; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/Scope.cpp b/libnd4j/include/graph/impl/Scope.cpp index 84a8f2f0d074..9332537d4021 100644 --- a/libnd4j/include/graph/impl/Scope.cpp +++ b/libnd4j/include/graph/impl/Scope.cpp @@ -21,53 +21,38 @@ #include namespace sd { - namespace graph { - Scope::Scope(int id, const char *name) { - _id = id; - - if (name != nullptr) - _name = name; - else - name = ""; - } +namespace graph { +Scope::Scope(int id, const char* name) { + _id = id; + + if (name != nullptr) + _name = name; + else + name = ""; +} - Scope::~Scope() { - for (auto v: _nodes) - delete v; - } +Scope::~Scope() { + for (auto v : _nodes) delete v; +} - void Scope::push_back(Node *node) { - _nodes.emplace_back(node); - } +void Scope::push_back(Node* node) { _nodes.emplace_back(node); } - std::vector* Scope::nodes() { - return &_nodes; - } +std::vector* Scope::nodes() { return &_nodes; } - int Scope::size() { - return (int) _nodes.size(); - } +int Scope::size() { return (int)_nodes.size(); } - int Scope::id() { - return _id; - } +int Scope::id() { return _id; } - std::string* Scope::name() { - return &_name; - } +std::string* Scope::name() { return &_name; } - void Scope::forgetNodes() { - _nodes.clear(); - } +void Scope::forgetNodes() { _nodes.clear(); } - Scope* Scope::clone() { - auto clone = new Scope(_id, _name.c_str()); +Scope* Scope::clone() { + auto clone = new Scope(_id, _name.c_str()); - for (auto v: _nodes) - clone->_nodes.emplace_back(v->clone()); + for (auto v : _nodes) clone->_nodes.emplace_back(v->clone()); - return clone; - } - } + return clone; } - +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Stash.cpp b/libnd4j/include/graph/impl/Stash.cpp index 618e01c8745c..5215e988a4dd 100644 --- a/libnd4j/include/graph/impl/Stash.cpp +++ b/libnd4j/include/graph/impl/Stash.cpp @@ -20,40 +20,38 @@ #include - namespace std { - size_t hash::operator()(const sd::graph::KeyPair& k) const { - using std::hash; - auto res = std::hash()(k.name()); - res ^= std::hash()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; - } +size_t hash::operator()(const sd::graph::KeyPair &k) const { + using std::hash; + auto res = std::hash()(k.name()); + res ^= std::hash()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; } +} // namespace std namespace sd { - namespace graph { - sd::graph::KeyPair::KeyPair(int node, const char * name) { - _node = node; - _name = std::string(name); - } +namespace graph { +sd::graph::KeyPair::KeyPair(int node, const char *name) { + _node = node; + _name = std::string(name); +} - bool sd::graph::KeyPair::operator<(const KeyPair& other) const { - if (_node < other._node) - return true; - else if (_node > other._node) - return false; - else - return _name < other._name; - } +bool sd::graph::KeyPair::operator<(const KeyPair &other) const { + if (_node < other._node) + return true; + else if (_node > other._node) + return false; + else + return _name < other._name; +} - sd::graph::Stash::Stash() { - // - } +sd::graph::Stash::Stash() { + // +} - sd::graph::Stash::~Stash() { - if (_handles.size() > 0) - this->clear(); - } +sd::graph::Stash::~Stash() { + if (_handles.size() > 0) this->clear(); +} /* bool sd::graph::Stash::checkStash(sd::graph::Block& block, const char *name) { @@ -61,40 +59,40 @@ bool sd::graph::Stash::checkStash(sd::graph::Block& block, const char *name) { } */ - bool sd::graph::Stash::checkStash(int nodeId, const char *name) { - KeyPair kp(nodeId, name); - return _stash.count(kp) > 0; - } +bool sd::graph::Stash::checkStash(int nodeId, const char *name) { + KeyPair kp(nodeId, name); + return _stash.count(kp) > 0; +} /* -sd::NDArray* sd::graph::Stash::extractArray(sd::graph::Block& block, const char *name) { - return extractArray(block.getNodeId(), name); +sd::NDArray* sd::graph::Stash::extractArray(sd::graph::Block& block, const char +*name) { return extractArray(block.getNodeId(), name); } */ - sd::NDArray* sd::graph::Stash::extractArray(int nodeId, const char *name) { - KeyPair kp(nodeId, name); - return _stash[kp]; - } +sd::NDArray *sd::graph::Stash::extractArray(int nodeId, const char *name) { + KeyPair kp(nodeId, name); + return _stash[kp]; +} /* -void sd::graph::Stash::storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array) { - storeArray(block.getNodeId(), name, array); +void sd::graph::Stash::storeArray(sd::graph::Block& block, const char *name, +sd::NDArray *array) { storeArray(block.getNodeId(), name, array); } */ - void sd::graph::Stash::storeArray(int nodeId, const char *name, sd::NDArray *array) { - KeyPair kp(nodeId, name); - _stash[kp] = array; +void sd::graph::Stash::storeArray(int nodeId, const char *name, + sd::NDArray *array) { + KeyPair kp(nodeId, name); + _stash[kp] = array; - // storing reference to delete it once it's not needed anymore - _handles.push_back(array); - } + // storing reference to delete it once it's not needed anymore + _handles.push_back(array); +} - void sd::graph::Stash::clear() { - for (auto v: _handles) - delete v; +void sd::graph::Stash::clear() { + for (auto v : _handles) delete v; - _handles.clear(); - _stash.clear(); - } - } -} \ No newline at end of file + _handles.clear(); + _stash.clear(); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/TimeHolder.cpp b/libnd4j/include/graph/impl/TimeHolder.cpp index a292a89974cf..a01ae29f7bd4 100644 --- a/libnd4j/include/graph/impl/TimeHolder.cpp +++ b/libnd4j/include/graph/impl/TimeHolder.cpp @@ -21,28 +21,26 @@ #include namespace sd { - namespace graph { +namespace graph { - void TimeHolder::setOuterTime(int nodeId, Nd4jLong time) { - _outer[nodeId] = time; - } +void TimeHolder::setOuterTime(int nodeId, Nd4jLong time) { + _outer[nodeId] = time; +} - void TimeHolder::setInnerTime(int nodeId, Nd4jLong time) { - _inner[nodeId] = time; - } +void TimeHolder::setInnerTime(int nodeId, Nd4jLong time) { + _inner[nodeId] = time; +} - Nd4jLong TimeHolder::outerTime(int nodeId) { - if (_outer.count(nodeId) == 0) - return 0; +Nd4jLong TimeHolder::outerTime(int nodeId) { + if (_outer.count(nodeId) == 0) return 0; - return _outer[nodeId]; - } + return _outer[nodeId]; +} - Nd4jLong TimeHolder::innerTime(int nodeId) { - if (_inner.count(nodeId) == 0) - return 0; +Nd4jLong TimeHolder::innerTime(int nodeId) { + if (_inner.count(nodeId) == 0) return 0; - return _inner[nodeId]; - } - } + return _inner[nodeId]; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 6929258fe41e..ad80bce30203 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -18,327 +18,306 @@ // @author raver119@gmail.com // -#include -#include -#include #include #include +#include #include +#include +#include #include namespace sd { - namespace graph { - Variable::Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx) { - _list = std::make_shared(arrayList); - - if (!name.empty()) - _name = name; +namespace graph { +Variable::Variable(const NDArrayList &arrayList, const std::string &name, + int id, int idx) { + _list = std::make_shared(arrayList); - _id = id; - _index = idx; - } - - Variable::Variable(const NDArray &array, const std::string &name, int id, int idx) { - _ndarray = std::make_shared(array); - - if (!name.empty()) - _name = name; - - _id = id; - _index = idx; - } - - Variable::Variable() { - // - } - - void sd::graph::Variable::setIndex(int index) { - _index = index; - } - - bool sd::graph::Variable::hasNDArray() const { - return _ndarray.get() != nullptr; - } - - void sd::graph::Variable::setVariableType(VariableType variableType) { - _variableType = variableType; - } - - bool sd::graph::Variable::hasNDArrayList() const { - return _list != nullptr; - } - - bool sd::graph::Variable::isPlaceholder() const { - return _placeholder; - } + if (!name.empty()) _name = name; - const std::string& sd::graph::Variable::name() const { - return _name; - } + _id = id; + _index = idx; +} - const std::string& sd::graph::Variable::getName() const { - return _name; - } +Variable::Variable(const NDArray &array, const std::string &name, int id, + int idx) { + _ndarray = std::make_shared(array); - void sd::graph::Variable::setName(const std::string &name) { - _name = name; - } + if (!name.empty()) _name = name; - int sd::graph::Variable::id() const { - return _id; - } + _id = id; + _index = idx; +} + +Variable::Variable() { + // +} + +void sd::graph::Variable::setIndex(int index) { _index = index; } + +bool sd::graph::Variable::hasNDArray() const { + return _ndarray.get() != nullptr; +} - int sd::graph::Variable::index() const { - return _index; - } +void sd::graph::Variable::setVariableType(VariableType variableType) { + _variableType = variableType; +} - void sd::graph::Variable::setId(int id) { - _id = id; - } +bool sd::graph::Variable::hasNDArrayList() const { return _list != nullptr; } - bool sd::graph::Variable::isEmpty() const { - if (_variableType == VariableType::NDARRAY) - return _ndarray == nullptr || !_ndarray->nonNull(); - else if (_variableType == VariableType::ARRAY_LIST) - return _list == nullptr; +bool sd::graph::Variable::isPlaceholder() const { return _placeholder; } - return false; - } +const std::string &sd::graph::Variable::name() const { return _name; } - bool sd::graph::Variable::isExternal() const { - return _external; - } +const std::string &sd::graph::Variable::getName() const { return _name; } - bool sd::graph::Variable::isReadOnly() const { - return _readOnly; - } +void sd::graph::Variable::setName(const std::string &name) { _name = name; } - void sd::graph::Variable::markExternal(bool reallyExternal) { - this->_external = reallyExternal; - } +int sd::graph::Variable::id() const { return _id; } - void sd::graph::Variable::markRemovable(bool reallyRemovable) { - if (!reallyRemovable) - nd4j_debug("",""); - this->_removable = reallyRemovable; - } - - void sd::graph::Variable::markReadOnly(bool reallyReadOnly) { - this->_readOnly = reallyReadOnly; - } - - std::shared_ptr sd::graph::Variable::getNDArray() const { - if (_variableType != VariableType::NDARRAY) { - nd4j_printf("Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); - } - - if (this->_ndarray.get() == nullptr) { - if (_name.empty()) { - auto nodeId = StringUtils::valueToString(this->id()); - auto outputIndex = StringUtils::valueToString(this->index()); - throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + ":" + outputIndex + ">"); - } else { - auto outputIndex = StringUtils::valueToString(this->index()); - throw std::runtime_error("Array doesn't exist for Variable <" + this->_name + ":" + outputIndex+ ">"); - } - } - - return this->_ndarray; - } - - std::shared_ptr sd::graph::Variable::getNDArrayList() const { - if (_variableType != VariableType::ARRAY_LIST) { - nd4j_debug("Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was requested\n", this->_id, this->_index, this->_name.c_str(), EnumUtils::_VariableTypeToString(_variableType)); - } - return this->_list; - } - - - bool Variable::isRemovable() const { - return _removable; - } - - - void sd::graph::Variable::setNDArrayList(std::shared_ptr list) { - this->_variableType = VariableType::ARRAY_LIST; - this->_list = list; - } - - - void sd::graph::Variable::setNDArray(std::shared_ptr array) { - this->_variableType = VariableType::NDARRAY; - this->_ndarray = array; - } - - - VariableType sd::graph::Variable::variableType() const { - return _variableType; - } - - - sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { - auto vid = flatVariable->id(); - this->_id = vid->first(); - this->_index = vid->second(); - - if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0) - this->_name = flatVariable->name()->str(); - - _external = true; - _readOnly = false; - - int8_t *buffer = nullptr; - - switch (flatVariable->variabletype()) { - case VarType_VARIABLE: { - - // ????? - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); - } - - _variableType = VariableType::NDARRAY; - } - break; - case VarType_CONSTANT: { - if (flatVariable->ndarray() == nullptr) - throw std::runtime_error("CONSTANT variable must have NDArray bundled"); - - auto ar = flatVariable->ndarray(); - if (ar->dtype() == DType_UTF8) { - _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); - } else { - _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); - } +int sd::graph::Variable::index() const { return _index; } - _variableType = VariableType::NDARRAY; - } - break; - case VarType_ARRAY: { +void sd::graph::Variable::setId(int id) { _id = id; } - // ????? - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); - // _ndarray->triggerAllocationFlag(true); - } - - _variableType = VariableType::NDARRAY; - } - break; - case VarType_PLACEHOLDER: { - if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr) - throw std::runtime_error("PLACEHOLDER variable must have shape defined"); - - if (flatVariable->ndarray() != nullptr) { - auto ar = flatVariable->ndarray(); - _ndarray = std::make_shared(sd::graph::FlatUtils::fromFlatArray(ar)); - // _ndarray->triggerAllocationFlag(true); - - _variableType = VariableType::NDARRAY; - } +bool sd::graph::Variable::isEmpty() const { + if (_variableType == VariableType::NDARRAY) + return _ndarray == nullptr || !_ndarray->nonNull(); + else if (_variableType == VariableType::ARRAY_LIST) + return _list == nullptr; - if (flatVariable->shape() != nullptr) { - int shapeLen = flatVariable->shape()->Length(); - for (int i = 0; i < flatVariable->shape()->size(); i++) - _shape.emplace_back(flatVariable->shape()->Get(i)); + return false; +} - if (_ndarray == nullptr) - _variableType = VariableType::PLACEHOLDER; - } - } - break; - default: - throw std::runtime_error("Unknown variable type used"); - } - } +bool sd::graph::Variable::isExternal() const { return _external; } - const std::vector& sd::graph::Variable::shape() const { - return _shape; - } +bool sd::graph::Variable::isReadOnly() const { return _readOnly; } - sd::graph::Variable::Variable(bool placeholder, DataType dataType, const std::vector &shape) { - _placeholder = placeholder; - _dtype = dataType; - _shape = shape; - } +void sd::graph::Variable::markExternal(bool reallyExternal) { + this->_external = reallyExternal; +} +void sd::graph::Variable::markRemovable(bool reallyRemovable) { + if (!reallyRemovable) nd4j_debug("", ""); + this->_removable = reallyRemovable; +} - sd::graph::Variable::Variable(std::shared_ptr array, const char *name ) { - _ndarray = array; - - _external = false; - _readOnly = false; - - if (name != nullptr) - _name = std::string(name); - - if (_ndarray != nullptr) - _variableType = VariableType::NDARRAY; - } - - DataType Variable::dataType() const { - return _dtype; - } - - sd::graph::Variable::Variable(std::shared_ptr array, const std::string &name, int id, int idx) : Variable(array, name.c_str()) { - _id = id; - _index = idx; - } - - - sd::graph::Variable::~Variable() { - // - } - - - void Variable::setId(int id, int idx) { - _id = id; - _index = idx; - } - - - flatbuffers::Offset Variable::asFlatVariable(flatbuffers::FlatBufferBuilder &builder) { - if (this->hasNDArray()) { - auto array = this->getNDArray(); - auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); - - auto fBuffer = builder.CreateVector(array->asByteVector()); - - // packing array - auto fArray = CreateFlatArray(builder, fShape, fBuffer, (sd::graph::DType) array->dataType()); - - // packing id/index of this var - auto fVid = CreateIntPair(builder, this->_id, this->_index); - - // name is still optional - flatbuffers::Offset stringId = 0; - if (!this->_name.empty()) - stringId = builder.CreateString(this->_name); +void sd::graph::Variable::markReadOnly(bool reallyReadOnly) { + this->_readOnly = reallyReadOnly; +} - // returning array - return CreateFlatVariable(builder, fVid, stringId, static_cast(array->dataType()), 0, fArray); - } else { - throw std::runtime_error("Variable::asFlatVariable isn't possible for NDArrayList"); - } - } +std::shared_ptr sd::graph::Variable::getNDArray() const { + if (_variableType != VariableType::NDARRAY) { + nd4j_printf( + "Variable[%i:%i/<%s>] is has [%s] type, but NDArray was requested\n", + this->_id, this->_index, this->_name.c_str(), + EnumUtils::_VariableTypeToString(_variableType)); + } + + if (this->_ndarray.get() == nullptr) { + if (_name.empty()) { + auto nodeId = StringUtils::valueToString(this->id()); + auto outputIndex = StringUtils::valueToString(this->index()); + throw std::runtime_error("Array doesn't exist for Variable <" + nodeId + + ":" + outputIndex + ">"); + } else { + auto outputIndex = StringUtils::valueToString(this->index()); + throw std::runtime_error("Array doesn't exist for Variable <" + + this->_name + ":" + outputIndex + ">"); } + } + + return this->_ndarray; +} + +std::shared_ptr sd::graph::Variable::getNDArrayList() const { + if (_variableType != VariableType::ARRAY_LIST) { + nd4j_debug( + "Variable[%i:%i/<%s>] is has [%s] type, but NDArrayList was " + "requested\n", + this->_id, this->_index, this->_name.c_str(), + EnumUtils::_VariableTypeToString(_variableType)); + } + return this->_list; +} + +bool Variable::isRemovable() const { return _removable; } + +void sd::graph::Variable::setNDArrayList( + std::shared_ptr list) { + this->_variableType = VariableType::ARRAY_LIST; + this->_list = list; +} + +void sd::graph::Variable::setNDArray(std::shared_ptr array) { + this->_variableType = VariableType::NDARRAY; + this->_ndarray = array; } +VariableType sd::graph::Variable::variableType() const { return _variableType; } + +sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { + auto vid = flatVariable->id(); + this->_id = vid->first(); + this->_index = vid->second(); + + if (flatVariable->name() != nullptr && flatVariable->name()->size() != 0) + this->_name = flatVariable->name()->str(); + + _external = true; + _readOnly = false; + + int8_t *buffer = nullptr; + + switch (flatVariable->variabletype()) { + case VarType_VARIABLE: { + // ????? + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_CONSTANT: { + if (flatVariable->ndarray() == nullptr) + throw std::runtime_error("CONSTANT variable must have NDArray bundled"); + + auto ar = flatVariable->ndarray(); + if (ar->dtype() == DType_UTF8) { + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } else { + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_ARRAY: { + // ????? + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + // _ndarray->triggerAllocationFlag(true); + } + + _variableType = VariableType::NDARRAY; + } break; + case VarType_PLACEHOLDER: { + if (flatVariable->shape() == nullptr && + flatVariable->ndarray() == nullptr) + throw std::runtime_error( + "PLACEHOLDER variable must have shape defined"); + + if (flatVariable->ndarray() != nullptr) { + auto ar = flatVariable->ndarray(); + _ndarray = std::make_shared( + sd::graph::FlatUtils::fromFlatArray(ar)); + // _ndarray->triggerAllocationFlag(true); + + _variableType = VariableType::NDARRAY; + } + + if (flatVariable->shape() != nullptr) { + int shapeLen = flatVariable->shape()->Length(); + for (int i = 0; i < flatVariable->shape()->size(); i++) + _shape.emplace_back(flatVariable->shape()->Get(i)); + + if (_ndarray == nullptr) _variableType = VariableType::PLACEHOLDER; + } + } break; + default: + throw std::runtime_error("Unknown variable type used"); + } +} + +const std::vector &sd::graph::Variable::shape() const { + return _shape; +} + +sd::graph::Variable::Variable(bool placeholder, DataType dataType, + const std::vector &shape) { + _placeholder = placeholder; + _dtype = dataType; + _shape = shape; +} + +sd::graph::Variable::Variable(std::shared_ptr array, + const char *name) { + _ndarray = array; + + _external = false; + _readOnly = false; + + if (name != nullptr) _name = std::string(name); + + if (_ndarray != nullptr) _variableType = VariableType::NDARRAY; +} + +DataType Variable::dataType() const { return _dtype; } + +sd::graph::Variable::Variable(std::shared_ptr array, + const std::string &name, int id, int idx) + : Variable(array, name.c_str()) { + _id = id; + _index = idx; +} + +sd::graph::Variable::~Variable() { + // +} + +void Variable::setId(int id, int idx) { + _id = id; + _index = idx; +} + +flatbuffers::Offset Variable::asFlatVariable( + flatbuffers::FlatBufferBuilder &builder) { + if (this->hasNDArray()) { + auto array = this->getNDArray(); + auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); + + auto fBuffer = builder.CreateVector(array->asByteVector()); + + // packing array + auto fArray = CreateFlatArray(builder, fShape, fBuffer, + (sd::graph::DType)array->dataType()); + + // packing id/index of this var + auto fVid = CreateIntPair(builder, this->_id, this->_index); + + // name is still optional + flatbuffers::Offset stringId = 0; + if (!this->_name.empty()) stringId = builder.CreateString(this->_name); + + // returning array + return CreateFlatVariable(builder, fVid, stringId, + static_cast(array->dataType()), + 0, fArray); + } else { + throw std::runtime_error( + "Variable::asFlatVariable isn't possible for NDArrayList"); + } +} +} // namespace graph +} // namespace sd + namespace std { - size_t hash>::operator()(const std::pair& k) const { - auto v = std::hash()(k.first); - v ^= std::hash()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2); - return v; - } +size_t hash>::operator()( + const std::pair &k) const { + auto v = std::hash()(k.first); + v ^= std::hash()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2); + return v; +} - size_t hash::operator()(const bfloat16& k) const { - return std::hash()((float)k); - } +size_t hash::operator()(const bfloat16 &k) const { + return std::hash()((float)k); +} - size_t hash::operator()(const float16& k) const { - return std::hash()((float)k); - } -} \ No newline at end of file +size_t hash::operator()(const float16 &k) const { + return std::hash()((float)k); +} +} // namespace std \ No newline at end of file diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index a89e8b5071e2..d511dc9e4140 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -18,234 +18,205 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace graph { - - VariableProxy::VariableProxy(const VariableSpace* ref) { - if (ref == nullptr) - _backed = new VariableSpace(); - - _backed = ref; - _current = new VariableSpace(); - } - - - VariableProxy::~VariableProxy() { - delete _current; - } - - - int VariableProxy::numberOfPlaceholders() const { - return _backed->numberOfPlaceholders(); - } - - - const std::vector>& VariableProxy::placeholders() const { - return _backed->placeholders(); - } - - bool VariableProxy::hasExternalVariable(int it) const { - return _backed->hasExternalVariable(it); - } - - - bool VariableProxy::hasExternalVariable(const std::pair& pair) const { - return _backed->hasExternalVariable(pair); - } - - - bool VariableProxy::hasExternalVariable(const std::string &symbol) const { - return _backed->hasExternalVariable(symbol); - } - - - bool VariableProxy::hasVariable(int id) const { - return _current->hasVariable(id) || _backed->hasVariable(id); - } - - - bool VariableProxy::hasVariable(int id, int idx) const { - return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); - } - - - bool VariableProxy::hasVariable(const std::pair& pair) const { - return _current->hasVariable(pair) || _backed->hasVariable(pair); - } - - - void VariableProxy::dropVariable(const std::pair &pair) { - dropVariable(pair.first, pair.second); - } - - - void VariableProxy::dropVariable(int id, int idx) { - assert(_current->hasVariable(id, idx)); - - _current->dropVariable(id, idx); - } - - - std::vector> VariableProxy::variables() const { - std::vector> result; - - auto b = _backed->variables(); - auto c = _current->variables(); - - for (auto v: b) - result.emplace_back(v); - - for (auto v: c) - result.emplace_back(v); - - return result; - } - - - bool VariableProxy::hasVariable(const std::string &symbol) const { - return _current->hasVariable(symbol) || _backed->hasVariable(symbol); - } - - - std::shared_ptr VariableProxy::getVariable(int id) const { - if (_current->hasVariable(id)) - return _current->getVariable(id); - - if (_backed->hasVariable(id)) - return _backed->getVariable(id); - - nd4j_printf("Unable to get Variable from proxy: [%i]\n", id); - throw std::runtime_error("Bad arguments"); - } - - - std::shared_ptr VariableProxy::getVariable(int id, int idx) const { - if (_current->hasVariable(id, idx)) - return _current->getVariable(id, idx); - - if (_backed->hasVariable(id, idx)) - return _backed->getVariable(id, idx); - - nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", id, idx); - throw std::runtime_error("Bad arguments"); - } - - - std::shared_ptr VariableProxy::getVariable(const std::pair& pair) const { - if (_current->hasVariable(pair)) - return _current->getVariable(pair); - - if (_backed->hasVariable(pair)) - return _backed->getVariable(pair); - - nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", pair.first, pair.second); - throw std::runtime_error("Bad arguments"); - } - - - std::shared_ptr VariableProxy::getVariable(const std::string &symbol) const { - if (_current->hasVariable(symbol)) - return _current->getVariable(symbol); - - if (_backed->hasVariable(symbol)) - return _backed->getVariable(symbol); - - nd4j_printf("Unable to get Variable from proxy: [%s]\n", symbol.c_str()); - throw std::runtime_error("Bad arguments"); - } - - - void VariableProxy::replaceVariable(std::shared_ptr variable) { - if (!variable->getName().empty()) { - // if variable has name defined - we should resolve it via backing var space - if (_backed->hasVariable(variable->getName())) { - auto origVar = _backed->getVariable(variable->getName()); - variable->setId(origVar->id(), origVar->index()); - _current->replaceVariable(variable); - } else - _current->replaceVariable(variable); - } else // if proxy has variable - that's one story - _current->replaceVariable(variable); - } - - std::shared_ptr - VariableProxy::putVariable(const std::string &name, int id, int idx, const NDArray &array) { - return _current->putVariable(name, id, idx, array); - } - - void VariableProxy::putOutputVariable(std::shared_ptr variable) { - _current->putOutputVariable(variable); - } - - std::shared_ptr VariableProxy::putVariable(const std::pair& pair, const NDArray &array) { - return _current->putVariable(pair, array); - } - - - void VariableProxy::putVariable(const std::pair& pair, const std::shared_ptr &variable) { - _current->putVariable(pair, variable); - } - - - void VariableProxy::putVariable(int id, const std::shared_ptr &variable) { - _current->putVariable(id, variable); - } - - - std::shared_ptr VariableProxy::putVariable(int id, const NDArray &array) { - return _current->putVariable(id, array); - } - - std::shared_ptr VariableProxy::putVariable(int id, int idx, const NDArray &array) { - return _current->putVariable(id, idx, array); - } - - void VariableProxy::putVariable(const std::string& name, int id, int idx, const std::shared_ptr &array) { - _current->putVariable(name, id, idx, array); - } - - Stash* VariableProxy::stash() const { - return _current->stash(); - } - - Nd4jLong VariableProxy::externalMemory() const { - return _backed->externalMemory() + _current->externalMemory(); - } - - - Nd4jLong VariableProxy::internalMemory() const { - return _backed->internalMemory() + _current->internalMemory(); - } - - - Nd4jLong VariableProxy::totalMemory() const { - return _backed->totalMemory() + _current->totalMemory(); - } - - - int VariableProxy::externalEntries() const { - return _backed->externalEntries() + _current->externalEntries(); - } - - - int VariableProxy::internalEntries() const { - return _backed->internalEntries() + _current->internalEntries(); - } - - - int VariableProxy::totalEntries() const { - return _backed->totalEntries() + _current->totalEntries(); - } - - VariableSpace& VariableProxy::operator=(const VariableSpace& other) { - if (this == &other) return *this; - - nd4j_printf("VariableProxy = not implemented\n",""); - - return *this; - } - } +namespace graph { + +VariableProxy::VariableProxy(const VariableSpace *ref) { + if (ref == nullptr) _backed = new VariableSpace(); + + _backed = ref; + _current = new VariableSpace(); +} + +VariableProxy::~VariableProxy() { delete _current; } + +int VariableProxy::numberOfPlaceholders() const { + return _backed->numberOfPlaceholders(); +} + +const std::vector> &VariableProxy::placeholders() + const { + return _backed->placeholders(); +} + +bool VariableProxy::hasExternalVariable(int it) const { + return _backed->hasExternalVariable(it); +} + +bool VariableProxy::hasExternalVariable(const std::pair &pair) const { + return _backed->hasExternalVariable(pair); +} + +bool VariableProxy::hasExternalVariable(const std::string &symbol) const { + return _backed->hasExternalVariable(symbol); +} + +bool VariableProxy::hasVariable(int id) const { + return _current->hasVariable(id) || _backed->hasVariable(id); +} + +bool VariableProxy::hasVariable(int id, int idx) const { + return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); +} + +bool VariableProxy::hasVariable(const std::pair &pair) const { + return _current->hasVariable(pair) || _backed->hasVariable(pair); +} + +void VariableProxy::dropVariable(const std::pair &pair) { + dropVariable(pair.first, pair.second); +} + +void VariableProxy::dropVariable(int id, int idx) { + assert(_current->hasVariable(id, idx)); + + _current->dropVariable(id, idx); +} + +std::vector> VariableProxy::variables() const { + std::vector> result; + + auto b = _backed->variables(); + auto c = _current->variables(); + + for (auto v : b) result.emplace_back(v); + + for (auto v : c) result.emplace_back(v); + + return result; +} + +bool VariableProxy::hasVariable(const std::string &symbol) const { + return _current->hasVariable(symbol) || _backed->hasVariable(symbol); +} + +std::shared_ptr VariableProxy::getVariable(int id) const { + if (_current->hasVariable(id)) return _current->getVariable(id); + + if (_backed->hasVariable(id)) return _backed->getVariable(id); + + nd4j_printf("Unable to get Variable from proxy: [%i]\n", id); + throw std::runtime_error("Bad arguments"); +} + +std::shared_ptr VariableProxy::getVariable(int id, int idx) const { + if (_current->hasVariable(id, idx)) return _current->getVariable(id, idx); + + if (_backed->hasVariable(id, idx)) return _backed->getVariable(id, idx); + + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", id, idx); + throw std::runtime_error("Bad arguments"); +} + +std::shared_ptr VariableProxy::getVariable( + const std::pair &pair) const { + if (_current->hasVariable(pair)) return _current->getVariable(pair); + + if (_backed->hasVariable(pair)) return _backed->getVariable(pair); + + nd4j_printf("Unable to get Variable from proxy: [%i:%i]\n", pair.first, + pair.second); + throw std::runtime_error("Bad arguments"); +} + +std::shared_ptr VariableProxy::getVariable( + const std::string &symbol) const { + if (_current->hasVariable(symbol)) return _current->getVariable(symbol); + + if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); + + nd4j_printf("Unable to get Variable from proxy: [%s]\n", symbol.c_str()); + throw std::runtime_error("Bad arguments"); +} + +void VariableProxy::replaceVariable(std::shared_ptr variable) { + if (!variable->getName().empty()) { + // if variable has name defined - we should resolve it via backing var space + if (_backed->hasVariable(variable->getName())) { + auto origVar = _backed->getVariable(variable->getName()); + variable->setId(origVar->id(), origVar->index()); + _current->replaceVariable(variable); + } else + _current->replaceVariable(variable); + } else // if proxy has variable - that's one story + _current->replaceVariable(variable); +} + +std::shared_ptr VariableProxy::putVariable(const std::string &name, + int id, int idx, + const NDArray &array) { + return _current->putVariable(name, id, idx, array); +} + +void VariableProxy::putOutputVariable(std::shared_ptr variable) { + _current->putOutputVariable(variable); +} + +std::shared_ptr VariableProxy::putVariable( + const std::pair &pair, const NDArray &array) { + return _current->putVariable(pair, array); +} + +void VariableProxy::putVariable(const std::pair &pair, + const std::shared_ptr &variable) { + _current->putVariable(pair, variable); +} + +void VariableProxy::putVariable(int id, + const std::shared_ptr &variable) { + _current->putVariable(id, variable); +} + +std::shared_ptr VariableProxy::putVariable(int id, + const NDArray &array) { + return _current->putVariable(id, array); +} + +std::shared_ptr VariableProxy::putVariable(int id, int idx, + const NDArray &array) { + return _current->putVariable(id, idx, array); +} + +void VariableProxy::putVariable(const std::string &name, int id, int idx, + const std::shared_ptr &array) { + _current->putVariable(name, id, idx, array); +} + +Stash *VariableProxy::stash() const { return _current->stash(); } + +Nd4jLong VariableProxy::externalMemory() const { + return _backed->externalMemory() + _current->externalMemory(); +} + +Nd4jLong VariableProxy::internalMemory() const { + return _backed->internalMemory() + _current->internalMemory(); +} + +Nd4jLong VariableProxy::totalMemory() const { + return _backed->totalMemory() + _current->totalMemory(); +} + +int VariableProxy::externalEntries() const { + return _backed->externalEntries() + _current->externalEntries(); +} + +int VariableProxy::internalEntries() const { + return _backed->internalEntries() + _current->internalEntries(); +} + +int VariableProxy::totalEntries() const { + return _backed->totalEntries() + _current->totalEntries(); +} + +VariableSpace &VariableProxy::operator=(const VariableSpace &other) { + if (this == &other) return *this; + + nd4j_printf("VariableProxy = not implemented\n", ""); + + return *this; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 818ef28b8b90..0e4307e69097 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -22,370 +22,362 @@ #include namespace sd { - namespace graph { - Stash* VariableSpace::stash() const { - return const_cast(&_stash); - } - - void VariableSpace::injectVariable(const std::pair &pair, std::shared_ptr variable) { - if (pair.second == 0) { - this->_variables[pair.first] = variable; - } - - if (!variable->getName().empty()) - this->_symbolic[variable->getName()] = variable; - - this->_paired[pair] = variable; - } - - const std::vector>& VariableSpace::placeholders() const { - return _placeholders; - } - - int VariableSpace::numberOfPlaceholders() const { - return _placeholders.size(); - } - - bool VariableSpace::hasVariable(const std::string &symbol) const { - return _symbolic.count(symbol) > 0; - } - - std::shared_ptr VariableSpace::getVariable(const std::string &symbol) const { - return _symbolic.at(symbol); - } - - bool VariableSpace::hasVariable(int id, int index) const { - std::pair pair(id, index); - return hasVariable(pair); - } - - bool VariableSpace::hasExternalVariable(int id) const { - if (!hasVariable(id)) - return false; - - auto var = getVariable(id); - return var->isExternal(); - } - - bool VariableSpace::hasExternalVariable(const std::pair& pair) const { - if (!hasVariable(pair)) - return false; - - auto var = getVariable(pair); - return var->isExternal(); - } - - bool VariableSpace::hasExternalVariable(const std::string &symbol) const { - if (!hasVariable(symbol)) - return false; - - auto var = getVariable(symbol); - return var->isExternal(); - } - - std::shared_ptr VariableSpace::getVariable(int id, int index) const { - std::pair pair(id, index); - return getVariable(pair); - } - - std::shared_ptr VariableSpace::getVariable(const std::pair& pair) const { - if (pair.first < 0) - return getVariable(pair.first); - else - return _paired.at(pair); - - nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); - throw std::runtime_error("Unknown variable requested"); - } - - bool VariableSpace::hasVariable(int id) const { - return _variables.count(id) > 0; - } - - bool VariableSpace::hasVariable(const std::pair& id) const { - return _paired.count(id) > 0; - } - - void VariableSpace::putOutputVariable(std::shared_ptr variable) { - //putVariable(_auto_counter--, variable); - putVariable(variable->id(), variable); - } - - int VariableSpace::externalEntries() const { - return _external.size(); - } - - int VariableSpace::internalEntries() const { - return _internal.size(); - } - - int VariableSpace::totalEntries() const { - return externalEntries() + internalEntries(); - } - - Nd4jLong VariableSpace::externalMemory() const { - Nd4jLong size = 0; - for (auto n: _external) { - size += n->getNDArray()->memoryFootprint(); - } +namespace graph { +Stash *VariableSpace::stash() const { return const_cast(&_stash); } - return size; - } +void VariableSpace::injectVariable(const std::pair &pair, + std::shared_ptr variable) { + if (pair.second == 0) { + this->_variables[pair.first] = variable; + } - std::vector> VariableSpace::variables() const { - std::vector> result; + if (!variable->getName().empty()) + this->_symbolic[variable->getName()] = variable; - for (auto v: _internal) - result.emplace_back(v); - - for (auto v: _external) - result.emplace_back(v); - - return result; - } - - Nd4jLong VariableSpace::internalMemory() const { - Nd4jLong size = 0; - for (auto n: _internal) { - size += n->getNDArray()->memoryFootprint(); - } - - return size; - } - - Nd4jLong VariableSpace::totalMemory() const { - return externalMemory() + internalMemory(); - } - - std::shared_ptr VariableSpace::putVariable(int id, int idx, const std::shared_ptr &array) { - auto variable = std::make_shared(array, "", id, idx); - this->putVariable({id, idx}, variable); - return variable; - } - - std::shared_ptr - VariableSpace::putVariable(const std::string &name, int id, int idx, const NDArray &array) { - auto variable = std::make_shared(array, name, id, idx); - this->putVariable({id, idx}, variable); - return variable; - } - - void VariableSpace::dropVariable(const std::string &pair) { - throw std::runtime_error("VariableSpace::dropVariable - not implemented yet"); - } + this->_paired[pair] = variable; +} +const std::vector> &VariableSpace::placeholders() + const { + return _placeholders; +} - std::shared_ptr VariableSpace::putVariable(const std::pair& pair, const NDArray &array) { - auto variable = std::make_shared(array, "", pair.first, pair.second); - this->putVariable(pair, variable); - return variable; - } +int VariableSpace::numberOfPlaceholders() const { return _placeholders.size(); } - std::shared_ptr VariableSpace::putVariable(int node, int idx, const NDArray &array) { - std::pair pair(node, idx); - return this->putVariable(pair, array); - } +bool VariableSpace::hasVariable(const std::string &symbol) const { + return _symbolic.count(symbol) > 0; +} - void VariableSpace::putVariable(const std::string& name, int node, int idx, const std::shared_ptr &variable) { - std::pair pair(node, idx); - variable->setName(name); - this->putVariable(pair, variable); - } +std::shared_ptr VariableSpace::getVariable( + const std::string &symbol) const { + return _symbolic.at(symbol); +} - void VariableSpace::silentPutVariable(const std::pair& pair, const std::shared_ptr &variable) { - std::lock_guard lock(_varmap); +bool VariableSpace::hasVariable(int id, int index) const { + std::pair pair(id, index); + return hasVariable(pair); +} - _paired[pair] = variable; - } +bool VariableSpace::hasExternalVariable(int id) const { + if (!hasVariable(id)) return false; - void VariableSpace::putVariable(const std::pair& pair, const std::shared_ptr &variable) { - silentPutVariable(pair, variable); + auto var = getVariable(id); + return var->isExternal(); +} - if (variable->isPlaceholder()) - _placeholders.emplace_back(variable); +bool VariableSpace::hasExternalVariable(const std::pair &pair) const { + if (!hasVariable(pair)) return false; - // copying duplicate for compatibility - if (pair.second == 0 && !this->hasVariable(pair.first)) { - this->putVariable(pair.first, variable); - } - - if (!variable->getName().empty()) { - _symbolic[variable->getName()] = variable; - } - } - - - void VariableSpace::putVariable(int id, const std::shared_ptr &variable) { - // we don't want to add variables more then once - if (_variables.count(id) > 0) { - throw std::runtime_error("VariableSpace::putVariable - duplicate found"); - } - - { - std::lock_guard lock(_varmap); - - if (_auto_counter >= id) - _auto_counter = id - 1; - - variable->setId(id); - - if (!variable->getName().empty()) { - //std::pair pair(*(variable->getName()), variable); - _symbolic[variable->name()] = variable; - } - - // we have special list for external variables to ensure graph completeness - if (id < 0) { - _external.emplace_back(variable); - } else { - _internal.emplace_back(variable); - } - - _variables[id] = variable; - } - - - std::pair pair(id, 0); - if (!hasVariable(pair)) { - this->silentPutVariable(pair, variable); - - if (variable->isPlaceholder()) - _placeholders.emplace_back(variable); - } - } - - std::shared_ptr VariableSpace::putVariable(int id, const NDArray &array) { - auto var = std::make_shared(array, "", id, 0); - this->putVariable(id, var); - return var; - } - - std::shared_ptr VariableSpace::getVariable(int id) const { - return _variables.at(id); - } + auto var = getVariable(pair); + return var->isExternal(); +} + +bool VariableSpace::hasExternalVariable(const std::string &symbol) const { + if (!hasVariable(symbol)) return false; + + auto var = getVariable(symbol); + return var->isExternal(); +} + +std::shared_ptr VariableSpace::getVariable(int id, int index) const { + std::pair pair(id, index); + return getVariable(pair); +} + +std::shared_ptr VariableSpace::getVariable( + const std::pair &pair) const { + if (pair.first < 0) + return getVariable(pair.first); + else + return _paired.at(pair); + + nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); + throw std::runtime_error("Unknown variable requested"); +} + +bool VariableSpace::hasVariable(int id) const { + return _variables.count(id) > 0; +} + +bool VariableSpace::hasVariable(const std::pair &id) const { + return _paired.count(id) > 0; +} + +void VariableSpace::putOutputVariable(std::shared_ptr variable) { + // putVariable(_auto_counter--, variable); + putVariable(variable->id(), variable); +} + +int VariableSpace::externalEntries() const { return _external.size(); } + +int VariableSpace::internalEntries() const { return _internal.size(); } + +int VariableSpace::totalEntries() const { + return externalEntries() + internalEntries(); +} + +Nd4jLong VariableSpace::externalMemory() const { + Nd4jLong size = 0; + for (auto n : _external) { + size += n->getNDArray()->memoryFootprint(); + } + + return size; +} + +std::vector> VariableSpace::variables() const { + std::vector> result; + + for (auto v : _internal) result.emplace_back(v); + + for (auto v : _external) result.emplace_back(v); + + return result; +} + +Nd4jLong VariableSpace::internalMemory() const { + Nd4jLong size = 0; + for (auto n : _internal) { + size += n->getNDArray()->memoryFootprint(); + } + + return size; +} + +Nd4jLong VariableSpace::totalMemory() const { + return externalMemory() + internalMemory(); +} + +std::shared_ptr VariableSpace::putVariable( + int id, int idx, const std::shared_ptr &array) { + auto variable = std::make_shared(array, "", id, idx); + this->putVariable({id, idx}, variable); + return variable; +} + +std::shared_ptr VariableSpace::putVariable(const std::string &name, + int id, int idx, + const NDArray &array) { + auto variable = std::make_shared(array, name, id, idx); + this->putVariable({id, idx}, variable); + return variable; +} + +void VariableSpace::dropVariable(const std::string &pair) { + throw std::runtime_error("VariableSpace::dropVariable - not implemented yet"); +} + +std::shared_ptr VariableSpace::putVariable( + const std::pair &pair, const NDArray &array) { + auto variable = + std::make_shared(array, "", pair.first, pair.second); + this->putVariable(pair, variable); + return variable; +} + +std::shared_ptr VariableSpace::putVariable(int node, int idx, + const NDArray &array) { + std::pair pair(node, idx); + return this->putVariable(pair, array); +} + +void VariableSpace::putVariable(const std::string &name, int node, int idx, + const std::shared_ptr &variable) { + std::pair pair(node, idx); + variable->setName(name); + this->putVariable(pair, variable); +} + +void VariableSpace::silentPutVariable( + const std::pair &pair, + const std::shared_ptr &variable) { + std::lock_guard lock(_varmap); + + _paired[pair] = variable; +} + +void VariableSpace::putVariable(const std::pair &pair, + const std::shared_ptr &variable) { + silentPutVariable(pair, variable); + + if (variable->isPlaceholder()) _placeholders.emplace_back(variable); + + // copying duplicate for compatibility + if (pair.second == 0 && !this->hasVariable(pair.first)) { + this->putVariable(pair.first, variable); + } + + if (!variable->getName().empty()) { + _symbolic[variable->getName()] = variable; + } +} + +void VariableSpace::putVariable(int id, + const std::shared_ptr &variable) { + // we don't want to add variables more then once + if (_variables.count(id) > 0) { + throw std::runtime_error("VariableSpace::putVariable - duplicate found"); + } + + { + std::lock_guard lock(_varmap); + + if (_auto_counter >= id) _auto_counter = id - 1; + + variable->setId(id); + + if (!variable->getName().empty()) { + // std::pair pair(*(variable->getName()), + // variable); + _symbolic[variable->name()] = variable; + } - VariableSpace::~VariableSpace() { - // - } + // we have special list for external variables to ensure graph completeness + if (id < 0) { + _external.emplace_back(variable); + } else { + _internal.emplace_back(variable); + } - VariableSpace::VariableSpace(const VariableSpace &other) { - _stash = other._stash; + _variables[id] = variable; + } - _paired = other._paired; - _symbolic = other._symbolic; - _variables = other._variables; + std::pair pair(id, 0); + if (!hasVariable(pair)) { + this->silentPutVariable(pair, variable); - _external = other._external; - _internal = other._internal; + if (variable->isPlaceholder()) _placeholders.emplace_back(variable); + } +} - _lists = other._lists; - _placeholders = other._placeholders; +std::shared_ptr VariableSpace::putVariable(int id, + const NDArray &array) { + auto var = std::make_shared(array, "", id, 0); + this->putVariable(id, var); + return var; +} +std::shared_ptr VariableSpace::getVariable(int id) const { + return _variables.at(id); +} - _auto_counter = other._auto_counter; - } +VariableSpace::~VariableSpace() { + // +} - VariableSpace::VariableSpace(VariableSpace &&other) { - _stash = std::move(other._stash); +VariableSpace::VariableSpace(const VariableSpace &other) { + _stash = other._stash; - _paired = std::move(other._paired); - _symbolic = std::move(other._symbolic); - _variables = std::move(other._variables); + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; - _external = std::move(other._external); - _internal = std::move(other._internal); + _external = other._external; + _internal = other._internal; - _lists = std::move(other._lists); - _placeholders = std::move(other._placeholders); + _lists = other._lists; + _placeholders = other._placeholders; + _auto_counter = other._auto_counter; +} - _auto_counter = other._auto_counter; - } +VariableSpace::VariableSpace(VariableSpace &&other) { + _stash = std::move(other._stash); - VariableSpace& VariableSpace::operator=(VariableSpace &&other) { - if (this == &other) return *this; + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); - _stash = std::move(other._stash); + _external = std::move(other._external); + _internal = std::move(other._internal); - _paired = std::move(other._paired); - _symbolic = std::move(other._symbolic); - _variables = std::move(other._variables); + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); - _external = std::move(other._external); - _internal = std::move(other._internal); + _auto_counter = other._auto_counter; +} - _lists = std::move(other._lists); - _placeholders = std::move(other._placeholders); +VariableSpace &VariableSpace::operator=(VariableSpace &&other) { + if (this == &other) return *this; + _stash = std::move(other._stash); - _auto_counter = other._auto_counter; + _paired = std::move(other._paired); + _symbolic = std::move(other._symbolic); + _variables = std::move(other._variables); - return *this; - } + _external = std::move(other._external); + _internal = std::move(other._internal); - VariableSpace& VariableSpace::operator=(const VariableSpace& other) { - if (this == &other) return *this; + _lists = std::move(other._lists); + _placeholders = std::move(other._placeholders); - _stash = other._stash; + _auto_counter = other._auto_counter; - _paired = other._paired; - _symbolic = other._symbolic; - _variables = other._variables; + return *this; +} - _external = other._external; - _internal = other._internal; +VariableSpace &VariableSpace::operator=(const VariableSpace &other) { + if (this == &other) return *this; - _lists = other._lists; - _placeholders = other._placeholders; + _stash = other._stash; + _paired = other._paired; + _symbolic = other._symbolic; + _variables = other._variables; - _auto_counter = other._auto_counter; + _external = other._external; + _internal = other._internal; - return *this; - } + _lists = other._lists; + _placeholders = other._placeholders; - void VariableSpace::replaceVariable(std::shared_ptr variable) { - bool replaced = false; - // trying name first - if (!variable->getName().empty()) { - nd4j_printf("Trying to replace variable by name: [%s]\n", variable->getName().c_str()); - if (hasVariable(variable->getName())) { - nd4j_printf("Replacing by name: [%s]\n", variable->getName().c_str()); - auto vs = getVariable(variable->getName()); - dropVariable(vs->id(), vs->index()); + _auto_counter = other._auto_counter; - putVariable({vs->id(), vs->index()}, variable); - //delete vs; - replaced = true; - } - } else { - nd4j_printf("Trying to replace variable by id: [%i:%i]\n", variable->id(), variable->index()); - if (hasVariable(variable->id(), variable->index())) { - nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), variable->index()); - auto vs = getVariable(variable->id(), variable->index()); - dropVariable(variable->id(), variable->index()); - putVariable({vs->id(), vs->index()}, variable); - //delete vs; - replaced = true; - } - } + return *this; +} - if (!replaced) { - nd4j_printf("wasn't able to replace variable, putting\n", ""); - putVariable({variable->id(), variable->index()}, variable); - } - } +void VariableSpace::replaceVariable(std::shared_ptr variable) { + bool replaced = false; + // trying name first + if (!variable->getName().empty()) { + nd4j_printf("Trying to replace variable by name: [%s]\n", + variable->getName().c_str()); + if (hasVariable(variable->getName())) { + nd4j_printf("Replacing by name: [%s]\n", variable->getName().c_str()); + auto vs = getVariable(variable->getName()); + dropVariable(vs->id(), vs->index()); - void VariableSpace::dropVariable(const std::pair &pair) { - dropVariable(pair.first, pair.second); - } + putVariable({vs->id(), vs->index()}, variable); + // delete vs; + replaced = true; + } + } else { + nd4j_printf("Trying to replace variable by id: [%i:%i]\n", variable->id(), + variable->index()); + if (hasVariable(variable->id(), variable->index())) { + nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), + variable->index()); + auto vs = getVariable(variable->id(), variable->index()); + dropVariable(variable->id(), variable->index()); + putVariable({vs->id(), vs->index()}, variable); + // delete vs; + replaced = true; + } + } - void VariableSpace::dropVariable(int id, int idx) { + if (!replaced) { + nd4j_printf("wasn't able to replace variable, putting\n", ""); + putVariable({variable->id(), variable->index()}, variable); + } +} - } +void VariableSpace::dropVariable(const std::pair &pair) { + dropVariable(pair.first, pair.second); +} - VariableSpace::VariableSpace() { +void VariableSpace::dropVariable(int id, int idx) {} - } - } -} \ No newline at end of file +VariableSpace::VariableSpace() {} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/impl/VariablesSet.cpp b/libnd4j/include/graph/impl/VariablesSet.cpp index 80f8e3728949..2920014cdd44 100644 --- a/libnd4j/include/graph/impl/VariablesSet.cpp +++ b/libnd4j/include/graph/impl/VariablesSet.cpp @@ -21,30 +21,21 @@ #include namespace sd { - namespace graph { - Nd4jStatus VariablesSet::status() { - return _status; - } - - int VariablesSet::size() { - return _holder.size(); - } - - void VariablesSet::push_back(Variable *variable) { - _holder.push_back(variable); - } - - Variable *VariablesSet::at(int index) { - return _holder.at(index); - } - - VariablesSet::VariablesSet(Nd4jStatus status) { - _status = status; - } - - VariablesSet::~VariablesSet() { - for (auto v: _holder) - delete v; - } - } +namespace graph { +Nd4jStatus VariablesSet::status() { return _status; } + +int VariablesSet::size() { return _holder.size(); } + +void VariablesSet::push_back(Variable *variable) { + _holder.push_back(variable); +} + +Variable *VariablesSet::at(int index) { return _holder.at(index); } + +VariablesSet::VariablesSet(Nd4jStatus status) { _status = status; } + +VariablesSet::~VariablesSet() { + for (auto v : _holder) delete v; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/LogicConditional.h b/libnd4j/include/graph/logic/LogicConditional.h index ffaf6f098f99..d84aa584c76f 100644 --- a/libnd4j/include/graph/logic/LogicConditional.h +++ b/libnd4j/include/graph/logic/LogicConditional.h @@ -21,29 +21,29 @@ #ifndef LIBND4J_LOGICCONDITIONAL_H #define LIBND4J_LOGICCONDITIONAL_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Conditional logical abstraction - * - * TL/DR: Class takes 2 ops/scopes with the same number of inputs/outputs and condtion. - * Condition is evaluated, and based on its result - one of ops/scopes is executed. - * Results of this execution will be copied to Conditional node, and every other op - * in the graph will be sure that it's Conditional own result, both alternative nodes will - * stay in disguise. - * - * @tparam T - */ - class LogicConditional { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - +namespace graph { +/** + * This class is responsible for execution logic of Conditional logical + * abstraction + * + * TL/DR: Class takes 2 ops/scopes with the same number of inputs/outputs and + * condtion. Condition is evaluated, and based on its result - one of ops/scopes + * is executed. Results of this execution will be copied to Conditional node, + * and every other op in the graph will be sure that it's Conditional own + * result, both alternative nodes will stay in disguise. + * + * @tparam T + */ +class LogicConditional { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICCONDITIONAL_H +#endif // LIBND4J_LOGICCONDITIONAL_H diff --git a/libnd4j/include/graph/logic/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h index d770ff10a443..f0ba6a439767 100644 --- a/libnd4j/include/graph/logic/LogicEnter.h +++ b/libnd4j/include/graph/logic/LogicEnter.h @@ -21,18 +21,16 @@ #ifndef LIBND4J_LOGICENTER_H #define LIBND4J_LOGICENTER_H -#include #include +#include namespace sd { - namespace graph { - class LogicEnter { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicEnter { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXIT_H +#endif // LIBND4J_LOGICEXIT_H diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index 541b3fc8425b..3dea3ec2e9b9 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -21,23 +21,22 @@ #ifndef LIBND4J_LOGICEXECUTOR_H #define LIBND4J_LOGICEXECUTOR_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class acts as switch for picking logic execution based on opNum, unique for each logical op - * @tparam T - */ - class LogicExecutor { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +/** + * This class acts as switch for picking logic execution based on opNum, unique + * for each logical op + * @tparam T + */ +class LogicExecutor { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXECUTOR_H +#endif // LIBND4J_LOGICEXECUTOR_H diff --git a/libnd4j/include/graph/logic/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h index d182e26fbf39..216409c38251 100644 --- a/libnd4j/include/graph/logic/LogicExit.h +++ b/libnd4j/include/graph/logic/LogicExit.h @@ -21,18 +21,16 @@ #ifndef LIBND4J_LOGICEXIT_H #define LIBND4J_LOGICEXIT_H -#include #include +#include namespace sd { - namespace graph { - class LogicExit { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicExit { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXIT_H +#endif // LIBND4J_LOGICEXIT_H diff --git a/libnd4j/include/graph/logic/LogicExpose.h b/libnd4j/include/graph/logic/LogicExpose.h index 046f3e64e9a6..6e4bb5e1b937 100644 --- a/libnd4j/include/graph/logic/LogicExpose.h +++ b/libnd4j/include/graph/logic/LogicExpose.h @@ -21,19 +21,17 @@ #ifndef LIBND4J_LOGICEXPOSE_H #define LIBND4J_LOGICEXPOSE_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - class LogicExpose { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicExpose { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICEXPOSE_H +#endif // LIBND4J_LOGICEXPOSE_H diff --git a/libnd4j/include/graph/logic/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h index 36693232be90..670c0d07faee 100644 --- a/libnd4j/include/graph/logic/LogicLoopCond.h +++ b/libnd4j/include/graph/logic/LogicLoopCond.h @@ -21,18 +21,16 @@ #ifndef LIBND4J_LOGICLOOPCOND_H #define LIBND4J_LOGICLOOPCOND_H -#include #include +#include namespace sd { - namespace graph { - class LogicLoopCond { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicLoopCond { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICLOOPCOND_H +#endif // LIBND4J_LOGICLOOPCOND_H diff --git a/libnd4j/include/graph/logic/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h index fe20c9d660ed..8bd8cbe7d5a1 100644 --- a/libnd4j/include/graph/logic/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -21,18 +21,16 @@ #ifndef LIBND4J_LOGICMERGE_H #define LIBND4J_LOGICMERGE_H -#include #include +#include namespace sd { - namespace graph { - class LogicMerge { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicMerge { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICMERGE_H +#endif // LIBND4J_LOGICMERGE_H diff --git a/libnd4j/include/graph/logic/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h index 5b9600909ea9..415b44f6dbcc 100644 --- a/libnd4j/include/graph/logic/LogicNextIteration.h +++ b/libnd4j/include/graph/logic/LogicNextIteration.h @@ -21,18 +21,16 @@ #ifndef LIBND4J_LOGICNEXTITERATION_H #define LIBND4J_LOGICNEXTITERATION_H -#include #include +#include namespace sd { - namespace graph { - class LogicNextIeration { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - +namespace graph { +class LogicNextIeration { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICNEXTITERATION_H +#endif // LIBND4J_LOGICNEXTITERATION_H diff --git a/libnd4j/include/graph/logic/LogicReturn.h b/libnd4j/include/graph/logic/LogicReturn.h index 2cc6107c5f6b..8c342b091417 100644 --- a/libnd4j/include/graph/logic/LogicReturn.h +++ b/libnd4j/include/graph/logic/LogicReturn.h @@ -21,26 +21,24 @@ #ifndef LIBND4J_LOGICRETURN_H #define LIBND4J_LOGICRETURN_H - -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Return logical abstraction - * - * Basically we're just transferring input variable(s) to output variable(s), nothing beyond that - * @tparam T - */ - class LogicReturn { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - - - -#endif //LIBND4J_LOGICRETURN_H +namespace graph { +/** + * This class is responsible for execution logic of Return logical abstraction + * + * Basically we're just transferring input variable(s) to output variable(s), + * nothing beyond that + * @tparam T + */ +class LogicReturn { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd + +#endif // LIBND4J_LOGICRETURN_H diff --git a/libnd4j/include/graph/logic/LogicScope.h b/libnd4j/include/graph/logic/LogicScope.h index a7a8d6b7a9c6..17d13d83ac91 100644 --- a/libnd4j/include/graph/logic/LogicScope.h +++ b/libnd4j/include/graph/logic/LogicScope.h @@ -21,25 +21,24 @@ #ifndef LIBND4J_LOGICSCOPE_H #define LIBND4J_LOGICSCOPE_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Scope logical abstraction - * - * It's ultra-simple. It does nothing, and can't be executed directly. - * - * @tparam T - */ - class LogicScope { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - +namespace graph { +/** + * This class is responsible for execution logic of Scope logical abstraction + * + * It's ultra-simple. It does nothing, and can't be executed directly. + * + * @tparam T + */ +class LogicScope { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICSCOPE_H +#endif // LIBND4J_LOGICSCOPE_H diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index d91959d91eff..d74ce87e4908 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -21,25 +21,24 @@ #ifndef LIBND4J_LOGICSWITCH_H #define LIBND4J_LOGICSWITCH_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of Switch logical abstraction - * - * It's ultra-simple. It does nothing, and can't be executed directly. - * - * @tparam T - */ - class LogicSwitch { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - +namespace graph { +/** + * This class is responsible for execution logic of Switch logical abstraction + * + * It's ultra-simple. It does nothing, and can't be executed directly. + * + * @tparam T + */ +class LogicSwitch { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICSWITCH_H +#endif // LIBND4J_LOGICSWITCH_H diff --git a/libnd4j/include/graph/logic/LogicWhile.h b/libnd4j/include/graph/logic/LogicWhile.h index 6e4b2ea3ae24..e80d742cbf41 100644 --- a/libnd4j/include/graph/logic/LogicWhile.h +++ b/libnd4j/include/graph/logic/LogicWhile.h @@ -21,24 +21,24 @@ #ifndef LIBND4J_LOGICWHILE_H #define LIBND4J_LOGICWHILE_H -#include -#include #include +#include +#include namespace sd { - namespace graph { - /** - * This class is responsible for execution logic of While logical abstraction - * - * Basic idea is simple: we take 2 scopes, one for condition and other one for body. and we re-execute body as long, as condition scope evaluates to TRUE - * @tparam T - */ - class LogicWhile { - public: - static Nd4jStatus processNode(Graph* graph, Node* node); - }; - } -} - +namespace graph { +/** + * This class is responsible for execution logic of While logical abstraction + * + * Basic idea is simple: we take 2 scopes, one for condition and other one for + * body. and we re-execute body as long, as condition scope evaluates to TRUE + * @tparam T + */ +class LogicWhile { + public: + static Nd4jStatus processNode(Graph* graph, Node* node); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_LOGICWHILE_H +#endif // LIBND4J_LOGICWHILE_H diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp index a1ba5a9d0007..ec1437d7bbd5 100644 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -18,121 +18,122 @@ // Created by raver119 on 20.10.2017. // +#include #include #include -#include - namespace sd { - namespace graph { - Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicConditional::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - - auto size = node->input()->size(); - - // propagating inputs (optional) - for (int e = 0; e < size - 3; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - - int scopeConditionIndex = node->input()->at(size - 3).first; - int scopeFalseIndex = node->input()->at(size - 2).first; - int scopeTrueIndex = node->input()->at(size - 1).first; - - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // now we should take result of the Scope run, and evaluate it - //nd4j_debug("", ""); - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node:"); - - bool isReturn = false; - - // now we're executing one of the scopes, depending on condition evaluation - if (result->e(0) == 0) { - auto scopeFalse = graph->scopeById(scopeFalseIndex); - lastNode = 0; - int nodes = scopeFalse->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeFalse->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto *node = scopeFalse->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } else { - auto scopeTrue = graph->scopeById(scopeTrueIndex); - lastNode = 0; - int nodes = scopeTrue->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeTrue->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto node = scopeTrue->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } - - // now fetch and transfer variables to Conditional node - // but only if return wasn't called at the end of scope - if (!isReturn) { - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(lastNode, e); - std::pair pairNew(node->id(), e); - if (__variableSpace->hasVariable(pair)) { - auto array = __variableSpace->getVariable(pair)->getNDArray(); - auto newVar = new Variable(array); - newVar->setId(lastNode, e); - newVar->markRemovable(false); - - __variableSpace->putVariable(pairNew, newVar); - } else - break; - } - } - - return sd::Status::OK(); - */ - } - } -} \ No newline at end of file +namespace graph { +Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { + throw std::runtime_error( + "LogicConditional::processNode - not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + + auto size = node->input()->size(); + + // propagating inputs (optional) + for (int e = 0; e < size - 3; e++) { + std::pair pair(node->id(), e); + if (!__variableSpace->hasVariable(pair)) { + __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, + node->id(), e)); + } + + auto va = node->input()->at(e); + + auto inputVar = __variableSpace->getVariable(va); + + auto innerVar = __variableSpace->getVariable(pair); + if (innerVar->hasNDArray()) { + // TODO: ??? + } else { + // FIXME: in some cases it's possible to have no NDArray + if (inputVar->hasNDArray()) + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); + } + } + + + int scopeConditionIndex = node->input()->at(size - 3).first; + int scopeFalseIndex = node->input()->at(size - 2).first; + int scopeTrueIndex = node->input()->at(size - 1).first; + + auto scopeCondition = graph->scopeById(scopeConditionIndex); + int lastNode = 0; + for (auto v: *scopeCondition->nodes()) { + GraphExecutioner::executeFlatNode(graph, v, __variableSpace); + lastNode = v->id(); + } + + // now we should take result of the Scope run, and evaluate it + //nd4j_debug("", ""); + auto result = __variableSpace->getVariable(lastNode)->getNDArray(); + //result->printBuffer("Result of the last node:"); + + bool isReturn = false; + + // now we're executing one of the scopes, depending on condition evaluation + if (result->e(0) == 0) { + auto scopeFalse = graph->scopeById(scopeFalseIndex); + lastNode = 0; + int nodes = scopeFalse->nodes()->size(); + for (int e = 0; e < nodes - 1; e++) { + auto v = scopeFalse->nodes()->at(e); + GraphExecutioner::executeFlatNode(graph, v, __variableSpace); + lastNode = v->id(); + } + + // last node is either return or just last op + auto *node = scopeFalse->nodes()->at(nodes -1); + if (node->opType() == OpType_LOGIC && node->opNum() == 40) { + isReturn = true; + LogicReturn::processNode(graph, node); + } else { + GraphExecutioner::executeFlatNode(graph, node, __variableSpace); + lastNode = node->id(); + } + } else { + auto scopeTrue = graph->scopeById(scopeTrueIndex); + lastNode = 0; + int nodes = scopeTrue->nodes()->size(); + for (int e = 0; e < nodes - 1; e++) { + auto v = scopeTrue->nodes()->at(e); + GraphExecutioner::executeFlatNode(graph, v, __variableSpace); + lastNode = v->id(); + } + + // last node is either return or just last op + auto node = scopeTrue->nodes()->at(nodes -1); + if (node->opType() == OpType_LOGIC && node->opNum() == 40) { + isReturn = true; + LogicReturn::processNode(graph, node); + } else { + GraphExecutioner::executeFlatNode(graph, node, __variableSpace); + lastNode = node->id(); + } + } + + // now fetch and transfer variables to Conditional node + // but only if return wasn't called at the end of scope + if (!isReturn) { + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(lastNode, e); + std::pair pairNew(node->id(), e); + if (__variableSpace->hasVariable(pair)) { + auto array = __variableSpace->getVariable(pair)->getNDArray(); + auto newVar = new Variable(array); + newVar->setId(lastNode, e); + newVar->markRemovable(false); + + __variableSpace->putVariable(pairNew, newVar); + } else + break; + } + } + + return sd::Status::OK(); + */ +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 95788197c47b..da4721bee3bd 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -18,60 +18,63 @@ // @author raver119@gmail.com // -#include #include - +#include namespace sd { - namespace graph { - Nd4jStatus LogicEnter::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicEnter::processNode - not implemented yet"); - /* - // this op replicates input variable into the frame. basically happens once for single loop. - // sure, if there's inner loop within outer loop, it'll be called once for outer loop and multiple times for inner loop +namespace graph { +Nd4jStatus LogicEnter::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicEnter::processNode - not implemented yet"); + /* + // this op replicates input variable into the frame. basically happens once + for single loop. + // sure, if there's inner loop within outer loop, it'll be called once for + outer loop and multiple times for inner loop - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); + auto __variableSpace = graph->variableSpace(); + auto __flowPath = __variableSpace->flowPath(); - // basically, first non-null variable is our target - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); + // basically, first non-null variable is our target + for (int e = 0; e < node->input()->size(); e++) { + auto inputAddr = node->input()->at(e); - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (var->hasNDArray()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + if (__variableSpace->hasVariable(inputAddr)) { + auto var = __variableSpace->getVariable(inputAddr); + if (var->hasNDArray()) { + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), + node->id(), 0); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); + auto array = var->getNDArray(); + lvar->setNDArray(array); + lvar->markReadOnly(true); - break; - } else if (var->hasNDArrayList()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + break; + } else if (var->hasNDArrayList()) { + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), + node->id(), 0); - auto list = var->getNDArrayList(); - lvar->setNDArrayList(list); - lvar->markReadOnly(true); + auto list = var->getNDArrayList(); + lvar->setNDArrayList(list); + lvar->markReadOnly(true); - break; - } else { - // FIXME: can we really have third case here? - continue; - } - } - } + break; + } else { + // FIXME: can we really have third case here? + continue; + } + } + } - return sd::Status::OK(); - */ - } - } -} \ No newline at end of file + return sd::Status::OK(); + */ +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 70195eb22be3..2e8fbba2b195 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -18,54 +18,55 @@ // Created by raver119 on 20.10.2017. // -#include -#include -#include -#include #include -#include -#include -#include #include +#include #include +#include #include +#include #include - +#include +#include +#include +#include namespace sd { - namespace graph { - Nd4jStatus LogicExecutor::processNode(Graph *graph, Node *node) { - switch (node->opNum()) { - case sd::logic::While: - return LogicWhile::processNode(graph, node); - case sd::logic::Scope: - return LogicScope::processNode(graph, node); - case sd::logic::Conditional: - return LogicConditional::processNode(graph, node); - case sd::logic::Switch: - return LogicSwitch::processNode(graph, node); - case sd::logic::Return: - return LogicReturn::processNode(graph, node); - case sd::logic::Expose: - return LogicExpose::processNode(graph, node); - case sd::logic::Merge: - return LogicMerge::processNode(graph, node); - case sd::logic::LoopCond: - return LogicLoopCond::processNode(graph, node); - case sd::logic::NextIteration: - return LogicNextIeration::processNode(graph, node); - case sd::logic::Exit: - return LogicExit::processNode(graph, node); - case sd::logic::Enter: - return LogicEnter::processNode(graph, node); - } +namespace graph { +Nd4jStatus LogicExecutor::processNode(Graph *graph, Node *node) { + switch (node->opNum()) { + case sd::logic::While: + return LogicWhile::processNode(graph, node); + case sd::logic::Scope: + return LogicScope::processNode(graph, node); + case sd::logic::Conditional: + return LogicConditional::processNode(graph, node); + case sd::logic::Switch: + return LogicSwitch::processNode(graph, node); + case sd::logic::Return: + return LogicReturn::processNode(graph, node); + case sd::logic::Expose: + return LogicExpose::processNode(graph, node); + case sd::logic::Merge: + return LogicMerge::processNode(graph, node); + case sd::logic::LoopCond: + return LogicLoopCond::processNode(graph, node); + case sd::logic::NextIteration: + return LogicNextIeration::processNode(graph, node); + case sd::logic::Exit: + return LogicExit::processNode(graph, node); + case sd::logic::Enter: + return LogicEnter::processNode(graph, node); + } - if (node->getName().empty()) { - nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), node->opNum()); - } else { - nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), node->getName().c_str(), node->opNum()); - } - return ND4J_STATUS_BAD_INPUT; - } - } -} \ No newline at end of file + if (node->getName().empty()) { + nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), + node->opNum()); + } else { + nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), + node->getName().c_str(), node->opNum()); + } + return ND4J_STATUS_BAD_INPUT; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index dfabcb8ac861..8753ed56c184 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -20,30 +20,30 @@ #include - namespace sd { - namespace graph { - Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { - // this op is basically no-op - // we just know it exists - throw std::runtime_error("LogicExit::processNode - Not implemented yet"); -/* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->protoContext(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - return ND4J_STATUS_OK; - */ - } - } -} \ No newline at end of file +namespace graph { +Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { + // this op is basically no-op + // we just know it exists + throw std::runtime_error("LogicExit::processNode - Not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + auto __flowPath = __variableSpace->flowPath(); + + Context ctx(node->protoContext(), __variableSpace); + auto input = ctx.variable(0)->getNDArray(); + + std::pair pair0(node->id(), 0); + + if (!__variableSpace->hasVariable(pair0)) + __variableSpace->putVariable(pair0, new Variable(nullptr, + nullptr, node->id(), 0)); + + __variableSpace->getVariable(pair0)->setNDArray(input); + __variableSpace->getVariable(pair0)->markRemovable(false); + + return ND4J_STATUS_OK; + */ +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExpose.cpp b/libnd4j/include/graph/logic/impl/LogicExpose.cpp index 06ddbc61d773..3717adab45f2 100644 --- a/libnd4j/include/graph/logic/impl/LogicExpose.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExpose.cpp @@ -21,10 +21,10 @@ #include namespace sd { - namespace graph { - Nd4jStatus LogicExpose::processNode(Graph *graph, Node *node) { - // do we really want this? - return ND4J_STATUS_OK; - } - } -} \ No newline at end of file +namespace graph { +Nd4jStatus LogicExpose::processNode(Graph *graph, Node *node) { + // do we really want this? + return ND4J_STATUS_OK; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 8c464f231765..b9bc803b8ec9 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -20,38 +20,38 @@ #include - namespace sd { - namespace graph { - Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicLoopCond::processNode - Not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->protoContext(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - // pass further - if (input->e(0) > 0) { - // if condition is TRUE body will be invoked some time soon - // __flowPath->markFrameActive(node->getFrameId(), true); - //__flowPath->i - } else { - // body won't be activated - // __flowPath->markFrameActive(node->getFrameId(), false); - } - - return ND4J_STATUS_OK; - */ - } - } -} \ No newline at end of file +namespace graph { +Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicLoopCond::processNode - Not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + auto __flowPath = __variableSpace->flowPath(); + + Context ctx(node->protoContext(), __variableSpace); + auto input = ctx.variable(0)->getNDArray(); + + std::pair pair0(node->id(), 0); + + if (!__variableSpace->hasVariable(pair0)) + __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, +node->id(), 0)); + + __variableSpace->getVariable(pair0)->setNDArray(input); + __variableSpace->getVariable(pair0)->markRemovable(false); + + // pass further + if (input->e(0) > 0) { + // if condition is TRUE body will be invoked some time soon +// __flowPath->markFrameActive(node->getFrameId(), true); + //__flowPath->i + } else { + // body won't be activated +// __flowPath->markFrameActive(node->getFrameId(), false); + } + + return ND4J_STATUS_OK; + */ +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index a0068c968ce4..4c9f1d2baa81 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -18,120 +18,123 @@ // Created by raver119 on 30.01.18. // -#include #include +#include namespace sd { - namespace graph { - Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicMerge::processNode - not implemented yet"); - /* - // at merge node only one of inputs exist if that's just switch and other node isn't LogicNextItration - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); +namespace graph { +Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicMerge::processNode - not implemented yet"); + /* + // at merge node only one of inputs exist if that's just switch and other node +isn't LogicNextItration auto __variableSpace = graph->variableSpace(); auto +__flowPath = __variableSpace->flowPath(); - // merge MUST have 2 inputs - auto inputAddr0 = node->input().at(0); - auto inputAddr1 = node->input().at(1); + // merge MUST have 2 inputs + auto inputAddr0 = node->input().at(0); + auto inputAddr1 = node->input().at(1); - bool isWhile = false; + bool isWhile = false; - // now we want to check if second input is NextIteration - if (graph->hasNode(inputAddr1.first)) { - auto secondNode = graph->nodeById(inputAddr1.first); + // now we want to check if second input is NextIteration + if (graph->hasNode(inputAddr1.first)) { + auto secondNode = graph->nodeById(inputAddr1.first); - // checking for NextIteration - if (secondNode->opType() == OpType_LOGIC && secondNode->opNum() == 80L) { - isWhile = true; + // checking for NextIteration + if (secondNode->opType() == OpType_LOGIC && secondNode->opNum() == 80L) { + isWhile = true; - // notifying NextIteration node for rewind index - secondNode->setRewindLayer(node->getLayer()); - secondNode->setRewindNode(node->id()); - } + // notifying NextIteration node for rewind index + secondNode->setRewindLayer(node->getLayer()); + secondNode->setRewindNode(node->id()); + } - } + } - // FIXME: we don't need this check. Just last input should survive, IF it exists - if (isWhile){ + // FIXME: we don't need this check. Just last input should survive, IF it +exists if (isWhile){ - if (node->getFrameId() >= 0) - __flowPath->markFrameActive(node->getFrameId(), true); + if (node->getFrameId() >= 0) + __flowPath->markFrameActive(node->getFrameId(), true); - bool hasVar = __variableSpace->hasVariable(inputAddr1); - if ( hasVar && __flowPath->wasExecuted(inputAddr1.first)) { - nd4j_debug("Node_%i: propagating second input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr1); + bool hasVar = __variableSpace->hasVariable(inputAddr1); + if ( hasVar && __flowPath->wasExecuted(inputAddr1.first)) { + nd4j_debug("Node_%i: propagating second input\n", node->id()); + auto var = __variableSpace->getVariable(inputAddr1); - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), +0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); - auto array = var->getNDArray(); + auto array = var->getNDArray(); - //array->printIndexedBuffer("propagated"); + //array->printIndexedBuffer("propagated"); - lvar->setNDArray(array); - lvar->markReadOnly(true); + lvar->setNDArray(array); + lvar->markReadOnly(true); - __flowPath->markExecuted(inputAddr1.first, false); + __flowPath->markExecuted(inputAddr1.first, false); - } else { - nd4j_debug("Node_%i: propagating first input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr0); + } else { + nd4j_debug("Node_%i: propagating first input\n", node->id()); + auto var = __variableSpace->getVariable(inputAddr0); - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), +0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); + auto array = var->getNDArray(); + lvar->setNDArray(array); + lvar->markReadOnly(true); - } - } else { + } + } else { - // basically, first non-null variable is our target - for (int e = 0; e < node->input().size(); e++) { - auto inputAddr = node->input().at(e); + // basically, first non-null variable is our target + for (int e = 0; e < node->input().size(); e++) { + auto inputAddr = node->input().at(e); - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (!var->hasNDArray() || !__flowPath->isNodeActive(inputAddr.first)) - continue; + if (__variableSpace->hasVariable(inputAddr)) { + auto var = __variableSpace->getVariable(inputAddr); + if (!var->hasNDArray() || +!__flowPath->isNodeActive(inputAddr.first)) continue; - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), +node->id(), 0); - if (lvar->hasNDArray()) - delete lvar->getNDArray(); + if (lvar->hasNDArray()) + delete lvar->getNDArray(); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - //lvar->markExternal(false);h + auto array = var->getNDArray(); + lvar->setNDArray(array); + lvar->markReadOnly(true); + //lvar->markExternal(false);h - break; - } - } - } + break; + } + } + } - return Status::OK(); - */ - } - } + return Status::OK(); + */ } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 4a24d2fe0288..1a0eafd94395 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -20,34 +20,34 @@ #include - namespace sd { - namespace graph { - Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicNextIeration::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); +namespace graph { +Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { + throw std::runtime_error( + "LogicNextIeration::processNode - not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + auto __flowPath = __variableSpace->flowPath(); - auto inputAddr = node->input()->at(0); + auto inputAddr = node->input()->at(0); - auto var = __variableSpace->getVariable(inputAddr); + auto var = __variableSpace->getVariable(inputAddr); - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + Variable *lvar = nullptr; + if (__variableSpace->hasVariable(node->id(), 0)) + lvar = __variableSpace->getVariable(node->id(), 0); + else + lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); // if (lvar->hasNDArray()) // delete lvar->getNDArray(); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); + auto array = var->getNDArray(); + lvar->setNDArray(array); + lvar->markReadOnly(true); - return ND4J_STATUS_OK; - */ - } - } -} \ No newline at end of file + return ND4J_STATUS_OK; + */ +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp index e7112a40d109..27a8fc76c78d 100644 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -19,40 +19,45 @@ // #include "graph/logic/LogicReturn.h" -#include + #include +#include namespace sd { - namespace graph { - Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicReturn::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); +namespace graph { +Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicReturn::processNode - not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); - auto outputAddr = node->output()->at(e); + for (int e = 0; e < node->input()->size(); e++) { + auto inputAddr = node->input()->at(e); + auto outputAddr = node->output()->at(e); - // FIXME!! - outputAddr.second = e; + // FIXME!! + outputAddr.second = e; - if (Environment::getInstance()->isDebugAndVerbose()) - nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second); + if (Environment::getInstance()->isDebugAndVerbose()) + nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", + inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second); - auto varIn = __variableSpace->getVariable(inputAddr); - auto varOut = __variableSpace->getVariable(outputAddr); + auto varIn = __variableSpace->getVariable(inputAddr); + auto varOut = __variableSpace->getVariable(outputAddr); - nd4j_debug("Returning varType: [%s]\n", EnumUtils::_VariableTypeToString(varIn->variableType())); + nd4j_debug("Returning varType: [%s]\n", + EnumUtils::_VariableTypeToString(varIn->variableType())); - // FIXME: this is obviously wrong, we should keep depth track for backprop here - varOut->getNDArray()->assign(varIn->getNDArray()); + // FIXME: this is obviously wrong, we should keep depth track for backprop + here varOut->getNDArray()->assign(varIn->getNDArray()); - if (Environment::getInstance()->isDebugAndVerbose()) - nd4j_debug("In after: [%f]; Out after: [%f]\n", varIn->getNDArray()->meanNumber().e(0), varOut->getNDArray()->meanNumber().e(0)); - } + if (Environment::getInstance()->isDebugAndVerbose()) + nd4j_debug("In after: [%f]; Out after: [%f]\n", + varIn->getNDArray()->meanNumber().e(0), + varOut->getNDArray()->meanNumber().e(0)); + } - return sd::Status::OK(); - */ - } - } + return sd::Status::OK(); + */ } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicScope.cpp b/libnd4j/include/graph/logic/impl/LogicScope.cpp index 5319397d6969..89738507edf9 100644 --- a/libnd4j/include/graph/logic/impl/LogicScope.cpp +++ b/libnd4j/include/graph/logic/impl/LogicScope.cpp @@ -18,16 +18,15 @@ // Created by raver119 on 20.10.2017. // -#include #include - +#include namespace sd { - namespace graph { - Nd4jStatus LogicScope::processNode(Graph *graph, Node *node) { - // this op is basically no-op - // we just know it exists - return sd::Status::OK(); - } - } -} \ No newline at end of file +namespace graph { +Nd4jStatus LogicScope::processNode(Graph *graph, Node *node) { + // this op is basically no-op + // we just know it exists + return sd::Status::OK(); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index e58b42534fa5..8de39b462b10 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -18,93 +18,97 @@ // Created by raver119 on 21.10.17. // -#include -#include #include +#include +#include namespace sd { - namespace graph { - Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { - throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->getContextPrototype(), __variableSpace); - - // this can be either our format, or compatible format. - if (graph->hasScope(node->input()->at(0).first)) { - nd4j_debug("Node_%i: Scoped mode.\n", node->id()); - // first input is Scope, so it's ours - int scopeConditionIndex = node->input()->at(0).first; - auto input = ctx.variable(1); - - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node"); - - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, node->id(), 1)); - - if (!result->e(0)) { - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair1)->markRemovable(false); - } - } else { - // first input is NOT a Scope, so it's compatible format - nd4j_debug("Node_%i: Compatible mode.\n", node->id()); - - auto input = ctx.variable(0)->getNDArray(); - auto boolean = ctx.variable(1)->getNDArray(); - - //input->printIndexedBuffer("0"); - //boolean->printIndexedBuffer("1"); - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, node->id(), 1)); - - if (!boolean->e(0)) { - // false - nd4j_debug("Node_%i: FALSE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - //true - nd4j_debug("Node_%i: TRUE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input); - __variableSpace->getVariable(pair1)->markRemovable(false); - } - } - - return sd::Status::OK(); - */ - }; - } -} +namespace graph { +Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { + throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + auto __flowPath = __variableSpace->flowPath(); + + Context ctx(node->getContextPrototype(), __variableSpace); + + // this can be either our format, or compatible format. + if (graph->hasScope(node->input()->at(0).first)) { + nd4j_debug("Node_%i: Scoped mode.\n", node->id()); + // first input is Scope, so it's ours + int scopeConditionIndex = node->input()->at(0).first; + auto input = ctx.variable(1); + + auto scopeCondition = graph->scopeById(scopeConditionIndex); + int lastNode = 0; + for (auto v: *scopeCondition->nodes()) { + GraphExecutioner::executeFlatNode(graph, v, __variableSpace); + lastNode = v->id(); + } + + // now we should take result of the Scope run, and evaluate it + auto result = __variableSpace->getVariable(lastNode)->getNDArray(); + //result->printBuffer("Result of the last node"); + + + std::pair pair0(node->id(), 0); + std::pair pair1(node->id(), 1); + + if (!__variableSpace->hasVariable(pair0)) + __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, + node->id(), 0)); + + if (!__variableSpace->hasVariable(pair1)) + __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, + node->id(), 1)); + + if (!result->e(0)) { + __flowPath->markBranch(node->id(), 0); + __variableSpace->getVariable(pair0)->setNDArray(input->getNDArray()); + __variableSpace->getVariable(pair0)->markRemovable(false); + } else { + __flowPath->markBranch(node->id(), 1); + __variableSpace->getVariable(pair1)->setNDArray(input->getNDArray()); + __variableSpace->getVariable(pair1)->markRemovable(false); + } + } else { + // first input is NOT a Scope, so it's compatible format + nd4j_debug("Node_%i: Compatible mode.\n", node->id()); + + auto input = ctx.variable(0)->getNDArray(); + auto boolean = ctx.variable(1)->getNDArray(); + + //input->printIndexedBuffer("0"); + //boolean->printIndexedBuffer("1"); + + std::pair pair0(node->id(), 0); + std::pair pair1(node->id(), 1); + + if (!__variableSpace->hasVariable(pair0)) + __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, + node->id(), 0)); + + if (!__variableSpace->hasVariable(pair1)) + __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, + node->id(), 1)); + + if (!boolean->e(0)) { + // false + nd4j_debug("Node_%i: FALSE branch active\n", node->id()); + __flowPath->markBranch(node->id(), 0); + __variableSpace->getVariable(pair0)->setNDArray(input); + __variableSpace->getVariable(pair0)->markRemovable(false); + } else { + //true + nd4j_debug("Node_%i: TRUE branch active\n", node->id()); + __flowPath->markBranch(node->id(), 1); + __variableSpace->getVariable(pair1)->setNDArray(input); + __variableSpace->getVariable(pair1)->markRemovable(false); + } + } + + return sd::Status::OK(); + */ +}; +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index 073c8c09dac7..c53d4dd86a03 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -18,129 +18,129 @@ // Created by raver119 on 20.10.2017. // -#include -#include -#include #include - +#include +#include +#include namespace sd { - namespace graph { - Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { - throw std::runtime_error("LogicWhile::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - - nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); - - // total number of inputs. 2 last inputs are scopes - int inputs = node->input()->size(); - - if (inputs < 3) { - nd4j_printf("While [%i]: loop should have at least 1 external variable announced\n", node->id()); - return ND4J_STATUS_BAD_INPUT; - } - - for (int e = 0; e < inputs - 2; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - int scopeConditionIndex = node->input()->at(inputs - 2).first; - int scopeBodyIndex = node->input()->at(inputs - 1).first; - - nd4j_debug("While [%i]: got [%i] inputs\n", node->id(), node->input()->size()); - - // we're running condition nodes now - auto scope = graph->scopeById(scopeConditionIndex); - int breaker = 0; - while (true && breaker < 10000000) { - int lastNode = 0; - // we're running condition scope first - nd4j_debug("While [%i]: got [%i] ops in condition scope [%i]\n", node->id(), scope->nodes()->size(), scopeConditionIndex); - - for (Node* v: *scope->nodes()) { - //v->getBlock()->updateVariables(); - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName().c_str()); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - if (status != ND4J_STATUS_OK) - return status; - } - - lastNode = v->id(); - } - - if (!__variableSpace->hasVariable(lastNode)) { - nd4j_printf("While [%i]: got no results out of conditional loop\n", node->id()); - return ND4J_STATUS_KERNEL_FAILURE; - } - - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - - if (Environment::getInstance()->isDebugAndVerbose()) - result->printBuffer("Result of the last node:"); - - // if result evaluates to 0.0 - condition returned FALSE - if (result->e(0) == 0) - break; - else { - auto scopeBody = graph->scopeById(scopeBodyIndex); - int lastNode = 0; - int e = 0; - nd4j_debug("While [%i] got [%i] ops in body scope [%i]\n", node->id(), scopeBody->nodes()->size(), scopeBodyIndex); - for (; e < scopeBody->nodes()->size() - 1; e++) { - Node* v = scopeBody->nodes()->at(e); - - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName().c_str()); - //v->getBlock()->updateVariables(); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - if (status != ND4J_STATUS_OK) - return status; - } - - lastNode = v->id(); - } - - // now execute return statement - Node* ret = scopeBody->nodes()->at(e); - LogicReturn::processNode(graph, ret); - } - - breaker++; - } - - // if we've hit breaker limit - we should notify about that - if (breaker >= 10000000) { - nd4j_printf("While condition seems to be never ending, aborting...\n", breaker); - return ND4J_STATUS_KERNEL_FAILURE; - } - - return sd::Status::OK(); - */ - } - } +namespace graph { +Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { + throw std::runtime_error("LogicWhile::processNode - not implemented yet"); + /* + auto __variableSpace = graph->variableSpace(); + + nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); + + // total number of inputs. 2 last inputs are scopes + int inputs = node->input()->size(); + + if (inputs < 3) { + nd4j_printf("While [%i]: loop should have at least 1 external variable + announced\n", node->id()); return ND4J_STATUS_BAD_INPUT; + } + + for (int e = 0; e < inputs - 2; e++) { + std::pair pair(node->id(), e); + if (!__variableSpace->hasVariable(pair)) { + __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, + node->id(), e)); + } + + auto va = node->input()->at(e); + + auto inputVar = __variableSpace->getVariable(va); + + auto innerVar = __variableSpace->getVariable(pair); + if (innerVar->hasNDArray()) { + // TODO: ??? + } else { + // FIXME: in some cases it's possible to have no NDArray + if (inputVar->hasNDArray()) + innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); + } + } + + int scopeConditionIndex = node->input()->at(inputs - 2).first; + int scopeBodyIndex = node->input()->at(inputs - 1).first; + + nd4j_debug("While [%i]: got [%i] inputs\n", node->id(), + node->input()->size()); + + // we're running condition nodes now + auto scope = graph->scopeById(scopeConditionIndex); + int breaker = 0; + while (true && breaker < 10000000) { + int lastNode = 0; + // we're running condition scope first + nd4j_debug("While [%i]: got [%i] ops in condition scope [%i]\n", + node->id(), scope->nodes()->size(), scopeConditionIndex); + + for (Node* v: *scope->nodes()) { + //v->getBlock()->updateVariables(); + if (v->opType() == OpType_LOGIC) { + nd4j_debug("Falling back to logic\n",""); + LogicExecutor::processNode(graph, v); + } else { + nd4j_debug("Op [<%s>]\n", v->getName().c_str()); + Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, + __variableSpace); if (status != ND4J_STATUS_OK) return status; + } + + lastNode = v->id(); + } + + if (!__variableSpace->hasVariable(lastNode)) { + nd4j_printf("While [%i]: got no results out of conditional loop\n", + node->id()); return ND4J_STATUS_KERNEL_FAILURE; + } + + // now we should take result of the Scope run, and evaluate it + auto result = __variableSpace->getVariable(lastNode)->getNDArray(); + + if (Environment::getInstance()->isDebugAndVerbose()) + result->printBuffer("Result of the last node:"); + + // if result evaluates to 0.0 - condition returned FALSE + if (result->e(0) == 0) + break; + else { + auto scopeBody = graph->scopeById(scopeBodyIndex); + int lastNode = 0; + int e = 0; + nd4j_debug("While [%i] got [%i] ops in body scope [%i]\n", node->id(), + scopeBody->nodes()->size(), scopeBodyIndex); for (; e < + scopeBody->nodes()->size() - 1; e++) { Node* v = scopeBody->nodes()->at(e); + + if (v->opType() == OpType_LOGIC) { + nd4j_debug("Falling back to logic\n",""); + LogicExecutor::processNode(graph, v); + } else { + nd4j_debug("Op [<%s>]\n", v->getName().c_str()); + //v->getBlock()->updateVariables(); + Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, + v, __variableSpace); if (status != ND4J_STATUS_OK) return status; + } + + lastNode = v->id(); + } + + // now execute return statement + Node* ret = scopeBody->nodes()->at(e); + LogicReturn::processNode(graph, ret); + } + + breaker++; + } + + // if we've hit breaker limit - we should notify about that + if (breaker >= 10000000) { + nd4j_printf("While condition seems to be never ending, aborting...\n", + breaker); return ND4J_STATUS_KERNEL_FAILURE; + } + + return sd::Status::OK(); + */ } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/optimization/GraphOptimizer.h b/libnd4j/include/graph/optimization/GraphOptimizer.h index df3a6f559baf..5432b6f39d47 100644 --- a/libnd4j/include/graph/optimization/GraphOptimizer.h +++ b/libnd4j/include/graph/optimization/GraphOptimizer.h @@ -24,18 +24,17 @@ #include namespace sd { - namespace graph { - class SD_EXPORT GraphOptimizer { - public: - /** - * This method optimizes given Graph and returns independent cloned Graph - * @param graph - * @return - */ - static Graph* optimize(const Graph &graph); - }; - } -} +namespace graph { +class SD_EXPORT GraphOptimizer { + public: + /** + * This method optimizes given Graph and returns independent cloned Graph + * @param graph + * @return + */ + static Graph* optimize(const Graph& graph); +}; +} // namespace graph +} // namespace sd - -#endif //SD_GRAPHOPTIMIZER_H +#endif // SD_GRAPHOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/NodeOptimizer.h b/libnd4j/include/graph/optimization/NodeOptimizer.h index 0befb67e5dbb..8fb39c012898 100644 --- a/libnd4j/include/graph/optimization/NodeOptimizer.h +++ b/libnd4j/include/graph/optimization/NodeOptimizer.h @@ -21,38 +21,39 @@ #ifndef SD_NODEOPTIMIZER_H #define SD_NODEOPTIMIZER_H -#include #include +#include + #include namespace sd { - namespace graph { - /** - * This abstract class defines basic methods needed for Inputs/Outputs optimizations. I.e. weight format changes or data types changes for a specific backend - */ - class SD_EXPORT NodeOptimizer { - protected: - std::string _target = {}; - - public: - NodeOptimizer() = default; - virtual ~NodeOptimizer() = default; - - /** - * This method applu - * @param node - */ - virtual void optimize(Node &node) = 0; - - /** - * This method returns target Op name for this optimizer - * @return - */ - const std::string& targetOp() const; - }; - } -} - - - -#endif //DEV_TESTS_NODEOPTIMIZER_H +namespace graph { +/** + * This abstract class defines basic methods needed for Inputs/Outputs + * optimizations. I.e. weight format changes or data types changes for a + * specific backend + */ +class SD_EXPORT NodeOptimizer { + protected: + std::string _target = {}; + + public: + NodeOptimizer() = default; + virtual ~NodeOptimizer() = default; + + /** + * This method applu + * @param node + */ + virtual void optimize(Node& node) = 0; + + /** + * This method returns target Op name for this optimizer + * @return + */ + const std::string& targetOp() const; +}; +} // namespace graph +} // namespace sd + +#endif // DEV_TESTS_NODEOPTIMIZER_H diff --git a/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp index f838a22e790e..1170b0b6191c 100644 --- a/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp +++ b/libnd4j/include/graph/optimization/impl/GraphOptimizer.cpp @@ -21,13 +21,13 @@ #include namespace sd { - namespace graph { - Graph *GraphOptimizer::optimize(const Graph &graph) { - auto clone = graph.clone(); +namespace graph { +Graph *GraphOptimizer::optimize(const Graph &graph) { + auto clone = graph.clone(); - //TODO: implement this method + // TODO: implement this method - return clone; - } - } -} \ No newline at end of file + return clone; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp index 7b9e681c8c34..b8ed74ef84ae 100644 --- a/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp +++ b/libnd4j/include/graph/optimization/impl/NodeOptimizer.cpp @@ -21,9 +21,7 @@ #include namespace sd { - namespace graph { - const std::string &NodeOptimizer::targetOp() const { - return _target; - } - } -} \ No newline at end of file +namespace graph { +const std::string &NodeOptimizer::targetOp() const { return _target; } +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/GraphProfile.h b/libnd4j/include/graph/profiling/GraphProfile.h index 8e76b0436683..5c89ce2ff13e 100644 --- a/libnd4j/include/graph/profiling/GraphProfile.h +++ b/libnd4j/include/graph/profiling/GraphProfile.h @@ -21,100 +21,105 @@ #ifndef ND4J_GRAPH_PROFILE_H #define ND4J_GRAPH_PROFILE_H -#include "NodeProfile.h" -#include #include -#include -#include -#include +#include + #include +#include +#include +#include + +#include "NodeProfile.h" namespace sd { - namespace graph { - class SD_EXPORT GraphProfile { - private: - // this variable - Nd4jLong _merges = 1L; - - /** - * This is global memory values - */ - Nd4jLong _memoryTotal = 0L; - Nd4jLong _memoryActivations = 0L; - Nd4jLong _memoryTemporary = 0L; - Nd4jLong _memoryObjects = 0L; - - // time spent for graph construction - Nd4jLong _buildTime = 0L; - - // time spent for graph execution - Nd4jLong _executionTime = 0L; - - // collection of pointers to profile results - std::vector _profiles; - std::map _profilesById; - - // collection of various timing reports - std::map _timings; - std::chrono::time_point _last; - - std::map> _timers; - - void updateLast(); - public: - GraphProfile(); - ~GraphProfile(); - - /** - * These methods just adding amount of bytes to various counters - */ - void addToTotal(Nd4jLong bytes); - void addToActivations(Nd4jLong bytes); - void addToTemporary(Nd4jLong bytes); - void addToObjects(Nd4jLong bytes); - - /** - * This method allows to set graph construction (i.e. deserialization) time in nanoseconds - */ - void setBuildTime(Nd4jLong nanos); - - /** - * This method sets graph execution time in nanoseconds. - */ - void setExecutionTime(Nd4jLong nanos); - - void startEvent(const char *name); - void recordEvent(const char *name); - void deleteEvent(const char *name); - - /** - * This method saves time as delta from last saved time - */ - void spotEvent(const char *name); - - /** - * This method returns pointer to NodeProfile by ID - * PLEASE NOTE: this method will create new NodeProfile if there's none - */ - NodeProfile* nodeById(int id, const char *name = nullptr); - bool nodeExists(int id); - - /** - * This method merges values from other profile report - * @param other - */ - void merge(GraphProfile *other); - void assign(GraphProfile *other); - - /** - * These methods are just utility methods for time - */ - static Nd4jLong currentTime(); - static Nd4jLong relativeTime(Nd4jLong time); - - void printOut(); - }; - } -} +namespace graph { +class SD_EXPORT GraphProfile { + private: + // this variable + Nd4jLong _merges = 1L; + + /** + * This is global memory values + */ + Nd4jLong _memoryTotal = 0L; + Nd4jLong _memoryActivations = 0L; + Nd4jLong _memoryTemporary = 0L; + Nd4jLong _memoryObjects = 0L; + + // time spent for graph construction + Nd4jLong _buildTime = 0L; + + // time spent for graph execution + Nd4jLong _executionTime = 0L; + + // collection of pointers to profile results + std::vector _profiles; + std::map _profilesById; + + // collection of various timing reports + std::map _timings; + std::chrono::time_point _last; + + std::map> + _timers; + + void updateLast(); + + public: + GraphProfile(); + ~GraphProfile(); + + /** + * These methods just adding amount of bytes to various counters + */ + void addToTotal(Nd4jLong bytes); + void addToActivations(Nd4jLong bytes); + void addToTemporary(Nd4jLong bytes); + void addToObjects(Nd4jLong bytes); + + /** + * This method allows to set graph construction (i.e. deserialization) time in + * nanoseconds + */ + void setBuildTime(Nd4jLong nanos); + + /** + * This method sets graph execution time in nanoseconds. + */ + void setExecutionTime(Nd4jLong nanos); + + void startEvent(const char *name); + void recordEvent(const char *name); + void deleteEvent(const char *name); + + /** + * This method saves time as delta from last saved time + */ + void spotEvent(const char *name); + + /** + * This method returns pointer to NodeProfile by ID + * PLEASE NOTE: this method will create new NodeProfile if there's none + */ + NodeProfile *nodeById(int id, const char *name = nullptr); + bool nodeExists(int id); + + /** + * This method merges values from other profile report + * @param other + */ + void merge(GraphProfile *other); + void assign(GraphProfile *other); + + /** + * These methods are just utility methods for time + */ + static Nd4jLong currentTime(); + static Nd4jLong relativeTime(Nd4jLong time); + + void printOut(); +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/GraphProfilingHelper.h b/libnd4j/include/graph/profiling/GraphProfilingHelper.h index d32d99374660..e87f6a509663 100644 --- a/libnd4j/include/graph/profiling/GraphProfilingHelper.h +++ b/libnd4j/include/graph/profiling/GraphProfilingHelper.h @@ -21,17 +21,17 @@ #ifndef LIBND4J_GRAPHPROFILINGHELPER_H #define LIBND4J_GRAPHPROFILINGHELPER_H - #include + #include "GraphProfile.h" namespace sd { - namespace graph { - class GraphProfilingHelper { - public: - static GraphProfile* profile(Graph *graph, int iterations); - }; - } -} +namespace graph { +class GraphProfilingHelper { + public: + static GraphProfile* profile(Graph* graph, int iterations); +}; +} // namespace graph +} // namespace sd -#endif //LIBND4J_GRAPHPROFILINGHELPER_H +#endif // LIBND4J_GRAPHPROFILINGHELPER_H diff --git a/libnd4j/include/graph/profiling/NodeProfile.h b/libnd4j/include/graph/profiling/NodeProfile.h index 2b6d302e95e6..6c1f639ce726 100644 --- a/libnd4j/include/graph/profiling/NodeProfile.h +++ b/libnd4j/include/graph/profiling/NodeProfile.h @@ -21,91 +21,93 @@ #ifndef LIBND4J_NODE_PROFILE_H #define LIBND4J_NODE_PROFILE_H -#include #include +#include + #include #include namespace sd { - namespace graph { - class SD_EXPORT NodeProfile { - private: - int _id; - std::string _name; +namespace graph { +class SD_EXPORT NodeProfile { + private: + int _id; + std::string _name; + + Nd4jLong _merges = 1L; + + // time spent during deserialization + Nd4jLong _buildTime = 0L; - Nd4jLong _merges = 1L; + // time spent before op execution + Nd4jLong _preparationTime = 0L; - // time spent during deserialization - Nd4jLong _buildTime = 0L; - - // time spent before op execution - Nd4jLong _preparationTime = 0L; + // time spent for op execution + Nd4jLong _executionTime = 0L; - // time spent for op execution - Nd4jLong _executionTime = 0L; + // total time spent during node execution + Nd4jLong _totalTime = 0L; - // total time spent during node execution - Nd4jLong _totalTime = 0L; + // time spent for output shape creation + Nd4jLong _shapeTime = 0L; - // time spent for output shape creation - Nd4jLong _shapeTime = 0L; + // time spent for output arrays creation + Nd4jLong _arrayTime = 0L; - // time spent for output arrays creation - Nd4jLong _arrayTime = 0L; + Nd4jLong _inputTime = 0L; - Nd4jLong _inputTime = 0L; + // amount of memory used for outputs + Nd4jLong _memoryActivations = 0L; - // amount of memory used for outputs - Nd4jLong _memoryActivations = 0L; + // amount of memory used internally for temporary arrays + Nd4jLong _memoryTemporary = 0L; - // amount of memory used internally for temporary arrays - Nd4jLong _memoryTemporary = 0L; + // amount of memory used internally for objects + Nd4jLong _memoryObjects = 0L; - // amount of memory used internally for objects - Nd4jLong _memoryObjects = 0L; + // total amount of memory used during execution + Nd4jLong _memoryTotal = 0L; - // total amount of memory used during execution - Nd4jLong _memoryTotal = 0L; + std::vector _inputShapes; + std::vector _outputShapes; - std::vector _inputShapes; - std::vector _outputShapes; - public: - NodeProfile() = default; - ~NodeProfile() = default; + public: + NodeProfile() = default; + ~NodeProfile() = default; - explicit NodeProfile(int id, const char *name); + explicit NodeProfile(int id, const char* name); - void setBuildTime(Nd4jLong time); - void setPreparationTime(Nd4jLong time); - void setExecutionTime(Nd4jLong time); - void setTotalTime(Nd4jLong time); - void setShapeFunctionTime(Nd4jLong time); - void setArrayTime(Nd4jLong time); - void setInputTime(Nd4jLong time); + void setBuildTime(Nd4jLong time); + void setPreparationTime(Nd4jLong time); + void setExecutionTime(Nd4jLong time); + void setTotalTime(Nd4jLong time); + void setShapeFunctionTime(Nd4jLong time); + void setArrayTime(Nd4jLong time); + void setInputTime(Nd4jLong time); - void setActivationsSize(Nd4jLong bytes); - void setTemporarySize(Nd4jLong bytes); - void setObjectsSize(Nd4jLong bytes); - void setTotalSize(Nd4jLong bytes); + void setActivationsSize(Nd4jLong bytes); + void setTemporarySize(Nd4jLong bytes); + void setObjectsSize(Nd4jLong bytes); + void setTotalSize(Nd4jLong bytes); - void addInputShape(Nd4jLong const* shapeInfo); - void addOutputShape(Nd4jLong const* shapeInfo); + void addInputShape(Nd4jLong const* shapeInfo); + void addOutputShape(Nd4jLong const* shapeInfo); - Nd4jLong getActivationsSize() const; - Nd4jLong getTemporarySize() const; - Nd4jLong getObjectsSize() const; - Nd4jLong getTotalSize() const; + Nd4jLong getActivationsSize() const; + Nd4jLong getTemporarySize() const; + Nd4jLong getObjectsSize() const; + Nd4jLong getTotalSize() const; - Nd4jLong getExecutionTime() const; + Nd4jLong getExecutionTime() const; - std::string& name(); + std::string& name(); - void merge(NodeProfile *other); - void assign(NodeProfile *other); + void merge(NodeProfile* other); + void assign(NodeProfile* other); - void printOut(); - }; - } -} + void printOut(); +}; +} // namespace graph +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/impl/GraphProfile.cpp b/libnd4j/include/graph/profiling/impl/GraphProfile.cpp index 14ce54a0f645..bb37c0764c78 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfile.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfile.cpp @@ -20,198 +20,183 @@ #include #include -#include #include + #include +#include namespace sd { - namespace graph { - GraphProfile::GraphProfile() { - updateLast(); - } - - GraphProfile::~GraphProfile() { - // releasing NodeProfile pointers - for (auto v: _profiles) - delete v; - - _timings.clear(); - } - - void GraphProfile::addToTotal(Nd4jLong bytes) { - _memoryTotal += bytes; - } - - void GraphProfile::addToActivations(Nd4jLong bytes) { - _memoryActivations += bytes; - } - - void GraphProfile::addToTemporary(Nd4jLong bytes) { - _memoryTemporary += bytes; - } - - void GraphProfile::addToObjects(Nd4jLong bytes) { - _memoryObjects += bytes; - } - - void GraphProfile::setBuildTime(Nd4jLong nanos) { - _buildTime = nanos; - } - - void GraphProfile::setExecutionTime(Nd4jLong nanos) { - _executionTime = nanos; - } - - - Nd4jLong GraphProfile::currentTime() { - auto t = std::chrono::system_clock::now(); - auto v = std::chrono::time_point_cast (t); - auto epoch = v.time_since_epoch(); - return (Nd4jLong) std::chrono::duration_cast(epoch).count(); - } - - Nd4jLong GraphProfile::relativeTime(Nd4jLong time) { - auto t1 = currentTime(); - return t1 - time; - } - - void GraphProfile::updateLast() { - _last = std::chrono::system_clock::now(); - } - - void GraphProfile::startEvent(const char *name) { - std::string k = name; - _timers[k] = std::chrono::system_clock::now(); - } - - void GraphProfile::recordEvent(const char *name) { - std::string k = name; - if (_timers.count(k) == 0) { - nd4j_printf("Can't find timer key: [%s]", name); - throw std::runtime_error("Missing timer key"); - } - auto t0 = _timers[k]; - auto t1 = std::chrono::system_clock::now(); - auto v = (Nd4jLong) std::chrono::duration_cast(t1 - t0).count(); - - _timings[k] = v; - _timers.erase(k); - } - - void GraphProfile::deleteEvent(const char *name) { - std::string k = name; - _timers.erase(k); - } - - void GraphProfile::spotEvent(const char *name) { - auto t = std::chrono::system_clock::now(); - auto d = (Nd4jLong) std::chrono::duration_cast(t - _last).count(); - std::string k = name; - _timings[k] = d; - updateLast(); - } - - NodeProfile* GraphProfile::nodeById(int id, const char *name) { - if (_profilesById.count(id) == 0) { - auto node = new NodeProfile(id, name); - _profiles.emplace_back(node); - _profilesById[id] = node; - return node; - } - - return _profilesById[id]; - } - - void GraphProfile::merge(GraphProfile *other) { - _merges += other->_merges; - _memoryActivations += other->_memoryActivations; - _memoryTemporary += other->_memoryTemporary; - _memoryTotal += other->_memoryTotal; - _memoryObjects += other->_memoryObjects; - - _executionTime += other->_executionTime; - _buildTime += other->_buildTime; - - - for (auto v:_profilesById) { - if (!other->nodeExists(v.first)) - continue; - - v.second->merge(other->nodeById(v.first)); - } - } - - void GraphProfile::assign(GraphProfile *other) { - _merges = other->_merges; - _memoryActivations = other->_memoryActivations; - _memoryTemporary = other->_memoryTemporary; - _memoryTotal = other->_memoryTotal; - _memoryObjects = other->_memoryObjects; - - _executionTime = other->_executionTime; - _buildTime = other->_buildTime; - - - for (auto v: other->_profilesById) { - nodeById(v.first, v.second->name().c_str())->assign(v.second); - } - } - - bool GraphProfile::nodeExists(int id) { - return _profilesById.count(id) > 0; - } - - void GraphProfile::printOut() { - nd4j_printf("Graph profile: %i executions\n", _merges); - nd4j_printf("\nMemory:\n", ""); - - Nd4jLong tmp = 0L; - Nd4jLong obj = 0L; - Nd4jLong act = 0L; - Nd4jLong ttl = 0L; - for (auto v: _profiles) { - tmp += v->getTemporarySize(); - obj += v->getObjectsSize(); - act += v->getActivationsSize(); - ttl += v->getTotalSize(); - } - - nd4j_printf("ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", act / _merges, tmp / _merges, obj / _merges, ttl / _merges); - - nd4j_printf("\nTime:\n", ""); - nd4j_printf("Construction time: %lld ns;\n", _buildTime / _merges); - nd4j_printf("Execution time: %lld ns;\n", _executionTime / _merges); - - nd4j_printf("\nPer-node reports:\n", ""); - if (_profiles.empty()) - nd4j_printf("No nodes in graph\n",""); - - // printint out stuff - std::vector sorted; - for (auto v: _profiles) { - v->printOut(); - sorted.emplace_back(v); - } - - if (_profiles.size() > 1) { - // building hot spots - std::sort(sorted.begin(), sorted.end(), [](const NodeProfile *a, const NodeProfile *b) -> bool { - return a->getExecutionTime() > b->getExecutionTime(); - }); - - nd4j_printf("\nTop 50 reports by EXEC:\n", ""); - auto limit = sd::math::nd4j_min(50, sorted.size()); - for (int e = 0; e < limit; e++) { - sorted[e]->printOut(); - } - } - - nd4j_printf("\nSpecial timers:\n", ""); - if (_timings.empty()) - nd4j_printf("No special timers were set\n",""); - - for (auto v: _timings) - nd4j_printf("%s: %lld ns;\n", v.first.c_str(), v.second); - } +namespace graph { +GraphProfile::GraphProfile() { updateLast(); } + +GraphProfile::~GraphProfile() { + // releasing NodeProfile pointers + for (auto v : _profiles) delete v; + + _timings.clear(); +} + +void GraphProfile::addToTotal(Nd4jLong bytes) { _memoryTotal += bytes; } + +void GraphProfile::addToActivations(Nd4jLong bytes) { + _memoryActivations += bytes; +} + +void GraphProfile::addToTemporary(Nd4jLong bytes) { _memoryTemporary += bytes; } + +void GraphProfile::addToObjects(Nd4jLong bytes) { _memoryObjects += bytes; } + +void GraphProfile::setBuildTime(Nd4jLong nanos) { _buildTime = nanos; } + +void GraphProfile::setExecutionTime(Nd4jLong nanos) { _executionTime = nanos; } + +Nd4jLong GraphProfile::currentTime() { + auto t = std::chrono::system_clock::now(); + auto v = std::chrono::time_point_cast(t); + auto epoch = v.time_since_epoch(); + return (Nd4jLong)std::chrono::duration_cast(epoch) + .count(); +} + +Nd4jLong GraphProfile::relativeTime(Nd4jLong time) { + auto t1 = currentTime(); + return t1 - time; +} + +void GraphProfile::updateLast() { _last = std::chrono::system_clock::now(); } + +void GraphProfile::startEvent(const char *name) { + std::string k = name; + _timers[k] = std::chrono::system_clock::now(); +} + +void GraphProfile::recordEvent(const char *name) { + std::string k = name; + if (_timers.count(k) == 0) { + nd4j_printf("Can't find timer key: [%s]", name); + throw std::runtime_error("Missing timer key"); + } + auto t0 = _timers[k]; + auto t1 = std::chrono::system_clock::now(); + auto v = + (Nd4jLong)std::chrono::duration_cast(t1 - t0) + .count(); + + _timings[k] = v; + _timers.erase(k); +} + +void GraphProfile::deleteEvent(const char *name) { + std::string k = name; + _timers.erase(k); +} + +void GraphProfile::spotEvent(const char *name) { + auto t = std::chrono::system_clock::now(); + auto d = + (Nd4jLong)std::chrono::duration_cast(t - _last) + .count(); + std::string k = name; + _timings[k] = d; + updateLast(); +} + +NodeProfile *GraphProfile::nodeById(int id, const char *name) { + if (_profilesById.count(id) == 0) { + auto node = new NodeProfile(id, name); + _profiles.emplace_back(node); + _profilesById[id] = node; + return node; + } + + return _profilesById[id]; +} + +void GraphProfile::merge(GraphProfile *other) { + _merges += other->_merges; + _memoryActivations += other->_memoryActivations; + _memoryTemporary += other->_memoryTemporary; + _memoryTotal += other->_memoryTotal; + _memoryObjects += other->_memoryObjects; + + _executionTime += other->_executionTime; + _buildTime += other->_buildTime; + + for (auto v : _profilesById) { + if (!other->nodeExists(v.first)) continue; + + v.second->merge(other->nodeById(v.first)); + } +} + +void GraphProfile::assign(GraphProfile *other) { + _merges = other->_merges; + _memoryActivations = other->_memoryActivations; + _memoryTemporary = other->_memoryTemporary; + _memoryTotal = other->_memoryTotal; + _memoryObjects = other->_memoryObjects; + + _executionTime = other->_executionTime; + _buildTime = other->_buildTime; + + for (auto v : other->_profilesById) { + nodeById(v.first, v.second->name().c_str())->assign(v.second); + } +} + +bool GraphProfile::nodeExists(int id) { return _profilesById.count(id) > 0; } + +void GraphProfile::printOut() { + nd4j_printf("Graph profile: %i executions\n", _merges); + nd4j_printf("\nMemory:\n", ""); + + Nd4jLong tmp = 0L; + Nd4jLong obj = 0L; + Nd4jLong act = 0L; + Nd4jLong ttl = 0L; + for (auto v : _profiles) { + tmp += v->getTemporarySize(); + obj += v->getObjectsSize(); + act += v->getActivationsSize(); + ttl += v->getTotalSize(); + } + + nd4j_printf("ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", act / _merges, + tmp / _merges, obj / _merges, ttl / _merges); + + nd4j_printf("\nTime:\n", ""); + nd4j_printf("Construction time: %lld ns;\n", _buildTime / _merges); + nd4j_printf("Execution time: %lld ns;\n", _executionTime / _merges); + + nd4j_printf("\nPer-node reports:\n", ""); + if (_profiles.empty()) nd4j_printf("No nodes in graph\n", ""); + + // printint out stuff + std::vector sorted; + for (auto v : _profiles) { + v->printOut(); + sorted.emplace_back(v); + } + + if (_profiles.size() > 1) { + // building hot spots + std::sort(sorted.begin(), sorted.end(), + [](const NodeProfile *a, const NodeProfile *b) -> bool { + return a->getExecutionTime() > b->getExecutionTime(); + }); + + nd4j_printf("\nTop 50 reports by EXEC:\n", ""); + auto limit = sd::math::nd4j_min(50, sorted.size()); + for (int e = 0; e < limit; e++) { + sorted[e]->printOut(); } -} \ No newline at end of file + } + + nd4j_printf("\nSpecial timers:\n", ""); + if (_timings.empty()) nd4j_printf("No special timers were set\n", ""); + + for (auto v : _timings) + nd4j_printf("%s: %lld ns;\n", v.first.c_str(), v.second); +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp index 001a8517a344..bf72392e21a4 100644 --- a/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp +++ b/libnd4j/include/graph/profiling/impl/GraphProfilingHelper.cpp @@ -20,53 +20,52 @@ #include - namespace sd { - namespace graph { - GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { - if (1 > 0) - throw std::runtime_error("GraphProfilingHelper::profile - Not implemented yet"); - - // saving original workspace - //auto varSpace = graph->variableSpace(); +namespace graph { +GraphProfile *GraphProfilingHelper::profile(Graph *graph, int iterations) { + if (1 > 0) + throw std::runtime_error( + "GraphProfilingHelper::profile - Not implemented yet"); - // printing out graph structure - // graph->printOut(); + // saving original workspace + // auto varSpace = graph->variableSpace(); - // warm up - for (int e = 0; e < iterations; e++) { - FlowPath fp; + // printing out graph structure + // graph->printOut(); - //auto _vs = varSpace->clone(); - //_vs->workspace()->expandTo(100000); - //_vs->setFlowPath(&fp); - //GraphExecutioner::execute(graph, _vs); + // warm up + for (int e = 0; e < iterations; e++) { + FlowPath fp; - //delete _vs; - } + // auto _vs = varSpace->clone(); + //_vs->workspace()->expandTo(100000); + //_vs->setFlowPath(&fp); + // GraphExecutioner::execute(graph, _vs); + // delete _vs; + } - auto profile = new GraphProfile(); - for (int e = 0; e < iterations; e++) { - FlowPath fp; -/* - // we're always starting from "fresh" varspace here - auto _vs = varSpace->clone(); - //_vs->workspace()->expandTo(100000); - _vs->setFlowPath(&fp); - //GraphExecutioner::execute(graph, _vs); + auto profile = new GraphProfile(); + for (int e = 0; e < iterations; e++) { + FlowPath fp; + /* + // we're always starting from "fresh" varspace here + auto _vs = varSpace->clone(); + //_vs->workspace()->expandTo(100000); + _vs->setFlowPath(&fp); + //GraphExecutioner::execute(graph, _vs); - auto p = fp.profile(); - if (e == 0) - profile->assign(p); - else - profile->merge(p); + auto p = fp.profile(); + if (e == 0) + profile->assign(p); + else + profile->merge(p); - delete _vs; - */ - } + delete _vs; + */ + } - return profile; - } - } + return profile; } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/profiling/impl/NodeProfile.cpp b/libnd4j/include/graph/profiling/impl/NodeProfile.cpp index 8db4472e6bc9..b78d8535ee93 100644 --- a/libnd4j/include/graph/profiling/impl/NodeProfile.cpp +++ b/libnd4j/include/graph/profiling/impl/NodeProfile.cpp @@ -18,150 +18,119 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace graph { - NodeProfile::NodeProfile(int id, const char *name) { - _id = id; - - if (name != nullptr) - _name = name; - }; - - void NodeProfile::printOut() { - nd4j_printf("Node: <%i:%s>\n", _id, _name.c_str()); - nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", _memoryActivations / _merges, _memoryTemporary / _merges, _memoryObjects / _merges, _memoryTotal / _merges); - nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", _preparationTime / _merges, _executionTime / _merges, _totalTime / _merges); - nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges); - - std::string inputs; - std::string outputs; - - int cnt = 0; - for (const auto &v: _inputShapes) - inputs += v + " "; - - for (const auto &v: _outputShapes) - outputs += v + " "; - - - nd4j_printf(" Inputs: %s\n", inputs.c_str()); - nd4j_printf(" Outputs: %s\n", outputs.c_str()); - }; - - Nd4jLong NodeProfile::getActivationsSize() const { - return _memoryActivations; - } - - void NodeProfile::setShapeFunctionTime(Nd4jLong time) { - _shapeTime = time; - } - - void NodeProfile::setArrayTime(Nd4jLong time) { - _arrayTime = time; - } - - void NodeProfile::setInputTime(Nd4jLong time) { - _inputTime = time; - } - - Nd4jLong NodeProfile::getTemporarySize() const{ - return _memoryTemporary; - } - - Nd4jLong NodeProfile::getObjectsSize() const{ - return _memoryObjects; - } - - Nd4jLong NodeProfile::getTotalSize() const{ - return _memoryTotal; - } - - void NodeProfile::setBuildTime(Nd4jLong time) { - _buildTime = time; - } - - void NodeProfile::setPreparationTime(Nd4jLong time) { - _preparationTime = time; - } - - void NodeProfile::setExecutionTime(Nd4jLong time) { - _executionTime = time; - } - - void NodeProfile::setTotalTime(Nd4jLong time) { - _totalTime = time; - } - - void NodeProfile::setActivationsSize(Nd4jLong bytes) { - _memoryActivations = bytes; - } - - void NodeProfile::setTemporarySize(Nd4jLong bytes) { - _memoryTemporary = bytes; - } - - void NodeProfile::setObjectsSize(Nd4jLong bytes) { - _memoryObjects = bytes; - } - - void NodeProfile::setTotalSize(Nd4jLong bytes) { - _memoryTotal = bytes; - } - - Nd4jLong NodeProfile::getExecutionTime() const { - return _executionTime; - } - - void NodeProfile::addInputShape(Nd4jLong const* shapeInfo) { - _inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); - } - - void NodeProfile::addOutputShape(Nd4jLong const*shapeInfo) { - _outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); - } - - void NodeProfile::merge(NodeProfile *other) { - _merges += other->_merges; - _memoryObjects += other->_memoryObjects; - _memoryActivations += other->_memoryActivations; - _memoryTemporary += other->_memoryTemporary; - _memoryTotal += other->_memoryTotal; - - _preparationTime += other->_preparationTime; - _executionTime += other->_executionTime; - _totalTime += other->_totalTime; - _shapeTime += other->_shapeTime; - _arrayTime += other->_arrayTime; - _inputTime += other->_inputTime; - - _inputShapes = other->_inputShapes; - _outputShapes = other->_outputShapes; - } - - std::string& NodeProfile::name() { - return _name; - } - - void NodeProfile::assign(NodeProfile *other) { - _merges = other->_merges; - _memoryObjects = other->_memoryObjects; - _memoryActivations = other->_memoryActivations; - _memoryTemporary = other->_memoryTemporary; - _memoryTotal = other->_memoryTotal; - - _preparationTime = other->_preparationTime; - _executionTime = other->_executionTime; - _totalTime = other->_totalTime; - _shapeTime = other->_shapeTime; - _arrayTime = other->_arrayTime; - _inputTime = other->_inputTime; - - _inputShapes = other->_inputShapes; - _outputShapes = other->_outputShapes; - } - } -} \ No newline at end of file +namespace graph { +NodeProfile::NodeProfile(int id, const char *name) { + _id = id; + + if (name != nullptr) _name = name; +}; + +void NodeProfile::printOut() { + nd4j_printf("Node: <%i:%s>\n", _id, _name.c_str()); + nd4j_printf(" Memory: ACT: %lld; TMP: %lld; OBJ: %lld; TTL: %lld;\n", + _memoryActivations / _merges, _memoryTemporary / _merges, + _memoryObjects / _merges, _memoryTotal / _merges); + nd4j_printf(" Time: PREP: %lld ns; EXEC: %lld ns; TTL: %lld ns;\n", + _preparationTime / _merges, _executionTime / _merges, + _totalTime / _merges); + nd4j_printf(" PREP: INPUT: %lld ns; SHAPE: %lld ns; ARRAY: %lld ns;\n", + _inputTime / _merges, _shapeTime / _merges, _arrayTime / _merges); + + std::string inputs; + std::string outputs; + + int cnt = 0; + for (const auto &v : _inputShapes) inputs += v + " "; + + for (const auto &v : _outputShapes) outputs += v + " "; + + nd4j_printf(" Inputs: %s\n", inputs.c_str()); + nd4j_printf(" Outputs: %s\n", outputs.c_str()); +}; + +Nd4jLong NodeProfile::getActivationsSize() const { return _memoryActivations; } + +void NodeProfile::setShapeFunctionTime(Nd4jLong time) { _shapeTime = time; } + +void NodeProfile::setArrayTime(Nd4jLong time) { _arrayTime = time; } + +void NodeProfile::setInputTime(Nd4jLong time) { _inputTime = time; } + +Nd4jLong NodeProfile::getTemporarySize() const { return _memoryTemporary; } + +Nd4jLong NodeProfile::getObjectsSize() const { return _memoryObjects; } + +Nd4jLong NodeProfile::getTotalSize() const { return _memoryTotal; } + +void NodeProfile::setBuildTime(Nd4jLong time) { _buildTime = time; } + +void NodeProfile::setPreparationTime(Nd4jLong time) { _preparationTime = time; } + +void NodeProfile::setExecutionTime(Nd4jLong time) { _executionTime = time; } + +void NodeProfile::setTotalTime(Nd4jLong time) { _totalTime = time; } + +void NodeProfile::setActivationsSize(Nd4jLong bytes) { + _memoryActivations = bytes; +} + +void NodeProfile::setTemporarySize(Nd4jLong bytes) { _memoryTemporary = bytes; } + +void NodeProfile::setObjectsSize(Nd4jLong bytes) { _memoryObjects = bytes; } + +void NodeProfile::setTotalSize(Nd4jLong bytes) { _memoryTotal = bytes; } + +Nd4jLong NodeProfile::getExecutionTime() const { return _executionTime; } + +void NodeProfile::addInputShape(Nd4jLong const *shapeInfo) { + _inputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); +} + +void NodeProfile::addOutputShape(Nd4jLong const *shapeInfo) { + _outputShapes.emplace_back(ShapeUtils::shapeInfoAsString(shapeInfo)); +} + +void NodeProfile::merge(NodeProfile *other) { + _merges += other->_merges; + _memoryObjects += other->_memoryObjects; + _memoryActivations += other->_memoryActivations; + _memoryTemporary += other->_memoryTemporary; + _memoryTotal += other->_memoryTotal; + + _preparationTime += other->_preparationTime; + _executionTime += other->_executionTime; + _totalTime += other->_totalTime; + _shapeTime += other->_shapeTime; + _arrayTime += other->_arrayTime; + _inputTime += other->_inputTime; + + _inputShapes = other->_inputShapes; + _outputShapes = other->_outputShapes; +} + +std::string &NodeProfile::name() { return _name; } + +void NodeProfile::assign(NodeProfile *other) { + _merges = other->_merges; + _memoryObjects = other->_memoryObjects; + _memoryActivations = other->_memoryActivations; + _memoryTemporary = other->_memoryTemporary; + _memoryTotal = other->_memoryTotal; + + _preparationTime = other->_preparationTime; + _executionTime = other->_executionTime; + _totalTime = other->_totalTime; + _shapeTime = other->_shapeTime; + _arrayTime = other->_arrayTime; + _inputTime = other->_inputTime; + + _inputShapes = other->_inputShapes; + _outputShapes = other->_outputShapes; +} +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/ArrayUtils.h b/libnd4j/include/helpers/ArrayUtils.h index 2ecebeb4aa98..6392cce00d32 100644 --- a/libnd4j/include/helpers/ArrayUtils.h +++ b/libnd4j/include/helpers/ArrayUtils.h @@ -21,23 +21,23 @@ #ifndef LIBND4J_ARRAYUTILS_H #define LIBND4J_ARRAYUTILS_H +#include + +#include #include #include -#include -#include namespace sd { - namespace ArrayUtils { - void toIntPtr(std::initializer_list list, int* target); - void toIntPtr(std::vector& list, int* target); - - void toLongPtr(std::initializer_list list, Nd4jLong* target); - void toLongPtr(std::vector& list, Nd4jLong* target); +namespace ArrayUtils { +void toIntPtr(std::initializer_list list, int* target); +void toIntPtr(std::vector& list, int* target); +void toLongPtr(std::initializer_list list, Nd4jLong* target); +void toLongPtr(std::vector& list, Nd4jLong* target); - std::vector toLongVector(std::vector vec); - std::vector toLongVector(std::vector vec); - } -} +std::vector toLongVector(std::vector vec); +std::vector toLongVector(std::vector vec); +} // namespace ArrayUtils +} // namespace sd -#endif //LIBND4J_ARRAYUTILS_H +#endif // LIBND4J_ARRAYUTILS_H diff --git a/libnd4j/include/helpers/AttentionHelper.h b/libnd4j/include/helpers/AttentionHelper.h index a2c52c61b403..d6794b45db53 100644 --- a/libnd4j/include/helpers/AttentionHelper.h +++ b/libnd4j/include/helpers/AttentionHelper.h @@ -24,13 +24,17 @@ #include "array/NDArray.h" namespace sd { - class SD_EXPORT AttentionHelper { - - public: - static sd::NDArray multiHeadProject(const sd::NDArray* input, const sd::NDArray* projectionMatrix, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - static void multiHeadProjectBp(const sd::NDArray* input, const sd::NDArray* projectionMatrix, const sd::NDArray* eps, sd::NDArray* dLdInput, sd::NDArray* dLdProjectionMatrix, sd::LaunchContext * context = sd::LaunchContext ::defaultContext()); - }; -} - +class SD_EXPORT AttentionHelper { + public: + static sd::NDArray multiHeadProject( + const sd::NDArray* input, const sd::NDArray* projectionMatrix, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); + static void multiHeadProjectBp( + const sd::NDArray* input, const sd::NDArray* projectionMatrix, + const sd::NDArray* eps, sd::NDArray* dLdInput, + sd::NDArray* dLdProjectionMatrix, + sd::LaunchContext* context = sd::LaunchContext ::defaultContext()); +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/BenchmarkHelper.h b/libnd4j/include/helpers/BenchmarkHelper.h index 38e3826535c1..31411d8a2ff2 100644 --- a/libnd4j/include/helpers/BenchmarkHelper.h +++ b/libnd4j/include/helpers/BenchmarkHelper.h @@ -21,69 +21,117 @@ #ifndef LIBND4J_BENCHMARKHELPER_H #define LIBND4J_BENCHMARKHELPER_H - +#include +#include +#include #include -#include -#include -#include -#include +#include +#include #include +#include +#include #include -#include -#include -#include -#include +#include #include -#include #include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace sd { - class SD_EXPORT BenchmarkHelper { - private: - unsigned int _wIterations; - unsigned int _rIterations; - - protected: - std::string benchmarkOperation(OpBenchmark &benchmark); - - void benchmarkScalarOperation(scalar::Ops op, std::string testName, double value, NDArray &x, NDArray &z); - - void benchmarkDeclarableOp(sd::ops::DeclarableOp &op, std::string testName, Context &context); - - void benchmarkGEMM(char orderA, std::initializer_list shapeA, char orderB, std::initializer_list shapeB, char orderC, std::initializer_list shapeC); - - std::string printHeader(); - public: - BenchmarkHelper(unsigned int warmUpIterations = 10, unsigned int runIterations = 100); - - std::string runOperationSuit(std::initializer_list benchmarks, const char *msg = nullptr); - std::string runOperationSuit(std::vector &benchmarks, bool postHeaders, const char *msg = nullptr); - std::string runOperationSuit(OpBenchmark* benchmark); - - std::string runOperationSuit(ScalarBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(TransformBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message = nullptr); - std::string runOperationSuit(PairwiseBenchmark *op, const std::function& func, const char *message = nullptr); - - - std::string runOperationSuit(TransformBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ScalarBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(BroadcastBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(PairwiseBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - std::string runOperationSuit(MatrixBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - - std::string runOperationSuit(DeclarableBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message = nullptr); - }; -} - - -#endif //SD_BENCHMARKHELPER_H +class SD_EXPORT BenchmarkHelper { + private: + unsigned int _wIterations; + unsigned int _rIterations; + + protected: + std::string benchmarkOperation(OpBenchmark &benchmark); + + void benchmarkScalarOperation(scalar::Ops op, std::string testName, + double value, NDArray &x, NDArray &z); + + void benchmarkDeclarableOp(sd::ops::DeclarableOp &op, std::string testName, + Context &context); + + void benchmarkGEMM(char orderA, std::initializer_list shapeA, + char orderB, std::initializer_list shapeB, + char orderC, std::initializer_list shapeC); + + std::string printHeader(); + + public: + BenchmarkHelper(unsigned int warmUpIterations = 10, + unsigned int runIterations = 100); + + std::string runOperationSuit(std::initializer_list benchmarks, + const char *msg = nullptr); + std::string runOperationSuit(std::vector &benchmarks, + bool postHeaders, const char *msg = nullptr); + std::string runOperationSuit(OpBenchmark *benchmark); + + std::string runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + TransformBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message = nullptr); + std::string runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + const char *message = nullptr); + + std::string runOperationSuit( + TransformBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + BroadcastBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + std::string runOperationSuit( + MatrixBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); + + std::string runOperationSuit( + DeclarableBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message = nullptr); +}; +} // namespace sd + +#endif // SD_BENCHMARKHELPER_H diff --git a/libnd4j/include/helpers/BitwiseUtils.h b/libnd4j/include/helpers/BitwiseUtils.h index 38bcd1716ed4..931a47d51f51 100644 --- a/libnd4j/include/helpers/BitwiseUtils.h +++ b/libnd4j/include/helpers/BitwiseUtils.h @@ -21,109 +21,90 @@ #ifndef LIBND4J_BITWISEUTILS_H #define LIBND4J_BITWISEUTILS_H -#include #include #include #include + #include +#include namespace sd { - class SD_EXPORT BitwiseUtils { - public: - - - /** - * This method returns first non-zero bit index - * @param holder - * @return - */ - static int valueBit(int holder); - - /** - * This method returns vector representation of bits. - * - * PLEASE NOTE: Result is ALWAYS left-to-right - */ - static std::vector valueBits(int holder); - - /** - * This method returns TRUE if it's called on Big-Endian system, and false otherwise - */ - static bool isBE(); - - /** - * This method returns enum - * @return - */ - static sd::ByteOrder asByteOrder(); - - /** - * This method swaps bytes: LE vs BE - * @tparam T - * @param v - * @return - */ - template - static FORCEINLINE T swap_bytes(T v) { - static_assert (CHAR_BIT == 8, "CHAR_BIT != 8"); - - union S { - T v; - unsigned char u8[sizeof(T)]; - S(T val) { - v = val; - } - }; - - S source(v); - S dest(v); - - for (size_t k = 0; k < sizeof(T); k++) - dest.u8[k] = source.u8[sizeof(T) - k - 1]; - - return dest.v; - } - - /** - * This method flips bits in given value - * - * @tparam T - * @param v - * @return - */ - static int FORCEINLINE flip_bits(int v) { - return ~v; - } - - static int8_t FORCEINLINE flip_bits(int8_t v) { - return ~v; - } - - static int16_t FORCEINLINE flip_bits(int16_t v) { - return ~v; - } - - static uint8_t FORCEINLINE flip_bits(uint8_t v) { - return ~v; - } - - static uint16_t FORCEINLINE flip_bits(uint16_t v) { - return ~v; - } - - static uint32_t FORCEINLINE flip_bits(uint32_t v) { - return ~v; - } - - static uint64_t FORCEINLINE flip_bits(uint64_t v) { - return ~v; - } - - static Nd4jLong FORCEINLINE flip_bits(Nd4jLong v) { - return ~v; - } +class SD_EXPORT BitwiseUtils { + public: + /** + * This method returns first non-zero bit index + * @param holder + * @return + */ + static int valueBit(int holder); + + /** + * This method returns vector representation of bits. + * + * PLEASE NOTE: Result is ALWAYS left-to-right + */ + static std::vector valueBits(int holder); + + /** + * This method returns TRUE if it's called on Big-Endian system, and false + * otherwise + */ + static bool isBE(); + + /** + * This method returns enum + * @return + */ + static sd::ByteOrder asByteOrder(); + + /** + * This method swaps bytes: LE vs BE + * @tparam T + * @param v + * @return + */ + template + static FORCEINLINE T swap_bytes(T v) { + static_assert(CHAR_BIT == 8, "CHAR_BIT != 8"); + + union S { + T v; + unsigned char u8[sizeof(T)]; + S(T val) { v = val; } }; -} + S source(v); + S dest(v); + + for (size_t k = 0; k < sizeof(T); k++) + dest.u8[k] = source.u8[sizeof(T) - k - 1]; + + return dest.v; + } + + /** + * This method flips bits in given value + * + * @tparam T + * @param v + * @return + */ + static int FORCEINLINE flip_bits(int v) { return ~v; } + + static int8_t FORCEINLINE flip_bits(int8_t v) { return ~v; } + + static int16_t FORCEINLINE flip_bits(int16_t v) { return ~v; } + + static uint8_t FORCEINLINE flip_bits(uint8_t v) { return ~v; } + + static uint16_t FORCEINLINE flip_bits(uint16_t v) { return ~v; } + + static uint32_t FORCEINLINE flip_bits(uint32_t v) { return ~v; } + + static uint64_t FORCEINLINE flip_bits(uint64_t v) { return ~v; } + + static Nd4jLong FORCEINLINE flip_bits(Nd4jLong v) { return ~v; } +}; +} // namespace sd -#endif //LIBND4J_BITWISEUTILS_H +#endif // LIBND4J_BITWISEUTILS_H diff --git a/libnd4j/include/helpers/BlasHelper.h b/libnd4j/include/helpers/BlasHelper.h index b2fe7b60ccb6..18b03021b724 100644 --- a/libnd4j/include/helpers/BlasHelper.h +++ b/libnd4j/include/helpers/BlasHelper.h @@ -21,98 +21,90 @@ #ifndef LIBND4J_BLAS_HELPER_H #define LIBND4J_BLAS_HELPER_H -#include -#include #include #include +#include +#include #ifdef _WIN32 #define CUBLASWINAPI __stdcall #define CUSOLVERAPI __stdcall #else -#define CUBLASWINAPI -#define CUSOLVERAPI +#define CUBLASWINAPI +#define CUSOLVERAPI #endif namespace sd { - typedef enum{ - CUBLAS_STATUS_SUCCESS =0, - CUBLAS_STATUS_NOT_INITIALIZED =1, - CUBLAS_STATUS_ALLOC_FAILED =3, - CUBLAS_STATUS_INVALID_VALUE =7, - CUBLAS_STATUS_ARCH_MISMATCH =8, - CUBLAS_STATUS_MAPPING_ERROR =11, - CUBLAS_STATUS_EXECUTION_FAILED=13, - CUBLAS_STATUS_INTERNAL_ERROR =14, - CUBLAS_STATUS_NOT_SUPPORTED =15, - CUBLAS_STATUS_LICENSE_ERROR =16 - } cublasStatus_t; - - typedef enum { - CUBLAS_OP_N=0, - CUBLAS_OP_T=1, - CUBLAS_OP_C=2 - } cublasOperation_t; - - struct cublasContext; - typedef struct cublasContext *cublasHandle_t; - - typedef enum - { - CUDA_R_16F= 2, /* real as a half */ - CUDA_C_16F= 6, /* complex as a pair of half numbers */ - CUDA_R_32F= 0, /* real as a float */ - CUDA_C_32F= 4, /* complex as a pair of float numbers */ - CUDA_R_64F= 1, /* real as a double */ - CUDA_C_64F= 5, /* complex as a pair of double numbers */ - CUDA_R_8I = 3, /* real as a signed char */ - CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ - CUDA_R_8U = 8, /* real as a unsigned char */ - CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ - CUDA_R_32I= 10, /* real as a signed int */ - CUDA_C_32I= 11, /* complex as a pair of signed int numbers */ - CUDA_R_32U= 12, /* real as a unsigned int */ - CUDA_C_32U= 13 /* complex as a pair of unsigned int numbers */ - } cublasDataType_t; - - typedef void (*CblasSgemv)(CBLAS_ORDER Layout, - CBLAS_TRANSPOSE TransA, int M, int N, - float alpha, float *A, int lda, - float *X, int incX, float beta, - float *Y, int incY); - - typedef void (*CblasDgemv)(CBLAS_ORDER Layout, - CBLAS_TRANSPOSE TransA, int M, int N, - double alpha, double *A, int lda, - double *X, int incX, double beta, - double *Y, int incY); - - - typedef void (*CblasSgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, int M, int N, - int K, float alpha, float *A, - int lda, float *B, int ldb, - float beta, float *C, int ldc); - - typedef void (*CblasDgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, int M, int N, - int K, double alpha, double *A, - int lda, double *B, int ldb, - double beta, double *C, int ldc); - - typedef void (*CblasSgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, - CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, - int *K_Array, float *alpha_Array, float **A_Array, - int *lda_Array, float **B_Array, int *ldb_Array, - float *beta_Array, float **C_Array, int *ldc_Array, - int group_count, int *group_size); - - typedef void (*CblasDgemmBatch)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, - CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, - int *K_Array, double *alpha_Array, double **A_Array, - int *lda_Array, double **B_Array, int* ldb_Array, - double *beta_Array, double **C_Array, int *ldc_Array, - int group_count, int *group_size); +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2 +} cublasOperation_t; + +struct cublasContext; +typedef struct cublasContext *cublasHandle_t; + +typedef enum { + CUDA_R_16F = 2, /* real as a half */ + CUDA_C_16F = 6, /* complex as a pair of half numbers */ + CUDA_R_32F = 0, /* real as a float */ + CUDA_C_32F = 4, /* complex as a pair of float numbers */ + CUDA_R_64F = 1, /* real as a double */ + CUDA_C_64F = 5, /* complex as a pair of double numbers */ + CUDA_R_8I = 3, /* real as a signed char */ + CUDA_C_8I = 7, /* complex as a pair of signed char numbers */ + CUDA_R_8U = 8, /* real as a unsigned char */ + CUDA_C_8U = 9, /* complex as a pair of unsigned char numbers */ + CUDA_R_32I = 10, /* real as a signed int */ + CUDA_C_32I = 11, /* complex as a pair of signed int numbers */ + CUDA_R_32U = 12, /* real as a unsigned int */ + CUDA_C_32U = 13 /* complex as a pair of unsigned int numbers */ +} cublasDataType_t; + +typedef void (*CblasSgemv)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, int M, + int N, float alpha, float *A, int lda, float *X, + int incX, float beta, float *Y, int incY); + +typedef void (*CblasDgemv)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, int M, + int N, double alpha, double *A, int lda, double *X, + int incX, double beta, double *Y, int incY); + +typedef void (*CblasSgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, int M, int N, int K, + float alpha, float *A, int lda, float *B, int ldb, + float beta, float *C, int ldc); + +typedef void (*CblasDgemm)(CBLAS_ORDER Layout, CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, int M, int N, int K, + double alpha, double *A, int lda, double *B, int ldb, + double beta, double *C, int ldc); + +typedef void (*CblasSgemmBatch)( + CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, + CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, int *K_Array, + float *alpha_Array, float **A_Array, int *lda_Array, float **B_Array, + int *ldb_Array, float *beta_Array, float **C_Array, int *ldc_Array, + int group_count, int *group_size); + +typedef void (*CblasDgemmBatch)( + CBLAS_ORDER Layout, CBLAS_TRANSPOSE *TransA_Array, + CBLAS_TRANSPOSE *TransB_Array, int *M_Array, int *N_Array, int *K_Array, + double *alpha_Array, double **A_Array, int *lda_Array, double **B_Array, + int *ldb_Array, double *beta_Array, double **C_Array, int *ldc_Array, + int group_count, int *group_size); #ifdef LAPACK_ROW_MAJOR #undef LAPACK_ROW_MAJOR @@ -121,324 +113,215 @@ namespace sd { #ifdef LAPACK_COL_MAJOR #undef LAPACK_COL_MAJOR #endif - enum LAPACK_LAYOUT { LAPACK_ROW_MAJOR=101, LAPACK_COL_MAJOR=102 }; - - typedef int (*LapackeSgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, - int m, int n, float* a, int lda, - float* s, float* u, int ldu, float* vt, - int ldvt, float* superb); - - typedef int (*LapackeDgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, - int m, int n, double* a, - int lda, double* s, double* u, int ldu, - double* vt, int ldvt, double* superb); - - typedef int (*LapackeSgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, - int n, float* a, int lda, float* s, - float* u, int ldu, float* vt, - int ldvt); - typedef int (*LapackeDgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, - int n, double* a, int lda, double* s, - double* u, int ldu, double* vt, - int ldvt); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemv)(cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - float *alpha, /* host or device pointer */ - float *A, - int lda, - float *x, - int incx, - float *beta, /* host or device pointer */ - float *y, - int incy); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemv)(cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - double *alpha, /* host or device pointer */ - double *A, - int lda, - double *x, - int incx, - double *beta, /* host or device pointer */ - double *y, - int incy); - - typedef cublasStatus_t (CUBLASWINAPI *CublasHgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - __half *alpha, /* host or device pointer */ - __half *A, - int lda, - __half *B, - int ldb, - __half *beta, /* host or device pointer */ - __half *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - float *A, - int lda, - float *B, - int ldb, - float *beta, /* host or device pointer */ - float *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemm)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - double *alpha, /* host or device pointer */ - double *A, - int lda, - double *B, - int ldb, - double *beta, /* host or device pointer */ - double *C, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmEx)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - void *A, - cublasDataType_t Atype, - int lda, - void *B, - cublasDataType_t Btype, - int ldb, - float *beta, /* host or device pointer */ - void *C, - cublasDataType_t Ctype, - int ldc); - - typedef cublasStatus_t (CUBLASWINAPI *CublasHgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - __half *alpha, /* host or device pointer */ - __half *Aarray[], - int lda, - __half *Barray[], - int ldb, - __half *beta, /* host or device pointer */ - __half *Carray[], - int ldc, - int batchCount); - - typedef cublasStatus_t (CUBLASWINAPI *CublasSgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host or device pointer */ - float *Aarray[], - int lda, - float *Barray[], - int ldb, - float *beta, /* host or device pointer */ - float *Carray[], - int ldc, - int batchCount); - - typedef cublasStatus_t (CUBLASWINAPI *CublasDgemmBatched)(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - double *alpha, /* host or device pointer */ - double *Aarray[], - int lda, - double *Barray[], - int ldb, - double *beta, /* host or device pointer */ - double *Carray[], - int ldc, - int batchCount); - - typedef enum{ - CUSOLVER_STATUS_SUCCESS=0, - CUSOLVER_STATUS_NOT_INITIALIZED=1, - CUSOLVER_STATUS_ALLOC_FAILED=2, - CUSOLVER_STATUS_INVALID_VALUE=3, - CUSOLVER_STATUS_ARCH_MISMATCH=4, - CUSOLVER_STATUS_MAPPING_ERROR=5, - CUSOLVER_STATUS_EXECUTION_FAILED=6, - CUSOLVER_STATUS_INTERNAL_ERROR=7, - CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED=8, - CUSOLVER_STATUS_NOT_SUPPORTED = 9, - CUSOLVER_STATUS_ZERO_PIVOT=10, - CUSOLVER_STATUS_INVALID_LICENSE=11 - } cusolverStatus_t; - - typedef enum { - CUSOLVER_EIG_TYPE_1=1, - CUSOLVER_EIG_TYPE_2=2, - CUSOLVER_EIG_TYPE_3=3 - } cusolverEigType_t ; - - typedef enum { - CUSOLVER_EIG_MODE_NOVECTOR=0, - CUSOLVER_EIG_MODE_VECTOR=1 - } cusolverEigMode_t ; - - struct cusolverDnContext; - typedef struct cusolverDnContext *cusolverDnHandle_t; - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvdBufferSize)( - cusolverDnHandle_t handle, - int m, - int n, - int *lwork); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvdBufferSize)( - cusolverDnHandle_t handle, - int m, - int n, - int *lwork); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnSgesvd)( - cusolverDnHandle_t handle, - signed char jobu, - signed char jobvt, - int m, - int n, - float *A, - int lda, - float *S, - float *U, - int ldu, - float *VT, - int ldvt, - float *work, - int lwork, - float *rwork, - int *info); - - typedef cusolverStatus_t (CUSOLVERAPI *CusolverDnDgesvd)( - cusolverDnHandle_t handle, - signed char jobu, - signed char jobvt, - int m, - int n, - double *A, - int lda, - double *S, - double *U, - int ldu, - double *VT, - int ldvt, - double *work, - int lwork, - double *rwork, - int *info); - - - enum BlasFunctions { - GEMV = 0, - GEMM = 1, - }; - - class BlasHelper { - private: - static BlasHelper* _instance; - - bool _hasHgemv = false; - bool _hasHgemm = false; - bool _hasHgemmBatch = false; - - bool _hasSgemv = false; - bool _hasSgemm = false; - bool _hasSgemmBatch = false; - - bool _hasDgemv = false; - bool _hasDgemm = false; - bool _hasDgemmBatch = false; - - CblasSgemv cblasSgemv; - CblasDgemv cblasDgemv; - CblasSgemm cblasSgemm; - CblasDgemm cblasDgemm; - CblasSgemmBatch cblasSgemmBatch; - CblasDgemmBatch cblasDgemmBatch; - LapackeSgesvd lapackeSgesvd; - LapackeDgesvd lapackeDgesvd; - LapackeSgesdd lapackeSgesdd; - LapackeDgesdd lapackeDgesdd; - - CublasSgemv cublasSgemv; - CublasDgemv cublasDgemv; - CublasHgemm cublasHgemm; - CublasSgemm cublasSgemm; - CublasDgemm cublasDgemm; - CublasSgemmEx cublasSgemmEx; - CublasHgemmBatched cublasHgemmBatched; - CublasSgemmBatched cublasSgemmBatched; - CublasDgemmBatched cublasDgemmBatched; - CusolverDnSgesvdBufferSize cusolverDnSgesvdBufferSize; - CusolverDnDgesvdBufferSize cusolverDnDgesvdBufferSize; - CusolverDnSgesvd cusolverDnSgesvd; - CusolverDnDgesvd cusolverDnDgesvd; - - public: - static BlasHelper* getInstance(); - - void initializeFunctions(Nd4jPointer *functions); - void initializeDeviceFunctions(Nd4jPointer *functions); - - template - bool hasGEMV(); - - template - bool hasGEMM(); - - bool hasGEMM(const sd::DataType dtype); - bool hasGEMV(const sd::DataType dtype); - - template - bool hasBatchedGEMM(); - - CblasSgemv sgemv(); - CblasDgemv dgemv(); - - CblasSgemm sgemm(); - CblasDgemm dgemm(); - - CblasSgemmBatch sgemmBatched(); - CblasDgemmBatch dgemmBatched(); - - LapackeSgesvd sgesvd(); - LapackeDgesvd dgesvd(); - - LapackeSgesdd sgesdd(); - LapackeDgesdd dgesdd(); - - // destructor - ~BlasHelper() noexcept; - }; -} +enum LAPACK_LAYOUT { LAPACK_ROW_MAJOR = 101, LAPACK_COL_MAJOR = 102 }; + +typedef int (*LapackeSgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, + int m, int n, float *a, int lda, float *s, + float *u, int ldu, float *vt, int ldvt, + float *superb); + +typedef int (*LapackeDgesvd)(LAPACK_LAYOUT matrix_layout, char jobu, char jobvt, + int m, int n, double *a, int lda, double *s, + double *u, int ldu, double *vt, int ldvt, + double *superb); + +typedef int (*LapackeSgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, + int n, float *a, int lda, float *s, float *u, + int ldu, float *vt, int ldvt); +typedef int (*LapackeDgesdd)(LAPACK_LAYOUT matrix_layout, char jobz, int m, + int n, double *a, int lda, double *s, double *u, + int ldu, double *vt, int ldvt); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemv)( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + float *alpha, /* host or device pointer */ + float *A, int lda, float *x, int incx, + float *beta, /* host or device pointer */ + float *y, int incy); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemv)( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + double *alpha, /* host or device pointer */ + double *A, int lda, double *x, int incx, + double *beta, /* host or device pointer */ + double *y, int incy); + +typedef cublasStatus_t(CUBLASWINAPI *CublasHgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, __half *alpha, /* host or device pointer */ + __half *A, int lda, __half *B, int ldb, + __half *beta, /* host or device pointer */ + __half *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + float *A, int lda, float *B, int ldb, + float *beta, /* host or device pointer */ + float *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemm)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, double *alpha, /* host or device pointer */ + double *A, int lda, double *B, int ldb, + double *beta, /* host or device pointer */ + double *C, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemmEx)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + void *A, cublasDataType_t Atype, int lda, void *B, cublasDataType_t Btype, + int ldb, float *beta, /* host or device pointer */ + void *C, cublasDataType_t Ctype, int ldc); + +typedef cublasStatus_t(CUBLASWINAPI *CublasHgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, __half *alpha, /* host or device pointer */ + __half *Aarray[], int lda, __half *Barray[], int ldb, + __half *beta, /* host or device pointer */ + __half *Carray[], int ldc, int batchCount); + +typedef cublasStatus_t(CUBLASWINAPI *CublasSgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, float *alpha, /* host or device pointer */ + float *Aarray[], int lda, float *Barray[], int ldb, + float *beta, /* host or device pointer */ + float *Carray[], int ldc, int batchCount); + +typedef cublasStatus_t(CUBLASWINAPI *CublasDgemmBatched)( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, double *alpha, /* host or device pointer */ + double *Aarray[], int lda, double *Barray[], int ldb, + double *beta, /* host or device pointer */ + double *Carray[], int ldc, int batchCount); + +typedef enum { + CUSOLVER_STATUS_SUCCESS = 0, + CUSOLVER_STATUS_NOT_INITIALIZED = 1, + CUSOLVER_STATUS_ALLOC_FAILED = 2, + CUSOLVER_STATUS_INVALID_VALUE = 3, + CUSOLVER_STATUS_ARCH_MISMATCH = 4, + CUSOLVER_STATUS_MAPPING_ERROR = 5, + CUSOLVER_STATUS_EXECUTION_FAILED = 6, + CUSOLVER_STATUS_INTERNAL_ERROR = 7, + CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED = 8, + CUSOLVER_STATUS_NOT_SUPPORTED = 9, + CUSOLVER_STATUS_ZERO_PIVOT = 10, + CUSOLVER_STATUS_INVALID_LICENSE = 11 +} cusolverStatus_t; + +typedef enum { + CUSOLVER_EIG_TYPE_1 = 1, + CUSOLVER_EIG_TYPE_2 = 2, + CUSOLVER_EIG_TYPE_3 = 3 +} cusolverEigType_t; + +typedef enum { + CUSOLVER_EIG_MODE_NOVECTOR = 0, + CUSOLVER_EIG_MODE_VECTOR = 1 +} cusolverEigMode_t; + +struct cusolverDnContext; +typedef struct cusolverDnContext *cusolverDnHandle_t; + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnSgesvdBufferSize)( + cusolverDnHandle_t handle, int m, int n, int *lwork); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnDgesvdBufferSize)( + cusolverDnHandle_t handle, int m, int n, int *lwork); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnSgesvd)( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, float *A, int lda, float *S, float *U, int ldu, float *VT, int ldvt, + float *work, int lwork, float *rwork, int *info); + +typedef cusolverStatus_t(CUSOLVERAPI *CusolverDnDgesvd)( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, double *A, int lda, double *S, double *U, int ldu, double *VT, + int ldvt, double *work, int lwork, double *rwork, int *info); + +enum BlasFunctions { + GEMV = 0, + GEMM = 1, +}; + +class BlasHelper { + private: + static BlasHelper *_instance; + + bool _hasHgemv = false; + bool _hasHgemm = false; + bool _hasHgemmBatch = false; + + bool _hasSgemv = false; + bool _hasSgemm = false; + bool _hasSgemmBatch = false; + + bool _hasDgemv = false; + bool _hasDgemm = false; + bool _hasDgemmBatch = false; + + CblasSgemv cblasSgemv; + CblasDgemv cblasDgemv; + CblasSgemm cblasSgemm; + CblasDgemm cblasDgemm; + CblasSgemmBatch cblasSgemmBatch; + CblasDgemmBatch cblasDgemmBatch; + LapackeSgesvd lapackeSgesvd; + LapackeDgesvd lapackeDgesvd; + LapackeSgesdd lapackeSgesdd; + LapackeDgesdd lapackeDgesdd; + + CublasSgemv cublasSgemv; + CublasDgemv cublasDgemv; + CublasHgemm cublasHgemm; + CublasSgemm cublasSgemm; + CublasDgemm cublasDgemm; + CublasSgemmEx cublasSgemmEx; + CublasHgemmBatched cublasHgemmBatched; + CublasSgemmBatched cublasSgemmBatched; + CublasDgemmBatched cublasDgemmBatched; + CusolverDnSgesvdBufferSize cusolverDnSgesvdBufferSize; + CusolverDnDgesvdBufferSize cusolverDnDgesvdBufferSize; + CusolverDnSgesvd cusolverDnSgesvd; + CusolverDnDgesvd cusolverDnDgesvd; + + public: + static BlasHelper *getInstance(); + + void initializeFunctions(Nd4jPointer *functions); + void initializeDeviceFunctions(Nd4jPointer *functions); + + template + bool hasGEMV(); + + template + bool hasGEMM(); + + bool hasGEMM(const sd::DataType dtype); + bool hasGEMV(const sd::DataType dtype); + + template + bool hasBatchedGEMM(); + + CblasSgemv sgemv(); + CblasDgemv dgemv(); + + CblasSgemm sgemm(); + CblasDgemm dgemm(); + + CblasSgemmBatch sgemmBatched(); + CblasDgemmBatch dgemmBatched(); + + LapackeSgesvd sgesvd(); + LapackeDgesvd dgesvd(); + + LapackeSgesdd sgesdd(); + LapackeDgesdd dgesdd(); + + // destructor + ~BlasHelper() noexcept; +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index 44e4a71c7d14..cda0097d27dc 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -21,44 +21,48 @@ #ifndef SD_CONSTANTHELPER_H #define SD_CONSTANTHELPER_H -#include +#include +#include +#include +#include #include +#include #include -#include -#include + #include #include -#include -#include -#include +#include namespace sd { - class SD_EXPORT ConstantHelper { - private: - static ConstantHelper* _INSTANCE; - ConstantHelper(); +class SD_EXPORT ConstantHelper { + private: + static ConstantHelper* _INSTANCE; + ConstantHelper(); + + std::vector> _cache; - std::vector> _cache; + // tracking of per-device constant memory buffers (CUDA only atm) + std::vector _devicePointers; + std::vector _deviceOffsets; + std::mutex _mutex; + std::mutex _mutexHolder; - // tracking of per-device constant memory buffers (CUDA only atm) - std::vector _devicePointers; - std::vector _deviceOffsets; - std::mutex _mutex; - std::mutex _mutexHolder; + std::vector _counters; - std::vector _counters; - public: - ~ConstantHelper() = default; + public: + ~ConstantHelper() = default; - static ConstantHelper* getInstance(); - static int getCurrentDevice(); - static int getNumberOfDevices(); - void* replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace = nullptr); + static ConstantHelper* getInstance(); + static int getCurrentDevice(); + static int getNumberOfDevices(); + void* replicatePointer(void* src, size_t numBytes, + memory::Workspace* workspace = nullptr); - ConstantDataBuffer* constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType); + ConstantDataBuffer* constantBuffer(const ConstantDescriptor& descriptor, + sd::DataType dataType); - Nd4jLong getCachedAmount(int deviceId); - }; -} + Nd4jLong getCachedAmount(int deviceId); +}; +} // namespace sd -#endif //SD_CONSTANTHELPER_H +#endif // SD_CONSTANTHELPER_H diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 1c7ae9bfdb56..cf8b732150ad 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -21,78 +21,86 @@ #ifndef SD_CONSTANTSHAPEHELPER_H #define SD_CONSTANTSHAPEHELPER_H +#include +#include +#include #include +#include #include + #include #include #include -#include -#include -#include -#include namespace sd { - class SD_EXPORT ConstantShapeHelper { - private: - static ConstantShapeHelper *_INSTANCE; - - std::mutex _mutex; - std::vector> _cache; - - - ConstantShapeHelper(); - public: - ~ConstantShapeHelper() = default; - - static ConstantShapeHelper* getInstance(); - - - ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape); - ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor); - ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo); - ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); - ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector &dimensions = {}); - - - const Nd4jLong* emptyShapeInfo(sd::DataType dataType); - const Nd4jLong* scalarShapeInfo(sd::DataType dataType); - const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType); - const Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); - const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, const std::vector &shape); - const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape); - const Nd4jLong* createShapeInfo(sd::DataType dataType, const Nd4jLong* shapeInfo); - - const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace); - const Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); - - bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor); - - - /** - * This method returns number of cached TAD shapes/offsets on specific device - * @return - */ - FORCEINLINE int cachedEntriesForDevice(int deviceId) { - if (deviceId > _cache.size()) - throw std::runtime_error("deviceId > number of actual devices"); - - return _cache[deviceId].size(); - } - - /** - * This method returns total number of cached TAD shapes/offsets on all devices - * @return - */ - FORCEINLINE int totalCachedEntries() { - int total = 0; - - for (int e = 0; e < _cache.size(); e++) - total += _cache[e].size(); - - return total; - } - }; -} - -#endif //SD_CONSTANTSHAPEHELPER_H +class SD_EXPORT ConstantShapeHelper { + private: + static ConstantShapeHelper* _INSTANCE; + + std::mutex _mutex; + std::vector> _cache; + + ConstantShapeHelper(); + + public: + ~ConstantShapeHelper() = default; + + static ConstantShapeHelper* getInstance(); + + ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, + const std::vector& shape); + ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor& descriptor); + ConstantDataBuffer bufferForShapeInfo(const Nd4jLong* shapeInfo); + ConstantDataBuffer bufferForShapeInfo(sd::DataType dataType, char order, + int rank, const Nd4jLong* shape); + ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace = nullptr, + const std::vector& dimensions = {}); + + const Nd4jLong* emptyShapeInfo(sd::DataType dataType); + const Nd4jLong* scalarShapeInfo(sd::DataType dataType); + const Nd4jLong* vectorShapeInfo(Nd4jLong length, sd::DataType dataType); + const Nd4jLong* createShapeInfo(const ShapeDescriptor& descriptor); + const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, + const std::vector& shape); + const Nd4jLong* createShapeInfo(sd::DataType dataType, char order, int rank, + const Nd4jLong* shape); + const Nd4jLong* createShapeInfo(sd::DataType dataType, + const Nd4jLong* shapeInfo); + + const Nd4jLong* createFromExisting(Nd4jLong* shapeInfo, + sd::memory::Workspace* workspace); + const Nd4jLong* createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal = true); + + bool checkBufferExistenceForShapeInfo(ShapeDescriptor& descriptor); + + /** + * This method returns number of cached TAD shapes/offsets on specific device + * @return + */ + FORCEINLINE int cachedEntriesForDevice(int deviceId) { + if (deviceId > _cache.size()) + throw std::runtime_error("deviceId > number of actual devices"); + + return _cache[deviceId].size(); + } + + /** + * This method returns total number of cached TAD shapes/offsets on all + * devices + * @return + */ + FORCEINLINE int totalCachedEntries() { + int total = 0; + + for (int e = 0; e < _cache.size(); e++) total += _cache[e].size(); + + return total; + } +}; +} // namespace sd + +#endif // SD_CONSTANTSHAPEHELPER_H diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index ef1aa4157c89..edaaaeeea37b 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -18,72 +18,80 @@ // @author raver119@gmail.com // - #ifndef SD_CONSTANTTADHELPER_H #define SD_CONSTANTTADHELPER_H +#include +#include +#include #include #include #include + #include -#include #include -#include -#include -#include +#include namespace sd { - class SD_EXPORT ConstantTadHelper { - private: - static ConstantTadHelper *_INSTANCE; - - std::mutex _mutex; - std::vector> _cache; - - ConstantTadHelper(); - public: - ~ConstantTadHelper() = default; - - static ConstantTadHelper* getInstance(); - - /** - * These methods calculate Tensor-Along-Dimension(s) shape and offsets - * - * @param originalShape - * @param dimensions - * @param keepUnitiesInShape - * @return - */ - TadPack tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape = false); - TadPack tadForDimensions(TadDescriptor &descriptor); - - /** - * This method returns number of cached TAD shapes/offsets on specific device - * @return - */ - FORCEINLINE int cachedEntriesForDevice(int deviceId) { - if (deviceId > _cache.size()) - throw std::runtime_error("deviceId > number of actual devices"); - - return _cache[deviceId].size(); - } - - /** - * This method returns total number of cached TAD shapes/offsets on all devices - * @return - */ - FORCEINLINE int totalCachedEntries() { - int total = 0; - - for (int e = 0; e < _cache.size(); e++) - total += _cache[e].size(); - - return total; - } - }; -} - -#endif //SD_CONSTANTTADHELPER_H +class SD_EXPORT ConstantTadHelper { + private: + static ConstantTadHelper *_INSTANCE; + + std::mutex _mutex; + std::vector> _cache; + + ConstantTadHelper(); + + public: + ~ConstantTadHelper() = default; + + static ConstantTadHelper *getInstance(); + + /** + * These methods calculate Tensor-Along-Dimension(s) shape and offsets + * + * @param originalShape + * @param dimensions + * @param keepUnitiesInShape + * @return + */ + TadPack tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int *dimensions, + int dimLength, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(const Nd4jLong *originalShape, int dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape = false); + TadPack tadForDimensions(TadDescriptor &descriptor); + + /** + * This method returns number of cached TAD shapes/offsets on specific device + * @return + */ + FORCEINLINE int cachedEntriesForDevice(int deviceId) { + if (deviceId > _cache.size()) + throw std::runtime_error("deviceId > number of actual devices"); + + return _cache[deviceId].size(); + } + + /** + * This method returns total number of cached TAD shapes/offsets on all + * devices + * @return + */ + FORCEINLINE int totalCachedEntries() { + int total = 0; + + for (int e = 0; e < _cache.size(); e++) total += _cache[e].size(); + + return total; + } +}; +} // namespace sd + +#endif // SD_CONSTANTTADHELPER_H diff --git a/libnd4j/include/helpers/CudaLaunchHelper.h b/libnd4j/include/helpers/CudaLaunchHelper.h index 3e933c546476..8d5cc889f397 100644 --- a/libnd4j/include/helpers/CudaLaunchHelper.h +++ b/libnd4j/include/helpers/CudaLaunchHelper.h @@ -21,19 +21,18 @@ #ifndef LIBND4J_CUDALAUNCHHELPER_H #define LIBND4J_CUDALAUNCHHELPER_H - -#include #include #include +#include #include namespace sd { - class SD_EXPORT CudaLaunchHelper { - public: - static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY); - static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); - }; -} - +class SD_EXPORT CudaLaunchHelper { + public: + static Triple getFlatLaunchParams(Nd4jLong length, int SM, int CORES, + int SHARED_MEMORY); + static int getReductionBlocks(Nd4jLong xLength, int blockSize = 512); +}; +} // namespace sd -#endif //LIBND4J_CUDALAUNCHHELPER_H +#endif // LIBND4J_CUDALAUNCHHELPER_H diff --git a/libnd4j/include/helpers/DebugHelper.h b/libnd4j/include/helpers/DebugHelper.h index 0e1bd28374ec..75204644d748 100644 --- a/libnd4j/include/helpers/DebugHelper.h +++ b/libnd4j/include/helpers/DebugHelper.h @@ -21,60 +21,64 @@ #ifndef LIBND4J_DEBUGHELPER_H #define LIBND4J_DEBUGHELPER_H -#include -#include -#include #include -#include +#include +#include +#include +#include #ifdef __CUDACC__ #include -#include #include +#include #endif #include namespace sd { - class NDArray; - class SD_EXPORT DebugHelper { - public: - - // cuda-specific debug functions +class NDArray; +class SD_EXPORT DebugHelper { + public: + // cuda-specific debug functions #ifdef __CUDACC__ - static FORCEINLINE void checkErrorCode(cudaStream_t *stream, int opNum = 0) { - if (Environment::getInstance()->isDebug()) { - cudaError_t res = cudaStreamSynchronize(*stream); + static FORCEINLINE void checkErrorCode(cudaStream_t* stream, int opNum = 0) { + if (Environment::getInstance()->isDebug()) { + cudaError_t res = cudaStreamSynchronize(*stream); - if (res != 0) { - //PRINT_FIRST("Kernel OpNum failed: [%i]\n", opNum); - std::string op = "Kernel OpNum failed: ["; - op += StringUtils::valueToString(opNum); - op += "]"; + if (res != 0) { + // PRINT_FIRST("Kernel OpNum failed: [%i]\n", opNum); + std::string op = "Kernel OpNum failed: ["; + op += StringUtils::valueToString(opNum); + op += "]"; - throw std::runtime_error(op); - } - } - } + throw std::runtime_error(op); + } + } + } - static FORCEINLINE void checkErrorCode(cudaStream_t *stream, const char *failMessage = nullptr) { - cudaError_t res = cudaStreamSynchronize(*stream); - if (res != 0) { - if (failMessage == nullptr) { - std::string op = "CUDA call ended with error code [" + StringUtils::valueToString(res) + std::string("]"); - throw std::runtime_error(op); - } else { - std::string op = std::string(failMessage) + std::string("Error code [") + StringUtils::valueToString(res) + std::string("]"); - throw std::runtime_error(op); - } - } - } + static FORCEINLINE void checkErrorCode(cudaStream_t* stream, + const char* failMessage = nullptr) { + cudaError_t res = cudaStreamSynchronize(*stream); + if (res != 0) { + if (failMessage == nullptr) { + std::string op = "CUDA call ended with error code [" + + StringUtils::valueToString(res) + + std::string("]"); + throw std::runtime_error(op); + } else { + std::string op = + std::string(failMessage) + std::string("Error code [") + + StringUtils::valueToString(res) + std::string("]"); + throw std::runtime_error(op); + } + } + } #endif - static DebugInfo debugStatistics(NDArray const* input); - static void retrieveDebugStatistics(DebugInfo* statistics, NDArray const* input); - }; -} - + static DebugInfo debugStatistics(NDArray const* input); + static void retrieveDebugStatistics(DebugInfo* statistics, + NDArray const* input); +}; +} // namespace sd -#endif //LIBND4J_DEBUGHELPER_H +#endif // LIBND4J_DEBUGHELPER_H diff --git a/libnd4j/include/helpers/DebugInfo.h b/libnd4j/include/helpers/DebugInfo.h index 6a2487c537b5..3d82928d2bcb 100644 --- a/libnd4j/include/helpers/DebugInfo.h +++ b/libnd4j/include/helpers/DebugInfo.h @@ -21,49 +21,49 @@ #ifndef LIBND4J__DEBUG_INFO_HELPER__H #define LIBND4J__DEBUG_INFO_HELPER__H -#include -#include -#include #include -#include -#include #include +#include +#include +#include +#include + +#include #ifdef __CUDACC__ #include -#include #include +#include #endif namespace sd { - struct SD_EXPORT DebugInfo { - double _minValue; - double _maxValue; - double _meanValue; - double _stdDevValue; - Nd4jLong _zeroCount; - Nd4jLong _positiveCount; - Nd4jLong _negativeCount; - Nd4jLong _infCount; - Nd4jLong _nanCount; - }; - - FORCEINLINE bool operator==(DebugInfo const& first, DebugInfo const& second) { - return sd::math::nd4j_abs(first._minValue - second._minValue) < 0.000001 && - sd::math::nd4j_abs(first._maxValue - second._maxValue) < 0.000001 && - sd::math::nd4j_abs(first._meanValue - second._meanValue) < 0.000001 && - sd::math::nd4j_abs(first._stdDevValue - second._stdDevValue) < 0.000001 && - first._zeroCount == second._zeroCount && - first._positiveCount == second._positiveCount && - first._negativeCount == second._negativeCount && - first._infCount == second._infCount && - first._nanCount == second._nanCount; - - } +struct SD_EXPORT DebugInfo { + double _minValue; + double _maxValue; + double _meanValue; + double _stdDevValue; + Nd4jLong _zeroCount; + Nd4jLong _positiveCount; + Nd4jLong _negativeCount; + Nd4jLong _infCount; + Nd4jLong _nanCount; +}; +FORCEINLINE bool operator==(DebugInfo const& first, DebugInfo const& second) { + return sd::math::nd4j_abs(first._minValue - second._minValue) < 0.000001 && + sd::math::nd4j_abs(first._maxValue - second._maxValue) < 0.000001 && + sd::math::nd4j_abs(first._meanValue - second._meanValue) < 0.000001 && + sd::math::nd4j_abs(first._stdDevValue - second._stdDevValue) < + 0.000001 && + first._zeroCount == second._zeroCount && + first._positiveCount == second._positiveCount && + first._negativeCount == second._negativeCount && + first._infCount == second._infCount && + first._nanCount == second._nanCount; } +} // namespace sd -#endif //LIBND4J_DEBUGHELPER_H +#endif // LIBND4J_DEBUGHELPER_H diff --git a/libnd4j/include/helpers/EnumUtils.h b/libnd4j/include/helpers/EnumUtils.h index 6138117c7e60..6d6fe1facbc8 100644 --- a/libnd4j/include/helpers/EnumUtils.h +++ b/libnd4j/include/helpers/EnumUtils.h @@ -25,12 +25,13 @@ #include namespace sd { - class EnumUtils { - public: - static const char * _VariableTypeToString(sd::graph::VariableType variableType); - static const char * _OpTypeToString(sd::graph::OpType opType); - static const char * _LogicOpToString(int opNum); - }; -} +class EnumUtils { + public: + static const char* _VariableTypeToString( + sd::graph::VariableType variableType); + static const char* _OpTypeToString(sd::graph::OpType opType); + static const char* _LogicOpToString(int opNum); +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/FileUtils.h b/libnd4j/include/helpers/FileUtils.h index aea296119fb2..cbcda199a4da 100644 --- a/libnd4j/include/helpers/FileUtils.h +++ b/libnd4j/include/helpers/FileUtils.h @@ -21,17 +21,17 @@ #ifndef SD_FILEUTILS_H #define SD_FILEUTILS_H -#include #include -namespace sd { - class SD_EXPORT FileUtils { - public: - static bool fileExists(const char *filename); +#include - static int64_t fileSize(const char *filename); - }; -} +namespace sd { +class SD_EXPORT FileUtils { + public: + static bool fileExists(const char *filename); + static int64_t fileSize(const char *filename); +}; +} // namespace sd -#endif //SD_FILEUTILS_H +#endif // SD_FILEUTILS_H diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index 001b8be495d8..f7abcaee50cc 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -21,50 +21,54 @@ #ifndef LIBND4J_GRADCHECK_H #define LIBND4J_GRADCHECK_H - #include #include namespace sd { class SD_EXPORT GradCheck { - - public: - enum LossFunc {MEAN = 0, SUM = 1}; - private: - static constexpr double EPSILON = 1e-5; - static constexpr double MAXRELERR = 1e-5; - static constexpr double MINABSERR = 1e-6; - static void fillGradArrays(const LossFunc loss, const std::vector& gradArrs); - - - public: - - /** - * performs numerical check of gradients in back prop - * - * opFF - feed forward operation - * opBP - back propagation operation - * argsHolderFF - argument holder for feed forward operation - * argsHolderBP - argument holder for back propagation operation - * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays - * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} - * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM - */ - static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); + public: + enum LossFunc { MEAN = 0, SUM = 1 }; + + private: + static constexpr double EPSILON = 1e-5; + static constexpr double MAXRELERR = 1e-5; + static constexpr double MINABSERR = 1e-6; + static void fillGradArrays(const LossFunc loss, + const std::vector& gradArrs); + + public: + /** + * performs numerical check of gradients in back prop + * + * opFF - feed forward operation + * opBP - back propagation operation + * argsHolderFF - argument holder for feed forward operation + * argsHolderBP - argument holder for back propagation operation + * whatArrsToCheck - specifies what output gradient arrays to check, for + * example {0, 1, 0} means that only second output gradient array will be + * checked, default value is empty std::vector which means to check all arrays + * IdxRange - specifies indexes range over which array elements will be + * checked, for example {0.2, 0.7} means range [0.2*array_length, + * 0.7*array_length), default value is {0., 1.} loss - type of scalar loss + * function, it specifies what elements values will be filled into input + * gradient arrays automatically, default value is SUM + */ + static bool checkGrad( + ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, + const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, + const std::vector& whatArrsToCheck = std::vector(), + const std::vector& IdxRange = {0., 1.}, + const LossFunc loss = SUM); }; - - - - // ////////////////////////////////////////////////////////////////////////// // ///// IMLEMENTATION OF INLINE METHODS ///// // ////////////////////////////////////////////////////////////////////////// // template -// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& permut) { +// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& +// permut) { // for(int i=0; i #include // #include @@ -33,242 +32,262 @@ namespace sd { - class SD_EXPORT LoopKind { - public: - enum Kind { SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y, BROADCAST_3D, BROADCAST_4D, BROADCAST_5D }; - - static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); - static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); - + public: + enum Kind { + SMALLARR2DX, + EWS1, + EWSNONZERO, + RANK1, + RANK2, + RANK3, + RANK4, + RANK5, + X_EWSNONZERO, + Y_EWSNONZERO, + Z_EWSNONZERO, + COMMON, + BROADCAST_SCALAR_X, + BROADCAST_SCALAR_Y, + BROADCAST_3D, + BROADCAST_4D, + BROADCAST_5D + }; + + static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo); }; ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int temp; - const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); - - if (xEws == 1 && zEws == 1 && xOrder == zOrder && (shapesSame || xOrder == 'c')) - return EWS1; - if(xEws > 0 && zEws > 0 && ((xOrder == zOrder && (shapesSame || xOrder == 'c')) || (xVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(xRank == 1 && shapesSame) - return RANK1; - if(xRank == 2 && shapesSame) - return RANK2; - if(xRank == 3 && shapesSame) - return RANK3; - if(xRank == 4 && shapesSame) - return RANK4; - if(xRank == 5 && shapesSame) - return RANK5; - if(xEws > 0 && xVectorOrC) - return X_EWSNONZERO; - if(zEws > 0 && zVectorOrC) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int temp; + const bool xVectorOrC = + shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = shape::shapeEquals(xShapeInfo, zShapeInfo); + + if (xEws == 1 && zEws == 1 && xOrder == zOrder && + (shapesSame || xOrder == 'c')) + return EWS1; + if (xEws > 0 && zEws > 0 && + ((xOrder == zOrder && (shapesSame || xOrder == 'c')) || + (xVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (xRank == 1 && shapesSame) return RANK1; + if (xRank == 2 && shapesSame) return RANK2; + if (xRank == 3 && shapesSame) return RANK3; + if (xRank == 4 && shapesSame) return RANK4; + if (xRank == 5 && shapesSame) return RANK5; + if (xEws > 0 && xVectorOrC) return X_EWSNONZERO; + if (zEws > 0 && zVectorOrC) return Z_EWSNONZERO; + return COMMON; } -LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { - auto xRank = shape::rank(xShapeInfo); - auto yRank = shape::rank(yShapeInfo); - auto zRank = shape::rank(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2); - - int countUnityDimsInY = 0, countUnityDimsInX = 0; - for (int i = 0; i < xRank; i++) { - if (i < yRank) - countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0; - countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0; +LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo) { + auto xRank = shape::rank(xShapeInfo); + auto yRank = shape::rank(yShapeInfo); + auto zRank = shape::rank(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + bool bNDLoopsRanks = (xRank == zRank && yRank <= xRank && yRank >= 2); + + int countUnityDimsInY = 0, countUnityDimsInX = 0; + for (int i = 0; i < xRank; i++) { + if (i < yRank) + countUnityDimsInY += (1 == shape::sizeAt(yShapeInfo, i)) ? 1 : 0; + countUnityDimsInX += (1 == shape::sizeAt(xShapeInfo, i)) ? 1 : 0; + } + + bool bNotCommonVectorCase = + (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1); + + if (bNDLoopsRanks && bNotCommonVectorCase) { + // case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = + // z[3,4,5] + if (sd::LoopKind::EWS1 == + deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo) && + (1 == shape::sizeAt(yShapeInfo, 0) || + 1 == shape::sizeAt(xShapeInfo, 0))) { + return EWS1; } - bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1); - - - if (bNDLoopsRanks && bNotCommonVectorCase) { - // case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = z[3,4,5] - if (sd::LoopKind::EWS1 == deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo) - && (1 == shape::sizeAt(yShapeInfo, 0) || 1 == shape::sizeAt(xShapeInfo, 0))) { - return EWS1; - } - - if (3 == xRank) - return sd::LoopKind::BROADCAST_3D; - if (4 == xRank) - return sd::LoopKind::BROADCAST_4D; - if (5 == xRank) - return sd::LoopKind::BROADCAST_5D; + if (3 == xRank) return sd::LoopKind::BROADCAST_3D; + if (4 == xRank) return sd::LoopKind::BROADCAST_4D; + if (5 == xRank) return sd::LoopKind::BROADCAST_5D; + } + if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && + zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { + // we validate that shapes are equal till the last dim + for (int e = 0; e < xRank - 1; e++) { + if (xShapeInfo[e + 1] != yShapeInfo[e + 1]) return COMMON; } + // now, if one of the shapes has 1 as last dim + auto detect = + xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0; - if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { - // we validate that shapes are equal till the last dim - for (int e = 0; e < xRank - 1; e++) { - if (xShapeInfo[e+1] != yShapeInfo[e+1]) - return COMMON; - } - - // now, if one of the shapes has 1 as last dim - auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0; + if (detect == 1) + return sd::LoopKind::BROADCAST_SCALAR_Y; + else if (detect == -1) + return sd::LoopKind::BROADCAST_SCALAR_X; + } - if (detect == 1) - return sd::LoopKind::BROADCAST_SCALAR_Y; - else if (detect == -1) - return sd::LoopKind::BROADCAST_SCALAR_X; - } - - return sd::LoopKind::COMMON; + return sd::LoopKind::COMMON; } ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char yOrder = shape::order(yShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int temp; - const bool xVectorOrC = shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; - const bool yVectorOrC = shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; - const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); - - if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) - return EWS1; - if(xEws > 0 && yEws > 0 && zEws > 0 && ((xOrder == yOrder && xOrder == zOrder && (shapesSame || xOrder == 'c')) || (xVectorOrC && yVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(xRank == 1 && shapesSame) - return RANK1; - if(xRank == 2 && shapesSame) - return RANK2; - if(xRank == 3 && shapesSame) - return RANK3; - if(xRank == 4 && shapesSame) - return RANK4; - if(xRank == 5 && shapesSame) - return RANK5; - if(xEws > 0 && xVectorOrC) - return X_EWSNONZERO; - if(yEws > 0 && yVectorOrC) - return Y_EWSNONZERO; - if(zEws > 0 && zVectorOrC) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const Nd4jLong* zShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char yOrder = shape::order(yShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int temp; + const bool xVectorOrC = + shape::isCommonVector(xShapeInfo, temp) || xOrder == 'c'; + const bool yVectorOrC = + shape::isCommonVector(yShapeInfo, temp) || yOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + const bool shapesSame = + shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); + + if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && + xOrder == zOrder && (shapesSame || xOrder == 'c')) + return EWS1; + if (xEws > 0 && yEws > 0 && zEws > 0 && + ((xOrder == yOrder && xOrder == zOrder && + (shapesSame || xOrder == 'c')) || + (xVectorOrC && yVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (xRank == 1 && shapesSame) return RANK1; + if (xRank == 2 && shapesSame) return RANK2; + if (xRank == 3 && shapesSame) return RANK3; + if (xRank == 4 && shapesSame) return RANK4; + if (xRank == 5 && shapesSame) return RANK5; + if (xEws > 0 && xVectorOrC) return X_EWSNONZERO; + if (yEws > 0 && yVectorOrC) return Y_EWSNONZERO; + if (zEws > 0 && zVectorOrC) return Z_EWSNONZERO; + return COMMON; } - ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo) { - - const int xRank = shape::rank(xShapeInfo); - const int tRank = shape::rank(tadShapeInfo); - - const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); - const Nd4jLong tEws = shape::elementWiseStride(tadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xOrder = shape::order(xShapeInfo); - const char tOrder = shape::order(tadShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - const bool allC = (tOrder == zOrder && zOrder == 'c'); - - int temp; - const bool tVectorOrC = shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c';; - - if(shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= Environment::getInstance()->elementwiseThreshold() && xEws == 1 && xOrder == 'c' && xRank == 2 && - tEws > 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) - return SMALLARR2DX; - if(tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) - return EWS1; - if(tEws > 0 && zEws > 0 && (allC || (tVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(tRank == 1 && zEws == 1 && zVectorOrC) - return RANK1; - if(tRank == 2 && zEws == 1 && zVectorOrC) - return RANK2; - if(tRank == 3 && zEws == 1 && zVectorOrC) - return RANK3; - if(tRank == 4 && zEws == 1 && zVectorOrC) - return RANK4; - if(tRank == 5 && zEws == 1 && zVectorOrC) - return RANK5; - if(tEws > 0 && tVectorOrC && zEws == 0) - return X_EWSNONZERO; - if(zEws > 0 && zVectorOrC && tEws == 0) - return Z_EWSNONZERO; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo) { + const int xRank = shape::rank(xShapeInfo); + const int tRank = shape::rank(tadShapeInfo); + + const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo); + const Nd4jLong tEws = shape::elementWiseStride(tadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xOrder = shape::order(xShapeInfo); + const char tOrder = shape::order(tadShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + const bool allC = (tOrder == zOrder && zOrder == 'c'); + + int temp; + const bool tVectorOrC = + shape::isCommonVector(tadShapeInfo, temp) || tOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, temp) || zOrder == 'c'; + ; + + if (shape::length(tadShapeInfo) * shape::length(zShapeInfo) <= + Environment::getInstance()->elementwiseThreshold() && + xEws == 1 && xOrder == 'c' && xRank == 2 && tEws > 1 && zEws == 1 && + (allC || (tVectorOrC && zVectorOrC))) + return SMALLARR2DX; + if (tEws == 1 && zEws == 1 && (allC || (tVectorOrC && zVectorOrC))) + return EWS1; + if (tEws > 0 && zEws > 0 && (allC || (tVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (tRank == 1 && zEws == 1 && zVectorOrC) return RANK1; + if (tRank == 2 && zEws == 1 && zVectorOrC) return RANK2; + if (tRank == 3 && zEws == 1 && zVectorOrC) return RANK3; + if (tRank == 4 && zEws == 1 && zVectorOrC) return RANK4; + if (tRank == 5 && zEws == 1 && zVectorOrC) return RANK5; + if (tEws > 0 && tVectorOrC && zEws == 0) return X_EWSNONZERO; + if (zEws > 0 && zVectorOrC && tEws == 0) return Z_EWSNONZERO; + return COMMON; } ////////////////////////////////////////////////////////////////////////////// -LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo) { - - // both tad shapes are the same, but strides and ews may be different - - const int tadRank = shape::rank(xTadShapeInfo); - - const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); - const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); - const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); - - const char xTadOrder = shape::order(xTadShapeInfo); - const char yTadOrder = shape::order(xTadShapeInfo); - const char zOrder = shape::order(zShapeInfo); - - int position; - const bool xTadVectorOrC = shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; - const bool yTadVectorOrC = shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; - const bool zVectorOrC = shape::isCommonVector(zShapeInfo, position) || zOrder == 'c'; - const bool allC = (xTadOrder == yTadOrder && xTadOrder == zOrder && zOrder == 'c'); - - if(xTadEws == 1 && yTadEws == 1 && zEws == 1 && allC) - return EWS1; - if(xTadEws > 0 && yTadEws > 0 && zEws > 0 && (allC || (xTadVectorOrC && yTadVectorOrC && zVectorOrC))) - return EWSNONZERO; - if(tadRank == 1 && zEws > 0 && zVectorOrC) - return RANK1; - if(tadRank == 2 && zEws > 0 && zVectorOrC) - return RANK2; - if(tadRank == 3 && zEws > 0 && zVectorOrC) - return RANK3; - if(tadRank == 4 && zEws > 0 && zVectorOrC) - return RANK4; - if(tadRank == 5 && zEws > 0 && zVectorOrC) - return RANK5; - return COMMON; +LoopKind::Kind LoopKind::deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* zShapeInfo) { + // both tad shapes are the same, but strides and ews may be different + + const int tadRank = shape::rank(xTadShapeInfo); + + const Nd4jLong xTadEws = shape::elementWiseStride(xTadShapeInfo); + const Nd4jLong yTadEws = shape::elementWiseStride(yTadShapeInfo); + const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo); + + const char xTadOrder = shape::order(xTadShapeInfo); + const char yTadOrder = shape::order(xTadShapeInfo); + const char zOrder = shape::order(zShapeInfo); + + int position; + const bool xTadVectorOrC = + shape::isCommonVector(xTadShapeInfo, position) || xTadOrder == 'c'; + const bool yTadVectorOrC = + shape::isCommonVector(yTadShapeInfo, position) || yTadOrder == 'c'; + const bool zVectorOrC = + shape::isCommonVector(zShapeInfo, position) || zOrder == 'c'; + const bool allC = + (xTadOrder == yTadOrder && xTadOrder == zOrder && zOrder == 'c'); + + if (xTadEws == 1 && yTadEws == 1 && zEws == 1 && allC) return EWS1; + if (xTadEws > 0 && yTadEws > 0 && zEws > 0 && + (allC || (xTadVectorOrC && yTadVectorOrC && zVectorOrC))) + return EWSNONZERO; + if (tadRank == 1 && zEws > 0 && zVectorOrC) return RANK1; + if (tadRank == 2 && zEws > 0 && zVectorOrC) return RANK2; + if (tadRank == 3 && zEws > 0 && zVectorOrC) return RANK3; + if (tadRank == 4 && zEws > 0 && zVectorOrC) return RANK4; + if (tadRank == 5 && zEws > 0 && zVectorOrC) return RANK5; + return COMMON; } - - - -} -#endif //LIBND4J_LOOPKIND_H +} // namespace sd +#endif // LIBND4J_LOOPKIND_H diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index 55f3c27e452e..d51d1ac5236d 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -14,1239 +14,1392 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma (iuriish@yahoo.com), created on 14.03.2019 - // +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 14.03.2019 +// #ifndef LIBND4J_LOOPS_H #define LIBND4J_LOOPS_H -#include -#include -#include +#include +#include +#include #include #include -#include -#include +#include #include -#include +#include #include -#include - -namespace sd { - - template - class SD_EXPORT ReductionLoops { - protected: - public: - - template - static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop); - }; - - template - class ReductionFloatLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop); - }; - - template - class SD_EXPORT ReductionBoolLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - template - class SD_EXPORT ReductionLongLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - template - class SD_EXPORT ReductionSameLoops : public ReductionLoops { - public: - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop); - }; - - - template - class SD_EXPORT IndexReductionLoops { - private: - public: - static void wrapIndexReduce(int opNum, const void* x, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* extraParams); - - template - static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams); - }; - - - template - class SD_EXPORT TransformLoops { - - public: - - template - static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, E* extraParams, uint64_t threadId, uint64_t numThreads); - }; - - template - class SD_EXPORT Reduction3Loops { - public: - - template - static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - template - static FORCEINLINE void loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - - static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop); - - template - static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop); - }; - - - - - /* - ////////////////////////////////////////////////////////////////////////////// - template - void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, - const Y* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - Z* extraParams, - std::function op) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - - const Nd4jLong* xShape = shape::shapeOf(xShapeInfo); - const Nd4jLong* xStride = shape::stride(xShapeInfo); - const Nd4jLong* yStride = shape::stride(yShapeInfo); - const Nd4jLong* zStride = shape::stride(zShapeInfo); - - const Nd4jLong len = shape::length(xShapeInfo); - - OmpLaunchHelper threadsInfo(len); - - switch (kindOfLoop) { - - case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - const auto threadNum = omp_get_thread_num(); - const auto threadOffset = threadsInfo.getThreadOffset(threadNum); - const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - - const auto xi = x + threadOffset; - const auto yi = y + threadOffset; - auto zi = z + threadOffset; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) - zi[i] = op(xi[i], yi[i], extraParams); - } - } - break; - - case LoopKind::EWSNONZERO: { - const uint xEws = shape::elementWiseStride(xShapeInfo); - const uint yEws = shape::elementWiseStride(yShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - const auto threadNum = omp_get_thread_num(); - const auto threadOffset = threadsInfo.getThreadOffset(threadNum); - const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - const auto xi = x + threadOffset * xEws; - const auto yi = y + threadOffset * yEws; - auto zi = z + threadOffset * zEws; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) - zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams); - } - } - break; - - case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR - for (uint i0 = 0; i0 < len; ++i0) - z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], extraParams); - } - break; - - case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] + i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams); - } - break; - - case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams); - } - break; - - case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - for (uint i3 = 0; i3 < xShape[3]; ++i3) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams); - } - break; - - case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4) - for (uint i0 = 0; i0 < xShape[0]; ++i0) - for (uint i1 = 0; i1 < xShape[1]; ++i1) - for (uint i2 = 0; i2 < xShape[2]; ++i2) - for (uint i3 = 0; i3 < xShape[3]; ++i3) - for (uint i4 = 0; i4 < xShape[4]; ++i4) - z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], extraParams); - } - break; - - default: { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - - bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) - { - auto threadNum = omp_get_thread_num(); - auto threadOffset = threadsInfo.getThreadOffset(threadNum); - auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum)); - PRAGMA_OMP_SIMD - for (uint i = 0; i < lenPerThread; i++) { - auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = op(x[xOffset], y[yOffset], extraParams); - } - } - } - } - } - */ - - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::ReductionLoops::loopReduce(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, - E* extraParams, - int64_t start, int64_t stop) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); - - const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong tadLen = shape::length(tadShapeInfo); - - const uint tadEws = shape::elementWiseStride(tadShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo); - const Nd4jLong* tadStride = shape::stride(tadShapeInfo); - - int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); - - switch (kindOfLoop) { - - //*********************************************// - // case LoopKind::SMALLARR2DX: { - // shape::printShapeInfoLinear(xShapeInfo); - // shape::printShapeInfoLinear(zShapeInfo); - // const auto xLen = zLen * tadLen; - // for (uint i = 0; i < xLen; ++i) { - // const auto zOffset = shape::subArrayOffset(i, xShapeInfo, zShapeInfo, dimsToExclude, dimsLen); - // const uint tadInd = (i / tadEws) % tadLen; - // auto startVal = tadInd ? z[zOffset] : static_cast(OpType::startingValue(x)); - // z[zOffset] = OpType::update(startVal, OpType::op(x[i], extraParams), extraParams); - // if(tadInd == tadLen - 1) - // z[zOffset] = OpType::postProcess(z[zOffset], tadLen, extraParams); - // printf("%u - %lld\n", i, zOffset); - // } - // } - case LoopKind::SMALLARR2DX: { - const auto uTadLen = static_cast(tadLen); - const auto uZLenMinusOne = static_cast(zLen - 1); - const auto xLen = static_cast(zLen * uTadLen); - const auto sv = static_cast(OpType::startingValue(x)); - - for (uint i = 0; i <= uZLenMinusOne; i++) - z[i] = OpType::startingValue(x); - - uint zOffset = 0; - for (uint i = 0; i < xLen; ++i) { - z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), extraParams); - zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1; - } - - for (uint i = 0; i <= uZLenMinusOne; i++) - z[i] = OpType::postProcess(z[i], tadLen, extraParams); - } - break; - - //*********************************************// - case LoopKind::EWS1: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::EWSNONZERO: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams); - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK1: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK2: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK3: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK4: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::RANK5: { - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) - s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]], extraParams), extraParams); - - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::X_EWSNONZERO: { - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams); - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - case LoopKind::Z_EWSNONZERO: { - uint castTadShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), extraParams); - } - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - default: { - auto innertadOffsets = new Nd4jLong[tadLen]; - shape::calcOffsets(tadShapeInfo, innertadOffsets); - - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - for (auto i = start; i < stop; i++) { - auto tad = x + tadOffsets[i]; - auto s = OpType::startingValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) - s = OpType::update(s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams); - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = OpType::postProcess(s, tadLen, extraParams); - }; - - delete[] innertadOffsets; - } - } - } - - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::TransformLoops::loopTransform(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - E* extraParams, - uint64_t threadId, uint64_t numThreads) { - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - const Nd4jLong* xShape = shape::shapeOf(const_cast(xShapeInfo)); - const Nd4jLong* xStride = shape::stride(const_cast(xShapeInfo)); - const Nd4jLong* zStride = shape::stride(const_cast(zShapeInfo)); +#include - const Nd4jLong len = shape::length(xShapeInfo); +#include - if (len == 0) - return; +namespace sd { - switch (kindOfLoop) { +template +class SD_EXPORT ReductionLoops { + protected: + public: + template + static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, E* extraParams, + int64_t start, int64_t stop); +}; + +template +class ReductionFloatLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, Z* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionBoolLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionLongLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT ReductionSameLoops : public ReductionLoops { + public: + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop); +}; + +template +class SD_EXPORT IndexReductionLoops { + private: + public: + static void wrapIndexReduce(int opNum, const void* x, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, void* extraParams); + + template + static void loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams); +}; + +template +class SD_EXPORT TransformLoops { + public: + template + static FORCEINLINE void loopTransform(const X* x, const Nd4jLong* xShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + E* extraParams, uint64_t threadId, + uint64_t numThreads); +}; + +template +class SD_EXPORT Reduction3Loops { + public: + template + static FORCEINLINE void loopReduce3(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + int* dims, int dimsLen, Z* extraParams, + int64_t start, int64_t stop); + + template + static FORCEINLINE void loopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, + Z* extraParams, int64_t start, int64_t stop); + + static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, int* dims, int dimsLen, + Z* extraParams, int64_t start, int64_t stop); + + static void wrapperAll(int opNum, const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop); + + template + static void innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, + int64_t stop); + + template + static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, + const X* y, const Nd4jLong* yShapeInfo, Z* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, + const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop); +}; + +/* +////////////////////////////////////////////////////////////////////////////// +template +void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, + const Y* y, const Nd4jLong* yShapeInfo, + Z* z, const Nd4jLong* zShapeInfo, + Z* extraParams, + std::function op) { + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, +yShapeInfo, zShapeInfo); + + const Nd4jLong* xShape = shape::shapeOf(xShapeInfo); + const Nd4jLong* xStride = shape::stride(xShapeInfo); + const Nd4jLong* yStride = shape::stride(yShapeInfo); + const Nd4jLong* zStride = shape::stride(zShapeInfo); + + const Nd4jLong len = shape::length(xShapeInfo); + + OmpLaunchHelper threadsInfo(len); + + switch (kindOfLoop) { - //*********************************************// case LoopKind::EWS1: { - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], extraParams); + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + const auto threadNum = omp_get_thread_num(); + const auto threadOffset = +threadsInfo.getThreadOffset(threadNum); const auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); + + const auto xi = x + threadOffset; + const auto yi = y + threadOffset; + auto zi = z + threadOffset; + + PRAGMA_OMP_SIMD + for (uint i = 0; i < lenPerThread; i++) + zi[i] = op(xi[i], yi[i], extraParams); + } } - break; + break; - //*********************************************// case LoopKind::EWSNONZERO: { const uint xEws = shape::elementWiseStride(xShapeInfo); + const uint yEws = shape::elementWiseStride(yShapeInfo); const uint zEws = shape::elementWiseStride(zShapeInfo); - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], extraParams); - } - break; - - //*********************************************// - case LoopKind::Z_EWSNONZERO: { - const uint zEws = shape::elementWiseStride(zShapeInfo); - uint castXShapeInfo[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, castXShapeInfo); - - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - int64_t start = span.startX(), stop = span.stopX(); - - if (zEws > 1) { - for (auto i = start; i < stop; i++) { - const auto xOffset = shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); - z[i * zEws] = OpType::op(x[xOffset], extraParams); - } - } - else { - for (auto i = start; i < stop; i++) { - const auto xOffset = shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); - z[i] = OpType::op(x[xOffset], extraParams); - } + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + const auto threadNum = omp_get_thread_num(); + const auto threadOffset = +threadsInfo.getThreadOffset(threadNum); const auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); const auto xi = x + +threadOffset * xEws; const auto yi = y + threadOffset * yEws; auto zi = z + +threadOffset * zEws; + + PRAGMA_OMP_SIMD + for (uint i = 0; i < lenPerThread; i++) + zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams); } } - break; + break; - //*********************************************// case LoopKind::RANK1: { - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams); + PRAGMA_OMP_PARALLEL_FOR + for (uint i0 = 0; i0 < len; ++i0) + z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], +extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK2: { - auto uXShape0 = static_cast(xShape[0]); - auto uXShape1 = static_cast(xShape[1]); - - auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); - auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) { - auto z0 = i0 * zStride[0]; - auto x0 = i0 * xStride[0]; - - for (auto i1 = span.startY(); i1 < span.stopY(); ++i1) - z[z0 + i1 * zStride[1]] = OpType::op(x[x0 + i1 * xStride[1]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] ++ i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK3: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - - auto loop = samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); - auto span = samediff::Span2::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1); - - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) { - auto z0 = i0 * zStride[0] + i1 * zStride[1]; - auto x0 = i0 * xStride[0] + i1 * xStride[1]; - - for (Nd4jLong i2 = 0; i2 < uXShape2; ++i2) - z[z0 + i2 * zStride[2]] = OpType::op(x[x0 + i2 * xStride[2]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = +op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK4: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - auto uXShape3 = xShape[3]; - - auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); - auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) - for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { - auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; - auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; - - for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) - z[z0 + i3 * zStride[3]] = OpType::op(x[x0 + i3 * xStride[3]], extraParams); - } + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + for (uint i3 = 0; i3 < xShape[3]; ++i3) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] += op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams); } - break; + break; - //*********************************************// case LoopKind::RANK5: { - auto uXShape0 = xShape[0]; - auto uXShape1 = xShape[1]; - auto uXShape2 = xShape[2]; - auto uXShape3 = xShape[3]; - auto uXShape4 = xShape[4]; - - auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, uXShape1, uXShape2); - auto span = samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, uXShape1, 1, 0, uXShape2, 1); - - - for (auto i0 = span.startX(); i0 < span.stopX(); i0++) - for (auto i1 = span.startY(); i1 < span.stopY(); i1++) - for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { - auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; - auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; - - for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) { - - auto z1 = z0 + i3 * zStride[3]; - auto x1 = x0 + i3 * xStride[3]; - - for (Nd4jLong i4 = 0; i4 < uXShape4; ++i4) - z[z1 + i4 * zStride[4]] = OpType::op(x[x1 + i4 * xStride[4]], extraParams); - - } - } - + PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4) + for (uint i0 = 0; i0 < xShape[0]; ++i0) + for (uint i1 = 0; i1 < xShape[1]; ++i1) + for (uint i2 = 0; i2 < xShape[2]; ++i2) + for (uint i3 = 0; i3 < xShape[3]; ++i3) + for (uint i4 = 0; i4 < xShape[4]; ++i4) + z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] += op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], +y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], +extraParams); } - break; + break; - //*********************************************// default: { uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK]; - bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - - for (auto i = span.startX(); i < span.stopX(); i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], extraParams); + bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, +xShapeInfoCast); bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, +yShapeInfoCast); bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, +zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = threadsInfo.getThreadOffset(threadNum); + auto lenPerThread = +static_cast(threadsInfo.getItersPerThread(threadNum)); PRAGMA_OMP_SIMD for +(uint i = 0; i < lenPerThread; i++) { auto xOffset = shape::indexOffset(i + +threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto yOffset = +shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); auto +zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, +canCastZ); z[zOffset] = op(x[xOffset], y[yOffset], extraParams); + } } } - - } } - - - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::Reduction3Loops::loopReduce3(const X* x, const Nd4jLong* xShapeInfo, - const X* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - int* dims, int dimsLen, - Z* extraParameters, int64_t start, int64_t stop) { - - // both tads have same shape, however strides and ews may differ - - Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); - - const Nd4jLong xLen = shape::length(xShapeInfo); - const Nd4jLong yLen = shape::length(yShapeInfo); - - const Nd4jLong* xTadShapeInfo = nullptr, * yTadShapeInfo = nullptr, * xTadOffsets = nullptr, * yTadOffsets = nullptr; - TadPack tadPackX, tadPackY; - std::vector zeroOffsets; - - if (xLen == yLen) { - tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dims, dimsLen); - tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dims, dimsLen); - xTadShapeInfo = tadPackX.primaryShapeInfo(); - yTadShapeInfo = tadPackY.primaryShapeInfo(); - xTadOffsets = tadPackX.primaryOffsets(); - yTadOffsets = tadPackY.primaryOffsets(); - } - else if (yLen > xLen) { - tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dims, dimsLen); - xTadShapeInfo = xShapeInfo; - yTadShapeInfo = tadPackY.primaryShapeInfo(); - yTadOffsets = tadPackY.primaryOffsets(); - } - else { - tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dims, dimsLen); - yTadShapeInfo = yShapeInfo; - xTadShapeInfo = tadPackX.primaryShapeInfo(); - xTadOffsets = tadPackX.primaryOffsets(); +} +*/ + +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::ReductionLoops::loopReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, + int64_t start, int64_t stop) { + const LoopKind::Kind kindOfLoop = + LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); + + const Nd4jLong zLen = shape::length(zShapeInfo); + const Nd4jLong tadLen = shape::length(tadShapeInfo); + + const uint tadEws = shape::elementWiseStride(tadShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo); + const Nd4jLong* tadStride = shape::stride(tadShapeInfo); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + + switch (kindOfLoop) { + //*********************************************// + // case LoopKind::SMALLARR2DX: { + // shape::printShapeInfoLinear(xShapeInfo); + // shape::printShapeInfoLinear(zShapeInfo); + // const auto xLen = zLen * tadLen; + // for (uint i = 0; i < xLen; ++i) { + // const auto zOffset = shape::subArrayOffset(i, xShapeInfo, + // zShapeInfo, dimsToExclude, dimsLen); const uint tadInd = (i / + // tadEws) % tadLen; auto startVal = tadInd ? z[zOffset] : + // static_cast(OpType::startingValue(x)); z[zOffset] = + // OpType::update(startVal, OpType::op(x[i], extraParams), + // extraParams); if(tadInd == tadLen - 1) + // z[zOffset] = OpType::postProcess(z[zOffset], tadLen, + // extraParams); + // printf("%u - %lld\n", i, zOffset); + // } + // } + case LoopKind::SMALLARR2DX: { + const auto uTadLen = static_cast(tadLen); + const auto uZLenMinusOne = static_cast(zLen - 1); + const auto xLen = static_cast(zLen * uTadLen); + const auto sv = static_cast(OpType::startingValue(x)); + + for (uint i = 0; i <= uZLenMinusOne; i++) z[i] = OpType::startingValue(x); + + uint zOffset = 0; + for (uint i = 0; i < xLen; ++i) { + z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), + extraParams); + zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1; + } + + for (uint i = 0; i <= uZLenMinusOne; i++) + z[i] = OpType::postProcess(z[i], tadLen, extraParams); + } break; + + //*********************************************// + case LoopKind::EWS1: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), + extraParams); + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) + s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) + s = OpType::update( + s, + OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3] + + i4 * tadStride[4]], + extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::X_EWSNONZERO: { + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), + extraParams); + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::Z_EWSNONZERO: { + uint castTadShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = + shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); + s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), + extraParams); } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ(xTadShapeInfo, yTadShapeInfo, zShapeInfo); - - const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); - const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); + //*********************************************// + default: { + auto innertadOffsets = new Nd4jLong[tadLen]; + shape::calcOffsets(tadShapeInfo, innertadOffsets); - const auto zLen = shape::length(zShapeInfo); - const auto tadLen = shape::length(xTadShapeInfo); + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - const auto tadShape = shape::shapeOf(xTadShapeInfo); - const auto xTadStride = shape::stride(xTadShapeInfo); - const auto yTadStride = shape::stride(xTadShapeInfo); + for (auto i = start; i < stop; i++) { + auto tad = x + tadOffsets[i]; + auto s = OpType::startingValue(tad); - int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + for (Nd4jLong j = 0; j < tadLen; j++) + s = OpType::update( + s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams); - switch (kindOfLoop) { + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = OpType::postProcess(s, tadLen, extraParams); + }; - //*********************************************// - case LoopKind::EWS1: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), extraParams); + delete[] innertadOffsets; + } + } +} - z[i] = OpType::postProcess(s, tadLen, extraParams); - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::TransformLoops::loopTransform( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + E* extraParams, uint64_t threadId, uint64_t numThreads) { + const LoopKind::Kind kindOfLoop = + LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + const Nd4jLong* xShape = shape::shapeOf(const_cast(xShapeInfo)); + const Nd4jLong* xStride = shape::stride(const_cast(xShapeInfo)); + const Nd4jLong* zStride = shape::stride(const_cast(zShapeInfo)); + + const Nd4jLong len = shape::length(xShapeInfo); + + if (len == 0) return; + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + for (auto i = start; i < stop; i++) z[i] = OpType::op(x[i], extraParams); + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + const uint xEws = shape::elementWiseStride(xShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], extraParams); + } break; + + //*********************************************// + case LoopKind::Z_EWSNONZERO: { + const uint zEws = shape::elementWiseStride(zShapeInfo); + uint castXShapeInfo[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, castXShapeInfo); + + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + int64_t start = span.startX(), stop = span.stopX(); + + if (zEws > 1) { + for (auto i = start; i < stop; i++) { + const auto xOffset = + shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); + z[i * zEws] = OpType::op(x[xOffset], extraParams); + } + } else { + for (auto i = start; i < stop; i++) { + const auto xOffset = + shape::indexOffset(i, xShapeInfo, castXShapeInfo, canCastX); + z[i] = OpType::op(x[xOffset], extraParams); } - break; + } + } break; + + //*********************************************// + case LoopKind::RANK1: { + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + z[i0 * zStride[0]] = OpType::op(x[i0 * xStride[0]], extraParams); + } break; + + //*********************************************// + case LoopKind::RANK2: { + auto uXShape0 = static_cast(xShape[0]); + auto uXShape1 = static_cast(xShape[1]); + + auto loop = + samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); + auto span = samediff::Span2::build(loop, threadId, numThreads, 0, + uXShape0, 1, 0, uXShape1, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) { + auto z0 = i0 * zStride[0]; + auto x0 = i0 * xStride[0]; + + for (auto i1 = span.startY(); i1 < span.stopY(); ++i1) + z[z0 + i1 * zStride[1]] = + OpType::op(x[x0 + i1 * xStride[1]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK3: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + + auto loop = + samediff::ThreadsHelper::pickLoop2d(numThreads, uXShape0, uXShape1); + auto span = samediff::Span2::build(loop, threadId, numThreads, 0, + uXShape0, 1, 0, uXShape1, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) { + auto z0 = i0 * zStride[0] + i1 * zStride[1]; + auto x0 = i0 * xStride[0] + i1 * xStride[1]; + + for (Nd4jLong i2 = 0; i2 < uXShape2; ++i2) + z[z0 + i2 * zStride[2]] = + OpType::op(x[x0 + i2 * xStride[2]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK4: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + auto uXShape3 = xShape[3]; + + auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, + uXShape1, uXShape2); + auto span = + samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, + uXShape1, 1, 0, uXShape2, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) + for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { + auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; + auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; + + for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) + z[z0 + i3 * zStride[3]] = + OpType::op(x[x0 + i3 * xStride[3]], extraParams); + } + } break; + + //*********************************************// + case LoopKind::RANK5: { + auto uXShape0 = xShape[0]; + auto uXShape1 = xShape[1]; + auto uXShape2 = xShape[2]; + auto uXShape3 = xShape[3]; + auto uXShape4 = xShape[4]; + + auto loop = samediff::ThreadsHelper::pickLoop3d(numThreads, uXShape0, + uXShape1, uXShape2); + auto span = + samediff::Span3::build(loop, threadId, numThreads, 0, uXShape0, 1, 0, + uXShape1, 1, 0, uXShape2, 1); + + for (auto i0 = span.startX(); i0 < span.stopX(); i0++) + for (auto i1 = span.startY(); i1 < span.stopY(); i1++) + for (auto i2 = span.startZ(); i2 < span.stopZ(); i2++) { + auto z0 = i0 * zStride[0] + i1 * zStride[1] + i2 * zStride[2]; + auto x0 = i0 * xStride[0] + i1 * xStride[1] + i2 * xStride[2]; + + for (Nd4jLong i3 = 0; i3 < uXShape3; ++i3) { + auto z1 = z0 + i3 * zStride[3]; + auto x1 = x0 + i3 * xStride[3]; + + for (Nd4jLong i4 = 0; i4 < uXShape4; ++i4) + z[z1 + i4 * zStride[4]] = + OpType::op(x[x1 + i4 * xStride[4]], extraParams); + } + } - //*********************************************// - case LoopKind::EWSNONZERO: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; + } break; - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); + //*********************************************// + default: { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), extraParams); + bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; + auto span = samediff::Span::build(threadId, numThreads, 0, len, 1); - //*********************************************// - case LoopKind::RANK1: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - const auto xTadOffset = i0 * xTadStride[0]; - const auto yTadOffset = i0 * yTadStride[0]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } + for (auto i = span.startX(); i < span.stopX(); i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], extraParams); + } + } + } +} - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::Reduction3Loops::loopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParameters, int64_t start, int64_t stop) { + // both tads have same shape, however strides and ews may differ + + Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), + param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); + + const Nd4jLong xLen = shape::length(xShapeInfo); + const Nd4jLong yLen = shape::length(yShapeInfo); + + const Nd4jLong *xTadShapeInfo = nullptr, *yTadShapeInfo = nullptr, + *xTadOffsets = nullptr, *yTadOffsets = nullptr; + TadPack tadPackX, tadPackY; + std::vector zeroOffsets; + + if (xLen == yLen) { + tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dims, dimsLen); + tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dims, dimsLen); + xTadShapeInfo = tadPackX.primaryShapeInfo(); + yTadShapeInfo = tadPackY.primaryShapeInfo(); + xTadOffsets = tadPackX.primaryOffsets(); + yTadOffsets = tadPackY.primaryOffsets(); + } else if (yLen > xLen) { + tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dims, dimsLen); + xTadShapeInfo = xShapeInfo; + yTadShapeInfo = tadPackY.primaryShapeInfo(); + yTadOffsets = tadPackY.primaryOffsets(); + } else { + tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dims, dimsLen); + yTadShapeInfo = yShapeInfo; + xTadShapeInfo = tadPackX.primaryShapeInfo(); + xTadOffsets = tadPackX.primaryOffsets(); + } + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ( + xTadShapeInfo, yTadShapeInfo, zShapeInfo); + + const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); + const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + const auto zLen = shape::length(zShapeInfo); + const auto tadLen = shape::length(xTadShapeInfo); + + const auto tadShape = shape::shapeOf(xTadShapeInfo); + const auto xTadStride = shape::stride(xTadShapeInfo); + const auto yTadStride = shape::stride(xTadShapeInfo); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen); + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), + extraParams); + + z[i] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update( + s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), + extraParams); + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + const auto xTadOffset = i0 * xTadStride[0]; + const auto yTadOffset = i0 * yTadStride[0]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); } - break; - //*********************************************// - case LoopKind::RANK2: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; + const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } } - break; - - //*********************************************// - case LoopKind::RANK3: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } } - break; - - //*********************************************// - case LoopKind::RANK4: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto xTadOffset = i0 * xTadStride[0] + + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3]; + const auto yTadOffset = i0 * yTadStride[0] + + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + } } - break; - - //*********************************************// - case LoopKind::RANK5: { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3] + i4 * xTadStride[4]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3] + i4 * yTadStride[4]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3] + + i4 * xTadStride[4]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3] + + i4 * yTadStride[4]; + s = OpType::update(s, + OpType::op(xTad[xTadOffset], + yTad[yTadOffset], extraParams), + extraParams); } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - break; - - //*********************************************// - default: { - uint castXTadShapeInfo[MAX_RANK]; - const bool canCastXTad = sd::DataTypeUtils::castShapeInfo(xTadShapeInfo, castXTadShapeInfo); - - if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - s = OpType::update(s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); - } - - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; + } } - else { - uint castYTadShapeInfo[MAX_RANK]; - const bool canCastYTad = sd::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - - Z extraParams[3]; - for (auto i = start; i < stop; i++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; - const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; - auto s = OpType::startingValue(xTad); - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); - }; - } - } + } } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + default: { + uint castXTadShapeInfo[MAX_RANK]; + const bool canCastXTad = sd::DataTypeUtils::castShapeInfo( + xTadShapeInfo, castXTadShapeInfo); + + if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto tadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + s = OpType::update( + s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), + extraParams); + } + + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } else { + uint castYTadShapeInfo[MAX_RANK]; + const bool canCastYTad = sd::DataTypeUtils::castShapeInfo( + yTadShapeInfo, castYTadShapeInfo); + + Z extraParams[3]; + for (auto i = start; i < stop; i++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = xTadOffsets ? x + xTadOffsets[i] : x; + const auto yTad = yTadOffsets ? y + yTadOffsets[i] : y; + auto s = OpType::startingValue(xTad); + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto xTadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + const auto yTadOffset = shape::indexOffset( + j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + z[i * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } } + } +} - ////////////////////////////////////////////////////////////////////////////// - template - template - void sd::Reduction3Loops::loopReduce3All(const X* x, const Nd4jLong* xShapeInfo, - const X* y, const Nd4jLong* yShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, - Z* extraParameters, - int64_t start, int64_t stop) { - - // both tads have same shape, however strides and ews may differ - - Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); - - const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ(xTadShapeInfo, yTadShapeInfo, zShapeInfo); - - const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); - const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); - - const auto zLen = shape::length(zShapeInfo); - const auto tadLen = shape::length(xTadShapeInfo); - - const auto numXTads = shape::length(xShapeInfo) / tadLen; - const auto numYTads = shape::length(yShapeInfo) / tadLen; - - const auto tadShape = shape::shapeOf(xTadShapeInfo); - const auto xTadStride = shape::stride(xTadShapeInfo); - const auto yTadStride = shape::stride(yTadShapeInfo); - - const auto startVal = OpType::startingValue(x); - - int numThreads = OmpLaunchHelper::tadThreads(tadLen, numXTads * numYTads); - - switch (kindOfLoop) { - //*********************************************// - case LoopKind::EWS1: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), extraParams); - - z[zInd] = OpType::postProcess(s, tadLen, extraParams); - } - }; +////////////////////////////////////////////////////////////////////////////// +template +template +void sd::Reduction3Loops::loopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, + Z* extraParameters, int64_t start, int64_t stop) { + // both tads have same shape, however strides and ews may differ + + Z param0(OpType::startingValue(x)), param1(OpType::startingValue(x)), + param2(extraParameters ? extraParameters[0] : OpType::startingValue(x)); + + const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXYZ( + xTadShapeInfo, yTadShapeInfo, zShapeInfo); + + const auto xTadEws = shape::elementWiseStride(xTadShapeInfo); + const auto yTadEws = shape::elementWiseStride(yTadShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + const auto zLen = shape::length(zShapeInfo); + const auto tadLen = shape::length(xTadShapeInfo); + + const auto numXTads = shape::length(xShapeInfo) / tadLen; + const auto numYTads = shape::length(yShapeInfo) / tadLen; + + const auto tadShape = shape::shapeOf(xTadShapeInfo); + const auto xTadStride = shape::stride(xTadShapeInfo); + const auto yTadStride = shape::stride(yTadShapeInfo); + + const auto startVal = OpType::startingValue(x); + + int numThreads = OmpLaunchHelper::tadThreads(tadLen, numXTads * numYTads); + + switch (kindOfLoop) { + //*********************************************// + case LoopKind::EWS1: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update(s, OpType::op(xTad[j], yTad[j], extraParams), + extraParams); + + z[zInd] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::EWSNONZERO: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) - s = OpType::update(s, OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), extraParams); - - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::EWSNONZERO: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) + s = OpType::update( + s, + OpType::op(xTad[j * xTadEws], yTad[j * yTadEws], extraParams), + extraParams); + + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK1: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - const auto xTadOffset = i0 * xTadStride[0]; - const auto yTadOffset = i0 * yTadStride[0]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK1: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + const auto xTadOffset = i0 * xTadStride[0]; + const auto yTadOffset = i0 * yTadStride[0]; + s = OpType::update( + s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK2: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK2: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1]; + const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK3: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + case LoopKind::RANK3: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto xTadOffset = i0 * xTadStride[0] + + i1 * xTadStride[1] + i2 * xTadStride[2]; + const auto yTadOffset = i0 * yTadStride[0] + + i1 * yTadStride[1] + i2 * yTadStride[2]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); + } + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK4: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK4: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3]; + s = OpType::update(s, + OpType::op(xTad[xTadOffset], + yTad[yTadOffset], extraParams), + extraParams); } - }; + } + } + } + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); } - break; - - //*********************************************// - case LoopKind::RANK5: { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto xTadOffset = i0 * xTadStride[0] + i1 * xTadStride[1] + i2 * xTadStride[2] + i3 * xTadStride[3] + i4 * xTadStride[4]; - const auto yTadOffset = i0 * yTadStride[0] + i1 * yTadStride[1] + i2 * yTadStride[2] + i3 * yTadStride[3] + i4 * yTadStride[4]; - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - } - } - } - } - z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams); + }; + } break; + + //*********************************************// + case LoopKind::RANK5: { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto xTadOffset = + i0 * xTadStride[0] + i1 * xTadStride[1] + + i2 * xTadStride[2] + i3 * xTadStride[3] + + i4 * xTadStride[4]; + const auto yTadOffset = + i0 * yTadStride[0] + i1 * yTadStride[1] + + i2 * yTadStride[2] + i3 * yTadStride[3] + + i4 * yTadStride[4]; + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], + extraParams), + extraParams); + } } - }; + } + } + } + z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams); } - break; - - //*********************************************// - default: { - uint castXTadShapeInfo[MAX_RANK]; - const bool canCastXTad = sd::DataTypeUtils::castShapeInfo(xTadShapeInfo, castXTadShapeInfo); - - if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - s = OpType::update(s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); - } - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + }; + } break; + + //*********************************************// + default: { + uint castXTadShapeInfo[MAX_RANK]; + const bool canCastXTad = sd::DataTypeUtils::castShapeInfo( + xTadShapeInfo, castXTadShapeInfo); + + if (shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto tadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + s = OpType::update( + s, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), + extraParams); } - else { - uint castYTadShapeInfo[MAX_RANK]; - const bool canCastYTad = sd::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - - Z extraParams[3]; - for (Nd4jLong ix = 0; ix < numXTads; ix++) { - for (Nd4jLong iy = 0; iy < numYTads; iy++) { - extraParams[0] = param0; - extraParams[1] = param1; - extraParams[2] = param2; - - const auto xTad = x + xTadOffsets[ix]; - const auto yTad = y + yTadOffsets[iy]; - const auto zInd = ix * numYTads + iy; - auto s = startVal; - - for (Nd4jLong j = 0; j < tadLen; ++j) { - const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); - const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); - s = OpType::update(s, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); - } - - z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); - } - }; + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + } + }; + } else { + uint castYTadShapeInfo[MAX_RANK]; + const bool canCastYTad = sd::DataTypeUtils::castShapeInfo( + yTadShapeInfo, castYTadShapeInfo); + + Z extraParams[3]; + for (Nd4jLong ix = 0; ix < numXTads; ix++) { + for (Nd4jLong iy = 0; iy < numYTads; iy++) { + extraParams[0] = param0; + extraParams[1] = param1; + extraParams[2] = param2; + + const auto xTad = x + xTadOffsets[ix]; + const auto yTad = y + yTadOffsets[iy]; + const auto zInd = ix * numYTads + iy; + auto s = startVal; + + for (Nd4jLong j = 0; j < tadLen; ++j) { + const auto xTadOffset = shape::indexOffset( + j, xTadShapeInfo, castXTadShapeInfo, canCastXTad); + const auto yTadOffset = shape::indexOffset( + j, yTadShapeInfo, castYTadShapeInfo, canCastYTad); + s = OpType::update( + s, + OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), + extraParams); } - } - } - } - - + z[zInd * zEws] = OpType::postProcess(s, tadLen, extraParams); + } + }; + } + } + } } +} // namespace sd -#endif //LIBND4J_LOOPS_H +#endif // LIBND4J_LOOPS_H diff --git a/libnd4j/include/helpers/Loops.hpp b/libnd4j/include/helpers/Loops.hpp index 852ef4808c3d..24d08221805e 100644 --- a/libnd4j/include/helpers/Loops.hpp +++ b/libnd4j/include/helpers/Loops.hpp @@ -22,17 +22,22 @@ //#define LIBND4J_LOOPS_CPP #include -#include -#include #include +#include +#include - -namespace sd { - - -} -//template void Loops::loopReduce(const double* x, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const Nd4jLong* zShapeInfo, double* extraParams, std::function startVal, std::function update, std::function op, std::function postPr); -//template void Loops::loopReduce(const float* x, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, float* z, const Nd4jLong* zShapeInfo, float* extraParams, std::function startVal, std::function update, std::function op, std::function postPr); +namespace sd {} +// template void Loops::loopReduce(const double* x, const +// Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const Nd4jLong* +// zShapeInfo, double* extraParams, std::function +// startVal, std::function update, +// std::function op, +// std::function postPr); template void +// Loops::loopReduce(const float* x, const Nd4jLong* tadShapeInfo, +// const Nd4jLong* tadOffsets, float* z, const Nd4jLong* zShapeInfo, float* +// extraParams, std::function startVal, +// std::function update, +// std::function op, +// std::function postPr); //#endif // LIBND4J_LOOPS_CPP - diff --git a/libnd4j/include/helpers/LoopsCoordsHelper.h b/libnd4j/include/helpers/LoopsCoordsHelper.h index cd578b62abee..d22ed6a96de7 100644 --- a/libnd4j/include/helpers/LoopsCoordsHelper.h +++ b/libnd4j/include/helpers/LoopsCoordsHelper.h @@ -14,427 +14,415 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author AbdelRauf - // +// +// @author AbdelRauf +// #ifndef LIBND4J_LOOPCOORDSHELPER_H #define LIBND4J_LOOPCOORDSHELPER_H +#include +#include + #include #include #include -#include -#include namespace sd { #if defined(__GNUC__) -#define likely(x) __builtin_expect( (x), 1) -#define unlikely(x) __builtin_expect( (x), 0) +#define likely(x) __builtin_expect((x), 1) +#define unlikely(x) __builtin_expect((x), 0) #else -#define likely(x) (x) -#define unlikely(x) (x) +#define likely(x) (x) +#define unlikely(x) (x) #endif - using zip_size_t = std::pair; - - template - struct CoordsState :CoordsState { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride; - Nd4jLong adjust; - CoordsState() :CoordsState() {} - }; - - template<> - struct CoordsState<0> { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride; - Nd4jLong adjust; - CoordsState() {} - }; - - - template - struct ZipCoordsState :ZipCoordsState { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride1; - Nd4jLong stride2; - Nd4jLong adjust1; - Nd4jLong adjust2; - ZipCoordsState() : ZipCoordsState() {} - }; - - template<> - struct ZipCoordsState<0> { - Nd4jLong coord; - Nd4jLong last_num; - Nd4jLong stride1; - Nd4jLong stride2; - Nd4jLong adjust1; - Nd4jLong adjust2; - ZipCoordsState() {} - }; - -#define COORDS(x,index) ((x).::sd::CoordsState<(index)>::coord) -#define STRIDE(x,index) ((x).::sd::CoordsState<(index)>::stride) -#define LAST_NUM(x,index) ((x).::sd::CoordsState<(index)>::last_num) -#define OF_ADJUST(x,index) ((x).::sd::CoordsState<(index)>::adjust) -#define ZIP_LAST_NUM(x,index) ((x).::sd::ZipCoordsState<(index)>::last_num) -#define ZIP_COORDS(x,index) ((x).::sd::ZipCoordsState<(index)>::coord) -#define ZIP_STRIDE1(x,index) ((x).::sd::ZipCoordsState<(index)>::stride1) -#define ZIP_STRIDE2(x,index) ((x).::sd::ZipCoordsState<(index)>::stride2) -#define ZIP_OF_ADJUST1(x,index) ((x).::sd::ZipCoordsState<(index)>::adjust1) -#define ZIP_OF_ADJUST2(x,index) ((x).::sd::ZipCoordsState<(index)>::adjust2) - - - FORCEINLINE void index2coords_C(Nd4jLong index, const Nd4jLong rank, const Nd4jLong* bases, Nd4jLong* coords) { - for (size_t i = rank - 1; i > 0; --i) { - coords[i] = index % bases[i]; - index /= bases[i]; - } - coords[0] = index; // last iteration - } - - FORCEINLINE void index2coords_F(Nd4jLong index, const Nd4jLong rank, const Nd4jLong* bases, Nd4jLong* coords) { - - for (size_t i = 0; i < rank - 1; i++) { - coords[i] = index % bases[i]; - index /= bases[i]; - } - coords[rank - 1] = index; // last iteration - } - - FORCEINLINE size_t offset_from_coords(const Nd4jLong* strides, const Nd4jLong* coords, const Nd4jLong& rank) { - - size_t offset = 0; - size_t rank_4 = rank & -4; - for (int i = 0; i < rank_4; i += 4) { - offset = offset - + coords[i] * strides[i] - + coords[i + 1] * strides[i + 1] - + coords[i + 2] * strides[i + 2] - + coords[i + 3] * strides[i + 3]; - } - for (int i = rank_4; i < rank; i++) { - offset += coords[i] * strides[i]; - } - return offset; - } - - - FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong*& x_strides, const Nd4jLong*& z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { - - zip_size_t offset = { 0,0 }; - size_t rank_4 = rank & -4; - for (int i = 0; i < rank_4; i += 4) { - offset.first = offset.first - + coords[i] * x_strides[i] - + coords[i + 1] * x_strides[i + 1] - + coords[i + 2] * x_strides[i + 2] - + coords[i + 3] * x_strides[i + 3]; - offset.second = offset.second - + coords[i] * z_strides[i] - + coords[i + 1] * z_strides[i + 1] - + coords[i + 2] * z_strides[i + 2] - + coords[i + 3] * z_strides[i + 3]; - } - for (int i = rank_4; i < rank; i++) { - offset.first += coords[i] * x_strides[i]; - offset.second += coords[i] * z_strides[i]; - } - return offset; - } - - template - constexpr size_t StridesOrderInd() { - return Last_Index_Faster ? Rank - Index - 1 : Index; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == Index), size_t>::type - coord_inc_n(CoordsState& cbs, size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { - last_offset += cbs.CoordsState::stride; - COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; - return last_offset; - } - //overflow case should not happen - COORDS(cbs, Ind) = 0; - //last_offset = 0;// last_offset + strides[Ind] - adjust_stride; - return 0; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != Index), size_t >::type - coord_inc_n(CoordsState& cbs, size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { - last_offset = last_offset + cbs.CoordsState::stride; - COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; - } - else { - //lets adjust offset - last_offset -= OF_ADJUST(cbs, Ind); - COORDS(cbs, Ind) = 0; - last_offset = coord_inc_n(cbs, last_offset); - } - - return last_offset; - - } - - template - FORCEINLINE size_t inc_coords(CoordsState& cbs, size_t last_offset) { - - return coord_inc_n(cbs,/* 1,*/ last_offset/*, 0*/); - } - - template - FORCEINLINE size_t inc_coords_ews(CoordsState& cbs, size_t last_offset, size_t ews) { - if (ews == 1) { - constexpr size_t Ind = StridesOrderInd(); - return last_offset + STRIDE(cbs, Ind); - } - return coord_inc_n(cbs,/* 1,*/ last_offset/*, 0*/); - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type - coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { - last_offset.first += ZIP_STRIDE1(cbs, Ind); - last_offset.second += ZIP_STRIDE2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; - return last_offset; - } - //overflow case should not happen - ZIP_COORDS(cbs, Ind) = 0; - //last_offset = 0;// last_offset + strides[Ind] - adjust_stride; - return { 0,0 }; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t >::type - coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { - - constexpr size_t Ind = StridesOrderInd(); - - if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { - last_offset.first += ZIP_STRIDE1(cbs, Ind); - last_offset.second += ZIP_STRIDE2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; - } - else { - - //lets adjust offset - last_offset.first -= ZIP_OF_ADJUST1(cbs, Ind); - last_offset.second -= ZIP_OF_ADJUST2(cbs, Ind); - ZIP_COORDS(cbs, Ind) = 0; - last_offset = coord_inc_n(cbs, last_offset); - } - - return last_offset; - - } - - template - FORCEINLINE zip_size_t inc_coords(ZipCoordsState& cbs, zip_size_t last_offset) { - - return coord_inc_n(cbs, last_offset); - } - - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), size_t>::type - init_coords(CoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { - constexpr size_t Ind = StridesOrderInd(); - COORDS(cbs, Ind) = index % bases[Ind]; - LAST_NUM(cbs, Ind) = bases[Ind] - 1; - STRIDE(cbs, Ind) = strides[Ind]; - OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; - offset += COORDS(cbs, Ind) * strides[Ind]; - return offset; - } - - - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), size_t>::type - init_coords(CoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { - constexpr size_t Ind = StridesOrderInd(); - COORDS(cbs, Ind) = index % bases[Ind]; - LAST_NUM(cbs, Ind) = bases[Ind] - 1; - STRIDE(cbs, Ind) = strides[Ind]; - OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; - offset += COORDS(cbs, Ind) * strides[Ind]; - return init_coords(cbs, index / bases[Ind], bases, strides, offset); - } - - - - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), bool>::type - eq_coords(CoordsState& cbs, const Nd4jLong* coords) { - return COORDS(cbs, rankIndex) == coords[rankIndex]; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), bool>::type - eq_coords(CoordsState& cbs, const Nd4jLong* coords) { - return COORDS(cbs, rankIndex) == coords[rankIndex] && eq_coords(cbs, coords); - } - - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), bool>::type - eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { - return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex]; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), bool>::type - eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { - return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex] && eq_zip_coords(cbs, coords); - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type - init_coords(ZipCoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, zip_size_t offset = {}) { - constexpr size_t Ind = StridesOrderInd(); - ZIP_COORDS(cbs, Ind) = index % bases[Ind]; - ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; - ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; - ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; - ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - return offset; - } - - template - FORCEINLINE - typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type - init_coords(ZipCoordsState& cbs, const Nd4jLong index, const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, zip_size_t offset = {}) { - constexpr size_t Ind = StridesOrderInd(); - ZIP_COORDS(cbs, Ind) = index % bases[Ind]; - ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; - ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; - ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; - ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); - offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); - return init_coords(cbs, index / bases[Ind], bases, x_strides, z_strides, offset); - } - - - //inc coords for non constant Ranks - template - FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong* coords, size_t last_offset, const size_t rank, const size_t skip = 0) { - - Nd4jLong val; - for (int i = rank - skip - 1; i >= 0; i--) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset += strides[i]; - break; - } - else { - last_offset -= coords[i] * strides[i]; - coords[i] = 0; - } - } - return last_offset; - } - - template<> - FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong* coords, size_t last_offset, const size_t rank, const size_t skip) { - - Nd4jLong val; - for (int i = skip; i < rank; i++) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset += strides[i]; - break; - } - else { - last_offset -= coords[i] * strides[i]; - coords[i] = 0; - } - } - return last_offset; - } - - - template - FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, Nd4jLong* coords, zip_size_t last_offset, const size_t rank, const size_t skip = 0) { - - Nd4jLong val = 0; - for (int i = rank - skip - 1; i >= 0; i--) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - last_offset.first += x_strides[i]; - last_offset.second += z_strides[i]; - break; - } - else { - last_offset.first -= coords[i] * x_strides[i]; - last_offset.second -= coords[i] * z_strides[i]; - coords[i] = 0; - } - } - return last_offset; - } - - template<> - FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, Nd4jLong* coords, zip_size_t last_offset, const size_t rank, const size_t skip) { - - Nd4jLong val = 0; - for (int i = skip; i < rank; i++) { - val = coords[i] + 1; - if (likely(val < bases[i])) { - coords[i] = val; - - last_offset.first += x_strides[i]; - last_offset.second += z_strides[i]; - break; - } - else { - last_offset.first -= coords[i] * x_strides[i]; - last_offset.second -= coords[i] * z_strides[i]; - coords[i] = 0; - } - } - return last_offset; - } +using zip_size_t = std::pair; + +template +struct CoordsState : CoordsState { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride; + Nd4jLong adjust; + CoordsState() : CoordsState() {} +}; + +template <> +struct CoordsState<0> { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride; + Nd4jLong adjust; + CoordsState() {} +}; + +template +struct ZipCoordsState : ZipCoordsState { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride1; + Nd4jLong stride2; + Nd4jLong adjust1; + Nd4jLong adjust2; + ZipCoordsState() : ZipCoordsState() {} +}; + +template <> +struct ZipCoordsState<0> { + Nd4jLong coord; + Nd4jLong last_num; + Nd4jLong stride1; + Nd4jLong stride2; + Nd4jLong adjust1; + Nd4jLong adjust2; + ZipCoordsState() {} +}; + +#define COORDS(x, index) ((x).::sd::CoordsState<(index)>::coord) +#define STRIDE(x, index) ((x).::sd::CoordsState<(index)>::stride) +#define LAST_NUM(x, index) ((x).::sd::CoordsState<(index)>::last_num) +#define OF_ADJUST(x, index) ((x).::sd::CoordsState<(index)>::adjust) +#define ZIP_LAST_NUM(x, index) ((x).::sd::ZipCoordsState<(index)>::last_num) +#define ZIP_COORDS(x, index) ((x).::sd::ZipCoordsState<(index)>::coord) +#define ZIP_STRIDE1(x, index) ((x).::sd::ZipCoordsState<(index)>::stride1) +#define ZIP_STRIDE2(x, index) ((x).::sd::ZipCoordsState<(index)>::stride2) +#define ZIP_OF_ADJUST1(x, index) ((x).::sd::ZipCoordsState<(index)>::adjust1) +#define ZIP_OF_ADJUST2(x, index) ((x).::sd::ZipCoordsState<(index)>::adjust2) + +FORCEINLINE void index2coords_C(Nd4jLong index, const Nd4jLong rank, + const Nd4jLong* bases, Nd4jLong* coords) { + for (size_t i = rank - 1; i > 0; --i) { + coords[i] = index % bases[i]; + index /= bases[i]; + } + coords[0] = index; // last iteration +} + +FORCEINLINE void index2coords_F(Nd4jLong index, const Nd4jLong rank, + const Nd4jLong* bases, Nd4jLong* coords) { + for (size_t i = 0; i < rank - 1; i++) { + coords[i] = index % bases[i]; + index /= bases[i]; + } + coords[rank - 1] = index; // last iteration +} + +FORCEINLINE size_t offset_from_coords(const Nd4jLong* strides, + const Nd4jLong* coords, + const Nd4jLong& rank) { + size_t offset = 0; + size_t rank_4 = rank & -4; + for (int i = 0; i < rank_4; i += 4) { + offset = offset + coords[i] * strides[i] + coords[i + 1] * strides[i + 1] + + coords[i + 2] * strides[i + 2] + coords[i + 3] * strides[i + 3]; + } + for (int i = rank_4; i < rank; i++) { + offset += coords[i] * strides[i]; + } + return offset; +} + +FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong*& x_strides, + const Nd4jLong*& z_strides, + const Nd4jLong* coords, + const Nd4jLong& rank) { + zip_size_t offset = {0, 0}; + size_t rank_4 = rank & -4; + for (int i = 0; i < rank_4; i += 4) { + offset.first = offset.first + coords[i] * x_strides[i] + + coords[i + 1] * x_strides[i + 1] + + coords[i + 2] * x_strides[i + 2] + + coords[i + 3] * x_strides[i + 3]; + offset.second = offset.second + coords[i] * z_strides[i] + + coords[i + 1] * z_strides[i + 1] + + coords[i + 2] * z_strides[i + 2] + + coords[i + 3] * z_strides[i + 3]; + } + for (int i = rank_4; i < rank; i++) { + offset.first += coords[i] * x_strides[i]; + offset.second += coords[i] * z_strides[i]; + } + return offset; +} + +template +constexpr size_t StridesOrderInd() { + return Last_Index_Faster ? Rank - Index - 1 : Index; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 == Index), size_t>::type +coord_inc_n(CoordsState& cbs, size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { + last_offset += cbs.CoordsState::stride; + COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; + return last_offset; + } + // overflow case should not happen + COORDS(cbs, Ind) = 0; + // last_offset = 0;// last_offset + strides[Ind] - adjust_stride; + return 0; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != Index), size_t>::type +coord_inc_n(CoordsState& cbs, size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(COORDS(cbs, Ind) < LAST_NUM(cbs, Ind))) { + last_offset = last_offset + cbs.CoordsState::stride; + COORDS(cbs, Ind) = COORDS(cbs, Ind) + 1; + } else { + // lets adjust offset + last_offset -= OF_ADJUST(cbs, Ind); + COORDS(cbs, Ind) = 0; + last_offset = + coord_inc_n(cbs, last_offset); + } + + return last_offset; +} + +template +FORCEINLINE size_t inc_coords(CoordsState& cbs, size_t last_offset) { + return coord_inc_n( + cbs, /* 1,*/ last_offset /*, 0*/); +} + +template +FORCEINLINE size_t inc_coords_ews(CoordsState& cbs, + size_t last_offset, size_t ews) { + if (ews == 1) { + constexpr size_t Ind = + StridesOrderInd(); + return last_offset + STRIDE(cbs, Ind); + } + return coord_inc_n( + cbs, /* 1,*/ last_offset /*, 0*/); +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type +coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { + last_offset.first += ZIP_STRIDE1(cbs, Ind); + last_offset.second += ZIP_STRIDE2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; + return last_offset; + } + // overflow case should not happen + ZIP_COORDS(cbs, Ind) = 0; + // last_offset = 0;// last_offset + strides[Ind] - adjust_stride; + return {0, 0}; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type +coord_inc_n(ZipCoordsState& cbs, zip_size_t last_offset) { + constexpr size_t Ind = StridesOrderInd(); + + if (likely(ZIP_COORDS(cbs, Ind) < ZIP_LAST_NUM(cbs, Ind))) { + last_offset.first += ZIP_STRIDE1(cbs, Ind); + last_offset.second += ZIP_STRIDE2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = ZIP_COORDS(cbs, Ind) + 1; + } else { + // lets adjust offset + last_offset.first -= ZIP_OF_ADJUST1(cbs, Ind); + last_offset.second -= ZIP_OF_ADJUST2(cbs, Ind); + ZIP_COORDS(cbs, Ind) = 0; + last_offset = + coord_inc_n(cbs, last_offset); + } + + return last_offset; +} +template +FORCEINLINE zip_size_t inc_coords(ZipCoordsState& cbs, + zip_size_t last_offset) { + return coord_inc_n(cbs, last_offset); } +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), size_t>::type +init_coords(CoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { + constexpr size_t Ind = StridesOrderInd(); + COORDS(cbs, Ind) = index % bases[Ind]; + LAST_NUM(cbs, Ind) = bases[Ind] - 1; + STRIDE(cbs, Ind) = strides[Ind]; + OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; + offset += COORDS(cbs, Ind) * strides[Ind]; + return offset; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), size_t>::type +init_coords(CoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* strides, size_t offset = 0) { + constexpr size_t Ind = StridesOrderInd(); + COORDS(cbs, Ind) = index % bases[Ind]; + LAST_NUM(cbs, Ind) = bases[Ind] - 1; + STRIDE(cbs, Ind) = strides[Ind]; + OF_ADJUST(cbs, Ind) = bases[Ind] * strides[Ind] - strides[Ind]; + offset += COORDS(cbs, Ind) * strides[Ind]; + return init_coords( + cbs, index / bases[Ind], bases, strides, offset); +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_coords(CoordsState& cbs, const Nd4jLong* coords) { + return COORDS(cbs, rankIndex) == coords[rankIndex]; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_coords(CoordsState& cbs, const Nd4jLong* coords) { + return COORDS(cbs, rankIndex) == coords[rankIndex] && + eq_coords(cbs, coords); +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { + return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex]; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_zip_coords(ZipCoordsState& cbs, const Nd4jLong* coords) { + return ZIP_COORDS(cbs, rankIndex) == coords[rankIndex] && + eq_zip_coords(cbs, coords); +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), zip_size_t>::type +init_coords(ZipCoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, zip_size_t offset = {}) { + constexpr size_t Ind = StridesOrderInd(); + ZIP_COORDS(cbs, Ind) = index % bases[Ind]; + ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; + ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; + ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; + ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + return offset; +} + +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), zip_size_t>::type +init_coords(ZipCoordsState& cbs, const Nd4jLong index, + const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, zip_size_t offset = {}) { + constexpr size_t Ind = StridesOrderInd(); + ZIP_COORDS(cbs, Ind) = index % bases[Ind]; + ZIP_LAST_NUM(cbs, Ind) = bases[Ind] - 1; + ZIP_STRIDE1(cbs, Ind) = x_strides[Ind]; + ZIP_STRIDE2(cbs, Ind) = z_strides[Ind]; + ZIP_OF_ADJUST1(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + ZIP_OF_ADJUST2(cbs, Ind) = ZIP_LAST_NUM(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + offset.first += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE1(cbs, Ind); + offset.second += ZIP_COORDS(cbs, Ind) * ZIP_STRIDE2(cbs, Ind); + return init_coords( + cbs, index / bases[Ind], bases, x_strides, z_strides, offset); +} + +// inc coords for non constant Ranks +template +FORCEINLINE size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* strides, + Nd4jLong* coords, size_t last_offset, + const size_t rank, const size_t skip = 0) { + Nd4jLong val; + for (int i = rank - skip - 1; i >= 0; i--) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset += strides[i]; + break; + } else { + last_offset -= coords[i] * strides[i]; + coords[i] = 0; + } + } + return last_offset; +} + +template <> +FORCEINLINE size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* strides, Nd4jLong* coords, + size_t last_offset, const size_t rank, + const size_t skip) { + Nd4jLong val; + for (int i = skip; i < rank; i++) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset += strides[i]; + break; + } else { + last_offset -= coords[i] * strides[i]; + coords[i] = 0; + } + } + return last_offset; +} + +template +FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* x_strides, + const Nd4jLong* z_strides, Nd4jLong* coords, + zip_size_t last_offset, const size_t rank, + const size_t skip = 0) { + Nd4jLong val = 0; + for (int i = rank - skip - 1; i >= 0; i--) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset.first += x_strides[i]; + last_offset.second += z_strides[i]; + break; + } else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; +} + +template <> +FORCEINLINE zip_size_t inc_coords(const Nd4jLong* bases, + const Nd4jLong* x_strides, + const Nd4jLong* z_strides, + Nd4jLong* coords, + zip_size_t last_offset, + const size_t rank, const size_t skip) { + Nd4jLong val = 0; + for (int i = skip; i < rank; i++) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + + last_offset.first += x_strides[i]; + last_offset.second += z_strides[i]; + break; + } else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; +} + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/MKLDNNStream.h b/libnd4j/include/helpers/MKLDNNStream.h index f575c48d9f92..0b5caa28ceda 100644 --- a/libnd4j/include/helpers/MKLDNNStream.h +++ b/libnd4j/include/helpers/MKLDNNStream.h @@ -27,48 +27,55 @@ #if defined(HAVE_MKLDNN) +#include +#include +#include + namespace sd { - class MKLDNNStream { - protected: - std::string _opName; +class MKLDNNStream { + protected: + std::string _opName; - std::vector _inputs; - std::vector _outputs; - std::vector _floatArguments; - std::vector _intArguments; + std::vector _inputs; + std::vector _outputs; + std::vector _floatArguments; + std::vector _intArguments; - public: - template - static bool isSupported() { - // FIXME: strict float support doesn't work anymore - return typeid(X) == typeid(float) && typeid(Y) == typeid(float); - } + public: + template + static bool isSupported() { + // FIXME: strict float support doesn't work anymore + return typeid(X) == typeid(float) && typeid(Y) == typeid(float); + } - static bool isSupported(const std::vector &arrays) { - // FIXME: strict float support doesn't work anymore - for (auto v:arrays) { - if (v != nullptr && v->dataType() != sd::DataType::FLOAT32) { - return false; - } - } - return true; - } + static bool isSupported(const std::vector &arrays) { + // FIXME: strict float support doesn't work anymore + for (auto v : arrays) { + if (v != nullptr && v->dataType() != sd::DataType::FLOAT32) { + return false; + } + } + return true; + } - explicit MKLDNNStream(const std::string &opName) : _opName(opName) { } + explicit MKLDNNStream(const std::string &opName) : _opName(opName) {} - bool checkAndReset(const std::vector &inputs, const std::vector &outputs, - const std::vector &floatArguments, const std::vector &intArguments) { - if (inputs != _inputs || outputs != _outputs || floatArguments != _floatArguments || intArguments != _intArguments) { - _inputs = inputs; - _outputs = outputs; - _floatArguments = floatArguments; - _intArguments = intArguments; - return true; - } - return false; - } - }; -} + bool checkAndReset(const std::vector &inputs, + const std::vector &outputs, + const std::vector &floatArguments, + const std::vector &intArguments) { + if (inputs != _inputs || outputs != _outputs || + floatArguments != _floatArguments || intArguments != _intArguments) { + _inputs = inputs; + _outputs = outputs; + _floatArguments = floatArguments; + _intArguments = intArguments; + return true; + } + return false; + } +}; +} // namespace sd #endif -#endif //LIBND4J_MKLDNNSTREAM_H +#endif // LIBND4J_MKLDNNSTREAM_H diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 33d71795ff83..549d20ea7c68 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -25,44 +25,68 @@ #include "array/NDArray.h" namespace sd { - class SD_EXPORT MmulHelper { - - private: - - // multiptication N-dimensions tensor on other N-dimensions one - static sd::NDArray* mmulNxN(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, const double alpha = 1.0, const double beta = 0.0, const char outOrder = 'f'); - - // dot product of vectors (X * Y) = Z[0] - static sd::NDArray* dot(const sd::NDArray* X, const sd::NDArray* Y, sd::NDArray* Z, const double alpha = 1.0, const double beta = 0.0); - - // multiptication Matrix to Matrix - static sd::NDArray* mmulMxM(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, double alpha = 1.0, double beta = 0.0, const char outOrder = 'f'); - - // multiptication Matrix to vector - static sd::NDArray* mmulMxV(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, double alpha = 1.0, double beta = 0.0, const char outOrder = 'f'); - - public: - - static sd::NDArray* mmul(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C = nullptr, const double alpha = 1.0, const double beta = 0.0, const char outOrder = 'f'); - - static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::initializer_list& axesA, const std::initializer_list& axesB = {}); - - static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::vector& axesA, const std::vector& axesB); - - static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector& axes_a, const std::vector& axes_b, const std::vector& permutForC = {}); - +class SD_EXPORT MmulHelper { + private: + // multiptication N-dimensions tensor on other N-dimensions one + static sd::NDArray* mmulNxN(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, const double alpha = 1.0, + const double beta = 0.0, + const char outOrder = 'f'); + + // dot product of vectors (X * Y) = Z[0] + static sd::NDArray* dot(const sd::NDArray* X, const sd::NDArray* Y, + sd::NDArray* Z, const double alpha = 1.0, + const double beta = 0.0); + + // multiptication Matrix to Matrix + static sd::NDArray* mmulMxM(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, double alpha = 1.0, + double beta = 0.0, const char outOrder = 'f'); + + // multiptication Matrix to vector + static sd::NDArray* mmulMxV(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, double alpha = 1.0, + double beta = 0.0, const char outOrder = 'f'); + + public: + static sd::NDArray* mmul(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C = nullptr, const double alpha = 1.0, + const double beta = 0.0, const char outOrder = 'f'); + + static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, + const std::initializer_list& axesA, + const std::initializer_list& axesB = {}); + + static sd::NDArray* tensorDot(const sd::NDArray* A, const sd::NDArray* B, + const std::vector& axesA, + const std::vector& axesB); + + static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, const std::vector& axes_a, + const std::vector& axes_b, + const std::vector& permutForC = {}); #ifndef __JAVACPP_HACK__ - /** - * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself - */ - static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC); - static sd::NDArray* tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); + /** + * modif - (can be empty) vector containing a subsequence of + * permutation/reshaping arrays (in any order), user must take care of + * correctness of such arrays by himself + */ + static void tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, + const std::vector>& modifA, + const std::vector>& modifB, + const std::vector>& modifC); + static sd::NDArray* tensorDot( + const sd::NDArray* a, const sd::NDArray* b, + const std::vector>& modifA, + const std::vector>& modifB); #endif - static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha = 1.0, double beta = 0.0); - }; -} - + static void matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, + const bool transX, const bool transY, double alpha = 1.0, + double beta = 0.0); +}; +} // namespace sd -#endif //LIBND4J_MMULHELPER_H \ No newline at end of file +#endif // LIBND4J_MMULHELPER_H \ No newline at end of file diff --git a/libnd4j/include/helpers/OmpLaunchHelper.h b/libnd4j/include/helpers/OmpLaunchHelper.h index 574a3182ceec..289f91253b63 100644 --- a/libnd4j/include/helpers/OmpLaunchHelper.h +++ b/libnd4j/include/helpers/OmpLaunchHelper.h @@ -22,49 +22,48 @@ #ifndef LIBND4J_OMPLAUNCHHELPER_H #define LIBND4J_OMPLAUNCHHELPER_H -#include -#include #include +#include + +#include namespace sd { class SD_EXPORT OmpLaunchHelper { - - public: - - OmpLaunchHelper() = delete; - - OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads = -1); - - FORCEINLINE Nd4jLong getThreadOffset(const int threadNum); - FORCEINLINE Nd4jLong getItersPerThread(const int threadNum); - - static Nd4jLong betterSpan(Nd4jLong N); - static Nd4jLong betterSpan(Nd4jLong N, Nd4jLong numThreads); - - static int betterThreads(Nd4jLong N); - static int betterThreads(Nd4jLong N, int maxThreads); - - static int tadThreads(Nd4jLong tadLength, Nd4jLong numTads); - - int _numThreads; - unsigned int _itersPerThread; - unsigned int _remainder; + public: + OmpLaunchHelper() = delete; + + OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads = -1); + + FORCEINLINE Nd4jLong getThreadOffset(const int threadNum); + FORCEINLINE Nd4jLong getItersPerThread(const int threadNum); + + static Nd4jLong betterSpan(Nd4jLong N); + static Nd4jLong betterSpan(Nd4jLong N, Nd4jLong numThreads); + + static int betterThreads(Nd4jLong N); + static int betterThreads(Nd4jLong N, int maxThreads); + + static int tadThreads(Nd4jLong tadLength, Nd4jLong numTads); + + int _numThreads; + unsigned int _itersPerThread; + unsigned int _remainder; }; //////////////////////////////////////////////////////////////////////////////// FORCEINLINE Nd4jLong OmpLaunchHelper::getThreadOffset(const int threadNum) { - - return threadNum * _itersPerThread; + return threadNum * _itersPerThread; } //////////////////////////////////////////////////////////////////////////////// FORCEINLINE Nd4jLong OmpLaunchHelper::getItersPerThread(const int threadNum) { - - return (threadNum == _numThreads - 1) ? _itersPerThread + _remainder : _itersPerThread; // last thread may contain bigger number of iterations -} - + return (threadNum == _numThreads - 1) + ? _itersPerThread + _remainder + : _itersPerThread; // last thread may contain bigger number of + // iterations } +} // namespace sd -#endif //LIBND4J_OMPLAUNCHHELPER_H +#endif // LIBND4J_OMPLAUNCHHELPER_H diff --git a/libnd4j/include/helpers/OpArgsHolder.h b/libnd4j/include/helpers/OpArgsHolder.h index 0850181a0634..f19569a017f7 100644 --- a/libnd4j/include/helpers/OpArgsHolder.h +++ b/libnd4j/include/helpers/OpArgsHolder.h @@ -21,85 +21,71 @@ #ifndef LIBND4J_OPARGSHOLDER_H #define LIBND4J_OPARGSHOLDER_H - #include #include namespace sd { class SD_EXPORT OpArgsHolder { + private: + std::vector _inArrs = std::vector(); + std::vector _tArgs = std::vector(); + std::vector _iArgs = std::vector(); + std::vector _bArgs = std::vector(); -private: - - std::vector _inArrs = std::vector(); - std::vector _tArgs = std::vector(); - std::vector _iArgs = std::vector(); - std::vector _bArgs = std::vector(); - - std::vector _isArrAlloc = std::vector(); - - int _numInArrs = _inArrs.size(); - int _numTArgs = _tArgs.size(); - int _numIArgs = _iArgs.size(); - int _numBArgs = _bArgs.size(); - -public: + std::vector _isArrAlloc = std::vector(); - // default constructor - OpArgsHolder(); + int _numInArrs = _inArrs.size(); + int _numTArgs = _tArgs.size(); + int _numIArgs = _iArgs.size(); + int _numBArgs = _bArgs.size(); - // copy constructor - OpArgsHolder(const OpArgsHolder& other); + public: + // default constructor + OpArgsHolder(); - // constructor - OpArgsHolder(const std::vector& inArrs, const std::vector& tArgs = std::vector(), const std::vector& iArgs = std::vector(), const std::vector& bArgs = std::vector()); + // copy constructor + OpArgsHolder(const OpArgsHolder& other); - // move constructor - OpArgsHolder(OpArgsHolder&& other) noexcept; + // constructor + OpArgsHolder(const std::vector& inArrs, + const std::vector& tArgs = std::vector(), + const std::vector& iArgs = std::vector(), + const std::vector& bArgs = std::vector()); - // assignment operator - OpArgsHolder& operator=(const OpArgsHolder& other); + // move constructor + OpArgsHolder(OpArgsHolder&& other) noexcept; - // move assignment operator - OpArgsHolder& operator=(OpArgsHolder&& other) noexcept; + // assignment operator + OpArgsHolder& operator=(const OpArgsHolder& other); - const std::vector& getInArrs() const - {return _inArrs; } + // move assignment operator + OpArgsHolder& operator=(OpArgsHolder&& other) noexcept; - const std::vector& getTArgs() const - {return _tArgs; } + const std::vector& getInArrs() const { return _inArrs; } - const std::vector& getIArgs() const - {return _iArgs; } + const std::vector& getTArgs() const { return _tArgs; } - const std::vector& getBArgs() const - {return _bArgs; } + const std::vector& getIArgs() const { return _iArgs; } - const std::vector& getAllocInfo() const - {return _isArrAlloc; } + const std::vector& getBArgs() const { return _bArgs; } - int getNumInArrs() const - {return _numInArrs; } + const std::vector& getAllocInfo() const { return _isArrAlloc; } - int getNumTArgs() const - {return _numTArgs; } + int getNumInArrs() const { return _numInArrs; } - int getNumIArgs() const - {return _numIArgs; } + int getNumTArgs() const { return _numTArgs; } - int getNumBArgs() const - {return _numBArgs; } + int getNumIArgs() const { return _numIArgs; } - OpArgsHolder createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace = false) const; + int getNumBArgs() const { return _numBArgs; } - ~OpArgsHolder() noexcept; + OpArgsHolder createArgsHolderForBP(const std::vector& inGradArrs, + const bool isInPlace = false) const; + ~OpArgsHolder() noexcept; }; +} // namespace sd - - - -} - -#endif //LIBND4J_OPARGSHOLDER_H +#endif // LIBND4J_OPARGSHOLDER_H diff --git a/libnd4j/include/helpers/OpBenchmark.h b/libnd4j/include/helpers/OpBenchmark.h index 4a4996929926..9ffa67de93f8 100644 --- a/libnd4j/include/helpers/OpBenchmark.h +++ b/libnd4j/include/helpers/OpBenchmark.h @@ -21,54 +21,57 @@ #ifndef SD_OPEXECUTIONER_H #define SD_OPEXECUTIONER_H -#include #include -#include -#include #include +#include +#include +#include namespace sd { - class SD_EXPORT OpBenchmark { - protected: - int _opNum = 0; - std::string _testName; - NDArray _x; - NDArray _y; - NDArray _z; - std::vector _axis; - public: - OpBenchmark() = default; - OpBenchmark(const std::string& name, const NDArray &x, const NDArray &y, const NDArray &z); - OpBenchmark(const std::string& name, const NDArray &x, const NDArray &z); - OpBenchmark(const std::string& name, const NDArray &x, const NDArray &z, const std::vector &axis); - OpBenchmark(const std::string& name, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis); +class SD_EXPORT OpBenchmark { + protected: + int _opNum = 0; + std::string _testName; + NDArray _x; + NDArray _y; + NDArray _z; + std::vector _axis; - void setOpNum(int opNum); - void setTestName(const std::string &testName); - void setX(const NDArray &array); - void setY(const NDArray &array); - void setZ(const NDArray &array); - void setAxis(std::vector axis); - void setAxis(std::initializer_list axis); + public: + OpBenchmark() = default; + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, + const NDArray &z); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z, + const std::vector &axis); + OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, + const NDArray &z, const std::vector &axis); - NDArray& x(); - int opNum() const; - const std::string& testName() const; - std::vector getAxis(); + void setOpNum(int opNum); + void setTestName(const std::string &testName); + void setX(const NDArray &array); + void setY(const NDArray &array); + void setZ(const NDArray &array); + void setAxis(std::vector axis); + void setAxis(std::initializer_list axis); - virtual std::string extra(); - virtual std::string dataType(); - virtual std::string axis() = 0; - virtual std::string orders() = 0; - virtual std::string strides() = 0; - virtual std::string shape(); - virtual std::string inplace() = 0; + NDArray &x(); + int opNum() const; + const std::string &testName() const; + std::vector getAxis(); - virtual void executeOnce() = 0; + virtual std::string extra(); + virtual std::string dataType(); + virtual std::string axis() = 0; + virtual std::string orders() = 0; + virtual std::string strides() = 0; + virtual std::string shape(); + virtual std::string inplace() = 0; - virtual OpBenchmark* clone() = 0; - }; -} + virtual void executeOnce() = 0; + virtual OpBenchmark *clone() = 0; +}; +} // namespace sd -#endif //SD_OPEXECUTIONER_H +#endif // SD_OPEXECUTIONER_H diff --git a/libnd4j/include/helpers/OpTracker.h b/libnd4j/include/helpers/OpTracker.h index 1d38fabe38a3..d62c94851007 100644 --- a/libnd4j/include/helpers/OpTracker.h +++ b/libnd4j/include/helpers/OpTracker.h @@ -21,40 +21,44 @@ #ifndef LIBND4J_OP_TRACKER_H #define LIBND4J_OP_TRACKER_H -#include -#include -#include -#include #include #include #include +#include + +#include +#include +#include namespace sd { - class SD_EXPORT OpTracker { - private: - static OpTracker* _INSTANCE; +class SD_EXPORT OpTracker { + private: + static OpTracker* _INSTANCE; + + std::string _export; - std::string _export; + int _operations = 0; + std::map> _map; - int _operations = 0; - std::map> _map; + OpTracker() = default; + ~OpTracker() = default; - OpTracker() = default; - ~OpTracker() = default; + template + std::string local_to_string(T value); - template - std::string local_to_string(T value); - public: - static OpTracker* getInstance(); + public: + static OpTracker* getInstance(); - int totalGroups(); - int totalOperations(); + int totalGroups(); + int totalOperations(); - void storeOperation(sd::graph::OpType opType, const sd::ops::OpDescriptor& descriptor); - void storeOperation(sd::graph::OpType opType, const char* opName, const Nd4jLong opNum); + void storeOperation(sd::graph::OpType opType, + const sd::ops::OpDescriptor& descriptor); + void storeOperation(sd::graph::OpType opType, const char* opName, + const Nd4jLong opNum); - const char* exportOperations(); - }; -} + const char* exportOperations(); +}; +} // namespace sd #endif diff --git a/libnd4j/include/helpers/PointersManager.h b/libnd4j/include/helpers/PointersManager.h index ba01713df739..6572d7a2ee0c 100644 --- a/libnd4j/include/helpers/PointersManager.h +++ b/libnd4j/include/helpers/PointersManager.h @@ -22,60 +22,56 @@ #ifndef CUDAMANAGER_H #define CUDAMANAGER_H -#include -#include #include - #include +#include +#include + namespace sd { class SD_EXPORT PointersManager { + private: + sd::LaunchContext* _context; + std::vector _pOnGlobMem; + std::string _funcName; - private: - - sd::LaunchContext *_context; - std::vector _pOnGlobMem; - std::string _funcName; + public: + PointersManager(const sd::LaunchContext* context, + const std::string& funcName = ""); - public: + ~PointersManager(); - PointersManager(const sd::LaunchContext* context, const std::string& funcName = ""); + void* replicatePointer(const void* src, const size_t size); - ~PointersManager(); - - void* replicatePointer(const void* src, const size_t size); - - void synchronize() const; - - template - void printDevContentOnHost(const void* pDev, const Nd4jLong len) const; + void synchronize() const; + template + void printDevContentOnHost(const void* pDev, const Nd4jLong len) const; #ifdef __CUDABLAS__ - template - static void printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid = 0); + template + static void printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, + const int tid = 0); #endif #ifdef __CUDACC__ - template - static FORCEINLINE __device__ void printDevContentOnDev(const void* pDev, const Nd4jLong len, const int tid = 0) { - if(blockIdx.x * blockDim.x + threadIdx.x != tid) - return; + template + static FORCEINLINE __device__ void printDevContentOnDev(const void* pDev, + const Nd4jLong len, + const int tid = 0) { + if (blockIdx.x * blockDim.x + threadIdx.x != tid) return; - printf("device print out: \n"); - for(Nd4jLong i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pDev)[i]); + printf("device print out: \n"); + for (Nd4jLong i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pDev)[i]); - printf("\n"); - } + printf("\n"); + } #endif - }; -} - - +} // namespace sd -#endif // CUDAMANAGER_H +#endif // CUDAMANAGER_H diff --git a/libnd4j/include/helpers/RandomLauncher.h b/libnd4j/include/helpers/RandomLauncher.h index 5eec7e0bd11d..0e22eda3b079 100644 --- a/libnd4j/include/helpers/RandomLauncher.h +++ b/libnd4j/include/helpers/RandomLauncher.h @@ -19,29 +19,51 @@ // #include -#include -#include #include +#include +#include namespace sd { - class SD_EXPORT RandomLauncher { - public: - static void applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr); +class SD_EXPORT RandomLauncher { + public: + static void applyDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double retainProb, NDArray* z = nullptr); + static void applyInvertedDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z = nullptr); + static void applyAlphaDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double retainProb, double alpha, double beta, + double alphaPrime, NDArray* z = nullptr); - static void fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to); + static void fillUniform(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double from, double to); - static void fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillGaussian(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double mean, double stdev); - static void fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda); + static void fillExponential(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double lambda); - static void fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillLogNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double mean, double stdev); - static void fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillTruncatedNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev); - static void fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob); + static void fillBinomial(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + int trials, double prob); - static void fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob); - }; -} \ No newline at end of file + static void fillBernoulli(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, NDArray* array, + double prob); +}; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h index fbbfb71500ef..455a4fdb0600 100644 --- a/libnd4j/include/helpers/ShapeBuilders.h +++ b/libnd4j/include/helpers/ShapeBuilders.h @@ -21,49 +21,76 @@ #ifndef SD_SHAPEBUILDERS_H #define SD_SHAPEBUILDERS_H -#include +#include +#include #include -#include #include -#include -#include - -namespace sd { - class SD_EXPORT ShapeBuilders { - public: - static Nd4jLong* createScalarShapeInfo(sd::DataType dataType, sd::memory::Workspace* workspace = nullptr); +#include - static Nd4jLong* createVectorShapeInfo(const sd::DataType dataType, const Nd4jLong length, sd::memory::Workspace* workspace = nullptr); +#include - /** - * create shapeInfo for given order basing on shape stored in shapeOnly vector - * memory allocation for shapeInfo is on given workspace - */ - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr); - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::vector& shapeOnly, memory::Workspace* workspace = nullptr); - static Nd4jLong* createShapeInfo(const sd::DataType dataType, const char order, const std::initializer_list& shapeOnly, memory::Workspace* workspace = nullptr); +namespace sd { +class SD_EXPORT ShapeBuilders { + public: + static Nd4jLong* createScalarShapeInfo( + sd::DataType dataType, sd::memory::Workspace* workspace = nullptr); - /** - * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo - * if copyStrides is false then strides for new shapeInfo are recalculated - */ - static Nd4jLong* copyShapeInfo(const Nd4jLong* inShapeInfo, const bool copyStrides, memory::Workspace* workspace = nullptr); - static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr); - static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr); + static Nd4jLong* createVectorShapeInfo( + const sd::DataType dataType, const Nd4jLong length, + sd::memory::Workspace* workspace = nullptr); - /** - * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr); + /** + * create shapeInfo for given order basing on shape stored in shapeOnly + * vector memory allocation for shapeInfo is on given workspace + */ + static Nd4jLong* createShapeInfo(const sd::DataType dataType, + const char order, int rank, + const Nd4jLong* shapeOnly, + memory::Workspace* workspace = nullptr); + static Nd4jLong* createShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shapeOnly, + memory::Workspace* workspace = nullptr); + static Nd4jLong* createShapeInfo( + const sd::DataType dataType, const char order, + const std::initializer_list& shapeOnly, + memory::Workspace* workspace = nullptr); - static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace = nullptr); + /** + * allocates memory for new shapeInfo and copy all information from + * inShapeInfo to new shapeInfo if copyStrides is false then strides for new + * shapeInfo are recalculated + */ + static Nd4jLong* copyShapeInfo(const Nd4jLong* inShapeInfo, + const bool copyStrides, + memory::Workspace* workspace = nullptr); + static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const DataType dtype, + const bool copyStrides, + memory::Workspace* workspace = nullptr); + static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const Nd4jLong* shapeInfoToGetTypeFrom, + const bool copyStrides, + memory::Workspace* workspace = nullptr); - static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace = nullptr); + /** + * allocates memory for new shapeInfo and copy all information from + * inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit + * dimensions) and corresponding strides for example inShapeInfo is {3, + * 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2 + * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} + */ + static Nd4jLong* copyShapeInfoWithoutUnites( + const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, + memory::Workspace* workspace = nullptr); - }; -} + static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, + memory::Workspace* workspace = nullptr); + static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, const char order, + const std::vector& shape, + memory::Workspace* workspace = nullptr); +}; +} // namespace sd -#endif //SD_SHAPEBUILDERS_H +#endif // SD_SHAPEBUILDERS_H diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 16be4aeb69f6..a4573268e5f0 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -21,204 +21,296 @@ #ifndef LIBND4J_SHAPEUTILS_H #define LIBND4J_SHAPEUTILS_H -#include #include -namespace sd { - - class SD_EXPORT ShapeUtils { - - public: - - // evaluate shape for array resulting from tensorDot operation, also evaluate shapes and permutation dimensions for transposition of two input arrays - static std::vector evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector axesA, std::vector axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt); - static std::vector evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector& axesA, const std::vector& axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt); - - // evaluate resulting shape after reduce operation - static const Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const NDArray& arr, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - static const Nd4jLong* evalReduceShapeInfo(const char order, const std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr); - - /** - * evaluate output shape for reduce operation when input shape is empty - * behavior is analogous to tf - */ - static const Nd4jLong* evalReduceShapeInfoEmpty(const char order, const std::vector& dimensions, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace); - - // evaluate shape for array which is result of repeat operation applied to arr - static std::vector evalRepeatShape(int axis, const std::vector& repeats, const NDArray& arr); - - // evaluate shapeInfo of permuted array - // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order - static const Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); - static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace); - - // evaluate shapeInfo of transposed array - // if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order - static const Nd4jLong* evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides = false); - - static bool copyVectorPart(std::vector& target, std::vector& source, int rank, int offset); - - // return new (shorter) sorted dimensions array without dimensions that are present in input vector - static std::vector evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions); - static std::vector evalDimsToExclude(const int rank, const std::vector& dimensions); - - // check whether 2 arrays have mutually broadcastable shapes - // shape comparison starts from the end - static bool areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2); - static bool areShapesBroadcastable(const Nd4jLong* shapeX, const Nd4jLong* shapeY); - static bool areShapesBroadcastable(const std::vector& shape1, const std::vector& shape2); - - // check the possibility of broadcast operation, if true then return shapeInfo of resulting array - // if evalMinMax == false then array with larger rank has to be passed as first argument - static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); - static bool evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace); - - // evaluate sorted vector of max axes to create tads along in case of simple broadcast operation - // if simple broadcast is not possible then empty vector is returned - // PLEASE NOTE: condition (rank_max >= rank_min) should be satisfied ! - static std::vector tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min); - - // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo - static bool evalCommonBroadcastShapeInfo(const std::vector& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr); - - // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank - // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3} - static std::vector getDimsWithSameShape(const NDArray& max, const NDArray& min); - - // evaluate shapeInfo for resulting array of tile operation - static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, const std::vector& reps, sd::memory::Workspace* workspace); - - // returns shape part of shapeInfo as std::vector - static std::vector pullShapeFromShapeInfo(const Nd4jLong *shapeInfo); - - static std::string shapeAsString(const NDArray &array); - static std::string shapeAsString(const NDArray* array); - static std::string shapeAsString(const std::vector& shape); - static std::string shapeAsString(const Nd4jLong* shapeInfo); - static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); - static std::string strideAsString(const NDArray& array); - static std::string strideAsString(const NDArray* array); - - static std::string shapeInfoAsString(const Nd4jLong* shapeInfo); - - static std::vector shapeAsVector(const Nd4jLong* shapeInfo); - - // evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal - static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, sd::memory::Workspace* workspace); - - static std::vector evalBroadcastBackwardAxis(const Nd4jLong *operand, const Nd4jLong *result); - - // utility to calculate matrix product shape with give source shapes and additional params - // returns ShapeList pointer with result shape - static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace); - - /** - * This method evaluates permutation vector necessary for reducing of shapeFrom to shapeTo - * if shapeFrom is identical to shapeTo (permutation is unnecessary) then empty vector is returned - * in case of permutation is impossible an exception is thrown - */ - static std::vector evalPermutFromTo(const std::vector& shapeFrom, const std::vector& shapeTo); - - /** - * This method composes shape (shape only, not whole shapeInfo!) using dimensions values and corresponding indexes, - * please note: the size of input vector dimsAndIdx must always be even, since the numbers of dimensions and indexes are the same, - * for example if dimsAndIdx = {dimC,dimB,dimA, 2,1,0} then output vector = {dimA,dimB,dimC} - */ - static std::vector composeShapeUsingDimsAndIdx(const std::vector& dimsAndIdx); - - /** - * x * y = c, evaluate shape for array resulting from mmul operation - * possible cases: dot product (xRank=yRank=1), matrix-vector product (xRank=2, yRank=1), vector-matrix product (xRank=1, yRank=2), matrix-matrix product (xRank=yRank and rank >=2) - */ - static std::vector evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY); - - /** - * evaluate number of sub-arrays along dimensions stored in dimsToExclude - * i.e. if shape is [2,3,4,5] and dimsToExclude={0,2}, then number of sub-arrays = 8 - */ - static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector& dimsToExclude); - - /** - * return shape without unities, for example if shape is [1,2,1,3] then [2,3] will be returned - * if unities are not present in given shapeInfo then exactly identical shape will be returned, for example [2,3] -> [2,3] - * edge case: if given shape is [1,1,1,...,1] (all dims are unities) then output will be empty and means scalar - */ - static std::vector evalDimsWithoutUnities(const Nd4jLong* shapeInfo); - - /** - * method returns false if permut == {0,1,2,...permut.size()-1} - in that case permutation is unnecessary - */ - FORCEINLINE static bool isPermutNecessary(const std::vector& permut); - - /** - * calculates strides using "dest" shape and given "order", also copies data type from "source" to "dest" - */ - static void updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order); - - /** - * calculates strides using "dest" shape and "order", also set "dtype" into "dest" - */ - static void updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order); - - /** - * This method retuns number of bytes required for string tensor - * @param numStrings - * @return - */ - static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) { - // we store +1 offset - return (numStrings + 1) * sizeof(Nd4jLong); - } - - /** - * This method selects strides based on dimentions required for broadcasting - * @param const pointer to input (Y) shape info for strides selection - * @param rank of input (X) to broadcasting - * @param dimentions size - * @param const pointer to dimentions for broadcasting - * @param pointer to output strides have to be pre allocated by 0 - * @return - */ - static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides); - - /* - * check whether arr1/arr2 is sub-array of arr2/arr1, - * this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1 - * sameDims is filled (and sorted) with dimensions values that match both in arr1 and arr2 shapes (unities are ignored) - * for example: - * if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2} - * if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims contains {2,4} - * if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims contains {2,5} - - static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims); - */ - - /* - * comparing of shapes, not strides - */ - static bool areShapesEqual(const Nd4jLong* shapeInfo, const std::vector& shapeOnly); - }; - - +#include +namespace sd { +class SD_EXPORT ShapeUtils { + public: + // evaluate shape for array resulting from tensorDot operation, also evaluate + // shapes and permutation dimensions for transposition of two input arrays + static std::vector evalShapeForTensorDot( + const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, + std::vector axesA, std::vector axesB, + std::vector& permutAt, std::vector& permutBt, + std::vector& shapeAt, std::vector& shapeBt); + static std::vector evalShapeForTensorDot( + const NDArray* a, const NDArray* b, const std::vector& axesA, + const std::vector& axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt); + + // evaluate resulting shape after reduce operation + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, const NDArray& arr, + const sd::DataType dataType, const bool keepDims = false, + const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const sd::DataType dataType, + const bool keepDims = false, const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, const NDArray& arr, + const bool keepDims = false, const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + static const Nd4jLong* evalReduceShapeInfo( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const bool keepDims = false, + const bool supportOldShapes = false, + sd::memory::Workspace* workspace = nullptr); + + /** + * evaluate output shape for reduce operation when input shape is empty + * behavior is analogous to tf + */ + static const Nd4jLong* evalReduceShapeInfoEmpty( + const char order, const std::vector& dimensions, + const Nd4jLong* shapeInfo, const sd::DataType dataType, + const bool keepDims, sd::memory::Workspace* workspace); + + // evaluate shape for array which is result of repeat operation applied to arr + static std::vector evalRepeatShape(int axis, + const std::vector& repeats, + const NDArray& arr); + + // evaluate shapeInfo of permuted array + // if setContigStrides = true, then set contiguous strides in output shapeInfo + // in accordance with arr order + static const Nd4jLong* evalPermShapeInfo(const int* dimensions, + const int rank, const NDArray& arr, + sd::memory::Workspace* workspace, + const bool setContigStrides = false); + static const Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, + const int rank, const NDArray& arr, + sd::memory::Workspace* workspace); + + // evaluate shapeInfo of transposed array + // if setContigStrides = true, then set contiguous strides in output shapeInfo + // in accordance with arr order + static const Nd4jLong* evalTranspShapeInfo( + const NDArray& arr, sd::memory::Workspace* workspace, + const bool setContigStrides = false); + + static bool copyVectorPart(std::vector& target, std::vector& source, + int rank, int offset); + + // return new (shorter) sorted dimensions array without dimensions that are + // present in input vector + static std::vector evalDimsToExclude(const int rank, const int dimsLen, + const int* dimensions); + static std::vector evalDimsToExclude(const int rank, + const std::vector& dimensions); + + // check whether 2 arrays have mutually broadcastable shapes + // shape comparison starts from the end + static bool areShapesBroadcastable(const NDArray& arr1, const NDArray& arr2); + static bool areShapesBroadcastable(const Nd4jLong* shapeX, + const Nd4jLong* shapeY); + static bool areShapesBroadcastable(const std::vector& shape1, + const std::vector& shape2); + + // check the possibility of broadcast operation, if true then return shapeInfo + // of resulting array if evalMinMax == false then array with larger rank has + // to be passed as first argument + static bool evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace); + static bool evalBroadcastShapeInfo(const Nd4jLong* max, const Nd4jLong* min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace); + + // evaluate sorted vector of max axes to create tads along in case of simple + // broadcast operation if simple broadcast is not possible then empty vector + // is returned PLEASE NOTE: condition (rank_max >= rank_min) should be + // satisfied ! + static std::vector tadAxesForSimpleBroadcast(const NDArray& max, + const NDArray& min); + + // check the possibility of broadcast operation for set of arrays, if true + // then return resulting broadcasted shapeInfo + static bool evalCommonBroadcastShapeInfo( + const std::vector& arrays, Nd4jLong*& resultShapeInfo, + memory::Workspace* workspace = nullptr); + + // return sorted vector of dimensions common (same) for two arrays, dimensions + // values corresponds to array with bigger rank for example if arr1{2,7}, + // arr2{2,5,4,7} then vector = {0,3} + static std::vector getDimsWithSameShape(const NDArray& max, + const NDArray& min); + + // evaluate shapeInfo for resulting array of tile operation + static const Nd4jLong* evalTileShapeInfo(const NDArray& arr, + const std::vector& reps, + sd::memory::Workspace* workspace); + + // returns shape part of shapeInfo as std::vector + static std::vector pullShapeFromShapeInfo( + const Nd4jLong* shapeInfo); + + static std::string shapeAsString(const NDArray& array); + static std::string shapeAsString(const NDArray* array); + static std::string shapeAsString(const std::vector& shape); + static std::string shapeAsString(const Nd4jLong* shapeInfo); + static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo); + static std::string strideAsString(const NDArray& array); + static std::string strideAsString(const NDArray* array); + + static std::string shapeInfoAsString(const Nd4jLong* shapeInfo); + + static std::vector shapeAsVector(const Nd4jLong* shapeInfo); + + // evaluate shapeInfo for diagonal array which is made using input arr + // elements as diagonal + static const Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, + sd::memory::Workspace* workspace); + + static std::vector evalBroadcastBackwardAxis(const Nd4jLong* operand, + const Nd4jLong* result); + + // utility to calculate matrix product shape with give source shapes and + // additional params returns ShapeList pointer with result shape + static const Nd4jLong* matrixProductShape(const Nd4jLong* theFirstShape, + const Nd4jLong* theSecondShape, + bool shouldTranspondFirst, + bool shouldTranspondSecond, + sd::DataType dtype, + sd::memory::Workspace* workspace); + + /** + * This method evaluates permutation vector necessary for reducing of + * shapeFrom to shapeTo if shapeFrom is identical to shapeTo (permutation is + * unnecessary) then empty vector is returned in case of permutation is + * impossible an exception is thrown + */ + static std::vector evalPermutFromTo( + const std::vector& shapeFrom, + const std::vector& shapeTo); + + /** + * This method composes shape (shape only, not whole shapeInfo!) using + * dimensions values and corresponding indexes, please note: the size of input + * vector dimsAndIdx must always be even, since the numbers of dimensions and + * indexes are the same, for example if dimsAndIdx = {dimC,dimB,dimA, 2,1,0} + * then output vector = {dimA,dimB,dimC} + */ + static std::vector composeShapeUsingDimsAndIdx( + const std::vector& dimsAndIdx); + + /** + * x * y = c, evaluate shape for array resulting from mmul operation + * possible cases: dot product (xRank=yRank=1), matrix-vector product + * (xRank=2, yRank=1), vector-matrix product (xRank=1, yRank=2), matrix-matrix + * product (xRank=yRank and rank >=2) + */ + static std::vector evalShapeForMatmul(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const bool transX, + const bool transY); + + /** + * evaluate number of sub-arrays along dimensions stored in dimsToExclude + * i.e. if shape is [2,3,4,5] and dimsToExclude={0,2}, then number of + * sub-arrays = 8 + */ + static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, + const std::vector& dimsToExclude); + + /** + * return shape without unities, for example if shape is [1,2,1,3] then + * [2,3] will be returned if unities are not present in given shapeInfo then + * exactly identical shape will be returned, for example [2,3] -> [2,3] edge + * case: if given shape is [1,1,1,...,1] (all dims are unities) then output + * will be empty and means scalar + */ + static std::vector evalDimsWithoutUnities( + const Nd4jLong* shapeInfo); + + /** + * method returns false if permut == {0,1,2,...permut.size()-1} - in that + * case permutation is unnecessary + */ + FORCEINLINE static bool isPermutNecessary(const std::vector& permut); + + /** + * calculates strides using "dest" shape and given "order", also copies data + * type from "source" to "dest" + */ + static void updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, + const char order); + + /** + * calculates strides using "dest" shape and "order", also set "dtype" into + * "dest" + */ + static void updateStridesAndType(Nd4jLong* dest, const DataType dtype, + const char order); + + /** + * This method retuns number of bytes required for string tensor + * @param numStrings + * @return + */ + static FORCEINLINE Nd4jLong + stringBufferHeaderRequirements(Nd4jLong numStrings) { + // we store +1 offset + return (numStrings + 1) * sizeof(Nd4jLong); + } + + /** + * This method selects strides based on dimentions required for broadcasting + * @param const pointer to input (Y) shape info for strides selection + * @param rank of input (X) to broadcasting + * @param dimentions size + * @param const pointer to dimentions for broadcasting + * @param pointer to output strides have to be pre allocated by 0 + * @return + */ + static void copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, + const int nRank, + const int dimsSize, + const int* dims, + Nd4jLong* outStrides); + + /* + * check whether arr1/arr2 is sub-array of arr2/arr1, + * this method do not evaluate what array is sub-array, it returns true if arr1 + is sub-array of arr2 or arr2 is sub-array of arr1 + * sameDims is filled (and sorted) with dimensions values that match both in + arr1 and arr2 shapes (unities are ignored) + * for example: + * if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2} + * if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims + contains {2,4} + * if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims + contains {2,5} + + static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, + std::vector& sameDims); + */ + + /* + * comparing of shapes, not strides + */ + static bool areShapesEqual(const Nd4jLong* shapeInfo, + const std::vector& shapeOnly); +}; ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// ////////////////////////////////////////////////////////////////////////// FORCEINLINE bool ShapeUtils::isPermutNecessary(const std::vector& permut) { + for (int i = 0; i < permut.size(); ++i) + if (permut[i] != i) return true; - for(int i=0; i + #include #include -#include /** - * This class provides PRIMITIVE read-write lock, and should NOT be used outside of GraphServer due to its inefficiency. - * However, since GraphServer isn't supposed to have Reads/Writes ration even close to 1.0, it'll work just fine. + * This class provides PRIMITIVE read-write lock, and should NOT be used outside + * of GraphServer due to its inefficiency. However, since GraphServer isn't + * supposed to have Reads/Writes ration even close to 1.0, it'll work just fine. * * Basic idea: write lock won't be obtained before all read requests served */ namespace sd { - class SD_EXPORT SimpleReadWriteLock { - private: - std::atomic _read_locks; - std::atomic _write_locks; - std::mutex _mutex; - - public: - - explicit SimpleReadWriteLock(); - SimpleReadWriteLock(const SimpleReadWriteLock& other); - ~SimpleReadWriteLock() = default; - - // read lock - void lockRead(); - void unlockRead(); - - // write lock - void lockWrite(); - void unlockWrite(); - - SimpleReadWriteLock& operator= ( const SimpleReadWriteLock &other); - }; -} - - -#endif //SD_READWRITELOCK_H +class SD_EXPORT SimpleReadWriteLock { + private: + std::atomic _read_locks; + std::atomic _write_locks; + std::mutex _mutex; + + public: + explicit SimpleReadWriteLock(); + SimpleReadWriteLock(const SimpleReadWriteLock& other); + ~SimpleReadWriteLock() = default; + + // read lock + void lockRead(); + void unlockRead(); + + // write lock + void lockWrite(); + void unlockWrite(); + + SimpleReadWriteLock& operator=(const SimpleReadWriteLock& other); +}; +} // namespace sd + +#endif // SD_READWRITELOCK_H diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index fe96e0287da9..eac8b90116b7 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -23,122 +23,126 @@ #ifndef LIBND4J_STRINGUTILS_H #define LIBND4J_STRINGUTILS_H -#include +#include +#include #include -#include +#include + #include +#include #include -#include -#include namespace sd { - class SD_EXPORT StringUtils { - public: - template - static FORCEINLINE std::string valueToString(T value) { - std::ostringstream os; - - os << value ; - - //convert the string stream into a string and return - return os.str(); - } - - /** - * This method just concatenates error message with a given graphId - * @param message - * @param graphId - * @return - */ - static FORCEINLINE std::string buildGraphErrorMessage(const char *message, Nd4jLong graphId) { - std::string result(message); - result += " ["; - result += valueToString(graphId); - result += "]"; - - return result; - } - - /** - * This method returns number of needle matches within haystack - * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 - * - * @param haystack - * @param haystackLength - * @param needle - * @param needleLength - * @return - */ - static uint64_t countSubarrays(const void *haystack, uint64_t haystackLength, const void *needle, uint64_t needleLength); - - /** - * This method returns number of bytes used for string NDArrays content - * PLEASE NOTE: this doesn't include header - * - * @param array - * @return - */ - static uint64_t byteLength(const NDArray &array); - - /** - * This method splits a string into substring by delimiter - * - * @param haystack - * @param delimiter - * @return - */ - static std::vector split(const std::string &haystack, const std::string &delimiter); - - - /** - * This method convert u8 string to u16 - * @param const reference to input string - * @param reference to output u16string - * @return boolean status - */ - static bool u8StringToU16String(const std::string& u8, std::u16string& u16); - - /** - * This method convert u8 string to u32 - * @param const reference to input string - * @param reference to output u32string - * @return boolean status - */ - static bool u8StringToU32String(const std::string& u8, std::u32string& u32); - - /** - * This method convert u16 string to u32 - * @param const reference to input u16string - * @param reference to output u32string - * @return boolean status - */ - static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32); - - /** - * This method convert u16 string to u8 string - * @param const reference to input u16string - * @param reference to output string - * @return boolean status - */ - static bool u16StringToU8String(const std::u16string& u16, std::string& u8); - - /** - * This method convert u32 string to u16 string - * @param const reference to input u32string - * @param reference to output u16string - * @return boolean status - */ - static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16); - - /** - * This method convert u32 string to u8 string - * @param const reference to input u32string - * @param reference to output string - * @return boolean status - */ - static bool u32StringToU8String(const std::u32string& u32, std::string& u8); - }; -} - - -#endif //LIBND4J_STRINGUTILS_H +class SD_EXPORT StringUtils { + public: + template + static FORCEINLINE std::string valueToString(T value) { + std::ostringstream os; + + os << value; + + // convert the string stream into a string and return + return os.str(); + } + + /** + * This method just concatenates error message with a given graphId + * @param message + * @param graphId + * @return + */ + static FORCEINLINE std::string buildGraphErrorMessage(const char* message, + Nd4jLong graphId) { + std::string result(message); + result += " ["; + result += valueToString(graphId); + result += "]"; + + return result; + } + + /** + * This method returns number of needle matches within haystack + * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 + * + * @param haystack + * @param haystackLength + * @param needle + * @param needleLength + * @return + */ + static uint64_t countSubarrays(const void* haystack, uint64_t haystackLength, + const void* needle, uint64_t needleLength); + + /** + * This method returns number of bytes used for string NDArrays content + * PLEASE NOTE: this doesn't include header + * + * @param array + * @return + */ + static uint64_t byteLength(const NDArray& array); + + /** + * This method splits a string into substring by delimiter + * + * @param haystack + * @param delimiter + * @return + */ + static std::vector split(const std::string& haystack, + const std::string& delimiter); + + /** + * This method convert u8 string to u16 + * @param const reference to input string + * @param reference to output u16string + * @return boolean status + */ + static bool u8StringToU16String(const std::string& u8, std::u16string& u16); + + /** + * This method convert u8 string to u32 + * @param const reference to input string + * @param reference to output u32string + * @return boolean status + */ + static bool u8StringToU32String(const std::string& u8, std::u32string& u32); + + /** + * This method convert u16 string to u32 + * @param const reference to input u16string + * @param reference to output u32string + * @return boolean status + */ + static bool u16StringToU32String(const std::u16string& u16, + std::u32string& u32); + + /** + * This method convert u16 string to u8 string + * @param const reference to input u16string + * @param reference to output string + * @return boolean status + */ + static bool u16StringToU8String(const std::u16string& u16, std::string& u8); + + /** + * This method convert u32 string to u16 string + * @param const reference to input u32string + * @param reference to output u16string + * @return boolean status + */ + static bool u32StringToU16String(const std::u32string& u32, + std::u16string& u16); + + /** + * This method convert u32 string to u8 string + * @param const reference to input u32string + * @param reference to output string + * @return boolean status + */ + static bool u32StringToU8String(const std::u32string& u32, std::string& u8); +}; +} // namespace sd + +#endif // LIBND4J_STRINGUTILS_H diff --git a/libnd4j/include/helpers/TAD.h b/libnd4j/include/helpers/TAD.h index cd58e421e5e8..40d811701bce 100644 --- a/libnd4j/include/helpers/TAD.h +++ b/libnd4j/include/helpers/TAD.h @@ -21,239 +21,242 @@ #ifndef LIBND4J_TAD_H #define LIBND4J_TAD_H - #include #include - namespace shape { - /** - * Dimension collapse is an algorithm - * for collapsing singular dimensions. - * This algorithm will adjust the dimensions - * wrt the original. - * - * The algorithm has 3 components: - * trailing ones - * middle ones - * beginning ones - * - * dimensions that are specified to reduce along - * that are singular should be truncated - * - * dimensions that are specified that are singular - * at the beginning should be removed with middle dimensions - * decremented. - * - * For any time there is a no op, a collapse will - * set the first dimension to be -1. - * - * - */ - class TAD { - public: - Nd4jLong tadIndex = 0; - int dimensionLength; - int* dimension = nullptr; - Nd4jLong const* shapeInfo = nullptr; - Nd4jLong* tadOnlyShapeInfo = nullptr; - Nd4jLong numTads = 0; - int tadRank = 0; - Nd4jLong* tadShape = nullptr; - Nd4jLong* tadStride = nullptr; - Nd4jLong* tadOffsets = nullptr; - Nd4jLong tadOffsetForBlock = 0; - int rank = 0; - int numOnes = 0; - //pointers to original - int originalDimensionLength; - int const* originalDimension = nullptr; - Nd4jLong const* originalShapeInfo = nullptr; - bool squeezed = false; - bool newSqueezeDimensions = false; - int numOnesInMiddle = 0; - bool wholeThing = false; - //need to track whether we create a new dimension array or not, we could have just moved the pointer forward - //due to leading ones - bool createdNewDimension = false; - - // special case for CUDA, we're passing in __shared__ memory pointers to be used instead of new/malloc - void *ptrManager = nullptr; - int *ptrOutput = nullptr; - - INLINEDEF bool dimensionsDescending(int rank, int const* dimensions, int length); +/** + * Dimension collapse is an algorithm + * for collapsing singular dimensions. + * This algorithm will adjust the dimensions + * wrt the original. + * + * The algorithm has 3 components: + * trailing ones + * middle ones + * beginning ones + * + * dimensions that are specified to reduce along + * that are singular should be truncated + * + * dimensions that are specified that are singular + * at the beginning should be removed with middle dimensions + * decremented. + * + * For any time there is a no op, a collapse will + * set the first dimension to be -1. + * + * + */ +class TAD { + public: + Nd4jLong tadIndex = 0; + int dimensionLength; + int *dimension = nullptr; + Nd4jLong const *shapeInfo = nullptr; + Nd4jLong *tadOnlyShapeInfo = nullptr; + Nd4jLong numTads = 0; + int tadRank = 0; + Nd4jLong *tadShape = nullptr; + Nd4jLong *tadStride = nullptr; + Nd4jLong *tadOffsets = nullptr; + Nd4jLong tadOffsetForBlock = 0; + int rank = 0; + int numOnes = 0; + // pointers to original + int originalDimensionLength; + int const *originalDimension = nullptr; + Nd4jLong const *originalShapeInfo = nullptr; + bool squeezed = false; + bool newSqueezeDimensions = false; + int numOnesInMiddle = 0; + bool wholeThing = false; + // need to track whether we create a new dimension array or not, we could have + // just moved the pointer forward due to leading ones + bool createdNewDimension = false; + + // special case for CUDA, we're passing in __shared__ memory pointers to be + // used instead of new/malloc + void *ptrManager = nullptr; + int *ptrOutput = nullptr; + + INLINEDEF bool dimensionsDescending(int rank, int const *dimensions, + int length); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF TAD() {} - + INLINEDEF + TAD() { + } #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void setExternalBuffers(void *ptrManager); - - + INLINEDEF void + setExternalBuffers(void *ptrManager); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void setOutputBuffer(int *ptrOutput); + INLINEDEF void + setOutputBuffer(int *ptrOutput); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - /** - * This method is for GPU mostly, it allows to initialize TAD instance with precalculated tadOnlyShapeInfo - */ - INLINEDEF void initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength); - - + /** + * This method is for GPU mostly, it allows to initialize TAD instance + * with precalculated tadOnlyShapeInfo + */ + INLINEDEF void + initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, + int *dimension, int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void init(Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength); + INLINEDEF void + init(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void init(int index, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength); - - + INLINEDEF void + init(int index, Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); - template + template #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void printTADsND(T *x); - - + INLINEDEF void + printTADsND(T *x); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong *out); + INLINEDEF void + permuteShapeBufferInPlace(Nd4jLong const *shapeBuffer, + int const *rearrange, Nd4jLong *out); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange); - - - + INLINEDEF Nd4jLong * + permuteShapeBuffer(Nd4jLong const *shapeBuffer, int *rearrange); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void createTadOnlyShapeInfo(); - + INLINEDEF void + createTadOnlyShapeInfo(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong const* shapeBuffer); - + INLINEDEF Nd4jLong + lengthPerSlice(Nd4jLong const *shapeBuffer); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index); - - + INLINEDEF Nd4jLong * + tad2Sub(Nd4jLong index); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF ~TAD(); - + INLINEDEF ~TAD(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF int* permuteDims(); - - - /** - * Compute the tad offset given a dimension. - * - * The general pattern for computing a tad offset is as follows: - * Every $STRIDE that was removed (the first dimension) - * do a jump by the major stride of the parent array - * (stride[0] of the parent array) - * - * For example given a c ordered 2,2,3,2 with stride 12,6,2,1 - * A tad of dimension 1 will jump 12 every 6 tads. - * - * You then end up with offsets of: - * 0 - * 1 - * 2 - * 3 - * 4 - * 5 - * 12 - * 13 - * 14 - * 15 - * 16 - * 17 - * - * notice there are 12 tads here. This same incremental jump will happen - * every time. - * Note here that by default the - * stride of element wise stride is used for the hops. - * - * Sometimes a jump doesn't happen. If there are less tads - * than the stride of the dimension you removed, the - * element wise stride will always be used. - * - * For example in a dimension of 0,1, you end up with offsets of: - * 0,1,2,3,4,5 - * - * Given that the inner most stride of the dimensions that was removed (1) - * had a stride of 6, we never need to do a major stride jump. - * - */ + INLINEDEF int * + permuteDims(); + + /** + * Compute the tad offset given a dimension. + * + * The general pattern for computing a tad offset is as follows: + * Every $STRIDE that was removed (the first dimension) + * do a jump by the major stride of the parent array + * (stride[0] of the parent array) + * + * For example given a c ordered 2,2,3,2 with stride 12,6,2,1 + * A tad of dimension 1 will jump 12 every 6 tads. + * + * You then end up with offsets of: + * 0 + * 1 + * 2 + * 3 + * 4 + * 5 + * 12 + * 13 + * 14 + * 15 + * 16 + * 17 + * + * notice there are 12 tads here. This same incremental jump will happen + * every time. + * Note here that by default the + * stride of element wise stride is used for the hops. + * + * Sometimes a jump doesn't happen. If there are less tads + * than the stride of the dimension you removed, the + * element wise stride will always be used. + * + * For example in a dimension of 0,1, you end up with offsets of: + * 0,1,2,3,4,5 + * + * Given that the inner most stride of the dimensions that was removed (1) + * had a stride of 6, we never need to do a major stride jump. + * + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tadOffset(Nd4jLong index); - + INLINEDEF Nd4jLong + tadOffset(Nd4jLong index); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tensorShape(); + INLINEDEF Nd4jLong * + tensorShape(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index, void *ptrManager); - + INLINEDEF Nd4jLong * + tad2Sub(Nd4jLong index, void *ptrManager); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void createOffsets(); - + INLINEDEF void + createOffsets(); #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong* shapeInfoOnlyShapeAndStride(); - + INLINEDEF Nd4jLong * + shapeInfoOnlyShapeAndStride(); - - /** - * Length of a tad given - * the shape information - */ + /** + * Length of a tad given + * the shape information + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength); + INLINEDEF Nd4jLong + tadLength(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); /** * Computes the number @@ -261,42 +264,33 @@ namespace shape { * a given dimension */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength); - + INLINEDEF Nd4jLong + tensorsAlongDimension(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength); #ifdef __CUDACC__ - __host__ __device__ - INLINEDEF void createOffsetForBlock(int blockIdx) { - this->tadOffsetForBlock = this->tadOffset(blockIdx); - } + __host__ __device__ INLINEDEF void createOffsetForBlock(int blockIdx) { + this->tadOffsetForBlock = this->tadOffset(blockIdx); + } #endif - #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - INLINEDEF void collapse(); - }; - + INLINEDEF void + collapse(); +}; - - - - - - - - - //// +//// /* #ifdef __CUDACC__ __host__ __device__ #endif - INLINEDEF TAD::TAD(int tadIndex,Nd4jLong *shapeInfo,int *dimension,int dimensionLength) { - this->tadIndex = tadIndex; - this->init(shapeInfo, dimension, dimensionLength); + INLINEDEF TAD::TAD(int tadIndex,Nd4jLong *shapeInfo,int *dimension,int +dimensionLength) { this->tadIndex = tadIndex; this->init(shapeInfo, dimension, +dimensionLength); } @@ -308,785 +302,796 @@ namespace shape { } */ - INLINEDEF void TAD::setExternalBuffers(void *ptrManager) { - this->ptrManager = ptrManager; - } - - INLINEDEF void TAD::setOutputBuffer(int *ptrOutput) { - this->ptrOutput = ptrOutput; - } - - INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength) { - this->tadOnlyShapeInfo = existingTAD; - this->rank = shape::rank(originalShape); - - this->originalShapeInfo = originalShape; - this->originalDimension = dimension; - this->originalDimensionLength = dimensionLength; - - this->shapeInfo = originalShape; - this->dimension = dimension; - this->dimensionLength = dimensionLength; +INLINEDEF void TAD::setExternalBuffers(void *ptrManager) { + this->ptrManager = ptrManager; +} - this->tadShape = shape::shapeOf(existingTAD); - this->tadStride = shape::stride(existingTAD); +INLINEDEF void TAD::setOutputBuffer(int *ptrOutput) { + this->ptrOutput = ptrOutput; +} - Nd4jLong ews = shape::elementWiseStride(originalShape); +INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, + Nd4jLong *originalShape, int *dimension, + int dimensionLength) { + this->tadOnlyShapeInfo = existingTAD; + this->rank = shape::rank(originalShape); + + this->originalShapeInfo = originalShape; + this->originalDimension = dimension; + this->originalDimensionLength = dimensionLength; + + this->shapeInfo = originalShape; + this->dimension = dimension; + this->dimensionLength = dimensionLength; + + this->tadShape = shape::shapeOf(existingTAD); + this->tadStride = shape::stride(existingTAD); + + Nd4jLong ews = shape::elementWiseStride(originalShape); + + this->numTads = shape::length(originalShape) / + shape::length(existingTAD); // this->tensorsAlongDimension(this->shapeInfo, + // this->dimension, + // this->dimensionLength);//shape::length(originalShape) + // / shape::length(existingTAD); + this->wholeThing = this->numTads == 1 || + ((this->dimensionLength == this->rank || + this->numTads == shape::length(this->shapeInfo)) && + ews == 1); +} - this->numTads = shape::length(originalShape) / shape::length(existingTAD); // this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);//shape::length(originalShape) / shape::length(existingTAD); - this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1); - } +INLINEDEF void TAD::init(int tadIndex, Nd4jLong const *shapeInfo, + int const *dimension, int dimensionLength) { + this->tadIndex = tadIndex; + this->init(shapeInfo, dimension, dimensionLength); +} - INLINEDEF void TAD::init(int tadIndex, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength) { - this->tadIndex = tadIndex; - this->init(shapeInfo, dimension, dimensionLength); +INLINEDEF void TAD::init(Nd4jLong const *shapeInfo, int const *dimension, + int dimensionLength) { + this->originalShapeInfo = shapeInfo; + this->originalDimension = dimension; + this->originalDimensionLength = dimensionLength; + // start off as original references + this->shapeInfo = shapeInfo; + this->dimensionLength = dimensionLength; + this->dimension = const_cast(dimension); + this->rank = shape::rank(shapeInfo); + this->numTads = dimensionLength == 0 ? 1 + : this->tensorsAlongDimension( + this->shapeInfo, this->dimension, + this->dimensionLength); + + Nd4jLong ews = shape::elementWiseStride(shapeInfo); + + if (dimensionLength == 0) { + wholeThing = true; + } else if (!shape::isVector(shapeInfo)) { + wholeThing = + this->numTads == + 1 // if number of TADs is 1, we just have input shape == TAD shape + || + ((this->dimensionLength == + this->rank // if number of dimensions is the same as input rank, + // that'll be wholeTad too, but only if EWS==1 (aka - + // not a View) + || (this->numTads == shape::length(shapeInfo) && + shape::order(shapeInfo) == + 'c')) // OR number of tads equals to shapeInfo length AND + // input is in C order. if order is F - we'll have to + // calculate offsets + && + ews == + 1); // as mentioned above - last 2 rules apply only to non-views + } else if (shape::isScalar(shapeInfo)) { + wholeThing = true; + // vector case + } else { + // if(dimensionLength == 1 && shape::shapeOf(shapeInfo)[dimension[0]] == 1) + // { + // if(dimension == 0 && ) { + if (dimensionLength != 0 && dimension != nullptr && + shape::shapeOf(shapeInfo)[dimension[0]] == 1) { + wholeThing = true; } + } +} - INLINEDEF void TAD::init(Nd4jLong const* shapeInfo, int const* dimension,int dimensionLength) { - this->originalShapeInfo = shapeInfo; - this->originalDimension = dimension; - this->originalDimensionLength = dimensionLength; - //start off as original references - this->shapeInfo = shapeInfo; - this->dimensionLength = dimensionLength; - this->dimension = const_cast(dimension); - this->rank = shape::rank(shapeInfo); - this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength); - - Nd4jLong ews = shape::elementWiseStride(shapeInfo); - - if (dimensionLength == 0) { - wholeThing = true; - } else if(!shape::isVector(shapeInfo)) { - wholeThing = this->numTads == 1 // if number of TADs is 1, we just have input shape == TAD shape - || ((this->dimensionLength == this->rank // if number of dimensions is the same as input rank, that'll be wholeTad too, but only if EWS==1 (aka - not a View) - || (this->numTads == shape::length(shapeInfo) && shape::order(shapeInfo) == 'c')) // OR number of tads equals to shapeInfo length AND input is in C order. if order is F - we'll have to calculate offsets - && ews == 1); // as mentioned above - last 2 rules apply only to non-views - } else if(shape::isScalar(shapeInfo)) { - wholeThing = true; - //vector case - } else { - // if(dimensionLength == 1 && shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - //if(dimension == 0 && ) { - if(dimensionLength != 0 && dimension != nullptr && shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - wholeThing = true; - } - } +template +INLINEDEF void TAD::printTADsND(T *x) { + if (wholeThing) { + for (int i = 0; i < shape::length(tadOnlyShapeInfo); i++) { + printf(" %f ", x[i]); } - - template - INLINEDEF void TAD::printTADsND(T *x) { - if(wholeThing) { - for(int i = 0; i < shape::length(tadOnlyShapeInfo); i++) { - printf(" %f ",x[i]); - } - printf("\n"); - } - else { - for (int i = 0; i < numTads; i++) { - auto offset = tadOffsets[i]; - Nd4jLong shapeIter[MAX_RANK]; - Nd4jLong coord[MAX_RANK]; - int dim; - int rankIter = shape::rank(tadOnlyShapeInfo); - Nd4jLong xStridesIter[MAX_RANK]; - T *xPointer = x + offset; - if (PrepareOneRawArrayIter(rankIter, - shape::shapeOf(tadOnlyShapeInfo), - xPointer, - shape::stride(tadOnlyShapeInfo), - &rankIter, - shapeIter, - &xPointer, - xStridesIter) >= 0) { - ND4J_RAW_ITER_START(dim, shape::rank(tadOnlyShapeInfo), coord, shapeIter); { - /* Process the innermost dimension */ - printf(" %f ",xPointer[0]); - } - ND4J_RAW_ITER_ONE_NEXT(dim, - rankIter, - coord, - shapeIter, - xPointer, - xStridesIter); - printf("\n"); - - } - else { - printf("Unable to prepare array\n"); - } - } + printf("\n"); + } else { + for (int i = 0; i < numTads; i++) { + auto offset = tadOffsets[i]; + Nd4jLong shapeIter[MAX_RANK]; + Nd4jLong coord[MAX_RANK]; + int dim; + int rankIter = shape::rank(tadOnlyShapeInfo); + Nd4jLong xStridesIter[MAX_RANK]; + T *xPointer = x + offset; + if (PrepareOneRawArrayIter(rankIter, shape::shapeOf(tadOnlyShapeInfo), + xPointer, shape::stride(tadOnlyShapeInfo), + &rankIter, shapeIter, &xPointer, + xStridesIter) >= 0) { + ND4J_RAW_ITER_START(dim, shape::rank(tadOnlyShapeInfo), coord, + shapeIter); + { + /* Process the innermost dimension */ + printf(" %f ", xPointer[0]); } - } + ND4J_RAW_ITER_ONE_NEXT(dim, rankIter, coord, shapeIter, xPointer, + xStridesIter); + printf("\n"); - - INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong* out) { - memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank)); - doPermuteShapeInfo(out, rearrange); - } - - INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange) { - int len = shape::shapeInfoLength(this->rank); - Nd4jLong *copy = shape::copyOf(len,shapeBuffer); - doPermuteShapeInfo(copy,rearrange); - return copy; + } else { + printf("Unable to prepare array\n"); + } } + } +} - INLINEDEF bool TAD::dimensionsDescending(int rank, int const* dimensions, int length) { - int desired = rank - 1; - for (int e = length - 1; e >= 0; e--) { - if (dimensions[e] != desired--) - return false; - } - return true; - } +INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const *shapeBuffer, + int const *rearrange, + Nd4jLong *out) { + memcpy(out, shapeBuffer, + sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank)); + doPermuteShapeInfo(out, rearrange); +} - INLINEDEF void TAD::createTadOnlyShapeInfo() { - this->tadOnlyShapeInfo = this->shapeInfoOnlyShapeAndStride(); - sd::ArrayOptions::setDataType(this->tadOnlyShapeInfo, sd::ArrayOptions::dataType(this->originalShapeInfo)); +INLINEDEF Nd4jLong *TAD::permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange) { + int len = shape::shapeInfoLength(this->rank); + Nd4jLong *copy = shape::copyOf(len, shapeBuffer); + doPermuteShapeInfo(copy, rearrange); + return copy; +} - // possible optimization goes here - if (shape::order(this->originalShapeInfo) == 'c' - && shape::strideDescendingCAscendingF(this->originalShapeInfo) - && dimensionsDescending(shape::rank(this->originalShapeInfo), this->originalDimension, this->originalDimensionLength)) { - // for C order, if outer dimensions are used, continuous layout is preserved - this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 2] = this->originalShapeInfo[shape::shapeInfoLength(this->originalShapeInfo) - 2]; - } +INLINEDEF bool TAD::dimensionsDescending(int rank, int const *dimensions, + int length) { + int desired = rank - 1; + for (int e = length - 1; e >= 0; e--) { + if (dimensions[e] != desired--) return false; + } + return true; +} - // do not swap order if positive elementwise stride preserved - if (shape::elementWiseStride(this->tadOnlyShapeInfo) >= 1) { - this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 1] = shape::order(this->originalShapeInfo); - } +INLINEDEF void TAD::createTadOnlyShapeInfo() { + this->tadOnlyShapeInfo = this->shapeInfoOnlyShapeAndStride(); + sd::ArrayOptions::setDataType( + this->tadOnlyShapeInfo, + sd::ArrayOptions::dataType(this->originalShapeInfo)); + + // possible optimization goes here + if (shape::order(this->originalShapeInfo) == 'c' && + shape::strideDescendingCAscendingF(this->originalShapeInfo) && + dimensionsDescending(shape::rank(this->originalShapeInfo), + this->originalDimension, + this->originalDimensionLength)) { + // for C order, if outer dimensions are used, continuous layout is preserved + this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 2] = + this->originalShapeInfo[shape::shapeInfoLength( + this->originalShapeInfo) - + 2]; + } + + // do not swap order if positive elementwise stride preserved + if (shape::elementWiseStride(this->tadOnlyShapeInfo) >= 1) { + this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 1] = + shape::order(this->originalShapeInfo); + } + + if (this->tadShape != nullptr) delete[] this->tadShape; + + this->tadShape = shape::shapeOf(this->tadOnlyShapeInfo); + this->tadStride = shape::stride(this->tadOnlyShapeInfo); +} - if (this->tadShape != nullptr) - delete[] this->tadShape; +INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const *shapeBuffer) { + int dimension = 0; + Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer), &dimension, + shape::rank(shapeBuffer), 1); + Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1); + delete[] remove; + return prod; +} - this->tadShape = shape::shapeOf(this->tadOnlyShapeInfo); - this->tadStride = shape::stride(this->tadOnlyShapeInfo); +INLINEDEF Nd4jLong *TAD::tad2Sub(Nd4jLong index) { + Nd4jLong *shape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + int leftOverIndexLen = rank - originalDimensionLength; + + Nd4jLong *ret = new Nd4jLong[rank]; + // shape of the tad + Nd4jLong *tadShape = new Nd4jLong[leftOverIndexLen]; + Nd4jLong *leftOverIndexes = new Nd4jLong[leftOverIndexLen]; + Nd4jLong *sub = new Nd4jLong[rank]; + + // indexes not specified in the tad indexes + + // every coordinate starts as zero + memset(ret, 0, shape::shapeInfoByteLength(rank)); + + // find the length of the elements we + // are iterating over + Nd4jLong len = 1; + // left over index cursor for initializing elements + int leftOverIndex = 0; + for (int i = 0; i < rank; i++) { + // look for dimensions NOT found in dimension length (basically compute + // shape - dimension (set difference) + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + // skip over specified dimensions when computing left over length + if (i == originalDimension[j]) { + found = true; + break; + } } - INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const* shapeBuffer) { - int dimension = 0; - Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1); - Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1); - delete[] remove; - return prod; + // add to the indexes that aren't specified as part of the tad dimension + // indexes + if (!found) { + // accumulate the list of indexes left over used for initializing the + // return value + leftOverIndexes[leftOverIndex] = i; + // accumulate the tad shape + tadShape[leftOverIndex] = shape[i]; + // accumulate the length (product) of the indexes that will be iterated + // over + len *= shape[i]; + leftOverIndex++; } + } + // sub for indices + /* int *sub = new int[leftOverIndexLen]; + shape::ind2subOrder(tadShape,index,len,sub); + */ + shape::index2coords(index, leftOverIndexLen, tadShape, sub); - INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index) { - Nd4jLong *shape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - int leftOverIndexLen = rank - originalDimensionLength; - - Nd4jLong *ret = new Nd4jLong[rank]; - //shape of the tad - Nd4jLong *tadShape = new Nd4jLong[leftOverIndexLen]; - Nd4jLong *leftOverIndexes = new Nd4jLong[leftOverIndexLen]; - Nd4jLong *sub = new Nd4jLong[rank]; - - //indexes not specified in the tad indexes - - //every coordinate starts as zero - memset(ret,0, shape::shapeInfoByteLength(rank)); - - //find the length of the elements we - //are iterating over - Nd4jLong len = 1; - //left over index cursor for initializing elements - int leftOverIndex = 0; - for(int i = 0; i < rank; i++) { - //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference) - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - //skip over specified dimensions when computing left over length - if(i == originalDimension[j]) { - found = true; - break; - } - - } - - //add to the indexes that aren't specified as part of the tad dimension - //indexes - if(!found) { - //accumulate the list of indexes left over used for initializing the return value - leftOverIndexes[leftOverIndex] = i; - //accumulate the tad shape - tadShape[leftOverIndex] = shape[i]; - //accumulate the length (product) of the indexes that will be iterated over - len *= shape[i]; - leftOverIndex++; - - } - } + for (int i = 0; i < leftOverIndexLen; i++) { + ret[leftOverIndexes[i]] = sub[i]; + } + if (ptrManager == nullptr) { + delete[] tadShape; + delete[] leftOverIndexes; + delete[] sub; + } - //sub for indices - /* int *sub = new int[leftOverIndexLen]; - shape::ind2subOrder(tadShape,index,len,sub); - */ - shape::index2coords(index, leftOverIndexLen,tadShape, sub); + return ret; +} +INLINEDEF TAD::~TAD() { + // we may have just moved the pointer forward, we may not need to delete the + // pointer here + if (originalDimension != this->dimension && createdNewDimension) { + delete[] this->dimension; + } + if (this->originalShapeInfo != this->shapeInfo) { + delete[] this->shapeInfo; + } + if (this->tadOffsets != nullptr) { + delete[] this->tadOffsets; + } + + if (this->tadOnlyShapeInfo != nullptr && + this->tadOnlyShapeInfo != shapeInfo) { + delete[] this->tadOnlyShapeInfo; + } +} - for(int i = 0; i < leftOverIndexLen; i++) { - ret[leftOverIndexes[i]] = sub[i]; - } +INLINEDEF int *TAD::permuteDims() { + // permute dimensions for tad + int dimIdx = 0; + // loop backwards assuming dimension is sorted - if (ptrManager == nullptr) { - delete[] tadShape; - delete[] leftOverIndexes; - delete[] sub; - } + int *permuteDims = new int[shape::rank(shapeInfo)]; - return ret; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + if (i == originalDimension[j]) { + found = true; + break; + } } + // not found, append it to the end for permute + if (!found) permuteDims[dimIdx++] = i; + } - INLINEDEF TAD::~TAD() { - //we may have just moved the pointer forward, we may not need to delete the pointer here - if(originalDimension != this->dimension && createdNewDimension) { - delete[] this->dimension; - } - if(this->originalShapeInfo != this->shapeInfo) { - delete[] this->shapeInfo; - } - if(this->tadOffsets != nullptr) { - delete[] this->tadOffsets; - } + for (int i = originalDimensionLength - 1; i >= 0; i--) { + permuteDims[dimIdx++] = originalDimension[i]; + } - if(this->tadOnlyShapeInfo != nullptr && this->tadOnlyShapeInfo != shapeInfo) { - delete[] this->tadOnlyShapeInfo; - } - } + /* + for (int i = 0; i < originalDimensionLength; i++) { + permuteDims[i] = originalDimension[i]; + } + */ - INLINEDEF int* TAD::permuteDims() { - //permute dimensions for tad - int dimIdx = 0; - //loop backwards assuming dimension is sorted - - int *permuteDims = new int[shape::rank(shapeInfo)]; - - for(int i = 0; i < shape::rank(shapeInfo); i++) { - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - if(i == originalDimension[j]) { - found = true; - break; - } - } - - //not found, append it to the end for permute - if(!found) - permuteDims[dimIdx++] = i; - } + // permute dimensions for tad + return permuteDims; +} +INLINEDEF Nd4jLong TAD::tadOffset(Nd4jLong index) { + if (tadOnlyShapeInfo == nullptr) { + this->createTadOnlyShapeInfo(); + } + if (wholeThing) return index; - for(int i = originalDimensionLength - 1; i >= 0; i--) { - permuteDims[dimIdx++] = originalDimension[i]; - } + if (dimensionLength > 1) { + Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); -/* - for (int i = 0; i < originalDimensionLength; i++) { - permuteDims[i] = originalDimension[i]; - } -*/ + Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); - //permute dimensions for tad - return permuteDims; + if (ret < 0) { + if (ptrManager == nullptr) delete[] tad2Sub; + return -1; } + if (ptrManager == nullptr) delete[] tad2Sub; + return ret; - INLINEDEF Nd4jLong TAD::tadOffset(Nd4jLong index) { - if(tadOnlyShapeInfo == nullptr) { - this->createTadOnlyShapeInfo(); - } - - if(wholeThing) - return index; - - if(dimensionLength > 1) { - Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); + } else { + Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); - Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); + Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); - if(ret < 0) { - if (ptrManager == nullptr) - delete[] tad2Sub; - return -1; - } - if (ptrManager == nullptr) - delete[] tad2Sub; + if (ptrManager == nullptr) delete[] tad2Sub; - return ret; - - } - else { - Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); + return ret; + } +} - Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub); +INLINEDEF Nd4jLong *TAD::tensorShape() { + if (this->tadShape != nullptr) return this->tadShape; - if (ptrManager == nullptr) - delete[] tad2Sub; + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + Nd4jLong *tensorShape = shape::keep(theShape, this->dimension, + dimensionLength, shape::rank(shapeInfo)); + this->tadShape = tensorShape; + this->tadRank = dimensionLength; + return tensorShape; +} - return ret; - } +INLINEDEF Nd4jLong *TAD::tad2Sub(Nd4jLong index, void *ptrManager) { + auto shape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + int leftOverIndexLen = rank - originalDimensionLength; + Nd4jLong *tadShape; + Nd4jLong *leftOverIndexes; + Nd4jLong *sub; + Nd4jLong *ret; + + ret = new Nd4jLong[rank]; + // shape of the tad + leftOverIndexes = new Nd4jLong[leftOverIndexLen]; + sub = new Nd4jLong[rank]; + tadShape = new Nd4jLong[leftOverIndexLen]; + + // indexes not specified in the tad indexes + + // every coordinate starts as zero + memset(ret, 0, sizeof(Nd4jLong) * rank); + + // find the length of the elements we + // are iterating over + Nd4jLong len = 1; + // left over index cursor for initializing elements + int leftOverIndex = 0; + for (int i = 0; i < rank; i++) { + // look for dimensions NOT found in dimension length (basically compute + // shape - dimension (set difference) + bool found = false; + for (int j = 0; j < originalDimensionLength; j++) { + // skip over specified dimensions when computing left over length + if (i == originalDimension[j]) { + found = true; + break; + } } - - INLINEDEF Nd4jLong* TAD::tensorShape(){ - if(this->tadShape != nullptr) - return this->tadShape; - - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - Nd4jLong *tensorShape = shape::keep(theShape, this->dimension, dimensionLength,shape::rank(shapeInfo)); - this->tadShape = tensorShape; - this->tadRank = dimensionLength; - return tensorShape; + // add to the indexes that aren't specified as part of the tad dimension + // indexes + if (!found) { + // accumulate the list of indexes left over used for initializing the + // return value + leftOverIndexes[leftOverIndex] = i; + // accumulate the tad shape + tadShape[leftOverIndex] = shape[i]; + // accumulate the length (product) of the indexes that will be iterated + // over + leftOverIndex++; + len *= shape[i]; } + } - INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index, void *ptrManager) { - auto shape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - int leftOverIndexLen = rank - originalDimensionLength; - Nd4jLong *tadShape; - Nd4jLong *leftOverIndexes; - Nd4jLong *sub; - Nd4jLong *ret; - - ret = new Nd4jLong[rank]; - //shape of the tad - leftOverIndexes = new Nd4jLong[leftOverIndexLen]; - sub = new Nd4jLong[rank]; - tadShape = new Nd4jLong[leftOverIndexLen]; - - //indexes not specified in the tad indexes - - //every coordinate starts as zero - memset(ret,0,sizeof(Nd4jLong) * rank); - - - //find the length of the elements we - //are iterating over - Nd4jLong len = 1; - //left over index cursor for initializing elements - int leftOverIndex = 0; - for(int i = 0; i < rank; i++) { - //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference) - bool found = false; - for(int j = 0; j < originalDimensionLength; j++) { - //skip over specified dimensions when computing left over length - if(i == originalDimension[j]) { - found = true; - break; - } - - } - - //add to the indexes that aren't specified as part of the tad dimension - //indexes - if(!found) { - //accumulate the list of indexes left over used for initializing the return value - leftOverIndexes[leftOverIndex] = i; - //accumulate the tad shape - tadShape[leftOverIndex] = shape[i]; - //accumulate the length (product) of the indexes that will be iterated over - leftOverIndex++; - len *= shape[i]; - - } - } + // sub for indices + /* int *sub = new int[leftOverIndexLen]; + shape::ind2subOrder(tadShape,index,len,sub); + */ + shape::index2coords(index, leftOverIndexLen, tadShape, sub); + for (int i = 0; i < leftOverIndexLen; i++) { + ret[leftOverIndexes[i]] = sub[i]; + } - //sub for indices - /* int *sub = new int[leftOverIndexLen]; - shape::ind2subOrder(tadShape,index,len,sub); - */ - shape::index2coords(index, leftOverIndexLen,tadShape, sub); + if (ptrManager == nullptr) { + delete[] leftOverIndexes; + delete[] tadShape; + delete[] sub; + } - for(int i = 0; i < leftOverIndexLen; i++) { - ret[leftOverIndexes[i]] = sub[i]; - } + return ret; +} - if (ptrManager == nullptr) { - delete[] leftOverIndexes; - delete[] tadShape; - delete[] sub; - } +INLINEDEF void TAD::createOffsets() { + this->tadOffsets = new Nd4jLong[this->numTads]; + uint nT = this->numTads; - return ret; - } - - INLINEDEF void TAD::createOffsets() { - this->tadOffsets = new Nd4jLong[this->numTads]; - uint nT = this->numTads; + for (uint i = 0; i < nT; i++) this->tadOffsets[i] = this->tadOffset(i); +} - for(uint i = 0; i < nT; i++) - this->tadOffsets[i] = this->tadOffset(i); +INLINEDEF Nd4jLong *TAD::shapeInfoOnlyShapeAndStride() { + // if(wholeThing || (dimensionLength == 1 && dimension[0] == MAX_DIMENSION) || + // shape::isScalar(shapeInfo)) + // return shape::createScalarShapeInfo(); + + // ensure tad shapes get setup right for vectors + if (dimensionLength > 1 && shape::isVector(shapeInfo)) + return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)), + shapeInfo); + + // case when tad coincides with whole array + if (this->numTads == 1 && + ((shape::rank(originalShapeInfo) == originalDimensionLength) || + originalDimensionLength == 0)) { + // we might have special case here: skipped dimensions might be just full of + // ones + Nd4jLong *ret = shape::copyOf( + shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); + if (shape::isDimPermuted( + dimension, + (Nd4jLong)dimensionLength)) // check whether we need permutation + doPermuteShapeInfo(ret, dimension); + + return ret; + } + + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + int rank = shape::rank(shapeInfo); + + if (dimensionLength == 1) { + if (dimension[0] == 0 && shape::isVector(shapeInfo) && theShape[1] == 1) { + int permuted[2] = {1, 0}; + Nd4jLong *permutedRet2 = shape::permuteShapeBuffer(shapeInfo, permuted); + return permutedRet2; + } else if (dimension[0] == 1 && shape::isVector(shapeInfo) && + theShape[0] == 1) { + Nd4jLong *ret = shape::copyOf( + shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); + return ret; + } else if (shape::shapeOf(shapeInfo)[dimension[0]] == 1) { + Nd4jLong *scalarInfo = shape::createScalarShapeInfo(); + scalarInfo[shape::shapeInfoLength(shape::rank(scalarInfo)) - 3] = + this->tadIndex; + return scalarInfo; } - - - INLINEDEF Nd4jLong* TAD::shapeInfoOnlyShapeAndStride() { - //if(wholeThing || (dimensionLength == 1 && dimension[0] == MAX_DIMENSION) || shape::isScalar(shapeInfo)) - // return shape::createScalarShapeInfo(); - - //ensure tad shapes get setup right for vectors - if(dimensionLength > 1 && shape::isVector(shapeInfo)) - return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo); - - // case when tad coincides with whole array - if( this->numTads == 1 && ((shape::rank(originalShapeInfo) == originalDimensionLength) || originalDimensionLength == 0)) { - // we might have special case here: skipped dimensions might be just full of ones - Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo); - if (shape::isDimPermuted(dimension, (Nd4jLong) dimensionLength)) // check whether we need permutation - doPermuteShapeInfo(ret, dimension); - - return ret; - } - - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - int rank = shape::rank(shapeInfo); - - if(dimensionLength == 1) { - if(dimension[0] == 0 && shape::isVector(shapeInfo) && theShape[1] == 1) { - int permuted[2] = {1,0}; - Nd4jLong *permutedRet2 = shape::permuteShapeBuffer(shapeInfo, permuted); - return permutedRet2; - } else if(dimension[0] == 1 && shape::isVector(shapeInfo) && theShape[0] == 1) { - Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo); - return ret; - } - else if(shape::shapeOf(shapeInfo)[dimension[0]] == 1) { - Nd4jLong *scalarInfo = shape::createScalarShapeInfo(); - scalarInfo[shape::shapeInfoLength(shape::rank(scalarInfo)) - 3] = this->tadIndex; - return scalarInfo; - } - } - - Nd4jLong *tensorShape = this->tensorShape(); - int *reverseDimensions = shape::reverseCopy(dimension, dimensionLength); - int *rankRange = shape::range(0, rank); - int *remove = shape::removeIndex(rankRange, dimension, (Nd4jLong) rank, (Nd4jLong) dimensionLength); - //concat is wrong here with the length - int *newPermuteDims = shape::concat(remove,rank - dimensionLength,reverseDimensions,dimensionLength); - - Nd4jLong* permuted = shape::permuteShapeBuffer(shapeInfo,newPermuteDims); - - - Nd4jLong sliceIndex = shape::sliceOffsetForTensor(shape::rank(permuted), - this->tadIndex, - shape::shapeOf(shapeInfo), - tensorShape, - dimensionLength, - dimension, - dimensionLength); - - - - Nd4jLong *ret2 = shape::sliceOfShapeBuffer(sliceIndex, permuted); - Nd4jLong tensorLength = shape::prodLong(tensorShape,tadRank); - - Nd4jLong compLength = shape::isVector(ret2) ? shape::length(ret2) : shape::prodLong(tensorShape,tadRank); - // int temp; - // const bool isLikeVector = shape::isLikeVector(ret2, temp); - - // if(dimensionLength == tadRank && compLength == shape::length(ret2) && !isLikeVector) { - if(dimensionLength == tadRank && compLength == shape::length(ret2)) { - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else if(dimensionLength > 1) { - //permute *then* return ret2 - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - shape::permuteShapeBufferInPlace(ret2,finalPermuteDims,ret2); - delete[] finalPermuteDims; - - } - - } - else { - Nd4jLong length = tensorLength; - Nd4jLong lengthPerSlice = this->lengthPerSlice(ret2); - if(lengthPerSlice < 1) { - return ret2; - } - - Nd4jLong offset = tadIndex * tensorLength /lengthPerSlice; - if(sliceIndex == 0 && length == lengthPerSlice) { - Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); - delete[] ret2; - ret2 = newRet2; - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - // bool isRowVector2 = shape::isRowVector(ret2) && !isLikeVector; - bool isRowVector2 = shape::isRowVector(ret2); - if(isRowVector2 == false) { - shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); - } - - delete[] finalPermuteDims; - - } - else if(length == lengthPerSlice) { - offset -= shape::slices(ret2) * (offset / shape::slices(ret2)); - Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset,ret2); - delete[] ret2; - ret2 = newRet2; - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else { - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - Nd4jLong *newRet = shape::permuteShapeBuffer(ret2, finalPermuteDims); - delete[] ret2; - delete[] finalPermuteDims; - ret2 = newRet; - - } - - } - else { - //execute final part, note that this is mainly so delete[] gets called - //at the bottom of the method - while(shape::length(ret2) > length) { - auto lengthPerSlice2 = this->lengthPerSlice(ret2); - sliceIndex = sliceOffsetForTensor(sliceIndex,shape::length(ret2),lengthPerSlice2); - sliceIndex -= shape::slices(ret2) * (sliceIndex / shape::slices(ret2)); - auto newRet2 = shape::sliceOfShapeBuffer(sliceIndex,ret2); - delete[] ret2; - ret2 = newRet2; - } - - //don't permute on a row vector - if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) { - //go to the bottom and return ret2 after proper freeing of pointers - //basic idea; we *don't* permute row vectors - } - else if(dimensionLength > 1){ - //permute *then* return ret - int *finalPermuteDims = new int[shape::rank(ret2)]; - int forward = 0; - for(int i = shape::rank(ret2) - 1; i >= 0; i--) { - finalPermuteDims[forward++] = i; - } - auto newPermute = shape::permuteShapeBuffer(ret2,finalPermuteDims); - delete[] ret2; - delete[] finalPermuteDims; - ret2 = newPermute; - } - - } - } - - - delete[] permuted; - delete[] newPermuteDims; - delete[] rankRange; - delete[] remove; - delete[] reverseDimensions; - return ret2; + } + + Nd4jLong *tensorShape = this->tensorShape(); + int *reverseDimensions = shape::reverseCopy(dimension, dimensionLength); + int *rankRange = shape::range(0, rank); + int *remove = shape::removeIndex(rankRange, dimension, (Nd4jLong)rank, + (Nd4jLong)dimensionLength); + // concat is wrong here with the length + int *newPermuteDims = shape::concat(remove, rank - dimensionLength, + reverseDimensions, dimensionLength); + + Nd4jLong *permuted = shape::permuteShapeBuffer(shapeInfo, newPermuteDims); + + Nd4jLong sliceIndex = shape::sliceOffsetForTensor( + shape::rank(permuted), this->tadIndex, shape::shapeOf(shapeInfo), + tensorShape, dimensionLength, dimension, dimensionLength); + + Nd4jLong *ret2 = shape::sliceOfShapeBuffer(sliceIndex, permuted); + Nd4jLong tensorLength = shape::prodLong(tensorShape, tadRank); + + Nd4jLong compLength = shape::isVector(ret2) + ? shape::length(ret2) + : shape::prodLong(tensorShape, tadRank); + // int temp; + // const bool isLikeVector = shape::isLikeVector(ret2, temp); + + // if(dimensionLength == tadRank && compLength == shape::length(ret2) && + // !isLikeVector) { + if (dimensionLength == tadRank && compLength == shape::length(ret2)) { + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else if (dimensionLength > 1) { + // permute *then* return ret2 + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; + } + shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); + delete[] finalPermuteDims; } + } else { + Nd4jLong length = tensorLength; + Nd4jLong lengthPerSlice = this->lengthPerSlice(ret2); + if (lengthPerSlice < 1) { + return ret2; + } - INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) { - if(dimensionLength == 1) { - return shape::shapeOf(shapeInfo)[dimension[0]]; + Nd4jLong offset = tadIndex * tensorLength / lengthPerSlice; + if (sliceIndex == 0 && length == lengthPerSlice) { + Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); + delete[] ret2; + ret2 = newRet2; + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; + } + // bool isRowVector2 = shape::isRowVector(ret2) && !isLikeVector; + bool isRowVector2 = shape::isRowVector(ret2); + if (isRowVector2 == false) { + shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2); + } + + delete[] finalPermuteDims; + + } else if (length == lengthPerSlice) { + offset -= shape::slices(ret2) * (offset / shape::slices(ret2)); + Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2); + delete[] ret2; + ret2 = newRet2; + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else { + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; } - else { - Nd4jLong ret = 1; - for(int i = 0; i < shape::rank(shapeInfo); i++) { - for(int j = 0; j < dimensionLength; j++) { - if(i == dimension[j]) - ret *= shape::shapeOf(shapeInfo)[dimension[j]]; - } - } - return ret; + Nd4jLong *newRet = shape::permuteShapeBuffer(ret2, finalPermuteDims); + delete[] ret2; + delete[] finalPermuteDims; + ret2 = newRet; + } + + } else { + // execute final part, note that this is mainly so delete[] gets called + // at the bottom of the method + while (shape::length(ret2) > length) { + auto lengthPerSlice2 = this->lengthPerSlice(ret2); + sliceIndex = sliceOffsetForTensor(sliceIndex, shape::length(ret2), + lengthPerSlice2); + sliceIndex -= shape::slices(ret2) * (sliceIndex / shape::slices(ret2)); + auto newRet2 = shape::sliceOfShapeBuffer(sliceIndex, ret2); + delete[] ret2; + ret2 = newRet2; + } + + // don't permute on a row vector + if (dimensionLength == 1 && shape::isVector(ret2) && + shape::shapeOf(ret2)[0] == 1) { + // go to the bottom and return ret2 after proper freeing of pointers + // basic idea; we *don't* permute row vectors + } else if (dimensionLength > 1) { + // permute *then* return ret + int *finalPermuteDims = new int[shape::rank(ret2)]; + int forward = 0; + for (int i = shape::rank(ret2) - 1; i >= 0; i--) { + finalPermuteDims[forward++] = i; } + auto newPermute = shape::permuteShapeBuffer(ret2, finalPermuteDims); + delete[] ret2; + delete[] finalPermuteDims; + ret2 = newPermute; + } } + } + + delete[] permuted; + delete[] newPermuteDims; + delete[] rankRange; + delete[] remove; + delete[] reverseDimensions; + return ret2; +} - - INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) { - return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength); +INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const *shapeInfo, + int const *dimension, int dimensionLength) { + if (dimensionLength == 1) { + return shape::shapeOf(shapeInfo)[dimension[0]]; + } else { + Nd4jLong ret = 1; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + for (int j = 0; j < dimensionLength; j++) { + if (i == dimension[j]) ret *= shape::shapeOf(shapeInfo)[dimension[j]]; + } } + return ret; + } +} +INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const *shapeInfo, + int const *dimension, + int dimensionLength) { + return shape::length(shapeInfo) / + this->tadLength(shapeInfo, dimension, dimensionLength); +} - INLINEDEF void TAD::collapse() { - auto shape = shape::shapeOf(shapeInfo); - //handle negative dimensions/backwards indexing - for(int i = 0; i < dimensionLength; i++) { - if((dimension)[i] < 0) - (dimension)[i] += shape::rank(this->shapeInfo); - } - - this->dimension = new int[dimensionLength]; - memcpy(this->dimension,this->originalDimension, sizeof(int) * dimensionLength); - - //we can drop trailing dimensions where it's all singular for example: - // shape: 4,3,1,2 - //dimension: 0,2 - // the problem for 0,2 is equivalent to: 0 - //the rest of the algorithm handles cases suchas - //shape: 4,1,1,2 - //dimension: 0,1 - //when this happens there are other dimensions (eg: at the end) that matter - int trailingOneDimensions = 0; - //trailing ones - for(int i = dimensionLength - 1; i >= 0; i--) { - if(shape[dimension[i]] != 1) { - break; - } - else if(shape[dimension[i]] == 1) - trailingOneDimensions++; - } +INLINEDEF void TAD::collapse() { + auto shape = shape::shapeOf(shapeInfo); + // handle negative dimensions/backwards indexing + for (int i = 0; i < dimensionLength; i++) { + if ((dimension)[i] < 0) (dimension)[i] += shape::rank(this->shapeInfo); + } + + this->dimension = new int[dimensionLength]; + memcpy(this->dimension, this->originalDimension, + sizeof(int) * dimensionLength); + + // we can drop trailing dimensions where it's all singular for example: + // shape: 4,3,1,2 + // dimension: 0,2 + // the problem for 0,2 is equivalent to: 0 + // the rest of the algorithm handles cases suchas + // shape: 4,1,1,2 + // dimension: 0,1 + // when this happens there are other dimensions (eg: at the end) that matter + int trailingOneDimensions = 0; + // trailing ones + for (int i = dimensionLength - 1; i >= 0; i--) { + if (shape[dimension[i]] != 1) { + break; + } else if (shape[dimension[i]] == 1) + trailingOneDimensions++; + } + + dimensionLength -= trailingOneDimensions; + + int leadingOneDimensions = 0; + // trailing ones + for (int i = 0; i < dimensionLength; i++) { + if (shape[dimension[i]] != 1) { + break; + } else if (shape[dimension[i]] == 1) + leadingOneDimensions++; + } + + // bump the dimension pointer forward for however many leadingones there are + dimension += leadingOneDimensions; + // decrease the dimension length by the amount of leading ones + dimensionLength -= leadingOneDimensions; + + bool preConverged = true; + for (int i = 0; i < dimensionLength; i++) { + if (shape[dimension[i]] == 1) { + preConverged = false; + break; + } + } + + // we took away all the singular dimensions, we can just return + if (preConverged) return; + + // no more singular dimensions specified + bool done = false; + int onesDecrement = 0; + bool changed = false; + while (!done) { + // terminate early: only singular dimensions specified for reduce + if ((dimensionLength) < 1) { + done = true; + // signal as a no op + dimension[0] = -1; + break; + } + // captures intermediary result from the for loop + traceNew(3); - dimensionLength -= trailingOneDimensions; + int intermediaryResult[MAX_RANK]; + for (int i = 0; i < dimensionLength; i++) { + intermediaryResult[i] = (dimension)[i]; + } - int leadingOneDimensions = 0; - //trailing ones - for(int i = 0; i < dimensionLength; i++) { - if(shape[dimension[i]] != 1) { - break; - } - else if(shape[dimension[i]] == 1) - leadingOneDimensions++; + bool oneEncountered = false; + bool nonOneEncountered = false; + bool hitBeginning = false; + // assume intermediate collapsing of dimensions + bool collapseMiddleDimensions = true; + // note here that dimension length MAY end up being zero + for (int i = (dimensionLength)-1; i >= 0; i--) { + if (shape[(dimension)[i]] == 1) { + oneEncountered = true; + // trailing ones + if (!nonOneEncountered) { + // just drop trailing ones + dimensionLength--; + nonOneEncountered = false; + collapseMiddleDimensions = false; + // intermediary result just needs to have the results copied from + // dimension since we're just removing the tail + memcpy(intermediaryResult, dimension, sizeof(int) * dimensionLength); + changed = true; + // break the for loop and force it to go back around starting from the + // new index + break; + } else { + // already decremented all dimensions + // this was a result of hitting beginning ones + // we will only need to loop once + if (i == 0) { + hitBeginning = true; + } + // will need to shift dimensions that aren't trailing ones + // back by onesDecrement + // mark the intermediary result as -1 for non inclusion + intermediaryResult[i] = -1; + onesDecrement++; } + } else { + intermediaryResult[i] = (dimension)[i]; + nonOneEncountered = true; + } + } - //bump the dimension pointer forward for however many leadingones there are - dimension += leadingOneDimensions; - //decrease the dimension length by the amount of leading ones - dimensionLength -= leadingOneDimensions; - - - bool preConverged = true; - for(int i = 0; i < dimensionLength; i++) { - if(shape[dimension[i]] == 1) { - preConverged = false; - break; - } + if (collapseMiddleDimensions && oneEncountered) { + // collapse dimensions + int newIntermediary[MAX_RANK]; + int idx = 0; + for (int i = 0; i < dimensionLength; i++) { + // of note: dimension will decrease by the number of ones encountered + if (intermediaryResult[i] >= 0) { + // dimension 0 doesn't need to be decremented + if (intermediaryResult[i] > 0) + newIntermediary[idx++] = intermediaryResult[i] - onesDecrement; + else + newIntermediary[idx++] = intermediaryResult[i]; } + } - //we took away all the singular dimensions, we can just return - if(preConverged) - return; - - //no more singular dimensions specified - bool done = false; - int onesDecrement = 0; - bool changed = false; - while(!done) { - //terminate early: only singular dimensions specified for reduce - if((dimensionLength) < 1) { - done = true; - //signal as a no op - dimension[0] = -1; - break; - } - //captures intermediary result from the for loop - traceNew(3); - - int intermediaryResult[MAX_RANK]; - for(int i = 0; i < dimensionLength; i++) { - intermediaryResult[i] = (dimension)[i]; - } - - bool oneEncountered = false; - bool nonOneEncountered = false; - bool hitBeginning = false; - //assume intermediate collapsing of dimensions - bool collapseMiddleDimensions = true; - //note here that dimension length MAY end up being zero - for(int i = (dimensionLength) - 1; i >= 0; i--) { - if(shape[(dimension)[i]] == 1) { - oneEncountered = true; - //trailing ones - if(!nonOneEncountered) { - //just drop trailing ones - dimensionLength--; - nonOneEncountered = false; - collapseMiddleDimensions = false; - //intermediary result just needs to have the results copied from dimension since we're just removing the tail - memcpy(intermediaryResult,dimension, sizeof(int) * dimensionLength); - changed = true; - //break the for loop and force it to go back around starting from the new index - break; - } - else { - //already decremented all dimensions - //this was a result of hitting beginning ones - //we will only need to loop once - if(i == 0) { - hitBeginning = true; - } - //will need to shift dimensions that aren't trailing ones - //back by onesDecrement - //mark the intermediary result as -1 for non inclusion - intermediaryResult[i] = -1; - onesDecrement++; - } - } - else { - intermediaryResult[i] = (dimension)[i]; - nonOneEncountered = true; - } - } - - if(collapseMiddleDimensions && oneEncountered) { - //collapse dimensions - int newIntermediary[MAX_RANK]; - int idx = 0; - for(int i = 0; i < dimensionLength; i++) { - //of note: dimension will decrease by the number of ones encountered - if(intermediaryResult[i] >= 0) { - //dimension 0 doesn't need to be decremented - if(intermediaryResult[i] > 0) - newIntermediary[idx++] = intermediaryResult[i] - onesDecrement; - else - newIntermediary[idx++] = intermediaryResult[i]; - } - } - - - //decrement by the number of dimensions where ones appeared - (dimensionLength) -= onesDecrement; - //update to current result - memcpy(dimension,newIntermediary, sizeof(int) * (dimensionLength)); - changed = true; - - } - //converged: no need to change result - else { - //update to current result - memcpy(dimension,intermediaryResult, sizeof(int) * dimensionLength); - } - - //converge when there are no singular dimensions specified in the reduce - done = (!oneEncountered && nonOneEncountered) || hitBeginning; - //delete[] intermediaryResult; - } + // decrement by the number of dimensions where ones appeared + (dimensionLength) -= onesDecrement; + // update to current result + memcpy(dimension, newIntermediary, sizeof(int) * (dimensionLength)); + changed = true; - //nothing changed but need to collapse dimension - if(!changed && this->numOnes > 0) { - for(int i = 0; i < dimensionLength ;i++) { - dimension[i] -= numOnes; - } - } + } + // converged: no need to change result + else { + // update to current result + memcpy(dimension, intermediaryResult, sizeof(int) * dimensionLength); + } + // converge when there are no singular dimensions specified in the reduce + done = (!oneEncountered && nonOneEncountered) || hitBeginning; + // delete[] intermediaryResult; + } + // nothing changed but need to collapse dimension + if (!changed && this->numOnes > 0) { + for (int i = 0; i < dimensionLength; i++) { + dimension[i] -= numOnes; } + } } +} // namespace shape -#endif //LIBND4J_TAD_H +#endif // LIBND4J_TAD_H diff --git a/libnd4j/include/helpers/benchmark/BasicSuit.h b/libnd4j/include/helpers/benchmark/BasicSuit.h index 8e06f4f0e452..dda0593d09d7 100644 --- a/libnd4j/include/helpers/benchmark/BasicSuit.h +++ b/libnd4j/include/helpers/benchmark/BasicSuit.h @@ -22,12 +22,10 @@ #define SD_BASICSUIT_H namespace sd { - class BasicSuit { - protected: +class BasicSuit { + protected: + public: +}; +} // namespace sd - public: - - }; -} - -#endif //SD_BASICSUIT_H \ No newline at end of file +#endif // SD_BASICSUIT_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BoolParameters.h b/libnd4j/include/helpers/benchmark/BoolParameters.h index 20d547cee591..ce98a1bbac57 100644 --- a/libnd4j/include/helpers/benchmark/BoolParameters.h +++ b/libnd4j/include/helpers/benchmark/BoolParameters.h @@ -22,27 +22,25 @@ #define SD_BOOLPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class BoolParameters : public ParametersSpace { - protected: - - public: - BoolParameters(std::string name) : ParametersSpace() { - _name = name; - } +class BoolParameters : public ParametersSpace { + protected: + public: + BoolParameters(std::string name) : ParametersSpace() { _name = name; } - std::vector evaluate() override { - std::vector result; - result.emplace_back(0); - result.emplace_back(1); - return result; - } - }; -} + std::vector evaluate() override { + std::vector result; + result.emplace_back(0); + result.emplace_back(1); + return result; + } +}; +} // namespace sd -#endif //SD_PARAMETERSPACE_H \ No newline at end of file +#endif // SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 2429df5441ec..0d5af63a5ecb 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -19,94 +19,116 @@ // #include "../OpBenchmark.h" +#include #ifndef SD_BROADCASTBENCHMARK_H #define SD_BROADCASTBENCHMARK_H namespace sd { - class SD_EXPORT BroadcastBenchmark : public OpBenchmark { - public: - BroadcastBenchmark() : OpBenchmark() { - // - } - - BroadcastBenchmark(broadcast::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis) : OpBenchmark(testName, x, y, z, axis) { - _opNum = (int) op; - } - - - BroadcastBenchmark(broadcast::Ops op, const std::string &name, const std::vector &axis) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - _axis = axis; - } - - ~BroadcastBenchmark(){ - // - } - -void executeOnce() override { +class SD_EXPORT BroadcastBenchmark : public OpBenchmark { + public: + BroadcastBenchmark() : OpBenchmark() { + // + } + + BroadcastBenchmark(broadcast::Ops op, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, y, z, axis) { + _opNum = (int)op; + } + + BroadcastBenchmark(broadcast::Ops op, const std::string &name, + const std::vector &axis) + : OpBenchmark() { + _opNum = (int)op; + _testName = name; + _axis = axis; + } + + ~BroadcastBenchmark() { + // + } + + void executeOnce() override { PointersManager manager(LaunchContext::defaultContext(), "BroadcastBM"); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(_x.shapeInfo(), _axis); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(_z.shapeInfo(), _axis); - - auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); - auto tadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); - - auto tadOnlyShapeInfoZ = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); - auto tadOffsetsZ = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); - - NativeOpExecutioner::execBroadcast(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), - /*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + _x.shapeInfo(), _axis); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + _z.shapeInfo(), _axis); + + auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX.specialShapeInfo(); + auto tadOffsets = Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); + + auto tadOnlyShapeInfoZ = Environment::getInstance()->isCPU() + ? packZ.primaryShapeInfo() + : packZ.specialShapeInfo(); + auto tadOffsetsZ = Environment::getInstance()->isCPU() + ? packZ.primaryOffsets() + : packZ.specialOffsets(); + + NativeOpExecutioner::execBroadcast( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), + _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), + _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), + /*Nd4jLong **/ tadOnlyShapeInfo, /*Nd4jLong */ tadOffsets, + /*Nd4jLong */ tadOnlyShapeInfoZ, /*Nd4jLong */ tadOffsetsZ); manager.synchronize(); - } - - std::string axis() override { - if (_axis.empty()) - return ""; - else { - std::string result; - for (auto v:_axis) { - auto s = StringUtils::valueToString(v); - result += s; - result += ","; - } - return result; - } - } - - std::string inplace() override { - std::string result; - result += (_x == _z ? "true" : "false"); - return result; - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _y.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - OpBenchmark* clone() override { - return new BroadcastBenchmark((broadcast::Ops) _opNum, _testName, _x, _y, _z, _axis); - } - }; -} - -#endif //SD_BROADCASTBENCHMARK_H \ No newline at end of file + } + + std::string axis() override { + if (_axis.empty()) + return ""; + else { + std::string result; + for (auto v : _axis) { + auto s = StringUtils::valueToString(v); + result += s; + result += ","; + } + return result; + } + } + + std::string inplace() override { + std::string result; + result += (_x == _z ? "true" : "false"); + return result; + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + OpBenchmark *clone() override { + return new BroadcastBenchmark((broadcast::Ops)_opNum, _testName, _x, _y, _z, + _axis); + } +}; +} // namespace sd + +#endif // SD_BROADCASTBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h index fb414a837520..96f940626138 100644 --- a/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h +++ b/libnd4j/include/helpers/benchmark/DeclarableBenchmark.h @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 3/2/2019. // @@ -25,153 +24,144 @@ #include #include #include +#include #include #include -#include namespace sd { - class SD_EXPORT DeclarableBenchmark : public OpBenchmark { - protected: - sd::ops::DeclarableOp *_op = nullptr; - sd::graph::Context *_context = nullptr; - public: - DeclarableBenchmark(sd::ops::DeclarableOp &op, std::string name = 0) : OpBenchmark() { - _op = &op; //ops::OpRegistrator::getInstance()->getOperation(op.getOpHash()); - _testName = name; - } - - void setContext(sd::graph::Context *ctx) { - _context = ctx; - } - - std::string axis() override { - return "N/A"; - } - - std::string orders() override { - if(_context != nullptr && _context->isFastPath()){ - auto& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); - } - return s; - } - return "N/A"; - } - - std::string strides() override { - if (_context != nullptr && _context->isFastPath()) { - auto& ins = _context->fastpath_in(); - std::string s(""); - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); - } - return s; - } else - return "N/A"; +class SD_EXPORT DeclarableBenchmark : public OpBenchmark { + protected: + sd::ops::DeclarableOp *_op = nullptr; + sd::graph::Context *_context = nullptr; + + public: + DeclarableBenchmark(sd::ops::DeclarableOp &op, std::string name = 0) + : OpBenchmark() { + _op = + &op; // ops::OpRegistrator::getInstance()->getOperation(op.getOpHash()); + _testName = name; + } + + void setContext(sd::graph::Context *ctx) { _context = ctx; } + + std::string axis() override { return "N/A"; } + + std::string orders() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string inplace() override { - return "N/A"; + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); + } + return s; + } + return "N/A"; + } + + std::string strides() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s(""); + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - void executeOnce() override { - PointersManager pm(LaunchContext::defaultContext(), "DeclarableBenchmark"); - _op->execute(_context); - pm.synchronize(); - } - - OpBenchmark *clone() override { - return new DeclarableBenchmark(*_op, _testName); + s += ShapeUtils::strideAsString(_context->getNDArray(i).get()); + } + return s; + } else + return "N/A"; + } + + std::string inplace() override { return "N/A"; } + + void executeOnce() override { + PointersManager pm(LaunchContext::defaultContext(), "DeclarableBenchmark"); + _op->execute(_context); + pm.synchronize(); + } + + OpBenchmark *clone() override { + return new DeclarableBenchmark(*_op, _testName); + } + + std::string shape() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string shape() override { - if (_context != nullptr && _context->isFastPath()) { - auto& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += ShapeUtils::shapeAsString(_context->getNDArray(i).get()); - } - return s; - } else - return "N/A"; + s += ShapeUtils::shapeAsString(_context->getNDArray(i).get()); + } + return s; + } else + return "N/A"; + } + + std::string dataType() override { + if (_context != nullptr && _context->isFastPath()) { + auto &ins = _context->fastpath_in(); + std::string s; + for (int i = 0; i < ins.size(); i++) { + if (i > 0) { + s += "/"; } - - std::string dataType() override { - if (_context != nullptr && _context->isFastPath()){ - auto& ins = _context->fastpath_in(); - std::string s; - for( int i=0; i 0){ - s += "/"; - } - s += DataTypeUtils::asString(_context->getNDArray(i)->dataType()); - } - return s; - } else - return "N/A"; + s += DataTypeUtils::asString(_context->getNDArray(i)->dataType()); + } + return s; + } else + return "N/A"; + } + + std::string extra() override { + if (_context != nullptr) { + auto iargs = _context->getIArguments(); + auto targs = _context->getTArguments(); + auto bargs = _context->getBArguments(); + std::string e; + bool any = false; + if (!iargs.empty()) { + e += "iargs=["; + for (int i = 0; i < iargs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(iargs.at(i)); } - - std::string extra() override { - if(_context != nullptr){ - auto iargs = _context->getIArguments(); - auto targs = _context->getTArguments(); - auto bargs = _context->getBArguments(); - std::string e; - bool any = false; - if(!iargs.empty()){ - e += "iargs=["; - for( int i=0; i 0) - e += ","; - e += std::to_string(iargs.at(i)); - } - e += "]"; - any = true; - } - if(!targs.empty()){ - if(any) - e += ","; - e += "targs=["; - for( int i=0; i 0) - e += ","; - e += std::to_string(targs.at(i)); - } - e += "]"; - any = true; - } - if(!bargs.empty()){ - if(any) - e += ","; - e += "bargs=["; - for( int i=0; i 0) - e += ","; - e += std::to_string(bargs.at(i)); - } - e += "]"; - } - return e; - } - return "N/A"; + e += "]"; + any = true; + } + if (!targs.empty()) { + if (any) e += ","; + e += "targs=["; + for (int i = 0; i < targs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(targs.at(i)); } - - ~DeclarableBenchmark() { - if (_context != nullptr) - delete _context; + e += "]"; + any = true; + } + if (!bargs.empty()) { + if (any) e += ","; + e += "bargs=["; + for (int i = 0; i < bargs.size(); i++) { + if (i > 0) e += ","; + e += std::to_string(bargs.at(i)); } - }; -} - -#endif //SD_DECLARABLEBENCHMARKS_H \ No newline at end of file + e += "]"; + } + return e; + } + return "N/A"; + } + + ~DeclarableBenchmark() { + if (_context != nullptr) delete _context; + } +}; +} // namespace sd + +#endif // SD_DECLARABLEBENCHMARKS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntParameters.h b/libnd4j/include/helpers/benchmark/IntParameters.h index 3f4e4cc344f8..d630becbe0ae 100644 --- a/libnd4j/include/helpers/benchmark/IntParameters.h +++ b/libnd4j/include/helpers/benchmark/IntParameters.h @@ -22,34 +22,36 @@ #define SD_INTPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class IntParameters : public ParametersSpace { - protected: - int _start; - int _stop; - int _step; - - public: - IntParameters(std::string name, int start, int stop, int step = 1) : ParametersSpace() { - _start = start; - _stop = stop; - _step = step; - _name = name; - } - - std::vector evaluate() override { - std::vector result; - for (int e = _start; e <= _stop; e += _step) { - result.emplace_back(e); - } - return result; - } - }; -} - -#endif //SD_INTPARAMETERS_H \ No newline at end of file +class IntParameters : public ParametersSpace { + protected: + int _start; + int _stop; + int _step; + + public: + IntParameters(std::string name, int start, int stop, int step = 1) + : ParametersSpace() { + _start = start; + _stop = stop; + _step = step; + _name = name; + } + + std::vector evaluate() override { + std::vector result; + for (int e = _start; e <= _stop; e += _step) { + result.emplace_back(e); + } + return result; + } +}; +} // namespace sd + +#endif // SD_INTPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/IntPowerParameters.h b/libnd4j/include/helpers/benchmark/IntPowerParameters.h index 29667ae43d64..d9bac1cd24df 100644 --- a/libnd4j/include/helpers/benchmark/IntPowerParameters.h +++ b/libnd4j/include/helpers/benchmark/IntPowerParameters.h @@ -22,36 +22,39 @@ #define SD_INTPOWERPARAMETERS_H #include -#include #include +#include + #include "Parameters.h" #include "ParametersSpace.h" namespace sd { - class IntPowerParameters : public ParametersSpace { - protected: - int _base; - int _start; - int _stop; - int _step; - - public: - IntPowerParameters(std::string name, int base, int start, int stop, int step = 1) : ParametersSpace() { - _base = base; - _start = start; - _stop = stop; - _step = step; - _name = name; - } - - std::vector evaluate() override { - std::vector result; - for (int e = _start; e <= _stop; e += _step) { - result.emplace_back(sd::math::nd4j_pow(_base, e)); - } - return result; - } - }; -} - -#endif //SD_INTPOWERPARAMETERS_H \ No newline at end of file +class IntPowerParameters : public ParametersSpace { + protected: + int _base; + int _start; + int _stop; + int _step; + + public: + IntPowerParameters(std::string name, int base, int start, int stop, + int step = 1) + : ParametersSpace() { + _base = base; + _start = start; + _stop = stop; + _step = step; + _name = name; + } + + std::vector evaluate() override { + std::vector result; + for (int e = _start; e <= _stop; e += _step) { + result.emplace_back(sd::math::nd4j_pow(_base, e)); + } + return result; + } +}; +} // namespace sd + +#endif // SD_INTPOWERPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h index 6d040b1fd6d6..c1d81df89df1 100644 --- a/libnd4j/include/helpers/benchmark/MatrixBenchmark.h +++ b/libnd4j/include/helpers/benchmark/MatrixBenchmark.h @@ -18,95 +18,98 @@ // @author raver119@gmail.com // -#include #include +#include #ifndef SD_MATRIXBENCHMARK_H #define SD_MATRIXBENCHMARK_H namespace sd { - class SD_EXPORT MatrixBenchmark : public OpBenchmark { - private: - float _alpha = 1.0f; - float _beta = 0.0f; - bool _tA; - bool _tB; - public: - MatrixBenchmark() : OpBenchmark() { - // - } - - MatrixBenchmark(float alpha, float beta, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { - _alpha = alpha; - _beta = beta; - _tA = false; - _tB = false; - } - - MatrixBenchmark(float alpha, float beta, bool tA, bool tB, const std::string &name) : OpBenchmark() { - _testName = name; - _alpha = alpha; - _beta = beta; - _tA = tA; - _tB = tB; - } - - ~MatrixBenchmark(){ - // - } - - void executeOnce() override { - auto xT = (_tA ? _x.transpose() : _x); - auto yT = (_tB ? _y.transpose() : _y); - - MmulHelper::mmul(&xT, &yT, &_z, _alpha, _beta); - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - return "N/A"; - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _y.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string shape() override { - std::string result; - result += ShapeUtils::shapeAsString(_x); - result += "x"; - result += ShapeUtils::shapeAsString(_y); - result += "="; - result += _z.shapeInfo() == nullptr ? "" : ShapeUtils::shapeAsString(_z); - return result; - } - - OpBenchmark* clone() override { - MatrixBenchmark* mb = new MatrixBenchmark(_alpha, _beta, _testName, _x, _y, _z); - mb->_tA = _tA; - mb->_tB = _tB; - return mb; - } - }; -} - -#endif //SD_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT MatrixBenchmark : public OpBenchmark { + private: + float _alpha = 1.0f; + float _beta = 0.0f; + bool _tA; + bool _tB; + + public: + MatrixBenchmark() : OpBenchmark() { + // + } + + MatrixBenchmark(float alpha, float beta, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _alpha = alpha; + _beta = beta; + _tA = false; + _tB = false; + } + + MatrixBenchmark(float alpha, float beta, bool tA, bool tB, + const std::string &name) + : OpBenchmark() { + _testName = name; + _alpha = alpha; + _beta = beta; + _tA = tA; + _tB = tB; + } + + ~MatrixBenchmark() { + // + } + + void executeOnce() override { + auto xT = (_tA ? _x.transpose() : _x); + auto yT = (_tB ? _y.transpose() : _y); + + MmulHelper::mmul(&xT, &yT, &_z, _alpha, _beta); + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { return "N/A"; } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string shape() override { + std::string result; + result += ShapeUtils::shapeAsString(_x); + result += "x"; + result += ShapeUtils::shapeAsString(_y); + result += "="; + result += _z.shapeInfo() == nullptr ? "" : ShapeUtils::shapeAsString(_z); + return result; + } + + OpBenchmark *clone() override { + MatrixBenchmark *mb = + new MatrixBenchmark(_alpha, _beta, _testName, _x, _y, _z); + mb->_tA = _tA; + mb->_tB = _tB; + return mb; + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h index 5cc686238c94..bc1b8c053be5 100644 --- a/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h +++ b/libnd4j/include/helpers/benchmark/PairwiseBenchmark.h @@ -26,71 +26,76 @@ using namespace sd::graph; namespace sd { - class SD_EXPORT PairwiseBenchmark : public OpBenchmark { - public: - PairwiseBenchmark() : OpBenchmark() { - // - } - - PairwiseBenchmark(pairwise::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { - _opNum = (int) op; - } - - PairwiseBenchmark(pairwise::Ops op, std::string name) : OpBenchmark() { - _opNum = (int) op; - _testName = name; - } - - ~PairwiseBenchmark(){ - // - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "PairwiseBM"); - - NativeOpExecutioner::execPairwiseTransform(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr); - - manager.synchronize(); - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - std::string result; - result += (_x.platformBuffer() == _y.platformBuffer() ? "x==y" : "x!=y"); - result += "/"; - result += (_x.platformBuffer() == _z.platformBuffer() ? "x==z" : "x!=z"); - result += "/"; - result += (_y.platformBuffer() == _z.platformBuffer() ? "y==z" : "y!=z"); - return result; - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _y.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += ShapeUtils::strideAsString(_y); - result += "/"; - result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - OpBenchmark* clone() override { - return new PairwiseBenchmark((pairwise::Ops) _opNum, _testName, _x, _y, _z); - } - }; -} - -#endif //SD_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT PairwiseBenchmark : public OpBenchmark { + public: + PairwiseBenchmark() : OpBenchmark() { + // + } + + PairwiseBenchmark(pairwise::Ops op, const std::string &testName, + const NDArray &x, const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _opNum = (int)op; + } + + PairwiseBenchmark(pairwise::Ops op, std::string name) : OpBenchmark() { + _opNum = (int)op; + _testName = name; + } + + ~PairwiseBenchmark() { + // + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "PairwiseBM"); + + NativeOpExecutioner::execPairwiseTransform( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), + _y.specialBuffer(), _y.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), + _z.specialBuffer(), _z.specialShapeInfo(), nullptr); + + manager.synchronize(); + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { + std::string result; + result += (_x.platformBuffer() == _y.platformBuffer() ? "x==y" : "x!=y"); + result += "/"; + result += (_x.platformBuffer() == _z.platformBuffer() ? "x==z" : "x!=z"); + result += "/"; + result += (_y.platformBuffer() == _z.platformBuffer() ? "y==z" : "y!=z"); + return result; + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _y.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += ShapeUtils::strideAsString(_y); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + OpBenchmark *clone() override { + return new PairwiseBenchmark((pairwise::Ops)_opNum, _testName, _x, _y, _z); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/Parameters.h b/libnd4j/include/helpers/benchmark/Parameters.h index 6810edec2b67..54efd9be2e0d 100644 --- a/libnd4j/include/helpers/benchmark/Parameters.h +++ b/libnd4j/include/helpers/benchmark/Parameters.h @@ -26,27 +26,33 @@ #include namespace sd { - class Parameters { - private: - std::map _intParams; - std::map _boolParams; - std::map> _arrayParams; - public: - Parameters() = default; - - Parameters* addIntParam(std::string string, int param); - Parameters* addIntParam(std::initializer_list strings, std::initializer_list params); - - Parameters* addBoolParam(std::string string, bool param); - Parameters* addBoolParam(std::initializer_list strings, std::initializer_list params); - - Parameters* addArrayParam(std::string string, std::initializer_list param); - Parameters* addArrayParam(std::initializer_list strings, std::initializer_list> params); - - int getIntParam(std::string string) const ; - bool getBoolParam(std::string string) const; - std::vector getArrayParam(std::string string) const; - }; -} - -#endif //SD_PARAMETERS_H \ No newline at end of file +class Parameters { + private: + std::map _intParams; + std::map _boolParams; + std::map> _arrayParams; + + public: + Parameters() = default; + + Parameters* addIntParam(std::string string, int param); + Parameters* addIntParam(std::initializer_list strings, + std::initializer_list params); + + Parameters* addBoolParam(std::string string, bool param); + Parameters* addBoolParam(std::initializer_list strings, + std::initializer_list params); + + Parameters* addArrayParam(std::string string, + std::initializer_list param); + Parameters* addArrayParam( + std::initializer_list strings, + std::initializer_list> params); + + int getIntParam(std::string string) const; + bool getBoolParam(std::string string) const; + std::vector getArrayParam(std::string string) const; +}; +} // namespace sd + +#endif // SD_PARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersBatch.h b/libnd4j/include/helpers/benchmark/ParametersBatch.h index 7706492183a3..235cd5f6e8b3 100644 --- a/libnd4j/include/helpers/benchmark/ParametersBatch.h +++ b/libnd4j/include/helpers/benchmark/ParametersBatch.h @@ -22,63 +22,60 @@ #define SD_PARAMETERSBATCH_H #include -#include #include +#include + namespace sd { - class ParametersBatch { - protected: - std::vector _spaces; - public: - ParametersBatch() = default; - ParametersBatch(std::initializer_list spaces) { - _spaces = spaces; - } - - ParametersBatch(std::vector spaces) { - _spaces = spaces; - } - - - std::vector parameters() { - std::vector result; - std::vector> vectors; - int totalIterations = 1; - - // hehe - int xCoords[MAX_RANK]; - Nd4jLong xShape[MAX_RANK]; - int xRank = _spaces.size(); - - for (int e = 0; e < _spaces.size(); e++) { - auto space = _spaces[e]; - auto values = space->evaluate(); - vectors.emplace_back(values); - - totalIterations *= values.size(); - xShape[e] = values.size(); - } - - //nd4j_printf("Total Iterations: %i\n", totalIterations); - - for (int i = 0; i < totalIterations; i++) { - if (xRank > 0) - shape::index2coords(i, xRank, xShape, xCoords); - - Parameters params; - for (int j = 0; j < xRank; j++) { - int value = vectors[j][xCoords[j]]; - std::string name = _spaces[j]->name(); - params.addIntParam(name, value); - } - - result.emplace_back(params); - } - - - return result; - } - }; -} - -#endif //SD_PARAMETERSBATCH_H \ No newline at end of file +class ParametersBatch { + protected: + std::vector _spaces; + + public: + ParametersBatch() = default; + ParametersBatch(std::initializer_list spaces) { + _spaces = spaces; + } + + ParametersBatch(std::vector spaces) { _spaces = spaces; } + + std::vector parameters() { + std::vector result; + std::vector> vectors; + int totalIterations = 1; + + // hehe + int xCoords[MAX_RANK]; + Nd4jLong xShape[MAX_RANK]; + int xRank = _spaces.size(); + + for (int e = 0; e < _spaces.size(); e++) { + auto space = _spaces[e]; + auto values = space->evaluate(); + vectors.emplace_back(values); + + totalIterations *= values.size(); + xShape[e] = values.size(); + } + + // nd4j_printf("Total Iterations: %i\n", totalIterations); + + for (int i = 0; i < totalIterations; i++) { + if (xRank > 0) shape::index2coords(i, xRank, xShape, xCoords); + + Parameters params; + for (int j = 0; j < xRank; j++) { + int value = vectors[j][xCoords[j]]; + std::string name = _spaces[j]->name(); + params.addIntParam(name, value); + } + + result.emplace_back(params); + } + + return result; + } +}; +} // namespace sd + +#endif // SD_PARAMETERSBATCH_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ParametersSpace.h b/libnd4j/include/helpers/benchmark/ParametersSpace.h index a71b03d962d1..42438bfc97c2 100644 --- a/libnd4j/include/helpers/benchmark/ParametersSpace.h +++ b/libnd4j/include/helpers/benchmark/ParametersSpace.h @@ -24,19 +24,18 @@ #include namespace sd { - class ParametersSpace { - protected: - std::string _name; - public: - ParametersSpace() = default; - ~ParametersSpace() = default; - - std::string name() { - return _name; - } - - virtual std::vector evaluate() = 0; - }; -} - -#endif //SD_PARAMETERSPACE_H \ No newline at end of file +class ParametersSpace { + protected: + std::string _name; + + public: + ParametersSpace() = default; + ~ParametersSpace() = default; + + std::string name() { return _name; } + + virtual std::vector evaluate() = 0; +}; +} // namespace sd + +#endif // SD_PARAMETERSPACE_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/PredefinedParameters.h b/libnd4j/include/helpers/benchmark/PredefinedParameters.h index d26b7bf0e7b4..ef4e55dc7868 100644 --- a/libnd4j/include/helpers/benchmark/PredefinedParameters.h +++ b/libnd4j/include/helpers/benchmark/PredefinedParameters.h @@ -24,23 +24,24 @@ #include "ParametersSpace.h" namespace sd { - class PredefinedParameters : public ParametersSpace{ - std::vector _params; - public: - PredefinedParameters(std::string name, std::initializer_list parameters) : ParametersSpace() { - _name = name; - _params = parameters; - } - - PredefinedParameters(std::string name, std::vector parameters) : ParametersSpace() { - _name = name; - _params = parameters; - } - - std::vector evaluate() override { - return _params; - } - }; -} - -#endif //SD_PREDEFINEDPARAMETERS_H \ No newline at end of file +class PredefinedParameters : public ParametersSpace { + std::vector _params; + + public: + PredefinedParameters(std::string name, std::initializer_list parameters) + : ParametersSpace() { + _name = name; + _params = parameters; + } + + PredefinedParameters(std::string name, std::vector parameters) + : ParametersSpace() { + _name = name; + _params = parameters; + } + + std::vector evaluate() override { return _params; } +}; +} // namespace sd + +#endif // SD_PREDEFINEDPARAMETERS_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index 491d0254859f..febbba1cfd0f 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -20,6 +20,7 @@ #include #include + #include "../OpBenchmark.h" #ifndef SD_REDUCEBENCHMARK_H @@ -28,124 +29,160 @@ using namespace sd::graph; namespace sd { - class SD_EXPORT ReductionBenchmark : public OpBenchmark { - protected: - int _opType; //0=Float, 1=Same - public: - ReductionBenchmark() : OpBenchmark() { - // - } - - ReductionBenchmark(reduce::FloatOps op, const std::string &testName, const NDArray &x, const NDArray &z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::SameOps op, const std::string &testName, const NDArray &x, const NDArray &z, std::initializer_list axis) : OpBenchmark(testName, x, z, axis) { - _opNum = (int) op; - _opType = 1; - } - - - ReductionBenchmark(reduce::FloatOps op) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::FloatOps op, const std::string &testName) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - _testName = testName; - } - - ReductionBenchmark(reduce::SameOps op) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - } - - ReductionBenchmark(reduce::SameOps op, const std::string &testName) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - _testName = testName; - } - - ReductionBenchmark(reduce::FloatOps op, const std::string &testName, const NDArray &x, const NDArray &z, const std::vector &axis) : OpBenchmark(testName ,x, z, axis) { - _opNum = (int) op; - _opType = 0; - } - - ReductionBenchmark(reduce::SameOps op, const std::string &testName, const NDArray &x, const NDArray &z, const std::vector &axis) : OpBenchmark(testName ,x, z, axis) { - _opNum = (int) op; - _opType = 1; - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "reductionBM"); - - if (_z.isScalar() || _y.shapeInfo() == nullptr) - if (_opType == 0) - NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo()); - else - NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo()); - else { - auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_x.shapeInfo(), _axis); - - auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() ? pack.primaryShapeInfo() : pack.specialShapeInfo(); - auto tadOffsets = Environment::getInstance()->isCPU() ? pack.primaryOffsets() : pack.specialOffsets(); - - if (_opType == 0) - NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); - else - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets); - } - - manager.synchronize(); - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - return result; - } - - std::string inplace() override { - return "n/a"; - } - - ~ReductionBenchmark(){ - // - } - - std::string axis() override { - if (_axis.empty()) - return "ALL"; - else { - std::string result; - for (auto v:_axis) { - auto s = StringUtils::valueToString(v); - result += s; - result += ","; - } - - return result; - } - } - - OpBenchmark* clone() override { - if (_opType == 0) - return new ReductionBenchmark((reduce::FloatOps) _opNum, _testName, _x, _z, _axis); - else - return new ReductionBenchmark((reduce::SameOps) _opNum, _testName, _x, _z, _axis); - } - }; -} - -#endif //SD_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT ReductionBenchmark : public OpBenchmark { + protected: + int _opType; // 0=Float, 1=Same + public: + ReductionBenchmark() : OpBenchmark() { + // + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + std::initializer_list axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + std::initializer_list axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 1; + } + + ReductionBenchmark(reduce::FloatOps op) : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName) + : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + _testName = testName; + } + + ReductionBenchmark(reduce::SameOps op) : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName) + : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + _testName = testName; + } + + ReductionBenchmark(reduce::FloatOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 0; + } + + ReductionBenchmark(reduce::SameOps op, const std::string &testName, + const NDArray &x, const NDArray &z, + const std::vector &axis) + : OpBenchmark(testName, x, z, axis) { + _opNum = (int)op; + _opType = 1; + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "reductionBM"); + + if (_z.isScalar() || _y.shapeInfo() == nullptr) + if (_opType == 0) + NativeOpExecutioner::execReduceFloatScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo()); + else + NativeOpExecutioner::execReduceSameScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo()); + else { + auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + _x.shapeInfo(), _axis); + + auto tadOnlyShapeInfo = Environment::getInstance()->isCPU() + ? pack.primaryShapeInfo() + : pack.specialShapeInfo(); + auto tadOffsets = Environment::getInstance()->isCPU() + ? pack.primaryOffsets() + : pack.specialOffsets(); + + if (_opType == 0) + NativeOpExecutioner::execReduceFloat( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, + tadOffsets); + else + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), nullptr, + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, + tadOffsets); + } + + manager.synchronize(); + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + return result; + } + + std::string inplace() override { return "n/a"; } + + ~ReductionBenchmark() { + // + } + + std::string axis() override { + if (_axis.empty()) + return "ALL"; + else { + std::string result; + for (auto v : _axis) { + auto s = StringUtils::valueToString(v); + result += s; + result += ","; + } + + return result; + } + } + + OpBenchmark *clone() override { + if (_opType == 0) + return new ReductionBenchmark((reduce::FloatOps)_opNum, _testName, _x, _z, + _axis); + else + return new ReductionBenchmark((reduce::SameOps)_opNum, _testName, _x, _z, + _axis); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h index 9b1da8b33d72..97292124f606 100644 --- a/libnd4j/include/helpers/benchmark/ScalarBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ScalarBenchmark.h @@ -25,68 +25,77 @@ using namespace sd::graph; namespace sd { - class SD_EXPORT ScalarBenchmark : public OpBenchmark { - public: - ScalarBenchmark() : OpBenchmark() { - // - } - - ~ScalarBenchmark(){ - - } - - ScalarBenchmark(scalar::Ops op) : OpBenchmark() { - _opNum = (int) op; - } - - ScalarBenchmark(scalar::Ops op, const std::string &testName) : OpBenchmark() { - _opNum = (int) op; - _testName = testName; - } - - ScalarBenchmark(scalar::Ops op, const std::string &testName, const NDArray &x, const NDArray &y, const NDArray &z) : OpBenchmark(testName, x, y, z) { - _opNum = (int) op; - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "ScalarBM"); - - if (_z.shapeInfo() == nullptr) - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), nullptr); - else - NativeOpExecutioner::execScalar(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), _y.specialShapeInfo(), nullptr); - - manager.synchronize(); - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string axis() override { - return "N/A"; - } - - std::string inplace() override { - return _x == _z ? "true" : "false"; - } - - OpBenchmark* clone() override { - return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x.shapeInfo() == nullptr ? _x : NDArray(_x.dup()) , _y.shapeInfo() == nullptr ? _y : NDArray(_y.dup()), _z.shapeInfo() == nullptr ? _z : NDArray(_z.dup())); - } - }; -} - -#endif //SD_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT ScalarBenchmark : public OpBenchmark { + public: + ScalarBenchmark() : OpBenchmark() { + // + } + + ~ScalarBenchmark() {} + + ScalarBenchmark(scalar::Ops op) : OpBenchmark() { _opNum = (int)op; } + + ScalarBenchmark(scalar::Ops op, const std::string &testName) : OpBenchmark() { + _opNum = (int)op; + _testName = testName; + } + + ScalarBenchmark(scalar::Ops op, const std::string &testName, const NDArray &x, + const NDArray &y, const NDArray &z) + : OpBenchmark(testName, x, y, z) { + _opNum = (int)op; + } + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "ScalarBM"); + + if (_z.shapeInfo() == nullptr) + NativeOpExecutioner::execScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), + _y.specialShapeInfo(), nullptr); + else + NativeOpExecutioner::execScalar( + LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), + _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), + _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), + _y.buffer(), _y.shapeInfo(), _y.specialBuffer(), + _y.specialShapeInfo(), nullptr); + + manager.synchronize(); + } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string axis() override { return "N/A"; } + + std::string inplace() override { return _x == _z ? "true" : "false"; } + + OpBenchmark *clone() override { + return new ScalarBenchmark( + (scalar::Ops)_opNum, _testName, + _x.shapeInfo() == nullptr ? _x : NDArray(_x.dup()), + _y.shapeInfo() == nullptr ? _y : NDArray(_y.dup()), + _z.shapeInfo() == nullptr ? _z : NDArray(_z.dup())); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/benchmark/TransformBenchmark.h b/libnd4j/include/helpers/benchmark/TransformBenchmark.h index 1476fd95284e..b5f1874737fc 100644 --- a/libnd4j/include/helpers/benchmark/TransformBenchmark.h +++ b/libnd4j/include/helpers/benchmark/TransformBenchmark.h @@ -25,105 +25,123 @@ using namespace sd::graph; namespace sd { - class SD_EXPORT TransformBenchmark : public OpBenchmark { - - protected: - int _opType; // 0=StrictOps, 1=Same, 2=Any, 3=Float - - public: - TransformBenchmark() : OpBenchmark() { - // - } - - TransformBenchmark(int opNum, int opType, const std::string &testName, const NDArray &x, const NDArray &z) : OpBenchmark(testName, x, z) { - _opNum = opNum; - _opType = opType; - } - - TransformBenchmark(transform::StrictOps op, const std::string &testName, const NDArray &x, const NDArray &z) : OpBenchmark(testName, x, z) { - _opNum = (int) op; - _opType = 0; - } - - TransformBenchmark(transform::StrictOps op, const std::string &name) : OpBenchmark() { - _opNum = (int) op; - _opType = 0; - _testName = name; - } - - TransformBenchmark(transform::SameOps op, const std::string &name) : OpBenchmark() { - _opNum = (int) op; - _opType = 1; - _testName = name; - } - - TransformBenchmark(transform::AnyOps op, const std::string &name) : OpBenchmark() { - _opNum = (int) op; - _opType = 2; - _testName = name; - } - - TransformBenchmark(transform::FloatOps op, const std::string &name) : OpBenchmark() { - _opNum = (int) op; - _opType = 3; - _testName = name; - } - - ~TransformBenchmark(){ - - } - - void executeOnce() override { - PointersManager manager(LaunchContext::defaultContext(), "TransformBM"); - - auto z = _z.shapeInfo() == nullptr ? _x : _z; - - switch (_opType) { - case 0: - NativeOpExecutioner::execTransformStrict(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 1: - NativeOpExecutioner::execTransformSame(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 2: - NativeOpExecutioner::execTransformAny(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); - break; - case 3: - NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), _opNum, _x.buffer(), _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), _z.specialShapeInfo(), nullptr, nullptr, nullptr); - break; - } - - manager.synchronize(); - } - - std::string axis() override { - return "N/A"; - } - - std::string orders() override { - std::string result; - result += _x.ordering(); - result += "/"; - result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); - return result; - } - - std::string strides() override { - std::string result; - result += ShapeUtils::strideAsString(_x); - result += "/"; - result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) : ShapeUtils::strideAsString(_z); - return result; - } - - std::string inplace() override { - return _x == _z ? "true" : "false"; - } - - OpBenchmark* clone() override { - return new TransformBenchmark(_opNum, _opType, _testName, _x, _z); - } - }; -} - -#endif //SD_SCALARBENCHMARK_H \ No newline at end of file +class SD_EXPORT TransformBenchmark : public OpBenchmark { + protected: + int _opType; // 0=StrictOps, 1=Same, 2=Any, 3=Float + + public: + TransformBenchmark() : OpBenchmark() { + // + } + + TransformBenchmark(int opNum, int opType, const std::string &testName, + const NDArray &x, const NDArray &z) + : OpBenchmark(testName, x, z) { + _opNum = opNum; + _opType = opType; + } + + TransformBenchmark(transform::StrictOps op, const std::string &testName, + const NDArray &x, const NDArray &z) + : OpBenchmark(testName, x, z) { + _opNum = (int)op; + _opType = 0; + } + + TransformBenchmark(transform::StrictOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 0; + _testName = name; + } + + TransformBenchmark(transform::SameOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 1; + _testName = name; + } + + TransformBenchmark(transform::AnyOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 2; + _testName = name; + } + + TransformBenchmark(transform::FloatOps op, const std::string &name) + : OpBenchmark() { + _opNum = (int)op; + _opType = 3; + _testName = name; + } + + ~TransformBenchmark() {} + + void executeOnce() override { + PointersManager manager(LaunchContext::defaultContext(), "TransformBM"); + + auto z = _z.shapeInfo() == nullptr ? _x : _z; + + switch (_opType) { + case 0: + NativeOpExecutioner::execTransformStrict( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 1: + NativeOpExecutioner::execTransformSame( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 2: + NativeOpExecutioner::execTransformAny( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + case 3: + NativeOpExecutioner::execTransformFloat( + LaunchContext::defaultContext(), _opNum, _x.buffer(), + _x.shapeInfo(), _x.specialBuffer(), _x.specialShapeInfo(), + _z.buffer(), _z.shapeInfo(), _z.specialBuffer(), + _z.specialShapeInfo(), nullptr, nullptr, nullptr); + break; + } + + manager.synchronize(); + } + + std::string axis() override { return "N/A"; } + + std::string orders() override { + std::string result; + result += _x.ordering(); + result += "/"; + result += _z.shapeInfo() == nullptr ? _x.ordering() : _z.ordering(); + return result; + } + + std::string strides() override { + std::string result; + result += ShapeUtils::strideAsString(_x); + result += "/"; + result += _z.shapeInfo() == nullptr ? ShapeUtils::strideAsString(_x) + : ShapeUtils::strideAsString(_z); + return result; + } + + std::string inplace() override { return _x == _z ? "true" : "false"; } + + OpBenchmark *clone() override { + return new TransformBenchmark(_opNum, _opType, _testName, _x, _z); + } +}; +} // namespace sd + +#endif // SD_SCALARBENCHMARK_H \ No newline at end of file diff --git a/libnd4j/include/helpers/biDiagonalUp.h b/libnd4j/include/helpers/biDiagonalUp.h index aaf64d41def1..79644f350948 100644 --- a/libnd4j/include/helpers/biDiagonalUp.h +++ b/libnd4j/include/helpers/biDiagonalUp.h @@ -21,59 +21,52 @@ #ifndef LIBND4J_BIDIAGONALUP_H #define LIBND4J_BIDIAGONALUP_H -#include #include +#include namespace sd { namespace ops { namespace helpers { - class BiDiagonalUp { - - public: - - NDArray _HHmatrix; // 2D Householder matrix - NDArray _HHbidiag; // vector which contains Householder coefficients - - /** - * constructor - * - * matrix - input matrix expected to be bi-diagonalized, remains unaffected - */ - BiDiagonalUp(const NDArray& matrix); - - /** - * this method evaluates data (coeff, normX, tail) used in Householder transformation - * formula for Householder matrix: P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, ...] - * tail and coeff are stored in _HHmatrix - * normX are stored in _HHbidiag - */ - template - void _evalData(); - - void evalData(); - - /** - * this method evaluates product of Householder sequence matrices (transformations) acting on columns - * - * type - type of sequence, type = 'u' (acting on columns) or type = 'v' (acting on rows) - */ - template - HHsequence makeHHsequence_(const char type) const; - - HHsequence makeHHsequence(const char type) const; - + public: + NDArray _HHmatrix; // 2D Householder matrix + NDArray _HHbidiag; // vector which contains Householder coefficients + + /** + * constructor + * + * matrix - input matrix expected to be bi-diagonalized, remains unaffected + */ + BiDiagonalUp(const NDArray& matrix); + + /** + * this method evaluates data (coeff, normX, tail) used in Householder + * transformation formula for Householder matrix: P = identity_matrix - coeff + * * w * w^T P * x = [normX, 0, 0 , 0, ...] coeff - scalar w = [1, w1, w2, w3, + * ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, + * ...] tail and coeff are stored in _HHmatrix normX are stored in _HHbidiag + */ + template + void _evalData(); + + void evalData(); + + /** + * this method evaluates product of Householder sequence matrices + * (transformations) acting on columns + * + * type - type of sequence, type = 'u' (acting on columns) or type = 'v' + * (acting on rows) + */ + template + HHsequence makeHHsequence_(const char type) const; + + HHsequence makeHHsequence(const char type) const; }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_BIDIAGONALUP_H +#endif // LIBND4J_BIDIAGONALUP_H diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index 10b8a52c32df..c1a014e5131b 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -21,109 +21,119 @@ #ifndef __CUDABLAS__ -#include #include -#include +#include #include #include +#include + #include namespace sd { - ConstantHelper::ConstantHelper() { - int numDevices = getNumberOfDevices(); - _cache.resize(numDevices); - _counters.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - MAP_IMPL map; - - _cache[e] = map; - _counters[e] = 0L; - } - } - - ConstantHelper* ConstantHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new sd::ConstantHelper(); - - return _INSTANCE; - } - - void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { - if (workspace == nullptr) { - auto deviceId = getCurrentDevice(); - _counters[deviceId] += numBytes; - } - - int8_t *ptr = nullptr; - ALLOCATE(ptr, workspace, numBytes, int8_t); - - std::memcpy(ptr, src, numBytes); - return ptr; - } - - int ConstantHelper::getCurrentDevice() { - return AffinityManager::currentDeviceId(); - } - - int ConstantHelper::getNumberOfDevices() { - return AffinityManager::numberOfDevices(); - } - - ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType) { - const auto deviceId = getCurrentDevice(); +ConstantHelper::ConstantHelper() { + int numDevices = getNumberOfDevices(); + _cache.resize(numDevices); + _counters.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + MAP_IMPL map; - // we're locking away cache modification - _mutexHolder.lock(); - - if (_cache[deviceId].count(descriptor) == 0) { - _cache[deviceId][descriptor] = new ConstantHolder(); - } + _cache[e] = map; + _counters[e] = 0L; + } +} - auto holder = _cache[deviceId][descriptor]; +ConstantHelper *ConstantHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new sd::ConstantHelper(); - // releasing cache lock - _mutexHolder.unlock(); + return _INSTANCE; +} +void *ConstantHelper::replicatePointer(void *src, size_t numBytes, + memory::Workspace *workspace) { + if (workspace == nullptr) { + auto deviceId = getCurrentDevice(); + _counters[deviceId] += numBytes; + } - ConstantDataBuffer* result; + int8_t *ptr = nullptr; + ALLOCATE(ptr, workspace, numBytes, int8_t); - // access to this holder instance is synchronous - holder->mutex()->lock(); + std::memcpy(ptr, src, numBytes); + return ptr; +} - if (holder->hasBuffer(dataType)) - result = holder->getConstantDataBuffer(dataType); - else { - auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); - auto cbuff = new int8_t[size]; - _counters[deviceId] += size; +int ConstantHelper::getCurrentDevice() { + return AffinityManager::currentDeviceId(); +} - // create buffer with this dtype - if (descriptor.isFloat()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast(descriptor.floatValues().data()), descriptor.length(), cbuff), (sd::DataType::DOUBLE, double), LIBND4J_TYPES); - } else if (descriptor.isInteger()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, (nullptr, const_cast(descriptor.integerValues().data()), descriptor.length(), cbuff), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); - } +int ConstantHelper::getNumberOfDevices() { + return AffinityManager::numberOfDevices(); +} - ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), DataTypeUtils::sizeOf(dataType)); - holder->addBuffer(dataBuffer, dataType); +ConstantDataBuffer *ConstantHelper::constantBuffer( + const ConstantDescriptor &descriptor, sd::DataType dataType) { + const auto deviceId = getCurrentDevice(); + + // we're locking away cache modification + _mutexHolder.lock(); + + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } + + auto holder = _cache[deviceId][descriptor]; + + // releasing cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer *result; + + // access to this holder instance is synchronous + holder->mutex()->lock(); + + if (holder->hasBuffer(dataType)) + result = holder->getConstantDataBuffer(dataType); + else { + auto size = descriptor.length() * DataTypeUtils::sizeOf(dataType); + auto cbuff = new int8_t[size]; + _counters[deviceId] += size; + + // create buffer with this dtype + if (descriptor.isFloat()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::DOUBLE, dataType, sd::TypeCast::convertGeneric, + (nullptr, const_cast(descriptor.floatValues().data()), + descriptor.length(), cbuff), + (sd::DataType::DOUBLE, double), LIBND4J_TYPES); + } else if (descriptor.isInteger()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::INT64, dataType, sd::TypeCast::convertGeneric, + (nullptr, const_cast(descriptor.integerValues().data()), + descriptor.length(), cbuff), + (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); + } - result = holder->getConstantDataBuffer(dataType); - } - holder->mutex()->unlock(); + ConstantDataBuffer dataBuffer(cbuff, nullptr, descriptor.length(), + DataTypeUtils::sizeOf(dataType)); + holder->addBuffer(dataBuffer, dataType); - return result; - } + result = holder->getConstantDataBuffer(dataType); + } + holder->mutex()->unlock(); - Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { - int numDevices = getNumberOfDevices(); - if (deviceId > numDevices || deviceId < 0) - return 0L; - else - return _counters[deviceId]; - } + return result; +} - sd::ConstantHelper* sd::ConstantHelper::_INSTANCE = 0; +Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { + int numDevices = getNumberOfDevices(); + if (deviceId > numDevices || deviceId < 0) + return 0L; + else + return _counters[deviceId]; } +sd::ConstantHelper *sd::ConstantHelper::_INSTANCE = 0; +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index fc8abe8aa0a7..9952c3f02293 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -21,174 +21,191 @@ #ifndef __CUDABLAS__ #include -#include #include #include +#include namespace sd { - ConstantShapeHelper::ConstantShapeHelper() { - _cache.resize(32); - for (int e = 0; e < 32; e++) { - MAP_IMPL cache; - _cache[e] = cache; - } - } - - ConstantShapeHelper* ConstantShapeHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new ConstantShapeHelper(); - - return _INSTANCE; - } - - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor); - } - - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor); - } +ConstantShapeHelper::ConstantShapeHelper() { + _cache.resize(32); + for (int e = 0; e < 32; e++) { + MAP_IMPL cache; + _cache[e] = cache; + } +} +ConstantShapeHelper* ConstantShapeHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new ConstantShapeHelper(); - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { - int deviceId = 0; + return _INSTANCE; +} - std::lock_guard lock(_mutex); +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + sd::DataType dataType, char order, const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor); +} - if (_cache[deviceId].count(descriptor) == 0) { - auto hPtr = descriptor.toShapeInfo(); - ConstantDataBuffer buffer(hPtr, nullptr, shape::shapeInfoLength(hPtr)*sizeof(Nd4jLong), DataType::INT64); - ShapeDescriptor descriptor1(descriptor); - _cache[deviceId][descriptor1] = buffer; - return _cache[deviceId][descriptor1]; - } else { - return _cache[deviceId].at(descriptor); - } - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor); +} - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { - ShapeDescriptor descriptor(shapeInfo); - return bufferForShapeInfo(descriptor); - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const ShapeDescriptor& descriptor) { + int deviceId = 0; + + std::lock_guard lock(_mutex); + + if (_cache[deviceId].count(descriptor) == 0) { + auto hPtr = descriptor.toShapeInfo(); + ConstantDataBuffer buffer(hPtr, nullptr, + shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), + DataType::INT64); + ShapeDescriptor descriptor1(descriptor); + _cache[deviceId][descriptor1] = buffer; + return _cache[deviceId][descriptor1]; + } else { + return _cache[deviceId].at(descriptor); + } +} - bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { - bool result; - int deviceId = 0; - std::lock_guard lock(_mutex); +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const Nd4jLong* shapeInfo) { + ShapeDescriptor descriptor(shapeInfo); + return bufferForShapeInfo(descriptor); +} - return _cache[deviceId].count(descriptor) != 0; - } +bool ConstantShapeHelper::checkBufferExistenceForShapeInfo( + ShapeDescriptor& descriptor) { + bool result; + int deviceId = 0; + std::lock_guard lock(_mutex); - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor).primaryAsT(); - } + return _cache[deviceId].count(descriptor) != 0; +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { - return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - const Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo( + dataType, shape::order(shapeInfo), shape::rank(shapeInfo), + shape::shapeOf(const_cast(shapeInfo))); +} - const Nd4jLong* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ConstantShapeHelper::emptyShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - const Nd4jLong* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ConstantShapeHelper::scalarShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ConstantShapeHelper::vectorShapeInfo( + const Nd4jLong length, const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - const Nd4jLong* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { - return bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, + const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +const Nd4jLong* ConstantShapeHelper::createShapeInfo( + const ShapeDescriptor& descriptor) { + return bufferForShapeInfo(descriptor).primaryAsT(); +} - if (destroyOriginal) - RELEASE(shapeInfo, nullptr) +const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - return result; - } + if (destroyOriginal) RELEASE(shapeInfo, nullptr) - const Nd4jLong* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); + return result; +} - RELEASE(shapeInfo, workspace); +const Nd4jLong* ConstantShapeHelper::createFromExisting( + Nd4jLong* shapeInfo, sd::memory::Workspace* workspace) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - return result; - } + RELEASE(shapeInfo, workspace); + return result; +} //////////////////////////////////////////////////////////////////////// -ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector &dimensions) { - - Nd4jLong* newShapeInfo = nullptr; - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); - - newShapeInfo[0] = shape::rank(maxShapeInfo); - - sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type - newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews - newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order - - if(!dimensions.empty()) { - - for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { - - if(j < dimensions.size() && dimensions[j] == i) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; - shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; - ++j; - } - else{ - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) - ++k; - } - } - } - else{ - - for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { - - if(j >= 0) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; - shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; - --j; - } - else { - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - } - } - } - - ShapeDescriptor descriptor(newShapeInfo); - - RELEASE(newShapeInfo, workspace); - - return bufferForShapeInfo(descriptor); +ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace, const std::vector& dimensions) { + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, + shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = + shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if (!dimensions.empty()) { + for (uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + if (j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if (shape::sizeAt(minShapeInfo, k) == 1 && + dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } + } + } else { + for (int j = shape::rank(minShapeInfo) - 1, + i = shape::rank(maxShapeInfo) - 1; + i >= 0; --i) { + if (j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 + ? 0 + : shape::stride(minShapeInfo)[j]; + --j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } + } + } + + ShapeDescriptor descriptor(newShapeInfo); + + RELEASE(newShapeInfo, workspace); + + return bufferForShapeInfo(descriptor); } - sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; -} +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index ea32db7e6973..3f53a7be9a51 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -19,101 +19,117 @@ // #include "../ConstantTadHelper.h" -#include + #include +#include #ifndef __CUDABLAS__ - namespace sd { - ConstantTadHelper::ConstantTadHelper() { - MAP_IMPL pack; - _cache.emplace_back(pack); - } - - ConstantTadHelper* ConstantTadHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new ConstantTadHelper(); - - return _INSTANCE; - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { - const int deviceId = 0; - - _mutex.lock(); - if (_cache[deviceId].count(descriptor) == 0) { - - const auto shapeInfo = descriptor.originalShape().toShapeInfo(); - const int rank = shape::rank(shapeInfo); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); - const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); - - auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them) - auto oPtr = new Nd4jLong[numOfSubArrs]; - - if (numOfSubArrs > 0) - shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); - - - ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64); - ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64); - TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); - - - - // auto shapeInfo = descriptor.originalShape().toShapeInfo(); - // shape::TAD tad; - // tad.init(shapeInfo, descriptor.axis().data(), descriptor.axis().size()); - // tad.createTadOnlyShapeInfo(); - // tad.createOffsets(); - - // auto sPtr = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo)]; - // auto oPtr = new Nd4jLong[tad.numTads]; - - // memcpy(sPtr, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - // memcpy(oPtr, tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - // TadPack t(shapesBuffer, offsetsBuffer, tad.numTads); +ConstantTadHelper::ConstantTadHelper() { + MAP_IMPL pack; + _cache.emplace_back(pack); +} +ConstantTadHelper *ConstantTadHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new ConstantTadHelper(); - _cache[deviceId][descriptor] = t; + return _INSTANCE; +} - TadPack &r = _cache[deviceId][descriptor]; - _mutex.unlock(); +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int dimension, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); +} - delete[] shapeInfo; +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, const_cast(dimensions.data()), + dimensions.size(), keepUnitiesInShape); +} - return r; - } else { - TadPack r = _cache[deviceId][descriptor]; - _mutex.unlock(); +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int *dimensions, int dimLength, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, + keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} - return r; - } - } +TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} - sd::ConstantTadHelper* sd::ConstantTadHelper::_INSTANCE = 0; +TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + const int deviceId = 0; + + _mutex.lock(); + if (_cache[deviceId].count(descriptor) == 0) { + const auto shapeInfo = descriptor.originalShape().toShapeInfo(); + const int rank = shape::rank(shapeInfo); + const std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); + const int subArrRank = + (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) + ? rank + : rank - dimsToExclude.size(); + + auto sPtr = new Nd4jLong[shape::shapeInfoLength( + subArrRank)]; // shape of sub-arrays (same for all for them) + auto oPtr = new Nd4jLong[numOfSubArrs]; + + if (numOfSubArrs > 0) + shape::calcSubArrsShapeInfoAndOffsets( + shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), + sPtr, oPtr, descriptor.areUnitiesinShape()); + + ConstantDataBuffer shapesBuffer( + sPtr, nullptr, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong), + DataType::INT64); + ConstantDataBuffer offsetsBuffer( + oPtr, nullptr, numOfSubArrs * sizeof(Nd4jLong), DataType::INT64); + TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); + + // auto shapeInfo = descriptor.originalShape().toShapeInfo(); + // shape::TAD tad; + // tad.init(shapeInfo, descriptor.axis().data(), descriptor.axis().size()); + // tad.createTadOnlyShapeInfo(); + // tad.createOffsets(); + + // auto sPtr = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo)]; + // auto oPtr = new Nd4jLong[tad.numTads]; + + // memcpy(sPtr, tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); memcpy(oPtr, + // tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + + // TadPack t(shapesBuffer, offsetsBuffer, tad.numTads); + + _cache[deviceId][descriptor] = t; + + TadPack &r = _cache[deviceId][descriptor]; + _mutex.unlock(); + + delete[] shapeInfo; + + return r; + } else { + TadPack r = _cache[deviceId][descriptor]; + _mutex.unlock(); + + return r; + } } +sd::ConstantTadHelper *sd::ConstantTadHelper::_INSTANCE = 0; +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 26a6643c34df..043ddf1725f7 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -19,384 +19,429 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // #include "../MmulHelper.h" + #include -#include -#include #include #include - +#include +#include namespace sd { ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = vA->bufferAsT(); - const T2* B = vB->bufferAsT(); - T3* C = vC->bufferAsT(); - - const T3 alphaZ = alpha; - const T3 betaZ = beta; - - const bool betaPersent = beta; +static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int aMaxis, const int aKaxis, const int bKaxis, + const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* bShapeInfo = vB->shapeInfo(); - const Nd4jLong* cShapeInfo = vC->shapeInfo(); + const T3 alphaZ = alpha; + const T3 betaZ = beta; - const int aRank = vA->rankOf(); - const int bRank = vB->rankOf(); - const int cRank = vC->rankOf(); + const bool betaPersent = beta; - const Nd4jLong cLen = vC->lengthOf(); + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* bShapeInfo = vB->shapeInfo(); + const Nd4jLong* cShapeInfo = vC->shapeInfo(); - const int K = vA->sizeAt(aKaxis); + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); - auto func = PRAGMA_THREADS_FOR { + const Nd4jLong cLen = vC->lengthOf(); - std::vector aCoords(2), bCoords(2), cCoords(2); + const int K = vA->sizeAt(aKaxis); - for (auto i = start; i < stop; ++i) { + auto func = PRAGMA_THREADS_FOR { + std::vector aCoords(2), bCoords(2), cCoords(2); - // evaluate C coordinates - shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); + for (auto i = start; i < stop; ++i) { + // evaluate C coordinates + shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); - // evaluate A coordinates - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate B coordinates - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); - auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); - T3 val = A[aOffset] * B[bOffset]; // first iteration + T3 val = A[aOffset] * B[bOffset]; // first iteration - for (int j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (int j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); - if(betaPersent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } - }; + if (betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, cLen); + samediff::Threads::parallel_tad(func, 0, cLen); } - ////////////////////////////////////////////////////////////////////////////// // MXN x N = M -> actual sequence of {M,N} axes doesn't matter template -static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { +static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, + const int incx, const int incy, const int aMaxis, + const double alpha, const double beta) { + const T1* A = vA->bufferAsT(); + const T2* X = vX->bufferAsT(); + T3* Y = vY->bufferAsT(); - const T1* A = vA->bufferAsT(); - const T2* X = vX->bufferAsT(); - T3* Y = vY->bufferAsT(); + const T3 alphaZ = alpha; + const T3 betaZ = beta; - const T3 alphaZ = alpha; - const T3 betaZ = beta; + const bool betaPersent = beta; - const bool betaPersent = beta; + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* xShapeInfo = vX->shapeInfo(); + const Nd4jLong* yShapeInfo = vY->shapeInfo(); - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* xShapeInfo = vX->shapeInfo(); - const Nd4jLong* yShapeInfo = vY->shapeInfo(); + const int N = vX->lengthOf(); + const int M = vY->lengthOf(); - const int N = vX->lengthOf(); - const int M = vY->lengthOf(); + const auto aMstride = vA->strideAt(aMaxis); + const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); - const auto aMstride = vA->strideAt(aMaxis); - const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; - for (auto i = start; i < stop; ++i) { + T3 val = A[aOffset] * X[xOffset]; // first iteration - // evaluate offsets - auto aOffset = i * aMstride; - auto xOffset = 0; + for (int j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } - T3 val = A[aOffset] * X[xOffset]; // first iteration + auto yOffset = i * incy; - for (int j = 1; j < N; ++j) { // rest iterations - aOffset += aNstride; - xOffset += incx; - val = val + A[aOffset] * X[xOffset]; - } - - auto yOffset = i * incy; - - if(betaPersent) - Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; - else - Y[yOffset] = alphaZ * val; - } - }; + if (betaPersent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, M); + samediff::Threads::parallel_tad(func, 0, M); } ////////////////////////////////////////////////////////////////////////////// // (X*Y) = Z[0] template -static void usualDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); - T3 alphaZ(alpha), betaZ(beta); - - const bool betaPersent = beta; - - T3 sum = 0; - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) - for(Nd4jLong i = 0; i < length; ++i) - sum += X[i * incx] * Y[i * incy]; - - if(betaPersent) - *Z = alphaZ * sum + betaZ * *Z; - else - *Z = alphaZ * sum; +static void usualDot(const Nd4jLong length, const double alpha, const void* vX, + const Nd4jLong incx, const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); + T3 alphaZ(alpha), betaZ(beta); + + const bool betaPersent = beta; + + T3 sum = 0; + PRAGMA_OMP_PARALLEL_FOR_ARGS( + OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) + schedule(guided) reduction(OMP_SUMT + : sum)) + for (Nd4jLong i = 0; i < length; ++i) sum += X[i * incx] * Y[i * incy]; + + if (betaPersent) + *Z = alphaZ * sum + betaZ * *Z; + else + *Z = alphaZ * sum; } ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - if (A->dataType() != B->dataType()) - throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType()); - - if (C != nullptr && A->dataType() != C->dataType()) - throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType()); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); - - const auto M = A->sizeAt(0); - const auto K = A->sizeAt(1); - const auto N = B->sizeAt(1); - - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of columns !"); - - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const auto aType = A->dataType(); - const auto bType = B->dataType(); - const auto cType = C->dataType(); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - const bool hasGemm = BlasHelper::getInstance()->hasGEMM(aType); - - const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; - const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; - - if(!typeFloat && !typeDouble) { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES); - // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + if (A->dataType() != B->dataType()) + throw datatype_exception::build( + "mmulMxM expects all data types to be the same", A->dataType(), + B->dataType()); + + if (C != nullptr && A->dataType() != C->dataType()) + throw datatype_exception::build( + "mmulMxM expects all data types to be the same", A->dataType(), + C->dataType()); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of A array is not equal 2 !"); + if (B->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if (C != nullptr && C->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM: rank of C array is not equal 2 !"); + if (B->sizeAt(0) != K) + throw std::runtime_error( + "MmulHelper::mmulMxM: B array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error( + "MmulHelper::mmulMxM: C array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error( + "MmulHelper::mmulMxM: C array has wrong number of columns !"); + + if (C == nullptr) + C = new NDArray( + outOrder, {M, N}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + const bool hasGemm = BlasHelper::getInstance()->hasGEMM(aType); + + const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; + const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; + + if (!typeFloat && !typeDouble) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, + (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), + NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, + // 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + } else { + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), + *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if (!aMcont && !aKcont) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + aMcont = true; + } + if (!bKcont && !bNcont) { + pB = new NDArray(B->dup('f')); + toDelete.push_back(pB); + bKcont = true; + } + if (!cMcont && !cNcont) { + pC = new NDArray(C->dup('f')); + toDelete.push_back(pC); + cMcont = true; } - else { - - std::vector toDelete; - - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aKcont = K == 1 || A->strideAt(1) == 1; - bool bKcont = K == 1 || B->strideAt(0) == 1; - bool bNcont = N == 1 || B->strideAt(1) == 1; - bool cMcont = M == 1 || C->strideAt(0) == 1; - bool cNcont = N == 1 || C->strideAt(1) == 1; - - if(!aMcont && !aKcont) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - aMcont = true; - } - if(!bKcont && !bNcont) { - pB = new NDArray(B->dup('f')); - toDelete.push_back(pB); - bKcont = true; - } - if(!cMcont && !cNcont) { - pC = new NDArray(C->dup('f')); - toDelete.push_back(pC); - cMcont = true; - } - - const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; - - const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); - const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); - - const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; - const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; - - const int lda = (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); - const int ldb = (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); - const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); - if(typeFloat) { - BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (float) beta, pC->bufferAsT(), ldc); - } - else if(typeDouble) { - BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (double) beta, pC->bufferAsT(), ldc); - } + const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; + + const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); + const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); + + const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; + + const int lda = + (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = + (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = + (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); + + if (typeFloat) { + BlasHelper::getInstance()->sgemm()( + blasOrder, transAblas, transBblas, M, N, K, (float)alpha, + pA->bufferAsT(), lda, pB->bufferAsT(), ldb, (float)beta, + pC->bufferAsT(), ldc); + } else if (typeDouble) { + BlasHelper::getInstance()->dgemm()( + blasOrder, transAblas, transBblas, M, N, K, (double)alpha, + pA->bufferAsT(), lda, pB->bufferAsT(), ldb, + (double)beta, pC->bufferAsT(), ldc); + } - if(pC != C) { - C->assign(pC); - delete pC; - } - if(pA != A) - delete pA; - if(pB != B) - delete pB; + if (pC != C) { + C->assign(pC); + delete pC; } + if (pA != A) delete pA; + if (pB != B) delete pB; + } - return C; + return C; } //////////////////////////////////////////////////////////////////////////// // MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const double alpha, const double beta, const char outOrder) { - - if (X->dataType() != A->dataType()) - throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); - - if (Y != nullptr && X->dataType() != Y->dataType()) - throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType()); - - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - if (Y->isEmpty()) - return Y; - - const int incx = X->stridesOf()[xLenDim]; - const int incy = Y->stridesOf()[yLenDim]; - - const auto aType = A->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - const bool hasGemv = BlasHelper::getInstance()->hasGEMV(aType); - - const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; - const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; - - if(!typeDouble && !typeFloat) { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); - // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, + const double alpha, const double beta, + const char outOrder) { + if (X->dataType() != A->dataType()) + throw datatype_exception::build( + "mmulMxV expects all data types to be the same", A->dataType(), + X->dataType()); + + if (Y != nullptr && X->dataType() != Y->dataType()) + throw datatype_exception::build( + "mmulMxV expects all data types to be the same", A->dataType(), + Y->dataType()); + + int xLenDim, yLenDim(0); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxV: rank of A array is not equal 2 !"); + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if (Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); + if (X->lengthOf() != N) + throw std::runtime_error( + "MmulHelper::mmulMxV: X vector has wrong length !"); + if (Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); + + if (Y == nullptr) + Y = new NDArray( + outOrder, {M}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), + A->getContext()); + + if (Y->isEmpty()) return Y; + + const int incx = X->stridesOf()[xLenDim]; + const int incy = Y->stridesOf()[yLenDim]; + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + const bool hasGemv = BlasHelper::getInstance()->hasGEMV(aType); + + const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; + const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; + + if (!typeDouble && !typeFloat) { + BUILD_SINGLE_SELECTOR_THRICE( + aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, + // incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + } else { + NDArray* pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if (!aMcont && !aNcont) { + pA = new NDArray(A->dup('f')); + aMcont = true; } - else { - - NDArray *pA(const_cast(A)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aNcont = N == 1 || A->strideAt(1) == 1; - - if(!aMcont && !aNcont) { - pA = new NDArray(A->dup('f')); - aMcont = true; - } - const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; - - const int lda = (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); - - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy); - } - else if(typeFloat) { - BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), incy); - } - - if(pA != A) - delete pA; + const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; + + const int lda = + (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + BlasHelper::getInstance()->dgemv()( + blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->buffer(), lda, + (double*)X->buffer(), incx, beta, (double*)Y->buffer(), incy); + } else if (typeFloat) { + BlasHelper::getInstance()->sgemv()( + blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->buffer(), + lda, (float*)X->buffer(), incx, (float)beta, (float*)Y->buffer(), + incy); } - return Y; + if (pA != A) delete pA; + } + + return Y; } //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, const double alpha, const double beta) { - if (X->dataType() != Y->dataType()) - throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType()); - - if (Z != nullptr && X->dataType() != Z->dataType()) - throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType()); - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot: X array must be vector !"); - if(!shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->stridesOf()[xLenDim]; - const Nd4jLong incy = Y->stridesOf()[yLenDim]; - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - - return Z; +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, + const double alpha, const double beta) { + if (X->dataType() != Y->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", + X->dataType(), Y->dataType()); + + if (Z != nullptr && X->dataType() != Z->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", + X->dataType(), Z->dataType()); + + int xLenDim(0), yLenDim(0); + + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot: X array must be vector !"); + if (!shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); + if (Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if (Y->lengthOf() != length) + throw std::runtime_error( + "MmulHelper::dot: lengths of input vectors are different !"); + + if (Z == nullptr) + Z = new NDArray( + DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), + X->getContext()); + + const Nd4jLong incx = X->stridesOf()[xLenDim]; + const Nd4jLong incy = Y->stridesOf()[yLenDim]; + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualDot, + (length, alpha, X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), + NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, + // X->buffer(), incx, Y->buffer(), incy, beta, Z->buffer()), LIBND4J_TYPES, + // FLOAT_TYPES, FLOAT_TYPES); + + return Z; } ////////////////////////////////////////////////////////////////////////////// @@ -405,79 +450,81 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes template -static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, - const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = vA->bufferAsT(); - const T2* B = vB->bufferAsT(); - T3* C = vC->bufferAsT(); - - const T3 alphaZ = alpha; - const T3 betaZ = beta; - - const bool betaPersent = beta; - - const Nd4jLong* aShapeInfo = vA->shapeInfo(); - const Nd4jLong* bShapeInfo = vB->shapeInfo(); - const Nd4jLong* cShapeInfo = vC->shapeInfo(); - - const int aRank = vA->rankOf(); - const int bRank = vB->rankOf(); - const int cRank = vC->rankOf(); - - const Nd4jLong cLen = vC->lengthOf(); - - const int K = vA->sizeAt(aKaxis); - - auto func = PRAGMA_THREADS_FOR { - - std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); - - for (auto i = start; i < stop; ++i) { - - // evaluate C coordinates - shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); - - // calculate index of current batch - Nd4jLong batchInd; - if(cRank > 2) - batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims); - - // evaluate A coordinates - if(aRank > 2) - shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims); - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; - - // evaluate B coordinates - if(bRank > 2) - shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims); - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; - - auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); - auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); - - T3 val = A[aOffset] * B[bOffset]; // first iteration - - for (int j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } - - auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); - - if(betaPersent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } - }; +static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int* aBatchDims, const int* bBatchDims, + const int* cBatchDims, const int aMaxis, + const int aKaxis, const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, + const double beta) { + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->shapeInfo(); + const Nd4jLong* bShapeInfo = vB->shapeInfo(); + const Nd4jLong* cShapeInfo = vC->shapeInfo(); + + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); + + const Nd4jLong cLen = vC->lengthOf(); + + const int K = vA->sizeAt(aKaxis); + + auto func = PRAGMA_THREADS_FOR { + std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); + + for (auto i = start; i < stop; ++i) { + // evaluate C coordinates + shape::index2coordsCPU(start, i, cShapeInfo, cCoords.data()); + + // calculate index of current batch + Nd4jLong batchInd; + if (cRank > 2) + batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, + cBatchDims); + + // evaluate A coordinates + if (aRank > 2) + shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, + aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if (bRank > 2) + shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, + bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (int j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if (betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; - samediff::Threads::parallel_tad(func, 0, cLen); + samediff::Threads::parallel_tad(func, 0, cLen); } ////////////////////////////////////////////////////////////////////////// @@ -485,89 +532,109 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, // [bS,M,K] x [K,N] = [bS,M,N] // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else { - C = new NDArray(outOrder, cExpectedShape, B->dataType()); - } - - if (C->isEmpty()) - return C; - - const int cRank = C->rankOf(); - - const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); - - std::vector aBatchDims, bBatchDims, cBatchDims; - - if(aRank > 2) - aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); - if(bRank > 2) - bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); - if(cRank > 2) - cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); - - // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES); - - return C; +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if (aRank > bRank && bRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if (bRank > aRank && aRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank) { + for (int i = 0; i < aRank - 2; ++i) + if (A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable " + "for matrix multiplication !"); + } + + if (A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for " + "matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = + aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if (C != nullptr) { + if (!C->isSameShape(cExpectedShape)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shape of C array is not suitable for AxB " + "matrix multiplication !"); + } else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + if (C->isEmpty()) return C; + + const int cRank = C->rankOf(); + + const int aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), + bNaxis(bRank - 1), cMaxis(cRank - 2), cNaxis(cRank - 1); + + std::vector aBatchDims, bBatchDims, cBatchDims; + + if (aRank > 2) + aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); + if (bRank > 2) + bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); + if (cRank > 2) + cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); + + // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), + // batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), + // cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, + // beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + A->dataType(), batchedGemm, + (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, + aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), + NUMERIC_TYPES); + + return C; } /* ////////////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else { C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -575,15 +642,16 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); + const std::vector dimsToExclude = +ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const Nd4jLong +numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); std::vector idxRanges(2 * C->rankOf()); // #pragma omp parallel for schedule(guided) firstprivate(idxRanges) for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, +idxRanges.data()); NDArray cSubArr = (*C)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*A)(idxRanges); @@ -606,7 +674,10 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN template -static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static void usualGemm(const char cOrder, const bool transA, const bool transB, +const int M, const int N, const int K, const double alpha, const void* vA, const +int lda, const void* vB, const int ldb, const double beta, void* vC, const int +ldc) { T1* A = reinterpret_cast(const_cast(vA)); T2* B = reinterpret_cast(const_cast(vB)); @@ -617,7 +688,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c const bool flagA = (flagC && transA) || (!flagC && !transA); const bool flagB = (flagC && transB) || (!flagC && !transB); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > +Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // for(uint row = 0; row < M; ++row) { // T3* c = flagC ? (C + row) : (C + row * ldc); @@ -633,7 +705,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // if(flagC) { // for(uint col = 0; col < N; ++col) { // if(betaZ) - // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; + // c[col * ldc] += a * b[flagB ? col : col * ldb] + +betaZ * c[col * ldc]; // else // c[col * ldc] += a * b[flagB ? col : col * ldb]; // } @@ -641,7 +714,8 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // else { // for(uint col = 0; col < N; ++col) { // if(betaZ) - // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; + // c[col] += a * b[flagB ? col : col * ldb] + betaZ * +c[col]; // else // c[col] += a * b[flagB ? col : col * ldb]; // } @@ -675,7 +749,9 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template -static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +static void usualGemv(const char aOrder, const int M, const int N, const double +alpha, const void* vA, const int lda, const void* vX, const int incx, const +double beta, void* vY, const int incy) { T1* A = reinterpret_cast(const_cast(vA)); T2* X = reinterpret_cast(const_cast(vX)); @@ -707,8 +783,17 @@ static void usualGemv(const char aOrder, const int M, const int N, const double } */ -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - -} +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool +// transA, const bool transB, const int M, const int N, const int K, const double +// alpha, const void* A, const int lda, const void* B, const int ldb, const +// double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char +// aOrder, const int M, const int N, const double alpha, const void* A, const int +// lda, const void* B, const int incx, const double beta, void* C, const int +// incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const +// double alpha, const void* vX, const Nd4jLong incx, const void* vY, const +// Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); + +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/PointersManager.cpp b/libnd4j/include/helpers/cpu/PointersManager.cpp index 61eb7b2ecccf..839474254393 100644 --- a/libnd4j/include/helpers/cpu/PointersManager.cpp +++ b/libnd4j/include/helpers/cpu/PointersManager.cpp @@ -20,35 +20,37 @@ #ifndef __CUDABLAS__ -#include #include +#include #include #include namespace sd { ////////////////////////////////////////////////////////////////////////// -PointersManager::PointersManager(const sd::LaunchContext *context, const std::string& funcName) { - _context = const_cast(context); - _funcName = funcName; +PointersManager::PointersManager(const sd::LaunchContext* context, + const std::string& funcName) { + _context = const_cast(context); + _funcName = funcName; } ////////////////////////////////////////////////////////////////////////// -void* PointersManager::replicatePointer(const void* src, const size_t numberOfBytes) { - // no-op - return const_cast(src); +void* PointersManager::replicatePointer(const void* src, + const size_t numberOfBytes) { + // no-op + return const_cast(src); } ////////////////////////////////////////////////////////////////////////// void PointersManager::synchronize() const { - // no-op + // no-op } ////////////////////////////////////////////////////////////////////////// PointersManager::~PointersManager() { - // no-op + // no-op } -} +} // namespace sd #endif diff --git a/libnd4j/include/helpers/cpu/biDiagonalUp.cpp b/libnd4j/include/helpers/cpu/biDiagonalUp.cpp index 4623a93ad292..38d99d1c01a4 100644 --- a/libnd4j/include/helpers/cpu/biDiagonalUp.cpp +++ b/libnd4j/include/helpers/cpu/biDiagonalUp.cpp @@ -18,163 +18,170 @@ // Created by Yurii Shyrma on 18.12.2017 // - -#include -#include #include - +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -BiDiagonalUp::BiDiagonalUp(const NDArray& matrix): _HHmatrix(sd::NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(0), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())), - _HHbidiag(sd::NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext())) { - - // input validation - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::biDiagonalizeUp constructor: input array must be 2D matrix !"); - - _HHmatrix.assign(&matrix); - _HHbidiag.assign(0.); - - evalData(); - +BiDiagonalUp::BiDiagonalUp(const NDArray& matrix) + : _HHmatrix(sd::NDArrayFactory::create( + matrix.ordering(), {matrix.sizeAt(0), matrix.sizeAt(1)}, + matrix.dataType(), matrix.getContext())), + _HHbidiag(sd::NDArrayFactory::create( + matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, + matrix.dataType(), matrix.getContext())) { + // input validation + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::biDiagonalizeUp constructor: input array must be 2D " + "matrix !"); + + _HHmatrix.assign(&matrix); + _HHbidiag.assign(0.); + + evalData(); } - template - void BiDiagonalUp::_evalData() { - - const auto rows = _HHmatrix.sizeAt(0); - const auto cols = _HHmatrix.sizeAt(1); - - if(rows < cols) - throw std::runtime_error("ops::helpers::BiDiagonalizeUp::evalData method: this procedure is applicable only for input matrix with rows >= cols !"); +template +void BiDiagonalUp::_evalData() { + const auto rows = _HHmatrix.sizeAt(0); + const auto cols = _HHmatrix.sizeAt(1); - NDArray* bottomRightCorner(nullptr), *column(nullptr), *row(nullptr); - T coeff, normX; + if (rows < cols) + throw std::runtime_error( + "ops::helpers::BiDiagonalizeUp::evalData method: this procedure is " + "applicable only for input matrix with rows >= cols !"); - T _x, _y; + NDArray *bottomRightCorner(nullptr), *column(nullptr), *row(nullptr); + T coeff, normX; - for(Nd4jLong i = 0; i < cols-1; ++i ) { + T _x, _y; - // evaluate Householder matrix nullifying columns - column = new NDArray(_HHmatrix({i,rows, i,i+1}, true)); + for (Nd4jLong i = 0; i < cols - 1; ++i) { + // evaluate Householder matrix nullifying columns + column = new NDArray(_HHmatrix({i, rows, i, i + 1}, true)); - _x = _HHmatrix.e(i,i); - _y = _HHbidiag.e(i,i); + _x = _HHmatrix.e(i, i); + _y = _HHbidiag.e(i, i); - Householder::evalHHmatrixDataI(*column, _x, _y); + Householder::evalHHmatrixDataI(*column, _x, _y); - _HHmatrix.p(i, i, _x); - _HHbidiag.p(i, i, _y); + _HHmatrix.p(i, i, _x); + _HHbidiag.p(i, i, _y); - // multiply corresponding matrix block on householder matrix from the left: P * bottomRightCorner - bottomRightCorner = new NDArray(_HHmatrix({i,rows, i+1,cols}, true)); // {i, cols} - Householder::mulLeft(*bottomRightCorner, _HHmatrix({i+1,rows, i,i+1}, true), _HHmatrix.e(i,i)); + // multiply corresponding matrix block on householder matrix from the left: + // P * bottomRightCorner + bottomRightCorner = + new NDArray(_HHmatrix({i, rows, i + 1, cols}, true)); // {i, cols} + Householder::mulLeft(*bottomRightCorner, + _HHmatrix({i + 1, rows, i, i + 1}, true), + _HHmatrix.e(i, i)); - delete bottomRightCorner; - delete column; + delete bottomRightCorner; + delete column; - if(i == cols-2) - continue; // do not apply right multiplying at last iteration + if (i == cols - 2) + continue; // do not apply right multiplying at last iteration - // evaluate Householder matrix nullifying rows - row = new NDArray(_HHmatrix({i,i+1, i+1,cols}, true)); + // evaluate Householder matrix nullifying rows + row = new NDArray(_HHmatrix({i, i + 1, i + 1, cols}, true)); - _x = _HHmatrix.e(i,i+1); - _y = _HHbidiag.e(i,i+1); + _x = _HHmatrix.e(i, i + 1); + _y = _HHbidiag.e(i, i + 1); - Householder::evalHHmatrixDataI(*row, _x, _y); + Householder::evalHHmatrixDataI(*row, _x, _y); - _HHmatrix.p(i, i+1, _x); - _HHbidiag.p(i, i+1, _y); + _HHmatrix.p(i, i + 1, _x); + _HHbidiag.p(i, i + 1, _y); - // multiply corresponding matrix block on householder matrix from the right: bottomRightCorner * P - bottomRightCorner = new NDArray(_HHmatrix({i+1,rows, i+1,cols}, true)); // {i, rows} + // multiply corresponding matrix block on householder matrix from the right: + // bottomRightCorner * P + bottomRightCorner = + new NDArray(_HHmatrix({i + 1, rows, i + 1, cols}, true)); // {i, rows} - Householder::mulRight(*bottomRightCorner, _HHmatrix({i,i+1, i+2,cols}, true), _HHmatrix.e(i,i+1)); + Householder::mulRight(*bottomRightCorner, + _HHmatrix({i, i + 1, i + 2, cols}, true), + _HHmatrix.e(i, i + 1)); - delete bottomRightCorner; - delete row; - } + delete bottomRightCorner; + delete row; + } - row = new NDArray(_HHmatrix({cols-2,cols-1, cols-1,cols}, true)); + row = new NDArray(_HHmatrix({cols - 2, cols - 1, cols - 1, cols}, true)); - _x = _HHmatrix.e(cols-2,cols-1); - _y = _HHbidiag.e(cols-2,cols-1); + _x = _HHmatrix.e(cols - 2, cols - 1); + _y = _HHbidiag.e(cols - 2, cols - 1); - Householder::evalHHmatrixDataI(*row, _x, _y); + Householder::evalHHmatrixDataI(*row, _x, _y); - _HHmatrix.p(cols-2,cols-1, _x); - _HHbidiag.p(cols-2,cols-1, _y); + _HHmatrix.p(cols - 2, cols - 1, _x); + _HHbidiag.p(cols - 2, cols - 1, _y); - delete row; + delete row; - column = new NDArray(_HHmatrix({cols-1,rows, cols-1,cols}, true)); + column = new NDArray(_HHmatrix({cols - 1, rows, cols - 1, cols}, true)); - _x = _HHmatrix.e(cols-1,cols-1); - _y = _HHbidiag.e(cols-1,cols-1); + _x = _HHmatrix.e(cols - 1, cols - 1); + _y = _HHbidiag.e(cols - 1, cols - 1); - Householder::evalHHmatrixDataI(*column, _x, _y); + Householder::evalHHmatrixDataI(*column, _x, _y); - _HHmatrix.p(cols-1, cols-1, _x); - _HHbidiag.p(cols-1, cols-1, _y); + _HHmatrix.p(cols - 1, cols - 1, _x); + _HHbidiag.p(cols - 1, cols - 1, _y); - delete column; - } + delete column; +} ////////////////////////////////////////////////////////////////////////// void BiDiagonalUp::evalData() { - auto xType = _HHmatrix.dataType(); + auto xType = _HHmatrix.dataType(); - BUILD_SINGLE_SELECTOR(xType, _evalData, ();, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, _evalData, ();, FLOAT_TYPES); } - ////////////////////////////////////////////////////////////////////////// template HHsequence BiDiagonalUp::makeHHsequence_(const char type) const { - - if(type == 'u') { - - const int diagSize = _HHbidiag.sizeAt(0); - auto colOfCoeffs = NDArrayFactory::create(_HHmatrix.ordering(), {diagSize, 1}, _HHmatrix.dataType(), _HHmatrix.getContext()); - - for(int i = 0; i < diagSize; ++i) - colOfCoeffs.p(i, _HHmatrix.e(i,i)); - - return HHsequence(_HHmatrix, colOfCoeffs, type); - } - else { - - const int diagUpSize = _HHbidiag.sizeAt(0) - 1; - NDArray colOfCoeffs = NDArrayFactory::create(_HHmatrix.ordering(), {diagUpSize, 1}, _HHmatrix.dataType(), _HHmatrix.getContext()); - - for(int i = 0; i < diagUpSize; ++i) - colOfCoeffs.p(i, _HHmatrix.e(i,i+1)); - - HHsequence result(_HHmatrix, colOfCoeffs, type); - result._diagSize = diagUpSize; - result._shift = 1; - - return result; - } + if (type == 'u') { + const int diagSize = _HHbidiag.sizeAt(0); + auto colOfCoeffs = + NDArrayFactory::create(_HHmatrix.ordering(), {diagSize, 1}, + _HHmatrix.dataType(), _HHmatrix.getContext()); + + for (int i = 0; i < diagSize; ++i) colOfCoeffs.p(i, _HHmatrix.e(i, i)); + + return HHsequence(_HHmatrix, colOfCoeffs, type); + } else { + const int diagUpSize = _HHbidiag.sizeAt(0) - 1; + NDArray colOfCoeffs = + NDArrayFactory::create(_HHmatrix.ordering(), {diagUpSize, 1}, + _HHmatrix.dataType(), _HHmatrix.getContext()); + + for (int i = 0; i < diagUpSize; ++i) + colOfCoeffs.p(i, _HHmatrix.e(i, i + 1)); + + HHsequence result(_HHmatrix, colOfCoeffs, type); + result._diagSize = diagUpSize; + result._shift = 1; + + return result; + } } - HHsequence BiDiagonalUp::makeHHsequence(const char type) const { - auto xType = _HHmatrix.dataType(); - - BUILD_SINGLE_SELECTOR(xType, return makeHHsequence_, (type);, FLOAT_TYPES); - } - +HHsequence BiDiagonalUp::makeHHsequence(const char type) const { + auto xType = _HHmatrix.dataType(); + BUILD_SINGLE_SELECTOR(xType, return makeHHsequence_, (type);, FLOAT_TYPES); +} BUILD_SINGLE_TEMPLATE(template void BiDiagonalUp::_evalData, (), FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template HHsequence BiDiagonalUp::makeHHsequence_, (const char type) const, FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template HHsequence BiDiagonalUp::makeHHsequence_, + (const char type) const, FLOAT_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/cublasHelper.cpp b/libnd4j/include/helpers/cpu/cublasHelper.cpp index f6f718702e39..6def885b666d 100644 --- a/libnd4j/include/helpers/cpu/cublasHelper.cpp +++ b/libnd4j/include/helpers/cpu/cublasHelper.cpp @@ -21,41 +21,25 @@ #include "../cublasHelper.h" namespace sd { - static void* handle_() { - return nullptr; - } +static void* handle_() { return nullptr; } - static void destroyHandle_(void* handle) { +static void destroyHandle_(void* handle) {} - } +CublasHelper::CublasHelper() {} - CublasHelper::CublasHelper() { +CublasHelper::~CublasHelper() {} - } +CublasHelper* CublasHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new sd::CublasHelper(); - CublasHelper::~CublasHelper() { + return _INSTANCE; +} - } +void* CublasHelper::handle() { return nullptr; } - CublasHelper* CublasHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new sd::CublasHelper(); +void* CublasHelper::solver() { return nullptr; } - return _INSTANCE; - } +void* CublasHelper::handle(int deviceId) { return nullptr; } - void* CublasHelper::handle() { - return nullptr; - } - - void* CublasHelper::solver() { - return nullptr; - } - - void* CublasHelper::handle(int deviceId) { - return nullptr; - } - - - sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0; -} \ No newline at end of file +sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/hhColPivQR.cpp b/libnd4j/include/helpers/cpu/hhColPivQR.cpp index 1a57749d0ebe..1a2cd41f53ca 100644 --- a/libnd4j/include/helpers/cpu/hhColPivQR.cpp +++ b/libnd4j/include/helpers/cpu/hhColPivQR.cpp @@ -18,149 +18,152 @@ // Created by Yurii Shyrma on 11.01.2018 // +#include #include #include -#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// HHcolPivQR::HHcolPivQR(const NDArray& matrix) { + _qr = matrix; + _diagSize = math::nd4j_min(matrix.sizeAt(0), matrix.sizeAt(1)); + _coeffs = NDArrayFactory::create(matrix.ordering(), {1, _diagSize}, + matrix.dataType(), matrix.getContext()); - _qr = matrix; - _diagSize = math::nd4j_min(matrix.sizeAt(0), matrix.sizeAt(1)); - _coeffs = NDArrayFactory::create(matrix.ordering(), {1, _diagSize}, matrix.dataType(), matrix.getContext()); - - _permut = NDArrayFactory::create(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(1)}, matrix.dataType(), matrix.getContext()); + _permut = NDArrayFactory::create(matrix.ordering(), + {matrix.sizeAt(1), matrix.sizeAt(1)}, + matrix.dataType(), matrix.getContext()); - evalData(); + evalData(); } - void HHcolPivQR::evalData() { - BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES); - } +void HHcolPivQR::evalData() { + BUILD_SINGLE_SELECTOR(_qr.dataType(), _evalData, (), FLOAT_TYPES); +} ////////////////////////////////////////////////////////////////////////// template void HHcolPivQR::_evalData() { - - int rows = _qr.sizeAt(0); - int cols = _qr.sizeAt(1); - - auto transp = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext()); - auto normsUpd = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext()); - auto normsDir = NDArrayFactory::create(_qr.ordering(), {1, cols}, _qr.dataType(), _qr.getContext()); - - int transpNum = 0; - - for (int k = 0; k < cols; ++k) { - - T norm = _qr({0,0, k,k+1}).reduceNumber(reduce::Norm2).e(0); - normsDir.p(k, norm); - normsUpd.p(k, norm); + int rows = _qr.sizeAt(0); + int cols = _qr.sizeAt(1); + + auto transp = NDArrayFactory::create(_qr.ordering(), {1, cols}, + _qr.dataType(), _qr.getContext()); + auto normsUpd = NDArrayFactory::create(_qr.ordering(), {1, cols}, + _qr.dataType(), _qr.getContext()); + auto normsDir = NDArrayFactory::create(_qr.ordering(), {1, cols}, + _qr.dataType(), _qr.getContext()); + + int transpNum = 0; + + for (int k = 0; k < cols; ++k) { + T norm = _qr({0, 0, k, k + 1}).reduceNumber(reduce::Norm2).e(0); + normsDir.p(k, norm); + normsUpd.p(k, norm); + } + + T normScaled = + (normsUpd.reduceNumber(reduce::Max)).e(0) * DataTypeUtils::eps(); + T threshold1 = normScaled * normScaled / (T)rows; + T threshold2 = math::nd4j_sqrt(DataTypeUtils::eps()); + + T nonZeroPivots = _diagSize; + T maxPivot = 0.; + + for (int k = 0; k < _diagSize; ++k) { + int biggestColIndex = normsUpd({0, 0, k, -1}) + .indexReduceNumber(indexreduce::IndexMax) + .e(0); + T biggestColNorm = + normsUpd({0, 0, k, -1}).reduceNumber(reduce::Max).e(0); + T biggestColSqNorm = biggestColNorm * biggestColNorm; + biggestColIndex += k; + + if (nonZeroPivots == (T)_diagSize && + biggestColSqNorm < threshold1 * (T)(rows - k)) + nonZeroPivots = k; + + transp.p(k, (T)biggestColIndex); + + if (k != biggestColIndex) { + auto temp1 = _qr({0, 0, k, k + 1}, true); + auto temp2 = _qr({0, 0, biggestColIndex, biggestColIndex + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + + T e0 = normsUpd.e(k); + T e1 = normsUpd.e(biggestColIndex); + normsUpd.p(k, e1); + normsUpd.p(biggestColIndex, e0); + // math::nd4j_swap(normsUpd(k), normsUpd(biggestColIndex)); + + e0 = normsDir.e(k); + e1 = normsDir.e(biggestColIndex); + normsDir.p(k, e1); + normsDir.p(biggestColIndex, e0); + // math::nd4j_swap(normsDir(k), normsDir(biggestColIndex)); + + ++transpNum; } - T normScaled = (normsUpd.reduceNumber(reduce::Max)).e(0) * DataTypeUtils::eps(); - T threshold1 = normScaled * normScaled / (T)rows; - T threshold2 = math::nd4j_sqrt(DataTypeUtils::eps()); - - T nonZeroPivots = _diagSize; - T maxPivot = 0.; - - for(int k = 0; k < _diagSize; ++k) { - - int biggestColIndex = normsUpd({0,0, k,-1}).indexReduceNumber(indexreduce::IndexMax).e(0); - T biggestColNorm = normsUpd({0,0, k,-1}).reduceNumber(reduce::Max).e(0); - T biggestColSqNorm = biggestColNorm * biggestColNorm; - biggestColIndex += k; - - if(nonZeroPivots == (T)_diagSize && biggestColSqNorm < threshold1 * (T)(rows-k)) - nonZeroPivots = k; - - transp.p(k, (T)biggestColIndex); - - if(k != biggestColIndex) { - - auto temp1 = _qr({0,0, k,k+1}, true); - auto temp2 = _qr({0,0, biggestColIndex,biggestColIndex+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - - T e0 = normsUpd.e(k); - T e1 = normsUpd.e(biggestColIndex); - normsUpd.p(k, e1); - normsUpd.p(biggestColIndex, e0); - //math::nd4j_swap(normsUpd(k), normsUpd(biggestColIndex)); - - e0 = normsDir.e(k); - e1 = normsDir.e(biggestColIndex); - normsDir.p(k, e1); - normsDir.p(biggestColIndex, e0); - //math::nd4j_swap(normsDir(k), normsDir(biggestColIndex)); - - ++transpNum; - } - - T normX; - { - auto qrBlock = _qr({k, rows, k, k + 1}, true); - T c; - Householder::evalHHmatrixDataI(qrBlock, c, normX); - _coeffs.p(k, c); - } - - _qr.p(k,k, normX); - - T max = math::nd4j_abs(normX); - if(max > maxPivot) - maxPivot = max; - - if(k < rows && (k+1) < cols) { - auto qrBlock = _qr({k, rows, k+1,cols}, true); - auto tail = _qr({k+1,rows, k, k+1}, true); - Householder::mulLeft(qrBlock, tail, _coeffs.e(k)); - } - - for (int j = k + 1; j < cols; ++j) { - - if (normsUpd.e(j) != (T)0.f) { - T temp = math::nd4j_abs(_qr.e(k, j)) / normsUpd.e(j); - temp = (1. + temp) * (1. - temp); - temp = temp < (T)0. ? (T)0. : temp; - T temp2 = temp * normsUpd.e(j) * normsUpd.e(j) / (normsDir.e(j)*normsDir.e(j)); - - if (temp2 <= threshold2) { - if(k+1 < rows && j < cols) - normsDir.p(j, _qr({k+1,rows, j,j+1}).reduceNumber(reduce::Norm2).e(0)); - - normsUpd.p(j, normsDir.e(j)); - } - else - normsUpd.p(j, normsUpd.e(j) * math::nd4j_sqrt(temp)); - } - } + T normX; + { + auto qrBlock = _qr({k, rows, k, k + 1}, true); + T c; + Householder::evalHHmatrixDataI(qrBlock, c, normX); + _coeffs.p(k, c); } - _permut.setIdentity(); - - for(int k = 0; k < _diagSize; ++k) { - - int idx = transp.e(k); - auto temp1 = _permut({0,0, k, k+1}, true); - auto temp2 = _permut({0,0, idx,idx+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } -} + _qr.p(k, k, normX); - BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES); + T max = math::nd4j_abs(normX); + if (max > maxPivot) maxPivot = max; + if (k < rows && (k + 1) < cols) { + auto qrBlock = _qr({k, rows, k + 1, cols}, true); + auto tail = _qr({k + 1, rows, k, k + 1}, true); + Householder::mulLeft(qrBlock, tail, _coeffs.e(k)); + } + + for (int j = k + 1; j < cols; ++j) { + if (normsUpd.e(j) != (T)0.f) { + T temp = math::nd4j_abs(_qr.e(k, j)) / normsUpd.e(j); + temp = (1. + temp) * (1. - temp); + temp = temp < (T)0. ? (T)0. : temp; + T temp2 = temp * normsUpd.e(j) * normsUpd.e(j) / + (normsDir.e(j) * normsDir.e(j)); + + if (temp2 <= threshold2) { + if (k + 1 < rows && j < cols) + normsDir.p(j, _qr({k + 1, rows, j, j + 1}) + .reduceNumber(reduce::Norm2) + .e(0)); + + normsUpd.p(j, normsDir.e(j)); + } else + normsUpd.p(j, normsUpd.e(j) * math::nd4j_sqrt(temp)); + } + } + } + + _permut.setIdentity(); + + for (int k = 0; k < _diagSize; ++k) { + int idx = transp.e(k); + auto temp1 = _permut({0, 0, k, k + 1}, true); + auto temp2 = _permut({0, 0, idx, idx + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } } -} -} +BUILD_SINGLE_TEMPLATE(template void HHcolPivQR::_evalData, (), FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/hhSequence.cpp b/libnd4j/include/helpers/cpu/hhSequence.cpp index 84c9e017a922..8a53024723de 100644 --- a/libnd4j/include/helpers/cpu/hhSequence.cpp +++ b/libnd4j/include/helpers/cpu/hhSequence.cpp @@ -18,106 +18,99 @@ // Created by Yurii Shyrma on 02.01.2018 // +#include #include #include -#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type): _vectors(vectors), _coeffs(coeffs) { - - _diagSize = sd::math::nd4j_min(_vectors.sizeAt(0), _vectors.sizeAt(1)); - _shift = 0; - _type = type; +HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, + const char type) + : _vectors(vectors), _coeffs(coeffs) { + _diagSize = sd::math::nd4j_min(_vectors.sizeAt(0), _vectors.sizeAt(1)); + _shift = 0; + _type = type; } ////////////////////////////////////////////////////////////////////////// template void HHsequence::_mulLeft(NDArray& matrix) { - - const int rows = _vectors.sizeAt(0); - const int cols = _vectors.sizeAt(1); - const int inRows = matrix.sizeAt(0); - - NDArray block; - - for(int i = _diagSize - 1; i >= 0; --i) { - - if(_type == 'u') { - - block = matrix({inRows-rows+_shift+ i,inRows, 0,0}, true); - T _x = _coeffs.e(i); - Householder::mulLeft(block, _vectors({i + 1 + _shift, rows, i, i+1}, true), _x); - _coeffs.p(i, _x); - } - else { - - block = matrix({inRows-cols+_shift+i,inRows, 0,0}, true); - T _x = _coeffs.e(i); - Householder::mulLeft(block, _vectors({i, i+1, i + 1 + _shift, cols}, true), _x); - _coeffs.p(i, _x); - } - + const int rows = _vectors.sizeAt(0); + const int cols = _vectors.sizeAt(1); + const int inRows = matrix.sizeAt(0); + + NDArray block; + + for (int i = _diagSize - 1; i >= 0; --i) { + if (_type == 'u') { + block = matrix({inRows - rows + _shift + i, inRows, 0, 0}, true); + T _x = _coeffs.e(i); + Householder::mulLeft( + block, _vectors({i + 1 + _shift, rows, i, i + 1}, true), _x); + _coeffs.p(i, _x); + } else { + block = matrix({inRows - cols + _shift + i, inRows, 0, 0}, true); + T _x = _coeffs.e(i); + Householder::mulLeft( + block, _vectors({i, i + 1, i + 1 + _shift, cols}, true), _x); + _coeffs.p(i, _x); } + } } - ////////////////////////////////////////////////////////////////////////// NDArray HHsequence::getTail(const int idx) const { + int first = idx + 1 + _shift; - - int first = idx + 1 + _shift; - - if(_type == 'u') - return _vectors({first, -1, idx, idx+1}, true); - else - return _vectors({idx, idx+1, first, -1}, true); + if (_type == 'u') + return _vectors({first, -1, idx, idx + 1}, true); + else + return _vectors({idx, idx + 1, first, -1}, true); } - ////////////////////////////////////////////////////////////////////////// template void HHsequence::_applyTo(NDArray& dest) { - - int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); - - if(dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size)) - dest = NDArrayFactory::create(dest.ordering(), {size, size}, dest.dataType(), dest.getContext()); - dest.setIdentity(); - - for(int k = _diagSize - 1; k >= 0; --k) { - - int curNum = size - k - _shift; - if(curNum < 1 || (k + 1 + _shift) >= size ) - continue; - auto block = dest({dest.sizeAt(0)-curNum,dest.sizeAt(0), dest.sizeAt(1)-curNum,dest.sizeAt(1)}, true); - T _x = _coeffs.e(k); - - Householder::mulLeft(block, getTail(k), _x); - - _coeffs.p(k, _x); - } -} + int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); + if (dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size)) + dest = NDArrayFactory::create(dest.ordering(), {size, size}, + dest.dataType(), dest.getContext()); + dest.setIdentity(); - void HHsequence::applyTo(NDArray& dest) { - auto xType = _coeffs.dataType(); + for (int k = _diagSize - 1; k >= 0; --k) { + int curNum = size - k - _shift; + if (curNum < 1 || (k + 1 + _shift) >= size) continue; + auto block = dest({dest.sizeAt(0) - curNum, dest.sizeAt(0), + dest.sizeAt(1) - curNum, dest.sizeAt(1)}, + true); + T _x = _coeffs.e(k); - BUILD_SINGLE_SELECTOR(xType, _applyTo, (dest), FLOAT_TYPES); - } + Householder::mulLeft(block, getTail(k), _x); - void HHsequence::mulLeft(NDArray& matrix) { - auto xType = _coeffs.dataType(); + _coeffs.p(k, _x); + } +} - BUILD_SINGLE_SELECTOR(xType, _mulLeft, (matrix), FLOAT_TYPES); - } +void HHsequence::applyTo(NDArray& dest) { + auto xType = _coeffs.dataType(); - BUILD_SINGLE_TEMPLATE(template void HHsequence::_applyTo, (sd::NDArray &dest), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void HHsequence::_mulLeft, (NDArray& matrix), FLOAT_TYPES); -} + BUILD_SINGLE_SELECTOR(xType, _applyTo, (dest), FLOAT_TYPES); } + +void HHsequence::mulLeft(NDArray& matrix) { + auto xType = _coeffs.dataType(); + + BUILD_SINGLE_SELECTOR(xType, _mulLeft, (matrix), FLOAT_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void HHsequence::_applyTo, (sd::NDArray & dest), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void HHsequence::_mulLeft, (NDArray & matrix), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/householder.cpp b/libnd4j/include/helpers/cpu/householder.cpp index 090cea17038c..538dbbd9d66a 100644 --- a/libnd4j/include/helpers/cpu/householder.cpp +++ b/libnd4j/include/helpers/cpu/householder.cpp @@ -18,201 +18,197 @@ // Created by Yurii Shyrma on 18.12.2017 // -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template NDArray Householder::evalHHmatrix(const NDArray& x) { - - // input validation - if(!x.isVector() && !x.isScalar()) - throw std::runtime_error("ops::helpers::Householder::evalHHmatrix method: input array must be vector or scalar!"); - - auto w = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), 1}, x.dataType(), x.getContext()); // column-vector - auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.dataType(), x.getContext()); // row-vector (transposed w) - - T coeff; - T normX = x.reduceNumber(reduce::Norm2).e(0); - - if(normX*normX - x.e(0) * x.e(0) <= DataTypeUtils::min() || x.lengthOf() == 1) { - - normX = x.e(0); - coeff = 0.f; - w = 0.f; - - } - else { - - if(x.e(0) >= (T)0.f) - normX = -normX; // choose opposite sign to lessen roundoff error - - T u0 = x.e(0) - normX; - coeff = -u0 / normX; - w.assign(x / u0); - } - - w.p(Nd4jLong(0), 1.f); - wT.assign(&w); - - NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); - identity.setIdentity(); // identity matrix - - return identity - mmul(w, wT) * coeff; + // input validation + if (!x.isVector() && !x.isScalar()) + throw std::runtime_error( + "ops::helpers::Householder::evalHHmatrix method: input array must be " + "vector or scalar!"); + + auto w = + NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), 1}, x.dataType(), + x.getContext()); // column-vector + auto wT = + NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.dataType(), + x.getContext()); // row-vector (transposed w) + + T coeff; + T normX = x.reduceNumber(reduce::Norm2).e(0); + + if (normX * normX - x.e(0) * x.e(0) <= DataTypeUtils::min() || + x.lengthOf() == 1) { + normX = x.e(0); + coeff = 0.f; + w = 0.f; + + } else { + if (x.e(0) >= (T)0.f) + normX = -normX; // choose opposite sign to lessen roundoff error + + T u0 = x.e(0) - normX; + coeff = -u0 / normX; + w.assign(x / u0); + } + + w.p(Nd4jLong(0), 1.f); + wT.assign(&w); + + NDArray identity = NDArrayFactory::create( + x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), + x.getContext()); + identity.setIdentity(); // identity matrix + + return identity - mmul(w, wT) * coeff; } ////////////////////////////////////////////////////////////////////////// template -void Householder::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX) { - - // input validation - if(!x.isVector() && !x.isScalar()) - throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input array must be vector or scalar!"); - - if(!x.isScalar() && x.lengthOf() != tail.lengthOf() + 1) - throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input tail vector must have length less than unity compared to input x vector!"); - - normX = x.reduceNumber(reduce::Norm2, nullptr).e(0); - - if(normX*normX - x.e(0) * x.e(0) <= DataTypeUtils::min() || x.lengthOf() == 1) { - - normX = x.e(0); - coeff = (T)0.f; - tail = (T)0.f; - } - else { - - if(x.e(0) >= (T)0.f) - normX = -normX; // choose opposite sign to lessen roundoff error - - T u0 = x.e(0) - normX; - coeff = -u0 / normX; - - if(x.isRowVector()) - tail.assign(static_cast(x({0,0, 1,-1})) / u0); - else - tail.assign(static_cast(x({1,-1, 0,0,})) / u0); - } +void Householder::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, + T& normX) { + // input validation + if (!x.isVector() && !x.isScalar()) + throw std::runtime_error( + "ops::helpers::Householder::evalHHmatrixData method: input array must " + "be vector or scalar!"); + + if (!x.isScalar() && x.lengthOf() != tail.lengthOf() + 1) + throw std::runtime_error( + "ops::helpers::Householder::evalHHmatrixData method: input tail vector " + "must have length less than unity compared to input x vector!"); + + normX = x.reduceNumber(reduce::Norm2, nullptr).e(0); + + if (normX * normX - x.e(0) * x.e(0) <= DataTypeUtils::min() || + x.lengthOf() == 1) { + normX = x.e(0); + coeff = (T)0.f; + tail = (T)0.f; + } else { + if (x.e(0) >= (T)0.f) + normX = -normX; // choose opposite sign to lessen roundoff error + + T u0 = x.e(0) - normX; + coeff = -u0 / normX; + + if (x.isRowVector()) + tail.assign(static_cast(x({0, 0, 1, -1})) / u0); + else + tail.assign(static_cast(x({ + 1, + -1, + 0, + 0, + })) / + u0); + } } ////////////////////////////////////////////////////////////////////////// template void Householder::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) { - - int rows = (int)x.lengthOf()-1; - int num = 1; - - if(rows == 0) { - rows = 1; - num = 0; - } - - auto tail = NDArrayFactory::create(x.ordering(), {rows, 1}, x.dataType(), x.getContext()); - evalHHmatrixData(x, tail, coeff, normX); - - if(x.isRowVector()) { - auto temp = x({0,0, num, x.sizeAt(1)}, true); - temp.assign(tail); - } - else { - auto temp = x({num,x.sizeAt(0), 0,0}, true); - temp.assign(tail); - } + int rows = (int)x.lengthOf() - 1; + int num = 1; + + if (rows == 0) { + rows = 1; + num = 0; + } + + auto tail = NDArrayFactory::create(x.ordering(), {rows, 1}, x.dataType(), + x.getContext()); + evalHHmatrixData(x, tail, coeff, normX); + + if (x.isRowVector()) { + auto temp = x({0, 0, num, x.sizeAt(1)}, true); + temp.assign(tail); + } else { + auto temp = x({num, x.sizeAt(0), 0, 0}, true); + temp.assign(tail); + } } ////////////////////////////////////////////////////////////////////////// template -void Householder::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff) { - - // if(matrix.rankOf() != 2) - // throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !"; - - if(matrix.sizeAt(0) == 1) { - matrix *= (T) 1.f - coeff; +void Householder::mulLeft(NDArray& matrix, const NDArray& tail, + const T coeff) { + // if(matrix.rankOf() != 2) + // throw "ops::helpers::Householder::mulLeft method: input array must be 2D + // matrix !"; + + if (matrix.sizeAt(0) == 1) { + matrix *= (T)1.f - coeff; + } else if (coeff != (T)0.f) { + auto bottomPart = matrix({1, matrix.sizeAt(0), 0, 0}, true); + auto bottomPartCopy = bottomPart.dup(); + + if (tail.isColumnVector()) { + auto column = tail; + auto row = tail.transpose(); + auto resultingRow = mmul(row, bottomPartCopy); + auto fistRow = matrix({0, 1, 0, 0}, true); + resultingRow += fistRow; + fistRow -= resultingRow * coeff; + bottomPart -= mmul(column, resultingRow) * coeff; + } else { + auto row = tail; + auto column = tail.transpose(); + auto resultingRow = mmul(row, bottomPartCopy); + auto fistRow = matrix({0, 1, 0, 0}, true); + resultingRow += fistRow; + fistRow -= resultingRow * coeff; + bottomPart -= mmul(column, resultingRow) * coeff; } - else if(coeff != (T)0.f) { - - auto bottomPart = matrix({1,matrix.sizeAt(0), 0,0}, true); - auto bottomPartCopy = bottomPart.dup(); - - if(tail.isColumnVector()) { - - auto column = tail; - auto row = tail.transpose(); - auto resultingRow = mmul(row, bottomPartCopy); - auto fistRow = matrix({0,1, 0,0}, true); - resultingRow += fistRow; - fistRow -= resultingRow * coeff; - bottomPart -= mmul(column, resultingRow) * coeff; - } - else { - - auto row = tail; - auto column = tail.transpose(); - auto resultingRow = mmul(row, bottomPartCopy); - auto fistRow = matrix({0,1, 0,0}, true); - resultingRow += fistRow; - fistRow -= resultingRow * coeff; - bottomPart -= mmul(column, resultingRow) * coeff; - } - } + } } - ////////////////////////////////////////////////////////////////////////// template -void Householder::mulRight(NDArray& matrix, const NDArray& tail, const T coeff) { - - // if(matrix.rankOf() != 2) - // throw "ops::helpers::Householder::mulRight method: input array must be 2D matrix !"; - - if(matrix.sizeAt(1) == 1) - matrix *= (T)1.f - coeff; - - else if(coeff != (T)0.f) { - - auto rightPart = matrix({0,0, 1,matrix.sizeAt(1)}, true); - auto rightPartCopy = rightPart.dup(); - auto fistCol = matrix({0,0, 0,1}, true); - - if(tail.isColumnVector()) { - - auto column = tail; - auto row = tail.transpose(); - auto resultingCol = mmul(rightPartCopy, column); - resultingCol += fistCol; - fistCol -= resultingCol * coeff; - rightPart -= mmul(resultingCol, row) * coeff; - } - else { - - auto row = tail; - auto column = tail.transpose(); - auto resultingCol = mmul(rightPartCopy, column); - resultingCol += fistCol; - fistCol -= resultingCol * coeff; - rightPart -= mmul(resultingCol, row) * coeff; - } - } +void Householder::mulRight(NDArray& matrix, const NDArray& tail, + const T coeff) { + // if(matrix.rankOf() != 2) + // throw "ops::helpers::Householder::mulRight method: input array must be + // 2D matrix !"; + + if (matrix.sizeAt(1) == 1) + matrix *= (T)1.f - coeff; + + else if (coeff != (T)0.f) { + auto rightPart = matrix({0, 0, 1, matrix.sizeAt(1)}, true); + auto rightPartCopy = rightPart.dup(); + auto fistCol = matrix({0, 0, 0, 1}, true); + + if (tail.isColumnVector()) { + auto column = tail; + auto row = tail.transpose(); + auto resultingCol = mmul(rightPartCopy, column); + resultingCol += fistCol; + fistCol -= resultingCol * coeff; + rightPart -= mmul(resultingCol, row) * coeff; + } else { + auto row = tail; + auto column = tail.transpose(); + auto resultingCol = mmul(rightPartCopy, column); + resultingCol += fistCol; + fistCol -= resultingCol * coeff; + rightPart -= mmul(resultingCol, row) * coeff; + } + } } - template class SD_EXPORT Householder; template class SD_EXPORT Householder; template class SD_EXPORT Householder; template class SD_EXPORT Householder; - - - - - - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/jacobiSVD.cpp b/libnd4j/include/helpers/cpu/jacobiSVD.cpp index 8c6d1ccc7b70..5d403c5456bf 100644 --- a/libnd4j/include/helpers/cpu/jacobiSVD.cpp +++ b/libnd4j/include/helpers/cpu/jacobiSVD.cpp @@ -18,412 +18,401 @@ // Created by Yurii Shyrma on 11.01.2018 // -#include -#include #include - +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, const bool fullUV) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); - - _rows = static_cast(matrix.sizeAt(0)); - _cols = static_cast(matrix.sizeAt(1)); - _diagSize = math::nd4j_min(_rows, _cols); - - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); - - if(_calcU) { - if(_fullUV) - _u = NDArrayFactory::create(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext()); - else - _u = NDArrayFactory::create(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext()); - } +JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, + const bool calcV, const bool fullUV) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); + + _rows = static_cast(matrix.sizeAt(0)); + _cols = static_cast(matrix.sizeAt(1)); + _diagSize = math::nd4j_min(_rows, _cols); + + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, + matrix.dataType(), matrix.getContext()); + + if (_calcU) { + if (_fullUV) + _u = NDArrayFactory::create(matrix.ordering(), {_rows, _rows}, + matrix.dataType(), matrix.getContext()); else - _u = NDArrayFactory::create(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext()); - - if(_calcV) { - if(_fullUV) - _v = NDArrayFactory::create(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext()); - else - _v = NDArrayFactory::create(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext()); - } + _u = NDArrayFactory::create(matrix.ordering(), {_rows, _diagSize}, + matrix.dataType(), matrix.getContext()); + } else + _u = NDArrayFactory::create(matrix.ordering(), {_rows, 1}, + matrix.dataType(), matrix.getContext()); + + if (_calcV) { + if (_fullUV) + _v = NDArrayFactory::create(matrix.ordering(), {_cols, _cols}, + matrix.dataType(), matrix.getContext()); else - _v = NDArrayFactory::create(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext()); + _v = NDArrayFactory::create(matrix.ordering(), {_cols, _diagSize}, + matrix.dataType(), matrix.getContext()); + } else + _v = NDArrayFactory::create(matrix.ordering(), {_cols, 1}, + matrix.dataType(), matrix.getContext()); - _m = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); + _m = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, + matrix.dataType(), matrix.getContext()); - evalData(matrix); + evalData(matrix); } ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, const NDArray& rotation) { - - if(i < j) { - - if(j+1 > block.sizeAt(0)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out of array row range !"); - - auto pTemp = block({i,j+1,j-i, 0,0,0}, true, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(rotation, temp)); - } - else { - - if(j+1 > block.sizeAt(0) || i+1 > block.sizeAt(0)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnLeft: some or both integer arguments are out of array row range !"); - - auto temp = NDArrayFactory::create(block.ordering(), {2, block.sizeAt(1)}, block.dataType(), block.getContext()); - auto row1 = block({i,i+1, 0,0}, true); - auto row2 = block({j,j+1, 0,0}, true); - auto rowTemp1 = temp({0,1, 0,0}, true); - auto rowTemp2 = temp({1,2, 0,0}, true); - rowTemp1.assign(row1); - rowTemp2.assign(row2); - temp.assign(mmul(rotation, temp)); - row1.assign(rowTemp1); - row2.assign(rowTemp2); - } +void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, + const NDArray& rotation) { + if (i < j) { + if (j + 1 > block.sizeAt(0)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnLeft: second arguments is out " + "of array row range !"); + + auto pTemp = block({i, j + 1, j - i, 0, 0, 0}, true, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(rotation, temp)); + } else { + if (j + 1 > block.sizeAt(0) || i + 1 > block.sizeAt(0)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnLeft: some or both integer " + "arguments are out of array row range !"); + + auto temp = NDArrayFactory::create(block.ordering(), {2, block.sizeAt(1)}, + block.dataType(), block.getContext()); + auto row1 = block({i, i + 1, 0, 0}, true); + auto row2 = block({j, j + 1, 0, 0}, true); + auto rowTemp1 = temp({0, 1, 0, 0}, true); + auto rowTemp2 = temp({1, 2, 0, 0}, true); + rowTemp1.assign(row1); + rowTemp2.assign(row2); + temp.assign(mmul(rotation, temp)); + row1.assign(rowTemp1); + row2.assign(rowTemp2); + } } ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::mulRotationOnRight(const int i, const int j, NDArray& block, const NDArray& rotation) { - - if(i < j) { - - if(j+1 > block.sizeAt(1)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: second argument is out of array column range !"); - - auto pTemp = block({0,0,0, i,j+1,j-i}, true, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, rotation)); - } - else { - - if(j+1 > block.sizeAt(1) || i+1 > block.sizeAt(1)) - throw std::runtime_error("ops::helpers::JacobiSVD mulRotationOnRight: some or both integer arguments are out of array column range !"); - - auto temp = NDArrayFactory::create(block.ordering(), {block.sizeAt(0), 2}, block.dataType(), block.getContext()); - auto col1 = block({0,0, i,i+1}, true); - auto col2 = block({0,0, j,j+1}, true); - auto colTemp1 = temp({0,0, 0,1}, true); - auto colTemp2 = temp({0,0, 1,2}, true); - colTemp1.assign(col1); - colTemp2.assign(col2); - temp.assign(mmul(temp, rotation)); - col1.assign(colTemp1); - col2.assign(colTemp2); - } +void JacobiSVD::mulRotationOnRight(const int i, const int j, NDArray& block, + const NDArray& rotation) { + if (i < j) { + if (j + 1 > block.sizeAt(1)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnRight: second argument is out " + "of array column range !"); + + auto pTemp = block({0, 0, 0, i, j + 1, j - i}, true, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, rotation)); + } else { + if (j + 1 > block.sizeAt(1) || i + 1 > block.sizeAt(1)) + throw std::runtime_error( + "ops::helpers::JacobiSVD mulRotationOnRight: some or both integer " + "arguments are out of array column range !"); + + auto temp = NDArrayFactory::create(block.ordering(), {block.sizeAt(0), 2}, + block.dataType(), block.getContext()); + auto col1 = block({0, 0, i, i + 1}, true); + auto col2 = block({0, 0, j, j + 1}, true); + auto colTemp1 = temp({0, 0, 0, 1}, true); + auto colTemp2 = temp({0, 0, 1, 2}, true); + colTemp1.assign(col1); + colTemp2.assign(col2); + temp.assign(mmul(temp, rotation)); + col1.assign(colTemp1); + col2.assign(colTemp2); + } } ////////////////////////////////////////////////////////////////////////// template bool JacobiSVD::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) { + auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), + _m.getContext()); + T n = math::nd4j_sqrt(block.e(p, p) * block.e(p, p) + + block.e(q, p) * block.e(q, p)); - auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - T n = math::nd4j_sqrt(block.e(p,p) * block.e(p,p) + block.e(q,p) * block.e(q,p)); - - const T almostZero = DataTypeUtils::min(); - const T precision = DataTypeUtils::eps(); + const T almostZero = DataTypeUtils::min(); + const T precision = DataTypeUtils::eps(); - if(n == (T)0.f) { - block.p(p, p, 0.f); - block.p(q, p, 0.f); - } else { - T v = block.e(p, p) / n; + if (n == (T)0.f) { + block.p(p, p, 0.f); + block.p(q, p, 0.f); + } else { + T v = block.e(p, p) / n; - rotation.p(0, 0, v); - rotation.p(1,1, v); + rotation.p(0, 0, v); + rotation.p(1, 1, v); - v = block.e(q,p) / n; - rotation.p(0, 1, v); + v = block.e(q, p) / n; + rotation.p(0, 1, v); - rotation.p(1,0, -rotation.template e(0, 1)); - mulRotationOnLeft(p, q, block, rotation); + rotation.p(1, 0, -rotation.template e(0, 1)); + mulRotationOnLeft(p, q, block, rotation); - if(_calcU) { - auto temp2 = rotation.transpose(); - mulRotationOnRight(p, q, _u, temp2); - } + if (_calcU) { + auto temp2 = rotation.transpose(); + mulRotationOnRight(p, q, _u, temp2); } + } - maxElem = math::nd4j_max(maxElem, math::nd4j_max(math::nd4j_abs(block.e(p,p)), math::nd4j_abs(block.e(q,q)))); - T threshold = math::nd4j_max(almostZero, precision * maxElem); - const bool condition1 = math::nd4j_abs(block.e(p,q)) > threshold; - const bool condition2 = math::nd4j_abs(block.e(q,p)) > threshold; + maxElem = math::nd4j_max( + maxElem, math::nd4j_max(math::nd4j_abs(block.e(p, p)), + math::nd4j_abs(block.e(q, q)))); + T threshold = math::nd4j_max(almostZero, precision * maxElem); + const bool condition1 = math::nd4j_abs(block.e(p, q)) > threshold; + const bool condition2 = math::nd4j_abs(block.e(q, p)) > threshold; - return condition1 || condition2; + return condition1 || condition2; } ////////////////////////////////////////////////////////////////////////// template -bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation) { - - T denom = 2.* math::nd4j_abs(y); - - if(denom < DataTypeUtils::min()) { - - rotation.p(0,0, 1.f); - rotation.p(1,1, 1.f); - rotation.p(0,1, 0.f); - rotation.p(1,0, 0.f); - return false; - } - else { - - T tau = (x-z)/denom; - T w = math::nd4j_sqrt(tau*tau + 1.); - T t; - - if(tau > (T)0.) - t = 1. / (tau + w); - else - t = 1. / (tau - w); +bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, + NDArray& rotation) { + T denom = 2. * math::nd4j_abs(y); + + if (denom < DataTypeUtils::min()) { + rotation.p(0, 0, 1.f); + rotation.p(1, 1, 1.f); + rotation.p(0, 1, 0.f); + rotation.p(1, 0, 0.f); + return false; + } else { + T tau = (x - z) / denom; + T w = math::nd4j_sqrt(tau * tau + 1.); + T t; + + if (tau > (T)0.) + t = 1. / (tau + w); + else + t = 1. / (tau - w); - T sign = t > (T)0. ? 1. : -1.; - T n = 1. / math::nd4j_sqrt(t*t + 1.f); - rotation.p(0,0, n); - rotation.p(1,1, n); + T sign = t > (T)0. ? 1. : -1.; + T n = 1. / math::nd4j_sqrt(t * t + 1.f); + rotation.p(0, 0, n); + rotation.p(1, 1, n); - rotation.p(0,1, -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * n); - rotation.p(1,0, -rotation.e(0,1)); + rotation.p(0, 1, + -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * n); + rotation.p(1, 0, -rotation.e(0, 1)); - return true; - } + return true; + } } ////////////////////////////////////////////////////////////////////////// template -void JacobiSVD::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right) { - - auto m = NDArrayFactory::create(block.ordering(), {2, 2}, block.dataType(), block.getContext()); - m.p(0,0, block.e(p,p)); - m.p(0,1, block.e(p,q)); - m.p(1,0, block.e(q,p)); - m.p(1,1, block.e(q,q)); - - auto rotation = NDArrayFactory::create(block.ordering(), {2, 2}, block.dataType(), block.getContext()); - T t = m.e(0,0) + m.e(1,1); - T d = m.e(1,0) - m.e(0,1); - - if(math::nd4j_abs(d) < DataTypeUtils::min()) { - - rotation.p(0,0, 1.f); - rotation.p(1,1, 1.f); - rotation.p(0,1, 0.f); - rotation.p(1,0, 0.f); - } - else { - - T u = t / d; - T tmp = math::nd4j_sqrt(1. + u*u); - rotation.p(0,0, u / tmp); - rotation.p(1,1, u / tmp); - rotation.p(0,1, 1.f / tmp); - rotation.p(1,0, -rotation.e(0,1)); - } - - m.assign(mmul(rotation, m)); - - auto _x = m.e(0,0); - auto _y = m.e(0,1); - auto _z = m.e(1,1); - - createJacobiRotation(_x, _y, _z, right); - - m.p(0, 0, _x); - m.p(0, 1, _y); - m.p(1, 1, _z); - - auto temp = right.transpose(); - left.assign(mmul(rotation, temp)); +void JacobiSVD::svd2x2(const NDArray& block, int p, int q, NDArray& left, + NDArray& right) { + auto m = NDArrayFactory::create(block.ordering(), {2, 2}, block.dataType(), + block.getContext()); + m.p(0, 0, block.e(p, p)); + m.p(0, 1, block.e(p, q)); + m.p(1, 0, block.e(q, p)); + m.p(1, 1, block.e(q, q)); + + auto rotation = NDArrayFactory::create(block.ordering(), {2, 2}, + block.dataType(), block.getContext()); + T t = m.e(0, 0) + m.e(1, 1); + T d = m.e(1, 0) - m.e(0, 1); + + if (math::nd4j_abs(d) < DataTypeUtils::min()) { + rotation.p(0, 0, 1.f); + rotation.p(1, 1, 1.f); + rotation.p(0, 1, 0.f); + rotation.p(1, 0, 0.f); + } else { + T u = t / d; + T tmp = math::nd4j_sqrt(1. + u * u); + rotation.p(0, 0, u / tmp); + rotation.p(1, 1, u / tmp); + rotation.p(0, 1, 1.f / tmp); + rotation.p(1, 0, -rotation.e(0, 1)); + } + + m.assign(mmul(rotation, m)); + + auto _x = m.e(0, 0); + auto _y = m.e(0, 1); + auto _z = m.e(1, 1); + + createJacobiRotation(_x, _y, _z, right); + + m.p(0, 0, _x); + m.p(0, 1, _y); + m.p(1, 1, _z); + + auto temp = right.transpose(); + left.assign(mmul(rotation, temp)); } - ////////////////////////////////////////////////////////////////////////// template void JacobiSVD::evalData(const NDArray& matrix) { + const T precision = (T)2.f * DataTypeUtils::eps(); + const T almostZero = DataTypeUtils::min(); - const T precision = (T)2.f * DataTypeUtils::eps(); - const T almostZero = DataTypeUtils::min(); - - T scale = matrix.reduceNumber(reduce::AMax).e(0); - if(scale== (T)0.f) - scale = (T)1.f; - - if(_rows > _cols) { - - HHcolPivQR qr(matrix / scale); - _m.assign(qr._qr({0,_cols, 0,_cols})); - _m.fillAsTriangular(0., 0, 0, _m, 'l'); - - HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); - - if(_fullUV) - hhSeg.applyTo(_u); - else if(_calcU) { - _u.setIdentity(); - hhSeg.mulLeft(_u); - } - - if(_calcV) - _v.assign(qr._permut); - } - else if(_rows < _cols) { - - auto matrixT = matrix.transpose(); - HHcolPivQR qr(matrixT / scale); - _m.assign(qr._qr({0,_rows, 0,_rows})); - _m.fillAsTriangular(0., 0, 0, _m, 'l'); - _m.transposei(); + T scale = matrix.reduceNumber(reduce::AMax).e(0); + if (scale == (T)0.f) scale = (T)1.f; - HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! + if (_rows > _cols) { + HHcolPivQR qr(matrix / scale); + _m.assign(qr._qr({0, _cols, 0, _cols})); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); - if(_fullUV) - hhSeg.applyTo(_v); - else if(_calcV) { - _v.setIdentity(); - hhSeg.mulLeft(_v); - } - - if(_calcU) - _u.assign(qr._permut); - } - else { + HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); - _m.assign(static_cast(matrix({0,_diagSize, 0,_diagSize})) / scale); - - if(_calcU) - _u.setIdentity(); - - if(_calcV) - _v.setIdentity(); + if (_fullUV) + hhSeg.applyTo(_u); + else if (_calcU) { + _u.setIdentity(); + hhSeg.mulLeft(_u); } - T maxDiagElem = 0.; - for(int i = 0; i < _diagSize; ++i) { - T current = math::nd4j_abs(_m.e(i,i)); - if(maxDiagElem < current ) - maxDiagElem = current; + if (_calcV) _v.assign(qr._permut); + } else if (_rows < _cols) { + auto matrixT = matrix.transpose(); + HHcolPivQR qr(matrixT / scale); + _m.assign(qr._qr({0, _rows, 0, _rows})); + _m.fillAsTriangular(0., 0, 0, _m, 'l'); + _m.transposei(); + + HHsequence hhSeg(qr._qr, qr._coeffs, + 'u'); // type = 'u' is not mistake here ! + + if (_fullUV) + hhSeg.applyTo(_v); + else if (_calcV) { + _v.setIdentity(); + hhSeg.mulLeft(_v); } - bool stop = false; + if (_calcU) _u.assign(qr._permut); + } else { + _m.assign( + static_cast(matrix({0, _diagSize, 0, _diagSize})) / + scale); - while(!stop) { + if (_calcU) _u.setIdentity(); - stop = true; + if (_calcV) _v.setIdentity(); + } - for(int p = 1; p < _diagSize; ++p) { + T maxDiagElem = 0.; + for (int i = 0; i < _diagSize; ++i) { + T current = math::nd4j_abs(_m.e(i, i)); + if (maxDiagElem < current) maxDiagElem = current; + } - for(int q = 0; q < p; ++q) { + bool stop = false; - T threshold = math::nd4j_max(almostZero, precision * maxDiagElem); + while (!stop) { + stop = true; - if(math::nd4j_abs(_m.e(p,q)) > threshold || math::nd4j_abs(_m.e(q,p)) > threshold){ + for (int p = 1; p < _diagSize; ++p) { + for (int q = 0; q < p; ++q) { + T threshold = math::nd4j_max(almostZero, precision * maxDiagElem); - stop = false; + if (math::nd4j_abs(_m.e(p, q)) > threshold || + math::nd4j_abs(_m.e(q, p)) > threshold) { + stop = false; - // if(isBlock2x2NotDiag(_m, p, q, maxDiagElem)) - { - auto rotLeft = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - auto rotRight = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); - svd2x2(_m, p, q, rotLeft, rotRight); + // if(isBlock2x2NotDiag(_m, p, q, maxDiagElem)) + { + auto rotLeft = NDArrayFactory::create( + _m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); + auto rotRight = NDArrayFactory::create( + _m.ordering(), {2, 2}, _m.dataType(), _m.getContext()); + svd2x2(_m, p, q, rotLeft, rotRight); - mulRotationOnLeft(p, q, _m, rotLeft); + mulRotationOnLeft(p, q, _m, rotLeft); - if(_calcU) { - auto temp = rotLeft.transpose(); - mulRotationOnRight(p, q, _u, temp); - } + if (_calcU) { + auto temp = rotLeft.transpose(); + mulRotationOnRight(p, q, _u, temp); + } - mulRotationOnRight(p, q, _m, rotRight); + mulRotationOnRight(p, q, _m, rotRight); - if(_calcV) - mulRotationOnRight(p, q, _v, rotRight); + if (_calcV) mulRotationOnRight(p, q, _v, rotRight); - maxDiagElem = math::nd4j_max(maxDiagElem, math::nd4j_max(math::nd4j_abs(_m.e(p,p)), math::nd4j_abs(_m.e(q,q)))); - } - } - } + maxDiagElem = math::nd4j_max( + maxDiagElem, + math::nd4j_max(math::nd4j_abs(_m.e(p, p)), + math::nd4j_abs(_m.e(q, q)))); + } } + } } + } - for(int i = 0; i < _diagSize; ++i) { - _s.p(i, math::nd4j_abs(_m.e(i,i))); - if(_calcU && _m.e(i,i) < (T)0.) { - auto temp = _u({0,0, i,i+1}, true); - temp.applyTransform(transform::Neg, temp, nullptr); - } + for (int i = 0; i < _diagSize; ++i) { + _s.p(i, math::nd4j_abs(_m.e(i, i))); + if (_calcU && _m.e(i, i) < (T)0.) { + auto temp = _u({0, 0, i, i + 1}, true); + temp.applyTransform(transform::Neg, temp, nullptr); } - - _s *= scale; - - for(int i = 0; i < _diagSize; i++) { - - int pos = (_s({i,-1, 0,0}).indexReduceNumber(indexreduce::IndexMax, nullptr)).template e(0); - T maxSingVal = _s({i,-1, 0,0}).reduceNumber(reduce::Max).template e(0); - - if(maxSingVal == (T)0.) - break; - - if(pos) { - - pos += i; - - T _e0 = _s.e(i); - T _e1 = _s.e(pos); - _s.p(pos, _e0); - _s.p(i, _e1); - //math::nd4j_swap(_s(i), _s(pos)); - - if(_calcU) { - auto temp1 = _u({0,0, pos,pos+1}, true); - auto temp2 = _u({0,0, i,i+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - - if(_calcV) { - auto temp1 = _v({0,0, pos, pos+1}, true); - auto temp2 = _v({0,0, i, i+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - } + } + + _s *= scale; + + for (int i = 0; i < _diagSize; i++) { + int pos = + (_s({i, -1, 0, 0}).indexReduceNumber(indexreduce::IndexMax, nullptr)) + .template e(0); + T maxSingVal = _s({i, -1, 0, 0}).reduceNumber(reduce::Max).template e(0); + + if (maxSingVal == (T)0.) break; + + if (pos) { + pos += i; + + T _e0 = _s.e(i); + T _e1 = _s.e(pos); + _s.p(pos, _e0); + _s.p(i, _e1); + // math::nd4j_swap(_s(i), _s(pos)); + + if (_calcU) { + auto temp1 = _u({0, 0, pos, pos + 1}, true); + auto temp2 = _u({0, 0, i, i + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + + if (_calcV) { + auto temp1 = _v({0, 0, pos, pos + 1}, true); + auto temp2 = _v({0, 0, i, i + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } } + } } - - - template class SD_EXPORT JacobiSVD; template class SD_EXPORT JacobiSVD; template class SD_EXPORT JacobiSVD; template class SD_EXPORT JacobiSVD; - - - - - - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp index a64f0fc913ad..bfcca574dcb2 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.hpp @@ -22,293 +22,309 @@ using namespace simdOps; - ////////////////////////////////////////////////////////////////////////////// template template -void sd::IndexReductionLoops::loopIndexReduce(const X* x, const Nd4jLong* xShapeInfo, - Z* z, const Nd4jLong* zShapeInfo, - const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, - X* extraParams) { - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); - if(kindOfLoop == sd::LoopKind::SMALLARR2DX) - kindOfLoop = sd::LoopKind::EWSNONZERO; - - const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong tadLen = shape::length(tadShapeInfo); - - const uint tadEws = shape::elementWiseStride(tadShapeInfo); - const uint zEws = shape::elementWiseStride(zShapeInfo); - - const Nd4jLong* tadShape = shape::shapeOf(const_cast(tadShapeInfo)); - const Nd4jLong* tadStride = shape::stride(const_cast(tadShapeInfo)); - - switch (kindOfLoop) { - //*********************************************// - case sd::LoopKind::EWS1: { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); +void sd::IndexReductionLoops::loopIndexReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams) { + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo); + if (kindOfLoop == sd::LoopKind::SMALLARR2DX) + kindOfLoop = sd::LoopKind::EWSNONZERO; + + const Nd4jLong zLen = shape::length(zShapeInfo); + const Nd4jLong tadLen = shape::length(tadShapeInfo); + + const uint tadEws = shape::elementWiseStride(tadShapeInfo); + const uint zEws = shape::elementWiseStride(zShapeInfo); + + const Nd4jLong* tadShape = + shape::shapeOf(const_cast(tadShapeInfo)); + const Nd4jLong* tadStride = + shape::stride(const_cast(tadShapeInfo)); + + switch (kindOfLoop) { + //*********************************************// + case sd::LoopKind::EWS1: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + z[i] = (Z)indexValue.index; } - break; + }; - //*********************************************// - case sd::LoopKind::EWSNONZERO: { + samediff::Threads::parallel_tad(func, 0, zLen); + } break; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); + //*********************************************// + case sd::LoopKind::EWSNONZERO: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j * tadEws], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j * tadEws], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } - z[i * zEws] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + z[i * zEws] = (Z)indexValue.index; } - break; + }; - //*********************************************// - case sd::LoopKind::RANK1: { + samediff::Threads::parallel_tad(func, 0, zLen); + } break; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); + //*********************************************// + case sd::LoopKind::RANK1: { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); - for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { - functions::indexreduce::IndexValue comp(tad[i0 * tadStride[0]], i0); - indexValue = OpType::update(indexValue, comp, extraParams); - } + for (Nd4jLong i0 = 0; i0 < tadLen; ++i0) { + functions::indexreduce::IndexValue comp(tad[i0 * tadStride[0]], + i0); + indexValue = OpType::update(indexValue, comp, extraParams); + } - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK2: { - Nd4jLong newStride[2]; - shape::updateStrides(2, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1]; - const auto tadIndex = i0 * newStride[0] + i1; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK2: { + Nd4jLong newStride[2]; + shape::updateStrides(2, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1]; + const auto tadIndex = i0 * newStride[0] + i1; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } + } + + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK3: { - Nd4jLong newStride[3]; - shape::updateStrides(3, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - - z[i] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK3: { + Nd4jLong newStride[3]; + shape::updateStrides(3, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + const auto tadOffset = + i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]; + const auto tadIndex = + i0 * newStride[0] + i1 * newStride[1] + i2; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } + } + } + + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK4: { - Nd4jLong newStride[4]; - shape::updateStrides(4, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - } - - z[i] = (Z) indexValue.index; + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK4: { + Nd4jLong newStride[4]; + shape::updateStrides(4, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3]; + const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + + i2 * newStride[2] + i3; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); } - }; + } + } + } - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::RANK5: { - Nd4jLong newStride[5]; - shape::updateStrides(5, tadShape, newStride, 'c'); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { - for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { - for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { - for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { - for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { - const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]; - const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4; - functions::indexreduce::IndexValue comp(tad[tadOffset], tadIndex); - indexValue = OpType::update(indexValue, comp, extraParams); - } - } - } - } - } - - z[i] = (Z) indexValue.index; + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::RANK5: { + Nd4jLong newStride[5]; + shape::updateStrides(5, tadShape, newStride, 'c'); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0) { + for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1) { + for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2) { + for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3) { + for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4) { + const auto tadOffset = + i0 * tadStride[0] + i1 * tadStride[1] + + i2 * tadStride[2] + i3 * tadStride[3] + + i4 * tadStride[4]; + const auto tadIndex = + i0 * newStride[0] + i1 * newStride[1] + + i2 * newStride[2] + i3 * newStride[3] + i4; + functions::indexreduce::IndexValue comp(tad[tadOffset], + tadIndex); + indexValue = OpType::update(indexValue, comp, extraParams); + } } - }; + } + } + } - samediff::Threads::parallel_tad(func, 0, zLen); + z[i] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::X_EWSNONZERO: { - uint castZShapeInfo[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - functions::indexreduce::IndexValue comp(tad[j * tadEws], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::X_EWSNONZERO: { + uint castZShapeInfo[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + functions::indexreduce::IndexValue comp(tad[j * tadEws], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = (Z)indexValue.index; } - break; - - //*********************************************// - case sd::LoopKind::Z_EWSNONZERO: { - uint castTadShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - functions::indexreduce::IndexValue comp(tad[tadOffset], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - z[i * zEws] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + case sd::LoopKind::Z_EWSNONZERO: { + uint castTadShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = shape::indexOffset(j, tadShapeInfo, + castTadShapeInfo, canCastTad); + functions::indexreduce::IndexValue comp(tad[tadOffset], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + z[i * zEws] = (Z)indexValue.index; } - break; - - //*********************************************// - default: { - uint castTadShapeInfo[MAX_RANK]; - uint castZShapeInfo[MAX_RANK]; - const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto tad = const_cast(x) + tadOffsets[i]; - auto indexValue = OpType::startingIndexValue(tad); - - for (Nd4jLong j = 0; j < tadLen; j++) { - auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); - functions::indexreduce::IndexValue comp(tad[tadOffset], j); - indexValue = OpType::update(indexValue, comp, extraParams); - } - - auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); - z[zOffset] = (Z) indexValue.index; - } - }; - - samediff::Threads::parallel_tad(func, 0, zLen); + }; + + samediff::Threads::parallel_tad(func, 0, zLen); + } break; + + //*********************************************// + default: { + uint castTadShapeInfo[MAX_RANK]; + uint castZShapeInfo[MAX_RANK]; + const bool canCastTad = sd::DataTypeUtils::castShapeInfo( + tadShapeInfo, castTadShapeInfo); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto tad = const_cast(x) + tadOffsets[i]; + auto indexValue = OpType::startingIndexValue(tad); + + for (Nd4jLong j = 0; j < tadLen; j++) { + auto tadOffset = shape::indexOffset(j, tadShapeInfo, + castTadShapeInfo, canCastTad); + functions::indexreduce::IndexValue comp(tad[tadOffset], j); + indexValue = OpType::update(indexValue, comp, extraParams); + } + + auto zOffset = + shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); + z[zOffset] = (Z)indexValue.index; } + }; + + samediff::Threads::parallel_tad(func, 0, zLen); } + } } template -void sd::IndexReductionLoops::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); +void sd::IndexReductionLoops::wrapIndexReduce( + const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, void* vextraParams) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + DISPATCH_BY_OPNUM_TT(loopIndexReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams), + INDEX_REDUCE_OPS); } \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_0.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_0.cpp index 97318dae8f95..4f08e744e07f 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_0.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_1.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_1.cpp index 680bf7a64597..a9ad241db44d 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_1.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_2.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_2.cpp index e22635b85a64..438aff6c8602 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_2.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_3.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_3.cpp index f85096f0a4e1..cac7092bcd11 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_3.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_4.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_4.cpp index 5272eba7e8c8..52b5367d43c5 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_4.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_4.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_5.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_5.cpp index 683d6d0c08d7..fe2b40b4b809 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_5.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_5.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_6.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_6.cpp index 0ff70b7b516a..82a08a95e69c 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_6.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_6.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_7.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_7.cpp index 64d93c5e3fa8..27be7b9e15b0 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_7.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_7.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_8.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_8.cpp index dd586ab2632d..27b4809d8cff 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_8.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_8.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_9.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_9.cpp index bb7ef80f759e..29232ce187f6 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_9.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int32_9.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_0.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_0.cpp index 8d0c55ce1454..b5963ce846c6 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_0.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_1.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_1.cpp index 7c58245595ee..5e56d7a0910b 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_1.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_2.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_2.cpp index 3bb6e6b7c51c..ea80d9f3186b 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_2.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_3.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_3.cpp index 49f977901fcc..60b7349072eb 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_3.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_4.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_4.cpp index 73f0e9872113..44d5fcb8dae7 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_4.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_4.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_5.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_5.cpp index b27aaf341652..3b1d9242d777 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_5.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_5.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_6.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_6.cpp index 452184acdf1f..819ecbd86982 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_6.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_6.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_7.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_7.cpp index 59cbc51cf226..6a6461b59d36 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_7.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_7.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_8.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_8.cpp index 51fc49ceaffd..e4b8a64ef2f5 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_8.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_8.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_9.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_9.cpp index b774dde52e3c..4f9a06fb17fd 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_9.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops_int64_9.cpp @@ -21,4 +21,11 @@ #include "./IndexReductionLoops.hpp" -BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, ::wrapIndexReduce(const int opNum, const void* vx, const Nd4jLong* xShapeInfo, void* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void sd::IndexReductionLoops, + ::wrapIndexReduce(const int opNum, const void* vx, + const Nd4jLong* xShapeInfo, void* z, + const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, + void* vextraParams), + LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp index f44b23ec6056..91a938654f27 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp @@ -26,35 +26,65 @@ using namespace simdOps; namespace sd { - template - template - void Reduction3Loops::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); + Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, + start, stop); #endif - } +} - template - template - void Reduction3Loops::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); + Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, + yTadShapeInfo, yTadOffsets, extraParams, start, stop); #endif - } +} - template - void Reduction3Loops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapper(const int opNum, const X* x, + const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Y* extraParams, int64_t start, + int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce3, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, + dimsLen, extraParams, start, stop), + REDUCE3_OPS); #endif - } +} - template - void Reduction3Loops::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapperAll( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT( + innerloopReduce3All, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, + xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), + REDUCE3_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_0); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, + FLOAT_TYPES_0); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp index 62368452113c..619f9928d489 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp @@ -26,35 +26,65 @@ using namespace simdOps; namespace sd { - template - template - void Reduction3Loops::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); + Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, + start, stop); #endif - } +} - template - template - void Reduction3Loops::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); + Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, + yTadShapeInfo, yTadOffsets, extraParams, start, stop); #endif - } +} - template - void Reduction3Loops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapper(const int opNum, const X* x, + const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Y* extraParams, int64_t start, + int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce3, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, + dimsLen, extraParams, start, stop), + REDUCE3_OPS); #endif - } +} - template - void Reduction3Loops::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapperAll( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT( + innerloopReduce3All, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, + xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), + REDUCE3_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_1); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, + FLOAT_TYPES_1); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp index 1ce18ef2b6a2..46029aed7ef7 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp @@ -26,35 +26,65 @@ using namespace simdOps; namespace sd { - template - template - void Reduction3Loops::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); + Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, + start, stop); #endif - } +} - template - template - void Reduction3Loops::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); + Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, + yTadShapeInfo, yTadOffsets, extraParams, start, stop); #endif - } +} - template - void Reduction3Loops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapper(const int opNum, const X* x, + const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Y* extraParams, int64_t start, + int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce3, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, + dimsLen, extraParams, start, stop), + REDUCE3_OPS); #endif - } +} - template - void Reduction3Loops::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapperAll( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT( + innerloopReduce3All, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, + xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), + REDUCE3_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_2); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, + FLOAT_TYPES_2); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp index 9763a6116074..37e3ebcf9328 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp @@ -26,35 +26,65 @@ using namespace simdOps; namespace sd { - template - template - void Reduction3Loops::innerloopReduce3(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Z* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop); + Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, + start, stop); #endif - } +} - template - template - void Reduction3Loops::innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void Reduction3Loops::innerloopReduce3All( + const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop); + Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, + yTadShapeInfo, yTadOffsets, extraParams, start, stop); #endif - } +} - template - void Reduction3Loops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapper(const int opNum, const X* x, + const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, int* dims, + int dimsLen, Y* extraParams, int64_t start, + int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce3, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, + dimsLen, extraParams, start, stop), + REDUCE3_OPS); #endif - } +} - template - void Reduction3Loops::wrapperAll(const int opNum, const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Y *z, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) { +template +void Reduction3Loops::wrapperAll( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, const X* y, + const Nd4jLong* yShapeInfo, Y* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Y* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS); + DISPATCH_BY_OPNUM_TT( + innerloopReduce3All, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, + xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), + REDUCE3_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_3); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, + FLOAT_TYPES_3); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp index 61674020e179..553a5f906f0d 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp @@ -24,25 +24,32 @@ using namespace simdOps; namespace sd { - template - template - void ReductionBoolLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionBoolLoops::innerloopReduce( + const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, + const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionBoolLoops::wrapper(const int opNum, - const X *x, const Nd4jLong *xShapeInfo, - Y *z, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - X *extraParams, int64_t start, int64_t stop) { +template +void ReductionBoolLoops::wrapper( + const int opNum, const X* x, const Nd4jLong* xShapeInfo, Y* z, + const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, + const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_BOOL_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES); } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionBoolLoops, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp index 9ca2c1556bfe..30fde65c5b97 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp @@ -18,34 +18,41 @@ // @author raver119@gmail.com // -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionFloatLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionFloatLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, Z *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - template - void ReductionFloatLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, - Y *z, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - Y *extraParams, - int64_t start, int64_t stop) { +} + +template +void ReductionFloatLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_FLOAT_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0); } - +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , + LIBND4J_TYPES, FLOAT_TYPES_0); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp index 124440045e20..cd1ef1343bfa 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp @@ -18,35 +18,41 @@ // @author raver119@gmail.com // -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionFloatLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionFloatLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, Z *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - template - void ReductionFloatLoops::wrapper(const int opNum, - const X *x, const Nd4jLong *xShapeInfo, - Y *z, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - Y *extraParams, - int64_t start, int64_t stop) { +} + +template +void ReductionFloatLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_FLOAT_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1); } - +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , + LIBND4J_TYPES, FLOAT_TYPES_1); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp index 283238421f7c..d6ab82cf8f75 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp @@ -18,32 +18,41 @@ // @author raver119@gmail.com // -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionFloatLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionFloatLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, Z *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionFloatLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { +template +void ReductionFloatLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_FLOAT_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2); } - +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , + LIBND4J_TYPES, FLOAT_TYPES_2); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp index 36b946355098..02092bfa3dab 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp @@ -18,32 +18,41 @@ // @author raver119@gmail.com // -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionFloatLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionFloatLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, Z *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionFloatLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { +template +void ReductionFloatLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_FLOAT_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3); } - +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , + LIBND4J_TYPES, FLOAT_TYPES_3); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp index caefeeae13a9..81da1b348d8f 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp @@ -22,31 +22,41 @@ using namespace simdOps; - -#include "ReductionLoops.hpp" #include #include +#include "ReductionLoops.hpp" + using namespace simdOps; namespace sd { - template - template - void ReductionLongLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionLongLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, Z *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, X *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionLongLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) { +template +void ReductionLongLoops::wrapper( + const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS); + DISPATCH_BY_OPNUM_TT(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_LONG_OPS); #endif - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionLongLoops, , + LIBND4J_TYPES, LONG_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp index 53725de83846..382cabd60680 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp @@ -24,27 +24,37 @@ using namespace simdOps; namespace sd { - template - template - void ReductionSameLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) { +template +template +void ReductionSameLoops::innerloopReduce( + const X *x, const Nd4jLong *xShapeInfo, X *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, X *extraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop); + ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } +} - template - void ReductionSameLoops::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz, - const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, - X *vextraParams, int64_t start, int64_t stop) { +template +void ReductionSameLoops::wrapper(const int opNum, const X *vx, + const Nd4jLong *xShapeInfo, X *vz, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, X *vextraParams, + int64_t start, int64_t stop) { #ifndef INLINE_LOOPS - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS); + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + DISPATCH_BY_OPNUM_T(innerloopReduce, + PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, + tadOffsets, extraParams, start, stop), + REDUCE_SAME_OPS); #endif - } - - BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES); +} // namespace sd diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 2bae2231f2ad..83f22a647695 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -18,916 +18,933 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 03.01.2018 // -#include -#include -#include -#include #include - +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV ) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.getContext()); - _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext()); - _m.assign(0.); - - if (_calcU) - _u = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); - else - _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext()); - _u.assign(0.); - - if (_calcV) { - _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext()); - _v.assign(0.); - } - - evalData(matrix); +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, + matrix.getContext()); + _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.getContext()); + _m.assign(0.); + + if (_calcU) + _u = NDArrayFactory::create( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); + else + _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, + matrix.getContext()); + _u.assign(0.); + + if (_calcV) { + _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, + matrix.getContext()); + _v.assign(0.); + } + + evalData(matrix); } ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV, const char t) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.getContext()); - _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext()); - _m.assign(0.f); - - if (_calcU) - _u = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); - else - _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext()); - _u.assign(0.); - - if (_calcV) { - _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext()); - _v.assign(0.); - } +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV, const char t) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, + matrix.getContext()); + _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.getContext()); + _m.assign(0.f); + + if (_calcU) + _u = NDArrayFactory::create( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); + else + _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, + matrix.getContext()); + _u.assign(0.); + + if (_calcV) { + _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, + matrix.getContext()); + _v.assign(0.); + } } - ////////////////////////////////////////////////////////////////////////// template void SVD::deflation1(int col1, int shift, int ind, int size) { - - if(ind <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !"); - - int first = col1 + shift; - T cos = _m.e(first, first); - T sin = _m.e(first+ind, first); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - - _m.p(first+ind, first+ind, 0.f); - return; - } - - cos /= denom; - sin /= denom; - - _m.p(first,first, denom); - _m.p(first+ind, first, 0.f); - _m.p(first+ind, first+ind, 0.f); - - auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); - rotation.p(0, 0, cos); - rotation.p(0, 1, -sin); - rotation.p(1, 0, sin); - rotation.p(1, 1, cos); - - if (_calcU) { - auto temp = _u({col1,col1+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1, col1+ind, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1, col1+ind, _u, rotation); + if (ind <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation1 method: input int must satisfy " + "condition ind > 0 !"); + + int first = col1 + shift; + T cos = _m.e(first, first); + T sin = _m.e(first + ind, first); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.p(first + ind, first + ind, 0.f); + return; + } + + cos /= denom; + sin /= denom; + + _m.p(first, first, denom); + _m.p(first + ind, first, 0.f); + _m.p(first + ind, first + ind, 0.f); + + auto rotation = + NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); + rotation.p(0, 0, cos); + rotation.p(0, 1, -sin); + rotation.p(1, 0, sin); + rotation.p(1, 1, cos); + + if (_calcU) { + auto temp = _u({col1, col1 + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1, col1 + ind, temp, rotation); + } else + JacobiSVD::mulRotationOnRight(col1, col1 + ind, _u, rotation); } ////////////////////////////////////////////////////////////////////////// template -void SVD::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size) { - - if(ind1 >= ind2) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input intes must satisfy condition ind1 < ind2 !"); - - if(size <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input size must satisfy condition size > 0 !"); - - T cos = _m.e(col1M+ind1, col1M); - T sin = _m.e(col1M+ind2, col1M); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - - _m.p(col1M + ind1, col1M + ind1, _m.e(col1M + ind2, col1M + ind2)); - return; - } - - cos /= denom; - sin /= denom; - _m.p(col1M + ind1, col1M, denom); - _m.p(col1M + ind2, col1M + ind2, _m.e(col1M + ind1, col1M + ind1)); - _m.p(col1M + ind2, col1M, 0.f); - - auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); - rotation.p(0,0, cos); - rotation.p(1,1, cos); - - rotation.p(0,1, -sin); - rotation.p(1,0, sin); - - if (_calcU) { - auto temp = _u({col1U,col1U+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation); - - if (_calcV) { - auto temp = _v({row1W,row1W+size, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1W+ind1, col1W+ind2, temp, rotation); - } +void SVD::deflation2(int col1U, int col1M, int row1W, int col1W, int ind1, + int ind2, int size) { + if (ind1 >= ind2) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input intes must satisfy " + "condition ind1 < ind2 !"); + + if (size <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input size must satisfy " + "condition size > 0 !"); + + T cos = _m.e(col1M + ind1, col1M); + T sin = _m.e(col1M + ind2, col1M); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.p(col1M + ind1, col1M + ind1, _m.e(col1M + ind2, col1M + ind2)); + return; + } + + cos /= denom; + sin /= denom; + _m.p(col1M + ind1, col1M, denom); + _m.p(col1M + ind2, col1M + ind2, _m.e(col1M + ind1, col1M + ind1)); + _m.p(col1M + ind2, col1M, 0.f); + + auto rotation = + NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); + rotation.p(0, 0, cos); + rotation.p(1, 1, cos); + + rotation.p(0, 1, -sin); + rotation.p(1, 0, sin); + + if (_calcU) { + auto temp = _u({col1U, col1U + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, temp, + rotation); + } else + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, _u, rotation); + + if (_calcV) { + auto temp = _v({row1W, row1W + size, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1W + ind1, col1W + ind2, temp, + rotation); + } } ////////////////////////////////////////////////////////////////////////// -// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively +// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) +// inclusively template -void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int shift) -{ - - const int len = col2 + 1 - col1; - - auto colVec0 = _m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true); - - auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c'); - - const T almostZero = DataTypeUtils::min(); - T maxElem; - if(len == 1) - maxElem = math::nd4j_abs(diagInterval.template e(0)); - else - maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); - T maxElem0 = colVec0.reduceNumber(reduce::AMax).template e(0); - - T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); - T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - - if(diagInterval.template e(0) < epsBig) - diagInterval.p(Nd4jLong(0), epsBig); - - for(int i=1; i < len; ++i) - if(math::nd4j_abs(colVec0.template e(i)) < eps) - colVec0.p(i, 0.f); - - for(int i=1; i < len; i++) - if(diagInterval.template e(i) < epsBig) { - deflation1(col1, shift, i, len); - for(int i = 0; i < len; ++i) - diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); - } - - { - - bool totDefl = true; - for(int i=1; i < len; i++) - if(colVec0.template e(i) >= almostZero) { - totDefl = false; - break; - } - - int* permut = nullptr; - ALLOCATE(permut, _m.getContext()->getWorkspace(), 3*_diagSize, int); - { - permut[0] = 0; - int p = 1; - - for(int i=1; i(diagInterval.template e(i)) < almostZero) - permut[p++] = i; - - int k = 1, m = ind+1; - - for( ; p < len; ++p) { - if(k > ind) - permut[p] = m++; - else if(m >= len) - permut[p] = k++; - else if(diagInterval.template e(k) < diagInterval.template e(m)) - permut[p] = m++; - else - permut[p] = k++; - } - } - - if(totDefl) { - for(int i=1; i(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) - permut[i-1] = permut[i]; - else { - permut[i-1] = 0; - break; - } - } - } - - int *tInd = permut + len; - int *tCol = permut + 2*len; - - for(int m = 0; m < len; m++) { - tCol[m] = m; - tInd[m] = m; - } - - for(int i = totDefl ? 0 : 1; i < len; i++) { - - const int ki = permut[len - (totDefl ? i+1 : i)]; - const int jac = tCol[ki]; - - T _e0 = diagInterval.template e(jac); - //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval.p(jac, diagInterval.template e(i)); - diagInterval.p(i, _e0); - - if(i!=0 && jac!=0) { - _e0 = colVec0.template e(jac); - //math::nd4j_swap((*colVec0)(i), (*colVec0)(jac)); - colVec0.p(jac, colVec0.template e(i)); - colVec0.p(i, _e0); - } - - if (_calcU) { - auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); - auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - else { - auto temp1 = _u({0,2, col1+i, col1+i+1}, true); - auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - - if(_calcV) { - auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true); - auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - - const int tI = tInd[i]; - tCol[tI] = jac; - tCol[ki] = i; - tInd[jac] = tI; - tInd[i] = ki; - } - - RELEASE(permut, _m.getContext()->getWorkspace()); - } - +void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, + int shift) { + const int len = col2 + 1 - col1; + + auto colVec0 = _m( + {col1 + shift, col1 + shift + len, col1 + shift, col1 + shift + 1}, true); + + auto diagInterval = + _m({col1 + shift, col1 + shift + len, col1 + shift, col1 + shift + len}, + true) + .diagonal('c'); + + const T almostZero = DataTypeUtils::min(); + T maxElem; + if (len == 1) + maxElem = math::nd4j_abs(diagInterval.template e(0)); + else + maxElem = diagInterval({1, -1, 0, 0}, true) + .reduceNumber(reduce::AMax) + .template e(0); + T maxElem0 = colVec0.reduceNumber(reduce::AMax).template e(0); + + T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); + T epsBig = + (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); + + if (diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); + + for (int i = 1; i < len; ++i) + if (math::nd4j_abs(colVec0.template e(i)) < eps) colVec0.p(i, 0.f); + + for (int i = 1; i < len; i++) + if (diagInterval.template e(i) < epsBig) { + deflation1(col1, shift, i, len); + for (int i = 0; i < len; ++i) + diagInterval.p(i, _m.e(col1 + shift + i, col1 + shift + i)); + } + + { + bool totDefl = true; + for (int i = 1; i < len; i++) + if (colVec0.template e(i) >= almostZero) { + totDefl = false; + break; + } + + int* permut = nullptr; + ALLOCATE(permut, _m.getContext()->getWorkspace(), 3 * _diagSize, int); { - int i = len-1; - - while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0.template e(i)) < almostZero)) - --i; - - for(; i > 1; --i) { - if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) - throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); - deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); - } + permut[0] = 0; + int p = 1; + + for (int i = 1; i < len; ++i) + if (math::nd4j_abs(diagInterval.template e(i)) < almostZero) + permut[p++] = i; + + int k = 1, m = ind + 1; + + for (; p < len; ++p) { + if (k > ind) + permut[p] = m++; + else if (m >= len) + permut[p] = k++; + else if (diagInterval.template e(k) < diagInterval.template e(m)) + permut[p] = m++; + else + permut[p] = k++; + } + } + + if (totDefl) { + for (int i = 1; i < len; ++i) { + int ki = permut[i]; + if (math::nd4j_abs(diagInterval.template e(ki)) < almostZero || + diagInterval.template e(0) < diagInterval.template e(ki)) + permut[i - 1] = permut[i]; + else { + permut[i - 1] = 0; + break; } - } + } + } + + int* tInd = permut + len; + int* tCol = permut + 2 * len; + + for (int m = 0; m < len; m++) { + tCol[m] = m; + tInd[m] = m; + } + + for (int i = totDefl ? 0 : 1; i < len; i++) { + const int ki = permut[len - (totDefl ? i + 1 : i)]; + const int jac = tCol[ki]; + + T _e0 = diagInterval.template e(jac); + // math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); + + if (i != 0 && jac != 0) { + _e0 = colVec0.template e(jac); + // math::nd4j_swap((*colVec0)(i), (*colVec0)(jac)); + colVec0.p(jac, colVec0.template e(i)); + colVec0.p(i, _e0); + } + + if (_calcU) { + auto temp1 = _u({col1, col1 + len + 1, col1 + i, col1 + i + 1}, true); + auto temp2 = + _u({col1, col1 + len + 1, col1 + jac, col1 + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } else { + auto temp1 = _u({0, 2, col1 + i, col1 + i + 1}, true); + auto temp2 = _u({0, 2, col1 + jac, col1 + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + + if (_calcV) { + auto temp1 = _v({row1W, row1W + len, col1W + i, col1W + i + 1}, true); + auto temp2 = + _v({row1W, row1W + len, col1W + jac, col1W + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + + const int tI = tInd[i]; + tCol[tI] = jac; + tCol[ki] = i; + tInd[jac] = tI; + tInd[i] = ki; + } + + RELEASE(permut, _m.getContext()->getWorkspace()); + } + + { + int i = len - 1; + + while (i > 0 && + (math::nd4j_abs(diagInterval.template e(i)) < almostZero || + math::nd4j_abs(colVec0.template e(i)) < almostZero)) + --i; + + for (; i > 1; --i) { + if ((diagInterval.template e(i) - diagInterval.template e(i - 1)) < + DataTypeUtils::eps() * maxElem) { + if (math::nd4j_abs(diagInterval.template e(i) - + diagInterval.template e(i - 1)) >= epsBig) + throw std::runtime_error( + "ops::helpers::SVD::deflation: diagonal elements are not " + "properly sorted !"); + deflation2(col1, col1 + shift, row1W, col1W, i - 1, i, len); + } + } + } } - ////////////////////////////////////////////////////////////////////////// template -T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& diagShifted, const T shift) { - - auto len = permut.lengthOf(); - T res = 1.; - T item; - for(Nd4jLong i=0; i(i); - item = col0.e(j) / ((diagShifted.e(j) - diff) * (diag.e(j) + shift + diff)); - res += item * col0.e(j); - } - - return res; +T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& diagShifted, + const T shift) { + auto len = permut.lengthOf(); + T res = 1.; + T item; + for (Nd4jLong i = 0; i < len; ++i) { + auto j = permut.e(i); + item = col0.e(j) / + ((diagShifted.e(j) - diff) * (diag.e(j) + shift + diff)); + res += item * col0.e(j); + } + + return res; } ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus) { - - auto len = col0.lengthOf(); - auto curLen = len; - - while(curLen > 1 && col0.e(curLen-1) == (T)0.f) - --curLen; - - for (Nd4jLong k = 0; k < len; ++k) { - - if (col0.e(k) == (T)0.f || curLen==1) { - - singVals.p(k, k==0 ? col0.e(0) : diag.e(k)); - mus.p(k, 0.f); - shifts.p(k, k==0 ? col0.e(0) : diag.e(k)); - continue; - } +void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, + const NDArray& permut, NDArray& singVals, + NDArray& shifts, NDArray& mus) { + auto len = col0.lengthOf(); + auto curLen = len; - T left = diag.e(k); - T right; + while (curLen > 1 && col0.e(curLen - 1) == (T)0.f) --curLen; - if(k==curLen-1) - right = diag.e(curLen-1) + col0.reduceNumber(reduce::Norm2).e(0); - else { - - int l = k+1; - while(col0.e(l) == (T)0.f) { - ++l; - if(l >= curLen) - throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !"); - } - - right = diag.e(l); - } - - T mid = left + (right - left) / (T)2.; - T fMid = secularEq(mid, col0, diag, permut, diag, 0.); - T shift = (k == curLen-1 || fMid > (T)0.) ? left : right; + for (Nd4jLong k = 0; k < len; ++k) { + if (col0.e(k) == (T)0.f || curLen == 1) { + singVals.p(k, k == 0 ? col0.e(0) : diag.e(k)); + mus.p(k, 0.f); + shifts.p(k, k == 0 ? col0.e(0) : diag.e(k)); + continue; + } - auto diagShifted = diag - shift; + T left = diag.e(k); + T right; - T muPrev, muCur; - if (shift == left) { - muPrev = (right - left) * 0.1; - if (k == curLen-1) - muCur = right - left; - else - muCur = (right - left) * 0.5; - } + if (k == curLen - 1) + right = diag.e(curLen - 1) + col0.reduceNumber(reduce::Norm2).e(0); + else { + int l = k + 1; + while (col0.e(l) == (T)0.f) { + ++l; + if (l >= curLen) + throw std::runtime_error( + "ops::helpers::SVD::calcSingVals method: l >= curLen !"); + } + + right = diag.e(l); + } + + T mid = left + (right - left) / (T)2.; + T fMid = secularEq(mid, col0, diag, permut, diag, 0.); + T shift = (k == curLen - 1 || fMid > (T)0.) ? left : right; + + auto diagShifted = diag - shift; + + T muPrev, muCur; + if (shift == left) { + muPrev = (right - left) * 0.1; + if (k == curLen - 1) + muCur = right - left; + else + muCur = (right - left) * 0.5; + } else { + muPrev = -(right - left) * 0.1; + muCur = -(right - left) * 0.5; + } + + T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); + T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); + + if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { + math::nd4j_swap(fPrev, fCur); + math::nd4j_swap(muPrev, muCur); + } + + bool useBisection = fPrev * fCur > (T)0.; + while (fCur != (T).0 && + math::nd4j_abs(muCur - muPrev) > + (T)8. * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(muCur), + math::nd4j_abs(muPrev)) && + math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && + !useBisection) { + T a = (fCur - fPrev) / ((T)1. / muCur - (T)1. / muPrev); + T jac = fCur - a / muCur; + T muZero = -a / jac; + T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); + + muPrev = muCur; + fPrev = fCur; + muCur = muZero; + fCur = fZero; + + if (shift == left && (muCur < (T)0. || muCur > right - left)) + useBisection = true; + if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) + useBisection = true; + if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && + math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) + useBisection = true; + } + + if (useBisection) { + T leftShifted, rightShifted; + if (shift == left) { + leftShifted = DataTypeUtils::min(); + rightShifted = (k == curLen - 1) ? right : ((right - left) * (T)0.6); + } else { + leftShifted = -(right - left) * (T)0.6; + rightShifted = -DataTypeUtils::min(); + } + + T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); + T fRight = + secularEq(rightShifted, col0, diag, permut, diagShifted, shift); + // if(fLeft * fRight >= (T)0.) + // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. + // !"; + + while (rightShifted - leftShifted > + (T)2.f * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(leftShifted), + math::nd4j_abs(rightShifted))) { + T midShifted = (leftShifted + rightShifted) / (T)2.; + fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); + if (fLeft * fMid < (T)0.) + rightShifted = midShifted; else { - muPrev = -(right - left) * 0.1; - muCur = -(right - left) * 0.5; - } - - T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); - T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); - - if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { - math::nd4j_swap(fPrev, fCur); - math::nd4j_swap(muPrev, muCur); - } - - bool useBisection = fPrev * fCur > (T)0.; - while (fCur != (T).0 && - math::nd4j_abs(muCur - muPrev) > (T)8. * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(muCur), math::nd4j_abs(muPrev)) - && math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && !useBisection) { - - T a = (fCur - fPrev) / ((T)1./muCur - (T)1./muPrev); - T jac = fCur - a / muCur; - T muZero = -a/jac; - T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); - - muPrev = muCur; - fPrev = fCur; - muCur = muZero; - fCur = fZero; - - if (shift == left && (muCur < (T)0. || muCur > right - left)) - useBisection = true; - if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) - useBisection = true; - if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) - useBisection = true; + leftShifted = midShifted; + fLeft = fMid; } - - - if (useBisection) { - - T leftShifted, rightShifted; - if (shift == left) { - leftShifted = DataTypeUtils::min(); - rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6); - } - else { - - leftShifted = -(right - left) * (T)0.6; - rightShifted = -DataTypeUtils::min(); - } - - T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); - T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift); - // if(fLeft * fRight >= (T)0.) - // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !"; - - while (rightShifted - leftShifted > (T)2.f * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(leftShifted), math::nd4j_abs(rightShifted))) { - - T midShifted = (leftShifted + rightShifted) / (T)2.; - fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); - if (fLeft * fMid < (T)0.) - rightShifted = midShifted; - else { - leftShifted = midShifted; - fLeft = fMid; - } - } - muCur = (leftShifted + rightShifted) / (T)2.; - } - singVals.p(k, shift + muCur); - shifts.p(k, shift); - mus.p(k, muCur); + } + muCur = (leftShifted + rightShifted) / (T)2.; } - + singVals.p(k, shift + muCur); + shifts.p(k, shift); + mus.p(k, muCur); + } } - ////////////////////////////////////////////////////////////////////////// template -void SVD::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) { - - int n = col0.lengthOf(); - int m = permut.lengthOf(); - if(m==0) { - zhat.assign(0.); - return; - } - - int last = permut.e(m-1); - - for (int k = 0; k < n; ++k) { - - if (col0.e(k) == (T)0.f) - zhat.p(k, (T)0.f); - else { - T dk = diag.e(k); - T prod = (singVals.e(last) + dk) * (mus.e(last) + (shifts.e(last) - dk)); - - for(int l = 0; l(l); - if(i!=k) { - int j = i(l-1); - prod *= ((singVals.e(j)+dk) / ((diag.e(i)+dk))) * ((mus.e(j)+(shifts.e(j)-dk)) / ((diag.e(i)-dk))); - } - } - T tmp = math::nd4j_sqrt(prod); - zhat.p(k, col0.e(k) > (T)0.f ? tmp : -tmp); +void SVD::perturb(const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& zhat) { + int n = col0.lengthOf(); + int m = permut.lengthOf(); + if (m == 0) { + zhat.assign(0.); + return; + } + + int last = permut.e(m - 1); + + for (int k = 0; k < n; ++k) { + if (col0.e(k) == (T)0.f) + zhat.p(k, (T)0.f); + else { + T dk = diag.e(k); + T prod = (singVals.e(last) + dk) * + (mus.e(last) + (shifts.e(last) - dk)); + + for (int l = 0; l < m; ++l) { + int i = permut.e(l); + if (i != k) { + int j = i < k ? i : permut.e(l - 1); + prod *= + ((singVals.e(j) + dk) / ((diag.e(i) + dk))) * + ((mus.e(j) + (shifts.e(j) - dk)) / ((diag.e(i) - dk))); } + } + T tmp = math::nd4j_sqrt(prod); + zhat.p(k, col0.e(k) > (T)0.f ? tmp : -tmp); } + } } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, - const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V) { - - int n = zhat.lengthOf(); - int m = perm.lengthOf(); +void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, + const NDArray& perm, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& U, + NDArray& V) { + int n = zhat.lengthOf(); + int m = perm.lengthOf(); + + for (int k = 0; k < n; ++k) { + auto colU = new NDArray(U({0, 0, k, k + 1}, true)); + *colU = 0.; + NDArray* colV = nullptr; - for (int k = 0; k < n; ++k) { - - auto colU = new NDArray(U({0,0, k,k+1}, true)); - *colU = 0.; - NDArray* colV = nullptr; - - if (_calcV) { - colV = new NDArray(V({0,0, k,k+1}, true)); - *colV = 0.; - } - - if (zhat.e(k) == (T)0.f) { - colU->p(k, 1.f); - - if (_calcV) - colV->p(k, 1.f); - } - else { - - for(int l = 0; l < m; ++l) { - int i = perm.e(l); - U.p(i,k, zhat.e(i)/(((diag.e(i) - shifts.e(k)) - mus.e(k)) )/( (diag.e(i) + singVals.e(k)))); - } - U.p(n,k, 0.f); - *colU /= colU->reduceNumber(reduce::Norm2); - - if (_calcV) { - - for(int l = 1; l < m; ++l){ - int i = perm.e(l); - V.p(i,k, diag.e(i) * zhat.e(i) / (((diag.e(i) - shifts.e(k)) - mus.e(k)) )/( (diag.e(i) + singVals.e(k)))); - } - V.p(0,k, -1.f); - *colV /= colV->reduceNumber(reduce::Norm2); - } + if (_calcV) { + colV = new NDArray(V({0, 0, k, k + 1}, true)); + *colV = 0.; + } + + if (zhat.e(k) == (T)0.f) { + colU->p(k, 1.f); + + if (_calcV) colV->p(k, 1.f); + } else { + for (int l = 0; l < m; ++l) { + int i = perm.e(l); + U.p(i, k, + zhat.e(i) / (((diag.e(i) - shifts.e(k)) - mus.e(k))) / + ((diag.e(i) + singVals.e(k)))); + } + U.p(n, k, 0.f); + *colU /= colU->reduceNumber(reduce::Norm2); + + if (_calcV) { + for (int l = 1; l < m; ++l) { + int i = perm.e(l); + V.p(i, k, + diag.e(i) * zhat.e(i) / + (((diag.e(i) - shifts.e(k)) - mus.e(k))) / + ((diag.e(i) + singVals.e(k)))); } - delete colU; - if (_calcV) - delete colV; + V.p(0, k, -1.f); + *colV /= colV->reduceNumber(reduce::Norm2); + } } + delete colU; + if (_calcV) delete colV; + } - auto colU = U({0,0, n,n+1}, true); - colU = 0.; - colU.p(n, 1.); + auto colU = U({0, 0, n, n + 1}, true); + colU = 0.; + colU.p(n, 1.); } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDArray& V) { - - const T almostZero = DataTypeUtils::min(); - auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diag = _m({col1, col1+size, col1, col1+size}, true).diagonal('c').dup(); - - diag.p(Nd4jLong(0), T(0)); - singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - U = NDArrayFactory::create(_u.ordering(), {size+1, size+1}, _u.getContext()); - if (_calcV) - V = NDArrayFactory::create(_v.ordering(), {size, size}, _v.getContext()); - - int curSize = size; - while(curSize > 1 && diag.template e(curSize-1) == (T)0.f) - --curSize; - - int m = 0; - std::vector indices; - for(int k = 0; k < curSize; ++k) - if(math::nd4j_abs(col0.template e(k)) > almostZero) - indices.push_back((T)k); - - auto permut = NDArrayFactory::create(_m.ordering(), {1, (int)indices.size()}, indices, _m.getContext()); - auto shifts = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - auto mus = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - auto zhat = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - - calcSingVals(col0, diag, permut, singVals, shifts, mus); - perturb(col0, diag, permut, singVals, shifts, mus, zhat); - calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); - - for(int i=0; i(i) > singVals.e(i+1)) { - T _e0 = singVals.e(i); - T _e1 = singVals.e(i+1); - //math::nd4j_swap(singVals(i),singVals(i+1)); - singVals.p(i, _e1); - singVals.p(i+1, _e0); - - auto temp1 = U({0,0, i,i+1}, true); - auto temp2 = U({0,0, i+1,i+2}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - - if(_calcV) { - auto temp1 = V({0,0, i,i+1}, true); - auto temp2 = V({0,0, i+1,i+2}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - } - } - - auto temp1 = singVals({0,curSize, 0,0}, true); - for (int e = 0; e < curSize / 2; ++e) { - T tmp = temp1.e(e); - temp1.p(e, temp1.e(curSize-1-e)); - temp1.p(curSize-1-e, tmp); - } - - auto temp2 = U({0,0, 0,curSize}, true); - for(int i = 0; i < curSize/2; ++i) { - auto temp3 = temp2({0,0, i,i+1}, true); - auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3.dup(); - temp3.assign(temp4); - temp4.assign(temp5); - } - - if (_calcV) { - auto temp2 = V({0,0, 0,curSize}, true); - for(int i = 0; i < curSize/2; ++i) { - auto temp3 = temp2({0,0, i,i+1}, true); - auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3.dup(); - temp3.assign(temp4); - temp4.assign(temp5); - } - } +void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, + NDArray& V) { + const T almostZero = DataTypeUtils::min(); + auto col0 = _m({col1, col1 + size, col1, col1 + 1}, true); + auto diag = + _m({col1, col1 + size, col1, col1 + size}, true).diagonal('c').dup(); + + diag.p(Nd4jLong(0), T(0)); + singVals = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + U = NDArrayFactory::create(_u.ordering(), {size + 1, size + 1}, + _u.getContext()); + if (_calcV) + V = NDArrayFactory::create(_v.ordering(), {size, size}, _v.getContext()); + + int curSize = size; + while (curSize > 1 && diag.template e(curSize - 1) == (T)0.f) --curSize; + + int m = 0; + std::vector indices; + for (int k = 0; k < curSize; ++k) + if (math::nd4j_abs(col0.template e(k)) > almostZero) + indices.push_back((T)k); + + auto permut = NDArrayFactory::create( + _m.ordering(), {1, (int)indices.size()}, indices, _m.getContext()); + auto shifts = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + auto mus = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + auto zhat = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + + calcSingVals(col0, diag, permut, singVals, shifts, mus); + perturb(col0, diag, permut, singVals, shifts, mus, zhat); + calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); + + for (int i = 0; i < curSize - 1; ++i) { + if (singVals.e(i) > singVals.e(i + 1)) { + T _e0 = singVals.e(i); + T _e1 = singVals.e(i + 1); + // math::nd4j_swap(singVals(i),singVals(i+1)); + singVals.p(i, _e1); + singVals.p(i + 1, _e0); + + auto temp1 = U({0, 0, i, i + 1}, true); + auto temp2 = U({0, 0, i + 1, i + 2}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + + if (_calcV) { + auto temp1 = V({0, 0, i, i + 1}, true); + auto temp2 = V({0, 0, i + 1, i + 2}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + } + } + + auto temp1 = singVals({0, curSize, 0, 0}, true); + for (int e = 0; e < curSize / 2; ++e) { + T tmp = temp1.e(e); + temp1.p(e, temp1.e(curSize - 1 - e)); + temp1.p(curSize - 1 - e, tmp); + } + + auto temp2 = U({0, 0, 0, curSize}, true); + for (int i = 0; i < curSize / 2; ++i) { + auto temp3 = temp2({0, 0, i, i + 1}, true); + auto temp4 = temp2({0, 0, curSize - 1 - i, curSize - i}, true); + auto temp5 = temp3.dup(); + temp3.assign(temp4); + temp4.assign(temp5); + } + + if (_calcV) { + auto temp2 = V({0, 0, 0, curSize}, true); + for (int i = 0; i < curSize / 2; ++i) { + auto temp3 = temp2({0, 0, i, i + 1}, true); + auto temp4 = temp2({0, 0, curSize - 1 - i, curSize - i}, true); + auto temp5 = temp3.dup(); + temp3.assign(temp4); + temp4.assign(temp5); + } + } } - ////////////////////////////////////////////////////////////////////////// -template -void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift) { - - // requires rows = cols + 1; - const int n = col2 - col1 + 1; - const int k = n/2; - const T almostZero = DataTypeUtils::min(); - T alphaK; - T betaK; - T r0; - T lambda, phi, c0, s0; - auto l = NDArrayFactory::create(_u.ordering(), {1, k}, _u.getContext()); - auto f = NDArrayFactory::create(_u.ordering(), {1, n-k-1}, _u.getContext()); - - if(n < _switchSize) { - - JacobiSVD jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV); - - if (_calcU) { - auto temp = _u({col1,col1+n+1, col1,col1+n+1}, true); - temp.assign(jac._u); - } - else { - auto temp1 = _u({0,1, col1,col1+n+1}, true); - temp1.assign(jac._u({0,1, 0,0}, true)); - auto temp2 = _u({1,2, col1,col1+n+1}, true); - temp2.assign(jac._u({n,n+1, 0,0}, true)); - } - - if (_calcV) { - auto temp = _v({row1W,row1W+n, col1W,col1W+n}, true); - temp.assign(jac._v); - } - - auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); - temp.assign(0.); - auto diag = _m.diagonal('c'); - diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - - return; - } - - alphaK = _m.e(col1 + k, col1 + k); - betaK = _m.e(col1 + k + 1, col1 + k); - - DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); - DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); - - if (_calcU) { - lambda = _u.e(col1 + k, col1 + k); - phi = _u.e(col1 + k + 1, col2 + 1); - } - else { - lambda = _u.e(1, col1 + k); - phi = _u.e(0, col2 + 1); - } - - r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * math::nd4j_abs(alphaK * lambda)) + math::nd4j_abs(betaK * phi) * math::nd4j_abs(betaK * phi)); - - if(_calcU) { - l.assign(_u({col1+k, col1+k+1, col1,col1+k}, true)); - f.assign(_u({col1+k+1,col1+k+2, col1+k+1,col1+n}, true)); - } - else { - l.assign(_u({1,2, col1, col1+k}, true)); - f.assign(_u({0,1, col1+k+1, col1+n}, true)); - } - - if (_calcV) - _v.p(row1W+k, col1W, 1.f); - - if (r0 < almostZero){ - c0 = 1.; - s0 = 0.; - } - else { - c0 = alphaK * lambda / r0; - s0 = betaK * phi / r0; - } +template +void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, + int shift) { + // requires rows = cols + 1; + const int n = col2 - col1 + 1; + const int k = n / 2; + const T almostZero = DataTypeUtils::min(); + T alphaK; + T betaK; + T r0; + T lambda, phi, c0, s0; + auto l = NDArrayFactory::create(_u.ordering(), {1, k}, _u.getContext()); + auto f = + NDArrayFactory::create(_u.ordering(), {1, n - k - 1}, _u.getContext()); + + if (n < _switchSize) { + JacobiSVD jac(_m({col1, col1 + n + 1, col1, col1 + n}, true), _calcU, + _calcV, _fullUV); if (_calcU) { - - auto q1 = _u({col1,col1+k+1, col1+k,col1+k+1}, true).dup(); - - for (int i = col1 + k - 1; i >= col1; --i) { - auto temp = _u({col1,col1+k+1, i+1,i+2}, true); - temp.assign(_u({col1, col1+k+1, i, i+1}, true)); - } - - _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); - _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); - _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); - _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; - } - else { - - T q1 = _u.e(0, col1 + k); - - for (int i = col1 + k - 1; i >= col1; --i) - _u.p(0, i+1, _u.e(0, i)); - - _u.p(0, col1, q1 * c0); - _u.p(0, col2+1, -q1*s0); - _u.p(1, col1, _u.e(1, col2+1) * s0); - _u.p(1, col2 + 1, _u.e(1, col2 + 1) * c0); - _u({1,2, col1+1, col1+k+1}, true) = 0.f; - _u({0,1, col1+k+1, col1+n}, true) = 0.f; - } - - _m.p(col1 + shift, col1 + shift, r0); - auto temp1 = _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true); - temp1.assign(l*alphaK); - auto temp2 = _m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true); - temp2.assign(f*betaK); - - deflation(col1, col2, k, row1W, col1W, shift); - - NDArray UofSVD, VofSVD, singVals; - calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); - - if(_calcU) { - auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, UofSVD)); - } - else { - auto pTemp = _u({0,0, col1,col1+n+1}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, UofSVD)); + auto temp = _u({col1, col1 + n + 1, col1, col1 + n + 1}, true); + temp.assign(jac._u); + } else { + auto temp1 = _u({0, 1, col1, col1 + n + 1}, true); + temp1.assign(jac._u({0, 1, 0, 0}, true)); + auto temp2 = _u({1, 2, col1, col1 + n + 1}, true); + temp2.assign(jac._u({n, n + 1, 0, 0}, true)); } if (_calcV) { - auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, VofSVD)); - } - - auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); - blockM = 0.f; - auto diag = blockM.diagonal('c'); - diag.assign(singVals); + auto temp = _v({row1W, row1W + n, col1W, col1W + n}, true); + temp.assign(jac._v); + } + + auto temp = + _m({col1 + shift, col1 + shift + n + 1, col1 + shift, col1 + shift + n}, + true); + temp.assign(0.); + auto diag = _m.diagonal('c'); + diag({col1 + shift, col1 + shift + n, 0, 0}, true) + .assign(jac._s({0, n, 0, 0}, true)); + + return; + } + + alphaK = _m.e(col1 + k, col1 + k); + betaK = _m.e(col1 + k + 1, col1 + k); + + DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); + DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); + + if (_calcU) { + lambda = _u.e(col1 + k, col1 + k); + phi = _u.e(col1 + k + 1, col2 + 1); + } else { + lambda = _u.e(1, col1 + k); + phi = _u.e(0, col2 + 1); + } + + r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * + math::nd4j_abs(alphaK * lambda)) + + math::nd4j_abs(betaK * phi) * + math::nd4j_abs(betaK * phi)); + + if (_calcU) { + l.assign(_u({col1 + k, col1 + k + 1, col1, col1 + k}, true)); + f.assign(_u({col1 + k + 1, col1 + k + 2, col1 + k + 1, col1 + n}, true)); + } else { + l.assign(_u({1, 2, col1, col1 + k}, true)); + f.assign(_u({0, 1, col1 + k + 1, col1 + n}, true)); + } + + if (_calcV) _v.p(row1W + k, col1W, 1.f); + + if (r0 < almostZero) { + c0 = 1.; + s0 = 0.; + } else { + c0 = alphaK * lambda / r0; + s0 = betaK * phi / r0; + } + + if (_calcU) { + auto q1 = _u({col1, col1 + k + 1, col1 + k, col1 + k + 1}, true).dup(); + + for (int i = col1 + k - 1; i >= col1; --i) { + auto temp = _u({col1, col1 + k + 1, i + 1, i + 2}, true); + temp.assign(_u({col1, col1 + k + 1, i, i + 1}, true)); + } + + _u({col1, col1 + k + 1, col1, col1 + 1}, true).assign(q1 * c0); + _u({col1, col1 + k + 1, col2 + 1, col2 + 2}, true).assign(q1 * (-s0)); + _u({col1 + k + 1, col1 + n + 1, col1, col1 + 1}, true) + .assign(static_cast(_u( + {col1 + k + 1, col1 + n + 1, col2 + 1, col2 + 2}, true)) * + s0); + _u({col1 + k + 1, col1 + n + 1, col2 + 1, col2 + 2}, true) *= c0; + } else { + T q1 = _u.e(0, col1 + k); + + for (int i = col1 + k - 1; i >= col1; --i) _u.p(0, i + 1, _u.e(0, i)); + + _u.p(0, col1, q1 * c0); + _u.p(0, col2 + 1, -q1 * s0); + _u.p(1, col1, _u.e(1, col2 + 1) * s0); + _u.p(1, col2 + 1, _u.e(1, col2 + 1) * c0); + _u({1, 2, col1 + 1, col1 + k + 1}, true) = 0.f; + _u({0, 1, col1 + k + 1, col1 + n}, true) = 0.f; + } + + _m.p(col1 + shift, col1 + shift, r0); + auto temp1 = _m( + {col1 + shift + 1, col1 + shift + k + 1, col1 + shift, col1 + shift + 1}, + true); + temp1.assign(l * alphaK); + auto temp2 = _m( + {col1 + shift + k + 1, col1 + shift + n, col1 + shift, col1 + shift + 1}, + true); + temp2.assign(f * betaK); + + deflation(col1, col2, k, row1W, col1W, shift); + + NDArray UofSVD, VofSVD, singVals; + calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); + + if (_calcU) { + auto pTemp = _u({col1, col1 + n + 1, col1, col1 + n + 1}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, UofSVD)); + } else { + auto pTemp = _u({0, 0, col1, col1 + n + 1}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, UofSVD)); + } + + if (_calcV) { + auto pTemp = _v({row1W, row1W + n, row1W, row1W + n}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, VofSVD)); + } + + auto blockM = _m( + {col1 + shift, col1 + shift + n, col1 + shift, col1 + shift + n}, true); + blockM = 0.f; + auto diag = blockM.diagonal('c'); + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// -template -void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V) { - - if (_calcU) { - - int colsU = _fullUV ? hhU.rows() : _diagSize; - auto temp1 = NDArrayFactory::create(_u.ordering(), {hhU.rows(), colsU}, _u.getContext()); - temp1.setIdentity(); - _u = temp1; - - auto temp2 = _u({0,_diagSize, 0,_diagSize}, true); - temp2.assign(V({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhU).mulLeft(_u); - } - - if (_calcV) { - - int colsV = _fullUV ? hhV.rows() : _diagSize; - auto temp1 = NDArrayFactory::create(_v.ordering(), {hhV.rows(), colsV}, _v.getContext()); - temp1.setIdentity(); - _v = temp1; - - auto temp2 = _v({0,_diagSize, 0,_diagSize}, true); - temp2.assign(U({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhV).mulLeft(_v); - } +template +void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, + const NDArray& U, const NDArray& V) { + if (_calcU) { + int colsU = _fullUV ? hhU.rows() : _diagSize; + auto temp1 = NDArrayFactory::create(_u.ordering(), {hhU.rows(), colsU}, + _u.getContext()); + temp1.setIdentity(); + _u = temp1; + + auto temp2 = _u({0, _diagSize, 0, _diagSize}, true); + temp2.assign(V({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhU).mulLeft(_u); + } + + if (_calcV) { + int colsV = _fullUV ? hhV.rows() : _diagSize; + auto temp1 = NDArrayFactory::create(_v.ordering(), {hhV.rows(), colsV}, + _v.getContext()); + temp1.setIdentity(); + _v = temp1; + + auto temp2 = _v({0, _diagSize, 0, _diagSize}, true); + temp2.assign(U({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhV).mulLeft(_v); + } } ////////////////////////////////////////////////////////////////////////// template void SVD::evalData(const NDArray& matrix) { + const T almostZero = DataTypeUtils::min(); - const T almostZero = DataTypeUtils::min(); + if (matrix.sizeAt(1) < _switchSize) { + JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); - if(matrix.sizeAt(1) < _switchSize) { + if (_calcU) _u = jac._u; + if (_calcV) _v = jac._v; - JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); + _s.assign(jac._s); - if(_calcU) - _u = jac._u; - if(_calcV) - _v = jac._v; + return; + } - _s.assign(jac._s); + T scale = matrix.reduceNumber(reduce::AMax).e(0); - return; - } - - T scale = matrix.reduceNumber(reduce::AMax).e(0); + if (scale == (T)0.) scale = 1.; - if(scale == (T)0.) - scale = 1.; + NDArray copy; + if (_transp) + copy = matrix.transpose(); + else + copy = matrix / scale; - NDArray copy; - if(_transp) - copy = matrix.transpose(); - else - copy = matrix / scale; + BiDiagonalUp biDiag(copy); - BiDiagonalUp biDiag(copy); + _u = 0.; + _v = 0.; - _u = 0.; - _v = 0.; + auto temp1 = biDiag._HHbidiag.transpose(); + auto temp2 = _m({0, _diagSize, 0, 0}, true); + temp2.assign(temp1); - auto temp1 = biDiag._HHbidiag.transpose(); - auto temp2 = _m({0,_diagSize, 0,0}, true); - temp2.assign(temp1); + auto temp3 = _m({_m.sizeAt(0) - 1, _m.sizeAt(0), 0, 0}, true); + temp3.assign(0.); + DivideAndConquer(0, _diagSize - 1, 0, 0, 0); - auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true); - temp3.assign(0.); + for (int i = 0; i < _diagSize; ++i) { + T a = math::nd4j_abs(_m.e(i, i)); + _s.p(i, a * scale); + if (a < almostZero) { + auto temp = _s({i + 1, _diagSize, 0, 0}, true); + temp.assign(0.); + break; + } else if (i == _diagSize - 1) + break; + } - DivideAndConquer(0, _diagSize - 1, 0, 0, 0); - - for (int i = 0; i < _diagSize; ++i) { - T a = math::nd4j_abs(_m.e(i, i)); - _s.p(i, a * scale); - if (a < almostZero) { - auto temp = _s({i+1,_diagSize, 0,0}, true); - temp.assign(0.); - break; - } - else if (i == _diagSize-1) - break; - } - - if(_transp) - exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u); - else - exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v); + if (_transp) + exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u); + else + exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v); } +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD, , FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD,,FLOAT_TYPES); - - - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index 26de981f8b4b..6fa88f9c1538 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -23,30 +23,32 @@ #include #include -#include + #include +#include namespace sd { - class SD_EXPORT CublasHelper { - private: - static CublasHelper *_INSTANCE; - static std::mutex _mutex; +class SD_EXPORT CublasHelper { + private: + static CublasHelper* _INSTANCE; + static std::mutex _mutex; + + std::vector _cache; + std::vector _solvers; + std::vector _cudnn; - std::vector _cache; - std::vector _solvers; - std::vector _cudnn; + CublasHelper(); + ~CublasHelper(); - CublasHelper(); - ~CublasHelper(); - public: - static CublasHelper* getInstance(); + public: + static CublasHelper* getInstance(); - void* cudnn(); - void* solver(); + void* cudnn(); + void* solver(); - void* handle(); - void* handle(int deviceId); - }; -} + void* handle(); + void* handle(int deviceId); +}; +} // namespace sd -#endif //SD_CUBLASHELPER_H +#endif // SD_CUBLASHELPER_H diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 62d932489734..dcdf23bc2966 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,171 +19,181 @@ // @author raver119@gmail.com // +#include +#include +#include #include +#include +#include #include -#include +#include #include -#include #include -#include -#include -#include -#include #define CONSTANT_LIMIT 49152 __constant__ char deviceConstantMemory[CONSTANT_LIMIT]; namespace sd { - static void* getConstantSpace() { - Nd4jPointer dConstAddr; - auto dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); - - if (dZ != 0) - throw cuda_exception::build("cudaGetSymbolAddress(...) failed", dZ); - - return dConstAddr; - } - - int ConstantHelper::getCurrentDevice() { - return AffinityManager::currentDeviceId(); +static void *getConstantSpace() { + Nd4jPointer dConstAddr; + auto dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), + deviceConstantMemory); + + if (dZ != 0) + throw cuda_exception::build("cudaGetSymbolAddress(...) failed", dZ); + + return dConstAddr; +} + +int ConstantHelper::getCurrentDevice() { + return AffinityManager::currentDeviceId(); +} + +int ConstantHelper::getNumberOfDevices() { + return AffinityManager::numberOfDevices(); +} + +ConstantHelper::ConstantHelper() { + auto initialDevice = getCurrentDevice(); + + auto numDevices = getNumberOfDevices(); + _devicePointers.resize(numDevices); + _deviceOffsets.resize(numDevices); + _cache.resize(numDevices); + _counters.resize(numDevices); + + // filling all pointers + for (int e = 0; e < numDevices; e++) { + auto res = cudaSetDevice(e); + if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); + auto constant = getConstantSpace(); + + MAP_IMPL devCache; + + _devicePointers[e] = constant; + _deviceOffsets[e] = 0; + _cache[e] = devCache; + _counters[e] = 0L; + } + + // + auto res = cudaSetDevice(initialDevice); + if (res != 0) throw cuda_exception::build("Final cudaSetDevice failed", res); +} + +ConstantHelper *ConstantHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new sd::ConstantHelper(); + + return _INSTANCE; +} + +void *ConstantHelper::replicatePointer(void *src, size_t numBytes, + memory::Workspace *workspace) { + std::lock_guard lock(_mutex); + + auto deviceId = getCurrentDevice(); + Nd4jPointer constantPtr = nullptr; + Nd4jLong constantOffset = 0L; + if (_devicePointers[deviceId] == 0) { + auto constant = getConstantSpace(); + + // filling default ptr, which will be 0 probably + _devicePointers[deviceId] = constant; + _deviceOffsets[deviceId] = 0; + constantPtr = constant; + } else { + constantPtr = _devicePointers[deviceId]; + constantOffset = _deviceOffsets[deviceId]; + } + + if (constantOffset + numBytes >= CONSTANT_LIMIT) { + int8_t *ptr = nullptr; + ALLOCATE_SPECIAL(ptr, workspace, numBytes, int8_t); + auto res = cudaMemcpy(ptr, src, numBytes, cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("cudaMemcpy failed", res); + + return ptr; + } else { + auto originalBytes = numBytes; + auto rem = numBytes % 8; + if (rem != 0) numBytes += 8 - rem; + + _deviceOffsets[deviceId] += numBytes; + + auto res = cudaMemcpyToSymbol(deviceConstantMemory, + const_cast(src), originalBytes, + constantOffset, cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("cudaMemcpyToSymbol failed", res); + + return reinterpret_cast(constantPtr) + constantOffset; + } +} + +ConstantDataBuffer *ConstantHelper::constantBuffer( + const ConstantDescriptor &descriptor, sd::DataType dataType) { + const auto deviceId = getCurrentDevice(); + + // all cache modifications are synchronous + _mutexHolder.lock(); + + if (_cache[deviceId].count(descriptor) == 0) { + _cache[deviceId][descriptor] = new ConstantHolder(); + } + auto holder = _cache[deviceId][descriptor]; + + // release cache lock + _mutexHolder.unlock(); + + ConstantDataBuffer *result; + + // access to this holder instance is synchronous + std::lock_guard lock(*holder->mutex()); + + if (holder->hasBuffer(dataType)) { + result = holder->getConstantDataBuffer(dataType); + } else { + auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); + auto cbuff = new int8_t[numBytes]; + _counters[deviceId] += numBytes; + + // create buffer with this dtype + if (descriptor.isFloat()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::DOUBLE, dataType, + sd::SpecialTypeConverter::convertGeneric, + (nullptr, const_cast(descriptor.floatValues().data()), + descriptor.length(), cbuff), + (sd::DataType::DOUBLE, double), LIBND4J_TYPES); + } else if (descriptor.isInteger()) { + BUILD_DOUBLE_SELECTOR( + sd::DataType::INT64, dataType, + sd::SpecialTypeConverter::convertGeneric, + (nullptr, const_cast(descriptor.integerValues().data()), + descriptor.length(), cbuff), + (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); } - int ConstantHelper::getNumberOfDevices() { - return AffinityManager::numberOfDevices(); - } - - - ConstantHelper::ConstantHelper() { - auto initialDevice = getCurrentDevice(); + auto dbuff = replicatePointer( + cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); - auto numDevices = getNumberOfDevices(); - _devicePointers.resize(numDevices); - _deviceOffsets.resize(numDevices); - _cache.resize(numDevices); - _counters.resize(numDevices); + ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), + DataTypeUtils::sizeOf(dataType)); - // filling all pointers - for (int e = 0; e < numDevices; e++) { - auto res = cudaSetDevice(e); - if (res != 0) - throw cuda_exception::build("cudaSetDevice failed", res); - auto constant = getConstantSpace(); + holder->addBuffer(dataBuffer, dataType); + result = holder->getConstantDataBuffer(dataType); + } - MAP_IMPL devCache; + return result; +} - _devicePointers[e] = constant; - _deviceOffsets[e] = 0; - _cache[e] = devCache; - _counters[e] = 0L; - } - - // - auto res = cudaSetDevice(initialDevice); - if (res != 0) - throw cuda_exception::build("Final cudaSetDevice failed", res); - } - - ConstantHelper* ConstantHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new sd::ConstantHelper(); - - return _INSTANCE; - } - - void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { - std::lock_guard lock(_mutex); - - auto deviceId = getCurrentDevice(); - Nd4jPointer constantPtr = nullptr; - Nd4jLong constantOffset = 0L; - if (_devicePointers[deviceId] == 0) { - auto constant = getConstantSpace(); - - // filling default ptr, which will be 0 probably - _devicePointers[deviceId] = constant; - _deviceOffsets[deviceId] = 0; - constantPtr = constant; - } else { - constantPtr = _devicePointers[deviceId]; - constantOffset = _deviceOffsets[deviceId]; - } - - if (constantOffset + numBytes >= CONSTANT_LIMIT) { - int8_t *ptr = nullptr; - ALLOCATE_SPECIAL(ptr, workspace, numBytes, int8_t); - auto res = cudaMemcpy(ptr, src, numBytes, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("cudaMemcpy failed", res); - - return ptr; - } else { - auto originalBytes = numBytes; - auto rem = numBytes % 8; - if (rem != 0) - numBytes += 8 - rem; - - _deviceOffsets[deviceId] += numBytes; - - auto res = cudaMemcpyToSymbol(deviceConstantMemory, const_cast(src), originalBytes, constantOffset, cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("cudaMemcpyToSymbol failed", res); - - return reinterpret_cast(constantPtr) + constantOffset; - } - } - - ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, sd::DataType dataType) { - const auto deviceId = getCurrentDevice(); - - // all cache modifications are synchronous - _mutexHolder.lock(); - - if (_cache[deviceId].count(descriptor) == 0) { - _cache[deviceId][descriptor] = new ConstantHolder(); - } - auto holder = _cache[deviceId][descriptor]; - - // release cache lock - _mutexHolder.unlock(); - - ConstantDataBuffer* result; - - // access to this holder instance is synchronous - std::lock_guard lock(*holder->mutex()); - - if (holder->hasBuffer(dataType)) { - result = holder->getConstantDataBuffer(dataType); - } else { - auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); - auto cbuff = new int8_t[numBytes]; - _counters[deviceId] += numBytes; - - // create buffer with this dtype - if (descriptor.isFloat()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::DOUBLE, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.floatValues().data()), descriptor.length(), cbuff), (sd::DataType::DOUBLE, double), LIBND4J_TYPES); - } else if (descriptor.isInteger()) { - BUILD_DOUBLE_SELECTOR(sd::DataType::INT64, dataType, sd::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.integerValues().data()), descriptor.length(), cbuff), (sd::DataType::INT64, Nd4jLong), LIBND4J_TYPES); - } - - auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); - - ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); - - holder->addBuffer(dataBuffer, dataType); - result = holder->getConstantDataBuffer(dataType); - } - - return result; - } - - Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { - int numDevices = getNumberOfDevices(); - if (deviceId > numDevices || deviceId < 0) - return 0L; - else - return _counters[deviceId]; - } +Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { + int numDevices = getNumberOfDevices(); + if (deviceId > numDevices || deviceId < 0) + return 0L; + else + return _counters[deviceId]; +} - sd::ConstantHelper* sd::ConstantHelper::_INSTANCE = 0; -} \ No newline at end of file +sd::ConstantHelper *sd::ConstantHelper::_INSTANCE = 0; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index 2026dbb04569..4bb34961a7f5 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -18,176 +18,197 @@ // @author raver119@gmail.com // -#include "../ConstantShapeHelper.h" -#include #include -#include +#include #include #include +#include + +#include "../ConstantShapeHelper.h" namespace sd { - ConstantShapeHelper::ConstantShapeHelper() { - auto numDevices = AffinityManager::numberOfDevices(); +ConstantShapeHelper::ConstantShapeHelper() { + auto numDevices = AffinityManager::numberOfDevices(); - _cache.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - MAP_IMPL cache; - _cache[e] = cache; - } - } + _cache.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + MAP_IMPL cache; + _cache[e] = cache; + } +} - ConstantShapeHelper* ConstantShapeHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new ConstantShapeHelper(); +ConstantShapeHelper* ConstantShapeHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new ConstantShapeHelper(); - return _INSTANCE; - } + return _INSTANCE; +} - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor); - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + sd::DataType dataType, char order, const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor); +} - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor); - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor); +} - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { - int deviceId = AffinityManager::currentDeviceId(); - - std::lock_guard lock(_mutex); - - if (_cache[deviceId].count(descriptor) == 0) { - auto hPtr = descriptor.toShapeInfo(); - auto dPtr = ConstantHelper::getInstance()->replicatePointer(hPtr, shape::shapeInfoByteLength(hPtr)); - ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); - ShapeDescriptor descriptor1(descriptor); - _cache[deviceId][descriptor1] = buffer; - return _cache[deviceId][descriptor1]; - } else { - return _cache[deviceId].at(descriptor); - } - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const ShapeDescriptor& descriptor) { + int deviceId = AffinityManager::currentDeviceId(); + + std::lock_guard lock(_mutex); + + if (_cache[deviceId].count(descriptor) == 0) { + auto hPtr = descriptor.toShapeInfo(); + auto dPtr = ConstantHelper::getInstance()->replicatePointer( + hPtr, shape::shapeInfoByteLength(hPtr)); + ConstantDataBuffer buffer(hPtr, dPtr, + shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), + DataType::INT64); + ShapeDescriptor descriptor1(descriptor); + _cache[deviceId][descriptor1] = buffer; + return _cache[deviceId][descriptor1]; + } else { + return _cache[deviceId].at(descriptor); + } +} - ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const Nd4jLong *shapeInfo) { - ShapeDescriptor descriptor(shapeInfo); - return bufferForShapeInfo(descriptor); - } +ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo( + const Nd4jLong* shapeInfo) { + ShapeDescriptor descriptor(shapeInfo); + return bufferForShapeInfo(descriptor); +} - bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { - auto deviceId = AffinityManager::currentDeviceId(); - std::lock_guard lock(_mutex); +bool ConstantShapeHelper::checkBufferExistenceForShapeInfo( + ShapeDescriptor& descriptor) { + auto deviceId = AffinityManager::currentDeviceId(); + std::lock_guard lock(_mutex); - return _cache[deviceId].count(descriptor) != 0; - } + return _cache[deviceId].count(descriptor) != 0; +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { - ShapeDescriptor descriptor(dataType, order, shape, rank); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, const int rank, + const Nd4jLong* shape) { + ShapeDescriptor descriptor(dataType, order, shape, rank); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const Nd4jLong* shapeInfo) { - return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); - } +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo( + dataType, shape::order(shapeInfo), shape::rank(shapeInfo), + shape::shapeOf(const_cast(shapeInfo))); +} - Nd4jLong const* ConstantShapeHelper::emptyShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::emptyShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::scalarShapeInfo(const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::scalarShapeInfo( + const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::scalarDescriptor(dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::vectorShapeInfo(const Nd4jLong length, const sd::DataType dataType) { - auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::vectorShapeInfo( + const Nd4jLong length, const sd::DataType dataType) { + auto descriptor = ShapeDescriptor::vectorDescriptor(length, dataType); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape) { - ShapeDescriptor descriptor(dataType, order, shape); - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const sd::DataType dataType, const char order, + const std::vector& shape) { + ShapeDescriptor descriptor(dataType, order, shape); + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::createShapeInfo(const ShapeDescriptor &descriptor) { - return bufferForShapeInfo(descriptor).primaryAsT(); - } +Nd4jLong const* ConstantShapeHelper::createShapeInfo( + const ShapeDescriptor& descriptor) { + return bufferForShapeInfo(descriptor).primaryAsT(); +} - Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong* shapeInfo, + bool destroyOriginal) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - if (destroyOriginal) - RELEASE(shapeInfo, nullptr); + if (destroyOriginal) RELEASE(shapeInfo, nullptr); - return result; - } + return result; +} - Nd4jLong const* ConstantShapeHelper::createFromExisting(Nd4jLong *shapeInfo, sd::memory::Workspace *workspace) { - ShapeDescriptor descriptor(shapeInfo); - auto result = createShapeInfo(descriptor); +Nd4jLong const* ConstantShapeHelper::createFromExisting( + Nd4jLong* shapeInfo, sd::memory::Workspace* workspace) { + ShapeDescriptor descriptor(shapeInfo); + auto result = createShapeInfo(descriptor); - RELEASE(shapeInfo, workspace); + RELEASE(shapeInfo, workspace); - return result; - } + return result; +} //////////////////////////////////////////////////////////////////////// -ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector& dimensions) { - - Nd4jLong* newShapeInfo = nullptr; - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); - - newShapeInfo[0] = shape::rank(maxShapeInfo); - - sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type - newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews - newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order - - if(!dimensions.empty()) { - - for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { - - if(j < dimensions.size() && dimensions[j] == i) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; - shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; - ++j; - } - else{ - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) - ++k; - } - } - } - else{ - - for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { - - if(j >= 0) { - shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; - shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; - --j; - } - else { - shape::shapeOf(newShapeInfo)[i] = 1; - shape::stride(newShapeInfo)[i] = 0; - } - } - } - - ShapeDescriptor descriptor(newShapeInfo); - - RELEASE(newShapeInfo, workspace); - - return bufferForShapeInfo(descriptor); +ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast( + const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, + sd::memory::Workspace* workspace, const std::vector& dimensions) { + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, + shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = + shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if (!dimensions.empty()) { + for (uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + if (j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if (shape::sizeAt(minShapeInfo, k) == 1 && + dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } + } + } else { + for (int j = shape::rank(minShapeInfo) - 1, + i = shape::rank(maxShapeInfo) - 1; + i >= 0; --i) { + if (j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 + ? 0 + : shape::stride(minShapeInfo)[j]; + --j; + } else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } + } + } + + ShapeDescriptor descriptor(newShapeInfo); + + RELEASE(newShapeInfo, workspace); + + return bufferForShapeInfo(descriptor); } - sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index 8463bab9c5bb..253d68a025b1 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -18,95 +18,118 @@ // @author raver119@gmail.com // -#include "../ConstantTadHelper.h" -#include -#include -#include #include +#include #include +#include #include +#include -namespace sd { - ConstantTadHelper::ConstantTadHelper() { - auto numDevices = AffinityManager::numberOfDevices(); - - for (int e = 0; e < numDevices; e++) { - MAP_IMPL pack; - _cache.emplace_back(pack); - } - } - - ConstantTadHelper* ConstantTadHelper::getInstance() { - if (!_INSTANCE) - _INSTANCE = new ConstantTadHelper(); - - return _INSTANCE; - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector &dimensions, const bool keepUnitiesInShape) { - return tadForDimensions(originalShape, const_cast(dimensions.data()), dimensions.size(), keepUnitiesInShape); - } - - TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector &dimensions, const bool keepUnitiesInShape) { - TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); - return tadForDimensions(tadDescriptor); - } - - TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { - const int deviceId = AffinityManager::currentDeviceId(); - - std::lock_guard lock(_mutex); - - if (_cache[deviceId].count(descriptor) == 0) { - const auto shapeInfo = descriptor.originalShape().toShapeInfo(); - const int rank = shape::rank(shapeInfo); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); - const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); - - auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; - auto oPtr = new Nd4jLong[numOfSubArrs]; - - if (numOfSubArrs > 0) - shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); - - Nd4jPointer soPtr; - auto res = cudaMalloc(reinterpret_cast(&soPtr), numOfSubArrs * sizeof(Nd4jLong)); - if (res != 0) - throw cuda_exception::build("Memory allocation for tadOffsets failed", res); - - res = cudaMemcpy(soPtr, oPtr, numOfSubArrs * sizeof(Nd4jLong), cudaMemcpyHostToDevice); - if (res != 0) - throw cuda_exception::build("tadOffsets copy failed", res); - - auto ssPtr = ConstantHelper::getInstance()->replicatePointer(sPtr, shape::shapeInfoByteLength(subArrRank)); - - ConstantDataBuffer shapesBuffer(sPtr, ssPtr, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong), DataType::INT64); - ConstantDataBuffer offsetsBuffer(oPtr, soPtr, numOfSubArrs * sizeof(Nd4jLong), DataType::INT64); - - TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); - _cache[deviceId][descriptor] = t; - - TadPack r = _cache[deviceId][descriptor]; - - delete[] shapeInfo; - - return r; - } else { - TadPack r = _cache[deviceId][descriptor]; - - return r; - } - } +#include "../ConstantTadHelper.h" - sd::ConstantTadHelper* sd::ConstantTadHelper::_INSTANCE = 0; -} \ No newline at end of file +namespace sd { +ConstantTadHelper::ConstantTadHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) { + MAP_IMPL pack; + _cache.emplace_back(pack); + } +} + +ConstantTadHelper *ConstantTadHelper::getInstance() { + if (!_INSTANCE) _INSTANCE = new ConstantTadHelper(); + + return _INSTANCE; +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int dimension, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + const std::vector &dimensions, + const bool keepUnitiesInShape) { + return tadForDimensions(originalShape, const_cast(dimensions.data()), + dimensions.size(), keepUnitiesInShape); +} + +TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, + int *dimensions, int dimLength, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, + keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, + std::vector &dimensions, + const bool keepUnitiesInShape) { + TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); + return tadForDimensions(tadDescriptor); +} + +TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { + const int deviceId = AffinityManager::currentDeviceId(); + + std::lock_guard lock(_mutex); + + if (_cache[deviceId].count(descriptor) == 0) { + const auto shapeInfo = descriptor.originalShape().toShapeInfo(); + const int rank = shape::rank(shapeInfo); + const std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); + const int subArrRank = + (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) + ? rank + : rank - dimsToExclude.size(); + + auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; + auto oPtr = new Nd4jLong[numOfSubArrs]; + + if (numOfSubArrs > 0) + shape::calcSubArrsShapeInfoAndOffsets( + shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), + sPtr, oPtr, descriptor.areUnitiesinShape()); + + Nd4jPointer soPtr; + auto res = cudaMalloc(reinterpret_cast(&soPtr), + numOfSubArrs * sizeof(Nd4jLong)); + if (res != 0) + throw cuda_exception::build("Memory allocation for tadOffsets failed", + res); + + res = cudaMemcpy(soPtr, oPtr, numOfSubArrs * sizeof(Nd4jLong), + cudaMemcpyHostToDevice); + if (res != 0) throw cuda_exception::build("tadOffsets copy failed", res); + + auto ssPtr = ConstantHelper::getInstance()->replicatePointer( + sPtr, shape::shapeInfoByteLength(subArrRank)); + + ConstantDataBuffer shapesBuffer( + sPtr, ssPtr, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong), + DataType::INT64); + ConstantDataBuffer offsetsBuffer( + oPtr, soPtr, numOfSubArrs * sizeof(Nd4jLong), DataType::INT64); + + TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); + _cache[deviceId][descriptor] = t; + + TadPack r = _cache[deviceId][descriptor]; + + delete[] shapeInfo; + + return r; + } else { + TadPack r = _cache[deviceId][descriptor]; + + return r; + } +} + +sd::ConstantTadHelper *sd::ConstantTadHelper::_INSTANCE = 0; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/PointersManager.cu b/libnd4j/include/helpers/cuda/PointersManager.cu index dc5fe15f50b1..990097c39cc1 100644 --- a/libnd4j/include/helpers/cuda/PointersManager.cu +++ b/libnd4j/include/helpers/cuda/PointersManager.cu @@ -19,8 +19,8 @@ // @author raver119@gmail.com // -#include #include +#include #include #include #include @@ -28,97 +28,121 @@ namespace sd { ////////////////////////////////////////////////////////////////////////// -PointersManager::PointersManager(const sd::LaunchContext* context, const std::string& funcName) { - _context = const_cast(context); - _funcName = funcName; +PointersManager::PointersManager(const sd::LaunchContext* context, + const std::string& funcName) { + _context = const_cast(context); + _funcName = funcName; } ////////////////////////////////////////////////////////////////////////// -void* PointersManager::replicatePointer(const void* src, const size_t numberOfBytes) { - - void* dst = nullptr; - if (_context->getWorkspace() == nullptr) { - cudaError_t cudaResult = cudaMalloc(reinterpret_cast(&dst), numberOfBytes); - if (cudaResult != 0) - throw cuda_exception::build(_funcName + ": cannot allocate global memory on device!", cudaResult); - } else { - dst = _context->getWorkspace()->allocateBytes(sd::memory::MemoryType::DEVICE, numberOfBytes); - } - - if (_context != nullptr) - cudaMemcpyAsync(dst, src, numberOfBytes, cudaMemcpyHostToDevice, *_context->getCudaStream()); - else - cudaMemcpy(dst, src, numberOfBytes, cudaMemcpyHostToDevice); - - _pOnGlobMem.emplace_back(dst); - - return dst; +void* PointersManager::replicatePointer(const void* src, + const size_t numberOfBytes) { + void* dst = nullptr; + if (_context->getWorkspace() == nullptr) { + cudaError_t cudaResult = + cudaMalloc(reinterpret_cast(&dst), numberOfBytes); + if (cudaResult != 0) + throw cuda_exception::build( + _funcName + ": cannot allocate global memory on device!", cudaResult); + } else { + dst = _context->getWorkspace()->allocateBytes( + sd::memory::MemoryType::DEVICE, numberOfBytes); + } + + if (_context != nullptr) + cudaMemcpyAsync(dst, src, numberOfBytes, cudaMemcpyHostToDevice, + *_context->getCudaStream()); + else + cudaMemcpy(dst, src, numberOfBytes, cudaMemcpyHostToDevice); + + _pOnGlobMem.emplace_back(dst); + + return dst; } ////////////////////////////////////////////////////////////////////////// void PointersManager::synchronize() const { - if (_context != nullptr) { - cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); - if (cudaResult != 0) - throw cuda_exception::build(_funcName + ": cuda stream synchronization failed !", cudaResult); - } else { - nd4j_printf("<%s> syncStream isn't possible: no stream set!", _funcName.c_str()); - } + if (_context != nullptr) { + cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); + if (cudaResult != 0) + throw cuda_exception::build( + _funcName + ": cuda stream synchronization failed !", cudaResult); + } else { + nd4j_printf("<%s> syncStream isn't possible: no stream set!", + _funcName.c_str()); + } } ////////////////////////////////////////////////////////////////////////// PointersManager::~PointersManager() { - - for (auto& p :_pOnGlobMem) - cudaFree(p); + for (auto& p : _pOnGlobMem) cudaFree(p); } - //////////////////////////////////////////////////////////////////////// template -static __global__ void printDevContentOnDev_(const void* pDev, const Nd4jLong len, const int tid) { - - PointersManager::printDevContentOnDev(pDev, len, tid); +static __global__ void printDevContentOnDev_(const void* pDev, + const Nd4jLong len, + const int tid) { + PointersManager::printDevContentOnDev(pDev, len, tid); } //////////////////////////////////////////////////////////////////////// -template -void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid) { - printDevContentOnDev_<<<512, 512, 1024, *sd::LaunchContext ::defaultContext()->getCudaStream()>>>(pDev, len, tid); - auto res = cudaStreamSynchronize(*sd::LaunchContext ::defaultContext()->getCudaStream()); - if (res != 0) - throw std::runtime_error("PointersManager::printDevContentOnDevFromHost: cudaStreamSynchronize failed!"); +template +void PointersManager::printDevContentOnDevFromHost(const void* pDev, + const Nd4jLong len, + const int tid) { + printDevContentOnDev_ + <<<512, 512, 1024, + *sd::LaunchContext ::defaultContext()->getCudaStream()>>>(pDev, len, + tid); + auto res = cudaStreamSynchronize( + *sd::LaunchContext ::defaultContext()->getCudaStream()); + if (res != 0) + throw std::runtime_error( + "PointersManager::printDevContentOnDevFromHost: cudaStreamSynchronize " + "failed!"); } -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); -template void PointersManager::printDevContentOnDevFromHost(const void* pDev, const Nd4jLong len, const int tid); - -//BUILD_SINGLE_TEMPLATE(template void PointersManager::printDevContentOnDevFromHost, (void* pDev, Nd4jLong len, int tid), LIBND4J_TYPES); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); +template void PointersManager::printDevContentOnDevFromHost( + const void* pDev, const Nd4jLong len, const int tid); + +// BUILD_SINGLE_TEMPLATE(template void +// PointersManager::printDevContentOnDevFromHost, (void* pDev, Nd4jLong len, int +// tid), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -template -void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const { - printf("host print out\n"); - void* pHost = operator new(sizeof(T) * len); - - cudaMemcpyAsync(pHost, pDev, sizeof(T) * len, cudaMemcpyDeviceToHost, *_context->getCudaStream()); - cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); - if(cudaResult != 0) - throw std::runtime_error("PointersManager::printCudaHost: cudaStreamSynchronize failed!"); - - for(Nd4jLong i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pHost)[i]); - printf("\n"); - - operator delete(pHost); +template +void PointersManager::printDevContentOnHost(const void* pDev, + const Nd4jLong len) const { + printf("host print out\n"); + void* pHost = operator new(sizeof(T) * len); + + cudaMemcpyAsync(pHost, pDev, sizeof(T) * len, cudaMemcpyDeviceToHost, + *_context->getCudaStream()); + cudaError_t cudaResult = cudaStreamSynchronize(*_context->getCudaStream()); + if (cudaResult != 0) + throw std::runtime_error( + "PointersManager::printCudaHost: cudaStreamSynchronize failed!"); + + for (Nd4jLong i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pHost)[i]); + printf("\n"); + + operator delete(pHost); } +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; +template void PointersManager::printDevContentOnHost( + const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; -template void PointersManager::printDevContentOnHost(const void* pDev, const Nd4jLong len) const; - - -} +} // namespace sd diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 0a3b466bc2a7..572701cc159a 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -19,502 +19,597 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include -#include "../MmulHelper.h" -#include -#include +#include #include +#include +#include + #include +#include "../MmulHelper.h" + namespace sd { ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, +static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, + void* vC, const Nd4jLong* cShapeInfo, + const int aMaxis, const int aKaxis, + const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast(vC); - const T1* A = reinterpret_cast(vA); - const T2* B = reinterpret_cast(vB); - T3* C = reinterpret_cast< T3*>(vC); - - __shared__ int K, *coords; - __shared__ bool betaPresent; - __shared__ Nd4jLong cLen, totalThreads; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { + __shared__ int K, *coords; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads; + __shared__ T3 alphaZ, betaZ; - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - cLen = shape::length(cShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); - K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; - betaPresent = beta; + betaPresent = beta; - totalThreads = gridDim.x * blockDim.x; + totalThreads = gridDim.x * blockDim.x; - alphaZ = alpha; - betaZ = beta; - } - __syncthreads(); - - auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) - auto bCoords = aCoords + 2; - auto cCoords = bCoords + 2; + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) + auto bCoords = aCoords + 2; + auto cCoords = bCoords + 2; - for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - // evaluate C coordinates - shape::index2coords(i, cShapeInfo, cCoords); + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); - // evaluate A coordinates - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate B coordinates - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - auto aOffset = shape::getOffset(aShapeInfo, aCoords); - auto bOffset = shape::getOffset(bShapeInfo, bCoords); + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); - T3 val = A[aOffset] * B[bOffset]; // first iteration + T3 val = A[aOffset] * B[bOffset]; // first iteration - for (uint j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords); + auto cOffset = shape::getOffset(cShapeInfo, cCoords); - if(betaPresent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } + if (betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - - usualCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +__host__ static void usualGemm( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + cudaStream_t* stream, const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, void* vC, + const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, + const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + usualCudaGemm + <<>>( + vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, + bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); } //////////////////////////////////////////////////////////////////////// // MXN x N = M -> actual sequence of {M,N} axes doesn't matter template -static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, - const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { - - const T1* A = reinterpret_cast(vA); - const T2* X = reinterpret_cast(vX); - T3* Y = reinterpret_cast< T3*>(vY); - - __shared__ int M, N; - __shared__ bool betaPresent; - __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { - - N = shape::length(xShapeInfo); - M = shape::length(yShapeInfo); - - aMstride = shape::stride(aShapeInfo)[aMaxis]; - aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; - - totalThreads = gridDim.x * blockDim.x; - - betaPresent = beta; - - alphaZ = alpha; - betaZ = beta; +static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, + const void* vX, const Nd4jLong* xShapeInfo, + void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, + const int aMaxis, const double alpha, + const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* X = reinterpret_cast(vX); + T3* Y = reinterpret_cast(vY); + + __shared__ int M, N; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + N = shape::length(xShapeInfo); + M = shape::length(yShapeInfo); + + aMstride = shape::stride(aShapeInfo)[aMaxis]; + aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; + + totalThreads = gridDim.x * blockDim.x; + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < M; i += totalThreads) { + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; } - __syncthreads(); - - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < M; i += totalThreads) { - // evaluate offsets - auto aOffset = i * aMstride; - auto xOffset = 0; + auto yOffset = i * incy; - T3 val = A[aOffset] * X[xOffset]; // first iteration - - for (uint j = 1; j < N; ++j) { // rest iterations - aOffset += aNstride; - xOffset += incx; - val = val + A[aOffset] * X[xOffset]; - } - - auto yOffset = i * incy; - - if(betaPresent) - Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; - else - Y[yOffset] = alphaZ * val; - } + if (betaPresent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemv(const int blocksPerGrid, const int threadsPerBlock, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { - - usualCudaGemv<<>>(vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, beta); +__host__ static void usualGemv(const int blocksPerGrid, + const int threadsPerBlock, cudaStream_t* stream, + const void* vA, const Nd4jLong* aShapeInfo, + const void* vX, const Nd4jLong* xShapeInfo, + void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, const int aMaxis, + const double alpha, const double beta) { + usualCudaGemv<<>>( + vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, + beta); } - ////////////////////////////////////////////////////////////////////////////// template -static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); +static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, + const void* vX, const Nd4jLong incx, + const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); - extern __shared__ unsigned char shmem[]; - auto pairwiseMul = reinterpret_cast(shmem); + extern __shared__ unsigned char shmem[]; + auto pairwiseMul = reinterpret_cast(shmem); - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid < length) - pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < length) pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; - __syncthreads(); + __syncthreads(); - if(tid == 0) { - T3 sum = 0; - for(Nd4jLong i = 0; i < length; ++i) - sum = sum + pairwiseMul[i]; + if (tid == 0) { + T3 sum = 0; + for (Nd4jLong i = 0; i < length; ++i) sum = sum + pairwiseMul[i]; - if(beta) - *Z = (T3)alpha * sum + (T3)beta * *Z; - else - *Z = (T3)alpha * sum; - } + if (beta) + *Z = (T3)alpha * sum + (T3)beta * *Z; + else + *Z = (T3)alpha * sum; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); +__host__ static void usualDot(const dim3& blocksPerGrid, + const dim3& threadsPerBlock, cudaStream_t* stream, + const Nd4jLong length, const double alpha, + const void* vX, const Nd4jLong incx, + const void* vY, const Nd4jLong incy, + const double beta, void* vZ) { + usualCudaDot + <<>>( + length, alpha, vX, incx, vY, incy, beta, vZ); } ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); - - const auto M = A->sizeAt(0); - const auto K = A->sizeAt(1); - const auto N = B->sizeAt(1); - - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); - - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); - - const auto aType = A->dataType(); - const auto bType = B->dataType(); - const auto cType = C->dataType(); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - - const bool typeDouble = ABC && aType == DataType::DOUBLE; - const bool typeFloat = ABC && aType == DataType::FLOAT32; - const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; - const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; - const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; - - std::lock_guard lock(*LaunchContext::deviceMutex()); +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, + double alpha, double beta, const char outOrder) { + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); + if (B->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if (C != nullptr && C->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); + if (B->sizeAt(0) != K) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); + if (C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error( + "MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + + if (C == nullptr) + C = new NDArray( + outOrder, {M, N}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const int major = Environment::getInstance() + ->capabilities()[AffinityManager::currentDeviceId()] + .first(); + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + const bool typeDouble = ABC && aType == DataType::DOUBLE; + const bool typeFloat = ABC && aType == DataType::FLOAT32; + const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; + const bool typeIntFloat = + AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; + const bool typeHalfFloat = + AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = + reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + if (!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && + !typeHalfFloat) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, + // threadsPerBlock, sharedMem, stream, A->specialBuffer(), + // A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + // C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, + // beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + aType, usualGemm, + (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), + A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, + beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); - if(!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && !typeHalfFloat) { + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); + } else { + std::vector toDelete; - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * 6 + 128; // 6 = aRank + bRank + cRank + NDArray *pA(const_cast(A)), *pB(const_cast(B)), + *pC(const_cast(C)); - NDArray::prepareSpecialUse({C}, {A, B}); - // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, stream, A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({C}, {A, B}); + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + if (!aMcont && !aKcont) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + aMcont = true; + } + if (!bKcont && !bNcont) { + pB = new NDArray(B->dup('f')); + toDelete.push_back(pB); + bKcont = true; + } + if (!cMcont) { + pC = new NDArray(C->dup('f')); + toDelete.push_back(pC); + cMcont = true; } - else { - - std::vector toDelete; - - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aKcont = K == 1 || A->strideAt(1) == 1; - bool bKcont = K == 1 || B->strideAt(0) == 1; - bool bNcont = N == 1 || B->strideAt(1) == 1; - bool cMcont = M == 1 || C->strideAt(0) == 1; - bool cNcont = N == 1 || C->strideAt(1) == 1; - if(!aMcont && !aKcont) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - aMcont = true; - } - if(!bKcont && !bNcont) { - pB = new NDArray(B->dup('f')); - toDelete.push_back(pB); - bKcont = true; - } - if(!cMcont) { - pC = new NDArray(C->dup('f')); - toDelete.push_back(pC); - cMcont = true; - } + const bool transA = !aMcont; + const bool transB = !bKcont; - const bool transA = !aMcont; - const bool transB = !bKcont; + const int lda = + (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = + (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); - const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); - const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - NDArray::prepareSpecialUse({pC}, {pA, pB}); + NDArray::prepareSpecialUse({pC}, {pA, pB}); - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->specialBuffer(), lda, (double*)pB->specialBuffer(), ldb, &beta, (double*)pC->specialBuffer(), ldc); - } - else if(typeFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->specialBuffer(), lda, (float*)pB->specialBuffer(), ldb, &betaF, (float*)pC->specialBuffer(), ldc); - } - else if(typeHalf) { - float16 alphaH(alpha), betaH(beta); - status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->specialBuffer(), lda, (__half*)pB->specialBuffer(), ldb, &betaH.data, (__half*)pC->specialBuffer(), ldc); - } - else if(typeIntFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_8I, lda, pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc); - } - else if(typeHalfFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->specialBuffer(), CUDA_R_16F, lda, pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, pC->specialBuffer(), CUDA_R_32F, ldc); - } + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, + (double*)pA->specialBuffer(), lda, + (double*)pB->specialBuffer(), ldb, &beta, + (double*)pC->specialBuffer(), ldc); + } else if (typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, + (float*)pA->specialBuffer(), lda, + (float*)pB->specialBuffer(), ldb, &betaF, + (float*)pC->specialBuffer(), ldc); + } else if (typeHalf) { + float16 alphaH(alpha), betaH(beta); + status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, + &alphaH.data, (__half*)pA->specialBuffer(), lda, + (__half*)pB->specialBuffer(), ldb, &betaH.data, + (__half*)pC->specialBuffer(), ldc); + } else if (typeIntFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, + pA->specialBuffer(), CUDA_R_8I, lda, + pB->specialBuffer(), CUDA_R_8I, ldb, &betaF, + pC->specialBuffer(), CUDA_R_32F, ldc); + } else if (typeHalfFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, + pA->specialBuffer(), CUDA_R_16F, lda, + pB->specialBuffer(), CUDA_R_16F, ldb, &betaF, + pC->specialBuffer(), CUDA_R_32F, ldc); + } - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - NDArray::registerSpecialUse({pC}, {pA, pB}); + NDArray::registerSpecialUse({pC}, {pA, pB}); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); - if(C != pC) - C->assign(pC); + if (C != pC) C->assign(pC); - for(int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - } + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + } - return C; + return C; } //////////////////////////////////////////////////////////////////////////// // MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, + const double alpha, const double beta, + const char outOrder) { + int xLenDim, yLenDim(0); + + if (A->rankOf() != 2) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if (Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: Y array must be vector !"); + if (X->lengthOf() != N) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: X vector has wrong length !"); + if (Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error( + "MmulHelper::mmulMxV cuda: Y array has wrong length !"); + + if (Y == nullptr) + Y = new NDArray( + outOrder, {M}, + DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), + A->getContext()); + + if (Y->isEmpty()) return Y; + + const int incx = X->strideAt(xLenDim); + const int incy = Y->strideAt(yLenDim); + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + + const bool typeDouble = AXY && aType == DataType::DOUBLE; + const bool typeFloat = AXY && aType == DataType::FLOAT32; + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = + reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + if (!typeDouble && !typeFloat) { + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({Y}, {A, X}); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, + // threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), + // X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), + // Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, + // NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualGemv, + (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), + A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), + Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({Y}, {A, X}); - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - if (Y->isEmpty()) - return Y; - - const int incx = X->strideAt(xLenDim); - const int incy = Y->strideAt(yLenDim); - - const auto aType = A->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - - const bool typeDouble = AXY && aType == DataType::DOUBLE; - const bool typeFloat = AXY && aType == DataType::FLOAT32; - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - if(!typeDouble && !typeFloat) { + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", + cudaResult); - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + } else { + NDArray* pA(const_cast(A)); - NDArray::prepareSpecialUse({Y}, {A, X}); - // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, A->specialBuffer(), A->specialShapeInfo(), X->specialBuffer(), X->specialShapeInfo(), Y->specialBuffer(), Y->specialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({Y}, {A, X}); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + if (!aMcont && !aNcont) { + pA = new NDArray(A->dup('f')); + aMcont = true; } - else { - NDArray *pA(const_cast(A)); + const bool transA = !aMcont; - bool aMcont = M == 1 || A->strideAt(0) == 1; - bool aNcont = N == 1 || A->strideAt(1) == 1; + const int lda = + (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - if(!aMcont && !aNcont) { - pA = new NDArray(A->dup('f')); - aMcont = true; - } + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const bool transA = !aMcont; - - const int lda = (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - - NDArray::prepareSpecialUse({Y}, {pA, X}); + NDArray::prepareSpecialUse({Y}, {pA, X}); - // choose appropriate cuda gemm api depending on data types - if(typeDouble) { - status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->specialBuffer(), lda, (double*)X->specialBuffer(), incx, &beta, (double*)Y->specialBuffer(), incy); - } - else if(typeFloat) { - float alphaF(alpha), betaF(beta); - status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->specialBuffer(), lda, (float*)X->specialBuffer(), incx, &betaF, (float*)Y->specialBuffer(), incy); - } + // choose appropriate cuda gemm api depending on data types + if (typeDouble) { + status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, + &alpha, (double*)pA->specialBuffer(), lda, + (double*)X->specialBuffer(), incx, &beta, + (double*)Y->specialBuffer(), incy); + } else if (typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, + &alphaF, (float*)pA->specialBuffer(), lda, + (float*)X->specialBuffer(), incx, &betaF, + (float*)Y->specialBuffer(), incy); + } - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", + cudaResult); - NDArray::registerSpecialUse({Y}, {pA, X}); + NDArray::registerSpecialUse({Y}, {pA, X}); - if(pA != A) - delete pA; - } + if (pA != A) delete pA; + } - return Y; + return Y; } //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, const double alpha, const double beta) { - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->shapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); - if(!shape::isCommonVector(Y->shapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->strideAt(xLenDim); - const Nd4jLong incy = Y->strideAt(yLenDim); - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - if(!X->isActualOnDeviceSide()) X->syncToDevice(); - if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); - if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); - - cudaStream_t* stream = X->getContext()->getCudaStream(); - - dim3 threadsPerBlock(512); - dim3 blocksPerGrid(1); - if (length > 512) - threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); - - NDArray::prepareSpecialUse({Z}, {X, Y}); - - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES) - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Z}, {X, Y}); - - return Z; +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, + const double alpha, const double beta) { + int xLenDim(0), yLenDim(0); + + if (!shape::isCommonVector(X->shapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + if (!shape::isCommonVector(Y->shapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + if (Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if (Y->lengthOf() != length) + throw std::runtime_error( + "MmulHelper::dot cuda: lengths of input vectors are different !"); + + if (Z == nullptr) + Z = new NDArray( + DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), + X->getContext()); + + const Nd4jLong incx = X->strideAt(xLenDim); + const Nd4jLong incy = Y->strideAt(yLenDim); + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + if (!X->isActualOnDeviceSide()) X->syncToDevice(); + if (!Y->isActualOnDeviceSide()) Y->syncToDevice(); + if (!Z->isActualOnDeviceSide()) Z->syncToDevice(); + + cudaStream_t* stream = X->getContext()->getCudaStream(); + + dim3 threadsPerBlock(512); + dim3 blocksPerGrid(1); + if (length > 512) + threadsPerBlock.x = + math::nd4j_ceil(static_cast(length) / 512); + + NDArray::prepareSpecialUse({Z}, {X, Y}); + + // BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, + // threadsPerBlock, stream, length, alpha, X->specialBuffer(), incx, + // Y->specialBuffer(), incy, beta, Z->specialBuffer()), NUMERIC_TYPES, + // NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, usualDot, + (blocksPerGrid, threadsPerBlock, stream, length, alpha, + X->specialBuffer(), incx, Y->specialBuffer(), incy, beta, + Z->specialBuffer()), + NUMERIC_TYPES) + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); + + NDArray::registerSpecialUse({Z}, {X, Y}); + + return Z; } ////////////////////////////////////////////////////////////////////////////// @@ -523,165 +618,208 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes template -static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, - const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, - const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, - const double alpha, const double beta) { - - const T1* A = reinterpret_cast(vA); - const T2* B = reinterpret_cast(vB); - T3* C = reinterpret_cast< T3*>(vC); - - __shared__ bool betaPresent; - __shared__ int aRank, bRank, cRank, K, *coords; - __shared__ Nd4jLong cLen, totalThreads; - __shared__ T3 alphaZ, betaZ; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - cLen = shape::length(cShapeInfo); +static __global__ void batchedCudaGemm( + const void* vA, const Nd4jLong* aShapeInfo, const void* vB, + const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, + const int cMaxis, const int cNaxis, const double alpha, const double beta) { + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast(vC); + + __shared__ bool betaPresent; + __shared__ int aRank, bRank, cRank, K, *coords; + __shared__ Nd4jLong cLen, totalThreads; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); - K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; - totalThreads = gridDim.x * blockDim.x; - aRank = shape::rank(aShapeInfo); - bRank = shape::rank(bShapeInfo); - cRank = shape::rank(cShapeInfo); + totalThreads = gridDim.x * blockDim.x; + aRank = shape::rank(aShapeInfo); + bRank = shape::rank(bShapeInfo); + cRank = shape::rank(cShapeInfo); - betaPresent = beta; + betaPresent = beta; - alphaZ = alpha; - betaZ = beta; - } - __syncthreads(); + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); - auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); - auto bCoords = aCoords + aRank; - auto cCoords = bCoords + bRank; + auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); + auto bCoords = aCoords + aRank; + auto cCoords = bCoords + bRank; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); - // evaluate C coordinates - shape::index2coords(i, cShapeInfo, cCoords); + // calculate index of current batch + Nd4jLong batchInd; + if (cBatchDims != nullptr) + batchInd = + shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); - // calculate index of current batch - Nd4jLong batchInd; - if(cBatchDims != nullptr) - batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); + // evaluate A coordinates + if (aBatchDims != nullptr) + shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; - // evaluate A coordinates - if(aBatchDims != nullptr) - shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); - aCoords[aMaxis] = cCoords[cMaxis]; - aCoords[aKaxis] = 0; + // evaluate B coordinates + if (bBatchDims != nullptr) + shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; - // evaluate B coordinates - if(bBatchDims != nullptr) - shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); - bCoords[bKaxis] = 0; - bCoords[bNaxis] = cCoords[cNaxis]; + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); - auto aOffset = shape::getOffset(aShapeInfo, aCoords); - auto bOffset = shape::getOffset(bShapeInfo, bCoords); + T3 val = A[aOffset] * B[bOffset]; // first iteration - T3 val = A[aOffset] * B[bOffset]; // first iteration - - for (uint j = 1; j < K; ++j) { // rest iterations - aOffset += shape::stride(aShapeInfo)[aKaxis]; - bOffset += shape::stride(bShapeInfo)[bKaxis]; - val = val + A[aOffset] * B[bOffset]; - } + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } - auto cOffset = shape::getOffset(cShapeInfo, cCoords); + auto cOffset = shape::getOffset(cShapeInfo, cCoords); - if(betaPresent) - C[cOffset] = alphaZ * val + betaZ * C[cOffset]; - else - C[cOffset] = alphaZ * val; - } + if (betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void batchedGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - - batchedCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +__host__ static void batchedGemm( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + cudaStream_t* stream, const void* vA, const Nd4jLong* aShapeInfo, + const void* vB, const Nd4jLong* bShapeInfo, void* vC, + const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, + const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, + const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, + const double beta) { + batchedCudaGemm + <<>>( + vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, + bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, + cNaxis, alpha, beta); } /////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else - C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - - if (C->isEmpty()) - return C; - - const int cRank = C->rankOf(); - - const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); - - const int threadsPerBlock = MAX_NUM_THREADS / 8; - const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * (aRank + bRank + cRank) + 128; - - PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); - - const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); - - if(aRank > 2) - aBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), (aRank - 2) * sizeof(int))); - if(bRank > 2) - bBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), (bRank - 2) * sizeof(int))); - if(cRank > 2) - cBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int))); - - NDArray::prepareSpecialUse({C}, {A, B}); - // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES) - NDArray::registerSpecialUse({C}, {A, B}); - - manager.synchronize(); - - return C; +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, + const double alpha, const double beta, + const char outOrder) { + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if (aRank > bRank && bRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if (bRank > aRank && aRank != 2) + throw std::runtime_error( + "MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank) { + for (int i = 0; i < aRank - 2; ++i) + if (A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable " + "for matrix multiplication !"); + } + + if (A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for " + "matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = + aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if (C != nullptr) { + if (!C->isSameShape(cExpectedShape)) + throw std::runtime_error( + "MmulHelper::mmulNxN: shape of C array is not suitable for AxB " + "matrix multiplication !"); + } else + C = new NDArray( + outOrder, cExpectedShape, + DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), + A->getContext()); + + if (C->isEmpty()) return C; + + const int cRank = C->rankOf(); + + const int aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), + bNaxis(bRank - 1), cMaxis(cRank - 2), cNaxis(cRank - 1); + + const int threadsPerBlock = MAX_NUM_THREADS / 8; + const int blocksPerGrid = + (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * (aRank + bRank + cRank) + 128; + + PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); + + const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); + + if (aRank > 2) + aBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), + (aRank - 2) * sizeof(int))); + if (bRank > 2) + bBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), + (bRank - 2) * sizeof(int))); + if (cRank > 2) + cBatchDims = reinterpret_cast(manager.replicatePointer( + ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), + (cRank - 2) * sizeof(int))); + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), + // batchedGemm, (blocksPerGrid, threadsPerBlock, + // A->getContext()->getCudaStream(), A->specialBuffer(), + // A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + // C->specialBuffer(), C->specialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, + // cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + A->dataType(), batchedGemm, + (blocksPerGrid, threadsPerBlock, sharedMem, + A->getContext()->getCudaStream(), A->specialBuffer(), + A->specialShapeInfo(), B->specialBuffer(), B->specialShapeInfo(), + C->specialBuffer(), C->specialShapeInfo(), aBatchDims, bBatchDims, + cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), + NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + + manager.synchronize(); + + return C; } - /* ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template -static __global__ void usualCudaGemv(const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +static __global__ void usualCudaGemv(const bool transA, const int M, const int +N, const double alpha, const void* vA, const int lda, const void* vX, const int +incx, const double beta, void* vY, const int incy) { T1* A = reinterpret_cast(const_cast(vA)); T2* X = reinterpret_cast(const_cast(vX)); @@ -697,7 +835,8 @@ static __global__ void usualCudaGemv(const bool transA, const int M, const int N alphaZ = alpha; betaZ = beta; - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; +strideAcol = lda; } } __syncthreads(); @@ -712,9 +851,13 @@ static __global__ void usualCudaGemv(const bool transA, const int M, const int N //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { +__host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 +&threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const +int N, const double alpha, const void* vA, const int lda, const void* vX, const +int incx, const double beta, void* vY, const int incy) { - usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); + usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); } */ /* @@ -722,7 +865,10 @@ __host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPer MXK x KxN = MxN C array must be in f order template -static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static __global__ void usualCudaGemm(const bool transA, const bool transB, const +int M, const int N, const int K, const double alpha, const void* vA, const int +lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) +{ T1* A = reinterpret_cast(const_cast(vA)); T2* B = reinterpret_cast(const_cast(vB)); @@ -739,8 +885,9 @@ static __global__ void usualCudaGemm(const bool transA, const bool transB, const alphaZ = alpha; betaZ = beta; - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } - if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; +strideAcol = lda; } if(transB) { strideBrow = ldb; strideBcol = 1; } else { +strideBrow = 1; strideBcol = ldb; } } __syncthreads(); @@ -748,47 +895,57 @@ static __global__ void usualCudaGemm(const bool transA, const bool transB, const T3 val = 0; if (row < M && col < N) for (int i = 0; i < K; i++) - val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; + val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow ++ col * strideBcol]; C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; } ////////////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - - usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); +__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 +&threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, +const int M, const int N, const int K, const double alpha, const void* vA, const +int lda, const void* vB, const int ldb, const double beta, void* vC, const int +ldc) { + + usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); } */ ////////////////////////////////////////////////////////////////////////// /* -NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else { C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -796,15 +953,16 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); + const std::vector dimsToExclude = +ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const Nd4jLong +numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->shapeInfo(), dimsToExclude); std::vector idxRanges(2 * C->rankOf()); // #pragma omp parallel for schedule(guided) firstprivate(idxRanges) for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, C->shapeInfo(), dimsToExclude, +idxRanges.data()); NDArray cSubArr = (*C)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*A)(idxRanges); @@ -831,33 +989,37 @@ NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, // [M,K] x [bS,K,N] = [bS,M,N] // bS could stand for several axes /* -NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, +const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be +equal 2 !"); else if(bRank > aRank && aRank != 2) throw +std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B +arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays +are not suitable for matrix multiplication !"); // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() +: B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = +A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is +not suitable for AxB matrix multiplication !"); } else C = new NDArray(outOrder, cExpectedShape, B->dataType()); @@ -868,8 +1030,8 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const auto K = A->sizeAt(-1); const auto N = B->sizeAt(-1); - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - std::vector toDelete; + NDArray *pA(const_cast(A)), *pB(const_cast(B)), +*pC(const_cast(C)); std::vector toDelete; bool aMcont = M == 1 || A->strideAt(-2) == 1; bool aKcont = K == 1 || A->strideAt(-1) == 1; @@ -892,9 +1054,9 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, if(!cMcont) { std::iota(permut.begin(), permut.end(), 0); permut[cRank - 2] = cRank - 1; - permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] -> [..., N,M] - auto Cpermut = C->permute(permut); - pC = new NDArray('c', Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); + permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] +-> [..., N,M] auto Cpermut = C->permute(permut); pC = new NDArray('c', +Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); pC->assign(Cpermut); toDelete.push_back(pC); cMcont = true; @@ -932,43 +1094,53 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const int bS = pC->lengthOf() / (M*N); - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, {-2, -1}); + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, +{-2, -1}); NDArray::prepareSpecialUse({pC}, {pA, pB}); if(!badTypes) { std::vector subArrOffsets(bS); - std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all sub-arrays have rank = 2 + std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all +sub-arrays have rank = 2 std::vector aSubArrs(bS), bSubArrs(bS), cSubArrs(bS); if(aRank > 2) - shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - aSubArrs[i] = aRank == 2 ? pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pA->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) aSubArrs[i] = aRank == 2 ? +pA->specialBuffer() : pA->specialBuffer() + subArrOffsets[i] * pA->sizeOfT(); if(bRank > 2) - shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - bSubArrs[i] = bRank == 2 ? pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pB->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) bSubArrs[i] = bRank == 2 ? +pB->specialBuffer() : pB->specialBuffer() + subArrOffsets[i] * pB->sizeOfT(); - shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); - for (int i = 0; i < bS; ++i) - cSubArrs[i] = pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT(); + shape::calcSubArrsShapeInfoAndOffsets(pC->shapeInfo(), bS, +dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), +subArrOffsets.data()); for (int i = 0; i < bS; ++i) cSubArrs[i] = +pC->specialBuffer() + subArrOffsets[i] * pC->sizeOfT(); PointersManager manager(A->getContext(), "mmulNxN"); - const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * sizeof(void*))); - const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * sizeof(void*))); - void** cSubArrsCuda = reinterpret_cast< void **>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * sizeof(void*))); + const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * +sizeof(void*))); const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * +sizeof(void*))); void** cSubArrsCuda = reinterpret_cast< void +**>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * +sizeof(void*))); const bool transA = !aMcont; const bool transB = !bKcont; - const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : pA->strideAt(-1); - const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(-2) : pB->strideAt(-1); - const int ldc = (cMcont && cNcont) ? M : C != pC ? pC->strideAt(-2) : pC->strideAt(-1); + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : +pA->strideAt(-1); const int ldb = (bKcont && bNcont) ? K : transB ? +pB->strideAt(-2) : pB->strideAt(-1); const int ldc = (cMcont && cNcont) ? M : C +!= pC ? pC->strideAt(-2) : pC->strideAt(-1); const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -989,21 +1161,27 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, uBeta._d = beta; } - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); auto stream = +A->getContext()->getCudaStream(); auto status = cublasSetStream_v2(*handle, *stream); if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +status); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, +&uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, +cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +status); auto cudaResult = cudaStreamSynchronize(*stream); if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", cudaResult); + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", +cudaResult); } else { @@ -1011,8 +1189,8 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, for(Nd4jLong i = 0; i < bS; ++i) { - ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*pC)(idxRanges); + ShapeUtils::evalIdxRangesForSubArr(i, pC->shapeInfo(), +dimsToExclude, idxRanges.data()); NDArray cSubArr = (*pC)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*pA)(idxRanges); @@ -1042,8 +1220,19 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, } */ -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - -} \ No newline at end of file +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, +// const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const +// bool transB, const int M, const int N, const int K, const double alpha, const +// void* vA, const int lda, const void* vB, const int ldb, const double beta, +// void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, +// const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const +// int M, const int N, const double alpha, const void* vA, const int lda, const +// void* vB, const int incx, const double beta, void* vC, const int incy), +// NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template +// void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, +// cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* +// vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double +// beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index 7ab2d7d63115..aae4b46bcade 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -18,13 +18,13 @@ // @author raver119@gmail.com // - #include #include -#include "../cublasHelper.h" #include -#include #include +#include + +#include "../cublasHelper.h" #include "config.h" #ifdef HAVE_CUDNN @@ -34,111 +34,111 @@ #endif namespace sd { - std::mutex CublasHelper::_mutex; +std::mutex CublasHelper::_mutex; - static void* handle_() { - auto _handle = new cublasHandle_t(); - auto status = cublasCreate_v2(_handle); // initialize CUBLAS context - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("cuBLAS handle creation failed !", status); +static void* handle_() { + auto _handle = new cublasHandle_t(); + auto status = cublasCreate_v2(_handle); // initialize CUBLAS context + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("cuBLAS handle creation failed !", status); - return reinterpret_cast(_handle); - } + return reinterpret_cast(_handle); +} - static void* solver_() { - auto cusolverH = new cusolverDnHandle_t(); - auto status = cusolverDnCreate(cusolverH); - if (status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("cuSolver handle creation failed !", status); +static void* solver_() { + auto cusolverH = new cusolverDnHandle_t(); + auto status = cusolverDnCreate(cusolverH); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("cuSolver handle creation failed !", status); - return cusolverH; - } + return cusolverH; +} - static void* cudnn_() { +static void* cudnn_() { #ifdef HAVE_CUDNN - auto cudnnH = new cudnnHandle_t(); - auto status = cudnnCreate(cudnnH); - if (status != CUDNN_STATUS_SUCCESS) - throw cuda_exception::build("cuDNN handle creation failed !", status); + auto cudnnH = new cudnnHandle_t(); + auto status = cudnnCreate(cudnnH); + if (status != CUDNN_STATUS_SUCCESS) + throw cuda_exception::build("cuDNN handle creation failed !", status); - return cudnnH; + return cudnnH; #endif - return nullptr; - } - - static void destroyHandle_(void* handle) { - auto ch = reinterpret_cast(handle); - auto status = cublasDestroy_v2(*ch); - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("cuBLAS handle destruction failed !", status); - - delete ch; - } - - CublasHelper::CublasHelper() { - //nd4j_printf("Initializing cuBLAS\n",""); - auto numDevices = AffinityManager::numberOfDevices(); - auto currentDevice = AffinityManager::currentDeviceId(); - _cache.resize(numDevices); - _solvers.resize(numDevices); - _cudnn.resize(numDevices); - for (int e = 0; e < numDevices; e++) { - AffinityManager::setCurrentNativeDevice(e); - - _cache[e] = handle_(); - _solvers[e] = solver_(); - _cudnn[e] = cudnn_(); - } - - // don't forget to restore back original device - AffinityManager::setCurrentNativeDevice(currentDevice); - } - - CublasHelper::~CublasHelper() { - nd4j_printf("Releasing cuBLAS\n",""); - auto numDevices = AffinityManager::numberOfDevices(); - - for (int e = 0; e < numDevices; e++) - destroyHandle_(_cache[e]); - } - - CublasHelper* CublasHelper::getInstance() { - _mutex.lock(); - if (!_INSTANCE) - _INSTANCE = new sd::CublasHelper(); - _mutex.unlock(); - - return _INSTANCE; - } - - void* CublasHelper::cudnn() { - auto deviceId = AffinityManager::currentDeviceId(); - if (deviceId < 0 || deviceId > _cudnn.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _cudnn[deviceId]; - } - - void* CublasHelper::handle() { - auto deviceId = AffinityManager::currentDeviceId(); - return handle(deviceId); - } - - void* CublasHelper::solver() { - auto deviceId = AffinityManager::currentDeviceId(); - if (deviceId < 0 || deviceId > _solvers.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _solvers[deviceId]; - } - - void* CublasHelper::handle(int deviceId) { - if (deviceId < 0 || deviceId > _cache.size()) - throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); - - return _cache[deviceId]; - } - - - sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0; -} \ No newline at end of file + return nullptr; +} + +static void destroyHandle_(void* handle) { + auto ch = reinterpret_cast(handle); + auto status = cublasDestroy_v2(*ch); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("cuBLAS handle destruction failed !", status); + + delete ch; +} + +CublasHelper::CublasHelper() { + // nd4j_printf("Initializing cuBLAS\n",""); + auto numDevices = AffinityManager::numberOfDevices(); + auto currentDevice = AffinityManager::currentDeviceId(); + _cache.resize(numDevices); + _solvers.resize(numDevices); + _cudnn.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentNativeDevice(e); + + _cache[e] = handle_(); + _solvers[e] = solver_(); + _cudnn[e] = cudnn_(); + } + + // don't forget to restore back original device + AffinityManager::setCurrentNativeDevice(currentDevice); +} + +CublasHelper::~CublasHelper() { + nd4j_printf("Releasing cuBLAS\n", ""); + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) destroyHandle_(_cache[e]); +} + +CublasHelper* CublasHelper::getInstance() { + _mutex.lock(); + if (!_INSTANCE) _INSTANCE = new sd::CublasHelper(); + _mutex.unlock(); + + return _INSTANCE; +} + +void* CublasHelper::cudnn() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _cudnn.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _cudnn[deviceId]; +} + +void* CublasHelper::handle() { + auto deviceId = AffinityManager::currentDeviceId(); + return handle(deviceId); +} + +void* CublasHelper::solver() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _solvers.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _solvers[deviceId]; +} + +void* CublasHelper::handle(int deviceId) { + if (deviceId < 0 || deviceId > _cache.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", + deviceId); + + return _cache[deviceId]; +} + +sd::CublasHelper* sd::CublasHelper::_INSTANCE = 0; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/data_gen.h b/libnd4j/include/helpers/data_gen.h index 1640396178a9..4145de70e0ad 100644 --- a/libnd4j/include/helpers/data_gen.h +++ b/libnd4j/include/helpers/data_gen.h @@ -29,15 +29,14 @@ * @return the linear spaced array */ template -T * linspace(int lower, int upper, int num) { - T *data = new T[num]; - for (int i = 0; i < num; i++) { - T t = (T) i / (num - 1); - data[i] = lower * (1 - t) + t * upper; +T *linspace(int lower, int upper, int num) { + T *data = new T[num]; + for (int i = 0; i < num; i++) { + T t = (T)i / (num - 1); + data[i] = lower * (1 - t) + t * upper; + } - } - - return data; + return data; } -#endif //LIBND4J_DATA_GEN_H_H +#endif // LIBND4J_DATA_GEN_H_H diff --git a/libnd4j/include/helpers/files.h b/libnd4j/include/helpers/files.h index c85b98be2870..913f7bfc0a6b 100644 --- a/libnd4j/include/helpers/files.h +++ b/libnd4j/include/helpers/files.h @@ -16,112 +16,111 @@ // // Methods to lookup files in $PATH -// adopted from https://stackoverflow.com/questions/2718915/check-if-file-exists-including-on-path +// adopted from +// https://stackoverflow.com/questions/2718915/check-if-file-exists-including-on-path // #ifndef LIBND4J_FILES_H #define LIBND4J_FILES_H -#include -#include #include - +#include +#include void *malloc_check(const char *what, size_t n); char *strsave(const char *s, const char *lim); -char ** shellpath(void); -void freeshellpath (char *shellpath[]); +char **shellpath(void); +void freeshellpath(char *shellpath[]); unsigned maxpathlen(char *path[], const char *base); bool file_exists(const char *name); void *malloc_check(const char *what, size_t n) { - void *p = malloc(n); - if (p == NULL) { - fprintf(stderr, "Cannot allocate %zu bytes to %s\n", n, what); - exit(2); - } - return p; + void *p = malloc(n); + if (p == NULL) { + fprintf(stderr, "Cannot allocate %zu bytes to %s\n", n, what); + exit(2); + } + return p; } char *strsave(const char *s, const char *lim) { - if (lim == NULL) - lim = s + strlen(s); - char *p = (char *) malloc_check("save string", lim - s + 1); - strncpy(p, s, lim-s); - p[lim-s] = '\0'; - return p; + if (lim == NULL) lim = s + strlen(s); + char *p = (char *)malloc_check("save string", lim - s + 1); + strncpy(p, s, lim - s); + p[lim - s] = '\0'; + return p; } -char ** shellpath(void) { - const char *path = getenv("PATH"); - if (!path) - path = "./"; +char **shellpath(void) { + const char *path = getenv("PATH"); + if (!path) path = "./"; - char **vector = // size is overkill - (char **) malloc_check("hold path elements", strlen(path) * sizeof(*vector)); - const char *p = path; - int next = 0; - while (p) { + char **vector = // size is overkill + (char **)malloc_check("hold path elements", + strlen(path) * sizeof(*vector)); + const char *p = path; + int next = 0; + while (p) { #ifdef _WIN32 - char *q = strchr(const_cast(p), ';'); // windows uses ; as delimiter + char *q = + strchr(const_cast(p), ';'); // windows uses ; as delimiter #else - char *q = strchr(const_cast(p), ':'); // linux and derivatives use : as delimiter + char *q = strchr(const_cast(p), + ':'); // linux and derivatives use : as delimiter #endif - vector[next++] = strsave(p, q); - p = q ? q + 1 : NULL; - } - vector[next] = NULL; - return vector; + vector[next++] = strsave(p, q); + p = q ? q + 1 : NULL; + } + vector[next] = NULL; + return vector; } -void freeshellpath (char *shellpath[]) { - for (int i = 0; shellpath[i]; i++) - free(shellpath[i]); - free(shellpath); +void freeshellpath(char *shellpath[]) { + for (int i = 0; shellpath[i]; i++) free(shellpath[i]); + free(shellpath); } unsigned maxpathlen(char *path[], const char *base) { - unsigned blen = strlen(base); - unsigned n = 0; - for (int i = 0; path[i]; i++) { - unsigned pn = strlen(path[i]); - if (pn > n) n = pn; - } - return blen+n+1; + unsigned blen = strlen(base); + unsigned n = 0; + for (int i = 0; path[i]; i++) { + unsigned pn = strlen(path[i]); + if (pn > n) n = pn; + } + return blen + n + 1; } -bool file_exists(const char *name){ - //printf("Trying file: [%s]\n", name); - FILE *file; - if (file = fopen(name, "r")) { - fclose(file); - return true; - } - return false; +bool file_exists(const char *name) { + // printf("Trying file: [%s]\n", name); + FILE *file; + if (file = fopen(name, "r")) { + fclose(file); + return true; + } + return false; } bool checkFileInPath(const char *file) { - char *path = getenv("PATH"); - char **listed = shellpath(); - size_t maxlen = maxpathlen(listed, file)+1; - char *buf = (char *) malloc_check("hold path", maxlen); - bool found = false; - for (int i = 0; listed[i]; i++) { - if (strlen(listed[i]) > 0) { + char *path = getenv("PATH"); + char **listed = shellpath(); + size_t maxlen = maxpathlen(listed, file) + 1; + char *buf = (char *)malloc_check("hold path", maxlen); + bool found = false; + for (int i = 0; listed[i]; i++) { + if (strlen(listed[i]) > 0) { #ifdef _WIN32 - snprintf(buf, maxlen, "%s\\%s", listed[i], file); + snprintf(buf, maxlen, "%s\\%s", listed[i], file); #else - snprintf(buf, maxlen, "%s/%s", listed[i], file); + snprintf(buf, maxlen, "%s/%s", listed[i], file); #endif - if (file_exists(buf)) { - found = true; - break; - } - } + if (file_exists(buf)) { + found = true; + break; + } } - free(buf); - freeshellpath(listed); + } + free(buf); + freeshellpath(listed); - return found; + return found; } - -#endif //LIBND4J_FILES_H +#endif // LIBND4J_FILES_H diff --git a/libnd4j/include/helpers/helper_generator.h b/libnd4j/include/helpers/helper_generator.h index 760cd0e24c45..e6ce41cc968b 100644 --- a/libnd4j/include/helpers/helper_generator.h +++ b/libnd4j/include/helpers/helper_generator.h @@ -21,10 +21,10 @@ #ifndef LIBND4J_HELPER_GENERATOR_H #define LIBND4J_HELPER_GENERATOR_H -#include -#include #include #include +#include +#include #ifdef _MSC_VER // include for uint64_t on MSVC @@ -34,587 +34,541 @@ #ifndef UINT64_C #if defined(__LP64__) -#define UINT64_C(c) c ## UL +#define UINT64_C(c) c##UL #else -#define UINT64_C(c) c ## ULL -#endif //LP64 -#endif // UINT64 - -#endif // MSVC/ANDROID +#define UINT64_C(c) c##ULL +#endif // LP64 +#endif // UINT64 +#endif // MSVC/ANDROID #ifdef __GNUC__ #include #endif - namespace sd { - namespace random { +namespace random { #ifdef __CUDACC__ - class SD_EXPORT CudaManaged { - private: - - protected: - void *devHolder; - - public: - void *operator new(size_t len) { - void *ptr; - cudaHostAlloc(&ptr, len, cudaHostAllocDefault); - return ptr; - } - - void operator delete(void *ptr) { - cudaFreeHost(ptr); - } - }; - - class SD_EXPORT RandomBuffer : public CudaManaged { +class SD_EXPORT CudaManaged { + private: + protected: + void *devHolder; + + public: + void *operator new(size_t len) { + void *ptr; + cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + return ptr; + } + + void operator delete(void *ptr) { cudaFreeHost(ptr); } +}; + +class SD_EXPORT RandomBuffer : public CudaManaged { #else - class SD_EXPORT RandomBuffer { +class SD_EXPORT RandomBuffer { #endif - private: - void *devHolder; - Nd4jLong size; - uint64_t *buffer; - uint64_t *devBuffer; - Nd4jLong offset; - Nd4jLong seed; - Nd4jLong position; - Nd4jLong generation; - Nd4jLong currentPosition; - Nd4jLong amplifier; - unsigned int synchronizer; + private: + void *devHolder; + Nd4jLong size; + uint64_t *buffer; + uint64_t *devBuffer; + Nd4jLong offset; + Nd4jLong seed; + Nd4jLong position; + Nd4jLong generation; + Nd4jLong currentPosition; + Nd4jLong amplifier; + unsigned int synchronizer; #ifdef __CUDACC__ - curandGenerator_t gen; + curandGenerator_t gen; #endif - public: - /** - * This method allocates buffer of size * sizeof(Nd4jLong) - * - * @param size - * @return - */ + public: + /** + * This method allocates buffer of size * sizeof(Nd4jLong) + * + * @param size + * @return + */ #ifdef __CUDACC__ - __host__ - RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *hostBuffer, uint64_t *devBuffer) { - this->buffer = hostBuffer; - this->seed = seed; - this->size = size; - this->generation = 1; - this->currentPosition = 0; - this->offset = 0; - this->amplifier = seed; - this->synchronizer = 0; - this->devBuffer = devBuffer; - - cudaMalloc(&devHolder, sizeof(sd::random::RandomBuffer)); - } - - __host__ - Nd4jPointer getDevicePointer() { - return reinterpret_cast(devHolder); - } - - __host__ - ~RandomBuffer() { - cudaFree(devHolder); - } - - __host__ - void propagateToDevice(sd::random::RandomBuffer *buffer, cudaStream_t stream) { - cudaMemcpyAsync(devHolder, buffer, sizeof(sd::random::RandomBuffer), cudaMemcpyHostToDevice, stream); - } - - __host__ __device__ + __host__ RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *hostBuffer, + uint64_t *devBuffer) { + this->buffer = hostBuffer; + this->seed = seed; + this->size = size; + this->generation = 1; + this->currentPosition = 0; + this->offset = 0; + this->amplifier = seed; + this->synchronizer = 0; + this->devBuffer = devBuffer; + + cudaMalloc(&devHolder, sizeof(sd::random::RandomBuffer)); + } + + __host__ Nd4jPointer getDevicePointer() { + return reinterpret_cast(devHolder); + } + + __host__ ~RandomBuffer() { cudaFree(devHolder); } + + __host__ void propagateToDevice(sd::random::RandomBuffer *buffer, + cudaStream_t stream) { + cudaMemcpyAsync(devHolder, buffer, sizeof(sd::random::RandomBuffer), + cudaMemcpyHostToDevice, stream); + } + + __host__ __device__ #endif - RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *buffer) { - this->buffer = buffer; - this->seed = seed; - this->size = size; - this->generation = 1; - this->currentPosition = 0; - this->offset = 0; - this->amplifier = seed; - this->synchronizer = 0; - this->devBuffer = buffer; - } - - inline _CUDA_HD uint64_t *getBuffer() { - return this->buffer; - } - - inline _CUDA_HD uint64_t *getDeviceBuffer() { - return this->devBuffer; - } + RandomBuffer(Nd4jLong seed, Nd4jLong size, uint64_t *buffer) { + this->buffer = buffer; + this->seed = seed; + this->size = size; + this->generation = 1; + this->currentPosition = 0; + this->offset = 0; + this->amplifier = seed; + this->synchronizer = 0; + this->devBuffer = buffer; + } + + inline _CUDA_HD uint64_t *getBuffer() { return this->buffer; } + + inline _CUDA_HD uint64_t *getDeviceBuffer() { return this->devBuffer; } #ifdef __CUDACC__ - _CUDA_HD curandGenerator_t *getGeneratorPointer() { - return &gen; - } - - _CUDA_HD curandGenerator_t getGenerator() { - return gen; - } + _CUDA_HD curandGenerator_t *getGeneratorPointer() { return &gen; } + _CUDA_HD curandGenerator_t getGenerator() { return gen; } - _CUDA_H void setBuffer(uint64_t *ptr) { - this->buffer = ptr; - } + _CUDA_H void setBuffer(uint64_t *ptr) { this->buffer = ptr; } #endif - inline _CUDA_HD Nd4jLong getSize() { - return this->size; - } - - inline _CUDA_HD Nd4jLong getSeed() { - return this->seed; - } - - void _CUDA_HD setSeed(Nd4jLong seed) { - this->seed = seed; - this->amplifier = seed; - this->generation = 1; - } - - Nd4jLong _CUDA_HD getAllocatedSize() { - return this->size * sizeof(double); - } - - inline _CUDA_HD Nd4jLong getOffset() { - return this->currentPosition; - } - - void _CUDA_HD setOffset(Nd4jLong offset) { - this->currentPosition = offset; - } - - void _CUDA_HD reSeed(Nd4jLong amplifier) { - this->amplifier = amplifier; - } - - inline _CUDA_D uint64_t getElement(Nd4jLong position) { - Nd4jLong actualPosition = this->getOffset() + position; - Nd4jLong tempGen = generation; - if (actualPosition >= this->size) { - tempGen += actualPosition / this->size; - actualPosition = actualPosition % this->size; - } + inline _CUDA_HD Nd4jLong getSize() { return this->size; } + + inline _CUDA_HD Nd4jLong getSeed() { return this->seed; } + + void _CUDA_HD setSeed(Nd4jLong seed) { + this->seed = seed; + this->amplifier = seed; + this->generation = 1; + } + + Nd4jLong _CUDA_HD getAllocatedSize() { return this->size * sizeof(double); } + + inline _CUDA_HD Nd4jLong getOffset() { return this->currentPosition; } + + void _CUDA_HD setOffset(Nd4jLong offset) { this->currentPosition = offset; } + + void _CUDA_HD reSeed(Nd4jLong amplifier) { this->amplifier = amplifier; } + + inline _CUDA_D uint64_t getElement(Nd4jLong position) { + Nd4jLong actualPosition = this->getOffset() + position; + Nd4jLong tempGen = generation; + if (actualPosition >= this->size) { + tempGen += actualPosition / this->size; + actualPosition = actualPosition % this->size; + } #ifdef __CUDACC__ -// __syncthreads(); + // __syncthreads(); - auto ret = static_cast(devBuffer[actualPosition]); + auto ret = static_cast(devBuffer[actualPosition]); #else - auto ret = static_cast(buffer[actualPosition]); + auto ret = static_cast(buffer[actualPosition]); #endif - if (tempGen != generation) - ret = safeShift(ret, tempGen); + if (tempGen != generation) ret = safeShift(ret, tempGen); - if(generation > 1) - ret = safeShift(ret, generation); + if (generation > 1) ret = safeShift(ret, generation); - if (amplifier != seed) - ret = safeShift(ret, amplifier); + if (amplifier != seed) ret = safeShift(ret, amplifier); #ifdef __CUDACC__ // __syncthreads(); #endif - if (amplifier != seed || generation > 1 || tempGen != generation) - ret = next64(seedConv(static_cast(ret))); - - return ret; - } - - uint64_t _CUDA_HD next64(uint64_t shiftedSeed) { - const auto s0 = static_cast(shiftedSeed); - auto s1 = static_cast(shiftedSeed) % sd::DataTypeUtils::max() + 11; - uint64_t r0, r1; - - s1 ^= s0; - r0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b - r1 = rotl(s1, 36); // c - - return r0 + r1; - } - - static _CUDA_HD inline uint64_t rotl(const uint64_t x, uint64_t k) { - return (x << k) | (x >> (64 - k)); - } - - uint64_t static _CUDA_HD inline safeShift(uint64_t x, uint64_t y) { - if (y != 0 && x > sd::DataTypeUtils::max() / y) { - return x / y + 11; - } else return (x * y) + 11; - } - - uint64_t _CUDA_HD seedConv(Nd4jLong seed) { - uint64_t x = static_cast(seed); - uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); - z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); - z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); - return z ^ (z >> 31); - } - - void _CUDA_HD incrementGeneration() { - this->generation++; - } - - Nd4jLong _CUDA_HD getNextIndex() { - currentPosition++; - if (currentPosition >= size) { - currentPosition = 0; - generation++; - } - Nd4jLong ret = currentPosition; - - return ret; - } - - uint64_t _CUDA_HD getNextElement() { - // TODO: proper implementation needed here - return generation == 1 ? buffer[getNextIndex()] : buffer[getNextIndex()] * generation; - } - - - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ + if (amplifier != seed || generation > 1 || tempGen != generation) + ret = next64(seedConv(static_cast(ret))); + + return ret; + } + + uint64_t _CUDA_HD next64(uint64_t shiftedSeed) { + const auto s0 = static_cast(shiftedSeed); + auto s1 = + static_cast(shiftedSeed) % sd::DataTypeUtils::max() + 11; + uint64_t r0, r1; + + s1 ^= s0; + r0 = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b + r1 = rotl(s1, 36); // c + + return r0 + r1; + } + + static _CUDA_HD inline uint64_t rotl(const uint64_t x, uint64_t k) { + return (x << k) | (x >> (64 - k)); + } + + uint64_t static _CUDA_HD inline safeShift(uint64_t x, uint64_t y) { + if (y != 0 && x > sd::DataTypeUtils::max() / y) { + return x / y + 11; + } else + return (x * y) + 11; + } + + uint64_t _CUDA_HD seedConv(Nd4jLong seed) { + uint64_t x = static_cast(seed); + uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); + } + + void _CUDA_HD incrementGeneration() { this->generation++; } + + Nd4jLong _CUDA_HD getNextIndex() { + currentPosition++; + if (currentPosition >= size) { + currentPosition = 0; + generation++; + } + Nd4jLong ret = currentPosition; + + return ret; + } + + uint64_t _CUDA_HD getNextElement() { + // TODO: proper implementation needed here + return generation == 1 ? buffer[getNextIndex()] + : buffer[getNextIndex()] * generation; + } + + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ #ifdef __CUDACC__ - __device__ - void rewind(Nd4jLong numberOfElements) { - if (gridDim.x > 1) { - __shared__ bool amLast; - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&synchronizer, gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); - - if (amLast) { - if (threadIdx.x == 0) { - synchronizer = 0; - - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } else if (newPos == this->getSize()) { - newPos = 0; - generation++; - } - - this->setOffset(newPos); - } - } - } else { - if (threadIdx.x == 0) { - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } else if (newPos == this->getSize()) { - generation++; - newPos = 0; - } - - this->setOffset(newPos); - } - } - } + __device__ void rewind(Nd4jLong numberOfElements) { + if (gridDim.x > 1) { + __shared__ bool amLast; + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&synchronizer, gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + __syncthreads(); + + if (amLast) { + if (threadIdx.x == 0) { + synchronizer = 0; + + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + newPos = 0; + generation++; + } + + this->setOffset(newPos); + } + } + } else { + if (threadIdx.x == 0) { + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + generation++; + newPos = 0; + } + + this->setOffset(newPos); + } + } + } #endif - void rewindH(Nd4jLong numberOfElements) { - Nd4jLong newPos = this->getOffset() + numberOfElements; - if (newPos > this->getSize()) { - generation += newPos / this->size; - newPos = newPos % this->size; - } - else if (newPos == this->getSize()) { - generation++; - newPos = 0; - } - - this->setOffset(newPos); - } - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - int _CUDA_D nextInt() { - auto u = nextUInt64(); - return u <= sd::DataTypeUtils::max() ? static_cast(u) : static_cast(u % sd::DataTypeUtils::max()); - }; - - uint64_t _CUDA_D nextUInt64() { - return getNextElement(); - } - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - int _CUDA_D nextInt(int to) { - int r = nextInt(); - int m = to - 1; - if ((to & m) == 0) // i.e., bound is a power of 2 - r = ((to * (Nd4jLong) r) >> 31); - else { - for (int u = r; - u - (r = u % to) + m < 0; - u = nextInt()); - } - return r; - }; - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - int _CUDA_D nextInt(int from, int to) { - if (from == 0) - return nextInt(to); - - return from + nextInt(to - from); - }; - - - /** - * This method returns random T in range of [0..1] - * @return - */ - template - _CUDA_D T nextT() { - auto u = static_cast(nextUInt64()); - auto m = static_cast(sd::DataTypeUtils::max()); - return static_cast(u / m); - } - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - template - _CUDA_D T nextT(T to) { - if (to == static_cast(1.0f)) - return nextT(); - - return nextT(static_cast(0.0f), to); - } - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - template - _CUDA_D T inline nextT(T from, T to) { - return from + (nextT() * (to - from)); - } - - inline _CUDA_D uint64_t relativeUInt64(Nd4jLong index) { - return getElement(index); - } - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - inline int _CUDA_D relativeInt(Nd4jLong index) { - auto u = relativeUInt64(index); - return u <= sd::DataTypeUtils::max() ? static_cast(u) : static_cast(u % sd::DataTypeUtils::max()); - } - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - inline int _CUDA_D relativeInt(Nd4jLong index, int to) { - auto rel = relativeInt(index); - return rel % to; - } - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - inline _CUDA_D int relativeInt(Nd4jLong index, int from, int to) { - if (from == 0) - return relativeInt(index, to); - - return from + relativeInt(index, to - from); - } - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - template - inline _CUDA_D T relativeT(Nd4jLong index) { - /** - * Basically we just get float u/m value, and convert into to - * - * FIXME: once we add support for additional datatypes this code must be tweaked - */ - auto u = static_cast(relativeUInt64(index)); - auto m = static_cast (sd::DataTypeUtils::max()); - return static_cast(u / m); - } - -/** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ - - template - _CUDA_D T relativeT(Nd4jLong index, T to) { - if (to == static_cast(1.0f)) - return relativeT(index); - - return relativeT(index, static_cast(0.0f), to); - } + void rewindH(Nd4jLong numberOfElements) { + Nd4jLong newPos = this->getOffset() + numberOfElements; + if (newPos > this->getSize()) { + generation += newPos / this->size; + newPos = newPos % this->size; + } else if (newPos == this->getSize()) { + generation++; + newPos = 0; + } -/** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ - template - _CUDA_D T relativeT(Nd4jLong index, T from, T to) { - return from + (relativeT(index) * (to - from)); - } - - }; - - class SD_EXPORT IGenerator { - protected: - Nd4jLong limit; - Nd4jLong seed; - uint64_t *buffer; - sd::random::RandomBuffer *realBuffer; - - public: - - _CUDA_HD IGenerator(sd::random::RandomBuffer *buffer) { - this->limit = buffer->getSize(); - this->buffer = reinterpret_cast(buffer->getBuffer()); - this->realBuffer = buffer; - this->seed = buffer->getSeed(); - } - - - _CUDA_HD RandomBuffer *getBuffer() { - return realBuffer; - } - - _CUDA_HD void setOffset(Nd4jLong offset) { - this->realBuffer->setOffset(offset); - } - - _CUDA_HD Nd4jLong getElementAbsolute(Nd4jLong position) { - return buffer[position]; - } - - _CUDA_HD Nd4jLong getElementRelative(Nd4jLong position) { - return buffer[realBuffer->getOffset() + position]; - } - - virtual _CUDA_HD void refreshBuffer() = 0; - }; - - - - class SD_EXPORT Xoroshiro128 : public IGenerator { - protected: - uint64_t state[2]; - - static inline _CUDA_HD uint64_t rotl(const uint64_t x, int k) { - return (x << k) | (x >> (64 - k)); - } - - /** - * This method returns 64 random bits - * @return - */ - uint64_t _CUDA_HD next64() { - const uint64_t s0 = state[0]; - uint64_t s1 = state[1]; - const uint64_t result = s0 + s1; - - s1 ^= s0; - state[0] = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b - state[1] = rotl(s1, 36); // c - - return result; - } - - uint64_t _CUDA_HD seedConv(Nd4jLong seed) { - uint64_t x = static_cast(seed); - uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); - z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); - z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); - return z ^ (z >> 31); - } - - void _CUDA_H jump(void) { - static const uint64_t JUMP[] = { 0xbeac0467eba5facb, 0xd86b048b86aa9922 }; - - uint64_t s0 = 0; - uint64_t s1 = 0; - for(unsigned int i = 0; i < sizeof JUMP / sizeof *JUMP; i++) - for(int b = 0; b < 64; b++) { - if (JUMP[i] & 1ULL << b) { - s0 ^= state[0]; - s1 ^= state[1]; - } - next64(); - } - - state[0] = s0; - state[1] = s1; - } - - public: - _CUDA_HD Xoroshiro128(sd::random::RandomBuffer *buffer) : IGenerator(buffer) { - // - } - - _CUDA_HD void refreshBuffer() { - state[0] = seedConv(this->seed); - state[1] = seedConv(this->seed * 119 + 3); - - int fd = 3 + 3; - - for (Nd4jLong i = 0; i < limit; i++) { - buffer[i] = next64(); - } - } - }; + this->setOffset(newPos); + } + + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + int _CUDA_D nextInt() { + auto u = nextUInt64(); + return u <= sd::DataTypeUtils::max() + ? static_cast(u) + : static_cast(u % sd::DataTypeUtils::max()); + }; + + uint64_t _CUDA_D nextUInt64() { return getNextElement(); } + + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + int _CUDA_D nextInt(int to) { + int r = nextInt(); + int m = to - 1; + if ((to & m) == 0) // i.e., bound is a power of 2 + r = ((to * (Nd4jLong)r) >> 31); + else { + for (int u = r; u - (r = u % to) + m < 0; u = nextInt()) + ; + } + return r; + }; + + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + int _CUDA_D nextInt(int from, int to) { + if (from == 0) return nextInt(to); + + return from + nextInt(to - from); + }; + + /** + * This method returns random T in range of [0..1] + * @return + */ + template + _CUDA_D T nextT() { + auto u = static_cast(nextUInt64()); + auto m = static_cast(sd::DataTypeUtils::max()); + return static_cast(u / m); + } + + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ + template + _CUDA_D T nextT(T to) { + if (to == static_cast(1.0f)) return nextT(); + + return nextT(static_cast(0.0f), to); + } + + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ + template + _CUDA_D T inline nextT(T from, T to) { + return from + (nextT() * (to - from)); + } + + inline _CUDA_D uint64_t relativeUInt64(Nd4jLong index) { + return getElement(index); + } + + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + inline int _CUDA_D relativeInt(Nd4jLong index) { + auto u = relativeUInt64(index); + return u <= sd::DataTypeUtils::max() + ? static_cast(u) + : static_cast(u % sd::DataTypeUtils::max()); + } + + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + inline int _CUDA_D relativeInt(Nd4jLong index, int to) { + auto rel = relativeInt(index); + return rel % to; + } + + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + inline _CUDA_D int relativeInt(Nd4jLong index, int from, int to) { + if (from == 0) return relativeInt(index, to); + + return from + relativeInt(index, to - from); + } + + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ + template + inline _CUDA_D T relativeT(Nd4jLong index) { + /** + * Basically we just get float u/m value, and convert into to + * + * FIXME: once we add support for additional datatypes this code must be + * tweaked + */ + auto u = static_cast(relativeUInt64(index)); + auto m = static_cast(sd::DataTypeUtils::max()); + return static_cast(u / m); + } + + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + + template + _CUDA_D T relativeT(Nd4jLong index, T to) { + if (to == static_cast(1.0f)) return relativeT(index); + + return relativeT(index, static_cast(0.0f), to); + } + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ + template + _CUDA_D T relativeT(Nd4jLong index, T from, T to) { + return from + (relativeT(index) * (to - from)); + } +}; + +class SD_EXPORT IGenerator { + protected: + Nd4jLong limit; + Nd4jLong seed; + uint64_t *buffer; + sd::random::RandomBuffer *realBuffer; + + public: + _CUDA_HD IGenerator(sd::random::RandomBuffer *buffer) { + this->limit = buffer->getSize(); + this->buffer = reinterpret_cast(buffer->getBuffer()); + this->realBuffer = buffer; + this->seed = buffer->getSeed(); + } + + _CUDA_HD RandomBuffer *getBuffer() { return realBuffer; } + + _CUDA_HD void setOffset(Nd4jLong offset) { + this->realBuffer->setOffset(offset); + } + + _CUDA_HD Nd4jLong getElementAbsolute(Nd4jLong position) { + return buffer[position]; + } + + _CUDA_HD Nd4jLong getElementRelative(Nd4jLong position) { + return buffer[realBuffer->getOffset() + position]; + } + + virtual _CUDA_HD void refreshBuffer() = 0; +}; + +class SD_EXPORT Xoroshiro128 : public IGenerator { + protected: + uint64_t state[2]; + + static inline _CUDA_HD uint64_t rotl(const uint64_t x, int k) { + return (x << k) | (x >> (64 - k)); + } + + /** + * This method returns 64 random bits + * @return + */ + uint64_t _CUDA_HD next64() { + const uint64_t s0 = state[0]; + uint64_t s1 = state[1]; + const uint64_t result = s0 + s1; + + s1 ^= s0; + state[0] = rotl(s0, 55) ^ s1 ^ (s1 << 14); // a, b + state[1] = rotl(s1, 36); // c + + return result; + } + + uint64_t _CUDA_HD seedConv(Nd4jLong seed) { + uint64_t x = static_cast(seed); + uint64_t z = (x += UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); + } + + void _CUDA_H jump(void) { + static const uint64_t JUMP[] = {0xbeac0467eba5facb, 0xd86b048b86aa9922}; + + uint64_t s0 = 0; + uint64_t s1 = 0; + for (unsigned int i = 0; i < sizeof JUMP / sizeof *JUMP; i++) + for (int b = 0; b < 64; b++) { + if (JUMP[i] & 1ULL << b) { + s0 ^= state[0]; + s1 ^= state[1]; + } + next64(); + } + + state[0] = s0; + state[1] = s1; + } + + public: + _CUDA_HD Xoroshiro128(sd::random::RandomBuffer *buffer) : IGenerator(buffer) { + // + } + + _CUDA_HD void refreshBuffer() { + state[0] = seedConv(this->seed); + state[1] = seedConv(this->seed * 119 + 3); + + int fd = 3 + 3; + + for (Nd4jLong i = 0; i < limit; i++) { + buffer[i] = next64(); } -} -#endif //LIBND4J_HELPER_GENERATOR_H + } +}; +} // namespace random +} // namespace sd +#endif // LIBND4J_HELPER_GENERATOR_H diff --git a/libnd4j/include/helpers/helper_hash.h b/libnd4j/include/helpers/helper_hash.h index 10a22f2237d4..ca449867aaa1 100644 --- a/libnd4j/include/helpers/helper_hash.h +++ b/libnd4j/include/helpers/helper_hash.h @@ -15,36 +15,38 @@ ******************************************************************************/ // -// Stronger 64-bit hash function helper, as described here: http://www.javamex.com/tutorials/collections/strong_hash_code_implementation.shtml +// Stronger 64-bit hash function helper, as described here: +// http://www.javamex.com/tutorials/collections/strong_hash_code_implementation.shtml // @author raver119@gmail.com // #ifndef LIBND4J_HELPER_HASH_H #define LIBND4J_HELPER_HASH_H -#include #include #include + #include +#include namespace sd { - namespace ops { - class SD_EXPORT HashHelper { - private: - static HashHelper* _INSTANCE; - - Nd4jLong _byteTable[256]; - const Nd4jLong HSTART = 0xBB40E64DA205B064L; - const Nd4jLong HMULT = 7664345821815920749L; - - bool _isInit = false; - std::mutex _locker; - - public: - static HashHelper* getInstance(); - Nd4jLong getLongHash(const std::string& str); - }; - } -} - -#endif //LIBND4J_HELPER_HASH_H +namespace ops { +class SD_EXPORT HashHelper { + private: + static HashHelper* _INSTANCE; + + Nd4jLong _byteTable[256]; + const Nd4jLong HSTART = 0xBB40E64DA205B064L; + const Nd4jLong HMULT = 7664345821815920749L; + + bool _isInit = false; + std::mutex _locker; + + public: + static HashHelper* getInstance(); + Nd4jLong getLongHash(const std::string& str); +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_HELPER_HASH_H diff --git a/libnd4j/include/helpers/helper_ptrmap.h b/libnd4j/include/helpers/helper_ptrmap.h index 4f2ec128c9ec..bc115af776b9 100644 --- a/libnd4j/include/helpers/helper_ptrmap.h +++ b/libnd4j/include/helpers/helper_ptrmap.h @@ -29,191 +29,199 @@ namespace sd { - /** - * This class is a simple wrapper to represent batch arguments as single surface of parameters. - * So we pass batch parameters as single surface, and then we use this helper to extract arguments for each aggregates. - * - * Surface map format is simple: - * [0] we put numbers for num*Arguments - * [1] then we put indexing arguments, since their size is constant - * [2] here we put block of JVM IntArrays by value, batchLimit * maxIntArrays * maxArraySize; - * [3] then we put real arguments - * [4] then we put arguments pointers - * [5] then we put shape pointers - * - */ - template - class PointersHelper { - private: - int aggregates; - void *ptrGeneral; - - // we enforce maximal batch size limit, to simplify +/** + * This class is a simple wrapper to represent batch arguments as single surface + * of parameters. So we pass batch parameters as single surface, and then we use + * this helper to extract arguments for each aggregates. + * + * Surface map format is simple: + * [0] we put numbers for num*Arguments + * [1] then we put indexing arguments, since their size is constant + * [2] here we put block of JVM IntArrays by value, batchLimit * maxIntArrays * + * maxArraySize; [3] then we put real arguments [4] then we put arguments + * pointers [5] then we put shape pointers + * + */ +template +class PointersHelper { + private: + int aggregates; + void *ptrGeneral; + + // we enforce maximal batch size limit, to simplify #ifdef __CUDACC__ - const int batchLimit = 8192; + const int batchLimit = 8192; #else - const int batchLimit = 512; + const int batchLimit = 512; #endif - // we have 5 diff kinds of arguments: arguments, shapeArguments, intArrayArguments, indexArguments, realArguments - const int argTypes = 5; - - int maxIntArrays; - int maxArraySize; - - // right now we hardcode maximas, but we'll probably change that later - int maxIndexArguments; - int maxRealArguments; - - // since that's pointers (which is 64-bit on 64bit systems), we limit number of maximum arguments to 1/2 of maxIndex arguments - int maxArguments; - int maxShapeArguments; - - int sizeT; - int sizePtr; - public: - - /** - * We accept single memory chunk and number of jobs stored. - * - * @param ptrToParams pointer to "surface" - * @param numAggregates actual number of aggregates being passed in - * @return - */ + // we have 5 diff kinds of arguments: arguments, shapeArguments, + // intArrayArguments, indexArguments, realArguments + const int argTypes = 5; + + int maxIntArrays; + int maxArraySize; + + // right now we hardcode maximas, but we'll probably change that later + int maxIndexArguments; + int maxRealArguments; + + // since that's pointers (which is 64-bit on 64bit systems), we limit number + // of maximum arguments to 1/2 of maxIndex arguments + int maxArguments; + int maxShapeArguments; + + int sizeT; + int sizePtr; + + public: + /** + * We accept single memory chunk and number of jobs stored. + * + * @param ptrToParams pointer to "surface" + * @param numAggregates actual number of aggregates being passed in + * @return + */ #ifdef __CUDACC__ - __host__ __device__ + __host__ __device__ #endif - PointersHelper(void *ptrToParams, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals) { - aggregates = numAggregates; - ptrGeneral = ptrToParams; - - // ptrSize for hypothetical 32-bit compatibility - sizePtr = sizeof(ptrToParams); - - // unfortunately we have to know sizeOf(T) - sizeT = sizeof(T); - - this->maxIntArrays = maxIntArrays; - this->maxArraySize = maxIntArraySize; - this->maxIndexArguments = maxIdx; - this->maxArguments = maxArgs; - this->maxShapeArguments = maxShapes; - this->maxRealArguments = maxReals; - } - - /** - * This method returns point - * - * @param aggregateIdx - * @return - */ - - ptr_def T **getArguments(int aggregateIdx) { - T **aPtr = (T **) getRealArguments(batchLimit); - - return aPtr + (aggregateIdx * maxArguments); - } - - /** - * This method returns number of array arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes]; - } - - /** - * This method returns set of pointers to shape aruments for specified aggregates - * - * @param aggregateIdx - * @return - */ - ptr_def Nd4jLong **getShapeArguments(int aggregateIdx) { - Nd4jLong **sPtr = (Nd4jLong **)getArguments(batchLimit); - - return sPtr + (aggregateIdx * maxShapeArguments); - } - - /** - * This methor returns number of shape arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumShapeArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 1]; - } - - /** - * This method returns pointer to array of int/index arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int *getIndexArguments(int aggregateIdx) { - // we skip first numeric num*arguments - int *ptr = ((int *) ptrGeneral) + (batchLimit * argTypes); - - // and return index for requested aggregate - return ptr + (aggregateIdx * maxIndexArguments) ; - } - - /** - * This method returns number of int/index arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumIndexArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 2]; - } - - /** - * This method returns pointer to array of jvm IntArray arguments - */ - ptr_def int *getIntArrayArguments(int aggregateIdx, int argumentIdx) { - int *ptr = (int * )getIndexArguments(batchLimit); - - return ptr + (aggregateIdx * maxIntArrays * maxArraySize) + (argumentIdx * maxArraySize); - } - - /** - * This method returns number of jvm IntArray arguments - */ - ptr_def int getNumIntArrayArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 4]; - } - - /** - * This method returns real arguments for specific aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def T *getRealArguments(int aggregateIdx) { - // we get pointer for last batchElement + 1, so that'll be pointer for 0 realArgument - T *ptr = (T * ) getIntArrayArguments(batchLimit, 0); - - return ptr + (aggregateIdx * maxRealArguments); - } - - /** - * This methor returns number of real arguments for specified aggregate - * - * @param aggregateIdx - * @return - */ - ptr_def int getNumRealArguments(int aggregateIdx) { - int *tPtr = (int *) ptrGeneral; - return tPtr[aggregateIdx * argTypes + 3]; - } - }; -} - -#endif //LIBND4J_HELPER_PTRMAP_H + PointersHelper(void *ptrToParams, int numAggregates, int maxArgs, + int maxShapes, int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals) { + aggregates = numAggregates; + ptrGeneral = ptrToParams; + + // ptrSize for hypothetical 32-bit compatibility + sizePtr = sizeof(ptrToParams); + + // unfortunately we have to know sizeOf(T) + sizeT = sizeof(T); + + this->maxIntArrays = maxIntArrays; + this->maxArraySize = maxIntArraySize; + this->maxIndexArguments = maxIdx; + this->maxArguments = maxArgs; + this->maxShapeArguments = maxShapes; + this->maxRealArguments = maxReals; + } + + /** + * This method returns point + * + * @param aggregateIdx + * @return + */ + + ptr_def T **getArguments(int aggregateIdx) { + T **aPtr = (T **)getRealArguments(batchLimit); + + return aPtr + (aggregateIdx * maxArguments); + } + + /** + * This method returns number of array arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes]; + } + + /** + * This method returns set of pointers to shape aruments for specified + * aggregates + * + * @param aggregateIdx + * @return + */ + ptr_def Nd4jLong **getShapeArguments(int aggregateIdx) { + Nd4jLong **sPtr = (Nd4jLong **)getArguments(batchLimit); + + return sPtr + (aggregateIdx * maxShapeArguments); + } + + /** + * This methor returns number of shape arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumShapeArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 1]; + } + + /** + * This method returns pointer to array of int/index arguments for specified + * aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int *getIndexArguments(int aggregateIdx) { + // we skip first numeric num*arguments + int *ptr = ((int *)ptrGeneral) + (batchLimit * argTypes); + + // and return index for requested aggregate + return ptr + (aggregateIdx * maxIndexArguments); + } + + /** + * This method returns number of int/index arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumIndexArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 2]; + } + + /** + * This method returns pointer to array of jvm IntArray arguments + */ + ptr_def int *getIntArrayArguments(int aggregateIdx, int argumentIdx) { + int *ptr = (int *)getIndexArguments(batchLimit); + + return ptr + (aggregateIdx * maxIntArrays * maxArraySize) + + (argumentIdx * maxArraySize); + } + + /** + * This method returns number of jvm IntArray arguments + */ + ptr_def int getNumIntArrayArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 4]; + } + + /** + * This method returns real arguments for specific aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def T *getRealArguments(int aggregateIdx) { + // we get pointer for last batchElement + 1, so that'll be pointer for 0 + // realArgument + T *ptr = (T *)getIntArrayArguments(batchLimit, 0); + + return ptr + (aggregateIdx * maxRealArguments); + } + + /** + * This methor returns number of real arguments for specified aggregate + * + * @param aggregateIdx + * @return + */ + ptr_def int getNumRealArguments(int aggregateIdx) { + int *tPtr = (int *)ptrGeneral; + return tPtr[aggregateIdx * argTypes + 3]; + } +}; +} // namespace sd + +#endif // LIBND4J_HELPER_PTRMAP_H diff --git a/libnd4j/include/helpers/helper_random.h b/libnd4j/include/helpers/helper_random.h index 6f2523e05fa6..829d33209f5a 100644 --- a/libnd4j/include/helpers/helper_random.h +++ b/libnd4j/include/helpers/helper_random.h @@ -33,204 +33,194 @@ #endif - namespace sd { - namespace random { - - template - class RandomHelper { - private: - sd::random::IGenerator *generator; - sd::random::RandomBuffer *buffer; - - - public: - - _CUDA_HD RandomHelper(sd::random::IGenerator *generator) { - this->generator = generator; - this->buffer = generator->getBuffer(); - } - - _CUDA_HD RandomHelper(sd::random::RandomBuffer *buffer) { - this->buffer = buffer; - } - - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - inline _CUDA_D int nextInt() { - int r = (int) nextUInt(); - return r < 0 ? -1 * r : r; - }; - - inline _CUDA_D uint64_t nextUInt() { - return buffer->getNextElement(); - } - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - inline _CUDA_D int nextInt(int to) { - int r = nextInt(); - int m = to - 1; - if ((to & m) == 0) // i.e., bound is a power of 2 - r = (int) ((to * (long) r) >> 31); - else { - for (int u = r; - u - (r = u % to) + m < 0; - u = nextInt()); - } - return r; - }; - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - inline _CUDA_D int nextInt(int from, int to) { - if (from == 0) - return nextInt(to); - - return from + nextInt(to - from); - }; - - - /** - * This method returns random T in range of [0..MAX_FLOAT] - * @return - */ - inline _CUDA_D T nextMaxT() { - T rnd = (T) buffer->getNextElement(); - return rnd < 0 ? -1 * rnd : rnd; - }; - - - /** - * This method returns random T in range of [0..1] - * @return - */ - inline _CUDA_D T nextT() { - return (T) nextUInt() / (T) sd::DataTypeUtils::max(); - } - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - inline _CUDA_D T nextT(T to) { - if (to == (T) 1.0f) - return nextT(); - - return nextT((T) 0.0f, to); - }; - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - inline _CUDA_D T nextT(T from, T to) { - return from + (nextT() * (to - from)); - } - - inline _CUDA_D uint64_t relativeUInt(Nd4jLong index) { - return buffer->getElement(index); - } - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - inline _CUDA_D int relativeInt(Nd4jLong index) { - return (int) (relativeUInt(index) % (sd::DataTypeUtils::max() + 1)); - } - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - inline _CUDA_D int relativeInt(Nd4jLong index, int to) { - int rel = relativeInt(index); - return rel % to; - } - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - inline int _CUDA_D relativeInt(Nd4jLong index, int to, int from) { - if (from == 0) - return relativeInt(index, to); - - return from + relativeInt(index, to - from); - } - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - - inline _CUDA_D T relativeT(Nd4jLong index) { - if (sizeof(T) < 4) { - // FIXME: this is fast hack for short types, like fp16. This should be improved. - return (T)((float) relativeUInt(index) / (float) sd::DataTypeUtils::max()); - } else return (T) relativeUInt(index) / (T) sd::DataTypeUtils::max(); - } - - /** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ - inline _CUDA_D T relativeT(Nd4jLong index, T to) { - if (to == (T) 1.0f) - return relativeT(index); - - return relativeT(index, (T) 0.0f, to); - } - - /** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ - inline _CUDA_D T relativeT(Nd4jLong index, T from, T to) { - return from + (relativeT(index) * (to - from)); - } - - - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ - inline _CUDA_D void rewind(Nd4jLong numberOfElements) { - buffer->rewindH(numberOfElements); - } - }; +namespace random { + +template +class RandomHelper { + private: + sd::random::IGenerator *generator; + sd::random::RandomBuffer *buffer; + + public: + _CUDA_HD RandomHelper(sd::random::IGenerator *generator) { + this->generator = generator; + this->buffer = generator->getBuffer(); + } + + _CUDA_HD RandomHelper(sd::random::RandomBuffer *buffer) { + this->buffer = buffer; + } + + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + inline _CUDA_D int nextInt() { + int r = (int)nextUInt(); + return r < 0 ? -1 * r : r; + }; + + inline _CUDA_D uint64_t nextUInt() { return buffer->getNextElement(); } + + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + inline _CUDA_D int nextInt(int to) { + int r = nextInt(); + int m = to - 1; + if ((to & m) == 0) // i.e., bound is a power of 2 + r = (int)((to * (long)r) >> 31); + else { + for (int u = r; u - (r = u % to) + m < 0; u = nextInt()) + ; } -} - -#endif //LIBND4J_HELPER_RANDOM_H + return r; + }; + + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + inline _CUDA_D int nextInt(int from, int to) { + if (from == 0) return nextInt(to); + + return from + nextInt(to - from); + }; + + /** + * This method returns random T in range of [0..MAX_FLOAT] + * @return + */ + inline _CUDA_D T nextMaxT() { + T rnd = (T)buffer->getNextElement(); + return rnd < 0 ? -1 * rnd : rnd; + }; + + /** + * This method returns random T in range of [0..1] + * @return + */ + inline _CUDA_D T nextT() { + return (T)nextUInt() / (T)sd::DataTypeUtils::max(); + } + + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ + inline _CUDA_D T nextT(T to) { + if (to == (T)1.0f) return nextT(); + + return nextT((T)0.0f, to); + }; + + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ + inline _CUDA_D T nextT(T from, T to) { + return from + (nextT() * (to - from)); + } + + inline _CUDA_D uint64_t relativeUInt(Nd4jLong index) { + return buffer->getElement(index); + } + + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + inline _CUDA_D int relativeInt(Nd4jLong index) { + return (int)(relativeUInt(index) % + (sd::DataTypeUtils::max() + 1)); + } + + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + inline _CUDA_D int relativeInt(Nd4jLong index, int to) { + int rel = relativeInt(index); + return rel % to; + } + + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + inline int _CUDA_D relativeInt(Nd4jLong index, int to, int from) { + if (from == 0) return relativeInt(index, to); + + return from + relativeInt(index, to - from); + } + + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ + + inline _CUDA_D T relativeT(Nd4jLong index) { + if (sizeof(T) < 4) { + // FIXME: this is fast hack for short types, like fp16. This should be + // improved. + return (T)((float)relativeUInt(index) / + (float)sd::DataTypeUtils::max()); + } else + return (T)relativeUInt(index) / (T)sd::DataTypeUtils::max(); + } + + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + inline _CUDA_D T relativeT(Nd4jLong index, T to) { + if (to == (T)1.0f) return relativeT(index); + + return relativeT(index, (T)0.0f, to); + } + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ + inline _CUDA_D T relativeT(Nd4jLong index, T from, T to) { + return from + (relativeT(index) * (to - from)); + } + + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ + inline _CUDA_D void rewind(Nd4jLong numberOfElements) { + buffer->rewindH(numberOfElements); + } +}; +} // namespace random +} // namespace sd + +#endif // LIBND4J_HELPER_RANDOM_H diff --git a/libnd4j/include/helpers/hhColPivQR.h b/libnd4j/include/helpers/hhColPivQR.h index 28dd42f64a0b..7cbff757270a 100644 --- a/libnd4j/include/helpers/hhColPivQR.h +++ b/libnd4j/include/helpers/hhColPivQR.h @@ -21,36 +21,31 @@ #ifndef LIBND4J_HHCOLPICQR_H #define LIBND4J_HHCOLPICQR_H -#include #include +#include namespace sd { namespace ops { namespace helpers { class HHcolPivQR { + public: + NDArray _qr; + NDArray _coeffs; + NDArray _permut; + int _diagSize; - public: - - NDArray _qr; - NDArray _coeffs; - NDArray _permut; - int _diagSize; - - HHcolPivQR() = delete; - HHcolPivQR(const NDArray& matrix); + HHcolPivQR() = delete; + HHcolPivQR(const NDArray& matrix); - template - void _evalData(); + template + void _evalData(); - void evalData(); + void evalData(); }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_HHCOLPICQR_H +#endif // LIBND4J_HHCOLPICQR_H diff --git a/libnd4j/include/helpers/hhSequence.h b/libnd4j/include/helpers/hhSequence.h index 31855a86cb2f..56db076af006 100644 --- a/libnd4j/include/helpers/hhSequence.h +++ b/libnd4j/include/helpers/hhSequence.h @@ -27,74 +27,67 @@ namespace sd { namespace ops { namespace helpers { - class HHsequence { - - public: - - /* - * matrix containing the Householder vectors - */ - NDArray _vectors; - - /* - * vector containing the Householder coefficients - */ - NDArray _coeffs; - - /* - * shift of the Householder sequence - */ - int _shift; - - /* - * length of the Householder sequence - */ - int _diagSize; - - /* - * type of sequence, type = 'u' (acting on columns, left) or type = 'v' (acting on rows, right) - */ - char _type; - - /* - * constructor - */ - HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type); - - /** - * this method mathematically multiplies input matrix on Householder sequence from the left H0*H1*...Hn * matrix - * - * matrix - input matrix to be multiplied - */ - template - void _mulLeft(NDArray& matrix); - - void mulLeft(NDArray& matrix); - - NDArray getTail(const int idx) const; - - template - void _applyTo(NDArray& dest); - - void applyTo(NDArray& dest); - - FORCEINLINE int rows() const; - + public: + /* + * matrix containing the Householder vectors + */ + NDArray _vectors; + + /* + * vector containing the Householder coefficients + */ + NDArray _coeffs; + + /* + * shift of the Householder sequence + */ + int _shift; + + /* + * length of the Householder sequence + */ + int _diagSize; + + /* + * type of sequence, type = 'u' (acting on columns, left) or type = 'v' + * (acting on rows, right) + */ + char _type; + + /* + * constructor + */ + HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type); + + /** + * this method mathematically multiplies input matrix on Householder sequence + * from the left H0*H1*...Hn * matrix + * + * matrix - input matrix to be multiplied + */ + template + void _mulLeft(NDArray& matrix); + + void mulLeft(NDArray& matrix); + + NDArray getTail(const int idx) const; + + template + void _applyTo(NDArray& dest); + + void applyTo(NDArray& dest); + + FORCEINLINE int rows() const; }; - ////////////////////////////////////////////////////////////////////////// FORCEINLINE int HHsequence::rows() const { - - return _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); -} - - - -} -} + return _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_HHSEQUENCE_H +#endif // LIBND4J_HHSEQUENCE_H diff --git a/libnd4j/include/helpers/householder.h b/libnd4j/include/helpers/householder.h index e7176990135b..d28ce20871bf 100644 --- a/libnd4j/include/helpers/householder.h +++ b/libnd4j/include/helpers/householder.h @@ -21,7 +21,6 @@ #ifndef LIBND4J_HOUSEHOLDER_H #define LIBND4J_HOUSEHOLDER_H - #include "array/NDArray.h" namespace sd { @@ -30,92 +29,89 @@ namespace helpers { template class Householder { - - public: - - /** - * this method calculates Householder matrix P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...] - * w = u / u0 - * u = x - |x|*e0 - * u0 = x0 - |x| - * e0 = [1, 0, 0 , 0, ...] - * - * x - input vector, remains unaffected - */ - static NDArray evalHHmatrix(const NDArray& x); - - /** - * this method evaluates data required for calculation of Householder matrix P = identity_matrix - coeff * w * w^T - * P * x = [normX, 0, 0 , 0, ...] - * coeff - scalar - * w = [1, w1, w2, w3, ...] - * w = u / u0 - * u = x - |x|*e0 - * u0 = x0 - |x| - * e0 = [1, 0, 0 , 0, ...] - * - * x - input vector, remains unaffected - * tail - the essential part of the vector w: [w1, w2, w3, ...] - * normX - this scalar is the first non-zero element in vector resulting from Householder transformation -> (P*x) - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, T& normX); - - static void evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX); - - /** - * this method mathematically multiplies input matrix on Householder from the left P * matrix - * - * matrix - input matrix - * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void mulLeft(NDArray& matrix, const NDArray& tail, const T coeff); - - /** - * this method mathematically multiplies input matrix on Householder from the right matrix * P - * - * matrix - input matrix - * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] - * coeff - scalar, scaling factor in Householder matrix formula - */ - static void mulRight(NDArray& matrix, const NDArray& tail, const T coeff); - - - + public: + /** + * this method calculates Householder matrix P = identity_matrix - coeff * w + * * w^T P * x = [normX, 0, 0 , 0, ...] coeff - scalar w = [1, w1, w2, w3, + * ...] w = u / u0 u = x - |x|*e0 u0 = x0 - |x| e0 = [1, 0, 0 , 0, ...] + * + * x - input vector, remains unaffected + */ + static NDArray evalHHmatrix(const NDArray& x); + + /** + * this method evaluates data required for calculation of Householder matrix + * P = identity_matrix - coeff * w * w^T P * x = [normX, 0, 0 , 0, ...] coeff + * - scalar w = [1, w1, w2, w3, ...] w = u / u0 u = x - |x|*e0 u0 = x0 - |x| + * e0 = [1, 0, 0 , 0, ...] + * + * x - input vector, remains unaffected + * tail - the essential part of the vector w: [w1, w2, w3, ...] + * normX - this scalar is the first non-zero element in vector resulting from + * Householder transformation -> (P*x) coeff - scalar, scaling factor in + * Householder matrix formula + */ + static void evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff, + T& normX); + + static void evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX); + + /** + * this method mathematically multiplies input matrix on Householder from the + * left P * matrix + * + * matrix - input matrix + * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] + * coeff - scalar, scaling factor in Householder matrix formula + */ + static void mulLeft(NDArray& matrix, const NDArray& tail, const T coeff); + + /** + * this method mathematically multiplies input matrix on Householder from the + * right matrix * P + * + * matrix - input matrix + * tail - the essential part of the Householder vector w: [w1, w2, w3, ...] + * coeff - scalar, scaling factor in Householder matrix formula + */ + static void mulRight(NDArray& matrix, const NDArray& tail, const T coeff); }; - - // /** - // * this function reduce given matrix to upper bidiagonal form (in-place operation), matrix must satisfy following condition rows >= cols - // * - // * matrix - input 2D matrix to be reduced to upper bidiagonal from - // */ - // template - // void biDiagonalizeUp(NDArray& matrix); - - // /** - // * given a matrix [m,n], this function computes its singular value decomposition matrix = u * s * v^T - // * - // * matrix - input 2D matrix to decompose, [m, n] - // * u - unitary matrix containing left singular vectors of input matrix, [m, m] - // * s - diagonal matrix with singular values of input matrix (non-negative) on the diagonal sorted in decreasing order, - // * actually the mathematically correct dimension of s is [m, n], however for memory saving we work with s as vector [1, p], where p is smaller among m and n - // * v - unitary matrix containing right singular vectors of input matrix, [n, n] - // * calcUV - if true then u and v will be computed, in opposite case function works significantly faster - // * fullUV - if false then only p (p is smaller among m and n) first columns of u and v will be calculated and their dimensions in this case are [m, p] and [n, p] - // * - // */ - // void svd(const NDArray& matrix, NDArray& u, NDArray& s, NDArray& v, const bool calcUV = false, const bool fullUV = false) - - - -} -} -} - - -#endif //LIBND4J_HOUSEHOLDER_H +// /** +// * this function reduce given matrix to upper bidiagonal form (in-place +// operation), matrix must satisfy following condition rows >= cols +// * +// * matrix - input 2D matrix to be reduced to upper bidiagonal from +// */ +// template +// void biDiagonalizeUp(NDArray& matrix); + +// /** +// * given a matrix [m,n], this function computes its singular value +// decomposition matrix = u * s * v^T +// * +// * matrix - input 2D matrix to decompose, [m, n] +// * u - unitary matrix containing left singular vectors of input matrix, [m, +// m] +// * s - diagonal matrix with singular values of input matrix (non-negative) on +// the diagonal sorted in decreasing order, +// * actually the mathematically correct dimension of s is [m, n], however +// for memory saving we work with s as vector [1, p], where p is smaller among m +// and n +// * v - unitary matrix containing right singular vectors of input matrix, [n, +// n] +// * calcUV - if true then u and v will be computed, in opposite case function +// works significantly faster +// * fullUV - if false then only p (p is smaller among m and n) first columns +// of u and v will be calculated and their dimensions in this case are [m, p] +// and [n, p] +// * +// */ +// void svd(const NDArray& matrix, NDArray& u, NDArray& s, NDArray& v, const +// bool calcUV = false, const bool fullUV = false) + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_HOUSEHOLDER_H diff --git a/libnd4j/include/helpers/impl/ArrayUtils.cpp b/libnd4j/include/helpers/impl/ArrayUtils.cpp index 004cb15469cc..04ea00a06f10 100644 --- a/libnd4j/include/helpers/impl/ArrayUtils.cpp +++ b/libnd4j/include/helpers/impl/ArrayUtils.cpp @@ -21,37 +21,34 @@ #include namespace sd { - namespace ArrayUtils { - void toIntPtr(std::initializer_list list, int* target) { - std::vector vec(list); - toIntPtr(vec, target); - } - - void toIntPtr(std::vector& list, int* target) { - memcpy(target, list.data(), list.size() * sizeof(int)); - } - - void toLongPtr(std::initializer_list list, Nd4jLong* target) { - std::vector vec(list); - toLongPtr(vec, target); - } - - void toLongPtr(std::vector& list, Nd4jLong* target) { - memcpy(target, list.data(), list.size() * sizeof(Nd4jLong)); - } - - std::vector toLongVector(std::vector vec) { - std::vector result(vec.size()); - Nd4jLong vecSize = vec.size(); - - for (Nd4jLong e = 0; e < vecSize; e++) - result[e] = vec[e]; - - return result; - } - - std::vector toLongVector(std::vector vec) { - return vec; - } - } +namespace ArrayUtils { +void toIntPtr(std::initializer_list list, int* target) { + std::vector vec(list); + toIntPtr(vec, target); } + +void toIntPtr(std::vector& list, int* target) { + memcpy(target, list.data(), list.size() * sizeof(int)); +} + +void toLongPtr(std::initializer_list list, Nd4jLong* target) { + std::vector vec(list); + toLongPtr(vec, target); +} + +void toLongPtr(std::vector& list, Nd4jLong* target) { + memcpy(target, list.data(), list.size() * sizeof(Nd4jLong)); +} + +std::vector toLongVector(std::vector vec) { + std::vector result(vec.size()); + Nd4jLong vecSize = vec.size(); + + for (Nd4jLong e = 0; e < vecSize; e++) result[e] = vec[e]; + + return result; +} + +std::vector toLongVector(std::vector vec) { return vec; } +} // namespace ArrayUtils +} // namespace sd diff --git a/libnd4j/include/helpers/impl/AttentionHelper.cpp b/libnd4j/include/helpers/impl/AttentionHelper.cpp index bd5d006f2b17..00ec60aad546 100644 --- a/libnd4j/include/helpers/impl/AttentionHelper.cpp +++ b/libnd4j/include/helpers/impl/AttentionHelper.cpp @@ -21,61 +21,82 @@ #ifndef LIBND4J_ATTENTIONHELPER_CPP #define LIBND4J_ATTENTIONHELPER_CPP -#include - #include "../AttentionHelper.h" + +#include #include namespace sd { - sd::NDArray AttentionHelper::multiHeadProject(const sd::NDArray *input, const sd::NDArray *projectionMatrix, sd::LaunchContext * context) { - auto miniBatchSize = input->sizeAt(0); - auto seqLength = input->sizeAt(2); - auto numHeads = projectionMatrix->sizeAt(0); - auto projectedSize = projectionMatrix->sizeAt(1); - - auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps] - auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps] - auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn] - - NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] - sd::ops::matmul mmul; - mmul.execute({&projectionPrep, &inputPrep}, {&projected}); - - projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); - projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] - - return projected; - } - - void AttentionHelper::multiHeadProjectBp(const sd::NDArray *input, const sd::NDArray *projectionMatrix, - const sd::NDArray *eps, sd::NDArray *dLdInput, - sd::NDArray *dLdProjectionMatrix, sd::LaunchContext * context) { - auto miniBatchSize = input->sizeAt(0); - auto seqLength = input->sizeAt(2); - auto numHeads = projectionMatrix->sizeAt(0); - auto projectedSize = projectionMatrix->sizeAt(1); - - auto epsPerm = eps->permute({1, 2, 0, 3}); - auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength}); - - auto inputPerm = input->permute({1, 0, 2}); - auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); - auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); - - sd::ops::matmul_bp mmulBp; - NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); - NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); - mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); - - dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); - dLdProjectionMatrix->assign(dLdProjectionPrep); - - dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength}); - dLdInputPrep.permutei({1, 0, 2}); - dLdInput->assign(dLdInputPrep); - } +sd::NDArray AttentionHelper::multiHeadProject( + const sd::NDArray *input, const sd::NDArray *projectionMatrix, + sd::LaunchContext *context) { + auto miniBatchSize = input->sizeAt(0); + auto seqLength = input->sizeAt(2); + auto numHeads = projectionMatrix->sizeAt(0); + auto projectedSize = projectionMatrix->sizeAt(1); + + auto inputPerm = input->permute( + {1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps] + auto inputPrep = inputPerm.reshape( + 'c', {input->sizeAt(1), + (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps] + auto projectionPrep = projectionMatrix->reshape( + 'c', + {numHeads * projectionMatrix->sizeAt(1), + projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn] + + NDArray projected( + 'c', + {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, + input->dataType(), context); //[nHeads*hS, batch*timeSteps] + sd::ops::matmul mmul; + mmul.execute({&projectionPrep, &inputPrep}, {&projected}); + + projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); + projected.permutei( + {2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] + + return projected; } +void AttentionHelper::multiHeadProjectBp(const sd::NDArray *input, + const sd::NDArray *projectionMatrix, + const sd::NDArray *eps, + sd::NDArray *dLdInput, + sd::NDArray *dLdProjectionMatrix, + sd::LaunchContext *context) { + auto miniBatchSize = input->sizeAt(0); + auto seqLength = input->sizeAt(2); + auto numHeads = projectionMatrix->sizeAt(0); + auto projectedSize = projectionMatrix->sizeAt(1); + + auto epsPerm = eps->permute({1, 2, 0, 3}); + auto epsReshaped = epsPerm.reshape( + 'c', {numHeads * projectedSize, miniBatchSize * seqLength}); + + auto inputPerm = input->permute({1, 0, 2}); + auto inputPrep = + inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); + auto projectionPrep = projectionMatrix->reshape( + 'c', + {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); + + sd::ops::matmul_bp mmulBp; + NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context); + NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context); + mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, + std::vector{&dLdProjectionPrep, &dLdInputPrep}, {}, + {}, {}); + + dLdProjectionPrep.reshapei( + {numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); + dLdProjectionMatrix->assign(dLdProjectionPrep); + + dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength}); + dLdInputPrep.permutei({1, 0, 2}); + dLdInput->assign(dLdInputPrep); +} +} // namespace sd #endif diff --git a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp index df97b08c6e8d..a672a0a99979 100644 --- a/libnd4j/include/helpers/impl/BenchmarkHelper.cpp +++ b/libnd4j/include/helpers/impl/BenchmarkHelper.cpp @@ -18,683 +18,756 @@ * @author raver119@gmail.com */ - #include "../BenchmarkHelper.h" + #include -#include #include +#include + namespace sd { - BenchmarkHelper::BenchmarkHelper(unsigned int warmUpIterations, unsigned int runIterations) { - _wIterations = warmUpIterations; - _rIterations = runIterations; - } +BenchmarkHelper::BenchmarkHelper(unsigned int warmUpIterations, + unsigned int runIterations) { + _wIterations = warmUpIterations; + _rIterations = runIterations; +} + +std::string BenchmarkHelper::printHeader() { + return std::string( + "TestName\tOpNum\tWarmup\tNumIter\tDataType\tInplace\tShape\tStrides\tAxi" + "s\tOrders\tavg (us)\tmedian (us)\tmin (us)\tmax (us)\tstdev (us)\n"); +} + +std::string BenchmarkHelper::benchmarkOperation(OpBenchmark &benchmark) { + for (uint i = 0; i < _wIterations; i++) benchmark.executeOnce(); + + std::vector timings(_rIterations); + double sumT = 0.0; + + for (uint i = 0; i < _rIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + benchmark.executeOnce(); + + auto timeEnd = std::chrono::system_clock::now(); + auto loopTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); + timings[i] = loopTime; + sumT += loopTime; + } + sumT /= _rIterations; + + std::sort(timings.begin(), timings.end()); + Nd4jLong median = timings[_rIterations / 2]; + + auto n = NDArrayFactory::create(timings, LaunchContext::defaultContext()); + + auto stdev = + n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false) + .e(0); + auto min = n.reduceNumber(sd::reduce::Min).e(0); + auto max = n.reduceNumber(sd::reduce::Max).e(0); + + // opNum, DataType, Shape, average time, median time + auto t = benchmark.dataType(); + auto s = benchmark.shape(); + auto strides = benchmark.strides(); + auto o = benchmark.orders(); + auto a = benchmark.axis(); + auto inpl = benchmark.inplace(); + + std::string temp; + temp.resize(65536); + + // printing out stuff + snprintf( + const_cast(temp.data()), temp.length(), + "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\t%s\t%lld\t%lld\t%lld\t%lld\t%.2f\n", + benchmark.testName().c_str(), benchmark.opNum(), _wIterations, + _rIterations, t.c_str(), inpl.c_str(), s.c_str(), strides.c_str(), + a.c_str(), o.c_str(), sd::math::nd4j_floor(sumT), + median, min, max, stdev); + + auto pos = temp.find('\n'); + return temp.substr(0, pos + 1); +} + +void BenchmarkHelper::benchmarkScalarOperation(scalar::Ops op, + std::string testName, + double value, NDArray &x, + NDArray &z) { + auto y = NDArrayFactory::create(x.dataType(), value); + + // for (uint i = 0; i < _wIterations; i++) + // NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), + // z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); + + std::vector timings(_rIterations); + double sumT = 0.0; + + for (uint i = 0; i < _rIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + // NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), + // z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); + + auto timeEnd = std::chrono::system_clock::now(); + auto loopTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); + timings[i] = loopTime; + sumT += loopTime; + } + sumT /= _rIterations; + + std::sort(timings.begin(), timings.end()); + Nd4jLong median = timings[_rIterations / 2]; + + NDArray n = NDArrayFactory::create(timings, nullptr); + double stdev = + n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false) + .e(0); + Nd4jLong min = n.reduceNumber(sd::reduce::Min).e(0); + Nd4jLong max = n.reduceNumber(sd::reduce::Max).e(0); + + // opNum, DataType, Shape, average time, median time + auto t = DataTypeUtils::asString(x.dataType()); + auto s = ShapeUtils::shapeAsString(&x); + auto stride = ShapeUtils::strideAsString(&x); + stride += "/"; + stride += ShapeUtils::strideAsString(&z); + std::string o; + o += x.ordering(); + o += "/"; + o += z.ordering(); + std::string inpl; + inpl += (x == z ? "true" : "false"); + + // printing out stuff + nd4j_printf( + "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\tn/a\t%lld\t%lld\t%lld\t%lld\t%.2f\n", + testName.c_str(), op, _wIterations, _rIterations, t.c_str(), inpl.c_str(), + s.c_str(), stride.c_str(), o.c_str(), + sd::math::nd4j_floor(sumT), median, min, max, stdev); +} + +std::string BenchmarkHelper::runOperationSuit( + std::initializer_list benchmarks, const char *msg) { + std::vector ops(benchmarks); + return runOperationSuit(ops, msg); +} + +std::string BenchmarkHelper::runOperationSuit(OpBenchmark *benchmark) { + return benchmarkOperation(*benchmark); +} + +std::string BenchmarkHelper::runOperationSuit( + std::vector &benchmarks, bool postHeaders, const char *msg) { + std::string result; + + if (msg != nullptr && postHeaders) { + result += "\n"; + result += msg; + result += "\n"; + } + + if (postHeaders) result += printHeader(); + + for (auto v : benchmarks) result += benchmarkOperation(*v); + + return result; +} + +std::string BenchmarkHelper::runOperationSuit( + DeclarableBenchmark *op, const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string result; + + if (message != nullptr) { + result += "\n"; + result += message; + result += "\n"; + } + + result += printHeader(); + + std::vector list; + + for (auto &p : parameters) { + auto ctx = func(p); + + auto clone = reinterpret_cast(op->clone()); + clone->setContext(ctx); + + result += runOperationSuit(clone); + + delete clone; + } + + return result; +} + +std::string BenchmarkHelper::runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ScalarBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); - std::string BenchmarkHelper::printHeader() { - return std::string("TestName\tOpNum\tWarmup\tNumIter\tDataType\tInplace\tShape\tStrides\tAxis\tOrders\tavg (us)\tmedian (us)\tmin (us)\tmax (us)\tstdev (us)\n"); + result.emplace_back(clone); } - std::string BenchmarkHelper::benchmarkOperation(OpBenchmark &benchmark) { - - for (uint i = 0; i < _wIterations; i++) - benchmark.executeOnce(); - - std::vector timings(_rIterations); - double sumT = 0.0; - - for (uint i = 0; i < _rIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - benchmark.executeOnce(); - - auto timeEnd = std::chrono::system_clock::now(); - auto loopTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); - timings[i] = loopTime; - sumT += loopTime; - } - sumT /= _rIterations; - - std::sort(timings.begin(), timings.end()); - Nd4jLong median = timings[_rIterations / 2]; - - auto n = NDArrayFactory::create(timings, LaunchContext::defaultContext()); - - auto stdev = n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false).e(0); - auto min = n.reduceNumber(sd::reduce::Min).e(0); - auto max = n.reduceNumber(sd::reduce::Max).e(0); - - // opNum, DataType, Shape, average time, median time - auto t = benchmark.dataType(); - auto s = benchmark.shape(); - auto strides = benchmark.strides(); - auto o = benchmark.orders(); - auto a = benchmark.axis(); - auto inpl = benchmark.inplace(); + output += runOperationSuit(result, false); - std::string temp; - temp.resize(65536); - - // printing out stuff - snprintf(const_cast(temp.data()), temp.length(), "%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\t%s\t%lld\t%lld\t%lld\t%lld\t%.2f\n", benchmark.testName().c_str(), benchmark.opNum(), - _wIterations, _rIterations, t.c_str(), inpl.c_str(), s.c_str(), strides.c_str(), a.c_str(), o.c_str(), - sd::math::nd4j_floor(sumT), median, min, max, stdev); - - auto pos = temp.find('\n'); - return temp.substr(0, pos + 1); + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - void BenchmarkHelper::benchmarkScalarOperation(scalar::Ops op, std::string testName, double value, NDArray &x, NDArray &z) { - auto y = NDArrayFactory::create(x.dataType(), value); - - //for (uint i = 0; i < _wIterations; i++) - //NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); - - - std::vector timings(_rIterations); - double sumT = 0.0; - - for (uint i = 0; i < _rIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - //NativeOpExecutioner::execScalar(op, x.buffer(), x.shapeInfo(), z.buffer(), z.shapeInfo(), y.buffer(), y.shapeInfo(), nullptr); - - auto timeEnd = std::chrono::system_clock::now(); - auto loopTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); - timings[i] = loopTime; - sumT += loopTime; - } - sumT /= _rIterations; - - std::sort(timings.begin(), timings.end()); - Nd4jLong median = timings[_rIterations / 2]; - - NDArray n = NDArrayFactory::create(timings, nullptr); - double stdev = n.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false).e(0); - Nd4jLong min = n.reduceNumber(sd::reduce::Min).e(0); - Nd4jLong max = n.reduceNumber(sd::reduce::Max).e(0); - - // opNum, DataType, Shape, average time, median time - auto t = DataTypeUtils::asString(x.dataType()); - auto s = ShapeUtils::shapeAsString(&x); - auto stride = ShapeUtils::strideAsString(&x); - stride += "/"; - stride += ShapeUtils::strideAsString(&z); - std::string o; - o += x.ordering(); - o += "/"; - o += z.ordering(); - std::string inpl; - inpl += (x == z ? "true" : "false"); - - // printing out stuff - nd4j_printf("%s\t%i\t%i\t%i\t%s\t%s\t%s\t%s\t%s\tn/a\t%lld\t%lld\t%lld\t%lld\t%.2f\n", testName.c_str(), op, - _wIterations, _rIterations, t.c_str(), inpl.c_str(), s.c_str(), stride.c_str(), o.c_str(), - sd::math::nd4j_floor(sumT), median, min, max, stdev); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ScalarBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ScalarBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + TransformBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "TransformBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(std::initializer_list benchmarks, const char *msg) { - std::vector ops(benchmarks); - return runOperationSuit(ops, msg); - } + output += runOperationSuit(result, false); - std::string BenchmarkHelper::runOperationSuit(OpBenchmark* benchmark) { - return benchmarkOperation(*benchmark); + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(std::vector &benchmarks, bool postHeaders, const char *msg) { - std::string result; - - if (msg != nullptr && postHeaders) { - result += "\n"; - result += msg; - result += "\n"; - } - - if (postHeaders) - result += printHeader(); - - for (auto v:benchmarks) - result += benchmarkOperation(*v); - - return result; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + TransformBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "TransformBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + std::string output; + auto parameters = parametersBatch.parameters(); + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(DeclarableBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string result; - - if (message != nullptr) { - result += "\n"; - result += message; - result += "\n"; - } - - result += printHeader(); - - std::vector list; - - for (auto &p : parameters) { - auto ctx = func(p); - - auto clone = reinterpret_cast(op->clone()); - clone->setContext(ctx); - - result += runOperationSuit(clone); + output += runOperationSuit(result, false); - delete clone; - } - - return result; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(ScalarBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ScalarBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + ResultSet x; + x.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); + } + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(ScalarBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ScalarBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); + output += runOperationSuit(result, false); - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(TransformBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("TransformBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + ReductionBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "ReductionBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setZ(z_); + + if (y_.shapeInfo() != nullptr) { + clone->setAxis(y_.asVectorT()); } - - std::string BenchmarkHelper::runOperationSuit(TransformBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("TransformBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + BroadcastBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size()) + throw std::runtime_error( + "BroadcastBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + clone->setAxis(op->getAxis()); + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - std::string output; - auto parameters = parametersBatch.parameters(); - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); + output += runOperationSuit(result, false); - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; - } - - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message) { - std::string output; - ResultSet x; - x.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, z); - std::vector result; - - if (x.size() != z.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); - - if (y_.shapeInfo() != nullptr) { - clone->setAxis(y_.asVectorT()); - } - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "PairwiseBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(ReductionBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("ReductionBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setZ(z_); + output += runOperationSuit(result, false); - if (y_.shapeInfo() != nullptr) { - clone->setAxis(y_.asVectorT()); - } - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } - - std::string BenchmarkHelper::runOperationSuit(BroadcastBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() ) - throw std::runtime_error("BroadcastBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - clone->setAxis(op->getAxis()); - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + PairwiseBenchmark *op, + const std::function &func, + const char *message) { + std::string output; + + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(x, y, z); + std::vector result; + + if (x.size() != z.size() || x.size() != y.size()) + throw std::runtime_error( + "PairwiseBenchmark: number of X and Z arrays should match"); + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); + } + + output += runOperationSuit(result, message); + + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); + } + + return output; +} + +std::string BenchmarkHelper::runOperationSuit( + MatrixBenchmark *op, + const std::function &func, + ParametersBatch ¶metersBatch, const char *message) { + auto parameters = parametersBatch.parameters(); + std::string output; + + if (message != nullptr) { + output += "\n"; + output += message; + output += "\n"; + } + + output += printHeader(); + + for (auto &p : parameters) { + ResultSet x; + x.setNonRemovable(); + ResultSet y; + y.setNonRemovable(); + ResultSet z; + z.setNonRemovable(); + func(p, x, y, z); + std::vector result; + + for (int e = 0; e < x.size(); e++) { + auto x_ = x.at(e); + auto y_ = y.at(e); + auto z_ = z.at(e); + + auto clone = op->clone(); + clone->setX(x_); + clone->setY(y_); + clone->setZ(z_); + + result.emplace_back(clone); } - std::string BenchmarkHelper::runOperationSuit(PairwiseBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("PairwiseBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; - } - - std::string BenchmarkHelper::runOperationSuit(PairwiseBenchmark *op, const std::function& func, const char *message) { - std::string output; - - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(x, y, z); - std::vector result; + output += runOperationSuit(result, false); - if (x.size() != z.size() || x.size() != y.size()) - throw std::runtime_error("PairwiseBenchmark: number of X and Z arrays should match"); - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, message); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - - return output; + // removing everything + for (auto v : result) { + delete reinterpret_cast(v); } + } - std::string BenchmarkHelper::runOperationSuit(MatrixBenchmark *op, const std::function& func, ParametersBatch ¶metersBatch, const char *message) { - auto parameters = parametersBatch.parameters(); - std::string output; - - if (message != nullptr) { - output += "\n"; - output += message; - output += "\n"; - } - - output += printHeader(); - - for (auto &p: parameters) { - ResultSet x; - x.setNonRemovable(); - ResultSet y; - y.setNonRemovable(); - ResultSet z; - z.setNonRemovable(); - func(p, x, y, z); - std::vector result; - - for (int e = 0; e < x.size(); e++) { - auto x_ = x.at(e); - auto y_ = y.at(e); - auto z_ = z.at(e); - - auto clone = op->clone(); - clone->setX(x_); - clone->setY(y_); - clone->setZ(z_); - - result.emplace_back(clone); - } - - output += runOperationSuit(result, false); - - // removing everything - for (auto v:result) { - delete reinterpret_cast(v); - } - } - - return output; - } -} \ No newline at end of file + return output; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/BitwiseUtils.cpp b/libnd4j/include/helpers/impl/BitwiseUtils.cpp index e3f4ce92a0d5..3a560a8c7c0e 100644 --- a/libnd4j/include/helpers/impl/BitwiseUtils.cpp +++ b/libnd4j/include/helpers/impl/BitwiseUtils.cpp @@ -18,65 +18,60 @@ // Created by raver119 on 10.11.2017. // -#include #include +#include #include #include namespace sd { - bool BitwiseUtils::isBE() { - short int word = 0x0001; - char *byte = (char *) &word; - return(byte[0] ? false : true); - } +bool BitwiseUtils::isBE() { + short int word = 0x0001; + char *byte = (char *)&word; + return (byte[0] ? false : true); +} - int BitwiseUtils::valueBit(int holder) { - if (holder == 0) - return -1; +int BitwiseUtils::valueBit(int holder) { + if (holder == 0) return -1; #ifdef REVERSE_BITS - for (int e = 32; e >= 0; e--) { + for (int e = 32; e >= 0; e--) { #else - for (int e = 0; e < 32; e++) { + for (int e = 0; e < 32; e++) { #endif - bool isOne = (holder & 1 << e) != 0; - - if (isOne) - return e; - } + bool isOne = (holder & 1 << e) != 0; - return -1; - } + if (isOne) return e; + } + return -1; +} - std::vector BitwiseUtils::valueBits(int holder) { - std::vector bits; - if (holder == 0) { - for (int e = 0; e < 32; e++) - bits.emplace_back(0); - - return bits; - } +std::vector BitwiseUtils::valueBits(int holder) { + std::vector bits; + if (holder == 0) { + for (int e = 0; e < 32; e++) bits.emplace_back(0); + return bits; + } #ifdef REVERSE_BITS - for (int e = 32; e >= 0; e--) { + for (int e = 32; e >= 0; e--) { #else - for (int e = 0; e < 32; e++) { + for (int e = 0; e < 32; e++) { #endif - bool isOne = (holder & 1 << e) != 0; + bool isOne = (holder & 1 << e) != 0; - if (isOne) - bits.emplace_back(1); - else - bits.emplace_back(0); - } + if (isOne) + bits.emplace_back(1); + else + bits.emplace_back(0); + } - return bits; - } + return bits; +} - sd::ByteOrder BitwiseUtils::asByteOrder() { - return isBE() ? ByteOrder::BE : ByteOrder::LE; - } +sd::ByteOrder BitwiseUtils::asByteOrder() { + return isBE() ? ByteOrder::BE : ByteOrder::LE; } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/BlasHelper.cpp b/libnd4j/include/helpers/impl/BlasHelper.cpp index 378c8a6f1de6..1bf82b0b5eb1 100644 --- a/libnd4j/include/helpers/impl/BlasHelper.cpp +++ b/libnd4j/include/helpers/impl/BlasHelper.cpp @@ -20,348 +20,322 @@ #include namespace sd { - BlasHelper* BlasHelper::getInstance() { - if (_instance == 0) - _instance = new BlasHelper(); - return _instance; - } - - - void BlasHelper::initializeFunctions(Nd4jPointer *functions) { - nd4j_debug("Initializing BLAS\n",""); - - _hasSgemv = functions[0] != nullptr; - _hasSgemm = functions[2] != nullptr; - - _hasDgemv = functions[1] != nullptr; - _hasDgemm = functions[3] != nullptr; - - _hasSgemmBatch = functions[4] != nullptr; - _hasDgemmBatch = functions[5] != nullptr; - - this->cblasSgemv = (CblasSgemv)functions[0]; - this->cblasDgemv = (CblasDgemv)functions[1]; - this->cblasSgemm = (CblasSgemm)functions[2]; - this->cblasDgemm = (CblasDgemm)functions[3]; - this->cblasSgemmBatch = (CblasSgemmBatch)functions[4]; - this->cblasDgemmBatch = (CblasDgemmBatch)functions[5]; - this->lapackeSgesvd = (LapackeSgesvd)functions[6]; - this->lapackeDgesvd = (LapackeDgesvd)functions[7]; - this->lapackeSgesdd = (LapackeSgesdd)functions[8]; - this->lapackeDgesdd = (LapackeDgesdd)functions[9]; - } - - void BlasHelper::initializeDeviceFunctions(Nd4jPointer *functions) { - nd4j_debug("Initializing device BLAS\n",""); - - /* - this->cublasSgemv = (CublasSgemv)functions[0]; - this->cublasDgemv = (CublasDgemv)functions[1]; - this->cublasHgemm = (CublasHgemm)functions[2]; - this->cublasSgemm = (CublasSgemm)functions[3]; - this->cublasDgemm = (CublasDgemm)functions[4]; - this->cublasSgemmEx = (CublasSgemmEx)functions[5]; - this->cublasHgemmBatched = (CublasHgemmBatched)functions[6]; - this->cublasSgemmBatched = (CublasSgemmBatched)functions[7]; - this->cublasDgemmBatched = (CublasDgemmBatched)functions[8]; - this->cusolverDnSgesvdBufferSize = (CusolverDnSgesvdBufferSize)functions[9]; - this->cusolverDnDgesvdBufferSize = (CusolverDnDgesvdBufferSize)functions[10]; - this->cusolverDnSgesvd = (CusolverDnSgesvd)functions[11]; - this->cusolverDnDgesvd = (CusolverDnDgesvd)functions[12]; - */ - } - - - template <> - bool BlasHelper::hasGEMV() { - if (sd::Environment::getInstance()->blasFallback()) - return false; +BlasHelper* BlasHelper::getInstance() { + if (_instance == 0) _instance = new BlasHelper(); + return _instance; +} + +void BlasHelper::initializeFunctions(Nd4jPointer* functions) { + nd4j_debug("Initializing BLAS\n", ""); + + _hasSgemv = functions[0] != nullptr; + _hasSgemm = functions[2] != nullptr; + + _hasDgemv = functions[1] != nullptr; + _hasDgemm = functions[3] != nullptr; + + _hasSgemmBatch = functions[4] != nullptr; + _hasDgemmBatch = functions[5] != nullptr; + + this->cblasSgemv = (CblasSgemv)functions[0]; + this->cblasDgemv = (CblasDgemv)functions[1]; + this->cblasSgemm = (CblasSgemm)functions[2]; + this->cblasDgemm = (CblasDgemm)functions[3]; + this->cblasSgemmBatch = (CblasSgemmBatch)functions[4]; + this->cblasDgemmBatch = (CblasDgemmBatch)functions[5]; + this->lapackeSgesvd = (LapackeSgesvd)functions[6]; + this->lapackeDgesvd = (LapackeDgesvd)functions[7]; + this->lapackeSgesdd = (LapackeSgesdd)functions[8]; + this->lapackeDgesdd = (LapackeDgesdd)functions[9]; +} + +void BlasHelper::initializeDeviceFunctions(Nd4jPointer* functions) { + nd4j_debug("Initializing device BLAS\n", ""); + + /* + this->cublasSgemv = (CublasSgemv)functions[0]; + this->cublasDgemv = (CublasDgemv)functions[1]; + this->cublasHgemm = (CublasHgemm)functions[2]; + this->cublasSgemm = (CublasSgemm)functions[3]; + this->cublasDgemm = (CublasDgemm)functions[4]; + this->cublasSgemmEx = (CublasSgemmEx)functions[5]; + this->cublasHgemmBatched = (CublasHgemmBatched)functions[6]; + this->cublasSgemmBatched = (CublasSgemmBatched)functions[7]; + this->cublasDgemmBatched = (CublasDgemmBatched)functions[8]; + this->cusolverDnSgesvdBufferSize = (CusolverDnSgesvdBufferSize)functions[9]; + this->cusolverDnDgesvdBufferSize = (CusolverDnDgesvdBufferSize)functions[10]; + this->cusolverDnSgesvd = (CusolverDnSgesvd)functions[11]; + this->cusolverDnDgesvd = (CusolverDnDgesvd)functions[12]; + */ +} + +template <> +bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance()->blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else - return _hasSgemv; + return _hasSgemv; #endif - } +} - template <> - bool BlasHelper::hasGEMV() { - if (sd::Environment::getInstance()->blasFallback()) - return false; +template <> +bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance()->blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; +#else + return _hasDgemv; +#endif +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +template <> +bool BlasHelper::hasGEMV() { + return false; +} + +bool BlasHelper::hasGEMV(const sd::DataType dtype) { + if (dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance()->blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasSgemv; +#endif + } + if (dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance()->blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; #else return _hasDgemv; #endif - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - template <> - bool BlasHelper::hasGEMV() { - return false; - } - - bool BlasHelper::hasGEMV(const sd::DataType dtype) { - if(dtype == DataType::FLOAT32) { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasSgemv; - #endif - } - if(dtype == DataType::DOUBLE) { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasDgemv; - #endif - } - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - if (sd::Environment::getInstance()->blasFallback()) - return false; + } + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance()->blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else - return _hasSgemm; + return _hasSgemm; +#endif +} + +template <> +bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance()->blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasDgemm; #endif - } +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} - template <> - bool BlasHelper::hasGEMM() { - if (sd::Environment::getInstance()->blasFallback()) - return false; +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +template <> +bool BlasHelper::hasGEMM() { + return false; +} + +bool BlasHelper::hasGEMM(const sd::DataType dtype) { + if (dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance()->blasFallback()) return false; + +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return true; +#else + return _hasSgemm; +#endif + } + if (dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance()->blasFallback()) return false; #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; + return true; #else return _hasDgemm; #endif - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - template <> - bool BlasHelper::hasGEMM() { - return false; - } - - bool BlasHelper:: hasGEMM(const sd::DataType dtype) { - if(dtype == DataType::FLOAT32) { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasSgemm; - #endif - } - if(dtype == DataType::DOUBLE) { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return true; - #else - return _hasDgemm; - #endif - } - return false; - } - - - template <> - bool BlasHelper::hasBatchedGEMM() { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - return _hasSgemmBatch; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - if (sd::Environment::getInstance()->blasFallback()) - return false; - - return _hasDgemmBatch; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - template <> - bool BlasHelper::hasBatchedGEMM() { - return false; - } - - CblasSgemv BlasHelper::sgemv() { -#if defined(__EXTERNAL_BLAS__)|| defined(HAVE_OPENBLAS) - return (CblasSgemv)&cblas_sgemv; + } + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance()->blasFallback()) return false; + + return _hasSgemmBatch; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance()->blasFallback()) return false; + + return _hasDgemmBatch; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +template <> +bool BlasHelper::hasBatchedGEMM() { + return false; +} + +CblasSgemv BlasHelper::sgemv() { +#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) + return (CblasSgemv)&cblas_sgemv; #else - return this->cblasSgemv; + return this->cblasSgemv; #endif - } - CblasDgemv BlasHelper::dgemv() { +} +CblasDgemv BlasHelper::dgemv() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasDgemv)&cblas_dgemv; + return (CblasDgemv)&cblas_dgemv; #else - return this->cblasDgemv; + return this->cblasDgemv; #endif - } +} - CblasSgemm BlasHelper::sgemm() { +CblasSgemm BlasHelper::sgemm() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasSgemm)&cblas_sgemm; + return (CblasSgemm)&cblas_sgemm; #else - return this->cblasSgemm; + return this->cblasSgemm; #endif - } +} - CblasDgemm BlasHelper::dgemm() { +CblasDgemm BlasHelper::dgemm() { #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) - return (CblasDgemm)&cblas_dgemm; + return (CblasDgemm)&cblas_dgemm; #else - return this->cblasDgemm; + return this->cblasDgemm; #endif - } +} - CblasSgemmBatch BlasHelper::sgemmBatched() { - return this->cblasSgemmBatch; - } +CblasSgemmBatch BlasHelper::sgemmBatched() { return this->cblasSgemmBatch; } - CblasDgemmBatch BlasHelper::dgemmBatched() { - return this->cblasDgemmBatch; - } +CblasDgemmBatch BlasHelper::dgemmBatched() { return this->cblasDgemmBatch; } - LapackeSgesvd BlasHelper::sgesvd() { - return this->lapackeSgesvd; - } +LapackeSgesvd BlasHelper::sgesvd() { return this->lapackeSgesvd; } - LapackeDgesvd BlasHelper::dgesvd() { - return this->lapackeDgesvd; - } +LapackeDgesvd BlasHelper::dgesvd() { return this->lapackeDgesvd; } - LapackeSgesdd BlasHelper::sgesdd() { - return this->lapackeSgesdd; - } +LapackeSgesdd BlasHelper::sgesdd() { return this->lapackeSgesdd; } - LapackeDgesdd BlasHelper::dgesdd() { - return this->lapackeDgesdd; - } +LapackeDgesdd BlasHelper::dgesdd() { return this->lapackeDgesdd; } - // destructor - BlasHelper::~BlasHelper() noexcept { } +// destructor +BlasHelper::~BlasHelper() noexcept {} - BlasHelper* BlasHelper::_instance = 0; -} +BlasHelper* BlasHelper::_instance = 0; +} // namespace sd diff --git a/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp b/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp index d0bcce11e62b..3c4b91d7fb5d 100644 --- a/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp +++ b/libnd4j/include/helpers/impl/CudaLaunchHelper.cpp @@ -22,20 +22,20 @@ #include namespace sd { - Triple CudaLaunchHelper::getFlatLaunchParams(Nd4jLong length, int SM, int CORES, int SHARED_MEMORY) { - // TODO: to be implemented - Triple triple(1, 2, 3); +Triple CudaLaunchHelper::getFlatLaunchParams(Nd4jLong length, int SM, int CORES, + int SHARED_MEMORY) { + // TODO: to be implemented + Triple triple(1, 2, 3); - return triple; - } + return triple; +} - int CudaLaunchHelper::getReductionBlocks(Nd4jLong xLength, int blockSize) { - int div = xLength / blockSize; - int can = sd::math::nd4j_max(div, 1); - if (xLength % blockSize != 0 && xLength > blockSize) - can++; +int CudaLaunchHelper::getReductionBlocks(Nd4jLong xLength, int blockSize) { + int div = xLength / blockSize; + int can = sd::math::nd4j_max(div, 1); + if (xLength % blockSize != 0 && xLength > blockSize) can++; - // not more then 512 blocks - return sd::math::nd4j_min(can, 512); - } + // not more then 512 blocks + return sd::math::nd4j_min(can, 512); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/DebugHelper.cpp b/libnd4j/include/helpers/impl/DebugHelper.cpp index d24068a65484..67d8c4816aad 100644 --- a/libnd4j/include/helpers/impl/DebugHelper.cpp +++ b/libnd4j/include/helpers/impl/DebugHelper.cpp @@ -18,92 +18,97 @@ // Created by raver119 on 20/04/18. // -#include #include #include -#include -#include #include +#include +#include +#include namespace sd { - DebugInfo DebugHelper::debugStatistics(NDArray const* input) { - DebugInfo info; - DebugHelper::retrieveDebugStatistics(&info, input); - return info; - } - void - DebugHelper::retrieveDebugStatistics(DebugInfo* info, NDArray const* input) { - if (nullptr == info) - return; - - info->_minValue = 0.; - info->_maxValue = -1; - info->_meanValue = 0.; - info->_stdDevValue = 1.; - info->_zeroCount = 0; - info->_positiveCount = 0; - info->_negativeCount = 0; - info->_infCount = 0; - info->_nanCount = 0; - if (input->lengthOf() == 1) { // scalar case - info->_minValue = input->e(0); - info->_maxValue = info->_minValue; - info->_meanValue = info->_minValue; - info->_stdDevValue = info->_minValue; - info->_zeroCount = sd::math::nd4j_abs(input->e(0)) > 0.00001? 0: 1; - info->_positiveCount = input->e(0) > 0?1:0; - info->_negativeCount = input->e(0) < 0?1:0; - info->_infCount = sd::math::nd4j_isinf(input->e(0)); - info->_nanCount = sd::math::nd4j_isnan(input->e(0)); - } - else if (input->lengthOf() > 0) { - // TO DO: here processing for all elements with array - auto _minValue = input->e(0); - auto _maxValue = input->e(0); - auto _meanValue = input->e(0); - auto _stdDevValue = 0.; //info->_minValue; - auto _zeroCount = sd::math::nd4j_abs(input->e(0)) > 0.00001? 0L : 1L; - auto _positiveCount = input->e(0) > 0? 1L : 0L; - auto _negativeCount = input->e(0) < 0? 1L : 0L; - auto _infCount = sd::math::nd4j_isinf(input->e(0)) ? 1L : 0L; - auto _nanCount = sd::math::nd4j_isnan(input->e(0)) ? 1L : 0L; +DebugInfo DebugHelper::debugStatistics(NDArray const* input) { + DebugInfo info; + DebugHelper::retrieveDebugStatistics(&info, input); + return info; +} +void DebugHelper::retrieveDebugStatistics(DebugInfo* info, + NDArray const* input) { + if (nullptr == info) return; -PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_meanValue,_zeroCount,_positiveCount,_negativeCount) reduction(min:_minValue) reduction(max:_maxValue)) - for (Nd4jLong e = 1; e < input->lengthOf(); e++) { - auto current = input->e(e); - auto n = e + 1.; -// auto delta = current - _meanValue; -// auto delta2 = delta * delta; - _minValue = sd::math::nd4j_min(current, _minValue); - _maxValue = sd::math::nd4j_max(current, _maxValue); + info->_minValue = 0.; + info->_maxValue = -1; + info->_meanValue = 0.; + info->_stdDevValue = 1.; + info->_zeroCount = 0; + info->_positiveCount = 0; + info->_negativeCount = 0; + info->_infCount = 0; + info->_nanCount = 0; + if (input->lengthOf() == 1) { // scalar case + info->_minValue = input->e(0); + info->_maxValue = info->_minValue; + info->_meanValue = info->_minValue; + info->_stdDevValue = info->_minValue; + info->_zeroCount = + sd::math::nd4j_abs(input->e(0)) > 0.00001 ? 0 : 1; + info->_positiveCount = input->e(0) > 0 ? 1 : 0; + info->_negativeCount = input->e(0) < 0 ? 1 : 0; + info->_infCount = sd::math::nd4j_isinf(input->e(0)); + info->_nanCount = sd::math::nd4j_isnan(input->e(0)); + } else if (input->lengthOf() > 0) { + // TO DO: here processing for all elements with array + auto _minValue = input->e(0); + auto _maxValue = input->e(0); + auto _meanValue = input->e(0); + auto _stdDevValue = 0.; // info->_minValue; + auto _zeroCount = + sd::math::nd4j_abs(input->e(0)) > 0.00001 ? 0L : 1L; + auto _positiveCount = input->e(0) > 0 ? 1L : 0L; + auto _negativeCount = input->e(0) < 0 ? 1L : 0L; + auto _infCount = sd::math::nd4j_isinf(input->e(0)) ? 1L : 0L; + auto _nanCount = sd::math::nd4j_isnan(input->e(0)) ? 1L : 0L; - _meanValue += current; - //_meanValue += delta / n; // this is a perfect formula but not working with omp in this notation - //_stdDevValue += delta2 * e / n; + PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_meanValue,_zeroCount,_positiveCount,_negativeCount) reduction(min:_minValue) reduction(max:_maxValue)) + for (Nd4jLong e = 1; e < input->lengthOf(); e++) { + auto current = input->e(e); + auto n = e + 1.; + // auto delta = current - _meanValue; + // auto delta2 = delta * delta; + _minValue = sd::math::nd4j_min(current, _minValue); + _maxValue = sd::math::nd4j_max(current, _maxValue); - _zeroCount += sd::math::nd4j_abs(current) > 0.00001 ? 0 : 1; - _positiveCount += current > 0 ? 1 : 0; - _negativeCount += current < 0 ? 1 : 0; - _infCount += sd::math::nd4j_isinf(current); - _nanCount += sd::math::nd4j_isnan(current); - } - *info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount}; - _stdDevValue = 0; //math::nd4j_sqrt(info->_stdDevValue / (input->lengthOf() - 1)); + _meanValue += current; + //_meanValue += delta / n; // this is a perfect formula but not working + //with omp in this notation _stdDevValue += delta2 * e / n; - auto func = PRAGMA_REDUCE_DOUBLE { - auto _stdDevValue = 0.0; - for (auto e = start; e < stop; e++) { - double current = input->e(e); - _stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue; - } + _zeroCount += sd::math::nd4j_abs(current) > 0.00001 ? 0 : 1; + _positiveCount += current > 0 ? 1 : 0; + _negativeCount += current < 0 ? 1 : 0; + _infCount += sd::math::nd4j_isinf(current); + _nanCount += sd::math::nd4j_isnan(current); + } + *info = {_minValue, _maxValue, _meanValue / input->lengthOf(), + _stdDevValue, _zeroCount, _positiveCount, + _negativeCount, _infCount, _nanCount}; + _stdDevValue = 0; // math::nd4j_sqrt(info->_stdDevValue / + // (input->lengthOf() - 1)); - return _stdDevValue; - }; - _stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf()); + auto func = PRAGMA_REDUCE_DOUBLE { + auto _stdDevValue = 0.0; + for (auto e = start; e < stop; e++) { + double current = input->e(e); + _stdDevValue += (info->_meanValue - current) * + (info->_meanValue - current); // info->_minValue; + } - info->_stdDevValue = math::nd4j_sqrt(_stdDevValue / input->lengthOf()); + return _stdDevValue; + }; + _stdDevValue = samediff::Threads::parallel_double( + func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf()); - } -// else - no statistics for empty - } + info->_stdDevValue = + math::nd4j_sqrt(_stdDevValue / input->lengthOf()); + } + // else - no statistics for empty } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/EnumUtils.cpp b/libnd4j/include/helpers/impl/EnumUtils.cpp index a18592d0b962..c848a7407d94 100644 --- a/libnd4j/include/helpers/impl/EnumUtils.cpp +++ b/libnd4j/include/helpers/impl/EnumUtils.cpp @@ -24,55 +24,91 @@ using namespace sd::graph; namespace sd { - const char * EnumUtils::_VariableTypeToString(sd::graph::VariableType variableType) { - switch (variableType) { - case NDARRAY: return "NDARRAY"; - case ARRAY_LIST: return "ARRAY_LIST"; - case FLOW: return "FLOW"; - default: return "UNKNOWN VariableType"; - } - } +const char* EnumUtils::_VariableTypeToString( + sd::graph::VariableType variableType) { + switch (variableType) { + case NDARRAY: + return "NDARRAY"; + case ARRAY_LIST: + return "ARRAY_LIST"; + case FLOW: + return "FLOW"; + default: + return "UNKNOWN VariableType"; + } +} - const char * EnumUtils::_OpTypeToString(sd::graph::OpType opType) { - switch(opType) { - case OpType_REDUCE_SAME: return "REDUCE_SAME"; - case OpType_REDUCE_BOOL: return "REDUCE_BOOL"; - case OpType_REDUCE_LONG: return "REDUCE_LONG"; - case OpType_REDUCE_FLOAT: return "REDUCE_FLOAT"; - case OpType_BOOLEAN: return "BOOLEAN"; - case OpType_BROADCAST: return "BROADCAST"; - case OpType_BROADCAST_BOOL: return "BROADCAST_BOOL"; - case OpType_PAIRWISE: return "PAIRWISE"; - case OpType_PAIRWISE_BOOL: return "PAIRWISE_BOOL"; - case OpType_CUSTOM: return "CUSTOM"; - case OpType_LOGIC: return "LOGIC"; - case OpType_TRANSFORM_SAME: return "TRANSFORM_SAME"; - case OpType_TRANSFORM_FLOAT: return "TRANSFORM_FLOAT"; - case OpType_TRANSFORM_BOOL: return "TRANSFORM_BOOL"; - case OpType_TRANSFORM_STRICT: return "TRANSFORM_STRICT"; - case OpType_TRANSFORM_ANY: return "TRANSFORM_ANY"; - case OpType_INDEX_REDUCE: return "INDEX_ACCUMULATION"; - case OpType_SCALAR: return "SCALAR"; - case OpType_SCALAR_BOOL: return "SCALAR_BOOL"; - case OpType_SHAPE: return "SHAPE"; - default: return "UNKNOWN OpType"; - } - } +const char* EnumUtils::_OpTypeToString(sd::graph::OpType opType) { + switch (opType) { + case OpType_REDUCE_SAME: + return "REDUCE_SAME"; + case OpType_REDUCE_BOOL: + return "REDUCE_BOOL"; + case OpType_REDUCE_LONG: + return "REDUCE_LONG"; + case OpType_REDUCE_FLOAT: + return "REDUCE_FLOAT"; + case OpType_BOOLEAN: + return "BOOLEAN"; + case OpType_BROADCAST: + return "BROADCAST"; + case OpType_BROADCAST_BOOL: + return "BROADCAST_BOOL"; + case OpType_PAIRWISE: + return "PAIRWISE"; + case OpType_PAIRWISE_BOOL: + return "PAIRWISE_BOOL"; + case OpType_CUSTOM: + return "CUSTOM"; + case OpType_LOGIC: + return "LOGIC"; + case OpType_TRANSFORM_SAME: + return "TRANSFORM_SAME"; + case OpType_TRANSFORM_FLOAT: + return "TRANSFORM_FLOAT"; + case OpType_TRANSFORM_BOOL: + return "TRANSFORM_BOOL"; + case OpType_TRANSFORM_STRICT: + return "TRANSFORM_STRICT"; + case OpType_TRANSFORM_ANY: + return "TRANSFORM_ANY"; + case OpType_INDEX_REDUCE: + return "INDEX_ACCUMULATION"; + case OpType_SCALAR: + return "SCALAR"; + case OpType_SCALAR_BOOL: + return "SCALAR_BOOL"; + case OpType_SHAPE: + return "SHAPE"; + default: + return "UNKNOWN OpType"; + } +} - - const char * EnumUtils::_LogicOpToString(int opNum) { - switch(opNum) { - case 0: return "WHILE"; - case 10: return "SCOPE"; - case 20: return "CONDITIONAL"; - case 30: return "SWITCH"; - case 40: return "RETURN"; - case 60: return "MERGE"; - case 70: return "LOOP_COND"; - case 80: return "NEXT_ITERATION"; - case 90: return "EXIT"; - case 100: return "ENTER"; - default: return "UNKNOWN OPERATION"; - } - } -} \ No newline at end of file +const char* EnumUtils::_LogicOpToString(int opNum) { + switch (opNum) { + case 0: + return "WHILE"; + case 10: + return "SCOPE"; + case 20: + return "CONDITIONAL"; + case 30: + return "SWITCH"; + case 40: + return "RETURN"; + case 60: + return "MERGE"; + case 70: + return "LOOP_COND"; + case 80: + return "NEXT_ITERATION"; + case 90: + return "EXIT"; + case 100: + return "ENTER"; + default: + return "UNKNOWN OPERATION"; + } +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/FileUtils.cpp b/libnd4j/include/helpers/impl/FileUtils.cpp index 15c426a4b9da..c0e02db27225 100644 --- a/libnd4j/include/helpers/impl/FileUtils.cpp +++ b/libnd4j/include/helpers/impl/FileUtils.cpp @@ -20,20 +20,19 @@ #include #include -#include #include +#include namespace sd { - bool FileUtils::fileExists(const char *filename) { - if (filename == nullptr) - return false; +bool FileUtils::fileExists(const char *filename) { + if (filename == nullptr) return false; - return file_exists(filename); - } + return file_exists(filename); +} - int64_t FileUtils::fileSize(const char *filename) { - struct stat stat_buf; - int rc = stat(filename, &stat_buf); - return rc == 0 ? stat_buf.st_size : -1; - } -} \ No newline at end of file +int64_t FileUtils::fileSize(const char *filename) { + struct stat stat_buf; + int rc = stat(filename, &stat_buf); + return rc == 0 ? stat_buf.st_size : -1; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index 1db001b544d6..e9ac5f0d5c98 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -18,134 +18,156 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 16.07.2018 // -#include #include - +#include namespace sd { ////////////////////////////////////////////////////////////////////////// -void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& gradArrs) { - - const int numInGradArrs = gradArrs.size(); - - // fill input gradient arrays in accordance to type of loss function - switch(loss) { - - case MEAN: - for(int i = 0; i < numInGradArrs; ++i) - *gradArrs[i] = 1. / gradArrs[i]->lengthOf(); - break; - - case SUM: - for(int i = 0; i < numInGradArrs; ++i) - *gradArrs[i] = 1.; - break; - - default: - throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !"); - } +void GradCheck::fillGradArrays(const LossFunc loss, + const std::vector& gradArrs) { + const int numInGradArrs = gradArrs.size(); + + // fill input gradient arrays in accordance to type of loss function + switch (loss) { + case MEAN: + for (int i = 0; i < numInGradArrs; ++i) + *gradArrs[i] = 1. / gradArrs[i]->lengthOf(); + break; + + case SUM: + for (int i = 0; i < numInGradArrs; ++i) *gradArrs[i] = 1.; + break; + + default: + throw std::invalid_argument( + "GradCheck::fillGradArrays: invalid type of loss function !"); + } } ////////////////////////////////////////////////////////////////////////// -bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss) { - - const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP - const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP - const std::vector& inArrsFF = argsHolderFF.getInArrs(); - const std::vector& inArrsBP = argsHolderBP.getInArrs(); - - // fill input gradient arrays in accordance to kind of loss function - fillGradArrays(loss, std::vector(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); - - // back prop pass - ResultSet outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF; - - NDArray tmpScalar(sd::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0 - - for(int i = 0; i < numInArrsFF; ++i) { // loop through input array - - if(!whatArrsToCheck.empty() && static_cast(whatArrsToCheck[i]) == false) - continue; - - const Nd4jLong idxStart = static_cast(idxRange[0] * inArrsFF[i]->lengthOf()); - const Nd4jLong idxEnd = static_cast(idxRange[1] * inArrsFF[i]->lengthOf()); - - for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array - - const double orig = inArrsFF[i]->e(j); - - // add epsilon, feed forward - inArrsFF[i]->p(j, orig + EPSILON); - ResultSet outArrsFF = opFF.execute(argsHolderFF); - int numOutArrs = outArrsFF.size(); - double scorePlus = 0.; - - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } - - // subtract epsilon, feed forward - inArrsFF[i]->p(j, orig - EPSILON); - outArrsFF = opFF.execute(argsHolderFF); - double scoreMinus = 0.; - - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } - - // restore initial element value - inArrsFF[i]->p(j, orig); - - // calculate numerical gradient - const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON); - if(std::isnan(numericalGrad) || std::isinf(numericalGrad)) { - printf("GradCheck::checkGrad: got wrong value for numerical gradient for input array # %i and its element at position %lld ! \n", i, j); - throw std::runtime_error(""); - } - - // get analytical gradient - const double analyticGrad = outArrsBP.at(i).e(j); - if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) { - printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j); - throw std::runtime_error(""); - } - - // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad); - - // calculate relative error - double relError; - if(numericalGrad == 0. && analyticGrad == 0.) - relError = 0.; - else - relError = math::nd4j_abs(analyticGrad - numericalGrad) / (math::nd4j_abs(analyticGrad) + math::nd4j_abs(numericalGrad)); - - // verify result - if(relError > MAXRELERR || std::isnan(relError)) { - - if(math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) - continue; - printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad); - printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); - return false; - } - } - } - - return true; -} - - - +bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, + const OpArgsHolder& argsHolderFF, + const OpArgsHolder& argsHolderBP, + const std::vector& whatArrsToCheck, + const std::vector& idxRange, + const LossFunc loss) { + const int numInArrsFF = + argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of + // output arrays in opBP + const int numInGradArrsBP = + argsHolderBP.getNumInArrs() - + numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + + // numInGradArrsBP + const std::vector& inArrsFF = argsHolderFF.getInArrs(); + const std::vector& inArrsBP = argsHolderBP.getInArrs(); + + // fill input gradient arrays in accordance to kind of loss function + fillGradArrays( + loss, std::vector(&inArrsBP[numInArrsFF], + &inArrsBP[numInArrsFF + numInGradArrsBP])); + + // back prop pass + ResultSet outArrsBP = opBP.execute( + argsHolderBP); // number of output arrays in back prop = numInArrsFF; + + NDArray tmpScalar(sd::DataType::DOUBLE, + inArrsFF[0]->getContext()); // scalar = 0 + + for (int i = 0; i < numInArrsFF; ++i) { // loop through input array + + if (!whatArrsToCheck.empty() && + static_cast(whatArrsToCheck[i]) == false) + continue; + + const Nd4jLong idxStart = + static_cast(idxRange[0] * inArrsFF[i]->lengthOf()); + const Nd4jLong idxEnd = + static_cast(idxRange[1] * inArrsFF[i]->lengthOf()); + + for (Nd4jLong j = idxStart; j < idxEnd; + ++j) { // loop through all elements for current array + + const double orig = inArrsFF[i]->e(j); + + // add epsilon, feed forward + inArrsFF[i]->p(j, orig + EPSILON); + ResultSet outArrsFF = opFF.execute(argsHolderFF); + int numOutArrs = outArrsFF.size(); + double scorePlus = 0.; + + for (int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if (loss == SUM) + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } + + // subtract epsilon, feed forward + inArrsFF[i]->p(j, orig - EPSILON); + outArrsFF = opFF.execute(argsHolderFF); + double scoreMinus = 0.; + + for (int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if (loss == SUM) + outArrsFF.at(k).reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k).reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } + + // restore initial element value + inArrsFF[i]->p(j, orig); + + // calculate numerical gradient + const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON); + if (std::isnan(numericalGrad) || std::isinf(numericalGrad)) { + printf( + "GradCheck::checkGrad: got wrong value for numerical gradient for " + "input array # %i and its element at position %lld ! \n", + i, j); + throw std::runtime_error(""); + } + + // get analytical gradient + const double analyticGrad = outArrsBP.at(i).e(j); + if (std::isnan(analyticGrad) || std::isinf(analyticGrad)) { + printf( + "GradCheck::checkGrad: got wrong value for analytical gradient for " + "input array # %i and its element at position %lld ! \n", + i, j); + throw std::runtime_error(""); + } + + // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, + // analyticGrad); + + // calculate relative error + double relError; + if (numericalGrad == 0. && analyticGrad == 0.) + relError = 0.; + else + relError = math::nd4j_abs(analyticGrad - numericalGrad) / + (math::nd4j_abs(analyticGrad) + + math::nd4j_abs(numericalGrad)); + + // verify result + if (relError > MAXRELERR || std::isnan(relError)) { + if (math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) + continue; + printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, + analyticGrad); + printf( + "GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for " + "input array # %i and its element at position %lld ! \n", + relError, MAXRELERR, i, j); + return false; + } + } + } + + return true; } - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index 8e37fd530af7..f56755800b4d 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -22,292 +22,377 @@ #define LIBND4J_MMULHELPER_CPP #include "../MmulHelper.h" -#include -#include + #include +#include +#include namespace sd { ////////////////////////////////////////////////////////////////////////// -sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* A, const sd::NDArray* B, const std::initializer_list& axesA, const std::initializer_list& axesB) { - std::vector aA(axesA); - std::vector aB(axesB); - return tensorDot(A, B, aA, aB); +sd::NDArray* sd::MmulHelper::tensorDot( + const sd::NDArray* A, const sd::NDArray* B, + const std::initializer_list& axesA, + const std::initializer_list& axesB) { + std::vector aA(axesA); + std::vector aB(axesB); + return tensorDot(A, B, aA, aB); } ////////////////////////////////////////////////////////////////////////// -sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector& axes_0, const std::vector& axes_1) { - - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - - auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); - - // check whether permutation is necessary - const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); - const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); - - // check whether reshape is necessary - const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); - const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); - - NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); - - c->reshapei(outShape); - - if(aP != aPR) - delete aPR; - if(bP != bPR) - delete bPR; - if(a != aP) - delete aP; - if(b != bP) - delete bP; - - return c; +sd::NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, + const sd::NDArray* b, + const std::vector& axes_0, + const std::vector& axes_1) { + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + + auto outShape = ShapeUtils::evalShapeForTensorDot( + a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); + + // check whether permutation is necessary + const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); + const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); + + // check whether reshape is necessary + const NDArray* aPR = aP->isSameShape(shapeAt) + ? aP + : new NDArray(aP->reshape(aP->ordering(), shapeAt)); + const NDArray* bPR = bP->isSameShape(shapeAt) + ? bP + : new NDArray(bP->reshape(bP->ordering(), shapeBt)); + + NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); + + c->reshapei(outShape); + + if (aP != aPR) delete aPR; + if (bP != bPR) delete bPR; + if (a != aP) delete aP; + if (b != bP) delete bP; + + return c; } ////////////////////////////////////////////////////////////////////////// -void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, sd::NDArray* c, const std::vector& axes_a, const std::vector& axes_b, const std::vector& permutForC) { - - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt); - - // check whether permutation is required - NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC)); - - // check whether permutation is necessary - const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); - const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); - - // check whether reshape is necessary - const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); - const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); - - std::vector requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)}; - - NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false)); - - mmul(aPR, bPR, cPR, 1.0, 0.0); - - if(cPR->buffer() != cP->buffer() || cPR->specialBuffer() != cP->specialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->buffer() - cP->assign(cPR); - - if(aP != aPR) - delete aPR; - if(bP != bPR) - delete bPR; - if(a != aP) - delete aP; - if(b != bP) - delete bP; - - if(cP != cPR) - delete cPR; - if(c != cP) - delete cP; +void sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, + sd::NDArray* c, const std::vector& axes_a, + const std::vector& axes_b, + const std::vector& permutForC) { + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, + shapeAt, shapeBt); + + // check whether permutation is required + NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC)); + + // check whether permutation is necessary + const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); + const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); + + // check whether reshape is necessary + const NDArray* aPR = aP->isSameShape(shapeAt) + ? aP + : new NDArray(aP->reshape(aP->ordering(), shapeAt)); + const NDArray* bPR = bP->isSameShape(shapeAt) + ? bP + : new NDArray(bP->reshape(bP->ordering(), shapeBt)); + + std::vector requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)}; + + NDArray* cPR = + cP->isSameShape(requiredCshape) + ? cP + : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false)); + + mmul(aPR, bPR, cPR, 1.0, 0.0); + + if (cPR->buffer() != cP->buffer() || + cPR->specialBuffer() != + cP->specialBuffer()) // this means both permute and reshape have been + // performed on c, cP always points on + // c->buffer() + cP->assign(cPR); + + if (aP != aPR) delete aPR; + if (bP != bPR) delete bPR; + if (a != aP) delete aP; + if (b != bP) delete bP; + + if (cP != cPR) delete cPR; + if (c != cP) delete cP; } - #ifndef __JAVACPP_HACK__ ////////////////////////////////////////////////////////////////////////// -void sd::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC) { - - NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); - std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception - - for(const auto& arr : modifA) - whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array - for(const auto& arr : modifB) - whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; - for(const auto& arr : modifC) - whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r"; - - // first step for a array - if(!whatToDoWithA.empty()) - aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0])); - // first step for b array - if(!whatToDoWithB.empty()) - bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0])); - // rest steps for a array - for(int i = 1; i < whatToDoWithA.size(); ++i) - if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); - // rest steps for b array - for(int i = 1; i < whatToDoWithB.size(); ++i) - if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); - - // now work with c array - std::vector cArrs = {c}; - if(!whatToDoWithC.empty()) { - cArrs = std::vector(whatToDoWithC.size()+1, c); - for(int i = 0; i < cArrs.size()-1; ++i) - cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c +void sd::MmulHelper::tensorDot( + const NDArray* a, const NDArray* b, NDArray* c, + const std::vector>& modifA, + const std::vector>& modifB, + const std::vector>& modifC) { + NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); + std::string whatToDoWithA, whatToDoWithB, + whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" + // - permutation+reshaping; "rp" - reshaping/permutation, + // and so on; if another string is produced - throw + // exception + + for (const auto& arr : modifA) + whatToDoWithA = + (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithA + "p" + : whatToDoWithA + + "r"; // when 0 is present in arr then it is permutation + // array, otherwise - it is reshaping array + for (const auto& arr : modifB) + whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithB + "p" + : whatToDoWithB + "r"; + for (const auto& arr : modifC) + whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithC + "p" + : whatToDoWithC + "r"; + + // first step for a array + if (!whatToDoWithA.empty()) + aPR = (whatToDoWithA[0] == 'p') + ? new NDArray(a->permute(modifA[0])) + : new NDArray(a->reshape(a->ordering(), modifA[0])); + // first step for b array + if (!whatToDoWithB.empty()) + bPR = (whatToDoWithB[0] == 'p') + ? new NDArray(b->permute(modifB[0])) + : new NDArray(b->reshape(b->ordering(), modifB[0])); + // rest steps for a array + for (int i = 1; i < whatToDoWithA.size(); ++i) + if (whatToDoWithA[i] == 'p') + aPR->permutei(modifA[i]); + else + aPR->reshapei(modifA[i]); + // rest steps for b array + for (int i = 1; i < whatToDoWithB.size(); ++i) + if (whatToDoWithB[i] == 'p') + bPR->permutei(modifB[i]); + else + bPR->reshapei(modifB[i]); + + // now work with c array + std::vector cArrs = {c}; + if (!whatToDoWithC.empty()) { + cArrs = std::vector(whatToDoWithC.size() + 1, c); + for (int i = 0; i < cArrs.size() - 1; ++i) + cArrs[i + 1] = + (whatToDoWithC[i] == 'p') + ? new NDArray(cArrs[i]->permute(modifC[i])) + : new NDArray(cArrs[i]->reshape( + c->ordering(), modifC[i], + false)); // since we ignore first element in cArrs (that is + // cArrs[0]) then it is always equal to c + } + + mmul(aPR, bPR, cArrs[cArrs.size() - 1], 1.0, 0.0); + + // check whether new buffer allocation was happened for c array + if (!whatToDoWithC.empty()) { + for (int i = cArrs.size() - 1; i > 0; --i) { + if (cArrs[i]->buffer() != cArrs[i - 1]->buffer() || + cArrs[i]->specialBuffer() != cArrs[i - 1]->specialBuffer()) + cArrs[i - 1]->assign(cArrs[i]); + delete cArrs[i]; } + } - mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); - - // check whether new buffer allocation was happened for c array - if(!whatToDoWithC.empty()) { - for(int i = cArrs.size()-1; i > 0; --i) { - if(cArrs[i]->buffer() != cArrs[i-1]->buffer() || cArrs[i]->specialBuffer() != cArrs[i-1]->specialBuffer()) - cArrs[i-1]->assign(cArrs[i]); - delete cArrs[i]; - } - } - - if(aPR != a) - delete aPR; - if(bPR != b) - delete bPR; + if (aPR != a) delete aPR; + if (bPR != b) delete bPR; } ////////////////////////////////////////////////////////////////////////// -NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b, const std::vector>& modifA, const std::vector>& modifB) { - - NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); - std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception - - for(const auto& arr : modifA) - whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array - for(const auto& arr : modifB) - whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; - - // first step for a array - if(!whatToDoWithA.empty()) - aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0])); - // first step for b array - if(!whatToDoWithB.empty()) - bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0])); - // rest steps for a array - for(int i = 1; i < whatToDoWithA.size(); ++i) - if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); - // rest steps for b array - for(int i = 1; i < whatToDoWithB.size(); ++i) - if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); - - NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0); - - if(aPR != a) - delete aPR; - if(bPR != b) - delete bPR; - return result; +NDArray* sd::MmulHelper::tensorDot( + const sd::NDArray* a, const sd::NDArray* b, + const std::vector>& modifA, + const std::vector>& modifB) { + NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); + std::string whatToDoWithA, + whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping + // only; "pr" - permutation+reshaping; "rp" - + // reshaping/permutation; another string - throw exception + + for (const auto& arr : modifA) + whatToDoWithA = + (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithA + "p" + : whatToDoWithA + + "r"; // when 0 is present in arr then it is permutation + // array, otherwise - it is reshaping array + for (const auto& arr : modifB) + whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) + ? whatToDoWithB + "p" + : whatToDoWithB + "r"; + + // first step for a array + if (!whatToDoWithA.empty()) + aPR = (whatToDoWithA[0] == 'p') + ? new NDArray(a->permute(modifA[0])) + : new NDArray(a->reshape(a->ordering(), modifA[0])); + // first step for b array + if (!whatToDoWithB.empty()) + bPR = (whatToDoWithB[0] == 'p') + ? new NDArray(b->permute(modifB[0])) + : new NDArray(b->reshape(b->ordering(), modifB[0])); + // rest steps for a array + for (int i = 1; i < whatToDoWithA.size(); ++i) + if (whatToDoWithA[i] == 'p') + aPR->permutei(modifA[i]); + else + aPR->reshapei(modifA[i]); + // rest steps for b array + for (int i = 1; i < whatToDoWithB.size(); ++i) + if (whatToDoWithB[i] == 'p') + bPR->permutei(modifB[i]); + else + bPR->reshapei(modifB[i]); + + NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0); + + if (aPR != a) delete aPR; + if (bPR != b) delete bPR; + return result; } #endif - ////////////////////////////////////////////////////////////////////////// -sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C , const double alpha, const double beta, const char outOrder) { - - int lenDim; - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim); - const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim); - - // dot product of 2 vectors - if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4) - return dot(A, B, C, alpha, beta); - - // matrix x matrix - if(aRank == 2 && bRank == 2) - return mmulMxM(A, B, C, alpha, beta, outOrder); - - // matrix x vector - if(aRank == 2 && isBVector) - return mmulMxV(A, B, C, alpha, beta, outOrder); - - // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm - if(isAVector && bRank == 2) { - NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} - NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N} - auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} - delete A2; - delete C2; - - if(!C) { - result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} - return result; - } - return C; +sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, + sd::NDArray* C, const double alpha, + const double beta, const char outOrder) { + int lenDim; + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim); + const bool isBVector = shape::isCommonVector(B->shapeInfo(), lenDim); + + // dot product of 2 vectors + if (isAVector && isBVector && + (aRank != 2 || + aRank == 2 && + (A->isSameShape(B) || + bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) + // or (4x1 * 4x1) or (4x1 * 4) + return dot(A, B, C, alpha, beta); + + // matrix x matrix + if (aRank == 2 && bRank == 2) return mmulMxM(A, B, C, alpha, beta, outOrder); + + // matrix x vector + if (aRank == 2 && isBVector) return mmulMxV(A, B, C, alpha, beta, outOrder); + + // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} + // x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm + if (isAVector && bRank == 2) { + NDArray* A2 = new NDArray( + A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} + NDArray* C2 = + C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) + : nullptr; // C{N} -> C2{1,N} + auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} + delete A2; + delete C2; + + if (!C) { + result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} + return result; } + return C; + } - // batched matrix multiplication - return mmulNxN(A, B, C, alpha, beta, outOrder); + // batched matrix multiplication + return mmulNxN(A, B, C, alpha, beta, outOrder); } - ////////////////////////////////////////////////////////////////////////// - void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, sd::NDArray* z, const bool transX, const bool transY, double alpha, double beta) { - int xRank = x->rankOf(); - int yRank = y->rankOf(); - - auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), transX, transY); - if(!z->isSameShape(outShape)) { - nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - throw std::invalid_argument(""); - } - - if (z->isEmpty()) - return; - - NDArray* xT(const_cast(x)), *yT(const_cast(y)), *zT(z); - - if((transX && xRank > 1) || (transY && yRank > 1)) { - const int rank = xRank >= yRank ? xRank : yRank; - std::vector permut(rank); - for (int i = 0; i < rank-2; ++i) - permut[i] = i; - permut[rank-2] = rank - 1; - permut[rank-1] = rank - 2; - - if(transX) - xT = new NDArray(x->permute(permut)); - - if(transY) - yT = new NDArray(y->permute(permut)); - } - - if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases - - if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case - xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1) - zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); - } - - mmul(xT, yT, zT, alpha, beta); - } - else { // rest cases - batched mmul - - const int batchRank = xRank - 2; - std::vector dimsToExclude(batchRank); - for(int i = 0; i < batchRank; ++i) - dimsToExclude[i] = i; - - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude); - -//PRAGMA_OMP_PARALLEL_FOR - for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - auto xSubArr = (*xT)(i, dimsToExclude); - auto ySubArr = (*yT)(i, dimsToExclude); - auto zSubArr = (*zT)(i, dimsToExclude); - mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); - } - } - - if(xT != x) - delete xT; - if(yT != y) - delete yT; - if(zT != z) - delete zT; +void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, + sd::NDArray* z, const bool transX, const bool transY, + double alpha, double beta) { + int xRank = x->rankOf(); + int yRank = y->rankOf(); + + auto outShape = ShapeUtils::evalShapeForMatmul(x->shapeInfo(), y->shapeInfo(), + transX, transY); + if (!z->isSameShape(outShape)) { + nd4j_printf( + "NDArrayFactory::matmul static method: input shape of output array is " + "wrong, actual is %s and expected is %s ! \n", + ShapeUtils::shapeAsString(z).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + throw std::invalid_argument(""); + } + + if (z->isEmpty()) return; + + NDArray *xT(const_cast(x)), *yT(const_cast(y)), *zT(z); + + if ((transX && xRank > 1) || (transY && yRank > 1)) { + const int rank = xRank >= yRank ? xRank : yRank; + std::vector permut(rank); + for (int i = 0; i < rank - 2; ++i) permut[i] = i; + permut[rank - 2] = rank - 1; + permut[rank - 1] = rank - 2; + + if (transX) xT = new NDArray(x->permute(permut)); + + if (transY) yT = new NDArray(y->permute(permut)); + } + + if (xRank <= 2 && + yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector + // (2Dx1D), matrix-matrix (2Dx2D) product cases + + if (xRank == 1 && + yRank == 2) { // reduce vector-matrix to matrix-matrix case + xT = new NDArray( + x->reshape(x->ordering(), + {1, x->lengthOf()})); // please note x is not transposed + // in this case (since xRank=1) + zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()})); } -//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + mmul(xT, yT, zT, alpha, beta); + } else { // rest cases - batched mmul + + const int batchRank = xRank - 2; + std::vector dimsToExclude(batchRank); + for (int i = 0; i < batchRank; ++i) dimsToExclude[i] = i; + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(xT->shapeInfo(), dimsToExclude); + + // PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < numOfSubArrs; ++i) { + auto xSubArr = (*xT)(i, dimsToExclude); + auto ySubArr = (*yT)(i, dimsToExclude); + auto zSubArr = (*zT)(i, dimsToExclude); + mmul(&xSubArr, &ySubArr, &zSubArr, alpha, beta); + } + } + + if (xT != x) delete xT; + if (yT != y) delete yT; + if (zT != z) delete zT; } +// BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool +// transA, const bool transB, const int M, const int N, const int K, const double +// alpha, const void* A, const int lda, const void* B, const int ldb, const +// double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char +// aOrder, const int M, const int N, const double alpha, const void* A, const int +// lda, const void* B, const int incx, const double beta, void* C, const int +// incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +// BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const +// double alpha, const void* vX, const Nd4jLong incx, const void* vY, const +// Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp b/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp index 0e409a952496..c382a6e8ca16 100644 --- a/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp +++ b/libnd4j/include/helpers/impl/OmpLaunchHelper.cpp @@ -20,92 +20,93 @@ // #include -#include #include +#include #ifdef _OPENMP #include #endif namespace sd { - //////////////////////////////////////////////////////////////////////////////// -OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) { - - auto maxItersPerThread = Environment::getInstance()->elementwiseThreshold(); - - if(N < maxItersPerThread) - _numThreads = 1; - else { - #ifdef _OPENMP - if(desiredNumThreads == -1) - desiredNumThreads = omp_get_max_threads(); - else if(desiredNumThreads < 1) - desiredNumThreads = 1; - else - desiredNumThreads = sd::math::nd4j_min(omp_get_max_threads(), desiredNumThreads); - #else - desiredNumThreads = sd::Environment::getInstance()->maxThreads(); - #endif - _numThreads = sd::math::nd4j_min(N / maxItersPerThread, desiredNumThreads); - } - - _itersPerThread = N / _numThreads; - _remainder = N % _numThreads; // last thread may contain bigger number of iterations -} +OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) { + auto maxItersPerThread = Environment::getInstance()->elementwiseThreshold(); + if (N < maxItersPerThread) + _numThreads = 1; + else { +#ifdef _OPENMP + if (desiredNumThreads == -1) + desiredNumThreads = omp_get_max_threads(); + else if (desiredNumThreads < 1) + desiredNumThreads = 1; + else + desiredNumThreads = + sd::math::nd4j_min(omp_get_max_threads(), desiredNumThreads); +#else + desiredNumThreads = sd::Environment::getInstance()->maxThreads(); +#endif + _numThreads = + sd::math::nd4j_min(N / maxItersPerThread, desiredNumThreads); + } + + _itersPerThread = N / _numThreads; + _remainder = + N % _numThreads; // last thread may contain bigger number of iterations +} Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) { - return OmpLaunchHelper::betterSpan(N, OmpLaunchHelper::betterThreads(N)); - } - - Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N, Nd4jLong numThreads) { - auto r = N % numThreads; - auto t = N / numThreads; - - if (r == 0) - return t; - else { - // breaks alignment - return t + 1; - } - } - - int OmpLaunchHelper::betterThreads(Nd4jLong N) { - #ifdef _OPENMP - return betterThreads(N, omp_get_max_threads()); - #else - return betterThreads(N, sd::Environment::getInstance()->maxThreads());; - #endif - } - - int OmpLaunchHelper::betterThreads(Nd4jLong N, int maxThreads) { - auto t = Environment::getInstance()->elementwiseThreshold(); - if (N < t) - return 1; - else { - return static_cast(sd::math::nd4j_min(N / t, maxThreads)); - } - } - - int OmpLaunchHelper::tadThreads(Nd4jLong tadLength, Nd4jLong numTads) { + return OmpLaunchHelper::betterSpan(N, OmpLaunchHelper::betterThreads(N)); +} + +Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N, Nd4jLong numThreads) { + auto r = N % numThreads; + auto t = N / numThreads; + + if (r == 0) + return t; + else { + // breaks alignment + return t + 1; + } +} + +int OmpLaunchHelper::betterThreads(Nd4jLong N) { +#ifdef _OPENMP + return betterThreads(N, omp_get_max_threads()); +#else + return betterThreads(N, sd::Environment::getInstance()->maxThreads()); + ; +#endif +} + +int OmpLaunchHelper::betterThreads(Nd4jLong N, int maxThreads) { + auto t = Environment::getInstance()->elementwiseThreshold(); + if (N < t) + return 1; + else { + return static_cast(sd::math::nd4j_min(N / t, maxThreads)); + } +} + +int OmpLaunchHelper::tadThreads(Nd4jLong tadLength, Nd4jLong numTads) { #ifdef _OPENMP - auto maxThreads = omp_get_max_threads(); + auto maxThreads = omp_get_max_threads(); #else - auto maxThreads = sd::Environment::getInstance()->maxThreads(); + auto maxThreads = sd::Environment::getInstance()->maxThreads(); #endif - // if there's only 1 thread allowed - nothing to do here - if (maxThreads <= 1) - return 1; + // if there's only 1 thread allowed - nothing to do here + if (maxThreads <= 1) return 1; - auto totalLength = tadLength * numTads; + auto totalLength = tadLength * numTads; - // if array is tiny - no need to spawn any threeds - if (totalLength < Environment::getInstance()->elementwiseThreshold()) - return 1; + // if array is tiny - no need to spawn any threeds + if (totalLength < Environment::getInstance()->elementwiseThreshold()) + return 1; - // by default we're spawning as many threads we can, but not more than number of TADs - return sd::math::nd4j_min(numTads, maxThreads); - } + // by default we're spawning as many threads we can, but not more than number + // of TADs + return sd::math::nd4j_min(numTads, maxThreads); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/OpArgsHolder.cpp b/libnd4j/include/helpers/impl/OpArgsHolder.cpp index 7b82a85d937b..ab1a0986a722 100644 --- a/libnd4j/include/helpers/impl/OpArgsHolder.cpp +++ b/libnd4j/include/helpers/impl/OpArgsHolder.cpp @@ -20,140 +20,129 @@ #include - namespace sd { //////////////////////////////////////////////////////////////////////// // default constructor OpArgsHolder::OpArgsHolder() { + _inArrs = std::vector(); + _tArgs = std::vector(); + _iArgs = std::vector(); + _bArgs = std::vector(); - _inArrs = std::vector(); - _tArgs = std::vector(); - _iArgs = std::vector(); - _bArgs = std::vector(); - - _isArrAlloc = std::vector(); + _isArrAlloc = std::vector(); - _numInArrs = 0; - _numTArgs = 0; - _numIArgs = 0; - _numBArgs = 0; + _numInArrs = 0; + _numTArgs = 0; + _numIArgs = 0; + _numBArgs = 0; } //////////////////////////////////////////////////////////////////////// // copy constructor OpArgsHolder::OpArgsHolder(const OpArgsHolder& other) { - - throw std::runtime_error("OpArgsHolder::OpArgsHolder copy constructor: don't use me !"); + throw std::runtime_error( + "OpArgsHolder::OpArgsHolder copy constructor: don't use me !"); } - //////////////////////////////////////////////////////////////////////// // constructor OpArgsHolder::OpArgsHolder(const std::vector& inArrs, - const std::vector& tArgs, - const std::vector& iArgs, - const std::vector& bArgs) { - _inArrs = inArrs; - _tArgs = tArgs; - _iArgs = iArgs; - _bArgs = bArgs; - - _isArrAlloc = std::vector(); - - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); + const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& bArgs) { + _inArrs = inArrs; + _tArgs = tArgs; + _iArgs = iArgs; + _bArgs = bArgs; + + _isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); } //////////////////////////////////////////////////////////////////////// // move constructor -OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept: _inArrs(std::move(other._inArrs)), - _tArgs(std::move(other._tArgs)), - _iArgs(std::move(other._iArgs)), - _bArgs(std::move(other._bArgs)), - _isArrAlloc(std::move(other._isArrAlloc)) { - - other._isArrAlloc = std::vector(); - - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); +OpArgsHolder::OpArgsHolder(OpArgsHolder&& other) noexcept + : _inArrs(std::move(other._inArrs)), + _tArgs(std::move(other._tArgs)), + _iArgs(std::move(other._iArgs)), + _bArgs(std::move(other._bArgs)), + _isArrAlloc(std::move(other._isArrAlloc)) { + other._isArrAlloc = std::vector(); + + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); } //////////////////////////////////////////////////////////////////////// // assignment operator OpArgsHolder& OpArgsHolder::operator=(const OpArgsHolder& other) { - - throw std::runtime_error("OpArgsHolder::OpArgsHolder assignment operator: don't use me !"); + throw std::runtime_error( + "OpArgsHolder::OpArgsHolder assignment operator: don't use me !"); } - //////////////////////////////////////////////////////////////////////// // move assignment operator OpArgsHolder& OpArgsHolder::operator=(OpArgsHolder&& other) noexcept { + if (this == &other) return *this; - if (this == &other) - return *this; - - for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary - if(_isArrAlloc[i]) - delete _inArrs[i]; + for (int i = 0; i < _isArrAlloc.size(); ++i) // delete arrays if necessary + if (_isArrAlloc[i]) delete _inArrs[i]; - _inArrs = std::move(other._inArrs); - _tArgs = std::move(other._tArgs); - _iArgs = std::move(other._iArgs); - _bArgs = std::move(other._bArgs); - _isArrAlloc = std::move(other._isArrAlloc); + _inArrs = std::move(other._inArrs); + _tArgs = std::move(other._tArgs); + _iArgs = std::move(other._iArgs); + _bArgs = std::move(other._bArgs); + _isArrAlloc = std::move(other._isArrAlloc); - other._isArrAlloc = std::vector(); + other._isArrAlloc = std::vector(); - _numInArrs = _inArrs.size(); - _numTArgs = _tArgs.size(); - _numIArgs = _iArgs.size(); - _numBArgs = _bArgs.size(); + _numInArrs = _inArrs.size(); + _numTArgs = _tArgs.size(); + _numIArgs = _iArgs.size(); + _numBArgs = _bArgs.size(); - return *this; + return *this; } //////////////////////////////////////////////////////////////////////// -OpArgsHolder OpArgsHolder::createArgsHolderForBP(const std::vector& inGradArrs, const bool isInPlace) const { - - const int numInGradArrs = inGradArrs.size(); - - OpArgsHolder result(std::vector(_numInArrs + numInGradArrs, nullptr), _tArgs, _iArgs); - - if(isInPlace) - result._isArrAlloc = std::vector(_numInArrs + numInGradArrs, false); - - for (int i = 0; i < _numInArrs; ++i) { - - if(isInPlace) { - result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy - result._isArrAlloc[i] = true; - } - else - result._inArrs[i] = _inArrs[i]; - } - - // input gradients - for (int i = 0; i < numInGradArrs; ++i) - result._inArrs[_numInArrs + i] = inGradArrs[i]; - - return result; +OpArgsHolder OpArgsHolder::createArgsHolderForBP( + const std::vector& inGradArrs, const bool isInPlace) const { + const int numInGradArrs = inGradArrs.size(); + + OpArgsHolder result( + std::vector(_numInArrs + numInGradArrs, nullptr), _tArgs, + _iArgs); + + if (isInPlace) + result._isArrAlloc = std::vector(_numInArrs + numInGradArrs, false); + + for (int i = 0; i < _numInArrs; ++i) { + if (isInPlace) { + result._inArrs[i] = new NDArray(*_inArrs[i]); // make copy + result._isArrAlloc[i] = true; + } else + result._inArrs[i] = _inArrs[i]; + } + + // input gradients + for (int i = 0; i < numInGradArrs; ++i) + result._inArrs[_numInArrs + i] = inGradArrs[i]; + + return result; } //////////////////////////////////////////////////////////////////////// // default destructor OpArgsHolder::~OpArgsHolder() noexcept { - - for (int i = 0; i < _isArrAlloc.size(); ++i) - if(_isArrAlloc[i]) - delete _inArrs[i]; -} - + for (int i = 0; i < _isArrAlloc.size(); ++i) + if (_isArrAlloc[i]) delete _inArrs[i]; } - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/OpBenchmark.cpp b/libnd4j/include/helpers/impl/OpBenchmark.cpp index 1eba37d03561..8fe4dd500f0c 100644 --- a/libnd4j/include/helpers/impl/OpBenchmark.cpp +++ b/libnd4j/include/helpers/impl/OpBenchmark.cpp @@ -21,103 +21,81 @@ #include "../OpBenchmark.h" namespace sd { - OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, const NDArray &z) { - _testName = name; - _x = x; - _y = y; - _z = z; - } - - OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z) { - _testName = name; - _x = x; - _z = z; - } - - OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &z, const std::vector &axis) { - _testName = name; - _x = x; - _z = z; - _axis = axis; - - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - } - - OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, const NDArray &y, const NDArray &z, const std::vector &axis) { - _testName = name; - _x = x; - _y = y; - _z = z; - _axis = axis; - - if (_axis.size() > 1) - std::sort(_axis.begin(), _axis.end()); - } - - - NDArray& OpBenchmark::x() { - return _x; - } - - int OpBenchmark::opNum() const { - return _opNum; - } - const std::string& OpBenchmark::testName() const{ - return _testName; - } - - void OpBenchmark::setOpNum(int opNum) { - _opNum = opNum; - } - - void OpBenchmark::setTestName(const std::string &name){ - _testName = name; - } - - void OpBenchmark::setX(const NDArray &array) { - _x = array; - } - - void OpBenchmark::setY(const NDArray &array) { - _y = array; - } - - void OpBenchmark::setZ(const NDArray &array) { - _z = array; - } - - void OpBenchmark::setAxis(std::vector axis) { - _axis = axis; - } - - void OpBenchmark::setAxis(std::initializer_list axis) { - _axis = axis; - } - - std::vector OpBenchmark::getAxis(){ - return _axis; - } - - std::string OpBenchmark::extra() { - return "N/A"; - } - - std::string OpBenchmark::shape() { - if (_x.shapeInfo() != nullptr) - return ShapeUtils::shapeAsString(_x); - else if (_z.shapeInfo() != nullptr) - return ShapeUtils::shapeAsString(_z); - else - return "N/A"; - } - - std::string OpBenchmark::dataType() { - if (_x.shapeInfo() != nullptr) - return DataTypeUtils::asString(_x.dataType()); - else if (_z.shapeInfo() != nullptr) - return DataTypeUtils::asString(_z.dataType()); - else - return "N/A"; - } -} \ No newline at end of file +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &y, const NDArray &z) { + _testName = name; + _x = x; + _y = y; + _z = z; +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &z) { + _testName = name; + _x = x; + _z = z; +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &z, const std::vector &axis) { + _testName = name; + _x = x; + _z = z; + _axis = axis; + + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} + +OpBenchmark::OpBenchmark(const std::string &name, const NDArray &x, + const NDArray &y, const NDArray &z, + const std::vector &axis) { + _testName = name; + _x = x; + _y = y; + _z = z; + _axis = axis; + + if (_axis.size() > 1) std::sort(_axis.begin(), _axis.end()); +} + +NDArray &OpBenchmark::x() { return _x; } + +int OpBenchmark::opNum() const { return _opNum; } +const std::string &OpBenchmark::testName() const { return _testName; } + +void OpBenchmark::setOpNum(int opNum) { _opNum = opNum; } + +void OpBenchmark::setTestName(const std::string &name) { _testName = name; } + +void OpBenchmark::setX(const NDArray &array) { _x = array; } + +void OpBenchmark::setY(const NDArray &array) { _y = array; } + +void OpBenchmark::setZ(const NDArray &array) { _z = array; } + +void OpBenchmark::setAxis(std::vector axis) { _axis = axis; } + +void OpBenchmark::setAxis(std::initializer_list axis) { _axis = axis; } + +std::vector OpBenchmark::getAxis() { return _axis; } + +std::string OpBenchmark::extra() { return "N/A"; } + +std::string OpBenchmark::shape() { + if (_x.shapeInfo() != nullptr) + return ShapeUtils::shapeAsString(_x); + else if (_z.shapeInfo() != nullptr) + return ShapeUtils::shapeAsString(_z); + else + return "N/A"; +} + +std::string OpBenchmark::dataType() { + if (_x.shapeInfo() != nullptr) + return DataTypeUtils::asString(_x.dataType()); + else if (_z.shapeInfo() != nullptr) + return DataTypeUtils::asString(_z.dataType()); + else + return "N/A"; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/OpTracker.cpp b/libnd4j/include/helpers/impl/OpTracker.cpp index bb82ab0d1a04..b2e7217d055d 100644 --- a/libnd4j/include/helpers/impl/OpTracker.cpp +++ b/libnd4j/include/helpers/impl/OpTracker.cpp @@ -19,105 +19,102 @@ // #include -#include #include #include +#include using namespace sd::ops; using namespace sd::graph; namespace sd { - - OpTracker* OpTracker::getInstance() { - if (_INSTANCE == 0) - _INSTANCE = new OpTracker(); - return _INSTANCE; - } +OpTracker* OpTracker::getInstance() { + if (_INSTANCE == 0) _INSTANCE = new OpTracker(); - void OpTracker::storeOperation(sd::graph::OpType opType, const OpDescriptor& descriptor) { - // check out CPU features - if (!::isMinimalRequirementsMet()) { - - auto binaryLevel = ::binaryLevel(); - auto optimalLevel = ::optimalLevel(); - - switch (binaryLevel) { - case 3: { - nd4j_printf("libnd4j binary was built with AVX512 support, but current CPU doesn't have this instruction set. Exiting now...",""); - } - break; - case 2: { - nd4j_printf("libnd4j binary was built with AVX/AVX2 support, but current CPU doesn't have this instruction set. Exiting now...",""); - } - break; - default: { - nd4j_printf("Unknown binary validation error. Exiting now...",""); - } - break; - } - - // we're exiting now - exit(119); - } - // - if (_map.count(opType) < 1) { - std::vector vec; - _map[opType] = vec; - } - - _operations++; - - auto vec = _map[opType]; - - if (std::find(vec.begin(), vec.end(), descriptor) == vec.end()) - _map[opType].emplace_back(descriptor); + return _INSTANCE; +} + +void OpTracker::storeOperation(sd::graph::OpType opType, + const OpDescriptor& descriptor) { + // check out CPU features + if (!::isMinimalRequirementsMet()) { + auto binaryLevel = ::binaryLevel(); + auto optimalLevel = ::optimalLevel(); + + switch (binaryLevel) { + case 3: { + nd4j_printf( + "libnd4j binary was built with AVX512 support, but current CPU " + "doesn't have this instruction set. Exiting now...", + ""); + } break; + case 2: { + nd4j_printf( + "libnd4j binary was built with AVX/AVX2 support, but current CPU " + "doesn't have this instruction set. Exiting now...", + ""); + } break; + default: { + nd4j_printf("Unknown binary validation error. Exiting now...", ""); + } break; } - void OpTracker::storeOperation(sd::graph::OpType opType, const char* opName, const Nd4jLong opNum) { - OpDescriptor descriptor(0, opName, false); - descriptor.setOpNum((int) opNum); - descriptor.setHash(-1); + // we're exiting now + exit(119); + } + // + if (_map.count(opType) < 1) { + std::vector vec; + _map[opType] = vec; + } - storeOperation(opType, descriptor); - } + _operations++; + auto vec = _map[opType]; - template - std::string OpTracker::local_to_string(T value) { - std::ostringstream os ; - os << value ; - return os.str() ; - } + if (std::find(vec.begin(), vec.end(), descriptor) == vec.end()) + _map[opType].emplace_back(descriptor); +} +void OpTracker::storeOperation(sd::graph::OpType opType, const char* opName, + const Nd4jLong opNum) { + OpDescriptor descriptor(0, opName, false); + descriptor.setOpNum((int)opNum); + descriptor.setHash(-1); - int OpTracker::totalGroups() { - return (int) _map.size(); - } + storeOperation(opType, descriptor); +} - int OpTracker::totalOperations() { - return _operations; - } +template +std::string OpTracker::local_to_string(T value) { + std::ostringstream os; + os << value; + return os.str(); +} + +int OpTracker::totalGroups() { return (int)_map.size(); } - const char* OpTracker::exportOperations() { - if (_export.length() == 0) { - for (auto &v: _map) { - std::string block = local_to_string(v.first) + " "; +int OpTracker::totalOperations() { return _operations; } - for (auto &i: v.second) { - block += local_to_string(i.getHash()) + ":"; - block += local_to_string(i.getOpNum()) + ":"; - block += *i.getOpName() + "<<"; - } +const char* OpTracker::exportOperations() { + if (_export.length() == 0) { + for (auto& v : _map) { + std::string block = local_to_string(v.first) + " "; - block += ">>"; - _export += block; - } - } + for (auto& i : v.second) { + block += local_to_string(i.getHash()) + ":"; + block += local_to_string(i.getOpNum()) + ":"; + block += *i.getOpName() + "<<"; + } - return _export.c_str(); + block += ">>"; + _export += block; } + } - sd::OpTracker* sd::OpTracker::_INSTANCE = 0; + return _export.c_str(); } + +sd::OpTracker* sd::OpTracker::_INSTANCE = 0; +} // namespace sd diff --git a/libnd4j/include/helpers/impl/Parameters.cpp b/libnd4j/include/helpers/impl/Parameters.cpp index 356ad5a5aa7e..a609ce4cb302 100644 --- a/libnd4j/include/helpers/impl/Parameters.cpp +++ b/libnd4j/include/helpers/impl/Parameters.cpp @@ -19,81 +19,88 @@ // #include "../benchmark/Parameters.h" + #include namespace sd { - Parameters* Parameters::addIntParam(std::string string, int param) { - _intParams[string] = param; - return this; - } +Parameters* Parameters::addIntParam(std::string string, int param) { + _intParams[string] = param; + return this; +} - int Parameters::getIntParam(std::string string) const { - if (_intParams.count(string) == 0) - throw std::runtime_error("Not available intParameter requested"); +int Parameters::getIntParam(std::string string) const { + if (_intParams.count(string) == 0) + throw std::runtime_error("Not available intParameter requested"); - return _intParams.at(string); - } + return _intParams.at(string); +} - Parameters* Parameters::addIntParam(std::initializer_list strings, std::initializer_list params) { - std::vector s(strings); - std::vector p(params); +Parameters* Parameters::addIntParam(std::initializer_list strings, + std::initializer_list params) { + std::vector s(strings); + std::vector p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _intParams[s[e]] = p[e]; + for (int e = 0; e < s.size(); e++) _intParams[s[e]] = p[e]; - return this; - } + return this; +} - Parameters* Parameters::addBoolParam(std::string string, bool param) { - _boolParams[string] = param; - return this; - } +Parameters* Parameters::addBoolParam(std::string string, bool param) { + _boolParams[string] = param; + return this; +} - Parameters* Parameters::addBoolParam(std::initializer_list strings, std::initializer_list params) { - std::vector s(strings); - std::vector p(params); +Parameters* Parameters::addBoolParam(std::initializer_list strings, + std::initializer_list params) { + std::vector s(strings); + std::vector p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _boolParams[s[e]] = p[e]; + for (int e = 0; e < s.size(); e++) _boolParams[s[e]] = p[e]; - return this; - } + return this; +} - Parameters* Parameters::addArrayParam(std::string string, std::initializer_list param) { - _arrayParams[string] = std::vector(param); - return this; - } +Parameters* Parameters::addArrayParam(std::string string, + std::initializer_list param) { + _arrayParams[string] = std::vector(param); + return this; +} - Parameters* Parameters::addArrayParam(std::initializer_list strings, std::initializer_list> params) { - std::vector s(strings); - std::vector> p(params); +Parameters* Parameters::addArrayParam( + std::initializer_list strings, + std::initializer_list> params) { + std::vector s(strings); + std::vector> p(params); - if (s.size() != p.size()) - throw std::runtime_error("addIntParam: number of keys and values should match"); + if (s.size() != p.size()) + throw std::runtime_error( + "addIntParam: number of keys and values should match"); - for (int e = 0; e < s.size(); e++) - _arrayParams[s[e]] = std::vector(p[e]); + for (int e = 0; e < s.size(); e++) + _arrayParams[s[e]] = std::vector(p[e]); - return this; - } + return this; +} - bool Parameters::getBoolParam(std::string string) const { - if (_boolParams.count(string) == 0) - throw std::runtime_error("Not available boolParameter requested"); +bool Parameters::getBoolParam(std::string string) const { + if (_boolParams.count(string) == 0) + throw std::runtime_error("Not available boolParameter requested"); - return _boolParams.at(string); - } + return _boolParams.at(string); +} - std::vector Parameters::getArrayParam(std::string string) const { - if (_arrayParams.count(string) == 0) - throw std::runtime_error("Not available arrayParameter requested"); +std::vector Parameters::getArrayParam(std::string string) const { + if (_arrayParams.count(string) == 0) + throw std::runtime_error("Not available arrayParameter requested"); - return _arrayParams.at(string); - } + return _arrayParams.at(string); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index f7cdd0f3aea9..4bb9ca952c35 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -18,140 +18,207 @@ // @author raver119@gmail.com // -#include -#include -#include #include +#include +#include +#include //#include #include namespace sd { - void RandomLauncher::applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { - if (z == nullptr) - z = array; - - ExtraArguments arguments({retainProb}); - PointersManager pm(context, "applyDropOut"); - - NDArray::prepareSpecialUse({z}, {array}); - - NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyDropOut"); - void RandomLauncher::applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { - if (z == nullptr) - z = array; + NDArray::prepareSpecialUse({z}, {array}); - ExtraArguments arguments({retainProb}); - PointersManager pm(context, "applyInvertedDropOut"); + NativeOpExecutioner::execRandom( + context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NDArray::prepareSpecialUse({z}, {array}); + NDArray::registerSpecialUse({z}, {array}); +} - NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyInvertedDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyInvertedDropOut"); - void RandomLauncher::applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { - if (z == nullptr) - z = array; + NDArray::prepareSpecialUse({z}, {array}); - ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); - PointersManager pm(context, "applyAlphaDropOut"); + NativeOpExecutioner::execRandom( + context, random::DropOutInverted, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NDArray::prepareSpecialUse({z}, {array}); + NDArray::registerSpecialUse({z}, {array}); +} - NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); - pm.synchronize(); +void RandomLauncher::applyAlphaDropOut(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double retainProb, + double alpha, double beta, + double alphaPrime, NDArray* z) { + if (z == nullptr) z = array; - NDArray::registerSpecialUse({z}, {array}); - } + ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); + PointersManager pm(context, "applyAlphaDropOut"); - void RandomLauncher::fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob) { - ExtraArguments arguments({prob}); - PointersManager pm(context, "fillBernoulli"); + NDArray::prepareSpecialUse({z}, {array}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + arguments.argumentsAsT(z->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({z}, {array}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillBernoulli(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double prob) { + ExtraArguments arguments({prob}); + PointersManager pm(context, "fillBernoulli"); - void RandomLauncher::fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to) { - ExtraArguments arguments({from, to}); - PointersManager pm(context, "fillUniform"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::BernoulliDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillUniform(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double from, double to) { + ExtraArguments arguments({from, to}); + PointersManager pm(context, "fillUniform"); - void RandomLauncher::fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillGaussian"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::UniformDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillGaussian(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillGaussian"); - void RandomLauncher::fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda) { - ExtraArguments arguments({lambda}); - PointersManager pm(context, "fillExponential"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillExponential(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double lambda) { + ExtraArguments arguments({lambda}); + PointersManager pm(context, "fillExponential"); - void RandomLauncher::fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillLogNormal"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::ExponentialDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillLogNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillLogNormal"); - void RandomLauncher::fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { - ExtraArguments arguments({mean, stdev}); - PointersManager pm(context, "fillTruncatedNormal"); + NDArray::prepareSpecialUse({array}, {}); - NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NDArray::registerSpecialUse({array}, {}); +} - NDArray::registerSpecialUse({array}, {}); - } +void RandomLauncher::fillTruncatedNormal(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, double mean, + double stdev) { + ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillTruncatedNormal"); + + NDArray::prepareSpecialUse({array}, {}); + + NativeOpExecutioner::execRandom( + context, random::TruncatedNormalDistribution, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); +} - void RandomLauncher::fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { - ExtraArguments arguments({(double) trials, prob}); - PointersManager pm(context, "fillBinomial"); +void RandomLauncher::fillBinomial(sd::LaunchContext* context, + sd::graph::RandomGenerator& rng, + NDArray* array, int trials, double prob) { + ExtraArguments arguments({(double)trials, prob}); + PointersManager pm(context, "fillBinomial"); - NDArray::prepareSpecialUse({array}, {}); + NDArray::prepareSpecialUse({array}, {}); - NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); + NativeOpExecutioner::execRandom( + context, random::BinomialDistributionEx, &rng, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); - NDArray::registerSpecialUse({array}, {}); - } + NDArray::registerSpecialUse({array}, {}); } +} // namespace sd diff --git a/libnd4j/include/helpers/impl/ShapeBuilders.cpp b/libnd4j/include/helpers/impl/ShapeBuilders.cpp index 7c0c7fed6577..2fde57f5f7e4 100644 --- a/libnd4j/include/helpers/impl/ShapeBuilders.cpp +++ b/libnd4j/include/helpers/impl/ShapeBuilders.cpp @@ -22,132 +22,155 @@ namespace sd { +Nd4jLong* ShapeBuilders::createScalarShapeInfo( + const sd::DataType dataType, sd::memory::Workspace* workspace) { + Nd4jLong* newShape; + ALLOCATE(newShape, workspace, shape::shapeInfoLength(0), Nd4jLong); + newShape[0] = 0; + newShape[1] = 0; + newShape[2] = 1; + newShape[3] = 99; + + sd::ArrayOptions::setDataType(newShape, dataType); + + return newShape; +} - Nd4jLong* ShapeBuilders::createScalarShapeInfo(const sd::DataType dataType, sd::memory::Workspace* workspace) { - Nd4jLong *newShape; - ALLOCATE(newShape, workspace, shape::shapeInfoLength(0), Nd4jLong); - newShape[0] = 0; - newShape[1] = 0; - newShape[2] = 1; - newShape[3] = 99; - - sd::ArrayOptions::setDataType(newShape, dataType); - - return newShape; - } +Nd4jLong* ShapeBuilders::createVectorShapeInfo( + const sd::DataType dataType, const Nd4jLong length, + sd::memory::Workspace* workspace) { + Nd4jLong* newShape; + ALLOCATE(newShape, workspace, shape::shapeInfoLength(1), Nd4jLong); - Nd4jLong* ShapeBuilders::createVectorShapeInfo(const sd::DataType dataType, const Nd4jLong length, sd::memory::Workspace* workspace) { - Nd4jLong *newShape; - ALLOCATE(newShape, workspace, shape::shapeInfoLength(1), Nd4jLong); + newShape[0] = 1; + newShape[1] = length; + newShape[2] = 1; + newShape[3] = 0; + newShape[4] = 1; + newShape[5] = 99; - newShape[0] = 1; - newShape[1] = length; - newShape[2] = 1; - newShape[3] = 0; - newShape[4] = 1; - newShape[5] = 99; + sd::ArrayOptions::setDataType(newShape, dataType); - sd::ArrayOptions::setDataType(newShape, dataType); + return newShape; +} - return newShape; +//////////////////////////////////////////////////////////////////////////////// +Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, + const char order, int rank, + const Nd4jLong* shapeOnly, + memory::Workspace* workspace) { + Nd4jLong* shapeInfo = nullptr; + + if (rank == 0) { // scalar case + shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + } else { + ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + shapeInfo[0] = rank; + bool isEmpty = false; + for (int i = 0; i < rank; ++i) { + shapeInfo[i + 1] = shapeOnly[i]; + + if (shapeOnly[i] == 0) isEmpty = true; } - //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) { - Nd4jLong* shapeInfo = nullptr; - - if(rank == 0) { // scalar case - shapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); - } - else { - ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - shapeInfo[0] = rank; - bool isEmpty = false; - for(int i = 0; i < rank; ++i) { - shapeInfo[i + 1] = shapeOnly[i]; - - if (shapeOnly[i] == 0) - isEmpty = true; - } - - if (!isEmpty) { - shape::updateStrides(shapeInfo, order); - } - else { - shapeInfo[shape::shapeInfoLength(rank) - 1] = order; - memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong)); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - } - - sd::ArrayOptions::setDataType(shapeInfo, dataType); - } - - return shapeInfo; + if (!isEmpty) { + shape::updateStrides(shapeInfo, order); + } else { + shapeInfo[shape::shapeInfoLength(rank) - 1] = order; + memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); } - Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace) { - auto shapeInfo = createScalarShapeInfo(dataType, workspace); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - return shapeInfo; - } + sd::ArrayOptions::setDataType(shapeInfo, dataType); + } - Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace) { - auto shapeInfo = createShapeInfo(dataType, order, shape, workspace); - memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong)); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); - return shapeInfo; - } + return shapeInfo; +} -//////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, const std::vector& shapeOnly, memory::Workspace* workspace) { +Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, + memory::Workspace* workspace) { + auto shapeInfo = createScalarShapeInfo(dataType, workspace); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; +} - return ShapeBuilders::createShapeInfo(dataType, order, shapeOnly.size(), shapeOnly.data(), workspace); - } +Nd4jLong* ShapeBuilders::emptyShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shape, + memory::Workspace* workspace) { + auto shapeInfo = createShapeInfo(dataType, order, shape, workspace); + memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, const char order, const std::initializer_list& shapeOnly, memory::Workspace* workspace) { - - return ShapeBuilders::createShapeInfo(dataType, order, std::vector(shapeOnly), workspace); - } +Nd4jLong* ShapeBuilders::createShapeInfo(const sd::DataType dataType, + const char order, + const std::vector& shapeOnly, + memory::Workspace* workspace) { + return ShapeBuilders::createShapeInfo(dataType, order, shapeOnly.size(), + shapeOnly.data(), workspace); +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfo(const Nd4jLong* inShapeInfo, const bool copyStrides, memory::Workspace* workspace) { +Nd4jLong* ShapeBuilders::createShapeInfo( + const sd::DataType dataType, const char order, + const std::initializer_list& shapeOnly, + memory::Workspace* workspace) { + return ShapeBuilders::createShapeInfo( + dataType, order, std::vector(shapeOnly), workspace); +} - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); +//////////////////////////////////////////////////////////////////////////////// +Nd4jLong* ShapeBuilders::copyShapeInfo(const Nd4jLong* inShapeInfo, + const bool copyStrides, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), + Nd4jLong); - memcpy(outShapeInfo, inShapeInfo, shape::shapeInfoByteLength(inShapeInfo)); + memcpy(outShapeInfo, inShapeInfo, shape::shapeInfoByteLength(inShapeInfo)); - if(!copyStrides) - shape::updateStrides(outShapeInfo, shape::order(outShapeInfo)); + if (!copyStrides) + shape::updateStrides(outShapeInfo, shape::order(outShapeInfo)); - return outShapeInfo; - } + return outShapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace) { - - Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfo(inShapeInfo, copyStrides, workspace); - ArrayOptions::setDataType(outShapeInfo, dtype); - - return outShapeInfo; - } +Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, + const DataType dtype, + const bool copyStrides, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = + ShapeBuilders::copyShapeInfo(inShapeInfo, copyStrides, workspace); + ArrayOptions::setDataType(outShapeInfo, dtype); + + return outShapeInfo; +} //////////////////////////////////////////////////////////////////////////////// - Nd4jLong* ShapeBuilders::copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace) { - - return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace); - } +Nd4jLong* ShapeBuilders::copyShapeInfoAndType( + const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, + const bool copyStrides, memory::Workspace* workspace) { + return ShapeBuilders::copyShapeInfoAndType( + inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, + workspace); +} //////////////////////////////////////////////////////////////////////////////// -Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) { - - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong); +Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites( + const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, + memory::Workspace* workspace) { + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, workspace, + shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong); - shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo); + shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, + outShapeInfo); - return outShapeInfo; + return outShapeInfo; } -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index 9cc559d89260..7bfc527a6ef9 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -18,1063 +18,1213 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include + +#include #include #include -#include #include -#include - namespace sd { ////////////////////////////////////////////////////////////////////////// -// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and dimensions permutations for transposition of two input arrays -std::vector ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector axesA, std::vector axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt) { - - int axeAsize = (int) axesA.size(); - int axeBsize = (int) axesB.size(); - int aRank = aShapeInfo[0]; - int bRank = bShapeInfo[0]; - - if(axeAsize != axeBsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b axes to make dot product along must have identical values !"); - if(axeAsize > aRank || axeBsize > bRank) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the length of vector of a or b axes is larger than array rank !"); - - // axes validation - for (int i = 0; i < axeBsize; i++) { - if (axesA[i] < 0) - axesA[i] += aRank; - if (axesB[i] < 0) - axesB[i] += bRank; - if (aShapeInfo[axesA[i] + 1] != bShapeInfo[axesB[i] + 1]) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the dimensions at given axes for both input arrays must be the same !"); - } - - // check whether axesA and axesB contain only unique numbers - std::set uniqueElems(axesA.begin(), axesA.end()); - if((int)uniqueElems.size() != axeAsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of a axes contains duplicates !"); - uniqueElems.clear(); - uniqueElems = std::set(axesB.begin(), axesB.end()); - if((int)uniqueElems.size() != axeBsize) - throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of b axes contains duplicates !"); - - std::vector list_A, list_B; - for (int i = 0; i < aRank; i++) - if (std::find(axesA.begin(), axesA.end(), i) == axesA.end()) - list_A.emplace_back(i); - for (int i = 0; i < bRank; i++) - if (std::find(axesB.begin(), axesB.end(), i) == axesB.end()) - list_B.emplace_back(i); - - permutAt = list_A; - permutAt.insert(permutAt.end(), axesA.begin(), axesA.end()); - permutBt = axesB; - permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); - - // if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case - uint i1, i2; - for(i1 = 0; i1 < aRank; ++i1) - if(permutAt[i1] != i1) - break; - if(i1 == aRank) - permutAt = {}; - for(i2 = 0; i2 < bRank; ++i2) - if(permutBt[i2] != i2) - break; - if(i2 == bRank) - permutBt = {}; - - Nd4jLong n2 = 1; - for (int i = 0; i < axeAsize; i++) - n2 *= aShapeInfo[axesA[i] + 1]; - shapeAt = {shape::length(aShapeInfo) / n2, n2}; - - std::vector oldShapeA; - oldShapeA.resize(list_A.size()); - for (int i = 0; i < oldShapeA.size(); ++i) - oldShapeA[i] = aShapeInfo[list_A[i] + 1]; - - - Nd4jLong n3 = 1; - for (int i = 0; i < axeBsize; i++) - n3 *= bShapeInfo[axesB[i] + 1]; - shapeBt = {n3, shape::length(bShapeInfo) / n3}; - - std::vector oldShapeB; - oldShapeB.resize(list_B.size()); - for (int i = 0; i < oldShapeB.size(); i++) - oldShapeB[i] = bShapeInfo[list_B[i] + 1]; - - std::vector aPlusB(oldShapeA); - aPlusB.insert(aPlusB.end(), oldShapeB.begin(), oldShapeB.end()); - - return aPlusB; +// evaluate shape for array resulting from tensorDot operation, also evaluate +// shapes and dimensions permutations for transposition of two input arrays +std::vector ShapeUtils::evalShapeForTensorDot( + const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, + std::vector axesA, std::vector axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt) { + int axeAsize = (int)axesA.size(); + int axeBsize = (int)axesB.size(); + int aRank = aShapeInfo[0]; + int bRank = bShapeInfo[0]; + + if (axeAsize != axeBsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b " + "axes to make dot product along must have identical values !"); + if (axeAsize > aRank || axeBsize > bRank) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the length of vector of a " + "or b axes is larger than array rank !"); + + // axes validation + for (int i = 0; i < axeBsize; i++) { + if (axesA[i] < 0) axesA[i] += aRank; + if (axesB[i] < 0) axesB[i] += bRank; + if (aShapeInfo[axesA[i] + 1] != bShapeInfo[axesB[i] + 1]) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the dimensions at given " + "axes for both input arrays must be the same !"); + } + + // check whether axesA and axesB contain only unique numbers + std::set uniqueElems(axesA.begin(), axesA.end()); + if ((int)uniqueElems.size() != axeAsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the vector of a axes " + "contains duplicates !"); + uniqueElems.clear(); + uniqueElems = std::set(axesB.begin(), axesB.end()); + if ((int)uniqueElems.size() != axeBsize) + throw std::runtime_error( + "ShapeUtils::evalShapeForTensorDot method: the vector of b axes " + "contains duplicates !"); + + std::vector list_A, list_B; + for (int i = 0; i < aRank; i++) + if (std::find(axesA.begin(), axesA.end(), i) == axesA.end()) + list_A.emplace_back(i); + for (int i = 0; i < bRank; i++) + if (std::find(axesB.begin(), axesB.end(), i) == axesB.end()) + list_B.emplace_back(i); + + permutAt = list_A; + permutAt.insert(permutAt.end(), axesA.begin(), axesA.end()); + permutBt = axesB; + permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); + + // if permut contains something like {0,1,2,..rank-1}, then there is no need + // to make permutation and we return empty vector in this case + uint i1, i2; + for (i1 = 0; i1 < aRank; ++i1) + if (permutAt[i1] != i1) break; + if (i1 == aRank) permutAt = {}; + for (i2 = 0; i2 < bRank; ++i2) + if (permutBt[i2] != i2) break; + if (i2 == bRank) permutBt = {}; + + Nd4jLong n2 = 1; + for (int i = 0; i < axeAsize; i++) n2 *= aShapeInfo[axesA[i] + 1]; + shapeAt = {shape::length(aShapeInfo) / n2, n2}; + + std::vector oldShapeA; + oldShapeA.resize(list_A.size()); + for (int i = 0; i < oldShapeA.size(); ++i) + oldShapeA[i] = aShapeInfo[list_A[i] + 1]; + + Nd4jLong n3 = 1; + for (int i = 0; i < axeBsize; i++) n3 *= bShapeInfo[axesB[i] + 1]; + shapeBt = {n3, shape::length(bShapeInfo) / n3}; + + std::vector oldShapeB; + oldShapeB.resize(list_B.size()); + for (int i = 0; i < oldShapeB.size(); i++) + oldShapeB[i] = bShapeInfo[list_B[i] + 1]; + + std::vector aPlusB(oldShapeA); + aPlusB.insert(aPlusB.end(), oldShapeB.begin(), oldShapeB.end()); + + return aPlusB; } ////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalShapeForTensorDot(const NDArray* a, const NDArray* b, const std::vector& axesA, const std::vector& axesB, std::vector& permutAt, std::vector& permutBt, std::vector& shapeAt, std::vector& shapeBt) { - - return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt); +std::vector ShapeUtils::evalShapeForTensorDot( + const NDArray* a, const NDArray* b, const std::vector& axesA, + const std::vector& axesB, std::vector& permutAt, + std::vector& permutBt, std::vector& shapeAt, + std::vector& shapeBt) { + return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, + permutAt, permutBt, shapeAt, shapeBt); } - ////////////////////////////////////////////////////////////////////////// // evaluate output shape for reduce operation when input shape is empty - const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, const std::vector& vdimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) { - auto dimsToExclude = vdimsToExclude; - - if (dimsToExclude.size() == 0) { // return copy of input shape - Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace); - ShapeDescriptor descriptor(outShapeInfo, dataType); - RELEASE(outShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } - - const int rank = shape::rank(shapeInfo); - Nd4jLong* outShapeInfo = nullptr; - - if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities +const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty( + const char order, const std::vector& vdimsToExclude, + const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims, + sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; + + if (dimsToExclude.size() == 0) { // return copy of input shape + Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + shapeInfo, dataType, true, workspace); + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + } - if(!keepDims) - outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); - else - outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector(rank, 1), workspace); - } - else { + const int rank = shape::rank(shapeInfo); + Nd4jLong* outShapeInfo = nullptr; - shape::checkDimensions(rank, dimsToExclude); + if (dimsToExclude.size() == + rank) { // return scalar or shape filled with unities - std::vector outShape; + if (!keepDims) + outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + else + outShapeInfo = ShapeBuilders::createShapeInfo( + dataType, order, std::vector(rank, 1), workspace); + } else { + shape::checkDimensions(rank, dimsToExclude); - if(keepDims) { - outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank); - for(const auto& dim : dimsToExclude) - outShape[dim] = 1; - } - else { - for (uint i = 0, j = 0; i < rank; ++i) { - if(j < dimsToExclude.size() && i == dimsToExclude[j]) - ++j; - else - outShape.emplace_back(shapeInfo[i + 1]); - } - } + std::vector outShape; - outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace); + if (keepDims) { + outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank); + for (const auto& dim : dimsToExclude) outShape[dim] = 1; + } else { + for (uint i = 0, j = 0; i < rank; ++i) { + if (j < dimsToExclude.size() && i == dimsToExclude[j]) + ++j; + else + outShape.emplace_back(shapeInfo[i + 1]); + } } - ShapeDescriptor descriptor(outShapeInfo, dataType); - RELEASE(outShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); + outShapeInfo = + ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace); + } + + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); } - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, const NDArray& arr, + const bool keepDims, const bool supportOldShapes, + sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), + keepDims, supportOldShapes, workspace); +} - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, + const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, + sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, + ArrayOptions::dataType(shapeInfo), keepDims, + supportOldShapes, workspace); +} ////////////////////////////////////////////////////////////////////////// - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, keepDims, supportOldShapes, workspace); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& dimsToExclude, const NDArray& arr, + const sd::DataType dataType, const bool keepDims, + const bool supportOldShapes, sd::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, + keepDims, supportOldShapes, workspace); +} ////////////////////////////////////////////////////////////////////////// // evaluate shape resulting from reduce operation - const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, const std::vector& vdimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) { - auto dimsToExclude = vdimsToExclude; - - if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) - return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace); - - Nd4jLong* newShapeInfo = nullptr; - - int rank = shape::rank(const_cast(shapeInfo)); - - if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case - - if(keepDims && rank > 1) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - newShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - newShapeInfo[i+1] = 1; - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ArrayOptions::setDataType(newShapeInfo, dataType); - - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } - else if(supportOldShapes) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - shape::shapeOldScalar(dataType, newShapeInfo, 'c'); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } - else { - newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } +const Nd4jLong* ShapeUtils::evalReduceShapeInfo( + const char order, const std::vector& vdimsToExclude, + const Nd4jLong* shapeInfo, const sd::DataType dataType, const bool keepDims, + const bool supportOldShapes, sd::memory::Workspace* workspace) { + auto dimsToExclude = vdimsToExclude; + + if (ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) + return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, + dataType, keepDims, workspace); + + Nd4jLong* newShapeInfo = nullptr; + + int rank = shape::rank(const_cast(shapeInfo)); + + if (dimsToExclude.size() == + 0) { // return scalar or array with len=1 in this case + + if (keepDims && rank > 1) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + newShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) newShapeInfo[i + 1] = 1; + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ArrayOptions::setDataType(newShapeInfo, dataType); + + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + } else if (supportOldShapes) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + shape::shapeOldScalar(dataType, newShapeInfo, 'c'); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + } else { + newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); } - - shape::checkDimensions(rank, dimsToExclude); - - int dimSize = dimsToExclude.size(); - - if(keepDims) { - - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); - newShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied - newShapeInfo[i+1] = 1; - else - newShapeInfo[i+1] = shapeInfo[i+1]; - - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); + } + + shape::checkDimensions(rank, dimsToExclude); + + int dimSize = dimsToExclude.size(); + + if (keepDims) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); + newShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) + if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), + i)) // dimsToExclude is already sorted after + // shape::checkDimensions() has been applied + newShapeInfo[i + 1] = 1; + else + newShapeInfo[i + 1] = shapeInfo[i + 1]; + + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + } + + int newRank = rank - dimSize; + if (newRank == 0 || + (dimSize == 1 && + dimsToExclude[0] == INT_MAX)) { // check whether given dimension is + // meant for the whole dimension + + if (supportOldShapes) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, + 'c'); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + } else { + newShapeInfo = ShapeBuilders::createScalarShapeInfo( + ArrayOptions::dataType(shapeInfo), workspace); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); } - - int newRank = rank - dimSize; - if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension - - if(supportOldShapes) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, 'c'); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } - else { - newShapeInfo = ShapeBuilders::createScalarShapeInfo(ArrayOptions::dataType(shapeInfo), workspace); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } - } - - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(newRank), Nd4jLong); - newShapeInfo[0] = newRank; // set rank - int j=1; - for(int i = 0; i < rank; ++i) - if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied - newShapeInfo[j++] = shapeInfo[i+1]; - - //ensure whether vector has proper shape for old shape type - if (newRank == 1 && supportOldShapes) { - int oldValue = newShapeInfo[1]; - RELEASE(newShapeInfo, workspace); - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2 - newShapeInfo[0] = 2; - if (dimsToExclude[0] == 0) { - newShapeInfo[1] = 1; - newShapeInfo[2] = oldValue; - } - else { - newShapeInfo[1] = oldValue; - newShapeInfo[2] = 1; - } + } + + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(newRank), Nd4jLong); + newShapeInfo[0] = newRank; // set rank + int j = 1; + for (int i = 0; i < rank; ++i) + if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), + i)) // dimsToExclude is already sorted after + // shape::checkDimensions() has been applied + newShapeInfo[j++] = shapeInfo[i + 1]; + + // ensure whether vector has proper shape for old shape type + if (newRank == 1 && supportOldShapes) { + int oldValue = newShapeInfo[1]; + RELEASE(newShapeInfo, workspace); + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), + Nd4jLong); // set newRank = 2 + newShapeInfo[0] = 2; + if (dimsToExclude[0] == 0) { + newShapeInfo[1] = 1; + newShapeInfo[2] = oldValue; + } else { + newShapeInfo[1] = oldValue; + newShapeInfo[2] = 1; } + } - ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); + ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order); - ShapeDescriptor descriptor(newShapeInfo, dataType); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); + ShapeDescriptor descriptor(newShapeInfo, dataType); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); } ////////////////////////////////////////////////////////////////////////// // evaluate shape for array which is result of repeat operation applied to arr -std::vector ShapeUtils::evalRepeatShape(int axis, const std::vector& repeats, const NDArray& arr) { - - if (axis < 0) - axis += arr.rankOf(); +std::vector ShapeUtils::evalRepeatShape( + int axis, const std::vector& repeats, const NDArray& arr) { + if (axis < 0) axis += arr.rankOf(); - if(repeats.size() != 1 && repeats.size() != arr.sizeAt(axis)) - throw std::invalid_argument("ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or equal to dimension at given axis !"); + if (repeats.size() != 1 && repeats.size() != arr.sizeAt(axis)) + throw std::invalid_argument( + "ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or " + "equal to dimension at given axis !"); - std::vector outShape = arr.getShapeAsVector(); + std::vector outShape = arr.getShapeAsVector(); - if(repeats.size() == 1) - outShape[axis] *= repeats[0]; - else - outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0); + if (repeats.size() == 1) + outShape[axis] *= repeats[0]; + else + outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0); - return outShape; + return outShape; } ////////////////////////////////////////////////////////////////////////// // evaluate shapeInfo of permuted array - const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { - - if (!arr.nonNull()) - throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!"); +const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, + const int rank, + const NDArray& arr, + sd::memory::Workspace* workspace, + const bool setContigStrides) { + if (!arr.nonNull()) + throw std::runtime_error( + "ShapeUtils::evalPermShapeInfo static method: wrong arguments: array " + "is nullptr!"); - if (rank != arr.rankOf()) - throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!"); + if (rank != arr.rankOf()) + throw std::runtime_error( + "ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is " + "not suitable!"); - auto shapeInfoLength = shape::shapeInfoLength(rank); + auto shapeInfoLength = shape::shapeInfoLength(rank); - // allocate memory for new array - shapeInfo - Nd4jLong *shapeInfoNew = nullptr; - ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); + // allocate memory for new array - shapeInfo + Nd4jLong* shapeInfoNew = nullptr; + ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong); - // copy arr _shapeInfo into new array - memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank)); + // copy arr _shapeInfo into new array + memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank)); - // perform buffer permutation - shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf()); + // perform buffer permutation + shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf()); - if(setContigStrides) - shape::updateStrides(shapeInfoNew, arr.ordering()); + if (setContigStrides) shape::updateStrides(shapeInfoNew, arr.ordering()); - ShapeDescriptor descriptor(shapeInfoNew); + ShapeDescriptor descriptor(shapeInfoNew); - RELEASE(shapeInfoNew, workspace); - - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } + RELEASE(shapeInfoNew, workspace); - ////////////////////////////////////////////////////////////////////////// - // evaluate shapeInfo of permuted array - const Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) { + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); +} - std::vector dims(dimensions, dimensions + rank); - return evalPermShapeInfo(dims.data(), rank, arr, workspace); - } +////////////////////////////////////////////////////////////////////////// +// evaluate shapeInfo of permuted array +const Nd4jLong* ShapeUtils::evalPermShapeInfo( + const Nd4jLong* dimensions, const int rank, const NDArray& arr, + sd::memory::Workspace* workspace) { + std::vector dims(dimensions, dimensions + rank); + return evalPermShapeInfo(dims.data(), rank, arr, workspace); +} ////////////////////////////////////////////////////////////////////////// // evaluate shapeInfo of transposed array - const Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) { - - int rank = arr.rankOf(); - std::vector dimensions(rank); - for (int i = 0; i < rank; ++i) - dimensions[i] = rank - 1 - i; - - return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides); - } +const Nd4jLong* ShapeUtils::evalTranspShapeInfo( + const NDArray& arr, sd::memory::Workspace* workspace, + const bool setContigStrides) { + int rank = arr.rankOf(); + std::vector dimensions(rank); + for (int i = 0; i < rank; ++i) dimensions[i] = rank - 1 - i; + + return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, + setContigStrides); +} ////////////////////////////////////////////////////////////////////////// - bool ShapeUtils::copyVectorPart(std::vector& target, std::vector& source, int rank, int offset) { - if (source.size() < offset + rank) - return false; - - for (int e = offset; e < offset + rank; e++) - target.push_back(source[e]); +bool ShapeUtils::copyVectorPart(std::vector& target, + std::vector& source, int rank, + int offset) { + if (source.size() < offset + rank) return false; - return true; - } + for (int e = offset; e < offset + rank; e++) target.push_back(source[e]); + return true; +} ////////////////////////////////////////////////////////////////////////// -// return new (shorter) sorted dimensions array without dimensions that are present in input vector -std::vector ShapeUtils::evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions) { - - std::vector newDimensions; - if(dimsLen == 0) { // if input vector is empty then return whole shape range - newDimensions.resize(rank); - std::iota(newDimensions.begin(), newDimensions.end(), 0); // fill with 0, 1, ... rank-1 - } - else { - bool isAbsent; - for(int i=0; i= 0 ? dimensions[j] : dimensions[j] + rank; - if(i == dim) { - isAbsent = false; - break; - } - } - if(isAbsent) - newDimensions.emplace_back(i); +// return new (shorter) sorted dimensions array without dimensions that are +// present in input vector +std::vector ShapeUtils::evalDimsToExclude(const int rank, + const int dimsLen, + const int* dimensions) { + std::vector newDimensions; + if (dimsLen == 0) { // if input vector is empty then return whole shape range + newDimensions.resize(rank); + std::iota(newDimensions.begin(), newDimensions.end(), + 0); // fill with 0, 1, ... rank-1 + } else { + bool isAbsent; + for (int i = 0; i < rank; ++i) { + isAbsent = true; + for (int j = 0; j < dimsLen; ++j) { + int dim = dimensions[j] >= 0 ? dimensions[j] : dimensions[j] + rank; + if (i == dim) { + isAbsent = false; + break; } + } + if (isAbsent) newDimensions.emplace_back(i); } + } - return newDimensions; + return newDimensions; } ////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalDimsToExclude(const int rank, const std::vector& dimensions) { - - return ShapeUtils::evalDimsToExclude(rank, dimensions.size(), dimensions.data()); +std::vector ShapeUtils::evalDimsToExclude( + const int rank, const std::vector& dimensions) { + return ShapeUtils::evalDimsToExclude(rank, dimensions.size(), + dimensions.data()); } ////////////////////////////////////////////////////////////////////////// // check whether 2 arrays have mutually broadcastable shapes // shape comparison starts from the end -bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) { - return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo()); +bool ShapeUtils::areShapesBroadcastable(const NDArray& arr1, + const NDArray& arr2) { + return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo()); } -bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { - int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2); +bool ShapeUtils::areShapesBroadcastable(const Nd4jLong* shapeInfo1, + const Nd4jLong* shapeInfo2) { + int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) + ? shape::rank(shapeInfo1) + : shape::rank(shapeInfo2); - for (int i = -1; i >= -minRank; --i) - if (shape::sizeAt(shapeInfo1, i) != shape::sizeAt(shapeInfo2, i) && shape::sizeAt(shapeInfo1, i) != 1 && shape::sizeAt(shapeInfo2, i) != 1) - return false; + for (int i = -1; i >= -minRank; --i) + if (shape::sizeAt(shapeInfo1, i) != shape::sizeAt(shapeInfo2, i) && + shape::sizeAt(shapeInfo1, i) != 1 && shape::sizeAt(shapeInfo2, i) != 1) + return false; - return true; + return true; } - bool ShapeUtils::areShapesBroadcastable(const std::vector& shape1, const std::vector& shape2) { - - const auto rank1 = shape1.size(); - const auto rank2 = shape2.size(); - const int minRank = rank1 < rank2 ? rank1 : rank2; - - for (int i = 1; i <= minRank; ++i) - if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1) - return false; - - return true; - } - - ////////////////////////////////////////////////////////////////////////// - // check the possibility of broadcast operation, if true then return shapeInfo of resulting array - // if evalMinMax == false the array with larger rank has to be passed as first argument - bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { - return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, resultShapeInfo, workspace); - } - - bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) { - - // check whether broadcast operation is possible for input arrays - if(!areShapesBroadcastable(max, min)) - return false; - - auto maxShapeInfo = max; //max.shapeInfo(); - auto minShapeInfo = min; //min.shapeInfo(); - - if(evalMinMax && (shape::rank(max) < shape::rank(min))) { - maxShapeInfo = min; - minShapeInfo = max; - } - - const auto maxRank = shape::rank(maxShapeInfo); - const auto minRank = shape::rank(minShapeInfo); - - // evaluate shapeInfo for resulting array - if(resultShapeInfo != nullptr) - throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !"); - - Nd4jLong *tmpShapeInfo = nullptr; - ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); - - // FIXME: get rid of memcpy here - memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); - for (int i = 0; i < minRank; ++i) - if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0) - tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; - - ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); - - if (shape::isEmpty(max) || shape::isEmpty(min)) { - ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); - memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); - } - - ShapeDescriptor descriptor(tmpShapeInfo); - RELEASE(tmpShapeInfo, workspace); - resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - - return true; - } - - ////////////////////////////////////////////////////////////////////////// - // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo - bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) { - - if(resultShapeInfo != nullptr) - throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !"); - - int size = arrays.size(); - int maxRank = arrays[size - 1]->rankOf(); - - for(int i = 0; i < size - 1; ++i) { - if(arrays[i]->rankOf() > maxRank) - maxRank = arrays[i]->rankOf(); - for(int j = i + 1; j < size; ++j) - if(!areShapesBroadcastable(*arrays[i], *arrays[j])) - return false; - } - - Nd4jLong *tmpShapeInfo = nullptr; - ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); - memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank)); - tmpShapeInfo[0] = maxRank; - - for(const auto& item : arrays ) { - for(int i = -1; i >= -item->rankOf(); --i) - if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i)) - tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i); - } - - shape::updateStrides(tmpShapeInfo, arrays[0]->ordering()); - ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType()); - - ShapeDescriptor descriptor(tmpShapeInfo); - RELEASE(tmpShapeInfo, workspace); - resultShapeInfo = const_cast(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor)); - - return true; - } - - - ////////////////////////////////////////////////////////////////////////// - // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank - // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3} - std::vector ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) { +bool ShapeUtils::areShapesBroadcastable(const std::vector& shape1, + const std::vector& shape2) { + const auto rank1 = shape1.size(); + const auto rank2 = shape2.size(); + const int minRank = rank1 < rank2 ? rank1 : rank2; - const NDArray *min, *max; + for (int i = 1; i <= minRank; ++i) + if (shape1[rank1 - i] != shape2[rank2 - i] && shape1[rank1 - i] != 1 && + shape2[rank2 - i] != 1) + return false; - if(arr1.rankOf() >= arr2.rankOf()) { - max = &arr1; - min = &arr2; - } - else { - max = &arr2; - min = &arr1; - } - - const int rankDiff = max->rankOf() - min->rankOf(); - - std::vector dims; - - for (int i = 0; i < min->rankOf(); ++i) - if (min->sizeAt(i) == max->sizeAt(rankDiff + i)) - dims.emplace_back(rankDiff + i); + return true; +} - return dims; - } +////////////////////////////////////////////////////////////////////////// +// check the possibility of broadcast operation, if true then return shapeInfo +// of resulting array if evalMinMax == false the array with larger rank has to +// be passed as first argument +bool ShapeUtils::evalBroadcastShapeInfo(const NDArray& max, const NDArray& min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace) { + return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, + resultShapeInfo, workspace); +} - ////////////////////////////////////////////////////////////////////////// - // evaluate shapeInfo for resulting array from tile operation - const Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector& reps, sd::memory::Workspace* workspace) { - // check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing) - int repsSize = reps.size(); - Nd4jLong product = 1; - for(const auto& item : reps) - product *= item; - if(product == 0) - throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); - - int rankOld = arr.rankOf(); - int diff = rankOld - repsSize; - - // evaluate new shapeInfo - Nd4jLong* newShapeInfo = nullptr; - if(diff < 0) { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong); - newShapeInfo[0] = repsSize; // set new rank - for(int i=1; i <= -diff; ++i) - newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place - memcpy(newShapeInfo + 1 - diff, arr.shapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place - for(int i=1; i <= repsSize; ++i) - newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps - } - else { - ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong); - memcpy(newShapeInfo, arr.shapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo - for(int i=1; i <= repsSize; ++i) - newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps - } - shape::updateStrides(newShapeInfo, arr.ordering()); - ArrayOptions::setDataType(newShapeInfo, arr.dataType()); +bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong* max, + const Nd4jLong* min, + const bool evalMinMax, + const Nd4jLong*& resultShapeInfo, + sd::memory::Workspace* workspace) { + // check whether broadcast operation is possible for input arrays + if (!areShapesBroadcastable(max, min)) return false; + + auto maxShapeInfo = max; // max.shapeInfo(); + auto minShapeInfo = min; // min.shapeInfo(); + + if (evalMinMax && (shape::rank(max) < shape::rank(min))) { + maxShapeInfo = min; + minShapeInfo = max; + } + + const auto maxRank = shape::rank(maxShapeInfo); + const auto minRank = shape::rank(minShapeInfo); + + // evaluate shapeInfo for resulting array + if (resultShapeInfo != nullptr) + throw std::runtime_error( + "std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the " + "input pointer on shapeInfo must be empty (=nullptr) !"); + + Nd4jLong* tmpShapeInfo = nullptr; + ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); + + // FIXME: get rid of memcpy here + memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); + for (int i = 0; i < minRank; ++i) + if ((maxShapeInfo[maxRank - i] != 0 && + maxShapeInfo[maxRank - i] < minShapeInfo[minRank - i]) || + minShapeInfo[minRank - i] == 0) + tmpShapeInfo[maxRank - i] = minShapeInfo[minRank - i]; + + ShapeUtils::updateStridesAndType( + tmpShapeInfo, + DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), + shape::order(maxShapeInfo)); + + if (shape::isEmpty(max) || shape::isEmpty(min)) { + ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); + memset(shape::stride(tmpShapeInfo), 0, + shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); + } + + ShapeDescriptor descriptor(tmpShapeInfo); + RELEASE(tmpShapeInfo, workspace); + resultShapeInfo = ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); + + return true; +} - ShapeDescriptor descriptor(newShapeInfo); - RELEASE(newShapeInfo, workspace); - return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); - } +////////////////////////////////////////////////////////////////////////// +// check the possibility of broadcast operation for set of arrays, if true then +// return resulting broadcasted shapeInfo +bool ShapeUtils::evalCommonBroadcastShapeInfo( + const std::vector& arrays, Nd4jLong*& resultShapeInfo, + memory::Workspace* workspace) { + if (resultShapeInfo != nullptr) + throw std::runtime_error( + "ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on " + "shapeInfo must be empty (=nullptr) !"); + + int size = arrays.size(); + int maxRank = arrays[size - 1]->rankOf(); + + for (int i = 0; i < size - 1; ++i) { + if (arrays[i]->rankOf() > maxRank) maxRank = arrays[i]->rankOf(); + for (int j = i + 1; j < size; ++j) + if (!areShapesBroadcastable(*arrays[i], *arrays[j])) return false; + } + + Nd4jLong* tmpShapeInfo = nullptr; + ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong); + memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank)); + tmpShapeInfo[0] = maxRank; + + for (const auto& item : arrays) { + for (int i = -1; i >= -item->rankOf(); --i) + if (tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i)) + tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i); + } + + shape::updateStrides(tmpShapeInfo, arrays[0]->ordering()); + ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType()); + + ShapeDescriptor descriptor(tmpShapeInfo); + RELEASE(tmpShapeInfo, workspace); + resultShapeInfo = const_cast( + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor)); + + return true; +} - std::vector ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) { - std::vector shape(shape::rank(shapeInfo)); - int shapeSize = shape.size(); +////////////////////////////////////////////////////////////////////////// +// return sorted vector of dimensions common (same) for two arrays, dimensions +// values corresponds to array with bigger rank for example if arr1{2,7}, +// arr2{2,5,4,7} then vector = {0,3} +std::vector ShapeUtils::getDimsWithSameShape(const NDArray& arr1, + const NDArray& arr2) { + const NDArray *min, *max; + + if (arr1.rankOf() >= arr2.rankOf()) { + max = &arr1; + min = &arr2; + } else { + max = &arr2; + min = &arr1; + } + + const int rankDiff = max->rankOf() - min->rankOf(); + + std::vector dims; + + for (int i = 0; i < min->rankOf(); ++i) + if (min->sizeAt(i) == max->sizeAt(rankDiff + i)) + dims.emplace_back(rankDiff + i); + + return dims; +} - for (int e = 0; e < shapeSize; e++) - shape[e] = shape::shapeOf(shapeInfo)[e]; +////////////////////////////////////////////////////////////////////////// +// evaluate shapeInfo for resulting array from tile operation +const Nd4jLong* ShapeUtils::evalTileShapeInfo( + const NDArray& arr, const std::vector& reps, + sd::memory::Workspace* workspace) { + // check whether reps contains at least one zero (then throw exception) or + // whether all elements in reps are unities (then simply reshape or do + // nothing) + int repsSize = reps.size(); + Nd4jLong product = 1; + for (const auto& item : reps) product *= item; + if (product == 0) + throw std::runtime_error( + "NDArray::tile method: one of the elements in reps array is zero !"); + + int rankOld = arr.rankOf(); + int diff = rankOld - repsSize; + + // evaluate new shapeInfo + Nd4jLong* newShapeInfo = nullptr; + if (diff < 0) { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), + Nd4jLong); + newShapeInfo[0] = repsSize; // set new rank + for (int i = 1; i <= -diff; ++i) + newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand + // side of newShapeInfo shape place + memcpy( + newShapeInfo + 1 - diff, arr.shapeInfo() + 1, + rankOld * sizeof(Nd4jLong)); // copy old dimensions to the right-hand + // side of newShapeInfo shape place + for (int i = 1; i <= repsSize; ++i) + newShapeInfo[i] *= + reps[i - 1]; // set new shape by multiplying old dimensions by + // corresponding numbers from reps + } else { + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), + Nd4jLong); + memcpy(newShapeInfo, arr.shapeInfo(), + shape::shapeInfoByteLength( + rankOld)); // copy all elements of _shapeInfo to newShapeInfo + for (int i = 1; i <= repsSize; ++i) + newShapeInfo[rankOld + 1 - i] *= + reps[repsSize - i]; // set new shape by multiplying old dimensions by + // corresponding numbers from reps + } + shape::updateStrides(newShapeInfo, arr.ordering()); + ArrayOptions::setDataType(newShapeInfo, arr.dataType()); + + ShapeDescriptor descriptor(newShapeInfo); + RELEASE(newShapeInfo, workspace); + return ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(descriptor) + .primaryAsT(); +} - return shape; - } +std::vector ShapeUtils::pullShapeFromShapeInfo( + const Nd4jLong* shapeInfo) { + std::vector shape(shape::rank(shapeInfo)); + int shapeSize = shape.size(); - std::string ShapeUtils::shapeAsString(const NDArray* array) { - std::string result; + for (int e = 0; e < shapeSize; e++) shape[e] = shape::shapeOf(shapeInfo)[e]; - result.append("["); - for (int e = 0; e < array->rankOf(); e++) { - result += flatbuffers::NumToString(array->sizeAt(e)); - if (e < array->rankOf() - 1) - result.append(", "); - } - result.append("]"); + return shape; +} - return result; - } +std::string ShapeUtils::shapeAsString(const NDArray* array) { + std::string result; - std::string ShapeUtils::shapeAsString(const NDArray& array) { - return shapeAsString(&array); - } + result.append("["); + for (int e = 0; e < array->rankOf(); e++) { + result += flatbuffers::NumToString(array->sizeAt(e)); + if (e < array->rankOf() - 1) result.append(", "); + } + result.append("]"); - std::string ShapeUtils::strideAsString(const NDArray& array) { - return strideAsString(&array); - } + return result; +} - std::string ShapeUtils::strideAsString(const NDArray* array) { - std::string result; - - auto shapeBuffer = array->shapeInfo(); //Nd4jLong* - int rank = (int)*shapeBuffer; - result.append("["); - for (int e = 0; e < rank; e++) { - if (e > 0) - result.append(","); - Nd4jLong stride = *(shapeBuffer + rank+1+e); - result += flatbuffers::NumToString(stride); - } - result.append("]"); +std::string ShapeUtils::shapeAsString(const NDArray& array) { + return shapeAsString(&array); +} - return result; - } +std::string ShapeUtils::strideAsString(const NDArray& array) { + return strideAsString(&array); +} - std::string ShapeUtils::shapeAsString(const std::vector& shape) { - std::string result; +std::string ShapeUtils::strideAsString(const NDArray* array) { + std::string result; - result.append("["); - for (int e = 0; e < shape.size(); e++) { - result += flatbuffers::NumToString(shape.at(e)); - if (e < shape.size() - 1) - result.append(", "); - } - result.append("]"); + auto shapeBuffer = array->shapeInfo(); // Nd4jLong* + int rank = (int)*shapeBuffer; + result.append("["); + for (int e = 0; e < rank; e++) { + if (e > 0) result.append(","); + Nd4jLong stride = *(shapeBuffer + rank + 1 + e); + result += flatbuffers::NumToString(stride); + } + result.append("]"); - return result; - } + return result; +} - std::string ShapeUtils::shapeAsString(const Nd4jLong* shapeInfo) { +std::string ShapeUtils::shapeAsString(const std::vector& shape) { + std::string result; - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); + result.append("["); + for (int e = 0; e < shape.size(); e++) { + result += flatbuffers::NumToString(shape.at(e)); + if (e < shape.size() - 1) result.append(", "); + } + result.append("]"); - std::string result; + return result; +} - result.append("["); - for (int e = 0; e < shapeInfo[0]; e++) { - result += flatbuffers::NumToString(shapeInfo[e+1]); - if (e < shapeInfo[0] - 1) - result.append(", "); - } - result.append("]"); +std::string ShapeUtils::shapeAsString(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - return result; - } + std::string result; - std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) { + result.append("["); + for (int e = 0; e < shapeInfo[0]; e++) { + result += flatbuffers::NumToString(shapeInfo[e + 1]); + if (e < shapeInfo[0] - 1) result.append(", "); + } + result.append("]"); - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); + return result; +} - std::string result; +std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - int len = shape::shapeInfoLength(shapeInfo[0]); + std::string result; - result.append("["); - for (int e = 0; e < len; e++) { - result += flatbuffers::NumToString(shapeInfo[e]); - if (e < len - 1) - result.append(", "); - } - result.append("]"); + int len = shape::shapeInfoLength(shapeInfo[0]); - return result; - } + result.append("["); + for (int e = 0; e < len; e++) { + result += flatbuffers::NumToString(shapeInfo[e]); + if (e < len - 1) result.append(", "); + } + result.append("]"); + return result; +} - std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) { - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !"); +std::string ShapeUtils::shapeAsString(const int rank, + const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr " + "!"); - std::string result; + std::string result; - result.append("["); - for (int e = 0; e < rank; e++) { - result += flatbuffers::NumToString(shapeInfo[e]); - if (e < rank - 1) - result.append(", "); - } - result.append("]"); + result.append("["); + for (int e = 0; e < rank; e++) { + result += flatbuffers::NumToString(shapeInfo[e]); + if (e < rank - 1) result.append(", "); + } + result.append("]"); - return result; - } + return result; +} ////////////////////////////////////////////////////////////////////////// std::vector ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) { + if (!shapeInfo) + throw std::runtime_error( + "ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr " + "!"); - if(!shapeInfo) - throw std::runtime_error("ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr !"); - - std::vector vector(shapeInfo[0]); + std::vector vector(shapeInfo[0]); - for (uint e = 0; e < shapeInfo[0]; e++) - vector[e] = shapeInfo[e + 1]; + for (uint e = 0; e < shapeInfo[0]; e++) vector[e] = shapeInfo[e + 1]; - return vector; + return vector; } ////////////////////////////////////////////////////////////////////////// -// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal - const Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){ - auto shapeInfo = const_cast(shapeInfoConst); - - const auto rank = shape::rank(shapeInfo); - - Nd4jLong* outputShapeInfo = nullptr; - - if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) { - ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); - outputShapeInfo[0] = 2; - outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo); - } - else { - ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong); - outputShapeInfo[0] = 2*rank; - for(int i = 1; i <= rank; ++i) - outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo)); - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo); - RELEASE(outputShapeInfo, workspace); - return result; - } +// evaluate shapeInfo for diagonal array which is made using input arr elements +// as diagonal +const Nd4jLong* ShapeUtils::evalDiagShapeInfo( + const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace) { + auto shapeInfo = const_cast(shapeInfoConst); + + const auto rank = shape::rank(shapeInfo); + + Nd4jLong* outputShapeInfo = nullptr; + + if (shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) { + ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); + outputShapeInfo[0] = 2; + outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo); + } else { + ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2 * rank), + Nd4jLong); + outputShapeInfo[0] = 2 * rank; + for (int i = 1; i <= rank; ++i) + outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, + shape::order(shapeInfo)); + + auto result = + ConstantShapeHelper::getInstance()->createShapeInfo(outputShapeInfo); + RELEASE(outputShapeInfo, workspace); + return result; +} -std::vector ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) { - // rRank >= oRank always !! - const auto oRank = shape::rank(operandShapeInfo); - const auto rRank = shape::rank(resultShapeInfo); - const auto diff = rRank - oRank; - std::vector axis; +std::vector ShapeUtils::evalBroadcastBackwardAxis( + const Nd4jLong* operandShapeInfo, const Nd4jLong* resultShapeInfo) { + // rRank >= oRank always !! + const auto oRank = shape::rank(operandShapeInfo); + const auto rRank = shape::rank(resultShapeInfo); + const auto diff = rRank - oRank; + std::vector axis; - for(int i = 0; i < rRank; ++i) - if(i < diff || shape::sizeAt(operandShapeInfo, i - diff) != shape::sizeAt(resultShapeInfo, i)) - axis.push_back(i); + for (int i = 0; i < rRank; ++i) + if (i < diff || shape::sizeAt(operandShapeInfo, i - diff) != + shape::sizeAt(resultShapeInfo, i)) + axis.push_back(i); - return axis; + return axis; } //////////////////////////////////////////////////////////////////////////////// - const Nd4jLong* ShapeUtils::matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, sd::memory::Workspace* workspace) { - auto inA = theFirstShape; - auto inB = theSecondShape; - Nd4jLong *shape; - ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong); - - Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace); - Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace); - - if (shouldTranspondFirst) - shape::transposeInplace(tmpA); - - if (shouldTranspondSecond) - shape::transposeInplace(tmpB); - - - if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) { - // special case here - shape[0] = 1; - shape[1] = tmpB[2]; - Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace); - - RELEASE(shape, workspace); - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - - return newShape; - } else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) { - // just scalar vs scalar - shape[0] = 1; - shape[1] = 1; - } else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) { - // gemv case - if (shape::rank(tmpB) == 2) { - shape[0] = tmpA[1]; - shape[1] = tmpB[2]; - } else { - // we have new 1D shape here - auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace); - - RELEASE(shape, workspace); - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - - return newShape; - } - } else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) || - (shape::isVector(tmpA) && shape::isMatrix(tmpB)) || - (shape::isColumnVector(tmpA) && shape::isVector(tmpB))) { - // gemm case - shape[0] = tmpA[1]; - shape[1] = tmpB[2]; - } else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) || - (shape::isScalar(tmpA) && shape::isVector(tmpB))) { - // element-wise - shape[0] = 1; - shape[1] = (int) sd::math::nd4j_max(shape::length(tmpA), shape::length(tmpB)); - } else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) { - // dot case - shape[0] = 1; - shape[1] = 1; - } else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) { - // dot case - shape[0] = 1; - shape[1] = 1; - } - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'f', 2, shape); - - RELEASE(shape, workspace); - - RELEASE(tmpA, workspace); - RELEASE(tmpB, workspace); - return newShape; +const Nd4jLong* ShapeUtils::matrixProductShape( + const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, + bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType dtype, + sd::memory::Workspace* workspace) { + auto inA = theFirstShape; + auto inB = theSecondShape; + Nd4jLong* shape; + ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong); + + Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace); + Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace); + + if (shouldTranspondFirst) shape::transposeInplace(tmpA); + + if (shouldTranspondSecond) shape::transposeInplace(tmpB); + + if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) { + // special case here + shape[0] = 1; + shape[1] = tmpB[2]; + Nd4jLong* newShape = + ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace); + + RELEASE(shape, workspace); + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + + return newShape; + } else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) { + // just scalar vs scalar + shape[0] = 1; + shape[1] = 1; + } else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) { + // gemv case + if (shape::rank(tmpB) == 2) { + shape[0] = tmpA[1]; + shape[1] = tmpB[2]; + } else { + // we have new 1D shape here + auto newShape = + ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace); + + RELEASE(shape, workspace); + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + + return newShape; } + } else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) || + (shape::isVector(tmpA) && shape::isMatrix(tmpB)) || + (shape::isColumnVector(tmpA) && shape::isVector(tmpB))) { + // gemm case + shape[0] = tmpA[1]; + shape[1] = tmpB[2]; + } else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) || + (shape::isScalar(tmpA) && shape::isVector(tmpB))) { + // element-wise + shape[0] = 1; + shape[1] = (int)sd::math::nd4j_max(shape::length(tmpA), + shape::length(tmpB)); + } else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) { + // dot case + shape[0] = 1; + shape[1] = 1; + } else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) { + // dot case + shape[0] = 1; + shape[1] = 1; + } + + auto newShape = + ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'f', 2, shape); + + RELEASE(shape, workspace); + + RELEASE(tmpA, workspace); + RELEASE(tmpB, workspace); + return newShape; +} //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalPermutFromTo(const std::vector& shapeFrom, const std::vector& shapeTo) { - auto rank = shapeFrom.size(); - if(rank != shapeTo.size()) - throw std::runtime_error("ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation !"); - - if (std::equal(begin(shapeFrom), end(shapeFrom), begin(shapeTo))) // if shapes are identical (permutation is unnecessary) then return empty vector - return std::vector(); - - std::vector permutation(rank, -2); // vector to be returned - std::vector shapeTo2(shapeTo); // make copy of const vector since we will change the content of shapeTo - - for(int i=0; i ShapeUtils::evalPermutFromTo( + const std::vector& shapeFrom, + const std::vector& shapeTo) { + auto rank = shapeFrom.size(); + if (rank != shapeTo.size()) + throw std::runtime_error( + "ShapeUtils::evalPermutFromTo static method: the input shapes are not " + "suitable for mutual permutation !"); + + if (std::equal(begin(shapeFrom), end(shapeFrom), + begin(shapeTo))) // if shapes are identical (permutation is + // unnecessary) then return empty vector + return std::vector(); + + std::vector permutation(rank, -2); // vector to be returned + std::vector shapeTo2( + shapeTo); // make copy of const vector since we will change the content + // of shapeTo + + for (int i = 0; i < rank; ++i) + for (int j = 0; j < rank; ++j) + if (shapeFrom[i] == shapeTo2[j]) { + permutation[j] = i; + shapeTo2[j] = -2; // mark coincidence as -2 in order to not account + // index of shapeTo twice + break; + } + + if (std::find(begin(permutation), end(permutation), -2) != + end(permutation)) // if -2 is still present in vector then permutation is + // impossible + throw std::runtime_error( + "ShapeUtils::evalPermutFromTo static method: the input shapes are not " + "suitable for mutual permutation !"); + + return permutation; } - //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::composeShapeUsingDimsAndIdx(const std::vector& dimsAndIdx) { - auto size = dimsAndIdx.size(); - if(size % 2 != 0) - throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of input vector must be even !"); - - size /= 2; - - std::vector shape(size); - int index; - - for(int i = 0; i < size; ++i) { - index = dimsAndIdx[i + size]; - if(index > size-1) - throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: input index is too large !"); - shape[index] = dimsAndIdx[i]; - } - - return shape; +std::vector ShapeUtils::composeShapeUsingDimsAndIdx( + const std::vector& dimsAndIdx) { + auto size = dimsAndIdx.size(); + if (size % 2 != 0) + throw std::runtime_error( + "ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of " + "input vector must be even !"); + + size /= 2; + + std::vector shape(size); + int index; + + for (int i = 0; i < size; ++i) { + index = dimsAndIdx[i + size]; + if (index > size - 1) + throw std::runtime_error( + "ShapeUtils::composeShapeUsingDimsAndIdx static method: input index " + "is too large !"); + shape[index] = dimsAndIdx[i]; + } + + return shape; } - //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY) { - - const auto xRank = xShapeInfo[0]; - const auto yRank = yShapeInfo[0]; - - const Nd4jLong x0Dim = transX ? xShapeInfo[xRank] : xShapeInfo[xRank-1]; - const Nd4jLong y0Dim = transY ? yShapeInfo[yRank] : yShapeInfo[yRank-1]; - const Nd4jLong x1Dim = transX ? xShapeInfo[xRank-1] : xShapeInfo[xRank]; - const Nd4jLong y1Dim = transY ? yShapeInfo[yRank-1] : yShapeInfo[yRank]; - - - if(xRank == 1 && yRank == 1) { // dot case, output is scalar - if(xShapeInfo[1] != yShapeInfo[1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]); - throw std::invalid_argument(""); - } - return std::vector({}); +std::vector ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, + const Nd4jLong* yShapeInfo, + const bool transX, + const bool transY) { + const auto xRank = xShapeInfo[0]; + const auto yRank = yShapeInfo[0]; + + const Nd4jLong x0Dim = transX ? xShapeInfo[xRank] : xShapeInfo[xRank - 1]; + const Nd4jLong y0Dim = transY ? yShapeInfo[yRank] : yShapeInfo[yRank - 1]; + const Nd4jLong x1Dim = transX ? xShapeInfo[xRank - 1] : xShapeInfo[xRank]; + const Nd4jLong y1Dim = transY ? yShapeInfo[yRank - 1] : yShapeInfo[yRank]; + + if (xRank == 1 && yRank == 1) { // dot case, output is scalar + if (xShapeInfo[1] != yShapeInfo[1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: since input arrays are " + "vectors they must have the same length, but got x length = %i, y " + "length = %i !", + xShapeInfo[1], yShapeInfo[1]); + throw std::invalid_argument(""); } - - - if(xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - if(xShapeInfo[1] != y0Dim) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } - return std::vector({y1Dim}); - } - - - if(xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - if(x1Dim != yShapeInfo[1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } - return std::vector({x0Dim}); + return std::vector({}); + } + + if (xRank == 1 && + yRank == + 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector + if (xShapeInfo[1] != y0Dim) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: input arrays have " + "inconsistent shapes for vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - - - // rest cases - usual 2Dx2D or batched mmul - if(xRank != yRank) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: the ranks of arrays must be the same, but got xRank = %i and yRank = %i ! \n", xRank, yRank); - throw std::invalid_argument(""); + return std::vector({y1Dim}); + } + + if (xRank == 2 && + yRank == + 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector + if (x1Dim != yShapeInfo[1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul method: input arrays have " + "inconsistent shapes for vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - - if(x1Dim != y0Dim) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim %i != yDim %i \n", x1Dim, y0Dim); - throw std::invalid_argument(""); + return std::vector({x0Dim}); + } + + // rest cases - usual 2Dx2D or batched mmul + if (xRank != yRank) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: the ranks of arrays " + "must be the same, but got xRank = %i and yRank = %i ! \n", + xRank, yRank); + throw std::invalid_argument(""); + } + + if (x1Dim != y0Dim) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: input shapes are " + "inconsistent: xDim %i != yDim %i \n", + x1Dim, y0Dim); + throw std::invalid_argument(""); + } + + for (int i = 0; i < xRank - 2; ++i) + if (xShapeInfo[i + 1] != yShapeInfo[i + 1]) { + nd4j_printf( + "ShapeUtils::evalShapeForMatmul static method: input shapes are " + "inconsistent: xShape = %s, yShape = %s ! \n", + ShapeUtils::shapeAsString(xShapeInfo).c_str(), + ShapeUtils::shapeAsString(yShapeInfo).c_str()); + throw std::invalid_argument(""); } - for(int i = 0; i < xRank - 2; ++i) - if(xShapeInfo[i+1] != yShapeInfo[i+1]) { - nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xShape = %s, yShape = %s ! \n", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str()); - throw std::invalid_argument(""); - } + std::vector cShape(xRank); - std::vector cShape(xRank); + // copy batch part of shape (if present) + for (int i = 0; i < xRank - 2; ++i) cShape[i] = xShapeInfo[i + 1]; + // copy rest part of shape (two dims: multiplication part) + cShape[xRank - 2] = x0Dim; + cShape[xRank - 1] = y1Dim; - // copy batch part of shape (if present) - for(int i = 0; i < xRank - 2; ++i) - cShape[i] = xShapeInfo[i+1]; - // copy rest part of shape (two dims: multiplication part) - cShape[xRank-2] = x0Dim; - cShape[xRank-1] = y1Dim; - - return cShape; + return cShape; } //////////////////////////////////////////////////////////////////////////////// -Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector& dimsToExclude) { - - Nd4jLong numOfSubArrs = 1; +Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, + const std::vector& dimsToExclude) { + Nd4jLong numOfSubArrs = 1; - if(dimsToExclude.size() == shape::rank(shapeInfo) || dimsToExclude.size() == 0) // means there is only one sub-array and it coincides with whole array - return numOfSubArrs; + if (dimsToExclude.size() == shape::rank(shapeInfo) || + dimsToExclude.size() == 0) // means there is only one sub-array and it + // coincides with whole array + return numOfSubArrs; - for(const auto& dim : dimsToExclude) - numOfSubArrs *= shapeInfo[dim + 1]; + for (const auto& dim : dimsToExclude) numOfSubArrs *= shapeInfo[dim + 1]; - return numOfSubArrs; + return numOfSubArrs; } //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::evalDimsWithoutUnities(const Nd4jLong* shapeInfo) { - - std::vector result; - for(int i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - result.push_back(shapeInfo[i]); +std::vector ShapeUtils::evalDimsWithoutUnities( + const Nd4jLong* shapeInfo) { + std::vector result; + for (int i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) result.push_back(shapeInfo[i]); - return result; + return result; } //////////////////////////////////////////////////////////////////////////////// -void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order) { - - shape::updateStrides(dest, order); - ArrayOptions::copyDataType(dest, source); +void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, + const char order) { + shape::updateStrides(dest, order); + ArrayOptions::copyDataType(dest, source); } //////////////////////////////////////////////////////////////////////////////// -void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order) { - - shape::updateStrides(dest, order); - ArrayOptions::setDataType(dest, dtype); +void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const DataType dtype, + const char order) { + shape::updateStrides(dest, order); + ArrayOptions::setDataType(dest, dtype); } //////////////////////////////////////////////////////////////////////////////// -std::vector ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min) { - - const int maxRank = max.rankOf(); - const int minRank = min.rankOf(); - const int diff = maxRank - minRank; - - Nd4jLong numOfMinTads(1), numOfMaxTads(1); - std::vector maxTadDims; - - for(int i = 0; i < minRank; ++i) { - if(min.sizeAt(i) == max.sizeAt(diff + i)) - maxTadDims.push_back(diff + i); - else { - numOfMinTads *= min.sizeAt(i); - numOfMaxTads *= max.sizeAt(i); - } +std::vector ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, + const NDArray& min) { + const int maxRank = max.rankOf(); + const int minRank = min.rankOf(); + const int diff = maxRank - minRank; + + Nd4jLong numOfMinTads(1), numOfMaxTads(1); + std::vector maxTadDims; + + for (int i = 0; i < minRank; ++i) { + if (min.sizeAt(i) == max.sizeAt(diff + i)) + maxTadDims.push_back(diff + i); + else { + numOfMinTads *= min.sizeAt(i); + numOfMaxTads *= max.sizeAt(i); } + } - if(min.lengthOf() > max.lengthOf()) { // in this case tad is max array - for(int i = 0; i < diff; ++i) - numOfMaxTads *= max.sizeAt(i); + if (min.lengthOf() > max.lengthOf()) { // in this case tad is max array + for (int i = 0; i < diff; ++i) numOfMaxTads *= max.sizeAt(i); - return numOfMaxTads == 1 ? maxTadDims : std::vector(); - } + return numOfMaxTads == 1 ? maxTadDims : std::vector(); + } - return numOfMinTads == 1 ? maxTadDims : std::vector(); + return numOfMinTads == 1 ? maxTadDims : std::vector(); } -void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) { - - int yRank = shape::rank(inShapeInfo); - auto yOrigStride = shape::stride(inShapeInfo); - - if (yRank == nRank) { - for (int i = 0; i < yRank; ++i) { - // x[2,3,4] * y[2,1,4] = z[2,3,4] - outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i]; - } +void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, + const int nRank, + const int dimsSize, + const int* dims, + Nd4jLong* outStrides) { + int yRank = shape::rank(inShapeInfo); + auto yOrigStride = shape::stride(inShapeInfo); + + if (yRank == nRank) { + for (int i = 0; i < yRank; ++i) { + // x[2,3,4] * y[2,1,4] = z[2,3,4] + outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i]; } - else { + } else { + auto dimEx = sd::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims); - auto dimEx = sd::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims); - - for (int i = 0, it = 0; i < nRank; ++i) { - auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i); - outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0; - if (it == yRank) - break; - } + for (int i = 0, it = 0; i < nRank; ++i) { + auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i); + outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0; + if (it == yRank) break; } + } } -bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector& shapeOnly) { - - if(shape::rank(shapeInfo) != shapeOnly.size()) - return false; +bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, + const std::vector& shapeOnly) { + if (shape::rank(shapeInfo) != shapeOnly.size()) return false; - for(uint i = 0; i < shape::rank(shapeInfo); ++i) - if(shape::shapeOf(shapeInfo)[i] != shapeOnly[i]) - return false; + for (uint i = 0; i < shape::rank(shapeInfo); ++i) + if (shape::shapeOf(shapeInfo)[i] != shapeOnly[i]) return false; - return true; + return true; } //////////////////////////////////////////////////////////////////////////////// /* -bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims) { +bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, +std::vector& sameDims) { if(!sameDims.empty()) sameDims.clear(); @@ -1089,7 +1239,8 @@ bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::v int numUnitiesInMin = 0; - for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= -min->rankOf(); ) { + for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= +-min->rankOf(); ) { if(max->sizeAt(iMax) == 1) { // ignore unities in shape --iMax; @@ -1114,6 +1265,4 @@ bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::v } */ -} - - +} // namespace sd diff --git a/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp b/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp index 52682b925fee..b6af27fbd813 100644 --- a/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp +++ b/libnd4j/include/helpers/impl/SimpleReadWriteLock.cpp @@ -20,51 +20,47 @@ #include - namespace sd { - SimpleReadWriteLock::SimpleReadWriteLock(const SimpleReadWriteLock& other) { - _read_locks.store(other._read_locks.load()); - _write_locks.store(other._write_locks.load()); - } +SimpleReadWriteLock::SimpleReadWriteLock(const SimpleReadWriteLock& other) { + _read_locks.store(other._read_locks.load()); + _write_locks.store(other._write_locks.load()); +} - SimpleReadWriteLock::SimpleReadWriteLock(){ - _read_locks.store(0); - _write_locks.store(0); - } +SimpleReadWriteLock::SimpleReadWriteLock() { + _read_locks.store(0); + _write_locks.store(0); +} - void SimpleReadWriteLock::lockRead() { - _mutex.lock(); - _read_locks++; - while(_write_locks.load() > 0) { - // just loop - } - _mutex.unlock(); - } +void SimpleReadWriteLock::lockRead() { + _mutex.lock(); + _read_locks++; + while (_write_locks.load() > 0) { + // just loop + } + _mutex.unlock(); +} - void SimpleReadWriteLock::unlockRead() { - _read_locks--; - } +void SimpleReadWriteLock::unlockRead() { _read_locks--; } - // write lock - void SimpleReadWriteLock::lockWrite() { - _mutex.lock(); - _write_locks++; - while (_read_locks.load() > 0) { - // just loop - } - _mutex.unlock(); - } +// write lock +void SimpleReadWriteLock::lockWrite() { + _mutex.lock(); + _write_locks++; + while (_read_locks.load() > 0) { + // just loop + } + _mutex.unlock(); +} - void SimpleReadWriteLock::unlockWrite() { - _write_locks--; - } +void SimpleReadWriteLock::unlockWrite() { _write_locks--; } - SimpleReadWriteLock& SimpleReadWriteLock::operator= ( const SimpleReadWriteLock &other) { - if (this == &other) return *this; +SimpleReadWriteLock& SimpleReadWriteLock::operator=( + const SimpleReadWriteLock& other) { + if (this == &other) return *this; - this->_write_locks.store(other._write_locks.load()); - this->_read_locks.store(other._read_locks.load()); + this->_write_locks.store(other._write_locks.load()); + this->_read_locks.store(other._read_locks.load()); - return *this; - } -} \ No newline at end of file + return *this; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/StringUtils.cpp b/libnd4j/include/helpers/impl/StringUtils.cpp index 5ac2fd8cc231..da60a46e2bab 100644 --- a/libnd4j/include/helpers/impl/StringUtils.cpp +++ b/libnd4j/include/helpers/impl/StringUtils.cpp @@ -20,139 +20,141 @@ // @author Oleg Semeniv // -#include #include +#include namespace sd { - static FORCEINLINE bool match(const uint8_t *haystack, const uint8_t *needle, uint64_t length) { - for (int e = 0; e < length; e++) - if (haystack[e] != needle[e]) - return false; - - return true; - } - - uint64_t StringUtils::countSubarrays(const void *vhaystack, uint64_t haystackLength, const void *vneedle, uint64_t needleLength) { - auto haystack = reinterpret_cast(vhaystack); - auto needle = reinterpret_cast(vneedle); - - uint64_t number = 0; - - for (uint64_t e = 0; e < haystackLength - needleLength; e++) { - if (match(&haystack[e], needle, needleLength)) - number++; - } - - return number; - } - - - uint64_t StringUtils::byteLength(const NDArray &array) { - if (!array.isS()) - throw sd::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); - - auto buffer = array.bufferAsT(); - return buffer[array.lengthOf()]; - } - - std::vector StringUtils::split(const std::string &haystack, const std::string &delimiter) { - std::vector output; - - std::string::size_type prev_pos = 0, pos = 0; - - // iterating through the haystack till the end - while((pos = haystack.find(delimiter, pos)) != std::string::npos) { - output.emplace_back(haystack.substr(prev_pos, pos-prev_pos)); - prev_pos = ++pos; - } - - output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word - - return output; - } - - bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) { - - if (u8.empty()) - return false; - - u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t)); - if (u8.size() == u16.size()) - u16.assign(u8.begin(), u8.end()); - else - return unicode::utf8to16(u8.data(), &u16[0], u8.size()); - - return true; - } - - bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) { - - if (u8.empty()) - return false; - - u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) ); - if (u8.size() == u32.size()) - u32.assign(u8.begin(), u8.end()); - else - return unicode::utf8to32(u8.data(), &u32[0], u8.size()); - - return true; - } - - bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) { - - if (u16.empty()) - return false; - - u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t)); - if (u16.size() == u32.size()) - u32.assign(u16.begin(), u16.end()); - else - return unicode::utf16to32(u16.data(), &u32[0], u16.size()); - - return true; - } - - bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) { - - if (u16.empty()) - return false; - - u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size())); - if (u16.size() == u8.size()) - u8.assign(u16.begin(), u16.end()); - else - return unicode::utf16to8(u16.data(), &u8[0], u16.size()); - - return true; - } - - bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) { - - if (u32.empty()) - return false; - - u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t)); - if (u32.size() == u16.size()) - u16.assign(u32.begin(), u32.end()); - else - return unicode::utf32to16(u32.data(), &u16[0], u32.size()); - - return true; - } - - bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) { - - if (u32.empty()) - return false; - - u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size())); - if (u32.size() == u8.size()) - u8.assign(u32.begin(), u32.end()); - else - return unicode::utf32to8(u32.data(), &u8[0], u32.size()); - - return true; - } +static FORCEINLINE bool match(const uint8_t* haystack, const uint8_t* needle, + uint64_t length) { + for (int e = 0; e < length; e++) + if (haystack[e] != needle[e]) return false; + return true; } + +uint64_t StringUtils::countSubarrays(const void* vhaystack, + uint64_t haystackLength, + const void* vneedle, + uint64_t needleLength) { + auto haystack = reinterpret_cast(vhaystack); + auto needle = reinterpret_cast(vneedle); + + uint64_t number = 0; + + for (uint64_t e = 0; e < haystackLength - needleLength; e++) { + if (match(&haystack[e], needle, needleLength)) number++; + } + + return number; +} + +uint64_t StringUtils::byteLength(const NDArray& array) { + if (!array.isS()) + throw sd::datatype_exception::build( + "StringUtils::byteLength expects one of String types;", + array.dataType()); + + auto buffer = array.bufferAsT(); + return buffer[array.lengthOf()]; +} + +std::vector StringUtils::split(const std::string& haystack, + const std::string& delimiter) { + std::vector output; + + std::string::size_type prev_pos = 0, pos = 0; + + // iterating through the haystack till the end + while ((pos = haystack.find(delimiter, pos)) != std::string::npos) { + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); + prev_pos = ++pos; + } + + output.emplace_back(haystack.substr(prev_pos, pos - prev_pos)); // Last word + + return output; +} + +bool StringUtils::u8StringToU16String(const std::string& u8, + std::u16string& u16) { + if (u8.empty()) return false; + + u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / + sizeof(char16_t)); + if (u8.size() == u16.size()) + u16.assign(u8.begin(), u8.end()); + else + return unicode::utf8to16(u8.data(), &u16[0], u8.size()); + + return true; +} + +bool StringUtils::u8StringToU32String(const std::string& u8, + std::u32string& u32) { + if (u8.empty()) return false; + + u32.resize(unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / + sizeof(char32_t)); + if (u8.size() == u32.size()) + u32.assign(u8.begin(), u8.end()); + else + return unicode::utf8to32(u8.data(), &u32[0], u8.size()); + + return true; +} + +bool StringUtils::u16StringToU32String(const std::u16string& u16, + std::u32string& u32) { + if (u16.empty()) return false; + + u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / + sizeof(char32_t)); + if (u16.size() == u32.size()) + u32.assign(u16.begin(), u16.end()); + else + return unicode::utf16to32(u16.data(), &u32[0], u16.size()); + + return true; +} + +bool StringUtils::u16StringToU8String(const std::u16string& u16, + std::string& u8) { + if (u16.empty()) return false; + + u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size())); + if (u16.size() == u8.size()) + u8.assign(u16.begin(), u16.end()); + else + return unicode::utf16to8(u16.data(), &u8[0], u16.size()); + + return true; +} + +bool StringUtils::u32StringToU16String(const std::u32string& u32, + std::u16string& u16) { + if (u32.empty()) return false; + + u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / + sizeof(char16_t)); + if (u32.size() == u16.size()) + u16.assign(u32.begin(), u32.end()); + else + return unicode::utf32to16(u32.data(), &u16[0], u32.size()); + + return true; +} + +bool StringUtils::u32StringToU8String(const std::u32string& u32, + std::string& u8) { + if (u32.empty()) return false; + + u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size())); + if (u32.size() == u8.size()) + u8.assign(u32.begin(), u32.end()); + else + return unicode::utf32to8(u32.data(), &u8[0], u32.size()); + + return true; +} + +} // namespace sd diff --git a/libnd4j/include/helpers/impl/TAD.cpp b/libnd4j/include/helpers/impl/TAD.cpp index 5d31827da27d..79261707954a 100644 --- a/libnd4j/include/helpers/impl/TAD.cpp +++ b/libnd4j/include/helpers/impl/TAD.cpp @@ -18,10 +18,7 @@ // @author Adam Gibson // - #include #include -namespace shape { - -} \ No newline at end of file +namespace shape {} \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/helper_hash.cpp b/libnd4j/include/helpers/impl/helper_hash.cpp index 0f45c8cb3616..73a4d60fecd5 100644 --- a/libnd4j/include/helpers/impl/helper_hash.cpp +++ b/libnd4j/include/helpers/impl/helper_hash.cpp @@ -22,50 +22,47 @@ #include namespace sd { - namespace ops { +namespace ops { - HashHelper* HashHelper::getInstance() { - if (_INSTANCE == 0) - _INSTANCE = new HashHelper(); +HashHelper* HashHelper::getInstance() { + if (_INSTANCE == 0) _INSTANCE = new HashHelper(); - return _INSTANCE; - } - - Nd4jLong HashHelper::getLongHash(const std::string& str) { - _locker.lock(); - if (!_isInit) { - nd4j_verbose("Building HashUtil table\n",""); - - Nd4jLong h = 0x544B2FBACAAF1684L; - for (int i = 0; i < 256; i++) { - for (int j = 0; j < 31; j++) { - h = (((unsigned long long) h) >> 7) ^ h; - h = (h << 11) ^ h; - h = (((unsigned long long) h) >> 10) ^ h; - } - _byteTable[i] = h; - } + return _INSTANCE; +} +Nd4jLong HashHelper::getLongHash(const std::string& str) { + _locker.lock(); + if (!_isInit) { + nd4j_verbose("Building HashUtil table\n", ""); - _isInit = true; - } + Nd4jLong h = 0x544B2FBACAAF1684L; + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 31; j++) { + h = (((unsigned long long)h) >> 7) ^ h; + h = (h << 11) ^ h; + h = (((unsigned long long)h) >> 10) ^ h; + } + _byteTable[i] = h; + } - _locker.unlock(); + _isInit = true; + } - Nd4jLong h = HSTART; - Nd4jLong hmult = HMULT; - Nd4jLong len = str.size(); - for (int i = 0; i < len; i++) { - char ch = str.at(i); - auto uch = (unsigned char) ch; - h = (h * hmult) ^ _byteTable[ch & 0xff]; - h = (h * hmult) ^ _byteTable[(uch >> 8) & 0xff]; - } + _locker.unlock(); - return h; - } + Nd4jLong h = HSTART; + Nd4jLong hmult = HMULT; + Nd4jLong len = str.size(); + for (int i = 0; i < len; i++) { + char ch = str.at(i); + auto uch = (unsigned char)ch; + h = (h * hmult) ^ _byteTable[ch & 0xff]; + h = (h * hmult) ^ _byteTable[(uch >> 8) & 0xff]; + } - sd::ops::HashHelper* sd::ops::HashHelper::_INSTANCE = 0; - } + return h; } +sd::ops::HashHelper* sd::ops::HashHelper::_INSTANCE = 0; +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/helpers/impl/logger.cpp b/libnd4j/include/helpers/impl/logger.cpp index 59d8f98bc1de..07110a2b5957 100644 --- a/libnd4j/include/helpers/impl/logger.cpp +++ b/libnd4j/include/helpers/impl/logger.cpp @@ -22,48 +22,48 @@ namespace sd { - #ifdef __CUDACC__ - __host__ +__host__ #endif - void Logger::info(const char *format, ...) { - va_list args; - va_start(args, format); + void + Logger::info(const char *format, ...) { + va_list args; + va_start(args, format); - vprintf(format, args); + vprintf(format, args); - va_end(args); + va_end(args); - fflush(stdout); - } + fflush(stdout); +} #ifdef __CUDACC__ - __host__ +__host__ #endif - void Logger::printv(const char *format, const std::vector& vec) { - printf("%s: {", format); - for(int e = 0; e < vec.size(); e++) { - auto v = vec[e]; - printf("%i", v); - if (e < vec.size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } + void + Logger::printv(const char *format, const std::vector &vec) { + printf("%s: {", format); + for (int e = 0; e < vec.size(); e++) { + auto v = vec[e]; + printf("%i", v); + if (e < vec.size() - 1) printf(", "); + } + printf("}\n"); + fflush(stdout); +} - #ifdef __CUDACC__ - __host__ +#ifdef __CUDACC__ +__host__ #endif - void Logger::printv(const char *format, const std::vector& vec) { - printf("%s: {", format); - for(int e = 0; e < vec.size(); e++) { - auto v = vec[e]; - printf("%lld", (long long) v); - if (e < vec.size() - 1) - printf(", "); - } - printf("}\n"); - fflush(stdout); - } -} \ No newline at end of file + void + Logger::printv(const char *format, const std::vector &vec) { + printf("%s: {", format); + for (int e = 0; e < vec.size(); e++) { + auto v = vec[e]; + printf("%lld", (long long)v); + if (e < vec.size() - 1) printf(", "); + } + printf("}\n"); + fflush(stdout); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/shape.cpp b/libnd4j/include/helpers/impl/shape.cpp index 33042e0b76f2..680d775ad683 100644 --- a/libnd4j/include/helpers/impl/shape.cpp +++ b/libnd4j/include/helpers/impl/shape.cpp @@ -20,7 +20,4 @@ #include -namespace shape { - -} - +namespace shape {} diff --git a/libnd4j/include/helpers/impl/unicode.cpp b/libnd4j/include/helpers/impl/unicode.cpp index 6ebbe7c1b478..0c8593ef8f6b 100644 --- a/libnd4j/include/helpers/impl/unicode.cpp +++ b/libnd4j/include/helpers/impl/unicode.cpp @@ -23,434 +23,420 @@ namespace sd { namespace unicode { - constexpr uint32_t ONEBYTEBOUND = 0x00000080; - constexpr uint32_t TWOBYTEBOUND = 0x00000800; - constexpr uint32_t THREEBYTEBOUND = 0x00010000; - constexpr uint16_t HIGHBYTEMIN = 0xd800u; - constexpr uint16_t HIGHBYTEMAX = 0xdbffu; - constexpr uint16_t TRAILBYTEMIN = 0xdc00u; - constexpr uint16_t TRAILBYTEMAX = 0xdfffu; - constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10); - constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN; - // Maximum valid value for a Unicode code point - constexpr uint32_t CODEPOINTMAX = 0x0010ffffu; - - template - FORCEINLINE uint8_t castToU8(const T cp) { - return static_cast(0xff & cp); - } +constexpr uint32_t ONEBYTEBOUND = 0x00000080; +constexpr uint32_t TWOBYTEBOUND = 0x00000800; +constexpr uint32_t THREEBYTEBOUND = 0x00010000; +constexpr uint16_t HIGHBYTEMIN = 0xd800u; +constexpr uint16_t HIGHBYTEMAX = 0xdbffu; +constexpr uint16_t TRAILBYTEMIN = 0xdc00u; +constexpr uint16_t TRAILBYTEMAX = 0xdfffu; +constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10); +constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN; +// Maximum valid value for a Unicode code point +constexpr uint32_t CODEPOINTMAX = 0x0010ffffu; + +template +FORCEINLINE uint8_t castToU8(const T cp) { + return static_cast(0xff & cp); +} - template - FORCEINLINE uint16_t castToU16(const T cp) { - return static_cast(0xffff & cp); - } +template +FORCEINLINE uint16_t castToU16(const T cp) { + return static_cast(0xffff & cp); +} - template - FORCEINLINE uint32_t castToU32(const T cp) { - return static_cast(0xffffff & cp); - } +template +FORCEINLINE uint32_t castToU32(const T cp) { + return static_cast(0xffffff & cp); +} - template - FORCEINLINE bool isTrail(const T cp) { - return ((castToU8(cp) >> 6) == 0x2); - } +template +FORCEINLINE bool isTrail(const T cp) { + return ((castToU8(cp) >> 6) == 0x2); +} - template - FORCEINLINE bool isHighSurrogate(const T cp) { - return (cp & 0xfffffc00) == 0xd800; - } +template +FORCEINLINE bool isHighSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xd800; +} - template - bool isLowSurrogate(const T cp) { - return (cp & 0xfffffc00) == 0xdc00; - } +template +bool isLowSurrogate(const T cp) { + return (cp & 0xfffffc00) == 0xdc00; +} - template - FORCEINLINE bool isLeadSurrogate(const T cp) { - return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX); - } +template +FORCEINLINE bool isLeadSurrogate(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX); +} - template - FORCEINLINE bool isTrailSurrogate(const T cp) { - return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX); - } +template +FORCEINLINE bool isTrailSurrogate(const T cp) { + return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX); +} - template - FORCEINLINE bool isSurrogateU8(const T cp) { - return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX); - } +template +FORCEINLINE bool isSurrogateU8(const T cp) { + return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX); +} - template - FORCEINLINE bool isSurrogateU16(const T cp) { - return ((cp - 0xd800u) < 2048u); - } +template +FORCEINLINE bool isSurrogateU16(const T cp) { + return ((cp - 0xd800u) < 2048u); +} - template - FORCEINLINE bool isSymbolU8Valid(const T cp) { - return (cp <= CODEPOINTMAX && !isSurrogateU8(cp)); - } +template +FORCEINLINE bool isSymbolU8Valid(const T cp) { + return (cp <= CODEPOINTMAX && !isSurrogateU8(cp)); +} - template - FORCEINLINE bool isSymbolValid(const T cp) { - return (cp <= CODEPOINTMAX); - } +template +FORCEINLINE bool isSymbolValid(const T cp) { + return (cp <= CODEPOINTMAX); +} - template - FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) { - return (high << 10) + low - 0x35fdc00; - } +template +FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) { + return (high << 10) + low - 0x35fdc00; +} - template - Nd4jLong symbolLength(const T* it) { - uint8_t lead = castToU8(*it); - if (lead < 0x80) - return 1; - else if ((lead >> 5) == 0x6) - return 2; - else if ((lead >> 4) == 0xe) - return 3; - else if ((lead >> 3) == 0x1e) - return 4; - else - return 0; - } +template +Nd4jLong symbolLength(const T* it) { + uint8_t lead = castToU8(*it); + if (lead < 0x80) + return 1; + else if ((lead >> 5) == 0x6) + return 2; + else if ((lead >> 4) == 0xe) + return 3; + else if ((lead >> 3) == 0x1e) + return 4; + else + return 0; +} - template - Nd4jLong symbolLength32(const T* it) { - auto lead = castToU32(*it); - if (lead < ONEBYTEBOUND) - return 1; - else if (lead < TWOBYTEBOUND) - return 2; - else if (lead < THREEBYTEBOUND) - return 3; - else if (lead <= CODEPOINTMAX) - return 4; - else - return 0; - } +template +Nd4jLong symbolLength32(const T* it) { + auto lead = castToU32(*it); + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else if (lead <= CODEPOINTMAX) + return 4; + else + return 0; +} - template - Nd4jLong symbolLength16(const T* it) { - - uint32_t lead = castToU16(*it); - if (!isLeadSurrogate(lead)) { - if (lead < ONEBYTEBOUND) - return 1; - else if (lead < TWOBYTEBOUND) - return 2; - else if (lead < THREEBYTEBOUND) - return 3; - else - return 0; - } - else { - return 4; - } - } +template +Nd4jLong symbolLength16(const T* it) { + uint32_t lead = castToU16(*it); + if (!isLeadSurrogate(lead)) { + if (lead < ONEBYTEBOUND) + return 1; + else if (lead < TWOBYTEBOUND) + return 2; + else if (lead < THREEBYTEBOUND) + return 3; + else + return 0; + } else { + return 4; + } +} + +Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + it += (length > 0) ? (length - 1) : 0; + count += 1; + } + return static_cast(count * sizeof(char32_t)); +} + +Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += 1; + } + return static_cast(count * sizeof(char32_t)); +} + +Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength(it); + auto step = ((length > 0) ? (length - 1) : 0); + it += step; + count += (4 == length) ? 2 : 1; + } + return static_cast(count * sizeof(char16_t)); +} + +Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end;) { + auto length = symbolLength16(it); + it += (4 == length) ? 2 : 1; + count += length; + } + return static_cast(count); +} + +Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + auto length = symbolLength32(it); + count += (4 == length) ? 2 : 1; + ; + } + return static_cast(count * sizeof(char16_t)); +} - Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength(it); - it += (length > 0) ? (length - 1) : 0; - count += 1; - } - return static_cast(count * sizeof(char32_t)); +Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) { + Nd4jLong count = 0; + for (auto it = static_cast(start); it != end; it++) { + count += symbolLength32(it); + } + return count; +} + +bool isStringValidU8(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolU8Valid(castToU8(*it))) { + return false; } - - Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end;) { - auto length = symbolLength16(it); - it += (4 == length) ? 2 : 1; - count += 1; - } - return static_cast(count*sizeof(char32_t)); + } + return true; +} + +bool isStringValidU16(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid(castToU32(*it))) { + return false; } - - Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength(it); - auto step = ((length > 0) ? (length - 1) : 0); - it += step; - count += (4 == length) ? 2 : 1; - } - return static_cast(count*sizeof(char16_t)); + } + return true; +} + +bool isStringValidU32(const void* start, const void* stop) { + for (auto it = static_cast(start); it != stop; it++) { + if (!isSymbolValid(castToU32(*it))) { + return false; } - - Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end;) { - auto length = symbolLength16(it); - it += (4 == length) ? 2 : 1; - count += length; - } - return static_cast(count); + } + return true; +} + +void* utf16to8Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + uint32_t cp = castToU16(*it++); + if (!isLeadSurrogate(cp)) { + if (cp < 0x80) { // for one byte + *(result++) = static_cast(cp); + } else if (cp < 0x800) { // for two bytes + *(result++) = static_cast((cp >> 6) | 0xc0); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } else { // for three bytes + *(result++) = static_cast((cp >> 12) | 0xe0); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); + } + } else { + if (it != end) { + uint32_t trail_surrogate = castToU16(*it++); + if (isTrailSurrogate(trail_surrogate)) + cp = (cp << 10) + trail_surrogate + BYTEOFFSET; + } + // for four bytes + *(result++) = static_cast((cp >> 18) | 0xf0); + *(result++) = static_cast(((cp >> 12) & 0x3f) | 0x80); + *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); + *(result++) = static_cast((cp & 0x3f) | 0x80); } - - Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - auto length = symbolLength32(it); - count += (4 == length) ? 2 : 1;; - } - return static_cast(count*sizeof(char16_t)); + } + return result; +} + +void* utf8to16Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (4 != nLength) { + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } + *(result++) = static_cast(cp); + } else { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; + // make a surrogate pair + *(result++) = static_cast((cp >> 10) + HIGHBYTEOFFSET); + *(result++) = static_cast((cp & 0x3ff) + TRAILBYTEMIN); } - - Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) { - - Nd4jLong count = 0; - for (auto it = static_cast(start); it != end; it++) { - count += symbolLength32(it); - } - return count; + } + return result; +} + +void* utf32to8Ptr(const void* start, const void* end, void* result) { + auto res = static_cast(result); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + if (*it < 0x80) // for one byte + *(res++) = static_cast(*it); + else if (*it < 0x800) { // for two bytes + *(res++) = static_cast((*it >> 6) | 0xc0); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } else if (*it < 0x10000) { // for three bytes + *(res++) = static_cast((*it >> 12) | 0xe0); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); + } else { // for four bytes + *(res++) = static_cast((*it >> 18) | 0xf0); + *(res++) = static_cast(((*it >> 12) & 0x3f) | 0x80); + *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); + *(res++) = static_cast((*it & 0x3f) | 0x80); } - - bool isStringValidU8(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolU8Valid( castToU8(*it) )) { - return false; - } - } - return true; + } + return result; +} + +void* utf8to32Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end;) { + auto nLength = symbolLength(it); + uint32_t cp = castToU8(*it++); + if (2 == nLength) { + cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); + } else if (3 == nLength) { + cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); + cp += (*it++) & 0x3f; + } else if (4 == nLength) { + cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); + cp += (castToU8(*it++) << 6) & 0xfff; + cp += (*it++) & 0x3f; } - - bool isStringValidU16(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolValid( castToU32(*it) )) { - return false; - } - } - return true; + (*result++) = cp; + } + return result; +} + +void* utf16to32Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocated + for (auto it = static_cast(start); it != end; it++) { + uint32_t cpHigh = castToU32(*it); + if (!isSurrogateU16(cpHigh)) { + *result++ = cpHigh; + } else { + it++; + uint32_t cpLow = castToU32(*it); + if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) { + *result++ = surrogateU32(cpHigh, cpLow); + } } - - bool isStringValidU32(const void* start, const void* stop) { - for (auto it = static_cast(start); it != stop; it++) { - if (!isSymbolValid( castToU32(*it) )) { - return false; - } - } - return true; + } + return result; +} + +void* utf32to16Ptr(const void* start, const void* end, void* res) { + auto result = static_cast(res); + // result have to be pre-allocate + for (auto it = static_cast(start); it != end; it++) { + uint32_t cpHigh = castToU32(*it); + // todo check do we need this as we have pre-validation, if yes find out how + // to check u16 + if (cpHigh < 0 || cpHigh > 0x10FFFF || + (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) { + // Invalid code point. Replace with sentinel, per Unicode standard: + *result++ = u'\uFFFD'; + } else if (cpHigh < 0x10000UL) { // In the BMP. + *result++ = static_cast(cpHigh); + } else { + *result++ = + static_cast(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U); + *result++ = + static_cast(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U); } + } + return result; +} - void* utf16to8Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - uint32_t cp = castToU16(*it++); - if (!isLeadSurrogate(cp)) { - if (cp < 0x80) { // for one byte - *(result++) = static_cast(cp); - } - else if (cp < 0x800) { // for two bytes - *(result++) = static_cast((cp >> 6) | 0xc0); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - else{ // for three bytes - *(result++) = static_cast((cp >> 12) | 0xe0); - *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - } - else { - if (it != end) { - uint32_t trail_surrogate = castToU16(*it++); - if (isTrailSurrogate(trail_surrogate)) - cp = (cp << 10) + trail_surrogate + BYTEOFFSET; - } - // for four bytes - *(result++) = static_cast((cp >> 18) | 0xf0); - *(result++) = static_cast(((cp >> 12) & 0x3f) | 0x80); - *(result++) = static_cast(((cp >> 6) & 0x3f) | 0x80); - *(result++) = static_cast((cp & 0x3f) | 0x80); - } - } - return result; - } - - void* utf8to16Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - - auto nLength = symbolLength(it); - uint32_t cp = castToU8(*it++); - if (4 != nLength) { - if (2 == nLength) { - cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); - } - else if (3 == nLength) { - cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); - cp += (*it++) & 0x3f; - } - *(result++) = static_cast(cp); - } - else { - cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); - cp += (castToU8(*it++) << 6) & 0xfff; - cp += (*it++) & 0x3f; - //make a surrogate pair - *(result++) = static_cast((cp >> 10) + HIGHBYTEOFFSET); - *(result++) = static_cast((cp & 0x3ff) + TRAILBYTEMIN); - } - } - return result; - } - - void* utf32to8Ptr( const void* start, const void* end, void* result) { - - auto res = static_cast(result); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end; it++) { - - if (*it < 0x80) // for one byte - *(res++) = static_cast(*it); - else if (*it < 0x800) { // for two bytes - *(res++) = static_cast((*it >> 6) | 0xc0); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - else if (*it < 0x10000) { // for three bytes - *(res++) = static_cast((*it >> 12) | 0xe0); - *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - else { // for four bytes - *(res++) = static_cast((*it >> 18) | 0xf0); - *(res++) = static_cast(((*it >> 12) & 0x3f) | 0x80); - *(res++) = static_cast(((*it >> 6) & 0x3f) | 0x80); - *(res++) = static_cast((*it & 0x3f) | 0x80); - } - } - return result; - } - - void* utf8to32Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end;) { - - auto nLength = symbolLength(it); - uint32_t cp = castToU8(*it++); - if (2 == nLength) { - cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f); - } - else if (3 == nLength) { - cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff); - cp += (*it++) & 0x3f; - } - else if (4 == nLength) { - cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff); - cp += (castToU8(*it++) << 6) & 0xfff; - cp += (*it++) & 0x3f; - } - (*result++) = cp; - } - return result; - } - - void* utf16to32Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocated - for (auto it = static_cast(start); it != end; it++) { - - uint32_t cpHigh = castToU32(*it); - if (!isSurrogateU16(cpHigh)) { - *result++ = cpHigh; - } - else { - it++; - uint32_t cpLow = castToU32(*it); - if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) { - *result++ = surrogateU32(cpHigh, cpLow); - } - } - } - return result; - } - - void* utf32to16Ptr(const void* start, const void* end, void* res) { - - auto result = static_cast(res); - // result have to be pre-allocate - for (auto it = static_cast(start); it != end; it++) { - - uint32_t cpHigh = castToU32(*it); - // todo check do we need this as we have pre-validation, if yes find out how to check u16 - if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) { - // Invalid code point. Replace with sentinel, per Unicode standard: - *result++ = u'\uFFFD'; - } - else if (cpHigh < 0x10000UL) { // In the BMP. - *result++ = static_cast(cpHigh); - } - else { - *result++ = static_cast(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U); - *result++ = static_cast(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U); - } - } - return result; - } - - Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) { - return offsetUtf8StringInUtf32(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) { - return offsetUtf16StringInUtf32(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) { - return offsetUtf8StringInUtf16(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) { - return offsetUtf16StringInUtf8(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) { - return offsetUtf32StringInUtf8(input, static_cast(input) + nInputSize); - } - - Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) { - return offsetUtf32StringInUtf16(input, static_cast(input) + nInputSize); - } - - bool utf8to16(const void* input, void* output, uint32_t nInputSize) { - return utf8to16Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf8to32(const void* input, void* output, uint32_t nInputSize) { - return utf8to32Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf16to32(const void* input, void* output, uint32_t nInputSize) { - return utf16to32Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf16to8(const void* input, void* output, uint32_t nInputSize) { - return utf16to8Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf32to16(const void* input, void* output, uint32_t nInputSize) { - return utf32to16Ptr(input, static_cast(input) + nInputSize, output); - } - - bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) { - return utf32to8Ptr(input, static_cast(input) + nInputSize, output); - } - - } +Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf32( + input, static_cast(input) + nInputSize); +} +Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf32( + input, static_cast(input) + nInputSize); } +Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) { + return offsetUtf8StringInUtf16( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf16StringInUtf8( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) { + return offsetUtf32StringInUtf8( + input, static_cast(input) + nInputSize); +} + +Nd4jLong offsetUtf32StringInUtf16(const void* input, + const uint32_t nInputSize) { + return offsetUtf32StringInUtf16( + input, static_cast(input) + nInputSize); +} + +bool utf8to16(const void* input, void* output, uint32_t nInputSize) { + return utf8to16Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf8to32(const void* input, void* output, uint32_t nInputSize) { + return utf8to32Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf16to32(const void* input, void* output, uint32_t nInputSize) { + return utf16to32Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf16to8(const void* input, void* output, uint32_t nInputSize) { + return utf16to8Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf32to16(const void* input, void* output, uint32_t nInputSize) { + return utf32to16Ptr(input, static_cast(input) + nInputSize, + output); +} + +bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) { + return utf32to8Ptr(input, static_cast(input) + nInputSize, + output); +} + +} // namespace unicode + +} // namespace sd diff --git a/libnd4j/include/helpers/jacobiSVD.h b/libnd4j/include/helpers/jacobiSVD.h index f6f161bbb228..57f46dec2a55 100644 --- a/libnd4j/include/helpers/jacobiSVD.h +++ b/libnd4j/include/helpers/jacobiSVD.h @@ -21,8 +21,8 @@ #ifndef LIBND4J_JACOBISVD_H #define LIBND4J_JACOBISVD_H -#include #include +#include namespace sd { namespace ops { @@ -30,43 +30,43 @@ namespace helpers { template class JacobiSVD { + public: + NDArray _m; + NDArray _s; // vector with singular values + NDArray _u; + NDArray _v; - public: + int _diagSize; + int _rows; + int _cols; - NDArray _m; - NDArray _s; // vector with singular values - NDArray _u; - NDArray _v; - - int _diagSize; - int _rows; - int _cols; + // bool _transp; + bool _calcU; + bool _calcV; + bool _fullUV; - // bool _transp; - bool _calcU; - bool _calcV; - bool _fullUV; + JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, + const bool fullUV); - JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, const bool fullUV); + bool isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem); - bool isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem); + static bool createJacobiRotation(const T& x, const T& y, const T& z, + NDArray& rotation); - static bool createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation); - - static void svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right); + static void svd2x2(const NDArray& block, int p, int q, NDArray& left, + NDArray& right); - static void mulRotationOnLeft(const int i, const int j, NDArray& block, const NDArray& rotation); + static void mulRotationOnLeft(const int i, const int j, NDArray& block, + const NDArray& rotation); - static void mulRotationOnRight(const int i, const int j, NDArray& block, const NDArray& rotation); + static void mulRotationOnRight(const int i, const int j, NDArray& block, + const NDArray& rotation); - void evalData(const NDArray& matrix); + void evalData(const NDArray& matrix); }; +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - - -#endif //LIBND4J_JACOBISVD_H +#endif // LIBND4J_JACOBISVD_H diff --git a/libnd4j/include/helpers/logger.h b/libnd4j/include/helpers/logger.h index 625a549318be..99e5cc1f3bd9 100644 --- a/libnd4j/include/helpers/logger.h +++ b/libnd4j/include/helpers/logger.h @@ -21,22 +21,31 @@ #ifndef LIBND4J_LOGGER_H #define LIBND4J_LOGGER_H -#include -#include -#include -#include #include +#include +#include #include #include #include +#include +#include + #ifndef __CUDA_ARCH__ -#define nd4j_debug(FORMAT, ...) if (sd::Environment::getInstance()->isDebug() && sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_logger(FORMAT, ...) if (sd::Environment::getInstance()->isDebug() && sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_verbose(FORMAT, ...) if (sd::Environment::getInstance()->isVerbose()) sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_debug(FORMAT, ...) \ + if (sd::Environment::getInstance()->isDebug() && \ + sd::Environment::getInstance()->isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_logger(FORMAT, ...) \ + if (sd::Environment::getInstance()->isDebug() && \ + sd::Environment::getInstance()->isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); +#define nd4j_verbose(FORMAT, ...) \ + if (sd::Environment::getInstance()->isVerbose()) \ + sd::Logger::info(FORMAT, __VA_ARGS__); #define nd4j_printf(FORMAT, ...) sd::Logger::info(FORMAT, __VA_ARGS__); -#define nd4j_printv(FORMAT, VECTOR) sd::Logger::printv(FORMAT, VECTOR); +#define nd4j_printv(FORMAT, VECTOR) sd::Logger::printv(FORMAT, VECTOR); #else @@ -49,17 +58,15 @@ #endif namespace sd { - class SD_EXPORT Logger { - - public: - - static void _CUDA_H info(const char *format, ...); - - static void _CUDA_H printv(const char *format, const std::vector& vec); - static void _CUDA_H printv(const char *format, const std::vector& vec); - }; +class SD_EXPORT Logger { + public: + static void _CUDA_H info(const char *format, ...); -} + static void _CUDA_H printv(const char *format, const std::vector &vec); + static void _CUDA_H printv(const char *format, + const std::vector &vec); +}; +} // namespace sd -#endif //LIBND4J_LOGGER_H +#endif // LIBND4J_LOGGER_H diff --git a/libnd4j/include/helpers/mman.h b/libnd4j/include/helpers/mman.h index 618ee23c3960..4a09ab1ddcb4 100644 --- a/libnd4j/include/helpers/mman.h +++ b/libnd4j/include/helpers/mman.h @@ -22,16 +22,18 @@ #ifndef _SYS_MMAN_H_ #define _SYS_MMAN_H_ -#include #include #include +#include #ifndef FILE_MAP_EXECUTE -#define FILE_MAP_EXECUTE 0x0020 +#define FILE_MAP_EXECUTE 0x0020 #endif /* FILE_MAP_EXECUTE */ -#ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later. -#define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows. +#ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later. +#define _WIN32_WINNT \ + 0x0501 // Change this to the appropriate value to target other versions of + // Windows. #endif /* All the headers include this file. */ @@ -53,269 +55,252 @@ typedef uint32_t OffsetType; extern "C" { #endif -#define PROT_NONE 0 -#define PROT_READ 1 -#define PROT_WRITE 2 -#define PROT_EXEC 4 +#define PROT_NONE 0 +#define PROT_READ 1 +#define PROT_WRITE 2 +#define PROT_EXEC 4 -#define MAP_FILE 0 -#define MAP_SHARED 1 -#define MAP_PRIVATE 2 -#define MAP_TYPE 0xf -#define MAP_FIXED 0x10 -#define MAP_ANONYMOUS 0x20 -#define MAP_ANON MAP_ANONYMOUS +#define MAP_FILE 0 +#define MAP_SHARED 1 +#define MAP_PRIVATE 2 +#define MAP_TYPE 0xf +#define MAP_FIXED 0x10 +#define MAP_ANONYMOUS 0x20 +#define MAP_ANON MAP_ANONYMOUS -#define MAP_FAILED ((void *)-1) +#define MAP_FAILED ((void *)-1) /* Flags for msync. */ -#define MS_ASYNC 1 -#define MS_SYNC 2 -#define MS_INVALIDATE 4 - -void _mmap(Nd4jLong* result, size_t length, const char *fileName); -void* mmap(void *addr, size_t len, int prot, int flags, int fildes, OffsetType off); -int munmap(void *addr, size_t len); -int _mprotect(void *addr, size_t len, int prot); -int msync(void *addr, size_t len, int flags); -int mlock(const void *addr, size_t len); -int munlock(const void *addr, size_t len); +#define MS_ASYNC 1 +#define MS_SYNC 2 +#define MS_INVALIDATE 4 + +void _mmap(Nd4jLong *result, size_t length, const char *fileName); +void *mmap(void *addr, size_t len, int prot, int flags, int fildes, + OffsetType off); +int munmap(void *addr, size_t len); +int _mprotect(void *addr, size_t len, int prot); +int msync(void *addr, size_t len, int flags); +int mlock(const void *addr, size_t len); +int munlock(const void *addr, size_t len); #ifdef __cplusplus } #endif -static int __map_mman_error(const DWORD err, const int deferr) -{ - if (err == 0) - return 0; - //TODO: implement - return err; +static int __map_mman_error(const DWORD err, const int deferr) { + if (err == 0) return 0; + // TODO: implement + return err; } -static DWORD __map_mmap_prot_page(const int prot) -{ - DWORD protect = 0; - - if (prot == PROT_NONE) - return protect; - - if ((prot & PROT_EXEC) != 0) - { - protect = ((prot & PROT_WRITE) != 0) ? - PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ; - } - else - { - protect = ((prot & PROT_WRITE) != 0) ? - PAGE_READWRITE : PAGE_READONLY; - } - - return protect; +static DWORD __map_mmap_prot_page(const int prot) { + DWORD protect = 0; + + if (prot == PROT_NONE) return protect; + + if ((prot & PROT_EXEC) != 0) { + protect = + ((prot & PROT_WRITE) != 0) ? PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ; + } else { + protect = ((prot & PROT_WRITE) != 0) ? PAGE_READWRITE : PAGE_READONLY; + } + + return protect; } -static DWORD __map_mmap_prot_file(const int prot) -{ - DWORD desiredAccess = 0; +static DWORD __map_mmap_prot_file(const int prot) { + DWORD desiredAccess = 0; - if (prot == PROT_NONE) - return desiredAccess; + if (prot == PROT_NONE) return desiredAccess; - if ((prot & PROT_READ) != 0) - desiredAccess |= FILE_MAP_READ; - if ((prot & PROT_WRITE) != 0) - desiredAccess |= FILE_MAP_WRITE; - if ((prot & PROT_EXEC) != 0) - desiredAccess |= FILE_MAP_EXECUTE; + if ((prot & PROT_READ) != 0) desiredAccess |= FILE_MAP_READ; + if ((prot & PROT_WRITE) != 0) desiredAccess |= FILE_MAP_WRITE; + if ((prot & PROT_EXEC) != 0) desiredAccess |= FILE_MAP_EXECUTE; - return desiredAccess; + return desiredAccess; } -void _mmap(Nd4jLong* result, size_t length, const char *fileName) { - HANDLE fm, h; +void _mmap(Nd4jLong *result, size_t length, const char *fileName) { + HANDLE fm, h; - void * map = MAP_FAILED; - OffsetType off = 0; - int prot = PROT_READ | PROT_WRITE; + void *map = MAP_FAILED; + OffsetType off = 0; + int prot = PROT_READ | PROT_WRITE; #ifdef _MSC_VER - #pragma warning(push) - #pragma warning(disable: 4293) +#pragma warning(push) +#pragma warning(disable : 4293) #endif - const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)off : (DWORD)(off & 0xFFFFFFFFL); - const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL); - const DWORD protect = __map_mmap_prot_page(prot); - const DWORD desiredAccess = __map_mmap_prot_file(prot); + const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)off + : (DWORD)(off & 0xFFFFFFFFL); + const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((off >> 32) & 0xFFFFFFFFL); + const DWORD protect = __map_mmap_prot_page(prot); + const DWORD desiredAccess = __map_mmap_prot_file(prot); - const OffsetType maxSize = off + (OffsetType) length; + const OffsetType maxSize = off + (OffsetType)length; - const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL); - const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); + const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)maxSize + : (DWORD)(maxSize & 0xFFFFFFFFL); + const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); #ifdef _MSC_VER - #pragma warning(pop) +#pragma warning(pop) #endif - h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, nullptr); - if (h == INVALID_HANDLE_VALUE) { - errno = __map_mman_error(GetLastError(), EPERM); - nd4j_printf("Error code: %i\n", (int) errno); - throw std::runtime_error("CreateFile failed"); - } + if (h == INVALID_HANDLE_VALUE) { + errno = __map_mman_error(GetLastError(), EPERM); + nd4j_printf("Error code: %i\n", (int)errno); + throw std::runtime_error("CreateFile failed"); + } - fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); + fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); - if (fm == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - throw std::runtime_error("CreateFileMapping failed"); - } + if (fm == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + throw std::runtime_error("CreateFileMapping failed"); + } - map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, length); + map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, + length); - CloseHandle(fm); + CloseHandle(fm); - if (map == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - throw std::runtime_error("MapViewOfFile failed"); - } + if (map == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + throw std::runtime_error("MapViewOfFile failed"); + } - result[0] = reinterpret_cast(map); - result[1] = reinterpret_cast(h); + result[0] = reinterpret_cast(map); + result[1] = reinterpret_cast(h); } -void* mmap(void *addr, size_t len, int prot, int flags, int files, OffsetType off) -{ - HANDLE fm, h; +void *mmap(void *addr, size_t len, int prot, int flags, int files, + OffsetType off) { + HANDLE fm, h; - void * map = MAP_FAILED; + void *map = MAP_FAILED; #ifdef _MSC_VER - #pragma warning(push) -#pragma warning(disable: 4293) +#pragma warning(push) +#pragma warning(disable : 4293) #endif - const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)off : (DWORD)(off & 0xFFFFFFFFL); - const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL); - const DWORD protect = __map_mmap_prot_page(prot); - const DWORD desiredAccess = __map_mmap_prot_file(prot); + const DWORD dwFileOffsetLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)off + : (DWORD)(off & 0xFFFFFFFFL); + const DWORD dwFileOffsetHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((off >> 32) & 0xFFFFFFFFL); + const DWORD protect = __map_mmap_prot_page(prot); + const DWORD desiredAccess = __map_mmap_prot_file(prot); - const OffsetType maxSize = off + (OffsetType)len; + const OffsetType maxSize = off + (OffsetType)len; - const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL); - const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) ? - (DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); + const DWORD dwMaxSizeLow = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)maxSize + : (DWORD)(maxSize & 0xFFFFFFFFL); + const DWORD dwMaxSizeHigh = (sizeof(OffsetType) <= sizeof(DWORD)) + ? (DWORD)0 + : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL); #ifdef _MSC_VER #pragma warning(pop) #endif - errno = 0; + errno = 0; - if (len == 0 - /* Unsupported flag combinations */ - || (flags & MAP_FIXED) != 0 - /* Usupported protection combinations */ - || prot == PROT_EXEC) - { - errno = EINVAL; - return MAP_FAILED; - } + if (len == 0 + /* Unsupported flag combinations */ + || (flags & MAP_FIXED) != 0 + /* Usupported protection combinations */ + || prot == PROT_EXEC) { + errno = EINVAL; + return MAP_FAILED; + } - h = ((flags & MAP_ANONYMOUS) == 0) ? - (HANDLE)_get_osfhandle(files) : INVALID_HANDLE_VALUE; + h = ((flags & MAP_ANONYMOUS) == 0) ? (HANDLE)_get_osfhandle(files) + : INVALID_HANDLE_VALUE; + if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) { + errno = EBADF; + return MAP_FAILED; + } - if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) - { - errno = EBADF; - return MAP_FAILED; - } + fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); - fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL); + if (fm == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } - if (fm == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - return MAP_FAILED; - } + map = + MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len); - map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len); + CloseHandle(fm); - CloseHandle(fm); + if (map == NULL) { + errno = __map_mman_error(GetLastError(), EPERM); + return MAP_FAILED; + } - if (map == NULL) - { - errno = __map_mman_error(GetLastError(), EPERM); - return MAP_FAILED; - } - - return map; + return map; } -int munmap(void *addr, size_t len) -{ - if (UnmapViewOfFile(addr)) - return 0; +int munmap(void *addr, size_t len) { + if (UnmapViewOfFile(addr)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int _mprotect(void *addr, size_t len, int prot) -{ - DWORD newProtect = __map_mmap_prot_page(prot); - DWORD oldProtect = 0; +int _mprotect(void *addr, size_t len, int prot) { + DWORD newProtect = __map_mmap_prot_page(prot); + DWORD oldProtect = 0; - if (VirtualProtect(addr, len, newProtect, &oldProtect)) - return 0; + if (VirtualProtect(addr, len, newProtect, &oldProtect)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } int msync(void *addr, size_t len, int flags) { - if (FlushViewOfFile(addr, len)) - return 0; + if (FlushViewOfFile(addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int mlock(const void *addr, size_t len) -{ - if (VirtualLock((LPVOID)addr, len)) - return 0; +int mlock(const void *addr, size_t len) { + if (VirtualLock((LPVOID)addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } -int munlock(const void *addr, size_t len) -{ - if (VirtualUnlock((LPVOID)addr, len)) - return 0; +int munlock(const void *addr, size_t len) { + if (VirtualUnlock((LPVOID)addr, len)) return 0; - errno = __map_mman_error(GetLastError(), EPERM); + errno = __map_mman_error(GetLastError(), EPERM); - return -1; + return -1; } - -#endif //PROJECT_MMAN_H +#endif // PROJECT_MMAN_H diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index ebc70664eb06..0013314a787d 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -24,20 +24,22 @@ #ifndef SHAPE_H_ #define SHAPE_H_ -#include +#include + #include +#include + +#include "../cnpy/cnpy.h" +#include "../helpers/logger.h" +#include "math/templatemath.h" #include "system/dll.h" #include "system/nd4jmalloc.h" -#include "math/templatemath.h" -#include "../helpers/logger.h" #include "system/pointercast.h" -#include "../cnpy/cnpy.h" -#include #define MAX_DIMENSION 0x7fffffff -#define MAX_NUM_THREADS 1024 +#define MAX_NUM_THREADS 1024 #define MAX_RANK 32 -#define MAX_SHAPEINFOLENGTH 2*MAX_RANK+4 +#define MAX_SHAPEINFOLENGTH 2 * MAX_RANK + 4 #define MAX_COORD 3 #define PREALLOC_SIZE 33554432 #ifdef __CUDACC__ @@ -45,16 +47,16 @@ #include #endif - #ifdef __CUDACC__ #define INLINEDEF inline #else #define INLINEDEF inline #endif -#include "system/pairwise_util.h" -#include #include +#include + +#include "system/pairwise_util.h" typedef unsigned int uint; @@ -64,103 +66,134 @@ namespace shape { * Shape information approximating * the information on an ndarray */ - struct SD_EXPORT ShapeInformation { - _CUDA_HD ShapeInformation(Nd4jLong* shape_ = nullptr, Nd4jLong *stride_ = nullptr, char order_ = 0, int rank_ = 0, int offset_ = 0, int elementWiseStride_ = 0) - : shape(shape_), stride(stride_), order(order_), rank(rank_), offset(offset_), elementWiseStride(elementWiseStride_) - {} - - Nd4jLong *shape; - Nd4jLong *stride; - char order; - int rank; - int offset; - int elementWiseStride; - }; +struct SD_EXPORT ShapeInformation { + _CUDA_HD ShapeInformation(Nd4jLong *shape_ = nullptr, + Nd4jLong *stride_ = nullptr, char order_ = 0, + int rank_ = 0, int offset_ = 0, + int elementWiseStride_ = 0) + : shape(shape_), + stride(stride_), + order(order_), + rank(rank_), + offset(offset_), + elementWiseStride(elementWiseStride_) {} + + Nd4jLong *shape; + Nd4jLong *stride; + char order; + int rank; + int offset; + int elementWiseStride; +}; /** * Indexing information * for bounds checking */ - struct SD_EXPORT CurrentIndexing { - int numElementsPerThread; - int blockStartingIndex; - int startingThreadIndex; - int endingThreadIndex; - - }; +struct SD_EXPORT CurrentIndexing { + int numElementsPerThread; + int blockStartingIndex; + int startingThreadIndex; + int endingThreadIndex; +}; +SD_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, + const Nd4jLong *shape1, + const int shape2Rank, + const Nd4jLong *shape2); +SD_EXPORT _CUDA_HD const Nd4jLong *detachShape(const Nd4jLong *originalShape); - SD_EXPORT _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2); +SD_EXPORT _CUDA_HD Nd4jLong *copyShape(Nd4jLong const *originalShape); - SD_EXPORT _CUDA_HD const Nd4jLong* detachShape(const Nd4jLong *originalShape); +SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2); - SD_EXPORT _CUDA_HD Nd4jLong* copyShape(Nd4jLong const* originalShape); +SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3); - SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); +SD_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank, + Nd4jLong const *shape1, + int const shape2Rank, + Nd4jLong const *shape2); - SD_EXPORT _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); +SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const *shapeInfo1, + Nd4jLong const *shapeInfo2); - SD_EXPORT _CUDA_HD bool strideEquals(int const shape1Rank,Nd4jLong const* shape1,int const shape2Rank, Nd4jLong const* shape2); +SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const *stride1, int const rank1, + Nd4jLong const *stride2, int const rank2); - SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1, Nd4jLong const* shapeInfo2); +SD_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - SD_EXPORT _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1, Nd4jLong const* stride2, int const rank2); +SD_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - SD_EXPORT _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +SD_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, + const Nd4jLong *shapeB); - SD_EXPORT _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +// returns true if ranks, shapes and strides are the same +SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2); +SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3); - SD_EXPORT _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB); +SD_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); +SD_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); - // returns true if ranks, shapes and strides are the same - SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2); - SD_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); - - SD_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); - SD_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); - - template - SD_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length); - - SD_EXPORT _CUDA_HD void traceNew(int id); +template +SD_EXPORT _CUDA_HD void fill(T *buffer, T value, Nd4jLong length); +SD_EXPORT _CUDA_HD void traceNew(int id); - SD_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); +SD_EXPORT _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength); - SD_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, + int dimensionLength); - SD_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder); +SD_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong *oldShape, + const int newRank, Nd4jLong *newShape, + bool isFOrder); - SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo); - /** - * newShapeInfo contains rank, shape and order only, no strides/ews/type - */ - SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo); +SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + const char newOrder, const int newRank, + const Nd4jLong *newShape, + Nd4jLong *newShapeInfo); +/** + * newShapeInfo contains rank, shape and order only, no strides/ews/type + */ +SD_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + Nd4jLong *newShapeInfo); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape); - SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *buffer); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape); - SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *output); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *output); #ifdef __CUDACC__ - __device__ SD_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); +__device__ SD_EXPORT Nd4jLong *cuMalloc(Nd4jLong *buffer, long size); #endif - - /** * Computes the standard packed array strides for a given shape. * @@ -168,9 +201,11 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - SD_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, + int rank); - SD_EXPORT _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + Nd4jLong *ret); /** * Computes the standard packed array strides for a given shape. @@ -180,17 +215,19 @@ namespace shape { * @return the strides for a matrix of n dimensions */ - SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank); - - SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank); - SD_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); - SD_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + Nd4jLong *ret); +SD_EXPORT _CUDA_HD void updateStrides(Nd4jLong *shape, const char order); +SD_EXPORT _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, + Nd4jLong *stridesOnly, const char order); -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - template - SD_EXPORT _CUDA_HD bool isDimPermuted(const T* dimensions, const int dimSize); +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 +template +SD_EXPORT _CUDA_HD bool isDimPermuted(const T *dimensions, const int dimSize); /** * Computes the standard packed array strides for a given shape. @@ -199,9 +236,11 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - SD_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum); - SD_EXPORT _CUDA_HD Nd4jLong* calcStridesFortran(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret); /** * Computes the standard packed array strides for a given shape. @@ -210,28 +249,28 @@ namespace shape { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const* shape, int rank, int startNum); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum); - SD_EXPORT _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, int startNum, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret); /** * @param toCopy the shape to copy * @return a copy of the original struct */ - SD_EXPORT _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy); - +SD_EXPORT _CUDA_HD ShapeInformation *shapeCopy(ShapeInformation *toCopy); - SD_EXPORT _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer); - - SD_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD bool strideDescendingCAscendingF( + const Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD bool isContiguous(const Nd4jLong *shapeInfo); /** * copy-past from java hasDefaultStridesForShape function * check whether array is not permuted and has contiguous elements in memory */ - SD_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo); - +SD_EXPORT _CUDA_HD bool areStridesDefault(const Nd4jLong *shapeInfo); /** * Compute the element wise stride @@ -244,7 +283,9 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder); +SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder); /** * Compute the element wise stride @@ -257,11 +298,19 @@ namespace shape { * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, Nd4jLong const* dimension, int dimensionLength); - - SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(Nd4jLong const* shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride); - - SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer); +SD_EXPORT _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder, + Nd4jLong const *dimension, + int dimensionLength); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + Nd4jLong const *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride); + +SD_EXPORT _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride, Nd4jLong *buffer); /** * * @param length @@ -269,9 +318,8 @@ namespace shape { * @param rearrange * @return */ - SD_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int* rearrange); - - +SD_EXPORT _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, + int *rearrange); /** * In place permute swap @@ -279,40 +327,48 @@ namespace shape { * @param shape * @param rearrange */ - SD_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int* rearrange); +SD_EXPORT _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, + int *rearrange); - SD_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange); +SD_EXPORT _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange); - SD_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int* rearrange, Nd4jLong *out); +SD_EXPORT _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, + int *rearrange, + Nd4jLong *out); - SD_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, const int *rearrange, Nd4jLong len = -1); - - /** - * Rearrange the permute indexes - * according to which dimensions are specified. - * - * For example, dimension is implicitly: - * 0,1,2 - * - * If you want to do a reduce along dimensions 0 and 1, - * you need to permute the indexes to be: - * 2,0,1 - * - * which will give us the ability to ierate along an element - * wise stride. - */ +SD_EXPORT _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeBuffer, + const int *rearrange, + Nd4jLong len = -1); - SD_EXPORT _CUDA_HD Nd4jLong* createPermuteIndexes(int originalRank, int *dimension,int dimensionLength); +/** + * Rearrange the permute indexes + * according to which dimensions are specified. + * + * For example, dimension is implicitly: + * 0,1,2 + * + * If you want to do a reduce along dimensions 0 and 1, + * you need to permute the indexes to be: + * 2,0,1 + * + * which will give us the ability to ierate along an element + * wise stride. + */ - SD_EXPORT _CUDA_HD Nd4jLong* computeResultShape(const Nd4jLong *originalShapeBuffer, int *dimension,int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, + int *dimension, + int dimensionLength); - /** - * This method does inplace transpose of given shapeBuffer - * - * @param shapeBuffer - */ - SD_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD Nd4jLong *computeResultShape( + const Nd4jLong *originalShapeBuffer, int *dimension, int dimensionLength); +/** + * This method does inplace transpose of given shapeBuffer + * + * @param shapeBuffer + */ +SD_EXPORT _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer); /** * Get the ordering for the device @@ -322,7 +378,8 @@ namespace shape { * @param elementStride * @return */ - SD_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride); +SD_EXPORT _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, + int elementStride); /** * Ensure that every value in the re arrange @@ -333,8 +390,9 @@ namespace shape { * @param shapeLength * @return */ - template - SD_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength); +template +SD_EXPORT _CUDA_HD int checkArrangeArray(T *arr, int arrLength, + int shapeLength); /** * Permute the shape information @@ -342,7 +400,8 @@ namespace shape { * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - SD_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank); +SD_EXPORT _CUDA_HD void permute(ShapeInformation **info, int *rearrange, + int rank); /** * Returns whether the @@ -350,49 +409,51 @@ namespace shape { * @param shape the shape of the array * @param rank the rank of cthe shape */ - SD_EXPORT _CUDA_HD int isVector(Nd4jLong const* shape, int rank); +SD_EXPORT _CUDA_HD int isVector(Nd4jLong const *shape, int rank); +/** + * When 1 dimension is the whole length of the + * array + */ +SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); - /** - * When 1 dimension is the whole length of the - * array - */ - SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank); - - SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int isVector(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim); +SD_EXPORT _CUDA_HD bool isLikeVector(Nd4jLong const *shapeInfo, + int &posOfNonUnityDim); - SD_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim); +SD_EXPORT _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, + int &posOfNonUnityDim); - SD_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const* shapeInfo); +SD_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong const *shapeInfo); - /** - * shape - input inShape is shape only, not shapeInfo - * returns number of non-unity dimensions in inShape - */ - SD_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape); +/** + * shape - input inShape is shape only, not shapeInfo + * returns number of non-unity dimensions in inShape + */ +SD_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, + const Nd4jLong *inShape); - /** +/** * Returns whether the * given shape is a vector or not * @param shape the shape of the array * @param rank the rank of the shape */ - SD_EXPORT _CUDA_HD int isMatrix(Nd4jLong *shape, int rank); +SD_EXPORT _CUDA_HD int isMatrix(Nd4jLong *shape, int rank); - INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo); +INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo); /** * Returns the shape portion of an information * buffer */ - SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo); /** * Return a copy of a buffer. @@ -400,26 +461,27 @@ namespace shape { * that must be freed elsewhere. */ - template - SD_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy); +template +SD_EXPORT _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy); - template - SD_EXPORT _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret); +template +SD_EXPORT _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy, T *ret); - /** +/** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - template - SD_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to); - /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - SD_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes); +template +SD_EXPORT _CUDA_HD void copyTo(Nd4jLong length, T const *from, T *to); +/** + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +SD_EXPORT _CUDA_HD void copyTo(int length, Nd4jLong const *from, Nd4jLong *to, + Nd4jLong *indexes); /** * Permute the given strides @@ -430,18 +492,20 @@ namespace shape { * and all must be filled in) * @return the rearranged array */ - //SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); +// SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int +// shapeRank, Nd4jLong *rearrange); /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - SD_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); +SD_EXPORT _CUDA_HD Nd4jLong *slice(Nd4jLong *shape); - SD_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD int slices(Nd4jLong *shapeBuffer); - SD_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer); +SD_EXPORT _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, + Nd4jLong *shapeBuffer); /** * Returns the length of the * shape information buffer: @@ -450,30 +514,30 @@ namespace shape { * info length for * @return rank * 2 + 4 */ - SD_EXPORT _CUDA_HD int shapeInfoLength(int rank); +SD_EXPORT _CUDA_HD int shapeInfoLength(int rank); - SD_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD int shapeInfoLength(Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD int shapeInfoLength(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(int rank); - SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo); /** * Returns the rank portion of * an information buffer */ - SD_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD int rank(const int *shapeInfo); - SD_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const int *shapeInfo); +SD_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo); - /** - * returns pointer on elementWiseStride - */ - SD_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo); +/** + * returns pointer on elementWiseStride + */ +SD_EXPORT _CUDA_HD Nd4jLong *ews(Nd4jLong *shapeInfo); /** * Converts a raw int buffer of the layout: @@ -485,65 +549,65 @@ namespace shape { * * where shape and stride are both straight int pointers */ - SD_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer); /** * Returns the stride portion of an information * buffer */ - SD_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer); - SD_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer); /** * Compute the length of the given shape */ - SD_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); +SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list &shape); - SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list& shape); +SD_EXPORT _CUDA_HD Nd4jLong length(std::initializer_list &shape); /*** * Returns the offset portion of an information buffer */ - SD_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong offset(Nd4jLong *buffer); - SD_EXPORT _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong &extra(Nd4jLong *buffer); /** * Returns the ordering * for this shape information buffer */ - SD_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); +SD_EXPORT _CUDA_HD char order(const Nd4jLong *buffer); /** * Returns the type */ - SD_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong* shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo); /** * Returns the element wise stride for this information * buffer */ - SD_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer); +SD_EXPORT _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer); - - /** +/** * Returns the element wise stride for this information * buffer - * relative to a dimension and ordering for a reduction index + * relative to a dimension and ordering for a reduction index */ - SD_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong *buffer, int *dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong reductionIndexElementWiseStride( + Nd4jLong *buffer, int *dimension, int dimensionLength); /** * Returns whether * the given shape info buffer * represents a scalar shape */ - SD_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); +SD_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); /** * Returns whether @@ -551,7 +615,7 @@ namespace shape { * represents a scalar * shape or not */ - SD_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); +SD_EXPORT _CUDA_HD int isScalar(volatile ShapeInformation *info); /** * Return a copy of this array with the @@ -565,10 +629,12 @@ namespace shape { * * item */ - template - SD_EXPORT _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *out); +template +SD_EXPORT _CUDA_HD void removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength, + T1 *out); - /** +/** * Return a copy of this array with the * given index omitted * @@ -581,21 +647,24 @@ namespace shape { * item */ - template - SD_EXPORT _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength); +template +SD_EXPORT _CUDA_HD T1 *removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength); - /** - * Iterate over a given set of indexes - * the begin and end indexes are 0 based. - * 1 padding is automatically assumed for the ending. - * - * For example if you want to iterate over 0 to 4 - * it will go to 4 rather than 3. - * - * indexes should be the indexes to exclude - * indexes length should be the length of indexes - */ - SD_EXPORT _CUDA_HD Nd4jLong* everyIndexBut(Nd4jLong const* indexes,int indexesLength,int begin,int end); +/** + * Iterate over a given set of indexes + * the begin and end indexes are 0 based. + * 1 padding is automatically assumed for the ending. + * + * For example if you want to iterate over 0 to 4 + * it will go to 4 rather than 3. + * + * indexes should be the indexes to exclude + * indexes length should be the length of indexes + */ +SD_EXPORT _CUDA_HD Nd4jLong *everyIndexBut(Nd4jLong const *indexes, + int indexesLength, int begin, + int end); /** * Computes the offset for accessing @@ -615,11 +684,11 @@ namespace shape { * for the shape to be returned as * @return the new shape */ - SD_EXPORT _CUDA_HD Nd4jLong* ensureVectorShape(Nd4jLong *shape); +SD_EXPORT _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape); - SD_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(); +SD_EXPORT _CUDA_HD Nd4jLong *createScalarShapeInfo(); - SD_EXPORT _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret); +SD_EXPORT _CUDA_HD Nd4jLong *createScalarShapeInfo(Nd4jLong *ret); /** * Generate an int buffer @@ -627,21 +696,22 @@ namespace shape { * at the specified increment * */ - template - SD_EXPORT _CUDA_HD T* range(int from, int to, int increment); +template +SD_EXPORT _CUDA_HD T *range(int from, int to, int increment); /** * Range between from and two with an * increment of 1 */ - template - SD_EXPORT _CUDA_HD T* range(int from, int to); +template +SD_EXPORT _CUDA_HD T *range(int from, int to); /** * Keep the given indexes * in the data */ - SD_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength); +SD_EXPORT _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const *index, + int indexLength, int dataLength); /** * Generate reverse copy of the data @@ -650,17 +720,18 @@ namespace shape { * @return */ - template - SD_EXPORT _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length); +template +SD_EXPORT _CUDA_HD T *reverseCopy(T const *data, Nd4jLong length); - template - SD_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length); +template +SD_EXPORT _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong length); - template - SD_EXPORT _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length); +template +SD_EXPORT _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong *indexes, + Nd4jLong length); - template - SD_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); +template +SD_EXPORT _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length); /** * * @param arr1 @@ -669,8 +740,9 @@ namespace shape { * @param arr2Length * @return */ - template - SD_EXPORT _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length); +template +SD_EXPORT _CUDA_HD T *concat(T const *arr1, Nd4jLong const arr1Length, + T const *arr2, Nd4jLong const arr2Length); /** * @@ -680,8 +752,9 @@ namespace shape { * @param lengths * @return */ - template - SD_EXPORT _CUDA_HD T* concat(int const numArrays, int const numTotalElements, Nd4jLong const**arr, Nd4jLong const* lengths); +template +SD_EXPORT _CUDA_HD T *concat(int const numArrays, int const numTotalElements, + Nd4jLong const **arr, Nd4jLong const *lengths); /** * Get the length per slice of the @@ -695,7 +768,9 @@ namespace shape { * @return the length per slice of the given shape * along the given dimension */ - SD_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const *shape, + int const *dimension, + int dimensionLength); /** * calculates the offset for a tensor @@ -704,13 +779,9 @@ namespace shape { * @param tensorShape * @return */ - SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, - int index, - Nd4jLong const* shape, - Nd4jLong const* tensorShape, - int tensorShapeLength, - int const *dimension, - int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor( + int rank, int index, Nd4jLong const *shape, Nd4jLong const *tensorShape, + int tensorShapeLength, int const *dimension, int dimensionLength); /** * calculates the offset for a tensor @@ -719,14 +790,16 @@ namespace shape { * @param tensorShape * @return */ - SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); +SD_EXPORT _CUDA_HD Nd4jLong sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2); /** * Computes the tensor along dimension * offset * @param index the index to get the offset for the tad for * @param rank the rank of the shapes and strides * @param info the shape information to use for tad - * @param dimension the dimensions to use for computing the tensor along dimensions + * @param dimension the dimensions to use for computing the tensor along + * dimensions */ // SD_EXPORT _CUDA_HD int offset(int index, // int rank, @@ -734,26 +807,24 @@ namespace shape { // Nd4jLong *dimension, // int dimensionLength); - /** * Computes the number * of tensors along * a given dimension */ - SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, - volatile int length, - volatile Nd4jLong *shape, - int *dimension, - int dimensionLength); +SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(int rank, volatile int length, + volatile Nd4jLong *shape, + int *dimension, + int dimensionLength); /** * Computes the number * of tensors along * a given dimension */ - SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength); - - +SD_EXPORT _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, + int *dimension, + int dimensionLength); /** * Returns the tensor along dimension @@ -763,24 +834,26 @@ namespace shape { * @param i * @return */ - SD_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); +SD_EXPORT _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i); /** * Computes the number of tads per block * */ - SD_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); +SD_EXPORT _CUDA_HD int tadsPerBlock(int blockSize, int tads); -// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, +// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, +// Nd4jLong *dimension, // int dimensionLength); /** * Returns a shape buffer * for the shape information metadata. */ - SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info); +SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info); - SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret); +SD_EXPORT _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info, + Nd4jLong *ret); /** * Returns the number of elements per thread @@ -831,25 +904,29 @@ namespace shape { * @param numElementsPerTad the number of elements * per tad */ - SD_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad); +SD_EXPORT _CUDA_HD int tadIndex(int i, int elementWiseStride, + int numElementsPerTad); /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - SD_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal); +SD_EXPORT _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal); /** * Computes the number of tads * per reduce index for the * reduction tad. */ - SD_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); +SD_EXPORT _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal); /** * Maps a linear index to a reduction index @@ -859,259 +936,357 @@ namespace shape { * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - SD_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum); +SD_EXPORT _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum); /** * Returns the prod of the data * up to the given length */ - SD_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); - - /** - * Returns the rear most left over item not present in - * the dimension array. This assumes that the dimension array is sorted. - * - * For example, given a dimension array of: - * 0,2 - * - * and - * - * 12,4,2,1 in data - * - * You end up with 1 (data[3]) - * since the first item won't match - * the last item of the dimension array - */ +SD_EXPORT _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length); -// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ - - SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0); - SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0); - SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0); - - SD_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); - - SD_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer); - - /** - * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] - */ - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords); - - SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords); - SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords); - - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims); - - /** - * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned - */ - SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); - SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords); - SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords); - SD_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *coords); - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims); - - /** - * increment n-dimensional array by one iteration by changing coord appropriately - * for example we have array with shape {2, 3}: - * - if input coord = {0,1}, then output coord = {0,2} - * - if input coord = {0,2}, then output coord = {1,0} - * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, {1,0}, {1,1}, {1,2} - */ +/** + * Returns the rear most left over item not present in + * the dimension array. This assumes that the dimension array is sorted. + * + * For example, given a dimension array of: + * 0,2 + * + * and + * + * 12,4,2,1 in data + * + * You end up with 1 (data[3]) + * since the first item won't match + * the last item of the dimension array + */ - /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - */ - SD_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); - SD_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned); +// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int +// length,Nd4jLong *dimension,int dimensionLength); + +/** + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ + +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const Nd4jLong *coords, + Nd4jLong baseOffset = 0); +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const int *coords, + Nd4jLong baseOffset = 0); +SD_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const uint *coords, + Nd4jLong baseOffset = 0); + +SD_EXPORT _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank); + +SD_EXPORT _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank, Nd4jLong *buffer); + +/** + * Convert a linear index to the corresponding coordinates + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, + * 1] + */ +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + uint *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, int *coords); + +SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + Nd4jLong *coords); +SD_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, int *coords); + +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +SD_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords, const int dimsSize, + const int *tadDims); + +/** + * Convert coordinates to the corresponding linear index (sequence number in + * other words) for example if shape is {2, 4} and coordinates [1, 1] then index + * 5 is returned + */ +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const Nd4jLong *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const uint *coords); +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, + const int *coords); +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +SD_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords, const int dimsSize, + const int *tadDims); + +/** + * increment n-dimensional array by one iteration by changing coord + * appropriately for example we have array with shape {2, 3}: + * - if input coord = {0,1}, then output coord = {0,2} + * - if input coord = {0,2}, then output coord = {1,0} + * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, + * {1,0}, {1,1}, {1,2} + */ - SD_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); +/* calculates an array buffer offset for given "index" using following formula: + * offset = coord_0*stride_0 + coord_1*stride_1 + ... + + * coord_{rank-1}*stride_{rank-1} + */ +SD_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, + const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, + const Nd4jLong *lShapeInfo, + const uint *uShapeInfo, + const bool useUnsigned); + +SD_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, + const Nd4jLong *shapeInfo); - SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides); +SD_EXPORT _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, + const Nd4jLong *shape, + const Nd4jLong *strides); - SD_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); - SD_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); +SD_EXPORT _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length); +SD_EXPORT _CUDA_HD void printIntArray(const int *arr, const int length); - SD_EXPORT _CUDA_HD void printArray(float *arr,int length); +SD_EXPORT _CUDA_HD void printArray(float *arr, int length); - template - SD_EXPORT _CUDA_HD void printArray(T *arr,int length, const char *message); +template +SD_EXPORT _CUDA_HD void printArray(T *arr, int length, const char *message); - SD_EXPORT _CUDA_HD Nd4jLong* shapeBufferOfNpy(int rank, unsigned int *shape,bool fortranOrder); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int *shape, + bool fortranOrder); - SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); +SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr); // SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also sort +// input array of dimensions, this operation is also necessary for creating TAD +// object +SD_EXPORT _CUDA_H void checkDimensions(const int rank, + std::vector &dimensions); + +// function calculates linear index of array min, min is sub-array of max, index +// to be returned is min-array's index and corresponds to maxIdx of max array +// dimsToExclude - should be sorted in increasing order +SD_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); + +// function calculates absolute offset of min array, min is sub-array of max, +// offset to be returned corresponds to maxIdx of max array dimsToExclude - +// should be sorted in increasing order +SD_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); + +// max array is outer for min array, min array is sub-array of max array +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +// dimsToExclude - should be sorted in increasing order +// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated +// as maxRank - minRank +SD_EXPORT _CUDA_HD void maxIndToMinInd(int *maxIdxs, int *minIdxs, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr, + const int dimsLen = -1); + +// calculate indexes of max-array, these output indexes correspond to one minIdx +// index of min-array which is sub-array of max-array dimsToExclude - should be +// sorted in increasing order +SD_EXPORT _CUDA_HD int outerArrayIndexes(int *maxIdxs, const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude = nullptr); + +// calculate offsets of max-array, these offsets correspond to one minIdx index +// of min-array which is sub-array of max-array maxOffsets - will contain +// calculated offsets of max-array, buffer for maxOffsets should be allocated +// beforehand dimsToExclude - should be sorted in increasing order memBuff - +// auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments +// storing, should be allocated beforehand +SD_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong *maxOffsets, + const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + int *memBuff, + const int *dimsToExclude = nullptr); + +// calculates offsets for entities (elements or sub-arrays), shape in context of +// sub-array means dimensions excluded from outer array rank is equal to size of +// shape +SD_EXPORT void calcOffsets(const int rank, const Nd4jLong *shape, + const Nd4jLong *strides, Nd4jLong *offsets, + const char order = 'c'); +SD_EXPORT void calcOffsets(const Nd4jLong *shapeInfo, Nd4jLong *offsets, + const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, +// Nd4jLong*& zOffsets, const char order = 'c'); +SD_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, + Nd4jLong *const buffer, + const char order); + +// deduce order and element-wise stride +// if array is scalar or unit length vector then ews = 1 and order is preserved +// if array is common vector then ews = stride of non-unity dimension and order +// is preserved if strides are normal/contiguous then ews = 1 and corresponding +// order is set, otherwise ews = 0 and order is preserved +SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder( + Nd4jLong *shapeInfo, const char proposedOrder, const int numOfNonUnitDims, + const Nd4jLong *shapeNoUnities, const Nd4jLong *stridesNoUnities); +SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong *shapeInfo); + +/** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) arguments: wholeShapeInfo - original shapeInfo of whole + * array numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to + * numOfSubArrs dimsSize - size of dimsToExclude, if dimsSize = array rank or + * dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and + * one zero offset will be returned dimsToExclude - MUST BE SORTED, dimensions + * to evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] + * subArrShapeInfo - output argument, contains shapeInfo (same for all + * sub-arrays) subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ +SD_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets( + const Nd4jLong *wholeShapeInfo, const Nd4jLong numOfSubArrs, + const int dimsSize, const int *dimsToExclude, Nd4jLong *subArrShapeInfo, + Nd4jLong *subArrOffsets, bool keepUnitiesInShape = false); + +/** + * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer + * offset from original array arguments: idx - input argument, intervals of + * indexes which define the sub-array to point on, when isStrided = false then + * idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * + * maxRank) when isStrided = true then idx has form + * {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and + * length (3 * maxRank) when (dimStart == dimEnd) then whole range will be used + * for current dimension maxShapeInfo - input argument, shapeInfo of original + * array minShapeInfo - output argument, shapeInfo of sub-array to be deduced + * minOffset - output argument, offset of sub-array buffer offsets from original + * buffer keepUnitiesInShape - input argument, if false then eliminate unities + * from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} isStrided - input + * argument, if true then idx has length (3 * this->rankOf()) and contains + * additional stride numbers which correspond to stride between dimStart and + * dimEnd, numOfUntiesInMinShape - input argument, number of occurrences in idx + * when (dimEnd - dimStart) = 1 + */ +SD_EXPORT void calcSubArrShapeInfoAndOffset( + const Nd4jLong *idx, const Nd4jLong *maxShapeInfo, Nd4jLong *minShapeInfo, + Nd4jLong &minOffset, const bool keepUnitiesInShape = false, + const bool isStrided = false, const int numOfUntiesInMinShape = 0); + +/** + * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} + * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and + * strides, no rank/type/ews/order stridesNoUnities will point on strides in + * shapeNoUnities that is on {4,1} returns number of non-unity dimensions in + * inShapeInfo if there is no unities in inShapeInfo, then no copy procedure + * will be performed and shapeNoUnities/stridesNoUnities will point on + * corresponding places in inShapeInfo + */ +SD_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + Nd4jLong *&shapeNoUnities, + Nd4jLong *&stridesNoUnities); + +/** + * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, + * dimsToExclude = {1,3}, dimsSize = 2 then outShapeInfo will contain {3, 2,3,4, + * 12,4,1, 16384,1,99} + */ +INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + const int dimsSize, + const int *dimsToExclude, + Nd4jLong *outShapeInfo); - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also sort input array of dimensions, this operation is also necessary for creating TAD object - SD_EXPORT _CUDA_H void checkDimensions(const int rank, std::vector& dimensions); - - // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - SD_EXPORT _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); - - // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - SD_EXPORT _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); - - // max array is outer for min array, min array is sub-array of max array - // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) - // dimsToExclude - should be sorted in increasing order - // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - SD_EXPORT _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr, const int dimsLen = -1); - - // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array - // dimsToExclude - should be sorted in increasing order - SD_EXPORT _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); - - // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array - // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand - // dimsToExclude - should be sorted in increasing order - // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - SD_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude = nullptr); - - // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array - // rank is equal to size of shape - SD_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c'); - SD_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c'); - // SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); - // SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); - SD_EXPORT _CUDA_HD void shapeOldScalar(sd::DataType dtype, Nd4jLong* const buffer, const char order); - - // deduce order and element-wise stride - // if array is scalar or unit length vector then ews = 1 and order is preserved - // if array is common vector then ews = stride of non-unity dimension and order is preserved - // if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved - SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities); - SD_EXPORT _CUDA_HD void checkStridesEwsAndOrder(Nd4jLong* shapeInfo); - - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * arguments: - * wholeShapeInfo - original shapeInfo of whole array - * numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs - * dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays) - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - SD_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false); - - /** - * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array - * arguments: - * idx - input argument, intervals of indexes which define the sub-array to point on, - * when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank) - * when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * maxShapeInfo - input argument, shapeInfo of original array - * minShapeInfo - output argument, shapeInfo of sub-array to be deduced - * minOffset - output argument, offset of sub-array buffer offsets from original buffer - * keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - * isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1 - */ - SD_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0); - - /** - * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} - * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order - * stridesNoUnities will point on strides in shapeNoUnities that is on {4,1} - * returns number of non-unity dimensions in inShapeInfo - * if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo - */ - SD_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities); - - /** - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo); - - /** - * get stride over contiguous axis (contiguous axis must have stride = 1) - * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) - */ - // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); - - - - - - -//END HEADERS - - - //BEGIN IMPLEMENTATIONS +/** + * get stride over contiguous axis (contiguous axis must have stride = 1) + * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then + * output is 5 (that is smallest stride in inShapeInfo except those equal to 1) + */ +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo); +// END HEADERS +// BEGIN IMPLEMENTATIONS #ifdef __CUDACC__ - /** -* BEWARE: THIS METHOD DOES NOT CHECKS ALLOCATION BOUNDARIES -*/ +/** + * BEWARE: THIS METHOD DOES NOT CHECKS ALLOCATION BOUNDARIES + */ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { - Nd4jLong *ret = buffer; - ret += (threadIdx.x * size); - return ret; + Nd4jLong *ret = buffer; + ret += (threadIdx.x * size); + return ret; } #endif /** -* Length of a tad given -* the shape information -*/ - INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { - if(dimensionLength == 1) { - return shape::shapeOf(shapeInfo)[dimension[0]]; - } - else { - Nd4jLong ret = 1; - for(int i = 0; i < shape::rank(shapeInfo); i++) { - for(int j = 0; j < dimensionLength; j++) { - if(i == dimension[j]) - ret *= shape::shapeOf(shapeInfo)[dimension[j]]; - } - } - return ret; - } + * Length of a tad given + * the shape information + */ +INLINEDEF _CUDA_HD Nd4jLong tadLength(const Nd4jLong *shapeInfo, int *dimension, + int dimensionLength) { + if (dimensionLength == 1) { + return shape::shapeOf(shapeInfo)[dimension[0]]; + } else { + Nd4jLong ret = 1; + for (int i = 0; i < shape::rank(shapeInfo); i++) { + for (int j = 0; j < dimensionLength; j++) { + if (i == dimension[j]) ret *= shape::shapeOf(shapeInfo)[dimension[j]]; + } } - - + return ret; + } +} /** * Tad element wise stride: @@ -1139,189 +1314,199 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * Again: this may not preserve ordering of the tad * but maybe used for reductions. */ - INLINEDEF _CUDA_HD int tadElementWiseStride(Nd4jLong *shapeInfo, int *dimension,int dimensionLength) { - return reductionIndexElementWiseStride(shapeInfo,dimension,dimensionLength); - } - - - INLINEDEF _CUDA_HD bool shapeEquals(const int shape1Rank, const Nd4jLong *shape1, const int shape2Rank, const Nd4jLong *shape2) { - if(shape1Rank != shape2Rank) - return false; - //rank not equals - for(int i = 0; i < shape1Rank; i++) { - if(shape1[i] != shape2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { - return shape::shapeEquals(shape::rank(shapeInfo1), shape::shapeOf(const_cast(shapeInfo1)), shape::rank(shapeInfo2), shape::shapeOf(const_cast(shapeInfo2))); - } - - INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3) { - - return shape::shapeEquals(shapeInfo1, shapeInfo2) && shape::shapeEquals(shapeInfo1, shapeInfo3); - - } - - INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, Nd4jLong const* shape1,int const shape2Rank,Nd4jLong const* shape2) { - if(shape1Rank != shape2Rank) - return false; - //rank not equals - for(int i = 0; i < shape1Rank; i++) { - if(shape1[i] != shape2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* shapeInfo1,Nd4jLong const* shapeInfo2) { - return shape::strideEquals(shape::rank(shapeInfo1),shape::stride(shapeInfo1),shape::rank(shapeInfo2),shape::stride(shapeInfo2)); - - } - - INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const* stride1,int const rank1 , Nd4jLong const* stride2, int const rank2) { - if(rank1 != rank2) - return false; - - for(int i = 0; i < rank1; i++) { - if(stride1[i] != stride2[i]) - return false; - } - - return true; - } - - INLINEDEF _CUDA_HD Nd4jLong *computeResultShape(Nd4jLong const* originalShapeBuffer, int * dimension,int dimensionLength) { - Nd4jLong *retShape; - int retShapeLength; - if(dimensionLength == 1 && dimension[0] == 2147483647) { - retShape = new Nd4jLong[2]; - retShape[0] = 1; - retShape[1] = 1; - retShapeLength = 2; - } - else { - retShape = shape::removeIndex(shape::shapeOf(originalShapeBuffer), dimension, shape::shapeInfoLength(shape::rank(originalShapeBuffer)), dimensionLength); - retShapeLength = shape::rank(originalShapeBuffer) - dimensionLength; - } - //ensure vector is proper shape - if (retShapeLength == 1) { - if (dimension[0] == 0) { - auto newRetShape = new Nd4jLong[2]{1, retShape[0]}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - else { - auto newRetShape = new Nd4jLong[2]{retShape[0], 1}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - } else if (retShapeLength == 0) { - auto newRetShape = new Nd4jLong[2]{1, 1}; - delete[] retShape; - retShape = newRetShape; - retShapeLength = 2; - } - - auto ret = shape::shapeBuffer(retShapeLength, sd::ArrayOptions::dataType(originalShapeBuffer), retShape); - delete[] retShape; - - return ret; +INLINEDEF _CUDA_HD int tadElementWiseStride(Nd4jLong *shapeInfo, int *dimension, + int dimensionLength) { + return reductionIndexElementWiseStride(shapeInfo, dimension, dimensionLength); +} - } +INLINEDEF _CUDA_HD bool shapeEquals(const int shape1Rank, + const Nd4jLong *shape1, + const int shape2Rank, + const Nd4jLong *shape2) { + if (shape1Rank != shape2Rank) return false; + // rank not equals + for (int i = 0; i < shape1Rank; i++) { + if (shape1[i] != shape2[i]) return false; + } + + return true; +} - INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride, Nd4jLong *buffer) { - Nd4jLong *theShape = shape::shapeOf(shapeInfo); - Nd4jLong *theStride = shape::stride(shapeInfo); - int rank = dimensionLength == 1 ? 2 : dimensionLength; - Nd4jLong *ret = buffer; - //set the rank - ret[0] = rank; - Nd4jLong *retShape = shape::shapeOf(ret); - Nd4jLong *retStride = shape::stride(ret); - int len = rank; - - if(dimensionLength == 1) { - if(shape::isMatrix(theShape,shape::rank(shapeInfo))) { - if(dimension[0] == 0) { - Nd4jLong newStride[2] = {theStride[dimension[0]],1}; - Nd4jLong newShape[2] = {theShape[dimension[0]],1}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } - else { - Nd4jLong newStride[2] = {theStride[dimension[0]],1}; - Nd4jLong newShape[2] = {theShape[dimension[0]],1}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } - } - else { - Nd4jLong newStride[2] = {1,theStride[dimension[0]]}; - Nd4jLong newShape[2] = {1,theShape[dimension[0]]}; - retShape[0] = newShape[0]; - retShape[1] = newShape[1]; - retStride[0] = newStride[0]; - retStride[1] = newStride[1]; - } +INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2) { + return shape::shapeEquals(shape::rank(shapeInfo1), + shape::shapeOf(const_cast(shapeInfo1)), + shape::rank(shapeInfo2), + shape::shapeOf(const_cast(shapeInfo2))); +} +INLINEDEF _CUDA_HD bool shapeEquals(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3) { + return shape::shapeEquals(shapeInfo1, shapeInfo2) && + shape::shapeEquals(shapeInfo1, shapeInfo3); +} +INLINEDEF _CUDA_HD bool strideEquals(int const shape1Rank, + Nd4jLong const *shape1, + int const shape2Rank, + Nd4jLong const *shape2) { + if (shape1Rank != shape2Rank) return false; + // rank not equals + for (int i = 0; i < shape1Rank; i++) { + if (shape1[i] != shape2[i]) return false; + } + + return true; +} - } - else { - Nd4jLong *newIndexes = dimension; - if(reverseCopyStride) - shape::reverseCopyTo(theStride, retStride, newIndexes, len); - else - shape::copyTo(len, theStride, retStride, newIndexes); - shape::copyTo(len, theShape, retShape, newIndexes); +INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const *shapeInfo1, + Nd4jLong const *shapeInfo2) { + return shape::strideEquals(shape::rank(shapeInfo1), shape::stride(shapeInfo1), + shape::rank(shapeInfo2), + shape::stride(shapeInfo2)); +} - } +INLINEDEF _CUDA_HD bool strideEquals(Nd4jLong const *stride1, int const rank1, + Nd4jLong const *stride2, int const rank2) { + if (rank1 != rank2) return false; + for (int i = 0; i < rank1; i++) { + if (stride1[i] != stride2[i]) return false; + } - ret[shape::shapeInfoLength(rank) - 1] = shape::order(shapeInfo); - return ret; - } + return true; +} - INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride(const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength,bool reverseCopyStride) { - int rank = dimensionLength == 1 ? 2 : dimensionLength; +INLINEDEF _CUDA_HD Nd4jLong *computeResultShape( + Nd4jLong const *originalShapeBuffer, int *dimension, int dimensionLength) { + Nd4jLong *retShape; + int retShapeLength; + if (dimensionLength == 1 && dimension[0] == 2147483647) { + retShape = new Nd4jLong[2]; + retShape[0] = 1; + retShape[1] = 1; + retShapeLength = 2; + } else { + retShape = shape::removeIndex( + shape::shapeOf(originalShapeBuffer), dimension, + shape::shapeInfoLength(shape::rank(originalShapeBuffer)), + dimensionLength); + retShapeLength = shape::rank(originalShapeBuffer) - dimensionLength; + } + // ensure vector is proper shape + if (retShapeLength == 1) { + if (dimension[0] == 0) { + auto newRetShape = new Nd4jLong[2]{1, retShape[0]}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } else { + auto newRetShape = new Nd4jLong[2]{retShape[0], 1}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } + } else if (retShapeLength == 0) { + auto newRetShape = new Nd4jLong[2]{1, 1}; + delete[] retShape; + retShape = newRetShape; + retShapeLength = 2; + } + + auto ret = shape::shapeBuffer(retShapeLength, + sd::ArrayOptions::dataType(originalShapeBuffer), + retShape); + delete[] retShape; + + return ret; +} - traceNew(4); +INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride, Nd4jLong *buffer) { + Nd4jLong *theShape = shape::shapeOf(shapeInfo); + Nd4jLong *theStride = shape::stride(shapeInfo); + int rank = dimensionLength == 1 ? 2 : dimensionLength; + Nd4jLong *ret = buffer; + // set the rank + ret[0] = rank; + Nd4jLong *retShape = shape::shapeOf(ret); + Nd4jLong *retStride = shape::stride(ret); + int len = rank; + + if (dimensionLength == 1) { + if (shape::isMatrix(theShape, shape::rank(shapeInfo))) { + if (dimension[0] == 0) { + Nd4jLong newStride[2] = {theStride[dimension[0]], 1}; + Nd4jLong newShape[2] = {theShape[dimension[0]], 1}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } else { + Nd4jLong newStride[2] = {theStride[dimension[0]], 1}; + Nd4jLong newShape[2] = {theShape[dimension[0]], 1}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } + } else { + Nd4jLong newStride[2] = {1, theStride[dimension[0]]}; + Nd4jLong newShape[2] = {1, theShape[dimension[0]]}; + retShape[0] = newShape[0]; + retShape[1] = newShape[1]; + retStride[0] = newStride[0]; + retStride[1] = newStride[1]; + } + + } else { + Nd4jLong *newIndexes = dimension; + if (reverseCopyStride) + shape::reverseCopyTo(theStride, retStride, newIndexes, len); + else + shape::copyTo(len, theStride, retStride, newIndexes); + shape::copyTo(len, theShape, retShape, newIndexes); + } + + ret[shape::shapeInfoLength(rank) - 1] = shape::order(shapeInfo); + return ret; +} - Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; - return shapeInfoOnlyShapeAndStride(shapeInfo, dimension, dimensionLength, reverseCopyStride, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *shapeInfoOnlyShapeAndStride( + const Nd4jLong *shapeInfo, Nd4jLong *dimension, int dimensionLength, + bool reverseCopyStride) { + int rank = dimensionLength == 1 ? 2 : dimensionLength; - INLINEDEF _CUDA_HD Nd4jLong * createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank) { + traceNew(4); - traceNew(5); + Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; + return shapeInfoOnlyShapeAndStride(shapeInfo, dimension, dimensionLength, + reverseCopyStride, ret); +} - Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; +INLINEDEF _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank) { + traceNew(5); - return createShapeInfo(shape, stride, rank, ret); - } + Nd4jLong *ret = new Nd4jLong[shape::shapeInfoLength(rank)]; - INLINEDEF _CUDA_HD Nd4jLong * createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank, Nd4jLong *buffer) { - buffer[0] = rank; - Nd4jLong *retShape = shape::shapeOf(buffer); - Nd4jLong *retStride = shape::stride(buffer); - for(int i = 0;i < rank; i++) { - retShape[i] = shape[i]; - retStride[i] = stride[i]; - } + return createShapeInfo(shape, stride, rank, ret); +} - return buffer; - } +INLINEDEF _CUDA_HD Nd4jLong *createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, + int rank, Nd4jLong *buffer) { + buffer[0] = rank; + Nd4jLong *retShape = shape::shapeOf(buffer); + Nd4jLong *retStride = shape::stride(buffer); + for (int i = 0; i < rank; i++) { + retShape[i] = shape[i]; + retStride[i] = stride[i]; + } + + return buffer; +} /** * Computes the standard packed array strides for a given shape. @@ -1330,50 +1515,47 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum) { - if (isVector(shape, rank)) { +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum) { + if (isVector(shape, rank)) { + traceNew(5); - traceNew(5); - - Nd4jLong *ret = new Nd4jLong[2]; - for (int i = 0; i < 2; i++) - ret[i] = 1; - return ret; - - } - - int dimensions = rank; + Nd4jLong *ret = new Nd4jLong[2]; + for (int i = 0; i < 2; i++) ret[i] = 1; + return ret; + } - traceNew(6); + int dimensions = rank; - Nd4jLong *stride = new Nd4jLong[dimensions]; - Nd4jLong st = startNum; - for (int j = 0; j < rank; j++) { - stride[j] = st; - st *= shape[j]; - } + traceNew(6); - return stride; - } + Nd4jLong *stride = new Nd4jLong[dimensions]; + Nd4jLong st = startNum; + for (int j = 0; j < rank; j++) { + stride[j] = st; + st *= shape[j]; + } - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, int startNum, Nd4jLong *ret) { - if (isVector(shape, rank)) { - for (int i = 0; i < rank; i++) - ret[i] = 1; - return ret; + return stride; +} - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret) { + if (isVector(shape, rank)) { + for (int i = 0; i < rank; i++) ret[i] = 1; + return ret; + } - //int dimensions = rank; + // int dimensions = rank; - Nd4jLong st = startNum; - for (int j = 0; j < rank; j++) { - ret[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = 0; j < rank; j++) { + ret[j] = st; + st *= shape[j]; + } - return ret; - } + return ret; +} /** * Computes the standard packed array strides for a given shape. @@ -1382,55 +1564,55 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const *shape, int rank, int startNum) { - - traceNew(7); - - Nd4jLong *stride = new Nd4jLong[rank]; +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum) { + traceNew(7); - if (rank == 1) { - stride[0] = 1; - return stride; - } + Nd4jLong *stride = new Nd4jLong[rank]; + if (rank == 1) { + stride[0] = 1; + return stride; + } - // if (shape::isVector(shape, rank)) { - // for (int i = 0; i < 2; i++) - // stride[i] = 1; - // return stride; + // if (shape::isVector(shape, rank)) { + // for (int i = 0; i < 2; i++) + // stride[i] = 1; + // return stride; - // } + // } - Nd4jLong st = startNum; - for (int j = rank - 1; j >= 0; j--) { - stride[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = rank - 1; j >= 0; j--) { + stride[j] = st; + st *= shape[j]; + } - return stride; - } + return stride; +} - INLINEDEF _CUDA_HD Nd4jLong * calcStrides(Nd4jLong const* shape, int rank, int startNum, Nd4jLong* ret) { - if (rank == 1) { - ret[0] = 1; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + int startNum, Nd4jLong *ret) { + if (rank == 1) { + ret[0] = 1; + return ret; + } - // if (shape::isVector(shape, rank)) { - // for (int i = 0; i < 2; i++) - // ret[i] = 1; - // return ret; + // if (shape::isVector(shape, rank)) { + // for (int i = 0; i < 2; i++) + // ret[i] = 1; + // return ret; - // } + // } - Nd4jLong st = startNum; - for (int j = rank - 1; j >= 0; j--) { - ret[j] = st; - st *= shape[j]; - } + Nd4jLong st = startNum; + for (int j = rank - 1; j >= 0; j--) { + ret[j] = st; + st *= shape[j]; + } - return ret; - } + return ret; +} /** * Computes the standard packed array strides for a given shape. @@ -1439,13 +1621,15 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank) { - return calcStridesFortran(shape, rank, 1); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, + int rank) { + return calcStridesFortran(shape, rank, 1); +} - INLINEDEF _CUDA_HD Nd4jLong * calcStridesFortran(Nd4jLong const* shape, int rank, Nd4jLong* ret) { - return calcStridesFortran(shape, rank, 1, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStridesFortran(Nd4jLong const *shape, int rank, + Nd4jLong *ret) { + return calcStridesFortran(shape, rank, 1, ret); +} /** * Computes the standard packed array strides for a given shape. @@ -1454,423 +1638,436 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) { * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank) { - return calcStrides(shape, rank, 1); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank) { + return calcStrides(shape, rank, 1); +} - INLINEDEF _CUDA_HD Nd4jLong* calcStrides(Nd4jLong const *shape, int rank, Nd4jLong* ret) { - return calcStrides(shape, rank, 1, ret); - } +INLINEDEF _CUDA_HD Nd4jLong *calcStrides(Nd4jLong const *shape, int rank, + Nd4jLong *ret) { + return calcStrides(shape, rank, 1, ret); +} ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD void updateStrides(Nd4jLong *shapeInfo, const char order) { - int rank = shapeInfo[0]; - int doubleRank = 2*rank; - - if (rank > 0) { - if (order == 'c') { - shapeInfo[doubleRank] = 1; // set unity as last stride for c order - for (int j = 1; j < rank; ++j) { - shapeInfo[doubleRank - j] = shapeInfo[doubleRank - j + 1] * shapeInfo[rank + 1 - j]; - } - } else { - shapeInfo[rank + 1] = 1; // set unity as first stride for f order - for (int j = rank + 1; j < doubleRank; ++j) { - shapeInfo[j + 1] = shapeInfo[j] * shapeInfo[j - rank]; - } - } - } - // set last 2 elements in shapeInfo - shapeInfo[doubleRank + 2] = 1; - shapeInfo[doubleRank + 3] = (int)order; - } +INLINEDEF _CUDA_HD void updateStrides(Nd4jLong *shapeInfo, const char order) { + int rank = shapeInfo[0]; + int doubleRank = 2 * rank; + + if (rank > 0) { + if (order == 'c') { + shapeInfo[doubleRank] = 1; // set unity as last stride for c order + for (int j = 1; j < rank; ++j) { + shapeInfo[doubleRank - j] = + shapeInfo[doubleRank - j + 1] * shapeInfo[rank + 1 - j]; + } + } else { + shapeInfo[rank + 1] = 1; // set unity as first stride for f order + for (int j = rank + 1; j < doubleRank; ++j) { + shapeInfo[j + 1] = shapeInfo[j] * shapeInfo[j - rank]; + } + } + } + // set last 2 elements in shapeInfo + shapeInfo[doubleRank + 2] = 1; + shapeInfo[doubleRank + 3] = (int)order; +} ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, Nd4jLong *stridesOnly, const char order) { - - if (rank > 0) { - if (order == 'c') { - stridesOnly[rank - 1] = 1; // set unity as last stride for c order - for (int j = 1; j < rank; ++j) - stridesOnly[rank - 1 - j] = stridesOnly[rank - j] * shapeOnly[rank - j]; - } - else { - stridesOnly[0] = 1; // set unity as first stride for f order - for (int j = 1; j < rank; ++j) { - stridesOnly[j] = stridesOnly[j - 1] * shapeOnly[j - 1]; - } - } - } - } - - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - template - INLINEDEF _CUDA_HD bool isDimPermuted(const T* dimensions, const Nd4jLong dimSize ) { - for(int i=0; i dimensions[i+1]) - return true; +INLINEDEF _CUDA_HD void updateStrides(const int rank, const Nd4jLong *shapeOnly, + Nd4jLong *stridesOnly, const char order) { + if (rank > 0) { + if (order == 'c') { + stridesOnly[rank - 1] = 1; // set unity as last stride for c order + for (int j = 1; j < rank; ++j) + stridesOnly[rank - 1 - j] = stridesOnly[rank - j] * shapeOnly[rank - j]; + } else { + stridesOnly[0] = 1; // set unity as first stride for f order + for (int j = 1; j < rank; ++j) { + stridesOnly[j] = stridesOnly[j - 1] * shapeOnly[j - 1]; + } + } + } +} - return false; - } +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 +template +INLINEDEF _CUDA_HD bool isDimPermuted(const T *dimensions, + const Nd4jLong dimSize) { + for (int i = 0; i < dimSize - 1; ++i) + if (dimensions[i] > dimensions[i + 1]) return true; + return false; +} /** * @param toCopy the shape to copy * @return a copy of the original struct */ - INLINEDEF _CUDA_HD ShapeInformation *shapeCopy( ShapeInformation *toCopy) { - auto copy = new ShapeInformation; +INLINEDEF _CUDA_HD ShapeInformation *shapeCopy(ShapeInformation *toCopy) { + auto copy = new ShapeInformation; - traceNew(8); + traceNew(8); - copy->shape = new Nd4jLong[toCopy->rank]; + copy->shape = new Nd4jLong[toCopy->rank]; - memcpy(copy->shape, toCopy->shape, toCopy->rank * sizeof(Nd4jLong)); + memcpy(copy->shape, toCopy->shape, toCopy->rank * sizeof(Nd4jLong)); - traceNew(9); + traceNew(9); - copy->stride = new Nd4jLong[toCopy->rank]; - for (int i = 0; i < toCopy->rank; i++) { - copy->stride[i] = toCopy->stride[i]; - } - copy->order = toCopy->order; - copy->rank = toCopy->rank; - copy->offset = toCopy->offset; - copy->elementWiseStride = toCopy->elementWiseStride; - return copy; - } + copy->stride = new Nd4jLong[toCopy->rank]; + for (int i = 0; i < toCopy->rank; i++) { + copy->stride[i] = toCopy->stride[i]; + } + copy->order = toCopy->order; + copy->rank = toCopy->rank; + copy->offset = toCopy->offset; + copy->elementWiseStride = toCopy->elementWiseStride; + return copy; +} - INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder) { - if (rank == 0) - return 1; - - if(shape::isVector(shape,rank)) { - return stride[rank - 1]; +INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder) { + if (rank == 0) return 1; + + if (shape::isVector(shape, rank)) { + return stride[rank - 1]; + } + + else { + int oldnd; + Nd4jLong *oldDims = shape::copyOf(rank, shape); + Nd4jLong *oldStrides = shape::copyOf(rank, stride); + Nd4jLong np, op, last_stride; + Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; + + traceNew(10); + + auto newStrides = new Nd4jLong[rank]; + oldnd = 0; + // set the shape to be 1 x length + int newShapeRank = 2; + auto newShape = new Nd4jLong[newShapeRank]; + newShape[0] = 1; + newShape[1] = shape::prodLong(shape, rank); + + /* + * Remove axes with dimension 1 from the old array. They have no effect + * but would need special cases since their strides do not matter. + */ + for (oldStart = 0; oldStart < rank; oldStart++) { + if (shape[oldStart] != 1) { + oldDims[oldnd] = shape[oldStart]; + oldStrides[oldnd] = stride[oldStart]; + oldnd++; + } + } + + np = 1; + for (newStart = 0; newStart < newShapeRank; newStart++) { + np *= newShape[newStart]; + } + op = 1; + for (oldStart = 0; oldStart < oldnd; oldStart++) { + op *= oldDims[oldStart]; + } + if (np != op) { + /* different total sizes; no hope */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + + if (np == 0) { + /* the current code does not handle 0-sized arrays, so give up */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + + /* oldStart to oldStop and newStart to newStop give the axis ranges + * currently worked with */ + oldStart = 0; + oldStop = 1; + newStart = 0; + newStop = 1; + while (newStart < newShapeRank && oldStart < oldnd) { + np = newShape[newStart]; + op = oldDims[oldStart]; + + while (np != op) { + if (np < op) { + /* Misses trailing 1s, these are handled later */ + np *= newShape[newStop++]; + } else { + op *= oldDims[oldStop++]; } + } - else { - int oldnd; - Nd4jLong *oldDims = shape::copyOf(rank, shape); - Nd4jLong *oldStrides = shape::copyOf(rank, stride); - Nd4jLong np, op, last_stride; - Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; - - traceNew(10); - - auto newStrides = new Nd4jLong[rank]; - oldnd = 0; - //set the shape to be 1 x length - int newShapeRank = 2; - auto newShape = new Nd4jLong[newShapeRank]; - newShape[0] = 1; - newShape[1] = shape::prodLong(shape, rank); - - /* - * Remove axes with dimension 1 from the old array. They have no effect - * but would need special cases since their strides do not matter. - */ - for (oldStart = 0; oldStart < rank; oldStart++) { - if (shape[oldStart] != 1) { - oldDims[oldnd] = shape[oldStart]; - oldStrides[oldnd] = stride[oldStart]; - oldnd++; - } - } - - np = 1; - for (newStart = 0; newStart < newShapeRank; newStart++) { - np *= newShape[newStart]; - } - op = 1; - for (oldStart = 0; oldStart < oldnd; oldStart++) { - op *= oldDims[oldStart]; - } - if (np != op) { -/* different total sizes; no hope */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - - if (np == 0) { -/* the current code does not handle 0-sized arrays, so give up */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - -/* oldStart to oldStop and newStart to newStop give the axis ranges currently worked with */ - oldStart = 0; - oldStop = 1; - newStart = 0; - newStop = 1; - while (newStart < newShapeRank && oldStart < oldnd) { - np = newShape[newStart]; - op = oldDims[oldStart]; - - while (np != op) { - if (np < op) { -/* Misses trailing 1s, these are handled later */ - np *= newShape[newStop++]; - } else { - op *= oldDims[oldStop++]; - } - } - -/* Check whether the original axes can be combined */ - for (ok = oldStart; ok < oldStop - 1; ok++) { - if (isFOrder) { - if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { -/* not contiguous enough */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - } else { -/* C order */ - if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { -/* not contiguous enough */ - delete[] newStrides; - delete[] newShape; - delete[] oldStrides; - delete[] oldDims; - return 0; - } - } - } - -/* Calculate new strides for all axes currently worked with */ - if (isFOrder) { - newStrides[newStart] = oldStrides[oldStart]; - for (nk = newStart + 1; nk < newStop; nk++) { - newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1]; - } - } else { -/* C order */ - newStrides[newStop - 1] = oldStrides[oldStop - 1]; - for (nk = newStop - 1; nk > newStart; nk--) { - newStrides[nk - 1] = newStrides[nk] * newShape[nk]; - } - } - newStart = newStop++; - oldStart = oldStop++; - } - -/* - * Set strides corresponding to trailing 1s of the new shape. - */ - if (newStart >= 1) { - last_stride = newStrides[newStart - 1]; - } else { - last_stride = stride[rank - 1]; - } - if (isFOrder) { - if (newStart >= 1) - last_stride *= newShape[newStart - 1]; - } - for (nk = newStart; nk < newShapeRank; nk++) { - newStrides[nk] = last_stride; - } -//returns the last element of the new stride array - int ret = last_stride; + /* Check whether the original axes can be combined */ + for (ok = oldStart; ok < oldStop - 1; ok++) { + if (isFOrder) { + if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { + /* not contiguous enough */ + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return 0; + } + } else { + /* C order */ + if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { + /* not contiguous enough */ delete[] newStrides; delete[] newShape; delete[] oldStrides; delete[] oldDims; - return ret; + return 0; + } } + } - + /* Calculate new strides for all axes currently worked with */ + if (isFOrder) { + newStrides[newStart] = oldStrides[oldStart]; + for (nk = newStart + 1; nk < newStop; nk++) { + newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1]; + } + } else { + /* C order */ + newStrides[newStop - 1] = oldStrides[oldStop - 1]; + for (nk = newStop - 1; nk > newStart; nk--) { + newStrides[nk - 1] = newStrides[nk] * newShape[nk]; + } + } + newStart = newStop++; + oldStart = oldStop++; } - INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const* shape, Nd4jLong const* stride, int isFOrder, - Nd4jLong const* dimension, int dimensionLength) { - if(dimensionLength == 1) { - return stride[dimension[0]]; - } - return 0; + /* + * Set strides corresponding to trailing 1s of the new shape. + */ + if (newStart >= 1) { + last_stride = newStrides[newStart - 1]; + } else { + last_stride = stride[rank - 1]; + } + if (isFOrder) { + if (newStart >= 1) last_stride *= newShape[newStart - 1]; + } + for (nk = newStart; nk < newShapeRank; nk++) { + newStrides[nk] = last_stride; + } + // returns the last element of the new stride array + int ret = last_stride; + delete[] newStrides; + delete[] newShape; + delete[] oldStrides; + delete[] oldDims; + return ret; + } +} - } +INLINEDEF _CUDA_HD int computeElementWiseStride(int rank, Nd4jLong const *shape, + Nd4jLong const *stride, + int isFOrder, + Nd4jLong const *dimension, + int dimensionLength) { + if (dimensionLength == 1) { + return stride[dimension[0]]; + } + return 0; +} /** * Get the shape info buffer * for the given rank and shape. */ - INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape) { - Nd4jLong *stride = shape::calcStrides(shape, rank); - - traceNew(11); - - auto shapeInfo = new shape::ShapeInformation(); - shapeInfo->shape = const_cast(shape); - shapeInfo->stride = stride; - shapeInfo->offset = 0; - shapeInfo->rank = rank; - int elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - shapeInfo->order = 'c'; - shapeInfo->elementWiseStride = elementWiseStride; - auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); - delete[] stride; - delete shapeInfo; - sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); - return shapeInfoBuffer; - } - - /** - * This is special method, it returns ONLY 2D shapebuffer. - * - * This method is used only for SoftMax - */ - INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, Nd4jLong const* shape, Nd4jLong *buffer) { - Nd4jLong stride[MAX_RANK]; - shape::calcStrides(shape,rank, stride); - - - shape::ShapeInformation shapeInfo; - shapeInfo.shape = const_cast(shape); - shapeInfo.stride = stride; - shapeInfo.offset = 0; - shapeInfo.rank = rank; - auto elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - - shapeInfo.order = 'c'; - shapeInfo.elementWiseStride = elementWiseStride; - shape::toShapeBuffer(&shapeInfo, buffer); - sd::ArrayOptions::setDataType(buffer, dtype); - return buffer; - } +INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape) { + Nd4jLong *stride = shape::calcStrides(shape, rank); + + traceNew(11); + + auto shapeInfo = new shape::ShapeInformation(); + shapeInfo->shape = const_cast(shape); + shapeInfo->stride = stride; + shapeInfo->offset = 0; + shapeInfo->rank = rank; + int elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + shapeInfo->order = 'c'; + shapeInfo->elementWiseStride = elementWiseStride; + auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); + delete[] stride; + delete shapeInfo; + sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); + return shapeInfoBuffer; +} /** -* Get the shape info buffer -* for the given rank and shape. -*/ - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const* shape) { - auto stride = shape::calcStridesFortran(shape,rank); - - traceNew(12); - - auto shapeInfo = new shape::ShapeInformation(); - shapeInfo->shape = const_cast(shape); - shapeInfo->stride = stride; - shapeInfo->offset = 0; - shapeInfo->rank = rank; - int elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); - - shapeInfo->order = 'f'; - shapeInfo->elementWiseStride = elementWiseStride; - auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); - delete[] stride; - delete shapeInfo; - sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); - return shapeInfoBuffer; - } - - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, Nd4jLong const *shape, Nd4jLong *output) { - Nd4jLong stride[MAX_RANK]; - shape::calcStridesFortran(shape,rank, stride); - + * This is special method, it returns ONLY 2D shapebuffer. + * + * This method is used only for SoftMax + */ +INLINEDEF _CUDA_HD Nd4jLong *shapeBuffer(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *buffer) { + Nd4jLong stride[MAX_RANK]; + shape::calcStrides(shape, rank, stride); + + shape::ShapeInformation shapeInfo; + shapeInfo.shape = const_cast(shape); + shapeInfo.stride = stride; + shapeInfo.offset = 0; + shapeInfo.rank = rank; + auto elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo.order = 'c'; + shapeInfo.elementWiseStride = elementWiseStride; + shape::toShapeBuffer(&shapeInfo, buffer); + sd::ArrayOptions::setDataType(buffer, dtype); + return buffer; +} - shape::ShapeInformation shapeInfo; - shapeInfo.shape = const_cast(shape); - shapeInfo.stride = stride; - shapeInfo.offset = 0; - shapeInfo.rank = rank; - auto elementWiseStride = shape::computeElementWiseStride(rank, shape, stride, 0); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape) { + auto stride = shape::calcStridesFortran(shape, rank); + + traceNew(12); + + auto shapeInfo = new shape::ShapeInformation(); + shapeInfo->shape = const_cast(shape); + shapeInfo->stride = stride; + shapeInfo->offset = 0; + shapeInfo->rank = rank; + int elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo->order = 'f'; + shapeInfo->elementWiseStride = elementWiseStride; + auto shapeInfoBuffer = shape::toShapeBuffer(shapeInfo); + delete[] stride; + delete shapeInfo; + sd::ArrayOptions::setDataType(shapeInfoBuffer, dtype); + return shapeInfoBuffer; +} - shapeInfo.order = 'f'; - shapeInfo.elementWiseStride = elementWiseStride; - shape::toShapeBuffer(&shapeInfo, output); - sd::ArrayOptions::setDataType(output, dtype); - return output; - } +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferFortran(int rank, sd::DataType dtype, + Nd4jLong const *shape, + Nd4jLong *output) { + Nd4jLong stride[MAX_RANK]; + shape::calcStridesFortran(shape, rank, stride); + + shape::ShapeInformation shapeInfo; + shapeInfo.shape = const_cast(shape); + shapeInfo.stride = stride; + shapeInfo.offset = 0; + shapeInfo.rank = rank; + auto elementWiseStride = + shape::computeElementWiseStride(rank, shape, stride, 0); + + shapeInfo.order = 'f'; + shapeInfo.elementWiseStride = elementWiseStride; + shape::toShapeBuffer(&shapeInfo, output); + sd::ArrayOptions::setDataType(output, dtype); + return output; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *indices) { - - Nd4jLong index, shift = 1;; - - index = indices[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * indices[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const Nd4jLong *indices) { + Nd4jLong index, shift = 1; + ; + + index = indices[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * indices[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords) { - - Nd4jLong index, shift = 1;; - - index = coords[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * coords[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords) { + Nd4jLong index, shift = 1; + ; + + index = coords[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const uint *coords) { - - Nd4jLong index, shift = 1;; - - index = coords[shapeInfo[0] - 1]; - for(uint i = shapeInfo[0]; i > 1; --i) { - shift *= shapeInfo[i]; - index += shift * coords[i - 2]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const uint *coords) { + Nd4jLong index, shift = 1; + ; + + index = coords[shapeInfo[0] - 1]; + for (uint i = shapeInfo[0]; i > 1; --i) { + shift *= shapeInfo[i]; + index += shift * coords[i - 2]; + } + + return index; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const int *indices) { - - Nd4jLong index, shift = 1;; - - index = indices[rank - 1]; - for(uint i = rank - 1; i >= 1; --i) { - shift *= shape[i]; - index += shift * indices[i - 1]; - } - - return index; +INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, + const int *indices) { + Nd4jLong index, shift = 1; + ; + + index = indices[rank - 1]; + for (uint i = rank - 1; i >= 1; --i) { + shift *= shape[i]; + index += shift * indices[i - 1]; + } + + return index; } -INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims) { +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, + const int *coords, const int dimsSize, + const int *tadDims) { + Nd4jLong index, shift = 1; + ; - Nd4jLong index, shift = 1;; + index = coords[tadDims[dimsSize - 1]]; + for (uint i = dimsSize - 1; i >= 1; --i) { + shift *= shapeInfo[tadDims[i]]; + index += shift * coords[i - 1]; + } - index = coords[tadDims[dimsSize - 1]]; - for(uint i = dimsSize - 1; i >= 1; --i) { - shift *= shapeInfo[tadDims[i]]; - index += shift * coords[i - 1]; - } - - return index; + return index; } template - INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) { - - PRAGMA_OMP_SIMD - for (int e = 0; e < length; e++) - buffer[e] = value; - } - +INLINEDEF _CUDA_HD void fill(T *buffer, T value, Nd4jLong length) { + PRAGMA_OMP_SIMD + for (int e = 0; e < length; e++) buffer[e] = value; +} // ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong +// *shapeInfo, Nd4jLong arrLen) { // const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; @@ -1892,7 +2089,8 @@ template // return offset; // } -// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, +// uint arrLen) { // const uint rank = shapeInfo[0]; // const uint ews = shapeInfo[rank + rank + 2]; @@ -1916,61 +2114,58 @@ template // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - - if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { - - const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; - if (ews == 1) - return index; - else if(ews > 1) - return ews * index; - } - - Nd4jLong offset = 0; - - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration - - return offset; +INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, + const Nd4jLong *shapeInfo) { + if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { + const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; + if (ews == 1) + return index; + else if (ews > 1) + return ews * index; + } + + Nd4jLong offset = 0; + + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; + } + + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + + return offset; } ////////////////////////////////////////////////////////////////////// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo) { + if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { + const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; + if (ews == 1) + return index; + else if (ews > 1) + return ews * index; + } - if (shapeInfo[2 * shapeInfo[0] + 3] == 99) { - - const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2]; - if (ews == 1) - return index; - else if(ews > 1) - return ews * index; - } - - uint offset = 0; + uint offset = 0; - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; + } - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration - return offset; + return offset; } - ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned) { +INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, + const Nd4jLong *lShapeInfo, + const uint *uShapeInfo, + const bool useUnsigned) { + if (useUnsigned) return getIndexOffset(static_cast(index), uShapeInfo); - if(useUnsigned) - return getIndexOffset(static_cast(index), uShapeInfo); - - return getIndexOffset(index, lShapeInfo); + return getIndexOffset(index, lShapeInfo); } /** @@ -1980,14 +2175,15 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange * @return */ - INLINEDEF _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, int *rearrange) { - traceNew(16); - Nd4jLong *ret = new Nd4jLong[length]; - for (int i = 0; i < length; i++) { - ret[i] = shape[rearrange[i]]; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *doPermuteSwap(int length, Nd4jLong *shape, + int *rearrange) { + traceNew(16); + Nd4jLong *ret = new Nd4jLong[length]; + for (int i = 0; i < length; i++) { + ret[i] = shape[rearrange[i]]; + } + return ret; +} /** * @@ -1996,125 +2192,128 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange * @return */ - INLINEDEF _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, int *rearrange) { - if(length == 1) { - return; - } - else { - Nd4jLong *shapeDeref = *shape; - if(shape::prodLong(shapeDeref,length) < 2) { - return; - } - } +INLINEDEF _CUDA_HD void doPermuteSwap(int length, Nd4jLong **shape, + int *rearrange) { + if (length == 1) { + return; + } else { + Nd4jLong *shapeDeref = *shape; + if (shape::prodLong(shapeDeref, length) < 2) { + return; + } + } + + bool inOrder = true; + for (int i = 0; i < length - 1; i++) { + inOrder = inOrder && rearrange[i] + 1 == rearrange[i + 1]; + } + + // all in order, nothing to do + if (inOrder) return; + + Nd4jLong *shapeDeref = *shape; + // we know they are just reversed, dimension length of 2 + if (length == 2) { + auto shapeFirst = shapeDeref[0]; + auto shapeSecond = shapeDeref[1]; + shapeDeref[0] = shapeSecond; + shapeDeref[1] = shapeFirst; + return; + } else if (length == 1) { + // no permute + return; + } + + auto temp = new Nd4jLong[length]; + memcpy(temp, shapeDeref, sizeof(Nd4jLong) * length); + for (int i = 0; i < length; i++) { + shapeDeref[i] = temp[rearrange[i]]; + } + + delete[] temp; +} - bool inOrder = true; - for(int i = 0; i < length - 1; i++) { - inOrder = inOrder && rearrange[i] + 1 == rearrange[i + 1]; +INLINEDEF _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, + int *rearrange, + Nd4jLong *out) { + if (shapeBuffer != out) + memcpy(out, shapeBuffer, + sizeof(Nd4jLong) * shape::shapeInfoLength(shapeBuffer)); - } + shape::doPermuteShapeInfo(out, rearrange); +} - //all in order, nothing to do - if(inOrder) - return; +INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const *shapeBuffer, + int *rearrange) { + auto len = shape::shapeInfoLength(shape::rank(shapeBuffer)); + Nd4jLong *copy = shape::copyOf(len, shapeBuffer); + shape::doPermuteShapeInfo(copy, rearrange); + return copy; +} +INLINEDEF _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeInfo, + const int *rearrange, Nd4jLong len) { + if (len == -1) // calculate array length if it is not given + len = shape::length(shapeInfo); - Nd4jLong *shapeDeref = *shape; - //we know they are just reversed, dimension length of 2 - if(length == 2) { - auto shapeFirst = shapeDeref[0]; - auto shapeSecond = shapeDeref[1]; - shapeDeref[0] = shapeSecond; - shapeDeref[1] = shapeFirst; - return; - } - else if(length == 1) { - //no permute - return; - } + // check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we + // don't need permute + if (len == 1) return; - auto temp = new Nd4jLong[length]; - memcpy(temp,shapeDeref,sizeof(Nd4jLong) * length); - for (int i = 0; i < length; i++) { - shapeDeref[i] = temp[rearrange[i]]; - } + const int rank = shape::rank(shapeInfo); - delete[] temp; + // check whether rearrange is like {0,1,2,3,...} - in this case we don't need + // permute as well + bool isPermutNecessary = false; + for (int i = 0; i < rank; ++i) + if (rearrange[i] != i) { + isPermutNecessary = true; + break; } + if (!isPermutNecessary) return; - INLINEDEF _CUDA_HD void permuteShapeBufferInPlace(Nd4jLong *shapeBuffer, int *rearrange, Nd4jLong *out) { - if(shapeBuffer != out) - memcpy(out,shapeBuffer,sizeof(Nd4jLong) * shape::shapeInfoLength(shapeBuffer)); - - shape::doPermuteShapeInfo(out, rearrange); - } - - INLINEDEF _CUDA_HD Nd4jLong *permuteShapeBuffer(Nd4jLong const* shapeBuffer, int* rearrange) { - auto len = shape::shapeInfoLength(shape::rank(shapeBuffer)); - Nd4jLong *copy = shape::copyOf(len, shapeBuffer); - shape::doPermuteShapeInfo(copy,rearrange); - return copy; + // check whether rearrange contains correct indexes + for (int i = 0; i < rank; ++i) + if (rearrange[i] >= rank || rearrange[i] < 0) { + printf( + "shape::doPermuteShapeInfo function failed: rearrange indexes are " + "incorrect !\n"); + return; } - INLINEDEF _CUDA_HD void doPermuteShapeInfo(Nd4jLong *shapeInfo, const int *rearrange, Nd4jLong len) { - - if(len == -1) // calculate array length if it is not given - len = shape::length(shapeInfo); - - //check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute - if(len == 1) - return; - - const int rank = shape::rank(shapeInfo); - - // check whether rearrange is like {0,1,2,3,...} - in this case we don't need permute as well - bool isPermutNecessary = false; - for(int i = 0; i < rank; ++i) - if(rearrange[i] != i) { - isPermutNecessary = true; - break; - } - - if(!isPermutNecessary) - return; - - // check whether rearrange contains correct indexes - for(int i = 0; i < rank; ++i) - if(rearrange[i] >= rank || rearrange[i] < 0) { - printf("shape::doPermuteShapeInfo function failed: rearrange indexes are incorrect !\n"); - return; - } - - // if everything is ok then perform permute - auto temp = new Nd4jLong[shape::shapeInfoLength(rank) - 3]; - memcpy(temp, shapeInfo, sizeof(Nd4jLong) * (shape::shapeInfoLength(rank) - 3)); - for (int i = 0; i < rank; ++i) { - shapeInfo[i + 1] = temp[rearrange[i] + 1]; - shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank]; - } + // if everything is ok then perform permute + auto temp = new Nd4jLong[shape::shapeInfoLength(rank) - 3]; + memcpy(temp, shapeInfo, + sizeof(Nd4jLong) * (shape::shapeInfoLength(rank) - 3)); + for (int i = 0; i < rank; ++i) { + shapeInfo[i + 1] = temp[rearrange[i] + 1]; + shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank]; + } - shape::checkStridesEwsAndOrder(shapeInfo); + shape::checkStridesEwsAndOrder(shapeInfo); - delete[] temp; - } + delete[] temp; +} +INLINEDEF _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, + int *dimension, + int dimensionLength) { + int delta = originalRank - dimensionLength; - INLINEDEF _CUDA_HD Nd4jLong *createPermuteIndexes(int originalRank, int *dimension,int dimensionLength) { - int delta = originalRank - dimensionLength; + traceNew(17); - traceNew(17); + Nd4jLong *ret = new Nd4jLong[originalRank]; + for (int i = 0; i < delta; i++) { + ret[i] = i + dimensionLength; + } - Nd4jLong *ret = new Nd4jLong[originalRank]; - for(int i = 0; i < delta; i++) { - ret[i] = i + dimensionLength; - } + for (int i = delta; i < originalRank; i++) { + ret[i] = i - delta; + } - for(int i = delta; i < originalRank; i++) { - ret[i] = i - delta; - } - - return ret; - } + return ret; +} /** * Get the ordering for the device @@ -2124,56 +2323,50 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param elementStride * @return */ - INLINEDEF _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, int elementStride) { - Nd4jLong sd = 1; - int dim = -1; - int i = -1; - int cContiguous = 1; - int isFortran = 1; - - for (i = length - 1; i >= 0; --i) { - dim = shape[i]; - - if (stride[i] != sd) { - cContiguous = 0; - break; - } - /* contiguous, if it got this far */ - if (dim == 0) { - break; - } - sd *= dim; - - } - - /* check if fortran contiguous */ - sd = elementStride; - for (i = 0; i < length; ++i) { - dim = shape[i]; - if (stride[i] != sd) { - isFortran = 0; - } - if (dim == 0) { - break; - } - sd *= dim; - - } - - if (isFortran && cContiguous) - return 'a'; - else if (isFortran && !cContiguous) - return 'f'; - else if (!isFortran && !cContiguous) - return 'c'; - else - return 'c'; - - } - - - - +INLINEDEF _CUDA_HD char getOrder(int length, Nd4jLong *shape, Nd4jLong *stride, + int elementStride) { + Nd4jLong sd = 1; + int dim = -1; + int i = -1; + int cContiguous = 1; + int isFortran = 1; + + for (i = length - 1; i >= 0; --i) { + dim = shape[i]; + + if (stride[i] != sd) { + cContiguous = 0; + break; + } + /* contiguous, if it got this far */ + if (dim == 0) { + break; + } + sd *= dim; + } + + /* check if fortran contiguous */ + sd = elementStride; + for (i = 0; i < length; ++i) { + dim = shape[i]; + if (stride[i] != sd) { + isFortran = 0; + } + if (dim == 0) { + break; + } + sd *= dim; + } + + if (isFortran && cContiguous) + return 'a'; + else if (isFortran && !cContiguous) + return 'f'; + else if (!isFortran && !cContiguous) + return 'c'; + else + return 'c'; +} /** * Ensure that every value in the re arrange @@ -2185,33 +2378,30 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @return */ - template - INLINEDEF _CUDA_HD int checkArrangeArray(T *arr, int arrLength, int shapeLength) { - if (arrLength != shapeLength) - return -1; - for (int i = 0; i < arrLength; i++) { - if (arr[i] >= arrLength || arr[i] < 0) - return -1; - } - - for (int i = 0; i < arrLength; i++) { - for (int j = 0; j < arrLength; j++) { - if (i != j && arr[i] == arr[j]) - return -1; - } - } +template +INLINEDEF _CUDA_HD int checkArrangeArray(T *arr, int arrLength, + int shapeLength) { + if (arrLength != shapeLength) return -1; + for (int i = 0; i < arrLength; i++) { + if (arr[i] >= arrLength || arr[i] < 0) return -1; + } - return 1; + for (int i = 0; i < arrLength; i++) { + for (int j = 0; j < arrLength; j++) { + if (i != j && arr[i] == arr[j]) return -1; } + } + return 1; +} - INLINEDEF _CUDA_HD void traceNew(int id) { - //printf("new happened: [%i]\n", id); +INLINEDEF _CUDA_HD void traceNew(int id){ +// printf("new happened: [%i]\n", id); #ifndef __CUDACC__ - //fflush(stdout); +// fflush(stdout); #endif - } +} /** * Permute the shape information @@ -2219,18 +2409,16 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - INLINEDEF _CUDA_HD void permute(ShapeInformation **info, int *rearrange, int rank) { - ShapeInformation *infoDeref = *info; - checkArrangeArray(rearrange, rank, rank); - shape::doPermuteSwap(rank, &infoDeref->shape, rearrange); - shape::doPermuteSwap(rank, &infoDeref->stride, rearrange); - char order = getOrder(rank, - infoDeref->shape, - infoDeref->stride, - infoDeref->elementWiseStride); - infoDeref->order = order; - - } +INLINEDEF _CUDA_HD + void permute(ShapeInformation **info, int *rearrange, int rank) { + ShapeInformation *infoDeref = *info; + checkArrangeArray(rearrange, rank, rank); + shape::doPermuteSwap(rank, &infoDeref->shape, rearrange); + shape::doPermuteSwap(rank, &infoDeref->stride, rearrange); + char order = getOrder(rank, infoDeref->shape, infoDeref->stride, + infoDeref->elementWiseStride); + infoDeref->order = order; +} /** * Returns whether the @@ -2238,182 +2426,176 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * @param shape the shape of the array * @param rank the rank of the shape */ - INLINEDEF _CUDA_HD int isVector(Nd4jLong const* shape, int rank) { - if (rank == 0) - return 0; +INLINEDEF _CUDA_HD int isVector(Nd4jLong const *shape, int rank) { + if (rank == 0) return 0; - if (rank == 1) - return 1; + if (rank == 1) return 1; - if (rank > 2) - return 0; - else if (rank <= 2) { - if (shape[0] == 1 || shape[1] == 1) - return 1; - } - return 0; - } - - INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const* shapeInfo, int& posOfNonUnityDim) { - - int numOfNonUnity = 0; - for(int i = 1; i <= shapeInfo[0]; ++i) { - if(shapeInfo[i] != 1) { - ++numOfNonUnity; - posOfNonUnityDim = i-1; - } - } + if (rank > 2) + return 0; + else if (rank <= 2) { + if (shape[0] == 1 || shape[1] == 1) return 1; + } + return 0; +} - return numOfNonUnity == 1 && shapeInfo[0] > 2; +INLINEDEF _CUDA_HD bool isLikeVector(Nd4jLong const *shapeInfo, + int &posOfNonUnityDim) { + int numOfNonUnity = 0; + for (int i = 1; i <= shapeInfo[0]; ++i) { + if (shapeInfo[i] != 1) { + ++numOfNonUnity; + posOfNonUnityDim = i - 1; } + } - INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim) { + return numOfNonUnity == 1 && shapeInfo[0] > 2; +} - if(rank(shapeInfo) > 0 && length(shapeInfo) == 1) { - posOfNonUnityDim = -1; - return true; - } +INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, + int &posOfNonUnityDim) { + if (rank(shapeInfo) > 0 && length(shapeInfo) == 1) { + posOfNonUnityDim = -1; + return true; + } - int numOfNonUnity = 0; - for(int i = 1; i <= shapeInfo[0]; ++i) { - if(shapeInfo[i] != 1) { - ++numOfNonUnity; - posOfNonUnityDim = i-1; - } - } - return numOfNonUnity == 1; + int numOfNonUnity = 0; + for (int i = 1; i <= shapeInfo[0]; ++i) { + if (shapeInfo[i] != 1) { + ++numOfNonUnity; + posOfNonUnityDim = i - 1; } + } + return numOfNonUnity == 1; +} - INLINEDEF _CUDA_H Nd4jLong const* detachShape(Nd4jLong const* originalShape) { - Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; - memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); - - return newShape; - } +INLINEDEF _CUDA_H Nd4jLong const *detachShape(Nd4jLong const *originalShape) { + Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; + memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); + return newShape; +} - INLINEDEF _CUDA_H Nd4jLong* copyShape(Nd4jLong const* originalShape) { - Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; - memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); +INLINEDEF _CUDA_H Nd4jLong *copyShape(Nd4jLong const *originalShape) { + Nd4jLong *newShape = new Nd4jLong[shape::shapeInfoLength(originalShape)]; + memcpy(newShape, originalShape, shape::shapeInfoByteLength(originalShape)); - return newShape; - } + return newShape; +} - INLINEDEF _CUDA_HD int isVector(const Nd4jLong *shapeInfo) { - return isVector(shape::shapeOf(const_cast(shapeInfo)), shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int isVector(const Nd4jLong *shapeInfo) { + return isVector(shape::shapeOf(const_cast(shapeInfo)), + shape::rank(shapeInfo)); +} - INLINEDEF _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo) { - bool isVector = shape::isVector(shapeInfo) == 1; - bool shapeFirstOne = shape::shapeOf(const_cast(shapeInfo))[0] == 1; - return isVector && shapeFirstOne; - } +INLINEDEF _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo) { + bool isVector = shape::isVector(shapeInfo) == 1; + bool shapeFirstOne = + shape::shapeOf(const_cast(shapeInfo))[0] == 1; + return isVector && shapeFirstOne; +} - INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) { - bool isVector = shape::isVector(shapeInfo) == 1; - bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1; - return isVector && !shapeFirstOne; - } +INLINEDEF _CUDA_HD bool isColumnVector(const Nd4jLong *shapeInfo) { + bool isVector = shape::isVector(shapeInfo) == 1; + bool shapeFirstOne = shape::shapeOf(shapeInfo)[0] == 1; + return isVector && !shapeFirstOne; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) { +INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, + const Nd4jLong *inShape) { + int num = 0; - int num = 0; + for (uint i = 0; i < rank; ++i) + if (inShape[i] != 1) ++num; - for(uint i = 0; i < rank; ++i) - if(inShape[i] != 1) - ++num; - - return num; + return num; } - INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) { - for(int i = 0; i < rank; i++) { - if(shape[i] == shape::prodLong(shape,rank)) - return 1; - } +INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) { + for (int i = 0; i < rank; i++) { + if (shape[i] == shape::prodLong(shape, rank)) return 1; + } - return 0; - } + return 0; +} - INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo) { - return oneDimEqualToLength(shape::shapeOf(shapeInfo),shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shapeInfo) { + return oneDimEqualToLength(shape::shapeOf(shapeInfo), shape::rank(shapeInfo)); +} /** -* Returns whether the -* given shape is a vector or not -* @param shape the shape of the array -* @param rank the rank of the shape -*/ - INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shape, int rank) { - if (rank > 2) - return 0; - else if (rank <= 2) { - if (shape[0] == 1 || shape[1] == 1) - return 0; - } - - return 1; - } + * Returns whether the + * given shape is a vector or not + * @param shape the shape of the array + * @param rank the rank of the shape + */ +INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shape, int rank) { + if (rank > 2) + return 0; + else if (rank <= 2) { + if (shape[0] == 1 || shape[1] == 1) return 0; + } + + return 1; +} - INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo) { - return isMatrix(shape::shapeOf(shapeInfo),shape::rank(shapeInfo)); - } +INLINEDEF _CUDA_HD int isMatrix(Nd4jLong *shapeInfo) { + return isMatrix(shape::shapeOf(shapeInfo), shape::rank(shapeInfo)); +} /** * Returns the shape portion of an information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) { - - return shapeInfo + 1; - } - - INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) { +INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) { + return shapeInfo + 1; +} - return shape::shapeOf(const_cast(shapeInfo)); - } +INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) { + return shape::shapeOf(const_cast(shapeInfo)); +} /** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - template - INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const* toCopy) { - traceNew(18); +template +INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy) { + traceNew(18); - T *ret = new T[length]; - return copyOf(length, toCopy, ret); - } + T *ret = new T[length]; + return copyOf(length, toCopy, ret); +} - template - INLINEDEF _CUDA_HD T* copyOf(Nd4jLong length, T const* toCopy, T *ret) { - memcpy(ret, toCopy, sizeof(T)*length); - return ret; - } +template +INLINEDEF _CUDA_HD T *copyOf(Nd4jLong length, T const *toCopy, T *ret) { + memcpy(ret, toCopy, sizeof(T) * length); + return ret; +} /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - template - INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const* from, T *to) { - memcpy(to, from, sizeof(T)*length); - } + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +template +INLINEDEF _CUDA_HD void copyTo(Nd4jLong length, T const *from, T *to) { + memcpy(to, from, sizeof(T) * length); +} /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const* from, Nd4jLong *to, Nd4jLong *indexes) { - for(int i = 0; i < length; i++) { - to[i] = from[indexes[i]]; - } - } + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +INLINEDEF _CUDA_HD void copyTo(int length, Nd4jLong const *from, Nd4jLong *to, + Nd4jLong *indexes) { + for (int i = 0; i < length; i++) { + to[i] = from[indexes[i]]; + } +} /** * Permute the given strides @@ -2424,93 +2606,88 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * and all must be filled in) * @return the rearranged array */ - /* - INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, int *rearrange) { - Nd4jLong *strideCopy = copyOf(shapeRank, toPermute); - checkArrangeArray(rearrange, shapeRank, shapeRank); - Nd4jLong *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); - delete[] strideCopy; - return newStride; - } - */ +/* + INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int + shapeRank, int *rearrange) { Nd4jLong *strideCopy = copyOf(shapeRank, + toPermute); checkArrangeArray(rearrange, shapeRank, shapeRank); Nd4jLong + *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); delete[] + strideCopy; return newStride; + } + */ /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - INLINEDEF _CUDA_HD Nd4jLong *slice(Nd4jLong *shape) { - return shape + 1; - } - - INLINEDEF _CUDA_HD int slices(Nd4jLong *shapeBuffer) { - return static_cast(shape::shapeOf(shapeBuffer)[0]); - } - - - INLINEDEF _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - int newRank = rank - 1; - if(newRank < 2) - newRank = 2; - Nd4jLong *newShapeBuffer = new Nd4jLong[shape::shapeInfoLength(newRank)]; - newShapeBuffer[0] = newRank; - Nd4jLong *currShape = shape::shapeOf(shapeBuffer); - Nd4jLong *currStride = shape::stride(shapeBuffer); - //initialize new shape and stride by taking the shape and stride + 1 - //and adding to the shape information - //a slice is always just taking the existing shape and cutting the first index off - //of the shape and stride - Nd4jLong *newShape = shape::shapeOf(newShapeBuffer); - Nd4jLong *newStride = shape::stride(newShapeBuffer); - if(shape::isVector(shapeBuffer)) { - Nd4jLong *currShape = shape::shapeOf(shapeBuffer); - //row vector: slice index 0 is a valid index, just copy the whole thing - if(currShape[0] == 1) { - if(sliceIdx == 0) { - memcpy(newShapeBuffer,shapeBuffer,shape::shapeInfoByteLength(shape::rank(shapeBuffer))); - return newShapeBuffer; - } - } - //column vector: this will be a scalar - else { - delete[] newShapeBuffer; - Nd4jLong *scalar = shape::createScalarShapeInfo(); - int offset = shape::offset(shapeBuffer); - scalar[shape::shapeInfoLength(2) - 3] = offset + sliceIdx; - return scalar; - } - } - else if(shape::isMatrix(shapeBuffer)) { - newShape[0] = 1; - newShape[1] = currShape[1]; - newStride[0] = 1; - newStride[1] = currStride[1]; - } - else { - for(int i = 0; i < newRank; i++) { - newShape[i] = currShape[i + 1]; - newStride[i] = currStride[i + 1]; - } - } +INLINEDEF _CUDA_HD Nd4jLong *slice(Nd4jLong *shape) { return shape + 1; } - auto indices = new Nd4jLong[rank]; - memset((void *) indices,0,rank * sizeof(Nd4jLong)); - indices[0] = sliceIdx; - Nd4jLong offset = shape::getOffset(newShapeBuffer, indices); - newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset; - - // set current order and ews - newShapeBuffer[2 * newRank + 2] = shape::elementWiseStride(shapeBuffer); - newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer); - - // correct order and ews if necessary - shape::checkStridesEwsAndOrder(newShapeBuffer); - - delete[] indices; +INLINEDEF _CUDA_HD int slices(Nd4jLong *shapeBuffer) { + return static_cast(shape::shapeOf(shapeBuffer)[0]); +} +INLINEDEF _CUDA_HD Nd4jLong *sliceOfShapeBuffer(Nd4jLong sliceIdx, + Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + int newRank = rank - 1; + if (newRank < 2) newRank = 2; + Nd4jLong *newShapeBuffer = new Nd4jLong[shape::shapeInfoLength(newRank)]; + newShapeBuffer[0] = newRank; + Nd4jLong *currShape = shape::shapeOf(shapeBuffer); + Nd4jLong *currStride = shape::stride(shapeBuffer); + // initialize new shape and stride by taking the shape and stride + 1 + // and adding to the shape information + // a slice is always just taking the existing shape and cutting the first + // index off of the shape and stride + Nd4jLong *newShape = shape::shapeOf(newShapeBuffer); + Nd4jLong *newStride = shape::stride(newShapeBuffer); + if (shape::isVector(shapeBuffer)) { + Nd4jLong *currShape = shape::shapeOf(shapeBuffer); + // row vector: slice index 0 is a valid index, just copy the whole thing + if (currShape[0] == 1) { + if (sliceIdx == 0) { + memcpy(newShapeBuffer, shapeBuffer, + shape::shapeInfoByteLength(shape::rank(shapeBuffer))); return newShapeBuffer; + } } + // column vector: this will be a scalar + else { + delete[] newShapeBuffer; + Nd4jLong *scalar = shape::createScalarShapeInfo(); + int offset = shape::offset(shapeBuffer); + scalar[shape::shapeInfoLength(2) - 3] = offset + sliceIdx; + return scalar; + } + } else if (shape::isMatrix(shapeBuffer)) { + newShape[0] = 1; + newShape[1] = currShape[1]; + newStride[0] = 1; + newStride[1] = currStride[1]; + } else { + for (int i = 0; i < newRank; i++) { + newShape[i] = currShape[i + 1]; + newStride[i] = currStride[i + 1]; + } + } + + auto indices = new Nd4jLong[rank]; + memset((void *)indices, 0, rank * sizeof(Nd4jLong)); + indices[0] = sliceIdx; + Nd4jLong offset = shape::getOffset(newShapeBuffer, indices); + newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset; + + // set current order and ews + newShapeBuffer[2 * newRank + 2] = shape::elementWiseStride(shapeBuffer); + newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer); + + // correct order and ews if necessary + shape::checkStridesEwsAndOrder(newShapeBuffer); + + delete[] indices; + + return newShapeBuffer; +} /** * Returns the length of the @@ -2520,48 +2697,46 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * info length for * @return rank * 2 + 4 */ - INLINEDEF _CUDA_HD int shapeInfoLength(int rank) { - //FIXME magic numbers - return rank * 2 + 4; - } +INLINEDEF _CUDA_HD int shapeInfoLength(int rank) { + // FIXME magic numbers + return rank * 2 + 4; +} - INLINEDEF _CUDA_HD int shapeInfoLength(Nd4jLong* shape) { - return shapeInfoLength(static_cast(shape[0])); - } +INLINEDEF _CUDA_HD int shapeInfoLength(Nd4jLong *shape) { + return shapeInfoLength(static_cast(shape[0])); +} - INLINEDEF _CUDA_HD int shapeInfoLength(const Nd4jLong* shape) { - return shapeInfoLength(static_cast(shape[0])); - } +INLINEDEF _CUDA_HD int shapeInfoLength(const Nd4jLong *shape) { + return shapeInfoLength(static_cast(shape[0])); +} - INLINEDEF _CUDA_HD size_t shapeInfoByteLength(int rank) { - //FIXME magic numbers - return (rank * 2 + 4) * sizeof(Nd4jLong); - } +INLINEDEF _CUDA_HD size_t shapeInfoByteLength(int rank) { + // FIXME magic numbers + return (rank * 2 + 4) * sizeof(Nd4jLong); +} - INLINEDEF _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong* shapeInfo) { - //FIXME magic numbers - return shapeInfoByteLength((int) shapeInfo[0]); - } +INLINEDEF _CUDA_HD size_t shapeInfoByteLength(const Nd4jLong *shapeInfo) { + // FIXME magic numbers + return shapeInfoByteLength((int)shapeInfo[0]); +} /** * Returns the rank portion of * an information buffer */ - INLINEDEF _CUDA_HD int rank(const Nd4jLong *buffer) { - return static_cast(buffer[0]); - } +INLINEDEF _CUDA_HD int rank(const Nd4jLong *buffer) { + return static_cast(buffer[0]); +} - INLINEDEF _CUDA_HD int rank(const int *buffer) { - return buffer[0]; - } +INLINEDEF _CUDA_HD int rank(const int *buffer) { return buffer[0]; } - INLINEDEF _CUDA_HD int rank(const unsigned int *buffer) { - return static_cast(buffer[0]); - } +INLINEDEF _CUDA_HD int rank(const unsigned int *buffer) { + return static_cast(buffer[0]); +} - INLINEDEF _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo) { - return shapeInfo + 2 * shapeInfo[0] + 2; - } +INLINEDEF _CUDA_HD Nd4jLong *ews(Nd4jLong *shapeInfo) { + return shapeInfo + 2 * shapeInfo[0] + 2; +} /** * Converts a raw int buffer of the layout: @@ -2573,216 +2748,211 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * where shape and stride are both straight int pointers */ - INLINEDEF _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer) { - - traceNew(19); - - auto info = new ShapeInformation; - auto length = shapeInfoLength(rank(buffer)); - auto rank = buffer[0]; - - //start after rank - info->shape = buffer + 1; - info->stride = buffer + (1 + rank); - info->rank = rank; - info->offset = buffer[length - 3]; - info->elementWiseStride = buffer[length - 2]; - Nd4jLong *stride = buffer + 1 + rank; - info->stride = stride; - info->order = (char) buffer[length - 1]; - return info; - } +INLINEDEF _CUDA_HD ShapeInformation *infoFromBuffer(Nd4jLong *buffer) { + traceNew(19); + + auto info = new ShapeInformation; + auto length = shapeInfoLength(rank(buffer)); + auto rank = buffer[0]; + + // start after rank + info->shape = buffer + 1; + info->stride = buffer + (1 + rank); + info->rank = rank; + info->offset = buffer[length - 3]; + info->elementWiseStride = buffer[length - 2]; + Nd4jLong *stride = buffer + 1 + rank; + info->stride = stride; + info->order = (char)buffer[length - 1]; + return info; +} /** * Returns the stride portion of an information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer) { - return buffer + (1 + rank(buffer)); - } - - INLINEDEF _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer) { - return stride(const_cast(buffer)); - } +INLINEDEF _CUDA_HD Nd4jLong *stride(Nd4jLong *buffer) { + return buffer + (1 + rank(buffer)); +} - INLINEDEF _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo) { - return ((shape::extra(const_cast(shapeInfo)) & ARRAY_EMPTY) == ARRAY_EMPTY); - } +INLINEDEF _CUDA_HD Nd4jLong *stride(const Nd4jLong *buffer) { + return stride(const_cast(buffer)); +} +INLINEDEF _CUDA_HD bool isEmpty(const Nd4jLong *shapeInfo) { + return ((shape::extra(const_cast(shapeInfo)) & ARRAY_EMPTY) == + ARRAY_EMPTY); +} /** * Compute the length of the given shape */ - INLINEDEF _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo) { - - const int rank = shape::rank(shapeInfo); - - if (rank == 0) { - if (isEmpty(shapeInfo)) - return 0L; - return 1L; - } - - if (rank == 1) - return shapeInfo[1]; - - // if(shape::elementWiseStride(shapeInfo) == 1) { // contiguous - // if(shape::order(shapeInfo) == 'c') - // return shapeInfo[1] * shapeInfo[rank + 1]; // first dim * first stride - // return shapeInfo[rank] * shapeInfo[2 * rank]; // last dim * last stride - // } - - return shape::prodLong(shape::shapeOf(const_cast(shapeInfo)), rank); - } +INLINEDEF _CUDA_HD Nd4jLong length(const Nd4jLong *shapeInfo) { + const int rank = shape::rank(shapeInfo); + + if (rank == 0) { + if (isEmpty(shapeInfo)) return 0L; + return 1L; + } + + if (rank == 1) return shapeInfo[1]; + + // if(shape::elementWiseStride(shapeInfo) == 1) { // contiguous + // if(shape::order(shapeInfo) == 'c') + // return shapeInfo[1] * shapeInfo[rank + 1]; // first dim * + // first stride + // return shapeInfo[rank] * shapeInfo[2 * rank]; // last dim * last + // stride + // } + + return shape::prodLong(shape::shapeOf(const_cast(shapeInfo)), + rank); +} - INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list& shape) { - Nd4jLong ret = 1; - for (auto v : shape) { - ret *= v; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list &shape) { + Nd4jLong ret = 1; + for (auto v : shape) { + ret *= v; + } + return ret; +} - INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list& shape) { - Nd4jLong ret = 1; - for (auto v : shape) { - ret *= v; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong length(std::initializer_list &shape) { + Nd4jLong ret = 1; + for (auto v : shape) { + ret *= v; + } + return ret; +} /*** * Returns the offset * portion of an information buffer */ - INLINEDEF _CUDA_HD Nd4jLong offset(Nd4jLong *buffer) { - return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; - } - - INLINEDEF _CUDA_HD Nd4jLong& extra(Nd4jLong *buffer) { - return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; - } +INLINEDEF _CUDA_HD Nd4jLong offset(Nd4jLong *buffer) { + return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; +} +INLINEDEF _CUDA_HD Nd4jLong &extra(Nd4jLong *buffer) { + return buffer[shape::shapeInfoLength(shape::rank(buffer)) - 3]; +} /** * Returns the ordering * for this shape information buffer */ - INLINEDEF _CUDA_HD char order(const Nd4jLong *buffer) { - //FIXME magic numbers - return static_cast(buffer[buffer[0] * 2 + 3]); - } +INLINEDEF _CUDA_HD char order(const Nd4jLong *buffer) { + // FIXME magic numbers + return static_cast(buffer[buffer[0] * 2 + 3]); +} /** * Returns type */ - INLINEDEF _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo) { - return shapeInfo[2 * shapeInfo[0] + 1]; - } +INLINEDEF _CUDA_HD Nd4jLong type(const Nd4jLong *shapeInfo) { + return shapeInfo[2 * shapeInfo[0] + 1]; +} /** * Returns the element wise stride for this information * buffer */ - INLINEDEF _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer) { - return buffer[shapeInfoLength(static_cast(buffer[0])) - 2]; - } +INLINEDEF _CUDA_HD Nd4jLong elementWiseStride(const Nd4jLong *buffer) { + return buffer[shapeInfoLength(static_cast(buffer[0])) - 2]; +} /** -* Returns the element wise stride for this information -* buffer relative to a dimension and reduction index -*/ - INLINEDEF _CUDA_HD Nd4jLong reductionIndexElementWiseStride(Nd4jLong* buffer, int* dimension, int dimensionLength) { - if(dimensionLength > 1) { - if(shape::order(buffer) == 'f') { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - if(shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { - //int tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - //return tadElementWiseStride; - auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; - return tadElementWiseStride; - } - - return 1; - - } - else { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - if(shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { - auto tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - return tadElementWiseStride; - } - - return 1; - } - } - else { - if(shape::order(buffer) == 'f') { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; - return tadElementWiseStride; - } - else { - /** - * The element wise stride belongs to a reduction index. - * When used out of order, we can get rid of the data - * dependencies and rely on using the max dimension - * specified for stride instead. - * Say we take the sum(0,1) along arr - * we can use arr.stride(1) as a representation - * along which to iterate. - */ - auto tadElementWiseStride = shape::stride(buffer)[dimension[dimensionLength - 1]]; - return tadElementWiseStride; - } - } - - } + * Returns the element wise stride for this information + * buffer relative to a dimension and reduction index + */ +INLINEDEF _CUDA_HD Nd4jLong reductionIndexElementWiseStride( + Nd4jLong *buffer, int *dimension, int dimensionLength) { + if (dimensionLength > 1) { + if (shape::order(buffer) == 'f') { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + if (shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { + // int tadElementWiseStride = + // shape::stride(buffer)[dimension[dimensionLength - 1]]; return + // tadElementWiseStride; + auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; + return tadElementWiseStride; + } + + return 1; + + } else { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + if (shape::shapeOf(buffer)[dimension[dimensionLength - 1]] != 1) { + auto tadElementWiseStride = + shape::stride(buffer)[dimension[dimensionLength - 1]]; + return tadElementWiseStride; + } + + return 1; + } + } else { + if (shape::order(buffer) == 'f') { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + auto tadElementWiseStride = shape::stride(buffer)[dimension[0]]; + return tadElementWiseStride; + } else { + /** + * The element wise stride belongs to a reduction index. + * When used out of order, we can get rid of the data + * dependencies and rely on using the max dimension + * specified for stride instead. + * Say we take the sum(0,1) along arr + * we can use arr.stride(1) as a representation + * along which to iterate. + */ + auto tadElementWiseStride = + shape::stride(buffer)[dimension[dimensionLength - 1]]; + return tadElementWiseStride; + } + } +} /** * Returns whether * the given shape info buffer * represents a scalar shape */ - INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) { +INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) { + const int rank = shape::rank(info); - const int rank = shape::rank(info); + if (rank > 2) return 0; + if (rank == 0) return 1; + if (rank == 1) return shape::shapeOf(const_cast(info))[0] == 1; + if (rank == 2) + return shape::shapeOf(const_cast(info))[0] == 1 && + shape::shapeOf(const_cast(info))[1] == 1; - if(rank > 2) - return 0; - if(rank == 0) - return 1; - if(rank == 1) - return shape::shapeOf(const_cast(info))[0] == 1; - if(rank == 2) - return shape::shapeOf(const_cast(info))[0] == 1 && shape::shapeOf(const_cast(info))[1] == 1; - - return 0; - } + return 0; +} /** * Returns whether @@ -2790,19 +2960,15 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * represents a scalar * shape or not */ - INLINEDEF _CUDA_HD int isScalar(volatile ShapeInformation *info) { +INLINEDEF _CUDA_HD int isScalar(volatile ShapeInformation *info) { + const int rank = info->rank; - const int rank = info->rank; - - if(rank > 2) - return 0; - if(rank == 1) - return info->shape[0] == 1; - if(rank == 2) - return info->shape[0] == 1 && info->shape[1] == 1; + if (rank > 2) return 0; + if (rank == 1) return info->shape[0] == 1; + if (rank == 2) return info->shape[0] == 1 && info->shape[1] == 1; - return 0; - } + return 0; +} /** * Return a copy of this array with the @@ -2816,28 +2982,29 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * item */ - template - INLINEDEF _CUDA_HD void removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength, T1 *ret) { - - int count = 0; - int absLength = dataLength - indexesLength; - for (int i = 0; i < dataLength && count < absLength; i++) { - int contains = 0; - for (int j = 0; j < indexesLength; j++) { - if (i == indexes[j]) { - contains = 1; - break; - } - } - - if (!contains) { - ret[count] = data[i]; - count++; - } - } - } +template +INLINEDEF _CUDA_HD void removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, Nd4jLong indexesLength, + T1 *ret) { + int count = 0; + int absLength = dataLength - indexesLength; + for (int i = 0; i < dataLength && count < absLength; i++) { + int contains = 0; + for (int j = 0; j < indexesLength; j++) { + if (i == indexes[j]) { + contains = 1; + break; + } + } + + if (!contains) { + ret[count] = data[i]; + count++; + } + } +} - /** +/** * Return a copy of this array with the * given index omitted * @@ -2849,46 +3016,50 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * * item */ - template - INLINEDEF _CUDA_HD T1* removeIndex(T1 const* data, T2 const* indexes, Nd4jLong dataLength, Nd4jLong indexesLength) { - auto lengthOfArr = dataLength - indexesLength; - if(lengthOfArr < 0) { - printf("Remove index call created a <= 0 length array. This was likely not intended."); - } - - auto ret = new T1[lengthOfArr]; - memset(ret,0,sizeof(T1) * lengthOfArr); - removeIndex(data, indexes, dataLength, indexesLength, ret); - return ret; - } - - INLINEDEF _CUDA_HD Nd4jLong* everyIndexBut(const Nd4jLong *indexes,int indexesLength,int begin,int end) { - int len = end - indexesLength; - - traceNew(20); - - auto ret = new Nd4jLong[len]; - int retIdx = 0; - //not here that we do 0 based indexing for end - this assumes things like: - //0 to 4 are specified - for(int i = begin; i < end ; i++) { - bool found = false; - for(int j = 0; j < indexesLength; j++) { - if(indexes[j] == i) { - found = true; - break; - } - } +template +INLINEDEF _CUDA_HD T1 *removeIndex(T1 const *data, T2 const *indexes, + Nd4jLong dataLength, + Nd4jLong indexesLength) { + auto lengthOfArr = dataLength - indexesLength; + if (lengthOfArr < 0) { + printf( + "Remove index call created a <= 0 length array. This was likely not " + "intended."); + } + + auto ret = new T1[lengthOfArr]; + memset(ret, 0, sizeof(T1) * lengthOfArr); + removeIndex(data, indexes, dataLength, indexesLength, ret); + return ret; +} - if(!found) { - ret[retIdx++] = i; - } +INLINEDEF _CUDA_HD Nd4jLong *everyIndexBut(const Nd4jLong *indexes, + int indexesLength, int begin, + int end) { + int len = end - indexesLength; - } + traceNew(20); - return ret; + auto ret = new Nd4jLong[len]; + int retIdx = 0; + // not here that we do 0 based indexing for end - this assumes things like: + // 0 to 4 are specified + for (int i = begin; i < end; i++) { + bool found = false; + for (int j = 0; j < indexesLength; j++) { + if (indexes[j] == i) { + found = true; + break; + } + } + if (!found) { + ret[retIdx++] = i; } + } + + return ret; +} /** * Computes the offset for accessing @@ -2896,8 +3067,8 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * and the offset to be read. */ #ifdef __CUDACC__ - INLINEDEF __device__ int tadOffset(ShapeInformation *xInfo, int offset) { - return offset + threadIdx.x * xInfo->elementWiseStride; +INLINEDEF __device__ int tadOffset(ShapeInformation *xInfo, int offset) { + return offset + threadIdx.x * xInfo->elementWiseStride; } #endif @@ -2909,21 +3080,21 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * for the shape to be returned as * @return the new shape */ - INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape, int dimension) { - traceNew(21); +INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape, int dimension) { + traceNew(21); - Nd4jLong *ret = new Nd4jLong[2]; + Nd4jLong *ret = new Nd4jLong[2]; - if (dimension == 0) { - ret[0] = 1; - ret[1] = shape[0]; - } else { - ret[0] = shape[0]; - ret[1] = 1; - } + if (dimension == 0) { + ret[0] = 1; + ret[1] = shape[0]; + } else { + ret[0] = shape[0]; + ret[1] = 1; + } - return ret; - } + return ret; +} /** * Returns a shape @@ -2933,99 +3104,96 @@ INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) * for the shape to be returned as * @return the new shape */ - INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape) { - return ensureVectorShape(shape, 0); - } +INLINEDEF _CUDA_HD Nd4jLong *ensureVectorShape(Nd4jLong *shape) { + return ensureVectorShape(shape, 0); +} - /** - * This method does STRICT comparison for two shape buffers - * - * @param shape - * @return - */ - INLINEDEF _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { - if (shapeA[0] != shapeB[0]) - return false; +/** + * This method does STRICT comparison for two shape buffers + * + * @param shape + * @return + */ +INLINEDEF _CUDA_HD bool equalsStrict(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + if (shapeA[0] != shapeB[0]) return false; - if (shapeA[0] == 0) - return true; + if (shapeA[0] == 0) return true; - // we do full comparison here - int length = shape::shapeInfoLength(shapeA[0]); + // we do full comparison here + int length = shape::shapeInfoLength(shapeA[0]); - for (int e = 1; e < length; e++) - if (shapeA[e] != shapeB[e]) - return false; + for (int e = 1; e < length; e++) + if (shapeA[e] != shapeB[e]) return false; - return true; - } + return true; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) { - - if (shapeInfo1[0] != shapeInfo2[0]) - return false; +INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2) { + if (shapeInfo1[0] != shapeInfo2[0]) return false; - if (shapeInfo1[0] == 0) - return true; + if (shapeInfo1[0] == 0) return true; - for (uint e = 0; e < static_cast(shape::rank(shapeInfo1)); ++e) - if (shape::shapeOf(shapeInfo1)[e] != shape::shapeOf(shapeInfo2)[e] || shape::stride(shapeInfo1)[e] != shape::stride(shapeInfo2)[e]) - return false; + for (uint e = 0; e < static_cast(shape::rank(shapeInfo1)); ++e) + if (shape::shapeOf(shapeInfo1)[e] != shape::shapeOf(shapeInfo2)[e] || + shape::stride(shapeInfo1)[e] != shape::stride(shapeInfo2)[e]) + return false; - return true; + return true; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3) { - - return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3); +INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, + const Nd4jLong *shapeInfo2, + const Nd4jLong *shapeInfo3) { + return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && + shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3); +} +INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) { + if (0 == rank(shapeInfo)) return 1; + if (dim >= 0) + return shapeInfo[1 + dim]; + else + return shapeInfo[1 + (rank(shapeInfo) + dim)]; } - INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) { - if (0 == rank(shapeInfo)) - return 1; - if (dim >= 0) - return shapeInfo[1+dim]; - else - return shapeInfo[1+(rank(shapeInfo) + dim)]; - } - - INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) { - if (0 == rank(shapeInfo)) - return 1; - if (dim >= 0) - return shapeInfo[1 + rank(shapeInfo) + dim]; - else - return shapeInfo[1 + 2*rank(shapeInfo) + dim]; - } - /** - * This method does SOFT comparison for two shape buffers, we compare only rank & shapes - * - * @param shape - * @return - */ - INLINEDEF _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { - if (shapeA[0] != shapeB[0]) - return false; +INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) { + if (0 == rank(shapeInfo)) return 1; + if (dim >= 0) + return shapeInfo[1 + rank(shapeInfo) + dim]; + else + return shapeInfo[1 + 2 * rank(shapeInfo) + dim]; +} - if (shapeA[0] == 0) - return true; +/** + * This method does SOFT comparison for two shape buffers, we compare only rank + * & shapes + * + * @param shape + * @return + */ +INLINEDEF _CUDA_HD bool equalsSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + if (shapeA[0] != shapeB[0]) return false; - // we compare only shapes, and ignoring stride & ews - auto length = shapeA[0]; + if (shapeA[0] == 0) return true; - for (int e = 1; e <= length; e++) - if (shapeA[e] != shapeB[e]) - return false; + // we compare only shapes, and ignoring stride & ews + auto length = shapeA[0]; - return true; - } + for (int e = 1; e <= length; e++) + if (shapeA[e] != shapeB[e]) return false; - INLINEDEF _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, const Nd4jLong *shapeB) { + return true; +} - return equalsSoft(shapeA, shapeB) && shapeA[shapeInfoLength(shapeA) - 3] == shapeB[shapeInfoLength(shapeB) - 3]; - } +INLINEDEF _CUDA_HD bool equalsTypesAndShapesSoft(const Nd4jLong *shapeA, + const Nd4jLong *shapeB) { + return equalsSoft(shapeA, shapeB) && shapeA[shapeInfoLength(shapeA) - 3] == + shapeB[shapeInfoLength(shapeB) - 3]; +} /** * Generate an int buffer @@ -3033,36 +3201,34 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * at the specified increment * */ - template - INLINEDEF _CUDA_HD T* range(int from, int to, int increment) { - int diff = sd::math::nd4j_abs(from - to); - int retLength = diff / increment; - T *ret; - - traceNew(22); - - if(diff / increment < 1) - ret = new T[1]; - else - ret = new T[diff / increment]; - if (from < to) { - int count = 0; - for (int i = from; i < to; i += increment) { - if (count >= retLength) - break; - ret[count++] = i; - } - } else if (from > to) { - int count = 0; - for (int i = from - 1; i >= to; i -= increment) { - if (count >= retLength) - break; - ret[count++] = i; - } - } - - return ret; - } +template +INLINEDEF _CUDA_HD T *range(int from, int to, int increment) { + int diff = sd::math::nd4j_abs(from - to); + int retLength = diff / increment; + T *ret; + + traceNew(22); + + if (diff / increment < 1) + ret = new T[1]; + else + ret = new T[diff / increment]; + if (from < to) { + int count = 0; + for (int i = from; i < to; i += increment) { + if (count >= retLength) break; + ret[count++] = i; + } + } else if (from > to) { + int count = 0; + for (int i = from - 1; i >= to; i -= increment) { + if (count >= retLength) break; + ret[count++] = i; + } + } + + return ret; +} /** * Generate a range @@ -3073,10 +3239,10 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return the int array starting at from and ending at to */ - template - INLINEDEF _CUDA_HD T* range(int from, int to) { - return range(from, to, 1); - } +template +INLINEDEF _CUDA_HD T *range(int from, int to) { + return range(from, to, 1); +} /** * Keep the given indexes in the data @@ -3086,71 +3252,67 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param dataLength * @return */ - INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const* index, int indexLength, int dataLength) { - - traceNew(23); - - Nd4jLong *ret = new Nd4jLong[indexLength]; - int count = 0; - for (int i = 0; i < dataLength; i++) { - int contains = 0; - for (int j = 0; j < indexLength; j++) { - if (i == index[j]) { - contains = 1; - break; - } - } - - if (contains) - ret[count++] = data[i]; - } - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *keep(volatile Nd4jLong *data, int const *index, + int indexLength, int dataLength) { + traceNew(23); + + Nd4jLong *ret = new Nd4jLong[indexLength]; + int count = 0; + for (int i = 0; i < dataLength; i++) { + int contains = 0; + for (int j = 0; j < indexLength; j++) { + if (i == index[j]) { + contains = 1; + break; + } + } + + if (contains) ret[count++] = data[i]; + } + return ret; +} /** * Generate a reverse * copy of the data */ - template - INLINEDEF _CUDA_HD T* reverseCopy(T const* data, Nd4jLong length) { - if (length < 1) - return nullptr; - - traceNew(24); - - T *copy = new T[length]; - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = data[i]; - copy[i] = data[length - i - 1]; - copy[length - i - 1] = temp; - } - return copy; - } - - template - INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong length) { - if (length < 1) - return; - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = from[i]; - to[i] = from[length - i - 1]; - to[length - i - 1] = temp; - } - } - - template - INLINEDEF _CUDA_HD void reverseCopyTo(T const* from, T *to, Nd4jLong *indexes, Nd4jLong length) { - if (length < 1) - return; +template +INLINEDEF _CUDA_HD T *reverseCopy(T const *data, Nd4jLong length) { + if (length < 1) return nullptr; + + traceNew(24); + + T *copy = new T[length]; + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = data[i]; + copy[i] = data[length - i - 1]; + copy[length - i - 1] = temp; + } + return copy; +} - for (Nd4jLong i = 0; i <= length / 2; i++) { - T temp = from[indexes[i]]; - to[i] = from[indexes[length - i - 1]]; - to[length - i - 1] = temp; - } +template +INLINEDEF _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong length) { + if (length < 1) return; + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = from[i]; + to[i] = from[length - i - 1]; + to[length - i - 1] = temp; + } +} - } +template +INLINEDEF _CUDA_HD void reverseCopyTo(T const *from, T *to, Nd4jLong *indexes, + Nd4jLong length) { + if (length < 1) return; + + for (Nd4jLong i = 0; i <= length / 2; i++) { + T temp = from[indexes[i]]; + to[i] = from[indexes[length - i - 1]]; + to[length - i - 1] = temp; + } +} /** * @@ -3160,16 +3322,16 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param arr2Length * @return */ - template - INLINEDEF _CUDA_HD T* concat(T const* arr1, Nd4jLong const arr1Length, T const* arr2, Nd4jLong const arr2Length) { - - traceNew(25); - - T *ret = new T[arr1Length + arr2Length]; - std::memcpy(ret, arr1, arr1Length * sizeof(T)); - std::memcpy(ret + arr1Length, arr2, arr2Length * sizeof(T)); - return ret; - } +template +INLINEDEF _CUDA_HD T *concat(T const *arr1, Nd4jLong const arr1Length, + T const *arr2, Nd4jLong const arr2Length) { + traceNew(25); + + T *ret = new T[arr1Length + arr2Length]; + std::memcpy(ret, arr1, arr1Length * sizeof(T)); + std::memcpy(ret + arr1Length, arr2, arr2Length * sizeof(T)); + return ret; +} /** * @@ -3179,20 +3341,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param lengths * @return */ - template - INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, Nd4jLong const numTotalElements, T const **arr, Nd4jLong const *lengths) { - - T* ret = new T[numTotalElements]; - Nd4jLong count = 0; - - for (Nd4jLong i = 0; i < numArrays; i++) { - for (Nd4jLong j = 0; j < lengths[i]; j++) { - ret[count++] = arr[i][j]; - } - } +template +INLINEDEF _CUDA_HD T *concat(Nd4jLong const numArrays, + Nd4jLong const numTotalElements, T const **arr, + Nd4jLong const *lengths) { + T *ret = new T[numTotalElements]; + Nd4jLong count = 0; - return ret; + for (Nd4jLong i = 0; i < numArrays; i++) { + for (Nd4jLong j = 0; j < lengths[i]; j++) { + ret[count++] = arr[i][j]; } + } + + return ret; +} /** * Get the length per slice of the @@ -3206,22 +3369,24 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return the length per slice of the given shape * along the given dimension */ - INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const* shape, int const* dimension, int dimensionLength) { - if(shape::isVector(shape,rank)) { - //return total length for row vectors - if(dimensionLength == 1 && shape[0] == 1) { - return shape::prodLong(shape,rank); - } - } - else if(rank == dimensionLength) - return shape::prodLong(shape,rank); - int absSelta = sd::math::nd4j_abs(rank - dimensionLength); - traceNew(27); - auto ret2 = shape::removeIndex(shape, dimension, rank, dimensionLength); - auto ret = prodLong(ret2, absSelta); - delete[] ret2; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong lengthPerSlice(int rank, Nd4jLong const *shape, + int const *dimension, + int dimensionLength) { + if (shape::isVector(shape, rank)) { + // return total length for row vectors + if (dimensionLength == 1 && shape[0] == 1) { + return shape::prodLong(shape, rank); + } + } else if (rank == dimensionLength) + return shape::prodLong(shape, rank); + int absSelta = sd::math::nd4j_abs(rank - dimensionLength); + traceNew(27); + auto ret2 = + shape::removeIndex(shape, dimension, rank, dimensionLength); + auto ret = prodLong(ret2, absSelta); + delete[] ret2; + return ret; +} /** * calculates the offset for a tensor @@ -3230,18 +3395,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @param tensorShape * @return */ - INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int rank, int index, Nd4jLong const* shape, Nd4jLong const* tensorShape, int tensorShapeLength, int const* dimension, int dimensionLength) { - auto tensorLength = prodLong(tensorShape, tensorShapeLength); - auto lengthPerSlice2 = lengthPerSlice(rank, shape, dimension, dimensionLength); - if (lengthPerSlice2 <= 0) { - return 0; - } - - Nd4jLong offset = index * tensorLength / lengthPerSlice2; - return offset; - } +INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor( + int rank, int index, Nd4jLong const *shape, Nd4jLong const *tensorShape, + int tensorShapeLength, int const *dimension, int dimensionLength) { + auto tensorLength = prodLong(tensorShape, tensorShapeLength); + auto lengthPerSlice2 = + lengthPerSlice(rank, shape, dimension, dimensionLength); + if (lengthPerSlice2 <= 0) { + return 0; + } + + Nd4jLong offset = index * tensorLength / lengthPerSlice2; + return offset; +} - /** +/** * calculates the offset for a tensor * @param index * @param arr @@ -3249,106 +3417,105 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons * @return */ - INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2) { - Nd4jLong offset = index * tensorLength / lengthPerSlice2; - return offset; - } - +INLINEDEF _CUDA_HD Nd4jLong sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2) { + Nd4jLong offset = index * tensorLength / lengthPerSlice2; + return offset; +} #ifdef __CUDACC__ /** -* Computes the offset for accessing -* a global element given the shape information -* and the offset to be read. -*/ - INLINEDEF _CUDA_D int tadOffset(Nd4jLong *xInfo, int offset) { - return offset + threadIdx.x * elementWiseStride(xInfo); - + * Computes the offset for accessing + * a global element given the shape information + * and the offset to be read. + */ +INLINEDEF _CUDA_D int tadOffset(Nd4jLong *xInfo, int offset) { + return offset + threadIdx.x * elementWiseStride(xInfo); } #endif - - - - /** * Computes the number * of tensors along * a given dimension */ - INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(volatile int rank, volatile int length, - volatile Nd4jLong *shape, int *dimension, int dimensionLength) { - Nd4jLong *tensorShape = shape::keep(shape, dimension, dimensionLength, rank); - Nd4jLong ret = length / shape::prodLong(tensorShape, dimensionLength); - delete[] tensorShape; - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(volatile int rank, + volatile int length, + volatile Nd4jLong *shape, + int *dimension, + int dimensionLength) { + Nd4jLong *tensorShape = shape::keep(shape, dimension, dimensionLength, rank); + Nd4jLong ret = length / shape::prodLong(tensorShape, dimensionLength); + delete[] tensorShape; + return ret; +} /** * Computes the number * of tensors along * a given dimension */ - INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, int *dimension, int dimensionLength) { - Nd4jLong *keepShape = shape::shapeOf(shapeInfo); - Nd4jLong *tensorShape = shape::keep(keepShape, dimension, dimensionLength, rank(shapeInfo)); - Nd4jLong ret = shape::length(shapeInfo) / shape::prodLong(tensorShape, dimensionLength); - delete[] tensorShape; - return ret; - } - - - +INLINEDEF _CUDA_HD Nd4jLong tensorsAlongDimension(Nd4jLong *shapeInfo, + int *dimension, + int dimensionLength) { + Nd4jLong *keepShape = shape::shapeOf(shapeInfo); + Nd4jLong *tensorShape = + shape::keep(keepShape, dimension, dimensionLength, rank(shapeInfo)); + Nd4jLong ret = + shape::length(shapeInfo) / shape::prodLong(tensorShape, dimensionLength); + delete[] tensorShape; + return ret; +} /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) { - - Nd4jLong offset = baseOffset; +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const Nd4jLong *indices, + Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i]; - return offset; + return offset; } ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) { +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const int *coords, Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - Nd4jLong offset = baseOffset; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - - return offset; + return offset; } ////////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) { - - Nd4jLong offset = baseOffset; +INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, + const uint *coords, Nd4jLong baseOffset) { + Nd4jLong offset = baseOffset; - for(uint i = 1; i <= shapeInfo[0]; ++i) - if(shapeInfo[i] != 1) - offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; + for (uint i = 1; i <= shapeInfo[0]; ++i) + if (shapeInfo[i] != 1) + offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i]; - return offset; + return offset; } - /** * Returns the tensor along dimension * for the given block index @@ -3357,194 +3524,196 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param i * @return */ - INLINEDEF _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i) { - return blockIdx + i * blockSize; - } +INLINEDEF _CUDA_HD int tadForBlockIndex(int blockSize, int blockIdx, int i) { + return blockIdx + i * blockSize; +} /** * Computes the number of tads per block * */ - INLINEDEF _CUDA_HD int tadsPerBlock(int blockSize, int tads) { - return sd::math::nd4j_ceil(tads / (double) blockSize); - } +INLINEDEF _CUDA_HD int tadsPerBlock(int blockSize, int tads) { + return sd::math::nd4j_ceil(tads / (double)blockSize); +} /** * Returns a shape buffer * for the shape information metadata. */ - INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info) { +INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info) { + traceNew(29); - traceNew(29); + auto ret = new Nd4jLong[shapeInfoLength(info->rank)]; + int count = 1; + int rank = info->rank; - auto ret = new Nd4jLong[shapeInfoLength(info->rank)]; - int count = 1; - int rank = info->rank; + ret[0] = info->rank; - ret[0] = info->rank; + for (int i = 0; i < rank; i++) { + ret[count++] = info->shape[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->shape[i]; - } + for (int i = 0; i < rank; i++) { + ret[count++] = info->stride[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->stride[i]; - } + ret[count++] = info->offset; + ret[count++] = info->elementWiseStride; + ret[count] = info->order; - ret[count++] = info->offset; - ret[count++] = info->elementWiseStride; - ret[count] = info->order; + return ret; +} - return ret; - } +INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer(ShapeInformation *info, + Nd4jLong *ret) { + int count = 1; + int rank = info->rank; - INLINEDEF _CUDA_HD Nd4jLong *toShapeBuffer( ShapeInformation *info, Nd4jLong* ret) { + ret[0] = info->rank; - int count = 1; - int rank = info->rank; + if (ret[0] == 0) { + ret[1] = 0; + ret[2] = 1; + ret[3] = 99; + return ret; + } - ret[0] = info->rank; + for (int i = 0; i < rank; i++) { + ret[count++] = info->shape[i]; + } - if (ret[0] == 0) { - ret[1] = 0; - ret[2] = 1; - ret[3] = 99; - return ret; - } + for (int i = 0; i < rank; i++) { + ret[count++] = info->stride[i]; + } - for (int i = 0; i < rank; i++) { - ret[count++] = info->shape[i]; - } + ret[count++] = info->offset; + ret[count++] = info->elementWiseStride; + ret[count++] = info->order; - for (int i = 0; i < rank; i++) { - ret[count++] = info->stride[i]; - } + return ret; +} - ret[count++] = info->offset; - ret[count++] = info->elementWiseStride; - ret[count++] = info->order; +INLINEDEF _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length) { + for (int i = 0; i < length; i++) { + printf(" %lld ", (long long)arr[i]); + } - return ret; - } + printf("\n"); +} - INLINEDEF _CUDA_HD void printIntArray(const Nd4jLong *arr, const int length) { - for(int i = 0; i < length; i++) { - printf(" %lld ", (long long) arr[i]); - } +INLINEDEF _CUDA_HD void printIntArray(const int *arr, const int length) { + for (int i = 0; i < length; i++) { + printf(" %i ", arr[i]); + } - printf("\n"); - } + printf("\n"); +} - INLINEDEF _CUDA_HD void printIntArray(const int *arr, const int length) { - for(int i = 0; i < length; i++) { - printf(" %i ", arr[i]); - } +INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + Nd4jLong *shape = shape::shapeOf(shapeInfo); + printf("Rank %d\n", rank); + printf("Shape:\n"); + for (int i = 0; i < rank; i++) { + printf(" %lld ", (long long)shape[i]); + } - printf("\n"); - } + printf("\n"); - INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - Nd4jLong *shape = shape::shapeOf(shapeInfo); - printf("Rank %d\n",rank); - printf("Shape:\n"); - for(int i = 0; i < rank; i++) { - printf(" %lld ",(long long) shape[i]); - } + Nd4jLong *stride = shape::stride(shapeInfo); + printf("Stride:\n"); + for (int i = 0; i < rank; i++) { + printf(" %lld ", (long long)stride[i]); + } - printf("\n"); + printf("\n"); - Nd4jLong *stride = shape::stride(shapeInfo); - printf("Stride:\n"); - for(int i = 0; i < rank; i++) { - printf(" %lld ", (long long) stride[i]); - } + printf("Order %c\n", shape::order(shapeInfo)); +} - printf("\n"); +INLINEDEF _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + int lim = shape::shapeInfoLength(rank); + printf("ShapeInfo: ["); + for (int i = 0; i < lim; i++) { + printf("%lld", (long long)shapeInfo[i]); - printf("Order %c\n",shape::order(shapeInfo)); + if (i < lim - 1) { + printf(", "); } - - INLINEDEF _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - int lim = shape::shapeInfoLength(rank); - printf("ShapeInfo: ["); - for (int i = 0; i < lim; i++) { - printf("%lld", (long long) shapeInfo[i]); - - if (i < lim - 1) { - printf(", "); - } - } - printf("]\n"); + } + printf("]\n"); #ifndef __CUDA_ARCH__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, const Nd4jLong *shape, const Nd4jLong *strides) { - printf("%s : [", msg); - for (int i = 0; i < rank; i++) { - printf("%lld, ", (long long) shape[i]); - } +INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, int rank, + const Nd4jLong *shape, + const Nd4jLong *strides) { + printf("%s : [", msg); + for (int i = 0; i < rank; i++) { + printf("%lld, ", (long long)shape[i]); + } - for (int i = 0; i < rank; i++) { - printf("%lld", (long long) strides[i]); + for (int i = 0; i < rank; i++) { + printf("%lld", (long long)strides[i]); - if (i < rank - 1) - printf(", "); - } - printf("]\n"); + if (i < rank - 1) printf(", "); + } + printf("]\n"); #ifndef __CUDA_ARCH__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, const Nd4jLong *shapeInfo) { - int rank = shape::rank(shapeInfo); - int lim = shape::shapeInfoLength(rank); - printf("%s : [", msg); - for (int i = 0; i < lim; i++) { - printf("%lld", (long long) shapeInfo[i]); +INLINEDEF _CUDA_HD void printShapeInfoLinear(const char *msg, + const Nd4jLong *shapeInfo) { + int rank = shape::rank(shapeInfo); + int lim = shape::shapeInfoLength(rank); + printf("%s : [", msg); + for (int i = 0; i < lim; i++) { + printf("%lld", (long long)shapeInfo[i]); - if (i < lim - 1) { - printf(", "); - } - } - printf("]\n"); + if (i < lim - 1) { + printf(", "); + } + } + printf("]\n"); #ifndef __CUDACC__ - fflush(stdout); + fflush(stdout); #endif - } - - template - INLINEDEF _CUDA_HD void printArray(void *varr,int length, const char * message) { - auto arr = reinterpret_cast(varr); - if (message != nullptr) - printf("%s: [", message); - else - printf("Array: ["); +} - for (int i = 0; i < length; i ++) { - printf("%f", (float) arr[i]); - if (i + 1 < length) printf(", "); - } - printf("]\n"); +template +INLINEDEF _CUDA_HD void printArray(void *varr, int length, + const char *message) { + auto arr = reinterpret_cast(varr); + if (message != nullptr) + printf("%s: [", message); + else + printf("Array: ["); + + for (int i = 0; i < length; i++) { + printf("%f", (float)arr[i]); + if (i + 1 < length) printf(", "); + } + printf("]\n"); #ifndef __CUDACC__ - fflush(stdout); + fflush(stdout); #endif - } +} - INLINEDEF _CUDA_HD void printArray(float *arr,int length) { - printf("Array: ["); - for (int i = 0; i < length; i ++) { - printf("%f", arr[i]); - if (i + 1 < length) printf(", "); - } - printf("]\n"); - } +INLINEDEF _CUDA_HD void printArray(float *arr, int length) { + printf("Array: ["); + for (int i = 0; i < length; i++) { + printf("%f", arr[i]); + if (i + 1 < length) printf(", "); + } + printf("]\n"); +} /** * Given an linear index, element wise stride * and the length of each tad @@ -3554,54 +3723,55 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param numElementsPerTad the number of elements * per tad */ - INLINEDEF _CUDA_HD int tadIndex(int i, int elementWiseStride, int numElementsPerTad) { - return i / (numElementsPerTad * elementWiseStride); - } +INLINEDEF _CUDA_HD int tadIndex(int i, int elementWiseStride, + int numElementsPerTad) { + return i / (numElementsPerTad * elementWiseStride); +} /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - INLINEDEF _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal) { - if (tadIndexForOriginal == 0) - return 0; - return tadIndexForOriginal / (tadsForOriginal / tadsForReduced); - } - - - INLINEDEF _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - Nd4jLong *shape = shape::shapeOf(shapeBuffer); - Nd4jLong *strides = shape::stride(shapeBuffer); - - // swap shape - for (int e = 0; e < rank / 2; e++) { - int idx1 = rank - e - 1; - int idx2 = e; - int tmp = shape[idx2]; - shape[idx2] = shape[idx1]; - shape[idx1] = tmp; - } - - // swap strides - for (int e = 0; e < rank / 2; e++) { - int idx1 = rank - e - 1; - int idx2 = e; - int tmp = strides[idx2]; - strides[idx2] = strides[idx1]; - strides[idx1] = tmp; - } +INLINEDEF _CUDA_HD int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal) { + if (tadIndexForOriginal == 0) return 0; + return tadIndexForOriginal / (tadsForOriginal / tadsForReduced); +} - if (shape::order(shapeBuffer) == 'c') - shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 102; - else - shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 99; - } +INLINEDEF _CUDA_HD void transposeInplace(Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + Nd4jLong *shape = shape::shapeOf(shapeBuffer); + Nd4jLong *strides = shape::stride(shapeBuffer); + + // swap shape + for (int e = 0; e < rank / 2; e++) { + int idx1 = rank - e - 1; + int idx2 = e; + int tmp = shape[idx2]; + shape[idx2] = shape[idx1]; + shape[idx1] = tmp; + } + + // swap strides + for (int e = 0; e < rank / 2; e++) { + int idx1 = rank - e - 1; + int idx2 = e; + int tmp = strides[idx2]; + strides[idx2] = strides[idx1]; + strides[idx1] = tmp; + } + + if (shape::order(shapeBuffer) == 'c') + shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 102; + else + shapeBuffer[shape::shapeInfoLength(shapeBuffer) - 1] = 99; +} /** * Tad index for linear @@ -3609,18 +3779,19 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param tadLength * @return */ - INLINEDEF _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength) { - return linearIndex % tadLength; - } +INLINEDEF _CUDA_HD int tadIndexForLinear(int linearIndex, int tadLength) { + return linearIndex % tadLength; +} /** * Computes the number of tads * per reduce index for the * reduction tad. */ - INLINEDEF _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal) { - return tadsForOriginal / tadsForReduce; - } +INLINEDEF _CUDA_HD int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal) { + return tadsForOriginal / tadsForReduce; +} /** * Maps a linear index to a reduction index @@ -3630,139 +3801,131 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - INLINEDEF _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum) { - int tad = tadIndex(i, elementWiseStride, numElementsPerTad); - return reductionIndexForTad(tad, tadNum, originalTadNum); - } - - INLINEDEF _CUDA_HD Nd4jLong* createScalarShapeInfo() { - - traceNew(30); - - auto shape = new Nd4jLong[1]; - shape[0] = 1; - auto stride = new Nd4jLong[1]; - stride[0] = 1; - auto shapeInformation2 = new ShapeInformation(); - shapeInformation2->rank = 1; - shapeInformation2->offset = 0; - shapeInformation2->stride = stride; - shapeInformation2->shape = shape; - shapeInformation2->elementWiseStride = 1; - shapeInformation2->order = 99; - Nd4jLong *ret = shape::toShapeBuffer(shapeInformation2); - delete shapeInformation2; - delete[] shape; - delete[] stride; - return ret; - } - - INLINEDEF _CUDA_HD Nd4jLong* createScalarShapeInfo(Nd4jLong *ret) { - ret[0] = 2; - ret[1] = 1; - ret[2] = 1; - ret[3] = 1; - ret[4] = 1; - ret[5] = 0; - ret[6] = 1; - ret[7] = 99; - - return ret; - } - - -/** - * Returns the prod of the data - * up to the given length - */ - INLINEDEF _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length) { - Nd4jLong prod = 1; - for (int i = 0; i < length; i++) { - prod *= data[i]; - } - - return prod; - } - - INLINEDEF _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data, Nd4jLong *dimension,int dimensionLength) { - Nd4jLong *stride = shape::stride(data); - //corner case: return the final item when its greater than the max, since its guaranteed to be left over - //note here that strides are interpreted in reverse for tad - //start from the front rather than the back - - int rank = shape::rank(data); - - - if(shape::order(data) == 'f') { - int dimIdx = dimensionLength - 1; - for(int i = rank - 1; i >= 0; i--) { - /** - * Needs to find an algorithm such that: - * looping backwards will find the highest dimension left - * that isn't included in the dimension index list. - * - * This can also be thought of as the last item of the first index - * of the difference between the full list of indices and - * the dimension indices. - * - * We should avoid excessive object creation by only looping backwards. - */ - if(dimension[dimIdx--] != i) { - int ret = stride[i]; - return ret; - } - } - } +INLINEDEF _CUDA_HD int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum) { + int tad = tadIndex(i, elementWiseStride, numElementsPerTad); + return reductionIndexForTad(tad, tadNum, originalTadNum); +} - else { - int dimIdx = dimensionLength - 1; - - for(int i = rank - 1; i >= 0; i--) { - /** - * Needs to find an algorithm such that: - * looping backwards will find the highest dimension left - * that isn't included in the dimension index list. - * - * This can also be thought of as the last item of the first index - * of the difference between the full list of indices and - * the dimension indices. - * - * We should avoid excessive object creation by only looping backwards. - */ - if(dimension[dimIdx--] != i) { - int ret = stride[i]; - return ret; - } - } - } +INLINEDEF _CUDA_HD Nd4jLong *createScalarShapeInfo() { + traceNew(30); + + auto shape = new Nd4jLong[1]; + shape[0] = 1; + auto stride = new Nd4jLong[1]; + stride[0] = 1; + auto shapeInformation2 = new ShapeInformation(); + shapeInformation2->rank = 1; + shapeInformation2->offset = 0; + shapeInformation2->stride = stride; + shapeInformation2->shape = shape; + shapeInformation2->elementWiseStride = 1; + shapeInformation2->order = 99; + Nd4jLong *ret = shape::toShapeBuffer(shapeInformation2); + delete shapeInformation2; + delete[] shape; + delete[] stride; + return ret; +} +INLINEDEF _CUDA_HD Nd4jLong *createScalarShapeInfo(Nd4jLong *ret) { + ret[0] = 2; + ret[1] = 1; + ret[2] = 1; + ret[3] = 1; + ret[4] = 1; + ret[5] = 0; + ret[6] = 1; + ret[7] = 99; + + return ret; +} +/** + * Returns the prod of the data + * up to the given length + */ +INLINEDEF _CUDA_HD Nd4jLong prodLong(const Nd4jLong *data, int length) { + Nd4jLong prod = 1; + for (int i = 0; i < length; i++) { + prod *= data[i]; + } + return prod; +} - int ret = stride[0]; +INLINEDEF _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data, Nd4jLong *dimension, + int dimensionLength) { + Nd4jLong *stride = shape::stride(data); + // corner case: return the final item when its greater than the max, since its + // guaranteed to be left over note here that strides are interpreted in reverse + // for tad start from the front rather than the back + + int rank = shape::rank(data); + + if (shape::order(data) == 'f') { + int dimIdx = dimensionLength - 1; + for (int i = rank - 1; i >= 0; i--) { + /** + * Needs to find an algorithm such that: + * looping backwards will find the highest dimension left + * that isn't included in the dimension index list. + * + * This can also be thought of as the last item of the first index + * of the difference between the full list of indices and + * the dimension indices. + * + * We should avoid excessive object creation by only looping backwards. + */ + if (dimension[dimIdx--] != i) { + int ret = stride[i]; + return ret; + } + } + } + + else { + int dimIdx = dimensionLength - 1; + + for (int i = rank - 1; i >= 0; i--) { + /** + * Needs to find an algorithm such that: + * looping backwards will find the highest dimension left + * that isn't included in the dimension index list. + * + * This can also be thought of as the last item of the first index + * of the difference between the full list of indices and + * the dimension indices. + * + * We should avoid excessive object creation by only looping backwards. + */ + if (dimension[dimIdx--] != i) { + int ret = stride[i]; return ret; + } } + } + + int ret = stride[0]; + return ret; +} #ifdef __CUDACC__ - __device__ INLINEDEF void sweepShapeInfoBuffer(Nd4jLong *shapeInfoBuffer, Nd4jLong *targetBuffer) { - // we read first element, to find out length of our shapeInfoBuffer - int rank = shapeInfoBuffer[0]; - int len = shape::shapeInfoLength(rank); - for (int i = threadIdx.x; i < len; i += blockDim.x) - targetBuffer[i] = shapeInfoBuffer[i]; +__device__ INLINEDEF void sweepShapeInfoBuffer(Nd4jLong *shapeInfoBuffer, + Nd4jLong *targetBuffer) { + // we read first element, to find out length of our shapeInfoBuffer + int rank = shapeInfoBuffer[0]; + int len = shape::shapeInfoLength(rank); + for (int i = threadIdx.x; i < len; i += blockDim.x) + targetBuffer[i] = shapeInfoBuffer[i]; } #endif - - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr) { - return shape::shapeBufferOfNpy(arr.shape.size(),(unsigned int*) arr.shape.data(),arr.fortranOrder); - } - - - - - +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(cnpy::NpyArray arr) { + return shape::shapeBufferOfNpy( + arr.shape.size(), (unsigned int *)arr.shape.data(), arr.fortranOrder); +} // INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer) { // unsigned Nd4jLong *shape; @@ -3774,90 +3937,84 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo // return ret; // } +INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int *shape, + bool fortranOrder) { + if (fortranOrder) { + Nd4jLong *shapeBufferRet = + shape::shapeBufferFortran(rank, sd::FLOAT32, (Nd4jLong *)shape); + return shapeBufferRet; + } else { + Nd4jLong *newShape = new Nd4jLong[rank]; + for (int i = 0; i < rank; i++) { + newShape[i] = shape[i]; + } + + Nd4jLong *shapeBufferRet = shape::shapeBuffer(rank, sd::FLOAT32, newShape); + delete[] newShape; + return shapeBufferRet; + } +} - INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpy(int rank, unsigned int* shape,bool fortranOrder) { - if(fortranOrder) { - Nd4jLong *shapeBufferRet = shape::shapeBufferFortran(rank, sd::FLOAT32,(Nd4jLong *) shape); - return shapeBufferRet; - } - else { - Nd4jLong *newShape = new Nd4jLong[rank]; - for(int i = 0; i < rank; i++) { - newShape[i] = shape[i]; - } - - Nd4jLong *shapeBufferRet = shape::shapeBuffer(rank, sd::FLOAT32, newShape); - delete[] newShape; - return shapeBufferRet; - - } - } +INLINEDEF _CUDA_HD bool strideDescendingCAscendingF( + const Nd4jLong *shapeBuffer) { + int rank = shape::rank(shapeBuffer); + Nd4jLong *strides = shape::stride(const_cast(shapeBuffer)); + char order = shape::order(shapeBuffer); - INLINEDEF _CUDA_HD bool strideDescendingCAscendingF(const Nd4jLong *shapeBuffer) { - int rank = shape::rank(shapeBuffer); - Nd4jLong *strides = shape::stride(const_cast(shapeBuffer)); - char order = shape::order(shapeBuffer); - - if (shape::isRowVector(shapeBuffer) && strides[0] == 1 && strides[1] == 1) - return true; - - if (order == 'c') { - for (int i = 1; i < rank; i++) - if (strides[i-1] <= strides[i]) - return false; - return true; - } else if (order == 'f') { - for (int i = 1; i < rank; i++) - if (strides[i-1] >= strides[i]) - return false; - return true; - } else { - printf("Unknown order for array!\n"); - return false; - } - } + if (shape::isRowVector(shapeBuffer) && strides[0] == 1 && strides[1] == 1) + return true; - INLINEDEF _CUDA_HD bool isContiguous(const Nd4jLong* shapeInfo) { + if (order == 'c') { + for (int i = 1; i < rank; i++) + if (strides[i - 1] <= strides[i]) return false; + return true; + } else if (order == 'f') { + for (int i = 1; i < rank; i++) + if (strides[i - 1] >= strides[i]) return false; + return true; + } else { + printf("Unknown order for array!\n"); + return false; + } +} - return (order(shapeInfo) == 'c') && (elementWiseStride(shapeInfo) > 0); - } +INLINEDEF _CUDA_HD bool isContiguous(const Nd4jLong *shapeInfo) { + return (order(shapeInfo) == 'c') && (elementWiseStride(shapeInfo) > 0); +} ////////////////////////////////////////////////////////////////////////// // copy-past from java hasDefaultStridesForShape function -INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { - - const int rank = shape::rank(shapeInfo); +INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong *shapeInfo) { + const int rank = shape::rank(shapeInfo); - if(rank == 0) - return true; - if(!strideDescendingCAscendingF(shapeInfo)) - return false; + if (rank == 0) return true; + if (!strideDescendingCAscendingF(shapeInfo)) return false; - Nd4jLong defaultShapeInfo[MAX_SHAPEINFOLENGTH]; - memcpy(defaultShapeInfo, shapeInfo, shape::shapeInfoByteLength(shapeInfo)); - shape::updateStrides(defaultShapeInfo, shape::order(shapeInfo)); + Nd4jLong defaultShapeInfo[MAX_SHAPEINFOLENGTH]; + memcpy(defaultShapeInfo, shapeInfo, shape::shapeInfoByteLength(shapeInfo)); + shape::updateStrides(defaultShapeInfo, shape::order(shapeInfo)); - bool result = true; - for(int i = rank+1; i <= 2*rank; ++i) - if(defaultShapeInfo[i] != shapeInfo[i]) { - result = false; - break; - } + bool result = true; + for (int i = rank + 1; i <= 2 * rank; ++i) + if (defaultShapeInfo[i] != shapeInfo[i]) { + result = false; + break; + } - return result; + return result; } -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const +// int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { // int oldnd; // Nd4jLong* olddims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); -// Nd4jLong* oldstrides = shape::copyOf(oldRank, shape::stride(oldShape)); -// int np, op, last_stride; -// int oi, oj, ok, ni, nj, nk; -// Nd4jLong* newStrides = new Nd4jLong[newRank]; -// oldnd = 0; +// Nd4jLong* oldstrides = shape::copyOf(oldRank, +// shape::stride(oldShape)); int np, op, last_stride; int oi, oj, ok, +// ni, nj, nk; Nd4jLong* newStrides = new Nd4jLong[newRank]; oldnd = 0; // /* -// * Remove axes with dimension 1 from the old array. They have no effect +// * Remove axes with dimension 1 from the old array. They have no +// effect // * but would need special cases since their strides do not matter. // */ // for (oi = 0; oi < oldRank; oi++) { @@ -3894,11 +4051,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // return false; // } -// /* oi to oj and ni to nj give the axis ranges currently worked with */ -// oi = 0; -// oj = 1; -// ni = 0; -// nj = 1; +// /* oi to oj and ni to nj give the axis ranges currently worked with +// */ oi = 0; oj = 1; ni = 0; nj = 1; // while (ni < newRank && oi < oldnd) { // np = newShapeOf[ni]; @@ -3926,7 +4080,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // } // } else { // /* C order */ -// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1]) { +// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + +// 1]) { // /* not contiguous enough */ // delete[] olddims; // delete[] oldstrides; @@ -3977,7 +4132,8 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // target[shape::shapeInfoLength(newRank) - 3] = 0; // target[shape::shapeInfoLength(newRank) - 2] = 0; // target[shape::shapeInfoLength(newRank) - 1] = isFOrder ? 102 : 99; -// sd::ArrayOptions::setDataType(target, sd::ArrayOptions::dataType(oldShape)); +// sd::ArrayOptions::setDataType(target, +// sd::ArrayOptions::dataType(oldShape)); // delete[] olddims; // delete[] oldstrides; @@ -3987,18 +4143,26 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* +// oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* +// newShapeInfo) { -// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements -// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo +// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order +// (except insertion/elimination of unities) will definitely cause +// allocation of new buffer for array elements +// // also this function takes into account identical shapes +// automatically, namely in that case oldShapeInfo is completely copied +// to newShapeInfo // newShapeInfo[0] = newRank; // memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); // Nd4jLong* newStrides = shape::stride(newShapeInfo); -// const Nd4jLong* oldShape = shape::shapeOf(const_cast(oldShapeInfo)); -// const Nd4jLong* oldStrides = shape::stride(const_cast(oldShapeInfo)); -// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; +// const Nd4jLong* oldShape = +// shape::shapeOf(const_cast(oldShapeInfo)); const Nd4jLong* +// oldStrides = shape::stride(const_cast(oldShapeInfo)); +// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, +// oldDim; // while (newStart < newRank && oldStart < oldRank) { @@ -4009,13 +4173,16 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // if (newDim < oldDim) newDim *= newShape[newStop++]; // else oldDim *= oldShape[oldStop++]; -// // ------ Check whether the original axes can be combined ------ // -// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { -// if(oldShape[i] == 1) // skip unity-dimension and its stride +// // ------ Check whether the original axes can be combined ------ +// // for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { +// if(oldShape[i] == 1) // skip unity-dimension +// and its stride // continue; // while((i + step) < oldRank && oldShape[i + step] == 1) -// ++step; // skip following unity-dimensions and its strides if such are present -// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step]) +// ++step; // skip following +// unity-dimensions and its strides if such are present +// if((i + step) < oldRank && oldStrides[i] != oldShape[i + +// step] * oldStrides[i + step]) // return false; // not contiguous enough // } @@ -4027,902 +4194,956 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) { // oldStart = oldStop++; // } -// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) -// for (int i = newStart; i < newRank; ++i) +// // rest of strides should be unities (if there is remainder in +// strides space, that is newStart < newRank) for (int i = newStart; i < +// newRank; ++i) // newStrides[i] = 1; -// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order -// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews -// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type +// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order +// newShapeInfo[2 * newRank + 2] = +// shape::elementWiseStride(oldShapeInfo); // ews newShapeInfo[2 * +// newRank + 1] = shape::type(oldShapeInfo); // type // return true; // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { - - // copy shape from newShape into newShapeInfo - newShapeInfo[0] = newRank; - memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); - - // copy order - newShapeInfo[2 * newRank + 3] = newOrder; - - return shape::reshapeC(oldShapeInfo, newShapeInfo); +INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + const char newOrder, const int newRank, + const Nd4jLong *newShape, + Nd4jLong *newShapeInfo) { + // copy shape from newShape into newShapeInfo + newShapeInfo[0] = newRank; + memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); + + // copy order + newShapeInfo[2 * newRank + 3] = newOrder; + + return shape::reshapeC(oldShapeInfo, newShapeInfo); } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo) { - - // newShapeInfo contains rank, shape and order; but no strides, type and ews - - const int newRank = shape::rank(newShapeInfo); - - // if oldShapeInfo is scalar or vector with length=1 - if(shape::length(oldShapeInfo) == 1) { - for (uint i = 0; i < newRank; ++i) - shape::stride(newShapeInfo)[i] = 1; - newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); - *shape::ews(newShapeInfo) = 1; - return true; - } - - const auto oldOrder = shape::order(oldShapeInfo); - const auto newOrder = shape::order(newShapeInfo); - const auto oldEws = shape::elementWiseStride(const_cast(oldShapeInfo)); - - if(oldEws > 0 && oldOrder != newOrder) - return false; - - // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code - - // FIXME - indeed we don't need to allocate so large memory amount (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) - Nd4jLong tempBuffer[4*MAX_RANK]; - Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides; - - // exclude unities from oldShapeInfo - const int oldNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides); - const int newNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides); - - // *** SECOND STAGE - strides evaluation - - int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; - - while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) { - - newDim = newShape[newStart]; - oldDim = oldShape[oldStart]; - - while (newDim != oldDim && newDim > 0 && oldDim > 0) { - - if (newDim < oldDim) - newDim *= newShape[newStop++]; - else - oldDim *= oldShape[oldStop++]; - } - - // check c-contiguous of old axes range - for(uint i = oldStart; i < oldStop - 1; ++i) // do not check value of last stride, it doesn't matter - if(oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1]) - return false; // not contiguous - - // fill newStrides in c manner - newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride - for (int i = newStop - 2; i >= newStart; --i) - newStrides[i] = newStrides[i + 1] * newShape[i + 1]; - - newStart = newStop++; - oldStart = oldStop++; - } - - // fill new calculated strides into newShapeInfo, take into account possible unities in shape - for (int j = 0, i = 0; i < newRank; ++i) - shape::stride(newShapeInfo)[i] = (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++]; +INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong *oldShapeInfo, + Nd4jLong *newShapeInfo) { + // newShapeInfo contains rank, shape and order; but no strides, type and ews - // set ews - if(oldEws == 0) - shape::checkStridesEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, newShape, newStrides); // set ews and order - else { - newShapeInfo[2 * newRank + 3] = oldOrder; // order - *shape::ews(newShapeInfo) = oldEws; // ews - } - - sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type + const int newRank = shape::rank(newShapeInfo); + // if oldShapeInfo is scalar or vector with length=1 + if (shape::length(oldShapeInfo) == 1) { + for (uint i = 0; i < newRank; ++i) shape::stride(newShapeInfo)[i] = 1; + newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); + *shape::ews(newShapeInfo) = 1; return true; + } + + const auto oldOrder = shape::order(oldShapeInfo); + const auto newOrder = shape::order(newShapeInfo); + const auto oldEws = + shape::elementWiseStride(const_cast(oldShapeInfo)); + + if (oldEws > 0 && oldOrder != newOrder) return false; + + // *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and + // newShapeInfo (if such are present of course), since they don't affect on + // strides evaluation, however they complicate code + + // FIXME - indeed we don't need to allocate so large memory amount + // (4*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + + // 2*newNumOfNonUnities) + Nd4jLong tempBuffer[4 * MAX_RANK]; + Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2 * MAX_RANK, + *oldStrides, *newStrides; + + // exclude unities from oldShapeInfo + const int oldNumOfNonUnities = + shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides); + const int newNumOfNonUnities = + shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides); + + // *** SECOND STAGE - strides evaluation + + int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; + + while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) { + newDim = newShape[newStart]; + oldDim = oldShape[oldStart]; + + while (newDim != oldDim && newDim > 0 && oldDim > 0) { + if (newDim < oldDim) + newDim *= newShape[newStop++]; + else + oldDim *= oldShape[oldStop++]; + } + + // check c-contiguous of old axes range + for (uint i = oldStart; i < oldStop - 1; + ++i) // do not check value of last stride, it doesn't matter + if (oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1]) + return false; // not contiguous + + // fill newStrides in c manner + newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride + for (int i = newStop - 2; i >= newStart; --i) + newStrides[i] = newStrides[i + 1] * newShape[i + 1]; + + newStart = newStop++; + oldStart = oldStop++; + } + + // fill new calculated strides into newShapeInfo, take into account possible + // unities in shape + for (int j = 0, i = 0; i < newRank; ++i) + shape::stride(newShapeInfo)[i] = + (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++]; + + // set ews + if (oldEws == 0) + shape::checkStridesEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, + newShape, newStrides); // set ews and order + else { + newShapeInfo[2 * newRank + 3] = oldOrder; // order + *shape::ews(newShapeInfo) = oldEws; // ews + } + + sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type + + return true; } - - INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder) { - Nd4jLong oldnd; - Nd4jLong* oldDims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); - Nd4jLong* oldStrides = shape::copyOf(oldRank, shape::stride(oldShape)); - Nd4jLong np, op, last_stride; - Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; - auto newStrides = new Nd4jLong[newRank]; - oldnd = 0; - - /* - * Remove axes with dimension 1 from the old array. They have no effect - * but would need special cases since their strides do not matter. - */ - for (oldStart = 0; oldStart < oldRank; oldStart++) { - if (shape::shapeOf(oldShape)[oldStart] != 1) { - oldDims[oldnd] = shape::shapeOf(oldShape)[oldStart]; - oldStrides[oldnd] = shape::stride(oldShape)[oldStart]; - oldnd++; - } - } - - np = 1; - for (newStart = 0; newStart < newRank; newStart++) { - np *= newShapeOf[newStart]; - } - op = 1; - for (oldStart = 0; oldStart < oldnd; oldStart++) { - op *= oldDims[oldStart]; - } - if (np != op) { - /* different total sizes; no hope */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - - if (np == 0) { - /* the current code does not handle 0-sized arrays, so give up */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - - /* oldStart to oldStop and newStart to newStop give the axis ranges currently worked with */ - oldStart = 0; - oldStop = 1; - newStart = 0; - newStop = 1; - - while (newStart < newRank && oldStart < oldnd) { - np = newShapeOf[newStart]; - op = oldDims[oldStart]; - - while (np != op) { - if (np < op) { - /* Misses trailing 1s, these are handled later */ - np *= newShapeOf[newStop++]; - } else { - op *= oldDims[oldStop++]; - } - } - - /* Check whether the original axes can be combined */ - for (ok = oldStart; ok < oldStop - 1; ok++) { - if (isFOrder) { - if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { - /* not contiguous enough */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - } else { - /* C order */ - if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { - /* not contiguous enough */ - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return false; - } - } - } - - /* Calculate new strides for all axes currently worked with */ - if (isFOrder) { - newStrides[newStart] = oldStrides[oldStart]; - for (nk = newStart + 1; nk < newStop; nk++) { - newStrides[nk] = newStrides[nk - 1] * newShapeOf[nk - 1]; - } - } else { - /* C order */ - newStrides[newStop - 1] = oldStrides[oldStop - 1]; - for (nk = newStop - 1; nk > newStart; nk--) { - newStrides[nk - 1] = newStrides[nk] * newShapeOf[nk]; - } - } - newStart = newStop++; - oldStart = oldStop++; +INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong *oldShape, + const int newRank, Nd4jLong *newShapeOf, + bool isFOrder) { + Nd4jLong oldnd; + Nd4jLong *oldDims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); + Nd4jLong *oldStrides = shape::copyOf(oldRank, shape::stride(oldShape)); + Nd4jLong np, op, last_stride; + Nd4jLong oldStart, oldStop, ok, newStart, newStop, nk; + auto newStrides = new Nd4jLong[newRank]; + oldnd = 0; + + /* + * Remove axes with dimension 1 from the old array. They have no effect + * but would need special cases since their strides do not matter. + */ + for (oldStart = 0; oldStart < oldRank; oldStart++) { + if (shape::shapeOf(oldShape)[oldStart] != 1) { + oldDims[oldnd] = shape::shapeOf(oldShape)[oldStart]; + oldStrides[oldnd] = shape::stride(oldShape)[oldStart]; + oldnd++; + } + } + + np = 1; + for (newStart = 0; newStart < newRank; newStart++) { + np *= newShapeOf[newStart]; + } + op = 1; + for (oldStart = 0; oldStart < oldnd; oldStart++) { + op *= oldDims[oldStart]; + } + if (np != op) { + /* different total sizes; no hope */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; + } + + if (np == 0) { + /* the current code does not handle 0-sized arrays, so give up */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; + } + + /* oldStart to oldStop and newStart to newStop give the axis ranges currently + * worked with */ + oldStart = 0; + oldStop = 1; + newStart = 0; + newStop = 1; + + while (newStart < newRank && oldStart < oldnd) { + np = newShapeOf[newStart]; + op = oldDims[oldStart]; + + while (np != op) { + if (np < op) { + /* Misses trailing 1s, these are handled later */ + np *= newShapeOf[newStop++]; + } else { + op *= oldDims[oldStop++]; + } + } + + /* Check whether the original axes can be combined */ + for (ok = oldStart; ok < oldStop - 1; ok++) { + if (isFOrder) { + if (oldStrides[ok + 1] != oldDims[ok] * oldStrides[ok]) { + /* not contiguous enough */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; } - - delete[] oldDims; - delete[] oldStrides; - delete[] newStrides; - - return true; - } - - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also it sorts input array of dimensions, this operation is also necessary for creating TAD object - INLINEDEF _CUDA_H void checkDimensions(const int rank, std::vector& dimensions) { - - int dimSize = dimensions.size(); - if(dimSize == 0) - throw std::runtime_error("shape::checkDimensions method: array of dimensions is empty!"); - // check presence of negative dimensions and if they are present transform them to positive ones -dim -> rank - |dim| - for(auto& dim : dimensions) - if(dim < 0) - dim += rank; - // sort input array of dimensions, this operation is also necessary for creating TAD object in external methods - if (dimSize > 1) { - std::sort(dimensions.begin(), dimensions.end()); - // remove duplicates if they are present - dimensions.erase(std::unique(dimensions.begin(), dimensions.end()), dimensions.end()); + } else { + /* C order */ + if (oldStrides[ok] != oldDims[ok + 1] * oldStrides[ok + 1]) { + /* not contiguous enough */ + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return false; } - // check whether number of dimensions is to big (>rank) - dimSize = dimensions.size(); - if(dimSize > rank) - throw std::runtime_error("shape::checkDimensions method: number of input dimensions is too big ( > rank of array)!"); - // check if min dimension is still negative and whether max dimension is bigger then rank-1 - if(dimensions[0] < 0 || dimensions.back() > (rank-1)) - throw std::runtime_error("shape::checkDimensions method: the negative dimension is still present in input array after transform or the too big dimension is present ( > rank of array) !"); - } + } + } + + /* Calculate new strides for all axes currently worked with */ + if (isFOrder) { + newStrides[newStart] = oldStrides[oldStart]; + for (nk = newStart + 1; nk < newStop; nk++) { + newStrides[nk] = newStrides[nk - 1] * newShapeOf[nk - 1]; + } + } else { + /* C order */ + newStrides[newStop - 1] = oldStrides[oldStop - 1]; + for (nk = newStop - 1; nk > newStart; nk--) { + newStrides[nk - 1] = newStrides[nk] * newShapeOf[nk]; + } + } + newStart = newStop++; + oldStart = oldStop++; + } + + delete[] oldDims; + delete[] oldStrides; + delete[] newStrides; + + return true; +} +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also it sorts +// input array of dimensions, this operation is also necessary for creating TAD +// object +INLINEDEF _CUDA_H void checkDimensions(const int rank, + std::vector &dimensions) { + int dimSize = dimensions.size(); + if (dimSize == 0) + throw std::runtime_error( + "shape::checkDimensions method: array of dimensions is empty!"); + // check presence of negative dimensions and if they are present transform + // them to positive ones -dim -> rank - |dim| + for (auto &dim : dimensions) + if (dim < 0) dim += rank; + // sort input array of dimensions, this operation is also necessary for + // creating TAD object in external methods + if (dimSize > 1) { + std::sort(dimensions.begin(), dimensions.end()); + // remove duplicates if they are present + dimensions.erase(std::unique(dimensions.begin(), dimensions.end()), + dimensions.end()); + } + // check whether number of dimensions is to big (>rank) + dimSize = dimensions.size(); + if (dimSize > rank) + throw std::runtime_error( + "shape::checkDimensions method: number of input dimensions is too big " + "( > rank of array)!"); + // check if min dimension is still negative and whether max dimension is + // bigger then rank-1 + if (dimensions[0] < 0 || dimensions.back() > (rank - 1)) + throw std::runtime_error( + "shape::checkDimensions method: the negative dimension is still " + "present in input array after transform or the too big dimension is " + "present ( > rank of array) !"); +} // max array is outer for min array, min array is sub-array of max array -// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) -INLINEDEF _CUDA_HD void maxIndToMinInd(int* maxIdxs, int* minIdxs, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, int dimsLen) { - - const auto maxRank = shape::rank(maxShapeInfo); - const auto minRank = shape::rank(minShapeInfo); - - // if(minRank >= maxRank) - // throw std::runtime_error("shape::maxIndToMinInd method: rank of min array should be smaller then rank of max array!"); - - if(dimsLen == -1) - dimsLen = maxRank - minRank; // if size is not given (= -1) then it is equal to ranks difference - - if(maxRank == minRank) { - - if(dimsToExclude == nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} - - for (int i = 0; i < maxRank; ++i) { - - if(i < dimsLen) - minIdxs[i] = maxIdxs[i]; - else { - if(maxIdxs[i] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; - else if(maxIdxs[i] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i]; - } - } - } +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +INLINEDEF _CUDA_HD void maxIndToMinInd(int *maxIdxs, int *minIdxs, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, int dimsLen) { + const auto maxRank = shape::rank(maxShapeInfo); + const auto minRank = shape::rank(minShapeInfo); + + // if(minRank >= maxRank) + // throw std::runtime_error("shape::maxIndToMinInd method: rank of min + // array should be smaller then rank of max array!"); + + if (dimsLen == -1) + dimsLen = maxRank - minRank; // if size is not given (= -1) then it is + // equal to ranks difference + + if (maxRank == minRank) { + if (dimsToExclude == + nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} + + for (int i = 0; i < maxRank; ++i) { + if (i < dimsLen) + minIdxs[i] = maxIdxs[i]; else { - - for (int i = 0, dim = 0; i < maxRank; ++i) { - - if(dim < dimsLen && dimsToExclude[dim] == i) { - minIdxs[i] = maxIdxs[i]; - ++dim; - continue; - } - - if(maxIdxs[i] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; - else if(maxIdxs[i] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i]; - } + if (maxIdxs[i] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; + else if (maxIdxs[i] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i]; } - } - else { - - if(dimsToExclude == nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} - - for (int i = 0; i < minRank; ++i) { - - if(maxIdxs[i + dimsLen] > minShapeInfo[i + 1]) - minIdxs[i] = maxIdxs[i + dimsLen] % minShapeInfo[i + 1]; - else if(maxIdxs[i + dimsLen] == minShapeInfo[i + 1]) - minIdxs[i] = 0; - else - minIdxs[i] = maxIdxs[i + dimsLen]; - } + } + } else { + for (int i = 0, dim = 0; i < maxRank; ++i) { + if (dim < dimsLen && dimsToExclude[dim] == i) { + minIdxs[i] = maxIdxs[i]; + ++dim; + continue; } - else { - for (int minI = 0, maxI = 0, dim = 0; maxI < maxRank; ++maxI) { - - if(dim < dimsLen && dimsToExclude[dim] == maxI) { - ++dim; - continue; - } - - if(maxIdxs[maxI] == minShapeInfo[minI + 1]) - minIdxs[minI] = 0; - else if(maxIdxs[maxI] > minShapeInfo[minI + 1]) - minIdxs[minI] = maxIdxs[maxI] % minShapeInfo[minI + 1]; - else - minIdxs[minI] = maxIdxs[maxI]; - ++minI; - } + if (maxIdxs[i] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i] % minShapeInfo[i + 1]; + else if (maxIdxs[i] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i]; + } + } + } else { + if (dimsToExclude == + nullptr) { // --> means dimsToExclude == {0,1,2,...,dimsLen-1} + + for (int i = 0; i < minRank; ++i) { + if (maxIdxs[i + dimsLen] > minShapeInfo[i + 1]) + minIdxs[i] = maxIdxs[i + dimsLen] % minShapeInfo[i + 1]; + else if (maxIdxs[i + dimsLen] == minShapeInfo[i + 1]) + minIdxs[i] = 0; + else + minIdxs[i] = maxIdxs[i + dimsLen]; + } + } else { + for (int minI = 0, maxI = 0, dim = 0; maxI < maxRank; ++maxI) { + if (dim < dimsLen && dimsToExclude[dim] == maxI) { + ++dim; + continue; } - } -} - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { - int maxIdxs[MAX_RANK]; - shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); - - int minIdxs[MAX_RANK]; - maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); - - return shape::coords2index(minShapeInfo, minIdxs); - } - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { - - int maxIdxs[MAX_RANK]; - shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); - - int minIdxs[MAX_RANK]; - maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); - - return getOffset(minShapeInfo, minIdxs); + if (maxIdxs[maxI] == minShapeInfo[minI + 1]) + minIdxs[minI] = 0; + else if (maxIdxs[maxI] > minShapeInfo[minI + 1]) + minIdxs[minI] = maxIdxs[maxI] % minShapeInfo[minI + 1]; + else + minIdxs[minI] = maxIdxs[maxI]; + ++minI; + } } + } +} - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, int* memBuff, const int* dimsToExclude) { - - const auto rankMin = shape::rank(minShapeInfo); - const auto rankMax = shape::rank(maxShapeInfo); - - // if(rankMin >= rankMax) - // throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!"); - - const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff - - int* indices = memBuff; - int* increment = memBuff + rankMax; - - int N, minI, maxI; - - // calculate min per-dim-indices which corresponds to absolute minIdx index - shape::index2coords(minIdx, minShapeInfo, indices); +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, + const int dimsLen) { + int maxIdxs[MAX_RANK]; + shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); + + int minIdxs[MAX_RANK]; + maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, + dimsLen); + + return shape::coords2index(minShapeInfo, minIdxs); +} - // transform storage indices to contain per-dim max indices, purpose - memory saving - // fill increment array as well - if(dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} - for(minI = rankMin - 1, maxI = rankMax-1; maxI >= diff; --maxI, --minI) { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI]; - } - for(maxI = 0; maxI < diff; ++maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - } - } - else { - for(N = diff-1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; --maxI) { - if(N >= 0 && dimsToExclude[N] == maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - --N; - } - else { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI--]; - } - } - } +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude, + const int dimsLen) { + int maxIdxs[MAX_RANK]; + shape::index2coords(const_cast(maxIdx), maxShapeInfo, maxIdxs); + + int minIdxs[MAX_RANK]; + maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, + dimsLen); + + return getOffset(minShapeInfo, minIdxs); +} - maxI = rankMax-1; - N = 0; - int step; +////////////////////////////////////////////////////////////////////// +INLINEDEF _CUDA_HD int outerArrayOffsets( + Nd4jLong *maxOffsets, const Nd4jLong minIdx, const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, int *memBuff, const int *dimsToExclude) { + const auto rankMin = shape::rank(minShapeInfo); + const auto rankMax = shape::rank(maxShapeInfo); + + // if(rankMin >= rankMax) + // throw std::runtime_error("shape::subArrayIndex method: rank of min + // array should be smaller then rank of max array!"); + + const auto diff = + rankMax - rankMin; // the size of dimsToExclude is equal to diff + + int *indices = memBuff; + int *increment = memBuff + rankMax; + + int N, minI, maxI; + + // calculate min per-dim-indices which corresponds to absolute minIdx index + shape::index2coords(minIdx, minShapeInfo, indices); + + // transform storage indices to contain per-dim max indices, purpose - memory + // saving fill increment array as well + if (dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} + for (minI = rankMin - 1, maxI = rankMax - 1; maxI >= diff; --maxI, --minI) { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI]; + } + for (maxI = 0; maxI < diff; ++maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + } + } else { + for (N = diff - 1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; + --maxI) { + if (N >= 0 && dimsToExclude[N] == maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + --N; + } else { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI--]; + } + } + } + + maxI = rankMax - 1; + N = 0; + int step; + maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); + + // nested loops - producing of absolute indices for max array + while (maxI >= 0) { + if (increment[maxI] != 0) { + indices[maxI] += increment[maxI]; + if (indices[maxI] >= maxShapeInfo[maxI + 1]) { + indices[maxI] %= + increment[maxI]; // restore initial value of indices[maxI] + step = -1; + } else { maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); - - // nested loops - producing of absolute indices for max array - while(maxI >= 0) { - - if(increment[maxI] != 0) { - - indices[maxI] += increment[maxI]; - if(indices[maxI] >= maxShapeInfo[maxI+1]) { - indices[maxI] %= increment[maxI]; // restore initial value of indices[maxI] - step = -1; - } - else { - maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices); - step = rankMax - 1 - maxI; - } - } - else if(maxI == rankMax - 1) - step = -1; - - maxI += step; - } - return N; - } - - ////////////////////////////////////////////////////////////////////// - INLINEDEF _CUDA_HD int outerArrayIndexes(int* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude) { - - const auto rankMin = shape::rank(minShapeInfo); - const auto rankMax = shape::rank(maxShapeInfo); - - // if(rankMin >= rankMax) - // throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!"); - // if(rankMax > MAX_RANK/2) - // throw std::runtime_error("shape::subArrayIndex method: rank of max array should be <= MAX_RANK/2 !"); - - const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff - - int indices[MAX_RANK], increment[MAX_RANK]; - - int N, minI, maxI; - - // calculate min per-dim-indices which corresponds to absolute minIdx index - shape::index2coords(minIdx, minShapeInfo, indices); - - // transform storage indices to contain per-dim max indices, purpose - memory saving - // fill increment array as well - if(dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} - for(minI = rankMin - 1, maxI = rankMax-1; maxI >= diff; --maxI, --minI) { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI]; - } - for(maxI = 0; maxI < diff; ++maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - } - } - else { - for(N = diff-1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; --maxI) { - if(N >= 0 && dimsToExclude[N] == maxI) { - increment[maxI] = 1; - indices[maxI] = 0; - --N; - } - else { - increment[maxI] = (maxShapeInfo[maxI+1] == minShapeInfo[minI+1]) ? 0 : minShapeInfo[minI+1]; - indices[maxI] = indices[minI--]; - } - } - } - - maxI = rankMax-1; - N = 0; - int step; - maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); - - // nested loops - producing of absolute indices for max array - while(maxI >= 0) { - - if(increment[maxI] != 0) { - - indices[maxI] += increment[maxI]; - if(indices[maxI] >= maxShapeInfo[maxI+1]) { - indices[maxI] %= increment[maxI]; // restore initial value of indices[maxI] - step = -1; - } - else { - maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); - step = rankMax - 1 - maxI; - } - } - else if(maxI == rankMax - 1) - step = -1; - - maxI += step; - } - return N; - } - - INLINEDEF _CUDA_HD void shapeOldScalar(sd::DataType dataType, Nd4jLong* const buffer, const char order) { - - buffer[0] = 2; - buffer[1] = 1; - buffer[2] = 1; - buffer[3] = 1; - buffer[4] = 1; - buffer[6] = 1; - buffer[7] = (int)order; - - sd::ArrayOptions::setDataType(buffer, dataType); - } - - template - INLINEDEF _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length) { - for (Nd4jLong e = 0; e < length; e++) - to[e] = (T2) from[e]; - }; + step = rankMax - 1 - maxI; + } + } else if (maxI == rankMax - 1) + step = -1; + + maxI += step; + } + return N; +} ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order) { - - // firstly consider simple case when ews > 0 - const Nd4jLong ews = shape::elementWiseStride(shapeInfo); - - if(ews > 0) { - - // set offset for first sub-array, it is equal to zero always - offsets[0] = 0; - - Nd4jLong e = 0; - if(order != shape::order(shapeInfo)) - for(int i = 1; i <= shape::rank(shapeInfo); ++i) - if(shapeInfo[i] != 1) - ++e; //check whether input is CommonVector - - if(order == shape::order(shapeInfo) || e == 1) { // e==1 means common vector - e = 1; - Nd4jLong len = shape::length(shapeInfo); - while(e < len) { - offsets[e] = offsets[e - 1] + ews; - e++; - } - return; - } - } +INLINEDEF _CUDA_HD int outerArrayIndexes(int *maxIdxs, const Nd4jLong minIdx, + const Nd4jLong *maxShapeInfo, + const Nd4jLong *minShapeInfo, + const int *dimsToExclude) { + const auto rankMin = shape::rank(minShapeInfo); + const auto rankMax = shape::rank(maxShapeInfo); + + // if(rankMin >= rankMax) + // throw std::runtime_error("shape::subArrayIndex method: rank of min + // array should be smaller then rank of max array!"); + // if(rankMax > MAX_RANK/2) + // throw std::runtime_error("shape::subArrayIndex method: rank of max + // array should be <= MAX_RANK/2 !"); + + const auto diff = + rankMax - rankMin; // the size of dimsToExclude is equal to diff + + int indices[MAX_RANK], increment[MAX_RANK]; + + int N, minI, maxI; + + // calculate min per-dim-indices which corresponds to absolute minIdx index + shape::index2coords(minIdx, minShapeInfo, indices); + + // transform storage indices to contain per-dim max indices, purpose - memory + // saving fill increment array as well + if (dimsToExclude == nullptr) { // means dimsToExclude == {0,1,2,...,diff-1} + for (minI = rankMin - 1, maxI = rankMax - 1; maxI >= diff; --maxI, --minI) { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI]; + } + for (maxI = 0; maxI < diff; ++maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + } + } else { + for (N = diff - 1, minI = rankMin - 1, maxI = rankMax - 1; maxI >= 0; + --maxI) { + if (N >= 0 && dimsToExclude[N] == maxI) { + increment[maxI] = 1; + indices[maxI] = 0; + --N; + } else { + increment[maxI] = (maxShapeInfo[maxI + 1] == minShapeInfo[minI + 1]) + ? 0 + : minShapeInfo[minI + 1]; + indices[maxI] = indices[minI--]; + } + } + } + + maxI = rankMax - 1; + N = 0; + int step; + maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); + + // nested loops - producing of absolute indices for max array + while (maxI >= 0) { + if (increment[maxI] != 0) { + indices[maxI] += increment[maxI]; + if (indices[maxI] >= maxShapeInfo[maxI + 1]) { + indices[maxI] %= + increment[maxI]; // restore initial value of indices[maxI] + step = -1; + } else { + maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices); + step = rankMax - 1 - maxI; + } + } else if (maxI == rankMax - 1) + step = -1; + + maxI += step; + } + return N; +} - shape::calcOffsets(shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo)), shape::stride(const_cast(shapeInfo)), offsets, order); +INLINEDEF _CUDA_HD void shapeOldScalar(sd::DataType dataType, + Nd4jLong *const buffer, + const char order) { + buffer[0] = 2; + buffer[1] = 1; + buffer[2] = 1; + buffer[3] = 1; + buffer[4] = 1; + buffer[6] = 1; + buffer[7] = (int)order; + + sd::ArrayOptions::setDataType(buffer, dataType); } +template +INLINEDEF _CUDA_H void convertT(T1 *from, T2 *to, Nd4jLong length) { + for (Nd4jLong e = 0; e < length; e++) to[e] = (T2)from[e]; +}; + ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order) { - - // if(false) { // tests showed that this code did calculation notably slower even for big N - // Nd4jLong indexes[MAX_RANK]; - // PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes)) - // for (Nd4jLong i = 0; i < N; ++i) { - // shape::index2coords(rank, shape, i, indexes); - // subArrOffsets[i] = 0; - // for (int j = 0; j < rank; ++j) - // if(shape[j] != 1) - // subArrOffsets[i] += indexes[j] * strides[j]; - // } - // return; - // } +INLINEDEF void calcOffsets(const Nd4jLong *shapeInfo, Nd4jLong *offsets, + const char order) { + // firstly consider simple case when ews > 0 + const Nd4jLong ews = shape::elementWiseStride(shapeInfo); + if (ews > 0) { // set offset for first sub-array, it is equal to zero always offsets[0] = 0; - Nd4jLong * idx = new Nd4jLong[rank]; - Nd4jLong* offsetPerDim = new Nd4jLong[rank]; - memset(idx, 0, sizeof(Nd4jLong) * rank); - - PRAGMA_OMP_SIMD - for (int k = 0; k < rank; ++k) - offsetPerDim[k] = (shape[k] - 1) * strides[k]; - - Nd4jLong init = 0, i = 1; - // nested loops - calculation of sub-array offsets - if(order == 'c') { - - Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne; - - while(j >= 0) { - - if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity - - if(j == rankMinusOne) { // last dimension - for(int l = 1; l < shape[j]; ++l) { - offsets[i] = offsets[i - 1] + strides[j]; - i++; - } - --j; - } - else if(idx[j] < shape[j] - 1) { - init += strides[j]; - offsets[i++] = init; - ++idx[j]; - j = rankMinusOne; - } - else { - init -= offsetPerDim[j]; - idx[j--] = 0; - } - } - } - else { + Nd4jLong e = 0; + if (order != shape::order(shapeInfo)) + for (int i = 1; i <= shape::rank(shapeInfo); ++i) + if (shapeInfo[i] != 1) ++e; // check whether input is CommonVector + + if (order == shape::order(shapeInfo) || + e == 1) { // e==1 means common vector + e = 1; + Nd4jLong len = shape::length(shapeInfo); + while (e < len) { + offsets[e] = offsets[e - 1] + ews; + e++; + } + return; + } + } + + shape::calcOffsets( + shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo)), + shape::stride(const_cast(shapeInfo)), offsets, order); +} - Nd4jLong j = 0; - - while(j < rank) { - - if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity - - if(j == 0) { // last dimension - for(int l = 1; l < shape[j]; ++l) { - offsets[i] = offsets[i - 1] + strides[j]; - i++; - } - ++j; - } - else if(idx[j] < shape[j] - 1) { - init += strides[j]; - offsets[i++] = init; - ++idx[j]; - j = 0; - } - else { - init -= offsetPerDim[j]; - idx[j++] = 0; - } +////////////////////////////////////////////////////////////////////// +INLINEDEF void calcOffsets(const int rank, const Nd4jLong *shape, + const Nd4jLong *strides, Nd4jLong *offsets, + const char order) { + // if(false) { // tests showed that this code did + // calculation notably slower even for big N + // Nd4jLong indexes[MAX_RANK]; + // PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes)) + // for (Nd4jLong i = 0; i < N; ++i) { + // shape::index2coords(rank, shape, i, indexes); + // subArrOffsets[i] = 0; + // for (int j = 0; j < rank; ++j) + // if(shape[j] != 1) + // subArrOffsets[i] += indexes[j] * strides[j]; + // } + // return; + // } + + // set offset for first sub-array, it is equal to zero always + offsets[0] = 0; + + Nd4jLong *idx = new Nd4jLong[rank]; + Nd4jLong *offsetPerDim = new Nd4jLong[rank]; + memset(idx, 0, sizeof(Nd4jLong) * rank); + + PRAGMA_OMP_SIMD + for (int k = 0; k < rank; ++k) offsetPerDim[k] = (shape[k] - 1) * strides[k]; + + Nd4jLong init = 0, i = 1; + // nested loops - calculation of sub-array offsets + if (order == 'c') { + Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne; + + while (j >= 0) { + if (shape[j] == 1) { + --j; + continue; + } // ignore dimensions equal to unity + + if (j == rankMinusOne) { // last dimension + for (int l = 1; l < shape[j]; ++l) { + offsets[i] = offsets[i - 1] + strides[j]; + i++; } - } - - delete []idx; - delete []offsetPerDim; + --j; + } else if (idx[j] < shape[j] - 1) { + init += strides[j]; + offsets[i++] = init; + ++idx[j]; + j = rankMinusOne; + } else { + init -= offsetPerDim[j]; + idx[j--] = 0; + } + } + } else { + Nd4jLong j = 0; + + while (j < rank) { + if (shape[j] == 1) { + ++j; + continue; + } // ignore dimensions equal to unity + + if (j == 0) { // last dimension + for (int l = 1; l < shape[j]; ++l) { + offsets[i] = offsets[i - 1] + strides[j]; + i++; + } + ++j; + } else if (idx[j] < shape[j] - 1) { + init += strides[j]; + offsets[i++] = init; + ++idx[j]; + j = 0; + } else { + init -= offsetPerDim[j]; + idx[j++] = 0; + } + } + } + + delete[] idx; + delete[] offsetPerDim; } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo) { - - // FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities) - Nd4jLong tempBuffer[2*MAX_RANK]; - Nd4jLong *shape = tempBuffer, *strides; - - // exclude unities from shapeInfo - const int numOfNonUnities = shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides); - - shape::checkStridesEwsAndOrder(shapeInfo, shape::order(shapeInfo), numOfNonUnities, shape, strides); +INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong *shapeInfo) { + // FIXME - indeed we don't need to allocate so large memory amount + // (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + + // 2*newNumOfNonUnities) + Nd4jLong tempBuffer[2 * MAX_RANK]; + Nd4jLong *shape = tempBuffer, *strides; + + // exclude unities from shapeInfo + const int numOfNonUnities = + shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides); + + shape::checkStridesEwsAndOrder(shapeInfo, shape::order(shapeInfo), + numOfNonUnities, shape, strides); } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnities, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities) { +INLINEDEF void _CUDA_HD checkStridesEwsAndOrder( + Nd4jLong *shapeInfo, const char proposedOrder, const int numOfNonUnities, + const Nd4jLong *shapeNoUnities, const Nd4jLong *stridesNoUnities) { + const int rank = shape::rank(shapeInfo); - const int rank = shape::rank(shapeInfo); - - if(shape::length(shapeInfo) == 1) { - *shape::ews(shapeInfo) = 1; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; - return; - } + if (shape::length(shapeInfo) == 1) { + *shape::ews(shapeInfo) = 1; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; + return; + } - if(numOfNonUnities == 1) { // case of common vector - *shape::ews(shapeInfo) = *stridesNoUnities; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; - return; - } + if (numOfNonUnities == 1) { // case of common vector + *shape::ews(shapeInfo) = *stridesNoUnities; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; + return; + } - bool contiguous = true; + bool contiguous = true; - //*** check whether strides are in c contiguous order ***// - for (uint i = 0; i < numOfNonUnities - 1; ++i) { - if(stridesNoUnities[i] != shapeNoUnities[i + 1] * stridesNoUnities[i + 1]) { - contiguous = false; - break; - } + //*** check whether strides are in c contiguous order ***// + for (uint i = 0; i < numOfNonUnities - 1; ++i) { + if (stridesNoUnities[i] != + shapeNoUnities[i + 1] * stridesNoUnities[i + 1]) { + contiguous = false; + break; } + } - if(contiguous) { - - *shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1]; - shapeInfo[rank * 2 + 3] = 99; - return; - } + if (contiguous) { + *shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1]; + shapeInfo[rank * 2 + 3] = 99; + return; + } - contiguous = true; + contiguous = true; - //*** check whether strides are in f contiguous order ***// - for (uint i = 1; i < numOfNonUnities; ++i) { - if(stridesNoUnities[i] != shapeNoUnities[i - 1] * stridesNoUnities[i - 1]) { - contiguous = false; - break; - } + //*** check whether strides are in f contiguous order ***// + for (uint i = 1; i < numOfNonUnities; ++i) { + if (stridesNoUnities[i] != + shapeNoUnities[i - 1] * stridesNoUnities[i - 1]) { + contiguous = false; + break; } + } - if(contiguous) { - - *shape::ews(shapeInfo) = stridesNoUnities[0]; - shapeInfo[rank * 2 + 3] = 102; - return; - } + if (contiguous) { + *shape::ews(shapeInfo) = stridesNoUnities[0]; + shapeInfo[rank * 2 + 3] = 102; + return; + } - *shape::ews(shapeInfo) = 0; - shapeInfo[rank * 2 + 3] = (int)proposedOrder; + *shape::ews(shapeInfo) = 0; + shapeInfo[rank * 2 + 3] = (int)proposedOrder; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) { - - const int rank = shape::rank(wholeShapeInfo); - - if(dimsSize == rank || dimsSize == 0) { // means there is one sub-array and it coincides with whole array, return copy of wholeShapeInfo and one zero offset in this case - memcpy(subArrShapeInfo, wholeShapeInfo, shape::shapeInfoLength(rank) * sizeof(Nd4jLong)); - *subArrOffsets = 0; - return; - } - - const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize; - - subArrShapeInfo[0] = subArrRank; // rank - sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type - subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order - - Nd4jLong* shape = new Nd4jLong[dimsSize]; - Nd4jLong* strides = new Nd4jLong[dimsSize]; - - for(int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) { - - if(j >= 0 && i == dimsToExclude[j]) { - - strides[j] = shape::stride(wholeShapeInfo)[i]; - shape[j--] = shape::shapeOf(wholeShapeInfo)[i]; - - if(keepUnitiesInShape) { - shape::shapeOf(subArrShapeInfo)[k] = 1; - shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; - } - } - else { - shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i]; - shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; - } - - } - - // calculation of sub-array offsets (subArrOffsets) - shape::calcOffsets(dimsSize, shape, strides, subArrOffsets); - - // evaluate ews - shape::checkStridesEwsAndOrder(subArrShapeInfo); - - delete []strides; - delete []shape; +INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets( + const Nd4jLong *wholeShapeInfo, const Nd4jLong numOfSubArrs, + const int dimsSize, const int *dimsToExclude, Nd4jLong *subArrShapeInfo, + Nd4jLong *subArrOffsets, bool keepUnitiesInShape) { + const int rank = shape::rank(wholeShapeInfo); + + if (dimsSize == rank || + dimsSize == 0) { // means there is one sub-array and it coincides with + // whole array, return copy of wholeShapeInfo and one + // zero offset in this case + memcpy(subArrShapeInfo, wholeShapeInfo, + shape::shapeInfoLength(rank) * sizeof(Nd4jLong)); + *subArrOffsets = 0; + return; + } + + const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize; + + subArrShapeInfo[0] = subArrRank; // rank + sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type + subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order + + Nd4jLong *shape = new Nd4jLong[dimsSize]; + Nd4jLong *strides = new Nd4jLong[dimsSize]; + + for (int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) { + if (j >= 0 && i == dimsToExclude[j]) { + strides[j] = shape::stride(wholeShapeInfo)[i]; + shape[j--] = shape::shapeOf(wholeShapeInfo)[i]; + + if (keepUnitiesInShape) { + shape::shapeOf(subArrShapeInfo)[k] = 1; + shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; + } + } else { + shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i]; + shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i]; + } + } + + // calculation of sub-array offsets (subArrOffsets) + shape::calcOffsets(dimsSize, shape, strides, subArrOffsets); + + // evaluate ews + shape::checkStridesEwsAndOrder(subArrShapeInfo); + + delete[] strides; + delete[] shape; } ////////////////////////////////////////////////////////////////////// -INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape, const bool isStrided, const int numOfUntiesInMinShape) { - - const uint maxRank = shape::rank(maxShapeInfo); - minOffset = 0; - uint first, last, stride, n(isStrided ? 3 : 2); - - minShapeInfo[0] = keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape; - - for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) { - - if (idx[step] == idx[step + 1]) { // means whole dimension - shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i]; - shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i]; - } - else { - - first = idx[step] >= 0 ? idx[step] : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1; - last = idx[step + 1] >= 0 ? idx[step + 1] : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1; - - if(last < first) - throw("shape::calcSubArrShapeInfoAndOffset: negative range in input indexes is found!"); - - if(isStrided) { - stride = idx[step + 2]; - last /*resulting sub-array axis*/ = (last - first + stride - 1) / stride; // ceil (last - first) / stride; - } - else { - stride = 1; - last /*resulting sub-array axis*/ = last - first; - } - - minOffset += first * shape::stride(maxShapeInfo)[i]; - - if(!keepUnitiesInShape && last == 1) - continue; - - shape::shapeOf(minShapeInfo)[j] = last; - shape::stride(minShapeInfo)[j++] = last == 1 ? shape::stride(maxShapeInfo)[i] : shape::stride(maxShapeInfo)[i] * stride; - } - } - - minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order - sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type - - shape::checkStridesEwsAndOrder(minShapeInfo); +INLINEDEF void calcSubArrShapeInfoAndOffset( + const Nd4jLong *idx, const Nd4jLong *maxShapeInfo, Nd4jLong *minShapeInfo, + Nd4jLong &minOffset, const bool keepUnitiesInShape, const bool isStrided, + const int numOfUntiesInMinShape) { + const uint maxRank = shape::rank(maxShapeInfo); + minOffset = 0; + uint first, last, stride, n(isStrided ? 3 : 2); + + minShapeInfo[0] = + keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape; + + for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) { + if (idx[step] == idx[step + 1]) { // means whole dimension + shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i]; + shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i]; + } else { + first = idx[step] >= 0 ? idx[step] + : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1; + last = idx[step + 1] >= 0 + ? idx[step + 1] + : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1; + + if (last < first) + throw( + "shape::calcSubArrShapeInfoAndOffset: negative range in input " + "indexes is found!"); + + if (isStrided) { + stride = idx[step + 2]; + last /*resulting sub-array axis*/ = + (last - first + stride - 1) / + stride; // ceil (last - first) / stride; + } else { + stride = 1; + last /*resulting sub-array axis*/ = last - first; + } + + minOffset += first * shape::stride(maxShapeInfo)[i]; + + if (!keepUnitiesInShape && last == 1) continue; + + shape::shapeOf(minShapeInfo)[j] = last; + shape::stride(minShapeInfo)[j++] = + last == 1 ? shape::stride(maxShapeInfo)[i] + : shape::stride(maxShapeInfo)[i] * stride; + } + } + + minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = + shape::order(maxShapeInfo); // order + sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type + + shape::checkStridesEwsAndOrder(minShapeInfo); } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = index % shapeInfo[i]; - index /= shapeInfo[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + Nd4jLong *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = index % shapeInfo[i]; + index /= shapeInfo[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); - index /= static_cast(shapeInfo[i]); - } - coords[0] = static_cast(index); // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords) { - - for(uint i = shapeInfo[0]; i > 1; --i) { - coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); - index /= static_cast(shapeInfo[i]); - } - coords[0] = static_cast(index); // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + uint *coords) { + for (uint i = shapeInfo[0]; i > 1; --i) { + coords[i - 1] = static_cast(index) % static_cast(shapeInfo[i]); + index /= static_cast(shapeInfo[i]); + } + coords[0] = static_cast(index); // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) { - - for(uint i = rank - 1; i > 0; --i) { - coords[i] = index % shape[i]; - index /= shape[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, Nd4jLong *coords) { + for (uint i = rank - 1; i > 0; --i) { + coords[i] = index % shape[i]; + index /= shape[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords) { - - for(uint i = rank - 1; i > 0; --i) { - coords[i] = index % shape[i]; - index /= shape[i]; - } - coords[0] = index; // last iteration +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, + const Nd4jLong *shape, int *coords) { + for (uint i = rank - 1; i > 0; --i) { + coords[i] = index % shape[i]; + index /= shape[i]; + } + coords[0] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims) { - - for(uint i = dimsSize - 1; i > 0; --i) { - coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; - index /= shapeInfo[1 + tadDims[i]]; - } - coords[tadDims[0]] = index; // last iteration +INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, + int *coords, const int dimsSize, + const int *tadDims) { + for (uint i = dimsSize - 1; i > 0; --i) { + coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; + index /= shapeInfo[1 + tadDims[i]]; + } + coords[tadDims[0]] = index; // last iteration } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords) { - - if(startIndex == index) { - shape::index2coords(index, shapeInfo, coords); - } - else { - int axis = shapeInfo[0] - 1; - while(coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) - coords[axis--] = 0; - ++coords[axis]; - } +INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + Nd4jLong *coords) { + if (startIndex == index) { + shape::index2coords(index, shapeInfo, coords); + } else { + int axis = shapeInfo[0] - 1; + while (coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) + coords[axis--] = 0; + ++coords[axis]; + } } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords) { - - if(startIndex == index) { - shape::index2coords(index, shapeInfo, coords); - } - else { - int axis = shapeInfo[0] - 1; - while(coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) - coords[axis--] = 0; - ++coords[axis]; - } +INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong &startIndex, + const Nd4jLong &index, + const Nd4jLong *shapeInfo, + int *coords) { + if (startIndex == index) { + shape::index2coords(index, shapeInfo, coords); + } else { + int axis = shapeInfo[0] - 1; + while (coords[axis] == shape::sizeAt(shapeInfo, axis) - 1) + coords[axis--] = 0; + ++coords[axis]; + } } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* +// zShapeInfo, Nd4jLong*& zOffsets, const char order) { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -4935,22 +5156,27 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // const char yOrder = shape::order(yShapeInfo); // const char zOrder = shape::order(zShapeInfo); -// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); +// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, +// zShapeInfo); -// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == +// zOrder && (xOrder == 'c' || shapesSame)) { // xOffsets = yOffsets = zOffsets = nullptr; // } -// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) { +// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, yShapeInfo))) { // xOffsets = yOffsets = nullptr; // zOffsets = new Nd4jLong[len]; // shape::calcOffsets(zShapeInfo, zOffsets, xOrder); // } -// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) { +// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, zShapeInfo))) { // xOffsets = zOffsets = nullptr; // yOffsets = new Nd4jLong[len]; // shape::calcOffsets(yShapeInfo, yOffsets, xOrder); // } -// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) { +// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || +// shape::shapeEquals(yShapeInfo, zShapeInfo))) { // yOffsets = zOffsets = nullptr; // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets, yOrder); @@ -5003,7 +5229,8 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } // } // } -// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) { +// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, +// zShapeInfo)) { // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets); // yOffsets = zOffsets = xOffsets; @@ -5063,7 +5290,9 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) +// { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -5076,7 +5305,8 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo); -// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shapesSame)) { // xOffsets = yOffsets = nullptr; // } // else if(xEws == 1) { @@ -5112,52 +5342,56 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo // } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) { - - const int rank = shape::rank(inShapeInfo); - const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo)); - - if(numOfNonUnities == rank) { // no unities in shape, no copy procedure - shapeNoUnities = const_cast(inShapeInfo) + 1; - stridesNoUnities = const_cast(inShapeInfo) + 1 + rank; - return numOfNonUnities; - } +INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo( + const Nd4jLong *inShapeInfo, Nd4jLong *&shapeNoUnities, + Nd4jLong *&stridesNoUnities) { + const int rank = shape::rank(inShapeInfo); + const int numOfNonUnities = + shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo)); + + if (numOfNonUnities == rank) { // no unities in shape, no copy procedure + shapeNoUnities = const_cast(inShapeInfo) + 1; + stridesNoUnities = const_cast(inShapeInfo) + 1 + rank; + return numOfNonUnities; + } - for(uint j = 0, i = 0; i < rank; ++i) { - if(shape::shapeOf(inShapeInfo)[i] != 1) { - shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i]; - shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i]; - } + for (uint j = 0, i = 0; i < rank; ++i) { + if (shape::shapeOf(inShapeInfo)[i] != 1) { + shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i]; + shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i]; } + } - stridesNoUnities = shapeNoUnities + numOfNonUnities; + stridesNoUnities = shapeNoUnities + numOfNonUnities; - return numOfNonUnities; + return numOfNonUnities; } ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) { +INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong *inShapeInfo, + const int dimsSize, + const int *dimsToExclude, + Nd4jLong *outShapeInfo) { + outShapeInfo[0] = inShapeInfo[0] - dimsSize; - outShapeInfo[0] = inShapeInfo[0] - dimsSize; - - for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) { - if(j < dimsSize && i == dimsToExclude[j]) { - ++j; - continue; - } - - shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i]; - shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i]; + for (uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) { + if (j < dimsSize && i == dimsToExclude[j]) { + ++j; + continue; } - sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type - *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews - outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order -} + shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i]; + shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i]; + } + sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type + *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews + outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order +} ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo) { // Nd4jLong result = 9223372036854775807LL; @@ -5175,8 +5409,6 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, // return result == 9223372036854775807LL ? 1 : result; // } - - -} +} // namespace shape #endif /* SHAPE_H_ */ diff --git a/libnd4j/include/helpers/svd.h b/libnd4j/include/helpers/svd.h index 58007bf37bf1..6729c0cb6db3 100644 --- a/libnd4j/include/helpers/svd.h +++ b/libnd4j/include/helpers/svd.h @@ -22,67 +22,77 @@ #define LIBND4J_SVD_H #include + #include "array/NDArray.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template class SVD { + public: + int _switchSize = 10; - public: - - int _switchSize = 10; + NDArray _m; + NDArray _s; + NDArray _u; + NDArray _v; - NDArray _m; - NDArray _s; - NDArray _u; - NDArray _v; - - int _diagSize; + int _diagSize; - bool _transp; - bool _calcU; - bool _calcV; - bool _fullUV; + bool _transp; + bool _calcU; + bool _calcV; + bool _fullUV; - /** - * constructor - */ - SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV); + /** + * constructor + */ + SVD(const NDArray& matrix, const int switchSize, const bool calcV, + const bool calcU, const bool fullUV); - SVD(const NDArray& matrix, const int switchSize, const bool calcV, const bool calcU, const bool fullUV, const char t); + SVD(const NDArray& matrix, const int switchSize, const bool calcV, + const bool calcU, const bool fullUV, const char t); - void deflation1(int col1, int shift, int ind, int size); - - void deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size); - - void deflation(int col1, int col2, int ind, int row1W, int col1W, int shift); + void deflation1(int col1, int shift, int ind, int size); - // FIXME: proper T support required here - T secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray &permut, const NDArray& diagShifted, const T shift); + void deflation2(int col1U, int col1M, int row1W, int col1W, int ind1, + int ind2, int size); - void calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus); + void deflation(int col1, int col2, int ind, int row1W, int col1W, int shift); - void perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat); + // FIXME: proper T support required here + T secularEq(const T diff, const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& diagShifted, const T shift); - void calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V); + void calcSingVals(const NDArray& col0, const NDArray& diag, + const NDArray& permut, NDArray& singVals, NDArray& shifts, + NDArray& mus); - void calcBlockSVD(int firstCol, int size, NDArray& U, NDArray& singVals, NDArray& V); + void perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, + const NDArray& singVals, const NDArray& shifts, + const NDArray& mus, NDArray& zhat); - void DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift); + void calcSingVecs(const NDArray& zhat, const NDArray& diag, + const NDArray& perm, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& U, + NDArray& V); - void exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V); + void calcBlockSVD(int firstCol, int size, NDArray& U, NDArray& singVals, + NDArray& V); - void evalData(const NDArray& matrix); + void DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift); - FORCEINLINE NDArray& getS(); - FORCEINLINE NDArray& getU(); - FORCEINLINE NDArray& getV(); + void exchangeUV(const HHsequence& hhU, const HHsequence& hhV, + const NDArray& U, const NDArray& V); -}; + void evalData(const NDArray& matrix); + FORCEINLINE NDArray& getS(); + FORCEINLINE NDArray& getU(); + FORCEINLINE NDArray& getV(); +}; ////////////////////////////////////////////////////////////////////////// template @@ -102,10 +112,8 @@ FORCEINLINE NDArray& SVD::getV() { return _v; } +} // namespace helpers +} // namespace ops +} // namespace sd - -} -} -} - -#endif //LIBND4J_SVD_H +#endif // LIBND4J_SVD_H diff --git a/libnd4j/include/helpers/threshold.h b/libnd4j/include/helpers/threshold.h index ba5304e55278..55c34e7384fe 100644 --- a/libnd4j/include/helpers/threshold.h +++ b/libnd4j/include/helpers/threshold.h @@ -23,5 +23,4 @@ #include - -#endif //LIBND4J_THRESHOLD_H +#endif // LIBND4J_THRESHOLD_H diff --git a/libnd4j/include/helpers/unicode.h b/libnd4j/include/helpers/unicode.h index 6db4841db2ba..fa5787298c47 100644 --- a/libnd4j/include/helpers/unicode.h +++ b/libnd4j/include/helpers/unicode.h @@ -26,164 +26,163 @@ namespace sd { namespace unicode { - /** - * This method calculate u16 offset based on utf8 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset of utf16 - */ - Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end); - - /** - * This method calculate u8 offset based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end); - - /** - * This method calculate u32 offset based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end); - - /** - * This method calculate u32 offset based on utf8 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end); - - /* - * This function check is valid charecter in u8 string - */ - bool isStringValidU8(const void* start, const void* stop); - - /* - * This function check is valid charecter in u16 string - */ - bool isStringValidU16(const void* start, const void* stop); - - /* - * This function check is valid u32 charecter in string - */ - bool isStringValidU32(const void* start, const void* stop); - - /** - * This method count offset for utf8 string in utf32 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset - */ - Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize); - - /** - * This method count offset for utf8 string in utf32 - * @param const pointer to the utf8 string start point - * @param const end pointer to the utf8 string - * @return offset - */ - Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop); - - /** - * This method count offset for utf32 based on utf16 string - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset - */ - Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u16 based on utf8 - * @param const pointer to the utf8 string start point - * @param size of the string - * @return offset of utf16 - */ - Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u8 based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf8 - */ - Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u32 based on utf8 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize); - - /** - * This method calculate offset of u32 based on utf16 - * @param const pointer to the utf16 string start point - * @param size of the string - * @return offset of utf32 - */ - Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize); - - /** - * This method convert utf8 string to utf16 string - * @param const pointer to the utf8 string start point - * @param reference to start point to utf16 - * @param size of input utf8 string - * @return status of convertion - */ - bool utf8to16(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf8 string to utf32 string - * @param const pointer to the utf8 string start point - * @param reference to start point to utf32 - * @param size of input utf8 string - * @return status of convertion - */ - bool utf8to32(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf16 string to utf32 string - * @param const pointer to the utf16 string start point - * @param reference to start point to utf32 - * @param size of input utf16 string - * @return status of convertion - */ - bool utf16to32(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf16 string to utf8 string - * @param const pointer to the utf16 string start point - * @param reference to start point to utf8 - * @param size of input utf16 string - * @return status of convertion - */ - bool utf16to8(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf32 string to utf16 string - * @param const pointer to the utf32 string start point - * @param reference to start point to utf16 - * @param size of input utf32 string - * @return status of convertion - */ - bool utf32to16(const void* input, void* output, uint32_t nInputSize); - - /** - * This method convert utf32 string to utf8 string - * @param const pointer to the utf32 string start point - * @param reference to start point to utf8 - * @param size of input utf32 string - * @return status of convertion - */ - bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize); -} -} - - -#endif //LIBND4J_UNICODE_H +/** + * This method calculate u16 offset based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ +Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end); + +/** + * This method calculate u8 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end); + +/** + * This method calculate u32 offset based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end); + +/** + * This method calculate u32 offset based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end); + +/* + * This function check is valid charecter in u8 string + */ +bool isStringValidU8(const void* start, const void* stop); + +/* + * This function check is valid charecter in u16 string + */ +bool isStringValidU16(const void* start, const void* stop); + +/* + * This function check is valid u32 charecter in string + */ +bool isStringValidU32(const void* start, const void* stop); + +/** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset + */ +Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize); + +/** + * This method count offset for utf8 string in utf32 + * @param const pointer to the utf8 string start point + * @param const end pointer to the utf8 string + * @return offset + */ +Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop); + +/** + * This method count offset for utf32 based on utf16 string + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset + */ +Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u16 based on utf8 + * @param const pointer to the utf8 string start point + * @param size of the string + * @return offset of utf16 + */ +Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u8 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf8 + */ +Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u32 based on utf8 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize); + +/** + * This method calculate offset of u32 based on utf16 + * @param const pointer to the utf16 string start point + * @param size of the string + * @return offset of utf32 + */ +Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize); + +/** + * This method convert utf8 string to utf16 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf16 + * @param size of input utf8 string + * @return status of convertion + */ +bool utf8to16(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf8 string to utf32 string + * @param const pointer to the utf8 string start point + * @param reference to start point to utf32 + * @param size of input utf8 string + * @return status of convertion + */ +bool utf8to32(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf16 string to utf32 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf32 + * @param size of input utf16 string + * @return status of convertion + */ +bool utf16to32(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf16 string to utf8 string + * @param const pointer to the utf16 string start point + * @param reference to start point to utf8 + * @param size of input utf16 string + * @return status of convertion + */ +bool utf16to8(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf32 string to utf16 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf16 + * @param size of input utf32 string + * @return status of convertion + */ +bool utf32to16(const void* input, void* output, uint32_t nInputSize); + +/** + * This method convert utf32 string to utf8 string + * @param const pointer to the utf32 string start point + * @param reference to start point to utf8 + * @param size of input utf32 string + * @return status of convertion + */ +bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize); +} // namespace unicode +} // namespace sd + +#endif // LIBND4J_UNICODE_H diff --git a/libnd4j/include/indexing/IndicesList.h b/libnd4j/include/indexing/IndicesList.h index f657a191ee60..0ecbae29cbaa 100644 --- a/libnd4j/include/indexing/IndicesList.h +++ b/libnd4j/include/indexing/IndicesList.h @@ -22,22 +22,24 @@ #define LIBND4J_INDICESLIST_H #include + #include "NDIndex.h" namespace sd { - class SD_EXPORT IndicesList { - protected: - std::vector _indices; - public: - explicit IndicesList() = default; - explicit IndicesList(std::initializer_list list); - - int size(); - NDIndex* at(int idx); - void push_back(NDIndex* idx); - bool isScalar(); - - ~IndicesList(); - }; -} -#endif //LIBND4J_INDICESLIST_H +class SD_EXPORT IndicesList { + protected: + std::vector _indices; + + public: + explicit IndicesList() = default; + explicit IndicesList(std::initializer_list list); + + int size(); + NDIndex* at(int idx); + void push_back(NDIndex* idx); + bool isScalar(); + + ~IndicesList(); +}; +} // namespace sd +#endif // LIBND4J_INDICESLIST_H diff --git a/libnd4j/include/indexing/NDIndex.h b/libnd4j/include/indexing/NDIndex.h index 32c831ce2f45..00189a1aa290 100644 --- a/libnd4j/include/indexing/NDIndex.h +++ b/libnd4j/include/indexing/NDIndex.h @@ -21,54 +21,53 @@ #ifndef LIBND4J_NDINDEX_H #define LIBND4J_NDINDEX_H +#include #include + #include -#include namespace sd { - class SD_EXPORT NDIndex { - protected: - std::vector _indices; - Nd4jLong _stride = 1; - public: - NDIndex() = default; - ~NDIndex() = default; - - bool isAll(); - bool isPoint(); - virtual bool isInterval(); - - std::vector& getIndices(); - Nd4jLong stride(); +class SD_EXPORT NDIndex { + protected: + std::vector _indices; + Nd4jLong _stride = 1; - static NDIndex* all(); - static NDIndex* point(Nd4jLong pt); - static NDIndex* interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); - }; + public: + NDIndex() = default; + ~NDIndex() = default; - class SD_EXPORT NDIndexAll : public NDIndex { - public: - NDIndexAll(); - virtual bool isInterval(); - ~NDIndexAll() = default; - }; + bool isAll(); + bool isPoint(); + virtual bool isInterval(); + std::vector& getIndices(); + Nd4jLong stride(); - class SD_EXPORT NDIndexPoint : public NDIndex { - public: - NDIndexPoint(Nd4jLong point); - virtual bool isInterval(); - ~NDIndexPoint() = default; - }; + static NDIndex* all(); + static NDIndex* point(Nd4jLong pt); + static NDIndex* interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); +}; - class SD_EXPORT NDIndexInterval : public NDIndex { - public: - NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); - virtual bool isInterval(); - ~NDIndexInterval() = default; - }; -} +class SD_EXPORT NDIndexAll : public NDIndex { + public: + NDIndexAll(); + virtual bool isInterval(); + ~NDIndexAll() = default; +}; +class SD_EXPORT NDIndexPoint : public NDIndex { + public: + NDIndexPoint(Nd4jLong point); + virtual bool isInterval(); + ~NDIndexPoint() = default; +}; +class SD_EXPORT NDIndexInterval : public NDIndex { + public: + NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride = 1); + virtual bool isInterval(); + ~NDIndexInterval() = default; +}; +} // namespace sd -#endif //LIBND4J_NDINDEX_H +#endif // LIBND4J_NDINDEX_H diff --git a/libnd4j/include/indexing/impl/IndicesList.cpp b/libnd4j/include/indexing/impl/IndicesList.cpp index 5acbf57d5cf9..d3467d12a184 100644 --- a/libnd4j/include/indexing/impl/IndicesList.cpp +++ b/libnd4j/include/indexing/impl/IndicesList.cpp @@ -22,32 +22,24 @@ using namespace sd; -sd::IndicesList::IndicesList(std::initializer_list list) { - for (auto v: list) - _indices.emplace_back(v); +sd::IndicesList::IndicesList(std::initializer_list list) { + for (auto v : list) _indices.emplace_back(v); } sd::IndicesList::~IndicesList() { - for(auto v: _indices) - delete v; + for (auto v : _indices) delete v; } -int sd::IndicesList::size() { - return (int) _indices.size(); -} +int sd::IndicesList::size() { return (int)_indices.size(); } bool sd::IndicesList::isScalar() { - if (_indices.size() == 1) { - return _indices.at(0)->isPoint(); - } + if (_indices.size() == 1) { + return _indices.at(0)->isPoint(); + } - return false; + return false; } -sd::NDIndex* sd::IndicesList::at(int idx) { - return _indices.at(idx); -} +sd::NDIndex* sd::IndicesList::at(int idx) { return _indices.at(idx); } -void sd::IndicesList::push_back(NDIndex* idx) { - _indices.emplace_back(idx); -} \ No newline at end of file +void sd::IndicesList::push_back(NDIndex* idx) { _indices.emplace_back(idx); } \ No newline at end of file diff --git a/libnd4j/include/indexing/impl/NDIndex.cpp b/libnd4j/include/indexing/impl/NDIndex.cpp index 43aaf09143fc..7111ce4cd8d5 100644 --- a/libnd4j/include/indexing/impl/NDIndex.cpp +++ b/libnd4j/include/indexing/impl/NDIndex.cpp @@ -22,64 +22,45 @@ namespace sd { - bool NDIndex::isInterval() { - return false; - } +bool NDIndex::isInterval() { return false; } - Nd4jLong NDIndex::stride() { - return _stride; - } +Nd4jLong NDIndex::stride() { return _stride; } - sd::NDIndexAll::NDIndexAll() : sd::NDIndex() { - _indices.push_back(-1); - } +sd::NDIndexAll::NDIndexAll() : sd::NDIndex() { _indices.push_back(-1); } - sd::NDIndexPoint::NDIndexPoint(Nd4jLong point) : sd::NDIndex() { - this->_indices.push_back(point); - } +sd::NDIndexPoint::NDIndexPoint(Nd4jLong point) : sd::NDIndex() { + this->_indices.push_back(point); +} - bool NDIndexAll::isInterval() { - return false; - } +bool NDIndexAll::isInterval() { return false; } - bool NDIndexPoint::isInterval() { - return false; - } +bool NDIndexPoint::isInterval() { return false; } - bool NDIndexInterval::isInterval() { - return true; - } +bool NDIndexInterval::isInterval() { return true; } +sd::NDIndexInterval::NDIndexInterval(Nd4jLong start, Nd4jLong end, + Nd4jLong stride) + : sd::NDIndex() { + this->_stride = stride; + for (int e = start; e < end; e += stride) this->_indices.push_back(e); +} +bool sd::NDIndex::isAll() { + return _indices.size() == 1 && _indices.at(0) == -1; +} - sd::NDIndexInterval::NDIndexInterval(Nd4jLong start, Nd4jLong end, Nd4jLong stride) : sd::NDIndex() { - this->_stride = stride; - for (int e = start; e < end; e+= stride) - this->_indices.push_back(e); - } +bool sd::NDIndex::isPoint() { + return _indices.size() == 1 && _indices.at(0) >= 0; +} - bool sd::NDIndex::isAll() { - return _indices.size() == 1 && _indices.at(0) == -1; - } +std::vector &sd::NDIndex::getIndices() { return _indices; } - bool sd::NDIndex::isPoint() { - return _indices.size() == 1 && _indices.at(0) >= 0; - } +sd::NDIndex *sd::NDIndex::all() { return new NDIndexAll(); } - std::vector &sd::NDIndex::getIndices() { - return _indices; - } +sd::NDIndex *sd::NDIndex::point(Nd4jLong pt) { return new NDIndexPoint(pt); } - - sd::NDIndex *sd::NDIndex::all() { - return new NDIndexAll(); - } - - sd::NDIndex *sd::NDIndex::point(Nd4jLong pt) { - return new NDIndexPoint(pt); - } - - sd::NDIndex *sd::NDIndex::interval(Nd4jLong start, Nd4jLong end, Nd4jLong stride) { - return new NDIndexInterval(start, end, stride); - } -} \ No newline at end of file +sd::NDIndex *sd::NDIndex::interval(Nd4jLong start, Nd4jLong end, + Nd4jLong stride) { + return new NDIndexInterval(start, end, stride); +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index fbd770147775..9e1c8f2842b6 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -21,13 +21,12 @@ #ifndef NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H #define NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H - -#include -#include +#include +#include #include #include -#include -#include +#include +#include /** * Native op executioner: @@ -35,645 +34,574 @@ */ class SD_EXPORT NativeOpExecutioner { -public: - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execIndexReduceScalar(sd::LaunchContext *lc, - int opNum, + public: + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execIndexReduceScalar(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfo - */ - static void execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParamsVals - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static void execReduce3All(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execIndexReduce(sd::LaunchContext *lc, - int opNum, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execReduce3Scalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfo + */ + static void execReduce3(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParamsVals + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadOnlyShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets); + + static void execReduce3All(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execIndexReduce(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + /** + * + * @param opNum + * @param x + * @param xStride + * @param result + * @param resultStride + * @param scalar + * @param extraParams + * @param n + */ + static void execScalar(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism = true); + + static void execScalarBool(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, + const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + + static void execScalarInt(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, + const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + + static void execScalar(sd::LaunchContext *lc, int opNum, void const *hX, + Nd4jLong const *hXShapeInfo, void const *dX, + Nd4jLong const *dXShapeInfo, void *extraParams, + void *hZ, Nd4jLong const *hZShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, void const *hScalars, + Nd4jLong const *hScalarShapeInfo, void const *dScalars, + Nd4jLong const *dScalarShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + static void execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + static void execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param y + * @param yShapeInfo + * @param result + * @param resultShapeInfo + * @param dimension + * @param dimensionLength + */ + static void execBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcast(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execInverseBroadcast( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastBool(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - /** - * - * @param opNum - * @param x - * @param xStride - * @param result - * @param resultStride - * @param scalar - * @param extraParams - * @param n - */ - static void execScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism = true); - -static void execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism = true); - -static void execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams); + + static void execInverseBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, const void *hY, + const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execInverseBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + /** + * + * @param opNum + * @param dx + * @param xStride + * @param y + * @param yStride + * @param result + * @param resultStride + * @param extraParams + * @param n + */ + static void execPairwiseTransform(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams); + + static void execPairwiseBoolTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams); + + static void execPairwiseIntTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams); + + /** + * + * @param opNum + * @param dx + * @param xStride + * @param result + * @param resultStride + * @param extraParams + * @param n + */ + static void execTransformFloat(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execTransformAny(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool allowParallelism = true); - static void execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - static void execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static void execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - -/** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param y - * @param yShapeInfo - * @param result - * @param resultShapeInfo - * @param dimension - * @param dimensionLength - */ - static void execBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); - - static void execBroadcast(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - - static void execBroadcastBool(sd::LaunchContext *lc, - int opNum, + static void execTransformStrict(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); - static void execBroadcastBool(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static void execBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ); + static void execTransformSame(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); - static void execBroadcastInt(sd::LaunchContext* lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + static void execTransformBool(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execReduceFloat(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceSame(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceBool(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static void execReduceLong(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @return + */ + static void execReduceFloatScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); -/** - * - * @param opNum - * @param dx - * @param xStride - * @param y - * @param yStride - * @param result - * @param resultStride - * @param extraParams - * @param n - */ - static void execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execPairwiseBoolTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); - - static void execPairwiseIntTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams); + static void execReduceBoolScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); -/** - * - * @param opNum - * @param dx - * @param xStride - * @param result - * @param resultStride - * @param extraParams - * @param n - */ - static void execTransformFloat(sd::LaunchContext *lc, - int opNum, + static void execReduceSameScalar(sd::LaunchContext *lc, int opNum, const void *hX, const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformAny(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool allowParallelism = true); - -static void execTransformStrict(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - -static void execTransformBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execReduceFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static void execReduceLong(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @return - */ - static void execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo); - - static void execReduce3TAD(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfoBuffer - * @param dimension - * @param dimensionLength - */ - static void execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool biasCorrected); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected); - - /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - * @param result - * @param resultShapeInfo - */ - static void execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected); - - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeBuffer, - const void *dX, const Nd4jLong *dXShapeBuffer, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - static void execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeBuffer, - const void *dX, const Nd4jLong *dXShapeBuffer, - const void *hY, const Nd4jLong *hYShapeBuffer, - const void *dY, const Nd4jLong *dYShapeBuffer, - void *hZ, const Nd4jLong *hZShapeBuffer, - void *dZ, const Nd4jLong *dZShapeBuffer, - void *extraArguments); - - - - inline static void execSort(void *x, const Nd4jLong *xShapeInfo, bool descending) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), LIBND4J_TYPES); - } - - static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); - } - - inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) { - sd::sparse::SparseUtils::sortCooIndicesGeneric(indices, reinterpret_cast(values), length, rank); - } - - - inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) { - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, return sd::SpecialMethods, ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES); - } - - inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo) { - auto zType = sd::ArrayOptions::dataType(zShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), FLOAT_TYPES); - } + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + static void execReduceLongScalar(sd::LaunchContext *lc, int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo); + + static void execReduce3TAD(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffsets); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfoBuffer + * @param dimension + * @param dimensionLength + */ + static void execSummaryStats(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, + int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool biasCorrected); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execSummaryStats(sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, + const Nd4jLong *dXShapeInfo, void *extraParams, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected); + + /** + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + * @param result + * @param resultShapeInfo + */ + static void execSummaryStatsScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + const void *hX, const Nd4jLong *hXShapeBuffer, + const void *dX, const Nd4jLong *dXShapeBuffer, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + static void execRandom(sd::LaunchContext *lc, int opNum, Nd4jPointer state, + const void *hX, const Nd4jLong *hXShapeBuffer, + const void *dX, const Nd4jLong *dXShapeBuffer, + const void *hY, const Nd4jLong *hYShapeBuffer, + const void *dY, const Nd4jLong *dYShapeBuffer, + void *hZ, const Nd4jLong *hZShapeBuffer, void *dZ, + const Nd4jLong *dZShapeBuffer, void *extraArguments); + + inline static void execSort(void *x, const Nd4jLong *xShapeInfo, + bool descending) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::sortGeneric(x, xShapeInfo, descending), + LIBND4J_TYPES); + } + + static void execSort(void *x, const Nd4jLong *xShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending), + LIBND4J_TYPES); + } + + inline static void execSortCooIndices(Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + sd::sparse::SparseUtils::sortCooIndicesGeneric( + indices, reinterpret_cast(values), length, rank); + } + + inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, + Nd4jLong N, int *dz, float threshold) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, return sd::SpecialMethods, + ::encodeBitmapGeneric(dx, xShapeInfo, N, dz, threshold), FLOAT_TYPES); + } + + inline static void decodeBitmap(const void *dx, Nd4jLong N, void *dz, + const Nd4jLong *zShapeInfo) { + auto zType = sd::ArrayOptions::dataType(zShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, + ::decodeBitmapGeneric(dx, N, dz, zShapeInfo), + FLOAT_TYPES); + } }; - -#endif //NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H +#endif // NATIVEOPERATIONS_NATIVEOPEXCUTIONER_H diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h old mode 100755 new mode 100644 index aa64db11364b..ac459510760a --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -31,7 +31,7 @@ defined __DMC__ || \ defined __BORLANDC__ ) # define thread_local __declspec(thread) -// note that ICC (linux) and Clang are covered by __GNUC__ +// note that ICC (linux) and Clang are covered by __GNUC__ # elif defined __GNUC__ || \ defined __SUNPRO_C || \ defined __xlC__ @@ -42,14 +42,14 @@ #endif */ +#include #include #include -#include -//DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION -//IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN -//RE ADDS THE DEFINITION VIA dll.h -#ifdef _WIN32 +// DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION +// IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN +// RE ADDS THE DEFINITION VIA dll.h +#ifdef _WIN32 #define SD_EXPORT __declspec(dllexport) #else #define SD_EXPORT @@ -64,16 +64,16 @@ bool debug = false; bool verbose = false; */ -#include -#include -#include #include +#include #include -#include +#include #include +#include #include #include -#include +#include +#include #include #include @@ -113,17 +113,19 @@ SD_EXPORT void setElementThreshold(int num); SD_EXPORT void setTADThreshold(int num); /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - */ -SD_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + */ +SD_EXPORT void execIndexReduceScalar(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, + void* extraParams, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -136,12 +138,12 @@ SD_EXPORT void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -SD_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execIndexReduce( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); /** * @@ -155,23 +157,25 @@ SD_EXPORT void execIndexReduce(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -SD_EXPORT void execBroadcast( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execBroadcast(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, OpaqueDataBuffer* dbY, + Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape); - -SD_EXPORT void execBroadcastBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execBroadcastBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, void* extraParams, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape); /** * @@ -186,20 +190,20 @@ SD_EXPORT void execBroadcastBool( * @param n */ SD_EXPORT void execPairwiseTransform( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + void* extraParams); SD_EXPORT void execPairwiseTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + void* extraParams); /** * @@ -210,30 +214,37 @@ SD_EXPORT void execPairwiseTransformBool( * @param result * @param resultShapeInfo */ -SD_EXPORT void execReduceFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - -SD_EXPORT void execReduceSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - -SD_EXPORT void execReduceBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); - - -SD_EXPORT void execReduceLong(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduceFloat(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceSame(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceBool(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); + +SD_EXPORT void execReduceLong(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -244,36 +255,33 @@ SD_EXPORT void execReduceLong(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -SD_EXPORT void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); - - -SD_EXPORT void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceFloat2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceSame2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); -SD_EXPORT void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceBool2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); - -SD_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); +SD_EXPORT void execReduceLong2( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape); /** * @@ -286,12 +294,13 @@ SD_EXPORT void execReduceLong2(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -SD_EXPORT void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduce3(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, void* extraParamsVals, + OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo); /** * @@ -302,12 +311,12 @@ SD_EXPORT void execReduce3(Nd4jPointer *extraPointers, * @param y * @param yShapeInfo */ -SD_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); +SD_EXPORT void execReduce3Scalar( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo); /** * * @param opNum @@ -321,26 +330,27 @@ SD_EXPORT void execReduce3Scalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -SD_EXPORT void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets); - - -SD_EXPORT void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets); +SD_EXPORT void execReduce3Tad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadOnlyShapeInfo, + Nd4jLong const* yTadOffsets); + +SD_EXPORT void execReduce3All( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, OpaqueDataBuffer* dbY, Nd4jLong const* hYShapeInfo, + Nd4jLong const* dYShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbDimension, Nd4jLong const* hDimensionShape, + Nd4jLong const* dDimensionShape, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets); /** * @@ -353,19 +363,22 @@ SD_EXPORT void execReduce3All(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -SD_EXPORT void execScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, - void *extraParams); +SD_EXPORT void execScalar(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, + OpaqueDataBuffer* dbScalar, + Nd4jLong const* hSscalarShapeInfo, + Nd4jLong const* dSscalarShapeInfo, void* extraParams); -SD_EXPORT void execScalarBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, - void *extraParams); +SD_EXPORT void execScalarBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalar, + Nd4jLong const* hSscalarShapeInfo, Nd4jLong const* dSscalarShapeInfo, + void* extraParams); /** * @@ -374,12 +387,11 @@ SD_EXPORT void execScalarBool(Nd4jPointer *extraPointers, * @param xShapeInfo * @param extraParams */ -SD_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected); +SD_EXPORT void execSummaryStatsScalar( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, bool biasCorrected); /** * * @param opNum @@ -389,12 +401,11 @@ SD_EXPORT void execSummaryStatsScalar(Nd4jPointer *extraPointers, * @param result * @param resultShapeInfo */ -SD_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected); +SD_EXPORT void execSummaryStats( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, bool biasCorrected); /** * * @param opNum @@ -406,14 +417,14 @@ SD_EXPORT void execSummaryStats(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -SD_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - bool biasCorrected, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); +SD_EXPORT void execSummaryStatsTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, void* extraParams, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + bool biasCorrected, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); /** * @@ -425,35 +436,37 @@ SD_EXPORT void execSummaryStatsTad(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -SD_EXPORT void execTransformFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformFloat( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -SD_EXPORT void execTransformSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformSame( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -SD_EXPORT void execTransformBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformBool( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -SD_EXPORT void execTransformAny(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformAny(Nd4jPointer* extraPointers, int opNum, + OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, + Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); -SD_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams); +SD_EXPORT void execTransformStrict( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, void* extraParams); /** * @@ -468,36 +481,34 @@ SD_EXPORT void execTransformStrict(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -SD_EXPORT void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - -SD_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); - -SD_EXPORT void specialConcat ( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *result, - Nd4jLong const* resultShapeInfo, - Nd4jPointer *tadPointers, - Nd4jPointer *offsetPointers); +SD_EXPORT void execScalarTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalars, + Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, + void* extraParams, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); + +SD_EXPORT void execScalarBoolTad( + Nd4jPointer* extraPointers, int opNum, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* hZShapeInfo, + Nd4jLong const* dZShapeInfo, OpaqueDataBuffer* dbScalars, + Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, + void* extraParams, OpaqueDataBuffer* dbDimension, + Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ); + +SD_EXPORT void specialConcat(Nd4jPointer* extraPointers, int dimension, + int numArrays, Nd4jPointer* data, + Nd4jPointer* inputShapeInfo, void* result, + Nd4jLong const* resultShapeInfo, + Nd4jPointer* tadPointers, + Nd4jPointer* offsetPointers); /** * This method implementation exists only for cuda. @@ -505,7 +516,7 @@ SD_EXPORT void specialConcat ( */ SD_EXPORT void initializeDevicesAndFunctions(); -SD_EXPORT void initializeFunctions(Nd4jPointer *functions); +SD_EXPORT void initializeFunctions(Nd4jPointer* functions); /** * This method acquires memory chunk of requested size on host side @@ -521,10 +532,12 @@ SD_EXPORT Nd4jPointer mallocHost(Nd4jLong memorySize, int flags); * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ -SD_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags); +SD_EXPORT Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, + int flags); /** * This method releases previously allocated host memory space @@ -565,7 +578,6 @@ SD_EXPORT void setOmpNumThreads(int threads); */ SD_EXPORT void setOmpMinThreads(int threads); - SD_EXPORT bool isBlasVersionMatches(int major, int minor, int build); /** @@ -674,7 +686,7 @@ SD_EXPORT int getDeviceMinor(int deviceId); * @param ptrToDeviceId * @return */ -SD_EXPORT const char * getDeviceName(int deviceId); +SD_EXPORT const char* getDeviceName(int deviceId); /** * @@ -685,11 +697,8 @@ SD_EXPORT const char * getDeviceName(int deviceId); * @param reserved * @return */ -SD_EXPORT int memcpySync(Nd4jPointer dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * @@ -700,11 +709,8 @@ SD_EXPORT int memcpySync(Nd4jPointer dst, * @param reserved * @return */ -SD_EXPORT int memcpyAsync(Nd4jPointer dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * @@ -715,11 +721,8 @@ SD_EXPORT int memcpyAsync(Nd4jPointer dst, * @param reserved * @return */ -SD_EXPORT int memsetSync(Nd4jPointer dst, - int value, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved); /** * @@ -730,11 +733,8 @@ SD_EXPORT int memsetSync(Nd4jPointer dst, * @param reserved * @return */ -SD_EXPORT int memsetAsync(Nd4jPointer dst, - int value, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved); /** * @@ -745,11 +745,8 @@ SD_EXPORT int memsetAsync(Nd4jPointer dst, * @param reserved * @return */ -SD_EXPORT int memcpyConstantAsync(Nd4jLong dst, - Nd4jPointer src, - Nd4jLong size, - int flags, - Nd4jPointer reserved); +SD_EXPORT int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, + int flags, Nd4jPointer reserved); /** * @@ -791,9 +788,8 @@ typedef sd::TadPack OpaqueTadPack; * @param targetBuffer * @param offsetsBuffer */ -SD_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const*xShapeInfo, - int *dimension, - int dimensionLength); +SD_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong const* xShapeInfo, + int* dimension, int dimensionLength); SD_EXPORT Nd4jLong const* getPrimaryShapeInfo(OpaqueTadPack* pack); SD_EXPORT Nd4jLong const* getPrimaryOffsets(OpaqueTadPack* pack); @@ -822,15 +818,14 @@ SD_EXPORT void deleteTadPack(OpaqueTadPack* ptr); * @param zTadShapeInfo * @param zTadOffsets */ -SD_EXPORT void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dzShapeInfo, - Nd4jLong n, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets); +SD_EXPORT void pullRows(Nd4jPointer* extraPointers, OpaqueDataBuffer* dbX, + Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, + OpaqueDataBuffer* dbZ, Nd4jLong const* zShapeInfo, + Nd4jLong const* dzShapeInfo, Nd4jLong n, + Nd4jLong* indexes, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets); /** * @@ -841,24 +836,18 @@ SD_EXPORT void pullRows(Nd4jPointer *extraPointers, * @param length * @param propagate */ -SD_EXPORT void average(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dxShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length, - bool propagate); - - -SD_EXPORT void accumulate(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dxShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length); +SD_EXPORT void average(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jLong const* xShapeInfo, Nd4jPointer* dx, + Nd4jLong const* dxShapeInfo, void* z, + Nd4jLong const* zShapeInfo, void* dz, + Nd4jLong const* dzShapeInfo, int n, Nd4jLong length, + bool propagate); +SD_EXPORT void accumulate(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jLong const* xShapeInfo, Nd4jPointer* dx, + Nd4jLong const* dxShapeInfo, void* z, + Nd4jLong const* zShapeInfo, void* dz, + Nd4jLong const* dzShapeInfo, int n, Nd4jLong length); /** * P2P enabler @@ -896,16 +885,12 @@ SD_EXPORT bool isP2PAvailable(); * @param tadShapeInfo * @param tadOffsets */ -SD_EXPORT void shuffle(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jPointer *xShapeInfo, - Nd4jPointer *dx, Nd4jPointer *dxShapeInfo, - Nd4jPointer *z, Nd4jPointer *zShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dzShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets); - +SD_EXPORT void shuffle(Nd4jPointer* extras, Nd4jPointer* x, + Nd4jPointer* xShapeInfo, Nd4jPointer* dx, + Nd4jPointer* dxShapeInfo, Nd4jPointer* z, + Nd4jPointer* zShapeInfo, Nd4jPointer* dz, + Nd4jPointer* dzShapeInfo, int N, int* shuffleMap, + Nd4jPointer* tadShapeInfo, Nd4jPointer* tadOffsets); /** * Type Conversions @@ -920,8 +905,8 @@ SD_EXPORT void shuffle(Nd4jPointer *extras, * @param dstType * @param z */ -SD_EXPORT void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer x, Nd4jLong N, int dstType, Nd4jPointer z); - +SD_EXPORT void convertTypes(Nd4jPointer* extras, int srcType, Nd4jPointer x, + Nd4jLong N, int dstType, Nd4jPointer z); /** * @@ -948,44 +933,25 @@ SD_EXPORT bool isExperimentalEnabled(); * @param realArguments * @param numRealArguments */ -SD_EXPORT void execAggregate(Nd4jPointer *extraPointers, - int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapeArguments, - int numShapeArguments, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype); - - -SD_EXPORT void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype); - -SD_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype); +SD_EXPORT void execAggregate(Nd4jPointer* extraPointers, int opNum, + void** arguments, int numArguments, + Nd4jLong** shapeArguments, int numShapeArguments, + int* indexArguments, int numIndexArguments, + int** intArrays, int numIntArrays, + void* realArguments, int numRealArguments, + sd::DataType dtype); + +SD_EXPORT void batchExecutor(Nd4jPointer* extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, int maxIdx, + int maxReals, void* ptrToArguments, + sd::DataType dtype); + +SD_EXPORT void execAggregateBatch(Nd4jPointer* extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals, + void* ptrToArguments, sd::DataType dtype); /** * Random operations @@ -1000,11 +966,10 @@ SD_EXPORT void execAggregateBatch(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -SD_EXPORT void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); +SD_EXPORT void execRandom(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1019,13 +984,14 @@ SD_EXPORT void execRandom(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -SD_EXPORT void execRandom3(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeBuffer, Nd4jLong const* dYShapeBuffer, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); +SD_EXPORT void execRandom3(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeBuffer, + Nd4jLong const* dXShapeBuffer, OpaqueDataBuffer* dbY, + Nd4jLong const* hYShapeBuffer, + Nd4jLong const* dYShapeBuffer, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1038,13 +1004,12 @@ SD_EXPORT void execRandom3(Nd4jPointer *extraPointers, * @param zShapeBuffer * @param extraArguments */ -SD_EXPORT void execRandom2(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeBuffer, Nd4jLong const* dXShapeBuffer, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeBuffer, Nd4jLong const* dZShapeBuffer, - void *extraArguments); - +SD_EXPORT void execRandom2(Nd4jPointer* extraPointers, int opNum, + Nd4jPointer state, OpaqueDataBuffer* dbX, + Nd4jLong const* hXShapeBuffer, + Nd4jLong const* dXShapeBuffer, OpaqueDataBuffer* dbZ, + Nd4jLong const* hZShapeBuffer, + Nd4jLong const* dZShapeBuffer, void* extraArguments); /** * @@ -1054,10 +1019,8 @@ SD_EXPORT void execRandom2(Nd4jPointer *extraPointers, * @param ptrToBuffer * @return */ -SD_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, - long seed, - long bufferSize, - Nd4jPointer ptrToBuffer); +SD_EXPORT Nd4jPointer initRandom(Nd4jPointer* extraPointers, long seed, + long bufferSize, Nd4jPointer ptrToBuffer); /** * @@ -1065,9 +1028,8 @@ SD_EXPORT Nd4jPointer initRandom(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -SD_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, - long seed, - Nd4jPointer ptrRandom); +SD_EXPORT void refreshBuffer(Nd4jPointer* extraPointers, long seed, + Nd4jPointer ptrRandom); /** * @@ -1075,329 +1037,337 @@ SD_EXPORT void refreshBuffer(Nd4jPointer *extraPointers, * @param seed * @param ptrRandom */ -SD_EXPORT void reSeedBuffer(Nd4jPointer *extraPointers, - long seed, - Nd4jPointer ptrRandom); +SD_EXPORT void reSeedBuffer(Nd4jPointer* extraPointers, long seed, + Nd4jPointer ptrRandom); /** * * @param ptrRandom */ SD_EXPORT void destroyRandom(Nd4jPointer ptrRandom); - } /** -* -* @param data -* @param shapeBuffer -* @param wordSize -* @param headerSize -* @return -*/ + * + * @param data + * @param shapeBuffer + * @param wordSize + * @param headerSize + * @return + */ template -static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,const Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) { - Nd4jLong const* shapeBufferCast = reinterpret_cast(shapeBuffer); - int rank = shape::rank(shapeBufferCast); - const Nd4jLong* shape = shape::shapeOf(shapeBufferCast); - unsigned int* npShape = new unsigned int[rank]; - for(int i = 0; i < rank; i++) { - npShape[i] = shape[i]; - } - - Nd4jLong length = shape::prodLong(shape,rank); - auto npHeader = cnpy::createNpyHeader(data,npShape,rank,wordSize); - char *ret = new char[npHeader.size() + 1]; - int count = 0; - for(int i = 0; i < npHeader.size(); i++) { - ret[count] = npHeader[i]; - count++; - } - - ret[count] = '\0'; +static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data, + const Nd4jPointer shapeBuffer, + Nd4jLong wordSize, + Nd4jLong* headerSize) { + Nd4jLong const* shapeBufferCast = + reinterpret_cast(shapeBuffer); + int rank = shape::rank(shapeBufferCast); + const Nd4jLong* shape = shape::shapeOf(shapeBufferCast); + unsigned int* npShape = new unsigned int[rank]; + for (int i = 0; i < rank; i++) { + npShape[i] = shape[i]; + } + + Nd4jLong length = shape::prodLong(shape, rank); + auto npHeader = cnpy::createNpyHeader(data, npShape, rank, wordSize); + char* ret = new char[npHeader.size() + 1]; + int count = 0; + for (int i = 0; i < npHeader.size(); i++) { + ret[count] = npHeader[i]; count++; + } + + ret[count] = '\0'; + count++; - *headerSize = count; - return reinterpret_cast(ret); + *headerSize = count; + return reinterpret_cast(ret); } extern "C" { -static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong* headerSize) { - auto shapeBufferCast = reinterpret_cast(shapeBuffer); - auto type = sd::ArrayOptions::dataType(shapeBufferCast); - BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES); +static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize, Nd4jLong* headerSize) { + auto shapeBufferCast = reinterpret_cast(shapeBuffer); + auto type = sd::ArrayOptions::dataType(shapeBufferCast); + BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, + (data, shapeBuffer, wordSize, headerSize), + LIBND4J_TYPES); } /** -* Load numpy from a header -* based on the cnpy parse from header method. -* @param data the header data to parse -* @return a pointer to a numpy cnpy:NpyArray struct -*/ + * Load numpy from a header + * based on the cnpy parse from header method. + * @param data the header data to parse + * @return a pointer to a numpy cnpy:NpyArray struct + */ static Nd4jPointer loadNpyFromHeader(Nd4jPointer data) { - char *header = reinterpret_cast(data); - - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header); - cnpy::NpyArray *ret = new cnpy::NpyArray(); - int totalLengthOfShape = 1; - for(int i = 0; i < arr.shape.size(); i++) { - totalLengthOfShape *= arr.shape[i]; - } - - ret->data = arr.data; - ret->wordSize = arr.wordSize; - ret->shape = arr.shape; - return reinterpret_cast(ret); + char* header = reinterpret_cast(data); + + cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header); + cnpy::NpyArray* ret = new cnpy::NpyArray(); + int totalLengthOfShape = 1; + for (int i = 0; i < arr.shape.size(); i++) { + totalLengthOfShape *= arr.shape[i]; + } + + ret->data = arr.data; + ret->wordSize = arr.wordSize; + ret->shape = arr.shape; + return reinterpret_cast(ret); } - } /** -* Create a numpy array from an nd4j -* array -* @param data a pointer to the data -* @param shapeBuffer the shapebuffer for the nd4j array -* @param wordSize the word size (4 for float, 8 for doubles) -* @return a pointer to a numpy array -*/ + * Create a numpy array from an nd4j + * array + * @param data a pointer to the data + * @param shapeBuffer the shapebuffer for the nd4j array + * @param wordSize the word size (4 for float, 8 for doubles) + * @return a pointer to a numpy array + */ template -static Nd4jPointer _numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) { - Nd4jLong *shapeBufferCast = reinterpret_cast(shapeBuffer); - int rank = shape::rank(shapeBufferCast); - Nd4jLong *shape = shape::shapeOf(shapeBufferCast); - unsigned int *npShape = new unsigned int[rank]; - for(int i = 0; i < rank; i++) { - npShape[i] = shape[i]; - } - - Nd4jLong length = shape::prodLong(shape,rank); - auto npHeader = cnpy::createNpyHeader(data,npShape,rank,wordSize); - char *dataChar = reinterpret_cast(data); - char *npHeaderData = npHeader.data(); - char *ret = new char[(wordSize * length) + npHeader.size()]; - char *cursorStart = ret; - std::memcpy(reinterpret_cast(ret), reinterpret_cast(npHeaderData), npHeader.size() * sizeof(Nd4jLong)); - //move to next - cursorStart += npHeader.size(); - std::memcpy(reinterpret_cast(ret), reinterpret_cast(dataChar), length * wordSize * sizeof(Nd4jLong)); - Nd4jPointer rettPointer = reinterpret_cast(ret); - return rettPointer; +static Nd4jPointer _numpyFromNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize) { + Nd4jLong* shapeBufferCast = reinterpret_cast(shapeBuffer); + int rank = shape::rank(shapeBufferCast); + Nd4jLong* shape = shape::shapeOf(shapeBufferCast); + unsigned int* npShape = new unsigned int[rank]; + for (int i = 0; i < rank; i++) { + npShape[i] = shape[i]; + } + + Nd4jLong length = shape::prodLong(shape, rank); + auto npHeader = cnpy::createNpyHeader(data, npShape, rank, wordSize); + char* dataChar = reinterpret_cast(data); + char* npHeaderData = npHeader.data(); + char* ret = new char[(wordSize * length) + npHeader.size()]; + char* cursorStart = ret; + std::memcpy(reinterpret_cast(ret), + reinterpret_cast(npHeaderData), + npHeader.size() * sizeof(Nd4jLong)); + // move to next + cursorStart += npHeader.size(); + std::memcpy(reinterpret_cast(ret), reinterpret_cast(dataChar), + length * wordSize * sizeof(Nd4jLong)); + Nd4jPointer rettPointer = reinterpret_cast(ret); + return rettPointer; } extern "C" { -static Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) { - auto shapeBufferCast = reinterpret_cast(shapeBuffer); - auto type = sd::ArrayOptions::dataType(shapeBufferCast); - BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES); +static Nd4jPointer numpyFromNd4j(Nd4jPointer data, Nd4jPointer shapeBuffer, + Nd4jLong wordSize) { + auto shapeBufferCast = reinterpret_cast(shapeBuffer); + auto type = sd::ArrayOptions::dataType(shapeBufferCast); + BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, + (data, shapeBuffer, wordSize), LIBND4J_TYPES); } - /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ SD_EXPORT Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); - /** -* Get the shape buffer from a -* numpy array. -* **Warning** this allocates memory -* @param npyArray -* @return -*/ + * Get the shape buffer from a + * numpy array. + * **Warning** this allocates memory + * @param npyArray + * @return + */ static Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - auto shape = new unsigned int[arr.shape.size()]; - for(unsigned int i = 0; i < arr.shape.size(); i++) { - shape[i] = arr.shape[i]; - } - - auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder); - delete[] shape; - return reinterpret_cast(shapeBuffer); + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + auto shape = new unsigned int[arr.shape.size()]; + for (unsigned int i = 0; i < arr.shape.size(); i++) { + shape[i] = arr.shape[i]; + } + + auto shapeBuffer = + shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder); + delete[] shape; + return reinterpret_cast(shapeBuffer); } - - /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ static Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - unsigned char *dataToPrint = reinterpret_cast(arr.data); - return dataToPrint; + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + unsigned char* dataToPrint = reinterpret_cast(arr.data); + return dataToPrint; } /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ static Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) { - cnpy::NpyArray *arrPointer = reinterpret_cast(npyArrayStruct); - unsigned char *dataToPrint = reinterpret_cast(arrPointer->data); - return reinterpret_cast(dataToPrint); + cnpy::NpyArray* arrPointer = + reinterpret_cast(npyArrayStruct); + unsigned char* dataToPrint = + reinterpret_cast(arrPointer->data); + return reinterpret_cast(dataToPrint); } /** -* -* @param npyArray -* @param fromFile -* @return -*/ + * + * @param npyArray + * @param fromFile + * @return + */ static Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) { - char *npyArrayBuffer = reinterpret_cast< char *>(npyArray); - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer); - return dataPointForNumpyStruct(reinterpret_cast(&arr)); + char* npyArrayBuffer = reinterpret_cast(npyArray); + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer); + return dataPointForNumpyStruct(reinterpret_cast(&arr)); } /** -* Load a numpy array from a file -* and return it as an Nd4jPointer -* @param path -* @return -*/ + * Load a numpy array from a file + * and return it as an Nd4jPointer + * @param path + * @return + */ static Nd4jPointer numpyFromFile(std::string path) { - char *numpyBuffer = cnpy::loadFile(path.data()); - return reinterpret_cast(numpyBuffer); + char* numpyBuffer = cnpy::loadFile(path.data()); + return reinterpret_cast(numpyBuffer); } - ////// NPZ ////// -static void* mapFromNpzFile(std::string path){ - cnpy::npz_t* mapPtr = new cnpy::npz_t(); - cnpy::npz_t map = cnpy::npzLoad(path); - mapPtr->insert(map.begin(), map.end()); - return reinterpret_cast(mapPtr); +static void* mapFromNpzFile(std::string path) { + cnpy::npz_t* mapPtr = new cnpy::npz_t(); + cnpy::npz_t map = cnpy::npzLoad(path); + mapPtr->insert(map.begin(), map.end()); + return reinterpret_cast(mapPtr); } - -static int getNumNpyArraysInMap(void *map){ - cnpy::npz_t* arrays = reinterpret_cast(map); - int n = arrays->size(); - return n; +static int getNumNpyArraysInMap(void* map) { + cnpy::npz_t* arrays = reinterpret_cast(map); + int n = arrays->size(); + return n; } -static const char* getNpyArrayNameFromMap(void *map, int index){ - cnpy::npz_t* arrays = reinterpret_cast(map); - cnpy::npz_t::iterator it = arrays->begin(); - cnpy::npz_t::iterator end = arrays->end(); - int cnt = 0; - for(; it != end; ++it, ++cnt){ - if (cnt == index){ - // FIXME: @fariz, this is a leak! +static const char* getNpyArrayNameFromMap(void* map, int index) { + cnpy::npz_t* arrays = reinterpret_cast(map); + cnpy::npz_t::iterator it = arrays->begin(); + cnpy::npz_t::iterator end = arrays->end(); + int cnt = 0; + for (; it != end; ++it, ++cnt) { + if (cnt == index) { + // FIXME: @fariz, this is a leak! #ifdef _MSC_VER - return const_cast(_strdup(it->first.c_str())); + return const_cast(_strdup(it->first.c_str())); #else - return const_cast(strdup(it->first.c_str())); + return const_cast(strdup(it->first.c_str())); #endif - } } - throw std::runtime_error("No array at index."); + } + throw std::runtime_error("No array at index."); } -static void* getNpyArrayFromMap(void *map, int index){ - cnpy::npz_t* arrays = reinterpret_cast(map); - cnpy::npz_t::iterator it = arrays->begin(); - cnpy::npz_t::iterator end = arrays->end(); - cnpy::NpyArray *arr = new cnpy::NpyArray(); - int cnt = 0; - for(; it != end; ++it, ++cnt){ - if (cnt == index){ - *arr = it->second; - return arr; - } +static void* getNpyArrayFromMap(void* map, int index) { + cnpy::npz_t* arrays = reinterpret_cast(map); + cnpy::npz_t::iterator it = arrays->begin(); + cnpy::npz_t::iterator end = arrays->end(); + cnpy::NpyArray* arr = new cnpy::NpyArray(); + int cnt = 0; + for (; it != end; ++it, ++cnt) { + if (cnt == index) { + *arr = it->second; + return arr; } - throw std::runtime_error("No array at index."); + } + throw std::runtime_error("No array at index."); } -SD_EXPORT int dataTypeFromNpyHeader(void *header); +SD_EXPORT int dataTypeFromNpyHeader(void* header); -static void* getNpyArrayData(void *npArray){ - cnpy::NpyArray* npyArray2 = reinterpret_cast(npArray); - return reinterpret_cast(npyArray2->data); +static void* getNpyArrayData(void* npArray) { + cnpy::NpyArray* npyArray2 = reinterpret_cast(npArray); + return reinterpret_cast(npyArray2->data); } -static int getNpyArrayRank(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - int rank = arr->shape.size(); - return rank; +static int getNpyArrayRank(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + int rank = arr->shape.size(); + return rank; } -static Nd4jLong* getNpyArrayShape(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - int ndim = arr->shape.size(); - Nd4jLong* shape = new Nd4jLong[ndim]; - for (int i=0; ishape.at(i); - } - return shape; +static Nd4jLong* getNpyArrayShape(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + int ndim = arr->shape.size(); + Nd4jLong* shape = new Nd4jLong[ndim]; + for (int i = 0; i < ndim; i++) { + shape[i] = arr->shape.at(i); + } + return shape; } -static char getNpyArrayOrder(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - return (arr->fortranOrder)?'f':'c'; +static char getNpyArrayOrder(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + return (arr->fortranOrder) ? 'f' : 'c'; } -static int getNpyArrayElemSize(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - return arr->wordSize; +static int getNpyArrayElemSize(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + return arr->wordSize; } -static void deleteNPArrayStruct(void *npArray){ - cnpy::NpyArray* arr = reinterpret_cast(npArray); - delete arr; +static void deleteNPArrayStruct(void* npArray) { + cnpy::NpyArray* arr = reinterpret_cast(npArray); + delete arr; } -static void deleteNPArrayMap(void *map){ - cnpy::npz_t* arrays = reinterpret_cast(map); - delete arrays; +static void deleteNPArrayMap(void* map) { + cnpy::npz_t* arrays = reinterpret_cast(map); + delete arrays; } ////// /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ static int elementSizeForNpyArray(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - cnpy::NpyArray *arrPointer = &arr; - int size = arrPointer->wordSize; - // arrPointer->destruct(); - return size; + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + cnpy::NpyArray* arrPointer = &arr; + int size = arrPointer->wordSize; + // arrPointer->destruct(); + return size; } - /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ static int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); - cnpy::NpyArray *arrPointer = &arr; - int size = arrPointer->wordSize; - return size; + cnpy::NpyArray arr = + cnpy::loadNpyFromHeader(reinterpret_cast(npyArray)); + cnpy::NpyArray* arrPointer = &arr; + int size = arrPointer->wordSize; + return size; } - static void releaseNumpy(Nd4jPointer npyArray) { - free(reinterpret_cast(npyArray)); + free(reinterpret_cast(npyArray)); } - /** * Return the length of a shape buffer * based on the pointer @@ -1406,18 +1376,18 @@ static void releaseNumpy(Nd4jPointer npyArray) { */ SD_EXPORT int lengthForShapeBufferPointer(Nd4jPointer buffer); - - /** -* The pointer to get the address for -* -* @param address the address to get the pointer -* @return the pointer for the given address -*/ +/** + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ SD_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); /** - * This method takes single N-dimensional tensor, and copies its TADs to target arrays + * This method takes single N-dimensional tensor, and copies its TADs to target + * arrays * * @param x * @param xShapeInfo @@ -1425,71 +1395,62 @@ SD_EXPORT Nd4jPointer pointerForAddress(Nd4jLong address); * @param zShapeInfo * @return */ -SD_EXPORT void tear(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, - Nd4jPointer *targets, Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets); - -SD_EXPORT void sort(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - bool descending); - -SD_EXPORT void sortByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending); - -SD_EXPORT void sortByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending); - -SD_EXPORT void sortTad(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - int *dimension, - int dimensionLength, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - bool descending); - -SD_EXPORT void sortTadByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending); - -SD_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dx, Nd4jLong const* dxShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending); - +SD_EXPORT void tear(Nd4jPointer* extraPointers, OpaqueDataBuffer* dbX, + Nd4jLong const* xShapeInfo, Nd4jLong const* dxShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); + +SD_EXPORT void sort(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, bool descending); + +SD_EXPORT void sortByKey(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, bool descending); + +SD_EXPORT void sortByValue(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, bool descending); + +SD_EXPORT void sortTad(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool descending); + +SD_EXPORT void sortTadByKey(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, int* dimension, + int dimensionLength, bool descending); + +SD_EXPORT void sortTadByValue(Nd4jPointer* extraPointers, void* x, + Nd4jLong const* xShapeInfo, void* dx, + Nd4jLong const* dxShapeInfo, void* y, + Nd4jLong const* yShapeInfo, void* dy, + Nd4jLong const* dyShapeInfo, int* dimension, + int dimensionLength, bool descending); // special sort impl for sorting out COO indices and values -SD_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); +SD_EXPORT void sortCooIndices(Nd4jPointer* extraPointers, Nd4jLong* indices, + void* values, Nd4jLong length, int rank); +SD_EXPORT Nd4jLong* mmapFile(Nd4jPointer* extraPointers, const char* fileName, + Nd4jLong length); -SD_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); - -SD_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); +SD_EXPORT void munmapFile(Nd4jPointer* extraPointers, Nd4jLong* ptrMap, + Nd4jLong length); typedef sd::graph::ResultWrapper OpaqueResultWrapper; // flatbuffers execution -SD_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); +SD_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer* extraPointers, + Nd4jPointer flatBufferPointer); SD_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); SD_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); @@ -1499,25 +1460,44 @@ SD_EXPORT const char* getAllCustomOps(); SD_EXPORT const char* getAllOperations(); // customOp executioner -SD_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); -SD_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); +SD_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, + Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, + int numInputs, Nd4jPointer* outputBuffers, + Nd4jPointer* outputShapes, int numOutputs, + double* tArgs, int numTArgs, Nd4jLong* iArgs, + int numIArgs, bool* bArgs, int numBArgs, + bool isInplace); +SD_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, + Nd4jPointer opContext); typedef sd::ShapeList OpaqueShapeList; -SD_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -SD_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, + Nd4jLong hash, + Nd4jPointer* inputShapes, + int numInputShapes, + double* tArgs, int numTArgs, + Nd4jLong* iArgs, int numIArgs); +SD_EXPORT OpaqueShapeList* calculateOutputShapes2( + Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, + Nd4jLong* iArgs, int numIArgs, bool* bArgs, int numBArgs, int* dArgs, + int numDArgs); SD_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); SD_EXPORT Nd4jLong const* getShape(OpaqueShapeList* list, Nd4jLong i); SD_EXPORT void deleteShapeList(Nd4jPointer shapeList); -SD_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); +SD_EXPORT int registerGraph(Nd4jPointer* extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer); typedef sd::graph::VariablesSet OpaqueVariablesSet; typedef sd::graph::Variable OpaqueVariable; -SD_EXPORT OpaqueVariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); +SD_EXPORT OpaqueVariablesSet* executeStoredGraph( + Nd4jPointer* extraPointers, Nd4jLong graphId, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int* inputIndices, int numInputs); SD_EXPORT Nd4jLong getVariablesSetSize(OpaqueVariablesSet* set); SD_EXPORT Nd4jStatus getVariablesSetStatus(OpaqueVariablesSet* set); @@ -1528,7 +1508,7 @@ SD_EXPORT const char* getVariableName(OpaqueVariable* variable); SD_EXPORT Nd4jLong const* getVariableShape(OpaqueVariable* variable); SD_EXPORT void* getVariableBuffer(OpaqueVariable* variable); -SD_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); +SD_EXPORT int unregisterGraph(Nd4jPointer* extraPointers, Nd4jLong graphId); SD_EXPORT void deleteCharArray(Nd4jPointer pointer); SD_EXPORT void deleteIntArray(Nd4jPointer pointer); @@ -1544,37 +1524,61 @@ SD_EXPORT void deleteGraphState(Nd4jPointer state); SD_EXPORT void deleteResultWrapper(Nd4jPointer ptr); -SD_EXPORT int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer x, Nd4jLong const* xShapeInfo, int N, float threshold); - -// this method executes op that requires scope to be present: if/while/cond/whatever -SD_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs); - -//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); -SD_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); -SD_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr); -SD_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr); -SD_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); - -SD_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, - void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, - void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, - void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, - void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo); - -SD_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); - +SD_EXPORT int estimateThreshold(Nd4jPointer* extraPointers, Nd4jPointer x, + Nd4jLong const* xShapeInfo, int N, + float threshold); + +// this method executes op that requires scope to be present: +// if/while/cond/whatever +SD_EXPORT Nd4jStatus execCustomOpWithScope( + Nd4jPointer* extraPointers, Nd4jPointer state, Nd4jLong opHash, + Nd4jLong* scopes, int numScopes, Nd4jPointer* inputBuffers, + Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, + Nd4jPointer* outputShapes, int numOutputs); + +// void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int +// numStrings, Nd4jPointer buffer); +SD_EXPORT Nd4jPointer createUtf8String(Nd4jPointer* extraPointers, + const char* string, int length); +SD_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer* extraPointers, + Nd4jPointer ptr); +SD_EXPORT char* getUtf8StringBuffer(Nd4jPointer* extraPointers, + Nd4jPointer ptr); +SD_EXPORT void deleteUtf8String(Nd4jPointer* extraPointers, Nd4jPointer ptr); + +SD_EXPORT void scatterUpdate( + Nd4jPointer* extraPointers, int opCode, int numOfSubArrs, void* hX, + Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, void* dX, + Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, void* hY, + Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, void* dY, + Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, void* hIindexes, + Nd4jLong const* hIndicesShapeInfo, void* dIindexes, + Nd4jLong const* dIndicesShapeInfo); + +SD_EXPORT void inspectArray(Nd4jPointer* extraPointers, Nd4jPointer buffer, + Nd4jLong* shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong* specialShapeInfo, Nd4jPointer debugInfo); typedef sd::ConstantDataBuffer OpaqueConstantDataBuffer; -SD_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty); - -SD_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length); -SD_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length); -SD_EXPORT OpaqueConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor); - -SD_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); -SD_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); +SD_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong* shape, + Nd4jLong* strides, + sd::DataType dtype, char order, + Nd4jLong ews, bool empty); + +SD_EXPORT OpaqueConstantDataBuffer* constantBufferLong(sd::DataType dtype, + Nd4jLong const* data, + int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(sd::DataType dtype, + double* data, + int length); +SD_EXPORT OpaqueConstantDataBuffer* constantBuffer( + sd::DataType dtype, sd::ConstantDescriptor* descriptor); + +SD_EXPORT Nd4jPointer +getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); +SD_EXPORT Nd4jPointer +getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); SD_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); SD_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf); @@ -1584,29 +1588,55 @@ typedef sd::graph::Context OpaqueContext; typedef sd::graph::RandomGenerator OpaqueRandomGenerator; SD_EXPORT OpaqueContext* createGraphContext(int nodeId); -SD_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); +SD_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator( + OpaqueContext* ptr); SD_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); -SD_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride); +SD_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, + bool reallyOverride); SD_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode); SD_EXPORT void ctxPurge(OpaqueContext* ptr); SD_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); -SD_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); -SD_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -SD_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); -SD_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -SD_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo); -SD_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments); -SD_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); -SD_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); -SD_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); +SD_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void* stream, + void* reductionPointer, + void* allocationPointer); +SD_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, + void* buffer, void* shapeInfo, + void* specialBuffer, + void* specialShapeInfo); +SD_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, + void* buffer, void* shapeInfo, + void* specialBuffer, + void* specialShapeInfo); +SD_EXPORT void setGraphContextInputBuffer(OpaqueContext* ptr, int index, + OpaqueDataBuffer* buffer, + void* shapeInfo, + void* specialShapeInfo); +SD_EXPORT void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, + OpaqueDataBuffer* buffer, + void* shapeInfo, + void* specialShapeInfo); +SD_EXPORT void setGraphContextDArguments(OpaqueContext* ptr, int* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, + Nd4jLong* arguments, + int numberOfArguments); +SD_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool* arguments, + int numberOfArguments); SD_EXPORT void deleteGraphContext(OpaqueContext* ptr); -SD_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +SD_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, + Nd4jLong nodeSeed = 0); SD_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); SD_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); -SD_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); -SD_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); -SD_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); +SD_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, + Nd4jLong rootSeed = 0, + Nd4jLong nodeSeed = 0); +SD_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, + Nd4jLong index); +SD_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, + Nd4jLong index); SD_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); SD_EXPORT const char* runLightBenchmarkSuit(bool printOut); @@ -1623,37 +1653,44 @@ SD_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); SD_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); SD_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); -SD_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -SD_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); -SD_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special); -SD_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); -SD_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); -SD_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); -SD_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes); -SD_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes); -SD_EXPORT void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer); -SD_EXPORT int dbLocality(OpaqueDataBuffer *dataBuffer); -SD_EXPORT int dbDeviceId(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId); -SD_EXPORT void dbTickHostRead(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbTickHostWrite(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbClose(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void deleteDataBuffer(OpaqueDataBuffer *dataBuffer); -SD_EXPORT void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements); - - -SD_EXPORT int binaryLevel(); +SD_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth); +SD_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, + int dataType, + bool allocateBoth); +SD_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, + int dataType, + Nd4jPointer primary, + Nd4jPointer special); +SD_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer* dataBuffer, + Nd4jLong length, Nd4jLong offset); +SD_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbExpandBuffer(OpaqueDataBuffer* dataBuffer, Nd4jLong elements); +SD_EXPORT void dbAllocatePrimaryBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbAllocateSpecialBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSetPrimaryBuffer(OpaqueDataBuffer* dataBuffer, + Nd4jPointer primaryBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSetSpecialBuffer(OpaqueDataBuffer* dataBuffer, + Nd4jPointer specialBuffer, Nd4jLong numBytes); +SD_EXPORT void dbSyncToSpecial(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSyncToPrimary(OpaqueDataBuffer* dataBuffer); +SD_EXPORT int dbLocality(OpaqueDataBuffer* dataBuffer); +SD_EXPORT int dbDeviceId(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbSetDeviceId(OpaqueDataBuffer* dataBuffer, int deviceId); +SD_EXPORT void dbTickHostRead(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickHostWrite(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickDeviceRead(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbTickDeviceWrite(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbClose(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void deleteDataBuffer(OpaqueDataBuffer* dataBuffer); +SD_EXPORT void dbExpand(OpaqueDataBuffer* dataBuffer, Nd4jLong elements); + +SD_EXPORT int binaryLevel(); SD_EXPORT int optimalLevel(); SD_EXPORT bool isMinimalRequirementsMet(); SD_EXPORT bool isOptimalRequirementsMet(); - } -#endif //NATIVEOPERATIONS_NATIVEOPS_H +#endif // NATIVEOPERATIONS_NATIVEOPS_H diff --git a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp index ad75922e498e..a49547fd32a2 100644 --- a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp @@ -14,82 +14,69 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - -#include -#include #include "legacy/NativeOpExecutioner.h" -#include +#include +#include +#include #include - -#include +#include #include -#include - -#include #include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include #include +#include +#include #include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include #ifdef _OPENMP -#include #include +#include #endif - - - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto hz = reinterpret_cast(hZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES, INDEXING_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execIndexReduceScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto hz = reinterpret_cast(hZ); + + BUILD_DOUBLE_SELECTOR(xType, zType, + hz[0] = functions::indexreduce::IndexReduce, + ::execScalar(opNum, hX, hXShapeInfo, extraParams), + LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -105,22 +92,21 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, int opNu * @param dimensionLength */ -void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto hz = reinterpret_cast(hZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); +void NativeOpExecutioner::execIndexReduce( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto hz = reinterpret_cast(hZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets), + LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -137,557 +123,597 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, * @param dimensionLength */ -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto loopKind = sd::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES); - }; - - Nd4jLong numTads = 0; - - switch (loopKind) { - case sd::LoopKind::BROADCAST_SCALAR_X: { - numTads = shape::length(hXShapeInfo); - } - break; - case sd::LoopKind::BROADCAST_SCALAR_Y: { - numTads = shape::length(hYShapeInfo); - } - break; - case sd::LoopKind::BROADCAST_3D: { - numTads = shape::sizeAt(hZShapeInfo, 0); - } - break; - case sd::LoopKind::BROADCAST_4D: { - numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); - } - break; - case sd::LoopKind::BROADCAST_5D: { - numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); - } - break; - default: { - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - numTads = xLen / yLen; - } + auto loopKind = sd::LoopKind::deduceKindOfLoopBroadcast( + hXShapeInfo, hYShapeInfo, hZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), + LIBND4J_TYPES); + }; + + Nd4jLong numTads = 0; + + switch (loopKind) { + case sd::LoopKind::BROADCAST_SCALAR_X: { + numTads = shape::length(hXShapeInfo); + } break; + case sd::LoopKind::BROADCAST_SCALAR_Y: { + numTads = shape::length(hYShapeInfo); + } break; + case sd::LoopKind::BROADCAST_3D: { + numTads = shape::sizeAt(hZShapeInfo, 0); + } break; + case sd::LoopKind::BROADCAST_4D: { + numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); + } break; + case sd::LoopKind::BROADCAST_5D: { + numTads = shape::sizeAt(hZShapeInfo, 0) * shape::sizeAt(hZShapeInfo, 1); + } break; + default: { + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + numTads = xLen / yLen; } + } - samediff::Threads::parallel_tad(func, 0, numTads); + samediff::Threads::parallel_tad(func, 0, numTads); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); - #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); - #endif -} - -void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!sd::Environment::getInstance()->isExperimentalBuild()) - if ((yType != xType && yType != sd::DataType::BOOL) || xType != zType) - throw sd::datatype_exception::build("NativeOps::execBroadcast both operands must have same data type", xType, yType); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); - }; + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES); +#endif +} - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; +void NativeOpExecutioner::execInverseBroadcast( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!sd::Environment::getInstance()->isExperimentalBuild()) + if ((yType != xType && yType != sd::DataType::BOOL) || xType != zType) + throw sd::datatype_exception::build( + "NativeOps::execBroadcast both operands must have same data type", + xType, yType); - samediff::Threads::parallel_tad(func, 0, numTads); +#ifdef __ND4J_EXPERIMENTAL__ + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); +#else + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); #endif - } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = xLen / yLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = xLen / yLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES); } - -void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!sd::Environment::getInstance()->isExperimentalBuild()) - if (yType != xType || sd::DataType::BOOL != zType) - throw sd::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execInverseBroadcastBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!sd::Environment::getInstance()->isExperimentalBuild()) + if (yType != xType || sd::DataType::BOOL != zType) + throw sd::datatype_exception::build( + "NativeOps::execInverseBroadcastBool both operands must have same " + "data type", + xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } - - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = xLen / yLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", + zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execBroadcastInt requires integer data type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = xLen / yLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), INTEGER_TYPES); - +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext *lc, const int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", + zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execBroadcastInt requires integer data type", + zType); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), + INTEGER_TYPES); } -void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ,const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt,::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = yLen / xLen; - - samediff::Threads::parallel_tad(func, 0, numTads); +void NativeOpExecutioner::execInverseBroadcastInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execInverseBroadcastInt requires integer data " + "type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + auto numTads = yLen / xLen; + + samediff::Threads::parallel_tad(func, 0, numTads); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hY -* @param yStride -* @param hZ -* @param resultStride -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + * + * @param opNum + * @param hX + * @param xStride + * @param hY + * @param yStride + * @param hZ + * @param resultStride + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execPairwiseTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, + functions::pairwise_transforms::PairWiseTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, - ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), - LIBND4J_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::pairwise_transforms::PairWiseTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, + extraParams, start, stop), + LIBND4J_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseBoolTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); - +void NativeOpExecutioner::execPairwiseBoolTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform", sd::DataType::BOOL, + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::pairwise_transforms::PairWiseBoolTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseIntTransform(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), INTEGER_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); - +void NativeOpExecutioner::execPairwiseIntTransform( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraParams) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSPairwiseInt requires integer data type", + zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR(xType, + functions::pairwise_transforms::PairWiseIntTransform, + ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraParams, start, stop), + INTEGER_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execReduceFloat( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); +void NativeOpExecutioner::execReduceSame( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); +void NativeOpExecutioner::execReduceBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // nothing to do here if result is empty - if (shape::isEmpty(hZShapeInfo)) - return; - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES); - }; - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo); - - samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance()->maxMasterThreads()); +void NativeOpExecutioner::execReduceLong( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // nothing to do here if result is empty + if (shape::isEmpty(hZShapeInfo)) return; + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffsets, start, stop), + LIBND4J_TYPES, LONG_TYPES); + }; + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ( + hXShapeInfo, hZShapeInfo, tadShapeInfo); + + samediff::Threads::parallel_tad( + func, 0, shape::length(hZShapeInfo), 1, + kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX + ? 1 + : sd::Environment::getInstance()->maxMasterThreads()); } //////////////////////////////////////////////////////////////////////// @@ -699,67 +725,64 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, * @param extraParams * @return */ -void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); +void NativeOpExecutioner::execReduceFloatScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES); +void NativeOpExecutioner::execReduceSameScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, BOOL_TYPES); +void NativeOpExecutioner::execReduceBoolScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, BOOL_TYPES); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, LONG_TYPES); +void NativeOpExecutioner::execReduceLongScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execScalar(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, LONG_TYPES); } - //////////////////////////////////////////////////////////////////////// /** * @@ -774,648 +797,705 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); +void NativeOpExecutioner::execReduce3Scalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, + ::execScalar(opNum, hX, hXShapeInfo, extraParamsVals, + hY, hYShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParamsVals -* @param hY -* @param hYShapeInfo -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - //BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES); - NativeOpExecutioner::execReduce3Scalar(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParamsVals + * @param hY + * @param hYShapeInfo + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, + // ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + // hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES); + NativeOpExecutioner::execReduce3Scalar( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadOnlyShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - const auto xLen = shape::length(hXShapeInfo); - const auto yLen = shape::length(hYShapeInfo); - - sd::TadPack tadPack; - - if(xLen == yLen) { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - else if(yLen > xLen) { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength); - } - else { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadOnlyShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + const auto xLen = shape::length(hXShapeInfo); + const auto yLen = shape::length(hYShapeInfo); + + sd::TadPack tadPack; + + if (xLen == yLen) { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } else if (yLen > xLen) { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hYShapeInfo, dimension, dimensionLength); + } else { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - // TODO: make it 2d - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3All( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + // TODO: make it 2d + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, + xOffsets, yTadShapeInfo, yOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffsets) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - const auto xLen = shape::length(hXShapeInfo); - const auto yLen = shape::length(hYShapeInfo); - - sd::TadPack tadPack; - - if(xLen == yLen) { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - else if(yLen > xLen) { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength); - } - else { - tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); +void NativeOpExecutioner::execReduce3TAD( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParamsVals, const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, const Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + const auto xLen = shape::length(hXShapeInfo); + const auto yLen = shape::length(hYShapeInfo); + + sd::TadPack tadPack; + + if (xLen == yLen) { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } else if (yLen > xLen) { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hYShapeInfo, dimension, dimensionLength); + } else { + tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, + hZShapeInfo, dimension, dimensionLength, tadShapeInfo, + tadOffsets, start, stop), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads()); } - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hZ -* @param resultStride -* @param scalar -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hScalarShapeInfo, - const void *dScalar, const Nd4jLong *dScalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; + * + * @param opNum + * @param hX + * @param xStride + * @param hZ + * @param resultStride + * @param scalar + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hScalarShapeInfo, const void *dScalar, + const Nd4jLong *dScalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, + functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform,::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", + zType, xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, + extraParams, start, stop), + LIBND4J_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const*hXShapeInfo, - void const* dX, Nd4jLong const*dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const*hZShapeInfo, - void *dZ, Nd4jLong const*dZShapeInfo, - void const* hScalars, Nd4jLong const*hScalarShapeInfo, - void const* dScalars, Nd4jLong const*dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const*tadShapeInfo, Nd4jLong const*tadOffsets, - Nd4jLong const*tadShapeInfoZ, Nd4jLong const*tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; +void NativeOpExecutioner::execScalar( + sd::LaunchContext *lc, int opNum, void const *hX, + Nd4jLong const *hXShapeInfo, void const *dX, Nd4jLong const *dXShapeInfo, + void *extraParams, void *hZ, Nd4jLong const *hZShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, void const *hScalars, + Nd4jLong const *hScalarShapeInfo, void const *dScalars, + Nd4jLong const *dScalarShapeInfo, int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance()->maxMasterThreads())); + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalar", + zType, xType, yType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance()->maxMasterThreads())); #endif } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); - +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + sd::DataType::BOOL, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); - - if (zType != sd::DataType::BOOL) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", sd::DataType::BOOL, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance()->maxMasterThreads())); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + xType, yType); + + if (zType != sd::DataType::BOOL) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarBool", + sd::DataType::BOOL, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + LIBND4J_TYPES, BOOL_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance()->maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalar, const Nd4jLong *hSscalarShapeInfo, - const void *dScalar, const Nd4jLong *dSscalarShapeInfo, - void *extraParams, - bool allowParallelism) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", sd::DataType::INT32, zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), INTEGER_TYPES); - }; - - auto zLen = shape::length(hZShapeInfo); - samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : sd::math::nd4j_max(1, sd::math::nd4j_min(zLen / 1024, sd::Environment::getInstance()->maxMasterThreads()))); - +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalar, + const Nd4jLong *hSscalarShapeInfo, const void *dScalar, + const Nd4jLong *dSscalarShapeInfo, void *extraParams, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + sd::DataType::INT32, zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, + ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + hScalar, extraParams, start, stop), + INTEGER_TYPES); + }; + + auto zLen = shape::length(hZShapeInfo); + samediff::Threads::parallel_for( + func, 0, zLen, 1, + !allowParallelism + ? 1 + : sd::math::nd4j_max( + 1, sd::math::nd4j_min( + zLen / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - const void *hScalars, const Nd4jLong *hScalarShapeInfo, - const void *dScalars, const Nd4jLong *dScalarShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || xType != zType) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); - - if (!sd::DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType); - - auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES); - }; - - auto yLen = shape::length(hScalarShapeInfo); - samediff::Threads::parallel_tad(func, 0, yLen, 1, sd::math::nd4j_min(yLen, sd::Environment::getInstance()->maxMasterThreads())); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, const void *hScalars, + const Nd4jLong *hScalarShapeInfo, const void *dScalars, + const Nd4jLong *dScalarShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execScalarInt", + xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execScalarInt requires integer data type", zType); + + auto func = PRAGMA_THREADS_FOR { + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, + hScalars, dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), + INTEGER_TYPES); + }; + + auto yLen = shape::length(hScalarShapeInfo); + samediff::Threads::parallel_tad( + func, 0, yLen, 1, + sd::math::nd4j_min( + yLen, sd::Environment::getInstance()->maxMasterThreads())); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, nullptr, 1), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::summarystats::SummaryStatsReduce, + ::exec(opNum, biasCorrected, hX, hXShapeInfo, + extraParams, hZ, hZShapeInfo, nullptr, 1), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -*/ -void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execScalar(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + */ +void NativeOpExecutioner::execSummaryStatsScalar( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, + functions::summarystats::SummaryStatsReduce, + ::execScalar(opNum, biasCorrected, hX, hXShapeInfo, + extraParams, hZ, hZShapeInfo), + LIBND4J_TYPES, FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param hXShapeInfo -* @param extraParams -* @param hZ -* @param hZShapeInfo -* @param dimension -* @param dimensionLength -*/ -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *extraParams, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool biasCorrected) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength), LIBND4J_TYPES, FLOAT_TYPES); + * + * @param opNum + * @param hX + * @param hXShapeInfo + * @param extraParams + * @param hZ + * @param hZShapeInfo + * @param dimension + * @param dimensionLength + */ +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *extraParams, void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + bool biasCorrected) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::exec(opNum, biasCorrected, hX, hXShapeInfo, extraParams, hZ, + hZShapeInfo, dimension, dimensionLength), + LIBND4J_TYPES, FLOAT_TYPES); } - //////////////////////////////////////////////////////////////////////// /** -* -* @param opNum -* @param hX -* @param xStride -* @param hZ -* @param resultStride -* @param extraParams -* @param n -*/ -void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); -} - -//////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); + * + * @param opNum + * @param hX + * @param xStride + * @param hZ + * @param resultStride + * @param extraParams + * @param n + */ +void NativeOpExecutioner::execTransformFloat( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, FLOAT_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - bool allowParallelism) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) { - - memcpy(hZ, hX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType)); - } - else { - auto func = PRAGMA_THREADS_DO { - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); - } +void NativeOpExecutioner::execTransformBool( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, BOOL_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - +void NativeOpExecutioner::execTransformAny( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + bool allowParallelism) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (opNum == sd::transform::Assign && + shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && + shape::order(hXShapeInfo) == 'c' && xType == zType && + shape::elementWiseStride(hXShapeInfo) == 1 && + shape::elementWiseStride(hZShapeInfo) == 1) { + memcpy( + hZ, hX, + shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType)); + } else { auto func = PRAGMA_THREADS_DO { - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES, LIBND4J_TYPES); }; - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); + } } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, - int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - auto func = PRAGMA_THREADS_DO { - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); - }; - - samediff::Threads::parallel_do(func, sd::math::nd4j_max(1, sd::math::nd4j_min(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads()))); +void NativeOpExecutioner::execTransformSame( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + LIBND4J_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - - - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); - - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execTransformStrict( + sd::LaunchContext *lc, int opNum, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraParams, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + auto func = PRAGMA_THREADS_DO { + BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, + ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, + extraParams, thread_id, numThreads), + FLOAT_TYPES); + }; + + samediff::Threads::parallel_do( + func, sd::math::nd4j_max( + 1, sd::math::nd4j_min( + shape::length(hZShapeInfo) / 1024, + sd::Environment::getInstance()->maxMasterThreads()))); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, +void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, int opNum, + Nd4jPointer state, void *hZ, + const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + zType, functions::random::RandomFunction, + ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), + FLOAT_TYPES); - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer state, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); - - auto rng = reinterpret_cast(state); - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext *lc, int opNum, Nd4jPointer state, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, void *dZ, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, + ::execTransform(opNum, state, hX, hXShapeInfo, hZ, + hZShapeInfo, extraArguments), + FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } - - - - - - +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execRandom( + sd::LaunchContext *lc, int opNum, Nd4jPointer state, const void *hX, + const Nd4jLong *hXShapeInfo, const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, const void *dY, + const Nd4jLong *dYShapeInfo, void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, void *extraArguments) { + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, + hZShapeInfo, extraArguments), + FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); +} diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index d5744e883175..e47f273ca38e 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -20,61 +20,57 @@ #define __STDC_CONSTANT_MACROS -#include -#include "legacy/NativeOpExecutioner.h" #include +#include +#include #include -#include -#include -#include +#include #include #include -#include -#include -#include +#include +#include +#include #include -#include -#include - - -#include #include #include +#include +#include +#include +#include + +#include "legacy/NativeOpExecutioner.h" #ifndef _WIN32 -#include #include +#include #else -#include #include +#include #endif -#include - -#include #include - +#include +#include char *name; bool nameSet = false; - #ifdef __ND4J_EXPERIMENTAL__ bool experimentalSupport = true; #else bool experimentalSupport = false; #endif -#include -#include -#include -#include +#include #include #include -#include #include +#include +#include +#include +#include #include #include #include -#include +#include #ifdef CPU_FEATURES #include @@ -83,13 +79,11 @@ bool experimentalSupport = false; using namespace sd; void setElementThreshold(int num) { - if (num > 0) - sd::Environment::getInstance()->setElementwiseThreshold(num); + if (num > 0) sd::Environment::getInstance()->setElementwiseThreshold(num); } void setTADThreshold(int num) { - if (num > 0) - sd::Environment::getInstance()->setTadThreshold(num); + if (num > 0) sd::Environment::getInstance()->setTadThreshold(num); } /** @@ -99,17 +93,21 @@ void setTADThreshold(int num) { * @param hXShapeInfo * @param extraParams */ -void execIndexReduceScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execIndexReduceScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -123,44 +121,36 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execIndexReduce(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - auto hz = reinterpret_cast(dbZ->primary()); - - NativeOpExecutioner::execIndexReduce(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - hz, - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduce(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + auto hz = reinterpret_cast(dbZ->primary()); + + NativeOpExecutioner::execIndexReduce( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, hz, hZShapeInfo, dbZ->special(), dZShapeInfo, + dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } - /** * * @param opNum @@ -173,83 +163,75 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, * @param dimension * @param dimensionLength */ -void execBroadcast(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - - NativeOpExecutioner::execBroadcast(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), hZShapeInfo, - dbZ->special(), dZShapeInfo, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execBroadcastBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - - NativeOpExecutioner::execBroadcastBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), hZShapeInfo, - dbZ->special(), dZShapeInfo, - extraParams, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, - hTADOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execBroadcast(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hZShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + + NativeOpExecutioner::execBroadcast( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, + dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, + hTADOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hZShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + + NativeOpExecutioner::execBroadcastBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams, + dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, + hTADOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -264,63 +246,42 @@ void execBroadcastBool(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -void execPairwiseTransform( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execPairwiseTransform(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransform(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execPairwiseTransform( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void execPairwiseTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - try { - NativeOpExecutioner::execPairwiseBoolTransform(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbY->primary(), - hYShapeInfo, - dbY->special(), - dYShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execPairwiseBoolTransform( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -332,102 +293,72 @@ void execPairwiseTransformBool( * @param hZ * @param hZShapeInfo */ -void execReduceFloat( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - - try { - NativeOpExecutioner::execReduceFloatScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceSame( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - - try { - NativeOpExecutioner::execReduceSameScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduceBoolScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceLong( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduceLongScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceFloatScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceSameScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceBoolScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceLong(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduceLongScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -439,146 +370,117 @@ void execReduceLong( * @param hZ * @param hZShapeInfo */ -void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - - NativeOpExecutioner::execReduceFloat(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceBool(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceSame(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduceLong(nullptr, opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + + NativeOpExecutioner::execReduceFloat( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceBool2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceSame2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceSame( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduceLong2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduceLong( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -592,19 +494,22 @@ void execReduceLong2(Nd4jPointer *extraPointers, * @param hZ * @param hZShapeInfo */ -void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduce3(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, - dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + void *extraParams, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduce3( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -616,18 +521,23 @@ void execReduce3(Nd4jPointer *extraPointers, * @param hY * @param hYShapeInfo */ -void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo) { - try { - NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbY->primary(), - hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo) { + try { + NativeOpExecutioner::execReduce3Scalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -642,46 +552,50 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, * @param dimension * @param dimensionLength */ -void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - if (extraPointers == nullptr || extraPointers[2] == 0) { - NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, - extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), - dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, - yTadOnlyShapeInfo, yTadOffsets); - } else { - // going tad-way - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, - dimensionLength); - - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); - - NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, dbX->special(), - dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), - hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, - hTADOffsets, nullptr, nullptr); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +void execReduce3Tad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *yTadOnlyShapeInfo, + const Nd4jLong *yTadOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + if (extraPointers == nullptr || extraPointers[2] == 0) { + NativeOpExecutioner::execReduce3( + LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, + dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + } else { + // going tad-way + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); + + NativeOpExecutioner::execReduce3TAD( + LaunchContext::defaultContext(), opNum, dbX->primary(), hXShapeInfo, + dbX->special(), dXShapeInfo, extraParams, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, + hTADShapeInfo, hTADOffsets, nullptr, nullptr); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -bool isBlasVersionMatches(int major, int minor, int build) { - return true; -} +bool isBlasVersionMatches(int major, int minor, int build) { return true; } /** * @@ -694,62 +608,43 @@ bool isBlasVersionMatches(int major, int minor, int build) { * @param extraParams * @param n */ -void execScalar( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalar->primary(), - hScalarShapeInfo, - dbScalar->special(), - dScalarShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execScalarBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalar, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execScalarBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalar->primary(), - hScalarShapeInfo, - dbScalar->special(), - dScalarShapeInfo, - extraParams); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbScalar, + const Nd4jLong *hScalarShapeInfo, + const Nd4jLong *dScalarShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + dScalarShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execScalarBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbScalar, + const Nd4jLong *hScalarShapeInfo, + const Nd4jLong *dScalarShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execScalarBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + dScalarShapeInfo, extraParams); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -759,29 +654,21 @@ void execScalarBool( * @param hXShapeInfo * @param extraParams */ -void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - try { - NativeOpExecutioner::execSummaryStatsScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + try { + NativeOpExecutioner::execSummaryStatsScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -792,29 +679,21 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, * @param hZ * @param hZShapeInfo */ -void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - bool biasCorrected) { - try { - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStats(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, bool biasCorrected) { + try { + NativeOpExecutioner::execSummaryStats( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * @@ -827,39 +706,30 @@ void execSummaryStats(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - bool biasCorrected, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - biasCorrected); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, bool biasCorrected, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execSummaryStats( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, + biasCorrected); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -872,220 +742,178 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, * @param extraParams * @param n */ -void execTransformFloat( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformFloat(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformSame( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformSame(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformBool( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformAny( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformAny(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execTransformStrict( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraParams) { - try { - NativeOpExecutioner::execTransformStrict(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - extraParams, - nullptr, - nullptr); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { - - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - auto dimensionLength = static_cast(shape::length(hDimensionShape)); - - - NativeOpExecutioner::execReduce3All(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParamsVals, dbY->primary(), - hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, dimension, - dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execTransformFloat( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformSame( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformAny(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execTransformAny( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execTransformStrict(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraParams) { + try { + NativeOpExecutioner::execTransformStrict( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraParams, nullptr, nullptr); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execReduce3All(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, void *extraParamsVals, + OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *dYShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, + const Nd4jLong *dDimensionShape, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + auto dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execReduce3All( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParamsVals, dbY->primary(), hYShapeInfo, + dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, + dbZ->special(), dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, + xOffsets, yTadShapeInfo, yOffsets); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** - * Concatneate multi array of the same shape together - * along a particular dimension - */ -void specialConcat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *hZ, - Nd4jLong const* hZShapeInfo, - Nd4jPointer *tadPointers, - Nd4jPointer *offsetPointers) { - try { - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + * Concatneate multi array of the same shape together + * along a particular dimension + */ +void specialConcat(Nd4jPointer *extraPointers, int dimension, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *hZ, + Nd4jLong const *hZShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers) { + try { + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_SINGLE_SELECTOR(zType, sd::SpecialMethods, + ::concatCpuGeneric(dimension, numArrays, data, + inputShapeInfo, hZ, hZShapeInfo), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** * This is dummy method for JNI compatibility - * Since we'll use this from java, jni compiler would like to have method no matter what. + * Since we'll use this from java, jni compiler would like to have method no + * matter what. */ -void initializeDevicesAndFunctions() { - -} +void initializeDevicesAndFunctions() {} void initializeFunctions(Nd4jPointer *functions) { - sd::BlasHelper::getInstance()->initializeFunctions(functions); + sd::BlasHelper::getInstance()->initializeFunctions(functions); } /** - * This method acquires memory chunk of requested size on host side - * - * @param pointer pointer that'll be used for allocation - * @param memorySize memory size, in bytes - * @param flags optional parameter - */ + * This method acquires memory chunk of requested size on host side + * + * @param pointer pointer that'll be used for allocation + * @param memorySize memory size, in bytes + * @param flags optional parameter + */ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { - return reinterpret_cast(new int8_t[memorySize]); + return reinterpret_cast(new int8_t[memorySize]); } /** * This method acquires memory chunk of requested size on specified device * - * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend. + * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based + * backend. * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { - // not supported - return 0L; + // not supported + return 0L; } /** @@ -1094,1997 +922,2041 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { * @param pointer pointer that'll be freed */ int freeHost(Nd4jPointer pointer) { - delete[] reinterpret_cast(pointer); - return 1L; + delete[] reinterpret_cast(pointer); + return 1L; } /** * This method releases previously allocated memory space on device * - * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based backend. + * PLEASE NOTE: This method is NOT supported and has NO effect in CPU-based + * backend. * * @param pointer pointer that'll be freed * @param ptrToDeviceId pointer to deviceId. */ int freeDevice(Nd4jPointer pointer, int deviceId) { - // not supported - return 0L; + // not supported + return 0L; } - /** * Returns the maximum number open mp threads */ -int ompGetMaxThreads() { - return omp_get_max_threads(); -} +int ompGetMaxThreads() { return omp_get_max_threads(); } /** * Returns the number open mp threads */ -int ompGetNumThreads() { - return omp_get_num_threads(); -} +int ompGetNumThreads() { return omp_get_num_threads(); } /** * Sets the number of openmp threads */ -void setOmpNumThreads(int threads) { - omp_set_num_threads(threads); +void setOmpNumThreads(int threads) { omp_set_num_threads(threads); } -} +Nd4jPointer createContext() { return 0L; } -Nd4jPointer createContext() { - return 0L; -} +Nd4jPointer createStream() { return 0L; } -Nd4jPointer createStream() { - return 0L; -} +Nd4jPointer createEvent() { return 0L; } -Nd4jPointer createEvent() { - return 0L; -} +int getDeviceMajor(int deviceId) { return 0; } -int getDeviceMajor(int deviceId ) { - return 0; -} +int getDeviceMinor(int deviceId) { return 0; } -int getDeviceMinor(int deviceId) { - return 0; -} +int registerEvent(Nd4jPointer event, Nd4jPointer stream) { return 0L; } -int registerEvent(Nd4jPointer event, Nd4jPointer stream) { - return 0L; -} +int setDevice(int deviceId) { return 0L; } -int setDevice(int deviceId) { - return 0L; -} +Nd4jLong getDeviceFreeMemory(int deviceId) { return 0L; } -Nd4jLong getDeviceFreeMemory(int deviceId) { - return 0L; -} +Nd4jLong getDeviceFreeMemoryDefault() { return 0L; } -Nd4jLong getDeviceFreeMemoryDefault() { - return 0L; -} +Nd4jLong getDeviceTotalMemory(int deviceId) { return 0L; } -Nd4jLong getDeviceTotalMemory(int deviceId) { - return 0L; +int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; +int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + return 0L; } -int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - return 0L; -} +int destroyEvent(Nd4jPointer event) { return 0L; } -int destroyEvent(Nd4jPointer event) { - return 0L; -} +int streamSynchronize(Nd4jPointer stream) { return 0L; } -int streamSynchronize(Nd4jPointer stream) { - return 0L; -} +int eventSynchronize(Nd4jPointer event) { return 0L; } -int eventSynchronize(Nd4jPointer event) { - return 0L; -} - -int getAvailableDevices() { - return 0L; -} +int getAvailableDevices() { return 0L; } void enableDebugMode(bool reallyEnable) { - sd::Environment::getInstance()->setDebug(reallyEnable); + sd::Environment::getInstance()->setDebug(reallyEnable); } void enableVerboseMode(bool reallyEnable) { - sd::Environment::getInstance()->setVerbose(reallyEnable); + sd::Environment::getInstance()->setVerbose(reallyEnable); } void setGridLimit(int gridSize) { - // no-op + // no-op } -sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* hXShapeInfo, int *dimension, int dimensionLength) { - auto pack = new TadPack(); - try { - *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +sd::TadPack *tadOnlyShapeInfo(Nd4jLong const *hXShapeInfo, int *dimension, + int dimensionLength) { + auto pack = new TadPack(); + try { + *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, dimensionLength); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } - return pack; + return pack; } -Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) { - return const_cast(pack->primaryShapeInfo()); +Nd4jLong const *getPrimaryShapeInfo(sd::TadPack *pack) { + return const_cast(pack->primaryShapeInfo()); } -Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) { - return const_cast(pack->primaryOffsets()); +Nd4jLong const *getPrimaryOffsets(sd::TadPack *pack) { + return const_cast(pack->primaryOffsets()); } -Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) { - return const_cast(pack->specialShapeInfo()); +Nd4jLong const *getSpecialShapeInfo(sd::TadPack *pack) { + return const_cast(pack->specialShapeInfo()); } -Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) { - return const_cast(pack->specialOffsets()); +Nd4jLong const *getSpecialOffsets(sd::TadPack *pack) { + return const_cast(pack->specialOffsets()); } -Nd4jLong getNumberOfTads(sd::TadPack* pack) { - return pack->numberOfTads(); -} +Nd4jLong getNumberOfTads(sd::TadPack *pack) { return pack->numberOfTads(); } -int getShapeInfoLength(sd::TadPack* pack) { - return pack->shapeInfoLength(); -} +int getShapeInfoLength(sd::TadPack *pack) { return pack->shapeInfoLength(); } -int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - // no-op - return 0L; +int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + // no-op + return 0L; } Nd4jPointer getConstantSpace() { - // no-op - return 0L; -} - -template -void pullRowsGeneric(void *vx, - Nd4jLong const* hXShapeInfo, - void *vz, - Nd4jLong const* hZShapeInfo, - const int n, - Nd4jLong const* indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - auto hX = reinterpret_cast(vx); - auto hZ = reinterpret_cast(vz); - - const auto xEWS = shape::elementWiseStride(tadShapeInfo); - const auto zEWS = shape::elementWiseStride(zTadShapeInfo); - const auto tadLength = shape::length(tadShapeInfo); - - int elementsPerThread = n / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - auto func = PRAGMA_THREADS_FOR { - for (auto idx = start; idx < stop; idx++) { - auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; - auto zTadOffsetForBlock = zTadOffsets[idx]; - - auto rX = hX + xTadOffsetForBlock; - auto rZ = hZ + zTadOffsetForBlock; - - if (xEWS == 1 && zEWS == 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < tadLength; i++) { - rZ[i] = rX[i]; - } - } else if (xEWS >= 1 && zEWS >= 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < tadLength; i++) { - rZ[i * zEWS] = rX[i * xEWS]; - } - } else { - for (Nd4jLong i = 0; i < tadLength; i++) { - auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo); - auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo); - hZ[zOffset] = hX[xOffset]; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, n, 1, _threads); -} - -void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - Nd4jLong n, - Nd4jLong* indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + // no-op + return 0L; } -template -void tearGeneric(void *vx, - Nd4jLong const* hXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* hZShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - auto hX = reinterpret_cast(vx); - - const auto tadLength = shape::length(tadShapeInfo); - auto tadEWS = shape::elementWiseStride(tadShapeInfo); - auto zEWS = shape::elementWiseStride(hZShapeInfo); - auto numTads = shape::length(hXShapeInfo) / tadLength; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto hZ = reinterpret_cast(targets[i]); - auto s = hX + tadOffsets[i]; - - if (zEWS == 1 && tadEWS == 1) { - PRAGMA_OMP_SIMD - for (Nd4jLong j = 0; j < tadLength; j++) { - hZ[j] = s[j]; - } - } else if (zEWS > 0 && tadEWS > 0) { - PRAGMA_OMP_SIMD - for (Nd4jLong j = 0; j < tadLength; j++) { - hZ[j * zEWS] = s[j * tadEWS]; - } - } else { - for (Nd4jLong j = 0; j < tadLength; j++) - hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)]; - } +template +void pullRowsGeneric(void *vx, Nd4jLong const *hXShapeInfo, void *vz, + Nd4jLong const *hZShapeInfo, const int n, + Nd4jLong const *indexes, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + auto hX = reinterpret_cast(vx); + auto hZ = reinterpret_cast(vz); + + const auto xEWS = shape::elementWiseStride(tadShapeInfo); + const auto zEWS = shape::elementWiseStride(zTadShapeInfo); + const auto tadLength = shape::length(tadShapeInfo); + + int elementsPerThread = n / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + auto func = PRAGMA_THREADS_FOR { + for (auto idx = start; idx < stop; idx++) { + auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; + auto zTadOffsetForBlock = zTadOffsets[idx]; + + auto rX = hX + xTadOffsetForBlock; + auto rZ = hZ + zTadOffsetForBlock; + + if (xEWS == 1 && zEWS == 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < tadLength; i++) { + rZ[i] = rX[i]; + } + } else if (xEWS >= 1 && zEWS >= 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < tadLength; i++) { + rZ[i * zEWS] = rX[i * xEWS]; } - }; + } else { + for (Nd4jLong i = 0; i < tadLength; i++) { + auto xOffset = + xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo); + auto zOffset = + zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo); + hZ[zOffset] = hX[xOffset]; + } + } + } + }; - samediff::Threads::parallel_tad(func,0, numTads); + samediff::Threads::parallel_tad(func, 0, n, 1, _threads); } -void tear(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* hZShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); +void pullRows(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearGeneric, (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + BUILD_SINGLE_SELECTOR( + xType, pullRowsGeneric, + (dbX->primary(), hXShapeInfo, dbZ->primary(), hZShapeInfo, n, indexes, + tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } - -void average(Nd4jPointer *extras, - Nd4jPointer *hX, const Nd4jLong *hXShapeInfo, - Nd4jPointer *dX, const Nd4jLong *dXShapeInfo, - void *z, const Nd4jLong *hZShapeInfo, - void *dz, const Nd4jLong *dZShapeInfo, - int n, - Nd4jLong length, - bool propagate) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +template +void tearGeneric(void *vx, Nd4jLong const *hXShapeInfo, Nd4jPointer *targets, + Nd4jLong const *hZShapeInfo, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets) { + auto hX = reinterpret_cast(vx); + + const auto tadLength = shape::length(tadShapeInfo); + auto tadEWS = shape::elementWiseStride(tadShapeInfo); + auto zEWS = shape::elementWiseStride(hZShapeInfo); + auto numTads = shape::length(hXShapeInfo) / tadLength; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto hZ = reinterpret_cast(targets[i]); + auto s = hX + tadOffsets[i]; + + if (zEWS == 1 && tadEWS == 1) { + PRAGMA_OMP_SIMD + for (Nd4jLong j = 0; j < tadLength; j++) { + hZ[j] = s[j]; + } + } else if (zEWS > 0 && tadEWS > 0) { + PRAGMA_OMP_SIMD + for (Nd4jLong j = 0; j < tadLength; j++) { + hZ[j * zEWS] = s[j * tadEWS]; + } + } else { + for (Nd4jLong j = 0; j < tadLength; j++) + hZ[shape::getIndexOffset(j, hZShapeInfo)] = + s[shape::getIndexOffset(j, tadShapeInfo)]; + } } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); } -void accumulate(Nd4jPointer *extras, - Nd4jPointer *hX, Nd4jLong const* hXShapeInfo, - Nd4jPointer *dX, Nd4jLong const* dXShapeInfo, - void *hz, Nd4jLong const* hZShapeInfo, - void *dz, Nd4jLong const* dZShapeInfo, - int n, - Nd4jLong length) { - try { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void tear(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + Nd4jPointer *targets, Nd4jLong const *hZShapeInfo, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, tearGeneric, + (dbX->primary(), hXShapeInfo, targets, hZShapeInfo, + tadShapeInfo, tadOffsets), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void average(Nd4jPointer *extras, Nd4jPointer *hX, const Nd4jLong *hXShapeInfo, + Nd4jPointer *dX, const Nd4jLong *dXShapeInfo, void *z, + const Nd4jLong *hZShapeInfo, void *dz, const Nd4jLong *dZShapeInfo, + int n, Nd4jLong length, bool propagate) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void accumulate(Nd4jPointer *extras, Nd4jPointer *hX, + Nd4jLong const *hXShapeInfo, Nd4jPointer *dX, + Nd4jLong const *dXShapeInfo, void *hz, + Nd4jLong const *hZShapeInfo, void *dz, + Nd4jLong const *dZShapeInfo, int n, Nd4jLong length) { + try { + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void enableP2P(bool enable) { - // no-op + // no-op } - - -void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) { - // TODO: to be implemented +void encodeThresholdP1(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, Nd4jLong N, int *dz, + float threshold) { + // TODO: to be implemented } - -void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, int *dz) { - // TODO: to be implemented +void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *hX, Nd4jLong N, + int *dz) { + // TODO: to be implemented } +void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, int *offsets, Nd4jLong N, + int *dz) { + // offsets won't be used here -void encodeThresholdP3(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, int *offsets, Nd4jLong N, int *dz){ - // offsets won't be used here - - // TODO: to be implemented + // TODO: to be implemented } -void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, const Nd4jLong *hZShapeInfo){ - // TODO: to be implemented +void decodeThreshold(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, + const Nd4jLong *hZShapeInfo) { + // TODO: to be implemented } bool isP2PAvailable() { - // always TRUE for cpu backend - return true; + // always TRUE for cpu backend + return true; } void checkP2P() { - // no-op + // no-op } -void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, Nd4jLong const* hZShapeInfo) { - NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo); +void decodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong N, void *dz, + Nd4jLong const *hZShapeInfo) { + NativeOpExecutioner::decodeBitmap(hX, N, dz, hZShapeInfo); } -template -void shuffleGeneric(void **hX, Nd4jLong * const*hXShapeInfo, void **dz, Nd4jLong * const* hZShapeInfo, int N, int *shuffleMap, Nd4jLong * const* tadOnlyShapeInfo, Nd4jLong * const* tadOffsets) { - - auto dX = reinterpret_cast(hX); - auto dZ = reinterpret_cast(dz); - - auto func = PRAGMA_THREADS_FOR { - for (auto f = start; f < stop; f++) { - auto hX = reinterpret_cast(dX[f]); - //auto hZ = reinterpret_cast(dZ[f]); - - auto xShapeInfo = hXShapeInfo[f]; - auto tadOffset = reinterpret_cast(tadOffsets[f]); - - - const auto tadLength = shape::length(tadOnlyShapeInfo[f]); - auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); - auto tadRank = shape::rank(tadOnlyShapeInfo[f]); - auto numTads = shape::length(hXShapeInfo[f]) / tadLength; - - auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]); - auto tadStride = shape::stride(tadOnlyShapeInfo[f]); - - if (shape::rank(xShapeInfo) == 1) { - auto xLength = shape::length(xShapeInfo); - auto ews = shape::elementWiseStride(xShapeInfo); - for (Nd4jLong r = 0; r < xLength; r++) { - auto swapIdx = shuffleMap[r]; - if (swapIdx < 0) - continue; - - sd::math::nd4j_swap(hX[r * ews], hX[swapIdx * ews]); - } - } else { - for (Nd4jLong r = 0; r < numTads; r++) { - if (shuffleMap[r] < 0) - continue; - - auto oldOffset = tadOffset[r]; - auto newOffset = tadOffset[shuffleMap[r]]; - - auto rX = hX + oldOffset; - auto rY = hX + newOffset; - - if (tadEWS == 1) { - for (Nd4jLong i = 0; i < tadLength; i++) { - sd::math::nd4j_swap(rX[i], rY[i]); - } - } else { - for (Nd4jLong i = 0; i < tadLength; i++) { - auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); - sd::math::nd4j_swap(hX[offset + oldOffset], hX[offset + newOffset]); - } - } - } - } +template +void shuffleGeneric(void **hX, Nd4jLong *const *hXShapeInfo, void **dz, + Nd4jLong *const *hZShapeInfo, int N, int *shuffleMap, + Nd4jLong *const *tadOnlyShapeInfo, + Nd4jLong *const *tadOffsets) { + auto dX = reinterpret_cast(hX); + auto dZ = reinterpret_cast(dz); + + auto func = PRAGMA_THREADS_FOR { + for (auto f = start; f < stop; f++) { + auto hX = reinterpret_cast(dX[f]); + // auto hZ = reinterpret_cast(dZ[f]); + + auto xShapeInfo = hXShapeInfo[f]; + auto tadOffset = reinterpret_cast(tadOffsets[f]); + + const auto tadLength = shape::length(tadOnlyShapeInfo[f]); + auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); + auto tadRank = shape::rank(tadOnlyShapeInfo[f]); + auto numTads = shape::length(hXShapeInfo[f]) / tadLength; + + auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]); + auto tadStride = shape::stride(tadOnlyShapeInfo[f]); + + if (shape::rank(xShapeInfo) == 1) { + auto xLength = shape::length(xShapeInfo); + auto ews = shape::elementWiseStride(xShapeInfo); + for (Nd4jLong r = 0; r < xLength; r++) { + auto swapIdx = shuffleMap[r]; + if (swapIdx < 0) continue; + + sd::math::nd4j_swap(hX[r * ews], hX[swapIdx * ews]); } - }; - - samediff::Threads::parallel_tad(func, 0, N); -} - -void shuffle(Nd4jPointer *extras, - Nd4jPointer *hX, Nd4jPointer *hXShapeInfo, - Nd4jPointer *dX, Nd4jPointer *dXShapeInfo, - Nd4jPointer *hz, Nd4jPointer *hZShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets) { - try { - auto xShape = reinterpret_cast(hXShapeInfo); - auto zShape = reinterpret_cast(hZShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); - - auto xType = sd::ArrayOptions::dataType(xShape[0]); - - BUILD_SINGLE_SELECTOR(xType, shuffleGeneric, - (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -bool isExperimentalEnabled() { - return sd::Environment::getInstance()->isExperimentalBuild(); -} - - -void setOmpMinThreads(int threads) { - // TODO: to be implemented -} - -int getDevice() { - return 0; -} - -void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const*dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const*tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const*tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalar(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalars->primary(), - hScalarShapeInfo, - dbScalars->special(), - dScalarShapeInfo, - dimension, - shape::length(hDimensionShape), - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - OpaqueDataBuffer *dbScalars, const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { - try { - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalarBool(nullptr, - opNum, - dbX->primary(), - hXShapeInfo, - dbX->special(), - dXShapeInfo, - extraParams, - dbZ->primary(), - hZShapeInfo, - dbZ->special(), - dZShapeInfo, - dbScalars->primary(), - hScalarShapeInfo, - dbScalars->special(), - dScalarShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + } else { + for (Nd4jLong r = 0; r < numTads; r++) { + if (shuffleMap[r] < 0) continue; -const char * getDeviceName(int deviceId) { - try { - if (!nameSet) { - name = reinterpret_cast(malloc(256 * sizeof(char))); + auto oldOffset = tadOffset[r]; + auto newOffset = tadOffset[shuffleMap[r]]; - CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); + auto rX = hX + oldOffset; + auto rY = hX + newOffset; - std::memset(name, 0, 256 * sizeof(char)); - nameSet = true; - - // TODO: provide proper CPU model name here - sprintf(name, "x86-compatible CPU"); + if (tadEWS == 1) { + for (Nd4jLong i = 0; i < tadLength; i++) { + sd::math::nd4j_swap(rX[i], rY[i]); + } + } else { + for (Nd4jLong i = 0; i < tadLength; i++) { + auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); + sd::math::nd4j_swap(hX[offset + oldOffset], + hX[offset + newOffset]); + } + } } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } + }; - - return name; + samediff::Threads::parallel_tad(func, 0, N); } +void shuffle(Nd4jPointer *extras, Nd4jPointer *hX, Nd4jPointer *hXShapeInfo, + Nd4jPointer *dX, Nd4jPointer *dXShapeInfo, Nd4jPointer *hz, + Nd4jPointer *hZShapeInfo, Nd4jPointer *dz, + Nd4jPointer *dZShapeInfo, int N, int *shuffleMap, + Nd4jPointer *tadShapeInfo, Nd4jPointer *tadOffsets) { + try { + auto xShape = reinterpret_cast(hXShapeInfo); + auto zShape = reinterpret_cast(hZShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); -void execAggregate(Nd4jPointer *extraPointers,int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapeArguments, - int numShapeArguments, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype) { + auto xType = sd::ArrayOptions::dataType(xShape[0]); + BUILD_SINGLE_SELECTOR( + xType, shuffleGeneric, + (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { - +bool isExperimentalEnabled() { + return sd::Environment::getInstance()->isExperimentalBuild(); } -void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { - +void setOmpMinThreads(int threads) { + // TODO: to be implemented +} + +int getDevice() { return 0; } + +void execScalarTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execScalar( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dbScalars->primary(), hScalarShapeInfo, + dbScalars->special(), dScalarShapeInfo, dimension, + shape::length(hDimensionShape), tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execScalarBoolTad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, OpaqueDataBuffer *dbScalars, + const Nd4jLong *hScalarShapeInfo, const Nd4jLong *dScalarShapeInfo, + void *extraParams, OpaqueDataBuffer *dbDimension, + const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + NativeOpExecutioner::execScalarBool( + nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, dbScalars->primary(), hScalarShapeInfo, + dbScalars->special(), dScalarShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } +const char *getDeviceName(int deviceId) { + try { + if (!nameSet) { + name = reinterpret_cast(malloc(256 * sizeof(char))); -void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); -void execRandom3(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbY, const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + std::memset(name, 0, 256 * sizeof(char)); + nameSet = true; -void execRandom2(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer state, - OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, const Nd4jLong *dXShapeInfo, - OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, - void *extraArguments) { - try { - NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, extraArguments); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + // TODO: provide proper CPU model name here + sprintf(name, "x86-compatible CPU"); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } + + return name; +} + +void execAggregate(Nd4jPointer *extraPointers, int opNum, void **arguments, + int numArguments, Nd4jLong **shapeArguments, + int numShapeArguments, int *indexArguments, + int numIndexArguments, int **intArrays, int numIntArrays, + void *realArguments, int numRealArguments, + sd::DataType dtype) {} + +void batchExecutor(Nd4jPointer *extraPointers, int numAggregates, int opNum, + int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} + +void execAggregateBatch(Nd4jPointer *extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} + +void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbZ->primary(), + hZShapeInfo, dbZ->special(), dZShapeInfo, + extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *dYShapeInfo, + OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, + const Nd4jLong *dZShapeInfo, void *extraArguments) { + try { + NativeOpExecutioner::execRandom( + nullptr, opNum, state, dbX->primary(), hXShapeInfo, dbX->special(), + dXShapeInfo, dbY->primary(), hYShapeInfo, dbY->special(), dYShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), dZShapeInfo, + extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer state, + OpaqueDataBuffer *dbX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *dXShapeInfo, OpaqueDataBuffer *dbZ, + const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo, + void *extraArguments) { + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, dbX->primary(), + hXShapeInfo, dbX->special(), dXShapeInfo, + dbZ->primary(), hZShapeInfo, dbZ->special(), + dZShapeInfo, extraArguments); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) { - try { - auto generator = new graph::RandomGenerator(seed, seed); +Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, + Nd4jPointer ptrToBuffer) { + try { + auto generator = new graph::RandomGenerator(seed, seed); - return (Nd4jPointer) generator; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return (Nd4jPointer)generator; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); - return nullptr; - } + return nullptr; + } } -void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - auto generator = reinterpret_cast (ptrRandom); +void refreshBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + auto generator = reinterpret_cast(ptrRandom); - generator->setStates(seed); + generator->setStates(seed); } -void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - auto generator = reinterpret_cast (ptrRandom); +void reSeedBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + auto generator = reinterpret_cast(ptrRandom); - generator->setStates(seed); + generator->setStates(seed); } - void destroyRandom(Nd4jPointer ptrBuffer) { - auto buffer = reinterpret_cast(ptrBuffer); - delete buffer; + auto buffer = reinterpret_cast(ptrBuffer); + delete buffer; } - - - /** - * Return the length of a shape buffer - * based on the pointer - * @param buffer the buffer pointer to check - * @return - */ + * Return the length of a shape buffer + * based on the pointer + * @param buffer the buffer pointer to check + * @return + */ int lengthForShapeBufferPointer(Nd4jPointer buffer) { - auto shapeBuffer = reinterpret_cast(buffer); - return shape::shapeInfoLength(shape::rank(shapeBuffer)); + auto shapeBuffer = reinterpret_cast(buffer); + return shape::shapeInfoLength(shape::rank(shapeBuffer)); } - /** - * The pointer to get the address for - * - * @param address the address to get the pointer - * @return the pointer for the given address - */ + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ Nd4jPointer pointerForAddress(Nd4jLong address) { - return reinterpret_cast(address); -} - -void sort(Nd4jPointer *extraPointers, - void *hX, const Nd4jLong *hXShapeInfo, - void *dX, const Nd4jLong *dXShapeInfo, - bool descending) { - try { - NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + return reinterpret_cast(address); } -void sortTad(Nd4jPointer *extraPointers, - void *hX, const Nd4jLong *hXShapeInfo, - void *dX, const Nd4jLong *dXShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, - const Nd4jLong *tadOffsets, - bool descending) { - try { - NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void sort(Nd4jPointer *extraPointers, void *hX, const Nd4jLong *hXShapeInfo, + void *dX, const Nd4jLong *dXShapeInfo, bool descending) { + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTad(Nd4jPointer *extraPointers, void *hX, const Nd4jLong *hXShapeInfo, + void *dX, const Nd4jLong *dXShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending) { + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void sortCooIndices(Nd4jPointer *extraPointers, - Nd4jLong *indices, - void *values, - Nd4jLong length, - int rank) { - try { - NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + try { + NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) { - return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold); +Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, + Nd4jLong const *hXShapeInfo, Nd4jLong N, int *dz, + float threshold) { + return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold); } - - -Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - auto hZ = new Nd4jLong[2];errno = 0; -try { +Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, + Nd4jLong length) { + auto hZ = new Nd4jLong[2]; + errno = 0; + try { #if defined(_WIN32) || defined(_WIN64) _mmap(hZ, static_cast(length), fileName); #else - int fd = open(fileName, O_RDWR, 0);// checking for failed fopen + int fd = open(fileName, O_RDWR, 0); // checking for failed fopen if (fd < 0) { - nd4j_printf("Errno: %i\n", errno); - throw std::runtime_error("Failed to open file for MMAP"); + nd4j_printf("Errno: %i\n", errno); + throw std::runtime_error("Failed to open file for MMAP"); } void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); -// check for failed allocation - if (ptr == MAP_FAILED) - return nullptr; + // check for failed allocation + if (ptr == MAP_FAILED) return nullptr; - hZ[0] = (Nd4jLong) ptr; + hZ[0] = (Nd4jLong)ptr; hZ[1] = fd; #endif return hZ; -} catch (std::exception &e) { + } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); return nullptr; -} + } } void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { - munmap((Nd4jPointer) ptrMap[0], length); + munmap((Nd4jPointer)ptrMap[0], length); #if defined(_WIN32) || defined(_WIN64) - CloseHandle(reinterpret_cast(ptrMap[1])); + CloseHandle(reinterpret_cast(ptrMap[1])); #else - close((int) ptrMap[1]); + close((int)ptrMap[1]); #endif - delete[] ptrMap; + delete[] ptrMap; } -sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - return nullptr; +sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, + Nd4jPointer flatBufferPointer) { + return nullptr; } -Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) { - return ptr->size(); +Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper *ptr) { + return ptr->size(); } -Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) { - return ptr->pointer(); +Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper *ptr) { + return ptr->pointer(); } -const char* getAllCustomOps() { - return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); +const char *getAllCustomOps() { + return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); } template -FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) { - auto buffer = reinterpret_cast(hX); - int span = (N / 6) + 8; +FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, + Nd4jPointer hX, int N, T threshold) { + auto buffer = reinterpret_cast(hX); + int span = (N / 6) + 8; - auto func = PRAGMA_REDUCE_LONG { - int64_t cnt = 0; - PRAGMA_OMP_SIMD - for (auto e = start; e < stop; e++) { - auto v = sd::math::nd4j_abs(buffer[e]); - if (v >= threshold) - cnt++; - } + auto func = PRAGMA_REDUCE_LONG { + int64_t cnt = 0; + PRAGMA_OMP_SIMD + for (auto e = start; e < stop; e++) { + auto v = sd::math::nd4j_abs(buffer[e]); + if (v >= threshold) cnt++; + } - return cnt; - }; + return cnt; + }; - return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N); + return samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, N); } - -int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong const* hXShapeInfo, int N, float threshold) { - try { - auto xType = ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 0; - } +int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, + Nd4jLong const *hXShapeInfo, int N, float threshold) { + try { + auto xType = ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, + (extraPointers, hX, N, threshold), FLOAT_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 0; + } } -Nd4jLong getShapeListSize(sd::ShapeList* list) { - return list->size(); -} +Nd4jLong getShapeListSize(sd::ShapeList *list) { return list->size(); } -Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) { - return const_cast(list->at(i)); +Nd4jLong const *getShape(sd::ShapeList *list, Nd4jLong i) { + return const_cast(list->at(i)); } void deleteShapeList(Nd4jPointer shapeList) { - auto list = reinterpret_cast(shapeList); + auto list = reinterpret_cast(shapeList); - //list->destroy(); - delete list; + // list->destroy(); + delete list; } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - sd::graph::VariableSpace varSpace; - Context block(2, &varSpace); - sd::ShapeList inShapes; +sd::ShapeList *_calculateOutputShapes( + Nd4jPointer *extraPointers, sd::ops::DeclarableOp *op, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputShapes, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, int *dArgs, int numDArgs) { + sd::graph::VariableSpace varSpace; + Context block(2, &varSpace); + sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) - block.appendI(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numTArgs; e++) - block.appendT(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numBArgs; e++) - block.appendB(bArgs[e]); + for (int e = 0; e < numBArgs; e++) block.appendB(bArgs[e]); - for (int e = 0; e < numDArgs; e++) - block.appendD((sd::DataType) dArgs[e]); + for (int e = 0; e < numDArgs; e++) block.appendD((sd::DataType)dArgs[e]); - for (int e = 0; e < numInputShapes; e++) { - auto shape_ = reinterpret_cast(inputShapes[e]); + for (int e = 0; e < numInputShapes; e++) { + auto shape_ = reinterpret_cast(inputShapes[e]); - // we shouldn't copy buffer if that's empty array - void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; + // we shouldn't copy buffer if that's empty array + void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; - auto array = std::make_shared(buffer_, shape_, LaunchContext::defaultContext(), false); + auto array = std::make_shared( + buffer_, shape_, LaunchContext::defaultContext(), false); - // block should contain references to proper variable - varSpace.putVariable(1, e, array); - block.pickInput(1, e); + // block should contain references to proper variable + varSpace.putVariable(1, e, array); + block.pickInput(1, e); - inShapes.push_back(shape_); - } + inShapes.push_back(shape_); + } - auto status = op->validateDataTypes(block); - if (status != Status::OK()) - throw std::runtime_error("Data types validation failed"); + auto status = op->validateDataTypes(block); + if (status != Status::OK()) + throw std::runtime_error("Data types validation failed"); - auto shapeList = op->calculateOutputShape(&inShapes, block); + auto shapeList = op->calculateOutputShape(&inShapes, block); - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); +sd::ShapeList *calculateOutputShapes2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs, bool *bArgs, int numBArgs, + int *dArgs, int numDArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op.get(), inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes( + extraPointers, op.get(), inputBuffers, inputShapes, numInputShapes, + tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp *op, Nd4jPointer* inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - Context block(1); - sd::ShapeList inShapes; +sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, + sd::ops::DeclarableOp *op, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + Context block(1); + sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) - block.appendI(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numTArgs; e++) - block.appendT(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numInputShapes; e++) - inShapes.push_back(reinterpret_cast(inputShapes[e])); + for (int e = 0; e < numInputShapes; e++) + inShapes.push_back(reinterpret_cast(inputShapes[e])); - auto shapeList = op->calculateOutputShape(&inShapes, block); - shapeList->detach(); + auto shapeList = op->calculateOutputShape(&inShapes, block); + shapeList->detach(); - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - - return _calculateOutputShapes(extraPointers, op.get(), inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} +sd::ShapeList *calculateOutputShapes(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); -int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - auto context = reinterpret_cast(opContext); - - return op->execute(context); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 20; - } + return _calculateOutputShapes(extraPointers, op.get(), inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - if (op == nullptr) - nd4j_printf("Can't find requested operation: [%lld]\n", hash); - - // we're using the same fake nodeId everywhere here - - std::vector inputs(numInputs); - std::vector outputs(numOutputs); - std::vector ttArgs(numTArgs); - std::vector iiArgs(numIArgs); - std::vector biArgs(numBArgs); - - // filling block now with inputs - for (int e = 0; e < numInputs; e++) { - auto shape = reinterpret_cast(inputShapes[e]); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - - inputs[e] = new sd::NDArray(buffer, shape); - } - - // if not inplace - transferring output arrays - - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - // we want to keep original output shape intact - auto shape = shape::copyShape(reinterpret_cast(outputShapes[e])); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e]; - - // FIXME: revisit this. - bool canNullify = true; - for (int i = 0; i < numInputs; i++) { - void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[i]; - if (ibuffer == buffer) { - canNullify = false; - break; - } - } - - if (canNullify) - memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); - - auto array = new sd::NDArray(buffer, shape); - outputs[e] = array; +int execCustomOp2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer opContext) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); + auto context = reinterpret_cast(opContext); - // and we want to release shape copy once we're done - delete []shape; + return op->execute(context); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 20; + } +} + +Nd4jStatus realExec(sd::ops::DeclarableOp *op, Nd4jPointer *extraPointers, + Nd4jLong hash, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, + int numOutputs, double *tArgs, int numTArgs, + Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, + bool isInplace) { + if (op == nullptr) + nd4j_printf("Can't find requested operation: [%lld]\n", hash); + + // we're using the same fake nodeId everywhere here + + std::vector inputs(numInputs); + std::vector outputs(numOutputs); + std::vector ttArgs(numTArgs); + std::vector iiArgs(numIArgs); + std::vector biArgs(numBArgs); + + // filling block now with inputs + for (int e = 0; e < numInputs; e++) { + auto shape = reinterpret_cast(inputShapes[e]); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + + inputs[e] = new sd::NDArray(buffer, shape); + } + + // if not inplace - transferring output arrays + + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // we want to keep original output shape intact + auto shape = + shape::copyShape(reinterpret_cast(outputShapes[e])); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e]; + + // FIXME: revisit this. + bool canNullify = true; + for (int i = 0; i < numInputs; i++) { + void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[i]; + if (ibuffer == buffer) { + canNullify = false; + break; } + } - for (int e = 0; e < numIArgs; e++) - iiArgs[e] = iArgs[e]; + if (canNullify) + memset((uint8_t *)buffer, '\0', + shape::length(shape) * + DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); + auto array = new sd::NDArray(buffer, shape); + outputs[e] = array; - for (int e = 0; e < numTArgs; e++) - ttArgs[e] = tArgs[e]; + // and we want to release shape copy once we're done + delete[] shape; + } - for (int e = 0; e < numBArgs; e++) - biArgs[e] = bArgs[e]; + for (int e = 0; e < numIArgs; e++) iiArgs[e] = iArgs[e]; - // hypothetically at this point we have everything filled - auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, std::vector(), isInplace); - //auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); + for (int e = 0; e < numTArgs; e++) ttArgs[e] = tArgs[e]; + for (int e = 0; e < numBArgs; e++) biArgs[e] = bArgs[e]; + // hypothetically at this point we have everything filled + auto hZ = op->execute(inputs, outputs, ttArgs, iiArgs, biArgs, + std::vector(), isInplace); + // auto hZ = op->execute(inputs, ttArgs, iiArgs, isInplace); - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - //shape::printShapeInfoLinear("JVM output shape", (int *) outputShapes[e]); - //shape::printShapeInfoLinear("C++ output shape", (int *) outputs[e]->shapeInfo()); - //outputs[e]->printIndexedBuffer("C++ raw output"); - //outputs[e]->printBuffer("C++ indexed output"); + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // shape::printShapeInfoLinear("JVM output shape", (int *) + // outputShapes[e]); shape::printShapeInfoLinear("C++ output shape", (int + // *) outputs[e]->shapeInfo()); outputs[e]->printIndexedBuffer("C++ raw + // output"); outputs[e]->printBuffer("C++ indexed output"); - if (outputs[e]->ordering() != shape::order(reinterpret_cast(outputShapes[e]))) - outputs[e]->streamline(shape::order(reinterpret_cast(outputShapes[e]))); - } + if (outputs[e]->ordering() != + shape::order(reinterpret_cast(outputShapes[e]))) + outputs[e]->streamline( + shape::order(reinterpret_cast(outputShapes[e]))); + } - for (auto v: inputs) - delete v; + for (auto v : inputs) delete v; - for (auto v: outputs) - delete v; + for (auto v : outputs) delete v; - return hZ; + return hZ; } - -int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return realExec(op.get(), extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } +int execCustomOp(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, + int numInputs, Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs, double *tArgs, + int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, bool isInplace) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); + return realExec(op.get(), extraPointers, hash, inputBuffers, inputShapes, + numInputs, outputBuffers, outputShapes, numOutputs, tArgs, + numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { - try { - auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); +int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer) { + try { + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); - //sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); + // sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + return ND4J_STATUS_OK; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -sd::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - return nullptr; +sd::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, + Nd4jLong graphId, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int *inputIndices, int numInputs) { + return nullptr; } -Nd4jLong getVariablesSetSize(sd::graph::VariablesSet* set) { - return set->size(); +Nd4jLong getVariablesSetSize(sd::graph::VariablesSet *set) { + return set->size(); } -Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet* set) { - return set->status(); +Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet *set) { + return set->status(); } -sd::graph::Variable* getVariable(sd::graph::VariablesSet* set, Nd4jLong i) { - return set->at(i); +sd::graph::Variable *getVariable(sd::graph::VariablesSet *set, Nd4jLong i) { + return set->at(i); } -int getVariableId(sd::graph::Variable* variable) { - return variable->id(); -} +int getVariableId(sd::graph::Variable *variable) { return variable->id(); } -int getVariableIndex(sd::graph::Variable* variable) { - return variable->index(); +int getVariableIndex(sd::graph::Variable *variable) { + return variable->index(); } -const char* getVariableName(sd::graph::Variable* variable) { - return variable->getName().c_str(); +const char *getVariableName(sd::graph::Variable *variable) { + return variable->getName().c_str(); } -Nd4jLong const* getVariableShape(sd::graph::Variable* variable) { - return const_cast(variable->getNDArray()->shapeInfo()); +Nd4jLong const *getVariableShape(sd::graph::Variable *variable) { + return const_cast(variable->getNDArray()->shapeInfo()); } -void* getVariableBuffer(sd::graph::Variable* variable) { - return variable->getNDArray()->buffer(); +void *getVariableBuffer(sd::graph::Variable *variable) { + return variable->getNDArray()->buffer(); } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { - sd::graph::GraphHolder::getInstance()->forgetGraph(graphId); - return Status::OK(); + sd::graph::GraphHolder::getInstance()->forgetGraph(graphId); + return Status::OK(); } void deletePointerArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteCharArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteIntArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteLongArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } -void deleteVariablesSet(sd::graph::VariablesSet* pointer) { - delete pointer; -} +void deleteVariablesSet(sd::graph::VariablesSet *pointer) { delete pointer; } -const char* getAllOperations() { - return sd::OpTracker::getInstance()->exportOperations(); +const char *getAllOperations() { + return sd::OpTracker::getInstance()->exportOperations(); } -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - return 0; +Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, + Nd4jLong opHash, Nd4jLong *scopes, + int numScopes, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs) { + return 0; } void deleteResultWrapper(Nd4jPointer ptr) { - // just 0 room for compiler s@!t - auto p = reinterpret_cast(ptr); - delete p; + // just 0 room for compiler s@!t + auto p = reinterpret_cast(ptr); + delete p; } /* * TypeDef: - * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, long N, int dstType, Nd4jPointer hZ); + * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, long + * N, int dstType, Nd4jPointer hZ); */ -void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, Nd4jLong N, int dstType, Nd4jPointer hZ) { - auto hx = reinterpret_cast(hX); - auto hz = reinterpret_cast(hZ); - - if (srcType == ND4J_FLOAT8) { - if (dstType == ND4J_FLOAT8) { - // convertGeneric(hx, N, hz); - } else if (dstType == ND4J_INT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - //nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - //convertGeneric(hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: eventually we might want to add it - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_UINT8) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: still might want to add - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT16) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: .... ^^^ - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT16) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - // TODO... - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else { - printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT24) { - - } else if (srcType == ND4J_FLOAT32) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_DOUBLE) { - if (dstType == ND4J_FLOAT8) { - // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_UINT16) { -// sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGeneric(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - // - } else if (dstType == ND4J_THRESHOLD) { - sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_THRESHOLD) { - if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } +void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer hX, Nd4jLong N, + int dstType, Nd4jPointer hZ) { + auto hx = reinterpret_cast(hX); + auto hz = reinterpret_cast(hZ); + + if (srcType == ND4J_FLOAT8) { + if (dstType == ND4J_FLOAT8) { + // convertGeneric(hx, N, hz); + } else if (dstType == ND4J_INT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + // nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + // dstType); + } + } else if (srcType == ND4J_INT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + // convertGeneric(hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: eventually we might want to add it + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_UINT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: still might want to add + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: .... ^^^ + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, + // hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + // TODO... + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else { + printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_FLOAT24) { + } else if (srcType == ND4J_FLOAT32) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_DOUBLE) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGeneric(nullptr, hx, + // N, hz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGeneric(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + // + } else if (dstType == ND4J_THRESHOLD) { + sd::TypeCast::convertToThreshold(nullptr, hx, N, hz); } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_THRESHOLD) { + if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertFromThreshold(nullptr, hx, N, hz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); } + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } } /* -void fillUtf8String(Nd4jPointer *extraPointers, const char **strings, int numStrings, Nd4jPointer buffer) { - auto hZ = reinterpret_cast(buffer); - for (int e = 0; e < numStrings; e++) { - hZ[e] = reinterpret_cast(createUtf8String(extraPointers, strings[e])); +void fillUtf8String(Nd4jPointer *extraPointers, const char **strings, int +numStrings, Nd4jPointer buffer) { auto hZ = +reinterpret_cast(buffer); for (int e = 0; e < numStrings; e++) +{ hZ[e] = reinterpret_cast(createUtf8String(extraPointers, +strings[e])); } } */ -Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length) { - auto u = new sd::utf8string(string, length); - return reinterpret_cast(u); +Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, + int length) { + auto u = new sd::utf8string(string, length); + return reinterpret_cast(u); } Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_length; + return reinterpret_cast(ptr)->_length; } -char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_buffer; +char *getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; } void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - delete(reinterpret_cast(ptr)); + delete (reinterpret_cast(ptr)); } template -static void _scatterUpdate( - Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets, - void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets, - void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets, - void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets, - void* vIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) { - - auto hIindexes = reinterpret_cast(vIindexes); - auto func = PRAGMA_THREADS_DO { - for (int i = 0; i < numOfSubArrs; ++i) { - int threadIndex = thread_id; - const auto xIndex = hIindexes[i]; - const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; - - if (!isOwner) - continue; - - NDArray inSubArr(reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo); - NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) { - continue; - } - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); - break; - default: - continue; - } - } - }; - - samediff::Threads::parallel_do(func); +static void _scatterUpdate( + Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void *hX, + const Nd4jLong *hXShapeInfo, const Nd4jLong *hXOffsets, void *dX, + const Nd4jLong *dXShapeInfo, const Nd4jLong *dXOffsets, void *hY, + const Nd4jLong *hYShapeInfo, const Nd4jLong *hYOffsets, void *dY, + const Nd4jLong *dYShapeInfo, const Nd4jLong *dYOffsets, void *vIindexes, + const Nd4jLong *hIndicesShapeInfo, void *dIindexes, + const Nd4jLong *dIndicesShapeInfo) { + auto hIindexes = reinterpret_cast(vIindexes); + auto func = PRAGMA_THREADS_DO { + for (int i = 0; i < numOfSubArrs; ++i) { + int threadIndex = thread_id; + const auto xIndex = hIindexes[i]; + const bool isOwner = xIndex < numThreads + ? threadIndex == xIndex + : threadIndex == xIndex % numThreads; + + if (!isOwner) continue; + + NDArray inSubArr( + reinterpret_cast(hX) + + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), + hXShapeInfo); + NDArray updSubArr(reinterpret_cast(hY) + + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), + hYShapeInfo); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) { + continue; + } + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, + inSubArr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, + inSubArr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, + inSubArr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, + inSubArr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, + inSubArr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, + inSubArr); + break; + default: + continue; + } + } + }; + + samediff::Threads::parallel_do(func); } //////////////////////////////////////////////////////////////////////// void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, const Nd4jLong* hXShapeInfo, const Nd4jLong* hXOffsets, - void* dX, const Nd4jLong* dXShapeInfo, const Nd4jLong* dXOffsets, - void* hY, const Nd4jLong* hYShapeInfo, const Nd4jLong* hYOffsets, - void* dY, const Nd4jLong* dYShapeInfo, const Nd4jLong* dYOffsets, - void* hIindexes, const Nd4jLong* hIndicesShapeInfo, void* dIindexes, const Nd4jLong* dIndicesShapeInfo) { - auto iType = ArrayOptions::dataType(hIndicesShapeInfo); - - try { - BUILD_SINGLE_SELECTOR(iType, _scatterUpdate, (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), INDEXING_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - try { - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, shapeInfo); - sd::DebugHelper::retrieveDebugStatistics(p, &array); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + void *hX, const Nd4jLong *hXShapeInfo, + const Nd4jLong *hXOffsets, void *dX, + const Nd4jLong *dXShapeInfo, const Nd4jLong *dXOffsets, + void *hY, const Nd4jLong *hYShapeInfo, + const Nd4jLong *hYOffsets, void *dY, + const Nd4jLong *dYShapeInfo, const Nd4jLong *dYOffsets, + void *hIindexes, const Nd4jLong *hIndicesShapeInfo, + void *dIindexes, const Nd4jLong *dIndicesShapeInfo) { + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); + + try { + BUILD_SINGLE_SELECTOR( + iType, _scatterUpdate, + (extraPointers, opCode, numOfSubArrs, hX, hXShapeInfo, hXOffsets, dX, + dXShapeInfo, dXOffsets, hY, hYShapeInfo, hYOffsets, dY, dYShapeInfo, + dYOffsets, hIindexes, hIndicesShapeInfo, dIindexes, dIndicesShapeInfo), + INDEXING_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, + Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { + try { + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, shapeInfo); + sd::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { - try { - auto buf = reinterpret_cast(p); - int cnt = 0; - for (int i = 0; i < len; i++) - cnt += buf[cnt]; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + try { + auto buf = reinterpret_cast(p); + int cnt = 0; + for (int i = 0; i < len; i++) cnt += buf[cnt]; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +sd::ConstantDataBuffer *shapeBuffer(int rank, Nd4jLong *shape, + Nd4jLong *strides, sd::DataType dtype, + char order, Nd4jLong ews, bool empty) { + try { + auto buffer = new ConstantDataBuffer(); + *buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) { - try { - auto buffer = new ConstantDataBuffer(); - *buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo( - ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} +void deleteShapeBuffer(sd::ConstantDataBuffer *ptr) { delete ptr; } -void deleteShapeBuffer(sd::ConstantDataBuffer* ptr) { - delete ptr; -} +void deleteTadPack(sd::TadPack *ptr) { delete ptr; } -void deleteTadPack(sd::TadPack* ptr) { - delete ptr; +sd::ConstantDataBuffer *constantBufferLong(sd::DataType dtype, + const Nd4jLong *data, int length) { + return nullptr; } -sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, const Nd4jLong *data, int length) { - return nullptr; +sd::ConstantDataBuffer *constantBufferDouble(sd::DataType dtype, double *data, + int length) { + return nullptr; } -sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) { +sd::ConstantDataBuffer *constantBuffer(sd::DataType dtype, + sd::ConstantDescriptor *descriptor) { + try { + return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, + dtype); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); return nullptr; + } } -sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) { - try { - return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer *dbf) { + return dbf->primary(); } - -Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) { - return dbf->primary(); +Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer *dbf) { + return dbf->special(); } -Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer* dbf) { - return dbf->special(); +Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer *dbf) { + return dbf->length(); } -Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer* dbf) { - return dbf->length(); +Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer *dbf) { + return dbf->sizeOf(); } -Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) { - return dbf->sizeOf(); -} - -sd::graph::Context* createGraphContext(int nodeId) { - try { - return new sd::graph::Context(nodeId); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} -sd::graph::RandomGenerator* getGraphContextRandomGenerator(sd::graph::Context* ptr) { - return &ptr->randomGenerator(); +sd::graph::Context *createGraphContext(int nodeId) { + try { + return new sd::graph::Context(nodeId); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void markGraphContextInplace(sd::graph::Context* ptr, bool reallyInplace) { - ptr->markInplace(reallyInplace); +sd::graph::RandomGenerator *getGraphContextRandomGenerator( + sd::graph::Context *ptr) { + return &ptr->randomGenerator(); } -void setGraphContextCudaContext(sd::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { +void markGraphContextInplace(sd::graph::Context *ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); } -void setGraphContextInputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextCudaContext(sd::graph::Context *ptr, void *stream, + void *reductionPointer, + void *allocationPointer) {} +void setGraphContextInputArray(sd::graph::Context *ptr, int index, void *buffer, + void *shapeInfo, void *specialBuffer, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } -void setGraphContextOutputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextOutputArray(sd::graph::Context *ptr, int index, + void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, + specialShapeInfo); } -void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextInputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextOutputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextTArguments(sd::graph::Context* ptr, double *arguments, int numberOfArguments) { - ptr->setTArguments(arguments, numberOfArguments); +void setGraphContextTArguments(sd::graph::Context *ptr, double *arguments, + int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); } -void setGraphContextIArguments(sd::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { - ptr->setIArguments(arguments, numberOfArguments); +void setGraphContextIArguments(sd::graph::Context *ptr, Nd4jLong *arguments, + int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); } -void setGraphContextBArguments(sd::graph::Context* ptr, bool *arguments, int numberOfArguments) { - ptr->setBArguments(arguments, numberOfArguments); +void setGraphContextBArguments(sd::graph::Context *ptr, bool *arguments, + int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); } -void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { - std::vector dtypes(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - dtypes[e] = (sd::DataType) arguments[e]; +void setGraphContextDArguments(OpaqueContext *ptr, int *arguments, + int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (sd::DataType)arguments[e]; - ptr->setDArguments(dtypes); + ptr->setDArguments(dtypes); } -void deleteGraphContext(sd::graph::Context* ptr) { - delete ptr; -} +void deleteGraphContext(sd::graph::Context *ptr) { delete ptr; } -void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { - ptr->allowHelpers(reallyAllow); +void ctxAllowHelpers(OpaqueContext *ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); } -void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { - if (execMode < 0 || execMode > 2) - execMode = 0; +void ctxSetExecutionMode(OpaqueContext *ptr, int execMode) { + if (execMode < 0 || execMode > 2) execMode = 0; - ptr->setExecutionMode((samediff::ExecutionMode) execMode); + ptr->setExecutionMode((samediff::ExecutionMode)execMode); } -void ctxPurge(OpaqueContext* ptr) { - ptr->clearFastPath(); -} - -sd::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - return new sd::graph::RandomGenerator(rootSeed, nodeSeed); -} +void ctxPurge(OpaqueContext *ptr) { ptr->clearFastPath(); } -Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator* ptr) { - return ptr->rootState(); +sd::graph::RandomGenerator *createRandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + return new sd::graph::RandomGenerator(rootSeed, nodeSeed); } -Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator* ptr) { - return ptr->nodeState(); +Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator *ptr) { + return ptr->rootState(); } -void setRandomGeneratorStates(sd::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { - ptr->setStates(rootSeed, nodeSeed); +Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator *ptr) { + return ptr->nodeState(); } -int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeInt(index); +void setRandomGeneratorStates(sd::graph::RandomGenerator *ptr, + Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); } -Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeLong(index); +int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeInt(index); } -void deleteRandomGenerator(sd::graph::RandomGenerator* ptr) { - delete ptr; +Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeLong(index); } +void deleteRandomGenerator(sd::graph::RandomGenerator *ptr) { delete ptr; } int dataTypeFromNpyHeader(void *header) { - return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); + return (int)cnpy::dataTypeFromHeader(reinterpret_cast(header)); } Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - try { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for (unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; - - if (arr.shape[i] == 0) - _empty = true; - } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return const_cast(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -void sortByKey(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void sortByValue(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + try { + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = sd::ShapeBuilders::createShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); } + return const_cast( + sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, + true)); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void sortTadByKey(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - int *dimension, int dimensionLength, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} +void sortByKey(Nd4jPointer *extraPointers, void *x, const Nd4jLong *xShapeInfo, + void *dx, const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); -void sortTadByValue(Nd4jPointer *extraPointers, - void *x, const Nd4jLong *xShapeInfo, - void *dx, const Nd4jLong *dxShapeInfo, - void *y, const Nd4jLong *yShapeInfo, - void *dy, const Nd4jLong *dyShapeInfo, - int *dimension, int dimensionLength, - bool descending) { - try { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + BUILD_DOUBLE_SELECTOR(xType, yType, sd::DoubleMethods, + ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByValue(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByKey(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, + descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByValue(Nd4jPointer *extraPointers, void *x, + const Nd4jLong *xShapeInfo, void *dx, + const Nd4jLong *dxShapeInfo, void *y, + const Nd4jLong *yShapeInfo, void *dy, + const Nd4jLong *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, sd::DoubleMethods, + ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, + dimensionLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -const char* runLightBenchmarkSuit(bool printOut) { - try { - sd::LightBenchmarkSuit suit; - auto result = suit.runSuit(); +const char *runLightBenchmarkSuit(bool printOut) { + try { + sd::LightBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { - return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId); + return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId); } -const char* runFullBenchmarkSuit(bool printOut) { - try { - sd::FullBenchmarkSuit suit; - auto result = suit.runSuit(); +const char *runFullBenchmarkSuit(bool printOut) { + try { + sd::FullBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::LaunchContext* defaultLaunchContext() { - return LaunchContext::defaultContext(); +sd::LaunchContext *defaultLaunchContext() { + return LaunchContext::defaultContext(); } -Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcScalarPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcReductionPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcExecutionStream(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcCopyStream(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcBlasHandle(OpaqueLaunchContext *lc) { return nullptr; } -Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { - return nullptr; -} +Nd4jPointer lcSolverHandle(OpaqueLaunchContext *lc) { return nullptr; } int lastErrorCode() { - return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); + return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); } -const char* lastErrorMessage() { - return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); +const char *lastErrorMessage() { + return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); } -void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { - ptr->setShapeFunctionOverride(reallyOverride); +void ctxShapeFunctionOverride(OpaqueContext *ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); } -int binaryLevel() { +int binaryLevel() { #ifdef CPU_FEATURES #if defined(F_X64) - return 1; -#elif defined (F_AVX2) - return 2; -#elif defined (F_AVX512) - return 3; + return 1; +#elif defined(F_AVX2) + return 2; +#elif defined(F_AVX512) + return 3; #else - return 0; + return 0; #endif #else - return 0; + return 0; #endif } int optimalLevel() { #ifdef CPU_FEATURES - auto features = cpu_features::GetX86Info().features; + auto features = cpu_features::GetX86Info().features; - if (features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd) - return 3; - else if (features.avx && features.avx2) - return 2; - else - return 1; + if (features.avx && features.avx2 && features.avx512f && features.avx512vl && + features.avx512bw && features.avx512dq && features.avx512cd) + return 3; + else if (features.avx && features.avx2) + return 2; + else + return 1; #else - return 0; + return 0; #endif } bool isMinimalRequirementsMet() { #ifdef CPU_FEATURES - auto features = cpu_features::GetX86Info().features; + auto features = cpu_features::GetX86Info().features; #if defined(F_X64) - return true; -#elif defined (F_AVX2) - return features.avx && features.avx2; -#elif defined (F_AVX512) - // we're optimizing for skylake-avx512 features, so we'll check those out - return features.avx && features.avx2 && features.avx512f && features.avx512vl && features.avx512bw && features.avx512dq && features.avx512cd; + return true; +#elif defined(F_AVX2) + return features.avx && features.avx2; +#elif defined(F_AVX512) + // we're optimizing for skylake-avx512 features, so we'll check those out + return features.avx && features.avx2 && features.avx512f && + features.avx512vl && features.avx512bw && features.avx512dq && + features.avx512cd; #else - return true; + return true; #endif #else - return true; + return true; #endif } bool isOptimalRequirementsMet() { #ifdef CPU_FEATURES - auto b = ::binaryLevel(); - auto o = ::optimalLevel(); + auto b = ::binaryLevel(); + auto o = ::optimalLevel(); - if (b == o) - return true; - else - return false; -#else + if (b == o) return true; + else + return false; +#else + return true; #endif } -OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - return allocateDataBuffer(elements, dataType, allocateBoth); +OpaqueDataBuffer *dbAllocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + return allocateDataBuffer(elements, dataType, allocateBoth); } -OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - try { - auto dtype = DataTypeUtils::fromInt(dataType); - return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype) , dtype, allocateBoth); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +OpaqueDataBuffer *allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), + dtype, allocateBoth); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->primary(); + return dataBuffer->primary(); } Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->special(); + return dataBuffer->special(); } -void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { - delete dataBuffer; -} +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { delete dataBuffer; } -OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { - auto buffer = dbAllocateDataBuffer(0, dataType, false); +OpaqueDataBuffer *dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, + Nd4jPointer primary, + Nd4jPointer special) { + auto buffer = dbAllocateDataBuffer(0, dataType, false); - if (primary != nullptr) - buffer->setPrimary(primary, elements); + if (primary != nullptr) buffer->setPrimary(primary, elements); - if (special != nullptr) - buffer->setSpecial(special, elements); + if (special != nullptr) buffer->setSpecial(special, elements); - return buffer; + return buffer; } -void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { - dataBuffer->setPrimary(primaryBuffer, numBytes); +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, + Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); } -void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { - dataBuffer->setSpecial(specialBuffer, numBytes); +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, + Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); } void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocatePrimary(); + dataBuffer->dataBuffer()->allocatePrimary(); } void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocateSpecial(); + dataBuffer->dataBuffer()->allocateSpecial(); } void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - try { - dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + try { + dataBuffer->dataBuffer()->expand( + elements * + DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { - return new InteropDataBuffer(*dataBuffer, length, offset); +OpaqueDataBuffer *dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, + Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); } void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToSpecial(); + dataBuffer->dataBuffer()->syncToSpecial(); } void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToPrimary(nullptr); + dataBuffer->dataBuffer()->syncToPrimary(nullptr); } void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readPrimary(); + dataBuffer->dataBuffer()->readPrimary(); } void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writePrimary(); + dataBuffer->dataBuffer()->writePrimary(); } void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readSpecial(); + dataBuffer->dataBuffer()->readSpecial(); } void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writeSpecial(); + dataBuffer->dataBuffer()->writeSpecial(); } void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - dataBuffer->expand(elements); + dataBuffer->expand(elements); } -int dbLocality(OpaqueDataBuffer *dataBuffer) { - return 0; -} +int dbLocality(OpaqueDataBuffer *dataBuffer) { return 0; } void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { - dataBuffer->setDeviceId(deviceId); + dataBuffer->setDeviceId(deviceId); } -int dbDeviceId(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->deviceId(); -} +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { return dataBuffer->deviceId(); } void dbClose(OpaqueDataBuffer *dataBuffer) { - dataBuffer->getDataBuffer()->close(); -} - -BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong const*, void*, Nd4jLong const*, const int, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong const* , Nd4jPointer*, Nd4jLong const*, Nd4jLong const*, Nd4jLong const*), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong* const*, void**, Nd4jLong* const*, int, int*, Nd4jLong* const*, Nd4jLong* const*), LIBND4J_TYPES); - - + dataBuffer->getDataBuffer()->close(); +} + +BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, + (void *, Nd4jLong const *, void *, Nd4jLong const *, + const int, Nd4jLong const *, Nd4jLong const *, + Nd4jLong const *, Nd4jLong const *, Nd4jLong const *), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void tearGeneric, + (void *, Nd4jLong const *, Nd4jPointer *, + Nd4jLong const *, Nd4jLong const *, Nd4jLong const *), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, + (void **, Nd4jLong *const *, void **, Nd4jLong *const *, + int, int *, Nd4jLong *const *, Nd4jLong *const *), + LIBND4J_TYPES); diff --git a/libnd4j/include/legacy/cuda/BlasVersionHelper.cu b/libnd4j/include/legacy/cuda/BlasVersionHelper.cu index 04b0e78f1990..d156ac1b9ecd 100644 --- a/libnd4j/include/legacy/cuda/BlasVersionHelper.cu +++ b/libnd4j/include/legacy/cuda/BlasVersionHelper.cu @@ -21,9 +21,9 @@ #include namespace sd { - BlasVersionHelper::BlasVersionHelper() { - _blasMajorVersion = __CUDACC_VER_MAJOR__; - _blasMinorVersion = __CUDACC_VER_MINOR__; - _blasPatchVersion = __CUDACC_VER_BUILD__; - } -} \ No newline at end of file +BlasVersionHelper::BlasVersionHelper() { + _blasMajorVersion = __CUDACC_VER_MAJOR__; + _blasMinorVersion = __CUDACC_VER_MINOR__; + _blasPatchVersion = __CUDACC_VER_BUILD__; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu index f01daffd7798..22e42410f522 100644 --- a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu +++ b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu @@ -14,440 +14,482 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include -#include -#include -#include +#include #include -#include +#include +#include #include +#include +#include #include -#include +#include #include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include -#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include #include -#include +#include +#include +#include +#include +#include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include using namespace sd; /** -* This is utility kernel, that updates given special buffer with proper values in device memory -*/ -extern "C" __global__ void prepareShapeBuffer(int *dimension, int *maxDimension, Nd4jLong *specialPointer, int rows, sd::DataType dataType) { - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid > 0) - return; - - dimension[0] = 0; - maxDimension[0] = 1; - - specialPointer[0] = 2; - specialPointer[1] = rows; - specialPointer[2] = 1; - specialPointer[3] = 1; - specialPointer[4] = 1; - specialPointer[5] = 0; - specialPointer[6] = 1; - specialPointer[7] = 99; - - ArrayOptions::setDataType(specialPointer, dataType); - - //printf("special[0]: [%lld]\n", (long long) specialPointer[0]); - //shape::printShapeInfoLinear("prepareShapeBuffer", specialPointer); + * This is utility kernel, that updates given special buffer with proper values + * in device memory + */ +extern "C" __global__ void prepareShapeBuffer(int* dimension, int* maxDimension, + Nd4jLong* specialPointer, + int rows, sd::DataType dataType) { + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid > 0) return; + + dimension[0] = 0; + maxDimension[0] = 1; + + specialPointer[0] = 2; + specialPointer[1] = rows; + specialPointer[2] = 1; + specialPointer[3] = 1; + specialPointer[4] = 1; + specialPointer[5] = 0; + specialPointer[6] = 1; + specialPointer[7] = 99; + + ArrayOptions::setDataType(specialPointer, dataType); + + // printf("special[0]: [%lld]\n", (long long) specialPointer[0]); + // shape::printShapeInfoLinear("prepareShapeBuffer", specialPointer); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseTransform(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (xType != zType && yType != zType) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type"); - if (lc == nullptr) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform: launch context cannot be nullptr !"); - if (stream == nullptr) - throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform: CUDA stream cannot be nullptr !"); - - dim3 launchDims(256, 1024, 8192); +void NativeOpExecutioner::execPairwiseTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (xType != zType && yType != zType) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform requires Z operand to have " + "either X or Y type"); + if (lc == nullptr) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform: launch context cannot be " + "nullptr !"); + if (stream == nullptr) + throw std::runtime_error( + "NativeOpExecutioner::execPairwiseTransform: CUDA stream cannot be " + "nullptr !"); + + dim3 launchDims(256, 1024, 8192); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES) + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, LIBND4J_TYPES) #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES) + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::pairwise_transforms::PairWiseTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES) #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseTransform failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseBoolTransform( sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", sd::DataType::BOOL, zType); - - if (yType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform both operands must have same data type", xType, yType); - - dim3 launchDims(256, 1024, 16384); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseBoolTransform failed", res); +void NativeOpExecutioner::execPairwiseBoolTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data " + "type", + sd::DataType::BOOL, zType); + + if (yType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseBoolTransform both operands must " + "have same data type", + xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseBoolTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execPairwiseIntTransform( sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void * hZ, Nd4jLong const* hZShapeInfo, - void * dZ, Nd4jLong const* dZShapeInfo, - void *extraParams) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", sd::DataType::BOOL, zType); - - if (yType != xType || zType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType); - - dim3 launchDims(256, 1024, 16384); - - BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execPairwiseIntTransform failed", res); +void NativeOpExecutioner::execPairwiseIntTransform( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data " + "type", + sd::DataType::BOOL, zType); + + if (yType != xType || zType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execPairwiseIntTransform both operands must have " + "same data type", + xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_SINGLE_SELECTOR( + xType, functions::pairwise_transforms::PairWiseIntTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseIntTransform failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStatsScalar failed", res); +void NativeOpExecutioner::execSummaryStatsScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, + hZShapeInfo, nullptr, nullptr, + biasCorrected, reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execSummaryStatsScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); - - if (yType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("F3B opNum:[%i]\n", opNum); - - dim3 launchDims(256, 256, 1024); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires Z operand to have " + "BOOL type"); + + if (yType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires both X & Y operands " + "to have same type"); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("F3B opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo, - void *extraParams) { - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastBool( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo, void* extraParams) { + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } - -void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void* hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); - - if (yType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - - dim3 launchDims(256, 256, 1024); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcastBool failed", res); +void NativeOpExecutioner::execInverseBroadcastBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires Z operand to have " + "BOOL type"); + + if (yType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastBool requires both X & Y operands " + "to have same type"); + + dim3 launchDims(256, 256, 1024); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::broadcast::BroadcastBool, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, extraParams, + dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - dim3 launchDims(256, 256, 1024); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcastBool failed", res); +void NativeOpExecutioner::execBroadcastInt( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcastBool failed", res); } -void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - if (!DataTypeUtils::isZ(zType)) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); - - if (yType != xType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("F3BI opNum:[%i]\n", opNum); - - dim3 launchDims(256, 256, 1024); - - BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcastInt failed", res); +void NativeOpExecutioner::execInverseBroadcastInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires Z operand to have INT " + "type"); + + if (yType != xType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execBroadcastInt requires both X & Y operands to " + "have same type"); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("F3BI opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR( + xType, functions::broadcast::BroadcastInt, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastInt failed", res); } //////////////////////////////////////////////////////////////////////// @@ -463,216 +505,242 @@ void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); - auto stream = lc->getCudaStream(); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; - - dim3 launchDims(256, 256, 1024); + dim3 launchDims(256, 256, 1024); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcast failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum, - const void *hX, const Nd4jLong *hXShapeInfo, - const void *dX, const Nd4jLong *dXShapeInfo, - const void *hY, const Nd4jLong *hYShapeInfo, - const void *dY, const Nd4jLong *dYShapeInfo, - void *hZ, const Nd4jLong *hZShapeInfo, - void *dZ, const Nd4jLong *dZShapeInfo) { - - auto stream = lc->getCudaStream(); +void NativeOpExecutioner::execBroadcast( + sd::LaunchContext* lc, const int opNum, const void* hX, + const Nd4jLong* hXShapeInfo, const void* dX, const Nd4jLong* dXShapeInfo, + const void* hY, const Nd4jLong* hYShapeInfo, const void* dY, + const Nd4jLong* dYShapeInfo, void* hZ, const Nd4jLong* hZShapeInfo, + void* dZ, const Nd4jLong* dZShapeInfo) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - dim3 launchDims; + dim3 launchDims; - launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock - launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = 1024; // shared memory + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / + launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execBroadcast failed", res); } -void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadOnlyShapeInfoZ,Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); +void NativeOpExecutioner::execInverseBroadcast( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) - return; + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; - dim3 launchDims(256, 256, 1024); + dim3 launchDims(256, 256, 1024); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::broadcast::Broadcast, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::broadcast::Broadcast, + ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, + dYShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execInverseBroadcast failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execInverseBroadcast failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("SF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto xRank = shape::rank(hXShapeInfo); - - if (zType != xType) - throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType); - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceSame failed", res); +void NativeOpExecutioner::execReduceSame( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("SF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xRank = shape::rank(hXShapeInfo); + + if (zType != xType) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceSame requires both X & Z operands to " + "have same type", + xType, zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceSame failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension,int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("LF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execReduceLong wrong Z data type", sd::DataType::INT64, zType); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceLong failed", res); - +void NativeOpExecutioner::execReduceLong( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("LF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceLong wrong Z data type", + sd::DataType::INT64, zType); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, LONG_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceLong failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("BF7 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("NativeOpExecutioner::execReduceBool requires Z operand to have BOOL type"); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceBool failed", res); +void NativeOpExecutioner::execReduceBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("BF7 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "NativeOpExecutioner::execReduceBool requires Z operand to have BOOL " + "type"); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceBool failed", res); } //////////////////////////////////////////////////////////////////////// @@ -687,39 +755,44 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc, * @param dimension * @param dimensionLength */ -void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("F2 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) - throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType); - - auto dz = reinterpret_cast(dZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execIndexReduce failed", res); +void NativeOpExecutioner::execIndexReduce( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("F2 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) + throw datatype_exception::build( + "NativeOpExecutioner::execIndexReduce requires Z operand to have " + "INT32/INT64 type", + zType); + + auto dz = reinterpret_cast(dZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, + shape::rank(hXShapeInfo), extraParams, dz, + dZShapeInfo, shape::rank(hZShapeInfo), dimension, + dimensionLength, 1, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + LIBND4J_TYPES, INDEXING_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execIndexReduce failed", res); } //////////////////////////////////////////////////////////////////////// @@ -732,38 +805,38 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc, * @param dZ * @param dZShapeInfo */ -void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension,int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("F8 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xRank = shape::rank(hXShapeInfo); - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceFloat failed", res); +void NativeOpExecutioner::execReduceFloat( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("F8 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xRank = shape::rank(hXShapeInfo); + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + dimension, dimensionLength, reductionPointer, tadShapeInfo, + tadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceFloat failed", res); } - /** * * @param opNum @@ -772,947 +845,1030 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc, * @param extraParams */ //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo){ - - if (sd::Environment::getInstance()->isDebug()) - printf("F1 opNum:[%i]\n", opNum); - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (sd::Environment::getInstance()->isDebugAndVerbose() && launchDims.x == 1) - printf("AF1 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - // FIXME: we want Z to be one of integer types - //if (!DataTypeUtils::isZ(zType)) - // throw sd::datatype_exception("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have one of integer types") - if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) - throw sd::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT32/INT64 data type", zType); - - auto dz = reinterpret_cast(dZ); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, - opNum, - dX, dXShapeInfo, shape::rank(hXShapeInfo), - extraParams, - dz, dZShapeInfo, 0, - nullptr, 0, - 1, - allocationPointer, reductionPointer, - nullptr, nullptr), LIBND4J_TYPES, INDEXING_TYPES); - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execIndexReduceScalar failed", res); +void NativeOpExecutioner::execIndexReduceScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + if (sd::Environment::getInstance()->isDebug()) + printf("F1 opNum:[%i]\n", opNum); + + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (sd::Environment::getInstance()->isDebugAndVerbose() && launchDims.x == 1) + printf("AF1 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + // FIXME: we want Z to be one of integer types + // if (!DataTypeUtils::isZ(zType)) + // throw sd::datatype_exception("NativeOpExecutioner::execIndexReduceScalar + // requires Z operand to have one of integer types") + if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execIndexReduceScalar requires Z operand to have " + "INT32/INT64 data type", + zType); + + auto dz = reinterpret_cast(dZ); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::indexreduce::IndexReduce, + ::executeIndexReduceScalar( + launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), + extraParams, dz, dZShapeInfo, 0, nullptr, 0, 1, allocationPointer, + reductionPointer, nullptr, nullptr), + LIBND4J_TYPES, INDEXING_TYPES); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execIndexReduceScalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceFloatScalar failed", res); +void NativeOpExecutioner::execReduceFloatScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceFloatFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execReduceFloatScalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("NativeOpExecutioner::execReduceBoolScalar requires Z operand to have BOOL type"); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceBoolScalar failed", res); +void NativeOpExecutioner::execReduceBoolScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "NativeOpExecutioner::execReduceBoolScalar requires Z operand to have " + "BOOL type"); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceBoolScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != xType) - throw datatype_exception::build("NativeOpExecutioner::execReduceSameScalar requires both X & Z operands to have same type", xType, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceSameScalar failed", res); +void NativeOpExecutioner::execReduceSameScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != xType) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceSameScalar requires both X & Z " + "operands to have same type", + xType, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_SINGLE_SELECTOR( + xType, functions::reduce::ReduceSameFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceSameScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execReduceLongScalar wrong Z data type", sd::DataType::INT64, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduceLongScalar failed", res); +void NativeOpExecutioner::execReduceLongScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build( + "NativeOpExecutioner::execReduceLongScalar wrong Z data type", + sd::DataType::INT64, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, + nullptr, 0, reductionPointer, nullptr), + LIBND4J_TYPES, LONG_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduceLongScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformSame(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (xType != zType) { - throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); - } - - dim3 launchDims(512, 512, 16384); - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformSame failed", res); +void NativeOpExecutioner::execTransformSame( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (xType != zType) { + throw std::runtime_error( + "NativeOpExecutioner::execTransformSame requires X & Z to have same " + "type"); + } + + dim3 launchDims(512, 512, 16384); + BUILD_SINGLE_SELECTOR( + xType, functions::transform::TransformSame, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformSame failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (!DataTypeUtils::isB(zType)) { - throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); - } - - dim3 launchDims(512, 512, 16384); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformBool failed", res); +void NativeOpExecutioner::execTransformBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (!DataTypeUtils::isB(zType)) { + throw std::runtime_error( + "NativeOpExecutioner::execTransformBool requires Z to have same " + "boolean type"); + } + + dim3 launchDims(512, 512, 16384); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformBool, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) { - cudaMemcpyAsync(dZ, dX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), cudaMemcpyDeviceToDevice, *stream); - } - else { - - dim3 launchDims(512, 512, 2048); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); - } - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformAny failed", res); +void NativeOpExecutioner::execTransformAny( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (opNum == sd::transform::Assign && + shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && + shape::order(hXShapeInfo) == 'c' && xType == zType && + shape::elementWiseStride(hXShapeInfo) == 1 && + shape::elementWiseStride(hZShapeInfo) == 1) { + cudaMemcpyAsync( + dZ, dX, + shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), + cudaMemcpyDeviceToDevice, *stream); + } else { + dim3 launchDims(512, 512, 2048); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformAny, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, LIBND4J_TYPES); + } + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformAny failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformStrict(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) { - return; - } - - if (xType != zType || !DataTypeUtils::isR(xType)) { - throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); - } - - dim3 launchDims(512, 512, 16384); - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformStrict failed", res); +void NativeOpExecutioner::execTransformStrict( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) { + return; + } + + if (xType != zType || !DataTypeUtils::isR(xType)) { + throw datatype_exception::build( + "NativeOpExecutioner::execTransformStrict requires X & Z to have same " + "floating point type", + xType, zType); + } + + dim3 launchDims(512, 512, 16384); + BUILD_SINGLE_SELECTOR( + xType, functions::transform::TransformStrict, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformStrict failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execTransformFloat(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraParams, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo)) - return; - - if (!DataTypeUtils::isR(zType)) - throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); - - dim3 launchDims(512, 512, 2048); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execTransformFloat failed", res); +void NativeOpExecutioner::execTransformFloat( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraParams, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zType = ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo)) return; + + if (!DataTypeUtils::isR(zType)) + throw datatype_exception::build( + "NativeOpExecutioner::execTransformFloat requires Z to have floating " + "point type", + zType); + + dim3 launchDims(512, 512, 2048); + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::transform::TransformFloat, + ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, + xRank, extraParams, dZ, dZShapeInfo, zRank, + nullptr, nullptr, nullptr, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execTransformFloat failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStats A failed", res); +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSummaryStats requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, extraParams, dZ, dZShapeInfo, + hZShapeInfo, nullptr, nullptr, biasCorrected, + reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execSummaryStats A failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool biasCorrected) { - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - - dim3 launchDims = dim3(256, 256, 32768); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execSummaryStats B failed", res); +void NativeOpExecutioner::execSummaryStats( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + + dim3 launchDims = dim3(256, 256, 32768); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execSummaryStats requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::summarystats::SummaryStatsReduce, + ::execSummaryStatsReduce( + launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, + dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execSummaryStats B failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - auto stream = lc->getCudaStream(); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3 failed", res); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto reductionPointer = lc->getReductionPointer(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks( + shape::length(hXShapeInfo), blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Z operand to have floating " + "point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, allocationPointer, + reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3 failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - if(shape::isScalar(hZShapeInfo)) { - NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); - return; - } - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Z operand to have floating point data type", zType); - - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, - dX, dXShapeInfo, - dY, dYShapeInfo, - extraParams, - dZ, dZShapeInfo, - dimension, dimensionLength, - 1, - allocationPointer, - tadOnlyShapeInfo, tadOffsets, - yTadOnlyShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3 B failed", res); +void NativeOpExecutioner::execReduce3( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadOnlyShapeInfo, + Nd4jLong const* yTadOffsets) { + if (shape::isScalar(hZShapeInfo)) { + NativeOpExecutioner::execReduce3( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + return; + } + + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3 requires Z operand to have floating " + "point data type", + zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, + yTadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3 B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo) { - - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - auto reductionPointer = lc->getReductionPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Z operand to have floating point data type", zType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3Scalar failed", res); +void NativeOpExecutioner::execReduce3Scalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo) { + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + auto reductionPointer = lc->getReductionPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3Scalar requires Y operand to have X " + "type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3Scalar requires Z operand to have " + "floating point data type", + zType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, allocationPointer, + reductionPointer, nullptr), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3Scalar failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims = dim3(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); - - if (!DataTypeUtils::isB(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires Z operand to have BOOL type"); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarBool failed", res); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires X & Y to have same type"); + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires Z operand to have BOOL " + "type"); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, + dZShapeInfo, dScalar, extraParams), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarBool failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarBool(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); - - if (!DataTypeUtils::isB(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarBool requires Z operand to have BOOL type"); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarBool B failed", res); +void NativeOpExecutioner::execScalarBool( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires X & Y to have same type"); + + if (!DataTypeUtils::isB(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarBool requires Z operand to have BOOL " + "type"); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::scalar::ScalarBoolTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarBool B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims = dim3(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); - - if (!DataTypeUtils::isZ(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); - - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarInt failed", res); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires Z operand to have INT " + "type"); + + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, + dZShapeInfo, dScalar, extraParams), + INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarInt failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalarInt(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - if (xType != yType || zType != xType) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); - - if (!DataTypeUtils::isZ(zType) ) - throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); - - BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalarInt B failed", res); +void NativeOpExecutioner::execScalarInt( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + if (xType != yType || zType != xType) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error( + "NativeOpExecutioner::execScalarInt requires Z operand to have INT " + "type"); + + BUILD_SINGLE_SELECTOR( + xType, functions::scalar::ScalarIntTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalarInt B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void* hZ, Nd4jLong const* hZShapeInfo, - void* dZ, Nd4jLong const* dZShapeInfo, - void const* hScalar, Nd4jLong const* hScalarShapeInfo, - void const* dScalar, Nd4jLong const* dScalarShapeInfo, - void *extraParams, bool allowParallelism) { - - auto stream = lc->getCudaStream(); - - dim3 launchDims(256, 512, 8192); +void NativeOpExecutioner::execScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalar, + Nd4jLong const* hScalarShapeInfo, void const* dScalar, + Nd4jLong const* dScalarShapeInfo, void* extraParams, + bool allowParallelism) { + auto stream = lc->getCudaStream(); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + dim3 launchDims(256, 512, 8192); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, + extraParams), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, + hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, + extraParams), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalar failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalar failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execScalar(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void const* hScalars, Nd4jLong const* hScalarShapeInfo, - void const* dScalars, Nd4jLong const* dScalarShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto stream = lc->getCudaStream(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) - return; - - dim3 launchDims(256, 256, 16384); +void NativeOpExecutioner::execScalar( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void const* hScalars, + Nd4jLong const* hScalarShapeInfo, void const* dScalars, + Nd4jLong const* dScalarShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) return; + + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execScalar B failed", res); + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execScalar B failed", res); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - dim3 launchDims = dim3(512, 512, 32768); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - auto rng = reinterpret_cast(stateHost); - - // functions::random::RandomFunction::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), - BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom X failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom(sd::LaunchContext* lc, int opNum, + Nd4jPointer stateHost, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, + void* extraArguments) { + auto stream = lc->getCudaStream(); + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + dim3 launchDims = dim3(512, 512, 32768); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + auto rng = reinterpret_cast(stateHost); + + // functions::random::RandomFunction::executeCudaSingle(launchDims, + // extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), + BUILD_SINGLE_SELECTOR( + zType, functions::random::RandomFunction, + ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, + dZShapeInfo, extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom X failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - auto rng = reinterpret_cast(stateHost); - - dim3 launchDims = dim3(512, 512, 32768); - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - // functions::random::RandomFunction::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments); - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom XY failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext* lc, int opNum, Nd4jPointer stateHost, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* hZ, Nd4jLong const* hZShapeInfo, void* dZ, + Nd4jLong const* dZShapeInfo, void* extraArguments) { + auto stream = lc->getCudaStream(); + + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + auto rng = reinterpret_cast(stateHost); + + dim3 launchDims = dim3(512, 512, 32768); + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + // functions::random::RandomFunction::executeCudaDouble(launchDims, + // extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, + // extraArguments); + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, + dXShapeInfo, dZ, dZShapeInfo, extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom XY failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execRandom(sd::LaunchContext *lc, - int opNum, - Nd4jPointer stateHost, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - - auto stream = lc->getCudaStream(); - auto sizeOf = sizeof(sd::graph::RandomGenerator); - Nd4jPointer stateDevice; - - cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); - checkCudaErrors(cudaStreamSynchronize(*stream)); - checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); - - auto rng = reinterpret_cast(stateHost); - - dim3 launchDims = dim3(512, 512, 32768); - auto xType = sd::ArrayOptions::dataType(hZShapeInfo); - // functions::random::RandomFunction::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments); - BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - - res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execRandom XYZ failed", res); - - cudaFree(stateDevice); - - rng->rewindH(shape::length(hZShapeInfo)); +void NativeOpExecutioner::execRandom( + sd::LaunchContext* lc, int opNum, Nd4jPointer stateHost, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void const* hY, Nd4jLong const* hYShapeInfo, void const* dY, + Nd4jLong const* dYShapeInfo, void* hZ, Nd4jLong const* hZShapeInfo, + void* dZ, Nd4jLong const* dZShapeInfo, void* extraArguments) { + auto stream = lc->getCudaStream(); + auto sizeOf = sizeof(sd::graph::RandomGenerator); + Nd4jPointer stateDevice; + + cudaError_t res = cudaMalloc(reinterpret_cast(&stateDevice), sizeOf); + checkCudaErrors(cudaStreamSynchronize(*stream)); + checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, + cudaMemcpyHostToDevice, *stream)); + + auto rng = reinterpret_cast(stateHost); + + dim3 launchDims = dim3(512, 512, 32768); + auto xType = sd::ArrayOptions::dataType(hZShapeInfo); + // functions::random::RandomFunction::executeCudaTriple(launchDims, + // extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, + // dZShapeInfo, extraArguments); + BUILD_SINGLE_SELECTOR( + xType, functions::random::RandomFunction, + ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, + dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, + extraArguments), + FLOAT_TYPES); + + res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execRandom XYZ failed", res); + + cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - auto reductionPointer = lc->getReductionPointer(); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("D119 opNum:[%i]\n", opNum); - - dim3 launchDims(shape::length(hZShapeInfo), 256, 32768); - - if (sd::Environment::getInstance()->isVerbose() && launchDims.x == 1) - printf("AD119 opNum:[%i]\n", opNum); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (yType != xType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3All both operands must have same data type", xType, yType); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3All failed", res); +void NativeOpExecutioner::execReduce3All( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParamsVals, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets) { + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + auto reductionPointer = lc->getReductionPointer(); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("D119 opNum:[%i]\n", opNum); + + dim3 launchDims(shape::length(hZShapeInfo), 256, 32768); + + if (sd::Environment::getInstance()->isVerbose() && launchDims.x == 1) + printf("AD119 opNum:[%i]\n", opNum); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (yType != xType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3All both operands must have same data " + "type", + xType, yType); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, + yOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3All failed", res); } - //////////////////////////////////////////////////////////////////////// -void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc, - int opNum, - void const* hX, Nd4jLong const* hXShapeInfo, - void const* dX, Nd4jLong const* dXShapeInfo, - void *extraParams, - void const* hY, Nd4jLong const* hYShapeInfo, - void const* dY, Nd4jLong const* dYShapeInfo, - void *hZ, Nd4jLong const* hZShapeInfo, - void *dZ, Nd4jLong const* dZShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yTadOffsets) { - - if(shape::isScalar(hZShapeInfo)) { - NativeOpExecutioner::execReduce3(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); - return; - } - - auto stream = lc->getCudaStream(); - auto allocationPointer = lc->getAllocationPointer(); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (xType != yType) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Y operand to have X type", xType, yType); - - if (!DataTypeUtils::isR(zType)) - throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType); - - auto numBlocks = shape::length(hZShapeInfo); - dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES); - - // TODO: remove after the release - auto res = cudaStreamSynchronize(*stream); - if (res != 0) - throw cuda_exception::build("execReduce3TAD failed", res); +void NativeOpExecutioner::execReduce3TAD( + sd::LaunchContext* lc, int opNum, void const* hX, + Nd4jLong const* hXShapeInfo, void const* dX, Nd4jLong const* dXShapeInfo, + void* extraParams, void const* hY, Nd4jLong const* hYShapeInfo, + void const* dY, Nd4jLong const* dYShapeInfo, void* hZ, + Nd4jLong const* hZShapeInfo, void* dZ, Nd4jLong const* dZShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yTadOffsets) { + if (shape::isScalar(hZShapeInfo)) { + NativeOpExecutioner::execReduce3( + lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + return; + } + + auto stream = lc->getCudaStream(); + auto allocationPointer = lc->getAllocationPointer(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3TAD requires Y operand to have X type", + xType, yType); + + if (!DataTypeUtils::isR(zType)) + throw sd::datatype_exception::build( + "NativeOpExecutioner::execReduce3TAD requires Z operand to have " + "floating point data type", + zType); + + auto numBlocks = shape::length(hZShapeInfo); + dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce3::Reduce3, + ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, + extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, + allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, + yTadOffsets), + LIBND4J_TYPES, FLOAT_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) throw cuda_exception::build("execReduce3TAD failed", res); } - diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu old mode 100755 new mode 100644 index 907dc4ccb177..81cc8f9fa16a --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -14,32 +14,25 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include - -#include - - -#include #include #include #include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include +#include #include -#include - +#include +#include //#include @@ -69,268 +62,295 @@ int minThreads = 32; __constant__ char deviceConstantMemory[49152]; - // this method just does type conversion in fancy way int getDeviceId(Nd4jPointer ptrToDeviceId) { - return (int)(Nd4jLong)ptrToDeviceId; + return (int)(Nd4jLong)ptrToDeviceId; } /* * Basic CUDA constants here: number of blocks per MP */ int getDeviceBlockThreshold(int deviceId) { - int ccMinor = deviceProperties[deviceId].minor; - int ccMajor = deviceProperties[deviceId].major; + int ccMinor = deviceProperties[deviceId].minor; + int ccMajor = deviceProperties[deviceId].major; - int blockThreshold = 8; + int blockThreshold = 8; - if (ccMajor >= 5) - blockThreshold = 32; - else if (ccMajor == 3) - blockThreshold = 16; - else if (ccMajor < 3) - blockThreshold = 8; + if (ccMajor >= 5) + blockThreshold = 32; + else if (ccMajor == 3) + blockThreshold = 16; + else if (ccMajor < 3) + blockThreshold = 8; - return blockThreshold; + return blockThreshold; } - /* - * This message returns shared memory threshold value. default overflow ratio is 0.3 + * This message returns shared memory threshold value. default overflow ratio is + * 0.3 */ int getDeviceSharedThreshold(int deviceId) { - int ccMinor = deviceProperties[deviceId].minor; - int ccMajor = deviceProperties[deviceId].major; + int ccMinor = deviceProperties[deviceId].minor; + int ccMajor = deviceProperties[deviceId].major; - // please note threshold isn't multiple of 32, and that's NOT a mistake + // please note threshold isn't multiple of 32, and that's NOT a mistake - int shmemThreshold; - if (ccMajor == 6 && ccMinor == 0) - shmemThreshold = 65536; - else if (ccMajor == 6 && ccMinor == 1) - shmemThreshold = 49152; - else if (ccMajor == 5 && ccMinor == 2) - shmemThreshold = 98304; - else if (ccMajor == 5) - shmemThreshold = 65536; - else if (ccMajor == 3 && ccMinor == 7) - shmemThreshold = 114688; - else shmemThreshold = 49152; + int shmemThreshold; + if (ccMajor == 6 && ccMinor == 0) + shmemThreshold = 65536; + else if (ccMajor == 6 && ccMinor == 1) + shmemThreshold = 49152; + else if (ccMajor == 5 && ccMinor == 2) + shmemThreshold = 98304; + else if (ccMajor == 5) + shmemThreshold = 65536; + else if (ccMajor == 3 && ccMinor == 7) + shmemThreshold = 114688; + else + shmemThreshold = 49152; - return shmemThreshold / 0.3; + return shmemThreshold / 0.3; } - - -sd::buffer::Buffer * createScalarBuffer(cudaStream_t stream) { - auto scalarShapeInfo = shape::createScalarShapeInfo(); - auto buff = sd::buffer::createBuffer(scalarShapeInfo,shape::shapeInfoLength(2), stream); - sd::buffer::copyDataToGpu(&buff, stream); - return buff; +sd::buffer::Buffer *createScalarBuffer(cudaStream_t stream) { + auto scalarShapeInfo = shape::createScalarShapeInfo(); + auto buff = sd::buffer::createBuffer(scalarShapeInfo, + shape::shapeInfoLength(2), stream); + sd::buffer::copyDataToGpu(&buff, stream); + return buff; } - class ScalarShapeInformation { -private: - sd::buffer::Buffer *scalarDimension; - sd::buffer::Buffer *scalarShapeInfo; -// std::thread::id threadId; - -public: - ScalarShapeInformation(cudaStream_t stream) { - auto scalarDimensionBuff = reinterpret_cast(malloc(sizeof(Nd4jLong))); - - CHECK_ALLOC(scalarDimensionBuff, "Failed to allocate ShapeInfoBuffer", sizeof(Nd4jLong)); + private: + sd::buffer::Buffer *scalarDimension; + sd::buffer::Buffer *scalarShapeInfo; + // std::thread::id threadId; - scalarDimensionBuff[0] = MAX_DIMENSION; - scalarDimension = sd::buffer::createBuffer(scalarDimensionBuff,1, stream); - scalarShapeInfo = createScalarBuffer(stream); -// threadId = std::this_thread::get_id(); + public: + ScalarShapeInformation(cudaStream_t stream) { + auto scalarDimensionBuff = + reinterpret_cast(malloc(sizeof(Nd4jLong))); - } - ~ScalarShapeInformation() { - sd::buffer::freeBuffer(&scalarShapeInfo); - sd::buffer::freeBuffer(&scalarDimension); - } + CHECK_ALLOC(scalarDimensionBuff, "Failed to allocate ShapeInfoBuffer", + sizeof(Nd4jLong)); + scalarDimensionBuff[0] = MAX_DIMENSION; + scalarDimension = sd::buffer::createBuffer(scalarDimensionBuff, 1, stream); + scalarShapeInfo = createScalarBuffer(stream); + // threadId = std::this_thread::get_id(); + } + ~ScalarShapeInformation() { + sd::buffer::freeBuffer(&scalarShapeInfo); + sd::buffer::freeBuffer(&scalarDimension); + } - Nd4jLong *getShapeInfoHostPointer() { - return scalarShapeInfo->data; - } + Nd4jLong *getShapeInfoHostPointer() { return scalarShapeInfo->data; } - Nd4jLong * getShapeInfoGpuPointer() { - return scalarShapeInfo->gData; - } + Nd4jLong *getShapeInfoGpuPointer() { return scalarShapeInfo->gData; } - Nd4jLong * getDimensionHostPointer() { - return scalarDimension->data; - } - - Nd4jLong * getDimensionGpuPointer() { - return scalarDimension->gData; - } + Nd4jLong *getDimensionHostPointer() { return scalarDimension->data; } + Nd4jLong *getDimensionGpuPointer() { return scalarDimension->gData; } }; - - - - template class ScalarInfo { - sd::buffer::Buffer *scalarData; - ScalarShapeInformation *shapeInfo; - T finalResult; - cudaStream_t streamRef; -public: - ScalarInfo(cudaStream_t stream) { - T *scalarResult = reinterpret_cast(malloc(sizeof(T))); - - CHECK_ALLOC(scalarResult, "Failed to allocate new scalar buffer", sizeof(T)); - - shapeInfo = new ScalarShapeInformation(stream); - scalarData = sd::buffer::createBuffer(scalarResult,1, stream); - streamRef = stream; - sd::buffer::copyDataToGpu(&scalarData, stream); - } - - T getFinalResultFromDevice() { - sd::buffer::copyDataFromGpu(&scalarData, streamRef); - return scalarData->data[0]; - } - - /** - * Get the device shape information - * representing a scalar - */ - Nd4jLong *getDeviceShapeInfo() { - return shapeInfo->getShapeInfoGpuPointer(); - } - - /** - * Get the dZ pointers - */ - T *getDevicePointer() { - return scalarData->gData; - } - - /** - * Get the infinite dimension device pointer - */ - Nd4jLong *getDimensionDevicePointer() { - return shapeInfo->getDimensionGpuPointer(); - } - - ~ScalarInfo() { - sd::buffer::freeBuffer(&scalarData); - delete shapeInfo; - } + sd::buffer::Buffer *scalarData; + ScalarShapeInformation *shapeInfo; + T finalResult; + cudaStream_t streamRef; + + public: + ScalarInfo(cudaStream_t stream) { + T *scalarResult = reinterpret_cast(malloc(sizeof(T))); + + CHECK_ALLOC(scalarResult, "Failed to allocate new scalar buffer", + sizeof(T)); + + shapeInfo = new ScalarShapeInformation(stream); + scalarData = sd::buffer::createBuffer(scalarResult, 1, stream); + streamRef = stream; + sd::buffer::copyDataToGpu(&scalarData, stream); + } + + T getFinalResultFromDevice() { + sd::buffer::copyDataFromGpu(&scalarData, streamRef); + return scalarData->data[0]; + } + + /** + * Get the device shape information + * representing a scalar + */ + Nd4jLong *getDeviceShapeInfo() { return shapeInfo->getShapeInfoGpuPointer(); } + + /** + * Get the dZ pointers + */ + T *getDevicePointer() { return scalarData->gData; } + + /** + * Get the infinite dimension device pointer + */ + Nd4jLong *getDimensionDevicePointer() { + return shapeInfo->getDimensionGpuPointer(); + } + + ~ScalarInfo() { + sd::buffer::freeBuffer(&scalarData); + delete shapeInfo; + } }; -void execPairwiseTransform( Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseTransform(&lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransform(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execPairwiseTransform( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execPairwiseTransformBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execPairwiseTransformBool( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execPairwiseBoolTransform( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execSummaryStatsScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, bool biasCorrected) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStatsScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execBroadcastBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcastBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execBroadcastBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execBroadcastBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -345,47 +365,58 @@ void execBroadcastBool(Nd4jPointer *extraPointers, * @param dimension * @param dimensionLength */ -void execBroadcast( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcast(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} +void execBroadcast(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execBroadcast( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} /** * @@ -397,230 +428,283 @@ void execBroadcast( * @param dZShapeInfo */ //////////////////////////////////////////////////////////////////////// -void execReduceFloat(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceFloatScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceSame(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSameScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceSameScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceSame2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSame(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceSame2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceSame( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceLong2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceLong(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceLong2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceLong( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceLong(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - - auto reductionPointer = reinterpret_cast(extraPointers[4]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::INT64) - throw datatype_exception::build("execReduceLong wrong Z data type", sd::DataType::INT64, zType); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, - ::execReduceScalar(launchDims, stream, opNum, - dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, - extraParams, - dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hXShapeInfo, - nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceLong(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + + auto reductionPointer = reinterpret_cast(extraPointers[4]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::INT64) + throw datatype_exception::build("execReduceLong wrong Z data type", + sd::DataType::INT64, zType); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceScalar(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + hXShapeInfo, extraParams, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + hXShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), + LIBND4J_TYPES, LONG_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceBool2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const*hXShapeInfo, Nd4jLong const*dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const*hZShapeInfo, Nd4jLong const*dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const*hDimensionShape, Nd4jLong const*dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceBool2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduceBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - - auto reductionPointer = reinterpret_cast(extraPointers[4]); - - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - - if (zType != sd::DataType::BOOL) - throw std::runtime_error("execReduceBool requires Z operand to have BOOL type"); - - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, - ::execReduceScalar(launchDims, stream, opNum, - dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), hXShapeInfo, - extraParams, - dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), hZShapeInfo, - nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + + auto reductionPointer = reinterpret_cast(extraPointers[4]); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (zType != sd::DataType::BOOL) + throw std::runtime_error( + "execReduceBool requires Z operand to have BOOL type"); + + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); + + BUILD_DOUBLE_SELECTOR( + xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceScalar(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + hXShapeInfo, extraParams, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + hZShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), + LIBND4J_TYPES, BOOL_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -635,36 +719,43 @@ void execReduceBool(Nd4jPointer *extraPointers, * @param dimensionLength */ //////////////////////////////////////////////////////////////////////// -void execIndexReduce(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduce(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - (int *) dbDimension->special(), dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduce(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execIndexReduce( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + (int *)dbDimension->special(), dimensionLength, + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -677,36 +768,44 @@ void execIndexReduce(Nd4jPointer *extraPointers, * @param dZShapeInfo */ //////////////////////////////////////////////////////////////////////// -void execReduceFloat2(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloat(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadPack.specialShapeInfo(), tadPack.specialOffsets()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduceFloat2(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduceFloat( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } /** @@ -717,287 +816,331 @@ void execReduceFloat2(Nd4jPointer *extraPointers, * @param extraParams */ //////////////////////////////////////////////////////////////////////// -void execIndexReduceScalar( - Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo){ - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execIndexReduceScalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execIndexReduceScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformSame(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformSame(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformSame(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[1] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformSame( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformBool(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[1] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformAny(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto stream = reinterpret_cast(extraPointers[1]); - auto streamSpecial = reinterpret_cast(extraPointers[4]); - LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], - reinterpret_cast(extraPointers[6])); - - NativeOpExecutioner::execTransformAny(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - nullptr, nullptr); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformAny(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto stream = reinterpret_cast(extraPointers[1]); + auto streamSpecial = reinterpret_cast(extraPointers[4]); + LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], + reinterpret_cast(extraPointers[6])); + + NativeOpExecutioner::execTransformAny( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, nullptr, nullptr); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformStrict(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformStrict(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformStrict(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[11] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformStrict( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execTransformFloat(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraParams, - tadShapeInfo, tadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execTransformFloat(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + auto tadShapeInfo = reinterpret_cast( + extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast( + extraPointers != nullptr ? extraPointers[11] : nullptr); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execTransformFloat( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraParams, tadShapeInfo, tadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void checkP2P() { - int curDevice = 0; - - cudaGetDevice(&curDevice); + int curDevice = 0; - int devCnt = 0; - cudaGetDeviceCount(&devCnt); + cudaGetDevice(&curDevice); - if (curDevice < 0 && curDevice > devCnt) - curDevice = 0; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); - bool tempSupport = true; + if (curDevice < 0 && curDevice > devCnt) curDevice = 0; - if (devCnt > 1) { - for (int dX = 0; dX < devCnt; dX++) { + bool tempSupport = true; - for (int dY = 0; dY < devCnt; dY++) { - if (dX == dY) - continue; + if (devCnt > 1) { + for (int dX = 0; dX < devCnt; dX++) { + for (int dY = 0; dY < devCnt; dY++) { + if (dX == dY) continue; - int canAccess = 0; - cudaSetDevice(dX); + int canAccess = 0; + cudaSetDevice(dX); - cudaDeviceCanAccessPeer(&canAccess, dX , dY); + cudaDeviceCanAccessPeer(&canAccess, dX, dY); - if (!canAccess) { - tempSupport = false; - break; - } - } - } + if (!canAccess) { + tempSupport = false; + break; + } + } + } - supportedP2P = tempSupport; + supportedP2P = tempSupport; - cudaSetDevice(curDevice); - } else { - // if we have only 1 device - we say that we support P2P, since all data will be on 1 device - supportedP2P = true; - } + cudaSetDevice(curDevice); + } else { + // if we have only 1 device - we say that we support P2P, since all data + // will be on 1 device + supportedP2P = true; + } } void enableP2P(bool enable) { - if (enable == allowedP2P) - return; - - int curDevice = 0; + if (enable == allowedP2P) return; - cudaGetDevice(&curDevice); + int curDevice = 0; - int devCnt = 0; - cudaGetDeviceCount(&devCnt); + cudaGetDevice(&curDevice); - if (curDevice < 0 && curDevice > devCnt) - curDevice = 0; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); - if (devCnt > 1) { - for (int dX = 0; dX < devCnt; dX++) { + if (curDevice < 0 && curDevice > devCnt) curDevice = 0; - for (int dY = 0; dY < devCnt; dY++) { - if (dX == dY) - continue; + if (devCnt > 1) { + for (int dX = 0; dX < devCnt; dX++) { + for (int dY = 0; dY < devCnt; dY++) { + if (dX == dY) continue; - int canAccess = 0; - cudaSetDevice(dX); + int canAccess = 0; + cudaSetDevice(dX); - cudaDeviceCanAccessPeer(&canAccess, dX , dY); + cudaDeviceCanAccessPeer(&canAccess, dX, dY); - if (canAccess) { - if (enable) { - cudaDeviceEnablePeerAccess(dY, 0); - } else { - cudaDeviceDisablePeerAccess(dY); - } - } else { - if (sd::Environment::getInstance()->isVerbose()) printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY); - } - } + if (canAccess) { + if (enable) { + cudaDeviceEnablePeerAccess(dY, 0); + } else { + cudaDeviceDisablePeerAccess(dY); + } + } else { + if (sd::Environment::getInstance()->isVerbose()) + printf("Peer access [%i] -> [%i] isn't possible\n", dX, dY); } - - cudaSetDevice(curDevice); + } } - allowedP2P = enable; - cudaSetDevice(curDevice); -} + } + + allowedP2P = enable; -bool isP2PAvailable() { - return supportedP2P; + cudaSetDevice(curDevice); } +bool isP2PAvailable() { return supportedP2P; } void initializeDevicesAndFunctions() { - try { - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - deviceProperties = new cudaDeviceProp[devCnt]; - for (int i = 0; i < devCnt; i++) { - cudaSetDevice(i); - cudaGetDeviceProperties(&deviceProperties[i], i); - - cudaDeviceSetLimit(cudaLimitStackSize, 4096); - } + try { + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + deviceProperties = new cudaDeviceProp[devCnt]; + for (int i = 0; i < devCnt; i++) { + cudaSetDevice(i); + cudaGetDeviceProperties(&deviceProperties[i], i); - cudaSetDevice(0); + cudaDeviceSetLimit(cudaLimitStackSize, 4096); + } - checkP2P(); + cudaSetDevice(0); - // enabling p2p gpu access if it's supported - if (supportedP2P && devCnt > 1) - enableP2P(allowedP2P); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + checkP2P(); + + // enabling p2p gpu access if it's supported + if (supportedP2P && devCnt > 1) enableP2P(allowedP2P); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } void initializeFunctions(Nd4jPointer *functions) { - sd::BlasHelper::getInstance()->initializeDeviceFunctions(functions); - /* - cublasSgemv = (CublasSgemv)functions[0]; - cublasDgemv = (CublasDgemv)functions[1]; - cublasHgemm = (CublasHgemm)functions[2]; - cublasSgemm = (CublasSgemm)functions[3]; - cublasDgemm = (CublasDgemm)functions[4]; - cublasSgemmEx = (CublasSgemmEx)functions[5]; - cublasHgemmBatched = (CublasHgemmBatched)functions[6]; - cublasSgemmBatched = (CublasSgemmBatched)functions[7]; - cublasDgemmBatched = (CublasDgemmBatched)functions[8]; - */ + sd::BlasHelper::getInstance()->initializeDeviceFunctions(functions); + /* + cublasSgemv = (CublasSgemv)functions[0]; +cublasDgemv = (CublasDgemv)functions[1]; +cublasHgemm = (CublasHgemm)functions[2]; +cublasSgemm = (CublasSgemm)functions[3]; +cublasDgemm = (CublasDgemm)functions[4]; +cublasSgemmEx = (CublasSgemmEx)functions[5]; +cublasHgemmBatched = (CublasHgemmBatched)functions[6]; +cublasSgemmBatched = (CublasSgemmBatched)functions[7]; +cublasDgemmBatched = (CublasDgemmBatched)functions[8]; + */ } - /** * This method acquires memory chunk of requested size on host side * @@ -1006,15 +1149,17 @@ void initializeFunctions(Nd4jPointer *functions) { * @param flags optional parameter */ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { - Nd4jPointer pointer; - // cudaHostAllocMapped |cudaHostAllocPortable - auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize + 8, cudaHostAllocDefault); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed"); - } + Nd4jPointer pointer; + // cudaHostAllocMapped |cudaHostAllocPortable + auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize + 8, + cudaHostAllocDefault); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaHostAlloc failed"); + } - return reinterpret_cast(pointer); + return reinterpret_cast(pointer); } /** @@ -1022,18 +1167,20 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { - Nd4jPointer pointer; - auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize + 8); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed"); - } + Nd4jPointer pointer; + auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize + 8); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMalloc failed"); + } - return reinterpret_cast(pointer); + return reinterpret_cast(pointer); } /** @@ -1042,13 +1189,14 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { * @param pointer pointer that'll be freed */ int freeHost(Nd4jPointer pointer) { - auto res = cudaFreeHost(reinterpret_cast(pointer)); - if (res != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFreeHost failed"); - } + auto res = cudaFreeHost(reinterpret_cast(pointer)); + if (res != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaFreeHost failed"); + } - return 1L; + return 1L; } /** @@ -1058,2746 +1206,3007 @@ int freeHost(Nd4jPointer pointer) { * @param ptrToDeviceId pointer to deviceId. */ int freeDevice(Nd4jPointer pointer, int deviceId) { - auto res = cudaFree(reinterpret_cast(pointer)); + auto res = cudaFree(reinterpret_cast(pointer)); - // we're intentionally skipping - if (res != 0 && res != 1) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFree failed"); - } + // we're intentionally skipping + if (res != 0 && res != 1) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaFree failed"); + } - return res == 0 ? 1L : 0L; + return res == 0 ? 1L : 0L; } - -Nd4jPointer createContext() { - return 0L; -} +Nd4jPointer createContext() { return 0L; } Nd4jPointer createStream() { + auto stream = new cudaStream_t(); + auto dZ = cudaStreamCreate(stream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaStreamCreate failed"); + } - auto stream = new cudaStream_t(); - auto dZ = cudaStreamCreate(stream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamCreate failed"); - } - - return stream; + return stream; } Nd4jPointer createEvent() { - Nd4jPointer nativeEvent= (Nd4jPointer) malloc(sizeof(cudaEvent_t)); + Nd4jPointer nativeEvent = (Nd4jPointer)malloc(sizeof(cudaEvent_t)); - CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t)); + CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", + sizeof(cudaEvent_t)); - auto dZ = cudaEventCreateWithFlags(reinterpret_cast(&nativeEvent), cudaEventDisableTiming); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventCreateWithFlags failed"); - } + auto dZ = cudaEventCreateWithFlags( + reinterpret_cast(&nativeEvent), cudaEventDisableTiming); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventCreateWithFlags failed"); + } - return nativeEvent; + return nativeEvent; } int registerEvent(Nd4jPointer event, Nd4jPointer stream) { - auto pEvent = reinterpret_cast(&event); - auto pStream = reinterpret_cast(stream); + auto pEvent = reinterpret_cast(&event); + auto pStream = reinterpret_cast(stream); - auto dZ = cudaEventRecord(*pEvent, *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventRecord failed"); - } + auto dZ = cudaEventRecord(*pEvent, *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventRecord failed"); + } - return 1; + return 1; } int setDevice(int deviceId) { - AffinityManager::setCurrentDevice(deviceId); - return 1; + AffinityManager::setCurrentDevice(deviceId); + return 1; } Nd4jLong getDeviceFreeMemoryDefault() { - size_t memFree = 0; - size_t memTotal = 0; + size_t memFree = 0; + size_t memTotal = 0; - cudaMemGetInfo(&memFree, &memTotal); + cudaMemGetInfo(&memFree, &memTotal); - return (Nd4jLong) memFree; + return (Nd4jLong)memFree; } Nd4jLong getDeviceFreeMemory(int device) { - int orig = -1; + int orig = -1; - cudaGetDevice(&orig); + cudaGetDevice(&orig); - if (device >= 0 && device != orig) { - cudaSetDevice(device); - } + if (device >= 0 && device != orig) { + cudaSetDevice(device); + } - size_t memFree = 0; - size_t memTotal = 0; + size_t memFree = 0; + size_t memTotal = 0; - cudaMemGetInfo(&memFree, &memTotal); + cudaMemGetInfo(&memFree, &memTotal); - if (device >= 0 && device != orig) { - cudaSetDevice(orig); - } + if (device >= 0 && device != orig) { + cudaSetDevice(orig); + } - return (Nd4jLong) memFree; + return (Nd4jLong)memFree; } Nd4jLong getDeviceTotalMemory(int device) { - int orig = -1; - - cudaGetDevice(&orig); - - if (device >= 0 && device != orig) { - cudaSetDevice(device); - } - size_t memFree = 0; - size_t memTotal = 0; - - cudaMemGetInfo(&memFree, &memTotal); - - if (device >= 0 && device != orig) { - cudaSetDevice(orig); - } + int orig = -1; + + cudaGetDevice(&orig); + + if (device >= 0 && device != orig) { + cudaSetDevice(device); + } + size_t memFree = 0; + size_t memTotal = 0; + + cudaMemGetInfo(&memFree, &memTotal); + + if (device >= 0 && device != orig) { + cudaSetDevice(orig); + } + + return (Nd4jLong)memTotal; +} + +int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + cudaMemcpyKind kind; + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + default: { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = cudaMemcpy(reinterpret_cast(dst), + const_cast(reinterpret_cast(src)), + static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", + src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpy failed"); + return 0; + } + + return 1; +} + +int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto pStream = reinterpret_cast(reserved); + + cudaMemcpyKind kind; + + // sd::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed"); + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; + } break; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + default: { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "UNDEFNED MEMCPY"); + return 0; + } + } + + auto dZ = + cudaMemcpyAsync(reinterpret_cast(dst), + const_cast(reinterpret_cast(src)), + static_cast(size), kind, *pStream); + // auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); + if (dZ != 0) { + printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", + src, dst, size, flags, static_cast(dZ)); + fflush(stdout); + fflush(stderr); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpyAsync failed"); + return 0; + } - return (Nd4jLong) memTotal; + return 1; } -int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaMemcpyKind kind; +int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto dZ = cudaMemset(reinterpret_cast(dst), value, + static_cast(size)); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemset failed"); + } - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - break; - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - default: { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); - return 0; - } - } + return 1; +} - auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); - if (dZ != 0) { - printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); - fflush(stdout); - fflush(stderr); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpy failed"); - return 0; - } +int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, + Nd4jPointer reserved) { + auto pStream = reinterpret_cast(reserved); - return 1; -} + auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, + static_cast(size), *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemsetAsync failed"); + } -int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto pStream = reinterpret_cast(reserved); - - cudaMemcpyKind kind; - - //sd::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed"); - - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - break; - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - default: { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); - return 0; - } - } - - auto dZ = cudaMemcpyAsync(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind, *pStream); - //auto dZ = cudaMemcpy(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind); - if (dZ != 0) { - printf("Failed on [%p] -> [%p], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); - fflush(stdout); - fflush(stderr); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyAsync failed"); - return 0; - } - - return 1; -} - -int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto dZ = cudaMemset(reinterpret_cast(dst), value, static_cast(size)); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemset failed"); - } - - return 1; -} - -int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - auto pStream = reinterpret_cast(reserved); - - auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, static_cast(size), *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemsetAsync failed"); - } - - return 1; + return 1; } int destroyEvent(Nd4jPointer event) { - auto pEvent = reinterpret_cast(&event); - auto dZ = cudaEventDestroy(*pEvent); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventDestroy failed"); - } + auto pEvent = reinterpret_cast(&event); + auto dZ = cudaEventDestroy(*pEvent); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventDestroy failed"); + } - return 1; + return 1; } int streamSynchronize(Nd4jPointer stream) { - auto pStream = reinterpret_cast(stream); + auto pStream = reinterpret_cast(stream); - auto dZ = cudaStreamSynchronize(*pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamSynchronize failed"); - } + auto dZ = cudaStreamSynchronize(*pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaStreamSynchronize failed"); + } - return 1L; + return 1L; } int eventSynchronize(Nd4jPointer event) { - auto pEvent = reinterpret_cast(&event); + auto pEvent = reinterpret_cast(&event); - auto dZ = cudaEventSynchronize(*pEvent); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventSynchronize failed"); - } + auto dZ = cudaEventSynchronize(*pEvent); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaEventSynchronize failed"); + } - return 1L; + return 1L; } int getAvailableDevices() { - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - return devCnt; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + return devCnt; } void enableDebugMode(bool reallyEnable) { - sd::Environment::getInstance()->setDebug(reallyEnable); + sd::Environment::getInstance()->setDebug(reallyEnable); } void setGridLimit(int gridSize) { - if (gridSize > 8192) - gridSize = 8192; - if (gridSize < 1) - gridSize = 1; - blockLimit = gridSize; + if (gridSize > 8192) gridSize = 8192; + if (gridSize < 1) gridSize = 1; + blockLimit = gridSize; } -int ompGetMaxThreads() { - return maxThreads; -} +int ompGetMaxThreads() { return maxThreads; } -int ompGetNumThreads() { - return maxThreads; -} +int ompGetNumThreads() { return maxThreads; } void setOmpNumThreads(int threads) { - if (threads > 1024) - threads = 1024; - if (threads < 32) - threads = 32; - maxThreads = threads; + if (threads > 1024) threads = 1024; + if (threads < 32) threads = 32; + maxThreads = threads; } void enableVerboseMode(bool reallyEnable) { - sd::Environment::getInstance()->setVerbose(reallyEnable); + sd::Environment::getInstance()->setVerbose(reallyEnable); } -int getDeviceMajor(int device) { - return deviceProperties[device].major; -} +int getDeviceMajor(int device) { return deviceProperties[device].major; } -int getDeviceMinor(int device) { - return deviceProperties[device].minor; -} +int getDeviceMinor(int device) { return deviceProperties[device].minor; } +const char *getDeviceName(int device) { return deviceProperties[device].name; } -const char * getDeviceName(int device) { - return deviceProperties[device].name; +void specialConcat(Nd4jPointer *extraPointers, int dimension, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *dZ, + Nd4jLong const *dZShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers) { + try { + BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), + sd::SpecialMethods, + ::concatCpuGeneric(dimension, numArrays, data, + inputShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void specialConcat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, - Nd4jPointer *inputShapeInfo, - void *dZ, - Nd4jLong const* dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { - try { - BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), sd::SpecialMethods, - ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), - LIBND4J_TYPES); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - /** * This method saves */ -sd::TadPack* tadOnlyShapeInfo(Nd4jLong const* dXShapeInfo, int *dimension, int dimensionLength) { - try { - auto pack = new TadPack(); - *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); - return pack; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; +sd::TadPack *tadOnlyShapeInfo(Nd4jLong const *dXShapeInfo, int *dimension, + int dimensionLength) { + try { + auto pack = new TadPack(); + *pack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + dXShapeInfo, dimension, dimensionLength); + return pack; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +Nd4jLong const *getPrimaryShapeInfo(sd::TadPack *pack) { + return pack->primaryShapeInfo(); +} +Nd4jLong const *getPrimaryOffsets(sd::TadPack *pack) { + return pack->primaryOffsets(); +} +Nd4jLong const *getSpecialShapeInfo(sd::TadPack *pack) { + return pack->specialShapeInfo(); +} +Nd4jLong const *getSpecialOffsets(sd::TadPack *pack) { + return pack->specialOffsets(); +} +Nd4jLong getNumberOfTads(sd::TadPack *pack) { return pack->numberOfTads(); } +int getShapeInfoLength(sd::TadPack *pack) { return pack->shapeInfoLength(); } + +int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, + Nd4jPointer reserved) { + cudaStream_t *pStream = reinterpret_cast(reserved); + + cudaMemcpyKind kind; + + DEBUG_KERNEL(pStream, -1); + + switch (flags) { + case 0: { + kind = cudaMemcpyHostToHost; + } break; + case 1: { + kind = cudaMemcpyHostToDevice; + } break; + case 2: { + kind = cudaMemcpyDeviceToHost; } -} - -Nd4jLong const* getPrimaryShapeInfo(sd::TadPack* pack) { - return pack->primaryShapeInfo(); -} -Nd4jLong const* getPrimaryOffsets(sd::TadPack* pack) { - return pack->primaryOffsets(); -} -Nd4jLong const* getSpecialShapeInfo(sd::TadPack* pack) { - return pack->specialShapeInfo(); -} -Nd4jLong const* getSpecialOffsets(sd::TadPack* pack) { - return pack->specialOffsets(); -} -Nd4jLong getNumberOfTads(sd::TadPack* pack) { - return pack->numberOfTads(); -} -int getShapeInfoLength(sd::TadPack* pack) { - return pack->shapeInfoLength(); -} - -int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaStream_t *pStream = reinterpret_cast(reserved); - - cudaMemcpyKind kind; - - DEBUG_KERNEL(pStream, -1); - - switch (flags) { - case 0: { - kind = cudaMemcpyHostToHost; - } - break; - case 1: { - kind = cudaMemcpyHostToDevice; - } - break; - case 2: { - kind = cudaMemcpyDeviceToHost; - } - case 3: { - kind = cudaMemcpyDeviceToDevice; - } - break; - } - auto dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, const_cast(src), size, dst, kind, *pStream); - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyToSymbolAsync failed"); - } - - return 1; + case 3: { + kind = cudaMemcpyDeviceToDevice; + } break; + } + auto dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, + const_cast(src), size, dst, + kind, *pStream); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaMemcpyToSymbolAsync failed"); + } + + return 1; } Nd4jPointer getConstantSpace() { - Nd4jPointer dConstAddr; - cudaError_t dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); - - if (dZ != 0) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaGetSymbolAddress failed"); - } - - return dConstAddr; -} - -void pullRows(Nd4jPointer *extraPointers, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* zShapeInfo, Nd4jLong const* dZShapeInfo, - Nd4jLong n, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, - Nd4jLong const* zTadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - dim3 launchDims(64, 256, 1024); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, - (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), - LIBND4J_TYPES); - - DEBUG_KERNEL(stream, -1); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - + Nd4jPointer dConstAddr; + cudaError_t dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), + deviceConstantMemory); -void average(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dXShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length, - bool propagate) { - try { - cudaStream_t *stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); + if (dZ != 0) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "cudaGetSymbolAddress failed"); + } - auto dX = reinterpret_cast(dx); + return dConstAddr; +} - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("averageFloat called\n"); +void pullRows(Nd4jPointer *extraPointers, OpaqueDataBuffer *dbX, + Nd4jLong const *xShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *zShapeInfo, + Nd4jLong const *dZShapeInfo, Nd4jLong n, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - // launching on gpu - if (mode == 0) { - dim3 launchDims(256, 256, 4096); - BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, (launchDims, stream, dX, dz, n, length, propagate), - LIBND4J_TYPES); - sd::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::averageGeneric(x, z, zShapeInfo, n, length, propagate), - LIBND4J_TYPES); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + dim3 launchDims(64, 256, 1024); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR( + xType, pullRowsKernelGeneric, + (launchDims, stream, dbX->special(), dbZ->special(), n, indexes, + tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + LIBND4J_TYPES); + + DEBUG_KERNEL(stream, -1); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void average(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong const *xShapeInfo, + Nd4jPointer *dx, Nd4jLong const *dXShapeInfo, void *z, + Nd4jLong const *zShapeInfo, void *dz, Nd4jLong const *dzShapeInfo, + int n, Nd4jLong length, bool propagate) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); + + auto dX = reinterpret_cast(dx); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("averageFloat called\n"); + + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + // launching on gpu + if (mode == 0) { + dim3 launchDims(256, 256, 4096); + BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, + (launchDims, stream, dX, dz, n, length, propagate), + LIBND4J_TYPES); + sd::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR( + xType, sd::SpecialMethods, + ::averageGeneric(x, z, zShapeInfo, n, length, propagate), + LIBND4J_TYPES); } -} + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void accumulate(Nd4jPointer *extras, Nd4jPointer *x, Nd4jLong const *xShapeInfo, + Nd4jPointer *dx, Nd4jLong const *dXShapeInfo, void *z, + Nd4jLong const *zShapeInfo, void *dz, + Nd4jLong const *dzShapeInfo, int n, Nd4jLong length) { + try { + auto stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); -void accumulate(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jLong const* xShapeInfo, - Nd4jPointer *dx, Nd4jLong const* dXShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void *dz, Nd4jLong const* dzShapeInfo, - int n, - Nd4jLong length) { - try { - auto stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); - - auto dX = reinterpret_cast(dx); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("accumulateFloat called\n"); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - // launching on gpu - if (mode == 0) { - dim3 launchDims(n, 256, 16384); - BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, (launchDims, stream, dX, dz, n, length), - LIBND4J_TYPES); - sd::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::accumulateGeneric(x, z, zShapeInfo, n, length), - LIBND4J_TYPES); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} + auto dX = reinterpret_cast(dx); + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("accumulateFloat called\n"); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); -void shuffle(Nd4jPointer *extras, - Nd4jPointer *x, Nd4jPointer *xShapeInfo, - Nd4jPointer *dx, Nd4jPointer *dXShapeInfo, - Nd4jPointer *z, Nd4jPointer *zShapeInfo, - Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, - int N, - int *shuffleMap, - Nd4jPointer *tadShapeInfo, - Nd4jPointer *tadOffsets) { - try { - cudaStream_t *stream = reinterpret_cast(extras[1]); - - auto dX = reinterpret_cast(dx); - auto dZ = reinterpret_cast(dz); - auto xShape = reinterpret_cast(xShapeInfo); - auto dxShape = reinterpret_cast(dXShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); - - auto xType = sd::ArrayOptions::dataType(xShape[0]); - dim3 launchDims(256, 512, 8192); - BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, - (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, tadOnlyShapeInfo, tadOffset), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + // launching on gpu + if (mode == 0) { + dim3 launchDims(n, 256, 16384); + BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, + (launchDims, stream, dX, dz, n, length), + LIBND4J_TYPES); + sd::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, + ::accumulateGeneric(x, z, zShapeInfo, n, length), + LIBND4J_TYPES); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void shuffle(Nd4jPointer *extras, Nd4jPointer *x, Nd4jPointer *xShapeInfo, + Nd4jPointer *dx, Nd4jPointer *dXShapeInfo, Nd4jPointer *z, + Nd4jPointer *zShapeInfo, Nd4jPointer *dz, Nd4jPointer *dZShapeInfo, + int N, int *shuffleMap, Nd4jPointer *tadShapeInfo, + Nd4jPointer *tadOffsets) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + + auto dX = reinterpret_cast(dx); + auto dZ = reinterpret_cast(dz); + auto xShape = reinterpret_cast(xShapeInfo); + auto dxShape = reinterpret_cast(dXShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); + + auto xType = sd::ArrayOptions::dataType(xShape[0]); + dim3 launchDims(256, 512, 8192); + BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, + (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, + tadOnlyShapeInfo, tadOffset), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } bool isExperimentalEnabled() { - return sd::Environment::getInstance()->isExperimentalBuild(); + return sd::Environment::getInstance()->isExperimentalBuild(); } void setOmpMinThreads(int threads) { - minThreads = sd::math::nd4j_max(32, threads); - minThreads = sd::math::nd4j_min(maxThreads, minThreads); + minThreads = sd::math::nd4j_max(32, threads); + minThreads = sd::math::nd4j_min(maxThreads, minThreads); } -int getDevice() { - return sd::AffinityManager::currentDeviceId(); -} +int getDevice() { return sd::AffinityManager::currentDeviceId(); } void setElementThreshold(int num) { - // this is no-op for CUDA + // this is no-op for CUDA } void setTADThreshold(int num) { - // this is no-op for CUDA + // this is no-op for CUDA } //////////////////////////////////////////////////////////////////////// -void execSummaryStats(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - bool biasCorrected) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStats(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, bool biasCorrected) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStats( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execSummaryStatsTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - bool biasCorrected, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - reinterpret_cast(dbDimension->special()), dimensionLength, - tadShapeInfo, tadOffsets, - biasCorrected); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execSummaryStatsTad(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, bool biasCorrected, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execSummaryStats( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + tadShapeInfo, tadOffsets, biasCorrected); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbDimension}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + void *extraParams, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3Tad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, - dimension, - shape::length(hDimensionShape)); - auto tadLength = shape::length(tadPack.primaryShapeInfo()); - auto yLength = shape::length(hYShapeInfo); - auto xLength = shape::length(hXShapeInfo); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - - if (tadLength == yLength || tadLength == xLength) { - // nd4j_printf("== way\n",""); - NativeOpExecutioner::execReduce3(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - } else - NativeOpExecutioner::execReduce3TAD(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dimension, dimensionLength, - tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Tad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *yTadOnlyShapeInfo, + Nd4jLong const *yTadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + hXShapeInfo, dimension, shape::length(hDimensionShape)); + auto tadLength = shape::length(tadPack.primaryShapeInfo()); + auto yLength = shape::length(hYShapeInfo); + auto xLength = shape::length(hXShapeInfo); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + + if (tadLength == yLength || tadLength == xLength) { + // nd4j_printf("== way\n",""); + NativeOpExecutioner::execReduce3( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + yTadOnlyShapeInfo, yTadOffsets); + } else + NativeOpExecutioner::execReduce3TAD( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, + yTadOnlyShapeInfo, yTadOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3Scalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT()); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execReduce3Scalar(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParams, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3Scalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT()); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarBool(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalarBool(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbScalar, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalarBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hScalarShapeInfo) + .specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarBoolTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParams, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), - dimension, dimensionLength, - tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalarBoolTad( + Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, Nd4jLong const *dScalarShapeInfo, + void *extraParams, OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalarBool( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParams, dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dbScalars->primary(), hScalarShapeInfo, dbScalars->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hScalarShapeInfo) + .specialAsT(), + dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalar(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalar, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalar(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hScalarShapeInfo).specialAsT(), - extraParams); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execScalar(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalar, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalar}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execScalar( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dbScalar->primary(), hScalarShapeInfo, dbScalar->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hScalarShapeInfo) + .specialAsT(), + extraParams); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalar}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// -void execScalarTad(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbScalars, Nd4jLong const* hScalarShapeInfo, Nd4jLong const* dScalarShapeInfo, - void *extraParams, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); +void execScalarTad(Nd4jPointer *extraPointers, int opNum, OpaqueDataBuffer *dbX, + Nd4jLong const *hXShapeInfo, Nd4jLong const *dXShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, OpaqueDataBuffer *dbScalars, + Nd4jLong const *hScalarShapeInfo, + Nd4jLong const *dScalarShapeInfo, void *extraParams, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *tadShapeInfoZ, Nd4jLong const *tadOffsetsZ) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbScalars}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (yType != xType && yType != sd::DataType::BOOL && !isExperimentalEnabled()) - throw sd::datatype_exception::build("execScalar both operands must have same data type", xType, yType); + if (yType != xType && yType != sd::DataType::BOOL && + !isExperimentalEnabled()) + throw sd::datatype_exception::build( + "execScalar both operands must have same data type", xType, yType); - dim3 launchDims(256, 256, 16384); + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR( + xType, yType, zType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, + dZ, dZShapeInfo, dScalars, extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), dbScalars->special(), extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE( + xType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + dbScalars->special(), extraParams, + dimension, dimensionLength, tadShapeInfo, + tadOffsets, tadShapeInfoZ, tadOffsetsZ), + LIBND4J_TYPES); #endif - DEBUG_KERNEL(stream, opNum); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void execAggregate(Nd4jPointer *extraPointers, - int opNum, - void **arguments, - int numArguments, - Nd4jLong **shapes, - int numShapes, - int *indexArguments, - int numIndexArguments, - int **intArrays, - int numIntArrays, - void *realArguments, - int numRealArguments, - sd::DataType dtype) { + DEBUG_KERNEL(stream, opNum); + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbScalars}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -void batchExecutor(Nd4jPointer *extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - void *ptrToArguments, - sd::DataType dtype) { -} +void execAggregate(Nd4jPointer *extraPointers, int opNum, void **arguments, + int numArguments, Nd4jLong **shapes, int numShapes, + int *indexArguments, int numIndexArguments, int **intArrays, + int numIntArrays, void *realArguments, int numRealArguments, + sd::DataType dtype) {} -void execAggregateBatch(Nd4jPointer *extraPointers, - int numAggregates, int opNum, - int maxArgs, int maxShapes, - int maxIntArrays, int maxIntArraySize, - int maxIdx, int maxReals, - void *ptrToArguments, sd::DataType dtype) { +void batchExecutor(Nd4jPointer *extraPointers, int numAggregates, int opNum, + int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} -} +void execAggregateBatch(Nd4jPointer *extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, int maxIntArrays, + int maxIntArraySize, int maxIdx, int maxReals, + void *ptrToArguments, sd::DataType dtype) {} //////////////////////////////////////////////////////////////////////// -void execRandom(Nd4jPointer *extraPointers, - int opNum, - Nd4jPointer stateHost, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } +void execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbZ->primary(), + hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbX->primary(), + hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } //////////////////////////////////////////////////////////////////////// void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - void *extraArguments) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - extraArguments); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) { - - unsigned long long *ptrHost = reinterpret_cast(extraPointers[0]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - // we don't synchronize at random initialization, it's safe to go unsync here - // cudaStreamSynchronize(*stream); - - auto ptrDev = reinterpret_cast(ptrToBuffer); - auto buffer = new sd::random::RandomBuffer(seed, bufferSize, reinterpret_cast(ptrHost), reinterpret_cast(ptrDev)); - buffer->propagateToDevice(buffer, *stream); - - sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed A"); - - // we generate sequence in the host memory - sd::random::Xoroshiro128 generator(buffer); - generator.refreshBuffer(); - - // and copy it to gpu - cudaMemcpyAsync(ptrDev, ptrHost, bufferSize * 8, cudaMemcpyHostToDevice, *stream); - sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed B"); - - return buffer; + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, OpaqueDataBuffer *dbY, + Nd4jLong const *hYShapeInfo, Nd4jLong const *dYShapeInfo, + OpaqueDataBuffer *dbZ, Nd4jLong const *hZShapeInfo, + Nd4jLong const *dZShapeInfo, void *extraArguments) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY}); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, dbX->primary(), + hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + extraArguments); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, + Nd4jPointer ptrToBuffer) { + unsigned long long *ptrHost = + reinterpret_cast(extraPointers[0]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + + // we don't synchronize at random initialization, it's safe to go unsync here + // cudaStreamSynchronize(*stream); + + auto ptrDev = reinterpret_cast(ptrToBuffer); + auto buffer = new sd::random::RandomBuffer( + seed, bufferSize, reinterpret_cast(ptrHost), + reinterpret_cast(ptrDev)); + buffer->propagateToDevice(buffer, *stream); + + sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed A"); + + // we generate sequence in the host memory + sd::random::Xoroshiro128 generator(buffer); + generator.refreshBuffer(); + + // and copy it to gpu + cudaMemcpyAsync(ptrDev, ptrHost, bufferSize * 8, cudaMemcpyHostToDevice, + *stream); + sd::DebugHelper::checkErrorCode(stream, "initRandom(...) failed B"); + + return buffer; } - void destroyRandom(Nd4jPointer ptrBuffer) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrBuffer); - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrBuffer); + // FIXME: it's bad thing, but we can't know in advance, which stream(s) where + // using this generator in practice + cudaDeviceSynchronize(); - // FIXME: it's bad thing, but we can't know in advance, which stream(s) where using this generator in practice - cudaDeviceSynchronize(); - - delete buffer; + delete buffer; } -void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { - - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrRandom); +void refreshBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrRandom); - unsigned long long *ptrHost = reinterpret_cast(extraPointers[0]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStreamSynchronize(*stream); + unsigned long long *ptrHost = + reinterpret_cast(extraPointers[0]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStreamSynchronize(*stream); - uint64_t *ptrDev = buffer->getDeviceBuffer(); + uint64_t *ptrDev = buffer->getDeviceBuffer(); - // update rng state - buffer->setSeed(seed); - buffer->setOffset(0); - buffer->propagateToDevice(buffer, *stream); + // update rng state + buffer->setSeed(seed); + buffer->setOffset(0); + buffer->propagateToDevice(buffer, *stream); - // refresh buffer on host size - sd::random::Xoroshiro128 generator(buffer); - generator.refreshBuffer(); + // refresh buffer on host size + sd::random::Xoroshiro128 generator(buffer); + generator.refreshBuffer(); - // copy back to gpu - cudaMemcpyAsync(ptrDev, ptrHost, buffer->getSize() * 8, cudaMemcpyHostToDevice, *stream); + // copy back to gpu + cudaMemcpyAsync(ptrDev, ptrHost, buffer->getSize() * 8, + cudaMemcpyHostToDevice, *stream); } -void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { +void reSeedBuffer(Nd4jPointer *extraPointers, long seed, + Nd4jPointer ptrRandom) { + sd::random::RandomBuffer *buffer = + reinterpret_cast(ptrRandom); - sd::random::RandomBuffer *buffer = reinterpret_cast (ptrRandom); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStreamSynchronize(*stream); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStreamSynchronize(*stream); - - // update rng state - buffer->reSeed(seed); - buffer->setOffset(0); - buffer->propagateToDevice(buffer, *stream); + // update rng state + buffer->reSeed(seed); + buffer->setOffset(0); + buffer->propagateToDevice(buffer, *stream); } - - /** - * Return the length of a shape buffer - * based on the pointer - * @param buffer the buffer pointer to check - * @return - */ + * Return the length of a shape buffer + * based on the pointer + * @param buffer the buffer pointer to check + * @return + */ int lengthForShapeBufferPointer(Nd4jPointer buffer) { - auto shapeBuffer = reinterpret_cast(buffer); - return shape::shapeInfoLength(shape::rank(shapeBuffer)); + auto shapeBuffer = reinterpret_cast(buffer); + return shape::shapeInfoLength(shape::rank(shapeBuffer)); } - /** - * The pointer to get the address for - * - * @param address the address to get the pointer - * @return the pointer for the given address - */ + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ Nd4jPointer pointerForAddress(Nd4jLong address) { - return reinterpret_cast(address); -} - -void tear(Nd4jPointer *extras, - OpaqueDataBuffer *dbX, Nd4jLong const* xShapeInfo, Nd4jLong const* dXShapeInfo, - Nd4jPointer *targets, - Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({}, {dbX}); - - cudaStream_t *stream = reinterpret_cast(extras[1]); - dim3 launchDims(512, 512, 512); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, - (launchDims, stream, dbX->special(), dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); - - InteropDataBuffer::registerSpecialUse({}, {dbX}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, int numElements, int level) { - - auto stream = reinterpret_cast(extras[1]); - auto g_scanBlockSums = reinterpret_cast(extras[2]); - - int blockSize = 512; // max size of the thread blocks - int numBlocks = sd::math::nd4j_max(1, static_cast(ceil(static_cast(numElements) / (2.f * blockSize)))); - int numThreads; - - if (numBlocks > 1) - numThreads = blockSize; - else if (sd::isPowerOfTwo(numElements)) - numThreads = numElements / 2; - else - numThreads = sd::floorPow2(numElements); - - int numEltsPerBlock = numThreads * 2; - - // if this is a non-power-of-2 array, the last block will be non-full - // compute the smallest power of 2 able to compute its scan. - int numEltsLastBlock = - numElements - (numBlocks-1) * numEltsPerBlock; - int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); - int np2LastBlock = 0; - int sharedMemLastBlock = 0; - - if (numEltsLastBlock != numEltsPerBlock) { - np2LastBlock = 1; - - if(!isPowerOfTwo(numEltsLastBlock)) - numThreadsLastBlock = floorPow2(numEltsLastBlock); - - unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; - sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + return reinterpret_cast(address); +} + +void tear(Nd4jPointer *extras, OpaqueDataBuffer *dbX, + Nd4jLong const *xShapeInfo, Nd4jLong const *dXShapeInfo, + Nd4jPointer *targets, Nd4jLong const *zShapeInfo, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({}, {dbX}); + + cudaStream_t *stream = reinterpret_cast(extras[1]); + dim3 launchDims(512, 512, 512); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, + (launchDims, stream, dbX->special(), dXShapeInfo, + targets, zShapeInfo, tadShapeInfo, tadOffsets), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + + InteropDataBuffer::registerSpecialUse({}, {dbX}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, + int numElements, int level) { + auto stream = reinterpret_cast(extras[1]); + auto g_scanBlockSums = reinterpret_cast(extras[2]); + + int blockSize = 512; // max size of the thread blocks + int numBlocks = sd::math::nd4j_max( + 1, static_cast( + ceil(static_cast(numElements) / (2.f * blockSize)))); + int numThreads; + + if (numBlocks > 1) + numThreads = blockSize; + else if (sd::isPowerOfTwo(numElements)) + numThreads = numElements / 2; + else + numThreads = sd::floorPow2(numElements); + + int numEltsPerBlock = numThreads * 2; + + // if this is a non-power-of-2 array, the last block will be non-full + // compute the smallest power of 2 able to compute its scan. + int numEltsLastBlock = numElements - (numBlocks - 1) * numEltsPerBlock; + int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); + int np2LastBlock = 0; + int sharedMemLastBlock = 0; + + if (numEltsLastBlock != numEltsPerBlock) { + np2LastBlock = 1; + + if (!isPowerOfTwo(numEltsLastBlock)) + numThreadsLastBlock = floorPow2(numEltsLastBlock); + + unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; + sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + } + + // padding space is used to avoid shared memory bank conflicts + int extraSpace = numEltsPerBlock / NUM_BANKS; + int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); + + // setup execution parameters + // if NP2, we process the last block separately + dim3 grid(max(1, numBlocks - np2LastBlock), 1, 1); + dim3 threads(numThreads, 1, 1); + dim3 gridOnes(1, 1, 1); + dim3 threadsOnes(numThreadsLastBlock, 1, 1); + + if (sharedMemSize < 2048) sharedMemSize = 2048; + + if (sharedMemLastBlock < 2048) sharedMemLastBlock = 2048; + + // execute the scan + if (numBlocks > 1) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, g_scanBlockSums[level], numThreads * 2, + 0, 0); + if (np2LastBlock) { + sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, + stream, dZ, dX, g_scanBlockSums[level], + numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); } - // padding space is used to avoid shared memory bank conflicts - int extraSpace = numEltsPerBlock / NUM_BANKS; - int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); - - // setup execution parameters - // if NP2, we process the last block separately - dim3 grid(max(1, numBlocks - np2LastBlock), 1, 1); - dim3 threads(numThreads, 1, 1); - dim3 gridOnes(1, 1, 1); - dim3 threadsOnes(numThreadsLastBlock, 1, 1); - - if (sharedMemSize < 2048) - sharedMemSize = 2048; - - if (sharedMemLastBlock < 2048) - sharedMemLastBlock = 2048; - - // execute the scan - if (numBlocks > 1) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0); - if (np2LastBlock) { - sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - - // After scanning all the sub-blocks, we are mostly done. But now we - // need to take all of the last values of the sub-blocks and scan those. - // This will give us a new value that must be sdded to each block to - // get the final results. - // recursive (CPU) call - prescanArrayRecursive(extras, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1); - - sd::uniformAdd<<>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); - - if (np2LastBlock) { - sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - } else if (isPowerOfTwo(numElements)) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0); - } else { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0); + // After scanning all the sub-blocks, we are mostly done. But now we + // need to take all of the last values of the sub-blocks and scan those. + // This will give us a new value that must be sdded to each block to + // get the final results. + // recursive (CPU) call + prescanArrayRecursive(extras, g_scanBlockSums[level], + g_scanBlockSums[level], numBlocks, level + 1); + + sd::uniformAdd<<>>( + dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); + + if (np2LastBlock) { + sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>( + dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); } + } else if (isPowerOfTwo(numElements)) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numThreads * 2, 0, 0); + } else { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numElements, 0, 0); + } - sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); } //////////////////////////////////////////////////////////////////////// -void execReduce3All(Nd4jPointer *extraPointers, - int opNum, - OpaqueDataBuffer *dbX, Nd4jLong const* hXShapeInfo, Nd4jLong const* dXShapeInfo, - void *extraParamsVals, - OpaqueDataBuffer *dbY, Nd4jLong const* hYShapeInfo, Nd4jLong const* dYShapeInfo, - OpaqueDataBuffer *dbZ, Nd4jLong const* hZShapeInfo, Nd4jLong const* dZShapeInfo, - OpaqueDataBuffer *dbDimension, Nd4jLong const* hDimensionShape, Nd4jLong const* dDimensionShape, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - try { - InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); - InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); - - auto dimension = reinterpret_cast(dbDimension->primary()); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3All(&lc, opNum, - dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hXShapeInfo).specialAsT(), - extraParamsVals, - dbY->primary(), hYShapeInfo, dbY->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hYShapeInfo).specialAsT(), - dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance()->bufferForShapeInfo(hZShapeInfo).specialAsT(), - reinterpret_cast(dbDimension->special()), dimensionLength, - xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); - - InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void sort(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - bool descending) { - try { - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); +void execReduce3All(Nd4jPointer *extraPointers, int opNum, + OpaqueDataBuffer *dbX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *dXShapeInfo, void *extraParamsVals, + OpaqueDataBuffer *dbY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *dYShapeInfo, OpaqueDataBuffer *dbZ, + Nd4jLong const *hZShapeInfo, Nd4jLong const *dZShapeInfo, + OpaqueDataBuffer *dbDimension, + Nd4jLong const *hDimensionShape, + Nd4jLong const *dDimensionShape, + Nd4jLong const *xTadShapeInfo, Nd4jLong const *xOffsets, + Nd4jLong const *yTadShapeInfo, Nd4jLong const *yOffsets) { + try { + InteropDataBuffer::prepareSpecialUse({dbZ}, {dbX, dbY, dbDimension}); + InteropDataBuffer::preparePrimaryUse({}, {dbDimension}); + + auto dimension = reinterpret_cast(dbDimension->primary()); + int dimensionLength = static_cast(shape::length(hDimensionShape)); + + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + NativeOpExecutioner::execReduce3All( + &lc, opNum, dbX->primary(), hXShapeInfo, dbX->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hXShapeInfo) + .specialAsT(), + extraParamsVals, dbY->primary(), hYShapeInfo, dbY->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hYShapeInfo) + .specialAsT(), + dbZ->primary(), hZShapeInfo, dbZ->special(), + ConstantShapeHelper::getInstance() + ->bufferForShapeInfo(hZShapeInfo) + .specialAsT(), + reinterpret_cast(dbDimension->special()), dimensionLength, + xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + + InteropDataBuffer::registerSpecialUse({dbZ}, {dbX, dbY}); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sort(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, bool descending) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_SINGLE_SELECTOR(xType, bitonicSortStepGeneric, - (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, - (launchDims, stream, dX, dXShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + auto xLength = shape::length(xShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_SINGLE_SELECTOR( + xType, bitonicSortStepGeneric, + (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES); } - - sd::DebugHelper::checkErrorCode(stream, "sort(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, + (launchDims, stream, dX, dXShapeInfo, n, + xLength, rev, descending), + LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} - - -void sortByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto yLength = shape::length(yShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - auto yType = sd::ArrayOptions::dataType(yShapeInfo); - - if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) - return; - - if (xLength != yLength) - throw std::runtime_error("sortByKey: keys and values must have the same size"); - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); - - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + sd::DebugHelper::checkErrorCode(stream, "sort(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByKey(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto xLength = shape::length(xShapeInfo); + auto yLength = shape::length(yShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + auto yType = sd::ArrayOptions::dataType(yShapeInfo); + + if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) return; + + if (xLength != yLength) + throw std::runtime_error( + "sortByKey: keys and values must have the same size"); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, + dyShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); } - - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, + dyShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} -void sortByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto yLength = shape::length(yShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = sd::ArrayOptions::dataType(yShapeInfo); - auto yType = sd::ArrayOptions::dataType(xShapeInfo); - - if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) - return; - - if (xLength != yLength) - throw std::runtime_error("sortByValue: keys and values must have the same size"); - - - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - dim3 launchDims(numBlocks, numThreads, 32768); - - for (int k = 2; k <= xLength; k = 2 * k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - } - } - } else { - int numThreads = sd::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = sd::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window <<= 1) { - int n = window; - int rev = 0; - do { - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), - LIBND4J_TYPES, LIBND4J_TYPES); - n >>= 1; - rev = 1; - } while (n > 1); - } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortByValue(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto xLength = shape::length(xShapeInfo); + auto yLength = shape::length(yShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = sd::ArrayOptions::dataType(yShapeInfo); + auto yType = sd::ArrayOptions::dataType(xShapeInfo); + + if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo)) return; + + if (xLength != yLength) + throw std::runtime_error( + "sortByValue: keys and values must have the same size"); + + // check if xLength is a power of 2, and use bitonic sort, if that's the + // case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && + (xLength <= 1024 * 1024 * 10)) { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + dim3 launchDims(numBlocks, numThreads, 32768); + + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, + dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - - -void sortTadByKey(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - auto yType = sd::ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, - (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), - LIBND4J_TYPES, LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void sortTadByValue(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - void *y, Nd4jLong const* yShapeInfo, - void *dy, Nd4jLong const* dyShapeInfo, - int *dimension, - int dimensionLength, - bool descending) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = sd::ArrayOptions::dataType(yShapeInfo); - auto yType = sd::ArrayOptions::dataType(xShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, - (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), - LIBND4J_TYPES, LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - - -void sortTad(Nd4jPointer *extraPointers, - void *x, Nd4jLong const* xShapeInfo, - void *dX, Nd4jLong const* dXShapeInfo, - int *dimension, - int dimensionLength, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets, - bool descending) { - try { - // to be implemented - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() - : reinterpret_cast(extraPointers[0]); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768); - auto xType = sd::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, - (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), - LIBND4J_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + } else { + int numThreads = sd::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) numBlocks++; + + numBlocks = sd::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, + dXShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } -} + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByKey(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 256, 2048); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + auto yType = sd::ArrayOptions::dataType(yShapeInfo); + BUILD_DOUBLE_SELECTOR( + xType, yType, oesTadGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, + dimensionLength, tadPack.platformShapeInfo(), + tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTadByValue(Nd4jPointer *extraPointers, void *x, + Nd4jLong const *xShapeInfo, void *dX, + Nd4jLong const *dXShapeInfo, void *y, + Nd4jLong const *yShapeInfo, void *dy, + Nd4jLong const *dyShapeInfo, int *dimension, + int dimensionLength, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 256, 2048); + auto xType = sd::ArrayOptions::dataType(yShapeInfo); + auto yType = sd::ArrayOptions::dataType(xShapeInfo); + + BUILD_DOUBLE_SELECTOR( + xType, yType, oesTadGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, + dimensionLength, tadPack.platformShapeInfo(), + tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortTad(Nd4jPointer *extraPointers, void *x, Nd4jLong const *xShapeInfo, + void *dX, Nd4jLong const *dXShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + try { + // to be implemented + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 + ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int)tadPack.numberOfTads(), 512, 33768); + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR( + xType, oesTadGeneric, + (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, + tadShapeInfo, tadOffsets, descending), + LIBND4J_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, + Nd4jLong length, int rank) { + throw std::runtime_error("sortCooIndices:: Not implemented yet"); +} + +Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, + Nd4jLong length) { + return nullptr; +} + +void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { + +} + +sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, + Nd4jPointer flatBufferPointer) { + try { + return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper *ptr) { + return ptr->size(); +} +Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper *ptr) { + return ptr->pointer(); +} + +const char *getAllCustomOps() { + return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); +} + +sd::ShapeList *_calculateOutputShapes( + Nd4jPointer *extraPointers, sd::ops::DeclarableOp *op, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputShapes, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, int *dArgs, int numDArgs) { + sd::graph::VariableSpace varSpace; + Context block(2, &varSpace); + sd::ShapeList inShapes; + + for (int e = 0; e < numIArgs; e++) block.getIArguments()->push_back(iArgs[e]); -void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) { - throw std::runtime_error("sortCooIndices:: Not implemented yet"); -} + for (int e = 0; e < numTArgs; e++) block.getTArguments()->push_back(tArgs[e]); -Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - return nullptr; -} + for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); -void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length) { + for (int e = 0; e < numDArgs; e++) + block.getDArguments()->push_back((sd::DataType)dArgs[e]); -} + for (int e = 0; e < numInputShapes; e++) { + auto shape_ = reinterpret_cast(inputShapes[e]); + // we shouldn't copy buffer if that's empty array + void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + void *bufferD_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e + numInputShapes]; -sd::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - try { - return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -Nd4jLong getResultWrapperSize(sd::graph::ResultWrapper* ptr) { - return ptr->size(); -} -Nd4jPointer getResultWrapperPointer(sd::graph::ResultWrapper* ptr) { - return ptr->pointer(); -} - + auto array = new sd::NDArray(buffer_, bufferD_, shape_); -const char* getAllCustomOps() { - return sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); -} + // block should contain references to proper variable + varSpace.putVariable(1, e, array); + block.pickInput(1, e); + inShapes.push_back(shape_); + } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - sd::graph::VariableSpace varSpace; - Context block(2, &varSpace); - sd::ShapeList inShapes; + auto shapeList = op->calculateOutputShape(&inShapes, block); - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + if (varSpace.launchContext()->getWorkspace() != nullptr) shapeList->detach(); - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); - - for (int e = 0; e < numBArgs; e++) - block.getBArguments()->push_back(bArgs[e]); - - for (int e = 0; e < numDArgs; e++) - block.getDArguments()->push_back((sd::DataType) dArgs[e]); - - for (int e = 0; e < numInputShapes; e++) { - auto shape_ = reinterpret_cast(inputShapes[e]); - - // we shouldn't copy buffer if that's empty array - void *buffer_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - void *bufferD_ = sd::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputShapes]; - - auto array = new sd::NDArray(buffer_, bufferD_, shape_); - - // block should contain references to proper variable - varSpace.putVariable(1, e, array); - block.pickInput(1, e); - - inShapes.push_back(shape_); - } - - auto shapeList = op->calculateOutputShape(&inShapes, block); - - if (varSpace.launchContext()->getWorkspace() != nullptr) - shapeList->detach(); - - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); +sd::ShapeList *calculateOutputShapes2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs, bool *bArgs, int numBArgs, + int *dArgs, int numDArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, - iArgs, numIArgs, bArgs, numBArgs, dArgs, numDArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs, bArgs, numBArgs, dArgs, numDArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -sd::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, sd::ops::DeclarableOp* op, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - Context block(1); - sd::ShapeList inShapes; +sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, + sd::ops::DeclarableOp *op, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + Context block(1); + sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) - block.getIArguments()->push_back(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.getIArguments()->push_back(iArgs[e]); - for (int e = 0; e < numTArgs; e++) - block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.getTArguments()->push_back(tArgs[e]); - for (int e = 0; e < numInputShapes; e++) - inShapes.push_back(reinterpret_cast(inputShapes[e])); + for (int e = 0; e < numInputShapes; e++) + inShapes.push_back(reinterpret_cast(inputShapes[e])); - auto shapeList = op->calculateOutputShape(&inShapes, block); + auto shapeList = op->calculateOutputShape(&inShapes, block); - return shapeList; + return shapeList; } -sd::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); +sd::ShapeList *calculateOutputShapes(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputShapes, + int numInputShapes, double *tArgs, + int numTArgs, Nd4jLong *iArgs, + int numIArgs) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return _calculateOutputShapes(extraPointers, op, inputShapes, + numInputShapes, tArgs, numTArgs, iArgs, + numIArgs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getShapeListSize(sd::ShapeList* list) { - return list->size(); -} +Nd4jLong getShapeListSize(sd::ShapeList *list) { return list->size(); } -Nd4jLong const* getShape(sd::ShapeList* list, Nd4jLong i) { - return list->at(i); +Nd4jLong const *getShape(sd::ShapeList *list, Nd4jLong i) { + return list->at(i); } -static FORCEINLINE Nd4jStatus realExec(sd::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - if (op == nullptr) - nd4j_printf("Can't find requested operation: [%lld]\n", hash); - - // we're using the same fake nodeId everywhere here - - std::vector inputs(numInputs); - std::vector outputs(numOutputs); - std::vector ttArgs(numTArgs); - std::vector bbArgs(numBArgs); - std::vector iiArgs(numIArgs); - - // filling block now with inputs - for (int e = 0; e < numInputs; e++) { - auto shape = reinterpret_cast(inputShapes[e]); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; - void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputs]; - - inputs[e] = new sd::NDArray(buffer, bufferD, shape); - } - - // if not inplace - transferring output arrays - - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - // we want to keep original output shape intact - auto shape = shape::copyShape(reinterpret_cast(outputShapes[e])); - void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e]; - void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : outputBuffers[e + numOutputs]; +static FORCEINLINE Nd4jStatus +realExec(sd::ops::DeclarableOp *op, Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, + bool *bArgs, int numBArgs, bool isInplace) { + if (op == nullptr) + nd4j_printf("Can't find requested operation: [%lld]\n", hash); - // FIXME: revisit this. - bool canNullify = true; - for (int i = 0; i < numInputs; i++) { - void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY ? nullptr : inputBuffers[i]; - if (ibuffer == buffer) { - canNullify = false; - break; - } - } + // we're using the same fake nodeId everywhere here - if (canNullify && buffer != nullptr) - memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); + std::vector inputs(numInputs); + std::vector outputs(numOutputs); + std::vector ttArgs(numTArgs); + std::vector bbArgs(numBArgs); + std::vector iiArgs(numIArgs); - auto array = new sd::NDArray(buffer, bufferD, shape); - outputs[e] = array; - } + // filling block now with inputs + for (int e = 0; e < numInputs; e++) { + auto shape = reinterpret_cast(inputShapes[e]); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e]; + void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[e + numInputs]; - for (int e = 0; e < numIArgs; e++) - iiArgs[e] = iArgs[e]; + inputs[e] = new sd::NDArray(buffer, bufferD, shape); + } + + // if not inplace - transferring output arrays + + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // we want to keep original output shape intact + auto shape = + shape::copyShape(reinterpret_cast(outputShapes[e])); + void *buffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e]; + void *bufferD = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : outputBuffers[e + numOutputs]; + + // FIXME: revisit this. + bool canNullify = true; + for (int i = 0; i < numInputs; i++) { + void *ibuffer = sd::ArrayOptions::arrayType(shape) == ArrayType::EMPTY + ? nullptr + : inputBuffers[i]; + if (ibuffer == buffer) { + canNullify = false; + break; + } + } - for (int e = 0; e < numTArgs; e++) - ttArgs[e] = tArgs[e]; + if (canNullify && buffer != nullptr) + memset((uint8_t *)buffer, '\0', + shape::length(shape) * + DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape))); - for (int e = 0; e < numBArgs; e++) - bbArgs[e] = bArgs[e]; + auto array = new sd::NDArray(buffer, bufferD, shape); + outputs[e] = array; + } + for (int e = 0; e < numIArgs; e++) iiArgs[e] = iArgs[e]; - // hypothetically at this point we have everything filled - auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, std::vector(), isInplace); - //auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); + for (int e = 0; e < numTArgs; e++) ttArgs[e] = tArgs[e]; + for (int e = 0; e < numBArgs; e++) bbArgs[e] = bArgs[e]; - if (!isInplace) - for (int e = 0; e < numOutputs; e++) { - //shape::printShapeInfoLinear("JVM output shape", (int *) outputShapes[e]); - //shape::printShapeInfoLinear("C++ output shape", (int *) outputs[e]->shapeInfo()); - //outputs[e]->printIndexedBuffer("C++ raw output"); - //outputs[e]->printBuffer("C++ indexed output"); + // hypothetically at this point we have everything filled + auto dZ = op->execute(inputs, outputs, ttArgs, iiArgs, bbArgs, + std::vector(), isInplace); + // auto dZ = op->execute(inputs, ttArgs, iiArgs, isInplace); - if (outputs[e]->ordering() != shape::order(reinterpret_cast(outputShapes[e]))) - outputs[e]->streamline(shape::order(reinterpret_cast(outputShapes[e]))); - } + if (!isInplace) + for (int e = 0; e < numOutputs; e++) { + // shape::printShapeInfoLinear("JVM output shape", (int *) + // outputShapes[e]); shape::printShapeInfoLinear("C++ output shape", (int + // *) outputs[e]->shapeInfo()); outputs[e]->printIndexedBuffer("C++ raw + // output"); outputs[e]->printBuffer("C++ indexed output"); + + if (outputs[e]->ordering() != + shape::order(reinterpret_cast(outputShapes[e]))) + outputs[e]->streamline( + shape::order(reinterpret_cast(outputShapes[e]))); + } - for (auto v: inputs) - delete v; + for (auto v : inputs) delete v; - for (auto v: outputs) - delete v; + for (auto v : outputs) delete v; - return Status::OK(); + return Status::OK(); } +int execCustomOp(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, + int numInputs, Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs, double *tArgs, + int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, + int numBArgs, bool isInplace) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); -int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, - numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + return realExec(op, extraPointers, hash, inputBuffers, inputShapes, + numInputs, outputBuffers, outputShapes, numOutputs, tArgs, + numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - try { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - auto context = reinterpret_cast(opContext); - - auto result = op->execute(context); +int execCustomOp2(Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer opContext) { + try { + auto op = sd::ops::OpRegistrator::getInstance()->getOperation(hash); + auto context = reinterpret_cast(opContext); - auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream()); - if (res != 0) - throw sd::cuda_exception::build("customOp execution failed", res); + auto result = op->execute(context); - for (auto v:context->fastpath_in()) { - if (!v->isEmpty()) - v->syncToDevice(); - } + auto res = + cudaStreamSynchronize(*context->launchContext()->getCudaStream()); + if (res != 0) + throw sd::cuda_exception::build("customOp execution failed", res); - for (auto v:context->fastpath_out()) { - if (!v->isEmpty()) - v->syncToDevice(); - } + for (auto v : context->fastpath_in()) { + if (!v->isEmpty()) v->syncToDevice(); + } - return result; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; + for (auto v : context->fastpath_out()) { + if (!v->isEmpty()) v->syncToDevice(); } + + return result; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } -int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { - try { - auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); +int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer flatBufferPointer) { + try { + auto graph = sd::graph::Graph::fromFlatPointer(flatBufferPointer); - sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); + sd::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + return ND4J_STATUS_OK; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } +static VariablesSet *executeStoredGraphT(Nd4jPointer *extraPointers, + Nd4jLong graphId, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, + int *inputIndices, int numInputs) { + auto graph = sd::graph::GraphHolder::getInstance()->pullGraph(graphId); + auto varSpace = graph->variableSpace()->clone(); -static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - auto graph = sd::graph::GraphHolder::getInstance()->pullGraph(graphId); - auto varSpace = graph->variableSpace()->clone(); - - std::vector handles; + std::vector handles; - for (int e = 0; e < numInputs; e++) { - auto idx = inputIndices[e]; + for (int e = 0; e < numInputs; e++) { + auto idx = inputIndices[e]; - // we'll delete this array later, together with cloned VariableSpace - auto array = new sd::NDArray(inputBuffers[e], reinterpret_cast(inputShapes[e])); - handles.emplace_back(array); + // we'll delete this array later, together with cloned VariableSpace + auto array = new sd::NDArray(inputBuffers[e], + reinterpret_cast(inputShapes[e])); + handles.emplace_back(array); - if (varSpace->hasVariable(idx)) { - auto var = varSpace->getVariable(idx); - if (var->hasNDArray()) - delete var->getNDArray(); + if (varSpace->hasVariable(idx)) { + auto var = varSpace->getVariable(idx); + if (var->hasNDArray()) delete var->getNDArray(); - var->setNDArray(array); - } else - varSpace->putVariable(idx, array); - } + var->setNDArray(array); + } else + varSpace->putVariable(idx, array); + } - auto dZ = sd::graph::GraphExecutioner::execute(graph, varSpace); - auto varSet = new sd::graph::VariablesSet(dZ); + auto dZ = sd::graph::GraphExecutioner::execute(graph, varSpace); + auto varSet = new sd::graph::VariablesSet(dZ); - if (dZ == ND4J_STATUS_OK) { - // pull back results, and provide them - auto outputs = graph->fetchOutputs(); - for (int e = 0; e < outputs->size(); e++) { - // we're only getting variable ID/Index from original grap. values will be taken from cloned workspace - std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); + if (dZ == ND4J_STATUS_OK) { + // pull back results, and provide them + auto outputs = graph->fetchOutputs(); + for (int e = 0; e < outputs->size(); e++) { + // we're only getting variable ID/Index from original grap. values will be + // taken from cloned workspace + std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); - auto var = varSpace->getVariable(varId); + auto var = varSpace->getVariable(varId); - varSet->push_back(var->clone()); - } + varSet->push_back(var->clone()); + } - delete outputs; - } + delete outputs; + } - delete varSpace; + delete varSpace; - return varSet; + return varSet; } -VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - try { - return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int *inputIndices, + int numInputs) { + try { + return executeStoredGraphT(extraPointers, graphId, inputBuffers, + inputShapes, inputIndices, numInputs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getVariablesSetSize(sd::graph::VariablesSet* set) { - return set->size(); +Nd4jLong getVariablesSetSize(sd::graph::VariablesSet *set) { + return set->size(); } -Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet* set) { - return set->status(); +Nd4jStatus getVariablesSetStatus(sd::graph::VariablesSet *set) { + return set->status(); } -sd::graph::Variable* getVariable(sd::graph::VariablesSet* set, Nd4jLong i) { - return set->at(i); +sd::graph::Variable *getVariable(sd::graph::VariablesSet *set, Nd4jLong i) { + return set->at(i); } -int getVariableId(sd::graph::Variable* variable) { - return variable->id(); -} +int getVariableId(sd::graph::Variable *variable) { return variable->id(); } -int getVariableIndex(sd::graph::Variable* variable) { - return variable->index(); +int getVariableIndex(sd::graph::Variable *variable) { + return variable->index(); } -const char* getVariableName(sd::graph::Variable* variable) { - return variable->getName()->c_str(); +const char *getVariableName(sd::graph::Variable *variable) { + return variable->getName()->c_str(); } -Nd4jLong const* getVariableShape(sd::graph::Variable* variable) { - return variable->getNDArray()->shapeInfo(); +Nd4jLong const *getVariableShape(sd::graph::Variable *variable) { + return variable->getNDArray()->shapeInfo(); } -void* getVariableBuffer(sd::graph::Variable* variable) { - return variable->getNDArray()->buffer(); +void *getVariableBuffer(sd::graph::Variable *variable) { + return variable->getNDArray()->buffer(); } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { - try { - sd::graph::GraphHolder::getInstance()->dropGraphAny(graphId); - - return ND4J_STATUS_OK; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; - } + try { + sd::graph::GraphHolder::getInstance()->dropGraphAny(graphId); + + return ND4J_STATUS_OK; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } void deletePointerArray(Nd4jPointer pointer) { - Nd4jPointer *ptr = reinterpret_cast(pointer); - delete[] ptr; + Nd4jPointer *ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteCharArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteIntArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } void deleteLongArray(Nd4jPointer pointer) { - auto ptr = reinterpret_cast(pointer); - delete[] ptr; + auto ptr = reinterpret_cast(pointer); + delete[] ptr; } -void deleteVariablesSet(sd::graph::VariablesSet* pointer) { - delete pointer; -} +void deleteVariablesSet(sd::graph::VariablesSet *pointer) { delete pointer; } void deleteShapeList(Nd4jPointer shapeList) { - sd::ShapeList* list = reinterpret_cast(shapeList); + sd::ShapeList *list = reinterpret_cast(shapeList); - //list->destroy(); - delete list; + // list->destroy(); + delete list; } -const char* getAllOperations() { - return sd::OpTracker::getInstance()->exportOperations(); +const char *getAllOperations() { + return sd::OpTracker::getInstance()->exportOperations(); } Nd4jPointer getGraphState(Nd4jLong id) { - return (Nd4jPointer) new sd::graph::GraphState(id); + return (Nd4jPointer) new sd::graph::GraphState(id); } - void deleteGraphState(Nd4jPointer state) { - auto stateP = reinterpret_cast(state); - delete stateP; -} - - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, sd::graph::GraphState *state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - /** - * That's basically exec, with VariableSpace provided in GraphState: - * depending on operation (i.e. while of if), different logic executors could be used - */ - - auto graph = state->graph(); - auto varSpace = state->variableSpace(); - - // Node is dynamically created, and has nothing beyond it: only inputs and outputs - // this node has id of 0, and inputs are - Node node(OpType_LOGIC, opHash, 0); - - // mapping inputs - for (int e = 0; e < numInputs; e++) { - auto buffer = inputBuffers[e]; - auto shapeInfo = reinterpret_cast(inputShapes[e]); - - auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace - varSpace->putVariable(0, e, array); - node.pickInput(0, e); - } - - // mapping scopes - for (int e = 0; e < numScopes; e++) { - // we should check scope existence in GraphState/Graph - int scopeId = (int) scopes[e]; - if (!state->hasScope(scopeId)) { - // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't exist\n", scopeId); - return Status::THROW(); - } - node.pickInput(scopeId, 0); - } - - auto dZ = LogicExecutor::processNode(graph, &node); - if (dZ != Status::OK()) - return dZ; - - // mapping outputs - - for (int e = 0; e < numOutputs; e++) { - auto buffer = outputBuffers[e]; - auto shapeInfo = reinterpret_cast(outputShapes[e]); - - NDArray array(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace to the same ID - //varSpace->putVariable(0, e, array); - - auto t = varSpace->getVariable(0, e)->getNDArray(); - array.assign(t); - } - - // removing input variables - for (int e = 0; e < numInputs; e++) { - varSpace->dropVariable(0, e); - } - - // after some bla-bla-bla we should have Graph and Node for current op - return Status::OK(); -} - - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - try { - return execCustomOpWithScope(extraPointers, reinterpret_cast(state), opHash, scopes, - numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, - numOutputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return 1; + auto stateP = reinterpret_cast(state); + delete stateP; +} + +Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, + sd::graph::GraphState *state, Nd4jLong opHash, + Nd4jLong *scopes, int numScopes, + Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs) { + /** + * That's basically exec, with VariableSpace provided in GraphState: + * depending on operation (i.e. while of if), different logic executors could + * be used + */ + + auto graph = state->graph(); + auto varSpace = state->variableSpace(); + + // Node is dynamically created, and has nothing beyond it: only inputs and + // outputs this node has id of 0, and inputs are + Node node(OpType_LOGIC, opHash, 0); + + // mapping inputs + for (int e = 0; e < numInputs; e++) { + auto buffer = inputBuffers[e]; + auto shapeInfo = reinterpret_cast(inputShapes[e]); + + auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); + + // now we just put array to VarSpace + varSpace->putVariable(0, e, array); + node.pickInput(0, e); + } + + // mapping scopes + for (int e = 0; e < numScopes; e++) { + // we should check scope existence in GraphState/Graph + int scopeId = (int)scopes[e]; + if (!state->hasScope(scopeId)) { + // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't + // exist\n", scopeId); + return Status::THROW(); } + node.pickInput(scopeId, 0); + } + + auto dZ = LogicExecutor::processNode(graph, &node); + if (dZ != Status::OK()) return dZ; + + // mapping outputs + + for (int e = 0; e < numOutputs; e++) { + auto buffer = outputBuffers[e]; + auto shapeInfo = reinterpret_cast(outputShapes[e]); + + NDArray array(buffer, shapeInfo, varSpace->launchContext()); + + // now we just put array to VarSpace to the same ID + // varSpace->putVariable(0, e, array); + + auto t = varSpace->getVariable(0, e)->getNDArray(); + array.assign(t); + } + + // removing input variables + for (int e = 0; e < numInputs; e++) { + varSpace->dropVariable(0, e); + } + + // after some bla-bla-bla we should have Graph and Node for current op + return Status::OK(); +} + +Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, + Nd4jLong opHash, Nd4jLong *scopes, + int numScopes, Nd4jPointer *inputBuffers, + Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, + Nd4jPointer *outputShapes, int numOutputs) { + try { + return execCustomOpWithScope( + extraPointers, reinterpret_cast(state), opHash, + scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, + outputShapes, numOutputs); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return 1; + } } void deleteResultWrapper(Nd4jPointer ptr) { - // just 0 room for compiler s@!t - auto p = reinterpret_cast(ptr); - delete p; + // just 0 room for compiler s@!t + auto p = reinterpret_cast(ptr); + delete p; } -int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong const* dXShapeInfo, int N, float threshold) { - throw std::runtime_error("estimateThreshold: Not implemented yet"); +int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, + Nd4jLong const *dXShapeInfo, int N, float threshold) { + throw std::runtime_error("estimateThreshold: Not implemented yet"); } /* * TypeDef: - * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, long N, int dstType, Nd4jPointer dZ); + * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, long + * N, int dstType, Nd4jPointer dZ); */ -void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, Nd4jLong N, int dstType, Nd4jPointer dZ) { - try { - auto dx = reinterpret_cast(dX); - auto dz = reinterpret_cast(dZ); - - if (srcType == ND4J_FLOAT8) { - if (dstType == ND4J_FLOAT8) { - // convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: eventually we might want to add it - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_UINT8) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: still might want to add - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT16) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: .... ^^^ - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT16) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO... - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT24) { - - } else if (srcType == ND4J_FLOAT32) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_DOUBLE) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_DOUBLE) { - if (dstType == ND4J_FLOAT8) { - //sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - - } else if (dstType == ND4J_FLOAT32) { - sd::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - // - } else if (dstType == ND4J_THRESHOLD) { - //sd::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_THRESHOLD) { - if (dstType == ND4J_FLOAT16) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_FLOAT32) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //sd::convertFromThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); +void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, Nd4jLong N, + int dstType, Nd4jPointer dZ) { + try { + auto dx = reinterpret_cast(dX); + auto dz = reinterpret_cast(dZ); + + if (srcType == ND4J_FLOAT8) { + if (dstType == ND4J_FLOAT8) { + // convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_UINT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_FLOAT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_UINT16) { + // sd::TypeCast::convertGenericCuda(extras, dx, + // N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_DOUBLE) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + // convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: eventually we might want to add it + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_UINT8) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: still might want to add + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: .... ^^^ + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_INT16) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO... + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_FLOAT24) { + } else if (srcType == ND4J_FLOAT32) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_DOUBLE) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_DOUBLE) { + if (dstType == ND4J_FLOAT8) { + // sd::TypeCast::convertGenericCuda(extras, dx, N, + // dz); + } else if (dstType == ND4J_INT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + sd::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + // + } else if (dstType == ND4J_THRESHOLD) { + // sd::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else if (srcType == ND4J_THRESHOLD) { + if (dstType == ND4J_FLOAT16) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_FLOAT32) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + // sd::convertFromThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); + } + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, + dstType); } + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length) { - auto u = new sd::utf8string(string, length); - return reinterpret_cast(u); +Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, + int length) { + auto u = new sd::utf8string(string, length); + return reinterpret_cast(u); } Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_length; + return reinterpret_cast(ptr)->_length; } -char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - return reinterpret_cast(ptr)->_buffer; +char *getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; } void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { - delete(reinterpret_cast(ptr)); + delete (reinterpret_cast(ptr)); } /////////////////////////////////////////////////////////////////// -template -__global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArrs, - void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const void* vindexes) { - - __shared__ T *x, *y; - __shared__ Nd4jLong arrLenX, arrLenY; - auto indexes = reinterpret_cast(vindexes); - - for (int e = 0; e < numOfSubArrs; e++ ) { - - const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; - - if (!isOwner) - continue; +template +__global__ static void scatterUpdateCuda(const int opCode, + const int numOfSubArrs, void *vx, + const Nd4jLong *xShapeInfo, + const Nd4jLong *xOffsets, void *vy, + const Nd4jLong *yShapeInfo, + const Nd4jLong *yOffsets, + const void *vindexes) { + __shared__ T *x, *y; + __shared__ Nd4jLong arrLenX, arrLenY; + auto indexes = reinterpret_cast(vindexes); + + for (int e = 0; e < numOfSubArrs; e++) { + const auto xIndex = indexes[e]; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex + : blockIdx.x == xIndex % gridDim.x; + + if (!isOwner) continue; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + arrLenX = shape::length(xShapeInfo); + arrLenY = shape::length(yShapeInfo); + } + __syncthreads(); - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[xIndex]; - y = reinterpret_cast(vy) + yOffsets[e]; - arrLenX = shape::length(xShapeInfo); - arrLenY = shape::length(yShapeInfo); - } - __syncthreads(); - - if (arrLenX != arrLenY) - return; - - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - switch (opCode) { - case 0: - x[xOffset] += y[yOffset]; - break; - case 1: - x[xOffset] -= y[yOffset]; - break; - case 2: - x[xOffset] *= y[yOffset]; - break; - case 3: - x[xOffset] /= y[yOffset]; - break; - case 4: - x[xOffset] = y[yOffset] - x[xOffset]; - break; - case 5: - x[xOffset] = y[yOffset] / x[xOffset]; - break; - case 6: - x[xOffset] = y[yOffset]; - break; - default: - continue; - } - } - __syncthreads(); + if (arrLenX != arrLenY) return; + + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + switch (opCode) { + case 0: + x[xOffset] += y[yOffset]; + break; + case 1: + x[xOffset] -= y[yOffset]; + break; + case 2: + x[xOffset] *= y[yOffset]; + break; + case 3: + x[xOffset] /= y[yOffset]; + break; + case 4: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case 5: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case 6: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } } + __syncthreads(); + } } -template -__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfSubArrs, void* vx, const Nd4jLong const* xShapeInfo, const Nd4jLong* xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const void* indexes) { - - scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); +template +__host__ static void scatterUpdateCudaLauncher( + const cudaStream_t *stream, const int opCode, const int numOfSubArrs, + void *vx, const Nd4jLong const *xShapeInfo, const Nd4jLong *xOffsets, + void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, + const void *indexes) { + scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>( + opCode, numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, + indexes); } - ////////////////////////////////////////////////////////////////////////// void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, - void* hX, Nd4jLong const* hXShapeInfo, Nd4jLong const* hXOffsets, - void* dX, Nd4jLong const* dXShapeInfo, Nd4jLong const* dXOffsets, - void* hY, Nd4jLong const* hYShapeInfo, Nd4jLong const* hYOffsets, - void* dY, Nd4jLong const* dYShapeInfo, Nd4jLong const* dYOffsets, - void* hIindexes, Nd4jLong const* hIndicesShapeInfo, void* dIindexes, Nd4jLong const* dIndicesShapeInfo) { - try { - auto stream = reinterpret_cast(extraPointers[1]); - - auto type = ArrayOptions::dataType(hXShapeInfo); - auto iType = ArrayOptions::dataType(hIndicesShapeInfo); - - BUILD_DOUBLE_SELECTOR(type, iType, scatterUpdateCudaLauncher, - (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIindexes), - LIBND4J_TYPES, INDEXING_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - try { - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, specialBuffer, shapeInfo, &lc); - sd::DebugHelper::retrieveDebugStatistics(p, &array); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } -} - -void __global__ tryPointerKernel(void* p, int len) { - auto buf = reinterpret_cast(p); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ int b; - if (tid < len) - atomicAdd(&b, buf[tid]); - - __syncthreads(); - - if (threadIdx.x ==0 && blockIdx.x == 0) - printf("Pointer check complete: %i\n", b); + void *hX, Nd4jLong const *hXShapeInfo, + Nd4jLong const *hXOffsets, void *dX, + Nd4jLong const *dXShapeInfo, Nd4jLong const *dXOffsets, + void *hY, Nd4jLong const *hYShapeInfo, + Nd4jLong const *hYOffsets, void *dY, + Nd4jLong const *dYShapeInfo, Nd4jLong const *dYOffsets, + void *hIindexes, Nd4jLong const *hIndicesShapeInfo, + void *dIindexes, Nd4jLong const *dIndicesShapeInfo) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + + auto type = ArrayOptions::dataType(hXShapeInfo); + auto iType = ArrayOptions::dataType(hIndicesShapeInfo); + + BUILD_DOUBLE_SELECTOR(type, iType, scatterUpdateCudaLauncher, + (stream, opCode, numOfSubArrs, dX, dXShapeInfo, + dXOffsets, dY, dYShapeInfo, dYOffsets, dIindexes), + LIBND4J_TYPES, INDEXING_TYPES); + + sd::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, + Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, + Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], + extraPointers[3]); + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, specialBuffer, shapeInfo, &lc); + sd::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } +} + +void __global__ tryPointerKernel(void *p, int len) { + auto buf = reinterpret_cast(p); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + __shared__ int b; + if (tid < len) atomicAdd(&b, buf[tid]); + + __syncthreads(); + + if (threadIdx.x == 0 && blockIdx.x == 0) + printf("Pointer check complete: %i\n", b); } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { - try { - cudaStream_t stream; - cudaStreamCreate(&stream); + try { + cudaStream_t stream; + cudaStreamCreate(&stream); - tryPointerKernel << < 256, 512, len + 64, stream >> > (p, len); - auto e = cudaStreamSynchronize(stream); + tryPointerKernel<<<256, 512, len + 64, stream>>>(p, len); + auto e = cudaStreamSynchronize(stream); - if (e != 0) - throw sd::cuda_exception::build("tryPointer failed", e); + if (e != 0) throw sd::cuda_exception::build("tryPointer failed", e); - cudaStreamDestroy(stream); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + cudaStreamDestroy(stream); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } int dataTypeFromNpyHeader(void *header) { - return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); -} -sd::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, sd::DataType dtype, char order, Nd4jLong ews, bool empty) { - try { - auto buffer = new ConstantDataBuffer(); - *buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo( - ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } + return (int)cnpy::dataTypeFromHeader(reinterpret_cast(header)); +} +sd::ConstantDataBuffer *shapeBuffer(int rank, Nd4jLong *shape, + Nd4jLong *strides, sd::DataType dtype, + char order, Nd4jLong ews, bool empty) { + try { + auto buffer = new ConstantDataBuffer(); + *buffer = sd::ConstantShapeHelper::getInstance()->bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -void deleteShapeBuffer(sd::ConstantDataBuffer* ptr) { - delete ptr; -} +void deleteShapeBuffer(sd::ConstantDataBuffer *ptr) { delete ptr; } -void deleteTadPack(sd::TadPack* ptr) { - delete ptr; -} +void deleteTadPack(sd::TadPack *ptr) { delete ptr; } bool isBlasVersionMatches(int major, int minor, int build) { - auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion; + auto result = major == Environment::getInstance()->_blasMajorVersion && + minor == Environment::getInstance()->_blasMinorVersion && + build == Environment::getInstance()->_blasPatchVersion; - if (!result) { - nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build); - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch"); - } + if (!result) { + nd4j_printf( + "CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i " + "instead\n", + Environment::getInstance()->_blasMajorVersion, + Environment::getInstance()->_blasMinorVersion, + Environment::getInstance()->_blasPatchVersion, major, minor, build); + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(152); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + "CUDA/cuBLAS version mismatch"); + } - return result; + return result; } -sd::ConstantDataBuffer* constantBufferLong(sd::DataType dtype, Nd4jLong const* data, int length) { - return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); +sd::ConstantDataBuffer *constantBufferLong(sd::DataType dtype, + Nd4jLong const *data, int length) { + return sd::ConstantHelper::getInstance()->constantBuffer( + ConstantDescriptor(data, length), dtype); } -sd::ConstantDataBuffer* constantBufferDouble(sd::DataType dtype, double *data, int length) { - return sd::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype); +sd::ConstantDataBuffer *constantBufferDouble(sd::DataType dtype, double *data, + int length) { + return sd::ConstantHelper::getInstance()->constantBuffer( + ConstantDescriptor(data, length), dtype); } -sd::ConstantDataBuffer* constantBuffer(sd::DataType dtype, sd::ConstantDescriptor *descriptor) { - return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); +sd::ConstantDataBuffer *constantBuffer(sd::DataType dtype, + sd::ConstantDescriptor *descriptor) { + return sd::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); } - -Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer* dbf) { - return dbf->primary(); +Nd4jPointer getConstantDataBufferPrimary(sd::ConstantDataBuffer *dbf) { + return dbf->primary(); } -Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer* dbf) { - return dbf->special(); +Nd4jPointer getConstantDataBufferSpecial(sd::ConstantDataBuffer *dbf) { + return dbf->special(); } -Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer* dbf) { - return dbf->length(); +Nd4jLong getConstantDataBufferLength(sd::ConstantDataBuffer *dbf) { + return dbf->length(); } -Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer* dbf) { - return dbf->sizeOf(); +Nd4jLong getConstantDataBufferSizeOf(sd::ConstantDataBuffer *dbf) { + return dbf->sizeOf(); } - -sd::graph::Context* createGraphContext(int nodeId) { - return new sd::graph::Context(nodeId); +sd::graph::Context *createGraphContext(int nodeId) { + return new sd::graph::Context(nodeId); } -sd::graph::RandomGenerator* getGraphContextRandomGenerator(sd::graph::Context* ptr) { - return &ptr->randomGenerator(); +sd::graph::RandomGenerator *getGraphContextRandomGenerator( + sd::graph::Context *ptr) { + return &ptr->randomGenerator(); } -void markGraphContextInplace(sd::graph::Context* ptr, bool reallyInplace) { - ptr->markInplace(reallyInplace); +void markGraphContextInplace(sd::graph::Context *ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); } -void setGraphContextCudaContext(sd::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { - ptr->setCudaContext(stream, reductionPointer, allocationPointer); +void setGraphContextCudaContext(sd::graph::Context *ptr, void *stream, + void *reductionPointer, + void *allocationPointer) { + ptr->setCudaContext(stream, reductionPointer, allocationPointer); } -void setGraphContextInputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextInputArray(sd::graph::Context *ptr, int index, void *buffer, + void *shapeInfo, void *specialBuffer, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); } -void setGraphContextOutputArray(sd::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +void setGraphContextOutputArray(sd::graph::Context *ptr, int index, + void *buffer, void *shapeInfo, + void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, + specialShapeInfo); } -void setGraphContextInputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextInputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextOutputBuffer(OpaqueContext* ptr, int index, OpaqueDataBuffer *buffer, void *shapeInfo, void *specialShapeInfo) { - ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); +void setGraphContextOutputBuffer(OpaqueContext *ptr, int index, + OpaqueDataBuffer *buffer, void *shapeInfo, + void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialShapeInfo); } -void setGraphContextTArguments(sd::graph::Context* ptr, double *arguments, int numberOfArguments) { - ptr->setTArguments(arguments, numberOfArguments); +void setGraphContextTArguments(sd::graph::Context *ptr, double *arguments, + int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); } -void setGraphContextIArguments(sd::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { - ptr->setIArguments(arguments, numberOfArguments); +void setGraphContextIArguments(sd::graph::Context *ptr, Nd4jLong *arguments, + int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); } -void setGraphContextBArguments(sd::graph::Context* ptr, bool *arguments, int numberOfArguments) { - ptr->setBArguments(arguments, numberOfArguments); +void setGraphContextBArguments(sd::graph::Context *ptr, bool *arguments, + int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); } -void setGraphContextDArguments(OpaqueContext* ptr, int *arguments, int numberOfArguments) { - std::vector dtypes(numberOfArguments); - for (int e = 0; e < numberOfArguments; e++) - dtypes[e] = (sd::DataType) arguments[e]; - - ptr->setDArguments(dtypes); -} +void setGraphContextDArguments(OpaqueContext *ptr, int *arguments, + int numberOfArguments) { + std::vector dtypes(numberOfArguments); + for (int e = 0; e < numberOfArguments; e++) + dtypes[e] = (sd::DataType)arguments[e]; -void deleteGraphContext(sd::graph::Context* ptr) { - delete ptr; + ptr->setDArguments(dtypes); } +void deleteGraphContext(sd::graph::Context *ptr) { delete ptr; } -sd::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - try { - return new sd::graph::RandomGenerator(rootSeed, nodeSeed); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator* ptr) { - return ptr->rootState(); +sd::graph::RandomGenerator *createRandomGenerator(Nd4jLong rootSeed, + Nd4jLong nodeSeed) { + try { + return new sd::graph::RandomGenerator(rootSeed, nodeSeed); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } -Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator* ptr) { - return ptr->nodeState(); +Nd4jLong getRandomGeneratorRootState(sd::graph::RandomGenerator *ptr) { + return ptr->rootState(); } -void setRandomGeneratorStates(sd::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { - ptr->setStates(rootSeed, nodeSeed); +Nd4jLong getRandomGeneratorNodeState(sd::graph::RandomGenerator *ptr) { + return ptr->nodeState(); } -int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeInt(index); +void setRandomGeneratorStates(sd::graph::RandomGenerator *ptr, + Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); } -Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator* ptr, Nd4jLong index) { - return ptr->relativeLong(index); +int getRandomGeneratorRelativeInt(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeInt(index); } -void deleteRandomGenerator(sd::graph::RandomGenerator* ptr) { - delete ptr; +Nd4jLong getRandomGeneratorRelativeLong(sd::graph::RandomGenerator *ptr, + Nd4jLong index) { + return ptr->relativeLong(index); } +void deleteRandomGenerator(sd::graph::RandomGenerator *ptr) { delete ptr; } Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - try { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for (unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; - - if (arr.shape[i] == 0) - _empty = true; - } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return (Nd4jPointer)(sd::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); // TO DO: this can lead to unpleasant crash sometimes - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } -} - -const char* runLightBenchmarkSuit(bool printOut) { - try { - sd::LightBenchmarkSuit suit; - auto result = suit.runSuit(); - - if (printOut) - nd4j_printf("%s\n", result.data()); - - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; - - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; + try { + cnpy::NpyArray arr = + cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) _empty = true; } -} - -const char* runFullBenchmarkSuit(bool printOut) { - try { - sd::FullBenchmarkSuit suit; - auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); - - auto chars = new char[result.length() + 1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; - - return chars; - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = sd::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = sd::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = sd::ShapeBuilders::createShapeInfo( + dtype, arr.fortranOrder ? 'f' : 'c', shape); } + return (Nd4jPointer)( + sd::ConstantShapeHelper::getInstance()->createFromExisting( + shapeBuffer, + true)); // TO DO: this can lead to unpleasant crash sometimes + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +const char *runLightBenchmarkSuit(bool printOut) { + try { + sd::LightBenchmarkSuit suit; + auto result = suit.runSuit(); + + if (printOut) nd4j_printf("%s\n", result.data()); + + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; + + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } +} + +const char *runFullBenchmarkSuit(bool printOut) { + try { + sd::FullBenchmarkSuit suit; + auto result = suit.runSuit(); + + if (printOut) nd4j_printf("%s\n", result.data()); + + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char)0x0; + + return chars; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { - return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId); + return sd::ConstantHelper::getInstance()->getCachedAmount(deviceId); } -sd::LaunchContext* defaultLaunchContext() { - return LaunchContext::defaultContext(); +sd::LaunchContext *defaultLaunchContext() { + return LaunchContext::defaultContext(); } -Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { - return lc->getScalarPointer(); +Nd4jPointer lcScalarPointer(OpaqueLaunchContext *lc) { + return lc->getScalarPointer(); } -Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { - return lc->getReductionPointer(); +Nd4jPointer lcReductionPointer(OpaqueLaunchContext *lc) { + return lc->getReductionPointer(); } -Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { - return lc->getAllocationPointer(); +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext *lc) { + return lc->getAllocationPointer(); } -Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { - return lc->getCudaStream(); +Nd4jPointer lcExecutionStream(OpaqueLaunchContext *lc) { + return lc->getCudaStream(); } -Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { - return lc->getCudaSpecialStream(); +Nd4jPointer lcCopyStream(OpaqueLaunchContext *lc) { + return lc->getCudaSpecialStream(); } -Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { - return lc->getCublasHandle(); +Nd4jPointer lcBlasHandle(OpaqueLaunchContext *lc) { + return lc->getCublasHandle(); } -Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { - return lc->getCusolverHandle(); +Nd4jPointer lcSolverHandle(OpaqueLaunchContext *lc) { + return lc->getCusolverHandle(); } int lastErrorCode() { - return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); + return sd::LaunchContext::defaultContext()->errorReference()->errorCode(); } -const char* lastErrorMessage() { - return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); +const char *lastErrorMessage() { + return sd::LaunchContext::defaultContext()->errorReference()->errorMessage(); } -void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { - ptr->setShapeFunctionOverride(reallyOverride); +void ctxShapeFunctionOverride(OpaqueContext *ptr, bool reallyOverride) { + ptr->setShapeFunctionOverride(reallyOverride); } -void ctxPurge(OpaqueContext* ptr) { - ptr->clearFastPath(); -} +void ctxPurge(OpaqueContext *ptr) { ptr->clearFastPath(); } -int binaryLevel() { - return 0; -} +int binaryLevel() { return 0; } -int optimalLevel() { - return 0; -} +int optimalLevel() { return 0; } -bool isMinimalRequirementsMet() { - return true; -} +bool isMinimalRequirementsMet() { return true; } -bool isOptimalRequirementsMet() { - return true; -} +bool isOptimalRequirementsMet() { return true; } -void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { - ptr->allowHelpers(reallyAllow); +void ctxAllowHelpers(OpaqueContext *ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); } -void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { - if (execMode < 0 || execMode > 2) - execMode = 0; +void ctxSetExecutionMode(OpaqueContext *ptr, int execMode) { + if (execMode < 0 || execMode > 2) execMode = 0; - ptr->setExecutionMode((samediff::ExecutionMode) execMode); + ptr->setExecutionMode((samediff::ExecutionMode)execMode); } -OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { - auto buffer = dbAllocateDataBuffer(0, dataType, false); +OpaqueDataBuffer *dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, + Nd4jPointer primary, + Nd4jPointer special) { + auto buffer = dbAllocateDataBuffer(0, dataType, false); - if (primary != nullptr) - buffer->setPrimary(primary, elements); + if (primary != nullptr) buffer->setPrimary(primary, elements); - if (special != nullptr) - buffer->setSpecial(special, elements); + if (special != nullptr) buffer->setSpecial(special, elements); - return buffer; + return buffer; } -OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - return allocateDataBuffer(elements, dataType, allocateBoth); +OpaqueDataBuffer *dbAllocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + return allocateDataBuffer(elements, dataType, allocateBoth); } -OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { - try { - auto dtype = DataTypeUtils::fromInt(dataType); - return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), dtype, allocateBoth); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - return nullptr; - } +OpaqueDataBuffer *allocateDataBuffer(Nd4jLong elements, int dataType, + bool allocateBoth) { + try { + auto dtype = DataTypeUtils::fromInt(dataType); + return new sd::InteropDataBuffer(elements * DataTypeUtils::sizeOf(dtype), + dtype, allocateBoth); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->primary(); + return dataBuffer->primary(); } Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->special(); + return dataBuffer->special(); } -void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { - delete dataBuffer; -} +void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { delete dataBuffer; } -void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { - dataBuffer->setPrimary(primaryBuffer, numBytes); +void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, + Nd4jLong numBytes) { + dataBuffer->setPrimary(primaryBuffer, numBytes); } -void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, Nd4jLong numBytes) { - dataBuffer->setSpecial(specialBuffer, numBytes); +void dbSetSpecialBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer specialBuffer, + Nd4jLong numBytes) { + dataBuffer->setSpecial(specialBuffer, numBytes); } void dbAllocatePrimaryBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocatePrimary(); + dataBuffer->dataBuffer()->allocatePrimary(); } void dbAllocateSpecialBuffer(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->allocateSpecial(); + dataBuffer->dataBuffer()->allocateSpecial(); } void dbExpandBuffer(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - try { - dataBuffer->dataBuffer()->expand(elements * DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); - } + try { + dataBuffer->dataBuffer()->expand( + elements * + DataTypeUtils::sizeOf(dataBuffer->dataBuffer()->getDataType())); + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + } } -OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset) { - return new InteropDataBuffer(*dataBuffer, length, offset); +OpaqueDataBuffer *dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, + Nd4jLong offset) { + return new InteropDataBuffer(*dataBuffer, length, offset); } void dbSyncToSpecial(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToSpecial(); + dataBuffer->dataBuffer()->syncToSpecial(); } void dbSyncToPrimary(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->syncToPrimary(nullptr); + dataBuffer->dataBuffer()->syncToPrimary(nullptr); } void dbTickHostRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readPrimary(); + dataBuffer->dataBuffer()->readPrimary(); } void dbTickHostWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writePrimary(); + dataBuffer->dataBuffer()->writePrimary(); } void dbTickDeviceRead(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->readSpecial(); + dataBuffer->dataBuffer()->readSpecial(); } void dbTickDeviceWrite(OpaqueDataBuffer *dataBuffer) { - dataBuffer->dataBuffer()->writeSpecial(); + dataBuffer->dataBuffer()->writeSpecial(); } void dbExpand(OpaqueDataBuffer *dataBuffer, Nd4jLong elements) { - dataBuffer->expand(elements); + dataBuffer->expand(elements); } void dbClose(OpaqueDataBuffer *dataBuffer) { - dataBuffer->getDataBuffer()->close(); + dataBuffer->getDataBuffer()->close(); } -int dbDeviceId(OpaqueDataBuffer *dataBuffer) { - return dataBuffer->deviceId(); -} +int dbDeviceId(OpaqueDataBuffer *dataBuffer) { return dataBuffer->deviceId(); } void dbSetDeviceId(OpaqueDataBuffer *dataBuffer, int deviceId) { - dataBuffer->setDeviceId(deviceId); + dataBuffer->setDeviceId(deviceId); } int dbLocality(OpaqueDataBuffer *dataBuffer) { - auto p = dataBuffer->dataBuffer()->isPrimaryActual(); - auto d = dataBuffer->dataBuffer()->isSpecialActual(); - - if (p && d) - return 0; - else if (p) - return -1; - else - return 1; + auto p = dataBuffer->dataBuffer()->isPrimaryActual(); + auto d = dataBuffer->dataBuffer()->isSpecialActual(); + + if (p && d) + return 0; + else if (p) + return -1; + else + return 1; } \ No newline at end of file diff --git a/libnd4j/include/legacy/impl/Environment.cpp b/libnd4j/include/legacy/impl/Environment.cpp index b19a7147bc61..6058a09e26cc 100644 --- a/libnd4j/include/legacy/impl/Environment.cpp +++ b/libnd4j/include/legacy/impl/Environment.cpp @@ -18,16 +18,18 @@ // Created by raver119 on 06.10.2017. // -#include -#include -#include -#include #include "system/Environment.h" + #include -#include #include #include +#include +#include +#include +#include +#include + #ifdef _OPENMP #include @@ -43,347 +45,318 @@ namespace sd { - sd::Environment::Environment() { - _tadThreshold.store(1); - _elementThreshold.store(1024); - _verbose.store(false); - _debug.store(false); - _profile.store(false); - _precBoost.store(false); - _leaks.store(false); - _dataType.store(sd::DataType::FLOAT32); - _maxThreads = std::thread::hardware_concurrency(); - _maxMasterThreads = _maxThreads.load(); +sd::Environment::Environment() { + _tadThreshold.store(1); + _elementThreshold.store(1024); + _verbose.store(false); + _debug.store(false); + _profile.store(false); + _precBoost.store(false); + _leaks.store(false); + _dataType.store(sd::DataType::FLOAT32); + _maxThreads = std::thread::hardware_concurrency(); + _maxMasterThreads = _maxThreads.load(); #ifndef ANDROID - const char* omp_threads = std::getenv("OMP_NUM_THREADS"); - if (omp_threads != nullptr) { - try { - std::string omp(omp_threads); - int val = std::stoi(omp); - _maxThreads.store(val); - _maxMasterThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * Defines size of thread pool used for parallelism - */ - const char* max_threads = std::getenv("SD_MAX_THREADS"); - if (max_threads != nullptr) { - try { - std::string t(max_threads); - int val = std::stoi(t); - _maxThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * Defines max number of threads usable at once - */ - const char* max_master_threads = std::getenv("SD_MASTER_THREADS"); - if (max_master_threads != nullptr) { - try { - std::string t(max_master_threads); - int val = std::stoi(t); - _maxMasterThreads.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - if (_maxMasterThreads.load() > _maxThreads.load()) { - nd4j_printf("Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match each other\n",""); - _maxMasterThreads.store(_maxThreads.load()); - } - - /** - * If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc) - */ - const char* forbid_helpers = std::getenv("SD_FORBID_HELPERS"); - if (max_master_threads != nullptr) { - _allowHelpers = false; - } - - /** - * This var defines max amount of host memory library can allocate - */ - const char* max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES"); - if (max_primary_memory != nullptr) { - try { - std::string t(max_primary_memory); - auto val = std::stol(t); - _maxTotalPrimaryMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined - */ - const char* max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES"); - if (max_special_memory != nullptr) { - try { - std::string t(max_special_memory); - auto val = std::stol(t); - _maxTotalSpecialMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - /** - * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined - */ - const char* max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES"); - if (max_device_memory != nullptr) { - try { - std::string t(max_device_memory); - auto val = std::stol(t); - _maxDeviceMemory.store(val); - } catch (std::invalid_argument &e) { - // just do nothing - } catch (std::out_of_range &e) { - // still do nothing - } - } - - const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK"); - if (blas_fallback != nullptr) { - _blasFallback = true; - } + const char *omp_threads = std::getenv("OMP_NUM_THREADS"); + if (omp_threads != nullptr) { + try { + std::string omp(omp_threads); + int val = std::stoi(omp); + _maxThreads.store(val); + _maxMasterThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * Defines size of thread pool used for parallelism + */ + const char *max_threads = std::getenv("SD_MAX_THREADS"); + if (max_threads != nullptr) { + try { + std::string t(max_threads); + int val = std::stoi(t); + _maxThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * Defines max number of threads usable at once + */ + const char *max_master_threads = std::getenv("SD_MASTER_THREADS"); + if (max_master_threads != nullptr) { + try { + std::string t(max_master_threads); + int val = std::stoi(t); + _maxMasterThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + if (_maxMasterThreads.load() > _maxThreads.load()) { + nd4j_printf( + "Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match " + "each other\n", + ""); + _maxMasterThreads.store(_maxThreads.load()); + } + + /** + * If this env var is defined - we'll disallow use of platform-specific + * helpers (mkldnn, cudnn, etc) + */ + const char *forbid_helpers = std::getenv("SD_FORBID_HELPERS"); + if (max_master_threads != nullptr) { + _allowHelpers = false; + } + + /** + * This var defines max amount of host memory library can allocate + */ + const char *max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES"); + if (max_primary_memory != nullptr) { + try { + std::string t(max_primary_memory); + auto val = std::stol(t); + _maxTotalPrimaryMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can + * allocate on all devices combined + */ + const char *max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES"); + if (max_special_memory != nullptr) { + try { + std::string t(max_special_memory); + auto val = std::stol(t); + _maxTotalSpecialMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can + * allocate on all devices combined + */ + const char *max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES"); + if (max_device_memory != nullptr) { + try { + std::string t(max_device_memory); + auto val = std::stol(t); + _maxDeviceMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + const char *blas_fallback = std::getenv("SD_BLAS_FALLBACK"); + if (blas_fallback != nullptr) { + _blasFallback = true; + } #endif #ifdef __CUDABLAS__ - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - auto devProperties = new cudaDeviceProp[devCnt]; - for (int i = 0; i < devCnt; i++) { - cudaSetDevice(i); - cudaGetDeviceProperties(&devProperties[i], i); - - //cudaDeviceSetLimit(cudaLimitStackSize, 4096); - Pair p(devProperties[i].major, devProperties[i].minor); - _capabilities.emplace_back(p); - } - - BlasVersionHelper ver; - _blasMajorVersion = ver._blasMajorVersion; - _blasMinorVersion = ver._blasMinorVersion; - _blasPatchVersion = ver._blasPatchVersion; - - cudaSetDevice(0); - delete[] devProperties; + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + auto devProperties = new cudaDeviceProp[devCnt]; + for (int i = 0; i < devCnt; i++) { + cudaSetDevice(i); + cudaGetDeviceProperties(&devProperties[i], i); + + // cudaDeviceSetLimit(cudaLimitStackSize, 4096); + Pair p(devProperties[i].major, devProperties[i].minor); + _capabilities.emplace_back(p); + } + + BlasVersionHelper ver; + _blasMajorVersion = ver._blasMajorVersion; + _blasMinorVersion = ver._blasMinorVersion; + _blasPatchVersion = ver._blasPatchVersion; + + cudaSetDevice(0); + delete[] devProperties; #else #endif - } +} - bool sd::Environment::blasFallback() { - return _blasFallback; - } +bool sd::Environment::blasFallback() { return _blasFallback; } - sd::Environment::~Environment() { - // - } +sd::Environment::~Environment() { + // +} - void Environment::setMaxPrimaryMemory(uint64_t maxBytes) { - _maxTotalPrimaryMemory = maxBytes; - } +void Environment::setMaxPrimaryMemory(uint64_t maxBytes) { + _maxTotalPrimaryMemory = maxBytes; +} - void Environment::setMaxSpecialyMemory(uint64_t maxBytes) { - _maxTotalSpecialMemory = maxBytes; - } +void Environment::setMaxSpecialyMemory(uint64_t maxBytes) { + _maxTotalSpecialMemory = maxBytes; +} - void Environment::setMaxDeviceMemory(uint64_t maxBytes) { - _maxDeviceMemory = maxBytes; - } +void Environment::setMaxDeviceMemory(uint64_t maxBytes) { + _maxDeviceMemory = maxBytes; +} - Environment *Environment::getInstance() { - if (_instance == 0) - _instance = new Environment(); +Environment *Environment::getInstance() { + if (_instance == 0) _instance = new Environment(); - return _instance; - } + return _instance; +} - bool Environment::isVerbose() { - return _verbose.load(); - } +bool Environment::isVerbose() { return _verbose.load(); } - bool Environment::isExperimentalBuild() { - return _experimental; - } +bool Environment::isExperimentalBuild() { return _experimental; } - sd::DataType Environment::defaultFloatDataType() { - return _dataType.load(); - } +sd::DataType Environment::defaultFloatDataType() { return _dataType.load(); } - std::vector& Environment::capabilities() { - return _capabilities; - } +std::vector &Environment::capabilities() { return _capabilities; } - void Environment::setDefaultFloatDataType(sd::DataType dtype) { - if (dtype != sd::DataType::FLOAT32 && dtype != sd::DataType::DOUBLE && dtype != sd::DataType::FLOAT8 && dtype != sd::DataType::HALF) - throw std::runtime_error("Default Float data type must be one of [FLOAT8, FLOAT16, FLOAT32, DOUBLE]"); +void Environment::setDefaultFloatDataType(sd::DataType dtype) { + if (dtype != sd::DataType::FLOAT32 && dtype != sd::DataType::DOUBLE && + dtype != sd::DataType::FLOAT8 && dtype != sd::DataType::HALF) + throw std::runtime_error( + "Default Float data type must be one of [FLOAT8, FLOAT16, FLOAT32, " + "DOUBLE]"); - _dataType.store(dtype); - } + _dataType.store(dtype); +} - void Environment::setVerbose(bool reallyVerbose) { - _verbose = reallyVerbose; - } +void Environment::setVerbose(bool reallyVerbose) { _verbose = reallyVerbose; } - bool Environment::isDebug() { - return _debug.load(); - } +bool Environment::isDebug() { return _debug.load(); } - bool Environment::isProfiling() { - return _profile.load(); - } +bool Environment::isProfiling() { return _profile.load(); } - bool Environment::isDetectingLeaks() { - return _leaks.load(); - } +bool Environment::isDetectingLeaks() { return _leaks.load(); } - void Environment::setLeaksDetector(bool reallyDetect) { - _leaks.store(reallyDetect); - } +void Environment::setLeaksDetector(bool reallyDetect) { + _leaks.store(reallyDetect); +} - void Environment::setProfiling(bool reallyProfile) { - _profile.store(reallyProfile); - } +void Environment::setProfiling(bool reallyProfile) { + _profile.store(reallyProfile); +} - bool Environment::isDebugAndVerbose() { - return this->isDebug() && this->isVerbose(); - } +bool Environment::isDebugAndVerbose() { + return this->isDebug() && this->isVerbose(); +} - void Environment::setDebug(bool reallyDebug) { - _debug = reallyDebug; - } +void Environment::setDebug(bool reallyDebug) { _debug = reallyDebug; } - int Environment::tadThreshold() { - return _tadThreshold.load(); - } +int Environment::tadThreshold() { return _tadThreshold.load(); } - void Environment::setTadThreshold(int threshold) { - _tadThreshold = threshold; - } +void Environment::setTadThreshold(int threshold) { _tadThreshold = threshold; } - int Environment::elementwiseThreshold() { - return _elementThreshold.load(); - } +int Environment::elementwiseThreshold() { return _elementThreshold.load(); } - void Environment::setElementwiseThreshold(int threshold) { - _elementThreshold = threshold; - } +void Environment::setElementwiseThreshold(int threshold) { + _elementThreshold = threshold; +} - int Environment::maxThreads() { - return _maxThreads.load(); - } +int Environment::maxThreads() { return _maxThreads.load(); } - int Environment::maxMasterThreads() { - return _maxMasterThreads.load(); - } +int Environment::maxMasterThreads() { return _maxMasterThreads.load(); } - void Environment::setMaxThreads(int max) { - // FIXME: not possible at this moment, since maxThreads is limited by number of threads in pool. however we can allocate more threads if we want - //_maxThreads.store(max); - } +void Environment::setMaxThreads(int max) { + // FIXME: not possible at this moment, since maxThreads is limited by number + // of threads in pool. however we can allocate more threads if we want + //_maxThreads.store(max); +} - void Environment::setMaxMasterThreads(int max) { - if (max > maxThreads()) { - max = maxThreads(); - } +void Environment::setMaxMasterThreads(int max) { + if (max > maxThreads()) { + max = maxThreads(); + } - if (max < 1) - return; + if (max < 1) return; - _maxMasterThreads = max; - } + _maxMasterThreads = max; +} - bool Environment::precisionBoostAllowed() { - return _precBoost.load(); - } +bool Environment::precisionBoostAllowed() { return _precBoost.load(); } - void Environment::allowPrecisionBoost(bool reallyAllow) { - _precBoost.store(reallyAllow); - } +void Environment::allowPrecisionBoost(bool reallyAllow) { + _precBoost.store(reallyAllow); +} - bool Environment::isCPU() { +bool Environment::isCPU() { #ifdef __CUDABLAS__ - return false; + return false; #else - return true; + return true; #endif - } +} - int Environment::blasMajorVersion(){ - return _blasMajorVersion; - } +int Environment::blasMajorVersion() { return _blasMajorVersion; } - int Environment::blasMinorVersion(){ - return _blasMinorVersion; - } +int Environment::blasMinorVersion() { return _blasMinorVersion; } - int Environment::blasPatchVersion(){ - return _blasPatchVersion; - } +int Environment::blasPatchVersion() { return _blasPatchVersion; } - bool Environment::helpersAllowed() { - return _allowHelpers.load(); - } +bool Environment::helpersAllowed() { return _allowHelpers.load(); } - void Environment::allowHelpers(bool reallyAllow) { - _allowHelpers.store(reallyAllow); - } +void Environment::allowHelpers(bool reallyAllow) { + _allowHelpers.store(reallyAllow); +} - void Environment::setGroupLimit(int group, Nd4jLong numBytes) { - sd::memory::MemoryCounter::getInstance()->setGroupLimit((sd::memory::MemoryType) group, numBytes); - } +void Environment::setGroupLimit(int group, Nd4jLong numBytes) { + sd::memory::MemoryCounter::getInstance()->setGroupLimit( + (sd::memory::MemoryType)group, numBytes); +} - void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) { - sd::memory::MemoryCounter::getInstance()->setDeviceLimit(deviceId, numBytes); - } +void Environment::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + sd::memory::MemoryCounter::getInstance()->setDeviceLimit(deviceId, numBytes); +} - Nd4jLong Environment::getGroupLimit(int group) { - return sd::memory::MemoryCounter::getInstance()->groupLimit((sd::memory::MemoryType) group); - } +Nd4jLong Environment::getGroupLimit(int group) { + return sd::memory::MemoryCounter::getInstance()->groupLimit( + (sd::memory::MemoryType)group); +} - Nd4jLong Environment::getDeviceLimit(int deviceId) { - return sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId); - } +Nd4jLong Environment::getDeviceLimit(int deviceId) { + return sd::memory::MemoryCounter::getInstance()->deviceLimit(deviceId); +} - Nd4jLong Environment::getGroupCounter(int group) { - return sd::memory::MemoryCounter::getInstance()->allocatedGroup((sd::memory::MemoryType) group); - } +Nd4jLong Environment::getGroupCounter(int group) { + return sd::memory::MemoryCounter::getInstance()->allocatedGroup( + (sd::memory::MemoryType)group); +} - Nd4jLong Environment::getDeviceCounter(int deviceId) { - return sd::memory::MemoryCounter::getInstance()->allocatedDevice(deviceId); - } +Nd4jLong Environment::getDeviceCounter(int deviceId) { + return sd::memory::MemoryCounter::getInstance()->allocatedDevice(deviceId); +} - uint64_t Environment::maxPrimaryMemory() { - return _maxTotalPrimaryMemory.load(); - } +uint64_t Environment::maxPrimaryMemory() { + return _maxTotalPrimaryMemory.load(); +} - uint64_t Environment::maxSpecialMemory() { - return _maxTotalSpecialMemory.load(); - } +uint64_t Environment::maxSpecialMemory() { + return _maxTotalSpecialMemory.load(); +} - sd::Environment *sd::Environment::_instance = 0; +sd::Environment *sd::Environment::_instance = 0; -} +} // namespace sd diff --git a/libnd4j/include/legacy/impl/cnpy.cpp b/libnd4j/include/legacy/impl/cnpy.cpp index c24c016875e0..49595563ae79 100644 --- a/libnd4j/include/legacy/impl/cnpy.cpp +++ b/libnd4j/include/legacy/impl/cnpy.cpp @@ -22,25 +22,25 @@ * THE SOFTWARE. ******************************************************************************/ -//Copyright (C) 2011 Carl Rogers -//Released under MIT License -//license available in LICENSE file, or at http://www.opensource.org/licenses/mit-license.php +// Copyright (C) 2011 Carl Rogers +// Released under MIT License +// license available in LICENSE file, or at +// http://www.opensource.org/licenses/mit-license.php -#include -#include #include +#include #include - +#include /** * * @return */ char cnpy::BigEndianTest() { - unsigned char x[] = {1,0}; - short y = *(short*) x; - return y == 1 ? '<' : '>'; + unsigned char x[] = {1, 0}; + short y = *(short *)x; + return y == 1 ? '<' : '>'; } /** @@ -49,121 +49,145 @@ char cnpy::BigEndianTest() { * @return */ char cnpy::mapType(const std::type_info &t) { - if(t == typeid(float) ) return 'f'; - if(t == typeid(double) ) return 'f'; - if(t == typeid(long double) ) return 'f'; - - if(t == typeid(int) ) return 'i'; - if(t == typeid(char) ) return 'i'; - if(t == typeid(short) ) return 'i'; - if(t == typeid(long) ) return 'i'; - if(t == typeid(long long) ) return 'i'; - - if(t == typeid(unsigned char) ) return 'u'; - if(t == typeid(unsigned short) ) return 'u'; - if(t == typeid(unsigned long) ) return 'u'; - if(t == typeid(unsigned long long) ) return 'u'; - if(t == typeid(unsigned int) ) return 'u'; - - if(t == typeid(bool) ) return 'b'; - - if(t == typeid(std::complex) ) return 'c'; - if(t == typeid(std::complex) ) return 'c'; - if(t == typeid(std::complex) ) return 'c'; - - else return '?'; + if (t == typeid(float)) return 'f'; + if (t == typeid(double)) return 'f'; + if (t == typeid(long double)) return 'f'; + + if (t == typeid(int)) return 'i'; + if (t == typeid(char)) return 'i'; + if (t == typeid(short)) return 'i'; + if (t == typeid(long)) return 'i'; + if (t == typeid(long long)) return 'i'; + + if (t == typeid(unsigned char)) return 'u'; + if (t == typeid(unsigned short)) return 'u'; + if (t == typeid(unsigned long)) return 'u'; + if (t == typeid(unsigned long long)) return 'u'; + if (t == typeid(unsigned int)) return 'u'; + + if (t == typeid(bool)) return 'b'; + + if (t == typeid(std::complex)) return 'c'; + if (t == typeid(std::complex)) return 'c'; + if (t == typeid(std::complex)) + return 'c'; + + else + return '?'; } template char cnpy::mapType() { - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - if(std::is_same::value) return 'f'; - - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - if(std::is_same::value) return 'i'; - - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - if(std::is_same::value) return 'u'; - - if(std::is_same::value) return 'b'; - - if(std::is_same, T>::value) return 'c'; - if(std::is_same, T>::value) return 'c'; - if(std::is_same, T>::value) return 'c'; - - else return '?'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + if (std::is_same::value) return 'f'; + + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + if (std::is_same::value) return 'i'; + + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + if (std::is_same::value) return 'u'; + + if (std::is_same::value) return 'b'; + + if (std::is_same, T>::value) return 'c'; + if (std::is_same, T>::value) return 'c'; + if (std::is_same, T>::value) + return 'c'; + + else + return '?'; } sd::DataType cnpy::dataTypeFromHeader(char *data) { - - // indices for type & data size - const int st = 10; - const int ti = 22; - const int si = 23; - - // read first char to make sure it looks like a header - if (data == nullptr || data[st] != '{') - throw std::runtime_error("cnpy::dataTypeFromHeader() - provided pointer doesn't look like a pointer to numpy header"); - - const auto t = data[ti]; - const auto s = data[si]; - - switch (t) { - case 'b': - return sd::DataType::BOOL; - case 'i': - switch (s) { - case '1': return sd::DataType::INT8; - case '2': return sd::DataType::INT16; - case '4': return sd::DataType::INT32; - case '8': return sd::DataType::INT64; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import"); - } - case 'f': - switch (s) { - case '1': return sd::DataType::FLOAT8; - case '2': return sd::DataType::HALF; - case '4': return sd::DataType::FLOAT32; - case '8': return sd::DataType::DOUBLE; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import"); - } - case 'u': - switch (s) { - case '1': return sd::DataType::UINT8; - case '2': return sd::DataType::UINT16; - case '4': return sd::DataType::UINT32; - case '8': return sd::DataType::UINT64; - default: - throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import"); - } - case 'c': - throw std::runtime_error("Import of complex data types isn't supported yet"); + // indices for type & data size + const int st = 10; + const int ti = 22; + const int si = 23; + + // read first char to make sure it looks like a header + if (data == nullptr || data[st] != '{') + throw std::runtime_error( + "cnpy::dataTypeFromHeader() - provided pointer doesn't look like a " + "pointer to numpy header"); + + const auto t = data[ti]; + const auto s = data[si]; + + switch (t) { + case 'b': + return sd::DataType::BOOL; + case 'i': + switch (s) { + case '1': + return sd::DataType::INT8; + case '2': + return sd::DataType::INT16; + case '4': + return sd::DataType::INT32; + case '8': + return sd::DataType::INT64; default: - throw std::runtime_error("Unknown type marker"); - } + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Integer data " + "types import"); + } + case 'f': + switch (s) { + case '1': + return sd::DataType::FLOAT8; + case '2': + return sd::DataType::HALF; + case '4': + return sd::DataType::FLOAT32; + case '8': + return sd::DataType::DOUBLE; + default: + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Float data " + "types import"); + } + case 'u': + switch (s) { + case '1': + return sd::DataType::UINT8; + case '2': + return sd::DataType::UINT16; + case '4': + return sd::DataType::UINT32; + case '8': + return sd::DataType::UINT64; + default: + throw std::runtime_error( + "Only data sizes of [1, 2, 4, 8] are supported for Unsigned data " + "types import"); + } + case 'c': + throw std::runtime_error( + "Import of complex data types isn't supported yet"); + default: + throw std::runtime_error("Unknown type marker"); + } } template -std::vector& operator+=(std::vector& lhs, const T rhs) { - //write in little endian - for(char byte = 0; byte < sizeof(T); byte++) { - char val = *((char*)&rhs+byte); - lhs.push_back(val); - } - - return lhs; +std::vector &operator+=(std::vector &lhs, const T rhs) { + // write in little endian + for (char byte = 0; byte < sizeof(T); byte++) { + char val = *((char *)&rhs + byte); + lhs.push_back(val); + } + + return lhs; } /** @@ -172,10 +196,10 @@ std::vector& operator+=(std::vector& lhs, const T rhs) { * @param rhs * @return */ -template<> -std::vector& operator+=(std::vector& lhs, const std::string rhs) { - lhs.insert(lhs.end(),rhs.begin(),rhs.end()); - return lhs; +template <> +std::vector &operator+=(std::vector &lhs, const std::string rhs) { + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); + return lhs; } /** @@ -184,15 +208,15 @@ std::vector& operator+=(std::vector& lhs, const std::string rhs) { * @param rhs * @return */ -template<> -std::vector& operator+=(std::vector& lhs, const char* rhs) { - //write in little endian - size_t len = strlen(rhs); - lhs.reserve(len); - for(size_t byte = 0; byte < len; byte++) { - lhs.push_back(rhs[byte]); - } - return lhs; +template <> +std::vector &operator+=(std::vector &lhs, const char *rhs) { + // write in little endian + size_t len = strlen(rhs); + lhs.reserve(len); + for (size_t byte = 0; byte < len; byte++) { + lhs.push_back(rhs[byte]); + } + return lhs; } /** @@ -200,87 +224,80 @@ std::vector& operator+=(std::vector& lhs, const char* rhs) { * @param path * @return */ -char* cnpy::loadFile(const char *path) { - char* buffer = 0; - long length; - FILE * f = fopen (path, "rb"); //was "rb" - - if (f) { - fseek (f, 0, SEEK_END); - length = ftell (f); - fseek (f, 0, SEEK_SET); - buffer = (char*) malloc ((length+ 1) * sizeof(char)); - - // just getting rid of compiler warning - Nd4jLong fps = 0; - - if (buffer) { - fps += fread (buffer, sizeof(char), length, f); - } - - fclose (f); +char *cnpy::loadFile(const char *path) { + char *buffer = 0; + long length; + FILE *f = fopen(path, "rb"); // was "rb" + + if (f) { + fseek(f, 0, SEEK_END); + length = ftell(f); + fseek(f, 0, SEEK_SET); + buffer = (char *)malloc((length + 1) * sizeof(char)); + + // just getting rid of compiler warning + Nd4jLong fps = 0; + + if (buffer) { + fps += fread(buffer, sizeof(char), length, f); } - buffer[length] = '\0'; - return buffer; -} + fclose(f); + } + buffer[length] = '\0'; + return buffer; +} /** -* Parse the numpy header from -* the given file -* based on the pointers passed in -* @param fp the file to parse from -* @param wordSize the size of -* the individual elements -* @param shape -* @param ndims -* @param fortranOrder -*/ -void cnpy::parseNpyHeaderStr(std::string header, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, + * Parse the numpy header from + * the given file + * based on the pointers passed in + * @param fp the file to parse from + * @param wordSize the size of + * the individual elements + * @param shape + * @param ndims + * @param fortranOrder + */ +void cnpy::parseNpyHeaderStr(std::string header, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, bool &fortranOrder) { - - - int loc1, loc2; - - - - //fortran order - loc1 = header.find("fortran_order") + 16; - fortranOrder = (header.substr(loc1,5) == "True" ? true : false); - //shape - loc1 = header.find("("); - loc2 = header.find(")"); - std::string str_shape = header.substr(loc1 + 1,loc2 - loc1 - 1); - if(str_shape[str_shape.size() - 1] == ',') ndims = 1; - else ndims = std::count(str_shape.begin(),str_shape.end(),',')+1; - - shape = new unsigned int[ndims]; - for(unsigned int i = 0; i < ndims; i++) { - loc1 = str_shape.find(","); - shape[i] = atoi(str_shape.substr(0,loc1).c_str()); - str_shape = str_shape.substr(loc1 + 1); - } - - - - //endian, word size, data type - //byte order code | stands for not applicable. - //not sure when this applies except for byte array - loc1 = header.find("descr") + 9; - bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); - assert(littleEndian); - - //char type = header[loc1+1]; - //assert(type == map_type(T)); - - std::string str_ws = header.substr(loc1 + 2); - loc2 = str_ws.find("'"); - wordSize = atoi(str_ws.substr(0,loc2).c_str()); - + int loc1, loc2; + + // fortran order + loc1 = header.find("fortran_order") + 16; + fortranOrder = (header.substr(loc1, 5) == "True" ? true : false); + // shape + loc1 = header.find("("); + loc2 = header.find(")"); + std::string str_shape = header.substr(loc1 + 1, loc2 - loc1 - 1); + if (str_shape[str_shape.size() - 1] == ',') + ndims = 1; + else + ndims = std::count(str_shape.begin(), str_shape.end(), ',') + 1; + + shape = new unsigned int[ndims]; + for (unsigned int i = 0; i < ndims; i++) { + loc1 = str_shape.find(","); + shape[i] = atoi(str_shape.substr(0, loc1).c_str()); + str_shape = str_shape.substr(loc1 + 1); + } + + // endian, word size, data type + // byte order code | stands for not applicable. + // not sure when this applies except for byte array + loc1 = header.find("descr") + 9; + bool littleEndian = + (header[loc1] == '<' || header[loc1] == '|' ? true : false); + assert(littleEndian); + + // char type = header[loc1+1]; + // assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1 + 2); + loc2 = str_ws.find("'"); + wordSize = atoi(str_ws.substr(0, loc2).c_str()); } /** @@ -294,26 +311,17 @@ void cnpy::parseNpyHeaderStr(std::string header, * @param ndims the number of dimensions for the array * @param fortranOrder */ -void cnpy::parseNpyHeader(FILE *fp, - unsigned int &wordSize, - unsigned int *&shape, - unsigned int &ndims, +void cnpy::parseNpyHeader(FILE *fp, unsigned int &wordSize, + unsigned int *&shape, unsigned int &ndims, bool &fortranOrder) { - char buffer[256]; - size_t res = fread(buffer,sizeof(char),11,fp); - if(res != 11) - throw std::runtime_error("parse_npy_header: failed fread"); - std::string header = fgets(buffer,256,fp); - assert(header[header.size() - 1] == '\n'); - cnpy::parseNpyHeaderStr(header, - wordSize, - shape, - ndims, - fortranOrder); + char buffer[256]; + size_t res = fread(buffer, sizeof(char), 11, fp); + if (res != 11) throw std::runtime_error("parse_npy_header: failed fread"); + std::string header = fgets(buffer, 256, fp); + assert(header[header.size() - 1] == '\n'); + cnpy::parseNpyHeaderStr(header, wordSize, shape, ndims, fortranOrder); } - - /** * * @param fp @@ -321,30 +329,27 @@ void cnpy::parseNpyHeader(FILE *fp, * @param global_header_size * @param global_header_offset */ -void cnpy::parseZipFooter(FILE* fp, - unsigned short& nrecs, - unsigned int& global_header_size, - unsigned int& global_header_offset) { - - std::vector footer(22); - fseek(fp, -22, SEEK_END); - size_t res = fread(&footer[0],sizeof(char),22,fp); - if(res != 22) - throw std::runtime_error("parse_zip_footer: failed fread"); - - unsigned short disk_no, disk_start, nrecs_on_disk, comment_len; - disk_no = *(unsigned short*) &footer[4]; - disk_start = *(unsigned short*) &footer[6]; - nrecs_on_disk = *(unsigned short*) &footer[8]; - nrecs = *(unsigned short*) &footer[10]; - global_header_size = *(unsigned int*) &footer[12]; - global_header_offset = *(unsigned int*) &footer[16]; - comment_len = *(unsigned short*) &footer[20]; - - assert(disk_no == 0); - assert(disk_start == 0); - assert(nrecs_on_disk == nrecs); - assert(comment_len == 0); +void cnpy::parseZipFooter(FILE *fp, unsigned short &nrecs, + unsigned int &global_header_size, + unsigned int &global_header_offset) { + std::vector footer(22); + fseek(fp, -22, SEEK_END); + size_t res = fread(&footer[0], sizeof(char), 22, fp); + if (res != 22) throw std::runtime_error("parse_zip_footer: failed fread"); + + unsigned short disk_no, disk_start, nrecs_on_disk, comment_len; + disk_no = *(unsigned short *)&footer[4]; + disk_start = *(unsigned short *)&footer[6]; + nrecs_on_disk = *(unsigned short *)&footer[8]; + nrecs = *(unsigned short *)&footer[10]; + global_header_size = *(unsigned int *)&footer[12]; + global_header_offset = *(unsigned int *)&footer[16]; + comment_len = *(unsigned short *)&footer[20]; + + assert(disk_no == 0); + assert(disk_start == 0); + assert(nrecs_on_disk == nrecs); + assert(comment_len == 0); } /** @@ -353,88 +358,85 @@ void cnpy::parseZipFooter(FILE* fp, * @return the loaded array */ cnpy::NpyArray cnpy::loadNpyFromFile(FILE *fp) { - unsigned int *shape; - unsigned int ndims, wordSize; - bool fortranOrder; - cnpy::parseNpyHeader(fp,wordSize,shape,ndims,fortranOrder); - unsigned long long size = 1; //long long so no overflow when multiplying by word_size - for(unsigned int i = 0;i < ndims;i++) size *= shape[i]; - - cnpy::NpyArray arr; - arr.wordSize = wordSize; - arr.shape = std::vector(shape,shape + ndims); - arr.data = new char[size * wordSize]; - arr.fortranOrder = fortranOrder; - size_t nread = fread(arr.data,wordSize,size,fp); - if(nread != size) - throw std::runtime_error("load_the_npy_file: failed fread"); - return arr; + unsigned int *shape; + unsigned int ndims, wordSize; + bool fortranOrder; + cnpy::parseNpyHeader(fp, wordSize, shape, ndims, fortranOrder); + unsigned long long size = + 1; // long long so no overflow when multiplying by word_size + for (unsigned int i = 0; i < ndims; i++) size *= shape[i]; + + cnpy::NpyArray arr; + arr.wordSize = wordSize; + arr.shape = std::vector(shape, shape + ndims); + arr.data = new char[size * wordSize]; + arr.fortranOrder = fortranOrder; + size_t nread = fread(arr.data, wordSize, size, fp); + if (nread != size) + throw std::runtime_error("load_the_npy_file: failed fread"); + return arr; } - - /** - * - * @param data - * @return - */ -cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) { - //move the pointer forward by 11 imitating - //the seek in loading directly from a file - return cnpy::loadNpyFromHeader(data); + * + * @param data + * @return + */ +cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) { + // move the pointer forward by 11 imitating + // the seek in loading directly from a file + return cnpy::loadNpyFromHeader(data); } /** -* -* @param data -* @return -*/ + * + * @param data + * @return + */ cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) { - // check for magic header - if (data == nullptr) - throw std::runtime_error("NULL pointer doesn't look like a NumPy header"); - - if (data[0] == (char) 0x93) { - std::vector exp({(char) 0x93, 'N', 'U', 'M', 'P', 'Y', (char) 0x01}); - std::vector hdr(data, data+7); - if (hdr != exp) - throw std::runtime_error("Pointer doesn't look like a NumPy header"); - } else - throw std::runtime_error("Pointer doesn't look like a NumPy header"); - - //move passed magic - data += 11; - unsigned int *shape; - unsigned int ndims, wordSize; - bool fortranOrder; - cnpy::parseNpyHeaderStr(std::string(data), - wordSize, - shape, - ndims, - fortranOrder); - //the "real" data starts after the \n - char currChar = data[0]; - int count = 0; - while(currChar != '\n') { - data++; - currChar = data[0]; - count++; - } - - //move pass the \n + // check for magic header + if (data == nullptr) + throw std::runtime_error("NULL pointer doesn't look like a NumPy header"); + + if (data[0] == (char)0x93) { + std::vector exp({(char)0x93, 'N', 'U', 'M', 'P', 'Y', (char)0x01}); + std::vector hdr(data, data + 7); + if (hdr != exp) + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + } else + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + + // move passed magic + data += 11; + unsigned int *shape; + unsigned int ndims, wordSize; + bool fortranOrder; + cnpy::parseNpyHeaderStr(std::string(data), wordSize, shape, ndims, + fortranOrder); + // the "real" data starts after the \n + char currChar = data[0]; + int count = 0; + while (currChar != '\n') { data++; + currChar = data[0]; count++; - - unsigned long long size = 1; //long long so no overflow when multiplying by word_size - for(unsigned int i = 0; i < ndims; i++) size *= shape[i]; - char *cursor = data; - cnpy::NpyArray arr; - arr.wordSize = wordSize; - arr.shape = std::vector(shape,shape + ndims); - delete[] shape; - arr.data = cursor; - arr.fortranOrder = fortranOrder; - return arr; + } + + // move pass the \n + data++; + count++; + + unsigned long long size = + 1; // long long so no overflow when multiplying by word_size + for (unsigned int i = 0; i < ndims; i++) size *= shape[i]; + char *cursor = data; + cnpy::NpyArray arr; + arr.wordSize = wordSize; + arr.shape = std::vector(shape, shape + ndims); + delete[] shape; + arr.data = cursor; + arr.fortranOrder = fortranOrder; + return arr; } /** @@ -443,41 +445,39 @@ cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) { * @return the arrays */ -cnpy::npz_t cnpy::npzLoad(FILE* fp){ - cnpy::npz_t arrays; - - while(1) { - std::vector local_header(30); - size_t headerres = fread(&local_header[0],sizeof(char),30,fp); - if(headerres != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string varname(name_len,' '); - size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - if(extra_field_len > 0) { - std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); - if(efield_res != extra_field_len) - throw std::runtime_error("npz_load: failed fread"); - } - - arrays[varname] = loadNpyFromFile(fp); +cnpy::npz_t cnpy::npzLoad(FILE *fp) { + cnpy::npz_t arrays; + + while (1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); + if (headerres != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string varname(name_len, ' '); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + if (extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + if (efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); } - return arrays; + + arrays[varname] = loadNpyFromFile(fp); + } + return arrays; } /** @@ -486,45 +486,42 @@ cnpy::npz_t cnpy::npzLoad(FILE* fp){ * @return the arrays */ cnpy::npz_t cnpy::npzLoad(std::string fname) { - FILE* fp = fopen(fname.c_str(),"rb"); - - if(!fp) printf("npz_load: Error! Unable to open file %s!\n",fname.c_str()); - assert(fp); - cnpy::npz_t arrays; - while(1) { - std::vector local_header(30); - size_t headerres = fread(&local_header[0],sizeof(char),30,fp); - if(headerres != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string varname(name_len,' '); - size_t vname_res = fread(&varname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - if(extra_field_len > 0) { - std::vector buff(extra_field_len); - size_t efield_res = fread(&buff[0],sizeof(char),extra_field_len,fp); - if(efield_res != extra_field_len) - throw std::runtime_error("npz_load: failed fread"); - } - - arrays[varname] = loadNpyFromFile(fp); + FILE *fp = fopen(fname.c_str(), "rb"); + + if (!fp) printf("npz_load: Error! Unable to open file %s!\n", fname.c_str()); + assert(fp); + cnpy::npz_t arrays; + while (1) { + std::vector local_header(30); + size_t headerres = fread(&local_header[0], sizeof(char), 30, fp); + if (headerres != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string varname(name_len, ' '); + size_t vname_res = fread(&varname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + if (extra_field_len > 0) { + std::vector buff(extra_field_len); + size_t efield_res = fread(&buff[0], sizeof(char), extra_field_len, fp); + if (efield_res != extra_field_len) + throw std::runtime_error("npz_load: failed fread"); } - fclose(fp); - return arrays; + arrays[varname] = loadNpyFromFile(fp); + } + fclose(fp); + return arrays; } /** @@ -534,149 +531,138 @@ cnpy::npz_t cnpy::npzLoad(std::string fname) { * @return */ cnpy::NpyArray cnpy::npzLoad(std::string fname, std::string varname) { - FILE *fp = fopen(fname.c_str(),"rb"); - - if(!fp) { - printf("npz_load: Error! Unable to open file %s!\n",fname.c_str()); + FILE *fp = fopen(fname.c_str(), "rb"); + + if (!fp) { + printf("npz_load: Error! Unable to open file %s!\n", fname.c_str()); + } + + while (1) { + std::vector local_header(30); + size_t header_res = fread(&local_header[0], sizeof(char), 30, fp); + if (header_res != 30) throw std::runtime_error("npz_load: failed fread"); + + // if we've reached the global header, stop reading + if (local_header[2] != 0x03 || local_header[3] != 0x04) break; + + // read in the variable name + unsigned short name_len = *(unsigned short *)&local_header[26]; + std::string vname(name_len, ' '); + size_t vname_res = fread(&vname[0], sizeof(char), name_len, fp); + if (vname_res != name_len) + throw std::runtime_error("npz_load: failed fread"); + + // erase the lagging .npy + for (int e = 0; e < 4; e++) varname.pop_back(); + + // read in the extra field + unsigned short extra_field_len = *(unsigned short *)&local_header[28]; + fseek(fp, extra_field_len, SEEK_CUR); // skip past the extra field + + if (vname == varname) { + NpyArray array = cnpy::loadNpyFromFile(fp); + fclose(fp); + return array; + } else { + // skip past the data + unsigned int size = *(unsigned int *)&local_header[22]; + fseek(fp, size, SEEK_CUR); } + } - while(1) { - std::vector local_header(30); - size_t header_res = fread(&local_header[0],sizeof(char),30,fp); - if(header_res != 30) - throw std::runtime_error("npz_load: failed fread"); - - //if we've reached the global header, stop reading - if(local_header[2] != 0x03 || local_header[3] != 0x04) break; - - //read in the variable name - unsigned short name_len = *(unsigned short*) &local_header[26]; - std::string vname(name_len,' '); - size_t vname_res = fread(&vname[0],sizeof(char),name_len,fp); - if(vname_res != name_len) - throw std::runtime_error("npz_load: failed fread"); - - //erase the lagging .npy - for (int e = 0; e < 4; e++) - varname.pop_back(); - - //read in the extra field - unsigned short extra_field_len = *(unsigned short*) &local_header[28]; - fseek(fp,extra_field_len,SEEK_CUR); //skip past the extra field - - if(vname == varname) { - NpyArray array = cnpy::loadNpyFromFile(fp); - fclose(fp); - return array; - } - else { - //skip past the data - unsigned int size = *(unsigned int*) &local_header[22]; - fseek(fp,size,SEEK_CUR); - } - } - - fclose(fp); - printf("npz_load: Error! Variable name %s not found in %s!\n",varname.c_str(),fname.c_str()); - throw std::runtime_error("Variable wasn't found in file"); + fclose(fp); + printf("npz_load: Error! Variable name %s not found in %s!\n", + varname.c_str(), fname.c_str()); + throw std::runtime_error("Variable wasn't found in file"); } - - - /** * Load a numpy array from the given file * @param fname the fully qualified path for the file * @return the NpArray for this file */ cnpy::NpyArray cnpy::npyLoad(std::string fname) { - FILE* fp = fopen(fname.c_str(), "rb"); + FILE *fp = fopen(fname.c_str(), "rb"); - if(!fp) { - printf("npy_load: Error! Unable to open file %s!\n",fname.c_str()); - } + if (!fp) { + printf("npy_load: Error! Unable to open file %s!\n", fname.c_str()); + } - NpyArray arr = cnpy::loadNpyFromFile(fp); + NpyArray arr = cnpy::loadNpyFromFile(fp); - fclose(fp); - return arr; + fclose(fp); + return arr; } - /** - * Save the numpy array - * @tparam T - * @param fname the file - * @param data the data for the ndarray - * @param shape the shape of the ndarray - * @param ndims the number of dimensions - * for the ndarray - * @param mode the mode for writing - */ -template -void cnpy::npy_save(std::string fname, - const T* data, - const unsigned int* shape, - const unsigned int ndims, - std::string mode) { - - FILE* fp = NULL; - - if(mode == "a") - fp = fopen(fname.c_str(),"r+b"); - - if(fp) { - //file exists. we need to append to it. read the header, modify the array size - unsigned int word_size, tmp_dims; - unsigned int* tmp_shape = 0; - bool fortran_order; - parseNpyHeader(fp, - word_size, - tmp_shape, - tmp_dims, - fortran_order); - - assert(!fortran_order); - - if(word_size != sizeof(T)) { - std::cout<<"libnpy error: " << fname<< " has word size " << word_size<<" but npy_save appending data sized " << sizeof(T) <<"\n"; - assert( word_size == sizeof(T) ); - } - - if(tmp_dims != ndims) { - std::cout<<"libnpy error: npy_save attempting to append misdimensioned data to "< header = createNpyHeader(data,tmp_shape,ndims); - fwrite(&header[0],sizeof(char),header.size(),fp); - fseek(fp,0,SEEK_END); - - delete[] tmp_shape; + * Save the numpy array + * @tparam T + * @param fname the file + * @param data the data for the ndarray + * @param shape the shape of the ndarray + * @param ndims the number of dimensions + * for the ndarray + * @param mode the mode for writing + */ +template +void cnpy::npy_save(std::string fname, const T *data, const unsigned int *shape, + const unsigned int ndims, std::string mode) { + FILE *fp = NULL; + + if (mode == "a") fp = fopen(fname.c_str(), "r+b"); + + if (fp) { + // file exists. we need to append to it. read the header, modify the array + // size + unsigned int word_size, tmp_dims; + unsigned int *tmp_shape = 0; + bool fortran_order; + parseNpyHeader(fp, word_size, tmp_shape, tmp_dims, fortran_order); + + assert(!fortran_order); + + if (word_size != sizeof(T)) { + std::cout << "libnpy error: " << fname << " has word size " << word_size + << " but npy_save appending data sized " << sizeof(T) << "\n"; + assert(word_size == sizeof(T)); } - else { - fp = fopen(fname.c_str(),"wb"); - std::vector header = createNpyHeader(data,shape,ndims); - fwrite(&header[0],sizeof(char),header.size(),fp); + + if (tmp_dims != ndims) { + std::cout << "libnpy error: npy_save attempting to append misdimensioned " + "data to " + << fname << "\n"; + assert(tmp_dims == ndims); } - unsigned long long nels = 1; - for(int i = 0;i < ndims;i++) nels *= shape[i]; + for (int i = 1; i < ndims; i++) { + if (shape[i] != tmp_shape[i]) { + std::cout + << "libnpy error: npy_save attempting to append misshaped data to " + << fname << "\n"; + assert(shape[i] == tmp_shape[i]); + } + } - fwrite(data,sizeof(T),nels,fp); - fclose(fp); -} + tmp_shape[0] += shape[0]; + fseek(fp, 0, SEEK_SET); + std::vector header = createNpyHeader(data, tmp_shape, ndims); + fwrite(&header[0], sizeof(char), header.size(), fp); + fseek(fp, 0, SEEK_END); + + delete[] tmp_shape; + } else { + fp = fopen(fname.c_str(), "wb"); + std::vector header = createNpyHeader(data, shape, ndims); + fwrite(&header[0], sizeof(char), header.size(), fp); + } + + unsigned long long nels = 1; + for (int i = 0; i < ndims; i++) nels *= shape[i]; + + fwrite(data, sizeof(T), nels, fp); + fclose(fp); +} /** * @@ -686,49 +672,58 @@ void cnpy::npy_save(std::string fname, * @param ndims * @return */ -template +template std::vector cnpy::createNpyHeader(const void *vdata, - const unsigned int *shape, - const unsigned int ndims, - unsigned int wordSize) { - - auto data = reinterpret_cast(vdata); - - std::vector dict; - dict += "{'descr': '"; - dict += sizeof(T) > 1 ? BigEndianTest() : '|'; - dict += mapType(); - dict += tostring(wordSize); - dict += "', 'fortran_order': False, 'shape': ("; - if (ndims > 0) { - dict += tostring(shape[0]); - for (int i = 1; i < ndims; i++) { - dict += ", "; - dict += tostring(shape[i]); - } - - if (ndims == 1) - dict += ","; + const unsigned int *shape, + const unsigned int ndims, + unsigned int wordSize) { + auto data = reinterpret_cast(vdata); + + std::vector dict; + dict += "{'descr': '"; + dict += sizeof(T) > 1 ? BigEndianTest() : '|'; + dict += mapType(); + dict += tostring(wordSize); + dict += "', 'fortran_order': False, 'shape': ("; + if (ndims > 0) { + dict += tostring(shape[0]); + for (int i = 1; i < ndims; i++) { + dict += ", "; + dict += tostring(shape[i]); } - // 0D case still requires close - dict += "), }"; - - //pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n - int remainder = 64 - (10 + dict.size()) % 64; - dict.insert(dict.end(),remainder,' '); - dict.back() = '\n'; - - std::vector header; - header += (char) 0x93; - header += "NUMPY"; - header += (char) 0x01; //major version of numpy format - header += (char) 0x00; //minor version of numpy format - header += (unsigned short) dict.size(); - header.insert(header.end(),dict.begin(),dict.end()); - - return header; + + if (ndims == 1) dict += ","; + } + // 0D case still requires close + dict += "), }"; + + // pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 + // bytes. dict needs to end with \n + int remainder = 64 - (10 + dict.size()) % 64; + dict.insert(dict.end(), remainder, ' '); + dict.back() = '\n'; + + std::vector header; + header += (char)0x93; + header += "NUMPY"; + header += (char)0x01; // major version of numpy format + header += (char)0x00; // minor version of numpy format + header += (unsigned short)dict.size(); + header.insert(header.end(), dict.begin(), dict.end()); + + return header; } -BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector cnpy::createNpyHeader, (const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize), LIBND4J_TYPES); -//template SD_EXPORT std::vector cnpy::createNpyHeader(const void *data, const unsigned int *shape, const unsigned int ndims, unsigned int wordSize); -template SD_EXPORT void cnpy::npy_save(std::string fname, const float* data, const unsigned int* shape, const unsigned int ndims, std::string mode); +BUILD_SINGLE_TEMPLATE( + template SD_EXPORT std::vector cnpy::createNpyHeader, + (const void *data, const unsigned int *shape, const unsigned int ndims, + unsigned int wordSize), + LIBND4J_TYPES); +// template SD_EXPORT std::vector cnpy::createNpyHeader(const void +// *data, const unsigned int *shape, const unsigned int ndims, unsigned int +// wordSize); +template SD_EXPORT void cnpy::npy_save(std::string fname, + const float *data, + const unsigned int *shape, + const unsigned int ndims, + std::string mode); diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index c6160d953129..fed97b469572 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -23,75 +23,122 @@ #include #include + #include namespace sd { ////////////////////////////////////////////////////////////////////////// inline pairwise::Ops fromBroadcastToPairwise(broadcast::Ops op) { - switch (op) { - case broadcast::Add: return pairwise::Add; - case broadcast::Subtract: return pairwise::Subtract; - case broadcast::Multiply: return pairwise::Multiply; - case broadcast::Divide: return pairwise::Divide; - case broadcast::ReverseDivide: return pairwise::ReverseDivide; - case broadcast::ReverseSubtract: return pairwise::ReverseSubtract; - case broadcast::CopyPws: return pairwise::CopyPws; - case broadcast::Pow: return pairwise::Pow; - case broadcast::MinPairwise: return pairwise::MinPairwise; - case broadcast::MaxPairwise: return pairwise::MaxPairwise; - case broadcast::AMinPairwise: return pairwise::AMinPairwise; - case broadcast::AMaxPairwise: return pairwise::AMaxPairwise; - case broadcast::SquaredSubtract: return pairwise::SquaredSubtract; - case broadcast::FloorMod: return pairwise::FloorMod; - case broadcast::FloorDiv: return pairwise::FloorDiv; - case broadcast::ReverseMod: return pairwise::ReverseMod; - case broadcast::SafeDivide: return pairwise::SafeDivide; - case broadcast::Mod: return pairwise::Mod; - case broadcast::TruncateDiv: return pairwise::TruncateDiv; - case broadcast::Atan2: return pairwise::Atan2; - case broadcast::LogicalOr: return pairwise::LogicalOr; - case broadcast::LogicalXor: return pairwise::LogicalXor; - case broadcast::LogicalNot: return pairwise::LogicalNot; - case broadcast::LogicalAnd: return pairwise::LogicalAnd; - case broadcast::PowDerivative: return pairwise::PowDerivative; - default: - throw std::runtime_error("fromBroadcastToPairwise: Not convertible operation"); - } + switch (op) { + case broadcast::Add: + return pairwise::Add; + case broadcast::Subtract: + return pairwise::Subtract; + case broadcast::Multiply: + return pairwise::Multiply; + case broadcast::Divide: + return pairwise::Divide; + case broadcast::ReverseDivide: + return pairwise::ReverseDivide; + case broadcast::ReverseSubtract: + return pairwise::ReverseSubtract; + case broadcast::CopyPws: + return pairwise::CopyPws; + case broadcast::Pow: + return pairwise::Pow; + case broadcast::MinPairwise: + return pairwise::MinPairwise; + case broadcast::MaxPairwise: + return pairwise::MaxPairwise; + case broadcast::AMinPairwise: + return pairwise::AMinPairwise; + case broadcast::AMaxPairwise: + return pairwise::AMaxPairwise; + case broadcast::SquaredSubtract: + return pairwise::SquaredSubtract; + case broadcast::FloorMod: + return pairwise::FloorMod; + case broadcast::FloorDiv: + return pairwise::FloorDiv; + case broadcast::ReverseMod: + return pairwise::ReverseMod; + case broadcast::SafeDivide: + return pairwise::SafeDivide; + case broadcast::Mod: + return pairwise::Mod; + case broadcast::TruncateDiv: + return pairwise::TruncateDiv; + case broadcast::Atan2: + return pairwise::Atan2; + case broadcast::LogicalOr: + return pairwise::LogicalOr; + case broadcast::LogicalXor: + return pairwise::LogicalXor; + case broadcast::LogicalNot: + return pairwise::LogicalNot; + case broadcast::LogicalAnd: + return pairwise::LogicalAnd; + case broadcast::PowDerivative: + return pairwise::PowDerivative; + default: + throw std::runtime_error( + "fromBroadcastToPairwise: Not convertible operation"); + } } ////////////////////////////////////////////////////////////////////////// inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) { - switch (op) { - case broadcast::EqualTo: return pairwise::EqualTo; - case broadcast::GreaterThan: return pairwise::GreaterThan; - case broadcast::LessThan: return pairwise::LessThan; - case broadcast::Epsilon: return pairwise::Epsilon; - case broadcast::GreaterThanOrEqual: return pairwise::GreaterThanOrEqual; - case broadcast::LessThanOrEqual: return pairwise::LessThanOrEqual; - case broadcast::NotEqualTo: return pairwise::NotEqualTo; - case broadcast::And: return pairwise::And; - case broadcast::Or: return pairwise::Or; - case broadcast::Xor: return pairwise::Xor; - case broadcast::Not: return pairwise::Not; - default: - throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation"); - } + switch (op) { + case broadcast::EqualTo: + return pairwise::EqualTo; + case broadcast::GreaterThan: + return pairwise::GreaterThan; + case broadcast::LessThan: + return pairwise::LessThan; + case broadcast::Epsilon: + return pairwise::Epsilon; + case broadcast::GreaterThanOrEqual: + return pairwise::GreaterThanOrEqual; + case broadcast::LessThanOrEqual: + return pairwise::LessThanOrEqual; + case broadcast::NotEqualTo: + return pairwise::NotEqualTo; + case broadcast::And: + return pairwise::And; + case broadcast::Or: + return pairwise::Or; + case broadcast::Xor: + return pairwise::Xor; + case broadcast::Not: + return pairwise::Not; + default: + throw std::runtime_error( + "fromBroadcastToPairwiseBool: Not convertible operation"); + } } - inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { - switch (op) { - case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd; - case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr; - case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor; - case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft; - case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight; - case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft; - case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight; - default: - throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation"); - } - } +inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { + switch (op) { + case broadcast::IntOps::IntAnd: + return pairwise::IntOps::IntAnd; + case broadcast::IntOps::IntOr: + return pairwise::IntOps::IntOr; + case broadcast::IntOps::IntXor: + return pairwise::IntOps::IntXor; + case broadcast::IntOps::ShiftLeft: + return pairwise::IntOps::ShiftLeft; + case broadcast::IntOps::ShiftRight: + return pairwise::IntOps::ShiftRight; + case broadcast::IntOps::CyclicShiftLeft: + return pairwise::IntOps::CyclicShiftLeft; + case broadcast::IntOps::CyclicShiftRight: + return pairwise::IntOps::CyclicShiftRight; + default: + throw std::runtime_error( + "fromBroadcastToPairwiseInt: Not convertible operation"); + } } +} // namespace sd -#endif //SD_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file +#endif // SD_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file diff --git a/libnd4j/include/loops/BroadcastScalarConverter.h b/libnd4j/include/loops/BroadcastScalarConverter.h index 3a23a615b53b..e7730f242672 100644 --- a/libnd4j/include/loops/BroadcastScalarConverter.h +++ b/libnd4j/include/loops/BroadcastScalarConverter.h @@ -22,37 +22,50 @@ #include #include + #include namespace sd { - inline bool isConvertibleToScalar(broadcast::Ops op) { - int opNum = (int) op; - - if (opNum <= 17) - return true; - - return false; - } - - inline scalar::Ops convertToScalar(broadcast::Ops op) { - switch (op) { - case broadcast::Add: return scalar::Add; - case broadcast::Subtract: return scalar::Subtract; - case broadcast::Multiply: return scalar::Multiply; - case broadcast::Divide: return scalar::Divide; - case broadcast::ReverseDivide: return scalar::ReverseDivide; - case broadcast::ReverseSubtract: return scalar::ReverseSubtract; - case broadcast::CopyPws: return scalar::CopyPws; - case broadcast::Pow: return scalar::Pow; - case broadcast::MinPairwise: return scalar::MinPairwise; - case broadcast::MaxPairwise: return scalar::MaxPairwise; - case broadcast::AMinPairwise: return scalar::AMinPairwise; - case broadcast::AMaxPairwise: return scalar::AMaxPairwise; - case broadcast::SquaredSubtract: return scalar::SquaredSubtract; - default: - throw std::runtime_error("Not convertible operation"); - } - } +inline bool isConvertibleToScalar(broadcast::Ops op) { + int opNum = (int)op; + + if (opNum <= 17) return true; + + return false; +} + +inline scalar::Ops convertToScalar(broadcast::Ops op) { + switch (op) { + case broadcast::Add: + return scalar::Add; + case broadcast::Subtract: + return scalar::Subtract; + case broadcast::Multiply: + return scalar::Multiply; + case broadcast::Divide: + return scalar::Divide; + case broadcast::ReverseDivide: + return scalar::ReverseDivide; + case broadcast::ReverseSubtract: + return scalar::ReverseSubtract; + case broadcast::CopyPws: + return scalar::CopyPws; + case broadcast::Pow: + return scalar::Pow; + case broadcast::MinPairwise: + return scalar::MinPairwise; + case broadcast::MaxPairwise: + return scalar::MaxPairwise; + case broadcast::AMinPairwise: + return scalar::AMinPairwise; + case broadcast::AMaxPairwise: + return scalar::AMaxPairwise; + case broadcast::SquaredSubtract: + return scalar::SquaredSubtract; + default: + throw std::runtime_error("Not convertible operation"); + } } +} // namespace sd -#endif //SD_BROADCASTSCALARCONVERTER_H +#endif // SD_BROADCASTSCALARCONVERTER_H diff --git a/libnd4j/include/loops/ReduceType.h b/libnd4j/include/loops/ReduceType.h index ae0bef08d496..3b90585cf60a 100644 --- a/libnd4j/include/loops/ReduceType.h +++ b/libnd4j/include/loops/ReduceType.h @@ -22,15 +22,7 @@ #define SD_REDUCETYPE_H namespace functions { - enum ReduceType { - SUM, - PRODUCT, - MAX, - MIN, - ASUM, - AMAX, - AMIN - }; +enum ReduceType { SUM, PRODUCT, MAX, MIN, ASUM, AMAX, AMIN }; } -#endif //SD_REDUCETYPE_H +#endif // SD_REDUCETYPE_H diff --git a/libnd4j/include/loops/broadcasting.h b/libnd4j/include/loops/broadcasting.h old mode 100755 new mode 100644 index 4f05f0c6e749..c6e3bc7cedc2 --- a/libnd4j/include/loops/broadcasting.h +++ b/libnd4j/include/loops/broadcasting.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_H_ #define BROADCASTING_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -39,160 +39,150 @@ #include #endif -#include #include +#include #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class Broadcast { - public: - +template +class Broadcast { + public: #ifdef __CUDABLAS__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); + static void execInverse( + int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, sd::LoopKind::Kind loopKind, + uint64_t start, uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, sd::LoopKind::Kind loopKind, + uint64_t start, uint64_t stop); + + template + static void execInverse( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/broadcasting_bool.h b/libnd4j/include/loops/broadcasting_bool.h index 400269c02e3a..fc1a6663d074 100644 --- a/libnd4j/include/loops/broadcasting_bool.h +++ b/libnd4j/include/loops/broadcasting_bool.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_BOOL_H_ #define BROADCASTING_BOOL_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -44,148 +44,151 @@ #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class BroadcastBool { - public: - +template +class BroadcastBool { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *result, Nd4jLong const* resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, void const *x, + Nd4jLong const *xShapeInfo, void const *y, Nd4jLong const *yShapeInfo, + void *result, Nd4jLong const *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *tadOnlyShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, void const *x, + Nd4jLong const *xShapeInfo, void const *y, Nd4jLong const *yShapeInfo, + void *result, Nd4jLong const *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, Nd4jLong const *tadOnlyShapeInfo, + Nd4jLong const *tadOffsets, Nd4jLong const *tadOnlyShapeInfoZ, + Nd4jLong const *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams); + + static void execInverse(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, + uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, void *extraParams, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, + uint64_t start, uint64_t stop); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams); + + template + static void execInverse(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/broadcasting_int.h b/libnd4j/include/loops/broadcasting_int.h index 386fbd3f74ec..f8c377fce1df 100644 --- a/libnd4j/include/loops/broadcasting_int.h +++ b/libnd4j/include/loops/broadcasting_int.h @@ -23,13 +23,13 @@ #ifndef BROADCASTING_INT_H_ #define BROADCASTING_INT_H_ -#include +#include #include #include -#include #include +#include #include -#include +#include #ifdef __CUDACC__ #include @@ -44,148 +44,141 @@ #include "legacy_ops.h" namespace functions { - namespace broadcast { +namespace broadcast { /** * Broadcast operation * for broadcasting a smaller tensor * along long a bigger one. */ - template - class BroadcastInt { - public: - +template +class BroadcastInt { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static __device__ void transformInverseCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadOnlyShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformCuda(const void *x, + const Nd4jLong *xShapeInfo, + const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + + static __host__ void execBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, + const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static __device__ void transformInverseCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadOnlyShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - static void execInverse(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - /** - * CPU execution - * @param x the input - * @param xShapeInfo the x shape information - * @param y the y data - * @param yShapeInfo the y shape information - * @param result the result - * @param resultShapeInfo the result shape information - * @param dimension the dimension to broadcast along long - * @param dimensionLength the length of the dimension buffer - */ - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); - - template - static void exec(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - template - static void execInverse(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetZ, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + static void execInverse( + int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void execInverse( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); #endif - }; - } -} +}; +} // namespace broadcast +} // namespace functions #endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/cpu/broadcasting.hpp b/libnd4j/include/loops/cpu/broadcasting.hpp index c0f22313ba16..4fd77cee31ea 100644 --- a/libnd4j/include/loops/cpu/broadcasting.hpp +++ b/libnd4j/include/loops/cpu/broadcasting.hpp @@ -18,830 +18,890 @@ // @author raver119@gmail.com // -#include +#include +#include +#include +#include #include #include +#include #include -#include -#include -#include -#include using namespace simdOps; namespace functions { namespace broadcast { - template - void Broadcast::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_OPS); - } - - template - void Broadcast::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, loopKind, start, stop), BROADCAST_OPS); - } +template +void Broadcast::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_OPS); +} +template +void Broadcast::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, sd::LoopKind::Kind loopKind, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, loopKind, start, stop), + BROADCAST_OPS); +} - template - template - void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - sd::LoopKind::Kind loopKind, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - - const sd::LoopKind::Kind kindOfLoop = - (loopKind == sd::LoopKind::BROADCAST_SCALAR_X || - loopKind == sd::LoopKind::BROADCAST_SCALAR_Y || - loopKind == sd::LoopKind::BROADCAST_3D || - loopKind == sd::LoopKind::BROADCAST_4D || - loopKind == sd::LoopKind::BROADCAST_5D) - ? loopKind : sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO){ - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); - } - } else if(kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_X){ - // this loop effectively turns broadcast into series of scalar ops - auto loopLength = yShapeInfo[shape::rank(yShapeInfo)]; - - for (auto i = start; i < stop; i++) { - auto oY = y + (i * loopLength); - auto oZ = z + (i * loopLength); - - const auto oX = x[i]; - - PRAGMA_OMP_SIMD - for (Nd4jLong f = 0; f < loopLength; f++) - oZ[f] = OpType::op(oX, oY[f]); - } - } else if(kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_Y){ - // this loop effectively turns broadcast into series of scalar ops - auto loopLength = xShapeInfo[shape::rank(xShapeInfo)]; - - for (auto i = start; i < stop; i++) { - auto oX = x + (i * loopLength); - auto oZ = z + (i * loopLength); - - const auto oY = y[i]; - - PRAGMA_OMP_SIMD - for (Nd4jLong f = 0; f < loopLength; f++) - oZ[f] = OpType::op(oX[f], oY); - } - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_3D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[3] = { 0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); - - for (auto index0 = start; index0 < stop; index0++) { - - PRAGMA_OMP_SIMD - for (uint64_t index1 = 0; index1 < nSize1; index1++) { - for (uint64_t index2 = 0; index2 < nSize2; index2++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2); - *rZ = OpType::op(*rX, *rY); - } - } - - } - - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_4D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[4] = { 0,0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); - uint64_t nSize3 = shape::sizeAt(zShapeInfo, 3); - - for (auto i = start; i < stop; i++) { - - uint64_t index0 = i / nSize1; - uint64_t index1 = i % nSize1; - - PRAGMA_OMP_SIMD - for (uint64_t index2 = 0; index2 < nSize2; index2++) { - for (uint64_t index3 = 0; index3 < nSize3; index3++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3); - *rZ = OpType::op(*rX, *rY); - } - } - } - - } - else if (kindOfLoop == sd::LoopKind::BROADCAST_5D) { - - int xRank = shape::rank(xShapeInfo); - int yRank = shape::rank(yShapeInfo); - - auto xStrides = shape::stride(xShapeInfo); - auto zStrides = shape::stride(zShapeInfo); - - Nd4jLong yStrides[5] = { 0,0,0,0,0 }; - sd::ShapeUtils::copyCertainStridesFromShapeInfo(yShapeInfo, xRank, dimensionLength, dimension, yStrides); - - uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1); - uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2); - uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3); - uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4); - - for (auto i = start; i < stop; i++) { - - uint32_t index0 = i / nSize1; - uint32_t index1 = i % nSize1; - - PRAGMA_OMP_SIMD - for (uint32_t index2 = 0; index2 < nSize2; index2++) { - for (uint32_t index3 = 0; index3 < nSize3; index3++) { - for (uint32_t index4 = 0; index4 < nSize4; index4++) { - auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + xStrides[2] * index2 + xStrides[3] * index3 + xStrides[4] * index4); - auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + yStrides[2] * index2 + yStrides[3] * index3 + yStrides[4] * index4); - auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + zStrides[2] * index2 + zStrides[3] * index3 + zStrides[4] * index4); - - *rZ = OpType::op(*rX, *rY); - } - } - } - } - - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); - } - } - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); - } - } - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); - } - } - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); - } - } - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); - } - } - } - } +template +template +void Broadcast::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, sd::LoopKind::Kind loopKind, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = + (loopKind == sd::LoopKind::BROADCAST_SCALAR_X || + loopKind == sd::LoopKind::BROADCAST_SCALAR_Y || + loopKind == sd::LoopKind::BROADCAST_3D || + loopKind == sd::LoopKind::BROADCAST_4D || + loopKind == sd::LoopKind::BROADCAST_5D) + ? loopKind + : sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, + zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_X) { + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = yShapeInfo[shape::rank(yShapeInfo)]; + for (auto i = start; i < stop; i++) { + auto oY = y + (i * loopLength); + auto oZ = z + (i * loopLength); + const auto oX = x[i]; - template - template - void Broadcast::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance()->maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if(kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oY = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); - } - }; - } - } + PRAGMA_OMP_SIMD + for (Nd4jLong f = 0; f < loopLength; f++) oZ[f] = OpType::op(oX, oY[f]); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_SCALAR_Y) { + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = xShapeInfo[shape::rank(xShapeInfo)]; + for (auto i = start; i < stop; i++) { + auto oX = x + (i * loopLength); + auto oZ = z + (i * loopLength); -//////////////////////////////////////////////////////////////////////// -template - void Broadcast::exec(const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { + const auto oY = y[i]; - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); -} + PRAGMA_OMP_SIMD + for (Nd4jLong f = 0; f < loopLength; f++) oZ[f] = OpType::op(oX[f], oY); + } + } else if (kindOfLoop == sd::LoopKind::BROADCAST_3D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[3] = {0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); + + for (auto index0 = start; index0 < stop; index0++) { + PRAGMA_OMP_SIMD + for (uint64_t index1 = 0; index1 < nSize1; index1++) { + for (uint64_t index2 = 0; index2 < nSize2; index2++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2); + *rZ = OpType::op(*rX, *rY); + } + } + } -//////////////////////////////////////////////////////////////////////// -template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + } else if (kindOfLoop == sd::LoopKind::BROADCAST_4D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[4] = {0, 0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint64_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint64_t nSize2 = shape::sizeAt(zShapeInfo, 2); + uint64_t nSize3 = shape::sizeAt(zShapeInfo, 3); + + for (auto i = start; i < stop; i++) { + uint64_t index0 = i / nSize1; + uint64_t index1 = i % nSize1; + + PRAGMA_OMP_SIMD + for (uint64_t index2 = 0; index2 < nSize2; index2++) { + for (uint64_t index3 = 0; index3 < nSize3; index3++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2 + xStrides[3] * index3); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2 + yStrides[3] * index3); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2 + zStrides[3] * index3); + *rZ = OpType::op(*rX, *rY); + } + } + } - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + } else if (kindOfLoop == sd::LoopKind::BROADCAST_5D) { + int xRank = shape::rank(xShapeInfo); + int yRank = shape::rank(yShapeInfo); + + auto xStrides = shape::stride(xShapeInfo); + auto zStrides = shape::stride(zShapeInfo); + + Nd4jLong yStrides[5] = {0, 0, 0, 0, 0}; + sd::ShapeUtils::copyCertainStridesFromShapeInfo( + yShapeInfo, xRank, dimensionLength, dimension, yStrides); + + uint32_t nSize1 = shape::sizeAt(zShapeInfo, 1); + uint32_t nSize2 = shape::sizeAt(zShapeInfo, 2); + uint32_t nSize3 = shape::sizeAt(zShapeInfo, 3); + uint32_t nSize4 = shape::sizeAt(zShapeInfo, 4); + + for (auto i = start; i < stop; i++) { + uint32_t index0 = i / nSize1; + uint32_t index1 = i % nSize1; + + PRAGMA_OMP_SIMD + for (uint32_t index2 = 0; index2 < nSize2; index2++) { + for (uint32_t index3 = 0; index3 < nSize3; index3++) { + for (uint32_t index4 = 0; index4 < nSize4; index4++) { + auto rX = x + (xStrides[0] * index0 + xStrides[1] * index1 + + xStrides[2] * index2 + xStrides[3] * index3 + + xStrides[4] * index4); + auto rY = y + (yStrides[0] * index0 + yStrides[1] * index1 + + yStrides[2] * index2 + yStrides[3] * index3 + + yStrides[4] * index4); + auto rZ = z + (zStrides[0] * index0 + zStrides[1] * index1 + + zStrides[2] * index2 + zStrides[3] * index3 + + zStrides[4] * index4); + + *rZ = OpType::op(*rX, *rY); + } + } + } + } - auto func = PRAGMA_THREADS_FOR{ + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + } + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + } + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + } + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + } + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + } + } +} - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } +template +template +void Broadcast::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance()->maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + }; + } } //////////////////////////////////////////////////////////////////////// -template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; +template +void Broadcast::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TTT( + exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); +} - samediff::Threads::parallel_tad(func, 0, zAxis0); +//////////////////////////////////////////////////////////////////////// +template +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], *y); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(*x, y[i0]); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], y[i0]); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// -template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; +template +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(x0[i1], *y0); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(*x0, y0[i1]); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// -template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } - } - }; +template +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(x1[i2], *y1); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(*x1, y1[i2]); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// -template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } +template +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// -template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); +template +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// -template -template -void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } +//////////////////////////////////////////////////////////////////////// +template +template +void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const X *x = reinterpret_cast(vx); + const Y *y = reinterpret_cast(vy); + Z *z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + } } -} \ No newline at end of file + +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.hpp b/libnd4j/include/loops/cpu/broadcasting_bool.hpp index a7386a9162b4..d5770f7c6359 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.hpp @@ -18,718 +18,787 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include +#include #include -#include -#include -#include using namespace simdOps; namespace functions { namespace broadcast { - template - void BroadcastBool::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_BOOL_OPS); - } +template +void BroadcastBool::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, xTadShapeInfo, xTadOffset, + zTadShapeInfo, zTadOffset, start, stop), + BROADCAST_BOOL_OPS); +} - template - void BroadcastBool::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void* extraParams) { +template +void BroadcastBool::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams) { + DISPATCH_BY_OPNUM_TT( + exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), + BROADCAST_BOOL_OPS); +} - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), BROADCAST_BOOL_OPS); - } +template +void BroadcastBool::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, xTadShapeInfo, xTadOffset, + zTadShapeInfo, zTadOffset, start, stop), + BROADCAST_BOOL_OPS); +} - template - void BroadcastBool::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_BOOL_OPS); - } +template +template +void BroadcastBool::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance()->maxThreads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f], extraParams); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); + } + }; - template - template - void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance()->maxThreads()); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f], extraParams); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); - } - }; - - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); - } - }; - } - } + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); + } + }; + } +} - template - template - void BroadcastBool::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance()->maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f], extraParams); - } - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); - } - } - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); - } - } - } - else { - - uint xShapeInfoCast[MAX_RANK]; - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); - } - } - } - } +template +template +void BroadcastBool::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance()->maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f], extraParams); + } + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); + } + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); + } + } + } else { + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); + } + } + } +} //////////////////////////////////////////////////////////////////////// template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y, extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0], extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0], extraParams); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], *y, extraParams); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(*x, y[i0], extraParams); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], y[i0], extraParams); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = + OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0, extraParams); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1], extraParams); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); - } - }; +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], *y0, extraParams); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(*x0, y0[i1], extraParams); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = + OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); + } + }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1, extraParams); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2], extraParams); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); - } - } - }; +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], *y1, extraParams); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(*x1, y1[i2], extraParams); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = + OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2, extraParams); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3], extraParams); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); - } - } +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2, extraParams); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3], extraParams); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = + OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3, extraParams); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4], extraParams); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); - } - } - } +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3, extraParams); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4], extraParams); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = + OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - }; +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, Z *z, + const Nd4jLong *zShapeInfo, X *extraParams) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } //////////////////////////////////////////////////////////////////////// template -template +template void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - X* extraParams = reinterpret_cast(vextraParams); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - } + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams) { + const X *x = reinterpret_cast(vx); + const X *y = reinterpret_cast(vy); + Z *z = reinterpret_cast(vz); + + X *extraParams = reinterpret_cast(vextraParams); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams); + } } - //BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , +// LIBND4J_TYPES, BOOL_TYPES); - -} -} \ No newline at end of file +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/broadcasting_int.hpp b/libnd4j/include/loops/cpu/broadcasting_int.hpp index 93d5c05e1889..ed9736c4d694 100644 --- a/libnd4j/include/loops/cpu/broadcasting_int.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_int.hpp @@ -18,700 +18,759 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include +#include #include -#include -#include -#include using namespace simdOps; namespace functions { - namespace broadcast { - - template - void BroadcastInt::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_INT_OPS); - } - - template - void BroadcastInt::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo) { +namespace broadcast { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_INT_OPS); - } +template +void BroadcastInt::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_INT_OPS); +} - template - void BroadcastInt::execInverse(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - dimension, - dimensionLength, - xTadShapeInfo, - xTadOffset, - zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_INT_OPS); - } +template +void BroadcastInt::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + BROADCAST_INT_OPS); +} - template - template - void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto xTadShapeShapeInfo = xTadShapeInfo; - auto tadOffsets = xTadOffset; - - if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - - xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(xTadShapeShapeInfo); - unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); - unsigned int tads = shape::length(xShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenY = shape::length(yShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance()->maxThreads()); - - auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); - }; - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oX = x + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); - } - }; - } - else { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oX = x + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); - } - }; - } - } +template +void BroadcastInt::execInverse( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + DISPATCH_BY_OPNUM_T( + execInverse, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xTadOffset, zTadShapeInfo, + zTadOffset, start, stop), + BROADCAST_INT_OPS); +} +template +template +void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffset, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, + uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = + shape::length(xTadShapeShapeInfo); // shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance()->maxThreads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + }; + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + }; + } else { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + }; + } +} - template - template - void BroadcastInt::execInverse(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, const int dimensionLength, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yTadOffset, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffset, - uint64_t start, uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - auto yTadShapeShapeInfo = yTadShapeInfo; - auto tadOffsets = yTadOffset; - - if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); - - yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); - tadOffsets = const_cast(tadPack.primaryOffsets()); - } - - //int *resultStride = shape::stride(yTadShapeShapeInfo); - unsigned int tadLength = shape::length(yTadShapeShapeInfo); - unsigned int tads = shape::length(yShapeInfo) / tadLength; - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = yTadShapeShapeInfo; - zTadOffset = tadOffsets; - } - - auto lenZ = shape::length(zTadShapeInfo); - auto lenX = shape::length(xShapeInfo); - - int tadsPerThread = tads / TAD_THRESHOLD; - int threads = sd::math::nd4j_max(1, tadsPerThread); - threads = sd::math::nd4j_min(threads, sd::Environment::getInstance()->maxThreads()); - - auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); - }; - } - else if(kindOfLoop == sd::LoopKind::EWSNONZERO) { - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oY = y + tadOffsets[i]; - auto oZ = z + zTadOffset[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { - - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); - } - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { - uint tadShapeShapeInfoCast[MAX_RANK]; - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); - } - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint tadShapeShapeInfoCast[MAX_RANK]; - uint tadShapeInfoZCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); - - for (auto i = start; i < stop; i ++) { - auto oZ = z + zTadOffset[i]; - auto oY = y + tadOffsets[i]; - - PRAGMA_OMP_SIMD - for (uint f = 0; f < tadLength; f++) { - auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); - } - }; - } - } +template +template +void BroadcastInt::execInverse( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + int *dimension, const int dimensionLength, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yTadOffset, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffset, uint64_t start, uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = const_cast(tadPack.primaryShapeInfo()); + tadOffsets = const_cast(tadPack.primaryOffsets()); + } + + // int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = sd::math::nd4j_max(1, tadsPerThread); + threads = sd::math::nd4j_min( + threads, sd::Environment::getInstance()->maxThreads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ( + yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + }; + } else if (kindOfLoop == sd::LoopKind::EWSNONZERO) { + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && + shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(yTadShapeShapeInfo, + zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto offset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = sd::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, + tadShapeShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + for (auto i = start; i < stop; i++) { + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) { + auto xOffset = + shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, + tadShapeShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + }; + } +} //////////////////////////////////////////////////////////////////////// template -static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, 0); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR { + if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], *y); + } else if (zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(*x, y[i0]); + } else if (zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) z[i0] = OpType::op(x[i0], y[i0]); + } else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i0 = start; i0 < stop; ++i0) { + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(x0[i1], *y0); + else if (zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) z0[i1] = OpType::op(*x0, y0[i1]); + else if (zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; - samediff::Threads::parallel_tad(func, 0, zAxis0); + samediff::Threads::parallel_tad(func, 0, zAxis0); } //////////////////////////////////////////////////////////////////////// template -static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); - - uint zAxis1 = shape::sizeAt(zShapeInfo, 1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(x1[i2], *y1); + else if (zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) z1[i2] = OpType::op(*x1, y1[i2]); + else if (zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); - - uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if (zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if (zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); - - uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); - - uint zAxis2 = shape::sizeAt(zShapeInfo, 2); - Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); - Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); - Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); - - uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); - - uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + uint zAxis0 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = + shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = + shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = + shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = + shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if (zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if (zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + samediff::Threads::parallel_for(func, 0, zAxis0, 1, 0, zAxis1, 1, 0, zAxis2, + 1); } //////////////////////////////////////////////////////////////////////// template -static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, + const Nd4jLong *yShapeInfo, X *z, + const Nd4jLong *zShapeInfo) { + const bool xzSameOffsets = + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = + shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); } //////////////////////////////////////////////////////////////////////// template -template +template void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - - switch (rank) { - - case 1: - execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 2: - execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 3: - execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 4: - execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - case 5: - execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - break; - default: - execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - } + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const X *x = reinterpret_cast(vx); + const X *y = reinterpret_cast(vy); + X *z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + switch (rank) { + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + } } -//BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); -} -} \ No newline at end of file +// BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , +// INTEGER_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp index a21ea1109d12..3612d84bf0cb 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p0.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_0, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_0, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp index 8cb7bc865fa4..87fdaf64160c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p1.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_1, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_1, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp index b073e4603652..a1024c6d91d9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p2.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_2, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_2, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp index 6d5032a88f04..064844116e8a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p3.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_3, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_3, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp index e312a564387a..08a2155d5e95 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p4.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_4, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_4, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp index 2f37d6505e1a..504ea0e50111 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p5.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_5, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_5, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp index e15adcd9fcba..b23d28589b25 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p6.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_6, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_6, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp index 4dfc22073780..5517d77963e6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p7.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_7, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_7, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp index ab59a846d87a..5d9b947e6d55 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p8.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_8, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_8, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp index e43382ec032b..5de5f4388cac 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p9.cpp @@ -21,7 +21,8 @@ #include "../broadcasting_bool.hpp" namespace functions { - namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_9, BOOL_TYPES); - } -} \ No newline at end of file +namespace broadcast { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_9, + BOOL_TYPES); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp index aec26e5ea83d..a2c2dd73a581 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p0.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_0); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp index 50cd1972268c..203892ef02a3 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p1.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_1); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp index 807a613bf832..2a6eaa6d755c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p2.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_2); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp index 26dfa1985cd2..190f3dc3a3ce 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p3.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_3); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp index 652974b39d58..8658dd8c7be6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p4.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_4); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp index 4159e5d0b7be..273a36b86ca9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p5.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_5); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp index 1fb44733ab2b..c48a592afe07 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p6.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_6); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp index 72d127f312f8..710580384f5a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p7.cpp @@ -21,7 +21,7 @@ #include "../broadcasting_int.hpp" namespace functions { - namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_7); - } -} \ No newline at end of file +namespace broadcast { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp index 059ef45b97d7..31756e67aaad 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p0.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp index b93b66aef95e..594c6f938985 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p1.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp index 09c6cbd50e4c..6730b7018414 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p10.cpp @@ -21,7 +21,8 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp index 33da9553fb3b..81c6a6981078 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p11.cpp @@ -21,7 +21,8 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp index f85e7ff9c592..3b257c36b18a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p12.cpp @@ -21,7 +21,8 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp index 77242a3145e8..3e6020c7160d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p2.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp index 683629bf1b95..2498642cdb07 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p3.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp index fc720385475f..27cd6e19ec7a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p4.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp index f9ce462c993b..f603d1d03257 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p5.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp index dbf09ea2273e..bf15b432058c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p6.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp index 700a5c771c59..34aef078cb11 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p7.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp index 60720b2da364..0e9126083828 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p8.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp index c0255f2d3d26..3d1b391cdcee 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p9.cpp @@ -21,7 +21,7 @@ #include "../broadcasting.hpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp index 137258f77d62..33664b4cfbff 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_0.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp index 3aaf3fde7282..c8e9ed02fd59 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_1.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp index c4f87dfae7da..3b602c8e13a0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_2.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp index 1a86d3eb4cc5..721e1f5299ca 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_3.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp index d263456400d9..d4b01543b621 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_4.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp index 4195c48a8dcf..d90d63a8d00c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_5.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp index b6966425da82..9ea9a141653b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_6.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp index 931d9a5ad91f..769d57f11d49 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_7.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp index 6b282d8fb102..b1013c7b589f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_8.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp index 17d14a835381..db37cdf10f7a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32_9.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT32, int32_t)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, + (sd::DataType::INT32, int32_t)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp index 63b5347a176c..b50ec738e6f1 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_0.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_0, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp index b7bab85cba95..dba4dfd4d419 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_1.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_1, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp index eb4217f66390..68a8d5273c0a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_2.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_2, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp index fceeb38298d0..7b8a2eb7d0c6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_3.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_3, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp index 0bb478598c3f..34d113830f34 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_4.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_4, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp index 851fe1edb26d..5731bb5eaa5e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_5.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_5, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp index b9268e519770..8f8d14a3195f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_6.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_6, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp index c17d61930394..08fc244ef87f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_7.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_7, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp index ddea061ac185..84cd5a984d47 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_8.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_8, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp index 79a6ddac2730..82269a85023e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64_9.cpp @@ -22,7 +22,8 @@ #include "../indexreduce.hpp" namespace functions { - namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, (sd::DataType::INT64, Nd4jLong)); - } -} \ No newline at end of file +namespace indexreduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_9, + (sd::DataType::INT64, Nd4jLong)); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp index 3dbc22427690..bf972aca2130 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p0.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp index 607467b47884..1b5f210fdb1f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p1.cpp @@ -21,8 +21,9 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); - } +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_1); +} -} \ No newline at end of file +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp index 365ff223a933..3a788260dafd 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p10.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp index 6222e487aeef..09d2bbebb778 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p11.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp index 9a9909bca945..7cce666cbb5b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p12.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp index 83bee2bb3b95..14f3dccb88ae 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p2.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp index 804a8887521b..1ecec73450ea 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p3.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp index c244607b3c5e..820b46145cf5 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p4.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp index 51043f20f9d6..8a19ad89bec5 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p5.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp index 02ed81a9a27f..4c003af7b7c1 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p6.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp index 9cd8ff32a9fc..c85a865e0a63 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p7.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp index 0f57bc913d9a..97abb69da96a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p8.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp index 6ef3f7e07bf0..3de06ba19fac 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p9.cpp @@ -21,7 +21,8 @@ #include "loops/cpu/pairwise.hpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_0.cpp b/libnd4j/include/loops/cpu/compilation_units/random_0.cpp index ef5c075533e0..4876d0d63365 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_0.cpp @@ -21,7 +21,7 @@ #include "../random.hpp" namespace functions { - namespace random { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace random { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_1.cpp b/libnd4j/include/loops/cpu/compilation_units/random_1.cpp index c4ec6bc1d855..1480cd7d592a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_1.cpp @@ -21,7 +21,7 @@ #include "../random.hpp" namespace functions { - namespace random { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace random { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_2.cpp b/libnd4j/include/loops/cpu/compilation_units/random_2.cpp index d766d5caf115..431cbe27ed97 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_2.cpp @@ -21,7 +21,7 @@ #include "../random.hpp" namespace functions { - namespace random { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace random { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random_3.cpp b/libnd4j/include/loops/cpu/compilation_units/random_3.cpp index 08032c8b4bf9..01af5a99657d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/random_3.cpp @@ -21,7 +21,7 @@ #include "../random.hpp" namespace functions { - namespace random { - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace random { +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp index 9b4c36769bbc..c7c5317665e6 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_0.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp index c4e77433ab7a..53502a411bbe 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_1.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp index 327c7d47e73c..23d90a3d5972 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_2.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp index d26a609040fd..5c23301dd054 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_3.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp index 8dac72f07052..15e302c51277 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_4.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp index 0b35b957e574..3758a437d490 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_5.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp index b7a4f0e24242..353db181d5be 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_6.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp index cc66ed99b8d0..8fa335c3105b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_7.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp index f8da905c3a7d..6de9aae7a370 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_8.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp index ebbaef251a96..d80e70914a45 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16_9.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp index f4d9c53e933f..f8cf39bb43de 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_0.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp index ee6dd7ff5dd8..f4877b67948c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_1.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp index 8bc9ab053949..55161cbf675a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_2.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp index 139f955b5b9e..a3d3d0f81773 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_3.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp index 5146e675dac7..700d83f2f981 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_4.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp index bee768cd8233..53a9899c51d3 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_5.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp index cc9fe5cb6687..aa691b11ab98 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_6.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp index fc46966d8ee8..c4d9835cf939 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_7.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp index 59e293f9a678..ddb2beb9ee08 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_8.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp index 58926d9e33d6..63d74c45e107 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double_9.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp index 57069c7c8730..7bcbd3bd0c30 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_0.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp index 7075e4fe1d40..e4192862ab87 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_1.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp index b4fe17ed8e0b..e8666de18ea2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_2.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp index 109d46c0cc5f..f6cd761138b8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_3.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp index 9390c388e9c4..56bd743b8726 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_4.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp index 44162fb56527..f1c1b97f5210 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_5.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp index 078ae7968af6..bb9a8428cb41 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_6.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp index 1027b86916f9..b3d2cd9a0911 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_7.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp index 0addb0a3f44d..36c02e3bb40f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_8.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp index 1e2878851ff1..60f8d542a6d2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16_9.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp index 52e0648d8688..7c402c6db55e 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_0.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_0, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp index 3fcf252de0e7..6722058cf747 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_1.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_1, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp index bd21f708b73d..ef663a7f0855 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_2.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_2, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp index e15aaa14c13c..4976a4adbd87 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_3.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_3, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp index ac3a138106d7..4fc61a73880b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_4.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_4, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp index 7a9b85d9a578..289378c13181 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_5.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_5, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp index ff6490bebdc7..5632a580276f 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_6.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_6, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp index 36270a134959..f37b57021d89 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_7.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_7, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp index ed9c9ff64a7c..22d216d2deb9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_8.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_8, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp index e619ccdeff32..e1288ed27b94 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float_9.cpp @@ -22,7 +22,8 @@ #include "../reduce3.hpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_9, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp index 3ebf86606e73..da56116cad45 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_0.cpp @@ -22,7 +22,8 @@ #include "../reduce/reduce_float.hpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); - } +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_0); } +} // namespace functions diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp index a0bc314e24cc..428fa5848ec2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_1.cpp @@ -22,7 +22,8 @@ #include "../reduce/reduce_float.hpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); - } +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_1); } +} // namespace functions diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp index 387516ed470a..b86846901fd7 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_2.cpp @@ -22,7 +22,8 @@ #include "../reduce/reduce_float.hpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); - } +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_2); } +} // namespace functions diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp index 6194cb7e45a4..aba596b08ac2 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float_3.cpp @@ -22,7 +22,8 @@ #include "../reduce/reduce_float.hpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); - } +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_3); } +} // namespace functions diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp index 3e1008b57fac..fe20ad6f62bf 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p0.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp index 90e5d0a47398..f1929bff666a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p1.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp index 4433cef9f755..266f2acc31b5 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p10.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp index 4366a23489ba..b5f33cdaf898 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p11.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp index 14c9b3774dfa..f3927fb850fa 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p12.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp index 29616adc10e2..72253b0130d9 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p2.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp index 686a3f7f7852..0b9e184f8cf0 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p3.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp index 865c7bf5ab10..13564c01d9b3 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p4.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp index 4284efac914d..3ab71830331b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p5.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp index 29a13300a8b6..a4de6a5984ab 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p6.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp index b542c60ab56d..afe2b77c017a 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p7.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp index 79983ab1d714..2a92f665e2b8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p8.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp b/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp index 41b39bb3f2d9..3dc629d29a0d 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p9.cpp @@ -21,7 +21,8 @@ #include "../scalar.hpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/indexreduce.hpp b/libnd4j/include/loops/cpu/indexreduce.hpp index 296fbcdefc79..01b7ce7af8d1 100644 --- a/libnd4j/include/loops/cpu/indexreduce.hpp +++ b/libnd4j/include/loops/cpu/indexreduce.hpp @@ -18,138 +18,148 @@ // Created by raver on 4/9/2018. // +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include using namespace simdOps; -namespace functions { +namespace functions { namespace indexreduce { //////////////////////////////////////////////////////////////////////// template -Nd4jLong IndexReduce::execScalar( const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); +Nd4jLong IndexReduce::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// template -void IndexReduce::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +void IndexReduce::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset), + INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// template -template -Nd4jLong IndexReduce::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - //T startingVal = OpType::startingValue(x); - auto startingIndex = OpType::startingIndexValue(x); - auto len = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - sd::OmpLaunchHelper info(len); - - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - IndexValue intermediatery[64]; - for (int e = 0; e < maxThreads; e++) - intermediatery[e].index = -1; +template +Nd4jLong IndexReduce::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + // T startingVal = OpType::startingValue(x); + auto startingIndex = OpType::startingIndexValue(x); + auto len = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + sd::OmpLaunchHelper info(len); + + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + IndexValue intermediatery[64]; + for (int e = 0; e < maxThreads; e++) intermediatery[e].index = -1; + + if (xEws == 1) { + auto func = PRAGMA_THREADS_FOR { + intermediatery[thread_id] = OpType::startingIndexValue(x); + + for (auto i = start; i < stop; i++) { + IndexValue curr(x[i], i); + intermediatery[thread_id] = + OpType::update(intermediatery[thread_id], curr, extraParams); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - if (xEws == 1) { - auto func = PRAGMA_THREADS_FOR { - intermediatery[thread_id] = OpType::startingIndexValue(x); - - for (auto i = start; i < stop; i++) { - IndexValue curr(x[i], i); - intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - - for (int e = 0; e < maxThreads; e++) - startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); + for (int e = 0; e < maxThreads; e++) + startingIndex = + OpType::update(startingIndex, intermediatery[e], extraParams); - } else { - auto func = PRAGMA_THREADS_FOR { - intermediatery[thread_id] = OpType::startingIndexValue(x); + } else { + auto func = PRAGMA_THREADS_FOR { + intermediatery[thread_id] = OpType::startingIndexValue(x); - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - IndexValue curr(x[offset], i); - intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams); - } - }; + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + IndexValue curr(x[offset], i); + intermediatery[thread_id] = + OpType::update(intermediatery[thread_id], curr, extraParams); + } + }; - maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); + maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads); - for (int e = 0; e < maxThreads; e++) - startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams); - } - return startingIndex.index; + for (int e = 0; e < maxThreads; e++) + startingIndex = + OpType::update(startingIndex, intermediatery[e], extraParams); + } + return startingIndex.index; } - //////////////////////////////////////////////////////////////////////// template -template +template void IndexReduce::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset) { + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); + const Nd4jLong zLen = shape::length(zShapeInfo); - const Nd4jLong zLen = shape::length(zShapeInfo); + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto indexValue = OpType::startingIndexValue(x); - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto indexValue = OpType::startingIndexValue(x); + for (Nd4jLong i = 0; i < zLen; i++) z[i] = (Z)indexValue.index; - for (Nd4jLong i = 0; i < zLen; i++) - z[i] = (Z) indexValue.index; + return; + } - return; - } + if (shape::isScalar(zShapeInfo)) { + z[0] = (Z)execScalar(x, xShapeInfo, extraParams); + return; + } - if(shape::isScalar(zShapeInfo)) { - z[0] = (Z) execScalar(x,xShapeInfo,extraParams); - return; - } + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } - - sd::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); + sd::IndexReductionLoops::template loopIndexReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); } -} -} \ No newline at end of file +} // namespace indexreduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise.hpp b/libnd4j/include/loops/cpu/pairwise.hpp index 45fe46e8f1c6..ceb3a196d099 100644 --- a/libnd4j/include/loops/cpu/pairwise.hpp +++ b/libnd4j/include/loops/cpu/pairwise.hpp @@ -18,210 +18,211 @@ // Created by remote on 2018-09-20. // -#include -#include -#include +#include #include -#include +#include #include +#include +#include +#include #include -#include -#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start,const uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_TRANSFORM_OPS); - }; - - - - template - template - void PairWiseTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, - const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_TRANSFORM_OPS); - }; - - - template - template - void PairWiseTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - } - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } +namespace pairwise_transforms { + +template +void PairWiseTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, + Nd4jLong yEws, void *z, Nd4jLong zEws, + void *extraParams, Nd4jLong n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_TRANSFORM_OPS); +}; + +template +template +void PairWiseTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseTransform::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_TRANSFORM_OPS); +}; + +template +template +void PairWiseTransform::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; + } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; } + } } +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/pairwise_bool.cpp b/libnd4j/include/loops/cpu/pairwise_bool.cpp index 19c5d5880125..b7c0cc2559a9 100644 --- a/libnd4j/include/loops/cpu/pairwise_bool.cpp +++ b/libnd4j/include/loops/cpu/pairwise_bool.cpp @@ -18,204 +18,208 @@ // Created by remote on 2018-09-20. // -#include -#include +#include #include #include -#include +#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseBoolTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_BOOL_OPS); - }; - - - - template - template - void PairWiseBoolTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseBoolTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start,const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_BOOL_OPS); - }; - - - template - template - void PairWiseBoolTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); +namespace pairwise_transforms { + +template +void PairWiseBoolTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, + Nd4jLong yEws, void *z, Nd4jLong zEws, + void *extraParams, Nd4jLong n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_BOOL_OPS); +}; + +template +template +void PairWiseBoolTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseBoolTransform::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TT(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_BOOL_OPS); +}; + +template +template +void PairWiseBoolTransform::exec( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; + } + } } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/pairwise_int.cpp b/libnd4j/include/loops/cpu/pairwise_int.cpp index 2494c2df7f71..e77592103a01 100644 --- a/libnd4j/include/loops/cpu/pairwise_int.cpp +++ b/libnd4j/include/loops/cpu/pairwise_int.cpp @@ -18,205 +18,210 @@ // @author raver119@gmail.com // -#include -#include +#include #include #include -#include +#include +#include using namespace simdOps; namespace functions { - namespace pairwise_transforms { - - template - void PairWiseIntTransform::exec(const int opNum, - const void *x, Nd4jLong xEws, - const void *y, Nd4jLong yEws, - void *z, Nd4jLong zEws, - void *extraParams, - Nd4jLong n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xEws, - y, - yEws, - z, - zEws, - extraParams, - n, start, stop), PAIRWISE_INT_OPS); - }; - - - - template - template - void PairWiseIntTransform::exec(const void *vx, Nd4jLong xEws, - const void *vy, Nd4jLong yEws, - void *vz, Nd4jLong zEws, - void *vextraParams, - const Nd4jLong n, - const uint64_t start, - const uint64_t stop) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && yEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], y[i], extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams); - } - } - - template - void PairWiseIntTransform::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - y, - yShapeInfo, - z, - zShapeInfo, - extraParams, start, stop), - PAIRWISE_INT_OPS); - }; - - - template - template - void PairWiseIntTransform::exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vz, const Nd4jLong* zShapeInfo, - void *vextraParams, - const uint64_t start, - const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto n = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - if (shape::isScalar(yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[0], extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for(auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); - }; - } - return; - } - - const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); - const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); - - if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && sameShapesXY) { - exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); - } - else if ((kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape - exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop); - } - else { - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[offset], y[offset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[offset], y[yOffset], extraParams); - }; - } - else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpType::op(x[xOffset], y[offset], extraParams); - }; - } - else { - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - }; - } - } - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , INTEGER_TYPES); +namespace pairwise_transforms { + +template +void PairWiseIntTransform::exec(const int opNum, const void *x, + Nd4jLong xEws, const void *y, Nd4jLong yEws, + void *z, Nd4jLong zEws, void *extraParams, + Nd4jLong n, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, PARAMS(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop), + PAIRWISE_INT_OPS); +}; + +template +template +void PairWiseIntTransform::exec(const void *vx, Nd4jLong xEws, + const void *vy, Nd4jLong yEws, void *vz, + Nd4jLong zEws, void *vextraParams, + const Nd4jLong n, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], y[i], extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } +} + +template +void PairWiseIntTransform::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + void *extraParams, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T(exec, + PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, start, stop), + PAIRWISE_INT_OPS); +}; + +template +template +void PairWiseIntTransform::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (shape::isScalar(yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + }; } + return; + } + + const sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop); + } else if ((kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) && + !sameShapesXY) { // not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, + shape::length(yShapeInfo), start, stop); + } else { + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + }; + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + }; + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + }; + } + } } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , + INTEGER_TYPES); +} // namespace pairwise_transforms +} // namespace functions diff --git a/libnd4j/include/loops/cpu/random.hpp b/libnd4j/include/loops/cpu/random.hpp index ff92531cf8b3..6e80368f6f17 100644 --- a/libnd4j/include/loops/cpu/random.hpp +++ b/libnd4j/include/loops/cpu/random.hpp @@ -19,263 +19,298 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include +#include using namespace randomOps; namespace functions { - namespace random { - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - if (OpClass::requiresSpecial) { - OpClass::specialOp(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments); - return; - } - - auto length = shape::length(zShapeInfo); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - - if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 && - shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); - } - }; - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else { - - uint xShapeInfoCast[MAX_RANK]; - uint yShapeInfoCast[MAX_RANK]; - uint zShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - }; - - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraArguments) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - auto length = shape::length(zShapeInfo); - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - - if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op(x[i], i, length, rng, extraArguments); - } - }; - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - else { - - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - } - - - template - template - void RandomFunction::execTransform(Nd4jPointer state, void *vz, const Nd4jLong *zShapeInfo, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - auto length = shape::length(zShapeInfo); - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - - if(shape::elementWiseStride(zShapeInfo) == 1){ - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - z[i] = OpClass::op( i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } - else{ - sd::OmpLaunchHelper info(length); - - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[offset] = OpClass::op(i, length, rng, extraArguments); - } - }; - - samediff::Threads::parallel_for(func, 0, length, 1); - } +namespace random { + +template +template +void RandomFunction::execTransform(Nd4jPointer state, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + if (OpClass::requiresSpecial) { + OpClass::specialOp(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraArguments); + return; + } + + auto length = shape::length(zShapeInfo); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && + shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + if (shape::elementWiseStride(zShapeInfo) == 1 && + shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(yShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(yShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); } - - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = + OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); } + }; - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) + samediff::Threads::parallel_for(func, 0, length, 1); + } + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = + OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = + OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto offset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + z[offset] = + OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = + OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +}; + +template +template +void RandomFunction::execTransform(Nd4jPointer state, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + auto length = shape::length(zShapeInfo); + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + if (shape::elementWiseStride(zShapeInfo) == 1 && + shape::elementWiseStride(xShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], i, length, rng, extraArguments); } - - template - void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeInfo, void *extraArguments) { - DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); } + }; - - //BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, length, 1); } -} \ No newline at end of file + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +} + +template +template +void RandomFunction::execTransform(Nd4jPointer state, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + auto length = shape::length(zShapeInfo); + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + + if (shape::elementWiseStride(zShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } else { + sd::OmpLaunchHelper info(length); + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[offset] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, + const void *x, const Nd4jLong *xShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T( + execTransform, + PARAMS(state, x, xShapeInfo, z, zShapeInfo, extraArguments), RANDOM_OPS) +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T(execTransform, + PARAMS(state, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraArguments), + RANDOM_OPS) +} + +template +void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, + const Nd4jLong *zShapeInfo, + void *extraArguments) { + DISPATCH_BY_OPNUM_T(execTransform, + PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) +} + +// BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , +// FLOAT_TYPES); +} // namespace random +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index ae214f12c80c..6807e1fda824 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -19,208 +19,232 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceBoolFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - z[0] = OpType::postProcess(startingValue, length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - Y ReduceBoolFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_BOOL_OPS); - } - - template - void ReduceBoolFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS); - } - - template - void ReduceBoolFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS); - } - - template - template - void _CUDA_H ReduceBoolFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceBoolFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + z[0] = OpType::postProcess(startingValue, length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceBoolFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_BOOL_OPS); +} + +template +void ReduceBoolFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_BOOL_OPS); +} + +template +void ReduceBoolFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_BOOL_OPS); +} + +template +template +void _CUDA_H ReduceBoolFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionBoolLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionBoolLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceBoolFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } - - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); +} + +template +template +void _CUDA_H ReduceBoolFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceBoolFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); } -} \ No newline at end of file + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp index 1795dbc3d8c3..1f8f0e225e48 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_float.hpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_float.hpp @@ -19,241 +19,261 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceFloatFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - if (std::is_same>::value) { - z[0] = sd::DataTypeUtils::nanOrZero(); - } else { - z[0] = OpType::startingValue(x); - } - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - - return; - } - - if (xEws > 0) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - int xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws > 0) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - Y ReduceFloatFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_FLOAT_OPS); - } - - template - void ReduceFloatFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS); - } - - template - void ReduceFloatFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, - xShapeInfo, - extraParams, - z, - zShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffset, start, stop), - REDUCE_FLOAT_OPS); - } - - template - template - void _CUDA_H ReduceFloatFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(x)); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 0) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceFloatFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + if (std::is_same>::value) { + z[0] = sd::DataTypeUtils::nanOrZero(); + } else { + z[0] = OpType::startingValue(x); + } + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + + return; + } + + if (xEws > 0) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + int xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws > 0) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceFloatFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_FLOAT_OPS); +} + +template +void ReduceFloatFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_FLOAT_OPS); +} + +template +void ReduceFloatFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_FLOAT_OPS); +} + +template +template +void _CUDA_H ReduceFloatFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue(x)); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 0) return; + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionFloatLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionFloatLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceFloatFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - // FIXME: wtf??? - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } +} + +template +template +void _CUDA_H ReduceFloatFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + // FIXME: wtf??? + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceFloatFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); } -} \ No newline at end of file + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index 1193d94fdebc..d8b518742b61 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -19,230 +19,256 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceLongFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - - template - Y ReduceLongFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_LONG_OPS); - } - - template - void ReduceLongFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_LONG_OPS); - } - - template - void ReduceLongFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS); - } - - template - template - void _CUDA_H ReduceLongFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vresult, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceLongFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +Y ReduceLongFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_LONG_OPS); +} + +template +void ReduceLongFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_LONG_OPS); +} + +template +void ReduceLongFunction::exec( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_LONG_OPS); +} + +template +template +void _CUDA_H ReduceLongFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, + void *vresult, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vresult); + auto extraParams = reinterpret_cast(vextraParams); + + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < resultLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (resultLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionLongLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLongLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceLongFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vresult, const Nd4jLong *resultShapeInfo) { - auto z = reinterpret_cast(vresult); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); +} + +template +template +void _CUDA_H ReduceLongFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vresult, + const Nd4jLong *resultShapeInfo) { + auto z = reinterpret_cast(vresult); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +Z _CUDA_H ReduceLongFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); + } + }; - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , + LIBND4J_TYPES, LONG_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index 059f987fc940..d2b0c65067a5 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -19,239 +19,261 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include +#include +#include #include +#include +#include +#include +#include + #include -#include -#include using namespace simdOps; namespace functions { - namespace reduce { - template - template - void _CUDA_H ReduceSameFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - const auto length = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - const int rank = shape::rank(xShapeInfo); - - if (shape::isEmpty(xShapeInfo)) { - z[0] = OpType::startingValue(x); - return; - } - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - return; - } - - if (xEws >= 1) { - z[0] = execScalar(x, xEws, length, extraParams); - } - else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - X intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); - - // write out results - z[0] = OpType::postProcess(intermediate[0], length, extraParams); - } - } - - - template - template - X _CUDA_H ReduceSameFunction::execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - - const Nd4jLong length = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - - if (xEws >= 1) { - return execScalar(x, xEws, length, extraParams); - } else { - auto startingValue = OpType::startingValue(x); - uint xShapeInfoCast[MAX_RANK]; - bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) - startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); - - return OpType::postProcess(startingValue, length, extraParams); - } - } - - template - X ReduceSameFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), REDUCE_SAME_OPS); - } - - template - void ReduceSameFunction::execScalar(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_SAME_OPS); - } - - template - void ReduceSameFunction::exec(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, - xShapeInfo, - extraParams, - z, - zShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffset, start, stop), - REDUCE_SAME_OPS); - } - - template - template - void _CUDA_H ReduceSameFunction::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto zLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < zLength; i++) - z[i] = startingVal; - return; - } - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - // || tad.wholeThing - if (zLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) { - z[0] = execScalar(x, xShapeInfo, extraParams); - return; - } - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset); - return; - } - - auto tadOnlyShapeInfo = tadShapeInfo; - auto tadOffsets = tadOffset; - - if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - tadOnlyShapeInfo = tadPack.primaryShapeInfo(); - tadOffsets = tadPack.primaryOffsets(); - } +namespace reduce { +template +template +void _CUDA_H ReduceSameFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + const auto length = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + const int rank = shape::rank(xShapeInfo); + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + return; + } + + if (xEws >= 1) { + z[0] = execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + int maxThreads = sd::math::nd4j_min( + 64, sd::Environment::getInstance()->maxThreads()); + X intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + + // write out results + z[0] = OpType::postProcess(intermediate[0], length, extraParams); + } +} + +template +template +X _CUDA_H ReduceSameFunction::execScalar(const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + const Nd4jLong length = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + + if (xEws >= 1) { + return execScalar(x, xEws, length, extraParams); + } else { + auto startingValue = OpType::startingValue(x); + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) + startingValue = OpType::update( + startingValue, + OpType::op( + x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], + extraParams), + extraParams); + + return OpType::postProcess(startingValue, length, extraParams); + } +} + +template +X ReduceSameFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), + REDUCE_SAME_OPS); +} + +template +void ReduceSameFunction::execScalar(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalar, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), + REDUCE_SAME_OPS); +} + +template +void ReduceSameFunction::exec(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset, start, stop), + REDUCE_SAME_OPS); +} + +template +template +void _CUDA_H ReduceSameFunction::exec( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto zLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < zLength; i++) z[i] = startingVal; + return; + } + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + // || tad.wholeThing + if (zLength == 1 || dimension == nullptr || + dimensionLength == shape::rank(xShapeInfo)) { + z[0] = execScalar(x, xShapeInfo, extraParams); + return; + } + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, tadShapeInfo, tadOffset); + return; + } + + auto tadOnlyShapeInfo = tadShapeInfo; + auto tadOffsets = tadOffset; + + if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) { + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + tadOnlyShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } #ifdef INLINE_LOOPS - sd::ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionLoops::template loopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #else - sd::ReductionSameLoops::template innerloopReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop); + sd::ReductionSameLoops::template innerloopReduce( + x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, + start, stop); #endif - } - - - template - template - void _CUDA_H ReduceSameFunction::exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto z = reinterpret_cast(vz); - z[0] = execScalar(x, xShapeInfo, extraParams); - } - - template - template - X _CUDA_H ReduceSameFunction::execScalar(const void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - X intermediate[64]; - - PRAGMA_OMP_SIMD - for (auto e = 0; e < maxThreads; e++) - intermediate[e] = OpType::startingValue(x); - - auto func = PRAGMA_THREADS_FOR { - if (xEws == 1) { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams); - } else { - for (auto i = start; i < stop; i++) - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); +} + +template +template +void _CUDA_H ReduceSameFunction::exec(const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto z = reinterpret_cast(vz); + z[0] = execScalar(x, xShapeInfo, extraParams); +} + +template +template +X _CUDA_H ReduceSameFunction::execScalar(const void *vx, Nd4jLong xEws, + Nd4jLong length, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + X intermediate[64]; + + PRAGMA_OMP_SIMD + for (auto e = 0; e < maxThreads; e++) + intermediate[e] = OpType::startingValue(x); + + auto func = PRAGMA_THREADS_FOR { + if (xEws == 1) { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i], extraParams), extraParams); + } else { + for (auto i = start; i < stop; i++) + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[i * xEws], extraParams), extraParams); + } + }; - // merge results - for (int e = 1; e < maxThreads; e++) - intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams); + maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - // return result - return OpType::postProcess(intermediate[0], length, extraParams); - } + // merge results + for (int e = 1; e < maxThreads; e++) + intermediate[0] = + OpType::update(intermediate[0], intermediate[e], extraParams); + // return result + return OpType::postProcess(intermediate[0], length, extraParams); +} - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , LIBND4J_TYPES); - } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , + LIBND4J_TYPES); +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/reduce3.hpp b/libnd4j/include/loops/cpu/reduce3.hpp index 3a830377e9b4..fe2f981c7b8f 100644 --- a/libnd4j/include/loops/cpu/reduce3.hpp +++ b/libnd4j/include/loops/cpu/reduce3.hpp @@ -17,248 +17,276 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.11.2018 - -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - auto length = shape::length(xShapeInfo); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto yEws = shape::elementWiseStride(yShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY || sd::ArrayOptions::arrayType(yShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - const auto startingVal = OpType::startingValue(x); - - for (Nd4jLong i = 0; i < length; i++) - z[i] = startingVal; - - return; - } - - Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f}; - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - Z startingVal = OpType::startingValue(x); - int maxThreads = sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); - Z intermediate[64]; - Z extraParamsLocal[3 * 64]; - +template +void Reduce3::execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto length = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY || + sd::ArrayOptions::arrayType(yShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + const auto startingVal = OpType::startingValue(x); + + for (Nd4jLong i = 0; i < length; i++) z[i] = startingVal; + + return; + } + + Z extraParamsVals[3] = {(Z)0.0f, (Z)0.0f, (Z)0.0f}; + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + Z startingVal = OpType::startingValue(x); + int maxThreads = + sd::math::nd4j_min(64, sd::Environment::getInstance()->maxThreads()); + Z intermediate[64]; + Z extraParamsLocal[3 * 64]; + + PRAGMA_OMP_SIMD + for (int e = 0; e < maxThreads; e++) intermediate[e] = startingVal; + + memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z)); + if (extraParams != nullptr) { PRAGMA_OMP_SIMD - for (int e = 0; e < maxThreads; e++) - intermediate[e] = startingVal; - - memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z)); - if (extraParams != nullptr) { - PRAGMA_OMP_SIMD - // mostly for future reference - for (int e = 0; e < maxThreads; e++) { - extraParamsLocal[3 * e] = extraParams[0]; - extraParamsLocal[3 * e + 1] = extraParams[1]; - extraParamsLocal[3 * e + 2] = extraParams[2]; - } + // mostly for future reference + for (int e = 0; e < maxThreads; e++) { + extraParamsLocal[3 * e] = extraParams[0]; + extraParamsLocal[3 * e + 1] = extraParams[1]; + extraParamsLocal[3 * e + 2] = extraParams[2]; } - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - - } else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - } else { - uint yShapeInfoCast[MAX_RANK]; - const bool canCastY = sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); - intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); - } - - // merge step - for (int e = 0; e < maxThreads; e++) - OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e); - - for (int e = 0; e < maxThreads; e++) - startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals); - - // writing out result - z[0] = OpType::postProcess(startingVal, length, extraParamsVals); + } + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + + } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + intermediate[thread_id] = OpType::update( + intermediate[thread_id], + OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + } else { + uint yShapeInfoCast[MAX_RANK]; + const bool canCastY = + sd::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto yOffset = + shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); + intermediate[thread_id] = + OpType::update(intermediate[thread_id], + OpType::op(x[xOffset], y[yOffset], + extraParamsLocal + 3 * thread_id), + extraParamsLocal + 3 * thread_id); + } + }; + + maxThreads = + samediff::Threads::parallel_for(func, 0, length, 1, maxThreads); + } + + // merge step + for (int e = 0; e < maxThreads; e++) + OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e); + + for (int e = 0; e < maxThreads; e++) + startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals); + + // writing out result + z[0] = OpType::postProcess(startingVal, length, extraParamsVals); } ////////////////////////////////////////////////////////////////////////// template -void Reduce3::execScalar(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS); +void Reduce3::execScalar(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalar, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if(shape::isScalar(zShapeInfo)) { - execScalar(vx, xShapeInfo, vextraParams, vy, yShapeInfo, vz, zShapeInfo); - return; - } +template +void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (shape::isScalar(zShapeInfo)) { + execScalar(vx, xShapeInfo, vextraParams, vy, yShapeInfo, vz, + zShapeInfo); + return; + } #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #endif } ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); +template +void Reduce3::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + extraParams, start, stop); #endif } - ////////////////////////////////////////////////////////////////////////// template -template -void Reduce3:: execAll(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); +template +void Reduce3::execAll(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, + int64_t stop) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); #ifdef INLINE_LOOPS - sd::Reduction3Loops::template loopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop); + sd::Reduction3Loops::template loopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, + yTadShapeInfo, yOffsets, extraParams, start, stop); #else - sd::Reduction3Loops::template innerloopReduce3All(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop); + sd::Reduction3Loops::template innerloopReduce3All( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, + yTadShapeInfo, yOffsets, extraParams, start, stop); #endif } ////////////////////////////////////////////////////////////////////////// template -void Reduce3::exec(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS); +void Reduce3::exec(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, void *extraParamsVals, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, start, stop), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -void Reduce3::exec(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS); +void Reduce3::exec(const int opNum, const void *vx, + const Nd4jLong *xShapeInfo, void *extraParamsVals, + const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -void Reduce3::execAll(const int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop) { - - DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS); +void Reduce3::execAll( + const int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, int64_t start, + int64_t stop) { + DISPATCH_BY_OPNUM_TT( + execAll, + PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, + dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, + yOffsets, start, stop), + REDUCE3_OPS); } -} -} \ No newline at end of file +} // namespace reduce3 +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/scalar.hpp b/libnd4j/include/loops/cpu/scalar.hpp index 236ba7e25ee2..9633de715ee5 100644 --- a/libnd4j/include/loops/cpu/scalar.hpp +++ b/libnd4j/include/loops/cpu/scalar.hpp @@ -18,193 +18,198 @@ // Created by raver119 on 08.10.2017. // -#include "../scalar.h" +#include +#include #include #include -#include -#include + #include "../legacy_ops.h" +#include "../scalar.h" using namespace simdOps; namespace functions { -namespace scalar { - +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -void ScalarTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance()->maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } +template +template +void ScalarTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarTransform::transform: super-bad loop visited. Shouldn't " + "ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance()->maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(const int opNum, - const void *x, Nd4jLong xStride, - void *z, Nd4jLong zStride, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform(const int opNum, const void *x, + Nd4jLong xStride, void *z, + Nd4jLong zStride, const void *scalar, + void *extraParams, const uint64_t n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS); +template +void ScalarTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TTT( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////// -template -template -void ScalarTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - const auto len = shape::length(xShapeInfo); - const auto xEws = shape::elementWiseStride(xShapeInfo); - const auto zEws = shape::elementWiseStride(zShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - } - else { - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } +template +template +void ScalarTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalar, void *vextraParams, + const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + const auto len = shape::length(xShapeInfo); + const auto xEws = shape::elementWiseStride(xShapeInfo); + const auto zEws = shape::elementWiseStride(zShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + } else { + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; } + } } //////////////////////////////////////////////////////////////////////// -template -template +template +template void ScalarTransform::transform(const void *vx, Nd4jLong xEws, void *vz, Nd4jLong zEws, const void *vscalar, - void *vextraParams, - const uint64_t len, const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } + void *vextraParams, const uint64_t len, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } } - - -} -} +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/scalar_bool.cpp b/libnd4j/include/loops/cpu/scalar_bool.cpp index a05d7ad84eee..adcfa10d6450 100644 --- a/libnd4j/include/loops/cpu/scalar_bool.cpp +++ b/libnd4j/include/loops/cpu/scalar_bool.cpp @@ -19,187 +19,193 @@ // #include "../scalar_bool.h" + +#include +#include #include #include -#include -#include #include "../legacy_ops.h" using namespace simdOps; namespace functions { - namespace scalar { - - - template - template - void ScalarBoolTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - // tad preparation - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarBoolTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance()->maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } - } - - template - void ScalarBoolTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS); - } - - - template - void ScalarBoolTransform::transform(const int opNum, - const void *x, Nd4jLong xEws, - void *z, Nd4jLong zEws, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS); - } - - template - void ScalarBoolTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS); - } - - template - template - void ScalarBoolTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, - void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto len = shape::length(xShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - return; - } - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } - } - - - template - template - void ScalarBoolTransform::transform(const void *vx, Nd4jLong xEws, - void *vz, Nd4jLong zEws, - const void *vscalar, - void *vextraParams, - const uint64_t len, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (xEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); +namespace scalar { + +template +template +void ScalarBoolTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarBoolTransform::transform: super-bad loop visited. " + "Shouldn't ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance()->maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } +} +template +void ScalarBoolTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_BOOL_OPS); } + +template +void ScalarBoolTransform::transform(const int opNum, const void *x, + Nd4jLong xEws, void *z, Nd4jLong zEws, + const void *scalar, void *extraParams, + const uint64_t n, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), + SCALAR_BOOL_OPS); } + +template +void ScalarBoolTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_BOOL_OPS); +} + +template +template +void ScalarBoolTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalar, void *vextraParams, + const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; + } +} + +template +template +void ScalarBoolTransform::transform( + const void *vx, Nd4jLong xEws, void *vz, Nd4jLong zEws, const void *vscalar, + void *vextraParams, const uint64_t len, const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (xEws == 1 && zEws == 1) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); + +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp index 61f2e128cfb4..a3aebc05b314 100644 --- a/libnd4j/include/loops/cpu/scalar_int.cpp +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -19,185 +19,195 @@ // #include "../scalar_int.h" + +#include +#include #include #include -#include -#include #include "../legacy_ops.h" using namespace simdOps; namespace functions { - namespace scalar { - - - template - template - void ScalarIntTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - auto extraParams = reinterpret_cast(vextraParams); - - if (zTadShapeInfo == nullptr) { - zTadShapeInfo = xTadShapeInfo; - zTadOffsets = xTadOffsets; - } - - // tad preparation - const int xTadEws = shape::elementWiseStride(xTadShapeInfo); - const int zTadEws = shape::elementWiseStride(zTadShapeInfo); - const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - const int numTads = shape::length(xShapeInfo) / tadLength; - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); - - if (kindOfLoop != sd::LoopKind::EWS1 && kindOfLoop != sd::LoopKind::EWSNONZERO) { - printf("ScalarIntTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); - return; - } - - int num_threads = sd::math::nd4j_min(numTads, sd::Environment::getInstance()->maxThreads()); - - if (kindOfLoop == sd::LoopKind::EWS1) { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], scalars[r], extraParams); - }; - } - else { - for (auto r = start; r < stop; r++) { - auto oZ = z + zTadOffsets[r]; - auto oX = x + xTadOffsets[r]; - - PRAGMA_OMP_SIMD - for (int f = 0; f < tadLength; f++) - oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); - }; - } - } - - template - void ScalarIntTransform::transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, - const uint64_t start, const uint64_t stop) { - - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS); - } - - - template - void ScalarIntTransform::transform(const int opNum, - const void *x, Nd4jLong xEws, - void *z, Nd4jLong zEws, - const void *scalar, - void *extraParams, - const uint64_t n, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS); - } - - template - void ScalarIntTransform::transform(const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalar, - void *extraParams, - const uint64_t start, const uint64_t stop) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS); - } - - template - template - void ScalarIntTransform::transform(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalar, void *vextraParams, - const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto len = shape::length(xShapeInfo); - - sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); - - if (kindOfLoop == sd::LoopKind::EWS1 || kindOfLoop == sd::LoopKind::EWSNONZERO) { - transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); - return; - } - - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpType::op(x[offset], scalar, extraParams); - }; - } - else { - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); - - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); - }; - } - } - - - template - template - void ScalarIntTransform::transform(const void *vx, Nd4jLong xEws, - void *vz, Nd4jLong zEws, - const void *vscalar, - void *vextraParams, - const uint64_t len, const uint64_t start, const uint64_t stop) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto scalar = reinterpret_cast(vscalar)[0]; - auto extraParams = reinterpret_cast(vextraParams); - - if (scalar < (sizeof(X) * 8)) { - if (xEws == 1 && zEws == 1) { - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } else { - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); - } - } - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , INTEGER_TYPES); +namespace scalar { + +template +template +void ScalarIntTransform::transform( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != sd::LoopKind::EWS1 && + kindOfLoop != sd::LoopKind::EWSNONZERO) { + printf( + "ScalarIntTransform::transform: super-bad loop visited. Shouldn't " + "ever happen\n"); + return; + } + + int num_threads = sd::math::nd4j_min( + numTads, sd::Environment::getInstance()->maxThreads()); + + if (kindOfLoop == sd::LoopKind::EWS1) { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + }; + } else { + for (auto r = start; r < stop; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + }; + } +} +template +void ScalarIntTransform::transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xTadOffsets, const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, + PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, + zTadOffsets, start, stop), + SCALAR_INT_OPS); } + +template +void ScalarIntTransform::transform(const int opNum, const void *x, + Nd4jLong xEws, void *z, Nd4jLong zEws, + const void *scalar, void *extraParams, + const uint64_t n, const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), + SCALAR_INT_OPS); } + +template +void ScalarIntTransform::transform(const int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, + const void *scalar, void *extraParams, + const uint64_t start, + const uint64_t stop) { + DISPATCH_BY_OPNUM_T( + transform, + PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), + SCALAR_INT_OPS); +} + +template +template +void ScalarIntTransform::transform(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + sd::LoopKind::Kind kindOfLoop = + sd::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == sd::LoopKind::EWS1 || + kindOfLoop == sd::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len, start, stop); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + }; + } else { + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = + sd::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + auto xOffset = + shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + auto zOffset = + shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + }; + } +} + +template +template +void ScalarIntTransform::transform(const void *vx, Nd4jLong xEws, void *vz, + Nd4jLong zEws, const void *vscalar, + void *vextraParams, const uint64_t len, + const uint64_t start, + const uint64_t stop) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + if (scalar < (sizeof(X) * 8)) { + if (xEws == 1 && zEws == 1) { + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } + } +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , + INTEGER_TYPES); + +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cpu/summarystatsreduce.cpp b/libnd4j/include/loops/cpu/summarystatsreduce.cpp index 65d44df8334d..aeab13b4d362 100644 --- a/libnd4j/include/loops/cpu/summarystatsreduce.cpp +++ b/libnd4j/include/loops/cpu/summarystatsreduce.cpp @@ -18,168 +18,180 @@ // Created by raver119 on 18.12.17. // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace summarystats { - - - template - Y SummaryStatsReduce::execScalar(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), SUMMARY_STATS_OPS); +namespace summarystats { + +template +Y SummaryStatsReduce::execScalar(const int opNum, + const bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT( + execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams), + SUMMARY_STATS_OPS); +} + +template +void SummaryStatsReduce::execScalar(const int opNum, + const bool biasCorrected, + const void *x, + const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalar, + PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), + SUMMARY_STATS_OPS); +} + +template +void SummaryStatsReduce::exec(const int opNum, const bool biasCorrected, + const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength) { + DISPATCH_BY_OPNUM_TT(exec, + PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, + zShapeInfo, dimension, dimensionLength), + SUMMARY_STATS_OPS); +} + +template +template +void SummaryStatsReduce::execScalar(const bool biasCorrected, + const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo) { + auto z = reinterpret_cast(vz); + z[0] = execScalar(biasCorrected, vx, xShapeInfo, vextraParams); +} + +template +template +Z SummaryStatsReduce::execScalar(const bool biasCorrected, const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + + SummaryStatsData startingIndex; + startingIndex.initialize(); + auto length = shape::length(xShapeInfo); + + uint xShapeInfoCast[MAX_RANK]; + const bool canCast = + sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + for (Nd4jLong i = 0; i < length; i++) { + auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast); + + SummaryStatsData curr; + curr.initWithValue(x[xOffset]); + startingIndex = update(startingIndex, curr, extraParams); + } + + return OpType::getValue(biasCorrected, startingIndex); +} + +template +template +void SummaryStatsReduce::exec(const bool biasCorrected, const void *vx, + const Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto resultLength = shape::length(zShapeInfo); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + SummaryStatsData comp; + comp.initWithValue(x[0]); + + for (Nd4jLong i = 0; i < resultLength; i++) + z[i] = OpType::getValue(biasCorrected, comp); + return; + } + + if (shape::isScalar(zShapeInfo)) { + z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + return; + } + + // no-op + if (dimensionLength < 1) return; + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + + // pre squeezed: this is for keeping the pointer to the original + // shape information for tad offset + // the squeezed information doesn't render the right strides for + // tad offset + if (resultLength == 1 || dimensionLength == shape::rank(xShapeInfo) || + tadPack.numberOfTads() == 1) { + z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + return; + } + + auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); + auto tadLength = shape::length(tadPack.primaryShapeInfo()); + auto tadEWS = shape::elementWiseStride(tadPack.primaryShapeInfo()); + auto tadOrder = shape::order(tadPack.primaryShapeInfo()); + + uint tadShapeShapeInfoCast[MAX_RANK]; + const bool canCast = tadEWS == 1 && tadOrder == 'c' + ? false + : sd::DataTypeUtils::castShapeInfo( + tadShapeShapeInfo, tadShapeShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; + auto tx = x + tadOffsetForBlock; + SummaryStatsData comp; + comp.initWithValue(tx[0]); + + if (tadEWS == 1 && tadOrder == 'c') { + for (Nd4jLong i = 1; i < tadLength; i++) { + SummaryStatsData indexVal2; + indexVal2.initWithValue(tx[i]); + + comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); } + } else { + for (Nd4jLong i = 1; i < tadLength; i++) { + auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, + tadShapeShapeInfoCast, canCast); - template - void SummaryStatsReduce::execScalar(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS); - } - - template - void SummaryStatsReduce::exec(const int opNum, - const bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS); - } - - template - template - void SummaryStatsReduce::execScalar(const bool biasCorrected, - const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo) { - auto z = reinterpret_cast(vz); - z[0] = execScalar(biasCorrected, vx, xShapeInfo, vextraParams); - } - - template - template - Z SummaryStatsReduce::execScalar(const bool biasCorrected, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); + SummaryStatsData indexVal2; + indexVal2.initWithValue(tx[xOffset]); - SummaryStatsData startingIndex; - startingIndex.initialize(); - auto length = shape::length(xShapeInfo); - - uint xShapeInfoCast[MAX_RANK]; - const bool canCast = sd::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - - for (Nd4jLong i = 0; i < length; i++) { - auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast); - - SummaryStatsData curr; - curr.initWithValue(x[xOffset]); - startingIndex = update(startingIndex, curr, extraParams); - } - - return OpType::getValue(biasCorrected, startingIndex); + comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); } + } - template - template - void SummaryStatsReduce::exec(const bool biasCorrected, - const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto resultLength = shape::length(zShapeInfo); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - SummaryStatsData comp; - comp.initWithValue(x[0]); - - for (Nd4jLong i = 0; i < resultLength; i++) - z[i] = OpType::getValue(biasCorrected, comp); - return; - } - - if (shape::isScalar(zShapeInfo)) { - z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); - return; - } - - //no-op - if (dimensionLength < 1) - return; - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - - //pre squeezed: this is for keeping the pointer to the original - //shape information for tad offset - //the squeezed information doesn't render the right strides for - //tad offset - if (resultLength == 1 || dimensionLength == shape::rank(xShapeInfo) || tadPack.numberOfTads() == 1) { - z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); - return; - } - - auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); - auto tadLength = shape::length(tadPack.primaryShapeInfo()); - auto tadEWS = shape::elementWiseStride(tadPack.primaryShapeInfo()); - auto tadOrder = shape::order(tadPack.primaryShapeInfo()); - - uint tadShapeShapeInfoCast[MAX_RANK]; - const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : sd::DataTypeUtils::castShapeInfo(tadShapeShapeInfo, tadShapeShapeInfoCast); - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - - auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; - auto tx = x + tadOffsetForBlock; - SummaryStatsData comp; - comp.initWithValue(tx[0]); - - if (tadEWS == 1 && tadOrder == 'c') { - for (Nd4jLong i = 1; i < tadLength; i++) { - SummaryStatsData indexVal2; - indexVal2.initWithValue(tx[i]); - - comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); - } - } else { - for (Nd4jLong i = 1; i < tadLength; i++) { - auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast); - - SummaryStatsData indexVal2; - indexVal2.initWithValue(tx[xOffset]); - - comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); - } - } - - z[r] = OpType::getValue(biasCorrected, comp); - } - }; - - samediff::Threads::parallel_tad(func, 0, resultLength, 1); - } + z[r] = OpType::getValue(biasCorrected, comp); + } + }; + samediff::Threads::parallel_tad(func, 0, resultLength, 1); +} - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , + LIBND4J_TYPES, FLOAT_TYPES); +} // namespace summarystats +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_any.cpp b/libnd4j/include/loops/cpu/transform/transform_any.cpp index f9fe819bdd81..2f2646d51637 100644 --- a/libnd4j/include/loops/cpu/transform/transform_any.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_any.cpp @@ -18,43 +18,45 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformAny::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS); - } +namespace transform { + +template +void TransformAny::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_ANY_OPS); +} ///////////////////////////////////////////////////////////////////// template -template -void _CUDA_H TransformAny::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +template +void _CUDA_H TransformAny::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); } - - -BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); -} -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, + LIBND4J_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_bool.cpp b/libnd4j/include/loops/cpu/transform/transform_bool.cpp index a9a132e5d93d..6d0170b3cbc7 100644 --- a/libnd4j/include/loops/cpu/transform/transform_bool.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_bool.cpp @@ -18,40 +18,44 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformBool::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS); - } - - template - template - void _CUDA_H TransformBool::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); - } -} \ No newline at end of file +namespace transform { + +template +void TransformBool::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_BOOL_OPS); +} + +template +template +void _CUDA_H TransformBool::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_float.cpp b/libnd4j/include/loops/cpu/transform/transform_float.cpp index bc93aca34ab1..bda579240de6 100644 --- a/libnd4j/include/loops/cpu/transform/transform_float.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_float.cpp @@ -18,39 +18,43 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - template - void TransformFloat::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS); - } - - template - template - void _CUDA_H TransformFloat::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +namespace transform { +template +void TransformFloat::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_TT( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_FLOAT_OPS); +} + +template +template +void _CUDA_H TransformFloat::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, + FLOAT_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_same.cpp b/libnd4j/include/loops/cpu/transform/transform_same.cpp index d826e69e3ca0..4530d71b0c6f 100644 --- a/libnd4j/include/loops/cpu/transform/transform_same.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_same.cpp @@ -18,41 +18,42 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformSame::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS); - } - - template - template - void _CUDA_H TransformSame::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); - } -} \ No newline at end of file +namespace transform { + +template +void TransformSame::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_SAME_OPS); +} + +template +template +void _CUDA_H TransformSame::exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/transform/transform_strict.cpp b/libnd4j/include/loops/cpu/transform/transform_strict.cpp index c59eb8da57c3..a5070f7a566e 100644 --- a/libnd4j/include/loops/cpu/transform/transform_strict.cpp +++ b/libnd4j/include/loops/cpu/transform/transform_strict.cpp @@ -18,41 +18,43 @@ // @author raver119@gmail.com // -#include #include -#include -#include #include +#include +#include +#include using namespace simdOps; namespace functions { - namespace transform { - - template - void TransformStrict::exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, +namespace transform { + +template +void TransformStrict::exec(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads) { + DISPATCH_BY_OPNUM_T( + exec, + PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), + TRANSFORM_STRICT_OPS); +} + +template +template +void _CUDA_H TransformStrict::exec(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS); - } - - template - template - void _CUDA_H TransformStrict::exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams, - uint64_t threadId, uint64_t numThreads) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - sd::TransformLoops::template loopTransform(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); - } -} \ No newline at end of file + void *vextraParams, uint64_t threadId, + uint64_t numThreads) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + sd::TransformLoops::template loopTransform( + x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads); +} + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); +} // namespace transform +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index 0e9ba10614c2..146b7f930b8c 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -18,297 +18,336 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include #include #include -#include -#include #include +#include +#include #include +#include +#include +#include + +#include +#include using namespace simdOps; -template +template static __global__ void broadcastSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::Broadcast::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::Broadcast::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } -template -static __global__ void broadcastSimple(const void const* x, const Nd4jLong const* xShapeInfo, - const void const* y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo ) { +template +static __global__ void broadcastSimple(const void const* x, + const Nd4jLong const* xShapeInfo, + const void const* y, + const Nd4jLong const* yShapeInfo, + void* z, + const Nd4jLong const* zShapeInfo) { + functions::broadcast::Broadcast::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} - functions::broadcast::Broadcast::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static __global__ void broadcastInverseSimple( + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::Broadcast::template transformInverseCuda< + OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } +namespace functions { +namespace broadcast { -template -static __global__ void broadcastInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::Broadcast::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +static Nd4jLong __device__ __noinline__ +getIndexOffset(Nd4jLong index, const Nd4jLong* shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); +} + +static Nd4jLong __device__ __noinline__ length(const Nd4jLong* shapeInfo) { + return shape::length(shapeInfo); } +template +template +__host__ void Broadcast::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} -namespace functions { - namespace broadcast { - - static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); - } - - static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) { - return shape::length(shapeInfo); - } - - template - template - __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - - template - template - __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { - broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); - } - - template - __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong const* zShapeInfo) { - DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - template - __host__ void Broadcast::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - - template - __host__ void Broadcast::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template - template - __device__ void Broadcast::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void* vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - - if (threadIdx.x == 0) { - tadLength = length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(tadOnlyShapeInfo); - auto zOrder = shape::order(tadOnlyShapeInfoZ); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - - auto rY = y + tadOffsets[r]; - auto rZ = z + tadOffsetsZ[r]; - - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && xOrder == yOrder && xOrder == zOrder) { - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = getIndexOffset(i, xShapeInfo); - auto yOffset = getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); - } - } - } - } - - - template - template - __device__ void Broadcast::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - auto xOrder = shape::order(tadOnlyShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = shape::order(tadOnlyShapeInfoZ); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto rX = x + tadOffsets[r]; - auto rZ = z + tadOffsetsZ[r]; - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) { - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - - auto xOffset = getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = getIndexOffset(i, yShapeInfo); - auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); - } - } - } - } +template +template +__host__ void Broadcast::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo) { + broadcastSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} -//////////////////////////////////////////////////////////////////////// -template -template -__device__ void Broadcast::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { +template +__host__ void Broadcast::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); +template +__host__ void Broadcast::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong const* zShapeInfo) { + DISPATCH_BY_OPNUM_TTT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; +template +template +__host__ void Broadcast::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} - if (threadIdx.x == 0) { +template +__host__ void Broadcast::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); +template +template +__device__ void Broadcast::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(tadOnlyShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rY = y + tadOffsets[r]; + auto rZ = z + tadOffsetsZ[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && + xOrder == yOrder && xOrder == zOrder) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = getIndexOffset(i, xShapeInfo); + auto yOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } +} - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); +template +template +__device__ void Broadcast::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + auto xOrder = shape::order(tadOnlyShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rX = x + tadOffsets[r]; + auto rZ = z + tadOffsetsZ[r]; + + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && + xOrder == zOrder) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = getIndexOffset(i, yShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } } - __syncthreads(); + } +} +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void Broadcast::transformCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - shape::index2coords(i, zShapeInfo, zCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } } /* - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template class + SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_5); BUILD_PAIRWISE_TEMPLATE(template class + SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_8); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + Broadcast, , PAIRWISE_TYPES_9); */ - } -} \ No newline at end of file +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.cu b/libnd4j/include/loops/cuda/broadcasting.cu index 55c882c3f81d..dc40f2b3b7b3 100644 --- a/libnd4j/include/loops/cuda/broadcasting.cu +++ b/libnd4j/include/loops/cuda/broadcasting.cu @@ -18,20 +18,19 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include #include #include -#include -#include #include +#include +#include #include +#include +#include +#include -namespace functions { - namespace broadcast { +#include +#include - } -} \ No newline at end of file +namespace functions { +namespace broadcast {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index f547b1b594df..69c1352686bb 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -18,303 +18,335 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include -#include #include -#include -#include -#include +#include +#include + #include -#include +#include using namespace simdOps; ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - void *extraParams, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastBool::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template -static __global__ void broadcastBoolSimple(const void const* x, const Nd4jLong const* xShapeInfo, - const void const* y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo, - void *extraParams) { - - functions::broadcast::BroadcastBool::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); +template +static __global__ void broadcastBoolSimple( + const void const* x, const Nd4jLong const* xShapeInfo, const void const* y, + const Nd4jLong const* yShapeInfo, void* z, const Nd4jLong const* zShapeInfo, + void* extraParams) { + functions::broadcast::BroadcastBool::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); } ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - void *extraParams, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastBool::template transformInverseCuda< + OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ); } namespace functions { namespace broadcast { ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +__host__ void BroadcastBool::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + broadcastBoolSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams) { - - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +__host__ void BroadcastBool::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo, void* extraParams) { + broadcastBoolSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) - DEBUG_KERNEL(stream, opNum); +template +__host__ void BroadcastBool::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams) { - - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), OPS_A(BROADCAST_BOOL_OPS)) - DEBUG_KERNEL(stream, opNum); +template +__host__ void BroadcastBool::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo, void* extraParams) { + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, + z, zShapeInfo, extraParams), + OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); } ////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); - } - -////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) - - DEBUG_KERNEL(stream, opNum); - } +template +template +__host__ void BroadcastBool::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + broadcastBoolInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, + "intermediateBroadcastBool(...) failed"); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastBool::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto rZ = z + tadOffsetsZ[r]; - auto rY = y + tadOffsets[r]; - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); - } - } - } - } +template +__host__ void BroadcastBool::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, void* extraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, tadOnlyShapeInfo, + tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_BOOL_OPS)) + + DEBUG_KERNEL(stream, opNum); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastBool::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - __shared__ Z *rZ; - __shared__ X const* rX; - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - if (threadIdx.x == 0) { - rZ = z + tadOffsetsZ[r]; - rX = x + tadOffsets[r]; - } - __syncthreads(); - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); - } - } - } - } +template +template +__device__ void BroadcastBool::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); + } + } + } +} ////////////////////////////////////////////////////////////////////////// -template +template template -__device__ void BroadcastBool::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); - - auto extraParams = reinterpret_cast(vextraParams); - - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; - +__device__ void BroadcastBool::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ Z* rZ; + __shared__ X const* rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { if (threadIdx.x == 0) { - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; } __syncthreads(); + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); + } + } + } +} - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastBool::transformCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vextraParams) { + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + auto extraParams = reinterpret_cast(vextraParams); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - shape::index2coords(i, zShapeInfo, zCoords); + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } -} + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); -BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu index 52e26bd43b55..452cdfc654ba 100644 --- a/libnd4j/include/loops/cuda/broadcasting_int.cu +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -18,283 +18,317 @@ // @author raver119@gmail.com // -#include +#include +#include +#include #include #include -#include #include -#include -#include -#include +#include +#include + #include -#include +#include using namespace simdOps; ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastIntSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastInt::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastInt::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template -static __global__ void broadcastIntSimple(const void *x, const Nd4jLong const* xShapeInfo, - const void *y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo) { - - functions::broadcast::BroadcastInt::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +template +static __global__ void broadcastIntSimple(const void* x, + const Nd4jLong const* xShapeInfo, + const void* y, + const Nd4jLong const* yShapeInfo, + void* z, + const Nd4jLong const* zShapeInfo) { + functions::broadcast::BroadcastInt::template transformCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// -template +template static __global__ void broadcastBoolInverseSimple( - void const* x, - Nd4jLong const* xShapeInfo, - void const* y, - Nd4jLong const* yShapeInfo, - void *z, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::broadcast::BroadcastInt::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + void const* x, Nd4jLong const* xShapeInfo, void const* y, + Nd4jLong const* yShapeInfo, void* z, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::broadcast::BroadcastInt::template transformInverseCuda( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } namespace functions { namespace broadcast { ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); +__host__ void BroadcastInt::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastIntSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// -template +template template -__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo) { - - broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +__host__ void BroadcastInt::intermediateBroadcast( + dim3 launchDims, cudaStream_t* stream, const void* x, + const Nd4jLong* xShapeInfo, const void* y, const Nd4jLong* yShapeInfo, + void* z, const Nd4jLong* zShapeInfo) { + broadcastIntSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) +template +__host__ void BroadcastInt::execBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_INT_OPS)) } ////////////////////////////////////////////////////////////////////////// -template -__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, - const void *x, const Nd4jLong const* xShapeInfo, - const void *y, const Nd4jLong const* yShapeInfo, - void *z, const Nd4jLong const* zShapeInfo) { - - DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS)) +template +__host__ void BroadcastInt::execBroadcast( + dim3 launchDims, cudaStream_t* stream, const int opNum, const void* x, + const Nd4jLong const* xShapeInfo, const void* y, + const Nd4jLong const* yShapeInfo, void* z, + const Nd4jLong const* zShapeInfo) { + DISPATCH_BY_OPNUM_T( + intermediateBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), + OPS_A(BROADCAST_INT_OPS)) } ////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastInt::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } - -////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastInt::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) - } +template +template +__host__ void BroadcastInt::intermediateInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + broadcastBoolInverseSimple + <<>>( + x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastInt::transformInverseCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; - xEWS = shape::elementWiseStride(xShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto rZ = z + tadOffsetsZ[r]; - auto rY = y + tadOffsets[r]; - - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); - } - } - } - } +template +__host__ void BroadcastInt::execInverseBroadcast( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, void const* y, Nd4jLong const* yShapeInfo, + void* z, Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateInverseBroadcast, + PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + tadOnlyShapeInfoZ, tadOffsetsZ), + OPS_A(BROADCAST_INT_OPS)) +} ////////////////////////////////////////////////////////////////////////// - template - template - __device__ void BroadcastInt::transformCuda( - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - if (tadOnlyShapeInfoZ == nullptr) { - tadOnlyShapeInfoZ = tadOnlyShapeInfo; - tadOffsetsZ = tadOffsets; - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the x shape info for setting up the tad problem - __shared__ Nd4jLong tadLength; - __shared__ Nd4jLong tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); - } - __syncthreads(); - - __shared__ X *rZ; - __shared__ X const* rX; - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - if (threadIdx.x == 0) { - rZ = z + tadOffsetsZ[r]; - rX = x + tadOffsets[r]; - } - __syncthreads(); - - - if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { - - for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); - } - else { - // it is expected that x and z tads and y array all have the same length - for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); - } - } - } - } +template +template +__device__ void BroadcastInt::transformInverseCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if (tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } +} ////////////////////////////////////////////////////////////////////////// -template +template template -__device__ void BroadcastInt::transformCuda(const void *vx, const Nd4jLong const* xShapeInfo, - const void *vy, const Nd4jLong const* yShapeInfo, - void *vz, const Nd4jLong const* zShapeInfo) { - - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen; - __shared__ int rank; - __shared__ bool xzSameOffsets, yzSameOffsets; - +__device__ void BroadcastInt::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadOnlyShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ X* rZ; + __shared__ X const* rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { if (threadIdx.x == 0) { - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; } __syncthreads(); + if (tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastInt::transformCuda( + const void* vx, const Nd4jLong const* xShapeInfo, const void* vy, + const Nd4jLong const* yShapeInfo, void* vz, + const Nd4jLong const* zShapeInfo) { + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ Nd4jLong zLen; + __shared__ int rank; + __shared__ bool xzSameOffsets, yzSameOffsets; - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); - for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + } + __syncthreads(); - shape::index2coords(i, zShapeInfo, zCoords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, zShapeInfo, zCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; } -} + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = + xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = + yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); -BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES); +} // namespace broadcast +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu index ffc274e1b7e6..bb69aba4f7d1 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_0.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu index 37c4edefce73..ef150042dd8e 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_1.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu index 1400b7289092..9aeddff516b3 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_10.cu @@ -21,7 +21,8 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu index 4cfd95934238..92a091db9185 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_11.cu @@ -21,7 +21,8 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu index 9600cd9f8452..cd4c47c6fe05 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_12.cu @@ -21,7 +21,8 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu index 92112eb5b30b..95a3653ae77a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_2.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu index b72bd706a4bf..6ab94f2b1154 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_3.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu index b592b874e081..6f4d11337d85 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_4.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu index c66438f5c70d..7542855457e7 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_5.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu index 5381d45f439d..49d1b89262a0 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_6.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu index d917b7c0fe6d..d7f55b4b4360 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_7.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu index b24f16bc6746..2f0d832474c1 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_8.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu index 48bc66b120d0..974a21f469fc 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting/broadcasting_9.cu @@ -21,7 +21,7 @@ #include "../../broadcasting.chpp" namespace functions { - namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace broadcast { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu index 21c6550e630c..8579df2dd285 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_0.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu index 729b612d349b..5502f409b34c 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_1.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu index 01b197a007c9..bd2846307aa8 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_10.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu index e552367e74d2..323442d4b913 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_11.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu index 6c3176ee4993..cb11391f4bce 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_12.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu index f0e43a85ff62..c6a2ffb6e98e 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_2.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu index 62de38e6c3d9..d8a91c4123db 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_3.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu index e77c3934ffb1..26c23badc4bd 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_4.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu index 0e312afc2b92..b90bd1fcc70c 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_5.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu index ce15e77ca6c0..31ec8c274433 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_6.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu index 1d9572fe6c08..bc05163a4378 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_7.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu index 2df8be1d1819..0ee9d447ec0f 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_8.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu index 235d4b6aae91..b505f571b2a6 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise/pairwise_9.cu @@ -21,7 +21,8 @@ #include "../../pairwise.chpp" namespace functions { - namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace pairwise_transforms { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu index 5208ece7453c..3714e97f51a4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_0.cu @@ -21,7 +21,8 @@ #include "../../reduce3.chpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, + FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu index f9f242ef584e..a8992a89f42a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_1.cu @@ -21,7 +21,8 @@ #include "../../reduce3.chpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, + FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu index 4360574a31b1..51c9d0486988 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_2.cu @@ -21,7 +21,8 @@ #include "../../reduce3.chpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, + FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu index ac2f43d60a32..43497bd10bf2 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3/reduce3_3.cu @@ -21,7 +21,8 @@ #include "../../reduce3.chpp" namespace functions { - namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce3 { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, + FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu index f32d9f16a2fe..85c7a3345f40 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_0.cu @@ -21,7 +21,8 @@ #include "../../reduce/reduce_float.chpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_0); - } -} \ No newline at end of file +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu index c94c45fdc489..06427db6cb7f 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_1.cu @@ -21,7 +21,8 @@ #include "../../reduce/reduce_float.chpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_1); - } -} \ No newline at end of file +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu index 8c3b2325a90a..f8514ec12ce3 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_2.cu @@ -21,7 +21,8 @@ #include "../../reduce/reduce_float.chpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_2); - } -} \ No newline at end of file +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu index af745591afef..14aec33a59bf 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float/reduce_float_3.cu @@ -21,7 +21,8 @@ #include "../../reduce/reduce_float.chpp" namespace functions { - namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_3); - } -} \ No newline at end of file +namespace reduce { +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , + LIBND4J_TYPES, FLOAT_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu index 149c85487836..fa3f62441571 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_0.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_0); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_0); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu index a088f94a2bd0..af306e8b1596 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_1.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_1); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_1); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu index fded63e8c46a..bfcd54d709d7 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_10.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_10); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_10); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu index 5506d1708df4..d151fe014657 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_11.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_11); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_11); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu index ce44f409086a..257f8571be48 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_12.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_12); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_12); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu index 0711ce550d1e..81d3ad86cbc4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_2.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_2); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_2); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu index 64f803d48ede..00bc2fb8d46a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_3.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_3); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_3); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu index 8806668a0cce..05b2fcee8c51 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_4.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_4); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_4); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu index 51b1e6b5140c..366efb82075e 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_5.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_5); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_5); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu index 95d0b46489fd..875100fa9db6 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_6.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_6); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_6); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu index 8a99df7d88c7..b6e7ec42ce1f 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_7.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_7); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_7); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu index c1c233e9b843..b944fb9b554a 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_8.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_8); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_8); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu index 4afcc624e746..00c2bd0e98f2 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu +++ b/libnd4j/include/loops/cuda/compilation_units/scalar/scalar_9.cu @@ -21,7 +21,8 @@ #include "../../scalar.chpp" namespace functions { - namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_9); - } -} \ No newline at end of file +namespace scalar { +BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , + PAIRWISE_TYPES_9); +} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 6cf47240779f..142d701125c8 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -18,361 +18,349 @@ // Created by raver on 4/9/2018. // +#include #include -#include "../indexreduce.h" #include -#include #include +#include "../indexreduce.h" #include "../legacy_ops.h" using namespace simdOps; - template -static __global__ void simpleIndexReduceGeneric(const int op, - void const* dx, - Nd4jLong const* xShapeInfo, int xRank, - void *extraParams, - void *result, - Nd4jLong const* zShapeInfo, int zRank, - int *dimension, - int dimensionLength, - int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,zShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); +static __global__ void simpleIndexReduceGeneric( + const int op, void const *dx, Nd4jLong const *xShapeInfo, int xRank, + void *extraParams, void *result, Nd4jLong const *zShapeInfo, int zRank, + int *dimension, int dimensionLength, int postProcessOrNot, + int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + functions::indexreduce::IndexReduce::transform( + op, dx, xShapeInfo, extraParams, result, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } namespace functions { - namespace indexreduce { - - template - _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - void const* dx, Nd4jLong const* xShapeInfo, - int xRank, - void *extraParams, - void *result, Nd4jLong const* zShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - simpleIndexReduceGeneric<<>>(opNum, - dx, xShapeInfo, xRank, - extraParams, - result, zShapeInfo, 0, - nullptr, 0, - 1, - allocationBuffer, reductionBuffer, - tadOnlyShapeInfo, tadOffsets); - } +namespace indexreduce { - template - _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - simpleIndexReduceGeneric<<>>( - opNum, - dx, - xShapeInfo, xRank, - extraParams, - result, - zShapeInfo, zRank, - dimension, - dimensionLength, - 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - } +template +_CUDA_H void IndexReduce::executeIndexReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, void const *dx, + Nd4jLong const *xShapeInfo, int xRank, void *extraParams, void *result, + Nd4jLong const *zShapeInfo, int zRank, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + simpleIndexReduceGeneric + <<>>( + opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, 0, + nullptr, 0, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, + tadOffsets); +} - // This is the un-specialized struct. Note that we prevent instantiation of this - // struct by putting an undefined symbol in the function body so it won't compile. - template - struct SharedIndexValue { - // Ensure that we won't compile any un-specialized types - __device__ T * getPointer() { - extern __device__ void error(void); - error(); - return 0; - } - }; +template +_CUDA_H void IndexReduce::executeIndexReduce( + dim3 launchDims, cudaStream_t *stream, const int opNum, void const *dx, + Nd4jLong const *xShapeInfo, int xRank, void *extraParams, void *result, + Nd4jLong const *zShapeInfo, int zRank, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + simpleIndexReduceGeneric + <<>>( + opNum, dx, xShapeInfo, xRank, extraParams, result, zShapeInfo, zRank, + dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); +} + +// This is the un-specialized struct. Note that we prevent instantiation of +// this struct by putting an undefined symbol in the function body so it won't +// compile. +template +struct SharedIndexValue { + // Ensure that we won't compile any un-specialized types + __device__ T *getPointer() { + extern __device__ void error(void); + error(); + return 0; + } +}; // Following are the specializations for the following types. -// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double -// One could also specialize it for user-defined types. - - template<> - struct SharedIndexValue { - __device__ IndexValue * getPointer() { - extern __shared__ IndexValue s_int2[]; - return s_int2; - } - }; +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedIndexValue { + __device__ IndexValue *getPointer() { + extern __shared__ IndexValue s_int2[]; + return s_int2; + } +}; // Following are the specializations for the following types. -// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double -// One could also specialize it for user-defined types. - - template<> - struct SharedIndexValue { - __device__ IndexValue * getPointer() { - extern __shared__ IndexValue s_int6[]; - return s_int6; - } - }; - - template - template - __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - IndexValue *sPartials = *sPartialsRef; - Nd4jLong floorPow2 = blockDim.x; - - if (floorPow2 & (floorPow2 - 1)) { - while ( floorPow2 & (floorPow2 - 1) ) { - floorPow2 &= floorPow2 - 1; - } - - if (tid >= floorPow2) { - IndexValue prev = sPartials[tid - floorPow2]; - IndexValue curr = sPartials[tid]; - sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams); - } - __syncthreads(); - } - - for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numElements) { - IndexValue curr = sPartials[tid]; - IndexValue next = sPartials[tid + activeThreads]; - sPartials[tid] = OpType::update(curr,next,extraParams); - } - __syncthreads(); - } - } +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedIndexValue { + __device__ IndexValue *getPointer() { + extern __shared__ IndexValue s_int6[]; + return s_int6; + } +}; - template - __device__ void IndexReduce::transform( - const int opNum, - void const* x, - Nd4jLong const* xShapeInfo, - void *extraParams, - void *result, - Nd4jLong const* zShapeInfo, - int *dimension, - int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, - void *reductionBuffer, - Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffset) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +template +template +__device__ void IndexReduce::aggregatePartials( + IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, + void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + auto extraParams = static_cast(vextraParams); + IndexValue *sPartials = *sPartialsRef; + Nd4jLong floorPow2 = blockDim.x; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) { + floorPow2 &= floorPow2 - 1; + } + + if (tid >= floorPow2) { + IndexValue prev = sPartials[tid - floorPow2]; + IndexValue curr = sPartials[tid]; + sPartials[tid - floorPow2] = OpType::update(prev, curr, extraParams); + } + __syncthreads(); + } + + for (int activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numElements) { + IndexValue curr = sPartials[tid]; + IndexValue next = sPartials[tid + activeThreads]; + sPartials[tid] = OpType::update(curr, next, extraParams); + } + __syncthreads(); + } +} + +template +__device__ void IndexReduce::transform( + const int opNum, void const *x, Nd4jLong const *xShapeInfo, + void *extraParams, void *result, Nd4jLong const *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffset) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(x, xShapeInfo, extraParams, result, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, + reductionBuffer, tadShapeInfo, tadOffset), + INDEX_REDUCE_OPS); +} + +template +template +__device__ void IndexReduce::transform( + void const *vdx, Nd4jLong const *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong const *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, + Nd4jLong const *tadOnlyShapeInfo, Nd4jLong const *tadOffsets) { + /**int + * Gpu information for the problem + */ + auto dx = reinterpret_cast(vdx); + auto z = reinterpret_cast(vz); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); + auto order = shape::order(xShapeInfo); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ volatile int resultScalar; + + // shared memory space for storing intermediate results + __shared__ IndexValue *sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast *>(shmem); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + // length for the tad + __shared__ volatile Nd4jLong xLength; + + __shared__ volatile Nd4jLong zLen; + + // only compute the tad indexes once + IndexValue reduction = OpType::startingIndexValue(dx); + + if (threadIdx.x == 0) { + if (zShapeInfo != nullptr) + zLen = shape::length(zShapeInfo); + else + zLen = 1; + + if (dimensionLength == 1) { + if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) + resultScalar = 1; + else + resultScalar = 0; + } else + resultScalar = 0; + + if (zLen == 1) resultScalar = 1; + + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + if (sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { + if (sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) return; + + for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) + z[i] = (Z)reduction.index; + + return; + } + + if (!resultScalar) { + __shared__ Nd4jLong tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + if (dimensionLength > 1 || tadEWS < 1) { + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + IndexValue comp{dx[xOffset], i}; + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], comp, extraParams); } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - template - template - __device__ void IndexReduce::transform(void const* vdx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void* vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets){ - /**int - * Gpu information for the problem - */ - auto dx = reinterpret_cast(vdx); - auto z = reinterpret_cast(vz); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); - auto order = shape::order(xShapeInfo); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ volatile int resultScalar; - - //shared memory space for storing intermediate results - __shared__ IndexValue* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); - } - __syncthreads(); - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - //length for the tad - __shared__ volatile Nd4jLong xLength; - - __shared__ volatile Nd4jLong zLen; - - - //only compute the tad indexes once - IndexValue reduction = OpType::startingIndexValue(dx); - - if (threadIdx.x == 0) { - if (zShapeInfo != nullptr) - zLen = shape::length(zShapeInfo); - else zLen = 1; - - if (dimensionLength == 1) { - if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) - resultScalar = 1; - else - resultScalar = 0; - } - else - resultScalar = 0; - - if (zLen == 1) - resultScalar = 1; - - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) { - - if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY) - return; - - for (uint i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) - z[i] = (Z) reduction.index; - - return; - } - - if (!resultScalar) { - - __shared__ Nd4jLong tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - } - __syncthreads(); - - if (dimensionLength > 1 || tadEWS < 1) { - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - IndexValue comp {dx[xOffset], i}; - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[r] = (Z) sPartials[threadIdx.x].index; - } - __syncthreads(); - } - } else { - - for(int i = blockIdx.x; i < numTads; i+= gridDim.x) { - Nd4jLong tadOffsetForBlock = tadOffsets[i]; - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) { - IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength),extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); - } - __syncthreads(); - } - } - } else { - auto n = shape::length(xShapeInfo); - auto xElementWiseStride = shape::elementWiseStride(xShapeInfo); - - if(xElementWiseStride >= 1 && order == 'c') { - for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) { - IndexValue indexVal = {dx[i * xElementWiseStride], i}; - reduction = OpType::update(reduction, indexVal, extraParams); - } - } else { - - for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { - auto offset = shape::getIndexOffset(i, xShapeInfo); - IndexValue indexVal = {dx[offset], i}; - reduction = OpType::update(reduction, indexVal, extraParams); - } - } - - - sPartials[threadIdx.x] = reduction; - __syncthreads(); - - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, (int) n),extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - __shared__ bool amLast; - unsigned int *tc = (unsigned int *) reductionBuffer; - tid = threadIdx.x; - if (threadIdx.x == 0) { - auto pBuffer = reinterpret_cast *>(reductionBuffer); - pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; - } - __threadfence(); - __syncthreads(); - - if (tid==0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x-1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - IndexValue *pBuffer = (IndexValue *) reductionBuffer; - - sPartials[threadIdx.x] = OpType::startingIndexValue(dx); - - for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], pBuffer[i], extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x),extraParams); - - __syncthreads(); - if (tid == 0) { - z[0] = (Z) sPartials[0].index; - } - } - } else { - if (tid == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = (Z) sPartials[0].index; - } - } - - } + __syncthreads(); + if (threadIdx.x == 0) { + z[r] = (Z)sPartials[threadIdx.x].index; } + __syncthreads(); + } + } else { + for (int i = blockIdx.x; i < numTads; i += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[i]; + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { + IndexValue comp{dx[tadOffsetForBlock + x * tadEWS], x}; + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], comp, extraParams); + } + + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); + __syncthreads(); + if (threadIdx.x == 0) { + z[i] = + (Z)sPartials[threadIdx.x] + .index; // postProcess(sPartials[0],tadLength ,extraParams); + } + __syncthreads(); + } } -} + } else { + auto n = shape::length(xShapeInfo); + auto xElementWiseStride = shape::elementWiseStride(xShapeInfo); + + if (xElementWiseStride >= 1 && order == 'c') { + for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { + IndexValue indexVal = {dx[i * xElementWiseStride], i}; + reduction = OpType::update(reduction, indexVal, extraParams); + } + } else { + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { + auto offset = shape::getIndexOffset(i, xShapeInfo); + IndexValue indexVal = {dx[offset], i}; + reduction = OpType::update(reduction, indexVal, extraParams); + } + } + + sPartials[threadIdx.x] = reduction; + __syncthreads(); + + aggregatePartials(&sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, (int)n), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + __shared__ bool amLast; + unsigned int *tc = (unsigned int *)reductionBuffer; + tid = threadIdx.x; + if (threadIdx.x == 0) { + auto pBuffer = reinterpret_cast *>(reductionBuffer); + pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; + } + __threadfence(); + __syncthreads(); + + if (tid == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + IndexValue *pBuffer = (IndexValue *)reductionBuffer; + + sPartials[threadIdx.x] = OpType::startingIndexValue(dx); + + for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) { + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], pBuffer[i], extraParams); + } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); + __syncthreads(); + if (tid == 0) { + z[0] = (Z)sPartials[0].index; + } + } + } else { + if (tid == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = (Z)sPartials[0].index; + } + } + } +} +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES, + INDEXING_TYPES); +} // namespace indexreduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h index 9e7bdf63409f..45350d761bf6 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/reduce_same_inplace.h @@ -18,156 +18,171 @@ // @author raver119@gmail.com // - #ifndef SD_REDUCE_SAME_LOOPS_H #define SD_REDUCE_SAME_LOOPS_H +#include #include -#include #include -#include +#include using namespace simdOps; namespace functions { - namespace reduce { - template - class ReduceSameInplace { - public: - static FORCEINLINE void _CUDA_D execScalarCudaLegacy(int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); - - template - static FORCEINLINE void _CUDA_D execScalarCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); - - template - static FORCEINLINE void _CUDA_D aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams); - }; +namespace reduce { +template +class ReduceSameInplace { + public: + static FORCEINLINE void _CUDA_D execScalarCudaLegacy( + int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo); + + template + static FORCEINLINE void _CUDA_D execScalarCuda(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, + void *vreductionBuffer, + Nd4jLong *tadOnlyShapeInfo); + + template + static FORCEINLINE void _CUDA_D aggregatePartials(void *vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void *vextraParams); +}; + +template +template +__device__ void ReduceSameInplace::aggregatePartials(void *vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + + auto sPartials = static_cast(vsPartials); + auto extraParams = static_cast(vextraParams); + + Nd4jLong floorPow2 = numItems; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; + + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); + + __syncthreads(); + } + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); + + __syncthreads(); + } +} - template - template - __device__ void ReduceSameInplace::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +template +FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCudaLegacy( + int opNum, void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalarCuda, + PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, + vreductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); +} - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. +template +template +FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCuda( + void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + Nd4jLong *zShapeInfo, void *vreductionBuffer, Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + int xEws = shape::elementWiseStride(xShapeInfo); + auto len = shape::length(xShapeInfo); + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ X *sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + } + __syncthreads(); + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - auto sPartials = static_cast(vsPartials); - auto extraParams = static_cast(vextraParams); + __syncthreads(); - Nd4jLong floorPow2 = numItems; + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (floorPow2 & (floorPow2 - 1)) { - - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); - - __syncthreads(); - } - } - - template - FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCudaLegacy(int opNum, void *vx, Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong *zShapeInfo, - void *vreductionBuffer, - Nd4jLong *tadOnlyShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalarCuda, PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, vreductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); - } - - template - template - FORCEINLINE void _CUDA_D ReduceSameInplace::execScalarCuda(void *vx, Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong *zShapeInfo, - void *vreductionBuffer, - Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - int xEws = shape::elementWiseStride(xShapeInfo); - auto len = shape::length(xShapeInfo); - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - } - __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } - } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } - } + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); + + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } + } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } +} // namespace reduce +} // namespace functions -#endif //SD_REDUCE_SAME_LOOPS_H +#endif // SD_REDUCE_SAME_LOOPS_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h index 7049483e7852..2cafbbca6240 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/scalar_inplace.h @@ -18,65 +18,66 @@ // @author raver119@gmail.com // - #ifndef SD_SCALAR_INPLACE_H #define SD_SCALAR_INPLACE_H +#include #include -#include #include -#include +#include using namespace simdOps; namespace functions { - namespace scalar { - template - class ScalarInplace { - public: - static FORCEINLINE _CUDA_D void transformCudaLegacy(int opNum, void* vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); - - template - static FORCEINLINE _CUDA_D void transformCuda(void* vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); - }; - - template - FORCEINLINE _CUDA_D void ScalarInplace::transformCudaLegacy(int opNum, void* vscalar, - void *vy, Nd4jLong *yShapeInfo, - void *vparams, - void *vz, Nd4jLong *zShapeInfo, - int *allocationBuffer) { - - DISPATCH_BY_OPNUM_TTT(transformCuda, PARAMS(vscalar, vy, yShapeInfo, vparams, vz, zShapeInfo, allocationBuffer), SCALAR_OPS); - } - - template - template - FORCEINLINE _CUDA_D void ScalarInplace::transformCuda(void* vscalar, - void *vy, Nd4jLong *yShapeInfo, - void *vparams, - void *vz, Nd4jLong *zShapeInfo, - int *allocationBuffer) { - - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) - length = shape::length(yShapeInfo); - __syncthreads(); - +namespace scalar { +template +class ScalarInplace { + public: + static FORCEINLINE _CUDA_D void transformCudaLegacy( + int opNum, void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, + void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer); + + template + static FORCEINLINE _CUDA_D void transformCuda(void *vscalar, void *vy, + Nd4jLong *yShapeInfo, + void *vparams, void *vz, + Nd4jLong *zShapeInfo, + int *allocationBuffer); +}; + +template +FORCEINLINE _CUDA_D void ScalarInplace::transformCudaLegacy( + int opNum, void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, + void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer) { + DISPATCH_BY_OPNUM_TTT(transformCuda, + PARAMS(vscalar, vy, yShapeInfo, vparams, vz, zShapeInfo, + allocationBuffer), + SCALAR_OPS); +} - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } - } - } +template +template +FORCEINLINE _CUDA_D void ScalarInplace::transformCuda( + void *vscalar, void *vy, Nd4jLong *yShapeInfo, void *vparams, void *vz, + Nd4jLong *zShapeInfo, int *allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) length = shape::length(yShapeInfo); + __syncthreads(); + + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } +} // namespace scalar +} // namespace functions -#endif //SD_SCALAR_INPLACE_H +#endif // SD_SCALAR_INPLACE_H diff --git a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h index 903ed8c53291..78cbcb40d897 100644 --- a/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h +++ b/libnd4j/include/loops/cuda/inplace_loops/transform_strict_inplace.h @@ -21,79 +21,74 @@ #ifndef SD_TRANSFORM_FLOAT_INPLACE_H #define SD_TRANSFORM_FLOAT_INPLACE_H +#include #include -#include #include -#include +#include using namespace simdOps; -#define LOCAL_TRANSFORM_STRICT_OPS \ - (23, Exp), \ - (24, Log) +#define LOCAL_TRANSFORM_STRICT_OPS (23, Exp), (24, Log) namespace functions { - namespace transform { - template - class TransformStrictInplace { - public: - static FORCEINLINE _CUDA_D void transformCudaLegacy(int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); - - template - static FORCEINLINE _CUDA_D void transformCuda(void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); - }; - - template - template - FORCEINLINE _CUDA_D void TransformStrictInplace::transformCuda( - void *vdy, - Nd4jLong *shapeInfo, - void *vparams, - void *vresult, - Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - - auto dy = static_cast(vdy); - auto result = static_cast(vresult); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - auto xOrder = shape::order(shapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto xEws = shape::elementWiseStride(shapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) - length = shape::length(shapeInfo); - __syncthreads(); - - - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { - auto xOffset2 = shape::getIndexOffset(i, shapeInfo); - auto zOffset2 = shape::getIndexOffset(i, zShapeInfo); - result[zOffset2] = OpType::op(dy[xOffset2], params); - } - } +namespace transform { +template +class TransformStrictInplace { + public: + static FORCEINLINE _CUDA_D void transformCudaLegacy( + int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, + Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); + + template + static FORCEINLINE _CUDA_D void transformCuda( + void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, + Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); +}; + +template +template +FORCEINLINE _CUDA_D void TransformStrictInplace::transformCuda( + void *vdy, Nd4jLong *shapeInfo, void *vparams, void *vresult, + Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + auto dy = static_cast(vdy); + auto result = static_cast(vresult); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + auto xOrder = shape::order(shapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto xEws = shape::elementWiseStride(shapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) length = shape::length(shapeInfo); + __syncthreads(); + + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) { + auto xOffset2 = shape::getIndexOffset(i, shapeInfo); + auto zOffset2 = shape::getIndexOffset(i, zShapeInfo); + result[zOffset2] = OpType::op(dy[xOffset2], params); + } +} - template - FORCEINLINE _CUDA_D void TransformStrictInplace::transformCudaLegacy( - int opNum, - void *dy, - Nd4jLong *shapeInfo, - void *params, - void *result, - Nd4jLong *zShapeInfo, - int *allocationPointer, - void *reductionPointer, - Nd4jLong *tadShapeInfo, - Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(transformCuda, PARAMS(dy, shapeInfo, params, result, zShapeInfo, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LOCAL_TRANSFORM_STRICT_OPS); - } - } +template +FORCEINLINE _CUDA_D void TransformStrictInplace::transformCudaLegacy( + int opNum, void *dy, Nd4jLong *shapeInfo, void *params, void *result, + Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionPointer, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T( + transformCuda, + PARAMS(dy, shapeInfo, params, result, zShapeInfo, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + LOCAL_TRANSFORM_STRICT_OPS); } +} // namespace transform +} // namespace functions #undef LOCAL_TRANSFORM_STRICT_OPS -#endif //SD_TRANSFORM_FLOAT_INPLACE_H +#endif // SD_TRANSFORM_FLOAT_INPLACE_H diff --git a/libnd4j/include/loops/cuda/pairwise.chpp b/libnd4j/include/loops/cuda/pairwise.chpp index 8f3cf358a4aa..92c8f6607cc2 100644 --- a/libnd4j/include/loops/cuda/pairwise.chpp +++ b/libnd4j/include/loops/cuda/pairwise.chpp @@ -20,104 +20,109 @@ #ifndef PAIRWISE_CU #define PAIRWISE_CU - #include "../pairwise_transform.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void __host__ PairWiseTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void __host__ PairWiseTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } //////////////////////////////////////////////////////////////////////////////// -template -void __host__ PairWiseTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void* vextraParams) { - DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_TRANSFORM_OPS); +template +void __host__ PairWiseTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + DISPATCH_BY_OPNUM_TTT(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, + yShapeInfo, vz, zShapeInfo, vextraParams), + PAIRWISE_TRANSFORM_OPS); } /* - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_9); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_1); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_2); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_3); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_4); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_5); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_6); BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT + PairWiseTransform, , PAIRWISE_TYPES_7); BUILD_PAIRWISE_TEMPLATE(template + class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_8); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , + PAIRWISE_TYPES_9); */ -} -} +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_CU \ No newline at end of file +#endif // PAIRWISE_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise.cu b/libnd4j/include/loops/cuda/pairwise.cu index 4833d32d0927..869200ba901e 100644 --- a/libnd4j/include/loops/cuda/pairwise.cu +++ b/libnd4j/include/loops/cuda/pairwise.cu @@ -21,7 +21,5 @@ #include "../pairwise_transform.h" namespace functions { - namespace pairwise_transforms { - - } -} \ No newline at end of file +namespace pairwise_transforms {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_bool.cu b/libnd4j/include/loops/cuda/pairwise_bool.cu index 578c6e7edd1e..c52ec95b611f 100644 --- a/libnd4j/include/loops/cuda/pairwise_bool.cu +++ b/libnd4j/include/loops/cuda/pairwise_bool.cu @@ -20,98 +20,98 @@ #ifndef PAIRWISE_BOOL_CU #define PAIRWISE_BOOL_CU - #include "../pairwise_bool.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } - -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H PairWiseBoolTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void _CUDA_H PairWiseBoolTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } - //////////////////////////////////////////////////////////////////////////////// -template -void PairWiseBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) { - auto xType = sd::DataTypeUtils::fromT(); - auto yType = sd::DataTypeUtils::fromT(); - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); +template +void PairWiseBoolTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + auto xType = sd::DataTypeUtils::fromT(); + auto yType = sd::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, + yShapeInfo, vz, zShapeInfo, vextraParams), + PAIRWISE_BOOL_OPS); } - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); -} -} +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT PairWiseBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_BOOL_CU \ No newline at end of file +#endif // PAIRWISE_BOOL_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_int.cu b/libnd4j/include/loops/cuda/pairwise_int.cu index dbc676a56754..039a5e682f51 100644 --- a/libnd4j/include/loops/cuda/pairwise_int.cu +++ b/libnd4j/include/loops/cuda/pairwise_int.cu @@ -20,97 +20,97 @@ #ifndef PAIRWISE_INT_CU #define PAIRWISE_INT_CU - #include "../pairwise_int.h" - using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void pairwiseSimpleShaped(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int xEws; - __shared__ int yEws; - __shared__ int zEws; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - xEws = shape::elementWiseStride(xShapeInfo); - yEws = shape::elementWiseStride(yShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - yOrder = shape::order(yShapeInfo); - zOrder = shape::order(zShapeInfo); - len = shape::length(xShapeInfo); - } - __syncthreads(); - - - if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); - } - } - else if (vx == vz) { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } - else { - for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - } +__global__ static void pairwiseSimpleShaped( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void* vextraParams) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } } - -namespace functions { +namespace functions { namespace pairwise_transforms { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H PairWiseIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void *vextraParams){ - - pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +template +template +void _CUDA_H PairWiseIntTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + pairwiseSimpleShaped + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); } - //////////////////////////////////////////////////////////////////////////////// -template -void PairWiseIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void *vextraParams) { - auto xType = sd::DataTypeUtils::fromT(); - - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); +template +void PairWiseIntTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, void* vextraParams) { + auto xType = sd::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, + vz, zShapeInfo, vextraParams), + PAIRWISE_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , INTEGER_TYPES); -} -} +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT PairWiseIntTransform, , + INTEGER_TYPES); +} // namespace pairwise_transforms +} // namespace functions -#endif // PAIRWISE_INT_CU \ No newline at end of file +#endif // PAIRWISE_INT_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index dd24d202dfb5..472c06510877 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -18,429 +18,524 @@ // @author raver119@gmail.com // -#include -#include -#include #include #include #include +#include #include +#include +#include using namespace randomOps; template -static inline __device__ void randomSingleGeneric( - Nd4jPointer state, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - z, - zShapeBuffer, - extraArguments); +static inline __device__ void randomSingleGeneric(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, z, zShapeBuffer, extraArguments); } template static inline __device__ void randomDoubleGeneric( - Nd4jPointer state, - void const* x, - Nd4jLong const* xShapeBuffer, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - x, - xShapeBuffer, - z, - zShapeBuffer, - extraArguments); + Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, x, xShapeBuffer, z, zShapeBuffer, extraArguments); } - template static inline __device__ void randomTripleGeneric( - Nd4jPointer state, - void const* x, - Nd4jLong const* xShapeBuffer, - void const* y, - Nd4jLong const* yShapeBuffer, - void *z, - Nd4jLong const* zShapeBuffer, - void *extraArguments) { - - - functions::random::RandomFunction::template execTransformCuda( - state, - x, - xShapeBuffer, - y, - yShapeBuffer, - z, - zShapeBuffer, - extraArguments); + Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void const* y, Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, void* extraArguments) { + functions::random::RandomFunction::template execTransformCuda( + state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); } - #ifndef __CLION_IDE__ // here we generate kernels for target operations -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, double, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float16, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, bfloat16, INPUT(Nd4jPointer state, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, float, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, double, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, float16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomDouble_, randomDoubleGeneric, bfloat16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, double, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) -DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, bfloat16, INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, void const* y, Nd4jLong const* yShapeBuffer, void *z, Nd4jLong const* zShapeBuffer, void *extraArguments), PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, double, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, float16, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomSingle_, randomSingleGeneric, bfloat16, + INPUT(Nd4jPointer state, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, float, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, double, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, float16, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE( + randomDouble_, randomDoubleGeneric, bfloat16, + INPUT(Nd4jPointer state, void const* x, Nd4jLong const* xShapeBuffer, + void* z, Nd4jLong const* zShapeBuffer, void* extraArguments), + PARAMS(state, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, double, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, float16, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) +DISPATCH_KERNEL_SIMPLE(randomTriple_, randomTripleGeneric, bfloat16, + INPUT(Nd4jPointer state, void const* x, + Nd4jLong const* xShapeBuffer, void const* y, + Nd4jLong const* yShapeBuffer, void* z, + Nd4jLong const* zShapeBuffer, + void* extraArguments), + PARAMS(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) #endif namespace functions { - namespace random { - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - if (OpClass::requiresSpecial) { - OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - return; - } else { - - __shared__ Nd4jLong length; - __shared__ int xEWS; - __shared__ int yEWS; - __shared__ int zEWS; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - sd::graph::RandomGenerator *devBuffer; - if (threadIdx.x == 0) { - length = shape::length(zShapeBuffer); - xEWS = shape::elementWiseStride(xShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - yOrder = shape::order(yShapeBuffer); - zOrder = shape::order(zShapeBuffer); - - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (xEWS >= 1 && yEWS >= 1 && zEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { - z[e * zEWS] = OpClass::op(x[e * xEWS], y[e * yEWS], e, length, buffer, extraArguments); - } - } else { - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - - auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); - auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer); - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - - z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, extraArguments); - } - } - } - }; - - - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - __shared__ Nd4jLong length; - __shared__ int xEWS; - __shared__ int zEWS; - __shared__ char xOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devBuffer; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - - length = shape::length(zShapeBuffer); - xEWS = shape::elementWiseStride(xShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - zOrder = shape::order(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - - if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { - for (Nd4jLong e = blockIdx.x * blockDim.x + threadIdx.x; e < length; e += blockDim.x * gridDim.x) { - z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments); - } - } else { - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; i += blockDim.x * gridDim.x) { - - auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - - z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments); - } - } - } - - - template - template - void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - __shared__ Nd4jLong length; - __shared__ Nd4jLong ews; - __shared__ sd::graph::RandomGenerator *buffer; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devBuffer; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - buffer = (sd::graph::RandomGenerator *) shmem; - cB = shmem; - devBuffer = reinterpret_cast (state); - dB = reinterpret_cast (state); - length = shape::length(zShapeBuffer); - ews = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (ews > 0) { - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - z[i * ews] = OpClass::op(i, length, buffer, extraArguments); - } - } else { - - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - z[zOffset2] = OpClass::op(i, length, buffer, extraArguments); - } - } - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, float, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, float16, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, bfloat16, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaSingle(dim3& launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomSingle, double, PARAMS(stateHost, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, float, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, float16, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); - - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, bfloat16, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) - - DEBUG_KERNEL(stream, opNum); - } - - template <> - _CUDA_H void RandomFunction::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +namespace random { +template +template +void _CUDA_D RandomFunction::execTransformCuda( + Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, + void const* vy, Nd4jLong const* yShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + if (OpClass::requiresSpecial) { + OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments); + return; + } else { + __shared__ Nd4jLong length; + __shared__ int xEWS; + __shared__ int yEWS; + __shared__ int zEWS; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + sd::graph::RandomGenerator* devBuffer; + if (threadIdx.x == 0) { + length = shape::length(zShapeBuffer); + xEWS = shape::elementWiseStride(xShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + yOrder = shape::order(yShapeBuffer); + zOrder = shape::order(zShapeBuffer); + + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (xEWS >= 1 && yEWS >= 1 && zEWS >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { + z[e * zEWS] = OpClass::op(x[e * xEWS], y[e * yEWS], e, length, buffer, + extraArguments); + } + } else { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); + auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer); + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + + z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, + extraArguments); + } + } + } +}; + +template +template +void _CUDA_D RandomFunction::execTransformCuda( + Nd4jPointer state, void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + __shared__ Nd4jLong length; + __shared__ int xEWS; + __shared__ int zEWS; + __shared__ char xOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + __shared__ sd::graph::RandomGenerator* devBuffer; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + + length = shape::length(zShapeBuffer); + xEWS = shape::elementWiseStride(xShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + zOrder = shape::order(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + if (xEWS >= 1 && zEWS >= 1 && xOrder == zOrder) { + for (Nd4jLong e = blockIdx.x * blockDim.x + threadIdx.x; e < length; + e += blockDim.x * gridDim.x) { + z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments); + } + } else { + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; + i += blockDim.x * gridDim.x) { + auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer); + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments); + } + } +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomDouble, double, PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template +template +void _CUDA_D RandomFunction::execTransformCuda(Nd4jPointer state, void* vz, + Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + __shared__ Nd4jLong length; + __shared__ Nd4jLong ews; + __shared__ sd::graph::RandomGenerator* buffer; + __shared__ unsigned char* cB; + __shared__ unsigned char* dB; + __shared__ sd::graph::RandomGenerator* devBuffer; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + buffer = (sd::graph::RandomGenerator*)shmem; + cB = shmem; + devBuffer = reinterpret_cast(state); + dB = reinterpret_cast(state); + length = shape::length(zShapeBuffer); + ews = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (ews > 0) { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + z[i * ews] = OpClass::op(i, length, buffer, extraArguments); + } + } else { + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + z[zOffset2] = OpClass::op(i, length, buffer, extraArguments); + } + } +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, float, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + DEBUG_KERNEL(stream, opNum); +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, float, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - DEBUG_KERNEL(stream, opNum); - } + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, float16, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { + DEBUG_KERNEL(stream, opNum); +} - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, float16, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, bfloat16, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - DEBUG_KERNEL(stream, opNum); - } + DEBUG_KERNEL(stream, opNum); +} - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +template <> +_CUDA_H void RandomFunction::executeCudaSingle( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void* vz, Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomSingle, double, + PARAMS(stateHost, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, bfloat16, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, float, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, float16, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, bfloat16, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - template <> - _CUDA_H void RandomFunction::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, Nd4jLong const* yShapeBuffer, void *vz, Nd4jLong const* zShapeBuffer, void *vextraArguments) { +template <> +_CUDA_H void RandomFunction::executeCudaDouble( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void* vz, + Nd4jLong const* zShapeBuffer, void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE( + randomDouble, double, + PARAMS(stateHost, x, xShapeBuffer, z, zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto extraArguments = reinterpret_cast(vextraArguments); +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, float, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - // this macro builds bunch of IF/ELSE selectors for kernel launch - DISPATCH_SIMPLE(randomTriple, double, PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments), OPS_A(RANDOM_OPS)) +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, float16, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); - } +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, bfloat16, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); +} - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); - } +template <> +_CUDA_H void RandomFunction::executeCudaTriple( + dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, + void const* vx, Nd4jLong const* xShapeBuffer, void const* vy, + Nd4jLong const* yShapeBuffer, void* vz, Nd4jLong const* zShapeBuffer, + void* vextraArguments) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraArguments = reinterpret_cast(vextraArguments); + + // this macro builds bunch of IF/ELSE selectors for kernel launch + DISPATCH_SIMPLE(randomTriple, double, + PARAMS(stateHost, x, xShapeBuffer, y, yShapeBuffer, z, + zShapeBuffer, extraArguments), + OPS_A(RANDOM_OPS)) + + DEBUG_KERNEL(stream, opNum); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES); +} // namespace random +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index c36b407c7b8c..30cf6ce6ad4f 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -19,330 +19,351 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceBoolFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceBoolFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceBoolFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceBoolFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } - namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { - - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. - - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); +__device__ void ReduceBoolFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - Nd4jLong floorPow2 = numItems; + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - if (floorPow2 & (floorPow2 - 1)) { + Nd4jLong floorPow2 = numItems; - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); //tadLength(xShapeInfo, dimension, dimensionLength); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceBoolFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); // tadLength(xShapeInfo, + // dimension, dimensionLength); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceBoolFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceBoolFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - nd4j_printf("Step A%i\n", -1); - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceBoolFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed"); - } +template +__host__ void ReduceBoolFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + nd4j_printf("Step A%i\n", -1); + + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceBoolFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed"); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceBoolFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceBoolFunction::intermediateScalar: failed to copy resulting scalar", res); - - sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed"); - - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed"); - } +template +__host__ void ReduceBoolFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceBoolFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + + sd::DebugHelper::checkErrorCode(stream, + "reduceBoolScalar empty(...) failed"); + + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed"); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceBoolFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_BOOL_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceBoolFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_BOOL_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceBoolFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const int rank, const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_BOOL_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceBoolFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, const int rank, + const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + OPS_A(REDUCE_BOOL_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , + LIBND4J_TYPES, BOOL_TYPES); -BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp index ed13b8331807..5391c65623bb 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp +++ b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp @@ -18,321 +18,346 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include +#include +#include #include #include -#include -#include -#include #include - -#include -#include +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceFloatFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceFloatFunction::template transformCudaXD< + OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceFloatFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceFloatFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void ReduceFloatFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceFloatFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - auto tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceFloatFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceFloatFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (amLast) { + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + unsigned int *tc = (unsigned *)reductionBuffer; + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceFloatFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(reinterpret_cast(x))); - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceFloatFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceFloatFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue( + reinterpret_cast(x))); + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceFloatFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, + ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShape, extraParams, z, zShape, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceFloatFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = std::is_same>::value ? sd::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceFloatFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar <<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceFloatFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = std::is_same>::value + ? sd::DataTypeUtils::nanOrZero() + : static_cast(OpType::startingValue( + reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceFloatFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceFloatFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_FLOAT_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceFloatFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_FLOAT_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceFloatFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const int rank, const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, hXShapeInfo, extraParams, z, zShape, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_FLOAT_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceFloatFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, const int rank, + const void *x, const Nd4jLong *xShape, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShape, hXShapeInfo, extraParams, z, zShape, + hZShapeInfo, dimension, dimensionLength, reductionPointer, + tadShapeInfo, tadOffsets), + OPS_A(REDUCE_FLOAT_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , +// LIBND4J_TYPES, FLOAT_TYPES); -//BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_long.cu b/libnd4j/include/loops/cuda/reduce/reduce_long.cu index 0b0b6e973355..c13ddfd304b4 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_long.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_long.cu @@ -19,342 +19,365 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__device__ void reduceSimpleGeneric(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - functions::reduce::ReduceLongFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); +__device__ void reduceSimpleGeneric(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + functions::reduce::ReduceLongFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __device__ void reduceScalarGeneric(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo) { - - functions::reduce::ReduceLongFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + functions::reduce::ReduceLongFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - reduceSimpleGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + reduceSimpleGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - reduceScalarGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); + void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + reduceScalarGeneric(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo); } namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { - - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. +__device__ void ReduceLongFunction::aggregatePartials( + void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); - Nd4jLong floorPow2 = numItems; + Nd4jLong floorPow2 = numItems; - if (floorPow2 & (floorPow2 - 1)) { + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ int tadLength, numTads; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; +__device__ void ReduceLongFunction::transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *vreductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ int tadLength, numTads; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceLongFunction::execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - void *vreductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ Z* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceLongFunction::execScalarCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, void *vreductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ Z *sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - if (amLast) { - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceLongFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceLongFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceLongFunction::intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(Z), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceLongFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceLongFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceLongFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceLongFunction::intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceLongFunction::intermediateScalar: failed to copy " + "resulting scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceLongFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_LONG_OPS)); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); +_CUDA_H void ReduceLongFunction::execReduceScalar( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, + void *z, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, + int *dimension, int dimensionLength, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + OPS_A(REDUCE_LONG_OPS)); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceLongFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, - const int opNum, - int rank, const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, - int *dimension, int dimensionLength, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_LONG_OPS)); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceLongFunction::execReduceXD( + dim3 launchDims, cudaStream_t *stream, const int opNum, int rank, + const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + OPS_A(REDUCE_LONG_OPS)); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template __device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - + int sPartialsLength = sMemSize / sizeof(X); + X *sPartialsDeref = (X *)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , + LIBND4J_TYPES, LONG_TYPES); -BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); - -} -} - +} // namespace reduce +} // namespace functions diff --git a/libnd4j/include/loops/cuda/reduce/reduce_same.cu b/libnd4j/include/loops/cuda/reduce/reduce_same.cu index 7866209cbb6c..cee846ee80c3 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_same.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_same.cu @@ -19,310 +19,364 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include - +#include +#include using namespace simdOps; - //////////////////////////////////////////////////////////////////////// template __global__ void simpleReduce(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::reduce::ReduceSameFunction::template transformCudaXD(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); + void* extraParams, void* z, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + functions::reduce::ReduceSameFunction::template transformCudaXD( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo, tadOffsets); } //////////////////////////////////////////////////////////////////////// template __global__ void simpleScalar(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - functions::reduce::ReduceSameFunction::template execScalarCuda(x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, tadOnlyShapeInfo); + void* extraParams, void* z, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + functions::reduce::ReduceSameFunction::template execScalarCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, reductionBuffer, + tadOnlyShapeInfo); } - namespace functions { -namespace reduce { +namespace reduce { //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::aggregatePartials(void *vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void ReduceSameFunction::aggregatePartials(void* vsPartials, + Nd4jLong tid, + Nd4jLong numItems, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = static_cast(vsPartials); + auto extraParams = static_cast(vextraParams); - auto sPartials = static_cast(vsPartials); - auto extraParams = static_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numItems) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numItems) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::transformCudaXD( void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - if (OpType::requiresSpecialAccumulation) { - OpType::execSpecialCuda(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - return; - } - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - - __shared__ int tadLength, tadRank, numTads; - __shared__ Nd4jLong *tadShape, *tadStride; - __shared__ bool isPlainOutput; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - - tadLength = shape::length(tadOnlyShapeInfo); - tadRank = shape::rank(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - tadShape = shape::shapeOf(tadOnlyShapeInfo); - tadStride = shape::stride(tadOnlyShapeInfo); +__device__ void ReduceSameFunction::transformCudaXD( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + void* vreductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + if (OpType::requiresSpecialAccumulation) { + OpType::execSpecialCuda(x, xShapeInfo, extraParams, z, zShapeInfo, + dimension, dimensionLength, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); + return; + } + + // shared memory space for storing intermediate results + __shared__ X* sPartials; + + __shared__ int tadLength, tadRank, numTads; + __shared__ Nd4jLong *tadShape, *tadStride; + __shared__ bool isPlainOutput; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + isPlainOutput = shape::order(zShapeInfo) == 'c' && + shape::elementWiseStride(zShapeInfo) == 1; + + tadLength = shape::length(tadOnlyShapeInfo); + tadRank = shape::rank(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + tadShape = shape::shapeOf(tadOnlyShapeInfo); + tadStride = shape::stride(tadOnlyShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Nd4jLong tadOffsetForBlock = tadOffsets[r]; + sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); + + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[xOffset], extraParams), extraParams); } __syncthreads(); - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - - Nd4jLong tadOffsetForBlock = tadOffsets[r]; - sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); - } - __syncthreads(); - - // aggregate. do NOT reduce for elements > tadLength - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - __syncthreads(); + // aggregate. do NOT reduce for elements > tadLength + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), + extraParams); + __syncthreads(); - if (threadIdx.x == 0) - z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); - } + if (threadIdx.x == 0) + z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = + OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); + } } //////////////////////////////////////////////////////////////////////// template -__device__ void ReduceSameFunction::execScalarCudaLegacy(int opNum, void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - DISPATCH_BY_OPNUM_T(execScalarCuda, PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, vreductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); +__device__ void ReduceSameFunction::execScalarCudaLegacy( + int opNum, void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T(execScalarCuda, + PARAMS(vx, xShapeInfo, vextraParams, vz, zShapeInfo, + vreductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); } //////////////////////////////////////////////////////////////////////// template template -__device__ void ReduceSameFunction::execScalarCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void * vz, Nd4jLong const* zShapeInfo, - void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - auto reductionBuffer = reinterpret_cast(vreductionBuffer); - - auto tid = blockDim.x * blockIdx.x + threadIdx.x; - - //shared memory space for storing intermediate results - __shared__ X* sPartials; - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong len; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - xEws = shape::elementWiseStride(xShapeInfo); - len = shape::length(xShapeInfo); - } +__device__ void ReduceSameFunction::execScalarCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + auto reductionBuffer = reinterpret_cast(vreductionBuffer); + + auto tid = blockDim.x * blockIdx.x + threadIdx.x; + + // shared memory space for storing intermediate results + __shared__ X* sPartials; + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + xEws = shape::elementWiseStride(xShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + sPartials[threadIdx.x] = OpType::startingValue(x); + + if (xEws > 0) + for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::op(x[i * xEws], extraParams), extraParams); + else + for (int i = tid; i < len; i += blockDim.x * gridDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), + extraParams); + + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, len), + extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + unsigned int* tc = (unsigned int*)reductionBuffer; + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) + reductionBuffer[blockIdx.x] = + sPartials[0]; // this->postProcess(sPartials[0],len,extraParams); + + __threadfence(); __syncthreads(); - sPartials[threadIdx.x] = OpType::startingValue(x); - if (xEws > 0) - for (int i = tid; i < len; i += (blockDim.x * gridDim.x)) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); - else - for (int i = tid; i < len; i += blockDim.x * gridDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams); + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, len), extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - - unsigned int *tc = (unsigned int *)reductionBuffer; - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reductionBuffer[blockIdx.x] = sPartials[0];//this->postProcess(sPartials[0],len,extraParams); - - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); - __syncthreads(); + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], reductionBuffer[i], extraParams); - if (amLast) { - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); + __syncthreads(); + aggregatePartials(sPartials, threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraParams); + __syncthreads(); - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], reductionBuffer[i], extraParams); - - __syncthreads(); - aggregatePartials(sPartials, threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraParams); - __syncthreads(); - - if (threadIdx.x == 0) { - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } - } + if (threadIdx.x == 0) { + z[0] = OpType::postProcess(sPartials[0], len, extraParams); + } } - else { - - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], len, extraParams); - } + } else { + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], len, extraParams); } + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceSameFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - if(shape::isEmpty(hXShapeInfo)) { - - if(shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceSameFunction::intermediateXD: failed to copy temporary scalar", res); - - auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); - - // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); - } - else { - simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); - } +template +__host__ void ReduceSameFunction::intermediateXD( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionPointer, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync( + sd::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, + sizeof(X), cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceSameFunction::intermediateXD: failed to copy temporary " + "scalar", + res); + + auto ptr = sd::LaunchContext::defaultContext()->getScalarPointer(); + + // scalar assign + functions::scalar::ScalarTransform::executeCudaShaped( + launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, + hZShapeInfo, ptr, nullptr); + } else { + simpleReduce + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); + } } //////////////////////////////////////////////////////////////////////// template -template -__host__ void ReduceSameFunction::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - if (shape::isEmpty(hXShapeInfo)) { - - if (shape::isEmpty(hZShapeInfo)) - return; - - const auto startingVal = static_cast(OpType::startingValue(reinterpret_cast(x))); - - auto res = cudaMemcpyAsync(z, &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream); - if (res != 0) - throw sd::cuda_exception::build("ReduceSameFunction::intermediateScalar: failed to copy resulting scalar", res); - } - else { - simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); - } +template +__host__ void ReduceSameFunction::intermediateScalar( + dim3 launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + if (shape::isEmpty(hXShapeInfo)) { + if (shape::isEmpty(hZShapeInfo)) return; + + const auto startingVal = + static_cast(OpType::startingValue(reinterpret_cast(x))); + + auto res = cudaMemcpyAsync(z, &startingVal, sizeof(X), + cudaMemcpyHostToDevice, *stream); + if (res != 0) + throw sd::cuda_exception::build( + "ReduceSameFunction::intermediateScalar: failed to copy resulting " + "scalar", + res); + } else { + simpleScalar + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, reductionBuffer, tadOnlyShapeInfo); + } } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceSameFunction::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_T(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS); - sd::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed"); +_CUDA_H void ReduceSameFunction::execReduceScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_T( + intermediateScalar, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionBuffer, tadOnlyShapeInfo), + REDUCE_SAME_OPS); + sd::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed"); } //////////////////////////////////////////////////////////////////////// template -_CUDA_H void ReduceSameFunction::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - DISPATCH_BY_OPNUM_T(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), REDUCE_SAME_OPS); - DEBUG_KERNEL(stream, opNum); +_CUDA_H void ReduceSameFunction::execReduceXD( + dim3 launchDims, cudaStream_t* stream, int opNum, int rank, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void* extraParams, + void* z, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, + int* dimension, int dimensionLength, void* reductionPointer, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { + DISPATCH_BY_OPNUM_T( + intermediateXD, + PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, + zShapeInfo, hZShapeInfo, dimension, dimensionLength, + reductionPointer, tadShapeInfo, tadOffsets), + REDUCE_SAME_OPS); + DEBUG_KERNEL(stream, opNum); } //////////////////////////////////////////////////////////////////////// template -__device__ void initializeShared(X *extraParams, X **sPartials, int sMemSize) { - int sPartialsLength = sMemSize / sizeof(X); - X *sPartialsDeref = (X *) *sPartials; - for (int i = 0; i < sPartialsLength; i++) - sPartialsDeref[i] = extraParams[0]; - +__device__ void initializeShared(X* extraParams, X** sPartials, int sMemSize) { + int sPartialsLength = sMemSize / sizeof(X); + X* sPartialsDeref = (X*)*sPartials; + for (int i = 0; i < sPartialsLength; i++) sPartialsDeref[i] = extraParams[0]; } +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , + LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ReduceSameFunction, , LIBND4J_TYPES); - -} -} \ No newline at end of file +} // namespace reduce +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/reduce3.chpp b/libnd4j/include/loops/cuda/reduce3.chpp index 31ebc7eebe42..1549a0aeefbe 100644 --- a/libnd4j/include/loops/cuda/reduce3.chpp +++ b/libnd4j/include/loops/cuda/reduce3.chpp @@ -17,546 +17,548 @@ // @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.11.2018 - -#include -#include #include -#include +#include #include +#include +#include using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { //////////////////////////////////////////////////////////////////////// template -__global__ void execScalarGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int* allocationPointer, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - Reduce3::execScalarCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo); +__global__ void execScalarGeneric(const int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, + void* vz, Nd4jLong const* zShapeInfo, + int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + Reduce3::execScalarCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, + extraParams, vz, zShapeInfo, allocationPointer, + reductionBuffer, tadOnlyShapeInfo); } template -__global__ void execAllGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - Reduce3::execAllCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); +__global__ void execAllGeneric( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + Reduce3::execAllCuda( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } - //////////////////////////////////////////////////////////////////////// template -__global__ void execGeneric(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - Reduce3::execCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); +__global__ void execGeneric( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + Reduce3::execCuda(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, + vz, zShapeInfo, dimension, dimensionLength, + postProcessOrNot, allocationPointer, tadOnlyShapeInfo, + tadOffsets, yTadOnlyShapeInfo, yTadOffsets); } - ////////////////////////////////////////////////////////////////////////// template template -__device__ void Reduce3::aggregatePartials(void* vsPartials, Nd4jLong tid, Nd4jLong numItems, void *vextraParams) { +__device__ void Reduce3::aggregatePartials(void* vsPartials, Nd4jLong tid, + Nd4jLong numItems, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. + auto sPartials = reinterpret_cast(vsPartials); + auto extraParams = reinterpret_cast(vextraParams); + Nd4jLong floorPow2 = numItems; - auto sPartials = reinterpret_cast(vsPartials); - auto extraParams = reinterpret_cast(vextraParams); - Nd4jLong floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; - if (floorPow2 & (floorPow2 - 1)) { + if (tid >= floorPow2) + sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], + sPartials[tid], extraParams); - while(floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (tid >= floorPow2) - sPartials[tid - floorPow2] = OpType::update(sPartials[tid - floorPow2], sPartials[tid], extraParams); - - __syncthreads(); - } + __syncthreads(); + } - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads) - sPartials[tid] = OpType::update(sPartials[tid], sPartials[tid + activeThreads], extraParams); + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads) + sPartials[tid] = OpType::update( + sPartials[tid], sPartials[tid + activeThreads], extraParams); - __syncthreads(); - } + __syncthreads(); + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::execScalarCuda( void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void* vz, Nd4jLong const* zShapeInfo, - int *allocationPointer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo) { - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Z extraZ[3]; - __shared__ Z* sPartials; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - extraZ[0] = (Z) 0.0f; - extraZ[1] = (Z) 0.0f; - - if (extraParams != nullptr) - extraZ[2] = static_cast(extraParams)[2]; - else - extraZ[2] = (Z) 0.0f; - } - __syncthreads(); - - sPartials[threadIdx.x] = OpType::startingValue(x); - Nd4jLong length = shape::length(xShapeInfo); - int xEws = shape::elementWiseStride(xShapeInfo); - int yEws = shape::elementWiseStride(yShapeInfo); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - char xOrder = shape::order(xShapeInfo); - char yOrder = shape::order(yShapeInfo); - - if(xOrder == yOrder && (xEws > 0 && yEws > 0) && shape::strideDescendingCAscendingF(xShapeInfo) && shape::strideDescendingCAscendingF(yShapeInfo)) { - - if (xEws == 1 && yEws == 1) { - for(Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[i], y[i], extraZ), extraZ); - } - else { - for(Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[i * xEws], y[i * yEws], extraZ), extraZ); - } +template +__device__ void Reduce3::execScalarCuda( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Z extraZ[3]; + __shared__ Z* sPartials; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + extraZ[0] = (Z)0.0f; + extraZ[1] = (Z)0.0f; + + if (extraParams != nullptr) + extraZ[2] = static_cast(extraParams)[2]; + else + extraZ[2] = (Z)0.0f; + } + __syncthreads(); + + sPartials[threadIdx.x] = OpType::startingValue(x); + Nd4jLong length = shape::length(xShapeInfo); + int xEws = shape::elementWiseStride(xShapeInfo); + int yEws = shape::elementWiseStride(yShapeInfo); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + char xOrder = shape::order(xShapeInfo); + char yOrder = shape::order(yShapeInfo); + + if (xOrder == yOrder && (xEws > 0 && yEws > 0) && + shape::strideDescendingCAscendingF(xShapeInfo) && + shape::strideDescendingCAscendingF(yShapeInfo)) { + if (xEws == 1 && yEws == 1) { + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + OpType::opAtomic(x[i], y[i], extraZ), extraZ); + } else { + for (Nd4jLong i = tid; i < length; i += gridDim.x * blockDim.x) + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[i * xEws], y[i * yEws], extraZ), extraZ); } - else { - sPartials[threadIdx.x] = OpType::startingValue(x); - auto threadCount = gridDim.x * blockDim.x; - for(Nd4jLong i = tid; i < length; i += threadCount) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto yOffset = shape::getIndexOffset(i, yShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ); - } + } else { + sPartials[threadIdx.x] = OpType::startingValue(x); + auto threadCount = gridDim.x * blockDim.x; + for (Nd4jLong i = tid; i < length; i += threadCount) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto yOffset = shape::getIndexOffset(i, yShapeInfo); + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ); + } + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, length), + extraZ); + __syncthreads(); + + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + int rank = shape::rank(xShapeInfo); + tid = threadIdx.x; + Z* extraBuffer = (Z*)allocationPointer; + if (threadIdx.x == 0) { + reinterpret_cast(reductionBuffer)[blockIdx.x] = sPartials[0]; + extraBuffer[blockIdx.x] = extraZ[0]; + extraBuffer[gridDim.x + blockIdx.x] = extraZ[1]; } - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, length), extraZ); + __threadfence(); __syncthreads(); - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - int rank = shape::rank(xShapeInfo); - tid = threadIdx.x; - Z *extraBuffer = (Z *) allocationPointer; - if (threadIdx.x == 0) { - reinterpret_cast(reductionBuffer)[blockIdx.x] = sPartials[0]; - extraBuffer[blockIdx.x] = extraZ[0]; - extraBuffer[gridDim.x + blockIdx.x] = extraZ[1]; - } + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - __threadfence(); - __syncthreads(); + sPartials[tid] = OpType::startingValue(x); + __syncthreads(); - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); + if (amLast) { + tc[16384] = 0; + sPartials[threadIdx.x] = OpType::startingValue(x); + + // TODO: later probably replace this. Right now we need extraZ sync for + // CosineSimilarity ONLY + if (tid == 0 && extraZ[0] != static_cast(0) && + extraZ[1] != static_cast(0)) { + extraZ[0] = 0.0; + extraZ[1] = 0.0; + for (int i = 0; i < gridDim.x; i++) { + extraZ[0] += extraBuffer[i]; + extraZ[1] += extraBuffer[gridDim.x + i]; } + } - sPartials[tid] = OpType::startingValue(x); - __syncthreads(); - - if (amLast) { - - tc[16384] = 0; - sPartials[threadIdx.x] = OpType::startingValue(x); - - // TODO: later probably replace this. Right now we need extraZ sync for CosineSimilarity ONLY - if (tid == 0 && extraZ[0] != static_cast(0) && extraZ[1] != static_cast(0)) { - extraZ[0] = 0.0; - extraZ[1] = 0.0; - for (int i = 0; i < gridDim.x; i++) { - extraZ[0] += extraBuffer[i]; - extraZ[1] += extraBuffer[gridDim.x + i]; - } - } + for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) + sPartials[threadIdx.x] = + OpType::update(sPartials[threadIdx.x], + static_cast(reductionBuffer)[i], extraZ); - for (Nd4jLong i = threadIdx.x; i < gridDim.x; i += blockDim.x) - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], static_cast(reductionBuffer)[i], extraZ); + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(gridDim.x, blockDim.x), + extraZ); + __syncthreads(); - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(gridDim.x, blockDim.x), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[0] = OpType::postProcess(sPartials[0], length, extraZ); - } + if (threadIdx.x == 0) + z[0] = OpType::postProcess(sPartials[0], length, extraZ); } - else { - - if (tid == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = OpType::postProcess(sPartials[0], length, extraZ); - //printf("Z: [%f]\n", (float) z[0]); - } + } else { + if (tid == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = OpType::postProcess(sPartials[0], length, extraZ); + // printf("Z: [%f]\n", (float) z[0]); } + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::transformAll( void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* xTadShapeInfo, Nd4jLong const* xOffsets, - Nd4jLong const* yTadShapeInfo, Nd4jLong const* yOffsets) { - - auto dx = reinterpret_cast(vx); - auto dy = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - // initialize partials first - __shared__ Z* sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); +template +__device__ void Reduce3::transformAll( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xOffsets, Nd4jLong const* yTadShapeInfo, + Nd4jLong const* yOffsets) { + auto dx = reinterpret_cast(vx); + auto dy = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + // initialize partials first + __shared__ Z* sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + } + __syncthreads(); + + Z startingVal = OpType::startingValue(dx); + sPartials[threadIdx.x] = startingVal; + X* tempX = reinterpret_cast(sPartials) + blockDim.x; + + const int maxBlock = blockDim.x; + + __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; + + __shared__ int xTadLength; + __shared__ int yTadLength; + + __shared__ int xTads; + __shared__ int yTads; + + // reading initial data + if (threadIdx.x == 0) { + xTadLength = shape::length(xTadShapeInfo); + yTadLength = shape::length(yTadShapeInfo); + + xTads = shape::length(xShapeInfo) / xTadLength; + yTads = shape::length(yShapeInfo) / yTadLength; + } + __syncthreads(); + + int limit = xTadLength / maxBlock; + if (xTadLength % maxBlock > 0) limit++; + + for (int r = blockIdx.x; r < xTads; r += blockDim.x * gridDim.x) { + auto x = dx + xOffsets[r]; + + if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) { + auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo); + tempX[threadIdx.x] = x[x0]; } __syncthreads(); - Z startingVal = OpType::startingValue(dx); - sPartials[threadIdx.x] = startingVal; - X *tempX = reinterpret_cast(sPartials) + blockDim.x; - - const int maxBlock = blockDim.x; - - __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; - - __shared__ int xTadLength; - __shared__ int yTadLength; - - __shared__ int xTads; - __shared__ int yTads; - - //reading initial data - if (threadIdx.x == 0) { - xTadLength = shape::length(xTadShapeInfo); - yTadLength = shape::length(yTadShapeInfo); - - xTads = shape::length(xShapeInfo) / xTadLength; - yTads = shape::length(yShapeInfo) / yTadLength; - } - __syncthreads(); - - int limit = xTadLength / maxBlock; - if (xTadLength % maxBlock > 0) - limit++; - - for (int r = blockIdx.x; r < xTads; r += blockDim.x * gridDim.x) { - - auto x = dx + xOffsets[r]; - - if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) { - auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo); + for (int g = 0; g < yTads; g++) { + auto y = dy + yOffsets[g]; + int ri = (r * yTads) + g; + + sPartials[threadIdx.x] = startingVal; + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + __syncthreads(); + + // we might have data too large for single cache block, rendering cache + // useless though :( + for (int t = 0; t < limit; t++) { + // we reset tempX IF we have >1 tiles + if (t >= 1 || (limit > 1 && g > 0)) + if (threadIdx.x + (t * maxBlock) < xTadLength) { + auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), + xTadShapeInfo); tempX[threadIdx.x] = x[x0]; + } + + for (int f = threadIdx.x + (t * maxBlock); + f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); + f += blockDim.x * gridDim.x) { + auto y0 = shape::getIndexOffset(f, yTadShapeInfo); + sPartials[threadIdx.x] = OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ); } - __syncthreads(); - - for (int g = 0; g < yTads; g++) { - - auto y = dy + yOffsets[g]; - int ri = (r * yTads) + g; - - sPartials[threadIdx.x] = startingVal; - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - __syncthreads(); - // we might have data too large for single cache block, rendering cache useless though :( - for (int t = 0; t < limit; t++) { - - // we reset tempX IF we have >1 tiles - if (t >= 1 || (limit > 1 && g > 0)) - if (threadIdx.x + (t * maxBlock) < xTadLength) { - auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), xTadShapeInfo); - tempX[threadIdx.x] = x[x0]; - } - - for (int f = threadIdx.x + (t * maxBlock); f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); f += blockDim.x * gridDim.x) { - auto y0 = shape::getIndexOffset(f, yTadShapeInfo); - sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ); - } - - // we MUST step through this block altogether - __syncthreads(); - } + // we MUST step through this block altogether + __syncthreads(); + } - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, xTadLength), extraZ); - __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, xTadLength), + extraZ); + __syncthreads(); - if (threadIdx.x == 0) { - z[ri] = OpType::postProcess(sPartials[threadIdx.x], xTadLength, extraZ); - } + if (threadIdx.x == 0) { + z[ri] = OpType::postProcess(sPartials[threadIdx.x], xTadLength, extraZ); + } - __syncthreads(); - } - } + __syncthreads(); + } + } } ////////////////////////////////////////////////////////////////////////// template -template -__device__ void Reduce3::transform(void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - // FIXME - if(shape::isScalar(zShapeInfo)) - return; - - if (yTadOnlyShapeInfo == nullptr) { - yTadOnlyShapeInfo = yShapeInfo; // execReduce3TAD case - } - - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - Z startingVal = OpType::startingValue(x); - - __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; - - __shared__ Z* sPartials; - __shared__ int tadLen; - __shared__ Nd4jLong zLen; - __shared__ Nd4jLong xTadEws; - __shared__ Nd4jLong yTadEws; - __shared__ Nd4jLong yTadNum; - __shared__ char xTadOrder; - __shared__ char yTadOrder; - - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast(shmem); - - tadLen = shape::length(tadOnlyShapeInfo); - zLen = shape::length(zShapeInfo); - xTadEws = shape::elementWiseStride(tadOnlyShapeInfo); - yTadEws = shape::elementWiseStride(yTadOnlyShapeInfo); - yTadNum = shape::length(yShapeInfo) / tadLen; - xTadOrder = shape::order(tadOnlyShapeInfo); - yTadOrder = shape::order(yTadOnlyShapeInfo); - } - __syncthreads(); - - sPartials[threadIdx.x] = startingVal; - - if(xTadEws >= 1 && yTadEws >= 1 && xTadOrder == yTadOrder) { - - for(int i = blockIdx.x; i < zLen; i+= gridDim.x) { - - Nd4jLong xOffset = tadOffsets[i]; - Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; - - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - - __syncthreads(); - - for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { - - Nd4jLong xOffset2 = xOffset + j*xTadEws; - Nd4jLong yOffset2 = yOffset + j*yTadEws; - sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ); - } - - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLen), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); - - __syncthreads(); - } +template +__device__ void Reduce3::transform( + void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + // FIXME + if (shape::isScalar(zShapeInfo)) return; + + if (yTadOnlyShapeInfo == nullptr) { + yTadOnlyShapeInfo = yShapeInfo; // execReduce3TAD case + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + Z startingVal = OpType::startingValue(x); + + __shared__ Z extraZ[OpType::extraParamsLen > 0 ? OpType::extraParamsLen : 1]; + + __shared__ Z* sPartials; + __shared__ int tadLen; + __shared__ Nd4jLong zLen; + __shared__ Nd4jLong xTadEws; + __shared__ Nd4jLong yTadEws; + __shared__ Nd4jLong yTadNum; + __shared__ char xTadOrder; + __shared__ char yTadOrder; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast(shmem); + + tadLen = shape::length(tadOnlyShapeInfo); + zLen = shape::length(zShapeInfo); + xTadEws = shape::elementWiseStride(tadOnlyShapeInfo); + yTadEws = shape::elementWiseStride(yTadOnlyShapeInfo); + yTadNum = shape::length(yShapeInfo) / tadLen; + xTadOrder = shape::order(tadOnlyShapeInfo); + yTadOrder = shape::order(yTadOnlyShapeInfo); + } + __syncthreads(); + + sPartials[threadIdx.x] = startingVal; + + if (xTadEws >= 1 && yTadEws >= 1 && xTadOrder == yTadOrder) { + for (int i = blockIdx.x; i < zLen; i += gridDim.x) { + Nd4jLong xOffset = tadOffsets[i]; + Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; + + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + + __syncthreads(); + + for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { + Nd4jLong xOffset2 = xOffset + j * xTadEws; + Nd4jLong yOffset2 = yOffset + j * yTadEws; + sPartials[threadIdx.x] = + j < blockDim.x + ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) + : OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), + extraZ); + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLen), + extraZ); + __syncthreads(); + + if (threadIdx.x == 0) + z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); + + __syncthreads(); } - else { - - for(int i = blockIdx.x; i < zLen; i += gridDim.x) { - - Nd4jLong xOffset = tadOffsets[i]; - Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; - - if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) - extraZ[threadIdx.x] = startingVal; - - __syncthreads(); - - for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { - - Nd4jLong xOffset2 = xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo); - Nd4jLong yOffset2 = yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo); - sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ); - - } - - __syncthreads(); - aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLen), extraZ); - __syncthreads(); - - if (threadIdx.x == 0) - z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); - - __syncthreads(); - } + } else { + for (int i = blockIdx.x; i < zLen; i += gridDim.x) { + Nd4jLong xOffset = tadOffsets[i]; + Nd4jLong yOffset = yTadNum == 1 ? 0 : yTadOffsets[i]; + + if (OpType::extraParamsLen > 0 && threadIdx.x < OpType::extraParamsLen) + extraZ[threadIdx.x] = startingVal; + + __syncthreads(); + + for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { + Nd4jLong xOffset2 = + xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo); + Nd4jLong yOffset2 = + yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo); + sPartials[threadIdx.x] = + j < blockDim.x + ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) + : OpType::update( + sPartials[threadIdx.x], + OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), + extraZ); + } + + __syncthreads(); + aggregatePartials(reinterpret_cast(sPartials), threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLen), + extraZ); + __syncthreads(); + + if (threadIdx.x == 0) + z[i] = OpType::postProcess(sPartials[threadIdx.x], tadLen, extraZ); + + __syncthreads(); } + } } ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execCuda(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - DISPATCH_BY_OPNUM_TT(transform, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), REDUCE3_OPS); +__device__ void Reduce3::execCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), + REDUCE3_OPS); } - - ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execAllCuda( const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - DISPATCH_BY_OPNUM_TT(transformAll, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), REDUCE3_OPS); +__device__ void Reduce3::execAllCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + DISPATCH_BY_OPNUM_TT( + transformAll, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets), + REDUCE3_OPS); } - ////////////////////////////////////////////////////////////////////////// template -__device__ void Reduce3::execScalarCuda(const int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int * allocationPointer, void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - DISPATCH_BY_OPNUM_TT(execScalarCuda, PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo), REDUCE3_OPS); +__device__ void Reduce3::execScalarCuda( + const int opNum, void const* vx, Nd4jLong const* xShapeInfo, void const* vy, + Nd4jLong const* yShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + DISPATCH_BY_OPNUM_TT( + execScalarCuda, + PARAMS(vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + allocationPointer, reductionBuffer, tadOnlyShapeInfo), + REDUCE3_OPS); } - //////////////////////////////////////////////////////////////////////// template -__host__ void Reduce3::exec(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - execGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "reduce3exec(...) failed"); +__host__ void Reduce3::exec( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + execGeneric<<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "reduce3exec(...) failed"); } //////////////////////////////////////////////////////////////////////// - template - __host__ void Reduce3::execAll(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { - - execAllGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationPointer, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "execAllGeneric(...) failed"); - } +template +__host__ void Reduce3::execAll( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationPointer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* yTadOnlyShapeInfo, Nd4jLong const* yTadOffsets) { + execAllGeneric<<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + dimension, dimensionLength, postProcessOrNot, allocationPointer, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "execAllGeneric(...) failed"); +} //////////////////////////////////////////////////////////////////////// template -__host__ void Reduce3::execScalar(dim3 launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void const* vy, Nd4jLong const* yShapeInfo, - void *extraParams, - void *vz, Nd4jLong const* zShapeInfo, - int* allocationPointer, - void *reductionBuffer, - Nd4jLong const* tadOnlyShapeInfo) { - - execScalarGeneric<<>>(opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, allocationPointer, reductionBuffer, tadOnlyShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "execScalarGeneric(...) failed"); +__host__ void Reduce3::execScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void const* vy, Nd4jLong const* yShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + int* allocationPointer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo) { + execScalarGeneric + <<>>( + opNum, vx, xShapeInfo, vy, yShapeInfo, extraParams, vz, zShapeInfo, + allocationPointer, reductionBuffer, tadOnlyShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "execScalarGeneric(...) failed"); } +// BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, +// FLOAT_TYPES); - - - - //BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES); - -} -} \ No newline at end of file +} // namespace reduce3 +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/reduce3.cu b/libnd4j/include/loops/cuda/reduce3.cu index c1d63e8dd245..cf8f8c3481bd 100644 --- a/libnd4j/include/loops/cuda/reduce3.cu +++ b/libnd4j/include/loops/cuda/reduce3.cu @@ -18,16 +18,12 @@ // @author raver119@gmail.com // - -#include -#include #include -#include +#include #include +#include +#include namespace functions { - namespace reduce3 { - - - } -} \ No newline at end of file +namespace reduce3 {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar.chpp b/libnd4j/include/loops/cuda/scalar.chpp index b412e4957bf1..5e8cd0048a37 100644 --- a/libnd4j/include/loops/cuda/scalar.chpp +++ b/libnd4j/include/loops/cuda/scalar.chpp @@ -21,152 +21,184 @@ #ifndef SCALAR_CU #define SCALAR_CU -#include "loops/scalar.h" #include #include -#include #include +#include #include +#include "loops/scalar.h" + using namespace simdOps; //////////////////////////////////////////////////////////////////////////////// template -__global__ static void scalarSimpleShaped(void const* vx, void const* vscalar, Nd4jLong const* xShapeInfo, void *vparams, void *vz, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - auto scalar = reinterpret_cast(vscalar)[0]; - auto x = reinterpret_cast(vx); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Nd4jLong length; - if(threadIdx.x == 0) { - length = shape::length(xShapeInfo); +__global__ static void scalarSimpleShaped(void const* vx, void const* vscalar, + Nd4jLong const* xShapeInfo, + void* vparams, void* vz, + Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto x = reinterpret_cast(vx); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ Nd4jLong length; + if (threadIdx.x == 0) { + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[i * zEws] = OpType::op(x[i * xEws], scalar, params); } - __syncthreads(); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - - if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { - for (Nd4jLong i = tid; i < length; i += totalThreads) { - z[i * zEws] = OpType::op(x[i * xEws], scalar, params); - } - } else { - for (Nd4jLong i = tid; i < length; i += totalThreads) { - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], scalar, params); - } + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], scalar, params); } - + } } //////////////////////////////////////////////////////////////////////////////// template -__global__ static void scalarAlongDimension(void const* vx, Nd4jLong const* xShapeInfo, - void* vextraParams, - void* vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - auto x = reinterpret_cast(vx); - auto extraParams = reinterpret_cast(vextraParams); - auto z = reinterpret_cast(vz); - auto scalars = reinterpret_cast(vscalars); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +__global__ static void scalarAlongDimension( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto extraParams = reinterpret_cast(vextraParams); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, void const* vscalar, void *vextraParams, int *allocPointer){ - - auto xEws = shape::elementWiseStride(hxShapeInfo); - auto xOrder = shape::order(hxShapeInfo); - - auto zEws = shape::elementWiseStride(hzShapeInfo); - auto zOrder = shape::order(hzShapeInfo); - - auto length = shape::length(hxShapeInfo); - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); - sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShapedA(...) failed"); +template +template +void _CUDA_H ScalarTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + auto xEws = shape::elementWiseStride(hxShapeInfo); + auto xOrder = shape::order(hxShapeInfo); + + auto zEws = shape::elementWiseStride(hzShapeInfo); + auto zOrder = shape::order(hzShapeInfo); + + auto length = shape::length(hxShapeInfo); + + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); + sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShapedA(...) failed"); } //////////////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void const* x, Nd4jLong const* xShapeInfo, void *z, Nd4jLong const* zShapeInfo, void const* scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "scalarAlongDimA(...) failed"); +template +template +void _CUDA_H ScalarTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, "scalarAlongDimA(...) failed"); } //////////////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, void const* vscalar, void *vextraParams) { - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_TTT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, hxShapeInfo, vz, zShapeInfo, hzShapeInfo, vscalar, vextraParams, nullptr), SCALAR_OPS); +template +void ScalarTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + void const* vscalar, void* vextraParams) { + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_TTT( + intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, hxShapeInfo, vz, zShapeInfo, + hzShapeInfo, vscalar, vextraParams, nullptr), + SCALAR_OPS); } //////////////////////////////////////////////////////////////////////////////// -template -void ScalarTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TTT(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_OPS); +template +void ScalarTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TTT( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_OPS); } -} -} - - +} // namespace scalar +} // namespace functions -#endif // SCALAR_CU \ No newline at end of file +#endif // SCALAR_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar.cu b/libnd4j/include/loops/cuda/scalar.cu index 26c3e5cb871b..c470fb46289c 100644 --- a/libnd4j/include/loops/cuda/scalar.cu +++ b/libnd4j/include/loops/cuda/scalar.cu @@ -18,15 +18,14 @@ // @author raver119@gmail.com // -#include "loops/scalar.h" #include #include -#include #include +#include #include -namespace functions { - namespace scalar { +#include "loops/scalar.h" - } -} \ No newline at end of file +namespace functions { +namespace scalar {} +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_bool.cu b/libnd4j/include/loops/cuda/scalar_bool.cu index 2c21c979befd..fd9639d433dc 100644 --- a/libnd4j/include/loops/cuda/scalar_bool.cu +++ b/libnd4j/include/loops/cuda/scalar_bool.cu @@ -19,218 +19,222 @@ // @author raver119@gmail.com // -#include "../scalar_bool.h" #include #include #include "../legacy_ops.h" +#include "../scalar_bool.h" using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__global__ void scalarAlongDimension(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::scalar::ScalarBoolTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +__global__ void scalarAlongDimension( + void const* x, Nd4jLong const* xShapeInfo, void* extraParams, void* z, + Nd4jLong const* zShapeInfo, void const* scalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::scalar::ScalarBoolTransform::template transformCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); } - //////////////////////////////////////////////////////////////////////// template -__global__ void scalarSimpleShaped(void const* x, void const* y, Nd4jLong const* xShapeInfo, void *params, void *z, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - functions::scalar::ScalarBoolTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +__global__ void scalarSimpleShaped(void const* x, void const* y, + Nd4jLong const* xShapeInfo, void* params, + void* z, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + functions::scalar::ScalarBoolTransform::template transformCuda( + y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); } - - - - // *********************************************************************// // *********************************************************************// namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(void const* vscalar, - void const* vy, Nd4jLong const* yShapeInfo, - void *vparams, - void *vz, Nd4jLong const* zShapeInfo, - int *allocationBuffer) { - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - auto yRank = shape::rank(yShapeInfo); - auto yEWS = shape::elementWiseStride(yShapeInfo); - auto yShape = shape::shapeOf(yShapeInfo); - auto yStride = shape::stride(yShapeInfo); - - auto zRank = shape::rank(zShapeInfo); - auto zEWS = shape::elementWiseStride(zShapeInfo); - auto zShape = shape::shapeOf(zShapeInfo); - auto zStride = shape::stride(zShapeInfo); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int len; - if(threadIdx.x == 0) - len = shape::length(yShapeInfo); - __syncthreads(); - - if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { - transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); - } - else { - for (Nd4jLong i = tid; i < len; i+= totalThreads) - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } +template +template +__device__ void ScalarBoolTransform::transformCuda( + void const* vscalar, void const* vy, Nd4jLong const* yShapeInfo, + void* vparams, void* vz, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if (threadIdx.x == 0) len = shape::length(yShapeInfo); + __syncthreads(); + + if (yEWS >= 1 && zEWS >= 1 && + shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, + allocationBuffer); + } else { + for (Nd4jLong i = tid; i < len; i += totalThreads) + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(Nd4jLong len, - void const* vx, - void const* vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer) { - - auto x = reinterpret_cast(vx)[0]; - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - Nd4jLong i = tid; - if(yEWS == 1 && zEWS == 1) { - for (; i < len; i += totalThreads) - z[i] = OpType::op(y[i], x, params); - } - else { - for (; i < len; i += totalThreads) - z[i * zEWS] = OpType::op(y[i * yEWS], x, params); - } +template +template +__device__ void ScalarBoolTransform::transformCuda( + Nd4jLong len, void const* vx, void const* vy, Nd4jLong yEWS, void* vparams, + void* vz, Nd4jLong zEWS, int* allocationBuffer) { + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if (yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) z[i] = OpType::op(y[i], x, params); + } else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } } - //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarBoolTransform::transformCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto x = reinterpret_cast(vx); - auto scalars = reinterpret_cast(vscalars); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +template +template +__device__ void ScalarBoolTransform::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + Z* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - Z *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - //////////////////////////////////////////////////////////////////////// -template +template template -_CUDA_H void ScalarBoolTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - void const* x, Nd4jLong const* xShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "scalarAlongDim(...) failed"); +_CUDA_H void ScalarBoolTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, "scalarAlongDim(...) failed"); } //////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarBoolTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void *vextraParams, int *allocPointer){ - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); - sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed"); +template +template +void _CUDA_H ScalarBoolTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); + sd::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed"); } //////////////////////////////////////////////////////////////////////// -template -void ScalarBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void const* vextraParams) { - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, const_cast(vextraParams), nullptr), SCALAR_BOOL_OPS); +template +void ScalarBoolTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void const* vextraParams) { + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_TT( + intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, + const_cast(vextraParams), nullptr), + SCALAR_BOOL_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarBoolTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_BOOL_OPS); -} - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); -} +template +void ScalarBoolTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_BOOL_OPS); } +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ScalarBoolTransform, , + LIBND4J_TYPES, BOOL_TYPES); +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cuda/scalar_int.cu b/libnd4j/include/loops/cuda/scalar_int.cu index 3d5b8982c00f..bf4a1a0b4c40 100644 --- a/libnd4j/include/loops/cuda/scalar_int.cu +++ b/libnd4j/include/loops/cuda/scalar_int.cu @@ -19,217 +19,222 @@ // @author raver119@gmail.com // -#include "../scalar_int.h" #include #include #include "../legacy_ops.h" +#include "../scalar_int.h" using namespace simdOps; //////////////////////////////////////////////////////////////////////// template -__global__ void scalarAlongDimension(void const* x, Nd4jLong const* xShapeInfo, - void *extraParams, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - functions::scalar::ScalarIntTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +__global__ void scalarAlongDimension( + void const* x, Nd4jLong const* xShapeInfo, void* extraParams, void* z, + Nd4jLong const* zShapeInfo, void const* scalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + functions::scalar::ScalarIntTransform::template transformCuda( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); } - //////////////////////////////////////////////////////////////////////// template -__global__ void scalarSimpleShaped(void const* x, void const* y, Nd4jLong const* xShapeInfo, void *params, void *z, Nd4jLong const* zShapeInfo, int *allocationBuffer) { - - functions::scalar::ScalarIntTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +__global__ void scalarSimpleShaped(void const* x, void const* y, + Nd4jLong const* xShapeInfo, void* params, + void* z, Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + functions::scalar::ScalarIntTransform::template transformCuda( + y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); } - - - - // *********************************************************************// // *********************************************************************// namespace functions { -namespace scalar { +namespace scalar { //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(void const* vscalar, - void const* vy, Nd4jLong const* yShapeInfo, - void *vparams, - void *vz, Nd4jLong const* zShapeInfo, - int *allocationBuffer) { - auto scalar = reinterpret_cast(vscalar)[0]; - auto y = reinterpret_cast(vy); - auto params = reinterpret_cast(vparams); - auto z = reinterpret_cast(vz); - - auto yRank = shape::rank(yShapeInfo); - auto yEWS = shape::elementWiseStride(yShapeInfo); - auto yShape = shape::shapeOf(yShapeInfo); - auto yStride = shape::stride(yShapeInfo); - - auto zRank = shape::rank(zShapeInfo); - auto zEWS = shape::elementWiseStride(zShapeInfo); - auto zShape = shape::shapeOf(zShapeInfo); - auto zStride = shape::stride(zShapeInfo); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int len; - if(threadIdx.x == 0) - len = shape::length(yShapeInfo); - __syncthreads(); - - if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { - transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); - } - else { - for (Nd4jLong i = tid; i < len; i+= totalThreads) - z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); - } +template +template +__device__ void ScalarIntTransform::transformCuda(void const* vscalar, + void const* vy, + Nd4jLong const* yShapeInfo, + void* vparams, void* vz, + Nd4jLong const* zShapeInfo, + int* allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if (threadIdx.x == 0) len = shape::length(yShapeInfo); + __syncthreads(); + + if (yEWS >= 1 && zEWS >= 1 && + shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, + allocationBuffer); + } else { + for (Nd4jLong i = tid; i < len; i += totalThreads) + z[shape::getIndexOffset(i, zShapeInfo)] = + OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params); + } } //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(Nd4jLong len, - void const* vx, - void const* vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer) { - - auto x = reinterpret_cast(vx)[0]; - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - - int totalThreads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - Nd4jLong i = tid; - if(yEWS == 1 && zEWS == 1) { - for (; i < len; i += totalThreads) - z[i] = OpType::op(y[i], x, params); - } - else { - for (; i < len; i += totalThreads) - z[i * zEWS] = OpType::op(y[i * yEWS], x, params); - } +template +template +__device__ void ScalarIntTransform::transformCuda( + Nd4jLong len, void const* vx, void const* vy, Nd4jLong yEWS, void* vparams, + void* vz, Nd4jLong zEWS, int* allocationBuffer) { + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if (yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) z[i] = OpType::op(y[i], x, params); + } else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } } - //////////////////////////////////////////////////////////////////////// -template -template -__device__ void ScalarIntTransform::transformCuda(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalars, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - auto x = reinterpret_cast(vx); - auto scalars = reinterpret_cast(vscalars); - auto z = reinterpret_cast(vz); - auto extraParams = reinterpret_cast(vextraParams); - - if (tadShapeInfoZ == nullptr) { - tadShapeInfoZ = tadShapeInfo; - tadOffsetsZ = tadOffsets; +template +template +__device__ void ScalarIntTransform::transformCuda( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, void const* vscalars, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + auto numTads = shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && + shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X* oZ = z + tadOffsetsZ[r]; + auto oX = x + tadOffsets[r]; - // tad preparation - auto tadEws = shape::elementWiseStride(tadShapeInfo); - auto zEws = shape::elementWiseStride(tadShapeInfoZ); - auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - auto numTads =shape::length(xShapeInfo) / tadLength; - - if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { - - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - X *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; - - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); - } - } else { - // main loop, rolling over tads - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - X *oZ = z + tadOffsetsZ[r]; - auto oX = x + tadOffsets[r]; + auto s = scalars[r]; - auto s = scalars[r]; - - for (int f = threadIdx.x; f < tadLength; f += blockDim.x) - oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); - } + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op( + oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams); } + } } - //////////////////////////////////////////////////////////////////////// -template +template template -_CUDA_H void ScalarIntTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - void const* x, Nd4jLong const* xShapeInfo, - void *z, Nd4jLong const* zShapeInfo, - void const* scalars, - void *extraParams, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - - scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +_CUDA_H void ScalarIntTransform::intermediateAlongDimension( + dim3& launchDims, cudaStream_t* stream, void const* x, + Nd4jLong const* xShapeInfo, void* z, Nd4jLong const* zShapeInfo, + void const* scalars, void* extraParams, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { + scalarAlongDimension + <<>>( + x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, + dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); } //////////////////////////////////////////////////////////////////////// -template -template -void _CUDA_H ScalarIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void *vextraParams, int *allocPointer){ - - scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); +template +template +void _CUDA_H ScalarIntTransform::intermediateShaped( + dim3& launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams, int* allocPointer) { + scalarSimpleShaped + <<>>( + vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); } //////////////////////////////////////////////////////////////////////// -template -void ScalarIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - void const* vx, Nd4jLong const* xShapeInfo, - void *vz, Nd4jLong const* zShapeInfo, - void const* vscalar, - void* vextraParams) { - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("H14 opNum:[%i]\n", opNum); - - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS); +template +void ScalarIntTransform::executeCudaShaped( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalar, void* vextraParams) { + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, + vscalar, vextraParams, nullptr), + SCALAR_INT_OPS); } //////////////////////////////////////////////////////////////////////// -template -void ScalarIntTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *vz, Nd4jLong const* zShapeInfo, void const* vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, Nd4jLong const* tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS); +template +void ScalarIntTransform::executeCudaAlongDimension( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, + void const* vscalars, void* vextraParams, int* dimension, + int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, Nd4jLong const* tadShapeInfoZ, + Nd4jLong const* tadOffsetsZ) { + DISPATCH_BY_OPNUM_T( + intermediateAlongDimension, + PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, + vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), + SCALAR_INT_OPS); } - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , INTEGER_TYPES); - -} -} +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT ScalarIntTransform, , + INTEGER_TYPES); +} // namespace scalar +} // namespace functions diff --git a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu index b07827f2638f..13851d358d2a 100644 --- a/libnd4j/include/loops/cuda/specials/accumulateKernel.cu +++ b/libnd4j/include/loops/cuda/specials/accumulateKernel.cu @@ -33,58 +33,62 @@ namespace sd { * @param n * @param length */ - template - __device__ void accumulateKernel(void **vx, void *vz, int n, const Nd4jLong length) { +template +__device__ void accumulateKernel(void **vx, void *vz, int n, + const Nd4jLong length) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + __shared__ T *shmem; - __shared__ - T *shmem; + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedmem[]; + shmem = (T *)sharedmem; + } + __syncthreads(); - if (threadIdx.x == 0) { - extern __shared__ unsigned char sharedmem[]; - shmem = (T *) sharedmem; - } - __syncthreads(); + for (int r = blockDim.x * blockIdx.x; r < length; + r += blockDim.x * gridDim.x) { + shmem[threadIdx.x] = 0.0f; - for (int r = blockDim.x * blockIdx.x; r < length; r += blockDim.x * gridDim.x) { - shmem[threadIdx.x] = 0.0f; + Nd4jLong baseIdx = r; - Nd4jLong baseIdx = r; + // aggregation step, we roll over all arrays + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)x[ar]; + cdata += baseIdx; - // aggregation step, we roll over all arrays - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) x[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] += cdata[threadIdx.x]; - } - - T *wdata = z + baseIdx; - - // saving accumulated values - if (baseIdx + threadIdx.x < length) - wdata[threadIdx.x] = shmem[threadIdx.x]; - } + if (baseIdx + threadIdx.x < length) + shmem[threadIdx.x] += cdata[threadIdx.x]; } -/////////////////////////////////////////////////////////////////////// - template - __global__ void execAccumulateKernel(void **vx, void *vz, int n, const Nd4jLong length) { + T *wdata = z + baseIdx; - accumulateKernel(vx, vz, n, length); - } + // saving accumulated values + if (baseIdx + threadIdx.x < length) wdata[threadIdx.x] = shmem[threadIdx.x]; + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void - accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vx, void *vz, int n, const Nd4jLong length) { - - execAccumulateKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>> (vx, vz, n, length); - sd::DebugHelper::checkErrorCode(stream, "accumulate(...) failed"); - } +template +__global__ void execAccumulateKernel(void **vx, void *vz, int n, + const Nd4jLong length) { + accumulateKernel(vx, vz, n, length); +} - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT accumulateKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vx, void * vz, int n, const Nd4jLong length), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vx, void *vz, int n, + const Nd4jLong length) { + execAccumulateKernel + <<>>(vx, vz, n, + length); + sd::DebugHelper::checkErrorCode(stream, "accumulate(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT accumulateKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vx, + void *vz, int n, const Nd4jLong length), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/averagingKernel.cu b/libnd4j/include/loops/cuda/specials/averagingKernel.cu index fe8798d74880..3e0d1e6f155d 100644 --- a/libnd4j/include/loops/cuda/specials/averagingKernel.cu +++ b/libnd4j/include/loops/cuda/specials/averagingKernel.cu @@ -24,81 +24,79 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void averagingKernel(void **vdx, void *vdz, int n, Nd4jLong length, bool propagate) { - - auto dx = reinterpret_cast(vdx); - auto dz = reinterpret_cast(vdz); - - __shared__ - T *shmem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char sharedmem[]; - shmem = (T *) sharedmem; - } - __syncthreads(); - - - // each block cycles over it's own part of arrays - for (int r = blockDim.x * blockIdx.x; r < length; r += blockDim.x * gridDim.x) { - shmem[threadIdx.x] = (T) 0.0f; - - Nd4jLong baseIdx = r; - - // aggregation step, we roll over all arrays - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) dx[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] += cdata[threadIdx.x]; - } - - - // average data in shared memory - if (baseIdx + threadIdx.x < length) - shmem[threadIdx.x] /= n; - - // div step & write out step - if (dz != nullptr) { - T *wdata = dz + baseIdx; - - if (baseIdx + threadIdx.x < length) { - wdata[threadIdx.x] = shmem[threadIdx.x]; - } - } - - // propagate averaged data to all arrays - if (propagate) - for (int ar = 0; ar < n; ar++) { - T *cdata = (T *) dx[ar]; - cdata += baseIdx; - - if (baseIdx + threadIdx.x < length) - cdata[threadIdx.x] = shmem[threadIdx.x]; - } - } +template +__device__ void averagingKernel(void **vdx, void *vdz, int n, Nd4jLong length, + bool propagate) { + auto dx = reinterpret_cast(vdx); + auto dz = reinterpret_cast(vdz); + + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedmem[]; + shmem = (T *)sharedmem; + } + __syncthreads(); + + // each block cycles over it's own part of arrays + for (int r = blockDim.x * blockIdx.x; r < length; + r += blockDim.x * gridDim.x) { + shmem[threadIdx.x] = (T)0.0f; + + Nd4jLong baseIdx = r; + + // aggregation step, we roll over all arrays + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)dx[ar]; + cdata += baseIdx; + + if (baseIdx + threadIdx.x < length) + shmem[threadIdx.x] += cdata[threadIdx.x]; } + // average data in shared memory + if (baseIdx + threadIdx.x < length) shmem[threadIdx.x] /= n; -/////////////////////////////////////////////////////////////////////// - template - __global__ void execAveragingKernel(void **vdx, void *vdz, int n, Nd4jLong length, bool propagate) { + // div step & write out step + if (dz != nullptr) { + T *wdata = dz + baseIdx; - averagingKernel(vdx, vdz, n, length, propagate); + if (baseIdx + threadIdx.x < length) { + wdata[threadIdx.x] = shmem[threadIdx.x]; + } } + // propagate averaged data to all arrays + if (propagate) + for (int ar = 0; ar < n; ar++) { + T *cdata = (T *)dx[ar]; + cdata += baseIdx; -/////////////////////////////////////////////////////////////////////// - template - __host__ void - averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdx, void *vdz, int n, Nd4jLong length, - bool propagate) { + if (baseIdx + threadIdx.x < length) + cdata[threadIdx.x] = shmem[threadIdx.x]; + } + } +} - execAveragingKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>>(vdx, vdz, n, length, propagate); - sd::DebugHelper::checkErrorCode(stream, "averaging(...) failed"); - } +/////////////////////////////////////////////////////////////////////// +template +__global__ void execAveragingKernel(void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate) { + averagingKernel(vdx, vdz, n, length, propagate); +} - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT averagingKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdx, void * vdz, int n, Nd4jLong length, bool propagate), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate) { + execAveragingKernel<<>>( + vdx, vdz, n, length, propagate); + sd::DebugHelper::checkErrorCode(stream, "averaging(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT averagingKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vdx, + void *vdz, int n, Nd4jLong length, bool propagate), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu index ac79e807c1f7..103c76aa6f99 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu @@ -23,168 +23,194 @@ ////////////////////////////////////////////////////////////////////////// template -__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - auto y = static_cast(vy); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, xShapeInfo); - int posIJ = shape::getIndexOffset(ij, xShapeInfo); - - X v0 = x[posIJ]; - X v1 = x[posIT]; - - if(!descending == (v0 > v1)) { - x[posIJ] = v1; - x[posIT] = v0; - - Y ytemp = y[posIJ]; - y[posIJ] = y[posIT]; - y[posIT] = ytemp; - } - } +__global__ void bitonicArbitraryStepKernelKey( + void *vx, Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending) { + auto x = static_cast(vx); + auto y = static_cast(vy); + + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int half = window >> 1; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + // for (int i = 0; i < length; i+= window) + /* + if window == 4; + iterations will be: 0; 4; 8; 12; 16; 20 + if gridDim = 3; + on first iteration we'll have: 0; 4; 8; + on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + + (3 * 4) = 20 + */ + int firstPosition; + int firstStep; + int secondPosition; + int secondStep; + + int WARP_SIZE = 32; + int numWarps = (gridDim.x * blockDim.x) / 32; + int warpId = tid / WARP_SIZE; + int warpIdx = tid % WARP_SIZE; + + if (half >= 128) { + firstPosition = blockIdx.x * window; + firstStep = gridDim.x * window; + + secondPosition = threadIdx.x; + secondStep = blockDim.x; + } else if (half >= 32) { + firstPosition = warpId * window; + firstStep = numWarps * window; + + secondPosition = warpIdx; + secondStep = WARP_SIZE; + } else { + firstPosition = tid * window; + firstStep = blockDim.x * gridDim.x * window; + + secondPosition = 0; + secondStep = 1; + } + + for (int i = firstPosition; i < length; i += firstStep) { + for (int j = secondPosition; j < half; j += secondStep) { + int it = (reverse) ? i + j + half : i + window - j - 1; + int ij = i + j; + if (it < length && ij < length) { + int posIT = shape::getIndexOffset(it, xShapeInfo); + int posIJ = shape::getIndexOffset(ij, xShapeInfo); + + X v0 = x[posIJ]; + X v1 = x[posIT]; + + if (!descending == (v0 > v1)) { + x[posIJ] = v1; + x[posIT] = v0; + + Y ytemp = y[posIJ]; + y[posIJ] = y[posIT]; + y[posIT] = ytemp; } + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ T *shmem; - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shrd[]; - shmem = (T *) shrd; - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, xShapeInfo); - int posIJ = shape::getIndexOffset(ij, xShapeInfo); - - shmem[threadIdx.x] = x[posIJ]; - shmem[threadIdx.x + blockDim.x] = x[posIT]; - - if(!descending == (shmem[threadIdx.x] > shmem[threadIdx.x + blockDim.x])) { - x[posIJ] = shmem[threadIdx.x + blockDim.x]; - x[posIT] = shmem[threadIdx.x]; - } - } +template +__global__ void execBitonicArbitraryStepKernel(void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, + int reverse, bool descending) { + auto x = static_cast(vx); + + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int half = window >> 1; + + __shared__ T *shmem; + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shrd[]; + shmem = (T *)shrd; + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + + // for (int i = 0; i < length; i+= window) + /* + if window == 4; + iterations will be: 0; 4; 8; 12; 16; 20 + if gridDim = 3; + on first iteration we'll have: 0; 4; 8; + on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + + (3 * 4) = 20 + */ + int firstPosition; + int firstStep; + int secondPosition; + int secondStep; + + int WARP_SIZE = 32; + int numWarps = (gridDim.x * blockDim.x) / 32; + int warpId = tid / WARP_SIZE; + int warpIdx = tid % WARP_SIZE; + + if (half >= 128) { + firstPosition = blockIdx.x * window; + firstStep = gridDim.x * window; + + secondPosition = threadIdx.x; + secondStep = blockDim.x; + } else if (half >= 32) { + firstPosition = warpId * window; + firstStep = numWarps * window; + + secondPosition = warpIdx; + secondStep = WARP_SIZE; + } else { + firstPosition = tid * window; + firstStep = blockDim.x * gridDim.x * window; + + secondPosition = 0; + secondStep = 1; + } + + for (int i = firstPosition; i < length; i += firstStep) { + for (int j = secondPosition; j < half; j += secondStep) { + int it = (reverse) ? i + j + half : i + window - j - 1; + int ij = i + j; + if (it < length && ij < length) { + int posIT = shape::getIndexOffset(it, xShapeInfo); + int posIJ = shape::getIndexOffset(ij, xShapeInfo); + + shmem[threadIdx.x] = x[posIJ]; + shmem[threadIdx.x + blockDim.x] = x[posIT]; + + if (!descending == + (shmem[threadIdx.x] > shmem[threadIdx.x + blockDim.x])) { + x[posIJ] = shmem[threadIdx.x + blockDim.x]; + x[posIT] = shmem[threadIdx.x]; } + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending) { - execBitonicArbitraryStepKernel<<>>(vx, xShapeInfo, window, length, reverse, descending); +template +__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, int reverse, + bool descending) { + execBitonicArbitraryStepKernel + <<>>( + vx, xShapeInfo, window, length, reverse, descending); } template -__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending) { - bitonicArbitraryStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); +__host__ void bitonicArbitraryStepGenericKey( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending) { + bitonicArbitraryStepKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int window, int length, + int reverse, bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicArbitraryStepGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int window, int length, + int reverse, bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu index 2297f8f93d2e..e76a972d2076 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu @@ -21,120 +21,135 @@ #include - ////////////////////////////////////////////////////////////////////////// template -__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, xShapeInfo); - int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (x[posI]>x[posIXJ])) { - /* exchange(i,ixj); */ - X temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - - Y ytemp = y[posI]; - y[posI] = y[posIXJ]; - y[posIXJ] = ytemp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (x[posI](vx); + auto y = static_cast(vy); + + unsigned int i, ixj; /* Sorting partners: i and ixj */ + i = threadIdx.x + blockDim.x * blockIdx.x; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) xLength = shape::length(xShapeInfo); + + __syncthreads(); + + if (i >= length) return; + + ixj = i ^ j; + + /* The threads with the lowest ids sort the array. */ + if ((ixj) > i) { + int posI = shape::getIndexOffset(i, xShapeInfo); + int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); + + if ((i & k) == 0) { + /* Sort ascending */ + if (!descending == (x[posI] > x[posIXJ])) { + /* exchange(i,ixj); */ + X temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + + Y ytemp = y[posI]; + y[posI] = y[posIXJ]; + y[posIXJ] = ytemp; + } + } else if ((i & k) != 0) { + /* Sort descending */ + if (!descending == (x[posI] < x[posIXJ])) { + /* exchange(i,ixj); */ + X temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + + Y ytemp = y[posI]; + y[posI] = y[posIXJ]; + y[posIXJ] = ytemp; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicSortStepKernel(void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, xShapeInfo); - int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (x[posI]>x[posIXJ])) { - /* exchange(i,ixj); */ - T temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (x[posI] +__global__ void bitonicSortStepKernel(void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, + bool descending) { + auto x = static_cast(vx); + + unsigned int i, ixj; /* Sorting partners: i and ixj */ + i = threadIdx.x + blockDim.x * blockIdx.x; + + __shared__ Nd4jLong xLength; + if (threadIdx.x == 0) xLength = shape::length(xShapeInfo); + + __syncthreads(); + + if (i >= length) return; + + ixj = i ^ j; + + /* The threads with the lowest ids sort the array. */ + if ((ixj) > i) { + int posI = shape::getIndexOffset(i, xShapeInfo); + int posIXJ = shape::getIndexOffset(ixj, xShapeInfo); + + if ((i & k) == 0) { + /* Sort ascending */ + if (!descending == (x[posI] > x[posIXJ])) { + /* exchange(i,ixj); */ + T temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + } + } else if ((i & k) != 0) { + /* Sort descending */ + if (!descending == (x[posI] < x[posIXJ])) { + /* exchange(i,ixj); */ + T temp = x[posI]; + x[posI] = x[posIXJ]; + x[posIXJ] = temp; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernel<<>>(vx, xShapeInfo, j, k, length, descending); +template +__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, + bool descending) { + bitonicSortStepKernel + <<>>( + vx, xShapeInfo, j, k, length, descending); } ////////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); +__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int j, int k, int length, + bool descending) { + bitonicSortStepKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); } - -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int j, int k, int length, + bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT bitonicSortStepGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int j, int k, int length, + bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/concatKernel.cu b/libnd4j/include/loops/cuda/specials/concatKernel.cu index 0d1c8485458d..cbd9e258425e 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernel.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernel.cu @@ -23,248 +23,248 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernel(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *resultShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, Nd4jLong *zOffsets) { - - int tid = threadIdx.x + blockIdx.x * blockDim.x; - - int zRank = shape::rank(resultShapeInfo); - - auto result = reinterpret_cast(vz); - auto dataT = reinterpret_cast(data); - auto shapeInfoPointers = reinterpret_cast(inputShapeInfos); - auto tadShapes = reinterpret_cast(tadPointers); - auto tadOffsets = reinterpret_cast(offsetPointers); - - //if (threadIdx.x == 0 && blockIdx.x == 0) { - // shape::printShapeInfoLinear("zTadShape", zTadShape); - //} - - //__shared__ int tDim[1]; - __shared__ int baseIdx; - - __shared__ int yLength; - __shared__ char yOrder; - __shared__ int yEWS; - - char zOrder = shape::order(resultShapeInfo); - - int zEWS = shape::elementWiseStride(resultShapeInfo); - int tadEWS = shape::elementWiseStride(zTadShape); - int zLength = shape::length(resultShapeInfo); - - __shared__ int arrOffset; - __shared__ int numTads; - - - if (shape::isVector(resultShapeInfo)) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Vector here\n"); - - if (zEWS >= 1) { - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - if(shape::isVector(shapeInfoPointers[r]) || shape::order(shapeInfoPointers[r]) == shape::order(resultShapeInfo)) { - yLength = shape::length(shapeInfoPointers[r]); - yEWS = shape::elementWiseStride(shapeInfoPointers[r]); - // FIXME: this is bad - __shared__ int baseIdx; - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(shapeInfoPointers[f]); - } - } - __syncthreads(); - for (int i = threadIdx.x; i < yLength && baseIdx + i < zLength; i += blockDim.x) { - result[baseIdx + i * zEWS] = dataT[r][i * yEWS]; - } - __syncthreads(); - } else { - if (tid == 0) - printf("Non-matched order for vector\n"); - } - } - } else { - if (tid == 0) - printf("Vector Non-1 zEWS\n"); +template +__device__ void concatKernel(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *resultShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, Nd4jLong *zTadShape, + Nd4jLong *zOffsets) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + int zRank = shape::rank(resultShapeInfo); + + auto result = reinterpret_cast(vz); + auto dataT = reinterpret_cast(data); + auto shapeInfoPointers = reinterpret_cast(inputShapeInfos); + auto tadShapes = reinterpret_cast(tadPointers); + auto tadOffsets = reinterpret_cast(offsetPointers); + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // shape::printShapeInfoLinear("zTadShape", zTadShape); + //} + + //__shared__ int tDim[1]; + __shared__ int baseIdx; + + __shared__ int yLength; + __shared__ char yOrder; + __shared__ int yEWS; + + char zOrder = shape::order(resultShapeInfo); + + int zEWS = shape::elementWiseStride(resultShapeInfo); + int tadEWS = shape::elementWiseStride(zTadShape); + int zLength = shape::length(resultShapeInfo); + + __shared__ int arrOffset; + __shared__ int numTads; + + if (shape::isVector(resultShapeInfo)) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Vector here\n"); + + if (zEWS >= 1) { + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + if (shape::isVector(shapeInfoPointers[r]) || + shape::order(shapeInfoPointers[r]) == + shape::order(resultShapeInfo)) { + yLength = shape::length(shapeInfoPointers[r]); + yEWS = shape::elementWiseStride(shapeInfoPointers[r]); + // FIXME: this is bad + __shared__ int baseIdx; + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(shapeInfoPointers[f]); } - return; + } + __syncthreads(); + for (int i = threadIdx.x; i < yLength && baseIdx + i < zLength; + i += blockDim.x) { + result[baseIdx + i * zEWS] = dataT[r][i * yEWS]; + } + __syncthreads(); + } else { + if (tid == 0) printf("Non-matched order for vector\n"); } + } + } else { + if (tid == 0) printf("Vector Non-1 zEWS\n"); + } + return; + } + + bool _vec = shape::isVector(resultShapeInfo); + + // TODO: to be pulled into separate kernel. matrix concatenation + for (int r = 0; r < numArrays; r++) { + auto currentShape = shapeInfoPointers[r]; + auto currentData = dataT[r]; + auto currentTad = tadShapes[r]; + auto currentOffsets = tadOffsets[r]; + + if (threadIdx.x == 0) { + yLength = shape::length(currentTad); + yOrder = shape::order(currentTad); + yEWS = shape::elementWiseStride(currentTad); + numTads = shape::length(currentShape) / yLength; + + arrOffset = 0; + for (int f = 0; f < r; f++) { + arrOffset += shape::length(tadShapes[f]); + } + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // shape::printShapeInfoLinear("currentTad", currentTad); + //} + } + __syncthreads(); + if (yLength == 1 && _vec) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch 0\n"); - bool _vec = shape::isVector(resultShapeInfo); + // edge case, each thread will handle it's own tad then + for (int j = tid; j < numTads; j += blockDim.x * gridDim.x) { + Nd4jLong inputOffset = currentOffsets[j]; + Nd4jLong resultOffset = zOffsets[j]; + T *dataTAD = currentData + inputOffset; + T *resultTAD = result + resultOffset; - // TODO: to be pulled into separate kernel. matrix concatenation - for (int r = 0; r < numArrays; r ++) { + int sub[MAX_RANK]; - auto currentShape = shapeInfoPointers[r]; - auto currentData = dataT[r]; - auto currentTad = tadShapes[r]; - auto currentOffsets = tadOffsets[r]; + shape::index2coords(arrOffset, zTadShape, sub); + Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - if (threadIdx.x == 0) { - yLength = shape::length(currentTad); - yOrder = shape::order(currentTad); - yEWS = shape::elementWiseStride(currentTad); - numTads = shape::length(currentShape) / yLength; - - arrOffset = 0; - for (int f = 0; f < r; f++) { - arrOffset += shape::length(tadShapes[f]); - } - - //if (threadIdx.x == 0 && blockIdx.x == 0) { - // shape::printShapeInfoLinear("currentTad", currentTad); - //} - } - __syncthreads(); + resultTAD += baseOffset; + + auto yRank = shape::rank(currentTad); + auto tadRank = shape::rank(zTadShape); + + shape::index2coords(0, currentTad, sub); - if (yLength == 1 && _vec) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch 0\n"); + auto yOffset = shape::getOffset(currentTad, sub); + resultOffset = shape::getOffset(zTadShape, sub); - // edge case, each thread will handle it's own tad then - for (int j = tid; j < numTads; j += blockDim.x * gridDim.x) { - Nd4jLong inputOffset = currentOffsets[j]; - Nd4jLong resultOffset = zOffsets[j]; + resultTAD[resultOffset] = dataTAD[yOffset]; + } + } else { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch 1\n"); - T *dataTAD = currentData + inputOffset; - T *resultTAD = result + resultOffset; + for (int j = blockIdx.x; j < numTads; j += gridDim.x) { + auto inputOffset = currentOffsets[j]; + auto resultOffset = zOffsets[j]; - int sub[MAX_RANK]; + auto dataTAD = currentData + inputOffset; + auto resultTAD = result + resultOffset; - shape::index2coords(arrOffset, zTadShape, sub); + int sub[MAX_RANK]; - Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); + shape::index2coords(arrOffset, zTadShape, sub); + Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - resultTAD += baseOffset; + resultTAD += baseOffset; - auto yRank = shape::rank(currentTad); - auto tadRank = shape::rank(zTadShape); + if (zOrder == yOrder && yEWS > 0 && tadEWS > 0) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch A\n"); - shape::index2coords(0, currentTad, sub); + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + resultTAD[i * tadEWS] = dataTAD[i * yEWS]; + } + } else { + if (tadEWS > 0 && + shape::order(resultShapeInfo) == shape::order(currentTad)) { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch B\n"); - auto yOffset = shape::getOffset(currentTad, sub); - resultOffset = shape::getOffset(zTadShape, sub); + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(shapeInfoPointers[f]); + } + // printf("R: %i; baseIdx: %i;\n", baseIdx); + } + __syncthreads(); - resultTAD[resultOffset] = dataTAD[yOffset]; - } + if (numTads == 1) { + for (int k = threadIdx.x; k < yLength; k += blockDim.x) { + resultTAD[baseIdx + k * tadEWS] = dataTAD[k]; + } } else { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch 1\n"); - - for (int j = blockIdx.x; j < numTads; j += gridDim.x) { - auto inputOffset = currentOffsets[j]; - auto resultOffset = zOffsets[j]; - - auto dataTAD = currentData + inputOffset; - auto resultTAD = result + resultOffset; - - int sub[MAX_RANK]; - - shape::index2coords(arrOffset, zTadShape, sub); - Nd4jLong baseOffset = shape::getOffset(zTadShape, sub); - - resultTAD += baseOffset; - - if (zOrder == yOrder && yEWS > 0 && tadEWS > 0) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch A\n"); - - for (int i = threadIdx.x; i < yLength; i += blockDim.x) { - resultTAD[i * tadEWS] = dataTAD[i * yEWS]; - } - } else { - if(tadEWS > 0 && shape::order(resultShapeInfo) == shape::order(currentTad)) { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch B\n"); - - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(shapeInfoPointers[f]); - } - //printf("R: %i; baseIdx: %i;\n", baseIdx); - } - __syncthreads(); - - if (numTads == 1) { - for(int k = threadIdx.x; k < yLength; k+= blockDim.x) { - resultTAD[baseIdx + k * tadEWS] = dataTAD[k]; - } - } else { - int yIdx[MAX_RANK]; - auto yRank = shape::rank(currentTad); - - for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { - shape::index2coords(i, currentTad, yIdx); - auto yOffset = shape::getOffset(currentTad, yIdx); - - resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset]; - } - } - __syncthreads(); - } else { - //if (threadIdx.x == 0 && blockIdx.x == 0) - // printf("Branch C; yLength: %i;\n", yLength); - - int zIdx[MAX_RANK]; - int yIdx[MAX_RANK]; - auto yRank = shape::rank(currentTad); - auto tadRank = shape::rank(zTadShape); - - for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { - shape::index2coords(i, currentTad, yIdx); - shape::index2coords(i, zTadShape, zIdx); - - auto yOffset = shape::getOffset(currentTad, yIdx); - auto resultOffset = shape::getOffset(zTadShape, zIdx); - - resultTAD[resultOffset] = dataTAD[yOffset]; - } - } - } - __syncthreads(); - } + int yIdx[MAX_RANK]; + auto yRank = shape::rank(currentTad); + + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + shape::index2coords(i, currentTad, yIdx); + auto yOffset = shape::getOffset(currentTad, yIdx); + + resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset]; + } } __syncthreads(); + } else { + // if (threadIdx.x == 0 && blockIdx.x == 0) + // printf("Branch C; yLength: %i;\n", yLength); + + int zIdx[MAX_RANK]; + int yIdx[MAX_RANK]; + auto yRank = shape::rank(currentTad); + auto tadRank = shape::rank(zTadShape); + + for (int i = threadIdx.x; i < yLength; i += blockDim.x) { + shape::index2coords(i, currentTad, yIdx); + shape::index2coords(i, zTadShape, zIdx); + + auto yOffset = shape::getOffset(currentTad, yIdx); + auto resultOffset = shape::getOffset(zTadShape, zIdx); + + resultTAD[resultOffset] = dataTAD[yOffset]; + } + } } + __syncthreads(); + } } + __syncthreads(); + } +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernel(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, - Nd4jLong *zOffsets) { - - concatKernel(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, offsetPointers, zTadShape, - zOffsets); - } - +template +__global__ void execConcatKernel(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong *zTadShape, Nd4jLong *zOffsets) { + concatKernel(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, + offsetPointers, zTadShape, zOffsets); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, - Nd4jLong *zTadShape, - Nd4jLong *zOffsets) { - - - execConcatKernel<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, offsetPointers, zTadShape, zOffsets); - sd::DebugHelper::checkErrorCode(stream, "concatGenericLegacy(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo, Nd4jPointer * tadPointers, Nd4jPointer * offsetPointers, Nd4jLong * zTadShape, Nd4jLong * zOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong *zTadShape, Nd4jLong *zOffsets) { + execConcatKernel<<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo, tadPointers, + offsetPointers, zTadShape, zOffsets); + sd::DebugHelper::checkErrorCode(stream, "concatGenericLegacy(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo, Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, Nd4jLong *zTadShape, + Nd4jLong *zOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu index 39a42e99ed93..8fcf292558f9 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelHStack.cu @@ -24,72 +24,74 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelHStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - // we expect all data coming in as vectors, and z as 2D matrix - // the only significant difference here is the fact that input lengths might be different - auto z = reinterpret_cast(vz); - auto inputShapes = (Nd4jLong **) inputShapeInfos; - T **input = (T **) data; - - __shared__ int inputEWS; - __shared__ int resultEWS; - __shared__ int inputLength; - - if (threadIdx.x == 0) { - resultEWS = shape::elementWiseStride(zShapeInfo); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - - __shared__ int baseIdx; - if (threadIdx.x == 0) { - baseIdx = 0; - for (int f = 0; f < r; f++) { - baseIdx += shape::length(inputShapes[f]); - } - } - __syncthreads(); - - - T *inputData = (T *) input[r]; +template +__device__ void concatKernelHStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + // we expect all data coming in as vectors, and z as 2D matrix + // the only significant difference here is the fact that input lengths might + // be different + auto z = reinterpret_cast(vz); + auto inputShapes = (Nd4jLong **)inputShapeInfos; + T **input = (T **)data; + + __shared__ int inputEWS; + __shared__ int resultEWS; + __shared__ int inputLength; + + if (threadIdx.x == 0) { + resultEWS = shape::elementWiseStride(zShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + __shared__ int baseIdx; + if (threadIdx.x == 0) { + baseIdx = 0; + for (int f = 0; f < r; f++) { + baseIdx += shape::length(inputShapes[f]); + } + } + __syncthreads(); - if (threadIdx.x == 0) { - inputEWS = shape::elementWiseStride(inputShapes[r]); - inputLength = shape::length(inputShapes[r]); - } - __syncthreads(); + T *inputData = (T *)input[r]; - for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { - z[baseIdx + i * resultEWS] = inputData[i * inputEWS]; - } - __syncthreads(); - } + if (threadIdx.x == 0) { + inputEWS = shape::elementWiseStride(inputShapes[r]); + inputLength = shape::length(inputShapes[r]); } + __syncthreads(); -/////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelHStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - concatKernelHStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); + for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { + z[baseIdx + i * resultEWS] = inputData[i * inputEWS]; } + __syncthreads(); + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { +template +__global__ void execConcatKernelHStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + concatKernelHStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); +} - execConcatKernelHStack<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "concatHStack(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelHStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong * zShapeInfo), LIBND4J_TYPES); -} \ No newline at end of file +/////////////////////////////////////////////////////////////////////// +template +__host__ void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + execConcatKernelHStack + <<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "concatHStack(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelHStackGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu index a949c12f4f2f..10f1663f3ce6 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelScalar.cu @@ -24,32 +24,36 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { +template +__device__ void concatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { + auto z = static_cast(vz); + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + auto input = reinterpret_cast(data); - auto z = static_cast(vz); - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - auto input = reinterpret_cast(data); - - for (int i = tid; i < numArrays; i += blockDim.x * gridDim.x) - z[i] = input[i][0]; - } + for (int i = tid; i < numArrays; i += blockDim.x * gridDim.x) + z[i] = input[i][0]; +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelScalar(int numArrays, Nd4jPointer *data, void *vz) { - - concatKernelScalar(numArrays, data, vz); - } +template +__global__ void execConcatKernelScalar(int numArrays, Nd4jPointer *data, + void *vz) { + concatKernelScalar(numArrays, data, vz); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void - concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, void *vz) { - - execConcatKernelScalar<<>>(numArrays, data, vz); - sd::DebugHelper::checkErrorCode(stream, "concatScalar(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelScalarGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, void * vz), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + void *vz) { + execConcatKernelScalar + <<>>(numArrays, data, + vz); + sd::DebugHelper::checkErrorCode(stream, "concatScalar(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelScalarGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, void *vz), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu index cd9b7ca80b3b..2d4cbee4ba03 100644 --- a/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu +++ b/libnd4j/include/loops/cuda/specials/concatKernelVStack.cu @@ -24,62 +24,64 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void concatKernelVStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - /* - this is special case for concat: we group bunch of vectors into 2D matrix - also: we expect each inputShapeInfo to have EWS, be a vector, and have equal size - */ - auto z = static_cast(vz); - - auto inputShapes = (Nd4jLong **) inputShapeInfos; - T **input = (T **) data; - - __shared__ int inputEWS; - __shared__ int resultEWS; - __shared__ int inputLength; - - if (threadIdx.x == 0) { - inputLength = shape::length(inputShapes[0]); - inputEWS = shape::elementWiseStride(inputShapes[0]); - resultEWS = shape::elementWiseStride(zShapeInfo); - } - __syncthreads(); - - for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { - - int zOffset = r * inputLength * resultEWS; - T *inputData = (T *) input[r]; - - for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { - z[zOffset + i * resultEWS] = inputData[i * inputEWS]; - } - } +template +__device__ void concatKernelVStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + /* + this is special case for concat: we group bunch of vectors into 2D matrix + also: we expect each inputShapeInfo to have EWS, be a vector, and have equal + size + */ + auto z = static_cast(vz); + + auto inputShapes = (Nd4jLong **)inputShapeInfos; + T **input = (T **)data; + + __shared__ int inputEWS; + __shared__ int resultEWS; + __shared__ int inputLength; + + if (threadIdx.x == 0) { + inputLength = shape::length(inputShapes[0]); + inputEWS = shape::elementWiseStride(inputShapes[0]); + resultEWS = shape::elementWiseStride(zShapeInfo); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numArrays; r += gridDim.x) { + int zOffset = r * inputLength * resultEWS; + T *inputData = (T *)input[r]; + + for (int i = threadIdx.x; i < inputLength; i += blockDim.x) { + z[zOffset + i * resultEWS] = inputData[i * inputEWS]; } + } +} /////////////////////////////////////////////////////////////////////// - template - __global__ void execConcatKernelVStack(int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - concatKernelVStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); - } - +template +__global__ void execConcatKernelVStack(int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + concatKernelVStack(numArrays, data, inputShapeInfos, vz, zShapeInfo); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfos, - void *vz, Nd4jLong *zShapeInfo) { - - execConcatKernelVStack<<>>(numArrays, data, inputShapeInfos, vz, zShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "concatVStack(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelVStackGeneric, (dim3 & launchDims, cudaStream_t * stream, int numArrays, Nd4jPointer * data, Nd4jPointer * inputShapeInfos, void * vz, Nd4jLong *zShapeInfo), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong *zShapeInfo) { + execConcatKernelVStack + <<>>( + numArrays, data, inputShapeInfos, vz, zShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "concatVStack(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT concatKernelVStackGeneric, + (dim3 & launchDims, cudaStream_t *stream, int numArrays, + Nd4jPointer *data, Nd4jPointer *inputShapeInfos, + void *vz, Nd4jLong *zShapeInfo), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertHalfs.cu b/libnd4j/include/loops/cuda/specials/convertHalfs.cu index c63b78552fc9..4c659c062553 100644 --- a/libnd4j/include/loops/cuda/specials/convertHalfs.cu +++ b/libnd4j/include/loops/cuda/specials/convertHalfs.cu @@ -24,23 +24,26 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __global__ void execConvertHalfs(half *dx, Nd4jLong n, void *dz) { - auto z = reinterpret_cast(dz); - int tid = threadIdx.x + blockIdx.x * blockDim.x; - - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) - z[i] = static_cast(__half2float(dx[i])); - } +template +__global__ void execConvertHalfs(half *dx, Nd4jLong n, void *dz) { + auto z = reinterpret_cast(dz); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) + z[i] = static_cast(__half2float(dx[i])); +} /////////////////////////////////////////////////////////////////////// - template - __host__ void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, half *dx, Nd4jLong n, void *dz) { - - execConvertHalfs<<>>(dx, n, dz); - sd::DebugHelper::checkErrorCode(stream, "convertHalfsToGeneric(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertHalfsToGeneric, (dim3 & launchDims, cudaStream_t * stream, half * dx, Nd4jLong n, void * dz), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, + half *dx, Nd4jLong n, void *dz) { + execConvertHalfs + <<>>(dx, n, dz); + sd::DebugHelper::checkErrorCode(stream, "convertHalfsToGeneric(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertHalfsToGeneric, + (dim3 & launchDims, cudaStream_t *stream, half *dx, + Nd4jLong n, void *dz), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/convertToHalf.cu b/libnd4j/include/loops/cuda/specials/convertToHalf.cu index ad82f5a8fb35..fc6b8801e15f 100644 --- a/libnd4j/include/loops/cuda/specials/convertToHalf.cu +++ b/libnd4j/include/loops/cuda/specials/convertToHalf.cu @@ -24,22 +24,27 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execConvertToHalf(void *dx, Nd4jLong n, half *dz) { - auto x = reinterpret_cast(dx); - int tid = threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ void execConvertToHalf(void *dx, Nd4jLong n, half *dz) { + auto x = reinterpret_cast(dx); + int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) - dz[i] = __float2half(static_cast(x[i])); - } + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) + dz[i] = __float2half(static_cast(x[i])); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong n, half *dz) { - execConvertToHalf<<>>(dx, n, dz); - sd::DebugHelper::checkErrorCode(stream, "convertToHalfs(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertToHalfGeneric, (dim3 & launchDims, cudaStream_t * stream, void * dx, Nd4jLong n, half * dz), LIBND4J_TYPES); - -} \ No newline at end of file +template +__host__ void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, + void *dx, Nd4jLong n, half *dz) { + execConvertToHalf + <<>>(dx, n, dz); + sd::DebugHelper::checkErrorCode(stream, "convertToHalfs(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT convertToHalfGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *dx, + Nd4jLong n, half *dz), + LIBND4J_TYPES); + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu index 44daf5ba1501..4e2b31db53c2 100644 --- a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu @@ -23,72 +23,77 @@ namespace sd { +//////////////////////////////////////////////////////////////////////// +template +__device__ void fillDimensionalIsMax(const void *vdX, void *vdZ, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets) { + auto dX = reinterpret_cast(vdX); + auto dZ = reinterpret_cast(vdZ); + + __shared__ int tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(zShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(zShapeInfo) / tadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; + auto highestElement = dX[r]; + + if (dimensionLength > 1 || tadEWS < 1) { + for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo); + dZ[xOffset] = (e == highestElement ? (T)1 : (T)0); + } + } else { + for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { + // so, we just set dZ[e] for each TAD. Sure, e should be replaced with + auto idx = tadOffsetForBlock + (e * tadEWS); + dZ[idx] = (e == highestElement ? (T)1 : (T)0); + } + } + } +} //////////////////////////////////////////////////////////////////////// - template - __device__ void fillDimensionalIsMax(const void *vdX, - void *vdZ, const Nd4jLong *zShapeInfo, +template +__global__ void execfillDimensionalIsMax(const void *dX, void *dZ, + const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets) { - - auto dX = reinterpret_cast(vdX); - auto dZ = reinterpret_cast(vdZ); - - __shared__ int tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(zShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(zShapeInfo) / tadLength; - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[r]; - auto highestElement = dX[r]; - - if (dimensionLength > 1 || tadEWS < 1) { - - for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo); - dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0); - } - } else { - for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { - // so, we just set dZ[e] for each TAD. Sure, e should be replaced with - auto idx = tadOffsetForBlock + (e * tadEWS); - dZ[idx] = (e == highestElement ? (T) 1 : (T) 0); - } - } - } - } - - -//////////////////////////////////////////////////////////////////////// - template - __global__ void execfillDimensionalIsMax(const void *dX, - void *dZ, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadOnlyShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOffsets) { - - fillDimensionalIsMax(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, tadOffsets); - } + fillDimensionalIsMax(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, + dimensionLength, tadOffsets); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, - const void *dX, - void *dZ, const Nd4jLong *zShapeInfo, - const Nd4jLong *tadOnlyShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadOffsets) { - - execfillDimensionalIsMax<<>>(dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "fillDimensionalIsMax(...) failed"); - } - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillDimensionalIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, const void *dX, void *dZ, const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void fillDimensionalIsMaxGeneric(dim3 &launchDims, + cudaStream_t *stream, const void *dX, + void *dZ, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets) { + execfillDimensionalIsMax + <<>>( + dX, dZ, zShapeInfo, tadOnlyShapeInfo, dimension, dimensionLength, + tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "fillDimensionalIsMax(...) failed"); +} +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillDimensionalIsMaxGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dX, + void *dZ, const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/fillIsMax.cu b/libnd4j/include/loops/cuda/specials/fillIsMax.cu index 175c575b8e65..909b5b6f87c5 100644 --- a/libnd4j/include/loops/cuda/specials/fillIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillIsMax.cu @@ -24,22 +24,28 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execFillIsMax(void *vdZ, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { - auto dz = reinterpret_cast(vdZ); - int tid = blockIdx.x * blockDim.x + threadIdx.x; +template +__global__ void execFillIsMax(void *vdZ, const Nd4jLong *xShapeInfo, + Nd4jLong length, long idx) { + auto dz = reinterpret_cast(vdZ); + int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) - dz[shape::getIndexOffset(i, xShapeInfo)] = (i == idx ? (T) 1 : (T) 0); - } + for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) + dz[shape::getIndexOffset(i, xShapeInfo)] = (i == idx ? (T)1 : (T)0); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { - execFillIsMax<<>>(dx, xShapeInfo, length, idx); - sd::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); - } - - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, const Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, + const Nd4jLong *xShapeInfo, Nd4jLong length, + long idx) { + execFillIsMax<<>>( + dx, xShapeInfo, length, idx); + sd::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT fillIsMaxGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *dz, + const Nd4jLong *zShapeInfo, Nd4jLong length, long idx), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/flatten.cu b/libnd4j/include/loops/cuda/specials/flatten.cu index ef86ac596c11..0babe05b4d81 100644 --- a/libnd4j/include/loops/cuda/specials/flatten.cu +++ b/libnd4j/include/loops/cuda/specials/flatten.cu @@ -26,46 +26,44 @@ namespace sd { //////////////////////////////////////////////////////////////////////// template -__global__ void flattenKernel( - Nd4jPointer *extraPointers, - int dOffset, - char order, - void *vz, Nd4jLong *zShapeInfo, - void *vy, Nd4jLong *yShapeInfo) { +__global__ void flattenKernel(Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo) { + auto z = reinterpret_cast(vz); + auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - auto y = reinterpret_cast(vy); + __shared__ Nd4jLong lenY, yOrder, zEWS, yEWS; - __shared__ Nd4jLong lenY, yOrder, zEWS, yEWS; + if (threadIdx.x == 0) { + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(zShapeInfo); + lenY = shape::length(yShapeInfo); + } + __syncthreads(); - if (threadIdx.x == 0) { + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - yEWS = shape::elementWiseStride(yShapeInfo); - zEWS = shape::elementWiseStride(zShapeInfo); - lenY = shape::length(yShapeInfo); - } - __syncthreads(); - - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - - for(auto i = tid; i < lenY; i += gridDim.x * blockDim.x) - z[i * zEWS + dOffset] = y[ops::helpers::getIndexOffsetOrdered(i, yShapeInfo, order)]; + for (auto i = tid; i < lenY; i += gridDim.x * blockDim.x) + z[i * zEWS + dOffset] = + y[ops::helpers::getIndexOffsetOrdered(i, yShapeInfo, order)]; } //////////////////////////////////////////////////////////////////////// template -__host__ void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, - Nd4jPointer *extraPointers, - int dOffset, - char order, - void *vz, Nd4jLong *zShapeInfo, - void *vy, Nd4jLong *yShapeInfo) { - - flattenKernel<<>>(extraPointers, dOffset, order, vz, zShapeInfo, vy, yShapeInfo); - sd::DebugHelper::checkErrorCode(stream, "flattenGeneric(...) failed"); +__host__ void flattenKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo) { + flattenKernel<<>>( + extraPointers, dOffset, order, vz, zShapeInfo, vy, yShapeInfo); + sd::DebugHelper::checkErrorCode(stream, "flattenGeneric(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT flattenKernelGeneric, (dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT flattenKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, char order, + void *vz, Nd4jLong *zShapeInfo, void *vy, + Nd4jLong *yShapeInfo), + LIBND4J_TYPES); -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/oesTad.cu b/libnd4j/include/loops/cuda/specials/oesTad.cu index 4897ffb55a7d..38d0e915e99a 100644 --- a/libnd4j/include/loops/cuda/specials/oesTad.cu +++ b/libnd4j/include/loops/cuda/specials/oesTad.cu @@ -22,184 +22,195 @@ ////////////////////////////////////////////////////////////////////////// template -__global__ void execOesTadKernelKey(void *vx, Nd4jLong const* xShapeInfo, - void *vy, Nd4jLong const* yShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - __shared__ int xLength; - __shared__ int xTadLength; - __shared__ int numTads; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - xTadLength = shape::length(tadShapeInfo); - numTads = xLength / xTadLength; - } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto dx = x + tadOffsets[r]; - auto dy = y + tadOffsets[r]; - - // this is general loop, we go uncached - int iterations = xTadLength; - - for (int i = 0; i < iterations; i++) { - - if (i % 2 == 0) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < xTadLength) { - auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - X dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - - Y dy0 = dy[t0]; - dy[t0] = dy[t1]; - dy[t1] = dy0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < xTadLength) { - auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - X dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - - Y dy0 = dy[t0]; - dy[t0] = dy[t1]; - dy[t1] = dy0; - } - } - } +__global__ void execOesTadKernelKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + bool descending) { + auto x = static_cast(vx); + auto y = static_cast(vy); + + __shared__ int xLength; + __shared__ int xTadLength; + __shared__ int numTads; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + xTadLength = shape::length(tadShapeInfo); + numTads = xLength / xTadLength; + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto dx = x + tadOffsets[r]; + auto dy = y + tadOffsets[r]; + + // this is general loop, we go uncached + int iterations = xTadLength; + + for (int i = 0; i < iterations; i++) { + if (i % 2 == 0) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < xTadLength) { + auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + X dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; + + Y dy0 = dy[t0]; + dy[t0] = dy[t1]; + dy[t1] = dy0; + } + } + } + } else { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < xTadLength) { + auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + X dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; + + Y dy0 = dy[t0]; + dy[t0] = dy[t1]; + dy[t1] = dy0; } - __syncthreads(); + } } + } + __syncthreads(); } + } } ////////////////////////////////////////////////////////////////////////// -template -__global__ void execOesTadKernel(void *vx, Nd4jLong const* xShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - auto x = static_cast(vx); - const int sharedSize = 32768; - - __shared__ int xLength; - __shared__ int xTadLength; - __shared__ int numTads; - __shared__ T *shmem; - __shared__ bool cached; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - xTadLength = shape::length(tadShapeInfo); - numTads = xLength / xTadLength; - - extern __shared__ unsigned char shrd[]; - shmem = (T *) shrd; - - cached = xTadLength <= (sharedSize / sizeof(T)); +template +__global__ void execOesTadKernel(void *vx, Nd4jLong const *xShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + auto x = static_cast(vx); + const int sharedSize = 32768; + + __shared__ int xLength; + __shared__ int xTadLength; + __shared__ int numTads; + __shared__ T *shmem; + __shared__ bool cached; + if (threadIdx.x == 0) { + xLength = shape::length(xShapeInfo); + xTadLength = shape::length(tadShapeInfo); + numTads = xLength / xTadLength; + + extern __shared__ unsigned char shrd[]; + shmem = (T *)shrd; + + cached = xTadLength <= (sharedSize / sizeof(T)); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto dx = x + tadOffsets[r]; + + // this is general loop, we go uncached + int iterations = xTadLength; + if (cached) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto t0 = shape::getIndexOffset(tid, tadShapeInfo); + shmem[tid] = dx[t0]; + } + + __syncthreads(); + dx = shmem; } - __syncthreads(); - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto dx = x + tadOffsets[r]; - // this is general loop, we go uncached - int iterations = xTadLength; - if (cached) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto t0 = shape::getIndexOffset(tid, tadShapeInfo); - shmem[tid] = dx[t0]; + for (int i = 0; i < iterations; i++) { + if (i % 2 == 0) { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < xTadLength) { + auto t0 = + cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + T dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; } - - __syncthreads(); - dx = shmem; + } } - - for (int i = 0; i < iterations; i++) { - - if (i % 2 == 0) { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < xTadLength) { - auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - T dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < xTadLength) { - auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); - auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); - - if (!descending == (dx[t0] > dx[t1])) { - T dt0 = dx[t0]; - dx[t0] = dx[t1]; - dx[t1] = dt0; - } - } - } + } else { + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < xTadLength) { + auto t0 = + cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo); + auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo); + + if (!descending == (dx[t0] > dx[t1])) { + T dt0 = dx[t0]; + dx[t0] = dx[t1]; + dx[t1] = dt0; } - __syncthreads(); + } } + } + __syncthreads(); + } - - if (cached) { - dx = x + tadOffsets[r]; - for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { - auto t0 = shape::getIndexOffset(tid, tadShapeInfo); - dx[t0] = shmem[tid]; - } - } + if (cached) { + dx = x + tadOffsets[r]; + for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { + auto t0 = shape::getIndexOffset(tid, tadShapeInfo); + dx[t0] = shmem[tid]; + } } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - execOesTadKernel<<>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); +template +__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + execOesTadKernel<<>>( + vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, + descending); } template -__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - void *vy, Nd4jLong const* yShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - bool descending) { - - execOesTadKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); +__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending) { + execOesTadKernelKey + <<>>( + vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, + tadShapeInfo, tadOffsets, descending); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending), LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT oesTadGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending), + LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template void SD_EXPORT oesTadGenericKey, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending), + LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu index d4b63b231cbd..36e2431ca19e 100644 --- a/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu +++ b/libnd4j/include/loops/cuda/specials/pullRowsKernel.cu @@ -24,69 +24,70 @@ namespace sd { /////////////////////////////////////////////////////////////////////// - template - __device__ void pullRowsKernel(void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto xEWS = shape::elementWiseStride(tadShapeInfo); - auto zEWS = shape::elementWiseStride(zTadShapeInfo); - auto tadLength = shape::length(tadShapeInfo); - - if (xEWS >= 1 && zEWS >= 1) { - for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { - T *rX = x + tadOffsets[indexes[idx]]; - T *rZ = z + zTadOffsets[idx]; +template +__device__ void pullRowsKernel(void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto xEWS = shape::elementWiseStride(tadShapeInfo); + auto zEWS = shape::elementWiseStride(zTadShapeInfo); + auto tadLength = shape::length(tadShapeInfo); - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - rZ[i * zEWS] = rX[i * xEWS]; - } - } - } else { - for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { - T *rX = x + tadOffsets[indexes[idx]]; - T *rZ = z + zTadOffsets[idx]; + if (xEWS >= 1 && zEWS >= 1) { + for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { + T *rX = x + tadOffsets[indexes[idx]]; + T *rZ = z + zTadOffsets[idx]; - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadShapeInfo); - auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); - rZ[zOffset] = rX[xOffset]; - } - } - } + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + rZ[i * zEWS] = rX[i * xEWS]; + } } + } else { + for (int idx = blockIdx.x; idx < len; idx += gridDim.x) { + T *rX = x + tadOffsets[indexes[idx]]; + T *rZ = z + zTadOffsets[idx]; -/////////////////////////////////////////////////////////////////////// - template - __global__ void execPullRowsKernel(void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - pullRowsKernel(vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets); + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadShapeInfo); + auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); + rZ[zOffset] = rX[xOffset]; + } } + } +} /////////////////////////////////////////////////////////////////////// - template - __host__ void pullRowsKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, - void *vz, - Nd4jLong len, - Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, - Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets) { - - execPullRowsKernel<<>>(vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets); - sd::DebugHelper::checkErrorCode(stream, "pullRows(...) failed"); - } +template +__global__ void execPullRowsKernel(void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets) { + pullRowsKernel(vx, vz, len, indexes, tadShapeInfo, tadOffsets, + zTadShapeInfo, zTadOffsets); +} - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT pullRowsKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, void * vz, Nd4jLong len, Nd4jLong * indexes, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets), LIBND4J_TYPES); +/////////////////////////////////////////////////////////////////////// +template +__host__ void pullRowsKernelGeneric( + dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong len, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets) { + execPullRowsKernel<<>>( + vx, vz, len, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, + zTadOffsets); + sd::DebugHelper::checkErrorCode(stream, "pullRows(...) failed"); } +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT pullRowsKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vx, + void *vz, Nd4jLong len, Nd4jLong *indexes, + Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, + Nd4jLong const *zTadOffsets), + LIBND4J_TYPES); +} // namespace sd diff --git a/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu b/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu index bb063180c45a..32bf9b1df988 100644 --- a/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu +++ b/libnd4j/include/loops/cuda/specials/setDiagonalKernel.cu @@ -18,8 +18,8 @@ // @author GS , created on 21.01.2019 // -#include #include +#include namespace sd { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -27,93 +27,151 @@ namespace sd { // buffer - input buffer // shape - input shape // value - given value -// diagonal - given upper diagonal (acceptable negative values also, 0 - the main diagonal) -// row, cols - height and width of given matrix (MxN, rows = M, cols = N) +// diagonal - given upper diagonal (acceptable negative values also, 0 - the +// main diagonal) row, cols - height and width of given matrix (MxN, rows = M, +// cols = N) // - template - static __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, T value, int diagonal, Nd4jLong rows, - Nd4jLong cols) { - - __shared__ Nd4jLong rank; - __shared__ T* array; +template +static __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + T value, int diagonal, + Nd4jLong rows, Nd4jLong cols) { + __shared__ Nd4jLong rank; + __shared__ T* array; - if (0 == threadIdx.x) { - rank = shape::rank(shape); - array = reinterpret_cast(buffer); - } - __syncthreads(); + if (0 == threadIdx.x) { + rank = shape::rank(shape); + array = reinterpret_cast(buffer); + } + __syncthreads(); - for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { - for (int j = threadIdx.x; j < cols; j += blockDim.x) { - Nd4jLong coords[2] = {i, j}; - Nd4jLong xOffset = shape::getOffset(shape, coords); - if (i + diagonal <= j) - array[xOffset] = value; - } - } + for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { + for (int j = threadIdx.x; j < cols; j += blockDim.x) { + Nd4jLong coords[2] = {i, j}; + Nd4jLong xOffset = shape::getOffset(shape, coords); + if (i + diagonal <= j) array[xOffset] = value; } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // set up given value to lower given diagonal // buffer - input buffer // shape - input shape // value - given value -// diagonal - given lower diagonal (acceptable negative values also, 0 - the main diagonal) -// row, cols - height and width of given matrix (MxN, rows = M, cols = N) +// diagonal - given lower diagonal (acceptable negative values also, 0 - the +// main diagonal) row, cols - height and width of given matrix (MxN, rows = M, +// cols = N) // - template - static __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, T value, int diagonal, Nd4jLong rows, Nd4jLong cols) { - Nd4jLong rank = shape::rank(shape); - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { - for (int j = threadIdx.x; j < cols; j += totalThreads) { - Nd4jLong coords[2] = {i, j}; - auto xOffset = shape::getOffset(shape, coords); - if (i + diagonal >= j) - *(reinterpret_cast(buffer) + xOffset) = value; - } - } +template +static __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + T value, int diagonal, + Nd4jLong rows, Nd4jLong cols) { + Nd4jLong rank = shape::rank(shape); + int totalThreads = blockDim.x; + for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { + for (int j = threadIdx.x; j < cols; j += totalThreads) { + Nd4jLong coords[2] = {i, j}; + auto xOffset = shape::getOffset(shape, coords); + if (i + diagonal >= j) *(reinterpret_cast(buffer) + xOffset) = value; } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, double value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, double value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, float value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, float value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, float16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, float16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, bfloat16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, bfloat16 value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, Nd4jLong value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, Nd4jLong value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int16_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int16_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, uint8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, uint8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, int8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, int8_t value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, bool value, int diagonal, Nd4jLong rows, Nd4jLong cols); - template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, bool value, int diagonal, Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + double value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + double value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + float value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + float value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + float16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + float16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + bfloat16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + bfloat16 value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + Nd4jLong value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + Nd4jLong value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int16_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int16_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + uint8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + uint8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + int8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + int8_t value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueLowerKernel(void* buffer, Nd4jLong* shape, + bool value, int diagonal, + Nd4jLong rows, Nd4jLong cols); +template __global__ void setDiagValueUpperKernel(void* buffer, Nd4jLong* shape, + bool value, int diagonal, + Nd4jLong rows, Nd4jLong cols); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void setDiagonalValueUpper(void* buffer, Nd4jLong* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream) { - dim3 launchDims(256, 512, 8192); - setDiagValueUpperKernel<<>>(buffer, shape, value.e(0), diagonal, rows, cols); - } +template +static void setDiagonalValueUpper(void* buffer, Nd4jLong* shape, + NDArray const& value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream) { + dim3 launchDims(256, 512, 8192); + setDiagValueUpperKernel + <<>>( + buffer, shape, value.e(0), diagonal, rows, cols); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void setDiagonalValueLower(void* buffer, Nd4jLong* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream) { - dim3 launchDims(256, 512, 8192); - setDiagValueLowerKernel<<>>(buffer, shape, value.e(0), diagonal, rows, cols); - } +template +static void setDiagonalValueLower(void* buffer, Nd4jLong* shape, + NDArray const& value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream) { + dim3 launchDims(256, 512, 8192); + setDiagValueLowerKernel + <<>>( + buffer, shape, value.e(0), diagonal, rows, cols); +} - BUILD_SINGLE_TEMPLATE(template void setDiagonalValueUpper, (void* buffer, Nd4jLong* shape, NDArray const& value, - int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void setDiagonalValueLower, (void* buffer, Nd4jLong* shape, NDArray const& value, - int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void setDiagonalValueUpper, + (void* buffer, Nd4jLong* shape, NDArray const& value, + int diagonal, Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void setDiagonalValueLower, + (void* buffer, Nd4jLong* shape, NDArray const& value, + int diagonal, Nd4jLong rows, Nd4jLong cols, + cudaStream_t& stream), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu index 643b0fd12a46..e30ce8210e9f 100644 --- a/libnd4j/include/loops/cuda/specials/shuffleKernel.cu +++ b/libnd4j/include/loops/cuda/specials/shuffleKernel.cu @@ -24,102 +24,103 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __global__ void execShuffleKernel(void **vdX, Nd4jLong **dxShapeInfo, - void **vdZ, - int N, - int *shuffleMap, - Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) { - - // we assume that shuffle map for each X contains pair TAD Y - auto dX = reinterpret_cast(vdX); - auto dZ = reinterpret_cast(vdZ); - - __shared__ int tadLength; - __shared__ int xRank; - __shared__ int tadEWS; - __shared__ int numTads; - __shared__ Nd4jLong* xShapeInfo; - __shared__ Nd4jLong xLength; - - for (int f = 0; f < N; f++) { - auto x = reinterpret_cast(dX[f]); - auto z = reinterpret_cast(dZ[f]); - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo[f]); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); - xShapeInfo = dxShapeInfo[f]; - xRank = shape::rank(xShapeInfo); - xLength = shape::length(xShapeInfo); - numTads = xLength / tadLength; +template +__global__ void execShuffleKernel(void **vdX, Nd4jLong **dxShapeInfo, + void **vdZ, int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets) { + // we assume that shuffle map for each X contains pair TAD Y + auto dX = reinterpret_cast(vdX); + auto dZ = reinterpret_cast(vdZ); + + __shared__ int tadLength; + __shared__ int xRank; + __shared__ int tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong *xShapeInfo; + __shared__ Nd4jLong xLength; + + for (int f = 0; f < N; f++) { + auto x = reinterpret_cast(dX[f]); + auto z = reinterpret_cast(dZ[f]); + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo[f]); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); + xShapeInfo = dxShapeInfo[f]; + xRank = shape::rank(xShapeInfo); + xLength = shape::length(xShapeInfo); + numTads = xLength / tadLength; + } + __syncthreads(); + + if (xRank == 1) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int r = tid; r < xLength; r += gridDim.x * blockDim.x) { + auto swapIndex = shuffleMap[r]; + if (swapIndex >= 0 && swapIndex < xLength) { + int idx = r * tadEWS; + int swap = swapIndex * tadEWS; + T oldX = x[idx]; + x[idx] = x[swap]; + x[swap] = oldX; + } + } + } else { + // we roll over the pairs of TADs, thus limit is numTads / 2 + for (uint r = blockIdx.x; r < numTads; r += gridDim.x) { + if (shuffleMap[r] >= 0) { + auto oldOffset = tadOffsets[f][r]; + auto newOffset = tadOffsets[f][shuffleMap[r]]; + + auto rX = x + oldOffset; + auto rY = x + newOffset; + + auto zX = z + oldOffset; + auto zY = z + newOffset; + + // so we're going to change TAD[oldOffset] with TAD[newOffset] + if (tadEWS == 1) { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + T oldX = rX[i]; + rX[i] = rY[i]; + zY[i] = oldX; } - __syncthreads(); - - if (xRank == 1) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int r = tid; r < xLength; r += gridDim.x * blockDim.x) { - auto swapIndex = shuffleMap[r]; - if (swapIndex >= 0 && swapIndex < xLength) { - int idx = r * tadEWS; - int swap = swapIndex * tadEWS; - T oldX = x[idx]; - x[idx] = x[swap]; - x[swap] = oldX; - } - } - } else { - // we roll over the pairs of TADs, thus limit is numTads / 2 - for (uint r = blockIdx.x; r < numTads; r += gridDim.x) { - if (shuffleMap[r] >= 0) { - auto oldOffset = tadOffsets[f][r]; - auto newOffset = tadOffsets[f][shuffleMap[r]]; - auto rX = x + oldOffset; - auto rY = x + newOffset; + } else { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); + auto yOffset = newOffset + xOffset; + xOffset += oldOffset; - auto zX = z + oldOffset; - auto zY = z + newOffset; - - // so we're going to change TAD[oldOffset] with TAD[newOffset] - if (tadEWS == 1) { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { - T oldX = rX[i]; - rX[i] = rY[i]; - zY[i] = oldX; - } - - } else { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { - - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); - auto yOffset = newOffset + xOffset; - xOffset += oldOffset; - - T oldX = x[xOffset]; - z[xOffset] = x[yOffset]; - z[yOffset] = oldX; - } - } - } - } + T oldX = x[xOffset]; + z[xOffset] = x[yOffset]; + z[yOffset] = oldX; } - __syncthreads(); + } } + } } + __syncthreads(); + } +} //////////////////////////////////////////////////////////////////////// - template - __host__ void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void **vdX, Nd4jLong **xShapeInfo, - void **vdZ, - int N, - int *shuffleMap, - Nd4jLong **tadOnlyShapeInfo, Nd4jLong **tadOffsets) { - - execShuffleKernel<<>>(vdX, xShapeInfo, vdZ, N, shuffleMap, tadOnlyShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "shuffleGeneric(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT shuffleKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * *vdX, Nd4jLong * *xShapeInfo, void **vdZ, int N, int * shuffleMap, Nd4jLong * *tadOnlyShapeInfo, Nd4jLong * *tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdX, Nd4jLong **xShapeInfo, + void **vdZ, int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets) { + execShuffleKernel<<>>( + vdX, xShapeInfo, vdZ, N, shuffleMap, tadOnlyShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "shuffleGeneric(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT shuffleKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, void **vdX, + Nd4jLong **xShapeInfo, void **vdZ, int N, + int *shuffleMap, Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu b/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu index 334584fabf6e..338d0afcd298 100644 --- a/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu +++ b/libnd4j/include/loops/cuda/specials/swapUnsafeKernel.cu @@ -22,41 +22,58 @@ namespace sd { - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // kernel to swap two NDArrays vals as linear sequences - // input - theSecondBuffer/Shape from input NDArray - // output - theFirstBuffer/Shape from input NDArray - template - static __global__ void swapUnsafeKernel(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - __shared__ Nd4jLong resultLength; - __shared__ T* input; - __shared__ T* output; - if (0 == threadIdx.x) { - resultLength = shape::length(theFirstShape); - input = reinterpret_cast(theSecondBuffer); - output = reinterpret_cast(theFirstBuffer); - } - __syncthreads(); - - for (int i = tid; i < resultLength; i += totalThreads) { - auto xEws = shape::order(theFirstShape) == 'c'? shape::elementWiseStride(theFirstShape) :1; - auto yEws = shape::order(theSecondShape) == 'c'? shape::elementWiseStride(theSecondShape):1; - - auto xOffset = shape::getIndexOffset(i * xEws, theFirstShape); - auto yOffset = shape::getIndexOffset(i * yEws, theSecondShape); - sd::math::nd4j_swap(output[xOffset], input[yOffset]); - } - } - - BUILD_SINGLE_TEMPLATE(template __global__ void swapUnsafeKernel, (void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape), LIBND4J_TYPES); - - template - void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream) { - swapUnsafeKernel<<<256, 512, 8192, *theStream>>>(theFirstBuffer, theFirstShape, theSecondBuffer, theSecondShape); - } - BUILD_SINGLE_TEMPLATE(template void templatedSwapUnsafe, (void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream), LIBND4J_TYPES); - -} \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// kernel to swap two NDArrays vals as linear sequences +// input - theSecondBuffer/Shape from input NDArray +// output - theFirstBuffer/Shape from input NDArray +template +static __global__ void swapUnsafeKernel(void* theFirstBuffer, + Nd4jLong const* theFirstShape, + void* theSecondBuffer, + Nd4jLong const* theSecondShape) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + __shared__ Nd4jLong resultLength; + __shared__ T* input; + __shared__ T* output; + if (0 == threadIdx.x) { + resultLength = shape::length(theFirstShape); + input = reinterpret_cast(theSecondBuffer); + output = reinterpret_cast(theFirstBuffer); + } + __syncthreads(); + + for (int i = tid; i < resultLength; i += totalThreads) { + auto xEws = shape::order(theFirstShape) == 'c' + ? shape::elementWiseStride(theFirstShape) + : 1; + auto yEws = shape::order(theSecondShape) == 'c' + ? shape::elementWiseStride(theSecondShape) + : 1; + + auto xOffset = shape::getIndexOffset(i * xEws, theFirstShape); + auto yOffset = shape::getIndexOffset(i * yEws, theSecondShape); + sd::math::nd4j_swap(output[xOffset], input[yOffset]); + } +} + +BUILD_SINGLE_TEMPLATE(template __global__ void swapUnsafeKernel, + (void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape), + LIBND4J_TYPES); + +template +void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape, + cudaStream_t* theStream) { + swapUnsafeKernel<<<256, 512, 8192, *theStream>>>( + theFirstBuffer, theFirstShape, theSecondBuffer, theSecondShape); +} +BUILD_SINGLE_TEMPLATE(template void templatedSwapUnsafe, + (void* theFirstBuffer, Nd4jLong const* theFirstShape, + void* theSecondBuffer, Nd4jLong const* theSecondShape, + cudaStream_t* theStream), + LIBND4J_TYPES); + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tearKernel.cu b/libnd4j/include/loops/cuda/specials/tearKernel.cu index 243f2ac5d388..b85b86f97ec7 100644 --- a/libnd4j/include/loops/cuda/specials/tearKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tearKernel.cu @@ -24,72 +24,75 @@ namespace sd { //////////////////////////////////////////////////////////////////////// - template - __device__ void - tearKernel(void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - - - __shared__ Nd4jLong tadLength; - __shared__ int tadEWS; - __shared__ int zEWS; -// __shared__ int tadRank; - __shared__ Nd4jLong numTads; -// __shared__ int zRank; -// __shared__ Nd4jLong *tadShape; -// __shared__ Nd4jLong *tadStride; -// __shared__ Nd4jLong const* zShape; -// __shared__ Nd4jLong const* zStride; - __shared__ T* x; - if (threadIdx.x == 0) { - tadLength = shape::length(tadShapeInfo); - tadEWS = shape::elementWiseStride(tadShapeInfo); - zEWS = shape::elementWiseStride(zShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - x = static_cast(vx); - } - __syncthreads(); - - for (Nd4jLong r = blockIdx.x; r < numTads; r += gridDim.x) { - T *z = (T *) targets[r]; - T *s = x + tadOffsets[r]; - - if (zEWS > 0 && tadEWS > 0) { - for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) - z[i * zEWS] = s[i * tadEWS]; - } else { - - for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) { - auto xOffset = shape::getIndexOffset(j, tadShapeInfo); - auto zOffset = shape::getIndexOffset(j, zShapeInfo); - - z[zOffset] = s[xOffset]; - } - } - } +template +__device__ void tearKernel(void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + __shared__ Nd4jLong tadLength; + __shared__ int tadEWS; + __shared__ int zEWS; + // __shared__ int tadRank; + __shared__ Nd4jLong numTads; + // __shared__ int zRank; + // __shared__ Nd4jLong *tadShape; + // __shared__ Nd4jLong *tadStride; + // __shared__ Nd4jLong const* zShape; + // __shared__ Nd4jLong const* zStride; + __shared__ T* x; + if (threadIdx.x == 0) { + tadLength = shape::length(tadShapeInfo); + tadEWS = shape::elementWiseStride(tadShapeInfo); + zEWS = shape::elementWiseStride(zShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + x = static_cast(vx); + } + __syncthreads(); + + for (Nd4jLong r = blockIdx.x; r < numTads; r += gridDim.x) { + T* z = (T*)targets[r]; + T* s = x + tadOffsets[r]; + + if (zEWS > 0 && tadEWS > 0) { + for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) + z[i * zEWS] = s[i * tadEWS]; + } else { + for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) { + auto xOffset = shape::getIndexOffset(j, tadShapeInfo); + auto zOffset = shape::getIndexOffset(j, zShapeInfo); + + z[zOffset] = s[xOffset]; + } } - + } +} //////////////////////////////////////////////////////////////////////// - template - __global__ void - execTearKernel(void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, - Nd4jLong const* tadOffsets) { - - tearKernel(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); - } +template +__global__ void execTearKernel(void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + tearKernel(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); +} //////////////////////////////////////////////////////////////////////// - template - __host__ void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, - void *vx, Nd4jLong const* xShapeInfo, - Nd4jPointer *targets, Nd4jLong const* zShapeInfo, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) { - - execTearKernel<<>>(vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "tear(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template void SD_EXPORT tearKernelGeneric, (dim3 & launchDims, cudaStream_t * stream, void * vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets), LIBND4J_TYPES); -} \ No newline at end of file +template +__host__ void tearKernelGeneric(dim3& launchDims, cudaStream_t* stream, + void* vx, Nd4jLong const* xShapeInfo, + Nd4jPointer* targets, + Nd4jLong const* zShapeInfo, + Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets) { + execTearKernel<<>>( + vx, xShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "tear(...) failed"); +} + +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT tearKernelGeneric, + (dim3 & launchDims, cudaStream_t* stream, void* vx, + Nd4jLong const* xShapeInfo, Nd4jPointer* targets, + Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tileKernel.cu b/libnd4j/include/loops/cuda/specials/tileKernel.cu index 3a2684579b59..abed61912405 100644 --- a/libnd4j/include/loops/cuda/specials/tileKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tileKernel.cu @@ -21,91 +21,126 @@ #include namespace sd { - static Nd4jLong __device__ __noinline__ getIndexOffset_(Nd4jLong index, Nd4jLong const* shapeInfo) { - return shape::getIndexOffset(index, shapeInfo); - } - - static Nd4jLong __device__ __noinline__ subArrayOffset(Nd4jLong index, Nd4jLong const* shapeInfoA, Nd4jLong const* shapeInfoB) { - return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); - } +static Nd4jLong __device__ __noinline__ +getIndexOffset_(Nd4jLong index, Nd4jLong const* shapeInfo) { + return shape::getIndexOffset(index, shapeInfo); +} +static Nd4jLong __device__ __noinline__ subArrayOffset( + Nd4jLong index, Nd4jLong const* shapeInfoA, Nd4jLong const* shapeInfoB) { + return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // tileKernel: // input: (inputBuffer and inputShape) - NDArray buffer and shape to tile // output: (outputBuffer and outputShape) - NDArray to tile input // resultLength - length for output array - template - static __global__ void - tileKernel(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, - Nd4jLong resultLength) { -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Original code to transform in cuda-based - auto tid = blockIdx.x * blockDim.x + threadIdx.x; // copy linear sequence of elements, so one-level threading - int totalThreads = gridDim.x * blockDim.x; - if (shape::order(outputShape) == 'c') { // ews == 1 always here - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = *(reinterpret_cast(inputBuffer) + yOffset); - } - } else { - for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = getIndexOffset_(i, outputShape); - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + yOffset); - } - } - +template +static __global__ void tileKernel(void const* inputBuffer, + Nd4jLong const* inputShape, + void* outputBuffer, + Nd4jLong const* outputShape, + Nd4jLong resultLength) { + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Original code to transform in cuda-based + auto tid = + blockIdx.x * blockDim.x + + threadIdx.x; // copy linear sequence of elements, so one-level threading + int totalThreads = gridDim.x * blockDim.x; + if (shape::order(outputShape) == 'c') { // ews == 1 always here + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = + *(reinterpret_cast(inputBuffer) + yOffset); } - - BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - void tileKernelH(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream) { - dim3 launchDims(256, 512, 8192); - tileKernel << < launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength); + } else { + for (int i = tid; i < resultLength; i += totalThreads) { + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = + *(reinterpret_cast(inputBuffer) + yOffset); } + } +} - BUILD_SINGLE_TEMPLATE(template void tileKernelH, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel, + (void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// enhancement for tileKernel to different input and output data types: X - output type, Y - input type - template - static __global__ void - tileKernelDouble(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews) { - char ordering = shape::order(outputShape); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; +template +void tileKernelH(void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, cudaStream_t* stream) { + dim3 launchDims(256, 512, 8192); + tileKernel<<>>( + inputBuffer, inputShape, outputBuffer, outputShape, resultLength); +} + +BUILD_SINGLE_TEMPLATE(template void tileKernelH, + (void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, cudaStream_t* stream), + LIBND4J_TYPES); - if (ordering == 'c' && ews == 1) { // ews == 1 always here - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } else if (ordering == 'c' && ews > 1) { - for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } else { - - for (int i = tid; i < resultLength; i += totalThreads) { - - auto xOffset = getIndexOffset_(i, outputShape); - auto yOffset = subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); - } - } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// enhancement for tileKernel to different input and output data types: X - +// output type, Y - input type +template +static __global__ void tileKernelDouble(void const* inputBuffer, + Nd4jLong const* inputShape, + void* outputBuffer, + Nd4jLong const* outputShape, + Nd4jLong resultLength, Nd4jLong ews) { + char ordering = shape::order(outputShape); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (ordering == 'c' && ews == 1) { // ews == 1 always here + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } - - BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES); - - template - void tileKernelHH(void const *inputBuffer, Nd4jLong const* inputShape, void *outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) { - dim3 launchDims(256, 512, 8192); - tileKernelDouble<<>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); + } else if (ordering == 'c' && ews > 1) { + for (int i = tid; i < resultLength; i += totalThreads) { + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i * ews) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } - - BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES); -} \ No newline at end of file + } else { + for (int i = tid; i < resultLength; i += totalThreads) { + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = + static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); + } + } +} + +BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, + (void const* inputBuffer, + Nd4jLong const* inputShape, void* outputBuffer, + Nd4jLong const* outputShape, Nd4jLong resultLength, + Nd4jLong ews), + LIBND4J_TYPES); + +template +void tileKernelHH(void const* inputBuffer, Nd4jLong const* inputShape, + void* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong resultLength, Nd4jLong ews, cudaStream_t* stream) { + dim3 launchDims(256, 512, 8192); + tileKernelDouble<<>>( + inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); +} + +BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, + (void const* inputBuffer, + Nd4jLong const* inputShape, void* outputBuffer, + Nd4jLong const* outputShape, Nd4jLong resultLength, + Nd4jLong ews, cudaStream_t* stream), + LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index 89f869a2a213..d7acde30e8d7 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -18,402 +18,414 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include using namespace simdOps; namespace functions { - namespace summarystats { +namespace summarystats { template -void _CUDA_G summaryStatsReduceT(int op, void const* dx, Nd4jLong const* xShapeInfo, int xRank, void *extraParams, void *z, Nd4jLong const* zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot,bool biasCorrected,int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - functions::summarystats::SummaryStatsReduce::transform(op,dx,xShapeInfo,extraParams,z,zShapeInfo,dimension,dimensionLength,biasCorrected,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); +void _CUDA_G summaryStatsReduceT( + int op, void const* dx, Nd4jLong const* xShapeInfo, int xRank, + void* extraParams, void* z, Nd4jLong const* zShapeInfo, int zRank, + int* dimension, int dimensionLength, int postProcessOrNot, + bool biasCorrected, int* allocationBuffer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { + functions::summarystats::SummaryStatsReduce::transform( + op, dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, biasCorrected, allocationBuffer, reductionBuffer, + tadOnlyShapeInfo, tadOffsets); } - /** - * - * @param sPartialsRef - * @param tid - * @param extraParams - */ - template - template - _CUDA_D void SummaryStatsReduce::aggregatePartials(SummaryStatsData **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { - // start the shared memory loop on the next power of 2 less - // than the block size. If block size is not a power of 2, - // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - SummaryStatsData *sPartials = *sPartialsRef; - Nd4jLong floorPow2 = blockDim.x; - - if (floorPow2 & (floorPow2 - 1)) { - while (floorPow2 & (floorPow2 - 1)) { - floorPow2 &= floorPow2 - 1; - } - - if (tid >= floorPow2) { - SummaryStatsData prev = sPartials[tid - floorPow2]; - SummaryStatsData curr = sPartials[tid]; - sPartials[tid - floorPow2] = update(prev, curr, extraParams); - } - __syncthreads(); - } - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (tid < activeThreads && tid + activeThreads < numElements) { - SummaryStatsData curr = sPartials[tid]; - SummaryStatsData next = sPartials[tid + activeThreads]; - sPartials[tid] = update(curr, next, extraParams); - } - __syncthreads(); - } - }; - - /** - * @param n n is the number of - * elements to loop through - * @param dx the data to operate on - * @param xVectorInfo the meta data for the vector: - * 0 is the offset - * 1 is the increment/stride - * 2 is the real length of the buffer (n and dx.length won't always be the same) - * 3 is the element wise stride for the buffer - * 4 is the number of elements it takes to get to the next row/column/tensor - * @param gpuInformation - * 0 is the block size - * 1 is the grid size - * 2 is the shared memory size - * @param problemDefinition - * 0 is the number of elements per vector - * 1 is the number of vectors - */ - template - template - _CUDA_D void SummaryStatsReduce::transform(void const* vx, Nd4jLong const* xShapeInfo, - void *vextraParams, - void *vz, Nd4jLong const* zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *vreductionBuffer, - Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - - auto dx = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ volatile int resultScalar; - - __shared__ int xElementWiseStride; - - int numElements = blockDim.x; - //shared memory space for storing intermediate results - __shared__ SummaryStatsData *sPartials; - if(threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); - } - __syncthreads(); - - Z startingVal = startingValue(dx); - - SummaryStatsData val; - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - - //length for the tad - __shared__ volatile int xLength; - - __shared__ volatile int resultLength; - - - SummaryStatsData reduction; - reduction.initWithValue(0.0); - reduction.n = 0; - if (threadIdx.x == 0) { - if (zShapeInfo != nullptr) - resultLength = shape::length(zShapeInfo); - else resultLength = 1; - - - if (dimensionLength == 1) { - if (resultLength == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) - resultScalar = 1; - else - resultScalar = 0; - } - else - resultScalar = 0; - - if (resultLength == 1) - resultScalar = 1; - - auto xStride = shape::stride(xShapeInfo); - auto xOrder = shape::order(xShapeInfo); - - if (dimension != nullptr && (dimension[0] != MAX_DIMENSION && dimensionLength == 1)) { - xElementWiseStride = xStride[dimension[0]]; - } - else { - xElementWiseStride = shape::elementWiseStride(xShapeInfo); - } - - - xLength = shape::length(xShapeInfo); - - - } - __syncthreads(); - if (!resultScalar) { - - __shared__ int tadLength; - __shared__ int tadEWS; - __shared__ int numTads; - - if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); - tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; - } - __syncthreads(); - - if (tadEWS == 0) { - - for (int r = blockIdx.x; r < numTads; r += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[r]; - - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { - auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[xOffset]); - - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), extraParams); - } - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[r] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); - } - __syncthreads(); - } - } - else { - - for (int i = blockIdx.x; i < numTads; i += gridDim.x) { - auto tadOffsetForBlock = tadOffsets[i]; - - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { - auto indexX = tadOffsetForBlock + x * tadEWS; - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[indexX]); - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - - __syncthreads(); - if (threadIdx.x == 0) { - z[i] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); //postProcess(sPartials[0],tadLength ,extraParams); - } - } - } - } - else if (resultScalar) { - __shared__ int n; - if (threadIdx.x == 0) { - xElementWiseStride = shape::elementWiseStride(xShapeInfo); - n = shape::length(xShapeInfo); - } - __syncthreads(); - - if (xElementWiseStride >= 1) { - for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[i * xElementWiseStride]); - reduction = update(reduction, indexVal2, extraParams); - } - } - else { - - for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { - - auto offset = shape::getIndexOffset(i, xShapeInfo); - SummaryStatsData indexVal2; - indexVal2.initWithValue(dx[offset]); - reduction = update(reduction, indexVal2, extraParams); - } - } - sPartials[threadIdx.x] = reduction; - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, blockDim.x, extraParams); - __syncthreads(); - - if (gridDim.x > 1) { - __shared__ bool amLast; - unsigned int *tc = (unsigned int *)reductionBuffer; - tid = threadIdx.x; - if (threadIdx.x == 0) { - SummaryStatsData *pBuffer = (SummaryStatsData*) reductionBuffer; - pBuffer[blockIdx.x] = sPartials[0]; - } - __threadfence(); - __syncthreads(); - - if (tid == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - SummaryStatsData* pBuffer = (SummaryStatsData*) reductionBuffer; - - Z startingVal = startingValue(dx); - - SummaryStatsData val; - val.initWithValue(startingVal); - val.n = 0; - sPartials[threadIdx.x] = val; - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - sPartials[threadIdx.x] = update(sPartials[threadIdx.x], pBuffer[i], extraParams); - } - - __syncthreads(); - aggregatePartials(&sPartials, threadIdx.x, gridDim.x, extraParams); - __syncthreads(); - - if (tid == 0) { - z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); - } - } - } - else { - if (tid == 0) { - unsigned int *tc = (unsigned *)reductionBuffer; - tc[16384] = 0; - z[0] = z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); - } - } - } - }; - - - template - _CUDA_D void SummaryStatsReduce::transform(const int opNum, void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *z, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { - DISPATCH_BY_OPNUM_TT(transform, PARAMS(dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets), SUMMARY_STATS_OPS); - }; - - - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { - - auto x = static_cast(vx); - auto extraParams = static_cast(vextraParams); - auto z = reinterpret_cast(vz); - auto reductionPointerA = reinterpret_cast(reductionBuffer); - - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("D16 opNum:[%i]\n", opNum); - - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - nullptr, - 1, - 1,biasCorrected, nullptr, reductionPointerA, tadShapeInfo, tadOffsets); - - // this is blocking method since method should return scalar - sd::DebugHelper::checkErrorCode(stream, "execSSReduceScalar(...) failed"); - } +/** + * + * @param sPartialsRef + * @param tid + * @param extraParams + */ +template +template +_CUDA_D void SummaryStatsReduce::aggregatePartials( + SummaryStatsData** sPartialsRef, Nd4jLong tid, Nd4jLong numElements, + void* vextraParams) { + // start the shared memory loop on the next power of 2 less + // than the block size. If block size is not a power of 2, + // accumulate the intermediate sums in the remainder range. + auto extraParams = static_cast(vextraParams); + SummaryStatsData* sPartials = *sPartialsRef; + Nd4jLong floorPow2 = blockDim.x; + + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) { + floorPow2 &= floorPow2 - 1; + } - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { + if (tid >= floorPow2) { + SummaryStatsData prev = sPartials[tid - floorPow2]; + SummaryStatsData curr = sPartials[tid]; + sPartials[tid - floorPow2] = update(prev, curr, extraParams); + } + __syncthreads(); + } + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (tid < activeThreads && tid + activeThreads < numElements) { + SummaryStatsData curr = sPartials[tid]; + SummaryStatsData next = sPartials[tid + activeThreads]; + sPartials[tid] = update(curr, next, extraParams); + } + __syncthreads(); + } +}; + +/** + * @param n n is the number of + * elements to loop through + * @param dx the data to operate on + * @param xVectorInfo the meta data for the vector: + * 0 is the offset + * 1 is the increment/stride + * 2 is the real length of the buffer (n and + * dx.length won't always be the same) 3 is the element wise stride for the + * buffer 4 is the number of elements it takes to get to the next + * row/column/tensor + * @param gpuInformation + * 0 is the block size + * 1 is the grid size + * 2 is the shared memory size + * @param problemDefinition + * 0 is the number of elements per vector + * 1 is the number of vectors + */ +template +template +_CUDA_D void SummaryStatsReduce::transform( + void const* vx, Nd4jLong const* xShapeInfo, void* vextraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationBuffer, void* vreductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets) { + auto dx = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ volatile int resultScalar; + + __shared__ int xElementWiseStride; + + int numElements = blockDim.x; + // shared memory space for storing intermediate results + __shared__ SummaryStatsData* sPartials; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sPartials = reinterpret_cast*>(shmem); + } + __syncthreads(); + + Z startingVal = startingValue(dx); + + SummaryStatsData val; + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + // length for the tad + __shared__ volatile int xLength; + + __shared__ volatile int resultLength; + + SummaryStatsData reduction; + reduction.initWithValue(0.0); + reduction.n = 0; + if (threadIdx.x == 0) { + if (zShapeInfo != nullptr) + resultLength = shape::length(zShapeInfo); + else + resultLength = 1; + + if (dimensionLength == 1) { + if (resultLength == 1 && + (dimension == nullptr || dimension[0] == MAX_DIMENSION)) + resultScalar = 1; + else + resultScalar = 0; + } else + resultScalar = 0; + + if (resultLength == 1) resultScalar = 1; + + auto xStride = shape::stride(xShapeInfo); + auto xOrder = shape::order(xShapeInfo); + + if (dimension != nullptr && + (dimension[0] != MAX_DIMENSION && dimensionLength == 1)) { + xElementWiseStride = xStride[dimension[0]]; + } else { + xElementWiseStride = shape::elementWiseStride(xShapeInfo); + } - auto x = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); + xLength = shape::length(xShapeInfo); + } + __syncthreads(); + if (!resultScalar) { + __shared__ int tadLength; + __shared__ int tadEWS; + __shared__ int numTads; + + if (threadIdx.x == 0) { + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + } + __syncthreads(); - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("F17 opNum:[%i]\n", opNum); + if (tadEWS == 0) { + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[r]; - auto reductionPointerA = reinterpret_cast(reductionBuffer); + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - nullptr, - 1, - 1,biasCorrected, nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { + auto xOffset = + tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo); + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[xOffset]); - DEBUG_KERNEL(stream, opNum); + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), + extraParams); + } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); + + __syncthreads(); + if (threadIdx.x == 0) { + z[r] = OpType::getValue(postProcessOrNot, sPartials[threadIdx.x]); + } + __syncthreads(); + } + } else { + for (int i = blockIdx.x; i < numTads; i += gridDim.x) { + auto tadOffsetForBlock = tadOffsets[i]; + + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + for (int x = threadIdx.x; x < tadLength; x += blockDim.x) { + auto indexX = tadOffsetForBlock + x * tadEWS; + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[indexX]); + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], OpType::op(indexVal2, extraParams), + extraParams); } + __syncthreads(); + aggregatePartials( + &sPartials, threadIdx.x, + sd::math::nd4j_min(blockDim.x, tadLength), extraParams); - template - _CUDA_H void SummaryStatsReduce::execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *vextraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer) { + __syncthreads(); + if (threadIdx.x == 0) { + z[i] = OpType::getValue(postProcessOrNot, + sPartials[threadIdx.x]); // postProcess(sPartials[0],tadLength + // ,extraParams); + } + } + } + } else if (resultScalar) { + __shared__ int n; + if (threadIdx.x == 0) { + xElementWiseStride = shape::elementWiseStride(xShapeInfo); + n = shape::length(xShapeInfo); + } + __syncthreads(); + + if (xElementWiseStride >= 1) { + for (Nd4jLong i = tid; i < n; i += (blockDim.x * gridDim.x)) { + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[i * xElementWiseStride]); + reduction = update(reduction, indexVal2, extraParams); + } + } else { + for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { + auto offset = shape::getIndexOffset(i, xShapeInfo); + SummaryStatsData indexVal2; + indexVal2.initWithValue(dx[offset]); + reduction = update(reduction, indexVal2, extraParams); + } + } + sPartials[threadIdx.x] = reduction; + + __syncthreads(); + aggregatePartials(&sPartials, threadIdx.x, blockDim.x, extraParams); + __syncthreads(); + + if (gridDim.x > 1) { + __shared__ bool amLast; + unsigned int* tc = (unsigned int*)reductionBuffer; + tid = threadIdx.x; + if (threadIdx.x == 0) { + SummaryStatsData* pBuffer = (SummaryStatsData*)reductionBuffer; + pBuffer[blockIdx.x] = sPartials[0]; + } + __threadfence(); + __syncthreads(); + + if (tid == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + SummaryStatsData* pBuffer = (SummaryStatsData*)reductionBuffer; + + Z startingVal = startingValue(dx); + + SummaryStatsData val; + val.initWithValue(startingVal); + val.n = 0; + sPartials[threadIdx.x] = val; + + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { + sPartials[threadIdx.x] = + update(sPartials[threadIdx.x], pBuffer[i], extraParams); + } - auto x = static_cast(vx); - auto z = static_cast(vz); - auto extraParams = static_cast(vextraParams); + __syncthreads(); + aggregatePartials(&sPartials, threadIdx.x, gridDim.x, + extraParams); + __syncthreads(); - if (sd::Environment::getInstance()->isDebugAndVerbose()) - printf("D18 opNum:[%i]\n", opNum); + if (tid == 0) { + z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); + } + } + } else { + if (tid == 0) { + unsigned int* tc = (unsigned*)reductionBuffer; + tc[16384] = 0; + z[0] = z[0] = OpType::getValue(postProcessOrNot, sPartials[0]); + } + } + } +}; + +template +_CUDA_D void SummaryStatsReduce::transform( + const int opNum, void const* dx, Nd4jLong const* xShapeInfo, + void* extraParams, void* z, Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, int* allocationBuffer, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets) { + DISPATCH_BY_OPNUM_TT( + transform, + PARAMS(dx, xShapeInfo, extraParams, z, zShapeInfo, dimension, + dimensionLength, postProcessOrNot, allocationBuffer, + reductionBuffer, tadOnlyShapeInfo, tadOffsets), + SUMMARY_STATS_OPS); +}; - summaryStatsReduceT<<>>( - opNum, - x, - xShapeInfo, shape::rank(hxShapeInfo), - extraParams, - z, - zShapeInfo, shape::rank(hzShapeInfo), - dimension, - dimensionLength, - 1, biasCorrected, nullptr, reinterpret_cast(reductionBuffer), tadShapeInfo, tadOffsets); +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduceScalar( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto extraParams = static_cast(vextraParams); + auto z = reinterpret_cast(vz); + auto reductionPointerA = reinterpret_cast(reductionBuffer); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("D16 opNum:[%i]\n", opNum); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), nullptr, 1, 1, biasCorrected, + nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + + // this is blocking method since method should return scalar + sd::DebugHelper::checkErrorCode(stream, "execSSReduceScalar(...) failed"); +} - DEBUG_KERNEL(stream, opNum); - } +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("F17 opNum:[%i]\n", opNum); + + auto reductionPointerA = reinterpret_cast(reductionBuffer); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), nullptr, 1, 1, biasCorrected, + nullptr, reductionPointerA, tadShapeInfo, tadOffsets); + + DEBUG_KERNEL(stream, opNum); +} +template +_CUDA_H void SummaryStatsReduce::execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void* vextraParams, + void* vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, + int* dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto extraParams = static_cast(vextraParams); + + if (sd::Environment::getInstance()->isDebugAndVerbose()) + printf("D18 opNum:[%i]\n", opNum); + + summaryStatsReduceT + <<>>( + opNum, x, xShapeInfo, shape::rank(hxShapeInfo), extraParams, z, + zShapeInfo, shape::rank(hzShapeInfo), dimension, dimensionLength, 1, + biasCorrected, nullptr, reinterpret_cast(reductionBuffer), + tadShapeInfo, tadOffsets); + + DEBUG_KERNEL(stream, opNum); +} - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); - } -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT SummaryStatsReduce, , + LIBND4J_TYPES, FLOAT_TYPES); +} // namespace summarystats +} // namespace functions \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_any.cu b/libnd4j/include/loops/cuda/transform/transform_any.cu index 912c693bc0bf..7fcb39b04b04 100644 --- a/libnd4j/include/loops/cuda/transform/transform_any.cu +++ b/libnd4j/include/loops/cuda/transform/transform_any.cu @@ -18,118 +18,112 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; - template -__global__ void transformAnySimple( - const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformAny::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformAnySimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformAny::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformAny::executeTransformShaped( - dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_ANY_OPS); - - DEBUG_KERNEL(stream, opNum); - } - - - template - template - __device__ void TransformAny::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - auto reductionPointer = reinterpret_cast(vreductionPointer); - - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - }; - - - template - template - _CUDA_H void TransformAny::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - transformAnySimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - - sd::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); - } +namespace transform { + +template +_CUDA_H void TransformAny::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_ANY_OPS); + + DEBUG_KERNEL(stream, opNum); +} - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); +template +template +__device__ void TransformAny::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + auto reductionPointer = reinterpret_cast(vreductionPointer); + + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); + } + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } } + } +}; + +template +template +_CUDA_H void TransformAny::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformAnySimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + + sd::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformAny, , LIBND4J_TYPES, + LIBND4J_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index 3f8674b698b5..791a48097a6b 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -18,123 +18,118 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; - template -__global__ void transformBoolSimple( - const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformBool::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformBoolSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformBool::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformBool::executeTransformShaped( - dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_BOOL_OPS); +namespace transform { + +template +_CUDA_H void TransformBool::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_BOOL_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformBool::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformBool::transformCuda( - const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - - template - template - _CUDA_H void TransformBool::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformBoolSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformBool::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformBoolSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformBool, , LIBND4J_TYPES, + BOOL_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index cb0afe95b2fd..85e564798b4f 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -18,130 +18,132 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformFloatSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformFloat::template transformCuda( - x, xShapeInfo, - params, - z, zShapeInfo, - allocationPointer, reductionPointer, - tadShapeInfo, tadOffsets); +__global__ void transformFloatSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformFloat::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformFloat::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS); +namespace transform { + +template +_CUDA_H void TransformFloat::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_FLOAT_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformFloat::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + auto reductionPointer = reinterpret_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (Nd4jLong i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformFloat::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto params = reinterpret_cast(vparams); - auto reductionPointer = reinterpret_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (Nd4jLong i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - template - __device__ void TransformFloat::transformCudaLegacy( - const int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *params, - void *z, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_TT(transformCuda, PARAMS(x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); } - - template - template - _CUDA_H void TransformFloat::intermediateShaped( - dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformFloatSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - - sd::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); - } - - BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); + } } + } +}; + +template +__device__ void TransformFloat::transformCudaLegacy( + const int opNum, const void *x, const Nd4jLong *xShapeInfo, void *params, + void *z, const Nd4jLong *zShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_TT( + transformCuda, + PARAMS(x, xShapeInfo, params, z, zShapeInfo, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_FLOAT_OPS); +} + +template +template +_CUDA_H void TransformFloat::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformFloatSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + + sd::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); } + +BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT TransformFloat, , LIBND4J_TYPES, + FLOAT_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index f79e5d0917a4..91036d08e680 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -18,112 +18,117 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformSameSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformSame::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer, tadShapeInfo, tadOffsets); +__global__ void transformSameSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformSame::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformSame::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_SAME_OPS); +namespace transform { + +template +_CUDA_H void TransformSame::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_SAME_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformSame::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformSame::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - - template - template - _CUDA_H void TransformSame::intermediateShaped(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShape, int xRank, void *extraParams, void *z, const Nd4jLong *zShape, int zRank, int *allocationPointer, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - transformSameSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformSame::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformSameSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformSame, , LIBND4J_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 00966562eeb0..f3f714ec3391 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -18,119 +18,117 @@ // @author raver119@gmail.com // -#include +#include +#include #include -#include +#include #include - -#include -#include +#include using namespace simdOps; template -__global__ void transformStrictSimple(const void *x, const Nd4jLong *xShapeInfo, int xRank, - void *params, - void *z, const Nd4jLong *zShapeInfo, int zRank, - int *allocationPointer, - void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - functions::transform::TransformStrict::template transformCuda(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); +__global__ void transformStrictSimple(const void *x, const Nd4jLong *xShapeInfo, + int xRank, void *params, void *z, + const Nd4jLong *zShapeInfo, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + functions::transform::TransformStrict::template transformCuda( + x, xShapeInfo, params, z, zShapeInfo, allocationPointer, reductionPointer, + tadShapeInfo, tadOffsets); } - namespace functions { - namespace transform { - - template - _CUDA_H void TransformStrict::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - const int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), TRANSFORM_STRICT_OPS); +namespace transform { + +template +_CUDA_H void TransformStrict::executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + DISPATCH_BY_OPNUM_T(intermediateShaped, + PARAMS(launchDims, stream, x, xShape, xRank, extraParams, + z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets), + TRANSFORM_STRICT_OPS); + + DEBUG_KERNEL(stream, opNum); +} - DEBUG_KERNEL(stream, opNum); +template +template +__device__ void TransformStrict::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vparams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *vreductionPointer, + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { + auto x = static_cast(vx); + auto z = static_cast(vz); + auto params = static_cast(vparams); + auto reductionPointer = static_cast(vreductionPointer); + + if (OpType::requiresSpecial) { + OpType::execSpecialCuda(x, xShapeInfo, z, zShapeInfo, params, + allocationPointer, reductionPointer, tadShapeInfo, + tadOffsets); + return; + } else { + __shared__ Nd4jLong xEws; + __shared__ Nd4jLong zEws; + __shared__ char xOrder; + __shared__ char zOrder; + __shared__ Nd4jLong length; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + zOrder = shape::order(zShapeInfo); + length = shape::length(xShapeInfo); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int totalThreads = gridDim.x * blockDim.x; + + if (xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { + for (int i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); + } else { + if (vx == vz) { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + z[xOffset] = OpType::op(x[xOffset], params); } - - - template - template - __device__ void TransformStrict::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vparams, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *vreductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - auto x = static_cast(vx); - auto z = static_cast(vz); - auto params = static_cast(vparams); - auto reductionPointer = static_cast(vreductionPointer); - - - if(OpType::requiresSpecial) { - OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - return; - } - else { - __shared__ Nd4jLong xEws; - __shared__ Nd4jLong zEws; - __shared__ char xOrder; - __shared__ char zOrder; - __shared__ Nd4jLong length; - - if (threadIdx.x == 0) { - - xEws = shape::elementWiseStride(xShapeInfo); - zEws = shape::elementWiseStride(zShapeInfo); - xOrder = shape::order(xShapeInfo); - zOrder = shape::order(zShapeInfo); - length = shape::length(xShapeInfo); - } - __syncthreads(); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int totalThreads = gridDim.x * blockDim.x; - - if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { - - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); - } - else { - if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - z[xOffset] = OpType::op(x[xOffset], params); - } - } - else { - for (Nd4jLong i = tid; i < length; i+= totalThreads) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = OpType::op(x[xOffset], params); - } - } - } - } - }; - - template - template - _CUDA_H void TransformStrict::intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) { - - transformStrictSimple<<>>(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); - sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); + } else { + for (Nd4jLong i = tid; i < length; i += totalThreads) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); + z[zOffset] = OpType::op(x[xOffset], params); + } + } } + } +}; + +template +template +_CUDA_H void TransformStrict::intermediateShaped( + dim3 launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets) { + transformStrictSimple + <<>>( + x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, + reductionPointer, tadShapeInfo, tadOffsets); + sd::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT TransformStrict, , FLOAT_TYPES); +} // namespace transform +} // namespace functions diff --git a/libnd4j/include/loops/cuda/type_conversions.cu b/libnd4j/include/loops/cuda/type_conversions.cu index 328c67080b99..406aaeaf20cb 100644 --- a/libnd4j/include/loops/cuda/type_conversions.cu +++ b/libnd4j/include/loops/cuda/type_conversions.cu @@ -18,522 +18,577 @@ // // +#include #include #include -#include namespace sd { - template - void TypeCast::convertGenericCuda(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - auto stream = reinterpret_cast(&extras[1]); - - sd::convertKernel<<<256, 1024, 1024, *stream>>>(dx, N, dz); - sd::DebugHelper::checkErrorCode(stream, "convertGeneric(...) failed"); - }; - - - template - __device__ void convertKernelGeneric(S *x, Nd4jLong N, T *z) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < N; i+= blockDim.x * gridDim.x) { - // despite it's stupid, it simplifies conversion to bottom dtypes - // FIXME: get rid of through-float though - z[i] = static_cast(static_cast(x[i])); - } - }; - - -// Define this to more rigorously avoid bank conflicts, even at the lower (root) levels of the tree -//#define ZERO_BANK_CONFLICTS +template +void TypeCast::convertGenericCuda(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + auto stream = reinterpret_cast(&extras[1]); + + sd::convertKernel<<<256, 1024, 1024, *stream>>>(dx, N, dz); + sd::DebugHelper::checkErrorCode(stream, "convertGeneric(...) failed"); +}; + +template +__device__ void convertKernelGeneric(S *x, Nd4jLong N, T *z) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < N; i += blockDim.x * gridDim.x) { + // despite it's stupid, it simplifies conversion to bottom dtypes + // FIXME: get rid of through-float though + z[i] = static_cast(static_cast(x[i])); + } +}; + + // Define this to more rigorously avoid bank conflicts, even at the lower + // (root) levels of the tree + //#define ZERO_BANK_CONFLICTS #ifdef ZERO_BANK_CONFLICTS -#define CONFLICT_FREE_OFFSET(index) ((index) >> LOG_NUM_BANKS + (index) >> (2 * LOG_NUM_BANKS)) +#define CONFLICT_FREE_OFFSET(index) \ + ((index) >> LOG_NUM_BANKS + (index) >> (2 * LOG_NUM_BANKS)) #else #define CONFLICT_FREE_OFFSET(index) ((index) >> LOG_NUM_BANKS) #endif #ifdef CHECK_BANK_CONFLICTS -#define TEMP(index) CUT_BANK_CHECKER(temp, index) +#define TEMP(index) CUT_BANK_CHECKER(temp, index) #else -#define TEMP(index) temp[index] +#define TEMP(index) temp[index] #endif +template +__device__ void loadSharedChunkFromMem(int *s_data, const int *g_idata, int n, + int baseIndex, int &ai, int &bi, + int &mem_ai, int &mem_bi, + int &bankOffsetA, int &bankOffsetB) { + int thid = threadIdx.x; + mem_ai = baseIndex + threadIdx.x; + mem_bi = mem_ai + blockDim.x; + + ai = thid; + bi = thid + blockDim.x; + + // compute spacing to avoid bank conflicts + bankOffsetA = CONFLICT_FREE_OFFSET(ai); + bankOffsetB = CONFLICT_FREE_OFFSET(bi); + + // Cache the computational window in shared memory + // pad values beyond n with zeros + s_data[ai + bankOffsetA] = g_idata[mem_ai]; + + if (isNP2) { // compile-time decision + s_data[bi + bankOffsetB] = (bi < n) ? g_idata[mem_bi] : 0; + } else { + s_data[bi + bankOffsetB] = g_idata[mem_bi]; + } +} +template +__device__ void storeSharedChunkToMem(int *g_odata, int *s_data, int n, int ai, + int bi, int mem_ai, int mem_bi, + int bankOffsetA, int bankOffsetB) { + __syncthreads(); + + // write results to global memory + g_odata[mem_ai] = s_data[ai + bankOffsetA]; + if (isNP2) { // compile-time decision + if (bi < n) g_odata[mem_bi] = s_data[bi + bankOffsetB]; + } else { + g_odata[mem_bi] = s_data[bi + bankOffsetB]; + } +} - - - template - __device__ void loadSharedChunkFromMem(int *s_data, const int *g_idata, int n, int baseIndex, int& ai, int& bi, int& mem_ai, int& mem_bi, int& bankOffsetA, int& bankOffsetB) { - int thid = threadIdx.x; - mem_ai = baseIndex + threadIdx.x; - mem_bi = mem_ai + blockDim.x; - - ai = thid; - bi = thid + blockDim.x; - - // compute spacing to avoid bank conflicts - bankOffsetA = CONFLICT_FREE_OFFSET(ai); - bankOffsetB = CONFLICT_FREE_OFFSET(bi); - - // Cache the computational window in shared memory - // pad values beyond n with zeros - s_data[ai + bankOffsetA] = g_idata[mem_ai]; - - if (isNP2) { // compile-time decision - s_data[bi + bankOffsetB] = (bi < n) ? g_idata[mem_bi] : 0; - } else { - s_data[bi + bankOffsetB] = g_idata[mem_bi]; - } - } - - template - __device__ void storeSharedChunkToMem(int* g_odata, int* s_data, int n, int ai, int bi, int mem_ai, int mem_bi, int bankOffsetA, int bankOffsetB) { - __syncthreads(); - - // write results to global memory - g_odata[mem_ai] = s_data[ai + bankOffsetA]; - if (isNP2) { // compile-time decision - if (bi < n) - g_odata[mem_bi] = s_data[bi + bankOffsetB]; - } else { - g_odata[mem_bi] = s_data[bi + bankOffsetB]; - } +template +__device__ void clearLastElement(int *s_data, int *g_blockSums, + int blockIndex) { + if (threadIdx.x == 0) { + int index = (blockDim.x << 1) - 1; + index += CONFLICT_FREE_OFFSET(index); + + if (storeSum) { // compile-time decision + // write this block's total sum to the corresponding index in the + // blockSums array + g_blockSums[blockIndex] = s_data[index]; } - template - __device__ void clearLastElement(int* s_data, int *g_blockSums, int blockIndex) { - if (threadIdx.x == 0) - { - int index = (blockDim.x << 1) - 1; - index += CONFLICT_FREE_OFFSET(index); - - if (storeSum) { // compile-time decision - // write this block's total sum to the corresponding index in the blockSums array - g_blockSums[blockIndex] = s_data[index]; - } - - // zero the last element in the scan so it will propagate back to the front - s_data[index] = 0; - } - } - - - - __device__ unsigned int buildSum(int *s_data) { - unsigned int thid = threadIdx.x; - unsigned int stride = 1; - - // build the sum in place up the tree - for (int d = blockDim.x; d > 0; d >>= 1) { - __syncthreads(); + // zero the last element in the scan so it will propagate back to the front + s_data[index] = 0; + } +} - if (thid < d) { - int i = __mul24(__mul24(2, stride), thid); - int ai = i + stride - 1; - int bi = ai + stride; +__device__ unsigned int buildSum(int *s_data) { + unsigned int thid = threadIdx.x; + unsigned int stride = 1; - ai += CONFLICT_FREE_OFFSET(ai); - bi += CONFLICT_FREE_OFFSET(bi); + // build the sum in place up the tree + for (int d = blockDim.x; d > 0; d >>= 1) { + __syncthreads(); - s_data[bi] += s_data[ai]; - } + if (thid < d) { + int i = __mul24(__mul24(2, stride), thid); + int ai = i + stride - 1; + int bi = ai + stride; - stride *= 2; - } + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); - return stride; + s_data[bi] += s_data[ai]; } - __device__ void scanRootToLeaves(int *s_data, unsigned int stride) { - unsigned int thid = threadIdx.x; + stride *= 2; + } - // traverse down the tree building the scan in place - for (int d = 1; d <= blockDim.x; d *= 2) { - stride >>= 1; + return stride; +} - __syncthreads(); +__device__ void scanRootToLeaves(int *s_data, unsigned int stride) { + unsigned int thid = threadIdx.x; - if (thid < d) { - int i = __mul24(__mul24(2, stride), thid); - int ai = i + stride - 1; - int bi = ai + stride; + // traverse down the tree building the scan in place + for (int d = 1; d <= blockDim.x; d *= 2) { + stride >>= 1; - ai += CONFLICT_FREE_OFFSET(ai); - bi += CONFLICT_FREE_OFFSET(bi); + __syncthreads(); - float t = s_data[ai]; - s_data[ai] = s_data[bi]; - s_data[bi] += t; - } - } - } + if (thid < d) { + int i = __mul24(__mul24(2, stride), thid); + int ai = i + stride - 1; + int bi = ai + stride; - template - __device__ void prescanBlock(int *data, int blockIndex, int *blockSums) { - int stride = buildSum(data); // build the sum in place up the tree - clearLastElement(data, blockSums, - (blockIndex == 0) ? blockIdx.x : blockIndex); - scanRootToLeaves(data, stride); // traverse down tree to build the scan - } + ai += CONFLICT_FREE_OFFSET(ai); + bi += CONFLICT_FREE_OFFSET(bi); + float t = s_data[ai]; + s_data[ai] = s_data[bi]; + s_data[bi] += t; + } + } +} - template - __global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) { - int ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB; - extern __shared__ int s_data[]; +template +__device__ void prescanBlock(int *data, int blockIndex, int *blockSums) { + int stride = buildSum(data); // build the sum in place up the tree + clearLastElement(data, blockSums, + (blockIndex == 0) ? blockIdx.x : blockIndex); + scanRootToLeaves(data, stride); // traverse down tree to build the scan +} - // load data into shared memory - loadSharedChunkFromMem(reinterpret_cast(s_data), g_idata, n, (baseIndex == 0) ? __mul24(blockIdx.x, (blockDim.x << 1)):baseIndex, ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); +template +__global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, + int n, int blockIndex, int baseIndex) { + int ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB; + extern __shared__ int s_data[]; - // scan the data in each block - prescanBlock(s_data, blockIndex, g_blockSums); + // load data into shared memory + loadSharedChunkFromMem( + reinterpret_cast(s_data), g_idata, n, + (baseIndex == 0) ? __mul24(blockIdx.x, (blockDim.x << 1)) : baseIndex, ai, + bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); - // write results to device memory - storeSharedChunkToMem(g_odata, s_data, n, ai, bi, mem_ai, mem_bi, bankOffsetA, bankOffsetB); - } + // scan the data in each block + prescanBlock(s_data, blockIndex, g_blockSums); + // write results to device memory + storeSharedChunkToMem(g_odata, s_data, n, ai, bi, mem_ai, mem_bi, + bankOffsetA, bankOffsetB); +} - __global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, int baseIndex) { - __shared__ float uni; - if (threadIdx.x == 0) - uni = uniforms[blockIdx.x + blockOffset]; +__global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, + int baseIndex) { + __shared__ float uni; + if (threadIdx.x == 0) uni = uniforms[blockIdx.x + blockOffset]; - unsigned int address = __mul24(blockIdx.x, (blockDim.x << 1)) + baseIndex + threadIdx.x; + unsigned int address = + __mul24(blockIdx.x, (blockDim.x << 1)) + baseIndex + threadIdx.x; - __syncthreads(); + __syncthreads(); - // note two adds per thread - g_data[address] += uni; - g_data[address + blockDim.x] += (threadIdx.x + blockDim.x < n) * uni; - } + // note two adds per thread + g_data[address] += uni; + g_data[address + blockDim.x] += (threadIdx.x + blockDim.x < n) * uni; +} /* * This kernel does prefix sum in parallel, to calculate offsets for each block */ - template - __device__ inline void encoderKernelP2Generic(void *dx, Nd4jLong n, void *dz) { - // TODO: to be remove - } +template +__device__ inline void encoderKernelP2Generic(void *dx, Nd4jLong n, void *dz) { + // TODO: to be remove +} ////////////////////////////////////////////////////////////////////////// /* - * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be huge. + * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be + * huge. */ -template -__global__ static void execEncoderKernelP1(const void *dx, Nd4jLong N, void *dz, float threshold) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - //basically, for phase One we want do calculation: how many eligible values we have, and which blocks will be holding data - Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; - - int pass = tid < N && sd::math::nd4j_abs(x[tid]) >= static_cast(threshold) ? 1 : 0; - int bp=__syncthreads_count(pass); - - if (threadIdx.x == 0) { - // saving out per-block passes - z[blockIdx.x+1] = bp; - - // saving out sum - atomicAdd(&z[0], bp); - } +template +__global__ static void execEncoderKernelP1(const void *dx, Nd4jLong N, void *dz, + float threshold) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + // basically, for phase One we want do calculation: how many eligible values + // we have, and which blocks will be holding data + Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; + + int pass = + tid < N && sd::math::nd4j_abs(x[tid]) >= static_cast(threshold) ? 1 + : 0; + int bp = __syncthreads_count(pass); + + if (threadIdx.x == 0) { + // saving out per-block passes + z[blockIdx.x + 1] = bp; + + // saving out sum + atomicAdd(&z[0], bp); + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold) { - - execEncoderKernelP1<<>>(dx, N, dz, threshold); - sd::DebugHelper::checkErrorCode(stream, "encoderP1(...) failed"); +template +__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz, + float threshold) { + execEncoderKernelP1<<>>( + dx, N, dz, threshold); + sd::DebugHelper::checkErrorCode(stream, "encoderP1(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP1Generic, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP1Generic, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *dz, float threshold), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// /* - * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be huge. + * PLEASE NOTE: This kernel doesn't allow loop for data. Basically: grid will be + * huge. * - * Based on: https://github.com/knotman90/cuStreamComp <-- efficient CUDA stream compaction algorithm + * Based on: https://github.com/knotman90/cuStreamComp <-- efficient CUDA stream + * compaction algorithm */ -template -__global__ static void execEncoderKernelP3(void *dx, int *offsets, Nd4jLong N, void *dz) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - extern __shared__ int warpTotals[]; - - // fetch block offset only once - __shared__ float threshold; - __shared__ FloatBits fb; - __shared__ int bo; - __shared__ int limit; - if (threadIdx.x == 0) { - limit = z[0]; - fb.i_ = z[2]; - threshold = fb.f_; - bo = offsets[blockIdx.x]; - } - __syncthreads(); - - // out-of-limit threads do not play here - auto value = tid < N ? x[tid] : (T) 0.f; - - // out-of-limit threads just declare they have no changes - auto pred = tid >= N ? 0 : sd::math::nd4j_abs(value) >= static_cast(threshold) ? 1 : 0; - auto w_i = threadIdx.x / warpSize; // warp index (or, warp number) - index of the Warp within TOTAL_WARPS - auto t_i = threadIdx.x % warpSize; // thread index within a warp - unsigned int t_m = INT_MAX >> (warpSize - t_i - 1); //thread mask (ERROR IN THE PAPER minus one is required) - - int b = __ballot_sync(t_m, pred); // balres = number whose ith bit isone if the ith's thread pred is true masked up to the current index in warp - auto t_u = __popc(b); // popc count the number of bit one. simply count the number predicated true BEFORE MY INDEX - - if (t_i == warpSize - 1) - warpTotals[w_i] = t_u + pred; - - __syncthreads(); - - - int w_i_u = 0; - for (int j = 0; j <= 5; j++) { - unsigned int b_j = __ballot_sync(t_m, warpTotals[t_i] & pow2i(j)); //# of the ones in the j'th digit of the warp offsets - w_i_u += (__popc(b_j) << j); - } - - // we just ignore all results coming from non-0 threads - if (w_i == 0 && t_i < blockDim.x / warpSize) - warpTotals[t_i] = w_i_u; - - __syncthreads(); - - - // pred is always false if we're out-of-limits - if (pred) { - int idx = t_u + warpTotals[w_i] + bo + 4; - if (idx < limit + 4) { - z[idx] = value > static_cast(0.0f) ? tid + 1 : -(tid + 1); - x[tid] = value > static_cast(0.0f) ? x[tid] - threshold : x[tid] + threshold; - } - } +template +__global__ static void execEncoderKernelP3(void *dx, int *offsets, Nd4jLong N, + void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + extern __shared__ int warpTotals[]; + + // fetch block offset only once + __shared__ float threshold; + __shared__ FloatBits fb; + __shared__ int bo; + __shared__ int limit; + if (threadIdx.x == 0) { + limit = z[0]; + fb.i_ = z[2]; + threshold = fb.f_; + bo = offsets[blockIdx.x]; + } + __syncthreads(); + + // out-of-limit threads do not play here + auto value = tid < N ? x[tid] : (T)0.f; + + // out-of-limit threads just declare they have no changes + auto pred = + tid >= N + ? 0 + : sd::math::nd4j_abs(value) >= static_cast(threshold) ? 1 : 0; + auto w_i = threadIdx.x / warpSize; // warp index (or, warp number) - index of + // the Warp within TOTAL_WARPS + auto t_i = threadIdx.x % warpSize; // thread index within a warp + unsigned int t_m = + INT_MAX >> (warpSize - t_i - + 1); // thread mask (ERROR IN THE PAPER minus one is required) + + int b = __ballot_sync( + t_m, pred); // balres = number whose ith bit isone if the ith's thread + // pred is true masked up to the current index in warp + auto t_u = __popc(b); // popc count the number of bit one. simply count the + // number predicated true BEFORE MY INDEX + + if (t_i == warpSize - 1) warpTotals[w_i] = t_u + pred; + + __syncthreads(); + + int w_i_u = 0; + for (int j = 0; j <= 5; j++) { + unsigned int b_j = __ballot_sync( + t_m, + warpTotals[t_i] & + pow2i(j)); //# of the ones in the j'th digit of the warp offsets + w_i_u += (__popc(b_j) << j); + } + + // we just ignore all results coming from non-0 threads + if (w_i == 0 && t_i < blockDim.x / warpSize) warpTotals[t_i] = w_i_u; + + __syncthreads(); + + // pred is always false if we're out-of-limits + if (pred) { + int idx = t_u + warpTotals[w_i] + bo + 4; + if (idx < limit + 4) { + z[idx] = value > static_cast(0.0f) ? tid + 1 : -(tid + 1); + x[tid] = value > static_cast(0.0f) ? x[tid] - threshold + : x[tid] + threshold; + } + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz) { - execEncoderKernelP3<<>>(dx, offsets, N, dz); - sd::DebugHelper::checkErrorCode(stream, "encoderP3(...) failed"); +template +__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, + void *dx, int *offsets, Nd4jLong N, + void *dz) { + execEncoderKernelP3<<>>( + dx, offsets, N, dz); + sd::DebugHelper::checkErrorCode(stream, "encoderP3(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP3Generic, (dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT encoderKernelP3Generic, + (dim3 & launchDims, cudaStream_t *stream, void *dx, + int *offsets, Nd4jLong N, void *dz), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// /* * This kernel handles decode from sparse threshold array, to dense array * * PLEASE NOTE: Z is expected to be memset to 0 -*/ -template + */ +template __global__ static void execDecoderKernel(const void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast (dx); - auto z = reinterpret_cast (dz); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ float threshold; - __shared__ int limit; - - __shared__ FloatBits fb; - if (threadIdx.x == 0) { - limit = x[0]; - fb.i_ = x[2]; - threshold = fb.f_; - } - __syncthreads(); - - for (int e = tid; e < limit; e += blockDim.x * gridDim.x) { - int el = x[e+4]; - int ael = sd::math::nd4j_abs(el) - 1; + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ float threshold; + __shared__ int limit; + + __shared__ FloatBits fb; + if (threadIdx.x == 0) { + limit = x[0]; + fb.i_ = x[2]; + threshold = fb.f_; + } + __syncthreads(); + + for (int e = tid; e < limit; e += blockDim.x * gridDim.x) { + int el = x[e + 4]; + int ael = sd::math::nd4j_abs(el) - 1; + + // TODO: investigate, if += would work better here, as in "decoded + // accumulation" + z[ael] += el > 0 ? threshold : -threshold; + } +} - // TODO: investigate, if += would work better here, as in "decoded accumulation" - z[ael] += el > 0 ? threshold : -threshold; - } +////////////////////////////////////////////////////////////////////////// +template +__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz) { + execDecoderKernel + <<>>(dx, N, dz); + sd::DebugHelper::checkErrorCode(stream, "execDecoder(...) failed"); } +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT decoderKernelGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *dz), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz) { +template +__global__ static void execCudaEncodeBitmapKernel(void *vdx, Nd4jLong N, + int *dz, int *scalar, + int *reductionBuffer, + float threshold) { + auto dx = reinterpret_cast(vdx); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + T off(0.0f); + __shared__ int counter; + __shared__ int *shmem; + __shared__ T *vals; + if (threadIdx.x == 0) { + extern __shared__ char mem[]; + shmem = reinterpret_cast(mem); + vals = reinterpret_cast(shmem + blockDim.x); + counter = 0; + } + __syncthreads(); + + Nd4jLong loopRemainder = N % (blockDim.x * gridDim.x); + Nd4jLong loopLimit = N + (blockDim.x * gridDim.x - loopRemainder); + + for (Nd4jLong i = tid; i < loopLimit; i += blockDim.x * gridDim.x) { + // all threads in block reading stuff + T val = i < N ? dx[i] : off; + T abs = sd::math::nd4j_abs(val); + + int byteId = i / 16 + 4; + int bitId = i % 16; + + shmem[threadIdx.x] = 0; + vals[threadIdx.x] = val; + + if (abs >= static_cast(threshold) && i < N) { + shmem[threadIdx.x] = 1 << (bitId); + atomicAdd(&counter, 1); + if (val < static_cast(0.0f)) { + shmem[threadIdx.x] |= 1 << (bitId + 16); + vals[threadIdx.x] += static_cast(threshold); + } else { + vals[threadIdx.x] -= static_cast(threshold); + } + } else if (abs >= static_cast(threshold) / static_cast(2.0f) && + val < static_cast(0.0f) && i < N) { + atomicAdd(&counter, 1); + shmem[threadIdx.x] = 1 << (bitId + 16); + + vals[threadIdx.x] += static_cast(threshold) / static_cast(2.0f); + } + __syncthreads(); - execDecoderKernel<<>>(dx, N, dz); - sd::DebugHelper::checkErrorCode(stream, "execDecoder(...) failed"); -} -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT decoderKernelGeneric, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz), FLOAT_TYPES); + if (threadIdx.x % 16 == 0 && i < N) { + int byte = 0; + for (int e = 0; e < 16; e++) { + if (i + e >= N) continue; + byte |= shmem[threadIdx.x + e]; + } + dz[byteId] = byte; + } + __syncthreads(); -////////////////////////////////////////////////////////////////////////// -template -__global__ static void execCudaEncodeBitmapKernel(void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold) { - auto dx = reinterpret_cast(vdx); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - T off(0.0f); - __shared__ int counter; - __shared__ int *shmem; - __shared__ T *vals; - if (threadIdx.x == 0){ - extern __shared__ char mem[]; - shmem = reinterpret_cast(mem); - vals = reinterpret_cast(shmem + blockDim.x); - counter = 0; - } - __syncthreads(); - - Nd4jLong loopRemainder = N % (blockDim.x * gridDim.x); - Nd4jLong loopLimit = N + (blockDim.x * gridDim.x - loopRemainder); - - for (Nd4jLong i = tid; i < loopLimit; i += blockDim.x * gridDim.x) { - // all threads in block reading stuff - T val = i < N ? dx[i] : off; - T abs = sd::math::nd4j_abs(val); - - int byteId = i / 16 + 4; - int bitId = i % 16; - - shmem[threadIdx.x] = 0; - vals[threadIdx.x] = val; - - if (abs >= static_cast(threshold) && i < N) { - shmem[threadIdx.x] = 1 << (bitId); - atomicAdd(&counter, 1); - if (val < static_cast(0.0f)) { - shmem[threadIdx.x] |= 1 << (bitId + 16); - vals[threadIdx.x] += static_cast(threshold); - } else { - vals[threadIdx.x] -= static_cast(threshold); - } - } else if (abs >= static_cast(threshold) / static_cast(2.0f) && val < static_cast(0.0f) && i < N) { - atomicAdd(&counter, 1); - shmem[threadIdx.x] = 1 << (bitId + 16); - - vals[threadIdx.x] += static_cast(threshold) / static_cast(2.0f); - } - __syncthreads(); - - if (threadIdx.x % 16 == 0 && i < N) { - int byte = 0; - for (int e = 0; e < 16; e++) { - if (i + e >= N) - continue; - - byte |= shmem[threadIdx.x + e]; - } - dz[byteId] = byte; - } - __syncthreads(); - - if (i < N) - dx[i] = vals[threadIdx.x]; - } - __syncthreads(); + if (i < N) dx[i] = vals[threadIdx.x]; + } + __syncthreads(); - if (threadIdx.x == 0) { - atomicAdd(scalar, counter); - } + if (threadIdx.x == 0) { + atomicAdd(scalar, counter); + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold) { - - execCudaEncodeBitmapKernel<<>>(vdx, N, dz, scalar, reductionBuffer, threshold); - sd::DebugHelper::checkErrorCode(stream, "encodeBitmap(...) failed"); +template +__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vdx, Nd4jLong N, int *dz, + int *scalar, int *reductionBuffer, + float threshold) { + execCudaEncodeBitmapKernel + <<>>( + vdx, N, dz, scalar, reductionBuffer, threshold); + sd::DebugHelper::checkErrorCode(stream, "encodeBitmap(...) failed"); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaEncodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold), FLOAT_TYPES); - +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaEncodeBitmapGeneric, + (dim3 & launchDims, cudaStream_t *stream, void *vdx, + Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, + float threshold), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -__global__ static void execCudaDecodeBitmapKernel(const void *dx, Nd4jLong N, void *vdz) { - auto dz = static_cast(vdz); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ T *shmem; - __shared__ FloatBits fb; - __shared__ float threshold; - __shared__ const int *x; - if (threadIdx.x == 0){ - extern __shared__ char mem[]; - shmem = reinterpret_cast(mem); - x = reinterpret_cast(dx); - fb.i_ = x[2]; - threshold = fb.f_; +template +__global__ static void execCudaDecodeBitmapKernel(const void *dx, Nd4jLong N, + void *vdz) { + auto dz = static_cast(vdz); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ T *shmem; + __shared__ FloatBits fb; + __shared__ float threshold; + __shared__ const int *x; + if (threadIdx.x == 0) { + extern __shared__ char mem[]; + shmem = reinterpret_cast(mem); + x = reinterpret_cast(dx); + fb.i_ = x[2]; + threshold = fb.f_; + } + __syncthreads(); + + int lim = N / 16 + 5; + for (int i = tid; i < N; i += blockDim.x * gridDim.x) { + int byteId = i / 16 + 4; + // printf("I: [%i]; byteId: [%i]\n", i, byteId); + + shmem[threadIdx.x] = dz[i]; + __syncthreads(); + + if (threadIdx.x % 16 == 0) { + int byte = x[byteId]; + + for (int e = 0; e < 16; e++) { + if (i + e >= N) continue; + + int bitId = (i + e) % 16; + + bool hasBit = (byte & 1 << (bitId)) != 0; + bool hasSign = (byte & 1 << (bitId + 16)) != 0; + + if (hasBit) { + if (hasSign) + shmem[threadIdx.x + bitId] -= threshold; + else + shmem[threadIdx.x + bitId] += threshold; + } else if (hasSign) { + shmem[threadIdx.x + bitId] -= threshold / 2; } - __syncthreads(); - - int lim = N / 16 + 5; - for (int i = tid; i < N; i += blockDim.x * gridDim.x) { - int byteId = i / 16 + 4; -// printf("I: [%i]; byteId: [%i]\n", i, byteId); - - shmem[threadIdx.x] = dz[i]; - __syncthreads(); - - if (threadIdx.x % 16 == 0) { - int byte = x[byteId]; - - for (int e = 0; e < 16; e++) { - if (i + e >= N) - continue; - - int bitId = (i + e) % 16; - - bool hasBit = (byte & 1 << (bitId) ) != 0; - bool hasSign = (byte & 1 << (bitId + 16) ) != 0; - - if (hasBit) { - if (hasSign) - shmem[threadIdx.x + bitId] -= threshold; - else - shmem[threadIdx.x + bitId] += threshold; - } else if (hasSign) { - shmem[threadIdx.x + bitId] -= threshold / 2; - } - } - } - __syncthreads(); + } + } + __syncthreads(); - dz[i] = shmem[threadIdx.x]; - } + dz[i] = shmem[threadIdx.x]; + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz) { - - execCudaDecodeBitmapKernel<<>>(dx, N, vdz); - sd::DebugHelper::checkErrorCode(stream, "cudeDecodeBitmap(...) failed"); +template +__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *vdz) { + execCudaDecodeBitmapKernel + <<>>(dx, N, vdz); + sd::DebugHelper::checkErrorCode(stream, "cudeDecodeBitmap(...) failed"); +} +BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaDecodeBitmapGeneric, + (dim3 & launchDims, cudaStream_t *stream, const void *dx, + Nd4jLong N, void *vdz), + FLOAT_TYPES); + +template +__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, + const int *g_idata, int *g_blockSums, int n, + int blockIndex, int baseIndex) { + prescan<<>>( + g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex); + sd::DebugHelper::checkErrorCode(stream, "prescan(...) failed"); +}; + +template +__global__ void convertKernel(void *dx, Nd4jLong N, void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + sd::convertKernelGeneric(x, N, z); } -BUILD_SINGLE_TEMPLATE(template void SD_EXPORT cudaDecodeBitmapGeneric, (dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz), FLOAT_TYPES); - - - template - __host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex) { - prescan<<>>(g_odata, g_idata, g_blockSums, n, blockIndex, baseIndex); - sd::DebugHelper::checkErrorCode(stream, "prescan(...) failed"); - }; - - template - __global__ void convertKernel(void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - sd::convertKernelGeneric(x, N, z); - } - -#define LIBND4J_BOOLS_LOCAL \ - (randomName0, 0), \ - (randomName1, 1) +#define LIBND4J_BOOLS_LOCAL (randomName0, 0), (randomName1, 1) - BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGenericCuda, (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES_EXTENDED, LIBND4J_TYPES_EXTENDED); - BUILD_DOUBLE_TEMPLATE(template void prescanLauncher, (dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex), LIBND4J_BOOLS_LOCAL, LIBND4J_BOOLS_LOCAL); +BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGenericCuda, + (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), + LIBND4J_TYPES_EXTENDED, LIBND4J_TYPES_EXTENDED); +BUILD_DOUBLE_TEMPLATE(template void prescanLauncher, + (dim3 & blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, const int *g_idata, + int *g_blockSums, int n, int blockIndex, int baseIndex), + LIBND4J_BOOLS_LOCAL, LIBND4J_BOOLS_LOCAL); #undef LIBND4J_BOOLS_LOCAL -} \ No newline at end of file +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/loops/impl/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp index a2f302d25194..b503e3045dfc 100644 --- a/libnd4j/include/loops/impl/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -18,222 +18,251 @@ // Created by raver on 6/12/2018. // -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - template - _CUDA_H void TypeCast::convertFromQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - // - auto z = reinterpret_cast(dz); - - auto fx = reinterpret_cast(dx); - auto amin = sd::math::nd4j_abs(fx[0]); - auto amax = sd::math::nd4j_abs(fx[1]); - +template +_CUDA_H void TypeCast::convertFromQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + // + auto z = reinterpret_cast(dz); - auto x = reinterpret_cast(dx) + 8; + auto fx = reinterpret_cast(dx); + auto amin = sd::math::nd4j_abs(fx[0]); + auto amax = sd::math::nd4j_abs(fx[1]); + auto x = reinterpret_cast(dx) + 8; - for (Nd4jLong e = 0; e < N; e++) { - z[e] = static_cast(static_cast(x[e]) / static_cast(DataTypeUtils::max()) * sd::math::nd4j_max(amin, amax)); - } - } - - template - _CUDA_H void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz) { - // find min/max first + for (Nd4jLong e = 0; e < N; e++) { + z[e] = static_cast(static_cast(x[e]) / + static_cast(DataTypeUtils::max()) * + sd::math::nd4j_max(amin, amax)); + } +} - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); +template +_CUDA_H void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + // find min/max first - T mn = DataTypeUtils::max(); - T mx = -DataTypeUtils::max(); + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); - for (Nd4jLong e = 0; e < N; e++) { - T v = x[e]; - if (v < mn) - mn = v; + T mn = DataTypeUtils::max(); + T mx = -DataTypeUtils::max(); - if (v > mx) - mx = v; - } + for (Nd4jLong e = 0; e < N; e++) { + T v = x[e]; + if (v < mn) mn = v; - // we shift by 2 fp32 elements - auto rz = z + 8; + if (v > mx) mx = v; + } - // - auto fz = reinterpret_cast(z); + // we shift by 2 fp32 elements + auto rz = z + 8; - float max = static_cast(mx); - float min = static_cast(mn); + // + auto fz = reinterpret_cast(z); - int max_byte = static_cast(DataTypeUtils::max()); - fz[0] = min; - fz[1] = max; + float max = static_cast(mx); + float min = static_cast(mn); - auto amax = sd::math::nd4j_abs(max); - auto amin = sd::math::nd4j_abs(min); + int max_byte = static_cast(DataTypeUtils::max()); + fz[0] = min; + fz[1] = max; - // now we actually apply quantization - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - rz[e] = static_cast(sd::math::nd4j_round( 1.0f * static_cast(x[e]) / sd::math::nd4j_max(amax, amin) * max_byte)); - } - }; + auto amax = sd::math::nd4j_abs(max); + auto amin = sd::math::nd4j_abs(min); - samediff::Threads::parallel_for(func, 0, N); + // now we actually apply quantization + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + rz[e] = static_cast(sd::math::nd4j_round( + 1.0f * static_cast(x[e]) / + sd::math::nd4j_max(amax, amin) * max_byte)); } + }; + + samediff::Threads::parallel_for(func, 0, N); +} - template - void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - // we suppose that first 4 bytes are integer, second 4 bytes are float - // integer: enc length - // integer: dec length - // float: threshold - FloatBits fb; - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - int limit = z[0]; - fb.i_ = z[2]; - float threshold = fb.f_; - - // TODO: int limit is sad thing here, 2B elements limitation - auto l = static_cast(N); - z[1] = l; +template +void TypeCast::convertToThreshold(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + // we suppose that first 4 bytes are integer, second 4 bytes are float + // integer: enc length + // integer: dec length + // float: threshold + FloatBits fb; + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + int limit = z[0]; + fb.i_ = z[2]; + float threshold = fb.f_; + + // TODO: int limit is sad thing here, 2B elements limitation + auto l = static_cast(N); + z[1] = l; #ifdef _OPENMP - int threads = OmpLaunchHelper::betterThreads(N); - auto span = OmpLaunchHelper::betterSpan(N, threads); + int threads = OmpLaunchHelper::betterThreads(N); + auto span = OmpLaunchHelper::betterSpan(N, threads); #else - int threads = 1; - auto span = N; + int threads = 1; + auto span = N; #endif + T tt = static_cast(threshold); + T mtt = -tt; + + // we use 3 as offset, since first 12 bytes are occupied with header + int flimit = limit + 4; + volatile int cnt = 4; + volatile bool flag = false; + PRAGMA_OMP_PARALLEL_THREADS(threads) { + int tid = omp_get_thread_num(); + int start = span * tid; + int stop = span * (tid + 1); + if (stop > l) stop = l; + + for (int e = start; e < stop; e++) { + bool flag_load; + PRAGMA_OMP_ATOMIC_ARGS(read) + flag_load = flag; + if (flag_load) break; + + T cUpd = x[e]; + if (cUpd >= tt) { + int idx; + PRAGMA_OMP_ATOMIC_ARGS(capture) + idx = cnt++; + + if (idx >= flimit) { + PRAGMA_OMP_ATOMIC_ARGS(write) + flag = true; + break; + } - T tt = static_cast(threshold); - T mtt = -tt; - - // we use 3 as offset, since first 12 bytes are occupied with header - int flimit = limit + 4; - volatile int cnt = 4; - volatile bool flag = false; - PRAGMA_OMP_PARALLEL_THREADS(threads) - { - int tid = omp_get_thread_num(); - int start = span * tid; - int stop = span * (tid + 1); - if (stop > l) - stop = l; - - for (int e = start; e < stop; e++) { - bool flag_load; -PRAGMA_OMP_ATOMIC_ARGS(read) - flag_load = flag; - if (flag_load) - break; - - T cUpd = x[e]; - if (cUpd >= tt) { - int idx; -PRAGMA_OMP_ATOMIC_ARGS(capture) - idx = cnt++; - - if (idx >= flimit) { -PRAGMA_OMP_ATOMIC_ARGS(write) - flag = true; - break; - } - - z[idx] = e + 1; - x[e] -= tt; - } else if (cUpd <= mtt) { - int idx; -PRAGMA_OMP_ATOMIC_ARGS(capture) - idx = cnt++; - - if (idx >= flimit) { -PRAGMA_OMP_ATOMIC_ARGS(write) - flag = true; - break; - } - - - z[idx] = -e - 1; - x[e] += tt; - } - } + z[idx] = e + 1; + x[e] -= tt; + } else if (cUpd <= mtt) { + int idx; + PRAGMA_OMP_ATOMIC_ARGS(capture) + idx = cnt++; + + if (idx >= flimit) { + PRAGMA_OMP_ATOMIC_ARGS(write) + flag = true; + break; } + + z[idx] = -e - 1; + x[e] += tt; + } } + } +} - template - void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz) { - FloatBits fb; - auto z = reinterpret_cast(dz); - auto x = reinterpret_cast(dx); - int limit = x[0]; - fb.i_ = x[2]; - float threshold = fb.f_; - - // we use 3 as offset, since first 12 bytes are occupied with header - int flimit = limit + 4; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - int el = x[e]; - int ael = sd::math::nd4j_abs(el) - 1; - z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); - } - }; - - samediff::Threads::parallel_for(func, 4, flimit); +template +void TypeCast::convertFromThreshold(Nd4jPointer *extras, const void *dx, + Nd4jLong N, void *dz) { + FloatBits fb; + auto z = reinterpret_cast(dz); + auto x = reinterpret_cast(dx); + int limit = x[0]; + fb.i_ = x[2]; + float threshold = fb.f_; + + // we use 3 as offset, since first 12 bytes are occupied with header + int flimit = limit + 4; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int el = x[e]; + int ael = sd::math::nd4j_abs(el) - 1; + z[ael] += el > 0 ? static_cast(threshold) : static_cast(-threshold); } + }; + + samediff::Threads::parallel_for(func, 4, flimit); +} - /** - * This is cpu version, so leave it here as inline, to avoid templates instantiation - * - * @tparam S - * @tparam T - * @param dx - * @param N - * @param dz - */ - template - void TypeCast::convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - z[i] = static_cast(static_cast(x[i])); - } - }; - samediff::Threads::parallel_for(func, 0, N); - }; - - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertFromQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - template void TypeCast::convertToQuantized(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +/** + * This is cpu version, so leave it here as inline, to avoid templates + * instantiation + * + * @tparam S + * @tparam T + * @param dx + * @param N + * @param dz + */ +template +void TypeCast::convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + z[i] = static_cast(static_cast(x[i])); + } + }; + samediff::Threads::parallel_for(func, 0, N); +}; + +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertFromThreshold(Nd4jPointer *extras, + const void *dx, + Nd4jLong N, void *dz); + +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToThreshold(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); + +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertFromQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); + +template void TypeCast::convertToQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); +template void TypeCast::convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +template void TypeCast::convertToQuantized(Nd4jPointer *extras, + void *dx, Nd4jLong N, + void *dz); #ifndef __CLION_IDE__ - BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGeneric, (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES, LIBND4J_TYPES) +BUILD_DOUBLE_TEMPLATE(template void TypeCast::convertGeneric, + (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), + LIBND4J_TYPES, LIBND4J_TYPES) #endif -} +} // namespace sd diff --git a/libnd4j/include/loops/indexreduce.h b/libnd4j/include/loops/indexreduce.h old mode 100755 new mode 100644 index 2e8bc33d20af..e7a14460b2fa --- a/libnd4j/include/loops/indexreduce.h +++ b/libnd4j/include/loops/indexreduce.h @@ -26,103 +26,92 @@ #ifdef _OPENMP #include #endif -#include +#include +#include #include +#include #include -#include -#include #ifdef __CUDACC__ #include #include #endif - #include - -#include "system/pairwise_util.h" - #include "legacy_ops.h" +#include "system/pairwise_util.h" namespace functions { - namespace indexreduce { +namespace indexreduce { - template - class IndexReduce { - public: +template +class IndexReduce { + public: #ifdef __CUDABLAS__ - static __device__ void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension,int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); - - template - static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *extraParams); - - - template - static __device__ void transform(const void *dx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - - static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, - int op, - const void *dx, const Nd4jLong *xShapeInfo, - int xRank, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, - int op, - const void *dx, const Nd4jLong *xShapeInfo, - int xRank, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo, - int zRank, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationBuffer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); + static __device__ void transform( + int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams, + void *result, const Nd4jLong *resultShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); + + template + static __device__ void aggregatePartials(IndexValue **sPartialsRef, + Nd4jLong tid, Nd4jLong numElements, + void *extraParams); + + template + static __device__ void transform(const void *dx, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, + int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeIndexReduceScalar( + dim3 launchDims, cudaStream_t *stream, int op, const void *dx, + const Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, int zRank, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeIndexReduce( + dim3 launchDims, cudaStream_t *stream, int op, const void *dx, + const Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, + const Nd4jLong *resultShapeInfo, int zRank, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationBuffer, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); #else - static Nd4jLong execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); - - template - static _CUDA_H Nd4jLong execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset); + static Nd4jLong execScalar(int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); + + template + static _CUDA_H Nd4jLong execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset); #endif - }; - } -} +}; +} // namespace indexreduce +} // namespace functions #endif /* INDEXREDUCE_H_ */ - diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 001f8806c6db..c54d53bd63b7 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -21,392 +21,145 @@ #ifndef PROJECT_LEGACY_OPS_H #define PROJECT_LEGACY_OPS_H -#define AGGREGATE_OPS \ - (0, aggregateOps::HierarchicSoftmax) ,\ - (1, aggregateOps::Dot) ,\ - (2, aggregateOps::Axpy) ,\ - (3, aggregateOps::SkipGram) ,\ - (4, aggregateOps::CBOW) ,\ - (5, aggregateOps::GEMM) - -#define BROADCAST_INT_OPS \ - (0, ShiftLeft), \ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - - -#define BROADCAST_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define BROADCAST_OPS \ - (0, Add), \ - (1, Subtract), \ - (2, Multiply), \ - (3, Divide), \ - (4, ReverseDivide), \ - (5, ReverseSubtract), \ - (6, CopyPws), \ - (7, Pow), \ - (13, MinPairwise) ,\ - (14, MaxPairwise) ,\ - (15, AMinPairwise) ,\ - (16, AMaxPairwise) ,\ - (17, SquaredSubtract),\ - (18, FloorMod),\ - (19, FloorDiv),\ - (20, ReverseMod),\ - (21, SafeDivide),\ - (22, Mod) ,\ - (23, TruncateDiv), \ - (26, Atan2) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd), \ - (31, DivideNoNan), \ - (32, IGamma), \ - (33, IGammac),\ - (34, PowDerivative) +#define AGGREGATE_OPS \ + (0, aggregateOps::HierarchicSoftmax), (1, aggregateOps::Dot), \ + (2, aggregateOps::Axpy), (3, aggregateOps::SkipGram), \ + (4, aggregateOps::CBOW), (5, aggregateOps::GEMM) + +#define BROADCAST_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define BROADCAST_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define BROADCAST_OPS \ + (0, Add), (1, Subtract), (2, Multiply), (3, Divide), (4, ReverseDivide), \ + (5, ReverseSubtract), (6, CopyPws), (7, Pow), (13, MinPairwise), \ + (14, MaxPairwise), (15, AMinPairwise), (16, AMaxPairwise), \ + (17, SquaredSubtract), (18, FloorMod), (19, FloorDiv), (20, ReverseMod), \ + (21, SafeDivide), (22, Mod), (23, TruncateDiv), (26, Atan2), \ + (27, LogicalOr), (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), \ + (31, DivideNoNan), (32, IGamma), (33, IGammac), (34, PowDerivative) // these ops return same data type as input -#define TRANSFORM_SAME_OPS \ - (0, Abs), \ - (1, Sign), \ - (2, Ones), \ - (3, Neg), \ - (4, Round), \ - (5, TimesOneMinus), \ - (6, Cube), \ - (7, OneMinus), \ - (11, Reciprocal), \ - (12, Square), \ - (13, CompareAndSetTransform) ,\ - (15, Identity), \ - (17, Ceiling), \ - (18, Floor), \ - (19, ClipByValue) ,\ - (21, Copy) +#define TRANSFORM_SAME_OPS \ + (0, Abs), (1, Sign), (2, Ones), (3, Neg), (4, Round), (5, TimesOneMinus), \ + (6, Cube), (7, OneMinus), (11, Reciprocal), (12, Square), \ + (13, CompareAndSetTransform), (15, Identity), (17, Ceiling), \ + (18, Floor), (19, ClipByValue), (21, Copy) -#define TRANSFORM_ANY_OPS \ - (0, Assign) +#define TRANSFORM_ANY_OPS (0, Assign) // these ops return bool -#define TRANSFORM_BOOL_OPS \ - (1, IsInf), \ - (2, IsNan), \ - (3, IsFinite), \ - (4, IsInfOrNan), \ - (5, MatchConditionBool), \ - (6, IsPositive) , \ - (7, Not), \ - (8, IsNegative) - - -#define TRANSFORM_STRICT_OPS \ - (2, ScaledTanh), \ - (3, Affine), \ - (4, TanhDerivative), \ - (5, HardTanhDerivative), \ - (6, SigmoidDerivative), \ - (7, SoftSignDerivative), \ - (8, TanDerivative) ,\ - (9, SELUDerivative) ,\ - (10, HardSigmoidDerivative) ,\ - (11, RationalTanhDerivative) ,\ - (12, RectifiedTanhDerivative) ,\ - (13, SwishDerivative) ,\ - (14, ACoshDerivative) ,\ - (15, ASinhDerivative) ,\ - (16, SinhDerivative), \ - (17, LogSigmoidDerivative) ,\ - (18, SpecialDerivative), \ - (19, Stabilize), \ - (20, StabilizeFP16) ,\ - (21, CubeDerivative) ,\ - (22, Cosine), \ - (23, Exp), \ - (24, Log), \ - (25, SetRange), \ - (26, Sigmoid), \ - (27, Sin), \ - (28, SoftPlus), \ - (29, Tanh), \ - (30, ACos), \ - (31, ASin), \ - (32, ATan), \ - (33, HardTanh), \ - (34, SoftSign), \ - (36, HardSigmoid), \ - (37, RationalTanh) ,\ - (38, RectifiedTanh) ,\ - (39, Sinh) ,\ - (40, Cosh) ,\ - (41, Tan) ,\ - (42, SELU) ,\ - (43, Swish) ,\ - (44, Log1p), \ - (45, Erf), \ - (46, ACosh), \ - (47, ASinh), \ - (48, Rint), \ - (49, LogSigmoid), \ - (50, Erfc) ,\ - (51, Expm1), \ - (52, ATanh) ,\ - (53, GELU) ,\ - (54, GELUDerivative), \ - (55, PreciseGELU) ,\ - (56, PreciseGELUDerivative), \ - (57, Mish),\ - (58, MishDerivative) +#define TRANSFORM_BOOL_OPS \ + (1, IsInf), (2, IsNan), (3, IsFinite), (4, IsInfOrNan), \ + (5, MatchConditionBool), (6, IsPositive), (7, Not), (8, IsNegative) + +#define TRANSFORM_STRICT_OPS \ + (2, ScaledTanh), (3, Affine), (4, TanhDerivative), (5, HardTanhDerivative), \ + (6, SigmoidDerivative), (7, SoftSignDerivative), (8, TanDerivative), \ + (9, SELUDerivative), (10, HardSigmoidDerivative), \ + (11, RationalTanhDerivative), (12, RectifiedTanhDerivative), \ + (13, SwishDerivative), (14, ACoshDerivative), (15, ASinhDerivative), \ + (16, SinhDerivative), (17, LogSigmoidDerivative), \ + (18, SpecialDerivative), (19, Stabilize), (20, StabilizeFP16), \ + (21, CubeDerivative), (22, Cosine), (23, Exp), (24, Log), \ + (25, SetRange), (26, Sigmoid), (27, Sin), (28, SoftPlus), (29, Tanh), \ + (30, ACos), (31, ASin), (32, ATan), (33, HardTanh), (34, SoftSign), \ + (36, HardSigmoid), (37, RationalTanh), (38, RectifiedTanh), (39, Sinh), \ + (40, Cosh), (41, Tan), (42, SELU), (43, Swish), (44, Log1p), (45, Erf), \ + (46, ACosh), (47, ASinh), (48, Rint), (49, LogSigmoid), (50, Erfc), \ + (51, Expm1), (52, ATanh), (53, GELU), (54, GELUDerivative), \ + (55, PreciseGELU), (56, PreciseGELUDerivative), (57, Mish), \ + (58, MishDerivative) // these ops return one of FLOAT data types -#define TRANSFORM_FLOAT_OPS \ - (1, Sqrt), \ - (3, RSqrt) - +#define TRANSFORM_FLOAT_OPS (1, Sqrt), (3, RSqrt) #define SUMMARY_STATS_OPS \ - (0, SummaryStatsVariance), \ - (1, SummaryStatsStandardDeviation) - -#define SCALAR_INT_OPS \ - (0, ShiftLeft) ,\ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - -#define SCALAR_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define SCALAR_OPS \ - (0, Add),\ - (1, Subtract),\ - (2, Multiply),\ - (3, Divide),\ - (4, ReverseDivide),\ - (5, ReverseSubtract),\ - (6, MaxPairwise),\ - (7, ELU), \ - (8, ELUDerivative), \ - (13, MinPairwise),\ - (14, CopyPws),\ - (15, Mod),\ - (16, ReverseMod),\ - (17, Remainder),\ - (18, FMod) ,\ - (19, TruncateDiv) ,\ - (20, FloorDiv) ,\ - (21, FloorMod), \ - (22, SquaredSubtract),\ - (23, SafeDivide), \ - (24, AMaxPairwise), \ - (25, AMinPairwise), \ - (26, Atan2) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd) ,\ - (31, Pow) ,\ - (32, PowDerivative) ,\ - (33, CompareAndSet) ,\ - (34, SXELogitsSmoother), \ - (35, LeakyRELU), \ - (36, LeakyRELUDerivative), \ - (37, ReplaceNans) ,\ - (38, LogX) ,\ - (39, RELU), \ - (40, RELU6), \ - (41, Step), \ - (42, LstmClip), \ - (43, TruncateMod) ,\ - (44, SquaredReverseSubtract) ,\ - (45, ReversePow), \ - (46, DivideNoNan), \ - (47, IGamma), \ - (48, IGammac), \ - (49, RELUDerivative) - - - - - -#define REDUCE3_OPS \ - (0, ManhattanDistance), \ - (1, EuclideanDistance), \ - (2, CosineSimilarity), \ - (3, Dot), \ - (4, EqualsWithEps) ,\ - (5, CosineDistance) ,\ - (6, JaccardDistance) ,\ - (7, SimpleHammingDistance) - -#define REDUCE_LONG_OPS \ - (0, CountNonZero), \ - (1, CountZero), \ - (2, MatchCondition) - -#define REDUCE_BOOL_OPS \ - (0, Any) ,\ - (1, All), \ - (2, IsFinite), \ - (3, IsInfOrNan), \ - (4, IsNan), \ - (5, IsInf), \ - (6, IsPositive), \ - (7, IsNegative) - -#define REDUCE_SAME_OPS \ - (0, Sum), \ - (1, Max), \ - (2, Min), \ - (3, Prod), \ - (4, ASum), \ - (5, AMax) ,\ - (6, AMin) ,\ - (7, ReduceSameBenchmarkOp) - -#define REDUCE_FLOAT_OPS \ - (0, Mean), \ - (1, AMean) ,\ - (2, Norm1), \ - (3, Norm2), \ - (4, NormMax), \ - (5, NormFrobenius), \ - (6, NormP), \ - (7, SquaredNorm) ,\ - (8, Entropy) ,\ - (9, LogEntropy) ,\ - (10, ShannonEntropy) ,\ - (12, ReduceFloatBenchmarkOp) - - - - -#define RANDOM_OPS \ - (0, UniformDistribution) ,\ - (1, DropOut) ,\ - (2, DropOutInverted) ,\ - (3, ProbablisticMerge) ,\ - (4, Linspace) ,\ - (5, Choice) ,\ - (6, GaussianDistribution) ,\ - (7, BernoulliDistribution) ,\ - (8, BinomialDistribution),\ - (9, BinomialDistributionEx),\ - (10, LogNormalDistribution) ,\ - (11, TruncatedNormalDistribution) ,\ - (12, AlphaDropOut),\ - (13, ExponentialDistribution),\ - (14, ExponentialDistributionInv), \ - (15, PoissonDistribution), \ - (16, GammaDistribution) - -#define PAIRWISE_INT_OPS \ - (0, ShiftLeft), \ - (1, ShiftRight), \ - (2, CyclicShiftLeft), \ - (3, CyclicShiftRight), \ - (4, IntAnd), \ - (5, IntOr), \ - (6, IntXor) - -#define PAIRWISE_BOOL_OPS \ - (0, EqualTo),\ - (1, GreaterThan),\ - (2, LessThan),\ - (3, Epsilon),\ - (4, GreaterThanOrEqual),\ - (5, MatchCondition) ,\ - (6, NotEqualTo),\ - (7, And),\ - (8, Or),\ - (9, Xor) ,\ - (10, Not) ,\ - (11, LessThanOrEqual) - -#define PAIRWISE_TRANSFORM_OPS \ - (0, Add),\ - (1, CopyPws),\ - (2, Divide),\ - (3, Multiply),\ - (4, Pow),\ - (5, ReverseSubtract),\ - (6, Subtract),\ - (7, MaxPairwise),\ - (8, MinPairwise),\ - (9, Copy2) ,\ - (10, Axpy),\ - (11, ReverseDivide),\ - (12, CompareAndSet),\ - (13, CompareAndReplace),\ - (14, Remainder),\ - (15, FMod),\ - (16, Atan2) ,\ - (17, TruncateDiv),\ - (18, FloorDiv), \ - (19, FloorMod) ,\ - (20, SquaredSubtract) ,\ - (21, ReverseMod),\ - (22, SafeDivide), \ - (23, Mod) ,\ - (24, RelativeError) ,\ - (25, BinaryRelativeError) ,\ - (26, BinaryMinimumAbsoluteRelativeError) ,\ - (27, LogicalOr) ,\ - (28, LogicalXor) ,\ - (29, LogicalNot) ,\ - (30, LogicalAnd) ,\ - (31, PowDerivative), \ - (32, LogPoissonLoss), \ - (33, LogPoissonLossFull) , \ - (34, AMaxPairwise), \ - (35, AMinPairwise) ,\ - (36, TruncateMod), \ - (37, ReplaceNans), \ - (38, DivideNoNan), \ - (39, IGamma), \ - (40, IGammac) - - - -#define INDEX_REDUCE_OPS \ - (0, IndexMax), \ - (1, IndexMin), \ - (2, IndexAbsoluteMax), \ - (3, IndexAbsoluteMin) , \ - (4, FirstIndex) , \ - (5, LastIndex) - - - -#endif //PROJECT_LEGACY_OPS_H + (0, SummaryStatsVariance), (1, SummaryStatsStandardDeviation) + +#define SCALAR_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define SCALAR_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define SCALAR_OPS \ + (0, Add), (1, Subtract), (2, Multiply), (3, Divide), (4, ReverseDivide), \ + (5, ReverseSubtract), (6, MaxPairwise), (7, ELU), (8, ELUDerivative), \ + (13, MinPairwise), (14, CopyPws), (15, Mod), (16, ReverseMod), \ + (17, Remainder), (18, FMod), (19, TruncateDiv), (20, FloorDiv), \ + (21, FloorMod), (22, SquaredSubtract), (23, SafeDivide), \ + (24, AMaxPairwise), (25, AMinPairwise), (26, Atan2), (27, LogicalOr), \ + (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), (31, Pow), \ + (32, PowDerivative), (33, CompareAndSet), (34, SXELogitsSmoother), \ + (35, LeakyRELU), (36, LeakyRELUDerivative), (37, ReplaceNans), \ + (38, LogX), (39, RELU), (40, RELU6), (41, Step), (42, LstmClip), \ + (43, TruncateMod), (44, SquaredReverseSubtract), (45, ReversePow), \ + (46, DivideNoNan), (47, IGamma), (48, IGammac), (49, RELUDerivative) + +#define REDUCE3_OPS \ + (0, ManhattanDistance), (1, EuclideanDistance), (2, CosineSimilarity), \ + (3, Dot), (4, EqualsWithEps), (5, CosineDistance), (6, JaccardDistance), \ + (7, SimpleHammingDistance) + +#define REDUCE_LONG_OPS (0, CountNonZero), (1, CountZero), (2, MatchCondition) + +#define REDUCE_BOOL_OPS \ + (0, Any), (1, All), (2, IsFinite), (3, IsInfOrNan), (4, IsNan), (5, IsInf), \ + (6, IsPositive), (7, IsNegative) + +#define REDUCE_SAME_OPS \ + (0, Sum), (1, Max), (2, Min), (3, Prod), (4, ASum), (5, AMax), (6, AMin), \ + (7, ReduceSameBenchmarkOp) + +#define REDUCE_FLOAT_OPS \ + (0, Mean), (1, AMean), (2, Norm1), (3, Norm2), (4, NormMax), \ + (5, NormFrobenius), (6, NormP), (7, SquaredNorm), (8, Entropy), \ + (9, LogEntropy), (10, ShannonEntropy), (12, ReduceFloatBenchmarkOp) + +#define RANDOM_OPS \ + (0, UniformDistribution), (1, DropOut), (2, DropOutInverted), \ + (3, ProbablisticMerge), (4, Linspace), (5, Choice), \ + (6, GaussianDistribution), (7, BernoulliDistribution), \ + (8, BinomialDistribution), (9, BinomialDistributionEx), \ + (10, LogNormalDistribution), (11, TruncatedNormalDistribution), \ + (12, AlphaDropOut), (13, ExponentialDistribution), \ + (14, ExponentialDistributionInv), (15, PoissonDistribution), \ + (16, GammaDistribution) + +#define PAIRWISE_INT_OPS \ + (0, ShiftLeft), (1, ShiftRight), (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), (4, IntAnd), (5, IntOr), (6, IntXor) + +#define PAIRWISE_BOOL_OPS \ + (0, EqualTo), (1, GreaterThan), (2, LessThan), (3, Epsilon), \ + (4, GreaterThanOrEqual), (5, MatchCondition), (6, NotEqualTo), (7, And), \ + (8, Or), (9, Xor), (10, Not), (11, LessThanOrEqual) + +#define PAIRWISE_TRANSFORM_OPS \ + (0, Add), (1, CopyPws), (2, Divide), (3, Multiply), (4, Pow), \ + (5, ReverseSubtract), (6, Subtract), (7, MaxPairwise), (8, MinPairwise), \ + (9, Copy2), (10, Axpy), (11, ReverseDivide), (12, CompareAndSet), \ + (13, CompareAndReplace), (14, Remainder), (15, FMod), (16, Atan2), \ + (17, TruncateDiv), (18, FloorDiv), (19, FloorMod), \ + (20, SquaredSubtract), (21, ReverseMod), (22, SafeDivide), (23, Mod), \ + (24, RelativeError), (25, BinaryRelativeError), \ + (26, BinaryMinimumAbsoluteRelativeError), (27, LogicalOr), \ + (28, LogicalXor), (29, LogicalNot), (30, LogicalAnd), \ + (31, PowDerivative), (32, LogPoissonLoss), (33, LogPoissonLossFull), \ + (34, AMaxPairwise), (35, AMinPairwise), (36, TruncateMod), \ + (37, ReplaceNans), (38, DivideNoNan), (39, IGamma), (40, IGammac) + +#define INDEX_REDUCE_OPS \ + (0, IndexMax), (1, IndexMin), (2, IndexAbsoluteMax), (3, IndexAbsoluteMin), \ + (4, FirstIndex), (5, LastIndex) + +#endif // PROJECT_LEGACY_OPS_H diff --git a/libnd4j/include/loops/pairwise_bool.h b/libnd4j/include/loops/pairwise_bool.h index 9cc8f220cbc8..1c87a28ed923 100644 --- a/libnd4j/include/loops/pairwise_bool.h +++ b/libnd4j/include/loops/pairwise_bool.h @@ -26,87 +26,72 @@ #ifdef _OPENMP #include #endif -#include +#include #include -#include -#include -#include +#include #include +#include +#include #include -#include +#include #ifdef __CUDACC__ #include #include #endif - #include "legacy_ops.h" using namespace simdOps; namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseBoolTransform { - public: - +template +class PairWiseBoolTransform { + public: #ifdef __CUDACC__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); - - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *result, const Nd4jLong *resultShapeBuffer, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *dx, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *result, Nd4jLong resultStride, - void *extraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeBuffer, - const void *vy, const Nd4jLong* yShapeBuffer, - void *vresult, const Nd4jLong* resultShapeBuffer, - void *vextraParams, - uint64_t start, uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeBuffer, + const void *y, const Nd4jLong *yShapeBuffer, void *result, + const Nd4jLong *resultShapeBuffer, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *dx, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *result, Nd4jLong resultStride, + void *extraParams, Nd4jLong n, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeBuffer, const void *vy, + const Nd4jLong *yShapeBuffer, void *vresult, + const Nd4jLong *resultShapeBuffer, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong n, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/pairwise_int.h b/libnd4j/include/loops/pairwise_int.h index 64deebc04723..d80c32989dcb 100644 --- a/libnd4j/include/loops/pairwise_int.h +++ b/libnd4j/include/loops/pairwise_int.h @@ -26,88 +26,72 @@ #ifdef _OPENMP #include #endif -#include +#include #include -#include -#include -#include +#include #include +#include +#include #include -#include +#include #ifdef __CUDACC__ #include #include #endif - - #include "legacy_ops.h" using namespace simdOps; namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseIntTransform { - public: - +template +class PairWiseIntTransform { + public: #ifdef __CUDACC__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); - - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *result, const Nd4jLong *resultShapeBuffer, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *dx, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *result, Nd4jLong resultStride, - void *extraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeBuffer, - const void *vy, const Nd4jLong* yShapeBuffer, - void *vresult, const Nd4jLong* resultShapeBuffer, - void *vextraParams, - uint64_t start,uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong n, - uint64_t start, uint64_t stop); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeBuffer, + const void *y, const Nd4jLong *yShapeBuffer, void *result, + const Nd4jLong *resultShapeBuffer, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *dx, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *result, Nd4jLong resultStride, + void *extraParams, Nd4jLong n, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeBuffer, const void *vy, + const Nd4jLong *yShapeBuffer, void *vresult, + const Nd4jLong *resultShapeBuffer, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong n, uint64_t start, + uint64_t stop); #endif - }; - } -} +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/pairwise_transform.h b/libnd4j/include/loops/pairwise_transform.h old mode 100755 new mode 100644 index b3b514df6235..008cb553ac0b --- a/libnd4j/include/loops/pairwise_transform.h +++ b/libnd4j/include/loops/pairwise_transform.h @@ -27,13 +27,14 @@ #include #endif -#include -#include +#include #include +#include +#include #include #include + #include "legacy_ops.h" -#include #ifdef __CUDACC__ #include @@ -41,68 +42,53 @@ #include #endif - namespace functions { - namespace pairwise_transforms { +namespace pairwise_transforms { /** * Transforms involving 2 arrays */ - template - class PairWiseTransform { - public: - +template +class PairWiseTransform { + public: #ifdef __CUDABLAS__ - template - static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *vextraParams); + template + static __host__ void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *vextraParams); - static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams); + static __host__ void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, void *extraParams); #endif - public: - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t start, uint64_t stop); - - static void exec(int opNum, - const void *x, Nd4jLong xStride, - const void *y, Nd4jLong yStride, - void *z, Nd4jLong resultStride, - void *extraParams, - Nd4jLong len, - uint64_t start, uint64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong* xShapeInfo, - const void *vy, const Nd4jLong* yShapeInfo, - void *vresult, const Nd4jLong* zShapeInfo, - void *vextraParams, - uint64_t start, uint64_t stop); - - template - static void exec(const void *vx, Nd4jLong xStride, - const void *vy, Nd4jLong yStride, - void *vresult, Nd4jLong resultStride, - void *vextraParams, - Nd4jLong len, - uint64_t start, uint64_t stop); - }; - } -} + public: + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t start, uint64_t stop); + + static void exec(int opNum, const void *x, Nd4jLong xStride, const void *y, + Nd4jLong yStride, void *z, Nd4jLong resultStride, + void *extraParams, Nd4jLong len, uint64_t start, + uint64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vresult, + const Nd4jLong *zShapeInfo, void *vextraParams, + uint64_t start, uint64_t stop); + + template + static void exec(const void *vx, Nd4jLong xStride, const void *vy, + Nd4jLong yStride, void *vresult, Nd4jLong resultStride, + void *vextraParams, Nd4jLong len, uint64_t start, + uint64_t stop); +}; +} // namespace pairwise_transforms +} // namespace functions #endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/random.h b/libnd4j/include/loops/random.h index 9b35f472fd25..6c46fcd18def 100644 --- a/libnd4j/include/loops/random.h +++ b/libnd4j/include/loops/random.h @@ -21,81 +21,84 @@ #ifndef LIBND4J_RANDOM_H #define LIBND4J_RANDOM_H - - -#include #include +#include +#include #include #include -#include - - namespace functions { - namespace random { - - template - class RandomFunction { - public: +namespace random { +template +class RandomFunction { + public: #ifdef __CUDABLAS__ - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, - const void *x, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, - const void *x, const Nd4jLong *xShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - template - static _CUDA_D void execTransformCuda(Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - - static _CUDA_H void executeCudaSingle(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - - static _CUDA_H void executeCudaDouble(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - const void *x, const Nd4jLong *xShapeBuffer, - void *z, const Nd4jLong *zShapeBuffer, - void *extraArguments); - - - static _CUDA_H void executeCudaTriple(dim3& launchDims, cudaStream_t* stream, - int opNum, - Nd4jPointer stateHost, - const void *x, const Nd4jLong *xShapeBuffer, - const void *y, const Nd4jLong *yShapeBuffer, - void *z, const Nd4jLong* zShapeBuffer, - void *extraArguments); + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, + const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + template + static _CUDA_D void execTransformCuda(Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaSingle(dim3 &launchDims, cudaStream_t *stream, + int opNum, Nd4jPointer stateHost, + void *z, const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaDouble(dim3 &launchDims, cudaStream_t *stream, + int opNum, Nd4jPointer stateHost, + const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, + void *extraArguments); + + static _CUDA_H void executeCudaTriple( + dim3 &launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, + const void *x, const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, + void *extraArguments); #else - template - static void execTransform(Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, const void *y, const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - template - static void execTransform(Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - template - static void execTransform(Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - - static void execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - static void execTransform(int opNum, Nd4jPointer state, const void *x, const Nd4jLong *xShapeBuffer, const void *y, const Nd4jLong *yShapeBuffer, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); - static void execTransform(int opNum, Nd4jPointer state, void *z, const Nd4jLong *zShapeBuffer, void *extraArguments); + template + static void execTransform(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + template + static void execTransform(Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + template + static void execTransform(Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + + static void execTransform(int opNum, Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + static void execTransform(int opNum, Nd4jPointer state, const void *x, + const Nd4jLong *xShapeBuffer, const void *y, + const Nd4jLong *yShapeBuffer, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); + static void execTransform(int opNum, Nd4jPointer state, void *z, + const Nd4jLong *zShapeBuffer, void *extraArguments); #endif - }; - } -} - +}; +} // namespace random +} // namespace functions -#endif //LIBND4J_RANDOM_H +#endif // LIBND4J_RANDOM_H diff --git a/libnd4j/include/loops/reduce3.h b/libnd4j/include/loops/reduce3.h old mode 100755 new mode 100644 index f2496f1fe443..1a5608846782 --- a/libnd4j/include/loops/reduce3.h +++ b/libnd4j/include/loops/reduce3.h @@ -30,242 +30,193 @@ #ifdef _OPENMP #include #endif -#include -#include -#include +#include +#include #include +#include #include +#include #include -#include -#include +#include #ifdef __CUDACC__ #include #include #endif - #include "legacy_ops.h" using namespace simdOps; namespace functions { -namespace reduce3 { +namespace reduce3 { /** * Reduce involving * 2 arrays */ -template +template class Reduce3 { - - public: - + public: #ifdef __CUDACC__ - virtual __device__ - inline Y opAtomic(X d1, X d2, Y *extraParamsRef) = 0; - - /** - * Aggregate shared memory - * @param sPartialsRef - * @param tid - * @param extraParams - */ - template - static __device__ void aggregatePartials(void* sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *x, const Nd4jLong *xShapeInfo, - const void *y, const Nd4jLong *yShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformAll(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); - - /** - Perform a reduction - @param n the number of elements - @param xOffset the starting offset - @param dx the data to perform the reduction on - @param incx the increment on which to perform the reduction - @param extraParams extra parameters used for calculations - @param result where to store the result of the reduction - */ - template - static __device__ void transform(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execAllCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - - static __device__ void execScalarCuda(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int * allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); - - - static __host__ void exec(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, + virtual __device__ inline Y opAtomic(X d1, X d2, Y *extraParamsRef) = 0; + + /** + * Aggregate shared memory + * @param sPartialsRef + * @param tid + * @param extraParams + */ + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda( + const void *x, const Nd4jLong *xShapeInfo, const void *y, + const Nd4jLong *yShapeInfo, void *extraParams, void *z, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformAll( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets); + + /** +Perform a reduction +@param n the number of elements +@param xOffset the starting offset +@param dx the data to perform the reduction on +@param incx the increment on which to perform the reduction +@param extraParams extra parameters used for calculations +@param result where to store the result of the reduction +*/ + template + static __device__ void transform( + const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execAllCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __device__ void execScalarCuda( + int opNum, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void exec( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __host__ void execAll( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int postProcessOrNot, int *allocationPointer, + const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, + const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); + + static __host__ void execScalar(dim3 launchDims, cudaStream_t *stream, + int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static __host__ void execAll(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int postProcessOrNot, - int *allocationPointer, - const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *yTadOnlyShapeInfo, const Nd4jLong *yTadOffsets); - - static __host__ void execScalar(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo, - int* allocationPointer, void *reductionBuffer, - const Nd4jLong *tadOnlyShapeInfo); + int *allocationPointer, void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); #else - template - static void execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo); - - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *y, const Nd4jLong *yShapeInfo, - void *z, const Nd4jLong *zShapeInfo); - - - template - static void exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop); - - - template - static void exec(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop); - - - template - static void execAll(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop); - - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - int64_t start, int64_t stop); - - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - int64_t start, int64_t stop); - - - static void execAll(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParamsVals, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - int64_t start, int64_t stop); + template + static void execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *y, + const Nd4jLong *yShapeInfo, void *z, + const Nd4jLong *zShapeInfo); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop); + + template + static void exec(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, int64_t stop); + + template + static void execAll(const void *vx, const Nd4jLong *xShapeInfo, + void *vextraParams, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, int64_t stop); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, int64_t start, int64_t stop); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, int64_t start, int64_t stop); + + static void execAll(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParamsVals, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, + int dimensionLength, const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, const Nd4jLong *yTadShapeInfo, + const Nd4jLong *yOffsets, int64_t start, int64_t stop); #endif }; - - -} -} +} // namespace reduce3 +} // namespace functions #ifdef __CUDACC__ #endif - - #endif /* REDUCE3_H_ */ diff --git a/libnd4j/include/loops/reduce_bool.h b/libnd4j/include/loops/reduce_bool.h index a74d53033f5f..9faa41ab93dd 100644 --- a/libnd4j/include/loops/reduce_bool.h +++ b/libnd4j/include/loops/reduce_bool.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_BOOL_H #define REDUCE_BOOL_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -36,12 +35,11 @@ #include #endif - #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -50,128 +48,156 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceBoolFunction { - public: +template +class ReduceBoolFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, Nd4jLong length, void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; - +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_float.h b/libnd4j/include/loops/reduce_float.h index c78082f8ed57..4437650db788 100644 --- a/libnd4j/include/loops/reduce_float.h +++ b/libnd4j/include/loops/reduce_float.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_FLOAT_H #define REDUCE_FLOAT_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -36,12 +35,11 @@ #include #endif - #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -50,132 +48,156 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceFloatFunction { - - public: - -#ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShape, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); +template +class ReduceFloatFunction { + public: +#ifdef __CUDACC__ + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShape, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param vx the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param vx the input data - * @param xShapeInfo the shape information for vx - * @param extraParams the extra parameters - * @param vz the vz buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param vx the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param vz the vz buffer - * @param zShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *zShapeInfo); - - - - /** - * Reduce down to 1 number - * @param vx the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *vx, Nd4jLong xElementWiseStride, Nd4jLong length, void *extraParams); + /** + * Reduce down to 1 number + * @param vx the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param vx the input data + * @param xShapeInfo the shape information for vx + * @param extraParams the extra parameters + * @param vz the vz buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param vx the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param vz the vz buffer + * @param zShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *vx, const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo); + + /** + * Reduce down to 1 number + * @param vx the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *vx, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; - +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_long.h b/libnd4j/include/loops/reduce_long.h index 45ede29850dc..193a6df0ebaa 100644 --- a/libnd4j/include/loops/reduce_long.h +++ b/libnd4j/include/loops/reduce_long.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_LONG_H #define REDUCE_LONG_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -38,9 +37,9 @@ #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -49,132 +48,157 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceLongFunction { - public: +template +class ReduceLongFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD(const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* hXShapeInfo, void *extraParams, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, + Nd4jLong numItems, + void *extraParams); + + template + static __device__ void execScalarCuda(const void *vx, + const Nd4jLong *xShapeInfo, + void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, + void *reductionBuffer, + const Nd4jLong *tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + const void *vx, const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t *stream, int opNum, int rank, + const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hXShapeInfo, + void *extraParams, void *vz, const Nd4jLong *zShapeInfo, + const Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static Z execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, - Nd4jLong length, - void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H Z execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static Z execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static Z _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/reduce_same.h b/libnd4j/include/loops/reduce_same.h index 5f3622f39b60..a1dd79280661 100644 --- a/libnd4j/include/loops/reduce_same.h +++ b/libnd4j/include/loops/reduce_same.h @@ -14,21 +14,20 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef REDUCE_SAME_H #define REDUCE_SAME_H #include //#include -#include #include +#include #ifdef _OPENMP #include #endif #include -#include -#include #include +#include #include +#include #pragma once #ifdef __CUDACC__ @@ -38,9 +37,9 @@ #include "legacy_ops.h" -//an op for the kernel +// an op for the kernel namespace functions { - namespace reduce { +namespace reduce { /** * A reduce function @@ -49,136 +48,165 @@ namespace functions { * via aggregating member * elements. */ - template - class ReduceSameFunction { - public: +template +class ReduceSameFunction { + public: #ifdef __CUDACC__ - template - static __device__ void aggregatePartials(void *sPartials, Nd4jLong tid, Nd4jLong numItems, void *extraParams); - - template - static __device__ void execScalarCuda( void const* vx, Nd4jLong const *xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - static __device__ void execScalarCudaLegacy(int opNum, void const* vx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - template - static __device__ void transformCudaXD( void const* vx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); - - template - static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - template - static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); - - static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); - - static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); + template + static __device__ void aggregatePartials(void* sPartials, Nd4jLong tid, + Nd4jLong numItems, + void* extraParams); + + template + static __device__ void execScalarCuda(void const* vx, + Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo); + + static __device__ void execScalarCudaLegacy(int opNum, void const* vx, + Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo); + + template + static __device__ void transformCudaXD( + void const* vx, Nd4jLong const* xShapeInfo, void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + template + static __host__ void intermediateScalar( + dim3 launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); + + template + static __host__ void intermediateXD( + dim3 launchDims, cudaStream_t* stream, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionPointer, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); + + static __host__ void execReduceScalar( + dim3 launchDims, cudaStream_t* stream, int opNum, void const* vx, + Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionBuffer, Nd4jLong const* tadOnlyShapeInfo); + + static __host__ void execReduceXD( + dim3 launchDims, cudaStream_t* stream, int opNum, int rank, + void const* vx, Nd4jLong const* xShapeInfo, Nd4jLong const* hXShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hZShapeInfo, int* dimension, int dimensionLength, + void* reductionPointer, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets); #else - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static _CUDA_H X execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - template - static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - - static X execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo); - - static void exec(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * Execute on the cpu - * @param x the input data - * @param xShapeInfo the shape information for x - * @param extraParams the extra parameters - * @param result the result buffer - * @param resultShapeInfoBuffer the shape information - * @param dimension the dimension to perform - * the reduce along long - * @param dimensionLength the length of the dimension buffer - */ - - - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, - int64_t start, int64_t stop); - - /** - * CPU implementation - * @param x the input data - * @param xShapeInfo the shape information for - * the input data - * @param extraParams the extra parameters for the problem - * @param result the result buffer - * @param resultShapeInfo the shape information - */ - template - static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *result, const Nd4jLong *resultShapeInfo); - - - - /** - * Reduce down to 1 number - * @param x the input - * @param xShapeInfo the shape information - * for the input - * @param extraParams the extra params - * @return - */ - template - static X _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, - Nd4jLong length, - void *extraParams); + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static _CUDA_H X execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + template + static _CUDA_H void execScalar(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static X execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams); + + static void execScalar(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, + const Nd4jLong *zShapeInfo); + + static void exec(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, int64_t stop); + + /** + * Execute on the cpu + * @param x the input data + * @param xShapeInfo the shape information for x + * @param extraParams the extra parameters + * @param result the result buffer + * @param resultShapeInfoBuffer the shape information + * @param dimension the dimension to perform + * the reduce along long + * @param dimensionLength the length of the dimension buffer + */ + + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffset, int64_t start, + int64_t stop); + + /** + * CPU implementation + * @param x the input data + * @param xShapeInfo the shape information for + * the input data + * @param extraParams the extra parameters for the problem + * @param result the result buffer + * @param resultShapeInfo the shape information + */ + template + static void _CUDA_H exec(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *result, + const Nd4jLong *resultShapeInfo); + + /** + * Reduce down to 1 number + * @param x the input + * @param xShapeInfo the shape information + * for the input + * @param extraParams the extra params + * @return + */ + template + static X _CUDA_H execScalar(const void *x, Nd4jLong xElementWiseStride, + Nd4jLong length, void *extraParams); #endif - }; +}; #ifdef __CUDACC__ - /** - * - * @param extraParams - * @param sPartials - * @param sMemSize - */ - template - __device__ void initializeShared(T *extraParams, T **sPartials, int sMemSize); +/** + * + * @param extraParams + * @param sPartials + * @param sMemSize + */ +template +__device__ void initializeShared(T* extraParams, T** sPartials, int sMemSize); #endif - } +} // namespace reduce -} +} // namespace functions #endif - diff --git a/libnd4j/include/loops/scalar.h b/libnd4j/include/loops/scalar.h old mode 100755 new mode 100644 index f7333d57de6f..c50b8c4764f2 --- a/libnd4j/include/loops/scalar.h +++ b/libnd4j/include/loops/scalar.h @@ -23,9 +23,9 @@ #ifndef SCALAR_H_ #define SCALAR_H_ +#include #include #include -#include #ifdef __JNI__ #include @@ -33,6 +33,7 @@ #include #include #include + #include "helpers/logger.h" #ifdef __CUDACC__ @@ -44,142 +45,117 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarTransform { - - public: - +template +class ScalarTransform { + public: #ifdef __CUDACC__ - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *hzShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, const Nd4jLong *hzShapeInfo, - const void* scalar, - void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const Nd4jLong *hzShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + __host__ static void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, const Nd4jLong *hxShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const Nd4jLong *hzShapeInfo, + const void *scalar, void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t len, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param len the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param len the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t len, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t len, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param len the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param len the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t len, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/scalar_bool.h b/libnd4j/include/loops/scalar_bool.h index 4992df5a16b7..e1e20973081c 100644 --- a/libnd4j/include/loops/scalar_bool.h +++ b/libnd4j/include/loops/scalar_bool.h @@ -28,12 +28,13 @@ #ifdef __JNI__ #include #endif +#include +#include #include #include #include + #include "helpers/logger.h" -#include -#include #ifdef __CUDACC__ #include @@ -44,170 +45,139 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarBoolTransform { - - public: - +template +class ScalarBoolTransform { + public: #ifdef __CUDACC__ - - template - __device__ - static void transformCuda(const void* scalar, - const void *vy, const Nd4jLong *shapeInfo, - void *vparams, - void *vresult, const Nd4jLong *resultShapeInfo, - int *allocationBuffer); - - template - __device__ - static void transformCuda(Nd4jLong n, - const void* vx, const void *vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer); - - template - __device__ - static void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void* scalar, - const void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + __device__ static void transformCuda(const void *scalar, const void *vy, + const Nd4jLong *shapeInfo, void *vparams, + void *vresult, + const Nd4jLong *resultShapeInfo, + int *allocationBuffer); + + template + __device__ static void transformCuda(Nd4jLong n, const void *vx, + const void *vy, Nd4jLong yEWS, + void *vparams, void *vz, Nd4jLong zEWS, + int *allocationBuffer); + + template + __device__ static void transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + __host__ static void executeCudaShaped( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, const void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); /* #include "cuda/scalar_temp.cu" */ #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/scalar_int.h b/libnd4j/include/loops/scalar_int.h index c3a53199efcb..c7a39da3be96 100644 --- a/libnd4j/include/loops/scalar_int.h +++ b/libnd4j/include/loops/scalar_int.h @@ -28,12 +28,13 @@ #ifdef __JNI__ #include #endif +#include +#include #include #include #include + #include "helpers/logger.h" -#include -#include #ifdef __CUDACC__ #include @@ -44,169 +45,138 @@ #include "legacy_ops.h" namespace functions { - namespace scalar { +namespace scalar { /** * Apply a scalar * operation to an array */ - template - class ScalarIntTransform { - - public: - +template +class ScalarIntTransform { + public: #ifdef __CUDACC__ - - template - __device__ - static void transformCuda(const void* scalar, - const void *vy, const Nd4jLong *shapeInfo, - void *vparams, - void *vresult, const Nd4jLong *resultShapeInfo, - int *allocationBuffer); - - template - __device__ - static void transformCuda(Nd4jLong n, - const void* vx, const void *vy, Nd4jLong yEWS, - void *vparams, - void *vz, Nd4jLong zEWS, - int *allocationBuffer); - - template - __device__ - static void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vextraParams, - void *vz, const Nd4jLong *zShapeInfo, - const void *vscalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); - - template - __host__ - static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* vscalar, - void *vextraParams, - int *allocPointer); - - __host__ - static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void* scalar, - void *extraParams); - - __host__ - static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - void *extraParams, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ); + + template + __device__ static void transformCuda(const void *scalar, const void *vy, + const Nd4jLong *shapeInfo, void *vparams, + void *vresult, + const Nd4jLong *resultShapeInfo, + int *allocationBuffer); + + template + __device__ static void transformCuda(Nd4jLong n, const void *vx, + const void *vy, Nd4jLong yEWS, + void *vparams, void *vz, Nd4jLong zEWS, + int *allocationBuffer); + + template + __device__ static void transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, + const Nd4jLong *zShapeInfo, const void *vscalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateAlongDimension( + dim3 &launchDims, cudaStream_t *stream, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); + + template + __host__ static void intermediateShaped( + dim3 &launchDims, cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, + const void *vscalar, void *vextraParams, int *allocPointer); + + __host__ static void executeCudaShaped(dim3 &launchDims, cudaStream_t *stream, + int opNum, const void *x, + const Nd4jLong *xShapeInfo, + void *result, + const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams); + + __host__ static void executeCudaAlongDimension( + dim3 &launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShapeInfo, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, void *extraParams, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ); #else - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *z, const Nd4jLong *zShapeInfo, - const void *scalars, - int *dimension, int dimensionLength, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, - const Nd4jLong *tadShapeInfoZ, const Nd4jLong *tadOffsetsZ, - uint64_t start, uint64_t stop); - - static void transform(int opNum, - const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, - uint64_t stop); - - static void transform(int opNum, - const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); - - - - - /* - * ScalarOp along dimension - */ - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - const void *scalar, - void *extraParams, - uint64_t start, uint64_t stop); - - - /** - * CPU implementation of scalar operation - * @param x the input - * @param xStride the stride for the input - * @param result the result buffer - * @param resultStride the stride for the result - * @param scalar the scalar to apply - * @param extraParams the extra parameters where - * neccssary - * @param n the number of elements to loop over - */ - - template - static void transform(const void *x, Nd4jLong xStride, - void *result, Nd4jLong resultStride, - const void *scalar, - void *extraParams, - uint64_t n, uint64_t start, uint64_t stop); + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *extraParams, void *z, const Nd4jLong *zShapeInfo, + const void *scalars, int *dimension, + int dimensionLength, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, + const Nd4jLong *tadShapeInfoZ, + const Nd4jLong *tadOffsetsZ, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + const void *scalar, void *extraParams, uint64_t start, + uint64_t stop); + + static void transform(int opNum, const void *x, Nd4jLong xStride, + void *result, Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); + + /* + * ScalarOp along dimension + */ + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, const Nd4jLong *xShapeInfo, void *result, + const Nd4jLong *resultShapeInfo, const void *scalar, + void *extraParams, uint64_t start, uint64_t stop); + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(const void *x, Nd4jLong xStride, void *result, + Nd4jLong resultStride, const void *scalar, + void *extraParams, uint64_t n, uint64_t start, + uint64_t stop); #endif - }; - } -} - +}; +} // namespace scalar +} // namespace functions #endif /* SCALAR_H_ */ diff --git a/libnd4j/include/loops/special_kernels.h b/libnd4j/include/loops/special_kernels.h index 209d35120781..b84366cfe4b0 100644 --- a/libnd4j/include/loops/special_kernels.h +++ b/libnd4j/include/loops/special_kernels.h @@ -21,84 +21,134 @@ #ifndef LIBND4J_SPECIAL_KERNELS_H #define LIBND4J_SPECIAL_KERNELS_H -#include - -#include -#include -#include #include -#include -#include #include +#include +#include #include #include +#include +#include +#include +#include namespace sd { - template - _CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, const Nd4jLong *xShapeInfo, Nd4jLong length, long idx); - - template - _CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dX, void *dZ, const Nd4jLong *zShapeInfo, const Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadOffsets); - - template - _CUDA_H void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong n, half *dz); - - template - _CUDA_H void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, Nd4jPointer *targets, - Nd4jLong const* zShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets); - - template - _CUDA_H void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdX, Nd4jLong **xShapeInfo, void **vdZ, int N, - int *shuffleMap, Nd4jLong** tadOnlyShapeInfo, Nd4jLong** tadOffsets); - - template - _CUDA_H void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, half *dx, Nd4jLong n, void *dz); - - template - _CUDA_H void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vz, Nd4jLong const* zShapeInfo); - - template - _CUDA_H void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, void *vresult); - - template - _CUDA_H void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vresult, Nd4jLong const* resultShapeInfo); - - template - _CUDA_H void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, int numArrays, Nd4jPointer *data, - Nd4jPointer *inputShapeInfos, void *vresult, Nd4jLong const* resultShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers, Nd4jLong const* zTadShape, Nd4jLong const* zOffsets); - - template - _CUDA_H void pullRowsKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong n, Nd4jLong *indexes, - Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets); - - template - _CUDA_H void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vdx, void *vdz, int n, Nd4jLong length, bool propagate); - - template - _CUDA_H void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void **vx, void *vz, int n, const Nd4jLong length); - - template - _CUDA_H void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, Nd4jPointer *extraPointers, int dOffset, char order, void *vz, Nd4jLong *zShapeInfo, void *vy, Nd4jLong *yShapeInfo); - - template - _CUDA_H void tileKernelH(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, cudaStream_t *stream); - template - _CUDA_H void tileKernelHH(void const* inputBuffer, Nd4jLong const* inputShape, void* outputBuffer, Nd4jLong const* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream); - - class NDArray; - template - _CUDA_H void setDiagonalValueUpper(void* buffer, Nd4jLong const* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream); - - template - _CUDA_H void setDiagonalValueLower(void* buffer, Nd4jLong const* shape, NDArray const& value, int diagonal, Nd4jLong rows, Nd4jLong cols, cudaStream_t& stream); - - template - _CUDA_H void templatedSwapUnsafe(void* theFirstBuffer, Nd4jLong const* theFirstShape, void* theSecondBuffer, Nd4jLong const* theSecondShape, cudaStream_t* theStream); - -} +template +_CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, + const Nd4jLong *xShapeInfo, Nd4jLong length, + long idx); + +template +_CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dX, void *dZ, + const Nd4jLong *zShapeInfo, + const Nd4jLong *tadOnlyShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadOffsets); + +template +_CUDA_H void convertToHalfGeneric(dim3 &launchDims, cudaStream_t *stream, + void *dx, Nd4jLong n, half *dz); + +template +_CUDA_H void tearKernelGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, Nd4jPointer *targets, + Nd4jLong const *zShapeInfo, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets); + +template +_CUDA_H void shuffleKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdX, Nd4jLong **xShapeInfo, void **vdZ, + int N, int *shuffleMap, + Nd4jLong **tadOnlyShapeInfo, + Nd4jLong **tadOffsets); + +template +_CUDA_H void convertHalfsToGeneric(dim3 &launchDims, cudaStream_t *stream, + half *dx, Nd4jLong n, void *dz); + +template +_CUDA_H void concatKernelVStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vz, + Nd4jLong const *zShapeInfo); + +template +_CUDA_H void concatKernelScalarGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + void *vresult); + +template +_CUDA_H void concatKernelHStackGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, + void *vresult, + Nd4jLong const *resultShapeInfo); + +template +_CUDA_H void concatKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfos, void *vresult, + Nd4jLong const *resultShapeInfo, + Nd4jPointer *tadPointers, + Nd4jPointer *offsetPointers, + Nd4jLong const *zTadShape, + Nd4jLong const *zOffsets); + +template +_CUDA_H void pullRowsKernelGeneric( + dim3 &launchDims, cudaStream_t *stream, void *vx, void *vz, Nd4jLong n, + Nd4jLong *indexes, Nd4jLong const *tadShapeInfo, Nd4jLong const *tadOffsets, + Nd4jLong const *zTadShapeInfo, Nd4jLong const *zTadOffsets); + +template +_CUDA_H void averagingKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vdx, void *vdz, int n, + Nd4jLong length, bool propagate); + +template +_CUDA_H void accumulateKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + void **vx, void *vz, int n, + const Nd4jLong length); + +template +_CUDA_H void flattenKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + Nd4jPointer *extraPointers, int dOffset, + char order, void *vz, Nd4jLong *zShapeInfo, + void *vy, Nd4jLong *yShapeInfo); + +template +_CUDA_H void tileKernelH(void const *inputBuffer, Nd4jLong const *inputShape, + void *outputBuffer, Nd4jLong const *outputShape, + Nd4jLong resultLength, cudaStream_t *stream); +template +_CUDA_H void tileKernelHH(void const *inputBuffer, Nd4jLong const *inputShape, + void *outputBuffer, Nd4jLong const *outputShape, + Nd4jLong resultLength, Nd4jLong ews, + cudaStream_t *stream); + +class NDArray; +template +_CUDA_H void setDiagonalValueUpper(void *buffer, Nd4jLong const *shape, + NDArray const &value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t &stream); + +template +_CUDA_H void setDiagonalValueLower(void *buffer, Nd4jLong const *shape, + NDArray const &value, int diagonal, + Nd4jLong rows, Nd4jLong cols, + cudaStream_t &stream); + +template +_CUDA_H void templatedSwapUnsafe(void *theFirstBuffer, + Nd4jLong const *theFirstShape, + void *theSecondBuffer, + Nd4jLong const *theSecondShape, + cudaStream_t *theStream); + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/loops/summarystatsreduce.h b/libnd4j/include/loops/summarystatsreduce.h old mode 100755 new mode 100644 index 1ab06a11bd89..c2601df33283 --- a/libnd4j/include/loops/summarystatsreduce.h +++ b/libnd4j/include/loops/summarystatsreduce.h @@ -23,15 +23,14 @@ #ifndef SUMMARYSTATSREDUCE_H_ #define SUMMARYSTATSREDUCE_H_ +#include #include #include - -#include #ifdef __CUDACC__ #include #include -#define host_and_device inline __host__ __device__ +#define host_and_device inline __host__ __device__ #else #define host_and_device inline #endif @@ -46,288 +45,274 @@ #include "legacy_ops.h" namespace functions { - namespace summarystats { - - // This example computes several statistical properties of a data - // series in a single reduction. The algorithm is described in detail here: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - // - // Thanks to Joseph Rhoads for contributing this example - - - // structure used to accumulate the moments and other - // statistical properties encountered so far. - template - class SummaryStatsData { - - public: - double n; - double min; - double max; - double mean; - double M2; - double M3; - double M4; - double bias; - - _CUDA_HD SummaryStatsData() { - initialize(); - } - - // initialize to the identity element - - _CUDA_HD void initialize() { - n = mean = M2 = M3 = M4 = bias = 0; - } - - _CUDA_HD void initWithValue(X val) { - n = 1; - min = val; - max = val; - mean = val; - M2 = 0; - M3 = 0; - M4 = 0; - bias = 0; - } - - _CUDA_HD void setValues(SummaryStatsData *target) { - n = target->n; - min = target->min; - max = target->max; - mean = target->mean; - M2 = target->M2; - M3 = target->M3; - M4 = target->M4; - bias = target->bias; - } - - _CUDA_HD double variance() { - if (n <= 1.0) - return 0.0; - return M2 / (n); - } - - _CUDA_HD double varianceBiasCorrected() { - if (this->n <= 1.0) { - return 0.0; - } - - return M2 / (n - 1.0); - } - - - _CUDA_HD double variance_n() { - if (n <= 1.0) - return 0.0; - return M2 / n; - } - - _CUDA_HD double skewness() { return M2 > 0.0 ? sd::math::nd4j_sqrt(n) * M3 / sd::math::nd4j_pow(M2, 1.5) : 0.0; } - - _CUDA_HD double kurtosis() { return M2 > 0.0 ? n * M4 / (M2 * M2) : 0; } - - _CUDA_HD double getM2() { - return M2; - } - - _CUDA_HD void setM2(X m2) { - M2 = m2; - } - - _CUDA_HD double getM3() { - return M3; - } - - _CUDA_HD void setM3(X m3) { - M3 = m3; - } - - _CUDA_HD double getM4() { - return M4; - } - - _CUDA_HD void setM4(X m4) { - M4 = m4; - } - - _CUDA_HD double getMax() { - return max; - } - - _CUDA_HD void setMax(X max) { - this->max = max; - } - - _CUDA_HD double getMean() { - return mean; - } - - _CUDA_HD void setMean(X mean) { - this->mean = mean; - } - - _CUDA_HD double getMin() { - return min; - } - - _CUDA_HD void setMin(X min) { - this->min = min; - } - - _CUDA_HD double getN() { - return n; - } - - _CUDA_HD void setN(X n) { - this->n = n; - } - }; +namespace summarystats { + +// This example computes several statistical properties of a data +// series in a single reduction. The algorithm is described in detail here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm +// +// Thanks to Joseph Rhoads for contributing this example + +// structure used to accumulate the moments and other +// statistical properties encountered so far. +template +class SummaryStatsData { + public: + double n; + double min; + double max; + double mean; + double M2; + double M3; + double M4; + double bias; + + _CUDA_HD SummaryStatsData() { initialize(); } + + // initialize to the identity element + + _CUDA_HD void initialize() { n = mean = M2 = M3 = M4 = bias = 0; } + + _CUDA_HD void initWithValue(X val) { + n = 1; + min = val; + max = val; + mean = val; + M2 = 0; + M3 = 0; + M4 = 0; + bias = 0; + } + + _CUDA_HD void setValues(SummaryStatsData* target) { + n = target->n; + min = target->min; + max = target->max; + mean = target->mean; + M2 = target->M2; + M3 = target->M3; + M4 = target->M4; + bias = target->bias; + } + + _CUDA_HD double variance() { + if (n <= 1.0) return 0.0; + return M2 / (n); + } + + _CUDA_HD double varianceBiasCorrected() { + if (this->n <= 1.0) { + return 0.0; + } -#ifdef __CUDACC__ - // This is the un-specialized struct. Note that we prevent instantiation of this -// struct by putting an undefined symbol in the function body so it won't compile. - template - struct SharedSummaryStatsData { - // Ensure that we won't compile any un-specialized types - __device__ T * getPointer() { - extern __device__ void error(void); - error(); - return 0; - } - }; - - // Following are the specializations for the following types. - // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double - // One could also specialize it for user-defined types. - - template<> - struct SharedSummaryStatsData { - __device__ SummaryStatsData * getPointer() { - extern __shared__ SummaryStatsData s_int2[]; - return s_int2; - } - }; - // Following are the specializations for the following types. - // int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, and double - // One could also specialize it for user-defined types. - - template<> - struct SharedSummaryStatsData { - __device__ SummaryStatsData * getPointer() { - extern __shared__ SummaryStatsData s_int6[]; - return s_int6; - } - }; -#endif + return M2 / (n - 1.0); + } - /** - * Standard deviation or variance 1 pass - */ - template - class SummaryStatsReduce { - public: - //calculate an update of the reduce operation - _CUDA_HD static SummaryStatsData update(SummaryStatsData x, SummaryStatsData y, - void* extraParams) { - if ((long) x.n == 0 && (long) y.n > 0) - return y; - else if ((long) x.n > 0 && (long) y.n == 0) - return x; - SummaryStatsData vz; - double n = x.n + y.n; - double n2 = n * n; - double n3 = n2 * n; - - - double delta = y.mean - x.mean; - double delta2 = delta * delta; - double delta3 = delta2 * delta; - double delta4 = delta3 * delta; - - //Basic number of samples (n), min, and max - vz.n = n; - vz.min = sd::math::nd4j_min(x.min, y.min); - vz.max = sd::math::nd4j_max(x.max, y.max); - double meanD = x.mean + delta * y.n / n; - vz.mean = meanD; - double M2D = x.M2 + y.M2; - M2D += delta2 * x.n * y.n / n; - vz.M2 = M2D; - vz.M3 = x.M3 + y.M3; - vz.M3 += delta3 * x.n * y.n * (x.n - y.n) / n2; - vz.M3 += 3.0 * delta * (x.n * y.M2 - y.n * x.M2) / n; - - vz.M4 = x.M4 + y.M4; - vz.M4 += delta4 * x.n * y.n * (x.n * x.n - x.n * y.n + y.n * y.n) / n3; - vz.M4 += 6.0 * delta2 * (x.n * x.n * y.M2 + y.n * y.n * x.M2) / n2; - vz.M4 += 4.0 * delta * (x.n * y.M3 - y.n * x.M3) / n; - - return vz; - } + _CUDA_HD double variance_n() { + if (n <= 1.0) return 0.0; + return M2 / n; + } + _CUDA_HD double skewness() { + return M2 > 0.0 ? sd::math::nd4j_sqrt(n) * M3 / + sd::math::nd4j_pow(M2, 1.5) + : 0.0; + } + _CUDA_HD double kurtosis() { return M2 > 0.0 ? n * M4 / (M2 * M2) : 0; } -#ifdef __CUDACC__ + _CUDA_HD double getM2() { return M2; } - static inline _CUDA_D Z startingValue(X const* input) { - return static_cast(0); - } + _CUDA_HD void setM2(X m2) { M2 = m2; } - template - static _CUDA_D void aggregatePartials(SummaryStatsData **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *extraParams); + _CUDA_HD double getM3() { return M3; } + _CUDA_HD void setM3(X m3) { M3 = m3; } - template - static _CUDA_D void transform(void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); + _CUDA_HD double getM4() { return M4; } - static _CUDA_D void transform(const int opNum, void const* dx, Nd4jLong const* xShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong const* tadOnlyShapeInfo, Nd4jLong const* tadOffsets); + _CUDA_HD void setM4(X m4) { M4 = m4; } - static _CUDA_H void execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); - static _CUDA_H void execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); - static _CUDA_H void execSummaryStatsReduce(dim3& launchDims, cudaStream_t *stream, int opNum, void const* x, Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, void *extraParams, void *vz, Nd4jLong const* zShapeInfo, Nd4jLong const* hzShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool biasCorrected, void *reductionBuffer); -#else + _CUDA_HD double getMax() { return max; } + + _CUDA_HD void setMax(X max) { this->max = max; } + + _CUDA_HD double getMean() { return mean; } + + _CUDA_HD void setMean(X mean) { this->mean = mean; } + + _CUDA_HD double getMin() { return min; } + + _CUDA_HD void setMin(X min) { this->min = min; } + + _CUDA_HD double getN() { return n; } + + _CUDA_HD void setN(X n) { this->n = n; } +}; - static Z execScalar(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - static void execScalar(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer); - - static void exec(int opNum, - bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength); - - template - static Z execScalar(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams); - - template - static void execScalar(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer); - - - template - static void exec(bool biasCorrected, - const void *x, const Nd4jLong *xShapeInfo, - void *extraParams, - void *vz, const Nd4jLong *resultShapeInfoBuffer, - int *dimension, int dimensionLength); +#ifdef __CUDACC__ +// This is the un-specialized struct. Note that we prevent instantiation of +// this +// struct by putting an undefined symbol in the function body so it won't +// compile. +template +struct SharedSummaryStatsData { + // Ensure that we won't compile any un-specialized types + __device__ T* getPointer() { + extern __device__ void error(void); + error(); + return 0; + } +}; + +// Following are the specializations for the following types. +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedSummaryStatsData { + __device__ SummaryStatsData* getPointer() { + extern __shared__ SummaryStatsData s_int2[]; + return s_int2; + } +}; +// Following are the specializations for the following types. +// int, uint, char, uchar, short, ushort, long long, ulong long, bool, float, +// and double One could also specialize it for user-defined types. + +template <> +struct SharedSummaryStatsData { + __device__ SummaryStatsData* getPointer() { + extern __shared__ SummaryStatsData s_int6[]; + return s_int6; + } +}; #endif - }; - } -} +/** + * Standard deviation or variance 1 pass + */ +template +class SummaryStatsReduce { + public: + // calculate an update of the reduce operation + _CUDA_HD static SummaryStatsData update(SummaryStatsData x, + SummaryStatsData y, + void* extraParams) { + if ((long)x.n == 0 && (long)y.n > 0) + return y; + else if ((long)x.n > 0 && (long)y.n == 0) + return x; + SummaryStatsData vz; + double n = x.n + y.n; + double n2 = n * n; + double n3 = n2 * n; + + double delta = y.mean - x.mean; + double delta2 = delta * delta; + double delta3 = delta2 * delta; + double delta4 = delta3 * delta; + + // Basic number of samples (n), min, and max + vz.n = n; + vz.min = sd::math::nd4j_min(x.min, y.min); + vz.max = sd::math::nd4j_max(x.max, y.max); + double meanD = x.mean + delta * y.n / n; + vz.mean = meanD; + double M2D = x.M2 + y.M2; + M2D += delta2 * x.n * y.n / n; + vz.M2 = M2D; + vz.M3 = x.M3 + y.M3; + vz.M3 += delta3 * x.n * y.n * (x.n - y.n) / n2; + vz.M3 += 3.0 * delta * (x.n * y.M2 - y.n * x.M2) / n; + + vz.M4 = x.M4 + y.M4; + vz.M4 += delta4 * x.n * y.n * (x.n * x.n - x.n * y.n + y.n * y.n) / n3; + vz.M4 += 6.0 * delta2 * (x.n * x.n * y.M2 + y.n * y.n * x.M2) / n2; + vz.M4 += 4.0 * delta * (x.n * y.M3 - y.n * x.M3) / n; + + return vz; + } + +#ifdef __CUDACC__ + + static inline _CUDA_D Z startingValue(X const* input) { + return static_cast(0); + } + + template + static _CUDA_D void aggregatePartials(SummaryStatsData** sPartialsRef, + Nd4jLong tid, Nd4jLong numElements, + void* extraParams); + + template + static _CUDA_D void transform(void const* dx, Nd4jLong const* xShapeInfo, + void* extraParams, void* vz, + Nd4jLong const* zShapeInfo, int* dimension, + int dimensionLength, int postProcessOrNot, + int* allocationBuffer, void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + static _CUDA_D void transform(const int opNum, void const* dx, + Nd4jLong const* xShapeInfo, void* extraParams, + void* vz, Nd4jLong const* zShapeInfo, + int* dimension, int dimensionLength, + int postProcessOrNot, int* allocationBuffer, + void* reductionBuffer, + Nd4jLong const* tadOnlyShapeInfo, + Nd4jLong const* tadOffsets); + + static _CUDA_H void execSummaryStatsReduceScalar( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer); + static _CUDA_H void execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, Nd4jLong const* tadShapeInfo, + Nd4jLong const* tadOffsets, bool biasCorrected, void* reductionBuffer); + static _CUDA_H void execSummaryStatsReduce( + dim3& launchDims, cudaStream_t* stream, int opNum, void const* x, + Nd4jLong const* xShapeInfo, Nd4jLong const* hxShapeInfo, + void* extraParams, void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong const* hzShapeInfo, int* dimension, int dimensionLength, + Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, + bool biasCorrected, void* reductionBuffer); +#else + + static Z execScalar(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + static void execScalar(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *vz, const Nd4jLong *resultShapeInfoBuffer); + + static void exec(int opNum, bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength); + + template + static Z execScalar(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams); + + template + static void execScalar(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, + void *vz, const Nd4jLong *resultShapeInfoBuffer); + + template + static void exec(bool biasCorrected, const void *x, + const Nd4jLong *xShapeInfo, void *extraParams, void *vz, + const Nd4jLong *resultShapeInfoBuffer, int *dimension, + int dimensionLength); +#endif +}; +} // namespace summarystats +} // namespace functions #endif /* SUMMARYSTATSREDUCE_H_ */ diff --git a/libnd4j/include/loops/transform_any.h b/libnd4j/include/loops/transform_any.h index 44c0120f4f38..635c956a1038 100644 --- a/libnd4j/include/loops/transform_any.h +++ b/libnd4j/include/loops/transform_any.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_ANY_H_ #define TRANSFORM_ANY_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,56 +47,54 @@ #include "legacy_ops.h" - namespace functions { namespace transform { -template +template class TransformAny { - public: - + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *params, - void *vz, const Nd4jLong *zShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda(const void *vx, + const Nd4jLong *xShapeInfo, void *params, + void *vz, const Nd4jLong *zShapeInfo, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void *extraParams, + uint64_t threadId, uint64_t numThreads); + + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif }; -} -} - +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_bool.h b/libnd4j/include/loops/transform_bool.h index 03ea5cd6816f..421e1695af70 100644 --- a/libnd4j/include/loops/transform_bool.h +++ b/libnd4j/include/loops/transform_bool.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_BOOL_H_ #define TRANSFORM_BOOL_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,55 +47,51 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformBool { - public: +namespace transform { +template +class TransformBool { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - template - static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_float.h b/libnd4j/include/loops/transform_float.h index 242936c08c89..b31ded716dcf 100644 --- a/libnd4j/include/loops/transform_float.h +++ b/libnd4j/include/loops/transform_float.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_FLOAT_H_ #define TRANSFORM_FLOAT_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,70 +47,64 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformFloat { - public: +namespace transform { +template +class TransformFloat { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static __device__ void transformCudaLegacy(int opNum, - const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - template - static __device__ void transformCuda(Nd4jLong n, - const void *dy, Nd4jLong incy, - void *params, - void *result, Nd4jLong resultStride, - int *allocationPointer, void *reductionPointer); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static __device__ void transformCudaLegacy( + int opNum, const void *dy, const Nd4jLong *shapeInfo, void *params, + void *result, const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static __device__ void transformCuda(Nd4jLong n, const void *dy, + Nd4jLong incy, void *params, + void *result, Nd4jLong resultStride, + int *allocationPointer, + void *reductionPointer); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); + + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_same.h b/libnd4j/include/loops/transform_same.h index 1d669c9d4920..e9c73293331f 100644 --- a/libnd4j/include/loops/transform_same.h +++ b/libnd4j/include/loops/transform_same.h @@ -24,10 +24,11 @@ #ifndef TRANSFORM_SAME_H_ #define TRANSFORM_SAME_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif @@ -46,57 +47,51 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformSame { - public: +namespace transform { +template +class TransformSame { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - template - static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/transform_strict.h b/libnd4j/include/loops/transform_strict.h index b920d021583b..b5da49ff839e 100644 --- a/libnd4j/include/loops/transform_strict.h +++ b/libnd4j/include/loops/transform_strict.h @@ -24,15 +24,16 @@ #ifndef TRANSFORM_STRICT_H_ #define TRANSFORM_STRICT_H_ -#include #include #include +#include + #ifdef _OPENMP #include #endif -#include #include +#include //#include //#include @@ -46,60 +47,53 @@ #include "legacy_ops.h" - namespace functions { - namespace transform { - - template - class TransformStrict { - public: +namespace transform { +template +class TransformStrict { + public: #ifdef __CUDACC__ - template - static __device__ void transformCuda(const void *dy, const Nd4jLong *shapeInfo, - void *params, - void *result, const Nd4jLong *resultShapeInfo, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - - template - static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); - - static _CUDA_H void executeTransformShaped(dim3 launchDims, cudaStream_t *stream, - int opNum, - const void *x, const Nd4jLong *xShape, int xRank, - void *extraParams, - void *z, const Nd4jLong *zShape, int zRank, - int *allocationPointer, void *reductionPointer, - const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets); + template + static __device__ void transformCuda( + const void *dy, const Nd4jLong *shapeInfo, void *params, void *result, + const Nd4jLong *resultShapeInfo, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + template + static _CUDA_H void intermediateShaped(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShape, + int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, + int *allocationPointer, + void *reductionPointer, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); + + static _CUDA_H void executeTransformShaped( + dim3 launchDims, cudaStream_t *stream, int opNum, const void *x, + const Nd4jLong *xShape, int xRank, void *extraParams, void *z, + const Nd4jLong *zShape, int zRank, int *allocationPointer, + void *reductionPointer, const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets); #else + static void exec(int opNum, const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, uint64_t numThreads); - - static void exec(int opNum, - const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); - - template - static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, - void *result, const Nd4jLong *resultShapeInfo, - void *extraParams, - uint64_t threadId, uint64_t numThreads); + template + static SD_EXPORT void exec(const void *dx, const Nd4jLong *xShapeInfo, + void *result, const Nd4jLong *resultShapeInfo, + void *extraParams, uint64_t threadId, + uint64_t numThreads); #endif - }; - } -} - +}; +} // namespace transform +} // namespace functions #endif /* TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/type_conversions.h b/libnd4j/include/loops/type_conversions.h index b56921435b16..300accd81f3c 100644 --- a/libnd4j/include/loops/type_conversions.h +++ b/libnd4j/include/loops/type_conversions.h @@ -15,8 +15,8 @@ ******************************************************************************/ /* - * This set of methods provides dataType conversions in all possible directions supported: - * FP8, FP16, FLOAT, DOUBLE, INT8, UINT8, UINT16, + * This set of methods provides dataType conversions in all possible directions + * supported: FP8, FP16, FLOAT, DOUBLE, INT8, UINT8, UINT16, * * @author raver119@gmail.com */ @@ -33,118 +33,126 @@ #define ND4J_FLOAT32 6 #define ND4J_DOUBLE 7 #define ND4J_THRESHOLD 8 -#define ND4J_FLOAT24 119 // not supported after all. might want to add support later. +#define ND4J_FLOAT24 \ + 119 // not supported after all. might want to add support later. -#include #include +#include +#include #include #include -#include -#include #include +#include #include -#include +#include #define NUM_BANKS 32 #define LOG_NUM_BANKS 4 - namespace sd { - typedef union { - float f_; - int i_; - } FloatBits; - - - class TypeCast { - - public: - template - static _CUDA_H void convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +typedef union { + float f_; + int i_; +} FloatBits; - template - static _CUDA_H void convertToThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); +class TypeCast { + public: + template + static _CUDA_H void convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz); - template - static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, const void *dx, Nd4jLong N, void *dz); + template + static _CUDA_H void convertToThreshold(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); - FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { - if (rawSize <= 0) - throw std::runtime_error("Input size for quantization can't be <= 0"); + template + static _CUDA_H void convertFromThreshold(Nd4jPointer *extras, const void *dx, + Nd4jLong N, void *dz); - // 2 fp32 values for max/min, and rawSize number of BYTES - return 8 + rawSize; - } + FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { + if (rawSize <= 0) + throw std::runtime_error("Input size for quantization can't be <= 0"); + // 2 fp32 values for max/min, and rawSize number of BYTES + return 8 + rawSize; + } - template - static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); + template + static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); - template - static _CUDA_H void convertFromQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); - - #ifdef __CUDACC__ - template - static _CUDA_H void convertGenericCuda(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - #endif - }; + template + static _CUDA_H void convertFromQuantized(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +#ifdef __CUDACC__ + template + static _CUDA_H void convertGenericCuda(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz); +#endif +}; - FORCEINLINE _CUDA_HD bool isPowerOfTwo(int n) { - return ((n&(n-1))==0) ; - } +FORCEINLINE _CUDA_HD bool isPowerOfTwo(int n) { return ((n & (n - 1)) == 0); } - FORCEINLINE _CUDA_HD int floorPow2(int n) { +FORCEINLINE _CUDA_HD int floorPow2(int n) { #ifdef WIN32 - // method 2 - return 1 << static_cast(logb(static_cast(n))); + // method 2 + return 1 << static_cast(logb(static_cast(n))); #else - // method 1 - // float nf = (float)n; - // return 1 << (((*(int*)&nf) >> 23) - 127); - int exp; - frexp(static_cast(n), &exp); - return 1 << (exp - 1); + // method 1 + // float nf = (float)n; + // return 1 << (((*(int*)&nf) >> 23) - 127); + int exp; + frexp(static_cast(n), &exp); + return 1 << (exp - 1); #endif - } +} #ifdef __CUDACC__ - __device__ __inline__ int pow2i (int e){ - return 1< - __host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz, float threshold); - - - template - __host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, void *dx, int *offsets, Nd4jLong N, void *dz); - - - template - __host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *dz); - - template - __host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, void *vdx, Nd4jLong N, int *dz, int *scalar, int *reductionBuffer, float threshold); - - - template - __host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, const void *dx, Nd4jLong N, void *vdz); - - __global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, int baseIndex); - - template - __global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex); - - - template - __host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, cudaStream_t *stream, int *g_odata, const int *g_idata, int *g_blockSums, int n, int blockIndex, int baseIndex); - - template - __global__ void convertKernel(void *dx, Nd4jLong N, void *dz); +__device__ __inline__ int pow2i(int e) { return 1 << e; } + +template +__host__ void encoderKernelP1Generic(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz, + float threshold); + +template +__host__ void encoderKernelP3Generic(dim3 &launchDims, cudaStream_t *stream, + void *dx, int *offsets, Nd4jLong N, + void *dz); + +template +__host__ void decoderKernelGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *dz); + +template +__host__ void cudaEncodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vdx, Nd4jLong N, int *dz, + int *scalar, int *reductionBuffer, + float threshold); + +template +__host__ void cudaDecodeBitmapGeneric(dim3 &launchDims, cudaStream_t *stream, + const void *dx, Nd4jLong N, void *vdz); + +__global__ void uniformAdd(int *g_data, int *uniforms, int n, int blockOffset, + int baseIndex); + +template +__global__ void prescan(int *g_odata, const int *g_idata, int *g_blockSums, + int n, int blockIndex, int baseIndex); + +template +__host__ void prescanLauncher(dim3 &blocks, dim3 &threads, int shmem, + cudaStream_t *stream, int *g_odata, + const int *g_idata, int *g_blockSums, int n, + int blockIndex, int baseIndex); + +template +__global__ void convertKernel(void *dx, Nd4jLong N, void *dz); #endif -} +} // namespace sd -#endif //LIBND4J_TYPE_CONVERSIONS_H +#endif // LIBND4J_TYPE_CONVERSIONS_H diff --git a/libnd4j/include/math/platformmath.h b/libnd4j/include/math/platformmath.h index e68e75fc46b3..f12ca2295106 100644 --- a/libnd4j/include/math/platformmath.h +++ b/libnd4j/include/math/platformmath.h @@ -22,60 +22,54 @@ #define LIBND4J_PLATFORM_MATH_H #include -#include #include #include +#include + #ifdef __CUDACC__ -#include #include +#include union BPAIR { - struct { - bfloat16 H; - bfloat16 L; - } B; - int W; + struct { + bfloat16 H; + bfloat16 L; + } B; + int W; - __host__ __device__ - BPAIR() {}; + __host__ __device__ BPAIR(){}; - __host__ __device__ - ~BPAIR() {}; + __host__ __device__ ~BPAIR(){}; }; #define math_def __host__ __device__ #ifdef CUDA_8 typedef union { - struct { - half H; - half L; - } B; - int W; + struct { + half H; + half L; + } B; + int W; } PAIR; #else -struct HALFS{ - half H; - half L; +struct HALFS { + half H; + half L; - __host__ __device__ - HALFS() {}; + __host__ __device__ HALFS(){}; - __host__ __device__ - ~HALFS() {}; - }; + __host__ __device__ ~HALFS(){}; +}; union PAIR { - HALFS B; - int W; - - __host__ __device__ - PAIR() {}; + HALFS B; + int W; - __host__ __device__ - ~PAIR(){} + __host__ __device__ PAIR(){}; + __host__ __device__ ~PAIR() {} }; -#endif // cuda_9 +#endif // cuda_9 #else #define math_def @@ -83,792 +77,799 @@ union PAIR { #endif - namespace sd { - namespace math { - template - math_def FORCEINLINE T p_exp(T value); +namespace math { +template +math_def FORCEINLINE T p_exp(T value); - template - math_def FORCEINLINE T p_log(T value); +template +math_def FORCEINLINE T p_log(T value); - template - math_def FORCEINLINE T p_floor(T value); +template +math_def FORCEINLINE T p_floor(T value); - template - math_def FORCEINLINE T p_ceil(T value); +template +math_def FORCEINLINE T p_ceil(T value); - template - math_def FORCEINLINE T p_round(T value); +template +math_def FORCEINLINE T p_round(T value); - template - math_def FORCEINLINE T p_cos(T value); +template +math_def FORCEINLINE T p_cos(T value); - template - math_def FORCEINLINE T p_cosh(T value); +template +math_def FORCEINLINE T p_cosh(T value); - template - math_def FORCEINLINE T p_acos(T value); +template +math_def FORCEINLINE T p_acos(T value); - template - math_def FORCEINLINE T p_acosh(T value); +template +math_def FORCEINLINE T p_acosh(T value); - template - math_def FORCEINLINE T p_sin(T value); +template +math_def FORCEINLINE T p_sin(T value); - template - math_def FORCEINLINE T p_sinh(T value); +template +math_def FORCEINLINE T p_sinh(T value); - template - math_def FORCEINLINE T p_asin(T value); +template +math_def FORCEINLINE T p_asin(T value); - template - math_def FORCEINLINE T p_sqrt(T value); +template +math_def FORCEINLINE T p_sqrt(T value); - template - math_def FORCEINLINE T p_tanh(T value); +template +math_def FORCEINLINE T p_tanh(T value); - template - math_def FORCEINLINE T p_erf(T value); +template +math_def FORCEINLINE T p_erf(T value); - template - math_def FORCEINLINE T p_erfc(T value); +template +math_def FORCEINLINE T p_erfc(T value); - template - math_def FORCEINLINE T p_atan(T value); +template +math_def FORCEINLINE T p_atan(T value); - template - math_def FORCEINLINE T p_tan(T value); +template +math_def FORCEINLINE T p_tan(T value); - template - math_def FORCEINLINE T p_atanh(T value); +template +math_def FORCEINLINE T p_atanh(T value); - template - math_def FORCEINLINE T p_rint(T value); +template +math_def FORCEINLINE T p_rint(T value); - template - math_def FORCEINLINE T p_rotl(T value, T shift); +template +math_def FORCEINLINE T p_rotl(T value, T shift); - template - math_def FORCEINLINE T p_rotr(T value, T shift); +template +math_def FORCEINLINE T p_rotr(T value, T shift); - template - math_def FORCEINLINE T p_remainder(T val1, T val2); +template +math_def FORCEINLINE T p_remainder(T val1, T val2); - template - math_def FORCEINLINE T p_fmod(T val1, T val2); +template +math_def FORCEINLINE T p_fmod(T val1, T val2); - template - math_def FORCEINLINE T p_pow(T value, T power); +template +math_def FORCEINLINE T p_pow(T value, T power); - template - math_def FORCEINLINE T p_atan2(T val1, T val2); +template +math_def FORCEINLINE T p_atan2(T val1, T val2); ////// - template <> - math_def FORCEINLINE float p_exp(float value) { - return expf(value); - } +template <> +math_def FORCEINLINE float p_exp(float value) { + return expf(value); +} - template <> - math_def FORCEINLINE float16 p_exp(float16 val) { +template <> +math_def FORCEINLINE float16 p_exp(float16 val) { #ifdef NATIVE_HALFS - return hexp(val.data); + return hexp(val.data); #else - return static_cast(expf((float) val)); + return static_cast(expf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_exp(bfloat16 val) { - return static_cast(expf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_exp(bfloat16 val) { + return static_cast(expf((float)val)); +} - template <> - math_def FORCEINLINE double p_exp(double value) { - return exp(value); - } +template <> +math_def FORCEINLINE double p_exp(double value) { + return exp(value); +} - template - math_def FORCEINLINE T p_exp(T value) { - return static_cast(expf(static_cast(value))); - } +template +math_def FORCEINLINE T p_exp(T value) { + return static_cast(expf(static_cast(value))); +} ///////// - template <> - math_def FORCEINLINE float16 p_pow(float16 value, float16 power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE bfloat16 p_pow(bfloat16 value, bfloat16 power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE float p_pow(float value, float power) { - return powf(value, power); - } - - template <> - math_def FORCEINLINE double p_pow(double value, double power) { - return pow(value, power); - } - - template - math_def FORCEINLINE T p_pow(T value, T power) { - return static_cast(powf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_pow(float16 value, float16 power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE bfloat16 p_pow(bfloat16 value, bfloat16 power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE float p_pow(float value, float power) { + return powf(value, power); +} + +template <> +math_def FORCEINLINE double p_pow(double value, double power) { + return pow(value, power); +} + +template +math_def FORCEINLINE T p_pow(T value, T power) { + return static_cast( + powf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_fmod(float16 value, float16 power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_fmod(float16 value, float16 power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE bfloat16 p_fmod(bfloat16 value, bfloat16 power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE bfloat16 p_fmod(bfloat16 value, bfloat16 power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE float p_fmod(float value, float power) { - return fmodf(value, power); - } +template <> +math_def FORCEINLINE float p_fmod(float value, float power) { + return fmodf(value, power); +} - template <> - math_def FORCEINLINE double p_fmod(double value, double power) { - return fmod(value, power); - } +template <> +math_def FORCEINLINE double p_fmod(double value, double power) { + return fmod(value, power); +} - template - math_def FORCEINLINE T p_fmod(T value, T power) { - return static_cast(fmodf(static_cast(value), static_cast(power))); - } +template +math_def FORCEINLINE T p_fmod(T value, T power) { + return static_cast( + fmodf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_atan2(float16 value, float16 power) { - return static_cast(atan2f(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_atan2(float16 value, float16 power) { + return static_cast( + atan2f(static_cast(value), static_cast(power))); +} - template <> - math_def FORCEINLINE float p_atan2(float value, float power) { - return atan2f(value, power); - } +template <> +math_def FORCEINLINE float p_atan2(float value, float power) { + return atan2f(value, power); +} - template <> - math_def FORCEINLINE double p_atan2(double value, double power) { - return atan2(value, power); - } +template <> +math_def FORCEINLINE double p_atan2(double value, double power) { + return atan2(value, power); +} - template - math_def FORCEINLINE T p_atan2(T value, T power) { - return static_cast(atan2f(static_cast(value), static_cast(power))); - } +template +math_def FORCEINLINE T p_atan2(T value, T power) { + return static_cast( + atan2f(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float16 p_remainder(float16 value, float16 power) { - return static_cast(remainderf(static_cast(value), static_cast(power))); - } - - template <> - math_def FORCEINLINE float p_remainder(float value, float power) { - return remainderf(value, power); - } - - template <> - math_def FORCEINLINE double p_remainder(double value, double power) { - return remainder(value, power); - } - - template - math_def FORCEINLINE T p_remainder(T value, T power) { - return static_cast(remainderf(static_cast(value), static_cast(power))); - } +template <> +math_def FORCEINLINE float16 p_remainder(float16 value, float16 power) { + return static_cast( + remainderf(static_cast(value), static_cast(power))); +} + +template <> +math_def FORCEINLINE float p_remainder(float value, float power) { + return remainderf(value, power); +} + +template <> +math_def FORCEINLINE double p_remainder(double value, double power) { + return remainder(value, power); +} + +template +math_def FORCEINLINE T p_remainder(T value, T power) { + return static_cast( + remainderf(static_cast(value), static_cast(power))); +} ///////// - template <> - math_def FORCEINLINE float p_log(float value) { - return logf(value); - } +template <> +math_def FORCEINLINE float p_log(float value) { + return logf(value); +} - template <> - math_def FORCEINLINE float16 p_log(float16 val) { +template <> +math_def FORCEINLINE float16 p_log(float16 val) { #ifdef NATIVE_HALFS - return hlog(val.data); + return hlog(val.data); #else - return static_cast(logf((float) val)); + return static_cast(logf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE double p_log(double value) { - return log(value); - } +template <> +math_def FORCEINLINE double p_log(double value) { + return log(value); +} - template - math_def FORCEINLINE T p_log(T value) { - return static_cast(logf(static_cast(value))); - } +template +math_def FORCEINLINE T p_log(T value) { + return static_cast(logf(static_cast(value))); +} ///////// - template <> - math_def FORCEINLINE float p_floor(float value) { - return floorf(value); - } +template <> +math_def FORCEINLINE float p_floor(float value) { + return floorf(value); +} - template <> - math_def FORCEINLINE float16 p_floor(float16 val) { +template <> +math_def FORCEINLINE float16 p_floor(float16 val) { #ifdef NATIVE_HALFS - return hfloor(val.data); + return hfloor(val.data); #else - return static_cast(floorf((float) val)); + return static_cast(floorf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_floor(bfloat16 value) { - return static_cast(floorf((float)value)); - } +template <> +math_def FORCEINLINE bfloat16 p_floor(bfloat16 value) { + return static_cast(floorf((float)value)); +} - template <> - math_def FORCEINLINE double p_floor(double value) { - return floor(value); - } +template <> +math_def FORCEINLINE double p_floor(double value) { + return floor(value); +} - template - math_def FORCEINLINE T p_floor(T value) { - return value; - } +template +math_def FORCEINLINE T p_floor(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_ceil(float value) { - return ceilf(value); - } +template <> +math_def FORCEINLINE float p_ceil(float value) { + return ceilf(value); +} - template <> - math_def FORCEINLINE float16 p_ceil(float16 val) { +template <> +math_def FORCEINLINE float16 p_ceil(float16 val) { #ifdef NATIVE_HALFS - return hceil(val.data); + return hceil(val.data); #else - return static_cast(ceilf((float) val)); + return static_cast(ceilf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_ceil(bfloat16 value) { - return static_cast(ceilf((float)value)); - } +template <> +math_def FORCEINLINE bfloat16 p_ceil(bfloat16 value) { + return static_cast(ceilf((float)value)); +} - template <> - math_def FORCEINLINE double p_ceil(double value) { - return ceil(value); - } +template <> +math_def FORCEINLINE double p_ceil(double value) { + return ceil(value); +} - template - math_def FORCEINLINE T p_ceil(T value) { - return value; - } +template +math_def FORCEINLINE T p_ceil(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_round(float value) { - return roundf(value); - } - - template <> - math_def FORCEINLINE float16 p_round(float16 val) { - return static_cast(roundf((float) val)); - } +template <> +math_def FORCEINLINE float p_round(float value) { + return roundf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_round(bfloat16 value) { - return static_cast(roundf((float)value)); - } +template <> +math_def FORCEINLINE float16 p_round(float16 val) { + return static_cast(roundf((float)val)); +} +template <> +math_def FORCEINLINE bfloat16 p_round(bfloat16 value) { + return static_cast(roundf((float)value)); +} - template <> - math_def FORCEINLINE double p_round(double value) { - return round(value); - } +template <> +math_def FORCEINLINE double p_round(double value) { + return round(value); +} - template - math_def FORCEINLINE T p_round(T value) { - return value; - } +template +math_def FORCEINLINE T p_round(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_rint(float value) { - return rintf(value); - } +template <> +math_def FORCEINLINE float p_rint(float value) { + return rintf(value); +} - template <> - math_def FORCEINLINE float16 p_rint(float16 val) { +template <> +math_def FORCEINLINE float16 p_rint(float16 val) { #ifdef NATIVE_HALFS - return hrint(val.data); + return hrint(val.data); #else - return static_cast(rintf((float) val)); + return static_cast(rintf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_rint(bfloat16 val) { - return static_cast(rintf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_rint(bfloat16 val) { + return static_cast(rintf((float)val)); +} - template <> - math_def FORCEINLINE double p_rint(double value) { - return rint(value); - } +template <> +math_def FORCEINLINE double p_rint(double value) { + return rint(value); +} - template - math_def FORCEINLINE T p_rint(T value) { - return value; - } +template +math_def FORCEINLINE T p_rint(T value) { + return value; +} ///////// - template <> - math_def FORCEINLINE float p_cos(float value) { - return cosf(value); - } +template <> +math_def FORCEINLINE float p_cos(float value) { + return cosf(value); +} - template <> - math_def FORCEINLINE float16 p_cos(float16 val) { +template <> +math_def FORCEINLINE float16 p_cos(float16 val) { #ifdef NATIVE_HALFS - return hcos(val.data); + return hcos(val.data); #else - return static_cast(cosf((float) val)); + return static_cast(cosf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_cos(bfloat16 val) { - return static_cast(cosf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_cos(bfloat16 val) { + return static_cast(cosf((float)val)); +} - template <> - math_def FORCEINLINE double p_cos(double value) { - return cos(value); - } +template <> +math_def FORCEINLINE double p_cos(double value) { + return cos(value); +} ///////// - template <> - math_def FORCEINLINE float p_sin(float value) { - return sinf(value); - } +template <> +math_def FORCEINLINE float p_sin(float value) { + return sinf(value); +} - template <> - math_def FORCEINLINE float16 p_sin(float16 val) { +template <> +math_def FORCEINLINE float16 p_sin(float16 val) { #ifdef NATIVE_HALFS - return hsin(val.data); + return hsin(val.data); #else - return static_cast(sinf((float) val)); + return static_cast(sinf((float)val)); #endif - } +} - template <> - math_def FORCEINLINE bfloat16 p_sin(bfloat16 val) { - return static_cast(sinf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_sin(bfloat16 val) { + return static_cast(sinf((float)val)); +} - template <> - math_def FORCEINLINE double p_sin(double value) { - return sin(value); - } +template <> +math_def FORCEINLINE double p_sin(double value) { + return sin(value); +} ///////// - template <> - math_def FORCEINLINE float p_sqrt(float value) { - return sqrtf(value); - } +template <> +math_def FORCEINLINE float p_sqrt(float value) { + return sqrtf(value); +} - template <> - math_def FORCEINLINE float16 p_sqrt(float16 val) { +template <> +math_def FORCEINLINE float16 p_sqrt(float16 val) { #ifdef NATIVE_HALFS - return hsqrt(val.data); + return hsqrt(val.data); #else - return static_cast(sqrtf((float) val)); + return static_cast(sqrtf((float)val)); #endif - } - template <> - math_def FORCEINLINE bfloat16 p_sqrt(bfloat16 val) { - return static_cast(sqrtf((float) val)); - } +} +template <> +math_def FORCEINLINE bfloat16 p_sqrt(bfloat16 val) { + return static_cast(sqrtf((float)val)); +} - template <> - math_def FORCEINLINE double p_sqrt(double value) { - return sqrt(value); - } +template <> +math_def FORCEINLINE double p_sqrt(double value) { + return sqrt(value); +} ///////// - template <> - math_def FORCEINLINE float p_tanh(float value) { - return tanhf(value); - } +template <> +math_def FORCEINLINE float p_tanh(float value) { + return tanhf(value); +} - template <> - math_def FORCEINLINE float16 p_tanh(float16 val) { - return static_cast(tanhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_tanh(float16 val) { + return static_cast(tanhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_tanh(bfloat16 val) { - return static_cast(tanhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_tanh(bfloat16 val) { + return static_cast(tanhf((float)val)); +} - template <> - math_def FORCEINLINE double p_tanh(double value) { - return tanh(value); - } +template <> +math_def FORCEINLINE double p_tanh(double value) { + return tanh(value); +} ///////// - template <> - math_def FORCEINLINE float p_erf(float value) { - return erff(value); - } +template <> +math_def FORCEINLINE float p_erf(float value) { + return erff(value); +} - template <> - math_def FORCEINLINE float16 p_erf(float16 val) { - return static_cast(erff((float) val)); - } +template <> +math_def FORCEINLINE float16 p_erf(float16 val) { + return static_cast(erff((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_erf(bfloat16 val) { - return static_cast(erff((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_erf(bfloat16 val) { + return static_cast(erff((float)val)); +} - template <> - math_def FORCEINLINE double p_erf(double value) { - return erf(value); - } +template <> +math_def FORCEINLINE double p_erf(double value) { + return erf(value); +} ///////// - template <> - math_def FORCEINLINE float p_erfc(float value) { - return erfcf(value); - } +template <> +math_def FORCEINLINE float p_erfc(float value) { + return erfcf(value); +} - template <> - math_def FORCEINLINE float16 p_erfc(float16 val) { - return static_cast(erfcf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_erfc(float16 val) { + return static_cast(erfcf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_erfc(bfloat16 val) { - return static_cast(erfcf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_erfc(bfloat16 val) { + return static_cast(erfcf((float)val)); +} - template <> - math_def FORCEINLINE double p_erfc(double value) { - return erfc(value); - } +template <> +math_def FORCEINLINE double p_erfc(double value) { + return erfc(value); +} ///////// - template <> - math_def FORCEINLINE float p_acos(float value) { - return acosf(value); - } +template <> +math_def FORCEINLINE float p_acos(float value) { + return acosf(value); +} - template <> - math_def FORCEINLINE float16 p_acos(float16 val) { - return static_cast(acosf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_acos(float16 val) { + return static_cast(acosf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_acos(bfloat16 val) { - return static_cast(acosf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_acos(bfloat16 val) { + return static_cast(acosf((float)val)); +} - template <> - math_def FORCEINLINE double p_acos(double value) { - return acos(value); - } +template <> +math_def FORCEINLINE double p_acos(double value) { + return acos(value); +} ///////// - template <> - math_def FORCEINLINE float p_sinh(float value) { - return sinhf(value); - } +template <> +math_def FORCEINLINE float p_sinh(float value) { + return sinhf(value); +} - template <> - math_def FORCEINLINE float16 p_sinh(float16 val) { - return static_cast(sinhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_sinh(float16 val) { + return static_cast(sinhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_sinh(bfloat16 val) { - return static_cast(sinhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_sinh(bfloat16 val) { + return static_cast(sinhf((float)val)); +} - template <> - math_def FORCEINLINE double p_sinh(double value) { - return sinh(value); - } +template <> +math_def FORCEINLINE double p_sinh(double value) { + return sinh(value); +} ///////// - template <> - math_def FORCEINLINE float p_acosh(float value) { - return acoshf(value); - } +template <> +math_def FORCEINLINE float p_acosh(float value) { + return acoshf(value); +} - template <> - math_def FORCEINLINE float16 p_acosh(float16 val) { - return static_cast(acoshf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_acosh(float16 val) { + return static_cast(acoshf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_acosh(bfloat16 val) { - return static_cast(acoshf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_acosh(bfloat16 val) { + return static_cast(acoshf((float)val)); +} - template <> - math_def FORCEINLINE double p_acosh(double value) { - return acosh(value); - } +template <> +math_def FORCEINLINE double p_acosh(double value) { + return acosh(value); +} ///////// - template <> - math_def FORCEINLINE float p_cosh(float value) { - return coshf(value); - } - - template <> - math_def FORCEINLINE float16 p_cosh(float16 val) { - return static_cast(coshf((float) val)); - } +template <> +math_def FORCEINLINE float p_cosh(float value) { + return coshf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_cosh(bfloat16 val) { - return static_cast(coshf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_cosh(float16 val) { + return static_cast(coshf((float)val)); +} - template <> - math_def FORCEINLINE double p_cosh(double value) { - return cosh(value); - } +template <> +math_def FORCEINLINE bfloat16 p_cosh(bfloat16 val) { + return static_cast(coshf((float)val)); +} +template <> +math_def FORCEINLINE double p_cosh(double value) { + return cosh(value); +} ///////// - template <> - math_def FORCEINLINE float p_asin(float value) { - return asinf(value); - } +template <> +math_def FORCEINLINE float p_asin(float value) { + return asinf(value); +} - template <> - math_def FORCEINLINE float16 p_asin(float16 val) { - return static_cast(asinf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_asin(float16 val) { + return static_cast(asinf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_asin(bfloat16 val) { - return static_cast(asinf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_asin(bfloat16 val) { + return static_cast(asinf((float)val)); +} - template <> - math_def FORCEINLINE double p_asin(double value) { - return asin(value); - } +template <> +math_def FORCEINLINE double p_asin(double value) { + return asin(value); +} ///////// - template <> - math_def FORCEINLINE float p_atan(float value) { - return atanf(value); - } - - template <> - math_def FORCEINLINE float16 p_atan(float16 val) { - return static_cast(atanf((float) val)); - } +template <> +math_def FORCEINLINE float p_atan(float value) { + return atanf(value); +} - template <> - math_def FORCEINLINE bfloat16 p_atan(bfloat16 val) { - return static_cast(atanf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_atan(float16 val) { + return static_cast(atanf((float)val)); +} - template <> - math_def FORCEINLINE double p_atan(double value) { - return atan(value); - } +template <> +math_def FORCEINLINE bfloat16 p_atan(bfloat16 val) { + return static_cast(atanf((float)val)); +} +template <> +math_def FORCEINLINE double p_atan(double value) { + return atan(value); +} ///////// - template <> - math_def FORCEINLINE float p_tan(float value) { - return tanf(value); - } +template <> +math_def FORCEINLINE float p_tan(float value) { + return tanf(value); +} - template <> - math_def FORCEINLINE float16 p_tan(float16 val) { - return static_cast(tanf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_tan(float16 val) { + return static_cast(tanf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_tan(bfloat16 val) { - return static_cast(tanf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_tan(bfloat16 val) { + return static_cast(tanf((float)val)); +} - template <> - math_def FORCEINLINE double p_tan(double value) { - return tan(value); - } +template <> +math_def FORCEINLINE double p_tan(double value) { + return tan(value); +} ///////// - template <> - math_def FORCEINLINE float p_atanh(float value) { - return atanhf(value); - } +template <> +math_def FORCEINLINE float p_atanh(float value) { + return atanhf(value); +} - template <> - math_def FORCEINLINE float16 p_atanh(float16 val) { - return static_cast(atanhf((float) val)); - } +template <> +math_def FORCEINLINE float16 p_atanh(float16 val) { + return static_cast(atanhf((float)val)); +} - template <> - math_def FORCEINLINE bfloat16 p_atanh(bfloat16 val) { - return static_cast(atanhf((float) val)); - } +template <> +math_def FORCEINLINE bfloat16 p_atanh(bfloat16 val) { + return static_cast(atanhf((float)val)); +} - template <> - math_def FORCEINLINE double p_atanh(double value) { - return atanh(value); - } +template <> +math_def FORCEINLINE double p_atanh(double value) { + return atanh(value); +} ///////// - template - math_def FORCEINLINE T _rotate_left(T value, T shift); - - template - math_def FORCEINLINE T _rotate_right(T value, T shift); - - template <> - math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) { - return value << shift | value >> (8 - shift); - } - - template <> - math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) { - return value >> shift | value << (8 - shift); - } - - template <> - math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) { - return value << shift | value >> (8 - shift); - } - - template <> - math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) { - return value >> shift | value << (8 - shift); - } - - template <> - math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) { - return value << shift | value >> (16 - shift); - } - - template <> - math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) { - return value >> shift | value << (16 - shift); - } - - template <> - math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) { - return value << shift | value >> (16 - shift); - } - - template <> - math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) { - return value >> shift | value << (16 - shift); - } - - template <> - math_def FORCEINLINE int _rotate_left(int value, int shift) { - return value << shift | value >> (32 - shift); - } - - template <> - math_def FORCEINLINE int _rotate_right(int value, int shift) { - return value >> shift | value << (32 - shift); - } - - template <> - math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) { - return value << shift | value >> (32 - shift); - } - - template <> - math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) { - return value >> shift | value << (32 - shift); - } - - template <> - math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) { - return value << shift | value >> (64 - shift); - } - - template <> - math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) { - return value >> shift | value << (64 - shift); - } - - template <> - math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) { +template +math_def FORCEINLINE T _rotate_left(T value, T shift); + +template +math_def FORCEINLINE T _rotate_right(T value, T shift); + +template <> +math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) { + return value << shift | value >> (8 - shift); +} + +template <> +math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) { + return value >> shift | value << (8 - shift); +} + +template <> +math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) { + return value << shift | value >> (8 - shift); +} + +template <> +math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) { + return value >> shift | value << (8 - shift); +} + +template <> +math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) { + return value << shift | value >> (16 - shift); +} + +template <> +math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) { + return value >> shift | value << (16 - shift); +} + +template <> +math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) { + return value << shift | value >> (16 - shift); +} + +template <> +math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) { + return value >> shift | value << (16 - shift); +} + +template <> +math_def FORCEINLINE int _rotate_left(int value, int shift) { + return value << shift | value >> (32 - shift); +} + +template <> +math_def FORCEINLINE int _rotate_right(int value, int shift) { + return value >> shift | value << (32 - shift); +} + +template <> +math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) { + return value << shift | value >> (32 - shift); +} + +template <> +math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) { + return value >> shift | value << (32 - shift); +} + +template <> +math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) { + return value << shift | value >> (64 - shift); +} + +template <> +math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) { + return value >> shift | value << (64 - shift); +} + +template <> +math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) { #ifdef SD_ARM_BUILD - // TODO: eventually remove this once gcc fixes the bug - Nd4jLong val = _rotate_left(*reinterpret_cast(&value), *reinterpret_cast(&shift)); - return *reinterpret_cast(&val); + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_left(*reinterpret_cast(&value), + *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); #else - return value << shift | value >> (64 - shift); + return value << shift | value >> (64 - shift); #endif - } +} - template <> - math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) { +template <> +math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) { #ifdef SD_ARM_BUILD - // TODO: eventually remove this once gcc fixes the bug - Nd4jLong val = _rotate_right(*reinterpret_cast(&value), *reinterpret_cast(&shift)); - return *reinterpret_cast(&val); + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_right(*reinterpret_cast(&value), + *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); #else - return value >> shift | value << (64 - shift); + return value >> shift | value << (64 - shift); #endif - } - +} - template - math_def FORCEINLINE T p_rotl(T value, T shift) { - return _rotate_left(value, shift); - } +template +math_def FORCEINLINE T p_rotl(T value, T shift) { + return _rotate_left(value, shift); +} - template - math_def FORCEINLINE T p_rotr(T value, T shift) { - return _rotate_right(value, shift); - } - } +template +math_def FORCEINLINE T p_rotr(T value, T shift) { + return _rotate_right(value, shift); } +} // namespace math +} // namespace sd -#endif //SD_PLATFORM_MATH_H +#endif // SD_PLATFORM_MATH_H diff --git a/libnd4j/include/math/templatemath.h b/libnd4j/include/math/templatemath.h index c220231d8627..f676b62d95e7 100644 --- a/libnd4j/include/math/templatemath.h +++ b/libnd4j/include/math/templatemath.h @@ -25,10 +25,10 @@ #ifndef TEMPLATEMATH_H_ #define TEMPLATEMATH_H_ +#include +#include #include #include -#include -#include #define BFLOAT16_MAX_VALUE 32737. #define HALF_MAX_VALUE 65504. @@ -49,932 +49,910 @@ namespace sd { #endif - namespace math { - template - math_def inline T nd4j_abs(T value); +namespace math { +template +math_def inline T nd4j_abs(T value); - template - math_def inline void nd4j_swap(T &val1, T &val2); +template +math_def inline void nd4j_swap(T& val1, T& val2); - template - math_def inline T nd4j_max(T val1, T val2); +template +math_def inline T nd4j_max(T val1, T val2); - template - math_def inline T nd4j_min(T val1, T val2); +template +math_def inline T nd4j_min(T val1, T val2); - template - math_def inline bool nd4j_eq(T val1, T val2, double eps); +template +math_def inline bool nd4j_eq(T val1, T val2, double eps); - template - math_def inline Z nd4j_re(T val1, T val2); +template +math_def inline Z nd4j_re(T val1, T val2); - template - math_def inline Z nd4j_rint(T val1); +template +math_def inline Z nd4j_rint(T val1); - template - math_def inline Z nd4j_copysign(T val1, T val2); +template +math_def inline Z nd4j_copysign(T val1, T val2); - template - math_def inline Z nd4j_softplus(T val); +template +math_def inline Z nd4j_softplus(T val); - template - math_def inline T nd4j_rotl(T val, T shift); +template +math_def inline T nd4j_rotl(T val, T shift); - template - math_def inline T nd4j_rotr(T val, T shift); +template +math_def inline T nd4j_rotr(T val, T shift); //#ifndef __CUDACC__ - template - math_def inline Z nd4j_dot(X *x, Y *y, int length); +template +math_def inline Z nd4j_dot(X* x, Y* y, int length); //#endif - template - math_def inline Z nd4j_ceil(T val1); +template +math_def inline Z nd4j_ceil(T val1); - template - math_def inline bool nd4j_isnan(T val1); +template +math_def inline bool nd4j_isnan(T val1); - template - math_def inline bool nd4j_isinf(T val1); +template +math_def inline bool nd4j_isinf(T val1); - template - math_def inline bool nd4j_isfin(T val1); +template +math_def inline bool nd4j_isfin(T val1); - template - math_def inline Z nd4j_cos(T val); +template +math_def inline Z nd4j_cos(T val); - template - math_def inline Z nd4j_cosh(T val); +template +math_def inline Z nd4j_cosh(T val); - template - math_def inline Z nd4j_exp(X val); +template +math_def inline Z nd4j_exp(X val); - template - math_def inline Z nd4j_floor(T val); +template +math_def inline Z nd4j_floor(T val); - template - math_def inline Z nd4j_log(X val); +template +math_def inline Z nd4j_log(X val); - template - math_def inline Z nd4j_pow(X val, Y val2); +template +math_def inline Z nd4j_pow(X val, Y val2); - template - math_def inline Z nd4j_round(T val); +template +math_def inline Z nd4j_round(T val); - template - math_def inline Z nd4j_remainder(X num, Y denom); +template +math_def inline Z nd4j_remainder(X num, Y denom); - template - math_def inline Z nd4j_fmod(X num, Y denom); +template +math_def inline Z nd4j_fmod(X num, Y denom); - template - math_def inline Z nd4j_erf(T num); +template +math_def inline Z nd4j_erf(T num); - template - math_def inline Z nd4j_erfc(T num); +template +math_def inline Z nd4j_erfc(T num); - math_def inline int32_t floatToRawIntBits(float d) { - union { - float f; - int32_t i; - } tmp; - tmp.f = d; - return tmp.i; - } +math_def inline int32_t floatToRawIntBits(float d) { + union { + float f; + int32_t i; + } tmp; + tmp.f = d; + return tmp.i; +} - math_def inline float intBitsToFloat(int32_t i) { - union { - float f; - int32_t i; - } tmp; - tmp.i = i; - return tmp.f; - } +math_def inline float intBitsToFloat(int32_t i) { + union { + float f; + int32_t i; + } tmp; + tmp.i = i; + return tmp.f; +} - math_def inline float mulsignf(float x, float y) { - return intBitsToFloat(floatToRawIntBits(x) ^ (floatToRawIntBits(y) & (1 << 31))); - } +math_def inline float mulsignf(float x, float y) { + return intBitsToFloat(floatToRawIntBits(x) ^ + (floatToRawIntBits(y) & (1 << 31))); +} - math_def inline float copysignfk(float x, float y) { - return intBitsToFloat((floatToRawIntBits(x) & ~(1 << 31)) ^ (floatToRawIntBits(y) & (1 << 31))); - } +math_def inline float copysignfk(float x, float y) { + return intBitsToFloat((floatToRawIntBits(x) & ~(1 << 31)) ^ + (floatToRawIntBits(y) & (1 << 31))); +} - template - math_def inline Z nd4j_sigmoid(T val) { - return (Z) 1.0f / ((Z) 1.0f + nd4j_exp(-val)); - } - - template - math_def inline Z nd4j_elu(T val, T alpha) { - if (val >= (T) 0.f) - return val; - return static_cast(alpha) * (nd4j_exp(val) - static_cast(1.0f)); - } - - template - math_def inline Z nd4j_leakyrelu(T val,T alpha) { - if (val < (T) 0.0f) - return alpha * val; - else - return val; - } - - template - math_def inline Z nd4j_eluderivative(T val, T alpha) { - if (val >= static_cast(0.0f)) - return static_cast(1.0f); - return static_cast(alpha) * nd4j_exp(val); - //return val >= 0.0 ? 1.0 : nd4j_exp(val); - } - - template - math_def inline Z nd4j_sin(T val); - - template - math_def inline Z nd4j_sinh(T val); - - template - math_def inline Z nd4j_softplus(T val) { - return nd4j_log((Z) 1.0f + nd4j_exp(val)); - } - - template - math_def inline Z nd4j_softsign(T val) { - return val / ((T) 1.0f + sd::math::nd4j_abs(val)); - } - - template - math_def inline Z nd4j_sqrt(X val); - - template - math_def inline Z nd4j_tanh(X val); - - template - math_def inline Z nd4j_tan(T val); - - template - math_def inline Z nd4j_atan2(X val1, X val2); - - template - math_def inline Z nd4j_atan2(X val1, X val2) { - return p_atan2(static_cast(val1), static_cast(val2)); - } - - - template - math_def inline Z nd4j_tan(T tval) { - return p_tan(static_cast(tval)); - } +template +math_def inline Z nd4j_sigmoid(T val) { + return (Z)1.0f / ((Z)1.0f + nd4j_exp(-val)); +} - template - math_def inline Z nd4j_tanhderivative(T val) { - Z tanh = nd4j_tanh(val); - return (Z) 1.0f - tanh * tanh; - } - template - math_def inline T nd4j_sigmoidderivative(T val) { - Z sigmoid = nd4j_sigmoid(val); - return sigmoid * ((Z) 1.0f - sigmoid); - } - - template - math_def inline T nd4j_softsignderivative(T val) { - T y = (T) 1.0f + nd4j_abs(val); - return (Z) 1.0f / (y * y); - } - - template - math_def inline T nd4j_sgn(T val) { - return val < (T) 0.0f ? (Z) -1.0f : val > (T) 0.0f ? (Z) 1.0f : (Z) 0.0f; - } +template +math_def inline Z nd4j_elu(T val, T alpha) { + if (val >= (T)0.f) return val; + return static_cast(alpha) * (nd4j_exp(val) - static_cast(1.0f)); +} - template - math_def inline Z nd4j_sign(T val) { - return nd4j_sgn(val); - } +template +math_def inline Z nd4j_leakyrelu(T val, T alpha) { + if (val < (T)0.0f) + return alpha * val; + else + return val; +} - template - math_def inline Z nd4j_signum(T val) { - return nd4j_sgn(val); - } +template +math_def inline Z nd4j_eluderivative(T val, T alpha) { + if (val >= static_cast(0.0f)) return static_cast(1.0f); + return static_cast(alpha) * nd4j_exp(val); + // return val >= 0.0 ? 1.0 : nd4j_exp(val); +} + +template +math_def inline Z nd4j_sin(T val); + +template +math_def inline Z nd4j_sinh(T val); + +template +math_def inline Z nd4j_softplus(T val) { + return nd4j_log((Z)1.0f + nd4j_exp(val)); +} + +template +math_def inline Z nd4j_softsign(T val) { + return val / ((T)1.0f + sd::math::nd4j_abs(val)); +} + +template +math_def inline Z nd4j_sqrt(X val); + +template +math_def inline Z nd4j_tanh(X val); + +template +math_def inline Z nd4j_tan(T val); + +template +math_def inline Z nd4j_atan2(X val1, X val2); + +template +math_def inline Z nd4j_atan2(X val1, X val2) { + return p_atan2(static_cast(val1), static_cast(val2)); +} - template - math_def inline Z nd4j_gamma(X a); +template +math_def inline Z nd4j_tan(T tval) { + return p_tan(static_cast(tval)); +} + +template +math_def inline Z nd4j_tanhderivative(T val) { + Z tanh = nd4j_tanh(val); + return (Z)1.0f - tanh * tanh; +} +template +math_def inline T nd4j_sigmoidderivative(T val) { + Z sigmoid = nd4j_sigmoid(val); + return sigmoid * ((Z)1.0f - sigmoid); +} + +template +math_def inline T nd4j_softsignderivative(T val) { + T y = (T)1.0f + nd4j_abs(val); + return (Z)1.0f / (y * y); +} - template - math_def inline Z nd4j_lgamma(X x); +template +math_def inline T nd4j_sgn(T val) { + return val < (T)0.0f ? (Z)-1.0f : val > (T)0.0f ? (Z)1.0f : (Z)0.0f; +} + +template +math_def inline Z nd4j_sign(T val) { + return nd4j_sgn(val); +} + +template +math_def inline Z nd4j_signum(T val) { + return nd4j_sgn(val); +} + +template +math_def inline Z nd4j_gamma(X a); + +template +math_def inline Z nd4j_lgamma(X x); //#ifndef __CUDACC__ /* template<> - math_def inline float16 nd4j_dot(float16 *x, float16 *y, int length) { - float16 dot = (float16) 0.0f; + math_def inline float16 nd4j_dot(float16 *x, float16 *y, int + length) { float16 dot = (float16) 0.0f; - // TODO: since we can't use simd on unions, we might use something else here. - for(int e = 0; e < length; e++) { - dot += x[e] * y[e]; + // TODO: since we can't use simd on unions, we might use something + else here. for(int e = 0; e < length; e++) { dot += x[e] * y[e]; } return dot; } */ - template - math_def inline Z nd4j_dot(X *x, Y *y, int length) { - Z dot = (Z)0.0f; +template +math_def inline Z nd4j_dot(X* x, Y* y, int length) { + Z dot = (Z)0.0f; - for(int e = 0; e < length; e++) { - dot += static_cast(x[e]) * static_cast(y[e]); - } + for (int e = 0; e < length; e++) { + dot += static_cast(x[e]) * static_cast(y[e]); + } - return dot; - } + return dot; +} //#endif - template - math_def inline Z nd4j_acos(T val); +template +math_def inline Z nd4j_acos(T val); - template - math_def inline Z nd4j_sech(T val); +template +math_def inline Z nd4j_sech(T val); - template - math_def inline Z nd4j_acosh(T val); +template +math_def inline Z nd4j_acosh(T val); - template - math_def inline Z nd4j_asin(T val); +template +math_def inline Z nd4j_asin(T val); - template - math_def inline Z nd4j_asinh(T val); - - template - math_def inline Z nd4j_asinh(T val) { - //Math.log(Math.sqrt(Math.pow(x, 2) + 1) + x) - return nd4j_log(nd4j_sqrt(nd4j_pow(val, (T) 2) + (Z) 1.f) + (Z) val); - } +template +math_def inline Z nd4j_asinh(T val); - template - math_def inline Z nd4j_atan(T val); +template +math_def inline Z nd4j_asinh(T val) { + // Math.log(Math.sqrt(Math.pow(x, 2) + 1) + x) + return nd4j_log(nd4j_sqrt(nd4j_pow(val, (T)2) + (Z)1.f) + + (Z)val); +} - template - math_def inline Z nd4j_atanh(T val); +template +math_def inline Z nd4j_atan(T val); +template +math_def inline Z nd4j_atanh(T val); - template<> - math_def inline float16 nd4j_abs(float16 value) { +template <> +math_def inline float16 nd4j_abs(float16 value) { #ifdef NATIVE_HALFS - if (value < (float16) 0.f) { - return float16(__hneg(value.data)); - } else - return value; + if (value < (float16)0.f) { + return float16(__hneg(value.data)); + } else + return value; #else - return (float16) fabsf((float) value); + return (float16)fabsf((float)value); #endif - } - template<> - math_def inline bfloat16 nd4j_abs(bfloat16 value) { - return (bfloat16) fabsf((float) value); - } - template<> - math_def inline float nd4j_abs(float value) { - return fabsf(value); - } - - template<> - math_def inline double nd4j_abs(double value) { - return fabs(value); - } - - template<> - math_def inline int nd4j_abs(int value) { - return abs(value); - } - - template<> - math_def inline Nd4jLong nd4j_abs(Nd4jLong value) { - return llabs(value); - } - - template<> - math_def inline bool nd4j_abs(bool value) { - return value; - } - - template<> - math_def inline uint8_t nd4j_abs(uint8_t value) { - return value; - } - - template<> - math_def inline uint16_t nd4j_abs(uint16_t value) { - return value; - } - - template<> - math_def inline uint32_t nd4j_abs(uint32_t value) { - return value; - } - - template<> - math_def inline Nd4jULong nd4j_abs(Nd4jULong value) { - return value; - } - - template<> - math_def inline int8_t nd4j_abs(int8_t value) { - return value < 0 ? -value : value; - } - - template<> - math_def inline int16_t nd4j_abs(int16_t value) { - return value < 0 ? -value : value; - } - - - template<> - math_def inline bool nd4j_isnan(float16 value) { - return *(value.data.getXP()) == 0x7fffU; - } - - template<> - math_def inline bool nd4j_isnan(bfloat16 value) { - return value == bfloat16::nan(); //0x7fffU; - } - - template<> - math_def inline bool nd4j_isnan(float value) { - return value != value; - } - - template<> - math_def inline bool nd4j_isnan(double value) { - return value != value; - } - - template<> - math_def inline bool nd4j_isnan(int value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint32_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(uint8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(int16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(int8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(bool value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(Nd4jLong value) { - return false; - } - - template<> - math_def inline bool nd4j_isnan(Nd4jULong value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(float16 value) { - return value < (float16) -HALF_MAX_VALUE || value > (float16) HALF_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(bfloat16 value) { - return value < (bfloat16) -BFLOAT16_MAX_VALUE || value > (bfloat16) BFLOAT16_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(float value) { +} +template <> +math_def inline bfloat16 nd4j_abs(bfloat16 value) { + return (bfloat16)fabsf((float)value); +} +template <> +math_def inline float nd4j_abs(float value) { + return fabsf(value); +} + +template <> +math_def inline double nd4j_abs(double value) { + return fabs(value); +} + +template <> +math_def inline int nd4j_abs(int value) { + return abs(value); +} + +template <> +math_def inline Nd4jLong nd4j_abs(Nd4jLong value) { + return llabs(value); +} + +template <> +math_def inline bool nd4j_abs(bool value) { + return value; +} + +template <> +math_def inline uint8_t nd4j_abs(uint8_t value) { + return value; +} + +template <> +math_def inline uint16_t nd4j_abs(uint16_t value) { + return value; +} + +template <> +math_def inline uint32_t nd4j_abs(uint32_t value) { + return value; +} + +template <> +math_def inline Nd4jULong nd4j_abs(Nd4jULong value) { + return value; +} + +template <> +math_def inline int8_t nd4j_abs(int8_t value) { + return value < 0 ? -value : value; +} + +template <> +math_def inline int16_t nd4j_abs(int16_t value) { + return value < 0 ? -value : value; +} + +template <> +math_def inline bool nd4j_isnan(float16 value) { + return *(value.data.getXP()) == 0x7fffU; +} + +template <> +math_def inline bool nd4j_isnan(bfloat16 value) { + return value == bfloat16::nan(); // 0x7fffU; +} + +template <> +math_def inline bool nd4j_isnan(float value) { + return value != value; +} + +template <> +math_def inline bool nd4j_isnan(double value) { + return value != value; +} + +template <> +math_def inline bool nd4j_isnan(int value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint32_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint16_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(uint8_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(int16_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(int8_t value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(bool value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(Nd4jLong value) { + return false; +} + +template <> +math_def inline bool nd4j_isnan(Nd4jULong value) { + return false; +} + +template <> +math_def inline bool nd4j_isinf(float16 value) { + return value < (float16)-HALF_MAX_VALUE || value > (float16)HALF_MAX_VALUE; +} + +template <> +math_def inline bool nd4j_isinf(bfloat16 value) { + return value < (bfloat16)-BFLOAT16_MAX_VALUE || + value > (bfloat16)BFLOAT16_MAX_VALUE; +} + +template <> +math_def inline bool nd4j_isinf(float value) { #ifdef __CUDACC__ - return isinf(value); + return isinf(value); #else - return std::isinf(value); + return std::isinf(value); #endif - //return value < -FLOAT_MAX_VALUE || value > FLOAT_MAX_VALUE; - } + // return value < -FLOAT_MAX_VALUE || value > FLOAT_MAX_VALUE; +} - template<> - math_def inline bool nd4j_isinf(double value) { +template <> +math_def inline bool nd4j_isinf(double value) { #ifdef __CUDACC__ - return isinf(value); + return isinf(value); #else - return std::isinf(value); + return std::isinf(value); #endif - //return value < -DOUBLE_MAX_VALUE || value > DOUBLE_MAX_VALUE; - } - - template<> - math_def inline bool nd4j_isinf(int value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint32_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(uint8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(int16_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(int8_t value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(bool value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(Nd4jLong value) { - return false; - } - - template<> - math_def inline bool nd4j_isinf(Nd4jULong value) { - return false; - } - - template - math_def inline bool nd4j_isfin(T value) { - return !nd4j_isnan(value) && !nd4j_isinf(value); - } - - template<> - math_def inline float16 nd4j_copysign(float16 val1, float16 val2) { - return (float16) copysignf((float) val1, (float) val2); - } - - template<> - math_def inline float nd4j_copysign(float val1, float val2) { - return copysignf(val1, val2); - } - - template<> - math_def inline double nd4j_copysign(double val1, double val2) { - return copysign(val1, val2); - } - - template<> - math_def inline int nd4j_copysign(int val1, int val2) { - if (val2 < 0) return -(nd4j_abs(val1)); - else return nd4j_abs(val1); - } - - template<> - math_def inline Nd4jLong nd4j_copysign(Nd4jLong val1, Nd4jLong val2) { - if (val2 < 0) return -(nd4j_abs(val1)); - else return nd4j_abs(val1); - } - - template<> - math_def inline bool nd4j_max(bool val1, bool val2) { - return (val1 || val2) ? true : false; - } - - template - math_def inline T nd4j_max(T val1, T val2) { - return val1 > val2 ? val1 : val2; - } - - template<> - math_def inline bool nd4j_min(bool val1, bool val2) { - return (val1 && val2) ? true : false; - } - - template - math_def inline T nd4j_min(T val1, T val2) { - return val1 < val2 ? val1 : val2; - } - - template - math_def inline bool nd4j_eq(T d1, T d2, double eps) { - if (sd::math::nd4j_isinf(d1) && sd::math::nd4j_isinf(d2)) { - if (d1 > 0 && d2 > 0) - return true; - else if (d1 < 0 && d2 < 0) - return true; - else - return false; - } - - auto diff = static_cast(sd::math::nd4j_abs(d1 - d2)); - - - // works well except in the range of very large numbers - if (diff <= eps) - return true; - - // Knuth approach - // works well except in the range of very small numbers - if (diff <= sd::math::nd4j_max(sd::math::nd4j_abs(static_cast(d1)), sd::math::nd4j_abs(static_cast(d2))) * eps) - return true; - - return false; - } - - template - math_def inline Z nd4j_ceil(X val) { - return static_cast(p_ceil(val)); - } - - template - math_def inline Z nd4j_round(X val) { - return static_cast(p_round(val)); - } + // return value < -DOUBLE_MAX_VALUE || value > DOUBLE_MAX_VALUE; +} - template - math_def inline Z nd4j_asin(X val) { - return p_asin(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(int value) { + return false; +} - template - math_def inline Z nd4j_atan(X val) { - return p_atan(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint32_t value) { + return false; +} - template - math_def inline Z nd4j_atanh(X val) { - return p_atanh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint16_t value) { + return false; +} - template - math_def inline Z nd4j_cosh(X val) { - return p_cosh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(uint8_t value) { + return false; +} - template - math_def inline Z nd4j_rint(X val) { - return p_rint(val); - } +template <> +math_def inline bool nd4j_isinf(int16_t value) { + return false; +} - template - math_def inline Z nd4j_sinh(X val) { - return p_sinh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(int8_t value) { + return false; +} - template - math_def inline Z nd4j_acos(X val) { - return p_acos(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(bool value) { + return false; +} - template - math_def inline Z nd4j_sech(X val) { - return static_cast(1) / nd4j_cosh(val); - } +template <> +math_def inline bool nd4j_isinf(Nd4jLong value) { + return false; +} - template - math_def inline Z nd4j_acosh(X val) { - return p_acosh(static_cast(val)); - } +template <> +math_def inline bool nd4j_isinf(Nd4jULong value) { + return false; +} - template - math_def inline Z nd4j_cos(X val) { - return p_cos(static_cast(val)); - } +template +math_def inline bool nd4j_isfin(T value) { + return !nd4j_isnan(value) && !nd4j_isinf(value); +} - template - math_def inline Z nd4j_exp(X val) { - return p_exp(val); - } +template <> +math_def inline float16 nd4j_copysign(float16 val1, float16 val2) { + return (float16)copysignf((float)val1, (float)val2); +} - template - math_def inline Z nd4j_floor(X val) { - return static_cast(p_floor(val)); - } - - template - math_def inline Z nd4j_log(X val) { - return static_cast(p_log(val)); - } - - /** - * This func is special case - it must return floating point value, and optionally Y arg can be floating point argument - * @tparam X - * @tparam Y - * @tparam Z - * @param val - * @param val2 - * @return - */ - template <> - math_def inline float nd4j_pow(float val, float val2) { - return p_pow(val, val2); - } +template <> +math_def inline float nd4j_copysign(float val1, float val2) { + return copysignf(val1, val2); +} - template - math_def inline Z nd4j_pow(X val, Y val2) { - return p_pow(static_cast(val), static_cast(val2)); - } - - /** - * LogGamma(a) - float point extension of ln(n!) - **/ - template - math_def inline Z nd4j_lgamma(X x) { -// if (x <= X(0.0)) -// { -// std::stringstream os; -// os << "Logarithm of Gamma has sence only for positive values, but " << x << " was given."; -// throw std::invalid_argument( os.str() ); -// } - - if (x < X(12.0)) { - return nd4j_log(nd4j_gamma(x)); - } +template <> +math_def inline double nd4j_copysign(double val1, double val2) { + return copysign(val1, val2); +} - // Abramowitz and Stegun 6.1.41 - // Asymptotic series should be good to at least 11 or 12 figures - // For error analysis, see Whittiker and Watson - // A Course in Modern Analysis (1927), page 252 - - static const double c[8] = { - 1.0/12.0, - -1.0/360.0, - 1.0/1260.0, - -1.0/1680.0, - 1.0/1188.0, - -691.0/360360.0, - 1.0/156.0, - -3617.0/122400.0 - }; - - double z = Z(1.0 / Z(x * x)); - double sum = c[7]; - - for (int i = 6; i >= 0; i--) { - sum *= z; - sum += c[i]; - } +template <> +math_def inline int nd4j_copysign(int val1, int val2) { + if (val2 < 0) + return -(nd4j_abs(val1)); + else + return nd4j_abs(val1); +} - double series = sum / Z(x); +template <> +math_def inline Nd4jLong nd4j_copysign(Nd4jLong val1, Nd4jLong val2) { + if (val2 < 0) + return -(nd4j_abs(val1)); + else + return nd4j_abs(val1); +} - static const double halfLogTwoPi = 0.91893853320467274178032973640562; +template <> +math_def inline bool nd4j_max(bool val1, bool val2) { + return (val1 || val2) ? true : false; +} - return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + halfLogTwoPi + series); - } +template +math_def inline T nd4j_max(T val1, T val2) { + return val1 > val2 ? val1 : val2; +} +template <> +math_def inline bool nd4j_min(bool val1, bool val2) { + return (val1 && val2) ? true : false; +} +template +math_def inline T nd4j_min(T val1, T val2) { + return val1 < val2 ? val1 : val2; +} - template - math_def inline T nd4j_re(T val1, T val2) { - if (val1 == (T) 0.0f && val2 == (T) 0.0f) - return (T) 0.0f; +template +math_def inline bool nd4j_eq(T d1, T d2, double eps) { + if (sd::math::nd4j_isinf(d1) && sd::math::nd4j_isinf(d2)) { + if (d1 > 0 && d2 > 0) + return true; + else if (d1 < 0 && d2 < 0) + return true; + else + return false; + } + + auto diff = static_cast(sd::math::nd4j_abs(d1 - d2)); + + // works well except in the range of very large numbers + if (diff <= eps) return true; + + // Knuth approach + // works well except in the range of very small numbers + if (diff <= sd::math::nd4j_max( + sd::math::nd4j_abs(static_cast(d1)), + sd::math::nd4j_abs(static_cast(d2))) * + eps) + return true; + + return false; +} - return nd4j_abs(val1 - val2) / (nd4j_abs(val1) + nd4j_abs(val2)); - } +template +math_def inline Z nd4j_ceil(X val) { + return static_cast(p_ceil(val)); +} +template +math_def inline Z nd4j_round(X val) { + return static_cast(p_round(val)); +} - template - math_def inline Z nd4j_remainder(X val, Y val2) { - return p_remainder(static_cast(val), static_cast(val2)); - } +template +math_def inline Z nd4j_asin(X val) { + return p_asin(static_cast(val)); +} + +template +math_def inline Z nd4j_atan(X val) { + return p_atan(static_cast(val)); +} - template - math_def inline Z nd4j_fmod(X val, Y val2) { - return p_fmod(static_cast(val), static_cast(val2)); - } +template +math_def inline Z nd4j_atanh(X val) { + return p_atanh(static_cast(val)); +} +template +math_def inline Z nd4j_cosh(X val) { + return p_cosh(static_cast(val)); +} - template - math_def inline Z nd4j_sin(X val) { - return p_sin(static_cast(val)); - } +template +math_def inline Z nd4j_rint(X val) { + return p_rint(val); +} +template +math_def inline Z nd4j_sinh(X val) { + return p_sinh(static_cast(val)); +} - template - math_def inline Z nd4j_sqrt(X val) { - return p_sqrt(static_cast(val)); - } +template +math_def inline Z nd4j_acos(X val) { + return p_acos(static_cast(val)); +} +template +math_def inline Z nd4j_sech(X val) { + return static_cast(1) / nd4j_cosh(val); +} - template - math_def inline X neg_tanh(X val) { - X o = static_cast(1.0f); - X t = static_cast(2.0f); - X e = static_cast(M_E); +template +math_def inline Z nd4j_acosh(X val) { + return p_acosh(static_cast(val)); +} - auto p = sd::math::nd4j_pow(e, val * t); - return (p - o)/ (p + o); - } +template +math_def inline Z nd4j_cos(X val) { + return p_cos(static_cast(val)); +} - template - math_def inline X pos_tanh(X val) { - X o = static_cast(1.0f); - X t = static_cast(-2.0f); - X e = static_cast(M_E); +template +math_def inline Z nd4j_exp(X val) { + return p_exp(val); +} - auto p = sd::math::nd4j_pow(e, val * t); - return (o - p) / (o + p); - } +template +math_def inline Z nd4j_floor(X val) { + return static_cast(p_floor(val)); +} +template +math_def inline Z nd4j_log(X val) { + return static_cast(p_log(val)); +} - math_def inline float neu_tanh(float val, float sign) { - float e(M_E); - float av = sign * val; - auto p = sd::math::nd4j_pow(e, -av * 2.f); - return (1 - p) / (1 + p); - } +/** + * This func is special case - it must return floating point value, and + * optionally Y arg can be floating point argument + * @tparam X + * @tparam Y + * @tparam Z + * @param val + * @param val2 + * @return + */ +template <> +math_def inline float nd4j_pow(float val, float val2) { + return p_pow(val, val2); +} - template <> - math_def inline float nd4j_tanh(float val) { - float sign = copysignfk(1.0f, val); - return sign * neu_tanh(val, sign); - } +template +math_def inline Z nd4j_pow(X val, Y val2) { + return p_pow(static_cast(val), static_cast(val2)); +} +/** + * LogGamma(a) - float point extension of ln(n!) + **/ +template +math_def inline Z nd4j_lgamma(X x) { + // if (x <= X(0.0)) + // { + // std::stringstream os; + // os << "Logarithm of Gamma has sence only for positive + // values, but " << x << " was given."; throw + // std::invalid_argument( os.str() ); + // } + + if (x < X(12.0)) { + return nd4j_log(nd4j_gamma(x)); + } + + // Abramowitz and Stegun 6.1.41 + // Asymptotic series should be good to at least 11 or 12 figures + // For error analysis, see Whittiker and Watson + // A Course in Modern Analysis (1927), page 252 + + static const double c[8] = { + 1.0 / 12.0, -1.0 / 360.0, 1.0 / 1260.0, -1.0 / 1680.0, + 1.0 / 1188.0, -691.0 / 360360.0, 1.0 / 156.0, -3617.0 / 122400.0}; + + double z = Z(1.0 / Z(x * x)); + double sum = c[7]; + + for (int i = 6; i >= 0; i--) { + sum *= z; + sum += c[i]; + } + + double series = sum / Z(x); + + static const double halfLogTwoPi = 0.91893853320467274178032973640562; + + return Z((double(x) - 0.5) * nd4j_log(x) - double(x) + + halfLogTwoPi + series); +} - template - math_def inline Z nd4j_tanh(X val) { - return val <= 0 ? neg_tanh(val) : pos_tanh(val); - } +template +math_def inline T nd4j_re(T val1, T val2) { + if (val1 == (T)0.0f && val2 == (T)0.0f) return (T)0.0f; - template - math_def inline T nd4j_rotl(T val, T shift) { - return p_rotl(val, shift); - } + return nd4j_abs(val1 - val2) / (nd4j_abs(val1) + nd4j_abs(val2)); +} - template - math_def inline T nd4j_rotr(T val, T shift) { - return p_rotr(val, shift); - } +template +math_def inline Z nd4j_remainder(X val, Y val2) { + return p_remainder(static_cast(val), static_cast(val2)); +} - template - math_def inline Z nd4j_erf(X val) { - return p_erf(static_cast(val)); - } +template +math_def inline Z nd4j_fmod(X val, Y val2) { + return p_fmod(static_cast(val), static_cast(val2)); +} +template +math_def inline Z nd4j_sin(X val) { + return p_sin(static_cast(val)); +} - template - math_def inline Z nd4j_erfc(X val) { - return p_erfc(static_cast(val)); - } +template +math_def inline Z nd4j_sqrt(X val) { + return p_sqrt(static_cast(val)); +} - template - math_def inline void nd4j_swap(T &val1, T &val2) { - T temp = val1; val1=val2; val2=temp; - }; - - template - math_def inline Z nd4j_gamma(X a) { -// nd4j_lgamma(a); -// return (Z)std::tgamma(a); - // Split the function domain into three intervals: - // (0, 0.001), [0.001, 12), and (12, infinity) - - /////////////////////////////////////////////////////////////////////////// - // First interval: (0, 0.001) - // - // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... - // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of a^3. - // The relative error over this interval is less than 6e-7. - - const double eulerGamma = 0.577215664901532860606512090; // Euler's gamma constant - - if (a < X(0.001)) - return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); - - /////////////////////////////////////////////////////////////////////////// - // Second interval: [0.001, 12) - - if (a < X(12.0)) { - // The algorithm directly approximates gamma over (1,2) and uses - // reduction identities to reduce other arguments to this interval. - - double y = (double)a; - int n = 0; - bool argWasLessThanOne = y < 1.0; - - // Add or subtract integers as necessary to bring y into (1,2) - // Will correct for this below - if (argWasLessThanOne) { - y += 1.0; - } - else { - n = static_cast(floor(y)) - 1; // will use n later - y -= n; - } - - // numerator coefficients for approximation over the interval (1,2) - static const double p[] = { - -1.71618513886549492533811E+0, - 2.47656508055759199108314E+1, - -3.79804256470945635097577E+2, - 6.29331155312818442661052E+2, - 8.66966202790413211295064E+2, - -3.14512729688483675254357E+4, - -3.61444134186911729807069E+4, - 6.64561438202405440627855E+4 - }; - - // denominator coefficients for approximation over the interval (1,2) - static const double q[] = { - -3.08402300119738975254353E+1, - 3.15350626979604161529144E+2, - -1.01515636749021914166146E+3, - -3.10777167157231109440444E+3, - 2.25381184209801510330112E+4, - 4.75584627752788110767815E+3, - -1.34659959864969306392456E+5, - -1.15132259675553483497211E+5 - }; - - double num = 0.0; - double den = 1.0; - - - double z = y - 1; - for (auto i = 0; i < 8; i++) { - num = (num + p[i]) * z; - den = den * z + q[i]; - } - double result = num / den + 1.0; - - // Apply correction if argument was not initially in (1,2) - if (argWasLessThanOne) { - // Use identity gamma(z) = gamma(z+1)/z - // The variable "result" now holds gamma of the original y + 1 - // Thus we use y-1 to get back the orginal y. - result /= (y - 1.0); - } - else { - // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) - for (auto i = 0; i < n; i++) - result *= y++; - } - - return Z(result); - } +template +math_def inline X neg_tanh(X val) { + X o = static_cast(1.0f); + X t = static_cast(2.0f); + X e = static_cast(M_E); - /////////////////////////////////////////////////////////////////////////// - // Third interval: [12, infinity) + auto p = sd::math::nd4j_pow(e, val * t); + return (p - o) / (p + o); +} - if (a > 171.624) { - // Correct answer too large to display. Force +infinity. - return Z(DOUBLE_MAX_VALUE); -// return DataTypeUtils::infOrMax(); - } +template +math_def inline X pos_tanh(X val) { + X o = static_cast(1.0f); + X t = static_cast(-2.0f); + X e = static_cast(M_E); - return sd::math::nd4j_exp(sd::math::nd4j_lgamma(a)); - } + auto p = sd::math::nd4j_pow(e, val * t); + return (o - p) / (o + p); +} - template - math_def inline Z nd4j_igamma(X a, Y x) { - Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); - auto sum = Z(0.); - auto denom = Z(1.); - if (a <= X(0.000001)) - //throw std::runtime_error("Cannot calculate gamma for a zero val."); - return Z(0); - - for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { - denom *= (a + i); - sum += nd4j_pow(x, i) / denom; - } - return aim * sum; - } +math_def inline float neu_tanh(float val, float sign) { + float e(M_E); + float av = sign * val; + auto p = sd::math::nd4j_pow(e, -av * 2.f); + return (1 - p) / (1 + p); +} - template - math_def inline Z nd4j_igammac(X a, Y x) { - return Z(1.) - nd4j_igamma(a, x); - } +template <> +math_def inline float nd4j_tanh(float val) { + float sign = copysignfk(1.0f, val); + return sign * neu_tanh(val, sign); +} + +template +math_def inline Z nd4j_tanh(X val) { + return val <= 0 ? neg_tanh(val) : pos_tanh(val); +} + +template +math_def inline T nd4j_rotl(T val, T shift) { + return p_rotl(val, shift); +} + +template +math_def inline T nd4j_rotr(T val, T shift) { + return p_rotr(val, shift); +} + +template +math_def inline Z nd4j_erf(X val) { + return p_erf(static_cast(val)); +} + +template +math_def inline Z nd4j_erfc(X val) { + return p_erfc(static_cast(val)); +} + +template +math_def inline void nd4j_swap(T& val1, T& val2) { + T temp = val1; + val1 = val2; + val2 = temp; +}; + +template +math_def inline Z nd4j_gamma(X a) { + // nd4j_lgamma(a); + // return (Z)std::tgamma(a); + // Split the function domain into three intervals: + // (0, 0.001), [0.001, 12), and (12, infinity) + + /////////////////////////////////////////////////////////////////////////// + // First interval: (0, 0.001) + // + // For small a, 1/Gamma(a) has power series a + gamma a^2 - ... + // So in this range, 1/Gamma(a) = a + gamma a^2 with error on the order of + // a^3. The relative error over this interval is less than 6e-7. + + const double eulerGamma = + 0.577215664901532860606512090; // Euler's gamma constant + + if (a < X(0.001)) + return Z(1.0 / ((double)a * (1.0 + eulerGamma * (double)a))); + + /////////////////////////////////////////////////////////////////////////// + // Second interval: [0.001, 12) + + if (a < X(12.0)) { + // The algorithm directly approximates gamma over (1,2) and uses + // reduction identities to reduce other arguments to this interval. + + double y = (double)a; + int n = 0; + bool argWasLessThanOne = y < 1.0; + + // Add or subtract integers as necessary to bring y into (1,2) + // Will correct for this below + if (argWasLessThanOne) { + y += 1.0; + } else { + n = static_cast(floor(y)) - 1; // will use n later + y -= n; + } + + // numerator coefficients for approximation over the interval (1,2) + static const double p[] = { + -1.71618513886549492533811E+0, 2.47656508055759199108314E+1, + -3.79804256470945635097577E+2, 6.29331155312818442661052E+2, + 8.66966202790413211295064E+2, -3.14512729688483675254357E+4, + -3.61444134186911729807069E+4, 6.64561438202405440627855E+4}; + + // denominator coefficients for approximation over the interval (1,2) + static const double q[] = { + -3.08402300119738975254353E+1, 3.15350626979604161529144E+2, + -1.01515636749021914166146E+3, -3.10777167157231109440444E+3, + 2.25381184209801510330112E+4, 4.75584627752788110767815E+3, + -1.34659959864969306392456E+5, -1.15132259675553483497211E+5}; + + double num = 0.0; + double den = 1.0; + + double z = y - 1; + for (auto i = 0; i < 8; i++) { + num = (num + p[i]) * z; + den = den * z + q[i]; + } + double result = num / den + 1.0; + + // Apply correction if argument was not initially in (1,2) + if (argWasLessThanOne) { + // Use identity gamma(z) = gamma(z+1)/z + // The variable "result" now holds gamma of the original y + 1 + // Thus we use y-1 to get back the orginal y. + result /= (y - 1.0); + } else { + // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) + for (auto i = 0; i < n; i++) result *= y++; + } + + return Z(result); + } + + /////////////////////////////////////////////////////////////////////////// + // Third interval: [12, infinity) + + if (a > 171.624) { + // Correct answer too large to display. Force +infinity. + return Z(DOUBLE_MAX_VALUE); + // return DataTypeUtils::infOrMax(); + } + + return sd::math::nd4j_exp(sd::math::nd4j_lgamma(a)); +} + +template +math_def inline Z nd4j_igamma(X a, Y x) { + Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); + auto sum = Z(0.); + auto denom = Z(1.); + if (a <= X(0.000001)) + // throw std::runtime_error("Cannot calculate gamma for a zero val."); + return Z(0); + + for (int i = 0; Z(1. / denom) > Z(1.0e-12); i++) { + denom *= (a + i); + sum += nd4j_pow(x, i) / denom; + } + return aim * sum; +} + +template +math_def inline Z nd4j_igammac(X a, Y x) { + return Z(1.) - nd4j_igamma(a, x); +} #ifdef __CUDACC__ - namespace atomics { +namespace atomics { template inline __device__ T nd4j_atomicAdd(T* address, T val); @@ -991,712 +969,748 @@ template inline __device__ T nd4j_atomicMax(T* address, T val); template <> -inline __device__ int32_t nd4j_atomicMin(int32_t* address, int32_t val) { - return atomicMin(address, val); +inline __device__ int32_t nd4j_atomicMin(int32_t* address, + int32_t val) { + return atomicMin(address, val); } template <> -inline __device__ uint32_t nd4j_atomicMin(uint32_t* address, uint32_t val) { - return atomicMin(address, val); +inline __device__ uint32_t nd4j_atomicMin(uint32_t* address, + uint32_t val) { + return atomicMin(address, val); } template <> -inline __device__ float nd4j_atomicMin(float* address, float val) { - int* address_as_ull = (int*)address; - int old = __float_as_int(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(math::nd4j_min(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); +inline __device__ float nd4j_atomicMin(float* address, float val) { + int* address_as_ull = (int*)address; + int old = __float_as_int(val), assumed; + do { + assumed = old; + old = + atomicCAS(address_as_ull, assumed, + __float_as_int(math::nd4j_min(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); } template <> -inline __device__ double nd4j_atomicMin(double* address, double val) { - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __double_as_longlong(math::nd4j_min(val, __longlong_as_double(assumed)))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMin(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong( + math::nd4j_min(val, __longlong_as_double(assumed)))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ uint64_t nd4j_atomicMin(uint64_t* address, uint64_t val) { +inline __device__ uint64_t nd4j_atomicMin(uint64_t* address, + uint64_t val) { #if __CUDA_ARCH__ >= 350 - return atomicMin((unsigned long long*)address, (unsigned long long)val); + return atomicMin((unsigned long long*)address, (unsigned long long)val); #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_min((unsigned long long)val, assumed)); - } while (assumed != old); - return old; + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_min((unsigned long long)val, assumed)); + } while (assumed != old); + return old; #endif } template <> -inline __device__ Nd4jLong nd4j_atomicMin(Nd4jLong* address, Nd4jLong val) { - - #if __CUDA_ARCH__ >= 350 - return atomicMin((unsigned long long*)address, (unsigned long long)val); - #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = (unsigned long long)val, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_min(val, (Nd4jLong)assumed)); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicMin(Nd4jLong* address, + Nd4jLong val) { +#if __CUDA_ARCH__ >= 350 + return atomicMin((unsigned long long*)address, (unsigned long long)val); +#else + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = (unsigned long long)val, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_min(val, (Nd4jLong)assumed)); + } while (assumed != old); + return old; #endif - } template <> -inline __device__ int16_t nd4j_atomicMin(int16_t* address, int16_t val) { - int32_t temp = *address; - *address = atomicMin(&temp, (int)val); - return *address; +inline __device__ int16_t nd4j_atomicMin(int16_t* address, + int16_t val) { + int32_t temp = *address; + *address = atomicMin(&temp, (int)val); + return *address; } template <> -inline __device__ bfloat16 nd4j_atomicMin(bfloat16* address, bfloat16 val) { - return bfloat16(nd4j_atomicMin(&address->_data, val._data)); +inline __device__ bfloat16 nd4j_atomicMin(bfloat16* address, + bfloat16 val) { + return bfloat16(nd4j_atomicMin(&address->_data, val._data)); } template <> -inline __device__ float16 nd4j_atomicMin(float16* address, float16 val) { - return float16(nd4j_atomicMin(reinterpret_cast(&address->data), (int16_t)val.data)); +inline __device__ float16 nd4j_atomicMin(float16* address, + float16 val) { + return float16(nd4j_atomicMin( + reinterpret_cast(&address->data), (int16_t)val.data)); } template <> -inline __device__ int32_t nd4j_atomicMax(int32_t* address, int32_t val) { - return atomicMax(address, val); +inline __device__ int32_t nd4j_atomicMax(int32_t* address, + int32_t val) { + return atomicMax(address, val); } template <> -inline __device__ uint32_t nd4j_atomicMax(uint32_t* address, uint32_t val) { - return atomicMax(address, val); +inline __device__ uint32_t nd4j_atomicMax(uint32_t* address, + uint32_t val) { + return atomicMax(address, val); } template <> -inline __device__ double nd4j_atomicMax(double* address, double val) { - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __double_as_longlong(math::nd4j_max(val, __longlong_as_double(assumed)))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMax(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong( + math::nd4j_max(val, __longlong_as_double(assumed)))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ float nd4j_atomicMax(float* address, float val) { - int* address_as_ull = (int*)address; - int old = __float_as_int(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(math::nd4j_max(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); +inline __device__ float nd4j_atomicMax(float* address, float val) { + int* address_as_ull = (int*)address; + int old = __float_as_int(val), assumed; + do { + assumed = old; + old = + atomicCAS(address_as_ull, assumed, + __float_as_int(math::nd4j_max(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); } template <> -inline __device__ uint8_t nd4j_atomicMin(uint8_t* address, uint8_t val) { - uint32_t temp = *address; - *address = atomicMin(&temp, (uint32_t)val); - return *address; +inline __device__ uint8_t nd4j_atomicMin(uint8_t* address, + uint8_t val) { + uint32_t temp = *address; + *address = atomicMin(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int8_t nd4j_atomicMin(int8_t* address, int8_t val) { - int32_t temp = *address; - *address = atomicMin(&temp, (int)val); - return *address; +inline __device__ int8_t nd4j_atomicMin(int8_t* address, int8_t val) { + int32_t temp = *address; + *address = atomicMin(&temp, (int)val); + return *address; } template <> -inline __device__ uint16_t nd4j_atomicMin(uint16_t* address, uint16_t val) { - uint32_t temp = *address; - *address = atomicMin(&temp, (uint32_t)val); - return *address; +inline __device__ uint16_t nd4j_atomicMin(uint16_t* address, + uint16_t val) { + uint32_t temp = *address; + *address = atomicMin(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ uint8_t nd4j_atomicMax(uint8_t* address, uint8_t val) { - uint32_t temp = *address; - *address = atomicMax(&temp, (uint32_t)val); - return *address; +inline __device__ uint8_t nd4j_atomicMax(uint8_t* address, + uint8_t val) { + uint32_t temp = *address; + *address = atomicMax(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int8_t nd4j_atomicMax(int8_t* address, int8_t val) { - int32_t temp = *address; - *address = atomicMax(&temp, (int)val); - return *address; +inline __device__ int8_t nd4j_atomicMax(int8_t* address, int8_t val) { + int32_t temp = *address; + *address = atomicMax(&temp, (int)val); + return *address; } template <> -inline __device__ uint16_t nd4j_atomicMax(uint16_t* address, uint16_t val) { - uint32_t temp = *address; - *address = atomicMax(&temp, (uint32_t)val); - return *address; +inline __device__ uint16_t nd4j_atomicMax(uint16_t* address, + uint16_t val) { + uint32_t temp = *address; + *address = atomicMax(&temp, (uint32_t)val); + return *address; } template <> -inline __device__ int16_t nd4j_atomicMax(int16_t* address, int16_t val) { - int32_t temp = *address; - *address = atomicMax(&temp, (int32_t)val); - return *address; +inline __device__ int16_t nd4j_atomicMax(int16_t* address, + int16_t val) { + int32_t temp = *address; + *address = atomicMax(&temp, (int32_t)val); + return *address; } template <> -inline __device__ float16 nd4j_atomicMax(float16* address, float16 val) { - auto address_as_ull = (int*) address; - - long addr = (long) address; - bool misaligned = addr & 0x3; +inline __device__ float16 nd4j_atomicMax(float16* address, + float16 val) { + auto address_as_ull = (int*)address; - if (misaligned) - address_as_ull = (int *) (address - 1); + long addr = (long)address; + bool misaligned = addr & 0x3; - PAIR old, assumed, fresh; + if (misaligned) address_as_ull = (int*)(address - 1); - old.W = *address_as_ull; - do { + PAIR old, assumed, fresh; - if (!misaligned) { - float16 res = nd4j_max((float16) old.B.H, val); - fresh.B.H = res.data; - fresh.B.L = old.B.L; - } else { - float16 res = nd4j_max((float16) old.B.L, val); - fresh.B.L = res.data; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + float16 res = nd4j_max((float16)old.B.H, val); + fresh.B.H = res.data; + fresh.B.L = old.B.L; + } else { + float16 res = nd4j_max((float16)old.B.L, val); + fresh.B.L = res.data; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template <> -inline __device__ bfloat16 nd4j_atomicMax(bfloat16* address, bfloat16 val) { - auto address_as_ull = (int*) address; +inline __device__ bfloat16 nd4j_atomicMax(bfloat16* address, + bfloat16 val) { + auto address_as_ull = (int*)address; - long addr = (long)(address); - bool misaligned = addr & 0x3; + long addr = (long)(address); + bool misaligned = addr & 0x3; - if (misaligned) - address_as_ull = (int *) (address - 1); + if (misaligned) address_as_ull = (int*)(address - 1); - BPAIR old, assumed, fresh; + BPAIR old, assumed, fresh; - old.W = *address_as_ull; - do { - - if (!misaligned) { - bfloat16 res = nd4j_max(old.B.H, val); - fresh.B.H = res; - fresh.B.L = old.B.L; - } else { - bfloat16 res = nd4j_max(old.B.L, val); - fresh.B.L = res; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + bfloat16 res = nd4j_max(old.B.H, val); + fresh.B.H = res; + fresh.B.L = old.B.L; + } else { + bfloat16 res = nd4j_max(old.B.L, val); + fresh.B.L = res; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template <> -inline __device__ uint64_t nd4j_atomicMax(uint64_t* address, uint64_t val) { +inline __device__ uint64_t nd4j_atomicMax(uint64_t* address, + uint64_t val) { #if __CUDA_ARCH__ >= 350 - return atomicMax((unsigned long long*)address, (unsigned long long)val); + return atomicMax((unsigned long long*)address, (unsigned long long)val); #else - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = __double_as_longlong(val), assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, math::nd4j_max((unsigned long long)val, assumed)); - } while (assumed != old); - return old; + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = __double_as_longlong(val), assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + math::nd4j_max((unsigned long long)val, assumed)); + } while (assumed != old); + return old; #endif } template <> -inline __device__ Nd4jLong nd4j_atomicMax(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* address_as_ull = (unsigned long long int *) address; - - //return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, (unsigned long long)nd4j_max(val, (Nd4jLong)assumed)); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicMax(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + + // return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + (unsigned long long)nd4j_max(val, (Nd4jLong)assumed)); + } while (assumed != old); + return old; } - template <> -inline __device__ double nd4j_atomicAdd(double* address, double val) { - unsigned long long int* address_as_ull = - (unsigned long long int *) address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val + - __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ Nd4jLong nd4j_atomicAdd(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* address_as_ull = (unsigned long long int *) address; - - //return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, val + assumed); - } while (assumed != old); - return old; +inline __device__ Nd4jLong nd4j_atomicAdd(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + + // return (Nd4jLong) atomicAdd(address_as_ull, (unsigned long long int) val); + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, val + assumed); + } while (assumed != old); + return old; } template <> -inline __device__ long nd4j_atomicAdd(long* address, long val) { - unsigned long long* address_as_ull = (unsigned long long int *) address; - -// return atomicAdd(address, val); - unsigned long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, val + assumed); - } while (assumed != old); - return old; +inline __device__ long nd4j_atomicAdd(long* address, long val) { + unsigned long long* address_as_ull = (unsigned long long int*)address; + + // return atomicAdd(address, val); + unsigned long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, val + assumed); + } while (assumed != old); + return old; } template <> -inline __device__ uint32_t nd4j_atomicAdd(uint32_t* address, uint32_t val) { - return atomicAdd(address, val); +inline __device__ uint32_t nd4j_atomicAdd(uint32_t* address, + uint32_t val) { + return atomicAdd(address, val); } template <> -inline __device__ uint64_t nd4j_atomicAdd(uint64_t* address, uint64_t val) { -// unsigned long long* address_as_ull = (unsigned long long int *) address; -// -//// return atomicAdd(address, val); -// unsigned long int old = *address_as_ull, assumed; -// do { -// assumed = old; -// old = atomicCAS(address_as_ull, assumed, val + assumed); -// } while (assumed != old); -// return old; - return (uint64_t)atomicAdd((unsigned long long*)address, (unsigned long long)val); +inline __device__ uint64_t nd4j_atomicAdd(uint64_t* address, + uint64_t val) { + // unsigned long long* address_as_ull = (unsigned long long int *) address; + // + //// return atomicAdd(address, val); + // unsigned long int old = *address_as_ull, assumed; + // do { + // assumed = old; + // old = atomicCAS(address_as_ull, assumed, val + assumed); + // } while (assumed != old); + // return old; + return (uint64_t)atomicAdd((unsigned long long*)address, + (unsigned long long)val); } template <> -inline __device__ float16 nd4j_atomicAdd(float16* address, float16 val) { +inline __device__ float16 nd4j_atomicAdd(float16* address, + float16 val) { #if __CUDA_ARCH__ >= 700 && defined(CUDA_10) - atomicAdd(reinterpret_cast<__half*>(address), val.data); + atomicAdd(reinterpret_cast<__half*>(address), val.data); #else - auto address_as_ull = (int*) address; + auto address_as_ull = (int*)address; - long addr = (long) address; - bool misaligned = addr & 0x3; + long addr = (long)address; + bool misaligned = addr & 0x3; - if (misaligned) - address_as_ull = (int *) (address - 1); + if (misaligned) address_as_ull = (int*)(address - 1); - PAIR old, assumed, fresh; + PAIR old, assumed, fresh; - old.W = *address_as_ull; - do { - - if (!misaligned) { - float16 res = ((float16) old.B.H) + val; - fresh.B.H = res.data; - fresh.B.L = old.B.L; - } else { - float16 res = ((float16) old.B.L) + val; - fresh.B.L = res.data; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + float16 res = ((float16)old.B.H) + val; + fresh.B.H = res.data; + fresh.B.L = old.B.L; + } else { + float16 res = ((float16)old.B.L) + val; + fresh.B.L = res.data; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; #endif } template <> -inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, bfloat16 val) { - auto address_as_ull = (int*) address; - - auto addr = (long)(address); - bool misaligned = addr & 0x3; +inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, + bfloat16 val) { + auto address_as_ull = (int*)address; - if (misaligned) - address_as_ull = (int *) (address - 1); + auto addr = (long)(address); + bool misaligned = addr & 0x3; - BPAIR old, assumed, fresh; + if (misaligned) address_as_ull = (int*)(address - 1); - old.W = *address_as_ull; - do { + BPAIR old, assumed, fresh; - if (!misaligned) { - bfloat16 res = old.B.H + val; - fresh.B.H = res; - fresh.B.L = old.B.L; - } else { - bfloat16 res = old.B.L + val; - fresh.B.L = res; - fresh.B.H = old.B.H; - } + old.W = *address_as_ull; + do { + if (!misaligned) { + bfloat16 res = old.B.H + val; + fresh.B.H = res; + fresh.B.L = old.B.L; + } else { + bfloat16 res = old.B.L + val; + fresh.B.L = res; + fresh.B.H = old.B.H; + } - assumed.W = old.W; - old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); - } while (assumed.W != old.W); + assumed.W = old.W; + old.W = atomicCAS(address_as_ull, assumed.W, fresh.W); + } while (assumed.W != old.W); - if (!misaligned) return old.B.H; - else return old.B.L; + if (!misaligned) + return old.B.H; + else + return old.B.L; } template static inline __device__ T internal_16bit_atomicAdd(T* address, T val) { - size_t shift = ((size_t)address & 2); - int *base_address = (int *)((char*)address - shift); - - union I16PAIR { - struct { - T H; - T L; - } B; - int W; + size_t shift = ((size_t)address & 2); + int* base_address = (int*)((char*)address - shift); - __host__ __device__ - I16PAIR() {}; + union I16PAIR { + struct { + T H; + T L; + } B; + int W; - __host__ __device__ - ~I16PAIR() {}; - }; + __host__ __device__ I16PAIR(){}; - I16PAIR pairNew, pairOld, pairAssumed; + __host__ __device__ ~I16PAIR(){}; + }; - if (reinterpret_cast(address) == base_address) { - pairOld.B.L = val; - do { + I16PAIR pairNew, pairOld, pairAssumed; - pairNew.B.L = pairOld.B.L; - pairNew.B.H = pairOld.B.H + val; - pairAssumed.W = pairOld.W; - - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H + val; + pairAssumed.W = pairOld.W; - return (T) pairOld.B.H; - } else { - pairOld.B.H = val; - do { + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); - pairNew.B.H = pairOld.B.H; - pairNew.B.L = pairOld.B.L + val; - pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + return (T)pairOld.B.H; + } else { + pairOld.B.H = val; + do { + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L + val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); - - return (T) pairOld.B.L; - } + } while (pairAssumed.W != pairOld.W); + return (T)pairOld.B.L; + } } template <> -inline __device__ int16_t nd4j_atomicAdd(int16_t* address, int16_t val) { - return internal_16bit_atomicAdd(address, val); +inline __device__ int16_t nd4j_atomicAdd(int16_t* address, + int16_t val) { + return internal_16bit_atomicAdd(address, val); } template <> -inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, uint16_t val) { - return internal_16bit_atomicAdd(address, val); +inline __device__ uint16_t nd4j_atomicAdd(uint16_t* address, + uint16_t val) { + return internal_16bit_atomicAdd(address, val); } template <> -inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { - int res = *address; - atomicAdd(&res, (int)val); - *address = res; - return *address; +inline __device__ int8_t nd4j_atomicAdd(int8_t* address, int8_t val) { + int res = *address; + atomicAdd(&res, (int)val); + *address = res; + return *address; } template <> -inline __device__ uint8_t nd4j_atomicAdd(uint8_t* address, uint8_t val) { - int res = *address; - atomicAdd(&res, (int)val); - *address = res; - return *address; +inline __device__ uint8_t nd4j_atomicAdd(uint8_t* address, + uint8_t val) { + int res = *address; + atomicAdd(&res, (int)val); + *address = res; + return *address; } template <> -inline __device__ bool nd4j_atomicAdd(bool* address, bool val) { - *address += (val); - return *address; +inline __device__ bool nd4j_atomicAdd(bool* address, bool val) { + *address += (val); + return *address; } template <> -inline __device__ double nd4j_atomicSub(double* address, double val) { - return nd4j_atomicAdd(address, -val); +inline __device__ double nd4j_atomicSub(double* address, double val) { + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ double nd4j_atomicMul(double* address, double val) { - unsigned long long int* address_as_ull = - (unsigned long long int*) address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val * - __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); +inline __device__ double nd4j_atomicMul(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); } template <> -inline __device__ double nd4j_atomicDiv(double* address, double val) { - return nd4j_atomicMul(address, 1./val); +inline __device__ double nd4j_atomicDiv(double* address, double val) { + return nd4j_atomicMul(address, 1. / val); } template <> -inline __device__ float nd4j_atomicAdd(float* address, float val) { - return atomicAdd(address,val); +inline __device__ float nd4j_atomicAdd(float* address, float val) { + return atomicAdd(address, val); } -//template <> -//inline __device__ int nd4j_atomicAdd(int* address, int val) { +// template <> +// inline __device__ int nd4j_atomicAdd(int* address, int val) { // return atomicAdd(address, val); //} template <> -inline __device__ int32_t nd4j_atomicAdd(int32_t* address, int32_t val) { - return (int32_t)atomicAdd((int*)address, (int)val); +inline __device__ int32_t nd4j_atomicAdd(int32_t* address, + int32_t val) { + return (int32_t)atomicAdd((int*)address, (int)val); } - template <> inline __device__ float nd4j_atomicSub(float* address, float val) { - return nd4j_atomicAdd(address, -val); + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ float16 nd4j_atomicSub(float16* address, float16 val) { - return nd4j_atomicAdd(address, -val); +inline __device__ float16 nd4j_atomicSub(float16* address, + float16 val) { + return nd4j_atomicAdd(address, -val); } template <> -inline __device__ bfloat16 nd4j_atomicSub(bfloat16* address, bfloat16 val) { - return nd4j_atomicAdd(address, -val); +inline __device__ bfloat16 nd4j_atomicSub(bfloat16* address, + bfloat16 val) { + return nd4j_atomicAdd(address, -val); } template <> inline __device__ float nd4j_atomicMul(float* address, float val) { - int* address_as_ull = - ( int*)address; - int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __float_as_int(val * - __int_as_float(assumed))); - } while (assumed != old); - return __int_as_float(old); + int* address_as_ull = (int*)address; + int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __float_as_int(val * __int_as_float(assumed))); + } while (assumed != old); + return __int_as_float(old); } template <> inline __device__ int8_t nd4j_atomicMul(int8_t* address, int8_t val) { - unsigned int *base_address = (unsigned int *)((size_t)address & ~3); - unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; - unsigned int sel = selectors[(size_t)address & 3]; - unsigned int old, assumed, mul, new_; + unsigned int* base_address = (unsigned int*)((size_t)address & ~3); + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + unsigned int sel = selectors[(size_t)address & 3]; + unsigned int old, assumed, mul, new_; - old = *base_address; + old = *base_address; - do { + do { + assumed = old; + mul = val * (int8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); + new_ = __byte_perm(old, mul, sel); - assumed = old; - mul = val * (int8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); - new_ = __byte_perm(old, mul, sel); + if (new_ == old) break; - if (new_ == old) - break; - - old = atomicCAS(base_address, assumed, new_); - } while (assumed != old); - return (int8_t)old; + old = atomicCAS(base_address, assumed, new_); + } while (assumed != old); + return (int8_t)old; } template <> -inline __device__ unsigned char nd4j_atomicMul(unsigned char* address, unsigned char val) { - unsigned int *base_address = (unsigned int *)((size_t)address & ~3); - unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; - unsigned int sel = selectors[(size_t)address & 3]; - unsigned int old, assumed, mul, new_; - - old = *base_address; +inline __device__ unsigned char nd4j_atomicMul( + unsigned char* address, unsigned char val) { + unsigned int* base_address = (unsigned int*)((size_t)address & ~3); + unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; + unsigned int sel = selectors[(size_t)address & 3]; + unsigned int old, assumed, mul, new_; - do { + old = *base_address; - assumed = old; - mul = val * (uint8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); - new_ = __byte_perm(old, mul, sel); + do { + assumed = old; + mul = val * (uint8_t)__byte_perm(old, 0, ((size_t)address & 3) | 0x4440); + new_ = __byte_perm(old, mul, sel); - if (new_ == old) - break; + if (new_ == old) break; - old = atomicCAS(base_address, assumed, new_); - } while (assumed != old); - return (uint8_t)old; + old = atomicCAS(base_address, assumed, new_); + } while (assumed != old); + return (uint8_t)old; } template static inline __device__ T internal_16bit_atomicMul(T* address, T val) { - size_t shift = ((size_t)address & 2); - int *base_address = (int *)((char*)address - shift); + size_t shift = ((size_t)address & 2); + int* base_address = (int*)((char*)address - shift); - union I16PAIR { - struct { - T H; - T L; - } B; - int W; + union I16PAIR { + struct { + T H; + T L; + } B; + int W; - __host__ __device__ - I16PAIR() {}; + __host__ __device__ I16PAIR(){}; - __host__ __device__ - ~I16PAIR() {}; - }; + __host__ __device__ ~I16PAIR(){}; + }; - I16PAIR pairNew, pairOld, pairAssumed; + I16PAIR pairNew, pairOld, pairAssumed; - if (reinterpret_cast(address) == base_address) { - pairOld.B.L = val; - do { - pairNew.B.L = pairOld.B.L; - pairNew.B.H = pairOld.B.H * val; - pairAssumed.W = pairOld.W; + if (reinterpret_cast(address) == base_address) { + pairOld.B.L = val; + do { + pairNew.B.L = pairOld.B.L; + pairNew.B.H = pairOld.B.H * val; + pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + } while (pairAssumed.W != pairOld.W); - return (T) pairOld.B.H; - } else { - pairOld.B.H = val; - do { - pairNew.B.H = pairOld.B.H; - pairNew.B.L = pairOld.B.L * val; - pairAssumed.W = pairOld.W; - pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); + return (T)pairOld.B.H; + } else { + pairOld.B.H = val; + do { + pairNew.B.H = pairOld.B.H; + pairNew.B.L = pairOld.B.L * val; + pairAssumed.W = pairOld.W; + pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W); - } while (pairAssumed.W != pairOld.W); + } while (pairAssumed.W != pairOld.W); - return (T) pairOld.B.L; - } + return (T)pairOld.B.L; + } } template <> -inline __device__ int16_t nd4j_atomicMul(int16_t* address, int16_t val) { - return internal_16bit_atomicMul(address, val); +inline __device__ int16_t nd4j_atomicMul(int16_t* address, + int16_t val) { + return internal_16bit_atomicMul(address, val); } template <> -inline __device__ uint16_t nd4j_atomicMul(uint16_t* address, uint16_t val) { - return internal_16bit_atomicMul(address, val); +inline __device__ uint16_t nd4j_atomicMul(uint16_t* address, + uint16_t val) { + return internal_16bit_atomicMul(address, val); } template <> inline __device__ int nd4j_atomicMul(int* address, int val) { - int* res_address = address; - int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return old; + int* res_address = address; + int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return old; } template <> -inline __device__ unsigned int nd4j_atomicMul(unsigned int* address, unsigned int val) { - unsigned int* res_address = address; - unsigned int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return old; +inline __device__ unsigned int nd4j_atomicMul( + unsigned int* address, unsigned int val) { + unsigned int* res_address = address; + unsigned int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return old; } template <> -inline __device__ int64_t nd4j_atomicMul(int64_t* address, int64_t val) { - unsigned long long int* res_address = (unsigned long long int*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (int64_t)old; +inline __device__ int64_t nd4j_atomicMul(int64_t* address, + int64_t val) { + unsigned long long int* res_address = (unsigned long long int*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (int64_t)old; } template <> -inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, uint64_t val) { - unsigned long long int* res_address = (unsigned long long int*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (uint64_t)old; +inline __device__ uint64_t nd4j_atomicMul(uint64_t* address, + uint64_t val) { + unsigned long long int* res_address = (unsigned long long int*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (uint64_t)old; } #if !defined(_WIN32) && !defined(_WIN64) template <> -inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, Nd4jLong val) { - unsigned long long int* res_address = (unsigned long long*)address; - unsigned long long int old = *res_address, assumed; - do { - assumed = old; - old = atomicCAS(res_address, assumed, val * assumed); - } while (assumed != old); - return (Nd4jLong)old; +inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, + Nd4jLong val) { + unsigned long long int* res_address = (unsigned long long*)address; + unsigned long long int old = *res_address, assumed; + do { + assumed = old; + old = atomicCAS(res_address, assumed, val * assumed); + } while (assumed != old); + return (Nd4jLong)old; } #endif template <> -inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, bfloat16 val) { - return internal_16bit_atomicMul(address, val); +inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, + bfloat16 val) { + return internal_16bit_atomicMul(address, val); } template <> -inline __device__ float16 nd4j_atomicMul(float16* address, float16 val) { - return internal_16bit_atomicMul(address, val); +inline __device__ float16 nd4j_atomicMul(float16* address, + float16 val) { + return internal_16bit_atomicMul(address, val); } template <> inline __device__ float nd4j_atomicDiv(float* address, float val) { - return nd4j_atomicMul(address, 1.f / val); + return nd4j_atomicMul(address, 1.f / val); } template <> -inline __device__ float16 nd4j_atomicDiv(float16* address, float16 val) { - return internal_16bit_atomicMul(address, (float16) 1.f / val); +inline __device__ float16 nd4j_atomicDiv(float16* address, + float16 val) { + return internal_16bit_atomicMul(address, (float16)1.f / val); } template <> -inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 val) { - return internal_16bit_atomicMul(address, (bfloat16) 1 / val); -} +inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, + bfloat16 val) { + return internal_16bit_atomicMul(address, (bfloat16)1 / val); } +} // namespace atomics #endif - } -} +} // namespace math +} // namespace sd #ifdef _OPENMP @@ -1704,39 +1718,60 @@ inline __device__ bfloat16 nd4j_atomicDiv(bfloat16* address, bfloat16 #define MAX_FLOAT 1e37 #endif -#pragma omp declare reduction(maxTF : float,double,float16,bfloat16 : \ - omp_out = sd::math::nd4j_max(omp_in, omp_out) )\ - initializer (omp_priv=-MAX_FLOAT) - -#pragma omp declare reduction(minTF : float,double,float16,bfloat16 : \ - omp_out = sd::math::nd4j_min(omp_in, omp_out) )\ - initializer (omp_priv=MAX_FLOAT) - -#pragma omp declare reduction(maxT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_max(omp_in, omp_out) )\ - initializer (omp_priv=0) - -#pragma omp declare reduction(minT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_min(omp_in, omp_out) )\ - initializer (omp_priv=0) - -#pragma omp declare reduction(amaxT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_max(sd::math::nd4j_abs(omp_in), sd::math::nd4j_abs(omp_out)) ) - -#pragma omp declare reduction(aminT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_min(sd::math::nd4j_abs(omp_in), sd::math::nd4j_abs(omp_out)) ) - -#pragma omp declare reduction(asumT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = sd::math::nd4j_abs(omp_in) + sd::math::nd4j_abs(omp_out))\ - initializer (omp_priv=0) - -#pragma omp declare reduction(sumT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = omp_in + omp_out)\ - initializer (omp_priv=0) - -#pragma omp declare reduction(prodT : float,double,float16,bfloat16,int,Nd4jLong,Nd4jULong,int8_t,uint8_t,bool,int16_t,uint16_t,uint32_t : \ - omp_out = omp_in * omp_out)\ - initializer (omp_priv=1) +#pragma omp declare reduction(maxTF \ + : float, double, float16, bfloat16 \ + : omp_out = sd::math::nd4j_max(omp_in, omp_out)) \ + initializer(omp_priv = -MAX_FLOAT) + +#pragma omp declare reduction(minTF \ + : float, double, float16, bfloat16 \ + : omp_out = sd::math::nd4j_min(omp_in, omp_out)) \ + initializer(omp_priv = MAX_FLOAT) + +#pragma omp declare reduction( \ + maxT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_max(omp_in, omp_out)) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + minT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_min(omp_in, omp_out)) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + amaxT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_max(sd::math::nd4j_abs(omp_in), \ + sd::math::nd4j_abs(omp_out))) + +#pragma omp declare reduction( \ + aminT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_min(sd::math::nd4j_abs(omp_in), \ + sd::math::nd4j_abs(omp_out))) + +#pragma omp declare reduction( \ + asumT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = sd::math::nd4j_abs(omp_in) + sd::math::nd4j_abs(omp_out)) \ + initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + sumT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = omp_in + omp_out) initializer(omp_priv = 0) + +#pragma omp declare reduction( \ + prodT \ + : float, double, float16, bfloat16, int, Nd4jLong, Nd4jULong, int8_t, \ + uint8_t, bool, int16_t, uint16_t, uint32_t \ + : omp_out = omp_in * omp_out) initializer(omp_priv = 1) #endif diff --git a/libnd4j/include/memory/AllocationEntry.h b/libnd4j/include/memory/AllocationEntry.h index b2d9839af4e8..13ec15e96c0e 100644 --- a/libnd4j/include/memory/AllocationEntry.h +++ b/libnd4j/include/memory/AllocationEntry.h @@ -21,30 +21,31 @@ #ifndef SD_ALLOCATIONENTRY_H #define SD_ALLOCATIONENTRY_H +#include #include + #include -#include namespace sd { - namespace memory { - class AllocationEntry { - private: - MemoryType _memoryType; - Nd4jLong _pointer; - Nd4jLong _numBytes; - std::string _stack; - public: - AllocationEntry() = default; - AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, std::string &stack); - ~AllocationEntry() = default; - - - Nd4jLong numBytes(); - std::string stackTrace(); - MemoryType memoryType(); - }; - } -} - - -#endif //SD_ALLOCATIONENTRY_H +namespace memory { +class AllocationEntry { + private: + MemoryType _memoryType; + Nd4jLong _pointer; + Nd4jLong _numBytes; + std::string _stack; + + public: + AllocationEntry() = default; + AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, + std::string &stack); + ~AllocationEntry() = default; + + Nd4jLong numBytes(); + std::string stackTrace(); + MemoryType memoryType(); +}; +} // namespace memory +} // namespace sd + +#endif // SD_ALLOCATIONENTRY_H diff --git a/libnd4j/include/memory/ColdZoneManager.h b/libnd4j/include/memory/ColdZoneManager.h index 448b5e97e61f..cb4b3f1ab936 100644 --- a/libnd4j/include/memory/ColdZoneManager.h +++ b/libnd4j/include/memory/ColdZoneManager.h @@ -18,37 +18,36 @@ // @author raver119@gmail.com // - #ifndef SD_COLDZONEMANAGER_H #define SD_COLDZONEMANAGER_H #include namespace sd { - namespace memory { - class ColdZoneManager : public ZoneManager { - public: - /** - * This constructor is used to initialize ZoneManager with existing FlatBuffers file - * @param filename - full path to existing file (i.e. FlatBuffers file) - */ - explicit ColdZoneManager(const char *filename); - - ColdZoneManager() = default; - ~ColdZoneManager() = default; +namespace memory { +class ColdZoneManager : public ZoneManager { + public: + /** + * This constructor is used to initialize ZoneManager with existing + * FlatBuffers file + * @param filename - full path to existing file (i.e. FlatBuffers file) + */ + explicit ColdZoneManager(const char *filename); - MemoryZone zone() const override; + ColdZoneManager() = default; + ~ColdZoneManager() = default; - uint64_t available() const override; + MemoryZone zone() const override; - uint64_t used() const override; + uint64_t available() const override; - MemoryDescriptor allocate(uint64_t numBytes) override; + uint64_t used() const override; - void release(MemoryDescriptor &descriptor) override; - }; - } -} + MemoryDescriptor allocate(uint64_t numBytes) override; + void release(MemoryDescriptor &descriptor) override; +}; +} // namespace memory +} // namespace sd -#endif //SD_COLDZONEMANAGER_H +#endif // SD_COLDZONEMANAGER_H diff --git a/libnd4j/include/memory/ExternalWorkspace.h b/libnd4j/include/memory/ExternalWorkspace.h index f557f1c484b9..25fee4075ca0 100644 --- a/libnd4j/include/memory/ExternalWorkspace.h +++ b/libnd4j/include/memory/ExternalWorkspace.h @@ -21,31 +21,33 @@ #ifndef LIBND4J_EXTERNALWORKSPACE_H #define LIBND4J_EXTERNALWORKSPACE_H -#include #include +#include namespace sd { - namespace memory { - class SD_EXPORT ExternalWorkspace { - private: - void *_ptrH = nullptr; - void *_ptrD = nullptr; - - Nd4jLong _sizeH = 0L; - Nd4jLong _sizeD = 0L; - public: - ExternalWorkspace() = default; - ~ExternalWorkspace() = default; - - ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, Nd4jLong sizeD); - - void *pointerHost(); - void *pointerDevice(); - - Nd4jLong sizeHost(); - Nd4jLong sizeDevice(); - }; - } -} +namespace memory { +class SD_EXPORT ExternalWorkspace { + private: + void *_ptrH = nullptr; + void *_ptrD = nullptr; + + Nd4jLong _sizeH = 0L; + Nd4jLong _sizeD = 0L; + + public: + ExternalWorkspace() = default; + ~ExternalWorkspace() = default; + + ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, + Nd4jLong sizeD); + + void *pointerHost(); + void *pointerDevice(); + + Nd4jLong sizeHost(); + Nd4jLong sizeDevice(); +}; +} // namespace memory +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h index edcdf9d3cfe0..4245a91f275e 100644 --- a/libnd4j/include/memory/GraphMemoryManager.h +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -21,39 +21,40 @@ #ifndef SD_GRAPHMEMORYMANAGER_H #define SD_GRAPHMEMORYMANAGER_H +#include #include #include -#include + #include using namespace sd::memory; namespace sd { - namespace graph { - class GraphMemoryManager { - protected: - std::map _zones; - - public: - GraphMemoryManager(); - ~GraphMemoryManager(); - - /** - * This method does allocation (probably) and returns structure that describes it - * @param numBytes - number of bytes to be allocated - * @param zone - memory zone for allocation - * @return - */ - virtual MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); - - /** - * This method releases (probably) memory chunk described by given descriptor - * @param descriptor - */ - virtual void release(MemoryDescriptor &descriptor); - }; - } -} - - -#endif //SD_GRAPHMEMORYMANAGER_H +namespace graph { +class GraphMemoryManager { + protected: + std::map _zones; + + public: + GraphMemoryManager(); + ~GraphMemoryManager(); + + /** + * This method does allocation (probably) and returns structure that describes + * it + * @param numBytes - number of bytes to be allocated + * @param zone - memory zone for allocation + * @return + */ + virtual MemoryDescriptor allocate(size_t numBytes, MemoryZone zone); + + /** + * This method releases (probably) memory chunk described by given descriptor + * @param descriptor + */ + virtual void release(MemoryDescriptor &descriptor); +}; +} // namespace graph +} // namespace sd + +#endif // SD_GRAPHMEMORYMANAGER_H diff --git a/libnd4j/include/memory/HotRamZoneManager.h b/libnd4j/include/memory/HotRamZoneManager.h index 39e4362ab53e..f733e0f4a8ec 100644 --- a/libnd4j/include/memory/HotRamZoneManager.h +++ b/libnd4j/include/memory/HotRamZoneManager.h @@ -24,18 +24,17 @@ #include namespace sd { - namespace memory { - class HotRamZoneManager : public HotZoneManager { - public: - HotRamZoneManager() = default; - ~HotRamZoneManager() = default; +namespace memory { +class HotRamZoneManager : public HotZoneManager { + public: + HotRamZoneManager() = default; + ~HotRamZoneManager() = default; - MemoryDescriptor allocate(uint64_t numBytes) override; + MemoryDescriptor allocate(uint64_t numBytes) override; - void release(MemoryDescriptor &descriptor) override; - }; - } -} + void release(MemoryDescriptor &descriptor) override; +}; +} // namespace memory +} // namespace sd - -#endif //SD_HOTRAMZONEMANAGER_H +#endif // SD_HOTRAMZONEMANAGER_H diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h index bdeb05bcad07..1c8e197ee9b6 100644 --- a/libnd4j/include/memory/HotZoneManager.h +++ b/libnd4j/include/memory/HotZoneManager.h @@ -22,31 +22,31 @@ #define SD_HOTZONEMANAGER_H #include + #include namespace sd { - namespace memory { - class SD_EXPORT HotZoneManager : public ZoneManager { - protected: - std::atomic _used = {0}; - std::atomic _available = {0}; - - public: - HotZoneManager() = default; - ~HotZoneManager() = default; +namespace memory { +class SD_EXPORT HotZoneManager : public ZoneManager { + protected: + std::atomic _used = {0}; + std::atomic _available = {0}; - MemoryZone zone() const override; + public: + HotZoneManager() = default; + ~HotZoneManager() = default; - uint64_t available() const override; + MemoryZone zone() const override; - uint64_t used() const override; + uint64_t available() const override; - virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + uint64_t used() const override; - virtual void release(MemoryDescriptor &descriptor) = 0; - }; - } -} + virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + virtual void release(MemoryDescriptor &descriptor) = 0; +}; +} // namespace memory +} // namespace sd -#endif //SD_HOTZONEMANAGER_H +#endif // SD_HOTZONEMANAGER_H diff --git a/libnd4j/include/memory/MemoryCounter.h b/libnd4j/include/memory/MemoryCounter.h index 909eb7819c9d..ae8474fe39fa 100644 --- a/libnd4j/include/memory/MemoryCounter.h +++ b/libnd4j/include/memory/MemoryCounter.h @@ -21,126 +21,130 @@ #ifndef SD_MEMORYCOUNTER_H #define SD_MEMORYCOUNTER_H -#include +#include #include +#include + #include -#include #include namespace sd { - namespace memory { - /** - * This class provides simple per-device counter - */ - class SD_EXPORT MemoryCounter { - private: - static MemoryCounter* _INSTANCE; - - // used for synchronization - std::mutex _locker; - - // per-device counters - std::map _deviceCounters; - - // TODO: change this wrt heterogenous stuff on next iteration - // per-group counters - std::map _groupCounters; - - // per-device limits - std::map _deviceLimits; - - // per-group limits - std::map _groupLimits; - - MemoryCounter(); - ~MemoryCounter() = default; - - public: - static MemoryCounter *getInstance(); - - /** - * This method checks if allocation of numBytes won't break through per-group or per-device limit - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validate(Nd4jLong numBytes); - - /** - * This method checks if allocation of numBytes won't break through per-device limit - * @param deviceId - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validateDevice(int deviceId, Nd4jLong numBytes); - - /** - * This method checks if allocation of numBytes won't break through per-group limit - * @param deviceId - * @param numBytes - * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise - */ - bool validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method adds specified number of bytes to specified counter - * @param deviceId - * @param numBytes - */ - void countIn(int deviceId, Nd4jLong numBytes); - void countIn(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method subtracts specified number of bytes from specified counter - * @param deviceId - * @param numBytes - */ - void countOut(int deviceId, Nd4jLong numBytes); - void countOut(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method returns amount of memory allocated on specified device - * @param deviceId - * @return - */ - Nd4jLong allocatedDevice(int deviceId); - - /** - * This method returns amount of memory allocated in specified group of devices - * @param group - * @return - */ - Nd4jLong allocatedGroup(sd::memory::MemoryType group); - - /** - * This method allows to set per-device memory limits - * @param deviceId - * @param numBytes - */ - void setDeviceLimit(int deviceId, Nd4jLong numBytes); - - /** - * This method returns current device limit in bytes - * @param deviceId - * @return - */ - Nd4jLong deviceLimit(int deviceId); - - /** - * This method allows to set per-group memory limits - * @param group - * @param numBytes - */ - void setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes); - - /** - * This method returns current group limit in bytes - * @param group - * @return - */ - Nd4jLong groupLimit(sd::memory::MemoryType group); - }; - } -} - - -#endif //SD_MEMORYCOUNTER_H +namespace memory { +/** + * This class provides simple per-device counter + */ +class SD_EXPORT MemoryCounter { + private: + static MemoryCounter* _INSTANCE; + + // used for synchronization + std::mutex _locker; + + // per-device counters + std::map _deviceCounters; + + // TODO: change this wrt heterogenous stuff on next iteration + // per-group counters + std::map _groupCounters; + + // per-device limits + std::map _deviceLimits; + + // per-group limits + std::map _groupLimits; + + MemoryCounter(); + ~MemoryCounter() = default; + + public: + static MemoryCounter* getInstance(); + + /** + * This method checks if allocation of numBytes won't break through per-group + * or per-device limit + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validate(Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-device + * limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateDevice(int deviceId, Nd4jLong numBytes); + + /** + * This method checks if allocation of numBytes won't break through per-group + * limit + * @param deviceId + * @param numBytes + * @return TRUE if allocated ammount will keep us below limit, FALSE otherwise + */ + bool validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method adds specified number of bytes to specified counter + * @param deviceId + * @param numBytes + */ + void countIn(int deviceId, Nd4jLong numBytes); + void countIn(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method subtracts specified number of bytes from specified counter + * @param deviceId + * @param numBytes + */ + void countOut(int deviceId, Nd4jLong numBytes); + void countOut(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns amount of memory allocated on specified device + * @param deviceId + * @return + */ + Nd4jLong allocatedDevice(int deviceId); + + /** + * This method returns amount of memory allocated in specified group of + * devices + * @param group + * @return + */ + Nd4jLong allocatedGroup(sd::memory::MemoryType group); + + /** + * This method allows to set per-device memory limits + * @param deviceId + * @param numBytes + */ + void setDeviceLimit(int deviceId, Nd4jLong numBytes); + + /** + * This method returns current device limit in bytes + * @param deviceId + * @return + */ + Nd4jLong deviceLimit(int deviceId); + + /** + * This method allows to set per-group memory limits + * @param group + * @param numBytes + */ + void setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes); + + /** + * This method returns current group limit in bytes + * @param group + * @return + */ + Nd4jLong groupLimit(sd::memory::MemoryType group); +}; +} // namespace memory +} // namespace sd + +#endif // SD_MEMORYCOUNTER_H diff --git a/libnd4j/include/memory/MemoryDescriptor.h b/libnd4j/include/memory/MemoryDescriptor.h index fb4ced9a3408..f0aad28ef183 100644 --- a/libnd4j/include/memory/MemoryDescriptor.h +++ b/libnd4j/include/memory/MemoryDescriptor.h @@ -23,35 +23,36 @@ #include #include + #include namespace sd { - namespace memory { - class SD_EXPORT MemoryDescriptor { - private: - void* _ptr; - MemoryZone _zone; - uint64_t _bytes; - public: - MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes); - ~MemoryDescriptor() = default; +namespace memory { +class SD_EXPORT MemoryDescriptor { + private: + void* _ptr; + MemoryZone _zone; + uint64_t _bytes; - MemoryDescriptor(const MemoryDescriptor& other) noexcept; + public: + MemoryDescriptor(void* ptr, MemoryZone zone, uint64_t bytes); + ~MemoryDescriptor() = default; - MemoryDescriptor& operator=(const MemoryDescriptor& other) noexcept; + MemoryDescriptor(const MemoryDescriptor& other) noexcept; - // move constructor - MemoryDescriptor(MemoryDescriptor&& other) noexcept; + MemoryDescriptor& operator=(const MemoryDescriptor& other) noexcept; - // move assignment operator - MemoryDescriptor& operator=(MemoryDescriptor&& other) noexcept; + // move constructor + MemoryDescriptor(MemoryDescriptor&& other) noexcept; - void* address() const; - MemoryZone zone() const; - uint64_t bytes() const; - }; - } -} + // move assignment operator + MemoryDescriptor& operator=(MemoryDescriptor&& other) noexcept; + void* address() const; + MemoryZone zone() const; + uint64_t bytes() const; +}; +} // namespace memory +} // namespace sd -#endif //SD_MEMORYDESCRIPTOR_H +#endif // SD_MEMORYDESCRIPTOR_H diff --git a/libnd4j/include/memory/MemoryRegistrator.h b/libnd4j/include/memory/MemoryRegistrator.h index a12d960b29ba..aaf1b429b181 100644 --- a/libnd4j/include/memory/MemoryRegistrator.h +++ b/libnd4j/include/memory/MemoryRegistrator.h @@ -21,48 +21,51 @@ #ifndef LIBND4J_MEMORYREGISTRATOR_H #define LIBND4J_MEMORYREGISTRATOR_H -#include "Workspace.h" +#include #include -#include + #include #include -#include +#include + +#include "Workspace.h" namespace sd { - namespace memory { - class SD_EXPORT MemoryRegistrator { - protected: - static MemoryRegistrator* _INSTANCE; - Workspace* _workspace; - MAP_IMPL _footprint; - std::mutex _lock; +namespace memory { +class SD_EXPORT MemoryRegistrator { + protected: + static MemoryRegistrator* _INSTANCE; + Workspace* _workspace; + MAP_IMPL _footprint; + std::mutex _lock; + + MemoryRegistrator(); + ~MemoryRegistrator() = default; - MemoryRegistrator(); - ~MemoryRegistrator() = default; - public: - static MemoryRegistrator* getInstance(); - bool hasWorkspaceAttached(); - Workspace* getWorkspace(); - void attachWorkspace(Workspace* workspace); - void forgetWorkspace(); + public: + static MemoryRegistrator* getInstance(); + bool hasWorkspaceAttached(); + Workspace* getWorkspace(); + void attachWorkspace(Workspace* workspace); + void forgetWorkspace(); - /** - * This method allows you to set memory requirements for given graph - */ - void setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes); + /** + * This method allows you to set memory requirements for given graph + */ + void setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes); - /** - * This method allows you to set memory requirements for given graph, ONLY if - * new amount of bytes is greater then current one - */ - void setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes); + /** + * This method allows you to set memory requirements for given graph, ONLY if + * new amount of bytes is greater then current one + */ + void setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes); - /** - * This method returns memory requirements for given graph - */ - Nd4jLong getGraphMemoryFootprint(Nd4jLong hash); - }; - } -} + /** + * This method returns memory requirements for given graph + */ + Nd4jLong getGraphMemoryFootprint(Nd4jLong hash); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_MEMORYREGISTRATOR_H +#endif // LIBND4J_MEMORYREGISTRATOR_H diff --git a/libnd4j/include/memory/MemoryReport.h b/libnd4j/include/memory/MemoryReport.h index 40c87e188346..aca8bf7efd8f 100644 --- a/libnd4j/include/memory/MemoryReport.h +++ b/libnd4j/include/memory/MemoryReport.h @@ -21,36 +21,34 @@ #ifndef LIBND4J_MEMORYREPORT_H #define LIBND4J_MEMORYREPORT_H -#include #include +#include namespace sd { - namespace memory { - class SD_EXPORT MemoryReport { - private: - Nd4jLong _vm = 0; - Nd4jLong _rss = 0; - - public: - MemoryReport() = default; - ~MemoryReport() = default; - - bool operator < (const MemoryReport& other) const; - bool operator <= (const MemoryReport& other) const; - bool operator > (const MemoryReport& other) const; - bool operator >= (const MemoryReport& other) const; - bool operator == (const MemoryReport& other) const; - bool operator != (const MemoryReport& other) const; - - Nd4jLong getVM() const; - void setVM(Nd4jLong vm); - - Nd4jLong getRSS() const; - void setRSS(Nd4jLong rss); - }; - } -} - - - -#endif //LIBND4J_MEMORYREPORT_H +namespace memory { +class SD_EXPORT MemoryReport { + private: + Nd4jLong _vm = 0; + Nd4jLong _rss = 0; + + public: + MemoryReport() = default; + ~MemoryReport() = default; + + bool operator<(const MemoryReport& other) const; + bool operator<=(const MemoryReport& other) const; + bool operator>(const MemoryReport& other) const; + bool operator>=(const MemoryReport& other) const; + bool operator==(const MemoryReport& other) const; + bool operator!=(const MemoryReport& other) const; + + Nd4jLong getVM() const; + void setVM(Nd4jLong vm); + + Nd4jLong getRSS() const; + void setRSS(Nd4jLong rss); +}; +} // namespace memory +} // namespace sd + +#endif // LIBND4J_MEMORYREPORT_H diff --git a/libnd4j/include/memory/MemoryTracker.h b/libnd4j/include/memory/MemoryTracker.h index 36a54f0c5208..824335442ad2 100644 --- a/libnd4j/include/memory/MemoryTracker.h +++ b/libnd4j/include/memory/MemoryTracker.h @@ -21,38 +21,41 @@ #ifndef SD_MEMORYTRACKER_H #define SD_MEMORYTRACKER_H -#include -#include +#include #include + +#include #include +#include + #include "AllocationEntry.h" -#include namespace sd { - namespace memory { - /** - * This class is used for tracking memory allocation wrt their allocation points in code - */ - class SD_EXPORT MemoryTracker { - private: - static MemoryTracker* _INSTANCE; - std::map _allocations; - std::map _released; - std::mutex _locker; - - MemoryTracker(); - ~MemoryTracker() = default; - public: - static MemoryTracker* getInstance(); - - void countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes); - void countOut(Nd4jPointer ptr); - - void summarize(); - void reset(); - }; - } -} - - -#endif //SD_MEMORYTRACKER_H +namespace memory { +/** + * This class is used for tracking memory allocation wrt their allocation points + * in code + */ +class SD_EXPORT MemoryTracker { + private: + static MemoryTracker* _INSTANCE; + std::map _allocations; + std::map _released; + std::mutex _locker; + + MemoryTracker(); + ~MemoryTracker() = default; + + public: + static MemoryTracker* getInstance(); + + void countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes); + void countOut(Nd4jPointer ptr); + + void summarize(); + void reset(); +}; +} // namespace memory +} // namespace sd + +#endif // SD_MEMORYTRACKER_H diff --git a/libnd4j/include/memory/MemoryType.h b/libnd4j/include/memory/MemoryType.h index 4f9ced600e26..6cf1091cdee4 100644 --- a/libnd4j/include/memory/MemoryType.h +++ b/libnd4j/include/memory/MemoryType.h @@ -6,12 +6,12 @@ #define SD_MEMORYTYPE_H namespace sd { - namespace memory { - enum MemoryType { - HOST = 0, - DEVICE = 10, - }; - } +namespace memory { +enum MemoryType { + HOST = 0, + DEVICE = 10, +}; } +} // namespace sd -#endif //SD_MEMORYTYPE_H +#endif // SD_MEMORYTYPE_H diff --git a/libnd4j/include/memory/MemoryUtils.h b/libnd4j/include/memory/MemoryUtils.h index c53ffa76752d..a3e749144219 100644 --- a/libnd4j/include/memory/MemoryUtils.h +++ b/libnd4j/include/memory/MemoryUtils.h @@ -21,18 +21,17 @@ #ifndef LIBND4J_MEMORYUTILS_H #define LIBND4J_MEMORYUTILS_H -#include "MemoryReport.h" #include -namespace sd { - namespace memory { - class SD_EXPORT MemoryUtils { - public: - static bool retrieveMemoryStatistics(MemoryReport& report); - }; - } -} - +#include "MemoryReport.h" +namespace sd { +namespace memory { +class SD_EXPORT MemoryUtils { + public: + static bool retrieveMemoryStatistics(MemoryReport& report); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_MEMORYUTILS_H +#endif // LIBND4J_MEMORYUTILS_H diff --git a/libnd4j/include/memory/MemoryZone.h b/libnd4j/include/memory/MemoryZone.h index 16921e9b5d79..666f6c9e6a51 100644 --- a/libnd4j/include/memory/MemoryZone.h +++ b/libnd4j/include/memory/MemoryZone.h @@ -22,13 +22,13 @@ #define SD_MEMORYZONE_H namespace sd { - namespace memory { - enum MemoryZone { - COLD = 0, - WARM = 10, - HOT = 20, - }; - } +namespace memory { +enum MemoryZone { + COLD = 0, + WARM = 10, + HOT = 20, +}; } +} // namespace sd -#endif //SD_MEMORYZONE_H +#endif // SD_MEMORYZONE_H diff --git a/libnd4j/include/memory/WarmZoneManager.h b/libnd4j/include/memory/WarmZoneManager.h index 7d830736abfa..05700805b169 100644 --- a/libnd4j/include/memory/WarmZoneManager.h +++ b/libnd4j/include/memory/WarmZoneManager.h @@ -24,15 +24,14 @@ #include namespace sd { - namespace memory { - class SD_EXPORT WarmZoneManager : public ZoneManager { - protected: - public: - WarmZoneManager() = default; - ~WarmZoneManager() = default; - }; - } -} +namespace memory { +class SD_EXPORT WarmZoneManager : public ZoneManager { + protected: + public: + WarmZoneManager() = default; + ~WarmZoneManager() = default; +}; +} // namespace memory +} // namespace sd - -#endif //SD_WARMZONEMANAGER_H +#endif // SD_WARMZONEMANAGER_H diff --git a/libnd4j/include/memory/Workspace.h b/libnd4j/include/memory/Workspace.h index 0f2bbfd94c8b..098856ba8876 100644 --- a/libnd4j/include/memory/Workspace.h +++ b/libnd4j/include/memory/Workspace.h @@ -24,85 +24,88 @@ #ifndef LIBND4J_WORKSPACE_H #define LIBND4J_WORKSPACE_H -#include -#include -#include +#include +#include #include #include #include -#include -#include + +#include +#include +#include namespace sd { - namespace memory { +namespace memory { + +class SD_EXPORT Workspace { + protected: + char* _ptrHost = nullptr; + char* _ptrDevice = nullptr; - class SD_EXPORT Workspace { - protected: - char* _ptrHost = nullptr; - char* _ptrDevice = nullptr; + bool _allocatedHost = false; + bool _allocatedDevice = false; - bool _allocatedHost = false; - bool _allocatedDevice = false; + std::atomic _offset; + std::atomic _offsetSecondary; - std::atomic _offset; - std::atomic _offsetSecondary; + Nd4jLong _initialSize = 0L; + Nd4jLong _initialSizeSecondary = 0L; - Nd4jLong _initialSize = 0L; - Nd4jLong _initialSizeSecondary = 0L; + Nd4jLong _currentSize = 0L; + Nd4jLong _currentSizeSecondary = 0L; - Nd4jLong _currentSize = 0L; - Nd4jLong _currentSizeSecondary = 0L; + std::mutex _mutexAllocation; + std::mutex _mutexSpills; - std::mutex _mutexAllocation; - std::mutex _mutexSpills; + bool _externalized = false; - bool _externalized = false; + std::vector _spills; + std::vector _spillsSecondary; - std::vector _spills; - std::vector _spillsSecondary; + std::atomic _spillsSize; + std::atomic _cycleAllocations; - std::atomic _spillsSize; - std::atomic _cycleAllocations; + std::atomic _spillsSizeSecondary; + std::atomic _cycleAllocationsSecondary; - std::atomic _spillsSizeSecondary; - std::atomic _cycleAllocationsSecondary; + void init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void freeSpills(); - void init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); - void freeSpills(); - public: - explicit Workspace(ExternalWorkspace *external); - Workspace(Nd4jLong initialSize = 0L, Nd4jLong secondaryBytes = 0L); - ~Workspace(); + public: + explicit Workspace(ExternalWorkspace* external); + Workspace(Nd4jLong initialSize = 0L, Nd4jLong secondaryBytes = 0L); + ~Workspace(); - Nd4jLong getAllocatedSize(); - Nd4jLong getCurrentSize(); - Nd4jLong getCurrentOffset(); - Nd4jLong getSpilledSize(); - Nd4jLong getUsedSize(); + Nd4jLong getAllocatedSize(); + Nd4jLong getCurrentSize(); + Nd4jLong getCurrentOffset(); + Nd4jLong getSpilledSize(); + Nd4jLong getUsedSize(); - Nd4jLong getAllocatedSecondarySize(); - Nd4jLong getCurrentSecondarySize(); - Nd4jLong getCurrentSecondaryOffset(); - Nd4jLong getSpilledSecondarySize(); - Nd4jLong getUsedSecondarySize(); + Nd4jLong getAllocatedSecondarySize(); + Nd4jLong getCurrentSecondarySize(); + Nd4jLong getCurrentSecondaryOffset(); + Nd4jLong getSpilledSecondarySize(); + Nd4jLong getUsedSecondarySize(); - void expandBy(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); - void expandTo(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void expandBy(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); + void expandTo(Nd4jLong primaryBytes, Nd4jLong secondaryBytes = 0L); -// bool resizeSupported(); + // bool resizeSupported(); - void* allocateBytes(Nd4jLong numBytes); - void* allocateBytes(MemoryType type, Nd4jLong numBytes); + void* allocateBytes(Nd4jLong numBytes); + void* allocateBytes(MemoryType type, Nd4jLong numBytes); - void scopeIn(); - void scopeOut(); + void scopeIn(); + void scopeOut(); - /* - * This method creates NEW workspace of the same memory size and returns pointer to it - */ - Workspace* clone(); - }; - } -} + /* + * This method creates NEW workspace of the same memory size and returns + * pointer to it + */ + Workspace* clone(); +}; +} // namespace memory +} // namespace sd -#endif //LIBND4J_WORKSPACE_H +#endif // LIBND4J_WORKSPACE_H diff --git a/libnd4j/include/memory/ZoneManager.h b/libnd4j/include/memory/ZoneManager.h index e097e8e5952d..67d79fa3488c 100644 --- a/libnd4j/include/memory/ZoneManager.h +++ b/libnd4j/include/memory/ZoneManager.h @@ -21,61 +21,60 @@ #ifndef SD_ZONEMANAGER_H #define SD_ZONEMANAGER_H -#include -#include #include +#include +#include + #include #include namespace sd { - namespace memory { - /** - * Abstract class that defines common methods for zone managers - */ - class SD_EXPORT ZoneManager { - protected: - std::mutex _lock; - - public: - ZoneManager() = default; - - virtual ~ZoneManager() = default; - - /** - * This method returns id of the current zone served by this manager instance - * @return MemoryZone enum - */ - virtual MemoryZone zone() const = 0; +namespace memory { +/** + * Abstract class that defines common methods for zone managers + */ +class SD_EXPORT ZoneManager { + protected: + std::mutex _lock; - /** - * This method returns amount of memory available in this zone - * @return number of bytes - */ - virtual uint64_t available() const = 0; + public: + ZoneManager() = default; - /** - * This method returns amount of memory currently used in this zone - * @return number of bytes - */ - virtual uint64_t used() const = 0; + virtual ~ZoneManager() = default; - /** - * This method allocates (probably) some memory chunk, and returns you pointer to it. - * @param numBytes - * @return - */ - virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + /** + * This method returns id of the current zone served by this manager instance + * @return MemoryZone enum + */ + virtual MemoryZone zone() const = 0; - /** - * This method releases (probably) memory described by given MemoryDescriptor - * @param descriptor - */ - virtual void release(MemoryDescriptor &descriptor) = 0; + /** + * This method returns amount of memory available in this zone + * @return number of bytes + */ + virtual uint64_t available() const = 0; + /** + * This method returns amount of memory currently used in this zone + * @return number of bytes + */ + virtual uint64_t used() const = 0; - }; - } -} + /** + * This method allocates (probably) some memory chunk, and returns you pointer + * to it. + * @param numBytes + * @return + */ + virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + /** + * This method releases (probably) memory described by given MemoryDescriptor + * @param descriptor + */ + virtual void release(MemoryDescriptor &descriptor) = 0; +}; +} // namespace memory +} // namespace sd -#endif //SD_ZONEMANAGER_H +#endif // SD_ZONEMANAGER_H diff --git a/libnd4j/include/memory/cpu/ColdZoneManager.cpp b/libnd4j/include/memory/cpu/ColdZoneManager.cpp index a9599c229d0c..25c00ef27036 100644 --- a/libnd4j/include/memory/cpu/ColdZoneManager.cpp +++ b/libnd4j/include/memory/cpu/ColdZoneManager.cpp @@ -21,29 +21,23 @@ #include namespace sd { - namespace memory { - ColdZoneManager::ColdZoneManager(const char *filename) { - // - } - - MemoryZone ColdZoneManager::zone() const { - return COLD; - } - - uint64_t ColdZoneManager::available() const { - return 0; - } - - uint64_t ColdZoneManager::used() const { - return 0; - } - - MemoryDescriptor ColdZoneManager::allocate(uint64_t numBytes) { - return MemoryDescriptor(nullptr, COLD, numBytes); - } - - void ColdZoneManager::release(MemoryDescriptor &descriptor) { - // - } - } +namespace memory { +ColdZoneManager::ColdZoneManager(const char *filename) { + // } + +MemoryZone ColdZoneManager::zone() const { return COLD; } + +uint64_t ColdZoneManager::available() const { return 0; } + +uint64_t ColdZoneManager::used() const { return 0; } + +MemoryDescriptor ColdZoneManager::allocate(uint64_t numBytes) { + return MemoryDescriptor(nullptr, COLD, numBytes); +} + +void ColdZoneManager::release(MemoryDescriptor &descriptor) { + // +} +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp index 47c1653eaed4..afca4b9dfe0d 100644 --- a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -18,34 +18,34 @@ // @author raver119@gmail.com // +#include #include #include -#include namespace sd { - namespace graph { - GraphMemoryManager::GraphMemoryManager() { - // first of all we initialize all memory managers - // CPU backend only has two: HOT and COLD - - _zones[MemoryZone::HOT] = new memory::HotRamZoneManager(); - _zones[MemoryZone::COLD] = new memory::ColdZoneManager(); - } - - GraphMemoryManager::~GraphMemoryManager() { - delete _zones[MemoryZone::HOT]; - delete _zones[MemoryZone::COLD]; - } - - MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, MemoryZone zone) { - if (zone == MemoryZone::WARM) - zone = MemoryZone::HOT; - - return _zones[zone]->allocate(numBytes); - } - - void GraphMemoryManager::release(MemoryDescriptor &descriptor) { - _zones[descriptor.zone()]->release(descriptor); - } - } +namespace graph { +GraphMemoryManager::GraphMemoryManager() { + // first of all we initialize all memory managers + // CPU backend only has two: HOT and COLD + + _zones[MemoryZone::HOT] = new memory::HotRamZoneManager(); + _zones[MemoryZone::COLD] = new memory::ColdZoneManager(); +} + +GraphMemoryManager::~GraphMemoryManager() { + delete _zones[MemoryZone::HOT]; + delete _zones[MemoryZone::COLD]; +} + +MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, + MemoryZone zone) { + if (zone == MemoryZone::WARM) zone = MemoryZone::HOT; + + return _zones[zone]->allocate(numBytes); +} + +void GraphMemoryManager::release(MemoryDescriptor &descriptor) { + _zones[descriptor.zone()]->release(descriptor); } +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/memory/cpu/HotZoneManager.cpp b/libnd4j/include/memory/cpu/HotZoneManager.cpp index 229029ad153e..73db24a94eb3 100644 --- a/libnd4j/include/memory/cpu/HotZoneManager.cpp +++ b/libnd4j/include/memory/cpu/HotZoneManager.cpp @@ -20,19 +20,12 @@ #include - namespace sd { - namespace memory { - MemoryZone HotZoneManager::zone() const { - return HOT; - } +namespace memory { +MemoryZone HotZoneManager::zone() const { return HOT; } - uint64_t HotZoneManager::available() const { - return _available; - } +uint64_t HotZoneManager::available() const { return _available; } - uint64_t HotZoneManager::used() const { - return _used; - } - } -} \ No newline at end of file +uint64_t HotZoneManager::used() const { return _used; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/cpu/Workspace.cpp b/libnd4j/include/memory/cpu/Workspace.cpp index ae60f1eeac30..a25195c44f16 100644 --- a/libnd4j/include/memory/cpu/Workspace.cpp +++ b/libnd4j/include/memory/cpu/Workspace.cpp @@ -20,199 +20,181 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include #include "../Workspace.h" + #include #include -#include +#include +#include +#include +#include +#include namespace sd { - namespace memory { - Workspace::Workspace(ExternalWorkspace *external) { - if (external->sizeHost() > 0) { - _ptrHost = (char *) external->pointerHost(); - _ptrDevice = (char *) external->pointerDevice(); - - _initialSize = external->sizeHost(); - _currentSize = external->sizeHost(); - _offset = 0L; - _offsetSecondary = 0L; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - - _externalized = true; - } - }; - - Workspace::Workspace(Nd4jLong initialSize, Nd4jLong secondaryBytes) { - if (initialSize > 0) { - this->_ptrHost = (char *) malloc(initialSize); - - CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", initialSize); - - memset(this->_ptrHost, 0, initialSize); - this->_allocatedHost = true; - } else - this->_allocatedHost = false; - - this->_initialSize = initialSize; - this->_currentSize = initialSize; - this->_currentSizeSecondary = 0; - this->_spillsSizeSecondary = 0; - this->_offset = 0; - this->_offsetSecondary = 0; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - } - - void Workspace::init(Nd4jLong bytes, Nd4jLong secondaryBytes) { - if (this->_currentSize < bytes) { - if (this->_allocatedHost && !_externalized) - free((void *)this->_ptrHost); - - this->_ptrHost =(char *) malloc(bytes); +namespace memory { +Workspace::Workspace(ExternalWorkspace *external) { + if (external->sizeHost() > 0) { + _ptrHost = (char *)external->pointerHost(); + _ptrDevice = (char *)external->pointerDevice(); + + _initialSize = external->sizeHost(); + _currentSize = external->sizeHost(); + _offset = 0L; + _offsetSecondary = 0L; + this->_cycleAllocations = 0; + this->_spillsSize = 0; + + _externalized = true; + } +}; + +Workspace::Workspace(Nd4jLong initialSize, Nd4jLong secondaryBytes) { + if (initialSize > 0) { + this->_ptrHost = (char *)malloc(initialSize); + + CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", + initialSize); + + memset(this->_ptrHost, 0, initialSize); + this->_allocatedHost = true; + } else + this->_allocatedHost = false; + + this->_initialSize = initialSize; + this->_currentSize = initialSize; + this->_currentSizeSecondary = 0; + this->_spillsSizeSecondary = 0; + this->_offset = 0; + this->_offsetSecondary = 0; + this->_cycleAllocations = 0; + this->_spillsSize = 0; +} - CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", bytes); +void Workspace::init(Nd4jLong bytes, Nd4jLong secondaryBytes) { + if (this->_currentSize < bytes) { + if (this->_allocatedHost && !_externalized) free((void *)this->_ptrHost); - memset(this->_ptrHost, 0, bytes); - this->_currentSize = bytes; - this->_allocatedHost = true; - } - } + this->_ptrHost = (char *)malloc(bytes); - void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); - } + CHECK_ALLOC(this->_ptrHost, "Failed to allocate new workspace", bytes); - void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(numBytes, secondaryBytes); - } + memset(this->_ptrHost, 0, bytes); + this->_currentSize = bytes; + this->_allocatedHost = true; + } +} - void Workspace::freeSpills() { - _spillsSize = 0; +void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); +} - if (_spills.size() < 1) - return; +void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(numBytes, secondaryBytes); +} - for (auto v:_spills) - free(v); +void Workspace::freeSpills() { + _spillsSize = 0; - _spills.clear(); - } + if (_spills.size() < 1) return; - Workspace::~Workspace() { - if (this->_allocatedHost && !_externalized) - free((void *)this->_ptrHost); + for (auto v : _spills) free(v); - freeSpills(); - } + _spills.clear(); +} - Nd4jLong Workspace::getUsedSize() { - return getCurrentOffset(); - } +Workspace::~Workspace() { + if (this->_allocatedHost && !_externalized) free((void *)this->_ptrHost); - Nd4jLong Workspace::getCurrentSize() { - return _currentSize; - } + freeSpills(); +} - Nd4jLong Workspace::getCurrentOffset() { - return _offset.load(); - } +Nd4jLong Workspace::getUsedSize() { return getCurrentOffset(); } +Nd4jLong Workspace::getCurrentSize() { return _currentSize; } - void* Workspace::allocateBytes(Nd4jLong numBytes) { - if (numBytes < 1) - throw allocation_exception::build("Number of bytes for allocation should be positive", numBytes); +Nd4jLong Workspace::getCurrentOffset() { return _offset.load(); } +void *Workspace::allocateBytes(Nd4jLong numBytes) { + if (numBytes < 1) + throw allocation_exception::build( + "Number of bytes for allocation should be positive", numBytes); - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocations += numBytes; - this->_mutexAllocation.lock(); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocations += numBytes; + this->_mutexAllocation.lock(); - if (_offset.load() + numBytes > _currentSize) { - nd4j_debug("Allocating %lld bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); + if (_offset.load() + numBytes > _currentSize) { + nd4j_debug("Allocating %lld bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - void *p = malloc(numBytes); + void *p = malloc(numBytes); - CHECK_ALLOC(p, "Failed to allocate new workspace", numBytes); + CHECK_ALLOC(p, "Failed to allocate new workspace", numBytes); - _mutexSpills.lock(); - _spills.push_back(p); - _mutexSpills.unlock(); + _mutexSpills.lock(); + _spills.push_back(p); + _mutexSpills.unlock(); - _spillsSize += numBytes; + _spillsSize += numBytes; - return p; - } + return p; + } - result = (void *)(_ptrHost + _offset.load()); - _offset += numBytes; - //memset(result, 0, (int) numBytes); + result = (void *)(_ptrHost + _offset.load()); + _offset += numBytes; + // memset(result, 0, (int) numBytes); - nd4j_debug("Allocating %lld bytes from workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + nd4j_debug( + "Allocating %lld bytes from workspace; Current PTR: %p; Current offset: " + "%lld\n", + numBytes, result, _offset.load()); - this->_mutexAllocation.unlock(); + this->_mutexAllocation.unlock(); - return result; - } + return result; +} - Nd4jLong Workspace::getAllocatedSize() { - return getCurrentSize() + getSpilledSize(); - } +Nd4jLong Workspace::getAllocatedSize() { + return getCurrentSize() + getSpilledSize(); +} - void Workspace::scopeIn() { - freeSpills(); - init(_cycleAllocations.load()); - _cycleAllocations = 0; - } +void Workspace::scopeIn() { + freeSpills(); + init(_cycleAllocations.load()); + _cycleAllocations = 0; +} - void Workspace::scopeOut() { - _offset = 0; - _offsetSecondary = 0; - } +void Workspace::scopeOut() { + _offset = 0; + _offsetSecondary = 0; +} - Nd4jLong Workspace::getSpilledSize() { - return _spillsSize.load(); - } +Nd4jLong Workspace::getSpilledSize() { return _spillsSize.load(); } - void* Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { - if (type == DEVICE) - throw std::runtime_error("CPU backend doesn't have device memory"); +void *Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { + if (type == DEVICE) + throw std::runtime_error("CPU backend doesn't have device memory"); - return this->allocateBytes(numBytes); - } + return this->allocateBytes(numBytes); +} - Nd4jLong Workspace::getAllocatedSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getAllocatedSecondarySize() { return 0L; } - Nd4jLong Workspace::getCurrentSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getCurrentSecondarySize() { return 0L; } - Nd4jLong Workspace::getCurrentSecondaryOffset() { - return 0L; - } +Nd4jLong Workspace::getCurrentSecondaryOffset() { return 0L; } - Nd4jLong Workspace::getSpilledSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getSpilledSecondarySize() { return 0L; } - Nd4jLong Workspace::getUsedSecondarySize() { - return 0L; - } +Nd4jLong Workspace::getUsedSecondarySize() { return 0L; } - Workspace* Workspace::clone() { - // for clone we take whatever is higher: current allocated size, or allocated size of current loop - return new Workspace(sd::math::nd4j_max(this->getCurrentSize(), this->_cycleAllocations.load())); - } - } +Workspace *Workspace::clone() { + // for clone we take whatever is higher: current allocated size, or allocated + // size of current loop + return new Workspace(sd::math::nd4j_max( + this->getCurrentSize(), this->_cycleAllocations.load())); } - +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/cuda/Workspace.cu b/libnd4j/include/memory/cuda/Workspace.cu index 9d228615689f..2d57b43c5ce3 100644 --- a/libnd4j/include/memory/cuda/Workspace.cu +++ b/libnd4j/include/memory/cuda/Workspace.cu @@ -20,278 +20,273 @@ // @author raver119@gmail.com // -#include -#include -#include -#include -#include "../Workspace.h" +#include +#include +#include #include #include +#include +#include +#include + +#include #include -#include -#include -#include +#include "../Workspace.h" namespace sd { - namespace memory { - Workspace::Workspace(ExternalWorkspace *external) { - if (external->sizeHost() > 0) { - _ptrHost = (char *) external->pointerHost(); - _ptrDevice = (char *) external->pointerDevice(); - - _initialSize = external->sizeDevice(); - _currentSize = external->sizeDevice(); - _initialSizeSecondary = external->sizeHost(); - _currentSizeSecondary = external->sizeHost(); - _offset = 0L; - _offsetSecondary = 0L; - this->_cycleAllocations = 0; - this->_cycleAllocationsSecondary = 0; - this->_spillsSize = 0; - this->_spillsSizeSecondary = 0; - - _externalized = true; - } - } - - Workspace::Workspace(Nd4jLong primarySize, Nd4jLong secondarySize) { - if (secondarySize > 0) { - auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), secondarySize, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); - - cudaMemset(this->_ptrHost, 0, secondarySize); - this->_allocatedHost = true; - } else - this->_allocatedHost = false; - - if (primarySize > 0) { - auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), primarySize); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - - cudaMemset(this->_ptrDevice, 0, primarySize); - this->_allocatedDevice = true; - } else - this->_allocatedDevice = false; - - this->_initialSize = primarySize; - this->_initialSizeSecondary = secondarySize; - this->_currentSize = primarySize; - this->_currentSizeSecondary = secondarySize; - this->_offset = 0; - this->_offsetSecondary = 0; - this->_cycleAllocations = 0; - this->_spillsSize = 0; - this->_spillsSizeSecondary = 0; - } - - void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) { - if (this->_currentSize < primaryBytes) { - if (this->_allocatedDevice && !_externalized) - cudaFree((void *)this->_ptrDevice); - - auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), secondaryBytes); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - - cudaMemset(this->_ptrDevice, 0, primaryBytes); - this->_currentSize = primaryBytes; - this->_allocatedDevice = true; - } - - if (this->_currentSizeSecondary < secondaryBytes) { - if (this->_allocatedHost && !_externalized) - cudaFreeHost((void *)this->_ptrHost); - - auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), secondaryBytes, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); - - - cudaMemset(this->_ptrHost, 0, secondaryBytes); - this->_currentSizeSecondary = secondaryBytes; - this->_allocatedHost = true; - } - } - - void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); - } - - void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { - this->init(numBytes, secondaryBytes); - } - - void Workspace::freeSpills() { - _spillsSize = 0; - _spillsSizeSecondary = 0; - - for (auto v:_spills) - cudaFree(v); - - for (auto v:_spillsSecondary) - cudaFreeHost(v); - - _spills.clear(); - _spillsSecondary.clear(); - } - - Workspace::~Workspace() { - if (this->_allocatedHost && !_externalized) - cudaFreeHost((void *)this->_ptrHost); - - if (this->_allocatedDevice && !_externalized) - cudaFree((void *)this->_ptrDevice); - - freeSpills(); - } - - Nd4jLong Workspace::getUsedSize() { - return getCurrentOffset(); - } +namespace memory { +Workspace::Workspace(ExternalWorkspace *external) { + if (external->sizeHost() > 0) { + _ptrHost = (char *)external->pointerHost(); + _ptrDevice = (char *)external->pointerDevice(); + + _initialSize = external->sizeDevice(); + _currentSize = external->sizeDevice(); + _initialSizeSecondary = external->sizeHost(); + _currentSizeSecondary = external->sizeHost(); + _offset = 0L; + _offsetSecondary = 0L; + this->_cycleAllocations = 0; + this->_cycleAllocationsSecondary = 0; + this->_spillsSize = 0; + this->_spillsSizeSecondary = 0; + + _externalized = true; + } +} + +Workspace::Workspace(Nd4jLong primarySize, Nd4jLong secondarySize) { + if (secondarySize > 0) { + auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), + secondarySize, cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); + + cudaMemset(this->_ptrHost, 0, secondarySize); + this->_allocatedHost = true; + } else + this->_allocatedHost = false; + + if (primarySize > 0) { + auto res = cudaMalloc(reinterpret_cast(&_ptrDevice), primarySize); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + + cudaMemset(this->_ptrDevice, 0, primarySize); + this->_allocatedDevice = true; + } else + this->_allocatedDevice = false; + + this->_initialSize = primarySize; + this->_initialSizeSecondary = secondarySize; + this->_currentSize = primarySize; + this->_currentSizeSecondary = secondarySize; + this->_offset = 0; + this->_offsetSecondary = 0; + this->_cycleAllocations = 0; + this->_spillsSize = 0; + this->_spillsSizeSecondary = 0; +} - Nd4jLong Workspace::getCurrentSize() { - return _currentSize; - } +void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) { + if (this->_currentSize < primaryBytes) { + if (this->_allocatedDevice && !_externalized) + cudaFree((void *)this->_ptrDevice); + + auto res = + cudaMalloc(reinterpret_cast(&_ptrDevice), secondaryBytes); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + + cudaMemset(this->_ptrDevice, 0, primaryBytes); + this->_currentSize = primaryBytes; + this->_allocatedDevice = true; + } + + if (this->_currentSizeSecondary < secondaryBytes) { + if (this->_allocatedHost && !_externalized) + cudaFreeHost((void *)this->_ptrHost); + + auto res = cudaHostAlloc(reinterpret_cast(&_ptrHost), + secondaryBytes, cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); + + cudaMemset(this->_ptrHost, 0, secondaryBytes); + this->_currentSizeSecondary = secondaryBytes; + this->_allocatedHost = true; + } +} + +void Workspace::expandBy(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(_currentSize + numBytes, _currentSizeSecondary + secondaryBytes); +} + +void Workspace::expandTo(Nd4jLong numBytes, Nd4jLong secondaryBytes) { + this->init(numBytes, secondaryBytes); +} + +void Workspace::freeSpills() { + _spillsSize = 0; + _spillsSizeSecondary = 0; + + for (auto v : _spills) cudaFree(v); + + for (auto v : _spillsSecondary) cudaFreeHost(v); + + _spills.clear(); + _spillsSecondary.clear(); +} - Nd4jLong Workspace::getCurrentOffset() { - return _offset.load(); - } +Workspace::~Workspace() { + if (this->_allocatedHost && !_externalized) + cudaFreeHost((void *)this->_ptrHost); + if (this->_allocatedDevice && !_externalized) + cudaFree((void *)this->_ptrDevice); - void* Workspace::allocateBytes(Nd4jLong numBytes) { - return allocateBytes(sd::memory::MemoryType::HOST, numBytes); - } + freeSpills(); +} - Nd4jLong Workspace::getAllocatedSize() { - return getCurrentSize() + getSpilledSize(); - } +Nd4jLong Workspace::getUsedSize() { return getCurrentOffset(); } - void Workspace::scopeIn() { - freeSpills(); - init(_cycleAllocations.load()); - _cycleAllocations = 0; - } +Nd4jLong Workspace::getCurrentSize() { return _currentSize; } - void Workspace::scopeOut() { - _offset = 0; - } +Nd4jLong Workspace::getCurrentOffset() { return _offset.load(); } - Nd4jLong Workspace::getSpilledSize() { - return _spillsSize.load(); - } +void *Workspace::allocateBytes(Nd4jLong numBytes) { + return allocateBytes(sd::memory::MemoryType::HOST, numBytes); +} - void* Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { - switch (type) { - case HOST: { - if (numBytes < 1) - throw allocation_exception::build("Number of [HOST] bytes for allocation should be positive", numBytes); +Nd4jLong Workspace::getAllocatedSize() { + return getCurrentSize() + getSpilledSize(); +} + +void Workspace::scopeIn() { + freeSpills(); + init(_cycleAllocations.load()); + _cycleAllocations = 0; +} +void Workspace::scopeOut() { _offset = 0; } - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocationsSecondary += numBytes; - this->_mutexAllocation.lock(); +Nd4jLong Workspace::getSpilledSize() { return _spillsSize.load(); } - if (_offsetSecondary.load() + numBytes > _currentSizeSecondary) { - nd4j_debug("Allocating %lld [HOST] bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); +void *Workspace::allocateBytes(sd::memory::MemoryType type, Nd4jLong numBytes) { + switch (type) { + case HOST: { + if (numBytes < 1) + throw allocation_exception::build( + "Number of [HOST] bytes for allocation should be positive", + numBytes); - Nd4jPointer p; - auto res = cudaHostAlloc(reinterpret_cast(&p), numBytes, cudaHostAllocDefault); - if (res != 0) - throw cuda_exception::build("Can't allocate [HOST] memory", res); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocationsSecondary += numBytes; + this->_mutexAllocation.lock(); - _mutexSpills.lock(); - _spillsSecondary.push_back(p); - _mutexSpills.unlock(); + if (_offsetSecondary.load() + numBytes > _currentSizeSecondary) { + nd4j_debug("Allocating %lld [HOST] bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - _spillsSizeSecondary += numBytes; + Nd4jPointer p; + auto res = cudaHostAlloc(reinterpret_cast(&p), numBytes, + cudaHostAllocDefault); + if (res != 0) + throw cuda_exception::build("Can't allocate [HOST] memory", res); - return p; - } + _mutexSpills.lock(); + _spillsSecondary.push_back(p); + _mutexSpills.unlock(); - result = (void *)(_ptrHost + _offsetSecondary.load()); - _offsetSecondary += numBytes; - //memset(result, 0, (int) numBytes); + _spillsSizeSecondary += numBytes; - nd4j_debug("Allocating %lld bytes from [HOST] workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + return p; + } - this->_mutexAllocation.unlock(); + result = (void *)(_ptrHost + _offsetSecondary.load()); + _offsetSecondary += numBytes; + // memset(result, 0, (int) numBytes); - return result; - } - break; - case DEVICE: { - if (numBytes < 1) - throw allocation_exception::build("Number of [DEVICE] bytes for allocation should be positive", numBytes); + nd4j_debug( + "Allocating %lld bytes from [HOST] workspace; Current PTR: %p; " + "Current offset: %lld\n", + numBytes, result, _offset.load()); + this->_mutexAllocation.unlock(); - //numBytes += 32; - void* result = nullptr; - this->_cycleAllocations += numBytes; - this->_mutexAllocation.lock(); + return result; + } break; + case DEVICE: { + if (numBytes < 1) + throw allocation_exception::build( + "Number of [DEVICE] bytes for allocation should be positive", + numBytes); - if (_offset.load() + numBytes > _currentSize) { - nd4j_debug("Allocating %lld [DEVICE] bytes in spills\n", numBytes); - this->_mutexAllocation.unlock(); + // numBytes += 32; + void *result = nullptr; + this->_cycleAllocations += numBytes; + this->_mutexAllocation.lock(); - Nd4jPointer p; - auto res = cudaMalloc(reinterpret_cast(&p), numBytes); - if (res != 0) - throw cuda_exception::build("Can't allocate [DEVICE] memory", res); + if (_offset.load() + numBytes > _currentSize) { + nd4j_debug("Allocating %lld [DEVICE] bytes in spills\n", numBytes); + this->_mutexAllocation.unlock(); - _mutexSpills.lock(); - _spills.push_back(p); - _mutexSpills.unlock(); + Nd4jPointer p; + auto res = cudaMalloc(reinterpret_cast(&p), numBytes); + if (res != 0) + throw cuda_exception::build("Can't allocate [DEVICE] memory", res); - _spillsSize += numBytes; + _mutexSpills.lock(); + _spills.push_back(p); + _mutexSpills.unlock(); - return p; - } + _spillsSize += numBytes; - result = (void *)(_ptrDevice + _offset.load()); - _offset += numBytes; - //memset(result, 0, (int) numBytes); + return p; + } - nd4j_debug("Allocating %lld bytes from [DEVICE] workspace; Current PTR: %p; Current offset: %lld\n", numBytes, result, _offset.load()); + result = (void *)(_ptrDevice + _offset.load()); + _offset += numBytes; + // memset(result, 0, (int) numBytes); - this->_mutexAllocation.unlock(); + nd4j_debug( + "Allocating %lld bytes from [DEVICE] workspace; Current PTR: %p; " + "Current offset: %lld\n", + numBytes, result, _offset.load()); - return result; - } - break; - default: - throw std::runtime_error("Unknown MemoryType was passed in"); - } - } + this->_mutexAllocation.unlock(); - Workspace* Workspace::clone() { - // for clone we take whatever is higher: current allocated size, or allocated size of current loop - return new Workspace(sd::math::nd4j_max(this->getCurrentSize(), this->_cycleAllocations.load())); - } + return result; + } break; + default: + throw std::runtime_error("Unknown MemoryType was passed in"); + } +} - Nd4jLong Workspace::getAllocatedSecondarySize() { - return getCurrentSecondarySize() + getSpilledSecondarySize(); - } +Workspace *Workspace::clone() { + // for clone we take whatever is higher: current allocated size, or allocated + // size of current loop + return new Workspace(sd::math::nd4j_max( + this->getCurrentSize(), this->_cycleAllocations.load())); +} - Nd4jLong Workspace::getCurrentSecondarySize() { - return _currentSizeSecondary; - } +Nd4jLong Workspace::getAllocatedSecondarySize() { + return getCurrentSecondarySize() + getSpilledSecondarySize(); +} - Nd4jLong Workspace::getCurrentSecondaryOffset() { - return _offsetSecondary.load(); - } +Nd4jLong Workspace::getCurrentSecondarySize() { return _currentSizeSecondary; } - Nd4jLong Workspace::getSpilledSecondarySize() { - return _spillsSizeSecondary; - } +Nd4jLong Workspace::getCurrentSecondaryOffset() { + return _offsetSecondary.load(); +} - Nd4jLong Workspace::getUsedSecondarySize() { - return getCurrentSecondaryOffset(); - } +Nd4jLong Workspace::getSpilledSecondarySize() { return _spillsSizeSecondary; } - } +Nd4jLong Workspace::getUsedSecondarySize() { + return getCurrentSecondaryOffset(); } + +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/AllocationEntry.cpp b/libnd4j/include/memory/impl/AllocationEntry.cpp index 6b4d85bb1040..0a1598144b97 100644 --- a/libnd4j/include/memory/impl/AllocationEntry.cpp +++ b/libnd4j/include/memory/impl/AllocationEntry.cpp @@ -21,24 +21,19 @@ #include namespace sd { - namespace memory { - AllocationEntry::AllocationEntry(MemoryType type, Nd4jLong ptr, Nd4jLong numBytes, std::string &stack) { - _pointer = ptr; - _numBytes = numBytes; - _stack = stack; - _memoryType = type; - } +namespace memory { +AllocationEntry::AllocationEntry(MemoryType type, Nd4jLong ptr, + Nd4jLong numBytes, std::string &stack) { + _pointer = ptr; + _numBytes = numBytes; + _stack = stack; + _memoryType = type; +} - std::string AllocationEntry::stackTrace() { - return _stack; - } +std::string AllocationEntry::stackTrace() { return _stack; } - Nd4jLong AllocationEntry::numBytes() { - return _numBytes; - } +Nd4jLong AllocationEntry::numBytes() { return _numBytes; } - MemoryType AllocationEntry::memoryType() { - return _memoryType; - } - } -} \ No newline at end of file +MemoryType AllocationEntry::memoryType() { return _memoryType; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/ExternalWorkspace.cpp b/libnd4j/include/memory/impl/ExternalWorkspace.cpp index c4feb181dddb..e2a81b812b2d 100644 --- a/libnd4j/include/memory/impl/ExternalWorkspace.cpp +++ b/libnd4j/include/memory/impl/ExternalWorkspace.cpp @@ -21,29 +21,22 @@ #include namespace sd { - namespace memory { - ExternalWorkspace::ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, Nd4jPointer ptrD, Nd4jLong sizeD) { - _ptrH = ptrH; - _sizeH = sizeH; - - _ptrD = ptrD; - _sizeD = sizeD; - }; - - void* ExternalWorkspace::pointerHost() { - return _ptrH; - } - - void* ExternalWorkspace::pointerDevice() { - return _ptrD; - } - - Nd4jLong ExternalWorkspace::sizeHost() { - return _sizeH; - } - - Nd4jLong ExternalWorkspace::sizeDevice() { - return _sizeD; - } - } -} \ No newline at end of file +namespace memory { +ExternalWorkspace::ExternalWorkspace(Nd4jPointer ptrH, Nd4jLong sizeH, + Nd4jPointer ptrD, Nd4jLong sizeD) { + _ptrH = ptrH; + _sizeH = sizeH; + + _ptrD = ptrD; + _sizeD = sizeD; +}; + +void* ExternalWorkspace::pointerHost() { return _ptrH; } + +void* ExternalWorkspace::pointerDevice() { return _ptrD; } + +Nd4jLong ExternalWorkspace::sizeHost() { return _sizeH; } + +Nd4jLong ExternalWorkspace::sizeDevice() { return _sizeD; } +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/HotRamZoneManager.cpp b/libnd4j/include/memory/impl/HotRamZoneManager.cpp index 651eb501a30a..15066bd09bb1 100644 --- a/libnd4j/include/memory/impl/HotRamZoneManager.cpp +++ b/libnd4j/include/memory/impl/HotRamZoneManager.cpp @@ -21,18 +21,18 @@ #include namespace sd { - namespace memory { - MemoryDescriptor HotRamZoneManager::allocate(uint64_t numBytes) { - _used += numBytes; +namespace memory { +MemoryDescriptor HotRamZoneManager::allocate(uint64_t numBytes) { + _used += numBytes; - auto ptr = new int8_t[numBytes]; - return MemoryDescriptor(ptr, zone(), numBytes); - } + auto ptr = new int8_t[numBytes]; + return MemoryDescriptor(ptr, zone(), numBytes); +} - void HotRamZoneManager::release(MemoryDescriptor &descriptor) { - _used -= descriptor.bytes(); +void HotRamZoneManager::release(MemoryDescriptor &descriptor) { + _used -= descriptor.bytes(); - delete[](reinterpret_cast(descriptor.address())); - } - } -} \ No newline at end of file + delete[](reinterpret_cast(descriptor.address())); +} +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryCounter.cpp b/libnd4j/include/memory/impl/MemoryCounter.cpp index 96be346816f5..16ca82133d5a 100644 --- a/libnd4j/include/memory/impl/MemoryCounter.cpp +++ b/libnd4j/include/memory/impl/MemoryCounter.cpp @@ -19,115 +19,117 @@ // #include "../MemoryCounter.h" + #include -#include #include +#include namespace sd { - namespace memory { - - MemoryCounter::MemoryCounter() { - auto numDevices = sd::AffinityManager::numberOfDevices(); - - // setting default 0s - for (int e = 0; e < numDevices; e++) { - _deviceLimits[e] = 0; - _deviceCounters[e] = 0; - } - - // setting initial values for limits - _groupLimits[sd::memory::MemoryType::HOST] = sd::Environment::getInstance()->maxPrimaryMemory(); - _groupLimits[sd::memory::MemoryType::DEVICE] = sd::Environment::getInstance()->maxSpecialMemory(); - - // setting initial counter values - _groupCounters[sd::memory::MemoryType::HOST] = 0; - _groupCounters[sd::memory::MemoryType::DEVICE] = 0; - } - - MemoryCounter* MemoryCounter::getInstance() { - if (_INSTANCE == 0) - _INSTANCE = new MemoryCounter(); - - return _INSTANCE; - } - - void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceCounters[deviceId] += numBytes; - } - - void MemoryCounter::countIn(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupCounters[group] += numBytes; - } - - void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceCounters[deviceId] -= numBytes; - } - - void MemoryCounter::countOut(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupCounters[group] -= numBytes; - } - - bool MemoryCounter::validate(Nd4jLong numBytes) { - auto deviceId = sd::AffinityManager::currentDeviceId(); - return validateDevice(deviceId, numBytes); - } - - bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - auto dLimit = _deviceLimits[deviceId]; - if (dLimit <= 0) - return true; - - auto dAlloc = _deviceCounters[deviceId]; - - return numBytes + dAlloc <= dLimit; - } - - bool MemoryCounter::validateGroup(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - auto gLimit = _groupLimits[group]; - if (gLimit <= 0) - return true; - - auto gAlloc = _groupCounters[group]; - - return numBytes + gAlloc <= gLimit; - } - - Nd4jLong MemoryCounter::allocatedDevice(int deviceId) { - std::lock_guard lock(_locker); - return _deviceCounters[deviceId]; - } - - Nd4jLong MemoryCounter::allocatedGroup(sd::memory::MemoryType group) { - std::lock_guard lock(_locker); - return _groupCounters[group]; - } - - void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _deviceLimits[deviceId] = numBytes; - } - - void MemoryCounter::setGroupLimit(sd::memory::MemoryType group, Nd4jLong numBytes) { - std::lock_guard lock(_locker); - _groupLimits[group] = numBytes; - } - - Nd4jLong MemoryCounter::deviceLimit(int deviceId) { - std::lock_guard lock(_locker); - return _deviceLimits[deviceId]; - } - - Nd4jLong MemoryCounter::groupLimit(sd::memory::MemoryType group) { - std::lock_guard lock(_locker); - return _groupLimits[group]; - } - - MemoryCounter* MemoryCounter::_INSTANCE = 0; - } -} \ No newline at end of file +namespace memory { + +MemoryCounter::MemoryCounter() { + auto numDevices = sd::AffinityManager::numberOfDevices(); + + // setting default 0s + for (int e = 0; e < numDevices; e++) { + _deviceLimits[e] = 0; + _deviceCounters[e] = 0; + } + + // setting initial values for limits + _groupLimits[sd::memory::MemoryType::HOST] = + sd::Environment::getInstance()->maxPrimaryMemory(); + _groupLimits[sd::memory::MemoryType::DEVICE] = + sd::Environment::getInstance()->maxSpecialMemory(); + + // setting initial counter values + _groupCounters[sd::memory::MemoryType::HOST] = 0; + _groupCounters[sd::memory::MemoryType::DEVICE] = 0; +} + +MemoryCounter* MemoryCounter::getInstance() { + if (_INSTANCE == 0) _INSTANCE = new MemoryCounter(); + + return _INSTANCE; +} + +void MemoryCounter::countIn(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] += numBytes; +} + +void MemoryCounter::countIn(sd::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] += numBytes; +} + +void MemoryCounter::countOut(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceCounters[deviceId] -= numBytes; +} + +void MemoryCounter::countOut(sd::memory::MemoryType group, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupCounters[group] -= numBytes; +} + +bool MemoryCounter::validate(Nd4jLong numBytes) { + auto deviceId = sd::AffinityManager::currentDeviceId(); + return validateDevice(deviceId, numBytes); +} + +bool MemoryCounter::validateDevice(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto dLimit = _deviceLimits[deviceId]; + if (dLimit <= 0) return true; + + auto dAlloc = _deviceCounters[deviceId]; + + return numBytes + dAlloc <= dLimit; +} + +bool MemoryCounter::validateGroup(sd::memory::MemoryType group, + Nd4jLong numBytes) { + std::lock_guard lock(_locker); + auto gLimit = _groupLimits[group]; + if (gLimit <= 0) return true; + + auto gAlloc = _groupCounters[group]; + + return numBytes + gAlloc <= gLimit; +} + +Nd4jLong MemoryCounter::allocatedDevice(int deviceId) { + std::lock_guard lock(_locker); + return _deviceCounters[deviceId]; +} + +Nd4jLong MemoryCounter::allocatedGroup(sd::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupCounters[group]; +} + +void MemoryCounter::setDeviceLimit(int deviceId, Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _deviceLimits[deviceId] = numBytes; +} + +void MemoryCounter::setGroupLimit(sd::memory::MemoryType group, + Nd4jLong numBytes) { + std::lock_guard lock(_locker); + _groupLimits[group] = numBytes; +} + +Nd4jLong MemoryCounter::deviceLimit(int deviceId) { + std::lock_guard lock(_locker); + return _deviceLimits[deviceId]; +} + +Nd4jLong MemoryCounter::groupLimit(sd::memory::MemoryType group) { + std::lock_guard lock(_locker); + return _groupLimits[group]; +} + +MemoryCounter* MemoryCounter::_INSTANCE = 0; +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryDescriptor.cpp b/libnd4j/include/memory/impl/MemoryDescriptor.cpp index 7d3dfeb753f2..ea0c0e0b3c2a 100644 --- a/libnd4j/include/memory/impl/MemoryDescriptor.cpp +++ b/libnd4j/include/memory/impl/MemoryDescriptor.cpp @@ -21,51 +21,48 @@ #include namespace sd { - namespace memory { - MemoryDescriptor::MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes) : _ptr(ptr), _zone(zone), _bytes(bytes) { - // - } +namespace memory { +MemoryDescriptor::MemoryDescriptor(void *ptr, MemoryZone zone, uint64_t bytes) + : _ptr(ptr), _zone(zone), _bytes(bytes) { + // +} - MemoryDescriptor::MemoryDescriptor(const MemoryDescriptor &other) noexcept : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { - // - } +MemoryDescriptor::MemoryDescriptor(const MemoryDescriptor &other) noexcept + : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // +} - MemoryDescriptor &MemoryDescriptor::operator=(const MemoryDescriptor &other) noexcept { - if (this == &other) - return *this; +MemoryDescriptor &MemoryDescriptor::operator=( + const MemoryDescriptor &other) noexcept { + if (this == &other) return *this; - _ptr = other._ptr; - _zone = other._zone; - _bytes = other._bytes; + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; - return *this; - } + return *this; +} - MemoryDescriptor::MemoryDescriptor(MemoryDescriptor &&other) noexcept : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { - // - } +MemoryDescriptor::MemoryDescriptor(MemoryDescriptor &&other) noexcept + : _ptr(other._ptr), _zone(other._zone), _bytes(other._bytes) { + // +} - MemoryDescriptor &MemoryDescriptor::operator=(MemoryDescriptor &&other) noexcept { - if (this == &other) - return *this; +MemoryDescriptor &MemoryDescriptor::operator=( + MemoryDescriptor &&other) noexcept { + if (this == &other) return *this; - _ptr = other._ptr; - _zone = other._zone; - _bytes = other._bytes; + _ptr = other._ptr; + _zone = other._zone; + _bytes = other._bytes; - return *this; - } + return *this; +} - void *MemoryDescriptor::address() const { - return _ptr; - } +void *MemoryDescriptor::address() const { return _ptr; } - MemoryZone MemoryDescriptor::zone() const { - return _zone; - } +MemoryZone MemoryDescriptor::zone() const { return _zone; } - uint64_t MemoryDescriptor::bytes() const { - return _bytes; - } - } -} +uint64_t MemoryDescriptor::bytes() const { return _bytes; } +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/MemoryRegistrator.cpp b/libnd4j/include/memory/impl/MemoryRegistrator.cpp index 31b4b0eaee80..acf936039fdb 100644 --- a/libnd4j/include/memory/impl/MemoryRegistrator.cpp +++ b/libnd4j/include/memory/impl/MemoryRegistrator.cpp @@ -21,70 +21,60 @@ #include namespace sd { - namespace memory { +namespace memory { - MemoryRegistrator::MemoryRegistrator() { - _workspace = nullptr; - }; +MemoryRegistrator::MemoryRegistrator() { _workspace = nullptr; }; - MemoryRegistrator* MemoryRegistrator::getInstance() { - if (_INSTANCE == 0) - _INSTANCE = new MemoryRegistrator(); +MemoryRegistrator* MemoryRegistrator::getInstance() { + if (_INSTANCE == 0) _INSTANCE = new MemoryRegistrator(); - return _INSTANCE; - } + return _INSTANCE; +} - bool MemoryRegistrator::hasWorkspaceAttached() { - return _workspace != nullptr; - } +bool MemoryRegistrator::hasWorkspaceAttached() { return _workspace != nullptr; } - Workspace* MemoryRegistrator::getWorkspace() { - return _workspace; - } +Workspace* MemoryRegistrator::getWorkspace() { return _workspace; } - void MemoryRegistrator::attachWorkspace(Workspace* workspace) { - _workspace = workspace; - } +void MemoryRegistrator::attachWorkspace(Workspace* workspace) { + _workspace = workspace; +} - void MemoryRegistrator::forgetWorkspace() { - _workspace = nullptr; - } +void MemoryRegistrator::forgetWorkspace() { _workspace = nullptr; } - void MemoryRegistrator::setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes) { - _lock.lock(); - - _footprint[hash] = bytes; +void MemoryRegistrator::setGraphMemoryFootprint(Nd4jLong hash, Nd4jLong bytes) { + _lock.lock(); - _lock.unlock(); - } + _footprint[hash] = bytes; - void MemoryRegistrator::setGraphMemoryFootprintIfGreater(Nd4jLong hash, Nd4jLong bytes) { - _lock.lock(); + _lock.unlock(); +} - if (_footprint.count(hash) == 0) - _footprint[hash] = bytes; - else { - Nd4jLong cv = _footprint[hash]; - if (bytes > cv) - _footprint[hash] = bytes; - } +void MemoryRegistrator::setGraphMemoryFootprintIfGreater(Nd4jLong hash, + Nd4jLong bytes) { + _lock.lock(); - _lock.unlock(); - } + if (_footprint.count(hash) == 0) + _footprint[hash] = bytes; + else { + Nd4jLong cv = _footprint[hash]; + if (bytes > cv) _footprint[hash] = bytes; + } - Nd4jLong MemoryRegistrator::getGraphMemoryFootprint(Nd4jLong hash) { - _lock.lock(); - - Nd4jLong result = 0L; - if (_footprint.count(hash) > 0) - result = _footprint[hash]; - - _lock.unlock(); + _lock.unlock(); +} - return result; - } +Nd4jLong MemoryRegistrator::getGraphMemoryFootprint(Nd4jLong hash) { + _lock.lock(); - MemoryRegistrator* MemoryRegistrator::_INSTANCE = 0; + Nd4jLong result = 0L; + if (_footprint.count(hash) > 0) result = _footprint[hash]; - } -} \ No newline at end of file + _lock.unlock(); + + return result; +} + +MemoryRegistrator* MemoryRegistrator::_INSTANCE = 0; + +} // namespace memory +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/impl/MemoryReport.cpp b/libnd4j/include/memory/impl/MemoryReport.cpp index 0c623b0cea4f..a56c9e1d1362 100644 --- a/libnd4j/include/memory/impl/MemoryReport.cpp +++ b/libnd4j/include/memory/impl/MemoryReport.cpp @@ -20,42 +20,42 @@ #include "memory/MemoryReport.h" -bool sd::memory::MemoryReport::operator<(const sd::memory::MemoryReport &other) const { - return this->_rss < other._rss; +bool sd::memory::MemoryReport::operator<( + const sd::memory::MemoryReport &other) const { + return this->_rss < other._rss; } -bool sd::memory::MemoryReport::operator>(const sd::memory::MemoryReport &other) const { - return this->_rss > other._rss; +bool sd::memory::MemoryReport::operator>( + const sd::memory::MemoryReport &other) const { + return this->_rss > other._rss; } -bool sd::memory::MemoryReport::operator==(const sd::memory::MemoryReport &other) const { - return this->_rss == other._rss; +bool sd::memory::MemoryReport::operator==( + const sd::memory::MemoryReport &other) const { + return this->_rss == other._rss; } -bool sd::memory::MemoryReport::operator!=(const sd::memory::MemoryReport &other) const { - return this->_rss != other._rss; +bool sd::memory::MemoryReport::operator!=( + const sd::memory::MemoryReport &other) const { + return this->_rss != other._rss; } -bool sd::memory::MemoryReport::operator<=(const sd::memory::MemoryReport &other) const { - return this->_rss <= other._rss; +bool sd::memory::MemoryReport::operator<=( + const sd::memory::MemoryReport &other) const { + return this->_rss <= other._rss; } -bool sd::memory::MemoryReport::operator>=(const sd::memory::MemoryReport &other) const { - return this->_rss >= other._rss; +bool sd::memory::MemoryReport::operator>=( + const sd::memory::MemoryReport &other) const { + return this->_rss >= other._rss; } -Nd4jLong sd::memory::MemoryReport::getVM() const { - return _vm; -} +Nd4jLong sd::memory::MemoryReport::getVM() const { return _vm; } -void sd::memory::MemoryReport::setVM(Nd4jLong _vm) { - MemoryReport::_vm = _vm; -} +void sd::memory::MemoryReport::setVM(Nd4jLong _vm) { MemoryReport::_vm = _vm; } -Nd4jLong sd::memory::MemoryReport::getRSS() const { - return _rss; -} +Nd4jLong sd::memory::MemoryReport::getRSS() const { return _rss; } void sd::memory::MemoryReport::setRSS(Nd4jLong _rss) { - MemoryReport::_rss = _rss; + MemoryReport::_rss = _rss; } diff --git a/libnd4j/include/memory/impl/MemoryTracker.cpp b/libnd4j/include/memory/impl/MemoryTracker.cpp index 5ebb4fd16445..da70d61ff462 100644 --- a/libnd4j/include/memory/impl/MemoryTracker.cpp +++ b/libnd4j/include/memory/impl/MemoryTracker.cpp @@ -18,161 +18,165 @@ // Created by raver119 on 07.05.19. // -#include -#include #include - - +#include #include -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) +#include + +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) -#include -#include #include +#include +#include #endif namespace sd { - namespace memory { - - MemoryTracker::MemoryTracker() { - // - } - - MemoryTracker* MemoryTracker::getInstance() { - if (_INSTANCE == 0) - _INSTANCE = new MemoryTracker(); - - return _INSTANCE; - } - -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - std::string demangle(char *message) { - char *mangled_name = 0, *offset_begin = 0, *offset_end = 0; - - // find parantheses and +address offset surrounding mangled name - for (char *p = message; *p; ++p) - { - if (*p == '(') - { - mangled_name = p; - } - else if (*p == '+') - { - offset_begin = p; - } - else if (*p == ')') - { - offset_end = p; - break; - } - } - - // if the line could be processed, attempt to demangle the symbol - if (mangled_name && offset_begin && offset_end && mangled_name < offset_begin) { - *mangled_name++ = '\0'; - *offset_begin++ = '\0'; - *offset_end++ = '\0'; - - int status; - char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status); - - // if demangling is successful, output the demangled function name - if (status == 0) { - std::string result(real_name); - free(real_name); - return result; - } else { - // otherwise, output the mangled function name - std::string result (message); - free(real_name); - return result; - } - } - - // safe return - return std::string(""); - } +namespace memory { -#endif +MemoryTracker::MemoryTracker() { + // +} + +MemoryTracker *MemoryTracker::getInstance() { + if (_INSTANCE == 0) _INSTANCE = new MemoryTracker(); + + return _INSTANCE; +} - void MemoryTracker::countIn(MemoryType type, Nd4jPointer ptr, Nd4jLong numBytes) { -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - if (Environment::getInstance()->isDetectingLeaks()) { - auto lptr = reinterpret_cast(ptr); +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) +std::string demangle(char *message) { + char *mangled_name = 0, *offset_begin = 0, *offset_end = 0; + + // find parantheses and +address offset surrounding mangled name + for (char *p = message; *p; ++p) { + if (*p == '(') { + mangled_name = p; + } else if (*p == '+') { + offset_begin = p; + } else if (*p == ')') { + offset_end = p; + break; + } + } + + // if the line could be processed, attempt to demangle the symbol + if (mangled_name && offset_begin && offset_end && + mangled_name < offset_begin) { + *mangled_name++ = '\0'; + *offset_begin++ = '\0'; + *offset_end++ = '\0'; + + int status; + char *real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status); + + // if demangling is successful, output the demangled function name + if (status == 0) { + std::string result(real_name); + free(real_name); + return result; + } else { + // otherwise, output the mangled function name + std::string result(message); + free(real_name); + return result; + } + } - _locker.lock(); + // safe return + return std::string(""); +} - void *array[50]; - size_t size; - char **messages; - size = backtrace(array, 50); +#endif - std::string stack(""); - messages = backtrace_symbols(array, size); - for (int i = 1; i < size && messages != NULL; ++i) { - stack += demangle(messages[i]) + "\n"; - } +void MemoryTracker::countIn(MemoryType type, Nd4jPointer ptr, + Nd4jLong numBytes) { +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) + if (Environment::getInstance()->isDetectingLeaks()) { + auto lptr = reinterpret_cast(ptr); + + _locker.lock(); + + void *array[50]; + size_t size; + char **messages; + size = backtrace(array, 50); + + std::string stack(""); + messages = backtrace_symbols(array, size); + for (int i = 1; i < size && messages != NULL; ++i) { + stack += demangle(messages[i]) + "\n"; + } - free(messages); + free(messages); - if (stack.find("ConstantTad") != std::string::npos || - stack.find("ConstantShape") != std::string::npos) { - _locker.unlock(); - return; - } + if (stack.find("ConstantTad") != std::string::npos || + stack.find("ConstantShape") != std::string::npos) { + _locker.unlock(); + return; + } - std::pair pair(lptr, AllocationEntry(type, lptr, numBytes, stack)); - _allocations.insert(pair); + std::pair pair( + lptr, AllocationEntry(type, lptr, numBytes, stack)); + _allocations.insert(pair); - _locker.unlock(); - } + _locker.unlock(); + } #endif - } - - void MemoryTracker::countOut(Nd4jPointer ptr) { -#if defined(__GNUC__) && !defined(__MINGW64__) && !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && !defined(SD_APPLE_BUILD) - if (Environment::getInstance()->isDetectingLeaks()) { - auto lptr = reinterpret_cast(ptr); +} - _locker.lock(); - if (_released.count(lptr) > 0) { - //throw std::runtime_error("Double free!"); - } +void MemoryTracker::countOut(Nd4jPointer ptr) { +#if defined(__GNUC__) && !defined(__MINGW64__) && \ + !defined(SD_ANDROID_BUILD) && !defined(SD_IOS_BUILD) && \ + !defined(SD_APPLE_BUILD) + if (Environment::getInstance()->isDetectingLeaks()) { + auto lptr = reinterpret_cast(ptr); - if (_allocations.count(lptr) > 0) { - //auto entry = _allocations[lptr]; - //std::string stack("new stack"); - //std::pair pair(lptr, entry); - //_released.insert(pair); + _locker.lock(); + if (_released.count(lptr) > 0) { + // throw std::runtime_error("Double free!"); + } + if (_allocations.count(lptr) > 0) { + // auto entry = _allocations[lptr]; + // std::string stack("new stack"); + // std::pair pair(lptr, entry); + //_released.insert(pair); - _allocations.erase(lptr); - } + _allocations.erase(lptr); + } - _locker.unlock(); - } + _locker.unlock(); + } #endif - } - - void MemoryTracker::summarize() { - if (!_allocations.empty()) { - nd4j_printf("\n%i leaked allocations\n", (int) _allocations.size()); +} - for (auto &v: _allocations) { - nd4j_printf("Leak of %i [%s] bytes\n%s\n\n", (int) v.second.numBytes(), v.second.memoryType() == MemoryType::HOST ? "HOST" : "DEVICE", v.second.stackTrace().c_str()); - } +void MemoryTracker::summarize() { + if (!_allocations.empty()) { + nd4j_printf("\n%i leaked allocations\n", (int)_allocations.size()); - throw std::runtime_error("Non-released allocations found"); - } - } + for (auto &v : _allocations) { + nd4j_printf("Leak of %i [%s] bytes\n%s\n\n", (int)v.second.numBytes(), + v.second.memoryType() == MemoryType::HOST ? "HOST" : "DEVICE", + v.second.stackTrace().c_str()); + } - void MemoryTracker::reset() { - _allocations.clear(); - _released.clear(); - } + throw std::runtime_error("Non-released allocations found"); + } +} - MemoryTracker* MemoryTracker::_INSTANCE = 0; - } +void MemoryTracker::reset() { + _allocations.clear(); + _released.clear(); } + +MemoryTracker *MemoryTracker::_INSTANCE = 0; +} // namespace memory +} // namespace sd diff --git a/libnd4j/include/memory/impl/MemoryUtils.cpp b/libnd4j/include/memory/impl/MemoryUtils.cpp index 8500a044e341..0c576ba81057 100644 --- a/libnd4j/include/memory/impl/MemoryUtils.cpp +++ b/libnd4j/include/memory/impl/MemoryUtils.cpp @@ -19,78 +19,81 @@ // #include "memory/MemoryUtils.h" + #include #if defined(__APPLE__) -#include +#include #include #elif defined(_WIN32) || defined(__WIN32__) || defined(WIN32) #else // linux -#include #include +#include #include + #include #endif - -bool sd::memory::MemoryUtils::retrieveMemoryStatistics(sd::memory::MemoryReport &report) { +bool sd::memory::MemoryUtils::retrieveMemoryStatistics( + sd::memory::MemoryReport &report) { #if defined(__APPLE__) - nd4j_debug("APPLE route\n", ""); -/* - struct task_basic_info t_info; - mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; + nd4j_debug("APPLE route\n", ""); + /* + struct task_basic_info t_info; + mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; - if (KERN_SUCCESS != task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t)&t_info, &t_info_count)) - return false; + if (KERN_SUCCESS != task_info(mach_task_self(), TASK_BASIC_INFO, + (task_info_t)&t_info, &t_info_count)) return false; - report.setVM(t_info.resident_size); - report.setRSS(t_info.resident_size); + report.setVM(t_info.resident_size); + report.setRSS(t_info.resident_size); - nd4j_debug("RSS: %lld; VM: %lld;\n", report.getRSS(), report.getVM()); -*/ - struct rusage _usage; + nd4j_debug("RSS: %lld; VM: %lld;\n", report.getRSS(), report.getVM()); + */ + struct rusage _usage; - auto res = getrusage(RUSAGE_SELF, &_usage); + auto res = getrusage(RUSAGE_SELF, &_usage); - report.setRSS(_usage.ru_maxrss); + report.setRSS(_usage.ru_maxrss); - nd4j_debug("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); + nd4j_debug("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, + _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); - return true; + return true; #elif defined(_WIN32) || defined(__WIN32__) || defined(WIN32) - nd4j_debug("WIN32 route\n", ""); - + nd4j_debug("WIN32 route\n", ""); #else - nd4j_debug("LINUX route\n", ""); - int fd = open("/proc/self/statm", O_RDONLY, 0); - if (fd >= 0) { - char line[256]; - char* s; - int n; - lseek(fd, 0, SEEK_SET); - if ((n = read(fd, line, sizeof(line))) > 0 && (s = (char*)memchr(line, ' ', n)) != NULL) { - report.setRSS((Nd4jLong)(atoll(s + 1) * getpagesize())); - } - close(fd); + nd4j_debug("LINUX route\n", ""); + int fd = open("/proc/self/statm", O_RDONLY, 0); + if (fd >= 0) { + char line[256]; + char* s; + int n; + lseek(fd, 0, SEEK_SET); + if ((n = read(fd, line, sizeof(line))) > 0 && + (s = (char*)memchr(line, ' ', n)) != NULL) { + report.setRSS((Nd4jLong)(atoll(s + 1) * getpagesize())); } + close(fd); + } - /* - struct rusage _usage; - - auto res = getrusage(RUSAGE_SELF, &_usage); + /* + struct rusage _usage; - report.setRSS(_usage.ru_maxrss); + auto res = getrusage(RUSAGE_SELF, &_usage); - //nd4j_printf("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); - */ + report.setRSS(_usage.ru_maxrss); + //nd4j_printf("Usage: %lld; %lld; %lld; %lld;\n", _usage.ru_ixrss, + _usage.ru_idrss, _usage.ru_isrss, _usage.ru_maxrss); + */ - return true; + return true; #endif - return false; + return false; } diff --git a/libnd4j/include/ops/BroadcastBoolOpsTuple.h b/libnd4j/include/ops/BroadcastBoolOpsTuple.h index d663a6c6cd4f..2a10326a84c2 100644 --- a/libnd4j/include/ops/BroadcastBoolOpsTuple.h +++ b/libnd4j/include/ops/BroadcastBoolOpsTuple.h @@ -21,30 +21,32 @@ #ifndef SD_BROADCASTBOOLOPSTUPLE_H #define SD_BROADCASTBOOLOPSTUPLE_H -#include #include +#include namespace sd { - class SD_EXPORT BroadcastBoolOpsTuple { - private: - - public: - sd::scalar::BoolOps s; - sd::pairwise::BoolOps p; - sd::broadcast::BoolOps b; - - BroadcastBoolOpsTuple() = default; - ~BroadcastBoolOpsTuple() = default; - - BroadcastBoolOpsTuple(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastBoolOpsTuple custom(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast); - }; -} - - -#endif //SD_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastBoolOpsTuple { + private: + public: + sd::scalar::BoolOps s; + sd::pairwise::BoolOps p; + sd::broadcast::BoolOps b; + + BroadcastBoolOpsTuple() = default; + ~BroadcastBoolOpsTuple() = default; + + BroadcastBoolOpsTuple(sd::scalar::BoolOps scalar, + sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastBoolOpsTuple custom(sd::scalar::BoolOps scalar, + sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h index 571e870adb9b..9d0fbdf81cce 100644 --- a/libnd4j/include/ops/BroadcastIntOpsTuple.h +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -21,30 +21,31 @@ #ifndef SD_BROADCASTINTOPSTUPLE_H #define SD_BROADCASTINTOPSTUPLE_H -#include #include +#include namespace sd { - class SD_EXPORT BroadcastIntOpsTuple { - private: - - public: - sd::scalar::IntOps s; - sd::pairwise::IntOps p; - sd::broadcast::IntOps b; - - BroadcastIntOpsTuple() = default; - ~BroadcastIntOpsTuple() = default; - - BroadcastIntOpsTuple(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastIntOpsTuple custom(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast); - }; -} - - -#endif //SD_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastIntOpsTuple { + private: + public: + sd::scalar::IntOps s; + sd::pairwise::IntOps p; + sd::broadcast::IntOps b; + + BroadcastIntOpsTuple() = default; + ~BroadcastIntOpsTuple() = default; + + BroadcastIntOpsTuple(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastIntOpsTuple custom(sd::scalar::IntOps scalar, + sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index 2b55535198f3..81b470076c28 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -21,42 +21,43 @@ #ifndef SD_BROADCASTOPSTUPLE_H #define SD_BROADCASTOPSTUPLE_H -#include #include +#include namespace sd { - class SD_EXPORT BroadcastOpsTuple { - private: - - public: - sd::scalar::Ops s; - sd::pairwise::Ops p; - sd::broadcast::Ops b; - - BroadcastOpsTuple() = default; - ~BroadcastOpsTuple() = default; - - BroadcastOpsTuple(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast) { - s = scalar; - p = pairwise; - b = broadcast; - } - - static BroadcastOpsTuple custom(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast); - - static BroadcastOpsTuple Add(); - static BroadcastOpsTuple Assign(); - static BroadcastOpsTuple Divide(); - static BroadcastOpsTuple DivideNoNan(); - static BroadcastOpsTuple Multiply(); - static BroadcastOpsTuple Subtract(); - static BroadcastOpsTuple IGamma(); - static BroadcastOpsTuple IGammac(); - - static BroadcastOpsTuple Pow(); - static BroadcastOpsTuple PowDerivative(); - }; -} - - -#endif //SD_BROADCASTOPSTUPLE_H +class SD_EXPORT BroadcastOpsTuple { + private: + public: + sd::scalar::Ops s; + sd::pairwise::Ops p; + sd::broadcast::Ops b; + + BroadcastOpsTuple() = default; + ~BroadcastOpsTuple() = default; + + BroadcastOpsTuple(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastOpsTuple custom(sd::scalar::Ops scalar, + sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast); + + static BroadcastOpsTuple Add(); + static BroadcastOpsTuple Assign(); + static BroadcastOpsTuple Divide(); + static BroadcastOpsTuple DivideNoNan(); + static BroadcastOpsTuple Multiply(); + static BroadcastOpsTuple Subtract(); + static BroadcastOpsTuple IGamma(); + static BroadcastOpsTuple IGammac(); + + static BroadcastOpsTuple Pow(); + static BroadcastOpsTuple PowDerivative(); +}; +} // namespace sd + +#endif // SD_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/InputType.h b/libnd4j/include/ops/InputType.h index 4deff4900807..94e62fb432de 100644 --- a/libnd4j/include/ops/InputType.h +++ b/libnd4j/include/ops/InputType.h @@ -22,15 +22,15 @@ #define ND4J_INPUTTYPE_H namespace sd { - namespace ops { - enum InputType { - InputType_BOOLEAN = 0, - InputType_NUMERIC = 1, - InputType_STRINGULAR = 2, - InputType_NUMERIC_SET = 3, - InputType_STRINGULAR_SET = 4, - }; - } +namespace ops { +enum InputType { + InputType_BOOLEAN = 0, + InputType_NUMERIC = 1, + InputType_STRINGULAR = 2, + InputType_NUMERIC_SET = 3, + InputType_STRINGULAR_SET = 4, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/BooleanOp.h b/libnd4j/include/ops/declarable/BooleanOp.h index 4e22ce2d4403..f0a9f11ab9fe 100644 --- a/libnd4j/include/ops/declarable/BooleanOp.h +++ b/libnd4j/include/ops/declarable/BooleanOp.h @@ -22,30 +22,31 @@ #define LIBND4J_BOOLEANOP_H #include -#include "OpDescriptor.h" + #include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class SD_EXPORT BooleanOp : public DeclarableOp { - protected: - OpDescriptor * _descriptor; - - bool prepareOutputs(Context& block); - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BooleanOp(const char *name, int numInputs, bool scalar); +namespace ops { +class SD_EXPORT BooleanOp : public DeclarableOp { + protected: + OpDescriptor* _descriptor; - bool verify(const std::vector& args); - bool verify(sd::graph::Context& block); + bool prepareOutputs(Context& block); + Nd4jStatus validateAndExecute(Context& block) override = 0; - Nd4jStatus execute(Context* block) override; + public: + BooleanOp(const char* name, int numInputs, bool scalar); - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + bool verify(const std::vector& args); + bool verify(sd::graph::Context& block); + Nd4jStatus execute(Context* block) override; + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_BOOLEANOP_H \ No newline at end of file +#endif // LIBND4J_BOOLEANOP_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h index 64e61069f741..2455667f7414 100644 --- a/libnd4j/include/ops/declarable/BroadcastableBoolOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableBoolOp.h @@ -22,22 +22,24 @@ #define SD_BROADCASTABLEBOOLOP_H #include -#include "OpDescriptor.h" -#include "DeclarableOp.h" + #include "DeclarableCustomOp.h" +#include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class SD_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{ - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs); +namespace ops { +class SD_EXPORT BroadcastableBoolOp : public DeclarableCustomOp { + protected: + Nd4jStatus validateAndExecute(Context &block) override = 0; - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + public: + BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs); + ShapeList *calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) override; +}; +} // namespace ops +} // namespace sd -#endif //SD_BROADCASTABLEBOOLOP_H +#endif // SD_BROADCASTABLEBOOLOP_H diff --git a/libnd4j/include/ops/declarable/BroadcastableOp.h b/libnd4j/include/ops/declarable/BroadcastableOp.h index 3bb1d55f4aa1..9b517e160449 100644 --- a/libnd4j/include/ops/declarable/BroadcastableOp.h +++ b/libnd4j/include/ops/declarable/BroadcastableOp.h @@ -22,22 +22,24 @@ #define LIBND4J_BROADCASTABLEOP_H #include -#include "OpDescriptor.h" -#include "DeclarableOp.h" + #include "DeclarableCustomOp.h" +#include "DeclarableOp.h" +#include "OpDescriptor.h" namespace sd { - namespace ops { - class SD_EXPORT BroadcastableOp : public DeclarableCustomOp{ - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - BroadcastableOp(const char *name, int numTArgs, int numIArgs); +namespace ops { +class SD_EXPORT BroadcastableOp : public DeclarableCustomOp { + protected: + Nd4jStatus validateAndExecute(Context &block) override = 0; - ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override; - }; - } -} + public: + BroadcastableOp(const char *name, int numTArgs, int numIArgs); + ShapeList *calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_BROADCASTABLEOP_H +#endif // LIBND4J_BROADCASTABLEOP_H diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 1b4d1f51f48c..dccb3e96d02c 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -21,67 +21,66 @@ #ifndef LIBND4J_CUSTOMOPERATIONS_H #define LIBND4J_CUSTOMOPERATIONS_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include #include #include -#include +#include #include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include #include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - namespace sd { - struct SD_EXPORT _loader { - _loader(); - }; - - namespace ops { +struct SD_EXPORT _loader { + _loader(); +}; - // logic ops - DECLARE_DIVERGENT_OP(Switch, 2, 2, true); - DECLARE_LOGIC_OP(While); - DECLARE_LOGIC_OP(Scope); - DECLARE_LOGIC_OP(Conditional); - DECLARE_LOGIC_OP(Return); +namespace ops { +// logic ops +DECLARE_DIVERGENT_OP(Switch, 2, 2, true); +DECLARE_LOGIC_OP(While); +DECLARE_LOGIC_OP(Scope); +DECLARE_LOGIC_OP(Conditional); +DECLARE_LOGIC_OP(Return); - /** - * This operations exposes given arguments as it's own outputs, but does it only once. - * Subsequent calls will be served directly by this op. - * - * PLEASE NOTE: This operation is internal graph operation, and shouldn't be used directly usually. - */ - DECLARE_CUSTOM_OP(expose, -1, -1, true, 0, 0); - } -} +/** + * This operations exposes given arguments as it's own outputs, but does it only + * once. Subsequent calls will be served directly by this op. + * + * PLEASE NOTE: This operation is internal graph operation, and shouldn't be + * used directly usually. + */ +DECLARE_CUSTOM_OP(expose, -1, -1, true, 0, 0); +} // namespace ops +} // namespace sd -#endif //LIBND4J_CUSTOMOPERATIONS_H +#endif // LIBND4J_CUSTOMOPERATIONS_H diff --git a/libnd4j/include/ops/declarable/DeclarableCustomOp.h b/libnd4j/include/ops/declarable/DeclarableCustomOp.h index 080ed6648637..163c519bad8c 100644 --- a/libnd4j/include/ops/declarable/DeclarableCustomOp.h +++ b/libnd4j/include/ops/declarable/DeclarableCustomOp.h @@ -24,19 +24,22 @@ #include namespace sd { - namespace ops { - class SD_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { - protected: - /** - * This method executes this Op - */ - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); +namespace ops { +class SD_EXPORT DeclarableCustomOp : public sd::ops::DeclarableOp { + protected: + /** + * This method executes this Op + */ + Nd4jStatus validateAndExecute(Context& block) override = 0; - ShapeList* calculateOutputShape(ShapeList* inputShapes, sd::graph::Context& block) override = 0; - }; - } -} + public: + DeclarableCustomOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); -#endif //LIBND4J_DECLARABLECUSTOMOP_H + ShapeList* calculateOutputShape(ShapeList* inputShapes, + sd::graph::Context& block) override = 0; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLECUSTOMOP_H diff --git a/libnd4j/include/ops/declarable/DeclarableListOp.h b/libnd4j/include/ops/declarable/DeclarableListOp.h index 89819864f17d..aab3613e381d 100644 --- a/libnd4j/include/ops/declarable/DeclarableListOp.h +++ b/libnd4j/include/ops/declarable/DeclarableListOp.h @@ -23,31 +23,36 @@ #include #include -#include #include +#include using namespace sd::graph; namespace sd { - namespace ops { - class SD_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override = 0; - - sd::NDArray* getZ(Context& block, int inputId) ; - void setupResult(const NDArray &array, Context& block); - void setupResultList(const NDArrayList &arrayList, Context& block); - - public: - DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs); - - Nd4jStatus execute(Context* block) override; - - ResultSet execute(const NDArrayList &list, const std::vector& inputs, const std::vector& tArgs = {}, const std::vector& iArgs = {}); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - }; - } -} +namespace ops { +class SD_EXPORT DeclarableListOp : public sd::ops::DeclarableOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override = 0; + + sd::NDArray* getZ(Context& block, int inputId); + void setupResult(const NDArray& array, Context& block); + void setupResultList(const NDArrayList& arrayList, Context& block); + + public: + DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, + int iArgs); + + Nd4jStatus execute(Context* block) override; + + ResultSet execute(const NDArrayList& list, + const std::vector& inputs, + const std::vector& tArgs = {}, + const std::vector& iArgs = {}); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 06c2e36ecc6a..e78b211fdcc8 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -21,18 +21,20 @@ #ifndef LIBND4J_DECLARABLE_OPS_H #define LIBND4J_DECLARABLE_OPS_H -#include -#include -#include #include -#include -#include "OpDescriptor.h" -#include -#include #include +#include +#include #include -#include +#include #include +#include +#include +#include + +#include + +#include "OpDescriptor.h" //#include #include @@ -42,181 +44,215 @@ using namespace sd::graph; namespace sd { - namespace ops { - - Nd4jStatus SD_EXPORT conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...); - - - template - Nd4jStatus resultHelper(T status, const char *func, const char *file, int line) { - if (status) { - // TODO: fill out error codes here - fprintf(stderr, "Validation error at %s:%d code=%d(%s) \"%s\" \n", file, line, - static_cast(status), "", func); - - return ND4J_STATUS_BAD_INPUT; - } - - return ND4J_STATUS_OK; - } - - /** - * This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class. - * - */ - class SD_EXPORT DeclarableOp { - private: - std::mutex _registrator; - bool _registered = false; - std::string _name; - protected: - OpDescriptor *_descriptor; - NDArray *_scalar = nullptr; - - virtual void registerTypes(); - - /** - * This method executes this Op, and defined for most of individual ops separately - */ - virtual Nd4jStatus validateAndExecute(Context& block) = 0; - - /** - * This method ensures that target variable has enough space for op execution - * - * TODO: we want workspaces support right here - */ - bool allocateResult(Context& block, std::initializer_list& shape, char order = 'c'); - bool allocateResult(Context& block, Nd4jLong* shape); - - /** - * This method overwrites existen NDArray or NDArrayList in VariableSpace - * - * PLEASE NOTE: This method is dangerous. - * - * @param block - * @param numOutput - * @param array - */ - void overwriteResult(Context& block, int outputIdx, NDArray* array); - void overwriteResult(Context& block, int outputIdx, NDArrayList* list); - - /* - * This method attaches array to specific Variable, identified by node ID and outputNumber (which is output index for multi-output operations) - */ - void storeResult(Context &block, int outputNumber, NDArray& array); - void storeResult(Context &block, int outputNumber, NDArray* array); - sd::NDArray* getZ(Context& block, int inputId = 0); - sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); - - virtual samediff::EmptyHandling emptyHandling(); - public: - // for special cases, like BooleanOps - DeclarableOp(); - DeclarableOp(const char *name, int numInputs, bool scalar); - - // regular constructors - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace); - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent); - DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); - - // for LogicalOps - DeclarableOp(const char *name, bool isLogical); - - // default testructor - virtual ~DeclarableOp(); - - // this method returns OpDescriptor, describing this Op instance - OpDescriptor *getOpDescriptor(); - - virtual Nd4jStatus validateDataTypes(Context& block); - - /** - * This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s) - */ - virtual ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) = 0; - - /** - * Returns opName - * - * @return - */ - const std::string& getOpName() const; - - /** - * Returns opHash - */ - Nd4jLong getOpHash() const; - - /** - * This method sets arguments for op - */ -// void setArguments(); - - /** - * This method returns pointer to results - */ -// void getResults(); - - /** - * This method executes given Op - * - * @param block - * @return 0 if OK, error code otherwise - */ - virtual Nd4jStatus execute(Context* block); - - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs); - - template ::value>> - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs); - - Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); - - sd::ResultSet evaluate(const std::vector &inputs); - - template ::value>> - sd::ResultSet evaluate(const std::vector &inputs, std::initializer_list args); - - sd::ResultSet evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); - - Nd4jStatus execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector &dArgs = std::vector(), bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32); - - sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false); - +namespace ops { - // There methods provide various validation options - Nd4jStatus validateNonEmptyInput(Context& block); +Nd4jStatus SD_EXPORT conditionHelper(const char* file, int line, int condition, + int argNumber, const char* format, ...); - // this method checks if all input arrays have equal lengths - Nd4jStatus validateInputLengthMatch(Context& block); +template +Nd4jStatus resultHelper(T status, const char* func, const char* file, + int line) { + if (status) { + // TODO: fill out error codes here + fprintf(stderr, "Validation error at %s:%d code=%d(%s) \"%s\" \n", file, + line, static_cast(status), "", func); - // this method checks if all input arrays have the same shapes (orders/strides are NOT checked) - Nd4jStatus validateInputDimensionsMatch(Context& block); + return ND4J_STATUS_BAD_INPUT; + } - // this method check if all input arrays have the same orders - Nd4jStatus validateOrdersMatch(Context& block); - - // this method checks if all input arrays are 2D - Nd4jStatus validateInput2D(Context& block); - - // this method checks if all input arrays are 3D - Nd4jStatus validateInput3D(Context& block); - - // this method checks if all input arrays are 4D - Nd4jStatus validateInput4D(Context& block); - - // this method checks if all input arrays are ND - Nd4jStatus validateInputDimensions(Context& block, int rank); - - // this method checks if number of available arguments matches op expectations - Nd4jStatus validateArguments(Context& block); - - /** - * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time - */ - int prepareOutputs(Context& block); - }; - } + return ND4J_STATUS_OK; } -#endif //LIBND4J_DECLARABLE_OPS_H +/** + * This class is the basic building block of Graph Operations. Any CustomOp out + * there is built on top of this "abstract" class. + * + */ +class SD_EXPORT DeclarableOp { + private: + std::mutex _registrator; + bool _registered = false; + std::string _name; + + protected: + OpDescriptor* _descriptor; + NDArray* _scalar = nullptr; + + virtual void registerTypes(); + + /** + * This method executes this Op, and defined for most of individual ops + * separately + */ + virtual Nd4jStatus validateAndExecute(Context& block) = 0; + + /** + * This method ensures that target variable has enough space for op execution + * + * TODO: we want workspaces support right here + */ + bool allocateResult(Context& block, std::initializer_list& shape, + char order = 'c'); + bool allocateResult(Context& block, Nd4jLong* shape); + + /** + * This method overwrites existen NDArray or NDArrayList in VariableSpace + * + * PLEASE NOTE: This method is dangerous. + * + * @param block + * @param numOutput + * @param array + */ + void overwriteResult(Context& block, int outputIdx, NDArray* array); + void overwriteResult(Context& block, int outputIdx, NDArrayList* list); + + /* + * This method attaches array to specific Variable, identified by node ID and + * outputNumber (which is output index for multi-output operations) + */ + void storeResult(Context& block, int outputNumber, NDArray& array); + void storeResult(Context& block, int outputNumber, NDArray* array); + sd::NDArray* getZ(Context& block, int inputId = 0); + sd::NDArray* getNullifiedZ(Context& block, int inputId = 0); + + virtual samediff::EmptyHandling emptyHandling(); + + public: + // for special cases, like BooleanOps + DeclarableOp(); + DeclarableOp(const char* name, int numInputs, bool scalar); + + // regular constructors + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace); + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent); + DeclarableOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); + + // for LogicalOps + DeclarableOp(const char* name, bool isLogical); + + // default testructor + virtual ~DeclarableOp(); + + // this method returns OpDescriptor, describing this Op instance + OpDescriptor* getOpDescriptor(); + + virtual Nd4jStatus validateDataTypes(Context& block); + + /** + * This method should be available in each implemented Op, and should return + * Op output shape(s), for a given input shape(s) + */ + virtual ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) = 0; + + /** + * Returns opName + * + * @return + */ + const std::string& getOpName() const; + + /** + * Returns opHash + */ + Nd4jLong getOpHash() const; + + /** + * This method sets arguments for op + */ + // void setArguments(); + + /** + * This method returns pointer to results + */ + // void getResults(); + + /** + * This method executes given Op + * + * @param block + * @return 0 if OK, error code otherwise + */ + virtual Nd4jStatus execute(Context* block); + + Nd4jStatus execute(const std::vector& inputs, + const std::vector& outputs); + + template ::value>> + Nd4jStatus execute(const std::vector& inputs, + const std::vector& outputs, + std::initializer_list tArgs); + + Nd4jStatus execute( + const std::vector& inputs, const std::vector& outputs, + const std::vector& tArgs, const std::vector& iArgs, + const std::vector& bArgs = std::vector(), + const std::vector& dArgs = std::vector(), + bool isInplace = false); + + sd::ResultSet evaluate(const std::vector& inputs); + + template ::value>> + sd::ResultSet evaluate(const std::vector& inputs, + std::initializer_list args); + + sd::ResultSet evaluate( + const std::vector& inputs, const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& bArgs = std::vector(), + const std::vector& dArgs = std::vector(), + bool isInplace = false); + + Nd4jStatus execute( + sd::graph::RandomGenerator& rng, const std::vector& inputs, + const std::vector& outputs, const std::vector& tArgs, + const std::vector& iArgs, const std::vector& bArgs, + const std::vector& dArgs = std::vector(), + bool isInplace = false, sd::DataType type = sd::DataType::FLOAT32); + + sd::ResultSet execute(const sd::OpArgsHolder& holder, bool isInplace = false); + + // There methods provide various validation options + Nd4jStatus validateNonEmptyInput(Context& block); + + // this method checks if all input arrays have equal lengths + Nd4jStatus validateInputLengthMatch(Context& block); + + // this method checks if all input arrays have the same shapes (orders/strides + // are NOT checked) + Nd4jStatus validateInputDimensionsMatch(Context& block); + + // this method check if all input arrays have the same orders + Nd4jStatus validateOrdersMatch(Context& block); + + // this method checks if all input arrays are 2D + Nd4jStatus validateInput2D(Context& block); + + // this method checks if all input arrays are 3D + Nd4jStatus validateInput3D(Context& block); + + // this method checks if all input arrays are 4D + Nd4jStatus validateInput4D(Context& block); + + // this method checks if all input arrays are ND + Nd4jStatus validateInputDimensions(Context& block, int rank); + + // this method checks if number of available arguments matches op expectations + Nd4jStatus validateArguments(Context& block); + + /** + * This method pre-allocates NDArrays for Op output, in case they are not + * available at op execution time + */ + int prepareOutputs(Context& block); +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLE_OPS_H diff --git a/libnd4j/include/ops/declarable/DeclarableReductionOp.h b/libnd4j/include/ops/declarable/DeclarableReductionOp.h index 8af574a2ef62..bd93d94848c3 100644 --- a/libnd4j/include/ops/declarable/DeclarableReductionOp.h +++ b/libnd4j/include/ops/declarable/DeclarableReductionOp.h @@ -24,19 +24,22 @@ #include namespace sd { - namespace ops { - class SD_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { - protected: - /** - * This method executes this Op - */ - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); +namespace ops { +class SD_EXPORT DeclarableReductionOp : public sd::ops::DeclarableOp { + protected: + /** + * This method executes this Op + */ + Nd4jStatus validateAndExecute(Context& block) override = 0; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - }; - } -} + public: + DeclarableReductionOp(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); -#endif //LIBND4J_DECLARABLE_REDUCTION_OP_H + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_DECLARABLE_REDUCTION_OP_H diff --git a/libnd4j/include/ops/declarable/EmptyHandling.h b/libnd4j/include/ops/declarable/EmptyHandling.h index c25fea498201..810faa279e3c 100644 --- a/libnd4j/include/ops/declarable/EmptyHandling.h +++ b/libnd4j/include/ops/declarable/EmptyHandling.h @@ -22,11 +22,7 @@ #define SAMEDIFF_EMPTYHANDLING_H namespace samediff { - enum EmptyHandling { - EMPTY_SKIP = 1, - EMPTY_EXCEPTION = 2, - EMPTY_EXECUTE = 3 - }; +enum EmptyHandling { EMPTY_SKIP = 1, EMPTY_EXCEPTION = 2, EMPTY_EXECUTE = 3 }; } -#endif //SAMEDIFF_EMPTYHANDLING_H +#endif // SAMEDIFF_EMPTYHANDLING_H diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h index 3e952bd4b6fc..54c737c1295f 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastBoolOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for broadcast operations. - */ - class SD_EXPORT LegacyBroadcastBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override ; - public: - LegacyBroadcastBoolOp(); - LegacyBroadcastBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for broadcast operations. + */ +class SD_EXPORT LegacyBroadcastBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyBroadcastBoolOp(); + LegacyBroadcastBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYBROADCASTOP_H +#endif // LIBND4J_LEGACYBROADCASTOP_H diff --git a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h index 44518798ed30..718ff2dffe8b 100644 --- a/libnd4j/include/ops/declarable/LegacyBroadcastOp.h +++ b/libnd4j/include/ops/declarable/LegacyBroadcastOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for broadcast operations. - */ - class SD_EXPORT LegacyBroadcastOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyBroadcastOp(); - LegacyBroadcastOp(int opNum); +namespace ops { +/** + * This class provides wrapper for broadcast operations. + */ +class SD_EXPORT LegacyBroadcastOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyBroadcastOp(); + LegacyBroadcastOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYBROADCASTOP_H +#endif // LIBND4J_LEGACYBROADCASTOP_H diff --git a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h index bfe33874c431..9076a2acd6f4 100644 --- a/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyIndexReduceOp.h @@ -24,23 +24,26 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for IndexAccumulation operations. i.e. IndexMax or IndexAbsoluteMin etc - * - * TODO: eventually we want this op class to return long long instead of T - */ - class SD_EXPORT LegacyIndexReduceOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyIndexReduceOp(); - LegacyIndexReduceOp(int opNum); +namespace ops { +/** + * This class provides wrapper for IndexAccumulation operations. i.e. IndexMax + * or IndexAbsoluteMin etc + * + * TODO: eventually we want this op class to return long long instead of T + */ +class SD_EXPORT LegacyIndexReduceOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyIndexReduceOp(); + LegacyIndexReduceOp(int opNum); - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYINDEXREDUCEOP_H +#endif // LIBND4J_LEGACYINDEXREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyOp.h b/libnd4j/include/ops/declarable/LegacyOp.h index f199f362120b..e8cf41e3d847 100644 --- a/libnd4j/include/ops/declarable/LegacyOp.h +++ b/libnd4j/include/ops/declarable/LegacyOp.h @@ -21,48 +21,50 @@ #ifndef LIBND4J_LEGACYOP_H #define LIBND4J_LEGACYOP_H -#include #include +#include namespace sd { - namespace ops { +namespace ops { - /** - * This class is root abstraction for legacy XYZ ops wrappers. - * All wrappers for specific op groups (i.e. LegacyTransformOp for Transform ops) are inheriting this class. - * - * - */ - class SD_EXPORT LegacyOp : public DeclarableOp { - protected: - // this field is mainly for debugging - // it defines, which legacy op should be invoked on a given data - int _opNum = -1; - int _numInputs = 0; +/** + * This class is root abstraction for legacy XYZ ops wrappers. + * All wrappers for specific op groups (i.e. LegacyTransformOp for Transform + * ops) are inheriting this class. + * + * + */ +class SD_EXPORT LegacyOp : public DeclarableOp { + protected: + // this field is mainly for debugging + // it defines, which legacy op should be invoked on a given data + int _opNum = -1; + int _numInputs = 0; - // All Op classes provide own specific implementation for this method - Nd4jStatus validateAndExecute(Context& block) override = 0; - public: - LegacyOp(int numInputs); - LegacyOp(int numInputs, int opNum); - ~LegacyOp() = default; + // All Op classes provide own specific implementation for this method + Nd4jStatus validateAndExecute(Context& block) override = 0; - LegacyOp(const LegacyOp& other) noexcept; + public: + LegacyOp(int numInputs); + LegacyOp(int numInputs, int opNum); + ~LegacyOp() = default; - LegacyOp& operator=(const LegacyOp& other) noexcept; + LegacyOp(const LegacyOp& other) noexcept; - // move constructor - LegacyOp(LegacyOp&& other) noexcept; + LegacyOp& operator=(const LegacyOp& other) noexcept; - // move assignment operator - LegacyOp& operator=(LegacyOp&& other) noexcept; + // move constructor + LegacyOp(LegacyOp&& other) noexcept; - // All Op classes provide own specific implementation for this method - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override = 0; - virtual LegacyOp* clone() = 0; - }; - } -} + // move assignment operator + LegacyOp& operator=(LegacyOp&& other) noexcept; + // All Op classes provide own specific implementation for this method + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override = 0; + virtual LegacyOp* clone() = 0; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYOP_H +#endif // LIBND4J_LEGACYOP_H diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h index 8e76226df161..74194c82a79b 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformBoolOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Pairwise transform operations - */ - class SD_EXPORT LegacyPairwiseTransformBoolOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyPairwiseTransformBoolOp(); - LegacyPairwiseTransformBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Pairwise transform operations + */ +class SD_EXPORT LegacyPairwiseTransformBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyPairwiseTransformBoolOp(); + LegacyPairwiseTransformBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYPAIRWISETRANSFORMOP_H +#endif // LIBND4J_LEGACYPAIRWISETRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h index 0418db506fe3..f82dca1f1454 100644 --- a/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyPairwiseTransformOp.h @@ -24,22 +24,23 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Pairwise transform operations - */ - class SD_EXPORT LegacyPairwiseTransformOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyPairwiseTransformOp(); - LegacyPairwiseTransformOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Pairwise transform operations + */ +class SD_EXPORT LegacyPairwiseTransformOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyPairwiseTransformOp(); + LegacyPairwiseTransformOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYPAIRWISETRANSFORMOP_H +#endif // LIBND4J_LEGACYPAIRWISETRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyRandomOp.h b/libnd4j/include/ops/declarable/LegacyRandomOp.h index 6dc7d8d72ee1..f7b6d8d51c8b 100644 --- a/libnd4j/include/ops/declarable/LegacyRandomOp.h +++ b/libnd4j/include/ops/declarable/LegacyRandomOp.h @@ -21,36 +21,42 @@ #ifndef LIBND4J_LEGACYRANDOMOP_H #define LIBND4J_LEGACYRANDOMOP_H - #include #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Random operations (i.e. linspace or Uniform) - */ - class SD_EXPORT LegacyRandomOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyRandomOp(); - LegacyRandomOp(int opNum); - ~LegacyRandomOp() = default; - - template - Nd4jStatus validateAndExecute_(Context &block); - - sd::ResultSet execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& tArgs = {}, const std::vector& iArgs = {}, const std::vector& dArgs = {}, bool isInplace = false); - - Nd4jStatus execute(Context* block) override; - - Nd4jStatus validateDataTypes(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYTRANSFORMOP_H +namespace ops { +/** + * This class provides wrapper for Random operations (i.e. linspace or + * Uniform) + */ +class SD_EXPORT LegacyRandomOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyRandomOp(); + LegacyRandomOp(int opNum); + ~LegacyRandomOp() = default; + + template + Nd4jStatus validateAndExecute_(Context& block); + + sd::ResultSet execute(sd::graph::RandomGenerator& rng, + const std::vector& inputs, + const std::vector& tArgs = {}, + const std::vector& iArgs = {}, + const std::vector& dArgs = {}, + bool isInplace = false); + + Nd4jStatus execute(Context* block) override; + + Nd4jStatus validateDataTypes(Context& block) override; + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYTRANSFORMOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduce3Op.h b/libnd4j/include/ops/declarable/LegacyReduce3Op.h index bfcd666a6a78..565e31f3b7ab 100644 --- a/libnd4j/include/ops/declarable/LegacyReduce3Op.h +++ b/libnd4j/include/ops/declarable/LegacyReduce3Op.h @@ -24,22 +24,24 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Reduce3 operations (i.e. dot, cosineDistance etc) - */ - class SD_EXPORT LegacyReduce3Op : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduce3Op(); - LegacyReduce3Op(int opNum); +namespace ops { +/** + * This class provides wrapper for Reduce3 operations (i.e. dot, + * cosineDistance etc) + */ +class SD_EXPORT LegacyReduce3Op : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduce3Op(); + LegacyReduce3Op(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCE3OP_H +#endif // LIBND4J_LEGACYREDUCE3OP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h index 647c88a6a322..64b7bf626559 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceBoolOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class SD_EXPORT LegacyReduceBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceBoolOp(); - LegacyReduceBoolOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceBoolOp(); + LegacyReduceBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h index 59f3ec8d2717..cafc3b685625 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceFloatOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class SD_EXPORT LegacyReduceFloatOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceFloatOp(); - LegacyReduceFloatOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceFloatOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceFloatOp(); + LegacyReduceFloatOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h index 1f2b8339b848..d91f0ca99114 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceLongOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceLongOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class SD_EXPORT LegacyReduceLongOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceLongOp(); - LegacyReduceLongOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceLongOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceLongOp(); + LegacyReduceLongOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceOp.h b/libnd4j/include/ops/declarable/LegacyReduceOp.h index 1ce9b5609b35..39d27fa4fcee 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceOp.h @@ -32,16 +32,15 @@ namespace sd { LegacyReduceOp(); LegacyReduceOp(int opNum); - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block); - virtual LegacyOp* clone(); + ShapeList* calculateOutputShape(ShapeList* inputShape, +sd::graph::Context& block); virtual LegacyOp* clone(); }; } } */ -#include -#include #include +#include #include +#include - -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h index 63472ec67e9c..eb09a3a39ab4 100644 --- a/libnd4j/include/ops/declarable/LegacyReduceSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyReduceSameOp.h @@ -24,19 +24,20 @@ #include namespace sd { - namespace ops { - class SD_EXPORT LegacyReduceSameOp: public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - public: - LegacyReduceSameOp(); - LegacyReduceSameOp(int opNum); +namespace ops { +class SD_EXPORT LegacyReduceSameOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyReduceSameOp(); + LegacyReduceSameOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYREDUCEOP_H +#endif // LIBND4J_LEGACYREDUCEOP_H diff --git a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h index 5da57ad18810..8fe33aeac7a9 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarBoolOp.h @@ -24,24 +24,25 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray - */ - class SD_EXPORT LegacyScalarBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - - public: - LegacyScalarBoolOp(); - LegacyScalarBoolOp(int opNum); - LegacyScalarBoolOp(int opNum, NDArray &scalar); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYSCALAROP_H +namespace ops { +/** + * This class provides wrapper for scalar transform operations, i.e. a + b = + * c, where either a or b is scalar primitive and other operand is NDArray + */ +class SD_EXPORT LegacyScalarBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyScalarBoolOp(); + LegacyScalarBoolOp(int opNum); + LegacyScalarBoolOp(int opNum, NDArray& scalar); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYSCALAROP_H diff --git a/libnd4j/include/ops/declarable/LegacyScalarOp.h b/libnd4j/include/ops/declarable/LegacyScalarOp.h index 4fbfa5cc28cf..c1b5d2d01f5a 100644 --- a/libnd4j/include/ops/declarable/LegacyScalarOp.h +++ b/libnd4j/include/ops/declarable/LegacyScalarOp.h @@ -24,24 +24,25 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for scalar transform operations, i.e. a + b = c, where either a or b is scalar primitive and other operand is NDArray - */ - class SD_EXPORT LegacyScalarOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context& block) override; - - public: - LegacyScalarOp(); - LegacyScalarOp(int opNum); - LegacyScalarOp(int opNum, NDArray &scalar); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J_LEGACYSCALAROP_H +namespace ops { +/** + * This class provides wrapper for scalar transform operations, i.e. a + b = + * c, where either a or b is scalar primitive and other operand is NDArray + */ +class SD_EXPORT LegacyScalarOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyScalarOp(); + LegacyScalarOp(int opNum); + LegacyScalarOp(int opNum, NDArray& scalar); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LEGACYSCALAROP_H diff --git a/libnd4j/include/ops/declarable/LegacyStatsOp.h b/libnd4j/include/ops/declarable/LegacyStatsOp.h index eb74a803f507..53b71414c82f 100644 --- a/libnd4j/include/ops/declarable/LegacyStatsOp.h +++ b/libnd4j/include/ops/declarable/LegacyStatsOp.h @@ -24,22 +24,24 @@ #include namespace sd { - namespace ops { - /** - * This class provides wrapper for SummaryStats operations: Variance and Standard Deviation - */ - class SD_EXPORT LegacyStatsOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyStatsOp(); - LegacyStatsOp(int opNum); +namespace ops { +/** + * This class provides wrapper for SummaryStats operations: Variance and + * Standard Deviation + */ +class SD_EXPORT LegacyStatsOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyStatsOp(); + LegacyStatsOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_LEGACYSTATSOP_H +#endif // LIBND4J_LEGACYSTATSOP_H diff --git a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h index 5b4e82b41587..b5ac326acf2f 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformAnyOp.h @@ -21,26 +21,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_ANY_OP__H #define LIBND4J__LEGACY_TRANSFORM_ANY_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformAnyOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformAnyOp(); - LegacyTransformAnyOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformAnyOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyTransformAnyOp(); + LegacyTransformAnyOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h index b9e81d7c3574..56c01ae32aac 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformBoolOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_BOOL_OP__H #define LIBND4J__LEGACY_TRANSFORM_BOOL_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformBoolOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformBoolOp(); - LegacyTransformBoolOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformBoolOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformBoolOp(); + LegacyTransformBoolOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h index 6ce3f2649655..4f01611bca7e 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformFloatOp.h @@ -21,26 +21,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H #define LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformFloatOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformFloatOp(); - LegacyTransformFloatOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} - - -#endif //LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformFloatOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; + + public: + LegacyTransformFloatOp(); + LegacyTransformFloatOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J__LEGACY_TRANSFORM_FLOAT_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformOp.h b/libnd4j/include/ops/declarable/LegacyTransformOp.h index 82848e714441..1ba819aea0e3 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformOp.h @@ -21,32 +21,32 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_OP__H #define LIBND4J__LEGACY_TRANSFORM_OP__H - //#include #ifdef ONLY_SAME_TRANSFORM namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block); - public: - LegacyTransformOp(); - LegacyTransformOp(int opNum); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block); - virtual LegacyOp* clone(); - }; - } -} +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block); + + public: + LegacyTransformOp(); + LegacyTransformOp(int opNum); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block); + virtual LegacyOp* clone(); +}; +} // namespace ops +} // namespace sd #endif +#include #include #include -#include #include - -#endif //LIBND4J__LEGACY_TRANSFORM_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h index f2fb71dff08f..e4494b205be1 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformSameOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformSameOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_SAME_OP__H #define LIBND4J__LEGACY_TRANSFORM_SAME_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformSameOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformSameOp(); - LegacyTransformSameOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformSameOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformSameOp(); + LegacyTransformSameOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h index 1a936f8c6ce3..f66839f80313 100644 --- a/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h +++ b/libnd4j/include/ops/declarable/LegacyTransformStrictOp.h @@ -22,26 +22,26 @@ #ifndef LIBND4J__LEGACY_TRANSFORM_STRICT_OP__H #define LIBND4J__LEGACY_TRANSFORM_STRICT_OP__H - #include namespace sd { - namespace ops { - /** - * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) - */ - class SD_EXPORT LegacyTransformStrictOp : public LegacyOp { - protected: - Nd4jStatus validateAndExecute(Context &block) override; - public: - LegacyTransformStrictOp(); - LegacyTransformStrictOp(int opNum); +namespace ops { +/** + * This class provides wrapper for Transform operations (i.e. Pow or OneMinus) + */ +class SD_EXPORT LegacyTransformStrictOp : public LegacyOp { + protected: + Nd4jStatus validateAndExecute(Context& block) override; - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - LegacyOp* clone() override; - }; - } -} + public: + LegacyTransformStrictOp(); + LegacyTransformStrictOp(int opNum); + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; + LegacyOp* clone() override; +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J__LEGACY_TRANSFORM_SAME_OP__H +#endif // LIBND4J__LEGACY_TRANSFORM_SAME_OP__H diff --git a/libnd4j/include/ops/declarable/LogicOp.h b/libnd4j/include/ops/declarable/LogicOp.h index 68092f346ecc..03f071a95f1d 100644 --- a/libnd4j/include/ops/declarable/LogicOp.h +++ b/libnd4j/include/ops/declarable/LogicOp.h @@ -24,24 +24,27 @@ #include "DeclarableOp.h" namespace sd { - namespace ops { - - /** - * Logic ops are unique snowflakes in any Graph. They dramatically change Graph Execution process, by introducing loops, conditions, etc. - * - * Their code is the part of GraphExecutioner logic. But we still want them to be expressed via Graph - * @tparam T - */ - class SD_EXPORT LogicOp : public DeclarableOp { - protected: - Nd4jStatus validateAndExecute(sd::graph::Context& block) override; - public: - LogicOp(const char *name); - - ShapeList* calculateOutputShape(ShapeList* inputShape, sd::graph::Context &block) override; - }; - } -} - - -#endif //LIBND4J_LOGICOP_H +namespace ops { + +/** + * Logic ops are unique snowflakes in any Graph. They dramatically change Graph + * Execution process, by introducing loops, conditions, etc. + * + * Their code is the part of GraphExecutioner logic. But we still want them to + * be expressed via Graph + * @tparam T + */ +class SD_EXPORT LogicOp : public DeclarableOp { + protected: + Nd4jStatus validateAndExecute(sd::graph::Context& block) override; + + public: + LogicOp(const char* name); + + ShapeList* calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) override; +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LOGICOP_H diff --git a/libnd4j/include/ops/declarable/OpDescriptor.h b/libnd4j/include/ops/declarable/OpDescriptor.h index ceb10db38f9d..9c99518750e9 100644 --- a/libnd4j/include/ops/declarable/OpDescriptor.h +++ b/libnd4j/include/ops/declarable/OpDescriptor.h @@ -21,169 +21,179 @@ #ifndef LIBND4J_OPDESCRIPTOR_H #define LIBND4J_OPDESCRIPTOR_H -#include -#include -#include +#include +#include #include #include -#include -#include -namespace sd { - namespace ops { - - /** - * This class is very basic info holder for ops. bean/pojo pretty much. - * - */ - class SD_EXPORT OpDescriptor { - protected: - // opNum for legacy XYZ ops - int _opNum = 0; - - // opName for CustomOp - std::string _opName; +#include +#include +#include - // hash is used for ops lookup in OpRegistrator - Nd4jLong _hash = -1; +namespace sd { +namespace ops { - // minimal required/expected number of inputs/outpus for this given op - int _numInputs = 1; - int _numOutputs = 1; +/** + * This class is very basic info holder for ops. bean/pojo pretty much. + * + */ +class SD_EXPORT OpDescriptor { + protected: + // opNum for legacy XYZ ops + int _opNum = 0; - // enum for ops. deprecated. will be removed - sd::graph::OpClass _opClass; + // opName for CustomOp + std::string _opName; - // special flag for divergent ops - ops that CAN and WILL modify graph behavior. Literally: IF, CASE. - bool _divergent = false; + // hash is used for ops lookup in OpRegistrator + Nd4jLong _hash = -1; - // flag, if this given op allows in-place execution - bool _allowsInplace = true; + // minimal required/expected number of inputs/outpus for this given op + int _numInputs = 1; + int _numOutputs = 1; - // minimal required number of T-type arguments. - // -1 as value means: not limited, variable number of arguments - int _tArgs = 0; + // enum for ops. deprecated. will be removed + sd::graph::OpClass _opClass; - // minimal required number of Integer-type arguments. - // -1 as value means: not limited, variable number of arguments - int _iArgs = 0; + // special flag for divergent ops - ops that CAN and WILL modify graph + // behavior. Literally: IF, CASE. + bool _divergent = false; - // field for BooleanOps - bool _scalar = false; + // flag, if this given op allows in-place execution + bool _allowsInplace = true; - // field for LogicOps - bool _logic = false; + // minimal required number of T-type arguments. + // -1 as value means: not limited, variable number of arguments + int _tArgs = 0; - // default InputType is numeric - InputType _inputType = InputType_NUMERIC; + // minimal required number of Integer-type arguments. + // -1 as value means: not limited, variable number of arguments + int _iArgs = 0; + // field for BooleanOps + bool _scalar = false; - bool _sameMode = false; - std::vector _allowedIns; - std::vector _allowedOuts; + // field for LogicOps + bool _logic = false; - // optional per-input configuration - MAP_IMPL> _outputTypes; - MAP_IMPL> _inputTypes; + // default InputType is numeric + InputType _inputType = InputType_NUMERIC; + bool _sameMode = false; + std::vector _allowedIns; + std::vector _allowedOuts; - // field for ops that allow data type override at runtime - bool _dtypeOverride = false; + // optional per-input configuration + MAP_IMPL> _outputTypes; + MAP_IMPL> _inputTypes; - bool checkDataTypesMatch(sd::DataType needle, std::vector &haystack) const; - public: - // default constructor - OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace); + // field for ops that allow data type override at runtime + bool _dtypeOverride = false; - // constructor for boolean ops - OpDescriptor(int numInputs, std::string opName, bool isScalar); - OpDescriptor(int numInputs, const char* opName, bool isScalar); + bool checkDataTypesMatch(sd::DataType needle, + std::vector& haystack) const; - // default constructor - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace); + public: + // default constructor + OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace); - // constructor for configurable op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs); + // constructor for boolean ops + OpDescriptor(int numInputs, std::string opName, bool isScalar); + OpDescriptor(int numInputs, const char* opName, bool isScalar); - // constructor for non-configurable divergent op - OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent); + // default constructor + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace); - // constructor for non-configurable divergent op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent); + // constructor for configurable op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs); - // constructor for configurable divergent op - OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs); + // constructor for non-configurable divergent op + OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace, bool divergent); - // constructor for logical ops (while, scope, etc) - OpDescriptor(const char * opName, bool isLogic); + // constructor for non-configurable divergent op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent); - bool operator==(const OpDescriptor& other) const; + // constructor for configurable divergent op + OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent, int tArgs, int iArgs); - // default destructor - ~OpDescriptor(); + // constructor for logical ops (while, scope, etc) + OpDescriptor(const char* opName, bool isLogic); - // this method returns minimal expected number of T arguments - int getNumberOfTArgs(); + bool operator==(const OpDescriptor& other) const; - // this method returns minimal expected number of Integer arguments - int getNumberOfIArgs(); + // default destructor + ~OpDescriptor(); - // this method returns minimal expected number of inputs - int getNumberOfInputs(); + // this method returns minimal expected number of T arguments + int getNumberOfTArgs(); - // this method returns hash code for this operation - Nd4jLong getHash(); + // this method returns minimal expected number of Integer arguments + int getNumberOfIArgs(); - // this method returns minimal expected number of outputs - int getNumberOfOutputs(); + // this method returns minimal expected number of inputs + int getNumberOfInputs(); - // this method returns opName (can be empty) - std::string *getOpName(); + // this method returns hash code for this operation + Nd4jLong getHash(); - // returns TRUE if this op is divergent. FALSE otherwise - bool isDivergent(); + // this method returns minimal expected number of outputs + int getNumberOfOutputs(); - // returns TRUE if this op allows in-place execution - bool allowsInplace(); + // this method returns opName (can be empty) + std::string* getOpName(); - // this method allows you to enable/disable inplace call for a given op - void allowInplace(bool reallyAllow); + // returns TRUE if this op is divergent. FALSE otherwise + bool isDivergent(); - // this method returns opNum (applicable for legacy XYZ ops only) - int getOpNum(); + // returns TRUE if this op allows in-place execution + bool allowsInplace(); - // this method allows to set specifc opNum - void setOpNum(int opNum); + // this method allows you to enable/disable inplace call for a given op + void allowInplace(bool reallyAllow); - void setHash(Nd4jLong hash); + // this method returns opNum (applicable for legacy XYZ ops only) + int getOpNum(); - InputType inputType(); + // this method allows to set specifc opNum + void setOpNum(int opNum); + void setHash(Nd4jLong hash); + InputType inputType(); - OpDescriptor* setInputType(InputType type); - OpDescriptor* setAllowedInputTypes(const std::initializer_list &dtype); - OpDescriptor* setAllowedOutputTypes(const std::initializer_list &dtype); - OpDescriptor* setAllowedInputTypes(int index, const std::vector &dtype); - OpDescriptor* setAllowedOutputTypes(int index, const std::vector &dtype); - OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype); - OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype); - OpDescriptor* setAllowedInputTypes(sd::DataType dtype); - OpDescriptor* setAllowedOutputTypes(sd::DataType dtype); - OpDescriptor* allowOverride(bool reallyAllow); - OpDescriptor* setSameMode(bool reallySame); - OpDescriptor* setInputType(int idx, sd::DataType dtype); - OpDescriptor* setOutputType(int idx, sd::DataType dtype); + OpDescriptor* setInputType(InputType type); + OpDescriptor* setAllowedInputTypes( + const std::initializer_list& dtype); + OpDescriptor* setAllowedOutputTypes( + const std::initializer_list& dtype); + OpDescriptor* setAllowedInputTypes(int index, + const std::vector& dtype); + OpDescriptor* setAllowedOutputTypes(int index, + const std::vector& dtype); + OpDescriptor* setAllowedInputTypes(int index, sd::DataType dtype); + OpDescriptor* setAllowedOutputTypes(int index, sd::DataType dtype); + OpDescriptor* setAllowedInputTypes(sd::DataType dtype); + OpDescriptor* setAllowedOutputTypes(sd::DataType dtype); + OpDescriptor* allowOverride(bool reallyAllow); + OpDescriptor* setSameMode(bool reallySame); + OpDescriptor* setInputType(int idx, sd::DataType dtype); + OpDescriptor* setOutputType(int idx, sd::DataType dtype); - std::vector getOutputTypesForOutput(int index); + std::vector getOutputTypesForOutput(int index); - bool checkInputMatch(int index, sd::DataType dataType); - bool checkOutputMatch(int index, sd::DataType dataType); - bool isSameMode(); + bool checkInputMatch(int index, sd::DataType dataType); + bool checkOutputMatch(int index, sd::DataType dataType); + bool isSameMode(); - bool isInherit(int index); - }; - } -} + bool isInherit(int index); +}; +} // namespace ops +} // namespace sd -#endif //LIBND4J_OPDESCRIPTOR_H +#endif // LIBND4J_OPDESCRIPTOR_H diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index a0954efe3c69..6573f3681a85 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -21,135 +21,142 @@ #ifndef LIBND4J_OPREGISTRATOR_H #define LIBND4J_OPREGISTRATOR_H -#include -#include -#include -#include +#include #include #include -#include +#include + +#include +#include +#include // handlers part -#include #include +#include #ifndef __JAVACPP_HACK__ namespace std { - template <> - class hash> { - public: - size_t operator()(const std::pair& k) const; - }; - - template <> - class hash> { - public: - size_t operator()(const std::pair& k) const; - }; +template <> +class hash> { + public: + size_t operator()(const std::pair& k) const; }; -#endif +template <> +class hash> { + public: + size_t operator()(const std::pair& k) const; +}; +}; // namespace std +#endif namespace sd { - namespace ops { - /** - * This class provides runtime ops lookup, based on opName or opHash. - * To build lookup directory we use *_OP_IMPL macro, which puts static structs at compile time in .cpp files, - * so once binary is executed, static objects are initialized automatically, and we get list of all ops - * available at runtime via this singleton. - * - */ - class SD_EXPORT OpRegistrator { - private: - static OpRegistrator* _INSTANCE; - OpRegistrator() { - nd4j_debug("OpRegistrator started\n",""); - - /* +namespace ops { +/** + * This class provides runtime ops lookup, based on opName or opHash. + * To build lookup directory we use *_OP_IMPL macro, which puts static structs + * at compile time in .cpp files, so once binary is executed, static objects are + * initialized automatically, and we get list of all ops available at runtime + * via this singleton. + * + */ +class SD_EXPORT OpRegistrator { + private: + static OpRegistrator* _INSTANCE; + OpRegistrator() { + nd4j_debug("OpRegistrator started\n", ""); + + /* #ifndef _RELEASE - std::signal(SIGSEGV, &OpRegistrator::sigSegVHandler); - std::signal(SIGINT, &OpRegistrator::sigIntHandler); - std::signal(SIGABRT, &OpRegistrator::sigIntHandler); - std::signal(SIGFPE, &OpRegistrator::sigIntHandler); - std::signal(SIGILL, &OpRegistrator::sigIntHandler); - std::signal(SIGTERM, &OpRegistrator::sigIntHandler); - atexit(&OpRegistrator::exitHandler); + std::signal(SIGSEGV, &OpRegistrator::sigSegVHandler); + std::signal(SIGINT, &OpRegistrator::sigIntHandler); + std::signal(SIGABRT, &OpRegistrator::sigIntHandler); + std::signal(SIGFPE, &OpRegistrator::sigIntHandler); + std::signal(SIGILL, &OpRegistrator::sigIntHandler); + std::signal(SIGTERM, &OpRegistrator::sigIntHandler); + atexit(&OpRegistrator::exitHandler); #endif - */ - }; + */ + }; - MAP_IMPL _msvc; + MAP_IMPL _msvc; - // pointers to our operations - MAP_IMPL> _declarablesLD; - MAP_IMPL> _declarablesD; + // pointers to our operations + MAP_IMPL> _declarablesLD; + MAP_IMPL> _declarablesD; - // pointers to platform-specific helpers - MAP_IMPL, sd::ops::platforms::PlatformHelper*> _helpersLH; - MAP_IMPL, sd::ops::platforms::PlatformHelper*> _helpersH; - std::vector _uniqueH; + // pointers to platform-specific helpers + MAP_IMPL, + sd::ops::platforms::PlatformHelper*> + _helpersLH; + MAP_IMPL, + sd::ops::platforms::PlatformHelper*> + _helpersH; + std::vector _uniqueH; - std::mutex _locker; - std::string _opsList; - bool isInit = false; - public: - ~OpRegistrator(); + std::mutex _locker; + std::string _opsList; + bool isInit = false; - static OpRegistrator* getInstance(); + public: + ~OpRegistrator(); - static void exitHandler(); - static void sigIntHandler(int sig); - static void sigSegVHandler(int sig); + static OpRegistrator* getInstance(); - void updateMSVC(Nd4jLong newHash, std::string& oldName); + static void exitHandler(); + static void sigIntHandler(int sig); + static void sigSegVHandler(int sig); - template - std::string local_to_string(T value); - const char * getAllCustomOperations(); + void updateMSVC(Nd4jLong newHash, std::string& oldName); - /** - * This method registers operation in our registry, so we can use them later - * - * @param op - */ - bool registerOperation(const std::string &opName, std::shared_ptr op); - bool registerOperation(std::shared_ptr op); + template + std::string local_to_string(T value); + const char* getAllCustomOperations(); - void registerHelper(sd::ops::platforms::PlatformHelper* op); + /** + * This method registers operation in our registry, so we can use them later + * + * @param op + */ + bool registerOperation(const std::string& opName, + std::shared_ptr op); + bool registerOperation(std::shared_ptr op); - bool hasHelper(Nd4jLong hash, samediff::Engine engine); + void registerHelper(sd::ops::platforms::PlatformHelper* op); - std::shared_ptr getOperation(Nd4jLong hash); - std::shared_ptr getOperation(const std::string &name); + bool hasHelper(Nd4jLong hash, samediff::Engine engine); - bool hasOperation(const std::string &opName) const; - bool hasOperation(const Nd4jLong opName) const; + std::shared_ptr getOperation(Nd4jLong hash); + std::shared_ptr getOperation(const std::string& name); - sd::ops::platforms::PlatformHelper* getPlatformHelper(Nd4jLong hash, samediff::Engine engine); + bool hasOperation(const std::string& opName) const; + bool hasOperation(const Nd4jLong opName) const; - std::vector getAllHashes(); + sd::ops::platforms::PlatformHelper* getPlatformHelper( + Nd4jLong hash, samediff::Engine engine); - int numberOfOperations(); - }; + std::vector getAllHashes(); + int numberOfOperations(); +}; - /* - * These structs are used to "register" our ops in OpRegistrator. - */ - template - struct __registrator{ - __registrator(); - }; +/* + * These structs are used to "register" our ops in OpRegistrator. + */ +template +struct __registrator { + __registrator(); +}; - template - struct __registratorSynonym { - __registratorSynonym(const char *name, const char *oname); - }; +template +struct __registratorSynonym { + __registratorSynonym(const char* name, const char* oname); +}; - } -} +} // namespace ops +} // namespace sd -#endif //LIBND4J_OPREGISTRATOR_H +#endif // LIBND4J_OPREGISTRATOR_H diff --git a/libnd4j/include/ops/declarable/OpTuple.h b/libnd4j/include/ops/declarable/OpTuple.h index 39ce03d05432..537fe7d8b0a4 100644 --- a/libnd4j/include/ops/declarable/OpTuple.h +++ b/libnd4j/include/ops/declarable/OpTuple.h @@ -21,31 +21,33 @@ #ifndef LIBND4J_OPTUPLE_H #define LIBND4J_OPTUPLE_H -#include -#include #include +#include +#include + namespace sd { - namespace ops { - class SD_EXPORT OpTuple { - public: - std::string _opName; - std::vector _inputs; - std::vector _outputs; - std::vector _tArgs; - std::vector _iArgs; - - OpTuple(const char *opName); - OpTuple(const char *opName, std::initializer_list&& inputs, std::initializer_list&& tArgs, std::initializer_list&& iArgs); - ~OpTuple(); - - OpTuple* addInput(sd::NDArray *array); - OpTuple* addOutput(sd::NDArray *array); - OpTuple* setTArgs(std::initializer_list tArgs); - OpTuple* setIArgs(std::initializer_list iArgs); - }; - } -} - - -#endif //LIBND4J_OPTUPLE_H +namespace ops { +class SD_EXPORT OpTuple { + public: + std::string _opName; + std::vector _inputs; + std::vector _outputs; + std::vector _tArgs; + std::vector _iArgs; + + OpTuple(const char* opName); + OpTuple(const char* opName, std::initializer_list&& inputs, + std::initializer_list&& tArgs, + std::initializer_list&& iArgs); + ~OpTuple(); + + OpTuple* addInput(sd::NDArray* array); + OpTuple* addOutput(sd::NDArray* array); + OpTuple* setTArgs(std::initializer_list tArgs); + OpTuple* setIArgs(std::initializer_list iArgs); +}; +} // namespace ops +} // namespace sd + +#endif // LIBND4J_OPTUPLE_H diff --git a/libnd4j/include/ops/declarable/PlatformHelper.h b/libnd4j/include/ops/declarable/PlatformHelper.h index 9f456b1ca96d..3adad1a3b1fa 100644 --- a/libnd4j/include/ops/declarable/PlatformHelper.h +++ b/libnd4j/include/ops/declarable/PlatformHelper.h @@ -21,75 +21,79 @@ #ifndef SD_PLATFORMHELPER_H #define SD_PLATFORMHELPER_H -#include #include #include -#include -#include +#include #include +#include + +#include namespace sd { - namespace ops { - namespace platforms { - /** - * This abstract class defines methods used by platform-specific helpers implementations - */ - class SD_EXPORT PlatformHelper { - protected: - // target engine for this impl - samediff::Engine _engine; - - // name of the operation this helper is built for - std::string _name; - - // hash of the operation this helper is built for - Nd4jLong _hash; - public: - PlatformHelper(const char *name, samediff::Engine engine); - - ~PlatformHelper() = default; - - std::string name(); - - samediff::Engine engine(); - - Nd4jLong hash(); - - /** - * This method checks, if given helper can be used with given input/output/configuration options - * - * @param context - * @return - */ - virtual bool isUsable(graph::Context &context) = 0; - - /** - * This method invokes helper. Typically this method replaces actual op execution - * - * @param context - * @return - */ - virtual Nd4jStatus invokeHelper(graph::Context &context) = 0; - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - sd::NDArray* getZ(graph::Context &ctx, int inputId); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - sd::NDArray* getNullifiedZ(graph::Context &ctx, int inputId); - }; - } - } -} - - -#endif //SD_PLATFORMHELPER_H +namespace ops { +namespace platforms { +/** + * This abstract class defines methods used by platform-specific helpers + * implementations + */ +class SD_EXPORT PlatformHelper { + protected: + // target engine for this impl + samediff::Engine _engine; + + // name of the operation this helper is built for + std::string _name; + + // hash of the operation this helper is built for + Nd4jLong _hash; + + public: + PlatformHelper(const char *name, samediff::Engine engine); + + ~PlatformHelper() = default; + + std::string name(); + + samediff::Engine engine(); + + Nd4jLong hash(); + + /** + * This method checks, if given helper can be used with given + * input/output/configuration options + * + * @param context + * @return + */ + virtual bool isUsable(graph::Context &context) = 0; + + /** + * This method invokes helper. Typically this method replaces actual op + * execution + * + * @param context + * @return + */ + virtual Nd4jStatus invokeHelper(graph::Context &context) = 0; + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + sd::NDArray *getZ(graph::Context &ctx, int inputId); + + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + sd::NDArray *getNullifiedZ(graph::Context &ctx, int inputId); +}; +} // namespace platforms +} // namespace ops +} // namespace sd + +#endif // SD_PLATFORMHELPER_H diff --git a/libnd4j/include/ops/declarable/generic/CustomOperations.cpp b/libnd4j/include/ops/declarable/generic/CustomOperations.cpp index c13430ce3abf..60c3e47ee296 100644 --- a/libnd4j/include/ops/declarable/generic/CustomOperations.cpp +++ b/libnd4j/include/ops/declarable/generic/CustomOperations.cpp @@ -15,37 +15,38 @@ ******************************************************************************/ // -// This is special snowflake. This file builds bindings for ops availability tests +// This is special snowflake. This file builds bindings for ops availability +// tests // // @author raver119@gmail.com // -#include #include +#include #include namespace sd { - _loader::_loader() { - // - OpTracker::getInstance(); +_loader::_loader() { + // + OpTracker::getInstance(); -//#ifndef __CLION_IDE__ - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_FLOAT_OPS); - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_SAME_OPS); - BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_BOOL_OPS); - BUILD_TRACKER(OpType_BROADCAST, BROADCAST_OPS); - BUILD_TRACKER(OpType_PAIRWISE, PAIRWISE_TRANSFORM_OPS); - BUILD_TRACKER(OpType_RANDOM, RANDOM_OPS); - BUILD_TRACKER(OpType_REDUCE_FLOAT, REDUCE_FLOAT_OPS); - BUILD_TRACKER(OpType_REDUCE_SAME, REDUCE_SAME_OPS); - BUILD_TRACKER(OpType_REDUCE_BOOL, REDUCE_BOOL_OPS); - BUILD_TRACKER(OpType_REDUCE_3, REDUCE3_OPS); - BUILD_TRACKER(OpType_INDEX_REDUCE, INDEX_REDUCE_OPS); - BUILD_TRACKER(OpType_SCALAR, SCALAR_OPS); - BUILD_TRACKER(OpType_SUMMARYSTATS, SUMMARY_STATS_OPS); -//#endif - }; + //#ifndef __CLION_IDE__ + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_FLOAT_OPS); + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_SAME_OPS); + BUILD_TRACKER(OpType_TRANSFORM_SAME, TRANSFORM_BOOL_OPS); + BUILD_TRACKER(OpType_BROADCAST, BROADCAST_OPS); + BUILD_TRACKER(OpType_PAIRWISE, PAIRWISE_TRANSFORM_OPS); + BUILD_TRACKER(OpType_RANDOM, RANDOM_OPS); + BUILD_TRACKER(OpType_REDUCE_FLOAT, REDUCE_FLOAT_OPS); + BUILD_TRACKER(OpType_REDUCE_SAME, REDUCE_SAME_OPS); + BUILD_TRACKER(OpType_REDUCE_BOOL, REDUCE_BOOL_OPS); + BUILD_TRACKER(OpType_REDUCE_3, REDUCE3_OPS); + BUILD_TRACKER(OpType_INDEX_REDUCE, INDEX_REDUCE_OPS); + BUILD_TRACKER(OpType_SCALAR, SCALAR_OPS); + BUILD_TRACKER(OpType_SUMMARYSTATS, SUMMARY_STATS_OPS); + //#endif +}; - static sd::_loader loader; -} \ No newline at end of file +static sd::_loader loader; +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp index 65f81b42822e..23770d42090a 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp @@ -22,35 +22,40 @@ #if NOT_EXCLUDED(OP_bits_hamming_distance) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length"); - REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type"); - - helpers::hamming(block.launchContext(), *x, *y, *output); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(bits_hamming_distance) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(bits_hamming_distance) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_INDICES}); - } - } +namespace ops { +CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + x->lengthOf() == y->lengthOf(), 0, + "bits_hamming_distance: both arguments must have the same length"); + REQUIRE_TRUE( + x->dataType() == y->dataType(), 0, + "bits_hamming_distance: both arguments must have the same data type"); + + helpers::hamming(block.launchContext(), *x, *y, *output); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(bits_hamming_distance) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(bits_hamming_distance) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp index 1e951c1d9580..8fc387c848f0 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_and.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, + pairwise::IntOps::IntAnd, + broadcast::IntOps::IntAnd), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_and) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_and) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp index cd20a8434f9e..642cf4dcc2d7 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_or.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, + pairwise::IntOps::IntOr, + broadcast::IntOps::IntOr), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_or) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_or) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp index 0af9fe759163..241ae8c22459 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/bitwise_xor.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, + pairwise::IntOps::IntXor, + broadcast::IntOps::IntXor), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(bitwise_xor) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(bitwise_xor) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index cc0c4827b23d..5ffa1f1989df 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -26,25 +26,27 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, + pairwise::CyclicShiftRight, + broadcast::CyclicShiftRight), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cyclic_rshift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(cyclic_rshift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index f2b36a6d8346..0578b87557db 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom( + scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, + broadcast::CyclicShiftLeft), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cyclic_shift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(cyclic_shift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 8b44d2a6f415..bde55383ba1e 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, + broadcast::ShiftRight), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rshift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(rshift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index 7d0647e1b7bf..47ff4ee5b294 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -26,25 +26,26 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false); + x->applyTrueBroadcast( + BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, + broadcast::ShiftLeft), + *y, *z, false); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(shift_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } - } +DECLARE_TYPES(shift_bits) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp index 0ba6fbcc7d52..ffe81cafa457 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp @@ -26,28 +26,30 @@ #include namespace sd { - namespace ops { - OP_IMPL(toggle_bits, -1, -1, true) { - - for (int i = 0; i < block.width(); i++) { - auto x = INPUT_VARIABLE(i); - auto z = OUTPUT_VARIABLE(i); - - REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type"); - REQUIRE_TRUE(x->isZ(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)"); - - helpers::__toggle_bits(block.launchContext(), *x, *z); - } - return Status::OK(); - } - - DECLARE_TYPES(toggle_bits) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +OP_IMPL(toggle_bits, -1, -1, true) { + for (int i = 0; i < block.width(); i++) { + auto x = INPUT_VARIABLE(i); + auto z = OUTPUT_VARIABLE(i); + + REQUIRE_TRUE(x->dataType() == z->dataType(), 0, + "Toggle bits requires input and output to have same type"); + REQUIRE_TRUE(x->isZ(), 0, + "Toggle bits requires input and output to be integer type " + "(int8, int16, int32, int64)"); + + helpers::__toggle_bits(block.launchContext(), *x, *z); + } + return Status::OK(); } +DECLARE_TYPES(toggle_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp index 3e27e8921af8..cfee30389f5b 100644 --- a/libnd4j/include/ops/declarable/generic/blas/axpy.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/axpy.cpp @@ -24,38 +24,41 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(axpy, 2, 1, false, -2, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(x->isSameShape(y),0, "Axpy: both arguments should have the same shape"); - REQUIRE_TRUE(x->dataType() == y->dataType() && x->dataType() == z->dataType(), 0, "Axpy: all arguments must have the same data type"); - - double a = 1.0; - - if (block.width() > 2) { - auto alpha = INPUT_VARIABLE(2); - REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); - } else if (block.numT() > 0) { - a = T_ARG(0); - } - - ExtraArguments arguments({a}); - - y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); - - return Status::OK(); - } - - DECLARE_TYPES(axpy) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(axpy, 2, 1, false, -2, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->isSameShape(y), 0, + "Axpy: both arguments should have the same shape"); + REQUIRE_TRUE(x->dataType() == y->dataType() && x->dataType() == z->dataType(), + 0, "Axpy: all arguments must have the same data type"); + + double a = 1.0; + + if (block.width() > 2) { + auto alpha = INPUT_VARIABLE(2); + REQUIRE_TRUE(alpha->isScalar(), 0, + "Axpy: alpha argument should be scalar or TArg"); + } else if (block.numT() > 0) { + a = T_ARG(0); + } + + ExtraArguments arguments({a}); + + y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments); + + return Status::OK(); +} + +DECLARE_TYPES(axpy) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp index 194af35b8cd0..524ed10ce1c3 100644 --- a/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/batched_gemm.cpp @@ -25,115 +25,135 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(batched_gemm, -1, -1, false, 0, 9) { - - int transA = INT_ARG(0); - int transB = INT_ARG(1); - int M = INT_ARG(2); - int N = INT_ARG(3); - int K = INT_ARG(4); - int ldA = INT_ARG(5); - int ldB = INT_ARG(6); - int ldC = INT_ARG(7); - int batchSize = INT_ARG(8); - - if (transA == 0) - transA = 111; - - if (transB == 0) - transB = 111; - - if (transA == 1) - transA = 112; - - if (transB == 1) - transB = 112; - - // basically A+B and 2 arrays of alpha and beta - int expectedWidth = batchSize * 2 + 2; - - REQUIRE_TRUE((transA == 111 || transA == 112) && (transB == 111 || transB == 112), 0, "BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, for NoTrans/Trans respectively") - REQUIRE_TRUE(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0, 0, ""); - REQUIRE_TRUE(block.width() == expectedWidth, 0, "BatchedGemm: expected number of input arrays is %i, but got %i instead", expectedWidth, block.width()); - - auto alpha = INPUT_VARIABLE(0); - auto beta = INPUT_VARIABLE(1); - - std::vector vA(batchSize); - std::vector vB(batchSize); - std::vector vC(batchSize); - - auto firstType = INPUT_VARIABLE(0)->dataType(); - for(int e = 0; e < batchSize; e++) { - vA[e] = INPUT_VARIABLE(e+2); - vB[e] = INPUT_VARIABLE(e+2+batchSize); - vC[e] = OUTPUT_VARIABLE(e); - - - REQUIRE_TRUE(firstType == vC[e]->dataType(), 0, "BatchedGemm: all inputs and outputs must have same data type"); - - REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of A should be equal to 2", e); - REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of B should be equal to 2", e); - REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, "BatchedGemm: batch %i, rank of C should be equal to 2", e); - - REQUIRE_TRUE(M == vA[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of A.rows() should be equal to M", e); - REQUIRE_TRUE(N == vB[e]->sizeAt(1), 0, "BatchedGemm: batch %i, number of B.columns() should be equal to N", e); - REQUIRE_TRUE(K == vA[e]->sizeAt(1) && K == vB[e]->sizeAt(0), 0, "BatchedGemm: batch %i, number of A.columns() and B.rows() should be equal to K", e); - }; - - REQUIRE_TRUE(vA.size() == vB.size() && vA.size() == vC.size() && vA.size() == batchSize, 0, "BatchedGemm: mismatched numbers of A, B, C for unknown reason"); - - sd::ops::helpers::bgemm(vA, vB, vC, alpha, beta, transA, transB, M, N, K, ldA, ldB, ldC); - - return Status::OK(); + int transA = INT_ARG(0); + int transB = INT_ARG(1); + int M = INT_ARG(2); + int N = INT_ARG(3); + int K = INT_ARG(4); + int ldA = INT_ARG(5); + int ldB = INT_ARG(6); + int ldC = INT_ARG(7); + int batchSize = INT_ARG(8); + + if (transA == 0) transA = 111; + + if (transB == 0) transB = 111; + + if (transA == 1) transA = 112; + + if (transB == 1) transB = 112; + + // basically A+B and 2 arrays of alpha and beta + int expectedWidth = batchSize * 2 + 2; + + REQUIRE_TRUE( + (transA == 111 || transA == 112) && (transB == 111 || transB == 112), 0, + "BatchedGemm: valid values for transA and transB are: 0/1 or 111/112, " + "for NoTrans/Trans respectively") + REQUIRE_TRUE( + M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0, + 0, ""); + REQUIRE_TRUE( + block.width() == expectedWidth, 0, + "BatchedGemm: expected number of input arrays is %i, but got %i instead", + expectedWidth, block.width()); + + auto alpha = INPUT_VARIABLE(0); + auto beta = INPUT_VARIABLE(1); + + std::vector vA(batchSize); + std::vector vB(batchSize); + std::vector vC(batchSize); + + auto firstType = INPUT_VARIABLE(0)->dataType(); + for (int e = 0; e < batchSize; e++) { + vA[e] = INPUT_VARIABLE(e + 2); + vB[e] = INPUT_VARIABLE(e + 2 + batchSize); + vC[e] = OUTPUT_VARIABLE(e); + + REQUIRE_TRUE( + firstType == vC[e]->dataType(), 0, + "BatchedGemm: all inputs and outputs must have same data type"); + + REQUIRE_TRUE(vA[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of A should be equal to 2", e); + REQUIRE_TRUE(vB[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of B should be equal to 2", e); + REQUIRE_TRUE(vC[e]->rankOf() == 2, 0, + "BatchedGemm: batch %i, rank of C should be equal to 2", e); + + REQUIRE_TRUE( + M == vA[e]->sizeAt(0), 0, + "BatchedGemm: batch %i, number of A.rows() should be equal to M", e); + REQUIRE_TRUE( + N == vB[e]->sizeAt(1), 0, + "BatchedGemm: batch %i, number of B.columns() should be equal to N", e); + REQUIRE_TRUE(K == vA[e]->sizeAt(1) && K == vB[e]->sizeAt(0), 0, + "BatchedGemm: batch %i, number of A.columns() and B.rows() " + "should be equal to K", + e); + }; + + REQUIRE_TRUE(vA.size() == vB.size() && vA.size() == vC.size() && + vA.size() == batchSize, + 0, + "BatchedGemm: mismatched numbers of A, B, C for unknown reason"); + + sd::ops::helpers::bgemm(vA, vB, vC, alpha, beta, transA, transB, M, N, K, ldA, + ldB, ldC); + + return Status::OK(); }; - DECLARE_SHAPE_FN(batched_gemm) { - int transA = INT_ARG(0); - int transB = INT_ARG(1); - int M = INT_ARG(2); - int N = INT_ARG(3); - int K = INT_ARG(4); - int ldA = INT_ARG(5); - int ldB = INT_ARG(6); - int ldC = INT_ARG(7); - int batchSize = INT_ARG(8); - - auto firstType = ArrayOptions::dataType(inputShape->at(0)); - for (int e = 1; e < block.width(); e++) { - REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0, "BatchedGemm: all inputs must have same data type"); - } - - auto shapeList = SHAPELIST(); - - if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1})); - return shapeList; - } - + int transA = INT_ARG(0); + int transB = INT_ARG(1); + int M = INT_ARG(2); + int N = INT_ARG(3); + int K = INT_ARG(4); + int ldA = INT_ARG(5); + int ldB = INT_ARG(6); + int ldC = INT_ARG(7); + int batchSize = INT_ARG(8); + + auto firstType = ArrayOptions::dataType(inputShape->at(0)); + for (int e = 1; e < block.width(); e++) { + REQUIRE_TRUE(firstType == ArrayOptions::dataType(inputShape->at(1)), 0, + "BatchedGemm: all inputs must have same data type"); + } + + auto shapeList = SHAPELIST(); + + if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && + batchSize > 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1})); + return shapeList; + } - std::vector shape({M, N}); + std::vector shape({M, N}); - for (int e = 0; e < batchSize; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'f', shape); - shapeList->push_back(newShape); - } + for (int e = 0; e < batchSize; e++) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShape->at(0)), 'f', shape); + shapeList->push_back(newShape); + } - return shapeList; + return shapeList; } DECLARE_TYPES(batched_gemm) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) -// ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + // ->setAllowedInputTypes(1, {DataType::FLOAT32, + // DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp index c7fad3b3affd..100d37d7b8c7 100644 --- a/libnd4j/include/ops/declarable/generic/blas/matmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/matmul.cpp @@ -23,67 +23,100 @@ #include #if NOT_EXCLUDED(OP_matmul) -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(matmul, 2, 1, false, 0, -2) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - int iSize = (int) block.numI(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - // optional use alpha nad beta - iSize = (int)block.numT(); - double alpha = iSize > 0 ? T_ARG(0) : 1.0; - double beta = iSize > 1 ? T_ARG(1) : 0.0; - - const int xRank = x->rankOf(); - const int yRank = y->rankOf(); - const int zRank = z->rankOf(); - - if (transZ) { - x = INPUT_VARIABLE(1); - y = INPUT_VARIABLE(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - const int xLastDim = transX ? -2 : -1; - const int yLastDim = transY ? -2 : -1; - const int xLastButOneDim = transX ? -1 : -2; - const int yLastButOneDim = transY ? -1 : -2; - - // ******* input validation ******* // - REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); - - if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "MATMUL OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", x->lengthOf(), y->lengthOf()); - } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else { - REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - - if (xRank > 2) // outer dims must be the same - for (int i = 0; i < xRank - 2; ++i) - REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - } - // ******* end of input validation ******* // - - MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + // optional use alpha nad beta + iSize = (int)block.numT(); + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); + + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; + + // ******* input validation ******* // + REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, + "MATMUL OP: input arrays must have rank bigger than 0 (should " + "not be scalars), but got instead: x rank = %i, y rank = %i !", + xRank, yRank); + + if (xRank == 1 && + yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, + "MATMUL OP: since input arrays are vectors they must have the " + "same length, but got x length = %i, y length = %i !", + x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = + // [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, + "MATMUL OP: input arrays have inconsistent shapes for " + "vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] + // = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, + "MATMUL OP: input arrays have inconsistent shapes for " + "matrix-vector product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, + "MATMUL OP: input and output arrays must have the same rank, " + "but got instead: x rank = %i, y rank = %i, z rank = %i !", + xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && + x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && + y->sizeAt(yLastDim) == z->sizeAt(-1), + 0, + "MATMUL OP: input/output arrays have inconsistent shapes for " + "matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE( + x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, + "MATMUL OP: input/output arrays have inconsistent shapes for " + "matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // + + MmulHelper::matmul(x, y, z, transX, transY, alpha, beta); + + return Status::OK(); } DECLARE_SYN(mMul, matmul); @@ -98,111 +131,113 @@ DECLARE_SYN(dot, matmul); ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(matmul) { - - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - const int iSize = (int) block.numI(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, - "MATMUL OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", - xShapeInfo[0], yShapeInfo[0]); - - if (transZ) { - xShapeInfo = inputShape->at(1); - yShapeInfo = inputShape->at(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - auto zShapeOnly = ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); - - auto dtypeX = ArrayOptions::dataType(xShapeInfo); - auto dtypeY = ArrayOptions::dataType(yShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto yOrder = shape::order(yShapeInfo); - auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; - - // we just pick the higher data type out of X and Y - auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtypeZ, zOrder, zShapeOnly); - return SHAPELIST(newShape); + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + const int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + REQUIRE_TRUE(xShapeInfo[0] > 0 && yShapeInfo[0] > 0, 0, + "MATMUL OP: input arrays must have rank bigger than 0 (should " + "not be scalars), but got instead: x rank = %i, y rank = %i !", + xShapeInfo[0], yShapeInfo[0]); + + if (transZ) { + xShapeInfo = inputShape->at(1); + yShapeInfo = inputShape->at(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + auto zShapeOnly = + ShapeUtils::evalShapeForMatmul(xShapeInfo, yShapeInfo, transX, transY); + + auto dtypeX = ArrayOptions::dataType(xShapeInfo); + auto dtypeY = ArrayOptions::dataType(yShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = xOrder == 'c' && yOrder == 'c' ? 'c' : 'f'; + + // we just pick the higher data type out of X and Y + auto dtypeZ = dtypeX > dtypeY ? dtypeX : dtypeY; + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + dtypeZ, zOrder, zShapeOnly); + return SHAPELIST(newShape); } ////////////////////////////////////////////////////////////////////// DECLARE_TYPES(matmul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); } ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(matmul_bp, 3, 2, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto eps = INPUT_VARIABLE(2); - auto dldx = OUTPUT_VARIABLE(0); - auto dldy = OUTPUT_VARIABLE(1); - - int iSize = (int) block.numI(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - // optional use alpha nad beta - iSize = (int) block.numT(); - - double alpha = iSize > 0 ? T_ARG(0) : 1.0; - double beta = iSize > 1 ? T_ARG(1) : 0.0; - -/* -In: x=[a,b], y=[b,c] -tX tY tZ x y z dz dLdx dLdy -F F F [a,b] [b,c] [a,c] [a,c] [a,c]*[b,c]T = [a,b] x*yT [a,b]T*[a,c] = [b,c] xT*y -T F F [b,a] [b,c] [a,c] [a,c] ([a,c]*[b,c]T)T = [b,a] (x*yT)T [b,a]*[a,c] = [b,c] x*y -F T F [a,b] [c,b] [a,c] [a,c] ([a,c]*[c,b]) = [a,b] x*y [a,b]T*[a,c] = [b,c] ->T xT*y -T T F [b,a] [c,b] [a,c] [a,c] ([a,c]*[c,b])T = [b,a] (x*y)T [b,a]*[a,c] = [b,c] ->T x*y -F F T [a,b] [b,c] [c,a] [c,a] -*/ - - - sd::ops::matmul op; - op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); - op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto eps = INPUT_VARIABLE(2); + auto dldx = OUTPUT_VARIABLE(0); + auto dldy = OUTPUT_VARIABLE(1); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + // optional use alpha nad beta + iSize = (int)block.numT(); + + double alpha = iSize > 0 ? T_ARG(0) : 1.0; + double beta = iSize > 1 ? T_ARG(1) : 0.0; + + /* + In: x=[a,b], y=[b,c] + tX tY tZ x y z dz dLdx dLdy F F F [a,b] + [b,c] [a,c] [a,c] [a,c]*[b,c]T = [a,b] x*yT [a,b]T*[a,c] = + [b,c] xT*y T F F [b,a] [b,c] [a,c] [a,c] ([a,c]*[b,c]T)T = + [b,a] (x*yT)T [b,a]*[a,c] = [b,c] x*y F T F [a,b] [c,b] + [a,c] [a,c] ([a,c]*[c,b]) = [a,b] x*y [a,b]T*[a,c] = + [b,c] ->T xT*y T T F [b,a] [c,b] [a,c] [a,c] ([a,c]*[c,b])T = + [b,a] (x*y)T [b,a]*[a,c] = [b,c] ->T x*y F F T [a,b] [b,c] + [c,a] [c,a] + */ + + sd::ops::matmul op; + op.execute({eps, y}, {dldx}, {alpha, beta}, {transZ, !transY, transX}, {}); + op.execute({x, eps}, {dldy}, {alpha, beta}, {!transX, transZ, transY}, {}); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(matmul_bp) { - Nd4jLong *xShapeInfo; - Nd4jLong *yShapeInfo; + Nd4jLong *xShapeInfo; + Nd4jLong *yShapeInfo; - COPY_SHAPE(inputShape->at(0), xShapeInfo); - COPY_SHAPE(inputShape->at(1), yShapeInfo); + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), yShapeInfo); - return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(yShapeInfo)); } ////////////////////////////////////////////////////////////////////// DECLARE_TYPES(matmul_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); -} - -} + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/svd.cpp b/libnd4j/include/ops/declarable/generic/blas/svd.cpp index 1c316cf4a91a..17e3b8dcf39d 100644 --- a/libnd4j/include/ops/declarable/generic/blas/svd.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/svd.cpp @@ -25,95 +25,106 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) { - auto x = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); - const int rank = x->rankOf(); - REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); + const int rank = x->rankOf(); + REQUIRE_TRUE( + rank >= 2, 0, + "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - bool fullUV = (bool)INT_ARG(0); - const bool calcUV = (bool)INT_ARG(1); + bool fullUV = (bool)INT_ARG(0); + const bool calcUV = (bool)INT_ARG(1); - if(calcUV == false) - fullUV = false; + if (calcUV == false) fullUV = false; - const int switchNum = INT_ARG(2); + const int switchNum = INT_ARG(2); - // #ifndef __CUDABLAS__ - helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum); - // #endif + // #ifndef __CUDABLAS__ + helpers::svd(block.launchContext(), x, + {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, + calcUV ? OUTPUT_VARIABLE(2) : nullptr}, + fullUV, calcUV, switchNum); + // #endif - return Status::OK();; + return Status::OK(); + ; } - - DECLARE_TYPES(svd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setSameMode(true); - } +DECLARE_TYPES(svd) { + getOpDescriptor() + ->setAllowedInputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setSameMode(true); +} DECLARE_SHAPE_FN(svd) { - - auto inShapeInfo = inputShape->at(0); - bool fullUV = (bool)INT_ARG(0); - bool calcUV = (bool)INT_ARG(1); - - const int rank = inShapeInfo[0]; - REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); - - const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1]; - - Nd4jLong* sShapeInfo(nullptr); - if(rank == 2) { - ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); - sShapeInfo[0] = 1; - sShapeInfo[1] = diagSize; - } - else { - ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(rank-1), Nd4jLong); - sShapeInfo[0] = rank - 1; - for(int i=1; i <= rank-2; ++i) - sShapeInfo[i] = inShapeInfo[i]; - sShapeInfo[rank-1] = diagSize; + auto inShapeInfo = inputShape->at(0); + bool fullUV = (bool)INT_ARG(0); + bool calcUV = (bool)INT_ARG(1); + + const int rank = inShapeInfo[0]; + REQUIRE_TRUE( + rank >= 2, 0, + "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); + + const int diagSize = inShapeInfo[rank] < inShapeInfo[rank - 1] + ? inShapeInfo[rank] + : inShapeInfo[rank - 1]; + + Nd4jLong* sShapeInfo(nullptr); + if (rank == 2) { + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(1), + Nd4jLong); + sShapeInfo[0] = 1; + sShapeInfo[1] = diagSize; + } else { + ALLOCATE(sShapeInfo, block.workspace(), shape::shapeInfoLength(rank - 1), + Nd4jLong); + sShapeInfo[0] = rank - 1; + for (int i = 1; i <= rank - 2; ++i) sShapeInfo[i] = inShapeInfo[i]; + sShapeInfo[rank - 1] = diagSize; + } + + ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); + + if (calcUV) { + Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr); + COPY_SHAPE(inShapeInfo, uShapeInfo); + COPY_SHAPE(inShapeInfo, vShapeInfo); + + if (fullUV) { + uShapeInfo[rank] = uShapeInfo[rank - 1]; + vShapeInfo[rank - 1] = vShapeInfo[rank]; + } else { + uShapeInfo[rank] = diagSize; + vShapeInfo[rank - 1] = vShapeInfo[rank]; + vShapeInfo[rank] = diagSize; } - ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo)); - - if(calcUV){ - - Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr); - COPY_SHAPE(inShapeInfo, uShapeInfo); - COPY_SHAPE(inShapeInfo, vShapeInfo); - - if(fullUV) { - uShapeInfo[rank] = uShapeInfo[rank-1]; - vShapeInfo[rank-1] = vShapeInfo[rank]; - } - else { - uShapeInfo[rank] = diagSize; - vShapeInfo[rank-1] = vShapeInfo[rank]; - vShapeInfo[rank] = diagSize; - } - - shape::updateStrides(uShapeInfo, shape::order(inShapeInfo)); - shape::updateStrides(vShapeInfo, shape::order(inShapeInfo)); - - auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(vShapeInfo))); - RELEASE(sShapeInfo, block.workspace()); - RELEASE(uShapeInfo, block.workspace()); - RELEASE(vShapeInfo, block.workspace()); - return result; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting(sShapeInfo, block.workspace())); + shape::updateStrides(uShapeInfo, shape::order(inShapeInfo)); + shape::updateStrides(vShapeInfo, shape::order(inShapeInfo)); + + auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(sShapeInfo)), + ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(uShapeInfo)), + ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(vShapeInfo))); + RELEASE(sShapeInfo, block.workspace()); + RELEASE(uShapeInfo, block.workspace()); + RELEASE(vShapeInfo, block.workspace()); + return result; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting( + sShapeInfo, block.workspace())); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp index 889bd495732b..945dc4758bfd 100644 --- a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp @@ -21,174 +21,188 @@ #include #if NOT_EXCLUDED(OP_tensormmul) -#include +#include #include #include -#include +#include namespace sd { -namespace ops { +namespace ops { //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - auto c = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); + REQUIRE_TRUE(a->dataType() == b->dataType(), 0, + "tensormmul: A, B and C data types must be the same"); - // building axes - int axe0_size = INT_ARG(0); - int axe1_size = INT_ARG(axe0_size+1); - std::vector axes_0(axe0_size), axes_1(axe1_size); - for (int e = 0; e < axe0_size; e++) - axes_0[e] = (int)INT_ARG(e + 1); + // building axes + int axe0_size = INT_ARG(0); + int axe1_size = INT_ARG(axe0_size + 1); + std::vector axes_0(axe0_size), axes_1(axe1_size); + for (int e = 0; e < axe0_size; e++) axes_0[e] = (int)INT_ARG(e + 1); - for (int e = 0; e < axe1_size; e++) - axes_1[e] = (int)INT_ARG(e + axe0_size + 2); + for (int e = 0; e < axe1_size; e++) + axes_1[e] = (int)INT_ARG(e + axe0_size + 2); - nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); + nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); - MmulHelper::tensorDot(a, b, c, axes_0, axes_1); - return Status::OK(); + MmulHelper::tensorDot(a, b, c, axes_0, axes_1); + return Status::OK(); } DECLARE_SYN(tensordot, tensormmul); //////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(tensormmul) { - - auto aShapeInfo = inputShape->at(0); - auto bShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); - - // building axes - int axe0_size = INT_ARG(0); - int axe1_size = INT_ARG(axe0_size+1); - std::vector axes_0(axe0_size), axes_1(axe1_size); - for (int e = 0; e < axe0_size; e++) - axes_0[e] = (int) INT_ARG(e+1); - - for (int e = 0; e < axe1_size; e++) - axes_1[e] = (int) INT_ARG(e + axe0_size + 2); - - // evaluate shapes - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; - auto outShape = sd::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); + auto aShapeInfo = inputShape->at(0); + auto bShapeInfo = inputShape->at(1); + + REQUIRE_TRUE( + ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), + 0, "tensormmul: A and B data types must be the same"); + + // building axes + int axe0_size = INT_ARG(0); + int axe1_size = INT_ARG(axe0_size + 1); + std::vector axes_0(axe0_size), axes_1(axe1_size); + for (int e = 0; e < axe0_size; e++) axes_0[e] = (int)INT_ARG(e + 1); + + for (int e = 0; e < axe1_size; e++) + axes_1[e] = (int)INT_ARG(e + axe0_size + 2); + + // evaluate shapes + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; + auto outShape = sd::ShapeUtils::evalShapeForTensorDot( + aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, + shapeBt); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); } //////////////////////////////////////////////////////////////////////// DECLARE_TYPES(tensormmul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) { + auto A = INPUT_VARIABLE(0); + auto B = INPUT_VARIABLE(1); - auto A = INPUT_VARIABLE(0); - auto B = INPUT_VARIABLE(1); - - auto dLdC = INPUT_VARIABLE(2); + auto dLdC = INPUT_VARIABLE(2); - auto dLdA = OUTPUT_VARIABLE(0); - auto dLdB = OUTPUT_VARIABLE(1); + auto dLdA = OUTPUT_VARIABLE(0); + auto dLdB = OUTPUT_VARIABLE(1); - REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); + REQUIRE_TRUE( + (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), + 0, "tensormmul_bp: A, B and dLdC data types must be the same"); - int axe0Size = INT_ARG(0); - int axe1Size = INT_ARG(axe0Size + 1); + int axe0Size = INT_ARG(0); + int axe1Size = INT_ARG(axe0Size + 1); - auto Arank = A->rankOf(); - auto Brank = B->rankOf(); - auto dLdCrank = dLdC->rankOf(); + auto Arank = A->rankOf(); + auto Brank = B->rankOf(); + auto dLdCrank = dLdC->rankOf(); - REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0"); + REQUIRE_TRUE( + (Arank >= axe0Size), 0, + "tensormmul_bp: A rank must be the higher or same as input axes 0"); - REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1"); + REQUIRE_TRUE( + (Brank >= axe1Size), 0, + "tensormmul_bp: B rank must be the higher or same as input axes 1"); - // building axes - std::vector axes0(axe0Size), axes1(axe1Size); - for (uint e = 0; e < axe0Size; e++) - axes0[e] = (int)INT_ARG(e + 1); - for (uint e = 0; e < axe1Size; e++) - axes1[e] = (int)INT_ARG(e + axe0Size + 2); + // building axes + std::vector axes0(axe0Size), axes1(axe1Size); + for (uint e = 0; e < axe0Size; e++) axes0[e] = (int)INT_ARG(e + 1); + for (uint e = 0; e < axe1Size; e++) axes1[e] = (int)INT_ARG(e + axe0Size + 2); - std::vector permutAt, permutBt; - std::vector shapeAt, shapeBt; + std::vector permutAt, permutBt; + std::vector shapeAt, shapeBt; - ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt); + ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, + shapeAt, shapeBt); - // special case for scalar value - if (dLdC->isScalar()) { + // special case for scalar value + if (dLdC->isScalar()) { + dLdA->assign((*dLdC) * *B); + dLdB->assign((*dLdC) * *A); - dLdA->assign((*dLdC) * *B); - dLdB->assign((*dLdC) * *A); - - return Status::OK(); - } + return Status::OK(); + } - std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); - std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); + std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); + std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); - // rank always have to be divided by 2 - std::vector axesAdLdC, axesBdLdC; - if (dLdCrank > 1) { - axesAdLdC.resize(dLdCrank / 2); - std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); - axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); - } - else { - axesAdLdC.push_back(0); - axesBdLdC.push_back(0); - } + // rank always have to be divided by 2 + std::vector axesAdLdC, axesBdLdC; + if (dLdCrank > 1) { + axesAdLdC.resize(dLdCrank / 2); + std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); + axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); + } else { + axesAdLdC.push_back(0); + axesBdLdC.push_back(0); + } - // calculate dLdA - MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt); + // calculate dLdA + MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt); - // calculate dLdB - MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt); + // calculate dLdB + MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt); - return Status::OK(); + return Status::OK(); } //////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(tensormmul_bp) { + auto aShapeInfo = inputShape->at(0); + auto bShapeInfo = inputShape->at(1); + auto dLShapeInfo = inputShape->at(2); - auto aShapeInfo = inputShape->at(0); - auto bShapeInfo = inputShape->at(1); - auto dLShapeInfo = inputShape->at(2); - - REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) && - (ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); + REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == + ArrayOptions::dataType(bShapeInfo) && + (ArrayOptions::dataType(dLShapeInfo) == + ArrayOptions::dataType(aShapeInfo))), + 0, "tensormmul_bp: A, B and dLdC data types must be the same"); - Nd4jLong* dLdAShapeInfo = nullptr; - Nd4jLong* dLdBShapeInfo = nullptr; + Nd4jLong* dLdAShapeInfo = nullptr; + Nd4jLong* dLdBShapeInfo = nullptr; - COPY_SHAPE(aShapeInfo, dLdAShapeInfo); - COPY_SHAPE(bShapeInfo, dLdBShapeInfo); + COPY_SHAPE(aShapeInfo, dLdAShapeInfo); + COPY_SHAPE(bShapeInfo, dLdBShapeInfo); - return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo)); + return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo)); } //////////////////////////////////////////////////////////////////////// DECLARE_TYPES(tensormmul_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS - ->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) - ->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }); -} -} + getOpDescriptor() + ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType::DOUBLE, + DataType::HALF}) // maybe better ALL_FLOATS + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 2, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 1, {DataType::FLOAT32, DataType::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp index 56047f16c9ec..5cefb3eddecc 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/boolean_not.cpp @@ -24,22 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(boolean_not, 1, 1,true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - x->applyTransform(transform::Not, *z); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_not) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::BOOL) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +OP_IMPL(boolean_not, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + x->applyTransform(transform::Not, *z); + + return Status::OK(); +} + +DECLARE_TYPES(boolean_not) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::BOOL) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp index e5d67baf1861..26ef481bcaee 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/choose.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/choose.cpp @@ -26,74 +26,75 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) { - - int mode = INT_ARG(0); - auto result = OUTPUT_VARIABLE(0); - auto numResults = OUTPUT_VARIABLE(1); - - if (block.width() > 1) { - auto arg = INPUT_VARIABLE(0); - auto comp = INPUT_VARIABLE(1); - - helpers::chooseFunctorArray(block.launchContext(), arg, comp, mode, result, numResults); - - }//scalar case - else { - double scalar = T_ARG(0); - auto arg = INPUT_VARIABLE(0); - helpers::chooseFunctorScalar(block.launchContext(), arg, scalar, mode, result, numResults); - } - - - return Status::OK(); - } - - DECLARE_TYPES(choose) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - - DECLARE_SHAPE_FN(choose) { - Nd4jLong const* shape; - int rank; - int mode = INT_ARG(0); - auto numResults = NDArrayFactory::create(0L); - if(block.width() > 1) { - auto first = INPUT_VARIABLE(0); - auto second = INPUT_VARIABLE(1); - if(first->lengthOf() > second->lengthOf()) { - shape = first->shapeInfo(); - rank = first->rankOf(); - } - else { - shape = second->shapeInfo(); - rank = second->rankOf(); - } - - helpers::chooseFunctorArray(block.launchContext(), first, second, mode, nullptr, &numResults); - } - else { - auto first = INPUT_VARIABLE(0); - shape = first->shapeInfo(); - rank = first->rankOf(); - double scalar = T_ARG(0); - - helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, nullptr, &numResults); - } - - auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(numResults.e(0), ArrayOptions::dataType(inputShape->at(0))); - - auto shapeScalar = ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64); - return SHAPELIST(newShape, shapeScalar); - } +namespace ops { +CUSTOM_OP_IMPL(choose, -1, 2, false, -2, -1) { + int mode = INT_ARG(0); + auto result = OUTPUT_VARIABLE(0); + auto numResults = OUTPUT_VARIABLE(1); + + if (block.width() > 1) { + auto arg = INPUT_VARIABLE(0); + auto comp = INPUT_VARIABLE(1); + + helpers::chooseFunctorArray(block.launchContext(), arg, comp, mode, result, + numResults); + + } // scalar case + else { + double scalar = T_ARG(0); + auto arg = INPUT_VARIABLE(0); + helpers::chooseFunctorScalar(block.launchContext(), arg, scalar, mode, + result, numResults); + } + + return Status::OK(); +} +DECLARE_TYPES(choose) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}); +} +DECLARE_SHAPE_FN(choose) { + Nd4jLong const* shape; + int rank; + int mode = INT_ARG(0); + auto numResults = NDArrayFactory::create(0L); + if (block.width() > 1) { + auto first = INPUT_VARIABLE(0); + auto second = INPUT_VARIABLE(1); + if (first->lengthOf() > second->lengthOf()) { + shape = first->shapeInfo(); + rank = first->rankOf(); + } else { + shape = second->shapeInfo(); + rank = second->rankOf(); } + + helpers::chooseFunctorArray(block.launchContext(), first, second, mode, + nullptr, &numResults); + } else { + auto first = INPUT_VARIABLE(0); + shape = first->shapeInfo(); + rank = first->rankOf(); + double scalar = T_ARG(0); + + helpers::chooseFunctorScalar(block.launchContext(), first, scalar, mode, + nullptr, &numResults); + } + + auto newShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + numResults.e(0), ArrayOptions::dataType(inputShape->at(0))); + + auto shapeScalar = + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64); + return SHAPELIST(newShape, shapeScalar); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp index c0623ebb70a0..4c733cd8a00d 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/eq_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(eq_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(eq_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) == y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(Equals, eq_scalar); - //DECLARE_SYN(equals, eq_scalar); + if (x->e(0) == y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(Equals, eq_scalar); +// DECLARE_SYN(equals, eq_scalar); - DECLARE_TYPES(eq_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(eq_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp index d40c501d4f93..8d97d33acf75 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/gt_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(gt_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(gt_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) > y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - //DECLARE_SYN(Greater, gt_scalar); - //DECLARE_SYN(greater, gt_scalar); + if (x->e(0) > y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +// DECLARE_SYN(Greater, gt_scalar); +// DECLARE_SYN(greater, gt_scalar); - DECLARE_TYPES(gt_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(gt_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp index d555f5d24437..ce854e4bbc42 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/gte_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(gte_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(gte_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) >= y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(GreaterOrEquals, gte_scalar); - DECLARE_SYN(greaterOrEquals, gte_scalar); + if (x->e(0) >= y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(GreaterOrEquals, gte_scalar); +DECLARE_SYN(greaterOrEquals, gte_scalar); - DECLARE_TYPES(gte_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(gte_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp index 4dd8ca605b12..69dcc3b44b17 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_non_decreasing.cpp @@ -25,30 +25,30 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { - auto input = INPUT_VARIABLE(0); - - // in case of empty input there's nothing to do - if (input->isEmpty()) - return ND4J_STATUS_TRUE; - - bool isNonDecreasing = true; - - sd::ops::helpers::compare_elem(block.launchContext(), input, false, isNonDecreasing); - - if (isNonDecreasing) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - - DECLARE_TYPES(is_non_decreasing) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +BOOLEAN_OP_IMPL(is_non_decreasing, 1, true) { + auto input = INPUT_VARIABLE(0); + + // in case of empty input there's nothing to do + if (input->isEmpty()) return ND4J_STATUS_TRUE; + + bool isNonDecreasing = true; + + sd::ops::helpers::compare_elem(block.launchContext(), input, false, + isNonDecreasing); + + if (isNonDecreasing) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} + +DECLARE_TYPES(is_non_decreasing) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp index 184b7b0a6da9..585bd6cb7ffe 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_numeric_tensor.cpp @@ -25,20 +25,19 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_numeric_tensor, 1, true) { +namespace ops { +BOOLEAN_OP_IMPL(is_numeric_tensor, 1, true) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - - return input->isR() || input->isZ() ? ND4J_STATUS_TRUE : ND4J_STATUS_FALSE; - } + return input->isR() || input->isZ() ? ND4J_STATUS_TRUE : ND4J_STATUS_FALSE; +} - DECLARE_TYPES(is_numeric_tensor) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(is_numeric_tensor) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp index 0c434cf571bd..580175ef2ba2 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/is_strictly_increasing.cpp @@ -25,30 +25,30 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { - auto input = INPUT_VARIABLE(0); - - // in case of empty input there's nothing to do - if (input->isEmpty()) - return ND4J_STATUS_TRUE; - - bool isStrictlyIncreasing = true; - - sd::ops::helpers::compare_elem(block.launchContext(), input, true, isStrictlyIncreasing); - - if (isStrictlyIncreasing) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - - DECLARE_TYPES(is_strictly_increasing) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +namespace ops { +BOOLEAN_OP_IMPL(is_strictly_increasing, 1, true) { + auto input = INPUT_VARIABLE(0); + + // in case of empty input there's nothing to do + if (input->isEmpty()) return ND4J_STATUS_TRUE; + + bool isStrictlyIncreasing = true; + + sd::ops::helpers::compare_elem(block.launchContext(), input, true, + isStrictlyIncreasing); + + if (isStrictlyIncreasing) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} + +DECLARE_TYPES(is_strictly_increasing) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp index 1c4f7ab2759b..a4dc91797fde 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/lt_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(lt_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(lt_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) < y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - //DECLARE_SYN(Less, lt_scalar); - //DECLARE_SYN(less, lt_scalar); + if (x->e(0) < y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +// DECLARE_SYN(Less, lt_scalar); +// DECLARE_SYN(less, lt_scalar); - DECLARE_TYPES(lt_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(lt_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp index 07a72cfedfdc..7fe63df45aa1 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/lte_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(lte_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(lte_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) <= y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(LessOrEquals, lte_scalar); - DECLARE_SYN(lessorequals, lte_scalar); + if (x->e(0) <= y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(LessOrEquals, lte_scalar); +DECLARE_SYN(lessorequals, lte_scalar); - DECLARE_TYPES(lte_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(lte_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp b/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp index 1c05b9fc13e4..6e9d8056a7f0 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/neq_scalar.cpp @@ -24,26 +24,26 @@ #include namespace sd { - namespace ops { - BOOLEAN_OP_IMPL(neq_scalar, 2, true) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); +namespace ops { +BOOLEAN_OP_IMPL(neq_scalar, 2, true) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - if (x->e(0) != y->e(0)) - return ND4J_STATUS_TRUE; - else - return ND4J_STATUS_FALSE; - } - DECLARE_SYN(NotEquals, neq_scalar); - DECLARE_SYN(notequals, neq_scalar); + if (x->e(0) != y->e(0)) + return ND4J_STATUS_TRUE; + else + return ND4J_STATUS_FALSE; +} +DECLARE_SYN(NotEquals, neq_scalar); +DECLARE_SYN(notequals, neq_scalar); - DECLARE_TYPES(neq_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } +DECLARE_TYPES(neq_scalar) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/select.cpp b/libnd4j/include/ops/declarable/generic/boolean/select.cpp index 1e3c9e1cc9ad..9088278d719c 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/select.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/select.cpp @@ -25,81 +25,88 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(select, 3, 1, false, 0, 0) { - auto cond = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - REQUIRE_TRUE(x->isSameShape(y), 0, "Select: X and Y shape should be equal"); - if (x->isScalar()) { - REQUIRE_TRUE(cond->isScalar(), 0, - "Select: Condition should gave either equal shape to X/Y first dimension or to be scalar"); - - auto z = OUTPUT_VARIABLE(0); - - if (y->isR()) { - auto v = !cond->e(0)? y->e(0) : x->e(0); - z->p(0, v); - } else { - auto v = !cond->e(0)? y->e(0) : x->e(0); - z->p(0, v); - } - } else { - bool same = cond->isSameShape(x); - REQUIRE_TRUE(cond->isScalar() || cond->lengthOf() == x->sizeAt(0) || same, 0, "Select: Condition should gave either equal shape to X/Y first dimension or to be scalar"); - if (same) { - auto z = OUTPUT_VARIABLE(0); - - for (int e = 0; e < cond->lengthOf(); e++) { - if (y->isR()) { - auto r = !cond->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } else { - auto r = !cond->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } - } - } else { - REQUIRE_TRUE(cond->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", cond->lengthOf()); - - auto z = OUTPUT_VARIABLE(0); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!cond->e(e)) { - tadsZ.at(e).assign(tadsY.at(e)); - } else { - tadsZ.at(e).assign(tadsX.at(e)); - } - } - } - } - - return Status::OK(); +namespace ops { +CUSTOM_OP_IMPL(select, 3, 1, false, 0, 0) { + auto cond = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + REQUIRE_TRUE(x->isSameShape(y), 0, "Select: X and Y shape should be equal"); + if (x->isScalar()) { + REQUIRE_TRUE(cond->isScalar(), 0, + "Select: Condition should gave either equal shape to X/Y " + "first dimension or to be scalar"); + + auto z = OUTPUT_VARIABLE(0); + + if (y->isR()) { + auto v = !cond->e(0) ? y->e(0) : x->e(0); + z->p(0, v); + } else { + auto v = !cond->e(0) ? y->e(0) : x->e(0); + z->p(0, v); + } + } else { + bool same = cond->isSameShape(x); + REQUIRE_TRUE(cond->isScalar() || cond->lengthOf() == x->sizeAt(0) || same, + 0, + "Select: Condition should gave either equal shape to X/Y " + "first dimension or to be scalar"); + if (same) { + auto z = OUTPUT_VARIABLE(0); + + for (int e = 0; e < cond->lengthOf(); e++) { + if (y->isR()) { + auto r = !cond->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } else { + auto r = !cond->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } + } + } else { + REQUIRE_TRUE(cond->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + cond->lengthOf()); + + auto z = OUTPUT_VARIABLE(0); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!cond->e(e)) { + tadsZ.at(e).assign(tadsY.at(e)); + } else { + tadsZ.at(e).assign(tadsX.at(e)); } + } + } + } - DECLARE_SHAPE_FN(select) { - auto inShape = inputShape->at(1); + return Status::OK(); +} - Nd4jLong *newshape; - COPY_SHAPE(inShape, newshape); +DECLARE_SHAPE_FN(select) { + auto inShape = inputShape->at(1); - return SHAPELIST(CONSTANT(newshape)); - } + Nd4jLong *newshape; + COPY_SHAPE(inShape, newshape); - DECLARE_TYPES(select) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::BOOL) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedInputTypes(2, DataType::ANY) - ->setAllowedOutputTypes(1, DataType::INHERIT); - } - } + return SHAPELIST(CONSTANT(newshape)); +} + +DECLARE_TYPES(select) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::BOOL) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedInputTypes(2, DataType::ANY) + ->setAllowedOutputTypes(1, DataType::INHERIT); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/where.cpp b/libnd4j/include/ops/declarable/generic/boolean/where.cpp index 3bc765b9aea5..8acfc44610c8 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where.cpp @@ -26,112 +26,122 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(Where, 1, 1, false, 0, 0) { - auto condition = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - if (z->isEmpty()) - return Status::OK(); - - if (block.width() == 3) { - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - REQUIRE_TRUE(x->isSameShape(y), 0, "X and Y must have equal shapes"); - - // if cond matches x/y shape - we have per-element mask - if (condition->isSameShape(x)) { - // FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y - for (int e = 0; e < condition->lengthOf(); e++) { - if (y->isR()) { - auto r = !condition->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } else { - auto r = !condition->e(e) ? y->e(e) : x->e(e); - z->p(e, r); - } - } - } else { - REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", condition->lengthOf()); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!condition->e(e)) { - tadsZ.at(e).assign(tadsY.at(e)); - } else { - tadsZ.at(e).assign(tadsX.at(e)); - } - } - } - } else { - // in this case we return 2D matrix, which basically contains coordinates fo true - REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width()); - auto output = OUTPUT_VARIABLE(0); - - int width = condition->rankOf(); - if (z->isEmpty()) - return ND4J_STATUS_OK; - - std::vector dims = ShapeUtils::evalDimsToExclude(width, {0}); - - helpers::_where(block.launchContext(), *condition, *output, block.workspace()); - } - return Status::OK(); +namespace ops { +CUSTOM_OP_IMPL(Where, 1, 1, false, 0, 0) { + auto condition = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + if (z->isEmpty()) return Status::OK(); + + if (block.width() == 3) { + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + REQUIRE_TRUE(x->isSameShape(y), 0, "X and Y must have equal shapes"); + + // if cond matches x/y shape - we have per-element mask + if (condition->isSameShape(x)) { + // FIXME: for perf it might be better to issue memcpy here, and fill only + // mismatched values from either X or Y + for (int e = 0; e < condition->lengthOf(); e++) { + if (y->isR()) { + auto r = !condition->e(e) ? y->e(e) : x->e(e); + z->p(e, r); + } else { + auto r = + !condition->e(e) ? y->e(e) : x->e(e); + z->p(e, r); } - - DECLARE_SHAPE_FN(Where) { - if (block.width() == 3) { - auto inShape = inputShape->at(1); - Nd4jLong *newshape; - COPY_SHAPE(inShape, newshape); - - return SHAPELIST(CONSTANT(newshape)); - } else { - // FIXME: we can't estimate result here in this case - // output shape is the 2D tensor num_true x rankOf (inShape) - auto condition = INPUT_VARIABLE(0); - auto inShape = inputShape->at(0); - Nd4jLong numOfTrue = 0; //condition->reduceNumber(reduce::CountNonZero, nullptr).e(0); - for (Nd4jLong i = 0; i < condition->lengthOf(); i++) - if (condition->e(i)) numOfTrue++; - - Nd4jLong const* theNewShape; - if (numOfTrue > 0) { - Nd4jLong* newShape; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - - newShape[0] = 2; - newShape[1] = numOfTrue; - newShape[2] = shape::rank(inShape); - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - ShapeUtils::updateStridesAndType(newShape, sd::DataType::INT64, 'c'); - - theNewShape = CONSTANT(newShape); - } - else { - theNewShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(sd::DataType::INT64); - } - - return SHAPELIST(theNewShape); - } + } + } else { + REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + condition->lengthOf()); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!condition->e(e)) { + tadsZ.at(e).assign(tadsY.at(e)); + } else { + tadsZ.at(e).assign(tadsX.at(e)); } + } + } + } else { + // in this case we return 2D matrix, which basically contains coordinates fo + // true + REQUIRE_TRUE( + block.width() == 1, 0, + "Where op takes either 1 or 3 operands, But got %d operands instead", + block.width()); + auto output = OUTPUT_VARIABLE(0); + + int width = condition->rankOf(); + if (z->isEmpty()) return ND4J_STATUS_OK; + + std::vector dims = ShapeUtils::evalDimsToExclude(width, {0}); + + helpers::_where(block.launchContext(), *condition, *output, + block.workspace()); + } + return Status::OK(); +} - DECLARE_TYPES(Where) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) // bool - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedInputTypes(2, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); - } +DECLARE_SHAPE_FN(Where) { + if (block.width() == 3) { + auto inShape = inputShape->at(1); + Nd4jLong* newshape; + COPY_SHAPE(inShape, newshape); + + return SHAPELIST(CONSTANT(newshape)); + } else { + // FIXME: we can't estimate result here in this case + // output shape is the 2D tensor num_true x rankOf (inShape) + auto condition = INPUT_VARIABLE(0); + auto inShape = inputShape->at(0); + Nd4jLong numOfTrue = 0; // condition->reduceNumber(reduce::CountNonZero, + // nullptr).e(0); + for (Nd4jLong i = 0; i < condition->lengthOf(); i++) + if (condition->e(i)) numOfTrue++; + + Nd4jLong const* theNewShape; + if (numOfTrue > 0) { + Nd4jLong* newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + + newShape[0] = 2; + newShape[1] = numOfTrue; + newShape[2] = shape::rank(inShape); + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + ShapeUtils::updateStridesAndType(newShape, sd::DataType::INT64, 'c'); + + theNewShape = CONSTANT(newShape); + } else { + theNewShape = ConstantShapeHelper::getInstance()->emptyShapeInfo( + sd::DataType::INT64); } + + return SHAPELIST(theNewShape); + } +} + +DECLARE_TYPES(Where) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) // bool + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedInputTypes(2, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp index 0d90c11a3c10..ad10a344c352 100644 --- a/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp +++ b/libnd4j/include/ops/declarable/generic/boolean/where_np.cpp @@ -18,8 +18,8 @@ // @author Adam Gibson // -#include #include +#include #if NOT_EXCLUDED(OP_where_np) @@ -27,136 +27,142 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(where_np, -1, 1, false, 0, 0) { - auto condition = INPUT_VARIABLE(0); - - if (block.width() == 3) { - auto x = INPUT_VARIABLE(1); - auto y = INPUT_VARIABLE(2); - - auto z = OUTPUT_VARIABLE(0); - int numMatches = 0; - // if cond matches x/y shape - we have per-element mask - if (condition->isSameShape(x)) { - // FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y - if(y->isScalar()) { - if (y->isR()) { - for (int e = 0; e < condition->lengthOf(); e++) { - auto r = condition->e(e) ? y->e(0) - : x->e(e); - z->p(e, r); - } - } else { - for (int e = 0; e < condition->lengthOf(); e++) { - auto r = condition->e(e) ? y->e(0) - : x->e(e); - z->p(e, r); - } - } - } - else { - if (y->isR()) { - for (int e = 0; e < condition->lengthOf(); e++) { - if (condition->e(e)) { - auto r = y->e(numMatches); - z->p(e, r); - numMatches++; - } else { - auto r = x->e(e); - z->p(e, r); - } - } - } else { - for (int e = 0; e < condition->lengthOf(); e++) { - if (condition->e(e)) { - auto r = y->e(numMatches); - z->p(e, r); - numMatches++; - } else { - auto r = x->e(e); - z->p(e, r); - } - } - } - } - } - else { - REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, "Condition length should be equal to the dim0 of x/y to act as TAD-mask, but got %d instead", condition->lengthOf()); - - auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); - auto tadsX = x->allTensorsAlongDimension(dims); - auto tadsY = y->allTensorsAlongDimension(dims); - auto tadsZ = z->allTensorsAlongDimension(dims); - - for (int e = 0; e < tadsX.size(); e++) { - if (!condition->e(e)) - tadsZ.at(e).assign(tadsY.at(e)); - else - tadsZ.at(e).assign(tadsX.at(e)); - } - } +namespace ops { +CUSTOM_OP_IMPL(where_np, -1, 1, false, 0, 0) { + auto condition = INPUT_VARIABLE(0); + + if (block.width() == 3) { + auto x = INPUT_VARIABLE(1); + auto y = INPUT_VARIABLE(2); + + auto z = OUTPUT_VARIABLE(0); + int numMatches = 0; + // if cond matches x/y shape - we have per-element mask + if (condition->isSameShape(x)) { + // FIXME: for perf it might be better to issue memcpy here, and fill only + // mismatched values from either X or Y + if (y->isScalar()) { + if (y->isR()) { + for (int e = 0; e < condition->lengthOf(); e++) { + auto r = condition->e(e) ? y->e(0) : x->e(e); + z->p(e, r); + } + } else { + for (int e = 0; e < condition->lengthOf(); e++) { + auto r = + condition->e(e) ? y->e(0) : x->e(e); + z->p(e, r); + } + } + } else { + if (y->isR()) { + for (int e = 0; e < condition->lengthOf(); e++) { + if (condition->e(e)) { + auto r = y->e(numMatches); + z->p(e, r); + numMatches++; } else { - // in this case we return 2D matrix, which basically contains coordinates fo true - - REQUIRE_TRUE(block.width() == 1, 0, "Where op takes either 1 or 3 operands, But got %d operands instead", block.width()); -// if (output->isEmpty()) - Nd4jLong width = condition->rankOf(); - - sd::ops::Where op; - auto res(op.evaluate({condition})); - REQUIRE_OK(res.status()); - auto& whereTrue = res.at(0); - - if (whereTrue.isEmpty()) - return ND4J_STATUS_OK; - for (Nd4jLong outNext = 0; outNext < width; ++outNext) { - auto output = OUTPUT_VARIABLE(outNext); - for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { - output->p(e, whereTrue.e(e, outNext)); - } - } + auto r = x->e(e); + z->p(e, r); } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(where_np) { - auto shapes = SHAPELIST(); - Nd4jLong *newShape; - if (block.width() == 3) { - auto inShape = inputShape->at(1); - COPY_SHAPE(inShape, newShape); - - shapes->push_back(CONSTANT(newShape)); + } + } else { + for (int e = 0; e < condition->lengthOf(); e++) { + if (condition->e(e)) { + auto r = y->e(numMatches); + z->p(e, r); + numMatches++; } else { - auto condition = INPUT_VARIABLE(0); - - Nd4jLong numOfTrue = 0LL; //condition->reduceNumber(reduce::CountNonZero).e(0); - for (Nd4jLong i = 0; i < condition->lengthOf(); ++i) - if (condition->e(i)) numOfTrue++; - - // output shape - a tuple of rank(inShape) 1D tensors with numOfTrue len - if (numOfTrue) { - for (Nd4jLong e = 0; e < condition->rankOf(); ++e) { - shapes->push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(numOfTrue, sd::DataType::INT64)); - } - } - else { - shapes->push_back(ConstantShapeHelper::getInstance()->emptyShapeInfo(sd::DataType::INT64)); - } + auto r = x->e(e); + z->p(e, r); } - return shapes; + } } + } + } else { + REQUIRE_TRUE(condition->lengthOf() == x->sizeAt(0), 0, + "Condition length should be equal to the dim0 of x/y to act " + "as TAD-mask, but got %d instead", + condition->lengthOf()); + + auto dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {0}); + auto tadsX = x->allTensorsAlongDimension(dims); + auto tadsY = y->allTensorsAlongDimension(dims); + auto tadsZ = z->allTensorsAlongDimension(dims); + + for (int e = 0; e < tadsX.size(); e++) { + if (!condition->e(e)) + tadsZ.at(e).assign(tadsY.at(e)); + else + tadsZ.at(e).assign(tadsX.at(e)); + } + } + } else { + // in this case we return 2D matrix, which basically contains coordinates fo + // true + + REQUIRE_TRUE( + block.width() == 1, 0, + "Where op takes either 1 or 3 operands, But got %d operands instead", + block.width()); + // if (output->isEmpty()) + Nd4jLong width = condition->rankOf(); + + sd::ops::Where op; + auto res(op.evaluate({condition})); + REQUIRE_OK(res.status()); + auto& whereTrue = res.at(0); + + if (whereTrue.isEmpty()) return ND4J_STATUS_OK; + for (Nd4jLong outNext = 0; outNext < width; ++outNext) { + auto output = OUTPUT_VARIABLE(outNext); + for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { + output->p(e, whereTrue.e(e, outNext)); + } + } + } - DECLARE_TYPES(where_np) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::BOOL) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedInputTypes(2, sd::DataType::ANY) - ->setAllowedOutputTypes( {ALL_FLOATS, ALL_INTS}); - } + return Status::OK(); +} + +DECLARE_SHAPE_FN(where_np) { + auto shapes = SHAPELIST(); + Nd4jLong* newShape; + if (block.width() == 3) { + auto inShape = inputShape->at(1); + COPY_SHAPE(inShape, newShape); + + shapes->push_back(CONSTANT(newShape)); + } else { + auto condition = INPUT_VARIABLE(0); + + Nd4jLong numOfTrue = + 0LL; // condition->reduceNumber(reduce::CountNonZero).e(0); + for (Nd4jLong i = 0; i < condition->lengthOf(); ++i) + if (condition->e(i)) numOfTrue++; + + // output shape - a tuple of rank(inShape) 1D tensors with numOfTrue len + if (numOfTrue) { + for (Nd4jLong e = 0; e < condition->rankOf(); ++e) { + shapes->push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo( + numOfTrue, sd::DataType::INT64)); + } + } else { + shapes->push_back(ConstantShapeHelper::getInstance()->emptyShapeInfo( + sd::DataType::INT64)); } + } + return shapes; +} + +DECLARE_TYPES(where_np) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::BOOL) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedInputTypes(2, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp index 936addea5644..35588f743300 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/add.cpp @@ -21,97 +21,98 @@ #include #if NOT_EXCLUDED(OP_add) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(add, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Add(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("add: result was replaced"); - - - return Status::OK(); - } - - DECLARE_TYPES(add) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(DataType::ANY); - } - - DECLARE_TYPES(add_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - CUSTOM_OP_IMPL(add_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - gradY->assign(epsNext); - gradX->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(sd::reduce::Sum); - gradY->assign(tmp); - gradX->assign(epsNext); - } else { - // broadcast case - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(epsNext); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(epsNext); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(add_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(add, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Add(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("add: result was replaced"); + + return Status::OK(); +} + +DECLARE_TYPES(add) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(DataType::ANY); +} + +DECLARE_TYPES(add_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(add_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + gradY->assign(epsNext); + gradX->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(sd::reduce::Sum); + gradY->assign(tmp); + gradX->assign(epsNext); + } else { + // broadcast case + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(epsNext); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(add_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp index aeaa5d128535..7d0d5d87b3f1 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/assign.cpp @@ -21,91 +21,93 @@ #include #if NOT_EXCLUDED(OP_assign) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(assign, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Assign(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return ND4J_STATUS_OK; - } - DECLARE_SYN(set, assign); - DECLARE_SYN(copy, assign); - - DECLARE_TYPES(assign) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(assign_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(assign_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradX->assign(0.0f); - - if (x->isSameShape(y)) { - gradY->assign(epsNext); - } else if (y->isScalar()) { - auto sum = epsNext->reduceNumber(sd::reduce::Sum); - gradY->assign(sum); - } else { - // broadcastable - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(epsNext); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(assign_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(assign, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Assign(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return ND4J_STATUS_OK; +} +DECLARE_SYN(set, assign); +DECLARE_SYN(copy, assign); + +DECLARE_TYPES(assign) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(assign_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(assign_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradX->assign(0.0f); + + if (x->isSameShape(y)) { + gradY->assign(epsNext); + } else if (y->isScalar()) { + auto sum = epsNext->reduceNumber(sd::reduce::Sum); + gradY->assign(sum); + } else { + // broadcastable + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(sd::reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(assign_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp index ed60f59250d7..cf20c04e01c9 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/atan2.cpp @@ -28,33 +28,35 @@ namespace sd { namespace ops { BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { + auto y = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + BROADCAST_CHECK_EMPTY(x, y, z); - BROADCAST_CHECK_EMPTY(x,y,z); + // auto tZ = BroadcastHelper::template broadcastApply>(y, + // x, z); + x->applyTrueBroadcast(sd::BroadcastOpsTuple::custom( + scalar::Atan2, pairwise::Atan2, broadcast::Atan2), + *y, *z, true); - // auto tZ = BroadcastHelper::template broadcastApply>(y, x, z); - x->applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true); + // if (tZ == nullptr) + // return ND4J_STATUS_KERNEL_FAILURE; + // else if (tZ != z) { + // OVERWRITE_RESULT(tZ); + // } - // if (tZ == nullptr) - // return ND4J_STATUS_KERNEL_FAILURE; - // else if (tZ != z) { - // OVERWRITE_RESULT(tZ); - // } - - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(tf_atan2) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - -} +DECLARE_TYPES(tf_atan2) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp index 32593ecf62f8..bfdd6818f202 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_and.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_and, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalAnd, pairwise::LogicalAnd, broadcast::LogicalAnd), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_and: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_and) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_and, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalAnd, pairwise::LogicalAnd, + broadcast::LogicalAnd), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_and: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_and) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp index 1dbb69f30846..f34f0c60aa27 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_or.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_or, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalOr, pairwise::LogicalOr, broadcast::LogicalOr), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_and: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_or) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_or, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalOr, pairwise::LogicalOr, + broadcast::LogicalOr), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_and: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_or) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp index 8f242fbda84d..dc49a4c7eabd 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/boolean_xor.cpp @@ -24,30 +24,33 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(boolean_xor, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::LogicalXor, pairwise::LogicalXor, broadcast::LogicalXor), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("boolean_xor: result was overwritten"); - - return Status::OK(); - } - - DECLARE_TYPES(boolean_xor) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(boolean_xor, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::LogicalXor, pairwise::LogicalXor, + broadcast::LogicalXor), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("boolean_xor: result was overwritten"); + + return Status::OK(); } +DECLARE_TYPES(boolean_xor) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp index cd907de366e0..0b192967f58e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide.cpp @@ -21,125 +21,128 @@ #include #if NOT_EXCLUDED(OP_divide) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(divide, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "DIVIDE OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Div, divide); - - DECLARE_TYPES(divide) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(divide_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(divide_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - -/* - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _e * -_x / (_y * _y); - }; -*/ - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - gradX->assign((*epsNext) / (*y)); - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, *gradY); - - } else if (y->isScalar()) { - // scalar case - - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - //tmpX.printBuffer("SumX"); - //tmp.printBuffer("Sum Eps"); - gradY->assign(tmp * tmpX / ((*y) * (*y))); - gradY->applyTransform(transform::Neg, *gradY); - - //epsNext->applyLambda(lambdaS, *gradX); - epsNext->applyScalarArr(scalar::Divide, *y, *gradX); - } else { - // broadcast case - - auto preX = *epsNext / *y; - - NDArray negX(*x); - x->applyTransform(transform::Neg, negX); - auto preY = *epsNext * negX / ((*y) * (*y)); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(divide_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(divide, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "DIVIDE OP: you can't divide by bool array!"); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Div, divide); + +DECLARE_TYPES(divide) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(divide_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(divide_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + /* + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { + return _e * -_x / (_y * _y); + }; + */ + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyPairwiseLambda(y, lambdaX, gradX); + gradX->assign((*epsNext) / (*y)); + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); + gradY->applyTransform(transform::Neg, *gradY); + + } else if (y->isScalar()) { + // scalar case + + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + // tmpX.printBuffer("SumX"); + // tmp.printBuffer("Sum Eps"); + gradY->assign(tmp * tmpX / ((*y) * (*y))); + gradY->applyTransform(transform::Neg, *gradY); + + // epsNext->applyLambda(lambdaS, *gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); + } else { + // broadcast case + + auto preX = *epsNext / *y; + + NDArray negX(*x); + x->applyTransform(transform::Neg, negX); + auto preY = *epsNext * negX / ((*y) * (*y)); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(divide_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp index 9ef300e1c4fa..f20a7b43061c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/divide_no_nan.cpp @@ -21,37 +21,39 @@ #include #if NOT_EXCLUDED(OP_divide_no_nan) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "DIVIDE_NO_NAN OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Div, divide); - - DECLARE_TYPES(divide_no_nan) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(divide_no_nan, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, + "DIVIDE_NO_NAN OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::DivideNoNan(), x, + y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Div, divide); + +DECLARE_TYPES(divide_no_nan) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp index 5d4aaef5e44a..2ad612b04ac6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/equals.cpp @@ -18,35 +18,38 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_SYN(equal, equals); - - DECLARE_TYPES(equals) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply( + BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, + broadcast::EqualTo), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_SYN(equal, equals); + +DECLARE_TYPES(equals) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp index d0a59bcc196d..b0384f9c5054 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floordiv.cpp @@ -21,75 +21,77 @@ #include #if NOT_EXCLUDED(OP_floordiv) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(floordiv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!y->isB(), 0, "FLOORDIV OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::custom(scalar::FloorDiv, pairwise::FloorDiv, broadcast::FloorDiv), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - - DECLARE_TYPES(floordiv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(floordiv_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(floordiv_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradY->assign(0.0f); - gradX->assign(0.0f); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(floordiv_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(floordiv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "FLOORDIV OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply( + BroadcastOpsTuple::custom(scalar::FloorDiv, pairwise::FloorDiv, + broadcast::FloorDiv), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(floordiv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(floordiv_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(floordiv_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradY->assign(0.0f); + gradX->assign(0.0f); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(floordiv_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp index 1319ccfd0a45..744b44ee7db6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/floormod.cpp @@ -14,93 +14,93 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // modified by sgazeos@gmail.com with backprop implementation. - // +// +// @author raver119@gmail.com +// modified by sgazeos@gmail.com with backprop implementation. +// #include #if NOT_EXCLUDED(OP_floormod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(floormod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x, y, z); - - REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(floormod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(floormod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - - CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - gradX->assign(epsNext); - - auto temp = epsNext->dup(); - BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); - - if (gradY->rankOf() == gradX->rankOf()) - epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); - else // epsNext is greater than gradY - { - std::vector dims(epsNext->rankOf() * 2); - Nd4jLong gap = epsNext->rankOf() - gradY->rankOf(); - for (Nd4jLong d = 0; d < gap; d++) { - dims[d * 2 + 1] = 1; - } - auto tempIn((temp)(dims)); - (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); - } - return Status::OK(); - } - - DECLARE_SHAPE_FN(floormod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong* shapeE; - Nd4jLong* shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } +namespace ops { +BROADCASTABLE_OP_IMPL(floormod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!y->isB(), 0, "FLOORMOD OP: you can't divide by bool array!"); + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(floormod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(floormod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(floormod_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + gradX->assign(epsNext); + + auto temp = epsNext->dup(); + BroadcastHelper::broadcastApply(BROADCAST(FloorMod), x, y, &temp); + + if (gradY->rankOf() == gradX->rankOf()) + epsNext->applyPairwiseTransform(pairwise::Multiply, temp, *gradY); + else // epsNext is greater than gradY + { + std::vector dims(epsNext->rankOf() * 2); + Nd4jLong gap = epsNext->rankOf() - gradY->rankOf(); + for (Nd4jLong d = 0; d < gap; d++) { + dims[d * 2 + 1] = 1; } + auto tempIn((temp)(dims)); + (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY); + } + return Status::OK(); +} + +DECLARE_SHAPE_FN(floormod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong* shapeE; + Nd4jLong* shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp index 084453dc8dea..56b5cbaf7f6d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater.cpp @@ -22,29 +22,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(greater) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(greater) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp index 5f448585eb4f..571017000603 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/greater_equal.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThanOrEqual), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThanOrEqual), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(greater_equal) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(greater_equal) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp index 9fa07424c04a..3f93ddb4d292 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igamma.cpp @@ -21,39 +21,41 @@ #include #if NOT_EXCLUDED(OP_igamma) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(igamma, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - -// auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, pairwise::IGamma, broadcast::IGamma}, x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); - - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(igamma) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(igamma, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + // auto tZ = BroadcastHelper::broadcastApply({scalar::IGamma, + // pairwise::IGamma, broadcast::IGamma}, x, y, z); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGamma(), x, y, z); + + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(igamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp index deeacd4ef59c..e4d9682d0a5b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/igammac.cpp @@ -21,38 +21,40 @@ #include #if NOT_EXCLUDED(OP_igammac) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(igammac, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - -// auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, pairwise::IGammac, broadcast::IGammac}, x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(igammac) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(igammac, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + // auto tZ = BroadcastHelper::broadcastApply({scalar::IGammac, + // pairwise::IGammac, broadcast::IGammac}, x, y, z); + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::IGammac(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(igammac) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp index 5d9c73f1b2d2..2f0aec719cfe 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less.cpp @@ -21,29 +21,29 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThan), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThan), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(less) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(less) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp index a0f0a0366029..1a5779a6f1f5 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/less_equal.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThanOrEqual), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(LessThanOrEqual), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(less_equal) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(less_equal) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp index dfb6b3d66954..d5634035ed93 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp @@ -21,70 +21,70 @@ #include #if NOT_EXCLUDED(OP_maximum) -#include #include +#include #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(maximum, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MaxPairwise), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(maximum) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(maximum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(maximum_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - helpers::maximumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); - return Status::OK(); - } - - DECLARE_SHAPE_FN(maximum_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(maximum, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MaxPairwise), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(maximum) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(maximum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(maximum_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + helpers::maximumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); + return Status::OK(); +} + +DECLARE_SHAPE_FN(maximum_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp index 51c0d876a283..c3e5ca77d36e 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/meshgrid.cpp @@ -22,72 +22,72 @@ #if NOT_EXCLUDED(OP_meshgrid) #include -#include +#include + #include namespace sd { -namespace ops { +namespace ops { -CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { - - int rank = block.width(); +CUSTOM_OP_IMPL(meshgrid, -1, -1, false, 0, 0) { + int rank = block.width(); - if(rank == 1) { - OUTPUT_VARIABLE(0)->assign(INPUT_VARIABLE(0)); - return Status::OK(); - } + if (rank == 1) { + OUTPUT_VARIABLE(0)->assign(INPUT_VARIABLE(0)); + return Status::OK(); + } - bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; - std::vector inArrs(rank); - std::vector outArrs(rank); + std::vector inArrs(rank); + std::vector outArrs(rank); - for(int i = 0; i < rank; ++i) { - inArrs[i] = INPUT_VARIABLE(i); - outArrs[i] = OUTPUT_VARIABLE(i); - } + for (int i = 0; i < rank; ++i) { + inArrs[i] = INPUT_VARIABLE(i); + outArrs[i] = OUTPUT_VARIABLE(i); + } - helpers::meshgrid(block.launchContext(), inArrs, outArrs, swapFirst2Dims); + helpers::meshgrid(block.launchContext(), inArrs, outArrs, swapFirst2Dims); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(meshgrid) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes(DataType::INHERIT) - ->setSameMode(true); - } - +DECLARE_TYPES(meshgrid) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes(DataType::INHERIT) + ->setSameMode(true); +} DECLARE_SHAPE_FN(meshgrid) { - bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; - - int rank = block.width(); - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i = 1; i <= rank; ++i) - outShapeInfo[i] = (Nd4jLong)shape::length(inputShape->at(i - 1)); - - if(swapFirst2Dims && rank > 1) - math::nd4j_swap(outShapeInfo[1], outShapeInfo[2]); - - auto in = inputShape->at(0); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - - auto shapes = SHAPELIST(); - auto resultShape = CONSTANT(outShapeInfo); + bool swapFirst2Dims = block.numI() > 0 ? (bool)INT_ARG(0) : true; + + int rank = block.width(); + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) + outShapeInfo[i] = (Nd4jLong)shape::length(inputShape->at(i - 1)); + + if (swapFirst2Dims && rank > 1) + math::nd4j_swap(outShapeInfo[1], outShapeInfo[2]); + + auto in = inputShape->at(0); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + + auto shapes = SHAPELIST(); + auto resultShape = CONSTANT(outShapeInfo); + shapes->push_back(resultShape); + + for (int i = 2; i <= rank; ++i) { shapes->push_back(resultShape); + } - for(int i = 2; i <= rank; ++i) { - shapes->push_back(resultShape); - } - - return shapes; + return shapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp index ef8645d1dbb2..854922fc1c22 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/minimum.cpp @@ -26,66 +26,65 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(minimum, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MinPairwise),x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(minimum) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(minimum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(minimum_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - helpers::minimumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); - return Status::OK(); - } - - DECLARE_SHAPE_FN(minimum_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(minimum, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(MinPairwise), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return ND4J_STATUS_OK; } +DECLARE_TYPES(minimum) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(minimum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(minimum_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + helpers::minimumBPFunctor(block.launchContext(), x, y, epsNext, gradX, gradY); + return Status::OK(); +} + +DECLARE_SHAPE_FN(minimum_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp index 95c710d170b6..dc88f27e4032 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp @@ -21,75 +21,75 @@ #include #if NOT_EXCLUDED(OP_mod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(mod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(Mod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(mod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(mod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(mod_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - gradY->assign(0.0f); - gradX->assign(0.0f); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(mod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(mod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(Mod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(mod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(mod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(mod_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + gradY->assign(0.0f); + gradX->assign(0.0f); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(mod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp index 84035d90dc2f..c44adc830a38 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/multiply.cpp @@ -27,123 +27,132 @@ namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(multiply, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - const Nd4jLong* zShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, zShapeInfo, block.workspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Multiply(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) - throw std::runtime_error("multiply: result was replaced"); - - return Status::OK(); - } - DECLARE_SYN(Mul, multiply); - - DECLARE_TYPES(multiply) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(multiply_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +BROADCASTABLE_OP_IMPL(multiply, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + const Nd4jLong* zShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, zShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "MULTIPLY OP: the shapes of x %s and y %s are not suitable for " + "broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + auto tZ = BroadcastHelper::broadcastApply(sd::BroadcastOpsTuple::Multiply(), + x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) + throw std::runtime_error("multiply: result was replaced"); + + return Status::OK(); +} +DECLARE_SYN(Mul, multiply); + +DECLARE_TYPES(multiply) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(multiply_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} /////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto dLdz = INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdy = OUTPUT_VARIABLE(1); - - const Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable for broadcast !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, "MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); - - const Nd4jLong xLen = x->lengthOf(); - const Nd4jLong yLen = y->lengthOf(); - - if(x->isScalar() && y->isScalar()) { // both are scalars - y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - //dLdx->assign((*y) * (*dLdz)); - //dLdy->assign((*x) * (*dLdz)); - - } - else if(x->isScalar()) { // x is scalar and y is not - dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); - //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); - } - else if(y->isScalar()) { // y is scalar and x is not - dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); - dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); - } - else if(x->isSameShape(y)) { - x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - } - else if (x->isSameShape(dLdz)) { - - auto yTiled = NDArray(dLdz, false, block.launchContext()); - y->tile(yTiled); - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - - dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); - yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); - } - else if (y->isSameShape(dLdz)) { - - auto xTiled = NDArray(dLdz, false, block.launchContext()); - x->tile(xTiled); - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); - xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); - } - else { - - auto xTiled = NDArray(dLdz, false, block.launchContext()); - auto yTiled = NDArray(dLdz, false, block.launchContext()); - x->tile(xTiled); - y->tile(yTiled); - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - - dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) ); - dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) ); - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + const Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable " + "for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, + "MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), + ShapeUtils::shapeAsString(dLdz).c_str()); + + const Nd4jLong xLen = x->lengthOf(); + const Nd4jLong yLen = y->lengthOf(); + + if (x->isScalar() && y->isScalar()) { // both are scalars + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + // dLdx->assign((*y) * (*dLdz)); + // dLdy->assign((*x) * (*dLdz)); + + } else if (x->isScalar()) { // x is scalar and y is not + dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); + dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy); + // dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); + } else if (y->isScalar()) { // y is scalar and x is not + dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); + dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx); + } else if (x->isSameShape(y)) { + x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (x->isSameShape(dLdz)) { + auto yTiled = NDArray(dLdz, false, block.launchContext()); + y->tile(yTiled); + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + + dLdy->assign((*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY)); + yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx); + } else if (y->isSameShape(dLdz)) { + auto xTiled = NDArray(dLdz, false, block.launchContext()); + x->tile(xTiled); + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + + dLdx->assign((*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX)); + xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy); + } else { + auto xTiled = NDArray(dLdz, false, block.launchContext()); + auto yTiled = NDArray(dLdz, false, block.launchContext()); + x->tile(xTiled); + y->tile(yTiled); + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + + dLdx->assign((*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX)); + dLdy->assign((*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY)); + } + + return Status::OK(); } DECLARE_SHAPE_FN(multiply_bp) { + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - Nd4jLong *dLdxShapeInfo = nullptr; - Nd4jLong *dLdyShapeInfo = nullptr; + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); } /* CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) { @@ -194,20 +203,20 @@ DECLARE_SHAPE_FN(multiply_bp) { preX->tileToShape(targetShape); preY->tileToShape(targetShape); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = + ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); + auto axisY = + ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); if (axisX.size() > 0) { - auto sum = preX->template reduceAlongDimension>(axisX); - gradX->assign(sum); - delete sum; + auto sum = preX->template + reduceAlongDimension>(axisX); gradX->assign(sum); delete sum; } else gradX->assign(preX); if (axisY.size() > 0) { - auto sum = preY->template reduceAlongDimension>(axisY); - gradY->assign(sum); - delete sum; + auto sum = preY->template + reduceAlongDimension>(axisY); gradY->assign(sum); delete sum; } else gradY->assign(preY); @@ -220,7 +229,7 @@ DECLARE_SHAPE_FN(multiply_bp) { } */ -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp index 9e2609f9d9c1..befc126347e6 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/not_equals.cpp @@ -21,29 +21,30 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(NotEqualTo), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(NotEqualTo), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(not_equals) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); - } - } -} \ No newline at end of file +DECLARE_TYPES(not_equals) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::BOOL); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp index a5c019db1f00..6cb552d1be26 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/percentile.cpp @@ -24,68 +24,98 @@ #include #include - namespace sd { -namespace ops { - -CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { - auto input = INPUT_VARIABLE(0); // tensor with rank > 0 - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - const auto q = T_ARG(0); // percentile - const int interpolation = block.numT() > 1 ? T_ARG(1) : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) - const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default - - const int axisArrRank = block.numI(); - const int inputArrRank = input->rankOf(); - - REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); - REQUIRE_TRUE(0.f <= q && q <= 100.f, 0, "PERCENTILE OP: percentile parameter must be within [0, 100] range, but got %f instead !", q); - REQUIRE_TRUE(interpolation == 0 || interpolation == 1 || interpolation == 2, 0, "PERCENTILE OP: the correct values for interpolation parameter are 0, 1, 2, but got %i instead !", interpolation); - REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, "PERCENTILE OP: the rank of axis array must be <= rank of input array, but got %i and %i correspondingly !", axisArrRank, inputArrRank); - - for(int i = 0; i < axisArrRank; ++i) { - int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; - REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); - } - - auto axises = block.getIArguments(); - helpers::percentile(block.launchContext(), *input, *output, axises, q, interpolation); - - return Status::OK(); +namespace ops { + +CUSTOM_OP_IMPL(percentile, 1, 1, false, 1, -2) { + auto input = INPUT_VARIABLE(0); // tensor with rank > 0 + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + const auto q = T_ARG(0); // percentile + const int interpolation = + block.numT() > 1 ? T_ARG(1) + : 2.; // 0-"lower", 1-"higher", 2-"nearest"(default) + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default + + const int axisArrRank = block.numI(); + const int inputArrRank = input->rankOf(); + + REQUIRE_TRUE(inputArrRank > 0, 0, + "PERCENTILE OP: rank of input array must be positive (>0), but " + "got %i instead !", + inputArrRank); + REQUIRE_TRUE(0.f <= q && q <= 100.f, 0, + "PERCENTILE OP: percentile parameter must be within [0, 100] " + "range, but got %f instead !", + q); + REQUIRE_TRUE(interpolation == 0 || interpolation == 1 || interpolation == 2, + 0, + "PERCENTILE OP: the correct values for interpolation parameter " + "are 0, 1, 2, but got %i instead !", + interpolation); + REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, + "PERCENTILE OP: the rank of axis array must be <= rank of input " + "array, but got %i and %i correspondingly !", + axisArrRank, inputArrRank); + + for (int i = 0; i < axisArrRank; ++i) { + int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; + REQUIRE_TRUE( + dim < inputArrRank, 0, + "PERCENTILE OP: element (dimension) of axis array at position %i is >= " + "rank of input array (%i >= %i), which is unacceptable !", + i, dim, inputArrRank); + } + + auto axises = block.getIArguments(); + helpers::percentile(block.launchContext(), *input, *output, axises, q, + interpolation); + + return Status::OK(); } - DECLARE_TYPES(percentile) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - +DECLARE_TYPES(percentile) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} DECLARE_SHAPE_FN(percentile) { - auto inputShapeInfo = inputShape->at(0); - const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default - - const int axisArrRank = block.numI(); - const int inputArrRank = inputShapeInfo[0]; - - REQUIRE_TRUE(inputArrRank > 0, 0, "PERCENTILE OP: rank of input array must be positive (>0), but got %i instead !", inputArrRank); - REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, "PERCENTILE OP: the rank of axis array must be <= rank of input array, but got %i and %i correspondingly !", axisArrRank, inputArrRank); - - for(int i = 0; i < axisArrRank; ++i) { - int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; - REQUIRE_TRUE(dim < inputArrRank, 0, "PERCENTILE OP: element (dimension) of axis array at position %i is >= rank of input array (%i >= %i), which is unacceptable !", i, dim, inputArrRank); - } - - auto axises = block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, block.workspace()); - - return SHAPELIST(outputShapeInfo); + auto inputShapeInfo = inputShape->at(0); + const int keepDims = block.numT() > 2 ? T_ARG(2) : 0.; // false is default + + const int axisArrRank = block.numI(); + const int inputArrRank = inputShapeInfo[0]; + + REQUIRE_TRUE(inputArrRank > 0, 0, + "PERCENTILE OP: rank of input array must be positive (>0), but " + "got %i instead !", + inputArrRank); + REQUIRE_TRUE(axisArrRank <= inputArrRank, 0, + "PERCENTILE OP: the rank of axis array must be <= rank of input " + "array, but got %i and %i correspondingly !", + axisArrRank, inputArrRank); + + for (int i = 0; i < axisArrRank; ++i) { + int dim = INT_ARG(i) >= 0 ? INT_ARG(i) : INT_ARG(i) + inputArrRank; + REQUIRE_TRUE( + dim < inputArrRank, 0, + "PERCENTILE OP: element (dimension) of axis array at position %i is >= " + "rank of input array (%i >= %i), which is unacceptable !", + i, dim, inputArrRank); + } + + auto axises = block.getIArguments(); + auto outputShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShapeInfo), axises, inputShapeInfo, keepDims, false, + block.workspace()); + + return SHAPELIST(outputShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp index ebbf625011ba..85283fac93a7 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/pow.cpp @@ -22,106 +22,112 @@ #include #if NOT_EXCLUDED(OP_Pow) -#include #include +#include namespace sd { namespace ops { - BROADCASTABLE_OP_IMPL(Pow, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - //REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); - - auto tZ = BroadcastHelper::broadcastApply({scalar::Pow, pairwise::Pow, broadcast::Pow}, x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(Pow) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); - } - - CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto dLdz = INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdy = OUTPUT_VARIABLE(1); - - const Nd4jLong* dLdzShapeInfo = nullptr; - const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); - REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s" - " and y %s are not suitable for broadcast !", - ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, +BROADCASTABLE_OP_IMPL(Pow, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + // REQUIRE_TRUE(!y->isB(), 0, "Pairwise OP: you can't divide by bool array!"); + + auto tZ = BroadcastHelper::broadcastApply( + {scalar::Pow, pairwise::Pow, broadcast::Pow}, x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} + +DECLARE_TYPES(Pow) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); +} + +CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto dLdz = INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdy = OUTPUT_VARIABLE(1); + + const Nd4jLong* dLdzShapeInfo = nullptr; + const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo( + x->shapeInfo(), y->shapeInfo(), true, dLdzShapeInfo, block.workspace()); + REQUIRE_TRUE(areShapesBroadcastable, 0, + "POW_BP OP: the shapes of x %s" + " and y %s are not suitable for broadcast !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0, "POW_BP OP: wrong shape of next epsilon array (dLdOut)," - " expected is %s, but got %s instead !", - ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str()); - - // dL/dy = x^y * log(x) * dL/dz - auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y - x->applyTransform(transform::Log, *dLdx); // b = log(x) - dLdx->applyScalar(sd::scalar::ReplaceNans, 0, *dLdx); - temp *= *dLdx; // c = b*a - temp *= *dLdz; // dL/dy = c * dL/dz - if (dLdy->isSameShape(*dLdz)) { - dLdy->assign(temp); - } - else { - std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), dLdz->shapeInfo()); - dLdy->assign(temp.reduceAlongDimension(reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) - } - - // dL/dx = y*x^(y-1) * dL/dz - x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, temp); // a = y*x^(y-1) - temp *= *dLdz; // dLdx = a*dL/dz - - if (dLdx->isSameShape(*dLdz)) { - dLdx->assign(temp); // dLdx = a*dL/dz - } - else { - std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), dLdz->shapeInfo()); - dLdx->assign(temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(Pow_bp) { - - auto xShapeInfo = inputShape->at(0); - auto yShapeInfo = inputShape->at(1); - - Nd4jLong* dLdxShapeInfo = nullptr; - Nd4jLong* dLdyShapeInfo = nullptr; - - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - COPY_SHAPE(yShapeInfo, dLdyShapeInfo); - - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); - } - - DECLARE_TYPES(Pow_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } + " expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), + ShapeUtils::shapeAsString(dLdz).c_str()); + + // dL/dy = x^y * log(x) * dL/dz + auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y + x->applyTransform(transform::Log, *dLdx); // b = log(x) + dLdx->applyScalar(sd::scalar::ReplaceNans, 0, *dLdx); + temp *= *dLdx; // c = b*a + temp *= *dLdz; // dL/dy = c * dL/dz + if (dLdy->isSameShape(*dLdz)) { + dLdy->assign(temp); + } else { + std::vector axesForY = ShapeUtils::evalBroadcastBackwardAxis( + y->shapeInfo(), dLdz->shapeInfo()); + dLdy->assign(temp.reduceAlongDimension( + reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz) + } + + // dL/dx = y*x^(y-1) * dL/dz + x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, + temp); // a = y*x^(y-1) + temp *= *dLdz; // dLdx = a*dL/dz + if (dLdx->isSameShape(*dLdz)) { + dLdx->assign(temp); // dLdx = a*dL/dz + } else { + std::vector axesForX = ShapeUtils::evalBroadcastBackwardAxis( + x->shapeInfo(), dLdz->shapeInfo()); + dLdx->assign( + temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz + } + + return Status::OK(); } + +DECLARE_SHAPE_FN(Pow_bp) { + auto xShapeInfo = inputShape->at(0); + auto yShapeInfo = inputShape->at(1); + + Nd4jLong* dLdxShapeInfo = nullptr; + Nd4jLong* dLdyShapeInfo = nullptr; + + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + COPY_SHAPE(yShapeInfo, dLdyShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo)); } +DECLARE_TYPES(Pow_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index 3691ffb55e6f..e8ca145ea3ed 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -21,119 +21,122 @@ #include #if NOT_EXCLUDED(OP_realdiv) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(realdiv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(RealDiv, realdiv); - - DECLARE_TYPES(realdiv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType::HALF, DataType::DOUBLE}); - } - - DECLARE_TYPES(realdiv_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +BROADCASTABLE_OP_IMPL(realdiv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Divide(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(RealDiv, realdiv); + +DECLARE_TYPES(realdiv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType::HALF, DataType::DOUBLE}); +} - CUSTOM_OP_IMPL(realdiv_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); +DECLARE_TYPES(realdiv_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); +CUSTOM_OP_IMPL(realdiv_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); - if (x->isSameShape(y)) { - // PWT case case + if (x->isSameShape(y)) { + // PWT case case - // X gradient - //epsNext->applyPairwiseLambda(y, lambdaX, gradX); - epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); + // X gradient + // epsNext->applyPairwiseLambda(y, lambdaX, gradX); + epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX); - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * -(*x) / ((*y) * (*y))); + gradY->assign((*epsNext) * -(*x) / ((*y) * (*y))); - } else if (y->isScalar()) { - // scalar case + } else if (y->isScalar()) { + // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmp * -tmpX / ((*y) * (*y))); + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmp * -tmpX / ((*y) * (*y))); - //epsNext->applyLambda(lambdaS, gradX); - epsNext->applyScalarArr(scalar::Divide, *y, *gradX); - } else { - // broadcast case + // epsNext->applyLambda(lambdaS, gradX); + epsNext->applyScalarArr(scalar::Divide, *y, *gradX); + } else { + // broadcast case - auto preX = *epsNext / *y; + auto preX = *epsNext / *y; - NDArray negX(*x); - x->applyTransform(transform::Neg, negX); - auto preY = *epsNext * negX / ((*y) * (*y)); + NDArray negX(*x); + x->applyTransform(transform::Neg, negX); + auto preY = *epsNext * negX / ((*y) * (*y)); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(realdiv_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +DECLARE_SHAPE_FN(realdiv_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - // eps always has shape of x - // grad always has shape of y + // eps always has shape of x + // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong *shapeE; + Nd4jLong *shapeG; - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - return shapeList; - } - } + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp index 0b6ea7d2a00a..6497a9568298 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_divide.cpp @@ -21,108 +21,111 @@ #include #if NOT_EXCLUDED(OP_reversedivide) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversedivide, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); - x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); - - return Status::OK(); - } - DECLARE_SYN(RDiv, reversedivide); - - DECLARE_TYPES(reversedivide) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - DECLARE_TYPES(reversedivide_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(reversedivide_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); - gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, *gradX); - // Y gradient - //epsNext->applyPairwiseLambda(x, lambdaY, gradY); - gradY->assign((*epsNext) / (*x)); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmp / tmpX); - - gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); - gradX->applyTransform(transform::Neg, *gradX); - } else { - // broadcast case - - auto preY = (*epsNext) / (*x); - - auto preX = *epsNext * (*y) / ((*x) * (*x)); - preX.applyTransform(transform::Neg, preX); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(reversedivide_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversedivide, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + REQUIRE_TRUE(!x->isB(), 0, + "REVERSEDIVIDE OP: you can't divide by bool array!"); + x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true); + + return Status::OK(); +} +DECLARE_SYN(RDiv, reversedivide); + +DECLARE_TYPES(reversedivide) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +DECLARE_TYPES(reversedivide_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(reversedivide_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); + gradX->applyTransform(transform::Neg, *gradX); + // Y gradient + // epsNext->applyPairwiseLambda(x, lambdaY, gradY); + gradY->assign((*epsNext) / (*x)); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmp / tmpX); + + gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); + gradX->applyTransform(transform::Neg, *gradX); + } else { + // broadcast case + + auto preY = (*epsNext) / (*x); + + auto preX = *epsNext * (*y) / ((*x) * (*x)); + preX.applyTransform(transform::Neg, preX); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(reversedivide_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp index bb25fada6cbf..4203ab9f0b1b 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp @@ -21,76 +21,75 @@ #include #if NOT_EXCLUDED(OP_reversemod) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversemod, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseMod), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - - DECLARE_TYPES(reversemod) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversemod, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseMod), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} - DECLARE_TYPES(reversemod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(reversemod) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +DECLARE_TYPES(reversemod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - CUSTOM_OP_IMPL(reversemod_bp, 3, 2, false, 0, 0) { - // PLEASE NOTE: we're just passing eps down the line here - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); +CUSTOM_OP_IMPL(reversemod_bp, 3, 2, false, 0, 0) { + // PLEASE NOTE: we're just passing eps down the line here + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); - gradY->assign(0.0f); - gradX->assign(0.0f); + gradY->assign(0.0f); + gradX->assign(0.0f); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(reversemod_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +DECLARE_SHAPE_FN(reversemod_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - // eps always has shape of x - // grad always has shape of y + // eps always has shape of x + // grad always has shape of y - Nd4jLong *shapeE; - Nd4jLong *shapeG; + Nd4jLong *shapeE; + Nd4jLong *shapeG; - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - return shapeList; - } - } + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index 5d33c7cea09d..c3cbe232c95d 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -21,101 +21,104 @@ #include #if NOT_EXCLUDED(OP_reversesubtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(reversesubtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(ReverseSubtract), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(RSub, reversesubtract); - - DECLARE_TYPES(reversesubtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(reversesubtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - epsNext->applyTransform(transform::Neg, *gradX); - gradY->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - gradY->assign(tmp); - epsNext->applyTransform(transform::Neg, *gradX); - } else { - // broadcastable - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - sum.applyTransform(transform::Neg, *gradX); - } else { - epsNext->applyTransform(transform::Neg, *gradX); - } - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else { - gradY->assign(epsNext); - } - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(reversesubtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } - - DECLARE_TYPES(reversesubtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +BROADCASTABLE_OP_IMPL(reversesubtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST(ReverseSubtract), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(RSub, reversesubtract); + +DECLARE_TYPES(reversesubtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(reversesubtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + epsNext->applyTransform(transform::Neg, *gradX); + gradY->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + gradY->assign(tmp); + epsNext->applyTransform(transform::Neg, *gradX); + } else { + // broadcastable + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); + sum.applyTransform(transform::Neg, *gradX); + } else { + epsNext->applyTransform(transform::Neg, *gradX); + } + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else { + gradY->assign(epsNext); } + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(reversesubtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; +} + +DECLARE_TYPES(reversesubtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp index 6f5482512227..0b619d707a35 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp @@ -21,137 +21,138 @@ #include #if NOT_EXCLUDED(OP_squaredsubtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(squaredsubtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(SquaredSubtract), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(squareddifference, squaredsubtract); - - DECLARE_TYPES(squaredsubtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(squaredsubtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - /* - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _e * (T) 2.0 * (_x - _y) ; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _e * (T) 2.0 * (_y - _x); - }; - */ - - auto ts = NDArrayFactory::create(x->dataType(), 2, block.launchContext()); - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); - gradX->assign((*epsNext) * ts * ((*x) - (*y))); - - // Y gradient - //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign((*epsNext) * ts * ((*y) - (*x))); - - } else if (y->isScalar()) { - // scalar case - auto tmpX = x->reduceNumber(reduce::Sum); - gradY->assign(tmpX); - - //epsNext->applyPairwiseLambda(x, lambdaS, gradX); - gradX->assign((*epsNext) * ts * ((*x) - (*y))); - } else { - // broadcast case - - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - - //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); - //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); - auto resX = (*epsNext) * ts * ((*x) - (*y)); - preX.assign(resX); - auto resY = (*epsNext) * ts * ((*y) - (*x)); - preY.assign(resY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } +namespace ops { +BROADCASTABLE_OP_IMPL(squaredsubtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST(SquaredSubtract), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(squareddifference, squaredsubtract); - return Status::OK(); - } +DECLARE_TYPES(squaredsubtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} - DECLARE_SHAPE_FN(squaredsubtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); +CUSTOM_OP_IMPL(squaredsubtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + /* + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { + return _e * (T) 2.0 * (_x - _y) ; + }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { + return _e * (T) 2.0 * (_y - _x); + }; + */ + + auto ts = NDArrayFactory::create(x->dataType(), 2, block.launchContext()); + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); + gradX->assign((*epsNext) * ts * ((*x) - (*y))); + + // Y gradient + // epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); + gradY->assign((*epsNext) * ts * ((*y) - (*x))); + + } else if (y->isScalar()) { + // scalar case + auto tmpX = x->reduceNumber(reduce::Sum); + gradY->assign(tmpX); + + // epsNext->applyPairwiseLambda(x, lambdaS, gradX); + gradX->assign((*epsNext) * ts * ((*x) - (*y))); + } else { + // broadcast case + + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + // epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); + // epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); + auto resX = (*epsNext) * ts * ((*x) - (*y)); + preX.assign(resX); + auto resY = (*epsNext) * ts * ((*y) - (*x)); + preY.assign(resY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } + + return Status::OK(); +} - // eps always has shape of x - // grad always has shape of y +DECLARE_SHAPE_FN(squaredsubtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); - Nd4jLong *shapeE; - Nd4jLong *shapeG; + // eps always has shape of x + // grad always has shape of y - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); + Nd4jLong *shapeE; + Nd4jLong *shapeG; - return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - } + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); - DECLARE_TYPES(squaredsubtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); +} - } +DECLARE_TYPES(squaredsubtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index fac1c5dfaed9..1b0d45440794 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -21,102 +21,104 @@ #include #if NOT_EXCLUDED(OP_subtract) -#include #include +#include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(subtract, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - BROADCAST_CHECK_EMPTY(x,y,z); - - auto tZ = BroadcastHelper::broadcastApply(BroadcastOpsTuple::Subtract(), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } - - return Status::OK(); - } - DECLARE_SYN(Sub, subtract); - DECLARE_SYN(sub, subtract); - - DECLARE_TYPES(subtract) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(subtract_bp, 3, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto epsNext = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - if (x->isSameShape(y)) { - // PWT case case - epsNext->applyTransform(transform::Neg, *gradY); - gradX->assign(epsNext); - } else if (y->isScalar()) { - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - gradY->assign(-tmp); - gradX->assign(epsNext); - } else { - // broadcastable - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(epsNext); - - if (axisY.size() > 0) { - auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); - sum.applyTransform(transform::Neg, *gradY); - } else { - epsNext->applyTransform(transform::Neg, *gradY); - } - } - - return Status::OK(); - } - - DECLARE_TYPES(subtract_bp) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - DECLARE_SHAPE_FN(subtract_bp) { - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto e = inputShape->at(2); - - // eps always has shape of x - // grad always has shape of y - - Nd4jLong *shapeE; - Nd4jLong *shapeG; - - COPY_SHAPE(x, shapeE); - COPY_SHAPE(y, shapeG); - - auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); - - return shapeList; - } +namespace ops { +BROADCASTABLE_OP_IMPL(subtract, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + BROADCAST_CHECK_EMPTY(x, y, z); + + auto tZ = + BroadcastHelper::broadcastApply(BroadcastOpsTuple::Subtract(), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } + + return Status::OK(); +} +DECLARE_SYN(Sub, subtract); +DECLARE_SYN(sub, subtract); + +DECLARE_TYPES(subtract) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(subtract_bp, 3, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto epsNext = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + if (x->isSameShape(y)) { + // PWT case case + epsNext->applyTransform(transform::Neg, *gradY); + gradX->assign(epsNext); + } else if (y->isScalar()) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + gradY->assign(-tmp); + gradX->assign(epsNext); + } else { + // broadcastable + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(epsNext); + + if (axisY.size() > 0) { + auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); + sum.applyTransform(transform::Neg, *gradY); + } else { + epsNext->applyTransform(transform::Neg, *gradY); } + } + + return Status::OK(); +} + +DECLARE_TYPES(subtract_bp) { + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(subtract_bp) { + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto e = inputShape->at(2); + + // eps always has shape of x + // grad always has shape of y + + Nd4jLong *shapeE; + Nd4jLong *shapeG; + + COPY_SHAPE(x, shapeE); + COPY_SHAPE(y, shapeG); + + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp index 60900a5d9644..dbd86e537210 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/truncatediv.cpp @@ -22,29 +22,29 @@ #include namespace sd { - namespace ops { - BROADCASTABLE_OP_IMPL(truncatediv, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +BROADCASTABLE_OP_IMPL(truncatediv, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - BROADCAST_CHECK_EMPTY(x,y,z); + BROADCAST_CHECK_EMPTY(x, y, z); - auto tZ = BroadcastHelper::broadcastApply(BROADCAST(TruncateDiv), x, y, z); - if (tZ == nullptr) - return ND4J_STATUS_KERNEL_FAILURE; - else if (tZ != z) { - OVERWRITE_RESULT(tZ); - } + auto tZ = BroadcastHelper::broadcastApply(BROADCAST(TruncateDiv), x, y, z); + if (tZ == nullptr) + return ND4J_STATUS_KERNEL_FAILURE; + else if (tZ != z) { + OVERWRITE_RESULT(tZ); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(truncatediv) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT); - } - } -} \ No newline at end of file +DECLARE_TYPES(truncatediv) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp index 95dbdfcea05d..f5c83975ef7b 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_sparse_to_dense.cpp @@ -25,49 +25,52 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { - auto indices = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - NDArray *def = nullptr; +namespace ops { +CUSTOM_OP_IMPL(compat_sparse_to_dense, 4, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + NDArray *def = nullptr; - auto output = OUTPUT_NULLIFIED(0); + auto output = OUTPUT_NULLIFIED(0); - if (block.width() > 3) - def = INPUT_VARIABLE(3); + if (block.width() > 3) def = INPUT_VARIABLE(3); - sd::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); + sd::ops::helpers::compat_sparse_to_dense(*values, *indices, def, *output); - return Status::OK(); - }; + return Status::OK(); +}; - DECLARE_SHAPE_FN(compat_sparse_to_dense) { - auto indices = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); +DECLARE_SHAPE_FN(compat_sparse_to_dense) { + auto indices = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); - if (block.width() > 3) { - auto def = INPUT_VARIABLE(3); + if (block.width() > 3) { + auto def = INPUT_VARIABLE(3); - REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, "compat_sparse_to_dense: default value must be a scalar of the same data type as actual values") - }; + REQUIRE_TRUE(def->dataType() == values->dataType() && def->isScalar(), 0, + "compat_sparse_to_dense: default value must be a scalar of " + "the same data type as actual values") + }; - auto dtype = values->dataType(); + auto dtype = values->dataType(); - // basically output shape is defined by the type of input, and desired shape input - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape->getBufferAsVector())); - } + // basically output shape is defined by the type of input, and desired shape + // input + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, 'c', shape->getBufferAsVector())); +} - DECLARE_TYPES(compat_sparse_to_dense) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) // indices - ->setAllowedInputTypes(1, {ALL_INTS}) // shape - ->setAllowedInputTypes(2,sd::DataType::ANY) // sparse values - ->setAllowedInputTypes(3,sd::DataType::ANY) // default value - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(compat_sparse_to_dense) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) // indices + ->setAllowedInputTypes(1, {ALL_INTS}) // shape + ->setAllowedInputTypes(2, sd::DataType::ANY) // sparse values + ->setAllowedInputTypes(3, sd::DataType::ANY) // default value + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp index 40e080a8fc63..c006429d9108 100644 --- a/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp +++ b/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp @@ -14,126 +14,134 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// #include #if NOT_EXCLUDED(OP_split_string) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto indices = OUTPUT_NULLIFIED(0); + auto values = OUTPUT_VARIABLE(1); + + auto d = delim->e(0); + + input->syncToHost(); + delim->syncToHost(); + + // output rank N+1 wrt input rank + std::vector icoords(input->rankOf()); + + // getting buffer lengths + // FIXME: it'll be bigger, since it'll include delimiters, + auto outputLength = StringUtils::byteLength(*input); - auto indices = OUTPUT_NULLIFIED(0); - auto values = OUTPUT_VARIABLE(1); - - auto d = delim->e(0); - - input->syncToHost(); - delim->syncToHost(); - - // output rank N+1 wrt input rank - std::vector icoords(input->rankOf()); - - // getting buffer lengths - // FIXME: it'll be bigger, since it'll include delimiters, - auto outputLength = StringUtils::byteLength(*input); - - uint64_t ss = 0L; - Nd4jLong ic = 0L; - // loop through each string within tensor - for (auto e = 0L; e < input->lengthOf(); e++) { - // now we should map substring to indices - auto s = input->e(e); - - // getting base index - shape::index2coordsCPU(0, e, input->shapeInfo(), icoords.data()); - - // getting number of substrings - auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; - - // filling output indices - for (uint64_t f = 0; f < cnt; f++) { - for (auto v : icoords) - indices->p(ic++, v); - - // last index - indices->p(ic++, f); - } - - ss += cnt; - } - - // process strings now - std::vector strings; - for (auto e = 0L; e < input->lengthOf(); e++) { - auto split = StringUtils::split(input->e(e), d); - - for (const auto& s : split) - strings.emplace_back(s); - } - - // now once we have all strings in single vector time to fill - auto tmp = NDArrayFactory::string({ (Nd4jLong)strings.size() }, strings, input->dataType(), block.launchContext()); - auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); - - // for CUDA mostly - values->dataBuffer()->allocatePrimary(); - values->dataBuffer()->expand(blen); - memcpy(values->buffer(), tmp.buffer(), blen); - values->tickWriteHost(); - - // special case, for future use - indices->syncToDevice(); - values->syncToDevice(); - - // we have to tick buffers - values->dataBuffer()->writePrimary(); - values->dataBuffer()->readSpecial(); - - return Status::OK(); - }; - - DECLARE_SHAPE_FN(compat_string_split) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - auto d = delim->e(0); - - // count number of delimiter substrings in all strings within input tensor - uint64_t cnt = 0; - for (auto e = 0L; e < input->lengthOf(); e++) { - // FIXME: bad, not UTF-compatible - auto s = input->e(e); - - // each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays - cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; - } - - // shape calculations - // virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension - // values tensor is going to be vector always - // indices tensor is going to be vector with length equal to values.length * output rank - - auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, sd::DataType::UTF8); - auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), sd::DataType::INT64); - - return SHAPELIST(indicesShape, valuesShape); - } + uint64_t ss = 0L; + Nd4jLong ic = 0L; + // loop through each string within tensor + for (auto e = 0L; e < input->lengthOf(); e++) { + // now we should map substring to indices + auto s = input->e(e); - DECLARE_TYPES(compat_string_split) { - getOpDescriptor() - ->setAllowedInputTypes({ ALL_STRINGS }) - ->setAllowedOutputTypes(0, { ALL_INDICES }) - ->setAllowedOutputTypes(1, { ALL_STRINGS }); - } + // getting base index + shape::index2coordsCPU(0, e, input->shapeInfo(), icoords.data()); + + // getting number of substrings + auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), + d.length()) + + 1; + + // filling output indices + for (uint64_t f = 0; f < cnt; f++) { + for (auto v : icoords) indices->p(ic++, v); + + // last index + indices->p(ic++, f); } + + ss += cnt; + } + + // process strings now + std::vector strings; + for (auto e = 0L; e < input->lengthOf(); e++) { + auto split = StringUtils::split(input->e(e), d); + + for (const auto& s : split) strings.emplace_back(s); + } + + // now once we have all strings in single vector time to fill + auto tmp = NDArrayFactory::string({(Nd4jLong)strings.size()}, strings, + input->dataType(), block.launchContext()); + auto blen = StringUtils::byteLength(tmp) + + ShapeUtils::stringBufferHeaderRequirements(strings.size()); + + // for CUDA mostly + values->dataBuffer()->allocatePrimary(); + values->dataBuffer()->expand(blen); + memcpy(values->buffer(), tmp.buffer(), blen); + values->tickWriteHost(); + + // special case, for future use + indices->syncToDevice(); + values->syncToDevice(); + + // we have to tick buffers + values->dataBuffer()->writePrimary(); + values->dataBuffer()->readSpecial(); + + return Status::OK(); +}; + +DECLARE_SHAPE_FN(compat_string_split) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + auto d = delim->e(0); + + // count number of delimiter substrings in all strings within input tensor + uint64_t cnt = 0; + for (auto e = 0L; e < input->lengthOf(); e++) { + // FIXME: bad, not UTF-compatible + auto s = input->e(e); + + // each substring we see in haystack, splits string in two parts. so we + // should add 1 to the number of subarrays + cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), + d.length()) + + 1; + } + + // shape calculations + // virtual tensor rank will be N+1, for N rank input array, where data will be + // located at the biggest dimension values tensor is going to be vector always + // indices tensor is going to be vector with length equal to values.length * + // output rank + + auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + cnt, sd::DataType::UTF8); + auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + cnt * (input->rankOf() + 1), sd::DataType::INT64); + + return SHAPELIST(indicesShape, valuesShape); +} + +DECLARE_TYPES(compat_string_split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_STRINGS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp b/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp index 4b77e2a45b90..f50c25c9f39e 100644 --- a/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp +++ b/libnd4j/include/ops/declarable/generic/compression/bitmap.cpp @@ -18,75 +18,78 @@ // @author George A. Shulinok // -#include #include #include +#include #if NOT_EXCLUDED(OP_decode_bitmap) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(decode_bitmap, 2, 1, true, 0, 0) { - const auto encoded = INPUT_VARIABLE(1); - auto updates = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(decode_bitmap, 2, 1, true, 0, 0) { + const auto encoded = INPUT_VARIABLE(1); + auto updates = OUTPUT_VARIABLE(0); - helpers::decodeBitmap(block.launchContext(), encoded, updates); - return Status::OK(); - } + helpers::decodeBitmap(block.launchContext(), encoded, updates); + return Status::OK(); +} - DECLARE_SHAPE_FN(decode_bitmap) { - auto weights = INPUT_VARIABLE(0); +DECLARE_SHAPE_FN(decode_bitmap) { + auto weights = INPUT_VARIABLE(0); - return SHAPELIST(weights->shapeInfo()); - } + return SHAPELIST(weights->shapeInfo()); +} - DECLARE_TYPES(decode_bitmap) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(decode_bitmap) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif #if NOT_EXCLUDED(OP_encode_bitmap) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(encode_bitmap, 1, 3, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto encoded = OUTPUT_NULLIFIED(1); - auto counter = OUTPUT_NULLIFIED(2); +namespace ops { +CUSTOM_OP_IMPL(encode_bitmap, 1, 3, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto encoded = OUTPUT_NULLIFIED(1); + auto counter = OUTPUT_NULLIFIED(2); - float threshold = T_ARG(0); + float threshold = T_ARG(0); - encoded->p(0, (int) input->lengthOf()); - encoded->p(1, (int) input->lengthOf()); - encoded->p(2, reinterpret_cast(&threshold)[0]); - encoded->p(3, 1); // flag for BITMAP_ENCODING + encoded->p(0, (int)input->lengthOf()); + encoded->p(1, (int)input->lengthOf()); + encoded->p(2, reinterpret_cast(&threshold)[0]); + encoded->p(3, 1); // flag for BITMAP_ENCODING - auto result = helpers::encodeBitmap(block.launchContext(), input, encoded, threshold); - counter->p(0, result); - counter->syncToDevice(); + auto result = + helpers::encodeBitmap(block.launchContext(), input, encoded, threshold); + counter->p(0, result); + counter->syncToDevice(); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(encode_bitmap) { - auto input = inputShape->at(0); +DECLARE_SHAPE_FN(encode_bitmap) { + auto input = inputShape->at(0); - auto outputLength = shape::length(input) / 16 + 5; - auto encodedShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(outputLength, DataType::INT32); - auto encodedCounter = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32); - return SHAPELIST(input, encodedShape, encodedCounter); - } + auto outputLength = shape::length(input) / 16 + 5; + auto encodedShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + outputLength, DataType::INT32); + auto encodedCounter = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32); + return SHAPELIST(input, encodedShape, encodedCounter); +} - DECLARE_TYPES(encode_bitmap) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedInputTypes(2, DataType::INT32); - } - } +DECLARE_TYPES(encode_bitmap) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedInputTypes(2, DataType::INT32); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/compression/threshold.cpp b/libnd4j/include/ops/declarable/generic/compression/threshold.cpp index 9512621e817a..65b4b16d49f8 100644 --- a/libnd4j/include/ops/declarable/generic/compression/threshold.cpp +++ b/libnd4j/include/ops/declarable/generic/compression/threshold.cpp @@ -18,87 +18,99 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(encode_threshold, 1, 2, true, 1, 0) { - auto x = INPUT_VARIABLE(0); - auto updated = OUTPUT_VARIABLE(0); - auto encoded = OUTPUT_NULLIFIED(1); - - float threshold = T_ARG(0); - - REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, "encode_threshold: gradients array must have length <= MAX_INT"); - REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, "encode_threshold: array for encoded updates can't have less than 4 elements"); -// REQUIRE_TRUE(x->platformBuffer() == updated->platformBuffer(), 0, "encode_threshold: gradients array must be the same at input and output"); - - // filling header bytes - encoded->p(0, encoded->lengthOf() - 4); - encoded->p(1, (int) x->lengthOf()); - encoded->p(2, reinterpret_cast(&threshold)[0]); - encoded->p(3, 0); // flag for FLEXIBLE_ENCODING - - // if there's no updates to process - just skip execution - if (encoded->lengthOf() == 4) - return Status::OK(); - - helpers::thresholdEncode(*x, *encoded, threshold); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(encode_threshold) { - auto x = INPUT_VARIABLE(0); - // we have limit option here - int boundary = block.numI() > 0 ? I_ARG(0) : DataTypeUtils::max(); - float threshold = T_ARG(0); - - REQUIRE_TRUE(boundary >= 0, 0, "encode_threshold: boundary must be positive"); - REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, "encode_threshold: gradients array must have length <= MAX_INT"); - - // we must calculate number of elements that >= threshold - auto elements = sd::math::nd4j_min(helpers::thresholdEstimate(*x, threshold), boundary); - if (elements < 2) - elements = 0; - - // result array must have 4 additional int elements for header - return SHAPELIST(x->shapeInfo(), sd::ConstantShapeHelper::getInstance()->vectorShapeInfo(elements + 4, DataType::INT32)); - } - - DECLARE_TYPES(encode_threshold) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, DataType::INT32); - } - - CUSTOM_OP_IMPL(decode_threshold, 2, 1, true, 0, 0) { - auto weights = INPUT_VARIABLE(0); - auto encoded = INPUT_VARIABLE(1); - auto updates = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, "decode_threshold: encoded array can't have length < 4"); - REQUIRE_TRUE(updates->lengthOf() == encoded->e(1), 0, "decode_threshold: updates array must have length equal to [%i]", encoded->e(1)); - REQUIRE_TRUE(encoded->e(3) == 0, 0, "decode_threshold: encoded array doesn't look like threshold-encoded"); - - helpers::thresholdDecode(*encoded, *updates); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(decode_threshold) { - auto weights = inputShape->at(0); - return SHAPELIST(weights); - } - - DECLARE_TYPES(decode_threshold) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes(0,{ALL_FLOATS}); - } - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(encode_threshold, 1, 2, true, 1, 0) { + auto x = INPUT_VARIABLE(0); + auto updated = OUTPUT_VARIABLE(0); + auto encoded = OUTPUT_NULLIFIED(1); + + float threshold = T_ARG(0); + + REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, + "encode_threshold: gradients array must have length <= MAX_INT"); + REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, + "encode_threshold: array for encoded updates can't have less " + "than 4 elements"); + // REQUIRE_TRUE(x->platformBuffer() == updated->platformBuffer(), + // 0, "encode_threshold: gradients array must be the same at input + // and output"); + + // filling header bytes + encoded->p(0, encoded->lengthOf() - 4); + encoded->p(1, (int)x->lengthOf()); + encoded->p(2, reinterpret_cast(&threshold)[0]); + encoded->p(3, 0); // flag for FLEXIBLE_ENCODING + + // if there's no updates to process - just skip execution + if (encoded->lengthOf() == 4) return Status::OK(); + + helpers::thresholdEncode(*x, *encoded, threshold); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(encode_threshold) { + auto x = INPUT_VARIABLE(0); + // we have limit option here + int boundary = block.numI() > 0 ? I_ARG(0) : DataTypeUtils::max(); + float threshold = T_ARG(0); + + REQUIRE_TRUE(boundary >= 0, 0, "encode_threshold: boundary must be positive"); + REQUIRE_TRUE(x->lengthOf() <= DataTypeUtils::max(), 0, + "encode_threshold: gradients array must have length <= MAX_INT"); + + // we must calculate number of elements that >= threshold + auto elements = sd::math::nd4j_min( + helpers::thresholdEstimate(*x, threshold), boundary); + if (elements < 2) elements = 0; + + // result array must have 4 additional int elements for header + return SHAPELIST(x->shapeInfo(), + sd::ConstantShapeHelper::getInstance()->vectorShapeInfo( + elements + 4, DataType::INT32)); +} + +DECLARE_TYPES(encode_threshold) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, DataType::INT32); +} + +CUSTOM_OP_IMPL(decode_threshold, 2, 1, true, 0, 0) { + auto weights = INPUT_VARIABLE(0); + auto encoded = INPUT_VARIABLE(1); + auto updates = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(encoded->lengthOf() >= 4, 0, + "decode_threshold: encoded array can't have length < 4"); + REQUIRE_TRUE(updates->lengthOf() == encoded->e(1), 0, + "decode_threshold: updates array must have length equal to [%i]", + encoded->e(1)); + REQUIRE_TRUE( + encoded->e(3) == 0, 0, + "decode_threshold: encoded array doesn't look like threshold-encoded"); + + helpers::thresholdDecode(*encoded, *updates); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(decode_threshold) { + auto weights = inputShape->at(0); + return SHAPELIST(weights); +} + +DECLARE_TYPES(decode_threshold) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index fe42d7057629..fd8e043950db 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -25,79 +25,90 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - // when empty - nothing to do - DataType newType = DataTypeUtils::fromInt(INT_ARG(0)); - DataType oldType = input->dataType(); - // correct output shape to conform with output data type - auto inputSize = DataTypeUtils::sizeOf(oldType); - auto outputSize = DataTypeUtils::sizeOf(newType); - auto lastSize = outputSize / inputSize; - if (inputSize < outputSize) { - REQUIRE_TRUE(input->sizeAt(-1) == lastSize, 0, - "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, - outputSize, lastSize, input->sizeAt(-1)); - } - if(input->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(bitcast, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + // when empty - nothing to do + DataType newType = DataTypeUtils::fromInt(INT_ARG(0)); + DataType oldType = input->dataType(); + // correct output shape to conform with output data type + auto inputSize = DataTypeUtils::sizeOf(oldType); + auto outputSize = DataTypeUtils::sizeOf(newType); + auto lastSize = outputSize / inputSize; + if (inputSize < outputSize) { + REQUIRE_TRUE( + input->sizeAt(-1) == lastSize, 0, + "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", + inputSize, outputSize, lastSize, input->sizeAt(-1)); + } + if (input->isEmpty()) { + REQUIRE_TRUE( + output->isEmpty(), 0, + "BITCAST: If input is empty, output array must also be empty."); + return Status::OK(); + } - // just memcpy data - DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); + // just memcpy data + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); - return Status::OK(); - } - DECLARE_SYN(BitCast, bitcast); + return Status::OK(); +} +DECLARE_SYN(BitCast, bitcast); - DECLARE_SHAPE_FN(bitcast) { - auto inShape = inputShape->at(0); - auto inputRank = shape::rank(inShape); - auto it = INT_ARG(0); - DataType newType = DataTypeUtils::fromInt(it); - DataType oldType = ArrayOptions::dataType(inShape); - // correct output shape to conform with output data type - auto inputSize = DataTypeUtils::sizeOf(oldType); - auto outputSize = DataTypeUtils::sizeOf(newType); +DECLARE_SHAPE_FN(bitcast) { + auto inShape = inputShape->at(0); + auto inputRank = shape::rank(inShape); + auto it = INT_ARG(0); + DataType newType = DataTypeUtils::fromInt(it); + DataType oldType = ArrayOptions::dataType(inShape); + // correct output shape to conform with output data type + auto inputSize = DataTypeUtils::sizeOf(oldType); + auto outputSize = DataTypeUtils::sizeOf(newType); - if (shape::length(inShape) == 0) - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); + if (shape::length(inShape) == 0) + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, newType))); - if (inputSize == outputSize) { - // only type should be changed - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); - } - else if (inputSize > outputSize) { - // range of output increased by 1 with inputSize / outputSize as last dimension - std::vector shapeOf(inputRank + 1); - int i; - for (i = 0; i < inputRank; ++i) { - shapeOf[i] = inShape[i + 1]; - } - shapeOf[i] = inputSize / outputSize; - auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); - return SHAPELIST(outputShape); - } - REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1)); - std::vector shapeOf(inputRank - 1); + if (inputSize == outputSize) { + // only type should be changed + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, newType))); + } else if (inputSize > outputSize) { + // range of output increased by 1 with inputSize / outputSize as last + // dimension + std::vector shapeOf(inputRank + 1); + int i; + for (i = 0; i < inputRank; ++i) { + shapeOf[i] = inShape[i + 1]; + } + shapeOf[i] = inputSize / outputSize; + auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo( + newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); + } + REQUIRE_TRUE( + shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, + "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", + inputSize, outputSize, outputSize / inputSize, + shape::sizeAt(inShape, -1)); + std::vector shapeOf(inputRank - 1); - for (auto i = 0; i < shapeOf.size(); ++i) { - shapeOf[i] = inShape[i + 1]; - } + for (auto i = 0; i < shapeOf.size(); ++i) { + shapeOf[i] = inShape[i + 1]; + } - auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); - return SHAPELIST(outputShape); - } + auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo( + newType, shape::order(inShape), shapeOf); + return SHAPELIST(outputShape); +} - DECLARE_TYPES(bitcast) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(bitcast) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp index cf8729d2f7ad..e8c10cf7fdd3 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/cast.cpp @@ -25,39 +25,40 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(cast, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if(input->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "If input is empty, output array must also be empty"); - return Status::OK(); - } - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - return Status::OK(); - } - DECLARE_SYN(Cast, cast); - - DECLARE_SHAPE_FN(cast) { - auto inShape = inputShape->at(0); - - auto it = INT_ARG(0); - DataType newType = DataTypeUtils::fromInt(it); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); - } - - DECLARE_TYPES(cast) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +namespace ops { +CUSTOM_OP_IMPL(cast, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) { + REQUIRE_TRUE(output->isEmpty(), 0, + "If input is empty, output array must also be empty"); + return Status::OK(); + } + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + return Status::OK(); +} +DECLARE_SYN(Cast, cast); + +DECLARE_SHAPE_FN(cast) { + auto inShape = inputShape->at(0); + + auto it = INT_ARG(0); + DataType newType = DataTypeUtils::fromInt(it); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, newType))); +} + +DECLARE_TYPES(cast) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp index 4eae77a5a285..42bbe029c7fe 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_double) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::DOUBLE); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_double) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::DOUBLE, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_double) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::DOUBLE); +} - } +DECLARE_SHAPE_FN(to_double) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::DOUBLE, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp index aa8ceb045483..bdf81af3f456 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_float16) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::HALF); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_float16) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::HALF, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_float16) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::HALF); +} - } +DECLARE_SHAPE_FN(to_float16) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::HALF, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp index 23a924f9cb85..2616cce0c1dc 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp @@ -24,31 +24,31 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - STORE_RESULT(output); + STORE_RESULT(output); - return Status::OK(); - } - - DECLARE_TYPES(to_float32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::FLOAT32); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(to_float32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::FLOAT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } +DECLARE_TYPES(to_float32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::FLOAT32); +} - } +DECLARE_SHAPE_FN(to_float32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::FLOAT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp index c28fa6049f1a..565481bb10d7 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_int32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT32); - } - DECLARE_SHAPE_FN(to_int32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_int32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT32); +} +DECLARE_SHAPE_FN(to_int32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::INT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp index cb994ccfe5e0..b785d32c4871 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_int64) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT64); - } - DECLARE_SHAPE_FN(to_int64) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT64, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_int64) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT64); +} +DECLARE_SHAPE_FN(to_int64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::INT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp index f62d9cd9b87b..bfff0bdb64dd 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_uint32) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT32); - } - DECLARE_SHAPE_FN(to_uint32) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT32, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - - } +namespace ops { +CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); } +DECLARE_TYPES(to_uint32) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT32); +} +DECLARE_SHAPE_FN(to_uint32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::UINT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp index dc337ea1bb73..2e3c02a80613 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp @@ -24,29 +24,29 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - STORE_RESULT(output); - - return Status::OK(); - } - - DECLARE_TYPES(to_uint64) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::INT8); - } - DECLARE_SHAPE_FN(to_uint64) { - auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT64, true, block.workspace()); - return SHAPELIST(CONSTANT(outShape)); - } - } +namespace ops { +CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + STORE_RESULT(output); + + return Status::OK(); +} + +DECLARE_TYPES(to_uint64) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::INT8); +} +DECLARE_SHAPE_FN(to_uint64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), DataType::UINT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp index ef42be652fd1..6b7ddd3f9cbb 100644 --- a/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/flow/flow_control_ops.cpp @@ -18,71 +18,68 @@ // Created by raver119 on 13.10.2017. // - #include #include namespace sd { - namespace ops { - /** - * This operation is, basically IF statement - * - * arg_0 is our "signal" - * arg_1 is condition that will determine transition - */ - // TODO: make this op a placeholder too - DIVERGENT_OP_IMPL(Switch, 2, 2, true) { - auto input = INPUT_VARIABLE(0); - auto condition = INPUT_VARIABLE(1); - - // we'll store signal to both ends - //STORE_2_RESULTS(*input, *input); -/* - // but we'll ensure only one node is active, and other is disabled - if (condition->e(0) == 0) { - block.setBranch(0); - this->storeResult(block, 0, new NDArray(input->dup())); - } else { - block.setBranch(1); - this->storeResult(block, 1, new NDArray(input->dup())); - } - - return Status::OK(); - */ +namespace ops { +/** + * This operation is, basically IF statement + * + * arg_0 is our "signal" + * arg_1 is condition that will determine transition + */ +// TODO: make this op a placeholder too +DIVERGENT_OP_IMPL(Switch, 2, 2, true) { + auto input = INPUT_VARIABLE(0); + auto condition = INPUT_VARIABLE(1); - throw std::runtime_error("Switch - Not implemented yet"); - } - DECLARE_SYN(switch, Switch); - DECLARE_SYN(if, Switch); + // we'll store signal to both ends + // STORE_2_RESULTS(*input, *input); + /* + // but we'll ensure only one node is active, and other is disabled + if (condition->e(0) == 0) { + block.setBranch(0); + this->storeResult(block, 0, new NDArray(input->dup())); + } else { + block.setBranch(1); + this->storeResult(block, 1, new NDArray(input->dup())); + } + return Status::OK(); + */ - /** - * This op is a placeholder. - * Actual WHILE implementation is in GraphExecutioner - */ - LOGIC_OP_IMPL(While); - DECLARE_SYN(while, While); + throw std::runtime_error("Switch - Not implemented yet"); +} +DECLARE_SYN(switch, Switch); +DECLARE_SYN(if, Switch); - /** - * This op is a placeholder. - * Actual Scope implementation is in Graph and GraphExecutioner - */ - LOGIC_OP_IMPL(Scope); - DECLARE_SYN(scope, Scope); +/** + * This op is a placeholder. + * Actual WHILE implementation is in GraphExecutioner + */ +LOGIC_OP_IMPL(While); +DECLARE_SYN(while, While); - /** - * This op is a placeholder. - * Actual Conditional implementation is in Graph and GraphExecutioner - */ - LOGIC_OP_IMPL(Conditional); - DECLARE_SYN(cond, Conditional); +/** + * This op is a placeholder. + * Actual Scope implementation is in Graph and GraphExecutioner + */ +LOGIC_OP_IMPL(Scope); +DECLARE_SYN(scope, Scope); +/** + * This op is a placeholder. + * Actual Conditional implementation is in Graph and GraphExecutioner + */ +LOGIC_OP_IMPL(Conditional); +DECLARE_SYN(cond, Conditional); - /** - * This op is a placeholder - * Actual implementation is in LogicReturn class - */ - LOGIC_OP_IMPL(Return); - DECLARE_SYN(return, Return); - } -} \ No newline at end of file +/** + * This op is a placeholder + * Actual implementation is in LogicReturn class + */ +LOGIC_OP_IMPL(Return); +DECLARE_SYN(return, Return); +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp b/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp index e4dbcd6d549d..0b74bdeebe58 100644 --- a/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp +++ b/libnd4j/include/ops/declarable/generic/grad/broadcast_gradient_args.cpp @@ -24,23 +24,21 @@ #include namespace sd { - namespace ops { - /** - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - OP_IMPL(broadcastgradientargs, 2, 2, true) { - - nd4j_printf("BroadcastGradientArgs: Not implemented yet\n", ""); +namespace ops { +/** + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +OP_IMPL(broadcastgradientargs, 2, 2, true) { + nd4j_printf("BroadcastGradientArgs: Not implemented yet\n", ""); - return ND4J_STATUS_KERNEL_FAILURE; - } - DECLARE_SYN(BroadcastGradientArgs, broadcastgradientargs); + return ND4J_STATUS_KERNEL_FAILURE; +} +DECLARE_SYN(BroadcastGradientArgs, broadcastgradientargs); - DECLARE_TYPES(broadcastgradientargs) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(broadcastgradientargs) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h index 7df331c4d1e9..e5e5113b98b2 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/BroadcastHelper.h @@ -22,124 +22,140 @@ #define LIBND4J_BROADCAST_HELPER_H #include +#include #include -#include #include -#include +#include namespace sd { - namespace ops { - class BroadcastHelper { - public: - static FORCEINLINE NDArray* broadcastApply(sd::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { - - if(x->isEmpty() || y->isEmpty()) { - if(!z->isEmpty()) - throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); - return z; - } - - std::unique_ptr ptr; - if (!Environment::getInstance()->isExperimentalBuild()) { - if (y->dataType() != x->dataType()) { - y = new NDArray(y->cast(x->dataType())); - std::unique_ptr ptr2(y); - ptr.swap(ptr2); - } - } - - if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, *y, *z); - } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (x->isScalar() && !y->isScalar()) { - if (z->isSameShape(y)) { - if (op.s == scalar::Add || op.s == scalar::Multiply ) { - y->applyScalarArr(op.s, *x, *z); - } else if (op.s == scalar::SquaredSubtract) { - y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); - } else if (op.s == scalar::Subtract) { - y->applyScalarArr(scalar::ReverseSubtract, *x, *z); - } else if (op.s == scalar::Divide) { - y->applyScalarArr(scalar::ReverseDivide, *x, *z); - } else if (op.s == scalar::Pow) { - y->applyScalarArr(scalar::ReversePow, *x, *z); - } else if (op.s == scalar::ReverseSubtract) { - y->applyScalarArr(scalar::Subtract, *x, *z); - } else if (op.s == scalar::ReverseDivide) { - y->applyScalarArr(scalar::Divide, *x, *z); - } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { - y->applyScalarArr(op.s, *x, *z); - } else if (op.s == scalar::CopyPws) { - z->assign(y); - } else { - z->assign(x); - z->applyPairwiseTransform(op.p, *y, extraArgs); - } - return z; - } else { - auto v = y->getShapeAsVector(); - auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); - tZ->applyPairwiseTransform(op.p, *y, extraArgs); - return tZ; - } - } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else { - auto sx = ShapeUtils::shapeAsString(x); - auto sy = ShapeUtils::shapeAsString(y); - nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); - return nullptr; - } +namespace ops { +class BroadcastHelper { + public: + static FORCEINLINE NDArray* broadcastApply( + sd::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, + ExtraArguments* extraArgs = nullptr) { + if (x->isEmpty() || y->isEmpty()) { + if (!z->isEmpty()) + throw std::runtime_error( + "BroadcastHelper::broadcastApply: when some of input arrays (or " + "both) is empty, output array must be empty as well !"); + return z; + } - return z; - } + std::unique_ptr ptr; + if (!Environment::getInstance()->isExperimentalBuild()) { + if (y->dataType() != x->dataType()) { + y = new NDArray(y->cast(x->dataType())); + std::unique_ptr ptr2(y); + ptr.swap(ptr2); + } + } - static FORCEINLINE NDArray* broadcastApply(sd::BroadcastBoolOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { + if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { + x->applyPairwiseTransform(op.p, *y, *z); + } else if (!x->isScalar() && y->isScalar()) { + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (x->isScalar() && !y->isScalar()) { + if (z->isSameShape(y)) { + if (op.s == scalar::Add || op.s == scalar::Multiply) { + y->applyScalarArr(op.s, *x, *z); + } else if (op.s == scalar::SquaredSubtract) { + y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z); + } else if (op.s == scalar::Subtract) { + y->applyScalarArr(scalar::ReverseSubtract, *x, *z); + } else if (op.s == scalar::Divide) { + y->applyScalarArr(scalar::ReverseDivide, *x, *z); + } else if (op.s == scalar::Pow) { + y->applyScalarArr(scalar::ReversePow, *x, *z); + } else if (op.s == scalar::ReverseSubtract) { + y->applyScalarArr(scalar::Subtract, *x, *z); + } else if (op.s == scalar::ReverseDivide) { + y->applyScalarArr(scalar::Divide, *x, *z); + } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || + op.s == scalar::AMaxPairwise || + op.s == scalar::AMinPairwise) { + y->applyScalarArr(op.s, *x, *z); + } else if (op.s == scalar::CopyPws) { + z->assign(y); + } else { + z->assign(x); + z->applyPairwiseTransform(op.p, *y, extraArgs); + } + return z; + } else { + auto v = y->getShapeAsVector(); + auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); + tZ->applyPairwiseTransform(op.p, *y, extraArgs); + return tZ; + } + } else if (x->isScalar() && + y->isScalar()) { // x->isScalar() && y->isScalar() + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else { + auto sx = ShapeUtils::shapeAsString(x); + auto sy = ShapeUtils::shapeAsString(y); + nd4j_printf( + "Broadcast: shapes should be equal, or broadcastable. But got %s vs " + "%s instead\n", + sx.c_str(), sy.c_str()); + return nullptr; + } - if(x->isEmpty() || y->isEmpty()) { - if(!z->isEmpty()) - throw std::runtime_error("BroadcastHelper::broadcastApply: when some of input arrays (or both) is empty, output array must be empty as well !"); - return z; - } + return z; + } - if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { - x->applyPairwiseTransform(op.p, *y, *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else if (!x->isScalar() && y->isScalar()) { - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (x->isScalar() && !y->isScalar()) { - if (z->isSameShape(y)) { - //z->assign(x); - x->applyPairwiseTransform(op.p, *y, *z, extraArgs); - return z; - } else { - auto v = y->getShapeAsVector(); - auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); - //tZ->applyPairwiseTransform(op.p, *y, extraArgs); - return tZ; - } - } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() - x->applyScalarArr(op.s, const_cast(*y), *z); - } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { - x->applyTrueBroadcast(op, *y, *z, true, extraArgs); - return z; - } else { - auto sx = ShapeUtils::shapeAsString(x); - auto sy = ShapeUtils::shapeAsString(y); - nd4j_printf("Broadcast: shapes should be equal, or broadcastable. But got %s vs %s instead\n", sx.c_str(), sy.c_str()); - return nullptr; - } + static FORCEINLINE NDArray* broadcastApply( + sd::BroadcastBoolOpsTuple op, NDArray* x, NDArray* y, NDArray* z, + ExtraArguments* extraArgs = nullptr) { + if (x->isEmpty() || y->isEmpty()) { + if (!z->isEmpty()) + throw std::runtime_error( + "BroadcastHelper::broadcastApply: when some of input arrays (or " + "both) is empty, output array must be empty as well !"); + return z; + } - return z; - } - }; + if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { + x->applyPairwiseTransform(op.p, *y, *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else if (!x->isScalar() && y->isScalar()) { + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (x->isScalar() && !y->isScalar()) { + if (z->isSameShape(y)) { + // z->assign(x); + x->applyPairwiseTransform(op.p, *y, *z, extraArgs); + return z; + } else { + auto v = y->getShapeAsVector(); + auto tZ = NDArrayFactory::valueOf(v, y, y->ordering()); + // tZ->applyPairwiseTransform(op.p, *y, extraArgs); + return tZ; + } + } else if (x->isScalar() && + y->isScalar()) { // x->isScalar() && y->isScalar() + x->applyScalarArr(op.s, const_cast(*y), *z); + } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { + x->applyTrueBroadcast(op, *y, *z, true, extraArgs); + return z; + } else { + auto sx = ShapeUtils::shapeAsString(x); + auto sy = ShapeUtils::shapeAsString(y); + nd4j_printf( + "Broadcast: shapes should be equal, or broadcastable. But got %s vs " + "%s instead\n", + sx.c_str(), sy.c_str()); + return nullptr; } -} + + return z; + } +}; +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h b/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h index 4d464a7456ce..525d8334931c 100644 --- a/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h +++ b/libnd4j/include/ops/declarable/generic/helpers/ScatterHelper.h @@ -22,18 +22,15 @@ #ifndef LIBND4J_SCATTERHELPER_H #define LIBND4J_SCATTERHELPER_H -#include -#include #include -#include #include +#include +#include +#include namespace sd { -namespace ops { - - -} -} +namespace ops {} +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp index 796dbb80ba36..d79f721e95e4 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_contrast.cpp @@ -21,110 +21,117 @@ #include #if NOT_EXCLUDED(OP_adjust_contrast) -#include #include +#include namespace sd { namespace ops { //////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); -// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_CONTRAST: Scale factor required"); + REQUIRE_TRUE(input->rankOf() > 2, 0, + "ADJUST_CONTRAST: op expects rank of input array to be >= 3, " + "but got %i instead", + input->rankOf()); + // REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation + // expects image with 3 channels (R, G, B), but got %i instead", + // input->sizeAt(-1)); - NDArray* factor = nullptr; + NDArray* factor = nullptr; - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } - // fill up axes vector first - std::vector axes(input->rankOf() - 1); - for (auto i = 0; i < axes.size(); ++i) - axes[i] = i; + // fill up axes vector first + std::vector axes(input->rankOf() - 1); + for (auto i = 0; i < axes.size(); ++i) axes[i] = i; - // mean as reduction for last dimension set - auto mean = input->reduceAlongDimension(reduce::Mean, axes); + // mean as reduction for last dimension set + auto mean = input->reduceAlongDimension(reduce::Mean, axes); - // this is contrast calculation - output->assign((*input - mean) * (*factor) + mean); + // this is contrast calculation + output->assign((*input - mean) * (*factor) + mean); - if(block.width() == 1) - delete factor; + if (block.width() == 1) delete factor; - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(adjust_contrast) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); -// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); - - NDArray* factor = nullptr; - auto size = input->sizeAt(-2) * input->sizeAt(-3); - auto channels = input->sizeAt(-1); - auto batch = input->lengthOf() / (size * channels); - auto input3D = input->reshape(input->ordering(), {batch, size, channels}); - auto output3D = input->reshape(input->ordering(), {batch, size, channels}); - - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } - - std::vector axes({1}); // dim 1 of pseudoresult - -// mean as reduction for last dimension set over size (dim 1) of result3D - auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); - - // result as (x - mean) * factor + mean - auto temp = input3D.ulike(); - input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); - temp.applyScalarArr(scalar::Multiply, *factor, temp); - temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); - output->assign(output3D); - if(block.width() == 1) - delete factor; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(input->rankOf() > 2, 0, + "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, " + "but got %i instead", + input->rankOf()); + // REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation + // expects image with 3 channels (R, G, B), but got %i instead", + // input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_CONTRAST_V2: Scale factor required"); + + NDArray* factor = nullptr; + auto size = input->sizeAt(-2) * input->sizeAt(-3); + auto channels = input->sizeAt(-1); + auto batch = input->lengthOf() / (size * channels); + auto input3D = input->reshape(input->ordering(), {batch, size, channels}); + auto output3D = input->reshape(input->ordering(), {batch, size, channels}); + + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + std::vector axes({1}); // dim 1 of pseudoresult + + // mean as reduction for last dimension set over size (dim 1) of result3D + auto mean = input3D.reduceAlongDimension(reduce::Mean, axes); + + // result as (x - mean) * factor + mean + auto temp = input3D.ulike(); + input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp); + temp.applyScalarArr(scalar::Multiply, *factor, temp); + temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D); + output->assign(output3D); + if (block.width() == 1) delete factor; + + return Status::OK(); } DECLARE_TYPES(adjust_contrast_v2) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp index ff564dd80b8b..8bf5d94ba052 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_hue.cpp @@ -28,55 +28,62 @@ namespace sd { namespace ops { - CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.numI(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); - REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - NDArray* delta = nullptr; - - if(block.width() > 1) - delta = INPUT_VARIABLE(1); - else { - delta = new NDArray(output->dataType(), block.launchContext()); - delta->p(0, T_ARG(0)); - } - - REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); - - helpers::adjustHue(block.launchContext(), input, delta, output, dimC); - - if(block.width() == 1) - delete delta; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_HUE: delta factor is required !"); + REQUIRE_TRUE(rank >= 3, 0, + "ADJUST_HUE: op expects rank of input array to be >= 3, but got " + "%i instead", + rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "ADJUST_HUE: operation expects image with 3 channels (R, G, B), " + "but got %i instead", + input->sizeAt(dimC)); + + NDArray* delta = nullptr; + + if (block.width() > 1) + delta = INPUT_VARIABLE(1); + else { + delta = new NDArray(output->dataType(), block.launchContext()); + delta->p(0, T_ARG(0)); + } + + REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, + "ADJUST_HUE: parameter delta must be within [-1, 1] interval, " + "but got %f instead", + delta); + + helpers::adjustHue(block.launchContext(), input, delta, output, dimC); + + if (block.width() == 1) delete delta; + + return Status::OK(); } DECLARE_TYPES(adjust_hue) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp index 40243f2d68bc..0079d5970675 100644 --- a/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/images/adjust_saturation.cpp @@ -21,58 +21,64 @@ #include #if NOT_EXCLUDED(OP_adjust_saturation) +#include #include #include -#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.numI(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); - - NDArray* factor = nullptr; - - if(block.width() > 1) - factor = INPUT_VARIABLE(1); - else { - factor = new NDArray(output->dataType(), block.launchContext()); - factor->p(0, T_ARG(0)); - } - - helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); - - if(block.width() == 1) - delete factor; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 3, 0, + "ADJUST_SATURATION: op expects rank of input array to be >= 3, " + "but got %i instead", + rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "ADJUST_SATURATION: operation expects image with 3 channels (R, " + "G, B), but got %i instead", + input->sizeAt(dimC)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, + "ADJUST_SATURATION: scale factor is required !"); + + NDArray* factor = nullptr; + + if (block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); + + if (block.width() == 1) delete factor; + + return Status::OK(); } DECLARE_TYPES(adjust_saturation) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp b/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp index b8ce12d64077..1d7f1c46a54f 100644 --- a/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/crop_and_resize.cpp @@ -26,70 +26,79 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(crop_and_resize, 4, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto boxes = INPUT_VARIABLE(1); + auto boxIndexes = INPUT_VARIABLE(2); - auto image = INPUT_VARIABLE(0); - auto boxes = INPUT_VARIABLE(1); - auto boxIndexes = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - int width; - int height; - int method = 0; // bilinear - double extrapolationVal = 0.; + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + int method = 0; // bilinear + double extrapolationVal = 0.; - auto newImageSize = INPUT_VARIABLE(3); - REQUIRE_TRUE(output->dataType() == image->dataType(), 0, "crop_and_resize: Source images and output should have the same data type."); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); - //width = int(newImageSize->getScalar(0)); - //height = int(newImageSize->getScalar(1)); - if (block.numI() == 1) { - method = INT_ARG(0); - } + auto newImageSize = INPUT_VARIABLE(3); + REQUIRE_TRUE(output->dataType() == image->dataType(), 0, + "crop_and_resize: Source images and output should have the same " + "data type."); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "crop_and_resize: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + // REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already + // given by the second param. Int params are expensive."); width = + // int(newImageSize->getScalar(0)); height = int(newImageSize->getScalar(1)); + if (block.numI() == 1) { + method = INT_ARG(0); + } - if (block.numT() == 1) { - extrapolationVal = T_ARG(0); - } + if (block.numT() == 1) { + extrapolationVal = T_ARG(0); + } - helpers::cropAndResizeFunctor(block.launchContext(), image, boxes, boxIndexes, newImageSize, method, extrapolationVal, output); - return ND4J_STATUS_OK; - } + helpers::cropAndResizeFunctor(block.launchContext(), image, boxes, boxIndexes, + newImageSize, method, extrapolationVal, output); + return ND4J_STATUS_OK; +} - DECLARE_SHAPE_FN(crop_and_resize) { - auto in = inputShape->at(0); - auto boxShape = inputShape->at(1); +DECLARE_SHAPE_FN(crop_and_resize) { + auto in = inputShape->at(0); + auto boxShape = inputShape->at(1); - Nd4jLong outputShape[4]; + Nd4jLong outputShape[4]; - int width; - int height; - auto newImageSize = INPUT_VARIABLE(3); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); + int width; + int height; + auto newImageSize = INPUT_VARIABLE(3); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "crop_and_resize: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + // REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already + // given by the second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); - outputShape[0] = boxShape[1]; - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[4]; + outputShape[0] = boxShape[1]; + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[4]; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(in), shape::order(in), outputShape, 4))); - } + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + ArrayOptions::dataType(in), shape::order(in), outputShape, 4))); +} - DECLARE_TYPES(crop_and_resize) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) -// ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedInputTypes(3, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF -// ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(crop_and_resize) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + // ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedInputTypes(3, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF + // ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp index d143bdcf81d9..44ea0676f29e 100644 --- a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp @@ -24,48 +24,59 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(draw_bounding_boxes, 3, 1, true) { +namespace ops { +OP_IMPL(draw_bounding_boxes, 3, 1, true) { + auto images = INPUT_VARIABLE(0); + auto boxes = INPUT_VARIABLE(1); - auto images = INPUT_VARIABLE(0); - auto boxes = INPUT_VARIABLE(1); + auto colors = (NDArray*)nullptr; + if (block.width() > + 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up + colors = INPUT_VARIABLE(2); // but v.2.y require color set - auto colors = (NDArray*) nullptr; - if (block.width() > 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up - colors = INPUT_VARIABLE(2); // but v.2.y require color set - - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(images->dataType() == output->dataType(), 0, "draw_bounding_boxes: Input and Output types " - "should be equals, but %d and %d occured.", - (int)images->dataType(), (int)output->dataType()); - REQUIRE_TRUE(images->rankOf() == 4, 0, "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.", - images->rankOf()); - REQUIRE_TRUE(boxes->rankOf() == 3, 0, "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.", - boxes->rankOf()); - if (colors) { - - REQUIRE_TRUE(colors->rankOf() == 2, 0, "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.", - colors->rankOf()); - REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, "draw_bounding_boxes: Color set last dim " - "should be not less than images depth, but " - "%lld and %lld occured.", - colors->sizeAt(1), images->sizeAt(3)); - } - REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, "draw_bounding_boxes: Batches for images and boxes " - "should be the same, but %lld and %lld occured.", - images->sizeAt(0), boxes->sizeAt(0)); - helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); - return ND4J_STATUS_OK; - } + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(images->dataType() == output->dataType(), 0, + "draw_bounding_boxes: Input and Output types " + "should be equals, but %d and %d occured.", + (int)images->dataType(), (int)output->dataType()); + REQUIRE_TRUE( + images->rankOf() == 4, 0, + "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.", + images->rankOf()); + REQUIRE_TRUE( + boxes->rankOf() == 3, 0, + "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.", + boxes->rankOf()); + if (colors) { + REQUIRE_TRUE( + colors->rankOf() == 2, 0, + "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.", + colors->rankOf()); + REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, + "draw_bounding_boxes: Color set last dim " + "should be not less than images depth, but " + "%lld and %lld occured.", + colors->sizeAt(1), images->sizeAt(3)); + } + REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, + "draw_bounding_boxes: Batches for images and boxes " + "should be the same, but %lld and %lld occured.", + images->sizeAt(0), boxes->sizeAt(0)); + helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, + colors, output); + return ND4J_STATUS_OK; +} - DECLARE_TYPES(draw_bounding_boxes) { - getOpDescriptor() - ->setAllowedInputTypes(0, {HALF, FLOAT32})// TF allows HALF and FLOAT32 only - ->setAllowedInputTypes(1, {FLOAT32}) // as TF - ->setAllowedInputTypes(2, {FLOAT32}) // as TF - ->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only - } - } +DECLARE_TYPES(draw_bounding_boxes) { + getOpDescriptor() + ->setAllowedInputTypes( + 0, {HALF, FLOAT32}) // TF allows HALF and FLOAT32 only + ->setAllowedInputTypes(1, {FLOAT32}) // as TF + ->setAllowedInputTypes(2, {FLOAT32}) // as TF + ->setAllowedOutputTypes( + {HALF, FLOAT32}); // TF allows HALF and FLOAT32 only } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp index 9bb24cf536a0..209fbcb946cf 100644 --- a/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp +++ b/libnd4j/include/ops/declarable/generic/images/extract_image_patches.cpp @@ -22,76 +22,78 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(extract_image_patches, 1, 1, false, 0, 7) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - int ksizeRows = INT_ARG(0); - int ksizeCols = INT_ARG(1); - int kstrideRows = INT_ARG(2); - int kstrideCols = INT_ARG(3); - int krateRows = INT_ARG(4); - int krateCols = INT_ARG(5); - bool isSame = INT_ARG(6) != 0; - - REQUIRE_TRUE(input->rankOf() == 4, 0, "extract_image_patches: The rank of input array should be 4, but %i is given", input->rankOf()); - // - if (output->isSameShape(input)) - output->assign(input); - else { - output->nullify(); - helpers::extractPatches(block.launchContext(), input, output, ksizeRows, ksizeCols, kstrideRows, kstrideCols, krateRows, krateCols, isSame); - } - return Status::OK(); - } - - DECLARE_TYPES(extract_image_patches) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(extract_image_patches) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong *outputShape = nullptr; - - int ksizeRowsEffective = INT_ARG(0) + (INT_ARG(0) - 1) * (INT_ARG(4) - 1); - int ksizeColsEffective = INT_ARG(1) + (INT_ARG(1) - 1) * (INT_ARG(5) - 1); - - auto batchSizeDim = shape::sizeAt(in, 0); - auto inputRowsDim = shape::sizeAt(in, 1); - auto inputColsDim = shape::sizeAt(in, 2); - auto outputDepthDim = shape::sizeAt(in, 3) * INT_ARG(0) * INT_ARG(1); // last dim * ksizeRows * ksizeCols - - auto inputRowSize = inputRowsDim; //shape::sizeAt(in, inputRowsDim); - auto inputColSize = inputColsDim; //shape::sizeAt(in, inputColsDim); - Nd4jLong outRowSize; - Nd4jLong outColSize; - if (INT_ARG(6) == 0) { - // Padding is "VALID": - outRowSize = (inputRowSize - ksizeRowsEffective + INT_ARG(2)) / INT_ARG(2); - outColSize = (inputColSize - ksizeColsEffective + INT_ARG(3)) / INT_ARG(3); - } else { - // Padding is "SAME": - outRowSize = (inputRowSize + INT_ARG(2) - 1) / INT_ARG(2); - outColSize = (inputColSize + INT_ARG(3) - 1) / INT_ARG(3); - } - - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = batchSizeDim; - outputShape[2] = outRowSize; - outputShape[3] = outColSize; - outputShape[4] = outputDepthDim; - - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(extract_image_patches, 1, 1, false, 0, 7) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + int ksizeRows = INT_ARG(0); + int ksizeCols = INT_ARG(1); + int kstrideRows = INT_ARG(2); + int kstrideCols = INT_ARG(3); + int krateRows = INT_ARG(4); + int krateCols = INT_ARG(5); + bool isSame = INT_ARG(6) != 0; + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "extract_image_patches: The rank of input array should be 4, " + "but %i is given", + input->rankOf()); + // + if (output->isSameShape(input)) + output->assign(input); + else { + output->nullify(); + helpers::extractPatches(block.launchContext(), input, output, ksizeRows, + ksizeCols, kstrideRows, kstrideCols, krateRows, + krateCols, isSame); + } + return Status::OK(); +} + +DECLARE_TYPES(extract_image_patches) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(extract_image_patches) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong *outputShape = nullptr; + + int ksizeRowsEffective = INT_ARG(0) + (INT_ARG(0) - 1) * (INT_ARG(4) - 1); + int ksizeColsEffective = INT_ARG(1) + (INT_ARG(1) - 1) * (INT_ARG(5) - 1); + + auto batchSizeDim = shape::sizeAt(in, 0); + auto inputRowsDim = shape::sizeAt(in, 1); + auto inputColsDim = shape::sizeAt(in, 2); + auto outputDepthDim = shape::sizeAt(in, 3) * INT_ARG(0) * + INT_ARG(1); // last dim * ksizeRows * ksizeCols + + auto inputRowSize = inputRowsDim; // shape::sizeAt(in, inputRowsDim); + auto inputColSize = inputColsDim; // shape::sizeAt(in, inputColsDim); + Nd4jLong outRowSize; + Nd4jLong outColSize; + if (INT_ARG(6) == 0) { + // Padding is "VALID": + outRowSize = (inputRowSize - ksizeRowsEffective + INT_ARG(2)) / INT_ARG(2); + outColSize = (inputColSize - ksizeColsEffective + INT_ARG(3)) / INT_ARG(3); + } else { + // Padding is "SAME": + outRowSize = (inputRowSize + INT_ARG(2) - 1) / INT_ARG(2); + outColSize = (inputColSize + INT_ARG(3) - 1) / INT_ARG(3); + } + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = batchSizeDim; + outputShape[2] = outRowSize; + outputShape[3] = outColSize; + outputShape[4] = outputDepthDim; + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp index 43cae28c50a1..be4cc399dbd2 100644 --- a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp @@ -16,44 +16,49 @@ // // @author AbdelRauf (rauf@konduit.ai) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + if (input->isEmpty()) return Status::OK(); - if (input->isEmpty()) - return Status::OK(); + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - const int rank = input->rankOf(); - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + REQUIRE_TRUE(rank >= 1, 0, + "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", + input->sizeAt(dimC)); - REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + helpers::transformHsvRgb(block.launchContext(), input, output, dimC); - helpers::transformHsvRgb(block.launchContext(), input, output, dimC); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(hsv_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp index 4787fc6897a5..73f849186d31 100644 --- a/libnd4j/include/ops/declarable/generic/images/image_resize.cpp +++ b/libnd4j/include/ops/declarable/generic/images/image_resize.cpp @@ -25,67 +25,72 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(image_resize, 2, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE(1); - auto image = INPUT_VARIABLE(0); - auto size = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + bool preserveAspectRatio = false; // - default value + bool antialias = false; + REQUIRE_TRUE(size->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %lld.", + size->lengthOf()); + width = size->e(0); + height = size->e(1); + if (block.numB()) { + preserveAspectRatio = B_ARG(0); + if (block.numB() > 1) antialias = B_ARG(1); + } - auto output = OUTPUT_VARIABLE(0); - int width; - int height; - bool preserveAspectRatio = false; // - default value - bool antialias = false; - REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %lld.", size->lengthOf()); - width = size->e(0); - height = size->e(1); - if (block.numB()) { - preserveAspectRatio = B_ARG(0); - if (block.numB() > 1) - antialias = B_ARG(1); - } + auto method = helpers::ImageResizeMethods::kResizeBilinear; + if (block.numI() == 1) { + method = (helpers::ImageResizeMethods)INT_ARG(0); + } - auto method = helpers::ImageResizeMethods::kResizeBilinear; - if (block.numI() == 1) { - method = (helpers::ImageResizeMethods)INT_ARG(0); - } - - return helpers::resizeFunctor(block.launchContext(), image, width, height, method, preserveAspectRatio, antialias, output); - } + return helpers::resizeFunctor(block.launchContext(), image, width, height, + method, preserveAspectRatio, antialias, output); +} - DECLARE_SHAPE_FN(image_resize) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); +DECLARE_SHAPE_FN(image_resize) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); - Nd4jLong* outputShape; + Nd4jLong* outputShape; - int width; - int height; - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); - outputShape[0] = 4; - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + int width; + int height; + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bilinear: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(image_resize) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); + outputShape[0] = 4; + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - } + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; } +DECLARE_TYPES(image_resize) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp index 8ed5c4954abf..f0289ac4ed56 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_area.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_area.cpp @@ -25,98 +25,134 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) { - - auto image = INPUT_VARIABLE(0); - int width; - int height; - - if (block.width() == 2) { - auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) - REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_area: Resize params is a pair of values, not %i.", size->lengthOf()); - size->syncToHost(); - width = size->e(1); - height = size->e(0); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params already given by the second param. Int params are expensive."); - width = INT_ARG(1); - height = INT_ARG(0); - } - - auto output = OUTPUT_VARIABLE(0); - if (output->isEmpty()) return Status::OK(); - auto inRank = image->rankOf(); - - REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); - REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_area: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); - REQUIRE_TRUE(width > 0 , 0, "resize_area: picture width should be positive 32 bit integer, but %i given", width); - REQUIRE_TRUE(height > 0 , 0, "resize_area: picture height should be positive 32 bit integer, but %i given", height); - REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_area: Only non-zero images allowed to processing."); - - auto alignCorners = false; - if (block.numB() > 0) { - alignCorners = B_ARG(0); - } - - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target); - } - - DECLARE_SHAPE_FN(resize_area) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - int width; - int height; - if (block.width() == 2) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, - "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, - "resize_area: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(1); - height = newImageSize->e(0); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); - width = INT_ARG(1); - height = INT_ARG(0); - } - - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = height; - outputShape[3] = width; - outputShape[4] = in[4]; - } - else { - outputShape[1] = height; - outputShape[2] = width; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); - - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_area) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({DataType::FLOAT32}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) { + auto image = INPUT_VARIABLE(0); + int width; + int height; + + if (block.width() == 2) { + auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content + // (new_height, new_width) + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, + "resize_area: Resize params is a pair of values, not %i.", + size->lengthOf()); + size->syncToHost(); + width = size->e(1); + height = size->e(0); + } else { + REQUIRE_TRUE(block.numI() == 2, 0, + "resize_area: Resize params already given by the second " + "param. Int params are expensive."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + auto output = OUTPUT_VARIABLE(0); + if (output->isEmpty()) return Status::OK(); + auto inRank = image->rankOf(); + + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, + "resize_area: Source tensor should have rank 4, but %i given.", + inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, + "resize_area: Source tensor and output should have the same " + "rank, but %i and %i given.", + inRank, output->rankOf()); + REQUIRE_TRUE(width > 0, 0, + "resize_area: picture width should be positive 32 bit integer, " + "but %i given", + width); + REQUIRE_TRUE(height > 0, 0, + "resize_area: picture height should be positive 32 bit integer, " + "but %i given", + height); + REQUIRE_TRUE(image->lengthOf() > 0, 0, + "resize_area: Only non-zero images allowed to processing."); + + auto alignCorners = false; + if (block.numB() > 0) { + alignCorners = B_ARG(0); + } + + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); + + return helpers::resizeAreaFunctor(block.launchContext(), &source, width, + height, alignCorners, &target); } +DECLARE_SHAPE_FN(resize_area) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + if (block.width() == 2) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_area: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_area: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(1); + height = newImageSize->e(0); + } else { + REQUIRE_TRUE( + block.numI() == 2, 0, + "resize_area: Resize params ommited as pair ints nor int tensor."); + width = INT_ARG(1); + height = INT_ARG(0); + } + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_area: Source tensor should have rank 4, but %i given.", + inRank); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = height; + outputShape[3] = width; + outputShape[4] = in[4]; + } else { + outputShape[1] = height; + outputShape[2] = width; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; +} +DECLARE_TYPES(resize_area) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp index 28bfaad97705..e24f8d20ed4c 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_bicubic.cpp @@ -25,90 +25,136 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { + auto image = INPUT_VARIABLE(0); + auto size = INPUT_VARIABLE( + 1); // integer vector with shape {2} and content (new_height, new_width) + size->syncToHost(); + auto output = OUTPUT_VARIABLE(0); + int width; + int height; + auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); - auto image = INPUT_VARIABLE(0); - auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width) - size->syncToHost(); - auto output = OUTPUT_VARIABLE(0); - int width; - int height; - auto inRank = image->rankOf(); - if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE( + inRank == 3 || inRank == 4, 0, + "resize_bicubic: Source tensor should have rank 4, but %i given.", + inRank); + REQUIRE_TRUE(output->rankOf() == inRank, 0, + "resize_bicubic: Source tensor and output should have the same " + "rank, but %i and %i given.", + inRank, output->rankOf()); + REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, + "resize_bicubic: Resize params is a pair of values, not %i.", + size->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bicubic: Resize params already given by the second " + "param. Int params are expensive."); + width = size->e(1); + height = size->e(0); + REQUIRE_TRUE(width > 0, 0, + "resize_bicubic: picture width should be positive 32 bit " + "integer, but %i given", + width); + REQUIRE_TRUE(height > 0, 0, + "resize_bicubic: picture height should be positive 32 bit " + "integer, but %i given", + height); + // REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, + // "resize_cubic: To use bicubic algorithm need at least 16 pixels as + // source."); + REQUIRE_TRUE(width > 3 && height > 3, 0, + "resize_bicubic: To use bicubic algorithm need at least 16 " + "pixels as target."); + REQUIRE_TRUE(image->lengthOf() > 0, 0, + "resize_bicubic: Only non-zero images allowed to processing."); + // auto method = 1; //kResizeBilinear; + // if (block.numI() == 1) { + // method = INT_ARG(0); + // } + auto alignCorners = false; + auto halfPixelAlign = false; + if (block.numB() > 0) { + alignCorners = block.getBArguments().at(0); + if (block.numB() > 1) halfPixelAlign = block.getBArguments().at(1); + } + REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, + "resize_bicubic: `half_pixel_centers' should be false or true " + "only when `align_corners' is false"); - REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); - REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); - REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); - width = size->e(1); - height = size->e(0); - REQUIRE_TRUE(width > 0 , 0, "resize_bicubic: picture width should be positive 32 bit integer, but %i given", width); - REQUIRE_TRUE(height > 0 , 0, "resize_bicubic: picture height should be positive 32 bit integer, but %i given", height); - //REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, "resize_cubic: To use bicubic algorithm need at least 16 pixels as source."); - REQUIRE_TRUE(width > 3 && height > 3, 0, "resize_bicubic: To use bicubic algorithm need at least 16 pixels as target."); - REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_bicubic: Only non-zero images allowed to processing."); -// auto method = 1; //kResizeBilinear; -// if (block.numI() == 1) { -// method = INT_ARG(0); -// } - auto alignCorners = false; - auto halfPixelAlign = false; - if (block.numB() > 0) { - alignCorners = block.getBArguments().at(0); - if (block.numB()> 1) - halfPixelAlign = block.getBArguments().at(1); - } - REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target); - } - - DECLARE_SHAPE_FN(resize_bicubic) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); + return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, + height, alignCorners, halfPixelAlign, + &target); +} - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - int width; - int height; - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); +DECLARE_SHAPE_FN(resize_bicubic) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + int width; + int height; + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bicubic: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bicubic: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); + REQUIRE_TRUE( + inRank == 4 || inRank == 3, 0, + "resize_bicubic: Source tensor should have rank 4, but %i given.", + inRank); - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_bicubic) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, DataType::INT32) - ->setAllowedOutputTypes({DataType::FLOAT32}); - } + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); - } + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; } +DECLARE_TYPES(resize_bicubic) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp index f5f89fe23f0b..fb1a40a4baf9 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_linear.cpp @@ -26,105 +26,135 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_bilinear, 1, 1, false, 0, -2) { - - NDArray* image = INPUT_VARIABLE(0); - NDArray* output = OUTPUT_VARIABLE(0); - int width; - int height; - bool alignCorners = false; // - default value - auto inRank = image->rankOf(); - if (output->isEmpty()) return Status::OK(); - - REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " - "tensor, but input has rank %i", - image->rankOf()); - REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); - - auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - height = newImageSize->e(0); - width = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided."); - height = INT_ARG(0); - width = INT_ARG(1); - } - - if (block.numB() > 0) - alignCorners = B_ARG(0); - bool halfPixelCenter = false; - - if (block.numB() > 1) - halfPixelCenter = B_ARG(1); - - REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false"); - - return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); - } - - DECLARE_SHAPE_FN(resize_bilinear) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - - Nd4jLong* outputShape; - auto inRank = shape::rank(in); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " - "tensor, but input has rank %i", - inRank); - - int width; - int height; - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - } - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { // input shape is 3D, so result also should be 3D - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - } - else { - ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); - } - - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_bilinear) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(resize_bilinear, 1, 1, false, 0, -2) { + NDArray* image = INPUT_VARIABLE(0); + NDArray* output = OUTPUT_VARIABLE(0); + int width; + int height; + bool alignCorners = false; // - default value + auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + image->rankOf()); + REQUIRE_TRUE(inRank == output->rankOf(), 0, + "resize_bilinear: Input and output ranks should be equals, but " + "%i and %i occured.", + inRank, output->rankOf()); + + auto source = inRank == 4 + ? image->reshape(image->ordering(), + {image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2), image->sizeAt(3)}) + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? output->reshape(output->ordering(), + {output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2), output->sizeAt(3)}, + false) + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); + + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bilinear: Resize params already given by the second " + "param. Int params are expensive."); + height = newImageSize->e(0); + width = newImageSize->e(1); + } else { + REQUIRE_TRUE( + block.numI() > 1, 0, + "resize_bilinear: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); + } + + if (block.numB() > 0) alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) halfPixelCenter = B_ARG(1); + + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, + "resize_bilinear: `half_pixel_centers' should be false or true " + "only when `align_corners' is false"); + + return helpers::resizeBilinearFunctor( + block.launchContext(), inRank == 4 ? image : &source, width, height, + alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } +DECLARE_SHAPE_FN(resize_bilinear) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + + Nd4jLong* outputShape; + auto inRank = shape::rank(in); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_bilinear: input image should be 4D " + "tensor, but input has rank %i", + inRank); + + int width; + int height; + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, + "resize_bilinear: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_bilinear: Resize params already given by the second " + "param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + } else { + REQUIRE_TRUE( + block.numI() == 2, 0, + "resize_bilinear: Neither resize width nor height are provided."); + width = INT_ARG(0); + height = INT_ARG(1); + } + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } else { + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, + shape::order(in)); + } + + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; +} +DECLARE_TYPES(resize_bilinear) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp index 46a8949b6e99..80e912a2509e 100644 --- a/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/images/resize_neighbor.cpp @@ -27,97 +27,137 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) { +namespace ops { +CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) { + auto image = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + auto inRank = image->rankOf(); + int width; + int height; + bool alignCorners = false; // - default value + if (output->isEmpty()) return Status::OK(); + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE( + newImageSize->lengthOf() == 2, 0, + "resize_nearest_neighbor: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_nearest_neighbor: Resize params already given by the " + "second param. Int params are expensive."); + height = newImageSize->e(0); + width = newImageSize->e(1); + } else { + REQUIRE_TRUE(block.numI() == 2, 0, + "resize_nearest_neighbor: Neither resize width nor height are " + "provided."); + height = INT_ARG(0); + width = INT_ARG(1); + } + if (block.numB() > 0) alignCorners = B_ARG(0); + bool halfPixelCenter = false; - auto image = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - auto inRank = image->rankOf(); - int width; - int height; - bool alignCorners = false; // - default value - if (output->isEmpty()) return Status::OK(); - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - height = newImageSize->e(0); - width = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - height = INT_ARG(0); - width = INT_ARG(1); - } - if (block.numB() > 0) - alignCorners = B_ARG(0); - bool halfPixelCenter = false; + if (block.numB() > 1) halfPixelCenter = B_ARG(1); + REQUIRE_TRUE( + width <= (1 << 24) || height <= (1 << 24), 0, + "resize_nearest_neighbor: the image resize should be limited to 2^24 " + "pixels both for height and width, but %d and %d were given.", + height, width); + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_nearest_neighbor: Input should be 4D tensor, but rank " + "%i occured"); + REQUIRE_TRUE(inRank == output->rankOf(), 0, + "resize_nearest_neighbor: Input and output ranks should be " + "equals, but %i and %i occured.", + inRank, output->rankOf()); + REQUIRE_TRUE(image->dataType() == output->dataType(), 0, + "resize_nearest_neighbor: Input and output types should be the " + "same, but `%s' occured instead.", + DataTypeUtils::asString(output->dataType()).c_str()); + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, + "resize_nearest_neighbor: `half_pixel_centers' should be false " + "or true only when `align_corners' is false"); + REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && + ((alignCorners && width > 1) || (width > 0)), + 0, + "resize_nearest_neighbor: Wrong input or output size to resize " + "(width = %d, height = %d)", + width, height); - if (block.numB() > 1) - halfPixelCenter = B_ARG(1); - REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbor: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); - REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); - REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str()); - REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false"); - REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); + auto source = inRank == 4 + ? *image + : image->reshape(image->ordering(), + {1, image->sizeAt(0), image->sizeAt(1), + image->sizeAt(2)}); + auto target = inRank == 4 + ? *output + : output->reshape(output->ordering(), + {1, output->sizeAt(0), output->sizeAt(1), + output->sizeAt(2)}, + false); - auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); - auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); - - return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); - } - - DECLARE_SHAPE_FN(resize_nearest_neighbor) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - auto inRank = shape::rank(in); - Nd4jLong* outputShape; + return helpers::resizeNeighborFunctor( + block.launchContext(), inRank == 4 ? image : &source, width, height, + alignCorners, halfPixelCenter, inRank == 4 ? output : &target); +} - REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: input image should be 4D " - "tensor, but input has rank %i", - inRank); +DECLARE_SHAPE_FN(resize_nearest_neighbor) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + auto inRank = shape::rank(in); + Nd4jLong* outputShape; - int width; - int height; - if (block.width() > 1) { - auto newImageSize = INPUT_VARIABLE(1); - REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); - REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - } - else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - } + REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, + "resize_nearest_neighbor: input image should be 4D " + "tensor, but input has rank %i", + inRank); - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - outputShape[0] = inRank; - if (inRank == 4) { - outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; - outputShape[4] = in[4]; - } - else { // input shape is 3D, so result also should be 3D - outputShape[1] = width; - outputShape[2] = height; - outputShape[3] = in[3]; - } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + int width; + int height; + if (block.width() > 1) { + auto newImageSize = INPUT_VARIABLE(1); + REQUIRE_TRUE( + newImageSize->lengthOf() == 2, 0, + "resize_nearest_neighbor: Resize params is a pair of values, not %i.", + newImageSize->lengthOf()); + REQUIRE_TRUE(block.numI() <= 1, 0, + "resize_nearest_neighbor: Resize params already given by the " + "second param. Int params are expensive."); + width = newImageSize->e(0); + height = newImageSize->e(1); + } else { + REQUIRE_TRUE(block.numI() <= 3, 0, + "resize_nearest_neighbor: Neither resize width nor height are " + "provided."); + width = INT_ARG(0); + height = INT_ARG(1); + } - shapeList->push_back(CONSTANT(outputShape)); - return shapeList; - } - DECLARE_TYPES(resize_nearest_neighbor) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + outputShape[0] = inRank; + if (inRank == 4) { + outputShape[1] = in[1]; + outputShape[2] = width; + outputShape[3] = height; + outputShape[4] = in[4]; + } else { // input shape is 3D, so result also should be 3D + outputShape[1] = width; + outputShape[2] = height; + outputShape[3] = in[3]; + } + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - } + shapeList->push_back(CONSTANT(outputShape)); + return shapeList; } +DECLARE_TYPES(resize_nearest_neighbor) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp index 348cce162353..0d3e110f02cb 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp @@ -18,57 +18,77 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(rgb_to_grs, 1, 1, false, 0, 0) { + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + const int inRank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) + : inRank - 1; - REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBGrayScale: operation expects 3 channels (R, G, B) in last dimention, but received %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(inRank >= 1, 0, + "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", + inRank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < inRank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "RGBGrayScale: operation expects 3 channels (R, G, B) in last " + "dimention, but received %i instead", + input->sizeAt(dimC)); - helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); - return Status::OK(); + helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); + return Status::OK(); } DECLARE_TYPES(rgb_to_grs) { - getOpDescriptor()->setAllowedInputTypes( {ALL_INTS, ALL_FLOATS} ) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setSameMode(true); } DECLARE_SHAPE_FN(rgb_to_grs) { + const auto input = INPUT_VARIABLE(0); + const int inRank = input->rankOf(); - const auto input = INPUT_VARIABLE(0); - const int inRank = input->rankOf(); - - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) + : inRank - 1; - REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last dimention, but received %i", dimC); + REQUIRE_TRUE(inRank >= 1, 0, + "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", + inRank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < inRank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, + "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last " + "dimention, but received %i", + dimC); - auto nShape = input->getShapeAsVector(); - nShape[dimC] = 1; + auto nShape = input->getShapeAsVector(); + nShape[dimC] = 1; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(input->dataType(), input->ordering(), nShape)); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + input->dataType(), input->ordering(), nShape)); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp index 026c93749a95..ae5931cff989 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp @@ -18,45 +18,47 @@ // @author AbdelRauf (rauf@konduit.ai) // - - -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - helpers::transformRgbHsv(block.launchContext(), input, output, dimC); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); + + helpers::transformRgbHsv(block.launchContext(), input, output, dimC); + + return Status::OK(); } - DECLARE_TYPES(rgb_to_hsv) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp index bf7d4f32236f..aaccd9720772 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYiq.cpp @@ -14,47 +14,50 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author AbdelRauf (rauf@konduit.ai) - // +// +// @author AbdelRauf (rauf@konduit.ai) +// -#include -#include -#include #include +#include +#include +#include namespace sd { - namespace ops { - - - - CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.numI(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - - helpers::transformRgbYiq(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - +namespace ops { + +CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); + + helpers::transformRgbYiq(block.launchContext(), input, output, dimC); + + return Status::OK(); +} - DECLARE_TYPES(rgb_to_yiq) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - } +DECLARE_TYPES(rgb_to_yiq) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp index 9bf40a2c1256..cb9418d6d0fa 100644 --- a/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp +++ b/libnd4j/include/ops/declarable/generic/images/rgbToYuv.cpp @@ -18,44 +18,48 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // - - -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - const int rank = input->rankOf(); - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(rank >= 1, 0, + "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", + input->sizeAt(dimC)); - helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC); + helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC); - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(rgb_to_yuv) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp index 08f4be2b77a6..14841f089460 100644 --- a/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yiqToRgb.cpp @@ -16,46 +16,49 @@ // // @author AbdelRauf (rauf@konduit.ai) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { - namespace ops { +namespace ops { +CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + if (input->isEmpty()) return Status::OK(); - CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) { + const int rank = input->rankOf(); + const int arg_size = block.numI(); + const int dimC = arg_size > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(rank >= 1, 0, + "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", + input->sizeAt(dimC)); - if (input->isEmpty()) - return Status::OK(); + helpers::transformYiqRgb(block.launchContext(), input, output, dimC); - const int rank = input->rankOf(); - const int arg_size = block.numI(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC)); - - helpers::transformYiqRgb(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - + return Status::OK(); +} - DECLARE_TYPES(yiq_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +DECLARE_TYPES(yiq_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp index 1665ea131939..a1e3b62d3f36 100644 --- a/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp +++ b/libnd4j/include/ops/declarable/generic/images/yuvToRgb.cpp @@ -16,45 +16,50 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); + const int rank = input->rankOf(); + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; - const int rank = input->rankOf(); - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + REQUIRE_TRUE(rank >= 1, 0, + "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE( + dimC >= 0 && dimC < rank, 0, + "Index of the Channel dimension out of range: %i not in [%i,%i) ", + INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE( + input->sizeAt(dimC) == 3, 0, + "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", + input->sizeAt(dimC)); - REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (argSize > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", input->sizeAt(dimC)); + helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC); - helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(yuv_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp index 8ef699aa208b..af22d83f59d7 100644 --- a/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp +++ b/libnd4j/include/ops/declarable/generic/kernels/knn_mindistance.cpp @@ -25,35 +25,41 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto lowest = INPUT_VARIABLE(1); - auto highest = INPUT_VARIABLE(2); +namespace ops { +CUSTOM_OP_IMPL(knn_mindistance, 3, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto lowest = INPUT_VARIABLE(1); + auto highest = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && input->lengthOf() == highest->lengthOf(), 0, "knn_mindistance: all input arrays must have same length"); - REQUIRE_TRUE(input->dataType() == lowest->dataType() && input->dataType() == highest->dataType() && input->dataType() == output->dataType(), 0, "knn_mindistance: all inputs must have the same data type"); + REQUIRE_TRUE(input->lengthOf() == lowest->lengthOf() && + input->lengthOf() == highest->lengthOf(), + 0, "knn_mindistance: all input arrays must have same length"); + REQUIRE_TRUE(input->dataType() == lowest->dataType() && + input->dataType() == highest->dataType() && + input->dataType() == output->dataType(), + 0, "knn_mindistance: all inputs must have the same data type"); - helpers::knn_mindistance(*input, *lowest, *highest, *output); + helpers::knn_mindistance(*input, *lowest, *highest, *output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(knn_mindistance) { - auto input = inputShape->at(0); +DECLARE_SHAPE_FN(knn_mindistance) { + auto input = inputShape->at(0); - // always return scalar here - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input))); - } + // always return scalar here + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(input))); +} - DECLARE_TYPES(knn_mindistance) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(knn_mindistance) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp b/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp index 1850f10a1651..92af4e719673 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/betaInc.cpp @@ -25,48 +25,55 @@ #include namespace sd { -namespace ops { +namespace ops { - DECLARE_TYPES(betainc) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(betainc) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto x = INPUT_VARIABLE(2); - - // just skip op if input is empty - if (x->isEmpty()) { - *x = DataTypeUtils::nanOrZero(); - return Status::OK(); - } - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str()); - - Nd4jLong arrLen = a->lengthOf(); - - // FIXME: this stuff should be single op call. No sense rolling over couple of arrays twice - for(Nd4jLong i = 0; i < arrLen; ++i ) { - REQUIRE_TRUE(a->e(i) > 0.f, 0, "BETAINC op: arrays a array must contain only elements > 0 !"); - REQUIRE_TRUE(b->e(i) > 0.f, 0, "BETAINC op: arrays b array must contain only elements > 0 !"); - REQUIRE_TRUE(0.f <= x->e(i) && x->e(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!"); - } - - helpers::betaInc(block.launchContext(), *a, *b, *x, *output); + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto x = INPUT_VARIABLE(2); + // just skip op if input is empty + if (x->isEmpty()) { + *x = DataTypeUtils::nanOrZero(); return Status::OK(); + } + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, + "CONFIGURABLE_OP betainc: all three input arrays must have the " + "same shapes, bit got a=%s, b=%s and x=%s instead !", + ShapeUtils::shapeAsString(a).c_str(), + ShapeUtils::shapeAsString(b).c_str(), + ShapeUtils::shapeAsString(x).c_str()); + + Nd4jLong arrLen = a->lengthOf(); + + // FIXME: this stuff should be single op call. No sense rolling over couple of + // arrays twice + for (Nd4jLong i = 0; i < arrLen; ++i) { + REQUIRE_TRUE(a->e(i) > 0.f, 0, + "BETAINC op: arrays a array must contain only elements > 0 !"); + REQUIRE_TRUE(b->e(i) > 0.f, 0, + "BETAINC op: arrays b array must contain only elements > 0 !"); + REQUIRE_TRUE( + 0.f <= x->e(i) && x->e(i) <= 1.f, 0, + "BETAINC op: all elements of x array must be within [0, 1] range!"); + } + + helpers::betaInc(block.launchContext(), *a, *b, *x, *output); + + return Status::OK(); } DECLARE_SYN(BetaInc, betainc); DECLARE_SYN(betaInc, betainc); - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp b/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp index dfc3830ca6ec..12b24e23a71e 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/cholesky.cpp @@ -24,23 +24,32 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(cholesky, 1, 1, true) { - NDArray* input = INPUT_VARIABLE(0); - NDArray* output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(cholesky, 1, 1, true) { + NDArray* input = INPUT_VARIABLE(0); + NDArray* output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() >=2, 0, "cholesky: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "cholesky: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "cholesky: The input tensor should be positive-defined and symmetric."); - return helpers::cholesky(block.launchContext(), input, output, block.isInplace()); - } - DECLARE_TYPES(cholesky) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - } + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "cholesky: The rank of input array should not less than 2, but " + "%i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "cholesky: The last two dimmensions should be equal, but %i and " + "%i are given", + input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + helpers::checkCholeskyInput(block.launchContext(), input), 0, + "cholesky: The input tensor should be positive-defined and symmetric."); + return helpers::cholesky(block.launchContext(), input, output, + block.isInplace()); +} +DECLARE_TYPES(cholesky) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/cross.cpp b/libnd4j/include/ops/declarable/generic/linalg/cross.cpp index 0d701cf7191e..daf2718f43c8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/cross.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/cross.cpp @@ -26,35 +26,38 @@ namespace sd { namespace ops { - DECLARE_TYPES(cross) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->setSameMode(true); - } - - OP_IMPL(cross, 2, 1, false) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - REQUIRE_TRUE(a->lengthOf() == b->lengthOf(), 0, "Cross: A and B lengths should match"); - REQUIRE_TRUE(a->rankOf() >= 1 && b->rankOf() >= 1, 0, "Cross: A and B should have rank >= 1"); - - // TODO: we might want to lift this restriction - REQUIRE_TRUE(a->isSameShape(b),0, "Cross: A and B should have equal shape"); - REQUIRE_TRUE(a->sizeAt(-1) == 3, 0, "Cross: outer dimension of A and B should be equal to 3"); - - auto o = OUTPUT_VARIABLE(0); - - if (a->lengthOf() == 3) { - helpers::cross(block.launchContext(), a, b, o); - } else { - helpers::crossBatched(block.launchContext(), a, b, o); - } - - return Status::OK(); - } +DECLARE_TYPES(cross) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->setSameMode(true); } + +OP_IMPL(cross, 2, 1, false) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + + REQUIRE_TRUE(a->lengthOf() == b->lengthOf(), 0, + "Cross: A and B lengths should match"); + REQUIRE_TRUE(a->rankOf() >= 1 && b->rankOf() >= 1, 0, + "Cross: A and B should have rank >= 1"); + + // TODO: we might want to lift this restriction + REQUIRE_TRUE(a->isSameShape(b), 0, "Cross: A and B should have equal shape"); + REQUIRE_TRUE(a->sizeAt(-1) == 3, 0, + "Cross: outer dimension of A and B should be equal to 3"); + + auto o = OUTPUT_VARIABLE(0); + + if (a->lengthOf() == 3) { + helpers::cross(block.launchContext(), a, b, o); + } else { + helpers::crossBatched(block.launchContext(), a, b, o); + } + + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/diag.cpp index d67ca057b2c5..35b5ff1c972b 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/diag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/diag.cpp @@ -25,42 +25,42 @@ #include namespace sd { -namespace ops { +namespace ops { -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(diag, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // input validation - REQUIRE_TRUE(input->rankOf() <= 3, 0, "CUSTOM_OP diag: rank of input array must be <= 3 !, but got %i instead", input->rankOf()); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // TODO: still not sure if we really want this - output->assign(0); + // input validation + REQUIRE_TRUE( + input->rankOf() <= 3, 0, + "CUSTOM_OP diag: rank of input array must be <= 3 !, but got %i instead", + input->rankOf()); - helpers::diagFunctor(block.launchContext(), input, output); - - return Status::OK(); + // TODO: still not sure if we really want this + output->assign(0); + + helpers::diagFunctor(block.launchContext(), input, output); + + return Status::OK(); } DECLARE_SYN(MatrixDiag, diag); +DECLARE_TYPES(diag) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - DECLARE_TYPES(diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - -////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(diag) { - const Nd4jLong* inputShapeInfo = inputShape->at(0); + const Nd4jLong* inputShapeInfo = inputShape->at(0); - return SHAPELIST(ShapeUtils::evalDiagShapeInfo(inputShapeInfo, block.workspace())); + return SHAPELIST( + ShapeUtils::evalDiagShapeInfo(inputShapeInfo, block.workspace())); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp index 12d08b7b43fe..4d005bc60bc8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/diagPart.cpp @@ -25,58 +25,69 @@ #include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(diag_part, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { - const int inRank = input->rankOf(); - - // input validation - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, "DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead !", inRank); - for(int i = 0; i < inRank-1; ++i) - REQUIRE_TRUE(input->sizeAt(i) == input->sizeAt(i+1), 0, "DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal !", ShapeUtils::shapeAsString(input).c_str()); +CUSTOM_OP_IMPL(diag_part, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::diagPartFunctor(block.launchContext(), input, output); + const int inRank = input->rankOf(); - return Status::OK(); - } - DECLARE_SYN(DiagPart, diag_part); + // input validation + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, + "DIAG_PART op: input array must have rank among following three " + "possible values: 2, 4, 6, but got %i instead !", + inRank); + for (int i = 0; i < inRank - 1; ++i) + REQUIRE_TRUE(input->sizeAt(i) == input->sizeAt(i + 1), 0, + "DIAG_PART op: wrong shape of input array %s ! All dimensions " + "must be equal !", + ShapeUtils::shapeAsString(input).c_str()); - DECLARE_TYPES(diag_part) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + helpers::diagPartFunctor(block.launchContext(), input, output); - DECLARE_SHAPE_FN(diag_part) { - auto inputShapeInfo = inputShape->at(0); + return Status::OK(); +} +DECLARE_SYN(DiagPart, diag_part); - const int inRank = inputShapeInfo[0]; +DECLARE_TYPES(diag_part) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - // input validation - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, "DIAG_PART op: input array must have rank among following three possible values: 2, 4, 6, but got %i instead !", inRank); - for(int i = 1; i < inRank; ++i) - REQUIRE_TRUE(inputShapeInfo[i] == inputShapeInfo[i+1], 0, "DIAG_PART op: wrong shape of input array %s ! All dimensions must be equal !", ShapeUtils::shapeAsString(inputShapeInfo).c_str()); +DECLARE_SHAPE_FN(diag_part) { + auto inputShapeInfo = inputShape->at(0); - Nd4jLong* outShapeInfo = nullptr; + const int inRank = inputShapeInfo[0]; - int outRank = inRank/2; + // input validation + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 6, 0, + "DIAG_PART op: input array must have rank among following three " + "possible values: 2, 4, 6, but got %i instead !", + inRank); + for (int i = 1; i < inRank; ++i) + REQUIRE_TRUE(inputShapeInfo[i] == inputShapeInfo[i + 1], 0, + "DIAG_PART op: wrong shape of input array %s ! All dimensions " + "must be equal !", + ShapeUtils::shapeAsString(inputShapeInfo).c_str()); - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outShapeInfo[0] = outRank; - for(int i = 1; i <= outRank; ++i) - outShapeInfo[i] = inputShapeInfo[i]; + Nd4jLong* outShapeInfo = nullptr; - ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); + int outRank = inRank / 2; - return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting(outShapeInfo, block.workspace())); - } + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 1; i <= outRank; ++i) outShapeInfo[i] = inputShapeInfo[i]; + ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting( + outShapeInfo, block.workspace())); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp index 17afcc10b170..5f627a9cdfdd 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/digamma.cpp @@ -26,25 +26,24 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + helpers::diGamma(block.launchContext(), *x, *z); - helpers::diGamma(block.launchContext(), *x, *z); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(digamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp index 5561d9cb6a8d..5b8c4f2c6df8 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/eye.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/eye.cpp @@ -21,92 +21,91 @@ #if NOT_EXCLUDED(OP_eye) #include -#include +#include namespace sd { namespace ops { +CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) { + helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); - CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) { - - helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); - - return Status::OK(); - } - - DECLARE_TYPES(eye) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(eye) { - - std::vector params; - - sd::DataType dtype = block.getTArguments().empty() ? sd::DataType::FLOAT32 : sd::DataTypeUtils::fromInt(T_ARG(0)); - - if(block.width() == 0) { - params = block.getIArguments(); - } - else { - for (int i = 0; i < block.width(); i++) { - auto input = INPUT_VARIABLE(i); - REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); - - for (int e = 0; e < input->lengthOf(); e++) - params.emplace_back(input->e(e)); - } - } - - - REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op."); - - const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' - if (!ordered) - params.insert(params.begin(), -99); - - REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op."); - - Nd4jLong* outShapeInfo(nullptr); - - const int size = params.size(); + return Status::OK(); +} - switch(size) { +DECLARE_TYPES(eye) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - case 2: - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[1] = params[1]; - outShapeInfo[2] = params[1]; - break; +DECLARE_SHAPE_FN(eye) { + std::vector params; - case 3: - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[1] = params[1]; - outShapeInfo[2] = params[2]; - break; + sd::DataType dtype = block.getTArguments().empty() + ? sd::DataType::FLOAT32 + : sd::DataTypeUtils::fromInt(T_ARG(0)); - default: - int rank = size-1; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - outShapeInfo[rank-1] = params[1]; - outShapeInfo[rank] = params[2]; - for(int i = 1; i < rank-1; ++i) - outShapeInfo[i] = params[i+2]; - break; - } + if (block.width() == 0) { + params = block.getIArguments(); + } else { + for (int i = 0; i < block.width(); i++) { + auto input = INPUT_VARIABLE(i); + REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); - shape::updateStrides(outShapeInfo, static_cast(-params[0])); - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, dtype)); - RELEASE(outShapeInfo, block.workspace()); - return SHAPELIST(result); + for (int e = 0; e < input->lengthOf(); e++) + params.emplace_back(input->e(e)); } - - -} + } + + REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op."); + + const bool ordered = + (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' + if (!ordered) params.insert(params.begin(), -99); + + REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op."); + + Nd4jLong* outShapeInfo(nullptr); + + const int size = params.size(); + + switch (size) { + case 2: + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[1] = params[1]; + outShapeInfo[2] = params[1]; + break; + + case 3: + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[1] = params[1]; + outShapeInfo[2] = params[2]; + break; + + default: + int rank = size - 1; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + outShapeInfo[rank - 1] = params[1]; + outShapeInfo[rank] = params[2]; + for (int i = 1; i < rank - 1; ++i) outShapeInfo[i] = params[i + 2]; + break; + } + + shape::updateStrides(outShapeInfo, static_cast(-params[0])); + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(outShapeInfo, dtype)); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp index c39f8b55da3e..b01ecba3947a 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lgamma.cpp @@ -26,25 +26,24 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(lgamma, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + helpers::lgamma(block.launchContext(), *x, *z); - helpers::lgamma(block.launchContext(), *x, *z); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(lgamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) // as TF says - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) // as TF says + ->setSameMode(true); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp b/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp index 797ca8b2a487..1e99fc6377bb 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/log1p.cpp @@ -24,25 +24,25 @@ #include namespace sd { - namespace ops { - OP_IMPL(Log1p, 1, 1, true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Log1p, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - x->applyTransform(transform::Log1p, *z); + x->applyTransform(transform::Log1p, *z); - STORE_RESULT(z); + STORE_RESULT(z); - return Status::OK(); - } - DECLARE_SYN(log1p, Log1p); - } + return Status::OK(); +} +DECLARE_SYN(log1p, Log1p); +} // namespace ops - DECLARE_TYPES(Log1p) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(Log1p) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp b/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp index 81831e3fc226..eb9c6c796748 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lstsq.cpp @@ -25,112 +25,144 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - bool fastFlag = true; - double l2_factor = 0.; - if (block.numB() > 0) { - fastFlag = B_ARG(0); - } - if (block.numT() > 0) { - l2_factor = T_ARG(0); - } - REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); - -// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - //REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0."); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return Status::OK(); - - auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z); - - return res; - } - - CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - bool fastFlag = true; - double l2_factor = 0.; - if (block.numB() > 0) { - fastFlag = B_ARG(0); - } - if (block.numT() > 0) { - l2_factor = T_ARG(0); - } - REQUIRE_TRUE(a->rankOf() >=2, 0, "lstsq: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "lstsq: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); - -// REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, "lstsq: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - //REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not finished for factor difference from 0."); - auto res = Status::OK(); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return res; - - res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, l2_factor, fastFlag, z); - - return res; - } - - DECLARE_SYN(MatrixSolveLs, lstsq); - - DECLARE_SHAPE_FN(lstsq) { - auto in0 = inputShape->at(0); - auto in1 = inputShape->at(1); - auto shapeOf = ShapeUtils::shapeAsVector(in1); - auto rank = shapeOf.size(); - shapeOf[rank - 2] = shape::sizeAt(in0, -1); - - if (shape::isEmpty(in0) || shape::isEmpty(in1)) { - shapeOf[rank - 1] = 0; // set output shape to empty - } - auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - if (shapeOf[rank - 1] == 0) { -// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); - resShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(in0)); - } - return SHAPELIST(resShape); - } - - DECLARE_TYPES(lstsq) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - DECLARE_SHAPE_FN(solve_ls) { - auto in0 = inputShape->at(0); - auto in1 = inputShape->at(1); - auto shapeOf = ShapeUtils::shapeAsVector(in1); - auto rank = shapeOf.size(); - shapeOf[rank - 2] = shape::sizeAt(in0, -1); - - if (shape::isEmpty(in0) || shape::isEmpty(in1)) { - shapeOf[rank - 1] = 0; // set output shape to empty - } - auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in0), shape::order(in1), shapeOf);//ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - if (shapeOf[rank - 1] == 0) { - resShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(in1)); -// ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); - } - return SHAPELIST(resShape); - } - - DECLARE_TYPES(solve_ls) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } +namespace ops { + +CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + bool fastFlag = true; + double l2_factor = 0.; + if (block.numB() > 0) { + fastFlag = B_ARG(0); + } + if (block.numT() > 0) { + l2_factor = T_ARG(0); + } + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "lstsq: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "lstsq: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); + + // REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last + // two dimmensions should be equal, but %i and %i are given", + // a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, + "lstsq: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + // REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not + // finished for factor difference from 0."); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return Status::OK(); + + auto res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, + l2_factor, fastFlag, z); + + return res; } +CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + bool fastFlag = true; + double l2_factor = 0.; + if (block.numB() > 0) { + fastFlag = B_ARG(0); + } + if (block.numT() > 0) { + l2_factor = T_ARG(0); + } + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "lstsq: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "lstsq: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); + + // REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "lstsq: The last + // two dimmensions should be equal, but %i and %i are given", + // a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-2) == b->sizeAt(-2), 0, + "lstsq: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + // REQUIRE_TRUE(l2_factor == 0., 0, "lstsq: Implementation of operation is not + // finished for factor difference from 0."); + auto res = Status::OK(); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return res; + + res = helpers::leastSquaresSolveFunctor(block.launchContext(), a, b, + l2_factor, fastFlag, z); + + return res; +} + +DECLARE_SYN(MatrixSolveLs, lstsq); + +DECLARE_SHAPE_FN(lstsq) { + auto in0 = inputShape->at(0); + auto in1 = inputShape->at(1); + auto shapeOf = ShapeUtils::shapeAsVector(in1); + auto rank = shapeOf.size(); + shapeOf[rank - 2] = shape::sizeAt(in0, -1); + + if (shape::isEmpty(in0) || shape::isEmpty(in1)) { + shapeOf[rank - 1] = 0; // set output shape to empty + } + auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in0), shape::order(in1), + shapeOf); // ShapeBuilders::copyShapeInfoAndType(in1, in0, true, + // block.workspace()); + if (shapeOf[rank - 1] == 0) { + // ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); + resShape = ConstantShapeHelper::getInstance()->emptyShapeInfo( + ArrayOptions::dataType(in0)); + } + return SHAPELIST(resShape); +} + +DECLARE_TYPES(lstsq) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} +DECLARE_SHAPE_FN(solve_ls) { + auto in0 = inputShape->at(0); + auto in1 = inputShape->at(1); + auto shapeOf = ShapeUtils::shapeAsVector(in1); + auto rank = shapeOf.size(); + shapeOf[rank - 2] = shape::sizeAt(in0, -1); + + if (shape::isEmpty(in0) || shape::isEmpty(in1)) { + shapeOf[rank - 1] = 0; // set output shape to empty + } + auto resShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in0), shape::order(in1), + shapeOf); // ShapeBuilders::copyShapeInfoAndType(in1, in0, true, + // block.workspace()); + if (shapeOf[rank - 1] == 0) { + resShape = ConstantShapeHelper::getInstance()->emptyShapeInfo( + ArrayOptions::dataType(in1)); + // ArrayOptions::setPropertyBit(resShape, ARRAY_EMPTY); + } + return SHAPELIST(resShape); +} + +DECLARE_TYPES(solve_ls) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp index 1607a08ea458..21123947c81c 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/lup.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/lup.cpp @@ -24,45 +24,62 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto p = OUTPUT_VARIABLE(1); - if (block.numI()) { - DataType dtype = (DataType)INT_ARG(0); - REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } + auto p = OUTPUT_VARIABLE(1); + if (block.numI()) { + DataType dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, + 0, + "lu: Permutation data type should be 32bit or 64bit int only, " + "but '%s' given.", + DataTypeUtils::asString(dtype).c_str()); + } - REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "lu: The rank of input array should not less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE( + input->sizeAt(-1) == input->sizeAt(-2), 0, + "lu: The last two dimmensions should be equal, but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); - helpers::lu(block.launchContext(), input, z, p); - return Status::OK(); - } - - DECLARE_SHAPE_FN(lu) { - auto in = inputShape->at(0); - auto shapeVector = ShapeUtils::shapeAsVector(in); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); - auto dtype = sd::DataType::INT32; - if (block.numI()) { - dtype = (DataType)INT_ARG(0); - REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); - } - auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1, - shapeVector.data(), block.workspace()); - return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); - } + helpers::lu(block.launchContext(), input, z, p); + return Status::OK(); +} + +DECLARE_SHAPE_FN(lu) { + auto in = inputShape->at(0); + auto shapeVector = ShapeUtils::shapeAsVector(in); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); + auto dtype = sd::DataType::INT32; + if (block.numI()) { + dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == sd::DataType::INT32 || dtype == sd::DataType::INT64, + 0, + "lu: Permutation data type should be 32bit or 64bit int only, " + "but '%s' given.", + DataTypeUtils::asString(dtype).c_str()); + } + auto luP = ShapeBuilders::createShapeInfo( + dtype, shape::order(in), shapeVector.size() - 1, shapeVector.data(), + block.workspace()); + return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); +} - DECLARE_TYPES(lu) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {sd::DataType::INT32, sd::DataType::INT64}) - ->setSameMode(false); - } - } +DECLARE_TYPES(lu) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {sd::DataType::INT32, sd::DataType::INT64}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp index c97bfcb2a2ad..3ffd49df389c 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrixDiagPart.cpp @@ -21,52 +21,56 @@ #include #include - namespace sd { - namespace ops { - CUSTOM_OP_IMPL(matrix_diag_part, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - const int inRank = input->rankOf(); - - REQUIRE_TRUE(inRank >= 2, 0, "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, but %i given!", inRank); +namespace ops { +CUSTOM_OP_IMPL(matrix_diag_part, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int inRank = input->rankOf(); - output->nullify(); - return helpers::matrixDiagPart(block.launchContext(), input, output); - } + REQUIRE_TRUE(inRank >= 2, 0, + "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, " + "but %i given!", + inRank); - DECLARE_SHAPE_FN(matrix_diag_part) { - Nd4jLong const* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); + output->nullify(); + return helpers::matrixDiagPart(block.launchContext(), input, output); +} - REQUIRE_TRUE(inRank >= 2, 0, "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, but %i given!", inRank); +DECLARE_SHAPE_FN(matrix_diag_part) { + Nd4jLong const* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - int outRank = inRank - 1; - int lastDimension = sd::math::nd4j_min(shape::sizeAt(in, -1), shape::sizeAt(in, -2)); - if(outRank == 1) { - //output shape is a vector with size min(sizeAt(0), sizeAt(1)) - outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(lastDimension, ArrayOptions::dataType(in)); - } - else { - Nd4jLong* anShapeInfo; - ALLOCATE(anShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - anShapeInfo[0] = outRank; - for(int i = 0; i < outRank - 1; ++i) - anShapeInfo[i + 1] = shape::sizeAt(in, i); - anShapeInfo[outRank] = lastDimension; + REQUIRE_TRUE(inRank >= 2, 0, + "CUSTOM_OP matrix_diag_part: input array must have rank >= 2, " + "but %i given!", + inRank); - ShapeUtils::updateStridesAndType(anShapeInfo, in, shape::order(in)); - outShapeInfo = CONSTANT(anShapeInfo); - } - return SHAPELIST(outShapeInfo); - } + int outRank = inRank - 1; + int lastDimension = + sd::math::nd4j_min(shape::sizeAt(in, -1), shape::sizeAt(in, -2)); + if (outRank == 1) { + // output shape is a vector with size min(sizeAt(0), sizeAt(1)) + outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo( + lastDimension, ArrayOptions::dataType(in)); + } else { + Nd4jLong* anShapeInfo; + ALLOCATE(anShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + anShapeInfo[0] = outRank; + for (int i = 0; i < outRank - 1; ++i) + anShapeInfo[i + 1] = shape::sizeAt(in, i); + anShapeInfo[outRank] = lastDimension; - DECLARE_TYPES(matrix_diag_part) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } -} + ShapeUtils::updateStridesAndType(anShapeInfo, in, shape::order(in)); + outShapeInfo = CONSTANT(anShapeInfo); + } + return SHAPELIST(outShapeInfo); } +DECLARE_TYPES(matrix_diag_part) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp index 222074c81c8d..730491e80599 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrixSetDiag.cpp @@ -25,35 +25,50 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(matrix_set_diag, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto diagonal = INPUT_VARIABLE(1); + auto input = INPUT_VARIABLE(0); + auto diagonal = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(diagonal->rankOf() == input->rankOf()-1, 0, "MATRIX_SET_DIAG op: rank of diagonal array must be smaller by one compared to rank of input array, but got %i and %i correspondingly !", diagonal->rankOf(), input->rankOf()); + REQUIRE_TRUE( + diagonal->rankOf() == input->rankOf() - 1, 0, + "MATRIX_SET_DIAG op: rank of diagonal array must be smaller by one " + "compared to rank of input array, but got %i and %i correspondingly !", + diagonal->rankOf(), input->rankOf()); - for(int i = 0; i < diagonal->rankOf() - 1; ++i) - REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays must be equal till last diagonal dimension but one, however got diagonal=%s and input=%s instead !", ShapeUtils::shapeAsString(diagonal).c_str(), ShapeUtils::shapeAsString(input).c_str()); + for (int i = 0; i < diagonal->rankOf() - 1; ++i) + REQUIRE_TRUE(diagonal->sizeAt(i) == input->sizeAt(i), 0, + "MATRIX_SET_DIAG op: the shapes of diagonal and input arrays " + "must be equal till last diagonal dimension but one, however " + "got diagonal=%s and input=%s instead !", + ShapeUtils::shapeAsString(diagonal).c_str(), + ShapeUtils::shapeAsString(input).c_str()); - REQUIRE_TRUE(diagonal->sizeAt(-1) == (int)sd::math::nd4j_min(input->sizeAt(-1), input->sizeAt(-2)), 0, "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but got %i instead !", input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); + REQUIRE_TRUE( + diagonal->sizeAt(-1) == (int)sd::math::nd4j_min( + input->sizeAt(-1), input->sizeAt(-2)), + 0, + "MATRIX_SET_DIAG op: the value of last dimension of diagonal array must " + "be equal to min(input_last_shape=%i, input_last_but_one_shape=%i), but " + "got %i instead !", + input->sizeAt(-1), input->sizeAt(-2), diagonal->sizeAt(-1)); - helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, false); + helpers::matrixSetDiag(block.launchContext(), *input, *diagonal, *output, + false); - return Status::OK(); + return Status::OK(); } DECLARE_SYN(MatrixSetDiag, matrix_set_diag); - DECLARE_TYPES(matrix_set_diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - -} +DECLARE_TYPES(matrix_set_diag) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp index 08e059c374c0..22c9c83e3843 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp @@ -20,39 +20,44 @@ #include #if NOT_EXCLUDED(OP_matrix_band_part) -#include #include +#include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(matrix_band_part, 1, 1, true, 0, 2) { - - auto input = INPUT_VARIABLE(0); - - auto output = OUTPUT_VARIABLE(0); - Nd4jLong minLower = INT_ARG(0); - Nd4jLong maxUpper = INT_ARG(1); - - REQUIRE_TRUE(input->rankOf() >= 2, 0, "matrix_band_part: Input rank should be 2 or greater."); - Nd4jLong N = input->sizeAt(-2); - Nd4jLong M = input->sizeAt(-1); - REQUIRE_TRUE(minLower > -N && minLower < N, 0, "matrix_band_part: lower diagonal count %i should be less than %i.", - minLower, N); - REQUIRE_TRUE(maxUpper > -M && maxUpper < M, 0, "matrix_band_part: upper diagonal count %i should be less than %i.", - maxUpper, M); - - helpers::matrixBandPart(block.launchContext(), input, output, minLower, maxUpper); - return ND4J_STATUS_OK; - } - DECLARE_SYN(band_part, matrix_band_part); - } - - DECLARE_TYPES(matrix_band_part) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setSameMode(true); - } +namespace ops { +CONFIGURABLE_OP_IMPL(matrix_band_part, 1, 1, true, 0, 2) { + auto input = INPUT_VARIABLE(0); + + auto output = OUTPUT_VARIABLE(0); + Nd4jLong minLower = INT_ARG(0); + Nd4jLong maxUpper = INT_ARG(1); + + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_band_part: Input rank should be 2 or greater."); + Nd4jLong N = input->sizeAt(-2); + Nd4jLong M = input->sizeAt(-1); + REQUIRE_TRUE( + minLower > -N && minLower < N, 0, + "matrix_band_part: lower diagonal count %i should be less than %i.", + minLower, N); + REQUIRE_TRUE( + maxUpper > -M && maxUpper < M, 0, + "matrix_band_part: upper diagonal count %i should be less than %i.", + maxUpper, M); + + helpers::matrixBandPart(block.launchContext(), input, output, minLower, + maxUpper); + return ND4J_STATUS_OK; +} +DECLARE_SYN(band_part, matrix_band_part); +} // namespace ops + +DECLARE_TYPES(matrix_band_part) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setSameMode(true); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp index edd10e6ea600..21e6eaff975f 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_determinant.cpp @@ -18,128 +18,157 @@ // Created by GS at 2/26/2018 // -#include #include #include +#include #if NOT_EXCLUDED(OP_matrix_determinant) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(matrix_determinant, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - - return helpers::determinant(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(matrix_determinant) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - - DECLARE_TYPES(matrix_determinant) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(matrix_determinant, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_determinant: The rank of input array should not less " + "than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "matrix_determinant: The last two dimmensions should be equal, " + "but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); + + return helpers::determinant(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(matrix_determinant) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +DECLARE_TYPES(matrix_determinant) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd + #endif #if NOT_EXCLUDED(OP_log_matrix_determinant) namespace sd { - namespace ops { - DECLARE_TYPES(log_matrix_determinant) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(log_matrix_determinant, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - - return helpers::logAbsDeterminant(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(log_matrix_determinant) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - } +namespace ops { +DECLARE_TYPES(log_matrix_determinant) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(log_matrix_determinant, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "log_matrix_determinant: The rank of input array should not " + "less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "log_matrix_determinant: The last two dimmensions should be " + "equal, but %i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); + + return helpers::logAbsDeterminant(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(log_matrix_determinant) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +} // namespace ops +} // namespace sd #endif #if NOT_EXCLUDED(OP_logdet) namespace sd { - namespace ops { - DECLARE_TYPES(logdet) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE(helpers::checkCholeskyInput(block.launchContext(), input), 0, "logdet: The input tensor should be positive-defined hermitian."); - - return helpers::logdetFunctor(block.launchContext(), input, output); - } - - DECLARE_SHAPE_FN(logdet) { - auto inShape = inputShape->at(0); - - Nd4jLong const* determinantShape; - int targetRank = shape::rank(inShape) - 2; // last two dimensions will be reduced to scalar - - if (targetRank == 0) { // scalar only - determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - else if (targetRank == 1) { // vector - determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); - } - else { // only two last dimensions are excluded - determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); - } - return SHAPELIST(determinantShape); - } - } +namespace ops { +DECLARE_TYPES(logdet) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "logdet: The rank of input array should not less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "logdet: The last two dimmensions should be equal, but %i and " + "%i are given", + input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + helpers::checkCholeskyInput(block.launchContext(), input), 0, + "logdet: The input tensor should be positive-defined hermitian."); + + return helpers::logdetFunctor(block.launchContext(), input, output); +} + +DECLARE_SHAPE_FN(logdet) { + auto inShape = inputShape->at(0); + + Nd4jLong const* determinantShape; + int targetRank = shape::rank(inShape) - + 2; // last two dimensions will be reduced to scalar + + if (targetRank == 0) { // scalar only + determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inShape)); + } else if (targetRank == 1) { // vector + determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); + } else { // only two last dimensions are excluded + determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + } + return SHAPELIST(determinantShape); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp index 79061ca5e31c..c07f8e36bbae 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_diag.cpp @@ -23,46 +23,45 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(matrix_diag, 1, 1, false, 0, 0) { + auto diagonal = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto diagonal = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(!diagonal->isScalar(), 0, + "CUSTOM_OP matrix_diag: input diagonal array must be at list a " + "vector, but scalar was given!"); - REQUIRE_TRUE(!diagonal->isScalar(), 0, "CUSTOM_OP matrix_diag: input diagonal array must be at list a vector, but scalar was given!"); + helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, + true); - helpers::matrixSetDiag(block.launchContext(), *output, *diagonal, *output, true); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(matrix_diag) { + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int inRank = shape::rank(in); - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int inRank = shape::rank(in); - - // if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] + // if for example diagonal array has shape [A,B,C] then output array has + // shape [A,B,C,C] - int outRank = inRank + 1; + int outRank = inRank + 1; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < inRank; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = shape::sizeAt(in, -1); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 0; i < inRank; ++i) outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = shape::sizeAt(in, -1); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(matrix_diag) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); -} + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } -} - +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp index 6a595a92b300..b5c5d7bba107 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_inverse.cpp @@ -24,23 +24,27 @@ #include #include namespace sd { - namespace ops { - OP_IMPL(matrix_inverse, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(matrix_inverse, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE(input->rankOf() >= 2, 0, + "matrix_inverse: The rank of input array should not less than " + "2, but %i is given", + input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, + "matrix_inverse: The last two dimmensions should be equal, but " + "%i and %i are given", + input->sizeAt(-1), input->sizeAt(-2)); - return helpers::inverse(block.launchContext(), input, output); - } + return helpers::inverse(block.launchContext(), input, output); +} - DECLARE_TYPES(matrix_inverse) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(matrix_inverse) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp index d9e9b00c39a0..5b38c1efb6a1 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/moments.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/moments.cpp @@ -25,68 +25,69 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(moments, 1, 2, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto means = OUTPUT_VARIABLE(0); - auto variances = OUTPUT_VARIABLE(1); - - auto axis = block.getIArguments(); - const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); -// for (int e = 0; e < axisVector->lengthOf(); e++) { -// int ca = (int) axisVector->e(e); -// if (ca < 0) -// ca += input->rankOf(); -// -// axis.emplace_back(ca); -// } - - } - - std::vector& dims = axis; - input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis); - input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(moments) { - auto axis = block.getIArguments(); - auto input = INPUT_VARIABLE(0); - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - - for (int e = 0; e < axisVector->lengthOf(); e++) { - int ca = axisVector->e(e); - if (ca < 0) - ca += input->rankOf(); +namespace ops { +CUSTOM_OP_IMPL(moments, 1, 2, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto means = OUTPUT_VARIABLE(0); + auto variances = OUTPUT_VARIABLE(1); + + auto axis = block.getIArguments(); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + // for (int e = 0; e < axisVector->lengthOf(); e++) { + // int ca = (int) axisVector->e(e); + // if (ca < 0) + // ca += input->rankOf(); + // + // axis.emplace_back(ca); + // } + } + + std::vector& dims = axis; + input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, + false, axis); + input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims); + + return Status::OK(); +} - axis.emplace_back(ca); - } +DECLARE_SHAPE_FN(moments) { + auto axis = block.getIArguments(); + auto input = INPUT_VARIABLE(0); - } - //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); - auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); - return SHAPELIST(meanShape, varianceShape); - } + for (int e = 0; e < axisVector->lengthOf(); e++) { + int ca = axisVector->e(e); + if (ca < 0) ca += input->rankOf(); - DECLARE_TYPES(moments) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + axis.emplace_back(ca); } + } + // std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {axis}); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + + auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, + false, block.workspace()); + auto varianceShape = ShapeUtils::evalReduceShapeInfo( + 'c', axis, *input, keepDims, false, block.workspace()); + return SHAPELIST(meanShape, varianceShape); +} +DECLARE_TYPES(moments) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp b/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp index 35ffdcbc6e82..eccf5eafd99a 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/polygamma.cpp @@ -25,38 +25,42 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) { - auto n = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); + auto n = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(n->isSameShape(x), 0, "POLYGAMMA op: two input arrays n and x must have the same shapes, but got n=%s and x=%s instead !", ShapeUtils::shapeAsString(n).c_str(), ShapeUtils::shapeAsString(x).c_str()); + REQUIRE_TRUE(n->isSameShape(x), 0, + "POLYGAMMA op: two input arrays n and x must have the same " + "shapes, but got n=%s and x=%s instead !", + ShapeUtils::shapeAsString(n).c_str(), + ShapeUtils::shapeAsString(x).c_str()); - Nd4jLong arrLen = n->lengthOf(); - // FIXME: this shit should be single op call, not a loop! - auto nNegative = n->reduceNumber(sd::reduce::IsNegative, nullptr); - auto xPositive = x->reduceNumber(sd::reduce::IsPositive, nullptr); - bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 - bool xPositiveFlag = xPositive.e(0); // require all x > 0 - REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !"); - REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); + Nd4jLong arrLen = n->lengthOf(); + // FIXME: this shit should be single op call, not a loop! + auto nNegative = n->reduceNumber(sd::reduce::IsNegative, nullptr); + auto xPositive = x->reduceNumber(sd::reduce::IsPositive, nullptr); + bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 + bool xPositiveFlag = xPositive.e(0); // require all x > 0 + REQUIRE_TRUE(nPositiveFlag, 0, + "POLYGAMMA op: all elements of n array must be >= 0 !"); + REQUIRE_TRUE(xPositiveFlag, 0, + "POLYGAMMA op: all elements of x array must be > 0 !"); - helpers::polyGamma(block.launchContext(), *n, *x, *output); - return Status::OK(); + helpers::polyGamma(block.launchContext(), *n, *x, *output); + return Status::OK(); } DECLARE_SYN(polyGamma, polygamma); DECLARE_SYN(PolyGamma, polygamma); - DECLARE_TYPES(polygamma) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } -} +DECLARE_TYPES(polygamma) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp index ac8f6dbac652..d933c83ebf66 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/qr.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/qr.cpp @@ -18,73 +18,84 @@ // Created by GS at 12/20/2019 // -#include #include #include +#include #if NOT_EXCLUDED(OP_qr) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto outputQ = OUTPUT_VARIABLE(0); - auto outputR = OUTPUT_VARIABLE(1); - auto fullMatricies = false; - if (block.numB()) - fullMatricies = B_ARG(0); - - REQUIRE_TRUE(input->rankOf() >=2, 0, "qr: The rank of input array should not be less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || (!fullMatricies && outputQ->isSameShape(input)), 0, "qr: The last dimmensions should be equal to result Q, but %i and %i are given", outputQ->sizeAt(-1), input->sizeAt(-2)); - REQUIRE_TRUE((fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), 0, "qr: The last dimmensions should be equal to result R, but %i and %i are given", outputR->sizeAt(-1), input->sizeAt(-1)); - if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty()) - helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies); +namespace ops { +CUSTOM_OP_IMPL(qr, 1, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto outputQ = OUTPUT_VARIABLE(0); + auto outputR = OUTPUT_VARIABLE(1); + auto fullMatricies = false; + if (block.numB()) fullMatricies = B_ARG(0); - return Status::OK(); - } + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "qr: The rank of input array should not be less than 2, but %i is given", + input->rankOf()); + REQUIRE_TRUE((fullMatricies && outputQ->sizeAt(-1) == input->sizeAt(-2)) || + (!fullMatricies && outputQ->isSameShape(input)), + 0, + "qr: The last dimmensions should be equal to result Q, but %i " + "and %i are given", + outputQ->sizeAt(-1), input->sizeAt(-2)); + REQUIRE_TRUE( + (fullMatricies && outputR->sizeAt(-1) == input->sizeAt(-1)) || + (!fullMatricies && outputR->sizeAt(-1) == outputR->sizeAt(-2)), + 0, + "qr: The last dimmensions should be equal to result R, but %i and %i are " + "given", + outputR->sizeAt(-1), input->sizeAt(-1)); + if (!input->isEmpty() && !outputQ->isEmpty() && !outputR->isEmpty()) + helpers::qr(block.launchContext(), input, outputQ, outputR, fullMatricies); - DECLARE_SHAPE_FN(qr) { - auto inShape = inputShape->at(0); + return Status::OK(); +} - Nd4jLong const* shapeQ; - Nd4jLong const* shapeR; - int targetRank = shape::rank(inShape); // last two dimensions will be reduced to scalar +DECLARE_SHAPE_FN(qr) { + auto inShape = inputShape->at(0); - auto fullMatricies = false; - if (block.numB()) - fullMatricies = B_ARG(0); + Nd4jLong const* shapeQ; + Nd4jLong const* shapeR; + int targetRank = + shape::rank(inShape); // last two dimensions will be reduced to scalar - auto shape = ShapeUtils::shapeAsVector(inShape); + auto fullMatricies = false; + if (block.numB()) fullMatricies = B_ARG(0); - if (!fullMatricies) { // outputs are: Q is MxN and R is NxN - shape[targetRank - 1] = shape::sizeAt(inShape, -1); - shape[targetRank - 2] = shape[targetRank - 1]; - shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), targetRank, - shape::shapeOf(inShape)); - shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), shape); + auto shape = ShapeUtils::shapeAsVector(inShape); - } - else {// otherwise outputs are Q is MxM and R is MxN with zero filled rows - shape[targetRank - 1] = shape::sizeAt(inShape, -2); - shape[targetRank - 2] = shape[targetRank - 1]; - shapeR = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), targetRank, - shape::shapeOf(inShape)); - shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), - shape::order(inShape), shape); - } + if (!fullMatricies) { // outputs are: Q is MxN and R is NxN + shape[targetRank - 1] = shape::sizeAt(inShape, -1); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeR = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); - return SHAPELIST(shapeQ, shapeR); + } else { // otherwise outputs are Q is MxM and R is MxN with zero filled rows + shape[targetRank - 1] = shape::sizeAt(inShape, -2); + shape[targetRank - 2] = shape[targetRank - 1]; + shapeR = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, + shape::shapeOf(inShape)); + shapeQ = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + } - } + return SHAPELIST(shapeQ, shapeR); +} - DECLARE_TYPES(qr) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(qr) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/linalg/solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/solve.cpp index 154001684b14..90e77e621219 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/solve.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/solve.cpp @@ -24,55 +24,69 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - bool useAdjoint = false; + bool useAdjoint = false; - if (block.numB() > 0) { - useAdjoint = B_ARG(0); - } - REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), b->rankOf() - 2, b->shapeInfo()), 0, "solve: Input shapes should be alike."); - REQUIRE_TRUE(a->rankOf() >=2, 0, "solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + if (block.numB() > 0) { + useAdjoint = B_ARG(0); + } + REQUIRE_TRUE(shape::shapeEquals(a->rankOf() - 2, a->shapeInfo(), + b->rankOf() - 2, b->shapeInfo()), + 0, "solve: Input shapes should be alike."); + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "solve: The rank of input left tensor should not be less than " + "2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "solve: The rank of input right tensor should not be less than " + "2, but %i is given", + b->rankOf()); - REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - if (a->isEmpty() || b->isEmpty() || z->isEmpty()) - return Status::OK(); + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, + "solve: The last two dimmensions should be equal, but %i and %i " + "are given", + a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, + "solve: The last dimmension of left part should be equal to " + "prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + if (a->isEmpty() || b->isEmpty() || z->isEmpty()) return Status::OK(); - auto input = a; - if (useAdjoint) { - auto adjointA = a->ulike(); - helpers::adjointMatrix(block.launchContext(), a, &adjointA); - input = new NDArray(adjointA); //.detach(); - }; + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, &adjointA); + input = new NDArray(adjointA); //.detach(); + }; - auto res = helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); - if (input != a) - delete input; + auto res = + helpers::solveFunctor(block.launchContext(), input, b, useAdjoint, z); + if (input != a) delete input; - return Status::OK(); - } - - DECLARE_SHAPE_FN(solve) { - auto in0 = inputShape->at(1); - auto in1 = inputShape->at(1); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + return Status::OK(); +} + +DECLARE_SHAPE_FN(solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - return SHAPELIST(CONSTANT(luShape)); - } + return SHAPELIST(CONSTANT(luShape)); +} - DECLARE_TYPES(solve) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } +DECLARE_TYPES(solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp b/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp index 9a9fb730b66f..4983730234dd 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/sufficient_statistics.cpp @@ -24,66 +24,68 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(sufficient_statistics, 2, 3, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto axisVector = INPUT_VARIABLE(1); - auto dataCount = OUTPUT_VARIABLE(0); - - auto sum = OUTPUT_VARIABLE(1); - auto squares = OUTPUT_VARIABLE(2); - - std::vector axis(axisVector->lengthOf());//*block.getIArguments(); - - // axis might be dynamic (i.e. tf mode) - helpers::adjustAxis(input->rankOf(), axisVector, axis); - - input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); - input->reduceAlongDimension(reduce::Sum, *sum, axis); - auto count = NDArrayFactory::create(input->dataType(), input->lengthOf() / sum->lengthOf()); - dataCount->assign(count); - if (block.numT() > 0) { - auto shift = OUTPUT_VARIABLE(3); - shift->assign(T_ARG(0)); - } - - return Status::OK(); - } - - DECLARE_TYPES(sufficient_statistics) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor() - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor() - ->setAllowedOutputTypes(0, DataType::INHERIT); - getOpDescriptor() - ->setAllowedOutputTypes(1, DataType::INHERIT); - getOpDescriptor() - ->setAllowedOutputTypes(2, DataType::INHERIT); - } - - DECLARE_SHAPE_FN(sufficient_statistics) { - auto axisVector = INPUT_VARIABLE(1); - std::vector axis(axisVector->lengthOf()); - - auto input = INPUT_VARIABLE(0); - helpers::adjustAxis(input->rankOf(), axisVector, axis); - - //std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - auto scalarShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))); - auto sumShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, false, block.workspace()); - - auto squareShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, false, block.workspace()); - - auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); - if (block.numT() > 0) - shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - - return shapeList; - } - } +namespace ops { +CUSTOM_OP_IMPL(sufficient_statistics, 2, 3, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto axisVector = INPUT_VARIABLE(1); + auto dataCount = OUTPUT_VARIABLE(0); + + auto sum = OUTPUT_VARIABLE(1); + auto squares = OUTPUT_VARIABLE(2); + + std::vector axis(axisVector->lengthOf()); //*block.getIArguments(); + + // axis might be dynamic (i.e. tf mode) + helpers::adjustAxis(input->rankOf(), axisVector, axis); + + input->reduceAlongDimension(reduce::SquaredNorm, *squares, axis); + input->reduceAlongDimension(reduce::Sum, *sum, axis); + auto count = NDArrayFactory::create(input->dataType(), + input->lengthOf() / sum->lengthOf()); + dataCount->assign(count); + if (block.numT() > 0) { + auto shift = OUTPUT_VARIABLE(3); + shift->assign(T_ARG(0)); + } + + return Status::OK(); +} +DECLARE_TYPES(sufficient_statistics) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); + getOpDescriptor()->setAllowedOutputTypes(1, DataType::INHERIT); + getOpDescriptor()->setAllowedOutputTypes(2, DataType::INHERIT); } +DECLARE_SHAPE_FN(sufficient_statistics) { + auto axisVector = INPUT_VARIABLE(1); + std::vector axis(axisVector->lengthOf()); + + auto input = INPUT_VARIABLE(0); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + + // std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {axis}); + auto scalarShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0))); + auto sumShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, + false, block.workspace()); + + auto squareShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, false, + false, block.workspace()); + + auto shapeList = SHAPELIST(scalarShape, sumShape, squareShape); + if (block.numT() > 0) + shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); + + return shapeList; +} +} // namespace ops + +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp index b707605d2639..14dddc1dcc74 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/trace.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/trace.cpp @@ -22,48 +22,55 @@ #if NOT_EXCLUDED(OP_trace) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(trace, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() >= 2, 0, "TRACE op: the rank of input array must be >=2, but got %i instead!", input->rankOf()); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::trace(block.launchContext(), *input, *output); + REQUIRE_TRUE( + input->rankOf() >= 2, 0, + "TRACE op: the rank of input array must be >=2, but got %i instead!", + input->rankOf()); - return Status::OK(); + helpers::trace(block.launchContext(), *input, *output); + + return Status::OK(); } - DECLARE_TYPES(trace) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(trace) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(trace) { - auto inShapeInfo = inputShape->at(0); + auto inShapeInfo = inputShape->at(0); - REQUIRE_TRUE(inShapeInfo[0] >= 2, 0, "TRACE op: the rank of input array must be >=2, but got %i instead!", inShapeInfo[0]); - const int rank = inShapeInfo[0] - 2; + REQUIRE_TRUE( + inShapeInfo[0] >= 2, 0, + "TRACE op: the rank of input array must be >=2, but got %i instead!", + inShapeInfo[0]); + const int rank = inShapeInfo[0] - 2; - Nd4jLong* outShapeInfo(nullptr); - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); + Nd4jLong* outShapeInfo(nullptr); + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); - outShapeInfo[0] = rank; - for(int i=1; i <= rank; ++i) - outShapeInfo[i] = inShapeInfo[i]; + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) outShapeInfo[i] = inShapeInfo[i]; - shape::updateStrides(outShapeInfo, shape::order(inShapeInfo)); - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, ArrayOptions::dataType(inShapeInfo))); - RELEASE(outShapeInfo, block.workspace()); - return SHAPELIST(result); + shape::updateStrides(outShapeInfo, shape::order(inShapeInfo)); + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(outShapeInfo, ArrayOptions::dataType(inShapeInfo))); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp index c7e1a125bda8..b952109c2266 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp @@ -20,44 +20,42 @@ #include - namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tri, -2, 1, false, 0, 1) { + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int diag = block.numI() > 2 ? INT_ARG(2) : 0; + const int diag = block.numI() > 2 ? INT_ARG(2) : 0; - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (1., diag + 1, 0, *output, 'l'), LIBND4J_TYPES); // fill with unities lower triangular block of matrix - BUILD_SINGLE_SELECTOR(output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), LIBND4J_TYPES); // fill with zeros upper triangular block of matrix + BUILD_SINGLE_SELECTOR( + output->dataType(), output->fillAsTriangular, + (1., diag + 1, 0, *output, 'l'), + LIBND4J_TYPES); // fill with unities lower triangular block of matrix + BUILD_SINGLE_SELECTOR( + output->dataType(), output->fillAsTriangular, (0., 0, diag, *output, 'u'), + LIBND4J_TYPES); // fill with zeros upper triangular block of matrix - // output->setValueInDiagMatrix(1., diag, 'l'); - // output->setValueInDiagMatrix(0., diag+1, 'u'); + // output->setValueInDiagMatrix(1., diag, 'l'); + // output->setValueInDiagMatrix(0., diag+1, 'u'); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(tri) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); - } - +DECLARE_TYPES(tri) { + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}); +} DECLARE_SHAPE_FN(tri) { - const int rows = INT_ARG(0); - const int cols = block.numI() > 1 ? INT_ARG(1) : rows; + const int rows = INT_ARG(0); + const int cols = block.numI() > 1 ? INT_ARG(1) : rows; - auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32; + auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', {rows, cols})); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, 'c', {rows, cols})); } - - - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp b/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp index c9d23753cf26..19d9aca89a4b 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triangular_solve.cpp @@ -24,59 +24,71 @@ #include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - bool isLower = true; - bool useAdjoint = false; +namespace ops { +CUSTOM_OP_IMPL(triangular_solve, 2, 1, false, 0, 0) { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + bool isLower = true; + bool useAdjoint = false; - if (block.numB() > 0) { - if (block.numB() > 1) { - isLower = B_ARG(0); - useAdjoint = B_ARG(1); - } - else { - isLower = B_ARG(0); - } - } + if (block.numB() > 0) { + if (block.numB() > 1) { + isLower = B_ARG(0); + useAdjoint = B_ARG(1); + } else { + isLower = B_ARG(0); + } + } - REQUIRE_TRUE(a->rankOf() >=2, 0, "triangular_solve: The rank of input left tensor should not be less than 2, but %i is given", a->rankOf()); - REQUIRE_TRUE(b->rankOf() >=2, 0, "triangular_solve: The rank of input right tensor should not be less than 2, but %i is given", b->rankOf()); + REQUIRE_TRUE(a->rankOf() >= 2, 0, + "triangular_solve: The rank of input left tensor should not be " + "less than 2, but %i is given", + a->rankOf()); + REQUIRE_TRUE(b->rankOf() >= 2, 0, + "triangular_solve: The rank of input right tensor should not be " + "less than 2, but %i is given", + b->rankOf()); - REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, "triangular_solve: The last two dimmensions should be equal, but %i and %i are given", a->sizeAt(-1), a->sizeAt(-2)); - REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, "triangular_solve: The last dimmension of left part should be equal to prelast of right part, but %i and %i are given", a->sizeAt(-1), b->sizeAt(-2)); - auto input = a; - if (useAdjoint) { - auto adjointA = a->ulike(); - helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); - input = new NDArray(adjointA); //.detach(); - isLower = !isLower; - }; + REQUIRE_TRUE(a->sizeAt(-1) == a->sizeAt(-2), 0, + "triangular_solve: The last two dimmensions should be equal, " + "but %i and %i are given", + a->sizeAt(-1), a->sizeAt(-2)); + REQUIRE_TRUE(a->sizeAt(-1) == b->sizeAt(-2), 0, + "triangular_solve: The last dimmension of left part should be " + "equal to prelast of right part, but %i and %i are given", + a->sizeAt(-1), b->sizeAt(-2)); + auto input = a; + if (useAdjoint) { + auto adjointA = a->ulike(); + helpers::adjointMatrix(block.launchContext(), a, isLower, &adjointA); + input = new NDArray(adjointA); //.detach(); + isLower = !isLower; + }; - auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, isLower, useAdjoint, z); - if (input != a) - delete input; + auto res = helpers::triangularSolveFunctor(block.launchContext(), input, b, + isLower, useAdjoint, z); + if (input != a) delete input; - return Status::OK(); - } - - DECLARE_SHAPE_FN(triangular_solve) { - auto in0 = inputShape->at(1); - auto in1 = inputShape->at(1); - auto luShape = ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); + return Status::OK(); +} - return SHAPELIST(CONSTANT(luShape)); - } +DECLARE_SHAPE_FN(triangular_solve) { + auto in0 = inputShape->at(1); + auto in1 = inputShape->at(1); + auto luShape = + ShapeBuilders::copyShapeInfoAndType(in1, in0, true, block.workspace()); - DECLARE_TYPES(triangular_solve) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - } + return SHAPELIST(CONSTANT(luShape)); +} + +DECLARE_TYPES(triangular_solve) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp index 18c5ac8ebde8..0347ef8b8c52 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/triu.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/triu.cpp @@ -19,96 +19,106 @@ // #include -#include - +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(triu, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); + REQUIRE_TRUE( + input->rankOf() > 0, 0, + "TRIU OP: the rank of input array must be > 0, but got %i instead !", + input->rankOf()); - const int diag = block.numI() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; - BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, (0, diag, 0, *output, 'l' ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), input->fillAsTriangular, + (0, diag, 0, *output, 'l'), LIBND4J_TYPES); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(triu) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(triu) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(triu) { + auto inShapeInfo = inputShape->at(0); - auto inShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inShapeInfo[0] > 0, 0, "TRIU OP: the rank of input array must be > 0, but got %i instead !", inShapeInfo[0]); + REQUIRE_TRUE( + inShapeInfo[0] > 0, 0, + "TRIU OP: the rank of input array must be > 0, but got %i instead !", + inShapeInfo[0]); - int rank = (inShapeInfo[0] == 1) ? 2 : inShapeInfo[0]; + int rank = (inShapeInfo[0] == 1) ? 2 : inShapeInfo[0]; - Nd4jLong *outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - memcpy(outShapeInfo, inShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + memcpy( + outShapeInfo, inShapeInfo, + (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only - if(inShapeInfo[0] == 1) { - outShapeInfo[0] = rank; - outShapeInfo[1] = inShapeInfo[1]; - outShapeInfo[2] = inShapeInfo[1]; - } + if (inShapeInfo[0] == 1) { + outShapeInfo[0] = rank; + outShapeInfo[1] = inShapeInfo[1]; + outShapeInfo[2] = inShapeInfo[1]; + } - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, shape::order(inShapeInfo)); + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(triu_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); // dLoss/dO - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); // dLoss/dO - - auto gradI = OUTPUT_VARIABLE(0); // dLoss/dI + auto gradI = OUTPUT_VARIABLE(0); // dLoss/dI - REQUIRE_TRUE(input->rankOf() > 0, 0, "TRIU_BP OP: the rank of input array must be > 0, but got %i instead !", input->rankOf()); + REQUIRE_TRUE( + input->rankOf() > 0, 0, + "TRIU_BP OP: the rank of input array must be > 0, but got %i instead !", + input->rankOf()); - const int diag = block.numI() > 0 ? INT_ARG(0) : 0; + const int diag = block.numI() > 0 ? INT_ARG(0) : 0; - helpers::triuBP(block.launchContext(), *input, *gradO, *gradI, diag); + helpers::triuBP(block.launchContext(), *input, *gradO, *gradI, diag); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(triu_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(triu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(triu_bp) { + auto gradOShapeInfo = inputShape->at(1); + int rank = gradOShapeInfo[0]; - auto gradOShapeInfo = inputShape->at(1); - int rank = gradOShapeInfo[0]; + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + memcpy( + outShapeInfo, gradOShapeInfo, + (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - memcpy(outShapeInfo, gradOShapeInfo, (1 + rank) * sizeof(Nd4jLong)); // copy rank and dimensions values only + auto in = inputShape->at(0); + ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - auto in = inputShape->at(0); - ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp b/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp index 6aba1fc5f4d1..4a53faeb48d7 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/zeta.cpp @@ -25,35 +25,41 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(zeta, 2, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto q = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(zeta, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto q = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->isSameShape(q), 0, "ZETA op: two input arrays must have the same shapes, bot got x=%s and q=%s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(q).c_str()); + REQUIRE_TRUE(x->isSameShape(q), 0, + "ZETA op: two input arrays must have the same shapes, bot got " + "x=%s and q=%s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(q).c_str()); - Nd4jLong arrLen = x->lengthOf(); + Nd4jLong arrLen = x->lengthOf(); - // FIXME: this should NOT be loop. - for(Nd4jLong i = 0; i < arrLen; ++i ) { - REQUIRE_TRUE(x->e(i) > 1.f, 0, "ZETA op: all elements of x array must be > 1 !"); - REQUIRE_TRUE(q->e(i) > 0.f, 0, "ZETA op: all elements of q array must be > 0 !"); - } + // FIXME: this should NOT be loop. + for (Nd4jLong i = 0; i < arrLen; ++i) { + REQUIRE_TRUE(x->e(i) > 1.f, 0, + "ZETA op: all elements of x array must be > 1 !"); + REQUIRE_TRUE(q->e(i) > 0.f, 0, + "ZETA op: all elements of q array must be > 0 !"); + } - helpers::zeta(block.launchContext(), *x, *q, *output); + helpers::zeta(block.launchContext(), *x, *q, *output); - return Status::OK(); - } - DECLARE_SYN(Zeta, zeta); + return Status::OK(); +} +DECLARE_SYN(Zeta, zeta); - DECLARE_TYPES(zeta) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(zeta) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp index d100153ec421..468a4566533b 100644 --- a/libnd4j/include/ops/declarable/generic/list/clone_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/clone_list.cpp @@ -24,19 +24,19 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(clone_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); +namespace ops { +LIST_OP_IMPL(clone_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); - auto newList = list->clone(); + auto newList = list->clone(); - //OVERWRITE_RESULT(newList); - setupResultList(newList, block); - return ND4J_STATUS_OK; - } - DECLARE_SYN(TensorArrayIdentityV3, clone_list); - DECLARE_SYN(tensorarrayidentityv3, clone_list); - } + // OVERWRITE_RESULT(newList); + setupResultList(newList, block); + return ND4J_STATUS_OK; } +DECLARE_SYN(TensorArrayIdentityV3, clone_list); +DECLARE_SYN(tensorarrayidentityv3, clone_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/create_list.cpp b/libnd4j/include/ops/declarable/generic/list/create_list.cpp index fc2d22df8821..d362df25495e 100644 --- a/libnd4j/include/ops/declarable/generic/list/create_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/create_list.cpp @@ -24,40 +24,40 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(create_list, 1, 2, 0, -2) { - int height = 0; - bool expandable = false; - if (block.numI() == 2) { - height = INT_ARG(0); - expandable = (bool) INT_ARG(1); - } else if (block.numI() == 1) { - height = INT_ARG(0); - } else if (block.width() == 1) { - height = INPUT_VARIABLE(0)->e(0); - expandable = true; - } else { - height = 0; - expandable = true; - } - - NDArrayList list(height, expandable); - - // we recieve input array for graph integrity purposes only - auto input = INPUT_VARIABLE(0); - setupResultList(list, block); - - auto scalar = NDArrayFactory::create(list.counter()); - block.pushNDArrayToVariableSpace(block.nodeId(), 1, scalar); - - return Status::OK(); - } - - DECLARE_SYN(TensorArrayV3, create_list); - DECLARE_SYN(tensorarrayv3, create_list); - DECLARE_SYN(TensorArrayCreateV3, create_list); - DECLARE_SYN(tensorarraycreatev3, create_list); - } +namespace ops { +LIST_OP_IMPL(create_list, 1, 2, 0, -2) { + int height = 0; + bool expandable = false; + if (block.numI() == 2) { + height = INT_ARG(0); + expandable = (bool)INT_ARG(1); + } else if (block.numI() == 1) { + height = INT_ARG(0); + } else if (block.width() == 1) { + height = INPUT_VARIABLE(0)->e(0); + expandable = true; + } else { + height = 0; + expandable = true; + } + + NDArrayList list(height, expandable); + + // we recieve input array for graph integrity purposes only + auto input = INPUT_VARIABLE(0); + setupResultList(list, block); + + auto scalar = NDArrayFactory::create(list.counter()); + block.pushNDArrayToVariableSpace(block.nodeId(), 1, scalar); + + return Status::OK(); } +DECLARE_SYN(TensorArrayV3, create_list); +DECLARE_SYN(tensorarrayv3, create_list); +DECLARE_SYN(TensorArrayCreateV3, create_list); +DECLARE_SYN(tensorarraycreatev3, create_list); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp index 341f6347e01a..2d1837ed8a88 100644 --- a/libnd4j/include/ops/declarable/generic/list/gather_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/gather_list.cpp @@ -24,53 +24,59 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(gather_list, 2, 1, 0, -2) { - auto list = INPUT_LIST(0); - auto indices = INPUT_VARIABLE(1); +namespace ops { +LIST_OP_IMPL(gather_list, 2, 1, 0, -2) { + auto list = INPUT_LIST(0); + auto indices = INPUT_VARIABLE(1); - indices->printShapeInfo("indices shape"); - indices->printIndexedBuffer("indices"); + indices->printShapeInfo("indices shape"); + indices->printIndexedBuffer("indices"); - REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, "Indices for Gather operation should be a vector"); - REQUIRE_TRUE(list->height() > 0, 0, "Number of elements in list should be positive prior to Gather call"); - REQUIRE_TRUE(list->height() == indices->lengthOf(), 1, "Number of indicies should be equal to number of elements in list, but got [%i] indices instead", indices->lengthOf()); + REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, + "Indices for Gather operation should be a vector"); + REQUIRE_TRUE( + list->height() > 0, 0, + "Number of elements in list should be positive prior to Gather call"); + REQUIRE_TRUE(list->height() == indices->lengthOf(), 1, + "Number of indicies should be equal to number of elements in " + "list, but got [%i] indices instead", + indices->lengthOf()); - // first of all we need to get shapes - std::vector shape({0}); - shape[0] = indices->lengthOf(); - for (int e = 0; e < list->height(); e++) { - auto array = list->readRaw(e); + // first of all we need to get shapes + std::vector shape({0}); + shape[0] = indices->lengthOf(); + for (int e = 0; e < list->height(); e++) { + auto array = list->readRaw(e); - // now we should fill other dimensions - if (e == 0) { - for (int d = 0; d < array.rankOf(); d++) - shape.emplace_back(array.sizeAt(d)); - } - } + // now we should fill other dimensions + if (e == 0) { + for (int d = 0; d < array.rankOf(); d++) + shape.emplace_back(array.sizeAt(d)); + } + } - auto result = NDArrayFactory::create('c', shape, list->dataType()); - std::vector indicesList((list->readRaw(0).rankOf() + 1) * 2, 0); - int skipPosition = 0; - for (int e = 0; e < indices->lengthOf(); e++) { - auto idx = indices->e(e); - auto array = list->readRaw(idx); - - // first dimension - indicesList[0] = skipPosition; - indicesList[1] = skipPosition++ + 1; + auto result = NDArrayFactory::create('c', shape, list->dataType()); + std::vector indicesList((list->readRaw(0).rankOf() + 1) * 2, 0); + int skipPosition = 0; + for (int e = 0; e < indices->lengthOf(); e++) { + auto idx = indices->e(e); + auto array = list->readRaw(idx); - auto subarray = (result)(indicesList, true); - subarray.assign(array); - } + // first dimension + indicesList[0] = skipPosition; + indicesList[1] = skipPosition++ + 1; - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayGatherV3, gather_list); - DECLARE_SYN(tensorarraygatherv3, gather_list); - } + auto subarray = (result)(indicesList, true); + subarray.assign(array); + } + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayGatherV3, gather_list); +DECLARE_SYN(tensorarraygatherv3, gather_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp index 8fde9fdc2495..01eb6eb753d8 100644 --- a/libnd4j/include/ops/declarable/generic/list/pick_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/pick_list.cpp @@ -24,33 +24,36 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(pick_list, 1, 1, 0, -2) { - auto list = INPUT_LIST(0); - - std::vector indices; - if (block.width() > 1 && block.getVariable(1)->getNDArray()->isVector()) { - auto ia = INPUT_VARIABLE(1); - for (int e = 0; e < ia->lengthOf(); e++) - indices.emplace_back(ia->e(e)); - } else if (block.numI() > 0) { - indices = block.getIArguments(); - } else return ND4J_STATUS_BAD_ARGUMENTS; - - for (auto& v: indices) { - if (v >= list->height()) { - nd4j_printf("Requested index [%i] is higher (or equal) then ArrayList height: [%i]", v, - list->height()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - auto result = list->pick(indices); - -// OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } +namespace ops { +LIST_OP_IMPL(pick_list, 1, 1, 0, -2) { + auto list = INPUT_LIST(0); + + std::vector indices; + if (block.width() > 1 && block.getVariable(1)->getNDArray()->isVector()) { + auto ia = INPUT_VARIABLE(1); + for (int e = 0; e < ia->lengthOf(); e++) + indices.emplace_back(ia->e(e)); + } else if (block.numI() > 0) { + indices = block.getIArguments(); + } else + return ND4J_STATUS_BAD_ARGUMENTS; + + for (auto& v : indices) { + if (v >= list->height()) { + nd4j_printf( + "Requested index [%i] is higher (or equal) then ArrayList height: " + "[%i]", + v, list->height()); + return ND4J_STATUS_BAD_ARGUMENTS; } + } + auto result = list->pick(indices); + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/read_list.cpp b/libnd4j/include/ops/declarable/generic/list/read_list.cpp index 0300bf42b913..5ab63a124043 100644 --- a/libnd4j/include/ops/declarable/generic/list/read_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/read_list.cpp @@ -24,40 +24,47 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(read_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); - NDArray result; +namespace ops { +LIST_OP_IMPL(read_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); + NDArray result; - REQUIRE_TRUE(list->height() > 0, 0, "ReadList: number of elements in list should be positive prior to Read call"); + REQUIRE_TRUE(list->height() > 0, 0, + "ReadList: number of elements in list should be positive prior " + "to Read call"); - if (block.numI() > 0) { - auto index = INT_ARG(0); + if (block.numI() > 0) { + auto index = INT_ARG(0); - REQUIRE_TRUE(list->isWritten(index), 0, "ReadList: requested index [%i] wasn't written yet", index); + REQUIRE_TRUE(list->isWritten(index), 0, + "ReadList: requested index [%i] wasn't written yet", index); - result = list->read(index); - } else if (block.width() > 0) { - auto vec = INPUT_VARIABLE(1); + result = list->read(index); + } else if (block.width() > 0) { + auto vec = INPUT_VARIABLE(1); - REQUIRE_TRUE(vec->isScalar(), 0, "ReadList: index operand should be a scalar"); - - auto index = vec->e(0); + REQUIRE_TRUE(vec->isScalar(), 0, + "ReadList: index operand should be a scalar"); - REQUIRE_TRUE(list->isWritten(index), 0, "ReadList: requested index [%i] wasn't written yet", index); + auto index = vec->e(0); - result = list->read(index); - } else { - REQUIRE_TRUE(false, 0, "ReadList: index value should be set either via IntArgs or via second operand"); - } + REQUIRE_TRUE(list->isWritten(index), 0, + "ReadList: requested index [%i] wasn't written yet", index); -// OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayReadV3, read_list); - DECLARE_SYN(tensorarrayreadv3, read_list); - } + result = list->read(index); + } else { + REQUIRE_TRUE(false, 0, + "ReadList: index value should be set either via IntArgs or " + "via second operand"); + } + + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayReadV3, read_list); +DECLARE_SYN(tensorarrayreadv3, read_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp index 3b2126b4785f..c9c387d8fee7 100644 --- a/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/scatter_list.cpp @@ -21,57 +21,58 @@ #include #if NOT_EXCLUDED(OP_scatter_list) -#include #include +#include namespace sd { - namespace ops { - LIST_OP_IMPL(scatter_list, 1, 1, 0, -2) { - NDArrayList *list = nullptr; - NDArray *array = nullptr; - NDArray *indices = nullptr; +namespace ops { +LIST_OP_IMPL(scatter_list, 1, 1, 0, -2) { + NDArrayList *list = nullptr; + NDArray *array = nullptr; + NDArray *indices = nullptr; - bool hasList = false; - auto w = block.width(); + bool hasList = false; + auto w = block.width(); - if (w >= 3){ - list = INPUT_LIST(0); - indices = INPUT_VARIABLE(1); - array = INPUT_VARIABLE(2); - hasList = true; - } else { - array = INPUT_VARIABLE(1); - indices = INPUT_VARIABLE(2); - list = new NDArrayList(indices->lengthOf(), false); + if (w >= 3) { + list = INPUT_LIST(0); + indices = INPUT_VARIABLE(1); + array = INPUT_VARIABLE(2); + hasList = true; + } else { + array = INPUT_VARIABLE(1); + indices = INPUT_VARIABLE(2); + list = new NDArrayList(indices->lengthOf(), false); - throw std::runtime_error("scatter_list - Not implemented yet"); - } + throw std::runtime_error("scatter_list - Not implemented yet"); + } - REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, "ScatterList: Indices for Scatter should be a vector") - REQUIRE_TRUE(indices->lengthOf() == array->sizeAt(0), 0, "ScatterList: Indices length should be equal number of TADs along dim0, but got %i instead", indices->lengthOf()); + REQUIRE_TRUE(indices->isVector() || indices->rankOf() == 1, 0, + "ScatterList: Indices for Scatter should be a vector") + REQUIRE_TRUE(indices->lengthOf() == array->sizeAt(0), 0, + "ScatterList: Indices length should be equal number of TADs " + "along dim0, but got %i instead", + indices->lengthOf()); - std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); - auto tads = array->allTensorsAlongDimension( axis); - for (int e = 0; e < tads.size(); e++) { - auto idx = indices->e(e); - if (idx >= tads.size()) - return ND4J_STATUS_BAD_ARGUMENTS; + std::vector axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); + auto tads = array->allTensorsAlongDimension(axis); + for (int e = 0; e < tads.size(); e++) { + auto idx = indices->e(e); + if (idx >= tads.size()) return ND4J_STATUS_BAD_ARGUMENTS; - auto arr = tads.at(e).dup(array->ordering()); - auto res = list->write(idx, arr); + auto arr = tads.at(e).dup(array->ordering()); + auto res = list->write(idx, arr); - if (res != Status::OK()) - return res; - } + if (res != Status::OK()) return res; + } - if (!hasList) - setupResultList(*list, block); + if (!hasList) setupResultList(*list, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayScatterV3, scatter_list); - DECLARE_SYN(tensorarrayscatterv3, scatter_list); - } + return Status::OK(); } +DECLARE_SYN(TensorArrayScatterV3, scatter_list); +DECLARE_SYN(tensorarrayscatterv3, scatter_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/size_list.cpp b/libnd4j/include/ops/declarable/generic/list/size_list.cpp index 7a4ab7f870d7..a1dd7739f5fd 100644 --- a/libnd4j/include/ops/declarable/generic/list/size_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/size_list.cpp @@ -24,25 +24,26 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(size_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); +namespace ops { +LIST_OP_IMPL(size_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); - auto result = NDArrayFactory::create(list->height(), block.launchContext()); + auto result = + NDArrayFactory::create(list->height(), block.launchContext()); - //nd4j_printf("List size: [%i]\n", list->height()); - result.printIndexedBuffer("actual height"); + // nd4j_printf("List size: [%i]\n", list->height()); + result.printIndexedBuffer("actual height"); - //nd4j_printf("List size: [%i]\n", list->height()); - result.printIndexedBuffer("actual height"); + // nd4j_printf("List size: [%i]\n", list->height()); + result.printIndexedBuffer("actual height"); - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArraySizeV3, size_list); - DECLARE_SYN(tensorarraysizev3, size_list); - } + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArraySizeV3, size_list); +DECLARE_SYN(tensorarraysizev3, size_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/split_list.cpp b/libnd4j/include/ops/declarable/generic/list/split_list.cpp index 40699443f010..5cd69324e019 100644 --- a/libnd4j/include/ops/declarable/generic/list/split_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/split_list.cpp @@ -21,69 +21,75 @@ #include #if NOT_EXCLUDED(OP_split_list) -#include #include +#include namespace sd { - namespace ops { - LIST_OP_IMPL(split_list, 2, 1, 0, -2) { - NDArrayList *list = nullptr; - NDArray *array = nullptr; - NDArray *sizes = nullptr; - - bool hasList = false; - - if (block.width() >= 3){ - list = INPUT_LIST(0); - array = INPUT_VARIABLE(1); - sizes = INPUT_VARIABLE(2); - hasList = true; - } else { - array = INPUT_VARIABLE(0); - sizes = INPUT_VARIABLE(1); - list = new NDArrayList(sizes->lengthOf(), false); - //block.trackList(list); - - throw std::runtime_error("split_list - Not implemented yet"); - } - - REQUIRE_TRUE(sizes->isZ(), 0, "split_list: sizes array must have one of integer types"); - REQUIRE_TRUE(sizes->rankOf() == 1, 0, "split_list: sizes array must be 1D") - - list->setShape(array->getShapeAsVector()); - - // now let's build subarrays - int cnt = 0; - std::vector indices(2 * array->rankOf(), 0); - for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { - int c_size = sizes->e(e); - - REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size); - REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size); - - // we're adding our interval along zeroth dimension - indices[0] = cnt; - indices[1] = cnt + c_size; - cnt += c_size; - - auto subarray = (*array)(indices); - - auto status = list->write(e, subarray.dup(array->ordering())); - - if (status != ND4J_STATUS_OK) - return status; - } - - if (!hasList) { - //OVERWRITE_RESULT(list); - setupResultList(*list, block); - } - - return Status::OK(); - } - DECLARE_SYN(TensorArraySplitV3, split_list); - DECLARE_SYN(tensorarraysplitv3, split_list); - } +namespace ops { +LIST_OP_IMPL(split_list, 2, 1, 0, -2) { + NDArrayList *list = nullptr; + NDArray *array = nullptr; + NDArray *sizes = nullptr; + + bool hasList = false; + + if (block.width() >= 3) { + list = INPUT_LIST(0); + array = INPUT_VARIABLE(1); + sizes = INPUT_VARIABLE(2); + hasList = true; + } else { + array = INPUT_VARIABLE(0); + sizes = INPUT_VARIABLE(1); + list = new NDArrayList(sizes->lengthOf(), false); + // block.trackList(list); + + throw std::runtime_error("split_list - Not implemented yet"); + } + + REQUIRE_TRUE(sizes->isZ(), 0, + "split_list: sizes array must have one of integer types"); + REQUIRE_TRUE(sizes->rankOf() == 1, 0, "split_list: sizes array must be 1D") + + list->setShape(array->getShapeAsVector()); + + // now let's build subarrays + int cnt = 0; + std::vector indices(2 * array->rankOf(), 0); + for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { + int c_size = sizes->e(e); + + REQUIRE_TRUE(c_size > 0, 0, + "Slice size should have postive value, but got %i instead", + c_size); + REQUIRE_TRUE( + cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, + "Slices size should NOT be higher then number of TADs of source array. " + "Source size: [%i]; Slice start: [%i]; Slice size: [%i]", + array->sizeAt(0), cnt, c_size); + + // we're adding our interval along zeroth dimension + indices[0] = cnt; + indices[1] = cnt + c_size; + cnt += c_size; + + auto subarray = (*array)(indices); + + auto status = list->write(e, subarray.dup(array->ordering())); + + if (status != ND4J_STATUS_OK) return status; + } + + if (!hasList) { + // OVERWRITE_RESULT(list); + setupResultList(*list, block); + } + + return Status::OK(); } +DECLARE_SYN(TensorArraySplitV3, split_list); +DECLARE_SYN(tensorarraysplitv3, split_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/stack_list.cpp b/libnd4j/include/ops/declarable/generic/list/stack_list.cpp index a0f0f422054d..85d6f99004cd 100644 --- a/libnd4j/include/ops/declarable/generic/list/stack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/stack_list.cpp @@ -25,21 +25,21 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(stack_list, 1, 1, 0, 0) { - auto list = INPUT_LIST(0); - //auto z = OUTPUT_VARIABLE(0); +namespace ops { +LIST_OP_IMPL(stack_list, 1, 1, 0, 0) { + auto list = INPUT_LIST(0); + // auto z = OUTPUT_VARIABLE(0); - // FIXME: this is obviously bad - auto result = list->stack(); + // FIXME: this is obviously bad + auto result = list->stack(); - //OVERWRITE_RESULT(result); - setupResult(result, block); - return Status::OK(); - } - DECLARE_SYN(TensorArrayConcatV3, stack_list); - DECLARE_SYN(tensorarrayconcatv3, stack_list); - } + // OVERWRITE_RESULT(result); + setupResult(result, block); + return Status::OK(); } +DECLARE_SYN(TensorArrayConcatV3, stack_list); +DECLARE_SYN(tensorarrayconcatv3, stack_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp index 9a12bb416e5f..b89a16bd3e9c 100644 --- a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp @@ -25,20 +25,20 @@ namespace sd { namespace ops { - LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) { - auto outputList = INPUT_LIST(0); - auto input = INPUT_VARIABLE(int(outputList != nullptr) ); +LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) { + auto outputList = INPUT_LIST(0); + auto input = INPUT_VARIABLE(int(outputList != nullptr)); - if (outputList == nullptr) { - outputList = new NDArrayList(0, true); - //block.trackList(outputList); - setupResultList(*outputList, block); - } - outputList->unstack(*input, INT_ARG(0)); + if (outputList == nullptr) { + outputList = new NDArrayList(0, true); + // block.trackList(outputList); + setupResultList(*outputList, block); + } + outputList->unstack(*input, INT_ARG(0)); - return Status::OK(); - } -} + return Status::OK(); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/list/write_list.cpp b/libnd4j/include/ops/declarable/generic/list/write_list.cpp index 8952e5453201..e9a72b052d73 100644 --- a/libnd4j/include/ops/declarable/generic/list/write_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/write_list.cpp @@ -24,47 +24,47 @@ #include namespace sd { - namespace ops { - LIST_OP_IMPL(write_list, 2, 1, 0, -2) { - auto list = INPUT_LIST(0); - auto output = OUTPUT_VARIABLE(0); - //output->printShapeInfo("Write_list default output shape"); - // nd4j mode - if (block.width() >= 3) { - auto input = INPUT_VARIABLE(block.width() - 2); - auto idx = INPUT_VARIABLE(block.width() - 1); +namespace ops { +LIST_OP_IMPL(write_list, 2, 1, 0, -2) { + auto list = INPUT_LIST(0); + auto output = OUTPUT_VARIABLE(0); + // output->printShapeInfo("Write_list default output shape"); + // nd4j mode + if (block.width() >= 3) { + auto input = INPUT_VARIABLE(block.width() - 2); + auto idx = INPUT_VARIABLE(block.width() - 1); - REQUIRE_TRUE(idx->isScalar(), 0, "Index should be Scalar"); + REQUIRE_TRUE(idx->isScalar(), 0, "Index should be Scalar"); - //nd4j_printf("Writing [%i]:\n", idx->e(0)); - //input->printShapeInfo("input shape"); - //input->printIndexedBuffer("input buffer"); - Nd4jStatus result = list->write(idx->e(0), input->dup()); + // nd4j_printf("Writing [%i]:\n", idx->e(0)); + // input->printShapeInfo("input shape"); + // input->printIndexedBuffer("input buffer"); + Nd4jStatus result = list->write(idx->e(0), input->dup()); - auto res = NDArrayFactory::create(list->counter(), block.launchContext()); - //res->printShapeInfo("Write_list 2 output shape"); + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); + // res->printShapeInfo("Write_list 2 output shape"); - setupResult(res, block); -// OVERWRITE_RESULT(res); + setupResult(res, block); + // OVERWRITE_RESULT(res); - return result; - } else if (block.numI() == 1) { - auto input = INPUT_VARIABLE(1); - auto idx = INT_ARG(0); + return result; + } else if (block.numI() == 1) { + auto input = INPUT_VARIABLE(1); + auto idx = INT_ARG(0); - Nd4jStatus result = list->write(idx, input->dup()); + Nd4jStatus result = list->write(idx, input->dup()); - auto res = NDArrayFactory::create(list->counter(), block.launchContext()); - //res->printShapeInfo("Write_list 1 output shape"); - //OVERWRITE_RESULT(res); - setupResult(res, block); - return result; - } else - return ND4J_STATUS_BAD_INPUT; - } - DECLARE_SYN(TensorArrayWriteV3, write_list); - DECLARE_SYN(tensorarraywritev3, write_list); - } + auto res = NDArrayFactory::create(list->counter(), block.launchContext()); + // res->printShapeInfo("Write_list 1 output shape"); + // OVERWRITE_RESULT(res); + setupResult(res, block); + return result; + } else + return ND4J_STATUS_BAD_INPUT; } +DECLARE_SYN(TensorArrayWriteV3, write_list); +DECLARE_SYN(tensorarraywritev3, write_list); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp index d6a6662a19a4..906ad009a594 100644 --- a/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/absoluteDifference.cpp @@ -23,272 +23,369 @@ #include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(absolute_difference_loss, 3, 1, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = (*predictions - *labels).transform(sd::transform::Abs); - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = (*predictions - *labels).transform(sd::transform::Abs); + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(absolute_difference_loss) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(absolute_difference_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = *predictions - *labels; - - // dE_i/dp_i = sign(p_i - y_i) - E.applyTransform(sd::transform::Sign, *dLdp); // dE/dp - // dE_i/dy_i = -sign(p_i - y_i) - - E.applyTransform(sd::transform::Abs, E); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-*dLdp); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = *predictions - *labels; + + // dE_i/dp_i = sign(p_i - y_i) + E.applyTransform(sd::transform::Sign, *dLdp); // dE/dp + // dE_i/dy_i = -sign(p_i - y_i) + + E.applyTransform(sd::transform::Abs, E); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } + + dLdl->assign(-*dLdp); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(absolute_difference_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(absolute_difference_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "ABSOLUTE_DIFFERENCE_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp index ac8214f149ba..514d1bbba434 100644 --- a/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/cosineDistance.cpp @@ -21,328 +21,438 @@ #include #if NOT_EXCLUDED(OP_cosine_distance_loss) -#include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - int dim = INT_ARG(1); // axis along which sum will be made - if(dim < 0) - dim += labels->rankOf(); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // regard 4 possible reduction modes below - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "COSINE_DISTANCE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - - if(!output->isScalar()) { - // weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + int dim = INT_ARG(1); // axis along which sum will be made + if (dim < 0) dim += labels->rankOf(); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // regard 4 possible reduction modes below + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "COSINE_DISTANCE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labels->rankOf(), 0, + "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) " + "must be < labels rank %i!", + dim, labels->rankOf()); + + if (!output->isScalar()) { + // weights array can be single scalar or has the same shape as output, and + // must be broadcastable to output shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == output->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + weights->rankOf(), output->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *output), + 0, + "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + } + + NDArray E = + 1. - + (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + // multiply E on weights + E *= (*weightsBroad); + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + output->assign(E.reduceNumber(reduce::Sum)); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + E.reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + + break; + } + } + + STORE_RESULT(*output); + + if (weightsBroad != weights) delete weightsBroad; - NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - // multiply E on weights - E *= (*weightsBroad); - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - output->assign(E.reduceNumber(reduce::Sum)); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = E.reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - - break; - } - } - - - STORE_RESULT(*output); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(cosine_distance_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss) { + // labels and predictions must have the same shapes + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + int dim = INT_ARG(1); + if (dim < 0) dim += labelsShapeInfo[0]; + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, + "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) " + "must be < labels rank %i!", + dim, labelsShapeInfo[0]); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + // evaluate output shapeInfo + Nd4jLong const* outShapeInfo = nullptr; + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else { // in this case output has the same shape as labels reduced by dim + // axis - // labels and predictions must have the same shapes - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - int dim = INT_ARG(1); - if(dim < 0) - dim += labelsShapeInfo[0]; - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, "COSINE_DISTANCE_LOSS OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labelsShapeInfo[0]); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - // evaluate output shapeInfo - Nd4jLong const* outShapeInfo = nullptr; - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else { // in this case output has the same shape as labels reduced by dim axis - - std::vector dimensions = {dim}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, outType, true, false, block.workspace()); - - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); + std::vector dimensions = {dim}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + outType, true, false, block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "COSINE_DISTANCE_LOSS OP: weights array should be scalar or have the " + "same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + int dim = INT_ARG(1); // axis along which sum will be made + if (dim < 0) dim += labels->rankOf(); + + std::vector dimensions = {dim}; + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + predictions->ordering(), dimensions, predictions->shapeInfo(), true, + false, block.workspace()); + // weights array can be single scalar or has the same shape as loss, and must + // be broadcastable to loss shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as loss array, but got %i and %i correspondingly!", + weights->rankOf(), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable( + weights->shapeInfo(), lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should " + "be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labels->rankOf(), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got " + "%i) must be < labels rank %i!", + dim, labels->rankOf()); + + NDArray E = + 1. - + (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + dLdp->assign(-*labels); + dLdl->assign(-*predictions); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum)); + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + } + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar() || weights->lengthOf() == 1) { + *dLdw = 0.; + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, + true, false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / numOfNonZeroWeights; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + } else { + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeights; + } else + dLdw->assign(E / numOfNonZeroWeights); + } + } + break; + } + } - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - int dim = INT_ARG(1); // axis along which sum will be made - if(dim < 0) - dim += labels->rankOf(); - - std::vector dimensions = {dim}; + if (weightsBroad != weights) delete weightsBroad; - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "COSINE_DISTANCE_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(predictions->ordering(), dimensions, predictions->shapeInfo(), true, false, block.workspace()); - // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->shapeInfo(), lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); - - NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - dLdp->assign(-*labels); - dLdl->assign(-*predictions); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum)); - } - else { - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - } - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar() || weights->lengthOf() == 1) { - *dLdw = 0.; - } - else { - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / numOfNonZeroWeights; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - } - else { - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeights; - } - else - dLdw->assign(E / numOfNonZeroWeights); - } - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(cosine_distance_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(cosine_distance_loss_grad) { - - /// labels and predictions must have the same shapes - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - int dim = INT_ARG(1); - if(dim < 0) - dim += labelsShapeInfo[0]; - - std::vector dimensions = {dim}; - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, true, false, block.workspace()); - // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // input dimension can't be larger than labels/predictions/weights rank - REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labelsShapeInfo[0]); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + /// labels and predictions must have the same shapes + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + int dim = INT_ARG(1); + if (dim < 0) dim += labelsShapeInfo[0]; + + std::vector dimensions = {dim}; + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "COSINE_DISTANCE_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + true, false, block.workspace()); + // weights array can be single scalar or has the same rank as loss, and must + // be broadcastable to loss + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as loss array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), + 0, + "COSINE_DISTANCE_LOSS_GRAD OP: shapes of weights and loss arrays should " + "be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // input dimension can't be larger than labels/predictions/weights rank + REQUIRE_TRUE(dim < labelsShapeInfo[0], 0, + "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got " + "%i) must be < labels rank %i!", + dim, labelsShapeInfo[0]); + + auto outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp index de1635384c71..81bd37c5400d 100644 --- a/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/hingeLoss.cpp @@ -24,281 +24,378 @@ #include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(hinge_loss, 3, 1, false, 0, 1) { - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "HINGE_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HINGE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HINGE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HINGE_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to logits if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // We first need to convert binary labels to -1/1 labels (as floats) - NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); - E.applyScalar(scalar::RELU, 0.0f, E); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } +CUSTOM_OP_IMPL(hinge_loss, 3, 1, false, 0, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "HINGE_LOSS OP: labels and logits arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HINGE_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HINGE_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HINGE_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to logits if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // We first need to convert binary labels to -1/1 labels (as floats) + NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); + E.applyScalar(scalar::RELU, 0.0f, E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + break; + } + } -////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(hinge_loss) { + if (weightsBroad != weights) delete weightsBroad; - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(hinge_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "HINGE_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HINGE_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HINGE_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); +DECLARE_TYPES(hinge_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - } +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(hinge_loss) { + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "HINGE_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HINGE_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HINGE_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const *outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); +} +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(hinge_loss_grad, 3, 3, false, 0, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "HINGE_LOSS_GRAD OP: labels and logits arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HINGE_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HINGE_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // We first need to convert binary labels to -1/1 labels (as floats) + NDArray z = (*labels * 2.f - 1.f); + + NDArray E = 1.f - z * (*logits); + E.applyScalar(scalar::RELU, 0.0f, E); + // turn E into gradient mask + + NDArray gradientMask(E.shapeInfo(), block.workspace()); + E.applyTransform(sd::transform::Sign, gradientMask); + + dLdp->assign(-z * gradientMask); + dLdl->assign(-2.f * (*logits) * gradientMask); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + if (weightsBroad != weights) delete weightsBroad; + return Status::OK(); +} -////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(hinge_loss_grad, 3, 3, false, 0, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "HINGE_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HINGE_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // We first need to convert binary labels to -1/1 labels (as floats) - NDArray z = (*labels * 2.f - 1.f); - - NDArray E = 1.f - z * (*logits); - E.applyScalar(scalar::RELU, 0.0f, E); - // turn E into gradient mask - - NDArray gradientMask(E.shapeInfo(), block.workspace()); - E.applyTransform(sd::transform::Sign, gradientMask); - - dLdp->assign(-z * gradientMask); - dLdl->assign(-2.f * (*logits) * gradientMask); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(hinge_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(hinge_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HINGE_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } +DECLARE_TYPES(hinge_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - } +DECLARE_SHAPE_FN(hinge_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HINGE_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HINGE_LOSS_GRAD OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HINGE_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp index f78fec1c217b..3fc91395888d 100644 --- a/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/huberLoss.cpp @@ -24,297 +24,395 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) { - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // FIXME: double? - double delta = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "HUBER_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HUBER_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to predictions if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - auto error = *predictions - *labels; - error.applyTransform(transform::Abs, error); - NDArray quadratic(error.shapeInfo(), block.workspace()); - error.applyScalar(scalar::MinPairwise, delta, quadratic); - - NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // FIXME: double? + double delta = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "HUBER_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HUBER_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HUBER_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HUBER_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to predictions if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + auto error = *predictions - *labels; + error.applyTransform(transform::Abs, error); + NDArray quadratic(error.shapeInfo(), block.workspace()); + error.applyScalar(scalar::MinPairwise, delta, quadratic); + + NDArray E = quadratic * quadratic * 0.5f + (error - quadratic) * delta; + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(huber_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(huber_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HUBER_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HUBER_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HUBER_LOSS OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HUBER_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HUBER_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(huber_loss_grad, 3, 3, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - auto delta = T_ARG(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "HUBER_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray diff = *predictions - *labels; - NDArray absDiff(diff); - absDiff.applyTransform(transform::Abs, absDiff); - NDArray quadratic(absDiff); - absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); - - NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; - - NDArray lteMask(diff.shapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); - - NDArray gtMask(diff.shapeInfo(), BOOL, true, block.launchContext()); - absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); - - NDArray signDiff(diff); - diff.applyTransform(transform::Sign, signDiff); - - - auto gtMaskFloat = gtMask.cast(diff.dataType()); - auto lteMaskFloat = lteMask.cast(diff.dataType()); - - - dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); - dLdl->assign(-lteMaskFloat * diff - gtMaskFloat * delta * signDiff); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(huber_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(huber_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "HUBER_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } - +CUSTOM_OP_IMPL(huber_loss_grad, 3, 3, false, 1, 1) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + auto delta = T_ARG(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "HUBER_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "HUBER_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "HUBER_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray diff = *predictions - *labels; + NDArray absDiff(diff); + absDiff.applyTransform(transform::Abs, absDiff); + NDArray quadratic(absDiff); + absDiff.applyScalar(scalar::MinPairwise, delta, quadratic); + + NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic) * delta; + + NDArray lteMask(diff.shapeInfo(), BOOL, true, block.launchContext()); + absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask); + + NDArray gtMask(diff.shapeInfo(), BOOL, true, block.launchContext()); + absDiff.applyScalar(scalar::GreaterThan, delta, gtMask); + + NDArray signDiff(diff); + diff.applyTransform(transform::Sign, signDiff); + + auto gtMaskFloat = gtMask.cast(diff.dataType()); + auto lteMaskFloat = lteMask.cast(diff.dataType()); + + dLdp->assign(lteMaskFloat * diff + gtMaskFloat * delta * signDiff); + dLdl->assign(-lteMaskFloat * diff - gtMaskFloat * delta * signDiff); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); +} +DECLARE_TYPES(huber_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(huber_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "HUBER_LOSS_GRAD OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "HUBER_LOSS_GRAD OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "HUBER_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp index 3afeea2bacd0..be24d2b56e08 100644 --- a/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/l2_loss.cpp @@ -24,29 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(l2_loss, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - // FIXME: output should be used directly here, to avoid sum - input->reduceNumber(reduce::SquaredNorm, *output); - (*output) /= 2.; - - return Status::OK(); - } - DECLARE_SHAPE_FN(l2_loss) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - } - - DECLARE_TYPES(l2_loss) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(l2_loss, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + // FIXME: output should be used directly here, to avoid sum + input->reduceNumber(reduce::SquaredNorm, *output); + (*output) /= 2.; + + return Status::OK(); +} +DECLARE_SHAPE_FN(l2_loss) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); +} + +DECLARE_TYPES(l2_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp index a9d77a8dad04..ec7e63124e13 100644 --- a/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/logLoss.cpp @@ -24,290 +24,390 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // FIXME: double? - double epsilon = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to predictions if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E = -(*labels)*((*predictions + epsilon).transform(transform::Log)) - (1. - *labels)*(((1. + epsilon) - *predictions).transform(transform::Log)); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // FIXME: double? + double epsilon = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "LOG_LOSS OP: labels and predictions arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_LOSS OP: reduction mode value is not acceptable, possible " + "values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to predictions if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E = + -(*labels) * ((*predictions + epsilon).transform(transform::Log)) - + (1. - *labels) * + (((1. + epsilon) - *predictions).transform(transform::Log)); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + } + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(log_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_LOSS OP: labels and predictions arrays must have the same " + "shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_LOSS OP: weights array should be scalar or have the same rank as " + "labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // FIXME: double? - double epsilon = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray predictPlusEps = *predictions + epsilon; - NDArray oneMinusLabels = 1. - *labels; - NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; - - // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) - dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp - // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) - ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy - - NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); - - // process 3 possible reduction modes below - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // FIXME: double? + double epsilon = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "LOG_LOSS_GRAD OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray predictPlusEps = *predictions + epsilon; + NDArray oneMinusLabels = 1. - *labels; + NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; + + // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) + dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - + *labels / predictPlusEps); // dE/dp + // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) + ((1. + 2. * epsilon) / predictPlusEps - 1.) + .applyTransform(transform::Log, *dLdl); // dE/dy + + NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - + oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); + + // process 3 possible reduction modes below + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(log_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(log_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_LOSS_GRAD OP: labels and predictions arrays must have the " + "same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank " + "as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp index 7508030f88f5..13b321280d3d 100644 --- a/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/log_poisson_loss.cpp @@ -25,287 +25,394 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(log_poisson_loss, 3, 1, true, 0, 1) { - auto log_predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - bool computeFullLoss = false; - if (block.numI() > 1) - computeFullLoss = INT_ARG(1) != 0; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, "LOG_POISSON_LOSS OP: labels and log_predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(log_predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_POISSON_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_POISSON_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(log_predictions)) - weightsBroad = new NDArray(weights->tileToShape(log_predictions->shapeInfo())); - - - NDArray E(labels->shapeInfo(), block.workspace()); - if (computeFullLoss) - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); - else - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); - - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - } - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = *weights * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); +CUSTOM_OP_IMPL(log_poisson_loss, 3, 1, true, 0, 1) { + auto log_predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + bool computeFullLoss = false; + if (block.numI() > 1) computeFullLoss = INT_ARG(1) != 0; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, + "LOG_POISSON_LOSS OP: labels and log_predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(log_predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_POISSON_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_POISSON_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(log_predictions)) + weightsBroad = + new NDArray(weights->tileToShape(log_predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), block.workspace()); + if (computeFullLoss) + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, + *log_predictions, E); + else + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, + E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: { // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; } - - ////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(log_poisson_loss) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = *weights * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } - ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(log_poisson_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); + if (weightsBroad != weights) delete weightsBroad; - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_POISSON_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_POISSON_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + return Status::OK(); +} - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(log_poisson_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(labelsShapeInfo, outType)); +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(log_poisson_loss) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_POISSON_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_POISSON_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_POISSON_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(labelsShapeInfo, outType)); + + return SHAPELIST(outShapeInfo); +} - return SHAPELIST(outShapeInfo); +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(log_poisson_loss_grad, 3, 3, false, 0, 1) { + auto log_predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + bool computeFullLoss = false; + if (block.numI() > 1) computeFullLoss = INT_ARG(1) != 0; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, + "LOG_POISSON_LOSS_GRAD OP: labels and log_predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(log_predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "LOG_POISSON_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(log_predictions)) + weightsBroad = + new NDArray(weights->tileToShape(log_predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), block.workspace()); + if (computeFullLoss) { + labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, + *log_predictions, E); + + NDArray rDiv(labels->shapeInfo(), block.workspace()); + labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); + dLdl->assign(rDiv + labels->transform(transform::Log) + + -(*log_predictions)); + } else { + labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, + E); + + dLdl->assign(-(*log_predictions)); + } + + dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; } - - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(log_poisson_loss_grad, 3, 3, false, 0, 1) { - - auto log_predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - bool computeFullLoss = false; - if (block.numI() > 1) - computeFullLoss = INT_ARG(1) != 0; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(log_predictions), 0, "LOG_POISSON_LOSS_GRAD OP: labels and log_predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(log_predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_POISSON_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(log_predictions)) - weightsBroad = new NDArray(weights->tileToShape(log_predictions->shapeInfo())); - - - NDArray E(labels->shapeInfo(), block.workspace()); - if (computeFullLoss) { - labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E); - - NDArray rDiv(labels->shapeInfo(), block.workspace()); - labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv); - dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); - } else { - labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E); - - dLdl->assign(-(*log_predictions)); - } - - dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - *dLdl *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + *dLdl *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; } - - DECLARE_TYPES(log_poisson_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; } + } - DECLARE_SHAPE_FN(log_poisson_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + if (weightsBroad != weights) delete weightsBroad; - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); + return Status::OK(); +} - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } +DECLARE_TYPES(log_poisson_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(log_poisson_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "LOG_POISSON_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "LOG_POISSON_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "LOG_POISSON_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp index 94592a3652a4..97c23501482e 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanPairWsSqErr.cpp @@ -26,357 +26,471 @@ #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) #include -#include + #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(mean_pairwssqerr_loss, 3, 1, false, 0, 1) { - /* - * Implementation of mean pairwise squared error loss - * - * For context on where this loss function may be useful see: - * - * Wei, Z., Zhang, J., Shen, X., Lin, Z., Mech, R., Hoai, M. and Samaras, D., 2018. - * Good view hunting: learning photo composition from dense view pairs. In Proceedings of the IEEE Conference on - * Computer Vision and Pattern Recognition (pp. 5437-5446). - * - * The paper defines the loss function as: - * - * L(y,q) = 1/((n*(n-1))/2) * (sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2) - * - * with y: predictions, q: labels, n: length of y and q - * - * As creating those pairs is computationally expensive, we implement a mathematically equivalent function: - * - * L(y,q) = 4/(n*(n-1)) * (n * sum (y_i - q_i)^2 - (sum y_i - q_i)^2) - * - * This equivalency can be derived as: - * - * sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2 = sum_(i,j=1..n,i!=j)((y_i - q_i) - (y_j - q_j))^2 - * - * To simplify the following equations we use - * - * sum_(i,j=1..n,i!=j)(d_i - d_j)^2 = sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) - * - * Due to the pairings each element will appear as both d_i and d_j exactly n-1 times. This allows us to split the sum: - * - * sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) = 2*(n-1)*sum d_i^2 - 2 * sum_(i,j=1..n,i!=j) d_i * d_j - * = 2*((n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j) - * - * Now we use the following equivalency: - * - * (sum d_i)^2 = sum d_i^2 + sum_(i,j=1..n,i!=j) d_i * d_j - * - * This allows us to now use sum d_i^2 and (sum d_i)^2 as a quick way to calculate the sum: - * - * (n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j = n * sum d_i^2 - (sum d_i)^2 - * - * And by substituting it into the original definition we get: - * - * 1/((n*(n-1))/2) * 2*(n * sum d_i^2 - (sum d_i)^2) - * - * Which can be again simplified to - * - * 4/(n*(n-1)) * (n * sum d_i^2 - (sum d_i)^2) - * - * After substituting d_i back to (y_i - q_i) this results in the function that we actually implement. - * - */ - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - - // input validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, - "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", - ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_PAIRWSSQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - if (labels->rankOf() == 1) { // If labels and predictions are of rank 1, it means that all data entries are 0-tensor (scalar) so that the result of becomes always zero. - *output = 0.; - return Status::OK(); - } - - std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - - auto n = double(labels->sizeAt(1)); - auto diffs = *predictions - *labels; - - auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - - auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); - - - auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == E.rankOf(), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as results array, but got %i and %i correspondingly!", weights->rankOf(), E.rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and results = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(&E).c_str()); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if (weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } +CUSTOM_OP_IMPL(mean_pairwssqerr_loss, 3, 1, false, 0, 1) { + /* + * Implementation of mean pairwise squared error loss + * + * For context on where this loss function may be useful see: + * + * Wei, Z., Zhang, J., Shen, X., Lin, Z., Mech, R., Hoai, M. and Samaras, D., + * 2018. Good view hunting: learning photo composition from dense view pairs. + * In Proceedings of the IEEE Conference on Computer Vision and Pattern + * Recognition (pp. 5437-5446). + * + * The paper defines the loss function as: + * + * L(y,q) = 1/((n*(n-1))/2) * (sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - + * q_j))^2) + * + * with y: predictions, q: labels, n: length of y and q + * + * As creating those pairs is computationally expensive, we implement a + * mathematically equivalent function: + * + * L(y,q) = 4/(n*(n-1)) * (n * sum (y_i - q_i)^2 - (sum y_i - q_i)^2) + * + * This equivalency can be derived as: + * + * sum_(i,j=1..n,i!=j)((y_i - y_j) - (q_i - q_j))^2 = sum_(i,j=1..n,i!=j)((y_i + * - q_i) - (y_j - q_j))^2 + * + * To simplify the following equations we use + * + * sum_(i,j=1..n,i!=j)(d_i - d_j)^2 = sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - + * 2*d_i*d_j) + * + * Due to the pairings each element will appear as both d_i and d_j exactly + * n-1 times. This allows us to split the sum: + * + * sum_(i,j=1..n,i!=j)(d_i^2 + d_j^2 - 2*d_i*d_j) = 2*(n-1)*sum d_i^2 - 2 * + * sum_(i,j=1..n,i!=j) d_i * d_j = 2*((n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) + * d_i * d_j) + * + * Now we use the following equivalency: + * + * (sum d_i)^2 = sum d_i^2 + sum_(i,j=1..n,i!=j) d_i * d_j + * + * This allows us to now use sum d_i^2 and (sum d_i)^2 as a quick way to + * calculate the sum: + * + * (n-1) * sum d_i^2 - sum_(i,j=1..n,i!=j) d_i * d_j = n * sum d_i^2 - (sum + * d_i)^2 + * + * And by substituting it into the original definition we get: + * + * 1/((n*(n-1))/2) * 2*(n * sum d_i^2 - (sum d_i)^2) + * + * Which can be again simplified to + * + * 4/(n*(n-1)) * (n * sum d_i^2 - (sum d_i)^2) + * + * After substituting d_i back to (y_i - q_i) this results in the function + * that we actually implement. + * + */ + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // input validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_PAIRWSSQERR_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + if (labels->rankOf() == + 1) { // If labels and predictions are of rank 1, it means that all data + // entries are 0-tensor (scalar) so that the result of becomes + // always zero. + *output = 0.; + return Status::OK(); + } + + std::vector reductionIdx = + ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); + + auto n = double(labels->sizeAt(1)); + auto diffs = *predictions - *labels; + + auto sumOfSquares = + (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); + + auto squareOfSum = + diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); + + auto E = ((sumOfSquares * n) - squareOfSum) * (4 / (n * (n - 1))); + + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == E.rankOf(), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as results array, but got %i and %i correspondingly!", + weights->rankOf(), E.rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and results = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(&E).c_str()); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_TYPES(mean_pairwssqerr_loss) { +DECLARE_TYPES(mean_pairwssqerr_loss) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(mean_pairwssqerr_loss) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const *outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else { // in this case output has the shape as labels and logits minus last + // dimension + std::vector dimensions = {-1}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, + false, true, block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS OP: weights array should be scalar or have the " + "same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS OP: shapes of weights and output arrays should " + "be broadcastable, but got weights = %s and output = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); +} ////////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(mean_pairwssqerr_loss) { +CUSTOM_OP_IMPL(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 1) { + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + auto n = double(labels->sizeAt(1)); + auto diffs = *predictions - *labels; + + std::vector reductionIdx = + ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); + auto sumOfSquares = + (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); + + auto squareOfSum = + diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); + squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); + + auto E = ((sumOfSquares * n) - squareOfSum) * (4 / (n * (n - 1))); + + auto sumPred = + predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); + auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); + + dLdp->assign(((diffs * n) - sumPred + sumLabel) * (8 / (n * (n - 1)))); + + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == E.rankOf(), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as results array, but got %i and %i correspondingly!", + weights->rankOf(), E.rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and results = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(&E).c_str()); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); + dLdl->assign(-*dLdp); - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, - "MEAN_PAIRWSSQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", - ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), - ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; + if (weightsBroad != weights) delete weightsBroad; - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else { // in this case output has the shape as labels and logits minus last dimension - std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(predictionsShapeInfo), dimensions, predictionsShapeInfo, false, true, block.workspace()); + return Status::OK(); +} - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); - } - - - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - auto n = double(labels->sizeAt(1)); - auto diffs = *predictions - *labels; - - std::vector reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); - auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true); - - auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true); - squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum); - - auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); - - auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true); - auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true); - - dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); - - - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == E.rankOf(), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as results array, but got %i and %i correspondingly!", weights->rankOf(), E.rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, E), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and results = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(&E).c_str()); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-*dLdp); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); - } - - DECLARE_TYPES(mean_pairwssqerr_loss_grad) { - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(mean_pairwssqerr_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); - } - } +DECLARE_TYPES(mean_pairwssqerr_loss_grad) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(mean_pairwssqerr_loss_grad) { + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: labels and predictions arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_PAIRWSSQERR_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + Nd4jLong *dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } +} // namespace ops +} // namespace sd #endif #pragma clang diagnostic pop diff --git a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp index 1d74e2eef5b3..fa6466e04908 100644 --- a/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/meanSqErr.cpp @@ -24,280 +24,374 @@ #include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) { - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray E(labels->shapeInfo(), false, block.launchContext()); - predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); - - // multiply E on weights - E *= (*weightsBroad); - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - - STORE_RESULT(*output); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_SQERR_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "MEAN_SQERR_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE(reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray E(labels->shapeInfo(), false, block.launchContext()); + predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E); + + // multiply E on weights + E *= (*weightsBroad); + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + STORE_RESULT(*output); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(mean_sqerr_loss) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and predictions - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); - + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_SQERR_LOSS OP: labels and predictions arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same " + "rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and predictions + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { - - auto predictions = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // inputs validation - REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(predictions)) - weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); - - NDArray diff = *predictions - *labels; - - // dE_i/dp_i = 2 * (p_i - y_i) - dLdp->assign(2. * diff); // dE/dp - // dE_i/dy_i = -2 * (p_i - y_i) - // dLdl->assign(-(*dLdp)); // dE/dl - - NDArray E = diff * diff; - - switch (reductionMode) { - - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - - *dLdp *= *weightsBroad / sum; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeights); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - } - break; - } - } - - dLdl->assign(-(*dLdp)); - - if(weightsBroad != weights) - delete weightsBroad; - - return Status::OK(); + auto predictions = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // inputs validation + REQUIRE_TRUE(labels->isSameShape(predictions), 0, + "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(predictions).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(predictions)) + weightsBroad = new NDArray(weights->tileToShape(predictions->shapeInfo())); + + NDArray diff = *predictions - *labels; + + // dE_i/dp_i = 2 * (p_i - y_i) + dLdp->assign(2. * diff); // dE/dp + // dE_i/dy_i = -2 * (p_i - y_i) + // dLdl->assign(-(*dLdp)); // dE/dl + + NDArray E = diff * diff; + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdw = 0.; + } else { + *dLdp *= *weightsBroad / sum; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeights); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + } + break; + } + } + + dLdl->assign(-(*dLdp)); + + if (weightsBroad != weights) delete weightsBroad; + + return Status::OK(); } DECLARE_TYPES(mean_sqerr_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(mean_sqerr_loss_grad) { - - auto predictionsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and predictions must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto predictionsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and predictions must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, + "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be " + "broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = DataTypeUtils::pickFloatingType( + ArrayOptions::dataType(predictionsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + predictionsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp index 7aba1579bcfe..54236ee08a09 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sigmCrossEntropy.cpp @@ -25,303 +25,408 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) { - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - auto labelsSmoothing = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // If labelsSmoothing is nonzero, smooth the labels towards 1/2: - auto newLabels = labels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels); - } - - NDArray E(labels, false, block.launchContext()); - - // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits - helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - (*output) = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - if(newLabels != labels) - delete newLabels; - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + auto labelsSmoothing = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have " + "the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: + auto newLabels = labels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(*labels); + newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, + *newLabels); + } + + NDArray E(labels, false, block.launchContext()); + + // logits - labels * logits + log(1 + exp(-logits)) -> take into account + // numerical stability at large logits + helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + (*output) = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + if (newLabels != labels) delete newLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sigm_cross_entropy_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else // in this case output has the same shape as labels and logits - outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(outShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have " + "the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the " + "same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should " + "be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else // in this case output has the same shape as labels and logits + outShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(outShapeInfo); } - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - - NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext()); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - - // perform weights broadcasting/tile to labels if needed - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(logits)) - weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); - - // If labelsSmoothing is nonzero, smooth the labels towards 1/2: - auto newLabels = labels; - if(labelsSmoothing.e(0) != 0.f) { - newLabels = new NDArray(*labels); - newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e(0), *newLabels); - } - - NDArray E(labels, false, block.launchContext()); - - // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits - helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); - - // dLdp = 1 - labels - 1 / (1 + exp(logits)) - helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); - - // dLdl = -logits - labelsSmoothing -= 1.f; - dLdl->assign(*logits * labelsSmoothing); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - *dLdp *= *weightsBroad; - *dLdl *= *weightsBroad; - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum)); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - sum.setContext(block.launchContext()); - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - *dLdp *= temp; - *dLdl *= temp; - - if(weights->isScalar()) - *dLdw = 0.; - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - auto numOfNonZeroWeightsScalar = NDArrayFactory::create(dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); - - if(weights->isScalar()) - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); - else if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeightsScalar; - } - else - dLdw->assign(E / numOfNonZeroWeightsScalar); - - NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; - *dLdp *= temp; - *dLdl *= temp; - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - if(newLabels != labels) - delete newLabels; - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), + block.launchContext()); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + weights->rankOf(), labels->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *labels), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not " + "acceptable, possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + + // perform weights broadcasting/tile to labels if needed + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(logits)) + weightsBroad = new NDArray(weights->tileToShape(logits->shapeInfo())); + + // If labelsSmoothing is nonzero, smooth the labels towards 1/2: + auto newLabels = labels; + if (labelsSmoothing.e(0) != 0.f) { + newLabels = new NDArray(*labels); + newLabels->applyScalar(scalar::SXELogitsSmoother, + labelsSmoothing.e(0), *newLabels); + } + + NDArray E(labels, false, block.launchContext()); + + // logits - labels * logits + log(1 + exp(-logits)) -> take into account + // numerical stability at large logits + helpers::sigmCrossEntropy(block.launchContext(), logits, newLabels, &E); + + // dLdp = 1 - labels - 1 / (1 + exp(logits)) + helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); + + // dLdl = -logits + labelsSmoothing -= 1.f; + dLdl->assign(*logits * labelsSmoothing); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + *dLdp *= *weightsBroad; + *dLdl *= *weightsBroad; + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum)); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + sum.setContext(block.launchContext()); + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + *dLdp *= temp; + *dLdl *= temp; + + if (weights->isScalar()) + *dLdw = 0.; + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + auto numOfNonZeroWeightsScalar = NDArrayFactory::create( + dLdw->dataType(), numOfNonZeroWeights, block.launchContext()); + + if (weights->isScalar()) + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); + else if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeightsScalar; + } else + dLdw->assign(E / numOfNonZeroWeightsScalar); + + NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; + *dLdp *= temp; + *dLdl *= temp; + } + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; + if (newLabels != labels) delete newLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sigm_cross_entropy_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.workspace()); - auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.workspace()); - auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + // weights array can be single scalar or has the same rank as labels, and must + // be broadcastable to labels + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have " + "the same rank as labels array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), + 0, + "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays " + "should be broadcastable, but got weights = %s and labels = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + logitsShapeInfo, outType, false, block.workspace()); + auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, outType, false, block.workspace()); + auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), + CONSTANT(dLdlShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp index 8bbdf7381480..53339461358e 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropy.cpp @@ -24,375 +24,503 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) { - - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - double labelsSmoothing = T_ARG(0); - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - // smoothing is possible for rank of logits/labels > 1 - REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of labels/ logits = 1 !"); - - if(!output->isScalar()) { - // weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf()); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + double labelsSmoothing = T_ARG(0); + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, " + "possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + // smoothing is possible for rank of logits/labels > 1 + REQUIRE_TRUE( + labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of " + "labels/ logits = 1 !"); + + if (!output->isScalar()) { + // weights array can be single scalar or has the same shape as output, and + // must be broadcastable to output shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == output->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + weights->rankOf(), output->rankOf()); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE(weights->isScalar() || + ShapeUtils::areShapesBroadcastable(*weights, *output), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output " + "arrays should be broadcastable, but got weights = %s and " + "output = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(labels).c_str()); + } + + // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: + // new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing + // / num_classes num_classes = labels->sizeAt(1) + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(cLabels); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + + labelsSmoothing / cLabels->sizeAt(1)); + } + + // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last + // dimension softmax_i = exp(logits_i) / sum_j(exp(logits_j)) so result = + // sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) ) for numerical + // stability we use shifted logits (one can approve this using simple math): + // softmax_i = exp(logits_i - maxLogit) / sum_j(exp(logits_j - maxLogit)) + // maxLogit is max among logits_i + + std::vector dimensions = {-1}; + NDArray shiftedLogits = + *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp) + .reduceAlongDimension(reduce::Sum, dimensions, true) + .transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)) + .reduceAlongDimension(reduce::Sum, dimensions); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) { + if (E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1) + weightsBroad = new NDArray( + weights->reshape(weights->ordering(), {weights->lengthOf()})); + else + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + } + + // multiply E on weights + E *= *weightsBroad; + + switch (reductionMode) { + case 0: // 0 - "none", un-reduced weighted losses with the same shape as + // labels. + output->assign(&E); + break; + + case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all + // elements of E array + E.reduceNumber(reduce::Sum, *output); + break; } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + double sum; + if (weights->isScalar()) + sum = weights->e(0) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum).e(0); + + if (sum == 0.) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / sum); + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else { + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + } + + if (numOfNonZeroWeights == 0) + *output = 0.; + else + output->assign(E.reduceNumber(reduce::Sum) / + double(numOfNonZeroWeights)); + + break; + } + } + + if (weightsBroad != weights) delete weightsBroad; - // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes - // num_classes = labels->sizeAt(1) - NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); - NDArray* newLabels = cLabels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(cLabels); - newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); - } - - // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension - // softmax_i = exp(logits_i) / sum_j(exp(logits_j)) - // so result = sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) ) - // for numerical stability we use shifted logits (one can approve this using simple math): - // softmax_i = exp(logits_i - maxLogit) / sum_j(exp(logits_j - maxLogit)) - // maxLogit is max among logits_i - - - std::vector dimensions = {-1}; - NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) { - if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1) - weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()})); - else - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - } - - // multiply E on weights - E *= *weightsBroad; - - switch (reductionMode) { - case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. - output->assign(&E); - break; - - case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array - E.reduceNumber(reduce::Sum, *output); - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - double sum; - if (weights->isScalar()) - sum = weights->e(0) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum).e(0); - - if (sum == 0.) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / sum); - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else { - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - } - - if (numOfNonZeroWeights == 0) - *output = 0.; - else - output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); - - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - if(newLabels != cLabels) - delete newLabels; - - delete cLabels; - - return Status::OK(); + if (newLabels != cLabels) delete newLabels; + + delete cLabels; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - Nd4jLong const* outShapeInfo = nullptr; - - if(INT_ARG(0) != 0) // in this case output is scalar - outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); - else { // in this case output has the shape as labels and logits minus last dimension - std::vector dimensions = {-1}; - outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.workspace()); - - // weights array can be single scalar or has the same rank as output, and must be broadcastable to output - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str()); - } - - return SHAPELIST(outShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must " + "have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + Nd4jLong const* outShapeInfo = nullptr; + + if (INT_ARG(0) != 0) // in this case output is scalar + outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); + else { // in this case output has the shape as labels and logits minus last + // dimension + std::vector dimensions = {-1}; + outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, + block.workspace()); + + // weights array can be single scalar or has the same rank as output, and + // must be broadcastable to output + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have " + "the same rank as output array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(outShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays " + "should be broadcastable, but got weights = %s and output = %s " + "instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShapeInfo).c_str()); + } + + return SHAPELIST(outShapeInfo); } - - - - - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) { + auto logits = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto labels = INPUT_VARIABLE(2); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights + auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels + + auto labelsSmoothing = T_ARG(0); + + int reductionMode = + INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - + // "weighted_sum_by_nonzero_weights" + // take into account Alex's proposition to treat "none" the same as + // "weighted_sum" mode when calculating gradients + if (reductionMode == 0) reductionMode = 1; + + std::vector dimensions = {-1}; + + // input validation + REQUIRE_TRUE(labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + // only 4 possible reduction modes exist + REQUIRE_TRUE( + reductionMode == 0 || reductionMode == 1 || reductionMode == 2 || + reductionMode == 3, + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not " + "acceptable, possible values are 0, 1, 2, 3, but got %i instead!", + reductionMode); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + logits->ordering(), dimensions, logits->shapeInfo(), false, false, + block.workspace()); + // weights array can be single scalar or has the same shape as loss, and must + // be broadcastable to loss shape + REQUIRE_TRUE( + weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or " + "have the same rank as loss array, but got %i and %i correspondingly!", + weights->rankOf(), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + weights->isScalar() || ShapeUtils::areShapesBroadcastable( + weights->shapeInfo(), lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays " + "should be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weights).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + // smoothing is possible for rank of logits/labels > 1 + REQUIRE_TRUE( + labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: smoothing is not possible when rank " + "of labels/ logits = 1 !"); + + // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: + // new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing + // / num_classes num_classes = labels->sizeAt(1) + NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); + NDArray* newLabels = cLabels; + if (labelsSmoothing != 0.) { + newLabels = new NDArray(labels->shapeInfo(), dLdl->dataType(), false, + block.launchContext()); + newLabels->assign((1.f - labelsSmoothing) * *cLabels + + labelsSmoothing / cLabels->sizeAt(1)); + } + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); + + // dEdp = softmax * sum_i(lables_i) - labels + dLdp->assign( + softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - + *newLabels); + + // dEdl = -log(softmax) + dLdl->assign(-softmax.transform(transform::Log) * (1.f - labelsSmoothing)); + + NDArray shiftedLogits = + *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); + NDArray logSumExp = shiftedLogits.transform(transform::Exp) + .reduceAlongDimension(reduce::Sum, dimensions, true) + .transform(transform::Log); + NDArray E = (*newLabels * (logSumExp - shiftedLogits)) + .reduceAlongDimension(reduce::Sum, dimensions); + + // perform weights broadcasting/tile to E if it is necessary + auto weightsBroad = weights; + if (!weights->isScalar() && !weights->isSameShape(&E)) + weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); + + dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions); + + switch (reductionMode) { + case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to + // sum of all elements of E array + + if (weights->isScalar() || weights->lengthOf() == 1) { + dLdw->assign(E.reduceNumber(reduce::Sum)); + *dLdp *= *weights; + *dLdl *= *weights; + } else { + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), + weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + } else + dLdw->assign(E); + } + + break; + } + case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all + // elements of E array divided by sum of all elements of + // weightsBroad array + NDArray sum; + if (weights->isScalar()) + sum = (*weights) * E.lengthOf(); + else + sum = weightsBroad->reduceNumber(reduce::Sum); + + if (sum.e(0) == 0.) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + if (weights->isScalar() || weights->lengthOf() == 1) { + NDArray temp = *weights / sum; + *dLdp *= temp; + *dLdl *= temp; + *dLdw = 0.; + } else { + NDArray temp = *weightsBroad / sum; + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)) + .reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, + true, false, false); + } else + dLdw->assign( + (E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / + (sum * sum)); + } + } + break; + } + case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and + // equal to scalar sum of all elements of E array divided by + // number of non-zero weights + Nd4jLong numOfNonZeroWeights = 0; + if (weights->isScalar()) { + if (weights->e(0) != 0.) numOfNonZeroWeights = E.lengthOf(); + } else + numOfNonZeroWeights = + weightsBroad->reduceNumber(reduce::CountNonZero).e(0); + + if (numOfNonZeroWeights == 0) { + *dLdp = 0.; + *dLdl = 0.; + *dLdw = 0.; + } else { + if (weights->isScalar() || weights->lengthOf() == 1) { + NDArray temp = *weights / numOfNonZeroWeights; + *dLdp *= temp; + *dLdl *= temp; + dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); + } else { + NDArray temp = *weightsBroad / numOfNonZeroWeights; + dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdp); + dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, + *dLdl); + + if (weights != weightsBroad) { + std::vector axesToReduceAlong = + ShapeUtils::evalBroadcastBackwardAxis( + weights->shapeInfo(), weightsBroad->shapeInfo()); + E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, + false, false); + *dLdw /= numOfNonZeroWeights; + } else + dLdw->assign(E / numOfNonZeroWeights); + } + } + break; + } + } - auto logits = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto labels = INPUT_VARIABLE(2); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights - auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels - - auto labelsSmoothing = T_ARG(0); + if (weightsBroad != weights) delete weightsBroad; - int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" - // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients - if(reductionMode == 0) - reductionMode = 1; + if (newLabels != cLabels) delete newLabels; - std::vector dimensions = {-1}; + delete cLabels; - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - // only 4 possible reduction modes exist - REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->shapeInfo(), false, false, block.workspace()); - // weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape - REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->shapeInfo(), lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - // smoothing is possible for rank of logits/labels > 1 - REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: smoothing is not possible when rank of labels/ logits = 1 !"); - - // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes - // num_classes = labels->sizeAt(1) - NDArray* cLabels = new NDArray(labels->cast(weights->dataType())); - NDArray* newLabels = cLabels; - if(labelsSmoothing != 0.) { - newLabels = new NDArray(labels->shapeInfo(), dLdl->dataType(), false, block.launchContext()); - newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); - } - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true); - - // dEdp = softmax * sum_i(lables_i) - labels - dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels); - - // dEdl = -log(softmax) - dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); - - NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true); - NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log); - NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions); - - // perform weights broadcasting/tile to E if it is necessary - auto weightsBroad = weights; - if(!weights->isScalar() && !weights->isSameShape(&E)) - weightsBroad = new NDArray(weights->tileToShape(E.shapeInfo())); - - dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions); - - switch (reductionMode) { - case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array - - if(weights->isScalar() || weights->lengthOf() == 1) { - dLdw->assign(E.reduceNumber(reduce::Sum)); - *dLdp *= *weights; - *dLdl *= *weights; - } - else { - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, *weightsBroad, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign(E); - } - - break; - } - case 2: { // 2 - "weighted_mean", output is scalar and equal to sum of all elements of E array divided by sum of all elements of weightsBroad array - NDArray sum; - if (weights->isScalar()) - sum = (*weights) * E.lengthOf(); - else - sum = weightsBroad->reduceNumber(reduce::Sum); - - if (sum.e(0) == 0.) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - if(weights->isScalar() || weights->lengthOf() == 1) { - NDArray temp = *weights / sum; - *dLdp *= temp; - *dLdl *= temp; - *dLdw = 0.; - } - else { - - NDArray temp = *weightsBroad / sum; - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - } - else - dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); - } - } - break; - } - case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights - Nd4jLong numOfNonZeroWeights = 0; - if(weights->isScalar()) { - if(weights->e(0) != 0.) - numOfNonZeroWeights = E.lengthOf(); - } - else - numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e(0); - - if (numOfNonZeroWeights == 0) { - *dLdp = 0.; - *dLdl = 0.; - *dLdw = 0.; - } - else { - - if(weights->isScalar() || weights->lengthOf() == 1) { - NDArray temp = *weights / numOfNonZeroWeights; - *dLdp *= temp; - *dLdl *= temp; - dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); - } - else { - NDArray temp = *weightsBroad / numOfNonZeroWeights; - dLdp->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdp); - dLdl->applyBroadcast(sd::broadcast::Multiply, dimensions, temp, *dLdl); - - if(weights != weightsBroad) { - std::vector axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->shapeInfo(), weightsBroad->shapeInfo()); - E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false); - *dLdw /= numOfNonZeroWeights; - } - else - dLdw->assign(E / numOfNonZeroWeights); - } - } - break; - } - } - - if(weightsBroad != weights) - delete weightsBroad; - - if(newLabels != cLabels) - delete newLabels; - - delete cLabels; - - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_grad) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto labelsShapeInfo = inputShape->at(2); - - std::vector dimensions = {-1}; - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.workspace()); - // weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); - // check whether broadcast operation is possible for weights array - REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str()); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); - auto dLdwShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(weightsShapeInfo), shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo))); - auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto labelsShapeInfo = inputShape->at(2); + + std::vector dimensions = {-1}; + + // labels and logits must have the same shapes + REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, + block.workspace()); + // weights array can be single scalar or has the same rank as loss, and must + // be broadcastable to loss + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or " + "have the same rank as loss array, but got %i and %i correspondingly!", + shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo)); + // check whether broadcast operation is possible for weights array + REQUIRE_TRUE( + shape::isScalar(weightsShapeInfo) || + ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), + 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays " + "should be broadcastable, but got weights = %s and loss = %s instead!", + ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), + ShapeUtils::shapeAsString(lossShapeInfo).c_str()); + + auto outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(logitsShapeInfo), + shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); + auto dLdwShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo))); + auto dLdlShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp index 7e4785ce14e5..51e8d4ae2345 100644 --- a/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/softmaxCrossEntropyWithLogits.cpp @@ -24,118 +24,154 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { - auto logits = INPUT_VARIABLE(0); - auto labels = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - - auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true); - auto logExp = (*logits - maxAlongDim).transform(transform::Exp); - auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log); - - (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension); - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto labels = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf() - 1; + + // input validation + REQUIRE_TRUE( + labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + REQUIRE_TRUE( + classesDim < logits->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be " + "smaller than rank of logits, but got %i and %i correspondingly !", + classesDim, logits->rankOf()); + + std::vector dimension = {classesDim}; + + auto maxAlongDim = + logits->reduceAlongDimension(reduce::Max, {classesDim}, true); + auto logExp = (*logits - maxAlongDim).transform(transform::Exp); + auto logSoftMax = + (logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true)) + .transform(transform::Log); + + (-(*labels) * logSoftMax) + .reduceAlongDimension(reduce::Sum, *output, dimension); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { - - auto logitsShapeInfo = inputShape->at(0); - auto labelsShapeInfo = inputShape->at(1); - - const int classesDim = block.numI() > 0 ? INT_ARG(0) : -1; - std::vector dimensions = {classesDim}; - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, false, false, block.workspace()); - - return SHAPELIST(reducedShapeInfo); - + auto logitsShapeInfo = inputShape->at(0); + auto labelsShapeInfo = inputShape->at(1); + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : -1; + std::vector dimensions = {classesDim}; + + // labels and logits must have the same shapes + REQUIRE_TRUE( + shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays " + "must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + auto outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + auto reducedShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(labelsShapeInfo), dimensions, labelsShapeInfo, outType, + false, false, block.workspace()); + + return SHAPELIST(reducedShapeInfo); } - - - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) { - - auto logits = INPUT_VARIABLE(0); - auto labels = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels - - const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf()-1; - - // input validation - REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); - REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); - - std::vector dimension = {classesDim}; - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); - - // dEdp = softmax * sum_i(labels_i) - labels - dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels); - - // dEdl = -log(softmax) - (-softmax).applyTransform(transform::Log, *dLdl); - - return Status::OK(); + auto logits = INPUT_VARIABLE(0); + auto labels = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels + + const int classesDim = block.numI() > 0 ? INT_ARG(0) : logits->rankOf() - 1; + + // input validation + REQUIRE_TRUE( + labels->isSameShape(logits), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits " + "arrays must have the same shapes, but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(labels).c_str(), + ShapeUtils::shapeAsString(logits).c_str()); + REQUIRE_TRUE( + classesDim < logits->rankOf(), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be " + "smaller than rank of logits, but got %i and %i correspondingly !", + classesDim, logits->rankOf()); + + std::vector dimension = {classesDim}; + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); + + // dEdp = softmax * sum_i(labels_i) - labels + dLdp->assign(softmax * + labels->reduceAlongDimension(reduce::Sum, dimension, true) - + *labels); + + // dEdl = -log(softmax) + (-softmax).applyTransform(transform::Log, *dLdl); + + return Status::OK(); } - ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) { - - getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) { - - auto logitsShapeInfo = inputShape->at(0); - auto labelsShapeInfo = inputShape->at(1); - - // labels and logits must have the same shapes - REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); - auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); - - return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); + auto logitsShapeInfo = inputShape->at(0); + auto labelsShapeInfo = inputShape->at(1); + + // labels and logits must have the same shapes + REQUIRE_TRUE( + shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, + "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits " + "arrays must have the same shapes, but got %s and %s correspondingly!", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + auto dLdpShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(logitsShapeInfo), + shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); + auto dLdlShapeInfo = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + outType, shape::order(labelsShapeInfo), + shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); + + return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp index 6f671e148002..397feeb35f4f 100644 --- a/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp +++ b/libnd4j/include/ops/declarable/generic/loss/sparseSoftmaxCrossEntropyWithLogits.cpp @@ -25,144 +25,189 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) { - auto labels = INPUT_VARIABLE(0); - auto logits = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int labelsRank = labels->rankOf(); - const int logitsRank = logits->rankOf(); - - // input validation - REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); - - std::vector labelsShape = labels->getShapeAsVector(); // this is correct - std::vector logitsShape = logits->getShapeAsVector(); - logitsShape.pop_back(); - bool equalSoft = logitsShape == labelsShape; - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); - - std::vector dimension = {-1}; - - auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); - auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); - auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log)); - - helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); - - return Status::OK(); +CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, + 0) { + auto labels = INPUT_VARIABLE(0); + auto logits = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int labelsRank = labels->rankOf(); + const int logitsRank = logits->rankOf(); + + // input validation + REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays " + "should satisfy relation (labels_rank = logits_rank - 1), but " + "got labels_rank = %i and logits_rank = %i instead !", + labelsRank, logitsRank); + + std::vector labelsShape = + labels->getShapeAsVector(); // this is correct + std::vector logitsShape = logits->getShapeAsVector(); + logitsShape.pop_back(); + bool equalSoft = logitsShape == labelsShape; + + REQUIRE_TRUE( + equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels " + "array, its shape should be the same as logits shape with last dimension " + "excluded, however got labels_shape = %s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShape).c_str(), + ShapeUtils::shapeAsString(logitsShape).c_str()); + + std::vector dimension = {-1}; + + auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true); + auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); + auto logSoftMax = -( + (logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true)) + .transform(transform::Log)); + + helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, + false); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) { - - auto labelsShapeInfo = inputShape->at(0); - auto logitsShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsShapeInfo[0], logitsShapeInfo[0]); - - bool equalSoft = true; - for (int i = 1; i < labelsShapeInfo[0]; ++i) - if (labelsShapeInfo[i] != logitsShapeInfo[i]) { - equalSoft = false; - break; - } - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.workspace()); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto labelsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(1); + + REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays " + "should satisfy relation (labels_rank = logits_rank - 1), but " + "got labels_rank = %i and logits_rank = %i instead !", + labelsShapeInfo[0], logitsShapeInfo[0]); + + bool equalSoft = true; + for (int i = 1; i < labelsShapeInfo[0]; ++i) + if (labelsShapeInfo[i] != logitsShapeInfo[i]) { + equalSoft = false; + break; + } + + REQUIRE_TRUE( + equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels " + "array, its shape should be the same as logits shape with last dimension " + "excluded, however got labels_shape = %s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + labelsShapeInfo, logitsShapeInfo, false, block.workspace()); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - - - - - - ////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) { - - auto labels = INPUT_VARIABLE(0); - auto logits = INPUT_VARIABLE(1); - - auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits - - const int labelsRank = labels->rankOf(); - const int logitsRank = logits->rankOf(); - - // input validation - REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); - - std::vector labelsShape = labels->getShapeAsVector(); // this is correct - std::vector logitsShape = logits->getShapeAsVector(); - logitsShape.pop_back(); - bool equalSoft = logitsShape == labelsShape; - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); - - std::vector dimension = {-1}; - - NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp); - softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); - - // dEdp = softmax - 1 (or 0) - dLdp->assign(softmax); - - // subtract unities at appropriate indexes of dLdp array - helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true); - - return Status::OK(); +CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, + 0, 0) { + auto labels = INPUT_VARIABLE(0); + auto logits = INPUT_VARIABLE(1); + + auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits + + const int labelsRank = labels->rankOf(); + const int logitsRank = logits->rankOf(); + + // input validation + REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input " + "arrays should satisfy relation (labels_rank = logits_rank - " + "1), but got labels_rank = %i and logits_rank = %i instead !", + labelsRank, logitsRank); + + std::vector labelsShape = + labels->getShapeAsVector(); // this is correct + std::vector logitsShape = logits->getShapeAsVector(); + logitsShape.pop_back(); + bool equalSoft = logitsShape == labelsShape; + + REQUIRE_TRUE(equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong " + "shape of labels array, its shape should be the same as logits " + "shape with last dimension excluded, however got labels_shape = " + "%s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShape).c_str(), + ShapeUtils::shapeAsString(logitsShape).c_str()); + + std::vector dimension = {-1}; + + NDArray softmax = + (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)) + .transform(transform::Exp); + softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true); + + // dEdp = softmax - 1 (or 0) + dLdp->assign(softmax); + + // subtract unities at appropriate indexes of dLdp array + helpers::scatterForLoss( + block.launchContext(), *labels, *dLdp, + *labels /*actually third array is unnecessary for gradient calculation*/, + true); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) { - - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits_grad) { - - auto labelsShapeInfo = inputShape->at(0); - auto logitsShapeInfo = inputShape->at(1); - - REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsShapeInfo[0], logitsShapeInfo[0]); - - bool equalSoft = true; - for (int i = 1; i < labelsShapeInfo[0]; ++i) - if (labelsShapeInfo[i] != logitsShapeInfo[i]) { - equalSoft = false; - break; - } - - REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); - - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); - - Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.workspace()); - - return SHAPELIST(CONSTANT(dLdpShapeInfo)); + auto labelsShapeInfo = inputShape->at(0); + auto logitsShapeInfo = inputShape->at(1); + + REQUIRE_TRUE(labelsShapeInfo[0] == logitsShapeInfo[0] - 1, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input " + "arrays should satisfy relation (labels_rank = logits_rank - " + "1), but got labels_rank = %i and logits_rank = %i instead !", + labelsShapeInfo[0], logitsShapeInfo[0]); + + bool equalSoft = true; + for (int i = 1; i < labelsShapeInfo[0]; ++i) + if (labelsShapeInfo[i] != logitsShapeInfo[i]) { + equalSoft = false; + break; + } + + REQUIRE_TRUE(equalSoft, 0, + "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong " + "shape of labels array, its shape should be the same as logits " + "shape with last dimension excluded, however got labels_shape = " + "%s and logits_shape = %s instead !", + ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), + ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); + + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); + + Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType( + logitsShapeInfo, outType, false, block.workspace()); + + return SHAPELIST(CONSTANT(dLdpShapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp index 9b5ed1918875..69ba6bb893b9 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp @@ -25,69 +25,74 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cbow, 15, 15, true, 0, 0) { - auto target = INPUT_VARIABLE(0); - auto ngStarter = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(cbow, 15, 15, true, 0, 0) { + auto target = INPUT_VARIABLE(0); + auto ngStarter = INPUT_VARIABLE(1); - // required part - auto context = INPUT_VARIABLE(2); - auto indices = INPUT_VARIABLE(3); - auto codes = INPUT_VARIABLE(4); + // required part + auto context = INPUT_VARIABLE(2); + auto indices = INPUT_VARIABLE(3); + auto codes = INPUT_VARIABLE(4); - auto syn0 = INPUT_VARIABLE(5); - auto syn1 = INPUT_VARIABLE(6); - auto syn1neg = INPUT_VARIABLE(7); + auto syn0 = INPUT_VARIABLE(5); + auto syn1 = INPUT_VARIABLE(6); + auto syn1neg = INPUT_VARIABLE(7); - auto expTable = INPUT_VARIABLE(8); - auto negTable = INPUT_VARIABLE(9); + auto expTable = INPUT_VARIABLE(8); + auto negTable = INPUT_VARIABLE(9); - auto alpha = INPUT_VARIABLE(10); - auto randomValue = INPUT_VARIABLE(11); - auto numLabels = INPUT_VARIABLE(12); + auto alpha = INPUT_VARIABLE(10); + auto randomValue = INPUT_VARIABLE(11); + auto numLabels = INPUT_VARIABLE(12); - auto lockedWords = INPUT_VARIABLE(13); + auto lockedWords = INPUT_VARIABLE(13); - auto inferenceVector = INPUT_VARIABLE(14); + auto inferenceVector = INPUT_VARIABLE(14); - auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); - auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; + auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); + auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; - auto trainWords = block.numB() > 0 ? B_ARG(0) : true; - auto isInference = block.numB() > 1 ? B_ARG(1) : false; + auto trainWords = block.numB() > 0 ? B_ARG(0) : true; + auto isInference = block.numB() > 1 ? B_ARG(1) : false; - REQUIRE_TRUE(block.isInplace(), 0, "CBOW: this operation requires inplace execution only"); + REQUIRE_TRUE(block.isInplace(), 0, + "CBOW: this operation requires inplace execution only"); - REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && syn0->dataType() == syn1neg->dataType(), 0, "CBOW: all syn tables must have the same data type"); - REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, "CBOW: expTable must have the same data type as syn0 table"); + REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && + syn0->dataType() == syn1neg->dataType(), + 0, "CBOW: all syn tables must have the same data type"); + REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, + "CBOW: expTable must have the same data type as syn0 table"); + sd::ops::helpers::cbow(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, + *ngStarter, nsRounds, *context, *lockedWords, *indices, + *codes, *alpha, *randomValue, *numLabels, + *inferenceVector, trainWords, numWorkers); - sd::ops::helpers::cbow(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, *ngStarter, nsRounds, *context, *lockedWords, *indices, *codes, *alpha, *randomValue, *numLabels, *inferenceVector, trainWords, numWorkers); - - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cbow) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::INT32) - ->setAllowedInputTypes(2, sd::DataType::INT32) - ->setAllowedInputTypes(3, sd::DataType::INT32) - ->setAllowedInputTypes(4, sd::DataType::INT8) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedInputTypes(10, {ALL_FLOATS}) - ->setAllowedInputTypes(11, sd::DataType::INT64) - ->setAllowedInputTypes(12, sd::DataType::INT32) - ->setAllowedInputTypes(13, sd::DataType::INT32) - ->setAllowedInputTypes(14, {ALL_FLOATS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(cbow) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::INT32) + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedInputTypes(3, sd::DataType::INT32) + ->setAllowedInputTypes(4, sd::DataType::INT8) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedInputTypes(10, {ALL_FLOATS}) + ->setAllowedInputTypes(11, sd::DataType::INT64) + ->setAllowedInputTypes(12, sd::DataType::INT32) + ->setAllowedInputTypes(13, sd::DataType::INT32) + ->setAllowedInputTypes(14, {ALL_FLOATS}) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp index c96ddb503df6..d9212086284a 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp @@ -25,69 +25,76 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(skipgram, 12, 12, true, 0, 0) { - auto target = INPUT_VARIABLE(0); - auto ngStarter = INPUT_VARIABLE(1); +namespace ops { +CONFIGURABLE_OP_IMPL(skipgram, 12, 12, true, 0, 0) { + auto target = INPUT_VARIABLE(0); + auto ngStarter = INPUT_VARIABLE(1); - // required part - auto indices = INPUT_VARIABLE(2); - auto codes = INPUT_VARIABLE(3); + // required part + auto indices = INPUT_VARIABLE(2); + auto codes = INPUT_VARIABLE(3); - auto syn0 = INPUT_VARIABLE(4); - auto syn1 = INPUT_VARIABLE(5); - auto syn1neg = INPUT_VARIABLE(6); + auto syn0 = INPUT_VARIABLE(4); + auto syn1 = INPUT_VARIABLE(5); + auto syn1neg = INPUT_VARIABLE(6); - auto expTable = INPUT_VARIABLE(7); - auto negTable = INPUT_VARIABLE(8); + auto expTable = INPUT_VARIABLE(7); + auto negTable = INPUT_VARIABLE(8); - auto alpha = INPUT_VARIABLE(9); - auto randomValue = INPUT_VARIABLE(10); + auto alpha = INPUT_VARIABLE(9); + auto randomValue = INPUT_VARIABLE(10); - auto inferenceVector = INPUT_VARIABLE(11); + auto inferenceVector = INPUT_VARIABLE(11); - //auto neu1e = INPUT_VARIABLE(12); + // auto neu1e = INPUT_VARIABLE(12); - auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); - auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; + auto numWorkers = block.numI() > 0 ? INT_ARG(0) : omp_get_max_threads(); + auto nsRounds = block.numI() > 1 ? INT_ARG(1) : 0; - auto isInference = block.numB() > 0 ? B_ARG(0) : false; - auto isPreciseMode = block.numB() > 1 ? B_ARG(1) : false; + auto isInference = block.numB() > 0 ? B_ARG(0) : false; + auto isPreciseMode = block.numB() > 1 ? B_ARG(1) : false; - REQUIRE_TRUE(block.isInplace(), 0, "SkipGram: this operation requires inplace execution only"); + REQUIRE_TRUE(block.isInplace(), 0, + "SkipGram: this operation requires inplace execution only"); - REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && syn0->dataType() == syn1neg->dataType(), 0, "SkipGram: all syn tables must have the same data type"); - REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, "SkipGram: expTable must have the same data type as syn0 table"); + REQUIRE_TRUE(syn0->dataType() == syn1->dataType() && + syn0->dataType() == syn1neg->dataType(), + 0, "SkipGram: all syn tables must have the same data type"); + REQUIRE_TRUE(syn0->dataType() == expTable->dataType(), 0, + "SkipGram: expTable must have the same data type as syn0 table"); + sd::ops::helpers::skipgram(*syn0, *syn1, *syn1neg, *expTable, *negTable, + *target, *ngStarter, nsRounds, *indices, *codes, + *alpha, *randomValue, *inferenceVector, + isPreciseMode, numWorkers); - sd::ops::helpers::skipgram(*syn0, *syn1, *syn1neg, *expTable, *negTable, *target, *ngStarter, nsRounds, *indices, *codes, *alpha, *randomValue, *inferenceVector, isPreciseMode, numWorkers); - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(skipgram) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::INT32) - ->setAllowedInputTypes(2, sd::DataType::INT32) - ->setAllowedInputTypes(3, sd::DataType::INT8) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedInputTypes(10, sd::DataType::INT64) - ->setAllowedInputTypes(11, {ALL_FLOATS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } +DECLARE_TYPES(skipgram) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::INT32) + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedInputTypes(3, sd::DataType::INT8) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedInputTypes(10, sd::DataType::INT64) + ->setAllowedInputTypes(11, {ALL_FLOATS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - /* - DECLARE_SHAPE_FN(skipgram) { - return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, block.workspace())); - } - */ - } +/* +DECLARE_SHAPE_FN(skipgram) { + return SHAPELIST(ShapeBuilders::createScalarShapeInfo(DataType::INT8, +block.workspace())); } +*/ +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp index 8ef400ed4482..809ae122c753 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/crelu.cpp @@ -22,93 +22,93 @@ #if NOT_EXCLUDED(OP_crelu) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(crelu, 1, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - - REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); - - auto tmp = x->dup(); - tmp.applyTransform(sd::transform::Neg, tmp); - - auto z = OUTPUT_VARIABLE(0); - - helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1); - // NDArrayFactory::concat({x, tmp}, -1, z); - - // TODO: make this configurable? - double threshold = 0.0; - z->applyScalar(sd::scalar::RELU, threshold, *z); - - STORE_RESULT(z); - - return Status::OK(); - } - - DECLARE_TYPES(crelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(crelu) { - auto inShape = inputShape->at(0); - std::vector shape; - for (int e = 0; e < shape::rank(inShape); e++) - shape.emplace_back(shape::shapeOf(inShape)[e]); - - shape[shape.size()-1] *= 2; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); - - return SHAPELIST(newShape); - } - - CUSTOM_OP_IMPL(crelu_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilonNext = INPUT_VARIABLE(1); - auto epsilon = OUTPUT_VARIABLE(0); - - // at first step we build fwd activation - sd::ops::crelu op; - auto tmpResult = op.evaluate({input}); - if (tmpResult.status() != ND4J_STATUS_OK) - return tmpResult.status(); - - auto actv = tmpResult.at(0); - - // now we do RELU backward pass - //actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); - helpers::reluDerivative(block.launchContext(), &actv, epsilonNext); - // now we split updated array into 2 chunks along last dimension - sd::ops::concat_bp opc; - auto dec = opc.evaluate({input, input, &actv}, {-1}); - if (dec.status() != ND4J_STATUS_OK) - return dec.status(); - - // and now we subtract two parts of epsilons and pass result out - auto pos = dec.at(0); - auto neg = dec.at(1); - - pos.applyPairwiseTransform(sd::pairwise::Subtract, neg, *epsilon); - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(crelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - DECLARE_SHAPE_FN(crelu_bp) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape))); - } - } +namespace ops { +CUSTOM_OP_IMPL(crelu, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + + REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); + + auto tmp = x->dup(); + tmp.applyTransform(sd::transform::Neg, tmp); + + auto z = OUTPUT_VARIABLE(0); + + helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf() - 1); + // NDArrayFactory::concat({x, tmp}, -1, z); + + // TODO: make this configurable? + double threshold = 0.0; + z->applyScalar(sd::scalar::RELU, threshold, *z); + + STORE_RESULT(z); + + return Status::OK(); +} + +DECLARE_TYPES(crelu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(crelu) { + auto inShape = inputShape->at(0); + std::vector shape; + for (int e = 0; e < shape::rank(inShape); e++) + shape.emplace_back(shape::shapeOf(inShape)[e]); + + shape[shape.size() - 1] *= 2; + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + + return SHAPELIST(newShape); +} + +CUSTOM_OP_IMPL(crelu_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilonNext = INPUT_VARIABLE(1); + auto epsilon = OUTPUT_VARIABLE(0); + + // at first step we build fwd activation + sd::ops::crelu op; + auto tmpResult = op.evaluate({input}); + if (tmpResult.status() != ND4J_STATUS_OK) return tmpResult.status(); + + auto actv = tmpResult.at(0); + + // now we do RELU backward pass + // actv->applyPairwiseTransform(pairwise::RELUDerivativeE, *epsilon, nullptr); + helpers::reluDerivative(block.launchContext(), &actv, epsilonNext); + // now we split updated array into 2 chunks along last dimension + sd::ops::concat_bp opc; + auto dec = opc.evaluate({input, input, &actv}, {-1}); + if (dec.status() != ND4J_STATUS_OK) return dec.status(); + + // and now we subtract two parts of epsilons and pass result out + auto pos = dec.at(0); + auto neg = dec.at(1); + + pos.applyPairwiseTransform(sd::pairwise::Subtract, neg, *epsilon); + + return ND4J_STATUS_OK; +} + +DECLARE_TYPES(crelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); +} + +DECLARE_SHAPE_FN(crelu_bp) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp index d71906d2908f..a8446b3419ff 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/cube.cpp @@ -18,7 +18,6 @@ // @author raver119@gmail.com // - #include #if NOT_EXCLUDED(OP_cube) @@ -26,41 +25,42 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cube, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(cube, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::Cube, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::Cube, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(cube) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(cube) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - CONFIGURABLE_OP_IMPL(cube_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(cube_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::CUBEDerivativeE, epsilon, z, nullptr); - helpers::cubeDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::CUBEDerivativeE, epsilon, z, + // nullptr); + helpers::cubeDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(cube_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(cube_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp index f89f0d2c7c17..accd6f2d2559 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/elu.cpp @@ -24,47 +24,47 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + input->applyScalar(sd::scalar::ELU, alpha, *output); - input->applyScalar(sd::scalar::ELU, alpha, *output); - - return Status::OK(); - } - - DECLARE_TYPES(elu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } + return Status::OK(); +} - CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) { +DECLARE_TYPES(elu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; - // input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output); - helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha); + // input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output); + helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(elu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(elu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp index ba498fea98db..bb7261dff267 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/hardsigmoid.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(hardsigmoid, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(hardsigmoid, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::HardSigmoid, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::HardSigmoid, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(hardsigmoid) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(hardsigmoid) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(hardsigmoid_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(hardsigmoid_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::HardSigmoidDerivativeE, epsilon, z, nullptr); - helpers::hardSigmoidDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::HardSigmoidDerivativeE, epsilon, z, + // nullptr); + helpers::hardSigmoidDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(hardsigmoid_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(hardsigmoid_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp index 0a245e6a086e..66fadedeab3c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/hardtanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(hardtanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(hardtanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::HardTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::HardTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(hardtanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(hardtanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(hardtanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(hardtanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::HardTanhDerivativeE, epsilon, z, nullptr); - helpers::hardTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::HardTanhDerivativeE, epsilon, z, + // nullptr); + helpers::hardTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(hardtanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(hardtanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp index 38e4a3ae88d6..799a895b34b6 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/identity.cpp @@ -24,47 +24,43 @@ #include namespace sd { - namespace ops { - OP_IMPL(identity, 1, 1, true) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(identity, 1, 1, true) { + auto z = OUTPUT_VARIABLE(0); - if (!block.isInplace()) { - auto first = INPUT_VARIABLE(0); + if (!block.isInplace()) { + auto first = INPUT_VARIABLE(0); - // we hope for memcpy here - z->assign(first); - } - - return Status::OK(); - } - DECLARE_SYN(linear, identity); - - DECLARE_TYPES(identity) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } + // we hope for memcpy here + z->assign(first); + } + return Status::OK(); +} +DECLARE_SYN(linear, identity); - OP_IMPL(identity_bp, 2, 1, true) { - auto first = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +DECLARE_TYPES(identity) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - z->assign(epsilon); +OP_IMPL(identity_bp, 2, 1, true) { + auto first = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - return Status::OK(); - } - DECLARE_SYN(LinearGrad, identity_bp); + z->assign(epsilon); + return Status::OK(); +} +DECLARE_SYN(LinearGrad, identity_bp); - DECLARE_TYPES(identity_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - } +DECLARE_TYPES(identity_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp index 38b1ade26ab3..b96628165ce4 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/identity_n.cpp @@ -24,39 +24,38 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(identity_n, 1, 1, true, 0, 0) { - - // just for lulz - if (!block.isInplace()) { - for (Nd4jLong i = 0; i < block.width(); ++i) { - auto x = INPUT_VARIABLE(i); - auto z = OUTPUT_VARIABLE(i); - - x->applyTransform(transform::Identity, *z); - } - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(identity_n) { - auto shapes = SHAPELIST(); - for (size_t i = 0; i < inputShape->size(); ++i) { - Nd4jLong* shape; - COPY_SHAPE_EX(inputShape->at(i), shape, block.workspace()); - shapes->push_back(CONSTANT(shape)); - } - return shapes; - } - - DECLARE_TYPES(identity_n) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - +namespace ops { +CUSTOM_OP_IMPL(identity_n, 1, 1, true, 0, 0) { + // just for lulz + if (!block.isInplace()) { + for (Nd4jLong i = 0; i < block.width(); ++i) { + auto x = INPUT_VARIABLE(i); + auto z = OUTPUT_VARIABLE(i); + + x->applyTransform(transform::Identity, *z); } + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(identity_n) { + auto shapes = SHAPELIST(); + for (size_t i = 0; i < inputShape->size(); ++i) { + Nd4jLong* shape; + COPY_SHAPE_EX(inputShape->at(i), shape, block.workspace()); + shapes->push_back(CONSTANT(shape)); + } + return shapes; } +DECLARE_TYPES(identity_n) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp index 2f4c2dc041a3..df6d11dcd314 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/lrelu.cpp @@ -24,45 +24,48 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - input->applyScalar(sd::scalar::LeakyRELU, alpha, *output); - STORE_RESULT(output); + input->applyScalar(sd::scalar::LeakyRELU, alpha, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(lrelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(lrelu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - //input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr); - helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, + // nullptr); + helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha); + return Status::OK(); +} - DECLARE_TYPES(lrelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(lrelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp index c382939ccba8..b4d29921527b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/prelu.cpp @@ -18,135 +18,163 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 24.07.2018 // - #include #if NOT_EXCLUDED(OP_prelu) #include #include + #include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - std::vector sharedAxes = block.getIArguments(); - - const int inputRank = input->rankOf(); - const int numSharedAxes = sharedAxes.size(); // can be zero as well - const Nd4jLong inputLen = input->lengthOf(); - const Nd4jLong alphaLen = alpha->lengthOf(); - const std::vector inputShape = input->getShapeAsVector(); - const std::vector alphaShape = alpha->getShapeAsVector(); - - //***** input validation *****// - std::vector expectedAlphaShape(&inputShape[1], &inputShape[inputRank]); - - REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank); - - for(int i = 0; i < numSharedAxes; ++i) { - if(sharedAxes[i] <= 0) - sharedAxes[i] += inputRank - 1; - REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i); - expectedAlphaShape[sharedAxes[i] - 1] = 1; - } - - Nd4jLong product = 1; - for(const auto& item : expectedAlphaShape) - product *= item; - - REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); - // ***** end of validation ***** // - - helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + std::vector sharedAxes = block.getIArguments(); + + const int inputRank = input->rankOf(); + const int numSharedAxes = sharedAxes.size(); // can be zero as well + const Nd4jLong inputLen = input->lengthOf(); + const Nd4jLong alphaLen = alpha->lengthOf(); + const std::vector inputShape = input->getShapeAsVector(); + const std::vector alphaShape = alpha->getShapeAsVector(); + + //***** input validation *****// + std::vector expectedAlphaShape(&inputShape[1], + &inputShape[inputRank]); + + REQUIRE_TRUE(inputRank > 1, 0, + "PRELU OP: wrong rank of input array, expected rank should be > " + "1, but got %i instead !", + inputRank); + + for (int i = 0; i < numSharedAxes; ++i) { + if (sharedAxes[i] <= 0) sharedAxes[i] += inputRank - 1; + REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, + "PRELU OP: wrong axis value %i in sharedAxes at position %i, " + "axis value must be within range [1, input_rank-1] !", + sharedAxes[i], i); + expectedAlphaShape[sharedAxes[i] - 1] = 1; + } + + Nd4jLong product = 1; + for (const auto& item : expectedAlphaShape) product *= item; + + REQUIRE_TRUE(product == alphaLen, 0, + "PRELU OP: wrong shape of alpha array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), + ShapeUtils::shapeAsString(alphaShape).c_str()); + // ***** end of validation ***** // + + helpers::prelu(block.launchContext(), *input, + alphaShape != expectedAlphaShape + ? alpha->reshape(alpha->ordering(), expectedAlphaShape) + : *alpha, + *output); + + return Status::OK(); } - - DECLARE_TYPES(prelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - +DECLARE_TYPES(prelu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - auto dLdO = INPUT_VARIABLE(2); - - auto dLdI = OUTPUT_VARIABLE(0); - auto dLdA = OUTPUT_VARIABLE(1); - - std::vector sharedAxes = block.getIArguments(); - - const int inputRank = input->rankOf(); - const int numSharedAxes = sharedAxes.size(); // can be zero as well - const Nd4jLong inputLen = input->lengthOf(); - const Nd4jLong alphaLen = alpha->lengthOf(); - const std::vector inputShape = input->getShapeAsVector(); - const std::vector alphaShape = alpha->getShapeAsVector(); - - //***** input validation *****// - - // temporary limitation imposed by Yurii - REQUIRE_TRUE(inputRank <= MAX_RANK/2, 0, "rank of input array should be <= MAX_RANK/2, but got %i instead!", inputRank); - REQUIRE_TRUE(input->lengthOf() / alpha->lengthOf() <= MAX_RANK*2, 0, "the length of input array should be no more than MAX_RANK*2 times the alpha array length, but got %lld and %lld correspondingly!", input->lengthOf(), alpha->lengthOf()); - - std::vector expectedAlphaShape(&inputShape[1], &inputShape[inputRank]); - - REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank); - - for(int i = 0; i < numSharedAxes; ++i) { - if(sharedAxes[i] <= 0) - sharedAxes[i] += inputRank - 1; - REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i); - expectedAlphaShape[sharedAxes[i] - 1] = 1; - } - - Nd4jLong product = 1; - for(const auto& item : expectedAlphaShape) - product *= item; - - REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); - // ***** end of validation ***** // - - - if(alphaShape != expectedAlphaShape) { - alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape)); - dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape)); - } - - helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA); - - if(alphaShape != expectedAlphaShape) { - delete alpha; - delete dLdA; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + auto dLdO = INPUT_VARIABLE(2); + + auto dLdI = OUTPUT_VARIABLE(0); + auto dLdA = OUTPUT_VARIABLE(1); + + std::vector sharedAxes = block.getIArguments(); + + const int inputRank = input->rankOf(); + const int numSharedAxes = sharedAxes.size(); // can be zero as well + const Nd4jLong inputLen = input->lengthOf(); + const Nd4jLong alphaLen = alpha->lengthOf(); + const std::vector inputShape = input->getShapeAsVector(); + const std::vector alphaShape = alpha->getShapeAsVector(); + + //***** input validation *****// + + // temporary limitation imposed by Yurii + REQUIRE_TRUE( + inputRank <= MAX_RANK / 2, 0, + "rank of input array should be <= MAX_RANK/2, but got %i instead!", + inputRank); + REQUIRE_TRUE( + input->lengthOf() / alpha->lengthOf() <= MAX_RANK * 2, 0, + "the length of input array should be no more than MAX_RANK*2 times the " + "alpha array length, but got %lld and %lld correspondingly!", + input->lengthOf(), alpha->lengthOf()); + + std::vector expectedAlphaShape(&inputShape[1], + &inputShape[inputRank]); + + REQUIRE_TRUE(inputRank > 1, 0, + "PRELU_BP OP: wrong rank of input array, expected rank should " + "be > 1, but got %i instead !", + inputRank); + + for (int i = 0; i < numSharedAxes; ++i) { + if (sharedAxes[i] <= 0) sharedAxes[i] += inputRank - 1; + REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, + "PRELU_BP OP: wrong axis value %i in sharedAxes at position " + "%i, axis value must be within range [1, input_rank-1] !", + sharedAxes[i], i); + expectedAlphaShape[sharedAxes[i] - 1] = 1; + } + + Nd4jLong product = 1; + for (const auto& item : expectedAlphaShape) product *= item; + + REQUIRE_TRUE(product == alphaLen, 0, + "PRELU_BP OP: wrong shape of alpha array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), + ShapeUtils::shapeAsString(alphaShape).c_str()); + // ***** end of validation ***** // + + if (alphaShape != expectedAlphaShape) { + alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape)); + dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape)); + } + + helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA); + + if (alphaShape != expectedAlphaShape) { + delete alpha; + delete dLdA; + } + + return Status::OK(); } - DECLARE_TYPES(prelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedInputTypes(2, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - -} +DECLARE_TYPES(prelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedInputTypes( + 2, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp index 3386d1578289..f871bfde2622 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/rationaltanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(rationaltanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(rationaltanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::RationalTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::RationalTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rationaltanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(rationaltanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(rationaltanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(rationaltanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RationalTanhDerivativeE, epsilon, z, nullptr); - helpers::rationalTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::RationalTanhDerivativeE, epsilon, + // z, nullptr); + helpers::rationalTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(rationaltanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(rationaltanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp index 641ee0d0e9d1..478012372dbe 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/rectifiedtanh.cpp @@ -25,41 +25,44 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(rectifiedtanh, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(rectifiedtanh, 1, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::RectifiedTanh, *output); - STORE_RESULT(output); + input->applyTransform(sd::transform::RectifiedTanh, *output); + STORE_RESULT(output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(rectifiedtanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(rectifiedtanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(rectifiedtanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(rectifiedtanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RectifiedTanhDerivativeE, epsilon, z, nullptr); - helpers::rectifiedTanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::RectifiedTanhDerivativeE, epsilon, + // z, nullptr); + helpers::rectifiedTanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(rectifiedtanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(rectifiedtanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp index 91c599126aae..3818dd205233 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/relu.cpp @@ -25,45 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(relu, 1, 1, true, 1, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(relu, 1, 1, true, 1, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - first->applyScalar(sd::scalar::RELU, scalar, *z); + first->applyScalar(sd::scalar::RELU, scalar, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(relu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(relu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} - CONFIGURABLE_OP_IMPL(relu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(relu_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::RELUDerivativeE, epsilon, z, nullptr); - helpers::reluDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(ReluGrad, relu_bp); + // input->applyPairwiseTransform(pairwise::RELUDerivativeE, epsilon, z, + // nullptr); + helpers::reluDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(ReluGrad, relu_bp); - DECLARE_TYPES(relu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(relu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp index 129c09480ac6..65dd0b74b539 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/relu6.cpp @@ -25,46 +25,44 @@ #include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - input->applyScalar(sd::scalar::RELU6, T_ARG(0), *output); + input->applyScalar(sd::scalar::RELU6, T_ARG(0), *output); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(relu6) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(relu6) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(relu6_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - //input->applyPairwiseTransform(pairwise::RELU6DerivativeE, gradO, gradI, nullptr); - helpers::relu6Derivative(block.launchContext(), input, gradO, gradI); - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + // input->applyPairwiseTransform(pairwise::RELU6DerivativeE, gradO, gradI, + // nullptr); + helpers::relu6Derivative(block.launchContext(), input, gradO, gradI); + return Status::OK(); } - DECLARE_TYPES(relu6_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - - - -} +DECLARE_TYPES(relu6_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp index 7fc6aa11ac63..ebbddb9a73be 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/selu.cpp @@ -25,42 +25,45 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(selu, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(selu, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SELU, *z); + first->applyTransform(sd::transform::SELU, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(selu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(selu) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(selu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(selu_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SELUDerivativeE, epsilon, z, nullptr); - helpers::seluDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::SELUDerivativeE, epsilon, z, + // nullptr); + helpers::seluDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(selu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(selu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp index 047d973e6f23..381fc7136a80 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/sigmoid.cpp @@ -24,42 +24,45 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(sigmoid, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(sigmoid, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::Sigmoid, *z); + first->applyTransform(sd::transform::Sigmoid, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(sigmoid) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(sigmoid) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(sigmoid_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(sigmoid_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SigmoidDerivativeE, epsilon, z, nullptr); - helpers::sigmoidDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } + // input->applyPairwiseTransform(pairwise::SigmoidDerivativeE, epsilon, z, + // nullptr); + helpers::sigmoidDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} - DECLARE_TYPES(sigmoid_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(sigmoid_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp index 5cd17e752e93..1340eacf4170 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/softplus.cpp @@ -25,43 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(softplus, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(softplus, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SoftPlus, *z); + first->applyTransform(sd::transform::SoftPlus, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(softplus) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(softplus) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(softplus_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(softplus_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SoftplusDerivativeE, epsilon, z, nullptr); - helpers::softPlusDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(SoftplusGrad, softplus_bp); + // input->applyPairwiseTransform(pairwise::SoftplusDerivativeE, epsilon, z, + // nullptr); + helpers::softPlusDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(SoftplusGrad, softplus_bp); - DECLARE_TYPES(softplus_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(softplus_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp index c7fb15fddf29..9b07c512beae 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/softsign.cpp @@ -25,44 +25,47 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(softsign, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(softsign, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::SoftSign, *z); + first->applyTransform(sd::transform::SoftSign, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(softsign) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(softsign) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(softsign_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(softsign_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::SoftsignDerivativeE, epsilon, z, nullptr); - helpers::softSignDerivative(block.launchContext(), input, epsilon, z); + // input->applyPairwiseTransform(pairwise::SoftsignDerivativeE, epsilon, z, + // nullptr); + helpers::softSignDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(SoftsignGrad, softsign_bp); + return Status::OK(); +} +DECLARE_SYN(SoftsignGrad, softsign_bp); - DECLARE_TYPES(softsign_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(softsign_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp index a42552f75327..f577083ec21a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/tanh.cpp @@ -25,43 +25,46 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(tanh, 1, 1, true, 0, 0) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CONFIGURABLE_OP_IMPL(tanh, 1, 1, true, 0, 0) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(sd::transform::Tanh, *z); + first->applyTransform(sd::transform::Tanh, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(tanh) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(tanh) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} - CONFIGURABLE_OP_IMPL(tanh_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto epsilon = INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(tanh_bp, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //input->applyPairwiseTransform(pairwise::TanhDerivativeE, epsilon, z, nullptr); - helpers::tanhDerivative(block.launchContext(), input, epsilon, z); - return Status::OK(); - } - DECLARE_SYN(TanhGrad, tanh_bp); + // input->applyPairwiseTransform(pairwise::TanhDerivativeE, epsilon, z, + // nullptr); + helpers::tanhDerivative(block.launchContext(), input, epsilon, z); + return Status::OK(); +} +DECLARE_SYN(TanhGrad, tanh_bp); - DECLARE_TYPES(tanh_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } - } +DECLARE_TYPES(tanh_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp index 519f09d6d95f..fe5c1c70a693 100644 --- a/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/activations/thresholdedrelu.cpp @@ -18,57 +18,55 @@ // @author Yurii Shyrma, created on 24.07.2018 // - #include #if NOT_EXCLUDED(OP_thresholdedrelu) #include -#include #include +#include namespace sd { -namespace ops { - +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(thresholdedrelu, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - helpers::thresholdRelu(block.launchContext(), *input, scalar, *output); + helpers::thresholdRelu(block.launchContext(), *input, scalar, *output); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(thresholdedrelu) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(thresholdedrelu) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY)->setSameMode(true); +} //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(thresholdedrelu_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto dLdO = INPUT_VARIABLE(1); - - auto dLdI = OUTPUT_VARIABLE(0); - auto threshold = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + auto input = INPUT_VARIABLE(0); + auto dLdO = INPUT_VARIABLE(1); - helpers::thresholdReluDerivative(block.launchContext(), input, threshold, dLdO, dLdI); - - return Status::OK(); -} - - DECLARE_TYPES(thresholdedrelu_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) - ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); - } + auto dLdI = OUTPUT_VARIABLE(0); + auto threshold = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + helpers::thresholdReluDerivative(block.launchContext(), input, threshold, + dLdO, dLdI); + return Status::OK(); } + +DECLARE_TYPES(thresholdedrelu_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) + ->setAllowedOutputTypes( + 0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp index e287c5035ccc..39fb71b2e483 100644 --- a/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/apply_sgd.cpp @@ -25,38 +25,45 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(apply_sgd, 2, 1, true, -2, 0) { - auto parameters = INPUT_VARIABLE(0); - auto gradients = INPUT_VARIABLE(1); - - double lr = 0.0; - - REQUIRE_TRUE(parameters->isSameShape(gradients), 0, "ApplySGD: parameters and gradients should have the same shape, but got parameters = %s and gradients = %s !", ShapeUtils::shapeAsString(parameters).c_str(), ShapeUtils::shapeAsString(gradients).c_str()); - - if (block.width() == 3) { - auto tarr = INPUT_VARIABLE(2); - lr = tarr->e(0); - } else if (block.numT() == 1) { - lr = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "ApplyGradients op should have LR announced either es T argument or additional NDArray!"); - } - - auto Z = OUTPUT_VARIABLE(0); - - helpers::applyGradientDescent(block.launchContext(), parameters, gradients, lr, Z); - - return Status::OK(); - } - DECLARE_SYN(ApplyGradientDescent, apply_sgd); - } - - DECLARE_TYPES(apply_sgd) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +CONFIGURABLE_OP_IMPL(apply_sgd, 2, 1, true, -2, 0) { + auto parameters = INPUT_VARIABLE(0); + auto gradients = INPUT_VARIABLE(1); + + double lr = 0.0; + + REQUIRE_TRUE(parameters->isSameShape(gradients), 0, + "ApplySGD: parameters and gradients should have the same shape, " + "but got parameters = %s and gradients = %s !", + ShapeUtils::shapeAsString(parameters).c_str(), + ShapeUtils::shapeAsString(gradients).c_str()); + + if (block.width() == 3) { + auto tarr = INPUT_VARIABLE(2); + lr = tarr->e(0); + } else if (block.numT() == 1) { + lr = T_ARG(0); + } else { + REQUIRE_TRUE(false, 0, + "ApplyGradients op should have LR announced either es T " + "argument or additional NDArray!"); + } + + auto Z = OUTPUT_VARIABLE(0); + + helpers::applyGradientDescent(block.launchContext(), parameters, gradients, + lr, Z); + + return Status::OK(); +} +DECLARE_SYN(ApplyGradientDescent, apply_sgd); +} // namespace ops + +DECLARE_TYPES(apply_sgd) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 5f0806152798..c185ef13bac1 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -24,336 +24,398 @@ #if NOT_EXCLUDED(OP_batchnorm) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.numI(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const uint numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(unsigned long i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !"); - - nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - // auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - // auto m = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - - helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); - - // NDArray stdInv = *v + epsilon; - // stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - // stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 - // if(applyScale) - // stdInv *= *gamma; - - // // empty array with same shape as input - // input->applyBroadcast(sd::broadcast::Subtract, axes, m, output); - // output->applyBroadcast(sd::broadcast::Multiply, axes, &stdInv); - - // if(applyOffset) - // output->applyBroadcast(sd::broadcast::Add, axes, beta); - - // delete v; - // delete m; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto mean = INPUT_VARIABLE(1); + auto variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const uint numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM op: too big number of input axes to normalize over, " + "expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of mean array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of variance array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of gamma array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM op: wrong shape of beta array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (unsigned long i = 1; i < block.width(); ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM op: types of all input arrays should be the same !"); + + nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); + + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, + // false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); auto m = + // input->reduceAlongDimension(sd::reduce::Mean, + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + + helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); + + // NDArray stdInv = *v + epsilon; + // stdInv.applyTransform(transform::Reciprocal); // 1 / + // (variance + epsilon) stdInv.applyTransform(transform::Sqrt); // 1 / + // (variance + epsilon)^0.5 if(applyScale) + // stdInv *= *gamma; + + // // empty array with same shape as input + // input->applyBroadcast(sd::broadcast::Subtract, axes, m, output); + // output->applyBroadcast(sd::broadcast::Multiply, axes, &stdInv); + + // if(applyOffset) + // output->applyBroadcast(sd::broadcast::Add, axes, beta); + + // delete v; + // delete m; + + return Status::OK(); } DECLARE_TYPES(batchnorm) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } DECLARE_SHAPE_FN(batchnorm) { + auto inShapeInfo = inputShape->at(0); + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - auto inShapeInfo = inputShape->at(0); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inShapeInfo, outType, false, + block.workspace()); // output shape is identical to input shape - auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.workspace()); // output shape is identical to input shape - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* dLdI = OUTPUT_VARIABLE(0); - NDArray* dLdM = OUTPUT_VARIABLE(1); - NDArray* dLdV = OUTPUT_VARIABLE(2); - NDArray* dLdG = nullptr; - NDArray* dLdB = nullptr; - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.numI(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const uint numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); - - // types of all input arrays should be the same (except dLdO) - for(unsigned long i = 1; i < block.width() - 2; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); - - // ***** calculations ***** // - - // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO - // stdInv = 1 / (v + eps)^0.5 - // N - batch size (product of spatial dimensions) - - // derivatives: - // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) - - // dfdx = gamma*stdInv*g; - // dfdm = -gamma*stdInv*g_sum; - // dmdx = 1/N; - // dvdx = 2 * (x - m) / N - // dvdm = -2 * [(x - m)]_sum / N - // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience - - // finally: - // dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) ) - - // dLdG = (g * (x - m))_sum * stdInv - // dLdB = g_sum - - // variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - // mean = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); - - const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); - const bool keepUnitiesInShape = inRank == mean->rankOf(); - - // inverse batch size 1/N - const float Ninv = 1.f * shape::tadLength(input->shapeInfo(), axes.data(), axes.size()) / input->lengthOf(); - - // input - mean - NDArray xMinusMean(input); // empty array with same shape as input - input->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); - - // stdInv - NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 - - // dvdm (use dLdM as storage for dvdm) - xMinusMean.reduceAlongDimension(sd::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape); - *dLdM *= -Ninv; - - // g_sum - auto gSum = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes, keepUnitiesInShape); - - // dLdB - if(applyOffset) - dLdB->assign(gSum); - - // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) - gSum *= Ninv; - dLdO->applyBroadcast(sd::broadcast::Subtract, axes, gSum, *dLdI); - dLdI->applyBroadcast(sd::broadcast::Multiply, axes, stdInv, *dLdI); - - // dLdV <- [g*(x - m)]_sum - (xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape); - - // dLdG - *dLdV *= stdInv; - if(applyScale) - dLdG->assign(dLdV); - - // (2 / N) * dfdv (use dLdV as storage for dfdv) - *dLdV *= stdInv*stdInv; // dLdV*stdInv * stdInv^2 - *dLdV *= -Ninv; // -0.5f * (2 / N); - - // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, *dLdM, xMinusMean); - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, *dLdV, xMinusMean); - - // dLdI - *dLdI += xMinusMean; - if(applyScale) - dLdI->applyBroadcast(sd::broadcast::Multiply, axes, *gamma, *dLdI); - - *dLdM = 0; // put zeros so far - *dLdV = 0; // put zeros so far - - // java code - // NDArray std = *variance + epsilon; - // std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - // std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 - // NDArray xMu(input); - // input->applyBroadcast(sd::broadcast::Subtract, axes, mean, &xMu); - // NDArray xHat(input); - // xMu.applyBroadcast(sd::broadcast::Multiply, axes, &std, &xHat); - // NDArray dxhat(input); - // dLdO->applyBroadcast(sd::broadcast::Multiply, axes, gamma, &dxhat); - // NDArray temp = dxhat*xMu; - // temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); - // *dLdV *= -0.5f * std*std*std; - // NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); - // *dxmu1 *= -std; - // NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); - // *dxmu2 *= *dLdV * (-2.f/N); - // NDArray dLdmu = *dxmu1 + *dxmu2; - // dLdmu *= (1.f /N); - // *dLdV *= (2.f/N); - // dxhat.applyBroadcast(sd::broadcast::Multiply, axes, &std); - // xMu.applyBroadcast(sd::broadcast::Multiply, axes, dLdV); - // dxhat += xMu; - // dxhat.applyBroadcast(sd::broadcast::Add, axes, &dLdmu, dLdI); - // delete dxmu1; - // delete dxmu2; - // xHat *= *dLdO; - // xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape); - - return Status::OK(); + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* dLdI = OUTPUT_VARIABLE(0); + NDArray* dLdM = OUTPUT_VARIABLE(1); + NDArray* dLdV = OUTPUT_VARIABLE(2); + NDArray* dLdG = nullptr; + NDArray* dLdB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const uint numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM_BP op: too big number of input axes to normalize " + "over, expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of mean array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of variance array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM_BP op: wrong shape of beta array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(dLdO), 0, + "BATCHNORM_BP op: wrong shape of output gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(dLdO).c_str()); + + // types of all input arrays should be the same (except dLdO) + for (unsigned long i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP op: types of arrays (input, mean, variance, " + "gamma, beta) should be the same !"); + + // ***** calculations ***** // + + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * + // ff_output, g = dLdO stdInv = 1 / (v + eps)^0.5 N - batch size (product of + // spatial dimensions) + + // derivatives: + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // dfdx = gamma*stdInv*g; + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc + // convenience + + // finally: + // dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - + // m)) ) + + // dLdG = (g * (x - m))_sum * stdInv + // dLdB = g_sum + + // variance = input->varianceAlongDimension(variance::SummaryStatsVariance, + // false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); mean = + // input->reduceAlongDimension(sd::reduce::Mean, + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + + const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); + const bool keepUnitiesInShape = inRank == mean->rankOf(); + + // inverse batch size 1/N + const float Ninv = + 1.f * shape::tadLength(input->shapeInfo(), axes.data(), axes.size()) / + input->lengthOf(); + + // input - mean + NDArray xMinusMean(input); // empty array with same shape as input + input->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal, + stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, + stdInv); // 1 / (variance + epsilon)^0.5 + + // dvdm (use dLdM as storage for dvdm) + xMinusMean.reduceAlongDimension(sd::reduce::Sum, *dLdM, excludedAxes, + keepUnitiesInShape); + *dLdM *= -Ninv; + + // g_sum + auto gSum = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes, + keepUnitiesInShape); + + // dLdB + if (applyOffset) dLdB->assign(gSum); + + // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) + gSum *= Ninv; + dLdO->applyBroadcast(sd::broadcast::Subtract, axes, gSum, *dLdI); + dLdI->applyBroadcast(sd::broadcast::Multiply, axes, stdInv, *dLdI); + + // dLdV <- [g*(x - m)]_sum + (xMinusMean * *dLdO) + .reduceAlongDimension(sd::reduce::Sum, *dLdV, excludedAxes, + keepUnitiesInShape); + + // dLdG + *dLdV *= stdInv; + if (applyScale) dLdG->assign(dLdV); + + // (2 / N) * dfdv (use dLdV as storage for dfdv) + *dLdV *= stdInv * stdInv; // dLdV*stdInv * stdInv^2 + *dLdV *= -Ninv; // -0.5f * (2 / N); + + // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, *dLdM, xMinusMean); + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, *dLdV, xMinusMean); + + // dLdI + *dLdI += xMinusMean; + if (applyScale) + dLdI->applyBroadcast(sd::broadcast::Multiply, axes, *gamma, *dLdI); + + *dLdM = 0; // put zeros so far + *dLdV = 0; // put zeros so far + + // java code + // NDArray std = *variance + epsilon; + // std.applyTransform(transform::Reciprocal); // 1 / + // (variance + epsilon) std.applyTransform(transform::Sqrt); // 1 / (variance + // + epsilon)^0.5 NDArray xMu(input); + // input->applyBroadcast(sd::broadcast::Subtract, axes, mean, &xMu); + // NDArray xHat(input); + // xMu.applyBroadcast(sd::broadcast::Multiply, axes, &std, &xHat); + // NDArray dxhat(input); + // dLdO->applyBroadcast(sd::broadcast::Multiply, axes, gamma, &dxhat); + // NDArray temp = dxhat*xMu; + // temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, + // keepUnitiesInShape); *dLdV *= -0.5f * std*std*std; NDArray* dxmu1 = + // dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); + // *dxmu1 *= -std; + // NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, + // keepUnitiesInShape); *dxmu2 *= *dLdV * (-2.f/N); NDArray dLdmu = *dxmu1 + + // *dxmu2; dLdmu *= (1.f /N); *dLdV *= (2.f/N); + // dxhat.applyBroadcast(sd::broadcast::Multiply, axes, &std); + // xMu.applyBroadcast(sd::broadcast::Multiply, axes, dLdV); + // dxhat += xMu; + // dxhat.applyBroadcast(sd::broadcast::Add, axes, &dLdmu, dLdI); + // delete dxmu1; + // delete dxmu2; + // xHat *= *dLdO; + // xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, + // keepUnitiesInShape); + + return Status::OK(); } DECLARE_TYPES(batchnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedInputTypes(2, sd::DataType::ANY) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, sd::DataType::ANY) - ->setAllowedInputTypes(5, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedInputTypes(2, sd::DataType::ANY) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, sd::DataType::ANY) + ->setAllowedInputTypes(5, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batchnorm_bp) { + Nd4jLong const* inShapeInfo = inputShape->at(0); + Nd4jLong const* meanShapeInfo = inputShape->at(1); - Nd4jLong const* inShapeInfo = inputShape->at(0); - Nd4jLong const* meanShapeInfo = inputShape->at(1); + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); + DataType outType = + DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); + auto shapes = SHAPELIST(); - auto shapes = SHAPELIST(); + // dLdI shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + outType, inShapeInfo)); - // dLdI shapeInfo - shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo)); + // dLdM shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + outType, meanShapeInfo)); - // dLdM shapeInfo - shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo)); + // dLdV shapeInfo (same as dLdM) + shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdV shapeInfo (same as dLdM) - shapes->push_back(shapes->at(shapes->size()-1)); + // dLdG shapeInfo (same as dLdM) + if (applyScale) shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdG shapeInfo (same as dLdM) - if(applyScale) - shapes->push_back(shapes->at(shapes->size()-1)); + // dLdB shapeInfo (same as dLdM) + if (applyOffset) shapes->push_back(shapes->at(shapes->size() - 1)); - // dLdB shapeInfo (same as dLdM) - if(applyOffset) - shapes->push_back(shapes->at(shapes->size()-1)); - - return shapes; + return shapes; } - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp index 9d45797b470e..ef1d0f8ef435 100644 --- a/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/bias_add.cpp @@ -23,93 +23,105 @@ #if NOT_EXCLUDED(OP_biasadd) #include -#include +#include namespace sd { namespace ops { //////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(biasadd, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto bias = INPUT_VARIABLE(1); - auto input = INPUT_VARIABLE(0); - auto bias = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; + const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last - const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; - const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last + REQUIRE_TRUE(bias->rankOf() == 1, 0, + "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i " + "instead !", + bias->rankOf()); - REQUIRE_TRUE(bias->rankOf() == 1, 0, "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i instead !", bias->rankOf()); + REQUIRE_TRUE( + bias->sizeAt(0) == input->sizeAt(channelDim), 0, + "BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not " + "suitable for broadcast operation along channel dimension %i !", + ShapeUtils::shapeAsString(bias).c_str(), + ShapeUtils::shapeAsString(input).c_str(), channelDim); - REQUIRE_TRUE(bias->sizeAt(0) == input->sizeAt(channelDim), 0, "BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not suitable for broadcast operation along channel dimension %i !", ShapeUtils::shapeAsString(bias).c_str(), ShapeUtils::shapeAsString(input).c_str(), channelDim); + REQUIRE_TRUE(output->isSameShape(input), 0, + "BIASADD CUSTOM_OP: wrong shape of output array, expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(output).c_str()); - REQUIRE_TRUE(output->isSameShape(input), 0, "BIASADD CUSTOM_OP: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(output).c_str()); + helpers::addBias(block, *input, *bias, *output, isNCHW); + // input->applyBroadcast(sd::broadcast::Add, {channelDim}, bias, output); - helpers::addBias(block, *input, *bias, *output, isNCHW); - // input->applyBroadcast(sd::broadcast::Add, {channelDim}, bias, output); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(bias_add, biasadd); //////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(biasadd) { - auto xShape = inputShape->at(0); - auto yShape = inputShape->at(1); + auto xShape = inputShape->at(0); + auto yShape = inputShape->at(1); - auto dtype = ArrayOptions::dataType(yShape); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(xShape, dtype))); + auto dtype = ArrayOptions::dataType(yShape); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(xShape, dtype))); } DECLARE_TYPES(biasadd) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } //////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto bias = INPUT_VARIABLE(1); + auto gradO = INPUT_VARIABLE(2); - auto input = INPUT_VARIABLE(0); - auto bias = INPUT_VARIABLE(1); - auto gradO = INPUT_VARIABLE(2); - - auto gradI = OUTPUT_VARIABLE(0); - auto gradB = OUTPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + auto gradB = OUTPUT_VARIABLE(1); - const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; - const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last + const bool isNCHW = !block.getBArguments().empty() ? B_ARG(0) : false; + const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last - gradI->assign(gradO); + gradI->assign(gradO); - gradO->reduceAlongDimension(sd::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); + gradO->reduceAlongDimension( + sd::reduce::Sum, *gradB, + ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; } DECLARE_SYN(BiasAddGrad, biasadd_bp); //////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(biasadd_bp) { - auto input = inputShape->at(0); - auto bias = inputShape->at(1); + auto input = inputShape->at(0); + auto bias = inputShape->at(1); - Nd4jLong* epsShape; - Nd4jLong* gradShape; + Nd4jLong* epsShape; + Nd4jLong* gradShape; - COPY_SHAPE(input, epsShape); - COPY_SHAPE(bias, gradShape); + COPY_SHAPE(input, epsShape); + COPY_SHAPE(bias, gradShape); - return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); + return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); } DECLARE_TYPES(biasadd_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp index 07c0d29c90ea..eb27aa59424b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/col2im.cpp @@ -25,68 +25,71 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf()); - REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf()); - - int strideY = INT_ARG(0); - int strideX = INT_ARG(1); - int padHeight = INT_ARG(2); - int padWidth = INT_ARG(3); - int imgHeight = INT_ARG(4); - int imgWidth = INT_ARG(5); - int dY = INT_ARG(6); //Dilation in height/y dimension - int dX = INT_ARG(7); //Dilation in width/x dimension - - LaunchContext* ctx = block.launchContext(); - helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX); - - return ND4J_STATUS_OK; - } - DECLARE_SHAPE_FN(col2im) { - auto inShape = inputShape->at(0); - - int bS = shape::shapeOf(inShape)[0]; - int iD = shape::shapeOf(inShape)[1]; - - int sY = INT_ARG(0); - int sX = INT_ARG(1); - int pY = INT_ARG(2); - int pX = INT_ARG(3); - int inY = INT_ARG(4); - int inX = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - - Nd4jLong* zShape; - ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); - - zShape[0] = 4; - zShape[1] = bS; - zShape[2] = iD; - zShape[3] = inY; - zShape[4] = inX; - - zShape[shape::shapeInfoLength(zShape) - 2] = 1; - zShape[shape::shapeInfoLength(zShape) - 1] = 99; - - ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); - - return SHAPELIST(CONSTANT(zShape)); - } - - DECLARE_TYPES(col2im) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(x->rankOf() == 6, 0, + "col2im input should be 6D, but got %i instead", x->rankOf()); + REQUIRE_TRUE(z->rankOf() == 4, 0, + "col2im output should be 4D, but got %i instead", z->rankOf()); + + int strideY = INT_ARG(0); + int strideX = INT_ARG(1); + int padHeight = INT_ARG(2); + int padWidth = INT_ARG(3); + int imgHeight = INT_ARG(4); + int imgWidth = INT_ARG(5); + int dY = INT_ARG(6); // Dilation in height/y dimension + int dX = INT_ARG(7); // Dilation in width/x dimension + + LaunchContext* ctx = block.launchContext(); + helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, + imgHeight, imgWidth, dY, dX); + + return ND4J_STATUS_OK; } +DECLARE_SHAPE_FN(col2im) { + auto inShape = inputShape->at(0); + + int bS = shape::shapeOf(inShape)[0]; + int iD = shape::shapeOf(inShape)[1]; + + int sY = INT_ARG(0); + int sX = INT_ARG(1); + int pY = INT_ARG(2); + int pX = INT_ARG(3); + int inY = INT_ARG(4); + int inX = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + + Nd4jLong* zShape; + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(4), Nd4jLong); + + zShape[0] = 4; + zShape[1] = bS; + zShape[2] = iD; + zShape[3] = inY; + zShape[4] = inX; + + zShape[shape::shapeInfoLength(zShape) - 2] = 1; + zShape[shape::shapeInfoLength(zShape) - 1] = 99; + + ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); + + return SHAPELIST(CONSTANT(zShape)); +} + +DECLARE_TYPES(col2im) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 1d8bf947f71f..b3abc3ac683f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -18,291 +18,436 @@ // @author raver119@gmail.com // @author Yurii Shyrma - #include #if NOT_EXCLUDED(OP_conv1d) -#include #include +#include #include namespace sd { -namespace ops { - - +namespace ops { CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { - - auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC - int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - const int rank = 3; - REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - int bS = input->sizeAt(0); // batch size - int iW = input->sizeAt(indIiW); // input width - int iC = input->sizeAt(indIOioC); // input channels - int oC = weights->sizeAt(indWoC); // output channels - - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector reshapeForInput, reshapeForOutput; - if(!isNCW) { - reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] - reshapeForOutput = {output->sizeAt(0), 1, output->sizeAt(1), output->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] - } - else { - reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] - reshapeForOutput = {output->sizeAt(0), output->sizeAt(1), 1, output->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] - } - - auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); - auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false); - auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - - sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) + + int kW = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 0-NCW, 1-NWC + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + const int rank = 3; + REQUIRE_TRUE(input->rankOf() == rank, 0, + "CUSTOM CONV1D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV1D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weights->rankOf()); + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + int bS = input->sizeAt(0); // batch size + int iW = input->sizeAt(indIiW); // input width + int iC = input->sizeAt(indIOioC); // input channels + int oC = weights->sizeAt(indWoC); // output channels + + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV1D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV1D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector reshapeForInput, reshapeForOutput; + if (!isNCW) { + reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), + input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] + reshapeForOutput = {output->sizeAt(0), 1, output->sizeAt(1), + output->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] + } else { + reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, + input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] + reshapeForOutput = {output->sizeAt(0), output->sizeAt(1), 1, + output->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] + } + + auto inputReshaped = input->reshape(input->ordering(), reshapeForInput); + auto outputReshaped = + output->reshape(output->ordering(), reshapeForOutput, false); + auto weightsReshaped = weights->reshape( + weights->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), + weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] + + sd::ops::conv2d conv2d; + const Nd4jStatus status = conv2d.execute( + {&inputReshaped, &weightsReshaped, bias}, {&outputReshaped}, {}, + {1, kW, 1, sW, 0, pW, 1, dW, paddingMode, !isNCW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, + // &outputReshaped, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); + + return Status::OK(); } - DECLARE_SHAPE_FN(conv1d) { - - auto inputShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - Nd4jLong const* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int rank = 3; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV1D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV1D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - - int bS = inputShapeInfo[1]; // batch size - int iW = inputShapeInfo[indIiW+1]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH,oW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - - outputShapeInfo[0] = 3; - outputShapeInfo[1] = bS; - - if (isNCW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oW; - } else { - outputShapeInfo[2] = oW; - outputShapeInfo[3] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(weightsShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + Nd4jLong const* biasShapeInfo = + block.width() > 2 ? inputShape->at(2) : nullptr; + + int kW = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int rank = 3; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV1D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV1D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo); + + int bS = inputShapeInfo[1]; // batch size + int iW = inputShapeInfo[indIiW + 1]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV1D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, 1, kW, 1, sW, 0, pW, 1, dW, 1, iW, + paddingMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + outputShapeInfo[0] = 3; + outputShapeInfo[1] = bS; + + if (isNCW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oW; + } else { + outputShapeInfo[2] = oW; + outputShapeInfo[3] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(weightsShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } DECLARE_TYPES(conv1d) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { - - auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL - int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - const int rank = 3; - REQUIRE_TRUE(input->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradO->rankOf()); - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int bS = input->sizeAt(0); // batch size - const int iW = input->sizeAt(indIiW); // input width - const int iC = input->sizeAt(indIOioC); // input channels - const int oC = weights->sizeAt(indWoC); // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector reshapeForInput, reshapeForGradO; - if(!isNCW) { - reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] - reshapeForGradO = {gradO->sizeAt(0), 1, gradO->sizeAt(1), gradO->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] - } - else { - reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] - reshapeForGradO = {gradO->sizeAt(0), gradO->sizeAt(1), 1, gradO->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] - } - - auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput); - auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false); - auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO); - auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC] - - sd::ops::conv2d_bp conv2dBP; - auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,dW, paddingMode, isNCW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weights = INPUT_VARIABLE(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE( + 2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next + + auto gradI = + OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon + auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kW = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME, 2-CAUSAL + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + const int rank = 3; + REQUIRE_TRUE(input->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == rank, 0, + "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradO->rankOf()); + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int bS = input->sizeAt(0); // batch size + const int iW = input->sizeAt(indIiW); // input width + const int iC = input->sizeAt(indIOioC); // input channels + const int oC = weights->sizeAt(indWoC); // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, 1, kW, 1, sW, 0, pW, 1, + dW, 1, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoW, 0, indIOioC, indIiW}); + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV1D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector reshapeForInput, reshapeForGradO; + if (!isNCW) { + reshapeForInput = {input->sizeAt(0), 1, input->sizeAt(1), + input->sizeAt(2)}; // [bS, iW, iC] -> [bS, 1, iW, iC] + reshapeForGradO = {gradO->sizeAt(0), 1, gradO->sizeAt(1), + gradO->sizeAt(2)}; // [bS, oW, oC] -> [bS, 1, oW, oC] + } else { + reshapeForInput = {input->sizeAt(0), input->sizeAt(1), 1, + input->sizeAt(2)}; // [bS, iC, iW] -> [bS, iC, 1, iW] + reshapeForGradO = {gradO->sizeAt(0), gradO->sizeAt(1), 1, + gradO->sizeAt(2)}; // [bS, oC, oW] -> [bS, oC, 1, oW] + } + + auto inputReshaped = input->reshape(input->ordering(), reshapeForInput); + auto gradIReshaped = + gradI->reshape(gradI->ordering(), reshapeForInput, false); + auto gradOReshaped = gradO->reshape(gradO->ordering(), reshapeForGradO); + auto weightsReshaped = weights->reshape( + weights->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), + weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] + auto gradWReshaped = gradW->reshape( + gradW->ordering(), + {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, + false); // [kW, iC, oC] -> [1, kW, iC, oC] + + sd::ops::conv2d_bp conv2dBP; + auto status = conv2dBP.execute( + {&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, + {&gradIReshaped, &gradWReshaped, gradB}, {}, + {1, kW, 1, sW, 0, pW, 1, dW, paddingMode, !isNCW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, + // &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, + // 1,dW, paddingMode, isNCW, wFormat); + + return Status::OK(); } - DECLARE_SHAPE_FN(conv1d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) - auto weightsShapeInfo = inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - Nd4jLong const* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next - - const int rank = 3; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) width - int sW = INT_ARG(1); // strides width - int pW = INT_ARG(2); // paddings width - int dW = INT_ARG(3); // dilations width - int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME - int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW - int wFormat = block.numI() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] - - int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); - if(!isNCW) { - indIOioC = 2; indIiW = 1; - } - else { - indIOioC = 1; indIiW = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iW = inputShapeInfo[indIiW+1]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}); - std::vector expectedWeightsShape = 0 == wFormat ? std::vector({kW, iC, oC}) : (1 == wFormat ? std::vector({oC, iC, kW}) : std::vector({oC, kW, iC})); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW) + auto weightsShapeInfo = + inputShape->at(1); // [kW, iC, oC], [oC, iC, kW], [oC, kW, iC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + Nd4jLong const* gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at( + 2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next + + const int rank = 3; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV1D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + int kW = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) width + int sW = INT_ARG(1); // strides width + int pW = INT_ARG(2); // paddings width + int dW = INT_ARG(3); // dilations width + int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME + int isNCW = block.numI() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW + int wFormat = + block.numI() > 6 + ? INT_ARG(6) + : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC] + + int indIOioC, indIiW, indWoC(0 == wFormat ? 2 : 0); + if (!isNCW) { + indIOioC = 2; + indIiW = 1; + } else { + indIOioC = 1; + indIiW = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iW = inputShapeInfo[indIiW + 1]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, 1, kW, 1, sW, 0, pW, 1, + dW, 1, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoW, 0, indIOioC, indIiW}); + std::vector expectedWeightsShape = + 0 == wFormat ? std::vector({kW, iC, oC}) + : (1 == wFormat ? std::vector({oC, iC, kW}) + : std::vector({oC, kW, iC})); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV1D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } DECLARE_TYPES(conv1d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes( + 0, {ALL_FLOATS, ALL_INTS, DataType::QINT8, DataType::QINT16}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp index 6024ef9d68c5..bc553a07f73d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv2d.cpp @@ -24,378 +24,567 @@ #include #if NOT_EXCLUDED(OP_conv2d) -#include -#include #include #include #include +#include -namespace sd { -namespace ops { +#include +namespace sd { +namespace ops { CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, isSameMode, isNCHW, wFormat); + + return Status::OK(); } - - DECLARE_SHAPE_FN(conv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - //output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1)); // filter(kernel) width - - const int rank = 4; // 4 - - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + + const int rank = 4; // 4 + + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV2D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo[0]); + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV2D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_TYPES(conv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +DECLARE_TYPES(conv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vectorexpectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vectorexpectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) " + "array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, + gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, + isNCHW, wFormat); + + return Status::OK(); } - - DECLARE_SHAPE_FN(conv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int rank = 4; - - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - const int kH = INT_ARG(0); // filter(kernel) height - const int kW = INT_ARG(1); // filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] + // (NCHW), epsilon_next + + const int rank = 4; + + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + const int kH = INT_ARG(0); // filter(kernel) height + const int kW = INT_ARG(1); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { - - auto gradIShape = INPUT_VARIABLE(0); // [4] - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - // create empty conv2d input array - std::vector gradIShapeAsVector(rank); - for(int i = 0; i < rank; ++i) - gradIShapeAsVector[i] = gradIShape->e(i); - NDArray input(gradO->ordering(), gradIShapeAsVector, gradO->dataType(), block.launchContext()); - - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto gradIShape = INPUT_VARIABLE(0); // [4] + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape " + "must be equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: length of array with output shape " + "must be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + // create empty conv2d input array + std::vector gradIShapeAsVector(rank); + for (int i = 0; i < rank; ++i) + gradIShapeAsVector[i] = gradIShape->e(i); + NDArray input(gradO->ordering(), gradIShapeAsVector, gradO->dataType(), + block.launchContext()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + nullptr, nullptr, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(conv2d_input_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - +DECLARE_TYPES(conv2d_input_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv2d_input_bp) { - - auto gradIShapeShapeInfo = inputShape->at(0); // [4] - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int rank = 4; - - REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - - const int kH = INT_ARG(0); // filter(kernel) height - const int kW = INT_ARG(1); // filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShape = INPUT_VARIABLE(0)->template asVectorT(); - - const int bS = gradIShape[0]; // batch size - const int iH = gradIShape[indIiH]; // input height - const int iW = gradIShape[indIiH+1]; // input width - const int iC = gradIShape[indIOioC]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - - Nd4jLong* gradIshapeInfo(nullptr); - ALLOCATE(gradIshapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - - gradIshapeInfo[0] = rank; - gradIshapeInfo[1] = bS; - - if (isNCHW) { - gradIshapeInfo[2] = iC; - gradIshapeInfo[3] = iH; - gradIshapeInfo[4] = iW; - } else { - gradIshapeInfo[2] = iH; - gradIshapeInfo[3] = iW; - gradIshapeInfo[4] = iC; - } - - ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, shape::order(gradOShapeInfo)); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto gradIShapeShapeInfo = inputShape->at(0); // [4] + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradOShapeInfo = inputShape->at( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + const int rank = 4; + + REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of array with output shape " + "must be equal to %i, but got %i instead !", + 1, gradIShapeShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV2D_INPUT_BP OP: rank of output gradients (next " + "epsilon) array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + + const int kH = INT_ARG(0); // filter(kernel) height + const int kW = INT_ARG(1); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShape = + INPUT_VARIABLE(0)->template asVectorT(); + + const int bS = gradIShape[0]; // batch size + const int iH = gradIShape[indIiH]; // input height + const int iW = gradIShape[indIiH + 1]; // input width + const int iC = gradIShape[indIOioC]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + + Nd4jLong* gradIshapeInfo(nullptr); + ALLOCATE(gradIshapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + + gradIshapeInfo[0] = rank; + gradIshapeInfo[1] = bS; + + if (isNCHW) { + gradIshapeInfo[2] = iC; + gradIshapeInfo[3] = iH; + gradIshapeInfo[4] = iW; + } else { + gradIshapeInfo[2] = iH; + gradIshapeInfo[3] = iW; + gradIshapeInfo[4] = iC; + } + + ShapeUtils::updateStridesAndType(gradIshapeInfo, gradOShapeInfo, + shape::order(gradOShapeInfo)); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif -#endif //LIBND4J_CONVO_OPS_H +#endif // LIBND4J_CONVO_OPS_H diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp index c8d2d9b5831d..baef60fc7ab9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv3d.cpp @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author Yurii Shyrma, created on 05.02.2018 // @@ -22,347 +21,526 @@ #include #if NOT_EXCLUDED(OP_conv3dnew) +#include #include -#include #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); - - std::vector permutForOutput; - - if (isNCDHW) - permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] - else - input = new NDArray(input->permute({0,4,1,2,3})); - - std::vector wAxes; - if(0 == wFormat) - wAxes = {3,0,1,2}; - else if(1 == wFormat) - wAxes = {1,2,3,4}; - else - wAxes = {4,1,2,3}; - - NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, oW, oC] - // [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, oH, oW, oC] - MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, wAxes, permutForOutput); - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCDHW); - - if(!isNCDHW) - delete input; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv3dnew!\n", 0); + + std::vector permutForOutput; + + if (isNCDHW) + permutForOutput = {0, 2, 3, 4, + 1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] + else + input = new NDArray(input->permute({0, 4, 1, 2, 3})); + + std::vector wAxes; + if (0 == wFormat) + wAxes = {3, 0, 1, 2}; + else if (1 == wFormat) + wAxes = {1, 2, 3, 4}; + else + wAxes = {4, 1, 2, 3}; + + NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, + input->dataType(), block.launchContext()); + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, iC, iD, iH, iW] is convoluted to + // [bS, iC, kD, kH, kW, oD, oH, oW] + // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, + // oC] [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, iC, kD, kH, kW] = [bS, oD, oH, + // oW, oC] [bS, iC, kD, kH, kW, oD, oH, oW] x [oC, kD, kH, kW, iC] = [bS, oD, + // oH, oW, oC] + MmulHelper::tensorDot(&columns, weights, output, {1, 2, 3, 4}, wAxes, + permutForOutput); + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCDHW); + + if (!isNCDHW) delete input; + + return Status::OK(); } - DECLARE_TYPES(conv3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(conv3dnew) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv3dnew) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - const int rank = 5; - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - - int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - int bS = inputShapeInfo[1]; // batch size - int iD = inputShapeInfo[indIiD+1]; // input depth - int iH = inputShapeInfo[indIiD+2]; // input height - int iW = inputShapeInfo[indIiD+3]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - if (isNCDHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oD; - outputShapeInfo[4] = oH; - outputShapeInfo[5] = oW; - } else { - outputShapeInfo[2] = oD; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - outputShapeInfo[5] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID; + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV3D OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV3D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, weightsShapeInfo); + + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + int bS = inputShapeInfo[1]; // batch size + int iD = inputShapeInfo[indIiD + 1]; // input depth + int iH = inputShapeInfo[indIiD + 2]; // input height + int iW = inputShapeInfo[indIiD + 3]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV3D OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + paddingMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + if (isNCDHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oD; + outputShapeInfo[4] = oH; + outputShapeInfo[5] = oW; + } else { + outputShapeInfo[2] = oD; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + outputShapeInfo[5] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); - - std::vector gradOaxesForDot; - - if(!isNCDHW) { - gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW - input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - } - else { - gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW - } - - std::vector wPermut, colPermut; - - if(0 == wFormat) { - wPermut = {3,0,1,2,4}; - colPermut = {2,3,4,1,0,5,6,7}; - } - else if(1 == wFormat) { - wPermut = {1,2,3,4,0}; - colPermut = {1,2,3,4,0,5,6,7}; - } - else { - wPermut = {4,1,2,3,0}; - colPermut = {2,3,4,1,0,5,6,7}; - } - - // ----- calculation of gradW and gradB ----- // - NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] - - //----- calculation of gradO -----// - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - //----- calculation of gradI -----// - // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - // [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - // [oC, kD, kH, kW, iC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - - if(!isNCDHW) { - delete input; - delete gradI; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is " + "not allowed for this operation !"); + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv3dnew_bp!\n", 0); + + std::vector gradOaxesForDot; + + if (!isNCDHW) { + gradOaxesForDot = {0, 1, 2, 3}; // bS, oD, oH, oW + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3, 4}; // bS, oD, oH, oW + } + + std::vector wPermut, colPermut; + + if (0 == wFormat) { + wPermut = {3, 0, 1, 2, 4}; + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 4, 0}; + colPermut = {1, 2, 3, 4, 0, 5, 6, 7}; + } else { + wPermut = {4, 1, 2, 3, 0}; + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + } + + // ----- calculation of gradW and gradB ----- // + NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, + input->dataType(), block.launchContext()); + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, iC, iD, iH, iW] is convoluted to + // [bS, iC, kD, kH, kW, oD, oH, oW] + MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 5, 6, 7}, gradOaxesForDot, + wPermut); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, + // oC, oD, oH, oW] = [iC, kD, kH, kW, oC] + + //----- calculation of gradO -----// + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + gradOaxesForDot); // sum over bS oD oH oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + //----- calculation of gradI -----// + // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, + // kW, iC, bS, oD, oH, oW] [oC, iC, kD, kH, kW] x [bS, oD, oH, oW, oC]/[bS, + // oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] [oC, kD, kH, kW, iC] x + // [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, + // oW] + MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); + ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, + dH, + dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is + // de-convoluted to [bS, iC, iD, iH, iW] + + if (!isNCDHW) { + delete input; + delete gradI; + } + + return Status::OK(); } - DECLARE_TYPES(conv3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(conv3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(conv3dnew_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - Nd4jLong const* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - const int rank = 5; - REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo); - - int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - int bS = inputShapeInfo[1]; // batch size - int iD = inputShapeInfo[indIiD+1]; // input depth - int iH = inputShapeInfo[indIiD+2]; // input height - int iW = inputShapeInfo[indIiD+3]; // input width - int iC = inputShapeInfo[indIOioC+1]; // input channels - int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - if(biasShapeInfo) { - auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); -} -} + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + Nd4jLong const* gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, + // kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + const int rank = 5; + REQUIRE_TRUE(paddingMode < 2, 0, + "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, inputShapeInfo); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, weightsShapeInfo); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, gradOShapeInfo); + + int indIOioC, indIiD, indWoC(0 == wFormat ? 4 : 0); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + int bS = inputShapeInfo[1]; // batch size + int iD = inputShapeInfo[indIiD + 1]; // input depth + int iH = inputShapeInfo[indIiD + 2]; // input height + int iW = inputShapeInfo[indIiD + 3]; // input width + int iC = inputShapeInfo[indIOioC + 1]; // input channels + int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIiD, indIiD + 1, + indIiD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), + 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp index 4db9e2548482..1be15405cecd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d.cpp @@ -22,308 +22,473 @@ #include #if NOT_EXCLUDED(OP_deconv2d) -#include #include +#include +#include +#include #include #include -#include -#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(!isNCHW) - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - - std::vector colPermut; - if(1 == wFormat) - colPermut = {1, 2, 3, 0, 4, 5}; - else - colPermut = {2, 3, 1, 0, 4, 5}; - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); - - //----- calculation of output -----// - // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] - // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW] - // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); - LaunchContext* ctx = block.launchContext(); - helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {1}, bias); - helpers::addBias(block, *output, *bias, *output, true); - - if(!isNCHW) - delete output; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, " + "but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (!isNCHW) + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + + std::vector colPermut; + if (1 == wFormat) + colPermut = {1, 2, 3, 0, 4, 5}; + else + colPermut = {2, 3, 1, 0, 4, 5}; + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, + input->dataType(), block.launchContext()); + + //----- calculation of output -----// + // NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] + // NHWC: [iC, oC, kH, kW] x [bS, iH, iW, iC] = [oC, kH, kW, bS, iH, iW] + // NHWC: [iC, kH, kW, oC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, + colPermut); + LaunchContext* ctx = block.launchContext(); + helpers::col2im( + *ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, + dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW] + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {1}, bias); + helpers::addBias(block, *output, *bias, *output, true); + + if (!isNCHW) delete output; + + return Status::OK(); +} +DECLARE_TYPES(deconv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_TYPES(deconv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } DECLARE_SHAPE_FN(deconv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong outputShape[4]; - - outputShape[0] = bS; - if (isNCHW) { - outputShape[1] = oC; - outputShape[2] = oH; - outputShape[3] = oW; - } else { - outputShape[1] = oH; - outputShape[2] = oW; - outputShape[3] = oC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(weightsShapeInfo), shape::order(inputShapeInfo), outputShape, 4))); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV2D OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong outputShape[4]; + + outputShape[0] = bS; + if (isNCHW) { + outputShape[1] = oC; + outputShape[2] = oH; + outputShape[3] = oW; + } else { + outputShape[1] = oH; + outputShape[2] = oW; + outputShape[3] = oC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(weightsShapeInfo), + shape::order(inputShapeInfo), outputShape, 4))); } - DECLARE_TYPES(deconv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - // ----- calculation of gradI -> pass it through conv2d_ff ----- // - sd::ops::conv2d conv2d; - const Nd4jStatus status = conv2d.execute({gradO, weights}, {gradI}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, !isNCHW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // -----prepare permutation arrays and axes for dot product ----- // - std::vector inputAxes; - - if(!isNCHW) { - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - inputAxes = {0, 1, 2}; // bS, iH, iW - } - else - inputAxes = {0, 2, 3}; // bS, iH, iW - - std::vector gradWAxes; // empty for wFormat = 1 - if(0 == wFormat) - gradWAxes = {3, 2, 0, 1}; - else if(2 == wFormat) - gradWAxes = {0, 3, 1, 2}; - - // ----- calculation of gradW ----- // - NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); - - LaunchContext* ctx = block.launchContext(); - helpers::im2col(*ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, oC, kH, kW, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, oC, kH, kW, iH, iW] = [iC, oC, kH, kW] - - // ----- calculation of gradB ----- // - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - if(!isNCHW) - delete gradO; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of input array must be equal to 4, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to " + "4 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + // ----- calculation of gradI -> pass it through conv2d_ff ----- // + sd::ops::conv2d conv2d; + const Nd4jStatus status = conv2d.execute( + {gradO, weights}, {gradI}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, !isNCHW, wFormat}, {}); + if (status != ND4J_STATUS_OK) return status; + + // -----prepare permutation arrays and axes for dot product ----- // + std::vector inputAxes; + + if (!isNCHW) { + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + inputAxes = {0, 1, 2}; // bS, iH, iW + } else + inputAxes = {0, 2, 3}; // bS, iH, iW + + std::vector gradWAxes; // empty for wFormat = 1 + if (0 == wFormat) + gradWAxes = {3, 2, 0, 1}; + else if (2 == wFormat) + gradWAxes = {0, 3, 1, 2}; + + // ----- calculation of gradW ----- // + NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, + input->dataType(), block.launchContext()); + + LaunchContext* ctx = block.launchContext(); + helpers::im2col( + *ctx, *gradO, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, oC, oH, oW] is convoluted to [bS, + // oC, kH, kW, iH, iW] + MmulHelper::tensorDot(input, &columns, gradW, inputAxes, {0, 4, 5}, + gradWAxes); // [bS, iC, iH, iW]/[bS, iH, iW, iC] x [bS, + // oC, kH, kW, iH, iW] = [iC, oC, kH, kW] + + // ----- calculation of gradB ----- // + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + {0, 2, 3}); // sum over bS, oH, oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + if (!isNCHW) delete gradO; + + return Status::OK(); } DECLARE_SHAPE_FN(deconv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - Nd4jLong const* biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int indIOioC, indIiH, indOoH, indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); - - if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - shapes->push_back(CONSTANT(gradBShapeInfo)); - } - - return shapes; + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + Nd4jLong const* biasShapeInfo = + block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] + // (NCDHW), epsilon_next + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to " + "%i , but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int indIOioC, indIiH, indOoH, + indWoC(0 == wFormat ? 2 : (1 == wFormat ? 1 : 3)); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); + + if (biasShapeInfo != nullptr) { + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + shapes->push_back(CONSTANT(gradBShapeInfo)); + } + + return shapes; } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp index 8eb361118c47..85042e84e4a2 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv2d_tf.cpp @@ -26,128 +26,207 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { - - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - // create empty conv2d input array - NDArray input(gradO->ordering(), gradIShape->asVectorT(), gradO->dataType(), block.launchContext()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShape = INPUT_VARIABLE( + 0); // [4] - shape of input of conv2d (that is shape of gradI) + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to " + "4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM DECONV2D_TF OP: rank of array with output shape must be " + "equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM DECONV2D_TF OP: length of array with output shape must " + "be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + // create empty conv2d input array + NDArray input(gradO->ordering(), gradIShape->asVectorT(), + gradO->dataType(), block.launchContext()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on " + "array with output shape expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + nullptr, nullptr, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(deconv2d_tf) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv2d_tf) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv2d_tf) { - - auto gradOShapeInfo = inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShapeShapeInfo = inputShape->at(0); // [4] - - const int rank = 4; - - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, shape::rank(gradIShapeShapeInfo)); - - const int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - const int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - const int sH = INT_ARG(2); // strides height - const int sW = INT_ARG(3); // strides width - const int pH = INT_ARG(4); // paddings height - const int pW = INT_ARG(5); // paddings width - const int dH = INT_ARG(6); // dilations height - const int dW = INT_ARG(7); // dilations width - const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShape = INPUT_VARIABLE(0)->template asVectorT(); - - const int bS = gradIShape[0]; // batch size - const int iH = gradIShape[indIiH]; // input height - const int iW = gradIShape[indIiH+1]; // input width - const int iC = gradIShape[indIOioC]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - const int oH = gradOShapeInfo[indOoH+1]; // input height - const int oW = gradOShapeInfo[indOoH+2]; // input width - - int trueiH, trueiW; // output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode); - - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - - Nd4jLong shape[4]; - shape[0] = bS; - - if (isNCHW) { - shape[1] = iC; - shape[2] = iH; - shape[3] = iW; - } else { - shape[1] = iH; - shape[2] = iW; - shape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weightsShapeInfo), shape::order(gradOShapeInfo), 4, shape)); + auto gradOShapeInfo = inputShape->at( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShapeShapeInfo = inputShape->at(0); // [4] + + const int rank = 4; + + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV2D_TF OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, + "CUSTOM DECONV2D_TF OP: rank of array with output shape must be " + "equal to %i, but got %i instead !", + 1, shape::rank(gradIShapeShapeInfo)); + + const int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + const int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + const int sH = INT_ARG(2); // strides height + const int sW = INT_ARG(3); // strides width + const int pH = INT_ARG(4); // paddings height + const int pW = INT_ARG(5); // paddings width + const int dH = INT_ARG(6); // dilations height + const int dW = INT_ARG(7); // dilations width + const int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + const int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + const int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int indIOioC, indIiH, indWoC(0 == wFormat ? 3 : 0), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShape = + INPUT_VARIABLE(0)->template asVectorT(); + + const int bS = gradIShape[0]; // batch size + const int iH = gradIShape[indIiH]; // input height + const int iW = gradIShape[indIiH + 1]; // input width + const int iC = gradIShape[indIOioC]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + const int oH = gradOShapeInfo[indOoH + 1]; // input height + const int oW = gradOShapeInfo[indOoH + 2]; // input width + + int trueiH, trueiW; // output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, + dH, dW, oH, oW, isSameMode); + + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, trueiH, trueiW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, + "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradIShape).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + + Nd4jLong shape[4]; + shape[0] = bS; + + if (isNCHW) { + shape[1] = iC; + shape[2] = iH; + shape[3] = iW; + } else { + shape[1] = iH; + shape[2] = iW; + shape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(weightsShapeInfo), shape::order(gradOShapeInfo), 4, + shape)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp index 6d5a483b892d..648f088d6eb5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/deconv3d.cpp @@ -21,334 +21,514 @@ #include #if NOT_EXCLUDED(OP_deconv3d) +#include #include -#include #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2)); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(!isNCDHW) - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - - std::vector colPermut; - if(1 == wFormat) - colPermut = {1,2,3,4,0,5,6,7}; - else - colPermut = {2,3,4,1,0,5,6,7}; - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - - //----- calculation of output -----// - // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] - // [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, iH, iW] - // [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] - sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] - ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add,{1}, bias); - helpers::addBias(block, *output, *bias, *output, true); - - if(!isNCDHW) - delete output; - - return Status::OK(); - + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (!isNCDHW) + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] + + std::vector colPermut; + if (1 == wFormat) + colPermut = {1, 2, 3, 4, 0, 5, 6, 7}; + else + colPermut = {2, 3, 4, 1, 0, 5, 6, 7}; + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + NDArray columns(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, + input->dataType(), block.launchContext()); + + //----- calculation of output -----// + // [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, + // iW] [iC, oC, kD, kH, kW] x [bS, iD, iH, iW, iC] = [oC, kD, kH, kW, bS, iD, + // iH, iW] [iC, kD, kH, kW, oC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, + // iD, iH, iW] + sd::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, + colPermut); // [bS, oC, kD, kH, kW, iD, iH, iW] -> + // [kD, kH, kW, oC, bS, iD, iH, iW] + ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is + // de-convoluted to [bS, oC, oD, oH, oW] + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add,{1}, bias); + helpers::addBias(block, *output, *bias, *output, true); + + if (!isNCDHW) delete output; + + return Status::OK(); } - DECLARE_TYPES(deconv3d) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv3d) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv3d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 5; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iD = inputShapeInfo[indIiD+1]; // input depth - const int iH = inputShapeInfo[indIiD+2]; // input height - const int iW = inputShapeInfo[indIiD+3]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCDHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oD; - outputShapeInfo[4] = oH; - outputShapeInfo[5] = oW; - } else { - outputShapeInfo[2] = oD; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - outputShapeInfo[5] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NDCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 5; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV3D OP: rank of input array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, " + "but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iD = inputShapeInfo[indIiD + 1]; // input depth + const int iH = inputShapeInfo[indIiD + 2]; // input height + const int iW = inputShapeInfo[indIiD + 3]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV3D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCDHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oD; + outputShapeInfo[4] = oH; + outputShapeInfo[5] = oW; + } else { + outputShapeInfo[2] = oD; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + outputShapeInfo[5] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - int trueoD, trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // ----- calculation of gradI -> pass it through conv3d_ff ----- // - sd::ops::conv3dnew conv3d; - const Nd4jStatus status = conv3d.execute({gradO, weights}, {gradI}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, isSameMode, !isNCDHW, wFormat}, {}); - if (status != ND4J_STATUS_OK) - return status; - - // -----prepare permutation arrays and axes for dot product ----- // - std::vector inputAxesForDot; - - if(!isNCDHW) { - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW - } - else - inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW - - std::vector gradWAxes; // empty for wFormat = 1 - if(0 == wFormat) - gradWAxes = {4,3,0,1,2}; - else if(2 == wFormat) - gradWAxes = {0,4,1,2,3}; - - // ----- calculation of gradW ----- // - auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] - MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] - - // ----- calculation of gradB ----- // - if(gradB) { - if(gradB->rankOf() == 2) - gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW - if(gradB != OUTPUT_VARIABLE(2)) - delete gradB; - } - - if(!isNCDHW) - delete gradO; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to " + "5 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + int trueoD, trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // ----- calculation of gradI -> pass it through conv3d_ff ----- // + sd::ops::conv3dnew conv3d; + const Nd4jStatus status = + conv3d.execute({gradO, weights}, {gradI}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + isSameMode, !isNCDHW, wFormat}, + {}); + if (status != ND4J_STATUS_OK) return status; + + // -----prepare permutation arrays and axes for dot product ----- // + std::vector inputAxesForDot; + + if (!isNCDHW) { + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] + inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW + } else + inputAxesForDot = {0, 2, 3, 4}; // bS, iD, iH, iW + + std::vector gradWAxes; // empty for wFormat = 1 + if (0 == wFormat) + gradWAxes = {4, 3, 0, 1, 2}; + else if (2 == wFormat) + gradWAxes = {0, 4, 1, 2, 3}; + + // ----- calculation of gradW ----- // + auto columns = NDArrayFactory::create( + input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), + block.launchContext()); + ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, + dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to + // [bS, oC, kD, kH, kW, iD, iH, iW] + MmulHelper::tensorDot( + input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, + gradWAxes); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, + // kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] + + // ----- calculation of gradB ----- // + if (gradB) { + if (gradB->rankOf() == 2) + gradB = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension(reduce::Sum, *gradB, + {0, 2, 3, 4}); // sum over bS, oD, oH, oW + if (gradB != OUTPUT_VARIABLE(2)) delete gradB; + } + + if (!isNCDHW) delete gradO; + + return Status::OK(); } - DECLARE_TYPES(deconv3d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(deconv3d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(deconv3d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weightsShapeInfo = inputShape->at(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - const int rank = 5; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(shape::sizeAt(weightsShapeInfo, 2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); - if(!isNCDHW) { - indIOioC = 4; indIiD = 1; - } - else { - indIOioC = 1; indIiD = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iD = inputShapeInfo[indIiD+1]; // input depth - const int iH = inputShapeInfo[indIiD+2]; // input height - const int iW = inputShapeInfo[indIiD+3]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - int trueoD, trueoH, trueoW; // true output depth, height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); - - if (biasShapeInfo != nullptr) { - auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - shapes->push_back(CONSTANT(gradBShapeInfo)); - } - - return shapes; + auto inputShapeInfo = inputShape->at( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; // [oC] + auto gradOShapeInfo = + block.width() > 3 + ? inputShape->at(3) + : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + const int rank = 5; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of input array must be equal to " + "%i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to " + "%i , but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) " + "array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) + : static_cast(shape::sizeAt( + weightsShapeInfo, 2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int indIOioC, indIiD, indWoC(0 == wFormat ? 3 : (1 == wFormat ? 1 : 4)); + if (!isNCDHW) { + indIOioC = 4; + indIiD = 1; + } else { + indIOioC = 1; + indIiD = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iD = inputShapeInfo[indIiD + 1]; // input depth + const int iH = inputShapeInfo[indIiD + 2]; // input height + const int iW = inputShapeInfo[indIiD + 3]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + int trueoD, trueoH, trueoW; // true output depth, height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIiD, indIiD + 1, + indIiD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + auto shapes = SHAPELIST(CONSTANT(gradIShapeInfo), CONSTANT(gradWShapeInfo)); + + if (biasShapeInfo != nullptr) { + auto gradBShapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + shapes->push_back(CONSTANT(gradBShapeInfo)); + } + + return shapes; } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp index b28601c678dc..133f222e5bef 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/depthwiseConv2d.cpp @@ -21,240 +21,395 @@ #include #if NOT_EXCLUDED(OP_depthwise_conv2d) -#include #include #include - +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI()> 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW,wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "CUSTOM DEPTHWISECONV2D OP: the output_channels must be equal " + "to input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + wFormat); + + return Status::OK(); } - DECLARE_TYPES(depthwise_conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(depthwise_conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(depthwise_conv2d) { - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size - const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height - const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width - const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels - const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels - - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "DEPTHWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = rank; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = inputShape->at( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto biasShapeInfo = + block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] = iC*mC + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal " + "to %i, but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH + 1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, + indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC * mC; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "DEPTHWISECONV2D OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong* outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = rank; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(depthwise_conv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(depthwise_conv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, + gradW, gradB, kH, kW, sH, sW, pH, pW, dH, + dW, isSameMode, isNCHW, wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(depthwise_conv2d_bp) { - auto inputShapeInfo = inputShape->at(0); - auto weightsShapeInfo = inputShape->at(1); - auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; - auto gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); - - const int rank = 4; - REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo)); - REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo)); - REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo)); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size - const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height - const int iW = shape::sizeAt(inputShapeInfo, indIiH+1); // input width - const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels - const int mC = shape::sizeAt(weightsShapeInfo, indWmC); // channels multiplier(oC = iC*mC) - const int oC = iC*mC; // output channels - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if(biasShapeInfo) - REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.workspace()); - - if(biasShapeInfo) { - Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); + auto inputShapeInfo = inputShape->at(0); + auto weightsShapeInfo = inputShape->at(1); + auto biasShapeInfo = block.width() > 3 ? inputShape->at(2) : nullptr; + auto gradOShapeInfo = + block.width() > 3 ? inputShape->at(3) : inputShape->at(2); + + const int rank = 4; + REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be " + "equal to %i, but got %i instead !", + rank, shape::rank(inputShapeInfo)); + REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be " + "equal to %i, but got %i instead !", + rank, shape::rank(weightsShapeInfo)); + REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, + "CUSTOM DEPTHWISECONV2D_BP OP: rank of output gradients (next " + "epsilon) array must be equal to %i, but got %i instead !", + rank, shape::rank(gradOShapeInfo)); + + int kH = INT_ARG(0) > 0 ? INT_ARG(0) + : static_cast(shape::sizeAt( + weightsShapeInfo, 0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 ? INT_ARG(1) + : static_cast(shape::sizeAt( + weightsShapeInfo, 1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = shape::sizeAt(inputShapeInfo, 0); // batch size + const int iH = shape::sizeAt(inputShapeInfo, indIiH); // input height + const int iW = shape::sizeAt(inputShapeInfo, indIiH + 1); // input width + const int iC = shape::sizeAt(inputShapeInfo, indIOioC); // input channels + const int mC = shape::sizeAt(weightsShapeInfo, + indWmC); // channels multiplier(oC = iC*mC) + const int oC = iC * mC; // output channels + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), + shape::rank(gradOShapeInfo), + shape::shapeOf(gradOShapeInfo)), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), + shape::rank(weightsShapeInfo), + shape::shapeOf(weightsShapeInfo)), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE( + shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), + 0, + "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsShapeInfo, gradOShapeInfo, false, block.workspace()); + + if (biasShapeInfo) { + Nd4jLong* gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp index 0baae6d7f149..20f6024b9744 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/dilation2d.cpp @@ -26,111 +26,116 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(dilation2d, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(dilation2d, 2, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D"); - REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D"); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D"); + REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D"); - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(3); - const bool isSameShape = INT_ARG(0) == 1; + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(3); + const bool isSameShape = INT_ARG(0) == 1; - REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2)); + REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, + "Dilation2D: number of input channels doesn't match number of " + "channels in weights: %i vs %i", + input->sizeAt(3), weights->sizeAt(2)); - std::vector strides(4); - std::vector rates(4); + std::vector strides(4); + std::vector rates(4); - if (block.width() > 2) { - REQUIRE_TRUE(block.width() >= 4, 0, "Dilation2D: number of input arrays should be 4 at least"); + if (block.width() > 2) { + REQUIRE_TRUE(block.width() >= 4, 0, + "Dilation2D: number of input arrays should be 4 at least"); + auto r = INPUT_VARIABLE(2); + auto s = INPUT_VARIABLE(3); - auto r = INPUT_VARIABLE(2); - auto s = INPUT_VARIABLE(3); + strides = s->template asVectorT(); + rates = r->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= 9, 0, + "Dilation2D: number of Int arguments should be 9 at least"); - strides = s->template asVectorT(); - rates = r->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= 9, 0, "Dilation2D: number of Int arguments should be 9 at least"); + int e = 1; + for (int cnt = 0; cnt < 4; cnt++) rates[cnt] = INT_ARG(e++); - int e = 1; - for (int cnt = 0;cnt < 4; cnt++) - rates[cnt] = INT_ARG(e++); + for (int cnt = 0; cnt < 4; cnt++) strides[cnt] = INT_ARG(e++); + } + int sH = 0, sW = 0; + int dH = 0, dW = 0; + int pH = 0, pW = 0; + int oH = 0, oW = 0; - for (int cnt = 0; cnt < 4; cnt++) - strides[cnt] = INT_ARG(e++); - } + helpers::dilation_hw(block.launchContext(), input->shapeInfo(), + weights->shapeInfo(), strides, rates, isSameShape, &sH, + &sW, &pH, &pW, &dH, &dW, &oH, &oW); + REQUIRE_TRUE(oH > 0 && oW > 0, 0, + "Dilation2D: outY and outX should have positive values, but got " + "[%i, %i] instead", + oH, oW); - int sH = 0, sW = 0; - int dH = 0, dW = 0; - int pH = 0, pW = 0; - int oH = 0, oW = 0; + helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, + pW, dH, dW); - helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); - - - REQUIRE_TRUE(oH > 0 && oW > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", oH, oW); - - helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, pW, dH, dW); - - return Status::OK(); - } - - DECLARE_TYPES(dilation2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(dilation2d) { - auto input = inputShape->at(0); - auto weights = inputShape->at(1); +DECLARE_TYPES(dilation2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - const int bS = shape::sizeAt(input, 0); - const int iC = shape::sizeAt(input, 3); - const bool isSameShape = INT_ARG(0) == 1; +DECLARE_SHAPE_FN(dilation2d) { + auto input = inputShape->at(0); + auto weights = inputShape->at(1); - std::vector strides(4); - std::vector rates(4); + const int bS = shape::sizeAt(input, 0); + const int iC = shape::sizeAt(input, 3); + const bool isSameShape = INT_ARG(0) == 1; - if (block.width() > 2) { - auto r = INPUT_VARIABLE(2); - auto s = INPUT_VARIABLE(3); + std::vector strides(4); + std::vector rates(4); + if (block.width() > 2) { + auto r = INPUT_VARIABLE(2); + auto s = INPUT_VARIABLE(3); - strides = s->template asVectorT(); - rates = r->template asVectorT(); - } else { - if (block.numI() < 9) { - auto newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(input)); - return SHAPELIST(newShape); - } + strides = s->template asVectorT(); + rates = r->template asVectorT(); + } else { + if (block.numI() < 9) { + auto newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(input)); + return SHAPELIST(newShape); + } - int e = 1; - for (int cnt = 0;cnt < 4; cnt++) - rates[cnt] = INT_ARG(e++); + int e = 1; + for (int cnt = 0; cnt < 4; cnt++) rates[cnt] = INT_ARG(e++); - for (int cnt = 0; cnt < 4; cnt++) - strides[cnt] = INT_ARG(e++); - } + for (int cnt = 0; cnt < 4; cnt++) strides[cnt] = INT_ARG(e++); + } - int sH = 0, sW = 0; - int dH = 0, dW = 0; - int pH = 0, pW = 0; - int oH = 0, oW = 0; + int sH = 0, sW = 0; + int dH = 0, dW = 0; + int pH = 0, pW = 0; + int oH = 0, oW = 0; - helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); + helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, + isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW); - std::array shape = {{bS, oH, oW, iC}}; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } -} + std::array shape = {{bS, oH, oW, iC}}; + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(weights), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp index b7df3748251b..8d84a172bb89 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/im2col.cpp @@ -22,139 +22,153 @@ #if NOT_EXCLUDED(OP_im2col) #include +#include #include #include -#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf()); - REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf()); - - int kernelHeight = INT_ARG(0); - int kernelWidth = INT_ARG(1); - int strideY = INT_ARG(2); - int strideX = INT_ARG(3); - int padHeight = INT_ARG(4); - int padWidth = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - double zeroPadVal = 0.0; - if (block.numT() > 0) - zeroPadVal = T_ARG(0); - - // FIXME: zeropad value is void - LaunchContext* ctx = block.launchContext(); - sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext())); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(im2col) { - auto inShape = inputShape->at(0); - - int bS = shape::shapeOf(inShape)[0]; - int iD = shape::shapeOf(inShape)[1]; - int inY = shape::shapeOf(inShape)[2]; - int inX = shape::shapeOf(inShape)[3]; - - int kY = INT_ARG(0); - int kX = INT_ARG(1); - int sY = INT_ARG(2); - int sX = INT_ARG(3); - int pY = INT_ARG(4); - int pX = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - - // output is always 6d for im2col - Nd4jLong* zShape; - ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(6), Nd4jLong); - - int oY = 0; - int oX = 0; - - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); - - zShape[0] = 6; - zShape[1] = bS; - zShape[2] = iD; - zShape[3] = kY; - zShape[4] = kX; - zShape[5] = oY; - zShape[6] = oX; - - zShape[shape::shapeInfoLength(zShape) - 2] = 1; - zShape[shape::shapeInfoLength(zShape) - 1] = 99; - - ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); - - return SHAPELIST(CONSTANT(zShape)); - } - - CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { - auto input = INPUT_VARIABLE(0); - auto gradAtOutput = INPUT_VARIABLE(1); - auto z = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf()); - REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf()); - REQUIRE_TRUE(z->rankOf() == 4, 0, "im2col_bp output (grad at input) should be 4D, but got %i instead", z->rankOf()); - - int kernelHeight = INT_ARG(0); - int kernelWidth = INT_ARG(1); - int strideY = INT_ARG(2); - int strideX = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dY = INT_ARG(6); //Dilation, height/y dimension - int dX = INT_ARG(7); //Dilation, width/x dimension - bool isSameMode = INT_ARG(8) > 0; - double zeroPadVal = 0.0; - if (block.numT() > 0) - zeroPadVal = T_ARG(0); - - //Assuming NCHW format here - int imgH = input->sizeAt(2); - int imgW = input->sizeAt(3); - - LaunchContext* ctx = block.launchContext(); - // FIXME:: all helpers should accept NDArray - ops::helpers::col2im(*ctx, *gradAtOutput, *z, strideY, strideX, pH, pW, imgH, imgW, dY, dX); - - return Status::OK(); - } - - DECLARE_TYPES(im2col) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - - DECLARE_TYPES(im2col_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(im2col_bp) { - Nd4jLong *inShape; - COPY_SHAPE(inputShape->at(0), inShape); - - return SHAPELIST(CONSTANT(inShape)); - } - } +namespace ops { +CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(x->rankOf() == 4, 0, + "im2col input should be 4D, but got %i instead", x->rankOf()); + REQUIRE_TRUE(z->rankOf() == 6, 0, + "im2col output should be 6D, but got %i instead", z->rankOf()); + + int kernelHeight = INT_ARG(0); + int kernelWidth = INT_ARG(1); + int strideY = INT_ARG(2); + int strideX = INT_ARG(3); + int padHeight = INT_ARG(4); + int padWidth = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + double zeroPadVal = 0.0; + if (block.numT() > 0) zeroPadVal = T_ARG(0); + + // FIXME: zeropad value is void + LaunchContext* ctx = block.launchContext(); + sd::ops::helpers::im2col( + *ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, + padWidth, dY, dX, + NDArrayFactory::create(zeroPadVal, block.launchContext())); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(im2col) { + auto inShape = inputShape->at(0); + + int bS = shape::shapeOf(inShape)[0]; + int iD = shape::shapeOf(inShape)[1]; + int inY = shape::shapeOf(inShape)[2]; + int inX = shape::shapeOf(inShape)[3]; + + int kY = INT_ARG(0); + int kX = INT_ARG(1); + int sY = INT_ARG(2); + int sX = INT_ARG(3); + int pY = INT_ARG(4); + int pX = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + + // output is always 6d for im2col + Nd4jLong* zShape; + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(6), Nd4jLong); + + int oY = 0; + int oX = 0; + + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, + dY, dX); + + zShape[0] = 6; + zShape[1] = bS; + zShape[2] = iD; + zShape[3] = kY; + zShape[4] = kX; + zShape[5] = oY; + zShape[6] = oX; + + zShape[shape::shapeInfoLength(zShape) - 2] = 1; + zShape[shape::shapeInfoLength(zShape) - 1] = 99; + + ShapeUtils::updateStridesAndType(zShape, inShape, 'c'); + + return SHAPELIST(CONSTANT(zShape)); +} + +CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { + auto input = INPUT_VARIABLE(0); + auto gradAtOutput = INPUT_VARIABLE(1); + auto z = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "im2col_bp input should be 4D, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, + "im2col_bp gradient at output (input idx 1) should be 6D, but " + "got %i instead", + gradAtOutput->rankOf()); + REQUIRE_TRUE( + z->rankOf() == 4, 0, + "im2col_bp output (grad at input) should be 4D, but got %i instead", + z->rankOf()); + + int kernelHeight = INT_ARG(0); + int kernelWidth = INT_ARG(1); + int strideY = INT_ARG(2); + int strideX = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dY = INT_ARG(6); // Dilation, height/y dimension + int dX = INT_ARG(7); // Dilation, width/x dimension + bool isSameMode = INT_ARG(8) > 0; + double zeroPadVal = 0.0; + if (block.numT() > 0) zeroPadVal = T_ARG(0); + + // Assuming NCHW format here + int imgH = input->sizeAt(2); + int imgW = input->sizeAt(3); + + LaunchContext* ctx = block.launchContext(); + // FIXME:: all helpers should accept NDArray + ops::helpers::col2im(*ctx, *gradAtOutput, *z, strideY, strideX, pH, pW, imgH, + imgW, dY, dX); + + return Status::OK(); +} + +DECLARE_TYPES(im2col) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} + +DECLARE_TYPES(im2col_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setSameMode(true); +} + +DECLARE_SHAPE_FN(im2col_bp) { + Nd4jLong* inShape; + COPY_SHAPE(inputShape->at(0), inShape); + + return SHAPELIST(CONSTANT(inShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp index 24488dd5de31..247de31c7792 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/ismax.cpp @@ -26,30 +26,28 @@ #include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) { - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - auto dimensions = block.getIArguments(); // argI - if (x->isScalar()) - z->assign(1); - else - helpers::ismax(block.launchContext(), x, z, dimensions); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + auto dimensions = block.getIArguments(); // argI + if (x->isScalar()) + z->assign(1); + else + helpers::ismax(block.launchContext(), x, z, dimensions); + + return Status::OK(); } DECLARE_SYN(IsMax, ismax); - DECLARE_TYPES(ismax) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::ANY); - - } - -} +DECLARE_TYPES(ismax) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::ANY); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp index 47243d9b9f96..288b276dbeb3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/pointwiseConv2d.cpp @@ -22,92 +22,139 @@ #include namespace sd { -namespace ops { - - +namespace ops { CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf()); - - int kH = 1; // filter(kernel) height - int kW = 1; // filter(kernel) width - int sH = 1; // strides height - int sW = 1; // strides width - int pH = 0; // paddings height - int pW = 0; // paddings width - int dH = 1; // dilations height - int dW = 1; // dilations width - int isNCHW = block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = + INPUT_VARIABLE(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM POINTWISECONV2D OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2, 0, + "CUSTOM POINTWISECONV2D OP: rank of biases array must be " + "equal <= 2, but got %i instead !", + bias->rankOf()); + + int kH = 1; // filter(kernel) height + int kW = 1; // filter(kernel) width + int sH = 1; // strides height + int sW = 1; // strides width + int pH = 0; // paddings height + int pW = 0; // paddings width + int dH = 1; // dilations height + int dW = 1; // dilations width + int isNCHW = + block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 1 + ? INT_ARG(1) + : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, 1 /*isSameMode*/, isNCHW, wFormat); + + return Status::OK(); } - DECLARE_TYPES(pointwise_conv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(pointwise_conv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(pointwise_conv2d) { - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] - auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]); - - int isNCHW = block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 1 ? INT_ARG(1) : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] - - int indIOioC, indWoC(0 == wFormat ? 3 : 0); - if(!isNCHW) - indIOioC = 3; - else - indIOioC = 1; - - const int bS = inputShapeInfo[1]; // batch size - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int oC = weightsShapeInfo[indWoC+1]; // output channels - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.workspace()); - - // do not forget to put oC instead of iC in outputShapeInfo - outputShapeInfo[indIOioC + 1] = oC; - - shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsShapeInfo = + inputShape->at(1); // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC] + auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC] + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "CUSTOM POINTWISECONV2D OP: rank of input array must be equal " + "to %i, but got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, + "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal " + "to %i, but got %i instead !", + rank, weightsShapeInfo[0]); + + int isNCHW = + block.numI() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 1 + ? INT_ARG(1) + : 0; // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC] + + int indIOioC, indWoC(0 == wFormat ? 3 : 0); + if (!isNCHW) + indIOioC = 3; + else + indIOioC = 1; + + const int bS = inputShapeInfo[1]; // batch size + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int oC = weightsShapeInfo[indWoC + 1]; // output channels + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, + "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weightsShapeInfo).c_str()); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "POINTWISECONV2D OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, weightsShapeInfo, true, block.workspace()); + + // do not forget to put oC instead of iC in outputShapeInfo + outputShapeInfo[indIOioC + 1] = oC; + + shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp index cd2bb3cf61cf..78c05e01e07f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/sconv2d.cpp @@ -24,379 +24,584 @@ #include #include + #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { + NDArray *input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + NDArray *weightsDepth = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC + + NDArray *output = OUTPUT_NULLIFIED( + 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + if (block.width() == 3) { + if ((INPUT_VARIABLE(2))->rankOf() == 4) + weightsPoint = INPUT_VARIABLE(2); + else + bias = INPUT_VARIABLE(2); + } else if (block.width() == 4) { + weightsPoint = INPUT_VARIABLE(2); + bias = INPUT_VARIABLE(3); + } + + REQUIRE_TRUE(input->rankOf() == 4, 0, + " SCONV2D OP: rank of input array must be equal to 4, but got " + "%i instead !", + input->rankOf()); + REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, + " SCONV2D OP: rank of weightsDepth array must be equal to 4, " + "but got %i instead !", + weightsDepth->rankOf()); + if (weightsPoint) + REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, + " SCONV2D OP: rank of weightsPoint array must be equal to 4, " + "but got %i instead !", + weightsPoint->rankOf()); + if (bias) + REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, + " SCONV2D OP: rank of biases array must be equal to 1 or 2, " + "but got %i instead !", + bias->rankOf()); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, + " SCONV2D OP: wrong shape of weightsDepth array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDepth).c_str()); + if (weightsPoint) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, + " SCONV2D OP: wrong shape of weightsPoint array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPoint).c_str()); + } + if (bias) + REQUIRE_TRUE(oC == bias->lengthOf(), 0, + " SCONV2D OP: length of bias array must be equal to " + "outChannels, but got %i instead", + bias->lengthOf()); + + if (iC == 1) { + nd4j_debug( + "SCONV2D OP: for input_channels = 1 this op is equivalent to standard " + "conv2d\n", + ""); + ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + wFormat); + return Status::OK(); + } - NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - NDArray *weightsDepth = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC - - NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - if(block.width() == 3) { - if((INPUT_VARIABLE(2))->rankOf() == 4) - weightsPoint = INPUT_VARIABLE(2); - else - bias = INPUT_VARIABLE(2); - } - else if(block.width() == 4) { - weightsPoint = INPUT_VARIABLE(2); - bias = INPUT_VARIABLE(3); - } - - REQUIRE_TRUE(input->rankOf() == 4, 0, " SCONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, " SCONV2D OP: rank of weightsDepth array must be equal to 4, but got %i instead !", weightsDepth->rankOf()); - if(weightsPoint) - REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, " SCONV2D OP: rank of weightsPoint array must be equal to 4, but got %i instead !", weightsPoint->rankOf()); - if(bias) - REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, " SCONV2D OP: rank of biases array must be equal to 1 or 2, but got %i instead !", bias->rankOf());; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); - if(weightsPoint) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); - } - if (bias) - REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf()); - - if (iC == 1) { - nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); - ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - return Status::OK(); - } - - ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); + ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, + output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, + isNCHW, wFormat); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(sconv2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(sconv2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sconv2d) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weightsDShapeInfo = inputShape->at(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - Nd4jLong const* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - Nd4jLong const* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - if(block.width() == 3) - if(inputShape->at(2)[0] == 4) - weightsPShapeInfo = inputShape->at(2); - else - biasShapeInfo = inputShape->at(2); - else if(block.width() == 4) { - weightsPShapeInfo = inputShape->at(2); - biasShapeInfo = inputShape->at(3); - } - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "SCONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, "SCONV2D OP: rank of weightsDepth array must be equal to %i, but got %i instead !", rank, weightsDShapeInfo[0]); - if(weightsPShapeInfo) - REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, "SCONV2D OP: rank of weightsPoint array must be equal to %i, but got %i instead !", rank, weightsPShapeInfo[0]); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2, 0, "SCONV2D OP: rank of biases array must be <= 2, but got %i instead !", biasShapeInfo[0]);; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier - const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); - if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); - } - if (biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "SCONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - int oH, oW; // output height, width - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - Nd4jLong* outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong); - - outputShapeInfo[0] = 4; - outputShapeInfo[1] = bS; - - if (isNCHW) { - outputShapeInfo[2] = oC; - outputShapeInfo[3] = oH; - outputShapeInfo[4] = oW; - } else { - outputShapeInfo[2] = oH; - outputShapeInfo[3] = oW; - outputShapeInfo[4] = oC; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, weightsDShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weightsDShapeInfo = inputShape->at( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong const *weightsPShapeInfo = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + Nd4jLong const *biasShapeInfo = + nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + if (block.width() == 3) + if (inputShape->at(2)[0] == 4) + weightsPShapeInfo = inputShape->at(2); + else + biasShapeInfo = inputShape->at(2); + else if (block.width() == 4) { + weightsPShapeInfo = inputShape->at(2); + biasShapeInfo = inputShape->at(3); + } + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of input array must be equal to %i, but got " + "%i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of weightsDepth array must be equal to %i, " + "but got %i instead !", + rank, weightsDShapeInfo[0]); + if (weightsPShapeInfo) + REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, + "SCONV2D OP: rank of weightsPoint array must be equal to %i, " + "but got %i instead !", + rank, weightsPShapeInfo[0]); + if (biasShapeInfo) + REQUIRE_TRUE( + biasShapeInfo[0] <= 2, 0, + "SCONV2D OP: rank of biases array must be <= 2, but got %i instead !", + biasShapeInfo[0]); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int mC = weightsDShapeInfo[indWmC + 1]; // channel multiplier + const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC + 1] + : iC * mC; // output channels (oC or iC*mC) + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, + "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); + if (weightsPShapeInfo) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, + "SCONV2D OP: wrong shape of point array, expected is %s, but got %s " + "instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); + } + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, + "SCONV2D OP: wrong shape of array with biases, expected rank, " + "length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + int oH, oW; // output height, width + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo), Nd4jLong); + + outputShapeInfo[0] = 4; + outputShapeInfo[1] = bS; + + if (isNCHW) { + outputShapeInfo[2] = oC; + outputShapeInfo[3] = oH; + outputShapeInfo[4] = oW; + } else { + outputShapeInfo[2] = oH; + outputShapeInfo[3] = oW; + outputShapeInfo[4] = oC; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, weightsDShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(sconv2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sconv2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { - - NDArray *input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - NDArray *gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - NDArray *weightsDepth = INPUT_VARIABLE(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - NDArray *gradB = nullptr; // [oC] - - if(block.width() == 4) { - if((INPUT_VARIABLE(3))->rankOf() == 4) { - weightsPoint = INPUT_VARIABLE(3); - gradWP = OUTPUT_NULLIFIED(2); - } - else { - bias = INPUT_VARIABLE(3); - gradB = OUTPUT_NULLIFIED(2); - } - } - else if(block.width() == 5) { - weightsPoint = INPUT_VARIABLE(3); - bias = INPUT_VARIABLE(4); - gradWP = OUTPUT_NULLIFIED(2); - gradB = OUTPUT_NULLIFIED(3); - } - - - REQUIRE_TRUE(input->rankOf() == 4, 0, " SCONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, " SCONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsDepth array must be equal to 4 !, but got %i instead !", weightsDepth->rankOf()); - if(weightsPoint) { - REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsPoint array must be equal to 4, but got %i instead !", weightsPoint->rankOf()); - REQUIRE_TRUE(gradWP->rankOf() == 4, 0, " SCONV2D_BP OP: rank of weightsPoint gradients array must be equal to 4, but got %i instead !", gradWP->rankOf()); - } - if(bias) { - REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, " SCONV2D_BP OP: rank of biases array must be equal to 1 or 2, but got %i instead !", bias->rankOf()); - REQUIRE_TRUE(gradB->rankOf() == 1 || gradB->rankOf() == 2, 0, " SCONV2D_BP OP: rank of biases gradientsarray must be equal to 1 or 2, but got %i instead !", gradB->rankOf()); - } - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str()); - REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str()); - if(weightsPoint) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str()); - REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str()); - } - if (bias) { - REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D_BP OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf()); - REQUIRE_TRUE(oC == gradB->lengthOf(), 0, " SCONV2D_BP OP: length of biases gradients array must be equal to outChannels, but got %i instead", gradB->lengthOf()); - } - - // if (iC == 1) { - // nd4j_debug(" SCONV2D_BP OP: for input_channels=1 this op is equivalent to standard conv2d_bp \n",""); - // sd::ops::conv2d_bp op; - // return op.execute(&block); - // } - - // ----- if weightsPoint is present, perform pointwise backprop first and calculate gradWP at this step ----- // - if (weightsPoint){ - - auto resultFFShape = isNCHW ? std::vector({bS, mC*iC, oH, oW}) : std::vector({bS, oH, oW, mC*iC}); - auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); - ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW, wFormat); // in this case oH=iH and oW=iW - - gradO = gradIDepth; - bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step - - delete resultFF; + NDArray *input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + NDArray *gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + NDArray *weightsDepth = INPUT_VARIABLE( + 2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *weightsPoint = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + NDArray *gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + NDArray *gradWD = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + NDArray *gradWP = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + NDArray *gradB = nullptr; // [oC] + + if (block.width() == 4) { + if ((INPUT_VARIABLE(3))->rankOf() == 4) { + weightsPoint = INPUT_VARIABLE(3); + gradWP = OUTPUT_NULLIFIED(2); + } else { + bias = INPUT_VARIABLE(3); + gradB = OUTPUT_NULLIFIED(2); } - - // ----- apply depthwise_conv2d_bp ----- // - ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW, wFormat); - - if(weightsPoint) - delete gradO; - - return Status::OK(); + } else if (block.width() == 5) { + weightsPoint = INPUT_VARIABLE(3); + bias = INPUT_VARIABLE(4); + gradWP = OUTPUT_NULLIFIED(2); + gradB = OUTPUT_NULLIFIED(3); + } + + REQUIRE_TRUE(input->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of input array must be equal to 4, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of output gradients (next epsilon) array " + "must be equal to 4, but got %i instead !", + gradO->rankOf()); + REQUIRE_TRUE(weightsDepth->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsDepth array must be equal to 4 " + "!, but got %i instead !", + weightsDepth->rankOf()); + if (weightsPoint) { + REQUIRE_TRUE(weightsPoint->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsPoint array must be equal to " + "4, but got %i instead !", + weightsPoint->rankOf()); + REQUIRE_TRUE(gradWP->rankOf() == 4, 0, + " SCONV2D_BP OP: rank of weightsPoint gradients array must be " + "equal to 4, but got %i instead !", + gradWP->rankOf()); + } + if (bias) { + REQUIRE_TRUE(bias->rankOf() == 1 || bias->rankOf() == 2, 0, + " SCONV2D_BP OP: rank of biases array must be equal to 1 or " + "2, but got %i instead !", + bias->rankOf()); + REQUIRE_TRUE(gradB->rankOf() == 1 || gradB->rankOf() == 2, 0, + " SCONV2D_BP OP: rank of biases gradientsarray must be equal " + "to 1 or 2, but got %i instead !", + gradB->rankOf()); + } + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, + " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDepth).c_str()); + REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, + " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(gradWD).c_str()); + if (weightsPoint) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, + " SCONV2D_BP OP: wrong shape of weightsPoint array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPoint).c_str()); + REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, + " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(gradWP).c_str()); + } + if (bias) { + REQUIRE_TRUE(oC == bias->lengthOf(), 0, + " SCONV2D_BP OP: length of bias array must be equal to " + "outChannels, but got %i instead", + bias->lengthOf()); + REQUIRE_TRUE(oC == gradB->lengthOf(), 0, + " SCONV2D_BP OP: length of biases gradients array must be " + "equal to outChannels, but got %i instead", + gradB->lengthOf()); + } + + // if (iC == 1) { + // nd4j_debug(" SCONV2D_BP OP: for input_channels=1 this op is equivalent + // to standard conv2d_bp \n",""); sd::ops::conv2d_bp op; return + // op.execute(&block); + // } + + // ----- if weightsPoint is present, perform pointwise backprop first and + // calculate gradWP at this step ----- // + if (weightsPoint) { + auto resultFFShape = isNCHW ? std::vector({bS, mC * iC, oH, oW}) + : std::vector({bS, oH, oW, mC * iC}); + auto resultFF = + NDArrayFactory::create_(input->ordering(), resultFFShape, + input->dataType(), block.launchContext()); + ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, + resultFF, kH, kW, sH, sW, pH, pW, dH, dW, + isSameMode, isNCHW, wFormat); + + auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC * mC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + auto gradIDepth = NDArrayFactory::create_( + resultFF->ordering(), gradIDepthShape, resultFF->dataType(), + block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, + // oH, oW] (NCHW) + + ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, + gradIDepth, gradWP, gradB, 1, 1, 1, 1, 0, 0, 1, + 1, isSameMode, isNCHW, + wFormat); // in this case oH=iH and oW=iW + + gradO = gradIDepth; + bias = gradB = nullptr; // if pointwise backprop was done then don't + // calculate gradB at depthwise_conv2d_bp step + + delete resultFF; + } + + // ----- apply depthwise_conv2d_bp ----- // + ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, + gradI, gradWD, gradB, kH, kW, sH, sW, pH, + pW, dH, dW, isSameMode, isNCHW, wFormat); + + if (weightsPoint) delete gradO; + + return Status::OK(); } - DECLARE_SHAPE_FN(sconv2d_bp) { - - auto inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradOShapeInfo = inputShape->at(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weightsDShapeInfo = inputShape->at(2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - Nd4jLong const* weightsPShapeInfo = nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - Nd4jLong const* biasShapeInfo = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr - - if(block.width() == 4) { - if(inputShape->at(3)[0] == 4) - weightsPShapeInfo = inputShape->at(3); - else - biasShapeInfo = inputShape->at(3); - } - else if(block.width() == 5) { - weightsPShapeInfo = inputShape->at(3); - biasShapeInfo = inputShape->at(4); - } - - const int rank = 4; - REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]); - REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]); - REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of weightsDepth array must be equal to %i, but got %i instead !", rank, weightsDShapeInfo[0]); - if(weightsPShapeInfo) - REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, " SCONV2D_BP OP: rank of weightsPoint array must be equal to %i, but got %i instead !", rank, weightsPShapeInfo[0]); - if(biasShapeInfo) - REQUIRE_TRUE(biasShapeInfo[0] ==1 || biasShapeInfo[0] == 2, 0, " SCONV2D_BP OP: rank of biases array must be 1 or 2, but got %i instead !", biasShapeInfo[0]);; - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); - if(!isNCHW) { - indIOioC = 3; indIiH = 1; - } - else { - indIOioC = 1; indIiH = 2; - } - - const int bS = inputShapeInfo[1]; // batch size - const int iH = inputShapeInfo[indIiH+1]; // input height - const int iW = inputShapeInfo[indIiH+2]; // input width - const int iC = inputShapeInfo[indIOioC+1]; // input channels - const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier - const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC) - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); - std::vector expectedWeightsDShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); - if(weightsPShapeInfo) { - std::vector expectedWeightsPShape = ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC*mC, oC); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); - } - if (biasShapeInfo) - REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && oC == shape::length(biasShapeInfo), 0, "SCONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo)); - - auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.workspace()); - auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsDShapeInfo, gradOShapeInfo, false, block.workspace()); - - Nd4jLong* gradWPshapeInfo(nullptr), *gradBshapeInfo(nullptr); - - if(weightsPShapeInfo && biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo), CONSTANT(gradBshapeInfo)); - } - - if(weightsPShapeInfo && !biasShapeInfo) { - gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradWPshapeInfo)); - } - - if(!weightsPShapeInfo && biasShapeInfo) { - gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.workspace()); - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), CONSTANT(gradBshapeInfo)); - } - - return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo)); + auto inputShapeInfo = + inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradOShapeInfo = inputShape->at( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weightsDShapeInfo = inputShape->at( + 2); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + Nd4jLong const *weightsPShapeInfo = + nullptr; // [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + Nd4jLong const *biasShapeInfo = + nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr + + if (block.width() == 4) { + if (inputShape->at(3)[0] == 4) + weightsPShapeInfo = inputShape->at(3); + else + biasShapeInfo = inputShape->at(3); + } else if (block.width() == 5) { + weightsPShapeInfo = inputShape->at(3); + biasShapeInfo = inputShape->at(4); + } + + const int rank = 4; + REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of input array must be equal to %i, but " + "got %i instead !", + rank, inputShapeInfo[0]); + REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of output gradients (next epsilon) array " + "must be equal to %i, but got %i instead !", + rank, gradOShapeInfo[0]); + REQUIRE_TRUE(weightsDShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of weightsDepth array must be equal to " + "%i, but got %i instead !", + rank, weightsDShapeInfo[0]); + if (weightsPShapeInfo) + REQUIRE_TRUE(weightsPShapeInfo[0] == rank, 0, + " SCONV2D_BP OP: rank of weightsPoint array must be equal to " + "%i, but got %i instead !", + rank, weightsPShapeInfo[0]); + if (biasShapeInfo) + REQUIRE_TRUE(biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2, 0, + " SCONV2D_BP OP: rank of biases array must be 1 or 2, but got " + "%i instead !", + biasShapeInfo[0]); + ; + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int indIOioC, indIiH, indWmC(0 == wFormat ? 3 : 0); + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + } else { + indIOioC = 1; + indIiH = 2; + } + + const int bS = inputShapeInfo[1]; // batch size + const int iH = inputShapeInfo[indIiH + 1]; // input height + const int iW = inputShapeInfo[indIiH + 2]; // input width + const int iC = inputShapeInfo[indIOioC + 1]; // input channels + const int mC = weightsDShapeInfo[indWmC + 1]; // channel multiplier + const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC + 1] + : iC * mC; // output channels (oC or iC*mC) + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShapeInfo = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, + "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), + ShapeUtils::shapeAsString(gradOShapeInfo).c_str()); + std::vector expectedWeightsDShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, + "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), + ShapeUtils::shapeAsString(weightsDShapeInfo).c_str()); + if (weightsPShapeInfo) { + std::vector expectedWeightsPShape = + ConvolutionUtils::expectWeightsShape(wFormat, 1, 1, iC * mC, oC); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, + "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s " + "instead !", + ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), + ShapeUtils::shapeAsString(weightsPShapeInfo).c_str()); + } + if (biasShapeInfo) + REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && + oC == shape::length(biasShapeInfo), + 0, + "SCONV2D_BP OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, biasShapeInfo[0], shape::length(biasShapeInfo)); + + auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShapeInfo, gradOShapeInfo, false, block.workspace()); + auto gradWDshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsDShapeInfo, gradOShapeInfo, false, block.workspace()); + + Nd4jLong *gradWPshapeInfo(nullptr), *gradBshapeInfo(nullptr); + + if (weightsPShapeInfo && biasShapeInfo) { + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradWPshapeInfo), CONSTANT(gradBshapeInfo)); + } + + if (weightsPShapeInfo && !biasShapeInfo) { + gradWPshapeInfo = ShapeBuilders::copyShapeInfoAndType( + weightsPShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradWPshapeInfo)); + } + + if (!weightsPShapeInfo && biasShapeInfo) { + gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType( + biasShapeInfo, gradOShapeInfo, false, block.workspace()); + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo), + CONSTANT(gradBshapeInfo)); + } + + return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWDshapeInfo)); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp index dc8dcf770e08..d8eadf08a959 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling2d.cpp @@ -26,103 +26,128 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { - auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const int factorH = INT_ARG(0); - const int factorW = INT_ARG(1); - const int isNCHW = block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf()); - - ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + auto output = + OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) + + const int factorH = INT_ARG(0); + const int factorW = INT_ARG(1); + const int isNCHW = + block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "UPSAMPLING2D op: input should be 4D, but got %i instead!", + input->rankOf()); + REQUIRE_TRUE(output->rankOf() == 4, 0, + "UPSAMPLING2D op: output should be 4D, but got %i instead!", + output->rankOf()); + + ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, + (bool)isNCHW); + + return Status::OK(); } DECLARE_SYN(upsampling, upsampling2d); - DECLARE_TYPES(upsampling2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(upsampling2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(upsampling2d) { - - auto inputShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inputShapeInfo[0] == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", inputShapeInfo[0]); - - const int factorH = INT_ARG(0); - const int factorW = INT_ARG(1); - const int isNCHW = block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC - - Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); - - outputShapeInfo[0] = inputShapeInfo[0]; - outputShapeInfo[1] = inputShapeInfo[1]; - - if(isNCHW) { - outputShapeInfo[2] = inputShapeInfo[2]; - outputShapeInfo[3] = inputShapeInfo[3] * factorH; - outputShapeInfo[4] = inputShapeInfo[4] * factorW; - } - else { - outputShapeInfo[2] = inputShapeInfo[2] * factorH; - outputShapeInfo[3] = inputShapeInfo[3] * factorW; - outputShapeInfo[4] = inputShapeInfo[4]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + + REQUIRE_TRUE(inputShapeInfo[0] == 4, 0, + "UPSAMPLING2D op: input should be 4D, but got %i instead!", + inputShapeInfo[0]); + + const int factorH = INT_ARG(0); + const int factorW = INT_ARG(1); + const int isNCHW = + block.numI() > 2 ? INT_ARG(2) : 0; // INT_ARG(2): 0-NCHW, 1-NHWC + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + + outputShapeInfo[0] = inputShapeInfo[0]; + outputShapeInfo[1] = inputShapeInfo[1]; + + if (isNCHW) { + outputShapeInfo[2] = inputShapeInfo[2]; + outputShapeInfo[3] = inputShapeInfo[3] * factorH; + outputShapeInfo[4] = inputShapeInfo[4] * factorW; + } else { + outputShapeInfo[2] = inputShapeInfo[2] * factorH; + outputShapeInfo[3] = inputShapeInfo[3] * factorW; + outputShapeInfo[4] = inputShapeInfo[4]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(upsampling2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(upsampling2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) { - - // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const int isNCHW = block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC - - // REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); - REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - - ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW); - - return Status::OK(); + // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] + // (NCHW) or [bS, iH, iW, iC] (NHWC) + auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) + // or [bS, factorH*iH, factorW*iW, iC] (NHWC) + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + + const int isNCHW = + block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + + // REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input array must + // be 4D, but got %i instead!", input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "UPSAMPLING2D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + gradO->rankOf()); + REQUIRE_TRUE(gradI->rankOf() == 4, 0, + "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got " + "%i instead!", + gradI->rankOf()); + + ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW); + + return Status::OK(); } DECLARE_SYN(upsampling_bp, upsampling2d_bp); - DECLARE_SHAPE_FN(upsampling2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.workspace()); - - return SHAPELIST(CONSTANT(gradIShapeInfo)); + REQUIRE_TRUE( + inputShape->at(0)[0] == 4, 0, + "UPSAMPLING2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "UPSAMPLING2D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + inputShape->at(1)[0]); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), inputShape->at(1), false, block.workspace()); + + return SHAPELIST(CONSTANT(gradIShapeInfo)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp index dce61910eba1..4609ec4d6f1d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/upsampling3d.cpp @@ -25,103 +25,131 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { - auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const int factorD = INT_ARG(0); - const int factorH = INT_ARG(1); - const int factorW = INT_ARG(2); - const int isNCDHW = block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf()); - - ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW); + auto input = INPUT_VARIABLE( + 0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const int factorD = INT_ARG(0); + const int factorH = INT_ARG(1); + const int factorW = INT_ARG(2); + const int isNCDHW = + block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "UPSAMPLING3D op: input should be 5D, but got %i instead!", + input->rankOf()); + REQUIRE_TRUE(output->rankOf() == 5, 0, + "UPSAMPLING3D op: output should be 5D, but got %i instead!", + output->rankOf()); + + ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, + factorW, (bool)isNCDHW); + + return Status::OK(); +} - return Status::OK(); +DECLARE_TYPES(upsampling3d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_TYPES(upsampling3d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - DECLARE_SHAPE_FN(upsampling3d) { - - auto inputShapeInfo = inputShape->at(0); - - REQUIRE_TRUE(inputShapeInfo[0] == 5, 0, "UPSAMPLING2D op: input should be 5D, but got %i instead!", inputShapeInfo[0]); - - const int factorD = INT_ARG(0); - const int factorH = INT_ARG(1); - const int factorW = INT_ARG(2); - const int isNCDHW = block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC - - Nd4jLong *outputShapeInfo = nullptr; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); - - outputShapeInfo[0] = inputShapeInfo[0]; - outputShapeInfo[1] = inputShapeInfo[1]; - - if(isNCDHW) { - outputShapeInfo[2] = inputShapeInfo[2]; - outputShapeInfo[3] = inputShapeInfo[3] * factorD; - outputShapeInfo[4] = inputShapeInfo[4] * factorH; - outputShapeInfo[5] = inputShapeInfo[5] * factorW; - } - else { - outputShapeInfo[2] = inputShapeInfo[2] * factorD; - outputShapeInfo[3] = inputShapeInfo[3] * factorH; - outputShapeInfo[4] = inputShapeInfo[4] * factorW; - outputShapeInfo[5] = inputShapeInfo[5]; - } - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - return SHAPELIST(CONSTANT(outputShapeInfo)); + auto inputShapeInfo = inputShape->at(0); + + REQUIRE_TRUE(inputShapeInfo[0] == 5, 0, + "UPSAMPLING2D op: input should be 5D, but got %i instead!", + inputShapeInfo[0]); + + const int factorD = INT_ARG(0); + const int factorH = INT_ARG(1); + const int factorW = INT_ARG(2); + const int isNCDHW = + block.numI() > 3 ? INT_ARG(3) : 0; // INT_ARG(3): 0-NCHW, 1-NHWC + + Nd4jLong *outputShapeInfo = nullptr; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(inputShapeInfo[0]), Nd4jLong); + + outputShapeInfo[0] = inputShapeInfo[0]; + outputShapeInfo[1] = inputShapeInfo[1]; + + if (isNCDHW) { + outputShapeInfo[2] = inputShapeInfo[2]; + outputShapeInfo[3] = inputShapeInfo[3] * factorD; + outputShapeInfo[4] = inputShapeInfo[4] * factorH; + outputShapeInfo[5] = inputShapeInfo[5] * factorW; + } else { + outputShapeInfo[2] = inputShapeInfo[2] * factorD; + outputShapeInfo[3] = inputShapeInfo[3] * factorH; + outputShapeInfo[4] = inputShapeInfo[4] * factorW; + outputShapeInfo[5] = inputShapeInfo[5]; + } + + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + + return SHAPELIST(CONSTANT(outputShapeInfo)); } - DECLARE_TYPES(upsampling3d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(upsampling3d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { - // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - - const int isNCDHW = block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC - - // REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", input->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); - REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - - ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW); - - return Status::OK(); + // NDArray* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] + // (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + auto gradO = INPUT_VARIABLE( + 1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + + const int isNCDHW = + block.numI() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC + + // REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input array must + // be 4D, but got %i instead!", input->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "UPSAMPLING3D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + gradO->rankOf()); + REQUIRE_TRUE(gradI->rankOf() == 5, 0, + "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got " + "%i instead!", + gradI->rankOf()); + + ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW); + + return Status::OK(); } - DECLARE_SHAPE_FN(upsampling3d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 5, 0, "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", inputShape->at(1)[0]); - - auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), inputShape->at(1), false, block.workspace()); - - return SHAPELIST(CONSTANT(gradIShapeInfo)); + REQUIRE_TRUE( + inputShape->at(0)[0] == 5, 0, + "UPSAMPLING3D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 5, 0, + "UPSAMPLING3D_BP op: output's gradient array must be 4D, but " + "got %i instead!", + inputShape->at(1)[0]); + + auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(0), inputShape->at(1), false, block.workspace()); + + return SHAPELIST(CONSTANT(gradIShapeInfo)); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp index 117e4d909675..171dbc33f58b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/dot_product_attention.cpp @@ -24,202 +24,241 @@ #include #include - namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(dot_product_attention, 3, -1, false, 0, 2) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto mask = block.width() > 3 ? INPUT_VARIABLE(3) : nullptr; - - auto output = OUTPUT_VARIABLE(0); - NDArray* weights; - bool outputWeights = INT_ARG(1); - if(outputWeights){ - weights = OUTPUT_VARIABLE(1); - }else{ - auto weightShape = ShapeUtils::evalShapeForMatmul(keys->shapeInfo(), queries->shapeInfo(), true, false); - weights = new NDArray('c', weightShape, values->dataType(), block.launchContext()); - } - - int normalization = INT_ARG(0); - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, - "dot_product_attention: Queries, Keys and Values must be rank 3 arrays for single headed attention " - "or rank 4 arrays for multi headed attention. But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && keys->sizeAt(0) == values->sizeAt(0), 0, - "dot_product_attention: Queries, Keys and Values must have the same mini batch size. " - "But got queries = %i, keys = %i, values = %i", queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); - - REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, - "dot_product_attention: Queries and Keys must have the same feature size. " - "But got queries = %i, keys = %i", queries->sizeAt(-2), keys->sizeAt(-2)); - - REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, - "dot_product_attention: Keys and Values must have the same timestep length. " - "But got keys = %i, values = %i", keys->sizeAt(-1), values->sizeAt(-1)); - - sd::ops::matmul mmul; - mmul.execute({keys, queries}, {weights}, {}, {1}, {}); - if(normalization) { - *weights /= sqrt((double)keys->sizeAt(-2)); - } - - if(mask != nullptr && mask->defined()){ - NDArray reshapedMask; - if(weights->rankOf() == 4){ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); - }else{ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1}); - } - - // the mask is 0 for positions we want to skip, and 1 for positions we want to keep. By subtracting 1 from - // it we get -1 for those we want to skip and 0 for those we want to keep. Multiplying it by 1e9 then - // turns all of those we want to skip into very large negative values. By adding this to the weights - // before going through the softmax, we effectively push all masked positions to zero after softmax. - // - // we are using 1e9 to mean effectively infinity - *weights += (reshapedMask - 1) * 1e9; - } - - sd::ops::softmax softmax; - softmax.execute({weights}, std::vector{weights}, {}, {-2}, {}, {}, true); - - mmul.execute({values, weights}, {output}, {}, {}, {}); - - if(!outputWeights){ - delete weights; - } - - return Status::OK(); - } - - - DECLARE_TYPES(dot_product_attention) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(dot_product_attention) { - auto query_shape = inputShape->at(0); - auto keys_shape = inputShape->at(1); - auto values_shape = inputShape->at(2); - - auto weights_shape = ConstantShapeHelper::getInstance()->createShapeInfo(sd::ArrayOptions::dataType(values_shape), 'c', ShapeUtils::evalShapeForMatmul(keys_shape, query_shape, true, false)); - auto output_shape = ConstantShapeHelper::getInstance()->createShapeInfo(sd::ArrayOptions::dataType(values_shape), 'c', ShapeUtils::evalShapeForMatmul(values_shape, weights_shape, false, false)); - - if(INT_ARG(1)){ - return SHAPELIST(output_shape, weights_shape); - }else{ - return SHAPELIST(output_shape); - } - +namespace ops { + +CUSTOM_OP_IMPL(dot_product_attention, 3, -1, false, 0, 2) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto mask = block.width() > 3 ? INPUT_VARIABLE(3) : nullptr; + + auto output = OUTPUT_VARIABLE(0); + NDArray *weights; + bool outputWeights = INT_ARG(1); + if (outputWeights) { + weights = OUTPUT_VARIABLE(1); + } else { + auto weightShape = ShapeUtils::evalShapeForMatmul( + keys->shapeInfo(), queries->shapeInfo(), true, false); + weights = new NDArray('c', weightShape, values->dataType(), + block.launchContext()); + } + + int normalization = INT_ARG(0); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "dot_product_attention: Queries, Keys and Values must have same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, + "dot_product_attention: Queries, Keys and Values must be rank 3 " + "arrays for single headed attention " + "or rank 4 arrays for multi headed attention. But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && + keys->sizeAt(0) == values->sizeAt(0), + 0, + "dot_product_attention: Queries, Keys and Values must have the " + "same mini batch size. " + "But got queries = %i, keys = %i, values = %i", + queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + + REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, + "dot_product_attention: Queries and Keys must have the same " + "feature size. " + "But got queries = %i, keys = %i", + queries->sizeAt(-2), keys->sizeAt(-2)); + + REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, + "dot_product_attention: Keys and Values must have the same " + "timestep length. " + "But got keys = %i, values = %i", + keys->sizeAt(-1), values->sizeAt(-1)); + + sd::ops::matmul mmul; + mmul.execute({keys, queries}, {weights}, {}, {1}, {}); + if (normalization) { + *weights /= sqrt((double)keys->sizeAt(-2)); + } + + if (mask != nullptr && mask->defined()) { + NDArray reshapedMask; + if (weights->rankOf() == 4) { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); + } else { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), mask->sizeAt(1), 1}); } - CUSTOM_OP_IMPL(dot_product_attention_bp, 4, 3, false, 0, 1) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto eps = INPUT_VARIABLE(3); - auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; - - auto dLdq = OUTPUT_VARIABLE(0); - auto dLdk = OUTPUT_VARIABLE(1); - auto dLdv = OUTPUT_VARIABLE(2); - - int normalization = INT_ARG(0); - - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, - "dot_product_attention: Queries, Keys and Values must be rank 3 arrays for single headed attention " - "or rank 4 arrays for multi headed attention. But got rank = %i", queries->rankOf()); + // the mask is 0 for positions we want to skip, and 1 for positions we want + // to keep. By subtracting 1 from it we get -1 for those we want to skip and + // 0 for those we want to keep. Multiplying it by 1e9 then turns all of + // those we want to skip into very large negative values. By adding this to + // the weights before going through the softmax, we effectively push all + // masked positions to zero after softmax. + // + // we are using 1e9 to mean effectively infinity + *weights += (reshapedMask - 1) * 1e9; + } - REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && keys->sizeAt(0) == values->sizeAt(0), 0, - "dot_product_attention: Queries, Keys and Values must have the same mini batch size. " - "But got queries = %i, keys = %i, values = %i", queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + sd::ops::softmax softmax; + softmax.execute({weights}, std::vector{weights}, {}, {-2}, {}, {}, + true); - REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, - "dot_product_attention: Queries and Keys must have the same feature size. " - "But got queries = %i, keys = %i", queries->sizeAt(-2), keys->sizeAt(-2)); + mmul.execute({values, weights}, {output}, {}, {}, {}); - REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, - "dot_product_attention: Keys and Values must have the same timestep length. " - "But got keys = %i, values = %i", keys->sizeAt(-1), values->sizeAt(-1)); + if (!outputWeights) { + delete weights; + } + return Status::OK(); +} - double factor; - if(normalization) - factor = sqrt((double)keys->sizeAt(-2)); - - auto weightShape = ShapeUtils::evalShapeForMatmul(keys->shapeInfo(), queries->shapeInfo(), true, false); - - sd::ops::matmul mmul; - NDArray preSoftmax('c', weightShape, values->dataType(), block.launchContext()); - mmul.execute({keys, queries}, {&preSoftmax},{}, {1}, {}); - - if(normalization) - preSoftmax /= factor; +DECLARE_TYPES(dot_product_attention) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - if(mask != nullptr){ - NDArray reshapedMask; - if(preSoftmax.rankOf() == 4){ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); - }else{ - reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1}); - } - preSoftmax += (reshapedMask - 1) * 1e9; - } +DECLARE_SHAPE_FN(dot_product_attention) { + auto query_shape = inputShape->at(0); + auto keys_shape = inputShape->at(1); + auto values_shape = inputShape->at(2); + + auto weights_shape = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::ArrayOptions::dataType(values_shape), 'c', + ShapeUtils::evalShapeForMatmul(keys_shape, query_shape, true, false)); + auto output_shape = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::ArrayOptions::dataType(values_shape), 'c', + ShapeUtils::evalShapeForMatmul(values_shape, weights_shape, false, + false)); + + if (INT_ARG(1)) { + return SHAPELIST(output_shape, weights_shape); + } else { + return SHAPELIST(output_shape); + } +} - NDArray weights('c', weightShape, values->dataType(), block.launchContext()); - sd::ops::softmax softmax; - softmax.execute({&preSoftmax}, {&weights},{}, {-2}, {}); +CUSTOM_OP_IMPL(dot_product_attention_bp, 4, 3, false, 0, 1) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto eps = INPUT_VARIABLE(3); + auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; + + auto dLdq = OUTPUT_VARIABLE(0); + auto dLdk = OUTPUT_VARIABLE(1); + auto dLdv = OUTPUT_VARIABLE(2); + + int normalization = INT_ARG(0); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "dot_product_attention: Queries, Keys and Values must have same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3 || queries->rankOf() == 4, 0, + "dot_product_attention: Queries, Keys and Values must be rank 3 " + "arrays for single headed attention " + "or rank 4 arrays for multi headed attention. But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(queries->sizeAt(0) == keys->sizeAt(0) && + keys->sizeAt(0) == values->sizeAt(0), + 0, + "dot_product_attention: Queries, Keys and Values must have the " + "same mini batch size. " + "But got queries = %i, keys = %i, values = %i", + queries->sizeAt(0), keys->sizeAt(0), values->sizeAt(0)); + + REQUIRE_TRUE(queries->sizeAt(-2) == keys->sizeAt(-2), 0, + "dot_product_attention: Queries and Keys must have the same " + "feature size. " + "But got queries = %i, keys = %i", + queries->sizeAt(-2), keys->sizeAt(-2)); + + REQUIRE_TRUE(keys->sizeAt(-1) == values->sizeAt(-1), 0, + "dot_product_attention: Keys and Values must have the same " + "timestep length. " + "But got keys = %i, values = %i", + keys->sizeAt(-1), values->sizeAt(-1)); + + double factor; + if (normalization) factor = sqrt((double)keys->sizeAt(-2)); + + auto weightShape = ShapeUtils::evalShapeForMatmul( + keys->shapeInfo(), queries->shapeInfo(), true, false); + + sd::ops::matmul mmul; + NDArray preSoftmax('c', weightShape, values->dataType(), + block.launchContext()); + mmul.execute({keys, queries}, {&preSoftmax}, {}, {1}, {}); + + if (normalization) preSoftmax /= factor; + + if (mask != nullptr) { + NDArray reshapedMask; + if (preSoftmax.rankOf() == 4) { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); + } else { + reshapedMask = mask->reshape(mask->ordering(), + {mask->sizeAt(0), mask->sizeAt(1), 1}); + } + preSoftmax += (reshapedMask - 1) * 1e9; + } - sd::ops::matmul_bp mmul_bp; - NDArray dLdw(weights.shapeInfo(), block.workspace()); - mmul_bp.execute({values, &weights, eps}, std::vector{dLdv, &dLdw}, {}, {}, {}); + NDArray weights('c', weightShape, values->dataType(), block.launchContext()); + sd::ops::softmax softmax; + softmax.execute({&preSoftmax}, {&weights}, {}, {-2}, {}); - NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); - sd::ops::softmax_bp softmax_bp; - softmax_bp.execute({&preSoftmax, &dLdw}, {&dLds}, {}, {-2}, {}); + sd::ops::matmul_bp mmul_bp; + NDArray dLdw(weights.shapeInfo(), block.workspace()); + mmul_bp.execute({values, &weights, eps}, std::vector{dLdv, &dLdw}, + {}, {}, {}); - if(normalization) - dLds /= factor; + NDArray dLds(preSoftmax.shapeInfo(), block.workspace()); + sd::ops::softmax_bp softmax_bp; + softmax_bp.execute({&preSoftmax, &dLdw}, {&dLds}, {}, {-2}, {}); - mmul_bp.execute({keys, queries, &dLds}, std::vector{dLdk, dLdq}, {}, {1}, {}); + if (normalization) dLds /= factor; - return Status::OK(); - } + mmul_bp.execute({keys, queries, &dLds}, std::vector{dLdk, dLdq}, + {}, {1}, {}); - DECLARE_TYPES(dot_product_attention_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(dot_product_attention_bp) { - Nd4jLong *dLdq_shape; - COPY_SHAPE(inputShape->at(0), dLdq_shape); - Nd4jLong *dLdk_shape; - COPY_SHAPE(inputShape->at(1), dLdk_shape); - Nd4jLong *dLdv_shape; - COPY_SHAPE(inputShape->at(2), dLdv_shape); +DECLARE_TYPES(dot_product_attention_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), CONSTANT(dLdv_shape)); - } +DECLARE_SHAPE_FN(dot_product_attention_bp) { + Nd4jLong *dLdq_shape; + COPY_SHAPE(inputShape->at(0), dLdq_shape); + Nd4jLong *dLdk_shape; + COPY_SHAPE(inputShape->at(1), dLdk_shape); + Nd4jLong *dLdv_shape; + COPY_SHAPE(inputShape->at(2), dLdv_shape); + return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), + CONSTANT(dLdv_shape)); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp index e015af2dab91..963f7622ae87 100644 --- a/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/embedding_lookup.cpp @@ -21,99 +21,102 @@ #include #if NOT_EXCLUDED(OP_embedding_lookup) -#include #include -#include -#include +#include +#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); // lookup param - auto indeces = INPUT_VARIABLE(1); // indeces, as is - auto output = OUTPUT_VARIABLE(0); // - - if (block.width() > 2) { // multiple input - indeces = INPUT_VARIABLE(block.width() - 1); - std::vector dims(input->rankOf()); - int i = output->rankOf() - input->rankOf(); - for (auto& v: dims){ - v = i++; - } - - ResultSet outputView = output->allTensorsAlongDimension(dims); - REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", - output->sizeAt(0), block.width() - ); - for (Nd4jLong e = 0; e < indeces->lengthOf(); ++e) { - Nd4jLong thisIndex = (*indeces).e(e); - input = INPUT_VARIABLE(thisIndex); // lookup param - - outputView.at(e).assign(input); - } + auto input = INPUT_VARIABLE(0); // lookup param + auto indeces = INPUT_VARIABLE(1); // indeces, as is + auto output = OUTPUT_VARIABLE(0); // + + if (block.width() > 2) { // multiple input + indeces = INPUT_VARIABLE(block.width() - 1); + std::vector dims(input->rankOf()); + int i = output->rankOf() - input->rankOf(); + for (auto& v : dims) { + v = i++; } - else { - int indexRank = indeces->rankOf(); - REQUIRE_TRUE(indexRank > 0, 0, "embeded_lookup: input array of indexes can't be single scalar, the requirement is: rank > 0 !"); - - int inputRank = input->rankOf(); - int lastIndDim = indeces->lengthOf(); - int partition_mode = INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div' - sd::ops::gather op; + ResultSet outputView = output->allTensorsAlongDimension(dims); + REQUIRE_TRUE( + block.width() > output->sizeAt(0), 0, + "embedding_lookup: input list should be greater then %i, but %i given.", + output->sizeAt(0), block.width()); + for (Nd4jLong e = 0; e < indeces->lengthOf(); ++e) { + Nd4jLong thisIndex = (*indeces).e(e); + input = INPUT_VARIABLE(thisIndex); // lookup param - auto result(op.evaluate({input, indeces}, {0})); - REQUIRE_TRUE(result.status() == Status::OK(), 0, "embedding_lookup: cannot retrieve results from gather op."); - REQUIRE_TRUE(result.at(0).isSameShape(output), 0, "embedding_lookup: wrong shape of return from gather op."); - output->assign(result.at(0)); + outputView.at(e).assign(input); } - return Status::OK(); + } else { + int indexRank = indeces->rankOf(); + REQUIRE_TRUE(indexRank > 0, 0, + "embeded_lookup: input array of indexes can't be single " + "scalar, the requirement is: rank > 0 !"); + + int inputRank = input->rankOf(); + int lastIndDim = indeces->lengthOf(); + int partition_mode = + INT_ARG(0); // partition_mode == 0 - i.e. 'mod' , 1 - 'div' + + sd::ops::gather op; + + auto result(op.evaluate({input, indeces}, {0})); + REQUIRE_TRUE(result.status() == Status::OK(), 0, + "embedding_lookup: cannot retrieve results from gather op."); + REQUIRE_TRUE(result.at(0).isSameShape(output), 0, + "embedding_lookup: wrong shape of return from gather op."); + output->assign(result.at(0)); + } + return Status::OK(); } DECLARE_TYPES(embedding_lookup) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); } DECLARE_SHAPE_FN(embedding_lookup) { + auto inShapeInfo = inputShape->at(0); + auto indecesShapeInfo = inputShape->at(1); + int inRank = shape::rank(inShapeInfo); + if (inputShape->size() == 2u) { + int outRank = inRank; - auto inShapeInfo = inputShape->at(0); - auto indecesShapeInfo = inputShape->at(1); - int inRank = shape::rank(inShapeInfo); - if (inputShape->size() == 2u) { - int outRank = inRank; - - std::vector shapeInfo(outRank); - - shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements - for (int e = 1; e < outRank; e++) - shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - - auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); - return SHAPELIST(outShapeInfo); - } - - - int outRank = inRank + 1; std::vector shapeInfo(outRank); - auto indeces = INPUT_VARIABLE(block.width() - 1); - shapeInfo[0] = indeces->lengthOf(); // vector - how many elements + + shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements for (int e = 1; e < outRank; e++) - shapeInfo[e] = shape::sizeAt(inShapeInfo, e); + shapeInfo[e] = shape::sizeAt(inShapeInfo, e); - auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + shapeInfo); return SHAPELIST(outShapeInfo); + } + + int outRank = inRank + 1; + std::vector shapeInfo(outRank); + auto indeces = INPUT_VARIABLE(block.width() - 1); + shapeInfo[0] = indeces->lengthOf(); // vector - how many elements + for (int e = 1; e < outRank; e++) + shapeInfo[e] = shape::sizeAt(inShapeInfo, e); + + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + shapeInfo); + return SHAPELIST(outShapeInfo); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index 9fc71572dd43..a5ff3038e7d7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -26,130 +26,147 @@ namespace sd { namespace ops { - DECLARE_TYPES(fused_batch_norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(fused_batch_norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { - auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) - auto scale = INPUT_VARIABLE(1); // [iD] - auto offset = INPUT_VARIABLE(2); // [iD] - - auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) - auto batchMean = OUTPUT_VARIABLE(1); // [iD] - auto batchVar = OUTPUT_VARIABLE(2); // [iD] - - const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const bool isTraining = (bool)INT_ARG(1); - - REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); - - int bS = x->sizeAt(0); // batch size - int iH, iW, iD; // input height, input width, input depth(number of channels) - if(dataFormat) { - iD = x->sizeAt(1); - iH = x->sizeAt(2); - iW = x->sizeAt(3); - } - else { - iD = x->sizeAt(3); - iH = x->sizeAt(1); - iW = x->sizeAt(2); - } - - REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); - REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); - - NDArray *mean(nullptr), *variance(nullptr); - if(!isTraining){ - mean = INPUT_VARIABLE(3); - variance = INPUT_VARIABLE(4); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); - } - else { - //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); - std::vector shape = {iD}; - mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); - variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); - } - - // FIXME: double? - double epsilon; - if(block.numT() > 0) - epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; - else - epsilon = 0.001; - - const int restSize = x->lengthOf() / iD; - auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); - xAffected.assign(x); - - const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; - // FIXME: float? - const double restSizeInv = 1.0 / restSize; - const double restSizeAdjust = (double)restSize / restSizeMinusOne; - - if(isTraining) { - auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); - sum *= restSizeInv; - mean->assign(sum); - *batchMean = *mean; - //delete sum; - } - else - *batchMean = 0.; - - xAffected -= *mean; - - if(isTraining) { - int power = 2; - xAffected.applyScalar(scalar::Pow, power, xAffected); - auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); - sum *= restSizeInv; - variance->assign(sum); - *batchVar = (*variance) * restSizeAdjust; - //delete sum; - } - else - *batchVar = 0.; - xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); - y->assign( xAffected ); - - if(isTraining) { - delete mean; - delete variance; - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) + auto scale = INPUT_VARIABLE(1); // [iD] + auto offset = INPUT_VARIABLE(2); // [iD] + + auto y = OUTPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) + auto batchMean = OUTPUT_VARIABLE(1); // [iD] + auto batchVar = OUTPUT_VARIABLE(2); // [iD] + + const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW + const bool isTraining = (bool)INT_ARG(1); + + REQUIRE_TRUE(x->rankOf() == 4, 0, + "CUSTOM_OP fused_batch_norm: the rank of input x array must be " + "equal to 4, but got %i instead !", + x->rankOf()); + + int bS = x->sizeAt(0); // batch size + int iH, iW, iD; // input height, input width, input depth(number of channels) + if (dataFormat) { + iD = x->sizeAt(1); + iH = x->sizeAt(2); + iW = x->sizeAt(3); + } else { + iD = x->sizeAt(3); + iH = x->sizeAt(1); + iW = x->sizeAt(2); + } + + REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(scale).c_str()); + REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(offset).c_str()); + + NDArray *mean(nullptr), *variance(nullptr); + if (!isTraining) { + mean = INPUT_VARIABLE(3); + variance = INPUT_VARIABLE(4); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input variance " + "array, expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(variance).c_str()); + } else { + // REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when + // isTraining=true then number of input arrays must be equal to 3, but got %i + // instead !", block.width()); + std::vector shape = {iD}; + mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), + block.launchContext()); + variance = NDArrayFactory::create_( + scale->ordering(), shape, scale->dataType(), block.launchContext()); + } + + // FIXME: double? + double epsilon; + if (block.numT() > 0) + epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; + else + epsilon = 0.001; + + const int restSize = x->lengthOf() / iD; + auto xAffected = NDArrayFactory::create( + x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext()); + xAffected.assign(x); + + const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; + // FIXME: float? + const double restSizeInv = 1.0 / restSize; + const double restSizeAdjust = (double)restSize / restSizeMinusOne; + + if (isTraining) { + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); + sum *= restSizeInv; + mean->assign(sum); + *batchMean = *mean; + // delete sum; + } else + *batchMean = 0.; + + xAffected -= *mean; + + if (isTraining) { + int power = 2; + xAffected.applyScalar(scalar::Pow, power, xAffected); + auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0}); + sum *= restSizeInv; + variance->assign(sum); + *batchVar = (*variance) * restSizeAdjust; + // delete sum; + } else + *batchVar = 0.; + xAffected *= + (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); + y->assign(xAffected); + + if (isTraining) { + delete mean; + delete variance; + } + + return Status::OK(); } - - DECLARE_SHAPE_FN(fused_batch_norm) { - auto xShapeInfo = inputShape->at(0); - auto scaleShapeInfo = inputShape->at(1); + auto xShapeInfo = inputShape->at(0); + auto scaleShapeInfo = inputShape->at(1); - const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW - const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; + const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW + const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; - REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); + REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, + "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, " + "expected is [%i], but got %s instead", + iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); - Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr); + Nd4jLong *outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), + *batchVarShapeInfo(nullptr); - COPY_SHAPE(xShapeInfo, outShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); - COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); + COPY_SHAPE(xShapeInfo, outShapeInfo); + COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); + COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); - return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), + CONSTANT(batchVarShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp index 1371d6b68b43..f21242937cab 100644 --- a/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp @@ -22,124 +22,147 @@ #if NOT_EXCLUDED(OP_layer_norm) #include -#include #include +#include namespace sd { -namespace ops { +namespace ops { + +CONFIGURABLE_OP_IMPL(layer_norm, 2, 1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto gain = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + auto axis = block.getIArguments(); + + const bool isNCHW = + block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, + "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + + NDArray *bias = nullptr; + if (block.width() > 2) { + bias = INPUT_VARIABLE(2); + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), + 0, + "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + } + + std::vector longAxis = ArrayUtils::toLongVector(axis); + + sd::ops::standardize standardizeOp; + std::vector inputs = {input}; + std::vector outputs = {output}; + std::vector targs = {}; + std::vector bargs = {}; + standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); + + // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, + // output); + output->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *output); + if (bias != nullptr) { + // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output); + // output->applyBroadcast(sd::broadcast::Add, {dimC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + } + + return Status::OK(); +} - CONFIGURABLE_OP_IMPL(layer_norm, 2, 1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto gain = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); +DECLARE_TYPES(layer_norm) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} - auto axis = block.getIArguments(); - - const bool isNCHW = block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC - const int dimC = isNCHW ? 1 : input->rankOf() - 1; - - REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); - - NDArray* bias = nullptr; - if (block.width() > 2) { - bias = INPUT_VARIABLE(2); - REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - } - - std::vector longAxis = ArrayUtils::toLongVector(axis); - - sd::ops::standardize standardizeOp; - std::vector inputs = {input}; - std::vector outputs = {output}; - std::vector targs = {}; - std::vector bargs = {}; - standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - - // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, output); - output->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *output); - if(bias != nullptr) { - // output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output); - // output->applyBroadcast(sd::broadcast::Add, {dimC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - } - - return Status::OK(); - } - - - DECLARE_TYPES(layer_norm) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(layer_norm_bp, 3, -1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto gain = INPUT_VARIABLE(1); - auto bias = block.width() == 4 ? INPUT_VARIABLE(2) : nullptr; - auto eps = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdg = OUTPUT_VARIABLE(1); - auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; - - const bool isNCHW = block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC - const int dimC = isNCHW ? 1 : input->rankOf() - 1; - - REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); - - auto axis = block.getIArguments(); - - std::vector longAxis = ArrayUtils::toLongVector(axis); - - if(bias != nullptr) { - REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); - // eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true); - eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); - } - - NDArray standardized(input->shapeInfo(), false, block.launchContext()); - - sd::ops::standardize standardizeOp; - std::vector inputs = {input}; - std::vector outputs = {&standardized}; - std::vector targs = {}; - std::vector bargs = {}; - - standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - standardized.applyPairwiseTransform(sd::pairwise::Multiply, *eps, standardized); - standardized.reduceAlongDimension(sd::reduce::Sum, *dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); - - sd::ops::standardize_bp standardizeBp; - // eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx); - eps->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *dLdx); - - auto dLdx_tmp = dLdx->dup(); - std::vector standardizeBpArgs = {input, &dLdx_tmp}; - std::vector standardizeBpOut = {dLdx}; - standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, bargs); - - return Status::OK(); - } - - DECLARE_TYPES(layer_norm_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(layer_norm_bp) { - Nd4jLong *dLdx_shape; - COPY_SHAPE(inputShape->at(0), dLdx_shape); - Nd4jLong *dLdg_shape; - COPY_SHAPE(inputShape->at(1), dLdg_shape); - if(inputShape->size() > 3){ - Nd4jLong *dLdb_shape; - COPY_SHAPE(inputShape->at(2), dLdb_shape); - return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape), CONSTANT(dLdb_shape)); - } - return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape)); - } +CUSTOM_OP_IMPL(layer_norm_bp, 3, -1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto gain = INPUT_VARIABLE(1); + auto bias = block.width() == 4 ? INPUT_VARIABLE(2) : nullptr; + auto eps = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdg = OUTPUT_VARIABLE(1); + auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; + + const bool isNCHW = + block.numB() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, + "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, " + "but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + + auto axis = block.getIArguments(); + + std::vector longAxis = ArrayUtils::toLongVector(axis); + + if (bias != nullptr) { + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), + 0, + "LAYER_NORM_BP OP: wrong shape of bias array, expected is " + "{%i}, but got %s instead !", + input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + // eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true); + eps->reduceAlongDimension( + sd::reduce::Sum, *dLdb, + ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + } + + NDArray standardized(input->shapeInfo(), false, block.launchContext()); + + sd::ops::standardize standardizeOp; + std::vector inputs = {input}; + std::vector outputs = {&standardized}; + std::vector targs = {}; + std::vector bargs = {}; + + standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); + standardized.applyPairwiseTransform(sd::pairwise::Multiply, *eps, + standardized); + standardized.reduceAlongDimension( + sd::reduce::Sum, *dLdg, + ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + + sd::ops::standardize_bp standardizeBp; + // eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx); + eps->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *dLdx); + + auto dLdx_tmp = dLdx->dup(); + std::vector standardizeBpArgs = {input, &dLdx_tmp}; + std::vector standardizeBpOut = {dLdx}; + standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, + bargs); + + return Status::OK(); +} +DECLARE_TYPES(layer_norm_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(layer_norm_bp) { + Nd4jLong *dLdx_shape; + COPY_SHAPE(inputShape->at(0), dLdx_shape); + Nd4jLong *dLdg_shape; + COPY_SHAPE(inputShape->at(1), dLdg_shape); + if (inputShape->size() > 3) { + Nd4jLong *dLdb_shape; + COPY_SHAPE(inputShape->at(2), dLdb_shape); + return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape), + CONSTANT(dLdb_shape)); + } + return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp index be9653a9d3ec..12e71befd578 100644 --- a/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/logSoftmax.cpp @@ -27,55 +27,59 @@ namespace sd { namespace ops { - - DECLARE_TYPES(log_softmax) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(log_softmax) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "LOG_SOFTMAX OP: the value of input integer parameter (dimension) must " + "be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::logSoftmax(block.launchContext(), *input, *output, dim); + helpers::logSoftmax(block.launchContext(), *input, *output, dim); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(log_softmax_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(log_softmax_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.numI()> 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *gradI, dim); + helpers::softmax(block.launchContext(), *input, *gradI, dim); - gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) ); + gradI->assign( + *gradO - + (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true)); - return Status::OK(); + return Status::OK(); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/lrn.cpp b/libnd4j/include/ops/declarable/generic/nn/lrn.cpp index e9546d1db711..44653807e8ce 100644 --- a/libnd4j/include/ops/declarable/generic/nn/lrn.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/lrn.cpp @@ -23,59 +23,67 @@ #include #if NOT_EXCLUDED(OP_lrn) -#include #include +#include namespace sd { - namespace ops { - - DECLARE_TYPES(lrn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CONFIGURABLE_OP_IMPL(lrn, 1, 1, true, 3, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", input->rankOf()); - - double alpha = T_ARG(1); - double beta = T_ARG(2); - double bias = T_ARG(0); - int depth = INT_ARG(0); - - return helpers::lrnFunctor(block, input, output, depth, bias, alpha, beta); - } - - DECLARE_TYPES(lrn_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - CONFIGURABLE_OP_IMPL(lrn_bp, 2, 1, true, 3, 1) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn_bp: Input rank of 4 expected, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->isSameShape(gradO), 0, "lrn_bp: Both input and grad_output should have the same shape, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - // FIXME: double/float? - float bias = T_ARG(0); - float alpha = T_ARG(1); - float beta = T_ARG(2); - int depth = INT_ARG(0); - - helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta); - - return Status::OK(); - } - DECLARE_SYN(local_response_normalization, lrn); - - } +namespace ops { + +DECLARE_TYPES(lrn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +CONFIGURABLE_OP_IMPL(lrn, 1, 1, true, 3, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn: Input rank of 4 expected, but got %i instead", + input->rankOf()); + + double alpha = T_ARG(1); + double beta = T_ARG(2); + double bias = T_ARG(0); + int depth = INT_ARG(0); + + return helpers::lrnFunctor(block, input, output, depth, bias, alpha, beta); +} + +DECLARE_TYPES(lrn_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +CONFIGURABLE_OP_IMPL(lrn_bp, 2, 1, true, 3, 1) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn_bp: Input rank of 4 expected, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(input->isSameShape(gradO), 0, + "lrn_bp: Both input and grad_output should have the same shape, " + "but got %s and %s correspondingly !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + // FIXME: double/float? + float bias = T_ARG(0); + float alpha = T_ARG(1); + float beta = T_ARG(2); + int depth = INT_ARG(0); + + helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta); + + return Status::OK(); +} +DECLARE_SYN(local_response_normalization, lrn); + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp index f9b7284f1e74..eee3c108a049 100644 --- a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp @@ -21,268 +21,357 @@ #include #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) -#include #include +#include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { - auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps] - auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps] - auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps] - auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn] - auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn] - auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn] - auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut] - auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; - - - auto output = OUTPUT_VARIABLE(0); - int normalization = INT_ARG(0); - int weights = INT_ARG(1); - - auto numHeads = Wk->sizeAt(0); - auto miniBatchSize = queries->sizeAt(0); - auto queryCount = queries->sizeAt(2); - auto projectedValuesSize = Wv->sizeAt(1); - auto outSize = Wo->sizeAt(1); - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "multi_head_dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3, 0, - "multi_head_dot_product_attention: Queries, Keys and Values must be rank 3 arrays" - "But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, - "multi_head_dot_product_attention: Input projections weights must have the same rank. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), 0, - "multi_head_dot_product_attention: Projections weights must have the same number of attention heads. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wo->rankOf() == 2, 0, - "multi_head_dot_product_attention: Output projection weights must have rank 2. " - "But got Wo = %s", ShapeUtils::shapeAsString(Wo).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(2) == queries->sizeAt(1), 0, - "multi_head_dot_product_attention: Query projection matrix Wq has incompatible size to queries matrix." - "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", queries->sizeAt(1), - ShapeUtils::shapeAsString(Wq).c_str(), ShapeUtils::shapeAsString(queries).c_str()); - - REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, - "multi_head_dot_product_attention: Key projection matrix Wk has incompatible size to keys matrix." - "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", keys->sizeAt(1), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(keys).c_str()); - - REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, - "multi_head_dot_product_attention: Value projection matrix Wv has incompatible size to values matrix." - "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", values->sizeAt(1), - ShapeUtils::shapeAsString(Wv).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, - "multi_head_dot_product_attention: Output projection matrix Wo has incompatible size to attention result." - "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", (Wv->sizeAt(1) * Wv->sizeAt(0)), - ShapeUtils::shapeAsString(Wo).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - - // Project queries, keys, values - auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] - - // Apply Attention - // attnResults = [minibatch, numHeads, projectedSize, seqLenth - NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); - sd::ops::dot_product_attention attention; - attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {}); - - // Project attention results - attnResults.permutei({0, 3, 1, 2}); - attnResults.reshapei(attnResults.ordering(), {miniBatchSize * queryCount, numHeads * projectedValuesSize}); - - sd::ops::matmul mmul; - NDArray projRes('c', {attnResults.sizeAt(0), Wo->sizeAt(1)}, values->dataType(), block.launchContext()); - mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {}); - projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize}); - projRes.permutei({0, 2, 1}); - - // FIXME: bad for performance - output->assign(projRes); - - return Status::OK(); - } - - - DECLARE_TYPES(multi_head_dot_product_attention) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(multi_head_dot_product_attention) { - auto queryShape = inputShape->at(0); - auto keysShape = inputShape->at(1); - auto valuesShape = inputShape->at(2); - auto WkShape = inputShape->at(3); - auto WoShape = inputShape->at(6); - - auto batchSize = shape::sizeAt(queryShape, 0); - auto outSize = shape::sizeAt(WoShape, 1); - auto queryCount = shape::sizeAt(queryShape, 2); - auto numHeads = shape::sizeAt(WkShape, 0); - auto timeSteps = shape::sizeAt(keysShape, 2); - - auto weightsShape = ConstantShapeHelper::getInstance()->createShapeInfo(sd::ArrayOptions::dataType(valuesShape), 'c', {batchSize, numHeads, timeSteps, queryCount}); - auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(sd::ArrayOptions::dataType(valuesShape), 'c', {batchSize, outSize, queryCount}); - - if(INT_ARG(1)){ - return SHAPELIST(outputShape, weightsShape); - }else{ - return SHAPELIST(outputShape); - } - - } - - CUSTOM_OP_IMPL(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto Wq = INPUT_VARIABLE(3); - auto Wk = INPUT_VARIABLE(4); - auto Wv = INPUT_VARIABLE(5); - auto Wo = INPUT_VARIABLE(6); - auto eps = INPUT_VARIABLE(7); - auto mask = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; - - auto dLdq = OUTPUT_VARIABLE(0); - auto dLdk = OUTPUT_VARIABLE(1); - auto dLdv = OUTPUT_VARIABLE(2); - auto dLdWq = OUTPUT_VARIABLE(3); - auto dLdWk = OUTPUT_VARIABLE(4); - auto dLdWv = OUTPUT_VARIABLE(5); - auto dLdWo = OUTPUT_VARIABLE(6); - - int normalization = INT_ARG(0); - - auto numHeads = Wk->sizeAt(0); - auto miniBatchSize = queries->sizeAt(0); - auto queryCount = queries->sizeAt(2); - auto outSize = Wo->sizeAt(1); - auto projectedValuesSize = Wv->sizeAt(1); - - - REQUIRE_TRUE(queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), 0, - "multi_head_dot_product_attention: Queries, Keys and Values must have same rank. " - "But got queries = %s, keys = %s, values = %s", ShapeUtils::shapeAsString(queries).c_str(), - ShapeUtils::shapeAsString(keys).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(queries->rankOf() == 3, 0, - "multi_head_dot_product_attention: Queries, Keys and Values must be rank 3 arrays" - "But got rank = %i", queries->rankOf()); - - REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, - "multi_head_dot_product_attention: Input projections weights must have the same rank. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), 0, - "multi_head_dot_product_attention: Projections weights must have the same number of attention heads. " - "But got Wq = %s, Wk = %s, Wv = %s", ShapeUtils::shapeAsString(Wq).c_str(), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - REQUIRE_TRUE(Wo->rankOf() == 2, 0, - "multi_head_dot_product_attention: Output projection weights must have rank 2. " - "But got Wo = %s", ShapeUtils::shapeAsString(Wo).c_str()); - - REQUIRE_TRUE(Wq->sizeAt(2) == queries->sizeAt(1), 0, - "multi_head_dot_product_attention: Query projection matrix Wq has incompatible size to queries matrix." - "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", queries->sizeAt(1), - ShapeUtils::shapeAsString(Wq).c_str(), ShapeUtils::shapeAsString(queries).c_str()); - - REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, - "multi_head_dot_product_attention: Key projection matrix Wk has incompatible size to keys matrix." - "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", keys->sizeAt(1), - ShapeUtils::shapeAsString(Wk).c_str(), ShapeUtils::shapeAsString(keys).c_str()); - - REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, - "multi_head_dot_product_attention: Value projection matrix Wv has incompatible size to values matrix." - "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", values->sizeAt(1), - ShapeUtils::shapeAsString(Wv).c_str(), ShapeUtils::shapeAsString(values).c_str()); - - REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, - "multi_head_dot_product_attention: Output projection matrix Wo has incompatible size to attention result." - "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", (Wv->sizeAt(1) * Wv->sizeAt(0)), - ShapeUtils::shapeAsString(Wo).c_str(), ShapeUtils::shapeAsString(Wv).c_str()); - - // Project queries, keys, values - auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); - auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); - auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); - - // Apply Attention - NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); - sd::ops::dot_product_attention attention; - attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {}); - - // Project attention results - attnResults.permutei({0, 3, 1, 2}); - attnResults.reshapei(attnResults.ordering(), {miniBatchSize * queryCount, numHeads * projectedValuesSize}); - - // dLdWo - auto epsPerm = eps->permute({0, 2, 1}); - auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); - sd::ops::matmul_bp matmulBp; - NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); - matmulBp.execute({&attnResults, Wo, &epsPostReshape}, std::vector{&dLdPreWo, dLdWo}, {}, {}, {}); - - // dLdAttn - dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); - dLdPreWo.permutei({0, 2, 3, 1}); - - sd::ops::dot_product_attention_bp attentionBp; - NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext()); - NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext()); - NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext()); - attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {}); - - AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext()); - AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext()); - AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext()); - - return Status::OK(); - } - - DECLARE_TYPES(multi_head_dot_product_attention_bp) { - getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(multi_head_dot_product_attention_bp) { - Nd4jLong *dLdq_shape; - COPY_SHAPE(inputShape->at(0), dLdq_shape); - Nd4jLong *dLdk_shape; - COPY_SHAPE(inputShape->at(1), dLdk_shape); - Nd4jLong *dLdv_shape; - COPY_SHAPE(inputShape->at(2), dLdv_shape); - Nd4jLong *dLdWq_shape; - COPY_SHAPE(inputShape->at(3), dLdWq_shape); - Nd4jLong *dLdWk_shape; - COPY_SHAPE(inputShape->at(4), dLdWk_shape); - Nd4jLong *dLdWv_shape; - COPY_SHAPE(inputShape->at(5), dLdWv_shape); - Nd4jLong *dLdWo_shape; - COPY_SHAPE(inputShape->at(6), dLdWo_shape); - - return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), CONSTANT(dLdv_shape), CONSTANT(dLdWq_shape), CONSTANT(dLdWk_shape), CONSTANT(dLdWv_shape), CONSTANT(dLdWo_shape)); - } +namespace ops { + +CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { + auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps] + auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps] + auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps] + auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn] + auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn] + auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn] + auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut] + auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; + + auto output = OUTPUT_VARIABLE(0); + int normalization = INT_ARG(0); + int weights = INT_ARG(1); + + auto numHeads = Wk->sizeAt(0); + auto miniBatchSize = queries->sizeAt(0); + auto queryCount = queries->sizeAt(2); + auto projectedValuesSize = Wv->sizeAt(1); + auto outSize = Wo->sizeAt(1); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "multi_head_dot_product_attention: Queries, Keys and Values must have " + "same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3, 0, + "multi_head_dot_product_attention: Queries, Keys and Values " + "must be rank 3 arrays" + "But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, + "multi_head_dot_product_attention: Input projections weights " + "must have the same rank. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), + 0, + "multi_head_dot_product_attention: Projections weights must " + "have the same number of attention heads. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wo->rankOf() == 2, 0, + "multi_head_dot_product_attention: Output projection weights " + "must have rank 2. " + "But got Wo = %s", + ShapeUtils::shapeAsString(Wo).c_str()); + + REQUIRE_TRUE( + Wq->sizeAt(2) == queries->sizeAt(1), 0, + "multi_head_dot_product_attention: Query projection matrix Wq has " + "incompatible size to queries matrix." + "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", + queries->sizeAt(1), ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(queries).c_str()); + + REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, + "multi_head_dot_product_attention: Key projection matrix Wk has " + "incompatible size to keys matrix." + "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", + keys->sizeAt(1), ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(keys).c_str()); + + REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, + "multi_head_dot_product_attention: Value projection matrix Wv " + "has incompatible size to values matrix." + "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", + values->sizeAt(1), ShapeUtils::shapeAsString(Wv).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, + "multi_head_dot_product_attention: Output projection matrix Wo " + "has incompatible size to attention result." + "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", + (Wv->sizeAt(1) * Wv->sizeAt(0)), + ShapeUtils::shapeAsString(Wo).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + // Project queries, keys, values + auto projectedQueries = AttentionHelper::multiHeadProject( + queries, Wq, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedKeys = AttentionHelper::multiHeadProject( + keys, Wk, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedValues = AttentionHelper::multiHeadProject( + values, Wv, + block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + + // Apply Attention + // attnResults = [minibatch, numHeads, projectedSize, seqLenth + NDArray attnResults('c', + {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), + projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, + projectedValues.dataType(), block.launchContext()); + sd::ops::dot_product_attention attention; + attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, + {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, + {normalization, weights}, {}); + + // Project attention results + attnResults.permutei({0, 3, 1, 2}); + attnResults.reshapei( + attnResults.ordering(), + {miniBatchSize * queryCount, numHeads * projectedValuesSize}); + + sd::ops::matmul mmul; + NDArray projRes('c', {attnResults.sizeAt(0), Wo->sizeAt(1)}, + values->dataType(), block.launchContext()); + mmul.execute({&attnResults, Wo}, {&projRes}, {}, {}, {}); + projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize}); + projRes.permutei({0, 2, 1}); + + // FIXME: bad for performance + output->assign(projRes); + + return Status::OK(); +} +DECLARE_TYPES(multi_head_dot_product_attention) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_SHAPE_FN(multi_head_dot_product_attention) { + auto queryShape = inputShape->at(0); + auto keysShape = inputShape->at(1); + auto valuesShape = inputShape->at(2); + auto WkShape = inputShape->at(3); + auto WoShape = inputShape->at(6); + + auto batchSize = shape::sizeAt(queryShape, 0); + auto outSize = shape::sizeAt(WoShape, 1); + auto queryCount = shape::sizeAt(queryShape, 2); + auto numHeads = shape::sizeAt(WkShape, 0); + auto timeSteps = shape::sizeAt(keysShape, 2); + + auto weightsShape = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::ArrayOptions::dataType(valuesShape), 'c', + {batchSize, numHeads, timeSteps, queryCount}); + auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::ArrayOptions::dataType(valuesShape), 'c', + {batchSize, outSize, queryCount}); + + if (INT_ARG(1)) { + return SHAPELIST(outputShape, weightsShape); + } else { + return SHAPELIST(outputShape); + } +} + +CUSTOM_OP_IMPL(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1) { + auto queries = INPUT_VARIABLE(0); + auto keys = INPUT_VARIABLE(1); + auto values = INPUT_VARIABLE(2); + auto Wq = INPUT_VARIABLE(3); + auto Wk = INPUT_VARIABLE(4); + auto Wv = INPUT_VARIABLE(5); + auto Wo = INPUT_VARIABLE(6); + auto eps = INPUT_VARIABLE(7); + auto mask = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; + + auto dLdq = OUTPUT_VARIABLE(0); + auto dLdk = OUTPUT_VARIABLE(1); + auto dLdv = OUTPUT_VARIABLE(2); + auto dLdWq = OUTPUT_VARIABLE(3); + auto dLdWk = OUTPUT_VARIABLE(4); + auto dLdWv = OUTPUT_VARIABLE(5); + auto dLdWo = OUTPUT_VARIABLE(6); + + int normalization = INT_ARG(0); + + auto numHeads = Wk->sizeAt(0); + auto miniBatchSize = queries->sizeAt(0); + auto queryCount = queries->sizeAt(2); + auto outSize = Wo->sizeAt(1); + auto projectedValuesSize = Wv->sizeAt(1); + + REQUIRE_TRUE( + queries->rankOf() == keys->rankOf() && keys->rankOf() == values->rankOf(), + 0, + "multi_head_dot_product_attention: Queries, Keys and Values must have " + "same rank. " + "But got queries = %s, keys = %s, values = %s", + ShapeUtils::shapeAsString(queries).c_str(), + ShapeUtils::shapeAsString(keys).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(queries->rankOf() == 3, 0, + "multi_head_dot_product_attention: Queries, Keys and Values " + "must be rank 3 arrays" + "But got rank = %i", + queries->rankOf()); + + REQUIRE_TRUE(Wq->rankOf() == Wk->rankOf() && Wk->rankOf() == Wv->rankOf(), 0, + "multi_head_dot_product_attention: Input projections weights " + "must have the same rank. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wq->sizeAt(0) == Wk->sizeAt(0) && Wk->sizeAt(0) == Wv->sizeAt(0), + 0, + "multi_head_dot_product_attention: Projections weights must " + "have the same number of attention heads. " + "But got Wq = %s, Wk = %s, Wv = %s", + ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + REQUIRE_TRUE(Wo->rankOf() == 2, 0, + "multi_head_dot_product_attention: Output projection weights " + "must have rank 2. " + "But got Wo = %s", + ShapeUtils::shapeAsString(Wo).c_str()); + + REQUIRE_TRUE( + Wq->sizeAt(2) == queries->sizeAt(1), 0, + "multi_head_dot_product_attention: Query projection matrix Wq has " + "incompatible size to queries matrix." + "Expected Wq[2] = queries[1] = %i, but got Wq = %s, queries = %s ", + queries->sizeAt(1), ShapeUtils::shapeAsString(Wq).c_str(), + ShapeUtils::shapeAsString(queries).c_str()); + + REQUIRE_TRUE(Wk->sizeAt(2) == keys->sizeAt(1), 0, + "multi_head_dot_product_attention: Key projection matrix Wk has " + "incompatible size to keys matrix." + "Expected Wk[2] = keys[1] = %i, but got Wk = %s, keys = %s ", + keys->sizeAt(1), ShapeUtils::shapeAsString(Wk).c_str(), + ShapeUtils::shapeAsString(keys).c_str()); + + REQUIRE_TRUE(Wv->sizeAt(2) == values->sizeAt(1), 0, + "multi_head_dot_product_attention: Value projection matrix Wv " + "has incompatible size to values matrix." + "Expected Wv[2] = values[1] = %i, but got Wv = %s, values = %s ", + values->sizeAt(1), ShapeUtils::shapeAsString(Wv).c_str(), + ShapeUtils::shapeAsString(values).c_str()); + + REQUIRE_TRUE(Wo->sizeAt(0) == (Wv->sizeAt(1) * Wv->sizeAt(0)), 0, + "multi_head_dot_product_attention: Output projection matrix Wo " + "has incompatible size to attention result." + "Expected Wo[0] = Wv[0] * Wv[1] = %i, but got Wo = %s, Wv = %", + (Wv->sizeAt(1) * Wv->sizeAt(0)), + ShapeUtils::shapeAsString(Wo).c_str(), + ShapeUtils::shapeAsString(Wv).c_str()); + + // Project queries, keys, values + auto projectedQueries = + AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); + auto projectedKeys = + AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); + auto projectedValues = + AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); + + // Apply Attention + NDArray attnResults('c', + {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), + projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, + projectedValues.dataType(), block.launchContext()); + sd::ops::dot_product_attention attention; + attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, + {&attnResults}, {}, {normalization, 0}, {}); + + // Project attention results + attnResults.permutei({0, 3, 1, 2}); + attnResults.reshapei( + attnResults.ordering(), + {miniBatchSize * queryCount, numHeads * projectedValuesSize}); + + // dLdWo + auto epsPerm = eps->permute({0, 2, 1}); + auto epsPostReshape = + epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); + sd::ops::matmul_bp matmulBp; + NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); + matmulBp.execute({&attnResults, Wo, &epsPostReshape}, + std::vector{&dLdPreWo, dLdWo}, {}, {}, {}); + + // dLdAttn + dLdPreWo.reshapei( + {miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)}); + dLdPreWo.permutei({0, 2, 3, 1}); + + sd::ops::dot_product_attention_bp attentionBp; + NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, + block.launchContext()); + NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, + block.launchContext()); + NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, + block.launchContext()); + attentionBp.execute( + {&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask}, + {&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, + {normalization}, {}); + + AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, + dLdWq, block.launchContext()); + AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, + block.launchContext()); + AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, + dLdWv, block.launchContext()); + + return Status::OK(); } +DECLARE_TYPES(multi_head_dot_product_attention_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(multi_head_dot_product_attention_bp) { + Nd4jLong *dLdq_shape; + COPY_SHAPE(inputShape->at(0), dLdq_shape); + Nd4jLong *dLdk_shape; + COPY_SHAPE(inputShape->at(1), dLdk_shape); + Nd4jLong *dLdv_shape; + COPY_SHAPE(inputShape->at(2), dLdv_shape); + Nd4jLong *dLdWq_shape; + COPY_SHAPE(inputShape->at(3), dLdWq_shape); + Nd4jLong *dLdWk_shape; + COPY_SHAPE(inputShape->at(4), dLdWk_shape); + Nd4jLong *dLdWv_shape; + COPY_SHAPE(inputShape->at(5), dLdWv_shape); + Nd4jLong *dLdWo_shape; + COPY_SHAPE(inputShape->at(6), dLdWo_shape); + + return SHAPELIST(CONSTANT(dLdq_shape), CONSTANT(dLdk_shape), + CONSTANT(dLdv_shape), CONSTANT(dLdWq_shape), + CONSTANT(dLdWk_shape), CONSTANT(dLdWv_shape), + CONSTANT(dLdWo_shape)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index 359c0ac3ca43..37f37411c508 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -26,195 +26,256 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, + dH, dW, PoolingType::AVG_POOL, extraParam0); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); } DECLARE_SYN(AvgPool2D, avgpool2d); DECLARE_SYN(AvgPool, avgpool2d); DECLARE_SYN(avgpool, avgpool2d); - DECLARE_TYPES(avgpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(avgpool2d) { - - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = block.getIArguments(); - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - const int pH = INT_ARG(4); - const int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const int isSameMode = INT_ARG(8); - - const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - const int bS = shapeOf[0]; - const int iD = isNCHW ? shapeOf[1] : shapeOf[3]; - const int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - const int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - - const char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - // allocate memory for new shape - Nd4jLong newShape[4]; - if (isNCHW) { - newShape[0] = bS; - newShape[1] = iD; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[0] = bS; - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iD; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), newShape, 4))); + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = block.getIArguments(); + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + const int pH = INT_ARG(4); + const int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int isSameMode = INT_ARG(8); + + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + const int bS = shapeOf[0]; + const int iD = isNCHW ? shapeOf[1] : shapeOf[3]; + const int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + const int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + + const char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + // allocate memory for new shape + Nd4jLong newShape[4]; + if (isNCHW) { + newShape[0] = bS; + newShape[1] = iD; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[0] = bS; + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iD; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + newShape, 4))); } - DECLARE_TYPES(avgpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - - // columns2d->addiColumnVector(gradOVector); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - // *gradI /= kH*kW; - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next " + "epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] NDArray* gradOVector = gradO->reshape('c', + // {(int) gradO->lengthOf(), 1}); NDArray* columns2d = + // columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); + + // columns2d->addiColumnVector(gradOVector); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + // *gradI /= kH*kW; + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 1, extraParam0); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(avgpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "AVGPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "AVGPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 32efcfb826e6..29e253b9098c 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -25,189 +25,253 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - //T extraParams[] = {}; - ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); - - if(!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW OP: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // T extraParams[] = {}; + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, 1, extraParam0); + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); } - DECLARE_TYPES(avgpool3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool3dnew) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(avgpool3dnew) { - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - auto inputShapeInfo = inputShape->at(0); - - int idxID, idxIC; - if(isNCDHW) { idxID = 2; idxIC = 1;} - else { idxID = 1; idxIC = 4;} - - int bS = inputShapeInfo[1]; // batch size - int iC = inputShapeInfo[idxIC+1]; // input channels - int iD = inputShapeInfo[idxID+1]; // input depth - int iH = inputShapeInfo[idxID+2]; // input height - int iW = inputShapeInfo[idxID+3]; // input width - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong outputShape[5]; - - outputShape[0] = bS; - - if (isNCDHW) { - outputShape[1] = iC; - outputShape[2] = oD; - outputShape[3] = oH; - outputShape[4] = oW; - } else { - outputShape[1] = oD; - outputShape[2] = oH; - outputShape[3] = oW; - outputShape[4] = iC; - } - // TF DOC: A Tensor. Has the same type as input. - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outputShape, 5))); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + auto inputShapeInfo = inputShape->at(0); + + int idxID, idxIC; + if (isNCDHW) { + idxID = 2; + idxIC = 1; + } else { + idxID = 1; + idxIC = 4; + } + + int bS = inputShapeInfo[1]; // batch size + int iC = inputShapeInfo[idxIC + 1]; // input channels + int iD = inputShapeInfo[idxID + 1]; // input depth + int iH = inputShapeInfo[idxID + 2]; // input height + int iW = inputShapeInfo[idxID + 3]; // input width + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong outputShape[5]; + + outputShape[0] = bS; + + if (isNCDHW) { + outputShape[1] = iC; + outputShape[2] = oD; + outputShape[3] = oH; + outputShape[4] = oW; + } else { + outputShape[1] = oD; + outputShape[2] = oH; + outputShape[3] = oW; + outputShape[4] = iC; + } + // TF DOC: A Tensor. Has the same type as input. + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), + shape::order(inputShapeInfo), outputShape, 5))); } - DECLARE_TYPES(avgpool3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(avgpool3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); - - if(!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE( + input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } - DECLARE_SHAPE_FN(avgpool3dnew_bp) { - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index bddf3ea550c6..0da8cf00a809 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -26,200 +26,261 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// // maxpool2d corresponds to poolingMode=0 CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto output = OUTPUT_NULLIFIED(0); - - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const bool isSameMode = INT_ARG(8); - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1); - const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", + input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto output = OUTPUT_NULLIFIED(0); + + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const bool isSameMode = INT_ARG(8); + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int oH = 0; + int oW = 0; + + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1); + const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, + dH, dW, PoolingType::MAX_POOL, 1); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); } DECLARE_SYN(MaxPool2D, maxpool2d); DECLARE_SYN(MaxPool, maxpool2d); DECLARE_SYN(maxpool, maxpool2d); - DECLARE_TYPES(maxpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - +DECLARE_TYPES(maxpool2d) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} DECLARE_SHAPE_FN(maxpool2d) { - - //NDArray *x = block.getVariables().at(0)->getNDArray(); - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - // 0 - number of dimensions; 1,2 - kernel Height/Width; 3,4 - stride Height/Width; 5,6 - pad Height/Width; 7,8 - dilation Height/Width; 9,10 - input Height/Width; 11 - batch size; 12 - input depth; 13 - same mode; - int kH = INT_ARG(0); - int kW = INT_ARG(1); - int sH = INT_ARG(2); - int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dH = INT_ARG(6); - int dW = INT_ARG(7); - int isSameMode = INT_ARG(8); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS = shapeOf[0]; - int iC = isNCHW ? shapeOf[1] : shapeOf[3]; - int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - - char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - // allocate memory for new shape - Nd4jLong newShape[4]; - - newShape[0] = bS; - if (isNCHW) { - newShape[1] = iC; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); + // NDArray *x = block.getVariables().at(0)->getNDArray(); + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + // 0 - number of dimensions; 1,2 - kernel Height/Width; 3,4 - stride + // Height/Width; 5,6 - pad Height/Width; 7,8 - dilation Height/Width; 9,10 - + // input Height/Width; 11 - batch size; 12 - input depth; 13 - same mode; + int kH = INT_ARG(0); + int kW = INT_ARG(1); + int sH = INT_ARG(2); + int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dH = INT_ARG(6); + int dW = INT_ARG(7); + int isSameMode = INT_ARG(8); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int bS = shapeOf[0]; + int iC = isNCHW ? shapeOf[1] : shapeOf[3]; + int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + + char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + + // allocate memory for new shape + Nd4jLong newShape[4]; + + newShape[0] = bS; + if (isNCHW) { + newShape[1] = iC; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); } - DECLARE_TYPES(maxpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(maxpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - - // input->template applyTransform>(columns, std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data()); - - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - - // columns2d->template applyTransform>(std::vector({(T)1., (T)1.}).data()); - // columns2d->muliColumnVector(gradOVector); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP op: wrong shape of output's gradients array (next " + "epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] + + // input->template applyTransform>(columns, + // std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, + // (T)0.f, (T)0.f}).data()); + + // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, + // kH*kW}); NDArray* gradOVector = gradO->reshape('c', {(int) + // gradO->lengthOf(), 1}); + + // columns2d->template + // applyTransform>(std::vector({(T)1., (T)1.}).data()); + // columns2d->muliColumnVector(gradOVector); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 0., 1.); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SYN(MaxPool2D_bp, maxpool2d_bp); DECLARE_SYN(MaxPool_bp, maxpool2d_bp); DECLARE_SHAPE_FN(maxpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "MAXPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "MAXPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index 71ea9a5a5bcf..78d883f27c2b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -25,204 +25,278 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); - - if(!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_NULLIFIED( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "MAXPOOL3D op: wrong shape of output array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the + // input depth/height/width must be greater or equal to kernel(filter) + // depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly + // !", iD,iH,iW, kD,kH,kW); REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= + // pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half + // of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] + // correspondingly !", pD,pH,pW, kD,kH,kW); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray(output->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, 0, 1); + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); } - DECLARE_TYPES(maxpool3dnew) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +DECLARE_TYPES(maxpool3dnew) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} DECLARE_SHAPE_FN(maxpool3dnew) { - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - auto inputShapeInfo = inputShape->at(0); - - int idxID, idxIC; - if(isNCDHW) { idxID = 2; idxIC = 1;} - else { idxID = 1; idxIC = 4;} - - int bS = inputShapeInfo[1]; // batch size - int iC = inputShapeInfo[idxIC+1]; // input channels - int iD = inputShapeInfo[idxID+1]; // input depth - int iH = inputShapeInfo[idxID+2]; // input height - int iW = inputShapeInfo[idxID+3]; // input width - - int oD, oH, oW; // output depth, height, width - ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - Nd4jLong outputShape[5]; - - - outputShape[0] = bS; - if (isNCDHW) { - outputShape[1] = iC; - outputShape[2] = oD; - outputShape[3] = oH; - outputShape[4] = oW; - } else { - outputShape[1] = oD; - outputShape[2] = oH; - outputShape[3] = oW; - outputShape[4] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outputShape, 5))); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + auto inputShapeInfo = inputShape->at(0); + + int idxID, idxIC; + if (isNCDHW) { + idxID = 2; + idxIC = 1; + } else { + idxID = 1; + idxIC = 4; + } + + int bS = inputShapeInfo[1]; // batch size + int iC = inputShapeInfo[idxIC + 1]; // input channels + int iD = inputShapeInfo[idxID + 1]; // input depth + int iH = inputShapeInfo[idxID + 2]; // input height + int iW = inputShapeInfo[idxID + 3]; // input width + + int oD, oH, oW; // output depth, height, width + ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, iD, iH, iW, + isSameMode); + + Nd4jLong outputShape[5]; + + outputShape[0] = bS; + if (isNCDHW) { + outputShape[1] = iC; + outputShape[2] = oD; + outputShape[3] = oH; + outputShape[4] = oW; + } else { + outputShape[1] = oD; + outputShape[2] = oH; + outputShape[3] = oW; + outputShape[4] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inputShapeInfo), + shape::order(inputShapeInfo), outputShape, 5))); } - DECLARE_TYPES(maxpool3dnew_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(maxpool3dnew_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCDHW) { - input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW] - - // ConvolutionUtils::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] - - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, kD*kH*kW}); - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // T extraParams[] = {(T)1., (T)1.}; - // columns2d->template applyTransform>(extraParams); - // columns2d->muliColumnVector(gradOVector); - - // ConvolutionUtils::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary; - ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); - - if(!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE( + input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead " + "{%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, + // kH, kW}, input->getWorkspace()); NDArray* columns = + // columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, + // oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW] + + // ConvolutionUtils::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, + // dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, + // kD, kH, kW, oD, oH, oW] + + // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, + // kD*kH*kW}); NDArray* gradOVector = gradO->reshape('c', {(int) + // gradO->lengthOf(), 1}); T extraParams[] = {(T)1., (T)1.}; + // columns2d->template applyTransform>(extraParams); + // columns2d->muliColumnVector(gradOVector); + + // ConvolutionUtils::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, + // dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is + // de-convoluted to [bS, iC, iD, iH, iW] + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - + // unnecessary; + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } - DECLARE_SHAPE_FN(maxpool3dnew_bp) { - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index 05aab696290b..cd2339c29b60 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -26,40 +26,41 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { +namespace ops { +CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); + auto indices = OUTPUT_NULLIFIED(1); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); - auto indices = OUTPUT_NULLIFIED(1); + REQUIRE_TRUE( + x->rankOf() == 4, 0, + "max_pool_with_argmax: Input should have rank of 4, but got %i instead", + x->rankOf()); - REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf()); + auto argI = block.getIArguments(); - auto argI = block.getIArguments(); + helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); - helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices); - - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(max_pool_with_argmax) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); +DECLARE_TYPES(max_pool_with_argmax) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setAllowedOutputTypes(1, {ALL_INTS}); +} - } +DECLARE_SHAPE_FN(max_pool_with_argmax) { + auto in = inputShape->at(0); + auto valuesShape = + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(in)); + auto indicesShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(in, DataType::INT64)); - DECLARE_SHAPE_FN(max_pool_with_argmax) { - - auto in = inputShape->at(0); - auto valuesShape = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(in)); - auto indicesShape = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(in, DataType::INT64)); - - return SHAPELIST(valuesShape, indicesShape); - } - } + return SHAPELIST(valuesShape, indicesShape); } +} // namespace ops +} // namespace sd #endif - diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp index 92a74a01ea40..082542c34dc8 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/pnormpool2d.cpp @@ -26,209 +26,270 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(pnormpool2d, 1, 1, false, 0, 10) { - - REQUIRE_OK(this->validateInputLengthMatch(block)); - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); - - auto kY = INT_ARG(0); - auto kX = INT_ARG(1); - auto sY = INT_ARG(2); - auto sX = INT_ARG(3); - auto pY = INT_ARG(4); - auto pX = INT_ARG(5); - auto dY = INT_ARG(6); - auto dX = INT_ARG(7); - bool isSameMode = static_cast(INT_ARG(8)); - auto extraParam0 = INT_ARG(9); - - REQUIRE_TRUE(dY != 0 && dX != 0, 0, "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dY, dX); - - int oY = 0; - int oX = 0; - - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - const auto inY = static_cast(input->sizeAt(2)); - const auto inX = static_cast(input->sizeAt(3)); - - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0); - - if(!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - DECLARE_SYN(PnormPool2D, pnormpool2d); - DECLARE_SYN(PnormPool, pnormpool2d); - DECLARE_SYN(pnormpool, pnormpool2d); - - DECLARE_TYPES(pnormpool2d) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - - DECLARE_SHAPE_FN(pnormpool2d) { - auto inShape = inputShape->at(0); - auto shapeOf = shape::shapeOf(inShape); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = block.getIArguments(); - int kH = INT_ARG(0); - int kW = INT_ARG(1); - int sH = INT_ARG(2); - int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - int dH = INT_ARG(6); - int dW = INT_ARG(7); - int isSameMode = INT_ARG(8); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS = shapeOf[0]; - int iC = isNCHW ? shapeOf[1] : shapeOf[3]; - int iH = isNCHW ? shapeOf[2] : shapeOf[1]; - int iW = isNCHW ? shapeOf[3] : shapeOf[2]; - char order = shape::order(inShape); // output order must be equal to input order - - // calculate output Height/Width - int oH, oW; - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - // allocate memory for new shape - Nd4jLong newShape[4]; - - newShape[0] = bS; - if (isNCHW) { - newShape[1] = iC; - newShape[2] = oH; - newShape[3] = oW; - } else { - newShape[1] = oH; - newShape[2] = oW; - newShape[3] = iC; - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); - } - - - DECLARE_TYPES(pnormpool2d_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +CUSTOM_OP_IMPL(pnormpool2d, 1, 1, false, 0, 10) { + REQUIRE_OK(this->validateInputLengthMatch(block)); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "PNORMPOOL2D op: input should have rank of 4, but got %i instead", + input->rankOf()); + + auto kY = INT_ARG(0); + auto kX = INT_ARG(1); + auto sY = INT_ARG(2); + auto sX = INT_ARG(3); + auto pY = INT_ARG(4); + auto pX = INT_ARG(5); + auto dY = INT_ARG(6); + auto dX = INT_ARG(7); + bool isSameMode = static_cast(INT_ARG(8)); + auto extraParam0 = INT_ARG(9); + + REQUIRE_TRUE( + dY != 0 && dX != 0, 0, + "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dY, + dX); + + int oY = 0; + int oX = 0; + + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + const auto inY = static_cast(input->sizeAt(2)); + const auto inX = static_cast(input->sizeAt(3)); + + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, + dY, dX); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::PNORM_POOL, extraParam0); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} +DECLARE_SYN(PnormPool2D, pnormpool2d); +DECLARE_SYN(PnormPool, pnormpool2d); +DECLARE_SYN(pnormpool, pnormpool2d); + +DECLARE_TYPES(pnormpool2d) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(pnormpool2d) { + auto inShape = inputShape->at(0); + auto shapeOf = shape::shapeOf(inShape); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = block.getIArguments(); + int kH = INT_ARG(0); + int kW = INT_ARG(1); + int sH = INT_ARG(2); + int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + int dH = INT_ARG(6); + int dW = INT_ARG(7); + int isSameMode = INT_ARG(8); + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "PNORMPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, + dW); + + int bS = shapeOf[0]; + int iC = isNCHW ? shapeOf[1] : shapeOf[3]; + int iH = isNCHW ? shapeOf[2] : shapeOf[1]; + int iW = isNCHW ? shapeOf[3] : shapeOf[2]; + char order = + shape::order(inShape); // output order must be equal to input order + + // calculate output Height/Width + int oH, oW; + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, isSameMode); + // allocate memory for new shape + Nd4jLong newShape[4]; + + newShape[0] = bS; + if (isNCHW) { + newShape[1] = iC; + newShape[2] = oH; + newShape[3] = oW; + } else { + newShape[1] = oH; + newShape[2] = oW; + newShape[3] = iC; + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), order, newShape, 4))); +} + +DECLARE_TYPES(pnormpool2d_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int pnorm = INT_ARG(9); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW - - // FIXME: double? - double eps = T_ARG(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "PNORMPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(!isNCHW) { - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - // if(isSameMode) // SAME - // ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace()); - // NDArray* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW] - // NDArray* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1}); - // NDArray* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); - // NDArray pNorm(columns2d->shapeInfo(), block.workspace()); - - // input->template applyTransform>(columns, std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data()); - - // columns2d->template applyTransform>(&pNorm); - // pNorm.template applyTransform>(&pNorm, std::vector({(T)pnorm}).data()); - - // NDArray* denomVec = pNorm.sum({1}); - // denomVec->template applyTransform>(std::vector({(T)1. - (T)1. / pnorm}).data()); - // denomVec->template applyScalar>(eps); // in case of 0 - // denomVec->template applyPairwiseTransform>(gradOVector, denomVec, nullptr); - - // if(pnorm != 2) { - // T extraParams[] = {(T)1. - (T)2. / pnorm}; - // pNorm.template applyTransform>(std::vector({(T)1. - (T)2. / pnorm}).data()); - // *columns2d *= pNorm; - // } - - // columns2d->muliColumnVector(denomVec); - - // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - - ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm); - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int pnorm = INT_ARG(9); + int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW + + // FIXME: double? + double eps = T_ARG(0); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "PNORMPOOL2D_BP op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "PNORMPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "PNORMPOOL2D_BP op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "PNORMPOOL2D_BP op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray( + gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + // if(isSameMode) // SAME + // ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, + // sW, dH, dW); + + // NDArray columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, + // input->getWorkspace()); NDArray* columns = columnsWrongShape.permute({0, + // 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] + // -> [bS, iC, kH, kW, oH, oW] NDArray* gradOVector = gradO->reshape('c', + // {(int) gradO->lengthOf(), 1}); NDArray* columns2d = + // columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW}); NDArray + // pNorm(columns2d->shapeInfo(), block.workspace()); + + // input->template applyTransform>(columns, + // std::vector({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, + // (T)0.f, (T)0.f}).data()); + + // columns2d->template applyTransform>(&pNorm); + // pNorm.template applyTransform>(&pNorm, + // std::vector({(T)pnorm}).data()); + + // NDArray* denomVec = pNorm.sum({1}); + // denomVec->template applyTransform>(std::vector({(T)1. - + // (T)1. / pnorm}).data()); denomVec->template + // applyScalar>(eps); // in case of 0 denomVec->template + // applyPairwiseTransform>(gradOVector, denomVec, + // nullptr); + + // if(pnorm != 2) { + // T extraParams[] = {(T)1. - (T)2. / pnorm}; + // pNorm.template applyTransform>(std::vector({(T)1. - + // (T)2. / pnorm}).data()); *columns2d *= pNorm; + // } + + // columns2d->muliColumnVector(denomVec); + + // columns->template applyTransform>(gradI, + // std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, + // (T)dW}).data()); + + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, + pH, pW, dH, dW, 2, pnorm); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(pnormpool2d_bp) { - - REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]); - REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); + REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, + "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", + inputShape->at(0)[0]); + REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, + "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must " + "be 4D, but got %i instead!", + inputShape->at(1)[0]); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor( + inputShape->at(0), ArrayOptions::dataType(inputShape->at(1))))); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp index 5c489edb1e01..cabfcc733810 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicBidirectionalRNN.cpp @@ -21,208 +21,340 @@ #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if non zero then [time, bS, ...], else [bS, time, ...] - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - auto hFW = OUTPUT_VARIABLE(0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x time x numUnitsFW], shape depends on timeMajor parameter - auto hBW = OUTPUT_VARIABLE(1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x time x numUnitsBW], shape depends on timeMajor parameter - auto hFWFinal = OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW] - auto hBWFinal = OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnitsFW = WxFW->sizeAt(1); - const int numUnitsBW = WxBW->sizeAt(1); - - std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - std::vector expectedbFWshape = {2*numUnitsFW}; - std::vector expectedbBWshape = {2*numUnitsBW}; - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape) , 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepshape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - - // forward steps - sd::ops::dynamic_rnn dynamicRnn; - auto resultsFW = dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); - hFW->assign(resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] - hFWFinal->assign(resultsFW.at(1)); - - auto seqLen = maxTimeStep; - if(seqLen == nullptr) { - // FIXME: which datatype should be used here? - seqLen = new NDArray(x->ordering(), {bS}, sd::DataType::INT64, block.launchContext()); - seqLen->assign(time); // set each element of seqLen to be equal to time - } - - // reverse x - sd::ops::reverse_sequence reverse; - auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) : reverse.evaluate({x, seqLen}, {1, 0}); - REQUIRE_TRUE (resultsIn.status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence."); - auto revInput = resultsIn.at(0); - - // backward steps - auto resultsBW = dynamicRnn.evaluate({&revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); - auto hBWtemp = resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] - hBWFinal->assign(resultsBW.at(1)); - - // reverse hBWtemp - auto resultsOut = timeMajor ? reverse.evaluate({&hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({&hBWtemp, seqLen}, {1, 0}); - hBW->assign(resultsOut.at(0)); - - if(seqLen != maxTimeStep) - delete seqLen; - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], shape depends on timeMajor parameter + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if non zero then [time, bS, ...], else [bS, time, ...] + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + auto hFW = OUTPUT_VARIABLE( + 0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x + // time x numUnitsFW], shape depends on timeMajor parameter + auto hBW = OUTPUT_VARIABLE( + 1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x + // time x numUnitsBW], shape depends on timeMajor parameter + auto hFWFinal = + OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW] + auto hBWFinal = + OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnitsFW = WxFW->sizeAt(1); + const int numUnitsBW = WxBW->sizeAt(1); + + std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + std::vector expectedbFWshape = {2 * numUnitsFW}; + std::vector expectedbBWshape = {2 * numUnitsBW}; + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepshape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + // forward steps + sd::ops::dynamic_rnn dynamicRnn; + auto resultsFW = + dynamicRnn.evaluate({x, WxFW, WhFW, bFW, h0FW, maxTimeStep}, {timeMajor}); + hFW->assign( + resultsFW.at(0)); // [time x bS x numUnitsFW] or [bS x time x numUnitsFW] + hFWFinal->assign(resultsFW.at(1)); + + auto seqLen = maxTimeStep; + if (seqLen == nullptr) { + // FIXME: which datatype should be used here? + seqLen = new NDArray(x->ordering(), {bS}, sd::DataType::INT64, + block.launchContext()); + seqLen->assign(time); // set each element of seqLen to be equal to time + } + + // reverse x + sd::ops::reverse_sequence reverse; + auto resultsIn = timeMajor ? reverse.evaluate({x, seqLen}, {0, 1}) + : reverse.evaluate({x, seqLen}, {1, 0}); + REQUIRE_TRUE(resultsIn.status() == ND4J_STATUS_OK, 0, + "dynamic_bidirectional_rnn: there is a problem with reverse on " + "the sequence."); + auto revInput = resultsIn.at(0); + + // backward steps + auto resultsBW = dynamicRnn.evaluate( + {&revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor}); + auto hBWtemp = + resultsBW.at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW] + hBWFinal->assign(resultsBW.at(1)); + + // reverse hBWtemp + auto resultsOut = timeMajor ? reverse.evaluate({&hBWtemp, seqLen}, {0, 1}) + : reverse.evaluate({&hBWtemp, seqLen}, {1, 0}); + hBW->assign(resultsOut.at(0)); + + if (seqLen != maxTimeStep) delete seqLen; + + return Status::OK(); } - DECLARE_TYPES(dynamic_bidirectional_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - -DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*numUnitsBW] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnitsFW = WxFW->sizeAt(1); - const int numUnitsBW = WxBW->sizeAt(1); - - - std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - std::vector expectedbFWshape = {2*numUnitsFW}; - std::vector expectedbBWshape = {2*numUnitsBW}; - - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepshape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - // evaluate output shapeInfos - Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hFWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hBWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hFWShapeInfo[0] = hBWShapeInfo[0] = inRank; - hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS; - hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time; - hFWShapeInfo[3] = numUnitsFW; - hBWShapeInfo[3] = numUnitsBW; - hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1; - hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; - hFWFinalPrevShapeInfo[2] = numUnitsFW; - hBWFinalPrevShapeInfo[2] = numUnitsBW; - - ShapeUtils::updateStridesAndType(hFWShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hBWShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->shapeInfo(), x->ordering()); - ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->shapeInfo(), x->ordering()); - - return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo)); +DECLARE_TYPES(dynamic_bidirectional_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} +DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) { + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], shape depends on timeMajor parameter + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*numUnitsBW] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnitsFW = WxFW->sizeAt(1); + const int numUnitsBW = WxBW->sizeAt(1); + + std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + std::vector expectedbFWshape = {2 * numUnitsFW}; + std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepshape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, + "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + // evaluate output shapeInfos + Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), + *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); + ALLOCATE(hFWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hBWShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hFWShapeInfo[0] = hBWShapeInfo[0] = inRank; + hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS; + hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time; + hFWShapeInfo[3] = numUnitsFW; + hBWShapeInfo[3] = numUnitsBW; + hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank - 1; + hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; + hFWFinalPrevShapeInfo[2] = numUnitsFW; + hBWFinalPrevShapeInfo[2] = numUnitsBW; + + ShapeUtils::updateStridesAndType(hFWShapeInfo, x->shapeInfo(), x->ordering()); + ShapeUtils::updateStridesAndType(hBWShapeInfo, x->shapeInfo(), x->ordering()); + ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->shapeInfo(), + x->ordering()); + ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->shapeInfo(), + x->ordering()); + + return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), + CONSTANT(hFWFinalPrevShapeInfo), + CONSTANT(hBWFinalPrevShapeInfo)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp index 73a26388b9de..0718267534fa 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/dynamicRNN.cpp @@ -19,155 +19,222 @@ // #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(dynamic_rnn, 4, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], depends on timeMajor parameter - auto Wx = INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] - auto Wh = INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] - - NDArray* h0 = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - if(block.width() == 5) { - if ((*INPUT_VARIABLE(4)).rankOf() == 2) - h0 = INPUT_VARIABLE(4); - else - maxTimeStep = INPUT_VARIABLE(4); - } - else if(block.width() == 6) { - h0 = INPUT_VARIABLE(4); - maxTimeStep = INPUT_VARIABLE(5); - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] or [bS x time x numUnits], depends on timeMajor parameter - auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final non-zero output [bS x numUnits] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(Wx->rankOf() == 2, 0, "DYNAMIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", Wx->rankOf()); - - const int inRank = x->rankOf(); - const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); - const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); - const int numUnits = Wx->sizeAt(1); - - std::vector expectedWhShape = {numUnits, numUnits}; - std::vector expectedBShape = {2*numUnits}; - REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(h0) { - std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - } - if(maxTimeStep) { - std::vector expectedmaxTimeStepShape = {bS}; - REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str()); - } - - if(timeMajor == false) { - x = new NDArray(x->permute({1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize] - h = new NDArray(h->permute({1, 0, 2})); // [bS x time x numUnits] -> [time x bS x numUnits] - } - - helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal); - - if(timeMajor == false) { - delete x; - delete h; - } - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x + // inSize], depends on timeMajor parameter + auto Wx = + INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] + auto Wh = + INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] + + NDArray* h0 = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + if (block.width() == 5) { + if ((*INPUT_VARIABLE(4)).rankOf() == 2) + h0 = INPUT_VARIABLE(4); + else + maxTimeStep = INPUT_VARIABLE(4); + } else if (block.width() == 6) { + h0 = INPUT_VARIABLE(4); + maxTimeStep = INPUT_VARIABLE(5); + } + + auto h = + OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] or [bS x time + // x numUnits], depends on timeMajor parameter + auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final + // non-zero output [bS x numUnits] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "DYNAMIC_RNN custom operation: input array x must have rank = " + "3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE(Wx->rankOf() == 2, 0, + "DYNAMIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + Wx->rankOf()); + + const int inRank = x->rankOf(); + const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1); + const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0); + const int numUnits = Wx->sizeAt(1); + + std::vector expectedWhShape = {numUnits, numUnits}; + std::vector expectedBShape = {2 * numUnits}; + REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (h0) { + std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, + "DYNAMIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + } + if (maxTimeStep) { + std::vector expectedmaxTimeStepShape = {bS}; + REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), + ShapeUtils::shapeAsString(maxTimeStep).c_str()); + } + + if (timeMajor == false) { + x = new NDArray(x->permute( + {1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize] + h = new NDArray(h->permute( + {1, 0, 2})); // [bS x time x numUnits] -> [time x bS x numUnits] + } + + helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, + hFinal); + + if (timeMajor == false) { + delete x; + delete h; + } + + return Status::OK(); } - - DECLARE_TYPES(dynamic_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(5, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_FLOATS}); - } - +DECLARE_TYPES(dynamic_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(5, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(dynamic_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] or [bS x time x inSize], depends on timeMajor parameter - auto WxShapeInfo = inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] - auto WhShapeInfo = inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] - - Nd4jLong const* h0ShapeInfo = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - Nd4jLong const* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int timeMajor = block.numI() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...] - - if(block.width() == 5) { - if (inputShape->at(4)[0] == 2) - h0ShapeInfo = inputShape->at(4); - else - maxTimeStepShapeInfo = inputShape->at(4); - } - else if(block.width() == 6) { - h0ShapeInfo = inputShape->at(4); - maxTimeStepShapeInfo = inputShape->at(5); - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "DYNAMIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, "DYNAMIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", WxShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = timeMajor ? xShapeInfo[1] : xShapeInfo[2]; - const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1]; - const int numUnits = WxShapeInfo[2]; - - - std::vector expectedWhShape = {numUnits, numUnits}; - std::vector expectedBShape = {2*numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - if(h0ShapeInfo) { - std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) { - std::vector expectedmaxTimeStepShape = {bS}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - } - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = timeMajor ? time : bS; - hShapeInfo[2] = timeMajor ? bS : time; - hPrevShapeInfo[1] = bS; - hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, WhShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hPrevShapeInfo, WhShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); + auto xShapeInfo = + inputShape->at(0); // input [time x bS x inSize] or [bS x time x inSize], + // depends on timeMajor parameter + auto WxShapeInfo = + inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] + auto WhShapeInfo = + inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] + + Nd4jLong const* h0ShapeInfo = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + Nd4jLong const* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + const int timeMajor = + block.numI() > 0 + ? INT_ARG(0) + : 0; // if true then [time, bS, ...], else [bS, time, ...] + + if (block.width() == 5) { + if (inputShape->at(4)[0] == 2) + h0ShapeInfo = inputShape->at(4); + else + maxTimeStepShapeInfo = inputShape->at(4); + } else if (block.width() == 6) { + h0ShapeInfo = inputShape->at(4); + maxTimeStepShapeInfo = inputShape->at(5); + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "DYNAMIC_RNN custom operation: input array x must have rank = " + "3, but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, + "DYNAMIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + WxShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = timeMajor ? xShapeInfo[1] : xShapeInfo[2]; + const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1]; + const int numUnits = WxShapeInfo[2]; + + std::vector expectedWhShape = {numUnits, numUnits}; + std::vector expectedBShape = {2 * numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, + "DYNAMIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + if (h0ShapeInfo) { + std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, + "DYNAMIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) { + std::vector expectedmaxTimeStepShape = {bS}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, + expectedmaxTimeStepShape), + 0, + "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), + ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + } + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = timeMajor ? time : bS; + hShapeInfo[2] = timeMajor ? bS : time; + hPrevShapeInfo[1] = bS; + hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, WhShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hPrevShapeInfo, WhShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); } - - - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index a0b1e707b282..80f357a2842d 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -23,167 +23,249 @@ #if NOT_EXCLUDED(OP_gru) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step - - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto h = OUTPUT_VARIABLE( + 0); // cell outputs [time, bS, nOut], that is per each time step + + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU operation: wrong shape of input-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU operation: wrong shape of hidden-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(gru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - auto hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut}); - - return SHAPELIST(hShapeInfo); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU operation: wrong shape of input-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU operation: wrong shape of hidden-to-hidden weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + auto hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + hI->dataType(), hI->ordering(), {time, bS, nOut}); + + return SHAPELIST(hShapeInfo); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] - - auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] - auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] - auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] - auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] - auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - const std::vector hCorrectShape = {time, bS, nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] + auto dLdhI = + OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] + auto dLdWx = OUTPUT_NULLIFIED( + 2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] + auto dLdWh = OUTPUT_NULLIFIED( + 3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] + auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU_BP operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU_BP operation: wrong shape of input-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU_BP operation: wrong shape of hidden-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_BP operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0, + "GRU_BP operation: wrong shape of gradient vs. ff output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, + dLdhI, dLdWx, dLdWh, dLdb); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(gru_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru_bp) { - - auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] - auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] - auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - - auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nIn = x->sizeAt(2); - const int nOut = hI->sizeAt(1); - - const std::vector h0CorrectShape = {bS, nOut}; - const std::vector wxCorrectShape = {nIn, 3*nOut}; - const std::vector whCorrectShape = {nOut, 3*nOut}; - const std::vector bCorrectShape = {3*nOut}; - const std::vector hCorrectShape = {time, bS, nOut}; - - REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - auto dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), x->shapeInfo()); - auto dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), hI->shapeInfo()); - auto dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wx->shapeInfo()); - auto dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wh->shapeInfo()); - auto dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), b->shapeInfo()); - - return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); -} - -} + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = + INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3 * nOut}; + const std::vector whCorrectShape = {nOut, 3 * nOut}; + const std::vector bCorrectShape = {3 * nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, + "GRU_BP operation: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(h0CorrectShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, + "GRU_BP operation: wrong shape of input-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(wxCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, + "GRU_BP operation: wrong shape of hidden-to-hidden weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(whCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_BP operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape), 0, + "GRU_BP operation: wrong shape of gradient vs. ff output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + auto dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + dLdh->dataType(), x->shapeInfo()); + auto dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + dLdh->dataType(), hI->shapeInfo()); + auto dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + dLdh->dataType(), Wx->shapeInfo()); + auto dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + dLdh->dataType(), Wh->shapeInfo()); + auto dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + dLdh->dataType(), b->shapeInfo()); + + return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, + dLdWhShapeInfo, dLdbShapeInfo); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 7039fce72761..f76efea276f9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -23,225 +23,347 @@ #if NOT_EXCLUDED(OP_gruCell) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size - auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights) - auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights) - auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates - auto bc = INPUT_VARIABLE(5); // cell biases, [nU] - - auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU] - auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU] - auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU] - auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU] - - REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf()); - - const int rank = x->rankOf(); - const auto bS = x->sizeAt(0); - const auto nIn = x->sizeAt(1); - const auto nU = hLast->sizeAt(1); - - REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); - REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf()); - REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i", Wru->sizeAt(0)); - REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru->sizeAt(1)); - REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc->sizeAt(1)); - REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU"); - REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size nU"); - REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 && - r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS && - r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU, - 0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == nU"); - - helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size + auto hLast = + INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous + // time step t-1, nU - number of units + auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and + // update gates (input/recurrent weights) + auto Wc = INPUT_VARIABLE( + 3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights) + auto bru = INPUT_VARIABLE( + 4); // reset and update biases, [2*nU] - reset and update gates + auto bc = INPUT_VARIABLE(5); // cell biases, [nU] + + auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU] + auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU] + auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU] + auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU] + + REQUIRE_TRUE(x->rankOf() == 2 && hLast->rankOf() == 2, 0, + "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - " + "got %i, %i", + x->rankOf(), hLast->rankOf()); + + const int rank = x->rankOf(); + const auto bS = x->sizeAt(0); + const auto nIn = x->sizeAt(1); + const auto nU = hLast->sizeAt(1); + + REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, + "gruCell: Input minibatch sizes (dimension 0) must be same for " + "x and hLast"); + REQUIRE_TRUE( + Wru->rankOf() == 2 && Wc->rankOf() == 2, 0, + "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", + Wru->rankOf(), Wc->rankOf()); + REQUIRE_TRUE(Wru->sizeAt(0) == (nIn + nU) && Wc->sizeAt(0) == (nIn + nU), 0, + "gruCell: Weights size(0) must be equal to nIn + nU, got %i", + Wru->sizeAt(0)); + REQUIRE_TRUE(Wru->sizeAt(1) == (2 * nU), 0, + "gruCell: Weights (reset and update) size(1) must be equal to " + "2*nU, got %i", + Wru->sizeAt(1)); + REQUIRE_TRUE(Wc->sizeAt(1) == nU, 0, + "gruCell: Weights (cell) size(1) must be equal to nU, got %i", + Wc->sizeAt(1)); + REQUIRE_TRUE(bru->rankOf() == 1 && bru->sizeAt(0) == (2 * nU), 0, + "gruCell: reset/update biases must be rank 1, size 2*nU"); + REQUIRE_TRUE(bc->rankOf() == 1 && bc->sizeAt(0) == nU, 0, + "gruCell: cell biases must be rank 1, size nU"); + REQUIRE_TRUE( + r->rankOf() == 2 && u->rankOf() == 2 && c->rankOf() == 2 && + h->rankOf() == 2 && r->sizeAt(0) == bS && u->sizeAt(0) == bS && + c->sizeAt(0) == bS && h->sizeAt(0) == bS && r->sizeAt(1) == nU && + u->sizeAt(1) == nU && c->sizeAt(1) == nU && h->sizeAt(1) == nU, + 0, + "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and " + "size(1) == nU"); + + helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, + h); + + return Status::OK(); } DECLARE_TYPES(gruCell) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(gruCell) { - - auto x = inputShape->at(0); // input [bS x nIn] - auto hLast = inputShape->at(1); // previous cell output [bS x nU], that is at previous time step t-1 - auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and update gates (input/recurrent weights) - auto Wc = inputShape->at(3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights) - auto bru = inputShape->at(4); // reset and update biases, [2*nU] - reset and update gates - auto bc = inputShape->at(5); // cell biases, [nU] - - REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast)); - - const int rank = x[0]; - const auto bS = x[1]; - const auto nIn = x[2]; - const auto nU = hLast[2]; - - REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); - REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc)); - REQUIRE_TRUE(Wru[1]==(nIn+nU) && Wc[1]==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", Wru[1], Wc[1]); - REQUIRE_TRUE(Wru[2]==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru[2]); - REQUIRE_TRUE(Wc[2]==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc[2]); - REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU"); - REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==nU, 0, "gruCell: cell biases must be rank 1, size nU"); - - Nd4jLong *s0(nullptr); - ALLOCATE(s0, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU] - - s0[0] = rank; - s0[1] = bS; - s0[2] = nU; - - ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast)); - auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting(s0, block.workspace()); - - //4 output shapes, all [bs, nU] - return SHAPELIST(ts0, ts0, ts0, ts0); + auto x = inputShape->at(0); // input [bS x nIn] + auto hLast = inputShape->at( + 1); // previous cell output [bS x nU], that is at previous time step t-1 + auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and + // update gates (input/recurrent weights) + auto Wc = inputShape->at( + 3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights) + auto bru = inputShape->at( + 4); // reset and update biases, [2*nU] - reset and update gates + auto bc = inputShape->at(5); // cell biases, [nU] + + REQUIRE_TRUE(shape::rank(x) == 2 && shape::rank(hLast) == 2, 0, + "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - " + "got %i, %i", + shape::rank(x), shape::rank(hLast)); + + const int rank = x[0]; + const auto bS = x[1]; + const auto nIn = x[2]; + const auto nU = hLast[2]; + + REQUIRE_TRUE(x[1] == hLast[1], 0, + "gruCell: Input minibatch sizes (dimension 0) must be same for " + "x and hLast"); + REQUIRE_TRUE( + shape::rank(Wru) == 2 && shape::rank(Wc) == 2, 0, + "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", + shape::rank(Wru), shape::rank(Wc)); + REQUIRE_TRUE( + Wru[1] == (nIn + nU) && Wc[1] == (nIn + nU), 0, + "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", + Wru[1], Wc[1]); + REQUIRE_TRUE(Wru[2] == (2 * nU), 0, + "gruCell: Weights (reset and update) size(1) must be equal to " + "2*nU, got %i", + Wru[2]); + REQUIRE_TRUE(Wc[2] == nU, 0, + "gruCell: Weights (cell) size(1) must be equal to nU, got %i", + Wc[2]); + REQUIRE_TRUE(shape::rank(bru) == 1 && bru[1] == (2 * nU), 0, + "gruCell: reset/update biases must be rank 1, size 2*nU"); + REQUIRE_TRUE(shape::rank(bc) == 1 && bc[1] == nU, 0, + "gruCell: cell biases must be rank 1, size nU"); + + Nd4jLong *s0(nullptr); + ALLOCATE(s0, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x nU] + + s0[0] = rank; + s0[1] = bS; + s0[2] = nU; + + ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast)); + auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting( + s0, block.workspace()); + + // 4 output shapes, all [bs, nU] + return SHAPELIST(ts0, ts0, ts0, ts0); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [bS x iS] - auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU] - auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU] - auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU] - auto b = INPUT_VARIABLE(4); // biases, [2*nU] - auto bc = INPUT_VARIABLE(5); // biases, [nU] - auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU] - auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU] - auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU] - auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU] - - auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS] - auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU] - auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU] - auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU] - auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU] - auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU] - - const Nd4jLong bS = x->sizeAt(0); - const Nd4jLong iS = x->sizeAt(1); - const Nd4jLong nU = hi->sizeAt(1); - - REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf()); - - const std::vector hiCorrectShape = {bS, nU}; - const std::vector wCorrectShape = {iS+nU, 2*nU}; - const std::vector wcCorrectShape = {iS+nU, nU}; - const std::vector bCorrectShape = {2*nU}; - const std::vector bcCorrectShape = {nU}; - - REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hi).c_str()); - REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(W).c_str()); - REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bc).c_str()); - REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdr).c_str()); - REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdu).c_str()); - REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - - helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [bS x iS] + auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU] + auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU] + auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU] + auto b = INPUT_VARIABLE(4); // biases, [2*nU] + auto bc = INPUT_VARIABLE(5); // biases, [nU] + auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU] + auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU] + auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU] + auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS] + auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU] + auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU] + auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU] + auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU] + auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU] + + const Nd4jLong bS = x->sizeAt(0); + const Nd4jLong iS = x->sizeAt(1); + const Nd4jLong nU = hi->sizeAt(1); + + REQUIRE_TRUE( + x->rankOf() == 2, 0, + "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", + x->rankOf()); + + const std::vector hiCorrectShape = {bS, nU}; + const std::vector wCorrectShape = {iS + nU, 2 * nU}; + const std::vector wcCorrectShape = {iS + nU, nU}; + const std::vector bCorrectShape = {2 * nU}; + const std::vector bcCorrectShape = {nU}; + + REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(hi).c_str()); + REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(W).c_str()); + REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wcCorrectShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bcCorrectShape).c_str(), + ShapeUtils::shapeAsString(bc).c_str()); + REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdr).c_str()); + REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdu).c_str()); + REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell " + "state), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdc).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt " + "current cell output), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, + dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); + + return Status::OK(); } DECLARE_TYPES(gruCell_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedInputTypes(4, {ALL_FLOATS}) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedInputTypes(6, {ALL_FLOATS}) - ->setAllowedInputTypes(7, {ALL_FLOATS}) - ->setAllowedInputTypes(8, {ALL_FLOATS}) - ->setAllowedInputTypes(9, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, {ALL_FLOATS}) + ->setAllowedInputTypes(5, {ALL_FLOATS}) + ->setAllowedInputTypes(6, {ALL_FLOATS}) + ->setAllowedInputTypes(7, {ALL_FLOATS}) + ->setAllowedInputTypes(8, {ALL_FLOATS}) + ->setAllowedInputTypes(9, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(gruCell_bp) { - - auto xShapeInfo = inputShape->at(0); // [bS x iS] - auto hiShapeInfo = inputShape->at(1); // [bS x nU] - auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU] - auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU] - auto bShapeInfo = inputShape->at(4); // [2*nU] - auto bcShapeInfo = inputShape->at(5); // [nU] - auto dLdrShapeInfo = inputShape->at(6); // [bS, nU] - auto dLduShapeInfo = inputShape->at(7); // [bS, nU] - auto dLdcShapeInfo = inputShape->at(8); // [bS, nU] - auto dLdhShapeInfo = inputShape->at(9); // [bS, nU] - - const int rank = xShapeInfo[0]; // = 2 - const Nd4jLong bS = xShapeInfo[1]; - const Nd4jLong iS = xShapeInfo[2]; - const Nd4jLong nU = hiShapeInfo[2]; - - REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]); - - const std::vector hiCorrectShape = {bS, nU}; - const std::vector wCorrectShape = {iS+nU, 2*nU}; - const std::vector wcCorrectShape = {iS+nU, nU}; - const std::vector bCorrectShape = {2*nU}; - const std::vector bcCorrectShape = {nU}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hiShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(wcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdrShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLduShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdhShapeInfo).c_str()); - - Nd4jLong *dLdxShapeInfo = nullptr; - COPY_SHAPE(xShapeInfo, dLdxShapeInfo); - - Nd4jLong *dLdhiShapeInfo = nullptr; - COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); - - Nd4jLong *dLdWShapeInfo = nullptr; - COPY_SHAPE(wShapeInfo, dLdWShapeInfo); - - Nd4jLong *dLdWcShapeInfo = nullptr; - COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo); - - Nd4jLong *dLdbShapeInfo = nullptr; - COPY_SHAPE(bShapeInfo, dLdbShapeInfo); - - Nd4jLong *dLdbcShapeInfo = nullptr; - COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); - - return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); + auto xShapeInfo = inputShape->at(0); // [bS x iS] + auto hiShapeInfo = inputShape->at(1); // [bS x nU] + auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU] + auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU] + auto bShapeInfo = inputShape->at(4); // [2*nU] + auto bcShapeInfo = inputShape->at(5); // [nU] + auto dLdrShapeInfo = inputShape->at(6); // [bS, nU] + auto dLduShapeInfo = inputShape->at(7); // [bS, nU] + auto dLdcShapeInfo = inputShape->at(8); // [bS, nU] + auto dLdhShapeInfo = inputShape->at(9); // [bS, nU] + + const int rank = xShapeInfo[0]; // = 2 + const Nd4jLong bS = xShapeInfo[1]; + const Nd4jLong iS = xShapeInfo[2]; + const Nd4jLong nU = hiShapeInfo[2]; + + REQUIRE_TRUE( + xShapeInfo[0] == 2, 0, + "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", + xShapeInfo[0]); + + const std::vector hiCorrectShape = {bS, nU}; + const std::vector wCorrectShape = {iS + nU, 2 * nU}; + const std::vector wcCorrectShape = {iS + nU, nU}; + const std::vector bCorrectShape = {2 * nU}; + const std::vector bcCorrectShape = {nU}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of previous cell output array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(hiShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wcCorrectShape).c_str(), + ShapeUtils::shapeAsString(wcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bcCorrectShape).c_str(), + ShapeUtils::shapeAsString(bcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdrShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update " + "gate), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLduShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell " + "state), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, + "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt " + "current cell output), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(hiCorrectShape).c_str(), + ShapeUtils::shapeAsString(dLdhShapeInfo).c_str()); + + Nd4jLong *dLdxShapeInfo = nullptr; + COPY_SHAPE(xShapeInfo, dLdxShapeInfo); + + Nd4jLong *dLdhiShapeInfo = nullptr; + COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); + + Nd4jLong *dLdWShapeInfo = nullptr; + COPY_SHAPE(wShapeInfo, dLdWShapeInfo); + + Nd4jLong *dLdWcShapeInfo = nullptr; + COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo); + + Nd4jLong *dLdbShapeInfo = nullptr; + COPY_SHAPE(bShapeInfo, dLdbShapeInfo); + + Nd4jLong *dLdbcShapeInfo = nullptr; + COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); + + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), + CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), + CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp index fc63449585f1..9733af262bcd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstm.cpp @@ -22,131 +22,210 @@ #if NOT_EXCLUDED(OP_lstm) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - auto c0 = INPUT_VARIABLE(2); // initial cell state (at time step = 0) [bS x numUnits], - - auto Wx = INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto Wh = INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits] - auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numProj], that is per each time step - auto c = OUTPUT_VARIABLE(1); // cell states [time x bS x numUnits] that is per each time step - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!! - - // FIXME: double - const double clippingCellValue = T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell state is clipped - const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = T_ARG(2); - - const int rank = x->rankOf(); - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int inSize = x->sizeAt(2); - const int numProj = h0->sizeAt(1); - const int numUnits = c0->sizeAt(1); - - // input shapes validation - const std::vector correctH0Shape = {bS, numProj}; - const std::vector correctC0Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTM operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !"); - - helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, c, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias}); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto h0 = INPUT_VARIABLE( + 1); // initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! + auto c0 = INPUT_VARIABLE( + 2); // initial cell state (at time step = 0) [bS x numUnits], + + auto Wx = + INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto Wh = + INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto Wc = INPUT_VARIABLE( + 5); // diagonal weights for peephole connections [3*numUnits] + auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto h = OUTPUT_VARIABLE( + 0); // cell outputs [time x bS x numProj], that is per each time step + auto c = OUTPUT_VARIABLE( + 1); // cell states [time x bS x numUnits] that is per each time step + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int projection = + INT_ARG(1); // if 1, then projection is performed, if false then + // numProj==numUnits is mandatory!!!! + + // FIXME: double + const double clippingCellValue = + T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + const double clippingProjValue = + T_ARG(1); // clipping value for projected ht, if it is not equal to zero, + // then projected cell output is clipped + const double forgetBias = T_ARG(2); + + const int rank = x->rankOf(); + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int inSize = x->sizeAt(2); + const int numProj = h0->sizeAt(1); + const int numUnits = c0->sizeAt(1); + + // input shapes validation + const std::vector correctH0Shape = {bS, numProj}; + const std::vector correctC0Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, + "LSTM operation: wrong shape of initial cell output, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctH0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, + "LSTM operation: wrong shape of initial cell state, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctC0Shape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, + "LSTM operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, + "LSTM operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, + "LSTM operation: wrong shape of diagonal weights for peephole " + "connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, + "LSTM operation: wrong shape of projection weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "LSTM operation: wrong shape of biases, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, + "LSTM operation: projection option is switched of, and in this " + "case output dimensionality for the projection matrices " + "(numProj) must be equal to number of units in lstmCell !"); + + helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, + c, + {(double)peephole, (double)projection, + clippingCellValue, clippingProjValue, forgetBias}); + + return Status::OK(); } - DECLARE_TYPES(lstm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(lstm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(lstm) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto h0ShapeInfo = inputShape->at(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - auto c0ShapeInfo = inputShape->at(2); // initial cell state (at time step = 0) [bS x numUnits], - - auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits] - auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj] - auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] - - const int rank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int inSize = xShapeInfo[3]; - const int numProj = h0ShapeInfo[2]; - const int numUnits = c0ShapeInfo[2]; - - // input shapes validation - const std::vector correctH0Shape = {bS, numProj}; - const std::vector correctC0Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj] - ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = time; - hShapeInfo[2] = cShapeInfo[2] = bS; - hShapeInfo[3] = numProj; - cShapeInfo[3] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, shape::order(c0ShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo)); + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto h0ShapeInfo = inputShape->at( + 1); // initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! + auto c0ShapeInfo = inputShape->at( + 2); // initial cell state (at time step = 0) [bS x numUnits], + + auto WxShapeInfo = + inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto WhShapeInfo = + inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto WcShapeInfo = inputShape->at( + 5); // diagonal weights for peephole connections [3*numUnits] + auto WpShapeInfo = + inputShape->at(6); // projection weights [numUnits x numProj] + auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] + + const int rank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int inSize = xShapeInfo[3]; + const int numProj = h0ShapeInfo[2]; + const int numUnits = c0ShapeInfo[2]; + + // input shapes validation + const std::vector correctH0Shape = {bS, numProj}; + const std::vector correctC0Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, + "LSTM operation: wrong shape of initial cell output, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctH0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, + "LSTM operation: wrong shape of initial cell state, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctC0Shape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, + "LSTM operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(WxShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, + "LSTM operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, + "LSTM operation: wrong shape of diagonal weights for peephole " + "connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(WcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, + "LSTM operation: wrong shape of projection weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(WpShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "LSTM operation: wrong shape of biases, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [time x bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [time x bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = time; + hShapeInfo[2] = cShapeInfo[2] = bS; + hShapeInfo[3] = numProj; + cShapeInfo[3] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(h0ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, + shape::order(c0ShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo)); } - - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp index 3b5ae8345486..56544d8f3144 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlock.cpp @@ -22,104 +22,135 @@ #if NOT_EXCLUDED(OP_lstmBlock) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmBlock, 9, 7, false, 2, 2) { - auto maxTSLength = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(1); // input [seqLen, bS, nIn] at time t - auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, nOut], time t-1 - auto yLast = INPUT_VARIABLE(3); // previous output [bS, nOut], time t-1 - - auto W = INPUT_VARIABLE(4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to input modulation gate, [nOut] - auto Wcf = INPUT_VARIABLE(6); // weights - cell peephole (t-1) connections to forget gate, [nOut] - auto Wco = INPUT_VARIABLE(7); // weights - cell peephole (t) connections to output gate, [nOut] - auto b = INPUT_VARIABLE(8); // biases, [4*nOut] - - auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [seqLen, bS, nOut] - auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [seqLen, bs, nOut] - auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [seqLen, bs, nOut] - auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [seqLen, bs, nOut] - auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, nOut] - auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, nOut] - auto y = OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,nIn]; 1=NST=[bS,nIn,seqLen]; 2=NTS=[bS,seqLen,nIn] - const double forgetBias = T_ARG(0); - const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped - - REQUIRE_TRUE(x->rankOf()==3, 0, "lstmBlock: Input array 1 (x) rank must be got input with rank %i", x->rankOf()); - REQUIRE_TRUE(cLast->rankOf()==2 && yLast->rankOf()==2, 0, "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) - got %i, %i", cLast->rankOf(), yLast->rankOf()); - REQUIRE_TRUE(W->rankOf()==2, 0, "lstmBlock: Weights array rank must be 2"); - REQUIRE_TRUE(b->rankOf()==1, 0, "lstmBlock: Biases must be rank 1"); - REQUIRE_TRUE(i->rankOf()==3 && c->rankOf()==3 && f->rankOf()==3 && o->rankOf()==3 && z->rankOf()==3 && h->rankOf()==3 && y->rankOf()==3, - 0, "lstmBlock: Output arrays must all be rank 3"); - - helpers::lstmBlockTimeLoop(maxTSLength, x, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, {(double)peephole, forgetBias, clippingCellValue}, dataFormat); - - return Status::OK(); + auto maxTSLength = INPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(1); // input [seqLen, bS, nIn] at time t + auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, nOut], time t-1 + auto yLast = INPUT_VARIABLE(3); // previous output [bS, nOut], time t-1 + + auto W = INPUT_VARIABLE( + 4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(nIn+nOut), 4*nOut] + auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to + // input modulation gate, [nOut] + auto Wcf = INPUT_VARIABLE( + 6); // weights - cell peephole (t-1) connections to forget gate, [nOut] + auto Wco = INPUT_VARIABLE( + 7); // weights - cell peephole (t) connections to output gate, [nOut] + auto b = INPUT_VARIABLE(8); // biases, [4*nOut] + + auto i = OUTPUT_VARIABLE( + 0); // Output - input modulation gate activations [seqLen, bS, nOut] + auto c = OUTPUT_VARIABLE( + 1); // Activations, cell state (pre tanh) [seqLen, bs, nOut] + auto f = OUTPUT_VARIABLE( + 2); // Output - forget gate activations [seqLen, bs, nOut] + auto o = OUTPUT_VARIABLE( + 3); // Output - output gate activations [seqLen, bs, nOut] + auto z = + OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, nOut] + auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, nOut] + auto y = + OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int dataFormat = + INT_ARG(1); // 0=TNS=[seqLen,bS,nIn]; 1=NST=[bS,nIn,seqLen]; + // 2=NTS=[bS,seqLen,nIn] + const double forgetBias = T_ARG(0); + const double clippingCellValue = + T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + + REQUIRE_TRUE( + x->rankOf() == 3, 0, + "lstmBlock: Input array 1 (x) rank must be got input with rank %i", + x->rankOf()); + REQUIRE_TRUE(cLast->rankOf() == 2 && yLast->rankOf() == 2, 0, + "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) " + "- got %i, %i", + cLast->rankOf(), yLast->rankOf()); + REQUIRE_TRUE(W->rankOf() == 2, 0, "lstmBlock: Weights array rank must be 2"); + REQUIRE_TRUE(b->rankOf() == 1, 0, "lstmBlock: Biases must be rank 1"); + REQUIRE_TRUE(i->rankOf() == 3 && c->rankOf() == 3 && f->rankOf() == 3 && + o->rankOf() == 3 && z->rankOf() == 3 && h->rankOf() == 3 && + y->rankOf() == 3, + 0, "lstmBlock: Output arrays must all be rank 3"); + + helpers::lstmBlockTimeLoop( + maxTSLength, x, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, + {(double)peephole, forgetBias, clippingCellValue}, dataFormat); + + return Status::OK(); } DECLARE_TYPES(lstmBlock) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmBlock) { - auto x = inputShape->at(1); - auto cLast = inputShape->at(2); - auto yLast = inputShape->at(3); - auto W = inputShape->at(4); - auto b = inputShape->at(8); - - REQUIRE_TRUE(shape::rank(x)==3, 0, "lstmBlock: Input array 1 (x) rank must be got input with rank %i", shape::rank(x)); - REQUIRE_TRUE(shape::rank(cLast)==2 && shape::rank(yLast)==2, 0, "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) - got %i, %i", shape::rank(cLast), shape::rank(yLast)); - REQUIRE_TRUE(shape::rank(W)==2, 0, "lstmBlock: Weights array rank must be 2"); - REQUIRE_TRUE(shape::rank(b)==1, 0, "lstmBlock: Biases must be rank 1"); - - - const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,size]; 1=NST=[bS,size,seqLen]; 2=NTS=[bS,seqLen,size] - int bs; - int t; - int nOut = cLast[2]; //rank, bs, nOut, ...] - - Nd4jLong *s(nullptr); - ALLOCATE(s, block.workspace(), shape::shapeInfoLength(3), Nd4jLong); // [time, bS, nOut] - s[0] = 3; - if(dataFormat == 0){ - //[rank, seqLen, bs, nIn, ...] - s[1] = x[1]; //seqLen - s[2] = x[2]; //bS - s[3] = nOut; - } else if(dataFormat==1){ - //[rank, bs, nIn, seqLen, ...] - s[1] = x[1]; //bS - s[2] = nOut; - s[3] = x[3]; //seqLen - } else { - //[rank, bs, seqLen, nIn, ...] - s[1] = x[1]; //bS - s[2] = x[2]; //seqLen - s[3] = nOut; - - } - ShapeUtils::updateStridesAndType(s, x, 'c'); - - auto s1 = CONSTANT(s); - - //7 outputs, all same shape/type - return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); + auto x = inputShape->at(1); + auto cLast = inputShape->at(2); + auto yLast = inputShape->at(3); + auto W = inputShape->at(4); + auto b = inputShape->at(8); + + REQUIRE_TRUE( + shape::rank(x) == 3, 0, + "lstmBlock: Input array 1 (x) rank must be got input with rank %i", + shape::rank(x)); + REQUIRE_TRUE(shape::rank(cLast) == 2 && shape::rank(yLast) == 2, 0, + "lstmBlock: Input ranks must be 2 for inputs 2/3 (cLast, yLast) " + "- got %i, %i", + shape::rank(cLast), shape::rank(yLast)); + REQUIRE_TRUE(shape::rank(W) == 2, 0, + "lstmBlock: Weights array rank must be 2"); + REQUIRE_TRUE(shape::rank(b) == 1, 0, "lstmBlock: Biases must be rank 1"); + + const int dataFormat = + INT_ARG(1); // 0=TNS=[seqLen,bS,size]; 1=NST=[bS,size,seqLen]; + // 2=NTS=[bS,seqLen,size] + int bs; + int t; + int nOut = cLast[2]; // rank, bs, nOut, ...] + + Nd4jLong *s(nullptr); + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(3), + Nd4jLong); // [time, bS, nOut] + s[0] = 3; + if (dataFormat == 0) { + //[rank, seqLen, bs, nIn, ...] + s[1] = x[1]; // seqLen + s[2] = x[2]; // bS + s[3] = nOut; + } else if (dataFormat == 1) { + //[rank, bs, nIn, seqLen, ...] + s[1] = x[1]; // bS + s[2] = nOut; + s[3] = x[3]; // seqLen + } else { + //[rank, bs, seqLen, nIn, ...] + s[1] = x[1]; // bS + s[2] = x[2]; // seqLen + s[3] = nOut; + } + ShapeUtils::updateStridesAndType(s, x, 'c'); + + auto s1 = CONSTANT(s); + + // 7 outputs, all same shape/type + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp index d7826abc15c6..04f9f133d2c9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmBlockCell.cpp @@ -22,106 +22,165 @@ #if NOT_EXCLUDED(OP_lstmBlockCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmBlockCell, 8, 7, false, 2, 1) { - //Notation: mostly following https://arxiv.org/pdf/1503.04069.pdf - auto xt = INPUT_VARIABLE(0); // input [bS, inSize] at time t - auto cLast = INPUT_VARIABLE(1); // previous cell state [bS, numUnits], time t-1 - auto yLast = INPUT_VARIABLE(2); // previous output [bS, numUnits], time t-1 - - auto W = INPUT_VARIABLE(3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - auto Wci = INPUT_VARIABLE(4); // weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - auto Wcf = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to forget gate, [numUnits] - auto Wco = INPUT_VARIABLE(6); // weights - cell peephole (t) connections to output gate, [numUnits] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [bS, numUnits] - auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [bs, numUnits] - auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [bs, numUnits] - auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [bs, numUnits] - auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [bs, numUnits] - auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [bs, numUnits] - auto y = OUTPUT_VARIABLE(6); // current cell output [bS, numProj], time t - - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - - const double forgetBias = T_ARG(0); - const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped - - REQUIRE_TRUE(xt->rankOf()==2 && cLast->rankOf()==2 && yLast->rankOf()==2, 0, "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, outLast) - got %i, %i, %i", xt->rankOf(), cLast->rankOf(), yLast->rankOf()); - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); - const int numUnits = cLast->sizeAt(1); - - REQUIRE_TRUE(xt->sizeAt(0) == yLast->sizeAt(0) && xt->sizeAt(0) == cLast->sizeAt(0), 0, "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, cLast, yLast"); - REQUIRE_TRUE(W->rankOf()==2, 0, "lstmBlockCell: Weights array rank must be 2"); - REQUIRE_TRUE(W->sizeAt(0)==(inSize+numUnits), 0, "lstmBlockCell: Weights size(0) must be equal to inSize + numUnits, got %i", W->sizeAt(0)); - REQUIRE_TRUE(W->sizeAt(1)==(4*numUnits), 0, "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", W->sizeAt(1)); - REQUIRE_TRUE(b->rankOf()==1 && b->sizeAt(0)==(4*numUnits), 0, "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); - REQUIRE_TRUE(i->rankOf()==2 && c->rankOf()==2 && f->rankOf()==2 && o->rankOf()==2 && z->rankOf()==2 && h->rankOf()==2 && y->rankOf()==2 && - i->sizeAt(0)==bS && c->sizeAt(0)==bS && f->sizeAt(0)==bS && o->sizeAt(0)==bS && z->sizeAt(0)==bS && h->sizeAt(0)==bS && y->sizeAt(0)==bS && - i->sizeAt(1)==numUnits && c->sizeAt(1)==numUnits && f->sizeAt(1)==numUnits && o->sizeAt(1)==numUnits && z->sizeAt(1)==numUnits && h->sizeAt(1)==numUnits && y->sizeAt(1)==numUnits, - 0, "lstmBlockCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == numUnits"); - - // calculations - helpers::lstmBlockCell(xt, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, h, y, {(double)peephole, forgetBias, clippingCellValue}); - - return Status::OK(); + // Notation: mostly following https://arxiv.org/pdf/1503.04069.pdf + auto xt = INPUT_VARIABLE(0); // input [bS, inSize] at time t + auto cLast = + INPUT_VARIABLE(1); // previous cell state [bS, numUnits], time t-1 + auto yLast = INPUT_VARIABLE(2); // previous output [bS, numUnits], time t-1 + + auto W = INPUT_VARIABLE( + 3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(inSize+numUnits), 4*numUnits] + auto Wci = INPUT_VARIABLE(4); // weights - cell peephole (t-1) connections to + // input modulation gate, [numUnits] + auto Wcf = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to + // forget gate, [numUnits] + auto Wco = INPUT_VARIABLE( + 6); // weights - cell peephole (t) connections to output gate, [numUnits] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto i = OUTPUT_VARIABLE( + 0); // Output - input modulation gate activations [bS, numUnits] + auto c = + OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [bs, numUnits] + auto f = + OUTPUT_VARIABLE(2); // Output - forget gate activations [bs, numUnits] + auto o = + OUTPUT_VARIABLE(3); // Output - output gate activations [bs, numUnits] + auto z = + OUTPUT_VARIABLE(4); // Output - input gate activations [bs, numUnits] + auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [bs, numUnits] + auto y = OUTPUT_VARIABLE(6); // current cell output [bS, numProj], time t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + + const double forgetBias = T_ARG(0); + const double clippingCellValue = + T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + + REQUIRE_TRUE( + xt->rankOf() == 2 && cLast->rankOf() == 2 && yLast->rankOf() == 2, 0, + "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, " + "outLast) - got %i, %i, %i", + xt->rankOf(), cLast->rankOf(), yLast->rankOf()); + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); + const int numUnits = cLast->sizeAt(1); + + REQUIRE_TRUE( + xt->sizeAt(0) == yLast->sizeAt(0) && xt->sizeAt(0) == cLast->sizeAt(0), 0, + "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, " + "cLast, yLast"); + REQUIRE_TRUE(W->rankOf() == 2, 0, + "lstmBlockCell: Weights array rank must be 2"); + REQUIRE_TRUE(W->sizeAt(0) == (inSize + numUnits), 0, + "lstmBlockCell: Weights size(0) must be equal to inSize + " + "numUnits, got %i", + W->sizeAt(0)); + REQUIRE_TRUE( + W->sizeAt(1) == (4 * numUnits), 0, + "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", + W->sizeAt(1)); + REQUIRE_TRUE(b->rankOf() == 1 && b->sizeAt(0) == (4 * numUnits), 0, + "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); + REQUIRE_TRUE(i->rankOf() == 2 && c->rankOf() == 2 && f->rankOf() == 2 && + o->rankOf() == 2 && z->rankOf() == 2 && h->rankOf() == 2 && + y->rankOf() == 2 && i->sizeAt(0) == bS && + c->sizeAt(0) == bS && f->sizeAt(0) == bS && + o->sizeAt(0) == bS && z->sizeAt(0) == bS && + h->sizeAt(0) == bS && y->sizeAt(0) == bS && + i->sizeAt(1) == numUnits && c->sizeAt(1) == numUnits && + f->sizeAt(1) == numUnits && o->sizeAt(1) == numUnits && + z->sizeAt(1) == numUnits && h->sizeAt(1) == numUnits && + y->sizeAt(1) == numUnits, + 0, + "lstmBlockCell: Output arrays must all be rank 2 with size(0) " + "== batchSize and size(1) == numUnits"); + + // calculations + helpers::lstmBlockCell(xt, cLast, yLast, W, Wci, Wcf, Wco, b, i, c, f, o, z, + h, y, + {(double)peephole, forgetBias, clippingCellValue}); + + return Status::OK(); } DECLARE_TYPES(lstmBlockCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmBlockCell) { - auto xt = inputShape->at(0); // input [bS, inSize] at time t - auto cLast = inputShape->at(1); // previous cell state [bS, numUnits], time t-1 - auto yLast = inputShape->at(2); // previous output [bS, numUnits], time t-1 - - auto W = inputShape->at(3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - auto Wci = inputShape->at(4); // weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - auto Wcf = inputShape->at(5); // weights - cell peephole (t-1) connections to forget gate, [numUnits] - auto Wco = inputShape->at(6); // weights - cell peephole (t) connections to output gate, [numUnits] - auto b = inputShape->at(7); // biases, [4*numUnits] - - REQUIRE_TRUE(shape::rank(xt)==2 && shape::rank(cLast)==2 && shape::rank(yLast)==2, 0, "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, cLast, outLast) - got %i, %i, %i", shape::rank(xt), shape::rank(cLast), shape::rank(yLast)); - const int inSize = xt[2]; - const int numUnits = cLast[2]; //[rank, bS, nOut, ...] - REQUIRE_TRUE(xt[1] == yLast[1] && xt[1] == cLast[1], 0, "lstmBlockCell: Input minibatch sizes (dimension 0) must be same for xt, cLast, yLast"); - REQUIRE_TRUE(shape::rank(W)==2, 0, "lstmBlockCell: Weights array rank must be rank 2, got %i", shape::rank(W)); - REQUIRE_TRUE(W[1]==(inSize+numUnits), 0, "lstmBlockCell: Weights size(0) must be equal to inSize + numUnits, got %i", W[1]); - REQUIRE_TRUE(W[2]==(4*numUnits), 0, "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", W[2]); - REQUIRE_TRUE(shape::rank(b)==1 && b[1]==(4*numUnits), 0, "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); - - // evaluate output shapeInfos - const int bS = xt[1]; - Nd4jLong *s(nullptr); - ALLOCATE(s, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); // [bS, numUnits] - - s[0] = 2; - s[1] = bS; - s[2] = numUnits; - - ShapeUtils::updateStridesAndType(s, xt, 'c'); - - auto s1 = CONSTANT(s); - - //7 outputs, all same shape: z, i, f, o, h, c, y - return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); + auto xt = inputShape->at(0); // input [bS, inSize] at time t + auto cLast = + inputShape->at(1); // previous cell state [bS, numUnits], time t-1 + auto yLast = inputShape->at(2); // previous output [bS, numUnits], time t-1 + + auto W = inputShape->at( + 3); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + // weights, [(inSize+numUnits), 4*numUnits] + auto Wci = inputShape->at(4); // weights - cell peephole (t-1) connections to + // input modulation gate, [numUnits] + auto Wcf = inputShape->at(5); // weights - cell peephole (t-1) connections to + // forget gate, [numUnits] + auto Wco = inputShape->at( + 6); // weights - cell peephole (t) connections to output gate, [numUnits] + auto b = inputShape->at(7); // biases, [4*numUnits] + + REQUIRE_TRUE(shape::rank(xt) == 2 && shape::rank(cLast) == 2 && + shape::rank(yLast) == 2, + 0, + "lstmBlockCell: Input ranks must be 2 for inputs 0/1/2 (x, " + "cLast, outLast) - got %i, %i, %i", + shape::rank(xt), shape::rank(cLast), shape::rank(yLast)); + const int inSize = xt[2]; + const int numUnits = cLast[2]; //[rank, bS, nOut, ...] + REQUIRE_TRUE(xt[1] == yLast[1] && xt[1] == cLast[1], 0, + "lstmBlockCell: Input minibatch sizes (dimension 0) must be " + "same for xt, cLast, yLast"); + REQUIRE_TRUE(shape::rank(W) == 2, 0, + "lstmBlockCell: Weights array rank must be rank 2, got %i", + shape::rank(W)); + REQUIRE_TRUE(W[1] == (inSize + numUnits), 0, + "lstmBlockCell: Weights size(0) must be equal to inSize + " + "numUnits, got %i", + W[1]); + REQUIRE_TRUE( + W[2] == (4 * numUnits), 0, + "lstmBlockCell: Weights size(1) must be equal to 4*numUnits, got %i", + W[2]); + REQUIRE_TRUE(shape::rank(b) == 1 && b[1] == (4 * numUnits), 0, + "lstmBlockCell: Biases must be rank 1, size 4*numUnits"); + + // evaluate output shapeInfos + const int bS = xt[1]; + Nd4jLong *s(nullptr); + ALLOCATE(s, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); // [bS, numUnits] + + s[0] = 2; + s[1] = bS; + s[2] = numUnits; + + ShapeUtils::updateStridesAndType(s, xt, 'c'); + + auto s1 = CONSTANT(s); + + // 7 outputs, all same shape: z, i, f, o, h, c, y + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp index 765dd36cc863..dc1a51c81ad8 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmCell.cpp @@ -22,127 +22,214 @@ #if NOT_EXCLUDED(OP_lstmCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) { - auto xt = INPUT_VARIABLE(0); // input [bS x inSize] - auto ht_1 = INPUT_VARIABLE(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - auto ct_1 = INPUT_VARIABLE(2); // previous cell state [bS x numUnits], that is at previous time step t-1 - - auto Wx = INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto Wh = INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits] - auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] - auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] - - auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x numProj], that is at current time step t - auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is at current time step t - - const int peephole = INT_ARG(0); // if 1, provide peephole connections - const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!! - - // FIXME: double? - const double clippingCellValue = T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell state is clipped - const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = T_ARG(2); - - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int numUnits = ct_1->sizeAt(1); - - // input shapes validation - const std::vector correctHt_1Shape = {bS, numProj}; - const std::vector correctCt_1Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1).c_str()); - REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str()); - REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str()); - REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTMCELL operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !"); - - // calculations - helpers::lstmCell(block.launchContext(), xt,ht_1,ct_1, Wx,Wh,Wc,Wp, b, ht,ct, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias}); - - return Status::OK(); + auto xt = INPUT_VARIABLE(0); // input [bS x inSize] + auto ht_1 = INPUT_VARIABLE( + 1); // previous cell output [bS x numProj], that is at previous time + // step t-1, in case of projection=false -> numProj=numUnits!!! + auto ct_1 = INPUT_VARIABLE(2); // previous cell state [bS x numUnits], that + // is at previous time step t-1 + + auto Wx = + INPUT_VARIABLE(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto Wh = + INPUT_VARIABLE(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto Wc = INPUT_VARIABLE( + 5); // diagonal weights for peephole connections [3*numUnits] + auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj] + auto b = INPUT_VARIABLE(7); // biases, [4*numUnits] + + auto ht = OUTPUT_VARIABLE( + 0); // current cell output [bS x numProj], that is at current time step t + auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is + // at current time step t + + const int peephole = INT_ARG(0); // if 1, provide peephole connections + const int projection = + INT_ARG(1); // if 1, then projection is performed, if false then + // numProj==numUnits is mandatory!!!! + + // FIXME: double? + const double clippingCellValue = + T_ARG(0); // clipping value for ct, if it is not equal to zero, then cell + // state is clipped + const double clippingProjValue = + T_ARG(1); // clipping value for projected ht, if it is not equal to zero, + // then projected cell output is clipped + const double forgetBias = T_ARG(2); + + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int numUnits = ct_1->sizeAt(1); + + // input shapes validation + const std::vector correctHt_1Shape = {bS, numProj}; + const std::vector correctCt_1Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), + ShapeUtils::shapeAsString(ht_1).c_str()); + REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1).c_str()); + REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, + "LSTMCELL operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, + "LSTMCELL operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, + "LSTMCELL operation: wrong shape of diagonal weights for " + "peephole connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(Wc).c_str()); + REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, + "LSTMCELL operation: wrong shape of projection weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "LSTMCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, + "LSTMCELL operation: projection option is switched of, and in " + "this case output dimensionality for the projection matrices " + "(numProj) must be equal to number of units in lstmCell !"); + + // calculations + helpers::lstmCell(block.launchContext(), xt, ht_1, ct_1, Wx, Wh, Wc, Wp, b, + ht, ct, + {(double)peephole, (double)projection, clippingCellValue, + clippingProjValue, forgetBias}); + + return Status::OK(); } - DECLARE_TYPES(lstmCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - +DECLARE_TYPES(lstmCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(lstmCell) { - - auto xtShapeInfo = inputShape->at(0); // input [bS x inSize] - auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1 - - auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] - auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] - auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits] - auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj] - auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] - - const int rank = shape::rank(xtShapeInfo); - const auto bS = xtShapeInfo[1]; - const auto inSize = xtShapeInfo[2]; - const auto numProj = ht_1ShapeInfo[2]; - const auto numUnits = ct_1ShapeInfo[2]; - - // input shapes validation - const std::vector correctHt_1Shape = {bS, numProj}; - const std::vector correctCt_1Shape = {bS, numUnits}; - const std::vector correctWxShape = {inSize, 4*numUnits}; - const std::vector correctWhShape = {numProj, 4*numUnits}; - const std::vector correctWcShape = {3*numUnits}; - const std::vector correctWpShape = {numUnits, numProj}; - const std::vector correctBShape = {4*numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = bS; - hShapeInfo[2] = numProj; - cShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xtShapeInfo, shape::order(ht_1ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, xtShapeInfo, shape::order(ct_1ShapeInfo)); - - auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(hShapeInfo), ConstantShapeHelper::getInstance()->createShapeInfo(cShapeInfo)); - RELEASE(hShapeInfo, block.workspace()); - RELEASE(cShapeInfo, block.workspace()); - return result; + auto xtShapeInfo = inputShape->at(0); // input [bS x inSize] + auto ht_1ShapeInfo = inputShape->at( + 1); // previous cell output [bS x numProj], that is at previous time + // step t-1, in case of projection=false -> numProj=numUnits!!! + auto ct_1ShapeInfo = + inputShape->at(2); // previous cell state [bS x numUnits], that is at + // previous time step t-1 + + auto WxShapeInfo = + inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits] + auto WhShapeInfo = + inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits] + auto WcShapeInfo = inputShape->at( + 5); // diagonal weights for peephole connections [3*numUnits] + auto WpShapeInfo = + inputShape->at(6); // projection weights [numUnits x numProj] + auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits] + + const int rank = shape::rank(xtShapeInfo); + const auto bS = xtShapeInfo[1]; + const auto inSize = xtShapeInfo[2]; + const auto numProj = ht_1ShapeInfo[2]; + const auto numUnits = ct_1ShapeInfo[2]; + + // input shapes validation + const std::vector correctHt_1Shape = {bS, numProj}; + const std::vector correctCt_1Shape = {bS, numUnits}; + const std::vector correctWxShape = {inSize, 4 * numUnits}; + const std::vector correctWhShape = {numProj, 4 * numUnits}; + const std::vector correctWcShape = {3 * numUnits}; + const std::vector correctWpShape = {numUnits, numProj}; + const std::vector correctBShape = {4 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), + ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, + "LSTMCELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, + "LSTMCELL operation: wrong shape of input-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWxShape).c_str(), + ShapeUtils::shapeAsString(WxShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, + "LSTMCELL operation: wrong shape of hidden-to-hidden weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, + "LSTMCELL operation: wrong shape of diagonal weights for " + "peephole connections, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWcShape).c_str(), + ShapeUtils::shapeAsString(WcShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, + "LSTMCELL operation: wrong shape of projection weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctWpShape).c_str(), + ShapeUtils::shapeAsString(WpShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "LSTMCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = bS; + hShapeInfo[2] = numProj; + cShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xtShapeInfo, + shape::order(ht_1ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, xtShapeInfo, + shape::order(ct_1ShapeInfo)); + + auto result = SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(hShapeInfo), + ConstantShapeHelper::getInstance()->createShapeInfo(cShapeInfo)); + RELEASE(hShapeInfo, block.workspace()); + RELEASE(cShapeInfo, block.workspace()); + return result; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index a5c8b8d28c33..a5e8194ddd53 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -22,786 +22,1068 @@ #if NOT_EXCLUDED(OP_lstmLayer) #include -#include - +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // input weights Wx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // peephole weights Wp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // biases b, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // ******* - // sequence length array seqLen, optional: - // 1) [bS] - - // ******* - // initial output hI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // output h, optional: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 - // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 - // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 - // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 - // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 - - // ******* - // output at last step hL, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // cell state at last step cL (same shape as in hL), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(3); // activation for cell state (c) - const auto outAct = INT_ARG(4); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !"); - REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !"); - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // OUTPUTS: + + // ******* + // output h, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // output at last step hL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = + INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], + // 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && + // [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time + // sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only + const auto retLastC = B_ARG( + 7); // indicates whether to return cells state at last time step only + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE( + dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, + "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = " + "4, but got dataFormat = %i and directionMode = %i instead !", + dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER operation: cell clipping value should be " + "nonnegative (>=0) !"); + REQUIRE_TRUE( + retFullSeq || retLastH || retLastC, 0, + "LSTM_LAYER operation: please specify what output arrays to produce !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = + retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = + retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of biases, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && + (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && + (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong peephole weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + } else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || + Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && + (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of biases, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || + hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || + cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && + (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER operation: wrong peephole weights, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString({2, 3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + } + + std::vector params = { + static_cast(dataFormat), static_cast(directionMode), + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + if (directionMode == 0) { // forward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, + h, hL, cL); + } else if (directionMode == 1) { // backward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, + h, hL, cL); + } else { // bidirectional + + NDArray WxFwd = (*Wx)({0, 1, 0, 0, 0, 0}); + NDArray WxBwd = (*Wx)({1, 2, 0, 0, 0, 0}); + NDArray WrFwd = (*Wr)({0, 1, 0, 0, 0, 0}); + NDArray WrBwd = (*Wr)({1, 2, 0, 0, 0, 0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), + *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), + *hFwd(nullptr), *hBwd(nullptr); + + if (Wp) { + WpFwd = new NDArray((*Wp)({0, 1, 0, 0})); + WpBwd = new NDArray((*Wp)({1, 2, 0, 0})); } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + if (b) { + bFwd = new NDArray((*b)({0, 1, 0, 0})); + bBwd = new NDArray((*b)({1, 2, 0, 0})); } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - if(directionMode == 0) { // forward - - helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL); + if (hI) { + hIFwd = new NDArray((*hI)({0, 1, 0, 0, 0, 0})); + hIBwd = new NDArray((*hI)({1, 2, 0, 0, 0, 0})); } - else if(directionMode == 1) { // backward - - helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL); + if (cI) { + cIFwd = new NDArray((*cI)({0, 1, 0, 0, 0, 0})); + cIBwd = new NDArray((*cI)({1, 2, 0, 0, 0, 0})); + } + if (hL) { + hLFwd = new NDArray((*hL)({0, 1, 0, 0, 0, 0})); + hLBwd = new NDArray((*hL)({1, 2, 0, 0, 0, 0})); } - else { // bidirectional - - NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); - NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); - NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); - NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); - - NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), - *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr); - - if(Wp) { - WpFwd = new NDArray((*Wp)({0,1, 0,0})); - WpBwd = new NDArray((*Wp)({1,2, 0,0})); - } - if(b) { - bFwd = new NDArray((*b)({0,1, 0,0})); - bBwd = new NDArray((*b)({1,2, 0,0})); - } - if(hI) { - hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); - hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); - } - if(cI) { - cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); - cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); - } - if(hL) { - hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0})); - hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0})); - } - if(cL) { - cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0})); - cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0})); - } - - if(h) { - if(directionMode == 2) { // sum - hFwd = h; - hBwd = new NDArray(h, false, h->getContext()); - } - else if(directionMode == 3) { // concat - hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0})); - hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0})); - } - else { // directionMode == 4 - hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0})); - hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0})); - } - } - - // FIXME - following two calls are independent and may run in different streams - helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd); - helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd); - - if(h && directionMode == 2) - *h += *hBwd; - - delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; - delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd; - if(hFwd != h) - delete hFwd; + if (cL) { + cLFwd = new NDArray((*cL)({0, 1, 0, 0, 0, 0})); + cLBwd = new NDArray((*cL)({1, 2, 0, 0, 0, 0})); + } + + if (h) { + if (directionMode == 2) { // sum + hFwd = h; + hBwd = new NDArray(h, false, h->getContext()); + } else if (directionMode == 3) { // concat + hFwd = new NDArray(dataFormat <= 1 ? (*h)({0, 0, 0, 0, 0, nOut}) + : (*h)({0, 0, 0, nOut, 0, 0})); + hBwd = + new NDArray(dataFormat <= 1 ? (*h)({0, 0, 0, 0, nOut, 2 * nOut}) + : (*h)({0, 0, nOut, 2 * nOut, 0, 0})); + } else { // directionMode == 4 + hFwd = new NDArray((*h)({0, 0, 0, 1, 0, 0, 0, 0})); + hBwd = new NDArray((*h)({0, 0, 1, 2, 0, 0, 0, 0})); + } } - return Status::OK(); + // FIXME - following two calls are independent and may run in different + // streams + helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, + WpFwd, params, true, hFwd, hLFwd, cLFwd); + helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, + WpBwd, params, false, hBwd, hLBwd, cLBwd); + + if (h && directionMode == 2) *h += *hBwd; + + delete WpFwd; + delete WpBwd; + delete bFwd; + delete bBwd; + delete hIFwd; + delete hIBwd; + delete cIFwd; + delete cIBwd; + delete hLFwd; + delete hLBwd; + delete cLFwd; + delete cLBwd; + delete hBwd; + if (hFwd != h) delete hFwd; + } + + return Status::OK(); } DECLARE_TYPES(lstmLayer) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayer) { - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim - - const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument) - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - DataType type; - if(x->isR()) - type = x->dataType(); - else - type = sd::DataType::FLOAT32; - - auto shapes = SHAPELIST(); - - // evaluate h shape (output) - if(retFullSeq) { - - std::vector hShape; - - if(directionMode <= 2) { // single direction or bidirectional with sum - if(dataFormat == 0) - hShape = {sL, bS, nOut}; - else if(dataFormat == 1) - hShape = {bS, sL, nOut}; - else if(dataFormat == 2) - hShape = {bS, nOut, sL}; - } - else if(directionMode == 3) { // bidirectional with concat - - if(dataFormat == 0) - hShape = {sL, bS, 2*nOut}; - else if(dataFormat == 1) - hShape = {bS, sL, 2*nOut}; - else if(dataFormat == 2) - hShape = {bS, 2*nOut, sL}; - } - else { // bidirectional with extra output dimension equal to 2 - hShape = {sL, 2, bS, nOut}; - } - - shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape)); + const auto dataFormat = INT_ARG( + 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, + // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + + const auto retFullSeq = + B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , + // h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape + // depends on dataFormat argument) + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only, + // in this case shape would be [bS, nOut] (exact shape depends + // on dataFormat argument) + const auto retLastC = + B_ARG(7); // indicates whether to return cells state at last time step + // only, in this case shape would be [bS, nOut] (exact shape + // depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + DataType type; + if (x->isR()) + type = x->dataType(); + else + type = sd::DataType::FLOAT32; + + auto shapes = SHAPELIST(); + + // evaluate h shape (output) + if (retFullSeq) { + std::vector hShape; + + if (directionMode <= 2) { // single direction or bidirectional with sum + if (dataFormat == 0) + hShape = {sL, bS, nOut}; + else if (dataFormat == 1) + hShape = {bS, sL, nOut}; + else if (dataFormat == 2) + hShape = {bS, nOut, sL}; + } else if (directionMode == 3) { // bidirectional with concat + + if (dataFormat == 0) + hShape = {sL, bS, 2 * nOut}; + else if (dataFormat == 1) + hShape = {bS, sL, 2 * nOut}; + else if (dataFormat == 2) + hShape = {bS, 2 * nOut, sL}; + } else { // bidirectional with extra output dimension equal to 2 + hShape = {sL, 2, bS, nOut}; } - // evaluate hL shape (output at last step) - if(retLastH) { - - std::vector hLShape; + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + type, x->ordering(), hShape)); + } - if(directionMode < 2) - hLShape = {bS, nOut}; - else - hLShape = {2, bS, nOut}; + // evaluate hL shape (output at last step) + if (retLastH) { + std::vector hLShape; - shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape)); + if (directionMode < 2) + hLShape = {bS, nOut}; + else + hLShape = {2, bS, nOut}; - if(retLastC) // cL and hL have same shapes - shapes->push_back(shapes->at(shapes->size() - 1)); - } + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + type, x->ordering(), hLShape)); - // evaluate cL shape (cell state at last step) - if(retLastC && !retLastH) { + if (retLastC) // cL and hL have same shapes + shapes->push_back(shapes->at(shapes->size() - 1)); + } - std::vector cLShape; + // evaluate cL shape (cell state at last step) + if (retLastC && !retLastH) { + std::vector cLShape; - if(directionMode < 2) - cLShape = {bS, nOut}; - else - cLShape = {2, bS, nOut}; + if (directionMode < 2) + cLShape = {bS, nOut}; + else + cLShape = {2, bS, nOut}; - shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape)); - } + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + type, x->ordering(), cLShape)); + } - return shapes; + return shapes; } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // input weights Wx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // peephole weights Wp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // biases b, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // ******* - // sequence length array seqLen, optional: - // 1) [bS] - - // ******* - // initial output hI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs. output dLdh, optional: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 - // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 - // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 - // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 - // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 - - // ******* - // gradient vs output at last time step dLdhL, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // gradient vs. input dLdx: - // 1) [sL, bS, nIn] when dataFormat == 0 - // 2) [bS, sL, nIn] when dataFormat == 1 - // 3) [bS, nIn, sL] when dataFormat == 2 - - // ******* - // gradient vs. input weights dLdWx: - // 1) [nIn, 4*nOut] when directionMode < 2 - // 2) [2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // gradient vs. recurrent weights dLdWr: - // 1) [nOut, 4*nOut] when directionMode < 2 - // 2) [2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // gradient vs. peephole weights dLdWp, optional: - // 1) [3*nOut] when directionMode < 2 - // 2) [2, 3*nOut] when directionMode >= 2 - - // ******* - // gradient vs. biases dLdb, optional: - // 1) [4*nOut] when directionMode < 2 - // 2) [2, 4*nOut] when directionMode >= 2 - - // gradient vs. sequence length array dLdsL, optional (do not calculate it!!!): - // 1) [bS] always - - // ******* - // gradient vs. initial output dLdhI, optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - // ******* - // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: - // 1) [bS, nOut] when directionMode < 2 - // 2) [2, bS, nOut] when directionMode >= 2 - - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(3); // activation for cell state (c) - const auto outAct = INT_ARG(4); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} - const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given - const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !"); - REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !"); - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output - const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step - const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step - - count = 3; - auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input - auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights - auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights - auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases - auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! - auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output - auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state - auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - // gradient vs. output at last time step validation - if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); - // gradient vs. cell state at last time step validation - if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. output dLdh, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // gradient vs output at last time step dLdhL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), + // optional: 1) [bS, nOut] when directionMode < 2 2) [2, bS, nOut] when + // directionMode >= 2 + + // OUTPUTS: + + // ******* + // gradient vs. input dLdx: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // gradient vs. input weights dLdWx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. recurrent weights dLdWr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. peephole weights dLdWp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // gradient vs. biases dLdb, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // gradient vs. sequence length array dLdsL, optional (do not calculate + // it!!!): 1) [bS] always + + // ******* + // gradient vs. initial output dLdhI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = + INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], + // 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && + // [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = + B_ARG(5); // indicates whether gradient vs. outputs is given for whole + // time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} + const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at + // last time step (dLdhL) is given + const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state + // at last time step (dLdcL) is given + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + REQUIRE_TRUE( + dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, + "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode " + "= 4, but got dataFormat = %i and directionMode = %i instead !", + dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_BP operation: cell clipping value should be " + "nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, + "LSTM_LAYER_BP operation: please specify at least one of three " + "input gradient arrays: dLdh, dLdhL or dLdcL !"); + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = + retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output + const auto dLdhL = retLastH + ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. output at last time step + const auto dLdcL = + retLastC ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. cell state at last time step + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input + auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights + auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights + auto dLdb = + hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases + auto dLdsL = + hasSeqLen + ? INPUT_VARIABLE(count++) + : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! + auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) + : nullptr; // gradient vs. initial output + auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) + : nullptr; // gradient vs. initial cell state + auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) + : nullptr; // gradient vs. peephole weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && + (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && + (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong peephole weights, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if (dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || + dLdhL->sizeAt(1) != nOut)) + REQUIRE_TRUE( + false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last " + "time step, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if (dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || + dLdcL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell " + "state at last time, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdcL).c_str()); + } else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || + Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && + (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || + hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || + cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if (Wp != nullptr && + (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong peephole weights, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, 3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if (dLdhL != nullptr && + (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || + dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) + REQUIRE_TRUE( + false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last " + "time step, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if (dLdcL != nullptr && + (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || + dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell " + "state at last time, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(dLdcL).c_str()); + } + + // gradient vs. output validation + if (dLdh) { + int factor = directionMode <= 2 ? 1 : 2; + std::vector expdLdhShape; + if (dataFormat == 0) + expdLdhShape = std::vector{sL, bS, factor * nOut}; + else if (dataFormat == 1) + expdLdhShape = std::vector{bS, sL, factor * nOut}; + else if (dataFormat == 2) + expdLdhShape = std::vector{bS, factor * nOut, sL}; + else + expdLdhShape = std::vector{sL, 2, bS, nOut}; + REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. " + "output, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expdLdhShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + } + + std::vector params = { + static_cast(dataFormat), static_cast(directionMode), + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + if (directionMode == 0) { // forward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, + dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, + dLdhI, dLdcI, dLdWp); + } else if (directionMode == 1) { // backward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, + dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, + dLdhI, dLdcI, dLdWp); + } else { // bidirectional + + NDArray WxFwd = (*Wx)({0, 1, 0, 0, 0, 0}); + NDArray WxBwd = (*Wx)({1, 2, 0, 0, 0, 0}); + NDArray dLdWxFwd = (*dLdWx)({0, 1, 0, 0, 0, 0}); + NDArray dLdWxBwd = (*dLdWx)({1, 2, 0, 0, 0, 0}); + + NDArray WrFwd = (*Wr)({0, 1, 0, 0, 0, 0}); + NDArray WrBwd = (*Wr)({1, 2, 0, 0, 0, 0}); + NDArray dLdWrFwd = (*dLdWr)({0, 1, 0, 0, 0, 0}); + NDArray dLdWrBwd = (*dLdWr)({1, 2, 0, 0, 0, 0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), + *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), + *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), + *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), + *dLdbBwd(nullptr), *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), + *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); + + if (Wp) { + WpFwd = new NDArray((*Wp)({0, 1, 0, 0})); + WpBwd = new NDArray((*Wp)({1, 2, 0, 0})); + dLdWpFwd = new NDArray((*dLdWp)({0, 1, 0, 0})); + dLdWpBwd = new NDArray((*dLdWp)({1, 2, 0, 0})); } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - // gradient vs. output at last time step validation - if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); - // gradient vs. cell state at last time step validation - if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + if (b) { + bFwd = new NDArray((*b)({0, 1, 0, 0})); + bBwd = new NDArray((*b)({1, 2, 0, 0})); + dLdbFwd = new NDArray((*dLdb)({0, 1, 0, 0})); + dLdbBwd = new NDArray((*dLdb)({1, 2, 0, 0})); } - - // gradient vs. output validation - if(dLdh) { - int factor = directionMode <= 2 ? 1 : 2; - std::vector expdLdhShape; - if(dataFormat == 0) expdLdhShape = std::vector{sL, bS, factor*nOut}; - else if(dataFormat == 1) expdLdhShape = std::vector{bS, sL, factor*nOut}; - else if(dataFormat == 2) expdLdhShape = std::vector{bS, factor*nOut, sL}; - else expdLdhShape = std::vector{sL, 2, bS, nOut}; - REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + if (hI) { + hIFwd = new NDArray((*hI)({0, 1, 0, 0, 0, 0})); + hIBwd = new NDArray((*hI)({1, 2, 0, 0, 0, 0})); + dLdhIFwd = new NDArray((*dLdhI)({0, 1, 0, 0, 0, 0})); + dLdhIBwd = new NDArray((*dLdhI)({1, 2, 0, 0, 0, 0})); } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - if(directionMode == 0) { // forward - - helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + if (cI) { + cIFwd = new NDArray((*cI)({0, 1, 0, 0, 0, 0})); + cIBwd = new NDArray((*cI)({1, 2, 0, 0, 0, 0})); + dLdcIFwd = new NDArray((*dLdcI)({0, 1, 0, 0, 0, 0})); + dLdcIBwd = new NDArray((*dLdcI)({1, 2, 0, 0, 0, 0})); + } + if (dLdhL) { + dLdhLFwd = new NDArray((*dLdhL)({0, 1, 0, 0, 0, 0})); + dLdhLBwd = new NDArray((*dLdhL)({1, 2, 0, 0, 0, 0})); + } + if (dLdcL) { + dLdcLFwd = new NDArray((*dLdcL)({0, 1, 0, 0, 0, 0})); + dLdcLBwd = new NDArray((*dLdcL)({1, 2, 0, 0, 0, 0})); } - else if(directionMode == 1) { // backward - helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + if (dLdh) { + if (directionMode == 2) { // sum + dLdhFwd = dLdh; + dLdhBwd = dLdh; + } else if (directionMode == 3) { // concat + dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0, 0, 0, 0, 0, nOut}) + : (*dLdh)({0, 0, 0, nOut, 0, 0})); + dLdhBwd = new NDArray(dataFormat <= 1 + ? (*dLdh)({0, 0, 0, 0, nOut, 2 * nOut}) + : (*dLdh)({0, 0, nOut, 2 * nOut, 0, 0})); + } else { // directionMode == 4 + dLdhFwd = new NDArray((*dLdh)({0, 0, 0, 1, 0, 0, 0, 0})); + dLdhBwd = new NDArray((*dLdh)({0, 0, 1, 2, 0, 0, 0, 0})); + } } - else { // bidirectional - - NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); - NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); - NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0}); - NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0}); - - NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); - NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); - NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0}); - NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0}); - - NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), - *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), - *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr), - *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); - - if(Wp) { - WpFwd = new NDArray((*Wp)({0,1, 0,0})); - WpBwd = new NDArray((*Wp)({1,2, 0,0})); - dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0})); - dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0})); - } - if(b) { - bFwd = new NDArray((*b)({0,1, 0,0})); - bBwd = new NDArray((*b)({1,2, 0,0})); - dLdbFwd = new NDArray((*dLdb)({0,1, 0,0})); - dLdbBwd = new NDArray((*dLdb)({1,2, 0,0})); - } - if(hI) { - hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); - hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); - dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0})); - dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0})); - } - if(cI) { - cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); - cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); - dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0})); - dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0})); - } - if(dLdhL) { - dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0})); - dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0})); - } - if(dLdcL) { - dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0})); - dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); - } - - if(dLdh) { - if(directionMode == 2) { // sum - dLdhFwd = dLdh; - dLdhBwd = dLdh; - } - else if(directionMode == 3) { // concat - dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); - dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0})); - } - else { // directionMode == 4 - dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0})); - dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0})); - } - } - - NDArray dLdxBwd = dLdx->ulike(); - - // FIXME - following two calls are independent and may run in different streams - helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); - helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); - - *dLdx += dLdxBwd; - - delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; - delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; - delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; - delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; - - if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; } + + NDArray dLdxBwd = dLdx->ulike(); + + // FIXME - following two calls are independent and may run in different + // streams + helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, + WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, + true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, + dLdhIFwd, dLdcIFwd, dLdWpFwd); + helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, + WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, + false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, + dLdhIBwd, dLdcIBwd, dLdWpBwd); + + *dLdx += dLdxBwd; + + delete WpFwd; + delete WpBwd; + delete bFwd; + delete bBwd; + delete hIFwd; + delete hIBwd; + delete cIFwd; + delete cIBwd; + delete dLdhLFwd; + delete dLdhLBwd; + delete dLdcLFwd; + delete dLdcLBwd; + delete dLdWpFwd; + delete dLdWpBwd; + delete dLdbFwd; + delete dLdbBwd; + delete dLdhIFwd; + delete dLdhIBwd; + delete dLdcIFwd; + delete dLdcIBwd; + + if (!(dLdh && directionMode == 2)) { + delete dLdhFwd; + delete dLdhBwd; } + } - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(lstmLayer_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(lstmLayer_bp) { - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - - int count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - auto outShapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); - - if(b != nullptr) - outShapes->push_back(b->shapeInfo()); - if(seqLen != nullptr) - outShapes->push_back(seqLen->shapeInfo()); - if(hI != nullptr) - outShapes->push_back(hI->shapeInfo()); - if(cI != nullptr) - outShapes->push_back(cI->shapeInfo()); - if(Wp != nullptr) - outShapes->push_back(Wp->shapeInfo()); - - return outShapes; + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + + int count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + auto outShapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); + + if (b != nullptr) outShapes->push_back(b->shapeInfo()); + if (seqLen != nullptr) outShapes->push_back(seqLen->shapeInfo()); + if (hI != nullptr) outShapes->push_back(hI->shapeInfo()); + if (cI != nullptr) outShapes->push_back(cI->shapeInfo()); + if (Wp != nullptr) outShapes->push_back(Wp->shapeInfo()); + + return outShapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp index 645541d6b25a..9775b68cc34f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -22,318 +22,408 @@ #if NOT_EXCLUDED(OP_lstmLayerCell) #include -#include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - // input x: [bS, nIn] or [nIn] - // input weights Wx: [nIn, 4*nOut] - // recurrent weights Wr: [nOut, 4*nOut] - // initial (previous) output hI: [bS, nOut] or [nOut] - // initial (previous) cell state cI: [bS, nOut] or [nOut] - // biases b (optional): [4*nOut] - // peephole weights Wp (optional): [3*nOut] - - // OUTPUTS: - // current output h: [bS, nOut] or [nOut] - // current cell state c: [bS, nOut] or [nOut] - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(1); // activation for cell state (c) - const auto outAct = INT_ARG(2); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights - - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !"); - - auto h = OUTPUT_VARIABLE(0); - auto c = OUTPUT_VARIABLE(1); - - // evaluate dimensions - const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(-1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // initial output/cell validation - std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; - REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - - std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); - - return Status::OK(); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // biases b (optional): [4*nOut] + // peephole weights Wp (optional): [3*nOut] + + // OUTPUTS: + // current output h: [bS, nOut] or [nOut] + // current cell state c: [bS, nOut] or [nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_CELL operation: cell clipping value should be " + "nonnegative (>=0) !"); + + auto h = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(1); + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 + ? std::vector{nOut} + : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL operation: wrong shape of initial cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of biases, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL operation: wrong shape of peephole weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + + std::vector params = { + static_cast(0) /*ignore*/, static_cast(0) /*ignore*/, + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); + + return Status::OK(); } DECLARE_TYPES(lstmLayerCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayerCell) { + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + uint count = hasBiases ? 4 : 3; + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count); // initial cell state - uint count = hasBiases ? 4 : 3; - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count); // initial cell state - - return new ShapeList({hI->shapeInfo(), cI->shapeInfo()}); + return new ShapeList({hI->shapeInfo(), cI->shapeInfo()}); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = clip(ft ◦ ct-1 + it ◦ c't) - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - // input x: [bS, nIn] or [nIn] - // input weights Wx: [nIn, 4*nOut] - // recurrent weights Wr: [nOut, 4*nOut] - // initial (previous) output hI: [bS, nOut] or [nOut] - // initial (previous) cell state cI: [bS, nOut] or [nOut] - // gradient wrt output dLdh: [bS, nOut] or [nOut] - // gradient wrt cell state dLdc: [bS, nOut] or [nOut] - // peephole weights Wp (optional): [3*nOut] - // biases b (optional): [4*nOut] - - // OUTPUTS: - // gradient wrt x dLdx: [bS, nIn] or [nIn] - // gradient wrt Wx dLdWx: [nIn, 4*nOut] - // gradient wrt Wr dLdWr: [nOut, 4*nOut] - // gradient wrt hI dLdhI: [bS, nOut] or [nOut] - // gradient wrt cI dLdcI: [bS, nOut] or [nOut] - // gradient wrt b dLdb (optional): [4*nOut] - // gradient wrt Wp dLdWp (optional): [3*nOut] - - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates - const auto cellAct = INT_ARG(1); // activation for cell state (c) - const auto outAct = INT_ARG(2); // activation for output (h) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; - const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; - const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; - const auto gateActHasBeta = gateAct == 3 || gateAct == 6; - const auto cellActHasBeta = cellAct == 3 || cellAct == 6; - const auto outActHasBeta = outAct == 3 || outAct == 6; - - uint count = 1; - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; - const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; - const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; - const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; - const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; - const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; - - count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output - - REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !"); - - count = 3; - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdWx = OUTPUT_VARIABLE(1); - auto dLdWr = OUTPUT_VARIABLE(2); - auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; - auto dLdhI = OUTPUT_VARIABLE(count++); - auto dLdcI = OUTPUT_VARIABLE(count++); - auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; - - // evaluate dimensions - const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(-1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // initial output/cell validation - std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; - REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str()); - // peephole weights validation - if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); - if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str()); - - - std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), - static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), - static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), - static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; - - std::vector zShape = x->rankOf() == 1 ? std::vector({4*nOut}) : std::vector({bS, 4*nOut}); - - NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); - NDArray a = z.ulike(); - NDArray h = cI->ulike(); - NDArray c = cI->ulike(); - - helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); - - helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); - - return Status::OK(); + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // gradient wrt output dLdh: [bS, nOut] or [nOut] + // gradient wrt cell state dLdc: [bS, nOut] or [nOut] + // peephole weights Wp (optional): [3*nOut] + // biases b (optional): [4*nOut] + + // OUTPUTS: + // gradient wrt x dLdx: [bS, nIn] or [nIn] + // gradient wrt Wx dLdWx: [nIn, 4*nOut] + // gradient wrt Wr dLdWr: [nOut, 4*nOut] + // gradient wrt hI dLdhI: [bS, nOut] or [nOut] + // gradient wrt cI dLdcI: [bS, nOut] or [nOut] + // gradient wrt b dLdb (optional): [4*nOut] + // gradient wrt Wp dLdWp (optional): [3*nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, + // 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, + // 8=ELU, 9=softsign, 10=softplus + const auto gateAct = + INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || + gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || + cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = + outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output + + REQUIRE_TRUE(cellClip >= 0, 0, + "LSTM_LAYER_CELL_BP operation: cell clipping value should be " + "nonnegative (>=0) !"); + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdWx = OUTPUT_VARIABLE(1); + auto dLdWr = OUTPUT_VARIABLE(2); + auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; + auto dLdhI = OUTPUT_VARIABLE(count++); + auto dLdcI = OUTPUT_VARIABLE(count++); + auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 + ? std::vector{nOut} + : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(exphIcIShape).c_str(), + ShapeUtils::shapeAsString(dLdh).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of biases, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(dLdb).c_str()); + // peephole weights validation + if (Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of peephole " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wp).c_str()); + if (dLdWp != nullptr && + (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({3 * nOut}).c_str(), + ShapeUtils::shapeAsString(dLdWp).c_str()); + + std::vector params = { + static_cast(0) /*ignore*/, static_cast(0) /*ignore*/, + static_cast(cellClip), static_cast(gateAct), + static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), + static_cast(cellBeta), static_cast(outAct), + static_cast(outAlpha), static_cast(outBeta)}; + + std::vector zShape = x->rankOf() == 1 + ? std::vector({4 * nOut}) + : std::vector({bS, 4 * nOut}); + + NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); + NDArray a = z.ulike(); + NDArray h = cI->ulike(); + NDArray c = cI->ulike(); + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); + + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, + &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, + dLdb, dLdWp); + + return Status::OK(); } DECLARE_TYPES(lstmLayerCellBp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(lstmLayerCellBp) { + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasPH = + B_ARG(1); // indicates whether peephole connections are present - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasPH = B_ARG(1); // indicates whether peephole connections are present - - uint count = 3; - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = INPUT_VARIABLE(count++); // initial output - const auto cI = INPUT_VARIABLE(count++); // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + uint count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights - auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); + auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo()); - if(b != nullptr) - shapes->push_back(b->shapeInfo()); + if (b != nullptr) shapes->push_back(b->shapeInfo()); - shapes->push_back(hI->shapeInfo()); - shapes->push_back(cI->shapeInfo()); + shapes->push_back(hI->shapeInfo()); + shapes->push_back(cI->shapeInfo()); - if(Wp != nullptr) - shapes->push_back(Wp->shapeInfo()); + if (Wp != nullptr) shapes->push_back(Wp->shapeInfo()); - return shapes; + return shapes; } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp index 0e0da782d844..67cc78887579 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sru.cpp @@ -15,7 +15,8 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // //@author Yurii Shyrma // @@ -23,563 +24,863 @@ #include #if NOT_EXCLUDED(OP_sru) -#include -#include #include #include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 - auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time] - auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time] - - const int rank = x->rankOf(); // = 3 - const auto bS = x->sizeAt(0); - const auto inSize = x->sizeAt(1); - const auto time = x->sizeAt(2); - - // input shapes validation - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {3*inSize, inSize}; - const std::vector bCorrectShape = {2*inSize}; - const std::vector c0CorrectShape = {bS, inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - // xm = x * mask - auto xm = x; - if(mask) { - xm = new NDArray(x->shapeInfo(), true, block.launchContext()); - x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); - } - - // time loop - helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c); - - if(mask) - delete xm; - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [bS x inSize x time], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 + auto mask = + block.width() > 4 + ? INPUT_VARIABLE(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time] + auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time] + + const int rank = x->rankOf(); // = 3 + const auto bS = x->sizeAt(0); + const auto inSize = x->sizeAt(1); + const auto time = x->sizeAt(2); + + // input shapes validation + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU operation: wrong rank of biases array, expected is %i, " + "but got %i instead !", + 1, b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of initial state array, expected is " + "%i, but got %i instead !", + rank - 1, c0->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU operation: wrong rank of mask array, expected is %i, but " + "got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {3 * inSize, inSize}; + const std::vector bCorrectShape = {2 * inSize}; + const std::vector c0CorrectShape = {bS, inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU operation: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU operation: wrong shape of initial state array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + // xm = x * mask + auto xm = x; + if (mask) { + xm = new NDArray(x->shapeInfo(), true, block.launchContext()); + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *xm); + } + + // time loop + helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c); + + if (mask) delete xm; + + return Status::OK(); } - DECLARE_TYPES(sru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru) { - - auto xShapeInfo = inputShape->at(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features - auto wShapeInfo = inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize] - auto bShapeInfo = inputShape->at(2); // B, row of biases with twice length [2*inSize] - auto c0ShapeInfo = inputShape->at(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 - auto maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - const int rank = xShapeInfo[0]; // = 3 - const int bS = xShapeInfo[1]; - const int inSize = xShapeInfo[2]; - const int time = xShapeInfo[3]; - - // input shapes validation - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {3*inSize, inSize}; - const std::vector bCorrectShape = {2*inSize}; - const std::vector c0CorrectShape = {bS, inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - Nd4jLong* newShapeInfo1 = nullptr; - ALLOCATE(newShapeInfo1, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time] - - newShapeInfo1[0] = rank; - newShapeInfo1[1] = bS; - newShapeInfo1[2] = inSize; - newShapeInfo1[3] = time; - - ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, shape::order(xShapeInfo)); - ShapeDescriptor descriptor(newShapeInfo1); - RELEASE(newShapeInfo1, block.workspace()); - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); - return SHAPELIST(result, result); + auto xShapeInfo = inputShape->at( + 0); // X, input 3d tensor [bS x inSize x time], time - number of time + // steps, bS - batch size, inSize - number of features + auto wShapeInfo = + inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize] + auto bShapeInfo = + inputShape->at(2); // B, row of biases with twice length [2*inSize] + auto c0ShapeInfo = inputShape->at( + 3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 + auto maskShapeInfo = + block.width() > 4 + ? inputShape->at(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + const int rank = xShapeInfo[0]; // = 3 + const int bS = xShapeInfo[1]; + const int inSize = xShapeInfo[2]; + const int time = xShapeInfo[3]; + + // input shapes validation + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU operation: wrong rank of biases array, expected is %i, " + "but got %i instead !", + 1, bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of initial state array, expected is " + "%i, but got %i instead !", + rank - 1, c0ShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU operation: wrong rank of mask array, expected is %i, but " + "got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {3 * inSize, inSize}; + const std::vector bCorrectShape = {2 * inSize}; + const std::vector c0CorrectShape = {bS, inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU operation: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU operation: wrong shape of initial state array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + Nd4jLong* newShapeInfo1 = nullptr; + ALLOCATE(newShapeInfo1, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x inSize x time] + + newShapeInfo1[0] = rank; + newShapeInfo1[1] = bS; + newShapeInfo1[2] = inSize; + newShapeInfo1[3] = time; + + ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, + shape::order(xShapeInfo)); + ShapeDescriptor descriptor(newShapeInfo1); + RELEASE(newShapeInfo1, block.workspace()); + auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); + return SHAPELIST(result, result); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) { - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 - auto c = INPUT_VARIABLE(4); // C, [bS x K x N] - auto inGradCt = INPUT_VARIABLE(5); // [bS x K] - auto inGradH = INPUT_VARIABLE(6); // [bS x K x N] - NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] - - bool applyMask = false; - if (block.width() > 7) { - mask = INPUT_VARIABLE(7); - applyMask = true; - } - - auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N] - auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K] - auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K] - auto gradInit = OUTPUT_VARIABLE(3); // [bS x K] - - const int bS = x->shapeOf()[0]; - const int K = x->shapeOf()[1]; - const int N = x->shapeOf()[2]; // N - number of time steps - - auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext()); - auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext()); - auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext()); - auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto gradTanh = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto gradCt = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto ftMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto rtMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto temp1 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - auto temp2 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); - - // x = x * mask - if(applyMask) - x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask - // multiplication matrix wi = matmul(w,x), U = WX - auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] - - auto wiZ = (*wi)({0,0, 0,K, 0,0}, true); // [bS x K x N] - auto wiF = (*wi)({0,0, K,2*K, 0,0}, true); // forget gate [bS x K x N] - auto wiR = (*wi)({0,0, 2*K,3*K, 0,0}, true); // reset gate [bS x K x N] - auto bF = (*b) ({0,0, 0,K }, true); // biases for forget gate [1 x K] - auto bR = (*b) ({0,0, K,2*K}, true); // biases for reset gate [1 x K] - auto gradBF = (*gradBias)({0,0, 0,K, 0,0}, true); // [bS x K x N] - auto gradBR = (*gradBias)({0,0, K,2*K, 0,0}, true); // [bS x K x N] - auto gradUZ = (*gradU) ({0,0, 0,K, 0,0}, true ); // [bS x K x N] - auto gradUF = (*gradU) ({0,0, K,2*K, 0,0}, true ); // [bS x K x N] - auto gradUR = (*gradU) ({0,0, 2*K,3*K, 0,0}, true ); // [bS x K x N] - - NDArray* ct_1 = nullptr; - - std::vector idx = {0,0, 0,0, 0,0}; - - for (int t = N-1; t >=0 ; --t) { - // initialization - idx[4] = t; - idx[5] = t + 1; - auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto inGradHt = (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] - - if(t != 0) { - idx[4] = t - 1; - idx[5] = t; - ct_1 = new NDArray((*c)(idx)); // previous c_{t-1} - } - else - ct_1 = c0; - - ///////////////// forward - // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) - ft.addRowVector(bF, ft); - rt.addRowVector(bR, rt); - ft.applyTransform(transform::Sigmoid, ft); - rt.applyTransform(transform::Sigmoid, rt); - - // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); - ct.applyTransform(transform::Tanh, *gct); - // ftMinus = 1-ft, rtMinus = 1-rt - ft.applyTransform(transform::OneMinus, *ftMinus); - rt.applyTransform(transform::OneMinus, *rtMinus); - - ///////////////// backward - // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - gct->applyPairwiseTransform(pairwise::Subtract, xt, *temp1); // temp1 = (g_ct - xt) - rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, *temp2); // temp2 = (1.0f - rt) * rt; - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; - inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; - - // bF, TODO - tanh - // gradTanh = (1.0f - g_ct * g_ct); - gct->applyPairwiseTransform(pairwise::Multiply, *gct, *gradTanh); // gradTanh = g_ct * g_ct - gradTanh->applyTransform(transform::OneMinus, *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) - // gradCt = inGradHt * rt * gradTanh - rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *gradCt); // gradCt = rt * gradTanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, *gradCt, *gradCt); // gradCt = inGradHt * rt * gradTanh - // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) - ct_1->applyPairwiseTransform(pairwise::Subtract, zt, *temp2); // temp2 = (ct_1 - zt) - temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) - temp1->applyPairwiseTransform(pairwise::Multiply, ft, *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft - temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, gradBFt); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; - - // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); - inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); - - // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, *temp1); // temp1 = rt * grad_tanh - inGradHt.applyPairwiseTransform(pairwise::Multiply, *temp1, *temp1); // temp1 = inGradHt * rt * grad_tanh - temp1->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt - temp1->applyPairwiseTransform(pairwise::Multiply, *ftMinus, gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); - gradUFt.assign(&gradBFt); - gradURt.assign(&gradBRt); - - // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; - gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, *temp1); // temp1 = (gradCt + inGradCt) - temp1->applyPairwiseTransform(pairwise::Multiply, ft, *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; - - if(t != 0) - delete ct_1; - } - - // gradInit - gradInit->assign(inGradCt); - - // gradX - auto weightsT = w->transpose(); // [K x 3K] - MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] - gradX->applyPairwiseTransform(pairwise::Add, *gradHX, *gradX); // + grad_highway_x - if(applyMask) - gradX->applyBroadcast(broadcast::Multiply, {0,1}, *mask, *gradX); // apply mask - - // gradB - auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K] - gradB->assign(temp3); - - // gradW [bS x 3K x K] - x->permutei({0, 2, 1}); // [bS x N x K] - MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] - - delete gct; delete gradU; delete gradHX; - delete temp1; delete temp2; delete gradCt; delete wi; - delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias; - - return Status::OK(); + auto x = + INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time + // steps, bS - batch size, K - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 + auto c = INPUT_VARIABLE(4); // C, [bS x K x N] + auto inGradCt = INPUT_VARIABLE(5); // [bS x K] + auto inGradH = INPUT_VARIABLE(6); // [bS x K x N] + NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] + + bool applyMask = false; + if (block.width() > 7) { + mask = INPUT_VARIABLE(7); + applyMask = true; + } + + auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N] + auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K] + auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K] + auto gradInit = OUTPUT_VARIABLE(3); // [bS x K] + + const int bS = x->shapeOf()[0]; + const int K = x->shapeOf()[1]; + const int N = x->shapeOf()[2]; // N - number of time steps + + auto gradBias = NDArrayFactory::create_( + x->ordering(), {bS, 2 * K, N}, gradX->dataType(), block.launchContext()); + auto gradU = NDArrayFactory::create_( + x->ordering(), {bS, 3 * K, N}, gradX->dataType(), block.launchContext()); + auto gradHX = NDArrayFactory::create_( + x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext()); + auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), + block.launchContext()); + auto gradTanh = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto gradCt = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto ftMinus = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto rtMinus = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto temp1 = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + auto temp2 = NDArrayFactory::create_( + c->ordering(), {bS, K}, gradX->dataType(), block.launchContext()); + + // x = x * mask + if (applyMask) + x->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, *x); // apply mask + // multiplication matrix wi = matmul(w,x), U = WX + auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N] + + auto wiZ = (*wi)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto wiF = (*wi)({0, 0, K, 2 * K, 0, 0}, true); // forget gate [bS x K x N] + auto wiR = + (*wi)({0, 0, 2 * K, 3 * K, 0, 0}, true); // reset gate [bS x K x N] + auto bF = (*b)({0, 0, 0, K}, true); // biases for forget gate [1 x K] + auto bR = (*b)({0, 0, K, 2 * K}, true); // biases for reset gate [1 x K] + auto gradBF = (*gradBias)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto gradBR = (*gradBias)({0, 0, K, 2 * K, 0, 0}, true); // [bS x K x N] + auto gradUZ = (*gradU)({0, 0, 0, K, 0, 0}, true); // [bS x K x N] + auto gradUF = (*gradU)({0, 0, K, 2 * K, 0, 0}, true); // [bS x K x N] + auto gradUR = (*gradU)({0, 0, 2 * K, 3 * K, 0, 0}, true); // [bS x K x N] + + NDArray* ct_1 = nullptr; + + std::vector idx = {0, 0, 0, 0, 0, 0}; + + for (int t = N - 1; t >= 0; --t) { + // initialization + idx[4] = t; + idx[5] = t + 1; + auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto inGradHt = + (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K] + + if (t != 0) { + idx[4] = t - 1; + idx[5] = t; + ct_1 = new NDArray((*c)(idx)); // previous c_{t-1} + } else + ct_1 = c0; + + ///////////////// forward + // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) + ft.addRowVector(bF, ft); + rt.addRowVector(bR, rt); + ft.applyTransform(transform::Sigmoid, ft); + rt.applyTransform(transform::Sigmoid, rt); + + // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) + // ? reluf(cur) : cur ); + ct.applyTransform(transform::Tanh, *gct); + // ftMinus = 1-ft, rtMinus = 1-rt + ft.applyTransform(transform::OneMinus, *ftMinus); + rt.applyTransform(transform::OneMinus, *rtMinus); + + ///////////////// backward + // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + gct->applyPairwiseTransform(pairwise::Subtract, xt, + *temp1); // temp1 = (g_ct - xt) + rtMinus->applyPairwiseTransform(pairwise::Multiply, rt, + *temp2); // temp2 = (1.0f - rt) * rt; + temp1->applyPairwiseTransform( + pairwise::Multiply, *temp2); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *temp1, + gradBRt); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; + + // bF, TODO - tanh + // gradTanh = (1.0f - g_ct * g_ct); + gct->applyPairwiseTransform(pairwise::Multiply, *gct, + *gradTanh); // gradTanh = g_ct * g_ct + gradTanh->applyTransform(transform::OneMinus, + *gradTanh); // gradTanh = (1.0f - g_ct * g_ct) + // gradCt = inGradHt * rt * gradTanh + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, + *gradCt); // gradCt = rt * gradTanh + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *gradCt, + *gradCt); // gradCt = inGradHt * rt * gradTanh + // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, + *temp1); // temp1 = (gradCt + inGradCt) + ct_1->applyPairwiseTransform(pairwise::Subtract, zt, + *temp2); // temp2 = (ct_1 - zt) + temp1->applyPairwiseTransform( + pairwise::Multiply, *ftMinus, + *temp1); // temp1 = (gradCt + inGradCt)*(1-ft) + temp1->applyPairwiseTransform( + pairwise::Multiply, ft, + *temp1); // temp1 = (gradCt + inGradCt)*(1-ft)*ft + temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, + gradBFt); // gradBFt = (gradCt + inGradCt) * + // (ct_1 - zt) * (1 - ft) * ft; + + // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); + inGradHt.applyPairwiseTransform(pairwise::Multiply, *rtMinus, gradHXt); + + // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); + rt.applyPairwiseTransform(pairwise::Multiply, *gradTanh, + *temp1); // temp1 = rt * grad_tanh + inGradHt.applyPairwiseTransform( + pairwise::Multiply, *temp1, + *temp1); // temp1 = inGradHt * rt * grad_tanh + temp1->applyPairwiseTransform( + pairwise::Add, *inGradCt, + *temp1); // temp1 = inGradHt * rt * grad_tanh + inGradCt + temp1->applyPairwiseTransform( + pairwise::Multiply, *ftMinus, + gradUZt); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - + // ft); + gradUFt.assign(&gradBFt); + gradURt.assign(&gradBRt); + + // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; + gradCt->applyPairwiseTransform(pairwise::Add, *inGradCt, + *temp1); // temp1 = (gradCt + inGradCt) + temp1->applyPairwiseTransform( + pairwise::Multiply, ft, + *inGradCt); // inGradCt = (gradCt + inGradCt) * ft; + + if (t != 0) delete ct_1; + } + + // gradInit + gradInit->assign(inGradCt); + + // gradX + auto weightsT = w->transpose(); // [K x 3K] + MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] + gradX->applyPairwiseTransform(pairwise::Add, *gradHX, + *gradX); // + grad_highway_x + if (applyMask) + gradX->applyBroadcast(broadcast::Multiply, {0, 1}, *mask, + *gradX); // apply mask + + // gradB + auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0, 2}, false, + true); // [1 x 2K] + gradB->assign(temp3); + + // gradW [bS x 3K x K] + x->permutei({0, 2, 1}); // [bS x N x K] + MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K] + + delete gct; + delete gradU; + delete gradHX; + delete temp1; + delete temp2; + delete gradCt; + delete wi; + delete gradTanh; + delete ftMinus; + delete rtMinus; + delete gradBias; + + return Status::OK(); } - DECLARE_TYPES(sru_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru_bp) { - - auto inShape = inputShape->at(0); // [bS x inSize x time] - auto bS = inShape[1]; - auto inSize = inShape[2]; - auto time = inShape[3]; - char order = (char)(inShape[9]); - - ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time}); - ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize}); - ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize}); - ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); + auto inShape = inputShape->at(0); // [bS x inSize x time] + auto bS = inShape[1]; + auto inSize = inShape[2]; + auto time = inShape[3]; + char order = (char)(inShape[9]); + + ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, + {bS, inSize, time}); + ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, + {bS, 3 * inSize, inSize}); + ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, + {1, 2 * inSize}); + ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, + {bS, inSize}); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); } - - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) { - - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 - NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] - - auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize] - auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize] - - // input shapes validation - const int rank = x->rankOf(); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong inSize = x->sizeAt(2) / 2; - - REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf()); - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct); - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] + auto b = + INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 + NDArray* mask = + block.width() > 4 + ? INPUT_VARIABLE(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] + + auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize] + auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize] + + // input shapes validation + const int rank = x->rankOf(); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong inSize = x->sizeAt(2) / 2; + + REQUIRE_TRUE(x->rankOf() == rank, 0, + "SRU_BI operation: wrong rank of input array, expected is %i, " + "but got %i instead !", + rank, x->rankOf()); + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU_BI operation: wrong rank of biases array, expected is 1, " + "but got %i instead !", + b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of initial state array, expected " + "is %i, but got %i instead !", + rank - 1, c0->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU_BI operation: wrong rank of mask array, expected is %i, " + "but got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct); + + return Status::OK(); } - DECLARE_TYPES(sru_bi) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bi) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sru_bi) { - - auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] - auto wShapeInfo = inputShape->at(1); - auto bShapeInfo = inputShape->at(2); - auto c0ShapeInfo = inputShape->at(3); - auto maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - const int rank = xShapeInfo[0]; // = 3 - const Nd4jLong time = xShapeInfo[1]; - const Nd4jLong bS = xShapeInfo[2]; - const Nd4jLong inSize = xShapeInfo[3] / 2; - - - // input shapes validation - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - - char order = shape::order(xShapeInfo); - - ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize}); - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); - return SHAPELIST(result, result); + auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] + auto wShapeInfo = inputShape->at(1); + auto bShapeInfo = inputShape->at(2); + auto c0ShapeInfo = inputShape->at(3); + auto maskShapeInfo = + block.width() > 4 + ? inputShape->at(4) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + const int rank = xShapeInfo[0]; // = 3 + const Nd4jLong time = xShapeInfo[1]; + const Nd4jLong bS = xShapeInfo[2]; + const Nd4jLong inSize = xShapeInfo[3] / 2; + + // input shapes validation + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of weights array, expected is %i, " + "but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU_BI operation: wrong rank of biases array, expected is 1, " + "but got %i instead !", + bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of initial state array, expected " + "is %i, but got %i instead !", + rank - 1, c0ShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU_BI operation: wrong rank of mask array, expected is %i, " + "but got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + char order = shape::order(xShapeInfo); + + ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, + {time, bS, 2 * inSize}); + auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); + return SHAPELIST(result, result); } - - DECLARE_TYPES(sru_bi_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sru_bi_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) { - - auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features - auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] - auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize] - auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 - auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize] - auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize] - auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize] - NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] - - // input shapes validation - const int rank = x->rankOf(); - const Nd4jLong time = x->sizeAt(0); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong inSize = x->sizeAt(2) / 2; - - REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); - REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf()); - REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); - REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf()); - REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf()); - REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHt->rankOf()); - if(mask) - REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - const std::vector ctCorrectShape = {time, bS, 2*inSize}; - - REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str()); - REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ct).c_str()); - if(mask) - REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str()); - - auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize] - auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize] - auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize] - auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize] - - helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, mask, gradI, gradW, gradB, gradC0); - - return Status::OK(); + auto x = INPUT_VARIABLE( + 0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time + // steps, bS - batch size, inSize - number of features + auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] + auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize] + auto c0 = INPUT_VARIABLE( + 3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 + auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize] + auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize] + auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize] + NDArray* mask = + block.width() > 7 + ? INPUT_VARIABLE(7) + : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] + + // input shapes validation + const int rank = x->rankOf(); + const Nd4jLong time = x->sizeAt(0); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong inSize = x->sizeAt(2) / 2; + + REQUIRE_TRUE(w->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of weights array, expected is " + "%i, but got %i instead !", + rank - 1, w->rankOf()); + REQUIRE_TRUE(b->rankOf() == 1, 0, + "SRU_BI_BP operation: wrong rank of biases array, expected is " + "1, but got %i instead !", + b->rankOf()); + REQUIRE_TRUE(c0->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of initial state array, " + "expected is %i, but got %i instead !", + rank - 1, c0->rankOf()); + REQUIRE_TRUE(ct->rankOf() == rank, 0, + "SRU_BI_BP operation: wrong rank of state array, expected is " + "%i, but got %i instead !", + rank, ct->rankOf()); + REQUIRE_TRUE(inGradC0->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of gradient c0, expected is " + "%i, but got %i instead !", + rank - 1, inGradC0->rankOf()); + REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, + "SRU_BI_BP operation: wrong rank of gradient ht, expected is " + "%i, but got %i instead !", + rank, inGradHt->rankOf()); + if (mask) + REQUIRE_TRUE(mask->rankOf() == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of mask array, expected is " + "%i, but got %i instead !", + rank - 1, mask->rankOf()); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + const std::vector ctCorrectShape = {time, bS, 2 * inSize}; + + REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0).c_str()); + REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, + "SRU_BI operation: wrong shape of state array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(ctCorrectShape).c_str(), + ShapeUtils::shapeAsString(ct).c_str()); + if (mask) + REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(mask).c_str()); + + auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize] + auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize] + auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize] + auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize] + + helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, + mask, gradI, gradW, gradB, gradC0); + + return Status::OK(); } DECLARE_SHAPE_FN(sru_bi_bp) { - - auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] - auto wShapeInfo = inputShape->at(1); - auto bShapeInfo = inputShape->at(2); - auto c0ShapeInfo = inputShape->at(3); - auto ctShapeInfo = inputShape->at(4); - auto inGradC0ShapeInfo = inputShape->at(5); - auto inGradHtShapeInfo = inputShape->at(6); - auto maskShapeInfo = block.width() > 7 ? inputShape->at(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] - - // input shapes validation - const int rank = xShapeInfo[0]; - const Nd4jLong time = xShapeInfo[1]; - const Nd4jLong bS = xShapeInfo[2]; - const Nd4jLong inSize = xShapeInfo[3] / 2; - - REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); - REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]); - REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo); - REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo); - REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]); - REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHtShapeInfo[0]); - if(maskShapeInfo) - REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); - - const std::vector wCorrectShape = {2*inSize, 6*inSize}; - const std::vector bCorrectShape = {4*inSize}; - const std::vector c0CorrectShape = {bS, 2*inSize}; - const std::vector ctCorrectShape = {time, bS, 2*inSize}; - const std::vector inGradC0CorrectShape = {bS, 2*inSize}; - const std::vector inGradHtCorrectShape = {time, bS, 2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ctShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str()); - if(maskShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str()); - - const char order = shape::order(xShapeInfo); - - ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize}); - ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize}); - ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {4 * inSize}); - ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize}); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); + auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] + auto wShapeInfo = inputShape->at(1); + auto bShapeInfo = inputShape->at(2); + auto c0ShapeInfo = inputShape->at(3); + auto ctShapeInfo = inputShape->at(4); + auto inGradC0ShapeInfo = inputShape->at(5); + auto inGradHtShapeInfo = inputShape->at(6); + auto maskShapeInfo = + block.width() > 7 + ? inputShape->at(7) + : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] + + // input shapes validation + const int rank = xShapeInfo[0]; + const Nd4jLong time = xShapeInfo[1]; + const Nd4jLong bS = xShapeInfo[2]; + const Nd4jLong inSize = xShapeInfo[3] / 2; + + REQUIRE_TRUE(wShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of weights array, expected is " + "%i, but got %i instead !", + rank - 1, wShapeInfo[0]); + REQUIRE_TRUE(bShapeInfo[0] == 1, 0, + "SRU_BI_BP operation: wrong rank of biases array, expected is " + "1, but got %i instead !", + bShapeInfo[0]); + REQUIRE_TRUE(c0ShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of initial state array, " + "expected is %i, but got %i instead !", + rank - 1, c0ShapeInfo); + REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, + "SRU_BI_BP operation: wrong rank of state array, expected is " + "%i, but got %i instead !", + rank, ctShapeInfo); + REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of gradient c0, expected is " + "%i, but got %i instead !", + rank - 1, inGradC0ShapeInfo[0]); + REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, + "SRU_BI_BP operation: wrong rank of gradient ht, expected is " + "%i, but got %i instead !", + rank, inGradHtShapeInfo[0]); + if (maskShapeInfo) + REQUIRE_TRUE(maskShapeInfo[0] == rank - 1, 0, + "SRU_BI_BP operation: wrong rank of mask array, expected is " + "%i, but got %i instead !", + rank - 1, maskShapeInfo[0]); + + const std::vector wCorrectShape = {2 * inSize, 6 * inSize}; + const std::vector bCorrectShape = {4 * inSize}; + const std::vector c0CorrectShape = {bS, 2 * inSize}; + const std::vector ctCorrectShape = {time, bS, 2 * inSize}; + const std::vector inGradC0CorrectShape = {bS, 2 * inSize}; + const std::vector inGradHtCorrectShape = {time, bS, 2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, + "SRU_BI operation: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(wCorrectShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, + "SRU_BI operation: wrong shape of biases array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(bCorrectShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of initial state array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(c0ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, + "SRU_BI operation: wrong shape of state array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(ctCorrectShape).c_str(), + ShapeUtils::shapeAsString(ctShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, + "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), + ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, + "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), + ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str()); + if (maskShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, + "SRU_BI operation: wrong shape of mask array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(c0CorrectShape).c_str(), + ShapeUtils::shapeAsString(maskShapeInfo).c_str()); + + const char order = shape::order(xShapeInfo); + + ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, + {time, bS, 2 * inSize}); + ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, + {time, 2 * inSize, 6 * inSize}); + ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, + {4 * inSize}); + ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, + {bS, 2 * inSize}); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); } -} -} +} // namespace ops +} // namespace sd #endif ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - // #if NOT_EXCLUDED(OP_sru) - // DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0); - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - // #if NOT_EXCLUDED(OP_sru_logic) - // DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0); - // #endif +/** + * Implementation of operations for Simple Recurrent Unit: "Training RNNs as + * Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru) +// DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru_logic) +// DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0); +// #endif ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ - // #if NOT_EXCLUDED(OP_sru_logic) - // DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0); - // #endif +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +// #if NOT_EXCLUDED(OP_sru_logic) +// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0); +// #endif // return 2d array evaluated though last dimension interval t1-t2 // static NDArray* timestep(const NDArray* arr, const int t1, const int t2) { // NDArray* result = new NDArray((*arr)({0,0, 0,0, t1,t2}, true)); -// result->reshapei(result->ordering(), {arr->shapeOf()[0], arr->shapeOf()[1]} ); +// result->reshapei(result->ordering(), {arr->shapeOf()[0], +// arr->shapeOf()[1]} ); // return result; // } @@ -587,11 +888,14 @@ DECLARE_SHAPE_FN(sru_bi_bp) { ///////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) { -// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features -// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] -// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] -// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 -// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] +// auto input = INPUT_VARIABLE(0); // X, input 3d tensor +// [bS x K x N], N - number of time steps, bS - batch size, K - number of +// features auto weights = INPUT_VARIABLE(1); // W, 2d tensor +// of weights [3K x K] auto bias = INPUT_VARIABLE(2); // +// B, row of biases with twice length [1 x 2*K] auto init = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state +// [bS x K] at time t=0 NDArray* mask = nullptr; // optional, 2d tensor +// of dropout mask [bS x K] // bool applyMask = false; // if (block.width() > 4) { @@ -602,13 +906,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N] // auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N] -// const int bS = input->shapeOf()[0]; // bS - batch size -// const int K = input->shapeOf()[1]; // K - number of features -// const int N = input->shapeOf()[2]; // N - number of time steps +// const int bS = input->shapeOf()[0]; // bS - batch +// size const int K = input->shapeOf()[1]; // K - +// number of features const int N = input->shapeOf()[2]; // N - number +// of time steps -// const auto wi = mmul(*weights, *input); // U [bS x 3K x N] -// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K] -// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K] +// const auto wi = mmul(*weights, *input); // U [bS x 3K +// x N] const auto bF = (*bias)({0,0, 0, K}); // +// biases for forget gate [1 x K] const auto bR = (*bias)({0,0, K,2*K}); // +// biases for reset gate [1 x K] // NDArray xt(input->dataType(), block.launchContext()); // NDArray zt(input->dataType(), block.launchContext()); @@ -616,23 +922,28 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // NDArray rt(input->dataType(), block.launchContext()); // NDArray ht(input->dataType(), block.launchContext()); // NDArray ct = *init; -// NDArray gct(state->ordering(), {bS, K}, input->dataType(), block.launchContext()); -// NDArray xmt = *input; +// NDArray gct(state->ordering(), {bS, K}, input->dataType(), +// block.launchContext()); NDArray xmt = *input; // // input = input * mask // if(applyMask) // xmt.applyBroadcast(broadcast::Multiply, {0, 1}, mask, &xmt, nullptr); // for (int t = 0; t < N; ++t) { -// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); // [bS x K x N] -> [bS x K x 1] -> [bS x K] -// zt = wi({0,0, 0, K, t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] -// ft = wi({0,0, K, 2*K, t,t+1}); ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] -// rt = wi({0,0, 2*K,3*K, t,t+1}); rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K] +// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); +// // [bS x K x N] -> [bS x K x 1] -> [bS x K] zt = wi({0,0, 0, K, +// t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] +// -> [bS x K x 1] -> [bS x K] ft = wi({0,0, K, 2*K, t,t+1}); +// ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x +// K x 1] -> [bS x K] rt = wi({0,0, 2*K,3*K, t,t+1}); +// rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x +// K x 1] -> [bS x K] // ft = sigmoid_(ft + bF); // rt = sigmoid_(rt + bR); // ct = ft * (ct - zt) + zt; -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct.applyTransform(transform::Tanh, &gct); // ht = rt * (gct - xt) + xt; @@ -672,14 +983,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // return SHAPELIST(result, result); // } - // ////////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_old, 5, 2, false, 0, 0) { -// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features -// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x inSize] -// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*inSize] -// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 -// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x inSize] +// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x +// inSize x time], time - number of time steps, bS - batch size, inSize - +// number of features auto w = INPUT_VARIABLE(1); // W, 2d +// tensor of weights [3K x inSize] auto b = INPUT_VARIABLE(2); // B, row +// of biases with twice length [1 x 2*inSize] auto c0 = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state +// [bS x inSize] at time t=0 NDArray* mask = nullptr; // optional, 2d +// tensor of dropout mask [bS x inSize] // bool applyMask = false; // if (block.width() > 4) { @@ -688,35 +1001,44 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // } // auto h = OUTPUT_VARIABLE(0); // h_t, [bS x inSize x time] -// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x time] +// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x +// time] -// const int bS = x->shapeOf()[0]; // bS - batch size -// const int inSize = x->shapeOf()[1]; // inSize - number of features -// const int time = x->shapeOf()[2]; // time - number of time steps +// const int bS = x->shapeOf()[0]; // bS - batch +// size const int inSize = x->shapeOf()[1]; // +// inSize - number of features const int time = x->shapeOf()[2]; // +// time - number of time steps // // multiplication matrix = matmul(w,x) -// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x time] -// auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS x inSize x time] -// auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); // forget gate [bS x inSize x time] -// auto wiR = (*wi)({0,0, 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time] -// auto bF = (*b) ({0,0, 0,inSize }, true); // biases for forget gate [1 x inSize] -// auto bR = (*b) ({0,0, inSize,2*inSize}, true); // biases for reset gate [1 x inSize] - -// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), *ct(nullptr), *ht(nullptr); -// auto ct_1 = c0->dup(c0->ordering()); -// auto gct = NDArrayFactory::create_(state->ordering(), {bS, inSize}, state->dataType(), state->getContext()); -// auto xmt = x->dup(x->ordering()); +// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x +// 3K x time] auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS +// x inSize x time] auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); +// // forget gate [bS x inSize x time] auto wiR = (*wi)({0,0, +// 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time] +// auto bF = (*b) ({0,0, 0,inSize }, true); // biases +// for forget gate [1 x inSize] auto bR = (*b) ({0,0, inSize,2*inSize}, +// true); // biases for reset gate [1 x inSize] + +// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), +// *ct(nullptr), *ht(nullptr); auto ct_1 = c0->dup(c0->ordering()); auto gct +// = NDArrayFactory::create_(state->ordering(), {bS, inSize}, +// state->dataType(), state->getContext()); auto xmt = +// x->dup(x->ordering()); // // x = x * mask // if(applyMask) -// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); // apply mask +// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); +// // apply mask // for (int t = 0; t < time; ++t) { -// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// zt = timestep(&wiZ, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ft = timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// rt = timestep(&wiR, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ct = timestep(state, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] -// ht = timestep(h, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] +// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] zt = timestep(&wiZ, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] ft = +// timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] rt = timestep(&wiR, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] ct = +// timestep(state, t, t+1); // [bS x inSize x time] -> [bS x +// inSize x 1] -> [bS x inSize] ht = timestep(h, t, t+1); // +// [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize] // // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) // ft->addRowVector(&bF, ft); @@ -728,7 +1050,8 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // ft->applyTransform(transform::OneMinus, ft); // ft->applyPairwiseTransform(pairwise::Multiply, *zt, nullptr); // ct->applyPairwiseTransform(pairwise::Add, *ft, nullptr); -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct->applyTransform(transform::Tanh, gct); // // ht = rt * gct + (1 - rt) * xt @@ -771,7 +1094,8 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order); -// auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShapeInfo1)); +// auto result = +// ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShapeInfo1)); // RELEASE(newShapeInfo1, block.workspace()); // return SHAPELIST(result, result); // } @@ -785,90 +1109,128 @@ DECLARE_SHAPE_FN(sru_bi_bp) { ////////////////////////////////////////////////////////////////////////// // CUSTOM_OP_IMPL(sru_bp_logic, 8, 4, true, 0, 0) { -// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features -// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize] -// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*inSize] -// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0 -// auto c = INPUT_VARIABLE(4); // C, [bS x inSize x time] -// auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize] -// auto inGradH = INPUT_VARIABLE(6); // [bS x inSize x time] -// auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] +// auto x = INPUT_VARIABLE(0); // +// X, input 3d tensor [bS x inSize x time], time - number of time steps, bS +// - batch size, inSize - number of features auto w = +// INPUT_VARIABLE(1); // W, 2d tensor of +// weights [3*inSize x inSize] auto b = INPUT_VARIABLE(2); // B, row +// of biases with twice length [1 x 2*inSize] auto c0 = +// INPUT_VARIABLE(3); // C_{0}, 2d tensor +// of initial state [bS x inSize] at time t=0 auto c = +// INPUT_VARIABLE(4); // C, [bS x inSize x +// time] auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize] auto inGradH = +// INPUT_VARIABLE(6); // [bS x inSize x +// time] auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // +// optional, 2d tensor of dropout mask [bS x inSize] // auto gradX = OUTPUT_VARIABLE(0); // [bS x inSize x time] -// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x inSize] -// auto gradB = OUTPUT_VARIABLE(2); // [2*inSize] +// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x +// inSize] auto gradB = OUTPUT_VARIABLE(2); // [2*inSize] // auto gradInit = OUTPUT_VARIABLE(3); // [bS x inSize] // // input shapes validation // const int rank = 3; -// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf()); -// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); -// REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf()); -// REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); -// REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP operation: wrong rank of cell states array, expected is %i, but got %i instead !", rank, c->rankOf()); -// REQUIRE_TRUE(inGradCt->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, expected is %i, but got %i instead !", rank-1, inGradCt->rankOf()); -// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong rank of array of cell outputs gradients, expected is %i, but got %i instead !", rank, inGradH->rankOf()); -// if(mask) -// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); +// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of +// input array, expected is %i, but got %i instead !", rank, x->rankOf()); +// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of +// weights array, expected is %i, but got %i instead !", rank-1, +// w->rankOf()); REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: +// wrong rank of biases array, expected is <=2, but got %i instead !", +// b->rankOf()); REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: +// wrong rank of initial state array, expected is %i, but got %i instead !", +// rank-1, c0->rankOf()); REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP +// operation: wrong rank of cell states array, expected is %i, but got %i +// instead !", rank, c->rankOf()); REQUIRE_TRUE(inGradCt->rankOf() == +// rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, +// expected is %i, but got %i instead !", rank-1, inGradCt->rankOf()); +// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong +// rank of array of cell outputs gradients, expected is %i, but got %i +// instead !", rank, inGradH->rankOf()); if(mask) +// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong +// rank of mask array, expected is %i, but got %i instead !", rank-1, +// mask->rankOf()); // const int bS = x->shapeOf()[0]; // const int inSize = x->shapeOf()[1]; -// const int time = x->shapeOf()[2]; // time - number of time steps +// const int time = x->shapeOf()[2]; // time - number +// of time steps // const std::string wShape = ShapeUtils::shapeAsString(w); -// const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize}); +// const std::string wCorrectShape = +// ShapeUtils::shapeAsString({3*inSize, inSize}); // // const std::string bShape = ShapeUtils::shapeAsString(b); -// // const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize}); -// const std::string c0Shape = ShapeUtils::shapeAsString(c0); -// const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize}); -// const std::string cShape = ShapeUtils::shapeAsString(c); -// const std::string cCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time}); -// const std::string inGradCtShape = ShapeUtils::shapeAsString(inGradCt); -// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize}); -// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH); -// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time}); - -// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); -// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); -// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str()); -// REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP operation: wrong shape of cell states array, expected is %s, but got %s instead !", cCorrectShape.c_str(), cShape.c_str()); -// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell state gradient, expected is %s, but got %s instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str()); -// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell outputs gradients, expected is %s, but got %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str()); +// // const std::string bCorrectShape = +// ShapeUtils::shapeAsString({2*inSize}); const std::string c0Shape = +// ShapeUtils::shapeAsString(c0); const std::string c0CorrectShape = +// ShapeUtils::shapeAsString({bS, inSize}); const std::string cShape = +// ShapeUtils::shapeAsString(c); const std::string cCorrectShape = +// ShapeUtils::shapeAsString({bS, inSize, time}); const std::string +// inGradCtShape = ShapeUtils::shapeAsString(inGradCt); const +// std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, +// inSize}); const std::string inGradHShape = +// ShapeUtils::shapeAsString(inGradH); const std::string inGradHCorrectShape +// = ShapeUtils::shapeAsString({bS, inSize, time}); + +// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape +// of weights array, expected is %s, but got %s instead !", +// wCorrectShape.c_str(), wShape.c_str()); +// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong +// shape of biases array, expected is %s, but got %s instead !", +// bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(c0Shape == +// c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, +// expected is %s, but got %s instead !", c0CorrectShape.c_str(), +// c0Shape.c_str()); REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP +// operation: wrong shape of cell states array, expected is %s, but got %s +// instead !", cCorrectShape.c_str(), cShape.c_str()); +// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: +// wrong shape of array of cell state gradient, expected is %s, but got %s +// instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str()); +// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: +// wrong shape of array of cell outputs gradients, expected is %s, but got +// %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str()); // if(mask) { // const std::string maskShape = ShapeUtils::shapeAsString(mask); -// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str()); +// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong +// shape of mask array, expected is %s, but got %s instead !", +// c0CorrectShape.c_str(), maskShape.c_str()); // } - -// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate [1 x inSize] -// const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for reset gate [1 x inSize] -// NDArray gradBias(x->ordering(), {bS, 2*inSize, time}, x->dataType(), block.launchContext()); -// NDArray gradU (x->ordering(), {bS, 3*inSize, time}, x->dataType(), block.launchContext()); -// NDArray gradHX (x->ordering(), {bS, inSize, time}, x->dataType(), block.launchContext()); -// NDArray gct (c->ordering(), {bS, inSize}, x->dataType(), block.launchContext()); +// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate +// [1 x inSize] const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for +// reset gate [1 x inSize] NDArray gradBias(x->ordering(), {bS, 2*inSize, +// time}, x->dataType(), block.launchContext()); NDArray gradU +// (x->ordering(), {bS, 3*inSize, time}, x->dataType(), +// block.launchContext()); NDArray gradHX (x->ordering(), {bS, inSize, +// time}, x->dataType(), block.launchContext()); NDArray gct (c->ordering(), +// {bS, inSize}, x->dataType(), block.launchContext()); // // x = x * mask // if(mask) -// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask +// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // +// apply mask // // multiplication matrix wi = matmul(w,x), U = WX -// const auto wi = mmul(*w, *x); // U [bS x 3K x time] +// const auto wi = mmul(*w, *x); // U [bS x 3K x time] // for (int t = time-1; t >=0 ; --t) { // // initialization -// auto xt = (*x)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] -// auto zt = wi({0,0, 0, inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto ft = wi({0,0, inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] -// auto ct = (*c)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] -// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] - -// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1} +// auto xt = (*x)({0,0, 0,0, t,t+1}); // +// [bS x inSize x time] -> [bS x inSize] auto zt = wi({0,0, 0, inSize, +// t,t+1}); // [bS x 3K x time] -> [bS x inSize] auto ft = wi({0,0, +// inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize] +// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // +// [bS x 3K x time] -> [bS x inSize] auto ct = (*c)({0,0, 0,0, +// t,t+1}); // [bS x inSize x time] -> [bS x inSize] auto inGradHt = +// (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x +// time] -> [bS x inSize] + +// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1} // ///////////////// forward // // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) // ft = sigmoid_(ft + bF); // rt = sigmoid_(rt + bR); -// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); +// // TODO T val = (activation_type == 1) ? tanh(cur) : +// ((activation_type == 2) ? reluf(cur) : cur ); // ct.applyTransform(transform::Tanh, &gct); // ///////////////// backward @@ -884,8 +1246,9 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); // NDArray gradHXt = inGradHt * rtMinus; -// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); -// NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * ftMinus; +// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - +// ft); NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * +// ftMinus; // // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; // *inGradCt = (gradCt + *inGradCt) * ft; @@ -902,17 +1265,18 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // // gradInit // gradInit->assign(inGradCt); // // gradX -// w->transposei(); // [inSize x 3K] -// gradX->assign( mmul(*w, gradU) + gradHX); -// if(mask) -// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask +// w->transposei(); // [inSize x 3K] gradX->assign( mmul(*w, gradU) + +// gradHX); if(mask) +// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, +// nullptr); // apply mask // // gradB -// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); // [1 x 2K] +// gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true); +// // [1 x 2K] // // gradW [bS x 3K x inSize] -// x->permutei({0, 2, 1}); // [bS x time x inSize] -// gradW->assign( mmul(gradU, *x) ); +// x->permutei({0, 2, 1}); // +// [bS x time x inSize] gradW->assign( mmul(gradU, *x) ); // return Status::OK(); // } @@ -930,10 +1294,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) { // auto time = inShape[3]; // char order = shape::order(inShape); -// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time}); -// ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize}); -// ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize}); -// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); - -// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); +// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, +// inSize, time}); ShapeDescriptor +// descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, +// inSize}); ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), +// order, {1, 2 * inSize}); ShapeDescriptor +// descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); + +// return +// SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), +// ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), +// ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), +// ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); // } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp index 52253f74f817..8a720c61e63b 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/sruCell.cpp @@ -22,88 +22,119 @@ #if NOT_EXCLUDED(OP_sruCell) #include -#include - +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) { - auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features - auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 - auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] - auto b = INPUT_VARIABLE(3); // biases [2*inSize] - - auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t - auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t - - const int rank = xt->rankOf(); - const int bS = xt->sizeAt(0); - const int inSize = xt->sizeAt(1); // inSize - number of features - - // input shapes validation - const std::vector correctCt_1Shape = {bS, inSize}; - const std::vector correctWShape = {inSize, 3*inSize}; - const std::vector correctBShape = {2*inSize}; - - REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str()); - REQUIRE_TRUE(w->isSameShape(correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(w).c_str()); - REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - - - // fixme: shitty initializer lists - helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct); - - return Status::OK(); + auto xt = INPUT_VARIABLE( + 0); // input [bS x inSize], bS - batch size, inSize - number of features + auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that + // is at previous time step t-1 + auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize] + auto b = INPUT_VARIABLE(3); // biases [2*inSize] + + auto ht = OUTPUT_VARIABLE( + 0); // current cell output [bS x inSize], that is at current time step t + auto ct = OUTPUT_VARIABLE( + 1); // current cell state [bS x inSize], that is at current time step t + + const int rank = xt->rankOf(); + const int bS = xt->sizeAt(0); + const int inSize = xt->sizeAt(1); // inSize - number of features + + // input shapes validation + const std::vector correctCt_1Shape = {bS, inSize}; + const std::vector correctWShape = {inSize, 3 * inSize}; + const std::vector correctBShape = {2 * inSize}; + + REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, + "SRUCELL operation: wrong shape of previous cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1).c_str()); + REQUIRE_TRUE(w->isSameShape(correctWShape), 0, + "SRUCELL operation: wrong shape of weights, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctWShape).c_str(), + ShapeUtils::shapeAsString(w).c_str()); + REQUIRE_TRUE(b->isSameShape(correctBShape), 0, + "SRUCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + + // fixme: shitty initializer lists + helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct); + + return Status::OK(); } - DECLARE_TYPES(sruCell) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(sruCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(sruCell) { - - auto xtShapeInfo = inputShape->at(0); // input [bS x inSize], bS - batch size, inSize - number of features - auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1 - auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize] - auto bShapeInfo = inputShape->at(3); // biases [2*inSize] - - const int rank = xtShapeInfo[0]; - const int bS = xtShapeInfo[1]; - const int inSize = xtShapeInfo[2]; // inSize - number of features - - // input shapes validation - const std::vector correctCt_1Shape = {bS, inSize}; - const std::vector correctWShape = {inSize, 3*inSize}; - const std::vector correctBShape = {2*inSize}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape) , 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo ,correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo ,correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj] - ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits] - - hShapeInfo[0] = cShapeInfo[0] = rank; - hShapeInfo[1] = cShapeInfo[1] = bS; - hShapeInfo[2] = cShapeInfo[2] = inSize; - - ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo)); - ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo)); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting(hShapeInfo, block.workspace()), ConstantShapeHelper::getInstance()->createFromExisting(cShapeInfo, block.workspace())); + auto xtShapeInfo = inputShape->at( + 0); // input [bS x inSize], bS - batch size, inSize - number of features + auto ct_1ShapeInfo = + inputShape->at(1); // previous cell state ct [bS x inSize], that is at + // previous time step t-1 + auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize] + auto bShapeInfo = inputShape->at(3); // biases [2*inSize] + + const int rank = xtShapeInfo[0]; + const int bS = xtShapeInfo[1]; + const int inSize = xtShapeInfo[2]; // inSize - number of features + + // input shapes validation + const std::vector correctCt_1Shape = {bS, inSize}; + const std::vector correctWShape = {inSize, 3 * inSize}; + const std::vector correctBShape = {2 * inSize}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, + "SRUCELL operation: wrong shape of previous cell state, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), + ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, correctWShape), 0, + "SRUCELL operation: wrong shape of weights, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctWShape).c_str(), + ShapeUtils::shapeAsString(wShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, + "SRUCELL operation: wrong shape of biases, expected is %s, but " + "got %s instead !", + ShapeUtils::shapeAsString(correctBShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numProj] + ALLOCATE(cShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); // [bS x numUnits] + + hShapeInfo[0] = cShapeInfo[0] = rank; + hShapeInfo[1] = cShapeInfo[1] = bS; + hShapeInfo[2] = cShapeInfo[2] = inSize; + + ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, + shape::order(ct_1ShapeInfo)); + ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, + shape::order(ct_1ShapeInfo)); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting( + hShapeInfo, block.workspace()), + ConstantShapeHelper::getInstance()->createFromExisting( + cShapeInfo, block.workspace())); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp index a5c323b59718..a37cb4c3a992 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticBidirectionalRNN.cpp @@ -19,211 +19,331 @@ // #include -#include -#include -#include +#include +#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBW = INPUT_VARIABLE(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] - - NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - switch(block.width()) { - case 8: - maxTimeStep = INPUT_VARIABLE(7); - break; - case 9: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - break; - case 10: - h0FW = INPUT_VARIABLE(7); - h0BW = INPUT_VARIABLE(8); - maxTimeStep = INPUT_VARIABLE(9); - break; - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + numUnitsBW)], that is per each time step - auto hFWFinal = OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW] - auto hBWFinal = OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf()); - REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf()); - - const Nd4jLong inRank = x->rankOf(); - const Nd4jLong time = x->sizeAt(0); - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong numUnitsFW = WxFW->sizeAt(1); - const Nd4jLong numUnitsBW = WxBW->sizeAt(1); - - const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - const std::vector expectedbFWshape = {2 * numUnitsFW}; - const std::vector expectedbBWshape = {2 * numUnitsBW}; - - REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str()); - REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str()); - REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str()); - REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str()); - if(h0FW) { - const std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str()); - } - if(h0BW) { - const std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str()); - } - if(maxTimeStep) - REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); - - // forward steps - auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), block.launchContext()); - helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, maxTimeStep, hFW, hFWFinal); - - auto seqLen = maxTimeStep; - if(seqLen == nullptr) { -// seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), block.launchContext()); // [bS] - seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, block.launchContext()); // [bS] - *seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time - } - - // reverse x - auto revOut = new NDArray(x, false, block.launchContext()); - helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1); - - // backward steps - auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), block.launchContext()); - - helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, maxTimeStep, hBW, hBWFinal); - - // reverse hBW - auto hBWcopy = new NDArray(*hBW); - helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1); - - // concatenate hFW and hBW along last third dimension - // NDArrayFactory::concat({hFW, hBW}, 2, h); - helpers::concat(block.launchContext(), {hFW, hBW}, *h, 2); - - delete hBW; - delete hFW; - delete hBWcopy; - delete revOut; - - if(seqLen != maxTimeStep) - delete seqLen; - - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto WxFW = INPUT_VARIABLE( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFW = INPUT_VARIABLE(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFW = INPUT_VARIABLE(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBW = INPUT_VARIABLE( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBW = INPUT_VARIABLE(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBW = INPUT_VARIABLE(6); // biases for backward RNN, [2*v] + + NDArray* h0FW = nullptr; // initial cell output for forward RNN (at time + // step = 0) [bS x numUnitsFW] + NDArray* h0BW = nullptr; // initial cell output for backward RNN (at time + // step = 0) [bS x numUnitsBW] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + switch (block.width()) { + case 8: + maxTimeStep = INPUT_VARIABLE(7); + break; + case 9: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + break; + case 10: + h0FW = INPUT_VARIABLE(7); + h0BW = INPUT_VARIABLE(8); + maxTimeStep = INPUT_VARIABLE(9); + break; + } + + auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + + // numUnitsBW)], that is per each time step + auto hFWFinal = + OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW] + auto hBWFinal = + OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + x->rankOf()); + REQUIRE_TRUE( + WxFW->rankOf() == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFW->rankOf()); + REQUIRE_TRUE( + WxBW->rankOf() == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBW->rankOf()); + + const Nd4jLong inRank = x->rankOf(); + const Nd4jLong time = x->sizeAt(0); + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong numUnitsFW = WxFW->sizeAt(1); + const Nd4jLong numUnitsBW = WxBW->sizeAt(1); + + const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + const std::vector expectedbFWshape = {2 * numUnitsFW}; + const std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFW).c_str()); + REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBW).c_str()); + REQUIRE_TRUE( + bFW->isSameShape(expectedbFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFW).c_str()); + REQUIRE_TRUE( + bBW->isSameShape(expectedbBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBW).c_str()); + if (h0FW) { + const std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FW).c_str()); + } + if (h0BW) { + const std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BW).c_str()); + } + if (maxTimeStep) + REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStep).c_str()); + + // forward steps + auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), + block.launchContext()); + helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, + maxTimeStep, hFW, hFWFinal); + + auto seqLen = maxTimeStep; + if (seqLen == nullptr) { + // seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), + // block.launchContext()); // [bS] + seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, + block.launchContext()); // [bS] + *seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time + } + + // reverse x + auto revOut = new NDArray(x, false, block.launchContext()); + helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1); + + // backward steps + auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), + block.launchContext()); + + helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, + maxTimeStep, hBW, hBWFinal); + + // reverse hBW + auto hBWcopy = new NDArray(*hBW); + helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1); + + // concatenate hFW and hBW along last third dimension + // NDArrayFactory::concat({hFW, hBW}, 2, h); + helpers::concat(block.launchContext(), {hFW, hBW}, *h, 2); + + delete hBW; + delete hFW; + delete hBWcopy; + delete revOut; + + if (seqLen != maxTimeStep) delete seqLen; + + return Status::OK(); } - DECLARE_TYPES(static_bidirectional_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - -DECLARE_SHAPE_FN(static_bidirectional_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto WxFWShapeInfo = inputShape->at(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - auto WhFWShapeInfo = inputShape->at(2); // hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - auto bFWShapeInfo = inputShape->at(3); // biases for forward RNN, [2*numUnitsFW] - auto WxBWShapeInfo = inputShape->at(4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - auto WhBWShapeInfo = inputShape->at(5); // hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - auto bBWShapeInfo = inputShape->at(6); // biases for backward RNN, [2*numUnitsBW] - - Nd4jLong const* h0FWShapeInfo = nullptr; // initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW] - Nd4jLong const* h0BWShapeInfo = nullptr; // initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW] - Nd4jLong const* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - switch(block.width()) { - case 8: - maxTimeStepShapeInfo = inputShape->at(7); - break; - case 9: - h0FWShapeInfo = inputShape->at(7); - h0BWShapeInfo = inputShape->at(8); - break; - case 10: - h0FWShapeInfo = inputShape->at(7); - h0BWShapeInfo = inputShape->at(8); - maxTimeStepShapeInfo = inputShape->at(9); - break; - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxFWShapeInfo[0] == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFWShapeInfo[0]); - REQUIRE_TRUE(WxBWShapeInfo[0] == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBWShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int numUnitsFW = WxFWShapeInfo[2]; - const int numUnitsBW = WxBWShapeInfo[2]; - - const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; - const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; - const std::vector expectedbFWshape = {2 * numUnitsFW}; - const std::vector expectedbBWshape = {2 * numUnitsBW}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str()); - if(h0FWShapeInfo) { - const std::vector expectedh0FWshape = {bS, numUnitsFW}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str()); - } - if(h0BWShapeInfo) { - const std::vector expectedh0BWshape = {bS, numUnitsBW}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = time; - hShapeInfo[2] = hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; - hShapeInfo[3] = numUnitsFW + numUnitsBW; - hFWFinalPrevShapeInfo[2] = numUnitsFW; - hBWFinalPrevShapeInfo[2] = numUnitsBW; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo)); +DECLARE_TYPES(static_bidirectional_rnn) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - - - - - - -} +DECLARE_SHAPE_FN(static_bidirectional_rnn) { + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto WxFWShapeInfo = inputShape->at( + 1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW] + auto WhFWShapeInfo = + inputShape->at(2); // hidden-to-hidden weights for forward RNN, + // [numUnitsFW x numUnitsFW] + auto bFWShapeInfo = + inputShape->at(3); // biases for forward RNN, [2*numUnitsFW] + auto WxBWShapeInfo = inputShape->at( + 4); // input-to-hidden weights for backward RNN, [inSize x numUnitsBW] + auto WhBWShapeInfo = + inputShape->at(5); // hidden-to-hidden weights for backward RNN, + // [numUnitsBW x numUnitsBW] + auto bBWShapeInfo = + inputShape->at(6); // biases for backward RNN, [2*numUnitsBW] + + Nd4jLong const* h0FWShapeInfo = + nullptr; // initial cell output for forward RNN (at time step = 0) [bS x + // numUnitsFW] + Nd4jLong const* h0BWShapeInfo = + nullptr; // initial cell output for backward RNN (at time step = 0) [bS x + // numUnitsBW] + Nd4jLong const* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + switch (block.width()) { + case 8: + maxTimeStepShapeInfo = inputShape->at(7); + break; + case 9: + h0FWShapeInfo = inputShape->at(7); + h0BWShapeInfo = inputShape->at(8); + break; + case 10: + h0FWShapeInfo = inputShape->at(7); + h0BWShapeInfo = inputShape->at(8); + maxTimeStepShapeInfo = inputShape->at(9); + break; + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input array must " + "have rank = 3, but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE( + WxFWShapeInfo[0] == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for forward RNN) must have rank = 2, but got %i instead !", + WxFWShapeInfo[0]); + REQUIRE_TRUE( + WxBWShapeInfo[0] == 2, 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights " + "array (for backward RNN) must have rank = 2, but got %i instead !", + WxBWShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int numUnitsFW = WxFWShapeInfo[2]; + const int numUnitsBW = WxBWShapeInfo[2]; + + const std::vector expectedWhFWshape = {numUnitsFW, numUnitsFW}; + const std::vector expectedWhBWshape = {numUnitsBW, numUnitsBW}; + const std::vector expectedbFWshape = {2 * numUnitsFW}; + const std::vector expectedbBWshape = {2 * numUnitsBW}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for forward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), + ShapeUtils::shapeAsString(WhFWShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "hidden-to-hidden weights array (for backward RNN), expected is " + "%s but got %s instead !", + ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), + ShapeUtils::shapeAsString(WhBWShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for forward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbFWshape).c_str(), + ShapeUtils::shapeAsString(bFWShapeInfo).c_str()); + REQUIRE_TRUE( + ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array " + "(for backward RNN), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbBWshape).c_str(), + ShapeUtils::shapeAsString(bBWShapeInfo).c_str()); + if (h0FWShapeInfo) { + const std::vector expectedh0FWshape = {bS, numUnitsFW}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), + 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for forward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), + ShapeUtils::shapeAsString(h0FWShapeInfo).c_str()); + } + if (h0BWShapeInfo) { + const std::vector expectedh0BWshape = {bS, numUnitsBW}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), + 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "initial cell output array (for backward RNN), expected is %s " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), + ShapeUtils::shapeAsString(h0BWShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, + "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of " + "maxTimeStep array, expected is [%i], but got %s instead !", + bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), + *hBWFinalPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hFWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + ALLOCATE(hBWFinalPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = time; + hShapeInfo[2] = hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS; + hShapeInfo[3] = numUnitsFW + numUnitsBW; + hFWFinalPrevShapeInfo[2] = numUnitsFW; + hBWFinalPrevShapeInfo[2] = numUnitsBW; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), + CONSTANT(hBWFinalPrevShapeInfo)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp index 1f614ba6ab43..384493fcddc5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/staticRNN.cpp @@ -19,130 +19,184 @@ // #include -#include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(static_rnn, 4, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] - auto Wx = INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] - auto Wh = INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] - - NDArray* h0 = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - if(block.width() == 5) { - if ((*INPUT_VARIABLE(4)).rankOf() == 2) - h0 = INPUT_VARIABLE(4); - else - maxTimeStep = INPUT_VARIABLE(4); - } - else if(block.width() == 6) { - h0 = INPUT_VARIABLE(4); - maxTimeStep = INPUT_VARIABLE(5); - } - - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] - auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final non-zero output [bS x numUnits] - - REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", x->rankOf()); - REQUIRE_TRUE(Wx->rankOf() == 2, 0, "STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", Wx->rankOf()); - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int inSize = x->sizeAt(2); - const int numUnits = Wx->sizeAt(1); - - const std::vector expectedWhShape = {numUnits, numUnits}; - const std::vector expectedbShape = {2 * numUnits}; - - REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); - REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - if(h0) { - const std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); - } - if(maxTimeStep) - REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str()); - - - helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] + auto Wx = + INPUT_VARIABLE(1); // input-to-hidden weights, [inSize x numUnits] + auto Wh = + INPUT_VARIABLE(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto b = INPUT_VARIABLE(3); // biases for, [2*numUnits] + + NDArray* h0 = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + NDArray* maxTimeStep = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + if (block.width() == 5) { + if ((*INPUT_VARIABLE(4)).rankOf() == 2) + h0 = INPUT_VARIABLE(4); + else + maxTimeStep = INPUT_VARIABLE(4); + } else if (block.width() == 6) { + h0 = INPUT_VARIABLE(4); + maxTimeStep = INPUT_VARIABLE(5); + } + + auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numUnits] + auto hFinal = OUTPUT_VARIABLE(1); // at the end it will store cell final + // non-zero output [bS x numUnits] + + REQUIRE_TRUE(x->rankOf() == 3, 0, + "STATIC_RNN custom operation: input array x must have rank = 3, " + "but got %i instead !", + x->rankOf()); + REQUIRE_TRUE(Wx->rankOf() == 2, 0, + "STATIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + Wx->rankOf()); + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int inSize = x->sizeAt(2); + const int numUnits = Wx->sizeAt(1); + + const std::vector expectedWhShape = {numUnits, numUnits}; + const std::vector expectedbShape = {2 * numUnits}; + + REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, + "STATIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, + "STATIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbShape).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + if (h0) { + const std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, + "STATIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0).c_str()); + } + if (maxTimeStep) + REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, + "STATIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS}).c_str(), + ShapeUtils::shapeAsString(maxTimeStep).c_str()); + + helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, + hFinal); + + return Status::OK(); } DECLARE_TYPES(static_rnn) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - DECLARE_SHAPE_FN(static_rnn) { - - auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - auto WxShapeInfo = inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] - auto WhShapeInfo = inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] - auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] - - const Nd4jLong* h0ShapeInfo = nullptr; // initial cell output (at time step = 0) [bS x numUnits] - const Nd4jLong* maxTimeStepShapeInfo = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - if(block.width() == 5) { - if (inputShape->at(4)[0] == 2) - h0ShapeInfo = inputShape->at(4); - else - maxTimeStepShapeInfo = inputShape->at(4); - } - else if(block.width() == 6) { - h0ShapeInfo = inputShape->at(4); - maxTimeStepShapeInfo = inputShape->at(5); - } - - REQUIRE_TRUE(xShapeInfo[0] == 3, 0, "STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead !", xShapeInfo[0]); - REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, "STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead !", WxShapeInfo[0]); - - const int inRank = xShapeInfo[0]; - const int time = xShapeInfo[1]; - const int bS = xShapeInfo[2]; - const int numUnits = WxShapeInfo[2]; - - const std::vector expectedWhShape = {numUnits, numUnits}; - const std::vector expectedbShape = {2 * numUnits}; - - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); - if(h0ShapeInfo){ - const std::vector expectedh0Shape = {bS, numUnits}; - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - } - if(maxTimeStepShapeInfo) - REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); - - // evaluate output shapeInfos - Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), Nd4jLong); - ALLOCATE(hPrevShapeInfo, block.workspace(), shape::shapeInfoLength(inRank-1), Nd4jLong); - - hShapeInfo[0] = inRank; - hPrevShapeInfo[0] = inRank-1; - hShapeInfo[1] = time; - hShapeInfo[2] = hPrevShapeInfo[1] = bS; - hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - ShapeUtils::updateStridesAndType(hPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo)); - - return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); + auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] + auto WxShapeInfo = + inputShape->at(1); // input-to-hidden weights, [inSize x numUnits] + auto WhShapeInfo = + inputShape->at(2); // hidden-to-hidden weights, [numUnits x numUnits] + auto bShapeInfo = inputShape->at(3); // biases for, [2*numUnits] + + const Nd4jLong* h0ShapeInfo = + nullptr; // initial cell output (at time step = 0) [bS x numUnits] + const Nd4jLong* maxTimeStepShapeInfo = + nullptr; // vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in + // batch, this means there are no calculations for time >= + // maxTimeStep + + if (block.width() == 5) { + if (inputShape->at(4)[0] == 2) + h0ShapeInfo = inputShape->at(4); + else + maxTimeStepShapeInfo = inputShape->at(4); + } else if (block.width() == 6) { + h0ShapeInfo = inputShape->at(4); + maxTimeStepShapeInfo = inputShape->at(5); + } + + REQUIRE_TRUE(xShapeInfo[0] == 3, 0, + "STATIC_RNN custom operation: input array x must have rank = 3, " + "but got %i instead !", + xShapeInfo[0]); + REQUIRE_TRUE(WxShapeInfo[0] == 2, 0, + "STATIC_RNN custom operation: input-to-hidden weights array " + "must have rank = 2, but got %i instead !", + WxShapeInfo[0]); + + const int inRank = xShapeInfo[0]; + const int time = xShapeInfo[1]; + const int bS = xShapeInfo[2]; + const int numUnits = WxShapeInfo[2]; + + const std::vector expectedWhShape = {numUnits, numUnits}; + const std::vector expectedbShape = {2 * numUnits}; + + REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, + "STATIC_RNN custom operation: wrong shape of hidden-to-hidden " + "weights array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWhShape).c_str(), + ShapeUtils::shapeAsString(WhShapeInfo).c_str()); + REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, + "STATIC_RNN custom operation: wrong shape of biases array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedbShape).c_str(), + ShapeUtils::shapeAsString(bShapeInfo).c_str()); + if (h0ShapeInfo) { + const std::vector expectedh0Shape = {bS, numUnits}; + REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, + "STATIC_RNN custom operation: wrong shape of initial cell " + "output array, expected is %s but got %s instead !", + ShapeUtils::shapeAsString(expectedh0Shape).c_str(), + ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); + } + if (maxTimeStepShapeInfo) + REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, + "STATIC_RNN custom operation: wrong shape of maxTimeStep " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS}).c_str(), + ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str()); + + // evaluate output shapeInfos + Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr); + ALLOCATE(hShapeInfo, block.workspace(), shape::shapeInfoLength(inRank), + Nd4jLong); + ALLOCATE(hPrevShapeInfo, block.workspace(), + shape::shapeInfoLength(inRank - 1), Nd4jLong); + + hShapeInfo[0] = inRank; + hPrevShapeInfo[0] = inRank - 1; + hShapeInfo[1] = time; + hShapeInfo[2] = hPrevShapeInfo[1] = bS; + hShapeInfo[3] = hPrevShapeInfo[2] = numUnits; + + ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + ShapeUtils::updateStridesAndType(hPrevShapeInfo, xShapeInfo, + shape::order(xShapeInfo)); + + return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hPrevShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp index 4e00edcdd567..18d9124322dc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/relu_layer.cpp @@ -21,45 +21,62 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(relu_layer, 3, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); +namespace ops { +CUSTOM_OP_IMPL(relu_layer, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); - REQUIRE_TRUE(x->isMatrix(), 0, "relu_layer: x argument should be a 2D tensor, but got rank %i instead!", x->rankOf()); - REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); - REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); - REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); - REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", x->sizeAt(1), w->sizeAt(0)); + REQUIRE_TRUE( + x->isMatrix(), 0, + "relu_layer: x argument should be a 2D tensor, but got rank %i instead!", + x->rankOf()); + REQUIRE_TRUE(w->isMatrix(), 0, + "relu_layer: weights argument should be a 2D tensor, but got " + "rank %i instead!", + w->rankOf()); + REQUIRE_TRUE(b->isVector(), 0, + "relu_layer: biases argument should be a 1D tensor, but got " + "rank %i instead!", + b->rankOf()); + REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, + "relu_layer: biases array length should match to columns of " + "weights matrix, however got length = %i and columns = %i!", + b->lengthOf(), w->sizeAt(1)); + REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, + "relu_layer: number of x columns should match to row number of " + "weights matrix, but got x_columns = %i and weights_rows = %i!", + x->sizeAt(1), w->sizeAt(0)); - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - sd::ops::xw_plus_b op; - auto status = op.execute({x, w, b}, {output}); - REQUIRE_TRUE(Status::OK() == status, 0, "relu_layer: xw_plus_b op failed on input data."); + sd::ops::xw_plus_b op; + auto status = op.execute({x, w, b}, {output}); + REQUIRE_TRUE(Status::OK() == status, 0, + "relu_layer: xw_plus_b op failed on input data."); - auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; + auto scalar = block.numT() > 0 ? block.getTArguments().at(0) : 0.0; - output->applyScalar(sd::scalar::RELU, scalar, *output); + output->applyScalar(sd::scalar::RELU, scalar, *output); - return Status::OK(); - } - - DECLARE_SHAPE_FN(relu_layer) { - auto inShape = inputShape->at(0); - auto weightsShape = inputShape->at(1); - auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.workspace()); + return Status::OK(); +} - return SHAPELIST(outputShape); - } +DECLARE_SHAPE_FN(relu_layer) { + auto inShape = inputShape->at(0); + auto weightsShape = inputShape->at(1); + auto outputShape = ShapeUtils::matrixProductShape( + inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), + block.workspace()); - DECLARE_TYPES(relu_layer) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) -// ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } + return SHAPELIST(outputShape); } +DECLARE_TYPES(relu_layer) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + // ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index bc292e09022a..fba7d33fe91a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -28,53 +28,61 @@ namespace sd { namespace ops { - DECLARE_TYPES(softmax) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(softmax) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "SOFTMAX OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "SOFTMAX OP: the value of input integer parameter (dimension) must be " + "less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *output, dim); + helpers::softmax(block.launchContext(), *input, *output, dim); - return Status::OK(); + return Status::OK(); } - CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; + const int rank = input->rankOf(); + const int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + REQUIRE_TRUE( + dim < rank, 0, + "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be " + "less than input array rank %i, but got dimension = %i instead !", + rank, dim); - helpers::softmax(block.launchContext(), *input, *gradI, dim); + helpers::softmax(block.launchContext(), *input, *gradI, dim); - auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); - gradI->assign(*gradI * (*gradO - sumAlongDim)); + auto sumAlongDim = + (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true); + gradI->assign(*gradI * (*gradO - sumAlongDim)); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(softmax_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - -} +DECLARE_TYPES(softmax_bp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp index 34f612887b0b..229dd02b3658 100644 --- a/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/xw_plus_b.cpp @@ -14,129 +14,156 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // xw_plus_b op. Created by GS 31.01.2018 - // @author Oleg Semeniv - // - // +// +// xw_plus_b op. Created by GS 31.01.2018 +// @author Oleg Semeniv +// +// #include #if NOT_EXCLUDED(OP_xw_plus_b) +#include #include #include -#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); - - if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) - return Status::OK(); - - const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); - - auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); - - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(z->rankOf() == 2, 0, "xw_plus_b: Output array should have rank equal 2, but got instead %i!", z->rankOf()); - - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - - // multiply x to y - MmulHelper::mmul(x, w, z, 1.0, 0.0); +namespace ops { +CUSTOM_OP_IMPL(xw_plus_b, 3, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); - // adding b vector - z->addiRowVector(*b); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); - if (bTranspose) - delete w; + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty()) + return Status::OK(); - return Status::OK(); - } + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); - DECLARE_SHAPE_FN(xw_plus_b) { + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) + : INPUT_VARIABLE(1); - auto weights = INPUT_VARIABLE(1); + REQUIRE_TRUE( + x->rankOf() == 2, 0, + "xw_plus_b: Input x array should have rank equal 2, but got instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b: Input weights array should have rank equal 2, but " + "got instead %i!", + w->rankOf()); + REQUIRE_TRUE( + z->rankOf() == 2, 0, + "xw_plus_b: Output array should have rank equal 2, but got instead %i!", + z->rankOf()); - const int nWeightsFormat = block.numI() > 0 ? INT_ARG(0) : 0; + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, + "xw_plus_b: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo(*weights, block.workspace()) : inputShape->at(1); + // multiply x to y + MmulHelper::mmul(x, w, z, 1.0, 0.0); - auto outputShape = ShapeUtils::matrixProductShape(inputShape->at(0), weightsShape, false, false, - ArrayOptions::dataType(inputShape->at(0)), block.workspace()); + // adding b vector + z->addiRowVector(*b); - return SHAPELIST(outputShape); - } + if (bTranspose) delete w; - DECLARE_TYPES(xw_plus_b) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - - - CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdb = OUTPUT_VARIABLE(2); - - if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || dLdz->isEmpty()) - return Status::OK(); + return Status::OK(); +} - const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); +DECLARE_SHAPE_FN(xw_plus_b) { + auto weights = INPUT_VARIABLE(1); - auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) : INPUT_VARIABLE(1); + const int nWeightsFormat = block.numI() > 0 ? INT_ARG(0) : 0; - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, "xw_plus_b BP: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + auto weightsShape = (1 == nWeightsFormat) ? ShapeUtils::evalTranspShapeInfo( + *weights, block.workspace()) + : inputShape->at(1); - auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) : OUTPUT_VARIABLE(1); + auto outputShape = ShapeUtils::matrixProductShape( + inputShape->at(0), weightsShape, false, false, + ArrayOptions::dataType(inputShape->at(0)), block.workspace()); - // dLdb - dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, { 0 })); + return SHAPELIST(outputShape); +} - matmul_bp mmul_bp; - mmul_bp.execute({ x, w, dLdz }, std::vector{dLdx, dLdw}, {}, {}, {}); +DECLARE_TYPES(xw_plus_b) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if (bTranspose) { - delete w; - delete dLdw; - } - return Status::OK(); - } +CUSTOM_OP_IMPL(xw_plus_b_bp, 4, 3, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || INPUT_VARIABLE(1)->isEmpty() || b->isEmpty() || + dLdz->isEmpty()) + return Status::OK(); + + const bool bTranspose = (block.numI() > 0 ? INT_ARG(0) == 1 : false); + + auto w = bTranspose ? new NDArray(INPUT_VARIABLE(1)->transpose()) + : INPUT_VARIABLE(1); + + REQUIRE_TRUE(x->rankOf() == 2, 0, + "xw_plus_b BP: Input x array should have rank equal 2, but got " + "instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b BP: Input weights array should have rank equal 2, " + "but got instead %i!", + w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, + "xw_plus_b BP: Output array should have rank equal 2, but got " + "instead %i!", + dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(-1), 0, + "xw_plus_b BP: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + dLdz->sizeAt(-1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(-1)); + + auto dLdw = (bTranspose) ? new NDArray(OUTPUT_VARIABLE(1)->transpose()) + : OUTPUT_VARIABLE(1); + + // dLdb + dLdb->assign(dLdz->reduceAlongDimension(reduce::Sum, {0})); + + matmul_bp mmul_bp; + mmul_bp.execute({x, w, dLdz}, std::vector{dLdx, dLdw}, {}, {}, {}); + + if (bTranspose) { + delete w; + delete dLdw; + } + return Status::OK(); +} - DECLARE_SHAPE_FN(xw_plus_b_bp) { - Nd4jLong* xShapeInfo; - Nd4jLong* wShapeInfo; - Nd4jLong* bShapeInfo; +DECLARE_SHAPE_FN(xw_plus_b_bp) { + Nd4jLong* xShapeInfo; + Nd4jLong* wShapeInfo; + Nd4jLong* bShapeInfo; - COPY_SHAPE(inputShape->at(0), xShapeInfo); - COPY_SHAPE(inputShape->at(1), wShapeInfo); - COPY_SHAPE(inputShape->at(2), bShapeInfo); - return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), CONSTANT(bShapeInfo)); - } + COPY_SHAPE(inputShape->at(0), xShapeInfo); + COPY_SHAPE(inputShape->at(1), wShapeInfo); + COPY_SHAPE(inputShape->at(2), bShapeInfo); + return SHAPELIST(CONSTANT(xShapeInfo), CONSTANT(wShapeInfo), + CONSTANT(bShapeInfo)); +} - DECLARE_TYPES(xw_plus_b_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - } +DECLARE_TYPES(xw_plus_b_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/parity_ops.cpp b/libnd4j/include/ops/declarable/generic/parity_ops.cpp index 3595512a258f..32846bf43c5b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops.cpp @@ -23,27 +23,25 @@ #ifndef LIBND4J_PARITY_OPS_H #define LIBND4J_PARITY_OPS_H -#include -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include +#include #include #include -#include -#include -#include #include -#include - -namespace sd { - namespace ops { +#include +#include - } -} +#include +#include -#endif //LIBND4J_PARITY_OPS_H +namespace sd { +namespace ops {} +} // namespace sd +#endif // LIBND4J_PARITY_OPS_H diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp index 362d51c83c35..b796e715bd23 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/assert.cpp @@ -25,22 +25,21 @@ #include namespace sd { - namespace ops { - OP_IMPL(Assert, 1, 1, false) { - auto x = INPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Assert, 1, 1, false) { + auto x = INPUT_VARIABLE(0); - if (!x->e(0)) { - REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", block.getNodeId()); - } + if (!x->e(0)) { + REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", + block.getNodeId()); + } - return Status::OK(); - } - DECLARE_TYPES(Assert) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setSameMode(true); - } - } + return Status::OK(); } +DECLARE_TYPES(Assert) { + getOpDescriptor()->setAllowedInputTypes(DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp index 3b9fc3916fa3..af8de75b5a71 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bincount.cpp @@ -26,98 +26,97 @@ #include namespace sd { - namespace ops { - DECLARE_TYPES(bincount) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::INT32) - ->setAllowedInputTypes(1, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - - CUSTOM_OP_IMPL(bincount, 1, 1, false, 0, 0) { - auto values = INPUT_VARIABLE(0); - - NDArray *weights = nullptr; - - int maxLength = -1; - int minLength = 0; - int maxIndex = values->argMax(); - maxLength = values->e(maxIndex) + 1; - - if (block.numI() > 0) { - minLength = sd::math::nd4j_max(INT_ARG(0), 0); - if (block.numI() == 2) - maxLength = sd::math::nd4j_min(maxLength, INT_ARG(1)); - } - - if (block.width() == 2) { // the second argument is weights - weights = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals"); - } - else if (block.width() == 3) { // the second argument is min and the third is max - auto min= INPUT_VARIABLE(1); - auto max = INPUT_VARIABLE(2); - minLength = min->e(0); - maxLength = max->e(0); - } - else if (block.width() > 3) { - auto min= INPUT_VARIABLE(2); - auto max = INPUT_VARIABLE(3); - minLength = min->e(0); - maxLength = max->e(0); - weights = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->isSameShape(weights), 0, "bincount: the input and weights shapes should be equals"); - - } - minLength = sd::math::nd4j_max(minLength, 0); - maxLength = sd::math::nd4j_min(maxLength, values->e(maxIndex) + 1); - - auto result = OUTPUT_VARIABLE(0); - result->assign(0.0f); - - helpers::adjustWeights(block.launchContext(), values, weights, result, minLength, maxLength); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(bincount) { - auto shapeList = SHAPELIST(); - auto in = INPUT_VARIABLE(0); - sd::DataType dtype = DataType::INT32; - if (block.width() > 1) - dtype = ArrayOptions::dataType(inputShape->at(1)); - else if (block.numI() > 2) - dtype = (sd::DataType)INT_ARG(2); - - int maxIndex = in->argMax(); - int maxLength = in->e(maxIndex) + 1; - int outLength = maxLength; - if (block.numI() > 0) - outLength = sd::math::nd4j_max(maxLength, INT_ARG(0)); - - if (block.numI() > 1) - outLength = sd::math::nd4j_min(outLength, INT_ARG(1)); - - if (block.width() == 3) { // the second argument is min and the third is max - auto min= INPUT_VARIABLE(1)->e(0); - auto max = INPUT_VARIABLE(2)->e(0); - outLength = sd::math::nd4j_max(maxLength, min); - outLength = sd::math::nd4j_min(outLength, max); - } - else if (block.width() > 3) { - auto min= INPUT_VARIABLE(2); - auto max = INPUT_VARIABLE(3); - outLength = sd::math::nd4j_max(maxLength, min->e(0)); - outLength = sd::math::nd4j_min(outLength, max->e(0)); - } - - auto newshape = ConstantShapeHelper::getInstance()->vectorShapeInfo(outLength, dtype); - - shapeList->push_back(newshape); - return shapeList; - } - - } +namespace ops { +DECLARE_TYPES(bincount) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::INT32) + ->setAllowedInputTypes(1, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +CUSTOM_OP_IMPL(bincount, 1, 1, false, 0, 0) { + auto values = INPUT_VARIABLE(0); + + NDArray *weights = nullptr; + + int maxLength = -1; + int minLength = 0; + int maxIndex = values->argMax(); + maxLength = values->e(maxIndex) + 1; + + if (block.numI() > 0) { + minLength = sd::math::nd4j_max(INT_ARG(0), 0); + if (block.numI() == 2) + maxLength = sd::math::nd4j_min(maxLength, INT_ARG(1)); + } + + if (block.width() == 2) { // the second argument is weights + weights = INPUT_VARIABLE(1); + REQUIRE_TRUE(values->isSameShape(weights), 0, + "bincount: the input and weights shapes should be equals"); + } else if (block.width() == + 3) { // the second argument is min and the third is max + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); + minLength = min->e(0); + maxLength = max->e(0); + } else if (block.width() > 3) { + auto min = INPUT_VARIABLE(2); + auto max = INPUT_VARIABLE(3); + minLength = min->e(0); + maxLength = max->e(0); + weights = INPUT_VARIABLE(1); + REQUIRE_TRUE(values->isSameShape(weights), 0, + "bincount: the input and weights shapes should be equals"); + } + minLength = sd::math::nd4j_max(minLength, 0); + maxLength = sd::math::nd4j_min(maxLength, values->e(maxIndex) + 1); + + auto result = OUTPUT_VARIABLE(0); + result->assign(0.0f); + + helpers::adjustWeights(block.launchContext(), values, weights, result, + minLength, maxLength); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(bincount) { + auto shapeList = SHAPELIST(); + auto in = INPUT_VARIABLE(0); + sd::DataType dtype = DataType::INT32; + if (block.width() > 1) + dtype = ArrayOptions::dataType(inputShape->at(1)); + else if (block.numI() > 2) + dtype = (sd::DataType)INT_ARG(2); + + int maxIndex = in->argMax(); + int maxLength = in->e(maxIndex) + 1; + int outLength = maxLength; + if (block.numI() > 0) outLength = sd::math::nd4j_max(maxLength, INT_ARG(0)); + + if (block.numI() > 1) outLength = sd::math::nd4j_min(outLength, INT_ARG(1)); + + if (block.width() == 3) { // the second argument is min and the third is max + auto min = INPUT_VARIABLE(1)->e(0); + auto max = INPUT_VARIABLE(2)->e(0); + outLength = sd::math::nd4j_max(maxLength, min); + outLength = sd::math::nd4j_min(outLength, max); + } else if (block.width() > 3) { + auto min = INPUT_VARIABLE(2); + auto max = INPUT_VARIABLE(3); + outLength = sd::math::nd4j_max(maxLength, min->e(0)); + outLength = sd::math::nd4j_min(outLength, max->e(0)); + } + + auto newshape = + ConstantShapeHelper::getInstance()->vectorShapeInfo(outLength, dtype); + + shapeList->push_back(newshape); + return shapeList; +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp index 4fc31dd51adc..b791c8987aa1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp @@ -28,66 +28,81 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf()); - REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf()); - REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !"); - - // contract shapeInfos, neglect and don't fill strides, ews, order - // shapes are of interest only - std::vector xShapeInfo(shape::shapeInfoLength(x->lengthOf())); - std::vector yShapeInfo(shape::shapeInfoLength(y->lengthOf())); - - // fill rank and data type - xShapeInfo[0] = x->lengthOf(); - yShapeInfo[0] = y->lengthOf(); - ArrayOptions::setDataType(xShapeInfo.data(), sd::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose - ArrayOptions::setDataType(yShapeInfo.data(), sd::DataType::INT64); - - for (Nd4jLong i = 0; i < x->lengthOf(); ++i) - xShapeInfo[i + 1] = x->e(i); - - for (Nd4jLong i = 0; i < y->lengthOf(); ++i) - yShapeInfo[i + 1] = y->e(i); - - const Nd4jLong* poinerOnOutShapeInfo = nullptr; - - const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace()); - - REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str()); - - for (Nd4jLong i = 0; i < z->lengthOf(); ++i) - z->p(i, poinerOnOutShapeInfo[i + 1]); - - return Status::OK(); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->rankOf() == 1, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have " + "rank = 1, but got %i instead!", + x->rankOf()); + REQUIRE_TRUE(y->rankOf() == 1, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have " + "rank = 1, but got %i instead!", + y->rankOf()); + REQUIRE_TRUE(x->dataType() == y->dataType(), 0, + "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the " + "same integer type !"); + + // contract shapeInfos, neglect and don't fill strides, ews, order + // shapes are of interest only + std::vector xShapeInfo(shape::shapeInfoLength(x->lengthOf())); + std::vector yShapeInfo(shape::shapeInfoLength(y->lengthOf())); + + // fill rank and data type + xShapeInfo[0] = x->lengthOf(); + yShapeInfo[0] = y->lengthOf(); + ArrayOptions::setDataType( + xShapeInfo.data(), + sd::DataType::INT64); // fill with some data type, it doesn't matter what + // type exactly to choose + ArrayOptions::setDataType(yShapeInfo.data(), sd::DataType::INT64); + + for (Nd4jLong i = 0; i < x->lengthOf(); ++i) + xShapeInfo[i + 1] = x->e(i); + + for (Nd4jLong i = 0; i < y->lengthOf(); ++i) + yShapeInfo[i + 1] = y->e(i); + + const Nd4jLong* poinerOnOutShapeInfo = nullptr; + + const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo( + xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, + block.launchContext()->getWorkspace()); + + REQUIRE_TRUE(isBroadcastPossible, 0, + "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s " + "and %s are not suitable for broadcast operation !", + ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), + ShapeUtils::shapeAsString(yShapeInfo.data()).c_str()); + + for (Nd4jLong i = 0; i < z->lengthOf(); ++i) + z->p(i, poinerOnOutShapeInfo[i + 1]); + + return Status::OK(); } DECLARE_TYPES(broadcast_dynamic_shape) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_INTS}) - ->setAllowedInputTypes({ALL_INTS}); + getOpDescriptor() + ->setAllowedOutputTypes({ALL_INTS}) + ->setAllowedInputTypes({ALL_INTS}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(broadcast_dynamic_shape) { + const int xRank = INPUT_VARIABLE(0)->lengthOf(); + const int yRank = INPUT_VARIABLE(1)->lengthOf(); - const int xRank = INPUT_VARIABLE(0)->lengthOf(); - const int yRank = INPUT_VARIABLE(1)->lengthOf(); - - const int maxRank = xRank > yRank ? xRank : yRank; + const int maxRank = xRank > yRank ? xRank : yRank; - auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0))); + auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo( + maxRank, ArrayOptions::dataType(inputShape->at(0))); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp index 561c6bb5b17e..5dda98299363 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/check_numerics.cpp @@ -24,33 +24,34 @@ #include namespace sd { - namespace ops { +namespace ops { - CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto message = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); +CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto message = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); - auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite); - REQUIRE_TRUE(allFinite.e(0), 0, "CheckNumerics: %s", message->e(0).c_str()); + auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite); + REQUIRE_TRUE(allFinite.e(0), 0, "CheckNumerics: %s", + message->e(0).c_str()); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(check_numerics) { - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0)))); - } +DECLARE_SHAPE_FN(check_numerics) { + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inputShape->at(0)))); +} - DECLARE_TYPES(check_numerics) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, sd::DataType::UTF8) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(check_numerics) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, sd::DataType::UTF8) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp index 1decc65f0b7e..c1162b13a249 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/compare_and_bitpack.cpp @@ -14,47 +14,50 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author sgazeos@gmail.com - // +// +// @author sgazeos@gmail.com +// +#include #include -#include #include -#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), block.launchContext()); - BROADCAST_CHECK_EMPTY(x, y, (&z0)); - - auto tZ = BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); - bitcast res; - auto status = res.execute({ tZ }, { z }, {}, { DataType::UINT8 }, {}, {}, false); - if (tZ != &z0) { - delete tZ; - } - - return status; - } - - DECLARE_TYPES(compare_and_bitpack) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::UINT8); - } - - DECLARE_SHAPE_FN(compare_and_bitpack) { - auto inShape = inputShape->at(0); - DataType newType = DataType::UINT8; - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, newType))); - } - - } -} \ No newline at end of file +namespace ops { +CUSTOM_OP_IMPL(compare_and_bitpack, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + auto z0 = NDArrayFactory::create(x->ordering(), x->getShapeAsVector(), + block.launchContext()); + BROADCAST_CHECK_EMPTY(x, y, (&z0)); + + auto tZ = + BroadcastHelper::broadcastApply(BROADCAST_BOOL(GreaterThan), x, y, &z0); + bitcast res; + auto status = res.execute({tZ}, {z}, {}, {DataType::UINT8}, {}, {}, false); + if (tZ != &z0) { + delete tZ; + } + + return status; +} + +DECLARE_TYPES(compare_and_bitpack) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, DataType::ANY) + ->setAllowedOutputTypes(0, DataType::UINT8); +} + +DECLARE_SHAPE_FN(compare_and_bitpack) { + auto inShape = inputShape->at(0); + DataType newType = DataType::UINT8; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, newType))); +} + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp index 36e0cef133de..cbaf29654e9c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/confusion_matrix.cpp @@ -21,66 +21,79 @@ #include #if NOT_EXCLUDED(OP_confusion_matrix) -#include -#include #include #include -#include +#include +#include #include +#include + namespace sd { - namespace ops { - DECLARE_TYPES(confusion_matrix) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } +namespace ops { +DECLARE_TYPES(confusion_matrix) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); +} - CUSTOM_OP_IMPL(confusion_matrix, 2, 1, false, 0, -2) { +CUSTOM_OP_IMPL(confusion_matrix, 2, 1, false, 0, -2) { + auto labels = INPUT_VARIABLE(0); + auto predictions = INPUT_VARIABLE(1); + NDArray *weights = nullptr; + if (block.width() > 2) { + weights = INPUT_VARIABLE(2); + REQUIRE_TRUE( + weights->isSameShape(predictions), 0, + "CONFUSION_MATRIX: Weights and predictions should have equal shape"); + } + auto output = OUTPUT_NULLIFIED(0); - auto labels = INPUT_VARIABLE(0); - auto predictions = INPUT_VARIABLE(1); - NDArray *weights = nullptr; - if(block.width() > 2){ - weights = INPUT_VARIABLE(2); - REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape"); - } - auto output = OUTPUT_NULLIFIED(0); + int minPrediction = predictions->reduceNumber(reduce::Min).e(0); + int minLabel = labels->reduceNumber(reduce::Min).e(0); - int minPrediction = predictions->reduceNumber(reduce::Min).e(0); - int minLabel = labels->reduceNumber(reduce::Min).e(0); + REQUIRE_TRUE(minLabel >= 0, 0, + "CONFUSION_MATRIX: Labels contains negative values !"); + REQUIRE_TRUE(minPrediction >= 0, 0, + "CONFUSION_MATRIX: Predictions contains negative values !"); + REQUIRE_TRUE( + labels->isVector(), 0, + "CONFUSION_MATRIX: Labels input should be a Vector, but got %iD instead", + labels->rankOf()); + REQUIRE_TRUE(predictions->isVector(), 0, + "CONFUSION_MATRIX: Predictions input should be Vector, but got " + "%iD instead", + predictions->rankOf()); + REQUIRE_TRUE( + labels->isSameShape(predictions), 0, + "CONFUSION_MATRIX: Labels and predictions should have equal shape"); - REQUIRE_TRUE(minLabel >=0, 0, "CONFUSION_MATRIX: Labels contains negative values !"); - REQUIRE_TRUE(minPrediction >=0, 0, "CONFUSION_MATRIX: Predictions contains negative values !"); - REQUIRE_TRUE(labels->isVector(), 0, "CONFUSION_MATRIX: Labels input should be a Vector, but got %iD instead", labels->rankOf()); - REQUIRE_TRUE(predictions->isVector(), 0, "CONFUSION_MATRIX: Predictions input should be Vector, but got %iD instead", predictions->rankOf()); - REQUIRE_TRUE(labels->isSameShape(predictions),0, "CONFUSION_MATRIX: Labels and predictions should have equal shape"); + helpers::confusionFunctor(block.launchContext(), labels, predictions, weights, + output); - helpers::confusionFunctor(block.launchContext(), labels, predictions, weights, output); + return Status::OK(); +} - return Status::OK(); - } +DECLARE_SHAPE_FN(confusion_matrix) { + auto labels = INPUT_VARIABLE(0); + auto predictions = INPUT_VARIABLE(1); + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + int numClasses = 0; - DECLARE_SHAPE_FN(confusion_matrix) { - auto labels = INPUT_VARIABLE(0); - auto predictions = INPUT_VARIABLE(1); - auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; - int numClasses = 0; + if (block.numI() > 0) { + numClasses = INT_ARG(0); + } else { + int maxPrediction = predictions->reduceNumber(reduce::Max).e(0); + int maxLabel = labels->reduceNumber(reduce::Max).e(0); + numClasses = (maxPrediction >= maxLabel) ? maxPrediction + 1 : maxLabel + 1; + } - if (block.numI() > 0) { - numClasses = INT_ARG(0); - } - else { - int maxPrediction = predictions->reduceNumber(reduce::Max).e(0); - int maxLabel = labels->reduceNumber(reduce::Max).e(0); - numClasses = (maxPrediction >= maxLabel) ? maxPrediction+1 : maxLabel+1; - } - - std::array shape = {{numClasses,numClasses}}; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', 2, shape.data()); - return SHAPELIST(newShape); - } - } + std::array shape = {{numClasses, numClasses}}; + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, 'c', 2, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp index bb9cbe5fd222..68db7a4d9d76 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp @@ -21,52 +21,49 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(expose, -1, -1, true, 0, 0) { +namespace ops { +CUSTOM_OP_IMPL(expose, -1, -1, true, 0, 0) { + for (int e = 0; e < block.width(); e++) { + auto inVar = block.variable(e); + if (inVar->variableType() == VariableType::NDARRAY) { + auto in = INPUT_VARIABLE(e); + auto out = OUTPUT_VARIABLE(e); - for (int e = 0; e < block.width(); e++) { - auto inVar = block.variable(e); - if (inVar->variableType() == VariableType::NDARRAY) { - auto in = INPUT_VARIABLE(e); - auto out = OUTPUT_VARIABLE(e); - - out->assign(in); - } else if (inVar->variableType() == VariableType::ARRAY_LIST) { - auto var = block.ensureVariable(block.name(), block.nodeId(), e); - if (!var->hasNDArrayList()) { - auto list = inVar->getNDArrayList(); - - //block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); - throw std::runtime_error("Expose - not implemented yet"); - } - } - } - - return ND4J_STATUS_OK; - } - DECLARE_SYN(Enter, expose); - DECLARE_SYN(enter, expose); + out->assign(in); + } else if (inVar->variableType() == VariableType::ARRAY_LIST) { + auto var = block.ensureVariable(block.name(), block.nodeId(), e); + if (!var->hasNDArrayList()) { + auto list = inVar->getNDArrayList(); + // block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); + throw std::runtime_error("Expose - not implemented yet"); + } + } + } - DECLARE_TYPES(expose) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return ND4J_STATUS_OK; +} +DECLARE_SYN(Enter, expose); +DECLARE_SYN(enter, expose); - DECLARE_SHAPE_FN(expose) { - auto shapeList = SHAPELIST(); +DECLARE_TYPES(expose) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - for (int e = 0; e < block.width(); e++) { - auto p = block.input(e); - auto var = block.getVariable(e); - if (var->variableType() == VariableType::NDARRAY) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape))); - } - } +DECLARE_SHAPE_FN(expose) { + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < block.width(); e++) { + auto p = block.input(e); + auto var = block.getVariable(e); + if (var->variableType() == VariableType::NDARRAY) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape))); } -} \ No newline at end of file + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index 2ce8643159e3..ebd22aa7b333 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -24,51 +24,55 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars, 1, 1, true, 0, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars, 1, 1, true, 0, 0) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); + NDArray* min; + NDArray* max; - NDArray* min; - NDArray* max; + REQUIRE_TRUE(block.width() == 3 || block.numT() == 2, 0, + "fake_quant_with_min_max_vars: No minimum/maximum values " + "provided by either input arrays or TArgs"); - REQUIRE_TRUE(block.width() == 3 || block.numT() == 2, 0, "fake_quant_with_min_max_vars: No minimum/maximum values provided by either input arrays or TArgs"); + NDArray m; + NDArray m2; + if (block.width() == 3) { + min = INPUT_VARIABLE(1); + max = INPUT_VARIABLE(2); + } else if (block.numT() == 2) { + m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); + m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); + min = &m; + max = &m2; + } + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, + "fake_quant_with_min_max_vars: input and output data types must " + "be the same"); - NDArray m; - NDArray m2; - if(block.width() == 3){ - min = INPUT_VARIABLE(1); - max = INPUT_VARIABLE(2); - } else if(block.numT() == 2){ - m = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - m2 = NDArrayFactory::create(x->dataType(), T_ARG(1), block.launchContext()); - min = &m; - max = &m2; - } - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars: input and output data types must be the same"); - - int numBits = 8; - if (block.numI()) - numBits = INT_ARG(0); - bool narrowed = false; - if (block.numB()) { - narrowed = B_ARG(0); - } - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \ - bits for quantization should be in between 2 and 16, but %i was given.", numBits); - helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output); - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(fake_quant_with_min_max_vars) { - getOpDescriptor() - -> setAllowedOutputTypes({ALL_FLOATS}) - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); - } + int numBits = 8; + if (block.numI()) numBits = INT_ARG(0); + bool narrowed = false; + if (block.numB()) { + narrowed = B_ARG(0); + } + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, + "fake_quant_with_min_max_vars: Number of \ + bits for quantization should be in between 2 and 16, but %i was given.", + numBits); + helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output); + return ND4J_STATUS_OK; +} - DECLARE_SYN(fake_quant_with_min_max_args, fake_quant_with_min_max_vars); - } +DECLARE_TYPES(fake_quant_with_min_max_vars) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SYN(fake_quant_with_min_max_args, fake_quant_with_min_max_vars); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp index c4f4b471e33c..5c4ae1754353 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars_per_channel.cpp @@ -24,49 +24,61 @@ #include #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, 0) { +namespace ops { +CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, + 0) { + auto x = INPUT_VARIABLE(0); + auto min = INPUT_VARIABLE(1); + auto max = INPUT_VARIABLE(2); - auto x = INPUT_VARIABLE(0); - auto min = INPUT_VARIABLE(1); - auto max = INPUT_VARIABLE(2); + auto depth = x->sizeAt(-1); + REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && + min->lengthOf() == max->lengthOf(), + 0, + "fake_quant_with_min_max_vars_per_channel: Min and Max should " + "be 1D tensors with the same length"); + REQUIRE_TRUE(depth == min->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Min length should be" + " %lld, but %lld occurs.", + depth, min->lengthOf()); - auto depth = x->sizeAt(-1); - REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, - "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); - REQUIRE_TRUE(depth == min->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Min length should be" - " %lld, but %lld occurs.", depth, min->lengthOf()); + REQUIRE_TRUE(depth == max->lengthOf(), 0, + "fake_quant_with_min_max_vars_per_channel: Max length should be" + "%lld, but %lld occurs.", + depth, max->lengthOf()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(depth == max->lengthOf(), 0, "fake_quant_with_min_max_vars_per_channel: Max length should be" - "%lld, but %lld occurs.", depth, max->lengthOf()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(x->dataType() == output->dataType(), 0, + "fake_quant_with_min_max_vars_per_channel: input and output " + "data types must be the same"); - REQUIRE_TRUE(x->dataType() == output->dataType(), 0, "fake_quant_with_min_max_vars_per_channel: input and output data types must be the same"); + int numBits = 8; + if (block.numI()) numBits = INT_ARG(0); + bool narrowed = false; + // INT_ARG(1); + if (block.numB()) { + narrowed = B_ARG(0); + } - int numBits = 8; - if (block.numI()) - numBits = INT_ARG(0); - bool narrowed = false; - //INT_ARG(1); - if (block.numB()) { - narrowed = B_ARG(0); - } - - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits" - " for quatization should be in between 2 and 16, but %i " - "was given.", numBits); - helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { - getOpDescriptor() - -> setAllowedOutputTypes({ALL_FLOATS}) - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); - } + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, + "fake_quant_with_min_max_vars_per_channel: Number of bits" + " for quatization should be in between 2 and 16, but %i " + "was given.", + numBits); + helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, + numBits, narrowed, output); + return ND4J_STATUS_OK; +} - DECLARE_SYN(fake_quant_with_min_max_args_per_channel, fake_quant_with_min_max_vars_per_channel); - } +DECLARE_TYPES(fake_quant_with_min_max_vars_per_channel) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SYN(fake_quant_with_min_max_args_per_channel, + fake_quant_with_min_max_vars_per_channel); +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp index a243842d276d..dd6008bd8e6f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/in_top_k.cpp @@ -26,39 +26,50 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(in_top_k, 2, 1, true, 0, 1) { - auto predictions = INPUT_VARIABLE(0); - auto target = INPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(in_top_k, 2, 1, true, 0, 1) { + auto predictions = INPUT_VARIABLE(0); + auto target = INPUT_VARIABLE(1); - auto result = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0, 0, "in_top_k: Parameter k is needed to be set"); - REQUIRE_TRUE(predictions->sizeAt(0) == target->lengthOf(), 0, "in_top_k: the number of predictions rows should be equal to target array length, but got %i and %i correspondingly !", predictions->sizeAt(0), target->lengthOf()); - REQUIRE_TRUE(predictions->rankOf() == 2, 0, "in_top_k: The predictions array should have rank 2, but %i given", predictions->rankOf()); - REQUIRE_TRUE(target->rankOf() == 1, 0, "in_top_k: The target should be a vector"); + auto result = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(block.numI() > 0, 0, + "in_top_k: Parameter k is needed to be set"); + REQUIRE_TRUE(predictions->sizeAt(0) == target->lengthOf(), 0, + "in_top_k: the number of predictions rows should be equal to " + "target array length, but got %i and %i correspondingly !", + predictions->sizeAt(0), target->lengthOf()); + REQUIRE_TRUE( + predictions->rankOf() == 2, 0, + "in_top_k: The predictions array should have rank 2, but %i given", + predictions->rankOf()); + REQUIRE_TRUE(target->rankOf() == 1, 0, + "in_top_k: The target should be a vector"); - int k = INT_ARG(0); - return helpers::inTopKFunctor(block.launchContext(), predictions, target, result, k); - } - - DECLARE_SHAPE_FN(in_top_k) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(1); - int shapeRank = shape::rank(in); + int k = INT_ARG(0); + return helpers::inTopKFunctor(block.launchContext(), predictions, target, + result, k); +} - auto aShape = ConstantShapeHelper::getInstance()->createShapeInfo(sd::DataType::BOOL, shape::order(in), shape::rank(in), shape::shapeOf(in)); - shapeList->push_back(aShape); - return shapeList; - } +DECLARE_SHAPE_FN(in_top_k) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(1); + int shapeRank = shape::rank(in); - DECLARE_TYPES(in_top_k) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(DataType::BOOL); - } + auto aShape = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::DataType::BOOL, shape::order(in), shape::rank(in), + shape::shapeOf(in)); + shapeList->push_back(aShape); + return shapeList; +} - } +DECLARE_TYPES(in_top_k) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(DataType::BOOL); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp index 49c7a29578d6..426bf8a2ce15 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/listdiff.cpp @@ -26,46 +26,58 @@ // this op will probably never become GPU-compatible namespace sd { - namespace ops { - CUSTOM_OP_IMPL(listdiff, 2, 2, false, 0, 0) { - auto values = INPUT_VARIABLE(0); - auto keep = INPUT_VARIABLE(1); - auto output1 = OUTPUT_VARIABLE(0); - auto output2 = OUTPUT_VARIABLE(1); +namespace ops { +CUSTOM_OP_IMPL(listdiff, 2, 2, false, 0, 0) { + auto values = INPUT_VARIABLE(0); + auto keep = INPUT_VARIABLE(1); + auto output1 = OUTPUT_VARIABLE(0); + auto output2 = OUTPUT_VARIABLE(1); - REQUIRE_TRUE(values->rankOf() == 1, 0, "ListDiff: rank of values should be 1D, but got %iD instead", values->rankOf()); - REQUIRE_TRUE(keep->rankOf() == 1, 0, "ListDiff: rank of keep should be 1D, but got %iD instead", keep->rankOf()); - REQUIRE_TRUE(keep->dataType() == values->dataType(), 0, "ListDiff: both inputs must have same data type"); + REQUIRE_TRUE(values->rankOf() == 1, 0, + "ListDiff: rank of values should be 1D, but got %iD instead", + values->rankOf()); + REQUIRE_TRUE(keep->rankOf() == 1, 0, + "ListDiff: rank of keep should be 1D, but got %iD instead", + keep->rankOf()); + REQUIRE_TRUE(keep->dataType() == values->dataType(), 0, + "ListDiff: both inputs must have same data type"); - return helpers::listDiffFunctor(block.launchContext(), values, keep, output1, output2); - }; + return helpers::listDiffFunctor(block.launchContext(), values, keep, output1, + output2); +}; - DECLARE_SHAPE_FN(listdiff) { - auto values = INPUT_VARIABLE(0); - auto keep = INPUT_VARIABLE(1); +DECLARE_SHAPE_FN(listdiff) { + auto values = INPUT_VARIABLE(0); + auto keep = INPUT_VARIABLE(1); - REQUIRE_TRUE(values->rankOf() == 1, 0, "ListDiff: rank of values should be 1D, but got %iD instead", values->rankOf()); - REQUIRE_TRUE(keep->rankOf() == 1, 0, "ListDiff: rank of keep should be 1D, but got %iD instead", keep->rankOf()); - auto v = values->dataType(); - auto k = keep->dataType(); - REQUIRE_TRUE(k == v, 0, "ListDiff: both inputs must have same data type"); + REQUIRE_TRUE(values->rankOf() == 1, 0, + "ListDiff: rank of values should be 1D, but got %iD instead", + values->rankOf()); + REQUIRE_TRUE(keep->rankOf() == 1, 0, + "ListDiff: rank of keep should be 1D, but got %iD instead", + keep->rankOf()); + auto v = values->dataType(); + auto k = keep->dataType(); + REQUIRE_TRUE(k == v, 0, "ListDiff: both inputs must have same data type"); - auto saved = helpers::listDiffCount(block.launchContext(), values, keep); + auto saved = helpers::listDiffCount(block.launchContext(), values, keep); - REQUIRE_TRUE(saved > 0, 0, "ListDiff: no matches found"); + REQUIRE_TRUE(saved > 0, 0, "ListDiff: no matches found"); - auto shapeX = ConstantShapeHelper::getInstance()->vectorShapeInfo(saved, values->dataType()); - auto shapeY = ConstantShapeHelper::getInstance()->vectorShapeInfo(saved, DataType::INT64); - return SHAPELIST(shapeX, shapeY); - } + auto shapeX = ConstantShapeHelper::getInstance()->vectorShapeInfo( + saved, values->dataType()); + auto shapeY = ConstantShapeHelper::getInstance()->vectorShapeInfo( + saved, DataType::INT64); + return SHAPELIST(shapeX, shapeY); +} - DECLARE_TYPES(listdiff) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - } +DECLARE_TYPES(listdiff) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(0, DataType::INHERIT) + ->setAllowedOutputTypes(1, {ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 67ecc23ec716..e12971f819c2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -21,201 +21,226 @@ #include #include - namespace sd { - namespace ops { +namespace ops { #if NOT_EXCLUDED(OP_image_non_max_suppression) - CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.numT() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.numT() > 1) { - scoreThreshold = T_ARG(1); - } - if (boxes->isEmpty() || scales->isEmpty()) - return Status::OK(); - - if (output->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, " - "but %i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array " - "should be 4, but %i is given", boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, - "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, "image.non_max_suppressio: The overlay " - "threashold should be in [0, 1], but " - "%lf is given.", overlayThreshold); - REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, - "image.non_max_suppression: Boxes and scores inputs should have the same data type, but %s and %s " - "were given.", DataTypeUtils::asString(boxes->dataType()).c_str(), - DataTypeUtils::asString(scales->dataType()).c_str()); - helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, - scoreThreshold, output); - return Status::OK(); - } +CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + if (boxes->isEmpty() || scales->isEmpty()) return Status::OK(); + + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + boxes->rankOf() == 2, 0, + "image.non_max_suppression: The rank of boxes array should be 2, " + "but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, + "image.non_max_suppression: The last dimension of boxes array " + "should be 4, but %i is given", + boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression: The rank of scales array should be " + "1, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, + "image.non_max_suppressio: The overlay " + "threashold should be in [0, 1], but " + "%lf is given.", + overlayThreshold); + REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, + "image.non_max_suppression: Boxes and scores inputs should have " + "the same data type, but %s and %s " + "were given.", + DataTypeUtils::asString(boxes->dataType()).c_str(), + DataTypeUtils::asString(scales->dataType()).c_str()); + helpers::nonMaxSuppression(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, scoreThreshold, + output); + return Status::OK(); +} - DECLARE_SHAPE_FN(non_max_suppression) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); - const Nd4jLong *outputShape = nullptr; - - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - if (maxOutputSize > 0) { - auto actualIndicesCount = shape::sizeAt(in, 0); - if (block.numT() > 1 || block.width() > 4) { - auto scoreThreshold = - block.numT() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); - auto scales = INPUT_VARIABLE(1); - scales->syncToHost(); - for (auto e = 0; e < scales->lengthOf(); e++) { - if (scales->e(e) < (float) scoreThreshold) { - actualIndicesCount--; - } - } - } - if (actualIndicesCount < maxOutputSize) - maxOutputSize = actualIndicesCount; - } - outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); - - return SHAPELIST(outputShape); - } - DECLARE_TYPES(non_max_suppression) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); +DECLARE_SHAPE_FN(non_max_suppression) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + const Nd4jLong *outputShape = nullptr; + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + if (maxOutputSize > 0) { + auto actualIndicesCount = shape::sizeAt(in, 0); + if (block.numT() > 1 || block.width() > 4) { + auto scoreThreshold = + block.numT() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e(0); + auto scales = INPUT_VARIABLE(1); + scales->syncToHost(); + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < (float)scoreThreshold) { + actualIndicesCount--; } + } + } + if (actualIndicesCount < maxOutputSize) maxOutputSize = actualIndicesCount; + } + outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + maxOutputSize, DataType::INT32); + + return SHAPELIST(outputShape); +} +DECLARE_TYPES(non_max_suppression) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); +} #endif #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - DECLARE_TYPES(non_max_suppression_v3) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); - } +DECLARE_TYPES(non_max_suppression_v3) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); +} - CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.numT() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.numT() > 1) { - scoreThreshold = T_ARG(1); - } - if (boxes->isEmpty() || scales->isEmpty()) - return Status::OK(); - if (output->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but " - "%i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should " - "be 4, but %i is given", boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, - "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, - "image.non_max_suppression_v3: The overlay threashold should be in [0, 1], but %lf given.", - overlayThreshold); - REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, - "image.non_max_suppression_v3: Boxes and scores inputs should have the same data type, but %s and %s " - "were given.", DataTypeUtils::asString(boxes->dataType()).c_str(), - DataTypeUtils::asString(scales->dataType()).c_str()); - - helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, - scoreThreshold, output); - return Status::OK(); - } +CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + if (boxes->isEmpty() || scales->isEmpty()) return Status::OK(); + if (output->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + boxes->rankOf() == 2, 0, + "image.non_max_suppression: The rank of boxes array should be 2, but " + "%i is given", + boxes->rankOf()); + REQUIRE_TRUE( + boxes->sizeAt(1) == 4, 0, + "image.non_max_suppression: The last dimension of boxes array should " + "be 4, but %i is given", + boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression: The rank of scales array should be " + "1, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, + "image.non_max_suppression_v3: The overlay threashold should be " + "in [0, 1], but %lf given.", + overlayThreshold); + REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0, + "image.non_max_suppression_v3: Boxes and scores inputs should " + "have the same data type, but %s and %s " + "were given.", + DataTypeUtils::asString(boxes->dataType()).c_str(), + DataTypeUtils::asString(scales->dataType()).c_str()); + + helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, scoreThreshold, + output); + return Status::OK(); +} - DECLARE_SHAPE_FN(non_max_suppression_v3) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); - - - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - - double overlayThreshold = 0.5; - double scoreThreshold = - DataTypeUtils::infOrMax(); - - if (block.width() > 3) { - overlayThreshold = INPUT_VARIABLE(3)->e(0); - } - else if (block.numT() > 0) { - overlayThreshold = T_ARG(0); - } - - if (block.width() > 4) { - scoreThreshold = INPUT_VARIABLE(4)->e(0); - } - else if (block.numT() > 1) { - scoreThreshold = T_ARG(1); - } - - auto len = maxOutputSize; - if (len > 0) - len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, nullptr); - - auto outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(len, DataType::INT32); - - return SHAPELIST(outputShape); - } +DECLARE_SHAPE_FN(non_max_suppression_v3) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + + double overlayThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + + if (block.width() > 3) { + overlayThreshold = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + overlayThreshold = T_ARG(0); + } + + if (block.width() > 4) { + scoreThreshold = INPUT_VARIABLE(4)->e(0); + } else if (block.numT() > 1) { + scoreThreshold = T_ARG(1); + } + + auto len = maxOutputSize; + if (len > 0) + len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, + maxOutputSize, overlayThreshold, + scoreThreshold, nullptr); + + auto outputShape = + ConstantShapeHelper::getInstance()->vectorShapeInfo(len, DataType::INT32); + + return SHAPELIST(outputShape); +} #endif - } -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp index d88289ffe31f..eab68b7881cd 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp @@ -24,69 +24,84 @@ #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) namespace sd { - namespace ops { - CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { - auto boxes = INPUT_VARIABLE(0); - auto scales = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - int maxOutputSize; // = INT_ARG(0); - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression_overlaps: Max output size argument cannot be retrieved."); - REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression_overlaps: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, "image.non_max_suppression_overlaps: The boxes array should be square, but {%lld, %lld} is given", boxes->sizeAt(0), boxes->sizeAt(1)); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression_overlaps: The rank of scales array should be 1, but %i is given", boxes->rankOf()); +namespace ops { +CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression_overlaps: Max output size argument " + "cannot be retrieved."); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, + "image.non_max_suppression_overlaps: The rank of boxes array " + "should be 2, but %i is given", + boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, + "image.non_max_suppression_overlaps: The boxes array should be " + "square, but {%lld, %lld} is given", + boxes->sizeAt(0), boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), + 0, + "image.non_max_suppression_overlaps: The rank of scales array " + "should be 1, but %i is given", + boxes->rankOf()); -// if (scales->lengthOf() < maxOutputSize) -// maxOutputSize = scales->lengthOf(); - double overlapThreshold = 0.5; - double scoreThreshold = -DataTypeUtils::infOrMax(); - if (block.numT() > 0) - overlapThreshold = T_ARG(0); - if (block.numT() > 1) - scoreThreshold = T_ARG(1); + // if (scales->lengthOf() < maxOutputSize) + // maxOutputSize = scales->lengthOf(); + double overlapThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + if (block.numT() > 0) overlapThreshold = T_ARG(0); + if (block.numT() > 1) scoreThreshold = T_ARG(1); - // TODO: refactor helpers to multithreaded facility - helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, maxOutputSize, overlapThreshold, - scoreThreshold, output); - return Status::OK(); - } - - DECLARE_SHAPE_FN(non_max_suppression_overlaps) { - auto in = inputShape->at(0); - int outRank = shape::rank(in); + // TODO: refactor helpers to multithreaded facility + helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, + maxOutputSize, overlapThreshold, + scoreThreshold, output); + return Status::OK(); +} - int maxOutputSize; - if (block.width() > 2) - maxOutputSize = INPUT_VARIABLE(2)->e(0); - else if (block.numI() == 1) - maxOutputSize = INT_ARG(0); - else - REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); +DECLARE_SHAPE_FN(non_max_suppression_overlaps) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); - double overlapThreshold = 0.5; - double scoreThreshold = 0.; + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.numI() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, + "image.non_max_suppression: Max output size argument cannot " + "be retrieved."); - Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric(block.launchContext(), INPUT_VARIABLE(0), - INPUT_VARIABLE(1), maxOutputSize, overlapThreshold, scoreThreshold, nullptr); //shape::sizeAt(in, 0); - if (boxSize < maxOutputSize) - maxOutputSize = boxSize; + double overlapThreshold = 0.5; + double scoreThreshold = 0.; - auto outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); + Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + maxOutputSize, overlapThreshold, scoreThreshold, + nullptr); // shape::sizeAt(in, 0); + if (boxSize < maxOutputSize) maxOutputSize = boxSize; - return SHAPELIST(outputShape); - } - DECLARE_TYPES(non_max_suppression_overlaps) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INDICES}); - } + auto outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + maxOutputSize, DataType::INT32); - } + return SHAPELIST(outputShape); } +DECLARE_TYPES(non_max_suppression_overlaps) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INDICES}); +} + +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp index 7262d7fa71a1..f8664f645abd 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/normalize_moments.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // Created by george@skymind.io on 26.01.2018. - // +// +// Created by george@skymind.io on 26.01.2018. +// #include #if NOT_EXCLUDED(OP_normalize_moments) @@ -24,62 +24,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(normalize_moments, 3, 2, false, 1, 0) { - auto counts = INPUT_VARIABLE(0); - auto means = INPUT_VARIABLE(1); - auto variances = INPUT_VARIABLE(2); +namespace ops { +CUSTOM_OP_IMPL(normalize_moments, 3, 2, false, 1, 0) { + auto counts = INPUT_VARIABLE(0); + auto means = INPUT_VARIABLE(1); + auto variances = INPUT_VARIABLE(2); - auto resMeans = OUTPUT_VARIABLE(0); - auto resVariances = OUTPUT_VARIABLE(1); + auto resMeans = OUTPUT_VARIABLE(0); + auto resVariances = OUTPUT_VARIABLE(1); - // FIXME: double? - NDArray shift = NDArrayFactory::create(0., block.launchContext()); + // FIXME: double? + NDArray shift = NDArrayFactory::create(0., block.launchContext()); - if (block.numT() > 0) { - shift.assign(T_ARG(0)); - } + if (block.numT() > 0) { + shift.assign(T_ARG(0)); + } - means->applyScalarArr(scalar::Divide, *counts, *resMeans); + means->applyScalarArr(scalar::Divide, *counts, *resMeans); - NDArray squareMeans = resMeans->dup('c'); - NDArray tempVariances = resVariances->dup('c'); + NDArray squareMeans = resMeans->dup('c'); + NDArray tempVariances = resVariances->dup('c'); - squareMeans.applyTransform(transform::Square, squareMeans, nullptr); - variances->applyScalarArr(scalar::Divide, *counts, tempVariances); - // tempVariances.printIndexedBuffer("varianced divided by count"); - tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances); + squareMeans.applyTransform(transform::Square, squareMeans, nullptr); + variances->applyScalarArr(scalar::Divide, *counts, tempVariances); + // tempVariances.printIndexedBuffer("varianced divided by count"); + tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, + *resVariances); - if (shift.e(0) != 0) { - resMeans->applyScalarArr(scalar::Add, shift, *resMeans); - } + if (shift.e(0) != 0) { + resMeans->applyScalarArr(scalar::Add, shift, *resMeans); + } - return Status::OK(); - } - - DECLARE_SHAPE_FN(normalize_moments) { - auto in = inputShape->at(1); + return Status::OK(); +} - Nd4jLong* meanShape = nullptr; - Nd4jLong* varianceShape = nullptr; +DECLARE_SHAPE_FN(normalize_moments) { + auto in = inputShape->at(1); - COPY_SHAPE_EX(in, meanShape, block.workspace()); - COPY_SHAPE_EX(in, varianceShape, block.workspace()); + Nd4jLong* meanShape = nullptr; + Nd4jLong* varianceShape = nullptr; - auto shapeList = SHAPELIST(); - shapeList->push_back(CONSTANT(meanShape)); - shapeList->push_back(CONSTANT(varianceShape)); + COPY_SHAPE_EX(in, meanShape, block.workspace()); + COPY_SHAPE_EX(in, varianceShape, block.workspace()); - return shapeList; - } + auto shapeList = SHAPELIST(); + shapeList->push_back(CONSTANT(meanShape)); + shapeList->push_back(CONSTANT(varianceShape)); - DECLARE_TYPES(normalize_moments) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ ALL_FLOATS }); - } - } + return shapeList; +} +DECLARE_TYPES(normalize_moments) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops + +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp index 3c0524eba164..8c7f2814a025 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/nth_element.cpp @@ -22,58 +22,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(nth_element, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto n = INPUT_VARIABLE(1); - bool reverse = false; - if (block.numI() > 0) - reverse = (bool)INT_ARG(0); +namespace ops { +CUSTOM_OP_IMPL(nth_element, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto n = INPUT_VARIABLE(1); + bool reverse = false; + if (block.numI() > 0) reverse = (bool)INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); - Nd4jLong lastDim = input->sizeAt(-1); - int nVal = n->e(0); - REQUIRE_TRUE(nVal < lastDim && nVal >= 0, 0, "nth_element: n should be non-negative and less than last dimension size (%lld), but %i was given.", lastDim, n); - REQUIRE_TRUE(input->rankOf() > 0, 0, "nth_element: The rank of input array should be at least 1, but %i is given", input->rankOf()); // - if (output->lengthOf() == input->lengthOf()) - output->assign(input); - else { -// if (!input->isVector() && reverse) -// n->assign(lastDim - n->e(0) - 1); - helpers::nthElementFunctor(block.launchContext(), input, nVal, output, reverse); - } - return ND4J_STATUS_OK; - } + auto output = OUTPUT_VARIABLE(0); + Nd4jLong lastDim = input->sizeAt(-1); + int nVal = n->e(0); + REQUIRE_TRUE(nVal < lastDim && nVal >= 0, 0, + "nth_element: n should be non-negative and less than last " + "dimension size (%lld), but %i was given.", + lastDim, n); + REQUIRE_TRUE(input->rankOf() > 0, 0, + "nth_element: The rank of input array should be at least 1, but " + "%i is given", + input->rankOf()); // + if (output->lengthOf() == input->lengthOf()) + output->assign(input); + else { + // if (!input->isVector() && reverse) + // n->assign(lastDim - n->e(0) - 1); + helpers::nthElementFunctor(block.launchContext(), input, nVal, output, + reverse); + } + return ND4J_STATUS_OK; +} - DECLARE_SHAPE_FN(nth_element) { +DECLARE_SHAPE_FN(nth_element) { + auto in = inputShape->at(0); + int outRank = shape::rank(in) - 1; + Nd4jLong const* outShape = nullptr; + if (outRank > 1) { + Nd4jLong* outputShape = nullptr; + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outputShape[0] = outRank; + for (Nd4jLong e = 0; e < outRank; e++) outputShape[e + 1] = in[e + 1]; - auto in = inputShape->at(0); - int outRank = shape::rank(in) - 1; - Nd4jLong const* outShape = nullptr; - if (outRank > 1) { - Nd4jLong *outputShape = nullptr; - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outputShape[0] = outRank; - for (Nd4jLong e = 0; e < outRank; e++) - outputShape[e + 1] = in[e + 1]; + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + outShape = CONSTANT(outputShape); + } else if (outRank == 1) { + outShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::sizeAt(in, 0), ArrayOptions::dataType(in)); + } else { + // outputShape = shape::createScalarShapeInfo(); + outShape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(in)); + } + return SHAPELIST(outShape); +} +DECLARE_TYPES(nth_element) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - outShape = CONSTANT(outputShape); - } - else if (outRank == 1) { - outShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(in, 0), ArrayOptions::dataType(in)); - } - else { - //outputShape = shape::createScalarShapeInfo(); - outShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(in)); - } - return SHAPELIST(outShape); - } - DECLARE_TYPES(nth_element) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - - } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp index 6349b84febd4..6d969c0cbc20 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/onehot.cpp @@ -21,96 +21,89 @@ #include #if NOT_EXCLUDED(OP_onehot) -#include #include +#include #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(onehot, 1, 1, false, -2, -2) { - auto input = INPUT_VARIABLE(0); - - // FIXME: double? - double on(1.0f); // T_ARG(0); - double off(0.0f); //T_ARG(1); +namespace ops { +CUSTOM_OP_IMPL(onehot, 1, 1, false, -2, -2) { + auto input = INPUT_VARIABLE(0); - auto depth = -1; //INT_ARG(0); - auto axis = -1; //INT_ARG(1); + // FIXME: double? + double on(1.0f); // T_ARG(0); + double off(0.0f); // T_ARG(1); - if (block.numI() > 0) - axis = INT_ARG(0); + auto depth = -1; // INT_ARG(0); + auto axis = -1; // INT_ARG(1); - if (block.numI() > 1) { - depth = INT_ARG(1); - } else if (block.width() > 1) { - depth = INPUT_VARIABLE(1)->e(0); - } + if (block.numI() > 0) axis = INT_ARG(0); - REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); + if (block.numI() > 1) { + depth = INT_ARG(1); + } else if (block.width() > 1) { + depth = INPUT_VARIABLE(1)->e(0); + } + REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); - if (block.width() > 2) { - on = INPUT_VARIABLE(2)->e(0); + if (block.width() > 2) { + on = INPUT_VARIABLE(2)->e(0); - if (block.width() > 3) - off = INPUT_VARIABLE(3)->e(0); - } else if (block.numT() > 0) { - on = T_ARG(0); + if (block.width() > 3) off = INPUT_VARIABLE(3)->e(0); + } else if (block.numT() > 0) { + on = T_ARG(0); - if (block.numT() > 1) - off = T_ARG(1); - } + if (block.numT() > 1) off = T_ARG(1); + } - auto output = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (axis < 0) - axis = output->rankOf() + axis; + if (axis < 0) axis = output->rankOf() + axis; - helpers::onehot(block.launchContext(), input, output, axis, depth, on, off); + helpers::onehot(block.launchContext(), input, output, axis, depth, on, off); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(onehot) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(onehot) { + auto inShape = inputShape->at(0); - sd::DataType dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; + sd::DataType dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; - int depth = -1; - Nd4jLong axis = -1; + int depth = -1; + Nd4jLong axis = -1; - if (block.numI() > 0) - axis = INT_ARG(0); + if (block.numI() > 0) axis = INT_ARG(0); - if (block.numI() > 1) { - depth = INT_ARG(1); - } else if (block.width() > 1) { - depth = INPUT_VARIABLE(1)->e(0); - } + if (block.numI() > 1) { + depth = INT_ARG(1); + } else if (block.width() > 1) { + depth = INPUT_VARIABLE(1)->e(0); + } - REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); + REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); - int rank = shape::rank(inShape); + int rank = shape::rank(inShape); - if (axis < 0) - axis = rank + 1 + axis; + if (axis < 0) axis = rank + 1 + axis; - std::vector shape; - for (int e = 0; e < rank; e++) - shape.push_back(shape::shapeOf(inShape)[e]); + std::vector shape; + for (int e = 0; e < rank; e++) shape.push_back(shape::shapeOf(inShape)[e]); - shape.insert(shape.begin() + axis, depth); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', rank + 1, shape.data()); + shape.insert(shape.begin() + axis, depth); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, 'c', rank + 1, shape.data()); - return SHAPELIST(newShape); - } + return SHAPELIST(newShape); +} - DECLARE_TYPES(onehot) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - } +DECLARE_TYPES(onehot) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp index 1e1058397f08..ed3ef117832d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rint.cpp @@ -24,22 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(rint, 1, 1, true) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - x->applyTransform(transform::Rint, *z); - - return Status::OK(); - } - } - - DECLARE_TYPES(rint) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +namespace ops { +OP_IMPL(rint, 1, 1, true) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + x->applyTransform(transform::Rint, *z); + + return Status::OK(); +} +} // namespace ops + +DECLARE_TYPES(rint) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 18e271409359..84ef5988c763 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -22,89 +22,98 @@ #if NOT_EXCLUDED(OP_roll) #include -#include #include +#include namespace sd { namespace ops { - CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - int inputLen = input->lengthOf(); +CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + int inputLen = input->lengthOf(); - bool shiftIsLinear = block.width() == 1; - std::vector axes; - std::vector shifts; - if (block.width() > 1) { - REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width()); - auto axesI = INPUT_VARIABLE(2); - auto shiftsI = INPUT_VARIABLE(1); - REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf()); - REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); - helpers::adjustAxis(axesI->lengthOf(), axesI, axes ); - shifts.resize(shiftsI->lengthOf()); - for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { - auto shift = shiftsI->e(i); - if (shift < 0) { - shift -= input->sizeAt(i) * (shift / inputLen - 1); - } - else { - shift %= input->sizeAt(i); - } - shifts[i] = shift; - } - - } - else { - int shift = INT_ARG(0); - if (shift < 0) { - // convert shift to positive value between 1 and inputLen - 1 - shift -= inputLen * (shift / inputLen - 1); - } - else - // cut shift to value between 1 and inputLen - 1 - shift %= inputLen; - axes.resize(block.numI() - 1); - if (axes.size()) - shifts.resize(axes.size());//emplace_back(shift); - else - shifts.push_back(shift); + bool shiftIsLinear = block.width() == 1; + std::vector axes; + std::vector shifts; + if (block.width() > 1) { + REQUIRE_TRUE(block.width() == 3, 0, + "roll: 3 arguments required for roll - input, shifts and " + "axes. But %i given.", + block.width()); + auto axesI = INPUT_VARIABLE(2); + auto shiftsI = INPUT_VARIABLE(1); + REQUIRE_TRUE( + axesI->rankOf() == shiftsI->rankOf(), 0, + "roll: shifts and axes should be the same rank, but %i and %i given.", + (int)shiftsI->rankOf(), (int)axesI->rankOf()); + REQUIRE_TRUE( + axesI->lengthOf() == shiftsI->lengthOf(), 0, + "roll: shifts and axes should be the same length, but %i and %i given.", + (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); + helpers::adjustAxis(axesI->lengthOf(), axesI, axes); + shifts.resize(shiftsI->lengthOf()); + for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { + auto shift = shiftsI->e(i); + if (shift < 0) { + shift -= input->sizeAt(i) * (shift / inputLen - 1); + } else { + shift %= input->sizeAt(i); + } + shifts[i] = shift; + } - for (auto& s: shifts) - s = shift; + } else { + int shift = INT_ARG(0); + if (shift < 0) { + // convert shift to positive value between 1 and inputLen - 1 + shift -= inputLen * (shift / inputLen - 1); + } else + // cut shift to value between 1 and inputLen - 1 + shift %= inputLen; + axes.resize(block.numI() - 1); + if (axes.size()) + shifts.resize(axes.size()); // emplace_back(shift); + else + shifts.push_back(shift); - for (unsigned e = 0; e < axes.size(); ++e) { - int axis = INT_ARG(e + 1); - REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", - input->rankOf(), input->rankOf() - 1, axis); - axes[e] = (axis < 0? (input->rankOf() + axis) : axis); - } - } + for (auto& s : shifts) s = shift; - if (block.isInplace()) output = input; + for (unsigned e = 0; e < axes.size(); ++e) { + int axis = INT_ARG(e + 1); + REQUIRE_TRUE( + axis < input->rankOf() && axis >= -input->rankOf(), 0, + "roll: axe value should be between -%i and %i, but %i was given.", + input->rankOf(), input->rankOf() - 1, axis); + axes[e] = (axis < 0 ? (input->rankOf() + axis) : axis); + } + } - shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); + if (block.isInplace()) output = input; - if (shiftIsLinear) { - helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); - } - else { - helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace()); - } + shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); - return Status::OK(); - } + if (shiftIsLinear) { + helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], + block.isInplace()); + } else { + helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, + block.isInplace()); + } - DECLARE_TYPES(roll) { - getOpDescriptor() - ->setAllowedInputTypes(0,sd::DataType::ANY) - ->setAllowedInputTypes(1,sd::DataType::INT32) // TODO: all ints in future - ->setAllowedInputTypes(2,sd::DataType::INT32) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return Status::OK(); } + +DECLARE_TYPES(roll) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, + sd::DataType::INT32) // TODO: all ints in future + ->setAllowedInputTypes(2, sd::DataType::INT32) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp index 28d099f31f9b..4d65859fc49d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_max.cpp @@ -22,84 +22,96 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_max, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_max: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMaxFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_max) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - idxVector->syncToHost(); - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - DECLARE_TYPES(segment_max) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - CUSTOM_OP_IMPL(segment_max_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_max_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - - } - DECLARE_TYPES(segment_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - - } +namespace ops { +CUSTOM_OP_IMPL(segment_max, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_max: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_max: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_max: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMaxFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(segment_max) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + idxVector->syncToHost(); + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + return SHAPELIST(CONSTANT(outputShape)); } + +DECLARE_TYPES(segment_max) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} +CUSTOM_OP_IMPL(segment_max_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_max_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} +DECLARE_TYPES(segment_max_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp index 3b2c16551337..2e7ccbad8fd6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_mean.cpp @@ -22,82 +22,95 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_mean, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_mean: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMeanFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_mean) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - DECLARE_TYPES(segment_mean) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } - - - CUSTOM_OP_IMPL(segment_mean_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_mean_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - DECLARE_TYPES(segment_mean_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_mean, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_mean: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_mean: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_mean: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMeanFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_mean) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +DECLARE_TYPES(segment_mean) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); +} + +CUSTOM_OP_IMPL(segment_mean_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_mean_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} +DECLARE_TYPES(segment_mean_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp index b0f93456d285..d78eedc7f209 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_min.cpp @@ -22,80 +22,94 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_min, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_min: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentMinFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_min) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - CUSTOM_OP_IMPL(segment_min_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output); - } - DECLARE_SHAPE_FN(segment_min_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - DECLARE_TYPES(segment_min) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - DECLARE_TYPES(segment_min_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_min, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_min: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_min: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_min: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentMinFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_min) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +CUSTOM_OP_IMPL(segment_min_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, + gradOut, output); +} +DECLARE_SHAPE_FN(segment_min_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +DECLARE_TYPES(segment_min) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} +DECLARE_TYPES(segment_min_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp index ad64664b289a..9cbfea381f14 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_prod.cpp @@ -22,85 +22,98 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_prod, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_prod: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentProdFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(segment_prod) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = val + 1; - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(segment_prod_bp, 3, 2, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto gradOut = INPUT_VARIABLE(2); - auto output = OUTPUT_NULLIFIED(0); - auto outIndices = OUTPUT_NULLIFIED(1); - outIndices->assign(indices); - helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output); - - return Status::OK(); - } - - DECLARE_TYPES(segment_prod) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - - - DECLARE_SHAPE_FN(segment_prod_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - } - - DECLARE_TYPES(segment_prod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_prod, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_prod: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_prod: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_prod: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentProdFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return Status::OK(); } + +DECLARE_SHAPE_FN(segment_prod) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = val + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +CUSTOM_OP_IMPL(segment_prod_bp, 3, 2, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto gradOut = INPUT_VARIABLE(2); + auto output = OUTPUT_NULLIFIED(0); + auto outIndices = OUTPUT_NULLIFIED(1); + outIndices->assign(indices); + helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, + output); + + return Status::OK(); +} + +DECLARE_TYPES(segment_prod) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(segment_prod_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} + +DECLARE_TYPES(segment_prod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp index 3289fa4cbb0f..21b80d4e6e6c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/segment_sum.cpp @@ -22,75 +22,87 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(segment_sum, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); - - auto expected = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - auto wrong = NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); - - REQUIRE_TRUE(helpers::segmentIndicesValidate(block.launchContext(), idxSegments, expected, wrong), 0, "segment_sum: segment indices should be arranged, but %2.1f > %2.1f", expected.e(0), wrong.e(0)); - - segmentedOutput->nullify(); - helpers::segmentSumFunctor(block.launchContext(), input, idxSegments, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(segment_sum) { - auto idxVector = INPUT_VARIABLE(1); - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - int val = (*idxVector).e(idxVector->lengthOf() - 1); - - int numOfClasses = static_cast(val) + 1; - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { - - return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_NULLIFIED(0)); - } - DECLARE_SHAPE_FN(segment_sum_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - DECLARE_TYPES(segment_sum) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - DECLARE_TYPES(segment_sum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - } +namespace ops { +CUSTOM_OP_IMPL(segment_sum, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "segment_sum: segment indexes array should be a vector, but it " + "rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "segment_sum: segment indexes array length should be equal to " + "the input first dimension, but %i != %i.", + idxSegments->lengthOf(), input->sizeAt(0)); + + auto expected = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + auto wrong = + NDArrayFactory::create(input->dataType(), 0.f, block.launchContext()); + + REQUIRE_TRUE( + helpers::segmentIndicesValidate(block.launchContext(), idxSegments, + expected, wrong), + 0, "segment_sum: segment indices should be arranged, but %2.1f > %2.1f", + expected.e(0), wrong.e(0)); + + segmentedOutput->nullify(); + helpers::segmentSumFunctor(block.launchContext(), input, idxSegments, + segmentedOutput); + + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(segment_sum) { + auto idxVector = INPUT_VARIABLE(1); + + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + int val = (*idxVector).e(idxVector->lengthOf() - 1); + + int numOfClasses = static_cast(val) + 1; + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + return SHAPELIST(CONSTANT(outputShape)); } + +CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { + return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), + INPUT_VARIABLE(1), INPUT_VARIABLE(2), + OUTPUT_NULLIFIED(0)); +} +DECLARE_SHAPE_FN(segment_sum_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +DECLARE_TYPES(segment_sum) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +DECLARE_TYPES(segment_sum_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index 5ac400478905..0bce3fd0b06c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -22,85 +22,78 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - const int inRank = input->rankOf(); +namespace ops { +CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + const int inRank = input->rankOf(); - //REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank); - Nd4jLong maxInd = input->argMax(); - float max = input->e(maxInd); - if (block.numI() > 0) { - maxInd = INT_ARG(0); - if (maxInd < max) - maxInd = static_cast(max); - } - else if (block.width() > 1) { - auto maxlen = INPUT_VARIABLE(1); - //REQUIRE_TRUE(maxlen->lengthOf() == 1, "sequence_mask: 2nd input (max length) should be a scalar array."); - float tmaxlen = maxlen->e(0); - if (tmaxlen > max) - maxInd = static_cast(tmaxlen); - } - else - maxInd = static_cast(max); + // REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= + // 1, but %i given!", inRank); + Nd4jLong maxInd = input->argMax(); + float max = input->e(maxInd); + if (block.numI() > 0) { + maxInd = INT_ARG(0); + if (maxInd < max) maxInd = static_cast(max); + } else if (block.width() > 1) { + auto maxlen = INPUT_VARIABLE(1); + // REQUIRE_TRUE(maxlen->lengthOf() == 1, "sequence_mask: 2nd input (max + // length) should be a scalar array."); + float tmaxlen = maxlen->e(0); + if (tmaxlen > max) maxInd = static_cast(tmaxlen); + } else + maxInd = static_cast(max); - helpers::sequenceMask(block.launchContext(), input, output, maxInd); + helpers::sequenceMask(block.launchContext(), input, output, maxInd); - return Status::OK(); - } - - DECLARE_SHAPE_FN(sequence_mask) { - - Nd4jLong* outShapeInfo = nullptr; - auto in = inputShape->at(0); - int outRank = shape::rank(in) + 1; - auto input = INPUT_VARIABLE(0); - auto dtype = DataType::BOOL; - auto argMaxInd = input->argMax(); - Nd4jLong max = input->e(argMaxInd); - Nd4jLong maxInd = max; + return Status::OK(); +} - if (block.numD() > 0) - dtype = D_ARG(0); +DECLARE_SHAPE_FN(sequence_mask) { + Nd4jLong* outShapeInfo = nullptr; + auto in = inputShape->at(0); + int outRank = shape::rank(in) + 1; + auto input = INPUT_VARIABLE(0); + auto dtype = DataType::BOOL; + auto argMaxInd = input->argMax(); + Nd4jLong max = input->e(argMaxInd); + Nd4jLong maxInd = max; - if (block.width() > 1) { - auto maxlen = INPUT_VARIABLE(1); - Nd4jLong tmaxlen = maxlen->e(0); - if (tmaxlen > max) - maxInd = static_cast(tmaxlen); - if (block.numI() > 0) { - dtype = (DataType) INT_ARG(0); - } - } - else { - if (block.numI() > 0) { - maxInd = INT_ARG(0); - } - if (maxInd < max) - maxInd = max; - if (block.numI() > 1) - dtype = (DataType)INT_ARG(1); // to work with legacy code - } + if (block.numD() > 0) dtype = D_ARG(0); - int lastDimension = maxInd; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - outShapeInfo[0] = outRank; - for(int i = 0; i < outRank - 1; ++i) - outShapeInfo[i + 1] = shape::sizeAt(in, i); - outShapeInfo[outRank] = lastDimension; + if (block.width() > 1) { + auto maxlen = INPUT_VARIABLE(1); + Nd4jLong tmaxlen = maxlen->e(0); + if (tmaxlen > max) maxInd = static_cast(tmaxlen); + if (block.numI() > 0) { + dtype = (DataType)INT_ARG(0); + } + } else { + if (block.numI() > 0) { + maxInd = INT_ARG(0); + } + if (maxInd < max) maxInd = max; + if (block.numI() > 1) + dtype = (DataType)INT_ARG(1); // to work with legacy code + } - ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in)); + int lastDimension = maxInd; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outShapeInfo[0] = outRank; + for (int i = 0; i < outRank - 1; ++i) + outShapeInfo[i + 1] = shape::sizeAt(in, i); + outShapeInfo[outRank] = lastDimension; - return SHAPELIST(CONSTANT(outShapeInfo)); - } + ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in)); - DECLARE_TYPES(sequence_mask) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } -} + return SHAPELIST(CONSTANT(outShapeInfo)); } +DECLARE_TYPES(sequence_mask) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp index adc3a2cb59e7..f4dd1da572d2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/square.cpp @@ -24,23 +24,21 @@ #include namespace sd { - namespace ops { - OP_IMPL(square, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - int extras = 2; - input->applyScalar(scalar::Pow, extras, *output); - - return Status::OK(); - } - - DECLARE_TYPES(square) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +OP_IMPL(square, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + int extras = 2; + input->applyScalar(scalar::Pow, extras, *output); + + return Status::OK(); +} + +DECLARE_TYPES(square) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp index 621df4462e3f..810f995709e5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp @@ -24,26 +24,24 @@ #include namespace sd { - namespace ops { - OP_IMPL(stop_gradient, 1, 1, true) { - auto out = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) { - auto x = INPUT_VARIABLE(0); - // we hope for memcpy here - out->assign(x); - } - - return Status::OK(); - } - DECLARE_SYN(StopGradient, stop_gradient); - - DECLARE_TYPES(stop_gradient) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +OP_IMPL(stop_gradient, 1, 1, true) { + auto out = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) { + auto x = INPUT_VARIABLE(0); + // we hope for memcpy here + out->assign(x); + } + + return Status::OK(); +} +DECLARE_SYN(StopGradient, stop_gradient); + +DECLARE_TYPES(stop_gradient) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp index 360e7fdd437b..90a14c00fbc6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp @@ -22,74 +22,77 @@ #if NOT_EXCLUDED(OP_top_k) //#include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(top_k, 1, 2, false, 0, -1) { - auto x = INPUT_VARIABLE(0); - int k = 1;// from params - bool needSort = true; - - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - - if (block.numB() == 1) - needSort = B_ARG(0); - - if (block.width() == 1) { - if (block.numI() > 0) { - k = INT_ARG(0); - } - } else { - k = INPUT_VARIABLE(1)->e(0); - } - - REQUIRE_TRUE(k <= x->sizeAt(-1), 0, "top_k: k should not be greater than last dimension"); - REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); - - int res = helpers::topKFunctor(block.launchContext(), x, values, indices, k, needSort); - return res; - } - - DECLARE_SHAPE_FN(top_k) { - auto shapeList = SHAPELIST(); - auto in = inputShape->at(0); - int shapeRank = shape::rank(in); - int k = 1; // default output shape is size 1 - - if (block.width() == 2) { - k = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - k = INT_ARG(0); - } - - REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); - - for (int e = 0; e < 2; e++) { // 2 element tuple at output - Nd4jLong* aShape; - ALLOCATE(aShape, block.workspace(), shape::shapeInfoLength(shapeRank), Nd4jLong); - aShape[0] = shapeRank; - for (int i = 1 ; i < shapeRank; ++i) - aShape[i] = shape::sizeAt(in, i - 1); - aShape[shapeRank] = k; - - shape::updateStrides(aShape, shape::order(in)); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(aShape, (e == 0?ArrayOptions::dataType(in):sd::DataType::INT64)))); - - RELEASE(aShape, block.workspace()); - } - return shapeList; - } - - DECLARE_TYPES(top_k) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, sd::DataType::ANY) - ->setAllowedOutputTypes(1, {ALL_INDICES}); - } +namespace ops { +CUSTOM_OP_IMPL(top_k, 1, 2, false, 0, -1) { + auto x = INPUT_VARIABLE(0); + int k = 1; // from params + bool needSort = true; + + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + + if (block.numB() == 1) needSort = B_ARG(0); + + if (block.width() == 1) { + if (block.numI() > 0) { + k = INT_ARG(0); } + } else { + k = INPUT_VARIABLE(1)->e(0); + } + + REQUIRE_TRUE(k <= x->sizeAt(-1), 0, + "top_k: k should not be greater than last dimension"); + REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); + + int res = helpers::topKFunctor(block.launchContext(), x, values, indices, k, + needSort); + return res; +} + +DECLARE_SHAPE_FN(top_k) { + auto shapeList = SHAPELIST(); + auto in = inputShape->at(0); + int shapeRank = shape::rank(in); + int k = 1; // default output shape is size 1 + + if (block.width() == 2) { + k = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + k = INT_ARG(0); + } + + REQUIRE_TRUE(k > 0, 0, "top_k: k should be positive, but %i given.", k); + + for (int e = 0; e < 2; e++) { // 2 element tuple at output + Nd4jLong* aShape; + ALLOCATE(aShape, block.workspace(), shape::shapeInfoLength(shapeRank), + Nd4jLong); + aShape[0] = shapeRank; + for (int i = 1; i < shapeRank; ++i) aShape[i] = shape::sizeAt(in, i - 1); + aShape[shapeRank] = k; + + shape::updateStrides(aShape, shape::order(in)); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(aShape, (e == 0 ? ArrayOptions::dataType(in) + : sd::DataType::INT64)))); + + RELEASE(aShape, block.workspace()); + } + return shapeList; +} + +DECLARE_TYPES(top_k) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, sd::DataType::ANY) + ->setAllowedOutputTypes(1, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp index 58fadb08fcdc..ff81fc0456d9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unique.cpp @@ -25,85 +25,92 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unique, 1, 2, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - - REQUIRE_TRUE(x->dataType() == values->dataType(), 0, "Unique: input and output data types must be the same"); - - return helpers::uniqueFunctor(block.launchContext(), x, values, indices, (NDArray*)nullptr); - } - - DECLARE_SHAPE_FN(unique) { - auto in = inputShape->at(0); - auto source = INPUT_VARIABLE(0); -// auto shapeList = SHAPELIST(); - const Nd4jLong* valuesShape; - const Nd4jLong* indicesShape; - - int uniqueCount = helpers::uniqueCount(block.launchContext(), source); - - if (uniqueCount == 0) { // empty value Shape - valuesShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(source->dataType()); - } - else { - // all output shapes are 1D arrays (vectors) - valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(uniqueCount, ArrayOptions::dataType(in)); - } - // second output is always LONG - indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::length(in), sd::DataType::INT64); - - //COPY_SHAPE_EX(in, indicesShape, block.workspace()); - - return SHAPELIST(valuesShape, indicesShape); - - } - - CUSTOM_OP_IMPL(unique_with_counts, 1, 3, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto values = OUTPUT_VARIABLE(0); - auto indices = OUTPUT_VARIABLE(1); - auto counts = OUTPUT_VARIABLE(2); - - return helpers::uniqueFunctor(block.launchContext(), input, values, indices, counts); - } - - DECLARE_SHAPE_FN(unique_with_counts) { - auto in = inputShape->at(0); - auto source = INPUT_VARIABLE(0); - - int uniqueCount = helpers::uniqueCount(block.launchContext(), source); - // all output shapes are 1D arrays (vectors) - // all output shapes are 1D arrays (vectors) - auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(uniqueCount, source->dataType()); - - // second output is always LONG - auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(source->lengthOf(), sd::DataType::INT64); - - // third one as well - auto countsShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(uniqueCount, sd::DataType::INT64); - - return SHAPELIST(valuesShape, indicesShape, countsShape); - } - - DECLARE_TYPES(unique) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}); - } - - DECLARE_TYPES(unique_with_counts) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(2, {ALL_INTS}); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unique, 1, 2, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + + REQUIRE_TRUE(x->dataType() == values->dataType(), 0, + "Unique: input and output data types must be the same"); + + return helpers::uniqueFunctor(block.launchContext(), x, values, indices, + (NDArray*)nullptr); +} + +DECLARE_SHAPE_FN(unique) { + auto in = inputShape->at(0); + auto source = INPUT_VARIABLE(0); + // auto shapeList = SHAPELIST(); + const Nd4jLong* valuesShape; + const Nd4jLong* indicesShape; + + int uniqueCount = helpers::uniqueCount(block.launchContext(), source); + + if (uniqueCount == 0) { // empty value Shape + valuesShape = + ConstantShapeHelper::getInstance()->emptyShapeInfo(source->dataType()); + } else { + // all output shapes are 1D arrays (vectors) + valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + uniqueCount, ArrayOptions::dataType(in)); + } + // second output is always LONG + indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::length(in), sd::DataType::INT64); + + // COPY_SHAPE_EX(in, indicesShape, block.workspace()); + + return SHAPELIST(valuesShape, indicesShape); +} + +CUSTOM_OP_IMPL(unique_with_counts, 1, 3, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto values = OUTPUT_VARIABLE(0); + auto indices = OUTPUT_VARIABLE(1); + auto counts = OUTPUT_VARIABLE(2); + + return helpers::uniqueFunctor(block.launchContext(), input, values, indices, + counts); +} + +DECLARE_SHAPE_FN(unique_with_counts) { + auto in = inputShape->at(0); + auto source = INPUT_VARIABLE(0); + + int uniqueCount = helpers::uniqueCount(block.launchContext(), source); + // all output shapes are 1D arrays (vectors) + // all output shapes are 1D arrays (vectors) + auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + uniqueCount, source->dataType()); + + // second output is always LONG + auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + source->lengthOf(), sd::DataType::INT64); + + // third one as well + auto countsShape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + uniqueCount, sd::DataType::INT64); + + return SHAPELIST(valuesShape, indicesShape, countsShape); } +DECLARE_TYPES(unique) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}); +} + +DECLARE_TYPES(unique_with_counts) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(2, {ALL_INTS}); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index e010e5301b8a..71a2e07f8e27 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -22,73 +22,87 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_max) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } - DECLARE_SHAPE_FN(unsorted_segment_max) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - Nd4jLong* outputShape; +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_max: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_max: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ild.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_max: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} +DECLARE_TYPES(unsorted_segment_max) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} +DECLARE_SHAPE_FN(unsorted_segment_max) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + Nd4jLong* outputShape; - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - return SHAPELIST(CONSTANT(outputShape)); - } + return SHAPELIST(CONSTANT(outputShape)); +} - CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } +CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMaxFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} - DECLARE_TYPES(unsorted_segment_max_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } +DECLARE_TYPES(unsorted_segment_max_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} - DECLARE_SHAPE_FN(unsorted_segment_max_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); +DECLARE_SHAPE_FN(unsorted_segment_max_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - } - } + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index ee58ef1aa59f..e711cea60e02 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -22,76 +22,90 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_mean) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_mean) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - - CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - - DECLARE_TYPES(unsorted_segment_mean_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_mean_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_mean: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_mean: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_mean: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; } +DECLARE_TYPES(unsorted_segment_mean) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_mean) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} + +CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMeanFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} + +DECLARE_TYPES(unsorted_segment_mean_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_mean_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index 750eff35bd2c..5c3e49243f4d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -22,78 +22,91 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_min) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_min: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_min: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_min: segment indices should be in range [0, " + "%ld), but %ld > %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); +DECLARE_SHAPE_FN(unsorted_segment_min) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); - return SHAPELIST(CONSTANT(outputShape)); - } + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - DECLARE_TYPES(unsorted_segment_min) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); - } + return SHAPELIST(CONSTANT(outputShape)); +} - CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } +DECLARE_TYPES(unsorted_segment_min) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); +} - DECLARE_TYPES(unsorted_segment_min_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } +CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentMinFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} - DECLARE_SHAPE_FN(unsorted_segment_min_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); +DECLARE_TYPES(unsorted_segment_min_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); +} - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +DECLARE_SHAPE_FN(unsorted_segment_min_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); - } + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} - } +} // namespace ops -} +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index e8537c2326c0..552d1fd23a2e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -22,91 +22,115 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong = 0; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_prod) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - DECLARE_TYPES(unsorted_segment_prod) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INDICES}) - ->setSameMode(false); - } - - CUSTOM_OP_IMPL(unsorted_segment_prod_bp, 3, 2, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto eps = INPUT_VARIABLE(2); -// auto numOfClasses = INT_ARG(0); - auto output = OUTPUT_NULLIFIED(0); - - Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e(0) : INT_ARG(0); - REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf()); - REQUIRE_TRUE(indices->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod_bp: segment indexes array length should be equal to the input first dimension, but %lld != %lld.", indices->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong = numOfClasses; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), indices, numOfClasses, wrong), 0, "unsorted_segment_prod_bp: segment indices should be in range [0, %lld), but %lld > %lld", - numOfClasses, wrong, numOfClasses); - - return helpers::unsortedSegmentProdFunctorBP(block.launchContext(), input, indices, eps, numOfClasses, output); - } - DECLARE_TYPES(unsorted_segment_prod_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INDICES}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INDICES}) - ->setAllowedInputTypes(2,{ALL_FLOATS, ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_prod_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_prod: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_prod: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong = 0; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_prod: segment indices should be in range [0, " + "%ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(unsorted_segment_prod) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +DECLARE_TYPES(unsorted_segment_prod) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INDICES}) + ->setSameMode(false); +} +CUSTOM_OP_IMPL(unsorted_segment_prod_bp, 3, 2, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto eps = INPUT_VARIABLE(2); + // auto numOfClasses = INT_ARG(0); + auto output = OUTPUT_NULLIFIED(0); + + Nd4jLong numOfClasses = + block.width() == 4 ? INPUT_VARIABLE(3)->e(0) : INT_ARG(0); + REQUIRE_TRUE(indices->isVector(), 0, + "unsorted_segment_prod_bp: segment indexes array should be a " + "vector, but it rank is %i.", + indices->rankOf()); + REQUIRE_TRUE(indices->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_prod_bp: segment indexes array length should " + "be equal to the input first dimension, but %lld != %lld.", + indices->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong = numOfClasses; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), indices, numOfClasses, wrong), + 0, + "unsorted_segment_prod_bp: segment indices should be in range " + "[0, %lld), but %lld > %lld", + numOfClasses, wrong, numOfClasses); + + return helpers::unsortedSegmentProdFunctorBP( + block.launchContext(), input, indices, eps, numOfClasses, output); +} +DECLARE_TYPES(unsorted_segment_prod_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INDICES}) + ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) + ->setSameMode(false); } + +DECLARE_SHAPE_FN(unsorted_segment_prod_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index 4574547f3d69..0fc14029c840 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -22,75 +22,89 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(unsorted_segment_sqrt_n) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - DECLARE_TYPES(unsorted_segment_sqrt_n) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_sqrt_n_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); -// return SHAPELIST(in, inIdx); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_sqrt_n: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_sqrt_n: segment indexes array length should " + "be equal to the input first dimension, but %ld != %ld.", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_sqrt_n: segment indices should be in range " + "[0, %ld), but %ld != %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentSqrtNFunctor( + block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; } + +DECLARE_SHAPE_FN(unsorted_segment_sqrt_n) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +DECLARE_TYPES(unsorted_segment_sqrt_n) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentSqrtNFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); +} +DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_sqrt_n_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); + // return SHAPELIST(in, inIdx); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 89c018bda325..32833510e610 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -22,73 +22,86 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto idxSegments = INPUT_VARIABLE(1); - auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); - - return ND4J_STATUS_OK; - } - DECLARE_TYPES(unsorted_segment_sum) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(unsorted_segment_sum) { - - auto in = inputShape->at(0); - int outRank = shape::rank(in); - Nd4jLong* outputShape = nullptr; - Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); - - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); - - return SHAPELIST(CONSTANT(outputShape)); - } - CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { - return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); - } - - DECLARE_SHAPE_FN(unsorted_segment_sum_bp){ - auto in = inputShape->at(0); - auto inIdx = inputShape->at(1); - - Nd4jLong* outShape; - Nd4jLong* outIndex; - COPY_SHAPE(in, outShape); - COPY_SHAPE(inIdx, outIndex); - return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); - - } - DECLARE_TYPES(unsorted_segment_sum_bp) { - getOpDescriptor() - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - - } +namespace ops { +CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto idxSegments = INPUT_VARIABLE(1); + auto segmentedOutput = OUTPUT_NULLIFIED(0); + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, + "unsorted_segment_sum: segment indexes array should be a " + "vector, but it rank is %i.", + idxSegments->rankOf()); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, + "unsorted_segment_sum: segment indexes array length should be " + "equal to the input first dimension, but %ld != %ld", + idxSegments->lengthOf(), input->sizeAt(0)); + + Nd4jLong wrong; + + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate( + block.launchContext(), idxSegments, numOfClasses, wrong), + 0, + "unsorted_segment_sum: segment indices should be in range [0, " + "%ld), but %ld > %ld", + numOfClasses, wrong, numOfClasses); + + helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, + numOfClasses, segmentedOutput); + + return ND4J_STATUS_OK; +} +DECLARE_TYPES(unsorted_segment_sum) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(unsorted_segment_sum) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; + Nd4jLong numOfClasses = + block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + + ALLOCATE(outputShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for (int i = 1; i < outRank; ++i) outputShape[i + 1] = shape::sizeAt(in, i); + + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + return SHAPELIST(CONSTANT(outputShape)); +} +CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { + return helpers::unsortedSegmentSumFunctorBP( + block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), + INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0)); } + +DECLARE_SHAPE_FN(unsorted_segment_sum_bp) { + auto in = inputShape->at(0); + auto inIdx = inputShape->at(1); + + Nd4jLong* outShape; + Nd4jLong* outIndex; + COPY_SHAPE(in, outShape); + COPY_SHAPE(inIdx, outIndex); + return SHAPELIST(CONSTANT(outShape), CONSTANT(outIndex)); +} +DECLARE_TYPES(unsorted_segment_sum_bp) { + getOpDescriptor() + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(sd::DataType::ANY) + ->setSameMode(false); +} + +} // namespace ops + +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp index 71d337c7cf40..24383776988e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/weighted_cross_entropy_with_logits.cpp @@ -27,26 +27,37 @@ namespace sd { namespace ops { - OP_IMPL(weighted_cross_entropy_with_logits, 3, 1, true) { - auto targets = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(1); - auto weights = INPUT_VARIABLE(2); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(targets->isSameShape(input), 0, "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The shape of both input params should be equal, but got input_shape=%s and targets_shape=%s !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(targets).c_str()); - REQUIRE_TRUE(weights->isScalar() || targets->sizeAt(-1) == weights->lengthOf(), 0, "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The weights should be scalar or vector with length equal to size of last targets dimension, but got weights_shape=%s instead!", ShapeUtils::shapeAsString(weights).c_str()); - - helpers::weightedCrossEntropyWithLogitsFunctor(block.launchContext(), targets, input, weights, output); - - return Status::OK(); - } - - DECLARE_TYPES(weighted_cross_entropy_with_logits) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +OP_IMPL(weighted_cross_entropy_with_logits, 3, 1, true) { + auto targets = INPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(1); + auto weights = INPUT_VARIABLE(2); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE( + targets->isSameShape(input), 0, + "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The shape of both input params " + "should be equal, but got input_shape=%s and targets_shape=%s !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(targets).c_str()); + REQUIRE_TRUE( + weights->isScalar() || targets->sizeAt(-1) == weights->lengthOf(), 0, + "WEIGHTED_CROSS_ENTROPY_WITH_LOGITS op: The weights should be scalar or " + "vector with length equal to size of last targets dimension, but got " + "weights_shape=%s instead!", + ShapeUtils::shapeAsString(weights).c_str()); + + helpers::weightedCrossEntropyWithLogitsFunctor(block.launchContext(), targets, + input, weights, output); + + return Status::OK(); } + +DECLARE_TYPES(weighted_cross_entropy_with_logits) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp index f70e92cf5a60..defc692fa753 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp @@ -24,39 +24,42 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(zero_fraction, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - if(input->isEmpty()){ - output->p(0, std::numeric_limits::quiet_NaN()); - return Status::OK(); - } - - int numZeros = 0; -// for (int e = 0; e < input->lengthOf(); e++) -// if ((*input)(e) == T(0)) -// numZeros++; - auto countZero = input->reduceNumber(reduce::CountZero); - //nd4j_printf("Zero count is %f for %i elements.", countZero.e(0), input->lengthOf()); - //countZero /= double(input->lengthOf()); - output->p(0, countZero.e(0) / double(input->lengthOf())); //printIndexedBuffer("Zero count"); - - return Status::OK(); - } - DECLARE_SHAPE_FN(zero_fraction) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::DOUBLE)); - } - - DECLARE_TYPES(zero_fraction) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(zero_fraction, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + if (input->isEmpty()) { + output->p(0, std::numeric_limits::quiet_NaN()); + return Status::OK(); + } + + int numZeros = 0; + // for (int e = 0; e < input->lengthOf(); e++) + // if ((*input)(e) == T(0)) + // numZeros++; + auto countZero = input->reduceNumber(reduce::CountZero); + // nd4j_printf("Zero count is %f for %i elements.", countZero.e(0), + // input->lengthOf()); countZero /= double(input->lengthOf()); + output->p( + 0, countZero.e(0) / + double(input->lengthOf())); // printIndexedBuffer("Zero count"); + + return Status::OK(); +} +DECLARE_SHAPE_FN(zero_fraction) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo( + sd::DataType::DOUBLE)); +} + +DECLARE_TYPES(zero_fraction) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp index 112c3aeff56d..42d867ad16d2 100644 --- a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp +++ b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp @@ -21,49 +21,51 @@ #include #if NOT_EXCLUDED(OP_random_bernoulli) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_bernoulli, 1, 1, true, 1, 0) { - auto rng = block.getRng(); - // FIXME: to be implemented -/* - if (rng == nullptr) - return Status::THROW("RNG is null, aborting..."); - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(random_bernoulli, 1, 1, true, 1, 0) { + auto rng = block.getRng(); + // FIXME: to be implemented + /* + if (rng == nullptr) + return Status::THROW("RNG is null, aborting..."); - T f = T_ARG(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->buffer(), z->shapeInfo(), &f); -*/ + T f = T_ARG(0); - auto z = OUTPUT_VARIABLE(0); - auto f = T_ARG(0); + functions::random::RandomFunction::template + execTransform>(block.getRNG(), + z->buffer(), z->shapeInfo(), &f); + */ - RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f); + auto z = OUTPUT_VARIABLE(0); + auto f = T_ARG(0); - return Status::OK(); - } + RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f); - DECLARE_SHAPE_FN(random_bernoulli) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); + return Status::OK(); +} - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_bernoulli) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_bernoulli) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_bernoulli) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/dropout.cpp b/libnd4j/include/ops/declarable/generic/random/dropout.cpp index b64fd49d5677..2deb1842f801 100644 --- a/libnd4j/include/ops/declarable/generic/random/dropout.cpp +++ b/libnd4j/include/ops/declarable/generic/random/dropout.cpp @@ -27,102 +27,105 @@ namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(dropout, 1, 1, true, 1, 1) { - auto input = INPUT_VARIABLE(0); // lookup param + auto input = INPUT_VARIABLE(0); // lookup param + + NDArray* reduceShape = nullptr; // this param is optional + auto output = OUTPUT_NULLIFIED(0); // - NDArray *reduceShape = nullptr; // this param is optional - auto output = OUTPUT_NULLIFIED(0); // - - int seed = INT_ARG(0); + int seed = INT_ARG(0); - // FIXME: float? - double probValue = T_ARG(0); - if (block.width() > 1) - reduceShape = INPUT_VARIABLE(1); + // FIXME: float? + double probValue = T_ARG(0); + if (block.width() > 1) reduceShape = INPUT_VARIABLE(1); - REQUIRE_TRUE(probValue > 0.f && probValue <= 1.f, 0, "dropout: Probability should be with range 0 to 1."); + REQUIRE_TRUE(probValue > 0.f && probValue <= 1.f, 0, + "dropout: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - *output = *input; - return Status::OK(); - } + if (probValue == 1.0) { + *output = *input; + return Status::OK(); + } - return helpers::dropOutFunctor(block, input, output, reduceShape, seed, probValue); + return helpers::dropOutFunctor(block, input, output, reduceShape, seed, + probValue); } - DECLARE_TYPES(dropout) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } +DECLARE_TYPES(dropout) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(dropout_bp, 2, 1, false, 1, 1) { - NDArray* input = INPUT_VARIABLE(0); // lookup param - NDArray* gradOut = INPUT_VARIABLE(1); // lookup param - - NDArray* reduceShape = nullptr; // this param is optional - NDArray* output = OUTPUT_NULLIFIED(0); // - - int seed = INT_ARG(0); - - double probValue = T_ARG(0); - if (block.width() > 2) - reduceShape = INPUT_VARIABLE(2); - - REQUIRE_TRUE((probValue > 0. && probValue <= 1.), 0, "dropout_bp: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - output->assign(0.f); // fill up output with 0 - return ND4J_STATUS_OK; - } - - REQUIRE_TRUE(helpers::dropOutFunctorBP(block, input, gradOut, output, reduceShape, seed, probValue) == ND4J_STATUS_OK, 0, "dropout_bp: Cannot backprop dropout." ); + NDArray* input = INPUT_VARIABLE(0); // lookup param + NDArray* gradOut = INPUT_VARIABLE(1); // lookup param + + NDArray* reduceShape = nullptr; // this param is optional + NDArray* output = OUTPUT_NULLIFIED(0); // + + int seed = INT_ARG(0); + + double probValue = T_ARG(0); + if (block.width() > 2) reduceShape = INPUT_VARIABLE(2); + REQUIRE_TRUE((probValue > 0. && probValue <= 1.), 0, + "dropout_bp: Probability should be with range 0 to 1."); + if (probValue == 1.0) { + output->assign(0.f); // fill up output with 0 return ND4J_STATUS_OK; + } + + REQUIRE_TRUE( + helpers::dropOutFunctorBP(block, input, gradOut, output, reduceShape, + seed, probValue) == ND4J_STATUS_OK, + 0, "dropout_bp: Cannot backprop dropout."); + + return ND4J_STATUS_OK; } DECLARE_TYPES(dropout_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(alpha_dropout_bp, 2, 1, false, 4, 1) { - NDArray* input = INPUT_VARIABLE(0); // lookup param - NDArray* gradOut = INPUT_VARIABLE(1); // lookup param + NDArray* input = INPUT_VARIABLE(0); // lookup param + NDArray* gradOut = INPUT_VARIABLE(1); // lookup param - NDArray* reduceShape = nullptr; // this param is optional - NDArray* output = OUTPUT_VARIABLE(0); // + NDArray* reduceShape = nullptr; // this param is optional + NDArray* output = OUTPUT_VARIABLE(0); // - if (block.width() > 2) - reduceShape = INPUT_VARIABLE(2); + if (block.width() > 2) reduceShape = INPUT_VARIABLE(2); - int seed = INT_ARG(0); - - double probValue = T_ARG(0); - double alphaValue = T_ARG(0); - double alpha1Value = T_ARG(2); - double betaValue = T_ARG(3); + int seed = INT_ARG(0); - REQUIRE_TRUE(probValue > 0. && probValue <= 1., 0, "dropout_bp: Probability should be with range 0 to 1."); - if (probValue == 1.0) { - output->assign(0.); // fill up output with 0 - return ND4J_STATUS_OK; - } + double probValue = T_ARG(0); + double alphaValue = T_ARG(0); + double alpha1Value = T_ARG(2); + double betaValue = T_ARG(3); - return helpers::alphaDropOutFunctorBP(block, input, gradOut, output, reduceShape, seed, probValue, alphaValue, alpha1Value, betaValue); -} - DECLARE_TYPES(alpha_dropout_bp) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setSameMode(true); - } + REQUIRE_TRUE(probValue > 0. && probValue <= 1., 0, + "dropout_bp: Probability should be with range 0 to 1."); + if (probValue == 1.0) { + output->assign(0.); // fill up output with 0 + return ND4J_STATUS_OK; + } + + return helpers::alphaDropOutFunctorBP(block, input, gradOut, output, + reduceShape, seed, probValue, + alphaValue, alpha1Value, betaValue); } +DECLARE_TYPES(alpha_dropout_bp) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 660ad1dcadd2..ed3746b3c50f 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -21,37 +21,37 @@ #include #if NOT_EXCLUDED(OP_random_exponential) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { - // random generator for distribution - auto rng = block.randomGenerator(); - auto z = OUTPUT_VARIABLE(0); - auto lambda = T_ARG(0); - - RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda); +namespace ops { +CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { + // random generator for distribution + auto rng = block.randomGenerator(); + auto z = OUTPUT_VARIABLE(0); + auto lambda = T_ARG(0); - return Status::OK(); - } + RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda); + return Status::OK(); +} - DECLARE_SHAPE_FN(random_exponential) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); +DECLARE_SHAPE_FN(random_exponential) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); - return SHAPELIST(newShape); - } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_exponential) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_exponential) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/gamma.cpp b/libnd4j/include/ops/declarable/generic/random/gamma.cpp index 0650f33d5342..a3a8c004e6b6 100644 --- a/libnd4j/include/ops/declarable/generic/random/gamma.cpp +++ b/libnd4j/include/ops/declarable/generic/random/gamma.cpp @@ -25,61 +25,68 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) { - // gamma distribution - auto rng = block.randomGenerator(); - auto shape = INPUT_VARIABLE(0); - auto alpha = INPUT_VARIABLE(1); - NDArray* beta = nullptr; +namespace ops { +CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + NDArray* beta = nullptr; - if (block.width() > 2) { - beta = INPUT_VARIABLE(2); - REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, "random_gamma: alpha and beta shapes should be broadcastable."); - } + if (block.width() > 2) { + beta = INPUT_VARIABLE(2); + REQUIRE_TRUE( + ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, + "random_gamma: alpha and beta shapes should be broadcastable."); + } - auto output = OUTPUT_VARIABLE(0); - auto seed = 0; + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; - if (block.numI()) { - seed = INT_ARG(0); - } + if (block.numI()) { + seed = INT_ARG(0); + } - rng.setSeed(seed); + rng.setSeed(seed); - helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output); + helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(random_gamma) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); - auto alphaShape = inputShape->at(1); - auto additionalShape = alphaShape; - if (inputShape->size() > 2) { - auto rest = inputShape->at(2); additionalShape = nullptr; - REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, "random_gamma: alpha and beta shapes should be broadcastable."); - const Nd4jLong* additionalShapeBroadcasted = nullptr; - ShapeUtils::evalBroadcastShapeInfo(alphaShape, rest, true, additionalShapeBroadcasted, block.workspace()); - additionalShape = additionalShapeBroadcasted; - } - auto lastDim = shape::sizeAt(alphaShape, 0); - auto dtype = ArrayOptions::dataType(alphaShape); - for (auto i = 0; i < shape::rank(additionalShape); i++) - shape.push_back(shape::sizeAt(additionalShape, i)); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_gamma) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto alphaShape = inputShape->at(1); + auto additionalShape = alphaShape; + if (inputShape->size() > 2) { + auto rest = inputShape->at(2); + additionalShape = nullptr; + REQUIRE_TRUE( + ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, + "random_gamma: alpha and beta shapes should be broadcastable."); + const Nd4jLong* additionalShapeBroadcasted = nullptr; + ShapeUtils::evalBroadcastShapeInfo( + alphaShape, rest, true, additionalShapeBroadcasted, block.workspace()); + additionalShape = additionalShapeBroadcasted; + } + auto lastDim = shape::sizeAt(alphaShape, 0); + auto dtype = ArrayOptions::dataType(alphaShape); + for (auto i = 0; i < shape::rank(additionalShape); i++) + shape.push_back(shape::sizeAt(additionalShape, i)); + auto newShape = + ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_gamma) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_gamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp index 7042ae6dd320..5088c4d0f175 100644 --- a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp @@ -24,28 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRng(); - auto z = OUTPUT_VARIABLE(0); - - z->p(Nd4jLong(0), rng.rootState()); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(get_seed) { - auto newshape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); - return SHAPELIST(newshape); - } - - DECLARE_TYPES(get_seed) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(DataType::INT64); - } - } +namespace ops { +CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // defined in Graph"); + auto rng = block.getRng(); + auto z = OUTPUT_VARIABLE(0); + + z->p(Nd4jLong(0), rng.rootState()); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(get_seed) { + auto newshape = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); + return SHAPELIST(newshape); +} + +DECLARE_TYPES(get_seed) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(DataType::INT64); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index f3a5a20959de..c984dda664df 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -22,93 +22,113 @@ #include #if NOT_EXCLUDED(OP_random_multinomial) -#include #include +#include #include namespace sd { - namespace ops { - /////////////////////// - /** - * multinomial (categorical) random generator - * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] - * and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N. - * represents the unnormalized log-probabilities for all classes. - * Int arguments: 0 - optional argument, corresponds to dimension with batch_size - * Int arguments: 1 - optional argument, integer type to use for the output. Default int64. - */ - // used https://en.wikipedia.org/wiki/Categorical_distribution - // methods: gumbel trick + softmax + argmax - CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_NULLIFIED(0); - auto inputSamples = INPUT_VARIABLE(1); - - - REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); - - REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," - " but got no argumets instead."); - - Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); - // do nothing if number of samples = 0 - if (0 == numOfSamples) - return Status::OK(); - - REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); - - const int rank = input->rankOf(); - REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - auto dimA = (0 == dimC) ? 1 : 0; - if (1 == input->sizeAt(dimA)) { - *output = 0; - return Status::OK(); - } - - auto rng = block.randomGenerator(); - helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC); - return Status::OK(); - } - - - DECLARE_SHAPE_FN(random_multinomial) { - - auto input = INPUT_VARIABLE(0); - auto inputSamples = INPUT_VARIABLE(1); - - REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," - " but got no argumets instead."); - - Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); - - REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); - - const int rank = input->rankOf(); - REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); - - const int argSize = block.numI(); - const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - auto nShape = input->getShapeAsVector(); - auto dimA = (0 == dimC) ? 1 : 0; - nShape[dimA] = numOfSamples; - - DataType nType = block.numD() ? D_ARG(0) : sd::DataType::INT64; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape)); - } - - DECLARE_TYPES(random_multinomial) { - getOpDescriptor() - ->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS }) - ->setAllowedInputTypes(1, { sd::DataType::INT32 }) - ->setAllowedOutputTypes(0, { ALL_INDICES }); - } - } +namespace ops { +/////////////////////// +/** + * multinomial (categorical) random generator + * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] + * and array with one scalar value of samples number, number of independent + * samples to draw for each experiment 1,N. represents the unnormalized + * log-probabilities for all classes. Int arguments: 0 - optional argument, + * corresponds to dimension with batch_size Int arguments: 1 - optional + * argument, integer type to use for the output. Default int64. + */ +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_NULLIFIED(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE( + !input->isEmpty(), 0, + "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); + + REQUIRE_TRUE( + inputSamples->lengthOf() == 1, 0, + "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + // do nothing if number of samples = 0 + if (0 == numOfSamples) return Status::OK(); + + REQUIRE_TRUE(numOfSamples > 0, 0, + "RANDOM_MULTINOMIAL OP: Number of samples should be greater " + "then 0, got %i. ", + numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, + "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = " + "2, but got instead rank = %i.", + rank); + + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + auto dimA = (0 == dimC) ? 1 : 0; + if (1 == input->sizeAt(dimA)) { + *output = 0; + return Status::OK(); + } + + auto rng = block.randomGenerator(); + helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, + numOfSamples, dimC); + return Status::OK(); +} + +DECLARE_SHAPE_FN(random_multinomial) { + auto input = INPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE( + inputSamples->lengthOf() == 1, 0, + "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + + REQUIRE_TRUE(numOfSamples > 0, 0, + "RANDOM_MULTINOMIAL OP: Number of samples should be greater " + "then 0, got %i. ", + numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, + "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = " + "2, but got instead rank = %i.", + rank); + + const int argSize = block.numI(); + const int dimC = argSize > 0 + ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) + : rank - 1; + + auto nShape = input->getShapeAsVector(); + auto dimA = (0 == dimC) ? 1 : 0; + nShape[dimA] = numOfSamples; + + DataType nType = block.numD() ? D_ARG(0) : sd::DataType::INT64; + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + nType, input->ordering(), nShape)); +} + +DECLARE_TYPES(random_multinomial) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {sd::DataType::INT32}) + ->setAllowedOutputTypes(0, {ALL_INDICES}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/normal.cpp b/libnd4j/include/ops/declarable/generic/random/normal.cpp index 204fbc7e76a0..00b5e09da62d 100644 --- a/libnd4j/include/ops/declarable/generic/random/normal.cpp +++ b/libnd4j/include/ops/declarable/generic/random/normal.cpp @@ -21,45 +21,51 @@ #include #if NOT_EXCLUDED(OP_random_normal) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_normal, 1, 1, true, 2, 0) { - // normal distribution - auto rng = block.randomGenerator(); - // FIXME: to be implemented -/* - REQUIRE_TRUE(rng != nullptr, 0, "RNG isn't defined for this Graph instance"); +namespace ops { +CUSTOM_OP_IMPL(random_normal, 1, 1, true, 2, 0) { + // normal distribution + auto rng = block.randomGenerator(); + // FIXME: to be implemented + /* + REQUIRE_TRUE(rng != nullptr, 0, "RNG isn't defined for this Graph + instance"); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), block.getTArguments()->data()); -*/ + functions::random::RandomFunction::template + execTransform>(block.getRNG(), + z->buffer(), z->shapeInfo(), z->buffer(), z->shapeInfo(), z->buffer(), + z->shapeInfo(), block.getTArguments()->data()); + */ - RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); + RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), + T_ARG(0), T_ARG(1)); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(random_normal) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); +DECLARE_SHAPE_FN(random_normal) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataType::FLOAT32, 'c', shape); + return SHAPELIST(newShape); +} - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape); - return SHAPELIST(newShape); - } - - DECLARE_SYN(randomnormal, random_normal); +DECLARE_SYN(randomnormal, random_normal); - DECLARE_TYPES(random_normal) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_normal) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/poisson.cpp b/libnd4j/include/ops/declarable/generic/random/poisson.cpp index 30fc27ecf508..e36b4ab9a7d0 100644 --- a/libnd4j/include/ops/declarable/generic/random/poisson.cpp +++ b/libnd4j/include/ops/declarable/generic/random/poisson.cpp @@ -25,43 +25,43 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) { - // gamma distribution - auto rng = block.randomGenerator(); - auto shape = INPUT_VARIABLE(0); - auto lambda = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - auto seed = 0; - if (block.numI()) { - seed = INT_ARG(0); - } - rng.setSeed(seed); - helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output); - - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto lambda = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; + if (block.numI()) { + seed = INT_ARG(0); + } + rng.setSeed(seed); + helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output); + return Status::OK(); +} - DECLARE_SHAPE_FN(random_poisson) { - auto in = INPUT_VARIABLE(0); - auto shape = in->template asVectorT(); - auto lambdaShape = inputShape->at(1); - auto dtype = ArrayOptions::dataType(lambdaShape); - for (auto d = 0; d < shape::rank(lambdaShape); ++d ) { - shape.emplace_back(shape::sizeAt(lambdaShape, d)); - } - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } +DECLARE_SHAPE_FN(random_poisson) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto lambdaShape = inputShape->at(1); + auto dtype = ArrayOptions::dataType(lambdaShape); + for (auto d = 0; d < shape::rank(lambdaShape); ++d) { + shape.emplace_back(shape::sizeAt(lambdaShape, d)); + } + auto newShape = + ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(random_poisson) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(random_poisson) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp index a52f7b6548cf..814947589d00 100644 --- a/libnd4j/include/ops/declarable/generic/random/random_crop.cpp +++ b/libnd4j/include/ops/declarable/generic/random/random_crop.cpp @@ -18,55 +18,58 @@ // Created by GS // - #include #include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(random_crop, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); // values for crop - auto shape = INPUT_VARIABLE(1); // shape for result + auto input = INPUT_VARIABLE(0); // values for crop + auto shape = INPUT_VARIABLE(1); // shape for result + + NDArray* reduceShape = nullptr; // this param is optional + auto output = OUTPUT_VARIABLE(0); // - NDArray* reduceShape = nullptr; // this param is optional - auto output = OUTPUT_VARIABLE(0); // - - int seed = 0; + int seed = 0; - if (block.numI() > 0) - seed = INT_ARG(0); + if (block.numI() > 0) seed = INT_ARG(0); - REQUIRE_TRUE(shape->isVector(), 0, "random_crop: Shape tensor should be a vector."); - - REQUIRE_TRUE(input->rankOf() == shape->lengthOf(), 0, "random_crop: The length of the shape vector is not match input rank. %i and %i were given.", - input->rankOf(), shape->lengthOf()); + REQUIRE_TRUE(shape->isVector(), 0, + "random_crop: Shape tensor should be a vector."); - for (int e = 0; e < shape->lengthOf(); ++e) { - REQUIRE_TRUE((*shape).e(e) <= input->sizeAt(e), 0, "random_crop: Shape tensor should be less than proper input dimension (dim %i, %i > %i).", e, (*shape).e(e), input->sizeAt(e)); - } + REQUIRE_TRUE(input->rankOf() == shape->lengthOf(), 0, + "random_crop: The length of the shape vector is not match input " + "rank. %i and %i were given.", + input->rankOf(), shape->lengthOf()); - return helpers::randomCropFunctor(block, input, shape, output, seed); + for (int e = 0; e < shape->lengthOf(); ++e) { + REQUIRE_TRUE((*shape).e(e) <= input->sizeAt(e), 0, + "random_crop: Shape tensor should be less than proper input " + "dimension (dim %i, %i > %i).", + e, (*shape).e(e), input->sizeAt(e)); + } + + return helpers::randomCropFunctor(block, input, shape, output, seed); } DECLARE_SHAPE_FN(random_crop) { - auto in = INPUT_VARIABLE(1); - auto typeShape = inputShape->at(0); - std::vector shape(in->lengthOf()); - - for (int e = 0; e < shape.size(); e++) - shape[e] = (*in).e(e); - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(typeShape), 'c', shape); - return SHAPELIST(newShape); + auto in = INPUT_VARIABLE(1); + auto typeShape = inputShape->at(0); + std::vector shape(in->lengthOf()); + + for (int e = 0; e < shape.size(); e++) shape[e] = (*in).e(e); + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(typeShape), 'c', shape); + return SHAPELIST(newShape); } - DECLARE_TYPES(random_crop) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(random_crop) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp b/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp index 012d2e55ae43..4f8741207154 100644 --- a/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp +++ b/libnd4j/include/ops/declarable/generic/random/random_shuffle.cpp @@ -22,33 +22,31 @@ #if NOT_EXCLUDED(OP_random_shuffle) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(random_shuffle, 1, 1, true) { - - auto input = INPUT_VARIABLE(0); - const bool isInplace = block.isInplace(); - auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); - -// sd::random::RandomBuffer* rng = block.getRNG(); - sd::graph::RandomGenerator rng = block.randomGenerator(); -// REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !"); - - helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace); - - return Status::OK(); -} + auto input = INPUT_VARIABLE(0); + const bool isInplace = block.isInplace(); + auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); + + // sd::random::RandomBuffer* rng = block.getRNG(); + sd::graph::RandomGenerator rng = block.randomGenerator(); + // REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be + // defined in Graph !"); + helpers::randomShuffle(block.launchContext(), *input, *output, rng, + isInplace); - DECLARE_TYPES(random_shuffle) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return Status::OK(); } + +DECLARE_TYPES(random_shuffle) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp index dff6a4e6484d..d265807e674b 100644 --- a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp @@ -21,43 +21,48 @@ #include #if NOT_EXCLUDED(OP_set_seed) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRng(); //.getRNG(); - - Nd4jLong seed = 0; - if (block.numI() > 0) { - seed = INT_ARG(0); - } else if (block.width() > 0) { - auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(input->isScalar(),0 ,"SetSeed: Seed operand should be scalar"); - seed = input->e(0); - } else { - REQUIRE_TRUE(false, 0, "SetSeed: either IArg or scalr input should be provided"); - } - - // FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream - //refreshBuffer(nullptr, seed, (Nd4jPointer) rng); - rng.setSeed((int)seed); - return Status::OK(); - } - - DECLARE_SHAPE_FN(set_seed) { - auto newshape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); - return SHAPELIST(newshape); - } - - DECLARE_TYPES(set_seed) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // defined in Graph"); + auto rng = block.getRng(); //.getRNG(); + + Nd4jLong seed = 0; + if (block.numI() > 0) { + seed = INT_ARG(0); + } else if (block.width() > 0) { + auto input = INPUT_VARIABLE(0); + REQUIRE_TRUE(input->isScalar(), 0, + "SetSeed: Seed operand should be scalar"); + seed = input->e(0); + } else { + REQUIRE_TRUE(false, 0, + "SetSeed: either IArg or scalr input should be provided"); + } + + // FIXME: this approach isn't really good for cuda, since it'll assume that + // CUDA might get nullptr instead of stream + // refreshBuffer(nullptr, seed, (Nd4jPointer) rng); + rng.setSeed((int)seed); + return Status::OK(); +} + +DECLARE_SHAPE_FN(set_seed) { + auto newshape = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); + return SHAPELIST(newshape); +} + +DECLARE_TYPES(set_seed) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 9ba51027c21f..14cdf927729e 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -22,76 +22,78 @@ #include #if NOT_EXCLUDED(OP_randomuniform) -#include #include +#include #include namespace sd { - namespace ops { - /////////////////////// - /** - * uniform distribution - * takes 1 ndarray - * - * T argumens map: - * TArgs[0] - min for rng - * TArgs[1] - max for rng - */ - CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) { - // uniform distribution - auto rng = block.randomGenerator(); - auto dtype = DataType::FLOAT32; - if (block.numI()) - dtype = (DataType)INT_ARG(0); - - auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; - auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr; - bool disposable = false; +namespace ops { +/////////////////////// +/** + * uniform distribution + * takes 1 ndarray + * + * T argumens map: + * TArgs[0] - min for rng + * TArgs[1] - max for rng + */ +CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) { + // uniform distribution + auto rng = block.randomGenerator(); + auto dtype = DataType::FLOAT32; + if (block.numI()) dtype = (DataType)INT_ARG(0); - if (min == nullptr && max == nullptr && block.numT() >= 2) { - min = NDArrayFactory::create_(dtype, block.launchContext()); - max = NDArrayFactory::create_(dtype, block.launchContext()); - min->p(0, T_ARG(0)); - max->p(0, T_ARG(1)); - disposable = true; - } + auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*)nullptr; + auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*)nullptr; + bool disposable = false; - auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given."); + if (min == nullptr && max == nullptr && block.numT() >= 2) { + min = NDArrayFactory::create_(dtype, block.launchContext()); + max = NDArrayFactory::create_(dtype, block.launchContext()); + min->p(0, T_ARG(0)); + max->p(0, T_ARG(1)); + disposable = true; + } - helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); + auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(output->dataType() == dtype, 0, + "RandomUniform: data type of output should be equals to given."); - if (disposable) { - delete min; - delete max; - } - return Status::OK(); - } + helpers::fillRandomUniform(block.launchContext(), rng, min, max, output); + if (disposable) { + delete min; + delete max; + } + return Status::OK(); +} - DECLARE_SHAPE_FN(randomuniform) { - auto in = INPUT_VARIABLE(0); - //auto min = INPUT_VARIABLE(1); - auto shape = in->template asVectorT(); - auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min +DECLARE_SHAPE_FN(randomuniform) { + auto in = INPUT_VARIABLE(0); + // auto min = INPUT_VARIABLE(1); + auto shape = in->template asVectorT(); + auto dtype = DataType::FLOAT32; // ArrayOptions::dataType(inputShape->at(1)); + // // output type is by given min - if (block.numI()) - dtype = (DataType)INT_ARG(0); - if (block.width() > 1) - REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same"); + if (block.numI()) dtype = (DataType)INT_ARG(0); + if (block.width() > 1) + REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, + "RandomUniform: data type of output and min/max args should " + "be the same"); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); - return SHAPELIST(newShape); - } + auto newShape = + ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(randomuniform) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - } +DECLARE_TYPES(randomuniform) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp index f0c27da53d31..0a72202f1d68 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp @@ -21,70 +21,73 @@ #include #if NOT_EXCLUDED(OP_argmax) -#include -#include #include +#include +#include namespace sd { - namespace ops { - DECLARE_TYPES(argmax) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - - CUSTOM_OP_IMPL(argmax, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +DECLARE_TYPES(argmax) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); +} - auto axis = block.getIArguments(); +CUSTOM_OP_IMPL(argmax, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); + auto axis = block.getIArguments(); - input->applyIndexReduce(indexreduce::IndexMax, *output, axis); - } else { - helpers::adjustAxis(input->rankOf(), axis); + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); - input->applyIndexReduce(indexreduce::IndexMax, *output, axis); - } + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); + } else { + helpers::adjustAxis(input->rankOf(), axis); - STORE_RESULT(output); + input->applyIndexReduce(indexreduce::IndexMax, *output, axis); + } - return Status::OK(); - } + STORE_RESULT(output); - DECLARE_SHAPE_FN(argmax) { - std::vector dims; + return Status::OK(); +} - if (block.width() == 1) { - dims = block.getIArguments(); - } else { - auto y = INPUT_VARIABLE(1); - dims = y->template asVectorT(); - } +DECLARE_SHAPE_FN(argmax) { + std::vector dims; - // we're resolving negative axis here - helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + if (block.width() == 1) { + dims = block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + // we're resolving negative axis here + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); - for (auto d:dims) { - REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape"); - } + for (auto d : dims) { + REQUIRE_TRUE(inputShape->at(0)[d + 1] != 0, 0, + "ArgMax: you can't reduce along axis with 0 in shape"); + } - // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); - } + // special case - output is scalar + if (dims.size() == 0 || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo( + sd::DataType::INT64)); + } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.workspace())); - } - } + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), + DataType::INT64, false, + false, block.workspace())); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp index a9c7951a8774..0c44a59aaba5 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp @@ -25,67 +25,70 @@ #include namespace sd { - namespace ops { +namespace ops { - DECLARE_TYPES(argmin) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - - CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto axis = block.getIArguments(); - - auto output = OUTPUT_VARIABLE(0); - - // axis might be dynamic (i.e. tf mode) - if (block.width() > 1 && axis.size() == 0) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axis); - - input->applyIndexReduce(indexreduce::IndexMin, *output, axis); - } else { - helpers::adjustAxis(input->rankOf(), axis); - - input->applyIndexReduce(indexreduce::IndexMin, *output, axis); - } - - STORE_RESULT(output); +DECLARE_TYPES(argmin) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); +} - return ND4J_STATUS_OK; - } +CUSTOM_OP_IMPL(argmin, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto axis = block.getIArguments(); - DECLARE_SHAPE_FN(argmin) { - std::vector dims; - auto in = inputShape->at(0); - if (block.width() == 1) { - dims = block.getIArguments(); - } else { - auto y = INPUT_VARIABLE(1); - dims = y->template asVectorT(); - } + auto output = OUTPUT_VARIABLE(0); - // we're resolving negative axis here - helpers::adjustAxis(shape::rank(in), dims); + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); + } else { + helpers::adjustAxis(input->rankOf(), axis); - for (auto d:dims) { - REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); - } + input->applyIndexReduce(indexreduce::IndexMin, *output, axis); + } - // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); - } + STORE_RESULT(output); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.workspace()); - return SHAPELIST(newShape); - } + return ND4J_STATUS_OK; +} - } +DECLARE_SHAPE_FN(argmin) { + std::vector dims; + auto in = inputShape->at(0); + if (block.width() == 1) { + dims = block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } + + // we're resolving negative axis here + helpers::adjustAxis(shape::rank(in), dims); + + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + for (auto d : dims) { + REQUIRE_TRUE(inputShape->at(0)[d + 1] != 0, 0, + "ArgMin: you can't reduce along axis with 0 in shape"); + } + + // special case - output is scalar + if (dims.size() == 0 || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); + } + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', dims, in, DataType::INT64, false, false, block.workspace()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp index 6a0b7880b591..6ae9ea962588 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/norm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/norm.cpp @@ -25,80 +25,87 @@ #include namespace sd { - namespace ops { - REDUCTION_OP_IMPL(norm, 1, 1, false, 1, -2) { - auto input = INPUT_VARIABLE(0); - NDArray *output = OUTPUT_VARIABLE(0); +namespace ops { +REDUCTION_OP_IMPL(norm, 1, 1, false, 1, -2) { + auto input = INPUT_VARIABLE(0); + NDArray *output = OUTPUT_VARIABLE(0); - auto mode = (int) T_ARG(0); - std::vector dims = block.getIArguments(); - bool overwrite = false; + auto mode = (int)T_ARG(0); + std::vector dims = block.getIArguments(); + bool overwrite = false; - if (block.width() == 1) { - output = OUTPUT_VARIABLE(0); - } else { - auto axisVector = INPUT_VARIABLE(1); - dims.resize(axisVector->lengthOf()); - helpers::adjustAxis(input->rankOf(), axisVector, dims); - auto shape = ShapeUtils::evalReduceShapeInfo(input->ordering(), dims, *input, false, false); - if (!shape::equalsStrict(shape, output->shapeInfo())) { - output = new NDArray(shape, false, block.launchContext()); - overwrite = true; - } - } - switch(mode) { - case 0: { - REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); - // fro - input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); - } - break; - case 1: { - // euclidean - if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { - input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2); - } else { - input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); - } - } - break; - case 2: { - // 1 - input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2); - } - break; - case 3: { - // 2 - input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2); - } - break; - case 4: { - // inf-norm - input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2); - } - break; - default: { - // p-norm - REQUIRE_TRUE(block.numI() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); - // FIXME: p is required here - //T p = T_ARG(1); - input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2); - } - } + if (block.width() == 1) { + output = OUTPUT_VARIABLE(0); + } else { + auto axisVector = INPUT_VARIABLE(1); + dims.resize(axisVector->lengthOf()); + helpers::adjustAxis(input->rankOf(), axisVector, dims); + auto shape = ShapeUtils::evalReduceShapeInfo(input->ordering(), dims, + *input, false, false); + if (!shape::equalsStrict(shape, output->shapeInfo())) { + output = new NDArray(shape, false, block.launchContext()); + overwrite = true; + } + } + switch (mode) { + case 0: { + REQUIRE_TRUE( + dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, + "Norm: Frobenius is defined for 2D matrices or TADS only"); + // fro + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, + output->rankOf() == 2); + } break; + case 1: { + // euclidean + if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { + input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, + output->rankOf() == 2); + } else { + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, + output->rankOf() == 2); + } + } break; + case 2: { + // 1 + input->reduceAlongDimension(reduce::Norm1, *output, dims, false, + output->rankOf() == 2); + } break; + case 3: { + // 2 + input->reduceAlongDimension(reduce::Norm2, *output, dims, false, + output->rankOf() == 2); + } break; + case 4: { + // inf-norm + input->reduceAlongDimension(reduce::NormMax, *output, dims, false, + output->rankOf() == 2); + } break; + default: { + // p-norm + REQUIRE_TRUE( + block.numI() > 1, 0, + "P-Norm reductions requires 2 TArguments, but only 1 was provided"); + // FIXME: p is required here + // T p = T_ARG(1); + input->reduceAlongDimension(reduce::NormP, *output, dims, false, + output->rankOf() == 2); + } + } - if (overwrite) { - OVERWRITE_RESULT(output); - } + if (overwrite) { + OVERWRITE_RESULT(output); + } - return ND4J_STATUS_OK; - }; + return ND4J_STATUS_OK; +}; - DECLARE_TYPES(norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp index aefea142485d..5681a7e49eeb 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceMean.cpp @@ -18,143 +18,165 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 01.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MEAN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_MEAN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_mean) { - - auto dimensions = block.getIArguments(); - auto in = inputShape->at(0); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(shape::rank(in), axesVector, dimensions); - } - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= in[0], 0, "REDUCE_MEAN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + auto dimensions = block.getIArguments(); + auto in = inputShape->at(0); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(shape::rank(in), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= in[0], 0, + "REDUCE_MEAN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MEAN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_mean) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_mean_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MEAN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - if(gradO->lengthOf() == 1) { - gradI->assign(gradO->e(0) / input->lengthOf()); - } - else { - - gradI->assign((gradO->lengthOf() + 0.) / input->lengthOf()); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } - else - *gradI *= *gradO; - - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MEAN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_MEAN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + if (gradO->lengthOf() == 1) { + gradI->assign(gradO->e(0) / input->lengthOf()); + } else { + gradI->assign((gradO->lengthOf() + 0.) / input->lengthOf()); + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_mean_bp) { - auto in = inputShape->at(0); - auto dimensions = block.getIArguments(); - auto rank = shape::rank(in); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_MEAN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -rank || item < rank, 0, "REDUCE_MEAN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , rank, rank, item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto in = inputShape->at(0); + auto dimensions = block.getIArguments(); + auto rank = shape::rank(in); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_MEAN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -rank || item < rank, 0, + "REDUCE_MEAN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + rank, rank, item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(inputShape->at(0), gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - DECLARE_TYPES(reduce_mean_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp index 04144a04327a..cdc0698e35e2 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceStDev.cpp @@ -18,162 +18,193 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 04.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) { - keepDims = B_ARG(0); - if (block.numB() > 1) - biasCorrected = B_ARG(1); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - if (block.numT() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_STDEV OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_STDEV OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_STDEV OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, + *output, biasCorrected, dimensions); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_stdev) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = block.getIArguments(); - - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - if (block.numB()) { - keepDims = B_ARG(0); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_STDEV OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + auto dimensions = block.getIArguments(); + + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_STDEV OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_STDEV OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_stdev) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) { - keepDims = B_ARG(0); - if (block.numB() > 1) - biasCorrected = B_ARG(1); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - if (block.numT() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_STDEV_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); - const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - - auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); - - NDArray variance(mean.shapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array - input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions); - - gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } - else - *gradI *= *gradO; // automatic broadcasting happens here - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_STDEV_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_STDEV_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); + const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + + NDArray variance(mean.shapeInfo(), true, + block.launchContext()); // create empty array with shape + // matching shape of mean array + input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, + variance, biasCorrected, dimensions); + + gradI->assign((*input - mean) / + (variance * NminusOne)); // automatic broadcasting happens here + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *input, true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like [a,b] + // -> [1,a,1,b] + } else + *gradI *= *gradO; // automatic broadcasting happens here + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_stdev_bp) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_STDEV_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_STDEV_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(in, gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); -// return SHAPELIST(in); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_STDEV_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_STDEV_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(in, gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); + // return SHAPELIST(in); } DECLARE_TYPES(reduce_stdev_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp index 1c60f53ab76e..29916c8f9950 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduceVariance.cpp @@ -18,157 +18,186 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 04.06.2018 // - #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) { - keepDims = B_ARG(0); - if (block.numB() > 1) - biasCorrected = B_ARG(1); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - if (block.numT() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_VARIANCE OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_VARIANCE OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + input->varianceAlongDimension(variance::SummaryStatsVariance, *output, + biasCorrected, dimensions); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_variance) { - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.numB()) { - keepDims = B_ARG(0); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - } - - REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + if (block.numB()) { + keepDims = B_ARG(0); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + } + + REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, + "REDUCE_VARIANCE OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_VARIANCE OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_variance) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - - auto gradI = OUTPUT_VARIABLE(0); - - bool keepDims = false;//block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; - bool biasCorrected = false;//block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } -// else if (block.getIArguments()->size()) - if (block.numB()) { - keepDims = B_ARG(0); - if (block.numB() > 1) - biasCorrected = B_ARG(1); - } - else if (block.numT()) { - keepDims = (bool)T_ARG(0); - if (block.numT() > 1) - biasCorrected = (bool)T_ARG(1); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); - const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; - const double factor1 = 2.0 / NminusOne; - const double factor2 = 2.0 / (N * NminusOne); - - auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); - - gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; // automatic broadcasting happens here - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + bool keepDims = + false; // block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; + bool biasCorrected = + false; // block.getTArguments()->size() > 1 ? (bool)T_ARG(1) : false; + + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + // else if (block.getIArguments()->size()) + if (block.numB()) { + keepDims = B_ARG(0); + if (block.numB() > 1) biasCorrected = B_ARG(1); + } else if (block.numT()) { + keepDims = (bool)T_ARG(0); + if (block.numT() > 1) biasCorrected = (bool)T_ARG(1); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_VARIANCE_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); + const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; + const double factor1 = 2.0 / NminusOne; + const double factor2 = 2.0 / (N * NminusOne); + + auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true); + + gradI->assign((*input - mean) * + (2.0f / NminusOne)); // automatic broadcasting happens here + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *input, true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like [a,b] + // -> [1,a,1,b] + } else + *gradI *= *gradO; // automatic broadcasting happens here + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_variance_bp) { - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(rank, axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(in, gradIshapeInfo); - - return SHAPELIST(CONSTANT(gradIshapeInfo)); + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(rank, axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= rank, 0, + "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_VARIANCE_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(in, gradIshapeInfo); + + return SHAPELIST(CONSTANT(gradIshapeInfo)); } - DECLARE_TYPES(reduce_variance_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp index f64bb86444cc..86b61a254771 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_dot.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,97 +28,112 @@ namespace ops { //////////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_dot_bp, 3, 2, false, 0, 0) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto gradO = INPUT_VARIABLE(2); - - auto gradX = OUTPUT_VARIABLE(0); - auto gradY = OUTPUT_VARIABLE(1); - - // L(x,y) = SUM(x_i * y_i) - // dL/dx_i = y_i - - REQUIRE_TRUE(x->isSameShape(y), 0, "REDUCE_DOT_BP OP: both input arrays x and y should have same shapes, but got %s and %s correspondingly", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - if (gradO->lengthOf() == 1) { // scalar of reduced to scalar with keep dimensions - gradX->assign((*y) * (*gradO)); - gradY->assign((*x) * (*gradO)); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto gradO = INPUT_VARIABLE(2); + + auto gradX = OUTPUT_VARIABLE(0); + auto gradY = OUTPUT_VARIABLE(1); + + // L(x,y) = SUM(x_i * y_i) + // dL/dx_i = y_i + + REQUIRE_TRUE(x->isSameShape(y), 0, + "REDUCE_DOT_BP OP: both input arrays x and y should have same " + "shapes, but got %s and %s correspondingly", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + if (gradO->lengthOf() == + 1) { // scalar of reduced to scalar with keep dimensions + gradX->assign((*y) * (*gradO)); + gradY->assign((*x) * (*gradO)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); + + if (block.width() > 3) { + auto axesVector = INPUT_VARIABLE(3); + helpers::adjustAxis(x->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 3) { - auto axesVector = INPUT_VARIABLE(3); - helpers::adjustAxis(x->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= x->rankOf(), 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -x->rankOf() && item < x->rankOf(), 0, "REDUCE_DOT_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , x->rankOf(), x->rankOf(), item); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *x, true, false, block.workspace()); - auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - - gradX->assign((*y) * r); - gradY->assign((*x) * r); - } - else { - gradX->assign((*y) * (*gradO)); - gradY->assign((*x) * (*gradO)); - } + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= x->rankOf(), 0, + "REDUCE_DOT_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -x->rankOf() && item < x->rankOf(), 0, + "REDUCE_DOT_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + x->rankOf(), x->rankOf(), item); + + if (!keepDims) { + auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo( + gradO->ordering(), dimensions, *x, true, false, block.workspace()); + auto r = gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + + gradX->assign((*y) * r); + gradY->assign((*x) * r); + } else { + gradX->assign((*y) * (*gradO)); + gradY->assign((*x) * (*gradO)); } - return Status::OK(); + } + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_dot_bp) { + if (shape::length(inputShape->at(2)) > 1) { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(shape::length(inputShape->at(2)) > 1) { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 3) { - auto axesVector = INPUT_VARIABLE(3); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_DOT_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_DOT_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); + if (block.width() > 3) { + auto axesVector = INPUT_VARIABLE(3); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); } - Nd4jLong *outShapeInfo1, *outShapeInfo2; - COPY_SHAPE(inputShape->at(0), outShapeInfo1); - COPY_SHAPE(inputShape->at(1), outShapeInfo2); - - return SHAPELIST(CONSTANT(outShapeInfo1), CONSTANT(outShapeInfo2)); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_DOT_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_DOT_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + } + + Nd4jLong *outShapeInfo1, *outShapeInfo2; + COPY_SHAPE(inputShape->at(0), outShapeInfo1); + COPY_SHAPE(inputShape->at(1), outShapeInfo2); + + return SHAPELIST(CONSTANT(outShapeInfo1), CONSTANT(outShapeInfo2)); } DECLARE_TYPES(reduce_dot_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp index 4833403825a5..a51cca63d9c6 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_logsumexp.cpp @@ -24,56 +24,60 @@ namespace sd { namespace ops { #if NOT_EXCLUDED(OP_reduce_logsumexp) - CUSTOM_OP_IMPL(reduce_logsumexp, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::vector axes;// = *block.getIArguments(); - if (block.width() > 1) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axes ); - } - else if (block.numI() > 0) { - axes = block.getIArguments(); - } +CUSTOM_OP_IMPL(reduce_logsumexp, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::vector axes; // = *block.getIArguments(); + if (block.width() > 1) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axes); + } else if (block.numI() > 0) { + axes = block.getIArguments(); + } - for(const auto& item : axes) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item shapeInfo()[0], 0, "REDUCE_LOGSUMEXP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + for (const auto& item : axes) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_LOGSUMEXP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; - Nd4jLong maxI = input->argMax(); - auto maxVals = input->e(maxI); - //void* whereMax = (void*)(); - auto internal = (*input); - internal -= maxVals; - internal.applyTransform(transform::Exp, internal); - internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals); - output->applyTransform(transform::Log, *output); - (*output) += maxVals; - return ND4J_STATUS_OK; - } - DECLARE_TYPES(reduce_logsumexp) { - getOpDescriptor() - -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - -> setAllowedOutputTypes({ALL_FLOATS}); - } - DECLARE_SHAPE_FN(reduce_logsumexp) { - - const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; - auto input = INPUT_VARIABLE(0); + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + Nd4jLong maxI = input->argMax(); + auto maxVals = input->e(maxI); + // void* whereMax = (void*)(); + auto internal = (*input); + internal -= maxVals; + internal.applyTransform(transform::Exp, internal); + internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, + false); //, (void*)&maxVals); + output->applyTransform(transform::Log, *output); + (*output) += maxVals; + return ND4J_STATUS_OK; +} +DECLARE_TYPES(reduce_logsumexp) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +DECLARE_SHAPE_FN(reduce_logsumexp) { + const bool keepDims = block.numT() > 0 ? (bool)T_ARG(0) : false; + auto input = INPUT_VARIABLE(0); - std::vector axes; // = *block.getIArguments(); - if (block.width() > 1) { - auto axisVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axisVector, axes ); - } - else if (block.numI() > 0) { - axes = block.getIArguments(); - } + std::vector axes; // = *block.getIArguments(); + if (block.width() > 1) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axes); + } else if (block.numI() > 0) { + axes = block.getIArguments(); + } - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, block.workspace()); + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), axes, inputShape->at(0), keepDims, false, + block.workspace()); - return SHAPELIST(outShapeInfo); - } -#endif -} + return SHAPELIST(outShapeInfo); } +#endif +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp index b9dfe68bee18..f7416ff91f6d 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_max.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -30,62 +30,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector dimensions = block.getIArguments(); - std::vector dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MAX OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MAX OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + bool keepDims = false; //: false; + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); - bool keepDims = false;//: false; - if (block.numB() > 0) - keepDims = B_ARG(0); - else if (block.numT() > 0) - keepDims = (bool)T_ARG(0); + input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); - input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_max) { - - bool keepDims = false;//: false; - - if (block.numB() > 0) - keepDims = B_ARG(0); - else if (block.numT() > 0) - keepDims = (bool)T_ARG(0); - - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = false; //: false; + + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MAX OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MAX OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_max) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } #endif @@ -93,69 +105,81 @@ DECLARE_TYPES(reduce_max) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMax); - gradI->p(indOfMaxElem.t(0), gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MAX_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MAX_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMax); + gradI->p(indOfMaxElem.t(0), gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexMax, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + } + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_max_bp) { + std::vector dimensions = block.getIArguments(); - std::vector dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MAX_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MAX_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp index f83bdac50e0b..fb9b17506309 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_min.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -30,135 +30,157 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector dimensions = block.getIArguments(); - std::vector dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MIN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MIN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MIN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + bool keepDims = false; //: false; + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); - bool keepDims = false;//: false; - if (block.numB() > 0) - keepDims = B_ARG(0); - else if (block.numT() > 0) - keepDims = (bool)T_ARG(0); + input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); - input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_min) { - - bool keepDims = false;//: false; - - if (block.numB() > 0) - keepDims = B_ARG(0); - else if (block.numT() > 0) - keepDims = (bool)T_ARG(0); - - auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MIN OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MIN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + bool keepDims = false; //: false; + + if (block.numB() > 0) + keepDims = B_ARG(0); + else if (block.numT() > 0) + keepDims = (bool)T_ARG(0); + + auto dimensions = block.getIArguments(); + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MIN OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MIN OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_min) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } - #endif - #if NOT_EXCLUDED(OP_reduce_min_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_MIN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_MIN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMin); - gradI->p(indOfMaxElem.e(0), gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexMin, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_MIN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_MIN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfMaxElem = input->indexReduceNumber(sd::indexreduce::IndexMin); + gradI->p(indOfMaxElem.e(0), gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexMin, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_min_bp) { + std::vector dimensions = block.getIArguments(); - std::vector dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_MIN_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_MIN_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_MIN_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_MIN_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_min_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp index 273497548f00..fb6269605440 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm1.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,142 +28,169 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm1, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM1 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Norm1, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm1) { - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM1 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM1 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM1 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_norm1) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif #if NOT_EXCLUDED(OP_reduce_norm1_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm1_bp, 2, 1, false, 0, 0) { - // L = Sum abs(x_i) for all i = 1 to N - // dL/dx_i = 1 if x_i >= 0 and -1 when x_i < 0 - // out_i = epsilon_i if x_i > 0 and -epsilon_i when x_i < 0 - // when gradO is non a scalar, using dimensions to split output onto gradO like parts - // and use LAMBDA with that formula for it. + // L = Sum abs(x_i) for all i = 1 to N + // dL/dx_i = 1 if x_i >= 0 and -1 when x_i < 0 + // out_i = epsilon_i if x_i > 0 and -epsilon_i when x_i < 0 + // when gradO is non a scalar, using dimensions to split output onto gradO + // like parts and use LAMBDA with that formula for it. - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - input->applyTransform(sd::transform::Sign, *gradI); + input->applyTransform(sd::transform::Sign, *gradI); - if (gradO->lengthOf() == 1) { - *gradI *= *gradO; - } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // + if (gradO->lengthOf() == 1) { + *gradI *= *gradO; + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm1_bp) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm1_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp index 329b567e544f..72188923a569 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm2.cpp @@ -28,62 +28,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm2, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM2 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM2 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Norm2, *output, dimensions, keepDims); + + return Status::OK(); } - DECLARE_SHAPE_FN(reduce_norm2) { - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM2 OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM2 OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM2 OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_norm2) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -91,77 +103,91 @@ DECLARE_TYPES(reduce_norm2) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm2_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + gradI->assign(input); - gradI->assign(input); + if (gradO->lengthOf() == 1) { + *gradI /= input->reduceNumber(reduce::Norm2); + *gradI *= *gradO; + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if (gradO->lengthOf() == 1) { - *gradI /= input->reduceNumber(reduce::Norm2); - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM2_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM2_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - // *** calculations *** // - - *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; - } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM2_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM2_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI /= input->reduceAlongDimension(reduce::Norm2, dimensions, true); + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm2_bp) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM2_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM2_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM2_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM2_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm2_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp index 89cc2a5530ee..10110d0ee4ed 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_norm_max.cpp @@ -20,8 +20,8 @@ // #include -#include #include +#include namespace sd { namespace ops { @@ -29,63 +29,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm_max, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM_MAX OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM_MAX OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::NormMax, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm_max) { - - auto in = inputShape->at(0); - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM_MAX OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(in), dimensions, in, keepDims, false, block.workspace())); + auto in = inputShape->at(0); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM_MAX OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM_MAX OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(in), dimensions, in, keepDims, false, block.workspace())); } DECLARE_TYPES(reduce_norm_max) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -93,73 +104,84 @@ DECLARE_TYPES(reduce_norm_max) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_norm_max_bp, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - std::vector dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - *gradI = 0; - - if(gradO->lengthOf() == 1) { - - auto indOfAbsMaxElem = input->indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - const Nd4jLong ind = indOfAbsMaxElem.t(0); - const int sign = input->e(ind) >= 0 ? 1 : -1; - gradI->p(ind, sign * gradO->e(0)); - } - else { - - auto indicesArr = input->applyIndexReduce(sd::indexreduce::IndexAbsoluteMax, dimensions); - helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation - *gradI *= input->transform(sd::transform::Sign); - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + std::vector dimensions = block.getIArguments(); + + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + *gradI = 0; + + if (gradO->lengthOf() == 1) { + auto indOfAbsMaxElem = + input->indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + const Nd4jLong ind = indOfAbsMaxElem.t(0); + const int sign = input->e(ind) >= 0 ? 1 : -1; + gradI->p(ind, sign * gradO->e(0)); + } else { + auto indicesArr = + input->applyIndexReduce(sd::indexreduce::IndexAbsoluteMax, dimensions); + helpers::scatterSimple( + block.launchContext(), 6, *gradI, *gradO, indicesArr, + ShapeUtils::evalDimsToExclude( + gradI->rankOf(), dimensions)); // 6 corresponds to copy operation + *gradI *= input->transform(sd::transform::Sign); + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_norm_max_bp) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_NORM_MAX_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_NORM_MAX_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_norm_max_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - - - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp index e111597ea49b..68072e639cbf 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_prod.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -28,62 +28,74 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_prod, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_PROD OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_PROD OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Prod, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_prod) { - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_PROD OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_PROD OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_PROD OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_PROD OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_prod) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -91,78 +103,94 @@ DECLARE_TYPES(reduce_prod) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_prod_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + if (gradO->lengthOf() == 1) { + gradI->assign(input->reduceNumber(sd::reduce::Prod)); + *gradI /= *input; + *gradI *= gradO->e(0); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - if (gradO->lengthOf() == 1) { - gradI->assign(input->reduceNumber(sd::reduce::Prod)); - *gradI /= *input; - *gradI *= gradO->e(0); - } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_NORM1_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_NORM1_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), products, *gradI); - *gradI /= *input; - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - *gradI *= gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - } else - *gradI *= *gradO; + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_NORM1_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + auto products = input->reduceAlongDimension(reduce::Prod, dimensions, true); + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), products, + *gradI); + *gradI /= *input; + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + *gradI *= gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + *gradI *= *gradO; + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_prod_bp) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_PROD_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_PROD_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_PROD_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_PROD_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_prod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp index d9d2dc126f05..e08f2d04a3f6 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sqnorm.cpp @@ -18,8 +18,8 @@ // Created by george@skymind.io on 6/4/2018. // -#include #include +#include namespace sd { namespace ops { @@ -27,63 +27,77 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sqnorm, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradI = OUTPUT_VARIABLE(0); + bool keepDims = false; - bool keepDims = false; - - auto dimensions = block.getIArguments(); + auto dimensions = block.getIArguments(); - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SQNORM OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SQNORM OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); - input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, keepDims); + input->reduceAlongDimension(reduce::SquaredNorm, *gradI, dimensions, + keepDims); - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sqnorm) { - - auto dimensions = block.getIArguments(); - bool keepDims = false; - - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SQNORM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - auto outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace()); - - return SHAPELIST(outShapeInfo); + auto dimensions = block.getIArguments(); + bool keepDims = false; + + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SQNORM OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SQNORM OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + auto outShapeInfo = ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace()); + + return SHAPELIST(outShapeInfo); } DECLARE_TYPES(reduce_sqnorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif @@ -92,74 +106,90 @@ DECLARE_TYPES(reduce_sqnorm) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sqnorm_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); + if (gradO->lengthOf() == 1) { + gradI->assign(2 * (*input) * gradO->e(0)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if (gradO->lengthOf() == 1) { - gradI->assign( 2 * (*input) * gradO->e(0)); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SQNORM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SQNORM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // - - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - gradI->assign(2. * (*input) *gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims))); // for example could be something like [a,b] -> [1,a,1,b] - } else - gradI->assign(2. * (*input) * *gradO); - } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SQNORM_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SQNORM_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + gradI->assign( + 2. * (*input) * + gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims))); // for example could be something like + // [a,b] -> [1,a,1,b] + } else + gradI->assign(2. * (*input) * *gradO); + } + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sqnorm_bp) { + if (shape::length(inputShape->at(1)) > 1) { + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } - if(shape::length(inputShape->at(1)) > 1) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SQNORM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SQNORM_BP OP: the number of dimensions to reduce " + "along must be <= input array rank, but got %i instead", + dimensions.size()); - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SQNORM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - } + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SQNORM_BP OP: the input dimension to reduce along " + "must be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + } - Nd4jLong* gradIshapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), gradIshapeInfo); + Nd4jLong* gradIshapeInfo(nullptr); + COPY_SHAPE(inputShape->at(0), gradIshapeInfo); - return SHAPELIST(CONSTANT(gradIshapeInfo)); + return SHAPELIST(CONSTANT(gradIshapeInfo)); } DECLARE_TYPES(reduce_sqnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp index e23ef5502b94..5b561d8afbb4 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/reduce_sum.cpp @@ -28,136 +28,161 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sum, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SUM OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->shapeInfo()[0] && item < input->shapeInfo()[0], + 0, + "REDUCE_SUM OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + input->reduceAlongDimension(reduce::Sum, *output, dimensions, keepDims); + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sum) { - - bool keepDims = false; - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - std::vector dimensions; - if (block.width() > 1) { - auto axesVector = INPUT_VARIABLE(1); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - else if (block.numI()) - dimensions = block.getIArguments(); - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SUM OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SUM OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); - - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, false, block.workspace())); + bool keepDims = false; + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + std::vector dimensions; + if (block.width() > 1) { + auto axesVector = INPUT_VARIABLE(1); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } else if (block.numI()) + dimensions = block.getIArguments(); + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SUM OP: the number of dimensions to reduce along must " + "be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SUM OP: the input dimension to reduce along must be " + "in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inputShape->at(0)), dimensions, inputShape->at(0), keepDims, + false, block.workspace())); } DECLARE_TYPES(reduce_sum) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } #endif #if NOT_EXCLUDED(OP_reduce_sum_bp) ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(reduce_sum_bp, 2, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - if (gradO->lengthOf() == 1) { - gradI->assign(gradO->e(0)); - } - else { - - bool keepDims = false; - auto dimensions = block.getIArguments(); - - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(input->rankOf(), axesVector, dimensions); - } - - if (block.numB()) - keepDims = B_ARG(0); - else if (block.numT()) - keepDims = (bool)T_ARG(0); - - REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, "REDUCE_SUM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_SUM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); - - // *** calculations *** // + if (gradO->lengthOf() == 1) { + gradI->assign(gradO->e(0)); + } else { + bool keepDims = false; + auto dimensions = block.getIArguments(); - if(!keepDims) { - auto gradOShapeKeepDims = ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, true, false, block.workspace()); - auto r = gradO->reshape(gradO->ordering(), ShapeUtils::pullShapeFromShapeInfo(gradOShapeKeepDims)); // for example could be something like [a,b] -> [1,a,1,b] - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), r, *gradI); - } else - gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), *gradO, *gradI); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(input->rankOf(), axesVector, dimensions); } - return Status::OK(); + if (block.numB()) + keepDims = B_ARG(0); + else if (block.numT()) + keepDims = (bool)T_ARG(0); + + REQUIRE_TRUE(dimensions.size() <= input->rankOf(), 0, + "REDUCE_SUM_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, + "REDUCE_SUM_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + input->rankOf(), input->rankOf(), item); + + // *** calculations *** // + + if (!keepDims) { + auto gradOShapeKeepDims = + ShapeUtils::evalReduceShapeInfo(gradO->ordering(), dimensions, *input, + true, false, block.workspace()); + auto r = gradO->reshape( + gradO->ordering(), + ShapeUtils::pullShapeFromShapeInfo( + gradOShapeKeepDims)); // for example could be something like + // [a,b] -> [1,a,1,b] + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), r, *gradI); + } else + gradI->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), *gradO, + *gradI); + } + + return Status::OK(); } DECLARE_SHAPE_FN(reduce_sum_bp) { - - auto dimensions = block.getIArguments(); - if (block.width() > 2) { - auto axesVector = INPUT_VARIABLE(2); - helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); - } - - REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, "REDUCE_SUM_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); - - for(const auto& item : dimensions) - REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_SUM_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !", inputShape->at(0)[0], inputShape->at(0)[0], item); - - Nd4jLong* outShapeInfo; - COPY_SHAPE(inputShape->at(0), outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto dimensions = block.getIArguments(); + if (block.width() > 2) { + auto axesVector = INPUT_VARIABLE(2); + helpers::adjustAxis(INPUT_VARIABLE(0)->rankOf(), axesVector, dimensions); + } + + REQUIRE_TRUE(dimensions.size() <= inputShape->at(0)[0], 0, + "REDUCE_SUM_BP OP: the number of dimensions to reduce along " + "must be <= input array rank, but got %i instead", + dimensions.size()); + + for (const auto& item : dimensions) + REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], + 0, + "REDUCE_SUM_BP OP: the input dimension to reduce along must " + "be in range [-%i, %i), but got %i instead !", + inputShape->at(0)[0], inputShape->at(0)[0], item); + + Nd4jLong* outShapeInfo; + COPY_SHAPE(inputShape->at(0), outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } DECLARE_TYPES(reduce_sum_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } - #endif -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp index 49961bfe2f5d..1efcbe19a169 100644 --- a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp @@ -24,63 +24,85 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto shape = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int inputRank = input->rankOf(); - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank <= 1, 0, "BROADCAST_TO op: rank of shape array should be <= 1, bot got %i instead !", shapeRank); - REQUIRE_TRUE(inputRank <= shapeLen, 0, "BROADCAST_TO op: rank of input shape array should be <= length of shape array, bot got %i and %i correspondingly !", inputRank, shapeLen); - - std::vector shapeBuff = shape->getBufferAsVector(); - std::vector outShape(shapeBuff.begin(), shapeBuff.end()); - - for(int i = 1; i <= inputRank; ++i) - REQUIRE_TRUE(input->sizeAt(inputRank-i) == outShape[shapeLen-i] || input->sizeAt(inputRank-i) == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - - input->tile(*output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto shape = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int inputRank = input->rankOf(); + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE(shapeRank <= 1, 0, + "BROADCAST_TO op: rank of shape array should be <= 1, bot got " + "%i instead !", + shapeRank); + REQUIRE_TRUE(inputRank <= shapeLen, 0, + "BROADCAST_TO op: rank of input shape array should be <= length " + "of shape array, bot got %i and %i correspondingly !", + inputRank, shapeLen); + + std::vector shapeBuff = shape->getBufferAsVector(); + std::vector outShape(shapeBuff.begin(), shapeBuff.end()); + + for (int i = 1; i <= inputRank; ++i) + REQUIRE_TRUE(input->sizeAt(inputRank - i) == outShape[shapeLen - i] || + input->sizeAt(inputRank - i) == 1, + 0, + "BROADCAST_TO op: shape of input array %s can't be " + "broadcasted to the shape %s !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + + input->tile(*output); + + return Status::OK(); } DECLARE_TYPES(broadcast_to) { - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(DataType::ANY)->setSameMode(true); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(broadcast_to) { - - auto inputShapeInfo = inputShape->at(0); - auto shape = INPUT_VARIABLE(1); - - const int inputRank = inputShapeInfo[0]; - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank <= 1, 0, "BROADCAST_TO op: rank of input shape array should be <= 1, bit got %i instead !", shapeRank); - REQUIRE_TRUE(inputRank <= shapeLen, 0, "BROADCAST_TO op: rank of input shape array should be <= length of shape array, bot got %i and %i correspondingly !", inputRank, shapeLen); - - std::vector shapeBuff = shape->getBufferAsVector(); - std::vector outShape(shapeBuff.begin(), shapeBuff.end()); - - for(int i = 1; i <= inputRank; ++i) - REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); - - auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape); - return SHAPELIST(outShapeInfo); + auto inputShapeInfo = inputShape->at(0); + auto shape = INPUT_VARIABLE(1); + + const int inputRank = inputShapeInfo[0]; + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE(shapeRank <= 1, 0, + "BROADCAST_TO op: rank of input shape array should be <= 1, bit " + "got %i instead !", + shapeRank); + REQUIRE_TRUE(inputRank <= shapeLen, 0, + "BROADCAST_TO op: rank of input shape array should be <= length " + "of shape array, bot got %i and %i correspondingly !", + inputRank, shapeLen); + + std::vector shapeBuff = shape->getBufferAsVector(); + std::vector outShape(shapeBuff.begin(), shapeBuff.end()); + + for (int i = 1; i <= inputRank; ++i) + REQUIRE_TRUE(inputShapeInfo[inputRank + 1 - i] == outShape[shapeLen - i] || + inputShapeInfo[inputRank + 1 - i] == 1, + 0, + "BROADCAST_TO op: shape of input array %s can't be " + "broadcasted to the shape %s !", + ShapeUtils::shapeAsString(inputShapeInfo).c_str(), + ShapeUtils::shapeAsString(outShape).c_str()); + + auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), + outShape); + return SHAPELIST(outShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp index 6a0ad187cb65..201903f10cb8 100644 --- a/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/evaluate_reduction_shape.cpp @@ -24,57 +24,63 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(evaluate_reduction_shape, 2, 1, false, 0, 0) { - auto inputShape = INPUT_VARIABLE(0); - auto axis = INPUT_VARIABLE(1)->asVectorT(); - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(evaluate_reduction_shape, 2, 1, false, 0, 0) { + auto inputShape = INPUT_VARIABLE(0); + auto axis = INPUT_VARIABLE(1)->asVectorT(); + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; + auto output = OUTPUT_VARIABLE(0); - auto shape = inputShape->asVectorT(); + auto shape = inputShape->asVectorT(); - auto tempShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(sd::DataType::INT64, 'c', shape); - auto tempReductionShapeInfo = ShapeUtils::evalReduceShapeInfo('c', axis, tempShapeInfo, keepDims, oldFormat, block.workspace()); + auto tempShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo( + sd::DataType::INT64, 'c', shape); + auto tempReductionShapeInfo = ShapeUtils::evalReduceShapeInfo( + 'c', axis, tempShapeInfo, keepDims, oldFormat, block.workspace()); - REQUIRE_TRUE(output->lengthOf() == shape::rank(tempReductionShapeInfo), 0, "evaluate_reduction_shape: output length should be %i, but got %i instead", shape::rank(tempReductionShapeInfo), output->lengthOf()); + REQUIRE_TRUE(output->lengthOf() == shape::rank(tempReductionShapeInfo), 0, + "evaluate_reduction_shape: output length should be %i, but got " + "%i instead", + shape::rank(tempReductionShapeInfo), output->lengthOf()); - for (int e = 0; e < shape::rank(tempReductionShapeInfo); e++) - output->p(e, tempReductionShapeInfo[e+1]); + for (int e = 0; e < shape::rank(tempReductionShapeInfo); e++) + output->p(e, tempReductionShapeInfo[e + 1]); - return Status::OK(); - } - - DECLARE_TYPES(evaluate_reduction_shape) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(0, sd::DataType::INT64); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(evaluate_reduction_shape) { - auto input = INPUT_VARIABLE(0); - auto axis = INPUT_VARIABLE(1)->asVectorT(); +DECLARE_TYPES(evaluate_reduction_shape) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(0, sd::DataType::INT64); +} - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; +DECLARE_SHAPE_FN(evaluate_reduction_shape) { + auto input = INPUT_VARIABLE(0); + auto axis = INPUT_VARIABLE(1)->asVectorT(); - Nd4jLong length = input->lengthOf(); + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto oldFormat = block.numB() > 1 ? B_ARG(1) : false; - if (keepDims) { - if (oldFormat) { - // for oldFormat we can't go below rank 2 - length = sd::math::nd4j_max(2, length); - } - } else { - length -= axis.size(); - if (oldFormat) { - length = sd::math::nd4j_max(2, length); - } - } + Nd4jLong length = input->lengthOf(); - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(length, sd::DataType::INT64)); - } + if (keepDims) { + if (oldFormat) { + // for oldFormat we can't go below rank 2 + length = sd::math::nd4j_max(2, length); } + } else { + length -= axis.size(); + if (oldFormat) { + length = sd::math::nd4j_max(2, length); + } + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + length, sd::DataType::INT64)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp index 86900c264124..5884b8179ebb 100644 --- a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp @@ -24,80 +24,84 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(expand_dims, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isScalar()) { - output->assign(input); - return Status::OK(); - } - - Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); - - if (axis < 0) - axis += input->rankOf() + 1; - - REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis); - - std::vector shape(input->rankOf()); - - for(int e = 0; e < input->rankOf(); e++) - shape[input->sizeAt(e)]; - - shape.insert(shape.begin() + axis, 1); - - if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) { - output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset()); - } else { - auto tmp = input->reshape(input->ordering(), shape); - output->assign(tmp); - } - return Status::OK(); - } - - DECLARE_TYPES(expand_dims) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +namespace ops { +CUSTOM_OP_IMPL(expand_dims, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isScalar()) { + output->assign(input); + return Status::OK(); + } + + Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); + + if (axis < 0) axis += input->rankOf() + 1; + + REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf() + 1, 0, + "ExpandDims: axis should be in range of 0...%i in this case, " + "but got %i instead", + input->rankOf() + 1, axis); + + std::vector shape(input->rankOf()); + + for (int e = 0; e < input->rankOf(); e++) shape[input->sizeAt(e)]; + + shape.insert(shape.begin() + axis, 1); + + if (input->ews() == 1 && output->ews() == 1 && + input->ordering() == output->ordering()) { + output->dataBuffer()->copyBufferFrom( + *input->dataBuffer().get(), + output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), + 0, input->bufferOffset()); + } else { + auto tmp = input->reshape(input->ordering(), shape); + output->assign(tmp); + } + return Status::OK(); +} - DECLARE_SHAPE_FN(expand_dims) { - auto inShape = inputShape->at(0); +DECLARE_TYPES(expand_dims) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - // 0D scalar edge case - if (shape::rank(inShape) == 0) { +DECLARE_SHAPE_FN(expand_dims) { + auto inShape = inputShape->at(0); - Nd4jLong x = 1; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', 1, &x); - return SHAPELIST(newShape); - } + // 0D scalar edge case + if (shape::rank(inShape) == 0) { + Nd4jLong x = 1; + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), 'c', 1, &x); + return SHAPELIST(newShape); + } - // FIXME: temp workaround for TF - if (shape::isScalar(inShape)) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', 2, shape::shapeOf(inShape)); - return SHAPELIST(newShape); - } + // FIXME: temp workaround for TF + if (shape::isScalar(inShape)) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), 'c', 2, shape::shapeOf(inShape)); + return SHAPELIST(newShape); + } - auto x_rank = shape::rank(inShape); - char order = shape::order(inShape); + auto x_rank = shape::rank(inShape); + char order = shape::order(inShape); - Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); + Nd4jLong axis = block.numI() > 0 ? INT_ARG(0) : INPUT_VARIABLE(1)->e(0); - if (axis < 0) - axis += x_rank + 1; + if (axis < 0) axis += x_rank + 1; - std::vector shape; - for(int e = 0; e < x_rank; e++) - shape.emplace_back(shape::shapeOf(inShape)[e]); + std::vector shape; + for (int e = 0; e < x_rank; e++) + shape.emplace_back(shape::shapeOf(inShape)[e]); - shape.insert(shape.begin() + axis, 1); + shape.insert(shape.begin() + axis, 1); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), order, shape); - return SHAPELIST(newShape); - } - } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), order, shape); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/flatten.cpp b/libnd4j/include/ops/declarable/generic/shape/flatten.cpp index 19cc4f4690bc..b5691373ba3c 100644 --- a/libnd4j/include/ops/declarable/generic/shape/flatten.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/flatten.cpp @@ -22,47 +22,54 @@ #if NOT_EXCLUDED(OP_) #include -#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(flatten, -1, 1, false, 0, 1) { - auto output = OUTPUT_VARIABLE(0); - auto zType = output->dataType(); - auto xType = INPUT_VARIABLE(0)->dataType(); +namespace ops { +CUSTOM_OP_IMPL(flatten, -1, 1, false, 0, 1) { + auto output = OUTPUT_VARIABLE(0); + auto zType = output->dataType(); + auto xType = INPUT_VARIABLE(0)->dataType(); - REQUIRE_TRUE(xType == zType, 0, "Flatten: output array must have same data type as input arrays"); - std::vector arrays(block.width()); - for (int e = 0; e < block.width(); e++) { - auto input = INPUT_VARIABLE(e); + REQUIRE_TRUE( + xType == zType, 0, + "Flatten: output array must have same data type as input arrays"); + std::vector arrays(block.width()); + for (int e = 0; e < block.width(); e++) { + auto input = INPUT_VARIABLE(e); - REQUIRE_TRUE(xType == input->dataType(), 0, "Flatten: all input arrays must have the same data type"); + REQUIRE_TRUE(xType == input->dataType(), 0, + "Flatten: all input arrays must have the same data type"); - arrays[e] = input; - } + arrays[e] = input; + } - char order = (char) INT_ARG(0); - helpers::flatten(block.launchContext(), arrays, output, order); + char order = (char)INT_ARG(0); + helpers::flatten(block.launchContext(), arrays, output, order); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(flatten) { - getOpDescriptor()->setAllowedInputTypes({ALL_INTS, ALL_FLOATS, sd::DataType::BOOL}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS, sd::DataType::BOOL}); - } +DECLARE_TYPES(flatten) { + getOpDescriptor()->setAllowedInputTypes( + {ALL_INTS, ALL_FLOATS, sd::DataType::BOOL}); + getOpDescriptor()->setAllowedOutputTypes( + 0, {ALL_FLOATS, ALL_INTS, sd::DataType::BOOL}); +} - DECLARE_SHAPE_FN(flatten) { - Nd4jLong length = 0; - sd::DataType dtype = ArrayOptions::dataType(inputShape->at(0)); - for (int e = 0; e < inputShape->size(); e++) { - length += shape::length(inputShape->at(e)); - REQUIRE_TRUE(dtype == ArrayOptions::dataType(inputShape->at(e)), 0, "Flatten: all input arrays must have the same datatype"); - } +DECLARE_SHAPE_FN(flatten) { + Nd4jLong length = 0; + sd::DataType dtype = ArrayOptions::dataType(inputShape->at(0)); + for (int e = 0; e < inputShape->size(); e++) { + length += shape::length(inputShape->at(e)); + REQUIRE_TRUE(dtype == ArrayOptions::dataType(inputShape->at(e)), 0, + "Flatten: all input arrays must have the same datatype"); + } - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(length, dtype)); - } - } + return SHAPELIST( + ConstantShapeHelper::getInstance()->vectorShapeInfo(length, dtype)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/order.cpp b/libnd4j/include/ops/declarable/generic/shape/order.cpp index 5b978f48f488..5f2ccdf9f30b 100644 --- a/libnd4j/include/ops/declarable/generic/shape/order.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/order.cpp @@ -24,31 +24,33 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(order, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(order, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - output->assign(input); + output->assign(input); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(order) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } +DECLARE_TYPES(order) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); +} - DECLARE_SHAPE_FN(order) { - auto input = inputShape->at(0); +DECLARE_SHAPE_FN(order) { + auto input = inputShape->at(0); - auto isFOrder = INT_ARG(0) == 1; + auto isFOrder = INT_ARG(0) == 1; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(input), isFOrder ? 'f' : 'c', shape::rank(input), shape::shapeOf(input)); - return SHAPELIST(newShape); - } - } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(input), isFOrder ? 'f' : 'c', shape::rank(input), + shape::shapeOf(input)); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/permute.cpp b/libnd4j/include/ops/declarable/generic/shape/permute.cpp index c3065455364e..76b6de592289 100644 --- a/libnd4j/include/ops/declarable/generic/shape/permute.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/permute.cpp @@ -22,60 +22,65 @@ #include #if NOT_EXCLUDED(OP_permute) -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is int vector of ordered set of dimensions to be permuted CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, + "PERMUTE OP: when input is empty, output must also be empty"); + return Status::OK(); // No op + } - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1 && block.numI() == 0) { - z->assign(x->transpose()); - return Status::OK(); - } + if (block.width() == 1 && block.numI() == 0) { + z->assign(x->transpose()); + return Status::OK(); + } - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - z->assign(x->permute(permutationVector)); + z->assign(x->permute(permutationVector)); - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(permute) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(permute) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - - if (block.width() == 1 && block.numI() == 0) - return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); + if (block.width() == 1 && block.numI() == 0) + return SHAPELIST( + ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); + auto outputShapeInfo = ShapeUtils::evalPermShapeInfo( + permutationVector.data(), x->rankOf(), *x, block.workspace(), true); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/rank.cpp b/libnd4j/include/ops/declarable/generic/shape/rank.cpp index 8a617dc5972e..4183de35857b 100644 --- a/libnd4j/include/ops/declarable/generic/shape/rank.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/rank.cpp @@ -24,30 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(rank, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - - output->p(0, input->rankOf()); - output->syncToDevice(); - - return Status::OK(); - } - DECLARE_SHAPE_FN(rank) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT32)); - } - - - DECLARE_TYPES(rank) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(rank, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + output->p(0, input->rankOf()); + output->syncToDevice(); + + return Status::OK(); +} +DECLARE_SHAPE_FN(rank) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT32)); +} + +DECLARE_TYPES(rank) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 364b451eee9c..8ddfb03e4b16 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -24,147 +24,150 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is a vector with (optional) negative of order as first element: // ({-order, dim1, dim2, dim3, ...}) CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + // Special case: empty.reshape() -> return empty + if (x->isEmpty()) { + REQUIRE_TRUE(z->isEmpty(), 0, + "Reshape: when input is empty, output must also be empty"); + return Status::OK(); // No op + } - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return Status::OK(); //No op - } + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), z->lengthOf()); - REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + if (Environment::getInstance()->isDebugAndVerbose()) + nd4j_printv("Reshape: new shape", z->getShapeAsVector()); - if (Environment::getInstance()->isDebugAndVerbose()) - nd4j_printv("Reshape: new shape", z->getShapeAsVector()); + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(reshape) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } DECLARE_SHAPE_FN(reshape) { - - const auto x = INPUT_VARIABLE(0); - - std::vector reshapeArgs; - std::vector shapeNew; - char orderNew = 'c'; - - if (block.width() == 1) { - reshapeArgs = block.getIArguments(); - if(!reshapeArgs.empty()) { - orderNew = (char) -reshapeArgs[0]; - if(orderNew == 'c' || orderNew == 'f') - reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case - } + const auto x = INPUT_VARIABLE(0); + + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + + if (block.width() == 1) { + reshapeArgs = block.getIArguments(); + if (!reshapeArgs.empty()) { + orderNew = (char)-reshapeArgs[0]; + if (orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase( + reshapeArgs + .begin()); // remove first element being order in this case } - else { - reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); - orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c'; + } else { + reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); + orderNew = block.numI() > 0 ? (char)-INT_ARG(0) : 'c'; + } + + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, + "Reshape buffer should have at least 1 dimension !"); + + // Nd4jLong xLen = x->lengthOf(); + // if(x->isEmpty()) { + // xLen = 1; + // for (uint i = 0; i < x->rankOf(); ++i) // + // take into account possible empty shapes + // if(x->sizeAt(i) != 0) + // xLen *= x->sizeAt(i); + // } + + // for (uint i = 0; i < reshapeArgs.size(); ++i) { + + // if (reshapeArgs[i] == -1) { + + // uint shapeLength = 1, numOfZeros = 0; + + // for(uint j = 0; j < i; ++j) + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + + // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one + // unknown dimension (-1) is allowed."); if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + // } + + // const auto dim = xLen / shapeLength; + + // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + // shapeNew.push_back(0); + // else + // shapeNew.push_back(dim); + // } + // else + // shapeNew.push_back(reshapeArgs[i]); + // } + + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; + + for (int i = 0; i < reshapeArgs.size(); ++i) { + const int dim = reshapeArgs[i]; + + if (dim == -1) { + REQUIRE_TRUE(pos == -1, 0, + "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); + } else if (dim == 0) { + shapeNew.push_back(0); + newShapeEmpty = true; + } else { + shapeNew.push_back(dim); + newShapeLen *= dim; } + } - REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); - - // Nd4jLong xLen = x->lengthOf(); - // if(x->isEmpty()) { - // xLen = 1; - // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - // if(x->sizeAt(i) != 0) - // xLen *= x->sizeAt(i); - // } - - // for (uint i = 0; i < reshapeArgs.size(); ++i) { - - // if (reshapeArgs[i] == -1) { - - // uint shapeLength = 1, numOfZeros = 0; - - // for(uint j = 0; j < i; ++j) - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - - // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - // } - - // const auto dim = xLen / shapeLength; - - // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - // shapeNew.push_back(0); - // else - // shapeNew.push_back(dim); - // } - // else - // shapeNew.push_back(reshapeArgs[i]); - // } - - Nd4jLong newShapeLen = 1; - int pos = -1; - bool newShapeEmpty = false; - - for (int i = 0; i < reshapeArgs.size(); ++i) { - - const int dim = reshapeArgs[i]; - - if (dim == -1) { - REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - pos = i; - shapeNew.push_back(1); - } - else if (dim == 0) { - shapeNew.push_back(0); - newShapeEmpty = true; - } - else { - shapeNew.push_back(dim); - newShapeLen *= dim; - } + if (pos != -1) { + Nd4jLong xLen = x->lengthOf(); + if (x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); + ++i) // take into account possible empty shapes + if (x->sizeAt(i) > 0 || !newShapeEmpty) xLen *= x->sizeAt(i); } - if (pos != -1) { + shapeNew[pos] = xLen / newShapeLen; + } - Nd4jLong xLen = x->lengthOf(); - if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) > 0 || !newShapeEmpty) - xLen *= x->sizeAt(i); - } + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, + "Reshape: lengths before and after reshape should match, but " + "got %i vs %i", + x->lengthOf(), len); - shapeNew[pos] = xLen / newShapeLen; - } - - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + x->dataType(), orderNew, shapeNew)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp index c0008cb08b24..4c60649183a4 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp @@ -24,35 +24,32 @@ #include namespace sd { - namespace ops { +namespace ops { +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) { + auto z = OUTPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); + // FIXME: add validation here? + auto tmp = x->reshape(y->ordering(), y->getShapeAsVector()); + z->assign(tmp); - auto z = OUTPUT_VARIABLE(0); - - // FIXME: add validation here? - auto tmp = x->reshape(y->ordering(), y->getShapeAsVector()); - z->assign(tmp); - - return Status::OK(); - } - DECLARE_SYN(reshape_as, reshapeas); - - DECLARE_SHAPE_FN(reshapeas) { - return SHAPELIST(CONSTANT(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->shapeInfo(), false, block.workspace()))); - } + return Status::OK(); +} +DECLARE_SYN(reshape_as, reshapeas); - DECLARE_TYPES(reshapeas) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +DECLARE_SHAPE_FN(reshapeas) { + return SHAPELIST(CONSTANT(ShapeBuilders::copyShapeInfo( + INPUT_VARIABLE(1)->shapeInfo(), false, block.workspace()))); } + +DECLARE_TYPES(reshapeas) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/shape.cpp b/libnd4j/include/ops/declarable/generic/shape/shape.cpp index e2db3db3e814..fda0ddf2f60e 100644 --- a/libnd4j/include/ops/declarable/generic/shape/shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/shape.cpp @@ -24,37 +24,36 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(shape_of, 1, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(shape_of, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - for (int e = 0; e < x->rankOf(); e++) - z->p(e, x->sizeAt(e)); + for (int e = 0; e < x->rankOf(); e++) z->p(e, x->sizeAt(e)); - STORE_RESULT(z); + STORE_RESULT(z); - return Status::OK(); - }; - DECLARE_SYN(shape, shape_of); + return Status::OK(); +}; +DECLARE_SYN(shape, shape_of); - DECLARE_SHAPE_FN(shape_of) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(shape_of) { + auto inShape = inputShape->at(0); - // LONG by default - auto dtype = DataType::INT64; - if (block.numI() > 0) - dtype = DataTypeUtils::fromInt(INT_ARG(0)); + // LONG by default + auto dtype = DataType::INT64; + if (block.numI() > 0) dtype = DataTypeUtils::fromInt(INT_ARG(0)); - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::rank(inShape), dtype)); - }; + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::rank(inShape), dtype)); +}; - DECLARE_TYPES(shape_of) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - } +DECLARE_TYPES(shape_of) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/shapes.cpp b/libnd4j/include/ops/declarable/generic/shape/shapes.cpp index 6481d1db3c78..3c60d40c513a 100644 --- a/libnd4j/include/ops/declarable/generic/shape/shapes.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/shapes.cpp @@ -24,37 +24,37 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(shapes_of, -1, -1, false, 0, 0) { - for (int e = 0; e < block.width(); e++) { - auto x = INPUT_VARIABLE(e); - auto z = OUTPUT_VARIABLE(e); - - for (int i = 0; i < x->rankOf(); i++) - z->p(i, x->sizeAt(i)); - } - - return Status::OK(); - }; - DECLARE_SYN(shape_n, shapes_of); - - DECLARE_SHAPE_FN(shapes_of) { - auto shapeList = SHAPELIST(); - - for (int e = 0; e < inputShape->size(); e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::rank(inShape), sd::DataType::INT64)); - } - - return shapeList; - }; - - DECLARE_TYPES(shapes_of) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(shapes_of, -1, -1, false, 0, 0) { + for (int e = 0; e < block.width(); e++) { + auto x = INPUT_VARIABLE(e); + auto z = OUTPUT_VARIABLE(e); + + for (int i = 0; i < x->rankOf(); i++) z->p(i, x->sizeAt(i)); + } + + return Status::OK(); +}; +DECLARE_SYN(shape_n, shapes_of); + +DECLARE_SHAPE_FN(shapes_of) { + auto shapeList = SHAPELIST(); + + for (int e = 0; e < inputShape->size(); e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo( + shape::rank(inShape), sd::DataType::INT64)); + } + + return shapeList; +}; + +DECLARE_TYPES(shapes_of) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/size.cpp b/libnd4j/include/ops/declarable/generic/shape/size.cpp index d31e782c6cf7..28558f9a0660 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size.cpp @@ -24,29 +24,30 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(size, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); - - output->p(0, input->lengthOf()); - output->syncToDevice(); - - return Status::OK(); - } - DECLARE_SHAPE_FN(size) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(size) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(size, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); + + output->p(0, input->lengthOf()); + output->syncToDevice(); + + return Status::OK(); +} +DECLARE_SHAPE_FN(size) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(size) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp index 2c27b018a3aa..87c6e63dd644 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp @@ -24,34 +24,35 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(size_at, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto dim = INT_ARG(0); - if (dim < 0) - dim += input->rankOf(); - - REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank") - - output->p(0, input->sizeAt(dim)); - output->syncToDevice(); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(size_at) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); - } - - DECLARE_TYPES(size_at) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(DataType::INT64) - ->allowOverride(true); - } - } +namespace ops { +CUSTOM_OP_IMPL(size_at, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto dim = INT_ARG(0); + if (dim < 0) dim += input->rankOf(); + + REQUIRE_TRUE(dim < input->rankOf(), 0, + "Size_At: Dim can't be higher then input rank") + + output->p(0, input->sizeAt(dim)); + output->syncToDevice(); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(size_at) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); +} + +DECLARE_TYPES(size_at) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(DataType::INT64) + ->allowOverride(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp b/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp index 0b71dae52f6d..09fda7f65081 100644 --- a/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/squeeze.cpp @@ -24,135 +24,136 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - - if (block.numI() > 0) - for (int e = 0; e < block.numI(); e++) { - int _a = INT_ARG(e); - if (_a < 0) - _a += input->rankOf(); - - axis.emplace_back(_a); - } - else if (block.width() > 1) { - auto a = INPUT_VARIABLE(1); - for (Nd4jLong e = 0; e < a->lengthOf(); e++) { - int _a = a->e(e); - - if (_a < 0) - _a += input->rankOf(); - - axis.emplace_back(_a); - } - } - - if (input->rankOf() == 0 || (input->rankOf() == 1 && input->lengthOf() == 1)) { - output->assign(input); - return Status::OK(); - } - - std::vector shape; - if (axis.size() == 0) { - for (int d = 0; d < input->rankOf(); d++) - if (input->sizeAt(d) > 1) - shape.emplace_back(input->sizeAt(d)); - } else { - for (int d = 0; d < input->rankOf(); d++) { - if (input->sizeAt(d) == 1) { - if (std::find(axis.begin(), axis.end(), d) == axis.end()) - shape.emplace_back(input->sizeAt(d)); - } else shape.emplace_back(input->sizeAt(d)); - } - } - - if (block.isInplace()) { - output->reshapei(input->ordering(), shape, false); - } else { - if (input->ews() == 1 && output->ews() == 1 && input->ordering() == output->ordering()) { - output->dataBuffer()->copyBufferFrom(*input->dataBuffer().get(), output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), 0, input->bufferOffset()); - } else { - auto tmp = input->reshape(input->ordering(), shape); - output->assign(tmp); - } - } - - return Status::OK(); - } - - DECLARE_TYPES(squeeze) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - DECLARE_SHAPE_FN(squeeze) { - auto shapeList = SHAPELIST(); - -// Nd4jLong* newShape; - auto in = inputShape->at(0); - auto rank = shape::rank(in); - auto length = shape::length(in); - - if (rank == 0 || (rank == 1 && length == 1)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(in))); - return shapeList; - } - - std::vector axis; - - if (block.numI() > 0) - for (int e = 0; e < block.numI(); e++) { - int _a = INT_ARG(e); - if (_a < 0) - _a += rank; - - axis.emplace_back(_a); - } - else if (block.width() > 1) { - auto a = INPUT_VARIABLE(1); - for (int e = 0; e < a->lengthOf(); e++) { - int _a = a->e(e); - - if (_a < 0) - _a += rank; - - axis.emplace_back(_a); - } - - } - - auto order = shape::order(in); - auto oldShape = shape::shapeOf(in); - - std::vector shape; - if (axis.size() == 0) { - for (int d = 0; d < rank; d++) - if (oldShape[d] > 1) - shape.emplace_back(oldShape[d]); - } else { - for (int d = 0; d < rank; d++) { - if (oldShape[d] == 1) { - if (std::find(axis.begin(), axis.end(), d) == axis.end()) - shape.emplace_back(oldShape[d]); - } else shape.emplace_back(oldShape[d]); - } - } - - if (shape.size() == 0) { - shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(in))); - return shapeList; - } - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), order, shape); - shapeList->push_back(newShape); - return shapeList; - } +namespace ops { +CUSTOM_OP_IMPL(squeeze, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + std::vector axis; + + if (block.numI() > 0) + for (int e = 0; e < block.numI(); e++) { + int _a = INT_ARG(e); + if (_a < 0) _a += input->rankOf(); + + axis.emplace_back(_a); + } + else if (block.width() > 1) { + auto a = INPUT_VARIABLE(1); + for (Nd4jLong e = 0; e < a->lengthOf(); e++) { + int _a = a->e(e); + + if (_a < 0) _a += input->rankOf(); + + axis.emplace_back(_a); + } + } + + if (input->rankOf() == 0 || + (input->rankOf() == 1 && input->lengthOf() == 1)) { + output->assign(input); + return Status::OK(); + } + + std::vector shape; + if (axis.size() == 0) { + for (int d = 0; d < input->rankOf(); d++) + if (input->sizeAt(d) > 1) shape.emplace_back(input->sizeAt(d)); + } else { + for (int d = 0; d < input->rankOf(); d++) { + if (input->sizeAt(d) == 1) { + if (std::find(axis.begin(), axis.end(), d) == axis.end()) + shape.emplace_back(input->sizeAt(d)); + } else + shape.emplace_back(input->sizeAt(d)); + } + } + + if (block.isInplace()) { + output->reshapei(input->ordering(), shape, false); + } else { + if (input->ews() == 1 && output->ews() == 1 && + input->ordering() == output->ordering()) { + output->dataBuffer()->copyBufferFrom( + *input->dataBuffer().get(), + output->lengthOf() * DataTypeUtils::sizeOfElement(output->dataType()), + 0, input->bufferOffset()); + } else { + auto tmp = input->reshape(input->ordering(), shape); + output->assign(tmp); + } + } + + return Status::OK(); +} + +DECLARE_TYPES(squeeze) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(squeeze) { + auto shapeList = SHAPELIST(); + + // Nd4jLong* newShape; + auto in = inputShape->at(0); + auto rank = shape::rank(in); + auto length = shape::length(in); + + if (rank == 0 || (rank == 1 && length == 1)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(in))); + return shapeList; + } + + std::vector axis; + + if (block.numI() > 0) + for (int e = 0; e < block.numI(); e++) { + int _a = INT_ARG(e); + if (_a < 0) _a += rank; + + axis.emplace_back(_a); + } + else if (block.width() > 1) { + auto a = INPUT_VARIABLE(1); + for (int e = 0; e < a->lengthOf(); e++) { + int _a = a->e(e); + + if (_a < 0) _a += rank; + + axis.emplace_back(_a); + } + } + + auto order = shape::order(in); + auto oldShape = shape::shapeOf(in); + + std::vector shape; + if (axis.size() == 0) { + for (int d = 0; d < rank; d++) + if (oldShape[d] > 1) shape.emplace_back(oldShape[d]); + } else { + for (int d = 0; d < rank; d++) { + if (oldShape[d] == 1) { + if (std::find(axis.begin(), axis.end(), d) == axis.end()) + shape.emplace_back(oldShape[d]); + } else + shape.emplace_back(oldShape[d]); } + } + + if (shape.size() == 0) { + shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(in))); + return shapeList; + } + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in), order, shape); + shapeList->push_back(newShape); + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp index 03ef8b9ab44a..e8d987b00540 100644 --- a/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/tile_to_shape.cpp @@ -25,75 +25,75 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) { +CUSTOM_OP_IMPL(tile_to_shape, 1, 1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector outShape(block.getIArguments().begin(), + block.getIArguments().end()); - std::vector outShape(block.getIArguments().begin(), block.getIArguments().end()); + if (block.isInplace()) { + input->tileToShape(outShape, *input); + } else { + input->tileToShape(outShape, *output); + } - if (block.isInplace()) { - input->tileToShape(outShape, *input); - } else { - input->tileToShape(outShape, *output); - } - - return Status::OK(); - } - - DECLARE_SHAPE_FN(tile_to_shape) { - auto in = inputShape->at(0); + return Status::OK(); +} - // output shape always equals to arguments +DECLARE_SHAPE_FN(tile_to_shape) { + auto in = inputShape->at(0); - auto conv = ArrayUtils::toLongVector(block.getIArguments()); + // output shape always equals to arguments - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), shape::order(in), conv); + auto conv = ArrayUtils::toLongVector(block.getIArguments()); - return SHAPELIST(newShape); - } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in), shape::order(in), conv); - DECLARE_TYPES(tile_to_shape) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return SHAPELIST(newShape); +} - DECLARE_TYPES(tile_to_shape_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(tile_to_shape) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +DECLARE_TYPES(tile_to_shape_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - CUSTOM_OP_IMPL(tile_to_shape_bp, 2, 1, true, 0, -1) { - auto input = INPUT_VARIABLE(0); - auto epsNext = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(tile_to_shape_bp, 2, 1, true, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto epsNext = INPUT_VARIABLE(1); - auto gradX = OUTPUT_VARIABLE(0); + auto gradX = OUTPUT_VARIABLE(0); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), epsNext->shapeInfo()); - // FIX ME: reduceAlongDimension should have a signature with result pass to avoid assigning twice - if (!axisX.empty()) { - auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(tempRes); - } else - gradX->assign(epsNext); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(input->shapeInfo(), + epsNext->shapeInfo()); + // FIX ME: reduceAlongDimension should have a signature with result pass to + // avoid assigning twice + if (!axisX.empty()) { + auto tempRes = epsNext->reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(tempRes); + } else + gradX->assign(epsNext); - STORE_RESULT(gradX); + STORE_RESULT(gradX); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(tile_to_shape_bp) { - auto in = inputShape->at(0); +DECLARE_SHAPE_FN(tile_to_shape_bp) { + auto in = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(in, newShape); + Nd4jLong *newShape; + COPY_SHAPE(in, newShape); - return SHAPELIST(CONSTANT(newShape)); - } -} + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp index bf48544bbbd9..fdd406e62909 100644 --- a/libnd4j/include/ops/declarable/generic/shape/transpose.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/transpose.cpp @@ -22,57 +22,61 @@ #include #if NOT_EXCLUDED(OP_transpose) -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + // Special case: empty.reshape() -> return empty + if (x->isEmpty()) { + REQUIRE_TRUE( + z->isEmpty(), 0, + "TRANSPOSE OP: when input is empty, output must also be empty"); + return Status::OK(); // No op + } + + if (block.width() == 1 && block.numI() == 0) { + z->assign(x->transpose()); + return Status::OK(); + } - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { - REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty"); - return Status::OK(); //No op - } - - if (block.width() == 1 && block.numI() == 0) { - z->assign(x->transpose()); - return Status::OK(); - } - - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - z->assign(x->permute(permutationVector)); + z->assign(x->permute(permutationVector)); - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(transpose) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } DECLARE_SHAPE_FN(transpose) { + auto x = INPUT_VARIABLE(0); - auto x = INPUT_VARIABLE(0); - - if (block.width() == 1 && block.numI() == 0) - return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); + if (block.width() == 1 && block.numI() == 0) + return SHAPELIST( + ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true)); - std::vector permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getIArguments(); + std::vector permutationVector = block.width() > 1 + ? INPUT_VARIABLE(1)->asVectorT() + : block.getIArguments(); - auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true); + auto outputShapeInfo = ShapeUtils::evalPermShapeInfo( + permutationVector.data(), x->rankOf(), *x, block.workspace(), true); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(outputShapeInfo); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/strings/split_string.cpp b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp index e42591b5c4ad..a9a6d4111775 100644 --- a/libnd4j/include/ops/declarable/generic/strings/split_string.cpp +++ b/libnd4j/include/ops/declarable/generic/strings/split_string.cpp @@ -24,27 +24,27 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - return Status::OK(); - }; - - DECLARE_SHAPE_FN(split_string) { - auto input = INPUT_VARIABLE(0); - auto delim = INPUT_VARIABLE(1); - - return SHAPELIST(); - } - - DECLARE_TYPES(split_string) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_STRINGS}) - ->setAllowedOutputTypes({ALL_STRINGS}); - } - } +namespace ops { +CUSTOM_OP_IMPL(split_string, 2, 1, true, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return Status::OK(); +}; + +DECLARE_SHAPE_FN(split_string) { + auto input = INPUT_VARIABLE(0); + auto delim = INPUT_VARIABLE(1); + + return SHAPELIST(); +} + +DECLARE_TYPES(split_string) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_STRINGS}) + ->setAllowedOutputTypes({ALL_STRINGS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/create.cpp b/libnd4j/include/ops/declarable/generic/tensor/create.cpp index c79b55497be0..88431c1f4779 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/create.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/create.cpp @@ -24,35 +24,36 @@ #include namespace sd { - namespace ops { +namespace ops { - CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { - auto init = block.numB() > 0 ? B_ARG(0) : true; +CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { + auto init = block.numB() > 0 ? B_ARG(0) : true; - if (init) - OUTPUT_VARIABLE(0)->nullify(); + if (init) OUTPUT_VARIABLE(0)->nullify(); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(create) { - auto shapeInput = INPUT_VARIABLE(0); - auto order = (char) INT_ARG(0); - auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); +DECLARE_SHAPE_FN(create) { + auto shapeInput = INPUT_VARIABLE(0); + auto order = (char)INT_ARG(0); + auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); - REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f"); + REQUIRE_TRUE(order == 'c' || order == 'f', 0, + "create: order must be either c or f"); - auto shape = shapeInput->getBufferAsVector(); + auto shape = shapeInput->getBufferAsVector(); - return SHAPELIST(sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, order, shape)); - } + return SHAPELIST(sd::ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, order, shape)); +} - DECLARE_TYPES(create) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - } +DECLARE_TYPES(create) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp index 5a35cee2cf84..df1ead4a5e2e 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/fill.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/fill.cpp @@ -24,75 +24,76 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(fill, 1, 1, false, -2, 0) { - auto shapeArray = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - auto w = block.width(); - auto i = block.numI(); - auto t = block.numT(); - - REQUIRE_TRUE( w > 1 || t > 0 || i > 0, 0, "Fill: either additional variable should exist, or scalar value should be present"); - - if(output->isEmpty()){ - //Empty output array - no-op - return Status::OK(); - } - - if (w > 1) { - output->assign(INPUT_VARIABLE(1)); - } else { - if (t > 0) { - output->assign(T_ARG(0)); - } else if (i > 0) { - output->assign(INT_ARG(0)); - } - } - - STORE_RESULT(output); - - return Status::OK(); - }; - - DECLARE_TYPES(fill) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(fill) { - - auto shapeArray = INPUT_VARIABLE(0); - const int len = (int) shapeArray->lengthOf(); - Nd4jLong *newShape = nullptr; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(len), Nd4jLong); - - newShape[0] = len; - for (int e = 0; e < shapeArray->lengthOf(); e++){ - newShape[e+1] = shapeArray->e(e); - } - - sd::DataType dataType; - - if (block.width() > 1) { - dataType = INPUT_VARIABLE(1)->dataType(); - } else if (block.numT() > 0) { - dataType = Environment::getInstance()->defaultFloatDataType(); - } else if (block.numI() > 0) { - dataType = sd::DataType::INT32; - } else if (block.numB() > 0) { - dataType = sd::DataType::BOOL; - } else - throw std::runtime_error("Fill: missing value to fill output array with"); - - ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); - - return SHAPELIST(CONSTANT(newShape)); - }; +namespace ops { + +CUSTOM_OP_IMPL(fill, 1, 1, false, -2, 0) { + auto shapeArray = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + auto w = block.width(); + auto i = block.numI(); + auto t = block.numT(); + + REQUIRE_TRUE(w > 1 || t > 0 || i > 0, 0, + "Fill: either additional variable should exist, or scalar value " + "should be present"); + + if (output->isEmpty()) { + // Empty output array - no-op + return Status::OK(); + } + + if (w > 1) { + output->assign(INPUT_VARIABLE(1)); + } else { + if (t > 0) { + output->assign(T_ARG(0)); + } else if (i > 0) { + output->assign(INT_ARG(0)); } + } + + STORE_RESULT(output); + + return Status::OK(); +}; + +DECLARE_TYPES(fill) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +DECLARE_SHAPE_FN(fill) { + auto shapeArray = INPUT_VARIABLE(0); + const int len = (int)shapeArray->lengthOf(); + Nd4jLong *newShape = nullptr; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(len), Nd4jLong); + + newShape[0] = len; + for (int e = 0; e < shapeArray->lengthOf(); e++) { + newShape[e + 1] = shapeArray->e(e); + } + + sd::DataType dataType; + + if (block.width() > 1) { + dataType = INPUT_VARIABLE(1)->dataType(); + } else if (block.numT() > 0) { + dataType = Environment::getInstance()->defaultFloatDataType(); + } else if (block.numI() > 0) { + dataType = sd::DataType::INT32; + } else if (block.numB() > 0) { + dataType = sd::DataType::BOOL; + } else + throw std::runtime_error("Fill: missing value to fill output array with"); + + ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); + + return SHAPELIST(CONSTANT(newShape)); +}; +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp index f2d74572d2b2..746b4eaa3cb8 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/fill_as.cpp @@ -24,33 +24,30 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(fill_as, 1, 1, true, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - - if (block.width() > 1) { - auto s = INPUT_VARIABLE(1); - output->assign(s); - } else if (block.numT() > 0) { - output->assign(T_ARG(0)); - } else if (block.numI() > 0) { - output->assign(INT_ARG(0)); - } - - STORE_RESULT(output); - - return ND4J_STATUS_OK; - } - DECLARE_SYN(filllike, fill_as); - DECLARE_SYN(fill_like, fill_as); - - - DECLARE_TYPES(fill_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(fill_as, 1, 1, true, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + + if (block.width() > 1) { + auto s = INPUT_VARIABLE(1); + output->assign(s); + } else if (block.numT() > 0) { + output->assign(T_ARG(0)); + } else if (block.numI() > 0) { + output->assign(INT_ARG(0)); + } + + STORE_RESULT(output); + + return ND4J_STATUS_OK; } +DECLARE_SYN(filllike, fill_as); +DECLARE_SYN(fill_like, fill_as); + +DECLARE_TYPES(fill_as) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp index 374456be64f3..059ca9eeef81 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp @@ -26,50 +26,64 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { - - auto output = OUTPUT_VARIABLE(0); - - const int nInputs = block.width(); - bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); - - REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT()); - - auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) : static_cast(T_ARG(0)); - auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) : static_cast(T_ARG(1)); - auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); - - if (numOfElements == 1) { - output->assign(start); - return Status::OK(); - } - - output->linspace(start, (finish - start) / ( numOfElements - 1.0 )); - return Status::OK(); - } - - DECLARE_SHAPE_FN(lin_space) { - - const int nInputs = block.width(); - bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); - REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() ); - - - auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast(D_ARG(0)) : DataType::FLOAT32) ; - Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); - - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); - } - +CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { + auto output = OUTPUT_VARIABLE(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || + (2 == block.numT() && block.numI() > 0)); + + REQUIRE_TRUE( + bInputs, 0, + "lin_space OP: Have to be supplied correct inputs, input size or T_ARG " + "size have to be equal 3, but got inputs - %i, T_ARGS - %i!", + nInputs, block.numT()); + + auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) + : static_cast(T_ARG(0)); + auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) + : static_cast(T_ARG(1)); + auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) + : static_cast(I_ARG(0)); + + if (numOfElements == 1) { + output->assign(start); + return Status::OK(); + } + + output->linspace(start, (finish - start) / (numOfElements - 1.0)); + return Status::OK(); +} - DECLARE_TYPES(lin_space) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } +DECLARE_SHAPE_FN(lin_space) { + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || + (2 == block.numT() && block.numI() > 0)); + REQUIRE_TRUE( + bInputs, 0, + "lin_space OP: Have to be supplied correct inputs, input size or T_ARG " + "size have to be equal 3, but got inputs - %i, T_ARGS - %i!", + nInputs, block.numT()); + + auto dataType = (nInputs > 0) + ? ArrayOptions::dataType(inputShape->at(0)) + : (block.numD() > 0 ? static_cast(D_ARG(0)) + : DataType::FLOAT32); + Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) + : static_cast(I_ARG(0)); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); } + +DECLARE_TYPES(lin_space) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp index 32ce543006e4..2519c394a984 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/ones_as.cpp @@ -24,32 +24,34 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(ones_as, 1, 1, false, 0, 0) { + auto output = OUTPUT_VARIABLE(0); - output->assign(1); + output->assign(1); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(ones_as) { - auto in = inputShape->at(0); - auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); - auto shape = sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); +DECLARE_SHAPE_FN(ones_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = + sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); - //nd4j_printf("numD: %i; dtype: %s\n", block.numD(), DataTypeUtils::asString(dtype).c_str()); + // nd4j_printf("numD: %i; dtype: %s\n", block.numD(), + // DataTypeUtils::asString(dtype).c_str()); - return SHAPELIST(shape); - } + return SHAPELIST(shape); +} - DECLARE_TYPES(ones_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - } +DECLARE_TYPES(ones_as) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/range.cpp b/libnd4j/include/ops/declarable/generic/tensor/range.cpp index 987b7424d235..f879be04e2a3 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/range.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/range.cpp @@ -29,257 +29,261 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(range, -2, 1, false, -2, -2) { - auto output = OUTPUT_VARIABLE(0); - - const int numInArrs = block.width(); - const int numTArgs = block.numT(); - const int numIArgs = block.numI(); - - NDArray *s = nullptr; - NDArray *d = nullptr; - - bool localS = false; - bool localD = false; - // FIXME: this op should be fully moved to helpers - - if (output->isEmpty()) - return Status::OK(); - - if (numInArrs > 0) { - if(numInArrs == 1) { - //limit = (*INPUT_VARIABLE(0))(0.); - if (output->isR()) { - s = NDArrayFactory::create_(0.0f, block.launchContext()); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else { - s = NDArrayFactory::create_(0, block.launchContext()); - d = NDArrayFactory::create_(1, block.launchContext()); - } - localS = true; - localD = true; - } else if(numInArrs == 2) { - s = INPUT_VARIABLE(0); - //limit = (*INPUT_VARIABLE(1))(0.); - if (output->isR()) { - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else { - d = NDArrayFactory::create_(1, block.launchContext()); - } - localD = true; - } else { - s = INPUT_VARIABLE(0); - //limit = (*INPUT_VARIABLE(1))(0.); - d = INPUT_VARIABLE(2); - } - } else if (numIArgs > 0) { - - if(numIArgs == 1) { - // limit = INT_ARG(0); - } else if(numIArgs == 2) { - s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); - //limit = INT_ARG(1); - d = NDArrayFactory::create_(1, block.launchContext()); - } - else { - s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); - //limit = INT_ARG(1); - d = NDArrayFactory::create_(INT_ARG(2), block.launchContext()); - } - - localS = true; - localD = true; + auto output = OUTPUT_VARIABLE(0); + + const int numInArrs = block.width(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); + + NDArray *s = nullptr; + NDArray *d = nullptr; + + bool localS = false; + bool localD = false; + // FIXME: this op should be fully moved to helpers + + if (output->isEmpty()) return Status::OK(); + + if (numInArrs > 0) { + if (numInArrs == 1) { + // limit = (*INPUT_VARIABLE(0))(0.); + if (output->isR()) { + s = NDArrayFactory::create_(0.0f, block.launchContext()); + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else { + s = NDArrayFactory::create_(0, block.launchContext()); + d = NDArrayFactory::create_(1, block.launchContext()); + } + localS = true; + localD = true; + } else if (numInArrs == 2) { + s = INPUT_VARIABLE(0); + // limit = (*INPUT_VARIABLE(1))(0.); + if (output->isR()) { + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else { + d = NDArrayFactory::create_(1, block.launchContext()); + } + localD = true; + } else { + s = INPUT_VARIABLE(0); + // limit = (*INPUT_VARIABLE(1))(0.); + d = INPUT_VARIABLE(2); } - else if (numTArgs > 0) { - - if(numTArgs == 1) { - //limit = T_ARG(0); - s = NDArrayFactory::create_(0.0f, block.launchContext()); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } else if(numTArgs == 2) { - s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); - //limit = T_ARG(1); - d = NDArrayFactory::create_(1.0f, block.launchContext()); - } - else { - s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); - //limit = T_ARG(1); - d = NDArrayFactory::create_(T_ARG(2), block.launchContext()); - } - - localS = true; - localD = true; + } else if (numIArgs > 0) { + if (numIArgs == 1) { + // limit = INT_ARG(0); + } else if (numIArgs == 2) { + s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); + // limit = INT_ARG(1); + d = NDArrayFactory::create_(1, block.launchContext()); } else { - REQUIRE_TRUE(false, 0, "CUSTOM RANGE OP: op should have inputs defined in any possible way: T_args, INT_args, or INPUT variables!"); + s = NDArrayFactory::create_(INT_ARG(0), block.launchContext()); + // limit = INT_ARG(1); + d = NDArrayFactory::create_(INT_ARG(2), block.launchContext()); } - helpers::range(block.launchContext(), *s, *d, *output); + localS = true; + localD = true; + } else if (numTArgs > 0) { + if (numTArgs == 1) { + // limit = T_ARG(0); + s = NDArrayFactory::create_(0.0f, block.launchContext()); + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else if (numTArgs == 2) { + s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); + // limit = T_ARG(1); + d = NDArrayFactory::create_(1.0f, block.launchContext()); + } else { + s = NDArrayFactory::create_(T_ARG(0), block.launchContext()); + // limit = T_ARG(1); + d = NDArrayFactory::create_(T_ARG(2), block.launchContext()); + } + + localS = true; + localD = true; + } else { + REQUIRE_TRUE(false, 0, + "CUSTOM RANGE OP: op should have inputs defined in any " + "possible way: T_args, INT_args, or INPUT variables!"); + } + + helpers::range(block.launchContext(), *s, *d, *output); - if (localS) - delete s; + if (localS) delete s; - if (localD) - delete d; + if (localD) delete d; - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(range) { - - const int numInArrs = block.width(); - const int numTArgs = block.numT(); - const int numIArgs = block.numI(); - - Nd4jLong steps = 0; - sd::DataType dataType = block.numD() ? D_ARG(0) : sd::DataType::INHERIT; - - if (numInArrs > 0) { - auto isR = INPUT_VARIABLE(0)->isR(); - auto isZ = INPUT_VARIABLE(0)->isZ(); - auto dtype = INPUT_VARIABLE(0)->dataType(); - - if (isR) { - double start(0), limit, delta(1); - - if (numInArrs == 1) - limit = INPUT_VARIABLE(0)->e(0); - else if (numInArrs == 2) { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - } else { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - delta = INPUT_VARIABLE(2)->e(0); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) - dataType = INPUT_VARIABLE(0)->dataType(); - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } else if (isZ) { - Nd4jLong start(0), limit, delta(1); - - if (numInArrs == 1) - limit = INPUT_VARIABLE(0)->e(0); - else if (numInArrs == 2) { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - } else { - start = INPUT_VARIABLE(0)->e(0); - limit = INPUT_VARIABLE(1)->e(0); - delta = INPUT_VARIABLE(2)->e(0); - } - - //nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta) - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) - dataType = INPUT_VARIABLE(0)->dataType(); - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } - } else if (numIArgs > 0) { - Nd4jLong start(0), limit, delta(1); - - if(numIArgs == 1) - limit = INT_ARG(0); - else if(numIArgs == 2) { - start = INT_ARG(0); - limit = INT_ARG(1); - } - else { - start = INT_ARG(0); - limit = INT_ARG(1); - delta = INT_ARG(2); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, sd::DataType::INT32)); - } - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - if (!block.numD()) { - if (limit > DataTypeUtils::max()) - dataType = sd::DataType::INT64; - else - dataType = sd::DataType::INT32; - } - - steps = (limit - start) / delta; - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; - } - else if (numTArgs > 0) { - double start(0), limit, delta(1); - - if(numTArgs == 1) - limit = T_ARG(0); - else if(numTArgs == 2) { - start = T_ARG(0); - limit = T_ARG(1); - } - else { - start = T_ARG(0); - limit = T_ARG(1); - delta = T_ARG(2); - } - - if (limit == start){ - //Return [0] to match TF - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType())); - } - - - REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); - - steps = static_cast((limit - start) / delta); - - if (!block.numD()) { - if (Environment::getInstance()->precisionBoostAllowed()) - dataType = sd::DataType::DOUBLE; - else - dataType = Environment::getInstance()->defaultFloatDataType(); - } - - if(math::nd4j_abs(start + steps * delta) < math::nd4j_abs(limit)) - ++steps; + const int numInArrs = block.width(); + const int numTArgs = block.numT(); + const int numIArgs = block.numI(); + + Nd4jLong steps = 0; + sd::DataType dataType = block.numD() ? D_ARG(0) : sd::DataType::INHERIT; + + if (numInArrs > 0) { + auto isR = INPUT_VARIABLE(0)->isR(); + auto isZ = INPUT_VARIABLE(0)->isZ(); + auto dtype = INPUT_VARIABLE(0)->dataType(); + + if (isR) { + double start(0), limit, delta(1); + + if (numInArrs == 1) + limit = INPUT_VARIABLE(0)->e(0); + else if (numInArrs == 2) { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + } else { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + delta = INPUT_VARIABLE(2)->e(0); + } + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST( + ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) dataType = INPUT_VARIABLE(0)->dataType(); + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else if (isZ) { + Nd4jLong start(0), limit, delta(1); + + if (numInArrs == 1) + limit = INPUT_VARIABLE(0)->e(0); + else if (numInArrs == 2) { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + } else { + start = INPUT_VARIABLE(0)->e(0); + limit = INPUT_VARIABLE(1)->e(0); + delta = INPUT_VARIABLE(2)->e(0); + } + + // nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, + // limit, delta) + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST( + ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) dataType = INPUT_VARIABLE(0)->dataType(); + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } + } else if (numIArgs > 0) { + Nd4jLong start(0), limit, delta(1); + + if (numIArgs == 1) + limit = INT_ARG(0); + else if (numIArgs == 2) { + start = INT_ARG(0); + limit = INT_ARG(1); } else { - REQUIRE_TRUE(false, 0, "CUSTOM RANGE OP: op should have inputs defined in any possible way: T_args, INT_args, or INPUT variables!"); + start = INT_ARG(0); + limit = INT_ARG(1); + delta = INT_ARG(2); } - REQUIRE_TRUE(steps > 0, 0, "CUSTOM RANGE OP: value of (limit-start)/delta should be positive !"); + if (limit == start) { + // Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + 0, sd::DataType::INT32)); + } - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); -} + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + if (!block.numD()) { + if (limit > DataTypeUtils::max()) + dataType = sd::DataType::INT64; + else + dataType = sd::DataType::INT32; + } + steps = (limit - start) / delta; - DECLARE_TYPES(range) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else if (numTArgs > 0) { + double start(0), limit, delta(1); + + if (numTArgs == 1) + limit = T_ARG(0); + else if (numTArgs == 2) { + start = T_ARG(0); + limit = T_ARG(1); + } else { + start = T_ARG(0); + limit = T_ARG(1); + delta = T_ARG(2); + } + + if (limit == start) { + // Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + 0, Environment::getInstance()->defaultFloatDataType())); + } + + REQUIRE_TRUE(delta != 0, 0, + "CUSTOM RANGE OP: delta should not be equal to zero !"); + + steps = static_cast((limit - start) / delta); + + if (!block.numD()) { + if (Environment::getInstance()->precisionBoostAllowed()) + dataType = sd::DataType::DOUBLE; + else + dataType = Environment::getInstance()->defaultFloatDataType(); } + + if (math::nd4j_abs(start + steps * delta) < + math::nd4j_abs(limit)) + ++steps; + } else { + REQUIRE_TRUE(false, 0, + "CUSTOM RANGE OP: op should have inputs defined in any " + "possible way: T_args, INT_args, or INPUT variables!"); + } + + REQUIRE_TRUE( + steps > 0, 0, + "CUSTOM RANGE OP: value of (limit-start)/delta should be positive !"); + + return SHAPELIST( + ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); } + +DECLARE_TYPES(range) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp index 90e740efdf8c..d95682f6e06f 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/strided_slice.cpp @@ -16,127 +16,131 @@ limitations under the License. #include #if NOT_EXCLUDED(OP_strided_slice) -#include -#include -#include #include +#include +#include + +#include namespace sd { - namespace ops { - - constexpr int kShrinkAxis = -1, kNewAxis = -2; - - struct StridedSliceSparseSpec { - int dims; - int num_add_axis_after_ellipsis; - std::vector* begin_tensor; - const std::vector* end_tensor; - const std::vector* strides_tensor; - const int begin_mask, end_mask; - int ellipsis_mask; - const int new_axis_mask, shrink_axis_mask; - }; - - struct StridedSliceDenseSpec { - const int dims; - int begin_mask; - int end_mask; - bool begin_valid; - bool end_valid; - std::vector& begin; - std::vector& end; - std::vector& strides; - std::vector final_shape_gather_indices; - int shrink_axis_mask; - - public: - bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) { - if (this->begin.size() < dims) - this->begin.resize(dims); - - if (this->end.size() < dims) - this->end.resize(dims); - - if (this->strides.size() < dims) - this->strides.resize(dims); - this->begin_mask = 0; - this->end_mask = 0; - this->shrink_axis_mask = 0; - { - int full_index = 0; - - this->begin_valid = sparse_spec.begin_tensor != nullptr; - this->end_valid = sparse_spec.end_tensor != nullptr; - - for (int e = 0; e < sparse_spec.dims; e++) { - if ((1 << e) & sparse_spec.ellipsis_mask) { - int next_index = sd::math::nd4j_min(this->dims - (sparse_spec.dims - e) + 1 + sparse_spec.num_add_axis_after_ellipsis, this->dims); - - for (; full_index < next_index; full_index++) { - // new_axis' aren't real axis so you have to skip - this->begin[full_index] = this->end[full_index] = 0; - this->strides[full_index] = 1; - this->begin_mask |= (1 << full_index); - this->end_mask |= (1 << full_index); - this->final_shape_gather_indices.push_back(full_index); - } - } else if ((1 << e) & sparse_spec.new_axis_mask) { - this->final_shape_gather_indices.emplace_back(kNewAxis); - } else { - if (full_index == this->begin.size()) { - nd4j_printf("Index out of range: %i out of %i\n", full_index, this->dims); - return false; - } - - // Gather slicing spec into appropriate index - if (sparse_spec.begin_tensor != nullptr) - this->begin[full_index] = sparse_spec.begin_tensor->at(e); - - - if (sparse_spec.end_tensor != nullptr) - this->end[full_index] = sparse_spec.end_tensor->at(e); - - this->strides[full_index] = sparse_spec.strides_tensor->at(e); - - if (sparse_spec.begin_mask & (1 << e)) - this->begin_mask |= (1 << full_index); - - - if (sparse_spec.end_mask & (1 << e)) - this->end_mask |= (1 << full_index); - - - // If shrink, record where to get the dimensionality from (i.e. - // new_axis creates a fake 1 size dimension. Also remember shrink - // axis (now in dense form) so we can ignore dense->end below. - if (sparse_spec.shrink_axis_mask & (1 << e)) { - this->final_shape_gather_indices.push_back(kShrinkAxis); - this->shrink_axis_mask |= (1 << full_index); - } else { - this->final_shape_gather_indices.push_back(full_index); - } - full_index++; - } - } - } - return true; - } - }; - - void vectorize(std::vector& input_shape) { - if (input_shape.size() == 2 && input_shape[0] == 1) { - int v = input_shape[1]; - input_shape.clear(); - input_shape.emplace_back(v); - } +namespace ops { + +constexpr int kShrinkAxis = -1, kNewAxis = -2; + +struct StridedSliceSparseSpec { + int dims; + int num_add_axis_after_ellipsis; + std::vector* begin_tensor; + const std::vector* end_tensor; + const std::vector* strides_tensor; + const int begin_mask, end_mask; + int ellipsis_mask; + const int new_axis_mask, shrink_axis_mask; +}; + +struct StridedSliceDenseSpec { + const int dims; + int begin_mask; + int end_mask; + bool begin_valid; + bool end_valid; + std::vector& begin; + std::vector& end; + std::vector& strides; + std::vector final_shape_gather_indices; + int shrink_axis_mask; + + public: + bool buildDenseSpec(StridedSliceSparseSpec& sparse_spec) { + if (this->begin.size() < dims) this->begin.resize(dims); + + if (this->end.size() < dims) this->end.resize(dims); + + if (this->strides.size() < dims) this->strides.resize(dims); + this->begin_mask = 0; + this->end_mask = 0; + this->shrink_axis_mask = 0; + { + int full_index = 0; + + this->begin_valid = sparse_spec.begin_tensor != nullptr; + this->end_valid = sparse_spec.end_tensor != nullptr; + + for (int e = 0; e < sparse_spec.dims; e++) { + if ((1 << e) & sparse_spec.ellipsis_mask) { + int next_index = sd::math::nd4j_min( + this->dims - (sparse_spec.dims - e) + 1 + + sparse_spec.num_add_axis_after_ellipsis, + this->dims); + + for (; full_index < next_index; full_index++) { + // new_axis' aren't real axis so you have to skip + this->begin[full_index] = this->end[full_index] = 0; + this->strides[full_index] = 1; + this->begin_mask |= (1 << full_index); + this->end_mask |= (1 << full_index); + this->final_shape_gather_indices.push_back(full_index); + } + } else if ((1 << e) & sparse_spec.new_axis_mask) { + this->final_shape_gather_indices.emplace_back(kNewAxis); + } else { + if (full_index == this->begin.size()) { + nd4j_printf("Index out of range: %i out of %i\n", full_index, + this->dims); + return false; + } + + // Gather slicing spec into appropriate index + if (sparse_spec.begin_tensor != nullptr) + this->begin[full_index] = sparse_spec.begin_tensor->at(e); + + if (sparse_spec.end_tensor != nullptr) + this->end[full_index] = sparse_spec.end_tensor->at(e); + + this->strides[full_index] = sparse_spec.strides_tensor->at(e); + + if (sparse_spec.begin_mask & (1 << e)) + this->begin_mask |= (1 << full_index); + + if (sparse_spec.end_mask & (1 << e)) + this->end_mask |= (1 << full_index); + + // If shrink, record where to get the dimensionality from (i.e. + // new_axis creates a fake 1 size dimension. Also remember shrink + // axis (now in dense form) so we can ignore dense->end below. + if (sparse_spec.shrink_axis_mask & (1 << e)) { + this->final_shape_gather_indices.push_back(kShrinkAxis); + this->shrink_axis_mask |= (1 << full_index); + } else { + this->final_shape_gather_indices.push_back(full_index); + } + full_index++; } + } + } + return true; + } +}; + +void vectorize(std::vector& input_shape) { + if (input_shape.size() == 2 && input_shape[0] == 1) { + int v = input_shape[1]; + input_shape.clear(); + input_shape.emplace_back(v); + } +} - bool _preprocess_strided_slice(std::vector* indicesList, std::vector* final_shape, std::vector& input_shape, std::vector& begin, std::vector& end, std::vector& strides, int begin_mask, int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, bool* is_identity, bool* is_simple_slice, bool* slice_dim0) { - std::vector preshape; +bool _preprocess_strided_slice( + std::vector* indicesList, std::vector* final_shape, + std::vector& input_shape, std::vector& begin, + std::vector& end, std::vector& strides, int begin_mask, + int ellipsis_mask, int end_mask, int new_axis_mask, int shrink_axis_mask, + bool* is_identity, bool* is_simple_slice, bool* slice_dim0) { + std::vector preshape; - bool ellipsis_seen = false; + bool ellipsis_seen = false; - StridedSliceSparseSpec sparse_spec = {(int) strides.size(), + StridedSliceSparseSpec sparse_spec = {(int)strides.size(), 0, &begin, &end, @@ -147,530 +151,601 @@ namespace sd { new_axis_mask, shrink_axis_mask}; - for (int i = 0; i < sparse_spec.dims; i++) { - if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { - sparse_spec.num_add_axis_after_ellipsis++; - } - if ((1 << i) & ellipsis_mask) { - ellipsis_seen = true; - } - } - // If no ellipsis insert one at the end - if (!ellipsis_seen) { - sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); - sparse_spec.dims++; // this effects loop iteration below - } - - StridedSliceDenseSpec dense_spec = {(int) input_shape.size(), 0, 0, false, false, begin, end, strides}; - if (!dense_spec.buildDenseSpec(sparse_spec)) - return false; - - //nd4j_printv("Input shape: ", input_shape); - - for (int e = 0; e < (int) input_shape.size(); e++) { - int begin_idx = begin[e]; - int end_idx = end[e]; - int stride_idx = strides[e]; - int size_idx = input_shape[e]; - - bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e)); - - if (stride_idx == 0) { - nd4j_printf("Stride is 0 at index %i\n", e); - return false; - } - if (size_idx == -1) { - preshape.emplace_back(shrink_i ? 1 : -1); - continue; - } - - const std::array masks = {{dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}}; - const std::array valid_range = {{stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}}; - - auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, int c) { - if (masks[c]) { - return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; - } else { - int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive - return x_fwd < valid_range[0] ? valid_range[0] : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; - } - }; - - if (shrink_i && stride_idx <= 0) { - nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", e); - return false; - } - - (*is_simple_slice) &= stride_idx == 1; - - const bool begin_and_end_masked = (begin_mask & (1 << e)) && (end_mask & (1 << e)); - - if (dense_spec.begin_valid && dense_spec.end_valid) { - if (shrink_i) { - int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx; - begin_idx = x_fwd; - end_idx = begin_idx + 1; - if (x_fwd < 0 || x_fwd >= size_idx) { - nd4j_printf("slice index %i of dimension %i out of bounds.\n", begin_idx, e); - return false; - } - } else { - begin_idx = canonical(begin_idx, 0); - end_idx = canonical(end_idx, 1); - } - } else { - (*is_identity) &= stride_idx == 1 && begin_and_end_masked; - (*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked; - } - - int interval_length = 1; - bool known_interval = false; - if (dense_spec.begin_valid && dense_spec.end_valid) { - interval_length = end_idx - begin_idx; - known_interval = true; - } else if (shrink_i) { - interval_length = 1; - known_interval = true; - } else if (begin_and_end_masked) { - if (size_idx > 0) { - if (stride_idx < 0) { - interval_length = -size_idx; - } else { - interval_length = size_idx; - } - - known_interval = true; - } - } - - if (known_interval) { - int size_i; - if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) { - size_i = input_shape.size() == 2 && input_shape[0] == 1? 1: 0; - } else { - size_i = interval_length / stride_idx + (interval_length % stride_idx != 0 ? 1 : 0); - } - - if (indicesList != nullptr) { - if (interval_length > 1) { - indicesList->push_back(begin_idx); - indicesList->push_back(end_idx); - indicesList->push_back(stride_idx); - // (*indicesList)[3*e] = begin_idx; - // (*indicesList)[3*e+1] = end_idx; - // (*indicesList)[3*e+2] = stride_idx; - } - else if (interval_length == 1) { - indicesList->push_back(begin_idx); - indicesList->push_back(begin_idx + 1); - indicesList->push_back(1); - // (*indicesList)[3*e] = begin_idx; - // (*indicesList)[3*e+1] = begin_idx + 1; - // (*indicesList)[3*e+2] = 1; - } - } - - preshape.emplace_back(size_i); - } else { - preshape.emplace_back(-1); - } - } - - - std::vector postshape; - //nd4j_printv("Preshape: ", preshape); - - final_shape->clear(); - for (auto gather_index : dense_spec.final_shape_gather_indices) { - if (gather_index >= 0) { - if (preshape.size() > gather_index) - final_shape->emplace_back(preshape.at(gather_index)); - else - final_shape->emplace_back(1); - } else if (gather_index == kNewAxis) { - final_shape->emplace_back(1); - } - } - - //nd4j_printv("Preshape: ", preshape); - //nd4j_printv("Postshape: ", *final_shape); - - return true; - } + for (int i = 0; i < sparse_spec.dims; i++) { + if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { + sparse_spec.num_add_axis_after_ellipsis++; + } + if ((1 << i) & ellipsis_mask) { + ellipsis_seen = true; + } + } + // If no ellipsis insert one at the end + if (!ellipsis_seen) { + sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); + sparse_spec.dims++; // this effects loop iteration below + } + + StridedSliceDenseSpec dense_spec = { + (int)input_shape.size(), 0, 0, false, false, begin, end, strides}; + if (!dense_spec.buildDenseSpec(sparse_spec)) return false; + + // nd4j_printv("Input shape: ", input_shape); + + for (int e = 0; e < (int)input_shape.size(); e++) { + int begin_idx = begin[e]; + int end_idx = end[e]; + int stride_idx = strides[e]; + int size_idx = input_shape[e]; + + bool shrink_i = (dense_spec.shrink_axis_mask & (1 << e)); + + if (stride_idx == 0) { + nd4j_printf("Stride is 0 at index %i\n", e); + return false; + } + if (size_idx == -1) { + preshape.emplace_back(shrink_i ? 1 : -1); + continue; + } + + const std::array masks = { + {dense_spec.begin_mask & (1 << e), dense_spec.end_mask & (1 << e)}}; + const std::array valid_range = { + {stride_idx > 0 ? 0 : -1, stride_idx > 0 ? size_idx : size_idx - 1}}; + + auto canonical = [stride_idx, e, size_idx, masks, valid_range](int x, + int c) { + if (masks[c]) { + return stride_idx > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; + } else { + int x_fwd = x < 0 ? size_idx + x : x; // make negative indices positive + return x_fwd < valid_range[0] + ? valid_range[0] + : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; + } + }; + + if (shrink_i && stride_idx <= 0) { + nd4j_printf("StridedSlice: only stride 1 allowed on non-range indexing\n", + e); + return false; + } + (*is_simple_slice) &= stride_idx == 1; - CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - if (z->isEmpty()) { - return ND4J_STATUS_OK; - } - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int dim_values = 0; //block.getIArguments()->size() - 5; - int delta = 0; //dim_values % 3; - int elements = 0; //dim_values / 3; - - std::vector begin; - std::vector end; - std::vector strides; - - bool isLive = false; - - std::vector args; - - // statically evaluated - if (block.numI() > 5) { - dim_values = block.numI() - 5; - delta = dim_values % 3; - elements = dim_values / 3; - - for (int e = 5; e < block.numI(); e++) - args.emplace_back(INT_ARG(e)); - - REQUIRE_TRUE(delta == 0, 0, "StridedSlice: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); - - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - - } else if (block.width() > 1) { - isLive = true; - - auto v_begin = INPUT_VARIABLE(1); - auto v_end = INPUT_VARIABLE(2); - - elements = v_begin->lengthOf(); - - REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSlice: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); - REQUIRE_TRUE((v_begin->rankOf() == 1 ) && (v_begin->rankOf() == v_end->rankOf()), 0, "StridedSlice: Rank of begin and ends should be 1, but %i given instead", (int)v_end->rankOf()); - - for (int e = 0; e < v_begin->lengthOf(); e++) - begin.emplace_back(v_begin->e(e)); - - for (int e = 0; e < v_end->lengthOf(); e++) - end.emplace_back(v_end->e(e)); - - if (block.width() > 3) { - auto v_stride = INPUT_VARIABLE(3); - - REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSlice: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); - REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, "StridedSlice: Rank of begin and ends should be %i, but %i given instead", (int) v_begin->rankOf(), v_stride->rankOf()); - - for (int e = 0; e < v_stride->lengthOf(); e++) - strides.emplace_back(v_stride->e(e)); - } else { - for (int e = 0; e < v_begin->lengthOf(); e++) - strides.emplace_back(1); - } - } else { - REQUIRE_TRUE(false, 0, "StridedSlice: Can't find begin/end/stride information neither in IArguments or in input arrays"); - } - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - if (shrink_axis_mask == 0) - for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - - std::vector indices; - auto input_shape = x->getShapeAsVector(); - std::vector final_shape; - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - // FIXME: remove this method once we get 1D vectors supported - //vectorize(input_shape); - REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSlice: shape calculation failed"); -// if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == 2 && input_shape[0] == 1)) { //(indices.size() == 6) && (indices[2] - indices[0] == 1)) { -// z->assign(x->e(indices[0])); -// } -// else { - if (indices.size()) { - Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.workspace(), shape::shapeInfoLength(x->rankOf()), Nd4jLong); - Nd4jLong offset; - - shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), subArrShapeInfo, offset, true, true); - auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo); - - NDArray::prepareSpecialUse({z}, {x}); - - NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign, - x->bufferWithOffset(offset), reinterpret_cast(subArrShapeInfoPack.primary()), - x->specialBufferWithOffset(offset), reinterpret_cast(subArrShapeInfoPack.special()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - nullptr, nullptr, nullptr, true); - - NDArray::registerSpecialUse({z}, {x}); - - RELEASE(subArrShapeInfo, block.workspace()); - } - else if (!z->isEmpty()){ - z->assign(x->e(0)); - } - return Status::OK(); + const bool begin_and_end_masked = + (begin_mask & (1 << e)) && (end_mask & (1 << e)); + + if (dense_spec.begin_valid && dense_spec.end_valid) { + if (shrink_i) { + int x_fwd = begin_idx < 0 ? size_idx + begin_idx : begin_idx; + begin_idx = x_fwd; + end_idx = begin_idx + 1; + if (x_fwd < 0 || x_fwd >= size_idx) { + nd4j_printf("slice index %i of dimension %i out of bounds.\n", + begin_idx, e); + return false; } - DECLARE_SYN(stridedslice, strided_slice); - - DECLARE_SHAPE_FN(strided_slice) { - auto inShape = inputShape->at(0); - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int x_rank = shape::rank(inShape); - - int dim_values = block.numI() - 5; - int delta = dim_values % 3; - int elements = dim_values / 3; - - - std::vector begin; - std::vector end; - std::vector strides; - - // if that's live - shape will be resolved in runtime - if (block.width() > 1) { - begin = INPUT_VARIABLE(1)->template asVectorT(); - end = INPUT_VARIABLE(2)->template asVectorT(); - strides = INPUT_VARIABLE(3)->template asVectorT(); - } else if (dim_values > 0) { - int delta2 = dim_values / x_rank; - - std::vector args; - for (int e = 5; e < block.numI(); e++) - args.emplace_back(INT_ARG(e)); - - // FIXME: propably template required here - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - } - - REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, "Strided_Slice: empty arguments"); - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - - //if (0 == shrink_axis_mask) - if (false) - for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= inShape[dim + 1], 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= inShape[dim + 1], 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - std::vector input_shape; //(shape::rank(inShape)); - auto inputLen = shape::length(inShape); - std::vector shape; - - auto rank = shape::rank(inShape); - auto shortShape = shape::shapeOf(inShape); - for (auto e = 0; e < rank; e++) - input_shape.emplace_back(shortShape[e]); - - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - std::vector indices; - bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); - if (indices.size()) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', - shape); -// if (inputLen > 1) { -// newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', -// shape); -// } else { -// newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); -// } - return SHAPELIST(newShape); - } - - return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape))); + } else { + begin_idx = canonical(begin_idx, 0); + end_idx = canonical(end_idx, 1); + } + } else { + (*is_identity) &= stride_idx == 1 && begin_and_end_masked; + (*slice_dim0) &= (e == 0 && stride_idx == 1) || begin_and_end_masked; + } + + int interval_length = 1; + bool known_interval = false; + if (dense_spec.begin_valid && dense_spec.end_valid) { + interval_length = end_idx - begin_idx; + known_interval = true; + } else if (shrink_i) { + interval_length = 1; + known_interval = true; + } else if (begin_and_end_masked) { + if (size_idx > 0) { + if (stride_idx < 0) { + interval_length = -size_idx; + } else { + interval_length = size_idx; } + known_interval = true; + } + } - CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) { - auto x = INPUT_VARIABLE(0); - auto epsNext = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - int begin_mask = INT_ARG(0); - int ellipsis_mask = INT_ARG(1); - int end_mask = INT_ARG(2); - int new_axis_mask = INT_ARG(3); - int shrink_axis_mask = INT_ARG(4); - - int dim_values = 0; //block.getIArguments()->size() - 5; - int delta = 0; //dim_values % 3; - int elements = 0; //dim_values / 3; - - std::vector begin; - std::vector end; - std::vector strides; - - bool isLive = false; - - std::vector args; - - // statically evaluated - if (block.numI() > 5) { - dim_values = block.numI() - 5; - delta = dim_values % 3; - elements = dim_values / 3; - - for (int e = 5; e < block.numI(); e++) - args.emplace_back(INT_ARG(e)); - - REQUIRE_TRUE(delta == 0, 0, "StridedSliceBP: Number of Integer arguments should be equal to input rank x 3 = %i, but got %i instead", (x->rankOf() * 3), dim_values); - - ShapeUtils::copyVectorPart(begin, args, elements, 0); - ShapeUtils::copyVectorPart(end, args, elements, elements); - ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); - - } else if (block.width() >= 3) { - isLive = true; - - auto v_begin = INPUT_VARIABLE(2); - auto v_end = INPUT_VARIABLE(3); - - elements = v_begin->lengthOf(); - - REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, "StridedSliceBP: Length of begin/end should match, but got %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf()); - - for (int e = 0; e < v_begin->lengthOf(); e++) - begin.emplace_back(v_begin->e(e)); - - for (int e = 0; e < v_end->lengthOf(); e++) - end.emplace_back(v_end->e(e)); - - if (block.width() >= 4) { - auto v_stride = INPUT_VARIABLE(4); - - REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, "StridedSliceBP: Length of begin/end/stride should match, but got %i vs %i vs %i instead", (int) v_begin->lengthOf(), (int) v_end->lengthOf(), (int) v_stride->lengthOf()); - - for (int e = 0; e < v_stride->lengthOf(); e++) - strides.emplace_back(v_stride->e(e)); - } else { - for (int e = 0; e < v_begin->lengthOf(); e++) - strides.emplace_back(1); - } - } else { - REQUIRE_TRUE(false, 0, "StridedSliceBP: Can't find begin/end/stride information neither in IArguments or in input arrays"); - } - - // validation of begin and start - std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); - std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); - std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); - std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); - - for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { - - if(moveAxes[dim]) - continue; - - if(b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { - int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; - REQUIRE_TRUE(first <= x->sizeAt(dim), 0, "StridedSlice: begin index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", begin[b], dim); - } - if(e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { - int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; - REQUIRE_TRUE(last <= x->sizeAt(dim), 0, "StridedSlice: end index should be <= corresponding dimension of input array, but got end_index = %i for dimension %i!", end[e], dim); - } - ++b; - ++e; - } - - auto input_shape = x->getShapeAsVector(); - std::vector indices; - std::vector final_shape; - bool is_identity; - bool is_simple_slice; - bool is_dim0; - - // FIXME: remove this method once we get 1D vectors supported - vectorize(input_shape); - REQUIRE_TRUE(_preprocess_strided_slice(&indices, &final_shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), 0, "StridedSliceBP: shape calculation failed"); - //REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: gradOut shape should be equals to output from strided_slice op."); - //Zero output array, so unused elements have 0 gradient - output->nullify(); - // - // the first case: only for scalar gradient step - if(epsNext->lengthOf() == 1 && (indices.size() == 3 && (indices[1] - indices[0]) == 1 || (indices[2] - indices[0] == 1))) { - output->p(indices[0], *epsNext); - } - else { // else for other cases - auto sub = (*output)(indices, true, true); - sub.assign(epsNext); - } - - return Status::OK(); + if (known_interval) { + int size_i; + if (interval_length == 0 || ((interval_length < 0) != (stride_idx < 0))) { + size_i = input_shape.size() == 2 && input_shape[0] == 1 ? 1 : 0; + } else { + size_i = interval_length / stride_idx + + (interval_length % stride_idx != 0 ? 1 : 0); + } + + if (indicesList != nullptr) { + if (interval_length > 1) { + indicesList->push_back(begin_idx); + indicesList->push_back(end_idx); + indicesList->push_back(stride_idx); + // (*indicesList)[3*e] = begin_idx; + // (*indicesList)[3*e+1] = end_idx; + // (*indicesList)[3*e+2] = stride_idx; + } else if (interval_length == 1) { + indicesList->push_back(begin_idx); + indicesList->push_back(begin_idx + 1); + indicesList->push_back(1); + // (*indicesList)[3*e] = begin_idx; + // (*indicesList)[3*e+1] = begin_idx + 1; + // (*indicesList)[3*e+2] = 1; } + } - DECLARE_SHAPE_FN(strided_slice_bp) { - auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); + preshape.emplace_back(size_i); + } else { + preshape.emplace_back(-1); + } + } + + std::vector postshape; + // nd4j_printv("Preshape: ", preshape); + + final_shape->clear(); + for (auto gather_index : dense_spec.final_shape_gather_indices) { + if (gather_index >= 0) { + if (preshape.size() > gather_index) + final_shape->emplace_back(preshape.at(gather_index)); + else + final_shape->emplace_back(1); + } else if (gather_index == kNewAxis) { + final_shape->emplace_back(1); + } + } - return SHAPELIST(CONSTANT(newShape)); - } + // nd4j_printv("Preshape: ", preshape); + // nd4j_printv("Postshape: ", *final_shape); - DECLARE_TYPES(strided_slice) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + return true; +} - DECLARE_TYPES(strided_slice_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +CUSTOM_OP_IMPL(strided_slice, 1, 1, false, 0, 5) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + if (z->isEmpty()) { + return ND4J_STATUS_OK; + } + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int dim_values = 0; // block.getIArguments()->size() - 5; + int delta = 0; // dim_values % 3; + int elements = 0; // dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + bool isLive = false; + + std::vector args; + + // statically evaluated + if (block.numI() > 5) { + dim_values = block.numI() - 5; + delta = dim_values % 3; + elements = dim_values / 3; + + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + REQUIRE_TRUE(delta == 0, 0, + "StridedSlice: Number of Integer arguments should be equal to " + "input rank x 3 = %i, but got %i instead", + (x->rankOf() * 3), dim_values); + + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + + } else if (block.width() > 1) { + isLive = true; + + auto v_begin = INPUT_VARIABLE(1); + auto v_end = INPUT_VARIABLE(2); + + elements = v_begin->lengthOf(); + + REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, + "StridedSlice: Length of begin/end should match, but got %i " + "vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf()); + REQUIRE_TRUE( + (v_begin->rankOf() == 1) && (v_begin->rankOf() == v_end->rankOf()), 0, + "StridedSlice: Rank of begin and ends should be 1, but %i given " + "instead", + (int)v_end->rankOf()); + + for (int e = 0; e < v_begin->lengthOf(); e++) + begin.emplace_back(v_begin->e(e)); + + for (int e = 0; e < v_end->lengthOf(); e++) + end.emplace_back(v_end->e(e)); + + if (block.width() > 3) { + auto v_stride = INPUT_VARIABLE(3); + + REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, + "StridedSlice: Length of begin/end/stride should match, but " + "got %i vs %i vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf(), + (int)v_stride->lengthOf()); + REQUIRE_TRUE((v_begin->rankOf() == v_stride->rankOf()), 0, + "StridedSlice: Rank of begin and ends should be %i, but %i " + "given instead", + (int)v_begin->rankOf(), v_stride->rankOf()); + + for (int e = 0; e < v_stride->lengthOf(); e++) + strides.emplace_back(v_stride->e(e)); + } else { + for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); } + } else { + REQUIRE_TRUE(false, 0, + "StridedSlice: Can't find begin/end/stride information " + "neither in IArguments or in input arrays"); + } + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + if (shrink_axis_mask == 0) + for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = + strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= x->sizeAt(dim), 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= x->sizeAt(dim), 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + std::vector indices; + auto input_shape = x->getShapeAsVector(); + std::vector final_shape; + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + // FIXME: remove this method once we get 1D vectors supported + // vectorize(input_shape); + REQUIRE_TRUE(_preprocess_strided_slice( + &indices, &final_shape, input_shape, begin, end, strides, + begin_mask, ellipsis_mask, end_mask, new_axis_mask, + shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), + 0, "StridedSlice: shape calculation failed"); + // if(z->lengthOf() == 1 && !z->isEmpty() && (input_shape.size() == + // 2 && input_shape[0] == 1)) { //(indices.size() == 6) && + // (indices[2] - indices[0] == 1)) { + // z->assign(x->e(indices[0])); + // } + // else { + if (indices.size()) { + Nd4jLong* subArrShapeInfo = nullptr; + ALLOCATE(subArrShapeInfo, block.workspace(), + shape::shapeInfoLength(x->rankOf()), Nd4jLong); + Nd4jLong offset; + + shape::calcSubArrShapeInfoAndOffset(indices.data(), x->shapeInfo(), + subArrShapeInfo, offset, true, true); + auto subArrShapeInfoPack = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo); + + NDArray::prepareSpecialUse({z}, {x}); + + NativeOpExecutioner::execTransformAny( + block.launchContext(), sd::transform::Assign, + x->bufferWithOffset(offset), + reinterpret_cast(subArrShapeInfoPack.primary()), + x->specialBufferWithOffset(offset), + reinterpret_cast(subArrShapeInfoPack.special()), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), nullptr, + nullptr, nullptr, true); + + NDArray::registerSpecialUse({z}, {x}); + + RELEASE(subArrShapeInfo, block.workspace()); + } else if (!z->isEmpty()) { + z->assign(x->e(0)); + } + return Status::OK(); +} +DECLARE_SYN(stridedslice, strided_slice); + +DECLARE_SHAPE_FN(strided_slice) { + auto inShape = inputShape->at(0); + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int x_rank = shape::rank(inShape); + + int dim_values = block.numI() - 5; + int delta = dim_values % 3; + int elements = dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + // if that's live - shape will be resolved in runtime + if (block.width() > 1) { + begin = INPUT_VARIABLE(1)->template asVectorT(); + end = INPUT_VARIABLE(2)->template asVectorT(); + strides = INPUT_VARIABLE(3)->template asVectorT(); + } else if (dim_values > 0) { + int delta2 = dim_values / x_rank; + + std::vector args; + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + // FIXME: propably template required here + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + } + + REQUIRE_TRUE(begin.size() > 0 && end.size() > 0 && strides.size() > 0, 0, + "Strided_Slice: empty arguments"); + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + + // if (0 == shrink_axis_mask) + if (false) + for (int dim = 0, b = 0, e = 0; dim < x_rank; ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = + strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= inShape[dim + 1], 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= inShape[dim + 1], 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + std::vector input_shape; //(shape::rank(inShape)); + auto inputLen = shape::length(inShape); + std::vector shape; + + auto rank = shape::rank(inShape); + auto shortShape = shape::shapeOf(inShape); + for (auto e = 0; e < rank; e++) input_shape.emplace_back(shortShape[e]); + + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + std::vector indices; + bool result = _preprocess_strided_slice( + &indices, &shape, input_shape, begin, end, strides, begin_mask, + ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, + &is_simple_slice, &is_dim0); + if (indices.size()) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), 'c', shape); + // if (inputLen > 1) { + // newShape = + // ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), + // 'c', + // shape); + // } else { + // newShape = + // ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); + // } + return SHAPELIST(newShape); + } + + return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo( + ArrayOptions::dataType(inShape))); +} + +CUSTOM_OP_IMPL(strided_slice_bp, 2, 1, false, 0, 5) { + auto x = INPUT_VARIABLE(0); + auto epsNext = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + int begin_mask = INT_ARG(0); + int ellipsis_mask = INT_ARG(1); + int end_mask = INT_ARG(2); + int new_axis_mask = INT_ARG(3); + int shrink_axis_mask = INT_ARG(4); + + int dim_values = 0; // block.getIArguments()->size() - 5; + int delta = 0; // dim_values % 3; + int elements = 0; // dim_values / 3; + + std::vector begin; + std::vector end; + std::vector strides; + + bool isLive = false; + + std::vector args; + + // statically evaluated + if (block.numI() > 5) { + dim_values = block.numI() - 5; + delta = dim_values % 3; + elements = dim_values / 3; + + for (int e = 5; e < block.numI(); e++) args.emplace_back(INT_ARG(e)); + + REQUIRE_TRUE(delta == 0, 0, + "StridedSliceBP: Number of Integer arguments should be equal " + "to input rank x 3 = %i, but got %i instead", + (x->rankOf() * 3), dim_values); + + ShapeUtils::copyVectorPart(begin, args, elements, 0); + ShapeUtils::copyVectorPart(end, args, elements, elements); + ShapeUtils::copyVectorPart(strides, args, elements, elements * 2); + + } else if (block.width() >= 3) { + isLive = true; + + auto v_begin = INPUT_VARIABLE(2); + auto v_end = INPUT_VARIABLE(3); + + elements = v_begin->lengthOf(); + + REQUIRE_TRUE(v_begin->lengthOf() == v_end->lengthOf(), 0, + "StridedSliceBP: Length of begin/end should match, but got %i " + "vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf()); + + for (int e = 0; e < v_begin->lengthOf(); e++) + begin.emplace_back(v_begin->e(e)); + + for (int e = 0; e < v_end->lengthOf(); e++) + end.emplace_back(v_end->e(e)); + + if (block.width() >= 4) { + auto v_stride = INPUT_VARIABLE(4); + + REQUIRE_TRUE(v_stride->lengthOf() == v_begin->lengthOf(), 0, + "StridedSliceBP: Length of begin/end/stride should match, " + "but got %i vs %i vs %i instead", + (int)v_begin->lengthOf(), (int)v_end->lengthOf(), + (int)v_stride->lengthOf()); + + for (int e = 0; e < v_stride->lengthOf(); e++) + strides.emplace_back(v_stride->e(e)); + } else { + for (int e = 0; e < v_begin->lengthOf(); e++) strides.emplace_back(1); + } + } else { + REQUIRE_TRUE(false, 0, + "StridedSliceBP: Can't find begin/end/stride information " + "neither in IArguments or in input arrays"); + } + + // validation of begin and start + std::vector ignoreBegin = BitwiseUtils::valueBits(begin_mask); + std::vector ignoreEnd = BitwiseUtils::valueBits(end_mask); + std::vector addAxes = BitwiseUtils::valueBits(new_axis_mask); + std::vector moveAxes = BitwiseUtils::valueBits(shrink_axis_mask); + + for (int dim = 0, b = 0, e = 0; dim < x->rankOf(); ++dim) { + if (moveAxes[dim]) continue; + + if (b < begin.size() && !ignoreBegin[b] && !addAxes[dim]) { + int first = strides[b] > 0 ? begin[b] : math::nd4j_abs(begin[b]) - 1; + REQUIRE_TRUE( + first <= x->sizeAt(dim), 0, + "StridedSlice: begin index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + begin[b], dim); + } + if (e < end.size() && !ignoreEnd[e] && !addAxes[dim]) { + int last = strides[e] > 0 ? end[e] : math::nd4j_abs(end[e]) - 1; + REQUIRE_TRUE( + last <= x->sizeAt(dim), 0, + "StridedSlice: end index should be <= corresponding dimension of " + "input array, but got end_index = %i for dimension %i!", + end[e], dim); + } + ++b; + ++e; + } + + auto input_shape = x->getShapeAsVector(); + std::vector indices; + std::vector final_shape; + bool is_identity; + bool is_simple_slice; + bool is_dim0; + + // FIXME: remove this method once we get 1D vectors supported + vectorize(input_shape); + REQUIRE_TRUE(_preprocess_strided_slice( + &indices, &final_shape, input_shape, begin, end, strides, + begin_mask, ellipsis_mask, end_mask, new_axis_mask, + shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0), + 0, "StridedSliceBP: shape calculation failed"); + // REQUIRE_TRUE(epsNext->isSameShape(final_shape), 0, "StridedSlice_bp: + // gradOut shape should be equals to output from strided_slice op."); Zero + // output array, so unused elements have 0 gradient + output->nullify(); + // + // the first case: only for scalar gradient step + if (epsNext->lengthOf() == 1 && + (indices.size() == 3 && (indices[1] - indices[0]) == 1 || + (indices[2] - indices[0] == 1))) { + output->p(indices[0], *epsNext); + } else { // else for other cases + auto sub = (*output)(indices, true, true); + sub.assign(epsNext); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(strided_slice_bp) { + auto inShape = inputShape->at(0); + Nd4jLong* newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +DECLARE_TYPES(strided_slice) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_TYPES(strided_slice_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp b/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp index 6d475af53ff5..79f67e4709a1 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/zeros_as.cpp @@ -24,33 +24,33 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { - auto out = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(zeros_as, 1, 1, false, 0, 0) { + auto out = OUTPUT_VARIABLE(0); - out->assign(0); // output is filled by zero by default - - return Status::OK(); - } - DECLARE_SYN(zeroslike, zeros_as); - DECLARE_SYN(zeros_like, zeros_as); + out->assign(0); // output is filled by zero by default + return Status::OK(); +} +DECLARE_SYN(zeroslike, zeros_as); +DECLARE_SYN(zeros_like, zeros_as); - DECLARE_SHAPE_FN(zeros_as) { - auto in = inputShape->at(0); - auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); - auto shape = sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); +DECLARE_SHAPE_FN(zeros_as) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : ArrayOptions::dataType(in); + auto shape = + sd::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, in); - return SHAPELIST(shape); - } + return SHAPELIST(shape); +} - DECLARE_TYPES(zeros_as) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY) - ->setSameMode(false); - } - } +DECLARE_TYPES(zeros_as) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY) + ->setSameMode(false); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tests/noop.cpp b/libnd4j/include/ops/declarable/generic/tests/noop.cpp index 37980c7e6a44..7f69b7aa6e2d 100644 --- a/libnd4j/include/ops/declarable/generic/tests/noop.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/noop.cpp @@ -24,18 +24,16 @@ #include namespace sd { - namespace ops { - OP_IMPL(noop, -2, -2, true) { - // Fastest op ever. - return Status::OK(); - } +namespace ops { +OP_IMPL(noop, -2, -2, true) { + // Fastest op ever. + return Status::OK(); +} - DECLARE_TYPES(noop) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(noop) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp b/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp index 7ded29e2019a..2323ee4eda0e 100644 --- a/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/test_output_reshape.cpp @@ -24,25 +24,22 @@ #include namespace sd { - namespace ops { - OP_IMPL(test_output_reshape, 1, 1, true) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(test_output_reshape, 1, 1, true) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - output->assign(input); + if (!block.isInplace()) output->assign(input); - output->reshapei({-1}); + output->reshapei({-1}); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(test_output_reshape) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(test_output_reshape) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp index 76510720a30d..cdcfb7c5cbba 100644 --- a/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/test_scalar.cpp @@ -24,43 +24,43 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(test_scalar, 1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(test_scalar, 1, 1, false, 0, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - double val = input->e(0) + 2.0; - output->p(0, val); + double val = input->e(0) + 2.0; + output->p(0, val); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(test_scalar) { - Nd4jLong *newShape; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); +DECLARE_SHAPE_FN(test_scalar) { + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; - ArrayOptions::setDataType(newShape, ArrayOptions::dataType(inputShape->at(0))); + ArrayOptions::setDataType(newShape, + ArrayOptions::dataType(inputShape->at(0))); - auto shape = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape)); - RELEASE(newShape, block.workspace()); - return SHAPELIST(shape); - } + auto shape = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newShape)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(shape); +} - DECLARE_TYPES(test_scalar) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(test_scalar) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp index 86368d98404f..917c4cb9b300 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testcustom.cpp @@ -24,32 +24,32 @@ #include namespace sd { - namespace ops { - ////////////////////////////////////////////////////////////////////////// - CUSTOM_OP_IMPL(testcustom, 1, 1, false, 0, -1) { - auto z = this->getZ(block); - - STORE_RESULT(*z); - return Status::OK(); - } - DECLARE_SHAPE_FN(testcustom) { - // this test op will just return back original shape doubled - Nd4jLong *shapeOf; - ALLOCATE(shapeOf, block.workspace(), shape::rank(inputShape->at(0)), Nd4jLong); - for (int e = 0; e < shape::rank(inputShape->at(0)); e++) - shapeOf[e] = inputShape->at(0)[e+1] * 2; - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', shape::rank(inputShape->at(0)), shapeOf); - RELEASE(shapeOf, block.workspace()); - return SHAPELIST(newShape); - } - - DECLARE_TYPES(testcustom) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(testcustom, 1, 1, false, 0, -1) { + auto z = this->getZ(block); + + STORE_RESULT(*z); + return Status::OK(); +} +DECLARE_SHAPE_FN(testcustom) { + // this test op will just return back original shape doubled + Nd4jLong *shapeOf; + ALLOCATE(shapeOf, block.workspace(), shape::rank(inputShape->at(0)), + Nd4jLong); + for (int e = 0; e < shape::rank(inputShape->at(0)); e++) + shapeOf[e] = inputShape->at(0)[e + 1] * 2; + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataType::FLOAT32, 'c', shape::rank(inputShape->at(0)), shapeOf); + RELEASE(shapeOf, block.workspace()); + return SHAPELIST(newShape); +} + +DECLARE_TYPES(testcustom) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp index f4d4d3159d3a..9546be7c3874 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testop2i2o.cpp @@ -24,32 +24,30 @@ #include namespace sd { - namespace ops { - ////////////////////////////////////////////////////////////////////////// - // test op, non-divergent - OP_IMPL(testop2i2o, 2, 2, true) { - //nd4j_printf("CPU op used!\n",""); - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto xO = OUTPUT_VARIABLE(0); - auto yO = OUTPUT_VARIABLE(1); - - x->applyScalar(scalar::Add, 1.0, *xO); - y->applyScalar(scalar::Add, 2.0, *yO); - - STORE_2_RESULTS(*xO, *yO); - - return Status::OK(); - } - DECLARE_SYN(TestOp2i2o, testop2i2o); - - DECLARE_TYPES(testop2i2o) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +////////////////////////////////////////////////////////////////////////// +// test op, non-divergent +OP_IMPL(testop2i2o, 2, 2, true) { + // nd4j_printf("CPU op used!\n",""); + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto xO = OUTPUT_VARIABLE(0); + auto yO = OUTPUT_VARIABLE(1); + + x->applyScalar(scalar::Add, 1.0, *xO); + y->applyScalar(scalar::Add, 2.0, *yO); + + STORE_2_RESULTS(*xO, *yO); + + return Status::OK(); +} +DECLARE_SYN(TestOp2i2o, testop2i2o); + +DECLARE_TYPES(testop2i2o) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp b/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp index a0749ed7f6e0..a85dbf906219 100644 --- a/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp +++ b/libnd4j/include/ops/declarable/generic/tests/testreduction.cpp @@ -24,20 +24,18 @@ #include namespace sd { - namespace ops { - REDUCTION_OP_IMPL(testreduction, 1, 1, false, 0, -1) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +REDUCTION_OP_IMPL(testreduction, 1, 1, false, 0, -1) { + auto z = OUTPUT_VARIABLE(0); -// STORE_RESULT(*z); - return Status::OK(); - } + // STORE_RESULT(*z); + return Status::OK(); +} - DECLARE_TYPES(testreduction) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(testreduction) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp index 2fe7ddef62b3..cc7f0c040434 100644 --- a/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp +++ b/libnd4j/include/ops/declarable/generic/thrid_party/firas_sparse.cpp @@ -26,85 +26,82 @@ #ifndef LIBND4J_THIRD_PARTY_H #define LIBND4J_THIRD_PARTY_H -#include -#include +#include #include #include -#include +#include #include #include -#include - -namespace sd { - namespace ops { - - - /** - * This op is special one, and suited only for ProjectionLayer by @firasdib - * - * - * - * @tparam T - */ - CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_NULLIFIED(0); +#include - int batchSize = x->sizeAt(0); - int numColumns = x->sizeAt(1); +#include - std::vector indices(block.getIArguments()); - std::map sparse2dense; +namespace sd { +namespace ops { +/** + * This op is special one, and suited only for ProjectionLayer by @firasdib + * + * + * + * @tparam T + */ +CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_NULLIFIED(0); - int cnt = 0; - for (auto v: indices) { - std::pair pair(v, cnt++); - sparse2dense.insert(pair); - } + int batchSize = x->sizeAt(0); + int numColumns = x->sizeAt(1); - ResultSet rows = x->allTensorsAlongDimension({1}); + std::vector indices(block.getIArguments()); + std::map sparse2dense; - //PRAGMA_OMP_PARALLEL_FOR - for (int r = 0; r < batchSize; r++) { - auto row = rows.at(r); + int cnt = 0; + for (auto v : indices) { + std::pair pair(v, cnt++); + sparse2dense.insert(pair); + } - for (int e = 0; e < numColumns; e += 2) { - int idx = row.e(e); - if (idx < 0) - break; + ResultSet rows = x->allTensorsAlongDimension({1}); - int denseIdx = sparse2dense.at(idx); + // PRAGMA_OMP_PARALLEL_FOR + for (int r = 0; r < batchSize; r++) { + auto row = rows.at(r); + for (int e = 0; e < numColumns; e += 2) { + int idx = row.e(e); + if (idx < 0) break; - float value = row.e(e); - float current = z->e(r, denseIdx); - z->p(r, denseIdx, value + current); - } - } + int denseIdx = sparse2dense.at(idx); + float value = row.e(e); + float current = z->e(r, denseIdx); + z->p(r, denseIdx, value + current); + } + } - //STORE_RESULT(*z); + // STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(firas_sparse) { - auto inP = inputShape->at(0); +DECLARE_SHAPE_FN(firas_sparse) { + auto inP = inputShape->at(0); - std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong) block.numI()}); - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inP), 'c', shape); - return SHAPELIST(newShape); - } + std::vector shape({shape::shapeOf(inP)[0], (Nd4jLong)block.numI()}); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inP), 'c', shape); + return SHAPELIST(newShape); +} - DECLARE_TYPES(firas_sparse) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +DECLARE_TYPES(firas_sparse) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif -#endif //LIBND4J_THIRD_PARTY_H +#endif // LIBND4J_THIRD_PARTY_H diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp index 607980f0d8f8..8b014c59411f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp @@ -40,90 +40,128 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - auto input = INPUT_VARIABLE(0); - auto crop = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = input->rankOf(); - const int dim0 = input->sizeAt(0); - REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank); - REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0); - - if(crop->sizeAt(0) != 2 || crop->sizeAt(1) != 2) - REQUIRE_TRUE(false, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(crop).c_str()); - - const uint cropBottom = crop->e(0,0); - const uint cropTop = crop->e(0,1); - const uint cropLeft = crop->e(1,0); - const uint cropRight = crop->e(1,1); - - const int oH = input->sizeAt(1) * blockSize - cropBottom - cropTop; // top and bottom - const int oW = input->sizeAt(2) * blockSize - cropLeft - cropRight; // left and right - REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); - REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); - else - helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); - - return Status::OK(); + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + auto input = INPUT_VARIABLE(0); + auto crop = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "BatchToSpace: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = input->rankOf(); + const int dim0 = input->sizeAt(0); + REQUIRE_TRUE( + rank == 4, 0, + "BatchToSpace: rank of input array must be equal 4, but got %i instead", + rank); + REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, + "BatchToSpace: first dimension of input array must be divisible " + "by blockSize * blockSize (that is by %i), but got first " + "dimension equal to %i", + blockSize * blockSize, dim0); + + if (crop->sizeAt(0) != 2 || crop->sizeAt(1) != 2) + REQUIRE_TRUE(false, 0, + "BatchToSpace: operation expects crop shape to be {2, 2}, but " + "got %s instead", + ShapeUtils::shapeAsString(crop).c_str()); + + const uint cropBottom = crop->e(0, 0); + const uint cropTop = crop->e(0, 1); + const uint cropLeft = crop->e(1, 0); + const uint cropRight = crop->e(1, 1); + + const int oH = + input->sizeAt(1) * blockSize - cropBottom - cropTop; // top and bottom + const int oW = + input->sizeAt(2) * blockSize - cropLeft - cropRight; // left and right + REQUIRE_TRUE(oH >= 0, 0, + "BatchToSpace: crop top/bottom values are too big and cause " + "negative output height dimension !"); + REQUIRE_TRUE(oW >= 0, 0, + "BatchToSpace: crop left/right values are too big and cause " + "negative output width dimension !"); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, + cropTop, cropLeft, cropRight, blockSize); + else + helpers::batchToSpace(block.launchContext(), input->dup(), *output, + cropBottom, cropTop, cropLeft, cropRight, blockSize); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(batch_to_space) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batch_to_space) { - - auto inputShapeInfo = inputShape->at(0); - auto cropShapeInfo = inputShape->at(1); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = inputShapeInfo[0]; - const int dim0 = inputShapeInfo[1]; - REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank); - REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0); - - if(cropShapeInfo[1] != 2 || cropShapeInfo[2] != 2) - REQUIRE_TRUE(false, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(cropShapeInfo).c_str()); - - const uint cropBottom = INPUT_VARIABLE(1)->e(0,0); - const uint cropTop = INPUT_VARIABLE(1)->e(0,1); - const uint cropLeft = INPUT_VARIABLE(1)->e(1,0); - const uint cropRight = INPUT_VARIABLE(1)->e(1,1); - - const int oH = inputShapeInfo[2] * blockSize - cropTop - cropBottom; // top and bottom - const int oW = inputShapeInfo[3] * blockSize - cropLeft - cropRight; // left and right - REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); - REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - - // we always give out C order here - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {dim0 / (blockSize * blockSize), oH, oW, inputShapeInfo[4]})); + auto inputShapeInfo = inputShape->at(0); + auto cropShapeInfo = inputShape->at(1); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "BatchToSpace: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = inputShapeInfo[0]; + const int dim0 = inputShapeInfo[1]; + REQUIRE_TRUE( + rank == 4, 0, + "BatchToSpace: rank of input array must be equal 4, but got %i instead", + rank); + REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, + "BatchToSpace: first dimension of input array must be divisible " + "by blockSize * blockSize (that is by %i), but got first " + "dimension equal to %i", + blockSize * blockSize, dim0); + + if (cropShapeInfo[1] != 2 || cropShapeInfo[2] != 2) + REQUIRE_TRUE(false, 0, + "BatchToSpace: operation expects crop shape to be {2, 2}, but " + "got %s instead", + ShapeUtils::shapeAsString(cropShapeInfo).c_str()); + + const uint cropBottom = INPUT_VARIABLE(1)->e(0, 0); + const uint cropTop = INPUT_VARIABLE(1)->e(0, 1); + const uint cropLeft = INPUT_VARIABLE(1)->e(1, 0); + const uint cropRight = INPUT_VARIABLE(1)->e(1, 1); + + const int oH = + inputShapeInfo[2] * blockSize - cropTop - cropBottom; // top and bottom + const int oW = + inputShapeInfo[3] * blockSize - cropLeft - cropRight; // left and right + REQUIRE_TRUE(oH >= 0, 0, + "BatchToSpace: crop top/bottom values are too big and cause " + "negative output height dimension !"); + REQUIRE_TRUE(oW >= 0, 0, + "BatchToSpace: crop left/right values are too big and cause " + "negative output width dimension !"); + + // we always give out C order here + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', + {dim0 / (blockSize * blockSize), oH, oW, inputShapeInfo[4]})); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp index f62921cc2ef6..591e4afad040 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp @@ -40,89 +40,124 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] - - auto input = INPUT_VARIABLE(0); - auto blockShape = INPUT_VARIABLE(1); - auto crop = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(blockShape->rankOf() == 1, 0, "BatchToSpaceND: rank of blockShape array must be equal to one, but got %i instead !", blockShape->rankOf()); - - const uint numOfSpatialDims = blockShape->sizeAt(0); - - const auto product = blockShape->reduceNumber(sd::reduce::Prod).e(0); - REQUIRE_TRUE(input->sizeAt(0) % product == 0, 0, "BatchToSpaceND: first dimension of input array must be divisible by product of blockShape array elements (= %lld), but got first dimension equal to %i", product, input->sizeAt(0)); - - if(crop->sizeAt(0) != numOfSpatialDims || crop->sizeAt(1) != 2) { - const std::string expectedCropShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "BatchToSpaceND: operation expects padding shape to be %s, but got %s instead", expectedCropShape.c_str(), ShapeUtils::shapeAsString(crop).c_str()); - } - - // FIXME - should we use this time-consuming validation ? - for (uint i = 0; i < numOfSpatialDims; ++i) { - const auto cropLeft = crop->e(i,0); - const auto cropRight = crop->e(i,1); - const auto outSpatialDim = input->sizeAt(i + 1) * blockShape->e(i) - cropLeft - cropRight; - REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !"); - } - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); - else - helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output); - - return Status::OK(); + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] + + auto input = INPUT_VARIABLE(0); + auto blockShape = INPUT_VARIABLE(1); + auto crop = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(blockShape->rankOf() == 1, 0, + "BatchToSpaceND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShape->rankOf()); + + const uint numOfSpatialDims = blockShape->sizeAt(0); + + const auto product = + blockShape->reduceNumber(sd::reduce::Prod).e(0); + REQUIRE_TRUE(input->sizeAt(0) % product == 0, 0, + "BatchToSpaceND: first dimension of input array must be " + "divisible by product of blockShape array elements (= %lld), " + "but got first dimension equal to %i", + product, input->sizeAt(0)); + + if (crop->sizeAt(0) != numOfSpatialDims || crop->sizeAt(1) != 2) { + const std::string expectedCropShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "BatchToSpaceND: operation expects padding shape to be %s, " + "but got %s instead", + expectedCropShape.c_str(), + ShapeUtils::shapeAsString(crop).c_str()); + } + + // FIXME - should we use this time-consuming validation ? + for (uint i = 0; i < numOfSpatialDims; ++i) { + const auto cropLeft = crop->e(i, 0); + const auto cropRight = crop->e(i, 1); + const auto outSpatialDim = + input->sizeAt(i + 1) * blockShape->e(i) - cropLeft - + cropRight; + REQUIRE_TRUE(outSpatialDim >= 0, 0, + "BatchToSpaceND: crop left/right values are too big and cause " + "negative output spatial dimension/dimensions !"); + } + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, + *output); + else + helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, + *crop, *output); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(batch_to_space_nd) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batch_to_space_nd) { - - auto inputShapeInfo = inputShape->at(0); - auto blockShapeInfo = inputShape->at(1); - auto cropShapeInfo = inputShape->at(2); - - REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, "BatchToSpaceND: rank of blockShape array must be equal to one, but got %i instead !", blockShapeInfo[0]); - - const auto product = INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); - REQUIRE_TRUE(inputShapeInfo[1] % product == 0, 0, "BatchToSpaceND: first dimension of input array must be divisible by product of blockShape array elements (= %lld), but got first dimension equal to %i", product, inputShapeInfo[1]); - - const auto numOfSpatialDims = blockShapeInfo[1]; - - if(cropShapeInfo[1] != numOfSpatialDims || cropShapeInfo[2] != 2) { - const std::string expectedCropShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "BatchToSpaceND: operation expects padding shape to be %s, but got %s instead", expectedCropShape.c_str(), ShapeUtils::shapeAsString(cropShapeInfo).c_str()); - } - - - std::vector outShape(inputShapeInfo + 1, inputShapeInfo + 1 + inputShapeInfo[0]); - - outShape[0] /= product; - - for (uint i = 0; i < numOfSpatialDims; ++i) - outShape[i + 1] = outShape[i + 1] * INPUT_VARIABLE(1)->e(i) - INPUT_VARIABLE(2)->e(i,0) - INPUT_VARIABLE(2)->e(i,1); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); + auto inputShapeInfo = inputShape->at(0); + auto blockShapeInfo = inputShape->at(1); + auto cropShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, + "BatchToSpaceND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShapeInfo[0]); + + const auto product = + INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); + REQUIRE_TRUE(inputShapeInfo[1] % product == 0, 0, + "BatchToSpaceND: first dimension of input array must be " + "divisible by product of blockShape array elements (= %lld), " + "but got first dimension equal to %i", + product, inputShapeInfo[1]); + + const auto numOfSpatialDims = blockShapeInfo[1]; + + if (cropShapeInfo[1] != numOfSpatialDims || cropShapeInfo[2] != 2) { + const std::string expectedCropShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "BatchToSpaceND: operation expects padding shape to be %s, " + "but got %s instead", + expectedCropShape.c_str(), + ShapeUtils::shapeAsString(cropShapeInfo).c_str()); + } + + std::vector outShape(inputShapeInfo + 1, + inputShapeInfo + 1 + inputShapeInfo[0]); + + outShape[0] /= product; + + for (uint i = 0; i < numOfSpatialDims; ++i) + outShape[i + 1] = outShape[i + 1] * INPUT_VARIABLE(1)->e(i) - + INPUT_VARIABLE(2)->e(i, 0) - + INPUT_VARIABLE(2)->e(i, 1); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp index 2bfe06dd9f40..3d8c750c3bf7 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp @@ -22,31 +22,31 @@ #if NOT_EXCLUDED(OP_clipbyavgnorm) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const bool isInplace = block.isInplace(); + auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext()); - const bool isInplace = block.isInplace(); - auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext()); + helpers::clipByAveraged(block.launchContext(), *input, *output, + block.getIArguments(), ts, isInplace); - helpers::clipByAveraged(block.launchContext(), *input, *output, block.getIArguments(), ts, isInplace); - - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(clipbyavgnorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - -} +DECLARE_TYPES(clipbyavgnorm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp index 99a01d39034d..e642d0c920d6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_global_norm.cpp @@ -25,48 +25,47 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(clip_by_global_norm, 1, 2, true, 1, 0) { + std::vector inputs(block.width()); + std::vector outputs(block.width() + 1); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i] = INPUT_VARIABLE(i); + outputs[i] = OUTPUT_VARIABLE(i); + } + outputs[inputs.size()] = OUTPUT_VARIABLE(inputs.size()); + double clipNorm = T_ARG(0); + bool isInplace = block.isInplace(); + helpers::clipByGlobalNorm(block.launchContext(), inputs, clipNorm, + block.workspace(), outputs, isInplace); - std::vector inputs(block.width()); - std::vector outputs(block.width() + 1); - for (size_t i = 0; i < inputs.size(); ++i) { - inputs[i] = INPUT_VARIABLE(i); - outputs[i] = OUTPUT_VARIABLE(i); - } - outputs[inputs.size()] = OUTPUT_VARIABLE(inputs.size()); - double clipNorm = T_ARG(0); - bool isInplace = block.isInplace(); - helpers::clipByGlobalNorm(block.launchContext(), inputs, clipNorm, block.workspace(), outputs, isInplace); - - return Status::OK(); + return Status::OK(); } DECLARE_SHAPE_FN(clip_by_global_norm) { + auto shapeList = SHAPELIST(); - auto shapeList = SHAPELIST(); - - for (int e = 0; e < block.width(); e++) { - auto in = inputShape->at(e); - - Nd4jLong* newShape; - COPY_SHAPE(in, newShape); - shapeList->push_back(CONSTANT(newShape)); - } - - shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - return shapeList; -} - - DECLARE_TYPES(clip_by_global_norm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + for (int e = 0; e < block.width(); e++) { + auto in = inputShape->at(e); + Nd4jLong* newShape; + COPY_SHAPE(in, newShape); + shapeList->push_back(CONSTANT(newShape)); + } + shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inputShape->at(0)))); + return shapeList; } + +DECLARE_TYPES(clip_by_global_norm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp index d062d47105c1..a699ee3f8f65 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp @@ -22,58 +22,60 @@ #if NOT_EXCLUDED(OP_clipbynorm) #include -#include +#include namespace sd { -namespace ops { +namespace ops { - CONFIGURABLE_OP_IMPL(clipbynorm, 1, 1, true, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +CONFIGURABLE_OP_IMPL(clipbynorm, 1, 1, true, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext()); - const bool isInplace = block.isInplace(); + const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), + block.launchContext()); + const bool isInplace = block.isInplace(); - helpers::clipByNorm(block.launchContext(), *input, *output, block.getIArguments(), clipNorm, isInplace); + helpers::clipByNorm(block.launchContext(), *input, *output, + block.getIArguments(), clipNorm, isInplace); - return Status::OK(); - } - - - CUSTOM_OP_IMPL(clipbynorm_bp, 2, 1, false, 1, 0) { - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); + return Status::OK(); +} - auto gradI = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(T_ARG(0)); +CUSTOM_OP_IMPL(clipbynorm_bp, 2, 1, false, 1, 0) { + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); - helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, block.getIArguments(), clipNorm); + auto gradI = OUTPUT_VARIABLE(0); + const auto clipNorm = NDArrayFactory::create(T_ARG(0)); - return Status::OK(); - } + helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, + block.getIArguments(), clipNorm); - DECLARE_SHAPE_FN(clipbynorm_bp) { - auto inShapeInfo = inputShape->at(0); + return Status::OK(); +} - Nd4jLong *newShape = nullptr; - COPY_SHAPE(inShapeInfo, newShape); +DECLARE_SHAPE_FN(clipbynorm_bp) { + auto inShapeInfo = inputShape->at(0); - return SHAPELIST(CONSTANT(newShape)); - } + Nd4jLong *newShape = nullptr; + COPY_SHAPE(inShapeInfo, newShape); - DECLARE_TYPES(clipbynorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + return SHAPELIST(CONSTANT(newShape)); +} - DECLARE_TYPES(clipbynorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedInputTypes(1, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(clipbynorm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } + +DECLARE_TYPES(clipbynorm_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp index 4275e4837b15..4e7e18ea96d0 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_value.cpp @@ -25,30 +25,34 @@ #include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(clipbyvalue, 1, 1, true, 2, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // FIXME: extra args!!! - auto left = T_ARG(0); - auto right = T_ARG(1); - - REQUIRE_TRUE(left < right, 0, "clip_by_value: left bound should be lesser than right. But %f >= %f given.", left, right); - //input->applyTransform(transform::ClipByValue, output, block.getTArguments()->data()); - helpers::clipByValue(block.launchContext(), *input, left, right, *output); - //STORE_RESULT(*output); - - return Status::OK(); - } - DECLARE_SYN(ClipByValue, clipbyvalue); - - DECLARE_TYPES(clipbyvalue) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(clipbyvalue, 1, 1, true, 2, 0) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // FIXME: extra args!!! + auto left = T_ARG(0); + auto right = T_ARG(1); + + REQUIRE_TRUE(left < right, 0, + "clip_by_value: left bound should be lesser than right. But %f " + ">= %f given.", + left, right); + // input->applyTransform(transform::ClipByValue, output, + // block.getTArguments()->data()); + helpers::clipByValue(block.launchContext(), *input, left, right, *output); + // STORE_RESULT(*output); + + return Status::OK(); } +DECLARE_SYN(ClipByValue, clipbyvalue); + +DECLARE_TYPES(clipbyvalue) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 1048f19935b0..6cb0a777dbed 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -19,420 +19,450 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include +#include +#include -namespace sd { -namespace ops { +#include +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { - - REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - - const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - - // first of all take into account possible presence of empty arrays - // also if scalar is present -> copy its value to vector with length=1 - std::vector nonEmptyArrs; - std::vector arrsToDelete; - int index = 0; - bool allOfSameType = true; - auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; - auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; - - for(int i = 0; i < numOfInArrs; ++i) { - auto input = INPUT_VARIABLE(i); - auto currentRank = input->rankOf(); - -// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy -// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank); -// REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, rankOfFirstArr); - if(!input->isEmpty()) { - - allOfSameType &= (typeOfFirstArr == input->dataType()); - - if(input->rankOf() == 0) { - auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext()); - vec->assign(input); - nonEmptyArrs.push_back(vec); - arrsToDelete.push_back(index); - } - else{ - nonEmptyArrs.push_back(input); - } - ++index; - } - } - - const int numOfNonEmptyArrs = nonEmptyArrs.size(); - - if(numOfNonEmptyArrs == 0){ - //All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op) - REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty"); - return Status::OK(); - } - - const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); - if(axis < 0){ - axis += rank; - } - - // ******** input validation ******** // - REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !"); - REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) { - for(int dim = 0; dim < rank; ++dim) - if(dim != axis) - REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); + REQUIRE_TRUE(block.width() > 0, 0, + "CONCAT op: No input arrays were provided"); + + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); + + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + // first of all take into account possible presence of empty arrays + // also if scalar is present -> copy its value to vector with length=1 + std::vector nonEmptyArrs; + std::vector arrsToDelete; + int index = 0; + bool allOfSameType = true; + auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; + auto typeOfFirstArr = + block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; + + for (int i = 0; i < numOfInArrs; ++i) { + auto input = INPUT_VARIABLE(i); + auto currentRank = input->rankOf(); + + // TODO: follow two lines are in accordance to current tf.concat spec. + // Commented for compatibility with legacy + // REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must + // be greater 0, but is %lld instead.", i, currentRank); + // REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of + // dimensions in concat should be equals, but for %i input variable + // %lld != %lld appears.", i, currentRank, rankOfFirstArr); + if (!input->isEmpty()) { + allOfSameType &= (typeOfFirstArr == input->dataType()); + + if (input->rankOf() == 0) { + auto vec = + new NDArray('c', {1}, input->dataType(), block.launchContext()); + vec->assign(input); + nonEmptyArrs.push_back(vec); + arrsToDelete.push_back(index); + } else { + nonEmptyArrs.push_back(input); + } + ++index; } - // ******** end of input validation ******** // - - auto output = OUTPUT_VARIABLE(0); - - if(numOfNonEmptyArrs == 1) - output->assign(nonEmptyArrs[0]); - else - helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis); + } - // delete dynamically allocated vectors with length=1 - for(int index : arrsToDelete) - delete nonEmptyArrs[index]; + const int numOfNonEmptyArrs = nonEmptyArrs.size(); + if (numOfNonEmptyArrs == 0) { + // All inputs are empty arrays -> return empty, mainly for TF import + // compatibility (no op) + REQUIRE_TRUE( + OUTPUT_VARIABLE(0)->isEmpty(), 0, + "CONCAT op: If all input variables are empty, output must be empty"); return Status::OK(); + } + + const int rank = + nonEmptyArrs[0]->rankOf(); // look up to first non-empty array + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) + : INT_ARG(0); + if (axis < 0) { + axis += rank; + } + + // ******** input validation ******** // + REQUIRE_TRUE(allOfSameType, 0, + "CONCAT op: all of input arrays must have same type !"); + REQUIRE_TRUE( + 0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, + "CONCAT op: input axis must be in range [0, %i], but got %i instead!", + rank - 1, axis); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, + "CONCAT op: all input arrays must have the same rank !"); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) { + for (int dim = 0; dim < rank; ++dim) + if (dim != axis) + REQUIRE_TRUE( + nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, + "CONCAT op: all input arrays must have the same dimensions (except " + "those on input axis) !"); + } + // ******** end of input validation ******** // + + auto output = OUTPUT_VARIABLE(0); + + if (numOfNonEmptyArrs == 1) + output->assign(nonEmptyArrs[0]); + else + helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis); + + // delete dynamically allocated vectors with length=1 + for (int index : arrsToDelete) delete nonEmptyArrs[index]; + + return Status::OK(); } - DECLARE_SYN(ParallelConcat, concat); - DECLARE_SYN(concat_v2, concat); - DECLARE_SYN(concatv2, concat); +DECLARE_SYN(ParallelConcat, concat); +DECLARE_SYN(concat_v2, concat); +DECLARE_SYN(concatv2, concat); - DECLARE_TYPES(concat) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY); - // ->setSameMode(true); - } +DECLARE_TYPES(concat) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY); + // ->setSameMode(true); +} ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(concat) { - - REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - - const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - - // first of all take into account possible presence of empty arrays - // also if scalar is present -> use the shape of vector with length=1 instead - ShapeList arrShapes; - std::vector shapesToDelete; - int index = 0; - for(int i = 0; i < numOfInArrs; ++i) { - - if(inputShape->at(i)[0] == 0) { - if (shape::isEmpty(inputShape->at(i))) - arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType())); - else - arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); - } - else{ - arrShapes.push_back(inputShape->at(i)); - } - ++index; + REQUIRE_TRUE(block.width() > 0, 0, + "CONCAT op: No input arrays were provided"); + + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); + + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + // first of all take into account possible presence of empty arrays + // also if scalar is present -> use the shape of vector with length=1 instead + ShapeList arrShapes; + std::vector shapesToDelete; + int index = 0; + for (int i = 0; i < numOfInArrs; ++i) { + if (inputShape->at(i)[0] == 0) { + if (shape::isEmpty(inputShape->at(i))) + arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo( + 0, INPUT_VARIABLE(0)->dataType())); + else + arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo( + 1, INPUT_VARIABLE(0)->dataType())); + } else { + arrShapes.push_back(inputShape->at(i)); } - - const int numOfNonEmptyArrs = arrShapes.size(); - - const int rank = shape::rank(arrShapes.at(0)); - - int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); - if(axis < 0){ - axis += rank; - } - - // ******** input validation ******** // - REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - - for(int i = 1; i < numOfNonEmptyArrs; ++i) { - for(int dim = 0; dim < rank; ++dim) - if(dim != axis) - REQUIRE_TRUE(arrShapes.at(i)[dim+1] == arrShapes.at(0)[dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); - } - // ******** end of input validation ******** // - - - Nd4jLong* outShapeInfo(nullptr); - COPY_SHAPE(arrShapes.at(0), outShapeInfo); - - // case when we have only one input array - if(numOfNonEmptyArrs == 1) { - ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0))); - return SHAPELIST(CONSTANT(outShapeInfo)); - } - - for(int i = 1; i < numOfNonEmptyArrs; ++i) - outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1]; - - ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0))); - - // delete dynamically allocated vectors shapes with length=1 -// for(int index : shapesToDelete) -// RELEASE(arrShapes[index], block.workspace()); - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo)); - RELEASE(outShapeInfo, block.workspace()); - return SHAPELIST(result); + ++index; + } + + const int numOfNonEmptyArrs = arrShapes.size(); + + const int rank = shape::rank(arrShapes.at(0)); + + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) + : INT_ARG(0); + if (axis < 0) { + axis += rank; + } + + // ******** input validation ******** // + REQUIRE_TRUE( + 0 <= axis && axis < rank, 0, + "CONCAT op: input axis must be in range [0, %i], but got %i instead!", + rank - 1, axis); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, + "CONCAT op: all input arrays must have the same rank !"); + + for (int i = 1; i < numOfNonEmptyArrs; ++i) { + for (int dim = 0; dim < rank; ++dim) + if (dim != axis) + REQUIRE_TRUE(arrShapes.at(i)[dim + 1] == arrShapes.at(0)[dim + 1], 0, + "CONCAT op: all input arrays must have the same " + "dimensions (except those on input axis) !"); + } + // ******** end of input validation ******** // + + Nd4jLong* outShapeInfo(nullptr); + COPY_SHAPE(arrShapes.at(0), outShapeInfo); + + // case when we have only one input array + if (numOfNonEmptyArrs == 1) { + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), + shape::order(arrShapes.at(0))); + return SHAPELIST(CONSTANT(outShapeInfo)); + } + + for (int i = 1; i < numOfNonEmptyArrs; ++i) + outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1]; + + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), + shape::order(arrShapes.at(0))); + + // delete dynamically allocated vectors shapes with length=1 + // for(int index : shapesToDelete) + // RELEASE(arrShapes[index], block.workspace()); + + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(outShapeInfo)); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST(result); } +// ////////////////////////////////////////////////////////////////////////// +// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){ +// // do something here{ +// NDArray *last = INPUT_VARIABLE((int) block.width() - 1); - // ////////////////////////////////////////////////////////////////////////// - // CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){ - // // do something here{ - // NDArray *last = INPUT_VARIABLE((int) block.width() - 1); - - // int _dimension = 0; - // if (block.numI() > 0) - // _dimension = INT_ARG(0); - // else { - // _dimension = (int) last->e(0); - // } +// int _dimension = 0; +// if (block.numI() > 0) +// _dimension = INT_ARG(0); +// else { +// _dimension = (int) last->e(0); +// } - // // we want to ensure that all - // NDArray *first = nullptr; - // auto output = OUTPUT_VARIABLE(0); +// // we want to ensure that all +// NDArray *first = nullptr; +// auto output = OUTPUT_VARIABLE(0); - // int elements = 0; +// int elements = 0; - // for (int e = 0; e < block.width(); e++) { - // auto arr = INPUT_VARIABLE(e); - // if (!arr->isEmpty()) - // elements++; +// for (int e = 0; e < block.width(); e++) { +// auto arr = INPUT_VARIABLE(e); +// if (!arr->isEmpty()) +// elements++; - // // we must find first non-empty element here - // if (!arr->isEmpty() && first == nullptr) - // first = arr; - // } +// // we must find first non-empty element here +// if (!arr->isEmpty() && first == nullptr) +// first = arr; +// } - // REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input required!"); +// REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input +// required!"); - // // it's possible to get into situation when your input has only 1 input. That's just assign - // if (elements == 1) { - // output->assign(first); - // return Status::OK(); - // } +// // it's possible to get into situation when your input has only 1 input. +// That's just assign if (elements == 1) { +// output->assign(first); +// return Status::OK(); +// } - // bool oldScalars = first->rankOf() == 2 && first->isScalar(); +// bool oldScalars = first->rankOf() == 2 && first->isScalar(); - // auto buffers = new Nd4jPointer[elements]; - // auto shapes = new Nd4jPointer[elements]; +// auto buffers = new Nd4jPointer[elements]; +// auto shapes = new Nd4jPointer[elements]; - // buffers[0] = (Nd4jPointer) first->buffer(); - // shapes[0] = (Nd4jPointer) first->shapeInfo(); +// buffers[0] = (Nd4jPointer) first->buffer(); +// shapes[0] = (Nd4jPointer) first->shapeInfo(); - // if (_dimension < 0) - // _dimension += first->rankOf(); +// if (_dimension < 0) +// _dimension += first->rankOf(); - // if (sd::Environment::getInstance()->isDebugAndVerbose()) { - // printf("Shape %i: ", 0); - // shape::printShapeInfoLinear((Nd4jLong *) shapes[0]); - // } +// if (sd::Environment::getInstance()->isDebugAndVerbose()) { +// printf("Shape %i: ", 0); +// shape::printShapeInfoLinear((Nd4jLong *) shapes[0]); +// } - // int er = 0; - // for (int e = 0; e < block.width(); e++) { - // Variable *var = block.variable(e); - // auto array = var->getNDArray(); +// int er = 0; +// for (int e = 0; e < block.width(); e++) { +// Variable *var = block.variable(e); +// auto array = var->getNDArray(); - // if (array->isEmpty()) - // continue; +// if (array->isEmpty()) +// continue; - // buffers[er] = reinterpret_cast(array->buffer()); - // shapes[er++] = reinterpret_cast(array->shapeInfo()); +// buffers[er] = reinterpret_cast(array->buffer()); +// shapes[er++] = reinterpret_cast(array->shapeInfo()); - // oldScalars &= array->rankOf() == 2 && array->isScalar(); +// oldScalars &= array->rankOf() == 2 && array->isScalar(); - // if (sd::Environment::getInstance()->isDebugAndVerbose()) { - // printf("Shape %i: ", e); - // shape::printShapeInfoLinear(array->shapeInfo()); - // } - // } - // if (sd::Environment::getInstance()->isDebugAndVerbose()) - // fflush(stdout); +// if (sd::Environment::getInstance()->isDebugAndVerbose()) { +// printf("Shape %i: ", e); +// shape::printShapeInfoLinear(array->shapeInfo()); +// } +// } +// if (sd::Environment::getInstance()->isDebugAndVerbose()) +// fflush(stdout); - // if (oldScalars) { - // nd4j_debug("OLD_SCALARS!\n",""); - // _dimension = 1; - // } +// if (oldScalars) { +// nd4j_debug("OLD_SCALARS!\n",""); +// _dimension = 1; +// } - // sd::SpecialMethods::concatCpuGeneric(_dimension, elements, buffers, shapes, output->buffer(), output->shapeInfo()); +// sd::SpecialMethods::concatCpuGeneric(_dimension, elements, buffers, +// shapes, output->buffer(), output->shapeInfo()); - // STORE_RESULT(*output); +// STORE_RESULT(*output); - // if (sd::Environment::getInstance()->isDebugAndVerbose()) - // output->printShapeInfo("Concat result shape"); +// if (sd::Environment::getInstance()->isDebugAndVerbose()) +// output->printShapeInfo("Concat result shape"); - // delete[] buffers; - // delete[] shapes; +// delete[] buffers; +// delete[] shapes; - // return ND4J_STATUS_OK; - // } +// return ND4J_STATUS_OK; +// } - // DECLARE_SYN(ParallelConcat, concat); - // DECLARE_SYN(concat_v2, concat); - // DECLARE_SYN(concatv2, concat); +// DECLARE_SYN(ParallelConcat, concat); +// DECLARE_SYN(concat_v2, concat); +// DECLARE_SYN(concatv2, concat); - // DECLARE_SHAPE_FN(concat) { - // auto inp = inputShape->at(0); - // int _dimension = INT_ARG(0); +// DECLARE_SHAPE_FN(concat) { +// auto inp = inputShape->at(0); +// int _dimension = INT_ARG(0); - // NDArray* first = nullptr; - // auto last = inputShape->at(inputShape->size() - 1); +// NDArray* first = nullptr; +// auto last = inputShape->at(inputShape->size() - 1); - // Nd4jLong elements = 0; - // Nd4jLong *newShape; +// Nd4jLong elements = 0; +// Nd4jLong *newShape; - // for (int e = 0; e < inputShape->size(); e++) { - // auto s = INPUT_VARIABLE(e); +// for (int e = 0; e < inputShape->size(); e++) { +// auto s = INPUT_VARIABLE(e); - // if (!s->isEmpty()) { - // elements++; +// if (!s->isEmpty()) { +// elements++; - // if (first == nullptr) - // first = s; - // } - // } +// if (first == nullptr) +// first = s; +// } +// } +// { // special cases for 0D concat +// bool allScalars = true; +// bool hasScalars = false; +// for (int e = 0; e < block.width(); e++) { +// auto c = INPUT_VARIABLE(e); - // { // special cases for 0D concat - // bool allScalars = true; - // bool hasScalars = false; - // for (int e = 0; e < block.width(); e++) { - // auto c = INPUT_VARIABLE(e); +// if (c->isEmpty()) +// continue; - // if (c->isEmpty()) - // continue; +// allScalars &= c->rankOf() == 0; +// hasScalars |= c->rankOf() == 0; +// } - // allScalars &= c->rankOf() == 0; - // hasScalars |= c->rankOf() == 0; - // } +// // all scalars +// if (allScalars) { +// ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), +// Nd4jLong); - // // all scalars - // if (allScalars) { - // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); +// shape::shapeBuffer(1, &elements, newShape); +// return SHAPELIST(newShape); +// } - // shape::shapeBuffer(1, &elements, newShape); - // return SHAPELIST(newShape); - // } +// // any scalar +// if (hasScalars) { +// ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), +// Nd4jLong); Nd4jLong length = shape::length(inp); for (int i = 1; +// i < block.width(); i++) { +// auto c = INPUT_VARIABLE(i); +// if (c->isEmpty()) +// continue; - // // any scalar - // if (hasScalars) { - // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(1), Nd4jLong); - // Nd4jLong length = shape::length(inp); - // for (int i = 1; i < block.width(); i++) { - // auto c = INPUT_VARIABLE(i); - // if (c->isEmpty()) - // continue; +// length += c->lengthOf(); +// } - // length += c->lengthOf(); - // } +// shape::shapeBuffer(1, &length, newShape); +// return SHAPELIST(newShape); +// } +// } - // shape::shapeBuffer(1, &length, newShape); - // return SHAPELIST(newShape); - // } - // } +// ALLOCATE(newShape, block.workspace(), +// shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); +// if (_dimension < 0) +// _dimension += first->rankOf(); - // ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); +// std::memcpy(newShape, first->shapeInfo(), +// shape::shapeInfoByteLength(first->shapeInfo())); for (int i = 0; i < +// inputShape->size(); i++) { +// auto s = INPUT_VARIABLE(i); - // if (_dimension < 0) - // _dimension += first->rankOf(); +// // FIXME: s == first is bad, but fast. alternatively we can subtract +// first size out of result if (s->isEmpty() || s == first) +// continue; - // std::memcpy(newShape, first->shapeInfo(), shape::shapeInfoByteLength(first->shapeInfo())); - // for (int i = 0; i < inputShape->size(); i++) { - // auto s = INPUT_VARIABLE(i); +// newShape[_dimension + 1] += +// shape::shapeOf(inputShape->at(i))[_dimension]; +// } - // // FIXME: s == first is bad, but fast. alternatively we can subtract first size out of result - // if (s->isEmpty() || s == first) - // continue; +// shape::updateStrides(newShape, first->ordering()); - // newShape[_dimension + 1] += shape::shapeOf(inputShape->at(i))[_dimension]; - // } - - // shape::updateStrides(newShape, first->ordering()); - - // return SHAPELIST(newShape); - // } +// return SHAPELIST(newShape); +// } ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) { + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1); + auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1); - auto first = INPUT_VARIABLE(0); + auto first = INPUT_VARIABLE(0); - const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf()); + const int axis = + isAxisInLastArr + ? INPUT_VARIABLE(block.width() - 1)->e(0) + : (INT_ARG(0) >= 0 ? INT_ARG(0) + : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf()); - int startPos = 0; + int startPos = 0; - for (int e = 0; e < numOfInArrs - 1; e++) { - auto originalChunk = INPUT_VARIABLE(e); - auto epsilonChunk = OUTPUT_VARIABLE(e); - std::vector indices(2 * epsilonNext->rankOf()); + for (int e = 0; e < numOfInArrs - 1; e++) { + auto originalChunk = INPUT_VARIABLE(e); + auto epsilonChunk = OUTPUT_VARIABLE(e); + std::vector indices(2 * epsilonNext->rankOf()); - int width = originalChunk->sizeAt(axis); + int width = originalChunk->sizeAt(axis); - for (int e = 0; e < epsilonNext->rankOf(); e++) { - if (e == axis) - indices[2*e + 1] = (indices[2*e] = startPos) + width; - else - indices[2*e + 1] = indices[2*e] = 0; - } + for (int e = 0; e < epsilonNext->rankOf(); e++) { + if (e == axis) + indices[2 * e + 1] = (indices[2 * e] = startPos) + width; + else + indices[2 * e + 1] = indices[2 * e] = 0; + } - auto subarray = (*epsilonNext)(indices, true); - epsilonChunk->assign(subarray); + auto subarray = (*epsilonNext)(indices, true); + epsilonChunk->assign(subarray); - startPos += width; - } + startPos += width; + } - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; } DECLARE_TYPES(concat_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } DECLARE_SHAPE_FN(concat_bp) { + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); - const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); - const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + auto shapeList = SHAPELIST(); - auto shapeList = SHAPELIST(); + for (int e = 0; e < numOfInArrs - 1; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } - for (int e = 0; e < numOfInArrs - 1; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } - - return shapeList; + return shapeList; } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index 79644f51b8d4..66910591290f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -21,138 +21,144 @@ #include #if NOT_EXCLUDED(OP_cumprod) -#include #include +#include namespace sd { - namespace ops { - CONFIGURABLE_OP_IMPL(cumprod, 1, 1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); - - if(input->isEmpty()){ - //No-op - return Status::OK(); - } - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - if (block.numI() == 2 && block.width() == 1) { - // all at once case - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, exclusive, reverse); - } else { - std::vector dims(block.numI() - 2); - - if (block.width() == 1) { - - for (int e = 0; e < block.numI() - 2; e++) - dims[e] = INT_ARG(e + 2); - } else { - auto ax = INPUT_VARIABLE(1); - dims = ax->template asVectorT(); - } - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += input->rankOf(); - - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - } - - return Status::OK(); - } - - DECLARE_TYPES(cumprod) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } - - DECLARE_TYPES(cumprod_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) // there is a case when axes given as IArgs - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(cumprod_bp, 2, 1, false, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; - auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - std::vector dims; - - if (block.width() > 2) { - dims = axis->template asVectorT(); - OUTPUT_VARIABLE(1)->assign(1.0f); - } else if (int newSize = (block.numI() - 2)) { - dims.resize(newSize); - - for (int e = 0; e < newSize; e++) - dims[e] = INT_ARG(e + 2); - } - - sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, output, dims, exclusive, reverse); - NDArray val = NDArray(output->dup()); - - gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); - val.applyPairwiseTransform(pairwise::Divide, *input, val); - if (!exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, true); - - } - else if (!exclusive && reverse){ - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, false, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, false, false); - } - else if (exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, true); - } - else { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, true, false); - } - - return Status::OK(); - } - - - DECLARE_SHAPE_FN(cumprod_bp) { - auto inp = inputShape->at(0); - Nd4jLong *newShapeX = nullptr; - COPY_SHAPE(inp, newShapeX); - - if (block.width() == 2) { - return SHAPELIST(CONSTANT(newShapeX)); - } else { - Nd4jLong *newShapeA = nullptr; - COPY_SHAPE(inputShape->at(1), newShapeA); - - return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); - } - } +namespace ops { +CONFIGURABLE_OP_IMPL(cumprod, 1, 1, true, 0, 2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->dataType() == output->dataType(), 0, + "CumSum: input and output data types must be equal"); + + if (input->isEmpty()) { + // No-op + return Status::OK(); + } + + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + if (block.numI() == 2 && block.width() == 1) { + // all at once case + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, exclusive, reverse); + } else { + std::vector dims(block.numI() - 2); + + if (block.width() == 1) { + for (int e = 0; e < block.numI() - 2; e++) dims[e] = INT_ARG(e + 2); + } else { + auto ax = INPUT_VARIABLE(1); + dims = ax->template asVectorT(); } + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += input->rankOf(); + + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, dims, exclusive, reverse); + } + + return Status::OK(); +} + +DECLARE_TYPES(cumprod) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} + +DECLARE_TYPES(cumprod_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes( + 1, + {ALL_INTS, ALL_FLOATS}) // there is a case when axes given as IArgs + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} + +CUSTOM_OP_IMPL(cumprod_bp, 2, 1, false, 0, 2) { + auto input = INPUT_VARIABLE(0); + auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; + auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + std::vector dims; + + if (block.width() > 2) { + dims = axis->template asVectorT(); + OUTPUT_VARIABLE(1)->assign(1.0f); + } else if (int newSize = (block.numI() - 2)) { + dims.resize(newSize); + + for (int e = 0; e < newSize; e++) dims[e] = INT_ARG(e + 2); + } + + sd::ops::helpers::prefix(block.launchContext(), scalar::Multiply, input, + output, dims, exclusive, reverse); + NDArray val = NDArray(output->dup()); + + gradOut->applyPairwiseTransform(pairwise::Multiply, *output, val); + val.applyPairwiseTransform(pairwise::Divide, *input, val); + if (!exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + false, true); + + } else if (!exclusive && reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, false, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + false, false); + } else if (exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + true, true); + } else { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, &val, output, + true, false); + } + + return Status::OK(); +} + +DECLARE_SHAPE_FN(cumprod_bp) { + auto inp = inputShape->at(0); + Nd4jLong *newShapeX = nullptr; + COPY_SHAPE(inp, newShapeX); + + if (block.width() == 2) { + return SHAPELIST(CONSTANT(newShapeX)); + } else { + Nd4jLong *newShapeA = nullptr; + COPY_SHAPE(inputShape->at(1), newShapeA); + + return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); + } } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp index 712eaca70392..7879fe5a55d5 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp @@ -21,131 +21,133 @@ #include #if NOT_EXCLUDED(OP_cumsum) -#include #include +#include namespace sd { -namespace ops { +namespace ops { CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; - if(input->isEmpty()){ - //No-op - return Status::OK(); - } + REQUIRE_TRUE(input->dataType() == output->dataType(), 0, + "CumSum: input and output data types must be equal"); - if (block.numI() == 2 && block.width() == 1) { - // all at once case - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse); + if (input->isEmpty()) { + // No-op + return Status::OK(); + } + + if (block.numI() == 2 && block.width() == 1) { + // all at once case + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, + exclusive, reverse); + } else { + std::vector dims(block.numI() - 2); + + if (block.width() == 1) { + for (int e = 0; e < block.numI() - 2; e++) dims[e] = INT_ARG(e + 2); + } else { + auto ax = INPUT_VARIABLE(1); + dims = ax->template asVectorT(); } - else { - std::vector dims(block.numI() - 2); - - if (block.width() == 1) { - for (int e = 0; e < block.numI() - 2; e++) - dims[e] = INT_ARG(e + 2); - } - else { - auto ax = INPUT_VARIABLE(1); - dims = ax->template asVectorT(); - } + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += input->rankOf(); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += input->rankOf(); + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, + dims, exclusive, reverse); + } - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, input, output, dims, exclusive, reverse); - } - - return Status::OK(); + return Status::OK(); +} +DECLARE_TYPES(cumsum) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(false); } - DECLARE_TYPES(cumsum) { - - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(false); - } CUSTOM_OP_IMPL(cumsum_bp, 2, -1, true, 0, 2) { - auto input = INPUT_VARIABLE(0); - auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; - auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); -// output->assign(gradOut); - const bool exclusive = INT_ARG(0) == 1; - const bool reverse = INT_ARG(1) == 1; - - std::vector dims; - - if (block.width() > 2) { - dims = axis->template asVectorT(); - OUTPUT_VARIABLE(1)->assign(1.0f); - } else if (int newSize = (block.numI() - 2)) { - dims.resize(newSize); - - for (int e = 0; e < newSize; e++) - dims[e] = INT_ARG(e + 2); - } - if (!exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, false, true); - - } - else if (!exclusive && reverse){ - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, false, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, false, false); - } - else if (exclusive && !reverse) { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, true); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, true, true); - } - else { - if (dims.size()) - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, dims, true, false); - else - sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, output, true, false); - } - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto axis = block.width() == 3 ? INPUT_VARIABLE(1) : nullptr; + auto gradOut = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + // output->assign(gradOut); + const bool exclusive = INT_ARG(0) == 1; + const bool reverse = INT_ARG(1) == 1; + + std::vector dims; + + if (block.width() > 2) { + dims = axis->template asVectorT(); + OUTPUT_VARIABLE(1)->assign(1.0f); + } else if (int newSize = (block.numI() - 2)) { + dims.resize(newSize); + + for (int e = 0; e < newSize; e++) dims[e] = INT_ARG(e + 2); + } + if (!exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, false, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, false, true); + + } else if (!exclusive && reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, false, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, false, false); + } else if (exclusive && !reverse) { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, true, true); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, true, true); + } else { + if (dims.size()) + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, dims, true, false); + else + sd::ops::helpers::prefix(block.launchContext(), scalar::Add, gradOut, + output, true, false); + } + + return Status::OK(); +} +DECLARE_TYPES(cumsum_bp) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes( + 1, {ALL_FLOATS, ALL_INTS}); // axes can be set as the second param + getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } - DECLARE_TYPES(cumsum_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS}); // axes can be set as the second param - getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } - DECLARE_SHAPE_FN(cumsum_bp) { - auto inp = inputShape->at(0); - Nd4jLong *newShapeX = nullptr; - COPY_SHAPE(inp, newShapeX); +DECLARE_SHAPE_FN(cumsum_bp) { + auto inp = inputShape->at(0); + Nd4jLong *newShapeX = nullptr; + COPY_SHAPE(inp, newShapeX); - if (block.width() == 2) { - return SHAPELIST(CONSTANT(newShapeX)); - } else { - Nd4jLong *newShapeA = nullptr; - COPY_SHAPE(inputShape->at(1), newShapeA); + if (block.width() == 2) { + return SHAPELIST(CONSTANT(newShapeX)); + } else { + Nd4jLong *newShapeA = nullptr; + COPY_SHAPE(inputShape->at(1), newShapeA); - return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); - } - } -} + return SHAPELIST(CONSTANT(newShapeX), CONSTANT(newShapeA)); + } } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp index dcf827eb1f55..10187be4521f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp @@ -23,69 +23,73 @@ #include #include + #include namespace sd { namespace ops { - CUSTOM_OP_IMPL(depth_to_space, 1, 1, false, 0, 2) { - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DepthToSpace: input should be 4D array, but got %f instead", input->rankOf()); - - int bS = input->sizeAt(0); - int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); - - REQUIRE_TRUE(iD % (block_size * block_size) == 0, 0, "DepthToSpace: input number of channels should be divisible by square(block_size)"); - - auto output = OUTPUT_VARIABLE(0); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC); - else - helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC); - - STORE_RESULT(output); - - return ND4J_STATUS_OK; - } - - DECLARE_TYPES(depth_to_space) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - - DECLARE_SHAPE_FN(depth_to_space) { - auto in = inputShape->at(0); - auto block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - int bS = shape::sizeAt(in, 0); - int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); - int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); - int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); - - int oD = iD / (block_size * block_size); - int oH = iH * block_size; - int oW = iW * block_size; - - - std::array shape; - if (isNHWC) - shape = {{bS, oH, oW, oD }}; - else - shape = {{bS, oD, oH, oW }}; - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } +CUSTOM_OP_IMPL(depth_to_space, 1, 1, false, 0, 2) { + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DepthToSpace: input should be 4D array, but got %f instead", + input->rankOf()); + + int bS = input->sizeAt(0); + int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); + int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); + int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + + REQUIRE_TRUE(iD % (block_size * block_size) == 0, 0, + "DepthToSpace: input number of channels should be divisible by " + "square(block_size)"); + + auto output = OUTPUT_VARIABLE(0); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_depthToSpace(block.launchContext(), *input, output, block_size, + isNHWC); + else + helpers::_depthToSpace(block.launchContext(), input->dup(), output, + block_size, isNHWC); + + STORE_RESULT(output); + + return ND4J_STATUS_OK; } + +DECLARE_TYPES(depth_to_space) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +DECLARE_SHAPE_FN(depth_to_space) { + auto in = inputShape->at(0); + auto block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + int bS = shape::sizeAt(in, 0); + int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); + int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); + int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); + + int oD = iD / (block_size * block_size); + int oH = iH * block_size; + int oW = iW * block_size; + + std::array shape; + if (isNHWC) + shape = {{bS, oH, oW, oD}}; + else + shape = {{bS, oD, oH, oW}}; + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp index 3100b9ad4168..08cea23d3ee6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_parititon.cpp @@ -22,131 +22,138 @@ #if NOT_EXCLUDED(OP_dynamic_partition) #include -#include #include +#include + namespace sd { namespace ops { - CUSTOM_OP_IMPL(dynamic_partition, 2, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - - // input->printShapeInfo("input"); - // indices->printShapeInfo("indices"); - - REQUIRE_TRUE(input->rankOf() >= indices->rankOf(), 0, - "dynamic_partition: data tensor rank should be non-lesser than indices\' tensor, but %i < %i given,", - input->rankOf(), indices->rankOf()); - for (int dim = 0; dim < indices->rankOf(); dim++) { - REQUIRE_TRUE(input->sizeAt(dim) == indices->sizeAt(dim), 0, - "dynamic_partition: dimensions should be equals for data and indices tensors, but at axis[%i] %i != %i given", - dim, input->sizeAt(dim), indices->sizeAt(dim)); - } - - auto numPartition = INT_ARG(0); - std::vector outputList(numPartition); - for (int o = 0; o < numPartition; ++o) { - outputList[o] = OUTPUT_VARIABLE(o); - } - helpers::dynamicPartitionFunctor(block.launchContext(), input, indices, outputList); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(dynamic_partition) { - auto numPartition = INT_ARG(0); - auto indices = INPUT_VARIABLE(1); - std::vector partitionSizes(numPartition, 0); - auto in = inputShape->at(0); - auto idx = inputShape->at(1); - for (int i = 0; i < numPartition; i++) { - for (int e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - partitionSizes[i]++; - } - - auto shapes = SHAPELIST(); - int outRank = shape::rank(in) - shape::rank(idx) + 1; - for (int e = 0; e < numPartition; e++) { - Nd4jLong *newShape; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - //shape::shapeVector(partitionSizes[e], newShape); - newShape[0] = outRank; - newShape[1] = partitionSizes[e]; - for (int i = 1; i < outRank; ++i) - newShape[i + 1] = shape::sizeAt(in, outRank + i - 1); - - shape::updateStrides(newShape, shape::order(in)); - ArrayOptions::setDataType(newShape, ArrayOptions::dataType(in)); - shapes->push_back(CONSTANT(newShape)); - } - - return shapes; - } - - DECLARE_TYPES(dynamic_partition) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); - } - - DECLARE_TYPES(dynamic_partition_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(dynamic_partition_bp, 3, 2, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - //auto gradOut = ; - auto numPartition = INT_ARG(0); - - std::vector outputList(2); // only for output - std::vector gradOutList(numPartition); - for (Nd4jLong e = 0; e < numPartition; e++) { - gradOutList[e] = INPUT_VARIABLE(e + 2); - } - outputList[0] = OUTPUT_VARIABLE(0); - outputList[1] = OUTPUT_VARIABLE(1); - auto originalIndices = indices->dup(); //->ordering(), indices->shapeInfo(), indices->dataType()); - originalIndices.linspace(0); - ops::dynamic_partition op; - auto res = op.evaluate({&originalIndices, indices}, {numPartition}); - REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); - ops::dynamic_stitch stichOp; - std::vector partitions(numPartition * 2); - for (size_t i = 0; i < res.size(); i++) { - partitions[i] = &res.at(i); - partitions[i + numPartition] = gradOutList[i]; - } - - auto result = stichOp.evaluate(partitions, {numPartition}); - REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning."); - result.at(0).reshapei(outputList[0]->getShapeAsVector()); - outputList[1]->assign(indices); - outputList[0]->assign(result.at(0)); - -// helpers::dynamicPartitionFunctorBP(block.launchContext(), input, indices, gradOutList, outputList); - return ND4J_STATUS_OK; - } - - DECLARE_SHAPE_FN(dynamic_partition_bp) { - auto numPartition = INT_ARG(0); - auto indices = INPUT_VARIABLE(1); - std::vector partitionSizes(numPartition, 0); - - auto shapes = SHAPELIST(); - // just copy shape info from input and indices to output - for (Nd4jLong i = 0; i < 2; i++) { - Nd4jLong *newShape; - COPY_SHAPE(inputShape->at(i), newShape); - shapes->push_back(CONSTANT(newShape)); - } - - return shapes; - } +CUSTOM_OP_IMPL(dynamic_partition, 2, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + + // input->printShapeInfo("input"); + // indices->printShapeInfo("indices"); + + REQUIRE_TRUE(input->rankOf() >= indices->rankOf(), 0, + "dynamic_partition: data tensor rank should be non-lesser than " + "indices\' tensor, but %i < %i given,", + input->rankOf(), indices->rankOf()); + for (int dim = 0; dim < indices->rankOf(); dim++) { + REQUIRE_TRUE(input->sizeAt(dim) == indices->sizeAt(dim), 0, + "dynamic_partition: dimensions should be equals for data and " + "indices tensors, but at axis[%i] %i != %i given", + dim, input->sizeAt(dim), indices->sizeAt(dim)); + } + + auto numPartition = INT_ARG(0); + std::vector outputList(numPartition); + for (int o = 0; o < numPartition; ++o) { + outputList[o] = OUTPUT_VARIABLE(o); + } + helpers::dynamicPartitionFunctor(block.launchContext(), input, indices, + outputList); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(dynamic_partition) { + auto numPartition = INT_ARG(0); + auto indices = INPUT_VARIABLE(1); + std::vector partitionSizes(numPartition, 0); + auto in = inputShape->at(0); + auto idx = inputShape->at(1); + for (int i = 0; i < numPartition; i++) { + for (int e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) partitionSizes[i]++; + } + + auto shapes = SHAPELIST(); + int outRank = shape::rank(in) - shape::rank(idx) + 1; + for (int e = 0; e < numPartition; e++) { + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + // shape::shapeVector(partitionSizes[e], newShape); + newShape[0] = outRank; + newShape[1] = partitionSizes[e]; + for (int i = 1; i < outRank; ++i) + newShape[i + 1] = shape::sizeAt(in, outRank + i - 1); + + shape::updateStrides(newShape, shape::order(in)); + ArrayOptions::setDataType(newShape, ArrayOptions::dataType(in)); + shapes->push_back(CONSTANT(newShape)); + } + + return shapes; +} + +DECLARE_TYPES(dynamic_partition) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } + +DECLARE_TYPES(dynamic_partition_bp) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} + +CUSTOM_OP_IMPL(dynamic_partition_bp, 3, 2, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + // auto gradOut = ; + auto numPartition = INT_ARG(0); + + std::vector outputList(2); // only for output + std::vector gradOutList(numPartition); + for (Nd4jLong e = 0; e < numPartition; e++) { + gradOutList[e] = INPUT_VARIABLE(e + 2); + } + outputList[0] = OUTPUT_VARIABLE(0); + outputList[1] = OUTPUT_VARIABLE(1); + auto originalIndices = + indices + ->dup(); //->ordering(), indices->shapeInfo(), indices->dataType()); + originalIndices.linspace(0); + ops::dynamic_partition op; + auto res = op.evaluate({&originalIndices, indices}, {numPartition}); + REQUIRE_TRUE(res.status() == ND4J_STATUS_OK, 0, + "dynamic_partition_bp: Error with dynamic partitioning."); + ops::dynamic_stitch stichOp; + std::vector partitions(numPartition * 2); + for (size_t i = 0; i < res.size(); i++) { + partitions[i] = &res.at(i); + partitions[i + numPartition] = gradOutList[i]; + } + + auto result = stichOp.evaluate(partitions, {numPartition}); + REQUIRE_TRUE(result.status() == ND4J_STATUS_OK, 0, + "dynamic_partition_bp: Error with dynamic partitioning."); + result.at(0).reshapei(outputList[0]->getShapeAsVector()); + outputList[1]->assign(indices); + outputList[0]->assign(result.at(0)); + + // helpers::dynamicPartitionFunctorBP(block.launchContext(), input, + // indices, gradOutList, outputList); + return ND4J_STATUS_OK; +} + +DECLARE_SHAPE_FN(dynamic_partition_bp) { + auto numPartition = INT_ARG(0); + auto indices = INPUT_VARIABLE(1); + std::vector partitionSizes(numPartition, 0); + + auto shapes = SHAPELIST(); + // just copy shape info from input and indices to output + for (Nd4jLong i = 0; i < 2; i++) { + Nd4jLong *newShape; + COPY_SHAPE(inputShape->at(i), newShape); + shapes->push_back(CONSTANT(newShape)); + } + + return shapes; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp b/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp index ecf0e53246b5..2b7f0efaf41b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/dynamic_stitch.cpp @@ -26,62 +26,70 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { - int numOfData = block.width(); -// int k = 0; - // checking input data size - REQUIRE_TRUE(numOfData % 2 == 0, 0, - "dynamic_stitch: The input params should contains" - " both indeces and data lists with same length."); - // split input data list on two equal parts - numOfData /= 2; +CUSTOM_OP_IMPL(dynamic_stitch, 2, 1, false, 0, 0) { + int numOfData = block.width(); + // int k = 0; + // checking input data size + REQUIRE_TRUE(numOfData % 2 == 0, 0, + "dynamic_stitch: The input params should contains" + " both indeces and data lists with same length."); + // split input data list on two equal parts + numOfData /= 2; - // form input lists to use with helpers - both indices and float data inputs - auto output = OUTPUT_VARIABLE(0); - std::vector inputs(numOfData); - std::vector indices(numOfData); + // form input lists to use with helpers - both indices and float data inputs + auto output = OUTPUT_VARIABLE(0); + std::vector inputs(numOfData); + std::vector indices(numOfData); - for (int e = 0; e < numOfData; e++) { - auto data = INPUT_VARIABLE(numOfData + e); - auto index = INPUT_VARIABLE(e); + for (int e = 0; e < numOfData; e++) { + auto data = INPUT_VARIABLE(numOfData + e); + auto index = INPUT_VARIABLE(e); - inputs[e] = data; - indices[e] = index; - } - // run helper - return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, output); - } + inputs[e] = data; + indices[e] = index; + } + // run helper + return helpers::dynamicStitchFunctor(block.launchContext(), inputs, indices, + output); +} - DECLARE_TYPES(dynamic_stitch) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } +DECLARE_TYPES(dynamic_stitch) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} - DECLARE_SHAPE_FN(dynamic_stitch) { - Nd4jLong maxValue = 0; - auto numOfData = block.width(); - numOfData /= 2; // only index part it's needed to review - auto restShape = inputShape->at(numOfData); - auto firstShape = inputShape->at(0); - // check up inputs to avoid non-int indices and calculate max value from indices to output shape length - for(int i = 0; i < numOfData; i++) { - auto input = INPUT_VARIABLE(i); - REQUIRE_TRUE(input->isZ(), 0, "dynamic_stitch: Indices should be integer, but %d type given.", (int)input->dataType() ); - auto maxV = input->reduceNumber(reduce::Max); - if (maxV.e(0) > maxValue) maxValue = maxV.e(0); - } - // calculate output rank - difference between indices shape and data shape - int outRank = shape::rank(restShape) - shape::rank(firstShape) + 1; // at least 1D tensor - std::vector outShape(outRank); - // fill up output shape template: the first to max index, and rests - to vals from the first data input - outShape[0] = maxValue + 1; - for(int i = 1; i < outRank; ++i) - outShape[i] = shape::sizeAt(restShape, i); +DECLARE_SHAPE_FN(dynamic_stitch) { + Nd4jLong maxValue = 0; + auto numOfData = block.width(); + numOfData /= 2; // only index part it's needed to review + auto restShape = inputShape->at(numOfData); + auto firstShape = inputShape->at(0); + // check up inputs to avoid non-int indices and calculate max value from + // indices to output shape length + for (int i = 0; i < numOfData; i++) { + auto input = INPUT_VARIABLE(i); + REQUIRE_TRUE( + input->isZ(), 0, + "dynamic_stitch: Indices should be integer, but %d type given.", + (int)input->dataType()); + auto maxV = input->reduceNumber(reduce::Max); + if (maxV.e(0) > maxValue) maxValue = maxV.e(0); + } + // calculate output rank - difference between indices shape and data shape + int outRank = shape::rank(restShape) - shape::rank(firstShape) + + 1; // at least 1D tensor + std::vector outShape(outRank); + // fill up output shape template: the first to max index, and rests - to vals + // from the first data input + outShape[0] = maxValue + 1; + for (int i = 1; i < outRank; ++i) outShape[i] = shape::sizeAt(restShape, i); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(restShape), shape::order(firstShape), outShape))); - } -} + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(restShape), + shape::order(firstShape), outShape))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp index 984708de5654..3156a9e7bf58 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/floor.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/floor.cpp @@ -24,25 +24,23 @@ #include namespace sd { - namespace ops { - OP_IMPL(Floor, 1, 1, true) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +namespace ops { +OP_IMPL(Floor, 1, 1, true) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - first->applyTransform(transform::Floor, *z); + first->applyTransform(transform::Floor, *z); - STORE_RESULT(*z); + STORE_RESULT(*z); - return Status::OK(); - } - DECLARE_SYN(floor, Floor); + return Status::OK(); +} +DECLARE_SYN(floor, Floor); - DECLARE_TYPES(Floor) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(Floor) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index e8db96837d4b..865c3042bc38 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -25,150 +25,156 @@ #include #include - namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; - auto output = OUTPUT_VARIABLE(0); - - const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); - - //Edge case: empty indices -> empty output - if(indices != nullptr && indices->isEmpty()){ - REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty"); - return Status::OK(); //No op - } - - const int numOfIntArgs = block.numI(); - - std::vector intArgs; - if (block.width() > 2) { - intArgs = INPUT_VARIABLE(2)->template asVectorT(); - } - else { - if (numOfIntArgs == 0) - intArgs.emplace_back(0); - else - for (int i = 0; i < numOfIntArgs; ++i) - intArgs.emplace_back(block.getIArguments().at(i)); - } - - const int inputRank = input->rankOf(); - if(intArgs[0] < 0) - intArgs[0] += inputRank; - - // input validation - REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank); - REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); - - if(checkIndices) { - - NDArray* pIndices = indices; - if(indices == nullptr) - pIndices = new NDArray(input->ordering(), {static_cast(intArgs.size()) - 1}, std::vector(intArgs.begin() + 1, intArgs.end()), DataType::INT64, block.launchContext()); - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *pIndices, *input, intArgs[0]); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - if(indices == nullptr) - delete pIndices; - } - - helpers::gather(block.launchContext(), input, indices, output, intArgs); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; + auto output = OUTPUT_VARIABLE(0); + + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); + + // Edge case: empty indices -> empty output + if (indices != nullptr && indices->isEmpty()) { + REQUIRE_TRUE(output->isEmpty(), 0, + "Gather op: If indices are empty, output must also be empty"); + return Status::OK(); // No op + } + + const int numOfIntArgs = block.numI(); + + std::vector intArgs; + if (block.width() > 2) { + intArgs = INPUT_VARIABLE(2)->template asVectorT(); + } else { + if (numOfIntArgs == 0) + intArgs.emplace_back(0); + else + for (int i = 0; i < numOfIntArgs; ++i) + intArgs.emplace_back(block.getIArguments().at(i)); + } + + const int inputRank = input->rankOf(); + if (intArgs[0] < 0) intArgs[0] += inputRank; + + // input validation + REQUIRE_TRUE(intArgs[0] < inputRank, 0, + "GATHER op: input axis must be smaller than input array rank, " + "but got %i and %i correspondingly!", + intArgs[0], inputRank); + REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, + "GATHER op: indices should be provided either as additional " + "input array or as IntArguments !"); + + if (checkIndices) { + NDArray* pIndices = indices; + if (indices == nullptr) + pIndices = + new NDArray(input->ordering(), {static_cast(intArgs.size()) - 1}, + std::vector(intArgs.begin() + 1, intArgs.end()), + DataType::INT64, block.launchContext()); + const Nd4jLong numOfBadIndx = helpers::checkIndices( + block.launchContext(), *pIndices, *input, intArgs[0]); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "GATHER OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + if (indices == nullptr) delete pIndices; + } + + helpers::gather(block.launchContext(), input, indices, output, intArgs); + + return Status::OK(); } DECLARE_TYPES(gather) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS}); + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_INTS, ALL_FLOATS}); } - DECLARE_SHAPE_FN(gather) { + // check shape of paddings + auto inputShapeInfo = inputShape->at(0); + Nd4jLong* outputShapeInfo = nullptr; - // check shape of paddings - auto inputShapeInfo = inputShape->at(0); - Nd4jLong* outputShapeInfo = nullptr; - - int axis = 0; + int axis = 0; - if (block.width() > 2) { - axis = INPUT_VARIABLE(2)->e(0); - } else - axis = block.numI() > 0 ? block.getIArguments().at(0) : 0; + if (block.width() > 2) { + axis = INPUT_VARIABLE(2)->e(0); + } else + axis = block.numI() > 0 ? block.getIArguments().at(0) : 0; - int inputRank = shape::rank(inputShapeInfo); - if(axis < 0) - axis += inputRank; + int inputRank = shape::rank(inputShapeInfo); + if (axis < 0) axis += inputRank; - REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank); + REQUIRE_TRUE(axis < inputRank, 0, + "GATHER op: input axis must be smaller than input array rank, " + "but got %i and %i correspondingly!", + axis, inputRank); - bool isEmpty = false; + bool isEmpty = false; - if (block.width() > 1) { - auto indicesShapeInfo = inputShape->at(1); + if (block.width() > 1) { + auto indicesShapeInfo = inputShape->at(1); - int indicesRank = shape::rank(indicesShapeInfo); + int indicesRank = shape::rank(indicesShapeInfo); - int outputRank = inputRank + indicesRank - 1; + int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(outputRank), Nd4jLong); - // fill output shapeInfo - outputShapeInfo[0] = outputRank; - int shapeIdx = 1; + // fill output shapeInfo + outputShapeInfo[0] = outputRank; + int shapeIdx = 1; - for(int i = 0; i < axis; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; + for (int i = 0; i < axis; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; - for(int i = 0; i < indicesRank; ++i) - outputShapeInfo[shapeIdx++] = indicesShapeInfo[i+1]; + for (int i = 0; i < indicesRank; ++i) + outputShapeInfo[shapeIdx++] = indicesShapeInfo[i + 1]; - for(int i = axis+1; i < inputRank; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; - } - else if (block.numI() > 1) { + for (int i = axis + 1; i < inputRank; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; + } else if (block.numI() > 1) { + int indicesRank = block.numI() == 2 ? 0 : 1; - int indicesRank = block.numI() == 2 ? 0 : 1; + int outputRank = inputRank + indicesRank - 1; + ALLOCATE(outputShapeInfo, block.workspace(), + shape::shapeInfoLength(outputRank), Nd4jLong); - int outputRank = inputRank + indicesRank - 1; - ALLOCATE(outputShapeInfo, block.workspace(), shape::shapeInfoLength(outputRank), Nd4jLong); + // building shape manually + outputShapeInfo[0] = outputRank; + int shapeIdx = 1; + for (int i = 0; i < axis; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; - // building shape manually - outputShapeInfo[0] = outputRank; - int shapeIdx = 1; - for(int i = 0; i < axis; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; + if (block.numI() > 2) outputShapeInfo[shapeIdx++] = block.numI() - 1; - if (block.numI() > 2) - outputShapeInfo[shapeIdx++] = block.numI() - 1; + for (int i = axis + 1; i < inputRank; ++i) + outputShapeInfo[shapeIdx++] = inputShapeInfo[i + 1]; + } else + REQUIRE_TRUE(false, 0, + "GATHER op: indices should be provided either as additional " + "input array or as IntArguments !"); - for(int i = axis+1; i < inputRank; ++i) - outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; - } - else - REQUIRE_TRUE(false, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); - - ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - - if(isEmpty){ - ArrayOptions::setPropertyBit(outputShapeInfo, ARRAY_EMPTY); - } + ShapeUtils::updateStridesAndType(outputShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outputShapeInfo)); - RELEASE(outputShapeInfo, block.workspace()); - return SHAPELIST(result); + if (isEmpty) { + ArrayOptions::setPropertyBit(outputShapeInfo, ARRAY_EMPTY); + } + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(outputShapeInfo)); + RELEASE(outputShapeInfo, block.workspace()); + return SHAPELIST(result); } -} -} - +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp index e5b895ec1962..eb6b042a7cc1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp @@ -22,79 +22,90 @@ #if NOT_EXCLUDED(OP_gather_nd) #include -#include #include +#include namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); - - const int rankIn = input->rankOf(); - const int rankInd = indices->rankOf(); - - REQUIRE_TRUE(rankInd > 0, 0, "GATHER_ND op: array of indexes can't be single scalar, the requirement is: rank > 0, but got rank = %i instead!", rankInd); - int lastIndDim = indices->sizeAt(-1); - REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *input); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::gatherND(block.launchContext(), *input, *indices, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const bool checkIndices = block.getBArguments().empty() ? false : B_ARG(0); + + const int rankIn = input->rankOf(); + const int rankInd = indices->rankOf(); + + REQUIRE_TRUE(rankInd > 0, 0, + "GATHER_ND op: array of indexes can't be single scalar, the " + "requirement is: rank > 0, but got rank = %i instead!", + rankInd); + int lastIndDim = indices->sizeAt(-1); + REQUIRE_TRUE(lastIndDim <= rankIn, 0, + "GATHER_ND op: the last dimension of indices array must be <= " + "rank of input array but got %i and %i correspondingly!", + lastIndDim, rankIn); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *input); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "GATHER_ND OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + } + + helpers::gatherND(block.launchContext(), *input, *indices, *output); + + return Status::OK(); } DECLARE_TYPES(gather_nd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } DECLARE_SHAPE_FN(gather_nd) { - - auto inShapeInfoIn = inputShape->at(0); - auto inShapeInfoInd = inputShape->at(1); - - const int rankIn = inShapeInfoIn[0]; - const int rankInd = inShapeInfoInd[0]; - REQUIRE_TRUE(rankInd > 0, 0, "GATHER_ND op: array of indexes can't be single scalar, the requirement is: rank > 0, but got rank = %i instead!", rankInd); - const int lastIndDim = inShapeInfoInd[rankInd]; - REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); - - int outRank = (rankInd - 1) + (rankIn - lastIndDim); - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), Nd4jLong); - - outShapeInfo[0] = outRank; - - for(int i = 1; i <= rankInd-1; ++i) - outShapeInfo[i] = inShapeInfoInd[i]; - - for(int i = 0; i < rankIn-lastIndDim; ++i) - outShapeInfo[rankInd + i] = inShapeInfoIn[lastIndDim + i + 1]; - - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfoIn, 'c'); - //ArrayOptions::setDataType(outShapeInfo, ArrayOptions::dataType(inShapeInfoIn)); - return SHAPELIST(CONSTANT(outShapeInfo)); + auto inShapeInfoIn = inputShape->at(0); + auto inShapeInfoInd = inputShape->at(1); + + const int rankIn = inShapeInfoIn[0]; + const int rankInd = inShapeInfoInd[0]; + REQUIRE_TRUE(rankInd > 0, 0, + "GATHER_ND op: array of indexes can't be single scalar, the " + "requirement is: rank > 0, but got rank = %i instead!", + rankInd); + const int lastIndDim = inShapeInfoInd[rankInd]; + REQUIRE_TRUE(lastIndDim <= rankIn, 0, + "GATHER_ND op: the last dimension of indices array must be <= " + "rank of input array but got %i and %i correspondingly!", + lastIndDim, rankIn); + + int outRank = (rankInd - 1) + (rankIn - lastIndDim); + + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(outRank), + Nd4jLong); + + outShapeInfo[0] = outRank; + + for (int i = 1; i <= rankInd - 1; ++i) outShapeInfo[i] = inShapeInfoInd[i]; + + for (int i = 0; i < rankIn - lastIndDim; ++i) + outShapeInfo[rankInd + i] = inShapeInfoIn[lastIndDim + i + 1]; + + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfoIn, 'c'); + // ArrayOptions::setDataType(outShapeInfo, + // ArrayOptions::dataType(inShapeInfoIn)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp index 4196385c1c62..772acc752db6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/hashcode.cpp @@ -22,37 +22,38 @@ #if NOT_EXCLUDED(OP_hashcode) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(hashcode, 1, 1, false, 0, 0) { - REQUIRE_TRUE(block.width() == 1, 0, "hashcode: this op can't be applied along dimension"); - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(hashcode, 1, 1, false, 0, 0) { + REQUIRE_TRUE(block.width() == 1, 0, + "hashcode: this op can't be applied along dimension"); - REQUIRE_TRUE(output->isScalar(), 0, "hashcode: this op requires scalar output"); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - helpers::hashCode(block.launchContext(), *input, *output); + REQUIRE_TRUE(output->isScalar(), 0, + "hashcode: this op requires scalar output"); - return Status::OK(); - }; + helpers::hashCode(block.launchContext(), *input, *output); - DECLARE_SHAPE_FN(hashcode) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); - } + return Status::OK(); +}; - - DECLARE_TYPES(hashcode) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes({sd::DataType::INT64}); - }; - } +DECLARE_SHAPE_FN(hashcode) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); } -#endif +DECLARE_TYPES(hashcode) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes({sd::DataType::INT64}); +}; +} // namespace ops +} // namespace sd +#endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp index 415361894942..91e408823c91 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp @@ -22,37 +22,38 @@ #if NOT_EXCLUDED(OP_histogram) #include -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { - auto input = INPUT_VARIABLE(0); - auto numBins = INT_ARG(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length") +namespace ops { +CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto numBins = INT_ARG(0); + auto output = OUTPUT_VARIABLE(0); - output->nullify(); - helpers::histogramHelper(block.launchContext(), *input, *output); + REQUIRE_TRUE(numBins == output->lengthOf(), 0, + "Histogram: numBins must match output length") - return Status::OK(); - } + output->nullify(); + helpers::histogramHelper(block.launchContext(), *input, *output); - DECLARE_SHAPE_FN(histogram) { - auto numBins = INT_ARG(0); - - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(numBins, sd::DataType::INT64)); - } + return Status::OK(); +} +DECLARE_SHAPE_FN(histogram) { + auto numBins = INT_ARG(0); - DECLARE_TYPES(histogram) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS}); - }; - } + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + numBins, sd::DataType::INT64)); } +DECLARE_TYPES(histogram) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS}); +}; +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp index 7c8bb4c6c4c5..ff368751d9d8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp @@ -25,44 +25,52 @@ #include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto range = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments().empty() ? 100 : INT_ARG(0); - - const double leftEdge = range->e(0); - const double rightEdge = range->e(1); - - REQUIRE_TRUE(leftEdge < rightEdge, 0, "HISTOGRAM_FIXED_WIDTH OP: wrong content of range input array, bottom_edge must be smaller than top_edge, but got %f and %f correspondingly !", leftEdge, rightEdge); - REQUIRE_TRUE(nbins >= 1, 0, "HISTOGRAM_FIXED_WIDTH OP: wrong nbins value, expected value should be >= 1, however got %i instead !", nbins); - - helpers::histogramFixedWidth(block.launchContext(), *input, *range, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto range = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int nbins = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : block.getIArguments().empty() ? 100 : INT_ARG(0); + + const double leftEdge = range->e(0); + const double rightEdge = range->e(1); + + REQUIRE_TRUE(leftEdge < rightEdge, 0, + "HISTOGRAM_FIXED_WIDTH OP: wrong content of range input array, " + "bottom_edge must be smaller than top_edge, but got %f and %f " + "correspondingly !", + leftEdge, rightEdge); + REQUIRE_TRUE(nbins >= 1, 0, + "HISTOGRAM_FIXED_WIDTH OP: wrong nbins value, expected value " + "should be >= 1, however got %i instead !", + nbins); + + helpers::histogramFixedWidth(block.launchContext(), *input, *range, *output); + + return Status::OK(); } DECLARE_TYPES(histogram_fixed_width) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_INDICES}); + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_INDICES}); } - ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(histogram_fixed_width) { - - const int nbins = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : block.getIArguments().empty() ? 100 : INT_ARG(0); - auto outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(nbins, DataType::INT64); - return SHAPELIST(outShapeInfo); + const int nbins = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : block.getIArguments().empty() ? 100 : INT_ARG(0); + auto outShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo( + nbins, DataType::INT64); + return SHAPELIST(outShapeInfo); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp b/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp index 3814106cf7ba..2e7ad21d1d73 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/invertPermutation.cpp @@ -22,32 +22,32 @@ #if NOT_EXCLUDED(OP_invert_permutation) #include -#include +#include namespace sd { -namespace ops { +namespace ops { //////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(invert_permutation, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->isVector(), 0 , "INVERT_PERMUTATION op: input array must be vector, but got shape %s instead !", ShapeUtils::shapeAsString(input).c_str()); - - helpers::invertPermutation(block.launchContext(), *input, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->isVector(), 0, + "INVERT_PERMUTATION op: input array must be vector, but got " + "shape %s instead !", + ShapeUtils::shapeAsString(input).c_str()); + + helpers::invertPermutation(block.launchContext(), *input, *output); + + return Status::OK(); } - + DECLARE_SYN(InvertPermutation, invert_permutation); - DECLARE_TYPES(invert_permutation) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } -} +DECLARE_TYPES(invert_permutation) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp index 64858001a43d..9dd1c8eec42d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_add.cpp @@ -22,25 +22,23 @@ #if NOT_EXCLUDED(OP_mergeadd) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(mergeadd, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto output = OUTPUT_VARIABLE(0); - helpers::mergeAdd(block.launchContext(), inArrs, *output); + std::vector inArrs(block.width()); - return Status::OK(); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + + helpers::mergeAdd(block.launchContext(), inArrs, *output); + + return Status::OK(); } DECLARE_SYN(mergesum, mergeadd); DECLARE_SYN(add_n, mergeadd); @@ -48,51 +46,50 @@ DECLARE_SYN(addn, mergeadd); DECLARE_SYN(accumulaten, mergeadd); DECLARE_SYN(accumulate_n, mergeadd); - DECLARE_TYPES(mergeadd) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - - - CUSTOM_OP_IMPL(mergeadd_bp, 2, 1, false, 0, 0) { - - auto inSize = block.width() - 1; +DECLARE_TYPES(mergeadd) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - REQUIRE_OK(this->validateInputDimensionsMatch(block)); +CUSTOM_OP_IMPL(mergeadd_bp, 2, 1, false, 0, 0) { + auto inSize = block.width() - 1; - std::vector outArrs(inSize); - - const auto gradient = INPUT_VARIABLE(inSize); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - for (int i = 0; i < inSize; ++i) { - outArrs[i] = OUTPUT_VARIABLE(i); - } - helpers::mergeAddBp(block.launchContext(), *gradient, outArrs); + std::vector outArrs(inSize); - return Status::OK(); - } + const auto gradient = INPUT_VARIABLE(inSize); - DECLARE_TYPES(mergeadd_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergeadd_bp) { + for (int i = 0; i < inSize; ++i) { + outArrs[i] = OUTPUT_VARIABLE(i); + } + helpers::mergeAddBp(block.launchContext(), *gradient, outArrs); - const int numOfInArrs = block.width() - 1; + return Status::OK(); +} - auto shapeList = SHAPELIST(); +DECLARE_TYPES(mergeadd_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +DECLARE_SHAPE_FN(mergeadd_bp) { + const int numOfInArrs = block.width() - 1; - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp index 83a448170651..995ff69088f9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_avg.cpp @@ -22,71 +22,68 @@ #if NOT_EXCLUDED(OP_mergeavg) #include -#include +#include namespace sd { -namespace ops { +namespace ops { OP_IMPL(mergeavg, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto output = OUTPUT_VARIABLE(0); - helpers::mergeAvg(block.launchContext(), inArrs, *output); + std::vector inArrs(block.width()); - return Status::OK(); -} - - DECLARE_TYPES(mergeavg) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + helpers::mergeAvg(block.launchContext(), inArrs, *output); - CUSTOM_OP_IMPL(mergeavg_bp, 2, 1, false, 0, 0) { + return Status::OK(); +} - auto inSize = block.width() - 1; +DECLARE_TYPES(mergeavg) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - REQUIRE_OK(this->validateInputDimensionsMatch(block)); +CUSTOM_OP_IMPL(mergeavg_bp, 2, 1, false, 0, 0) { + auto inSize = block.width() - 1; - std::vector outArrs(inSize); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - const auto gradient = INPUT_VARIABLE(inSize); - - for (int i = 0; i < inSize; ++i) { - outArrs[i] = OUTPUT_VARIABLE(i); - } - helpers::mergeAvgBp(block.launchContext(), *gradient, outArrs); - return Status::OK(); - } + std::vector outArrs(inSize); - DECLARE_TYPES(mergeavg_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergeavg_bp) { + const auto gradient = INPUT_VARIABLE(inSize); - const int numOfInArrs = block.width() - 1; + for (int i = 0; i < inSize; ++i) { + outArrs[i] = OUTPUT_VARIABLE(i); + } + helpers::mergeAvgBp(block.launchContext(), *gradient, outArrs); + return Status::OK(); +} - auto shapeList = SHAPELIST(); +DECLARE_TYPES(mergeavg_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +DECLARE_SHAPE_FN(mergeavg_bp) { + const int numOfInArrs = block.width() - 1; - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp index 49ab78f7c1e9..adbf72d1a9b1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max.cpp @@ -22,76 +22,72 @@ #if NOT_EXCLUDED(OP_mergemax) #include -#include +#include namespace sd { -namespace ops { - -OP_IMPL(mergemax, -1, 1, false) { - - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - - auto output = OUTPUT_VARIABLE(0); +namespace ops { - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); +OP_IMPL(mergemax, -1, 1, false) { + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - helpers::mergeMax(block.launchContext(), inArrs, *output); + auto output = OUTPUT_VARIABLE(0); - return Status::OK(); -} -DECLARE_SYN(MergeMax, mergemax); + std::vector inArrs(block.width()); - DECLARE_TYPES(mergemax) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); + helpers::mergeMax(block.launchContext(), inArrs, *output); - CUSTOM_OP_IMPL(mergemax_bp, 2, 1, false, 0, 0) { + return Status::OK(); +} +DECLARE_SYN(MergeMax, mergemax); - auto inSize = block.width(); +DECLARE_TYPES(mergemax) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} - REQUIRE_OK(this->validateInputDimensionsMatch(block)); +CUSTOM_OP_IMPL(mergemax_bp, 2, 1, false, 0, 0) { + auto inSize = block.width(); - std::vector inArrs(inSize); - std::vector outArrs(inSize - 1); + REQUIRE_OK(this->validateInputDimensionsMatch(block)); - for (int i = 0; i < inSize; ++i) - inArrs[i] = INPUT_VARIABLE(i); + std::vector inArrs(inSize); + std::vector outArrs(inSize - 1); - for (int i = 0; i < (inSize - 1); ++i) { - outArrs[i] = OUTPUT_NULLIFIED(i); - } + for (int i = 0; i < inSize; ++i) inArrs[i] = INPUT_VARIABLE(i); - helpers::mergeMaxBp(block.launchContext(), inArrs, outArrs); + for (int i = 0; i < (inSize - 1); ++i) { + outArrs[i] = OUTPUT_NULLIFIED(i); + } - return Status::OK(); - } + helpers::mergeMaxBp(block.launchContext(), inArrs, outArrs); - DECLARE_TYPES(mergemax_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - DECLARE_SHAPE_FN(mergemax_bp) { + return Status::OK(); +} - const int numOfInArrs = block.width() - 1; +DECLARE_TYPES(mergemax_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::ANY); +} +DECLARE_SHAPE_FN(mergemax_bp) { + const int numOfInArrs = block.width() - 1; - auto shapeList = SHAPELIST(); - - for (int e = 0; e < numOfInArrs; e++) { - auto inShape = inputShape->at(e); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape)))); - } + auto shapeList = SHAPELIST(); - return shapeList; - } + for (int e = 0; e < numOfInArrs; e++) { + auto inShape = inputShape->at(e); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), + shape::shapeOf(inShape), shape::rank(inShape)))); + } + return shapeList; } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index 0a4e15ff44f9..a344dbeeee62 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -22,42 +22,39 @@ #if NOT_EXCLUDED(OP_mergemaxindex) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) { + REQUIRE_OK(this->validateInputDimensionsMatch(block)); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_OK(this->validateInputDimensionsMatch(block)); - auto output = OUTPUT_VARIABLE(0); + std::vector inArrs(block.width()); - std::vector inArrs(block.width()); - - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - helpers::mergeMaxIndex(block.launchContext(), inArrs, *output); + helpers::mergeMaxIndex(block.launchContext(), inArrs, *output); - return Status::OK(); + return Status::OK(); } DECLARE_SYN(MergeMaxIndex, mergemaxindex); - DECLARE_TYPES(mergemaxindex) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); - } +DECLARE_TYPES(mergemaxindex) { + getOpDescriptor()->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops DECLARE_SHAPE_FN(mergemaxindex) { - auto in = inputShape->at(0); - auto dtype = DataType::INT32; - if (block.numI() > 0) - dtype = (DataType)INT_ARG(0); + auto in = inputShape->at(0); + auto dtype = DataType::INT32; + if (block.numI() > 0) dtype = (DataType)INT_ARG(0); - auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); - return SHAPELIST(CONSTANT(resShape)); -} + auto resShape = + ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); + return SHAPELIST(CONSTANT(resShape)); } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index ed78dda2d1c6..427b1cbb7299 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -21,82 +21,132 @@ #if NOT_EXCLUDED(OP_mirror_pad) #include -#include +#include namespace sd { namespace ops { - CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - const int mode = INT_ARG(0); // 0 - REFLECT, else - SYMMETRIC - const int includeBorder = mode ? 0 : 1; - - if(input->rankOf() <= 1) { // when input is scalar or vector; - REQUIRE_TRUE(paddings->lengthOf() == 2, 0, "MIRROR_PAD OP: the length of paddings array must be equal 2, when input array is vector or scalar, bot but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE( (paddings->e(0) <= (input->lengthOf() - includeBorder)) && (paddings->e(1) <= (input->lengthOf() - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then length of input array (being vector or scalar) for symmetric mode (or length-1 for reflect mode) !"); - } - else { - REQUIRE_TRUE(paddings->rankOf() == 2, 0, "MIRROR_PAD OP: the rank of paddings array must be equal 2, but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE(paddings->sizeAt(0) == input->rankOf(), 0, "MIRROR_PAD OP: zero dimension of paddings array must be equal to input array rank, but got %i and %i correspondingly !", paddings->sizeAt(0), input->rankOf()); - - for(int i = 0; i < input->rankOf(); ++i) - REQUIRE_TRUE( (paddings->e(i,0) <= (input->sizeAt(i) - includeBorder)) && (paddings->e(i,1) <= (input->sizeAt(i) - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then corresponding dimension of input array for symmetric mode (or dimension-1 for reflect mode) !"); - } - - helpers::mirrorPad(block.launchContext(), *input, *paddings, *output, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const int mode = INT_ARG(0); // 0 - REFLECT, else - SYMMETRIC + const int includeBorder = mode ? 0 : 1; + + if (input->rankOf() <= 1) { // when input is scalar or vector; + REQUIRE_TRUE( + paddings->lengthOf() == 2, 0, + "MIRROR_PAD OP: the length of paddings array must be equal 2, when " + "input array is vector or scalar, bot but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + (paddings->e(0) <= (input->lengthOf() - includeBorder)) && + (paddings->e(1) <= (input->lengthOf() - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must be " + "no grater then length of input array (being vector or scalar) for " + "symmetric mode (or length-1 for reflect mode) !"); + } else { + REQUIRE_TRUE(paddings->rankOf() == 2, 0, + "MIRROR_PAD OP: the rank of paddings array must be equal 2, " + "but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + paddings->sizeAt(0) == input->rankOf(), 0, + "MIRROR_PAD OP: zero dimension of paddings array must be equal to " + "input array rank, but got %i and %i correspondingly !", + paddings->sizeAt(0), input->rankOf()); + + for (int i = 0; i < input->rankOf(); ++i) + REQUIRE_TRUE( + (paddings->e(i, 0) <= (input->sizeAt(i) - includeBorder)) && + (paddings->e(i, 1) <= + (input->sizeAt(i) - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must " + "be no grater then corresponding dimension of input array for " + "symmetric mode (or dimension-1 for reflect mode) !"); + } + + helpers::mirrorPad(block.launchContext(), *input, *paddings, *output, mode); + + return Status::OK(); } - DECLARE_TYPES(mirror_pad) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF - getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); - } +DECLARE_TYPES(mirror_pad) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes( + 1, {DataType::INT32, DataType::INT64}); // to conform with TF + getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); +} DECLARE_SHAPE_FN(mirror_pad) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - - const int rank = input->rankOf() ? input->rankOf() : 1; // if scalar is input then vector is output - const int includeBorder = static_cast(INT_ARG(0)) ? 0 : 1; - - if(rank == 1) { // when input is scalar or vector; - REQUIRE_TRUE(paddings->lengthOf() == 2, 0, "MIRROR_PAD OP: the length of paddings array must be equal 2, when input array is vector or scalar, bot but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE( (paddings->e(0) <= (input->lengthOf() - includeBorder)) && (paddings->e(1) <= (input->lengthOf() - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then length of input array (being vector or scalar) for symmetric mode (or length-1 for reflect mode) !"); - } - else { - REQUIRE_TRUE(paddings->rankOf() == 2, 0, "MIRROR_PAD OP: the rank of paddings array must be equal 2, but got %i instead !", paddings->rankOf()); - REQUIRE_TRUE(paddings->sizeAt(0) == input->rankOf(), 0, "MIRROR_PAD OP: zero dimension of paddings array must be equal to input array rank, but got %i and %i correspondingly !", paddings->sizeAt(0), input->rankOf()); - for(int i = 0; i < input->rankOf(); ++i) - REQUIRE_TRUE( (paddings->e(i,0) <= (input->sizeAt(i) - includeBorder)) && (paddings->e(i,1) <= (input->sizeAt(i) - includeBorder)), 0, "MIRROR_PAD OP: wrong content of paddings array, its elements must be no grater then corresponding dimension of input array for symmetric mode (or dimension-1 for reflect mode) !"); - } - - if(rank == 1) { - Nd4jLong len = input->lengthOf() + paddings->e(0) + paddings->e(1); - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(len, input->dataType())); - } - - Nd4jLong* outShapeInfo(nullptr); - - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) - outShapeInfo[i+1] = input->sizeAt(i) + paddings->e(i,0) + paddings->e(i,1); - ShapeUtils::updateStridesAndType(outShapeInfo, input->shapeInfo(), input->ordering()); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + + const int rank = input->rankOf() + ? input->rankOf() + : 1; // if scalar is input then vector is output + const int includeBorder = static_cast(INT_ARG(0)) ? 0 : 1; + + if (rank == 1) { // when input is scalar or vector; + REQUIRE_TRUE( + paddings->lengthOf() == 2, 0, + "MIRROR_PAD OP: the length of paddings array must be equal 2, when " + "input array is vector or scalar, bot but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + (paddings->e(0) <= (input->lengthOf() - includeBorder)) && + (paddings->e(1) <= (input->lengthOf() - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must be " + "no grater then length of input array (being vector or scalar) for " + "symmetric mode (or length-1 for reflect mode) !"); + } else { + REQUIRE_TRUE(paddings->rankOf() == 2, 0, + "MIRROR_PAD OP: the rank of paddings array must be equal 2, " + "but got %i instead !", + paddings->rankOf()); + REQUIRE_TRUE( + paddings->sizeAt(0) == input->rankOf(), 0, + "MIRROR_PAD OP: zero dimension of paddings array must be equal to " + "input array rank, but got %i and %i correspondingly !", + paddings->sizeAt(0), input->rankOf()); + for (int i = 0; i < input->rankOf(); ++i) + REQUIRE_TRUE( + (paddings->e(i, 0) <= (input->sizeAt(i) - includeBorder)) && + (paddings->e(i, 1) <= + (input->sizeAt(i) - includeBorder)), + 0, + "MIRROR_PAD OP: wrong content of paddings array, its elements must " + "be no grater then corresponding dimension of input array for " + "symmetric mode (or dimension-1 for reflect mode) !"); + } + + if (rank == 1) { + Nd4jLong len = + input->lengthOf() + paddings->e(0) + paddings->e(1); + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + len, input->dataType())); + } + + Nd4jLong* outShapeInfo(nullptr); + + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 0; i < rank; ++i) + outShapeInfo[i + 1] = input->sizeAt(i) + paddings->e(i, 0) + + paddings->e(i, 1); + ShapeUtils::updateStridesAndType(outShapeInfo, input->shapeInfo(), + input->ordering()); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 0c5dfc7f0bb2..8260fdc0b733 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -22,94 +22,119 @@ #if NOT_EXCLUDED(OP_pad) #include -#include +#include + #include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { - - auto input = INPUT_VARIABLE(0); - auto paddings = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - const int rank = input->rankOf(); - - // input validation - std::vector expectedPaddingsShape = {rank, 2}; - std::vector currentPaddingsShape = paddings->getShapeAsVector(); - REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); - - NDArray padValue(input->dataType(), block.launchContext()); - - // in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements - if (INT_ARG(0) == 0) { // CONSTANT mode - if(block.width() > 2) { - REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, "PAD op: data types of input and padValue arrays should be the same but got %i and %i correspondingly !", input->dataType(), INPUT_VARIABLE(2)->dataType()); - padValue.assign(INPUT_VARIABLE(2)->e(0)); - } - else if (!block.getTArguments().empty()) - padValue = T_ARG(0); - } - else if(INT_ARG(0) == 1) { // REFLECT mode - for(int dim=0; dim < rank; ++dim) - REQUIRE_TRUE(paddings->e(dim,0) <= (input->shapeOf()[dim]-1) && paddings->e(dim,1) <= (input->shapeOf()[dim]-1), 0, "PAD op: wrong content of paddings array for REFLECT mode !"); - } - if(INT_ARG(0) == 2) { // SYMMETRIC mode - for(int dim=0; dim < rank; ++dim) - REQUIRE_TRUE(paddings->e(dim,0) <= input->shapeOf()[dim] && paddings->e(dim,1) <= input->shapeOf()[dim], 0, "PAD op: wrong content of paddings array for SYMMETRIC mode !"); - } - - // CONSTANT->0, REFLECT->1, SYMMETRIC->2 - REQUIRE_TRUE(INT_ARG(0) >= 0 && INT_ARG(0) <= 2, 0, "PAD op: unknown padding mode, there are only three possible legal values -> 0,1,2, but got %i instead !", INT_ARG(0)); - - // std::vector dimensions(input->rankOf()); - // std::iota(dimensions.begin(), dimensions.end(), 0); // fill with 0, 1, ... rank-1 - - // helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, dimensions, 0, 0, 0, padValue); - helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, padValue); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto paddings = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + const int rank = input->rankOf(); + + // input validation + std::vector expectedPaddingsShape = {rank, 2}; + std::vector currentPaddingsShape = paddings->getShapeAsVector(); + REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, + "PAD op: wrong shape of paddings array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), + ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); + + NDArray padValue(input->dataType(), block.launchContext()); + + // in case of REFLECT and SYMMETRIC modes paddings must obey additional shape + // requirements + if (INT_ARG(0) == 0) { // CONSTANT mode + if (block.width() > 2) { + REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, + "PAD op: data types of input and padValue arrays should be " + "the same but got %i and %i correspondingly !", + input->dataType(), INPUT_VARIABLE(2)->dataType()); + padValue.assign(INPUT_VARIABLE(2)->e(0)); + } else if (!block.getTArguments().empty()) + padValue = T_ARG(0); + } else if (INT_ARG(0) == 1) { // REFLECT mode + for (int dim = 0; dim < rank; ++dim) + REQUIRE_TRUE( + paddings->e(dim, 0) <= (input->shapeOf()[dim] - 1) && + paddings->e(dim, 1) <= (input->shapeOf()[dim] - 1), + 0, "PAD op: wrong content of paddings array for REFLECT mode !"); + } + if (INT_ARG(0) == 2) { // SYMMETRIC mode + for (int dim = 0; dim < rank; ++dim) + REQUIRE_TRUE( + paddings->e(dim, 0) <= input->shapeOf()[dim] && + paddings->e(dim, 1) <= input->shapeOf()[dim], + 0, "PAD op: wrong content of paddings array for SYMMETRIC mode !"); + } + + // CONSTANT->0, REFLECT->1, SYMMETRIC->2 + REQUIRE_TRUE(INT_ARG(0) >= 0 && INT_ARG(0) <= 2, 0, + "PAD op: unknown padding mode, there are only three possible " + "legal values -> 0,1,2, but got %i instead !", + INT_ARG(0)); + + // std::vector dimensions(input->rankOf()); + // std::iota(dimensions.begin(), dimensions.end(), 0); + // // fill with 0, 1, ... rank-1 + + // helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, + // dimensions, 0, 0, 0, padValue); + helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, + padValue); + + return Status::OK(); } DECLARE_TYPES(pad) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF -// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes( + 1, {DataType::INT32, DataType::INT64}) // INT32 with TF + // ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // + // INT32 with TF, but used also INT64 due long shapes + ->setSameMode(true); } DECLARE_SHAPE_FN(pad) { - - // check shape of paddings - auto inputShapeInfo = inputShape->at(0); - auto paddings = INPUT_VARIABLE(1); - const int rank = inputShapeInfo[0]; - - // paddings validation - const std::vector expectedPaddingsShape = {rank, 2}; - const std::vector currentPaddingsShape = paddings->getShapeAsVector(); - REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), Nd4jLong); - outShapeInfo[0] = rank; - for(int i=1; i <= rank; ++i) - outShapeInfo[i] = inputShapeInfo[i] + paddings->e(i-1,0) + paddings->e(i-1,1); - - ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo)); - ShapeDescriptor descriptor(outShapeInfo); - RELEASE(outShapeInfo, block.workspace()); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor)); + // check shape of paddings + auto inputShapeInfo = inputShape->at(0); + auto paddings = INPUT_VARIABLE(1); + const int rank = inputShapeInfo[0]; + + // paddings validation + const std::vector expectedPaddingsShape = {rank, 2}; + const std::vector currentPaddingsShape = + paddings->getShapeAsVector(); + REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, + "PAD op: wrong shape of paddings array, expected is %s, but got " + "%s instead !", + ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), + ShapeUtils::shapeAsString(currentPaddingsShape).c_str()); + + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank), + Nd4jLong); + outShapeInfo[0] = rank; + for (int i = 1; i <= rank; ++i) + outShapeInfo[i] = inputShapeInfo[i] + paddings->e(i - 1, 0) + + paddings->e(i - 1, 1); + + ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, + shape::order(inputShapeInfo)); + ShapeDescriptor descriptor(outShapeInfo); + RELEASE(outShapeInfo, block.workspace()); + return SHAPELIST( + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor)); } - - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp index 12038c763b82..65817f6ad680 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/parallelStack.cpp @@ -22,56 +22,57 @@ #if NOT_EXCLUDED(OP_parallel_stack) #include -#include +#include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(parallel_stack, -1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // check whether shapes of all input array are the same - for (int i = 0; i < (int) block.width() - 1; ++i) - REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), (INPUT_VARIABLE(i+1))->shapeInfo()), 0, "PARALLEL_STACK op: the shapes of all input arrays must be the same !"); + // check whether shapes of all input array are the same + for (int i = 0; i < (int)block.width() - 1; ++i) + REQUIRE_TRUE( + shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), + (INPUT_VARIABLE(i + 1))->shapeInfo()), + 0, + "PARALLEL_STACK op: the shapes of all input arrays must be the same !"); - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + std::vector inArrs(block.width()); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - const int dim = 0; - helpers::stack(block.launchContext(), inArrs, *output, dim); + const int dim = 0; + helpers::stack(block.launchContext(), inArrs, *output, dim); - return Status::OK(); + return Status::OK(); } - DECLARE_TYPES(parallel_stack) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(parallel_stack) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(parallel_stack) { + auto inShapeInfo = inputShape->at(0); + int rank = inShapeInfo[0]; - auto inShapeInfo = inputShape->at(0); - int rank = inShapeInfo[0]; - - Nd4jLong* outShapeInfo = nullptr; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank+1), Nd4jLong); + Nd4jLong* outShapeInfo = nullptr; + ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(rank + 1), + Nd4jLong); - outShapeInfo[0] = rank + 1; - outShapeInfo[1] = block.width(); - for(int i = 1; i <= rank; ++i) - outShapeInfo[i+1] = inShapeInfo[i]; + outShapeInfo[0] = rank + 1; + outShapeInfo[1] = block.width(); + for (int i = 1; i <= rank; ++i) outShapeInfo[i + 1] = inShapeInfo[i]; - ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, shape::order(inShapeInfo)); + ShapeUtils::updateStridesAndType(outShapeInfo, inShapeInfo, + shape::order(inShapeInfo)); - return SHAPELIST(CONSTANT(outShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp index 4c121d115df8..07a6a5c8cac3 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/repeat.cpp @@ -24,52 +24,58 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// here iArgs is int vector of repeats at the beginning and last element in iArgs is dimension +// here iArgs is int vector of repeats at the beginning and last element in +// iArgs is dimension CUSTOM_OP_IMPL(repeat, 1, 1, true, 0, -1) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + std::vector repeats = block.getIArguments(); - std::vector repeats = block.getIArguments(); + const int axis = + repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); - const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); + repeats.pop_back(); - repeats.pop_back(); + REQUIRE_TRUE(0 <= axis && axis < input->rankOf(), 0, + "CUSTOM REPEAT OP: wrong axis argument it should be less then " + "input array rank %i, but got %i instead !", + input->rankOf(), axis); - REQUIRE_TRUE(0 <= axis && axis < input->rankOf(), 0, "CUSTOM REPEAT OP: wrong axis argument it should be less then input array rank %i, but got %i instead !", input->rankOf(), axis); + REQUIRE_TRUE(repeats.size() == 1 || repeats.size() == input->sizeAt(axis), 0, + "CUSTOM REPEAT OP: wrong axis argument, size of repeats vector " + "must be 1 or equal to dimension at given axis, but got " + "repeats.size = %i and axis = %i !", + repeats.size(), axis); - REQUIRE_TRUE(repeats.size() == 1 || repeats.size() == input->sizeAt(axis), 0, "CUSTOM REPEAT OP: wrong axis argument, size of repeats vector must be 1 or equal to dimension at given axis, but got repeats.size = %i and axis = %i !", repeats.size(), axis); + input->repeat(axis, repeats, *output); - input->repeat(axis, repeats, *output); - - return Status::OK(); + return Status::OK(); } DECLARE_TYPES(repeat) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } DECLARE_SHAPE_FN(repeat) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - - auto repeats = block.getIArguments(); - - const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); + auto repeats = block.getIArguments(); - repeats.pop_back(); + const int axis = + repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back(); - auto outShape = ShapeUtils::evalRepeatShape(axis, repeats, *input); + repeats.pop_back(); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(input->dataType(), input->ordering(), outShape))); + auto outShape = ShapeUtils::evalRepeatShape(axis, repeats, *input); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(input->dataType(), input->ordering(), outShape))); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 13bda4161e73..c8d05dd76872 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -24,88 +24,84 @@ #include #include - namespace sd { -namespace ops { - - CONFIGURABLE_OP_IMPL(reverse, 1, 1, true, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if(output->isEmpty()){ - //No-op - return Status::OK(); - } - - std::vector axis; - - if (block.width() > 1) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = block.getIArguments(); - - if(axis.empty()) { // do not perform reversion - if (!block.isInplace()) - output->assign(input); - } - else { - // check the consistency of input dimensions to reverse along - shape::checkDimensions(input->rankOf(), axis); - helpers::reverse(block.launchContext(), input, output, &axis, false); - } - - return Status::OK(); - } - - DECLARE_SYN(reverse_v2, reverse); - - DECLARE_TYPES(reverse) { - getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } - - CUSTOM_OP_IMPL(reverse_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - std::vector axis; - - if (block.width() == 3) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = block.getIArguments(); - - if(axis.empty()) { // reversion is not performed in this case - output->assign(eps); - } - else { - // check the consistency of input dimensions to reverse along - shape::checkDimensions(input->rankOf(), axis); - // we just reverse back original array - helpers::reverse(block.launchContext(), eps, output, &axis, false); - } - - return Status::OK(); - } - - DECLARE_TYPES(reverse_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - DECLARE_SHAPE_FN(reverse_bp) { - auto in = inputShape->at(0); - Nd4jLong *out; - COPY_SHAPE(in, out); - - return SHAPELIST(CONSTANT(out)); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(reverse, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (output->isEmpty()) { + // No-op + return Status::OK(); + } + + std::vector axis; + + if (block.width() > 1) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + if (axis.empty()) { // do not perform reversion + if (!block.isInplace()) output->assign(input); + } else { + // check the consistency of input dimensions to reverse along + shape::checkDimensions(input->rankOf(), axis); + helpers::reverse(block.launchContext(), input, output, &axis, false); + } + + return Status::OK(); +} + +DECLARE_SYN(reverse_v2, reverse); + +DECLARE_TYPES(reverse) { + getOpDescriptor()->setAllowedInputTypes(0, DataType::ANY); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} + +CUSTOM_OP_IMPL(reverse_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + std::vector axis; + + if (block.width() == 3) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + if (axis.empty()) { // reversion is not performed in this case + output->assign(eps); + } else { + // check the consistency of input dimensions to reverse along + shape::checkDimensions(input->rankOf(), axis); + // we just reverse back original array + helpers::reverse(block.launchContext(), eps, output, &axis, false); + } + return Status::OK(); } + +DECLARE_TYPES(reverse_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); } +DECLARE_SHAPE_FN(reverse_bp) { + auto in = inputShape->at(0); + Nd4jLong *out; + COPY_SHAPE(in, out); + + return SHAPELIST(CONSTANT(out)); +} + +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp index c7dcc6e36dbf..058a5d631e4f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverseSequence.cpp @@ -28,57 +28,99 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(reverse_sequence, 2, 1, false, 0, 2) { - - auto input = INPUT_VARIABLE(0); - auto seqLengths = INPUT_VARIABLE(1); - auto output = OUTPUT_VARIABLE(0); - - int seqDim = INT_ARG(0); - int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; - - REQUIRE_TRUE(input->rankOf() > 1, 0, "REVERSE_SEQUENSE operation: input array must have rank > 1, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(seqLengths->rankOf() == 1, 0, "REVERSE_SEQUENSE operation: input array seqLengths must be 1D vector, that is it must have rank == 1, but got %i instead !", seqLengths->rankOf()); - REQUIRE_TRUE(seqLengths->lengthOf() == input->sizeAt(batchDim), 0, "REVERSE_SEQUENSE custom operation: the length of array seqLengths must be equal to the value of batchDim dimension of input array, but got %i and %i correspondingly !", seqLengths->lengthOf(), input->sizeAt(batchDim)); - REQUIRE_TRUE(seqDim != batchDim, 0, "REVERSE_SEQUENSE operation: input integer parameters seqDim and batchDim must be different, but they both are equal to %i !", batchDim); - REQUIRE_TRUE(batchDim < input->rankOf(), 0, "REVERSE_SEQUENSE operation: input integer parameter batchDim must be smaller than input array rank, but got %i and %i correspondingly !", batchDim, input->rankOf()); - REQUIRE_TRUE(seqDim < input->rankOf(), 0, "REVERSE_SEQUENSE operation: input integer parameter seqDim must be smaller than input array rank, but got %i and %i correspondingly !", seqDim, input->rankOf()); - - auto maxElem = seqLengths->reduceNumber(reduce::Max); - REQUIRE_TRUE(maxElem.e(0) <= input->sizeAt(seqDim), 0, "REVERSE_SEQUENSE operation: max element in seqLengths array must be not greater than value of seqDim dimension of input array !"); - - helpers::reverseSequence(block.launchContext(), input, seqLengths, output, seqDim, batchDim); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto seqLengths = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + int seqDim = INT_ARG(0); + int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; + + REQUIRE_TRUE(input->rankOf() > 1, 0, + "REVERSE_SEQUENSE operation: input array must have rank > 1, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(seqLengths->rankOf() == 1, 0, + "REVERSE_SEQUENSE operation: input array seqLengths must be 1D " + "vector, that is it must have rank == 1, but got %i instead !", + seqLengths->rankOf()); + REQUIRE_TRUE(seqLengths->lengthOf() == input->sizeAt(batchDim), 0, + "REVERSE_SEQUENSE custom operation: the length of array " + "seqLengths must be equal to the value of batchDim dimension of " + "input array, but got %i and %i correspondingly !", + seqLengths->lengthOf(), input->sizeAt(batchDim)); + REQUIRE_TRUE( + seqDim != batchDim, 0, + "REVERSE_SEQUENSE operation: input integer parameters seqDim and " + "batchDim must be different, but they both are equal to %i !", + batchDim); + REQUIRE_TRUE( + batchDim < input->rankOf(), 0, + "REVERSE_SEQUENSE operation: input integer parameter batchDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + batchDim, input->rankOf()); + REQUIRE_TRUE( + seqDim < input->rankOf(), 0, + "REVERSE_SEQUENSE operation: input integer parameter seqDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + seqDim, input->rankOf()); + + auto maxElem = seqLengths->reduceNumber(reduce::Max); + REQUIRE_TRUE( + maxElem.e(0) <= input->sizeAt(seqDim), 0, + "REVERSE_SEQUENSE operation: max element in seqLengths array must be not " + "greater than value of seqDim dimension of input array !"); + + helpers::reverseSequence(block.launchContext(), input, seqLengths, output, + seqDim, batchDim); + + return Status::OK(); } - DECLARE_TYPES(reverse_sequence) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } +DECLARE_TYPES(reverse_sequence) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} DECLARE_SHAPE_FN(reverse_sequence) { - - auto inShapeInfo = inputShape->at(0); - auto seqLenShapeInfo = inputShape->at(1); - - int seqDim = INT_ARG(0); - int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; - - REQUIRE_TRUE(batchDim < inShapeInfo[0], 0, "REVERSE_SEQUENSE operation: input integer parameter batchDim must be smaller than input array rank, but got %i and %i correspondingly !", batchDim, inShapeInfo[0]); - REQUIRE_TRUE(seqDim < inShapeInfo[0], 0, "REVERSE_SEQUENSE operation: input integer parameter seqDim must be smaller than input array rank, but got %i and %i correspondingly !", seqDim, inShapeInfo[0]); - REQUIRE_TRUE(inShapeInfo[0] > 1, 0, "REVERSE_SEQUENSE operation: input array must have rank > 1, but got %i instead !", inShapeInfo[0]); - REQUIRE_TRUE(seqLenShapeInfo[0] == 1, 0, "REVERSE_SEQUENSE operation: input array seqLengths must be 1D vector, that is it must have rank == 1, but got %i instead !", seqLenShapeInfo[0]); - REQUIRE_TRUE(seqLenShapeInfo[1] == inShapeInfo[batchDim+1], 0, "REVERSE_SEQUENSE custom operation: the length of array seqLengths must be equal to the value of batchDim dimension of input array, but got %i and %i correspondingly !", seqLenShapeInfo[1], inShapeInfo[batchDim+1]); - - Nd4jLong* outShapeInfo = nullptr; - COPY_SHAPE(inShapeInfo, outShapeInfo); - - return SHAPELIST(CONSTANT(outShapeInfo)); + auto inShapeInfo = inputShape->at(0); + auto seqLenShapeInfo = inputShape->at(1); + + int seqDim = INT_ARG(0); + int batchDim = block.numI() > 1 ? INT_ARG(1) : 0; + + REQUIRE_TRUE( + batchDim < inShapeInfo[0], 0, + "REVERSE_SEQUENSE operation: input integer parameter batchDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + batchDim, inShapeInfo[0]); + REQUIRE_TRUE( + seqDim < inShapeInfo[0], 0, + "REVERSE_SEQUENSE operation: input integer parameter seqDim must be " + "smaller than input array rank, but got %i and %i correspondingly !", + seqDim, inShapeInfo[0]); + REQUIRE_TRUE(inShapeInfo[0] > 1, 0, + "REVERSE_SEQUENSE operation: input array must have rank > 1, " + "but got %i instead !", + inShapeInfo[0]); + REQUIRE_TRUE(seqLenShapeInfo[0] == 1, 0, + "REVERSE_SEQUENSE operation: input array seqLengths must be 1D " + "vector, that is it must have rank == 1, but got %i instead !", + seqLenShapeInfo[0]); + REQUIRE_TRUE(seqLenShapeInfo[1] == inShapeInfo[batchDim + 1], 0, + "REVERSE_SEQUENSE custom operation: the length of array " + "seqLengths must be equal to the value of batchDim dimension of " + "input array, but got %i and %i correspondingly !", + seqLenShapeInfo[1], inShapeInfo[batchDim + 1]); + + Nd4jLong* outShapeInfo = nullptr; + COPY_SHAPE(inShapeInfo, outShapeInfo); + + return SHAPELIST(CONSTANT(outShapeInfo)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp index 09a0efc5585e..e0a8f48241df 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_add.cpp @@ -29,73 +29,90 @@ namespace sd { namespace ops { OP_IMPL(scatter_add, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - const Nd4jLong indLen = indices->lengthOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_ADD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + const Nd4jLong indLen = indices->lengthOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_ADD OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_ADD OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), + inShape.begin() + Nd4jLong(1L), inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ADD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else { - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_ADD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); + helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + } - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + Nd4jLong(1L), inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterAdd, scatter_add); DECLARE_TYPES(scatter_add) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp index a4febc684483..71b8483cb8d2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_div.cpp @@ -26,73 +26,89 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_div, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); - } +namespace ops { +OP_IMPL(scatter_div, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_DIV OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_DIV OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_DIV OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_DIV OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_DIV OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - return Status::OK(); - } - DECLARE_SYN(ScatterDiv, scatter_div); + helpers::scatter(block.launchContext(), pairwise::Divide, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_div) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterDiv, scatter_div); + +DECLARE_TYPES(scatter_div) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp index 220c68227b1e..370e95a5e1e0 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_max.cpp @@ -28,72 +28,88 @@ namespace sd { namespace ops { OP_IMPL(scatter_max, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MAX OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MAX OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MAX OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MAX OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MAX OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); + helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, + *updates, *output, lock); + } - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MAX OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterMax, scatter_max); - DECLARE_TYPES(scatter_max) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_max) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp index 957541e5e74f..069a8f10d400 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_min.cpp @@ -28,72 +28,88 @@ namespace sd { namespace ops { OP_IMPL(scatter_min, 3, 1, true) { - - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MIN OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MIN OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MIN OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MIN OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MIN OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); + helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, + *updates, *output, lock); + } - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MIN OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); - } - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(ScatterMin, scatter_min); - DECLARE_TYPES(scatter_min) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_min) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp index 24e329f6da31..29c65052706b 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_mul.cpp @@ -26,74 +26,89 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_mul, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - if (!block.isInplace()) - output->assign(input); - - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MUL OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); - } - - return Status::OK(); - } - DECLARE_SYN(ScatterMul, scatter_mul); +namespace ops { +OP_IMPL(scatter_mul, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + if (!block.isInplace()) output->assign(input); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_MUL OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MUL OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_MUL OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_MUL OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_MUL OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - DECLARE_TYPES(scatter_mul) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, + *updates, *output, lock); + } - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterMul, scatter_mul); + +DECLARE_TYPES(scatter_mul) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp index 1a907b1331a8..34eb2f1c0ca6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd.cpp @@ -27,73 +27,99 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(scatter_nd, 3, 1, false, 0, 0) { - auto indices = INPUT_VARIABLE(0); - auto updates = INPUT_VARIABLE(1); - auto shape = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - const int shapeRank = shape->rankOf(); - const Nd4jLong shapeLen = shape->lengthOf(); - - REQUIRE_TRUE(shapeRank == 1, 0, "SCATTER_ND OP: the rank of shape array must be 1, but got %i instead !", shapeRank); - REQUIRE_TRUE(indices->sizeAt(-1) <= shapeLen, 0, "SCATTER_ND OP: last dimension of indices array must be <= length of shape array, but got %i and %i correspondingly !", indices->sizeAt(-1), shapeLen); - // REQUIRE_TRUE(updRank == (indRank + shapeLen - 2), 0, "SCATTER_ND OP: the equality updates_rank = (indices_rank + shape_length - 2) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, shape_length = %i !", updRank, indRank, shapeLen); - REQUIRE_TRUE(updRank == (indRank - 1 + shapeLen - indices->sizeAt(-1)), 0, "SCATTER_ND OP: the equality updates_rank = (indices_rank - 1 + shape_length - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, shape_length = %i, last_indices_dimension = %i !", updRank, shapeLen, indices->sizeAt(-1)); - - std::vector outShape = shape->getBufferAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // initial zeroing of output - *output = 0; - - helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - - return Status::OK(); - } - - DECLARE_TYPES(scatter_nd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -//////////////////////////////////////////////////////////////////////// - DECLARE_SHAPE_FN(scatter_nd) { +CUSTOM_OP_IMPL(scatter_nd, 3, 1, false, 0, 0) { + auto indices = INPUT_VARIABLE(0); + auto updates = INPUT_VARIABLE(1); + auto shape = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + const int shapeRank = shape->rankOf(); + const Nd4jLong shapeLen = shape->lengthOf(); + + REQUIRE_TRUE( + shapeRank == 1, 0, + "SCATTER_ND OP: the rank of shape array must be 1, but got %i instead !", + shapeRank); + REQUIRE_TRUE(indices->sizeAt(-1) <= shapeLen, 0, + "SCATTER_ND OP: last dimension of indices array must be <= " + "length of shape array, but got %i and %i correspondingly !", + indices->sizeAt(-1), shapeLen); + // REQUIRE_TRUE(updRank == (indRank + shapeLen - 2), 0, "SCATTER_ND OP: the + // equality updates_rank = (indices_rank + shape_length - 2) must be true for + // input arrays, but got instead: updates_rank = %i, indices_rank = %i, + // shape_length = %i !", updRank, indRank, shapeLen); + REQUIRE_TRUE(updRank == (indRank - 1 + shapeLen - indices->sizeAt(-1)), 0, + "SCATTER_ND OP: the equality updates_rank = (indices_rank - 1 + " + "shape_length - last_indices_dimension) must be true for input " + "arrays, but got instead: updates_rank = %i, shape_length = %i, " + "last_indices_dimension = %i !", + updRank, shapeLen, indices->sizeAt(-1)); + + std::vector outShape = shape->getBufferAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND OP: wrong shape of updates array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND OP: please check elements of indices-array, total " + "number of wrong elements is %lld!", + numOfBadIndx); + } + + // initial zeroing of output + *output = 0; + + helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + + return Status::OK(); +} - auto shape = INPUT_VARIABLE(2); - auto updShapeInfo = inputShape->at(1); +DECLARE_TYPES(scatter_nd) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} - Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); +//////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(scatter_nd) { + auto shape = INPUT_VARIABLE(2); + auto updShapeInfo = inputShape->at(1); - outShapeInfo[0] = shape->lengthOf(); - for (int i = 0; i < outShapeInfo[0]; ++i) - outShapeInfo[i + 1] = shape->e(i); + Nd4jLong *outShapeInfo; + ALLOCATE(outShapeInfo, block.workspace(), + shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); - ShapeUtils::updateStridesAndType(outShapeInfo, updShapeInfo, shape::order(updShapeInfo)); + outShapeInfo[0] = shape->lengthOf(); + for (int i = 0; i < outShapeInfo[0]; ++i) + outShapeInfo[i + 1] = shape->e(i); - return SHAPELIST(CONSTANT(outShapeInfo)); - } + ShapeUtils::updateStridesAndType(outShapeInfo, updShapeInfo, + shape::order(updShapeInfo)); + return SHAPELIST(CONSTANT(outShapeInfo)); } -} + +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp index 691894a9ec72..946494955dbf 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_add.cpp @@ -25,57 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_add, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_ADD OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_ADD OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_ADD OP: the last dimension of indices array must be " + "<= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_ADD OP: the equality updates_rank = (indices_rank - " + "1 + input_rank - last_indices_dimension) must be true for " + "input arrays, but got instead: updates_rank = %i, indices_rank " + "= %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_ADD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_ADD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::Add, *indices, *updates, + *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_add) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_nd_add) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp index d9db29bb132f..1ed2154109e3 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_sub.cpp @@ -25,57 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_sub, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_SUB OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_SUB OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_SUB OP: the last dimension of indices array must be " + "<= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_SUB OP: the equality updates_rank = (indices_rank - " + "1 + input_rank - last_indices_dimension) must be true for " + "input arrays, but got instead: updates_rank = %i, indices_rank " + "= %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_SUB OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::Subtract, *indices, + *updates, *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_sub) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - -} +DECLARE_TYPES(scatter_nd_sub) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp index 9f3f6be0849c..a7d0c8e0a181 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_nd_update.cpp @@ -25,58 +25,75 @@ #include namespace sd { -namespace ops { +namespace ops { OP_IMPL(scatter_nd_update, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - const bool lock = block.getBArguments().empty() ? true : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - const Nd4jLong indLastDim = indices->sizeAt(-1); - - REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_UPDATE OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); - REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_UPDATE OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); - - std::vector inShape = input->getShapeAsVector(); - std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); - if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_UPDATE OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - if (!block.isInplace()) - output->assign(input); - - helpers::scatterND(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + const Nd4jLong indLastDim = indices->sizeAt(-1); + + REQUIRE_TRUE(indLastDim <= inRank, 0, + "SCATTER_ND_UPDATE OP: the last dimension of indices array must " + "be <= input_array_rank, but got %i instead !", + indLastDim); + REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, + "SCATTER_ND_UPDATE OP: the equality updates_rank = " + "(indices_rank - 1 + input_rank - last_indices_dimension) must " + "be true for input arrays, but got instead: updates_rank = %i, " + "indices_rank = %i, last_indices_dimension = %i !", + updRank, indRank, indLastDim); + + std::vector inShape = input->getShapeAsVector(); + std::vector updShape = updates->getShapeAsVector(); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), + std::end(indShape) - 1); + if (inRank > indLastDim) + std::move(std::begin(inShape) + indLastDim, std::end(inShape), + std::back_inserter(expectedUpdShape)); + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_ND_UPDATE OP: please check elements of " + "indices-array, total number of wrong elements is %lld!", + numOfBadIndx); + } + + if (!block.isInplace()) output->assign(input); + + helpers::scatterND(block.launchContext(), pairwise::CopyPws, *indices, + *updates, *output, lock); + + return Status::OK(); } - DECLARE_TYPES(scatter_nd_update) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - - } - -} +DECLARE_TYPES(scatter_nd_update) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd + #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp index 8ceb5de5fe88..7cf7e3612dca 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_sub.cpp @@ -26,74 +26,93 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_sub, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? false : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); - } +namespace ops { +OP_IMPL(scatter_sub, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? false : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_SUB OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_SUB OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_SUB OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_SUB OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } - return Status::OK(); - } - DECLARE_SYN(ScatterSub, scatter_sub); + // ScatterHelper::template scatterApply>(output, + // indices, updates); + helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_sub) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterSub, scatter_sub); + +DECLARE_TYPES(scatter_sub) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp index d48a86fdc142..b4e318cdc043 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_upd.cpp @@ -25,74 +25,91 @@ #include namespace sd { - namespace ops { - OP_IMPL(scatter_upd, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto indices = INPUT_VARIABLE(1); - auto updates = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - if (!block.isInplace()) - output->assign(input); - - const bool lock = block.getBArguments().empty() ? true : B_ARG(0); - const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); - - const int inRank = input->rankOf(); - const int indRank = indices->rankOf(); - const int updRank = updates->rankOf(); - - REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); - - if(inRank == 1) { - REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); - } - else if (inRank == updRank && indices->isVector()) { - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - else { - - REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - - std::vector updShape = updates->getShapeAsVector(); - std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); - expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - - REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); - } - - if (!indices->isEmpty()) { - - if(checkIndices) { - const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); - REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_UPD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); - } - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); - } - - return Status::OK(); - } - DECLARE_SYN(ScatterUpdate, scatter_upd); +namespace ops { +OP_IMPL(scatter_upd, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto indices = INPUT_VARIABLE(1); + auto updates = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + if (!block.isInplace()) output->assign(input); + + const bool lock = block.getBArguments().empty() ? true : B_ARG(0); + const bool checkIndices = block.numB() <= 1 ? false : B_ARG(1); + + const int inRank = input->rankOf(); + const int indRank = indices->rankOf(); + const int updRank = updates->rankOf(); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); + + if (inRank == 1) { + REQUIRE_TRUE(indices->isSameShape(updates), 0, + "SCATTER_UPD OP: when input array has rank = 1 then indices " + "and updates must have the same shapes, but got %s and %s " + "correspondingly !", + ShapeUtils::shapeAsString(indices).c_str(), + ShapeUtils::shapeAsString(updates).c_str()); + } else if (inRank == updRank && indices->isVector()) { + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = {indices->lengthOf()}; + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_UPD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } else { + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, + "SCATTER_UPD OP: wrong rank of updates array, expected is %i, " + "but got %i instead !", + indRank + inRank - 1, updRank); + + std::vector updShape = updates->getShapeAsVector(); + std::vector inShape = input->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); + expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + 1, + inShape.end()); + + REQUIRE_TRUE(expectedUpdShape == updShape, 0, + "SCATTER_UPD OP: wrong shape of updates array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedUpdShape).c_str(), + ShapeUtils::shapeAsString(updShape).c_str()); + } + + if (!indices->isEmpty()) { + if (checkIndices) { + const Nd4jLong numOfBadIndx = + helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, + "SCATTER_UPD OP: please check elements of indices-array, " + "total number of wrong elements is %lld!", + numOfBadIndx); + } + // ScatterHelper::template scatterApply>(output, + // indices, updates); + helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, + *updates, *output, lock); + } - DECLARE_TYPES(scatter_upd) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } - } + return Status::OK(); +} +DECLARE_SYN(ScatterUpdate, scatter_upd); + +DECLARE_TYPES(scatter_upd) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp index 72b087e7baa9..1d7b9ec3a110 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/scatter_update.cpp @@ -22,39 +22,36 @@ #if NOT_EXCLUDED(OP_scatter_update) #include -#include +#include namespace sd { - namespace ops { - /** - * scatter update operation - * - * IArgs map: - * IArgs[0] - update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign - * IArgs[1] - number of dimensions - * IArgs[...] - dimensions - * IArgs[...] - number of indices - * IArgs[...] - indices - * - * @tparam T - */ - CONFIGURABLE_OP_IMPL(scatter_update, 2, 1, true, 0, -1) { - - auto operand = INPUT_VARIABLE(0); - auto updates = INPUT_VARIABLE(1); - - helpers::scatterUpdate(block.launchContext(), *operand, *updates, &block.getIArguments()); - - return Status::OK(); - } - DECLARE_SYN(scatterupdate, scatter_update); - - DECLARE_TYPES(scatter_update) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +namespace ops { +/** + * scatter update operation + * + * IArgs map: + * IArgs[0] - update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 + * - rdiv; 6 - assign IArgs[1] - number of dimensions IArgs[...] - dimensions + * IArgs[...] - number of indices + * IArgs[...] - indices + * + * @tparam T + */ +CONFIGURABLE_OP_IMPL(scatter_update, 2, 1, true, 0, -1) { + auto operand = INPUT_VARIABLE(0); + auto updates = INPUT_VARIABLE(1); + + helpers::scatterUpdate(block.launchContext(), *operand, *updates, + &block.getIArguments()); + + return Status::OK(); +} +DECLARE_SYN(scatterupdate, scatter_update); + +DECLARE_TYPES(scatter_update) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp index 23db77b0493e..1c25a5c08d2a 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/slice.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/slice.cpp @@ -21,215 +21,258 @@ #include //#if NOT_EXCLUDED(OP_slice) -#include #include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(slice, 1, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - int x_rank = input->rankOf(); - - std::vector begin; - std::vector sz; - - if (block.width() == 3) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); - - begin = b->template asVectorT(); - sz = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - - auto vec = block.getIArguments(); - - ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); - ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); - } - - REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(sz.size() == x_rank, 0, "size array should have length of [%i] but got [%i] instead", x_rank, sz.size()); - - std::vector indices(2 * x_rank); - auto empty = false; - for (int e = 0; e < x_rank; e++) { - int size = sz[e]; - int start = begin[e]; - - REQUIRE_TRUE(start >= 0, 0, "Slice: start index should not be negative"); - - REQUIRE_TRUE(start <= input->sizeAt(e), 0, "Index %i is invalid for dimension %i with size %i.", start, e, input->shapeInfo()[e + 1]); - if (size == -1){ - size = input->sizeAt(e) - start; - } - REQUIRE_TRUE(size >= 0, 0, "Slice: interval for dimension %i is less then 1"); - REQUIRE_TRUE(start + size <= input->sizeAt(e), 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, input->sizeAt(e)); - - if(start == input->sizeAt(e) || size == 0 ){ - empty = true; - //Don't break to perform input validation on other dims - } - - indices[2*e] = start; - indices[2*e+1] = start + size; - } - - if(empty){ - REQUIRE_TRUE(output->isEmpty(), 0, "Slice: empty array indices requested, but output array is not empty"); - return Status::OK(); - } +namespace ops { +CUSTOM_OP_IMPL(slice, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + int x_rank = input->rankOf(); + + std::vector begin; + std::vector sz; + + if (block.width() == 3) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); + + begin = b->template asVectorT(); + sz = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + + auto vec = block.getIArguments(); + + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); + } + + REQUIRE_TRUE(begin.size() == x_rank, 0, + "begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(sz.size() == x_rank, 0, + "size array should have length of [%i] but got [%i] instead", + x_rank, sz.size()); + + std::vector indices(2 * x_rank); + auto empty = false; + for (int e = 0; e < x_rank; e++) { + int size = sz[e]; + int start = begin[e]; + + REQUIRE_TRUE(start >= 0, 0, "Slice: start index should not be negative"); + + REQUIRE_TRUE(start <= input->sizeAt(e), 0, + "Index %i is invalid for dimension %i with size %i.", start, e, + input->shapeInfo()[e + 1]); + if (size == -1) { + size = input->sizeAt(e) - start; + } + REQUIRE_TRUE(size >= 0, 0, + "Slice: interval for dimension %i is less then 1"); + REQUIRE_TRUE(start + size <= input->sizeAt(e), 0, + "Slice: interval [%i, %i] is out of bounds for dimension %i " + "with size %i", + start, start + size, e, input->sizeAt(e)); + + if (start == input->sizeAt(e) || size == 0) { + empty = true; + // Don't break to perform input validation on other dims + } - Nd4jLong* subArrShapeInfo = nullptr; - ALLOCATE(subArrShapeInfo, block.workspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong); + indices[2 * e] = start; + indices[2 * e + 1] = start + size; + } - Nd4jLong offset; + if (empty) { + REQUIRE_TRUE( + output->isEmpty(), 0, + "Slice: empty array indices requested, but output array is not empty"); + return Status::OK(); + } - shape::calcSubArrShapeInfoAndOffset(indices.data(), input->shapeInfo(), subArrShapeInfo, offset, true); + Nd4jLong *subArrShapeInfo = nullptr; + ALLOCATE(subArrShapeInfo, block.workspace(), + shape::shapeInfoLength(input->rankOf()), Nd4jLong); - auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo); + Nd4jLong offset; - NDArray::prepareSpecialUse({output}, {input}); + shape::calcSubArrShapeInfoAndOffset(indices.data(), input->shapeInfo(), + subArrShapeInfo, offset, true); - NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign, - input->bufferWithOffset(offset), reinterpret_cast(subArrShapeInfoPack.primary()), - input->specialBufferWithOffset(offset), reinterpret_cast(subArrShapeInfoPack.special()), - output->buffer(), output->shapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - nullptr, nullptr, nullptr, true); + auto subArrShapeInfoPack = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo); - NDArray::registerSpecialUse({output}, {input}); + NDArray::prepareSpecialUse({output}, {input}); - RELEASE(subArrShapeInfo, block.workspace()); + NativeOpExecutioner::execTransformAny( + block.launchContext(), sd::transform::Assign, + input->bufferWithOffset(offset), + reinterpret_cast(subArrShapeInfoPack.primary()), + input->specialBufferWithOffset(offset), + reinterpret_cast(subArrShapeInfoPack.special()), + output->buffer(), output->shapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), nullptr, nullptr, nullptr, true); - // auto sub = (*input)(indices, true); - // output->assign(sub); + NDArray::registerSpecialUse({output}, {input}); - STORE_RESULT(output); + RELEASE(subArrShapeInfo, block.workspace()); - return Status::OK(); - } + // auto sub = (*input)(indices, true); + // output->assign(sub); - DECLARE_TYPES(slice) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } + STORE_RESULT(output); - DECLARE_SHAPE_FN(slice) { - auto inShape = inputShape->at(0); - auto x_rank = shape::rank(inShape); + return Status::OK(); +} - std::vector begin; - std::vector sz; +DECLARE_TYPES(slice) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); +} - if (block.width() == 3) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); +DECLARE_SHAPE_FN(slice) { + auto inShape = inputShape->at(0); + auto x_rank = shape::rank(inShape); + + std::vector begin; + std::vector sz; + + if (block.width() == 3) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); + + begin = b->template asVectorT(); + sz = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + + auto vec = block.getIArguments(); + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); + } + + REQUIRE_TRUE(begin.size() == x_rank, 0, + "Begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(sz.size() == x_rank, 0, + "Size array should have length of [%i] but got [%i] instead", + x_rank, sz.size()); + + std::vector shape; + auto empty = false; + for (int e = 0; e < x_rank; e++) { + auto size = sz[e]; + auto start = begin[e]; + + if (size == -1) { + size = inShape[e + 1] - start; + } - begin = b->template asVectorT(); - sz = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); + // Bounds checking. Note that begin[i] == size[i] means empty array + REQUIRE_TRUE(start >= 0 && start <= inShape[e + 1], 0, + "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= " + "size[i], got begin=%i for dimension size %i", + e, start, inShape[e + 1]); + REQUIRE_TRUE(size == -1 || size >= 0, 0, + "Invalid size[%i] value: must be positive (or -1 for 'all " + "remaining'), got %i", + e, size, inShape[e + 1]); + REQUIRE_TRUE(start >= 0 && start <= inShape[e + 1], 0, + "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= " + "size[i], got begin=%i for dimension size %i", + e, start, inShape[e + 1]); + REQUIRE_TRUE(start + size <= inShape[e + 1], 0, + "Slice: interval [%i, %i] is out of bounds for dimension %i " + "with size %i", + start, start + size, e, inShape[e + 1]); + if (start == inShape[e + 1]) { + size = 0; + } - auto vec = block.getIArguments(); - ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); - ShapeUtils::copyVectorPart(sz, vec, x_rank, x_rank); - } + shape.emplace_back(size); + } - REQUIRE_TRUE(begin.size() == x_rank, 0, "Begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(sz.size() == x_rank, 0, "Size array should have length of [%i] but got [%i] instead", x_rank, sz.size()); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), 'c', shape); + return SHAPELIST(newShape); +} - std::vector shape; - auto empty = false; - for (int e = 0; e < x_rank; e++) { - auto size = sz[e]; - auto start = begin[e]; +DECLARE_TYPES(slice_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - if(size == -1){ - size = inShape[e+1] - start; - } +CUSTOM_OP_IMPL(slice_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto epsNext = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(1); - //Bounds checking. Note that begin[i] == size[i] means empty array - REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); - REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]); - REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); - REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]); - if(start == inShape[e+1] ){ - size = 0; - } + auto output = OUTPUT_VARIABLE(0); + output->assign(0.); + int x_rank = input->rankOf(); - shape.emplace_back(size); - } + std::vector begin; + std::vector end; - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); - return SHAPELIST(newShape); - } + if (block.width() == 4) { + auto b = INPUT_VARIABLE(1); + auto e = INPUT_VARIABLE(2); - DECLARE_TYPES(slice_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } + begin = b->template asVectorT(); + end = e->template asVectorT(); + } else { + REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, + "Number of IArgs should be equal to [%i] but got [%i] instead", + x_rank * 2, block.numI()); + auto vec = block.getIArguments(); - CUSTOM_OP_IMPL(slice_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto epsNext = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(1); + ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); + ShapeUtils::copyVectorPart(end, vec, x_rank, x_rank); + } - auto output = OUTPUT_VARIABLE(0); - output->assign(0.); - int x_rank = input->rankOf(); + REQUIRE_TRUE(begin.size() == x_rank, 0, + "begin array should have length of [%i] but got [%i] instead", + x_rank, begin.size()); + REQUIRE_TRUE(end.size() == x_rank, 0, + "end array should have length of [%i] but got [%i] instead", + x_rank, end.size()); - std::vector begin; - std::vector end; + std::vector indices(2 * x_rank); + for (int e = 0; e < x_rank; e++) { + int size = end[e]; + int start = begin[e]; - if (block.width() == 4) { - auto b = INPUT_VARIABLE(1); - auto e = INPUT_VARIABLE(2); + if (size == -1) { //-1 means all remaining values + size = input->sizeAt(e) - start; + } + REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", + e); - begin = b->template asVectorT(); - end = e->template asVectorT(); - } else { - REQUIRE_TRUE(block.numI() >= x_rank * 2, 0, "Number of IArgs should be equal to [%i] but got [%i] instead", x_rank * 2, block.numI()); - - auto vec = block.getIArguments(); + indices[2 * e] = start; + indices[2 * e + 1] = start + size; + } + auto sub = (*output)(indices, true); + sub.assign(epsNext); - ShapeUtils::copyVectorPart(begin, vec, x_rank, 0); - ShapeUtils::copyVectorPart(end, vec, x_rank, x_rank); - } + return Status::OK(); +} - REQUIRE_TRUE(begin.size() == x_rank, 0, "begin array should have length of [%i] but got [%i] instead", x_rank, begin.size()); - REQUIRE_TRUE(end.size() == x_rank, 0, "end array should have length of [%i] but got [%i] instead", x_rank, end.size()); +DECLARE_SHAPE_FN(slice_bp) { + auto inShape = inputShape->at(0); + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); - std::vector indices(2 * x_rank); - for (int e = 0; e < x_rank; e++) { - int size = end[e]; - int start = begin[e]; - - if (size == -1){ //-1 means all remaining values - size = input->sizeAt(e) - start; - } - REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", e); - - indices[2*e] = start; - indices[2*e + 1] = start + size; - } - auto sub = (*output)(indices, true); - sub.assign(epsNext); - - return Status::OK(); - } - - DECLARE_SHAPE_FN(slice_bp) { - auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd //#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp index 9a16838185ef..64e92bfef321 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp @@ -26,75 +26,108 @@ limitations under the License. namespace sd { namespace ops { - CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - auto input = INPUT_VARIABLE(0); - auto padding = INPUT_VARIABLE(1); - - auto output = OUTPUT_VARIABLE(0); - - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(output->rankOf() == 4, 0, "SpaceToBatch: rank of output array must be equal 4, but got %i instead", output->rankOf()); - - if(padding->sizeAt(0) != 2 || padding->sizeAt(1) != 2) - REQUIRE_TRUE(false, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(padding).c_str()); - - const uint padBottom = padding->e(0,0); - const uint padTop = padding->e(0,1); - const uint padLeft = padding->e(1,0); - const uint padRight = padding->e(1,1); - - REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); - else - helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize); - - return Status::OK(); + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + auto input = INPUT_VARIABLE(0); + auto padding = INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "SpaceToBatch: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "SpaceToBatch: rank of input array must be equal 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + output->rankOf() == 4, 0, + "SpaceToBatch: rank of output array must be equal 4, but got %i instead", + output->rankOf()); + + if (padding->sizeAt(0) != 2 || padding->sizeAt(1) != 2) + REQUIRE_TRUE(false, 0, + "SpaceToBatch: operation expects padding shape to be {2, 2}, " + "but got %s instead", + ShapeUtils::shapeAsString(padding).c_str()); + + const uint padBottom = padding->e(0, 0); + const uint padTop = padding->e(0, 1); + const uint padLeft = padding->e(1, 0); + const uint padRight = padding->e(1, 1); + + REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && + (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatch: after padding, second and third dimensions of " + "input array must be divisible by blockSize !"); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, + padTop, padLeft, padRight, blockSize); + else + helpers::spaceToBatch(block.launchContext(), input->dup(), *output, + padBottom, padTop, padLeft, padRight, blockSize); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(space_to_batch) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(space_to_batch) { - - auto inputShapeInfo = inputShape->at(0); - auto paddingShapeInfo = inputShape->at(1); - - const uint blockSize = INT_ARG(0); - REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); - - const int rank = inputShapeInfo[0]; - REQUIRE_TRUE(rank == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", rank); - - if(paddingShapeInfo[1] != 2 || paddingShapeInfo[1] != 2) - REQUIRE_TRUE(false, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); - - const uint padBottom = INPUT_VARIABLE(1)->e(0,0); - const uint padTop = INPUT_VARIABLE(1)->e(0,1); - const uint padLeft = INPUT_VARIABLE(1)->e(1,0); - const uint padRight = INPUT_VARIABLE(1)->e(1,1); - - REQUIRE_TRUE((inputShapeInfo[2] + padBottom + padTop) % blockSize == 0 && (inputShapeInfo[3] + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {inputShapeInfo[1] * blockSize * blockSize, (inputShapeInfo[2] + padBottom + padTop) / blockSize, (inputShapeInfo[3] + padLeft + padRight) / blockSize, inputShapeInfo[4]})); + auto inputShapeInfo = inputShape->at(0); + auto paddingShapeInfo = inputShape->at(1); + + const uint blockSize = INT_ARG(0); + REQUIRE_TRUE(blockSize >= 2, 0, + "SpaceToBatch: integer parameter block_size must be >= 2, but " + "got %i instead", + blockSize); + + const int rank = inputShapeInfo[0]; + REQUIRE_TRUE( + rank == 4, 0, + "SpaceToBatch: rank of input array must be equal 4, but got %i instead", + rank); + + if (paddingShapeInfo[1] != 2 || paddingShapeInfo[1] != 2) + REQUIRE_TRUE(false, 0, + "SpaceToBatch: operation expects padding shape to be {2, 2}, " + "but got %s instead", + ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); + + const uint padBottom = INPUT_VARIABLE(1)->e(0, 0); + const uint padTop = INPUT_VARIABLE(1)->e(0, 1); + const uint padLeft = INPUT_VARIABLE(1)->e(1, 0); + const uint padRight = INPUT_VARIABLE(1)->e(1, 1); + + REQUIRE_TRUE((inputShapeInfo[2] + padBottom + padTop) % blockSize == 0 && + (inputShapeInfo[3] + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatch: after padding, second and third dimensions of " + "input array must be divisible by blockSize !"); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', + {inputShapeInfo[1] * blockSize * blockSize, + (inputShapeInfo[2] + padBottom + padTop) / blockSize, + (inputShapeInfo[3] + padLeft + padRight) / blockSize, + inputShapeInfo[4]})); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp index 0b8c4152d32f..b3bf700ebc51 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp @@ -23,83 +23,113 @@ limitations under the License. #include namespace sd { -namespace ops { - +namespace ops { CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockSize[0], (iW + padLeft + padRight)/blockSize[1], iC] - - auto input = INPUT_VARIABLE(0); - auto blockShape = INPUT_VARIABLE(1); - auto padding = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(blockShape->rankOf() == 1, 0, "SpaceToBatchND: rank of blockShape array must be equal to one, but got %i instead !", blockShape->rankOf()); - - const uint numOfSpatialDims = blockShape->sizeAt(0); - - REQUIRE_TRUE(input->rankOf() == output->rankOf(), 0, "SpaceToBatchND: rank of input and output array must be the same, but got %i and %i correspondingly !", input->rankOf(), output->rankOf()); - - if(padding->sizeAt(0) != numOfSpatialDims || padding->sizeAt(1) != 2) { - const std::string expectedpaddingShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "SpaceToBatchND: operation expects padding shape to be %s, but got %s instead", expectedpaddingShape.c_str(), ShapeUtils::shapeAsString(padding).c_str()); - } - - // FIXME - should we use this time-consuming validation ? - for (uint i = 0; i < numOfSpatialDims; ++i) { - const uint padLeft = padding->e(i,0); - const uint padRight = padding->e(i,1); - const Nd4jLong blockSize = blockShape->e(i); - REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !"); - } - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); - else - helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output); - - return Status::OK(); + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockSize[0], (iW + padLeft + + // padRight)/blockSize[1], iC] + + auto input = INPUT_VARIABLE(0); + auto blockShape = INPUT_VARIABLE(1); + auto padding = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(blockShape->rankOf() == 1, 0, + "SpaceToBatchND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShape->rankOf()); + + const uint numOfSpatialDims = blockShape->sizeAt(0); + + REQUIRE_TRUE(input->rankOf() == output->rankOf(), 0, + "SpaceToBatchND: rank of input and output array must be the " + "same, but got %i and %i correspondingly !", + input->rankOf(), output->rankOf()); + + if (padding->sizeAt(0) != numOfSpatialDims || padding->sizeAt(1) != 2) { + const std::string expectedpaddingShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "SpaceToBatchND: operation expects padding shape to be %s, " + "but got %s instead", + expectedpaddingShape.c_str(), + ShapeUtils::shapeAsString(padding).c_str()); + } + + // FIXME - should we use this time-consuming validation ? + for (uint i = 0; i < numOfSpatialDims; ++i) { + const uint padLeft = padding->e(i, 0); + const uint padRight = padding->e(i, 1); + const Nd4jLong blockSize = blockShape->e(i); + REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, + 0, + "SpaceToBatchND: after padding, spatial dimensions of input " + "array must be divisible by blockSize !"); + } + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, + *padding, *output); + else + helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, + *padding, *output); + + return Status::OK(); } //////////////////////////////////////////////////////////////////////////////// DECLARE_TYPES(space_to_batch_nd) { - - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setSameMode(true); } //////////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(space_to_batch_nd) { - - auto inputShapeInfo = inputShape->at(0); - auto blockShapeInfo = inputShape->at(1); - auto paddingShapeInfo = inputShape->at(2); - - REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, "SpaceToBatchND: rank of blockShape array must be equal to one, but got %i instead !", blockShapeInfo[0]); - - const uint numOfSpatialDims = blockShapeInfo[1]; - - if(paddingShapeInfo[1] != numOfSpatialDims || paddingShapeInfo[2] != 2) { - const std::string expectedpaddingShape = "[" + std::to_string(numOfSpatialDims) + ", 2]"; // [numOfSpatialDims, 2] - REQUIRE_TRUE(false, 0, "SpaceToBatchND: operation expects padding shape to be %s, but got %s instead", expectedpaddingShape.c_str(), ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); - } - - std::vector outShape(inputShapeInfo + 1, inputShapeInfo + 1 + inputShapeInfo[0]); - - outShape[0] *= INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); - - for (uint i = 0; i < numOfSpatialDims; ++i) - outShape[i + 1] = (outShape[i + 1] + INPUT_VARIABLE(2)->e(i,0) + INPUT_VARIABLE(2)->e(i,1)) / INPUT_VARIABLE(1)->e(i); - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); + auto inputShapeInfo = inputShape->at(0); + auto blockShapeInfo = inputShape->at(1); + auto paddingShapeInfo = inputShape->at(2); + + REQUIRE_TRUE(blockShapeInfo[0] == 1, 0, + "SpaceToBatchND: rank of blockShape array must be equal to one, " + "but got %i instead !", + blockShapeInfo[0]); + + const uint numOfSpatialDims = blockShapeInfo[1]; + + if (paddingShapeInfo[1] != numOfSpatialDims || paddingShapeInfo[2] != 2) { + const std::string expectedpaddingShape = "[" + + std::to_string(numOfSpatialDims) + + ", 2]"; // [numOfSpatialDims, 2] + REQUIRE_TRUE(false, 0, + "SpaceToBatchND: operation expects padding shape to be %s, " + "but got %s instead", + expectedpaddingShape.c_str(), + ShapeUtils::shapeAsString(paddingShapeInfo).c_str()); + } + + std::vector outShape(inputShapeInfo + 1, + inputShapeInfo + 1 + inputShapeInfo[0]); + + outShape[0] *= + INPUT_VARIABLE(1)->reduceNumber(sd::reduce::Prod).e(0); + + for (uint i = 0; i < numOfSpatialDims; ++i) + outShape[i + 1] = (outShape[i + 1] + INPUT_VARIABLE(2)->e(i, 0) + + INPUT_VARIABLE(2)->e(i, 1)) / + INPUT_VARIABLE(1)->e(i); + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inputShapeInfo), 'c', outShape)); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp index b831dce2fde9..3e83ce6e5dd1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp @@ -22,68 +22,73 @@ #if NOT_EXCLUDED(OP_space_to_depth) #include -#include #include +#include + namespace sd { namespace ops { - DECLARE_TYPES(space_to_depth) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - - CUSTOM_OP_IMPL(space_to_depth, 1, 1, false, 0, 2) { - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "SpaceToDepth: input should be 4D array, but got %f instead", input->rankOf()); - - int bS = input->sizeAt(0); - int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); - - REQUIRE_TRUE(iH % block_size == 0 && iW % block_size == 0, 0, "SpaceToDepth: input Height & Width should be divisible by block_size"); - - auto output = OUTPUT_VARIABLE(0); - - if (shape::strideDescendingCAscendingF(input->shapeInfo())) - helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC); - else - helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC); - - return Status::OK(); - } - - - DECLARE_SHAPE_FN(space_to_depth) { - auto in = inputShape->at(0); - int block_size = INT_ARG(0); - bool isNHWC = INT_ARG(1) == 1; - - int bS = shape::sizeAt(in, 0); - int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); - int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); - int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); - - int oD = iD * block_size * block_size; - int oH = iH / block_size; - int oW = iW / block_size; - - std::array shape; - if (isNHWC) - shape = {{bS, oH, oW, oD }}; - else - shape = {{bS, oD, oH, oW }}; - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), 'c', 4, shape.data()); - return SHAPELIST(newShape); - } +DECLARE_TYPES(space_to_depth) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } + +CUSTOM_OP_IMPL(space_to_depth, 1, 1, false, 0, 2) { + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + auto input = INPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "SpaceToDepth: input should be 4D array, but got %f instead", + input->rankOf()); + + int bS = input->sizeAt(0); + int iD = isNHWC ? input->sizeAt(3) : input->sizeAt(1); + int iH = isNHWC ? input->sizeAt(1) : input->sizeAt(2); + int iW = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + + REQUIRE_TRUE( + iH % block_size == 0 && iW % block_size == 0, 0, + "SpaceToDepth: input Height & Width should be divisible by block_size"); + + auto output = OUTPUT_VARIABLE(0); + + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, + isNHWC); + else + helpers::_spaceTodepth(block.launchContext(), input->dup(), output, + block_size, isNHWC); + + return Status::OK(); +} + +DECLARE_SHAPE_FN(space_to_depth) { + auto in = inputShape->at(0); + int block_size = INT_ARG(0); + bool isNHWC = INT_ARG(1) == 1; + + int bS = shape::sizeAt(in, 0); + int iD = isNHWC ? shape::sizeAt(in, 3) : shape::sizeAt(in, 1); + int iH = isNHWC ? shape::sizeAt(in, 1) : shape::sizeAt(in, 2); + int iW = isNHWC ? shape::sizeAt(in, 2) : shape::sizeAt(in, 3); + + int oD = iD * block_size * block_size; + int oH = iH / block_size; + int oW = iW / block_size; + + std::array shape; + if (isNHWC) + shape = {{bS, oH, oW, oD}}; + else + shape = {{bS, oD, oH, oW}}; + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(in), 'c', 4, shape.data()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/split.cpp b/libnd4j/include/ops/declarable/generic/transforms/split.cpp index 462f2c77e11a..600ad6093512 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/split.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/split.cpp @@ -22,127 +22,134 @@ #if NOT_EXCLUDED(OP_split) #include -#include +#include + #include namespace sd { namespace ops { - CUSTOM_OP_IMPL(split, 1, -1, false, 0, 1) { - NDArray *input = nullptr; - int num_splits = INT_ARG(0); - - // axis is 0 by default - int axis = 0; - - if (block.width() == 1) { - input = INPUT_VARIABLE(0); - } else { - auto a = INPUT_VARIABLE(0); - auto b = INPUT_VARIABLE(1); - - if (a->isScalar()) { - // axis goes first - axis = a->e(0); - input = b; - } else if (b->isScalar()) { - axis = b->e(0); - input = a; - } - } - - //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays - if(input->isEmpty()){ - for( int i=0; i< num_splits; i++ ){ - REQUIRE_TRUE(OUTPUT_VARIABLE(i)->isEmpty(), 0, "Split: When input array is empty, all output arrays must be empty"); - } - //No op - return Status::OK(); - } - - if (block.numI() == 2) - axis = INT_ARG(1); - - if(axis < 0) axis += input->rankOf(); - - REQUIRE_TRUE(input->sizeAt(axis) % num_splits == 0, 0, "Split: num_splits has wrong value, remainder of division should be 0, but it's %i", input->sizeAt(axis) % num_splits); - - std::vector outArrs(num_splits); - for (int e = 0; e < num_splits; e++) { - outArrs[e] = OUTPUT_VARIABLE(e); - } - - helpers::split(block.launchContext(), *input, outArrs, axis); - - return Status::OK(); +CUSTOM_OP_IMPL(split, 1, -1, false, 0, 1) { + NDArray *input = nullptr; + int num_splits = INT_ARG(0); + + // axis is 0 by default + int axis = 0; + + if (block.width() == 1) { + input = INPUT_VARIABLE(0); + } else { + auto a = INPUT_VARIABLE(0); + auto b = INPUT_VARIABLE(1); + + if (a->isScalar()) { + // axis goes first + axis = a->e(0); + input = b; + } else if (b->isScalar()) { + axis = b->e(0); + input = a; } - - DECLARE_TYPES(split) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); + } + + // Edge case: splitting empty array (mainly for TF import compatibility) -> + // return N empty arrays + if (input->isEmpty()) { + for (int i = 0; i < num_splits; i++) { + REQUIRE_TRUE( + OUTPUT_VARIABLE(i)->isEmpty(), 0, + "Split: When input array is empty, all output arrays must be empty"); } + // No op + return Status::OK(); + } - DECLARE_SHAPE_FN(split) { - int num_splits = INT_ARG(0); - auto input = inputShape->at(0); - sd::DataType dataType = ArrayOptions::dataType(input); - - // axis is 0 by default - int axis = 0; - - int inputVar = 0; - if (inputShape->size() != 1) { - auto shape0 = inputShape->at(0); - auto shape1 = inputShape->at(1); - - if (shape::isScalar(shape0)) { - input = shape1; - auto _a = INPUT_VARIABLE(0); - axis = _a->e(0); - dataType = ArrayOptions::dataType(shape1); - inputVar = 1; - } else if (shape::isScalar(shape1)) { - input = shape0; - auto _a = INPUT_VARIABLE(1); - axis = _a->e(0); - dataType = ArrayOptions::dataType(shape0); - inputVar = 0; - } - } - - auto shapes = SHAPELIST(); - - //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays - // if(INPUT_VARIABLE(inputVar)->isEmpty()){ - // for (int e = 0; e < num_splits; e++) { - // auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType); - // shapes->push_back(empty); - // } - // return shapes; - // } - - if (block.numI() == 2) - axis = INT_ARG(1); - - if (axis < 0) - axis += shape::rank(input); - - std::vector shape(shape::rank(input)); - - for (int e = 0; e < shape::rank(input); e++) - if (e == axis) - shape[e] = shape::sizeAt(input, e) / num_splits; - else - shape[e] = shape::sizeAt(input, e); - - for (int e = 0; e < num_splits; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape); - shapes->push_back(newShape); - } - - return shapes; - } + if (block.numI() == 2) axis = INT_ARG(1); + + if (axis < 0) axis += input->rankOf(); + + REQUIRE_TRUE(input->sizeAt(axis) % num_splits == 0, 0, + "Split: num_splits has wrong value, remainder of division " + "should be 0, but it's %i", + input->sizeAt(axis) % num_splits); + + std::vector outArrs(num_splits); + for (int e = 0; e < num_splits; e++) { + outArrs[e] = OUTPUT_VARIABLE(e); + } + + helpers::split(block.launchContext(), *input, outArrs, axis); + + return Status::OK(); +} + +DECLARE_TYPES(split) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); } + +DECLARE_SHAPE_FN(split) { + int num_splits = INT_ARG(0); + auto input = inputShape->at(0); + sd::DataType dataType = ArrayOptions::dataType(input); + + // axis is 0 by default + int axis = 0; + + int inputVar = 0; + if (inputShape->size() != 1) { + auto shape0 = inputShape->at(0); + auto shape1 = inputShape->at(1); + + if (shape::isScalar(shape0)) { + input = shape1; + auto _a = INPUT_VARIABLE(0); + axis = _a->e(0); + dataType = ArrayOptions::dataType(shape1); + inputVar = 1; + } else if (shape::isScalar(shape1)) { + input = shape0; + auto _a = INPUT_VARIABLE(1); + axis = _a->e(0); + dataType = ArrayOptions::dataType(shape0); + inputVar = 0; + } + } + + auto shapes = SHAPELIST(); + + // Edge case: splitting empty array (mainly for TF import compatibility) -> + // return N empty arrays + // if(INPUT_VARIABLE(inputVar)->isEmpty()){ + // for (int e = 0; e < num_splits; e++) { + // auto empty = + // ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType); + // shapes->push_back(empty); + // } + // return shapes; + // } + + if (block.numI() == 2) axis = INT_ARG(1); + + if (axis < 0) axis += shape::rank(input); + + std::vector shape(shape::rank(input)); + + for (int e = 0; e < shape::rank(input); e++) + if (e == axis) + shape[e] = shape::sizeAt(input, e) / num_splits; + else + shape[e] = shape::sizeAt(input, e); + + for (int e = 0; e < num_splits; e++) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + dataType, shape::order(input), shape); + shapes->push_back(newShape); + } + + return shapes; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp index a3499e93c7aa..63f4e530e662 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/split_v.cpp @@ -25,104 +25,104 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto sizes = INPUT_VARIABLE(1); +CUSTOM_OP_IMPL(split_v, 2, -1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto sizes = INPUT_VARIABLE(1); - int axis = 0; + int axis = 0; - if (block.numI() > 0) { - axis = INT_ARG(0); - } else if (block.width() > 2){ - auto _a = INPUT_VARIABLE(2); - axis = _a->e(0); - } + if (block.numI() > 0) { + axis = INT_ARG(0); + } else if (block.width() > 2) { + auto _a = INPUT_VARIABLE(2); + axis = _a->e(0); + } - if (axis < 0) - axis += input->rankOf(); + if (axis < 0) axis += input->rankOf(); - std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + std::vector dims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - int pos = 0; - std::vector indices(2 * input->rankOf()); - - for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { - int c_size = sizes->e(e); - - for (int d = 0; d < input->rankOf(); d++) { - if (d == axis) - indices[2*d + 1] = (indices[2*d] = pos) + c_size; - else - indices[2*d] = indices[2*d + 1] = 0; - } + int pos = 0; + std::vector indices(2 * input->rankOf()); - auto output = OUTPUT_VARIABLE(e); - REQUIRE_TRUE(output->dataType() == input->dataType(), 0, "SplitV: all outputs must have same data type as input"); + for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { + int c_size = sizes->e(e); - auto sub = (*input)(indices); + for (int d = 0; d < input->rankOf(); d++) { + if (d == axis) + indices[2 * d + 1] = (indices[2 * d] = pos) + c_size; + else + indices[2 * d] = indices[2 * d + 1] = 0; + } - output->assign(sub); + auto output = OUTPUT_VARIABLE(e); + REQUIRE_TRUE(output->dataType() == input->dataType(), 0, + "SplitV: all outputs must have same data type as input"); - pos += c_size; - } + auto sub = (*input)(indices); - //delete tads; - return Status::OK(); - } + output->assign(sub); - DECLARE_TYPES(split_v) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_INTS}) - ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); - } + pos += c_size; + } - DECLARE_SHAPE_FN(split_v) { - auto input = inputShape->at(0); - //auto sizes = inputShape->at(1); - - auto shapeList = SHAPELIST(); - int rank = shape::rank(input); - - // 0 is just default axis - int axis = 0; - - if (block.numI() > 0) - axis = INT_ARG(0); - else if (block.width() > 2) { - auto _a = INPUT_VARIABLE(2); - axis = _a->e(0); - } - - if (axis < 0) - axis += shape::rank(input); - - // this op assumes we have sizes defined - auto sizes = INPUT_VARIABLE(1); - - auto length = sizes->lengthOf(); - int pos = 0; - for (Nd4jLong e = 0; e < length; e++) { - int c_size = sizes->e(e); - - - std::vector shape(rank); - - for (int d = 0; d < rank; d++) { - if (d != axis) - shape[d] = shape::sizeAt(input, d); - else - shape[d] = c_size; - } - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(input), shape::order(input), shape); - shapeList->push_back(newShape); - } - - return shapeList; - } + // delete tads; + return Status::OK(); } + +DECLARE_TYPES(split_v) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(split_v) { + auto input = inputShape->at(0); + // auto sizes = inputShape->at(1); + + auto shapeList = SHAPELIST(); + int rank = shape::rank(input); + + // 0 is just default axis + int axis = 0; + + if (block.numI() > 0) + axis = INT_ARG(0); + else if (block.width() > 2) { + auto _a = INPUT_VARIABLE(2); + axis = _a->e(0); + } + + if (axis < 0) axis += shape::rank(input); + + // this op assumes we have sizes defined + auto sizes = INPUT_VARIABLE(1); + + auto length = sizes->lengthOf(); + int pos = 0; + for (Nd4jLong e = 0; e < length; e++) { + int c_size = sizes->e(e); + + std::vector shape(rank); + + for (int d = 0; d < rank; d++) { + if (d != axis) + shape[d] = shape::sizeAt(input, d); + else + shape[d] = c_size; + } + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(input), shape::order(input), shape); + shapeList->push_back(newShape); + } + + return shapeList; } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp index bc07544d33a3..2c9bd9ddbc57 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/stack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/stack.cpp @@ -22,88 +22,96 @@ #if NOT_EXCLUDED(OP_stack) #include -#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - int dim = block.numI() > 0 ? INT_ARG(0) : 0; - if(dim < 0) - dim += input->rankOf() + 1; + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + int dim = block.numI() > 0 ? INT_ARG(0) : 0; + if (dim < 0) dim += input->rankOf() + 1; - // no-op in case of empty output array - if (output->isEmpty()) - return Status::OK(); + // no-op in case of empty output array + if (output->isEmpty()) return Status::OK(); - // input validation - // check whether shapes of all input array are the same - for (int i = 0; i < (int) block.width() - 1; ++i) - REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), (INPUT_VARIABLE(i+1))->shapeInfo()), 0, "STACK op: the shapes of all input arrays must be the same !"); + // input validation + // check whether shapes of all input array are the same + for (int i = 0; i < (int)block.width() - 1; ++i) + REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->shapeInfo(), + (INPUT_VARIABLE(i + 1))->shapeInfo()), + 0, + "STACK op: the shapes of all input arrays must be the same !"); - REQUIRE_TRUE(dim <= input->rankOf(), 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", input->shapeOf(), dim); + REQUIRE_TRUE(dim <= input->rankOf(), 0, + "STACK op: the input dimension parameter must be <= rank of " + "input arrays shapes (rank=%i), but got %i instead !", + input->shapeOf(), dim); + std::vector inArrs(block.width()); + for (int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); + helpers::stack(block.launchContext(), inArrs, *output, dim); - helpers::stack(block.launchContext(), inArrs, *output, dim); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(pack, stack); DECLARE_SYN(Pack, stack); - DECLARE_TYPES(stack) { - //getOpDescriptor()->setSameMode(true); - getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) - ->setAllowedOutputTypes(DataType::ANY); - - } +DECLARE_TYPES(stack) { + // getOpDescriptor()->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes(DataType::ANY) + ->setAllowedOutputTypes(DataType::ANY); +} DECLARE_SHAPE_FN(stack) { - - // check whether input dimension is within rank range - auto inShapeInfo = inputShape->at(0); - int rank = shape::rank(inShapeInfo); - int dim = block.numI() > 0 ? INT_ARG(0) : 0; - if(dim < 0 ) - dim += rank + 1; - - REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim); - - // empty input arrays require some special handling - if (shape::isEmpty(inShapeInfo)) { - switch (rank) { - case 0: { - // we're going to return rank 1 here - if (block.width() == 1) { - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo))); - } else { - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0})); - } - } - } - } - - if(rank == 0) { - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo))); - } - - //the rank of output ShapeInfo is larger by one compared to input ShapeInfo - std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); - - // insert (int) block.width() at dim position of input shape to get output shape - outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width()); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape))); + // check whether input dimension is within rank range + auto inShapeInfo = inputShape->at(0); + int rank = shape::rank(inShapeInfo); + int dim = block.numI() > 0 ? INT_ARG(0) : 0; + if (dim < 0) dim += rank + 1; + + REQUIRE_TRUE(dim <= inShapeInfo[0], 0, + "STACK op: the input dimension parameter must be <= rank of " + "input arrays shapes (rank=%i), but got %i instead !", + inShapeInfo[0], dim); + + // empty input arrays require some special handling + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + // we're going to return rank 1 here + if (block.width() == 1) { + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + 0, ArrayOptions::dataType(inShapeInfo))); + } else { + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShapeInfo), 'c', + {(Nd4jLong)block.width(), 0})); + } + } + } + } + + if (rank == 0) { + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo( + block.width(), ArrayOptions::dataType(inShapeInfo))); + } + + // the rank of output ShapeInfo is larger by one compared to input ShapeInfo + std::vector outShape(inShapeInfo + 1, inShapeInfo + 1 + rank); + + // insert (int) block.width() at dim position of input shape to get output + // shape + outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong)block.width()); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), + shape::order(inShapeInfo), outShape))); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp index 4df64d20f780..8de950c3c91c 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/standardize.cpp @@ -24,117 +24,121 @@ #include #include - namespace sd { -namespace ops { - - CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - std::vector axis; - - if (block.width() > 1) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = block.getIArguments(); - - REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") - - shape::checkDimensions(input->rankOf(), axis); - - auto means = input->reduceAlongDimension(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); - stdev.reshapei(means.getShapeAsVector()); - - input->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), means, *output, false); - output->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, false); - output->applyScalar(sd::scalar::ReplaceNans, 0, *output); - - return Status::OK(); - } - - - DECLARE_TYPES(standardize) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); - getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); - } +namespace ops { - CUSTOM_OP_IMPL(standardize_bp, 2, 1, false, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); +CONFIGURABLE_OP_IMPL(standardize, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::vector axis; + std::vector axis; - if (block.width() == 3) - axis = INPUT_VARIABLE(1)->template asVectorT(); - else if (block.numI() > 0) - axis = block.getIArguments(); + if (block.width() > 1) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); - REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + shape::checkDimensions(input->rankOf(), axis); - shape::checkDimensions(input->rankOf(), axis); - auto longAxis = ArrayUtils::toLongVector(axis); + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, axis); + stdev.reshapei(means.getShapeAsVector()); - auto means = input->reduceAlongDimension(reduce::Mean, axis, true); - auto stdev = input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, axis); - stdev.reshapei(means.getShapeAsVector()); + input->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), means, *output, + false); + output->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, + false); + output->applyScalar(sd::scalar::ReplaceNans, 0, *output); - eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, false); - - NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); - - NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); - std::vector meanBpArgs = {input, &dldu_sum}; - std::vector meanBpOutput = {&dldx_u}; - std::vector meanBpTArgs = {}; - std::vector meanBpBArgs = {}; - - sd::ops::reduce_mean_bp meanBp; - meanBp.execute(meanBpArgs, meanBpOutput, meanBpTArgs, longAxis, meanBpBArgs); - *output += dldx_u; - - // (eps * (means - input) / (stdev * stdev)) - NDArray tmp(eps->shapeInfo(), false, block.launchContext()); - means.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), *input, tmp, false); - tmp.applyPairwiseTransform(sd::pairwise::Multiply, *eps, tmp); - stdev.applyPairwiseTransform(sd::pairwise::Multiply, stdev, stdev); - tmp.applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, tmp, false); - - auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); - NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); - std::vector stdevBpArgs = {input, &dlds_sum}; - std::vector stdevBpOutput = {&dldx_s}; - std::vector stdevBpTArgs = {}; - std::vector stdevBpBArgs = {}; - sd::ops::reduce_stdev_bp stdevBp; - stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, stdevBpBArgs); - *output += dldx_s; - - output->applyScalar(sd::scalar::ReplaceNans, 0, *output); + return Status::OK(); +} - return Status::OK(); - } +DECLARE_TYPES(standardize) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, + {DataType::INT32, DataType::INT64}); + getOpDescriptor()->setAllowedOutputTypes(0, DataType::INHERIT); +} - DECLARE_TYPES(standardize_bp) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +CUSTOM_OP_IMPL(standardize_bp, 2, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto eps = block.width() == 3 ? INPUT_VARIABLE(2) : INPUT_VARIABLE(1); + + auto output = OUTPUT_VARIABLE(0); + std::vector axis; + + if (block.width() == 3) + axis = INPUT_VARIABLE(1)->template asVectorT(); + else if (block.numI() > 0) + axis = block.getIArguments(); + + REQUIRE_TRUE(!axis.empty(), 0, "STANDARDIZE OP: axis has to be non-empty") + + shape::checkDimensions(input->rankOf(), axis); + auto longAxis = ArrayUtils::toLongVector(axis); + + auto means = input->reduceAlongDimension(reduce::Mean, axis, true); + auto stdev = input->varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, axis); + stdev.reshapei(means.getShapeAsVector()); + + eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, *output, + false); + + NDArray dldu_sum = -output->reduceAlongDimension(reduce::Sum, axis, true); + + NDArray dldx_u(input->shapeInfo(), false, block.launchContext()); + std::vector meanBpArgs = {input, &dldu_sum}; + std::vector meanBpOutput = {&dldx_u}; + std::vector meanBpTArgs = {}; + std::vector meanBpBArgs = {}; + + sd::ops::reduce_mean_bp meanBp; + meanBp.execute(meanBpArgs, meanBpOutput, meanBpTArgs, longAxis, meanBpBArgs); + *output += dldx_u; + + // (eps * (means - input) / (stdev * stdev)) + NDArray tmp(eps->shapeInfo(), false, block.launchContext()); + means.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), *input, tmp, + false); + tmp.applyPairwiseTransform(sd::pairwise::Multiply, *eps, tmp); + stdev.applyPairwiseTransform(sd::pairwise::Multiply, stdev, stdev); + tmp.applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), stdev, tmp, false); + + auto dlds_sum = tmp.reduceAlongDimension(reduce::Sum, axis, true); + NDArray dldx_s(input->shapeInfo(), false, block.launchContext()); + std::vector stdevBpArgs = {input, &dlds_sum}; + std::vector stdevBpOutput = {&dldx_s}; + std::vector stdevBpTArgs = {}; + std::vector stdevBpBArgs = {}; + sd::ops::reduce_stdev_bp stdevBp; + stdevBp.execute(stdevBpArgs, stdevBpOutput, stdevBpTArgs, longAxis, + stdevBpBArgs); + *output += dldx_s; + + output->applyScalar(sd::scalar::ReplaceNans, 0, *output); + + return Status::OK(); +} - DECLARE_SHAPE_FN(standardize_bp) { - auto in = inputShape->at(0); - Nd4jLong *out; - COPY_SHAPE(in, out); +DECLARE_TYPES(standardize_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - return SHAPELIST(CONSTANT(out)); - } +DECLARE_SHAPE_FN(standardize_bp) { + auto in = inputShape->at(0); + Nd4jLong *out; + COPY_SHAPE(in, out); + return SHAPELIST(CONSTANT(out)); } -} + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp index c1def43bf3bc..056677f42586 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tear.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tear.cpp @@ -21,60 +21,65 @@ #include #if NOT_EXCLUDED(OP_tear) -#include -#include #include +#include +#include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(tear, 1, -1, false, 0, -1) { - auto input = INPUT_VARIABLE(0); +namespace ops { +CUSTOM_OP_IMPL(tear, 1, -1, false, 0, -1) { + auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(!block.getIArguments().empty(), 0, "At least 1 dimension should be specified for Tear"); + REQUIRE_TRUE(!block.getIArguments().empty(), 0, + "At least 1 dimension should be specified for Tear"); - std::vector dims(block.getIArguments()); + std::vector dims(block.getIArguments()); - for (auto &v: dims) - REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, "Tear dimensions should be non-negative values, and lower then input rank. Got %i instead", v); + for (auto &v : dims) + REQUIRE_TRUE(v >= 0 && v < input->rankOf(), 0, + "Tear dimensions should be non-negative values, and lower " + "then input rank. Got %i instead", + v); - auto tads = input->allTensorsAlongDimension(dims); - for (Nd4jLong e = 0; e < tads.size(); e++) { - auto outE = OUTPUT_VARIABLE(e); - outE->assign(tads.at(e)); + auto tads = input->allTensorsAlongDimension(dims); + for (Nd4jLong e = 0; e < tads.size(); e++) { + auto outE = OUTPUT_VARIABLE(e); + outE->assign(tads.at(e)); - // just for debugging purposes - this->storeResult(block, e, *outE); - } + // just for debugging purposes + this->storeResult(block, e, *outE); + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_SHAPE_FN(tear) { - auto inShape = inputShape->at(0); +DECLARE_SHAPE_FN(tear) { + auto inShape = inputShape->at(0); - std::vector dims(block.getIArguments()); + std::vector dims(block.getIArguments()); - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(inShape, dims); - auto numTads = tadPack.numberOfTads(); + auto tadPack = + sd::ConstantTadHelper::getInstance()->tadForDimensions(inShape, dims); + auto numTads = tadPack.numberOfTads(); - auto result = SHAPELIST(); - for (Nd4jLong e = 0; e < numTads; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape::rank(tadPack.primaryShapeInfo()), shape::shapeOf(tadPack.primaryShapeInfo())); - result->push_back(newShape); - } + auto result = SHAPELIST(); + for (Nd4jLong e = 0; e < numTads; e++) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), + shape::rank(tadPack.primaryShapeInfo()), + shape::shapeOf(tadPack.primaryShapeInfo())); + result->push_back(newShape); + } - return result; - } + return result; +} - DECLARE_TYPES(tear) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } - } +DECLARE_TYPES(tear) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp index 86c1bd4ac7fa..bb8578c7d125 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/tile.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/tile.cpp @@ -23,162 +23,178 @@ #if NOT_EXCLUDED(OP_tile) #include -#include +#include namespace sd { namespace ops { CUSTOM_OP_IMPL(tile, 1, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - std::vector reps; - - if (block.numI() == inRank) { - - reps = ArrayUtils::toLongVector(block.getIArguments()); - } - else if (block.width() > 1) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - - reps = reps_vector->template asVectorT(); - } - else { - REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - auto repProd = shape::prodLong(reps.data(), reps.size()); - REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); - - input->tile(reps, *output); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 1) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + + reps = reps_vector->template asVectorT(); + } else { + REQUIRE_TRUE( + false, 0, + "TILE op: this op requires repeats vector, either as IArgs or second " + "array with length equal to rank of input array to be tiled !"); + } + + auto repProd = shape::prodLong(reps.data(), reps.size()); + REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); + + input->tile(reps, *output); + + return Status::OK(); } - DECLARE_TYPES(tile) { - getOpDescriptor()->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedOutputTypes(sd::DataType::ANY); - } - +DECLARE_TYPES(tile) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(sd::DataType::ANY); +} DECLARE_SHAPE_FN(tile) { - - auto inShape = inputShape->at(0); - const int inRank = inShape[0]; - std::vector reps; - - if (block.numI() == inRank) { - - reps = ArrayUtils::toLongVector(block.getIArguments()); - } - else if (block.width() > 1) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - reps = reps_vector->template asVectorT(); - } - else { - REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - auto repProd = shape::prodLong(reps.data(), reps.size()); - REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); - - std::vector shape(inRank); - for (int e = 0; e < shape::rank(inShape); e++) - shape[e] = shape::sizeAt(inShape, e) * reps[e]; - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); - return SHAPELIST(newShape); + auto inShape = inputShape->at(0); + const int inRank = inShape[0]; + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 1) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + reps = reps_vector->template asVectorT(); + } else { + REQUIRE_TRUE( + false, 0, + "TILE op: this op requires repeats vector, either as IArgs or second " + "array with length equal to rank of input array to be tiled !"); + } + + auto repProd = shape::prodLong(reps.data(), reps.size()); + REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s"); + + std::vector shape(inRank); + for (int e = 0; e < shape::rank(inShape); e++) + shape[e] = shape::sizeAt(inShape, e) * reps[e]; + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShape), shape::order(inShape), shape); + return SHAPELIST(newShape); } - //////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(tile_bp, 2, 1, false, 0, -2) { - - auto input = INPUT_VARIABLE(0); - auto gradO = INPUT_VARIABLE(1); - auto gradI = OUTPUT_VARIABLE(0); - - const int inRank = input->rankOf(); - - std::vector reps; - - if (block.numI() == inRank) { - - reps = ArrayUtils::toLongVector(block.getIArguments()); - } - else if (block.width() > 2) { - - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE_BP op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - - reps = reps_vector->template asVectorT(); - gradO = INPUT_VARIABLE(2); - } - else { - REQUIRE_TRUE(false, 0, "TILE_BP op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - REQUIRE_TRUE(inRank == gradO->rankOf(), 0, "TILE_BP op: the ranks of input array and output's gradients array (next epsilon) must be equal, but got %i and %i correspondingly !", inRank, gradO->rankOf()); - - for (int i = 0; i < inRank; ++i) - REQUIRE_TRUE(gradO->sizeAt(i) == gradI->sizeAt(i) * reps[i], 0, "TILE_BP op: shapes of input array and output's gradients array (next epsilon) are inconsistent !"); - - helpers::tileBP(block.launchContext(), *gradO, *gradI, reps); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + auto gradI = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 2) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE_BP op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + + reps = reps_vector->template asVectorT(); + gradO = INPUT_VARIABLE(2); + } else { + REQUIRE_TRUE( + false, 0, + "TILE_BP op: this op requires repeats vector, either as IArgs or " + "second array with length equal to rank of input array to be tiled !"); + } + + REQUIRE_TRUE( + inRank == gradO->rankOf(), 0, + "TILE_BP op: the ranks of input array and output's gradients array (next " + "epsilon) must be equal, but got %i and %i correspondingly !", + inRank, gradO->rankOf()); + + for (int i = 0; i < inRank; ++i) + REQUIRE_TRUE(gradO->sizeAt(i) == gradI->sizeAt(i) * reps[i], 0, + "TILE_BP op: shapes of input array and output's gradients " + "array (next epsilon) are inconsistent !"); + + helpers::tileBP(block.launchContext(), *gradO, *gradI, reps); + + return Status::OK(); } - DECLARE_TYPES(tile_bp) { - getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); +DECLARE_TYPES(tile_bp) { + getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS}); + getOpDescriptor()->setAllowedInputTypes(2, {ALL_FLOATS}); - getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); - } + getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS}); +} DECLARE_SHAPE_FN(tile_bp) { - - auto inShape = inputShape->at(0); - auto gradOShape = inputShape->at(1); - const int inRank = inShape[0]; - - std::vector reps; - - if (block.numI() == inRank) { - - reps = ArrayUtils::toLongVector(block.getIArguments()); - } - else if (block.width() > 2) { - auto reps_vector = INPUT_VARIABLE(1); - REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, "TILE_BP op: repeats vector length should be equal to input rank, but got %i and %i correspondingly !", reps_vector->lengthOf(), inRank); - reps = reps_vector->template asVectorT(); - gradOShape = inputShape->at(2); - } - else { - REQUIRE_TRUE(false, 0, "TILE_BP op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); - } - - REQUIRE_TRUE(inRank == gradOShape[0], 0, "TILE_BP op: the ranks of input array and output's gradients array (next epsilon) must be equal, but got %i and %i correspondingly !", inRank, gradOShape[0]); - - for (int i = 0; i < inRank; ++i) - REQUIRE_TRUE(shape::sizeAt(gradOShape, i) == shape::sizeAt(inShape, i) * reps[i], 0, "TILE_BP op: shapes of input array and output's gradients array (next epsilon) are inconsistent !"); - - Nd4jLong *gradIShape; - COPY_SHAPE(inShape, gradIShape); - - return SHAPELIST(CONSTANT(gradIShape)); - + auto inShape = inputShape->at(0); + auto gradOShape = inputShape->at(1); + const int inRank = inShape[0]; + + std::vector reps; + + if (block.numI() == inRank) { + reps = ArrayUtils::toLongVector(block.getIArguments()); + } else if (block.width() > 2) { + auto reps_vector = INPUT_VARIABLE(1); + REQUIRE_TRUE(reps_vector->lengthOf() == inRank, 0, + "TILE_BP op: repeats vector length should be equal to input " + "rank, but got %i and %i correspondingly !", + reps_vector->lengthOf(), inRank); + reps = reps_vector->template asVectorT(); + gradOShape = inputShape->at(2); + } else { + REQUIRE_TRUE( + false, 0, + "TILE_BP op: this op requires repeats vector, either as IArgs or " + "second array with length equal to rank of input array to be tiled !"); + } + + REQUIRE_TRUE( + inRank == gradOShape[0], 0, + "TILE_BP op: the ranks of input array and output's gradients array (next " + "epsilon) must be equal, but got %i and %i correspondingly !", + inRank, gradOShape[0]); + + for (int i = 0; i < inRank; ++i) + REQUIRE_TRUE( + shape::sizeAt(gradOShape, i) == shape::sizeAt(inShape, i) * reps[i], 0, + "TILE_BP op: shapes of input array and output's gradients array (next " + "epsilon) are inconsistent !"); + + Nd4jLong *gradIShape; + COPY_SHAPE(inShape, gradIShape); + + return SHAPELIST(CONSTANT(gradIShape)); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp b/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp index beebcad86eec..19eb69a97306 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/unstack.cpp @@ -23,107 +23,111 @@ #if NOT_EXCLUDED(OP_unstack) #include -#include +#include namespace sd { -namespace ops { +namespace ops { CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); + auto dim = INT_ARG(0); + if (dim < 0) dim += input->rankOf(); - auto dim = INT_ARG(0); - if (dim < 0) - dim += input->rankOf(); + REQUIRE_TRUE(dim < input->rankOf(), 0, + "Unstack dimension should be lower then rank of input %i, but " + "got dimension=%i !", + input->rankOf(), dim); + REQUIRE_TRUE(dim >= 0, 0, + "Unstack dimension should be non-negative value, but got %i !", + dim); + if (input->isEmpty()) return Status::OK(); - REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim); - REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim); + std::vector outArrs(input->sizeAt(dim)); + for (uint i = 0; i < outArrs.size(); ++i) outArrs[i] = OUTPUT_VARIABLE(i); - if(input->isEmpty()) - return Status::OK(); + helpers::unstack(block.launchContext(), *input, outArrs, dim); - std::vector outArrs(input->sizeAt(dim)); - for(uint i = 0; i < outArrs.size(); ++i) - outArrs[i] = OUTPUT_VARIABLE(i); - - helpers::unstack(block.launchContext(), *input, outArrs, dim); - - return Status::OK(); + return Status::OK(); } DECLARE_SYN(unpack, unstack); DECLARE_SHAPE_FN(unstack) { - auto inShapeInfo = inputShape->at(0); - - auto dim = INT_ARG(0); - if (dim < 0) - dim += shape::rank(inShapeInfo); - - REQUIRE_TRUE(dim < inShapeInfo[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShapeInfo[0], dim); - REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim); - - if(ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) { - - if(shape::shapeOf(inShapeInfo)[dim] == 0) - return SHAPELIST(); - - const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim]; - std::vector outShape; - for(uint i = 0; i < shape::rank(inShapeInfo); ++i) - if(i != dim) - outShape.push_back(shape::shapeOf(inShapeInfo)[i]); - - auto result = SHAPELIST(); - for(uint i = 0; i < numTads; ++i) - result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)); + auto inShapeInfo = inputShape->at(0); - return result; - } + auto dim = INT_ARG(0); + if (dim < 0) dim += shape::rank(inShapeInfo); - std::vector dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim}); + REQUIRE_TRUE(dim < inShapeInfo[0], 0, + "UNSTACK op: dimension should be lower then rank of input %i, " + "but got dimension=%i !", + inShapeInfo[0], dim); + REQUIRE_TRUE( + dim >= 0, 0, + "UNSTACK op: dimension should be non-negative value, but got %i !", dim); - if (dims.size() == 0 && shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars + if (ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) { + if (shape::shapeOf(inShapeInfo)[dim] == 0) return SHAPELIST(); - auto result = SHAPELIST(); - for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++) - result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShapeInfo))); + const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim]; + std::vector outShape; + for (uint i = 0; i < shape::rank(inShapeInfo); ++i) + if (i != dim) outShape.push_back(shape::shapeOf(inShapeInfo)[i]); - return result; - } - - std::vector subArrShape(shape::rank(inShapeInfo) - 1); + auto result = SHAPELIST(); + for (uint i = 0; i < numTads; ++i) + result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + outShape)); - for(uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++) - if(i != dim) - subArrShape[j++] = shape::shapeOf(inShapeInfo)[i]; + return result; + } - // remove leading and trailing 1 - if (inShapeInfo[0] == 2 && subArrShape.size() == 2) { + std::vector dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim}); - if (subArrShape[0] == 1) - subArrShape.erase(subArrShape.begin()); - else if (subArrShape[1] == 1) - subArrShape.erase(subArrShape.end()); - } + if (dims.size() == 0 && + shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars auto result = SHAPELIST(); - for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) { - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), subArrShape); - result->push_back(newShape); - } + for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++) + result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo( + ArrayOptions::dataType(inShapeInfo))); + return result; + } + + std::vector subArrShape(shape::rank(inShapeInfo) - 1); + + for (uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++) + if (i != dim) subArrShape[j++] = shape::shapeOf(inShapeInfo)[i]; + + // remove leading and trailing 1 + if (inShapeInfo[0] == 2 && subArrShape.size() == 2) { + if (subArrShape[0] == 1) + subArrShape.erase(subArrShape.begin()); + else if (subArrShape[1] == 1) + subArrShape.erase(subArrShape.end()); + } + + auto result = SHAPELIST(); + for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) { + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), + subArrShape); + result->push_back(newShape); + } + return result; } DECLARE_TYPES(unstack) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) - ->setSameMode(true); + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); } - -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp b/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp index e176797b0dc3..26b103af4bea 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/cell_contains.cpp @@ -25,30 +25,31 @@ #include namespace sd { - namespace ops { - - CUSTOM_OP_IMPL(cell_contains, 3, 1, false, 0, 1) { - auto corner = INPUT_VARIABLE(0); - auto width = INPUT_VARIABLE(1); - auto point = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - auto dimension = INT_ARG(0); - output->assign(helpers::cell_contains(corner, width, point, dimension)); - return Status::OK(); - } - - DECLARE_TYPES(cell_contains) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(sd::DataType::BOOL) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(cell_contains) { - return SHAPELIST(CONSTANT(ShapeBuilders::createScalarShapeInfo(sd::DataType::BOOL, block.workspace()))); - } - } +namespace ops { + +CUSTOM_OP_IMPL(cell_contains, 3, 1, false, 0, 1) { + auto corner = INPUT_VARIABLE(0); + auto width = INPUT_VARIABLE(1); + auto point = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + auto dimension = INT_ARG(0); + output->assign(helpers::cell_contains(corner, width, point, dimension)); + return Status::OK(); +} + +DECLARE_TYPES(cell_contains) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes(sd::DataType::BOOL) + ->setSameMode(false); +} + +DECLARE_SHAPE_FN(cell_contains) { + return SHAPELIST(CONSTANT(ShapeBuilders::createScalarShapeInfo( + sd::DataType::BOOL, block.workspace()))); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp index 11700be2c374..30da9ad68af8 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/edge_force.cpp @@ -25,45 +25,55 @@ #include namespace sd { -namespace ops { - - CUSTOM_OP_IMPL(barnes_edge_forces, 4, 1, false, 0, 1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); - auto dataP = INPUT_VARIABLE(3); - auto N = INT_ARG(0); +namespace ops { - auto output = OUTPUT_NULLIFIED(0); +CUSTOM_OP_IMPL(barnes_edge_forces, 4, 1, false, 0, 1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); + auto dataP = INPUT_VARIABLE(3); + auto N = INT_ARG(0); - REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf()); - REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf()); - REQUIRE_TRUE(dataP->dataType() == output->dataType() && dataP->dataType() == valP->dataType(), 0, "barnes_edge_force: data type of dataP, valP and output must be the same"); + auto output = OUTPUT_NULLIFIED(0); - helpers::barnes_edge_forces(rowP, colP, valP, N, output, *dataP); + REQUIRE_TRUE(rowP->isVector(), 0, + "barnes_edge_force: row input must be a vector, but its rank is " + "%i instead !", + rowP->rankOf()); + REQUIRE_TRUE(colP->isVector(), 0, + "barnes_edge_force: col input must be a vector, but its rank is " + "%i instead !", + colP->rankOf()); + REQUIRE_TRUE(dataP->dataType() == output->dataType() && + dataP->dataType() == valP->dataType(), + 0, + "barnes_edge_force: data type of dataP, valP and output must be " + "the same"); - return Status::OK(); - } - - DECLARE_TYPES(barnes_edge_forces) { - getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setAllowedInputTypes(2, {ALL_FLOATS}) - ->setAllowedInputTypes(3, {ALL_FLOATS}) - ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setSameMode(false); - } - - DECLARE_SHAPE_FN(barnes_edge_forces) { - Nd4jLong* bufShape; - Nd4jLong* outShapeInfo; - outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShape->at(3), inputShape->at(3), false, block.workspace()); - return SHAPELIST(CONSTANT(outShapeInfo)); - } + helpers::barnes_edge_forces(rowP, colP, valP, N, output, *dataP); + return Status::OK(); +} +DECLARE_TYPES(barnes_edge_forces) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}) + ->setSameMode(false); } + +DECLARE_SHAPE_FN(barnes_edge_forces) { + Nd4jLong* bufShape; + Nd4jLong* outShapeInfo; + outShapeInfo = ShapeBuilders::copyShapeInfoAndType( + inputShape->at(3), inputShape->at(3), false, block.workspace()); + return SHAPELIST(CONSTANT(outShapeInfo)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/gains.cpp b/libnd4j/include/ops/declarable/generic/tsne/gains.cpp index 4fb943483b79..87cd08bd398e 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/gains.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/gains.cpp @@ -25,25 +25,23 @@ #include namespace sd { -namespace ops { - - OP_IMPL(barnes_gains, 3, 1, true) { - auto input = INPUT_VARIABLE(0); - auto gradX = INPUT_VARIABLE(1); - auto epsilon = INPUT_VARIABLE(2); - - auto output = OUTPUT_VARIABLE(0); - - helpers::barnes_gains(input, gradX, epsilon, output); - return Status::OK(); - } - - DECLARE_TYPES(barnes_gains) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setSameMode(true); - } +namespace ops { + +OP_IMPL(barnes_gains, 3, 1, true) { + auto input = INPUT_VARIABLE(0); + auto gradX = INPUT_VARIABLE(1); + auto epsilon = INPUT_VARIABLE(2); + + auto output = OUTPUT_VARIABLE(0); + + helpers::barnes_gains(input, gradX, epsilon, output); + return Status::OK(); } + +DECLARE_TYPES(barnes_gains) { + getOpDescriptor()->setAllowedInputTypes(sd::DataType::ANY)->setSameMode(true); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 3f02bdd79c2a..2b26c6dbbf90 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author George A. Shulinok , created on 4/18/2019. - // +// +// @author George A. Shulinok , created on 4/18/2019. +// #include #if NOT_EXCLUDED(OP_barnes_symmetrized) @@ -25,69 +25,80 @@ #include namespace sd { - namespace ops { - NDArray* rowCountsPtr = nullptr; +namespace ops { +NDArray* rowCountsPtr = nullptr; - CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto valP = INPUT_VARIABLE(2); - auto N = rowP->lengthOf() - 1; - auto outputRows = OUTPUT_VARIABLE(0); - auto outputCols = OUTPUT_VARIABLE(1); - auto outputVals = OUTPUT_VARIABLE(2); +CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto valP = INPUT_VARIABLE(2); + auto N = rowP->lengthOf() - 1; + auto outputRows = OUTPUT_VARIABLE(0); + auto outputCols = OUTPUT_VARIABLE(1); + auto outputVals = OUTPUT_VARIABLE(2); - if (block.numI() > 0) - N = INT_ARG(0); + if (block.numI() > 0) N = INT_ARG(0); - if (rowCountsPtr) { - helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCountsPtr); - delete rowCountsPtr; - return Status::OK(); - } - return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); - } - - DECLARE_TYPES(barnes_symmetrized) { - getOpDescriptor() - ->setAllowedInputTypes(0, { DataType::INT32 }) - ->setAllowedInputTypes(1, { DataType::INT32 }) - ->setAllowedInputTypes(2, { ALL_INTS, ALL_FLOATS }) - ->setAllowedOutputTypes(1, { DataType::INT32 }) - ->setAllowedOutputTypes(1, { DataType::INT32 }) - ->setAllowedOutputTypes(2, { ALL_INTS, ALL_FLOATS }) - ->setSameMode(false); - } + if (rowCountsPtr) { + helpers::barnes_symmetrize(rowP, colP, valP, N, outputRows, outputCols, + outputVals, rowCountsPtr); + delete rowCountsPtr; + return Status::OK(); + } + return Status::THROW("barnes_symmetrized: Cannot loop due wrong input data."); +} - DECLARE_SHAPE_FN(barnes_symmetrized) { - auto valPShapeInfo = inputShape->at(2); - Nd4jLong* outShapeInfo; - auto rowP = INPUT_VARIABLE(0); - auto colP = INPUT_VARIABLE(1); - auto N = rowP->lengthOf() - 1; - if (block.numI() > 0) - N = INT_ARG(0); +DECLARE_TYPES(barnes_symmetrized) { + getOpDescriptor() + ->setAllowedInputTypes(0, {DataType::INT32}) + ->setAllowedInputTypes(1, {DataType::INT32}) + ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes(1, {DataType::INT32}) + ->setAllowedOutputTypes(1, {DataType::INT32}) + ->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) + ->setSameMode(false); +} - auto dataType = rowP->dataType(); //ArrayOptions::dataType(inputShape->at(0)); - NDArray* rowCounts = NDArrayFactory::create_('c', { N }, block.launchContext()); //rowP->dup(); - //srowCounts->assign(0); - Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); - rowCounts->syncToHost(); - // rowCounts->printBuffer("Row Counts"); - if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len."); - rowCountsPtr = rowCounts; - //ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); -// outShapeInfo[1] = 1; -// outShapeInfo[2] = len; - // ShapeUtils::updateStridesAndType(outShapeInfo, ArrayOptions::dataType(valPShapeInfo), 'c'); - //outShapeInfo = ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), len, block.workspace()); - outShapeInfo = sd::ShapeBuilders::createShapeInfo(ArrayOptions::dataType(valPShapeInfo), 'c', { 1, len }, block.workspace()); - auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, len }, block.workspace()); - auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo(dataType, 'c', { 1, N + 1 }, block.workspace()); - return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), CONSTANT(outShapeInfo)); - } +DECLARE_SHAPE_FN(barnes_symmetrized) { + auto valPShapeInfo = inputShape->at(2); + Nd4jLong* outShapeInfo; + auto rowP = INPUT_VARIABLE(0); + auto colP = INPUT_VARIABLE(1); + auto N = rowP->lengthOf() - 1; + if (block.numI() > 0) N = INT_ARG(0); - } + auto dataType = + rowP->dataType(); // ArrayOptions::dataType(inputShape->at(0)); + NDArray* rowCounts = NDArrayFactory::create_( + 'c', {N}, block.launchContext()); // rowP->dup(); + // srowCounts->assign(0); + Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts); + rowCounts->syncToHost(); + // rowCounts->printBuffer("Row Counts"); + if (len <= 0) + throw std::runtime_error( + "barnes_symmetrized: Cannot allocate shape due non-positive len."); + rowCountsPtr = rowCounts; + // ALLOCATE(outShapeInfo, block.workspace(), shape::shapeInfoLength(2), + // Nd4jLong); + // outShapeInfo[1] = 1; + // outShapeInfo[2] = len; + // ShapeUtils::updateStridesAndType(outShapeInfo, + // ArrayOptions::dataType(valPShapeInfo), 'c'); + // outShapeInfo = + // ShapeBuilders::createVectorShapeInfo(ArrayOptions::dataType(valPShapeInfo), + // len, block.workspace()); + outShapeInfo = sd::ShapeBuilders::createShapeInfo( + ArrayOptions::dataType(valPShapeInfo), 'c', {1, len}, block.workspace()); + auto outColsShapeInfo = sd::ShapeBuilders::createShapeInfo( + dataType, 'c', {1, len}, block.workspace()); + auto outRowsShapeInfo = sd::ShapeBuilders::createShapeInfo( + dataType, 'c', {1, N + 1}, block.workspace()); + return SHAPELIST(CONSTANT(outRowsShapeInfo), CONSTANT(outColsShapeInfo), + CONSTANT(outShapeInfo)); } +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp index acbb40774d66..f6fe9f0afac3 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp @@ -14,68 +14,81 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateMsg = INPUT_VARIABLE(1); - const auto initStateMsdx = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateMsg = OUTPUT_VARIABLE(1); - auto stateMsdx = OUTPUT_VARIABLE(2); - - if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateMsg->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateMsdx->shapeInfo()).c_str()); - - bool bParamsSupply = 5 == block.width() || 2 == block.numT(); - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); - - double dRho, dEpsilon; - - if (block.width() > 3) { - const auto rho = INPUT_VARIABLE(3); - const auto epsilon = INPUT_VARIABLE(4); - - REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dRho = rho->e(0); - dEpsilon = epsilon->e(0); - } - else { - dRho = T_ARG(0); - dEpsilon = T_ARG(1); - } - - helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(ada_delta_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateMsg = INPUT_VARIABLE(1); + const auto initStateMsdx = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateMsg = OUTPUT_VARIABLE(1); + auto stateMsdx = OUTPUT_VARIABLE(2); + + if (gradient->isEmpty() || initStateMsg->isEmpty() || + initStateMsdx->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, + "ADA_DELTA UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsg->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, + "ADA_DELTA UPDATER OP: input state Msdx must have the same " + "shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsdx->shapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE(bParamsSupply, 0, + "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); + + double dRho, dEpsilon; + + if (block.width() > 3) { + const auto rho = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(rho->isScalar(), 0, + "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead " + "got rank %i!", + rho->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but " + "instead got rank %i!", + epsilon->rankOf()); + + dRho = rho->e(0); + dEpsilon = epsilon->e(0); + } else { + dRho = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, + *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, + dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(ada_delta_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp index 05b084b351a5..569381cc5b73 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp @@ -14,64 +14,71 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateH = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - - bool bParamsSupply = 4 == block.width() || 2 == block.numT(); - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); - - double dLr, dEpsilon; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto epsilon = INPUT_VARIABLE(3); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dEpsilon = T_ARG(1); - } - - helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(ada_grad_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateH = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initState), 0, + "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE( + bParamsSupply, 0, + "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); + + double dLr, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto epsilon = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, + *stateH, dLr, dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(ada_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp index 9964f65ed955..8a9d98289112 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp @@ -14,80 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateU = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateU = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - - - bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - - int iteration = block.numI() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(ada_max_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateU), 0, + "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + int iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, + "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead " + "got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, + "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead " + "got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, + *initStateM, *update, *stateU, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, iteration); + return Status::OK(); +} - } +DECLARE_TYPES(ada_max_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp index d508a8bd14aa..354d0d48a6f4 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp @@ -14,79 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateU = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateU = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - - bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - - auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(adam_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateU), 0, + "ADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "ADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "ADAM UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE( + beta1->isScalar(), 0, + "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE( + beta2->isScalar(), 0, + "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE( + epsilon->isScalar(), 0, + "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, + *initStateM, *update, *stateU, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, iteration); + return Status::OK(); } + +DECLARE_TYPES(adam_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp index 0e5e6052d5fa..5e71fb3710fe 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp @@ -14,85 +14,107 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateV = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - const auto initStateH = INPUT_VARIABLE(3); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - auto stateH = OUTPUT_VARIABLE(3); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateH->shapeInfo()).c_str()); - - bool bParamsSupply = 8 == block.width() || 4 == block.numT(); - - auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 4) { - const auto lr = INPUT_VARIABLE(4); - const auto beta1 = INPUT_VARIABLE(5); - const auto beta2 = INPUT_VARIABLE(6); - const auto epsilon = INPUT_VARIABLE(7); - - REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH, - *update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); - return Status::OK(); - } - - DECLARE_TYPES(ams_grad_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + const auto initStateH = INPUT_VARIABLE(3); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + auto stateH = OUTPUT_VARIABLE(3); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || + initStateH->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, + "AMSGRAD UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, + "AMSGRAD UPDATER OP: input state Msdx must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, + "AMSGRAD UPDATER OP: input state Msdx must have the same shape " + "as gradient!," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateH->shapeInfo()).c_str()); + + bool bParamsSupply = 8 == block.width() || 4 == block.numT(); + + auto iteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 4) { + const auto lr = INPUT_VARIABLE(4); + const auto beta1 = INPUT_VARIABLE(5); + const auto beta2 = INPUT_VARIABLE(6); + const auto epsilon = INPUT_VARIABLE(7); + + REQUIRE_TRUE(lr->isScalar(), 0, + "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, + "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead " + "got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, + "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead " + "got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, + *initStateM, *initStateH, *update, *stateV, *stateM, + *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); } + +DECLARE_TYPES(ams_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp index 7fa65bc723df..df3e95d6d777 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp @@ -14,79 +14,98 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initStateV = INPUT_VARIABLE(1); - const auto initStateM = INPUT_VARIABLE(2); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - auto stateM = OUTPUT_VARIABLE(2); - - // todo maybe we need an error like on Java side - if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); - REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); - - bool bParamsSupply = 7 == block.width() || 4 == block.numT(); - - auto nIteration = block.numI() > 0 ? INT_ARG(0) : 0; - - REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); - - double dLr, dBeta1, dBeta2, dEpsilon; - - if (block.width() > 3) { - const auto lr = INPUT_VARIABLE(3); - const auto beta1 = INPUT_VARIABLE(4); - const auto beta2 = INPUT_VARIABLE(5); - const auto epsilon = INPUT_VARIABLE(6); - - REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); - REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dBeta1 = beta1->e(0); - dBeta2 = beta2->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dBeta1 = T_ARG(1); - dBeta2 = T_ARG(2); - dEpsilon = T_ARG(3); - } - - helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration); - return Status::OK(); - } - - DECLARE_TYPES(nadam_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - } +namespace ops { + +CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initStateM), 0, + "NADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + REQUIRE_TRUE( + gradient->isSameShape(initStateV), 0, + "NADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.numT(); + + auto nIteration = block.numI() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, + "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon " + "were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, + "NADAM UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE( + beta1->isScalar(), 0, + "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", + beta1->rankOf()); + REQUIRE_TRUE( + beta2->isScalar(), 0, + "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", + beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "NADAM UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, + *initStateM, *update, *stateV, *stateM, dLr, dBeta1, + dBeta2, dEpsilon, nIteration); + return Status::OK(); } + +DECLARE_TYPES(nadam_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); +} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp index 22b7f4042e5b..d685ecbdc904 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp @@ -14,62 +14,70 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateV = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - bool bParamsSupply = 4 == block.width() || 2 == block.numT(); - - REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); - - double dLr, dMomentum; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto momentum = INPUT_VARIABLE(3); - - REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf()); - - dLr = lr->e(0); - dMomentum = momentum->e(0); - } - else { - dLr = T_ARG(0); - dMomentum = T_ARG(1); - } - helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum); - return Status::OK(); - } - - DECLARE_TYPES(nesterovs_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, + "NESTEROVS UPDATER OP: input state Msg must have the same shape " + "as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.numT(); + + REQUIRE_TRUE( + bParamsSupply, 0, + "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); + + double dLr, dMomentum; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto momentum = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, + "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(momentum->isScalar(), 0, + "NESTEROVS UPDATER OP: Momentum has to be a scalar, but " + "instead got rank %i!", + momentum->rankOf()); + + dLr = lr->e(0); + dMomentum = momentum->e(0); + } else { + dLr = T_ARG(0); + dMomentum = T_ARG(1); + } + helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, + *update, *stateV, dLr, dMomentum); + return Status::OK(); +} - } +DECLARE_TYPES(nesterovs_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp index 3aed0e197c22..84030f5df0f6 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp @@ -14,67 +14,78 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { - - const auto gradient = INPUT_VARIABLE(0); - const auto initState = INPUT_VARIABLE(1); - - auto update = OUTPUT_VARIABLE(0); - auto stateG = OUTPUT_VARIABLE(1); - - if (gradient->isEmpty() || initState->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient," - " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), - ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); - - bool bParamsSupply = 5 == block.width() || 3 == block.numT(); - - REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!"); - - double dLr, dRmsDecay, dEpsilon; - - if (block.width() > 2) { - const auto lr = INPUT_VARIABLE(2); - const auto rmsDecay = INPUT_VARIABLE(3); - const auto epsilon = INPUT_VARIABLE(4); - - REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); - REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf()); - REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); - - dLr = lr->e(0); - dRmsDecay = rmsDecay->e(0); - dEpsilon = epsilon->e(0); - } - else { - dLr = T_ARG(0); - dRmsDecay = T_ARG(1); - dEpsilon = T_ARG(2); - } - - helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon); - return Status::OK(); - } - - DECLARE_TYPES(rms_prop_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } +namespace ops { + +CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateG = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) return Status::OK(); + + REQUIRE_TRUE( + gradient->isSameShape(initState), 0, + "RMS_PROB UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", + ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->shapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 3 == block.numT(); + + REQUIRE_TRUE(bParamsSupply, 0, + "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were " + "not provided!"); + + double dLr, dRmsDecay, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto rmsDecay = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(lr->isScalar(), 0, + "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); + REQUIRE_TRUE(rmsDecay->isScalar(), 0, + "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but " + "instead got rank %i!", + rmsDecay->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, + "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead " + "got rank %i!", + epsilon->rankOf()); + + dLr = lr->e(0); + dRmsDecay = rmsDecay->e(0); + dEpsilon = epsilon->e(0); + } else { + dLr = T_ARG(0); + dRmsDecay = T_ARG(1); + dEpsilon = T_ARG(2); + } + + helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, + *stateG, dLr, dRmsDecay, dEpsilon); + return Status::OK(); +} - } +DECLARE_TYPES(rms_prop_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp index 829db3dcee9c..07abb72d1833 100644 --- a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp +++ b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp @@ -14,48 +14,48 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - - CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { +namespace ops { - const auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); +CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - if (input->isEmpty()) - return Status::OK(); + if (input->isEmpty()) return Status::OK(); - bool bLearningRate = 2 == block.width() || 1 == block.numT(); + bool bLearningRate = 2 == block.width() || 1 == block.numT(); - REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!"); + REQUIRE_TRUE(bLearningRate, 0, + "SGD UPDATER OP: Learning rate was not provided!"); - if (block.width() > 1) { - const auto lr = INPUT_VARIABLE(1); - REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + if (block.width() > 1) { + const auto lr = INPUT_VARIABLE(1); + REQUIRE_TRUE(lr->isScalar(), 0, + "SGD UPDATER OP: Learning rate has to be a scalar, but " + "instead got rank %i!", + lr->rankOf()); - input->applyScalarArr(scalar::Multiply, *lr, *output); - } - else { - input->applyScalar(scalar::Multiply, T_ARG(0), *output); - } + input->applyScalarArr(scalar::Multiply, *lr, *output); + } else { + input->applyScalar(scalar::Multiply, T_ARG(0), *output); + } - return Status::OK(); - } - - DECLARE_TYPES(sgd_updater) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } + return Status::OK(); +} - } +DECLARE_TYPES(sgd_updater) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp index 5518588e4659..d1d586ff6eba 100644 --- a/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp +++ b/libnd4j/include/ops/declarable/generic/util/print_affinity.cpp @@ -25,28 +25,35 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { - // TODO: make this op compatible with ArrayList etc - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - nd4j_printf(": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: [HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", block.nodeId(), input->isActualOnHostSide() ? "true" : "false", input->isActualOnDeviceSide() ? "true" : "false", input->dataBuffer()->deviceId(), input->buffer(), input->specialBuffer(), input->dataBuffer()->getLenInBytes()); - - return Status::OK(); - } - - DECLARE_TYPES(print_affinity) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_STRINGS}) - ->setAllowedOutputTypes(0, sd::DataType::INT32); - } - - DECLARE_SHAPE_FN(print_affinity) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); - } - } +namespace ops { +CUSTOM_OP_IMPL(print_affinity, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + nd4j_printf( + ": Actuality: [HOST: %s; DEVICE: %s]; affinity: [%i]; Pointers: " + "[HOST: %p; DEVICE: %p]; DataBuffer length: %lld\n", + block.nodeId(), input->isActualOnHostSide() ? "true" : "false", + input->isActualOnDeviceSide() ? "true" : "false", + input->dataBuffer()->deviceId(), input->buffer(), input->specialBuffer(), + input->dataBuffer()->getLenInBytes()); + + return Status::OK(); } +DECLARE_TYPES(print_affinity) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, sd::DataType::INT32); +} + +DECLARE_SHAPE_FN(print_affinity) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); +} +} // namespace ops +} // namespace sd + #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp index 9d3369627432..c645a970d292 100644 --- a/libnd4j/include/ops/declarable/generic/util/print_variable.cpp +++ b/libnd4j/include/ops/declarable/generic/util/print_variable.cpp @@ -24,54 +24,56 @@ #include namespace sd { - namespace ops { - CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { - // TODO: make this op compatible with ArrayList etc - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - std::string str; +namespace ops { +CUSTOM_OP_IMPL(print_variable, 1, 1, true, 0, 0) { + // TODO: make this op compatible with ArrayList etc + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + std::string str; - if (block.width() == 2) { - auto message = INPUT_VARIABLE(1); - REQUIRE_TRUE(message->isS(), 0, "print_variable: message variable must be a String"); + if (block.width() == 2) { + auto message = INPUT_VARIABLE(1); + REQUIRE_TRUE(message->isS(), 0, + "print_variable: message variable must be a String"); - str = message->e(0); - } + str = message->e(0); + } - bool printSpecial = false; - if (block.numB() > 0) - printSpecial = B_ARG(0); + bool printSpecial = false; + if (block.numB() > 0) printSpecial = B_ARG(0); - if (printSpecial && !sd::Environment::getInstance()->isCPU()) { - // only specific backends support special printout. for cpu-based backends it's the same as regular print + if (printSpecial && !sd::Environment::getInstance()->isCPU()) { + // only specific backends support special printout. for cpu-based backends + // it's the same as regular print - if (block.width() == 2) - helpers::print_special(*block.launchContext(), *input, str); - else - helpers::print_special(*block.launchContext(), *input); - } else { - // optionally add message to the print out - if (block.width() == 2) { - input->printIndexedBuffer(str.c_str()); - } else { - input->printIndexedBuffer(); - } - } + if (block.width() == 2) + helpers::print_special(*block.launchContext(), *input, str); + else + helpers::print_special(*block.launchContext(), *input); + } else { + // optionally add message to the print out + if (block.width() == 2) { + input->printIndexedBuffer(str.c_str()); + } else { + input->printIndexedBuffer(); + } + } - return Status::OK(); - } + return Status::OK(); +} - DECLARE_TYPES(print_variable) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_STRINGS}) - ->setAllowedOutputTypes(0, sd::DataType::INT32); - } +DECLARE_TYPES(print_variable) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_STRINGS}) + ->setAllowedOutputTypes(0, sd::DataType::INT32); +} - DECLARE_SHAPE_FN(print_variable) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); - } - } +DECLARE_SHAPE_FN(print_variable) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT32)); } +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h index 3f0d86e19f11..858ac32c36b3 100644 --- a/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/headers/BarnesHutTsne.h @@ -24,78 +24,79 @@ #include namespace sd { - namespace ops { - /** - * This operation used as helper with BarnesHutTsne class - * to compute edge forces using barnes hut - * - * Expected input: - * 0: 1D row-vector (or with shape (1, m)) - * 1: 1D integer vector with slice nums - * 2: 1D float-point values vector with same shape as above - * 3: 2D float-point matrix with data to search - * - * Int args: - * 0: N - number of slices - * - * Output: - * 0: 2D matrix with the same shape and type as the 3th argument - */ - #if NOT_EXCLUDED(OP_barnes_edge_forces) - DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); - #endif +namespace ops { +/** + * This operation used as helper with BarnesHutTsne class + * to compute edge forces using barnes hut + * + * Expected input: + * 0: 1D row-vector (or with shape (1, m)) + * 1: 1D integer vector with slice nums + * 2: 1D float-point values vector with same shape as above + * 3: 2D float-point matrix with data to search + * + * Int args: + * 0: N - number of slices + * + * Output: + * 0: 2D matrix with the same shape and type as the 3th argument + */ +#if NOT_EXCLUDED(OP_barnes_edge_forces) +DECLARE_CUSTOM_OP(barnes_edge_forces, 4, 1, false, 0, 1); +#endif - /** - * This operation used as helper with BarnesHutTsne class - * to Symmetrize the value matrix - * - * Expected input: - * 0: 1D int row-vector - * 1: 1D int col-vector - * 2: 1D float vector with values - * - * Output: - * 0: 1D int result row-vector - * 1: 1D int result col-vector - * 2: a float-point tensor with shape 1xN, with values from the last input vector - */ - #if NOT_EXCLUDED(OP_barnes_symmetrized) - DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); - #endif +/** + * This operation used as helper with BarnesHutTsne class + * to Symmetrize the value matrix + * + * Expected input: + * 0: 1D int row-vector + * 1: 1D int col-vector + * 2: 1D float vector with values + * + * Output: + * 0: 1D int result row-vector + * 1: 1D int result col-vector + * 2: a float-point tensor with shape 1xN, with values from the last input + * vector + */ +#if NOT_EXCLUDED(OP_barnes_symmetrized) +DECLARE_CUSTOM_OP(barnes_symmetrized, 3, 3, false, 0, -1); +#endif - /** - * This operation used as helper with BranesHutTsne class - * to compute x = x + 2 * yGrads / abs(yGrads) != yIncs / abs(yIncs) - * - * Expected input: - * 0: input tensor - * 1: input gradient - * 2: gradient step tensor - * - * Output: - * 0: result of expression above - */ - #if NOT_EXCLUDED(OP_barnes_gains) - DECLARE_OP(barnes_gains, 3, 1, true); - #endif +/** + * This operation used as helper with BranesHutTsne class + * to compute x = x + 2 * yGrads / abs(yGrads) != yIncs / abs(yIncs) + * + * Expected input: + * 0: input tensor + * 1: input gradient + * 2: gradient step tensor + * + * Output: + * 0: result of expression above + */ +#if NOT_EXCLUDED(OP_barnes_gains) +DECLARE_OP(barnes_gains, 3, 1, true); +#endif - /** - * This operation used as helper with Cell class - * to check vals in given set - * - * Expected input: - * 0: 1D float row-vector (corners) - * 1: 1D float col-vector (widths) - * 2: 1D float vector (point) - * - * Output: - * 0: bool val - */ - #if NOT_EXCLUDED(OP_cell_contains) - DECLARE_CUSTOM_OP(cell_contains, 3, 1, false, 0, 1); - #endif +/** + * This operation used as helper with Cell class + * to check vals in given set + * + * Expected input: + * 0: 1D float row-vector (corners) + * 1: 1D float col-vector (widths) + * 2: 1D float vector (point) + * + * Output: + * 0: bool val + */ +#if NOT_EXCLUDED(OP_cell_contains) +DECLARE_CUSTOM_OP(cell_contains, 3, 1, false, 0, 1); +#endif - } -} +} // namespace ops +} // namespace sd -#endif // LIBND4J_HEADERS_BARNES_HUT_TSNE_H +#endif // LIBND4J_HEADERS_BARNES_HUT_TSNE_H diff --git a/libnd4j/include/ops/declarable/headers/activations.h b/libnd4j/include/ops/declarable/headers/activations.h index db9e8186a062..4eccc17c5e1e 100644 --- a/libnd4j/include/ops/declarable/headers/activations.h +++ b/libnd4j/include/ops/declarable/headers/activations.h @@ -21,180 +21,177 @@ #ifndef LIBND4J_HEADERS_ACTIVATIONS_H #define LIBND4J_HEADERS_ACTIVATIONS_H - #include namespace sd { - namespace ops { - /** - * This is Sigmoid activation function implementation - * Math is: 1 / 1 + exp(-x) - */ - #if NOT_EXCLUDED(OP_sigmoid) - DECLARE_CONFIGURABLE_OP(sigmoid, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(sigmoid_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Softsign activation function implementation - * Math is: x / 1 + abs(x) - */ - #if NOT_EXCLUDED(OP_softsign) - DECLARE_CONFIGURABLE_OP(softsign, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softsign_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Tanh activation function implementation - */ - #if NOT_EXCLUDED(OP_tanh) - DECLARE_CONFIGURABLE_OP(tanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(tanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Softplus activation function implementation - * Math is: log(1 + exp(x)) - */ - #if NOT_EXCLUDED(OP_softplus) - DECLARE_CONFIGURABLE_OP(softplus, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softplus_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RELU activation function implementation - */ - #if NOT_EXCLUDED(OP_relu) - DECLARE_CONFIGURABLE_OP(relu, 1, 1, true, 1, 0); - DECLARE_CONFIGURABLE_OP(relu_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is SELU activation function implementation - */ - #if NOT_EXCLUDED(OP_selu) - DECLARE_CONFIGURABLE_OP(selu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(selu_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Leaky RELU activation function. - * Math is: x < 0 ? alpha * x : x; - */ - #if NOT_EXCLUDED(OP_lrelu) - DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0); - #endif - - /** - * This op is ELU activation function. - * Math is: x >= 0 ? x : exp(x) - 1; - */ - #if NOT_EXCLUDED(OP_elu) - DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0); - #endif - - /** - * This is Cube activation function. - * Math is: x^3 - */ - #if NOT_EXCLUDED(OP_cube) - DECLARE_CONFIGURABLE_OP(cube, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(cube_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RectifiedTanh activation function. - * Math is: max(0, tanh(x)) - */ - #if NOT_EXCLUDED(OP_rectifiedtanh) - DECLARE_CONFIGURABLE_OP(rectifiedtanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(rectifiedtanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is RationalTanh activation function. - */ - #if NOT_EXCLUDED(OP_rationaltanh) - DECLARE_CONFIGURABLE_OP(rationaltanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(rationaltanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is HardTanh activation function. - * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; - */ - #if NOT_EXCLUDED(OP_hardtanh) - DECLARE_CONFIGURABLE_OP(hardtanh, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(hardtanh_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is HardSigmoid activation function. - * Math is: min(1, max(0, 0.2 * x + 0.5)) - */ - #if NOT_EXCLUDED(OP_hardsigmoid) - DECLARE_CONFIGURABLE_OP(hardsigmoid, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(hardsigmoid_bp, 2, 1, true, 0, 0); - #endif - - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ - #if NOT_EXCLUDED(OP_identity) - DECLARE_OP(identity, 1, 1, true); - DECLARE_OP(identity_bp, 2, 1, true); - #endif - - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ - #if NOT_EXCLUDED(OP_identity_n) - DECLARE_CUSTOM_OP(identity_n, 1, 1, true, 0, 0); - #endif - - /** - * This is Concatenated RELU implementation. - * What happens inside: RELU(Concat((x, -x, {-1}))) - * - * PLEASE NOTE: Concatenation will double amount of features available in input - */ - #if NOT_EXCLUDED(OP_crelu) - DECLARE_CUSTOM_OP(crelu, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(crelu_bp, 2, 1, false, 0, 0); - #endif - - /** - * This is RELU6 activation function implementation - */ - #if NOT_EXCLUDED(OP_relu6) - DECLARE_CONFIGURABLE_OP(relu6, 1, 1, true, 1, 0); - DECLARE_CONFIGURABLE_OP(relu6_bp, 2, 1, true, 0, 0); - #endif - - - /** - * Parametric Rectified Linear Unit - * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 - */ - #if NOT_EXCLUDED(OP_prelu) - DECLARE_CONFIGURABLE_OP(prelu, 2, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(prelu_bp, 3, 2, true, 0, 0); - #endif - - /** - * Thresholded Rectified Linear Unit - * f(x) = x for x > theta, f(x) = 0 otherwise - * theta must be >= 0 - */ - #if NOT_EXCLUDED(OP_thresholdedrelu) - DECLARE_CONFIGURABLE_OP(thresholdedrelu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(thresholdedrelu_bp, 2, 1, true, 0, 0); - #endif - - - } -} +namespace ops { +/** + * This is Sigmoid activation function implementation + * Math is: 1 / 1 + exp(-x) + */ +#if NOT_EXCLUDED(OP_sigmoid) +DECLARE_CONFIGURABLE_OP(sigmoid, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(sigmoid_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Softsign activation function implementation + * Math is: x / 1 + abs(x) + */ +#if NOT_EXCLUDED(OP_softsign) +DECLARE_CONFIGURABLE_OP(softsign, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softsign_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Tanh activation function implementation + */ +#if NOT_EXCLUDED(OP_tanh) +DECLARE_CONFIGURABLE_OP(tanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(tanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Softplus activation function implementation + * Math is: log(1 + exp(x)) + */ +#if NOT_EXCLUDED(OP_softplus) +DECLARE_CONFIGURABLE_OP(softplus, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softplus_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RELU activation function implementation + */ +#if NOT_EXCLUDED(OP_relu) +DECLARE_CONFIGURABLE_OP(relu, 1, 1, true, 1, 0); +DECLARE_CONFIGURABLE_OP(relu_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is SELU activation function implementation + */ +#if NOT_EXCLUDED(OP_selu) +DECLARE_CONFIGURABLE_OP(selu, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(selu_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Leaky RELU activation function. + * Math is: x < 0 ? alpha * x : x; + */ +#if NOT_EXCLUDED(OP_lrelu) +DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0); +DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0); +#endif + +/** + * This op is ELU activation function. + * Math is: x >= 0 ? x : exp(x) - 1; + */ +#if NOT_EXCLUDED(OP_elu) +DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0); +DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0); +#endif + +/** + * This is Cube activation function. + * Math is: x^3 + */ +#if NOT_EXCLUDED(OP_cube) +DECLARE_CONFIGURABLE_OP(cube, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(cube_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RectifiedTanh activation function. + * Math is: max(0, tanh(x)) + */ +#if NOT_EXCLUDED(OP_rectifiedtanh) +DECLARE_CONFIGURABLE_OP(rectifiedtanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rectifiedtanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is RationalTanh activation function. + */ +#if NOT_EXCLUDED(OP_rationaltanh) +DECLARE_CONFIGURABLE_OP(rationaltanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rationaltanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is HardTanh activation function. + * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; + */ +#if NOT_EXCLUDED(OP_hardtanh) +DECLARE_CONFIGURABLE_OP(hardtanh, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hardtanh_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is HardSigmoid activation function. + * Math is: min(1, max(0, 0.2 * x + 0.5)) + */ +#if NOT_EXCLUDED(OP_hardsigmoid) +DECLARE_CONFIGURABLE_OP(hardsigmoid, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hardsigmoid_bp, 2, 1, true, 0, 0); +#endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +#if NOT_EXCLUDED(OP_identity) +DECLARE_OP(identity, 1, 1, true); +DECLARE_OP(identity_bp, 2, 1, true); +#endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +#if NOT_EXCLUDED(OP_identity_n) +DECLARE_CUSTOM_OP(identity_n, 1, 1, true, 0, 0); +#endif + +/** + * This is Concatenated RELU implementation. + * What happens inside: RELU(Concat((x, -x, {-1}))) + * + * PLEASE NOTE: Concatenation will double amount of features available in input + */ +#if NOT_EXCLUDED(OP_crelu) +DECLARE_CUSTOM_OP(crelu, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(crelu_bp, 2, 1, false, 0, 0); +#endif + +/** + * This is RELU6 activation function implementation + */ +#if NOT_EXCLUDED(OP_relu6) +DECLARE_CONFIGURABLE_OP(relu6, 1, 1, true, 1, 0); +DECLARE_CONFIGURABLE_OP(relu6_bp, 2, 1, true, 0, 0); +#endif + +/** + * Parametric Rectified Linear Unit + * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 + */ +#if NOT_EXCLUDED(OP_prelu) +DECLARE_CONFIGURABLE_OP(prelu, 2, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(prelu_bp, 3, 2, true, 0, 0); +#endif + +/** + * Thresholded Rectified Linear Unit + * f(x) = x for x > theta, f(x) = 0 otherwise + * theta must be >= 0 + */ +#if NOT_EXCLUDED(OP_thresholdedrelu) +DECLARE_CONFIGURABLE_OP(thresholdedrelu, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(thresholdedrelu_bp, 2, 1, true, 0, 0); +#endif + +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index b5f29896f779..db8188267dc2 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -24,107 +24,109 @@ #include namespace sd { - namespace ops { - /** - * This operation toggles individual bits of each element in array - * - * PLEASE NOTE: This operation is possible only on integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_toggle_bits) - DECLARE_OP(toggle_bits, -1, -1, true); - #endif - +namespace ops { +/** + * This operation toggles individual bits of each element in array + * + * PLEASE NOTE: This operation is possible only on integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_toggle_bits) +DECLARE_OP(toggle_bits, -1, -1, true); +#endif - /** - * This operation shift individual bits of each element in array to the left: << - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_shift_bits) - DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_shift_bits) +DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array to the right: >> - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_rshift_bits) - DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array to the right: + * >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_rshift_bits) +DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array, shifting to the left - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_cyclic_shift_bits) - DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array, shifting to + * the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_cyclic_shift_bits) +DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); +#endif - /** - * This operation shift individual bits of each element in array, shifting to the right - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); - #endif +/** + * This operation shift individual bits of each element in array, shifting to + * the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_cyclic_rshift_bits) +DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); +#endif - /** - * This operation applies bitwise AND - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_and) - DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); - #endif +/** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_and) +DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0); +#endif - /** - * This operation applies bitwise OR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_or) - DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); - #endif +/** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_or) +DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0); +#endif - /** - * This operation applies bitwise XOR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bitwise_xor) - DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); - #endif +/** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bitwise_xor) +DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0); +#endif - /** - * This operation returns hamming distance based on bits - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_bits_hamming_distance) - DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0); - #endif - } -} +/** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_bits_hamming_distance) +DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/blas.h b/libnd4j/include/ops/declarable/headers/blas.h index 09215e1139f5..05f31a6ca738 100644 --- a/libnd4j/include/ops/declarable/headers/blas.h +++ b/libnd4j/include/ops/declarable/headers/blas.h @@ -23,92 +23,94 @@ #include namespace sd { - namespace ops { - - /** - * This op is general matmum implementation. Depending on inputs dimensionality output result might be different. - * matrix x matrix = BLAS gemm - * vector x matrix = BLAS gemm - * vector x vector = BLAS dot - * vector x scalar = element-wise mul - * scalar x vector = element-wise mul - * - * Optional T arguments: - * 0: alpha (where applicable) - * 1: beta (where applicable) - * - * Optional Integer arguments: - * 0: transA (where applicable) - * 1: transB (where applicable) - */ - #if NOT_EXCLUDED(OP_matmul) - DECLARE_CUSTOM_OP(matmul, 2, 1, false, 0, -2); - DECLARE_CUSTOM_OP(matmul_bp, 3, 2, false, 0, -2); - #endif +namespace ops { - /** - * tensorMmul/tensorDot operation - * takes 2 ndarrays, and 2 sets of axes - * - * Integer argumens map: - * IArgs[0] - number of axes along for first array - * IArgs[1]... axes values for first array - * IArgs[] - number of axes along for second array - * IArgs[1]... axes values for second array - */ - #if NOT_EXCLUDED(OP_tensormmul) - DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1); - DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1); - #endif +/** + * This op is general matmum implementation. Depending on inputs dimensionality + * output result might be different. matrix x matrix = BLAS gemm vector x matrix + * = BLAS gemm vector x vector = BLAS dot vector x scalar = element-wise mul + * scalar x vector = element-wise mul + * + * Optional T arguments: + * 0: alpha (where applicable) + * 1: beta (where applicable) + * + * Optional Integer arguments: + * 0: transA (where applicable) + * 1: transB (where applicable) + */ +#if NOT_EXCLUDED(OP_matmul) +DECLARE_CUSTOM_OP(matmul, 2, 1, false, 0, -2); +DECLARE_CUSTOM_OP(matmul_bp, 3, 2, false, 0, -2); +#endif - /** - * This op is simple implementation of BLAS AXPY method. - * Math is: y += a * x; - */ - #if NOT_EXCLUDED(OP_axpy) - DECLARE_CONFIGURABLE_OP(axpy, 2, 1, false, -2, 0); - #endif +/** + * tensorMmul/tensorDot operation + * takes 2 ndarrays, and 2 sets of axes + * + * Integer argumens map: + * IArgs[0] - number of axes along for first array + * IArgs[1]... axes values for first array + * IArgs[] - number of axes along for second array + * IArgs[1]... axes values for second array + */ +#if NOT_EXCLUDED(OP_tensormmul) +DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1); +DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1); +#endif - /** - * This operation implements batched matrix multiplication - * Expected arguments: - * alpha: vector of T - * beta: vector of T - * ...: A, B matrices sequentially. i.e: AAAAABBBBB - * - * Integer arguments: - * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments - * batchCount - number of operations in this batch - * - * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within batch. - */ - #if NOT_EXCLUDED(OP_batched_gemm) - DECLARE_CUSTOM_OP(batched_gemm, -1, -1, false, 0, 9); - #endif +/** + * This op is simple implementation of BLAS AXPY method. + * Math is: y += a * x; + */ +#if NOT_EXCLUDED(OP_axpy) +DECLARE_CONFIGURABLE_OP(axpy, 2, 1, false, -2, 0); +#endif - /** - * performs singular value decomposition (SVD) of one or more matrices, evaluates the SVD of each inner-most 2D matrix in input array: - * x[..., :, :] = u[..., :, :] * s[...,:] * transpose(v[..., :, :]) - * - * Input array: - * x[..., Rows, Cols], the necessary condition is: rank of x >= 2 - * - * Outputs arrays: - * s[..., diagSize] - array with singular values which are stored in decreasing order, diagSize is smaller among Rows and Cols - * u[..., Rows, Rows] if IArgs[1] is true, else u[..., Rows, diagSize] - array with right singular vectors - * v[..., Cols, Cols] if IArgs[1] is true, else v[..., Cols, diagSize] - array with left singular vectors - * - * Integer arguments: - * IArgs[0] - bool, whether to calculate u and v, s is calculated in any case - * IArgs[1] - bool, whether to calculate full-sized u and v - * IArgs[2] - the number of cols or rows which determines what algorithm to use. More precisely: - * if diagSize < IArgs[2] then Jacobi algorithm is used, in opposite case the Divide-And-Conquer is applied - * Recommended value is 16. - */ - #if NOT_EXCLUDED(OP_svd) - DECLARE_CUSTOM_OP(svd, 1, 1, false, 0, 3); - #endif - } -} +/** + * This operation implements batched matrix multiplication + * Expected arguments: + * alpha: vector of T + * beta: vector of T + * ...: A, B matrices sequentially. i.e: AAAAABBBBB + * + * Integer arguments: + * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments + * batchCount - number of operations in this batch + * + * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within + * batch. + */ +#if NOT_EXCLUDED(OP_batched_gemm) +DECLARE_CUSTOM_OP(batched_gemm, -1, -1, false, 0, 9); +#endif + +/** + * performs singular value decomposition (SVD) of one or more matrices, + * evaluates the SVD of each inner-most 2D matrix in input array: x[..., :, :] = + * u[..., :, :] * s[...,:] * transpose(v[..., :, :]) + * + * Input array: + * x[..., Rows, Cols], the necessary condition is: rank of x >= 2 + * + * Outputs arrays: + * s[..., diagSize] - array with singular values which are stored in decreasing + * order, diagSize is smaller among Rows and Cols u[..., Rows, Rows] if IArgs[1] + * is true, else u[..., Rows, diagSize] - array with right singular vectors + * v[..., Cols, Cols] if IArgs[1] is true, else v[..., Cols, diagSize] - array + * with left singular vectors + * + * Integer arguments: + * IArgs[0] - bool, whether to calculate u and v, s is calculated in any case + * IArgs[1] - bool, whether to calculate full-sized u and v + * IArgs[2] - the number of cols or rows which determines what algorithm to use. + * More precisely: if diagSize < IArgs[2] then Jacobi algorithm is used, in + * opposite case the Divide-And-Conquer is applied Recommended value is 16. + */ +#if NOT_EXCLUDED(OP_svd) +DECLARE_CUSTOM_OP(svd, 1, 1, false, 0, 3); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/boolean.h b/libnd4j/include/ops/declarable/headers/boolean.h index 75e95f630349..c61179c05e84 100644 --- a/libnd4j/include/ops/declarable/headers/boolean.h +++ b/libnd4j/include/ops/declarable/headers/boolean.h @@ -24,133 +24,137 @@ #include namespace sd { - namespace ops { - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x < y - */ - #if NOT_EXCLUDED(OP_lt_scalar) - DECLARE_BOOLEAN_OP(lt_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x > y - */ - #if NOT_EXCLUDED(OP_gt_scalar) - DECLARE_BOOLEAN_OP(gt_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x <= y - */ - #if NOT_EXCLUDED(OP_lte_scalar) - DECLARE_BOOLEAN_OP(lte_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x >= y - */ - #if NOT_EXCLUDED(OP_gte_scalar) - DECLARE_BOOLEAN_OP(gte_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if both operands are equal. - */ - #if NOT_EXCLUDED(OP_eq_scalar) - DECLARE_BOOLEAN_OP(eq_scalar, 2, true); - #endif - - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x != y - */ - #if NOT_EXCLUDED(OP_neq_scalar) - DECLARE_BOOLEAN_OP(neq_scalar, 2, true); - #endif - - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ - #if NOT_EXCLUDED(OP_where) - DECLARE_CUSTOM_OP(Where, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_where_np) - DECLARE_CUSTOM_OP(where_np, 1, 1, false, 0, 0); - #endif - - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ - #if NOT_EXCLUDED(OP_select) - DECLARE_CUSTOM_OP(select, 3, 1, false, 0, 0); - #endif - - /** - * This op takes either 1 argument and 1 scalar - * or 1 argument and another comparison array - * and runs a pre defined conditional op. - * - * The output of the op is dynamic in size and returns a flat vector of elements - * that return true on the given condition. - * In numpy parlance, most people might understand: - * a[a > 2] - * where a is a numpy array and the condition is true when an element is - * > 2. Libnd4j already implements a number of pre defined conditions. - * @tparam T - */ - #if NOT_EXCLUDED(OP_choose) - DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] <= x[i+1]. - */ - #if NOT_EXCLUDED(OP_is_non_decreasing) - DECLARE_BOOLEAN_OP(is_non_decreasing, 1, true); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] < x[i+1]. - */ - #if NOT_EXCLUDED(OP_is_strictly_increasing) - DECLARE_BOOLEAN_OP(is_strictly_increasing, 1, true); - #endif - - /** - * This op takes 1 n-dimensional array as input, and returns true if input is a numeric array. - */ - #if NOT_EXCLUDED(OP_is_numeric_tensor) - DECLARE_BOOLEAN_OP(is_numeric_tensor, 1, true); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_not) - DECLARE_OP(boolean_not, 1, 1, true); - #endif - } -} +namespace ops { + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x < y + */ +#if NOT_EXCLUDED(OP_lt_scalar) +DECLARE_BOOLEAN_OP(lt_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x > y + */ +#if NOT_EXCLUDED(OP_gt_scalar) +DECLARE_BOOLEAN_OP(gt_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x <= y + */ +#if NOT_EXCLUDED(OP_lte_scalar) +DECLARE_BOOLEAN_OP(lte_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x >= y + */ +#if NOT_EXCLUDED(OP_gte_scalar) +DECLARE_BOOLEAN_OP(gte_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if both operands are equal. + */ +#if NOT_EXCLUDED(OP_eq_scalar) +DECLARE_BOOLEAN_OP(eq_scalar, 2, true); +#endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x != y + */ +#if NOT_EXCLUDED(OP_neq_scalar) +DECLARE_BOOLEAN_OP(neq_scalar, 2, true); +#endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +#if NOT_EXCLUDED(OP_where) +DECLARE_CUSTOM_OP(Where, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_where_np) +DECLARE_CUSTOM_OP(where_np, 1, 1, false, 0, 0); +#endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +#if NOT_EXCLUDED(OP_select) +DECLARE_CUSTOM_OP(select, 3, 1, false, 0, 0); +#endif + +/** + * This op takes either 1 argument and 1 scalar + * or 1 argument and another comparison array + * and runs a pre defined conditional op. + * + * The output of the op is dynamic in size and returns a flat vector of + * elements that return true on the given condition. In numpy parlance, most + * people might understand: a[a > 2] where a is a numpy array and the condition + * is true when an element is > 2. Libnd4j already implements a number of pre + * defined conditions. + * @tparam T + */ +#if NOT_EXCLUDED(OP_choose) +DECLARE_CUSTOM_OP(choose, -1, 1, false, -2, -1); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] <= x[i+1]. + */ +#if NOT_EXCLUDED(OP_is_non_decreasing) +DECLARE_BOOLEAN_OP(is_non_decreasing, 1, true); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] < x[i+1]. + */ +#if NOT_EXCLUDED(OP_is_strictly_increasing) +DECLARE_BOOLEAN_OP(is_strictly_increasing, 1, true); +#endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if input is a + * numeric array. + */ +#if NOT_EXCLUDED(OP_is_numeric_tensor) +DECLARE_BOOLEAN_OP(is_numeric_tensor, 1, true); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_not) +DECLARE_OP(boolean_not, 1, 1, true); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/broadcastable.h b/libnd4j/include/ops/declarable/headers/broadcastable.h index 7380412a4156..fa89ab47c9ef 100644 --- a/libnd4j/include/ops/declarable/headers/broadcastable.h +++ b/libnd4j/include/ops/declarable/headers/broadcastable.h @@ -21,367 +21,397 @@ #ifndef LIBND4J_HEADERS_BROADCASTABLE_H #define LIBND4J_HEADERS_BROADCASTABLE_H -#include #include -#include +#include #include +#include namespace sd { - namespace ops { - // TODO: make broadcastables separate class - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Max(X, Y) - */ - #if NOT_EXCLUDED(OP_maximum) - DECLARE_BROADCASTABLE_OP(maximum, 0, 0); - DECLARE_CUSTOM_OP(maximum_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Min(X, Y) - */ - #if NOT_EXCLUDED(OP_minimum) - DECLARE_BROADCASTABLE_OP(minimum, 0, 0); - DECLARE_CUSTOM_OP(minimum_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Add(X, Y) - */ - #if NOT_EXCLUDED(OP_add) - DECLARE_BROADCASTABLE_OP(add, 0, 0); - DECLARE_CUSTOM_OP(add_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) - */ - #if NOT_EXCLUDED(OP_subtract) - DECLARE_BROADCASTABLE_OP(subtract, 0, 0); - DECLARE_CUSTOM_OP(subtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(Y, X) - */ - #if NOT_EXCLUDED(OP_reversesubtract) - DECLARE_BROADCASTABLE_OP(reversesubtract, 0, 0); - DECLARE_CUSTOM_OP(reversesubtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) - */ - #if NOT_EXCLUDED(OP_reversemod) - DECLARE_BROADCASTABLE_OP(reversemod, 0, 0); - DECLARE_CUSTOM_OP(reversemod_bp, 3, 2, true, 0, 0); - #endif - - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) - */ - #if NOT_EXCLUDED(OP_squaredsubtract) - DECLARE_BROADCASTABLE_OP(squaredsubtract, 0, 0) - DECLARE_CUSTOM_OP(squaredsubtract_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Multiply(X, Y) - */ - #if NOT_EXCLUDED(OP_multiply) - DECLARE_BROADCASTABLE_OP(multiply, 0, 0); - DECLARE_CUSTOM_OP(multiply_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ - #if NOT_EXCLUDED(OP_divide) - DECLARE_BROADCASTABLE_OP(divide, 0, 0); - DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 - */ - #if NOT_EXCLUDED(OP_divide_no_nan) - DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); - #endif - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(Y, x) - */ - #if NOT_EXCLUDED(OP_reversedivide) - DECLARE_BROADCASTABLE_OP(reversedivide, 0, 0); - DECLARE_CUSTOM_OP(reversedivide_bp, 3, 2, false, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorMod(X, Y) - */ - #if NOT_EXCLUDED(OP_floormod) - DECLARE_BROADCASTABLE_OP(floormod, 0, 0); - DECLARE_CUSTOM_OP(floormod_bp, 3, 2, true, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mod) - DECLARE_BROADCASTABLE_OP(mod, 0, 0); - DECLARE_CUSTOM_OP(mod_bp, 3, 2, true, 0, 0); - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorDiv(X, Y) - */ - #if NOT_EXCLUDED(OP_floordiv) - DECLARE_BROADCASTABLE_OP(floordiv, 0, 0) - DECLARE_CUSTOM_OP(floordiv_bp, 2, 1, true, 0, 0) - #endif - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ - #if NOT_EXCLUDED(OP_realdiv) - DECLARE_BROADCASTABLE_OP(realdiv, 0, 0); - DECLARE_CUSTOM_OP(realdiv_bp, 3, 2, false, 0, 0); - #endif - - - /** - * - * - * @tparam T - */ - DECLARE_BROADCASTABLE_OP(truncatediv, 0, 0); - - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Assign(X, Y) - */ - #if NOT_EXCLUDED(OP_assign) - DECLARE_BROADCASTABLE_OP(assign, 0, 0); - DECLARE_CUSTOM_OP(assign_bp, 3, 2, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_meshgrid) - DECLARE_CUSTOM_OP(meshgrid, -1, -1, false, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x == _y ? (T) 1.0f : (T) 0.0f; - * - */ - #if NOT_EXCLUDED(OP_equals) - DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_not_equals) - DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_less_equal) - DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_greater_equal) - DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_less) - DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0); - #endif - - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; - */ - #if NOT_EXCLUDED(OP_greater) - DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_and) - DECLARE_BROADCASTABLE_OP(boolean_and, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_or) - DECLARE_BROADCASTABLE_OP(boolean_or, 0, 0); - #endif - - /** - * - */ - #if NOT_EXCLUDED(OP_boolean_xor) - DECLARE_BROADCASTABLE_OP(boolean_xor, 0, 0); - #endif - - /** - * This operation performs calculation of percentile of input array along given axises - * - * Input - tensor with rank N > 0 - * Output - tensor with rank (N - length(axis)) or scalar if number of Integer arguments is zero - * Float arguments: - * 0: percentile (scalar) in range [0,100] (inclusively) - * 1: interpolation (optional), possible values are 0-"lower", 1-"higher", 2-"nearest"(default) - * 2: keepDims (optional), if it is non zero, then unities are kept in reduced resulting shape of output array, default is 0 - * Integer arguments - axis - the sequence of axises to calculate percentile along, if sequence is empty then calculate percentile for whole input tensor and return result as scalar - * - */ - #if NOT_EXCLUDED(OP_percentile) - DECLARE_CUSTOM_OP(percentile, 1, 1, false, 1, -2); - #endif - - - /** - * Special atan2 op impl for TF's args order - * @tparam T - */ - #if NOT_EXCLUDED(OP_tf_atan2) - DECLARE_BROADCASTABLE_OP(tf_atan2, 0, 0); - #endif - - /** - * Broadcastable pow implementation - * @tparam T - */ - #if NOT_EXCLUDED(OP_Pow) - DECLARE_BROADCASTABLE_OP(Pow, 0, 0); - DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); - #endif - - /** - * Broadcastable igamma implementation - * - * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } - * @tparam T - */ - #if NOT_EXCLUDED(OP_igamma) - DECLARE_BROADCASTABLE_OP(igamma, 0, 0); - #endif - /** - * Broadcastable igammac implementation - * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } - * @tparam T - */ - #if NOT_EXCLUDED(OP_igammac) - DECLARE_BROADCASTABLE_OP(igammac, 0, 0); - #endif - } -} +namespace ops { +// TODO: make broadcastables separate class + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Max(X, Y) + */ +#if NOT_EXCLUDED(OP_maximum) +DECLARE_BROADCASTABLE_OP(maximum, 0, 0); +DECLARE_CUSTOM_OP(maximum_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Min(X, Y) + */ +#if NOT_EXCLUDED(OP_minimum) +DECLARE_BROADCASTABLE_OP(minimum, 0, 0); +DECLARE_CUSTOM_OP(minimum_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Add(X, Y) + */ +#if NOT_EXCLUDED(OP_add) +DECLARE_BROADCASTABLE_OP(add, 0, 0); +DECLARE_CUSTOM_OP(add_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) + */ +#if NOT_EXCLUDED(OP_subtract) +DECLARE_BROADCASTABLE_OP(subtract, 0, 0); +DECLARE_CUSTOM_OP(subtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(Y, X) + */ +#if NOT_EXCLUDED(OP_reversesubtract) +DECLARE_BROADCASTABLE_OP(reversesubtract, 0, 0); +DECLARE_CUSTOM_OP(reversesubtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) + */ +#if NOT_EXCLUDED(OP_reversemod) +DECLARE_BROADCASTABLE_OP(reversemod, 0, 0); +DECLARE_CUSTOM_OP(reversemod_bp, 3, 2, true, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) + */ +#if NOT_EXCLUDED(OP_squaredsubtract) +DECLARE_BROADCASTABLE_OP(squaredsubtract, 0, 0) +DECLARE_CUSTOM_OP(squaredsubtract_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Multiply(X, Y) + */ +#if NOT_EXCLUDED(OP_multiply) +DECLARE_BROADCASTABLE_OP(multiply, 0, 0); +DECLARE_CUSTOM_OP(multiply_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +#if NOT_EXCLUDED(OP_divide) +DECLARE_BROADCASTABLE_OP(divide, 0, 0); +DECLARE_CUSTOM_OP(divide_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +#if NOT_EXCLUDED(OP_divide_no_nan) +DECLARE_BROADCASTABLE_OP(divide_no_nan, 0, 0); +#endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(Y, x) + */ +#if NOT_EXCLUDED(OP_reversedivide) +DECLARE_BROADCASTABLE_OP(reversedivide, 0, 0); +DECLARE_CUSTOM_OP(reversedivide_bp, 3, 2, false, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorMod(X, Y) + */ +#if NOT_EXCLUDED(OP_floormod) +DECLARE_BROADCASTABLE_OP(floormod, 0, 0); +DECLARE_CUSTOM_OP(floormod_bp, 3, 2, true, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mod) +DECLARE_BROADCASTABLE_OP(mod, 0, 0); +DECLARE_CUSTOM_OP(mod_bp, 3, 2, true, 0, 0); +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorDiv(X, Y) + */ +#if NOT_EXCLUDED(OP_floordiv) +DECLARE_BROADCASTABLE_OP(floordiv, 0, 0) +DECLARE_CUSTOM_OP(floordiv_bp, 2, 1, true, 0, 0) +#endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +#if NOT_EXCLUDED(OP_realdiv) +DECLARE_BROADCASTABLE_OP(realdiv, 0, 0); +DECLARE_CUSTOM_OP(realdiv_bp, 3, 2, false, 0, 0); +#endif + +/** + * + * + * @tparam T + */ +DECLARE_BROADCASTABLE_OP(truncatediv, 0, 0); + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Assign(X, Y) + */ +#if NOT_EXCLUDED(OP_assign) +DECLARE_BROADCASTABLE_OP(assign, 0, 0); +DECLARE_CUSTOM_OP(assign_bp, 3, 2, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_meshgrid) +DECLARE_CUSTOM_OP(meshgrid, -1, -1, false, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x == _y ? (T) 1.0f : (T) 0.0f; + * + */ +#if NOT_EXCLUDED(OP_equals) +DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x != _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_not_equals) +DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_less_equal) +DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_greater_equal) +DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x < _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_less) +DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0); +#endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x > _y ? (T) 1.0f : (T) 0.0f; + */ +#if NOT_EXCLUDED(OP_greater) +DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_and) +DECLARE_BROADCASTABLE_OP(boolean_and, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_or) +DECLARE_BROADCASTABLE_OP(boolean_or, 0, 0); +#endif + +/** + * + */ +#if NOT_EXCLUDED(OP_boolean_xor) +DECLARE_BROADCASTABLE_OP(boolean_xor, 0, 0); +#endif + +/** + * This operation performs calculation of percentile of input array along given + * axises + * + * Input - tensor with rank N > 0 + * Output - tensor with rank (N - length(axis)) or scalar if number of Integer + * arguments is zero Float arguments: 0: percentile (scalar) in range [0,100] + * (inclusively) 1: interpolation (optional), possible values are 0-"lower", + * 1-"higher", 2-"nearest"(default) 2: keepDims (optional), if it is non zero, + * then unities are kept in reduced resulting shape of output array, default is + * 0 Integer arguments - axis - the sequence of axises to calculate percentile + * along, if sequence is empty then calculate percentile for whole input tensor + * and return result as scalar + * + */ +#if NOT_EXCLUDED(OP_percentile) +DECLARE_CUSTOM_OP(percentile, 1, 1, false, 1, -2); +#endif + +/** + * Special atan2 op impl for TF's args order + * @tparam T + */ +#if NOT_EXCLUDED(OP_tf_atan2) +DECLARE_BROADCASTABLE_OP(tf_atan2, 0, 0); +#endif + +/** + * Broadcastable pow implementation + * @tparam T + */ +#if NOT_EXCLUDED(OP_Pow) +DECLARE_BROADCASTABLE_OP(Pow, 0, 0); +DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0); +#endif + +/** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ +#if NOT_EXCLUDED(OP_igamma) +DECLARE_BROADCASTABLE_OP(igamma, 0, 0); +#endif +/** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * @tparam T + */ +#if NOT_EXCLUDED(OP_igammac) +DECLARE_BROADCASTABLE_OP(igammac, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/common.h b/libnd4j/include/ops/declarable/headers/common.h index d70e6beb8eef..5c0624a8f97b 100644 --- a/libnd4j/include/ops/declarable/headers/common.h +++ b/libnd4j/include/ops/declarable/headers/common.h @@ -21,22 +21,25 @@ #ifndef LIBND4J_OPS_DECLARABLE_COMMON_H #define LIBND4J_OPS_DECLARABLE_COMMON_H -#include -#include -#include #include +#include #include -#include +#include +#include +#include #include -#include -#include +#include +#include #include #include +#include +#include +#include #include -#include -#include -#include -#include +#include +#include + +#include #include #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/compat.h b/libnd4j/include/ops/declarable/headers/compat.h index 37894517a0b9..1ec0e332b258 100644 --- a/libnd4j/include/ops/declarable/headers/compat.h +++ b/libnd4j/include/ops/declarable/headers/compat.h @@ -24,31 +24,30 @@ #include namespace sd { - namespace ops { - /** - * This operation splits input string into pieces separated by delimiter - * PLEASE NOTE: This implementation is compatible with TF 1.x - * - * Input[0] - string to split - * Input[1] - delimiter - * - * Returns: - * Output[0] - indices tensor - * Output[1] - values tensor - */ - #if NOT_EXCLUDED(OP_compat_string_split) - DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); - #endif - - /** - * This operation converts TF sparse array representation to dense NDArray - */ - #if NOT_EXCLUDED(OP_compat_sparse_to_dense) - DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); - #endif - - } -} - - -#endif //SAMEDIFF_COMPAT_H +namespace ops { +/** + * This operation splits input string into pieces separated by delimiter + * PLEASE NOTE: This implementation is compatible with TF 1.x + * + * Input[0] - string to split + * Input[1] - delimiter + * + * Returns: + * Output[0] - indices tensor + * Output[1] - values tensor + */ +#if NOT_EXCLUDED(OP_compat_string_split) +DECLARE_CUSTOM_OP(compat_string_split, 2, 2, false, 0, 0); +#endif + +/** + * This operation converts TF sparse array representation to dense NDArray + */ +#if NOT_EXCLUDED(OP_compat_sparse_to_dense) +DECLARE_CUSTOM_OP(compat_sparse_to_dense, 4, 1, false, 0, 0); +#endif + +} // namespace ops +} // namespace sd + +#endif // SAMEDIFF_COMPAT_H diff --git a/libnd4j/include/ops/declarable/headers/compression.h b/libnd4j/include/ops/declarable/headers/compression.h index 9c177f8a4af3..3ffca17c3a30 100644 --- a/libnd4j/include/ops/declarable/headers/compression.h +++ b/libnd4j/include/ops/declarable/headers/compression.h @@ -24,39 +24,40 @@ #include namespace sd { - namespace ops { - - /** - * encode_bitmap - reinterpret 3D float tensor into uint8_t vector with length N. - * - * Input: - * 0 - 3D float tensor with shape {height, width, channels} - * - * Output: - * 0 - 1D uint8_t tensor with shape {N} - */ - #if NOT_EXCLUDED(OP_encode_bitmap) - DECLARE_CUSTOM_OP(encode_bitmap, 1, 3, true, 1, 0); - #endif +namespace ops { - /** - * decode_bitmap - reinterpret uint8_t linear tensor as data to float tensor with shape - * - * Input: - * 0 - uint8_t vector with length N ( shape {N}) - * - * Output: - * 0 - 3D tensor with shape {height, width, channels} - * - */ - #if NOT_EXCLUDED(OP_decode_bitmap) - DECLARE_CUSTOM_OP(decode_bitmap, 2, 1, true, 0, 0); - #endif +/** + * encode_bitmap - reinterpret 3D float tensor into uint8_t vector with length + * N. + * + * Input: + * 0 - 3D float tensor with shape {height, width, channels} + * + * Output: + * 0 - 1D uint8_t tensor with shape {N} + */ +#if NOT_EXCLUDED(OP_encode_bitmap) +DECLARE_CUSTOM_OP(encode_bitmap, 1, 3, true, 1, 0); +#endif +/** + * decode_bitmap - reinterpret uint8_t linear tensor as data to float tensor + * with shape + * + * Input: + * 0 - uint8_t vector with length N ( shape {N}) + * + * Output: + * 0 - 3D tensor with shape {height, width, channels} + * + */ +#if NOT_EXCLUDED(OP_decode_bitmap) +DECLARE_CUSTOM_OP(decode_bitmap, 2, 1, true, 0, 0); +#endif - DECLARE_CUSTOM_OP(encode_threshold, 2, 1, true, 1, 0); - DECLARE_CUSTOM_OP(decode_threshold, 2, 1, true, 0, 0); - } -} +DECLARE_CUSTOM_OP(encode_threshold, 2, 1, true, 1, 0); +DECLARE_CUSTOM_OP(decode_threshold, 2, 1, true, 0, 0); +} // namespace ops +} // namespace sd -#endif // SD_HEADERS_COMPRESSION_H \ No newline at end of file +#endif // SD_HEADERS_COMPRESSION_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/convo.h b/libnd4j/include/ops/declarable/headers/convo.h index d00da07f23d4..b472209b7e28 100644 --- a/libnd4j/include/ops/declarable/headers/convo.h +++ b/libnd4j/include/ops/declarable/headers/convo.h @@ -24,307 +24,306 @@ #include namespace sd { - namespace ops { +namespace ops { - /** - * 1D temporal convolution implementation - * Expected input: - * x: 3D array - * weight: 3D Array - * bias: optional vector - * - * Int args: - * 0: kernel - * 1: stride - * 2: padding - */ - #if NOT_EXCLUDED(OP_conv1d) - DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5); - DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5); - #endif - - /** - * 2D convolution implementation - * Expected input: - * x: 4D array - * weight: 4D Array - * bias: optional vector, length of outputChannels - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 1 true, 0 false - * 9: data format: 1 NHWC, 0 NCHW - */ - #if NOT_EXCLUDED(OP_conv2d) - DECLARE_CUSTOM_OP(conv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(conv2d_bp, 3, 2, false, 0, 9); - DECLARE_CUSTOM_OP(conv2d_input_bp, 3, 1, false, 0, 9); - #endif - - /** - * Depthwise convolution2d op: - * Expected inputs: - * x: 4D array, NCHW format - * weightsDepth: 4D array, - * weightsPointwise: optional, 4D array - * bias: optional, vector - */ - #if NOT_EXCLUDED(OP_sconv2d) - DECLARE_CUSTOM_OP(sconv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(sconv2d_bp, 3, 2, false, 0, 9); - #endif - - /** - * 2D deconvolution implementation - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_deconv2d) - DECLARE_CUSTOM_OP(deconv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(deconv2d_bp, 3, 2, false, 0, 9); - #endif +/** + * 1D temporal convolution implementation + * Expected input: + * x: 3D array + * weight: 3D Array + * bias: optional vector + * + * Int args: + * 0: kernel + * 1: stride + * 2: padding + */ +#if NOT_EXCLUDED(OP_conv1d) +DECLARE_CUSTOM_OP(conv1d, 2, 1, false, 0, 5); +DECLARE_CUSTOM_OP(conv1d_bp, 3, 2, false, 0, 5); +#endif - /** - * 3D deconvolution implementation - * - * IntArgs: - * 0: filter(kernel) depth - * 1: filter(kernel) height - * 2: filter(kernel) width - * 3: strides depth - * 4: strides height - * 5: strides width - * 6: paddings depth - * 7: paddings height - * 8: paddings width - * 9: dilations depth - * 10: dilations height - * 11: dilations width - * 12: same mode: 0 false, 1 true - * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 - */ +/** + * 2D convolution implementation + * Expected input: + * x: 4D array + * weight: 4D Array + * bias: optional vector, length of outputChannels + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 1 true, 0 false + * 9: data format: 1 NHWC, 0 NCHW + */ +#if NOT_EXCLUDED(OP_conv2d) +DECLARE_CUSTOM_OP(conv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(conv2d_bp, 3, 2, false, 0, 9); +DECLARE_CUSTOM_OP(conv2d_input_bp, 3, 1, false, 0, 9); +#endif - #if NOT_EXCLUDED(OP_deconv3d) - DECLARE_CUSTOM_OP(deconv3d, 2, 1, false, 0, 13); - DECLARE_CUSTOM_OP(deconv3d_bp, 3, 2, false, 0, 13); - #endif +/** + * Depthwise convolution2d op: + * Expected inputs: + * x: 4D array, NCHW format + * weightsDepth: 4D array, + * weightsPointwise: optional, 4D array + * bias: optional, vector + */ +#if NOT_EXCLUDED(OP_sconv2d) +DECLARE_CUSTOM_OP(sconv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(sconv2d_bp, 3, 2, false, 0, 9); +#endif +/** + * 2D deconvolution implementation + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_deconv2d) +DECLARE_CUSTOM_OP(deconv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(deconv2d_bp, 3, 2, false, 0, 9); +#endif - /** - * This op implements max pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_maxpool2d) - DECLARE_CUSTOM_OP(maxpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(maxpool2d_bp, 2, 1, false, 0, 10); - #endif +/** + * 3D deconvolution implementation + * + * IntArgs: + * 0: filter(kernel) depth + * 1: filter(kernel) height + * 2: filter(kernel) width + * 3: strides depth + * 4: strides height + * 5: strides width + * 6: paddings depth + * 7: paddings height + * 8: paddings width + * 9: dilations depth + * 10: dilations height + * 11: dilations width + * 12: same mode: 0 false, 1 true + * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 + */ - /** - * This op implements average pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ - #if NOT_EXCLUDED(OP_avgpool2d) - DECLARE_CUSTOM_OP(avgpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(avgpool2d_bp, 2, 1, false, 0, 10); - #endif +#if NOT_EXCLUDED(OP_deconv3d) +DECLARE_CUSTOM_OP(deconv3d, 2, 1, false, 0, 13); +DECLARE_CUSTOM_OP(deconv3d_bp, 3, 2, false, 0, 13); +#endif - /** - * This op implements pnorm pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - * 9: p for p-norm - */ - #if NOT_EXCLUDED(OP_pnormpool2d) - DECLARE_CUSTOM_OP(pnormpool2d, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); - #endif +/** + * This op implements max pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_maxpool2d) +DECLARE_CUSTOM_OP(maxpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(maxpool2d_bp, 2, 1, false, 0, 10); +#endif - /** - * This op implements im2col algorithm, widely used in convolution neural networks - * Input: 4D input expected - * - * Int args: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: isSameMode - */ - #if NOT_EXCLUDED(OP_im2col) - DECLARE_CUSTOM_OP(im2col, 1, 1, false, 0, 9); - DECLARE_CUSTOM_OP(im2col_bp, 2, 1, false, 0, 9); - #endif +/** + * This op implements average pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +#if NOT_EXCLUDED(OP_avgpool2d) +DECLARE_CUSTOM_OP(avgpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(avgpool2d_bp, 2, 1, false, 0, 10); +#endif - /** - * This op implements col2im algorithm, widely used in convolution neural networks - * Input: 6D input expected (like output of im2col op) - * - * Int args: - * 0: stride height - * 1: stride width - * 2: padding height - * 3: padding width - * 4: image height - * 5: image width - * 6: dilation height - * 7: dilation width - */ - #if NOT_EXCLUDED(OP_col2im) - DECLARE_CUSTOM_OP(col2im, 1, 1, false, 0, 9); - #endif +/** + * This op implements pnorm pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + * 9: p for p-norm + */ +#if NOT_EXCLUDED(OP_pnormpool2d) +DECLARE_CUSTOM_OP(pnormpool2d, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(pnormpool2d_bp, 2, 1, false, 1, 10); +#endif - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for rows (height) - * 1: scale factor for columns (width) - * 2: data format: 0 NHWC (default), 1 NCHW - */ - #if NOT_EXCLUDED(OP_upsampling2d) - DECLARE_CUSTOM_OP(upsampling2d, 1, 1, false, 0, 2); - DECLARE_CUSTOM_OP(upsampling2d_bp, 2, 1, false, 0, 0); - #endif +/** + * This op implements im2col algorithm, widely used in convolution neural + * networks Input: 4D input expected + * + * Int args: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: isSameMode + */ +#if NOT_EXCLUDED(OP_im2col) +DECLARE_CUSTOM_OP(im2col, 1, 1, false, 0, 9); +DECLARE_CUSTOM_OP(im2col_bp, 2, 1, false, 0, 9); +#endif - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for depth - * 1: scale factor for rows (height) - * 2: scale factor for columns (width) - * 3: data format: 0 NDHWC (default), 1 NCDHW - */ - #if NOT_EXCLUDED(OP_upsampling3d) - DECLARE_CUSTOM_OP(upsampling3d, 1, 1, false, 0, 3); - DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); - #endif +/** + * This op implements col2im algorithm, widely used in convolution neural + * networks Input: 6D input expected (like output of im2col op) + * + * Int args: + * 0: stride height + * 1: stride width + * 2: padding height + * 3: padding width + * 4: image height + * 5: image width + * 6: dilation height + * 7: dilation width + */ +#if NOT_EXCLUDED(OP_col2im) +DECLARE_CUSTOM_OP(col2im, 1, 1, false, 0, 9); +#endif - /** - * This op produces binary matrix wrt to target dimension. - * Maximum value within each TAD is replaced with 1, other values are set to true. - * - * Int args: - * 0: axis - */ - #if NOT_EXCLUDED(OP_ismax) - DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2); - #endif +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for rows (height) + * 1: scale factor for columns (width) + * 2: data format: 0 NHWC (default), 1 NCHW + */ +#if NOT_EXCLUDED(OP_upsampling2d) +DECLARE_CUSTOM_OP(upsampling2d, 1, 1, false, 0, 2); +DECLARE_CUSTOM_OP(upsampling2d_bp, 2, 1, false, 0, 0); +#endif - /** - * Dilation2D op - * - * Int args: - * 0: isSameMode - */ - #if NOT_EXCLUDED(OP_dilation2d) - DECLARE_CUSTOM_OP(dilation2d, 2, 1, false, 0, 1); - #endif +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for depth + * 1: scale factor for rows (height) + * 2: scale factor for columns (width) + * 3: data format: 0 NDHWC (default), 1 NCDHW + */ +#if NOT_EXCLUDED(OP_upsampling3d) +DECLARE_CUSTOM_OP(upsampling3d, 1, 1, false, 0, 3); +DECLARE_CUSTOM_OP(upsampling3d_bp, 2, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_conv3dnew) - DECLARE_CUSTOM_OP(conv3dnew, 2, 1, false, 0, 13); - DECLARE_CUSTOM_OP(conv3dnew_bp, 3, 2, false, 0, 13); - #endif +/** + * This op produces binary matrix wrt to target dimension. + * Maximum value within each TAD is replaced with 1, other values are set to + * true. + * + * Int args: + * 0: axis + */ +#if NOT_EXCLUDED(OP_ismax) +DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2); +#endif - #if NOT_EXCLUDED(OP_avgpool3dnew) - DECLARE_CUSTOM_OP(avgpool3dnew, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(avgpool3dnew_bp, 2, 1, false, 0, 10); - #endif +/** + * Dilation2D op + * + * Int args: + * 0: isSameMode + */ +#if NOT_EXCLUDED(OP_dilation2d) +DECLARE_CUSTOM_OP(dilation2d, 2, 1, false, 0, 1); +#endif - #if NOT_EXCLUDED(OP_maxpool3dnew) - DECLARE_CUSTOM_OP(maxpool3dnew, 1, 1, false, 0, 10); - DECLARE_CUSTOM_OP(maxpool3dnew_bp, 2, 1, false, 0, 10); - #endif +#if NOT_EXCLUDED(OP_conv3dnew) +DECLARE_CUSTOM_OP(conv3dnew, 2, 1, false, 0, 13); +DECLARE_CUSTOM_OP(conv3dnew_bp, 3, 2, false, 0, 13); +#endif - /** - * This op same as maxpool2d with a variant to return a matrix of indexes for max values - * - * Input - 4D tensor - * Output: - * 0 - 4D tensor as input - * 1 - 4D tensor with max value indexes - * - * Int params: - * 9 int with 2x4 vectors and 1 bool value - */ - #if NOT_EXCLUDED(OP_max_pool_woth_argmax) - DECLARE_CUSTOM_OP(max_pool_with_argmax, 1, 2, false, 0, 9); - #endif +#if NOT_EXCLUDED(OP_avgpool3dnew) +DECLARE_CUSTOM_OP(avgpool3dnew, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(avgpool3dnew_bp, 2, 1, false, 0, 10); +#endif +#if NOT_EXCLUDED(OP_maxpool3dnew) +DECLARE_CUSTOM_OP(maxpool3dnew, 1, 1, false, 0, 10); +DECLARE_CUSTOM_OP(maxpool3dnew_bp, 2, 1, false, 0, 10); +#endif - #if NOT_EXCLUDED(OP_depthwise_conv2d) - DECLARE_CUSTOM_OP(depthwise_conv2d, 2, 1, false, 0, 9); - DECLARE_CUSTOM_OP(depthwise_conv2d_bp, 3, 2, false, 0, 9); - #endif +/** + * This op same as maxpool2d with a variant to return a matrix of indexes for + * max values + * + * Input - 4D tensor + * Output: + * 0 - 4D tensor as input + * 1 - 4D tensor with max value indexes + * + * Int params: + * 9 int with 2x4 vectors and 1 bool value + */ +#if NOT_EXCLUDED(OP_max_pool_woth_argmax) +DECLARE_CUSTOM_OP(max_pool_with_argmax, 1, 2, false, 0, 9); +#endif - /** - * point-wise 2D convolution - * Expected input: - * x: 4D array - * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) - * bias: optional vector, length of oC - * - * IntArgs: - * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) - */ - DECLARE_CUSTOM_OP(pointwise_conv2d, 2, 1, false, 0, 0); +#if NOT_EXCLUDED(OP_depthwise_conv2d) +DECLARE_CUSTOM_OP(depthwise_conv2d, 2, 1, false, 0, 9); +DECLARE_CUSTOM_OP(depthwise_conv2d_bp, 3, 2, false, 0, 9); +#endif - DECLARE_CUSTOM_OP(deconv2d_tf, 2, 1, false, 0, 0); +/** + * point-wise 2D convolution + * Expected input: + * x: 4D array + * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) + * bias: optional vector, length of oC + * + * IntArgs: + * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) + */ +DECLARE_CUSTOM_OP(pointwise_conv2d, 2, 1, false, 0, 0); - } -} +DECLARE_CUSTOM_OP(deconv2d_tf, 2, 1, false, 0, 0); +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index e467539199c7..643784ce5e13 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -23,91 +23,94 @@ #include namespace sd { - namespace ops { - /** - * This operation casts elements of input array to double data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_double) - DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0); - #endif +namespace ops { +/** + * This operation casts elements of input array to double data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_double) +DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to float16 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_float16) - DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to float16 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_float16) +DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to float data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_float32) - DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to float data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_float32) +DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_int32) - DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_int32) +DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to int64 (aka long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_int64) - DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to int64 (aka long long) data + * type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_int64) +DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to unsinged int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_uint32) - DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to unsinged int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_uint32) +DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to unsigned int64 (aka unsigned long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ - #if NOT_EXCLUDED(OP_to_uint64) - DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0); - #endif +/** + * This operation casts elements of input array to unsigned int64 (aka unsigned + * long long) data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +#if NOT_EXCLUDED(OP_to_uint64) +DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0); +#endif - /** - * This operation casts elements of input array to specified data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - * - * - * Int args: - * 0: target DataType - */ - #if NOT_EXCLUDED(OP_cast) - DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); - #endif - /** - * This operation change type of input and modified shape of output to conform with given data type - * - * all as above op - * */ - #if NOT_EXCLUDED(OP_bitcast) - DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); - #endif - } -} +/** + * This operation casts elements of input array to specified data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + * + * + * Int args: + * 0: target DataType + */ +#if NOT_EXCLUDED(OP_cast) +DECLARE_CUSTOM_OP(cast, 1, 1, false, 0, 1); +#endif +/** + * This operation change type of input and modified shape of output to conform + * with given data type + * + * all as above op + * */ +#if NOT_EXCLUDED(OP_bitcast) +DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/images.h b/libnd4j/include/ops/declarable/headers/images.h index 41974901a361..7e69e5f71991 100644 --- a/libnd4j/include/ops/declarable/headers/images.h +++ b/libnd4j/include/ops/declarable/headers/images.h @@ -16,7 +16,7 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// // // @author AbdelRauf (rauf@konduit.ai) // @@ -24,90 +24,89 @@ #ifndef LIBND4J_HEADERS_IMAGES_H #define LIBND4J_HEADERS_IMAGES_H -#include -#include -#include #include +#include +#include +#include #include namespace sd { namespace ops { - /** * Rgb To Hsv * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_rgb_to_hsv) - DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0); #endif /** * Hsv To Rgb * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_hsv_to_rgb) - DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0); #endif /** -* Rgb To GrayScale -* Input arrays: -* 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension must have size 3 and should contain RGB values. -*/ + * Rgb To GrayScale + * Input arrays: + * 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension + * must have size 3 and should contain RGB values. + */ #if NOT_EXCLUDED(OP_rgb_to_grs) - DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); #endif - /** - * Rgb To Yuv - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ +/** + * Rgb To Yuv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yuv) - DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0); #endif - /** - * Yuv To Rgb - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ +/** + * Yuv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yuv) - DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0); /** -* Rgb To Yiq -* Input arrays: -* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. -* Int arguments: -* 0 - optional argument, corresponds to dimension with 3 channels -*/ + * Rgb To Yiq + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_rgb_to_yiq) - DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0); #endif /** -* Yiq To Rgb -* Input arrays: -* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. -* Int arguments: -* 0 - optional argument, corresponds to dimension with 3 channels -*/ + * Yiq To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, + * that is dimension containing channels. Int arguments: 0 - optional argument, + * corresponds to dimension with 3 channels + */ #if NOT_EXCLUDED(OP_yiq_to_rgb) - DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0); #endif - } } diff --git a/libnd4j/include/ops/declarable/headers/kernels.h b/libnd4j/include/ops/declarable/headers/kernels.h index c4cc02cb59ae..dbcdf20f7113 100644 --- a/libnd4j/include/ops/declarable/headers/kernels.h +++ b/libnd4j/include/ops/declarable/headers/kernels.h @@ -24,11 +24,11 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_knn_mindistance) - DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_knn_mindistance) +DECLARE_CUSTOM_OP(knn_mindistance, 3, 1, false, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //LIBND4J_KERNELS_H +#endif // LIBND4J_KERNELS_H diff --git a/libnd4j/include/ops/declarable/headers/list.h b/libnd4j/include/ops/declarable/headers/list.h index af4fb5706e95..08703f61023b 100644 --- a/libnd4j/include/ops/declarable/headers/list.h +++ b/libnd4j/include/ops/declarable/headers/list.h @@ -24,108 +24,105 @@ #include namespace sd { - namespace ops { - // list operations, basically all around NDArrayList - - /** - * This operations puts given NDArray into (optionally) given NDArrayList. - * If no NDArrayList was provided - new one will be created - */ - #if NOT_EXCLUDED(OP_write_list) - DECLARE_LIST_OP(write_list, 2, 1, 0, -2); - #endif - - /** - * This operation concatenates given NDArrayList, and returns NDArray as result - */ - #if NOT_EXCLUDED(OP_stack_list) - DECLARE_LIST_OP(stack_list, 1, 1, 0, 0); - #endif - - /** - * This operations selects specified index fron NDArrayList and returns it as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, scalar with index - * - * Int args: - * optional, index - */ - #if NOT_EXCLUDED(OP_read_list) - DECLARE_LIST_OP(read_list, 1, 1, 0, 0); - #endif - - /** - * This operations selects specified indices fron NDArrayList and returns them as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, vector with indices - * - * Int args: - * optional, indices - */ - #if NOT_EXCLUDED(OP_pick_list) - DECLARE_LIST_OP(pick_list, 1, 1, -2, -2); - #endif - - /** - * This operations returns scalar, with number of existing arrays within given NDArrayList - * Expected arguments: - * x: list - */ - #if NOT_EXCLUDED(OP_size_list) - DECLARE_LIST_OP(size_list, 1, 1, 0, 0); - #endif - - /** - * This operation creates new empty NDArrayList - */ - #if NOT_EXCLUDED(OP_create_list) - DECLARE_LIST_OP(create_list, 1, 2, 0, -2); - #endif - - /** - * This operation unpacks given NDArray into specified NDArrayList wrt specified indices - */ - #if NOT_EXCLUDED(OP_scatter_list) - DECLARE_LIST_OP(scatter_list, 1, 1, 0, -2); - #endif - - /** - * This operation splits given NDArray into chunks, and stores them into given NDArrayList wert sizes - * Expected arguments: - * list: optional, NDArrayList. if not available - new NDArrayList will be created - * array: array to be split - * sizes: vector with sizes for each chunk - */ - #if NOT_EXCLUDED(OP_split_list) - DECLARE_LIST_OP(split_list, 2, 1, 0, -2); - #endif - - /** - * This operation builds NDArray from NDArrayList using indices - * Expected arguments: - * x: non-empty list - * indices: vector with indices for gather operation - */ - #if NOT_EXCLUDED(OP_gather_list) - DECLARE_LIST_OP(gather_list, 2, 1, 0, -2); - #endif - - /** - * This operation clones given NDArrayList - */ - #if NOT_EXCLUDED(OP_clone_list) - DECLARE_LIST_OP(clone_list, 1, 1, 0, 0); - #endif - - /** - * This operation unstacks given NDArray into NDArrayList by the first dimension - */ - #if NOT_EXCLUDED(OP_unstack_list) - DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); - #endif - } -} +namespace ops { +// list operations, basically all around NDArrayList + +/** + * This operations puts given NDArray into (optionally) given NDArrayList. + * If no NDArrayList was provided - new one will be created + */ +#if NOT_EXCLUDED(OP_write_list) +DECLARE_LIST_OP(write_list, 2, 1, 0, -2); +#endif + +/** + * This operation concatenates given NDArrayList, and returns NDArray as result + */ +#if NOT_EXCLUDED(OP_stack_list) +DECLARE_LIST_OP(stack_list, 1, 1, 0, 0); +#endif + +/** + * This operations selects specified index fron NDArrayList and returns it as + * NDArray Expected arguments: x: non-empty list indices: optional, scalar with + * index + * + * Int args: + * optional, index + */ +#if NOT_EXCLUDED(OP_read_list) +DECLARE_LIST_OP(read_list, 1, 1, 0, 0); +#endif + +/** + * This operations selects specified indices fron NDArrayList and returns them + * as NDArray Expected arguments: x: non-empty list indices: optional, vector + * with indices + * + * Int args: + * optional, indices + */ +#if NOT_EXCLUDED(OP_pick_list) +DECLARE_LIST_OP(pick_list, 1, 1, -2, -2); +#endif + +/** + * This operations returns scalar, with number of existing arrays within given + * NDArrayList Expected arguments: x: list + */ +#if NOT_EXCLUDED(OP_size_list) +DECLARE_LIST_OP(size_list, 1, 1, 0, 0); +#endif + +/** + * This operation creates new empty NDArrayList + */ +#if NOT_EXCLUDED(OP_create_list) +DECLARE_LIST_OP(create_list, 1, 2, 0, -2); +#endif + +/** + * This operation unpacks given NDArray into specified NDArrayList wrt specified + * indices + */ +#if NOT_EXCLUDED(OP_scatter_list) +DECLARE_LIST_OP(scatter_list, 1, 1, 0, -2); +#endif + +/** + * This operation splits given NDArray into chunks, and stores them into given + * NDArrayList wert sizes Expected arguments: list: optional, NDArrayList. if + * not available - new NDArrayList will be created array: array to be split + * sizes: vector with sizes for each chunk + */ +#if NOT_EXCLUDED(OP_split_list) +DECLARE_LIST_OP(split_list, 2, 1, 0, -2); +#endif + +/** + * This operation builds NDArray from NDArrayList using indices + * Expected arguments: + * x: non-empty list + * indices: vector with indices for gather operation + */ +#if NOT_EXCLUDED(OP_gather_list) +DECLARE_LIST_OP(gather_list, 2, 1, 0, -2); +#endif + +/** + * This operation clones given NDArrayList + */ +#if NOT_EXCLUDED(OP_clone_list) +DECLARE_LIST_OP(clone_list, 1, 1, 0, 0); +#endif + +/** + * This operation unstacks given NDArray into NDArrayList by the first dimension + */ +#if NOT_EXCLUDED(OP_unstack_list) +DECLARE_LIST_OP(unstack_list, 1, 1, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/loss.h b/libnd4j/include/ops/declarable/headers/loss.h index 3f36250404f9..bc8fa22183f9 100644 --- a/libnd4j/include/ops/declarable/headers/loss.h +++ b/libnd4j/include/ops/declarable/headers/loss.h @@ -25,341 +25,383 @@ namespace sd { namespace ops { - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of hinge loss function max(0, 1 - labels*logits) - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_hinge_loss) - DECLARE_CUSTOM_OP(hinge_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(hinge_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of hinge loss function max(0, 1 - labels*logits) + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_hinge_loss) +DECLARE_CUSTOM_OP(hinge_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(hinge_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Huber loss function: - * 0.5 * (labels-predictions)^2 if |labels-predictions| <= delta - * 0.5 * delta^2 + delta * (|labels-predictions| - delta) if |labels-predictions| > delta - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: point where the huber loss function changes from a quadratic to linear. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_huber_loss) - DECLARE_CUSTOM_OP(huber_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(huber_loss_grad, 3, 1, false, 1, 1); - #endif - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * log(1 - p_i) ) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: a small increment to add to avoid taking a log of zero. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_log_loss) - DECLARE_CUSTOM_OP(log_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(log_loss_grad, 3, 3, false, 1, 1); - #endif - - /** - * l2_loss op. - * compute a l2 norm for given array. - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ - #if NOT_EXCLUDED(OP_l2_loss) - DECLARE_CUSTOM_OP(l2_loss, 1, 1, false, 0, 0); - #endif - - - /** - * This op calculates logarithmic loss of poisson distributed input. - * Input arrays: - * 0: log_predictions - must be already pre-transformed to log(x) - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: optional - boolean value compute_full_loss: 0 (default) or 1 (compute) - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as log_predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_log_poisson_loss) - DECLARE_CUSTOM_OP(log_poisson_loss, 3, 1, true, 0, 1); - DECLARE_CUSTOM_OP(log_poisson_loss_grad, 3, 3, true, 0, 1); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of pairwise-errors-squared loss function - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Output array: - * 0: loss value, it is just single scalar, type float. - */ - #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) - DECLARE_CUSTOM_OP(mean_pairwssqerr_loss, 3, 1, false, 0, 0); - DECLARE_CUSTOM_OP(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Huber loss function: + * 0.5 * (labels-predictions)^2 if + * |labels-predictions| <= delta 0.5 * delta^2 + delta * (|labels-predictions| - + * delta) if |labels-predictions| > delta + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: point where the huber loss function changes from a quadratic to linear. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_huber_loss) +DECLARE_CUSTOM_OP(huber_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(huber_loss_grad, 3, 1, false, 1, 1); +#endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Sum-of-Squares loss function 1/N * sum_{i}^{N}(predictions_i - labels_i)^2 - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_mean_sqerr_loss) - DECLARE_CUSTOM_OP(mean_sqerr_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(mean_sqerr_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * + * log(1 - p_i) ) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: a small increment to add to avoid taking a log of zero. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_log_loss) +DECLARE_CUSTOM_OP(log_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(log_loss_grad, 3, 3, false, 1, 1); +#endif +/** + * l2_loss op. + * compute a l2 norm for given array. + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +#if NOT_EXCLUDED(OP_l2_loss) +DECLARE_CUSTOM_OP(l2_loss, 1, 1, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/2): new_labels = labels * (1 - labelsSmoothing)+ 0.5 * labelsSmoothing - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) - DECLARE_CUSTOM_OP(sigm_cross_entropy_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1); - #endif - +/** + * This op calculates logarithmic loss of poisson distributed input. + * Input arrays: + * 0: log_predictions - must be already pre-transformed to log(x) + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights 1: optional - boolean value compute_full_loss: 0 + * (default) or 1 (compute) + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as log_predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_log_poisson_loss) +DECLARE_CUSTOM_OP(log_poisson_loss, 3, 1, true, 0, 1); +DECLARE_CUSTOM_OP(log_poisson_loss_grad, 3, 3, true, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - labelsSmoothing) + labelsSmoothing / numClasses - * - * Output array: - * 0: loss values, type float. - * Can be an array with shape as in logits except last dimension is equal to unity or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss, 3, 1, false, 1, 1); - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of pairwise-errors-squared loss function + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Output array: + * 0: loss value, it is just single scalar, type float. + */ +#if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) +DECLARE_CUSTOM_OP(mean_pairwssqerr_loss, 3, 1, false, 0, 0); +DECLARE_CUSTOM_OP(mean_pairwssqerr_loss_grad, 3, 3, false, 0, 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Sum-of-Squares loss function 1/N * + * sum_{i}^{N}(predictions_i - labels_i)^2 + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_mean_sqerr_loss) +DECLARE_CUSTOM_OP(mean_sqerr_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(mean_sqerr_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Absolute Difference loss function |predictions - labels| - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_absolute_difference_loss) - DECLARE_CUSTOM_OP(absolute_difference_loss, 3, 1, false, 0, 1); - DECLARE_CUSTOM_OP(absolute_difference_loss_grad, 3, 3, false, 0, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/2): new_labels = labels * (1 - + * labelsSmoothing)+ 0.5 * labelsSmoothing + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) +DECLARE_CUSTOM_OP(sigm_cross_entropy_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - + * labelsSmoothing) + labelsSmoothing / numClasses + * + * Output array: + * 0: loss values, type float. + * Can be an array with shape as in logits except last dimension is equal + * to unity or just single scalar, depending on reduction mode (see input + * integer argument) + */ +#if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss, 3, 1, false, 1, 1); +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of cosine-distance loss function 1. - (predictions * labels).reduce_sum_along(dimension) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: dimension along which the cosine distance is computed. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ - #if NOT_EXCLUDED(OP_cosine_distance_loss) - DECLARE_CUSTOM_OP(cosine_distance_loss, 3, 1, false, 0, 2); - DECLARE_CUSTOM_OP(cosine_distance_loss_grad, 3, 3, false, 0, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Absolute Difference loss function |predictions - labels| + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_absolute_difference_loss) +DECLARE_CUSTOM_OP(absolute_difference_loss, 3, 1, false, 0, 1); +DECLARE_CUSTOM_OP(absolute_difference_loss_grad, 3, 3, false, 0, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function - * - * Input arrays: - * 0: logits - logits, type float - * 1: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: optional (default is last dimension) dimension with classes - * - * Output array: - * 0: loss values, type float. An array with shape resulting from reducing of logits shape along dimension with classes - */ - #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); - DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of cosine-distance loss function 1. - (predictions * + * labels).reduce_sum_along(dimension) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights 1: dimension along which the cosine + * distance is computed. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +#if NOT_EXCLUDED(OP_cosine_distance_loss) +DECLARE_CUSTOM_OP(cosine_distance_loss, 3, 1, false, 0, 2); +DECLARE_CUSTOM_OP(cosine_distance_loss_grad, 3, 3, false, 0, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sparse softmax cross-entropy loss function - * - * Input arrays: - * 0: labels - ground truth vales, expected to be within range [0, num_classes), type float. - * Must have rank equal logits rank minus 1. - * 1: logits - logits, type float - * - * Output array: - * 0: loss values, type float. Has the same shape as labels - */ - #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) - DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); - DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function + * + * Input arrays: + * 0: logits - logits, type float + * 1: labels - ground truth vales, expected to be 0. or 1., type float. + * Must have the same shape as logits. + * + * Input integer arguments: + * 0: optional (default is last dimension) dimension with classes + * + * Output array: + * 0: loss values, type float. An array with shape resulting from reducing of + * logits shape along dimension with classes + */ +#if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0); +DECLARE_CUSTOM_OP(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, + 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sparse softmax cross-entropy loss function + * + * Input arrays: + * 0: labels - ground truth vales, expected to be within range [0, + * num_classes), type float. Must have rank equal logits rank minus 1. 1: logits + * - logits, type float + * + * Output array: + * 0: loss values, type float. Has the same shape as labels + */ +#if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) +DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, + 0); +DECLARE_CUSTOM_OP(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, + false, 0, 0); +#endif -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/nlp.h b/libnd4j/include/ops/declarable/headers/nlp.h index 1eb846365744..ab399fe8abce 100644 --- a/libnd4j/include/ops/declarable/headers/nlp.h +++ b/libnd4j/include/ops/declarable/headers/nlp.h @@ -23,16 +23,16 @@ #include namespace sd { - namespace ops { +namespace ops { - #if NOT_EXCLUDED(OP_skipgram) - DECLARE_CONFIGURABLE_OP(skipgram, 12, 12, true, 0, 0); - #endif +#if NOT_EXCLUDED(OP_skipgram) +DECLARE_CONFIGURABLE_OP(skipgram, 12, 12, true, 0, 0); +#endif - #if NOT_EXCLUDED(OP_cbow) - DECLARE_CONFIGURABLE_OP(cbow, 15, 15, true, 0, 0); - #endif - } -} +#if NOT_EXCLUDED(OP_cbow) +DECLARE_CONFIGURABLE_OP(cbow, 15, 15, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //SD_NLP_H +#endif // SD_NLP_H diff --git a/libnd4j/include/ops/declarable/headers/nn.h b/libnd4j/include/ops/declarable/headers/nn.h index 26699aa324d4..6b4ad2e3212b 100644 --- a/libnd4j/include/ops/declarable/headers/nn.h +++ b/libnd4j/include/ops/declarable/headers/nn.h @@ -24,236 +24,244 @@ #include namespace sd { - namespace ops { +namespace ops { - #if NOT_EXCLUDED(OP_softmax) - DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); - #endif +#if NOT_EXCLUDED(OP_softmax) +DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); +#endif - /** - * Local response normalization implementation as TF. - * input: 4D array - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - 4D array - */ - #if NOT_EXCLUDED(OP_lrn) - DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); - #endif - - /** - * Local response normalization - backprop variant. - * input: - * 0 - 4D array of data - * 1 - epsilon - 4D array of approximation - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - next approximation as 4D array - */ - #if NOT_EXCLUDED(OP_lrn) - DECLARE_CONFIGURABLE_OP(lrn_bp, 2, 1, true, 3, 0); - #endif - - /** - * Batch normalization implementation. - * Reference: https://arxiv.org/abs/1502.03167v3 - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: - * beta: - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * - * T args: - * 0: epsilon - */ - #if NOT_EXCLUDED(OP_batchnorm) - DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); - #endif - - /** - * back prop in batch normalization - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: optional - * beta: optional - * dLdOut: next epsilon - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * T args: - * 0: epsilon - * - * output arrays: - * dL/dInput - * dL/dMean - * dL/dVariance - * dL/dGamma, optional - * dL/dBeta, optional - */ - #if NOT_EXCLUDED(OP_batchnorm) - DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); - #endif +/** + * Local response normalization implementation as TF. + * input: 4D array + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - 4D array + */ +#if NOT_EXCLUDED(OP_lrn) +DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); +#endif +/** + * Local response normalization - backprop variant. + * input: + * 0 - 4D array of data + * 1 - epsilon - 4D array of approximation + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - next approximation as 4D array + */ +#if NOT_EXCLUDED(OP_lrn) +DECLARE_CONFIGURABLE_OP(lrn_bp, 2, 1, true, 3, 0); +#endif - /** - * This operation updates parameters with provided gradients, wrt learning rate - * Expected arguments: - * x: parameters, any shape - * y: gradients. same shape as x - * lr: optional, learning rate - * - * T args: - * 0: optional, learning rate - */ - #if NOT_EXCLUDED(OP_apply_sgd) - DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); - #endif +/** + * Batch normalization implementation. + * Reference: https://arxiv.org/abs/1502.03167v3 + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: + * beta: + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * + * T args: + * 0: epsilon + */ +#if NOT_EXCLUDED(OP_batchnorm) +DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); +#endif - /** - * This operation performs batch normalization of layer, it is based on following article https://arxiv.org/abs/1502.03167. - * Expected arguments: - * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width - * iD - input depth (or number of channels) - * scale: 1D input array of scale factors, shape [iD] - * offset: 1D input array of offsets (shifts), shape [iD] - * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * - * T input arguments: - * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * - * integer input arguments: - * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW - * 1: isTraining, may have two values: zero -> inference, unity -> training - */ - #if NOT_EXCLUDED(OP_fused_batch_norm) - DECLARE_CUSTOM_OP(fused_batch_norm, 3, 1, false, 0, 2); - #endif +/** + * back prop in batch normalization + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: optional + * beta: optional + * dLdOut: next epsilon + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * T args: + * 0: epsilon + * + * output arrays: + * dL/dInput + * dL/dMean + * dL/dVariance + * dL/dGamma, optional + * dL/dBeta, optional + */ +#if NOT_EXCLUDED(OP_batchnorm) +DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); +#endif - #if NOT_EXCLUDED(OP_log_softmax) - DECLARE_CONFIGURABLE_OP(log_softmax, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(log_softmax_bp, 2, 1, true, 0, 0); - #endif +/** + * This operation updates parameters with provided gradients, wrt learning rate + * Expected arguments: + * x: parameters, any shape + * y: gradients. same shape as x + * lr: optional, learning rate + * + * T args: + * 0: optional, learning rate + */ +#if NOT_EXCLUDED(OP_apply_sgd) +DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); +#endif +/** + * This operation performs batch normalization of layer, it is based on + * following article https://arxiv.org/abs/1502.03167. Expected arguments: x: + * input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] + * (data format = NCHW), where bS - batch size iH - input height iW - input + * width iD - input depth (or number of channels) scale: 1D input array of + * scale factors, shape [iD] offset: 1D input array of offsets (shifts), shape + * [iD] mean: 1D input array of population mean used for inference, shape [iD], + * this array is required only if isTraining = false variance: 1D input array of + * population mean used for inference, shape [iD], this array is required only + * if isTraining = false + * + * T input arguments: + * 0: epsilon, it is optional argument, default value is 0.001, this is small + * number to be added to the variance of x + * + * integer input arguments: + * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW + * 1: isTraining, may have two values: zero -> inference, unity -> training + */ +#if NOT_EXCLUDED(OP_fused_batch_norm) +DECLARE_CUSTOM_OP(fused_batch_norm, 3, 1, false, 0, 2); +#endif - /** - * relu_layer = relu(x*w + b) - */ - DECLARE_CUSTOM_OP(relu_layer, 3, 1, false, 0, 0); +#if NOT_EXCLUDED(OP_log_softmax) +DECLARE_CONFIGURABLE_OP(log_softmax, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(log_softmax_bp, 2, 1, true, 0, 0); +#endif - /** - * applies layer normalization to input - * y = g * standardize(x) + b - * - * see sd::ops::standardize - * - */ - #if NOT_EXCLUDED(OP_layer_norm) - DECLARE_CONFIGURABLE_OP(layer_norm, 3, 1, true, 0, -2); - DECLARE_CUSTOM_OP(layer_norm_bp, 4, 1, false, 0, -2); - #endif +/** + * relu_layer = relu(x*w + b) + */ +DECLARE_CUSTOM_OP(relu_layer, 3, 1, false, 0, 0); - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] - */ - #if NOT_EXCLUDED(OP_dot_product_attention) - DECLARE_CUSTOM_OP(dot_product_attention, 3, -1, false, 0, 2); - DECLARE_CUSTOM_OP(dot_product_attention_bp, 4, 3, false, 0, 1); - #endif +/** + * applies layer normalization to input + * y = g * standardize(x) + b + * + * see sd::ops::standardize + * + */ +#if NOT_EXCLUDED(OP_layer_norm) +DECLARE_CONFIGURABLE_OP(layer_norm, 3, 1, true, 0, -2); +DECLARE_CUSTOM_OP(layer_norm_bp, 4, 1, false, 0, -2); +#endif +/** + * This operation performs dot product attention on the given timeseries input + * with the given queries out = sum(similarity(k_i, q) * v_i) + * + * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q + * + * Optionally with normalization step: + * similarity(k, q) = softmax(k * q / sqrt(size(q)) + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, + * eq. 1) + * + * Note: This supports multiple queries at once, if only one query is available + * the queries vector still has to be 3D but can have queryCount = 1 + * + * Note: keys and values usually is the same array. If you want to use it as the + * same array, simply pass it for both. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or + * 4D array of shape [batchSize, numHeads, featureKeys, queryCount] k: input 3D + * array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of + * shape [batchSize, numHeads, featureKeys, timesteps] v: input 3D array + * "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape + * [batchSize, numHeads, featureValues, timesteps] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount] 1: OPTIONAL; Attention + * weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, + * timesteps, queryCount] + */ +#if NOT_EXCLUDED(OP_dot_product_attention) +DECLARE_CUSTOM_OP(dot_product_attention, 3, -1, false, 0, 2); +DECLARE_CUSTOM_OP(dot_product_attention_bp, 4, 3, false, 0, 1); +#endif - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] - * Wq: input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wk: input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, outSize, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, queryCount] - */ - #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) - DECLARE_CUSTOM_OP(multi_head_dot_product_attention, 7, -1, false, 0, 2); - DECLARE_CUSTOM_OP(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1); - #endif - } -} +/** + * This performs multi-headed dot product attention on the given timeseries + * input out = concat(head_1, head_2, ..., head_n) * Wo head_i = + * dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) + * + * Optionally with normalization when calculating the attention for each head. + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. + * 4,5, "3.2.2 Multi-Head Attention") + * + * This makes use of dot_product_attention OP support for rank 4 inputs. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] + * Wq: input query projection weights of shape [numHeads, projectedKeys, + * featureKeys] Wk: input key projection weights of shape [numHeads, + * projectedKeys, featureKeys] Wv: input value projection weights of shape + * [numHeads, projectedValues, featureValues] Wo: output projection weights of + * shape [numHeads * projectedValues, outSize] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, outSize, queryCount] + * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, + * queryCount] + */ +#if NOT_EXCLUDED(OP_multi_head_dot_product_attention) +DECLARE_CUSTOM_OP(multi_head_dot_product_attention, 7, -1, false, 0, 2); +DECLARE_CUSTOM_OP(multi_head_dot_product_attention_bp, 8, 7, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 8fae1b63c092..6cdff8767b3e 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -25,1998 +25,2077 @@ #include namespace sd { - namespace ops { - /** - * This operation returns index of max element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ - #if NOT_EXCLUDED(OP_argmax) - DECLARE_CUSTOM_OP(argmax, 1, 1, false, 0, -2); - #endif - - /** - * This operation returns index of min element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ - #if NOT_EXCLUDED(OP_argmin) - DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); - #endif - - /** - * This operation provides various normalization modes: - * 0: frobenius - * 1: euclidean (norm2) - * 2: norm1 - * 3: norm2 - * 4: inf-norm - * 5: p-norm - * - * Expected arguments: - * input: N-dimensional array - * - * - * Int args: - * 0...: axis - * - * T args: - * 0: norm mode - * 1: p for p-norm - */ - #if NOT_EXCLUDED(OP_norm) - DECLARE_REDUCTION_OP(norm, 1, 1, false, 1, -2); - #endif - - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array - * - * Input arrays: - * 0: input array, considered as batch of matrices - * 1: diagonal array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) - * - * Output array: - * 0: has the same shape as input, corresponding diagonal elements are substituted - */ - #if NOT_EXCLUDED(OP_matrix_set_diag) - DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); - #endif - - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, - * rest output elements are set to zeros - * - * Input array: - * diagonal: array containing elements to be inserted into output array, - * following rank condition is present: diagonal_rank = ouput_rank - 1 - * - * Output array: - * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] - */ - DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0); - - /** - * This op calculates regularized incomplete beta integral Ix(a, b). - * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied - * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied - * - * Input arrays: - * a: defines power t^{a-1}, must be > 0, type float. - * b: defines power (1-t)^{b-1}, must be > 0, type float. - * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. - * - * Output array: - * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float - * - * Three input and one output arrays must have the same shape - */ - #if NOT_EXCLUDED(OP_betainc) - DECLARE_CONFIGURABLE_OP(betainc, 3, 1, false, 0, 0); - #endif - - /** - * This operation is added for compatibility purposes mostly. - * PLEASE NOTE: Please consider using Add instead - * Expected arguments: - * 0: N-dimensional input - * 1: bias vector - */ - #if NOT_EXCLUDED(OP_biasadd) - DECLARE_CUSTOM_OP(biasadd, 2, 1, true, 0, 0); - DECLARE_CUSTOM_OP(biasadd_bp, 3, 2, false, 0, 0); - #endif - - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ - #if NOT_EXCLUDED(OP_diag) - DECLARE_CUSTOM_OP(diag, 1, 1, false, 0, 0); - #endif - - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ - #if NOT_EXCLUDED(OP_diag_part) - DECLARE_CUSTOM_OP(diag_part, 1, 1, false, 0, 0); - #endif - - /** - * Returns a diagonal vector for any submatricies with in a given tensor. - * It is an op inverse to matrix_set_giag. - * Using input tensor as batched 2D diagonals flat them to vector (1D) with diagonal values. - * - * Input : batched tensor with rank >=2 - * Output: tensor with rank lesser by 1 from input - */ - #if NOT_EXCLUDED(OP_matrix_diag_part) - DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0); - #endif - - /** - * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. - * For A (MxN) Q is M x M and R is (NxN). - * - * Input : - * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies - * - * Output: - * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} - * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} - */ - #if NOT_EXCLUDED(OP_qr) - DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0); - #endif - - /** - * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. - * Expected arguments: - * 0: vector with original values - * 1: vector with values to exclude - */ - #if NOT_EXCLUDED(OP_listdiff) - DECLARE_CUSTOM_OP(listdiff, 2, 2, false, 0, 0); - #endif - - /** - * This operation applies Add operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_add) - DECLARE_OP(scatter_add, 3, 1, true); - #endif - - /** - * This operation applies Subtract operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_sub) - DECLARE_OP(scatter_sub, 3, 1, true); - #endif - - /** - * This operation applies Multiply operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_mul) - DECLARE_OP(scatter_mul, 3, 1, true); - #endif - - /** - * This operation applies Divide operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_div) - DECLARE_OP(scatter_div, 3, 1, true); - #endif - - /** - * This operation applies Assign operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_upd) - DECLARE_OP(scatter_upd, 3, 1, true); - #endif - - /** - * This operation applies Max operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_max) - DECLARE_OP(scatter_max, 3, 1, true); - #endif - - /** - * This operation applies Min operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_min) - DECLARE_OP(scatter_min, 3, 1, true); - #endif - - /** - * This operation scatter "updates" elements into new output array according to given "indices" - * Expected arguments: - * indices: array containing elements/slices indexes of output array to put "updates" elements into, the rest output elements will be zeros - * updates: array containing elements to be inserted into output array - * shape: contains shape of output array - */ - #if NOT_EXCLUDED(OP_scatter_nd) - DECLARE_CUSTOM_OP(scatter_nd, 3, 1, false, 0, 0); - #endif - - /** - * This operation scatter "updates" elements into input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to put "updates" elements into - * updates: array containing elements to be inserted into input array - */ - #if NOT_EXCLUDED(OP_scatter_nd_update) - DECLARE_OP(scatter_nd_update, 3, 1, true); - #endif - - /** - * This operation adds "updates" elements to input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to add "updates" elements to - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_add) - DECLARE_OP(scatter_nd_add, 3, 1, true); - #endif - - /** - * This operation subtract "updates" elements from input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to subtract "updates" elements from - * updates: array containing elements to be interfered with input - */ - #if NOT_EXCLUDED(OP_scatter_sub) - DECLARE_OP(scatter_nd_sub, 3, 1, true); - #endif - - /** - * This operation takes input's shape, and returns new NDArray filled with specified value - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: scalar value, used to fill NDArray - */ - #if NOT_EXCLUDED(OP_fill_as) - DECLARE_CONFIGURABLE_OP(fill_as, 1, 1, true, 1, 0); - #endif - - /** - * This operation applies element-wise rint (round to integral value) operation - */ - #if NOT_EXCLUDED(OP_rint) - DECLARE_OP(rint, 1, 1, true); - #endif - - /** - * This operation returns unique elements from input array as vector, and their original indices in input array - * Expected input: - * input: N-dimensional array - */ - #if NOT_EXCLUDED(OP_unique) - DECLARE_CUSTOM_OP(unique, 1, 2, false, 0, 0); - #endif - - /** - * This operation returns 3 1D arrays for given 1D array with unique element count and indexes - * input: - * 0 - 1D array - * - * output: - * 0 - 1D array with unique values - * 1 - 1D array with ids for values in array above - * 2 - 1D array with counts for values in array above - */ - #if NOT_EXCLUDED(OP_unique_with_counts) - DECLARE_CUSTOM_OP(unique_with_counts, 1, 3, false, 0, 0); - #endif - - /** - * This operation splits input NDArray into multiple TADs along given dimensions - * Expected arguments: - * input: N-dimensional array - * - * Int args: - * 0..: TAD axis - */ - #if NOT_EXCLUDED(OP_tear) - DECLARE_CUSTOM_OP(tear, 1, -1, false, 0, -1); - #endif - - /** - * This op does the same as tear, just uses different input format: - * @tparam T - */ - #if NOT_EXCLUDED(OP_unstack) - DECLARE_CUSTOM_OP(unstack, 1, -1, false, 0, 1); - #endif - - /** - * This operation extracts a strided (optionally) slice from a tensor, - */ - #if NOT_EXCLUDED(OP_strided_slice) - DECLARE_CUSTOM_OP(strided_slice, 1, 1, false, 0, 5); // TODO: new op type needed. that returns VIEW - DECLARE_CUSTOM_OP(strided_slice_bp, 2, 1, false, 0, 5); - #endif - - /** - * This operation extracts a slice from a tensor. - * - */ - #if NOT_EXCLUDED(OP_slice) - DECLARE_CUSTOM_OP(slice, 1, 1, false, 0, -2); - DECLARE_CUSTOM_OP(slice_bp, 2, 1, false, 0, -2); - #endif - - /** - * This operation generate sequences. Basically from......to, with step used as increment. - * Expected arguments: - * start: optional scalar with starting value - * stop: optional scalar with end value - * step: optional scalar witn step value - * - * Int args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - * - * T args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - */ - #if NOT_EXCLUDED(OP_range) - DECLARE_CUSTOM_OP(range, -2, 1, false, -2, -2); - #endif - - /** - * This operation return one-hot encoded n-dimensional array - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: 'on' value - * 1: 'off' value - * - * Int args: - * 0: depth - * 1: axis - */ - #if NOT_EXCLUDED(OP_onehot) - DECLARE_CUSTOM_OP(onehot, 1, 1, false, -2, -2); - #endif - - - /** - * This operation calculate the confusion matrix for a - * pair of prediction and label 1-D arrays. - * Expected arguments: - * Input arrays: - * 0 - predictions: 1-D array - * 1 - labels: 1-D array - * 2 - weights : optional - * Int args: - * 0 - num_classes: optional - * - */ - #if NOT_EXCLUDED(OP_confusion_matrix) - DECLARE_CUSTOM_OP(confusion_matrix, 2, 1, false, 0, -2); - #endif - - /** - * This operation stacks a list of rank tensors into one rank-(R+1) tensor. - * Expected arguments: - * 0...: N-Dimensional arrays to stack - * - */ - #if NOT_EXCLUDED(OP_stack) - DECLARE_CUSTOM_OP(stack, -1, 1, false, 0, 0); - #endif - - /** - * This operation returns length of input array - * Expected arguments: - * input: N-dimensional array - * - * TODO: make this operation reduction, to allow TAD -> size - */ - #if NOT_EXCLUDED(OP_size) - DECLARE_CUSTOM_OP(size, 1, 1, false, 0, 0); // add DeclarableScalarOp? - #endif - - - /** - * This operation returns rank of input array as scalar value. - */ - #if NOT_EXCLUDED(OP_rank) - DECLARE_CUSTOM_OP(rank, 1, 1, false, 0, 0); // ^ - #endif - - - #if NOT_EXCLUDED(OP_broadcastgradientargs) - DECLARE_OP(broadcastgradientargs, 2, 2, true); - #endif - - /** - * This operation takes input's shape, and returns new NDArray filled with zeros - * Expected arguments: - * input: N-dimensional array - * - */ - #if NOT_EXCLUDED(OP_zeros_as) - DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); - #endif - - /** - * This operation takes input's shape, and returns new NDArray filled with ones - * Expected arguments: - * input: N-dimensional array - * - */ - #if NOT_EXCLUDED(OP_ones_as) - DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); - #endif - - /** - * This operation applies element-wise pow(x, 2) to the given input - * Expected arguments: - * input: N-Dimensional array - */ - #if NOT_EXCLUDED(OP_square) - DECLARE_OP(square, 1, 1, true); - #endif - - /** - * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + n)^{-x} - * Implementation is based on Euler-Maclaurin summation formula - * - * Input arrays: - * x: define power {-x}, must be > 1, type float. - * q: define summand in denominator, must be > 0, type float. - * - * Output array: - * 0: corresponding values of Hurwitz zeta function - * - * Two input and one output arrays must have the same shape - */ - #if NOT_EXCLUDED(OP_zeta) - DECLARE_CONFIGURABLE_OP(zeta, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in - * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * - * Input arrays: - * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) - * 1: x - abscissa points where to evaluate the polygamma function, type float - * - * Output array: - * 0: values of polygamma function at corresponding x, type float - * - * Two input and one output arrays have the same shape - */ - #if NOT_EXCLUDED(OP_polygamma) - DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates lgamma function lgamma(x) = log(Gamma(x)) - * - * Input arrays: - * 0: x - input matrix - * - * Output array: - * 0: log of Gamma(x) - * - */ - #if NOT_EXCLUDED(OP_lgamma) - DECLARE_OP(lgamma, 1, 1, true); - #endif - - /** - * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) - * - * Input arrays: - * 0: x - abscissa points where to evaluate the digamma function, type float - * - * Output array: - * 0: values of digamma function at corresponding x, type float - * - */ - #if NOT_EXCLUDED(OP_digamma) - DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); - #endif - - /** - * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. - * Input arrays: - * 0 - shape vector - * 1 - optional scalar NDArray - * - * T arguments: - * 0 - optional scalar value - * - */ - #if NOT_EXCLUDED(OP_fill) - DECLARE_CUSTOM_OP(fill, 1, 1, false, -2, 0); - #endif - - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * Input arrays: - * 0 - input array - * 1 - array of sizes - * 2 - optional axis - * - * Integer arguments: - * 0 - optional axis - * - */ - #if NOT_EXCLUDED(OP_split_v) - DECLARE_CUSTOM_OP(split_v, 2, -1, false, 0, -2); - #endif - - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * 0 - input array - * 1 - optional axis - * - * Integer arguments: - * 0 - number of splits - * 1 - optional axis - */ - #if NOT_EXCLUDED(OP_split) - DECLARE_CUSTOM_OP(split, 1, -1, false, 0, 1); - #endif - - - /** - * This operation adjusts image hue by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing delta - * - * T arguments: - * 0 - optional argument, delta value - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ - #if NOT_EXCLUDED(OP_adjust_hue) - DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); - #endif - - /** - * This operation adjusts image saturation by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation factor - * - * T arguments: - * 0 - optional argument, saturation factor - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ - #if NOT_EXCLUDED(OP_adjust_saturation) - DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); - #endif - - /** - * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) - * Input arrays: - * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation contrast factor - * - * T arguments: - * 0 - optional argument, contrast factor - * - */ - #if NOT_EXCLUDED(OP_adjust_contrast) - DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); - #endif - - - - - /** - * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation - * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension - * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input - * block size and how the data is moved. - * Input: - * 0 - 4D tensor on given type - * Output: - * 0 - 4D tensor of given type and proper shape - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - */ - #if NOT_EXCLUDED(OP_depth_to_space) - DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); - #endif - - /** - * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor - * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates - * the input block size. - * - * Input: - * - 4D tensor of given type - * Output: - * - 4D tensor - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - * - */ - #if NOT_EXCLUDED(OP_space_to_depth) - DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); - #endif - - /** - * This op calculates cross-product between input arguments - * Input arguments - * 0 - vector or tensor A - * 1 - vector or tensor B - */ - #if NOT_EXCLUDED(OP_cross) - DECLARE_OP(cross, 2, 1, false); - #endif - - /** - * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op - * outputs a copy of the input tensor where values from the height and width dimensions are moved to the - * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block - * size. - * - * Inputs: - * 0 - input tensor - * 1 - 2D paddings tensor (shape {M, 2}) - * - * Output: - * - result tensor - * - * Int args: - * 0 - block size (M) - * - */ - #if NOT_EXCLUDED(OP_space_to_batch) - DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); - #endif - - /* - * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape - * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, - * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension - * combines both the position within a spatial block and the original batch position. Prior to division into - * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. - * - * Inputs: - * 0 - input (N-D tensor) - * 1 - block_shape - int 1D tensor with M length - * 2 - paddings - int 2D tensor with shape {M, 2} - * - * Output: - * - N-D tensor with the same type as input 0. - * - * */ - #if NOT_EXCLUDED(OP_space_to_batch_nd) - DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); - #endif - - /** - * - * - */ - #if NOT_EXCLUDED(OP_batch_to_space) - DECLARE_CUSTOM_OP(batch_to_space, 2, 1, false, 0, 1); - #endif - #if NOT_EXCLUDED(OP_batch_to_space_nd) - DECLARE_CUSTOM_OP(batch_to_space_nd, 3, 1, false, 0, 0); - #endif - - /** - * top_k operation returns a vector of k top values for - * given NDArray as tensor with default boolean (true) - * as sort for result index array - * will be sorted by the values in descending order. - * The first parameter is a NDArray for working. - * The second is k (default 1) - optional - * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) optional - */ - #if NOT_EXCLUDED(OP_top_k) - DECLARE_CUSTOM_OP(top_k, 1, 2, false, 0, -1); - #endif - - /** - * in_top_k operation returns a vector of k boolean values for - * given NDArray as 2D matrix of predicted in the NDArray k top values - * The first parameter is a NDArray of predicted values (2d array). - * The second is NDArray as vector of indeces k top values will be search. - * The third is k - */ - #if NOT_EXCLUDED(OP_in_top_k) - DECLARE_CUSTOM_OP(in_top_k, 2, 1, true, 1, 1); - #endif - - /** - * moments operation calculate a mean and variation for given NDArray - * with reduce a result according to axis array given. - * For full axis the result is both mean and variance of all members in array. - * Otherwise there are two NDArrays with means and variances for - * Axes can be put as the second NDArray or as int vector. - * - * the optional flag "keep_dims" can be set as T param - */ - #if NOT_EXCLUDED(OP_moments) - DECLARE_CUSTOM_OP(moments, 1, 2, false, 0, -2); - #endif - - /** - * embedding_lookup - search for submatrices in given matrix and retunts them - * accordingly to index array given. - */ - #if NOT_EXCLUDED(OP_embedding_lookup) - DECLARE_CUSTOM_OP(embedding_lookup, 2, 1, false, 0, 1); - #endif - - /** - * dynamic_partition - partition a input tensor onto num_partitions - * accordingly to index array given. - * - * the first param - NDArray to be partitioned. - * the second param - index array - * the third param (integer param) - num or partitions. - * - * returns a num of NDArrays as output - */ - #if NOT_EXCLUDED(OP_dynamic_partition) - DECLARE_CUSTOM_OP(dynamic_partition, 2, 1, false, 0, 1); - #endif - - #if NOT_EXCLUDED(OP_dynamic_partition_bp) - DECLARE_CUSTOM_OP(dynamic_partition_bp, 3, 2, false, 0, 1); - #endif - - /** - * dynamic_stitch - merge partitions from the second param a input tensor - * into a single tensor accordingly to index array given. - * - * the first param - index array - * the second params - tensors to be merged - * - * returns a num of NDArrays as output - * - * the operation is inversion od dynamic_partition - */ - #if NOT_EXCLUDED(OP_dynamic_stitch) - DECLARE_CUSTOM_OP(dynamic_stitch, 2, 1, false, 0, 0); - #endif - - /** - * zero_fraction op. - * compute a fraction of zeros in given array - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ - #if NOT_EXCLUDED(OP_zero_fraction) - DECLARE_CUSTOM_OP(zero_fraction, 1, 1, false, 0, 0); - #endif - - /** - * xw_plus_b op. - * multiply two first matrices and add third vector to each row of result - * - * input params: - * - 2D matrix NxM - * - 2D matrix MxN - * - 1D vector with N elements - * output value - 2D matrix NxN as multiply of matrixes and add vector - * Int args: - * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul - */ - #if NOT_EXCLUDED(OP_xw_plus_b) - DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); - DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); - #endif - - /** - * This operation is missed due it simplicy. - * Input and output params are the same after operation. - * Input - NDArray, output - NDArray with the same shape. - */ - #if NOT_EXCLUDED(OP_stop_gradient) - DECLARE_OP(stop_gradient, 1, 1, true); - #endif - - #if NOT_EXCLUDED(OP_parallel_stack) - DECLARE_CUSTOM_OP(parallel_stack, -1, 1, false, 0, 0); - #endif - - /** - * normalize_moments operation normalize already calculated mean and variation - * accordingly to shift and count. - * input params: - * - count of data - * - tensor with mean - * - tensor with variance (the same shape as before) - * - * - optional floating point param shift. - * - * returns a normalized pair mean and variance with the same shapes as input - */ - #if NOT_EXCLUDED(OP_normalize_moments) - DECLARE_CUSTOM_OP(normalize_moments, 3, 2, false, 1, 0); - #endif - - /** - * sufficient_statistics operation return calculated mean and variation with data count. - * this operation is invert for moments - * accordingly to shift and count. - * input params: - * - input tensor - * - axes vector - * - * - * - optional floating point param shift. - * - optional int (as bool) keep_dimension - * - * returns four tensors: - * - scalar tensor (data count) - * - sum elements of input (accross axises) - * - sum of squares of input (accross axises) - * - shift (if was given by input floating param) - */ - #if NOT_EXCLUDED(OP_sufficient_statistics) - DECLARE_CUSTOM_OP(sufficient_statistics, 2, 3, false, 0, 0); - #endif - - /** - * This op calculates weighted logarithmic loss of input - * Input arguments - * 0 - target - * 1 - input - * 2 - weights (scalar or vector with same as last dimension) - * - * return value - a tensor with the same shape as target or input - */ - #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) - DECLARE_OP(weighted_cross_entropy_with_logits, 3, 1, true); - #endif - - /** - * This op calculates dropout of input - * Input arguments - * 0 - input tensor - * 1 - noise_shape - (vector with shape to reduce) - optional - * - * int parameter - seed for random numbers - * T parameter - probability (should be between 0 and 1) - * return value - a tensor with the same shape as target or input - */ - #if NOT_EXCLUDED(OP_dropout) - DECLARE_CONFIGURABLE_OP(dropout, 1, 1, true, 1, 1); - #endif - #if NOT_EXCLUDED(OP_dropout_bp) - DECLARE_CONFIGURABLE_OP(dropout_bp, 2, 1, false, 1, 1); - #endif - - /* Calculates alpha weighted dropout - T params: - 0 - drop probability - 1 - alpha value - 2 - alpha' value - 3 - beta value - */ - #if NOT_EXCLUDED(OP_alpha_dropout_bp) - DECLARE_CONFIGURABLE_OP(alpha_dropout_bp, 2, 1, false, 4, 1); - #endif - - - /** - * bincount operation return a vector with element counted. - * - * input params: - * - input tensor - only int part are accepted - * - weights - the same shape tensor with integer weights for element (optional) - * default weight - 1,1,1..,1 for all values in the tensor - * - * optional ints: - * - min_length - zero or greater - * - max_length - between min_length and max(input) + 1 - * - * returns four tensors: - * - vector tensor with length to min(max_len, max(input) + 1) with count - * of values in indexed place - * - */ - #if NOT_EXCLUDED(OP_bincount) - DECLARE_CUSTOM_OP(bincount, 1, 1, false, 0, 0); - #endif - - /** - * broadcast_dynamic_shape op. - * - * input params: - * 0 - the first shape (vector with shape) - * 1 - the second shape (vector with shape) - * - * return value: - * vector with broadcasted shape - */ - #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) - DECLARE_CUSTOM_OP(broadcast_dynamic_shape, 2, 1, false, 0, 0); - #endif - - /** - * matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with determinant for all - * M x M matricies - */ - #if NOT_EXCLUDED(OP_matrix_determinant) - DECLARE_CUSTOM_OP(matrix_determinant, 1, 1, false, 0, 0); - #endif - - /** - * log_matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ - - #if NOT_EXCLUDED(OP_log_matrix_determinant) - DECLARE_CUSTOM_OP(log_matrix_determinant, 1, 1, false, 0, 0); - #endif - - /** - * logdet op. Logarithm of the determinant of hermitian positive matricies. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ - - #if NOT_EXCLUDED(OP_logdet) - DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0); - #endif - - /** - * matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_lstsq) - DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0); - #endif - - /* solve_ls - analog of lstsq op with another solution approach - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition - * */ - #if NOT_EXCLUDED(OP_solve_ls) - DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0); - #endif - - /** - * matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M matricies in it - */ - #if NOT_EXCLUDED(OP_matrix_inverse) - DECLARE_OP(matrix_inverse, 1, 1, true); - #endif - - /** - * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - lower - default is true (optional) - left part is lower triangular matrix - * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_triangular_solve) - DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0); - #endif - - /** - * solve op. - solve systems of linear equations - general method. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ - #if NOT_EXCLUDED(OP_solve) - DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); - #endif - - /** - * lu op. - make LUP decomposition of given batch of 2D square matricies - * - * input params: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it - * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) - * - * int argument: - * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 - */ - - #if NOT_EXCLUDED(OP_matrix_inverse) - DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); - #endif - - /** - * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] - * - * input params: - * 0 - the ND-tensor filled by integer-like values - * - * optional int param - maxlength (maxlength >= max(x)). By default maxlength = max(x). - * return value: - * (N+1)D tensor filled by 0 and 1 accordingly the mask - */ - #if NOT_EXCLUDED(OP_sequence_mask) - DECLARE_CUSTOM_OP(sequence_mask, 1, 1, false, 0, 0); - #endif - /** - * segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ - - #if NOT_EXCLUDED(OP_segment_max) - DECLARE_CUSTOM_OP(segment_max, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_max_bp) - DECLARE_CUSTOM_OP(segment_max_bp, 3, 2, false, 0, 0); - #endif - - /** - * segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with min values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_min) - DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_min_bp) - DECLARE_CUSTOM_OP(segment_min_bp, 3, 2, false, 0, 0); - #endif - - /** - * segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with sum of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_sum) - DECLARE_CUSTOM_OP(segment_sum, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_sum_bp) - DECLARE_CUSTOM_OP(segment_sum_bp, 3, 2, false, 0, 0); - #endif - - /** - * segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with product of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_prod) - DECLARE_CUSTOM_OP(segment_prod, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_prod_bp) - DECLARE_CUSTOM_OP(segment_prod_bp, 3, 2, false, 0, 0); - #endif - /** - * segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_segment_mean) - DECLARE_CUSTOM_OP(segment_mean, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_segment_mean_bp) - DECLARE_CUSTOM_OP(segment_mean_bp, 3, 2, false, 0, 0); - #endif - - /** - * unsorted_segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_max) - DECLARE_CUSTOM_OP(unsorted_segment_max, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) - DECLARE_CUSTOM_OP(unsorted_segment_max_bp, 3, 2, false, 0, 1); - #endif - - /** - * unsorted_segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with min values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - DECLARE_CUSTOM_OP(unsorted_segment_min, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - DECLARE_CUSTOM_OP(unsorted_segment_min_bp, 3, 2, false, 0, 1); - #endif - - /** - * unsorted_segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with sum of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_sum) - DECLARE_CUSTOM_OP(unsorted_segment_sum, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) - DECLARE_CUSTOM_OP(unsorted_segment_sum_bp, 3, 2, false, 0, 1); - #endif - - /** - * unsorted_segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with product of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_prod) - DECLARE_CUSTOM_OP(unsorted_segment_prod, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) - DECLARE_CUSTOM_OP(unsorted_segment_prod_bp, 3, 2, false, 0, 1); - #endif - - /** - * unsorted_segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_mean) - DECLARE_CUSTOM_OP(unsorted_segment_mean, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) - DECLARE_CUSTOM_OP(unsorted_segment_mean_bp, 3, 2, false, 0, 1); - #endif - - /** - * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor divided by the sqrt(N). - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ - #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) - DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) - DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1); - #endif - - /** - * extract_image_patches op - Extract patches from images and put them in the "depth" output dimension. - * - * input params: - * 0 - images tensor (4D) - * - * int params: - * 0 - ksize_rows - * 1 - ksize_cols - * 2 - strides_rows - * 3 - strides_cols - * 4 - rates_rows - * 5 - rates_cols - * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' - */ - #if NOT_EXCLUDED(OP_extract_image_patches) - DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7); - #endif - - /** - * draw_bounding_boxes op - modified input image with given colors exept given boxes. - * - * input params: - * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), - * 3 (RGB) or 4 (RGBA) - * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as - * (y_min, x_min, y_max, x_max), all values in between 0. and 1. - * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) - * - * output: - * 0 - 4D tensor with same shape as images (input 0) - */ - #if NOT_EXCLUDED(OP_draw_bounding_boxes) - DECLARE_OP(draw_bounding_boxes, 3, 1, true); - #endif - - /** - * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) - * - * input params: - * 0 - NDArray - * - * int params: - * 0 - shift - * 1 - axe 1 - * 2 - axe 2 - * ... - * N - axe N - * - * All axes are optional and should be between 0 and input->rankOf(). Of course, all axes can be repeated. - * - * output: - * 0 - NDArray with the same shape as input. - */ - #if NOT_EXCLUDED(OP_roll) - DECLARE_CONFIGURABLE_OP(roll, 1, 1, true, 0, 1); - #endif - - /** - * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) - * - * optional input params: - * 0 - startVal - NDArray scalar (float point) - * 1 - finishVal - NDArray scalar (float point) - * 2 - numOfElements - NDArray scalar (integer) - * Optional: - * T args - * 0 - startVal - * 1 - finishVal] - * 2 - numOfElements - * output: - * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. - */ - #if NOT_EXCLUDED(OP_lin_space) - DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); - #endif - - /** - * reduction_sum - tf.reduction_sum operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_sum) - DECLARE_CUSTOM_OP(reduce_sum, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_reduce_sum_bp) - DECLARE_CUSTOM_OP(reduce_sum_bp, 2, 1, false, 0, 0); - #endif - - /** - * reduction_prod - tf.reduction_prod operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_prod) - DECLARE_CUSTOM_OP(reduce_prod, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_reduce_prod_bp) - DECLARE_CUSTOM_OP(reduce_prod_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates min of elements along given dimensions - * - * input array: - * x: tensor to calculate mins for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate min along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated mins - */ - #if NOT_EXCLUDED(OP_reduce_min) - DECLARE_CUSTOM_OP(reduce_min, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_min_bp) - DECLARE_CUSTOM_OP(reduce_min_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates max of elements along given dimensions - * - * input array: - * x: tensor to calculate maxes for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated maxes - */ - #if NOT_EXCLUDED(OP_reduce_max) - DECLARE_CUSTOM_OP(reduce_max, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_max_bp) - DECLARE_CUSTOM_OP(reduce_max_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates norm1 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm1 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm1 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm1 - */ - #if NOT_EXCLUDED(OP_reduce_norm1) - DECLARE_CUSTOM_OP(reduce_norm1, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm1_bp) - DECLARE_CUSTOM_OP(reduce_norm1_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates norm2 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm2 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm2 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm2 - */ - #if NOT_EXCLUDED(OP_reduce_norm2) - DECLARE_CUSTOM_OP(reduce_norm2, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm2_bp) - DECLARE_CUSTOM_OP(reduce_norm2_bp, 2, 1, false, 0, 0); - #endif - - - /** - * This op calculates squared norm of elements along given dimensions - * - * input array: - * x: tensor to calculate squared norm for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate squared norm along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ - #if NOT_EXCLUDED(OP_reduce_sqnorm) - DECLARE_CUSTOM_OP(reduce_sqnorm, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) - DECLARE_CUSTOM_OP(reduce_sqnorm_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates norm max of elements along given dimensions - * - * input array: - * x: tensor to calculate norm max for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ - #if NOT_EXCLUDED(OP_reduce_norm_max) - DECLARE_CUSTOM_OP(reduce_norm_max, 1, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_reduce_norm_max_bp) - DECLARE_CUSTOM_OP(reduce_norm_max_bp, 2, 1, false, 0, 0); - #endif - - /** - * This op calculates mean of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - #if NOT_EXCLUDED(OP_reduce_mean) - DECLARE_CUSTOM_OP(reduce_mean, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_reduce_mean_bp) - DECLARE_CUSTOM_OP(reduce_mean_bp, 2, 1, false, 0, 0) - #endif - /** - * This op calculates sample variance of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - DECLARE_CUSTOM_OP(reduce_variance, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(reduce_variance_bp, 2, 1, false, 0, 0) - - /** - * This op calculates sample standard deviation of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - DECLARE_CUSTOM_OP(reduce_stdev, 1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(reduce_stdev_bp, 2, 1, false, 0, 0) - - /** - * This op calculates backprop dot for two tensors along given dimensions - * - * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. - * - * output array: - * the tensor with calculated backproped dots - * - */ - - #if NOT_EXCLUDED(OP_reduce_dot_bp) - DECLARE_CUSTOM_OP(reduce_dot_bp, 3, 2, false, 0, 0); - #endif - /** - * reduce_logsumexp - tf.reduce_logsumexe operation - * - * input params: - * 0 - NDArray (input) - * 1 - 1D NDArray (axis) (optional) - integer array - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * CAUTION: All axes are optional and should be between 0 and input->rankOf() - 1 - * and put either with second param or as integers but not both - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ - #if NOT_EXCLUDED(OP_reduce_logsumexp) - DECLARE_CUSTOM_OP(reduce_logsumexp, 1, 1, false, 0, 0); - #endif - - /** - * This op make bilinear or nearest neighbor interpolated resize for given tensor - * - * input array: - * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) numeric type - * 1 - 2D-Tensor with shape (num_boxes, 4) float type - * 2 - 1D-Tensor with shape (num_boxes) int type - * 3 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) int type - * - * float arguments (optional) - * 0 - exprapolation_value (optional) default 0.f - * - * int arguments: (optional) - * 0 - mode (default 0 - bilinear interpolation) - * - * output array: - * the 4D-Tensor with resized to crop_size images given - float type - */ - #if NOT_EXCLUDED(OP_crop_and_resize) - DECLARE_CUSTOM_OP(crop_and_resize, 4, 1, false, -1, -1); - #endif - - /** - * This op make bilinear interpolated resize for given tensor - * - * input array: - * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) - * 1 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) - * - * int arguments: (optional) - * 0 - new width - * 1 - new height - * - * output array: - * the 4D-Tensor with calculated backproped dots - * - * CAUTION: either size tensor or a pair of int params should be provided. - */ - - #if NOT_EXCLUDED(OP_resize_bilinear) - DECLARE_CUSTOM_OP(resize_bilinear, 1, 1, false, 0, -2); - #endif - - /** - * This op make nearest neighbor interpolated resize for given tensor - * - * input array: - * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) - * 1 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) - * - * int arguments: (optional) - * 0 - new width - * 1 - new height - * - * output array: - * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) - * - * CAUTION: either size tensor or a pair of int params should be provided. - */ - - #if NOT_EXCLUDED(OP_resize_nearest_neighbor) - DECLARE_CUSTOM_OP(resize_nearest_neighbor, 1, 1, false, 0, -2); - #endif - - /** - * This op make bicubic interpolated resize for given tensor - * - * input array: - * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) - * 1 - 1D-Tensor with 2 values (newWidth, newHeight) - * - * output array: - * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) - * - */ - #if NOT_EXCLUDED(OP_resize_bicubic) - DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2); - #endif - - /** - * This op make area interpolated resize (as OpenCV INTER_AREA algorithm) for given tensor - * - * input array: - * 0 - images - 4D-Tensor with shape (batch, sizeX, sizeY, channels) - * 1 - size - 1D-Tensor with 2 values (newWidth, newHeight) (if missing a pair of integer args should be provided). - * - * int args: - proveded only when size tensor is missing - * 0 - new height - * 1 - new width - * boolean args: - * 0 - align_corners - optional (default is false) - * - * output array: - * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) - * - */ - #if NOT_EXCLUDED(OP_resize_area) - DECLARE_CUSTOM_OP(resize_area, 1, 1, false, 0, -2); - #endif - - /** - * This op make interpolated resize for given tensor with given algorithm. - * Supported algorithms are bilinear, bicubic, nearest_neighbor. - * Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic - * - * input array: - * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) - * 1 - 1D-Tensor with 2 values (newWidth, newHeight) - * - * optional int args: - * 0 - algorithm - bilinear by default - * optional bool args: - * 0 - preserve_aspect_ratio - default False - * 1 - antialias - default False - * - * output array: - * the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels}) - * - */ - - #if NOT_EXCLUDED(OP_image_resize) - DECLARE_CUSTOM_OP(image_resize, 2, 1, false, 0, 0); - #endif - - /** - * Copy a tensor setting everything outside a central band in each innermost matrix - * - * input array: - * x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN - * - * int arguments: - * lower band - * upper band - * - * output array: - * matrix with given bands between lower and upper diagonals - * - */ - - #if NOT_EXCLUDED(OP_matrix_band_part) - DECLARE_CONFIGURABLE_OP(matrix_band_part, 1, 1, true, 0, 2); - #endif - - - #if NOT_EXCLUDED(OP_Assert) - DECLARE_OP(Assert, 1, 1, false); - #endif - - /** - * image.non_max_suppression ops. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * - vector with size M, where M <= output_size by int type - * - * */ - #if NOT_EXCLUDED(OP_image_non_max_suppression) - DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); - #endif - #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0); - #endif - - /* - * image.non_max_suppression_overlaps op. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size - * */ - #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) - DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); - #endif - - /* - * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). - * input: - * 0 - matricies - tensor with shape (..., N, N) by float type - * - * output - lower triangular matrix (matricies when rank > 2) with the same shape as input. - * */ - #if NOT_EXCLUDED(OP_cholesky) - DECLARE_OP(cholesky, 1, 1, true); - #endif - /* - * nth_element - apply nth_element for last dimension of input tensor - * input array: - * 0 - input array - * 1 - scalar tensor with n for operation. n should be less than last dimension - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_nth_element) - DECLARE_CUSTOM_OP(nth_element, 2, 1, false, 0, 0); - #endif - - /** - * This op checks for Inf/NaN values within input array, and throws exception if there's at least one - */ - #if NOT_EXCLUDED(OP_check_numerics) - DECLARE_CUSTOM_OP(check_numerics, 2, 1, true, 0, 0); - #endif -/** - * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - min value - * 2 - 0D Tensor - max value - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) - DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); - #endif - -/** - * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel - * - * input params: - * 0 - NDArray (input) - at least 2D. - * 1 - 1D Tensor - min values (min length equals to last dim of input) - * 2 - 1D Tensor - max value (length equals to min) - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ - #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) - DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, -2); - #endif - - /** - * compare_and_bitpack - compare with greater and pack result with uint8 - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - threshold - * - * - * output: - * 0 - NDArray with the same shape as input and type uint8 - */ - #if NOT_EXCLUDED(OP_compare_and_bitpack) - DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); - #endif - } -} +namespace ops { +/** + * This operation returns index of max element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +#if NOT_EXCLUDED(OP_argmax) +DECLARE_CUSTOM_OP(argmax, 1, 1, false, 0, -2); +#endif + +/** + * This operation returns index of min element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +#if NOT_EXCLUDED(OP_argmin) +DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); +#endif + +/** + * This operation provides various normalization modes: + * 0: frobenius + * 1: euclidean (norm2) + * 2: norm1 + * 3: norm2 + * 4: inf-norm + * 5: p-norm + * + * Expected arguments: + * input: N-dimensional array + * + * + * Int args: + * 0...: axis + * + * T args: + * 0: norm mode + * 1: p for p-norm + */ +#if NOT_EXCLUDED(OP_norm) +DECLARE_REDUCTION_OP(norm, 1, 1, false, 1, -2); +#endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of input array + * + * Input arrays: + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank + * - 1, the shapes of diagonal and input arrays must be equal except last + * dimension of input array, for example if input_shape = [A,B,C,D] then + * diagonal_shape = [A,B,C], also last dimension of diagonal array should be + * equal to smaller of last and last but one input dimensions that is: + * diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * 0: has the same shape as input, corresponding diagonal elements are + * substituted + */ +#if NOT_EXCLUDED(OP_matrix_set_diag) +DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); +#endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of output array, rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank + * - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has + * shape [A,B,C] then output array has shape [A,B,C,C] + */ +DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0); + +/** + * This op calculates regularized incomplete beta integral Ix(a, b). + * Implementation is based on two algorithms depending on input values of a and + * b: + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature + * method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm + * for continued fractions is applied + * + * Input arrays: + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, + * type float. + * + * Output array: + * 0: values of regularized incomplete beta integral that corresponds to + * variable upper limit x, type float + * + * Three input and one output arrays must have the same shape + */ +#if NOT_EXCLUDED(OP_betainc) +DECLARE_CONFIGURABLE_OP(betainc, 3, 1, false, 0, 0); +#endif + +/** + * This operation is added for compatibility purposes mostly. + * PLEASE NOTE: Please consider using Add instead + * Expected arguments: + * 0: N-dimensional input + * 1: bias vector + */ +#if NOT_EXCLUDED(OP_biasadd) +DECLARE_CUSTOM_OP(biasadd, 2, 1, true, 0, 0); +DECLARE_CUSTOM_OP(biasadd_bp, 3, 2, false, 0, 0); +#endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +#if NOT_EXCLUDED(OP_diag) +DECLARE_CUSTOM_OP(diag, 1, 1, false, 0, 0); +#endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +#if NOT_EXCLUDED(OP_diag_part) +DECLARE_CUSTOM_OP(diag_part, 1, 1, false, 0, 0); +#endif + +/** + * Returns a diagonal vector for any submatricies with in a given tensor. + * It is an op inverse to matrix_set_giag. + * Using input tensor as batched 2D diagonals flat them to vector (1D) with + * diagonal values. + * + * Input : batched tensor with rank >=2 + * Output: tensor with rank lesser by 1 from input + */ +#if NOT_EXCLUDED(OP_matrix_diag_part) +DECLARE_CUSTOM_OP(matrix_diag_part, 1, 1, false, 0, 0); +#endif + +/** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper + * triangular. For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of + * float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies + * {Qs} 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular + * matricies {Rs} + */ +#if NOT_EXCLUDED(OP_qr) +DECLARE_CUSTOM_OP(qr, 1, 2, false, 0, 0); +#endif + +/** + * This operation takes 2 arrays: original values, and values to be excluded. + * And returns 2 arrays: values left after exclusion, and indices in original + * array for surivals. Expected arguments: 0: vector with original values 1: + * vector with values to exclude + */ +#if NOT_EXCLUDED(OP_listdiff) +DECLARE_CUSTOM_OP(listdiff, 2, 2, false, 0, 0); +#endif + +/** + * This operation applies Add operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_add) +DECLARE_OP(scatter_add, 3, 1, true); +#endif + +/** + * This operation applies Subtract operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_sub) +DECLARE_OP(scatter_sub, 3, 1, true); +#endif + +/** + * This operation applies Multiply operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_mul) +DECLARE_OP(scatter_mul, 3, 1, true); +#endif + +/** + * This operation applies Divide operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_div) +DECLARE_OP(scatter_div, 3, 1, true); +#endif + +/** + * This operation applies Assign operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_upd) +DECLARE_OP(scatter_upd, 3, 1, true); +#endif + +/** + * This operation applies Max operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_max) +DECLARE_OP(scatter_max, 3, 1, true); +#endif + +/** + * This operation applies Min operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_min) +DECLARE_OP(scatter_min, 3, 1, true); +#endif + +/** + * This operation scatter "updates" elements into new output array according to + * given "indices" Expected arguments: indices: array containing elements/slices + * indexes of output array to put "updates" elements into, the rest output + * elements will be zeros updates: array containing elements to be inserted into + * output array shape: contains shape of output array + */ +#if NOT_EXCLUDED(OP_scatter_nd) +DECLARE_CUSTOM_OP(scatter_nd, 3, 1, false, 0, 0); +#endif + +/** + * This operation scatter "updates" elements into input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to put "updates" elements + * into updates: array containing elements to be inserted into input array + */ +#if NOT_EXCLUDED(OP_scatter_nd_update) +DECLARE_OP(scatter_nd_update, 3, 1, true); +#endif + +/** + * This operation adds "updates" elements to input array along given "indices" + * Expected arguments: + * input: array to be updated + * indices: array containing elements/slices indexes of input array to add + * "updates" elements to updates: array containing elements to be interfered + * with input + */ +#if NOT_EXCLUDED(OP_scatter_add) +DECLARE_OP(scatter_nd_add, 3, 1, true); +#endif + +/** + * This operation subtract "updates" elements from input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to subtract "updates" + * elements from updates: array containing elements to be interfered with input + */ +#if NOT_EXCLUDED(OP_scatter_sub) +DECLARE_OP(scatter_nd_sub, 3, 1, true); +#endif + +/** + * This operation takes input's shape, and returns new NDArray filled with + * specified value Expected arguments: input: N-dimensional array + * + * T args: + * 0: scalar value, used to fill NDArray + */ +#if NOT_EXCLUDED(OP_fill_as) +DECLARE_CONFIGURABLE_OP(fill_as, 1, 1, true, 1, 0); +#endif + +/** + * This operation applies element-wise rint (round to integral value) operation + */ +#if NOT_EXCLUDED(OP_rint) +DECLARE_OP(rint, 1, 1, true); +#endif + +/** + * This operation returns unique elements from input array as vector, and their + * original indices in input array Expected input: input: N-dimensional array + */ +#if NOT_EXCLUDED(OP_unique) +DECLARE_CUSTOM_OP(unique, 1, 2, false, 0, 0); +#endif + +/** + * This operation returns 3 1D arrays for given 1D array with unique element + * count and indexes input: 0 - 1D array + * + * output: + * 0 - 1D array with unique values + * 1 - 1D array with ids for values in array above + * 2 - 1D array with counts for values in array above + */ +#if NOT_EXCLUDED(OP_unique_with_counts) +DECLARE_CUSTOM_OP(unique_with_counts, 1, 3, false, 0, 0); +#endif + +/** + * This operation splits input NDArray into multiple TADs along given dimensions + * Expected arguments: + * input: N-dimensional array + * + * Int args: + * 0..: TAD axis + */ +#if NOT_EXCLUDED(OP_tear) +DECLARE_CUSTOM_OP(tear, 1, -1, false, 0, -1); +#endif + +/** + * This op does the same as tear, just uses different input format: + * @tparam T + */ +#if NOT_EXCLUDED(OP_unstack) +DECLARE_CUSTOM_OP(unstack, 1, -1, false, 0, 1); +#endif + +/** + * This operation extracts a strided (optionally) slice from a tensor, + */ +#if NOT_EXCLUDED(OP_strided_slice) +DECLARE_CUSTOM_OP(strided_slice, 1, 1, false, 0, + 5); // TODO: new op type needed. that returns VIEW +DECLARE_CUSTOM_OP(strided_slice_bp, 2, 1, false, 0, 5); +#endif + +/** + * This operation extracts a slice from a tensor. + * + */ +#if NOT_EXCLUDED(OP_slice) +DECLARE_CUSTOM_OP(slice, 1, 1, false, 0, -2); +DECLARE_CUSTOM_OP(slice_bp, 2, 1, false, 0, -2); +#endif + +/** + * This operation generate sequences. Basically from......to, with step used as + * increment. Expected arguments: start: optional scalar with starting value + * stop: optional scalar with end value + * step: optional scalar witn step value + * + * Int args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + * + * T args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + */ +#if NOT_EXCLUDED(OP_range) +DECLARE_CUSTOM_OP(range, -2, 1, false, -2, -2); +#endif + +/** + * This operation return one-hot encoded n-dimensional array + * Expected arguments: + * input: N-dimensional array + * + * T args: + * 0: 'on' value + * 1: 'off' value + * + * Int args: + * 0: depth + * 1: axis + */ +#if NOT_EXCLUDED(OP_onehot) +DECLARE_CUSTOM_OP(onehot, 1, 1, false, -2, -2); +#endif + +/** + * This operation calculate the confusion matrix for a + * pair of prediction and label 1-D arrays. + * Expected arguments: + * Input arrays: + * 0 - predictions: 1-D array + * 1 - labels: 1-D array + * 2 - weights : optional + * Int args: + * 0 - num_classes: optional + * + */ +#if NOT_EXCLUDED(OP_confusion_matrix) +DECLARE_CUSTOM_OP(confusion_matrix, 2, 1, false, 0, -2); +#endif + +/** + * This operation stacks a list of rank tensors into one rank-(R+1) tensor. + * Expected arguments: + * 0...: N-Dimensional arrays to stack + * + */ +#if NOT_EXCLUDED(OP_stack) +DECLARE_CUSTOM_OP(stack, -1, 1, false, 0, 0); +#endif + +/** + * This operation returns length of input array + * Expected arguments: + * input: N-dimensional array + * + * TODO: make this operation reduction, to allow TAD -> size + */ +#if NOT_EXCLUDED(OP_size) +DECLARE_CUSTOM_OP(size, 1, 1, false, 0, 0); // add DeclarableScalarOp? +#endif + +/** + * This operation returns rank of input array as scalar value. + */ +#if NOT_EXCLUDED(OP_rank) +DECLARE_CUSTOM_OP(rank, 1, 1, false, 0, 0); // ^ +#endif + +#if NOT_EXCLUDED(OP_broadcastgradientargs) +DECLARE_OP(broadcastgradientargs, 2, 2, true); +#endif + +/** + * This operation takes input's shape, and returns new NDArray filled with zeros + * Expected arguments: + * input: N-dimensional array + * + */ +#if NOT_EXCLUDED(OP_zeros_as) +DECLARE_CUSTOM_OP(zeros_as, 1, 1, false, 0, 0); +#endif + +/** + * This operation takes input's shape, and returns new NDArray filled with ones + * Expected arguments: + * input: N-dimensional array + * + */ +#if NOT_EXCLUDED(OP_ones_as) +DECLARE_CUSTOM_OP(ones_as, 1, 1, false, 0, 0); +#endif + +/** + * This operation applies element-wise pow(x, 2) to the given input + * Expected arguments: + * input: N-Dimensional array + */ +#if NOT_EXCLUDED(OP_square) +DECLARE_OP(square, 1, 1, true); +#endif + +/** + * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + + * n)^{-x} Implementation is based on Euler-Maclaurin summation formula + * + * Input arrays: + * x: define power {-x}, must be > 1, type float. + * q: define summand in denominator, must be > 0, type float. + * + * Output array: + * 0: corresponding values of Hurwitz zeta function + * + * Two input and one output arrays must have the same shape + */ +#if NOT_EXCLUDED(OP_zeta) +DECLARE_CONFIGURABLE_OP(zeta, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates polygamma function psi^(n)(x). Implementation is based on + * serial representation written in terms of the Hurwitz zeta function: + * polygamma = (-1)^{n+1} * n! * zeta(n+1, x). + * + * Input arrays: + * 0: n - define derivative order (n+1), type integer (however currently is + * implemented as float casted to integer) 1: x - abscissa points where to + * evaluate the polygamma function, type float + * + * Output array: + * 0: values of polygamma function at corresponding x, type float + * + * Two input and one output arrays have the same shape + */ +#if NOT_EXCLUDED(OP_polygamma) +DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ +#if NOT_EXCLUDED(OP_lgamma) +DECLARE_OP(lgamma, 1, 1, true); +#endif + +/** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +#if NOT_EXCLUDED(OP_digamma) +DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); +#endif + +/** + * This operation takes shape as first argument, and returns new NDArray filled + * with specific scalar value. Input arrays: 0 - shape vector 1 - optional + * scalar NDArray + * + * T arguments: + * 0 - optional scalar value + * + */ +#if NOT_EXCLUDED(OP_fill) +DECLARE_CUSTOM_OP(fill, 1, 1, false, -2, 0); +#endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension Input arrays: 0 - input array 1 - array of sizes 2 - optional axis + * + * Integer arguments: + * 0 - optional axis + * + */ +#if NOT_EXCLUDED(OP_split_v) +DECLARE_CUSTOM_OP(split_v, 2, -1, false, 0, -2); +#endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension 0 - input array 1 - optional axis + * + * Integer arguments: + * 0 - number of splits + * 1 - optional axis + */ +#if NOT_EXCLUDED(OP_split) +DECLARE_CUSTOM_OP(split, 1, -1, false, 0, 1); +#endif + +/** + * This operation adjusts image hue by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing delta + * + * T arguments: + * 0 - optional argument, delta value + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_adjust_hue) +DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); +#endif + +/** + * This operation adjusts image saturation by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing saturation factor + * + * T arguments: + * 0 - optional argument, saturation factor + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_adjust_saturation) +DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); +#endif + +/** + * This operation adjusts image contrast by given factor ( z = (x - mean) * + * factor + mean ) Input arrays: 0 - input array with rank >= 3, must have last + * one dimension equal 3, that is dimension containing channels. 1 - optional + * argument, input scalar-array containing saturation contrast factor + * + * T arguments: + * 0 - optional argument, contrast factor + * + */ +#if NOT_EXCLUDED(OP_adjust_contrast) +DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); +#endif + +/** + * This operation rearranges data from depth into blocks of spatial data. This + * is the reverse transformation of space_to_depth op. This op output is a copy + * of the input tensor where values from the depth dimension are moved in + * spatial blocks to the height and width dimensions. Int attr 0 indicates the + * input block size and how the data is moved. Input: 0 - 4D tensor on given + * type Output: 0 - 4D tensor of given type and proper shape + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + */ +#if NOT_EXCLUDED(OP_depth_to_space) +DECLARE_CUSTOM_OP(depth_to_space, 1, 1, false, 0, -1); +#endif + +/** + * This operation rearranges blocks of spatial data, into depth.This op output + * is a copy of the input tensor where values from the height and width + * dimensions are moved to the depth dimension. Int attr 0 indicates the input + * block size. + * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + * + */ +#if NOT_EXCLUDED(OP_space_to_depth) +DECLARE_CUSTOM_OP(space_to_depth, 1, 1, false, 0, -1); +#endif + +/** + * This op calculates cross-product between input arguments + * Input arguments + * 0 - vector or tensor A + * 1 - vector or tensor B + */ +#if NOT_EXCLUDED(OP_cross) +DECLARE_OP(cross, 2, 1, false); +#endif + +/** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the batch dimension. After + * the zero-padding, both height and width of the input must be divisible by the + * block size. + * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) + * + */ +#if NOT_EXCLUDED(OP_space_to_batch) +DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1); +#endif + +/* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a + * grid of blocks of shape block_shape, and interleaves these blocks with the + * "batch" dimension (0) such that in the output, the spatial dimensions [1, + * ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch + * position. Prior to division into blocks, the spatial dimensions of the input + * are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ +#if NOT_EXCLUDED(OP_space_to_batch_nd) +DECLARE_CUSTOM_OP(space_to_batch_nd, 3, 1, false, 0, 0); +#endif + +/** + * + * + */ +#if NOT_EXCLUDED(OP_batch_to_space) +DECLARE_CUSTOM_OP(batch_to_space, 2, 1, false, 0, 1); +#endif +#if NOT_EXCLUDED(OP_batch_to_space_nd) +DECLARE_CUSTOM_OP(batch_to_space_nd, 3, 1, false, 0, 0); +#endif + +/** + * top_k operation returns a vector of k top values for + * given NDArray as tensor with default boolean (true) + * as sort for result index array + * will be sorted by the values in descending order. + * The first parameter is a NDArray for working. + * The second is k (default 1) - optional + * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) + * optional + */ +#if NOT_EXCLUDED(OP_top_k) +DECLARE_CUSTOM_OP(top_k, 1, 2, false, 0, -1); +#endif + +/** + * in_top_k operation returns a vector of k boolean values for + * given NDArray as 2D matrix of predicted in the NDArray k top values + * The first parameter is a NDArray of predicted values (2d array). + * The second is NDArray as vector of indeces k top values will be search. + * The third is k + */ +#if NOT_EXCLUDED(OP_in_top_k) +DECLARE_CUSTOM_OP(in_top_k, 2, 1, true, 1, 1); +#endif + +/** + * moments operation calculate a mean and variation for given NDArray + * with reduce a result according to axis array given. + * For full axis the result is both mean and variance of all members in array. + * Otherwise there are two NDArrays with means and variances for + * Axes can be put as the second NDArray or as int vector. + * + * the optional flag "keep_dims" can be set as T param + */ +#if NOT_EXCLUDED(OP_moments) +DECLARE_CUSTOM_OP(moments, 1, 2, false, 0, -2); +#endif + +/** + * embedding_lookup - search for submatrices in given matrix and retunts them + * accordingly to index array given. + */ +#if NOT_EXCLUDED(OP_embedding_lookup) +DECLARE_CUSTOM_OP(embedding_lookup, 2, 1, false, 0, 1); +#endif + +/** + * dynamic_partition - partition a input tensor onto num_partitions + * accordingly to index array given. + * + * the first param - NDArray to be partitioned. + * the second param - index array + * the third param (integer param) - num or partitions. + * + * returns a num of NDArrays as output + */ +#if NOT_EXCLUDED(OP_dynamic_partition) +DECLARE_CUSTOM_OP(dynamic_partition, 2, 1, false, 0, 1); +#endif + +#if NOT_EXCLUDED(OP_dynamic_partition_bp) +DECLARE_CUSTOM_OP(dynamic_partition_bp, 3, 2, false, 0, 1); +#endif + +/** + * dynamic_stitch - merge partitions from the second param a input tensor + * into a single tensor accordingly to index array given. + * + * the first param - index array + * the second params - tensors to be merged + * + * returns a num of NDArrays as output + * + * the operation is inversion od dynamic_partition + */ +#if NOT_EXCLUDED(OP_dynamic_stitch) +DECLARE_CUSTOM_OP(dynamic_stitch, 2, 1, false, 0, 0); +#endif + +/** + * zero_fraction op. + * compute a fraction of zeros in given array + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +#if NOT_EXCLUDED(OP_zero_fraction) +DECLARE_CUSTOM_OP(zero_fraction, 1, 1, false, 0, 0); +#endif + +/** + * xw_plus_b op. + * multiply two first matrices and add third vector to each row of result + * + * input params: + * - 2D matrix NxM + * - 2D matrix MxN + * - 1D vector with N elements + * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else + * mmul + */ +#if NOT_EXCLUDED(OP_xw_plus_b) +DECLARE_CUSTOM_OP(xw_plus_b, 3, 1, false, 0, 0); +DECLARE_CUSTOM_OP(xw_plus_b_bp, 4, 3, false, 0, 0); +#endif + +/** + * This operation is missed due it simplicy. + * Input and output params are the same after operation. + * Input - NDArray, output - NDArray with the same shape. + */ +#if NOT_EXCLUDED(OP_stop_gradient) +DECLARE_OP(stop_gradient, 1, 1, true); +#endif + +#if NOT_EXCLUDED(OP_parallel_stack) +DECLARE_CUSTOM_OP(parallel_stack, -1, 1, false, 0, 0); +#endif + +/** + * normalize_moments operation normalize already calculated mean and variation + * accordingly to shift and count. + * input params: + * - count of data + * - tensor with mean + * - tensor with variance (the same shape as before) + * + * - optional floating point param shift. + * + * returns a normalized pair mean and variance with the same shapes as input + */ +#if NOT_EXCLUDED(OP_normalize_moments) +DECLARE_CUSTOM_OP(normalize_moments, 3, 2, false, 1, 0); +#endif + +/** + * sufficient_statistics operation return calculated mean and variation with + * data count. this operation is invert for moments accordingly to shift and + * count. input params: + * - input tensor + * - axes vector + * + * + * - optional floating point param shift. + * - optional int (as bool) keep_dimension + * + * returns four tensors: + * - scalar tensor (data count) + * - sum elements of input (accross axises) + * - sum of squares of input (accross axises) + * - shift (if was given by input floating param) + */ +#if NOT_EXCLUDED(OP_sufficient_statistics) +DECLARE_CUSTOM_OP(sufficient_statistics, 2, 3, false, 0, 0); +#endif + +/** + * This op calculates weighted logarithmic loss of input + * Input arguments + * 0 - target + * 1 - input + * 2 - weights (scalar or vector with same as last dimension) + * + * return value - a tensor with the same shape as target or input + */ +#if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) +DECLARE_OP(weighted_cross_entropy_with_logits, 3, 1, true); +#endif + +/** + * This op calculates dropout of input + * Input arguments + * 0 - input tensor + * 1 - noise_shape - (vector with shape to reduce) - optional + * + * int parameter - seed for random numbers + * T parameter - probability (should be between 0 and 1) + * return value - a tensor with the same shape as target or input + */ +#if NOT_EXCLUDED(OP_dropout) +DECLARE_CONFIGURABLE_OP(dropout, 1, 1, true, 1, 1); +#endif +#if NOT_EXCLUDED(OP_dropout_bp) +DECLARE_CONFIGURABLE_OP(dropout_bp, 2, 1, false, 1, 1); +#endif + +/* Calculates alpha weighted dropout + T params: + 0 - drop probability + 1 - alpha value + 2 - alpha' value + 3 - beta value + */ +#if NOT_EXCLUDED(OP_alpha_dropout_bp) +DECLARE_CONFIGURABLE_OP(alpha_dropout_bp, 2, 1, false, 4, 1); +#endif + +/** + * bincount operation return a vector with element counted. + * + * input params: + * - input tensor - only int part are accepted + * - weights - the same shape tensor with integer weights for element + * (optional) default weight - 1,1,1..,1 for all values in the tensor + * + * optional ints: + * - min_length - zero or greater + * - max_length - between min_length and max(input) + 1 + * + * returns four tensors: + * - vector tensor with length to min(max_len, max(input) + 1) with count + * of values in indexed place + * + */ +#if NOT_EXCLUDED(OP_bincount) +DECLARE_CUSTOM_OP(bincount, 1, 1, false, 0, 0); +#endif + +/** + * broadcast_dynamic_shape op. + * + * input params: + * 0 - the first shape (vector with shape) + * 1 - the second shape (vector with shape) + * + * return value: + * vector with broadcasted shape + */ +#if NOT_EXCLUDED(OP_broadcast_dynamic_shape) +DECLARE_CUSTOM_OP(broadcast_dynamic_shape, 2, 1, false, 0, 0); +#endif + +/** + * matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with determinant for all + * M x M matricies + */ +#if NOT_EXCLUDED(OP_matrix_determinant) +DECLARE_CUSTOM_OP(matrix_determinant, 1, 1, false, 0, 0); +#endif + +/** + * log_matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ + +#if NOT_EXCLUDED(OP_log_matrix_determinant) +DECLARE_CUSTOM_OP(log_matrix_determinant, 1, 1, false, 0, 0); +#endif + +/** + * logdet op. Logarithm of the determinant of hermitian positive matricies. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ + +#if NOT_EXCLUDED(OP_logdet) +DECLARE_CUSTOM_OP(logdet, 1, 1, false, 0, 0); +#endif + +/** + * matrix_solve_ls op (lstsq) - solves one or more linear least-squares + * problems. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_lstsq) +DECLARE_CUSTOM_OP(lstsq, 2, 1, false, 0, 0); +#endif + +/* solve_ls - analog of lstsq op with another solution approach + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq + * method due QR decomposition + * */ +#if NOT_EXCLUDED(OP_solve_ls) +DECLARE_CUSTOM_OP(solve_ls, 2, 1, false, 0, 0); +#endif + +/** + * matrix_inverse op. - make inverse for all 2D square matricies found in the + * input tensor + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M + * matricies in it + */ +#if NOT_EXCLUDED(OP_matrix_inverse) +DECLARE_OP(matrix_inverse, 1, 1, true); +#endif + +/** + * triangular_solve op. - reverse Gaussian method for solve systems of linear + * equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular + * matrix 1 - adjoint - default is false (optional) - indicate input matrix or + * its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_triangular_solve) +DECLARE_CUSTOM_OP(triangular_solve, 2, 1, false, 0, 0); +#endif + +/** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its + * adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +#if NOT_EXCLUDED(OP_solve) +DECLARE_CUSTOM_OP(solve, 2, 1, true, 0, 0); +#endif + +/** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M + * matricies in it 1 - int (32 or 64) batched vector of permutations with length + * M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, + * default INT32 + */ + +#if NOT_EXCLUDED(OP_matrix_inverse) +DECLARE_CUSTOM_OP(lu, 1, 2, false, 0, 0); +#endif + +/** + * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, + * i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] + * + * input params: + * 0 - the ND-tensor filled by integer-like values + * + * optional int param - maxlength (maxlength >= max(x)). By default maxlength = + * max(x). return value: (N+1)D tensor filled by 0 and 1 accordingly the mask + */ +#if NOT_EXCLUDED(OP_sequence_mask) +DECLARE_CUSTOM_OP(sequence_mask, 1, 1, false, 0, 0); +#endif +/** + * segment_max op. - make a tensor filled by max values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ + +#if NOT_EXCLUDED(OP_segment_max) +DECLARE_CUSTOM_OP(segment_max, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_max_bp) +DECLARE_CUSTOM_OP(segment_max_bp, 3, 2, false, 0, 0); +#endif + +/** + * segment_min op. - make a tensor filled by min values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with min values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_min) +DECLARE_CUSTOM_OP(segment_min, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_min_bp) +DECLARE_CUSTOM_OP(segment_min_bp, 3, 2, false, 0, 0); +#endif + +/** + * segment_sum op. - make a tensor filled by sum of values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with sum of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_sum) +DECLARE_CUSTOM_OP(segment_sum, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_sum_bp) +DECLARE_CUSTOM_OP(segment_sum_bp, 3, 2, false, 0, 0); +#endif + +/** + * segment_prod op. - make a tensor filled by product of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with product of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_prod) +DECLARE_CUSTOM_OP(segment_prod, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_prod_bp) +DECLARE_CUSTOM_OP(segment_prod_bp, 3, 2, false, 0, 0); +#endif +/** + * segment_mean op. - make a tensor filled by average of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_segment_mean) +DECLARE_CUSTOM_OP(segment_mean, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_segment_mean_bp) +DECLARE_CUSTOM_OP(segment_mean_bp, 3, 2, false, 0, 0); +#endif + +/** + * unsorted_segment_max op. - make a tensor filled by max values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_max) +DECLARE_CUSTOM_OP(unsorted_segment_max, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_max_bp) +DECLARE_CUSTOM_OP(unsorted_segment_max_bp, 3, 2, false, 0, 1); +#endif + +/** + * unsorted_segment_min op. - make a tensor filled by min values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with min values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +DECLARE_CUSTOM_OP(unsorted_segment_min, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +DECLARE_CUSTOM_OP(unsorted_segment_min_bp, 3, 2, false, 0, 1); +#endif + +/** + * unsorted_segment_sum op. - make a tensor filled by sum of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with sum of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_sum) +DECLARE_CUSTOM_OP(unsorted_segment_sum, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) +DECLARE_CUSTOM_OP(unsorted_segment_sum_bp, 3, 2, false, 0, 1); +#endif + +/** + * unsorted_segment_prod op. - make a tensor filled by product of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with product of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_prod) +DECLARE_CUSTOM_OP(unsorted_segment_prod, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) +DECLARE_CUSTOM_OP(unsorted_segment_prod_bp, 3, 2, false, 0, 1); +#endif + +/** + * unsorted_segment_mean op. - make a tensor filled by average of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_mean) +DECLARE_CUSTOM_OP(unsorted_segment_mean, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) +DECLARE_CUSTOM_OP(unsorted_segment_mean_bp, 3, 2, false, 0, 1); +#endif + +/** + * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor + * divided by the sqrt(N). + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +#if NOT_EXCLUDED(OP_unsorted_segment_sqrt) +DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) +DECLARE_CUSTOM_OP(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1); +#endif + +/** + * extract_image_patches op - Extract patches from images and put them in the + * "depth" output dimension. + * + * input params: + * 0 - images tensor (4D) + * + * int params: + * 0 - ksize_rows + * 1 - ksize_cols + * 2 - strides_rows + * 3 - strides_cols + * 4 - rates_rows + * 5 - rates_cols + * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' + */ +#if NOT_EXCLUDED(OP_extract_image_patches) +DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7); +#endif + +/** + * draw_bounding_boxes op - modified input image with given colors exept given + * boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where + * channes is 1 (BW image), 3 (RGB) or 4 (RGBA) 1 - boxes tensor (3D) with shape + * {batch, number_of_boxes, 4} where last dimension encoded as (y_min, x_min, + * y_max, x_max), all values in between 0. and 1. 2 - colours tensor (2D) with + * shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +#if NOT_EXCLUDED(OP_draw_bounding_boxes) +DECLARE_OP(draw_bounding_boxes, 3, 1, true); +#endif + +/** + * roll - op porting from numpy + * (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) + * + * input params: + * 0 - NDArray + * + * int params: + * 0 - shift + * 1 - axe 1 + * 2 - axe 2 + * ... + * N - axe N + * + * All axes are optional and should be between 0 and input->rankOf(). Of + * course, all axes can be repeated. + * + * output: + * 0 - NDArray with the same shape as input. + */ +#if NOT_EXCLUDED(OP_roll) +DECLARE_CONFIGURABLE_OP(roll, 1, 1, true, 0, 1); +#endif + +/** + * lin_space - op porting from TF + * (https://www.tensorflow.org/api_docs/python/tf/lin_space) + * + * optional input params: + * 0 - startVal - NDArray scalar (float point) + * 1 - finishVal - NDArray scalar (float point) + * 2 - numOfElements - NDArray scalar (integer) + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements + * output: + * 0 - 1D NDArray with the same type as input and length as given with + * numOfElements param. + */ +#if NOT_EXCLUDED(OP_lin_space) +DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); +#endif + +/** + * reduction_sum - tf.reduction_sum operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_sum) +DECLARE_CUSTOM_OP(reduce_sum, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_reduce_sum_bp) +DECLARE_CUSTOM_OP(reduce_sum_bp, 2, 1, false, 0, 0); +#endif + +/** + * reduction_prod - tf.reduction_prod operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_prod) +DECLARE_CUSTOM_OP(reduce_prod, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_reduce_prod_bp) +DECLARE_CUSTOM_OP(reduce_prod_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates min of elements along given dimensions + * + * input array: + * x: tensor to calculate mins for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate min along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated mins + */ +#if NOT_EXCLUDED(OP_reduce_min) +DECLARE_CUSTOM_OP(reduce_min, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_min_bp) +DECLARE_CUSTOM_OP(reduce_min_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates max of elements along given dimensions + * + * input array: + * x: tensor to calculate maxes for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate max along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated maxes + */ +#if NOT_EXCLUDED(OP_reduce_max) +DECLARE_CUSTOM_OP(reduce_max, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_max_bp) +DECLARE_CUSTOM_OP(reduce_max_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates norm1 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm1 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm1 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm1 + */ +#if NOT_EXCLUDED(OP_reduce_norm1) +DECLARE_CUSTOM_OP(reduce_norm1, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm1_bp) +DECLARE_CUSTOM_OP(reduce_norm1_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates norm2 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm2 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm2 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm2 + */ +#if NOT_EXCLUDED(OP_reduce_norm2) +DECLARE_CUSTOM_OP(reduce_norm2, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm2_bp) +DECLARE_CUSTOM_OP(reduce_norm2_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates squared norm of elements along given dimensions + * + * input array: + * x: tensor to calculate squared norm for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate squared norm along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +#if NOT_EXCLUDED(OP_reduce_sqnorm) +DECLARE_CUSTOM_OP(reduce_sqnorm, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_sqnorm_bp) +DECLARE_CUSTOM_OP(reduce_sqnorm_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates norm max of elements along given dimensions + * + * input array: + * x: tensor to calculate norm max for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm max along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +#if NOT_EXCLUDED(OP_reduce_norm_max) +DECLARE_CUSTOM_OP(reduce_norm_max, 1, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_reduce_norm_max_bp) +DECLARE_CUSTOM_OP(reduce_norm_max_bp, 2, 1, false, 0, 0); +#endif + +/** + * This op calculates mean of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +#if NOT_EXCLUDED(OP_reduce_mean) +DECLARE_CUSTOM_OP(reduce_mean, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_reduce_mean_bp) +DECLARE_CUSTOM_OP(reduce_mean_bp, 2, 1, false, 0, 0) +#endif +/** + * This op calculates sample variance of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +DECLARE_CUSTOM_OP(reduce_variance, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(reduce_variance_bp, 2, 1, false, 0, 0) + +/** + * This op calculates sample standard deviation of elements along given + * dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +DECLARE_CUSTOM_OP(reduce_stdev, 1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(reduce_stdev_bp, 2, 1, false, 0, 0) + +/** + * This op calculates backprop dot for two tensors along given dimensions + * + * input array: + * x: tensor to calculate dot for + * y: tensor to calculate dot for + * z: tensor with gradient output of the FF dot for x and y + * + * int arguments: + * list of integers - dimensions to calculate dot along, + * default corresponds to empty list in which case calculation + * is performed for all dimensions and scalar is returned. + * + * output array: + * the tensor with calculated backproped dots + * + */ + +#if NOT_EXCLUDED(OP_reduce_dot_bp) +DECLARE_CUSTOM_OP(reduce_dot_bp, 3, 2, false, 0, 0); +#endif +/** + * reduce_logsumexp - tf.reduce_logsumexe operation + * + * input params: + * 0 - NDArray (input) + * 1 - 1D NDArray (axis) (optional) - integer array + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * CAUTION: All axes are optional and should be between 0 and input->rankOf() - + * 1 and put either with second param or as integers but not both + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +#if NOT_EXCLUDED(OP_reduce_logsumexp) +DECLARE_CUSTOM_OP(reduce_logsumexp, 1, 1, false, 0, 0); +#endif + +/** + * This op make bilinear or nearest neighbor interpolated resize for given + * tensor + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) numeric type + * 1 - 2D-Tensor with shape (num_boxes, 4) float type + * 2 - 1D-Tensor with shape (num_boxes) int type + * 3 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) int type + * + * float arguments (optional) + * 0 - exprapolation_value (optional) default 0.f + * + * int arguments: (optional) + * 0 - mode (default 0 - bilinear interpolation) + * + * output array: + * the 4D-Tensor with resized to crop_size images given - float type + */ +#if NOT_EXCLUDED(OP_crop_and_resize) +DECLARE_CUSTOM_OP(crop_and_resize, 4, 1, false, -1, -1); +#endif + +/** + * This op make bilinear interpolated resize for given tensor + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) + * + * int arguments: (optional) + * 0 - new width + * 1 - new height + * + * output array: + * the 4D-Tensor with calculated backproped dots + * + * CAUTION: either size tensor or a pair of int params should be provided. + */ + +#if NOT_EXCLUDED(OP_resize_bilinear) +DECLARE_CUSTOM_OP(resize_bilinear, 1, 1, false, 0, -2); +#endif + +/** + * This op make nearest neighbor interpolated resize for given tensor + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) (optional) + * + * int arguments: (optional) + * 0 - new width + * 1 - new height + * + * output array: + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, + * channels}) + * + * CAUTION: either size tensor or a pair of int params should be provided. + */ + +#if NOT_EXCLUDED(OP_resize_nearest_neighbor) +DECLARE_CUSTOM_OP(resize_nearest_neighbor, 1, 1, false, 0, -2); +#endif + +/** + * This op make bicubic interpolated resize for given tensor + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) + * + * output array: + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, + * channels}) + * + */ +#if NOT_EXCLUDED(OP_resize_bicubic) +DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2); +#endif + +/** + * This op make area interpolated resize (as OpenCV INTER_AREA algorithm) for + * given tensor + * + * input array: + * 0 - images - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - size - 1D-Tensor with 2 values (newWidth, newHeight) (if missing a + * pair of integer args should be provided). + * + * int args: - proveded only when size tensor is missing + * 0 - new height + * 1 - new width + * boolean args: + * 0 - align_corners - optional (default is false) + * + * output array: + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, + * channels}) + * + */ +#if NOT_EXCLUDED(OP_resize_area) +DECLARE_CUSTOM_OP(resize_area, 1, 1, false, 0, -2); +#endif + +/** + * This op make interpolated resize for given tensor with given algorithm. + * Supported algorithms are bilinear, bicubic, nearest_neighbor. + * Need to implement to full compatibility with TF: lanczos5, gaussian, area and + * mitchellcubic + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) + * + * optional int args: + * 0 - algorithm - bilinear by default + * optional bool args: + * 0 - preserve_aspect_ratio - default False + * 1 - antialias - default False + * + * output array: + * the 4D-Tensor with resized by given algorithm image (shape is {batch, + * newWidth, newHeight, channels}) + * + */ + +#if NOT_EXCLUDED(OP_image_resize) +DECLARE_CUSTOM_OP(image_resize, 2, 1, false, 0, 0); +#endif + +/** + * Copy a tensor setting everything outside a central band in each innermost + * matrix + * + * input array: + * x: given tensor with shape {..., M, N} - as vector (matrix) of matricies + * MxN + * + * int arguments: + * lower band + * upper band + * + * output array: + * matrix with given bands between lower and upper diagonals + * + */ + +#if NOT_EXCLUDED(OP_matrix_band_part) +DECLARE_CONFIGURABLE_OP(matrix_band_part, 1, 1, true, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_Assert) +DECLARE_OP(Assert, 1, 1, false); +#endif + +/** + * image.non_max_suppression ops. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * - vector with size M, where M <= output_size by int type + * + * */ +#if NOT_EXCLUDED(OP_image_non_max_suppression) +DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); +#endif +#if NOT_EXCLUDED(OP_image_non_max_suppression_v3) +DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0); +#endif + +/* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices + * from the overlaps tensor, where M <= max_output_size + * */ +#if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) +DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); +#endif + +/* + * cholesky op - decomposite positive square symetric matrix (or matricies when + * rank > 2). input: 0 - matricies - tensor with shape (..., N, N) by float type + * + * output - lower triangular matrix (matricies when rank > 2) with the same + * shape as input. + * */ +#if NOT_EXCLUDED(OP_cholesky) +DECLARE_OP(cholesky, 1, 1, true); +#endif +/* + * nth_element - apply nth_element for last dimension of input tensor + * input array: + * 0 - input array + * 1 - scalar tensor with n for operation. n should be less than last + * dimension + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_nth_element) +DECLARE_CUSTOM_OP(nth_element, 2, 1, false, 0, 0); +#endif + +/** + * This op checks for Inf/NaN values within input array, and throws exception if + * there's at least one + */ +#if NOT_EXCLUDED(OP_check_numerics) +DECLARE_CUSTOM_OP(check_numerics, 2, 1, true, 0, 0); +#endif +/** + * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - min value + * 2 - 0D Tensor - max value + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) +DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars, 3, 1, true, 0, -2); +#endif + +/** + * fake_quant_with_min_max_vals_per_channel - + * tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +#if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) +DECLARE_CONFIGURABLE_OP(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, + -2); +#endif + +/** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +#if NOT_EXCLUDED(OP_compare_and_bitpack) +DECLARE_CUSTOM_OP(compare_and_bitpack, 2, 1, false, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index 367a41995e05..d478e2b0daba 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -25,78 +25,79 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_set_seed) - DECLARE_CUSTOM_OP(set_seed, -2, 1, false, 0, -2); - #endif +namespace ops { +#if NOT_EXCLUDED(OP_set_seed) +DECLARE_CUSTOM_OP(set_seed, -2, 1, false, 0, -2); +#endif - #if NOT_EXCLUDED(OP_get_seed) - DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_get_seed) +DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0); +#endif - /* - * random_uniform distribution for types int32,int64, float16, float and double - * by default dtype is float32 - * - * input: - * 0 - shape of output (1D int tensor) - * 1 - min val (0D of output type) - optional (0 as default) - * 2 - max val (0D of output type) - optional (inf as default) - * - * output: - * 0 - uniformly distributed values of given type (between min and max) - */ - #if NOT_EXCLUDED(OP_randomuniform) - DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); - #endif - /* - * multinomial (categorical) random generator draws samples from a multinomial distribution - * - * Input array: - * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] - * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. - * Int arguments: - * 0 - optional argument, corresponds to dimension with batch_size - * 1 - optional argument, integer type to use for the output. Default int64. - * - * Output array: - * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] - */ - #if NOT_EXCLUDED(OP_random_multinomial) - DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); - #endif +/* + * random_uniform distribution for types int32,int64, float16, float and double + * by default dtype is float32 + * + * input: + * 0 - shape of output (1D int tensor) + * 1 - min val (0D of output type) - optional (0 as default) + * 2 - max val (0D of output type) - optional (inf as default) + * + * output: + * 0 - uniformly distributed values of given type (between min and max) + */ +#if NOT_EXCLUDED(OP_randomuniform) +DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); +#endif +/* + * multinomial (categorical) random generator draws samples from a multinomial + * distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size + * (N), num_classes (K)] 1 - array with one int value of samples number, number + * of independent samples to draw for each experiment 1,N. Int arguments: 0 - + * optional argument, corresponds to dimension with batch_size 1 - optional + * argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ +#if NOT_EXCLUDED(OP_random_multinomial) +DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_random_normal) - DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); - #endif +#if NOT_EXCLUDED(OP_random_normal) +DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); +#endif - #if NOT_EXCLUDED(OP_random_bernoulli) - DECLARE_CUSTOM_OP(random_bernoulli, 1, 1, true, 0, 1); - #endif +#if NOT_EXCLUDED(OP_random_bernoulli) +DECLARE_CUSTOM_OP(random_bernoulli, 1, 1, true, 0, 1); +#endif - #if NOT_EXCLUDED(OP_random_exponential) - DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0); - #endif +#if NOT_EXCLUDED(OP_random_exponential) +DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0); +#endif - #if NOT_EXCLUDED(OP_random_crop) - DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_random_crop) +DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0); +#endif - /** - * random_gamma op. - */ - #if NOT_EXCLUDED(OP_random_gamma) - DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0); - #endif +/** + * random_gamma op. + */ +#if NOT_EXCLUDED(OP_random_gamma) +DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0); +#endif - /** - * random_poisson op. - */ - #if NOT_EXCLUDED(OP_random_poisson) - DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); - #endif +/** + * random_poisson op. + */ +#if NOT_EXCLUDED(OP_random_poisson) +DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); +#endif - } -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index aeeae24c42fe..870f0ceea2f2 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -24,420 +24,448 @@ #include namespace sd { -namespace ops { - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ - #if NOT_EXCLUDED(OP_sru) - DECLARE_CUSTOM_OP(sru, 5, 2, false, 0, 0); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of cell output [N x bS x 2K] - * 1: 3d tensor of cell state [N x bS x 2K] - */ - #if NOT_EXCLUDED(OP_sru_bi) - DECLARE_CUSTOM_OP(sru_bi, 5, 2, true, 0, 0); - #endif - - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ - #if NOT_EXCLUDED(OP_sru) - DECLARE_CUSTOM_OP(sru_bp, 8, 4, true, 0, 0); - #endif - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: 3d tensor of cell state [N x bS x 2K] - * 5: 2d tensor of cell state gradients [bS x 2K] - * 6: 3d tensor of state output gradients [N x bS x 2K] - * 7: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of input gradients [N x bS x 2K] - * 1: 3d tensor of weights gradients [N x 2K x 6K] - * 2: 2d, row of biases gradients [1 x 4K] - * 3: 2d, tensor of state gradients [bS x 2K] - */ - #if NOT_EXCLUDED(OP_sru_bi) - DECLARE_CUSTOM_OP(sru_bi_bp, 8, 4, true, 0, 0); - #endif +namespace ops { +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +#if NOT_EXCLUDED(OP_sru) +DECLARE_CUSTOM_OP(sru, 5, 2, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - * 2: previous cell state [batchSize x numUnits], that is at previous time step t-1 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: current cell output [batchSize x numProj], that is at current time step t - * 1: current cell state [batchSize x numUnits], that is at current time step t - */ - #if NOT_EXCLUDED(OP_lstmCell) - DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit (bidirectional case): + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of cell output [N x bS x 2K] + * 1: 3d tensor of cell state [N x bS x 2K] + */ +#if NOT_EXCLUDED(OP_sru_bi) +DECLARE_CUSTOM_OP(sru_bi, 5, 2, true, 0, 0); +#endif - #if NOT_EXCLUDED(OP_lstmLayerCell) - DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); - #endif - #if NOT_EXCLUDED(OP_lstmLayerCell) - DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +#if NOT_EXCLUDED(OP_sru) +DECLARE_CUSTOM_OP(sru_bp, 8, 4, true, 0, 0); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit + * (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav + * Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: 3d tensor of cell state [N x bS x 2K] 5: 2d tensor of cell state + * gradients [bS x 2K] 6: 3d tensor of state output gradients [N x bS x 2K] 7: + * optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of input gradients [N x bS x 2K] + * 1: 3d tensor of weights gradients [N x 2K x 6K] + * 2: 2d, row of biases gradients [1 x 4K] + * 3: 2d, tensor of state gradients [bS x 2K] + */ +#if NOT_EXCLUDED(OP_sru_bi) +DECLARE_CUSTOM_OP(sru_bi_bp, 8, 4, true, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with optional peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * See also: https://arxiv.org/pdf/1503.04069.pdf - * - * Input arrays: - * 0: input [bS, inSize] at time t - * 1: previous cell state [bS, numUnits], time t-1 - * 2: previous output [bS, numUnits], time t-1 - * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 4: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 5: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 6: weights - cell peephole (t) connections to output gate, [numUnits] - * 7: biases, shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, numUnits] - * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) - * 2: f - Output - forget gate activations [bs, numUnits] - * 3: o - Output - output gate activations [bs, numUnits] - * 4: z (ci) - Output - block input [bs, numUnits] - * 5: h (co) - Cell state, post tanh [bs, numUnits] - * 6: y (h) - Current cell output [bS, numUnits], time t - */ - #if NOT_EXCLUDED(OP_lstmBlockCell) - DECLARE_CUSTOM_OP(lstmBlockCell, 8, 7, false, 2, 1); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with peep hole connections: + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation and https://research.google.com/pubs/archive/43905.pdf Hasim Sak, + * Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent + * neural network architectures for large scale acoustic modeling." INTERSPEECH, + * 2014. + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numProj], that is at + * previous time step t-1, in case of projection=false -> numProj=numUnits!!! 2: + * previous cell state [batchSize x numUnits], that is at previous time step + * t-1 3: input-to-hidden weights, [inSize x 4*numUnits] 4: hidden-to-hidden + * weights, [numProj x 4*numUnits] 5: diagonal weights for peephole connections + * [3*numUnits] 6: projection weights [numUnits x numProj] 7: biases, + * [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: current cell output [batchSize x numProj], that is at current time step + * t 1: current cell state [batchSize x numUnits], that is at current time step + * t + */ +#if NOT_EXCLUDED(OP_lstmCell) +DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM layer with optional peep hole connections. - * See lstmBlockCell for details. lstmBlockCell is used internally for computation. - * This method expects as input (and returns as output) sequences in one of 3 formats, depending on the data format arg: - * dataFormat = 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major" - * dataFormat = 1 -> NST: shape [numExamples, inOutSize, timeLength] - * dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - * - * - * Input arrays: - * 0: max sequence length; long/int64 scalar - * 1: input [seqLength, bS, inSize] at time t - * 2: previous/initial cell state [bS, numUnits] - * 3: previous/initial output [bS, numUnits] - * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 7: weights - cell peephole (t) connections to output gate, [numUnits] - * 8: biases, Shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size] - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat - * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat - * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat - * 3: o - Output - output gate activations, rank 3, shape as per dataFormat - * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat - * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat - * 6: y (h) - Current cell output, rank 3, shape as per dataFormat - */ - #if NOT_EXCLUDED(OP_lstmBlock) - DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); - #endif +#if NOT_EXCLUDED(OP_lstmLayerCell) +DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); +#endif +#if NOT_EXCLUDED(OP_lstmLayerCell) +DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); +#endif - ////////////////////////////////////////////////////////////////////////// - #if NOT_EXCLUDED(OP_lstmLayer) - DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with optional peep hole + * connections: S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". + * Neural Computation and https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. See also: https://arxiv.org/pdf/1503.04069.pdf + * + * Input arrays: + * 0: input [bS, inSize] at time t + * 1: previous cell state [bS, numUnits], time t-1 + * 2: previous output [bS, numUnits], time t-1 + * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 4: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 5: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 6: weights - cell peephole (t) + * connections to output gate, [numUnits] 7: biases, shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, numUnits] + * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) + * 2: f - Output - forget gate activations [bs, numUnits] + * 3: o - Output - output gate activations [bs, numUnits] + * 4: z (ci) - Output - block input [bs, numUnits] + * 5: h (co) - Cell state, post tanh [bs, numUnits] + * 6: y (h) - Current cell output [bS, numUnits], time t + */ +#if NOT_EXCLUDED(OP_lstmBlockCell) +DECLARE_CUSTOM_OP(lstmBlockCell, 8, 7, false, 2, 1); +#endif - ////////////////////////////////////////////////////////////////////////// - #if NOT_EXCLUDED(OP_lstmLayer) - DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM layer with optional peep hole + * connections. See lstmBlockCell for details. lstmBlockCell is used internally + * for computation. This method expects as input (and returns as output) + * sequences in one of 3 formats, depending on the data format arg: dataFormat = + * 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to + * as "time major" dataFormat = 1 -> NST: shape [numExamples, inOutSize, + * timeLength] dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] + * - TF "time_major=false" layout + * + * + * Input arrays: + * 0: max sequence length; long/int64 scalar + * 1: input [seqLength, bS, inSize] at time t + * 2: previous/initial cell state [bS, numUnits] + * 3: previous/initial output [bS, numUnits] + * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 5: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 6: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 7: weights - cell peephole (t) + * connections to output gate, [numUnits] 8: biases, Shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; + * 2=NTS=[mb,seqLen,size] + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations, rank 3, shape as per + * dataFormat 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat + * 2: f - Output - forget gate activations, rank 3, shape as per + * dataFormat 3: o - Output - output gate activations, rank 3, shape as per + * dataFormat 4: z (ci) - Output - block input, rank 3, shape as per dataFormat + * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat + * 6: y (h) - Current cell output, rank 3, shape as per dataFormat + */ +#if NOT_EXCLUDED(OP_lstmBlock) +DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); +#endif +////////////////////////////////////////////////////////////////////////// +#if NOT_EXCLUDED(OP_lstmLayer) +DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell state [batchSize x inSize], that is at previous time step t-1 - * 2: weights [inSize x 3*inSize] - * 3: biases [1 x 2*inSize] - * - * Output arrays: - * 0: current cell output [batchSize x inSize], that is at current time step t - * 1: current cell state [batchSize x inSize], that is at current time step t - */ - #if NOT_EXCLUDED(OP_sruCell) - DECLARE_CUSTOM_OP(sruCell, 4, 2, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +#if NOT_EXCLUDED(OP_lstmLayer) +DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); +#endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs + * as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell state [batchSize x inSize], that is at + * previous time step t-1 2: weights [inSize x 3*inSize] 3: biases [1 x + * 2*inSize] + * + * Output arrays: + * 0: current cell output [batchSize x inSize], that is at current time step + * t 1: current cell state [batchSize x inSize], that is at current time step t + */ +#if NOT_EXCLUDED(OP_sruCell) +DECLARE_CUSTOM_OP(sruCell, 4, 2, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit cell: - * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio - * "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numUnits], that is at previous time step t-1 - * 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights) - * 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) - * 4: reset and update biases, [2*numUnits] - reset and update gates - * 5: cell biases, [numUnits] - * - * Output arrays: - * 0: Reset gate output [bS, numUnits] - * 1: Update gate output [bS, numUnits] - * 2: Cell gate output [bS, numUnits] - * 3: Current cell output [bS, numUnits] - */ - #if NOT_EXCLUDED(OP_gruCell) - DECLARE_CUSTOM_OP(gruCell, 6, 4, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit cell: + * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, + * Fethi Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase + * Representations using RNN Encoder-Decoder for Statistical Machine + * Translation" + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numUnits], that is + * at previous time step t-1 2: RU weights - [(inSize+numUnits), 2*numUnits] - + * reset and update gates (input/recurrent weights) 3: C weights - + * [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) 4: reset + * and update biases, [2*numUnits] - reset and update gates 5: cell biases, + * [numUnits] + * + * Output arrays: + * 0: Reset gate output [bS, numUnits] + * 1: Update gate output [bS, numUnits] + * 2: Cell gate output [bS, numUnits] + * 3: Current cell output [bS, numUnits] + */ +#if NOT_EXCLUDED(OP_gruCell) +DECLARE_CUSTOM_OP(gruCell, 6, 4, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_gruCell) - DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_gruCell) +DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "LSTM time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numProj], that is at time step = 0, in case of projection=false -> numProj=numUnits!!! - * 2: initial cell state [batchSize x numUnits], that is at time step = 0 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: cell outputs [time x batchSize x numProj], that is per each time step - * 1: cell states [time x batchSize x numUnits], that is per each time step - */ - #if NOT_EXCLUDED(OP_lstm) - DECLARE_CUSTOM_OP(lstm, 8, 2, false, 3, 2); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "LSTM time sequences" with peep hole connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numProj], that is at time step = 0, in case of + * projection=false -> numProj=numUnits!!! 2: initial cell state [batchSize x + * numUnits], that is at time step = 0 3: input-to-hidden weights, [inSize x + * 4*numUnits] 4: hidden-to-hidden weights, [numProj x 4*numUnits] 5: diagonal + * weights for peephole connections [3*numUnits] 6: projection weights [numUnits + * x numProj] 7: biases, [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: cell outputs [time x batchSize x numProj], that is per each time step + * 1: cell states [time x batchSize x numUnits], that is per each time step + */ +#if NOT_EXCLUDED(OP_lstm) +DECLARE_CUSTOM_OP(lstm, 8, 2, false, 3, 2); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numUnits], that is at time step = 0 - * 2: input-to-hidden weights, [inSize x 3*numUnits] - * 3: hidden-to-hidden weights, [numUnits x 3*numUnits] - * 4: biases, [3*numUnits] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits], that is per each time step - */ - #if NOT_EXCLUDED(OP_gru) - DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); - #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numUnits], that is at time step = 0 2: input-to-hidden + * weights, [inSize x 3*numUnits] 3: hidden-to-hidden weights, [numUnits x + * 3*numUnits] 4: biases, [3*numUnits] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits], that is per each time step + */ +#if NOT_EXCLUDED(OP_gru) +DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_gru) - DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_gru) +DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); +#endif - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - DECLARE_CUSTOM_OP(static_rnn, 4, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights, [inSize x numUnits] 2: hidden-to-hidden weights, [numUnits x + * numUnits] 3: biases, [2*numUnits] 4: (optional) initial cell output + * [batchSize x numUnits], that is at time step = 0 5: (optional) vector with + * shape [batchSize] containing integer values within [0,time), each element of + * this vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] + * 1: cell final non-zero output [batchSize x numUnits] + */ +DECLARE_CUSTOM_OP(static_rnn, 4, 2, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x numUnits], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - DECLARE_CUSTOM_OP(dynamic_rnn, 4, 2, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * numUnits], time - number of time steps, batchSize - batch size, inSize - + * number of features 1: input-to-hidden weights, [inSize x numUnits] 2: + * hidden-to-hidden weights, [numUnits x numUnits] 3: biases, [2*numUnits] 4: + * (optional) initial cell output [batchSize x numUnits], that is at time step = + * 0 5: (optional) vector with shape [batchSize] containing integer values + * within [0,time), each element of this vector set max time step per each input + * in batch, this provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x + * numUnits] 1: cell final non-zero output [batchSize x numUnits] + */ +DECLARE_CUSTOM_OP(dynamic_rnn, 4, 2, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] - * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - DECLARE_CUSTOM_OP(static_bidirectional_rnn, 7, 3, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights for forward RNN, [inSize x numUnitsFW] 2: hidden-to-hidden weights + * for forward RNN, [numUnitsFW x numUnitsFW] 3: biases for forward RNN, + * [2*numUnitsFW] 4: input-to-hidden weights for backward RNN, [inSize x + * numUnitsBW] 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x + * numUnitsBW] 6: biases for backward RNN, [2*numUnitsBW] 7: (optional) initial + * cell output for forward RNN [batchSize x numUnitsFW], that is at time step = + * 0 8: (optional) initial cell output for backward RNN [batchSize x + * numUnitsBW], that is at time step = 0 9: (optional) vector with shape + * [batchSize] containing integer values within [0,time), each element of this + * vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] + * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] + * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] + */ +DECLARE_CUSTOM_OP(static_bidirectional_rnn, 7, 3, false, 0, 0); - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or [batchSize x time x numUnitsFW] - * 1: cell outputs for backward RNN [time x batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] - * 2: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 3: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - DECLARE_CUSTOM_OP(dynamic_bidirectional_rnn, 7, 4, false, 0, 0); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * inSize], time - number of time steps, batchSize - batch size, inSize - number + * of features 1: input-to-hidden weights for forward RNN, [inSize x + * numUnitsFW] 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x + * numUnitsFW] 3: biases for forward RNN, [2*numUnitsFW] 4: input-to-hidden + * weights for backward RNN, [inSize x numUnitsBW] 5: hidden-to-hidden weights + * for backward RNN, [numUnitsBW x numUnitsBW] 6: biases for backward RNN, + * [2*numUnitsBW] 7: (optional) initial cell output for forward RNN [batchSize x + * numUnitsFW], that is at time step = 0 8: (optional) initial cell output for + * backward RNN [batchSize x numUnitsBW], that is at time step = 0 9: (optional) + * vector with shape [batchSize] containing integer values within [0,time), each + * element of this vector set max time step per each input in batch, this + * provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or + * [batchSize x time x numUnitsFW] 1: cell outputs for backward RNN [time x + * batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] 2: cell final + * non-zero output for forward RNN [batchSize x numUnitsFW] 3: cell final + * non-zero output for backward RNN [batchSize x numUnitsBW] + */ +DECLARE_CUSTOM_OP(dynamic_bidirectional_rnn, 7, 4, false, 0, 0); -} -} +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/shape.h b/libnd4j/include/ops/declarable/headers/shape.h index 7f93303429f2..66622bd1a12e 100644 --- a/libnd4j/include/ops/declarable/headers/shape.h +++ b/libnd4j/include/ops/declarable/headers/shape.h @@ -24,98 +24,97 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_permute) - DECLARE_CUSTOM_OP(permute, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_reshapeas) - DECLARE_CUSTOM_OP(reshapeas, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_transpose) - DECLARE_CUSTOM_OP(transpose, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_shape_of) - DECLARE_CUSTOM_OP(shape_of, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_shapes_of) - DECLARE_CUSTOM_OP(shapes_of, -1, -1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_squeeze) - DECLARE_CUSTOM_OP(squeeze, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_expand_dims) - DECLARE_CUSTOM_OP(expand_dims, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_reshape) - DECLARE_CUSTOM_OP(reshape, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_size_at) - DECLARE_CUSTOM_OP(size_at, 1, 1, false, 0, 1); - #endif - - /** - * This op changes order of given array to specified order. - * In other words: C/F order switch - * - * Int args: - * 0 - isForder. set to 1 for F order output, or 0 for C order output - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_order) - DECLARE_CUSTOM_OP(order, 1, 1, false, 0, 1); - #endif - - /** - * This op boosts specified input up to specified shape - * - * @tparam T - */ - #if NOT_EXCLUDED(OP_tile_to_shape) - DECLARE_CUSTOM_OP(tile_to_shape, 1, 1, false, 0, -1); - DECLARE_CUSTOM_OP(tile_to_shape_bp, 2, 1, false, 0, -1); - #endif - - /** - * This op broadcast given input up to given shape - * - * inputs: - * input array - array to be broadcasted to given shape - * shape array - array containing shape be broadcasted to - */ - #if NOT_EXCLUDED(OP_broadcast_to) - DECLARE_CUSTOM_OP(broadcast_to, 2, 1, false, 0, 0); - #endif - - - #if NOT_EXCLUDED(OP_evaluate_reduction_shape) - DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); - #endif - - /** - * This operation creates new array - * Input: - * array with shape values - * - * IArgs: - * order value - * data type value - * - * BArgs: - * initialization option - */ - #if NOT_EXCLUDED(OP_create) - DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_permute) +DECLARE_CUSTOM_OP(permute, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_reshapeas) +DECLARE_CUSTOM_OP(reshapeas, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_transpose) +DECLARE_CUSTOM_OP(transpose, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_shape_of) +DECLARE_CUSTOM_OP(shape_of, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_shapes_of) +DECLARE_CUSTOM_OP(shapes_of, -1, -1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_squeeze) +DECLARE_CUSTOM_OP(squeeze, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_expand_dims) +DECLARE_CUSTOM_OP(expand_dims, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_reshape) +DECLARE_CUSTOM_OP(reshape, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_size_at) +DECLARE_CUSTOM_OP(size_at, 1, 1, false, 0, 1); +#endif + +/** + * This op changes order of given array to specified order. + * In other words: C/F order switch + * + * Int args: + * 0 - isForder. set to 1 for F order output, or 0 for C order output + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_order) +DECLARE_CUSTOM_OP(order, 1, 1, false, 0, 1); +#endif + +/** + * This op boosts specified input up to specified shape + * + * @tparam T + */ +#if NOT_EXCLUDED(OP_tile_to_shape) +DECLARE_CUSTOM_OP(tile_to_shape, 1, 1, false, 0, -1); +DECLARE_CUSTOM_OP(tile_to_shape_bp, 2, 1, false, 0, -1); +#endif + +/** + * This op broadcast given input up to given shape + * + * inputs: + * input array - array to be broadcasted to given shape + * shape array - array containing shape be broadcasted to + */ +#if NOT_EXCLUDED(OP_broadcast_to) +DECLARE_CUSTOM_OP(broadcast_to, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_evaluate_reduction_shape) +DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); +#endif + +/** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +#if NOT_EXCLUDED(OP_create) +DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/strings.h b/libnd4j/include/ops/declarable/headers/strings.h index bd4b8b94996e..303a406a2b6d 100644 --- a/libnd4j/include/ops/declarable/headers/strings.h +++ b/libnd4j/include/ops/declarable/headers/strings.h @@ -24,19 +24,18 @@ #include namespace sd { - namespace ops { - /** - * This operation splits input string into pieces separated by delimiter - * - * Input[0] - string to split - * Input[1] - delimiter - */ - #if NOT_EXCLUDED(OP_split_string) - DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); - #endif - - } -} +namespace ops { +/** + * This operation splits input string into pieces separated by delimiter + * + * Input[0] - string to split + * Input[1] - delimiter + */ +#if NOT_EXCLUDED(OP_split_string) +DECLARE_CUSTOM_OP(split_string, 2, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_STRINGS_H +#endif // SAMEDIFF_STRINGS_H diff --git a/libnd4j/include/ops/declarable/headers/tests.h b/libnd4j/include/ops/declarable/headers/tests.h index cad12b3c85f9..c54d5ed90320 100644 --- a/libnd4j/include/ops/declarable/headers/tests.h +++ b/libnd4j/include/ops/declarable/headers/tests.h @@ -20,29 +20,29 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_test_output_reshape) - DECLARE_OP(test_output_reshape, 1, 1, true); - #endif +namespace ops { +#if NOT_EXCLUDED(OP_test_output_reshape) +DECLARE_OP(test_output_reshape, 1, 1, true); +#endif - #if NOT_EXCLUDED(OP_test_scalar) - DECLARE_CUSTOM_OP(test_scalar, 1, 1, false, 0, 0); - #endif +#if NOT_EXCLUDED(OP_test_scalar) +DECLARE_CUSTOM_OP(test_scalar, 1, 1, false, 0, 0); +#endif - #if NOT_EXCLUDED(OP_testreduction) - DECLARE_REDUCTION_OP(testreduction, 1, 1, false, 0, -1); - #endif +#if NOT_EXCLUDED(OP_testreduction) +DECLARE_REDUCTION_OP(testreduction, 1, 1, false, 0, -1); +#endif - #if NOT_EXCLUDED(OP_noop) - DECLARE_OP(noop, -2, -2, true); - #endif +#if NOT_EXCLUDED(OP_noop) +DECLARE_OP(noop, -2, -2, true); +#endif - #if NOT_EXCLUDED(OP_testop2i2o) - DECLARE_OP(testop2i2o, 2, 2, true); - #endif +#if NOT_EXCLUDED(OP_testop2i2o) +DECLARE_OP(testop2i2o, 2, 2, true); +#endif - #if NOT_EXCLUDED(OP_testcustom) - DECLARE_CUSTOM_OP(testcustom, 1, 1, false, 0, -1); - #endif - } -} \ No newline at end of file +#if NOT_EXCLUDED(OP_testcustom) +DECLARE_CUSTOM_OP(testcustom, 1, 1, false, 0, -1); +#endif +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/third_party.h b/libnd4j/include/ops/declarable/headers/third_party.h index 705a02903112..c4cad9eabb0d 100644 --- a/libnd4j/include/ops/declarable/headers/third_party.h +++ b/libnd4j/include/ops/declarable/headers/third_party.h @@ -24,11 +24,11 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_firas_sparse) - DECLARE_CUSTOM_OP(firas_sparse, 1, 1, false, 0, -1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_firas_sparse) +DECLARE_CUSTOM_OP(firas_sparse, 1, 1, false, 0, -1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 29efc4a73c7f..a64c2317f9fb 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -24,208 +24,209 @@ #include namespace sd { - namespace ops { - #if NOT_EXCLUDED(OP_clipbyvalue) - DECLARE_CONFIGURABLE_OP(clipbyvalue, 1, 1, true, 2, 0); - #endif - - #if NOT_EXCLUDED(OP_clipbynorm) - DECLARE_CONFIGURABLE_OP(clipbynorm, 1, 1, true, 1, 0); - DECLARE_CUSTOM_OP(clipbynorm_bp, 2, 1, false, 1, 0); - #endif - - #if NOT_EXCLUDED(OP_clipbyavgnorm) - DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0); - #endif - - #if NOT_EXCLUDED(OP_cumsum) - DECLARE_CONFIGURABLE_OP(cumsum, 1, 1, true, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_cumprod) - DECLARE_CONFIGURABLE_OP(cumprod, 1, 1, true, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_tile) - DECLARE_CUSTOM_OP(tile, 1, 1, false, 0, -2); - DECLARE_CUSTOM_OP(tile_bp, 2, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_repeat) - DECLARE_CUSTOM_OP(repeat, 1, 1, true, 0, -1); - #endif - - #if NOT_EXCLUDED(OP_invert_permutation) - DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0); - #endif - - DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0); - DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0); - - #if NOT_EXCLUDED(OP_mergemax) - DECLARE_OP(mergemax, -1, 1, false); - DECLARE_CUSTOM_OP(mergemax_bp, 2, 1, false, 0, 0); - #endif - /* - * Complete tensor with max indices merged from all input tensors list - * - * INPUT: tensors with the same shape - * OUTPUT: integer tensor with the same shape - * INT_ARG: result type (one of int), INT32 by default - */ - #if NOT_EXCLUDED(OP_mergemaxindex) - DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mergeadd) - DECLARE_OP(mergeadd, -1, 1, false); - DECLARE_CUSTOM_OP(mergeadd_bp, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_mergeavg) - DECLARE_OP(mergeavg, -1, 1, false); - DECLARE_CUSTOM_OP(mergeavg_bp, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_scatter_update) - DECLARE_CONFIGURABLE_OP(scatter_update, 2, 1, true, 0, -1); - #endif - - #if NOT_EXCLUDED(OP_Floor) - DECLARE_OP(Floor, 1, 1, true); - #endif - - #if NOT_EXCLUDED(OP_Log1p) - DECLARE_OP(Log1p, 2, 1, true); - #endif - - #if NOT_EXCLUDED(OP_reverse) - DECLARE_CONFIGURABLE_OP(reverse, 1, 1, true, 0, -2); - DECLARE_CUSTOM_OP(reverse_bp, 2, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_gather) - DECLARE_CUSTOM_OP(gather, 1, 1, false, 0, -2); - #endif - - #if NOT_EXCLUDED(OP_pad) - DECLARE_CUSTOM_OP(pad, 2, 1, false, 0, 1); - #endif - - /** - * creates identity 2D matrix or batch of identical 2D identity matrices - * - * Input array: - * provide some array - in any case operation simply neglects it - * - * Input float argument (if passed): - * TArgs[0] - type of elements of output array, default value is 5 (float) - * - * Input integer arguments: - * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order - * IArgs[1] - the number of rows in output inner-most 2D identity matrix - * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape - */ - #if NOT_EXCLUDED(OP_eye) - DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2); - #endif - - #if NOT_EXCLUDED(OP_gather_nd) - DECLARE_CUSTOM_OP(gather_nd, 2, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_reverse_sequence) - DECLARE_CUSTOM_OP(reverse_sequence, 2, 1, false, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_trace) - DECLARE_CUSTOM_OP(trace, 1, 1, false, 0, 0); - #endif - - #if NOT_EXCLUDED(OP_random_shuffle) - DECLARE_OP(random_shuffle, 1, 1, true); - #endif - - /** - * clip a list of given tensors with given average norm when needed - * - * Input: - * a list of tensors (at least one) - * - * Input floating point argument: - * clip_norm - a value that used as threshold value and norm to be used - * - * return a list of clipped tensors - * and global_norm as scalar tensor at the end - */ - #if NOT_EXCLUDED(OP_clip_by_global_norm) - DECLARE_CUSTOM_OP(clip_by_global_norm, 1, 2, true, 1, 0); - #endif - - DECLARE_CUSTOM_OP(tri, -2, 1, false, 0, 1); - - DECLARE_CUSTOM_OP(triu, 1, 1, false, 0, 0); - - DECLARE_CUSTOM_OP(triu_bp, 2, 1, false, 0, 0); - - #if NOT_EXCLUDED(OP_mirror_pad) - DECLARE_CUSTOM_OP(mirror_pad, 2, 1, false, 0, 1); - #endif - - #if NOT_EXCLUDED(OP_cumsum) - DECLARE_CUSTOM_OP(cumsum_bp, 2, -1, false, 0, 2); - #endif - - #if NOT_EXCLUDED(OP_cumprod) - DECLARE_CUSTOM_OP(cumprod_bp, 2, -21, false, 0, 2); - #endif - - - #if NOT_EXCLUDED(OP_flatten) - DECLARE_CUSTOM_OP(flatten, -1, 1, false, 0, 1); - #endif - - /** - * returns histogram (as 1D array) with fixed bins width - * - * Input arrays: - * - input array with elements to be binned into output histogram - * - range array with first element being bottom limit and second element being top limit of histogram, - please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * - * Input integer arguments: - * nbins (optional) - number of histogram bins, default value is 100 - */ - #if NOT_EXCLUDED(OP_histogram_fixed_width) - DECLARE_CUSTOM_OP(histogram_fixed_width, 2, 1, false, 0, 0); - #endif - - - /** - * standardizes input array to be zero mean unit variance along the given axis - * - * - */ - #if NOT_EXCLUDED(OP_standardize) - DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2); - DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2); - #endif - - /** - * This operation calculates hash code, optionally along dimension - */ - #if NOT_EXCLUDED(OP_hashcode) - DECLARE_CUSTOM_OP(hashcode, 1, 1, false, 0, 0); - #endif - - /** - * This operation calculates number of entries per bin - */ - #if NOT_EXCLUDED(OP_histogram) - DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); - #endif - } -} +namespace ops { +#if NOT_EXCLUDED(OP_clipbyvalue) +DECLARE_CONFIGURABLE_OP(clipbyvalue, 1, 1, true, 2, 0); +#endif + +#if NOT_EXCLUDED(OP_clipbynorm) +DECLARE_CONFIGURABLE_OP(clipbynorm, 1, 1, true, 1, 0); +DECLARE_CUSTOM_OP(clipbynorm_bp, 2, 1, false, 1, 0); +#endif + +#if NOT_EXCLUDED(OP_clipbyavgnorm) +DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0); +#endif + +#if NOT_EXCLUDED(OP_cumsum) +DECLARE_CONFIGURABLE_OP(cumsum, 1, 1, true, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_cumprod) +DECLARE_CONFIGURABLE_OP(cumprod, 1, 1, true, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_tile) +DECLARE_CUSTOM_OP(tile, 1, 1, false, 0, -2); +DECLARE_CUSTOM_OP(tile_bp, 2, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_repeat) +DECLARE_CUSTOM_OP(repeat, 1, 1, true, 0, -1); +#endif + +#if NOT_EXCLUDED(OP_invert_permutation) +DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0); +#endif + +DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0); +DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0); + +#if NOT_EXCLUDED(OP_mergemax) +DECLARE_OP(mergemax, -1, 1, false); +DECLARE_CUSTOM_OP(mergemax_bp, 2, 1, false, 0, 0); +#endif +/* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ +#if NOT_EXCLUDED(OP_mergemaxindex) +DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mergeadd) +DECLARE_OP(mergeadd, -1, 1, false); +DECLARE_CUSTOM_OP(mergeadd_bp, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_mergeavg) +DECLARE_OP(mergeavg, -1, 1, false); +DECLARE_CUSTOM_OP(mergeavg_bp, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_scatter_update) +DECLARE_CONFIGURABLE_OP(scatter_update, 2, 1, true, 0, -1); +#endif + +#if NOT_EXCLUDED(OP_Floor) +DECLARE_OP(Floor, 1, 1, true); +#endif + +#if NOT_EXCLUDED(OP_Log1p) +DECLARE_OP(Log1p, 2, 1, true); +#endif + +#if NOT_EXCLUDED(OP_reverse) +DECLARE_CONFIGURABLE_OP(reverse, 1, 1, true, 0, -2); +DECLARE_CUSTOM_OP(reverse_bp, 2, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_gather) +DECLARE_CUSTOM_OP(gather, 1, 1, false, 0, -2); +#endif + +#if NOT_EXCLUDED(OP_pad) +DECLARE_CUSTOM_OP(pad, 2, 1, false, 0, 1); +#endif + +/** + * creates identity 2D matrix or batch of identical 2D identity matrices + * + * Input array: + * provide some array - in any case operation simply neglects it + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * + * Input integer arguments: + * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> + * 'f'-order IArgs[1] - the number of rows in output inner-most 2D + * identity matrix IArgs[2] - optional, the number of columns in output + * inner-most 2D identity matrix, if this argument is not provided then it is + * taken to be equal to number of rows IArgs[3,4,...] - optional, shape of + * batch, output matrix will have leading batch dimensions of this shape + */ +#if NOT_EXCLUDED(OP_eye) +DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2); +#endif + +#if NOT_EXCLUDED(OP_gather_nd) +DECLARE_CUSTOM_OP(gather_nd, 2, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_reverse_sequence) +DECLARE_CUSTOM_OP(reverse_sequence, 2, 1, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_trace) +DECLARE_CUSTOM_OP(trace, 1, 1, false, 0, 0); +#endif + +#if NOT_EXCLUDED(OP_random_shuffle) +DECLARE_OP(random_shuffle, 1, 1, true); +#endif + +/** + * clip a list of given tensors with given average norm when needed + * + * Input: + * a list of tensors (at least one) + * + * Input floating point argument: + * clip_norm - a value that used as threshold value and norm to be used + * + * return a list of clipped tensors + * and global_norm as scalar tensor at the end + */ +#if NOT_EXCLUDED(OP_clip_by_global_norm) +DECLARE_CUSTOM_OP(clip_by_global_norm, 1, 2, true, 1, 0); +#endif + +DECLARE_CUSTOM_OP(tri, -2, 1, false, 0, 1); + +DECLARE_CUSTOM_OP(triu, 1, 1, false, 0, 0); + +DECLARE_CUSTOM_OP(triu_bp, 2, 1, false, 0, 0); + +#if NOT_EXCLUDED(OP_mirror_pad) +DECLARE_CUSTOM_OP(mirror_pad, 2, 1, false, 0, 1); +#endif + +#if NOT_EXCLUDED(OP_cumsum) +DECLARE_CUSTOM_OP(cumsum_bp, 2, -1, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_cumprod) +DECLARE_CUSTOM_OP(cumprod_bp, 2, -21, false, 0, 2); +#endif + +#if NOT_EXCLUDED(OP_flatten) +DECLARE_CUSTOM_OP(flatten, -1, 1, false, 0, 1); +#endif + +/** + * returns histogram (as 1D array) with fixed bins width + * + * Input arrays: + * - input array with elements to be binned into output histogram + * - range array with first element being bottom limit and second element being + top limit of histogram, please note that input_value <= range[0] will be mapped + to histogram[0], input_value >= range[1] will be mapped to histogram[-1] + * + * Input integer arguments: + * nbins (optional) - number of histogram bins, default value is 100 + */ +#if NOT_EXCLUDED(OP_histogram_fixed_width) +DECLARE_CUSTOM_OP(histogram_fixed_width, 2, 1, false, 0, 0); +#endif + +/** + * standardizes input array to be zero mean unit variance along the given axis + * + * + */ +#if NOT_EXCLUDED(OP_standardize) +DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2); +DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2); +#endif + +/** + * This operation calculates hash code, optionally along dimension + */ +#if NOT_EXCLUDED(OP_hashcode) +DECLARE_CUSTOM_OP(hashcode, 1, 1, false, 0, 0); +#endif + +/** + * This operation calculates number of entries per bin + */ +#if NOT_EXCLUDED(OP_histogram) +DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); +#endif +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/updaters.h b/libnd4j/include/ops/declarable/headers/updaters.h index dc08ff1f2116..c7e6f177505c 100644 --- a/libnd4j/include/ops/declarable/headers/updaters.h +++ b/libnd4j/include/ops/declarable/headers/updaters.h @@ -14,197 +14,194 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // - +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// #ifndef LIBND4J_HEADERS_UPDATERS_H #define LIBND4J_HEADERS_UPDATERS_H -#include -#include -#include #include +#include +#include +#include #include - namespace sd { - namespace ops { - +namespace ops { - /** - * SGD updater - * Input arrays: - * 0 - input array with gradients. - * Optional: - * 1 - scalar learning rate value - * Optional: - * T args - * 0 - scalar learning rate value - */ +/** + * SGD updater + * Input arrays: + * 0 - input array with gradients. + * Optional: + * 1 - scalar learning rate value + * Optional: + * T args + * 0 - scalar learning rate value + */ #if NOT_EXCLUDED(OP_sgd_updater) - DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); +DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); #endif - /** - * RmsPropUpdater updater - * Input arrays: - * 0 - input array with gradients. - * 1 - Initial state - * Optional: - * 2 - scalar learning rate value - * 3 - scalar rms decay - * 4 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - scalar rms decay - * 2 - epsilon - */ +/** + * RmsPropUpdater updater + * Input arrays: + * 0 - input array with gradients. + * 1 - Initial state + * Optional: + * 2 - scalar learning rate value + * 3 - scalar rms decay + * 4 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - scalar rms decay + * 2 - epsilon + */ #if NOT_EXCLUDED(OP_rms_prop_updater) - DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); #endif - // AdaGrad - /* Input arrays : - * 0 - input array with gradients. - * 1 - historical grad state - * Optional : - * 2 - scalar learning rate value - * 3 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - epsilon - */ +// AdaGrad +/* Input arrays : + * 0 - input array with gradients. + * 1 - historical grad state + * Optional : + * 2 - scalar learning rate value + * 3 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - epsilon + */ #if NOT_EXCLUDED(OP_ada_grad_updater) - DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); #endif - // AdaMax - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// AdaMax +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_ada_max_updater) - DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); #endif - // Nesterov's momentum - /* Input arrays : - * 0 - input array with gradients. - * 1 - V grad state - * Optional : - * 2 - scalar learning rate value - * 3 - scalar momentum value - * Optional: - * T args - * 0 - learning rate value - * 1 - momentum value - */ +// Nesterov's momentum +/* Input arrays : + * 0 - input array with gradients. + * 1 - V grad state + * Optional : + * 2 - scalar learning rate value + * 3 - scalar momentum value + * Optional: + * T args + * 0 - learning rate value + * 1 - momentum value + */ #if NOT_EXCLUDED(OP_nesterovs_updater) - DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); +DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); #endif - // Adam - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// Adam +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_adam_updater) - DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); #endif - // AdaDelta - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - rho value - * 6 - epsilon - * Optional: - * T args - * 0 - rho - * 1 - epsilon - */ +// AdaDelta +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - rho value + * 6 - epsilon + * Optional: + * T args + * 0 - rho + * 1 - epsilon + */ #if NOT_EXCLUDED(OP_ada_delta_updater) - DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); +DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); #endif - // Nadam - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - * 2 - gradient state M - * Optional : - * 3 - scalar learning rate value - * 4 - beta 1 value - * 5 - beta 2 value - * 6 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +// Nadam +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_nadam_updater) - DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); -#endif - // AmsGrad - /* Input arrays : - * 0 - input array with gradients. - * 1 - gradient state V - sqrd gradients - * 2 - gradient state M - moving avg - * 3 - gradient state H - max - * Optional : - * 4 - scalar learning rate value - * 5 - beta 1 value - * 6 - beta 2 value - * 7 - epsilon - * Optional: - * T args - * 0 - scalar learning rate value - * 1 - beta 1 value - * 2 - beta 2 value - * 3 - epsilon - * Optional: - * I args - * 0 - iteration - */ +DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); +#endif +// AmsGrad +/* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V - sqrd gradients + * 2 - gradient state M - moving avg + * 3 - gradient state H - max + * Optional : + * 4 - scalar learning rate value + * 5 - beta 1 value + * 6 - beta 2 value + * 7 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ #if NOT_EXCLUDED(OP_ams_grad_updater) - DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); -#endif -} -} +DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); +#endif +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/headers/util.h b/libnd4j/include/ops/declarable/headers/util.h index 57b013f298b5..a32a700bcd09 100644 --- a/libnd4j/include/ops/declarable/headers/util.h +++ b/libnd4j/include/ops/declarable/headers/util.h @@ -24,21 +24,21 @@ #include namespace sd { - namespace ops { - /** - * This operation prints out NDArray content, either on host or device. - */ - #if NOT_EXCLUDED(OP_print_variable) - DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); - #endif +namespace ops { +/** + * This operation prints out NDArray content, either on host or device. + */ +#if NOT_EXCLUDED(OP_print_variable) +DECLARE_CUSTOM_OP(print_variable, 1, 1, true, 0, 0); +#endif - /** - * This operation prints out affinity & locality status of given NDArray - */ - #if NOT_EXCLUDED(OP_print_affinity) - DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); - #endif - } -} +/** + * This operation prints out affinity & locality status of given NDArray + */ +#if NOT_EXCLUDED(OP_print_affinity) +DECLARE_CUSTOM_OP(print_affinity, 1, 1, true, 0, 0); +#endif +} // namespace ops +} // namespace sd -#endif //LIBND4J_UTILS_H +#endif // LIBND4J_UTILS_H diff --git a/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h b/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h index f52dd9ba44ac..bb03c96fa66b 100644 --- a/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h +++ b/libnd4j/include/ops/declarable/helpers/BarnesHutTsne.h @@ -27,15 +27,22 @@ namespace sd { namespace ops { namespace helpers { - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts); - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts = nullptr); - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data); - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output); - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension); - -} -} -} - - -#endif //LIBND4J_ACTIVATIONS_H +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts); +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts = nullptr); +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data); +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output); +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_ACTIVATIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/activations.h b/libnd4j/include/ops/declarable/helpers/activations.h index b20eb8450ab3..9065464ebd62 100644 --- a/libnd4j/include/ops/declarable/helpers/activations.h +++ b/libnd4j/include/ops/declarable/helpers/activations.h @@ -27,26 +27,37 @@ namespace sd { namespace ops { namespace helpers { - SD_EXPORT void softMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); +SD_EXPORT void softMaxForVector(sd::LaunchContext *context, + const NDArray &input, NDArray &output); - SD_EXPORT void logSoftMaxForVector(sd::LaunchContext * context, const NDArray &input, NDArray &output); +SD_EXPORT void logSoftMaxForVector(sd::LaunchContext *context, + const NDArray &input, NDArray &output); - SD_EXPORT void softmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); +SD_EXPORT void softmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension); - SD_EXPORT void logSoftmax(sd::LaunchContext * context, const NDArray &input, NDArray &output, const int dimension); +SD_EXPORT void logSoftmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension); - SD_EXPORT void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension); +SD_EXPORT void softmaxDerivative(sd::LaunchContext *context, + const NDArray &input, NDArray &output, + const int dimension); - SD_EXPORT void prelu(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, NDArray &output); +SD_EXPORT void prelu(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, NDArray &output); - SD_EXPORT void preluBP(sd::LaunchContext * context, const NDArray &input, const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, NDArray &dLdA); +SD_EXPORT void preluBP(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, + NDArray &dLdA); - SD_EXPORT void thresholdRelu(sd::LaunchContext * context, const NDArray &input, double threshold, NDArray &output); +SD_EXPORT void thresholdRelu(sd::LaunchContext *context, const NDArray &input, + double threshold, NDArray &output); - SD_EXPORT void thresholdReluDerivative(sd::LaunchContext * context, NDArray *input, double threshold, NDArray* dLdO, NDArray *output); -} -} -} +SD_EXPORT void thresholdReluDerivative(sd::LaunchContext *context, + NDArray *input, double threshold, + NDArray *dLdO, NDArray *output); +} // namespace helpers +} // namespace ops +} // namespace sd - -#endif //LIBND4J_ACTIVATIONS_H +#endif // LIBND4J_ACTIVATIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/addBias.h b/libnd4j/include/ops/declarable/helpers/addBias.h index 8eff731f7401..3760d90999d5 100644 --- a/libnd4j/include/ops/declarable/helpers/addBias.h +++ b/libnd4j/include/ops/declarable/helpers/addBias.h @@ -21,20 +21,18 @@ #ifndef LIBND4J_ADDBIAS_H #define LIBND4J_ADDBIAS_H -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void addBias(graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW); - +void addBias(graph::Context& block, const NDArray& input, const NDArray& bias, + NDArray& output, const bool isNCHW); } -} -} - +} // namespace ops +} // namespace sd -#endif // LIBND4J_ADDBIAS_H +#endif // LIBND4J_ADDBIAS_H diff --git a/libnd4j/include/ops/declarable/helpers/adjust_hue.h b/libnd4j/include/ops/declarable/helpers/adjust_hue.h index 2d0e2f08762d..3dda2bcd90d6 100644 --- a/libnd4j/include/ops/declarable/helpers/adjust_hue.h +++ b/libnd4j/include/ops/declarable/helpers/adjust_hue.h @@ -20,114 +20,108 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC); - - +void adjustHue(sd::LaunchContext* context, const NDArray* input, + const NDArray* deltaScalarArr, NDArray* output, const int dimC); //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) { - - // h values are in range [0, 360) - // s and v values are in range [0, 1] - - const T max = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); - const T min = sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); - const T c = max - min; - const T _p6 = (T)1 / (T)6; - // calculate h - if(c == 0) { - h = 0; - } - else if(max == r) { - h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1); - } - else if(max == g) { - h = _p6 * ((b - r) / c + (T)2); - } - else { // max == b - h = _p6 * ((r - g) / c + (T)4); - } - - // calculate s - s = max == (T)0 ? (T)0 : c / max; - - // calculate v - v = max;// / 255.f; +FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, + T& s, T& v) { + // h values are in range [0, 360) + // s and v values are in range [0, 1] + + const T max = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); + const T min = sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); + const T c = max - min; + const T _p6 = (T)1 / (T)6; + // calculate h + if (c == 0) { + h = 0; + } else if (max == r) { + h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1); + } else if (max == g) { + h = _p6 * ((b - r) / c + (T)2); + } else { // max == b + h = _p6 * ((r - g) / c + (T)4); + } + + // calculate s + s = max == (T)0 ? (T)0 : c / max; + + // calculate v + v = max; // / 255.f; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) { - - const float sector = h * 6.f; - const T c = v * s; - - if(0.f <= sector && sector < 1.f) { - r = v; - g = v - c * (1 - sector); - b = v - c; - } - else if(1.f <= sector && sector < 2.f) { - r = v - c * (sector - 1); - g = v; - b = v - c; - } - else if(2.f <= sector && sector < 3.f) { - r = v - c; - g = v; - b = v - c * (3 - sector); - } - else if(3.f <= sector && sector < 4.f) { - r = v - c; - g = v - c * (sector - 3); - b = v; - } - else if(4.f <= sector && sector < 5.f) { - r = v - c * (5 - sector); - g = v - c; - b = v; - } - else { // 5.f <= sector < 6.f - r = v; - g = v - c; - b = v - c * (sector - 5); - } - -// r *= 255; -// g *= 255; -// b *= 255; +FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, + T& g, T& b) { + const float sector = h * 6.f; + const T c = v * s; + + if (0.f <= sector && sector < 1.f) { + r = v; + g = v - c * (1 - sector); + b = v - c; + } else if (1.f <= sector && sector < 2.f) { + r = v - c * (sector - 1); + g = v; + b = v - c; + } else if (2.f <= sector && sector < 3.f) { + r = v - c; + g = v; + b = v - c * (3 - sector); + } else if (3.f <= sector && sector < 4.f) { + r = v - c; + g = v - c * (sector - 3); + b = v; + } else if (4.f <= sector && sector < 5.f) { + r = v - c * (5 - sector); + g = v - c; + b = v; + } else { // 5.f <= sector < 6.f + r = v; + g = v - c; + b = v - c * (sector - 5); + } + + // r *= 255; + // g *= 255; + // b *= 255; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, T& v) { - y = static_cast(0.299) * r + static_cast(0.587) *g + static_cast(0.114) * b; - u = -static_cast(0.14714119) * r - static_cast(0.2888691) * g + static_cast(0.43601035) * b; - v = static_cast(0.61497538) * r - static_cast(0.51496512) * g - static_cast(0.10001026) * b; +FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, + T& v) { + y = static_cast(0.299) * r + static_cast(0.587) * g + + static_cast(0.114) * b; + u = -static_cast(0.14714119) * r - static_cast(0.2888691) * g + + static_cast(0.43601035) * b; + v = static_cast(0.61497538) * r - static_cast(0.51496512) * g - + static_cast(0.10001026) * b; } //////////////////////////////////////////////////////////////////////////////// template -FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, T& b) { - r = y + static_cast(1.13988303) * v; - g = y - static_cast(0.394642334) * u - static_cast(0.58062185) * v; - b = y + static_cast(2.03206185) * u; +FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, + T& b) { + r = y + static_cast(1.13988303) * v; + g = y - static_cast(0.394642334) * u - static_cast(0.58062185) * v; + b = y + static_cast(2.03206185) * u; } /*//////////////////////////////////////////////////////////////////////////////// template -static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { - T v_mid; - int h_category; +static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* +v_max) { T v_mid; int h_category; // According to the figures in: // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma // For the conditions, we don't care about the case where two components are @@ -185,12 +179,9 @@ static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_m //////////////////////////////////////////////////////////////////////////////// template -static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) { - int h_category = static_cast(h); - T ratio = h - (T)h_category; - bool increase = ((h_category & 0x1) == 0); - if (!increase) - ratio = 1 - ratio; +static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* +b) { int h_category = static_cast(h); T ratio = h - (T)h_category; bool +increase = ((h_category & 0x1) == 0); if (!increase) ratio = 1 - ratio; T v_mid = v_min + ratio * (v_max - v_min); // According to the figures in: @@ -230,6 +221,6 @@ static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* } */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/adjust_saturation.h b/libnd4j/include/ops/declarable/helpers/adjust_saturation.h index 25dc30f102ad..2f95d8643a24 100644 --- a/libnd4j/include/ops/declarable/helpers/adjust_saturation.h +++ b/libnd4j/include/ops/declarable/helpers/adjust_saturation.h @@ -19,25 +19,24 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC); +void adjustSaturation(sd::LaunchContext* context, const NDArray* input, + const NDArray* factorScalarArr, NDArray* output, + const int dimC); /* template - static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) { - T vv = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); - T range = vv - sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); - if (vv > 0) { - *s = range / vv; - } else { - *s = 0; + static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) + { T vv = sd::math::nd4j_max(r, sd::math::nd4j_max(g, b)); T range = vv + - sd::math::nd4j_min(r, sd::math::nd4j_min(g, b)); if (vv > 0) { *s = + range / vv; } else { *s = 0; } T norm = 1.0f / (6.0f * range); T hh; @@ -59,15 +58,9 @@ namespace helpers { } template - static FORCEINLINE _CUDA_HD void hsv_to_rgb(T h, T s, T v, T* r, T* g, T* b) { - T c = s * v; - T m = v - c; - T dh = h * 6; - T rr, gg, bb; - int h_category = static_cast(dh); - T fmodu = dh; - while (fmodu <= (T) 0) - fmodu += (T) 2.0f; + static FORCEINLINE _CUDA_HD void hsv_to_rgb(T h, T s, T v, T* r, T* g, T* b) + { T c = s * v; T m = v - c; T dh = h * 6; T rr, gg, bb; int h_category = + static_cast(dh); T fmodu = dh; while (fmodu <= (T) 0) fmodu += (T) 2.0f; while (fmodu >= (T) 2.0f) fmodu -= (T) 2.0f; @@ -116,6 +109,6 @@ namespace helpers { } */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/axis.h b/libnd4j/include/ops/declarable/helpers/axis.h index 76c5a070ddc0..8379b4591b23 100644 --- a/libnd4j/include/ops/declarable/helpers/axis.h +++ b/libnd4j/include/ops/declarable/helpers/axis.h @@ -19,20 +19,20 @@ // #ifndef __AXIS_H_HELPERS__ #define __AXIS_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - /* - * adjustAxis routines: adjust data with output to non-negative values. - * */ - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output); - void adjustAxis(Nd4jLong rank, std::vector& output); +/* + * adjustAxis routines: adjust data with output to non-negative values. + * */ +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output); +void adjustAxis(Nd4jLong rank, std::vector& output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/batched_gemm.h b/libnd4j/include/ops/declarable/helpers/batched_gemm.h index 26651cf3c63e..63e18a1df15a 100644 --- a/libnd4j/include/ops/declarable/helpers/batched_gemm.h +++ b/libnd4j/include/ops/declarable/helpers/batched_gemm.h @@ -18,15 +18,19 @@ // @author raver119@gmail.com // -#include #include +#include + namespace sd { namespace ops { namespace helpers { -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc); +void bgemm(const std::vector& vA, const std::vector& vB, + std::vector& vC, const NDArray* alphas, + const NDArray* betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/batchnorm.h b/libnd4j/include/ops/declarable/helpers/batchnorm.h index 72bc69718e37..62fc8baefc64 100644 --- a/libnd4j/include/ops/declarable/helpers/batchnorm.h +++ b/libnd4j/include/ops/declarable/helpers/batchnorm.h @@ -23,17 +23,17 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon); - void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon); - - -} } -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_BATCHNORM_H +#endif // LIBND4J_BATCHNORM_H diff --git a/libnd4j/include/ops/declarable/helpers/betaInc.h b/libnd4j/include/ops/declarable/helpers/betaInc.h index 4c37d7c5dc3a..ab16f8c0083f 100644 --- a/libnd4j/include/ops/declarable/helpers/betaInc.h +++ b/libnd4j/include/ops/declarable/helpers/betaInc.h @@ -22,20 +22,23 @@ #define LIBND4J_BETAINC_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - const uint maxIter = MAX_NUM_THREADS /*articles propose 10000*/; // max number of loop iterations in function for continued fractions - - void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output); - +const uint maxIter = + MAX_NUM_THREADS /*articles propose 10000*/; // max number of loop + // iterations in function for + // continued fractions -} -} -} +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_BETAINC_H \ No newline at end of file +#endif // LIBND4J_BETAINC_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/choose.h b/libnd4j/include/ops/declarable/helpers/choose.h index 233c6b1ff95c..4a1c5a389fa0 100644 --- a/libnd4j/include/ops/declarable/helpers/choose.h +++ b/libnd4j/include/ops/declarable/helpers/choose.h @@ -19,17 +19,20 @@ // #ifndef __CHOOSE_H_HELPERS__ #define __CHOOSE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void chooseFunctorArray(sd::LaunchContext * context, NDArray* arg, NDArray* comp, int mode, NDArray* result, NDArray* numResults); - void chooseFunctorScalar(sd::LaunchContext * context, NDArray* arg, double scalar, int mode, NDArray* result, NDArray* numResults); +void chooseFunctorArray(sd::LaunchContext* context, NDArray* arg, NDArray* comp, + int mode, NDArray* result, NDArray* numResults); +void chooseFunctorScalar(sd::LaunchContext* context, NDArray* arg, + double scalar, int mode, NDArray* result, + NDArray* numResults); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/col2im.h b/libnd4j/include/ops/declarable/helpers/col2im.h index 7a1468d88223..d139c6606ccc 100644 --- a/libnd4j/include/ops/declarable/helpers/col2im.h +++ b/libnd4j/include/ops/declarable/helpers/col2im.h @@ -27,12 +27,13 @@ namespace sd { namespace ops { namespace helpers { - SD_EXPORT void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW); +SD_EXPORT void col2im(sd::LaunchContext& context, const NDArray& input, + NDArray& output, const int sH, const int sW, const int pH, + const int pW, const int iH, const int iW, const int dH, + const int dW); - -} } -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_COL2IM_H +#endif // LIBND4J_COL2IM_H diff --git a/libnd4j/include/ops/declarable/helpers/compare_elem.h b/libnd4j/include/ops/declarable/helpers/compare_elem.h index 18faee3288b9..0d0f9f8eb8cc 100644 --- a/libnd4j/include/ops/declarable/helpers/compare_elem.h +++ b/libnd4j/include/ops/declarable/helpers/compare_elem.h @@ -14,21 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - #ifndef LIBND4J_COMPARE_ELEM_H #define LIBND4J_COMPARE_ELEM_H #include + #include "array/NDArray.h" namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void compare_elem(sd::LaunchContext * context, NDArray* input, bool isStrictlyIncreasing, bool& output); - } - } +void compare_elem(sd::LaunchContext* context, NDArray* input, + bool isStrictlyIncreasing, bool& output); } +} // namespace ops +} // namespace sd - -#endif //LIBND4J_COMPARE_ELEM_H \ No newline at end of file +#endif // LIBND4J_COMPARE_ELEM_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/compression.h b/libnd4j/include/ops/declarable/helpers/compression.h index b9c70a91b886..966d2209dcd3 100644 --- a/libnd4j/include/ops/declarable/helpers/compression.h +++ b/libnd4j/include/ops/declarable/helpers/compression.h @@ -19,16 +19,18 @@ // #ifndef __COMPRESSION_H_HELPERS__ #define __COMPRESSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output); - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold); -} -} -} +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output); +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/confusion.h b/libnd4j/include/ops/declarable/helpers/confusion.h index a4d27be4fbe6..e660c2b7bdd1 100644 --- a/libnd4j/include/ops/declarable/helpers/confusion.h +++ b/libnd4j/include/ops/declarable/helpers/confusion.h @@ -19,16 +19,17 @@ // #ifndef __CONFUSION_H_HELPERS__ #define __CONFUSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output); +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index 211bd4d76b4f..90e0b2467cbd 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -22,312 +22,427 @@ #define LIBND4J_CONVOLUTIONS_H #include +#include #include #include -#include - namespace sd { - namespace ops { - - enum PoolingType { - MAX_POOL = 0, - AVG_POOL = 1, - PNORM_POOL = 2, - }; - - class SD_EXPORT ConvolutionUtils { - public: - static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { - - if(paddingMode == 0) { // valid - // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; - // oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; - oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; - oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; - } - else if (paddingMode == 1) { // same - oH = (int) math::nd4j_ceil(iH * 1. / sH); - oW = (int) math::nd4j_ceil(iW * 1. / sW); - } - else { // causal - oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH - oW = (iW - 1) / sW + 1; - } - } - - static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { - - if(paddingMode == 0) { // valid - oD = (iD - ((kD - 1) * dD + 1) + 2 * pD) / sD + 1; - oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; - oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; - } - else if(paddingMode == 1) { // same - oD = (int) sd::math::nd4j_ceil(iD * 1. / sD); - oH = (int) sd::math::nd4j_ceil(iH * 1. / sH); - oW = (int) sd::math::nd4j_ceil(iW * 1. / sW); - - } - else { // causal - oD = (iD - 1) / sD + 1; - oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH - oW = (iW - 1) / sW + 1; - } - } - - static inline void calcPadding2D(int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, int sW, int dH, int dW, const int paddingMode = 1 /* default is same mode*/) { - - if(paddingMode == 0) // valid - return; - - if(paddingMode == 1) { // same - - const int eKH = (kH - 1) * dH + 1; - const int eKW = (kW - 1) * dW + 1; - - pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pW = ((oW - 1) * sW + eKW - iW) / 2; - } - else { // causal - pH = (kH - 1) * dH; - pW = (kW - 1) * dW; - } - } - - static inline void calcPadding3D(int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, const int iD, const int iH, const int iW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int dD, const int dH, const int dW, const int paddingMode = 1 /* default is same mode*/) { - - if(paddingMode == 0) // valid - return; - - if(paddingMode == 1) { // same - - const int eKD = (kD - 1) * dD + 1; - const int eKH = (kH - 1) * dH + 1; - const int eKW = (kW - 1) * dW + 1; - - pD = ((oD - 1) * sD + eKD - iD) / 2; - pH = ((oH - 1) * sH + eKH - iH) / 2; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2 - pW = ((oW - 1) * sW + eKW - iW) / 2; - } - else { // causal - pD = (kD - 1) * dD; - pH = (kH - 1) * dH; - pW = (kW - 1) * dW; - } - } - - // calculation of output height and width in 2D deconvolution procedure - static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int iH, const int iW, const int paddingMode) { - - if (paddingMode) { - oH = sH * iH; - oW = sW * iW; - } - else { - const int ekH = (kH - 1) * dH + 1; - const int ekW = (kW - 1) * dW + 1; - - oH = sH * (iH - 1) + ekH - 2 * pH; - oW = sW * (iW - 1) + ekW - 2 * pW; - } - } - - // calculation of output height and width in 3D deconvolution procedure - static inline void calcOutSizeDeconv3D(int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int iD, const int iH, const int iW, const int paddingMode) { - - if (paddingMode) { - oD = sD * iD; - oH = sH * iH; - oW = sW * iW; - } - else { - - const int ekD = (kD - 1) * dD + 1; - const int ekH = (kH - 1) * dH + 1; - const int ekW = (kW - 1) * dW + 1; - - oD = sD * (iD - 1) + ekD - 2 * pD; - oH = sH * (iH - 1) + ekH - 2 * pH; - oW = sW * (iW - 1) + ekW - 2 * pW; - } - } - - // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { - getSizesAndIndexesConv2d(isNCHW, wFormat, input.shapeInfo(), output.shapeInfo(), bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - } - - static inline void getSizesAndIndexesConv2d(const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, int& indWkH, int& indOoH) { - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2) - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - if(0 == wFormat) { - indWkH = 0; indWiC = 2; indWoC = 3; - } - else if(1 == wFormat) { - indWkH = 2; indWiC = 1; indWoC = 0; - } - else { - indWkH = 1; indWiC = 3; indWoC = 0; - } - - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - bS = inShapeInfo[1]; // batch size - iC = inShapeInfo[indIOioC+1]; // input channels - iH = inShapeInfo[indIiH+1]; // input height - iW = inShapeInfo[indIiH+2]; // input width - oC = outShapeInfo[indIOioC+1]; // output channels - oH = outShapeInfo[indOoH+1]; // output height - oW = outShapeInfo[indOoH+2]; // output width - } - - // evaluates sizes values and indexes using input and output arrays depending on data format - static inline void getSizesAndIndexesConv3d(const bool isNCDHW, const int wFormat, const NDArray& input, const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, int& indWiC, int& indWoC, int& indWkD) { - // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2) - // output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - if(0 == wFormat) { - indWkD = 0; indWiC = 3; indWoC = 4; - } - else if(1 == wFormat) { - indWkD = 2; indWiC = 1; indWoC = 0; - } - else { - indWkD = 1; indWiC = 4; indWoC = 0; - } - - if(!isNCDHW) { - indIOioC = 4; indIOioD = 1; - } - else { - indIOioC = 1; indIOioD = 2; - } - - bS = input.sizeAt(0); // batch size - iC = input.sizeAt(indIOioC); // input channels - iD = input.sizeAt(indIOioD); // input depth - iH = input.sizeAt(indIOioD+1); // input height - iW = input.sizeAt(indIOioD+2); // input width - oC = output.sizeAt(indIOioC); // output channels - oD = output.sizeAt(indIOioD); // output depth - oH = output.sizeAt(indIOioD+1); // output height - oW = output.sizeAt(indIOioD+2); // output width - } - - // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) { - - // if(kH != 1) { - // if(paddingMode) { - // pH = (oH - 1) * sH - iH + kH - pH; - // dH = dH - 1; - // } - // else - // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - // } - // if(kW != 1) { - // if(paddingMode) { - // pW = (oW - 1) * sW - iW + kW - pW; - // dW = dW - 1; - // } - // else - // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - // } - // } - - // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) { - - // if(kD != 1) { - // if(paddingMode) { - // pD = (oD - 1) * sD - iD + kD - pD; - // dD = dD - 1; - // } - // else - // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); - // } - // if(kH != 1) { - // if(paddingMode) { - // pH = (oH - 1) * sH - iH + kH - pH; - // dH = dH - 1; - // } - // else - // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); - // } - // if(kW != 1) { - // if(paddingMode) { - // pW = (oW - 1) * sW - iW + kW - pW; - // dW = dW - 1; - // } - // else - // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); - // } - // } - - static std::vector expectWeightsShape(const int wFormat, const int kH, const int kW, const int iC, const int oC) { - - if(0 == wFormat) - return std::vector({kH, kW, iC, oC}); - - if(1 == wFormat) - return std::vector({oC, iC, kH, kW}); - - return std::vector({oC, kH, kW, iC}); - } - - static std::vector expectWeightsShape(const int wFormat, const int kD, const int kH, const int kW, const int iC, const int oC) { - - if(0 == wFormat) - return std::vector({kD, kH, kW, iC, oC}); - - if(1 == wFormat) - return std::vector({oC, iC, kD, kH, kW}); - - return std::vector({oC, kD, kH, kW, iC}); - } - - static void conv2d(sd::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - // static void conv2d(sd::graph::Context & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); - - // static void conv2dBP(sd::graph::Context & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); - - static void conv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void depthwiseConv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void depthwiseConv2dBP(sd::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void sconv2d(sd::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat); - - static void vol2col(sd::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - - static void col2vol(sd::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - - static void upsampling2d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW); - - static void upsampling3d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW); - - static void upsampling2dBP(sd::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW); - - static void upsampling3dBP(sd::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW); - - static void pooling2d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0); - - static void pooling3d(sd::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); - - static void pooling2dBP(sd::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0); - - static void pooling3dBP(sd::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); - }; - -} -} -#endif //LIBND4J_CONVOLUTIONS_H +namespace ops { + +enum PoolingType { + MAX_POOL = 0, + AVG_POOL = 1, + PNORM_POOL = 2, +}; + +class SD_EXPORT ConvolutionUtils { + public: + static inline void calcOutSizePool2D(int& oH, int& oW, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int iH, const int iW, + const int paddingMode) { + if (paddingMode == 0) { // valid + // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; + // oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; + } else if (paddingMode == 1) { // same + oH = (int)math::nd4j_ceil(iH * 1. / sH); + oW = (int)math::nd4j_ceil(iW * 1. / sW); + } else { // causal + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; + } + } + + static inline void calcOutSizePool3D(int& oD, int& oH, int& oW, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const int iD, + const int iH, const int iW, + const int paddingMode) { + if (paddingMode == 0) { // valid + oD = (iD - ((kD - 1) * dD + 1) + 2 * pD) / sD + 1; + oH = (iH - ((kH - 1) * dH + 1) + 2 * pH) / sH + 1; + oW = (iW - ((kW - 1) * dW + 1) + 2 * pW) / sW + 1; + } else if (paddingMode == 1) { // same + oD = (int)sd::math::nd4j_ceil(iD * 1. / sD); + oH = (int)sd::math::nd4j_ceil(iH * 1. / sH); + oW = (int)sd::math::nd4j_ceil(iW * 1. / sW); + + } else { // causal + oD = (iD - 1) / sD + 1; + oH = (iH - 1) / sH + 1; // 2*pH = (kH-1)*dH + oW = (iW - 1) / sW + 1; + } + } + + static inline void calcPadding2D( + int& pH, int& pW, int oH, int oW, int iH, int iW, int kH, int kW, int sH, + int sW, int dH, int dW, + const int paddingMode = 1 /* default is same mode*/) { + if (paddingMode == 0) // valid + return; + + if (paddingMode == 1) { // same + + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pH = ((oH - 1) * sH + eKH - iH) / + 2; // Note that padBottom is 1 bigger than this if bracketed term is + // not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } else { // causal + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; + } + } + + static inline void calcPadding3D( + int& pD, int& pH, int& pW, const int oD, const int oH, const int oW, + const int iD, const int iH, const int iW, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, const int dD, + const int dH, const int dW, + const int paddingMode = 1 /* default is same mode*/) { + if (paddingMode == 0) // valid + return; + + if (paddingMode == 1) { // same + + const int eKD = (kD - 1) * dD + 1; + const int eKH = (kH - 1) * dH + 1; + const int eKW = (kW - 1) * dW + 1; + + pD = ((oD - 1) * sD + eKD - iD) / 2; + pH = ((oH - 1) * sH + eKH - iH) / + 2; // Note that padBottom is 1 bigger than this if bracketed term is + // not divisible by 2 + pW = ((oW - 1) * sW + eKW - iW) / 2; + } else { // causal + pD = (kD - 1) * dD; + pH = (kH - 1) * dH; + pW = (kW - 1) * dW; + } + } + + // calculation of output height and width in 2D deconvolution procedure + static inline void calcOutSizeDeconv2D(int& oH, int& oW, const int kH, + const int kW, const int sH, + const int sW, const int pH, + const int pW, const int dH, + const int dW, const int iH, + const int iW, const int paddingMode) { + if (paddingMode) { + oH = sH * iH; + oW = sW * iW; + } else { + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; + + oH = sH * (iH - 1) + ekH - 2 * pH; + oW = sW * (iW - 1) + ekW - 2 * pW; + } + } + + // calculation of output height and width in 3D deconvolution procedure + static inline void calcOutSizeDeconv3D( + int& oD, int& oH, int& oW, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW, const int iD, + const int iH, const int iW, const int paddingMode) { + if (paddingMode) { + oD = sD * iD; + oH = sH * iH; + oW = sW * iW; + } else { + const int ekD = (kD - 1) * dD + 1; + const int ekH = (kH - 1) * dH + 1; + const int ekW = (kW - 1) * dW + 1; + + oD = sD * (iD - 1) + ekD - 2 * pD; + oH = sH * (iH - 1) + ekH - 2 * pH; + oW = sW * (iW - 1) + ekW - 2 * pW; + } + } + + // evaluates sizes values and indexes using input and output arrays depending + // on data format + static inline void getSizesAndIndexesConv2d( + const bool isNCHW, const int wFormat, const NDArray& input, + const NDArray& output, int& bS, int& iC, int& iH, int& iW, int& oC, + int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, + int& indWkH, int& indOoH) { + getSizesAndIndexesConv2d(isNCHW, wFormat, input.shapeInfo(), + output.shapeInfo(), bS, iC, iH, iW, oC, oH, oW, + indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + } + + static inline void getSizesAndIndexesConv2d( + const bool isNCHW, const int wFormat, const Nd4jLong* inShapeInfo, + const Nd4jLong* outShapeInfo, int& bS, int& iC, int& iH, int& iW, int& oC, + int& oH, int& oW, int& indIOioC, int& indIiH, int& indWiC, int& indWoC, + int& indWkH, int& indOoH) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), + // [oC, kH, kW, iC] (wFormat = 2) output [bS, oH, oW, oC] (NHWC) or [bS, + // oC, oH, oW] (NCHW) + + if (0 == wFormat) { + indWkH = 0; + indWiC = 2; + indWoC = 3; + } else if (1 == wFormat) { + indWkH = 2; + indWiC = 1; + indWoC = 0; + } else { + indWkH = 1; + indWiC = 3; + indWoC = 0; + } + + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + bS = inShapeInfo[1]; // batch size + iC = inShapeInfo[indIOioC + 1]; // input channels + iH = inShapeInfo[indIiH + 1]; // input height + iW = inShapeInfo[indIiH + 2]; // input width + oC = outShapeInfo[indIOioC + 1]; // output channels + oH = outShapeInfo[indOoH + 1]; // output height + oW = outShapeInfo[indOoH + 2]; // output width + } + + // evaluates sizes values and indexes using input and output arrays depending + // on data format + static inline void getSizesAndIndexesConv3d( + const bool isNCDHW, const int wFormat, const NDArray& input, + const NDArray& output, int& bS, int& iC, int& iD, int& iH, int& iW, + int& oC, int& oD, int& oH, int& oW, int& indIOioC, int& indIOioD, + int& indWiC, int& indWoC, int& indWkD) { + // input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + // weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat + // = 1), [oC, kD, kH, kW, iC] (wFormat = 2) output [bS, oD, oH, oW, oC] + // (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + if (0 == wFormat) { + indWkD = 0; + indWiC = 3; + indWoC = 4; + } else if (1 == wFormat) { + indWkD = 2; + indWiC = 1; + indWoC = 0; + } else { + indWkD = 1; + indWiC = 4; + indWoC = 0; + } + + if (!isNCDHW) { + indIOioC = 4; + indIOioD = 1; + } else { + indIOioC = 1; + indIOioD = 2; + } + + bS = input.sizeAt(0); // batch size + iC = input.sizeAt(indIOioC); // input channels + iD = input.sizeAt(indIOioD); // input depth + iH = input.sizeAt(indIOioD + 1); // input height + iW = input.sizeAt(indIOioD + 2); // input width + oC = output.sizeAt(indIOioC); // output channels + oD = output.sizeAt(indIOioD); // output depth + oH = output.sizeAt(indIOioD + 1); // output height + oW = output.sizeAt(indIOioD + 2); // output width + } + + // static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const + // int iW, const int oH, const int oW, const int kH, const int kW, const int + // sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& + // dW) { + + // if(kH != 1) { + // if(paddingMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(paddingMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } + + // static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const + // int iH, const int iW, const int oD, const int oH, const int oW, const int + // kD, const int kH, const int kW, const int sD, const int sH, const int sW, + // const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& + // dW) { + + // if(kD != 1) { + // if(paddingMode) { + // pD = (oD - 1) * sD - iD + kD - pD; + // dD = dD - 1; + // } + // else + // dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1); + // } + // if(kH != 1) { + // if(paddingMode) { + // pH = (oH - 1) * sH - iH + kH - pH; + // dH = dH - 1; + // } + // else + // dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1); + // } + // if(kW != 1) { + // if(paddingMode) { + // pW = (oW - 1) * sW - iW + kW - pW; + // dW = dW - 1; + // } + // else + // dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1); + // } + // } + + static std::vector expectWeightsShape(const int wFormat, + const int kH, const int kW, + const int iC, const int oC) { + if (0 == wFormat) return std::vector({kH, kW, iC, oC}); + + if (1 == wFormat) return std::vector({oC, iC, kH, kW}); + + return std::vector({oC, kH, kW, iC}); + } + + static std::vector expectWeightsShape(const int wFormat, + const int kD, const int kH, + const int kW, const int iC, + const int oC) { + if (0 == wFormat) return std::vector({kD, kH, kW, iC, oC}); + + if (1 == wFormat) return std::vector({oC, iC, kD, kH, kW}); + + return std::vector({oC, kD, kH, kW, iC}); + } + + static void conv2d(sd::graph::Context& context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + // static void conv2d(sd::graph::Context & block, const std::vector& + // inArrs, NDArray* output, const std::vector& intArgs); + + // static void conv2dBP(sd::graph::Context & block, const + // std::vector& inArrs, const std::vector& outArrs, const + // std::vector& intArgs); + + static void conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + static void depthwiseConv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat); + + static void depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat); + + static void sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat); + + static void vol2col(sd::graph::Context& block, const NDArray& vol, + NDArray& col, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW); + + static void col2vol(sd::graph::Context& block, const NDArray& col, + NDArray& vol, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW); + + static void upsampling2d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int factorH, + const int factorW, const bool isNCHW); + + static void upsampling3d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int factorD, + const int factorH, const int factorW, + const bool isNCDHW); + + static void upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, + NDArray& gradI, const bool isNCHW); + + static void upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, + NDArray& gradI, const bool isNCDHW); + + static void pooling2d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, const int extraParam0); + + static void pooling3d(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0); + + static void pooling2dBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int poolingMode, + const int extraParam0); + + static void pooling3dBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0); +}; + +} // namespace ops +} // namespace sd +#endif // LIBND4J_CONVOLUTIONS_H diff --git a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp index 97bdd5c8963d..a25d8957fd54 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp @@ -18,40 +18,40 @@ // @author George A. Shulinok , created on 4/18/2019 // -#include #include +#include namespace sd { namespace ops { namespace helpers { - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) { - - int* pRowCounts = reinterpret_cast(rowCounts.buffer()); - int const* pRows = reinterpret_cast(rowP->buffer()); - int const* pCols = reinterpret_cast(colP->buffer()); - for (Nd4jLong n = 0; n < N; n++) { - int begin = pRows[n];//->e(n); - int end = pRows[n + 1];//rowP->e(n + 1); - for (int i = begin; i < end; i++) { - bool present = false; - for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) - if (pCols[m] == n) { - present = true; - break; - } - - ++pRowCounts[n]; - - if (!present) - ++pRowCounts[pCols[i]]; - } +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts) { + int* pRowCounts = reinterpret_cast(rowCounts.buffer()); + int const* pRows = reinterpret_cast(rowP->buffer()); + int const* pCols = reinterpret_cast(colP->buffer()); + for (Nd4jLong n = 0; n < N; n++) { + int begin = pRows[n]; //->e(n); + int end = pRows[n + 1]; // rowP->e(n + 1); + for (int i = begin; i < end; i++) { + bool present = false; + for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) + if (pCols[m] == n) { + present = true; + break; } - NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {}); - //rowCounts.printBuffer("Row counts"); - auto numElements = numElementsArr.e(0); - return numElements; + + ++pRowCounts[n]; + + if (!present) ++pRowCounts[pCols[i]]; } + } + NDArray numElementsArr = + rowCounts.sumNumber(); // reduceAlongDimension(reduce::Sum, {}); + // rowCounts.printBuffer("Row counts"); + auto numElements = numElementsArr.e(0); + return numElements; +} // static // void printVector(std::vector const& v) { // for (auto x: v) { @@ -61,179 +61,216 @@ namespace helpers { // fflush(stdout); // } - template - static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - //auto N = rowP->lengthOf() - 1; /// 2 + rowP->lengthOf() % 2; - //auto numElements = output->lengthOf(); - //std::vector symRowP = rowCounts->asVectorT();//NDArrayFactory::create('c', {numElements}); - //NDArray symValP = NDArrayFactory::create('c', {numElements}); - //symRowP.insert(symRowP.begin(),0); - //symRowP(1, {0}) = *rowCounts; - int const* pRows = reinterpret_cast(rowP->buffer()); - int* symRowP = reinterpret_cast(outputRows->buffer()); - symRowP[0] = 0; - for (Nd4jLong n = 0; n < N; n++) - symRowP[n + 1] = symRowP[n] + rowCounts->e(n); -// outputRows->printBuffer("output rows"); - - int* symColP = reinterpret_cast(outputCols->buffer()); -// symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n)) -// outputRows->printBuffer("SymRows are"); - int const* pCols = reinterpret_cast(colP->buffer()); - T const* pVals = reinterpret_cast(valP->buffer()); - T* pOutput = reinterpret_cast(outputVals->buffer()); - //std::vector rowCountsV = rowCounts->getBufferAsVector(); - std::vector offset(N);// = NDArrayFactory::create('c', {N}); - -//PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(guided) shared(offset)) - for (Nd4jLong n = 0; n < N; n++) { - int begin = pRows[n]; - int bound = pRows[n + 1]; - - for (int i = begin; i < bound; i++) { - bool present = false; - int colPI = pCols[i]; - int start = pRows[colPI]; - int end = pRows[colPI + 1]; - - //PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(offset)) - for (int m = start; m < end; m++) { - if (pCols[m] == n) { - present = true; - if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[colPI] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; - } - } - } - - // If (colP[i], n) is not present, there is no addition involved - if (!present) { - //int colPI = pCols[i]; - //if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[pCols[i]] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; - //} - - } - // Update offsets - if (!present || (present && n <= colPI)) { - ++offset[n]; - - if (colPI != n) - ++offset[colPI]; - } -// printVector(offset); - } +template +static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, + NDArray* outputRows, NDArray* outputCols, + NDArray* outputVals, NDArray* rowCounts) { + // auto N = rowP->lengthOf() - 1; /// 2 + rowP->lengthOf() % 2; + // auto numElements = output->lengthOf(); + // std::vector symRowP = + // rowCounts->asVectorT();//NDArrayFactory::create('c', + // {numElements}); NDArray symValP = NDArrayFactory::create('c', + // {numElements}); symRowP.insert(symRowP.begin(),0); symRowP(1, {0}) = + // *rowCounts; + int const* pRows = reinterpret_cast(rowP->buffer()); + int* symRowP = reinterpret_cast(outputRows->buffer()); + symRowP[0] = 0; + for (Nd4jLong n = 0; n < N; n++) + symRowP[n + 1] = symRowP[n] + rowCounts->e(n); + // outputRows->printBuffer("output rows"); + + int* symColP = reinterpret_cast(outputCols->buffer()); + // symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n)) + // outputRows->printBuffer("SymRows are"); + int const* pCols = reinterpret_cast(colP->buffer()); + T const* pVals = reinterpret_cast(valP->buffer()); + T* pOutput = reinterpret_cast(outputVals->buffer()); + // std::vector rowCountsV = rowCounts->getBufferAsVector(); + std::vector offset(N); // = NDArrayFactory::create('c', {N}); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(guided) shared(offset)) + for (Nd4jLong n = 0; n < N; n++) { + int begin = pRows[n]; + int bound = pRows[n + 1]; + + for (int i = begin; i < bound; i++) { + bool present = false; + int colPI = pCols[i]; + int start = pRows[colPI]; + int end = pRows[colPI + 1]; + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(offset)) + for (int m = start; m < end; m++) { + if (pCols[m] == n) { + present = true; + if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[colPI] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; + } } + } + + // If (colP[i], n) is not present, there is no addition involved + if (!present) { + // int colPI = pCols[i]; + // if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[pCols[i]] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; + //} + } + // Update offsets + if (!present || (present && n <= colPI)) { + ++offset[n]; + + if (colPI != n) ++offset[colPI]; + } + // printVector(offset); } - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - - // Divide the result by two - BUILD_SINGLE_SELECTOR(valP->dataType(), barnes_symmetrize_, (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), NUMERIC_TYPES); - - *outputVals /= 2.0; - //output->assign(symValP); - } - BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES); - - template - static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) { - T const* dataP = reinterpret_cast(data->buffer()); - T const* vals = reinterpret_cast(valP->buffer()); - T* outputP = reinterpret_cast(output->buffer()); - int colCount = data->columns(); - - -// auto shift = 0; - auto rowSize = sizeof(T) * colCount; - - auto func = PRAGMA_THREADS_FOR { - for (auto n = start; n < stop; n++) { - int s = rowP->e(n); - int end = rowP->e(n + 1); - int shift = n * colCount; - for (int i = s; i < end; i++) { - T const *thisSlice = dataP + colP->e(i) * colCount; - T res = 1; - - for (int k = 0; k < colCount; k++) { - auto tempVal = dataP[shift + k] - thisSlice[k];//thisSlice[k]; - res += tempVal * tempVal; - } - - res = vals[i] / res; - for (int k = 0; k < colCount; k++) - outputP[shift + k] += ((dataP[shift + k] - thisSlice[k]) * res); - } - //shift += colCount; - } - }; - - samediff::Threads::parallel_tad(func, 0, N); - } - - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) { - // Loop over all edges in the graph - BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, (rowP, colP, valP, N, &data, output), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES); - - template - static void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - // gains = gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) - // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); - auto gainsInternal = LAMBDA_TTT(x, grad, eps) { -// return T((x + 2.) * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)); - //return T((x + 2.) * sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)); - T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) ? x + T(.2) : x * T(.8); - if(res < .01) res = .01; - return res; - }; - - input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); - } - - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - // gains = gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) - // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); - BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, (input, gradX, epsilon, output), NUMERIC_TYPES); -// auto signGradX = *gradX; -// auto signEpsilon = *epsilon; -// gradX->applyTransform(transform::Sign, &signGradX, nullptr); -// epsilon->applyTransform(transform::Sign, &signEpsilon, nullptr); -// auto leftPart = (*input + 2.) * signGradX; -// auto leftPartBool = NDArrayFactory::create(leftPart.ordering(), leftPart.getShapeAsVector()); -// -// leftPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, &leftPartBool, nullptr); -// auto rightPart = *input * 0.8 * signGradX; -// auto rightPartBool = NDArrayFactory::create(rightPart.ordering(), rightPart.getShapeAsVector()); -// rightPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, &rightPartBool, nullptr); -// leftPart.assign(leftPartBool); -// rightPart.assign(rightPartBool); -// leftPart.applyPairwiseTransform(pairwise::Add, &rightPart, output, nullptr); - - } - BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES); - - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) { - auto cornerMinusWidth = *corner - *width; - auto cornerPlusWidth = *corner + *width; - - for (Nd4jLong i = 0; i < dimension; i++) { - if (cornerMinusWidth.e(i) > point->e(i)) - return false; - if (cornerPlusWidth.e(i) < point->e(i)) - return false; + } +} +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts) { + // Divide the result by two + BUILD_SINGLE_SELECTOR( + valP->dataType(), barnes_symmetrize_, + (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), + NUMERIC_TYPES); + + *outputVals /= 2.0; + // output->assign(symValP); +} +BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, + (const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts), + NUMERIC_TYPES); + +template +static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output) { + T const* dataP = reinterpret_cast(data->buffer()); + T const* vals = reinterpret_cast(valP->buffer()); + T* outputP = reinterpret_cast(output->buffer()); + int colCount = data->columns(); + + // auto shift = 0; + auto rowSize = sizeof(T) * colCount; + + auto func = PRAGMA_THREADS_FOR { + for (auto n = start; n < stop; n++) { + int s = rowP->e(n); + int end = rowP->e(n + 1); + int shift = n * colCount; + for (int i = s; i < end; i++) { + T const* thisSlice = dataP + colP->e(i) * colCount; + T res = 1; + + for (int k = 0; k < colCount; k++) { + auto tempVal = dataP[shift + k] - thisSlice[k]; // thisSlice[k]; + res += tempVal * tempVal; } - return true; + res = vals[i] / res; + for (int k = 0; k < colCount; k++) + outputP[shift + k] += ((dataP[shift + k] - thisSlice[k]) * res); + } + // shift += colCount; } + }; + + samediff::Threads::parallel_tad(func, 0, N); } + +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data) { + // Loop over all edges in the graph + BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, + (rowP, colP, valP, N, &data, output), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, + (const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output), + FLOAT_TYPES); + +template +static void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + // gains = + // gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) + // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); + auto gainsInternal = LAMBDA_TTT(x, grad, eps) { + // return T((x + 2.) * sd::math::nd4j_sign(grad) != + // sd::math::nd4j_sign(eps)) + T(x * 0.8 * + // sd::math::nd4j_sign(grad) != + // sd::math::nd4j_sign(eps)); + // return T((x + 2.) * sd::math::nd4j_sign(grad) == + // sd::math::nd4j_sign(eps)) + T(x * 0.8 * + // sd::math::nd4j_sign(grad) == sd::math::nd4j_sign(eps)); + T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) + ? x + T(.2) + : x * T(.8); + if (res < .01) res = .01; + return res; + }; + + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); } +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + // gains = + // gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) + // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); + BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, + (input, gradX, epsilon, output), NUMERIC_TYPES); + // auto signGradX = *gradX; + // auto signEpsilon = *epsilon; + // gradX->applyTransform(transform::Sign, &signGradX, nullptr); + // epsilon->applyTransform(transform::Sign, &signEpsilon, nullptr); + // auto leftPart = (*input + 2.) * signGradX; + // auto leftPartBool = + // NDArrayFactory::create(leftPart.ordering(), + // leftPart.getShapeAsVector()); + // + // leftPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, + // &leftPartBool, nullptr); auto rightPart = *input * 0.8 * signGradX; + // auto rightPartBool = + // NDArrayFactory::create(rightPart.ordering(), + // rightPart.getShapeAsVector()); + // rightPart.applyPairwiseTransform(pairwise::NotEqualTo, &signEpsilon, + // &rightPartBool, nullptr); leftPart.assign(leftPartBool); + // rightPart.assign(rightPartBool); + // leftPart.applyPairwiseTransform(pairwise::Add, &rightPart, output, + // nullptr); +} +BUILD_SINGLE_TEMPLATE(template void barnes_gains_, + (NDArray * input, NDArray* gradX, NDArray* epsilon, + NDArray* output), + NUMERIC_TYPES); + +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension) { + auto cornerMinusWidth = *corner - *width; + auto cornerPlusWidth = *corner + *width; + + for (Nd4jLong i = 0; i < dimension; i++) { + if (cornerMinusWidth.e(i) > point->e(i)) return false; + if (cornerPlusWidth.e(i) < point->e(i)) return false; + } + + return true; +} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp index ccc4d676aeb4..3e6cf25f1f70 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp @@ -19,226 +19,263 @@ // @author raver119@gmail.com // -#include +#include +#include #include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// - template - void static _softMaxDerivForVector(sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output) { - - const T* inBuff = reinterpret_cast(input); - T* outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0.; - int length = shape::length(inShapeInfo); - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - max = sd::math::nd4j_max(max, inBuff[offset]); - } - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - outBuff[offset] = sd::math::nd4j_exp(inBuff[offset] - max); - sum += outBuff[offset]; - } - - for (int i = 0; i < length; i++) { - const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); - outBuff[offset] /= sum; - outBuff[offset] *= (1.f - outBuff[offset]); // derivative - } - } +template +void static _softMaxDerivForVector(sd::LaunchContext* context, + const void* input, + const Nd4jLong* inShapeInfo, void* output) { + const T* inBuff = reinterpret_cast(input); + T* outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0.; + int length = shape::length(inShapeInfo); + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + max = sd::math::nd4j_max(max, inBuff[offset]); + } + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + outBuff[offset] = sd::math::nd4j_exp(inBuff[offset] - max); + sum += outBuff[offset]; + } + + for (int i = 0; i < length; i++) { + const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); + outBuff[offset] /= sum; + outBuff[offset] *= (1.f - outBuff[offset]); // derivative + } +} /////////////////////////////////////////////////////////////////// - void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - int temp; - - if(shape::isCommonVector(input.shapeInfo(), temp)) { - - BUILD_SINGLE_SELECTOR(input.dataType(), _softMaxDerivForVector, (context, input.buffer(), input.shapeInfo(), output.buffer()), FLOAT_TYPES); - } - else { - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output *= (1.f - output); // derivative - } - } +void softmaxDerivative(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + int temp; + + if (shape::isCommonVector(input.shapeInfo(), temp)) { + BUILD_SINGLE_SELECTOR( + input.dataType(), _softMaxDerivForVector, + (context, input.buffer(), input.shapeInfo(), output.buffer()), + FLOAT_TYPES); + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output *= (1.f - output); // derivative + } +} /////////////////////////////////////////////////////////////////// - template - void logSoftMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo) { - auto inBuff = reinterpret_cast(input); - auto outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0; - - auto inEWS = shape::elementWiseStride(inShapeInfo); - auto length = shape::length(inShapeInfo); - - if (inEWS == 1) { - for (Nd4jLong i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i]); - - PRAGMA_OMP_SIMD_SUM(sum) - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); - sum += outBuff[i]; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i] /= sum; - outBuff[i] = sd::math::nd4j_log(outBuff[i]); - } - } - else if (inEWS > 1) { - - PRAGMA_OMP_SIMD_MAX(max) - for (Nd4jLong i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i * inEWS]); - - PRAGMA_OMP_SIMD_SUM(sum) - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i * inEWS] = sd::math::nd4j_exp(inBuff[i * inEWS] - max); - sum += outBuff[i * inEWS]; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < length; i++) { - outBuff[i * inEWS] /= sum; - outBuff[i * inEWS] = sd::math::nd4j_log(outBuff[i * inEWS]); - } - } +template +void logSoftMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo) { + auto inBuff = reinterpret_cast(input); + auto outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0; + + auto inEWS = shape::elementWiseStride(inShapeInfo); + auto length = shape::length(inShapeInfo); + + if (inEWS == 1) { + for (Nd4jLong i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i]); + + PRAGMA_OMP_SIMD_SUM(sum) + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); + sum += outBuff[i]; } - /////////////////////////////////////////////////////////////////// - void logSoftMaxForVector(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - if(!input.isVector() || !output.isVector()) - throw std::runtime_error("ops::helpers::logSoftMaxForVector function input and output arrays must be vectors !"); - - auto xType = input.dataType(); - BUILD_SINGLE_SELECTOR(xType, logSoftMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i] /= sum; + outBuff[i] = sd::math::nd4j_log(outBuff[i]); + } + } else if (inEWS > 1) { + PRAGMA_OMP_SIMD_MAX(max) + for (Nd4jLong i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i * inEWS]); + + PRAGMA_OMP_SIMD_SUM(sum) + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i * inEWS] = sd::math::nd4j_exp(inBuff[i * inEWS] - max); + sum += outBuff[i * inEWS]; } + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < length; i++) { + outBuff[i * inEWS] /= sum; + outBuff[i * inEWS] = sd::math::nd4j_log(outBuff[i * inEWS]); + } + } +} -////////////////////////////////////////////////////////////////////////// -void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { - const Nd4jLong inputLen = input.lengthOf(); - const Nd4jLong* inputShapeInfo = input.shapeInfo(); - const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - // FIXME: double! - double x = input.e(i); - if (x < 0.0) { - // FIXME: double - output.p(i, (x * alpha.e(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo)))); - } else - output.p(i, x); - } - }; - - samediff::Threads::parallel_for(func, 0, inputLen); +/////////////////////////////////////////////////////////////////// +void logSoftMaxForVector(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + if (!input.isVector() || !output.isVector()) + throw std::runtime_error( + "ops::helpers::logSoftMaxForVector function input and output arrays " + "must be vectors !"); + + auto xType = input.dataType(); + BUILD_SINGLE_SELECTOR( + xType, logSoftMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), + FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// -void preluBP(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - - const Nd4jLong inputLen = input.lengthOf(); - const Nd4jLong* inputShapeInfo = input.shapeInfo(); - const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); - - dLdA.assign(0.0f); - - for(Nd4jLong i = 0; i < inputLen; ++i) { +void prelu(sd::LaunchContext* context, const NDArray& input, + const NDArray& alpha, NDArray& output) { + const Nd4jLong inputLen = input.lengthOf(); + const Nd4jLong* inputShapeInfo = input.shapeInfo(); + const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + // FIXME: double! + double x = input.e(i); + if (x < 0.0) { // FIXME: double - double x = input.e(i); - double grO = dLdO.e(i); - if(x < 0.0) { - Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); - dLdI.p(i, grO * alpha.e(alphaInd)); - double prevVal = dLdA.e(alphaInd); - prevVal += (grO * x); - dLdA.p(alphaInd, prevVal); - } - else - dLdI.p(i, grO); - } -} - - - bool checkAlphaShapeLen(std::vector const& expectedShape, Nd4jLong shapeLen) { - Nd4jLong expectedAlphaLen = std::accumulate(expectedShape.cbegin(), expectedShape.cend(), 1, std::multiplies()); - return expectedAlphaLen == shapeLen; - } - template - static void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { - auto routine = LAMBDA_T(_x, threshold) { - return _x > (T)threshold? _x: (T)0.f; - }; - const_cast(input).applyLambda(routine, output); - } - - void thresholdRelu(sd::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, (input, threshold, output), FLOAT_TYPES); - } - - template - static void thresholdReluDerivative_(sd::LaunchContext * context, NDArray* input, double theta, NDArray* dLdO, NDArray* output) { - auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - - input->applyPairwiseLambda(*dLdO, derivative, *output); - + output.p(i, (x * alpha.e(shape::subArrayIndex( + i, inputShapeInfo, alphaShapeInfo)))); + } else + output.p(i, x); } + }; - void thresholdReluDerivative(sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (context, input, threshold, dLdO, output), FLOAT_TYPES); - } - - /////////////////////////////////////////////////////////////////// - void logSoftmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - - if(input.isVector()) { + samediff::Threads::parallel_for(func, 0, inputLen); +} - if(rank == 1 || input.sizeAt(dimension) != 1) { - BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); - } - else - output = 0.; - } - else { +////////////////////////////////////////////////////////////////////////// +void preluBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, + NDArray& dLdA) { + const Nd4jLong inputLen = input.lengthOf(); + const Nd4jLong* inputShapeInfo = input.shapeInfo(); + const Nd4jLong* alphaShapeInfo = alpha.shapeInfo(); + + dLdA.assign(0.0f); + + for (Nd4jLong i = 0; i < inputLen; ++i) { + // FIXME: double + double x = input.e(i); + double grO = dLdO.e(i); + if (x < 0.0) { + Nd4jLong alphaInd = + shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); + dLdI.p(i, grO * alpha.e(alphaInd)); + double prevVal = dLdA.e(alphaInd); + prevVal += (grO * x); + dLdA.p(alphaInd, prevVal); + } else + dLdI.p(i, grO); + } +} - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output.applyTransform(transform::Log, output); - } - } +bool checkAlphaShapeLen(std::vector const& expectedShape, + Nd4jLong shapeLen) { + Nd4jLong expectedAlphaLen = + std::accumulate(expectedShape.cbegin(), expectedShape.cend(), 1, + std::multiplies()); + return expectedAlphaLen == shapeLen; +} +template +static void thresholdRelu_(NDArray const& input, double threshold, + NDArray& output) { + auto routine = LAMBDA_T(_x, threshold) { + return _x > (T)threshold ? _x : (T)0.f; + }; + const_cast(input).applyLambda(routine, output); +} - BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void logSoftMaxForVector_, (void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void _softMaxDerivForVector, (sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output), FLOAT_TYPES); +void thresholdRelu(sd::LaunchContext* context, NDArray const& input, + double threshold, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, + (input, threshold, output), FLOAT_TYPES); +} +template +static void thresholdReluDerivative_(sd::LaunchContext* context, NDArray* input, + double theta, NDArray* dLdO, + NDArray* output) { + auto derivative = LAMBDA_TT(_x, grO, theta) { + if (_x > theta) + return grO; + else + return static_cast(0); + }; + + input->applyPairwiseLambda(*dLdO, derivative, *output); } + +void thresholdReluDerivative(sd::LaunchContext* context, NDArray* input, + double threshold, NDArray* dLdO, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, + (context, input, threshold, dLdO, output), FLOAT_TYPES); } + +/////////////////////////////////////////////////////////////////// +void logSoftmax(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo()), + FLOAT_TYPES); + } else + output = 0.; + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output.applyTransform(transform::Log, output); + } } +BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, + (sd::LaunchContext * context, NDArray* input, + double threshold, NDArray* dLdO, NDArray* output), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void logSoftMaxForVector_, + (void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void _softMaxDerivForVector, + (sd::LaunchContext * context, const void* input, + const Nd4jLong* inShapeInfo, void* output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp index a03b4504f3df..70ab24f2bc34 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/addBias.cpp @@ -15,634 +15,699 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma, created on 26.02.2018 - // - // - // @author AbdelRauf - // +// +// @author Yurii Shyrma, created on 26.02.2018 +// +// +// @author AbdelRauf +// -#include -#include -#include -#include -#include #include +#include #include #include -#if defined(__GNUC__) +#include +#include +#include +#include + +#if defined(__GNUC__) #define align32 __attribute__((aligned(32))) #elif defined(_MSC_VER) #define align32 __declspec(align(32)) #else -#define align32 -#endif +#define align32 +#endif namespace sd { - namespace ops { - namespace helpers { - - template - static FORCEINLINE void _add(const T* __restrict xx, const T* __restrict yy, T* __restrict zz, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - zz[c] = xx[c] + yy[c]; - } - - template - static FORCEINLINE void _add_inplace(T* __restrict xx, const T* __restrict yy, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - xx[c] = xx[c] + yy[c]; - } - - template - static FORCEINLINE void _add_broadcast_inplace(T* __restrict xx, const T yy, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - xx[c] = xx[c] + yy; - } - - template - static FORCEINLINE void _add_broadcast(const T* __restrict xx, const T yy, T* __restrict zz, const size_t& N) { - PRAGMA_OMP_SIMD - for (size_t c = 0; c < N; c++) - zz[c] = xx[c] + yy; - } - - static constexpr size_t MIN_NN = 32; - static constexpr size_t MIN_NN_K = 2; - - template - static typename std::enable_if::value, const X*>::type - flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, std::unique_ptr& b_heap, const Nd4jLong num, Nd4jLong yStrideC) - { - //best results when buffer used much , may result bad perf if buffer is used once - X* b_new = nullptr; - if (yStrideC != 1) { - if (num > b_stack_size) { - b_heap.reset(new X[num]); - b_new = b_heap.get(); - } - else { - b_new = b_stack; - } - for (size_t i = 0; i < num; i++) { - b_new[i] = b_real[i * yStrideC]; - } - } - else { - //no need , just pass normal bias - return static_cast(b_real); - } - return const_cast(b_new); - } - - template - static typename std::enable_if::value, const X*>::type - flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, std::unique_ptr& b_heap, const Nd4jLong num, Nd4jLong yStrideC) - { - //best results when buffer used much , may result bad perf if buffer is used once - X* b_new = nullptr; - if (num > b_stack_size) { - b_heap.reset(new X[num]); - b_new = b_heap.get(); - } - else { - b_new = b_stack; - } - if (yStrideC != 1) { - for (size_t i = 0; i < num; i++) { - b_new[i] = static_cast(b_real[i * yStrideC]); - } - } - else { - for (size_t i = 0; i < num; i++) { - b_new[i] = static_cast(b_real[i]); - } - } - return const_cast(b_new); - } - - template - static void channel_atTheEnd_stride1_C(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T* b, T* z, const bool& inplace, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset = sd::init_coords(cst, start, bases, x_strides); - - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - _add(&(x[offset]), b, &(z[offset]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - _add_inplace(&(x[offset]), b, inc); - offset = sd::inc_coords(cst, offset); - } - } - } - - - template - static void channel_atTheEnd_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, T* x, const T* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - //just ensure that passed sameStride is correct, because when bases are equal orders matters - bool sameOrderStride = same_order && same_stride; - if (sameOrderStride && x_strides[constRank - 1] == 1) { - channel_atTheEnd_stride1_C(x_strides, bases, x, b, z, inplaceOp, start, stop, inc); - } - else { - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); - Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); - - if (same_order && x_stride == 1 && z_stride == 1) { - /* bases are equal with different strides , but the last one is 1. So we can still vectorize it */ - for (size_t i = 0; i < loop_count; i++) { - _add(&(x[offset.first]), b, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + b[j]; - offset = sd::inc_coords(cst, offset); - } - } - } - - } - - /** - * this is our main optimization which benefits from everything for the continuous last_channel C order case - * as it is intended for full continous we do not need any rank info - */ - template - void channel_atTheEnd_continous_C(T* x, const T* b, T* z, bool inplaceOp, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - size_t nums = (stop - start); - size_t num_inc = nums - nums % inc; - if (inplaceOp) { - - size_t offset_p = start; - for (size_t i = 0; i < num_inc; i += inc) { - _add_inplace(&(x[offset_p]), b, inc); - offset_p += inc; - } - if (nums > num_inc) - _add_inplace(&(x[offset_p]), b, nums - num_inc); - } - else { - size_t offset_p = start; - for (size_t i = 0; i < num_inc; i += inc) { - _add(&(x[offset_p]), b, &(z[offset_p]), inc); - offset_p += inc; - } - if (nums > num_inc) - _add(&(x[offset_p]), b, &(z[offset_p]), nums - num_inc); - } - } - - template - static void channel_NC_stride1_C(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset = sd::init_coords(cst, start, bases, x_strides); - - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset]), yy, &(z[offset]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast_inplace(&(x[offset]), yy, inc); - offset = sd::inc_coords(cst, offset); - } - } - } - - template - void channel_NC_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, const Nd4jLong yStrideC, T* x, const T2* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - //just ensure that passed sameStride is correct, because when bases are equal orders matters - - bool sameOrderStride = same_order && same_stride; - - if (sameOrderStride && x_strides[constRank - 1] == 1) { - channel_NC_stride1_C(x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); - } - else { - - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); - Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); - if (same_order && z_stride == 1 && x_stride == 1) { - /* bases are equal with different strides , but the last one is 1. So we can still vectorize it */ - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + yy; - offset = sd::inc_coords(cst, offset); - } - } - } - } - - /// - template - void channel_NC_continous_numHW_C(Nd4jLong rank, const Nd4jLong* bases, const Nd4jLong* x_strides, T* x, const T2* b, T* z, bool inplaceOp, const Nd4jLong yStrideC, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - - sd::CoordsState<1> cst; - //note: we had to manually pass index - size_t offset_p = sd::init_coords<2>(cst, start / inc, bases, x_strides); - - //partitioning was done using numHW, so we can increment from rank 2 - if (inplaceOp) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast_inplace(&(x[offset_p]), yy, inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - else { - if (yStrideC == 1) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1)]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords<2>(cst, offset_p); - } - } - } - } - - // - template - static void channel_generic_stride_skip_F(const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, const Nd4jLong& stop, const Nd4jLong& inc) - { - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - size_t loop_count = (stop - start) / inc; - sd::CoordsState cst; - size_t offset_p = sd::init_coords(cst, start, bases, x_strides); - if (!inplace) { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); - _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); - offset_p = sd::inc_coords(cst, offset_p); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); - _add_broadcast_inplace(&(x[offset_p]), yy, inc); - offset_p = sd::inc_coords(cst, offset_p); - } - } - } - - /// - template - void channel_generic_F(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, const bool& inplaceOp, const bool same_stride, const bool same_order, const Nd4jLong yStrideC, T* x, const T2* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { - //just ensure that passed sameStride is correct, because when bases are equal orders matters - bool sameOrderStride = same_order && same_stride; - if (sameOrderStride && x_strides[0] == 1) { - channel_generic_stride_skip_F(x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); - } - else { - // (stop-start) % inc == 0 because we handled inside partitioning using the channel size - - size_t loop_count = (stop - start) / inc; - sd::ZipCoordsState cst; - sd::zip_size_t offset = sd::init_coords(cst, start, bases, x_strides, z_strides); - Nd4jLong x_stride = ZIP_STRIDE1(cst, 0); - Nd4jLong z_stride = ZIP_STRIDE2(cst, 0); - if (same_order && z_stride == 1 && x_stride == 1) { - - for (size_t i = 0; i < loop_count; i++) { - T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); - _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); - offset = sd::inc_coords(cst, offset); - } - } - else { - for (size_t i = 0; i < loop_count; i++) { - T* xx = &(x[offset.first]); - T* zz = &(z[offset.second]); - T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); - for (size_t j = 0; j < inc; j++) - zz[j * z_stride] = xx[j * x_stride] + yy; - offset = sd::inc_coords(cst, offset); - } - } - } - } - - - template - static void addBias_(const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - /* - if (input.rankOf() == 2 && bias.rankOf() == 1 && input.sizeAt(1) == bias.sizeAt(0) && input.ordering() == 'c') { - int rows = input.sizeAt(0); - int biasLen = bias.lengthOf(); - - auto inB = input.bufferAsT(); - auto bB = bias.bufferAsT(); - auto outB = output.bufferAsT(); - - for (int e = 0; e < rows; e++) { - auto row = inB + (e * biasLen); - auto out = outB + (e * biasLen); - - for (int t = 0; t < biasLen; t++) { - out[t] = row[t] + bB[t]; - } - } - - return; - } - */ - - auto x_shapeInfo = input.shapeInfo(); - auto z_shapeInfo = output.shapeInfo(); - auto x = input.bufferAsT(); - auto z = output.bufferAsT(); - auto b = bias.bufferAsT(); - const Nd4jLong rank = x_shapeInfo[0]; - auto bases = &(x_shapeInfo[1]); - auto x_strides = &(x_shapeInfo[rank + 1]); - auto z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - const bool same_order = inplaceOp || (input.ordering() == output.ordering()); - const bool channel_atTheEnd = !isNCHW; - const bool same_stride = inplaceOp || shape::strideEquals(x_shapeInfo, z_shapeInfo); - bool isContinuous = false; - int posOfNonUnityDim; - bias.isCommonVector(posOfNonUnityDim); - const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim); - char order = input.ordering(); - - //for rank>5 - if (rank > 5) { - const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last - const_cast(input).applyBroadcast(sd::broadcast::Add, { channelDim }, bias, output); - return; - } - - if (same_order && same_stride) { - isContinuous = shape::elementWiseStride(x_shapeInfo) == 1 && shape::elementWiseStride(z_shapeInfo) == 1; - // check_continuity(order, bases, x_strides, rank); - }//if ( sameOrder && same_stride) - - bool treat_as_lastC = false; - // - if (rank == 2 && isNCHW) { - //we believe we better treat it as channel at the end case; - treat_as_lastC = true; - } - if (channel_atTheEnd || treat_as_lastC) { - //N..HWC case here - //flattened bias variables - constexpr size_t BSIZE1 = 3 * MIN_NN * MIN_NN; - constexpr size_t BSIZE2 = BSIZE1 + MIN_NN * MIN_NN; - X flatBias_stack[BSIZE2] align32; - std::unique_ptr flatBias_heap; - const X* bias_new; - X* bias_extra = nullptr; - size_t total_num = 1; - for (Nd4jLong i = 0; i < rank; i++) { - total_num *= bases[i]; - } - Nd4jLong inc; - size_t rank_skip = 1; - if (order == 'c') { - size_t b_stack_size = BSIZE2; - inc = bases[rank - 1]; - if (isContinuous) { - //for continous we need extra stack memory - // to create vectorizable bias from small size - b_stack_size = BSIZE1; - bias_extra = &(flatBias_stack[BSIZE1]); - } - bias_new = flattened_bias(b, (X*)flatBias_stack, b_stack_size, flatBias_heap, inc, yStrideC); - if (isContinuous && inc < MIN_NN_K * MIN_NN && total_num > inc * MIN_NN_K) { - //for small size where total_num is sufficient we need to recreate vectorizable buffer - size_t old_inc = inc; - //sizeof bias_extra is MIN_NN * MIN_NN - size_t new_inc = inc < MIN_NN ? inc * MIN_NN : inc * MIN_NN / MIN_NN_K; - //if there is a room then lets multiply - new_inc = (new_inc * MIN_NN_K <= total_num && new_inc < MIN_NN * MIN_NN / MIN_NN_K) ? MIN_NN_K * new_inc : new_inc; - for (size_t i = 0; i < new_inc; i += inc) { - //copy to our buffer - X* cp = &(bias_extra[i]); - for (size_t j = 0; j < inc; j++) { - cp[j] = bias_new[j]; - } - } - //vectorizable buffer - inc = new_inc; - bias_new = bias_extra; - } - } - else { - inc = bases[0]; - if (isContinuous) { - //we can choose other inc and index for that case - //but for now lets choose all till the last one - uint32_t req_numThreads = sd::Environment::getInstance()->maxMasterThreads(); - isContinuous = false; - if (rank > 2) { - if (req_numThreads < 2 || bases[rank - 1] >= req_numThreads) { - inc = total_num / bases[rank - 1]; - isContinuous = true; - rank_skip = rank - 1; - } - else if (rank > 3 && bases[rank - 1] * bases[rank - 2] >= req_numThreads) { - inc = total_num / bases[rank - 1] / bases[rank - 2]; //for continuous case it is its stride - rank_skip = rank - 2; - isContinuous = true; - } - } - } - } - - FUNC_1D func = [order, isContinuous, rank, x, b, bias_new, z, x_shapeInfo, z_shapeInfo, same_stride, same_order, yStrideC, rank_skip] - (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { - const Nd4jLong rank = x_shapeInfo[0]; - auto bases = &(x_shapeInfo[1]); - auto x_strides = &(x_shapeInfo[rank + 1]); - auto z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - if (order == 'c') { - if (isContinuous) { - channel_atTheEnd_continous_C(const_cast(x), bias_new, z, inplaceOp, start, stop, increment); - } - // rank is in [2,5] - else if (rank == 4) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - - } - else if (rank == 5) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - else if (rank == 2) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - else if (rank == 3) { - channel_atTheEnd_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, const_cast(x), bias_new, z, start, stop, increment); - } - } - else { - //generic F case - if (isContinuous) { - if (rank == 4) { - if (rank_skip == rank - 2) { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - else { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 5) { - if (rank_skip == rank - 2) { - //skip==3 - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - else { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 3) { - channel_generic_stride_skip_F(x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - } - else if (rank == 4) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 5) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 2) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - - } - }; - // - samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); - } - else { - //NC...HW case here - size_t numNC = 1; - size_t numHW = 1; - for (size_t i = 0; i < 2; i++) { - numNC *= bases[i]; - } - for (Nd4jLong i = 2; i < rank; i++) { - numHW *= bases[i]; - } - Nd4jLong total_num = numNC * numHW; - Nd4jLong inc = (order == 'c') ? bases[rank - 1] : bases[0]; - if (order == 'c' && isContinuous) { - //sometimes last dimension is too big and multithreading could suffer using unfair partitioning - //so we will do it only when inc is smaller our value or multithreading turned off - uint32_t req_numThreads = sd::Environment::getInstance()->maxMasterThreads(); - if (req_numThreads < 2 || numNC >= req_numThreads || inc <= 2 * 8196 || rank == 3) { - inc = numHW; - } - else { - //treat it as stride1c case - isContinuous = false; - } - } - FUNC_1D func = [order, isContinuous, rank, x, b, z, x_shapeInfo, z_shapeInfo, same_stride, same_order, yStrideC] - (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { - const Nd4jLong rank = x_shapeInfo[0]; - const Nd4jLong* bases = &(x_shapeInfo[1]); - const Nd4jLong* x_strides = &(x_shapeInfo[rank + 1]); - const Nd4jLong* z_strides = &(z_shapeInfo[rank + 1]); - const bool inplaceOp = (x == z); - if (order == 'c') { - if (isContinuous) { - channel_NC_continous_numHW_C(rank, bases, x_strides, const_cast(x), b, z, inplaceOp, yStrideC, start, stop, increment); - } - // rank is in [3,5] - else if (rank == 4) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - - } - else if (rank == 5) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_NC_generic_C(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - } - else { - //the same can be applied for NCHW case - //generic F case - //continous case is missing - - if (rank == 4) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 5) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - else if (rank == 3) { - channel_generic_F(bases, x_strides, z_strides, inplaceOp, same_stride, same_order, yStrideC, const_cast(x), b, z, start, stop, increment); - } - } - }; - // - samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); - } - } - ////////////////////////////////////////////////////////////////////////// - void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - - // bias.rankOf() == 1 ? bias : bias.reshape(bias.ordering(), {bias.lengthOf()}) - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, output, isNCHW), FLOAT_TYPES, FLOAT_TYPES); - } - - - BUILD_DOUBLE_TEMPLATE(template void addBias_, (const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES); - } - } +namespace ops { +namespace helpers { + +template +static FORCEINLINE void _add(const T* __restrict xx, const T* __restrict yy, + T* __restrict zz, const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) zz[c] = xx[c] + yy[c]; +} + +template +static FORCEINLINE void _add_inplace(T* __restrict xx, const T* __restrict yy, + const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) xx[c] = xx[c] + yy[c]; +} + +template +static FORCEINLINE void _add_broadcast_inplace(T* __restrict xx, const T yy, + const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) xx[c] = xx[c] + yy; +} + +template +static FORCEINLINE void _add_broadcast(const T* __restrict xx, const T yy, + T* __restrict zz, const size_t& N) { + PRAGMA_OMP_SIMD + for (size_t c = 0; c < N; c++) zz[c] = xx[c] + yy; +} + +static constexpr size_t MIN_NN = 32; +static constexpr size_t MIN_NN_K = 2; + +template +static typename std::enable_if::value, const X*>::type +flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, + std::unique_ptr& b_heap, const Nd4jLong num, + Nd4jLong yStrideC) { + // best results when buffer used much , may result bad perf if buffer is used + // once + X* b_new = nullptr; + if (yStrideC != 1) { + if (num > b_stack_size) { + b_heap.reset(new X[num]); + b_new = b_heap.get(); + } else { + b_new = b_stack; + } + for (size_t i = 0; i < num; i++) { + b_new[i] = b_real[i * yStrideC]; + } + } else { + // no need , just pass normal bias + return static_cast(b_real); + } + return const_cast(b_new); +} + +template +static typename std::enable_if::value, const X*>::type +flattened_bias(const Y* b_real, X* b_stack, const size_t b_stack_size, + std::unique_ptr& b_heap, const Nd4jLong num, + Nd4jLong yStrideC) { + // best results when buffer used much , may result bad perf if buffer is used + // once + X* b_new = nullptr; + if (num > b_stack_size) { + b_heap.reset(new X[num]); + b_new = b_heap.get(); + } else { + b_new = b_stack; + } + if (yStrideC != 1) { + for (size_t i = 0; i < num; i++) { + b_new[i] = static_cast(b_real[i * yStrideC]); + } + } else { + for (size_t i = 0; i < num; i++) { + b_new[i] = static_cast(b_real[i]); + } + } + return const_cast(b_new); } + +template +static void channel_atTheEnd_stride1_C(const Nd4jLong*& x_strides, + const Nd4jLong*& bases, T* x, const T* b, + T* z, const bool& inplace, + const Nd4jLong& start, + const Nd4jLong& stop, + const Nd4jLong& inc) { + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset = sd::init_coords(cst, start, bases, x_strides); + + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + _add(&(x[offset]), b, &(z[offset]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + _add_inplace(&(x[offset]), b, inc); + offset = sd::inc_coords(cst, offset); + } + } +} + +template +static void channel_atTheEnd_generic_C( + const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* z_strides, + const bool& inplaceOp, const bool same_stride, const bool same_order, T* x, + const T* b, T* z, Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + bool sameOrderStride = same_order && same_stride; + if (sameOrderStride && x_strides[constRank - 1] == 1) { + channel_atTheEnd_stride1_C(x_strides, bases, x, b, z, + inplaceOp, start, stop, inc); + } else { + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = + sd::init_coords(cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); + Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); + + if (same_order && x_stride == 1 && z_stride == 1) { + /* bases are equal with different strides , but the last one is 1. So we + * can still vectorize it */ + for (size_t i = 0; i < loop_count; i++) { + _add(&(x[offset.first]), b, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + b[j]; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +/** + * this is our main optimization which benefits from everything for the + * continuous last_channel C order case as it is intended for full continous we + * do not need any rank info + */ +template +void channel_atTheEnd_continous_C(T* x, const T* b, T* z, bool inplaceOp, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + size_t nums = (stop - start); + size_t num_inc = nums - nums % inc; + if (inplaceOp) { + size_t offset_p = start; + for (size_t i = 0; i < num_inc; i += inc) { + _add_inplace(&(x[offset_p]), b, inc); + offset_p += inc; + } + if (nums > num_inc) _add_inplace(&(x[offset_p]), b, nums - num_inc); + } else { + size_t offset_p = start; + for (size_t i = 0; i < num_inc; i += inc) { + _add(&(x[offset_p]), b, &(z[offset_p]), inc); + offset_p += inc; + } + if (nums > num_inc) + _add(&(x[offset_p]), b, &(z[offset_p]), nums - num_inc); + } +} + +template +static void channel_NC_stride1_C(const Nd4jLong*& x_strides, + const Nd4jLong*& bases, T* x, const T2* b, + T* z, const bool& inplace, + const Nd4jLong yStrideC, const Nd4jLong& start, + const Nd4jLong& stop, const Nd4jLong& inc) { + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset = sd::init_coords(cst, start, bases, x_strides); + + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset]), yy, &(z[offset]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast_inplace(&(x[offset]), yy, inc); + offset = sd::inc_coords(cst, offset); + } + } +} + +template +void channel_NC_generic_C(const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, const bool& inplaceOp, + const bool same_stride, const bool same_order, + const Nd4jLong yStrideC, T* x, const T2* b, T* z, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + + bool sameOrderStride = same_order && same_stride; + + if (sameOrderStride && x_strides[constRank - 1] == 1) { + channel_NC_stride1_C(x_strides, bases, x, b, z, inplaceOp, + yStrideC, start, stop, inc); + } else { + // (stop-start) % inc == 0 because we handled inside partitioning using + // the channel size + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = + sd::init_coords(cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, constRank - 1); + Nd4jLong z_stride = ZIP_STRIDE2(cst, constRank - 1); + if (same_order && z_stride == 1 && x_stride == 1) { + /* bases are equal with different strides , but the last one is 1. So we + * can still vectorize it */ + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + T yy = static_cast(b[ZIP_COORDS(cst, 1) * yStrideC]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + yy; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +/// +template +void channel_NC_continous_numHW_C(Nd4jLong rank, const Nd4jLong* bases, + const Nd4jLong* x_strides, T* x, const T2* b, + T* z, bool inplaceOp, const Nd4jLong yStrideC, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // (stop-start) % inc == 0 because we handled inside partitioning using the + // channel size + size_t loop_count = (stop - start) / inc; + + sd::CoordsState<1> cst; + // note: we had to manually pass index + size_t offset_p = sd::init_coords<2>(cst, start / inc, bases, x_strides); + + // partitioning was done using numHW, so we can increment from rank 2 + if (inplaceOp) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast_inplace(&(x[offset_p]), yy, inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } else { + if (yStrideC == 1) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1)]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, 1) * yStrideC]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords<2>(cst, offset_p); + } + } + } +} + +// +template +static void channel_generic_stride_skip_F( + const Nd4jLong*& x_strides, const Nd4jLong*& bases, T* x, const T2* b, T* z, + const bool& inplace, const Nd4jLong yStrideC, const Nd4jLong& start, + const Nd4jLong& stop, const Nd4jLong& inc) { + // (stop-start) % inc == 0 because we handled inside partitioning using the + // channel size + size_t loop_count = (stop - start) / inc; + sd::CoordsState cst; + size_t offset_p = + sd::init_coords(cst, start, bases, x_strides); + if (!inplace) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); + _add_broadcast(&(x[offset_p]), yy, &(z[offset_p]), inc); + offset_p = sd::inc_coords(cst, offset_p); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[COORDS(cst, b_index) * yStrideC]); + _add_broadcast_inplace(&(x[offset_p]), yy, inc); + offset_p = sd::inc_coords(cst, offset_p); + } + } +} + +/// +template +void channel_generic_F(const Nd4jLong* bases, const Nd4jLong* x_strides, + const Nd4jLong* z_strides, const bool& inplaceOp, + const bool same_stride, const bool same_order, + const Nd4jLong yStrideC, T* x, const T2* b, T* z, + Nd4jLong start, Nd4jLong stop, Nd4jLong inc) { + // just ensure that passed sameStride is correct, because when bases are + // equal orders matters + bool sameOrderStride = same_order && same_stride; + if (sameOrderStride && x_strides[0] == 1) { + channel_generic_stride_skip_F( + x_strides, bases, x, b, z, inplaceOp, yStrideC, start, stop, inc); + } else { + // (stop-start) % inc == 0 because we handled inside partitioning using + // the channel size + + size_t loop_count = (stop - start) / inc; + sd::ZipCoordsState cst; + sd::zip_size_t offset = sd::init_coords( + cst, start, bases, x_strides, z_strides); + Nd4jLong x_stride = ZIP_STRIDE1(cst, 0); + Nd4jLong z_stride = ZIP_STRIDE2(cst, 0); + if (same_order && z_stride == 1 && x_stride == 1) { + for (size_t i = 0; i < loop_count; i++) { + T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); + _add_broadcast(&(x[offset.first]), yy, &(z[offset.second]), inc); + offset = sd::inc_coords(cst, offset); + } + } else { + for (size_t i = 0; i < loop_count; i++) { + T* xx = &(x[offset.first]); + T* zz = &(z[offset.second]); + T yy = static_cast(b[ZIP_COORDS(cst, b_index) * yStrideC]); + for (size_t j = 0; j < inc; j++) + zz[j * z_stride] = xx[j * x_stride] + yy; + offset = sd::inc_coords(cst, offset); + } + } + } +} + +template +static void addBias_(const NDArray& input, const NDArray& bias, NDArray& output, + const bool isNCHW) { + /* + if (input.rankOf() == 2 && bias.rankOf() == 1 && input.sizeAt(1) == +bias.sizeAt(0) && input.ordering() == 'c') { int rows = input.sizeAt(0); int +biasLen = bias.lengthOf(); + +auto inB = input.bufferAsT(); +auto bB = bias.bufferAsT(); +auto outB = output.bufferAsT(); + + for (int e = 0; e < rows; e++) { + auto row = inB + (e * biasLen); +auto out = outB + (e * biasLen); + + for (int t = 0; t < biasLen; t++) { + out[t] = row[t] + bB[t]; + } + } + +return; + } + */ + + auto x_shapeInfo = input.shapeInfo(); + auto z_shapeInfo = output.shapeInfo(); + auto x = input.bufferAsT(); + auto z = output.bufferAsT(); + auto b = bias.bufferAsT(); + const Nd4jLong rank = x_shapeInfo[0]; + auto bases = &(x_shapeInfo[1]); + auto x_strides = &(x_shapeInfo[rank + 1]); + auto z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + const bool same_order = inplaceOp || (input.ordering() == output.ordering()); + const bool channel_atTheEnd = !isNCHW; + const bool same_stride = + inplaceOp || shape::strideEquals(x_shapeInfo, z_shapeInfo); + bool isContinuous = false; + int posOfNonUnityDim; + bias.isCommonVector(posOfNonUnityDim); + const Nd4jLong yStrideC = bias.strideAt(posOfNonUnityDim); + char order = input.ordering(); + + // for rank>5 + if (rank > 5) { + const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last + const_cast(input).applyBroadcast(sd::broadcast::Add, {channelDim}, + bias, output); + return; + } + + if (same_order && same_stride) { + isContinuous = shape::elementWiseStride(x_shapeInfo) == 1 && + shape::elementWiseStride(z_shapeInfo) == 1; + // check_continuity(order, bases, x_strides, rank); + } // if ( sameOrder && same_stride) + + bool treat_as_lastC = false; + // + if (rank == 2 && isNCHW) { + // we believe we better treat it as channel at the end case; + treat_as_lastC = true; + } + if (channel_atTheEnd || treat_as_lastC) { + // N..HWC case here + // flattened bias variables + constexpr size_t BSIZE1 = 3 * MIN_NN * MIN_NN; + constexpr size_t BSIZE2 = BSIZE1 + MIN_NN * MIN_NN; + X flatBias_stack[BSIZE2] align32; + std::unique_ptr flatBias_heap; + const X* bias_new; + X* bias_extra = nullptr; + size_t total_num = 1; + for (Nd4jLong i = 0; i < rank; i++) { + total_num *= bases[i]; + } + Nd4jLong inc; + size_t rank_skip = 1; + if (order == 'c') { + size_t b_stack_size = BSIZE2; + inc = bases[rank - 1]; + if (isContinuous) { + // for continous we need extra stack memory + // to create vectorizable bias from small size + b_stack_size = BSIZE1; + bias_extra = &(flatBias_stack[BSIZE1]); + } + bias_new = flattened_bias(b, (X*)flatBias_stack, b_stack_size, + flatBias_heap, inc, yStrideC); + if (isContinuous && inc < MIN_NN_K * MIN_NN && + total_num > inc * MIN_NN_K) { + // for small size where total_num is sufficient we need to recreate + // vectorizable buffer + size_t old_inc = inc; + // sizeof bias_extra is MIN_NN * MIN_NN + size_t new_inc = inc < MIN_NN ? inc * MIN_NN : inc * MIN_NN / MIN_NN_K; + // if there is a room then lets multiply + new_inc = (new_inc * MIN_NN_K <= total_num && + new_inc < MIN_NN * MIN_NN / MIN_NN_K) + ? MIN_NN_K * new_inc + : new_inc; + for (size_t i = 0; i < new_inc; i += inc) { + // copy to our buffer + X* cp = &(bias_extra[i]); + for (size_t j = 0; j < inc; j++) { + cp[j] = bias_new[j]; + } + } + // vectorizable buffer + inc = new_inc; + bias_new = bias_extra; + } + } else { + inc = bases[0]; + if (isContinuous) { + // we can choose other inc and index for that case + // but for now lets choose all till the last one + uint32_t req_numThreads = + sd::Environment::getInstance()->maxMasterThreads(); + isContinuous = false; + if (rank > 2) { + if (req_numThreads < 2 || bases[rank - 1] >= req_numThreads) { + inc = total_num / bases[rank - 1]; + isContinuous = true; + rank_skip = rank - 1; + } else if (rank > 3 && + bases[rank - 1] * bases[rank - 2] >= req_numThreads) { + inc = total_num / bases[rank - 1] / + bases[rank - 2]; // for continuous case it is its stride + rank_skip = rank - 2; + isContinuous = true; + } + } + } + } + + FUNC_1D func = [order, isContinuous, rank, x, b, bias_new, z, x_shapeInfo, + z_shapeInfo, same_stride, same_order, yStrideC, + rank_skip](uint64_t thread_id, int64_t start, int64_t stop, + int64_t increment) -> void { + const Nd4jLong rank = x_shapeInfo[0]; + auto bases = &(x_shapeInfo[1]); + auto x_strides = &(x_shapeInfo[rank + 1]); + auto z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + if (order == 'c') { + if (isContinuous) { + channel_atTheEnd_continous_C(const_cast(x), bias_new, z, + inplaceOp, start, stop, increment); + } + // rank is in [2,5] + else if (rank == 4) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + + } else if (rank == 5) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } else if (rank == 2) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } else if (rank == 3) { + channel_atTheEnd_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + const_cast(x), bias_new, z, start, stop, increment); + } + } else { + // generic F case + if (isContinuous) { + if (rank == 4) { + if (rank_skip == rank - 2) { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } else { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + } else if (rank == 5) { + if (rank_skip == rank - 2) { + // skip==3 + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } else { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + } else if (rank == 3) { + channel_generic_stride_skip_F( + x_strides, bases, const_cast(x), b, z, inplaceOp, yStrideC, + start, stop, increment); + } + } else if (rank == 4) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 5) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 2) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } + }; + // + samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); + } else { + // NC...HW case here + size_t numNC = 1; + size_t numHW = 1; + for (size_t i = 0; i < 2; i++) { + numNC *= bases[i]; + } + for (Nd4jLong i = 2; i < rank; i++) { + numHW *= bases[i]; + } + Nd4jLong total_num = numNC * numHW; + Nd4jLong inc = (order == 'c') ? bases[rank - 1] : bases[0]; + if (order == 'c' && isContinuous) { + // sometimes last dimension is too big and multithreading could suffer + // using unfair partitioning so we will do it only when inc is smaller our + // value or multithreading turned off + uint32_t req_numThreads = + sd::Environment::getInstance()->maxMasterThreads(); + if (req_numThreads < 2 || numNC >= req_numThreads || inc <= 2 * 8196 || + rank == 3) { + inc = numHW; + } else { + // treat it as stride1c case + isContinuous = false; + } + } + FUNC_1D func = [order, isContinuous, rank, x, b, z, x_shapeInfo, + z_shapeInfo, same_stride, same_order, + yStrideC](uint64_t thread_id, int64_t start, int64_t stop, + int64_t increment) -> void { + const Nd4jLong rank = x_shapeInfo[0]; + const Nd4jLong* bases = &(x_shapeInfo[1]); + const Nd4jLong* x_strides = &(x_shapeInfo[rank + 1]); + const Nd4jLong* z_strides = &(z_shapeInfo[rank + 1]); + const bool inplaceOp = (x == z); + if (order == 'c') { + if (isContinuous) { + channel_NC_continous_numHW_C(rank, bases, x_strides, + const_cast(x), b, z, inplaceOp, + yStrideC, start, stop, increment); + } + // rank is in [3,5] + else if (rank == 4) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + + } else if (rank == 5) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_NC_generic_C( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } else { + // the same can be applied for NCHW case + // generic F case + // continous case is missing + + if (rank == 4) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 5) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } else if (rank == 3) { + channel_generic_F( + bases, x_strides, z_strides, inplaceOp, same_stride, same_order, + yStrideC, const_cast(x), b, z, start, stop, increment); + } + } + }; + // + samediff::Threads::parallel_aligned_increment(func, 0, total_num, inc); + } +} +////////////////////////////////////////////////////////////////////////// +void addBias(sd::graph::Context& block, const NDArray& input, + const NDArray& bias, NDArray& output, const bool isNCHW) { + // bias.rankOf() == 1 ? bias : bias.reshape(bias.ordering(), + // {bias.lengthOf()}) + BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, + (input, bias, output, isNCHW), FLOAT_TYPES, + FLOAT_TYPES); +} + +BUILD_DOUBLE_TEMPLATE(template void addBias_, + (const NDArray& input, const NDArray& bias, + NDArray& output, const bool isNCHW), + FLOAT_TYPES, FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp index 20d91ee8b6d1..6c9680660f2d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp @@ -19,86 +19,85 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template -static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - const T delta = deltaScalarArr->e(0); - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - T h, s, v; - - rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); - - h += delta ; - if (h > (T)1) - h -= (T)1; - else if (h < 0) - h += (T)1; - - hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - - const T *xTad = x + packX.platformOffsets()[i]; - T *zTad = z + packZ.platformOffsets()[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - h += delta ; - if (h > (T)1) - h -= (T)1; - else if (h < 0) - h += (T)1; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } +static void adjustHue_(const NDArray *input, const NDArray *deltaScalarArr, + NDArray *output, const int dimC) { + const T delta = deltaScalarArr->e(0); + const int rank = input->rankOf(); + + const T *x = input->bufferAsT(); + T *z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + T h, s, v; + + rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); + + h += delta; + if (h > (T)1) + h -= (T)1; + else if (h < 0) + h += (T)1; + + hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T *xTad = x + packX.platformOffsets()[i]; + T *zTad = z + packZ.platformOffsets()[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + h += delta; + if (h > (T)1) + h -= (T)1; + else if (h < 0) + h += (T)1; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - -void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES); +void adjustHue(sd::LaunchContext *context, const NDArray *input, + const NDArray *deltaScalarArr, NDArray *output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, + (input, deltaScalarArr, output, dimC), FLOAT_TYPES); } /* template -static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { // we're 100% sure it's 3 const int numChannels = 3; int tuples = array->lengthOf() / numChannels; @@ -166,8 +165,8 @@ static void adjust_hue_single_(sd::LaunchContext * context, NDArray *array, NDAr } } -void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, +NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { @@ -177,21 +176,24 @@ void adjust_hue_(sd::LaunchContext * context, NDArray *array, NDArray *output, N // FIXME: template selector should be moved out of loop PRAGMA_OMP_PARALLEL_FOR for (int e = 0; e < tSize; e++) { - BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, +tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); } delete tadsIn; delete tadsOut; } else { - BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } } -BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (sd::LaunchContext * +context, NDArray *array, NDArray *output, float delta, bool isNHWC);, +FLOAT_TYPES); */ - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp index 6610b69ac035..8bf1216af4e0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp @@ -19,84 +19,88 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - const T factor = factorScalarArr->e(0); - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) { - T h, s, v; - - rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); - - s *= factor; - if (s > 1.f) - s = 1.f; - else if (s < 0.f) - s = 0.f; - - hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T *xTad = x + packX.platformOffsets()[i]; - T *zTad = z + packZ.platformOffsets()[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - s *= factor; - if (s > 1.f) - s = 1.f; - else if (s < 0.f) - s = 0.f; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } +static void adjustSaturation_(const NDArray *input, + const NDArray *factorScalarArr, NDArray *output, + const int dimC) { + const T factor = factorScalarArr->e(0); + const int rank = input->rankOf(); + + const T *x = input->bufferAsT(); + T *z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + T h, s, v; + + rgbToHsv(x[i], x[i + 1], x[i + 2], h, s, v); + + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; + + hsvToRgb(h, s, v, z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T *xTad = x + packX.platformOffsets()[i]; + T *zTad = z + packZ.platformOffsets()[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - -void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES); +void adjustSaturation(sd::LaunchContext *context, const NDArray *input, + const NDArray *factorScalarArr, NDArray *output, + const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, + (input, factorScalarArr, output, dimC), FLOAT_TYPES); } /* template -static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void adjust_saturation_single_(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { // we're 100% sure it's 3 const int numChannels = 3; int tuples = array->lengthOf() / numChannels; @@ -114,7 +118,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, +s * delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2); } @@ -143,7 +148,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, +s * delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo); } @@ -153,8 +159,8 @@ static void adjust_saturation_single_(sd::LaunchContext * context, NDArray *arra } } -void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray +*output, NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { @@ -165,7 +171,8 @@ void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *out // FIXME: template selector should be moved out of loop PRAGMA_OMP_PARALLEL_FOR for (int e = 0; e < tSize; e++) { - BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, +tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); } @@ -173,13 +180,16 @@ void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *out delete tadsOut; } else { - BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } } -BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, +(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool +isNHWC), FLOAT_TYPES); */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp index 36f7166309f5..7ea8d2458f38 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp @@ -20,29 +20,26 @@ #include - namespace sd { namespace ops { namespace helpers { - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { - output.resize(axisVector->lengthOf()); - for (Nd4jLong e = 0; e < axisVector->lengthOf(); e++) { - auto ca = axisVector->e(e); - if (ca < 0) - ca += rank; - - output[e] = ca; - } - } +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { + output.resize(axisVector->lengthOf()); + for (Nd4jLong e = 0; e < axisVector->lengthOf(); e++) { + auto ca = axisVector->e(e); + if (ca < 0) ca += rank; - void adjustAxis(Nd4jLong rank, std::vector &axisVector) { - for (size_t e = 0; e < axisVector.size(); e++) { - auto a = axisVector[e]; - if (a < 0) - axisVector[e] = a + rank; - } - } -} + output[e] = ca; + } } + +void adjustAxis(Nd4jLong rank, std::vector& axisVector) { + for (size_t e = 0; e < axisVector.size(); e++) { + auto a = axisVector[e]; + if (a < 0) axisVector[e] = a + rank; + } } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp index ec8f040a9a6c..edf614c70d98 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batched_gemm.cpp @@ -18,114 +18,135 @@ // @author raver119@gmail.com // +#include +#include +#include #include #include -#include -#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template -void bgemm_(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - - int batchSize = vA.size(); - - if (BlasHelper::getInstance()->hasBatchedGEMM()) { - auto arr = vA.at(0); - CBLAS_TRANSPOSE *tA, *tB; - int *tM, *tN, *tK, *tldA, *tldB, *tldC, *tsize; - // mkl requires mnk etc as arrays, cuda doesn't - ALLOCATE(tA, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); - ALLOCATE(tB, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); - ALLOCATE(tM, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tN, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tK, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldA, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldB, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tldC, arr->getContext()->getWorkspace(), batchSize, int); - ALLOCATE(tsize, arr->getContext()->getWorkspace(), batchSize, int); - - shape::fill(tA, (CBLAS_TRANSPOSE) transA, batchSize); - shape::fill(tB, (CBLAS_TRANSPOSE) transB, batchSize); - - shape::fill(tM, M, batchSize); - shape::fill(tN, N, batchSize); - shape::fill(tK, K, batchSize); - shape::fill(tldA, lda, batchSize); - shape::fill(tldB, ldb, batchSize); - shape::fill(tldC, ldc, batchSize); - shape::fill(tsize, 1, batchSize); - - std::vector buffersA(batchSize); - std::vector buffersB(batchSize); - std::vector buffersC(batchSize); - - for (int e = 0; e < batchSize; e++) { - buffersA[e] = reinterpret_cast(vA[e]->buffer()); - buffersB[e] = reinterpret_cast(vB[e]->buffer()); - buffersC[e] = reinterpret_cast(vC[e]->buffer()); - } - - if (std::is_same::value) { - BlasHelper::getInstance()->dgemmBatched()(CblasColMajor, tA, tB, tM, tN, tK, (double *) alphas->buffer(), (double **) buffersA.data(), tldA, (double **) buffersB.data(), tldB, (double *) betas->buffer(),(double **) buffersC.data(), tldC, vA.size(), tsize); - } else if (std::is_same::value) { - BlasHelper::getInstance()->sgemmBatched()(CblasColMajor, tA, tB, tM, tN, tK, (float *) alphas->buffer(), (float **) buffersA.data(), tldA, (float **) buffersB.data(), tldB, (float *) betas->buffer(), (float **) buffersC.data(), tldC, vA.size(), tsize); - } +void bgemm_(const std::vector &vA, const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + int batchSize = vA.size(); + + if (BlasHelper::getInstance()->hasBatchedGEMM()) { + auto arr = vA.at(0); + CBLAS_TRANSPOSE *tA, *tB; + int *tM, *tN, *tK, *tldA, *tldB, *tldC, *tsize; + // mkl requires mnk etc as arrays, cuda doesn't + ALLOCATE(tA, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); + ALLOCATE(tB, arr->getContext()->getWorkspace(), batchSize, CBLAS_TRANSPOSE); + ALLOCATE(tM, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tN, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tK, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldA, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldB, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tldC, arr->getContext()->getWorkspace(), batchSize, int); + ALLOCATE(tsize, arr->getContext()->getWorkspace(), batchSize, int); + + shape::fill(tA, (CBLAS_TRANSPOSE)transA, batchSize); + shape::fill(tB, (CBLAS_TRANSPOSE)transB, batchSize); + + shape::fill(tM, M, batchSize); + shape::fill(tN, N, batchSize); + shape::fill(tK, K, batchSize); + shape::fill(tldA, lda, batchSize); + shape::fill(tldB, ldb, batchSize); + shape::fill(tldC, ldc, batchSize); + shape::fill(tsize, 1, batchSize); + + std::vector buffersA(batchSize); + std::vector buffersB(batchSize); + std::vector buffersC(batchSize); + + for (int e = 0; e < batchSize; e++) { + buffersA[e] = reinterpret_cast(vA[e]->buffer()); + buffersB[e] = reinterpret_cast(vB[e]->buffer()); + buffersC[e] = reinterpret_cast(vC[e]->buffer()); + } - // release temporary arrays - RELEASE(tA, arr->getContext()->getWorkspace()); - RELEASE(tB, arr->getContext()->getWorkspace()); - RELEASE(tM, arr->getContext()->getWorkspace()); - RELEASE(tN, arr->getContext()->getWorkspace()); - RELEASE(tK, arr->getContext()->getWorkspace()); - RELEASE(tldA, arr->getContext()->getWorkspace()); - RELEASE(tldB, arr->getContext()->getWorkspace()); - RELEASE(tldC, arr->getContext()->getWorkspace()); - RELEASE(tsize, arr->getContext()->getWorkspace()); - } else { - CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE) transA; - CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE) transB; - - int vaSize = vA.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto p = start; p < stop; p++) { - auto A = reinterpret_cast(vA.at(p)->buffer()); - auto B = reinterpret_cast(vB.at(p)->buffer()); - auto C = reinterpret_cast(vC.at(p)->buffer()); - auto alpha = alphas->e(p); - auto beta = betas->e(p); - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - T c_mnp = 0; - - PRAGMA_OMP_SIMD - for (int k = 0; k < K; ++k) - c_mnp += A[tA == CblasNoTrans ? (m + k * lda) : (m * lda + k)] * B[tB == CblasNoTrans ? (k + n * ldb) : (k * ldb + n)]; - - C[m + n * ldc] = alpha * c_mnp + beta * C[m + n * ldc]; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, vaSize); + if (std::is_same::value) { + BlasHelper::getInstance()->dgemmBatched()( + CblasColMajor, tA, tB, tM, tN, tK, (double *)alphas->buffer(), + (double **)buffersA.data(), tldA, (double **)buffersB.data(), tldB, + (double *)betas->buffer(), (double **)buffersC.data(), tldC, + vA.size(), tsize); + } else if (std::is_same::value) { + BlasHelper::getInstance()->sgemmBatched()( + CblasColMajor, tA, tB, tM, tN, tK, (float *)alphas->buffer(), + (float **)buffersA.data(), tldA, (float **)buffersB.data(), tldB, + (float *)betas->buffer(), (float **)buffersC.data(), tldC, vA.size(), + tsize); } -} + // release temporary arrays + RELEASE(tA, arr->getContext()->getWorkspace()); + RELEASE(tB, arr->getContext()->getWorkspace()); + RELEASE(tM, arr->getContext()->getWorkspace()); + RELEASE(tN, arr->getContext()->getWorkspace()); + RELEASE(tK, arr->getContext()->getWorkspace()); + RELEASE(tldA, arr->getContext()->getWorkspace()); + RELEASE(tldB, arr->getContext()->getWorkspace()); + RELEASE(tldC, arr->getContext()->getWorkspace()); + RELEASE(tsize, arr->getContext()->getWorkspace()); + } else { + CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE)transA; + CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE)transB; + + int vaSize = vA.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto p = start; p < stop; p++) { + auto A = reinterpret_cast(vA.at(p)->buffer()); + auto B = reinterpret_cast(vB.at(p)->buffer()); + auto C = reinterpret_cast(vC.at(p)->buffer()); + auto alpha = alphas->e(p); + auto beta = betas->e(p); + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + T c_mnp = 0; + + PRAGMA_OMP_SIMD + for (int k = 0; k < K; ++k) + c_mnp += A[tA == CblasNoTrans ? (m + k * lda) : (m * lda + k)] * + B[tB == CblasNoTrans ? (k + n * ldb) : (k * ldb + n)]; + + C[m + n * ldc] = alpha * c_mnp + beta * C[m + n * ldc]; + } + } + } + }; -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - auto xType = vA.at(0)->dataType(); - BUILD_SINGLE_SELECTOR(xType, bgemm_, (vA, vB, vC, alphas, betas, transA, transB, M, N, K, lda, ldb, ldc), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, vaSize); + } } -BUILD_SINGLE_TEMPLATE(template void bgemm_, (const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc), FLOAT_TYPES); - +void bgemm(const std::vector &vA, const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + auto xType = vA.at(0)->dataType(); + BUILD_SINGLE_SELECTOR( + xType, bgemm_, + (vA, vB, vC, alphas, betas, transA, transB, M, N, K, lda, ldb, ldc), + FLOAT_TYPES); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void bgemm_, + (const std::vector &vA, + const std::vector &vB, + std::vector &vC, const NDArray *alphas, + const NDArray *betas, int transA, int transB, int M, + int N, int K, const int lda, const int ldb, + const int ldc), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index 65c342d9ca09..988c703efc6c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -18,183 +18,213 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include -#include #include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, - NDArray* output, +static void batchnorm_(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - const T* m = mean->bufferAsT(); - const T* v = variance->bufferAsT(); - const T* g = gamma == nullptr ? nullptr : gamma->bufferAsT(); - const T* b = beta == nullptr ? nullptr : beta->bufferAsT(); - - const bool xzSameOffset = shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); - - bool paramSameOffset = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(paramSameOffset && gamma != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(paramSameOffset && beta != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - const Nd4jLong lenBig = input->lengthOf(); - const Nd4jLong lenSmall = mean->lengthOf(); - - const Nd4jLong steps = lenBig / lenSmall; - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); - - OmpLaunchHelper info(lenBig, lenSmall); - - auto func = PRAGMA_THREADS_DO { - - Nd4jLong* xOffsets = new Nd4jLong[steps]; - Nd4jLong* zOffsets = xzSameOffset ? xOffsets : new Nd4jLong[steps]; - int* auxBuff = new int[2 * input->rankOf()]; - - for (Nd4jLong j = 0; j < lenSmall; ++j) { - - const bool isOwner = (j < info._numThreads) ? thread_id == j : thread_id == (j % info._numThreads); - - if(!isOwner) - continue; - - const auto meanOffset = shape::getIndexOffset(j, mean->shapeInfo()); - const auto varOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, variance->shapeInfo()); - - const auto meanVal = m[meanOffset]; - auto sigmaInvGam = static_cast(1) / sd::math::nd4j_sqrt(v[varOffset] + epsilon); - - if(g != nullptr) { - const auto gammaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, gamma->shapeInfo()); - sigmaInvGam *= g[gammaOffset]; - } - - T betaVal = static_cast(0); - if(b != nullptr) { - const auto betaOffset = paramSameOffset ? meanOffset : shape::getIndexOffset(j, beta->shapeInfo()); - betaVal = b[betaOffset]; - } - - // calculate offsets for input and output - shape::outerArrayOffsets(xOffsets, j, input->shapeInfo(), mean->shapeInfo(), auxBuff, dimsToExclude.data()); - if(!xzSameOffset) - shape::outerArrayOffsets(zOffsets, j, output->shapeInfo(), mean->shapeInfo(), auxBuff, dimsToExclude.data()); - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < steps; ++i) - z[zOffsets[i]] = (x[xOffsets[i]] - meanVal) * sigmaInvGam + betaVal; - } - - delete []auxBuff; - delete []xOffsets; - if(!xzSameOffset) - delete []zOffsets; - }; - - samediff::Threads::parallel_do(func, info._numThreads); + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + const T* m = mean->bufferAsT(); + const T* v = variance->bufferAsT(); + const T* g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const T* b = beta == nullptr ? nullptr : beta->bufferAsT(); + + const bool xzSameOffset = + shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); + + bool paramSameOffset = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (paramSameOffset && gamma != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (paramSameOffset && beta != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + const Nd4jLong lenBig = input->lengthOf(); + const Nd4jLong lenSmall = mean->lengthOf(); + + const Nd4jLong steps = lenBig / lenSmall; + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + + OmpLaunchHelper info(lenBig, lenSmall); + + auto func = PRAGMA_THREADS_DO { + Nd4jLong* xOffsets = new Nd4jLong[steps]; + Nd4jLong* zOffsets = xzSameOffset ? xOffsets : new Nd4jLong[steps]; + int* auxBuff = new int[2 * input->rankOf()]; + + for (Nd4jLong j = 0; j < lenSmall; ++j) { + const bool isOwner = (j < info._numThreads) + ? thread_id == j + : thread_id == (j % info._numThreads); + + if (!isOwner) continue; + + const auto meanOffset = shape::getIndexOffset(j, mean->shapeInfo()); + const auto varOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, variance->shapeInfo()); + + const auto meanVal = m[meanOffset]; + auto sigmaInvGam = + static_cast(1) / sd::math::nd4j_sqrt(v[varOffset] + epsilon); + + if (g != nullptr) { + const auto gammaOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, gamma->shapeInfo()); + sigmaInvGam *= g[gammaOffset]; + } + + T betaVal = static_cast(0); + if (b != nullptr) { + const auto betaOffset = + paramSameOffset ? meanOffset + : shape::getIndexOffset(j, beta->shapeInfo()); + betaVal = b[betaOffset]; + } + + // calculate offsets for input and output + shape::outerArrayOffsets(xOffsets, j, input->shapeInfo(), + mean->shapeInfo(), auxBuff, + dimsToExclude.data()); + if (!xzSameOffset) + shape::outerArrayOffsets(zOffsets, j, output->shapeInfo(), + mean->shapeInfo(), auxBuff, + dimsToExclude.data()); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < steps; ++i) + z[zOffsets[i]] = (x[xOffsets[i]] - meanVal) * sigmaInvGam + betaVal; + } + + delete[] auxBuff; + delete[] xOffsets; + if (!xzSameOffset) delete[] zOffsets; + }; + + samediff::Threads::parallel_do(func, info._numThreads); } ////////////////////////////////////////////////////////////////////////// template -static void batchnorm2_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, - NDArray* output, +static void batchnorm2_(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - const auto x = input->bufferAsT(); - auto z = output->bufferAsT(); - const auto m = mean->bufferAsT(); - const auto v = variance->bufferAsT(); - const auto g = gamma == nullptr ? nullptr : gamma->bufferAsT(); - const auto b = beta == nullptr ? nullptr : beta->bufferAsT(); - - // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank - const uint xRank = input->rankOf(); - const uint minRank = mean->rankOf(); - const uint numAxes = axes.size(); - - const bool xzSameOffset = shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); - - bool paramSameOffset = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(paramSameOffset && gamma != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(paramSameOffset && beta != nullptr) - paramSameOffset &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - auto func = PRAGMA_THREADS_FOR { - - int xzCoords[MAX_RANK], minCoords[MAX_RANK]; - - for (uint i = 0, j = 0; i < xRank; ++i) - if(j < numAxes && i != axes[j]) - minCoords[i] = 0; - else - ++j; - - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, input->shapeInfo(), xzCoords); - - const auto xOffset = shape::getOffset(input->shapeInfo(), xzCoords); - const auto zOffset = xzSameOffset ? xOffset : shape::getOffset(output->shapeInfo(), xzCoords); - - if(minRank == xRank) { - for (uint j = 0; j < numAxes; ++j) - minCoords[axes[j]] = xzCoords[axes[j]]; - } - else // minRank = numAxes = 1 in this case - minCoords[0] = xzCoords[axes[0]]; - - const auto meanOffset = shape::getOffset(mean->shapeInfo(), minCoords); - const auto varianceOffset = paramSameOffset ? meanOffset : shape::getOffset(variance->shapeInfo(), minCoords); - - T sigmaInvGam = 1. / sd::math::nd4j_sqrt(v[varianceOffset] + epsilon); - - if(g != nullptr) { - const auto gammaOffset = paramSameOffset ? meanOffset : shape::getOffset(gamma->shapeInfo(), minCoords); - sigmaInvGam *= g[gammaOffset]; - } - - z[zOffset] = (x[xOffset] - m[meanOffset]) * sigmaInvGam; - - if(b != nullptr) { - const auto betaOffset = paramSameOffset ? meanOffset : shape::getOffset(beta->shapeInfo(), minCoords); - z[zOffset] += b[betaOffset]; - } - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf()); + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + + // beta + + const auto x = input->bufferAsT(); + auto z = output->bufferAsT(); + const auto m = mean->bufferAsT(); + const auto v = variance->bufferAsT(); + const auto g = gamma == nullptr ? nullptr : gamma->bufferAsT(); + const auto b = beta == nullptr ? nullptr : beta->bufferAsT(); + + // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank + const uint xRank = input->rankOf(); + const uint minRank = mean->rankOf(); + const uint numAxes = axes.size(); + + const bool xzSameOffset = + shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo()); + + bool paramSameOffset = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (paramSameOffset && gamma != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (paramSameOffset && beta != nullptr) + paramSameOffset &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int xzCoords[MAX_RANK], minCoords[MAX_RANK]; + + for (uint i = 0, j = 0; i < xRank; ++i) + if (j < numAxes && i != axes[j]) + minCoords[i] = 0; + else + ++j; + + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, input->shapeInfo(), xzCoords); + + const auto xOffset = shape::getOffset(input->shapeInfo(), xzCoords); + const auto zOffset = + xzSameOffset ? xOffset + : shape::getOffset(output->shapeInfo(), xzCoords); + + if (minRank == xRank) { + for (uint j = 0; j < numAxes; ++j) + minCoords[axes[j]] = xzCoords[axes[j]]; + } else // minRank = numAxes = 1 in this case + minCoords[0] = xzCoords[axes[0]]; + + const auto meanOffset = shape::getOffset(mean->shapeInfo(), minCoords); + const auto varianceOffset = + paramSameOffset ? meanOffset + : shape::getOffset(variance->shapeInfo(), minCoords); + + T sigmaInvGam = + 1. / sd::math::nd4j_sqrt(v[varianceOffset] + epsilon); + + if (g != nullptr) { + const auto gammaOffset = + paramSameOffset ? meanOffset + : shape::getOffset(gamma->shapeInfo(), minCoords); + sigmaInvGam *= g[gammaOffset]; + } + + z[zOffset] = (x[xOffset] - m[meanOffset]) * sigmaInvGam; + + if (b != nullptr) { + const auto betaOffset = + paramSameOffset ? meanOffset + : shape::getOffset(beta->shapeInfo(), minCoords); + z[zOffset] += b[betaOffset]; + } + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); } ////////////////////////////////////////////////////////////////////////// -void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // batchnorm2_ is still slower ? - BUILD_SINGLE_SELECTOR(input->dataType(), batchnorm_, (input, mean, variance, gamma, beta, output, axes, epsilon), FLOAT_TYPES); +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon) { + // batchnorm2_ is still slower ? + BUILD_SINGLE_SELECTOR( + input->dataType(), batchnorm_, + (input, mean, variance, gamma, beta, output, axes, epsilon), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void batchnorm_, + (const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon), + FLOAT_TYPES); - -BUILD_SINGLE_TEMPLATE(template void batchnorm_, (const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon), FLOAT_TYPES); - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index ec06610b8a44..219cac4b4ad8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -18,11 +18,12 @@ // Created by Yurii Shyrma on 11.12.2017 // -#include #include -#include #include #include +#include + +#include namespace sd { namespace ops { @@ -30,111 +31,112 @@ namespace helpers { /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, -// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” +// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering +// Calculations Using Continued Fractions” template static T continuedFraction(const T a, const T b, const T x) { - - const T min = DataTypeUtils::min() / DataTypeUtils::eps(); - const T aPlusb = a + b; - T val, aPlus2i; - - T t2 = 1; - T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - T result = t1; - - for(uint i = 1; i <= maxIter; ++i) { - - aPlus2i = a + static_cast(2*i); - val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - // t1 - t1 = static_cast(1) + val * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + val / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - result *= t2 * t1; - val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - // t1 - t1 = static_cast(1) + val * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + val / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - val = t2 * t1; - result *= val; - - // condition to stop loop - if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) - return result; - } - - return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity + const T min = DataTypeUtils::min() / DataTypeUtils::eps(); + const T aPlusb = a + b; + T val, aPlus2i; + + T t2 = 1; + T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + T result = t1; + + for (uint i = 1; i <= maxIter; ++i) { + aPlus2i = a + static_cast(2 * i); + val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + // t1 + t1 = static_cast(1) + val * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + result *= t2 * t1; + val = + -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); + // t1 + t1 = static_cast(1) + val * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + val = t2 * t1; + result *= val; + + // condition to stop loop + if (math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; + } + + return DataTypeUtils::infOrMax(); // no convergence, more iterations is + // required, return infinity } /////////////////////////////////////////////////////////////////// -// evaluates incomplete beta function for positive a and b, and x between 0 and 1. +// evaluates incomplete beta function for positive a and b, and x between 0 +// and 1. template static T betaIncCore(T a, T b, T x) { - // if (a <= (T)0. || b <= (T)0.) - // throw("betaInc function: a and b must be > 0 !"); - - // if (x < (T)0. || x > (T)1.) - // throw("betaInc function: x must be within (0, 1) interval !"); + // if (a <= (T)0. || b <= (T)0.) + // throw("betaInc function: a and b must be > 0 !"); + // if (x < (T)0. || x > (T)1.) + // throw("betaInc function: x must be within (0, 1) interval !"); - // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 - if(a == b && x == static_cast(0.5)) - return static_cast(0.5); + // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 + if (a == b && x == static_cast(0.5)) return static_cast(0.5); - if (x == static_cast(0) || x == static_cast(1)) - return x; + if (x == static_cast(0) || x == static_cast(1)) return x; - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = + math::nd4j_exp(math::nd4j_log(x) * a + + math::nd4j_log(1.f - x) * b - gammaPart); - if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x) / a; - else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; + if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) + return front * continuedFraction(a, b, x) / a; + else // symmetry relation + return static_cast(1) - + front * continuedFraction(b, a, static_cast(1) - x) / b; } /////////////////////////////////////////////////////////////////// -template -static void betaIncForArray(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - - int xLen = x.lengthOf(); +template +static void betaIncForArray(sd::LaunchContext* context, const NDArray& a, + const NDArray& b, const NDArray& x, + NDArray& output) { + int xLen = x.lengthOf(); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.t(i) = betaIncCore(a.t(i), b.t(i), x.t(i)); - }; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output.t(i) = betaIncCore(a.t(i), b.t(i), x.t(i)); + }; - samediff::Threads::parallel_for(func, 0, xLen); + samediff::Threads::parallel_for(func, 0, xLen); } /////////////////////////////////////////////////////////////////// // overload betaInc for arrays, shapes of a, b and x must be the same !!! -void betaInc(sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - auto xType = a.dataType(); - BUILD_SINGLE_SELECTOR(xType, betaIncForArray, (context, a, b, x, output), FLOAT_TYPES); +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output) { + auto xType = a.dataType(); + BUILD_SINGLE_SELECTOR(xType, betaIncForArray, (context, a, b, x, output), + FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void betaIncForArray, (sd::LaunchContext * context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output), FLOAT_TYPES); - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void betaIncForArray, + (sd::LaunchContext * context, const NDArray& a, + const NDArray& b, const NDArray& x, NDArray& output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp index 2b02e1a46a57..b89a008ab925 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp @@ -18,231 +18,251 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template -static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - - const int rank = input.rankOf(); - const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - - const T normActual = norm2.e(0); - const T normClip = clipNorm.e(0); - - if (isInplace) { - - if(norm2.lengthOf() == 1) { - - if(normActual > normClip) - input *= (normClip / normActual); - } - else { - - auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T iNormActual = norm2.e(i); - if (iNormActual > normClip) - listOfInSubArrs.at(i) *= normClip / iNormActual; - } - }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); +template +static void clipByNorm_(NDArray& input, NDArray& output, + const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace) { + const int rank = input.rankOf(); + const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); + + const T normActual = norm2.e(0); + const T normClip = clipNorm.e(0); + + if (isInplace) { + if (norm2.lengthOf() == 1) { + if (normActual > normClip) input *= (normClip / normActual); + } else { + auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T iNormActual = norm2.e(i); + if (iNormActual > normClip) + listOfInSubArrs.at(i) *= normClip / iNormActual; } + }; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } - else { - - if(norm2.lengthOf() == 1) { - - if(normActual > normClip) - output.assign(input * (normClip / normActual)); - else - output.assign(input); - } - else { - - auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); - auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inputSubArr = listOfInSubArrs.at(i); - auto outputSubArr = listOfOutSubArrs.at(i); - outputSubArr.assign(inputSubArr); - - const T iNormActual = norm2.e(i); - - if (iNormActual > clipNorm.e(0)) - outputSubArr *= clipNorm / iNormActual; - } - }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); + } else { + if (norm2.lengthOf() == 1) { + if (normActual > normClip) + output.assign(input * (normClip / normActual)); + else + output.assign(input); + } else { + auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); + auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inputSubArr = listOfInSubArrs.at(i); + auto outputSubArr = listOfOutSubArrs.at(i); + outputSubArr.assign(inputSubArr); + + const T iNormActual = norm2.e(i); + + if (iNormActual > clipNorm.e(0)) + outputSubArr *= clipNorm / iNormActual; } + }; + samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); } + } } ////////////////////////////////////////////////////////////////////////// -void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace) { + BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, + (input, output, dimensions, clipNorm, isInplace), + FLOAT_TYPES); } - - template - static void clipByGlobalNorm_(std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - T globalNorm = 0; //NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) -// PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(sumT : globalNorm) - for (size_t i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - auto l2norm = input->reduceNumber(reduce::Norm2); - globalNorm += l2norm.t(0) * l2norm.t(0); - } - - //globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = sd::math::nd4j_sqrt(globalNorm); - auto normS = sd::math::nd4j_sqrt(globalNorm); - outputs[inputs.size()]->p(0, normS); - - const T factor = clipNorm / normS; - -// PRAGMA_OMP_PARALLEL_FOR - for (size_t e = 0; e < inputs.size(); e++) { - // all-reduce - auto input = inputs[e]; - auto output = outputs[e]; - - if (normS <= clipNorm) { - output->assign(input); - } - else { - - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, *output); - } - } - } - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); +template +static void clipByGlobalNorm_(std::vector const& inputs, + double clipNorm, sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + T globalNorm = 0; // NDArrayFactory::create(0, inputs[0]->getContext()); + // //sqrt(sum([l2norm(t)**2 for t in t_list])) + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(sumT : globalNorm) + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + auto l2norm = input->reduceNumber(reduce::Norm2); + globalNorm += l2norm.t(0) * l2norm.t(0); + } + + // globalNorm.applyTransform(transform::Sqrt, nullptr, nullptr);// = + // sd::math::nd4j_sqrt(globalNorm); + auto normS = sd::math::nd4j_sqrt(globalNorm); + outputs[inputs.size()]->p(0, normS); + + const T factor = clipNorm / normS; + + // PRAGMA_OMP_PARALLEL_FOR + for (size_t e = 0; e < inputs.size(); e++) { + // all-reduce + auto input = inputs[e]; + auto output = outputs[e]; + + if (normS <= clipNorm) { + output->assign(input); + } else { + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + input->applyLambda(lambda, *output); } + } +} +void clipByGlobalNorm(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, + (inputs, clipNorm, workspace, outputs, isInplace), + FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, + (std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - - const int rank = input.rankOf(); - - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - - if(norm2.lengthOf() == 1) { +template +static void clipByNormBP_(const NDArray& input, const NDArray& gradO, + NDArray& gradI /*output*/, + const std::vector& dimensions, + const NDArray& clipNorm) { + const int rank = input.rankOf(); - const T N = norm2.e(0); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - auto cn = clipNorm.e(0); + if (norm2.lengthOf() == 1) { + const T N = norm2.e(0); - if(N > cn) { - - const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e(0); // reduce to scalar - const T factor1 = static_cast(1.f) / N; - const T factor3 = factor1 / (N * N); // 1 / (N*N*N) - - auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { - return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); - }; - - (const_cast(input)).applyPairwiseLambda(const_cast(gradO), lambda, gradI); - } - else - gradI.assign(gradO); - } - else { - - auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); - auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); - auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); - - auto cn = clipNorm.e(0); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - T N = norm2.e(i); - - auto gradOSubArr = gradOSubArrs.at(i); - auto gradISubArr = gradISubArrs.at(i); + auto cn = clipNorm.e(0); - if (N > cn) { - auto inputSubArr = inputSubArrs.at(i); - const T sumOfProd = (inputSubArr * gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar - const T factor1 = static_cast(1.f) / N; - const T factor3 = factor1 / (N * N); // 1 / (N*N*N) + if (N > cn) { + const T sumOfProd = (input * gradO) + .reduceNumber(reduce::Sum) + .e(0); // reduce to scalar + const T factor1 = static_cast(1.f) / N; + const T factor3 = factor1 / (N * N); // 1 / (N*N*N) + + auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { + return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); + }; + + (const_cast(input)) + .applyPairwiseLambda(const_cast(gradO), lambda, gradI); + } else + gradI.assign(gradO); + } else { + auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); + auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); + auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); - auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { - return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); - }; + auto cn = clipNorm.e(0); - inputSubArr.applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); - } else - gradISubArr.assign(gradOSubArr); - } - }; - samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); - } + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + T N = norm2.e(i); + + auto gradOSubArr = gradOSubArrs.at(i); + auto gradISubArr = gradISubArrs.at(i); + + if (N > cn) { + auto inputSubArr = inputSubArrs.at(i); + const T sumOfProd = (inputSubArr * gradOSubArr) + .reduceNumber(reduce::Sum) + .e(0); // reduce to scalar + const T factor1 = static_cast(1.f) / N; + const T factor3 = factor1 / (N * N); // 1 / (N*N*N) + + auto lambda = + LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { + return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); + }; + + inputSubArr.applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); + } else + gradISubArr.assign(gradOSubArr); + } + }; + samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); + } } - void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm), FLOAT_TYPES); +void clipByNormBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI /*output*/, + const std::vector& dimensions, const NDArray& clipNorm) { + BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, + (input, gradO, gradI, dimensions, clipNorm), + FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, + (const NDArray& input, const NDArray& gradO, + NDArray& gradI /*output*/, + const std::vector& dimensions, + const NDArray& clipNorm), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -template -static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - - auto cn = clipNorm.e(0); - if (dimensions.size() == 0) { - // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / input.lengthOf(); - if (n2 <= cn) { - if (!isInplace) - output.assign(input); - } - else { - const T factor = cn / n2; - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input.applyLambda(lambda, output); - } +template +static void clipByAveraged_(NDArray& input, NDArray& output, + const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace) { + auto cn = clipNorm.e(0); + if (dimensions.size() == 0) { + // all-reduce + T n2 = input.reduceNumber(reduce::Norm2).e(0) / input.lengthOf(); + if (n2 <= cn) { + if (!isInplace) output.assign(input); + } else { + const T factor = cn / n2; + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + input.applyLambda(lambda, output); } - else { - // along dimension - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); - if (!isInplace) - output.assign(input); - auto tads = output.allTensorsAlongDimension(dimensions); - // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads.size(); e++) { - T n2 = norm2.e(e) / tads.at(e).lengthOf(); - const T factor = cn / n2; - if (n2 > cn) { - auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads.at(e).applyLambda(lambda, output); - } - } + } else { + // along dimension + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); + if (!isInplace) output.assign(input); + auto tads = output.allTensorsAlongDimension(dimensions); + // TODO: make this CUDA-compliant somehow + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / tads.at(e).lengthOf(); + const T factor = cn / n2; + if (n2 > cn) { + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + tads.at(e).applyLambda(lambda, output); + } } + } } - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); - } +void clipByAveraged(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, + (input, output, dimensions, clipNorm, isInplace), + FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, + (NDArray & input, NDArray& output, + const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace), + FLOAT_TYPES); /* if (d1 > params[1]) @@ -252,23 +272,29 @@ static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector - static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) { - auto routine = LAMBDA_T(_x, leftBound, rightBound) { - if (_x > rightBound) return rightBound; - if (_x < leftBound) return leftBound; - return _x; - }; +template +static void clipByValue_(NDArray& input, double leftBound, double rightBound, + NDArray& output) { + auto routine = LAMBDA_T(_x, leftBound, rightBound) { + if (_x > rightBound) return rightBound; + if (_x < leftBound) return leftBound; + return _x; + }; - input.applyLambda(routine, output); - } + input.applyLambda(routine, output); +} - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (input, leftBound, rightBound, output), FLOAT_TYPES); - } +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, + (input, leftBound, rightBound, output), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByValue_, (NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByValue_, + (NDArray & input, double leftBound, double rightBound, + NDArray& output); + , FLOAT_TYPES); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp index 42d4af529217..9f13b3af177e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp @@ -18,8 +18,8 @@ // Created by raver119 on 30.11.17. // -#include #include +#include namespace sd { namespace ops { @@ -27,115 +27,120 @@ namespace helpers { // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] template -void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - - auto imBuff = output.bufferAsT(); - auto colBuff = input.bufferAsT(); - auto imShapeBuffer = output.shapeInfo(); - auto colShapeBuffer = input.shapeInfo(); - auto colShape = shape::shapeOf(colShapeBuffer); - auto colStride = shape::stride(colShapeBuffer); - auto imShape = shape::shapeOf(imShapeBuffer); - auto imStride = shape::stride(imShapeBuffer); - - const int bS = imShape[0]; - const int iC = imShape[1]; - const int kH = colShape[2]; - const int kW = colShape[3]; - const int oH = colShape[4]; - const int oW = colShape[5]; - const Nd4jLong colStride0 = colStride[0]; - const Nd4jLong colStride1 = colStride[1]; - const Nd4jLong colStride2 = colStride[2]; - const Nd4jLong colStride3 = colStride[3]; - const Nd4jLong colStride4 = colStride[4]; - const Nd4jLong colStride5 = colStride[5]; - const Nd4jLong imStride0 = imStride[0]; - const Nd4jLong imStride1 = imStride[1]; - const Nd4jLong imStride2 = imStride[2]; - const Nd4jLong imStride3 = imStride[3]; - - - // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { - if (false) { - - auto func = PRAGMA_THREADS_FOR_2D { - T const* col; - T* im; - - int imRow, imCol; - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto c = start_y; c < stop_y; c += inc_y) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - imRow = (-pH + kRow * dH) + colH * sH; - imCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - - if (static_cast(imRow) < static_cast(iH) && - static_cast(imCol) < static_cast(iW)) - *im += *col; - } - } - } - } +void col2im_(sd::LaunchContext& context, const NDArray& input, NDArray& output, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + auto imBuff = output.bufferAsT(); + auto colBuff = input.bufferAsT(); + auto imShapeBuffer = output.shapeInfo(); + auto colShapeBuffer = input.shapeInfo(); + auto colShape = shape::shapeOf(colShapeBuffer); + auto colStride = shape::stride(colShapeBuffer); + auto imShape = shape::shapeOf(imShapeBuffer); + auto imStride = shape::stride(imShapeBuffer); + + const int bS = imShape[0]; + const int iC = imShape[1]; + const int kH = colShape[2]; + const int kW = colShape[3]; + const int oH = colShape[4]; + const int oW = colShape[5]; + const Nd4jLong colStride0 = colStride[0]; + const Nd4jLong colStride1 = colStride[1]; + const Nd4jLong colStride2 = colStride[2]; + const Nd4jLong colStride3 = colStride[3]; + const Nd4jLong colStride4 = colStride[4]; + const Nd4jLong colStride5 = colStride[5]; + const Nd4jLong imStride0 = imStride[0]; + const Nd4jLong imStride1 = imStride[1]; + const Nd4jLong imStride2 = imStride[2]; + const Nd4jLong imStride3 = imStride[3]; + + // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == + // 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && + // shape::strideDescendingCAscendingF(imShapeBuffer)) { + if (false) { + auto func = PRAGMA_THREADS_FOR_2D { + T const* col; + T* im; + + int imRow, imCol; + + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto c = start_y; c < stop_y; c += inc_y) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + imRow = (-pH + kRow * dH) + colH * sH; + imCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + + if (static_cast(imRow) < + static_cast(iH) && + static_cast(imCol) < static_cast(iW)) + *im += *col; } + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - - auto func = PRAGMA_THREADS_FOR { - T *col, *im; - - for (auto b = start; b < stop; b++) { - T *im0 = imBuff + b * imStride0; - T const* col4 = colBuff + b * colStride0; - for (int colH = 0; colH < oH; ++colH, col4 += colStride4) { - T const* col5 = col4; - for (int colW = 0; colW < oW; ++colW, col5 += colStride5) { - T const* col1 = col5; - T *im1 = im0; - for (int c = 0; c < iC; ++c, col1 += colStride1, im1 += imStride1) { - int imRow = (-pH + colH * sH); - T const* col2 = col1; - T *im2 = im1 + imRow * imStride2; - for (int kRow = 0; - kRow < kH; ++kRow, col2 += colStride2, imRow += dH, im2 += dH * imStride2) { - int imCol = -pW + colW * sW; - T const* col3 = col2; - T *im3 = im2 + imCol * imStride3; - for (int kCol = 0; - kCol < kW; ++kCol, col3 += colStride3, imCol += dW, im3 += dW * imStride3) { - - if (static_cast(imRow) < static_cast(iH) && - static_cast(imCol) < static_cast(iW)) - *im3 += *col3; - } - } - } - } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + auto func = PRAGMA_THREADS_FOR { + T *col, *im; + + for (auto b = start; b < stop; b++) { + T* im0 = imBuff + b * imStride0; + T const* col4 = colBuff + b * colStride0; + for (int colH = 0; colH < oH; ++colH, col4 += colStride4) { + T const* col5 = col4; + for (int colW = 0; colW < oW; ++colW, col5 += colStride5) { + T const* col1 = col5; + T* im1 = im0; + for (int c = 0; c < iC; ++c, col1 += colStride1, im1 += imStride1) { + int imRow = (-pH + colH * sH); + T const* col2 = col1; + T* im2 = im1 + imRow * imStride2; + for (int kRow = 0; kRow < kH; ++kRow, col2 += colStride2, + imRow += dH, im2 += dH * imStride2) { + int imCol = -pW + colW * sW; + T const* col3 = col2; + T* im3 = im2 + imCol * imStride3; + for (int kCol = 0; kCol < kW; ++kCol, col3 += colStride3, + imCol += dW, im3 += dW * imStride3) { + if (static_cast(imRow) < + static_cast(iH) && + static_cast(imCol) < static_cast(iW)) + *im3 += *col3; } + } } - }; + } + } + } + }; - samediff::Threads::parallel_tad(func, 0, bS); - } + samediff::Threads::parallel_tad(func, 0, bS); + } } - -void col2im(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); +void col2im(sd::LaunchContext& context, const NDArray& input, NDArray& output, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR( + input.dataType(), col2im_, + (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp index 32dc3d7c79b9..e33d89b7d302 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compare_elem.cpp @@ -14,62 +14,65 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void _compare_elem(NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto length = shape::length(input->shapeInfo()); - - int elementsPerThread = length / ELEMENT_THRESHOLD; - int num_threads = sd::math::nd4j_max(1, elementsPerThread); - num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); - Nd4jLong sumt = 0; - - if(isStrictlyIncreasing) { - //PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) - auto func = PRAGMA_REDUCE_LONG { - Nd4jLong sum = 0; - for (auto i = start; i < stop; i++) { - auto val0 = input->t(i); - auto val1 = input->t(i + 1); - sum += val0 >= val1 ? -1 : 0; - } - return sum; - }; - sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); - } else { - //PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) - auto func = PRAGMA_REDUCE_LONG { - Nd4jLong sum = 0; - for (auto i = start; i < stop; i++) { - auto val0 = input->t(i); - auto val1 = input->t(i + 1); - sum += val0 > val1 ? -1 : 0; - } +namespace ops { +namespace helpers { +template +static void _compare_elem(NDArray* input, bool isStrictlyIncreasing, + bool& output) { + auto length = shape::length(input->shapeInfo()); - return sum; - }; - sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); - } + int elementsPerThread = length / ELEMENT_THRESHOLD; + int num_threads = sd::math::nd4j_max(1, elementsPerThread); + num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); + Nd4jLong sumt = 0; - //nd4j_printf("Sum: %lld\n", sumt) + if (isStrictlyIncreasing) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) + auto func = PRAGMA_REDUCE_LONG { + Nd4jLong sum = 0; + for (auto i = start; i < stop; i++) { + auto val0 = input->t(i); + auto val1 = input->t(i + 1); + sum += val0 >= val1 ? -1 : 0; + } + return sum; + }; + sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); + } else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:sum) + auto func = PRAGMA_REDUCE_LONG { + Nd4jLong sum = 0; + for (auto i = start; i < stop; i++) { + auto val0 = input->t(i); + auto val1 = input->t(i + 1); + sum += val0 > val1 ? -1 : 0; + } - output = (sumt > -1); + return sum; + }; + sumt = samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, length - 1); + } - } + // nd4j_printf("Sum: %lld\n", sumt) - void compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _compare_elem, (input, isStrictlyIncreasing, output), LIBND4J_TYPES); - } + output = (sumt > -1); +} +void compare_elem(sd::LaunchContext* context, NDArray* input, + bool isStrictlyIncreasing, bool& output) { + auto xType = input->dataType(); - BUILD_SINGLE_TEMPLATE(template void _compare_elem, (NDArray *A, bool isStrictlyIncreasing, bool& output);, LIBND4J_TYPES); - } - } + BUILD_SINGLE_SELECTOR(xType, _compare_elem, + (input, isStrictlyIncreasing, output), LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void _compare_elem, + (NDArray * A, bool isStrictlyIncreasing, bool& output); + , LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp index 94e74cd84786..2f492a7d28ad 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_0, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_0, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp index 9820c1392204..8ca7c4e41ae7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_1, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_1, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp index 2a78f285f75d..60f2c9c4eef6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_2, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_2, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp index 13757997a522..0b5f1ba7f3a6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_3, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_3, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp index ea3043eeb384..23390c9b0a5c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_4, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_4, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp index 60c1ae906101..747a64a31c1f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_5, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_5, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp index 6e33d5546f23..4253fd621e33 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_6, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_6, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp index ef4a199fd2cb..1dc0f49f1356 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_7, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_7, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp index 71cd2ebb87b7..2d2ea0ab3ac8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_8, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_8, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp index e9db5c303cf7..99de466e6027 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp @@ -19,12 +19,17 @@ // #include + #include "../crop_and_resize.hpp" namespace sd { - namespace ops { - namespace helpers { - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, (NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops), NUMERIC_TYPES_9, FLOAT_TYPES, INTEGER_TYPES); - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, NDArray *crops), + NUMERIC_TYPES_9, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp index 0911b0619478..53eade55f857 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compression/compression.cpp @@ -17,21 +17,25 @@ // // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output) { - NativeOpExecutioner::decodeBitmap(input->buffer(), output->lengthOf(), output->buffer(), output->shapeInfo()); - } - - - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold) { - return NativeOpExecutioner::encodeBitmap(input->buffer(), input->shapeInfo(), input->lengthOf(), output->bufferAsT(), threshold); - } -} +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + NativeOpExecutioner::decodeBitmap(input->buffer(), output->lengthOf(), + output->buffer(), output->shapeInfo()); } + +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold) { + return NativeOpExecutioner::encodeBitmap(input->buffer(), input->shapeInfo(), + input->lengthOf(), + output->bufferAsT(), threshold); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp index bac3812d1739..33f2cffb711f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compression/threshold.cpp @@ -18,45 +18,53 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static int32_t thresholdEstimate_(const NDArray &updates, const float threshold) { - auto N = updates.lengthOf(); - const auto buffer = updates.bufferAsT(); - - auto func = PRAGMA_REDUCE_LONG { - int64_t cnt = 0; - for (auto e = start; e < stop; e++) { - auto v = sd::math::nd4j_abs(buffer[e]); - if (v >= threshold) - cnt++; - } - - return cnt; - }; - - return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N); - } - - int32_t thresholdEstimate(const NDArray &updates, const float threshold) { - BUILD_SINGLE_SELECTOR(updates.dataType(), return thresholdEstimate_, (updates, threshold), FLOAT_TYPES); - - return 0; - } - - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { - BUILD_SINGLE_SELECTOR(updates.dataType(), sd::TypeCast::convertToThreshold, (nullptr, updates.buffer(), updates.lengthOf(), encoded.buffer()), FLOAT_TYPES); - } - - void thresholdDecode(const NDArray &encoded, NDArray &updates) { - BUILD_SINGLE_SELECTOR(updates.dataType(), sd::TypeCast::convertFromThreshold, (nullptr, encoded.buffer(), updates.lengthOf(), updates.buffer()), FLOAT_TYPES); - } - } +namespace ops { +namespace helpers { +template +static int32_t thresholdEstimate_(const NDArray &updates, + const float threshold) { + auto N = updates.lengthOf(); + const auto buffer = updates.bufferAsT(); + + auto func = PRAGMA_REDUCE_LONG { + int64_t cnt = 0; + for (auto e = start; e < stop; e++) { + auto v = sd::math::nd4j_abs(buffer[e]); + if (v >= threshold) cnt++; } + + return cnt; + }; + + return samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, N); +} + +int32_t thresholdEstimate(const NDArray &updates, const float threshold) { + BUILD_SINGLE_SELECTOR(updates.dataType(), return thresholdEstimate_, + (updates, threshold), FLOAT_TYPES); + + return 0; +} + +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { + BUILD_SINGLE_SELECTOR( + updates.dataType(), sd::TypeCast::convertToThreshold, + (nullptr, updates.buffer(), updates.lengthOf(), encoded.buffer()), + FLOAT_TYPES); +} + +void thresholdDecode(const NDArray &encoded, NDArray &updates) { + BUILD_SINGLE_SELECTOR( + updates.dataType(), sd::TypeCast::convertFromThreshold, + (nullptr, encoded.buffer(), updates.lengthOf(), updates.buffer()), + FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp b/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp index 9fd2b5b02381..cdb74c618ca0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/concat.cpp @@ -18,24 +18,30 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - #include #include namespace sd { - namespace ops { - namespace helpers { - ////////////////////////////////////////////////////////////////////////// - template - static void concat_(const std::vector& inArrs, NDArray& output, const int axis) { - sd::SpecialMethods::concatCpuGeneric(inArrs, output, axis); - } +namespace ops { +namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static void concat_(const std::vector& inArrs, NDArray& output, + const int axis) { + sd::SpecialMethods::concatCpuGeneric(inArrs, output, axis); +} - void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - BUILD_SINGLE_SELECTOR(output.dataType(), concat_,(inArrs, output, axis), LIBND4J_TYPES); - } +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis) { + BUILD_SINGLE_SELECTOR(output.dataType(), concat_, (inArrs, output, axis), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector& inArrs, NDArray& output, const int axis), LIBND4J_TYPES); - } - } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template void concat_, + (const std::vector& inArrs, + NDArray& output, const int axis), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp index b524dc1eacb2..f4ec0c667dc1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/confusion.cpp @@ -18,39 +18,44 @@ // @author GS // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - ResultSet arrs = output->allTensorsAlongDimension({1}); - int lLen = labels->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (int j = start; j < stop; j++) { - auto label = labels->e(j); - auto pred = predictions->e(j); - T value = (weights == nullptr ? (T) 1.0f : weights->e(j)); - arrs.at(label).p(pred, value); - } - }; - - samediff::Threads::parallel_for(func, 0, lLen); +template +void _confusionFunctor(NDArray* labels, NDArray* predictions, NDArray* weights, + NDArray* output) { + ResultSet arrs = output->allTensorsAlongDimension({1}); + int lLen = labels->lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (int j = start; j < stop; j++) { + auto label = labels->e(j); + auto pred = predictions->e(j); + T value = (weights == nullptr ? (T)1.0f : weights->e(j)); + arrs.at(label).p(pred, value); } + }; - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto xType = output->dataType(); // weights can be null - - BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, (labels, predictions, weights, output), NUMERIC_TYPES); - } + samediff::Threads::parallel_for(func, 0, lLen); +} - BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, (NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output);, NUMERIC_TYPES); +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output) { + auto xType = output->dataType(); // weights can be null + BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, + (labels, predictions, weights, output), NUMERIC_TYPES); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, + (NDArray * labels, NDArray* predictions, NDArray* weights, + NDArray* output); + , NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp index b12064cacda8..2715fb0ae539 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_col2vol.cpp @@ -18,126 +18,150 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// // [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] template -static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - // initial zeroing of volume content - volume.nullify(); - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* volBuff = volume.bufferAsT(); - T* colBuff = const_cast(columns).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.shapeInfo()) && shape::strideDescendingCAscendingF(columns.shapeInfo())) { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int c = 0; c < iC; c++) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = -pD + kDep * dD + colD * sD; - volRow = -pH + kRow * dH + colH * sH; - volCol = -pW + kCol * dW + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } +static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW) { + // initial zeroing of volume content + volume.nullify(); + + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* volBuff = volume.bufferAsT(); + T* colBuff = const_cast(columns).bufferAsT(); + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && + shape::strideDescendingCAscendingF(volume.shapeInfo()) && + shape::strideDescendingCAscendingF(columns.shapeInfo())) { + auto func = PRAGMA_THREADS_FOR { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int c = 0; c < iC; c++) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + volDep = -pD + kDep * dD + colD * sD; + volRow = -pH + kRow * dH + colH * sH; + volCol = -pW + kCol * dW + colW * sW; + + if (static_cast(volDep) < + static_cast(iD) && + static_cast(volRow) < + static_cast(iH) && + static_cast(volCol) < + static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *vol += *col; + } } - }; - - samediff::Threads::parallel_tad(func, 0, bS); - - } else { - - auto func = PRAGMA_THREADS_FOR { - T* col, *vol; - int volDep, volRow, volCol; - - for (int b = start; b < stop; b++) { - for (int colD = 0; colD < oD; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) { - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *vol += *col; - } - } - } - } - } - } - } - } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, bS); + + } else { + auto func = PRAGMA_THREADS_FOR { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start; b < stop; b++) { + for (int colD = 0; colD < oD; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(volDep) < + static_cast(iD) && + static_cast(volRow) < + static_cast(iH) && + static_cast(volCol) < + static_cast(iW)) { + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *vol += *col; + } } - }; - - samediff::Threads::parallel_tad(func, 0, bS); + } + } + } } + } } + } + }; -void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, bS); + } } +void ConvolutionUtils::col2vol(sd::graph::Context& block, + const NDArray& columns, NDArray& volume, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, + (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp index 45e66651c19c..cd05df093c7e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2d.cpp @@ -18,90 +18,115 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - - } - -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +static void conv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); + + std::vector permutForOutput; + + if (isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if (0 == wFormat) + wAxes = {0, 1, 2}; + else if (1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), + input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS * oH * oW, oC}, output->dataType(), + output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot( + &col, weights, &mmulResult, {3, 4, 5}, wAxes, + {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if (isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp index 6a01a4a4dda5..9c367713d961 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_conv2dBP.cpp @@ -18,110 +18,144 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include #include -#include #include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); + + std::vector gradOaxesForDot; + + if (!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + + if (0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if (gradW) { + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to + // [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 4, 5}, gradOaxesForDot, + wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, + // oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, + gradOaxesForDot); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, + // oW] [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, + // oH, oW] [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, + // bS, oH, oW] + sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); + + helpers::col2im( + *block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2dBP_, + (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp index fa86dbd6c199..6be9e4014710 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2d.cpp @@ -18,84 +18,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - } - -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } - +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 0, 4, 5, 2, 3}, + {iC, bS * oH * oW, + kH * kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> + // [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if (!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray outputReshaped = + output->reshape(output->ordering(), outReShape, false); + + helpers::im2col( + *output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, + modifWeights, + modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, + // mC] = [iC, bS*oH*oW, mC] + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } + +void ConvolutionUtils::depthwiseConv2d( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp index 7c0d933e235a..5b704ce7f464 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_depthwiseConv2dBP.cpp @@ -18,103 +18,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // +#include +#include +#include #include #include -#include -#include -#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW - - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - } - -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, + const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 2, 3, 0, 4, 5}, + {iC, kH * kW, + bS * oH * oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if (!isNCHW) { + gradOreShape = {bS, oH, oW, iC, + mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {3, 0, 1, 2}, + {iC, mC, + bS * oH * + oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + gradOreShape = {bS, iC, mC, oH, + oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {1, 0, 2, 3}, + {iC, mC, + bS * oH * + oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col( + *input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, + modifGradO1, + modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, + // bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); + gradO->reduceAlongDimension( + reduce::Sum, *gradBR, {0, indOoH, indOoH + 1}); // sum over bS, oH, oW + + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, + modifColumns); // [iC, kH*kW, mC] x [iC, mC, + // bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im( + *input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } +} +void ConvolutionUtils::depthwiseConv2dBP( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2dBP_, + (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp index 26dc4f99eae3..7653d985624f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp @@ -18,206 +18,248 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling2d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iH = input.sizeAt(2); - const int iW = input.sizeAt(3); - const int oC = output.sizeAt(1); - const int oH = output.sizeAt(2); - const int oW = output.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const int kProd = kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T max = -DataTypeUtils::max(); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T val = pIn[kh + kw]; - if (val > max) - max = val; - } - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += pIn[kh + kw]; - - if (extraParam0 == 0) { //Exclude padding - int a = (hend - hstart) / iStep2 + ((hend - hstart) % iStep2 == 0 ? 0 : 1); - int r = (wend - wstart) / iStep3 + ((wend - wstart) % iStep3 == 0 ? 0 : 1); - sum /= static_cast(a * r); // Accounts for dilation - } else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); +template +static void pooling2d_(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iH = input.sizeAt(2); + const int iW = input.sizeAt(3); + const int oC = output.sizeAt(1); + const int oH = output.sizeAt(2); + const int oW = output.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + + const Nd4jLong iStep2 = dH * iStride2; + const Nd4jLong iStep3 = dW * iStride3; + const int kProd = kH * kW; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T max = -DataTypeUtils::max(); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T val = pIn[kh + kw]; + if (val > max) max = val; + } + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + max; } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend; - T *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, static_cast((T) 1.f) / extraParam0); - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += pIn[kh + kw]; + + if (extraParam0 == 0) { // Exclude padding + int a = (hend - hstart) / iStep2 + + ((hend - hstart) % iStep2 == 0 ? 0 : 1); + int r = (wend - wstart) / iStep3 + + ((wend - wstart) % iStep3 == 0 ? 0 : 1); + sum /= static_cast(a * r); // Accounts for dilation + } else if (extraParam0 == 1) // Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + sum; } + } } - - void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend; + T* pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + sum = sd::math::nd4j_pow( + sum, static_cast((T)1.f) / extraParam0); + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = + sum; + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling2d: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw ""; + } } + +void ConvolutionUtils::pooling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, + (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, + poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp index 03f34bfae342..7dd18e530e6f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp @@ -18,289 +18,322 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iH, iW] - // gradI [bS, iC, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iH = gradI.sizeAt(2); - const int iW = gradI.sizeAt(3); - const int oC = gradO.sizeAt(1); - const int oH = gradO.sizeAt(2); - const int oW = gradO.sizeAt(3); - - nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const Nd4jLong gIStep2 = dH*gIStride2; - const Nd4jLong gIStep3 = dW*gIStride3; - const int kProd = kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T valIn = pIn[kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kh * iStride2 + kw * iStride3]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + maxKW * gIStride3] += valO; - } - } - } - } +template +static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + // input [bS, iC, iH, iW] + // gradI [bS, iC, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iH = gradI.sizeAt(2); + const int iW = gradI.sizeAt(3); + const int oC = gradO.sizeAt(1); + const int oH = gradO.sizeAt(2); + const int oW = gradO.sizeAt(3); + + nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong iStep2 = dH * iStride2; + const Nd4jLong iStep3 = dW * iStride3; + const Nd4jLong gIStep2 = dH * gIStride2; + const Nd4jLong gIStep3 = dW * gIStride3; + const int kProd = kH * kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && + iStride2 == gIStride2 && iStride3 == gIStride3; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if (sameStrides) { + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T valIn = pIn[kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= gIStride2; - hend *= gIStride2; - wstart *= gIStride3; - wend *= gIStride3; - - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if ((int) extraParam0 == 0) //Exclude padding - valO /= static_cast(sd::math::nd4j_ceil( - static_cast(hend - hstart) / static_cast(gIStep2))) * - static_cast(sd::math::nd4j_ceil( - static_cast(wend - wstart) / - static_cast(gIStep3))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) - pgI[kh + kw] += valO; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = sameStrides ? gI + (pIn - in) : gI + b * gIStride0 + c * gIStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / - dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / - dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / - dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - sum = static_cast(0.f); - valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3]; - - if (sameStrides) { - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO * sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f) * - sd::math::nd4j_sgn(pIn[kh + kw]); - } else { - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow( - sd::math::nd4j_abs(pIn[kh * iStride2 + kw * iStride3]), - extraParam0); - - valO *= sd::math::nd4j_pow(sum, - ((T) 1. - extraParam0) / extraParam0); - - for (Nd4jLong kh = hstart; kh < hend; kh += dH) { - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kh * iStride2 + kw * iStride3]; - pgI[kh * gIStride2 + kw * gIStride3] += valO * - sd::math::nd4j_pow( - sd::math::nd4j_abs( - inVal), - extraParam0 - 1.f) * - sd::math::nd4j_sgn( - inVal); - } - } - } - } - } - } + } + gI[pIn - in + maxKH + maxKW] += valO; + } else { + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = pIn[kh * iStride2 + kw * iStride3]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; } - }; + } - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + gI[b * gIStride0 + c * gIStride1 + maxKH * gIStride2 + + maxKW * gIStride3] += valO; + } } - else { - nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect pooling2dBP mode"); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pgI = gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + hstart *= gIStride2; + hend *= gIStride2; + wstart *= gIStride3; + wend *= gIStride3; + + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if ((int)extraParam0 == 0) // Exclude padding + valO /= static_cast(sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(gIStep2))) * + static_cast(sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + gIStep3))); // Accounts for dilation + else if ((int)extraParam0 == 1) // Include padding + valO /= kProd; + + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep3) + pgI[kh + kw] += valO; } + } } - -void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong hstart, wstart, hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + pgI = sameStrides ? gI + (pIn - in) + : gI + b * gIStride0 + c * gIStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (hstart < 0) + hstart += + dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); + if (wstart < 0) + wstart += + dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); + if (hend > iH) + hend -= + dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); + if (wend > iW) + wend -= + dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); + + sum = static_cast(0.f); + valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + + ow * oStride3]; + + if (sameStrides) { + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + pgI[kh + kw] += valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kh + kw]), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(pIn[kh + kw]); + } else { + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs( + pIn[kh * iStride2 + kw * iStride3]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += dH) { + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = pIn[kh * iStride2 + kw * iStride3]; + pgI[kh * gIStride2 + kw * gIStride3] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * + sd::math::nd4j_sgn(inVal); + } + } + } + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling2dBP: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw std::runtime_error("Incorrect pooling2dBP mode"); + } } + +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, + (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, + dW, poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp index 04d5f993ad4b..860b5f774ba2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3d.cpp @@ -18,244 +18,251 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling3d_(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iD, iH, iW] - // output is [bS, iC, oD, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iD = input.sizeAt(2); - const int iH = input.sizeAt(3); - const int iW = input.sizeAt(4); - const int oC = output.sizeAt(1); - const int oD = output.sizeAt(2); - const int oH = output.sizeAt(3); - const int oW = output.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - const Nd4jLong oStride4 = output.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const int kProd = kD*kH*kW; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = -DataTypeUtils::max(); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T val = pIn[kd + kh + kw]; - if (val > sum) - sum = val; - } - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += pIn[kd + kh + kw]; - - if (extraParam0 == 0) //Exclude padding - sum /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(iStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(iStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(iStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } +template +static void pooling3d_(sd::graph::Context& block, const NDArray& input, + NDArray& output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // input is [bS, iC, iD, iH, iW] + // output is [bS, iC, oD, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kDEff = kD + (kD - 1) * (dD - 1); + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iD = input.sizeAt(2); + const int iH = input.sizeAt(3); + const int iW = input.sizeAt(4); + const int oC = output.sizeAt(1); + const int oD = output.sizeAt(2); + const int oH = output.sizeAt(3); + const int oW = output.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + const Nd4jLong oStride4 = output.stridesOf()[4]; + const Nd4jLong iStep2 = dD * iStride2; + const Nd4jLong iStep3 = dH * iStride3; + const Nd4jLong iStep4 = dW * iStride4; + const int kProd = kD * kH * kW; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = -DataTypeUtils::max(); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T val = pIn[kd + kh + kw]; + if (val > sum) sum = val; } - }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_3D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int od = start_z; od < stop_z; od += inc_z) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - else { - nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw std::runtime_error("Incorrect poooling3d mode"); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += pIn[kd + kh + kw]; + + if (extraParam0 == 0) // Exclude padding + sum /= sd::math::nd4j_ceil( + static_cast(dend - dstart) / + static_cast(iStep2)) * + sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(iStep3)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + iStep4)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } } + } } - -void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_3D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int od = start_z; od < stop_z; od += inc_z) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0); + + sum = sd::math::nd4j_pow(sum, (T)1.f / extraParam0); + + out[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4] = sum; + } + } + } } - + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling3d: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw std::runtime_error("Incorrect poooling3d mode"); + } } + +void ConvolutionUtils::pooling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, + (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, + pW, dD, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp index 02f6f57aca77..7bea63eedb56 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling3dBP.cpp @@ -18,309 +18,337 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// - template - static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iD, iH, iW] - // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oD, oH, oW] - - // initial zeroing of gradI - gradI.nullify(); - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); - - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); - - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iD = gradI.sizeAt(2); - const int iH = gradI.sizeAt(3); - const int iW = gradI.sizeAt(4); - const int oC = gradO.sizeAt(1); - const int oD = gradO.sizeAt(2); - const int oH = gradO.sizeAt(3); - const int oW = gradO.sizeAt(4); - - nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); - - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong gIStride0 = gradI.stridesOf()[0]; - const Nd4jLong gIStride1 = gradI.stridesOf()[1]; - const Nd4jLong gIStride2 = gradI.stridesOf()[2]; - const Nd4jLong gIStride3 = gradI.stridesOf()[3]; - const Nd4jLong gIStride4 = gradI.stridesOf()[4]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong oStride4 = gradO.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const Nd4jLong gIStep2 = dD*gIStride2; - const Nd4jLong gIStep3 = dH*gIStride3; - const Nd4jLong gIStep4 = dW*gIStride4; - const int kProd = kD*kH*kW; - - const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && iStride2 == gIStride2 && iStride3 == gIStride3 && iStride4 == gIStride4; - - if(poolingMode == 0) { // max - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = -DataTypeUtils::max(); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - maxKD = dstart; - maxKH = hstart; - maxKW = wstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T valIn = pIn[kd + kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKD + maxKH + maxKW] += valO; - } else { - - // we set these to default values - maxKH = hstart; - maxKW = wstart; - maxKD = dstart; - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - T valIn = pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - - gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + maxKH * gIStride3 + maxKW * gIStride4] += valO; - } - } - } - } +template +static void pooling3dBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + // input [bS, iC, iD, iH, iW] + // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oD, oH, oW] + + // initial zeroing of gradI + gradI.nullify(); + + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); + + const int kDEff = kD + (kD - 1) * (dD - 1); + const int kHEff = kH + (kH - 1) * (dH - 1); + const int kWEff = kW + (kW - 1) * (dW - 1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iD = gradI.sizeAt(2); + const int iH = gradI.sizeAt(3); + const int iW = gradI.sizeAt(4); + const int oC = gradO.sizeAt(1); + const int oD = gradO.sizeAt(2); + const int oH = gradO.sizeAt(3); + const int oW = gradO.sizeAt(4); + + nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); + + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong gIStride0 = gradI.stridesOf()[0]; + const Nd4jLong gIStride1 = gradI.stridesOf()[1]; + const Nd4jLong gIStride2 = gradI.stridesOf()[2]; + const Nd4jLong gIStride3 = gradI.stridesOf()[3]; + const Nd4jLong gIStride4 = gradI.stridesOf()[4]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong oStride4 = gradO.stridesOf()[4]; + const Nd4jLong iStep2 = dD * iStride2; + const Nd4jLong iStep3 = dH * iStride3; + const Nd4jLong iStep4 = dW * iStride4; + const Nd4jLong gIStep2 = dD * gIStride2; + const Nd4jLong gIStep3 = dH * gIStride3; + const Nd4jLong gIStep4 = dW * gIStride4; + const int kProd = kD * kH * kW; + + const bool sameStrides = iStride0 == gIStride0 && iStride1 == gIStride1 && + iStride2 == gIStride2 && iStride3 == gIStride3 && + iStride4 == gIStride4; + + if (poolingMode == 0) { // max + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = -DataTypeUtils::max(); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + maxKD = dstart; + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T valIn = pIn[kd + kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * gIStride0 + c * gIStride1; - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - dstart *= gIStride2; - dend *= gIStride2; - hstart *= gIStride3; - hend *= gIStride3; - wstart *= gIStride4; - wend *= gIStride4; - - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (extraParam0 == 0) //Exclude padding - valO /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(gIStep2)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(gIStep3)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(gIStep4)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - valO /= kProd; - - for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) - pgI[kd + kh + kw] += valO; - } - } - } + } + gI[pIn - in + maxKD + maxKH + maxKW] += valO; + } else { + // we set these to default values + maxKH = hstart; + maxKW = wstart; + maxKD = dstart; + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + T valIn = + pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; } - } - }; + } - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + gI[b * gIStride0 + c * gIStride1 + maxKD * gIStride2 + + maxKH * gIStride3 + maxKW * gIStride4] += valO; + } + } } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - auto func = PRAGMA_THREADS_FOR_2D { - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; - - for (int b = start_x; b < stop_x; b++) { - for (int c = start_y; c < stop_y; c++) { - for (int od = 0; od < oD; od++) { - for (int oh = 0; oh < oH; ++oh) { - for (int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = gI + (pIn - in); - - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if (dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if (hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if (wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if (dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if (hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if (wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - sum = static_cast(0.); - valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4]; - - if (sameStrides) { - - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd + kh + kw]),extraParam0 - (T) 1.f) * sd::math::nd4j_sgn(pIn[kd + kh + kw]); - } else { - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(pIn[kd * iStride2 + kh * iStride3 + kw * iStride4]), extraParam0); - - valO *= sd::math::nd4j_pow(sum, ((T) 1.f - extraParam0) / extraParam0); - - for (Nd4jLong kd = dstart; kd < dend; kd += dD) - for (Nd4jLong kh = hstart; kh < hend; kh += dH) - for (Nd4jLong kw = wstart; kw < wend; kw += dW) { - const auto inVal = pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; - pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += valO * sd::math::nd4j_pow(sd::math::nd4j_abs(inVal), extraParam0 - 1.f) * sd::math::nd4j_sgn(inVal); - } - } - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 1) { // avg + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pgI = gI + b * gIStride0 + c * gIStride1; + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + dstart *= gIStride2; + dend *= gIStride2; + hstart *= gIStride3; + hend *= gIStride3; + wstart *= gIStride4; + wend *= gIStride4; + + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (extraParam0 == 0) // Exclude padding + valO /= sd::math::nd4j_ceil( + static_cast(dend - dstart) / + static_cast(gIStep2)) * + sd::math::nd4j_ceil( + static_cast(hend - hstart) / + static_cast(gIStep3)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast( + gIStep4)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + valO /= kProd; + + for (Nd4jLong kd = dstart; kd < dend; kd += gIStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += gIStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += gIStep4) + pgI[kd + kh + kw] += valO; + } } - else { - nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } + /*************************************************************************/ + else if (poolingMode == 2) { // pnorm + auto func = PRAGMA_THREADS_FOR_2D { + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; + + for (int b = start_x; b < stop_x; b++) { + for (int c = start_y; c < stop_y; c++) { + for (int od = 0; od < oD; od++) { + for (int oh = 0; oh < oH; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pIn = in + b * iStride0 + c * iStride1; + pgI = gI + (pIn - in); + + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + sum = static_cast(0.); + valO = gO[b * oStride0 + c * oStride1 + od * oStride2 + + oh * oStride3 + ow * oStride4]; + + if (sameStrides) { + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + pgI[kd + kh + kw] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(pIn[kd + kh + kw]), + extraParam0 - (T)1.f) * + sd::math::nd4j_sgn(pIn[kd + kh + kw]); + } else { + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs( + pIn[kd * iStride2 + kh * iStride3 + + kw * iStride4]), + extraParam0); + + valO *= sd::math::nd4j_pow( + sum, ((T)1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += dD) + for (Nd4jLong kh = hstart; kh < hend; kh += dH) + for (Nd4jLong kw = wstart; kw < wend; kw += dW) { + const auto inVal = + pIn[kD * iStride2 + kh * iStride3 + kw * iStride4]; + pgI[kd * gIStride2 + kh * gIStride3 + kw * gIStride4] += + valO * + sd::math::nd4j_pow( + sd::math::nd4j_abs(inVal), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(inVal); + } + } + } } + } } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + nd4j_printf( + "ConvolutionUtils::pooling3dBP: pooling mode argument can take three " + "values only: 0, 1, 2, but got %i instead !\n", + poolingMode); + throw ""; + } +} - void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - } - } +void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kD, const int kH, + const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW, const int poolingMode, + const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, + (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, dD, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp index 742f88c3bad2..0f76569916a3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_sconv2d.cpp @@ -18,56 +18,84 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) +static void sconv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + NDArray* outputDepth = output; + if (weightsPoint) // if pointwise convolution is expected + outputDepth = + new NDArray(output->ordering(), + !isNCHW ? std::vector({bS, oH, oW, iC * mC}) + : std::vector({bS, iC * mC, oH, oW}), + input->dataType(), input->getContext()); - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } - } - -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); - } + // ----- perform depthwise convolution (if weightsPoint is absent then oC = + // iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d( + block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, + kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1, + 1, 1, 1, 0, 0, 1, 1, paddingMode, isNCHW, + wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } } + +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, + const NDArray* weightsPoint, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), sconv2d_, + (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp index ffdd5c34b361..6e07a5e0110b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2d.cpp @@ -18,63 +18,71 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); +static void upsampling2d_(const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oH = output.sizeAt(dimIH); - const uint oW = output.sizeAt(dimIH + 1); + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimIH]; - const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oH = output.sizeAt(dimIH); + const uint oW = output.sizeAt(dimIH + 1); - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimIH]; - const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimIH]; + const Nd4jLong xStride3 = input.stridesOf()[dimIH + 1]; - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3; - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < oW; ++w) { - xCoord2 = h / factorH; - xCoord3 = w / factorW; + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimIH]; + const Nd4jLong zStride3 = output.stridesOf()[dimIH + 1]; - z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3]; - } - } - } - } - }; + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3; + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < oW; ++w) { + xCoord2 = h / factorH; + xCoord3 = w / factorW; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); + z[b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3] = + x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + + xCoord3 * xStride3]; + } } + } + } + }; - -void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oH, 1); } +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, + (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp index aba46aabc119..623ae6645411 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling2dBP.cpp @@ -18,69 +18,74 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimIH = isNCHW ? 2 : 1; - const uint dimIC = isNCHW ? 1 : 3; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iH = gradI.sizeAt(dimIH); - const uint iW = gradI.sizeAt(dimIH + 1); - - const uint factorH = gradO.sizeAt(dimIH) / iH; - const uint factorW = gradO.sizeAt(dimIH + 1) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint h = start_z; h < stop_z; h += inc_z) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; - - z[zOffset] = 0; - - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + xw * xStride3]; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); +static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) gradI has shape [bS, iC, iH, iW] (NCHW) + // or [bS, iH, iW, iC] (NHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimIH = isNCHW ? 2 : 1; + const uint dimIC = isNCHW ? 1 : 3; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iH = gradI.sizeAt(dimIH); + const uint iW = gradI.sizeAt(dimIH + 1); + + const uint factorH = gradO.sizeAt(dimIH) / iH; + const uint factorW = gradO.sizeAt(dimIH + 1) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimIH]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimIH + 1]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimIH]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimIH + 1]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint h = start_z; h < stop_z; h += inc_z) { + for (uint w = 0; w < iW; ++w) { + const auto zOffset = + b * zStride0 + c * zStride1 + h * zStride2 + w * zStride3; + + z[zOffset] = 0; + + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += x[b * xStride0 + c * xStride1 + xh * xStride2 + + xw * xStride3]; + } } + } + } + }; -void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iH, 1); } +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, + (gradO, gradI, isNCHW), FLOAT_TYPES); } -} + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp index 7b86ec5a150f..b73981e763aa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3d.cpp @@ -18,72 +18,80 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = input.sizeAt(0); - const uint iC = input.sizeAt(dimIC); - const uint oD = output.sizeAt(dimID); - const uint oH = output.sizeAt(dimID + 1); - const uint oW = output.sizeAt(dimID + 2); - - const Nd4jLong xStride0 = input.stridesOf()[0]; - const Nd4jLong xStride1 = input.stridesOf()[dimIC]; - const Nd4jLong xStride2 = input.stridesOf()[dimID]; - const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = output.stridesOf()[0]; - const Nd4jLong zStride1 = output.stridesOf()[dimIC]; - const Nd4jLong zStride2 = output.stridesOf()[dimID]; - const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - uint xCoord2, xCoord3, xCoord4; - - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < oH; ++h) { - for (uint w = 0; w < oW; ++w) { - - xCoord2 = d / factorD; - xCoord3 = h / factorH; - xCoord4 = w / factorW; - - z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4] = x[ - b * xStride0 + c * xStride1 + xCoord2 * xStride2 + xCoord3 * xStride3 + - xCoord4 * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); - } - - void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); +static void upsampling3d_(const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = input.sizeAt(0); + const uint iC = input.sizeAt(dimIC); + const uint oD = output.sizeAt(dimID); + const uint oH = output.sizeAt(dimID + 1); + const uint oW = output.sizeAt(dimID + 2); + + const Nd4jLong xStride0 = input.stridesOf()[0]; + const Nd4jLong xStride1 = input.stridesOf()[dimIC]; + const Nd4jLong xStride2 = input.stridesOf()[dimID]; + const Nd4jLong xStride3 = input.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = input.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = output.stridesOf()[0]; + const Nd4jLong zStride1 = output.stridesOf()[dimIC]; + const Nd4jLong zStride2 = output.stridesOf()[dimID]; + const Nd4jLong zStride3 = output.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = output.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + uint xCoord2, xCoord3, xCoord4; + + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < oH; ++h) { + for (uint w = 0; w < oW; ++w) { + xCoord2 = d / factorD; + xCoord3 = h / factorH; + xCoord4 = w / factorW; + + z[b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + + w * zStride4] = + x[b * xStride0 + c * xStride1 + xCoord2 * xStride2 + + xCoord3 * xStride3 + xCoord4 * xStride4]; + } + } } + } + } + }; + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, oD, 1); } + +void ConvolutionUtils::upsampling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, + (input, output, factorD, factorH, factorW, isNCDHW), + FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp index 93c2746fbfc6..40480b9e6360 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_upsampling3dBP.cpp @@ -18,78 +18,82 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - - const T* x = gradO.bufferAsT(); - T* z = gradI.bufferAsT(); - - const uint dimID = isNCDHW ? 2 : 1; - const uint dimIC = isNCDHW ? 1 : 4; - - const uint bS = gradI.sizeAt(0); - const uint iC = gradI.sizeAt(dimIC); - const uint iD = gradI.sizeAt(dimID); - const uint iH = gradI.sizeAt(dimID + 1); - const uint iW = gradI.sizeAt(dimID + 2); - - const uint factorD = gradO.sizeAt(dimID) / iD; - const uint factorH = gradO.sizeAt(dimID + 1) / iH; - const uint factorW = gradO.sizeAt(dimID + 2) / iW; - - const Nd4jLong xStride0 = gradO.stridesOf()[0]; - const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; - const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; - const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; - const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; - - const Nd4jLong zStride0 = gradI.stridesOf()[0]; - const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; - const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; - const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; - const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (uint b = start_x; b < stop_x; b += inc_x) { - for (uint c = start_y; c < stop_y; c += inc_y) { - for (uint d = start_z; d < stop_z; d += inc_z) { - for (uint h = 0; h < iH; ++h) { - for (uint w = 0; w < iW; ++w) { - - const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + h * zStride3 + w * zStride4; - - z[zOffset] = 0; - - for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) - for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) - for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) - z[zOffset] += x[b * xStride0 + c * xStride1 + xd * xStride2 + xh * xStride3 + xw * xStride4]; - } - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); - } - - - void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); +static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, + const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + + const T* x = gradO.bufferAsT(); + T* z = gradI.bufferAsT(); + + const uint dimID = isNCDHW ? 2 : 1; + const uint dimIC = isNCDHW ? 1 : 4; + + const uint bS = gradI.sizeAt(0); + const uint iC = gradI.sizeAt(dimIC); + const uint iD = gradI.sizeAt(dimID); + const uint iH = gradI.sizeAt(dimID + 1); + const uint iW = gradI.sizeAt(dimID + 2); + + const uint factorD = gradO.sizeAt(dimID) / iD; + const uint factorH = gradO.sizeAt(dimID + 1) / iH; + const uint factorW = gradO.sizeAt(dimID + 2) / iW; + + const Nd4jLong xStride0 = gradO.stridesOf()[0]; + const Nd4jLong xStride1 = gradO.stridesOf()[dimIC]; + const Nd4jLong xStride2 = gradO.stridesOf()[dimID]; + const Nd4jLong xStride3 = gradO.stridesOf()[dimID + 1]; + const Nd4jLong xStride4 = gradO.stridesOf()[dimID + 2]; + + const Nd4jLong zStride0 = gradI.stridesOf()[0]; + const Nd4jLong zStride1 = gradI.stridesOf()[dimIC]; + const Nd4jLong zStride2 = gradI.stridesOf()[dimID]; + const Nd4jLong zStride3 = gradI.stridesOf()[dimID + 1]; + const Nd4jLong zStride4 = gradI.stridesOf()[dimID + 2]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (uint b = start_x; b < stop_x; b += inc_x) { + for (uint c = start_y; c < stop_y; c += inc_y) { + for (uint d = start_z; d < stop_z; d += inc_z) { + for (uint h = 0; h < iH; ++h) { + for (uint w = 0; w < iW; ++w) { + const auto zOffset = b * zStride0 + c * zStride1 + d * zStride2 + + h * zStride3 + w * zStride4; + + z[zOffset] = 0; + + for (uint xd = d * factorD; xd < d * factorD + factorD; ++xd) + for (uint xh = h * factorH; xh < h * factorH + factorH; ++xh) + for (uint xw = w * factorW; xw < w * factorW + factorW; ++xw) + z[zOffset] += + x[b * xStride0 + c * xStride1 + xd * xStride2 + + xh * xStride3 + xw * xStride4]; + } + } } + } + } + }; + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, iD, 1); } + +void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, + (gradO, gradI, isNCHW), FLOAT_TYPES); } + +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp index 4c8b5bad1474..5f560950d925 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_vol2col.cpp @@ -18,130 +18,153 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include namespace sd { - namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] template -static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* colBuff = columns.bufferAsT(); - T* volBuff = const_cast(volume).bufferAsT(); - - - if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.shapeInfo()) && shape::strideDescendingCAscendingF(columns.shapeInfo())) { - - auto func = PRAGMA_THREADS_FOR_3D { - T *col, *vol; - int volDep, volRow, volCol; - - for (int b = start_x; b < stop_x; b += inc_x) { - for (int c = start_y; c < stop_y; c += inc_y) { - for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } +static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW) { + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; + + T* colBuff = columns.bufferAsT(); + T* volBuff = const_cast(volume).bufferAsT(); + + if (volume.ordering() == 'c' && columns.ordering() == 'c' && + shape::strideDescendingCAscendingF(volume.shapeInfo()) && + shape::strideDescendingCAscendingF(columns.shapeInfo())) { + auto func = PRAGMA_THREADS_FOR_3D { + T *col, *vol; + int volDep, volRow, volCol; + + for (int b = start_x; b < stop_x; b += inc_x) { + for (int c = start_y; c < stop_y; c += inc_y) { + for (int kDep = start_z; kDep < stop_z; kDep += inc_z) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= + static_cast(iD) || + static_cast(volRow) >= + static_cast(iH) || + static_cast(volCol) >= + static_cast(iW)) + *col = static_cast(0.); + else { + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *col = *vol; + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); - - } else { - - auto func = PRAGMA_THREADS_FOR_2D { - T *col, *vol; - int volDep, volRow, volCol; - for (int b = start_x; b < stop_x; b++) { - for (int colD = start_y; colD < stop_y; colD++) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD * sD; - volRow = (-pH + kRow * dH) + colH * sH; - volCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kDep * colStride2 + kRow * colStride3 + kCol * colStride4 + colD * colStride5 + colH * colStride6 + colW * colStride7; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.f); - else { - vol = volBuff + b * volStride0 + c * volStride1 + volDep * volStride2 + volRow * volStride3 + volCol * volStride4; - *col = *vol; - } - } - } - } - } - } - } - } + } + } + } + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1, 0, kD, 1); + + } else { + auto func = PRAGMA_THREADS_FOR_2D { + T *col, *vol; + int volDep, volRow, volCol; + for (int b = start_x; b < stop_x; b++) { + for (int colD = start_y; colD < stop_y; colD++) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + volDep = (-pD + kDep * dD) + colD * sD; + volRow = (-pH + kRow * dH) + colH * sH; + volCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kDep * colStride2 + kRow * colStride3 + + kCol * colStride4 + colD * colStride5 + + colH * colStride6 + colW * colStride7; + + if (static_cast(volDep) >= + static_cast(iD) || + static_cast(volRow) >= + static_cast(iH) || + static_cast(volCol) >= + static_cast(iW)) + *col = static_cast(0.f); + else { + vol = volBuff + b * volStride0 + c * volStride1 + + volDep * volStride2 + volRow * volStride3 + + volCol * volStride4; + *col = *vol; + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); - //func(0, 0, bS, 1, 0, oD, 1); + } + } + } } + } } + } + }; -void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oD, 1); + // func(0, 0, bS, 1, 0, oD, 1); + } } +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& volume, + NDArray& columns, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, + (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp index ab6503946c14..5165f6396d0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.cpp @@ -33,31 +33,38 @@ limitations under the License. // @author sgazeos@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { -// ------------------------------------------------------------------------------------------------------------------ // -// ------------------------------------------------------------------------------------------------------------------ // -// crop and resize helper functor: +// ------------------------------------------------------------------------------------------------------------------ +// // +// ------------------------------------------------------------------------------------------------------------------ +// // crop and resize helper functor: // \@param context - launch context for operation -// \@param images - batch of images (4D tensor) with shape {batch, width, height, channels} with given type +// \@param images - batch of images (4D tensor) with shape {batch, width, +// height, channels} with given type // \@param boxes - float boxes for crop // \@param indices - integer boxes indices for crop // \@param cropSize - integer size (newWidth, newHeight) -// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation algorithm +// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation +// algorithm // \@param extrapolationVal - radix to increase/decrease image // \@param crops - output image batch (4D with given type) // - void - cropAndResizeFunctor(sd::LaunchContext * context, NDArray const *images, NDArray const *boxes, - NDArray const *indices, NDArray const *cropSize, - int method, double extrapolationVal, NDArray *crops) { - BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cropAndResizeFunctor(sd::LaunchContext *context, NDArray const *images, + NDArray const *boxes, NDArray const *indices, + NDArray const *cropSize, int method, + double extrapolationVal, NDArray *crops) { + BUILD_TRIPLE_SELECTOR( + images->dataType(), boxes->dataType(), indices->dataType(), + cropAndResizeFunctor_, + (images, boxes, indices, cropSize, method, extrapolationVal, crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp index c7d29c47131e..58d09800f5b3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/crop_and_resize.hpp @@ -18,106 +18,118 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - const int batchSize = images->sizeAt(0); - const int imageHeight = images->sizeAt(1); - const int imageWidth = images->sizeAt(2); - - const int numBoxes = crops->sizeAt(0); - const int cropHeight = crops->sizeAt(1); - const int cropWidth = crops->sizeAt(2); - const int depth = crops->sizeAt(3); +namespace ops { +namespace helpers { +template +void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, + NDArray const *indices, NDArray const *cropSize, + int method, double extrapolationVal, + NDArray *crops) { + const int batchSize = images->sizeAt(0); + const int imageHeight = images->sizeAt(1); + const int imageWidth = images->sizeAt(2); - for (auto b = 0; b < numBoxes; ++b) { - T y1 = boxes->t(b, 0); - T x1 = boxes->t(b, 1); - T y2 = boxes->t(b, 2); - T x2 = boxes->t(b, 3); + const int numBoxes = crops->sizeAt(0); + const int cropHeight = crops->sizeAt(1); + const int cropWidth = crops->sizeAt(2); + const int depth = crops->sizeAt(3); - int bIn = indices->e(b); - if (bIn >= batchSize) { - continue; - } + for (auto b = 0; b < numBoxes; ++b) { + T y1 = boxes->t(b, 0); + T x1 = boxes->t(b, 1); + T y2 = boxes->t(b, 2); + T x2 = boxes->t(b, 3); - T heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : T(0); - T widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); + int bIn = indices->e(b); + if (bIn >= batchSize) { + continue; + } - auto func = PRAGMA_THREADS_FOR { - for (auto y = start; y < stop; y++) { - const float inY = (cropHeight > 1) - ? y1 * (imageHeight - 1) + y * heightScale - : 0.5 * (y1 + y2) * (imageHeight - 1); + T heightScale = (cropHeight > 1) + ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) + : T(0); + T widthScale = + (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); - if (inY < 0 || inY > imageHeight - 1) { - for (auto x = 0; x < cropWidth; ++x) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - } - continue; - } - if (method == 0 /* bilinear */) { - const int topYIndex = sd::math::p_floor(inY); - const int bottomYIndex = sd::math::p_ceil(inY); - const float y_lerp = inY - topYIndex; + auto func = PRAGMA_THREADS_FOR { + for (auto y = start; y < stop; y++) { + const float inY = (cropHeight > 1) + ? y1 * (imageHeight - 1) + y * heightScale + : 0.5 * (y1 + y2) * (imageHeight - 1); - for (auto x = 0; x < cropWidth; ++x) { - const float in_x = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); + if (inY < 0 || inY > imageHeight - 1) { + for (auto x = 0; x < cropWidth; ++x) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + } + continue; + } + if (method == 0 /* bilinear */) { + const int topYIndex = sd::math::p_floor(inY); + const int bottomYIndex = sd::math::p_ceil(inY); + const float y_lerp = inY - topYIndex; - if (in_x < 0 || in_x > imageWidth - 1) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - int left_x_index = math::p_floor(in_x); - int right_x_index = math::p_ceil(in_x); - T x_lerp = in_x - left_x_index; + for (auto x = 0; x < cropWidth; ++x) { + const float in_x = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); - for (auto d = 0; d < depth; ++d) { - const float topLeft(images->e(bIn, topYIndex, left_x_index, d)); - const float topRight(images->e(bIn, topYIndex, right_x_index, d)); - const float bottomLeft(images->e(bIn, bottomYIndex, left_x_index, d)); - const float bottomRight(images->e(bIn, bottomYIndex, right_x_index, d)); - const float top = topLeft + (topRight - topLeft) * x_lerp; - const float bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; - crops->p(b, y, x, d, top + (bottom - top) * y_lerp); - } - } - } else { // method is "nearest neighbor" - for (auto x = 0; x < cropWidth; ++x) { - const float inX = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); + if (in_x < 0 || in_x > imageWidth - 1) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + continue; + } + int left_x_index = math::p_floor(in_x); + int right_x_index = math::p_ceil(in_x); + T x_lerp = in_x - left_x_index; - if (inX < 0 || inX > imageWidth - 1) { - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - const int closestXIndex = roundf(inX); - const int closestYIndex = roundf(inY); - for (auto d = 0; d < depth; ++d) { - crops->p(b, y, x, d, images->e(bIn, closestYIndex, closestXIndex, d)); - } - } - } - } - }; + for (auto d = 0; d < depth; ++d) { + const float topLeft( + images->e(bIn, topYIndex, left_x_index, d)); + const float topRight( + images->e(bIn, topYIndex, right_x_index, d)); + const float bottomLeft( + images->e(bIn, bottomYIndex, left_x_index, d)); + const float bottomRight( + images->e(bIn, bottomYIndex, right_x_index, d)); + const float top = topLeft + (topRight - topLeft) * x_lerp; + const float bottom = + bottomLeft + (bottomRight - bottomLeft) * x_lerp; + crops->p(b, y, x, d, top + (bottom - top) * y_lerp); + } + } + } else { // method is "nearest neighbor" + for (auto x = 0; x < cropWidth; ++x) { + const float inX = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); - samediff::Threads::parallel_for(func, 0, cropHeight); - } + if (inX < 0 || inX > imageWidth - 1) { + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, extrapolationVal); + } + continue; + } + const int closestXIndex = roundf(inX); + const int closestYIndex = roundf(inY); + for (auto d = 0; d < depth; ++d) { + crops->p(b, y, x, d, + images->e(bIn, closestYIndex, closestXIndex, d)); } + } } - } -} \ No newline at end of file + } + }; + + samediff::Threads::parallel_for(func, 0, cropHeight); + } +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp index 7f662e1c2845..33a5b57498fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/cross.cpp @@ -18,39 +18,39 @@ // @author GS (sgazeos@gmail.com), created on 10/1/2018 // - -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { -void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - auto _a = a->reshape(a->ordering(), {-1, 3}); - auto _b = b->reshape(b->ordering(), {-1, 3}); - auto _o = o->reshape(o->ordering(), {-1, 3}, false); +void crossBatched(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o) { + auto _a = a->reshape(a->ordering(), {-1, 3}); + auto _b = b->reshape(b->ordering(), {-1, 3}); + auto _o = o->reshape(o->ordering(), {-1, 3}, false); - auto tadsA = _a.allTensorsAlongDimension({1}); - auto tadsB = _b.allTensorsAlongDimension({1}); - auto tadsO = _o.allTensorsAlongDimension({1}); + auto tadsA = _a.allTensorsAlongDimension({1}); + auto tadsB = _b.allTensorsAlongDimension({1}); + auto tadsO = _o.allTensorsAlongDimension({1}); - int tads = tadsA.size(); + int tads = tadsA.size(); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto a_ = tadsA.at(e); - auto b_ = tadsB.at(e); - auto o_ = tadsO.at(e); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); - helpers::cross(context, &a_, &b_, &o_); - } - }; + helpers::cross(context, &a_, &b_, &o_); + } + }; - samediff::Threads::parallel_tad(func, 0, tads); + samediff::Threads::parallel_tad(func, 0, tads); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp index 27b73d001eb5..b0e782113727 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp @@ -18,89 +18,104 @@ // // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - T const*input_ptr = reinterpret_cast(input.buffer()); - T *output_ptr = reinterpret_cast(output->buffer()); - - const int batch_size = input.sizeAt(0); - const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); - const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); - const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); - - const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); - const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); - const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); - - const int input_area = input_width * input_height; - const int input_depth_by_input_area = input_depth * input_area; - const int output_depth_by_input_height = output_depth * input_height; - - if (isNHWC) { - const int total_count = batch_size * output_height * output_width * output_depth; - auto func = PRAGMA_THREADS_FOR { - for (auto out_idx = start; out_idx < stop; out_idx++) { - const int d = out_idx % output_depth; - const int out_idx2 = out_idx / output_depth; - const int w = out_idx2 % output_width; - const int out_idx3 = out_idx2 / output_width; - const int h = out_idx3 % output_height; - const int b = out_idx3 / output_height; - - const int in_h = h / block_size; - const int offset_h = h % block_size; - const int in_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * output_depth; - const int in_d = d + offset_d; - const int inp_idx = in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); - (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } else { - const int total_count = batch_size * input_depth_by_input_area; - - auto func = PRAGMA_THREADS_FOR { - for (int input_idx = start; input_idx < stop; input_idx++) { - const int n_bY_bX_oC_iY = input_idx / input_width; - const int iX = input_idx - n_bY_bX_oC_iY * input_width; - - const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; - const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; - - const int n_bY = n_bY_bX / block_size; - const int bX = n_bY_bX - n_bY * block_size; - - const int n = n_bY / block_size; - const int bY = n_bY - n * block_size; - - const int output_idx = bX + block_size * (iX + input_width * (bY + block_size * (oC_iY + n * output_depth_by_input_height))); - - (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } - } - - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input.dataType(); - - BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); - +template +static void __depthToSpace(const NDArray &input, NDArray *output, + int block_size, bool isNHWC) { + T const *input_ptr = reinterpret_cast(input.buffer()); + T *output_ptr = reinterpret_cast(output->buffer()); + + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); + + const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); + const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); + const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); + + const int input_area = input_width * input_height; + const int input_depth_by_input_area = input_depth * input_area; + const int output_depth_by_input_height = output_depth * input_height; + + if (isNHWC) { + const int total_count = + batch_size * output_height * output_width * output_depth; + auto func = PRAGMA_THREADS_FOR { + for (auto out_idx = start; out_idx < stop; out_idx++) { + const int d = out_idx % output_depth; + const int out_idx2 = out_idx / output_depth; + const int w = out_idx2 % output_width; + const int out_idx3 = out_idx2 / output_width; + const int h = out_idx3 % output_height; + const int b = out_idx3 / output_height; + + const int in_h = h / block_size; + const int offset_h = h % block_size; + const int in_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = d + offset_d; + const int inp_idx = + in_d + + input_depth * (in_w + input_width * (in_h + input_height * b)); + (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } else { + const int total_count = batch_size * input_depth_by_input_area; + + auto func = PRAGMA_THREADS_FOR { + for (int input_idx = start; input_idx < stop; input_idx++) { + const int n_bY_bX_oC_iY = input_idx / input_width; + const int iX = input_idx - n_bY_bX_oC_iY * input_width; + + const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; + const int oC_iY = + n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; + + const int n_bY = n_bY_bX / block_size; + const int bX = n_bY_bX - n_bY * block_size; + + const int n = n_bY / block_size; + const int bY = n_bY - n * block_size; + + const int output_idx = + bX + block_size * + (iX + input_width * + (bY + block_size * + (oC_iY + + n * output_depth_by_input_height))); + + (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } } + +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); + + BUILD_SINGLE_SELECTOR(xType, __depthToSpace, + (input, output, block_size, isNHWC), LIBND4J_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void __depthToSpace, + (const NDArray &input, NDArray *output, int block_size, + bool isNHWC); + , LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp index 37abaf559883..bb53d65084ce 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp @@ -19,8 +19,8 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { @@ -30,24 +30,19 @@ namespace helpers { // calculate digamma function for array elements template static void diGamma_(const NDArray& x, NDArray& z) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - z.p(i, diGammaScalar(x.e(i))); - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) z.p(i, diGammaScalar(x.e(i))); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); } void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp index 670ad532239d..351df65532b3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/diag.cpp @@ -25,39 +25,41 @@ namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag template static void _diagFunctor(const NDArray* input, NDArray* output) { + const int inLength = input->lengthOf(); - const int inLength = input->lengthOf(); - - for(int i = 0; i < inLength; ++i) - output->p(i * (inLength + 1), (*input).e(i)); + for (int i = 0; i < inLength; ++i) + output->p(i * (inLength + 1), (*input).e(i)); } - void diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (input, output), LIBND4J_TYPES); - } +void diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto xType = input->dataType(); -BUILD_SINGLE_TEMPLATE(template void _diagFunctor, (const NDArray* input, NDArray* output);, LIBND4J_TYPES); - -void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - const int outLen = output->lengthOf(); - const int inLen = input->lengthOf(); - int i(0), j(0); - while (j < outLen) { - output->p(j, input->e(i)); - i += outLen + 1; - ++j; - } + BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (input, output), LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void _diagFunctor, + (const NDArray* input, NDArray* output); + , LIBND4J_TYPES); +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + const int outLen = output->lengthOf(); + const int inLen = input->lengthOf(); + int i(0), j(0); + while (j < outLen) { + output->p(j, input->e(i)); + i += outLen + 1; + ++j; + } } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp index 1688dcbc421d..ee426c1fd7b7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp @@ -18,83 +18,87 @@ // @autkhor raver119@gmail.com // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// template -static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - // input [bS, iH, iW, iC] - // weights [kH, kW, iC] - // output [bS, oH, oW, iC] - - const X* x = input->bufferAsT(); - const X* y = weights->bufferAsT(); - Z* z = output->bufferAsT(); - - const Nd4jLong* xShapeInfo = input->shapeInfo(); - const Nd4jLong* yShapeInfo = weights->shapeInfo(); - const Nd4jLong* zShapeInfo = output->shapeInfo(); - - const uint bS = input->sizeAt(0); - const uint iH = input->sizeAt(1); - const uint iW = input->sizeAt(2); - const uint iC = input->sizeAt(3); - - const uint kH = weights->sizeAt(0); - const uint kW = weights->sizeAt(1); - - const uint oH = output->sizeAt(1); - const uint oW = output->sizeAt(2); - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto oh = start_y; oh < stop_y; oh += inc_y) { - for (uint ow = 0; ow < oW; ++ow) { - for (uint c = 0; c < iC; ++c) { - - X max = -DataTypeUtils::max(); - - for (uint kh = 0; kh < kH; ++kh) { - const int ih = oh * sH - pH + kh * dH; - if (ih < 0 || ih >= iH) continue; - - for (uint kw = 0; kw < kW; ++kw) { - const int iw = ow * sW - pW + kw * dW; - if (iw < 0 || iw >= iW) continue; - - uint xCoords[4] = { static_cast(b), static_cast(ih), static_cast(iw), c}; - uint yCoords[3] = {kh, kw, c}; - - const X val = x[shape::getOffset(xShapeInfo, xCoords)] + y[shape::getOffset(yShapeInfo, yCoords)]; - if (val > max) - max = val; - } - } - - uint zCoords[4] = { static_cast(b), static_cast(oh), ow, c}; - z[shape::getOffset(zShapeInfo, zCoords)] = static_cast(max); - } - } +static void dilation2d_(NDArray* input, NDArray* weights, NDArray* output, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW) { + // input [bS, iH, iW, iC] + // weights [kH, kW, iC] + // output [bS, oH, oW, iC] + + const X* x = input->bufferAsT(); + const X* y = weights->bufferAsT(); + Z* z = output->bufferAsT(); + + const Nd4jLong* xShapeInfo = input->shapeInfo(); + const Nd4jLong* yShapeInfo = weights->shapeInfo(); + const Nd4jLong* zShapeInfo = output->shapeInfo(); + + const uint bS = input->sizeAt(0); + const uint iH = input->sizeAt(1); + const uint iW = input->sizeAt(2); + const uint iC = input->sizeAt(3); + + const uint kH = weights->sizeAt(0); + const uint kW = weights->sizeAt(1); + + const uint oH = output->sizeAt(1); + const uint oW = output->sizeAt(2); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto oh = start_y; oh < stop_y; oh += inc_y) { + for (uint ow = 0; ow < oW; ++ow) { + for (uint c = 0; c < iC; ++c) { + X max = -DataTypeUtils::max(); + + for (uint kh = 0; kh < kH; ++kh) { + const int ih = oh * sH - pH + kh * dH; + if (ih < 0 || ih >= iH) continue; + + for (uint kw = 0; kw < kW; ++kw) { + const int iw = ow * sW - pW + kw * dW; + if (iw < 0 || iw >= iW) continue; + + uint xCoords[4] = {static_cast(b), static_cast(ih), + static_cast(iw), c}; + uint yCoords[3] = {kh, kw, c}; + + const X val = x[shape::getOffset(xShapeInfo, xCoords)] + + y[shape::getOffset(yShapeInfo, yCoords)]; + if (val > max) max = val; + } } + + uint zCoords[4] = {static_cast(b), static_cast(oh), ow, + c}; + z[shape::getOffset(zShapeInfo, zCoords)] = static_cast(max); + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); } -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES); +void dilation2d(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, + (input, weights, output, sH, sW, pH, pW, dH, dW), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp index 1d5de5cc7d46..8041770a3749 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp @@ -18,157 +18,216 @@ // @author raver119@gmail.com // -#include +#include #include -#include +#include + #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) { - - sd::graph::RandomGenerator nodeRng(3019L, seed); - int inLen = input->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - float val = nodeRng.relativeT(e, T(0.f), T(1.f)); +template +static void dropoutSimple(NDArray const* input, NDArray* output, + double probValue, int seed) { + sd::graph::RandomGenerator nodeRng(3019L, seed); + int inLen = input->lengthOf(); - if (val < probValue) - output->p(e, input->e(e) / probValue); - } - }; + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + float val = nodeRng.relativeT(e, T(0.f), T(1.f)); - samediff::Threads::parallel_for(func, 0, inLen); + if (val < probValue) output->p(e, input->e(e) / probValue); } - BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); - - template - int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - //NativeOps native; - //sd::graph::RandomGenerator nodeRng(seed); //static int dropOutFunctor_(sd::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - //NativeOps native; - //native.reSeedBuffer(nullptr, (long)seed, rng); - //if (newRng ) - if (reduceShape == nullptr){ - dropoutSimple(input, output, probValue, seed); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - - bool fit = true; - for(auto i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } - } - - // check dims to fit input - REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), output->getContext())); - chunk->assign(1.f); - //chunk->applyRandom>(rng, nullptr, chunk.get(), &probValue); - //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); - dropoutSimple(chunk.get(), chunk.get(), probValue, seed); - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; - - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } + }; - return Status::OK(); - } - - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); - -/////////////////////////////////// backrpopagations /////////////////////////////////////////////// - template - static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - - int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); - - if (ND4J_STATUS_OK == res) - for (Nd4jLong e = 0; e < output->lengthOf(); e++) { - if (output->e(e) != 0.f) output->p(e, gradOut->e(e) / probValue); -// else (*output)(e) = T(0.f); + samediff::Threads::parallel_for(func, 0, inLen); +} +BUILD_SINGLE_TEMPLATE(template void dropoutSimple, + (NDArray const* input, NDArray* output, double probValue, + int seed), + FLOAT_TYPES); + +template +int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + // NativeOps native; + // sd::graph::RandomGenerator nodeRng(seed); //static int + // dropOutFunctor_(sd::random::RandomBuffer* rng, NDArray* input, NDArray* + // output, NDArray* reduceShape, int seed, double probValue) { NativeOps + // native; native.reSeedBuffer(nullptr, (long)seed, rng); if (newRng ) + if (reduceShape == nullptr) { + dropoutSimple(input, output, probValue, seed); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + + bool fit = true; + for (auto i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } - - return res; + } } - template - static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - //NativeOps native; - //auto rng = context.getRNG(); - //native.reSeedBuffer(nullptr, (long)seed, rng); - //if (rng == nullptr) - // return ND4J_STATUS_BAD_RNG; - //T probValueArr[] = {probValue, alpha, alpha1, beta}; - //input->template applyRandom>(rng, nullptr, output, probValueArr); - sd::graph::RandomGenerator nodeRng(3019L, seed); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - float randVal = nodeRng.relativeT(e, T(0.f), T(1.f)); - float xVal = input->e(e); - output->p(e, randVal >= probValue ? alpha * beta + alpha1 : alpha * xVal + alpha1); - } - }; + // check dims to fit input + REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), output->getContext())); + chunk->assign(1.f); + // chunk->applyRandom>(rng, nullptr, + // chunk.get(), &probValue); + // NativeOpExecutioner::execRandom(random::DropOutInverted, rng, + // chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), + // &prob); + dropoutSimple(chunk.get(), chunk.get(), probValue, seed); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); + + *dropOutMultiplier += *chunk; + + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } + + return Status::OK(); +} - samediff::Threads::parallel_for(func, 0, input->lengthOf()); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + auto xType = input->dataType(); - return Status::OK(); + BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, + (context, input, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, + (graph::Context & context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue); + , FLOAT_TYPES); + +/////////////////////////////////// backrpopagations +////////////////////////////////////////////////// +template +static int dropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + int res = + dropOutFunctor(context, input, output, reduceShape, seed, probValue); + + if (ND4J_STATUS_OK == res) + for (Nd4jLong e = 0; e < output->lengthOf(); e++) { + if (output->e(e) != 0.f) + output->p(e, gradOut->e(e) / probValue); + // else (*output)(e) = T(0.f); } - template - int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { + return res; +} - int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); - if (res == ND4J_STATUS_OK) { - (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); - } - return res; +template +static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + // NativeOps native; + // auto rng = context.getRNG(); + // native.reSeedBuffer(nullptr, (long)seed, rng); + // if (rng == nullptr) + // return ND4J_STATUS_BAD_RNG; + // T probValueArr[] = {probValue, alpha, alpha1, beta}; + // input->template applyRandom>(rng, nullptr, + // output, probValueArr); + sd::graph::RandomGenerator nodeRng(3019L, seed); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + float randVal = nodeRng.relativeT(e, T(0.f), T(1.f)); + float xVal = input->e(e); + output->p(e, randVal >= probValue ? alpha * beta + alpha1 + : alpha * xVal + alpha1); } + }; - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - BUILD_SINGLE_SELECTOR(gradOut->dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, input->lengthOf()); - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); + return Status::OK(); +} - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(gradOut->dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); +template +int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, + probValue, alpha, alpha1, beta); + if (res == ND4J_STATUS_OK) { + (*output) *= alpha; + (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, + //output, nullptr); + } + return res; +} +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue) { + BUILD_SINGLE_SELECTOR( + gradOut->dataType(), return dropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, + (graph::Context & context, NDArray* input, + NDArray* gradOut, NDArray* output, NDArray* reduceShape, + int seed, double probValue), + FLOAT_TYPES); + +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, + (context, input, output, reduceShape, seed, probValue, + alpha, alpha1, beta), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, + (graph::Context & context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta), + FLOAT_TYPES); + +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + BUILD_SINGLE_SELECTOR(gradOut->dataType(), return alphaDropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, + probValue, alpha, alpha1, beta), + FLOAT_TYPES); } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, + (graph::Context & context, NDArray* input, + NDArray* gradOut, NDArray* output, NDArray* reduceShape, + int seed, double probValue, double alpha, double alpha1, + double beta), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp index 5547c92832f6..722bfce1de92 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dynamic.cpp @@ -17,209 +17,258 @@ // // Created by george on 05.04.18. // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - - template - static void _dynamicPartitionFunctor(NDArray const* input, NDArray const* indices, std::vector& outputList) { - std::vector> outputs(outputList.size()); - int sourceDimsLen = input->rankOf() - indices->rankOf(); - if (sourceDimsLen) { - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - - ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); - - unsigned int outSize = outputList.size(); - - //PRAGMA_OMP_PARALLEL_FOR_IF(outSize > Environment::getInstance()->tadThreshold()) - for (unsigned int i = 0; i < outSize; i++) { - outputs[i].first = outputList[i]; - std::vector outDims(outputs[i].first->rankOf() - 1); - - int r = outputs[i].first->rankOf(); - - for (int k = 1; k < r; k++) - outDims[k - 1] = k; - - ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); - - outputs[i].second = 0; - - //PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > Environment::getInstance()->elementwiseThreshold()) - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if ((*indices).e(e) == i) - listOutForCurrent.at(outputs[i].second++).assign(listOfTensors.at(e)); - } - - } else { - unsigned int outSize = outputList.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - outputs[i].first = outputList[i]; - outputs[i].second = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - outputs[i].first->p(outputs[i].second++, input->e(e)); - } - }; - - samediff::Threads::parallel_tad(func, 0, outSize); - } - } - template - static int _dynamicStitchFunctor(std::vector const& inputs, std::vector const& indices, NDArray* output){ - - int numOfData = inputs.size(); - - if (output->isVector()) { - for (int e = 0; e < numOfData; e++) { - auto data = inputs[e]; - auto index = indices[e]; - for (Nd4jLong i = 0; i < index->lengthOf(); i++) { - Nd4jLong pos = index->e(i); - if (pos < 0) { - nd4j_printf("dynamic_stitch: Index value should be non-negative. But %i was given", pos); - return ND4J_STATUS_VALIDATION; - } - if (pos >= output->lengthOf()) { - nd4j_printf("dynamic_stitch: Index should be less than %i. But %i was given", - output->lengthOf(), pos); - return ND4J_STATUS_VALIDATION; - } - output->p(pos, data->e(i)); - } - } - } - else { - std::vector restDims(output->rankOf() - 1); - for (auto i = restDims.size(); i > 0; i--) - restDims[restDims.size() - i] = output->rankOf() - i; - - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (int e = 0; e < numOfData; e++) { - auto data = inputs[e]; - auto index = indices[e]; - std::vector sourceDims(data->rankOf() - index->rankOf()); - for (auto i = sourceDims.size(); i > 0; i--) - sourceDims[sourceDims.size() - i] = data->rankOf() - i; - - ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims) ; - - for (Nd4jLong i = 0; i < index->lengthOf(); i++) { - auto pos = index->e(i); - if (pos < 0) { - nd4j_printf("dynamic_stitch: Index value should be non-negative. But %i was given", pos); - return ND4J_STATUS_VALIDATION; - } - if (pos >= output->lengthOf()) { - nd4j_printf("dynamic_stitch: Index should be less than %i. But %i was given", - output->lengthOf(), pos); - return ND4J_STATUS_VALIDATION; - } - - listOfOutTensors.at(pos).assign(listOfTensors.at(i)); - } - } - } - return ND4J_STATUS_OK; - } - - template - static void _dynamicPartitionFunctorBP(NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - std::vector> outputs(inputGradientList.size()); - - int sourceDimsLen = input->rankOf() - indices->rankOf(); - if (sourceDimsLen) { // multidimensional case - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - - ResultSet listOfTensors = outputList[0]->allTensorsAlongDimension(sourceDims); - - for (auto i = 0; i < inputGradientList.size(); i++) { - outputs[i].first = inputGradientList[i]; - if (outputs[i].first->rankOf() < 1) continue; // skip empty gradient outs - std::vector outDims(outputs[i].first->rankOf() - 1); - - for (int k = 1; k < outputs[i].first->rankOf(); k++) - outDims[k - 1] = k; - - ResultSet listOutForCurrent = outputs[i].first->allTensorsAlongDimension(outDims); - - outputs[i].second = 0; - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - listOfTensors.at(e).assign(listOutForCurrent.at(outputs[i].second++)); - } - } - else { // one-dimensional case - auto output = outputList[0]; - unsigned int gradsSize = inputGradientList.size(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - outputs[i].first = inputGradientList[i]; - outputs[i].second = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - if (indices->e(e) == i) - output->p(e, outputs[i].first->e(outputs[i].second++)); - } - }; - - samediff::Threads::parallel_tad(func, 0, gradsSize); - } - - outputList[1]->assign(indices); - } - - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - auto xType = input->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctor, (input, indices, outputList), LIBND4J_TYPES); - } - - template - static int _dynamicStitchFunctorBP(std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList){ - throw std::runtime_error("Not umplemented yet"); - } - - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ - auto xType = inputs.at(0)->dataType(); +namespace ops { +namespace helpers { + +template +static void _dynamicPartitionFunctor(NDArray const* input, + NDArray const* indices, + std::vector& outputList) { + std::vector> outputs(outputList.size()); + int sourceDimsLen = input->rankOf() - indices->rankOf(); + if (sourceDimsLen) { + std::vector sourceDims(sourceDimsLen); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctor, (inputs, indices, output), LIBND4J_TYPES); - } + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { - auto xType = inputs.at(0)->dataType(); + ResultSet listOfTensors = input->allTensorsAlongDimension(sourceDims); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), LIBND4J_TYPES); - } + unsigned int outSize = outputList.size(); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - auto xType = input->dataType(); + // PRAGMA_OMP_PARALLEL_FOR_IF(outSize > + // Environment::getInstance()->tadThreshold()) + for (unsigned int i = 0; i < outSize; i++) { + outputs[i].first = outputList[i]; + std::vector outDims(outputs[i].first->rankOf() - 1); - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), LIBND4J_TYPES); - } + int r = outputs[i].first->rankOf(); - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, (NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, (std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList);, LIBND4J_TYPES); + for (int k = 1; k < r; k++) outDims[k - 1] = k; - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctor, (NDArray const* input, NDArray const* indices, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctor, (std::vector const& inputs, std::vector const& indices, NDArray* output);, LIBND4J_TYPES); + ResultSet listOutForCurrent = + outputs[i].first->allTensorsAlongDimension(outDims); + outputs[i].second = 0; + + // PRAGMA_OMP_PARALLEL_FOR_IF(indices->lengthOf() > + // Environment::getInstance()->elementwiseThreshold()) + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if ((*indices).e(e) == i) + listOutForCurrent.at(outputs[i].second++).assign(listOfTensors.at(e)); + } + } else { + unsigned int outSize = outputList.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + outputs[i].first = outputList[i]; + outputs[i].second = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + outputs[i].first->p(outputs[i].second++, input->e(e)); + } + }; + + samediff::Threads::parallel_tad(func, 0, outSize); + } +} +template +static int _dynamicStitchFunctor(std::vector const& inputs, + std::vector const& indices, + NDArray* output) { + int numOfData = inputs.size(); + + if (output->isVector()) { + for (int e = 0; e < numOfData; e++) { + auto data = inputs[e]; + auto index = indices[e]; + for (Nd4jLong i = 0; i < index->lengthOf(); i++) { + Nd4jLong pos = index->e(i); + if (pos < 0) { + nd4j_printf( + "dynamic_stitch: Index value should be non-negative. But %i was " + "given", + pos); + return ND4J_STATUS_VALIDATION; + } + if (pos >= output->lengthOf()) { + nd4j_printf( + "dynamic_stitch: Index should be less than %i. But %i was given", + output->lengthOf(), pos); + return ND4J_STATUS_VALIDATION; + } + output->p(pos, data->e(i)); + } + } + } else { + std::vector restDims(output->rankOf() - 1); + for (auto i = restDims.size(); i > 0; i--) + restDims[restDims.size() - i] = output->rankOf() - i; + + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (int e = 0; e < numOfData; e++) { + auto data = inputs[e]; + auto index = indices[e]; + std::vector sourceDims(data->rankOf() - index->rankOf()); + for (auto i = sourceDims.size(); i > 0; i--) + sourceDims[sourceDims.size() - i] = data->rankOf() - i; + + ResultSet listOfTensors = data->allTensorsAlongDimension(sourceDims); + + for (Nd4jLong i = 0; i < index->lengthOf(); i++) { + auto pos = index->e(i); + if (pos < 0) { + nd4j_printf( + "dynamic_stitch: Index value should be non-negative. But %i was " + "given", + pos); + return ND4J_STATUS_VALIDATION; } + if (pos >= output->lengthOf()) { + nd4j_printf( + "dynamic_stitch: Index should be less than %i. But %i was given", + output->lengthOf(), pos); + return ND4J_STATUS_VALIDATION; + } + + listOfOutTensors.at(pos).assign(listOfTensors.at(i)); + } + } + } + return ND4J_STATUS_OK; +} + +template +static void _dynamicPartitionFunctorBP( + NDArray const* input, NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList) { + std::vector> outputs(inputGradientList.size()); + + int sourceDimsLen = input->rankOf() - indices->rankOf(); + if (sourceDimsLen) { // multidimensional case + std::vector sourceDims(sourceDimsLen); + + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; + + ResultSet listOfTensors = + outputList[0]->allTensorsAlongDimension(sourceDims); + + for (auto i = 0; i < inputGradientList.size(); i++) { + outputs[i].first = inputGradientList[i]; + if (outputs[i].first->rankOf() < 1) continue; // skip empty gradient outs + std::vector outDims(outputs[i].first->rankOf() - 1); + + for (int k = 1; k < outputs[i].first->rankOf(); k++) outDims[k - 1] = k; + + ResultSet listOutForCurrent = + outputs[i].first->allTensorsAlongDimension(outDims); + + outputs[i].second = 0; + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + listOfTensors.at(e).assign(listOutForCurrent.at(outputs[i].second++)); } + } else { // one-dimensional case + auto output = outputList[0]; + unsigned int gradsSize = inputGradientList.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + outputs[i].first = inputGradientList[i]; + outputs[i].second = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + if (indices->e(e) == i) + output->p(e, outputs[i].first->e(outputs[i].second++)); + } + }; + + samediff::Threads::parallel_tad(func, 0, gradsSize); + } + + outputList[1]->assign(indices); +} + +void dynamicPartitionFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector& outputList) { + auto xType = input->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctor, + (input, indices, outputList), LIBND4J_TYPES); +} + +template +static int _dynamicStitchFunctorBP(std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList) { + throw std::runtime_error("Not umplemented yet"); +} + +int dynamicStitchFunctor(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray* output) { + auto xType = inputs.at(0)->dataType(); + + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctor, + (inputs, indices, output), LIBND4J_TYPES); +} + +int dynamicStitchFunctorBP(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList) { + auto xType = inputs.at(0)->dataType(); + + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, + (inputs, indices, gradInput, outputList), + LIBND4J_TYPES); +} + +void dynamicPartitionFunctorBP(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList) { + auto xType = input->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, + (input, indices, inputGradientList, outputList), + LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, + (NDArray const* input, NDArray const* indices, + std::vector const& inputGradientList, + std::vector& outputList); + , LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, + (std::vector const& inputs, + std::vector const& indices, + NDArray const* gradInput, + std::vector& outputList); + , LIBND4J_TYPES); + +BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctor, + (NDArray const* input, NDArray const* indices, + std::vector& outputList); + , LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctor, + (std::vector const& inputs, + std::vector const& indices, NDArray* output); + , LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index 377ea559fffe..3174b4d2e6eb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -18,82 +18,93 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ - std::vector restDims({1, 2, 3}); // the first and the last dims - ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); - ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); - // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) - //int e = 0; - const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); - const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); - const int ksize = ksizeRowsEffective * ksizeColsEffective; - int batchCount = listOfMatricies.size(); //lengthOf() / ksize; - Nd4jLong lastDim = images->sizeAt(3); - Nd4jLong outLastDim = output->sizeAt(3); - Nd4jLong rowDim = images->sizeAt(1); - Nd4jLong colDim = images->sizeAt(2); - Nd4jLong outRowDim = output->sizeAt(1); - Nd4jLong outColDim = output->sizeAt(2); - auto rowCast = 1; //(sizeRow - 1)*rateRow < outRowDim/sizeRow ?0:1;///(ksize * lastDim > rowDim * ksizeColsEffective + lastDim?1:0); - auto colCast = 1; //colDim / ksizeColsEffective +2 <= sizeCol?0:1;//(ksize * lastDim > ksizeRowsEffective * colDim + lastDim?1:0); - if (sizeRow * rateRow < 3) - rowCast = 0; - if (sizeCol * rateCol < 3) - colCast = 0; +template +static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, + int sizeCol, int strideRow, int strideCol, + int rateRow, int rateCol, bool theSame) { + std::vector restDims({1, 2, 3}); // the first and the last dims + ResultSet listOfMatricies = images->allTensorsAlongDimension(restDims); + ResultSet listOfOutputs = output->allTensorsAlongDimension(restDims); + // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) + // int e = 0; + const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); + const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); + const int ksize = ksizeRowsEffective * ksizeColsEffective; + int batchCount = listOfMatricies.size(); // lengthOf() / ksize; + Nd4jLong lastDim = images->sizeAt(3); + Nd4jLong outLastDim = output->sizeAt(3); + Nd4jLong rowDim = images->sizeAt(1); + Nd4jLong colDim = images->sizeAt(2); + Nd4jLong outRowDim = output->sizeAt(1); + Nd4jLong outColDim = output->sizeAt(2); + auto rowCast = 1; //(sizeRow - 1)*rateRow < outRowDim/sizeRow ?0:1;///(ksize + //* lastDim > rowDim * ksizeColsEffective + lastDim?1:0); + auto colCast = 1; // colDim / ksizeColsEffective +2 <= sizeCol?0:1;//(ksize * + // lastDim > ksizeRowsEffective * colDim + lastDim?1:0); + if (sizeRow * rateRow < 3) rowCast = 0; + if (sizeCol * rateCol < 3) colCast = 0; - auto func = PRAGMA_THREADS_FOR { - for (auto batch = 0; batch < stop; batch++) { - auto patch = listOfMatricies.at(batch); - auto outMatrix = listOfOutputs.at(batch); + auto func = PRAGMA_THREADS_FOR { + for (auto batch = 0; batch < stop; batch++) { + auto patch = listOfMatricies.at(batch); + auto outMatrix = listOfOutputs.at(batch); - for (Nd4jLong i = 0; i < outRowDim; i++) { - for (Nd4jLong j = 0; j < outColDim; j++) { - Nd4jLong pos = 0; - //for (Nd4jLong k = 0; k < outputLastDim; k++) { - auto rowStart = i * strideRow - (theSame ? rowCast : 0); - auto colStart = j * strideCol - (theSame ? colCast : 0); - auto rowEnd = rowStart + sizeRow * rateRow; - auto colEnd = colStart + sizeCol * rateCol; - if (!theSame) { - rowEnd = math::nd4j_min(rowStart + sizeRow * rateRow, rowDim); - colEnd = math::nd4j_min(colStart + sizeCol * rateCol, colDim); - } - //auto pixel = 0LL; - for (auto row = rowStart; row < rowEnd; row += rateRow) - for (auto col = colStart; col < colEnd; col += rateCol) - for (auto pixel = 0; pixel < lastDim; pixel++) { - bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || - (!theSame); - if (setUp) { - outMatrix.t(i, j, pos) = patch.e(row, col, pixel); - } - pos++; - } - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, batchCount); + for (Nd4jLong i = 0; i < outRowDim; i++) { + for (Nd4jLong j = 0; j < outColDim; j++) { + Nd4jLong pos = 0; + // for (Nd4jLong k = 0; k < outputLastDim; k++) { + auto rowStart = i * strideRow - (theSame ? rowCast : 0); + auto colStart = j * strideCol - (theSame ? colCast : 0); + auto rowEnd = rowStart + sizeRow * rateRow; + auto colEnd = colStart + sizeCol * rateCol; + if (!theSame) { + rowEnd = math::nd4j_min(rowStart + sizeRow * rateRow, rowDim); + colEnd = math::nd4j_min(colStart + sizeCol * rateCol, colDim); + } + // auto pixel = 0LL; + for (auto row = rowStart; row < rowEnd; row += rateRow) + for (auto col = colStart; col < colEnd; col += rateCol) + for (auto pixel = 0; pixel < lastDim; pixel++) { + bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && + col < colDim) || + (!theSame); + if (setUp) { + outMatrix.t(i, j, pos) = patch.e(row, col, pixel); + } + pos++; + } + } + } } + }; + samediff::Threads::parallel_tad(func, 0, batchCount); +} - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame){ - auto xType = images->dataType(); +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame) { + auto xType = images->dataType(); - BUILD_SINGLE_SELECTOR(xType, _extractPatches, (images, output, sizeRow, sizeCol, stradeRow, stradeCol, rateRow, rateCol, theSame), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(xType, _extractPatches, + (images, output, sizeRow, sizeCol, stradeRow, stradeCol, + rateRow, rateCol, theSame), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _extractPatches, (NDArray* input, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _extractPatches, + (NDArray * input, NDArray* output, int sizeRow, + int sizeCol, int stradeRow, int stradeCol, int rateRow, + int rateCol, bool theSame), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp index 9c343eafdf0b..1a7fc19e758d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/eye.cpp @@ -18,28 +18,25 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void eye(sd::LaunchContext * context, NDArray& output) { - - const int rank = output.rankOf(); - auto arrs = output.allTensorsAlongDimension({rank-2, rank-1}); +void eye(sd::LaunchContext* context, NDArray& output) { + const int rank = output.rankOf(); + auto arrs = output.allTensorsAlongDimension({rank - 2, rank - 1}); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - arrs.at(i).setIdentity(); - }; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) arrs.at(i).setIdentity(); + }; - samediff::Threads::parallel_tad(func, 0, arrs.size()); + samediff::Threads::parallel_tad(func, 0, arrs.size()); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp index d2c918da9081..11d1fe456860 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/fake_quantization.cpp @@ -18,105 +18,124 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - // - // nudge - nudged min max over scale - // scale = (Max - Min) / (quantMax - quantMin) - // quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 - // - template - static void nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { - // floating point instead integers - T quantMaxF = static_cast(quantMax); - T quantMinF = static_cast(quantMin); - // compute scale - *scale = (max - min) / (quantMaxF - quantMinF); - // compute left bound point - auto zeroPointFromMin = quantMinF - min / *scale; - // bound zero point to conform with range [0 or 1, 2^b - 1] - uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { - if (zeroPointFromMin < quantMinF) { - return static_cast(quantMin); - } - if (zeroPointFromMin > quantMaxF) { - return static_cast(quantMax); - } - return (uint16_t)sd::math::nd4j_round(zeroPointFromMin); - }(); - // compute nudged min and max with computed nudged zero point - *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); - *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); +// +// nudge - nudged min max over scale +// scale = (Max - Min) / (quantMax - quantMin) +// quantMin = 0 or 1, quantMax = 2^b - 1 == (1 << b) - 1 +// +template +static void nudge(T min, T max, int quantMin, int quantMax, T* scale, + T* nudgedMin, T* nudgedMax) { + // floating point instead integers + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + // compute scale + *scale = (max - min) / (quantMaxF - quantMinF); + // compute left bound point + auto zeroPointFromMin = quantMinF - min / *scale; + // bound zero point to conform with range [0 or 1, 2^b - 1] + uint16_t const nudged_zero_point = [zeroPointFromMin, quantMin, quantMax, + quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); + } + return (uint16_t)sd::math::nd4j_round(zeroPointFromMin); + }(); + // compute nudged min and max with computed nudged zero point + *nudgedMin = (quantMinF - nudged_zero_point) * (*scale); + *nudgedMax = (quantMaxF - nudged_zero_point) * (*scale); +} - template - void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed ? 1 : 0; // 0 or 1 - int upperIntBound = (1 << numBits) - 1; // 2^b - 1 - auto channels = input->sizeAt(-1); // last dimension +template +void fakeQuantWithMinMaxVarsPerChannel_(NDArray* input, NDArray* min, + NDArray* max, int numBits, + bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; // 0 or 1 + int upperIntBound = (1 << numBits) - 1; // 2^b - 1 + auto channels = input->sizeAt(-1); // last dimension - PRAGMA_OMP_PARALLEL_FOR - for (auto i = 0; i < channels; i++) { - T scale, nudged_min, nudged_max; - // nudge min and max first, with scale computing - nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, &nudged_min, &nudged_max); - // slide using last dimension and process all for given channel - for (auto e = 0; e < input->lengthOf(); e += channels) { - T val = input->t(e + i); - if ( val <= nudged_min) - val = nudged_min; - else if (val >= nudged_max) - val = nudged_max; - // quantization itself - output->t(e + i) = math::nd4j_floor((val - nudged_min)/scale + T(0.5)) * scale + nudged_min; - } - } + PRAGMA_OMP_PARALLEL_FOR + for (auto i = 0; i < channels; i++) { + T scale, nudged_min, nudged_max; + // nudge min and max first, with scale computing + nudge(min->t(i), max->t(i), lowIntBound, upperIntBound, &scale, + &nudged_min, &nudged_max); + // slide using last dimension and process all for given channel + for (auto e = 0; e < input->lengthOf(); e += channels) { + T val = input->t(e + i); + if (val <= nudged_min) + val = nudged_min; + else if (val >= nudged_max) + val = nudged_max; + // quantization itself + output->t(e + i) = + math::nd4j_floor((val - nudged_min) / scale + T(0.5)) * scale + + nudged_min; } + } +} // -//const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); +// const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); // const auto clamped_shifted = clamped - nudged_min; // outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() * // nudged_scale_repl + // nudged_min; // - template - void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed ? 1 : 0; - int upperIntBound = (1 << numBits) - 1; - - T nudgedMin, nudgedMax, scale; - // nudge with given min and max and compute scale and nudged min and max - nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - // quantization as one - auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { - T val = x; // boundign value between nudged min and max - if (val < nudgedMin) { - val = nudgedMin; - } - else if (val > nudgedMax) - val = nudgedMax; - // converse value with scale and shifted with nudged min - val -= nudgedMin; - return (sd::math::nd4j_floor(val / scale + T(0.5f)) * scale + nudgedMin); - }; +template +void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; - input->applyLambda(fakeQuantizationWithMinMax, *output); - } - - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, (NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output), FLOAT_TYPES); + T nudgedMin, nudgedMax, scale; + // nudge with given min and max and compute scale and nudged min and max + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, + &nudgedMin, &nudgedMax); + // quantization as one + auto fakeQuantizationWithMinMax = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; // boundign value between nudged min and max + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) + val = nudgedMax; + // converse value with scale and shifted with nudged min + val -= nudgedMin; + return (sd::math::nd4j_floor(val / scale + T(0.5f)) * scale + + nudgedMin); + }; + input->applyLambda(fakeQuantizationWithMinMax, *output); } + +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void fakeQuantWithMinMaxVars_, + (NDArray * input, NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp b/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp index aadd74298e7a..9b1385ca2dac 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/flatten.cpp @@ -21,38 +21,40 @@ #include namespace sd { - namespace ops { - namespace helpers { - - template - static void flatten_(std::vector &inputs, NDArray *output, const char order) { - - int numArrays = inputs.size(); - std::vector offsets(numArrays); - Nd4jLong cOffset = 0; - - // calculating offsets in output - for (int e = 0; e < numArrays; e++) { - offsets[e] = cOffset; - cOffset += inputs[e]->lengthOf(); - } - - // actually transferring data - for (int e = 0; e < numArrays; e++) { - auto z = reinterpret_cast(output->bufferWithOffset(offsets[e])); - - auto xBuffer = inputs[e]->bufferAsT(); - auto xShapeInfo = inputs[e]->shapeInfo(); - auto xLength = inputs[e]->lengthOf(); - - for (Nd4jLong i = 0; i < xLength; i++) - z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; - } - } - - void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), LIBND4J_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +template +static void flatten_(std::vector &inputs, NDArray *output, + const char order) { + int numArrays = inputs.size(); + std::vector offsets(numArrays); + Nd4jLong cOffset = 0; + + // calculating offsets in output + for (int e = 0; e < numArrays; e++) { + offsets[e] = cOffset; + cOffset += inputs[e]->lengthOf(); + } + + // actually transferring data + for (int e = 0; e < numArrays; e++) { + auto z = reinterpret_cast(output->bufferWithOffset(offsets[e])); + + auto xBuffer = inputs[e]->bufferAsT(); + auto xShapeInfo = inputs[e]->shapeInfo(); + auto xLength = inputs[e]->lengthOf(); + + for (Nd4jLong i = 0; i < xLength; i++) + z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; + } +} + +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (inputs, output, order), + LIBND4J_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp index 1deb12752e9b..923d3eac6e68 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp @@ -18,160 +18,168 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 07.03.2019 // -#include -#include #include -#include #include +#include +#include + +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////// -void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - int axis = intArgs.size() > 0 ? intArgs[0] : 0; - const int inputRank = input->rankOf(); - if(axis < 0) - axis += inputRank; - - const int numOfIntArgs = intArgs.size(); - - if (indices != nullptr) { - - // first case: indices consist of only one scalar - if(indices->isScalar()) { - - if(input->rankOf() <= 1){ - //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + int axis = intArgs.size() > 0 ? intArgs[0] : 0; + const int inputRank = input->rankOf(); + if (axis < 0) axis += inputRank; + + const int numOfIntArgs = intArgs.size(); + + if (indices != nullptr) { + // first case: indices consist of only one scalar + if (indices->isScalar()) { + if (input->rankOf() <= 1) { + // For scalar indices, rank 0 or 1 input: can't do tensor along + // dimension 0 as this is whole array... instead, we want to get a scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + NDArray inSubArr = (*input)(indices->e(0), {axis}); + output->assign(inSubArr); + } + } else { + if (input->rankOf() == 1 && output->rankOf() == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output->p(i, input->e(indices->e(i))); + }; + + samediff::Threads::parallel_for(func, 0, output->lengthOf()); + + } else { + std::vector dimsOut; + for (int i = 0; i < axis; ++i) dimsOut.push_back(i); + for (int i = axis + indices->rankOf(); i < output->rankOf(); ++i) + dimsOut.push_back(i); + + std::vector dimsIn = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + const Nd4jLong numOfSubArrs = indices->lengthOf(); + + auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimsIn); + auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimsOut); + + auto inTadShapeInfo = inTadPack.primaryShapeInfo(); + auto outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && + shape::order(inTadShapeInfo) == 'c' && + input->dataType() == output->dataType() && + shape::elementWiseStride(inTadShapeInfo) == 1 && + shape::elementWiseStride(outTadShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[indices->e(i)]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + memcpy(outBuff, inBuff, + shape::length(inTadShapeInfo) * input->sizeOfT()); } - else { - NDArray inSubArr = (*input)(indices->e(0), {axis}); - output->assign(inSubArr); + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[indices->e(i)]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input->getContext(), transform::Assign, inBuff, + inTadShapeInfo, nullptr /*input specialBuffer*/, + nullptr /*input specialShapeInfo*/, outBuff, outTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output specialShapeInfo*/, nullptr, nullptr, + nullptr, false /*allowParallelism*/); } - } - else { - - if(input->rankOf() == 1 && output->rankOf() == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output->p(i, input->e(indices->e(i))); - }; - - samediff::Threads::parallel_for(func, 0, output->lengthOf()); - - } - else { - - std::vector dimsOut; - for (int i = 0; i < axis; ++i) - dimsOut.push_back(i); - for (int i = axis+indices->rankOf(); i < output->rankOf(); ++i) - dimsOut.push_back(i); - - std::vector dimsIn = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - - const Nd4jLong numOfSubArrs = indices->lengthOf(); - - auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimsIn); - auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimsOut); - - auto inTadShapeInfo = inTadPack.primaryShapeInfo(); - auto outTadShapeInfo = outTadPack.primaryShapeInfo(); - - if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - else { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { + }; - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[indices->e(i)]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, - inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, - outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } + } } - else { - - // we only allow scalar/vector case here - if (numOfIntArgs == 2) { // scalar case - - output->assign((*input)(intArgs[1], {axis})); - } - else { // vector case - - const Nd4jLong numOfSubArrs = intArgs.size() - 1; - - std::vector dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - - auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dims); - auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dims); - - auto inTadShapeInfo = inTadPack.primaryShapeInfo(); - auto outTadShapeInfo = outTadPack.primaryShapeInfo(); - - if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && shape::order(inTadShapeInfo) == 'c' && input->dataType() == output->dataType() && shape::elementWiseStride(inTadShapeInfo) == 1 && shape::elementWiseStride(outTadShapeInfo) == 1) { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); - void* outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - std::memcpy(outBuff, inBuff, shape::length(inTadShapeInfo) * input->sizeOfT()); - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - - } - else { - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - auto inBuff = input->bufferWithOffset(inTadPack.primaryOffsets()[intArgs[i + 1]]); - auto outBuff = output->bufferWithOffset(outTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input->getContext(), transform::Assign, - inBuff, inTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, - outBuff, outTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - - } - }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - - } + } else { + // we only allow scalar/vector case here + if (numOfIntArgs == 2) { // scalar case + + output->assign((*input)(intArgs[1], {axis})); + } else { // vector case + + const Nd4jLong numOfSubArrs = intArgs.size() - 1; + + std::vector dims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + + auto inTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dims); + auto outTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dims); + + auto inTadShapeInfo = inTadPack.primaryShapeInfo(); + auto outTadShapeInfo = outTadPack.primaryShapeInfo(); + + if (shape::order(inTadShapeInfo) == shape::order(outTadShapeInfo) && + shape::order(inTadShapeInfo) == 'c' && + input->dataType() == output->dataType() && + shape::elementWiseStride(inTadShapeInfo) == 1 && + shape::elementWiseStride(outTadShapeInfo) == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[intArgs[i + 1]]); + void* outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + std::memcpy(outBuff, inBuff, + shape::length(inTadShapeInfo) * input->sizeOfT()); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input->bufferWithOffset( + inTadPack.primaryOffsets()[intArgs[i + 1]]); + auto outBuff = + output->bufferWithOffset(outTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input->getContext(), transform::Assign, inBuff, inTadShapeInfo, + nullptr /*input specialBuffer*/, + nullptr /*input specialShapeInfo*/, outBuff, outTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output specialShapeInfo*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } + } } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp index db62c4b4f42c..b858ee9d7b37 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp @@ -18,166 +18,174 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - - -#include +#include #include +#include + #include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - //////////////////////////////////////////////////////////////////////// -template +template static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { + const X* x = reinterpret_cast(input.buffer()); + const Y* y = reinterpret_cast(indices.buffer()); + X* z = reinterpret_cast(output.buffer()); - const X* x = reinterpret_cast(input.buffer()); - const Y* y = reinterpret_cast(indices.buffer()); - X* z = reinterpret_cast(output.buffer()); + const int xRank = input.rankOf(); + const int yRank = indices.rankOf(); + const int zRank = output.rankOf(); + const int maxRank = + sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); - const int xRank = input.rankOf(); - const int yRank = indices.rankOf(); - const int zRank = output.rankOf(); - const int maxRank = sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); + const Nd4jLong zLen = output.lengthOf(); - const Nd4jLong zLen = output.lengthOf(); + const uint yLastDim = indices.sizeAt(-1); - const uint yLastDim = indices.sizeAt(-1); + const int diff = zRank - xRank; + const bool bEqual = yLastDim == xRank; - const int diff = zRank - xRank; - const bool bEqual = yLastDim == xRank; + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK], zCoords[MAX_RANK], temp; - auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - int xCoords[MAX_RANK], zCoords[MAX_RANK], temp; + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - for (auto i = start; i < stop; i++) { - - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - - temp = zCoords[yRank - 1]; - zCoords[yRank - 1] = 0; - const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords); - zCoords[yRank - 1] = temp; + temp = zCoords[yRank - 1]; + zCoords[yRank - 1] = 0; + const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords); + zCoords[yRank - 1] = temp; - if(bEqual) - memcpy(xCoords, zCoords, zRank * sizeof(int)); - else if(diff >= 0) - memcpy(xCoords, zCoords + diff, xRank * sizeof(int)); - else - memcpy(xCoords - diff, zCoords, zRank * sizeof(int)); + if (bEqual) + memcpy(xCoords, zCoords, zRank * sizeof(int)); + else if (diff >= 0) + memcpy(xCoords, zCoords + diff, xRank * sizeof(int)); + else + memcpy(xCoords - diff, zCoords, zRank * sizeof(int)); - for (uint j = 0; j < yLastDim; ++j) - xCoords[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride + for (uint j = 0; j < yLastDim; ++j) + xCoords[j] = + y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - z[zOffset] = x[xOffset]; - } - }; + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } //////////////////////////////////////////////////////////////////////// -void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES); +void gatherND(sd::LaunchContext* context, NDArray& input, NDArray& indices, + NDArray& output) { + BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, + (input, indices, output), LIBND4J_TYPES, + INDEXING_TYPES); } - //////////////////////////////////////////////////////////////////////// -template -static void gather_(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - int axis = intArgs.size() > 0 ? intArgs[0] : 0; - const int inputRank = input->rankOf(); - if(axis < 0) - axis += inputRank; - - const int numOfIntArgs = intArgs.size(); - - if (indices != nullptr) { - - for(Nd4jLong i = 0; i < indices->lengthOf(); ++i) - if(indices->e(i) >= input->sizeAt(axis)) - throw std::runtime_error("helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !"); - - // first case: indices consist of only one scalar - if(indices->isScalar()) { - if(input->rankOf() <= 1){ - //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); - } else { - auto dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - - auto tadArr = NDArray(reinterpret_cast(reinterpret_cast(input->buffer()) + tadPack.primaryOffsets()[indices->e(0)]), tadPack.primaryShapeInfo(), output->getContext()); - output->assign(&tadArr); - } - } - else if (input->rankOf() == 1 && indices->isVector()) { - // special case - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - output->p(e, input->e(indices->e(e))); - }; - - samediff::Threads::parallel_for(func, 0, indices->lengthOf()); +template +static void gather_(NDArray* input, const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + int axis = intArgs.size() > 0 ? intArgs[0] : 0; + const int inputRank = input->rankOf(); + if (axis < 0) axis += inputRank; + + const int numOfIntArgs = intArgs.size(); + + if (indices != nullptr) { + for (Nd4jLong i = 0; i < indices->lengthOf(); ++i) + if (indices->e(i) >= input->sizeAt(axis)) + throw std::runtime_error( + "helpers::gather function: indices array contains wrong elements, " + "each element must be smaller than corresponding dimension of " + "input array !"); + + // first case: indices consist of only one scalar + if (indices->isScalar()) { + if (input->rankOf() <= 1) { + // For scalar indices, rank 0 or 1 input: can't do tensor along + // dimension 0 as this is whole array... instead, we want to get a scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + auto dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + + auto tadArr = + NDArray(reinterpret_cast( + reinterpret_cast(input->buffer()) + + tadPack.primaryOffsets()[indices->e(0)]), + tadPack.primaryShapeInfo(), output->getContext()); + output->assign(&tadArr); + } + } else if (input->rankOf() == 1 && indices->isVector()) { + // special case + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + output->p(e, input->e(indices->e(e))); + }; + + samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + std::vector dimsOut(indices->rankOf()); + std::iota(dimsOut.begin(), dimsOut.end(), + axis); // fill with axis, axis+1, ... indices->rankOf()-1 + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray subArrOut = (*output)(i, dimsOut); + NDArray subArrIn = (*input)(indices->e(i), {axis}); + subArrOut.assign(subArrIn); } - else { - - std::vector dimsOut(indices->rankOf()); - std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... indices->rankOf()-1 - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray subArrOut = (*output)(i, dimsOut); - NDArray subArrIn = (*input)(indices->e(i), {axis}); - subArrOut.assign(subArrIn); - } - }; + }; - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - else { - - for(int i = 1; i < numOfIntArgs; ++i) - if(intArgs[i] >= input->sizeAt(axis)) - throw std::runtime_error("helpers::gather function: some of input indexes is larger than corresponding shape of input array !"); - - // we only allow scalar/vector case here - if (numOfIntArgs == 2) { // scalar case - output->assign((*input)(intArgs[1], {axis})); - } - else { // vector case - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis}); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray subArrOut = (*output)(i, {axis}); - NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); - subArrOut.assign(subArrIn); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } else { + for (int i = 1; i < numOfIntArgs; ++i) + if (intArgs[i] >= input->sizeAt(axis)) + throw std::runtime_error( + "helpers::gather function: some of input indexes is larger than " + "corresponding shape of input array !"); + + // we only allow scalar/vector case here + if (numOfIntArgs == 2) { // scalar case + output->assign((*input)(intArgs[1], {axis})); + } else { // vector case + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis}); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray subArrOut = (*output)(i, {axis}); + NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); + subArrOut.assign(subArrIn); } - } -} + }; - void gather(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES); + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } - -} + } } + +void gather(NDArray* input, const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + BUILD_SINGLE_SELECTOR(input->dataType(), gather_, + (input, indices, output, intArgs), LIBND4J_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp index df5ee1afcf7b..e8786a3a1ac0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp @@ -25,18 +25,22 @@ namespace sd { namespace ops { namespace helpers { template -static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, NDArray* output) { - auto lambda = LAMBDA_TT(_x, _y, weight) { - return _x - (_y * weight); - }; +static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, + NDArray* output) { + auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; - input->applyPairwiseLambda(*step, lambda, *output); + input->applyPairwiseLambda(*step, lambda, *output); } -void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (input, step, weight, output), FLOAT_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES); -} -} +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, + (input, step, weight, output), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, + (NDArray * input, NDArray* step, double weight, + NDArray* output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp index 10b6a27e026d..22df1684e685 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp @@ -18,87 +18,84 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - static Nd4jLong hamming_distance(unsigned long long x, unsigned long long y) { - Nd4jLong dist = 0; - - for (unsigned long long val = x ^ y; val > 0; val /= 2) { - if (val & 1) - dist++; - } - return dist; - } - - - template - static void _hamming(NDArray &x, NDArray &y, NDArray &z) { - auto xEws = x.ews(); - auto yEws = y.ews(); - - auto xBuffer = x.bufferAsT(); - auto yBuffer = y.bufferAsT(); - - Nd4jLong distance = 0; - auto lengthOf = x.lengthOf(); - int maxThreads = sd::math::nd4j_min(256, omp_get_max_threads()); - Nd4jLong intermediate[256]; - - // nullify temp values - for (int e = 0; e < maxThreads; e++) - intermediate[e] = 0; - - if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(xBuffer[e]); - auto _y = static_cast(yBuffer[e]); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(xBuffer[e * xEws]); - auto _y = static_cast(yBuffer[e * yEws]); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto _x = static_cast(x.e(e)); - auto _y = static_cast(y.e(e)); - - intermediate[thread_id] += hamming_distance(_x, _y); - } - }; - - maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); - } - - // accumulate intermediate variables into output array - for (int e = 0; e < maxThreads; e++) - distance += intermediate[e]; - - z.p(0, distance); - } - - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { - BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +static Nd4jLong hamming_distance(unsigned long long x, unsigned long long y) { + Nd4jLong dist = 0; + + for (unsigned long long val = x ^ y; val > 0; val /= 2) { + if (val & 1) dist++; + } + return dist; +} + +template +static void _hamming(NDArray &x, NDArray &y, NDArray &z) { + auto xEws = x.ews(); + auto yEws = y.ews(); + + auto xBuffer = x.bufferAsT(); + auto yBuffer = y.bufferAsT(); + + Nd4jLong distance = 0; + auto lengthOf = x.lengthOf(); + int maxThreads = sd::math::nd4j_min(256, omp_get_max_threads()); + Nd4jLong intermediate[256]; + + // nullify temp values + for (int e = 0; e < maxThreads; e++) intermediate[e] = 0; + + if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(xBuffer[e]); + auto _y = static_cast(yBuffer[e]); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(xBuffer[e * xEws]); + auto _y = static_cast(yBuffer[e * yEws]); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto _x = static_cast(x.e(e)); + auto _y = static_cast(y.e(e)); + + intermediate[thread_id] += hamming_distance(_x, _y); + } + }; + + maxThreads = samediff::Threads::parallel_for(func, 0, lengthOf); + } + + // accumulate intermediate variables into output array + for (int e = 0; e < maxThreads; e++) distance += intermediate[e]; + + z.p(0, distance); +} + +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, + (x, y, output), INTEGER_TYPES, INDEXING_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp index 5893b2c88363..822d155d0523 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/hashcode.cpp @@ -18,88 +18,90 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { - Nd4jLong blockSize = 32; - auto length = array.lengthOf(); - int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); - auto tempA = NDArrayFactory::create('c', {numBlocks}, context); - auto tempB = NDArrayFactory::create('c', { numBlocks / blockSize + 1}, context); - - auto buffer = array.bufferAsT(); - auto tempBufferA = tempA.bufferAsT(); - auto tempBufferB = tempB.bufferAsT(); - - // default buffer is the first one, because it might be the last one in case of small arrays (< blockSize) - auto tempBuffer = tempBufferA; - auto tempResult = tempBufferB; - - // we divide array into 32 element chunks, and store intermediate results once - auto func = PRAGMA_THREADS_FOR { - for (auto b = start; b < stop; b++) { - auto blockBuffer = buffer + b * numBlocks; - - Nd4jLong r = 1; - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31 * r + v; - } - - tempBuffer[b] = r; - } - }; - samediff::Threads::parallel_tad(func, 0, numBlocks); - - // we replace pointer with intermediate one, and repeat only one chunk left - int iterationCount = 0; - while (numBlocks > 1) { - int lastLength = numBlocks; - numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); - - - auto func2 = PRAGMA_THREADS_FOR { - for (auto b = start; b < stop; b++) { - auto blockBuffer = tempBuffer + b * numBlocks; - - Nd4jLong r = 1; - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31 * r + v; - } - - tempResult[b] = r; - } - }; - samediff::Threads::parallel_tad(func2, 0, numBlocks); - - - iterationCount++; - // swapping buffers - if (iterationCount % 2 == 0) { - tempBuffer = tempBufferA; - tempResult = tempBufferB; - } else { - tempBuffer = tempBufferB; - tempResult = tempBufferA; - } - } +namespace ops { +namespace helpers { +template +static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { + Nd4jLong blockSize = 32; + auto length = array.lengthOf(); + int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); + auto tempA = NDArrayFactory::create('c', {numBlocks}, context); + auto tempB = NDArrayFactory::create( + 'c', {numBlocks / blockSize + 1}, context); + + auto buffer = array.bufferAsT(); + auto tempBufferA = tempA.bufferAsT(); + auto tempBufferB = tempB.bufferAsT(); + + // default buffer is the first one, because it might be the last one in case + // of small arrays (< blockSize) + auto tempBuffer = tempBufferA; + auto tempResult = tempBufferB; + + // we divide array into 32 element chunks, and store intermediate results once + auto func = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; b++) { + auto blockBuffer = buffer + b * numBlocks; + + Nd4jLong r = 1; + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { + auto v = longBytes(blockBuffer[e]); + r = 31 * r + v; + } + + tempBuffer[b] = r; + } + }; + samediff::Threads::parallel_tad(func, 0, numBlocks); + + // we replace pointer with intermediate one, and repeat only one chunk left + int iterationCount = 0; + while (numBlocks > 1) { + int lastLength = numBlocks; + numBlocks = + lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); + + auto func2 = PRAGMA_THREADS_FOR { + for (auto b = start; b < stop; b++) { + auto blockBuffer = tempBuffer + b * numBlocks; + + Nd4jLong r = 1; + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; + e++) { + auto v = longBytes(blockBuffer[e]); + r = 31 * r + v; + } - if (length <= blockSize) - result.p(0, tempBufferA[0]); - else - result.p(0, tempResult[0]); - } + tempResult[b] = r; + } + }; + samediff::Threads::parallel_tad(func2, 0, numBlocks); + + iterationCount++; + // swapping buffers + if (iterationCount % 2 == 0) { + tempBuffer = tempBufferA; + tempResult = tempBufferB; + } else { + tempBuffer = tempBufferB; + tempResult = tempBufferA; + } + } + if (length <= blockSize) + result.p(0, tempBufferA[0]); + else + result.p(0, tempResult[0]); +} - void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { - BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES); - } - } - } +void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { + BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp index cb815110d415..9dafacb11c55 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -21,50 +21,55 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - static void histogram_(void const* xBuffer, Nd4jLong const* xShapeInfo, void *zBuffer, Nd4jLong const* zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { - auto dx = reinterpret_cast(xBuffer); - auto result = reinterpret_cast(zBuffer); +namespace ops { +namespace helpers { +template +static void histogram_(void const *xBuffer, Nd4jLong const *xShapeInfo, + void *zBuffer, Nd4jLong const *zShapeInfo, + Nd4jLong numBins, double min_val, double max_val) { + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); - int length = shape::length(xShapeInfo); + int length = shape::length(xShapeInfo); - X binSize = (max_val - min_val) / (numBins); + X binSize = (max_val - min_val) / (numBins); - // FIXME: this op should be parallelized - { - int *bins = new int[numBins]; - std::memset(bins, 0, sizeof(int) * numBins); + // FIXME: this op should be parallelized + { + int *bins = new int[numBins]; + std::memset(bins, 0, sizeof(int) * numBins); - PRAGMA_OMP_SIMD - for (int x = 0; x < length; x++) { - int idx = (int) ((dx[x] - min_val) / binSize); - if (idx < 0) - idx = 0; - else if (idx >= numBins) - idx = numBins - 1; + PRAGMA_OMP_SIMD + for (int x = 0; x < length; x++) { + int idx = (int)((dx[x] - min_val) / binSize); + if (idx < 0) + idx = 0; + else if (idx >= numBins) + idx = numBins - 1; - bins[idx]++; - } - - PRAGMA_OMP_SIMD - for (Nd4jLong x = 0; x < numBins; x++) { - result[x] += bins[x]; - } + bins[idx]++; + } + PRAGMA_OMP_SIMD + for (Nd4jLong x = 0; x < numBins; x++) { + result[x] += bins[x]; + } - delete[] bins; - } - } + delete[] bins; + } +} - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output) { - Nd4jLong numBins = output.lengthOf(); - double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); - double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); + double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); - } - } - } -} \ No newline at end of file + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, + (input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo(), numBins, min_val, max_val), + LIBND4J_TYPES, INDEXING_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp index 9376e80bf8a4..a150c5652efb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp @@ -24,45 +24,48 @@ namespace sd { namespace ops { namespace helpers { - template -void histogramFixedWidth_(const NDArray& input, const NDArray& range, NDArray& output) { - - const int nbins = output.lengthOf(); +void histogramFixedWidth_(const NDArray& input, const NDArray& range, + NDArray& output) { + const int nbins = output.lengthOf(); - // firstly initialize output with zeros - output.nullify(); + // firstly initialize output with zeros + output.nullify(); - const T leftEdge = range.e(0); - const T rightEdge = range.e(1); + const T leftEdge = range.e(0); + const T rightEdge = range.e(1); - const T binWidth = (rightEdge - leftEdge ) / nbins; - const T secondEdge = leftEdge + binWidth; - const T lastButOneEdge = rightEdge - binWidth; + const T binWidth = (rightEdge - leftEdge) / nbins; + const T secondEdge = leftEdge + binWidth; + const T lastButOneEdge = rightEdge - binWidth; - Nd4jLong inputLength = input.lengthOf(); + Nd4jLong inputLength = input.lengthOf(); - // FIXME: make this one parallel without CRITICAL section - for(Nd4jLong i = 0; i < inputLength; ++i) { - const T value = input.e(i); + // FIXME: make this one parallel without CRITICAL section + for (Nd4jLong i = 0; i < inputLength; ++i) { + const T value = input.e(i); - if(value < secondEdge) { - output.p(0, output.e(0) + 1); - } else if(value >= lastButOneEdge) { - output.p(nbins - 1, output.e(nbins - 1) + 1); - } else { - Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); - output.p(currInd, output.e(currInd) + 1); - } + if (value < secondEdge) { + output.p(0, output.e(0) + 1); + } else if (value >= lastButOneEdge) { + output.p(nbins - 1, output.e(nbins - 1) + 1); + } else { + Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + output.p(currInd, output.e(currInd) + 1); } + } } -void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (input, range, output), LIBND4J_TYPES); +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, + (input, range, output), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, + (const NDArray& input, const NDArray& range, + NDArray& output), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp index 2434fddcc658..60bb27dd6b56 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp @@ -18,123 +18,135 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.09.2018 // -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -static void im2col_(sd::LaunchContext & context, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - - // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] - - auto imBuff = static_cast(input.buffer()); - auto colBuff = static_cast(output.buffer()); - auto imShapeBuffer = input.shapeInfo(); - auto colShapeBuffer = output.shapeInfo(); - auto colShape = shape::shapeOf(colShapeBuffer); - auto colStride = shape::stride(colShapeBuffer); - auto imShape = shape::shapeOf(imShapeBuffer); - auto imStride = shape::stride(imShapeBuffer); - - const T zeroPadVal = arrZeroPadVal.e(0); - - const int bS = imShape[0]; - const int iC = imShape[1]; - const int iH = imShape[2]; - const int iW = imShape[3]; - const int oH = colShape[4]; - const int oW = colShape[5]; - const Nd4jLong colStride0 = colStride[0]; - const Nd4jLong colStride1 = colStride[1]; - const Nd4jLong colStride2 = colStride[2]; - const Nd4jLong colStride3 = colStride[3]; - const Nd4jLong colStride4 = colStride[4]; - const Nd4jLong colStride5 = colStride[5]; - const Nd4jLong imStride0 = imStride[0]; - const Nd4jLong imStride1 = imStride[1]; - const Nd4jLong imStride2 = imStride[2]; - const Nd4jLong imStride3 = imStride[3]; - - - if (shape::order(imShapeBuffer) == 'c' && shape::order(colShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(imShapeBuffer) && shape::strideDescendingCAscendingF(colShapeBuffer)) { - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b++) { - for (auto c = start_y; c < stop_y; c++) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - int imRow = (-pH + kRow * dH) + colH * sH; - int imCol = (-pW + kCol * dW) + colW * sW; - - auto col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - - if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) - *col = zeroPadVal; - else { - auto im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - *col = *im; - } - } - } - } - } +static void im2col_(sd::LaunchContext& context, const NDArray& input, + NDArray& output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const NDArray& arrZeroPadVal) { + // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] + + auto imBuff = static_cast(input.buffer()); + auto colBuff = static_cast(output.buffer()); + auto imShapeBuffer = input.shapeInfo(); + auto colShapeBuffer = output.shapeInfo(); + auto colShape = shape::shapeOf(colShapeBuffer); + auto colStride = shape::stride(colShapeBuffer); + auto imShape = shape::shapeOf(imShapeBuffer); + auto imStride = shape::stride(imShapeBuffer); + + const T zeroPadVal = arrZeroPadVal.e(0); + + const int bS = imShape[0]; + const int iC = imShape[1]; + const int iH = imShape[2]; + const int iW = imShape[3]; + const int oH = colShape[4]; + const int oW = colShape[5]; + const Nd4jLong colStride0 = colStride[0]; + const Nd4jLong colStride1 = colStride[1]; + const Nd4jLong colStride2 = colStride[2]; + const Nd4jLong colStride3 = colStride[3]; + const Nd4jLong colStride4 = colStride[4]; + const Nd4jLong colStride5 = colStride[5]; + const Nd4jLong imStride0 = imStride[0]; + const Nd4jLong imStride1 = imStride[1]; + const Nd4jLong imStride2 = imStride[2]; + const Nd4jLong imStride3 = imStride[3]; + + if (shape::order(imShapeBuffer) == 'c' && + shape::order(colShapeBuffer) == 'c' && + shape::strideDescendingCAscendingF(imShapeBuffer) && + shape::strideDescendingCAscendingF(colShapeBuffer)) { + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b++) { + for (auto c = start_y; c < stop_y; c++) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + int imRow = (-pH + kRow * dH) + colH * sH; + int imCol = (-pW + kCol * dW) + colW * sW; + + auto col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + + if (static_cast(imRow) >= + static_cast(iH) || + static_cast(imCol) >= static_cast(iW)) + *col = zeroPadVal; + else { + auto im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + *col = *im; + } } + } } - }; - - samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); - } - else { - - auto func = PRAGMA_THREADS_FOR_2D { - T *col; - T const* im; - int imRow, imCol; - - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto colH = start_y; colH < stop_y; colH += inc_y) { - for (int colW = 0; colW < oW; ++colW) { - for (int c = 0; c < iC; ++c) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - imRow = (-pH + kRow * dH) + colH * sH; - imCol = (-pW + kCol * dW) + colW * sW; - - col = colBuff + b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - - if (static_cast(imRow) >= static_cast(iH) || static_cast(imCol) >= static_cast(iW)) - *col = zeroPadVal; - else { - im = imBuff + b * imStride0 + c * imStride1 + imRow * imStride2 + imCol * imStride3; - *col = *im; - } - } - } - } - } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, bS, 1, 0, iC, 1); + } else { + auto func = PRAGMA_THREADS_FOR_2D { + T* col; + T const* im; + int imRow, imCol; + + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto colH = start_y; colH < stop_y; colH += inc_y) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + imRow = (-pH + kRow * dH) + colH * sH; + imCol = (-pW + kCol * dW) + colW * sW; + + col = colBuff + b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + + if (static_cast(imRow) >= + static_cast(iH) || + static_cast(imCol) >= static_cast(iW)) + *col = zeroPadVal; + else { + im = imBuff + b * imStride0 + c * imStride1 + + imRow * imStride2 + imCol * imStride3; + *col = *im; + } } + } } - }; + } + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); - } + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + } } - -void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), FLOAT_TYPES); +void im2col(sd::LaunchContext& context, const NDArray& im, NDArray& col, + const int kH, const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, const int dW, + const NDArray& arrZeroPadVal) { + BUILD_SINGLE_SELECTOR( + im.dataType(), im2col_, + (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp index ee4faafb024f..d2de40cc7dc1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_draw_bounding_boxes.cpp @@ -31,128 +31,139 @@ limitations under the License. // // @author sgazeos@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - typedef std::vector> ColorTable_t; - static ColorTable_t DefaultColorTable(int depth) { - std::vector> colorTable; - colorTable.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow - colorTable.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue - colorTable.emplace_back(std::vector({1, 0, 0, 1})); // 2: red - colorTable.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime - colorTable.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple - colorTable.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive - colorTable.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon - colorTable.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue - colorTable.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua - colorTable.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia +typedef std::vector> ColorTable_t; +static ColorTable_t DefaultColorTable(int depth) { + std::vector> colorTable; + colorTable.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow + colorTable.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue + colorTable.emplace_back(std::vector({1, 0, 0, 1})); // 2: red + colorTable.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime + colorTable.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple + colorTable.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive + colorTable.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon + colorTable.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue + colorTable.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua + colorTable.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia - if (depth == 1) { - for (Nd4jLong i = 0; i < colorTable.size(); i++) { - colorTable[i][0] = 1; - } - } - return colorTable; + if (depth == 1) { + for (Nd4jLong i = 0; i < colorTable.size(); i++) { + colorTable[i][0] = 1; } + } + return colorTable; +} - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { - // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set - // boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as - // floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd - // floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd - // height = images->sizeAt(1), width = images->sizeAt(2) - // colors - colors for each box given - // set up color for each box as frame - auto batchSize = images->sizeAt(0); - auto boxSize = boxes->sizeAt(0); - auto height = images->sizeAt(1); - auto width = images->sizeAt(2); - auto channels = images->sizeAt(3); - //auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch -// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch - //auto colorSet = colors->allTensorsAlongDimension({0}); - output->assign(images); // fill up all output with input images, then fill up boxes - ColorTable_t colorTable; - if (colors) { - for (auto i = 0; i < colors->sizeAt(0); i++) { - std::vector colorValue(4); - for (auto j = 0; j < 4; j++) { - colorValue[j] = j < colors->sizeAt(1) ? colors->e(i, j) : 1.f; - } - colorTable.emplace_back(colorValue); - } - } - if (colorTable.empty()) - colorTable = DefaultColorTable(channels); - auto func = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; ++batch) { // loop by batch - const Nd4jLong numBoxes = boxes->sizeAt(1); - for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { - auto colorIndex = boxIndex % colorTable.size(); - auto rowStart = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 0)); - auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); - auto rowEnd = Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 2)); - auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); - auto colStart = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 1)); - auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); - auto colEnd = Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 3)); - auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - - if (rowStart > rowEnd || colStart > colEnd) { - nd4j_debug( - "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " - "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - if (rowStart >= height || rowEnd < 0 || colStart >= width || - colEnd < 0) { - nd4j_debug( - "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " - "outside the image and not be drawn\n ", rowStart, colStart, rowEnd, colEnd); - continue; - } +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, + NDArray* output) { + // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or + // RGBA (last dim = 4) channel set boxes - batch of 2D bounds with last dim + // (y_start, x_start, y_end, x_end) to compute i and j as floor((height - 1 ) + // * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd floor((width + // - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd height = + // images->sizeAt(1), width = images->sizeAt(2) colors - colors for each box + // given set up color for each box as frame + auto batchSize = images->sizeAt(0); + auto boxSize = boxes->sizeAt(0); + auto height = images->sizeAt(1); + auto width = images->sizeAt(2); + auto channels = images->sizeAt(3); + // auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split + // images by batch + // auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split + // boxes by batch + // auto colorSet = colors->allTensorsAlongDimension({0}); + output->assign( + images); // fill up all output with input images, then fill up boxes + ColorTable_t colorTable; + if (colors) { + for (auto i = 0; i < colors->sizeAt(0); i++) { + std::vector colorValue(4); + for (auto j = 0; j < 4; j++) { + colorValue[j] = j < colors->sizeAt(1) ? colors->e(i, j) : 1.f; + } + colorTable.emplace_back(colorValue); + } + } + if (colorTable.empty()) colorTable = DefaultColorTable(channels); + auto func = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; ++batch) { // loop by batch + const Nd4jLong numBoxes = boxes->sizeAt(1); + for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { + auto colorIndex = boxIndex % colorTable.size(); + auto rowStart = + Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 0)); + auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); + auto rowEnd = + Nd4jLong((height - 1) * boxes->t(batch, boxIndex, 2)); + auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); + auto colStart = + Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 1)); + auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); + auto colEnd = + Nd4jLong((width - 1) * boxes->t(batch, boxIndex, 3)); + auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - // Draw upper line - if (rowStart >= 0) { - for (auto j = colStartBound; j <= colEndBound; ++j) - for (auto c = 0; c < channels; c++) { - output->p(batch, rowStart, j, c, colorTable[colorIndex][c]); - } - } - // Draw bottom line. - if (rowEnd < height) { - for (auto j = colStartBound; j <= colEndBound; ++j) - for (auto c = 0; c < channels; c++) { - output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]); - } - } + if (rowStart > rowEnd || colStart > colEnd) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, " + "%lld, %lld) is inverted " + "and will not be drawn\n", + rowStart, colStart, rowEnd, colEnd); + continue; + } + if (rowStart >= height || rowEnd < 0 || colStart >= width || + colEnd < 0) { + nd4j_debug( + "helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, " + "%lld, %lld) is completely " + "outside the image and not be drawn\n ", + rowStart, colStart, rowEnd, colEnd); + continue; + } - // Draw left line. - if (colStart >= 0) { - for (auto i = rowStartBound; i <= rowEndBound; ++i) - for (auto c = 0; c < channels; c++) { - output->p(batch, i, colStart, c, colorTable[colorIndex][c]); - } - } - // Draw right line. - if (colEnd < width) { - for (auto i = rowStartBound; i <= rowEndBound; ++i) - for (auto c = 0; c < channels; c++) { - output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); - } - } - } + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowStart, j, c, colorTable[colorIndex][c]); } - }; - samediff::Threads::parallel_tad(func, 0, batchSize); + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound; j <= colEndBound; ++j) + for (auto c = 0; c < channels; c++) { + output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]); + } + } + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colStart, c, colorTable[colorIndex][c]); + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound; i <= rowEndBound; ++i) + for (auto c = 0; c < channels; c++) { + output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); + } + } + } } - -} -} + }; + samediff::Threads::parallel_tad(func, 0, batchSize); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 2f0f007797bc..aa76b630e5fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -33,983 +33,1080 @@ limitations under the License. // @author sgazeos@gmail.com // -#include #include #include +#include + #include "../cross.h" namespace sd { namespace ops { namespace helpers { - struct BilinearInterpolationData { - Nd4jLong _bottomIndex; // Lower source index used in the interpolation - Nd4jLong _topIndex; // Upper source index used in the interpolation - // 1D linear iterpolation scale (see: - // https://en.wikipedia.org/wiki/Bilinear_interpolation) - double _interpolarValue; - }; - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } - - template - struct ImageResizerStateCommon { - explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters) - : _alignCorners(alignCorners), - _halfPixelCenters(halfPixelCenters) {} - - // ValidateAndCalculateOutputSize checks the bounds on the input tensors - // and requested size, sets up some of the resizing state such as the - // heightScale and widthScale, and calculates the output size. - // If any of these operations fails, it sets an error status in - // the context, which the caller must check. - int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { - // - batchSize = input->sizeAt(0);//.dim_size(0); - outHeight = height; - outWidth = width; //internal::SubtleMustCopy(Svec(1)); - inHeight = static_cast(input->sizeAt(1)); - inWidth = static_cast(input->sizeAt(2)); - channels = input->sizeAt(3); //.dim_size(3); - heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); - widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); - - // Guard against overflows - if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); - } - if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); - } - - return Status::OK(); - } - - // Calculates all the required variables, and allocates the output. - int validateAndCreateOutput(NDArray const* input, int const width, int const height) { - return validateAndCalculateOutputSize(input, width, height); - } - - I batchSize; - I outHeight; - I outWidth; - I inHeight; - I inWidth; - I channels; - F heightScale; - F widthScale; - NDArray* output = nullptr; - - private: - bool _alignCorners; - bool _halfPixelCenters; - }; +struct BilinearInterpolationData { + Nd4jLong _bottomIndex; // Lower source index used in the interpolation + Nd4jLong _topIndex; // Upper source index used in the interpolation + // 1D linear iterpolation scale (see: + // https://en.wikipedia.org/wiki/Bilinear_interpolation) + double _interpolarValue; +}; +// calculateResizeScale determines the float scaling factor. +inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); +} - typedef ImageResizerStateCommon ImageResizerState; +template +struct ImageResizerStateCommon { + explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, + int const height) { + // + batchSize = input->sizeAt(0); //.dim_size(0); + outHeight = height; + outWidth = width; // internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize width"); + } - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, + int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + I batchSize; + I outHeight; + I outWidth; + I inHeight; + I inWidth; + I channels; + F heightScale; + F widthScale; + NDArray* output = nullptr; + + private: + bool _alignCorners; + bool _halfPixelCenters; +}; + +typedef ImageResizerStateCommon ImageResizerState; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - HalfPixelScaler(){}; - inline float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +struct HalfPixelScaler { + HalfPixelScaler(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } +}; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScalerNN { - HalfPixelScalerNN(){}; - inline float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale; - } - }; +struct HalfPixelScalerNN { + HalfPixelScalerNN(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale; + } +}; // Older incorrect scaling method that causes all resizes to have a slight // translation leading to inconsistent results. For example, a flip then a // resize gives different results then a resize then a flip. - struct LegacyScaler { - LegacyScaler(){}; - inline float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - - struct WeightsAndIndices { - float _weight0; - float _weight1; - float _weight2; - float _weight3; - Nd4jLong _index0; - Nd4jLong _index1; - Nd4jLong _index2; - Nd4jLong _index3; - - int _advance; // advance value. - }; - - template - inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize, - Nd4jLong inSize, - double scale, - BilinearInterpolationData *interpolationData) { - interpolationData[outSize]._bottomIndex = 0; - interpolationData[outSize]._topIndex = 0; - - auto func = PRAGMA_THREADS_FOR { - for (auto k = start; k < stop; k++) { - auto i = (outSize - k - 1); - double const in = scaler(i, scale); - double const in_f = sd::math::nd4j_floor(in); - double const in_c = sd::math::nd4j_ceil(in); - interpolationData[i]._bottomIndex = sd::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); - interpolationData[i]._topIndex = sd::math::nd4j_min(static_cast(in_c), inSize - 1); - interpolationData[i]._interpolarValue = in - in_f; - } - }; - samediff::Threads::parallel_for(func, 0, outSize); +struct LegacyScaler { + LegacyScaler(){}; + inline float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } +}; + +struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. +}; + +template +inline void computeInterpolationWeights( + const Scaler scaler, Nd4jLong outSize, Nd4jLong inSize, double scale, + BilinearInterpolationData* interpolationData) { + interpolationData[outSize]._bottomIndex = 0; + interpolationData[outSize]._topIndex = 0; + + auto func = PRAGMA_THREADS_FOR { + for (auto k = start; k < stop; k++) { + auto i = (outSize - k - 1); + double const in = scaler(i, scale); + double const in_f = sd::math::nd4j_floor(in); + double const in_c = sd::math::nd4j_ceil(in); + interpolationData[i]._bottomIndex = + sd::math::nd4j_max(static_cast(in_f), + (Nd4jLong)0LL); // static_cast(in); + interpolationData[i]._topIndex = + sd::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i]._interpolarValue = in - in_f; } + }; + samediff::Threads::parallel_for(func, 0, outSize); +} /** * Computes the bilinear interpolation from the appropriate 4 float points * and the linear interpolation weights. */ // static void -// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, +// Nd4jLong inWidth, Nd4jLong outHeight, // Nd4jLong outWidth, Nd4jLong channels, // std::vector const& xs, // std::vector const& ys, // NDArray *output); - template - static void - resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const &xs, - std::vector const &ys, - Z* pOutputBuf) { - - Nd4jLong inRowSize = inWidth * channels; - Nd4jLong inBatchNumValues = inHeight * inRowSize; - Nd4jLong outRowSize = outWidth * channels; - -// T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction - BilinearInterpolationData const* xsPtr = xs.data(); - -// T* pOutputBuf = output->dataBuffer()->primaryAsT(); - auto computeBilinear = [](double topLeft, double topRight, - double bottomLeft, double bottomRight, - double xVal, double yVal) { - double top = topLeft + (topRight - topLeft) * xVal; - double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - return top + (bottom - top) * yVal; - }; - - auto func = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; ++batch) { - auto pInput = pInputBuf + batch * inBatchNumValues; - for (Nd4jLong y = 0; y < outHeight; ++y) { - auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize; - const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize; - const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize; - double yVal = ys[y]._interpolarValue; - for (Nd4jLong x = 0; x < outWidth; ++x) { - auto xsBottom = xsPtr[x]._bottomIndex; - auto xsTop = xsPtr[x]._topIndex; - auto xVal = xsPtr[x]._interpolarValue; - for (Nd4jLong c = 0; c < channels; ++c) { - double topLeft(ysInputLowerPtr[xsBottom + c]); - double topRight(ysInputLowerPtr[xsTop + c]); - double bottomLeft(ysInputUpperPtr[xsBottom + c]); - double bottomRight(ysInputUpperPtr[xsTop + c]); - pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight, - xVal, yVal); - } - } - } - } - }; - samediff::Threads::parallel_tad(func, 0, batchSize); +template +static void resizeImage_(T const* pInputBuf, Nd4jLong batchSize, + Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, + Nd4jLong channels, + std::vector const& xs, + std::vector const& ys, + Z* pOutputBuf) { + Nd4jLong inRowSize = inWidth * channels; + Nd4jLong inBatchNumValues = inHeight * inRowSize; + Nd4jLong outRowSize = outWidth * channels; + + // T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // + // this works only with 'c' direction + BilinearInterpolationData const* xsPtr = xs.data(); + + // T* pOutputBuf = output->dataBuffer()->primaryAsT(); + auto computeBilinear = [](double topLeft, double topRight, double bottomLeft, + double bottomRight, double xVal, double yVal) { + double top = topLeft + (topRight - topLeft) * xVal; + double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + return top + (bottom - top) * yVal; + }; + + auto func = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; ++batch) { + auto pInput = pInputBuf + batch * inBatchNumValues; + for (Nd4jLong y = 0; y < outHeight; ++y) { + auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize; + const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize; + const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize; + double yVal = ys[y]._interpolarValue; + for (Nd4jLong x = 0; x < outWidth; ++x) { + auto xsBottom = xsPtr[x]._bottomIndex; + auto xsTop = xsPtr[x]._topIndex; + auto xVal = xsPtr[x]._interpolarValue; + for (Nd4jLong c = 0; c < channels; ++c) { + double topLeft(ysInputLowerPtr[xsBottom + c]); + double topRight(ysInputLowerPtr[xsTop + c]); + double bottomLeft(ysInputUpperPtr[xsBottom + c]); + double bottomRight(ysInputUpperPtr[xsTop + c]); + pOutput[x * channels + c] = computeBilinear( + topLeft, topRight, bottomLeft, bottomRight, xVal, yVal); + } + } + } } + }; + samediff::Threads::parallel_tad(func, 0, batchSize); +} - template - static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, - bool const halfPixelCenter, NDArray *output) { - ImageResizerState st(alignCorners, halfPixelCenter); - st.validateAndCalculateOutputSize(images, width, height); - - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); +template +static int resizeBilinearFunctor_(NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return Status::OK(); + } + + std::vector ys(outHeight + 1); + std::vector xs(outWidth + 1); + if (halfPixelCenter) { + computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, + st.heightScale, ys.data()); + computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, + st.widthScale, xs.data()); + + } else { + // Compute the cached interpolation weights on the x and y dimensions. + computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, + st.heightScale, ys.data()); + computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, + st.widthScale, xs.data()); + } + int xsSize = xs.size(); + // Scale x interpolation weights to avoid a multiplication during iteration. + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + xs[i]._bottomIndex *= channels; + xs[i]._topIndex *= channels; + } + }; + samediff::Threads::parallel_for(func, 0, xsSize); - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return Status::OK(); - } + resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, + inHeight, inWidth, outHeight, outWidth, channels, xs, ys, + output->dataBuffer()->primaryAsT()); + return Status::OK(); +} - std::vector ys(outHeight + 1); - std::vector xs(outWidth + 1); +template +void resizeNeighbor(ImageResizerState const& st, NDArray const* images, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + const Nd4jLong batchSize = st.batchSize; + const Nd4jLong inHeight = st.inHeight; + const Nd4jLong inWidth = st.inWidth; + const Nd4jLong channels = st.channels; + + const Nd4jLong outHeight = st.outHeight; + const Nd4jLong outWidth = st.outWidth; + Scaler scaler; + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto y = start_y; y < stop_y; y += inc_y) { + auto posY = + alignCorners + ? static_cast( + sd::math::p_round(scaler(y, st.heightScale))) + : static_cast( + sd::math::p_floor(scaler(y, st.heightScale))); + Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); if (halfPixelCenter) { - computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale, - ys.data()); - computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data()); - + inY = sd::math::nd4j_max(0LL, inY); } - else { - // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale, - ys.data()); - computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data()); + for (Nd4jLong x = 0; x < outWidth; ++x) { + auto posX = + alignCorners + ? static_cast( + sd::math::p_round(scaler(x, st.widthScale))) + : static_cast( + sd::math::p_floor(scaler(x, st.widthScale))); + Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenter) { + inX = sd::math::nd4j_max(0LL, inX); + } + // copy pixel over all channels + for (Nd4jLong e = 0; e < channels; e++) + output->t(b, y, x, e) = images->t(b, inY, inX, e); } - int xsSize = xs.size(); - // Scale x interpolation weights to avoid a multiplication during iteration. - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - xs[i]._bottomIndex *= channels; - xs[i]._topIndex *= channels; - } - }; - samediff::Threads::parallel_for(func, 0, xsSize); - - resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT()); - return Status::OK(); + } } + }; + samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); +} - template - void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - const Nd4jLong batchSize = st.batchSize; - const Nd4jLong inHeight = st.inHeight; - const Nd4jLong inWidth = st.inWidth; - const Nd4jLong channels = st.channels; - - const Nd4jLong outHeight = st.outHeight; - const Nd4jLong outWidth = st.outWidth; - Scaler scaler; - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto y = start_y; y < stop_y; y += inc_y) { - auto posY = alignCorners ? static_cast(sd::math::p_round(scaler(y, st.heightScale))) : static_cast(sd::math::p_floor(scaler(y, st.heightScale))); - Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); - if (halfPixelCenter) { - inY = sd::math::nd4j_max(0LL, inY); - } - for (Nd4jLong x = 0; x < outWidth; ++x) { - auto posX = alignCorners ? static_cast(sd::math::p_round(scaler(x, st.widthScale))) : static_cast(sd::math::p_floor(scaler(x, st.widthScale))); - Nd4jLong inX = sd::math::nd4j_min(posX,inWidth - 1); - if (halfPixelCenter) { - inX = sd::math::nd4j_max(0LL, inX); - } - // copy pixel over all channels - for (Nd4jLong e = 0; e < channels; e++) - output->t(b, y, x, e) = images->t(b, inY, inX, e); - } - } - } - }; - samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); - } - - template - int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - ImageResizerState st(alignCorners, halfPixelCenter); - st.validateAndCalculateOutputSize(images, width, height); - - // Handle no-op resizes efficiently. - if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) { - output->assign(images); - return Status::OK(); - } - - if (halfPixelCenter) - resizeNeighbor(st, images, alignCorners, true, output); - else - resizeNeighbor(st, images, alignCorners, false, output); - - return Status::OK(); - } +template +int resizeNeighborFunctor_(NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + // Handle no-op resizes efficiently. + if (output->sizeAt(1) == images->sizeAt(1) && + output->sizeAt(2) == images->sizeAt(2)) { + output->assign(images); + return Status::OK(); + } + + if (halfPixelCenter) + resizeNeighbor(st, images, alignCorners, true, + output); + else + resizeNeighbor(st, images, alignCorners, false, output); + + return Status::OK(); +} -// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong +// inHeight, Nd4jLong inWidth, Nd4jLong outHeight, // Nd4jLong outWidth, Nd4jLong channels, // std::vector const &xs, // std::vector const &ys, // NDArray *output) { -// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, -// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), +// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), +// resizeImage_, +// (images, batchSize, inHeight, inWidth, +// outHeight, outWidth, channels, xs, ys, output), // NUMERIC_TYPES, FLOAT_TYPES); // } - int resizeBilinearFunctor(sd::LaunchContext * context, NDArray const *images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); - return Status::OK(); - } +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_DOUBLE_SELECTOR( + images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), + NUMERIC_TYPES, FLOAT_TYPES); + return Status::OK(); +} - int resizeNeighborFunctor(sd::LaunchContext * context, NDArray const *images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); - } +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + images->dataType(), return resizeNeighborFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), + LIBND4J_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// ------------------------------------------------------------------------------------------------------------------ // -// Bicubic interpolation -// ------------------------------------------------------------------------------------------------------------------ // - class CachedInterpolationCalculator { - public: - CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} - - // Advances iteration. Returns the number of values that should be copied from - // the current point to the next point. The copying should always be done by - // copying the last values from the old point to the first - // values of the new point. - inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, - const Nd4jLong x3) { - // We use 2 hands and walk through, copying from one to another where - // we already have values. - // Invariant, new_indicies_hand <= cached_values_hand - const Nd4jLong new_x_indices[] = {x0, x1, x2, x3}; - int cachedValuesHand = 0; - int newIndiciesHand = 0; - while (cachedValuesHand < 4) { - if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { - if (newIndiciesHand < cachedValuesHand) { - _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; - } - newIndiciesHand++; - } - cachedValuesHand++; - } - switch (newIndiciesHand) { - case 0: - _indexes[0] = x0; - case 1: - _indexes[1] = x1; - case 2: - _indexes[2] = x2; - case 3: - _indexes[3] = x3; - break; - } - return newIndiciesHand; - } - - private: - Nd4jLong _indexes[4]; - }; - static const Nd4jLong kTableSize = 1024LL; //(1 << 10); - - const float* initCoeffsTable(const double a) { - // Allocate and initialize coefficients table using Bicubic - // convolution algorithm. - // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffs_table = new float[(kTableSize + 1) * 2]; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i <= stop; ++i) { - float x = i * 1.0 / kTableSize; - coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; - x += 1.0; - coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; - } - }; - samediff::Threads::parallel_for(func, 0, kTableSize); - return coeffs_table; - } - - const float* getCoeffsTable(const bool use_keys_cubic) { - // Static so that we initialize it on first use - if (use_keys_cubic) { - // http://ieeexplore.ieee.org/document/1163711/ - // R. G. Keys. Cubic convolution interpolation for digital image - // processing. IEEE Transactions on Acoustics, Speech, and Signal - // Processing, 29(6):1153–1160, 1981. - static const float* coeffs_table = initCoeffsTable(-0.5f); - return coeffs_table; - } else { - static const float* coeffs_table = initCoeffsTable(-0.75f); - return coeffs_table; +// ------------------------------------------------------------------------------------------------------------------ +// // Bicubic interpolation +// ------------------------------------------------------------------------------------------------------------------ +// // +class CachedInterpolationCalculator { + public: + CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, + const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; } + newIndiciesHand++; + } + cachedValuesHand++; } - - inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { - return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; } - - template - int resizeBicubicFunctor_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - return ND4J_STATUS_OK; + return newIndiciesHand; + } + + private: + Nd4jLong _indexes[4]; +}; +static const Nd4jLong kTableSize = 1024LL; //(1 << 10); + +const float* initCoeffsTable(const double a) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffs_table = new float[(kTableSize + 1) * 2]; + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i <= stop; ++i) { + float x = i * 1.0 / kTableSize; + coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; } + }; + samediff::Threads::parallel_for(func, 0, kTableSize); + return coeffs_table; +} - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, - width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); - } -// ------------------------------------------------------------------------------------------------------------------ // - - template - inline float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, - const T value0, const T value1, const T value2, const T value3) { - return static_cast(value0) * weight0 + - static_cast(value1) * weight1 + - static_cast(value2) * weight2 + - static_cast(value3) * weight3; - } +const float* getCoeffsTable(const bool use_keys_cubic) { + // Static so that we initialize it on first use + if (use_keys_cubic) { + // http://ieeexplore.ieee.org/document/1163711/ + // R. G. Keys. Cubic convolution interpolation for digital image + // processing. IEEE Transactions on Acoustics, Speech, and Signal + // Processing, 29(6):1153–1160, 1981. + static const float* coeffs_table = initCoeffsTable(-0.5f); + return coeffs_table; + } else { + static const float* coeffs_table = initCoeffsTable(-0.75f); + return coeffs_table; + } +} -// Compute the 1D interpolation for a given X index using the y_weights - static float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { - return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); - } +inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); +} - template - inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { - const Scaler scaler; - const float in_loc_f = scaler(out_loc, scale); - const Nd4jLong in_loc = std::floor(in_loc_f); - const float delta = in_loc_f - in_loc; - const Nd4jLong offset = lrintf(delta * kTableSize); - const float* coeffs_table = getCoeffsTable(use_keys_cubic); - if (use_keys_cubic) { - // The legacy code placed more weight on the edge pixels, since bounding - // the set of inputs to sample could cause an edge pixel to be repeated. - // Here we change the behavior at borders to match that used by the - // scale_and_translate_op, where sampling locations outside the image have - // their weight set to 0, and the weights are renormalized so that their sum - // is 1.0. - out->_index0 = bound(in_loc - 1, limit); - out->_weight0 = - (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); - out->_index1 = bound(in_loc, limit); - out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); - out->_index2 = bound(in_loc + 1, limit); - out->_weight2 = - (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] - : 0.0f); - out->_index3 = bound(in_loc + 2, limit); - out->_weight3 = (out->_index3 == in_loc + 2 - ? coeffs_table[(kTableSize - offset) * 2 + 1] - : 0.0f); - - const float weight_sum = - out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; - if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits::min()) { - const float one_over_weight_sum = 1.0f / weight_sum; - out->_weight0 *= one_over_weight_sum; - out->_weight1 *= one_over_weight_sum; - out->_weight2 *= one_over_weight_sum; - out->_weight3 *= one_over_weight_sum; - } - } else { - out->_weight0 = coeffs_table[offset * 2 + 1]; - out->_weight1 = coeffs_table[offset * 2]; - out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; - out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; - out->_index0 = bound(in_loc - 1, limit); - out->_index1 = bound(in_loc, limit); - out->_index2 = bound(in_loc + 1, limit); - out->_index3 = bound(in_loc + 2, limit); - } - } +template +int resizeBicubicFunctor_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + return ND4J_STATUS_OK; +} - static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, - const bool half_pixel_centers, - std::vector* x_wais) { - CachedInterpolationCalculator calc; - if (half_pixel_centers) { - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - getWeightsAndIndices( - resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); - auto &x_wai = (*x_wais)[x]; - x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, - x_wai._index3); - } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - getWeightsAndIndices( - resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); - auto& x_wai = (*x_wais)[x]; - x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2, - x_wai._index3); - } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } - // Scale the values so they can be used as offsets into buffers. - auto func = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; ++x) { - (*x_wais)[x]._index0 *= resizer_state.channels; - (*x_wais)[x]._index1 *= resizer_state.channels; - (*x_wais)[x]._index2 *= resizer_state.channels; - (*x_wais)[x]._index3 *= resizer_state.channels; - } - }; - samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); - } +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctor_, + (context, image, width, height, preserveAspectRatio, antialias, output), + NUMERIC_TYPES); +} +// ------------------------------------------------------------------------------------------------------------------ +// // + +template +inline float interpolate1D(const float weight0, const float weight1, + const float weight2, const float weight3, + const T value0, const T value1, const T value2, + const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; +} - template - static FORCEINLINE float computeYInterpolation( - int which, int channelNum, const WeightsAndIndices& yWai, - const T* pY0, const T* pY1, const T* pY2, const T* pY3, - const WeightsAndIndices& xWai) { - int xIndex; - switch (which) { - case 0: - xIndex = xWai._index0; - break; - case 1: - xIndex = xWai._index1; - break; - case 2: - xIndex = xWai._index2; - break; - default: - xIndex = xWai._index3; - break; - } - const Nd4jLong pt_index = xIndex + channelNum; - return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, - yWai._weight3, pY0[pt_index], pY1[pt_index], - pY2[pt_index], pY3[pt_index]); - } +// Compute the 1D interpolation for a given X index using the y_weights +static float compute(float values[4], const float xW0, const float xW1, + const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1], values[2], + values[3]); +} - template - static void - bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { - std::vector xWais(resizerState.outWidth); - computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); - - const auto numChannels = resizerState.channels; - const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; - const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; - const auto batchNum = resizerState.batchSize; - const auto outHeight = resizerState.outHeight; - const auto outWidth = resizerState.outWidth; - - auto func = PRAGMA_THREADS_FOR { - const T* inputPtr = image->getDataBuffer()->primaryAsT(); - F* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway - std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); - - for (auto b = start; b < stop; ++b) { - auto pInput = inputPtr + b * inBatchWidth; - - for (Nd4jLong y = 0; y < outHeight; ++y) { - auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; - - WeightsAndIndices yWai; - if (halfPixelCenters) { - getWeightsAndIndices( - resizerState.heightScale, y, resizerState.inHeight, &yWai); - } else { - getWeightsAndIndices( - resizerState.heightScale, y, resizerState.inHeight, &yWai); - } - // Make pointers represent offsets of data in inputBPtr. - const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; - - if (numChannels == 3) { - // Manually unroll case of 3 channels. - F cached_value_0[4] = {0}; - F cached_value_1[4] = {0}; - F cached_value_2[4] = {0}; - for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { - const WeightsAndIndices &xWai = xWais[x]; - // Shift values in cached_value_* to fill first '_advance' values. - switch (xWai._advance) { - case 3: - cached_value_0[0] = cached_value_0[1]; - cached_value_0[1] = cached_value_0[2]; - cached_value_0[2] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[1]; - cached_value_1[1] = cached_value_1[2]; - cached_value_1[2] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[1]; - cached_value_2[1] = cached_value_2[2]; - cached_value_2[2] = cached_value_2[3]; - break; - case 2: - cached_value_0[0] = cached_value_0[2]; - cached_value_0[1] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[2]; - cached_value_1[1] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[2]; - cached_value_2[1] = cached_value_2[3]; - break; - case 1: { - cached_value_0[0] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[3]; - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - cached_value_0[0] = computeYInterpolation( - 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[0] = computeYInterpolation( - 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[0] = computeYInterpolation( - 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 1: - cached_value_0[1] = computeYInterpolation( - 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[1] = computeYInterpolation( - 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[1] = computeYInterpolation( - 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 2: - cached_value_0[2] = computeYInterpolation( - 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[2] = computeYInterpolation( - 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[2] = computeYInterpolation( - 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - - case 3: - cached_value_0[3] = computeYInterpolation( - 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[3] = computeYInterpolation( - 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[3] = computeYInterpolation( - 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - break; - } - pOutput[x * numChannels + 0] = - compute(cached_value_0, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * numChannels + 1] = - compute(cached_value_1, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * numChannels + 2] = - compute(cached_value_2, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } else { - for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { - const WeightsAndIndices &xWai = xWais[x]; - // Shift values in cachedValue to fill first '_advance' values. - switch (xWai._advance) { - case 3: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; - } - break; - case 2: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; - } - break; - case 1: { - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; - } - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 0] = computeYInterpolation( - 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 1: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 1] = computeYInterpolation( - 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 2: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 2] = computeYInterpolation( - 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - - case 3: - for (auto c = 0; c < numChannels; ++c) { - cachedValue[4 * c + 3] = computeYInterpolation( - 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - break; - } - for (auto c = 0; c < numChannels; ++c) { - pOutput[x * numChannels + c] = - (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } - } - } - } - }; - samediff::Threads::parallel_tad(func, 0, batchNum); +template +inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, + const Nd4jLong limit, WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = std::floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = lrintf(delta * kTableSize); + const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } +} -// simplified bicubic resize without antialiasing -// - template - int resizeBicubicFunctorA_(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output) { - ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align - int res = st.validateAndCreateOutput(image, width, height); - if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); - - return res; - } - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); - } -// ------------------------------------------------------------------------------------------------------------------ // - struct CachedInterpolation { - Nd4jLong start; - Nd4jLong end; - float startScale; - float endMinusOneScale; - bool needsBounding; +static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, + const bool half_pixel_centers, + std::vector* x_wais) { + CachedInterpolationCalculator calc; + if (half_pixel_centers) { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, + x_wai._index2, x_wai._index3); + } }; - - template - struct ScaleCache { - float yScale; - T const* yPtr; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + getWeightsAndIndices( + resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]); + auto& x_wai = (*x_wais)[x]; + x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, + x_wai._index2, x_wai._index3); + } }; - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static void computePatchSumOf3Channels(float scale, - ImageResizerState const& st, - std::vector> const& yPtrs, - CachedInterpolation const& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); + } + // Scale the values so they can be used as offsets into buffers. + auto func = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; ++x) { + (*x_wais)[x]._index0 *= resizer_state.channels; + (*x_wais)[x]._index1 *= resizer_state.channels; + (*x_wais)[x]._index2 *= resizer_state.channels; + (*x_wais)[x]._index3 *= resizer_state.channels; + } + }; + samediff::Threads::parallel_for(func, 0, resizer_state.outWidth); +} - float sum_0 = 0; - float sum_1 = 0; - float sum_2 = 0; - for (size_t i = 0; i < yPtrs.size(); ++i) { - const T* ptr = yPtrs[i].yPtr; - float scaleX = xCache.startScale; - Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); - float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; - float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; - float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; - - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]); - sum_y_1 += static_cast(ptr[offset + 1]); - sum_y_2 += static_cast(ptr[offset + 2]); - } - scaleX = xCache.endMinusOneScale; - offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; - sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; - sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; - } - sum_0 += sum_y_0 * yPtrs[i].yScale; - sum_1 += sum_y_1 * yPtrs[i].yScale; - sum_2 += sum_y_2 * yPtrs[i].yScale; - } +template +static FORCEINLINE float computeYInterpolation(int which, int channelNum, + const WeightsAndIndices& yWai, + const T* pY0, const T* pY1, + const T* pY2, const T* pY3, + const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); +} - outputPtr[0] = sum_0 * scale; - outputPtr[1] = sum_1 * scale; - outputPtr[2] = sum_2 * scale; - } +template +static void bicubicInterpolateWithCaching(NDArray const* image, + ImageResizerState const& resizerState, + bool const halfPixelCenters, + NDArray* output) { + std::vector xWais(resizerState.outWidth); + computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); + + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + const auto batchNum = resizerState.batchSize; + const auto outHeight = resizerState.outHeight; + const auto outWidth = resizerState.outWidth; + + auto func = PRAGMA_THREADS_FOR { + const T* inputPtr = image->getDataBuffer()->primaryAsT(); + F* pOutputY = + output->dataBuffer()->primaryAsT(); // output is float anyway + std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); + + for (auto b = start; b < stop; ++b) { + auto pInput = inputPtr + b * inBatchWidth; + + for (Nd4jLong y = 0; y < outHeight; ++y) { + auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; + + WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); + } else { + getWeightsAndIndices( + resizerState.heightScale, y, resizerState.inHeight, &yWai); + } + // Make pointers represent offsets of data in inputBPtr. + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (numChannels == 3) { + // Manually unroll case of 3 channels. + F cached_value_0[4] = {0}; + F cached_value_1[4] = {0}; + F cached_value_2[4] = {0}; + for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static void computePatchSum(float scale, const ImageResizerState& st, - const std::vector>& yPtrs, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - const auto numChannels = st.channels; - for (Nd4jLong c = 0; c < numChannels; ++c) { - float sum = 0; - for (size_t i = 0; i < yPtrs.size(); ++i) { - T const* ptr = yPtrs[i].yPtr; - float scaleX = xCache.startScale; - float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - sumY += static_cast( - ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); - } - scaleX = xCache.endMinusOneScale; - sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; - } - sum += sumY * yPtrs[i].yScale; + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + cached_value_0[0] = computeYInterpolation( + 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation( + 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation( + 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 1: + cached_value_0[1] = computeYInterpolation( + 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation( + 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation( + 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 2: + cached_value_0[2] = computeYInterpolation( + 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation( + 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation( + 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + + case 3: + cached_value_0[3] = computeYInterpolation( + 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation( + 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation( + 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + break; + } + pOutput[x * numChannels + 0] = + compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 1] = + compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * numChannels + 2] = + compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } else { + for (Nd4jLong x = 0; x < resizerState.outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; + } + break; + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; } - outputPtr[c] = sum * scale; + break; + } } - } + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation( + 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 1: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation( + 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + + case 2: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation( + 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } - template - static void resizeArea(ImageResizerState const& st, std::vector const& caches, NDArray const* input, NDArray* output) { - T const* inputPtr = input->bufferAsT(); - float scale = 1.f / (st.heightScale * st.widthScale); - auto outputPtr = output->bufferAsT(); // output is always float. TO DO: provide another float types also with template declaration - - auto batchProcess = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - for (auto y = 0; y < st.outHeight; ++y) { - const float inY = y * st.heightScale; - const float inY1 = (y + 1) * st.heightScale; - // The start and end height indices of all the cells that could - // contribute to the target cell. - const Nd4jLong yStart = math::nd4j_floor(inY); - const Nd4jLong yEnd = math::nd4j_ceil(inY1); - - std::vector> yCaches; - auto cacheLen = yEnd - yStart; - if (cacheLen) { - yCaches.resize(cacheLen); - }; - - for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) { - ScaleCache scaleCache; - if (i < inY) { - scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY); - } else { - scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0); - } - scaleCache.yPtr = inputPtr + (batch * st.inHeight * st.inWidth * st.channels + - bound(i, st.inHeight) * st.inWidth * st.channels); - yCaches[k] = scaleCache; - } - float* output = outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth; - - if (st.channels == 3) { - for (Nd4jLong x = 0; x < st.outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSumOf3Channels(scale, st, yCaches, xCache, output); - output += st.channels; - } - } else { - for (Nd4jLong x = 0; x < st.outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSum(scale, st, yCaches, xCache, output); - output += st.channels; - } - } + case 3: + for (auto c = 0; c < numChannels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation( + 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); } + break; } - }; - samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); + for (auto c = 0; c < numChannels; ++c) { + pOutput[x * numChannels + c] = + (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } + } + } + } } + }; + samediff::Threads::parallel_tad(func, 0, batchNum); +} - template - int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - ImageResizerState st(alignCorners, false); // Create resize info - auto res = st.validateAndCalculateOutputSize(image, width, height); - if (Status::OK() == res) { - std::vector xCached(st.outWidth); - auto cachingProcedure = PRAGMA_THREADS_FOR { - for (auto x = start; x < stop; x++) { - auto &xCache = xCached[x]; - const float inX = x * st.widthScale; - const float inX1 = (x + 1) * st.widthScale; - - Nd4jLong v = math::nd4j_floor(inX); - xCache.start = v; - xCache.startScale = - v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v - : 1.f); - v = math::nd4j_ceil(inX1); - xCache.end = v--; - xCache.endMinusOneScale = - v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v - : 1.f); - xCache.needsBounding = bound(xCache.start, st.inWidth) != xCache.start || - bound(xCache.end - 1, st.inWidth) != (xCache.end - 1); - - } - }; - samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); - - resizeArea(st, xCached, image, output); - } - return res; +// simplified bicubic resize without antialiasing +// +template +int resizeBicubicFunctorA_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output) { + ImageResizerState st(alignCorners, + halfPixelAlign); // align_corners, half_pixel_align + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); + + return res; +} +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctorA_, + (context, image, width, height, alignCorners, halfPixelAlign, output), + NUMERIC_TYPES); +} +// ------------------------------------------------------------------------------------------------------------------ +// // +struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; +}; + +template +struct ScaleCache { + float yScale; + T const* yPtr; +}; +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static void computePatchSumOf3Channels(float scale, ImageResizerState const& st, + std::vector> const& yPtrs, + CachedInterpolation const& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (size_t i = 0; i < yPtrs.size(); ++i) { + const T* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; } + sum_0 += sum_y_0 * yPtrs[i].yScale; + sum_1 += sum_y_1 * yPtrs[i].yScale; + sum_2 += sum_y_2 * yPtrs[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; +} - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static void computePatchSum(float scale, const ImageResizerState& st, + const std::vector>& yPtrs, + const CachedInterpolation& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (size_t i = 0; i < yPtrs.size(); ++i) { + T const* ptr = yPtrs[i].yPtr; + float scaleX = xCache.startScale; + float sumY = + static_cast( + ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * + scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); + } + scaleX = xCache.endMinusOneScale; + sumY += + static_cast( + ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + + c]) * + scaleX; + } + sum += sumY * yPtrs[i].yScale; } + outputPtr[c] = sum * scale; + } +} + +template +static void resizeArea(ImageResizerState const& st, + std::vector const& caches, + NDArray const* input, NDArray* output) { + T const* inputPtr = input->bufferAsT(); + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = + output->bufferAsT(); // output is always float. TO DO: provide + // another float types also with template + // declaration + + auto batchProcess = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + for (auto y = 0; y < st.outHeight; ++y) { + const float inY = y * st.heightScale; + const float inY1 = (y + 1) * st.heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + + std::vector> yCaches; + auto cacheLen = yEnd - yStart; + if (cacheLen) { + yCaches.resize(cacheLen); + }; -// ------------------------------------------------------------------------------------------------------------------ // - int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { - switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; - case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; - case kResizeArea: return resizeAreaFunctor(context, image, width, height, preserveAspectRatio, output); - case kResizeLanczos5: - case kResizeGaussian: - case kResizeMitchelcubic: - throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); + for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) { + ScaleCache scaleCache; + if (i < inY) { + scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY); + } else { + scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0); + } + scaleCache.yPtr = + inputPtr + (batch * st.inHeight * st.inWidth * st.channels + + bound(i, st.inHeight) * st.inWidth * st.channels); + yCaches[k] = scaleCache; + } + float* output = + outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth; + + if (st.channels == 3) { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSumOf3Channels(scale, st, yCaches, xCache, output); + output += st.channels; + } + } else { + for (Nd4jLong x = 0; x < st.outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSum(scale, st, yCaches, xCache, output); + output += st.channels; + } } - return ND4J_STATUS_OK; + } } + }; + samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1); +} +template +int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + if (Status::OK() == res) { + std::vector xCached(st.outWidth); + auto cachingProcedure = PRAGMA_THREADS_FOR { + for (auto x = start; x < stop; x++) { + auto& xCache = xCached[x]; + const float inX = x * st.widthScale; + const float inX1 = (x + 1) * st.widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = v < inX + ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = + v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + xCache.needsBounding = + bound(xCache.start, st.inWidth) != xCache.start || + bound(xCache.end - 1, st.inWidth) != (xCache.end - 1); + } + }; + samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1); + resizeArea(st, xCached, image, output); + } + return res; } + +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, + (context, image, width, height, alignCorners, output), + NUMERIC_TYPES); +} + +// ------------------------------------------------------------------------------------------------------------------ +// // +int resizeFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, ImageResizeMethods method, + bool preserveAspectRatio, bool antialias, NDArray* output) { + switch (method) { + case kResizeBilinear: + return resizeBilinearFunctor(context, image, width, height, false, false, + output); + break; + case kResizeNearest: + return resizeNeighborFunctor(context, image, width, height, false, false, + output); + break; + case kResizeBicubic: + return resizeBicubicFunctor(context, image, width, height, + preserveAspectRatio, antialias, output); + break; + case kResizeArea: + return resizeAreaFunctor(context, image, width, height, + preserveAspectRatio, output); + case kResizeLanczos5: + case kResizeGaussian: + case kResizeMitchelcubic: + throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); + } + return ND4J_STATUS_OK; } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp index c3ad42db3181..741292a3cd6a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp @@ -18,8 +18,9 @@ // @author sgazeos@gmail.com // -#include #include +#include + #include #include #include @@ -28,235 +29,288 @@ namespace sd { namespace ops { namespace helpers { - template - static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double overlapThreshold, - double scoreThreshold, NDArray* output) { - std::vector indices(scales->lengthOf()); - std::iota(indices.begin(), indices.end(), 0); - auto actualIndicesCount = indices.size(); - for (auto e = 0; e < scales->lengthOf(); e++) { - if (scales->e(e) < (float)scoreThreshold) { - indices[e] = -1; - actualIndicesCount--; - } - } - std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return i >= 0 && j >=0?scales->e(i) > scales->e(j):(i > j);}); - -// std::vector selected(output->lengthOf()); - std::vector selectedIndices(output->lengthOf(), 0); - auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool { - if (previousIndex < 0 || nextIndex < 0) return true; - T minYPrev = sd::math::nd4j_min(boxes.t(previousIndex, 0), boxes.t(previousIndex, 2)); - T minXPrev = sd::math::nd4j_min(boxes.t(previousIndex, 1), boxes.t(previousIndex, 3)); - T maxYPrev = sd::math::nd4j_max(boxes.t(previousIndex, 0), boxes.t(previousIndex, 2)); - T maxXPrev = sd::math::nd4j_max(boxes.t(previousIndex, 1), boxes.t(previousIndex, 3)); - T minYNext = sd::math::nd4j_min(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); - T minXNext = sd::math::nd4j_min(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); - T maxYNext = sd::math::nd4j_max(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); - T maxXNext = sd::math::nd4j_max(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - return intersectionValue > threshold; - - }; -// int numSelected = 0; - int numBoxes = actualIndicesCount; //boxes->sizeAt(0); - int numSelected = 0; - - for (int i = 0; i < numBoxes; ++i) { - bool shouldSelect = numSelected < output->lengthOf(); - - // FIXME: add parallelism here - for (int j = numSelected - 1; j >= 0; --j) { - if (shouldSelect) - if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(overlapThreshold))) { - shouldSelect = false; - } - } - if (shouldSelect) { - output->p(numSelected, indices[i]); - selectedIndices[numSelected++] = i; - } - } +template +static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + std::vector indices(scales->lengthOf()); + std::iota(indices.begin(), indices.end(), 0); + auto actualIndicesCount = indices.size(); + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < (float)scoreThreshold) { + indices[e] = -1; + actualIndicesCount--; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Return intersection-over-union overlap between boxes i and j - template - static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - const T zero = static_cast(0.f); - const T yminI = math::nd4j_min(boxes.t(i, 0), boxes.t(i, 2)); - const T xminI = math::nd4j_min(boxes.t(i, 1), boxes.t(i, 3)); - const T ymaxI = math::nd4j_max(boxes.t(i, 0), boxes.t(i, 2)); - const T xmaxI = math::nd4j_max(boxes.t(i, 1), boxes.t(i, 3)); - const T yminJ = math::nd4j_min(boxes.t(j, 0), boxes.t(j, 2)); - const T xminJ = math::nd4j_min(boxes.t(j, 1), boxes.t(j, 3)); - const T ymaxJ = math::nd4j_max(boxes.t(j, 0), boxes.t(j, 2)); - const T xmaxJ = math::nd4j_max(boxes.t(j, 1), boxes.t(j, 3)); - const T areaI = (ymaxI - yminI) * (xmaxI - xminI); - const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); - if (areaI <= zero || areaJ <= zero) { - return zero; + } + std::sort(indices.begin(), indices.end(), [scales](int i, int j) { + return i >= 0 && j >= 0 ? scales->e(i) > scales->e(j) : (i > j); + }); + + // std::vector selected(output->lengthOf()); + std::vector selectedIndices(output->lengthOf(), 0); + auto needToSuppressWithThreshold = [](NDArray& boxes, int previousIndex, + int nextIndex, T threshold) -> bool { + if (previousIndex < 0 || nextIndex < 0) return true; + T minYPrev = sd::math::nd4j_min(boxes.t(previousIndex, 0), + boxes.t(previousIndex, 2)); + T minXPrev = sd::math::nd4j_min(boxes.t(previousIndex, 1), + boxes.t(previousIndex, 3)); + T maxYPrev = sd::math::nd4j_max(boxes.t(previousIndex, 0), + boxes.t(previousIndex, 2)); + T maxXPrev = sd::math::nd4j_max(boxes.t(previousIndex, 1), + boxes.t(previousIndex, 3)); + T minYNext = + sd::math::nd4j_min(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); + T minXNext = + sd::math::nd4j_min(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); + T maxYNext = + sd::math::nd4j_max(boxes.t(nextIndex, 0), boxes.t(nextIndex, 2)); + T maxXNext = + sd::math::nd4j_max(boxes.t(nextIndex, 1), boxes.t(nextIndex, 3)); + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + return intersectionValue > threshold; + }; + // int numSelected = 0; + int numBoxes = actualIndicesCount; // boxes->sizeAt(0); + int numSelected = 0; + + for (int i = 0; i < numBoxes; ++i) { + bool shouldSelect = numSelected < output->lengthOf(); + + // FIXME: add parallelism here + for (int j = numSelected - 1; j >= 0; --j) { + if (shouldSelect) + if (needToSuppressWithThreshold(*boxes, indices[i], + indices[selectedIndices[j]], + T(overlapThreshold))) { + shouldSelect = false; } - const T intersectionYmin = math::nd4j_max(yminI, yminJ); - const T intersectionXmin = math::nd4j_max(xminI, xminJ); - const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ); - const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ); - const T intersectionY = intersectionYmax - intersectionYmin; - const T intersectionX = intersectionXmax - intersectionXmin; - const T intersectionArea = math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero); - return intersectionArea / (areaI + areaJ - intersectionArea); - } - - template - static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - return boxes.t(i, j); - } - - typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j); - - static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.); - BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, (boxes, i, j) , FLOAT_TYPES); - return res; } - - static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { - NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.); - BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j) , FLOAT_TYPES); - return res; + if (shouldSelect) { + output->p(numSelected, indices[i]); + selectedIndices[numSelected++] = i; } - + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static Nd4jLong - nonMaxSuppressionGeneric_(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, - float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc f) { - - auto numBoxes = boxes->sizeAt(0); - T* scoresData = scores->dataBuffer()->primaryAsT(); - - // Data structure for a selection candidate in NMS. - struct Candidate { - int _boxIndex; - T _score; - int _suppressBeginIndex; - }; - - auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool{ - return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || - (bsI._score < bsJ._score); - }; - - std::priority_queue, decltype(cmp)> candidatePriorityQueue(cmp); - for (auto i = 0; i < scores->lengthOf(); ++i) { - if ((float)scoresData[i] > (float)scoreThreshold) { - candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); - } - } - - std::vector selected; - T similarity, originalScore; - Candidate nextCandidate; - - while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { - nextCandidate = candidatePriorityQueue.top(); - originalScore = nextCandidate._score; - candidatePriorityQueue.pop(); - - // Overlapping boxes are likely to have similar scores, therefore we - // iterate through the previously selected boxes backwards in order to - // see if `nextCandidate` should be suppressed. We also enforce a property - // that a candidate can be suppressed by another candidate no more than - // once via `suppress_begin_index` which tracks which previously selected - // boxes have already been compared against next_candidate prior to a given - // iteration. These previous selected boxes are then skipped over in the - // following loop. - bool shouldHardSuppress = false; - for (int j = static_cast(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) { - auto similarityA = f(*boxes, nextCandidate._boxIndex, selected[j]); //boxes->t(nextCandidate._boxIndex, selected[j]); - similarity = similarityA.template t(0); - nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity); +// Return intersection-over-union overlap between boxes i and j +template +static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { + const T zero = static_cast(0.f); + const T yminI = math::nd4j_min(boxes.t(i, 0), boxes.t(i, 2)); + const T xminI = math::nd4j_min(boxes.t(i, 1), boxes.t(i, 3)); + const T ymaxI = math::nd4j_max(boxes.t(i, 0), boxes.t(i, 2)); + const T xmaxI = math::nd4j_max(boxes.t(i, 1), boxes.t(i, 3)); + const T yminJ = math::nd4j_min(boxes.t(j, 0), boxes.t(j, 2)); + const T xminJ = math::nd4j_min(boxes.t(j, 1), boxes.t(j, 3)); + const T ymaxJ = math::nd4j_max(boxes.t(j, 0), boxes.t(j, 2)); + const T xmaxJ = math::nd4j_max(boxes.t(j, 1), boxes.t(j, 3)); + const T areaI = (ymaxI - yminI) * (xmaxI - xminI); + const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); + if (areaI <= zero || areaJ <= zero) { + return zero; + } + const T intersectionYmin = math::nd4j_max(yminI, yminJ); + const T intersectionXmin = math::nd4j_max(xminI, xminJ); + const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ); + const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ); + const T intersectionY = intersectionYmax - intersectionYmin; + const T intersectionX = intersectionXmax - intersectionXmin; + const T intersectionArea = + math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero); + return intersectionArea / (areaI + areaJ - intersectionArea); +} - // First decide whether to perform hard suppression - if ((float)similarity >= static_cast(overlapThreshold)) { - shouldHardSuppress = true; - break; - } +template +static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, + Nd4jLong j) { + return boxes.t(i, j); +} - // If next_candidate survives hard suppression, apply soft suppression - if ((float)nextCandidate._score <= (float)scoreThreshold) break; - } - // If `nextCandidate._score` has not dropped below `scoreThreshold` - // by this point, then we know that we went through all of the previous - // selections and can safely update `suppress_begin_index` to - // `selected.size()`. If on the other hand `next_candidate.score` - // *has* dropped below the score threshold, then since `suppressWeight` - // always returns values in [0, 1], further suppression by items that were - // not covered in the above for loop would not have caused the algorithm - // to select this item. We thus do the same update to - // `suppressBeginIndex`, but really, this element will not be added back - // into the priority queue in the following. - nextCandidate._suppressBeginIndex = selected.size(); +typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j); - if (!shouldHardSuppress) { - if (nextCandidate._score == originalScore) { - // Suppression has not occurred, so select next_candidate - selected.push_back(nextCandidate._boxIndex); -// selected_scores.push_back(nextCandidate._score); - } - if ((float)nextCandidate._score > (float)scoreThreshold) { - // Soft suppression has occurred and current score is still greater than - // score_threshold; add next_candidate back onto priority queue. - candidatePriorityQueue.push(nextCandidate); - } - } - } +static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, + Nd4jLong j) { + NDArray res(boxes.dataType(), + boxes.getContext()); // = NDArrayFactory::create(0.); + BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, + (boxes, i, j), FLOAT_TYPES); + return res; +} - if (output) { - DataBuffer buf(selected.data(), selected.size() * sizeof(I), DataTypeUtils::fromT()); - output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); - } +static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) { + NDArray res(boxes.dataType(), + boxes.getContext()); // = NDArrayFactory::create(0.); + BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j), + FLOAT_TYPES); + return res; +} - return (Nd4jLong)selected.size(); +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, + NDArray* boxes, NDArray* scores, + int outputSize, + float overlapThreshold, + float scoreThreshold, NDArray* output, + SimiliratyFunc f) { + auto numBoxes = boxes->sizeAt(0); + T* scoresData = scores->dataBuffer()->primaryAsT(); + + // Data structure for a selection candidate in NMS. + struct Candidate { + int _boxIndex; + T _score; + int _suppressBeginIndex; + }; + + auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool { + return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || + (bsI._score < bsJ._score); + }; + + std::priority_queue, decltype(cmp)> + candidatePriorityQueue(cmp); + for (auto i = 0; i < scores->lengthOf(); ++i) { + if ((float)scoresData[i] > (float)scoreThreshold) { + candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); } - - Nd4jLong - nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyOverlaps), FLOAT_TYPES, INTEGER_TYPES); - return 0; + } + + std::vector selected; + T similarity, originalScore; + Candidate nextCandidate; + + while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { + nextCandidate = candidatePriorityQueue.top(); + originalScore = nextCandidate._score; + candidatePriorityQueue.pop(); + + // Overlapping boxes are likely to have similar scores, therefore we + // iterate through the previously selected boxes backwards in order to + // see if `nextCandidate` should be suppressed. We also enforce a property + // that a candidate can be suppressed by another candidate no more than + // once via `suppress_begin_index` which tracks which previously selected + // boxes have already been compared against next_candidate prior to a given + // iteration. These previous selected boxes are then skipped over in the + // following loop. + bool shouldHardSuppress = false; + for (int j = static_cast(selected.size()) - 1; + j >= nextCandidate._suppressBeginIndex; --j) { + auto similarityA = + f(*boxes, nextCandidate._boxIndex, + selected[j]); // boxes->t(nextCandidate._boxIndex, selected[j]); + similarity = similarityA.template t(0); + nextCandidate._score *= T(similarity <= overlapThreshold + ? 1.0 + : 0.); // suppressWeightFunc(similarity); + + // First decide whether to perform hard suppression + if ((float)similarity >= static_cast(overlapThreshold)) { + shouldHardSuppress = true; + break; + } + + // If next_candidate survives hard suppression, apply soft suppression + if ((float)nextCandidate._score <= (float)scoreThreshold) break; } - - Nd4jLong - nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyV3), FLOAT_TYPES, INTEGER_TYPES); - return 0; + // If `nextCandidate._score` has not dropped below `scoreThreshold` + // by this point, then we know that we went through all of the previous + // selections and can safely update `suppress_begin_index` to + // `selected.size()`. If on the other hand `next_candidate.score` + // *has* dropped below the score threshold, then since `suppressWeight` + // always returns values in [0, 1], further suppression by items that were + // not covered in the above for loop would not have caused the algorithm + // to select this item. We thus do the same update to + // `suppressBeginIndex`, but really, this element will not be added back + // into the priority queue in the following. + nextCandidate._suppressBeginIndex = selected.size(); + + if (!shouldHardSuppress) { + if (nextCandidate._score == originalScore) { + // Suppression has not occurred, so select next_candidate + selected.push_back(nextCandidate._boxIndex); + // selected_scores.push_back(nextCandidate._score); + } + if ((float)nextCandidate._score > (float)scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // score_threshold; add next_candidate back onto priority queue. + candidatePriorityQueue.push(nextCandidate); + } } + } - BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc similiratyFunc), FLOAT_TYPES, INTEGER_TYPES); + if (output) { + DataBuffer buf(selected.data(), selected.size() * sizeof(I), + DataTypeUtils::fromT()); + output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); + } - void - nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, - overlapThreshold, scoreThreshold, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output), NUMERIC_TYPES); + return (Nd4jLong)selected.size(); +} +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, + double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), + output == nullptr ? DataType::INT32 : output->dataType(), + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, + output, similiratyOverlaps), + FLOAT_TYPES, INTEGER_TYPES); + return 0; } + +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), + output == nullptr ? DataType::INT32 : output->dataType(), + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, + output, similiratyV3), + FLOAT_TYPES, INTEGER_TYPES); + return 0; +} + +BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, + (sd::LaunchContext * context, NDArray* boxes, + NDArray* scores, int maxSize, float overlapThreshold, + float scoreThreshold, NDArray* output, + SimiliratyFunc similiratyFunc), + FLOAT_TYPES, INTEGER_TYPES); + +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output) { + BUILD_SINGLE_SELECTOR( + boxes->dataType(), nonMaxSuppressionV2_, + (boxes, scales, maxSize, overlapThreshold, scoreThreshold, output), + NUMERIC_TYPES); } -} \ No newline at end of file +BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, + (NDArray * boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output), + NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp index 2183b7d5adda..ff6e70c25dff 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp @@ -19,10 +19,10 @@ // @author AbdelRauf (rauf@konduit.ai) // +#include +#include #include #include -#include -#include namespace sd { namespace ops { @@ -30,259 +30,271 @@ namespace helpers { template static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - const int rank = input.rankOf(); - - if(dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && - 'c' == output.ordering() && 1 == output.ews()){ - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const auto xStep = i*3; - z[i] = 0.2989f*x[xStep] + 0.5870f*x[xStep + 1] + 0.1140f*x[xStep + 2]; - } - }; - - samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); - return; - } - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - const auto xOffset0 = shape::getOffset(input.shapeInfo(), coords); - const auto xOffset1 = xOffset0 + input.strideAt(dimC); - const auto xOffset2 = xOffset1 + input.strideAt(dimC); - z[zOffset] = 0.2989f*x[xOffset0] + 0.5870f*x[xOffset1] + 0.1140f*x[xOffset2]; - } + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + + if (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const auto xStep = i * 3; + z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + + 0.1140f * x[xStep + 2]; + } }; samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); return; + } + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + const auto xOffset0 = shape::getOffset(input.shapeInfo(), coords); + const auto xOffset1 = xOffset0 + input.strideAt(dimC); + const auto xOffset2 = xOffset1 + input.strideAt(dimC); + z[zOffset] = + 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; } -void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), + NUMERIC_TYPES); } template -FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, const int dimC, Op op) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - const int rank = input.rankOf(); - bool bSimple = (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && - 'c' == output.ordering() && 1 == output.ews()); - - if (bSimple) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3); - return; - } - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input.stridesOf()[dimC]; - const Nd4jLong zDimCstride = output.stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } +FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, + const int dimC, Op op) { + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + bool bSimple = + (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()); + + if (bSimple) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } }; - samediff::Threads::parallel_tad(func, 0, numOfTads); + samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3); return; + } + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + return; } template -FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, const int dimC) { - auto op = sd::ops::helpers::rgbYuv; - return rgbToFromYuv_(input, output, dimC, op); +FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, + const int dimC) { + auto op = sd::ops::helpers::rgbYuv; + return rgbToFromYuv_(input, output, dimC, op); } -void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), FLOAT_TYPES); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), + FLOAT_TYPES); } template -FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, const int dimC) { - auto op = sd::ops::helpers::yuvRgb; - return rgbToFromYuv_(input, output, dimC, op); +FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, + const int dimC) { + auto op = sd::ops::helpers::yuvRgb; + return rgbToFromYuv_(input, output, dimC, op); } -void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), FLOAT_TYPES); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), + FLOAT_TYPES); } template -FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) { - - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, + const int dimC, Op op) { + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); - } - }; + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } + }; - samediff::Threads::parallel_tad(func, 0, numOfTads); - } + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - template -FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) { - - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - // TODO: Use tensordot or other optimizied helpers to see if we can get better performance. - - if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, + const int dimC, T (&tr)[3][3]) { + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + // TODO: Use tensordot or other optimizied helpers to see if we can get better + // performance. + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && + input->ordering() == 'c' && output->ordering() == 'c') { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) { + // simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A' + // v.shape (1,3) row vector + T x0, x1, x2; + x0 = x[i]; // just additional hint + x1 = x[i + 1]; + x2 = x[i + 2]; + z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + z[i + 1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + z[i + 2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + } + }; - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - //simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A' - // v.shape (1,3) row vector - T x0, x1, x2; - x0 = x[i]; //just additional hint - x1 = x[i + 1]; - x2 = x[i + 2]; - z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; - z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; - z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); - } - }; + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + // simple M*v //tr.T*v + T x0, x1, x2; + x0 = xTad[0]; + x1 = xTad[xDimCstride]; + x2 = xTad[2 * xDimCstride]; + zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; + zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; + zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; + } + }; - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - //simple M*v //tr.T*v - T x0, x1, x2; - x0 = xTad[0]; - x1 = xTad[xDimCstride]; - x2 = xTad[2 * xDimCstride]; - zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0]; - zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1]; - zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2]; - - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } + samediff::Threads::parallel_tad(func, 0, numOfTads); + } } - - - template -FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) { - auto op = sd::ops::helpers::hsvToRgb; - return tripleTransformer(input, output, dimC, op); +FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, + const int dimC) { + auto op = sd::ops::helpers::hsvToRgb; + return tripleTransformer(input, output, dimC, op); } template -FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int dimC) { - auto op = sd::ops::helpers::rgbToHsv; - return tripleTransformer(input, output, dimC, op); +FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, + const int dimC) { + auto op = sd::ops::helpers::rgbToHsv; + return tripleTransformer(input, output, dimC, op); } - template -FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) { - T arr[3][3] = { - { (T)0.299, (T)0.59590059, (T)0.2115 }, - { (T)0.587, (T)-0.27455667, (T)-0.52273617 }, - { (T)0.114, (T)-0.32134392, (T)0.31119955 } - }; - return tripleTransformer(input, output, dimC, arr); +FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, + const int dimC) { + T arr[3][3] = {{(T)0.299, (T)0.59590059, (T)0.2115}, + {(T)0.587, (T)-0.27455667, (T)-0.52273617}, + {(T)0.114, (T)-0.32134392, (T)0.31119955}}; + return tripleTransformer(input, output, dimC, arr); } template -FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) { - //TODO: this operation does not use the clamp operation, so there is a possibility being out of range. - //Justify that it will not be out of range for images data - T arr[3][3] = { - { (T)1, (T)1, (T)1 }, - { (T)0.95598634, (T)-0.27201283, (T)-1.10674021 }, - { (T)0.6208248, (T)-0.64720424, (T)1.70423049 } - }; - return tripleTransformer(input, output, dimC, arr); +FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, + const int dimC) { + // TODO: this operation does not use the clamp operation, so there is a + // possibility being out of range. Justify that it will not be out of range + // for images data + T arr[3][3] = {{(T)1, (T)1, (T)1}, + {(T)0.95598634, (T)-0.27201283, (T)-1.10674021}, + {(T)0.6208248, (T)-0.64720424, (T)1.70423049}}; + return tripleTransformer(input, output, dimC, arr); } - - -void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), + FLOAT_TYPES); } -void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), + FLOAT_TYPES); } -void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), + FLOAT_TYPES); } -void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES); +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), + FLOAT_TYPES); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp index 5325ac282b95..1bff0a9ecea6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/invertPermutation.cpp @@ -18,34 +18,37 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////// -void invertPermutation(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - - std::set uniqueElems; - const int length = input.lengthOf(); - - for(int i = 0; i < length; ++i) { - - int elem = input.e(i); - - if(!uniqueElems.insert(elem).second) // this operation forbids us to use #pragma omp - throw std::runtime_error("helpers::invertPermutation function: input array contains duplicates !"); - - if(elem < 0 || elem > length - 1) - throw std::runtime_error("helpers::invertPermutation function: element of input array is out of range (0, length-1) !"); - - output.p(elem, i); - } +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + std::set uniqueElems; + const int length = input.lengthOf(); + + for (int i = 0; i < length; ++i) { + int elem = input.e(i); + + if (!uniqueElems.insert(elem) + .second) // this operation forbids us to use #pragma omp + throw std::runtime_error( + "helpers::invertPermutation function: input array contains " + "duplicates !"); + + if (elem < 0 || elem > length - 1) + throw std::runtime_error( + "helpers::invertPermutation function: element of input array is out " + "of range (0, length-1) !"); + + output.p(elem, i); + } } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp index 687153f99780..432f893da454 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp @@ -19,193 +19,183 @@ // @author raver119@gmail.com // - -#include -#include -#include #include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void ismax_(const NDArray* input, NDArray* output, const std::vector& dimensions) { - - if (input->isVector()) { - int dimensionsLength = dimensions.size(); - int length = input->lengthOf(); - if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) { - for (int i = 0; i < length; i++) - output->p(i, 1); - } - else { - int eleStride = shape::elementWiseStride(input->shapeInfo()); - if (eleStride == 1) { - int maxIdx = 0; - auto currMax = input->e(0); - if (length < ELEMENT_THRESHOLD) { - - for (int i = 0; i < length; i++) { - if (currMax < input->e(i)) { - currMax = input->e(i); - maxIdx = i; - } - output->p(i, 0); - } - } - else { - - { - int maxIdxLocal = maxIdx; - auto currMaxLocal = currMax; - - for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i)) { - currMaxLocal = input->e(i); - maxIdxLocal = i; - } - output->p(i, 0); - } - - PRAGMA_OMP_CRITICAL - { - if (currMax < currMaxLocal) { - currMax = currMaxLocal; - maxIdx = maxIdxLocal; - } - } - } - } - output->p(maxIdx, 1); +static void ismax_(const NDArray* input, NDArray* output, + const std::vector& dimensions) { + if (input->isVector()) { + int dimensionsLength = dimensions.size(); + int length = input->lengthOf(); + if (!dimensions.empty() && (input->shapeOf())[dimensions[0]] == 1) { + for (int i = 0; i < length; i++) output->p(i, 1); + } else { + int eleStride = shape::elementWiseStride(input->shapeInfo()); + if (eleStride == 1) { + int maxIdx = 0; + auto currMax = input->e(0); + if (length < ELEMENT_THRESHOLD) { + for (int i = 0; i < length; i++) { + if (currMax < input->e(i)) { + currMax = input->e(i); + maxIdx = i; } - else { - int maxIdx = 0; - auto currMax = input->e(0); - if (length < ELEMENT_THRESHOLD) { - - for (int i = 0; i < length; i++) { - if (currMax < input->e(i)) { - currMax = input->e(i); - maxIdx = i; - } - output->p(i, 0.f); - } - } - else { - - { - int maxIdxLocal = maxIdx; - auto currMaxLocal = currMax; - for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i)) { - currMaxLocal = input->e(i); - maxIdxLocal = i; - } - output->p(i, 0.f); - } - - PRAGMA_OMP_CRITICAL - { - if (currMax < currMaxLocal) { - currMax = currMaxLocal; - maxIdx = maxIdxLocal; - } - } - } - } - output->p(maxIdx, 1); + output->p(i, 0); + } + } else { + { + int maxIdxLocal = maxIdx; + auto currMaxLocal = currMax; + + for (int i = 0; i < length; i++) { + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); + maxIdxLocal = i; + } + output->p(i, 0); } + + PRAGMA_OMP_CRITICAL { + if (currMax < currMaxLocal) { + currMax = currMaxLocal; + maxIdx = maxIdxLocal; + } + } + } } - } - else { - int dimensionsLength = dimensions.size(); - //int tads = tad.numTads; - //decompose in to several sub tads after - //moving all dimensions (in sorted order) - //to the back. - //permuted version of the input shape info for setting up the tad problem - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), const_cast(dimensions.data()), dimensionsLength); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), const_cast(dimensions.data()), dimensionsLength); - - - auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); - auto tadOffsets = tadPack.primaryOffsets(); - auto zOfsets = tadPackZ.platformOffsets(); - - int tadLength = shape::length(tadShapeShapeInfo); - int tads = tadPack.numberOfTads(); - - int tadsPerThread = tads / TAD_THRESHOLD; - int num_threads = sd::math::nd4j_max(1, tadsPerThread); - num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); - - auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); - auto zEWS = shape::elementWiseStride(tadPackZ.primaryShapeInfo()); - - int span = (tads / num_threads) + 8; - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; - auto rZ = output->bufferAsT() + zOfsets[r]; - - auto maxValue = rX[0]; - int maxIdx = 0; - if (tadEWS == 1 && zEWS == 1) { - for (int i = 0; i < tadLength; i++) { - if (rX[i] > maxValue) { - maxIdx = i; - maxValue = rX[i]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } - else if (tadEWS > 1 && zEWS > 1) { - for (int i = 0; i < tadLength; i++) { - if (rX[i * tadEWS] > maxValue) { - maxIdx = i; - maxValue = rX[i * tadEWS]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } else { - for (int i = 0; i < tadLength; i++) { - auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo); - if (rX[xOffset] > maxValue) { - maxIdx = i; - maxValue = rX[xOffset]; - } - } - - PRAGMA_OMP_SIMD - for (int i = 0; i < tadLength; i++) { - auto zOffset = shape::getIndexOffset(i, tadPackZ.primaryShapeInfo()); - rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0; - } - } + output->p(maxIdx, 1); + } else { + int maxIdx = 0; + auto currMax = input->e(0); + if (length < ELEMENT_THRESHOLD) { + for (int i = 0; i < length; i++) { + if (currMax < input->e(i)) { + currMax = input->e(i); + maxIdx = i; + } + output->p(i, 0.f); + } + } else { + { + int maxIdxLocal = maxIdx; + auto currMaxLocal = currMax; + for (int i = 0; i < length; i++) { + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); + maxIdxLocal = i; + } + output->p(i, 0.f); } - }; - samediff::Threads::parallel_tad(func, 0, tads); + PRAGMA_OMP_CRITICAL { + if (currMax < currMaxLocal) { + currMax = currMaxLocal; + maxIdx = maxIdxLocal; + } + } + } + } + output->p(maxIdx, 1); + } } -} - + } else { + int dimensionsLength = dimensions.size(); + // int tads = tad.numTads; + // decompose in to several sub tads after + // moving all dimensions (in sorted order) + // to the back. + // permuted version of the input shape info for setting up the tad problem + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), const_cast(dimensions.data()), + dimensionsLength); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), const_cast(dimensions.data()), + dimensionsLength); + + auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); + auto tadOffsets = tadPack.primaryOffsets(); + auto zOfsets = tadPackZ.platformOffsets(); + + int tadLength = shape::length(tadShapeShapeInfo); + int tads = tadPack.numberOfTads(); + + int tadsPerThread = tads / TAD_THRESHOLD; + int num_threads = sd::math::nd4j_max(1, tadsPerThread); + num_threads = sd::math::nd4j_min(num_threads, omp_get_max_threads()); + + auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); + auto zEWS = shape::elementWiseStride(tadPackZ.primaryShapeInfo()); + + int span = (tads / num_threads) + 8; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto rX = const_cast(input)->bufferAsT() + tadOffsets[r]; + auto rZ = output->bufferAsT() + zOfsets[r]; + + auto maxValue = rX[0]; + int maxIdx = 0; + if (tadEWS == 1 && zEWS == 1) { + for (int i = 0; i < tadLength; i++) { + if (rX[i] > maxValue) { + maxIdx = i; + maxValue = rX[i]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + rZ[i] = maxIdx == i ? (Z)1 : (Z)0; + } + } else if (tadEWS > 1 && zEWS > 1) { + for (int i = 0; i < tadLength; i++) { + if (rX[i * tadEWS] > maxValue) { + maxIdx = i; + maxValue = rX[i * tadEWS]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + rZ[i * zEWS] = maxIdx == i ? (Z)1 : (Z)0; + } + } else { + for (int i = 0; i < tadLength; i++) { + auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo); + if (rX[xOffset] > maxValue) { + maxIdx = i; + maxValue = rX[xOffset]; + } + } + + PRAGMA_OMP_SIMD + for (int i = 0; i < tadLength; i++) { + auto zOffset = + shape::getIndexOffset(i, tadPackZ.primaryShapeInfo()); + rZ[zOffset] = maxIdx == i ? (Z)1 : (Z)0; + } + } + } + }; -void ismax(sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, (input, output, dimensions), LIBND4J_TYPES, LIBND4J_TYPES); + samediff::Threads::parallel_tad(func, 0, tads); + } } - -} -} +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), ismax_, + (input, output, dimensions), LIBND4J_TYPES, + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index b2a0e537f70d..45dc14c3eaf5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -18,369 +18,425 @@ // @author GS // -#include #include +#include #include namespace sd { namespace ops { namespace helpers { - template - static void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.f ? y : T(0.f); - }; +template +static void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, + (theFirst, theSecond), FLOAT_TYPES); +} + +template +static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + T zero = (T)0.f; + auto functor = LAMBDA_TT(x, y, zero) { return x > zero ? y : zero; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); + + /* + auto x = input->bufferAsT(); + auto y = epsilon->bufferAsT(); + auto z = output->bufferAsT(); + + int length = input->lengthOf(); + + T zero = (T) 0.f; + + PRAGMA_OMP_PARALLEL_FOR + for (int e = 0; e < length; e++) { + z[e] = x[e] > zero ? y[e] : zero; + } + */ +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void relu6Derivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.f && x < (T)6.f ? y : T(0.f); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output, const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { return x < 0 ? alphaT * y : y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, + const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return y * sd::math::nd4j_eluderivative(x, alphaT); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +static void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::SELUDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return y * (3 * x * x); }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +// return (x >= X(0.f) ? y: -y); +template +static void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > T(0.f) ? y : -y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +template +static void sigmCrossEntropy_(NDArray* logits, NDArray* labels, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} + +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, + (logits, labels, output), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +template +static void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, + NDArray* output) { + // 1 - labels - 1 / (1 + exp(logits)) + auto functor = LAMBDA_TT(x, y) { + if (x <= 0) + return static_cast(1.) - y - + static_cast(1.) / + (static_cast(1.) + sd::math::nd4j_exp(x)); + auto e = sd::math::nd4j_exp(-x); + return static_cast(1.) - y - e / (static_cast(1.) + e); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} + +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, + (logits, labels, output), FLOAT_TYPES); +} - theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); - } +//////////////////////////////////////////////////////////////////////// +template +static void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * ((T)1.0f - (th * th)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); - } +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +// return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); +template +static void hardTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * simdOps::HardTanhDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +static void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::RationalTanhDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - template - static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - T zero = (T) 0.f; - auto functor = LAMBDA_TT(x, y, zero){ - return x > zero ? y : zero; - }; +template +static void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T)0.0f; + }; - input->applyPairwiseLambda(*epsilon, functor, *output); + input->applyPairwiseLambda(*epsilon, functor, *output); +} - /* - auto x = input->bufferAsT(); - auto y = epsilon->bufferAsT(); - auto z = output->bufferAsT(); +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - int length = input->lengthOf(); +// X f = (X) 1.0f + sd::math::nd4j_abs(d1); +// return (X) d2 * ((X) 1.0f / (f * f)); - T zero = (T) 0.f; +template +static void softSignDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T ss = (T)1.f + sd::math::nd4j_abs(x); + return y * ((T)1.0f / (ss * ss)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - PRAGMA_OMP_PARALLEL_FOR - for (int e = 0; e < length; e++) { - z[e] = x[e] > zero ? y[e] : zero; - } - */ - } +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +template +static void softPlusDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T p = sd::math::nd4j_pow(static_cast(M_E), x); + return y * (p / (p + 1.)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - template - static void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f && x < (T)6.f? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT) { - return x < 0 ? alphaT * y : y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT){ - return y * sd::math::nd4j_eluderivative(x, alphaT); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - static void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::SELUDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * (3 * x * x); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - //return (x >= X(0.f) ? y: -y); - template - static void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > T(0.f)? y : -y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void sigmCrossEntropy_(NDArray* logits, NDArray* labels, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - x * y + sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } - - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, NDArray* output) { - // 1 - labels - 1 / (1 + exp(logits)) - auto functor = LAMBDA_TT(x, y) { - if(x <= 0) - return static_cast(1.) - y - static_cast(1.) / (static_cast(1.) + sd::math::nd4j_exp(x)); - auto e = sd::math::nd4j_exp(-x); - return static_cast(1.) - y - e / (static_cast(1.) + e); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } - - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); - } - - //////////////////////////////////////////////////////////////////////// - template - static void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * ((T)1.0f - (th * th)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - // return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); - template - static void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * simdOps::HardTanhDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::RationalTanhDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T) 0.0f; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - // X f = (X) 1.0f + sd::math::nd4j_abs(d1); - // return (X) d2 * ((X) 1.0f / (f * f)); - - template - static void softSignDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T ss = (T)1.f + sd::math::nd4j_abs(x); - return y * ((T) 1.0f / (ss * ss)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void softPlusDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T p = sd::math::nd4j_pow(static_cast(M_E), x); - return y * (p / (p + 1.)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} /// /// \param theFirst /// \param theSecond /// \param theOutput - template - static void sigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T s = sd::math::nd4j_sigmoid(x); - return y * (s * ((T) 1.0f - s)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::HardSigmoidDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyTransform(transform::Exp, tempInput); - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - template - static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); - tempInput.applyTransform(transform::Exp, tempInput); - - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); - } - - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); - } - -////////////////////////////////////////////////////////////////////////// template -static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { +static void sigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T s = sd::math::nd4j_sigmoid(x); + return y * (s * ((T)1.0f - s)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - T posWeight = weights->e(0); +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { - T targetWeight = (1. + (posWeight - (T)1.f) * _z); - return (1. - _z) * _x + - targetWeight * (sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f)) - ); - }; +template +static void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::HardSigmoidDerivative::op(x, nullptr); + }; - auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { - return (((T)1.0 - _z) * _x) + - _w * (sd::math::nd4j_log(T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f))); - }; + input->applyPairwiseLambda(*epsilon, functor, *output); +} +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); - } - else - { - std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f, *targetVector); +template +static void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} - std::unique_ptr targetTensor(new NDArray(*targets)); - *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); - } +template +static void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, + NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); + + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); } -void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), + FLOAT_TYPES); } +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, + (input, subtrah, axis, output), FLOAT_TYPES); } + +////////////////////////////////////////////////////////////////////////// +template +static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + T posWeight = weights->e(0); + + auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { + T targetWeight = (1. + (posWeight - (T)1.f) * _z); + return (1. - _z) * _x + + targetWeight * (sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp( + -sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { + return (((T)1.0 - _z) * _x) + + _w * + (sd::math::nd4j_log( + T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + if (weights->isScalar()) { + const_cast(input)->applyPairwiseLambda( + const_cast(*targets), mainRoutineT1, *output); + } else { + std::unique_ptr targetVector(new NDArray(*weights)); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); + + std::unique_ptr targetTensor(new NDArray(*targets)); + *targetTensor = (*targetVector * *targetTensor) + T(1.f); + const_cast(input)->applyTriplewiseLambda( + const_cast(*targets), *targetTensor.get(), mainRoutineT2, + *output); + } } -} \ No newline at end of file + +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + BUILD_SINGLE_SELECTOR(targets->dataType(), + weightedCrossEntropyWithLogitsFunctor_, + (targets, input, weights, output), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp index 3b71f7ce9742..b6d6b66b2a0f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp @@ -19,8 +19,8 @@ // @author George A. Shulinok // -#include #include +#include namespace sd { namespace ops { @@ -30,24 +30,23 @@ namespace helpers { // calculate digamma function for array elements template static void lgamma_(NDArray& x, NDArray& z) { - - auto lgammaProc = LAMBDA_T(x_) { - return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); - }; - - x.applyLambda(lgammaProc, z); + auto lgammaProc = LAMBDA_T(x_) { + return T( + DataTypeUtils::fromT() == DataType::DOUBLE + ? ::lgamma(x_) + : ::lgammaf(x_)); // math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); } void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray & x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp index 8dc31d8c00ad..aa9591f9c298 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp @@ -19,314 +19,348 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include #include -#include +#include namespace sd { namespace ops { namespace helpers { template -static int lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta) { - - nd4j_debug("MKL-DNN is not used for lrn!\n", 0); - - const int rank = input->rankOf(); - - TadPack inTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {rank - 1}); - TadPack outTadPack; - - if(shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo())) - outTadPack = inTadPack; - else - outTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {rank - 1}); - - const Nd4jLong numOfTads = inTadPack.numberOfTads(); - const Nd4jLong tadLen = input->sizeAt(-1); - - const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); - const Nd4jLong* outTadOffsets = outTadPack.primaryOffsets(); - - const Nd4jLong inTadEws = shape::elementWiseStride(inTadPack.primaryShapeInfo()); - const Nd4jLong outTadEws = shape::elementWiseStride(outTadPack.primaryShapeInfo()); - - const T* inBuff = reinterpret_cast(input->buffer()); - T* outBuff = reinterpret_cast(output->buffer()); - - const T tbias = static_cast(bias); - const T tbeta = static_cast(beta); - const T talpha = static_cast(alpha); - - if(inTadEws == 1 && outTadEws == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T *x = inBuff + inTadOffsets[i]; - T *y = outBuff + outTadOffsets[i]; - - T prev = 0; - - // calculate squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - for (uint s = begin; s < end; ++s) - prev = prev + x[s] * x[s]; - y[j] = prev; - } else if (begin == 0 && last <= tadLen) - y[j] = prev + x[end - 1] * x[end - 1]; - else if (begin > 0 && last <= tadLen) - y[j] = prev + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; - else if (begin > 0 && last > tadLen) - y[j] = prev - x[begin - 1] * x[begin - 1]; - else - y[j] = prev; - - if (j != 0) - prev = y[j]; - - y[j] = x[j] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - else { - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < numOfTads; ++i) { - const T *x = inBuff + inTadOffsets[i]; - T *y = outBuff + outTadOffsets[i]; - - T prev = 0; - - // calculate squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - for (uint s = begin; s < end; ++s) - prev = prev + x[s * inTadEws] * x[s * inTadEws]; - y[j * outTadEws] = prev; - } else if (begin == 0 && last <= tadLen) - y[j * outTadEws] = prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; - else if (begin > 0 && last <= tadLen) - y[j * outTadEws] = prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else if (begin > 0 && last > tadLen) - y[j * outTadEws] = prev - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else - y[j * outTadEws] = prev; - - if (j != 0) - prev = y[j * outTadEws]; - - y[j * outTadEws] = x[j * inTadEws] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - return Status::OK(); +static int lrnFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* output, int depth, float bias, float alpha, + float beta) { + nd4j_debug("MKL-DNN is not used for lrn!\n", 0); + + const int rank = input->rankOf(); + + TadPack inTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {rank - 1}); + TadPack outTadPack; + + if (shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo())) + outTadPack = inTadPack; + else + outTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {rank - 1}); + + const Nd4jLong numOfTads = inTadPack.numberOfTads(); + const Nd4jLong tadLen = input->sizeAt(-1); + + const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); + const Nd4jLong* outTadOffsets = outTadPack.primaryOffsets(); + + const Nd4jLong inTadEws = + shape::elementWiseStride(inTadPack.primaryShapeInfo()); + const Nd4jLong outTadEws = + shape::elementWiseStride(outTadPack.primaryShapeInfo()); + + const T* inBuff = reinterpret_cast(input->buffer()); + T* outBuff = reinterpret_cast(output->buffer()); + + const T tbias = static_cast(bias); + const T tbeta = static_cast(beta); + const T talpha = static_cast(alpha); + + if (inTadEws == 1 && outTadEws == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T* x = inBuff + inTadOffsets[i]; + T* y = outBuff + outTadOffsets[i]; + + T prev = 0; + + // calculate squared sum of elements per each j-th element range [j - + // depth, j + depth + 1] we store each squared sum in corresponding + // element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + for (uint s = begin; s < end; ++s) prev = prev + x[s] * x[s]; + y[j] = prev; + } else if (begin == 0 && last <= tadLen) + y[j] = prev + x[end - 1] * x[end - 1]; + else if (begin > 0 && last <= tadLen) + y[j] = prev + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; + else if (begin > 0 && last > tadLen) + y[j] = prev - x[begin - 1] * x[begin - 1]; + else + y[j] = prev; + + if (j != 0) prev = y[j]; + + y[j] = + x[j] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < numOfTads; ++i) { + const T* x = inBuff + inTadOffsets[i]; + T* y = outBuff + outTadOffsets[i]; + + T prev = 0; + + // calculate squared sum of elements per each j-th element range [j - + // depth, j + depth + 1] we store each squared sum in corresponding + // element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + for (uint s = begin; s < end; ++s) + prev = prev + x[s * inTadEws] * x[s * inTadEws]; + y[j * outTadEws] = prev; + } else if (begin == 0 && last <= tadLen) + y[j * outTadEws] = + prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; + else if (begin > 0 && last <= tadLen) + y[j * outTadEws] = + prev + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else if (begin > 0 && last > tadLen) + y[j * outTadEws] = + prev - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else + y[j * outTadEws] = prev; + + if (j != 0) prev = y[j * outTadEws]; + + y[j * outTadEws] = x[j * inTadEws] / sd::math::nd4j_pow( + tbias + alpha * prev, tbeta); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } + return Status::OK(); } - -BUILD_SINGLE_TEMPLATE(template int lrnFunctor_, (sd::graph::Context& block, NDArray* input, NDArray* output, int depth, float bias, float alpha, float beta), FLOAT_TYPES); -int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - BUILD_SINGLE_SELECTOR(input->dataType(), return lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int lrnFunctor_, + (sd::graph::Context & block, NDArray* input, + NDArray* output, int depth, float bias, float alpha, + float beta), + FLOAT_TYPES); + +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta) { + BUILD_SINGLE_SELECTOR(input->dataType(), return lrnFunctor_, + (block, input, output, depth, bias, alpha, beta), + FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// template -static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - - const int rank = input.rankOf(); - - TadPack inTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), {rank - 1}); - TadPack gradITadPack; - - if(shape::haveSameShapeAndStrides(input.shapeInfo(), gradI.shapeInfo())) - gradITadPack = inTadPack; - else - gradITadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), {rank - 1}); - - const Nd4jLong numOfTads = inTadPack.numberOfTads(); - const Nd4jLong tadLen = input.sizeAt(-1); - - const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); - const Nd4jLong* gradITadOffsets = gradITadPack.primaryOffsets(); - - const Nd4jLong inTadEws = shape::elementWiseStride(inTadPack.primaryShapeInfo()); - const Nd4jLong gradITadEws = shape::elementWiseStride(gradITadPack.primaryShapeInfo()); - - const X* inBuff = reinterpret_cast(input.buffer()); - Y* gradIBuff = reinterpret_cast(gradI.buffer()); - - const Y tbias = static_cast(bias); - const Y tbeta = static_cast(beta); - const Y talpha = static_cast(alpha); - const Y coeff = talpha * tbeta; - - if(inTadEws == 1 && gradITadEws == 1) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const X *x = inBuff + inTadOffsets[i]; - Y *y = gradIBuff + gradITadOffsets[i]; - - // this loop calculates squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - y[0] = 0; - for (uint s = begin; s < end; ++s) - y[0] = y[0] + x[s] * x[s]; - } else if (begin == 0 && last <= tadLen) - y[j] = y[j - 1] + x[end - 1] * x[end - 1]; - else if (begin > 0 && last <= tadLen) - y[j] = y[j - 1] + x[end - 1] * x[end - 1] - x[begin - 1] * x[begin - 1]; - else if (begin > 0 && last > tadLen) - y[j] = y[j - 1] - x[begin - 1] * x[begin - 1]; - else - y[j] = y[j - 1]; - } - - Y *factor = new Y[tadLen]; - - Y prev = 0; - // second loop calculates derivatives using information gained in first loop above - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - Y init = tbias + talpha * y[j]; - - if (j == 0) { - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s], -tbeta - 1); - prev = prev + x[s] * factor[s]; - } - y[0] = prev; - } else if (begin == 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[end - 1], -tbeta - 1); - y[j] = prev + x[end - 1] * factor[end - 1]; - } else if (begin > 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[end - 1], -tbeta - 1); - y[j] = prev + x[end - 1] * factor[end - 1] - x[begin - 1] * factor[begin - 1]; - } else if (begin > 0 && last > tadLen) - y[j] = prev - x[begin - 1] * factor[begin - 1]; - else - y[j] = prev; - - if (j != 0) - prev = y[j]; - - y[j] = factor[j] * init - 2 * x[j] * coeff * prev; - } - - delete[]factor; +static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, + const int depth, const float bias, const float alpha, + const float beta) { + const int rank = input.rankOf(); + + TadPack inTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), {rank - 1}); + TadPack gradITadPack; + + if (shape::haveSameShapeAndStrides(input.shapeInfo(), gradI.shapeInfo())) + gradITadPack = inTadPack; + else + gradITadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradI.shapeInfo(), {rank - 1}); + + const Nd4jLong numOfTads = inTadPack.numberOfTads(); + const Nd4jLong tadLen = input.sizeAt(-1); + + const Nd4jLong* inTadOffsets = inTadPack.primaryOffsets(); + const Nd4jLong* gradITadOffsets = gradITadPack.primaryOffsets(); + + const Nd4jLong inTadEws = + shape::elementWiseStride(inTadPack.primaryShapeInfo()); + const Nd4jLong gradITadEws = + shape::elementWiseStride(gradITadPack.primaryShapeInfo()); + + const X* inBuff = reinterpret_cast(input.buffer()); + Y* gradIBuff = reinterpret_cast(gradI.buffer()); + + const Y tbias = static_cast(bias); + const Y tbeta = static_cast(beta); + const Y talpha = static_cast(alpha); + const Y coeff = talpha * tbeta; + + if (inTadEws == 1 && gradITadEws == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const X* x = inBuff + inTadOffsets[i]; + Y* y = gradIBuff + gradITadOffsets[i]; + + // this loop calculates squared sum of elements per each j-th element + // range [j - depth, j + depth + 1] we store each squared sum in + // corresponding element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + y[0] = 0; + for (uint s = begin; s < end; ++s) y[0] = y[0] + x[s] * x[s]; + } else if (begin == 0 && last <= tadLen) + y[j] = y[j - 1] + x[end - 1] * x[end - 1]; + else if (begin > 0 && last <= tadLen) + y[j] = y[j - 1] + x[end - 1] * x[end - 1] - + x[begin - 1] * x[begin - 1]; + else if (begin > 0 && last > tadLen) + y[j] = y[j - 1] - x[begin - 1] * x[begin - 1]; + else + y[j] = y[j - 1]; + } + + Y* factor = new Y[tadLen]; + + Y prev = 0; + // second loop calculates derivatives using information gained in first + // loop above + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + Y init = tbias + talpha * y[j]; + + if (j == 0) { + for (uint s = begin; s < end; ++s) { + factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s], + -tbeta - 1); + prev = prev + x[s] * factor[s]; } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const X *x = inBuff + inTadOffsets[i]; - Y *y = gradIBuff + gradITadOffsets[i]; - - // this loop calculates squared sum of elements per each j-th element range [j - depth, j + depth + 1] - // we store each squared sum in corresponding element of y array - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - if (j == 0) { - y[0] = 0; - for (uint s = begin; s < end; ++s) - y[0] = y[0] + x[s * inTadEws] * x[s * inTadEws]; - } else if (begin == 0 && last <= tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; - else if (begin > 0 && last <= tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else if (begin > 0 && last > tadLen) - y[j * gradITadEws] = - y[(j - 1) * gradITadEws] - x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; - else - y[j * gradITadEws] = y[(j - 1) * gradITadEws]; - } - - Y *factor = new Y[tadLen]; - - Y prev = 0; - // second loop calculates derivatives using information gained in first loop above - for (Nd4jLong j = 0; j < tadLen; ++j) { - const uint begin = sd::math::nd4j_max(0, j - depth); - const uint last = depth + j + 1; - const uint end = sd::math::nd4j_min(last, tadLen); - - Y init = tbias + talpha * y[j * gradITadEws]; - - if (j == 0) { - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * y[s * gradITadEws], -tbeta - 1); - prev = prev + x[s * inTadEws] * factor[s]; - } - y[0] = prev; - } else if (begin == 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[(end - 1) * gradITadEws], - -tbeta - 1); - y[j * gradITadEws] = prev + x[(end - 1) * inTadEws] * factor[end - 1]; - } else if (begin > 0 && last <= tadLen) { - factor[end - 1] = sd::math::nd4j_pow(tbias + talpha * y[(end - 1) * gradITadEws], - -tbeta - 1); - y[j * gradITadEws] = prev + x[(end - 1) * inTadEws] * factor[end - 1] - - x[(begin - 1) * inTadEws] * factor[begin - 1]; - } else if (begin > 0 && last > tadLen) - y[j * gradITadEws] = prev - x[(begin - 1) * inTadEws] * factor[begin - 1]; - else - y[j * gradITadEws] = prev; - - if (j != 0) - prev = y[j * gradITadEws]; - - y[j * gradITadEws] = factor[j] * init - 2 * x[j * inTadEws] * coeff * prev; - } - - delete[]factor; + y[0] = prev; + } else if (begin == 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[end - 1], -tbeta - 1); + y[j] = prev + x[end - 1] * factor[end - 1]; + } else if (begin > 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[end - 1], -tbeta - 1); + y[j] = prev + x[end - 1] * factor[end - 1] - + x[begin - 1] * factor[begin - 1]; + } else if (begin > 0 && last > tadLen) + y[j] = prev - x[begin - 1] * factor[begin - 1]; + else + y[j] = prev; + + if (j != 0) prev = y[j]; + + y[j] = factor[j] * init - 2 * x[j] * coeff * prev; + } + + delete[] factor; + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const X* x = inBuff + inTadOffsets[i]; + Y* y = gradIBuff + gradITadOffsets[i]; + + // this loop calculates squared sum of elements per each j-th element + // range [j - depth, j + depth + 1] we store each squared sum in + // corresponding element of y array + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + if (j == 0) { + y[0] = 0; + for (uint s = begin; s < end; ++s) + y[0] = y[0] + x[s * inTadEws] * x[s * inTadEws]; + } else if (begin == 0 && last <= tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] + + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws]; + else if (begin > 0 && last <= tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] + + x[(end - 1) * inTadEws] * x[(end - 1) * inTadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else if (begin > 0 && last > tadLen) + y[j * gradITadEws] = + y[(j - 1) * gradITadEws] - + x[(begin - 1) * inTadEws] * x[(begin - 1) * inTadEws]; + else + y[j * gradITadEws] = y[(j - 1) * gradITadEws]; + } + + Y* factor = new Y[tadLen]; + + Y prev = 0; + // second loop calculates derivatives using information gained in first + // loop above + for (Nd4jLong j = 0; j < tadLen; ++j) { + const uint begin = sd::math::nd4j_max(0, j - depth); + const uint last = depth + j + 1; + const uint end = sd::math::nd4j_min(last, tadLen); + + Y init = tbias + talpha * y[j * gradITadEws]; + + if (j == 0) { + for (uint s = begin; s < end; ++s) { + factor[s] = sd::math::nd4j_pow( + tbias + talpha * y[s * gradITadEws], -tbeta - 1); + prev = prev + x[s * inTadEws] * factor[s]; } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - gradI *= gradO; + y[0] = prev; + } else if (begin == 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[(end - 1) * gradITadEws], -tbeta - 1); + y[j * gradITadEws] = + prev + x[(end - 1) * inTadEws] * factor[end - 1]; + } else if (begin > 0 && last <= tadLen) { + factor[end - 1] = sd::math::nd4j_pow( + tbias + talpha * y[(end - 1) * gradITadEws], -tbeta - 1); + y[j * gradITadEws] = prev + + x[(end - 1) * inTadEws] * factor[end - 1] - + x[(begin - 1) * inTadEws] * factor[begin - 1]; + } else if (begin > 0 && last > tadLen) + y[j * gradITadEws] = + prev - x[(begin - 1) * inTadEws] * factor[begin - 1]; + else + y[j * gradITadEws] = prev; + + if (j != 0) prev = y[j * gradITadEws]; + + y[j * gradITadEws] = + factor[j] * init - 2 * x[j * inTadEws] * coeff * prev; + } + + delete[] factor; + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } + gradI *= gradO; } - -void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, + (input, gradO, gradI, depth, bias, alpha, beta), + FLOAT_TYPES, FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 02d4c985538e..40bc335aff59 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -20,227 +20,262 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include +#include +#include #include +#include #include -#include #include -#include +#include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params) { - - // xt input [bS x nIn] - // ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!! - // ct_1 previous cell state [bS x nOut], that is at previous time step t-1 - - // Wx input-to-hidden weights, [nIn x 4*nOut] - // Wh hidden-to-hidden weights, [numProj x 4*nOut] - // Wc diagonal weights for peephole connections [3*nOut] - // Wp projection weights [nOut x numProj] - // b biases, [4*nOut] - - // ht current cell output [bS x numProj], that is at current time step t - // ct current cell state [bS x nOut], that is at current time step t - - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!! - double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = params[4]; - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int nOut = ct_1->sizeAt(1); - - auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] - - auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut] - auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut] - auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut] - auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate - zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate - } - - // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - ct->assign( sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct) ); - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); - - if(peephole) - zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc - - // current cell output = ot*tanh(ct) - auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] - - // apply projection - if(projection) { - ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] - // if clipping projection is provided then projected cell output state is clipped by this value - if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); - } - else - ht->assign(&htNoPeepHole); +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params) { + // xt input [bS x nIn] + // ht_1 previous cell output [bS x numProj], that is at previous time step + // t-1, in case of projection=false -> numProj=nOut!!! ct_1 previous cell + // state [bS x nOut], that is at previous time step t-1 + + // Wx input-to-hidden weights, [nIn x 4*nOut] + // Wh hidden-to-hidden weights, [numProj x 4*nOut] + // Wc diagonal weights for peephole connections [3*nOut] + // Wp projection weights [nOut x numProj] + // b biases, [4*nOut] + + // ht current cell output [bS x numProj], that is at current time step t + // ct current cell state [bS x nOut], that is at current time step t + + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const bool projection = + (bool)params[1]; // if true, then projection is performed, if false then + // numProj==nOut is mandatory!!!! + double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + double clippingProjValue = + params[3]; // clipping value for projected ht, if it is not equal to + // zero, then projected cell output is clipped + const double forgetBias = params[4]; + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int nOut = ct_1->sizeAt(1); + + auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] + + auto zit = z({0, 0, 0, nOut}); // z for input gate, = mmul(Wxi,xt) + + // mmul(Whi,ht_1) + bi = [bS x nOut] + auto zft = z({0, 0, nOut, 2 * nOut}); // z for forget gate, = mmul(Wxf,xt) + + // mmul(Whf,ht_1) + bf = [bS x nOut] + auto zct = + z({0, 0, 2 * nOut, 3 * nOut}); // z for cell state, = mmul(Wxc,xt) + + // mmul(Whc,ht_1) + bc = [bS x nOut] + auto zot = + z({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, = mmul(Wxo,xt) + + // mmul(Who,ht_1) + bo = [bS x nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zit += + (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate + zft += (*ct_1) * + (*Wc)({nOut, 2 * nOut}); // add peephole connections to forget gate + } + + // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + ct->assign(sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct)); + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); + + if (peephole) + zot += (*ct) * (*Wc)({{2 * nOut, 3 * nOut}}); // add peephole connections + // to output gate zot + ct*Wc + + // current cell output = ot*tanh(ct) + auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] + + // apply projection + if (projection) { + ht->assign(mmul(htNoPeepHole, + *Wp)); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] + // if clipping projection is provided then projected cell output state is + // clipped by this value + if (clippingProjValue != 0.) + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); + } else + ht->assign(&htNoPeepHole); } template -static void fusedTanh(NDArray *z, NDArray *i, NDArray *c, const NDArray *cLast, NDArray *f, NDArray *h) { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - /* - z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, h); //h = tanh(c) - */ - - auto uLen = static_cast(z->lengthOf()); - auto c_ = c->bufferAsT(); - auto z_ = z->bufferAsT(); - auto i_ = i->bufferAsT(); - auto f_ = f->bufferAsT(); - auto cLast_ = cLast->bufferAsT(); - auto h_ = h->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - c_[e] = z_[e] * i_[e] + (f_[e] * cLast_[e]); - h_[e] = sd::math::nd4j_tanh(c_[e]); - } - }; - - samediff::Threads::parallel_for(func, 0, uLen); -} - -////////////////////////////////////////////////////////////////////////// - -void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params) { - - /* Input arrays: - * 0: xt - input [bS, nIn] at time t - * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 - * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 - * 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - * 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut] - * 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut] - * 6: Wco - weights - cell peephole (t) connections to output gate, [nOut] - * 7: b - biases, [4*nOut] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, nOut] - * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) - * 2: f - Output - forget gate activations [bs, nOut] - * 3: o - Output - output gate activations [bs, nOut] - * 4: z (ci) - Output - block input [bs, nOut] - * 5: h (co) - Cell state, post tanh [bs, nOut] - * 6: y (h) - Current cell output [bS, nOut], time t - */ - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const double forgetBias = params[1]; - const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int nOut = cLast->sizeAt(1); - - //Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] - NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext()); - helpers::concat(xt->getContext(), {const_cast(xt), const_cast(yLast)}, concatOut, {1}); - - auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] - m += (*b); // addiRowVector - - //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) - auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut] - auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut] - auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut] - auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zi += (*cLast) * (*Wci); // add peephole connections to input gate - zf += (*cLast) * (*Wcf); // add peephole connections to forget gate - } - - // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - if(forgetBias != 0.0) - zf += forgetBias; - - PRAGMA_OMP_PARALLEL - PRAGMA_OMP_SINGLE - { - PRAGMA_OMP_TASK - zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) - - PRAGMA_OMP_TASK - zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) - - PRAGMA_OMP_TASK - zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); +static void fusedTanh(NDArray* z, NDArray* i, NDArray* c, const NDArray* cLast, + NDArray* f, NDArray* h) { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + /* + z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * + i auto temp = (*f) * (*cLast); *c += temp; //c = + (i * z) + (zf * (*cLast)) c->applyTransform(transform::Tanh, h); //h = + tanh(c) + */ + + auto uLen = static_cast(z->lengthOf()); + auto c_ = c->bufferAsT(); + auto z_ = z->bufferAsT(); + auto i_ = i->bufferAsT(); + auto f_ = f->bufferAsT(); + auto cLast_ = cLast->bufferAsT(); + auto h_ = h->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + c_[e] = z_[e] * i_[e] + (f_[e] * cLast_[e]); + h_[e] = sd::math::nd4j_tanh(c_[e]); } + }; - if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 && - z->ordering() == i->ordering() && z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && z->ordering() == f->ordering() && z->ordering() == h->ordering()) { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), FLOAT_TYPES); - } else { - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - } - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue, *c); - - // add peephole connections to output gate zot + ct*Wc - if(peephole) { - auto prod = *c * (*Wco); - zo += prod; - } - - zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) - - // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h + samediff::Threads::parallel_for(func, 0, uLen); } +////////////////////////////////////////////////////////////////////////// - - -} -} +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params) { + /* Input arrays: + * 0: xt - input [bS, nIn] at time t + * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 + * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 + * 3: W - Weights - concatenated (input-to-hidden, + * hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] 4: Wci - weights - + * cell peephole (t-1) connections to input modulation gate, [nOut] 5: Wcf - + * weights - cell peephole (t-1) connections to forget gate, [nOut] 6: Wco - + * weights - cell peephole (t) connections to output gate, [nOut] 7: b - + * biases, [4*nOut] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell + * state, if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, nOut] + * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) + * 2: f - Output - forget gate activations [bs, nOut] + * 3: o - Output - output gate activations [bs, nOut] + * 4: z (ci) - Output - block input [bs, nOut] + * 5: h (co) - Cell state, post tanh [bs, nOut] + * 6: y (h) - Current cell output [bS, nOut], time t + */ + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const double forgetBias = params[1]; + const double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int nOut = cLast->sizeAt(1); + + // Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] + NDArray concatOut(xt->ordering(), + {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, + xt->dataType(), xt->getContext()); + helpers::concat(xt->getContext(), + {const_cast(xt), const_cast(yLast)}, + concatOut, {1}); + + auto m = + mmul(concatOut, + *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] + m += (*b); // addiRowVector + + // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) + auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] + auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] + auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] + auto zo = m({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, [bS, nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zi += (*cLast) * (*Wci); // add peephole connections to input gate + zf += (*cLast) * (*Wcf); // add peephole connections to forget gate + } + + // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + if (forgetBias != 0.0) zf += forgetBias; + + PRAGMA_OMP_PARALLEL + PRAGMA_OMP_SINGLE { + PRAGMA_OMP_TASK + zz.applyTransform(transform::Tanh, *z); // z = tanh(zz) + + PRAGMA_OMP_TASK + zi.applyTransform(transform::Sigmoid, *i); // i = sigmoid(zi) + + PRAGMA_OMP_TASK + zf.applyTransform(transform::Sigmoid, *f); // f = sigmoid(zf); + } + + if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && + f->ews() == 1 && h->ews() == 1 && z->ordering() == i->ordering() && + z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && + z->ordering() == f->ordering() && z->ordering() == h->ordering()) { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + BUILD_SINGLE_SELECTOR(z->dataType(), fusedTanh, (z, i, c, cLast, f, h), + FLOAT_TYPES); + } else { + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); // c = z * i + auto temp = (*f) * (*cLast); + *c += temp; // c = (i * z) + (zf * (*cLast)) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + } + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); + + // add peephole connections to output gate zot + ct*Wc + if (peephole) { + auto prod = *c * (*Wco); + zo += prod; + } + + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) + + // current cell output = ot*tanh(ct) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); // y = o * h } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp index 19f936c81b91..51cf033ff2ed 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp @@ -17,92 +17,114 @@ // // @author GS // -#include #include #include #include #include - -#include -#include #include +#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static void fillRegularizer(NDArray& ioMatrix, double const value) { - auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1}); - auto rows = ioMatrix.sizeAt(-2); - //auto cols = ioMatrix.sizeAt(-1); - - for (auto x = 0; x < lastDims.size(); x++) { - for (auto r = 0; r < rows; r++) { - lastDims[x].t(r,r) = (T)value; - } - } +template +static void fillRegularizer(NDArray& ioMatrix, double const value) { + auto lastDims = ioMatrix.allTensorsAlongDimension({-2, -1}); + auto rows = ioMatrix.sizeAt(-2); + // auto cols = ioMatrix.sizeAt(-1); + for (auto x = 0; x < lastDims.size(); x++) { + for (auto r = 0; r < rows; r++) { + lastDims[x].t(r, r) = (T)value; } + } +} - template - int leastSquaresSolveFunctor_(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - NDArray::preparePrimaryUse({output}, {leftInput, rightInput}); - if (fast) { // Cholesky decomposition approach - // Equation for solve A^T * Ax = A^T * b, so - // 1. Computing A2: - auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); - //tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); - NDArray leftOutput('c', tAtShape, output->dataType(), context); - MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A - // 2. Computing B' = A^T * b - auto rightOutput = output->ulike(); - - MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b - // 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - l2Regularizer * I) - auto regularizer = leftOutput.ulike(); - fillRegularizer(regularizer, l2Regularizer);https://mangapark.net/ -// regularizer *= l2Regularizer; - leftOutput += regularizer; - // 4. Cholesky decomposition -- output matrix is square and lower triangular -// auto leftOutputT = leftOutput.ulike(); - auto err = helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition - if (err) return err; - // alternate moment: inverse lower triangular matrix to solve equation A'x = b' => L^Tx = L^-1 * b' - // solve one upper triangular system (to avoid float problems) - - // 5. Solve two triangular systems: - auto rightB = rightOutput.ulike(); - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB); - helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); //.transposei(); - helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output); - // All done - } - else { // QR decomposition approach - // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular - // 1. QR decomposition - auto qShape = leftInput->getShapeAsVector(); - auto rShape = leftInput->getShapeAsVector(); - qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); +template +int leastSquaresSolveFunctor_(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + NDArray::preparePrimaryUse({output}, {leftInput, rightInput}); + if (fast) { // Cholesky decomposition approach + // Equation for solve A^T * Ax = A^T * b, so + // 1. Computing A2: + auto tAtShape = ShapeUtils::evalShapeForMatmul( + leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); + // tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); + NDArray leftOutput('c', tAtShape, output->dataType(), context); + MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, + false); // Computing A2 = A^T * A + // 2. Computing B' = A^T * b + auto rightOutput = output->ulike(); - NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike(); - NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike(); - helpers::qr(context, leftInput, &Q, &R, true); - // 2. b` = Q^t * b: - auto rightOutput = rightInput->ulike(); - MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); - // 3. Solve triangular system - helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output); - } - NDArray::registerPrimaryUse({output}, {leftInput, rightInput}); - return Status::OK(); - } + MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, + false); // Computing B' = A^T * b + // 3. due l2Regularizer = 0, skip regularization ( indeed A' = A2 - + // l2Regularizer * I) + auto regularizer = leftOutput.ulike(); + fillRegularizer(regularizer, l2Regularizer); + https: // mangapark.net/ + // regularizer *= l2Regularizer; + leftOutput += regularizer; + // 4. Cholesky decomposition -- output matrix is square and lower triangular + // auto leftOutputT = leftOutput.ulike(); + auto err = helpers::cholesky(context, &leftOutput, &leftOutput, + true); // inplace decomposition + if (err) return err; + // alternate moment: inverse lower triangular matrix to solve equation A'x = + // b' => L^Tx = L^-1 * b' solve one upper triangular system (to avoid float + // problems) - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES); - } + // 5. Solve two triangular systems: + auto rightB = rightOutput.ulike(); + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, + false, &rightB); + helpers::adjointMatrix(context, &leftOutput, true, + &leftOutput); //.transposei(); + helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, + output); + // All done + } else { // QR decomposition approach + // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal + // matrix, and R - upper triangular + // 1. QR decomposition + auto qShape = leftInput->getShapeAsVector(); + auto rShape = leftInput->getShapeAsVector(); + qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); + NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), + context); // = leftInput->ulike(); + NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), + context); // = rightInput->ulike(); + helpers::qr(context, leftInput, &Q, &R, true); + // 2. b` = Q^t * b: + auto rightOutput = rightInput->ulike(); + MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); + // 3. Solve triangular system + helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, + output); + } + NDArray::registerPrimaryUse({output}, {leftInput, rightInput}); + return Status::OK(); } + +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return leastSquaresSolveFunctor_, + (context, leftInput, rightInput, l2Regularizer, fast, output), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 58c704bf4359..8290113ee55f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -18,599 +18,680 @@ // @author raver119@gmail.com // -#include -#include #include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { - - if (theFirst != theSecond) - for (int i = 0; i < matrix->columns(); i++) { - math::nd4j_swap(matrix->t(theFirst, i), matrix->t(theSecond, i)); - } - } - BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); - - template - static void swapRows(T* matrixBuf, Nd4jLong const* matrixShape, Nd4jLong theFirst, Nd4jLong theSecond) { - if (theFirst != theSecond) { - auto n = shape::sizeAt(matrixShape, -1); - - auto loop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong theFirstPos[] = {theFirst, i}; - Nd4jLong theSecondPos[] = {theSecond, i}; - auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); - auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); - math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); - } - }; - - samediff::Threads::parallel_tad(loop, 0, n, 1); - } - } - - void swapRows(NDArray* matrix, int theFirst, int theSecond) { - BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES); +template +static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { + if (theFirst != theSecond) + for (int i = 0; i < matrix->columns(); i++) { + math::nd4j_swap(matrix->t(theFirst, i), matrix->t(theSecond, i)); } +} +BUILD_SINGLE_TEMPLATE(template void swapRows_, + (NDArray * matrix, int theFirst, int theSecond), + FLOAT_TYPES); - template - static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); +template +static void swapRows(T* matrixBuf, Nd4jLong const* matrixShape, + Nd4jLong theFirst, Nd4jLong theSecond) { + if (theFirst != theSecond) { + auto n = shape::sizeAt(matrixShape, -1); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(matrixShape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(matrixShape, theSecondPos, 0); + math::nd4j_swap(matrixBuf[theFirstIndex], matrixBuf[theSecondIndex]); + } + }; + + samediff::Threads::parallel_tad(loop, 0, n, 1); + } +} - if (inputMatrix->isIdentityMatrix()) return; +void swapRows(NDArray* matrix, int theFirst, int theSecond) { + BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, + (matrix, theFirst, theSecond), FLOAT_TYPES); +} - auto invertDiagonals = PRAGMA_THREADS_FOR { - for (int i = start; i < stop; i += increment) - invertedMatrix->t(i, i) /= inputMatrix->t(i, i); - }; +template +static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) return; + + auto invertDiagonals = PRAGMA_THREADS_FOR { + for (int i = start; i < stop; i += increment) + invertedMatrix->t(i, i) /= inputMatrix->t(i, i); + }; + + auto invertSubDiagonals = PRAGMA_THREADS_FOR { + for (int i = start; i < stop; i += increment) + invertedMatrix->t(i, i - 1) -= + (inputMatrix->t(i, i - 1) * invertedMatrix->t(i - 1, i - 1) / + inputMatrix->t(i, i)); + }; + + samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); + samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int i = 1; i < n; i++) { + for (int j = 0; j < i - 1; j++) + for (int k = 0; k < i; k++) + invertedMatrix->t(i, j) -= + ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / + inputMatrix->t(i, i))); + } +} - auto invertSubDiagonals = PRAGMA_THREADS_FOR { - for (int i = start; i < stop; i += increment) - invertedMatrix->t(i, i - 1) -= (inputMatrix->t(i, i - 1) * invertedMatrix->t(i - 1, i - 1) / inputMatrix->t(i, i)); - }; +BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, + (NDArray * inputMatrix, NDArray* invertedMatrix); + , FLOAT_TYPES); - samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); - samediff::Threads::parallel_for(invertSubDiagonals, 1, n, 1); +void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, + (inputMatrix, invertedMatrix), FLOAT_TYPES); +} -// PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int i = 1; i < n; i++) { - for (int j = 0; j < i - 1 ; j++) - for (int k = 0; k < i; k++) - invertedMatrix->t(i, j) -= ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / inputMatrix->t(i, i))); - } +template +static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I + return; + } + + auto invertDiagonals = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + invertedMatrix->t(i, i) /= inputMatrix->t(i, i); + }; + + // PRAGMA_OMP_PARALLEL_FOR_IF(n > + // Environment::getInstance()->elementwiseThreshold()) + auto invertUpDiagonals = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + invertedMatrix->t(i, i + 1) -= + (inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / + inputMatrix->t(i, i)); + }; + + samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); + samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1); + + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (auto i = n - 2; i >= 0; i--) { + for (auto j = i + 2; j < n; j++) + for (auto k = i; k < n; k++) + invertedMatrix->t(i, j) -= + ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / + inputMatrix->t(i, i))); + } +} - } +BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, + (NDArray * inputMatrix, NDArray* invertedMatrix); + , FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); +void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, + (inputMatrix, invertedMatrix), FLOAT_TYPES); +} - void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); +template +static NDArray lup_(LaunchContext* context, NDArray* input, NDArray* compound, + NDArray* permutation) { + const int rowNum = input->rows(); + const int columnNum = input->columns(); + + NDArray determinant = NDArrayFactory::create(1.f, context); + NDArray compoundMatrix = input->dup(); // copy + NDArray permutationMatrix( + input, false, context); // has same shape as input and contiguous strides + permutationMatrix.setIdentity(); + + T pivotValue; // = T(0.0); + int pivot; // = -1; + int swapCount = 0; + + for (int i = 0; i < rowNum; i++) { + pivotValue = T(0.0); + pivot = -1; + // PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue)) + for (int rowCounter = i; rowCounter < rowNum; rowCounter++) { + if (sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)) > pivotValue) { + pivotValue = sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)); + pivot = rowCounter; + } } - template - static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - - if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I - return; + if (pivotValue > DataTypeUtils::min()) { + swapRows(&compoundMatrix, pivot, i); + swapRows(&permutationMatrix, pivot, i); + if (pivot != i) swapCount++; + + for (int j = i + 1; j < rowNum; j++) { + compoundMatrix.t(j, i) /= compoundMatrix.t(i, i); + // PRAGMA_OMP_PARALLEL_FOR + for (int k = i + 1; k < rowNum; k++) { + compoundMatrix.t(j, k) -= + compoundMatrix.t(j, i) * compoundMatrix.t(i, k); } - - auto invertDiagonals = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - invertedMatrix->t(i, i) /= inputMatrix->t(i, i); - }; - - //PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold()) - auto invertUpDiagonals = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - invertedMatrix->t(i, i + 1) -= (inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / - inputMatrix->t(i, i)); - }; - - samediff::Threads::parallel_for(invertDiagonals, 0, n, 1); - samediff::Threads::parallel_for(invertUpDiagonals, 0, n - 1, 1); - -// PRAGMA_OMP_PARALLEL_FOR_SIMD - for (auto i = n - 2; i >= 0; i--) { - for (auto j = i + 2; j < n; j++) - for (auto k = i; k < n; k++) - invertedMatrix->t(i, j) -= ((invertedMatrix->t(k, j) * inputMatrix->t(i, k) / inputMatrix->t(i, i))); + } + } + } + + for (int e = 0; e < rowNum; e++) { + // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e)); + determinant *= compoundMatrix.e(e, e); + } + if (swapCount % 2) determinant = -determinant; + if (compound != nullptr) compound->assign(compoundMatrix); + if (permutation != nullptr) { + auto permutaionVector = NDArrayFactory::create( + 'c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); + for (auto i = 0; i < rowNum; i++) { + for (auto j = 0; j < columnNum; j++) { + if (permutationMatrix.t(i, j) != 0) { + permutaionVector.template t(i) = j; } + } } - - BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); - - void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES); + if (permutationMatrix.isSameShape(permutation)) + permutation->assign(permutationMatrix); + else if (permutation->isSameShape(permutaionVector)) { + permutation->assign(permutaionVector); } + } + return determinant; +} - - template - static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { - - const int rowNum = input->rows(); - const int columnNum = input->columns(); - - NDArray determinant = NDArrayFactory::create(1.f, context); - NDArray compoundMatrix = input->dup(); // copy - NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides - permutationMatrix.setIdentity(); - - T pivotValue; // = T(0.0); - int pivot; // = -1; - int swapCount = 0; - - for(int i = 0; i < rowNum; i++ ) { - pivotValue = T(0.0); - pivot = -1; - //PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue)) - for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { - if (sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)) > pivotValue) { - pivotValue = sd::math::nd4j_abs(compoundMatrix.t(rowCounter, i)); - pivot = rowCounter; - } - } - - if( pivotValue > DataTypeUtils::min()) { - swapRows(&compoundMatrix, pivot, i); - swapRows(&permutationMatrix, pivot, i); - if (pivot != i) - swapCount++; - - for( int j = i + 1; j < rowNum; j++ ) { - compoundMatrix.t(j, i) /= compoundMatrix.t(i, i); - //PRAGMA_OMP_PARALLEL_FOR - for( int k = i + 1; k < rowNum; k++ ) { - compoundMatrix.t(j, k) -= compoundMatrix.t(j, i) * compoundMatrix.t(i, k); - } - } - } - } - - for (int e = 0; e < rowNum; e++) { - // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e)); - determinant *= compoundMatrix.e(e, e); - } - if (swapCount % 2) determinant = -determinant; - if (compound != nullptr) - compound->assign(compoundMatrix); - if (permutation != nullptr) { - auto permutaionVector = NDArrayFactory::create('c', {rowNum}, DataTypeUtils::fromT(), input->getContext()); - for (auto i = 0; i < rowNum; i++) { - for (auto j = 0; j < columnNum; j++) { - if (permutationMatrix.t(i, j) != 0) { - permutaionVector.template t(i) = j; - } - } - } - if (permutationMatrix.isSameShape(permutation)) - permutation->assign(permutationMatrix); - else if (permutation->isSameShape(permutaionVector)) { - permutation->assign(permutaionVector); - } - } - return determinant; +BUILD_DOUBLE_TEMPLATE(template NDArray lup_, + (LaunchContext * context, NDArray* input, NDArray* output, + NDArray* permutation), + FLOAT_TYPES, INDEXING_TYPES); +/* + * lu decomposition with naive algorithm with partial pivoting + * */ +template +static I argmaxCol(I column, T* compoundBuffer, Nd4jLong const* compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); // sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1; + // auto loop = PRAGMA_THREADS_FOR { + auto start = column, stop = rowNum, increment = 1; + for (auto rowCounter = start; rowCounter < stop; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = sd::math::nd4j_max(maxValue, + sd::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; } + } + //}; + // samediff::Threads::parallel_for(loop, column, rowNum, 1); + return result; +} - BUILD_DOUBLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); - /* - * lu decomposition with naive algorithm with partial pivoting - * */ - template - static I argmaxCol(I column, T* compoundBuffer, Nd4jLong const* compoundShape) { - auto rowNum = shape::sizeAt(compoundShape, 0); - Nd4jLong xInitial[] = {column, column}; - auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); - auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); - auto result = -1; - //auto loop = PRAGMA_THREADS_FOR { - auto start = column, stop = rowNum, increment = 1; - for (auto rowCounter = start; rowCounter < stop; rowCounter++) { - Nd4jLong xPos[] = {rowCounter, column}; - auto xIndex = shape::getOffset(compoundShape, xPos, 0); - if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { - maxValue = sd::math::nd4j_max(maxValue, sd::math::nd4j_abs(compoundBuffer[xIndex])); - result = rowCounter; - } - } - //}; - //samediff::Threads::parallel_for(loop, column, rowNum, 1); - return result; +template +void processColumns(int currentRow, int rowNum, T* compoundBuf, + Nd4jLong const* compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + auto loop = PRAGMA_THREADS_FOR { + for (auto j = start; j < stop; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; // output->t(i, i); + for (int k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; + } } + }; + samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); +} - template - void processColumns(int currentRow, int rowNum, T* compoundBuf, Nd4jLong const* compoundShape) { - Nd4jLong xDiag[] = {currentRow, currentRow}; - auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); - auto loop = PRAGMA_THREADS_FOR { - for (auto j = start; j < stop; j++) { - Nd4jLong xRow[] = {j, currentRow}; - auto rowIndex = shape::getOffset(compoundShape, xRow, 0); - compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); - for (int k = currentRow + 1; k < rowNum; k++) { - Nd4jLong yRow[] = {j, k}; - Nd4jLong yCol[] = {currentRow, k}; - auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); - auto colIndex = shape::getOffset(compoundShape, yCol, 0); - compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; - } - } - }; - samediff::Threads::parallel_tad(loop, currentRow + 1, rowNum, 1); +template +static void doolitleLU(LaunchContext* context, NDArray* compound, + Nd4jLong rowNum) { + auto input = compound->dup(); + compound->nullify(); + + // Decomposing matrix into Upper and Lower + // triangular matrix + for (auto i = 0; i < rowNum; i++) { + // Upper Triangular + for (auto k = i; k < rowNum; k++) { + // Summation of L(i, j) * U(j, k) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(i, j) * compound->t(j, k); + + // Evaluating U(i, k) + compound->t(i, k) = input.t(i, k) - sum; } - template - static void doolitleLU(LaunchContext* context, NDArray* compound, Nd4jLong rowNum) { - auto input = compound->dup(); - compound->nullify(); - - // Decomposing matrix into Upper and Lower - // triangular matrix - for (auto i = 0; i < rowNum; i++) { - - // Upper Triangular - for (auto k = i; k < rowNum; k++) { - - // Summation of L(i, j) * U(j, k) - int sum = 0; - for (int j = 0; j < i; j++) - sum += compound->t(i,j) * compound->t(j,k); - - // Evaluating U(i, k) - compound->t(i, k) = input.t(i, k) - sum; - } - - // Lower Triangular - for (int k = i + 1; k < rowNum; k++) { - // Summation of L(k, j) * U(j, i) - int sum = 0; - for (int j = 0; j < i; j++) - sum += compound->t(k,j) * compound->t(j, i); - - // Evaluating L(k, i) - compound->t(k, i) = (input.t(k, i) - sum) / compound->t(i,i); - } - } - } + // Lower Triangular + for (int k = i + 1; k < rowNum; k++) { + // Summation of L(k, j) * U(j, i) + int sum = 0; + for (int j = 0; j < i; j++) + sum += compound->t(k, j) * compound->t(j, i); - template - static void luNN_(LaunchContext *context, NDArray* compound, NDArray* permutation, Nd4jLong rowNum) { - - //const int rowNum = compound->rows(); -// const int columnNum = output->columns(); - if (permutation) { // LUP algorithm - permutation->linspace(0); - auto permutationBuf = permutation->bufferAsT(); //dataBuffer()->primaryAsT(); - auto compoundBuf = compound->bufferAsT(); - auto compoundShape = compound->shapeInfo(); - auto permutationShape = permutation->shapeInfo(); - for (auto i = 0; i < rowNum - 1; i++) { - auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); - if (pivotIndex < 0) { - throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutationBuf[shape::getIndexOffset(i, permutationShape)], - permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); - swapRows(compoundBuf, compoundShape, i, pivotIndex); - - processColumns(i, rowNum, compoundBuf, compoundShape); - } - } - else { // Doolitle algorithm with LU decomposition - doolitleLU(context, compound, rowNum); - } + // Evaluating L(k, i) + compound->t(k, i) = (input.t(k, i) - sum) / compound->t(i, i); } + } +} - template - static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { - auto n = input->sizeAt(-1); - - output->assign(input); // fill up output tensor with zeros - ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); - ResultSet permutations; - if (permutationVectors) - permutations = permutationVectors->allTensorsAlongDimension({-1}); - - auto loop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - luNN_(context, &outputs.at(i), permutationVectors ? &permutations.at(i) : nullptr, n); - } - }; - samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); +template +static void luNN_(LaunchContext* context, NDArray* compound, + NDArray* permutation, Nd4jLong rowNum) { + // const int rowNum = compound->rows(); + // const int columnNum = output->columns(); + if (permutation) { // LUP algorithm + permutation->linspace(0); + auto permutationBuf = + permutation->bufferAsT(); // dataBuffer()->primaryAsT(); + auto compoundBuf = compound->bufferAsT(); + auto compoundShape = compound->shapeInfo(); + auto permutationShape = permutation->shapeInfo(); + for (auto i = 0; i < rowNum - 1; i++) { + auto pivotIndex = argmaxCol(i, compoundBuf, compoundShape); + if (pivotIndex < 0) { + throw std::runtime_error("helpers::luNN_: input matrix is singular."); + } + math::nd4j_swap( + permutationBuf[shape::getIndexOffset(i, permutationShape)], + permutationBuf[shape::getIndexOffset(pivotIndex, permutationShape)]); + swapRows(compoundBuf, compoundShape, i, pivotIndex); + + processColumns(i, rowNum, compoundBuf, compoundShape); } + } else { // Doolitle algorithm with LU decomposition + doolitleLU(context, compound, rowNum); + } +} - void lu(LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation?permutation->dataType():DataType::INT32, lu_, (context, input, output, permutation), FLOAT_TYPES, INDEXING_TYPES); +template +static void lu_(LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutationVectors) { + auto n = input->sizeAt(-1); + + output->assign(input); // fill up output tensor with zeros + ResultSet outputs = output->allTensorsAlongDimension({-2, -1}); + ResultSet permutations; + if (permutationVectors) + permutations = permutationVectors->allTensorsAlongDimension({-1}); + + auto loop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + luNN_(context, &outputs.at(i), + permutationVectors ? &permutations.at(i) : nullptr, n); } + }; + samediff::Threads::parallel_for(loop, 0, outputs.size(), 1); +} -// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, INDEXING_TYPES); - - template - static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { +void lu(LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), + permutation ? permutation->dataType() : DataType::INT32, + lu_, (context, input, output, permutation), FLOAT_TYPES, + INDEXING_TYPES); +} - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; +// BUILD_DOUBLE_TEMPLATE(template NDArray lu_, (LaunchContext *context, +// NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES, +// INDEXING_TYPES); - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.workspace()); +template +static int determinant_(LaunchContext* context, NDArray* input, + NDArray* output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + + auto matrix = + NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), + context); //, block.workspace()); + + for (int e = 0; e < output->lengthOf(); e++) { + for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) + matrix.p(row, input->e(k)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, + (NDArray*)nullptr)); + } + + return Status::OK(); +} - for (int e = 0; e < output->lengthOf(); e++) { - for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) - matrix.p(row, input->e(k)); - output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); - } +int determinant(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, + (context, input, output), FLOAT_TYPES); +} - return Status::OK(); +template +int logAbsDeterminant_(LaunchContext* context, NDArray* input, + NDArray* output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + + NDArray matrix = + NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), + context); //, block.workspace()); + for (int e = 0; e < output->lengthOf(); e++) { + for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { + matrix.p(row, input->e(k)); } + NDArray det = + lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); + if (det.e(0) != 0.f) + output->p(e, sd::math::nd4j_log(sd::math::nd4j_abs(det.t(0)))); + } - int determinant(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); - } + return ND4J_STATUS_OK; +} -template - int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) { - - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.workspace()); - for (int e = 0; e < output->lengthOf(); e++) { - for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { - matrix.p(row, input->e(k)); - } - NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); - if (det.e(0) != 0.f) - output->p(e, sd::math::nd4j_log(sd::math::nd4j_abs(det.t(0)))); - } +int logAbsDeterminant(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, + (context, input, output), FLOAT_TYPES); +} - return ND4J_STATUS_OK; +template +static int inverse_(LaunchContext* context, NDArray* input, NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + + output->assign(0.f); // fill up output tensor with zeros + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto permutation = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + + for (int e = 0; e < totalCount; e++) { + if (e) matrix.assign(0.f); + + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + matrix.p(row++, input->e(k)); } - - int logAbsDeterminant(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); + T det = lup_(context, &matrix, &compound, &permutation) + .template e(0); + + // FIXME: and how this is going to work on float16? + if (sd::math::nd4j_abs(det) < T(0.000001)) { + nd4j_printf( + "matrix_inverse: The matrix %i has no inverse due determinant is " + "%lf. Quiting...\n", + e, det); + matrix.printIndexedBuffer("Wrong matrix"); + return ND4J_STATUS_VALIDATION; + } + lowerMatrix.setIdentity(); // set up U to identity matrix + for (int k = 1; k < n; + k++) { // and then put all values under main diagonal on to it + for (int j = 0; j < k; j++) + lowerMatrix.template t(k, j) = compound.template t(k, j); + } + upperMatrix.setIdentity(); // set up U to identity matrix + for (int k = 0; k < n; + k++) { // and then put all values under main diagonal on to it + for (int j = k; j < n; j++) + upperMatrix.template t(k, j) = compound.template e(k, j); } + invertUpperMatrix(&upperMatrix, &matrix); - template - static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - - output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - - for (int e = 0; e < totalCount; e++) { - if (e) - matrix.assign(0.f); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - matrix.p(row++, input->e(k)); - } - T det = lup_(context, &matrix, &compound, &permutation).template e(0); - - // FIXME: and how this is going to work on float16? - if (sd::math::nd4j_abs(det) < T(0.000001)) { - nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det); - matrix.printIndexedBuffer("Wrong matrix"); - return ND4J_STATUS_VALIDATION; - } - lowerMatrix.setIdentity(); // set up U to identity matrix - for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it - for (int j = 0; j < k; j++) - lowerMatrix.template t(k, j) = compound.template t(k, j); - } - upperMatrix.setIdentity(); // set up U to identity matrix - for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it - for (int j = k; j < n; j++) - upperMatrix.template t(k, j) = compound.template e(k, j); - } - invertUpperMatrix(&upperMatrix, &matrix); - - invertLowerMatrix(&lowerMatrix, &upperMatrix); - - sd::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0); - sd::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - output->t(k) = matrix.template t(row++); - } - } + invertLowerMatrix(&lowerMatrix, &upperMatrix); - return Status::OK(); + sd::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0); + sd::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + output->t(k) = matrix.template t(row++); } + } - template - static int lowerInverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - - output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - -// auto batchLoop = PRAGMA_THREADS_FOR { - for (int e = 0; e < totalCount; e++) { - if (e) - matrix.assign(0.f); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - matrix.p(row++, input->e(k)); - } - T det = T(1.f); - for (auto i = 0; i < n; i++) { - det *= matrix. template t(i, i); - } - - // FIXME: and how this is going to work on float16? - if (sd::math::nd4j_abs(det) < T(0.000001)) { - nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det); - matrix.printIndexedBuffer("Wrong matrix"); - return ND4J_STATUS_VALIDATION; - } - lowerMatrix.nullify(); - invertLowerMatrix(&matrix, &lowerMatrix); - - for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { - output->t(k) = lowerMatrix.template t(row++); - } - } + return Status::OK(); +} - return Status::OK(); +template +static int lowerInverse_(LaunchContext* context, NDArray* input, + NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + + output->assign(0.f); // fill up output tensor with zeros + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + context); //, block.workspace()); + auto permutation = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = + NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + + // auto batchLoop = PRAGMA_THREADS_FOR { + for (int e = 0; e < totalCount; e++) { + if (e) matrix.assign(0.f); + + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + matrix.p(row++, input->e(k)); } - - template - static int upperInverse_(LaunchContext *context, NDArray* input, NDArray* output) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - - output->nullify(); // fill up output tensor with zeros -// auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); -// auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); -// auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto totalCount = outputPart.size(); //lengthOf() / n2; - for (int e = 0; e < totalCount; e++) { - invertUpperMatrix(&inputPart.at(e), &outputPart.at(e)); - } - return Status::OK(); + T det = T(1.f); + for (auto i = 0; i < n; i++) { + det *= matrix.template t(i, i); } - int inverse(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); + // FIXME: and how this is going to work on float16? + if (sd::math::nd4j_abs(det) < T(0.000001)) { + nd4j_printf( + "matrix_inverse: The matrix %i has no inverse due determinant is " + "%lf. Quiting...\n", + e, det); + matrix.printIndexedBuffer("Wrong matrix"); + return ND4J_STATUS_VALIDATION; } + lowerMatrix.nullify(); + invertLowerMatrix(&matrix, &lowerMatrix); - int lowerInverseFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, (context, input, output), FLOAT_TYPES); + for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { + output->t(k) = lowerMatrix.template t(row++); } + } - int upperInverseFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, (context, input, output), FLOAT_TYPES); - } + return Status::OK(); +} - template - static bool checkCholeskyInput_(sd::LaunchContext * context, NDArray const* input) { - //std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType())); //, block.workspace()); - ResultSet lastMatrixList = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf()-1}); - for (size_t i = 0; i < lastMatrixList.size(); i++) { - auto thisMatrix = lastMatrixList.at(i); - // check for symmetric - for (Nd4jLong r = 0; r < thisMatrix.rows(); r++) - for (Nd4jLong c = 0; c < thisMatrix.columns(); c++) - if (sd::math::nd4j_abs(thisMatrix.e(r, c) - lastMatrixList.at(i).e(c,r)) > DataTypeUtils::min()) return false; - - NDArray output = NDArrayFactory::create(0., context); - if (ND4J_STATUS_OK != determinant(context, &thisMatrix, &output)) return false; - if (output.e(0) <= T(0)) return 0; - NDArray reversedMatrix = thisMatrix.dup(); - if (ND4J_STATUS_OK != inverse(context, &thisMatrix, &reversedMatrix)) return false; - if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) return false; - if (output.e(0) <= T(0)) return 0; +template +static int upperInverse_(LaunchContext* context, NDArray* input, + NDArray* output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + + output->nullify(); // fill up output tensor with zeros + // auto matrix = NDArrayFactory::create('c', {n, n}, + // DataTypeUtils::fromT(), context); //, block.workspace()); auto + // lowerMatrix = NDArrayFactory::create('c', {n, n}, + // DataTypeUtils::fromT(), context); auto upperMatrix = + // NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), + // context); + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto totalCount = outputPart.size(); // lengthOf() / n2; + for (int e = 0; e < totalCount; e++) { + invertUpperMatrix(&inputPart.at(e), &outputPart.at(e)); + } + return Status::OK(); +} - } +int inverse(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, + (context, input, output), FLOAT_TYPES); +} +int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return lowerInverse_, + (context, input, output), FLOAT_TYPES); +} - return true; - } +int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return upperInverse_, + (context, input, output), FLOAT_TYPES); +} - bool checkCholeskyInput(sd::LaunchContext * context, NDArray const* input) { - BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); - } +template +static bool checkCholeskyInput_(sd::LaunchContext* context, + NDArray const* input) { + // std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, + // input->dataType())); //, block.workspace()); + ResultSet lastMatrixList = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + for (size_t i = 0; i < lastMatrixList.size(); i++) { + auto thisMatrix = lastMatrixList.at(i); + // check for symmetric + for (Nd4jLong r = 0; r < thisMatrix.rows(); r++) + for (Nd4jLong c = 0; c < thisMatrix.columns(); c++) + if (sd::math::nd4j_abs(thisMatrix.e(r, c) - + lastMatrixList.at(i).e(c, r)) > + DataTypeUtils::min()) + return false; + + NDArray output = NDArrayFactory::create(0., context); + if (ND4J_STATUS_OK != determinant(context, &thisMatrix, &output)) + return false; + if (output.e(0) <= T(0)) return 0; + NDArray reversedMatrix = thisMatrix.dup(); + if (ND4J_STATUS_OK != inverse(context, &thisMatrix, &reversedMatrix)) + return false; + if (ND4J_STATUS_OK != determinant(context, &reversedMatrix, &output)) + return false; + if (output.e(0) <= T(0)) return 0; + } + + return true; +} - template - int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) { - - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto totalCount = output->lengthOf() / n2; - if (!inplace) - output->assign(0.f); // fill up output tensor with zeros only inplace=false - - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.workspace()); - std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); - - for (int e = 0; e < totalCount; e++) { - - // fill up matrix - for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { - matrix->p(l++, input->e(k)); - } - //if (e) // from the second loop need to zero matrix - lowerMatrix->assign(0.f); - - for (Nd4jLong col = 0; col < n; col++) { - for (Nd4jLong row = 0; row < col; row++) { - T rowSum = 0; - for (Nd4jLong k = 0; k < row; ++k) - rowSum += (lowerMatrix->e(col, k) * lowerMatrix->e(row, k)); - lowerMatrix->p(col, row, (matrix->e(row, col) - rowSum) / lowerMatrix->e(row, row)); - } - T diagonalSum = 0; - for (Nd4jLong k = 0; k < col; ++k) - diagonalSum += lowerMatrix->e(col, k) * lowerMatrix->e(col, k); - lowerMatrix->p(col, col, sd::math::nd4j_sqrt(matrix->e(col, col) - diagonalSum)); - //nd4j_printf("%i: ", col); - //lowerMatrix->printIndexedBuffer("Lower matrix"); - } - for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { - output->p(k, lowerMatrix->e(l++)); - } - } +bool checkCholeskyInput(sd::LaunchContext* context, NDArray const* input) { + BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, + (context, input), FLOAT_TYPES); +} - return ND4J_STATUS_OK; +template +int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, + bool inplace) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto totalCount = output->lengthOf() / n2; + if (!inplace) + output->assign(0.f); // fill up output tensor with zeros only inplace=false + + std::unique_ptr matrix(NDArrayFactory::create_( + 'c', {n, n}, input->dataType(), context)); //, block.workspace()); + std::unique_ptr lowerMatrix( + NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); + + for (int e = 0; e < totalCount; e++) { + // fill up matrix + for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { + matrix->p(l++, input->e(k)); } - - int cholesky(sd::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { - BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); + // if (e) // from the second loop need to zero matrix + lowerMatrix->assign(0.f); + + for (Nd4jLong col = 0; col < n; col++) { + for (Nd4jLong row = 0; row < col; row++) { + T rowSum = 0; + for (Nd4jLong k = 0; k < row; ++k) + rowSum += (lowerMatrix->e(col, k) * lowerMatrix->e(row, k)); + lowerMatrix->p(col, row, + (matrix->e(row, col) - rowSum) / + lowerMatrix->e(row, row)); + } + T diagonalSum = 0; + for (Nd4jLong k = 0; k < col; ++k) + diagonalSum += lowerMatrix->e(col, k) * lowerMatrix->e(col, k); + lowerMatrix->p( + col, col, + sd::math::nd4j_sqrt(matrix->e(col, col) - diagonalSum)); + // nd4j_printf("%i: ", col); + // lowerMatrix->printIndexedBuffer("Lower matrix"); } - - template - int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { - auto tempOutput = input->dup(); - int res = cholesky_(context, input, &tempOutput, false); - if (res != ND4J_STATUS_OK) - return res; - auto n = input->sizeAt(-1); - auto totalCount = output->lengthOf(); - std::vector d(n); - ResultSet matricies = tempOutput.allTensorsAlongDimension({input->rankOf()-2, input->rankOf() - 1}); - - for (Nd4jLong e = 0; e < totalCount; e++) { - for (size_t i = 0; i < n; ++i) - output->t(e) += sd::math::nd4j_log(sd::math::nd4j_pow(matricies.at(e).t(i, i), T(2))); - } - return ND4J_STATUS_OK; + for (int k = e * n2, l = 0; k < (e + 1) * n2; k++) { + output->p(k, lowerMatrix->e(l++)); } + } - int logdetFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); - } + return ND4J_STATUS_OK; +} - int lup(sd::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, (context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); - return Status::OK(); - } +int cholesky(sd::LaunchContext* context, NDArray* input, NDArray* output, + bool inplace) { + BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, + (context, input, output, inplace), FLOAT_TYPES); +} +template +int logdetFunctor_(LaunchContext* context, NDArray* input, NDArray* output) { + auto tempOutput = input->dup(); + int res = cholesky_(context, input, &tempOutput, false); + if (res != ND4J_STATUS_OK) return res; + auto n = input->sizeAt(-1); + auto totalCount = output->lengthOf(); + std::vector d(n); + ResultSet matricies = tempOutput.allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + + for (Nd4jLong e = 0; e < totalCount; e++) { + for (size_t i = 0; i < n; ++i) + output->t(e) += sd::math::nd4j_log( + sd::math::nd4j_pow(matricies.at(e).t(i, i), T(2))); + } + return ND4J_STATUS_OK; } + +int logdetFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, + (context, input, output), FLOAT_TYPES); } + +int lup(sd::LaunchContext* context, NDArray* input, NDArray* compound, + NDArray* permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, + (context, input, compound, permutation), FLOAT_NATIVE, + INDEXING_TYPES); + return Status::OK(); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp index 443048c5676a..07ce55dc9af3 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrixSetDiag.cpp @@ -19,61 +19,64 @@ // #include -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - - // input and output are the same array (x == z) when zeroPad = true - // xRank = zRank, xRank = yRank + 1 - // xLen = zLen - - const T* x = input.bufferAsT(); - const T* y = diagonal.bufferAsT(); - T* z = output.bufferAsT(); - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* yShapeInfo = diagonal.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const bool areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not - - const int xRank = input.rankOf(); - const auto xLen = input.lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - - int coords[MAX_RANK]; - - for (Nd4jLong i = 0; i < xLen; ++i) { - - shape::index2coordsCPU(start, i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); - - // condition to be on diagonal of innermost matrix - if (coords[xRank - 2] == coords[xRank - 1]) - z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; - else - z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; - } - }; - samediff::Threads::parallel_for(func, 0, xLen); +template +void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, + NDArray& output, const bool zeroPad) { + // input and output are the same array (x == z) when zeroPad = true + // xRank = zRank, xRank = yRank + 1 + // xLen = zLen + + const T* x = input.bufferAsT(); + const T* y = diagonal.bufferAsT(); + T* z = output.bufferAsT(); + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* yShapeInfo = diagonal.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const bool areSameOffsets = shape::haveSameShapeAndStrides( + xShapeInfo, + zShapeInfo); // shapes are definitely the same, but strides might not + + const int xRank = input.rankOf(); + const auto xLen = input.lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + + for (Nd4jLong i = 0; i < xLen; ++i) { + shape::index2coordsCPU(start, i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = + areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); + + // condition to be on diagonal of innermost matrix + if (coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } + }; + samediff::Threads::parallel_for(func, 0, xLen); } ////////////////////////////////////////////////////////////////////////// -void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, (input, diagonal, output, zeroPad), LIBND4J_TYPES); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad) { + BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiag_, + (input, diagonal, output, zeroPad), LIBND4J_TYPES); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp index 660184afd93d..0e2d5b2ffff0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_band.cpp @@ -23,51 +23,63 @@ namespace sd { namespace ops { namespace helpers { - template - void matrixBandPart_(NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - // TO DO: retrieve all 2D submatricies with last dimensions and process them with given bands - Nd4jLong M = input->sizeAt(-2); - Nd4jLong N = input->sizeAt(-1); - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - ResultSet listOut = output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); - ResultSet listDiag = input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); - for (Nd4jLong e = 0; e < static_cast(listOut.size()); ++e) { - auto inputMatrix = listDiag.at(e); - auto outputMatrix = listOut.at(e); - if (outputMatrix.platformBuffer() != inputMatrix.platformBuffer()) // if not inplace - outputMatrix.assign(inputMatrix); - if (lowerBand >= 0) { - for (Nd4jLong row = 0; row < inputMatrix.rows(); ++row) { - for (Nd4jLong col = 0; col < row; ++col) { - if ((row - col) > lowerBand) - outputMatrix.p(row, col, 0.); -// else - // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); - } -// in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper). - } - } - if (upperBand >= 0) { - for (Nd4jLong col = 0; col < inputMatrix.columns(); ++col) { - for (Nd4jLong row = 0; row < col; ++row) { - if ((col - row) > upperBand) - outputMatrix.p(row, col, 0.); -// else - // (*outputMatrix)(row, col) = (*inputMatrix)(row, col); - } -// in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper). - } - - } +template +void matrixBandPart_(NDArray* input, NDArray* output, Nd4jLong lowerBand, + Nd4jLong upperBand) { + // TO DO: retrieve all 2D submatricies with last dimensions and process them + // with given bands + Nd4jLong M = input->sizeAt(-2); + Nd4jLong N = input->sizeAt(-1); + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + ResultSet listOut = + output->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + ResultSet listDiag = + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim}); + for (Nd4jLong e = 0; e < static_cast(listOut.size()); ++e) { + auto inputMatrix = listDiag.at(e); + auto outputMatrix = listOut.at(e); + if (outputMatrix.platformBuffer() != + inputMatrix.platformBuffer()) // if not inplace + outputMatrix.assign(inputMatrix); + if (lowerBand >= 0) { + for (Nd4jLong row = 0; row < inputMatrix.rows(); ++row) { + for (Nd4jLong col = 0; col < row; ++col) { + if ((row - col) > lowerBand) outputMatrix.p(row, col, 0.); + // else + // (*outputMatrix)(row, col) = + // (*inputMatrix)(row, col); } + // in_band(m, n) = (num_lower < 0 || (m-n) <= + // num_lower)) && (num_upper < 0 || (n-m) <= + // num_upper). + } } - - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, (input, output, lowerBand, upperBand), FLOAT_TYPES); + if (upperBand >= 0) { + for (Nd4jLong col = 0; col < inputMatrix.columns(); ++col) { + for (Nd4jLong row = 0; row < col; ++row) { + if ((col - row) > upperBand) outputMatrix.p(row, col, 0.); + // else + // (*outputMatrix)(row, col) = + // (*inputMatrix)(row, col); + } + // in_band(m, n) = (num_lower < 0 || (m-n) <= + // num_lower)) && (num_upper < 0 || (n-m) <= + // num_upper). + } } - BUILD_SINGLE_TEMPLATE(template void matrixBandPart_, (NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand), FLOAT_TYPES); -} -} + } } +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand) { + BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, + (input, output, lowerBand, upperBand), FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void matrixBandPart_, + (NDArray * input, NDArray* output, Nd4jLong lowerBand, + Nd4jLong upperBand), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp index afd4aa64d597..137d558bc541 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/matrix_diag_part.cpp @@ -19,49 +19,52 @@ // #include -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag template int _matrixDiagPart(const NDArray* input, NDArray* output) { + auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); + auto listDiag = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); + if (listOut.size() != listDiag.size()) { + nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); + return ND4J_STATUS_VALIDATION; + } + int lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); + // TODO: tune this properlys + int lO = listOut.size(); - if (listOut.size() != listDiag. size()) { - nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - int lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); - // TODO: tune this properlys - int lO = listOut.size(); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + for (int j = 0; j < lastDimension; ++j) + listOut.at(i).p(j, listDiag.at(i).e(j, j)); + }; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - for (int j = 0; j < lastDimension; ++j) - listOut.at(i).p(j, listDiag.at(i).e(j, j)); - }; + samediff::Threads::parallel_tad(func, 0, lO); - samediff::Threads::parallel_tad(func, 0, lO); - - return Status::OK(); + return Status::OK(); } - int matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, (input, output), LIBND4J_TYPES); - } +int matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, + (input, output), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, (const NDArray* input, NDArray* output), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, + (const NDArray* input, NDArray* output), LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index a458b5eff15d..c7c32a37612b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -18,64 +18,71 @@ // @author raver119@gmail.com // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - - int kY = params[0]; - int kX = params[1]; +template +static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* values, std::vector const& params, + NDArray* indices) { + int kY = params[0]; + int kX = params[1]; - int sY = params[2]; - int sX = params[3]; + int sY = params[2]; + int sX = params[3]; - int pY = params[4]; - int pX = params[5]; + int pY = params[4]; + int pX = params[5]; - int dY = params[6]; - int dX = params[7]; + int dY = params[6]; + int dX = params[7]; - int oY = 0; - int oX = 0; + int oY = 0; + int oX = 0; - const int bSize = input->sizeAt(0); - const int inD = input->sizeAt(1); - const int inY = input->sizeAt(2); - const int inX = input->sizeAt(3); + const int bSize = input->sizeAt(0); + const int inD = input->sizeAt(1); + const int inY = input->sizeAt(2); + const int inX = input->sizeAt(3); - const bool isSameMode = params[8] != 0; + const bool isSameMode = params[8] != 0; - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], + params[1], params[2], params[3], params[6], + params[7]); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); - - if (nullptr != indices) { - // for max_pool_with_argmax - int total = input->lengthOf(); - int part = total / bSize; - - for (int k = 0; k < total; ) - for (int i = 0; i < part; i++) { - indices->p(k++, i); - } - } + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::MAX_POOL, 1); - } - - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); - } + if (nullptr != indices) { + // for max_pool_with_argmax + int total = input->lengthOf(); + int part = total / bSize; + for (int k = 0; k < total;) + for (int i = 0; i < part; i++) { + indices->p(k++, i); + } + } } + +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices) { + BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, + (block, input, values, params, indices), FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp index 7874d6d67ba7..4cdae35293df 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp @@ -20,258 +20,268 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -static void mergeMaxIndex_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - Nd4jLong idx = 0; - - for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); - if (v > max) { - max = v; - idx = i; - } - } - output.p(e, idx); +template +static void mergeMaxIndex_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T max = -DataTypeUtils::max(); + Nd4jLong idx = 0; + + for (Nd4jLong i = 0; i < numArgs; i++) { + T v = inArrs[i]->e(e); + if (v > max) { + max = v; + idx = i; } - }; + } + output.p(e, idx); + } + }; - samediff::Threads::parallel_for(func, 0, x->lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } -void mergeMaxIndex(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES); +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), + LIBND4J_TYPES); } - ////////////////////////////////////////////////////////////////////////// -template -static void mergeMax_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); - if (v > max) - max = v; - } - output.p(e, max); - } - }; +template +static void mergeMax_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T max = -DataTypeUtils::max(); + for (Nd4jLong i = 0; i < numArgs; i++) { + T v = inArrs[i]->e(e); + if (v > max) max = v; + } + output.p(e, max); + } + }; - samediff::Threads::parallel_for(func, 0, x->lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } -void mergeMax(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), LIBND4J_TYPES); +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), + LIBND4J_TYPES); } - ////////////////////////////////////////////////////////////////////////// -template -static void mergeMaxBp_(const std::vector& inArrs, std::vector& outArrs) { - - // outArrs.size() == inArrs.size() - 1 - const Nd4jLong numArgs = outArrs.size(); - // last array is gradient - const auto gradient = inArrs[numArgs]->bufferAsT(); - auto length = inArrs[numArgs]->lengthOf(); - - bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews()); - - if (bSameOrderAndEws1) { - auto gradOrdering = inArrs[numArgs]->ordering(); - - for (int i = 0; i < numArgs; ++i) { - bSameOrderAndEws1 &= (gradOrdering == inArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - bSameOrderAndEws1 &= (gradOrdering == outArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); - } - } - - - if(bSameOrderAndEws1){ - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - Nd4jLong nMaxIndex = 0; - for (Nd4jLong i = 0; i < numArgs; i++) { - const T* v = inArrs[i]->bufferAsT(); - if (v[e] > max) { - max = v[e]; - nMaxIndex = i; - } - } - T* z = outArrs[nMaxIndex]->bufferAsT(); - z[e] = gradient[e]; - } - }; - - samediff::Threads::parallel_for(func, 0, length); - return; - } - - auto gradShape = inArrs[numArgs]->shapeInfo(); - std::vector vbSameShaepeAndStrides(numArgs); +template +static void mergeMaxBp_(const std::vector& inArrs, + std::vector& outArrs) { + // outArrs.size() == inArrs.size() - 1 + const Nd4jLong numArgs = outArrs.size(); + // last array is gradient + const auto gradient = inArrs[numArgs]->bufferAsT(); + auto length = inArrs[numArgs]->lengthOf(); + + bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews()); + + if (bSameOrderAndEws1) { + auto gradOrdering = inArrs[numArgs]->ordering(); + for (int i = 0; i < numArgs; ++i) { - vbSameShaepeAndStrides[i] = shape::haveSameShapeAndStrides(gradShape, inArrs[i]->shapeInfo()); + bSameOrderAndEws1 &= (gradOrdering == inArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); + bSameOrderAndEws1 &= (gradOrdering == outArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); } + } - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto e = start; e < stop; e++) { - - shape::index2coordsCPU(start, e, gradShape, coords); - - const auto gradOffset = shape::getOffset(gradShape, coords); - - T max = -DataTypeUtils::max(); - Nd4jLong nMaxIndex = 0; - - for (Nd4jLong i = 0; i < numArgs; i++) { - - const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords); - const T* v = inArrs[i]->bufferAsT(); - if (v[xOffset] > max) { - max = v[xOffset]; - nMaxIndex = i; - } - } - - const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords); - - T* z = outArrs[nMaxIndex]->bufferAsT(); - z[zOffset] = gradient[gradOffset]; - } + if (bSameOrderAndEws1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T max = -DataTypeUtils::max(); + Nd4jLong nMaxIndex = 0; + for (Nd4jLong i = 0; i < numArgs; i++) { + const T* v = inArrs[i]->bufferAsT(); + if (v[e] > max) { + max = v[e]; + nMaxIndex = i; + } + } + T* z = outArrs[nMaxIndex]->bufferAsT(); + z[e] = gradient[e]; + } }; samediff::Threads::parallel_for(func, 0, length); return; -} - -void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(outArrs[0]->dataType(), mergeMaxBp_, (inArrs, outArrs), LIBND4J_TYPES); -} + } + + auto gradShape = inArrs[numArgs]->shapeInfo(); + std::vector vbSameShaepeAndStrides(numArgs); + for (int i = 0; i < numArgs; ++i) { + vbSameShaepeAndStrides[i] = + shape::haveSameShapeAndStrides(gradShape, inArrs[i]->shapeInfo()); + } + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto e = start; e < stop; e++) { + shape::index2coordsCPU(start, e, gradShape, coords); + + const auto gradOffset = shape::getOffset(gradShape, coords); + + T max = -DataTypeUtils::max(); + Nd4jLong nMaxIndex = 0; + + for (Nd4jLong i = 0; i < numArgs; i++) { + const auto xOffset = + vbSameShaepeAndStrides[i] + ? gradOffset + : shape::getOffset(inArrs[i]->shapeInfo(), coords); + const T* v = inArrs[i]->bufferAsT(); + if (v[xOffset] > max) { + max = v[xOffset]; + nMaxIndex = i; + } + } -////////////////////////////////////////////////////////////////////////// -template -static void mergeAvg_(const std::vector& inArrs, NDArray& output) { - const Nd4jLong numArgs = inArrs.size(); - const T factor = 1.f / numArgs; - auto x = inArrs[0]; + const auto zOffset = + vbSameShaepeAndStrides[nMaxIndex] + ? gradOffset + : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T sum = 0.; - for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); - sum += v; - } - output.p(e, sum * factor); - } - }; + T* z = outArrs[nMaxIndex]->bufferAsT(); + z[zOffset] = gradient[gradOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, x->lengthOf()); + samediff::Threads::parallel_for(func, 0, length); + return; } -void mergeAvg(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), LIBND4J_TYPES); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(outArrs[0]->dataType(), mergeMaxBp_, (inArrs, outArrs), + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// -template -static void mergeAvgBp_(const NDArray& gradient, std::vector& outArrs) { - - const Nd4jLong numArgs = outArrs.size(); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - - T v = gradient.e(e) / numArgs; - - for (Nd4jLong i = 0; i < numArgs; i++) { - outArrs[i]->p(e, v); - } - } - }; +template +static void mergeAvg_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + const T factor = 1.f / numArgs; + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T sum = 0.; + for (Nd4jLong i = 0; i < numArgs; i++) { + T v = inArrs[i]->e(e); + sum += v; + } + output.p(e, sum * factor); + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } -void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (gradient, outArrs), LIBND4J_TYPES); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), + LIBND4J_TYPES); } - ////////////////////////////////////////////////////////////////////////// -template -static void mergeAdd_(const std::vector& inArrs, NDArray& output) { - - const Nd4jLong numArgs = inArrs.size(); - auto x = inArrs[0]; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T sum = (T) 0.f; - for (Nd4jLong i = 0; i < numArgs; i++) - sum += inArrs[i]->e(e); +template +static void mergeAvgBp_(const NDArray& gradient, + std::vector& outArrs) { + const Nd4jLong numArgs = outArrs.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T v = gradient.e(e) / numArgs; + + for (Nd4jLong i = 0; i < numArgs; i++) { + outArrs[i]->p(e, v); + } + } + }; - output.p(e, sum); - } - }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); +} - samediff::Threads::parallel_for(func, 0, x->lengthOf()); +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (gradient, outArrs), + LIBND4J_TYPES); } - void mergeAdd(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES); - } ////////////////////////////////////////////////////////////////////////// -template -static void mergeAddBp_(const NDArray& gradient, std::vector& outArrs) { - - const Nd4jLong numArgs = outArrs.size(); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - - T v = gradient.e(e); - - for (Nd4jLong i = 0; i < numArgs; i++) { - outArrs[i]->p(e, v); - } - } - }; +template +static void mergeAdd_(const std::vector& inArrs, + NDArray& output) { + const Nd4jLong numArgs = inArrs.size(); + auto x = inArrs[0]; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T sum = (T)0.f; + for (Nd4jLong i = 0; i < numArgs; i++) sum += inArrs[i]->e(e); + + output.p(e, sum); + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); + samediff::Threads::parallel_for(func, 0, x->lengthOf()); } - -void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (gradient, outArrs), LIBND4J_TYPES); +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), + LIBND4J_TYPES); } +////////////////////////////////////////////////////////////////////////// +template +static void mergeAddBp_(const NDArray& gradient, + std::vector& outArrs) { + const Nd4jLong numArgs = outArrs.size(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T v = gradient.e(e); + + for (Nd4jLong i = 0; i < numArgs; i++) { + outArrs[i]->p(e, v); + } + } + }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf()); } + +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (gradient, outArrs), + LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp index a6c167f27f87..7d98880ae34f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/meshgrid.cpp @@ -18,36 +18,33 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.04.2018 // - -#include #include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - - const int rank = inArrs.size(); - int inIndices[MAX_RANK]; - std::iota(inIndices, inIndices + rank, 0); - if(swapFirst2Dims && rank > 1) { - inIndices[0] = 1; - inIndices[1] = 0; - } - - for(int i = 0; i < rank; ++i) { - auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - for(int j = 0; j < list.size(); ++j) - list.at(j).assign(inArrs[i]); - } -} - -} -} +void meshgrid(sd::LaunchContext* context, const std::vector& inArrs, + const std::vector& outArrs, const bool swapFirst2Dims) { + const int rank = inArrs.size(); + int inIndices[MAX_RANK]; + std::iota(inIndices, inIndices + rank, 0); + if (swapFirst2Dims && rank > 1) { + inIndices[0] = 1; + inIndices[1] = 0; + } + + for (int i = 0; i < rank; ++i) { + auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + for (int j = 0; j < list.size(); ++j) list.at(j).assign(inArrs[i]); + } } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp index 6174151d6874..ce7fffb97381 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/minimax.cpp @@ -19,159 +19,160 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x <= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); +template +static void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + if (x->isSameShape(y)) { + // PWT case case - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x <= s ? _e : (T)0.; }; - } - template - void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + auto targetShape = epsNext->getShapeAsVector(); - if (x->isSameShape(y)) { - // PWT case case + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x >= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - } - - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); - BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } } +template +void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x >= s ? _e : (T)0.; }; + + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case + + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } } + +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); +} + +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); } +BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, + (NDArray * x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY), + NUMERIC_TYPES); +BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, + (NDArray * x, NDArray* y, NDArray* epsNext, + NDArray* gradX, NDArray* gradY), + NUMERIC_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp index 90d687aba3b5..642e857dbaa2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/nth_element.cpp @@ -18,59 +18,70 @@ // @author sgazeos@gmail.com // -#include -#include -#include -#include #include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - void nthElementFunctor_(NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { +template +void nthElementFunctor_(NDArray* input, Nd4jLong n, NDArray* output, + bool reverse) { + NDArray sortedVals(*input); + if (input->isVector()) { + // std::vector data(input->lengthOf()); + // memcpy(&data[0], input->buffer(), sizeof(T) * data.size()); + // size_t l = 0; + // for (size_t l = 0; l < data.size(); ++l) + // data[l] = input->e(l); + // auto nthPos = data.begin(); + // nthPos += n; + // std::nth_element(data.begin(), nthPos, data.end()); + SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), + reverse); + output->p(0, sortedVals.e(n)); + } else { // rank greater than 1 + std::vector lastDims( + {input->rankOf() - + 1}); // = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {input->rankOf() - 1}); - NDArray sortedVals(*input); - if (input->isVector()) { - //std::vector data(input->lengthOf()); - //memcpy(&data[0], input->buffer(), sizeof(T) * data.size()); - //size_t l = 0; - //for (size_t l = 0; l < data.size(); ++l) - // data[l] = input->e(l); - //auto nthPos = data.begin(); - //nthPos += n; - //std::nth_element(data.begin(), nthPos, data.end()); - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), reverse); - output->p(0, sortedVals.e(n)); - } - else { // rank greater than 1 - std::vector lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); + auto pack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sortedVals.shapeInfo(), lastDims); - auto pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(sortedVals.shapeInfo(), lastDims); + SpecialMethods::sortTadGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), lastDims.data(), + lastDims.size(), pack.primaryShapeInfo(), + pack.primaryOffsets(), reverse); - SpecialMethods::sortTadGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), lastDims.data(), lastDims.size(), pack.primaryShapeInfo(), pack.primaryOffsets(), reverse); + ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); + Nd4jLong oL = output->lengthOf(); - ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); - Nd4jLong oL = output->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto row = rows.at(e); - output->p(e, row.e(n)); - } - }; - - samediff::Threads::parallel_for(func, 0, oL); - } - } - - void nthElementFunctor(sd::LaunchContext *launchContext, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { - BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (input, n, output, reverse), LIBND4J_TYPES); - - } - BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto row = rows.at(e); + output->p(e, row.e(n)); + } + }; + samediff::Threads::parallel_for(func, 0, oL); + } } + +void nthElementFunctor(sd::LaunchContext* launchContext, NDArray* input, + Nd4jLong n, NDArray* output, bool reverse) { + BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, + (input, n, output, reverse), LIBND4J_TYPES); } -} +BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, + (NDArray * input, Nd4jLong n, NDArray* output, + bool reverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp index 2aa14585bbc3..8a717efd90d4 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp @@ -18,86 +18,102 @@ // @author raver119@gmail.com // -#include -#include -#include #include "../one_hot.h" +#include +#include +#include + namespace sd { - namespace ops { - namespace helpers { - template - static void onehot_(void *voutput, Nd4jLong const* zShapeInfo, void const* vindices, Nd4jLong const* iShapeInfo, int axis, double on, double off) { - auto output = reinterpret_cast(voutput); - auto indices = reinterpret_cast(vindices); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(zShapeInfo, {axis}); - - auto iLen = static_cast(shape::length(iShapeInfo)); - auto tLen = static_cast(shape::length(tadPack.primaryShapeInfo())); - auto numTads = static_cast(tadPack.numberOfTads()); - auto tadEws = shape::elementWiseStride(tadPack.primaryShapeInfo()); - - if (iLen != numTads) - throw std::runtime_error("OneHot: number of TADs should be equal to number of indices"); - - if (shape::elementWiseStride(zShapeInfo) != 1 || shape::elementWiseStride(iShapeInfo) != 1) - throw std::runtime_error("OneHot: op expects output and indices to have elementWiseStride to be equal to 1"); - - Z zero = static_cast(off); - Z one = static_cast(on); - - if (tadEws >= 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = 0; e < stop; e++) { - auto cO = output + tadPack.primaryOffsets()[e]; - - auto idx = static_cast(indices[e]); - if (idx < 0 || idx >= tLen) { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[t * tadEws] = zero; - } - } else { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[t * tadEws] = idx == t ? one : zero; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto cO = output + tadPack.primaryOffsets()[e]; - - auto idx = static_cast(indices[e]); - if (idx < 0 || idx >= tLen) { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = zero; - } - } else { - PRAGMA_OMP_SIMD - for (unsigned int t = 0; t < tLen; t++) { - cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = idx == t ? one : zero; - } - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); - } - } - - void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off) { - auto zType = output->dataType(); - auto iType = indices->dataType(); - - BUILD_DOUBLE_SELECTOR(zType, iType, onehot_, (output->buffer(), output->shapeInfo(), indices->buffer(), indices->shapeInfo(), axis, on, off), LIBND4J_TYPES, LIBND4J_TYPES); - } +namespace ops { +namespace helpers { +template +static void onehot_(void* voutput, Nd4jLong const* zShapeInfo, + void const* vindices, Nd4jLong const* iShapeInfo, int axis, + double on, double off) { + auto output = reinterpret_cast(voutput); + auto indices = reinterpret_cast(vindices); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + zShapeInfo, {axis}); + + auto iLen = static_cast(shape::length(iShapeInfo)); + auto tLen = + static_cast(shape::length(tadPack.primaryShapeInfo())); + auto numTads = static_cast(tadPack.numberOfTads()); + auto tadEws = shape::elementWiseStride(tadPack.primaryShapeInfo()); + + if (iLen != numTads) + throw std::runtime_error( + "OneHot: number of TADs should be equal to number of indices"); + + if (shape::elementWiseStride(zShapeInfo) != 1 || + shape::elementWiseStride(iShapeInfo) != 1) + throw std::runtime_error( + "OneHot: op expects output and indices to have elementWiseStride to be " + "equal to 1"); + + Z zero = static_cast(off); + Z one = static_cast(on); + + if (tadEws >= 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = 0; e < stop; e++) { + auto cO = output + tadPack.primaryOffsets()[e]; + + auto idx = static_cast(indices[e]); + if (idx < 0 || idx >= tLen) { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[t * tadEws] = zero; + } + } else { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[t * tadEws] = idx == t ? one : zero; + } + } + } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto cO = output + tadPack.primaryOffsets()[e]; + + auto idx = static_cast(indices[e]); + if (idx < 0 || idx >= tLen) { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = zero; + } + } else { + PRAGMA_OMP_SIMD + for (unsigned int t = 0; t < tLen; t++) { + cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = + idx == t ? one : zero; + } } - } + } + }; + + samediff::Threads::parallel_tad(func, 0, numTads); + } +} + +void onehot(const sd::LaunchContext* context, const NDArray* indices, + NDArray* output, const uint axis, const uint depth, const double on, + const double off) { + auto zType = output->dataType(); + auto iType = indices->dataType(); + + BUILD_DOUBLE_SELECTOR( + zType, iType, onehot_, + (output->buffer(), output->shapeInfo(), indices->buffer(), + indices->shapeInfo(), axis, on, off), + LIBND4J_TYPES, LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp index a0efd44c1d87..003ef27bf30c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/pad.cpp @@ -18,123 +18,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -template -void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const Nd4jLong* xShape = input.shapeOf(); - const Nd4jLong* zShape = output.shapeOf(); - - const int rank = input.rankOf(); // both input and output have the same rank - const int rankMinusOne = rank - 1; - - const auto zLen = output.lengthOf(); - - if(mode == 0) { // CONSTANT case - - const T padVal = padValue.e(0); +template +void pad_(const int mode, const NDArray& input, const NDArray& paddings, + NDArray& output, const NDArray& padValue) { + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - auto func = PRAGMA_THREADS_FOR { + const Nd4jLong* xShape = input.shapeOf(); + const Nd4jLong* zShape = output.shapeOf(); - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + const int rank = input.rankOf(); // both input and output have the same rank + const int rankMinusOne = rank - 1; - for (auto i = start; i < stop; i++) { + const auto zLen = output.lengthOf(); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + if (mode == 0) { // CONSTANT case - memcpy(xCoords, zCoords, rank * sizeof(int)); + const T padVal = padValue.e(0); - bool within = true; + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - for (int j = rankMinusOne; j >= 0; --j) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - if (xShape[j] == zShape[j]) - continue; + memcpy(xCoords, zCoords, rank * sizeof(int)); - const auto left = paddings.e(j, 0); + bool within = true; - if (zCoords[j] < left || zCoords[j] >= left + xShape[j]) { - within = false; - break; - } - else - xCoords[j] = zCoords[j] - left; - } - - if (within) - z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; - else - z[zOffset] = padVal; - } - }; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; - samediff::Threads::parallel_tad(func, 0, zLen); - } - else { // REFLECT and SYMMETRIC cases + const auto left = paddings.e(j, 0); - const Nd4jLong shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC - const Nd4jLong shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC + if (zCoords[j] < left || zCoords[j] >= left + xShape[j]) { + within = false; + break; + } else + xCoords[j] = zCoords[j] - left; + } - auto func = PRAGMA_THREADS_FOR { + if (within) + z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; + else + z[zOffset] = padVal; + } + }; - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + samediff::Threads::parallel_tad(func, 0, zLen); + } else { // REFLECT and SYMMETRIC cases - for (auto i = start; i < stop; i++) { + const Nd4jLong shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC + const Nd4jLong shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - memcpy(xCoords, zCoords, rank * sizeof(int)); + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - for (int j = rankMinusOne; j >= 0; --j) { + memcpy(xCoords, zCoords, rank * sizeof(int)); - if (xShape[j] == zShape[j]) - continue; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; - xCoords[j] = zCoords[j] - paddings.e(j, 0); // are ready to fill middle (within input dimension range) + xCoords[j] = + zCoords[j] - + paddings.e(j, 0); // are ready to fill middle (within + // input dimension range) - if (xCoords[j] < 0) - xCoords[j] = -xCoords[j] - shift1; // means fill from left - else if (xCoords[j] >= xShape[j]) - xCoords[j] = 2 * xShape[j] - xCoords[j] - shift2; // means fill from right - } + if (xCoords[j] < 0) + xCoords[j] = -xCoords[j] - shift1; // means fill from left + else if (xCoords[j] >= xShape[j]) + xCoords[j] = + 2 * xShape[j] - xCoords[j] - shift2; // means fill from right + } - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - z[zOffset] = x[xOffset]; - } - }; + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); - } + samediff::Threads::parallel_tad(func, 0, zLen); + } } // ////////////////////////////////////////////////////////////////////////// // template -// void pad2_(const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) { +// void pad2_(const int mode, const NDArray& input, const NDArray& paddings, +// NDArray& output, NDArray const& padValue) { // const int rank = output.rankOf(); // std::vector dimsToExclude(rank); -// std::iota(dimsToExclude.begin(), dimsToExclude.end(), 0); // fill with 0, 1, ... rank-1 +// std::iota(dimsToExclude.begin(), dimsToExclude.end(), 0); // +// fill with 0, 1, ... rank-1 // Nd4jLong numLeft = paddings.e(rank-1,0); // Nd4jLong numRight = paddings.e(rank-1,1); // Nd4jLong inDimSize = input.sizeAt(rank-1); // Nd4jLong outDimSize = output.sizeAt(rank-1); -// std::vector> outIdx = { std::vector(2*rank), {numLeft, numLeft + inDimSize}, {0, numLeft}, {numLeft + inDimSize, outDimSize} }; +// std::vector> outIdx = { +// std::vector(2*rank), {numLeft, numLeft + inDimSize}, {0, +// numLeft}, {numLeft + inDimSize, outDimSize} }; // for(int i = 0; i < rank-1; ++i) { // outIdx[0][2*i] = paddings.e(i, 0); @@ -145,10 +141,12 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // // ***** populate innermost sub-arrays firstly ***** // // dimsToExclude.pop_back(); -// Nd4jLong startL = mode == 1 ? 1 : 0; // REFLECT or SYMMETRIC -// Nd4jLong startR = mode == 1 ? inDimSize-2 : inDimSize-1; // REFLECT or SYMMETRIC +// Nd4jLong startL = mode == 1 ? 1 : 0; // +// REFLECT or SYMMETRIC Nd4jLong startR = mode == 1 ? inDimSize-2 : +// inDimSize-1; // REFLECT or SYMMETRIC -// Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); +// Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), +// dimsToExclude); // NDArray outSubArr0 = output(outIdx[0], true); @@ -171,12 +169,14 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // temp.assign(padValue); // assign right // } // } -// else { // REFLECT or SYMMETRIC +// else { // REFLECT or SYMMETRIC -// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) // fill left side +// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) // +// fill left side // outSubArr1.t(k) = inSubArr.t(e); -// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) // fill right side +// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; +// ++k, --e) // fill right side // outSubArr1.t(k) = inSubArr.t(e); // } // } @@ -203,13 +203,16 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // if(mode == 0) { // outIdxOuter[0] = 0; outIdxOuter[1] = numLeft; -// outIdxInner[0] = numLeft + inDimSize; outIdxInner[1] = outDimSize; +// outIdxInner[0] = numLeft + inDimSize; outIdxInner[1] = +// outDimSize; // } -// startL = mode == 1 ? numLeft + 1 : numLeft; // REFLECT or SYMMETRIC -// startR = mode == 1 ? numLeft + inDimSize - 2 : numLeft + inDimSize-1; // REFLECT or SYMMETRIC +// startL = mode == 1 ? numLeft + 1 : numLeft; // REFLECT or SYMMETRIC +// startR = mode == 1 ? numLeft + inDimSize - 2 : numLeft + inDimSize-1; +// // REFLECT or SYMMETRIC -// numOfSubArrs = ShapeUtils::getNumOfSubArrs(output.shapeInfo(), dimsToExclude); +// numOfSubArrs = ShapeUtils::getNumOfSubArrs(output.shapeInfo(), +// dimsToExclude); // PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(outIdxOuter, outIdxInner)) // for(Nd4jLong j = 0; j < numOfSubArrs; ++j) { @@ -220,17 +223,20 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // if(numLeft != 0) { // NDArray tempO = outSubArr(outIdxOuter); -// tempO.assign(padValue); // assign left +// tempO.assign(padValue); // +// assign left // } // if(numRight != 0) { // NDArray tempI = outSubArr(outIdxInner); -// tempI.assign(padValue); // assign right +// tempI.assign(padValue); // +// assign right // } // } -// else { // REFLECT or SYMMETRIC +// else { // REFLECT or SYMMETRIC -// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) { // fill left side +// for(Nd4jLong k = numLeft-1, e = startL; k >= 0; --k, ++e) { +// // fill left side // outIdxOuter[0] = k; // outIdxOuter[1] = k+1; // outIdxInner[0] = e; @@ -240,7 +246,8 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // outSubArrOuter.assign(outSubArrInner); // } -// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < outDimSize; ++k, --e) { // fill right side +// for(Nd4jLong k = numLeft + inDimSize, e = startR; k < +// outDimSize; ++k, --e) { // fill right side // outIdxOuter[0] = k; // outIdxOuter[1] = k+1; // outIdxInner[0] = e; @@ -254,106 +261,115 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray // } // } -void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue) { - BUILD_SINGLE_SELECTOR(input.dataType(), pad_, (mode, input, paddings, output, padValue), LIBND4J_TYPES); +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, NDArray const& padValue) { + BUILD_SINGLE_SELECTOR(input.dataType(), pad_, + (mode, input, paddings, output, padValue), + LIBND4J_TYPES); } ////////////////////////////////////////////////////////////////////////// -template -static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - - // mode: 0 - REFLECT, else - SYMMETRIC - const int reflBorder = (bool)mode ? 1 : 0; - const int rank = input.rankOf(); - const Nd4jLong outLen = output.lengthOf(); - - if(rank <= 1) { - - const Nd4jLong inLen = input.lengthOf(); - const auto leftSide = paddings.e(0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder; - - for(int i = 0; i < outLen; ++i) { - - if (i < leftSide) // left side - output.p(i, input.e(leftSideCorrected - i)); - - else if(i >= leftSide && i < leftSide + inLen) // middle - output.p(i, input.e(i - leftSide)); - - else // right side - output.p(i, input.e(len - i)); - } +template +static void mirrorPad_(const NDArray& input, const NDArray& paddings, + NDArray& output, const int mode) { + // mode: 0 - REFLECT, else - SYMMETRIC + const int reflBorder = (bool)mode ? 1 : 0; + const int rank = input.rankOf(); + const Nd4jLong outLen = output.lengthOf(); + + if (rank <= 1) { + const Nd4jLong inLen = input.lengthOf(); + const auto leftSide = paddings.e(0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + for (int i = 0; i < outLen; ++i) { + if (i < leftSide) // left side + output.p(i, input.e(leftSideCorrected - i)); + + else if (i >= leftSide && i < leftSide + inLen) // middle + output.p(i, input.e(i - leftSide)); + + else // right side + output.p(i, input.e(len - i)); } - else { - - auto func = PRAGMA_THREADS_FOR { - - int inIdx[MAX_RANK], outIdx[MAX_RANK]; - - for (auto i = start; i < stop; i++) { + } else { + auto func = PRAGMA_THREADS_FOR { + int inIdx[MAX_RANK], outIdx[MAX_RANK]; - shape::index2coordsCPU(start, i, output.shapeInfo(), outIdx); + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), outIdx); - for (int j = 0; j < rank; ++j) { - const Nd4jLong inLen = input.sizeAt(j); - const auto leftSide = paddings.e(j, 0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + for (int j = 0; j < rank; ++j) { + const Nd4jLong inLen = input.sizeAt(j); + const auto leftSide = paddings.e(j, 0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; - if (outIdx[j] < leftSide) // left side - inIdx[j] = leftSideCorrected - outIdx[j]; + if (outIdx[j] < leftSide) // left side + inIdx[j] = leftSideCorrected - outIdx[j]; - else if (outIdx[j] >= leftSide && outIdx[j] < leftSide + inLen) // middle - inIdx[j] = outIdx[j] - leftSide; + else if (outIdx[j] >= leftSide && + outIdx[j] < leftSide + inLen) // middle + inIdx[j] = outIdx[j] - leftSide; - else // right side - inIdx[j] = len - outIdx[j]; - } + else // right side + inIdx[j] = len - outIdx[j]; + } - auto outOffset = shape::getOffset(output.shapeInfo(), outIdx); - auto inOffset = shape::getOffset(input.shapeInfo(), inIdx); - reinterpret_cast(output.buffer())[outOffset] = reinterpret_cast(input.buffer())[inOffset]; - } - }; + auto outOffset = shape::getOffset(output.shapeInfo(), outIdx); + auto inOffset = shape::getOffset(input.shapeInfo(), inIdx); + reinterpret_cast(output.buffer())[outOffset] = + reinterpret_cast(input.buffer())[inOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, outLen); - } + samediff::Threads::parallel_for(func, 0, outLen); + } } - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_SINGLE_SELECTOR(input.dataType(), mirrorPad_, (input, paddings, output, mode), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void mirrorPad_, (const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode) { + BUILD_SINGLE_SELECTOR(input.dataType(), mirrorPad_, + (input, paddings, output, mode), LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void mirrorPad_, + (const NDArray& input, const NDArray& paddings, + NDArray& output, const int mode), + LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// /*// initial values of inIdx, outIdx, dim must be equal to zero template -static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) { +static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& +paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int +outIdx, NDArray& padValue ) { int leftOffset; - // dimensions are array of input dimensions, it is sorted in increasing order - // every time at the beginning we erase first element from it (not good idea to use vector for this purpose, but luckily it is small enough) - // then we use this array for tads building, every time while recursion the number of built tads becomes bigger - dimensions.erase(dimensions.begin()); - // build tad basing on output array, also create auxiliary arrays pointing on required output array ranges - shape::TAD tadOut(output.shapeInfo(), dimensions.data(), dimensions.size()); - tadOut.createTadOnlyShapeInfo(); + // dimensions are array of input dimensions, it is sorted in increasing +order + // every time at the beginning we erase first element from it (not good idea +to use vector for this purpose, but luckily it is small enough) + // then we use this array for tads building, every time while recursion the +number of built tads becomes bigger dimensions.erase(dimensions.begin()); + // build tad basing on output array, also create auxiliary arrays pointing +on required output array ranges shape::TAD tadOut(output.shapeInfo(), +dimensions.data(), dimensions.size()); tadOut.createTadOnlyShapeInfo(); tadOut.createOffsets(); - auto subArrOut = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext()); - auto subArr = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, output.getContext()); - // build tad basing on input array, also create auxiliary array pointing on required input array range - shape::TAD tadIn(input.shapeInfo(), dimensions.data(), dimensions.size()); - tadIn.createTadOnlyShapeInfo(); + auto subArrOut = NDArray(output.getBuffer(), tadOut.tadOnlyShapeInfo, +output.getContext()); auto subArr = NDArray(output.getBuffer(), +tadOut.tadOnlyShapeInfo, output.getContext()); + // build tad basing on input array, also create auxiliary array pointing on +required input array range shape::TAD tadIn(input.shapeInfo(), +dimensions.data(), dimensions.size()); tadIn.createTadOnlyShapeInfo(); tadIn.createOffsets(); - auto subArrIn = NDArray(input.getBuffer(), tadIn.tadOnlyShapeInfo, output.getContext()); - // these indices take into account recursion and always point to actual tads numbers - if (input.rankOf() > 1 && output.rankOf() > 1) {// only for non-vector cases - outIdx = outIdx * output.sizeAt(dim + 1); - inIdx = inIdx * input.sizeAt(dim + 1); + auto subArrIn = NDArray(input.getBuffer(), tadIn.tadOnlyShapeInfo, +output.getContext()); + // these indices take into account recursion and always point to actual tads +numbers if (input.rankOf() > 1 && output.rankOf() > 1) {// only for non-vector +cases outIdx = outIdx * output.sizeAt(dim + 1); inIdx = inIdx * input.sizeAt(dim ++ 1); } // current input tad number, we add to it unity in a loop int k = -1; @@ -366,15 +382,16 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& // increase input tads number ++k; - // recursion condition allows for the fact that tad can't reduce to scalar - if(dim < input.rankOf() - 2) - recursiveLoopForPad(mode, input, paddings, output, dimensions, dim + 1, inIdx + k, outIdx + i, padValue); - else if (paddings.sizeAt(0) > dim + 1){ - leftOffset = paddings.e(dim + 1, 0); + // recursion condition allows for the fact that tad can't reduce to +scalar if(dim < input.rankOf() - 2) recursiveLoopForPad(mode, input, paddings, +output, dimensions, dim + 1, inIdx + k, outIdx + i, padValue); else if +(paddings.sizeAt(0) > dim + 1){ leftOffset = paddings.e(dim + 1, 0); // shift buffers pointers to actual element position if (output.rankOf() > 1) { - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + i]); - subArrIn.setBuffer(reinterpret_cast(input.getBuffer()) + tadIn.tadOffsets[inIdx + i - paddings.e(dim, 0)]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + i]); + subArrIn.setBuffer(reinterpret_cast(input.getBuffer()) + +tadIn.tadOffsets[inIdx + i - paddings.e(dim, 0)]); } else { subArrOut.p(i, subArrIn.e(i - leftOffset)); @@ -383,35 +400,37 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& switch (mode) { case 0: // CONSTANT mode for(int j = 0; j < subArrOut.lengthOf(); ++j) - if(j < leftOffset || j >= (subArrIn.lengthOf() + leftOffset) ) // firstly fill with zeros outer ranges + if(j < leftOffset || j >= (subArrIn.lengthOf() + +leftOffset) ) // firstly fill with zeros outer ranges subArrOut.p(j, (T)0.f); else - subArrOut.p(j, subArrIn.e(j - leftOffset)); // fill middle with elements of input array - break; + subArrOut.p(j, subArrIn.e(j - leftOffset)); +// fill middle with elements of input array break; case 1: // REFLECT mode - for(int j = 1; j <= leftOffset; ++j) // fill firstly left side - subArrOut.p(leftOffset - j, subArrIn.e(j)); - for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle + for(int j = 1; j <= leftOffset; ++j) // fill firstly left +side subArrOut.p(leftOffset - j, subArrIn.e(j)); for(int j = 0; j < +subArrIn.lengthOf(); ++j) // fill middle subArrOut.p(leftOffset + j, subArrIn.e(j)); - for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side - subArrOut.p(j, subArrIn.e(subArrOut.lengthOf() - j - 1)); - break; + for(int j = (subArrOut.lengthOf() - leftOffset); j < +subArrOut.lengthOf(); ++j) // fill right side subArrOut.p(j, +subArrIn.e(subArrOut.lengthOf() - j - 1)); break; case 2: // SYMMETRIC mode - for(int j = 1; j <= leftOffset; ++j) // fill firstly left side - subArrOut.p(leftOffset - j, subArrIn.e(j-1)); - for(int j = 0; j < subArrIn.lengthOf(); ++j) // fill middle + for(int j = 1; j <= leftOffset; ++j) // fill firstly left +side subArrOut.p(leftOffset - j, subArrIn.e(j-1)); for(int j = 0; j < +subArrIn.lengthOf(); ++j) // fill middle subArrOut.p(leftOffset + j, subArrIn.e(j)); - for(int j = (subArrOut.lengthOf() - leftOffset); j < subArrOut.lengthOf(); ++j) // fill right side - subArrOut.p(j, subArrIn.e(subArrOut.lengthOf() - j)); - break; + for(int j = (subArrOut.lengthOf() - leftOffset); j < +subArrOut.lengthOf(); ++j) // fill right side subArrOut.p(j, +subArrIn.e(subArrOut.lengthOf() - j)); break; } } else { if (mode == 0 && input.rankOf() < 2) - subArrOut.p(i, subArrIn.e(i - leftOffset)); // fill middle with elements of input array + subArrOut.p(i, subArrIn.e(i - leftOffset)); // fill middle +with elements of input array } } // populate sub-array formed previously @@ -422,18 +441,19 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& // fill left side with padValue if (output.rankOf() > 1) { subArrOut.setBuffer( - reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(padValue); + reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(padValue); } else { subArrOut.p(j - 1, padValue); } } // output.printIndexedBuffer("Output at"); - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill left side with zeros - if (output.rankOf() > 1) { - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(padValue); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill left side with zeros if +(output.rankOf() > 1) { + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) ++ tadOut.tadOffsets[outIdx + j]); subArrOut.assign(padValue); } else { subArrOut.p(j, padValue); @@ -442,42 +462,54 @@ static void recursiveLoopForPad_(const int mode, NDArray& input, const NDArray& break; case 1: // REFLECT mode - for(int j = 1; j <= leftOffset; ++j) { // fill left side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(&subArr); + for(int j = 1; j <= leftOffset; ++j) { // fill left side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset + j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(&subArr); } - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - 1 - j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(&subArr); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill right side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - 1 - j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + j]); subArrOut.assign(&subArr); } break; case 2: // SYMMETRIC mode - for(int j = 1; j <= leftOffset; ++j) { // fill left side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset + j - 1]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + leftOffset - j]); - subArrOut.assign(&subArr); + for(int j = 1; j <= leftOffset; ++j) { // fill left side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset + j - 1]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + leftOffset - j]); subArrOut.assign(&subArr); } - for(int j = (output.sizeAt(dim) - leftOffset); j < output.sizeAt(dim); ++j) { // fill right side - subArr.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - j]); - subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + tadOut.tadOffsets[outIdx + j]); - subArrOut.assign(&subArr); + for(int j = (output.sizeAt(dim) - leftOffset); j < +output.sizeAt(dim); ++j) { // fill right side + subArr.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + output.sizeAt(dim) + leftOffset - j]); + subArrOut.setBuffer(reinterpret_cast(output.getBuffer()) + +tadOut.tadOffsets[outIdx + j]); subArrOut.assign(&subArr); } break; } } */ /* - void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue ) { - BUILD_SINGLE_SELECTOR(input.dataType(), recursiveLoopForPad_, (mode, input, paddings, output, dimensions, dim, inIdx, outIdx, padValue), LIBND4J_TYPES); + void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& + paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, + int outIdx, NDArray& padValue ) { BUILD_SINGLE_SELECTOR(input.dataType(), + recursiveLoopForPad_, (mode, input, paddings, output, dimensions, dim, inIdx, + outIdx, padValue), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void recursiveLoopForPad_, (const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void recursiveLoopForPad_, (const int mode, + NDArray& input, const NDArray& paddings, NDArray& output, std::vector + dimensions, int dim, int inIdx, int outIdx, NDArray& padValue), + LIBND4J_TYPES); */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp index b8e7229de7e2..9686c83317c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/percentile.cpp @@ -18,70 +18,82 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 17.05.2018 // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void _percentile(const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - - const int inputRank = input.rankOf(); - - if(axises.empty()) - for(int i=0; i shapeOfSubArr(listOfSubArrs.at(0).rankOf()); - for(int i=0; i(math::nd4j_ceil((len - 1) * fraction)); - break; - case 1: // higher - position = static_cast(math::nd4j_floor((len - 1) * fraction)); - break; - case 2: // nearest - position = static_cast(math::nd4j_round((len - 1) * fraction)); - break; - } - position = len - position - 1; - - // FIXME: our sort impl should be used instead, so this operation might be implemented as generic - // FIXME: parallelism ! - for(int i=0; i(flattenedArr.buffer()); - flattenedArr.assign(listOfSubArrs.at(i)); - std::sort(buff, buff + len); - output.p(i, flattenedArr.e(position)); - } +static void _percentile(const NDArray& input, NDArray& output, + std::vector& axises, const float q, + const int interpolation) { + const int inputRank = input.rankOf(); + + if (axises.empty()) + for (int i = 0; i < inputRank; ++i) axises.push_back(i); + else + shape::checkDimensions(inputRank, + axises); // check, sort dimensions and remove + // duplicates if they are present + + auto listOfSubArrs = input.allTensorsAlongDimension(axises); + + std::vector shapeOfSubArr(listOfSubArrs.at(0).rankOf()); + for (int i = 0; i < shapeOfSubArr.size(); ++i) + shapeOfSubArr[i] = listOfSubArrs.at(0).shapeOf()[i]; + + auto flattenedArr = NDArrayFactory::create( + 'c', shapeOfSubArr, input.dataType(), input.getContext()); + const int len = flattenedArr.lengthOf(); + + const float fraction = 1.f - q / 100.; + Nd4jLong position = 0; + + switch (interpolation) { + case 0: // lower + position = static_cast( + math::nd4j_ceil((len - 1) * fraction)); + break; + case 1: // higher + position = static_cast( + math::nd4j_floor((len - 1) * fraction)); + break; + case 2: // nearest + position = static_cast( + math::nd4j_round((len - 1) * fraction)); + break; + } + position = len - position - 1; + + // FIXME: our sort impl should be used instead, so this operation might be + // implemented as generic + // FIXME: parallelism ! + for (int i = 0; i < listOfSubArrs.size(); ++i) { + auto buff = reinterpret_cast(flattenedArr.buffer()); + flattenedArr.assign(listOfSubArrs.at(i)); + std::sort(buff, buff + len); + output.p(i, flattenedArr.e(position)); + } } - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, (input, output, axises, q, interpolation), LIBND4J_TYPES); - } +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation) { + BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, + (input, output, axises, q, interpolation), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _percentile, (const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _percentile, + (const NDArray& input, NDArray& output, + std::vector& axises, const float q, + const int interpolation), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp index 2c93cee08895..57bf5fa51aa1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp @@ -18,80 +18,82 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // calculate factorial template static FORCEINLINE T getFactorial(const int n) { - if (n < 0) - throw std::runtime_error("factorial is not defined for negative number !"); + if (n < 0) + throw std::runtime_error("factorial is not defined for negative number !"); - if(n==0 || n==1) - return (T)1.f; + if (n == 0 || n == 1) return (T)1.f; - T result = (T)1.f; + T result = (T)1.f; - for(int i = 2; i <= n; ++i) - result *= i; + for (int i = 2; i <= n; ++i) result *= i; - return result; + return result; } ////////////////////////////////////////////////////////////////////////// -// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) +// implementation is based on serial representation written in terms of the +// Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) template -static FORCEINLINE T polyGammaScalar(sd::LaunchContext * context, const int n, const T x) { - - // if (n < 0) - // throw("polyGamma function: n must be >= 0 !"); +static FORCEINLINE T polyGammaScalar(sd::LaunchContext* context, const int n, + const T x) { + // if (n < 0) + // throw("polyGamma function: n must be >= 0 !"); - // if (x <= (T)0.) - // throw("polyGamma function: x must be > 0 !"); + // if (x <= (T)0.) + // throw("polyGamma function: x must be > 0 !"); - int sign = (n + 1) % 2 ? -1 : 1; - // T factorial = (T)std::tgamma(n + 1); + int sign = (n + 1) % 2 ? -1 : 1; + // T factorial = (T)std::tgamma(n + 1); - return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); + return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); } - ////////////////////////////////////////////////////////////////////////// // calculate polygamma function for arrays template -static void polyGamma_(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T order = n.e(i); - if(order != static_cast(order)) // if order has fractional part then do not perform calculations and return NAN - output.p(i, std::numeric_limits::quiet_NaN()); - else if (order == 0) // polygamma function of zero order is digamma function - output.p(i, diGammaScalar(x.e(i))); - else - output.p(i, polyGammaScalar(context, order, x.e(i))); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); +static void polyGamma_(sd::LaunchContext* context, const NDArray& n, + const NDArray& x, NDArray& output) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const T order = n.e(i); + if (order != + static_cast(order)) // if order has fractional part then do not + // perform calculations and return NAN + output.p(i, std::numeric_limits::quiet_NaN()); + else if (order == + 0) // polygamma function of zero order is digamma function + output.p(i, diGammaScalar(x.e(i))); + else + output.p(i, polyGammaScalar(context, order, x.e(i))); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); } - void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - BUILD_SINGLE_SELECTOR(x.dataType(), polyGamma_, (context, n, x, output), FLOAT_TYPES); - } - -BUILD_SINGLE_TEMPLATE(template void polyGamma_, (sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output), FLOAT_TYPES); - - - -} -} +void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, + NDArray& output) { + BUILD_SINGLE_SELECTOR(x.dataType(), polyGamma_, (context, n, x, output), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void polyGamma_, + (sd::LaunchContext * context, const NDArray& n, + const NDArray& x, NDArray& output), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp index 4f57170ec507..ea1d66040635 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/prefix.cpp @@ -18,118 +18,136 @@ // @author raver119@gmail.com // -#include -#include #include +#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void prefix_(scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, bool exclusive, bool reverse) { - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - auto length = shape::length(xShapeInfo); - - T prevSum = op == scalar::Add ? (T) 0 : (T) 1; - T sum = prevSum; - - if (reverse) { - if (shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(zShapeInfo) == 1 && - shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { - - for (Nd4jLong e = length - 1; e >= 0; --e) { - sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) : simdOps::Multiply::op(sum, x[e]); - if (!exclusive) - prevSum = sum; - - z[e] = prevSum; - - prevSum = sum; - } - } - else { - - for (Nd4jLong e = length - 1; e >= 0; --e) { - - auto xOffset = shape::getIndexOffset(e, xShapeInfo); - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - sum = op == scalar::Add ? simdOps::Add::op(sum, x[xOffset]) : simdOps::Multiply::op(sum, x[xOffset]); - - if (!exclusive) - prevSum = sum; - - z[zOffset] = prevSum; - prevSum = sum; - } - } - } else { - if (shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(zShapeInfo) == 1 && - shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { - - for (Nd4jLong e = 0; e < length; e++) { - sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) : simdOps::Multiply::op(sum, x[e]); - - if (!exclusive) - prevSum = sum; - - z[e] = prevSum; - - prevSum = sum; - } - } - else { - - for (Nd4jLong e = 0; e < length; e++) { - - auto xOffset = shape::getIndexOffset(e, xShapeInfo); - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - sum = op == scalar::Add ? simdOps::Add::op(sum, x[xOffset]) : simdOps::Multiply::op(sum, x[xOffset]); - - if (!exclusive) - prevSum = sum; - - z[zOffset] = prevSum; - prevSum = sum; - } - } - } - }; - - template - static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - auto xTads = x->allTensorsAlongDimension(dims); - auto zTads = z->allTensorsAlongDimension(dims); - auto t = xTads.size(); - - for (int e = 0; e < t; e++) { - auto tx = xTads.at(e); - auto tz = zTads.at(e); - - prefix_(op, tx.buffer(), tx.shapeInfo(), tz.buffer(), tz.shapeInfo(), exclusive, reverse); - } - }; - - template - static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - prefix_(op, x->buffer(), x->shapeInfo(), z->buffer(), z->shapeInfo(), exclusive, reverse); - }; - - void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, exclusive, reverse), LIBND4J_TYPES); - } - - void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, dims, exclusive, reverse), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, void* vz, Nd4jLong const* zShapeInfo, bool exclusive, bool reverse), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void prefix_, (scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse), LIBND4J_TYPES); - - - - } +namespace ops { +namespace helpers { +template +static void prefix_(scalar::Ops op, const void* vx, Nd4jLong const* xShapeInfo, + void* vz, Nd4jLong const* zShapeInfo, bool exclusive, + bool reverse) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto length = shape::length(xShapeInfo); + + T prevSum = op == scalar::Add ? (T)0 : (T)1; + T sum = prevSum; + + if (reverse) { + if (shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(zShapeInfo) == 1 && + shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { + for (Nd4jLong e = length - 1; e >= 0; --e) { + sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) + : simdOps::Multiply::op(sum, x[e]); + if (!exclusive) prevSum = sum; + + z[e] = prevSum; + + prevSum = sum; + } + } else { + for (Nd4jLong e = length - 1; e >= 0; --e) { + auto xOffset = shape::getIndexOffset(e, xShapeInfo); + auto zOffset = shape::getIndexOffset(e, zShapeInfo); + sum = op == scalar::Add + ? simdOps::Add::op(sum, x[xOffset]) + : simdOps::Multiply::op(sum, x[xOffset]); + + if (!exclusive) prevSum = sum; + + z[zOffset] = prevSum; + prevSum = sum; + } + } + } else { + if (shape::elementWiseStride(xShapeInfo) == 1 && + shape::elementWiseStride(zShapeInfo) == 1 && + shape::order(xShapeInfo) == 'c' && shape::order(zShapeInfo) == 'c') { + for (Nd4jLong e = 0; e < length; e++) { + sum = op == scalar::Add ? simdOps::Add::op(sum, x[e]) + : simdOps::Multiply::op(sum, x[e]); + + if (!exclusive) prevSum = sum; + + z[e] = prevSum; + + prevSum = sum; + } + } else { + for (Nd4jLong e = 0; e < length; e++) { + auto xOffset = shape::getIndexOffset(e, xShapeInfo); + auto zOffset = shape::getIndexOffset(e, zShapeInfo); + sum = op == scalar::Add + ? simdOps::Add::op(sum, x[xOffset]) + : simdOps::Multiply::op(sum, x[xOffset]); + + if (!exclusive) prevSum = sum; + + z[zOffset] = prevSum; + prevSum = sum; + } } -} \ No newline at end of file + } +}; + +template +static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, + const std::vector& dims, bool exclusive, + bool reverse) { + auto xTads = x->allTensorsAlongDimension(dims); + auto zTads = z->allTensorsAlongDimension(dims); + auto t = xTads.size(); + + for (int e = 0; e < t; e++) { + auto tx = xTads.at(e); + auto tz = zTads.at(e); + + prefix_(op, tx.buffer(), tx.shapeInfo(), tz.buffer(), tz.shapeInfo(), + exclusive, reverse); + } +}; + +template +static void prefix_(scalar::Ops op, const NDArray* x, NDArray* z, + bool exclusive, bool reverse) { + prefix_(op, x->buffer(), x->shapeInfo(), z->buffer(), z->shapeInfo(), + exclusive, reverse); +}; + +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse) { + BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, (op, x, z, exclusive, reverse), + LIBND4J_TYPES); +} + +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse) { + BUILD_SINGLE_SELECTOR(x->dataType(), prefix_, + (op, x, z, dims, exclusive, reverse), LIBND4J_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const void* vx, + Nd4jLong const* xShapeInfo, void* vz, + Nd4jLong const* zShapeInfo, bool exclusive, + bool reverse), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const NDArray* x, NDArray* z, + const std::vector& dims, bool exclusive, + bool reverse), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void prefix_, + (scalar::Ops op, const NDArray* x, NDArray* z, + bool exclusive, bool reverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp index 26a24a5afcc2..7790f2a6f156 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/print_variable.cpp @@ -21,11 +21,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { - array.printIndexedBuffer(message.c_str()); - } - } - } +namespace ops { +namespace helpers { +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message) { + array.printIndexedBuffer(message.c_str()); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp index cab9132eb43e..45e29354ec65 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/qr.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,117 +17,132 @@ // // @author George A. Shulinok // -#include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - NDArray matrixMinor(NDArray& in, Nd4jLong col) { - NDArray m = in.ulike(); - m.setIdentity(); - m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); +template +NDArray matrixMinor(NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}) + .assign(in({col, m.rows(), col, m.columns()})); - return m; - } + return m; +} /* m = I - v v^T */ - template - NDArray vmul(NDArray const& v, int n) - { - NDArray res('c', {n,n}, v.dataType(), v.getContext()); // x = matrix_new(n, n); - T const* vBuf = v.getDataBuffer()->primaryAsT(); - T* resBuf = res.dataBuffer()->primaryAsT(); - auto interloop = PRAGMA_THREADS_FOR_2D { - for (auto i = start_x; i < n; i += inc_x) - for (auto j = start_y; j < n; j += inc_y) - resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); - }; - - samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); - return res; - } - - template - void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { - Nd4jLong M = matrix->sizeAt(-2); - Nd4jLong N = matrix->sizeAt(-1); - auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); - auto resR = fullMatricies?R->ulike():matrix->ulike(); - std::vector q(M); - - NDArray z = matrix->dup(); - NDArray e('c', {M}, DataTypeUtils::fromT(), Q->getContext()); // two internal buffers and scalar for squared norm - - for (Nd4jLong k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number - e.nullify(); - z = matrixMinor(z, k); // minor computing for current column with given matrix z (initally is a input matrix) -// z.printIndexedBuffer("Minor!!!"); - - auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer - auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); - if (matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element - norm *= T(-1.f);//.applyTransform(transform::Neg, nullptr, nullptr); //t(0) = -norm.t(0); - //e.t(k) = T(1.f); // e - is filled by 0 vector except diagonal element (filled by 1) - //auto tE = e; - //tE *= norm; -// norm.printIndexedBuffer("Norm!!!"); - e.p(k, norm); - e += currentColumn;// e += tE; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 - auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); - e /= normE; - q[k] = vmul(e, M); - auto qQ = z.ulike(); - MmulHelper::matmul(&q[k], &z, &qQ, false, false); - z = std::move(qQ); - } - resQ.assign(q[0]); // -// MmulHelper::matmul(&q[0], matrix, &resR, false, false); - for (Nd4jLong i = 1; i < N && i < M - 1; i++) { - auto tempResQ = resQ.ulike(); - MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? - resQ = std::move(tempResQ); - } - MmulHelper::matmul(&resQ, matrix, &resR, false, false); - // resR *= -1.f; - resQ.transposei(); - if (fullMatricies) { - Q->assign(resQ); - R->assign(resR); - } - else { - Q->assign(resQ({0,0, 0, N})); - R->assign(resR({0,N, 0, 0})); - } - } - - template - void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - auto batching = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - //qr here - qrSingle(&listInput.at(batch), &listOutQ.at(batch), &listOutR.at(batch), fullMatricies); - } - }; - - samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); +template +NDArray vmul(NDArray const& v, int n) { + NDArray res('c', {n, n}, v.dataType(), + v.getContext()); // x = matrix_new(n, n); + T const* vBuf = v.getDataBuffer()->primaryAsT(); + T* resBuf = res.dataBuffer()->primaryAsT(); + auto interloop = PRAGMA_THREADS_FOR_2D { + for (auto i = start_x; i < n; i += inc_x) + for (auto j = start_y; j < n; j += inc_y) + resBuf[i * n + j] = -2 * vBuf[i] * vBuf[j] + (i == j ? T(1) : T(0)); + }; + + samediff::Threads::parallel_for(interloop, 0, n, 1, 0, n, 1); + return res; +} - } +template +void qrSingle(NDArray* matrix, NDArray* Q, NDArray* R, + bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(-2); + Nd4jLong N = matrix->sizeAt(-1); + auto resQ = fullMatricies ? Q->ulike() + : NDArrayFactory::create( + matrix->ordering(), {M, M}, Q->getContext()); + auto resR = fullMatricies ? R->ulike() : matrix->ulike(); + std::vector q(M); + + NDArray z = matrix->dup(); + NDArray e( + 'c', {M}, DataTypeUtils::fromT(), + Q->getContext()); // two internal buffers and scalar for squared norm + + for (Nd4jLong k = 0; k < N && k < M - 1; + k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(z, k); // minor computing for current column with given + // matrix z (initally is a input matrix) + // z.printIndexedBuffer("Minor!!!"); + + auto currentColumn = + z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (matrix->t(k, k) > + T(0.f)) // negate on positive matrix diagonal element + norm *= T(-1.f); //.applyTransform(transform::Neg, nullptr, nullptr); + ////t(0) = -norm.t(0); + // e.t(k) = T(1.f); // e - is filled by 0 vector except diagonal element + // (filled by 1) auto tE = e; tE *= norm; + // norm.printIndexedBuffer("Norm!!!"); + e.p(k, norm); + e += currentColumn; // e += tE; // e[i] = x[i] + a * e[i] for each i from + // 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // + // MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (Nd4jLong i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ.ulike(); + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); // use mmulMxM? + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } else { + Q->assign(resQ({0, 0, 0, N})); + R->assign(resR({0, N, 0, 0})); + } +} - void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (input, outputQ, outputR, fullMatricies), FLOAT_TYPES); +template +void qr_(NDArray const* input, NDArray* outputQ, NDArray* outputR, + bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + ResultSet listOutQ( + outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR( + outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput( + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto batching = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + // qr here + qrSingle(&listInput.at(batch), &listOutQ.at(batch), + &listOutR.at(batch), fullMatricies); } + }; + samediff::Threads::parallel_tad(batching, 0, listOutQ.size(), 1); } -} + +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, + (input, outputQ, outputR, fullMatricies), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index 1e96211b3141..bda893bddaf7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -22,193 +22,207 @@ //#include #include //#include -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template - void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - - auto broadcasted = alpha->shapeInfo(); - if (beta != nullptr) { - const Nd4jLong* broadcastedShape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcastedShape, context->getWorkspace()); - broadcasted = broadcastedShape; - } - - auto step = shape::length(broadcasted); - auto shift = output->lengthOf() / step; - - auto copyAlpha = alpha; - auto copyBeta = beta; - if (beta != nullptr) { - NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); - NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); - - copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); - copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - - } -// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; - bool directOutput = output->ews() == 1 && output->ordering() == 'c'; - T* outputBuf = output->dataBuffer()->primaryAsT(); - - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong k = 0; k < shift; k++) { - auto pos = k * step; - auto u = rng.relativeT(k, 0., 1.); - for (Nd4jLong e = 0; e < step; e++) - if (directOutput) { - outputBuf[pos + e] = math::nd4j_igamma(copyAlpha->t(e), - beta != nullptr ? copyBeta->t(e) * u : u); - } - else { - output->t(pos + e) = math::nd4j_igamma(copyAlpha->t(e), - beta != nullptr ? copyBeta->t(e) * u : u); - } - } - - if (beta != nullptr) { - delete copyAlpha; - delete copyBeta; - //delete broadcasted; - } - } +template +void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + auto broadcasted = alpha->shapeInfo(); + if (beta != nullptr) { + const Nd4jLong* broadcastedShape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcastedShape, + context->getWorkspace()); + broadcasted = broadcastedShape; + } + + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() / step; + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); + + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast( + BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray( + betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); + } + // bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; + bool directOutput = output->ews() == 1 && output->ordering() == 'c'; + T* outputBuf = output->dataBuffer()->primaryAsT(); + + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong k = 0; k < shift; k++) { + auto pos = k * step; + auto u = rng.relativeT(k, 0., 1.); + for (Nd4jLong e = 0; e < step; e++) + if (directOutput) { + outputBuf[pos + e] = math::nd4j_igamma( + copyAlpha->t(e), beta != nullptr ? copyBeta->t(e) * u : u); + } else { + output->t(pos + e) = math::nd4j_igamma( + copyAlpha->t(e), beta != nullptr ? copyBeta->t(e) * u : u); + } + } + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + // delete broadcasted; + } +} - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); - } - BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, - graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); - - /* - * algorithm Poisson generator based upon the inversion by sequential search:[48]:505 - init: - Let x ← 0, p ← e−λ, s ← p. - Generate uniform random number u in [0,1]. - while u > s do: - x ← x + 1. - p ← p * λ / x. - s ← s + p. - return x. - * */ - template - void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - auto shift = output->lengthOf() / lambda->lengthOf(); - auto step = lambda->lengthOf(); - T* lambdaBuf = lambda->dataBuffer()->primaryAsT(); - T* outputBuf = output->dataBuffer()->primaryAsT(); - bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c'; - bool directOut = output->ews() == 1 && output->ordering() == 'c'; - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong k = 0; k < shift; k++) { - auto pos = k * step; - auto u = rng.relativeT(k, 0., 1.); - for (Nd4jLong e = 0; e < step; e++) { - auto p = math::nd4j_exp(-lambda->t(e)); - auto s = p; - auto x = T(0.f); - while (u > s) { - x += 1.f; - p *= directLa?lambdaBuf[e]/x:lambda->t(e) / x; - s += p; - } - if (directOut) - outputBuf[pos + e] = x; - else - output->t(pos + e) = x; - } - } +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, + (context, rng, alpha, beta, output), FLOAT_NATIVE); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output), + FLOAT_NATIVE); + +/* + * algorithm Poisson generator based upon the inversion by sequential +search:[48]:505 init: Let x ← 0, p ← e−λ, s ← p. Generate uniform random number +u in [0,1]. while u > s do: x ← x + 1. p ← p * λ / x. s ← s + p. return x. + * */ +template +void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + auto step = lambda->lengthOf(); + T* lambdaBuf = lambda->dataBuffer()->primaryAsT(); + T* outputBuf = output->dataBuffer()->primaryAsT(); + bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c'; + bool directOut = output->ews() == 1 && output->ordering() == 'c'; + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong k = 0; k < shift; k++) { + auto pos = k * step; + auto u = rng.relativeT(k, 0., 1.); + for (Nd4jLong e = 0; e < step; e++) { + auto p = math::nd4j_exp(-lambda->t(e)); + auto s = p; + auto x = T(0.f); + while (u > s) { + x += 1.f; + p *= directLa ? lambdaBuf[e] / x : lambda->t(e) / x; + s += p; + } + if (directOut) + outputBuf[pos + e] = x; + else + output->t(pos + e) = x; } + } +} - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); - } - BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, - graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES); - - template - void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - T minVal = T(0); - T maxVal = DataTypeUtils::max(); - if (min) - minVal = min->t(0); - if (max) - maxVal = max->t(0); - - if (output->isR()) - RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); - else { - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong i = 0; i < output->lengthOf(); i++) { - output->t(i) = rng.relativeT(i, minVal, maxVal); - } - } +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, + (context, rng, lambda, output), FLOAT_NATIVE); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output), + FLOAT_TYPES); + +template +void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + T minVal = T(0); + T maxVal = DataTypeUtils::max(); + if (min) minVal = min->t(0); + if (max) maxVal = max->t(0); + + if (output->isR()) + RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); + else { + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < output->lengthOf(); i++) { + output->t(i) = rng.relativeT(i, minVal, maxVal); } + } +} - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); - } +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, + (context, rng, min, max, output), NUMERIC_TYPES); +} - // used https://en.wikipedia.org/wiki/Categorical_distribution - // methods: gumbel trick + softmax + argmax - template - void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - - const Tx* x = input.bufferAsT(); - Tz* z = output.bufferAsT(); - - Tx minVal = DataTypeUtils::min(); - Tx maxVal = 1.0; - - auto dimA = (0 == dimC) ? 1 : 0; - const Nd4jLong batchValue = output.sizeAt(dimC); - const Nd4jLong numOfClassX = input.sizeAt(dimA); - - const Nd4jLong zDimAstride = output.stridesOf()[dimA]; - const Nd4jLong xDimAstride = input.stridesOf()[dimA]; - const Nd4jLong zDimCstride = output.stridesOf()[dimC]; - const Nd4jLong xDimCstride = input.stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR_2D{ - for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) { - for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) { - - const Tx* xTad = x + (nBatchIndex * xDimCstride); - Tz* zTad = z + (nBatchIndex * zDimCstride); - Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; - Tx Max = -minVal; - - auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; - auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; - for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass += 1) { - auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; - auto unifornLog = sd::math::nd4j_log(-sd::math::nd4j_log(rng.relativeT(nIndex, minVal, maxVal))); - Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); - if (tValue > Max) { - Max = tValue; - arg = nClass; - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); - rng.rewindH(output.lengthOf()*numOfClassX); - - return; +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +template +void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + const Tx* x = input.bufferAsT(); + Tz* z = output.bufferAsT(); + + Tx minVal = DataTypeUtils::min(); + Tx maxVal = 1.0; + + auto dimA = (0 == dimC) ? 1 : 0; + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const Nd4jLong zDimAstride = output.stridesOf()[dimA]; + const Nd4jLong xDimAstride = input.stridesOf()[dimA]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto nBatchIndex = start_x; nBatchIndex < stop_x; + nBatchIndex += inc_x) { + for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; + nSampleIndexInBatch += inc_y) { + const Tx* xTad = x + (nBatchIndex * xDimCstride); + Tz* zTad = z + (nBatchIndex * zDimCstride); + Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; + Tx Max = -minVal; + + auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass += 1) { + auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; + auto unifornLog = + sd::math::nd4j_log(-sd::math::nd4j_log( + rng.relativeT(nIndex, minVal, maxVal))); + Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } + } + } } + }; - void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES); - } + samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); + rng.rewindH(output.lengthOf() * numOfClassX); + return; } + +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), + fillRandomMultiNomial_, + (context, rng, input, output, numOfSamples, dimC), + FLOAT_TYPES, INDEXING_TYPES); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index f24ad9025609..ff3d53f9da57 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -18,109 +18,101 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - - -#include -#include #include -#include +#include #include +#include + +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - - // check edge cases first - int temp; - const int firstDim = input.sizeAt(0); - if(input.lengthOf() == 1 || firstDim == 1) { - - if(!isInplace) - output.assign(input); +void randomShuffle_(NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace) { + // check edge cases first + int temp; + const int firstDim = input.sizeAt(0); + if (input.lengthOf() == 1 || firstDim == 1) { + if (!isInplace) output.assign(input); + } else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { + // apply Fisher-Yates shuffle + if (isInplace) { + // PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > + // Environment::getInstance()->tadThreshold()) + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + if (i == r) continue; + T t0 = input.t(i); + T t1 = input.t(r); + // math::nd4j_swap(input(i), input(r)); + input.t(i) = t1; + input.t(r) = t0; + } + } else { + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); + output.p(Nd4jLong(0), input.e(0)); + + // FIXME: parallelism!! + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + output.t(i) = input.t(indices[r]); + if (i == r) continue; + + output.t(r) = input.t(indices[i]); + math::nd4j_swap(indices[i], indices[r]); + } + rng.rewindH(firstDim - 1); } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { - - // apply Fisher-Yates shuffle - if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - if(i == r) - continue; - T t0 = input.t(i); - T t1 = input.t(r); - //math::nd4j_swap(input(i), input(r)); - input.t(i) = t1; - input.t(r) = t0; - } - } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - output.p(Nd4jLong(0), input.e(0)); - - // FIXME: parallelism!! - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - output.t(i) = input.t(indices[r]); - if(i == r) - continue; - - output.t(r) = input.t(indices[i]); - math::nd4j_swap(indices[i], indices[r]); - } - rng.rewindH(firstDim-1); - } + } else { + // evaluate sub-arrays list of input array through all dimensions excluding + // first one + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); + auto subArrsListIn = input.allTensorsAlongDimension(dimensions); + + // apply Fisher-Yates shuffle + if (isInplace) { + // PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > + // Environment::getInstance()->elementwiseThreshold()) + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + + if (i == r) continue; + subArrsListIn.at(i).swapUnsafe(subArrsListIn.at(r)); + } + } else { + // evaluate sub-arrays list of output array through all dimensions + // excluding first one + auto subArrsListOut = output.allTensorsAlongDimension(dimensions); + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); + bool isZeroShuffled = false; + // PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > + // Environment::getInstance()->tadThreshold()) + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + subArrsListOut.at(i).assign(subArrsListIn.at(indices[r])); + if (r == 0) isZeroShuffled = true; + if (i == r) continue; + subArrsListOut.at(r).assign(subArrsListIn.at(indices[i])); + math::nd4j_swap(indices[i], indices[r]); + } + if (!isZeroShuffled) subArrsListOut.at(0).assign(subArrsListIn.at(0)); } - else { - - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); - - // apply Fisher-Yates shuffle - if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - - if(i == r) - continue; - subArrsListIn.at(i).swapUnsafe(subArrsListIn.at(r)); - } - } - else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i).assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - if(i == r) - continue; - subArrsListOut.at(r).assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - if(!isZeroShuffled) - subArrsListOut.at(0).assign(subArrsListIn.at(0)); - } - rng.rewindH(firstDim-1); - } - + rng.rewindH(firstDim - 1); + } } - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); - } -} -} +void randomShuffle(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, + (input, output, rng, isInplace), LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp index 34be299b7619..7083b9f21f96 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random_crop.cpp @@ -20,50 +20,65 @@ #include //#include -#include -#include #include + +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int _randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - graph::RandomGenerator rngX(context.getRng()); - //functions::random::RandomFunction::template execTransform>(rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); - //NativeOpExecutioner::execRandom(random::UniformDistribution, rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); - Nd4jLong last = shape->lengthOf() - 1; +template +static int _randomCropFunctor(graph::Context& context, NDArray* input, + NDArray* shape, NDArray* output, int seed) { + graph::RandomGenerator rngX(context.getRng()); + // functions::random::RandomFunction::template + // execTransform>(rng, output->buffer(), + // output->shapeInfo(), std::vector({T(0.), shape->e(last)}).data()); + // NativeOpExecutioner::execRandom(random::UniformDistribution, rng, + // output->buffer(), output->shapeInfo(), std::vector({T(0.), + // shape->e(last)}).data()); + Nd4jLong last = shape->lengthOf() - 1; - rngX.setSeed(seed); - //functions::random::RandomFunction::template execTransform>(rng, output->buffer(), output->shapeInfo(), std::vector({T(0.), shape->getScalar(last)}).data()); - for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { - output->p(e, rngX.relativeT(e, 0, shape->e(last))); - } - Nd4jLong maxIndex = output->argMax(); - Nd4jLong startPos = output->e(maxIndex); - Nd4jLong lastDim = input->sizeAt(-1); - // nd4j_printf("Before processing: %i %i. Output length %i\n", maxIndex, startPos, output->lengthOf()); - Nd4jLong pos = 0; - Nd4jLong width = startPos + shape->e(last); - if (width >= lastDim) { - startPos -= (width - lastDim); - width = lastDim; - } + rngX.setSeed(seed); + // functions::random::RandomFunction::template + // execTransform>(rng, output->buffer(), + // output->shapeInfo(), std::vector({T(0.), + // shape->getScalar(last)}).data()); + for (Nd4jLong e = 0; e < output->lengthOf(); ++e) { + output->p(e, rngX.relativeT(e, 0, shape->e(last))); + } + Nd4jLong maxIndex = output->argMax(); + Nd4jLong startPos = output->e(maxIndex); + Nd4jLong lastDim = input->sizeAt(-1); + // nd4j_printf("Before processing: %i %i. Output length %i\n", maxIndex, + // startPos, output->lengthOf()); + Nd4jLong pos = 0; + Nd4jLong width = startPos + shape->e(last); + if (width >= lastDim) { + startPos -= (width - lastDim); + width = lastDim; + } - for (Nd4jLong i = 0; i < input->lengthOf(); i += lastDim) { - for (Nd4jLong k = startPos; k < width && pos < output->lengthOf(); k++) { - output->p(pos++, input->e(i + k)); - } - } - return ND4J_STATUS_OK; + for (Nd4jLong i = 0; i < input->lengthOf(); i += lastDim) { + for (Nd4jLong k = startPos; k < width && pos < output->lengthOf(); k++) { + output->p(pos++, input->e(i + k)); } + } + return ND4J_STATUS_OK; +} - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, (context, input, shape, output, seed), FLOAT_TYPES); - } +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, + (context, input, shape, output, seed), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, (graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, + (graph::Context & context, NDArray* input, NDArray* shape, + NDArray* output, int seed), + FLOAT_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/range.cpp b/libnd4j/include/ops/declarable/helpers/cpu/range.cpp index eb2cbd76047f..9a82e461a6f9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/range.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/range.cpp @@ -18,40 +18,41 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 27.08.2018 // - -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // be careful: outVector must have c-order and ews = 1 !!! template -static void _range(const NDArray& start, const NDArray& delta, NDArray& outVector) { - - const Nd4jLong len = outVector.lengthOf(); - - auto buff = reinterpret_cast(outVector.buffer()); - auto s = start.e(0); - auto d = delta.e(0); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - buff[i] = s + i * d; - }; - samediff::Threads::parallel_for(func, 0, len); +static void _range(const NDArray& start, const NDArray& delta, + NDArray& outVector) { + const Nd4jLong len = outVector.lengthOf(); + + auto buff = reinterpret_cast(outVector.buffer()); + auto s = start.e(0); + auto d = delta.e(0); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) buff[i] = s + i * d; + }; + samediff::Threads::parallel_for(func, 0, len); } - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (start, delta, outVector), LIBND4J_TYPES); - } - -BUILD_SINGLE_TEMPLATE(template void _range, (const NDArray& start, const NDArray& delta, NDArray& outVector), LIBND4J_TYPES); +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (start, delta, outVector), + LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void _range, + (const NDArray& start, const NDArray& delta, + NDArray& outVector), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index 654a9decde94..476568877c4e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -18,201 +18,209 @@ // @author Yurii Shyrma, created on 16.04.2018 // -#include -#include #include #include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template inline void swap(T* arr, Nd4jLong from, Nd4jLong to) { - T tmp = arr[from]; - arr[from] = arr[to]; - arr[to] = tmp; + T tmp = arr[from]; + arr[from] = arr[to]; + arr[to] = tmp; } ///////////////////////////////////////////////////////////////////////////////////// // this legacy op is written by raver119@gmail.com -template -static void reverseArray(sd::LaunchContext * context, void const* vinArr, Nd4jLong const*inShapeBuffer, void *voutArr, Nd4jLong const*outShapeBuffer, int numOfElemsToReverse = 0) { - auto inArr = reinterpret_cast(vinArr); - auto outArr = reinterpret_cast(voutArr); - - Nd4jLong inLength = shape::length(inShapeBuffer); - Nd4jLong outLength = shape::length(outShapeBuffer); - if(numOfElemsToReverse == 0) - numOfElemsToReverse = inLength; - int inEWS = shape::elementWiseStride(inShapeBuffer); - char inOrder = shape::order(inShapeBuffer); - auto sLength = numOfElemsToReverse - 1; - - // two step phase here - if (inArr == outArr) { - if (inEWS == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto idx = sLength - e; - swap(const_cast(inArr), e, idx); - } - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - else if (inEWS > 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto idx1 = (sLength - e) * inEWS; - Nd4jLong idx2 = e * inEWS; - swap(const_cast(inArr), idx1, idx2); - } - }; - - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer); - swap(outArr, inOffset, outOffset); - } - }; - - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); - } - } - else { - // single step phase here - auto outEWS = shape::elementWiseStride(outShapeBuffer); - char outOrder = shape::order(outShapeBuffer); - - if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) { - - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong e = start; e < stop; e++) - outArr[sLength - e] = inArr[e]; - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[e] = inArr[e]; - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[(sLength - e) * outEWS] = inArr[e * inEWS]; - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) - outArr[e * outEWS] = inArr[e * inEWS]; - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - else { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer); - outArr[outOffset] = inArr[inOffset]; - } - }; - samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); - - if(inLength != numOfElemsToReverse) { - - auto f2 = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto inOffset = shape::getIndexOffset(e, inShapeBuffer); - auto outOffset = shape::getIndexOffset(e, outShapeBuffer); - outArr[outOffset] = inArr[inOffset]; - } - }; - samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); - } - } - } -} - - -/////////////////////////////////////////////////////////////////// template -static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ - - int posOfNonUnityDim = -1; - if(input->isVector() || shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim)) { +static void reverseArray(sd::LaunchContext* context, void const* vinArr, + Nd4jLong const* inShapeBuffer, void* voutArr, + Nd4jLong const* outShapeBuffer, + int numOfElemsToReverse = 0) { + auto inArr = reinterpret_cast(vinArr); + auto outArr = reinterpret_cast(voutArr); + + Nd4jLong inLength = shape::length(inShapeBuffer); + Nd4jLong outLength = shape::length(outShapeBuffer); + if (numOfElemsToReverse == 0) numOfElemsToReverse = inLength; + int inEWS = shape::elementWiseStride(inShapeBuffer); + char inOrder = shape::order(inShapeBuffer); + auto sLength = numOfElemsToReverse - 1; + + // two step phase here + if (inArr == outArr) { + if (inEWS == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto idx = sLength - e; + swap(const_cast(inArr), e, idx); + } + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); + } else if (inEWS > 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto idx1 = (sLength - e) * inEWS; + Nd4jLong idx2 = e * inEWS; + swap(const_cast(inArr), idx1, idx2); + } + }; + + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer); + swap(outArr, inOffset, outOffset); + } + }; - if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim)) - output->assign(input); - else - helpers::reverseArray(context, const_cast(input)->buffer(), const_cast(input)->shapeInfo(), output->buffer(), output->shapeInfo(), seqLengths->e(0)); + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse / 2); } - else { - - if(seqDim > batchDim) - --seqDim; - - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); - - auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); - auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - - for(int i = 0; i < inSubArrsSet.size(); ++i) { - - Nd4jLong numOfElemsToReverse = seqLengths->e(i); - - if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); - } - else { - auto inInnerSet = inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet.size(); ++j) - helpers::reverseArray(context, inInnerSet.at(j).buffer(), inInnerSet.at(j).shapeInfo(), outInnerSet.at(j).buffer(), outInnerSet.at(j).shapeInfo(), numOfElemsToReverse); - } + } else { + // single step phase here + auto outEWS = shape::elementWiseStride(outShapeBuffer); + char outOrder = shape::order(outShapeBuffer); + + if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong e = start; e < stop; e++) outArr[sLength - e] = inArr[e]; + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) outArr[e] = inArr[e]; + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } + } else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + outArr[(sLength - e) * outEWS] = inArr[e * inEWS]; + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) + outArr[e * outEWS] = inArr[e * inEWS]; + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer); + outArr[outOffset] = inArr[inOffset]; } + }; + samediff::Threads::parallel_for(func, 0, numOfElemsToReverse); + + if (inLength != numOfElemsToReverse) { + auto f2 = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto inOffset = shape::getIndexOffset(e, inShapeBuffer); + auto outOffset = shape::getIndexOffset(e, outShapeBuffer); + outArr[outOffset] = inArr[inOffset]; + } + }; + samediff::Threads::parallel_for(f2, numOfElemsToReverse, inLength); + } } + } } - void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); - } - -////////////////////////////////////////////////////////////////////////// -void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp) { - - // we need to reverse axis only if that's new op - std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; - - auto listOut = output->allTensorsAlongDimension(dimensions); - auto listIn = input->allTensorsAlongDimension(dimensions); - - for(int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() - auto subArrIn = listIn.at(i); - auto subArrOut = listOut.at(i); - BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn.buffer(), subArrIn.shapeInfo(), subArrOut.buffer(), subArrOut.shapeInfo()), LIBND4J_TYPES); +/////////////////////////////////////////////////////////////////// +template +static void reverseSequence_(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, + int seqDim, const int batchDim) { + int posOfNonUnityDim = -1; + if (input->isVector() || + shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim)) { + if ((seqDim == 0 && input->sizeAt(0) == 1) || + (batchDim == posOfNonUnityDim)) + output->assign(input); + else + helpers::reverseArray(context, const_cast(input)->buffer(), + const_cast(input)->shapeInfo(), + output->buffer(), output->shapeInfo(), + seqLengths->e(0)); + } else { + if (seqDim > batchDim) --seqDim; + + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); + + auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); + auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); + + for (int i = 0; i < inSubArrsSet.size(); ++i) { + Nd4jLong numOfElemsToReverse = seqLengths->e(i); + + if (numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { + outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); + } else { + auto inInnerSet = inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + auto outInnerSet = + outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); + for (int j = 0; j < inInnerSet.size(); ++j) + helpers::reverseArray( + context, inInnerSet.at(j).buffer(), inInnerSet.at(j).shapeInfo(), + outInnerSet.at(j).buffer(), outInnerSet.at(j).shapeInfo(), + numOfElemsToReverse); + } } + } } -BUILD_SINGLE_TEMPLATE(template void reverseSequence_, (sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void reverseArray, (sd::LaunchContext * context, void const*inArr, Nd4jLong const*inShapeBuffer, void* outArr, Nd4jLong const* outShapeBuffer, int numOfElemsToReverse), LIBND4J_TYPES); - - -} +void reverseSequence(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, + (context, input, seqLengths, output, seqDim, batchDim), + LIBND4J_TYPES); } + +////////////////////////////////////////////////////////////////////////// +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs, bool isBackProp) { + // we need to reverse axis only if that's new op + std::vector dimensions = + isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) + : *intArgs; + + auto listOut = output->allTensorsAlongDimension(dimensions); + auto listIn = input->allTensorsAlongDimension(dimensions); + + for (int i = 0; i < listIn.size(); ++i) { // listIn.size() = listOut.size() + auto subArrIn = listIn.at(i); + auto subArrOut = listOut.at(i); + BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, + (context, subArrIn.buffer(), subArrIn.shapeInfo(), + subArrOut.buffer(), subArrOut.shapeInfo()), + LIBND4J_TYPES); + } } +BUILD_SINGLE_TEMPLATE(template void reverseSequence_, + (sd::LaunchContext * context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void reverseArray, + (sd::LaunchContext * context, void const* inArr, + Nd4jLong const* inShapeBuffer, void* outArr, + Nd4jLong const* outShapeBuffer, int numOfElemsToReverse), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index 228435d2b64d..379f17e6c362 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -24,138 +24,143 @@ namespace sd { namespace ops { namespace helpers { - template - static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, bool inplace){ - auto source = input; - if (!inplace) - output->assign(input); - - int fullLen = source->lengthOf(); - int actualShift = shift; // % fullLen; // shift already non-negative then - if (actualShift < 0) { - actualShift -= fullLen * (actualShift / fullLen - 1); - } - else - actualShift %= fullLen; - - if (actualShift) { - int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; - - // stage 1) swap last actualShift elements with first ones. - //PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > Environment::getInstance()->elementwiseThreshold()) - for (int e = 0; e < actualShift; ++e) { - int sourceIndex = fullLen - actualShift + e; - - auto _e0 = output->e(e); - auto _e1 = output->e(sourceIndex); - - //sd::math::nd4j_swap((*output)(e), (*output)(sourceIndex)); - output->p(e, _e1); - output->p(sourceIndex, _e0); - } - - // stage 2) swap swapped actualShift elements with rest remainShiftCount times. - //PRAGMA_OMP_PARALLEL_FOR //_IF(shiftCount > Environment::getInstance()->tadThreshold()) - for (int count = 1; count < shiftCount; ++count) { - for (int e = 0; e < actualShift; ++e) { - int destinationIndex = fullLen - (count + 1) * actualShift + e; - int sourceIndex = fullLen - count * actualShift + e; - - auto _e0 = output->e(destinationIndex); - auto _e1 = output->e(sourceIndex); - - //sd::math::nd4j_swap((*output)(destinationIndex), (*output)(sourceIndex)); - output->p(destinationIndex, _e1); - output->p(sourceIndex, _e0); - } - } - - // stage 3) swap remainer of items. - if (remainShift && shiftCount) - for (int i = actualShift; i < 2 * actualShift; ++i) { - auto _e0 = output->e(i); - auto _e1 = output->e(i + remainShift); - - //sd::math::nd4j_swap((*output)(i), (*output)(i + remainShift)); - - output->p(i, _e1); - output->p(i + remainShift, _e0); - } - } +template +static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, + bool inplace) { + auto source = input; + if (!inplace) output->assign(input); + + int fullLen = source->lengthOf(); + int actualShift = shift; // % fullLen; // shift already non-negative then + if (actualShift < 0) { + actualShift -= fullLen * (actualShift / fullLen - 1); + } else + actualShift %= fullLen; + + if (actualShift) { + int shiftCount = fullLen / actualShift - 1; + int remainShift = fullLen % actualShift; + + // stage 1) swap last actualShift elements with first ones. + // PRAGMA_OMP_PARALLEL_FOR //_IF(actualShift > + // Environment::getInstance()->elementwiseThreshold()) + for (int e = 0; e < actualShift; ++e) { + int sourceIndex = fullLen - actualShift + e; + + auto _e0 = output->e(e); + auto _e1 = output->e(sourceIndex); + + // sd::math::nd4j_swap((*output)(e), (*output)(sourceIndex)); + output->p(e, _e1); + output->p(sourceIndex, _e0); } - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - - if (!inplace) - output->assign(input); - - auto source = output; //input; - for (size_t i = 0; i < axes.size(); i++) { - int axe = axes[i]; - if (axe == source->rankOf() - 1) {// last dimension - ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); - int fullLen = listOfTensors.size(); - int theShift = shifts[i]; - if (theShift > 0) { - theShift %= fullLen; - } - else { - theShift -= fullLen * (theShift / fullLen - 1); - } - for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(context, &listOfTensors.at(k), &listOfOutTensors.at(k), theShift, true); - } - } - else { - std::vector dims(source->rankOf() - axe - 1); - for (size_t i = 0; i < dims.size(); ++i) - dims[i] = axe + 1 + i; - - ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); - // - int fullLen = listOfTensors.size(); - int sizeAt = input->sizeAt(axe); - - int theShift = shifts[i]; - - if (theShift > 0) { - theShift %= sizeAt; - } - else { - theShift -= sizeAt * (theShift / sizeAt - 1); - } - - if (theShift) { - for (int dim = 0; dim < fullLen / sizeAt; ++dim) { - for (int e = theShift; e < sizeAt - theShift; ++e) { - auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); - auto targetM = listOfOutTensors.at(dim * sizeAt + e); - sourceM.swapUnsafe(targetM); - } - - for (int e = 0; e < theShift; ++e) { - int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - auto sourceM = listOfTensors.at(sourceIndex); - auto targetM = listOfOutTensors.at(dim * sizeAt + e); - - sourceM.swapUnsafe(targetM); - } - } - } - } -// if (!inplace) -// source = output; - } + // stage 2) swap swapped actualShift elements with rest remainShiftCount + // times. + // PRAGMA_OMP_PARALLEL_FOR //_IF(shiftCount > + // Environment::getInstance()->tadThreshold()) + for (int count = 1; count < shiftCount; ++count) { + for (int e = 0; e < actualShift; ++e) { + int destinationIndex = fullLen - (count + 1) * actualShift + e; + int sourceIndex = fullLen - count * actualShift + e; + + auto _e0 = output->e(destinationIndex); + auto _e1 = output->e(sourceIndex); + + // sd::math::nd4j_swap((*output)(destinationIndex), + // (*output)(sourceIndex)); + output->p(destinationIndex, _e1); + output->p(sourceIndex, _e0); + } } - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace){ - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, (input, output, shift, inplace), LIBND4J_TYPES); - } + // stage 3) swap remainer of items. + if (remainShift && shiftCount) + for (int i = actualShift; i < 2 * actualShift; ++i) { + auto _e0 = output->e(i); + auto _e1 = output->e(i + remainShift); + + // sd::math::nd4j_swap((*output)(i), (*output)(i + remainShift)); - BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); + output->p(i, _e1); + output->p(i + remainShift, _e0); + } + } } + +void rollFunctorFull(sd::LaunchContext* context, NDArray* input, + NDArray* output, std::vector const& shifts, + std::vector const& axes, bool inplace) { + if (!inplace) output->assign(input); + + auto source = output; // input; + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; + if (axe == source->rankOf() - 1) { // last dimension + ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); + int theShift = shifts[i]; + if (theShift > 0) { + theShift %= fullLen; + } else { + theShift -= fullLen * (theShift / fullLen - 1); + } + for (int k = 0; k < fullLen; k++) { + rollFunctorLinear(context, &listOfTensors.at(k), + &listOfOutTensors.at(k), theShift, true); + } + } else { + std::vector dims(source->rankOf() - axe - 1); + for (size_t i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; + + ResultSet listOfTensors = source->allTensorsAlongDimension({dims}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({dims}); + // + int fullLen = listOfTensors.size(); + int sizeAt = input->sizeAt(axe); + + int theShift = shifts[i]; + + if (theShift > 0) { + theShift %= sizeAt; + } else { + theShift -= sizeAt * (theShift / sizeAt - 1); + } + + if (theShift) { + for (int dim = 0; dim < fullLen / sizeAt; ++dim) { + for (int e = theShift; e < sizeAt - theShift; ++e) { + auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + sourceM.swapUnsafe(targetM); + } + + for (int e = 0; e < theShift; ++e) { + int sourceIndex = dim * sizeAt + sizeAt - theShift + e; + auto sourceM = listOfTensors.at(sourceIndex); + auto targetM = listOfOutTensors.at(dim * sizeAt + e); + + sourceM.swapUnsafe(targetM); + } + } + } + } + // if (!inplace) + // source = output; + } } -} \ No newline at end of file + +void rollFunctorLinear(sd::LaunchContext* context, NDArray* input, + NDArray* output, int shift, bool inplace) { + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, + (input, output, shift, inplace), LIBND4J_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, + (NDArray * input, NDArray* output, int shift, + bool inplace), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp index 99a172c0282d..8b70dad9d863 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_b.cpp @@ -19,399 +19,433 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void batchToSpace_(const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight) { - - // input [bS, H * blockSize, W * blockSize, iC] - // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen > zLen - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); +static void batchToSpace_(const NDArray& input, NDArray& output, + const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight) { + // input [bS, H * blockSize, W * blockSize, iC] + // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft + // - cropRight, iC] + + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen > zLen + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const int rank = 4; + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const uint bS = xShapeInfo[1]; + const uint iH = xShapeInfo[2]; + const uint iW = xShapeInfo[3]; + const uint iC = xShapeInfo[4]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_3D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto h = start_y; h < stop_y; h += inc_y) { + for (auto w = start_z; w < stop_z; w += inc_z) { + for (uint c = 0; c < iC; ++c) { + const Nd4jLong xOffset = b * xShapeInfo[5] + h * xShapeInfo[6] + + w * xShapeInfo[7] + c * xShapeInfo[8]; + const Nd4jLong zOffset = + b * zShapeInfo[5] + (h - cropBottom) * zShapeInfo[6] + + (w - cropLeft) * zShapeInfo[7] + c * zShapeInfo[8]; - const int rank = 4; - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const uint bS = xShapeInfo[1]; - const uint iH = xShapeInfo[2]; - const uint iW = xShapeInfo[3]; - const uint iC = xShapeInfo[4]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_3D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto h = start_y; h < stop_y; h += inc_y) { - for (auto w = start_z; w < stop_z; w += inc_z) { - for (uint c = 0; c < iC; ++c) { - const Nd4jLong xOffset = b * xShapeInfo[5] + h * xShapeInfo[6] + w * xShapeInfo[7] + c * xShapeInfo[8]; - const Nd4jLong zOffset = b * zShapeInfo[5] + (h - cropBottom) * zShapeInfo[6] + (w - cropLeft) * zShapeInfo[7] + c * zShapeInfo[8]; - - z[zOffset] = x[xOffset]; - } - } - } + z[zOffset] = x[xOffset]; + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, cropLeft, iW - cropRight, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, cropBottom, iH - cropTop, 1, + cropLeft, iW - cropRight, 1); } -BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpace_, + (const NDArray& input, NDArray& output, + const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); - inputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) - output.assign(inputRearranged0); - else { - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)}); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpace_, (inputRearranged1, output, cropBottom, cropTop, cropLeft, cropRight), LIBND4J_TYPES); - } +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize) { + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + NDArray inputRearranged0 = input.reshape( + input.ordering(), {blockSize, blockSize, output.sizeAt(0), + input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); + inputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) + output.assign(inputRearranged0); + else { + NDArray inputRearranged1 = inputRearranged0.reshape( + input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, + input.sizeAt(2) * blockSize, input.sizeAt(3)}); + BUILD_SINGLE_SELECTOR( + input.dataType(), batchToSpace_, + (inputRearranged1, output, cropBottom, cropTop, cropLeft, cropRight), + LIBND4J_TYPES); + } } ////////////////////////////////////////////////////////////////////////// template -static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims) { - - // input [bS, H * blockShape[0], W * blockShape[1], iC] - // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen +static void batchToSpaceND_(const NDArray& input, const NDArray& crop, + NDArray& output, const uint numOfSpatialDims) { + // input [bS, H * blockShape[0], W * blockShape[1], iC] + // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - + // cropLeft - cropRight, iC] - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - const int rank = input.rankOf(); - const Nd4jLong zLen = output.lengthOf(); + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - // loop through input array - auto func = PRAGMA_THREADS_FOR { + const int rank = input.rankOf(); + const Nd4jLong zLen = output.lengthOf(); - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + // loop through input array + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - for (auto i = start; i < stop; i++) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + memcpy(xCoords, zCoords, rank * sizeof(int)); - memcpy(xCoords, zCoords, rank * sizeof(int)); + // evaluate spatial coordinates for x + for (uint j = 1; j <= numOfSpatialDims; ++j) + xCoords[j] += crop.e(j - 1, 0); // add crop left - // evaluate spatial coordinates for x - for (uint j = 1; j <= numOfSpatialDims; ++j) - xCoords[j] += crop.e(j - 1, 0); // add crop left + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); - - z[zOffset] = x[xOffset]; - } - }; + z[zOffset] = x[xOffset]; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, (const NDArray& input, const NDArray& crop, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpaceND_, + (const NDArray& input, const NDArray& crop, + NDArray& output, const uint numOfSpatialDims), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output) { +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output) { + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] + const uint rank = input.rankOf(); + const uint numOfSpatialDims = blockShape.sizeAt(0); - const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + //*** construct reshaping std::vector for first reshape of input array ***// - //*** construct reshaping std::vector for first reshape of input array ***// + std::vector temp(numOfSpatialDims + rank); - std::vector temp(numOfSpatialDims + rank); + uint i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = output.sizeAt(0); + for (uint j = 1; j < rank; ++i, ++j) temp[i] = input.sizeAt(j); - uint i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = output.sizeAt(0); - for(uint j = 1; j < rank; ++i, ++j) - temp[i] = input.sizeAt(j); + NDArray inputRearranged0 = input.reshape(input.ordering(), temp); - NDArray inputRearranged0 = input.reshape(input.ordering(), temp); + //*** construct permuting std::vector for permutation of input array ***// - //*** construct permuting std::vector for permutation of input array ***// + temp[0] = numOfSpatialDims; - temp[0] = numOfSpatialDims; + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < static_cast(temp.size()); ++i) + temp[i] = i; - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < static_cast(temp.size()); ++i) - temp[i] = i; + inputRearranged0.permutei(temp); - inputRearranged0.permutei(temp); + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + //*** construct reshaping std::vector for second reshape of input array + //***// + temp.resize(rank); - if(input.lengthOf() == output.lengthOf()) { - output.assign(inputRearranged0); - } - else { - //*** construct reshaping std::vector for second reshape of input array ***// - - temp.resize(rank); - - temp[0] = output.sizeAt(0); + temp[0] = output.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? input.sizeAt(i) * blockShape.e(i - 1) : input.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? input.sizeAt(i) * blockShape.e(i - 1) + : input.sizeAt(i); - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); + NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceND_, (inputRearranged1, crop, output, numOfSpatialDims), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceND_, + (inputRearranged1, crop, output, numOfSpatialDims), + LIBND4J_TYPES); + } } ////////////////////////////////////////////////////////////////////////// template -static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC] - // output [bS, H * blockSize, W * blockSize, iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen - - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); - - const int rank = 4; - - const Nd4jLong* xShapeInfo = input.shapeInfo(); - const Nd4jLong* zShapeInfo = output.shapeInfo(); - - const uint bS = zShapeInfo[1]; - const uint oH = zShapeInfo[2]; - const uint oW = zShapeInfo[3]; - const uint iC = zShapeInfo[4]; - - // loop through output array - auto func = PRAGMA_THREADS_FOR_2D { - for (auto b = start_x; b < stop_x; b += inc_x) { - for (auto h = start_y; h < stop_y; h += inc_y) { - for (uint w = 0; w < oW; ++w) { - for (uint c = 0; c < iC; ++c) { - - const Nd4jLong zOffset = b * zShapeInfo[5] + h * zShapeInfo[6] + w * zShapeInfo[7] + c * zShapeInfo[8]; - - if (h >= padBottom && h < oH - padTop && w >= padLeft && w < oW - padRight) { - const Nd4jLong xOffset = b * xShapeInfo[5] + (h - padBottom) * xShapeInfo[6] + (w - padLeft) * xShapeInfo[7] + c * xShapeInfo[8]; - z[zOffset] = x[xOffset]; - } else - z[zOffset] = 0.f; - } - } - } +static void spaceToBatch_(const NDArray& input, NDArray& output, + const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight) { + // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - + // padRight, iC] output [bS, H * blockSize, W * blockSize, iC] + + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + + const int rank = 4; + + const Nd4jLong* xShapeInfo = input.shapeInfo(); + const Nd4jLong* zShapeInfo = output.shapeInfo(); + + const uint bS = zShapeInfo[1]; + const uint oH = zShapeInfo[2]; + const uint oW = zShapeInfo[3]; + const uint iC = zShapeInfo[4]; + + // loop through output array + auto func = PRAGMA_THREADS_FOR_2D { + for (auto b = start_x; b < stop_x; b += inc_x) { + for (auto h = start_y; h < stop_y; h += inc_y) { + for (uint w = 0; w < oW; ++w) { + for (uint c = 0; c < iC; ++c) { + const Nd4jLong zOffset = b * zShapeInfo[5] + h * zShapeInfo[6] + + w * zShapeInfo[7] + c * zShapeInfo[8]; + + if (h >= padBottom && h < oH - padTop && w >= padLeft && + w < oW - padRight) { + const Nd4jLong xOffset = + b * xShapeInfo[5] + (h - padBottom) * xShapeInfo[6] + + (w - padLeft) * xShapeInfo[7] + c * xShapeInfo[8]; + z[zOffset] = x[xOffset]; + } else + z[zOffset] = 0.f; + } } - }; + } + } + }; - samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); + samediff::Threads::parallel_for(func, 0, bS, 1, 0, oH, 1); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, + (const NDArray& input, NDArray& output, + const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false); - outputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES); - - if(output.buffer() != outputRearranged1.buffer()) - outputRearranged0.assign(outputRearranged1); - } +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize) { + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + NDArray outputRearranged0 = + output.reshape(output.ordering(), + {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), + output.sizeAt(2), output.sizeAt(3)}, + false); + outputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + NDArray outputRearranged1 = outputRearranged0.reshape( + output.ordering(), + {input.sizeAt(0), output.sizeAt(1) * blockSize, + output.sizeAt(2) * blockSize, output.sizeAt(3)}, + false); + BUILD_SINGLE_SELECTOR( + input.dataType(), spaceToBatch_, + (input, outputRearranged1, padBottom, padTop, padLeft, padRight), + LIBND4J_TYPES); + + if (output.buffer() != outputRearranged1.buffer()) + outputRearranged0.assign(outputRearranged1); + } } - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template -static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims) { - - // 4D example - // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - padLeft - padRight, iC] - // output [bS, H * blockShape[0], W * blockShape[1], iC] +static void spaceToBatchND_(const NDArray& input, const NDArray& padding, + NDArray& output, const uint numOfSpatialDims) { + // 4D example + // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - + // padLeft - padRight, iC] output [bS, H * blockShape[0], W * blockShape[1], + // iC] - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - const T* x = input.bufferAsT(); - T* z = output.bufferAsT(); + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); - const int rank = input.rankOf(); - const Nd4jLong zLen = output.lengthOf(); + const int rank = input.rankOf(); + const Nd4jLong zLen = output.lengthOf(); - // loop through output array - auto func = PRAGMA_THREADS_FOR { + // loop through output array + auto func = PRAGMA_THREADS_FOR { + int zCoords[MAX_RANK], xCoords[MAX_RANK]; - int zCoords[MAX_RANK], xCoords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); - for (auto i = start; i < stop; i++) { + const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); - shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); + memcpy(xCoords, zCoords, rank * sizeof(int)); - const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); + bool within = true; - memcpy(xCoords, zCoords, rank * sizeof(int)); + for (uint j = 1; j <= numOfSpatialDims; ++j) { + const auto padLeft = padding.e(j - 1, 0); + const auto padRight = padding.e(j - 1, 1); - bool within = true; + within &= + zCoords[j] >= padLeft && zCoords[j] < output.sizeAt(j) - padRight; - for (uint j = 1; j <= numOfSpatialDims; ++j) { + if (!within) break; - const auto padLeft = padding.e(j - 1, 0); - const auto padRight = padding.e(j - 1, 1); + xCoords[j] = zCoords[j] - padLeft; // get coordinates for x + } - within &= zCoords[j] >= padLeft && zCoords[j] < output.sizeAt(j) - padRight; - - if (!within) - break; - - xCoords[j] = zCoords[j] - padLeft; // get coordinates for x - } - - if (within) - z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; - else - z[zOffset] = 0.f; - } - }; + if (within) + z[zOffset] = x[shape::getOffset(input.shapeInfo(), xCoords)]; + else + z[zOffset] = 0.f; + } + }; - samediff::Threads::parallel_tad(func, 0, zLen); + samediff::Threads::parallel_tad(func, 0, zLen); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, (const NDArray& input, const NDArray& padding, NDArray& output, const uint numOfSpatialDims), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatchND_, + (const NDArray& input, const NDArray& padding, + NDArray& output, const uint numOfSpatialDims), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output ) { - - // 4D example with two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + padRight)/blockShape[1], iC] +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output) { + // 4D example with two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + + // padRight)/blockShape[1], iC] - const uint rank = input.rankOf(); + const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + const uint numOfSpatialDims = blockShape.sizeAt(0); - //*** construct reshaping std::vector for first reshape of output array ***// - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of output array ***// + std::vector temp(numOfSpatialDims + rank); - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = input.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = output.sizeAt(j); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = input.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = output.sizeAt(j); - NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); + NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); - //*** construct permuting std::vector for permutation of output array ***// + //*** construct permuting std::vector for permutation of output array ***// - temp[0] = numOfSpatialDims; - - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; + temp[0] = numOfSpatialDims; - outputRearranged0.permutei(temp); + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - // ****** // + outputRearranged0.permutei(temp); - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { + // ****** // - //*** construct reshaping std::vector for second reshape of output array ***// - temp.resize(rank); + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + //*** construct reshaping std::vector for second reshape of output array + //***// + temp.resize(rank); - temp[0] = input.sizeAt(0); + temp[0] = input.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e(i - 1) : output.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? output.sizeAt(i) * blockShape.e(i - 1) + : output.sizeAt(i); - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); + NDArray outputRearranged1 = + outputRearranged0.reshape(output.ordering(), temp, false); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, + (input, padding, outputRearranged1, numOfSpatialDims), + LIBND4J_TYPES); - if(output.buffer() != outputRearranged1.buffer()) - outputRearranged0.assign(outputRearranged1); - } + if (output.buffer() != outputRearranged1.buffer()) + outputRearranged0.assign(outputRearranged1); + } } - /* template struct SpaceToBatchHelper { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - for (int batch_pos = 0; batch_pos < batch_shape[0]; ++batch_pos) { - const int space_pos = batch_pos * block_shape[0] + block_offsets[0] - pad_start[0]; - if (space_pos >= 0 && space_pos < space_shape[0]) { - SpaceToBatchHelper::run(ptrSpace + space_pos * space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start + 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); - } else { + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { for (int batch_pos = 0; batch_pos < batch_shape[0]; +++batch_pos) { const int space_pos = batch_pos * block_shape[0] + +block_offsets[0] - pad_start[0]; if (space_pos >= 0 && space_pos < +space_shape[0]) { SpaceToBatchHelper::run(ptrSpace + space_pos * +space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start ++ 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); } else { if (!B2S) for (int i = 0; i < batch_strides[0]; i++) ptrBatch[i] = (T) 0.f; @@ -425,26 +459,30 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr template struct SpaceToBatchHelper<0, B2S> { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - int str = batch_strides[-1]; - for (int i = 0; i < str; i++) - if (B2S) - ptrSpace[i] = ptrBatch[i]; - else - ptrBatch[i] = ptrSpace[i]; + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { int str = batch_strides[-1]; for (int i = 0; i < str; +i++) if (B2S) ptrSpace[i] = ptrBatch[i]; else ptrBatch[i] = ptrSpace[i]; } }; template - void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - auto ptrSpace = reinterpret_cast(vptrSpace); - auto ptrBatch = reinterpret_cast(vptrBatch); - SpaceToBatchHelper::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides); + void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { auto ptrSpace = +reinterpret_cast(vptrSpace); auto ptrBatch = reinterpret_cast(vptrBatch); SpaceToBatchHelper::run(ptrSpace, +space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, +batch_shape, batch_strides); }; - Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings) { - auto in = input->reshape('c', internal_input_shape); - auto out = output->reshape('c', internal_output_shape); + Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *paddings) { auto in = input->reshape('c', +internal_input_shape); auto out = output->reshape('c', internal_output_shape); switch (internal_block_dims) { case 1: _prepare<1, false>(context, &in, &out, block_shape, paddings); @@ -459,16 +497,19 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr _prepare<4, false>(context, &in, &out, block_shape, paddings); break; default: { - return Status::THROW("SpaceToBatch: Wrong number of internal_block_dims"); + return Status::THROW("SpaceToBatch: Wrong number of +internal_block_dims"); } } return Status::OK(); } - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) { - auto in = input->reshape('c', internal_input_shape); - auto out = output->reshape('c', internal_output_shape); + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops) { auto in = input->reshape('c', +internal_input_shape); auto out = output->reshape('c', internal_output_shape); switch (internal_block_dims) { case 1: _prepare<1, true>(context, &in, &out, block_shape, crops); @@ -483,7 +524,8 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr _prepare<4, true>(context, &in, &out, block_shape, crops); break; default: { - return Status::THROW("BatchToSpace: Wrong number of internal_block_dims"); + return Status::THROW("BatchToSpace: Wrong number of +internal_block_dims"); } } @@ -498,12 +540,16 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr #define STB_BOOL (0, false),\ (1, true) - BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); + BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, +void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, +const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong +*block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong +*batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); #undef STB_BOOL #undef STB_DIM */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp index b51a4adc9606..dd1e84e4d135 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp @@ -18,91 +18,103 @@ // // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto input_ptr = reinterpret_cast(input.buffer()); - auto output_ptr = reinterpret_cast(output->buffer()); - - const int batch_size = input.sizeAt(0); - const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); - const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); - const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); - - const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); - const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); - const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); - - const int input_depth_by_output_height = input_depth * output_height; - - const int output_area = output_width * output_height; - const int output_depth_by_output_area = output_depth * output_area; - - - if (isNHWC) { - const int total_count = batch_size * input_height * input_width * input_depth; - - auto func = PRAGMA_THREADS_FOR { - for (auto inp_idx = start; inp_idx < stop; inp_idx++) { - // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) - const int d = inp_idx % input_depth; - const int inp_idx2 = inp_idx / input_depth; - const int w = inp_idx2 % input_width; - const int inp_idx3 = inp_idx2 / input_width; - const int h = inp_idx3 % input_height; - const int b = inp_idx3 / input_height; - - const int out_h = h / block_size; - const int offset_h = h % block_size; - const int out_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; - const int out_d = d + offset_d; - - const int out_idx = out_d + output_depth * (out_w + output_width * (out_h + output_height * b)); - *(output_ptr + out_idx) = *(input_ptr + inp_idx); - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } else { - const int total_count = batch_size * output_depth_by_output_area; - - auto func = PRAGMA_THREADS_FOR { - for (auto inp_idx = start; inp_idx < stop; inp_idx++) { - const int n_iC_oY_bY_oX = inp_idx / block_size; - const int bX = inp_idx - n_iC_oY_bY_oX * block_size; - - const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; - const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; - - const int n_iC_oY = n_iC_oY_bY / block_size; - const int bY = n_iC_oY_bY - n_iC_oY * block_size; - - const int n = n_iC_oY / input_depth_by_output_height; - const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - - const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * input_depth_by_output_height + iC_oY) * output_width; - - *(output_ptr + output_idx) = *(input_ptr + inp_idx); - } - }; - - samediff::Threads::parallel_for(func, 0, total_count); - } - } +template +static void _spaceTodepth_(const NDArray &input, NDArray *output, + int block_size, bool isNHWC) { + auto input_ptr = reinterpret_cast(input.buffer()); + auto output_ptr = reinterpret_cast(output->buffer()); + + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); + + const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); + const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); + const int output_width = isNHWC ? output->sizeAt(2) : output->sizeAt(3); + + const int input_depth_by_output_height = input_depth * output_height; + + const int output_area = output_width * output_height; + const int output_depth_by_output_area = output_depth * output_area; + + if (isNHWC) { + const int total_count = + batch_size * input_height * input_width * input_depth; + + auto func = PRAGMA_THREADS_FOR { + for (auto inp_idx = start; inp_idx < stop; inp_idx++) { + // inp_idx = d + input_depth * (w + input_width * (h + input_height * + // b)) + const int d = inp_idx % input_depth; + const int inp_idx2 = inp_idx / input_depth; + const int w = inp_idx2 % input_width; + const int inp_idx3 = inp_idx2 / input_width; + const int h = inp_idx3 % input_height; + const int b = inp_idx3 / input_height; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + + const int out_idx = + out_d + + output_depth * (out_w + output_width * (out_h + output_height * b)); + *(output_ptr + out_idx) = *(input_ptr + inp_idx); + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } else { + const int total_count = batch_size * output_depth_by_output_area; + + auto func = PRAGMA_THREADS_FOR { + for (auto inp_idx = start; inp_idx < stop; inp_idx++) { + const int n_iC_oY_bY_oX = inp_idx / block_size; + const int bX = inp_idx - n_iC_oY_bY_oX * block_size; + + const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; + const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; + + const int n_iC_oY = n_iC_oY_bY / block_size; + const int bY = n_iC_oY_bY - n_iC_oY * block_size; + + const int n = n_iC_oY / input_depth_by_output_height; + const int iC_oY = n_iC_oY - n * input_depth_by_output_height; + + const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * + input_depth_by_output_height + + iC_oY) * + output_width; + + *(output_ptr + output_idx) = *(input_ptr + inp_idx); + } + }; + + samediff::Threads::parallel_for(func, 0, total_count); + } +} - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); - } +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, + (input, output, block_size, isNHWC), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, + (const NDArray &input, NDArray *output, int block_size, + bool isNHWC), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index e19eb5deaa0e..cdf0e3bd3b32 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -18,178 +18,190 @@ // @author raver119@gmail.com // +#include +#include #include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, z - input/output -template -Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int axis) { +template +Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, + const int axis) { + std::atomic numOfBadIndx{0}; - std::atomic numOfBadIndx{0}; + const auto x = indices.bufferAsT(); - const auto x = indices.bufferAsT(); + const auto xShapeInfo = indices.shapeInfo(); + const auto zShapeInfo = output.shapeInfo(); - const auto xShapeInfo = indices.shapeInfo(); - const auto zShapeInfo = output.shapeInfo(); + const auto xRank = indices.rankOf(); - const auto xRank = indices.rankOf(); + auto func = PRAGMA_THREADS_FOR { + int xCoords[MAX_RANK]; - auto func = PRAGMA_THREADS_FOR { - - int xCoords[MAX_RANK]; - - for (auto i = start; i < stop; i++) { + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, xShapeInfo, xCoords); - shape::index2coordsCPU(start, i, xShapeInfo, xCoords); + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - - if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { - printf("checkIndices: out of range element %lld at index %ld \n", currentInd, i); - ++numOfBadIndx; - } - } - }; + if (currentInd >= + shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank - 1] : axis)) { + printf("checkIndices: out of range element %lld at index %ld \n", + currentInd, i); + ++numOfBadIndx; + } + } + }; - samediff::Threads::parallel_for(func, 0, indices.lengthOf()); + samediff::Threads::parallel_for(func, 0, indices.lengthOf()); - return numOfBadIndx; + return numOfBadIndx; } /////////////////////////////////////////////////////////////////// -Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { - - BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, (indices, output, axis), INDEXING_TYPES); +Nd4jLong checkIndices(sd::LaunchContext* context, const NDArray& indices, + const NDArray& output, const int axis) { + BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, + (indices, output, axis), INDEXING_TYPES); } /////////////////////////////////////////////////////////////////// -void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const int outRank = output.rankOf(); - const int indRank = indices.rankOf(); - const int updRank = updates.rankOf(); - const Nd4jLong indLen = indices.lengthOf(); - - if(outRank == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong idx = indices.e(i); - NDArray out = output({idx, idx + 1}); +void scatter(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock) { + const int outRank = output.rankOf(); + const int indRank = indices.rankOf(); + const int updRank = updates.rankOf(); + const Nd4jLong indLen = indices.lengthOf(); + + if (outRank == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong idx = indices.e(i); + NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i)); - } - }; + out.applyPairwiseTransform(op, updates.e(i)); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); - } - else { // outRank > 1 + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance()->maxThreads()); + } else { // outRank > 1 - int sizeOfDims = indRank; - if(outRank == updRank && indices.isVector()) - sizeOfDims = 1; + int sizeOfDims = indRank; + if (outRank == updRank && indices.isVector()) sizeOfDims = 1; - std::vector dimsToExcludeUpd(sizeOfDims); - std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + std::vector dimsToExcludeUpd(sizeOfDims); + std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - NDArray outSubArr = output(indices.e(i), std::vector({0})); - NDArray updSubArr = updates(i, dimsToExcludeUpd); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + NDArray outSubArr = + output(indices.e(i), std::vector({0})); + NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr); - } - }; + outSubArr.applyPairwiseTransform(op, updSubArr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); - } + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance()->maxThreads()); + } } /////////////////////////////////////////////////////////////////// -void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const Nd4jLong indLen = indices.lengthOf(); - const int outRank = output.rankOf(); - const int indRank = indices.rankOf(); - const Nd4jLong indLastDim = indices.sizeAt(-1); - - if(outRank == 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - Nd4jLong idx = indices.e(i); - NDArray out = output({idx, idx + 1}); +void scatterND(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock) { + const Nd4jLong indLen = indices.lengthOf(); + const int outRank = output.rankOf(); + const int indRank = indices.rankOf(); + const Nd4jLong indLastDim = indices.sizeAt(-1); + + if (outRank == 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + Nd4jLong idx = indices.e(i); + NDArray out = output({idx, idx + 1}); - out.applyPairwiseTransform(op, updates.e(i), nullptr); - } - }; + out.applyPairwiseTransform(op, updates.e(i), nullptr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); - } - else { - std::vector dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1}); - std::vector dimsToExcludeUpd(indRank - 1); - std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + samediff::Threads::parallel_tad( + func, 0, indLen, 1, + lock ? 1 : sd::Environment::getInstance()->maxThreads()); + } else { + std::vector dimsToExcludeInd = + ShapeUtils::evalDimsToExclude(indRank, {indRank - 1}); + std::vector dimsToExcludeUpd(indRank - 1); + std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); - auto func = PRAGMA_THREADS_FOR { - std::vector idxRangeOut(2*outRank, 0); + auto func = PRAGMA_THREADS_FOR { + std::vector idxRangeOut(2 * outRank, 0); - for (auto i = start; i < stop; i++) { - NDArray indSubArr = indices(i, dimsToExcludeInd); + for (auto i = start; i < stop; i++) { + NDArray indSubArr = indices(i, dimsToExcludeInd); - for (Nd4jLong j = 0; j < indLastDim; ++j) { - idxRangeOut[2 * j] = indSubArr.e(j); - idxRangeOut[2 * j + 1] = idxRangeOut[2 * j] + 1; - } + for (Nd4jLong j = 0; j < indLastDim; ++j) { + idxRangeOut[2 * j] = indSubArr.e(j); + idxRangeOut[2 * j + 1] = idxRangeOut[2 * j] + 1; + } - NDArray outSubArr = output(idxRangeOut); - NDArray updSubArr = updates(i, dimsToExcludeUpd); + NDArray outSubArr = output(idxRangeOut); + NDArray updSubArr = updates(i, dimsToExcludeUpd); - outSubArr.applyPairwiseTransform(op, updSubArr); - } - }; + outSubArr.applyPairwiseTransform(op, updSubArr); + } + }; - samediff::Threads::parallel_tad(func, 0, indLen / indLastDim, 1, lock ? 1 : sd::Environment::getInstance()->maxThreads()); - } + samediff::Threads::parallel_tad( + func, 0, indLen / indLastDim, 1, + lock ? 1 : sd::Environment::getInstance()->maxThreads()); + } } -void scatterForLoss(sd::LaunchContext *context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { +void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, + NDArray& updates, NDArray& output, const bool calcGrad) { + // shapes of indices and output must be the same + // shape of indices should be the same as updates shape with last dimension + // excluded for example if updates is {a,b,c} then indices should be {a,b} - // shapes of indices and output must be the same - // shape of indices should be the same as updates shape with last dimension excluded - // for example if updates is {a,b,c} then indices should be {a,b} + const Nd4jLong indicesLen = indices.lengthOf(); - const Nd4jLong indicesLen = indices.lengthOf(); + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1}); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1}); - - if(!calcGrad) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto subArr = updates(i, dimsToExclude); - output.p(i, subArr.e(indices.e(i))); - } - }; + if (!calcGrad) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto subArr = updates(i, dimsToExclude); + output.p(i, subArr.e(indices.e(i))); + } + }; - samediff::Threads::parallel_for(func, 0, indicesLen); - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto subArr = updates(i, dimsToExclude); - auto ind = indices.e(i); - subArr.p(ind, subArr.e(ind) - 1.); - } - }; + samediff::Threads::parallel_for(func, 0, indicesLen); + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto subArr = updates(i, dimsToExclude); + auto ind = indices.e(i); + subArr.p(ind, subArr.e(ind) - 1.); + } + }; - samediff::Threads::parallel_for(func, 0, indicesLen); - } + samediff::Threads::parallel_for(func, 0, indicesLen); + } } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp index fe41c5d43335..042acde058ad 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatterUpdateAndSimple.cpp @@ -18,98 +18,103 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void scatterUpdate(sd::LaunchContext * context, NDArray& input, NDArray& updates, const std::vector* intArgs) { - - int opCode = (*intArgs)[0]; - int dimSize = (*intArgs)[1]; - Nd4jLong e; - Nd4jLong limg = 2 + dimSize; - std::vector tadDimensions(dimSize); - for (e = 2; e < limg; e++) - tadDimensions[e-2] = (*intArgs)[e]; - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), tadDimensions); - - // increasing counter to skip numIndices - e++; - std::vector indices; - for (; e < static_cast(intArgs->size()); e++) - indices.push_back((*intArgs)[e]); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inSubArr = input(indices[i], dimsToExclude, true); - auto updSubArr = updates(i, dimsToExclude, true); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) - continue; - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr); - break; - default: - continue; - } - } - }; +void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, + const std::vector* intArgs) { + int opCode = (*intArgs)[0]; + int dimSize = (*intArgs)[1]; + Nd4jLong e; + Nd4jLong limg = 2 + dimSize; + std::vector tadDimensions(dimSize); + for (e = 2; e < limg; e++) tadDimensions[e - 2] = (*intArgs)[e]; + + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input.rankOf(), tadDimensions); + + // increasing counter to skip numIndices + e++; + std::vector indices; + for (; e < static_cast(intArgs->size()); e++) + indices.push_back((*intArgs)[e]); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inSubArr = input(indices[i], dimsToExclude, true); + auto updSubArr = updates(i, dimsToExclude, true); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) continue; + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, + inSubArr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, + inSubArr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, + inSubArr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, + inSubArr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, + inSubArr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, + inSubArr); + break; + default: + continue; + } + } + }; - samediff::Threads::parallel_tad(func, 0, indices.size()); + samediff::Threads::parallel_tad(func, 0, indices.size()); } - ////////////////////////////////////////////////////////////////////////// -void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - - // updates and indices have same length - const Nd4jLong len = indices.lengthOf(); - - switch (opId) { - - case 6: { // copy - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inSubArr = input(i, dimensions); - inSubArr.p(indices.t(i), updates.e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, len); +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + // updates and indices have same length + const Nd4jLong len = indices.lengthOf(); + + switch (opId) { + case 6: { // copy + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inSubArr = input(i, dimensions); + inSubArr.p(indices.t(i), updates.e(i)); } - break; + }; - default: - throw std::invalid_argument("helpers::scatterSimple: operation is not implemented for given id !"); - } -} + samediff::Threads::parallel_for(func, 0, len); + } break; + default: + throw std::invalid_argument( + "helpers::scatterSimple: operation is not implemented for given id " + "!"); + } } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 241e3e131a2a..09166d7d7ff9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -19,1073 +19,1151 @@ // @author GS // -#include -#include #include +#include +#include + #include namespace sd { namespace ops { namespace helpers { - // segment max - template - static void segmentMaxFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // max - val = sd::math::nd4j_max(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->t(idx) = val; - } - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - auto numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto maxT = listOfOutTensors.at(idx); - - //int pos = 0; - maxT.assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - - for (Nd4jLong e = 0; e < maxT.lengthOf(); e++) { - maxT.t(e) = sd::math::nd4j_max(maxT.t(e), listOfTensors.at(i).t(e)); - } - } - else { - idx = indices->e(i); - maxT = listOfOutTensors.at(idx); - maxT.assign(listOfTensors.at(i)); - } - - } - } +// segment max +template +static void segmentMaxFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + T val = input->e(0); + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // max + val = sd::math::nd4j_max(val, input->t(e)); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->t(idx) = val; } - - // segmen min - template - static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // min - val = sd::math::nd4j_min(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->t(idx) = val; - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto minT = listOfOutTensors.at(idx); - - int pos = 0; - minT.assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - - for (Nd4jLong e = 0; e < minT.lengthOf(); e++) { - minT.p(e, sd::math::nd4j_min(minT.e(e), listOfTensors.at(i).e(e))); - } - } - else { - idx = indices->e(i); - minT = listOfOutTensors.at(idx); - minT.assign(listOfTensors.at(i)); - } - } + } else { + std::vector restDims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + auto numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto maxT = listOfOutTensors.at(idx); + + // int pos = 0; + maxT.assign(listOfTensors.at(0)); + + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + for (Nd4jLong e = 0; e < maxT.lengthOf(); e++) { + maxT.t(e) = + sd::math::nd4j_max(maxT.t(e), listOfTensors.at(i).t(e)); } + } else { + idx = indices->e(i); + maxT = listOfOutTensors.at(idx); + maxT.assign(listOfTensors.at(i)); + } } + } +} - // segmen mean - template - static void segmentMeanFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; - - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // mean - val += input->e(e); - count++; - } - else { - output->p(idx, val / count); - idx = indices->e(e); - val = input->e(e); - count = 1; - } - output->p(idx, val / count); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto meanT = listOfOutTensors.at(idx); - int count = 1; - auto meanV = meanT.dup(); - meanV.assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - meanV.p(e, meanV.e(e) + listOfTensors.at(i).e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, meanT.lengthOf()); - - count++; - } - else { - //meanT->assign(meanV); - meanV.applyScalar(scalar::Divide, count, meanT); - idx = indices->e(i); - meanT = listOfOutTensors.at(idx); - meanV.assign(listOfTensors.at(i)); - count = 1; - } - meanV.applyScalar(scalar::Divide, count, meanT); - } - } +// segmen min +template +static void segmentMinFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + T val = input->e(0); + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // min + val = sd::math::nd4j_min(val, input->t(e)); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->t(idx) = val; } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - template - static void segmentSumFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val += input->t(e); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto sumT = listOfOutTensors.at(idx); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT.p(e, sumT.e(e) + listOfTensors.at(i).e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); - } - else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT.assign(listOfTensors.at(i)); - } - } - } - } + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - template - static void segmentProdFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - output->assign(1.f); - if (input->isVector()) { - T val = input->e(0); - int count = 0; - - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val *= input->e(e); - } - else { - idx = indices->e(e); - val = input->e(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - auto sumT = listOfOutTensors.at(idx); - sumT.assign(listOfTensors.at(0)); - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT.p(e, sumT.e(e) * listOfTensors.at(i).e(e)); - } - }; - samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); - } - else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT.assign(listOfTensors.at(i)); - } - } - } - } + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto minT = listOfOutTensors.at(idx); -// template -// static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& anOutput) { -// } + int pos = 0; + minT.assign(listOfTensors.at(0)); - void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, (input, indices, output), LIBND4J_TYPES); + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + for (Nd4jLong e = 0; e < minT.lengthOf(); e++) { + minT.p(e, + sd::math::nd4j_min(minT.e(e), listOfTensors.at(i).e(e))); + } + } else { + idx = indices->e(i); + minT = listOfOutTensors.at(idx); + minT.assign(listOfTensors.at(i)); + } } + } +} - void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, (input, indices, output), LIBND4J_TYPES); +// segmen mean +template +static void segmentMeanFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector()) { + T val = T(0.f); + int count = 0; + + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // mean + val += input->e(e); + count++; + } else { + output->p(idx, val / count); + idx = indices->e(e); + val = input->e(e); + count = 1; + } + output->p(idx, val / count); } - - void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, (input, indices, output), LIBND4J_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto meanT = listOfOutTensors.at(idx); + int count = 1; + auto meanV = meanT.dup(); + meanV.assign(listOfTensors.at(0)); + + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + meanV.p(e, meanV.e(e) + listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, meanT.lengthOf()); + + count++; + } else { + // meanT->assign(meanV); + meanV.applyScalar(scalar::Divide, count, meanT); + idx = indices->e(i); + meanT = listOfOutTensors.at(idx); + meanV.assign(listOfTensors.at(i)); + count = 1; + } + meanV.applyScalar(scalar::Divide, count, meanT); } + } +} - void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, (input, indices, output), LIBND4J_TYPES); +template +static void segmentSumFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector()) { + T val = T(0.f); + int count = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val += input->t(e); + } else { + idx = indices->e(e); + val = input->t(e); + } + output->p(idx, val); } - - void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, (input, indices, output), LIBND4J_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto sumT = listOfOutTensors.at(idx); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT.p(e, sumT.e(e) + listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); + } else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(i)); + } } + } +} - bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output) { - auto val = indices->e(0); - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - output = indices->e(e); - if (val.e(0) > output.e(0)) - return false; - val = indices->e(e); - } - - return true; +template +static void segmentProdFunctor_(NDArray* input, NDArray* indices, + NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + output->assign(1.f); + if (input->isVector()) { + T val = input->e(0); + int count = 0; + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val *= input->e(e); + } else { + idx = indices->e(e); + val = input->e(e); + } + output->p(idx, val); } - - //BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted segment ops - // -------------------------------------------------------------------------------------------------------------- // - - bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - Nd4jLong val = indices->e(0); - - Nd4jLong maxInd = indices->argMax(); - if (indices->e(maxInd) >= expected) { - output = val; - return false; - } - output = expected; - return true; + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + auto sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(0)); + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT.p(e, sumT.e(e) * listOfTensors.at(i).e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT.lengthOf()); + } else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT.assign(listOfTensors.at(i)); + } } + } +} - template - static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); +// template +// static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, +// NDArray& anOutput) { +// } - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->e(fi->second.at(0)); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); - } - output->p(fi->first, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - auto maxT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { - T val = sd::math::nd4j_max(maxT.e(e), outputT.e(e)); - - outputT.p(e, val); - } - } - } - } - } - void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - template - static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - //std::sort(idxs.begin(), idxs.end()); +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, + (input, indices, output), LIBND4J_TYPES); +} - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output) { + auto val = indices->e(0); + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + output = indices->e(e); + if (val.e(0) > output.e(0)) return false; + val = indices->e(e); + } - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->t(fi->second.at(0)); + return true; +} - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); - } - output->t(fi->first) = val; - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto minT = listOfTensors.at(fi->second.at(idx)); - - for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { - outputT.t(e) = sd::math::nd4j_min(minT.t(e), outputT.t(e)); - } - } - //outputT->assign(maxT); - } - } +// BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, +// NDArray&, NDArray&), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, + (NDArray * input, NDArray* indices, NDArray* output), + LIBND4J_TYPES); +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted segment ops +// -------------------------------------------------------------------------------------------------------------- +// // + +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + Nd4jLong val = indices->e(0); + + Nd4jLong maxInd = indices->argMax(); + if (indices->e(maxInd) >= expected) { + output = val; + return false; + } + output = expected; + return true; +} +template +static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + // int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->e(fi->second.at(0)); + for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); + ++idx) { + val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); + } + output->p(fi->first, val); } - void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), - NUMERIC_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - //std::sort(idxs.begin(), idxs.end()); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - int loop_size = fi->second.size(); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); + ++idx) { + auto maxT = listOfTensors.at(fi->second.at(idx)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + T val = sd::math::nd4j_max(maxT.e(e), outputT.e(e)); - // FIXME: parallelism here? - for (size_t idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - - output->p(fi->first, sumValue / fi->second.size()); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - // FIXME: parallelism here? - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loopSize = fi->second.size(); - - for (Nd4jLong idx = 1; idx < loopSize; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - outputT += current; - } - (outputT) /= double(fi->second.size()); - } + outputT.p(e, val); } + } + } + } +} +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, + (input, indices, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, + (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, + NDArray* output), + NUMERIC_TYPES); + +template +static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + // int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs; //(indices->lengthOf()); + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->t(fi->second.at(0)); + + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); + } + output->t(fi->first) = val; } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - if (input->isVector()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - Nd4jLong loop_size = fi->second.size(); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto minT = listOfTensors.at(fi->second.at(idx)); - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - outputT += current; - } - //outputT->assign(maxT); - } + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + outputT.t(e) = sd::math::nd4j_min(minT.t(e), outputT.t(e)); } + } + // outputT->assign(maxT); } + } +} +void unsortedSegmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, + (input, indices, numOfClasses, output), NUMERIC_TYPES); +} - template - void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); +BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, + (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, + NDArray* output), + NUMERIC_TYPES); - //std::sort(idxs.begin(), idxs.end()); +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - output->assign(1.f); + // std::sort(idxs.begin(), idxs.end()); - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T prodValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - prodValue *= input->e(fi->second.at(idx)); - } - output->p(fi->first, prodValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + if (input->isVector()) { // 1D case - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + int loop_size = fi->second.size(); - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); + // FIXME: parallelism here? + for (size_t idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } - outputT *= current; - } - } - } - } - - void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); + output->p(fi->first, sumValue / fi->second.size()); } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue / sd::math::nd4j_sqrt(fi->second.size())); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT.assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - outputT += current; - } - //outputT->assign(maxT); - outputT /= sd::math::nd4j_sqrt(fi->second.size()); - } - } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // FIXME: parallelism here? + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loopSize = fi->second.size(); + + for (Nd4jLong idx = 1; idx < loopSize; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + (outputT) /= double(fi->second.size()); } + } +} - // -------------------------------------------------------------------------------------------------------------- // - // Backpropagate ops helpers - // -------------------------------------------------------------------------------------------------------------- // - // Sorted backpropagate ops - // - // segment max - template - int segmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - segmentMaxFunctor_(input, indices, &tempRes); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, loop_size); - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current.lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) <= T(1.e-6)) - currentOut.p(e, currentGradOut.e(e)); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - return ND4J_STATUS_OK; - } + if (input->isVector()) { // 1D case + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + Nd4jLong loop_size = fi->second.size(); - int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), NUMERIC_TYPES); + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue); } - BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES); - - // segmen min - int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray tempRes = gradOut->dup(); - segmentMinFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - output->assign(0.); - int pos = 0; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current.lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) < - 1.e-5) - currentOut.p(e, currentGradOut.e(e)); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loop_size = fi->second.size(); + + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + // outputT->assign(maxT); } + } +} - // segmen mean - int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - int numClasses = output->sizeAt(0); - MAP_IMPL classCount;//(numClasses); +template +void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - for (Nd4jLong count = 0; count < numClasses; ++count) { - classCount[count] = 0; - } + // std::sort(idxs.begin(), idxs.end()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)] ++; - } + output->assign(1.f); - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); -; - - int pos = 0; - //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current.lengthOf(); e++) { - currentOut.p(e, currentGradOut.e(e) / classCount.at(classNum)); - } - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; + if (input->isVector()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T prodValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + prodValue *= input->e(fi->second.at(idx)); + } + output->p(fi->first, prodValue); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { -// int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); - currentOut.assign(currentGradOut); - } - //}; + outputT *= current; + } + } + } +} - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, + (input, indices, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, + (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, + NDArray* output), + NUMERIC_TYPES); + +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + MAP_IMPL> idxs; //(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + // std::sort(idxs.begin(), idxs.end()); + + if (input->isVector()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue / sd::math::nd4j_sqrt( + fi->second.size())); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT.assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + outputT += current; + } + // outputT->assign(maxT); + outputT /= sd::math::nd4j_sqrt(fi->second.size()); + } + } +} - int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto tempRes = gradOut->dup(); - segmentProdFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); - - currentOut.assign(currentFFOut * currentGradOut / current); - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); +// -------------------------------------------------------------------------------------------------------------- +// // Backpropagate ops helpers +// -------------------------------------------------------------------------------------------------------------- +// // Sorted backpropagate ops +// +// segment max +template +int segmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + segmentMaxFunctor_(input, indices, &tempRes); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= + T(1.e-6)) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, loop_size); + } else { + std::vector restDims = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) <= T(1.e-6)) + currentOut.p(e, currentGradOut.e(e)); } + } + }; - return ND4J_STATUS_OK; - } + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted backpropagate segment ops - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { -// int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - for (int e = 0; e < current.lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - current.e(e)) < 1.e-5) - currentOut.p(e, currentGradOut.e(e)); - } - } + return ND4J_STATUS_OK; +} + +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, + (context, input, indices, gradOut, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output), + NUMERIC_TYPES); + +// segmen min +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray tempRes = gradOut->dup(); + segmentMinFunctor(context, input, indices, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - + input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + output->assign(0.); + int pos = 0; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) < 1.e-5) + currentOut.p(e, currentGradOut.e(e)); } + } + }; - return ND4J_STATUS_OK; - } + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; +} - int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +// segmen mean +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + int numClasses = output->sizeAt(0); + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - template - static int unsortedSegmentMinFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto tempRes = gradOut->dup(); - unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) - output->t(e) = gradOut->t(classNum); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current.lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).t(e) - current.t(e)) < 1.e-6) - currentOut.t(e) = currentGradOut.t(e); - } - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - - return ND4J_STATUS_OK; + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + ; + + int pos = 0; + // auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, + // uint64_t increment) -> void { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + currentOut.p(e, currentGradOut.e(e) / classCount.at(classNum)); + } } + //}; + + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; +} - int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - MAP_IMPL classCount;//(numClasses); + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } + currentOut.assign(currentGradOut); + } + //}; - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); +} - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - currentOut.assign(currentGradOut / double(classCount[classNum])); - } - } - return ND4J_STATUS_OK; +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto tempRes = gradOut->dup(); + segmentProdFunctor(context, input, indices, &tempRes); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / + input->e(e)); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // int numOfClasses = tempRes.sizeAt(0); // number of classes + // std::vector> outputs(numOfClasses); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); + + currentOut.assign(currentFFOut * currentGradOut / current); + } + //}; - int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + return ND4J_STATUS_OK; +} - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted backpropagate segment ops +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - + input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); + } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + for (int e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).e(e) - + current.e(e)) < 1.e-5) + currentOut.p(e, currentGradOut.e(e)); + } + } + } - currentOut.assign(currentGradOut); - } - //}; + return ND4J_STATUS_OK; +} - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR( + output->dataType(), return unsortedSegmentMaxFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output), + NUMERIC_TYPES); + +template +static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + auto tempRes = gradOut->dup(); + unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) + output->t(e) = gradOut->t(classNum); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).t(e) - + current.t(e)) < 1.e-6) + currentOut.t(e) = currentGradOut.t(e); + } } + //}; - int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } - auto tempRes = gradOut->dup(); - unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); - } - }; + return ND4J_STATUS_OK; +} - samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); - - currentOut.assign(currentFFOut * currentGradOut / current); - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR( + output->dataType(), return unsortedSegmentMinFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, + (sd::LaunchContext * context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output), + NUMERIC_TYPES); + +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); + } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + currentOut.assign(currentGradOut / double(classCount[classNum])); + } + } + return ND4J_STATUS_OK; +} - return Status::OK(); +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); } + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); -// template - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL classCount;//(numClasses); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } + currentOut.assign(currentGradOut); + } + //}; - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / sd::math::nd4j_sqrt(classCount[classNum])); - } - //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); +} - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (int e = 0; e < current.lengthOf(); e++) { - currentOut.p(e, currentGradOut.e(e) / sd::math::nd4j_sqrt(classCount[classNum])); - } - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + auto tempRes = gradOut->dup(); + unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * + tempRes.e(classNum) / + input->e(e)); + } + }; + + samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); + + currentOut.assign(currentFFOut * currentGradOut / current); } + //}; + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + + return Status::OK(); } + +// template +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL classCount; //(numClasses); + + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector()) { + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + auto classNum = indices->e(e); + output->p(e, + gradOut->e(classNum) / + sd::math::nd4j_sqrt(classCount[classNum])); + } + //}; + + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (int e = 0; e < current.lengthOf(); e++) { + currentOut.p( + e, currentGradOut.e(e) / + sd::math::nd4j_sqrt(classCount[classNum])); + } + } + //}; + + // samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp index 8e25c4690202..4b4be8eb1d0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp @@ -18,30 +18,36 @@ // @author GS // -#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { - auto func = PRAGMA_THREADS_FOR_2D { - for (auto i = start_x; i < stop_x; i += inc_x) - for (auto k = start_y; k < stop_y; k += inc_y) - if (i < input->t(k)) - output->t(k * maxIndex + i) = B(true); //, T(1.0f)); - }; - - samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), 1); - } +template +static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { + auto func = PRAGMA_THREADS_FOR_2D { + for (auto i = start_x; i < stop_x; i += inc_x) + for (auto k = start_y; k < stop_y; k += inc_y) + if (i < input->t(k)) + output->t(k * maxIndex + i) = B(true); //, T(1.0f)); + }; - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); - } - - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); + samediff::Threads::parallel_for(func, 0, maxIndex, 1, 0, input->lengthOf(), + 1); } + +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, + (input, output, maxIndex), INTEGER_TYPES, + LIBND4J_TYPES_EXTENDED); } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, + (NDArray * input, NDArray* output, int maxIndex), + INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp index 07cbca04ea5c..0eceb0127c1d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sg_cb.cpp @@ -18,609 +18,724 @@ // @author raver119@gmail.com // +#include #include #include -#include #define HS_MAX_EXP 6.0f namespace sd { - namespace ops { - namespace helpers { - template - void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot(0.0f); - T g(0.0f); - T f(0.0f); - - // dot - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1[e]; - } - - // gradient - if (dot < (T) - HS_MAX_EXP || dot >= (T) HS_MAX_EXP) - return; - - - int idx = static_cast((dot + HS_MAX_EXP) * ((float) expLength / HS_MAX_EXP / 2.0f)); - - if (idx >= expLength || idx < 0) - return; - - f = expTable[idx]; - g = (static_cast(1.0f) - static_cast(code) - f) * (T) alpha; - - // axpy1 - - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1[e] = g * syn0[e] + syn1[e]; - } - } - } - - template - void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot = (T) 0.0f; - T g = (T) 0.0f; - - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1Neg[e]; - } - - if (dot > HS_MAX_EXP) - g = (code - 1) * alpha; - else if (dot < (T) - HS_MAX_EXP) - g = (code - 0) * alpha; - else { - int idx = (int) ((dot + (T) HS_MAX_EXP) * ((T) expLength / HS_MAX_EXP / 2.0)); - if (idx >= expLength) - return; - - if (idx < 0) - return; - - g = ((T) code - expTable[idx]) * alpha; - } - - // axpy1 - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1Neg[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1Neg[e] = g * syn0[e] + syn1Neg[e]; - } - } - } - - template - void cbow_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - - auto neu1 = new T[vectorLength]; - auto neu1e = new T[vectorLength]; - memset(neu1, 0, vectorLength * sizeof(T)); - memset(neu1e, 0, vectorLength * sizeof(T)); - - // building neu1 for current window - for (int c = 0; c < contextWidth; c++) { - if (context[c] >= vocabSize) - throw std::runtime_error("Bad context 4"); - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += syn0word[i]; - } - } - - // for inference we add additional inference vector - if (infVector != nullptr) { - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += infVector[i]; - } - } - - - // average neu1 - if (contextWidth > 0) { - for (int i = 0; i < vectorLength; i++) { - neu1[i] /= contextWidth + (infVector != nullptr ? 1 : 0); - } - } - - // softmax round - if (hsRounds > 0) { - for (int i = 0; i < hsRounds; i++) { - if (indices[i] < 0 || indices[i] >= vocabSize) - throw std::runtime_error("Bad context 5"); - - hSoftmax_(neu1, syn1 + (indices[i] * vectorLength), expTable, neu1e, alpha, vectorLength, codes[i], expLength, infVector != nullptr); - } - } - - auto nsStarter = ngStarter; - auto irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - // if we don't train words - we skip start of idxSyn0 - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // propagate neu1e -> syn0 - if (infVector == nullptr) { - for (int c = starter; c < contextWidth; c++) { - if (lockedWords[c] == 1) - continue; - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - syn0word[i] += neu1e[i]; - } - } - } else { - - for (int i = 0; i < vectorLength; i++) { - infVector[i] += neu1e[i]; - } - } - - - delete[] neu1; - delete[] neu1e; - } - BUILD_SINGLE_TEMPLATE(template void cbow_, (void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); - - - template - void skipgram_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - - auto neu1e = new T[vectorLength]; - memset(neu1e, 0, vectorLength * sizeof(T)); - - // hierarchic softmax goes first (if enabled) - auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength); - auto irow = 0; - if (hsRounds > 0) { - for (int r = 0; r < hsRounds; r++) { - irow = indices[r]; - if (irow < 0 || irow >= vocabSize) - break; - - hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes[r], expLength, infVector != nullptr); - } - } - - // negative sampling goes second (if enabled) - auto nsStarter = ngStarter; - irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - if (infVector == nullptr) { - for (int e = 0; e < vectorLength; e++) { - syn0row[e] += neu1e[e]; - } - } else { - for (int e = 0; e < vectorLength; e++) { - infVector[e] += neu1e[e]; - } - } - - delete[] neu1e; - } - BUILD_SINGLE_TEMPLATE(template void skipgram_, (void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength), FLOAT_TYPES); - - int binarySearch(const int *haystack, const int needle, const int totalElements) { - int firstIndex = 0; - int lastIndex = totalElements - 1; - int halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - - while(haystack[halfIndex] != needle && firstIndex < lastIndex) { - if (needle < haystack[halfIndex]) { - lastIndex = halfIndex - 1; - } else if (needle > haystack[halfIndex]) { - firstIndex = halfIndex + 1; - } - halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - } - - return (haystack[halfIndex] == needle) ? halfIndex : -1; - } - - template - static void do_update(const int target, const int rowIndex, const int count, T *syn0, T *neu1t, const int vectorLength) { - - auto syn0row = syn0 + (target * vectorLength); - auto neu1e = neu1t + (rowIndex * vectorLength); - for (int e = 0; e< vectorLength; e++) - syn0row[e] += neu1e[e] / count; - } - - template - static void do_positive(const int target, const int postive, T* syn0, T* syn1Neg, T* expTable, T* neu1e, const double alpha, const int vectorLength, const int expLength) { - //nd4j_printf("Target: [%i]; Positive: [%i]; TID: [%i];\n", target, postive, omp_get_thread_num()); - nSampling_(syn0, syn1Neg, expTable, neu1e, alpha, vectorLength, 1, expLength, false); - } - - template - static void do_negative(int target, int positive, T* syn0, T* syn1Neg, T* expTable, T* negTable, T* neu1e, int *sStarters, const double alpha, const unsigned long long rv, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int nsRounds, const int numThreads, const int numTargets) { - int irow = 0; - unsigned long long randomValue = rv; - for (int r = 0; r < nsRounds; r++) { - randomValue = sd::math::nd4j_abs(randomValue * (unsigned long long) 25214903917 + 11); - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == positive) - continue; - - // we shift irow here to guarantee independence - - int dim = irow % numThreads; - if (dim != omp_get_thread_num()) { - irow += (numThreads - dim + omp_get_thread_num()); - - // roll back to nearest affilated word - while (irow >= vocabSize) - irow -= numThreads; - - // if this row was processed as first step somewhere - skip it - if (binarySearch(sStarters, irow, numTargets) > 0) { - r--; - continue; - } - } - - - nSampling_(syn0, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, 0, expLength, false); - } - } - - template - void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads) { - //auto syn0 = reinterpret_cast(vsyn0); - //auto syn1 = reinterpret_cast(vsyn1); - //auto syn1Neg = reinterpret_cast(vsyn1Neg); - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - - // regular mode provides 0 guarantees for reproducibility - auto numTargets = targets.lengthOf(); - auto bTarget = targets.bufferAsT(); - auto bIndices = indices.bufferAsT(); - auto bCodes = codes.bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - T sneu1e[600]; - - for (auto t = start; t < stop; t++) { - T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - memset(neu1e, 0, vectorLength * sizeof(T)); - - auto target = bTarget[t]; - auto alpha = lr.e(t); - unsigned long long randomValue = nextRandom.e(t); - - auto syn0row = reinterpret_cast(s0.bufferWithOffset(target * vectorLength)); - - if (hsRounds > 0) { - int irow = 0; - auto cShift = t * idxShift; - - for (Nd4jLong e = 0; e < hsRounds; e++) { - irow = bIndices[e + cShift]; - if (irow < 0 || irow >= vocabSize) - continue; - - auto syn1row = s1.bufferWithOffset(irow * vectorLength); - auto code = bCodes[e + cShift]; - - //nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code); - hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, - expLength, false); - } - } - - - if (nsRounds > 0) { - int irow = negStarters.e(t); - int nsStarter = irow; - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - } - - for (int e = 0; e < vectorLength; e++) - syn0row[e] += neu1e[e]; - - // optionally release temp arrays - if (vectorLength > 600) - delete[] neu1e; - } - }; - - samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); - } - BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool preciseMode, const int numThreads), FLOAT_TYPES); - - - template - void cbowBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads) { - const auto syn0 = s0.bufferAsT(); - const auto syn1 = s1.bufferAsT(); - const auto syn1Neg = s1n.bufferAsT(); - - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - const auto numTargets = context.sizeAt(0); - const int contextWidth = context.sizeAt(1); - - const auto bContext = context.bufferAsT(); - const auto bLocker = lockedWords.bufferAsT(); - const auto bIndices = indices.bufferAsT(); - const auto bCodes = codes.bufferAsT(); - const auto bStarters = negStarters.bufferAsT(); - const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); - - auto func = PRAGMA_THREADS_FOR { - T sneu1[600]; - T sneu1e[600]; - - for (auto e = start; e < stop; e++) { - T *neu1 = vectorLength <= 600 ? sneu1 : new T[vectorLength]; - T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - - // optionally we nullify temp arrays after successful (and on first) cycle - memset(neu1, 0, sizeof(T) * vectorLength); - memset(neu1e, 0, sizeof(T) * vectorLength); - - auto alpha = lr.e(e); - auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); - - int actualContext = 0; - - // building neu1 for current window - for (int c = 0; c < contextWidth; c++) { - // getting next context word - auto cContext = bContext[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0) - continue; - - if (cContext >= vocabSize) - throw std::runtime_error("ContextID can't be >= vocab size"); - - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - neu1[i] += syn0word[i]; - - actualContext++; - } - - if (infVector != nullptr) - actualContext++; - - if (actualContext > 1) { - for (int i = 0; i < vectorLength; i++) - neu1[i] /= actualContext; - } - - // hierarchic softmax step - if (!indices.isEmpty()) { - for (Nd4jLong i = 0; i < numIndices; i++) { - const int cIndex = bIndices[(e * numIndices) + i]; - const int cCode = bCodes[(e * numIndices) + i]; - - // we're skipping padded values - if (cIndex < 0) - continue; - - if (cIndex >= vocabSize) - throw std::runtime_error("Index can't be > vocab size"); - - hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, alpha, vectorLength, - cCode, expLength, false); - } - } - - // negative sampling step - if (!negStarters.isEmpty() && nsRounds > 0) { - int irow = bStarters[e]; - const int nsStarter = irow; - unsigned long long randomValue = nextRandom.e(e); - - for (int r = 0; r < nsRounds + 1; r++) { - // we're skipping rng on 0 step - if (r != 0) { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } else { - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, - alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr); - } - - //nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", omp_get_thread_num(), 0, irow); - } - } - - - // if we're skipping labels - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // applying previously averaged results - for (int c = starter; c < contextWidth; c++) { - // getting context - auto cContext = bContext[c + (e * contextWidth)]; - auto cLock = bLocker[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0 || cLock == 1) - continue; - - if (cContext >= vocabSize) - throw std::runtime_error("ContextID can't be > vocab size"); - - // one word from context - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - syn0word[i] += neu1e[i]; - - } - - // optionally release temp arrays - if (vectorLength > 600) { - delete[] neu1; - delete[] neu1e; - } - } - }; - - samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); - } - BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES); - - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { - auto xType = syn0.dataType(); - - // single round case - if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) { - auto hsRounds = codes.lengthOf(); - - BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), negTable.buffer(), inferenceVector.buffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e(0), randomValue.e(0), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf()), FLOAT_TYPES); - } else if (ngStarter.isVector() || target.isVector()){ - // batch mode - - BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.lengthOf(), preciseMode, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("SkipGram: target must have rank 0 or 1"); - } - - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) { - auto xType = syn0.dataType(); - - if ((context.rankOf() == 0 || context.rankOf() == 1) && (indices.rankOf() == 1 || indices.rankOf() == 0)) { - // single round case - /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); - if (context.lengthOf() == 2) { - context.printBuffer("context"); - lockedWords.printBuffer("locked"); - codes.printBuffer("codes"); - indices.printBuffer("indices"); - }*/ - - auto hsRounds = codes.lengthOf(); - - BUILD_SINGLE_SELECTOR(xType, cbow_, (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), negTable.buffer(), inferenceVector.buffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(context.buffer()), reinterpret_cast(lockedWords.buffer()),reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e( 0), randomValue.e(0), (int) context.lengthOf(), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), FLOAT_TYPES); - } else if (context.rankOf() == 2 && indices.rankOf() == 2) { - // batch mode - //nd4j_printf("Batch exec\n",""); - - BUILD_SINGLE_SELECTOR(xType, cbowBatchExec_, (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, context, lockedWords, target, ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); - } +namespace ops { +namespace helpers { +template +void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot(0.0f); + T g(0.0f); + T f(0.0f); + + // dot + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1[e]; + } + + // gradient + if (dot < (T)-HS_MAX_EXP || dot >= (T)HS_MAX_EXP) return; + + int idx = static_cast((dot + HS_MAX_EXP) * + ((float)expLength / HS_MAX_EXP / 2.0f)); + + if (idx >= expLength || idx < 0) return; + + f = expTable[idx]; + g = (static_cast(1.0f) - static_cast(code) - f) * (T)alpha; + + // axpy1 + + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1[e] = g * syn0[e] + syn1[e]; + } + } +} + +template +void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot = (T)0.0f; + T g = (T)0.0f; + + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1Neg[e]; + } + + if (dot > HS_MAX_EXP) + g = (code - 1) * alpha; + else if (dot < (T)-HS_MAX_EXP) + g = (code - 0) * alpha; + else { + int idx = (int)((dot + (T)HS_MAX_EXP) * ((T)expLength / HS_MAX_EXP / 2.0)); + if (idx >= expLength) return; + + if (idx < 0) return; + + g = ((T)code - expTable[idx]) * alpha; + } + + // axpy1 + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1Neg[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1Neg[e] = g * syn0[e] + syn1Neg[e]; + } + } +} + +template +void cbow_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, + int *context, int *lockedWords, int *indices, int8_t *codes, + double alpha, Nd4jLong randomValue, const int contextWidth, + const int hsRounds, const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, const int negLength, + const int numLabels, const bool trainWords) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + + auto neu1 = new T[vectorLength]; + auto neu1e = new T[vectorLength]; + memset(neu1, 0, vectorLength * sizeof(T)); + memset(neu1e, 0, vectorLength * sizeof(T)); + + // building neu1 for current window + for (int c = 0; c < contextWidth; c++) { + if (context[c] >= vocabSize) throw std::runtime_error("Bad context 4"); + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + neu1[i] += syn0word[i]; + } + } + + // for inference we add additional inference vector + if (infVector != nullptr) { + for (int i = 0; i < vectorLength; i++) { + neu1[i] += infVector[i]; + } + } + + // average neu1 + if (contextWidth > 0) { + for (int i = 0; i < vectorLength; i++) { + neu1[i] /= contextWidth + (infVector != nullptr ? 1 : 0); + } + } + + // softmax round + if (hsRounds > 0) { + for (int i = 0; i < hsRounds; i++) { + if (indices[i] < 0 || indices[i] >= vocabSize) + throw std::runtime_error("Bad context 5"); + + hSoftmax_(neu1, syn1 + (indices[i] * vectorLength), expTable, neu1e, + alpha, vectorLength, codes[i], expLength, + infVector != nullptr); + } + } + + auto nsStarter = ngStarter; + auto irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr); + } + } + + // if we don't train words - we skip start of idxSyn0 + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // propagate neu1e -> syn0 + if (infVector == nullptr) { + for (int c = starter; c < contextWidth; c++) { + if (lockedWords[c] == 1) continue; + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + syn0word[i] += neu1e[i]; + } + } + } else { + for (int i = 0; i < vectorLength; i++) { + infVector[i] += neu1e[i]; + } + } + + delete[] neu1; + delete[] neu1e; +} +BUILD_SINGLE_TEMPLATE(template void cbow_, + (void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, + int ngStarter, int *context, int *lockedWords, + int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int contextWidth, + const int hsRounds, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const int numLabels, const bool trainWords), + FLOAT_TYPES); + +template +void skipgram_(void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, + int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, + const int hsRounds, const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + + auto neu1e = new T[vectorLength]; + memset(neu1e, 0, vectorLength * sizeof(T)); + + // hierarchic softmax goes first (if enabled) + auto syn0row = + infVector != nullptr ? infVector : syn0 + (target * vectorLength); + auto irow = 0; + if (hsRounds > 0) { + for (int r = 0; r < hsRounds; r++) { + irow = indices[r]; + if (irow < 0 || irow >= vocabSize) break; + + hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, codes[r], expLength, + infVector != nullptr); + } + } + + // negative sampling goes second (if enabled) + auto nsStarter = ngStarter; + irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr); + } + } + + if (infVector == nullptr) { + for (int e = 0; e < vectorLength; e++) { + syn0row[e] += neu1e[e]; + } + } else { + for (int e = 0; e < vectorLength; e++) { + infVector[e] += neu1e[e]; + } + } + + delete[] neu1e; +} +BUILD_SINGLE_TEMPLATE(template void skipgram_, + (void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, + int ngStarter, int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int hsRounds, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength), + FLOAT_TYPES); + +int binarySearch(const int *haystack, const int needle, + const int totalElements) { + int firstIndex = 0; + int lastIndex = totalElements - 1; + int halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + + while (haystack[halfIndex] != needle && firstIndex < lastIndex) { + if (needle < haystack[halfIndex]) { + lastIndex = halfIndex - 1; + } else if (needle > haystack[halfIndex]) { + firstIndex = halfIndex + 1; + } + halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + } + + return (haystack[halfIndex] == needle) ? halfIndex : -1; +} + +template +static void do_update(const int target, const int rowIndex, const int count, + T *syn0, T *neu1t, const int vectorLength) { + auto syn0row = syn0 + (target * vectorLength); + auto neu1e = neu1t + (rowIndex * vectorLength); + for (int e = 0; e < vectorLength; e++) syn0row[e] += neu1e[e] / count; +} + +template +static void do_positive(const int target, const int postive, T *syn0, + T *syn1Neg, T *expTable, T *neu1e, const double alpha, + const int vectorLength, const int expLength) { + // nd4j_printf("Target: [%i]; Positive: [%i]; TID: [%i];\n", target, postive, + // omp_get_thread_num()); + nSampling_(syn0, syn1Neg, expTable, neu1e, alpha, vectorLength, 1, + expLength, false); +} + +template +static void do_negative(int target, int positive, T *syn0, T *syn1Neg, + T *expTable, T *negTable, T *neu1e, int *sStarters, + const double alpha, const unsigned long long rv, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const int nsRounds, const int numThreads, + const int numTargets) { + int irow = 0; + unsigned long long randomValue = rv; + for (int r = 0; r < nsRounds; r++) { + randomValue = sd::math::nd4j_abs( + randomValue * (unsigned long long)25214903917 + 11); + auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; + + if (irow == positive) continue; + + // we shift irow here to guarantee independence + + int dim = irow % numThreads; + if (dim != omp_get_thread_num()) { + irow += (numThreads - dim + omp_get_thread_num()); + + // roll back to nearest affilated word + while (irow >= vocabSize) irow -= numThreads; + + // if this row was processed as first step somewhere - skip it + if (binarySearch(sStarters, irow, numTargets) > 0) { + r--; + continue; + } + } + + nSampling_(syn0, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, + vectorLength, 0, expLength, false); + } +} + +template +void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool preciseMode, const int numThreads) { + // auto syn0 = reinterpret_cast(vsyn0); + // auto syn1 = reinterpret_cast(vsyn1); + // auto syn1Neg = reinterpret_cast(vsyn1Neg); + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + + // regular mode provides 0 guarantees for reproducibility + auto numTargets = targets.lengthOf(); + auto bTarget = targets.bufferAsT(); + auto bIndices = indices.bufferAsT(); + auto bCodes = codes.bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + T sneu1e[600]; + + for (auto t = start; t < stop; t++) { + T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; + memset(neu1e, 0, vectorLength * sizeof(T)); + + auto target = bTarget[t]; + auto alpha = lr.e(t); + unsigned long long randomValue = nextRandom.e(t); + + auto syn0row = + reinterpret_cast(s0.bufferWithOffset(target * vectorLength)); + + if (hsRounds > 0) { + int irow = 0; + auto cShift = t * idxShift; + + for (Nd4jLong e = 0; e < hsRounds; e++) { + irow = bIndices[e + cShift]; + if (irow < 0 || irow >= vocabSize) continue; + + auto syn1row = s1.bufferWithOffset(irow * vectorLength); + auto code = bCodes[e + cShift]; + + // nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, + // code); + hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + code, expLength, false); + } + } + + if (nsRounds > 0) { + int irow = negStarters.e(t); + int nsStarter = irow; + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } + } + + for (int e = 0; e < vectorLength; e++) syn0row[e] += neu1e[e]; + + // optionally release temp arrays + if (vectorLength > 600) delete[] neu1e; + } + }; + + samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); +} +BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool preciseMode, const int numThreads), + FLOAT_TYPES); + +template +void cbowBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, + void *vnegTable, void *vinfVector, NDArray &context, + NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, NDArray &nLabels, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength, const bool trainWords, + const int numThreads) { + const auto syn0 = s0.bufferAsT(); + const auto syn1 = s1.bufferAsT(); + const auto syn1Neg = s1n.bufferAsT(); + + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + const auto numTargets = context.sizeAt(0); + const int contextWidth = context.sizeAt(1); + + const auto bContext = context.bufferAsT(); + const auto bLocker = lockedWords.bufferAsT(); + const auto bIndices = indices.bufferAsT(); + const auto bCodes = codes.bufferAsT(); + const auto bStarters = negStarters.bufferAsT(); + const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); + + auto func = PRAGMA_THREADS_FOR { + T sneu1[600]; + T sneu1e[600]; + + for (auto e = start; e < stop; e++) { + T *neu1 = vectorLength <= 600 ? sneu1 : new T[vectorLength]; + T *neu1e = vectorLength <= 600 ? sneu1e : new T[vectorLength]; + + // optionally we nullify temp arrays after successful (and on first) cycle + memset(neu1, 0, sizeof(T) * vectorLength); + memset(neu1e, 0, sizeof(T) * vectorLength); + + auto alpha = lr.e(e); + auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); + + int actualContext = 0; + + // building neu1 for current window + for (int c = 0; c < contextWidth; c++) { + // getting next context word + auto cContext = bContext[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0) continue; + + if (cContext >= vocabSize) + throw std::runtime_error("ContextID can't be >= vocab size"); + + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) neu1[i] += syn0word[i]; + + actualContext++; + } + + if (infVector != nullptr) actualContext++; + + if (actualContext > 1) { + for (int i = 0; i < vectorLength; i++) neu1[i] /= actualContext; + } + + // hierarchic softmax step + if (!indices.isEmpty()) { + for (Nd4jLong i = 0; i < numIndices; i++) { + const int cIndex = bIndices[(e * numIndices) + i]; + const int cCode = bCodes[(e * numIndices) + i]; + + // we're skipping padded values + if (cIndex < 0) continue; + + if (cIndex >= vocabSize) + throw std::runtime_error("Index can't be > vocab size"); + + hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, + alpha, vectorLength, cCode, expLength, false); + } + } + + // negative sampling step + if (!negStarters.isEmpty() && nsRounds > 0) { + int irow = bStarters[e]; + const int nsStarter = irow; + unsigned long long randomValue = nextRandom.e(e); + + for (int r = 0; r < nsRounds + 1; r++) { + // we're skipping rng on 0 step + if (r != 0) { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } else { + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr); + } + + // nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", + // omp_get_thread_num(), 0, irow); } + } + + // if we're skipping labels + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // applying previously averaged results + for (int c = starter; c < contextWidth; c++) { + // getting context + auto cContext = bContext[c + (e * contextWidth)]; + auto cLock = bLocker[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0 || cLock == 1) continue; + + if (cContext >= vocabSize) + throw std::runtime_error("ContextID can't be > vocab size"); + + // one word from context + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) syn0word[i] += neu1e[i]; + } + + // optionally release temp arrays + if (vectorLength > 600) { + delete[] neu1; + delete[] neu1e; + } } -} \ No newline at end of file + }; + + samediff::Threads::parallel_tad(func, 0, numTargets, 1, numThreads); +} +BUILD_SINGLE_TEMPLATE( + template void cbowBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, + void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, + NDArray &nextRandom, NDArray &nLabels, const int nsRounds, + const int vocabSize, const int vectorLength, const int expLength, + const int negLength, const bool trainWords, const int numThreads), + FLOAT_TYPES); + +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers) { + auto xType = syn0.dataType(); + + // single round case + if ((ngStarter.isScalar() && !ngStarter.isEmpty()) || + (target.isScalar() && !target.isEmpty())) { + auto hsRounds = codes.lengthOf(); + + BUILD_SINGLE_SELECTOR( + xType, skipgram_, + (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), + negTable.buffer(), inferenceVector.buffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), hsRounds, nsRounds, (int)syn0.sizeAt(0), + (int)syn0.sizeAt(1), (int)expTable.lengthOf(), + (int)negTable.lengthOf()), + FLOAT_TYPES); + } else if (ngStarter.isVector() || target.isVector()) { + // batch mode + + BUILD_SINGLE_SELECTOR( + xType, skipgramBatchExec_, + (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, + target, ngStarter, indices, codes, alpha, randomValue, nsRounds, + syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), + negTable.lengthOf(), preciseMode, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("SkipGram: target must have rank 0 or 1"); +} + +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + int numWorkers) { + auto xType = syn0.dataType(); + + if ((context.rankOf() == 0 || context.rankOf() == 1) && + (indices.rankOf() == 1 || indices.rankOf() == 0)) { + // single round case + /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; + Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); if + (context.lengthOf() == 2) { context.printBuffer("context"); + lockedWords.printBuffer("locked"); + codes.printBuffer("codes"); + indices.printBuffer("indices"); + }*/ + + auto hsRounds = codes.lengthOf(); + + BUILD_SINGLE_SELECTOR( + xType, cbow_, + (syn0.buffer(), syn1.buffer(), syn1Neg.buffer(), expTable.buffer(), + negTable.buffer(), inferenceVector.buffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(context.buffer()), + reinterpret_cast(lockedWords.buffer()), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), (int)context.lengthOf(), hsRounds, + nsRounds, (int)syn0.sizeAt(0), (int)syn0.sizeAt(1), + (int)expTable.lengthOf(), (int)negTable.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), + FLOAT_TYPES); + } else if (context.rankOf() == 2 && indices.rankOf() == 2) { + // batch mode + // nd4j_printf("Batch exec\n",""); + + BUILD_SINGLE_SELECTOR( + xType, cbowBatchExec_, + (syn0, syn1, syn1Neg, expTable.buffer(), negTable.buffer(), nullptr, + context, lockedWords, target, ngStarter, indices, codes, alpha, + randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), + expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), + trainWords, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp index 9dfeac2ec16e..387bcf54a70b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -21,61 +21,65 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x >> shift; - }; +namespace ops { +namespace helpers { +template +void rshift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x >> shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x << shift; - }; +template +void shift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x << shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), + INTEGER_TYPES); +} - template - void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x >> shift | x << step; - }; +template +void cyclic_rshift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x >> shift | x << step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x << shift | x >> step; - }; +template +void cyclic_shift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x << shift | x >> step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp index bfd44629cf27..d0605848015d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/softmax.cpp @@ -19,237 +19,239 @@ // @author raver119@gmail.com // -#include +#include +#include #include +#include + #include -#include -#include namespace sd { - namespace ops { - namespace helpers { - - template - static void softMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, void *output, Nd4jLong const* outShapeInfo) { - - auto inBuff = reinterpret_cast(input); - auto outBuff = reinterpret_cast(output); - - T max = -DataTypeUtils::max(); - T sum = 0.; - int inEWS = shape::elementWiseStride(inShapeInfo); - int outEWS = shape::elementWiseStride(outShapeInfo); - int length = shape::length(inShapeInfo); - - if (inEWS >= 1 && outEWS >= 1) { - - if (inEWS == 1 && outEWS == 1) { - - for (int i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i]); - - for (int i = 0; i < length; i++) { - outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); - sum += outBuff[i]; - } - - for (int i = 0; i < length; i++) - outBuff[i] /= sum; - } - else { - - for (int i = 0; i < length; i++) - max = sd::math::nd4j_max(max, inBuff[i * inEWS]); - - for (int i = 0; i < length; i++) { - T r = sd::math::nd4j_exp(inBuff[i * inEWS] - max); - outBuff[i * outEWS] = r; - sum += r; - } - - for (int i = 0; i < length; i++) - outBuff[i * outEWS] /= sum; - } - } - } - - /////////////////////////////////////////////////////////////////// - void softMaxForVector(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - - if(!input.isVector() || !output.isVector()) - throw std::runtime_error("ops::helpers::softMaxForVector function: input and output arrays must be vectors !"); - - auto xType = input.dataType(); - BUILD_SINGLE_SELECTOR(xType, softMaxForVector_, (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); - } - - template - void softmax_loop(const T* input, T *output, const Nd4jLong * offsets, Nd4jLong numOfSubArrs, uint32_t tadLen); +namespace ops { +namespace helpers { + +template +static void softMaxForVector_(void const* input, Nd4jLong const* inShapeInfo, + void* output, Nd4jLong const* outShapeInfo) { + auto inBuff = reinterpret_cast(input); + auto outBuff = reinterpret_cast(output); + + T max = -DataTypeUtils::max(); + T sum = 0.; + int inEWS = shape::elementWiseStride(inShapeInfo); + int outEWS = shape::elementWiseStride(outShapeInfo); + int length = shape::length(inShapeInfo); + + if (inEWS >= 1 && outEWS >= 1) { + if (inEWS == 1 && outEWS == 1) { + for (int i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i]); + + for (int i = 0; i < length; i++) { + outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); + sum += outBuff[i]; + } + + for (int i = 0; i < length; i++) outBuff[i] /= sum; + } else { + for (int i = 0; i < length; i++) + max = sd::math::nd4j_max(max, inBuff[i * inEWS]); + + for (int i = 0; i < length; i++) { + T r = sd::math::nd4j_exp(inBuff[i * inEWS] - max); + outBuff[i * outEWS] = r; + sum += r; + } + + for (int i = 0; i < length; i++) outBuff[i * outEWS] /= sum; + } + } +} + +/////////////////////////////////////////////////////////////////// +void softMaxForVector(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + if (!input.isVector() || !output.isVector()) + throw std::runtime_error( + "ops::helpers::softMaxForVector function: input and output arrays must " + "be vectors !"); + + auto xType = input.dataType(); + BUILD_SINGLE_SELECTOR( + xType, softMaxForVector_, + (input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()), + FLOAT_TYPES); +} + +template +void softmax_loop(const T* input, T* output, const Nd4jLong* offsets, + Nd4jLong numOfSubArrs, uint32_t tadLen); #ifdef _OPENMP - template <> - FORCEINLINE void softmax_loop(const float* input, float *output, const Nd4jLong * offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { +template <> +FORCEINLINE void softmax_loop(const float* input, float* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { #pragma omp parallel for default(shared) - for (Nd4jLong i = 0; i < numOfSubArrs; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - float max = -DataTypeUtils::max(); - float sum = 0.f; - -#pragma omp simd reduction(max:max) - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - -#pragma omp simd reduction(+:sum) - for (uint j = 0; j < tadLen; ++j) { - float temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - } -#else - template <> - FORCEINLINE void softmax_loop(const float *input, float *output, const Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - float max = -DataTypeUtils::max(); - float sum = 0.f; - - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - - for (uint j = 0; j < tadLen; ++j) { - float temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } + for (Nd4jLong i = 0; i < numOfSubArrs; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + float max = -DataTypeUtils::max(); + float sum = 0.f; + +#pragma omp simd reduction(max : max) + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + +#pragma omp simd reduction(+ : sum) + for (uint j = 0; j < tadLen; ++j) { + float temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - }; + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } +} +#else +template <> +FORCEINLINE void softmax_loop(const float* input, float* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + float max = -DataTypeUtils::max(); + float sum = 0.f; + + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + + for (uint j = 0; j < tadLen; ++j) { + float temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } + }; - samediff::Threads::parallel_tad(func,0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); +} #endif +template +FORCEINLINE void softmax_loop(const T* input, T* output, + const Nd4jLong* offsets, Nd4jLong numOfSubArrs, + uint32_t tadLen) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input + offsets[i]; + auto outBuff = output + offsets[i]; + + T max = -DataTypeUtils::max(); + T sum(0.f); + +#pragma omp simd reduction(maxT : max) + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[j]); + +#pragma omp simd reduction(sumT : sum) + for (uint j = 0; j < tadLen; ++j) { + T temp = sd::math::nd4j_exp(inBuff[j] - max); + outBuff[j] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[j] /= sum; + } + }; - template - FORCEINLINE void softmax_loop(const T *input, T *output, const Nd4jLong *offsets, Nd4jLong numOfSubArrs, uint32_t tadLen) { - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input + offsets[i]; - auto outBuff = output + offsets[i]; - - T max = -DataTypeUtils::max(); - T sum(0.f); - -#pragma omp simd reduction(maxT:max) - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[j]); - -#pragma omp simd reduction(sumT:sum) - for (uint j = 0; j < tadLen; ++j) { - T temp = sd::math::nd4j_exp(inBuff[j] - max); - outBuff[j] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; - } - }; - - samediff::Threads::parallel_tad(func,0, numOfSubArrs); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); +} ////////////////////////////////////////////////////////////////////////// - template - static void softmax_(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - const int rank = input.rankOf(); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) - softMaxForVector_(input.buffer(), input.shapeInfo(), output.buffer(), output.shapeInfo()); - else - output = 1.; - } - else if(input.isSameShapeStrict(output)) { - - TadPack tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimension); - auto tadShapeInfo = tadPack.primaryShapeInfo(); - auto tadOffsets = tadPack.primaryOffsets(); - const uint numOfSubArrs = tadPack.numberOfTads(); - const uint tadLen = shape::length(tadShapeInfo); - - if(shape::elementWiseStride(tadShapeInfo) == 1){ - auto inBuff = input.bufferAsT(); - T *outBuff = output.bufferAsT(); - - softmax_loop(inBuff, outBuff, tadOffsets, numOfSubArrs, tadLen); - } - else { - - uint inShapeInfoCast[MAX_RANK]; - bool canCast = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, inShapeInfoCast); - - auto offsets = new Nd4jLong[tadLen]; - shape::calcOffsets(tadShapeInfo, offsets); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inBuff = input.bufferAsT() + tadOffsets[i]; - auto outBuff = output.bufferAsT() + tadOffsets[i]; - - T max = -DataTypeUtils::max(); - T sum = 0.f; - - for (uint j = 0; j < tadLen; ++j) - max = sd::math::nd4j_max(max, inBuff[offsets[j]]); - - for (uint j = 0; j < tadLen; ++j) { - T temp = sd::math::nd4j_exp(inBuff[offsets[j]] - max); - outBuff[offsets[j]] = temp; - sum += temp; - } - - for (uint j = 0; j < tadLen; ++j) - outBuff[offsets[j]] /= sum; - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - - delete []offsets; - } - } - else { - NDArray max = input.reduceAlongDimension(sd::reduce::Max, {dimension}, true); - input.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), max, output, false); - output.applyTransform(sd::transform::Exp, output); - NDArray sum = output.reduceAlongDimension(sd::reduce::Sum, {dimension}, true); - output /= sum; - } - } - - - /////////////////////////////////////////////////////////////////// - void softmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { +template +static void softmax_(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimension) { + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) + softMaxForVector_(input.buffer(), input.shapeInfo(), output.buffer(), + output.shapeInfo()); + else + output = 1.; + } else if (input.isSameShapeStrict(output)) { + TadPack tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimension); + auto tadShapeInfo = tadPack.primaryShapeInfo(); + auto tadOffsets = tadPack.primaryOffsets(); + const uint numOfSubArrs = tadPack.numberOfTads(); + const uint tadLen = shape::length(tadShapeInfo); + + if (shape::elementWiseStride(tadShapeInfo) == 1) { + auto inBuff = input.bufferAsT(); + T* outBuff = output.bufferAsT(); + + softmax_loop(inBuff, outBuff, tadOffsets, numOfSubArrs, tadLen); + } else { + uint inShapeInfoCast[MAX_RANK]; + bool canCast = + sd::DataTypeUtils::castShapeInfo(tadShapeInfo, inShapeInfoCast); + + auto offsets = new Nd4jLong[tadLen]; + shape::calcOffsets(tadShapeInfo, offsets); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto inBuff = input.bufferAsT() + tadOffsets[i]; + auto outBuff = output.bufferAsT() + tadOffsets[i]; + + T max = -DataTypeUtils::max(); + T sum = 0.f; + + for (uint j = 0; j < tadLen; ++j) + max = sd::math::nd4j_max(max, inBuff[offsets[j]]); + + for (uint j = 0; j < tadLen; ++j) { + T temp = sd::math::nd4j_exp(inBuff[offsets[j]] - max); + outBuff[offsets[j]] = temp; + sum += temp; + } + + for (uint j = 0; j < tadLen; ++j) outBuff[offsets[j]] /= sum; + } + }; - BUILD_SINGLE_SELECTOR(input.dataType(), softmax_, (context, input, output, dimension), FLOAT_TYPES); - } + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } + delete[] offsets; } -} \ No newline at end of file + } else { + NDArray max = + input.reduceAlongDimension(sd::reduce::Max, {dimension}, true); + input.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), max, output, + false); + output.applyTransform(sd::transform::Exp, output); + NDArray sum = + output.reduceAlongDimension(sd::reduce::Sum, {dimension}, true); + output /= sum; + } +} + +/////////////////////////////////////////////////////////////////// +void softmax(sd::LaunchContext* context, const NDArray& input, NDArray& output, + const int dimension) { + BUILD_SINGLE_SELECTOR(input.dataType(), softmax_, + (context, input, output, dimension), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index 24e8ef317e86..56e1287e7fa0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -17,86 +17,103 @@ // // @author GS // -#include +#include "../solve.h" + #include #include #include #include +#include -#include "../triangular_solve.h" #include "../lup.h" -#include "../solve.h" +#include "../triangular_solve.h" namespace sd { namespace ops { namespace helpers { -// --------------------------------------------------------------------------------------------------------------------------------------- // - template - static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto rows = input->sizeAt(-2); - output->assign(input); - - auto batchLoop = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - for (Nd4jLong r = 0; r < rows; r++) { - for (Nd4jLong c = 0; c < r; c++) { - math::nd4j_swap(outputPart[batch].t(r, c) , outputPart[batch].t(c, r)); - } - } - } - }; - samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); - } - -// --------------------------------------------------------------------------------------------------------------------------------------- // - template - static int solveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { - - // stage 1: LU decomposition batched - auto leftOutput = leftInput->ulike(); - auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); - auto permutations = NDArrayFactory::create('c', permuShape, context); - helpers::lu(context, leftInput, &leftOutput, &permutations); - auto P = leftInput->ulike(); //permutations batched matrix - P.nullify(); // to fill up matricies with zeros - auto PPart = P.allTensorsAlongDimension({-2,-1}); - auto permutationsPart = permutations.allTensorsAlongDimension({-1}); +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +template +static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto rows = input->sizeAt(-2); + output->assign(input); - for (auto batch = 0; batch < permutationsPart.size(); ++batch) { - for (Nd4jLong row = 0; row < PPart[batch].rows(); ++row) { - PPart[batch].t(row, permutationsPart[batch].t(row)) = T(1.f); - } + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + for (Nd4jLong r = 0; r < rows; r++) { + for (Nd4jLong c = 0; c < r; c++) { + math::nd4j_swap(outputPart[batch].t(r, c), + outputPart[batch].t(c, r)); } + } + } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); +} - auto leftLower = leftOutput.dup(); - auto rightOutput = rightInput->ulike(); - auto rightPermuted = rightOutput.ulike(); - MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); - ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); - for (auto i = 0; i < leftLowerPart.size(); i++) { - for (Nd4jLong r = 0; r < leftLowerPart[i].rows(); r++) - leftLowerPart[i].t(r,r) = (T)1.f; - } - // stage 2: triangularSolveFunctor for Lower with given b - helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); - // stage 3: triangularSolveFunctor for Upper with output of previous stage - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +template +static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool const adjoint, + NDArray* output) { + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); + permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto P = leftInput->ulike(); // permutations batched matrix + P.nullify(); // to fill up matricies with zeros + auto PPart = P.allTensorsAlongDimension({-2, -1}); + auto permutationsPart = permutations.allTensorsAlongDimension({-1}); - return Status::OK(); + for (auto batch = 0; batch < permutationsPart.size(); ++batch) { + for (Nd4jLong row = 0; row < PPart[batch].rows(); ++row) { + PPart[batch].t(row, permutationsPart[batch].t(row)) = T(1.f); } + } -// --------------------------------------------------------------------------------------------------------------------------------------- // - int solveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool const adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); - } -// --------------------------------------------------------------------------------------------------------------------------------------- // - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); - } -// --------------------------------------------------------------------------------------------------------------------------------------- // + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto rightPermuted = rightOutput.ulike(); + MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); + ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); + for (auto i = 0; i < leftLowerPart.size(); i++) { + for (Nd4jLong r = 0; r < leftLowerPart[i].rows(); r++) + leftLowerPart[i].t(r, r) = (T)1.f; + } + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, + false, &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, + false, output); + + return Status::OK(); } + +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool const adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, + (context, leftInput, rightInput, adjoint, output), + FLOAT_TYPES); } +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, + (context, input, output), FLOAT_TYPES); } +// --------------------------------------------------------------------------------------------------------------------------------------- +// // +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp index 48c6c490318d..fd70c0462444 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp @@ -14,118 +14,117 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleh Semeniv (oleg.semeniv@gmail.com) - // +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// -#include #include +#include namespace sd { namespace ops { namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static void split_(const NDArray& input, const std::vector& outArrs, + const int axis) { + uint numSplits = outArrs.size(); - ////////////////////////////////////////////////////////////////////////// - template - static void split_(const NDArray& input, const std::vector& outArrs, const int axis) { - uint numSplits = outArrs.size(); - - const auto sizeofT = input.sizeOfT(); - - auto xBuff = input.bufferAsT(); - - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; - - if (luckCase1) { - for (uint i = 0; i < numSplits; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if (!luckCase1) - break; - } - } - - if (luckCase1) { - - T* x = const_cast(xBuff); - for (uint i = 0; i < numSplits; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf(); - memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); - x += memAmountToCopy; - } - return; - } - - const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; - bool areOutsContin = true; - bool allSameOrder = true; - - if (isXcontin) { - for (uint i = 0; i < numSplits; ++i) { - areOutsContin &= outArrs[i]->strideAt(axis) == 1; - allSameOrder &= outArrs[i]->ordering() == input.ordering(); - if (!areOutsContin || !allSameOrder) - break; - } - } - - const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + const auto sizeofT = input.sizeOfT(); - if (luckCase2) { + auto xBuff = input.bufferAsT(); - const auto xDim = input.sizeAt(axis); + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; - for (Nd4jLong i = 0; i < input.lengthOf() / xDim; ++i) { + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; + } + } + + if (luckCase1) { + T* x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; + } + return; + } + + const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; + bool areOutsContin = true; + bool allSameOrder = true; + + if (isXcontin) { + for (uint i = 0; i < numSplits; ++i) { + areOutsContin &= outArrs[i]->strideAt(axis) == 1; + allSameOrder &= outArrs[i]->ordering() == input.ordering(); + if (!areOutsContin || !allSameOrder) break; + } + } - auto x = xBuff + xDim * i; + const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; - for (uint j = 0; j < numSplits; ++j) { - const auto zDim = outArrs[j]->sizeAt(axis); - T* z = outArrs[j]->bufferAsT() + zDim * i; - memcpy(z, x, zDim * sizeofT); - z += zDim; - x += zDim; - } - } + if (luckCase2) { + const auto xDim = input.sizeAt(axis); - return; - } + for (Nd4jLong i = 0; i < input.lengthOf() / xDim; ++i) { + auto x = xBuff + xDim * i; - uint zDim = outArrs[0]->sizeAt(axis); - // general case + for (uint j = 0; j < numSplits; ++j) { + const auto zDim = outArrs[j]->sizeAt(axis); + T* z = outArrs[j]->bufferAsT() + zDim * i; + memcpy(z, x, zDim * sizeofT); + z += zDim; + x += zDim; + } + } - auto func = PRAGMA_THREADS_FOR{ + return; + } - int coords[MAX_RANK], temp; + uint zDim = outArrs[0]->sizeAt(axis); + // general case - for (auto i = start; i < stop; i += increment) { + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - shape::index2coordsCPU(start, i, input.shapeInfo(), coords); - const auto xOffset = shape::getOffset(input.shapeInfo(), coords); + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, input.shapeInfo(), coords); + const auto xOffset = shape::getOffset(input.shapeInfo(), coords); - uint outArrIdx = 0; + uint outArrIdx = 0; - temp = coords[axis]; + temp = coords[axis]; - while (coords[axis] >= zDim) { - coords[axis] -= zDim; - ++outArrIdx; - } + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } - T* z = outArrs[outArrIdx]->bufferAsT(); - const auto zOffset = shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); - z[zOffset] = xBuff[xOffset]; + T* z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = + shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, input.lengthOf()); - } + samediff::Threads::parallel_for(func, 0, input.lengthOf()); +} - void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { - BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES); - } - } - } +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis) { + BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp index ecd5ead2be2d..e106750b3f77 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp @@ -15,337 +15,366 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // // @author Yurii Shyrma, created on 05.12.2017 // -#include #include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray activation(const NDArray& arr) { - - // return (const_cast&>(arr)).template transform>(); - auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, result); - return result; + // return (const_cast&>(arr)).template + // transform>(); + auto result = NDArray(&arr, false, arr.getContext()); + (const_cast(arr)).applyTransform(transform::Tanh, result); + return result; } - ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); + return (const_cast(arr)).transform(transform::Sigmoid); } - ////////////////////////////////////////////////////////////////////////// -void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { - - // x input [bS x inSize], bS - batch size, inSize - number of features - // c0 previous cell state c [bS x inSize], that is at previous time step t-1 - // w weights [inSize x 3*inSize] - // b biases [2*inSize] +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { + // x input [bS x inSize], bS - batch size, inSize - number of features + // c0 previous cell state c [bS x inSize], that is at previous time step t-1 + // w weights [inSize x 3*inSize] + // b biases [2*inSize] - // h current cell output [bS x inSize], that is at current time step t - // c current cell state [bS x inSize], that is at current time step t + // h current cell output [bS x inSize], that is at current time step t + // c current cell state [bS x inSize], that is at current time step t - const int inSize = x->sizeAt(1); // inSize - number of features + const int inSize = x->sizeAt(1); // inSize - number of features - auto z = mmul(*x, *w); // [bS x 3*inSize] + auto z = mmul(*x, *w); // [bS x 3*inSize] - // forget gate = sigmoid(x*Wf + bf) - auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize})); + // forget gate = sigmoid(x*Wf + bf) + auto f = sigmoid(z({0, 0, inSize, 2 * inSize}) + (*b)({0, inSize})); - // reset gate = sigmoid(x*Wr + br) - auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize})); + // reset gate = sigmoid(x*Wr + br) + auto r = + sigmoid(z({0, 0, 2 * inSize, 3 * inSize}) + (*b)({inSize, 2 * inSize})); - // ◦ means element-wise product or so called Hadamard product - // current sell state = f◦c0 + (1 - f)◦(x*Wc) - c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) ); - // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); + // ◦ means element-wise product or so called Hadamard product + // current sell state = f◦c0 + (1 - f)◦(x*Wc) + c->assign(f * (*c0) + (1.f - f) * z({0, 0, 0, inSize})); + // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); - // current cell output = r◦activation(c) + (1 - r)◦x - h->assign( r * activation(*c) + (1.f - r) * (*x) ); - // *h = r * (activation(c) - *x) + *x; + // current cell output = r◦activation(c) + (1 - r)◦x + h->assign(r * activation(*c) + (1.f - r) * (*x)); + // *h = r * (activation(c) - *x) + *x; } - ////////////////////////////////////////////////////////////////////////// -void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c) { + // x input [bS x inSize x time] + // c0 initial cell state (at time step = 0) [bS x inSize], + // w weights, [3*inSize x inSize] + // b biases, [2*inSize] - // x input [bS x inSize x time] - // c0 initial cell state (at time step = 0) [bS x inSize], - // w weights, [3*inSize x inSize] - // b biases, [2*inSize] + // h cell outputs [bS x inSize x time] + // c cell states [bS x inSize x time] - // h cell outputs [bS x inSize x time] - // c cell states [bS x inSize x time] + auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] - auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] + const int time = x->sizeAt(2); - const int time = x->sizeAt(2); + NDArray ct_1(*c0); - NDArray ct_1(*c0); + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({0, 0, 0, 0, t, t + 1}); + auto ht = (*h)({0, 0, 0, 0, t, t + 1}); + auto ct = (*c)({0, 0, 0, 0, t, t + 1}); - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({0,0, 0,0, t,t+1}); - auto ht = (*h)({0,0, 0,0, t,t+1}); - auto ct = (*c)({0,0, 0,0, t,t+1}); - - helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); - ct_1.assign(ct); - } + helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); + ct_1.assign(ct); + } } ////////////////////////////////////////////////////////////////////////// template -static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - - // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features - // w 2d tensor of weights [2*K x 6*K] - // b row of biases with twice length [4*K] - // c0 2d tensor of initial state [bS x 2*K] at time t=0 - // mask optional, 2d tensor of dropout mask [bS x 2*K] - - // ht [time x bS x 2*K] - // ct [time x bS x 2*K] - - const Nd4jLong time = x->sizeAt(0); // time - number of time steps - const Nd4jLong bS = x->sizeAt(1); // bS - batch size - const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - const Nd4jLong d2 = 2*K; - const Nd4jLong ncols = bS*d2; - const Nd4jLong ncolsWi = 3*ncols; - - T* pI = x->bufferAsT(); - T* pWi = wi.bufferAsT(); - T* pBias = const_cast(b)->bufferAsT(); - T* pInit = const_cast(c0)->bufferAsT(); - T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; - T* pHt = ht->bufferAsT(); - T* pCt = ct->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto col = start; col < stop; col++) { - const auto colNum = col % d2; - bool flip = colNum >= K; - T maskVal = mask ? *(pMask + col) : T(1); - T cur = *(pInit + col); - T bF = *(pBias + colNum); - T bR = *(pBias + colNum + d2); - T *pWiVal = pWi + 3 * col; - T *pIVal = pI + col; - T *pHtVal = pHt + col; - T *pCtVal = pCt + col; - - if (flip) { - const auto step = (time - 1) * ncols; - pIVal += step; - pHtVal += step; - pCtVal += step; - pWiVal += (time - 1) * ncolsWi; - } - - auto ncolsRev = flip ? -ncols : ncols; - auto ncolsWiRev = flip ? -ncolsWi : ncolsWi; - - for (Nd4jLong t = 0; t < time; ++t) { - // evaluate sigmoids - T ft = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[1] + bF))); - T rt = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[2] + bR))); - - cur = (cur - *pWiVal) * ft + *pWiVal; - *pCtVal = cur; - T val = sd::math::nd4j_tanh(cur); - *pHtVal = (val * maskVal - *pIVal) * rt + *pIVal; - - pIVal += ncolsRev; - pWiVal += ncolsWiRev; - pCtVal += ncolsRev; - pHtVal += ncolsRev; - } - } - }; - - samediff::Threads::parallel_tad(func, 0, ncols); +static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* mask, NDArray* ht, + NDArray* ct) { + // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - + // batch size, K - number of features w 2d tensor of weights [2*K x 6*K] + // b row of biases with twice length [4*K] + // c0 2d tensor of initial state [bS x 2*K] at time t=0 + // mask optional, 2d tensor of dropout mask [bS x 2*K] + + // ht [time x bS x 2*K] + // ct [time x bS x 2*K] + + const Nd4jLong time = x->sizeAt(0); // time - number of time steps + const Nd4jLong bS = x->sizeAt(1); // bS - batch size + const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features + + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + const Nd4jLong d2 = 2 * K; + const Nd4jLong ncols = bS * d2; + const Nd4jLong ncolsWi = 3 * ncols; + + T* pI = x->bufferAsT(); + T* pWi = wi.bufferAsT(); + T* pBias = const_cast(b)->bufferAsT(); + T* pInit = const_cast(c0)->bufferAsT(); + T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; + T* pHt = ht->bufferAsT(); + T* pCt = ct->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto col = start; col < stop; col++) { + const auto colNum = col % d2; + bool flip = colNum >= K; + T maskVal = mask ? *(pMask + col) : T(1); + T cur = *(pInit + col); + T bF = *(pBias + colNum); + T bR = *(pBias + colNum + d2); + T* pWiVal = pWi + 3 * col; + T* pIVal = pI + col; + T* pHtVal = pHt + col; + T* pCtVal = pCt + col; + + if (flip) { + const auto step = (time - 1) * ncols; + pIVal += step; + pHtVal += step; + pCtVal += step; + pWiVal += (time - 1) * ncolsWi; + } + + auto ncolsRev = flip ? -ncols : ncols; + auto ncolsWiRev = flip ? -ncolsWi : ncolsWi; + + for (Nd4jLong t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[1] + bF))); + T rt = (1.) / (1. + sd::math::nd4j_exp(-(pWiVal[2] + bR))); + + cur = (cur - *pWiVal) * ft + *pWiVal; + *pCtVal = cur; + T val = sd::math::nd4j_tanh(cur); + *pHtVal = (val * maskVal - *pIVal) * rt + *pIVal; + + pIVal += ncolsRev; + pWiVal += ncolsWiRev; + pCtVal += ncolsRev; + pHtVal += ncolsRev; + } + } + }; + + samediff::Threads::parallel_tad(func, 0, ncols); } ////////////////////////////////////////////////////////////////////////// template -static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask, - NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - - // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features - // w 2d tensor of weights [2*K x 6*K] - // b row of biases with twice length 4*K] - // c0 2d tensor of initial state [bS x 2*K] at time t=0 - // ct [time x bS x 2*K] - // inGradC0 [bS x 2*K] - // inGradHt [time x bS x 2*K] - // mask optional, 2d tensor of dropout mask [bS x 2*K] - - // gradI [time x bS x 2*K] - // gradW [time x 2*K x 6*K] - // gradB [4*K] - // gradC0 [bS x 2*K] - - const Nd4jLong time = x->sizeAt(0); // time - number of time steps - const Nd4jLong bS = x->sizeAt(1); - const Nd4jLong K = x->sizeAt(2) / 2; - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] - NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), x->getContext()); - NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), x->getContext()); - - const Nd4jLong d2 = 2*K; - const Nd4jLong ncols = bS*d2; - const Nd4jLong ncolsWi = 3*ncols; - T* pInput = x->bufferAsT(); - T* pWi = wi.bufferAsT(); - T* pBias = const_cast(b)->bufferAsT(); - T* pInit = const_cast(c0)->bufferAsT(); - T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; - T* pState = const_cast(ct)->bufferAsT(); - T* pInGradCt = const_cast(inGradC0)->bufferAsT(); - T* pInGradHt = const_cast(inGradHt)->bufferAsT(); - T* pGradWi = gradWi.bufferAsT(); - T* pGradInput = gradI->bufferAsT(); - T* pGradBias = gradBias.bufferAsT(); - T* pGradInit = gradC0->bufferAsT(); - - auto func = PRAGMA_THREADS_FOR { - for (auto col = start; col < stop; col++) { - T gbF = 0.f; - T gbR = 0.f; - const auto colNum = col % d2; - const bool flip = colNum >= K; - T maskVal = mask ? *(pMask + col) : T(1.); - T cur = *(pInGradCt + col); - T bF = *(pBias + colNum); - T bR = *(pBias + colNum + d2); - T *pWiVal = pWi + 3 * col; - T *pInputVal = pInput + col; - T *pStateVal = pState + col; - T *pInGradHtVal = pInGradHt + col; - T *pGradWiVal = pGradWi + 3 * col; - T *pGradInputVal = pGradInput + col; - - if (!flip) { - const auto stepI = (time - 1) * ncols; - const auto stepW = (time - 1) * ncolsWi; - pInputVal += stepI; - pStateVal += stepI; - pInGradHtVal += stepI; - pGradInputVal += stepI; - pWiVal += stepW; - pGradWiVal += stepW; - } - - Nd4jLong ncolsRev = flip ? -ncols : ncols; - Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi; - - for (Nd4jLong t = 0; t < time; ++t) { - // evaluate sigmoids - T ft = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp(-(*(pWiVal + 1) + bF))); - T rt = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp(-(*(pWiVal + 2) + bR))); - - T val = sd::math::nd4j_tanh(*pStateVal); - T prevVal = (t < time - 1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); - // grad wrt input - *pGradInputVal = *pInGradHtVal - (*pInGradHtVal) * rt; - // grad wrt rt, wiR and bR - T grt = (*pInGradHtVal) * (val * maskVal - *pInputVal) * (rt - rt * rt); - *(pGradWiVal + 2) = grt; - gbR += grt; - // grad wrt state - T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt * val * val) + cur; - // grad wrt wi0 - *pGradWiVal = gradSateVal - gradSateVal * ft; - // grad wrt ft, wi1, and bF - T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft * ft); - *(pGradWiVal + 1) = gft; - gbF += gft; - // grad wrt c_previous - cur = gradSateVal * ft; - - pInputVal -= ncolsRev; - pWiVal -= ncolsWiRev; - pStateVal -= ncolsRev; - pGradWiVal -= ncolsWiRev; - pGradInputVal -= ncolsRev; - pInGradHtVal -= ncolsRev; - } - *(pGradBias + col) = gbF; - *(pGradBias + col + ncols) = gbR; - *(pGradInit + col) = cur; - } - }; - - samediff::Threads::parallel_tad(func, 0, ncols); - - // gradB - gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] - - // gradW - x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] - MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K] -} - - -void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES); -} -void sruBIBP(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES); -} - +static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradHt, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0) { + // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - + // batch size, K - number of features w 2d tensor of weights [2*K x 6*K] b + // row of biases with twice length 4*K] c0 2d tensor of initial state [bS x + // 2*K] at time t=0 ct [time x bS x 2*K] inGradC0 [bS x 2*K] inGradHt [time x + // bS x 2*K] mask optional, 2d tensor of dropout mask [bS x 2*K] + + // gradI [time x bS x 2*K] + // gradW [time x 2*K x 6*K] + // gradB [4*K] + // gradC0 [bS x 2*K] + + const Nd4jLong time = x->sizeAt(0); // time - number of time steps + const Nd4jLong bS = x->sizeAt(1); + const Nd4jLong K = x->sizeAt(2) / 2; + + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = + mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K] + NDArray gradBias(x->ordering(), {bS, 4 * K}, x->dataType(), x->getContext()); + NDArray gradWi(x->ordering(), {time, bS, 6 * K}, x->dataType(), + x->getContext()); + + const Nd4jLong d2 = 2 * K; + const Nd4jLong ncols = bS * d2; + const Nd4jLong ncolsWi = 3 * ncols; + T* pInput = x->bufferAsT(); + T* pWi = wi.bufferAsT(); + T* pBias = const_cast(b)->bufferAsT(); + T* pInit = const_cast(c0)->bufferAsT(); + T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; + T* pState = const_cast(ct)->bufferAsT(); + T* pInGradCt = const_cast(inGradC0)->bufferAsT(); + T* pInGradHt = const_cast(inGradHt)->bufferAsT(); + T* pGradWi = gradWi.bufferAsT(); + T* pGradInput = gradI->bufferAsT(); + T* pGradBias = gradBias.bufferAsT(); + T* pGradInit = gradC0->bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + for (auto col = start; col < stop; col++) { + T gbF = 0.f; + T gbR = 0.f; + const auto colNum = col % d2; + const bool flip = colNum >= K; + T maskVal = mask ? *(pMask + col) : T(1.); + T cur = *(pInGradCt + col); + T bF = *(pBias + colNum); + T bR = *(pBias + colNum + d2); + T* pWiVal = pWi + 3 * col; + T* pInputVal = pInput + col; + T* pStateVal = pState + col; + T* pInGradHtVal = pInGradHt + col; + T* pGradWiVal = pGradWi + 3 * col; + T* pGradInputVal = pGradInput + col; + + if (!flip) { + const auto stepI = (time - 1) * ncols; + const auto stepW = (time - 1) * ncolsWi; + pInputVal += stepI; + pStateVal += stepI; + pInGradHtVal += stepI; + pGradInputVal += stepI; + pWiVal += stepW; + pGradWiVal += stepW; + } + + Nd4jLong ncolsRev = flip ? -ncols : ncols; + Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi; + + for (Nd4jLong t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = + ((T)1.) / ((T)1. + sd::math::nd4j_exp(-(*(pWiVal + 1) + bF))); + T rt = + ((T)1.) / ((T)1. + sd::math::nd4j_exp(-(*(pWiVal + 2) + bR))); + + T val = sd::math::nd4j_tanh(*pStateVal); + T prevVal = + (t < time - 1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); + // grad wrt input + *pGradInputVal = *pInGradHtVal - (*pInGradHtVal) * rt; + // grad wrt rt, wiR and bR + T grt = (*pInGradHtVal) * (val * maskVal - *pInputVal) * (rt - rt * rt); + *(pGradWiVal + 2) = grt; + gbR += grt; + // grad wrt state + T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt * val * val) + cur; + // grad wrt wi0 + *pGradWiVal = gradSateVal - gradSateVal * ft; + // grad wrt ft, wi1, and bF + T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft * ft); + *(pGradWiVal + 1) = gft; + gbF += gft; + // grad wrt c_previous + cur = gradSateVal * ft; + + pInputVal -= ncolsRev; + pWiVal -= ncolsWiRev; + pStateVal -= ncolsRev; + pGradWiVal -= ncolsWiRev; + pGradInputVal -= ncolsRev; + pInGradHtVal -= ncolsRev; + } + *(pGradBias + col) = gbF; + *(pGradBias + col + ncols) = gbR; + *(pGradInit + col) = cur; + } + }; -BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES); + samediff::Threads::parallel_tad(func, 0, ncols); + // gradB + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] + // gradW + x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS] + MmulHelper::mmul( + x, &gradWi, gradW, 1., + 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K] } + +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct) { + BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), + FLOAT_TYPES); } +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0) { + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBIBP_, + (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void sruBI_, + (NDArray * x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* mask, NDArray* ht, + NDArray* ct), + FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void sruBIBP_, + (NDArray * x, const NDArray* w, const NDArray* b, + const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradW, + NDArray* gradB, NDArray* gradC0), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd ////////////////////////////////////////////////////////////////////////// // template -// void sruCellBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { - -// NDArray* x = inArrs[0]; // input [bS x inSize], bS - batch size, inSize - number of features -// NDArray* c0 = inArrs[1]; // previous cell state c [bS x inSize], that is at previous time step t-1 -// NDArray* w = inArrs[2]; // weights [inSize x 3*inSize] -// NDArray* b = inArrs[3]; // biases [2*inSize] -// NDArray* dLdC = inArrs[4]; // gradient of the loss func with respect to cell output [bS x inSize] -// NDArray* dLdH = inArrs[5]; // gradient of the loss func with respect to cell state [bS x inSize] - -// NDArray* dLdX = outArrs[0]; // gradient of the loss func with respect to input [bS x inSize], so called epsilon -// NDArray* dLdW = outArrs[1]; // gradient of the loss func with respect to weights [inSize x 3*inSize] -// NDArray* dLdB = outArrs[2]; // gradient of the loss func with respect to biases [2*inSize] -// NDArray* dLdC0 = outArrs[3]; // gradient of the loss func with respect to previous cell state [bS, inSize] +// void sruCellBP(const std::vector*>& inArrs, const +// std::vector*>& outArrs) { + +// NDArray* x = inArrs[0]; // input [bS x inSize], bS - +// batch size, inSize - number of features NDArray* c0 = inArrs[1]; // +// previous cell state c [bS x inSize], that is at previous time step t-1 +// NDArray* w = inArrs[2]; // weights [inSize x +// 3*inSize] NDArray* b = inArrs[3]; // biases +// [2*inSize] NDArray* dLdC = inArrs[4]; // gradient of the +// loss func with respect to cell output [bS x inSize] NDArray* dLdH = +// inArrs[5]; // gradient of the loss func with respect to +// cell state [bS x inSize] + +// NDArray* dLdX = outArrs[0]; // gradient of the loss func +// with respect to input [bS x inSize], so called epsilon NDArray* dLdW +// = outArrs[1]; // gradient of the loss func with respect to +// weights [inSize x 3*inSize] NDArray* dLdB = outArrs[2]; // gradient +// of the loss func with respect to biases [2*inSize] NDArray* dLdC0 = +// outArrs[3]; // gradient of the loss func with respect to +// previous cell state [bS, inSize] // const int inSize = x->sizeAt(1); // inSize - number of features @@ -353,16 +382,17 @@ BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, con // NDArray z = mmul(*x, *w); // [bS x 3*inSize] // // forget gate = sigmoid(x*Wf + bf) -// NDArray f = sigmoid(z({{},{inSize, 2*inSize}}) + (*b)({{0, inSize}})); // [bS, inSize] -// NDArray oneMinusF = 1. - f; +// NDArray f = sigmoid(z({{},{inSize, 2*inSize}}) + (*b)({{0, +// inSize}})); // [bS, inSize] NDArray oneMinusF = 1. - f; // // reset gate = sigmoid(x*Wr + br) -// NDArray r = sigmoid(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, 2*inSize}})); // [bS, inSize] -// NDArray oneMinusR = 1. - r; - -// // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) ); -// // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( r*activation(*c) + ((T)1. - r) * (*x) ); +// NDArray r = sigmoid(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, +// 2*inSize}})); // [bS, inSize] NDArray oneMinusR = 1. - r; +// // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( +// f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) ); +// // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( +// r*activation(*c) + ((T)1. - r) * (*x) ); // //*********** back propagation ***********// // // dCdC0 = f; @@ -378,27 +408,30 @@ BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, con // // dHdC = r * (1 - tanh*tanh) // NDArray dHdC = r * (1. - tanh * tanh); // // dCdX = dCdX + dCdF*dFdX = (1-f)*Wc + dCdF*Wf -// NDArray dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * (*w)({{},{inSize, 2*inSize}}); - +// NDArray dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * +// (*w)({{},{inSize, 2*inSize}}); // // dLdC0 = dLdC * dCdC0 = dLdC * f // dLdC0->assign((*dLdC) * f); - -// // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf +// // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = +// dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf // (*dLdB)({{0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * dCdF * dFdBf); // // dLdBr = dLdH * dHdR * dRdBr // (*dLdB)({{inSize, 2*inSize}}).assign((*dLdH) * dHdR * dRdBr) - -// // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x -// (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * (*x)); +// // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = +// (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x +// (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * +// (*x)); // // dLdWf = dLdBf * x // (*dLdW)({{}, {inSize, 2*inSize}}).assign((*dLdB)({{0, inSize}}) * (*x)); // // dLdWr = dLdBr * x -// (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) * (*x)); - +// (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) +// * (*x)); -// // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX -// dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX); +// // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + +// dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX +// dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, +// 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX); // } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp index 694ced4cbbd1..aee01cf90a73 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/stack.cpp @@ -18,105 +18,111 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include #include #include - +#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void stack_(const std::vector& inArrs, NDArray& output, const int dim) { - - const int numOfSubArrs = inArrs.size(); - - if(inArrs[0]->rankOf() == 0) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.p(i, inArrs[i]->t(0)); - }; - - samediff::Threads::parallel_for(func, 0, numOfSubArrs); - } - else { - - auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); - auto zTadShapeInfo = zTadPack.primaryShapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - - for (auto i = start; i < stop; i++) { - - void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(inArrs[0]->getContext(), transform::Assign, - inArrs[i]->buffer(), inArrs[i]->shapeInfo(), nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, - zBuff, zTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } - +static void stack_(const std::vector& inArrs, NDArray& output, + const int dim) { + const int numOfSubArrs = inArrs.size(); + + if (inArrs[0]->rankOf() == 0) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) output.p(i, inArrs[i]->t(0)); + }; + + samediff::Threads::parallel_for(func, 0, numOfSubArrs); + } else { + auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), + ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); + auto zTadShapeInfo = zTadPack.primaryShapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + inArrs[0]->getContext(), transform::Assign, inArrs[i]->buffer(), + inArrs[i]->shapeInfo(), nullptr /*input specialBuffer*/, + nullptr /*input specialShapeInfo*/, zBuff, zTadShapeInfo, + nullptr /*output specialBuffer*/, + nullptr /*output specialShapeInfo*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } //////////////////////////////////////////////////////////////////////// -void stack(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int dim) { - BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), LIBND4J_TYPES); +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), + LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector& inArrs, NDArray& output, const int dim), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void stack_, + (const std::vector& inArrs, + NDArray& output, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// template -static void unstack_(const NDArray& input, const std::vector& outArrs, const int dim) { - - const int numOfSubArrs = outArrs.size(); - - if(outArrs[0]->rankOf() == 0) { - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - outArrs[i]->p(0, input.t(i)); - }; - - samediff::Threads::parallel_for(func, 0, numOfSubArrs); - } - else { - - auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); - auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign, - xBuff, xTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/, - outArrs[i]->buffer(), outArrs[i]->shapeInfo(), nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/, - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfSubArrs); - } +static void unstack_(const NDArray& input, const std::vector& outArrs, + const int dim) { + const int numOfSubArrs = outArrs.size(); + + if (outArrs[0]->rankOf() == 0) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) outArrs[i]->p(0, input.t(i)); + }; + + samediff::Threads::parallel_for(func, 0, numOfSubArrs); + } else { + auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), + ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); + auto xTadShapeInfo = xTadPack.primaryShapeInfo(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]); + + NativeOpExecutioner::execTransformAny( + input.getContext(), transform::Assign, xBuff, xTadShapeInfo, + nullptr /*input specialBuffer*/, nullptr /*input specialShapeInfo*/, + outArrs[i]->buffer(), outArrs[i]->shapeInfo(), + nullptr /*output specialBuffer*/, + nullptr /*output specialShapeInfo*/, nullptr, nullptr, nullptr, + false /*allowParallelism*/); + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfSubArrs); + } } //////////////////////////////////////////////////////////////////////// -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), LIBND4J_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void unstack_, (const NDArray& input, const std::vector& outArrs, const int dim), LIBND4J_TYPES); - +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), + LIBND4J_TYPES); } -} -} - +BUILD_SINGLE_TEMPLATE(template void unstack_, + (const NDArray& input, + const std::vector& outArrs, const int dim), + LIBND4J_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index a5a07776d325..3fd116924e87 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -18,966 +18,988 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 03.01.2018 // -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV ) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.getContext()); - _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext()); - _m.assign(0.); - - if (_calcU) - _u = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); - else - _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext()); - _u.assign(0.); - - if (_calcV) { - _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext()); - _v.assign(0.); - } - - evalData(matrix); +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, + matrix.getContext()); + _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.getContext()); + _m.assign(0.); + + if (_calcU) + _u = NDArrayFactory::create( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); + else + _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, + matrix.getContext()); + _u.assign(0.); + + if (_calcV) { + _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, + matrix.getContext()); + _v.assign(0.); + } + + evalData(matrix); } ////////////////////////////////////////////////////////////////////////// template -SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const bool calcV, const bool fullUV, const char t) { - - if(matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error("ops::helpers::SVD constructor: input array must be 2D matrix !"); - - const int rows = matrix.sizeAt(0); - const int cols = matrix.sizeAt(1); - - if(cols > rows) { - - _transp = true; - _diagSize = rows; - } - else { - - _transp = false; - _diagSize = cols; - } - - _switchSize = switchSize; - _calcU = calcU; - _calcV = calcV; - _fullUV = fullUV; - - if (_transp) - math::nd4j_swap(_calcU, _calcV); - - _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, matrix.getContext()); - _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, matrix.getContext()); - _m.assign(0.f); - - if (_calcU) - _u = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); - else - _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext()); - _u.assign(0.); - - if (_calcV) { - _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, matrix.getContext()); - _v.assign(0.); - } +SVD::SVD(const NDArray& matrix, const int switchSize, const bool calcU, + const bool calcV, const bool fullUV, const char t) { + if (matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error( + "ops::helpers::SVD constructor: input array must be 2D matrix !"); + + const int rows = matrix.sizeAt(0); + const int cols = matrix.sizeAt(1); + + if (cols > rows) { + _transp = true; + _diagSize = rows; + } else { + _transp = false; + _diagSize = cols; + } + + _switchSize = switchSize; + _calcU = calcU; + _calcV = calcV; + _fullUV = fullUV; + + if (_transp) math::nd4j_swap(_calcU, _calcV); + + _s = NDArrayFactory::create(matrix.ordering(), {_diagSize, 1}, + matrix.getContext()); + _m = NDArrayFactory::create(matrix.ordering(), {_diagSize + 1, _diagSize}, + matrix.getContext()); + _m.assign(0.f); + + if (_calcU) + _u = NDArrayFactory::create( + matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext()); + else + _u = NDArrayFactory::create(matrix.ordering(), {2, _diagSize + 1}, + matrix.getContext()); + _u.assign(0.); + + if (_calcV) { + _v = NDArrayFactory::create(matrix.ordering(), {_diagSize, _diagSize}, + matrix.getContext()); + _v.assign(0.); + } } - ////////////////////////////////////////////////////////////////////////// template void SVD::deflation1(int col1, int shift, int ind, int size) { - - if(ind <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !"); - - int first = col1 + shift; - T cos = _m.e(first, first); - T sin = _m.e(first+ind, first); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - - _m.p(first+ind, first+ind, 0.f); - return; - } - - cos /= denom; - sin /= denom; - - _m.p(first,first, denom); - _m.p(first+ind, first, 0.f); - _m.p(first+ind, first+ind, 0.f); - - auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); - rotation.p(0, 0, cos); - rotation.p(0, 1, -sin); - rotation.p(1, 0, sin); - rotation.p(1, 1, cos); - - if (_calcU) { - auto temp = _u({col1,col1+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1, col1+ind, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1, col1+ind, _u, rotation); + if (ind <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation1 method: input int must satisfy " + "condition ind > 0 !"); + + int first = col1 + shift; + T cos = _m.e(first, first); + T sin = _m.e(first + ind, first); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.p(first + ind, first + ind, 0.f); + return; + } + + cos /= denom; + sin /= denom; + + _m.p(first, first, denom); + _m.p(first + ind, first, 0.f); + _m.p(first + ind, first + ind, 0.f); + + auto rotation = + NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); + rotation.p(0, 0, cos); + rotation.p(0, 1, -sin); + rotation.p(1, 0, sin); + rotation.p(1, 1, cos); + + if (_calcU) { + auto temp = _u({col1, col1 + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1, col1 + ind, temp, rotation); + } else + JacobiSVD::mulRotationOnRight(col1, col1 + ind, _u, rotation); } ////////////////////////////////////////////////////////////////////////// template -void SVD::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size) { - - if(ind1 >= ind2) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input intes must satisfy condition ind1 < ind2 !"); - - if(size <= 0) - throw std::runtime_error("ops::helpers::SVD::deflation2 method: input size must satisfy condition size > 0 !"); - - T cos = _m.e(col1M+ind1, col1M); - T sin = _m.e(col1M+ind2, col1M); - T denom = math::nd4j_sqrt(cos*cos + sin*sin); - - if (denom == (T)0.) { - - _m.p(col1M + ind1, col1M + ind1, _m.e(col1M + ind2, col1M + ind2)); - return; - } - - cos /= denom; - sin /= denom; - _m.p(col1M + ind1, col1M, denom); - _m.p(col1M + ind2, col1M + ind2, _m.e(col1M + ind1, col1M + ind1)); - _m.p(col1M + ind2, col1M, 0.f); - - auto rotation = NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); - rotation.p(0,0, cos); - rotation.p(1,1, cos); - - rotation.p(0,1, -sin); - rotation.p(1,0, sin); - - if (_calcU) { - auto temp = _u({col1U,col1U+size+1, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, temp, rotation); - } - else - JacobiSVD::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation); - - if (_calcV) { - auto temp = _v({row1W,row1W+size, 0,0}, true); - JacobiSVD::mulRotationOnRight(col1W+ind1, col1W+ind2, temp, rotation); - } +void SVD::deflation2(int col1U, int col1M, int row1W, int col1W, int ind1, + int ind2, int size) { + if (ind1 >= ind2) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input intes must satisfy " + "condition ind1 < ind2 !"); + + if (size <= 0) + throw std::runtime_error( + "ops::helpers::SVD::deflation2 method: input size must satisfy " + "condition size > 0 !"); + + T cos = _m.e(col1M + ind1, col1M); + T sin = _m.e(col1M + ind2, col1M); + T denom = math::nd4j_sqrt(cos * cos + sin * sin); + + if (denom == (T)0.) { + _m.p(col1M + ind1, col1M + ind1, _m.e(col1M + ind2, col1M + ind2)); + return; + } + + cos /= denom; + sin /= denom; + _m.p(col1M + ind1, col1M, denom); + _m.p(col1M + ind2, col1M + ind2, _m.e(col1M + ind1, col1M + ind1)); + _m.p(col1M + ind2, col1M, 0.f); + + auto rotation = + NDArrayFactory::create(_m.ordering(), {2, 2}, _m.getContext()); + rotation.p(0, 0, cos); + rotation.p(1, 1, cos); + + rotation.p(0, 1, -sin); + rotation.p(1, 0, sin); + + if (_calcU) { + auto temp = _u({col1U, col1U + size + 1, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, temp, + rotation); + } else + JacobiSVD::mulRotationOnRight(col1U + ind1, col1U + ind2, _u, rotation); + + if (_calcV) { + auto temp = _v({row1W, row1W + size, 0, 0}, true); + JacobiSVD::mulRotationOnRight(col1W + ind1, col1W + ind2, temp, + rotation); + } } ////////////////////////////////////////////////////////////////////////// -// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively +// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) +// inclusively template -void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, int shift) -{ - - const int len = col2 + 1 - col1; - - auto colVec0 = new NDArray(_m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true)); - - auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c'); - - const T almostZero = DataTypeUtils::min(); - T maxElem; - if(len == 1) - maxElem = math::nd4j_abs(diagInterval.template e(0)); - else - maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e(0); - T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); - - T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); - T epsBig = (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); - - if(diagInterval.template e(0) < epsBig) - diagInterval.p(Nd4jLong(0), epsBig); - - for(int i=1; i < len; ++i) - if(math::nd4j_abs(colVec0->template e(i)) < eps) - colVec0->p(i, 0.f); - - for(int i=1; i < len; i++) - if(diagInterval.template e(i) < epsBig) { - deflation1(col1, shift, i, len); - for(int i = 0; i < len; ++i) - diagInterval.p(i, _m.e(col1+shift+i,col1+shift+i)); - } - +void SVD::deflation(int col1, int col2, int ind, int row1W, int col1W, + int shift) { + const int len = col2 + 1 - col1; + + auto colVec0 = new NDArray( + _m({col1 + shift, col1 + shift + len, col1 + shift, col1 + shift + 1}, + true)); + + auto diagInterval = + _m({col1 + shift, col1 + shift + len, col1 + shift, col1 + shift + len}, + true) + .diagonal('c'); + + const T almostZero = DataTypeUtils::min(); + T maxElem; + if (len == 1) + maxElem = math::nd4j_abs(diagInterval.template e(0)); + else + maxElem = diagInterval({1, -1, 0, 0}, true) + .reduceNumber(reduce::AMax) + .template e(0); + T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e(0); + + T eps = math::nd4j_max(almostZero, DataTypeUtils::eps() * maxElem); + T epsBig = + (T)8. * DataTypeUtils::eps() * math::nd4j_max(maxElem0, maxElem); + + if (diagInterval.template e(0) < epsBig) + diagInterval.p(Nd4jLong(0), epsBig); + + for (int i = 1; i < len; ++i) + if (math::nd4j_abs(colVec0->template e(i)) < eps) colVec0->p(i, 0.f); + + for (int i = 1; i < len; i++) + if (diagInterval.template e(i) < epsBig) { + deflation1(col1, shift, i, len); + for (int i = 0; i < len; ++i) + diagInterval.p(i, _m.e(col1 + shift + i, col1 + shift + i)); + } + + { + bool totDefl = true; + for (int i = 1; i < len; i++) + if (colVec0->template e(i) >= almostZero) { + totDefl = false; + break; + } + + int* permut = nullptr; + ALLOCATE(permut, _m.getContext()->getWorkspace(), 3 * _diagSize, int); { - - bool totDefl = true; - for(int i=1; i < len; i++) - if(colVec0->template e(i) >= almostZero) { - totDefl = false; - break; - } - - int* permut = nullptr; - ALLOCATE(permut, _m.getContext()->getWorkspace(), 3*_diagSize, int); - { - permut[0] = 0; - int p = 1; - - for(int i=1; i(diagInterval.template e(i)) < almostZero) - permut[p++] = i; - - int k = 1, m = ind+1; - - for( ; p < len; ++p) { - if(k > ind) - permut[p] = m++; - else if(m >= len) - permut[p] = k++; - else if(diagInterval.template e(k) < diagInterval.template e(m)) - permut[p] = m++; - else - permut[p] = k++; - } - } - - if(totDefl) { - for(int i=1; i(diagInterval.template e(ki)) < almostZero || diagInterval.template e(0) < diagInterval.template e(ki)) - permut[i-1] = permut[i]; - else { - permut[i-1] = 0; - break; - } - } - } - - int *tInd = permut + len; - int *tCol = permut + 2*len; - - for(int m = 0; m < len; m++) { - tCol[m] = m; - tInd[m] = m; - } - - for(int i = totDefl ? 0 : 1; i < len; i++) { - - const int ki = permut[len - (totDefl ? i+1 : i)]; - const int jac = tCol[ki]; - - T _e0 = diagInterval.template e(jac); - //math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); - diagInterval.p(jac, diagInterval.template e(i)); - diagInterval.p(i, _e0); - - if(i!=0 && jac!=0) { - _e0 = colVec0->template e(jac); - //math::nd4j_swap((*colVec0)(i), (*colVec0)(jac)); - colVec0->p(jac, colVec0->template e(i)); - colVec0->p(i, _e0); - } - - if (_calcU) { - auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); - auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - else { - auto temp1 = _u({0,2, col1+i, col1+i+1}, true); - auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - - if(_calcV) { - auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true); - auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - - const int tI = tInd[i]; - tCol[tI] = jac; - tCol[ki] = i; - tInd[jac] = tI; - tInd[i] = ki; - } - - RELEASE(permut, _m.getContext()); - } - - { - int i = len-1; - - while(i > 0 && (math::nd4j_abs(diagInterval.template e(i)) < almostZero || math::nd4j_abs(colVec0->template e(i)) < almostZero)) - --i; - - for(; i > 1; --i) { - if( (diagInterval.template e(i) - diagInterval.template e(i-1)) < DataTypeUtils::eps()*maxElem ) { - if (math::nd4j_abs(diagInterval.template e(i) - diagInterval.template e(i-1)) >= epsBig) - throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); - deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); - } + permut[0] = 0; + int p = 1; + + for (int i = 1; i < len; ++i) + if (math::nd4j_abs(diagInterval.template e(i)) < almostZero) + permut[p++] = i; + + int k = 1, m = ind + 1; + + for (; p < len; ++p) { + if (k > ind) + permut[p] = m++; + else if (m >= len) + permut[p] = k++; + else if (diagInterval.template e(k) < diagInterval.template e(m)) + permut[p] = m++; + else + permut[p] = k++; + } + } + + if (totDefl) { + for (int i = 1; i < len; ++i) { + int ki = permut[i]; + if (math::nd4j_abs(diagInterval.template e(ki)) < almostZero || + diagInterval.template e(0) < diagInterval.template e(ki)) + permut[i - 1] = permut[i]; + else { + permut[i - 1] = 0; + break; } - } - - delete colVec0; + } + } + + int* tInd = permut + len; + int* tCol = permut + 2 * len; + + for (int m = 0; m < len; m++) { + tCol[m] = m; + tInd[m] = m; + } + + for (int i = totDefl ? 0 : 1; i < len; i++) { + const int ki = permut[len - (totDefl ? i + 1 : i)]; + const int jac = tCol[ki]; + + T _e0 = diagInterval.template e(jac); + // math::nd4j_swap(diagInterval)(i), (*diagInterval)(jac)); + diagInterval.p(jac, diagInterval.template e(i)); + diagInterval.p(i, _e0); + + if (i != 0 && jac != 0) { + _e0 = colVec0->template e(jac); + // math::nd4j_swap((*colVec0)(i), (*colVec0)(jac)); + colVec0->p(jac, colVec0->template e(i)); + colVec0->p(i, _e0); + } + + if (_calcU) { + auto temp1 = _u({col1, col1 + len + 1, col1 + i, col1 + i + 1}, true); + auto temp2 = + _u({col1, col1 + len + 1, col1 + jac, col1 + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } else { + auto temp1 = _u({0, 2, col1 + i, col1 + i + 1}, true); + auto temp2 = _u({0, 2, col1 + jac, col1 + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + + if (_calcV) { + auto temp1 = _v({row1W, row1W + len, col1W + i, col1W + i + 1}, true); + auto temp2 = + _v({row1W, row1W + len, col1W + jac, col1W + jac + 1}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + + const int tI = tInd[i]; + tCol[tI] = jac; + tCol[ki] = i; + tInd[jac] = tI; + tInd[i] = ki; + } + + RELEASE(permut, _m.getContext()); + } + + { + int i = len - 1; + + while (i > 0 && + (math::nd4j_abs(diagInterval.template e(i)) < almostZero || + math::nd4j_abs(colVec0->template e(i)) < almostZero)) + --i; + + for (; i > 1; --i) { + if ((diagInterval.template e(i) - diagInterval.template e(i - 1)) < + DataTypeUtils::eps() * maxElem) { + if (math::nd4j_abs(diagInterval.template e(i) - + diagInterval.template e(i - 1)) >= epsBig) + throw std::runtime_error( + "ops::helpers::SVD::deflation: diagonal elements are not " + "properly sorted !"); + deflation2(col1, col1 + shift, row1W, col1W, i - 1, i, len); + } + } + } + + delete colVec0; } - ////////////////////////////////////////////////////////////////////////// template -T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& diagShifted, const T shift) { - - auto len = permut.lengthOf(); - T res = 1.; - T item; - for(int i=0; i(i); - item = col0.e(j) / ((diagShifted.e(j) - diff) * (diag.e(j) + shift + diff)); - res += item * col0.e(j); - } - - return res; +T SVD::secularEq(const T diff, const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& diagShifted, + const T shift) { + auto len = permut.lengthOf(); + T res = 1.; + T item; + for (int i = 0; i < len; ++i) { + auto j = permut.e(i); + item = col0.e(j) / + ((diagShifted.e(j) - diff) * (diag.e(j) + shift + diff)); + res += item * col0.e(j); + } + + return res; } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus) { - - auto len = col0.lengthOf(); - auto curLen = len; - - while(curLen > 1 && col0.e(curLen-1) == (T)0.f) - --curLen; - - for (int k = 0; k < len; ++k) { - - if (col0.e(k) == (T)0.f || curLen==1) { - - singVals.p(k, k==0 ? col0.e(0) : diag.e(k)); - mus.p(k, 0.f); - shifts.p(k, k==0 ? col0.e(0) : diag.e(k)); - continue; - } +void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, + const NDArray& permut, NDArray& singVals, + NDArray& shifts, NDArray& mus) { + auto len = col0.lengthOf(); + auto curLen = len; - T left = diag.e(k); - T right; + while (curLen > 1 && col0.e(curLen - 1) == (T)0.f) --curLen; - if(k==curLen-1) - right = diag.e(curLen-1) + col0.reduceNumber(reduce::Norm2).e(0); - else { - - int l = k+1; - while(col0.e(l) == (T)0.f) { - ++l; - if(l >= curLen) - throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !"); - } - - right = diag.e(l); - } - - T mid = left + (right - left) / (T)2.; - T fMid = secularEq(mid, col0, diag, permut, diag, 0.); - T shift = (k == curLen-1 || fMid > (T)0.) ? left : right; + for (int k = 0; k < len; ++k) { + if (col0.e(k) == (T)0.f || curLen == 1) { + singVals.p(k, k == 0 ? col0.e(0) : diag.e(k)); + mus.p(k, 0.f); + shifts.p(k, k == 0 ? col0.e(0) : diag.e(k)); + continue; + } - auto diagShifted = diag - shift; + T left = diag.e(k); + T right; - T muPrev, muCur; - if (shift == left) { - muPrev = (right - left) * 0.1; - if (k == curLen-1) - muCur = right - left; - else - muCur = (right - left) * 0.5; - } + if (k == curLen - 1) + right = diag.e(curLen - 1) + col0.reduceNumber(reduce::Norm2).e(0); + else { + int l = k + 1; + while (col0.e(l) == (T)0.f) { + ++l; + if (l >= curLen) + throw std::runtime_error( + "ops::helpers::SVD::calcSingVals method: l >= curLen !"); + } + + right = diag.e(l); + } + + T mid = left + (right - left) / (T)2.; + T fMid = secularEq(mid, col0, diag, permut, diag, 0.); + T shift = (k == curLen - 1 || fMid > (T)0.) ? left : right; + + auto diagShifted = diag - shift; + + T muPrev, muCur; + if (shift == left) { + muPrev = (right - left) * 0.1; + if (k == curLen - 1) + muCur = right - left; + else + muCur = (right - left) * 0.5; + } else { + muPrev = -(right - left) * 0.1; + muCur = -(right - left) * 0.5; + } + + T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); + T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); + + if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { + math::nd4j_swap(fPrev, fCur); + math::nd4j_swap(muPrev, muCur); + } + + bool useBisection = fPrev * fCur > (T)0.; + while (fCur != (T).0 && + math::nd4j_abs(muCur - muPrev) > + (T)8. * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(muCur), + math::nd4j_abs(muPrev)) && + math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && + !useBisection) { + T a = (fCur - fPrev) / ((T)1. / muCur - (T)1. / muPrev); + T jac = fCur - a / muCur; + T muZero = -a / jac; + T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); + + muPrev = muCur; + fPrev = fCur; + muCur = muZero; + fCur = fZero; + + if (shift == left && (muCur < (T)0. || muCur > right - left)) + useBisection = true; + if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) + useBisection = true; + if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && + math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) + useBisection = true; + } + + if (useBisection) { + T leftShifted, rightShifted; + if (shift == left) { + leftShifted = DataTypeUtils::min(); + rightShifted = (k == curLen - 1) ? right : ((right - left) * (T)0.6); + } else { + leftShifted = -(right - left) * (T)0.6; + rightShifted = -DataTypeUtils::min(); + } + + T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); + T fRight = + secularEq(rightShifted, col0, diag, permut, diagShifted, shift); + // if(fLeft * fRight >= (T)0.) + // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. + // !"; + + while (rightShifted - leftShifted > + (T)2.f * DataTypeUtils::eps() * + math::nd4j_max(math::nd4j_abs(leftShifted), + math::nd4j_abs(rightShifted))) { + T midShifted = (leftShifted + rightShifted) / (T)2.; + fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); + if (fLeft * fMid < (T)0.) + rightShifted = midShifted; else { - muPrev = -(right - left) * 0.1; - muCur = -(right - left) * 0.5; - } - - T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift); - T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift); - - if (math::nd4j_abs(fPrev) < math::nd4j_abs(fCur)) { - math::nd4j_swap(fPrev, fCur); - math::nd4j_swap(muPrev, muCur); + leftShifted = midShifted; + fLeft = fMid; } - - bool useBisection = fPrev * fCur > (T)0.; - while (fCur != (T).0 && - math::nd4j_abs(muCur - muPrev) > (T)8. * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(muCur), math::nd4j_abs(muPrev)) - && math::nd4j_abs(fCur - fPrev) > DataTypeUtils::eps() && !useBisection) { - - T a = (fCur - fPrev) / ((T)1./muCur - (T)1./muPrev); - T jac = fCur - a / muCur; - T muZero = -a/jac; - T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift); - - muPrev = muCur; - fPrev = fCur; - muCur = muZero; - fCur = fZero; - - if (shift == left && (muCur < (T)0. || muCur > right - left)) - useBisection = true; - if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) - useBisection = true; - if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) - useBisection = true; - } - - - if (useBisection) { - - T leftShifted, rightShifted; - if (shift == left) { - leftShifted = DataTypeUtils::min(); - rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6); - } - else { - - leftShifted = -(right - left) * (T)0.6; - rightShifted = -DataTypeUtils::min(); - } - - T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift); - T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift); - // if(fLeft * fRight >= (T)0.) - // throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !"; - - while (rightShifted - leftShifted > (T)2.f * DataTypeUtils::eps() * math::nd4j_max(math::nd4j_abs(leftShifted), math::nd4j_abs(rightShifted))) { - - T midShifted = (leftShifted + rightShifted) / (T)2.; - fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift); - if (fLeft * fMid < (T)0.) - rightShifted = midShifted; - else { - leftShifted = midShifted; - fLeft = fMid; - } - } - muCur = (leftShifted + rightShifted) / (T)2.; - } - singVals.p(k, shift + muCur); - shifts.p(k, shift); - mus.p(k, muCur); + } + muCur = (leftShifted + rightShifted) / (T)2.; } - + singVals.p(k, shift + muCur); + shifts.p(k, shift); + mus.p(k, muCur); + } } - ////////////////////////////////////////////////////////////////////////// template -void SVD::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) { - - int n = col0.lengthOf(); - int m = permut.lengthOf(); - if(m==0) { - zhat.assign(0.); - return; - } - - int last = permut.e(m-1); - - for (int k = 0; k < n; ++k) { - - if (col0.e(k) == (T)0.f) - zhat.p(k, (T)0.f); - else { - T dk = diag.e(k); - T prod = (singVals.e(last) + dk) * (mus.e(last) + (shifts.e(last) - dk)); - - for(int l = 0; l(l); - if(i!=k) { - int j = i(l-1); - prod *= ((singVals.e(j)+dk) / ((diag.e(i)+dk))) * ((mus.e(j)+(shifts.e(j)-dk)) / ((diag.e(i)-dk))); - } - } - T tmp = math::nd4j_sqrt(prod); - zhat.p(k, col0.e(k) > (T)0.f ? tmp : -tmp); +void SVD::perturb(const NDArray& col0, const NDArray& diag, + const NDArray& permut, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& zhat) { + int n = col0.lengthOf(); + int m = permut.lengthOf(); + if (m == 0) { + zhat.assign(0.); + return; + } + + int last = permut.e(m - 1); + + for (int k = 0; k < n; ++k) { + if (col0.e(k) == (T)0.f) + zhat.p(k, (T)0.f); + else { + T dk = diag.e(k); + T prod = (singVals.e(last) + dk) * + (mus.e(last) + (shifts.e(last) - dk)); + + for (int l = 0; l < m; ++l) { + int i = permut.e(l); + if (i != k) { + int j = i < k ? i : permut.e(l - 1); + prod *= + ((singVals.e(j) + dk) / ((diag.e(i) + dk))) * + ((mus.e(j) + (shifts.e(j) - dk)) / ((diag.e(i) - dk))); } + } + T tmp = math::nd4j_sqrt(prod); + zhat.p(k, col0.e(k) > (T)0.f ? tmp : -tmp); } + } } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals, - const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V) { - - int n = zhat.lengthOf(); - int m = perm.lengthOf(); - - for (int k = 0; k < n; ++k) { - - auto colU = new NDArray(U({0,0, k,k+1}, true)); - *colU = 0.; - NDArray* colV = nullptr; +void SVD::calcSingVecs(const NDArray& zhat, const NDArray& diag, + const NDArray& perm, const NDArray& singVals, + const NDArray& shifts, const NDArray& mus, NDArray& U, + NDArray& V) { + int n = zhat.lengthOf(); + int m = perm.lengthOf(); + + for (int k = 0; k < n; ++k) { + auto colU = new NDArray(U({0, 0, k, k + 1}, true)); + *colU = 0.; + NDArray* colV = nullptr; - if (_calcV) { - colV = new NDArray(V({0,0, k,k+1}, true)); - *colV = 0.; - } - - if (zhat.e(k) == (T)0.f) { - colU->p(k, 1.f); - - if (_calcV) - colV->p(k, 1.f); - } - else { - - for(int l = 0; l < m; ++l) { - int i = perm.e(l); - U.p(i,k, zhat.e(i)/(((diag.e(i) - shifts.e(k)) - mus.e(k)) )/( (diag.e(i) + singVals.e(k)))); - } - U.p(n,k, 0.f); - *colU /= colU->reduceNumber(reduce::Norm2); - - if (_calcV) { - - for(int l = 1; l < m; ++l){ - int i = perm.e(l); - V.p(i,k, diag.e(i) * zhat.e(i) / (((diag.e(i) - shifts.e(k)) - mus.e(k)) )/( (diag.e(i) + singVals.e(k)))); - } - V.p(0,k, -1.f); - *colV /= colV->reduceNumber(reduce::Norm2); - } + if (_calcV) { + colV = new NDArray(V({0, 0, k, k + 1}, true)); + *colV = 0.; + } + + if (zhat.e(k) == (T)0.f) { + colU->p(k, 1.f); + + if (_calcV) colV->p(k, 1.f); + } else { + for (int l = 0; l < m; ++l) { + int i = perm.e(l); + U.p(i, k, + zhat.e(i) / (((diag.e(i) - shifts.e(k)) - mus.e(k))) / + ((diag.e(i) + singVals.e(k)))); + } + U.p(n, k, 0.f); + *colU /= colU->reduceNumber(reduce::Norm2); + + if (_calcV) { + for (int l = 1; l < m; ++l) { + int i = perm.e(l); + V.p(i, k, + diag.e(i) * zhat.e(i) / + (((diag.e(i) - shifts.e(k)) - mus.e(k))) / + ((diag.e(i) + singVals.e(k)))); } - delete colU; - if (_calcV) - delete colV; + V.p(0, k, -1.f); + *colV /= colV->reduceNumber(reduce::Norm2); + } } + delete colU; + if (_calcV) delete colV; + } - auto colU = U({0,0, n,n+1}, true); - colU = 0.; - colU.p(n, 1.); + auto colU = U({0, 0, n, n + 1}, true); + colU = 0.; + colU.p(n, 1.); } - ////////////////////////////////////////////////////////////////////////// template -void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDArray& V) { - - const T almostZero = DataTypeUtils::min(); - auto col0 = _m({col1, col1+size, col1, col1+1}, true); - auto diag = static_cast(_m({col1, col1+size, col1, col1+size}, true).diagonal('c')); - - diag.p(Nd4jLong(0), T(0)); - singVals = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - U = NDArrayFactory::create(_u.ordering(), {size+1, size+1}, _u.getContext()); - if (_calcV) - V = NDArrayFactory::create(_v.ordering(), {size, size}, _v.getContext()); - - int curSize = size; - while(curSize > 1 && diag.template e(curSize-1) == (T)0.f) - --curSize; - - int m = 0; - std::vector indices; - for(int k = 0; k < curSize; ++k) - if(math::nd4j_abs(col0.template e(k)) > almostZero) - indices.push_back((T)k); - - auto permut = NDArrayFactory::create(_m.ordering(), {1, (int)indices.size()}, indices, _m.getContext()); - auto shifts = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - auto mus = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - auto zhat = NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); - - calcSingVals(col0, diag, permut, singVals, shifts, mus); - perturb(col0, diag, permut, singVals, shifts, mus, zhat); - calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); - - for(int i=0; i(i) > singVals.e(i+1)) { - T _e0 = singVals.e(i); - T _e1 = singVals.e(i+1); - //math::nd4j_swap(singVals(i),singVals(i+1)); - singVals.p(i, _e1); - singVals.p(i+1, _e0); - - auto temp1 = U({0,0, i,i+1}, true); - auto temp2 = U({0,0, i+1,i+2}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - - if(_calcV) { - auto temp1 = V({0,0, i,i+1}, true); - auto temp2 = V({0,0, i+1,i+2}, true); - auto temp3 = temp1.dup(); - temp1.assign(temp2); - temp2.assign(temp3); - } - } - } - - auto temp1 = singVals({0,curSize, 0,0}, true); - for (int e = 0; e < curSize / 2; ++e) { - T tmp = temp1.e(e); - temp1.p(e, temp1.e(curSize-1-e)); - temp1.p(curSize-1-e, tmp); - } - - auto temp2 = U({0,0, 0,curSize}, true); - for(int i = 0; i < curSize/2; ++i) { - auto temp3 = temp2({0,0, i,i+1}, true); - auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3.dup(); - temp3.assign(temp4); - temp4.assign(temp5); - } - - if (_calcV) { - auto temp2 = V({0,0, 0,curSize}, true); - for(int i = 0; i < curSize/2; ++i) { - auto temp3 = temp2({0,0, i,i+1}, true); - auto temp4 = temp2({0,0, curSize-1-i,curSize-i}, true); - auto temp5 = temp3.dup(); - temp3.assign(temp4); - temp4.assign(temp5); - } - } +void SVD::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, + NDArray& V) { + const T almostZero = DataTypeUtils::min(); + auto col0 = _m({col1, col1 + size, col1, col1 + 1}, true); + auto diag = static_cast( + _m({col1, col1 + size, col1, col1 + size}, true).diagonal('c')); + + diag.p(Nd4jLong(0), T(0)); + singVals = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + U = NDArrayFactory::create(_u.ordering(), {size + 1, size + 1}, + _u.getContext()); + if (_calcV) + V = NDArrayFactory::create(_v.ordering(), {size, size}, _v.getContext()); + + int curSize = size; + while (curSize > 1 && diag.template e(curSize - 1) == (T)0.f) --curSize; + + int m = 0; + std::vector indices; + for (int k = 0; k < curSize; ++k) + if (math::nd4j_abs(col0.template e(k)) > almostZero) + indices.push_back((T)k); + + auto permut = NDArrayFactory::create( + _m.ordering(), {1, (int)indices.size()}, indices, _m.getContext()); + auto shifts = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + auto mus = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + auto zhat = + NDArrayFactory::create(_m.ordering(), {size, 1}, _m.getContext()); + + calcSingVals(col0, diag, permut, singVals, shifts, mus); + perturb(col0, diag, permut, singVals, shifts, mus, zhat); + calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); + + for (int i = 0; i < curSize - 1; ++i) { + if (singVals.e(i) > singVals.e(i + 1)) { + T _e0 = singVals.e(i); + T _e1 = singVals.e(i + 1); + // math::nd4j_swap(singVals(i),singVals(i+1)); + singVals.p(i, _e1); + singVals.p(i + 1, _e0); + + auto temp1 = U({0, 0, i, i + 1}, true); + auto temp2 = U({0, 0, i + 1, i + 2}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + + if (_calcV) { + auto temp1 = V({0, 0, i, i + 1}, true); + auto temp2 = V({0, 0, i + 1, i + 2}, true); + auto temp3 = temp1.dup(); + temp1.assign(temp2); + temp2.assign(temp3); + } + } + } + + auto temp1 = singVals({0, curSize, 0, 0}, true); + for (int e = 0; e < curSize / 2; ++e) { + T tmp = temp1.e(e); + temp1.p(e, temp1.e(curSize - 1 - e)); + temp1.p(curSize - 1 - e, tmp); + } + + auto temp2 = U({0, 0, 0, curSize}, true); + for (int i = 0; i < curSize / 2; ++i) { + auto temp3 = temp2({0, 0, i, i + 1}, true); + auto temp4 = temp2({0, 0, curSize - 1 - i, curSize - i}, true); + auto temp5 = temp3.dup(); + temp3.assign(temp4); + temp4.assign(temp5); + } + + if (_calcV) { + auto temp2 = V({0, 0, 0, curSize}, true); + for (int i = 0; i < curSize / 2; ++i) { + auto temp3 = temp2({0, 0, i, i + 1}, true); + auto temp4 = temp2({0, 0, curSize - 1 - i, curSize - i}, true); + auto temp5 = temp3.dup(); + temp3.assign(temp4); + temp4.assign(temp5); + } + } } - ////////////////////////////////////////////////////////////////////////// -template -void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift) { - - // requires rows = cols + 1; - const int n = col2 - col1 + 1; - const int k = n/2; - const T almostZero = DataTypeUtils::min(); - T alphaK; - T betaK; - T r0; - T lambda, phi, c0, s0; - auto l = NDArrayFactory::create(_u.ordering(), {1, k}, _u.getContext()); - auto f = NDArrayFactory::create(_u.ordering(), {1, n-k-1}, _u.getContext()); - - if(n < _switchSize) { - - JacobiSVD jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV); - - if (_calcU) { - auto temp = _u({col1,col1+n+1, col1,col1+n+1}, true); - temp.assign(jac._u); - } - else { - auto temp1 = _u({0,1, col1,col1+n+1}, true); - temp1.assign(jac._u({0,1, 0,0}, true)); - auto temp2 = _u({1,2, col1,col1+n+1}, true); - temp2.assign(jac._u({n,n+1, 0,0}, true)); - } - - if (_calcV) { - auto temp = _v({row1W,row1W+n, col1W,col1W+n}, true); - temp.assign(jac._v); - } - - auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); - temp.assign(0.); - auto diag = _m.diagonal('c'); - diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); - - return; - } - - alphaK = _m.e(col1 + k, col1 + k); - betaK = _m.e(col1 + k + 1, col1 + k); - - DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); - DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); - - if (_calcU) { - lambda = _u.e(col1 + k, col1 + k); - phi = _u.e(col1 + k + 1, col2 + 1); - } - else { - lambda = _u.e(1, col1 + k); - phi = _u.e(0, col2 + 1); - } - - r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * math::nd4j_abs(alphaK * lambda)) + math::nd4j_abs(betaK * phi) * math::nd4j_abs(betaK * phi)); - - if(_calcU) { - l.assign(_u({col1+k, col1+k+1, col1,col1+k}, true)); - f.assign(_u({col1+k+1,col1+k+2, col1+k+1,col1+n}, true)); - } - else { - l.assign(_u({1,2, col1, col1+k}, true)); - f.assign(_u({0,1, col1+k+1, col1+n}, true)); - } - - // UofSVD.printIndexedBuffer(); - // VofSVD.printIndexedBuffer(); - // singVals.printIndexedBuffer(); - // printf("!! \n"); - - if (_calcV) - _v.p(row1W+k, col1W, 1.f); - - if (r0 < almostZero){ - c0 = 1.; - s0 = 0.; - } - else { - c0 = alphaK * lambda / r0; - s0 = betaK * phi / r0; - } +template +void SVD::DivideAndConquer(int col1, int col2, int row1W, int col1W, + int shift) { + // requires rows = cols + 1; + const int n = col2 - col1 + 1; + const int k = n / 2; + const T almostZero = DataTypeUtils::min(); + T alphaK; + T betaK; + T r0; + T lambda, phi, c0, s0; + auto l = NDArrayFactory::create(_u.ordering(), {1, k}, _u.getContext()); + auto f = + NDArrayFactory::create(_u.ordering(), {1, n - k - 1}, _u.getContext()); + + if (n < _switchSize) { + JacobiSVD jac(_m({col1, col1 + n + 1, col1, col1 + n}, true), _calcU, + _calcV, _fullUV); if (_calcU) { - - auto temp = _u({col1,col1+k+1, col1+k,col1+k+1}, true); - NDArray q1(temp); - - for (int i = col1 + k - 1; i >= col1; --i) { - auto temp = _u({col1,col1+k+1, i+1,i+2}, true); - temp.assign(_u({col1, col1+k+1, i, i+1}, true)); - } - - _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0); - _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0)); - _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0); - _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0; - } - else { - - T q1 = _u.e(0, col1 + k); - - for (int i = col1 + k - 1; i >= col1; --i) - _u.p(0, i+1, _u.e(0, i)); - - _u.p(0, col1, q1 * c0); - _u.p(0, col2+1, -q1*s0); - _u.p(1, col1, _u.e(1, col2+1) * s0); - _u.p(1, col2 + 1, _u.e(1, col2 + 1) * c0); - _u({1,2, col1+1, col1+k+1}, true) = 0.f; - _u({0,1, col1+k+1, col1+n}, true) = 0.f; - } - - _m.p(col1 + shift, col1 + shift, r0); - auto temp1 = _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true); - temp1.assign(l*alphaK); - auto temp2 = _m({col1+shift+k+1,col1+shift+n, col1+shift,col1+shift+1}, true); - temp2.assign(f*betaK); - - deflation(col1, col2, k, row1W, col1W, shift); - - NDArray UofSVD, VofSVD, singVals; - calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); - - if(_calcU) { - auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, UofSVD)); - } - else { - auto pTemp = _u({0,0, col1,col1+n+1}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, UofSVD)); + auto temp = _u({col1, col1 + n + 1, col1, col1 + n + 1}, true); + temp.assign(jac._u); + } else { + auto temp1 = _u({0, 1, col1, col1 + n + 1}, true); + temp1.assign(jac._u({0, 1, 0, 0}, true)); + auto temp2 = _u({1, 2, col1, col1 + n + 1}, true); + temp2.assign(jac._u({n, n + 1, 0, 0}, true)); } if (_calcV) { - auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true); - auto temp = pTemp.dup(); - pTemp.assign(mmul(temp, VofSVD)); - } - - auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); - blockM = 0.f; - auto diag = blockM.diagonal('c'); - diag.assign(singVals); + auto temp = _v({row1W, row1W + n, col1W, col1W + n}, true); + temp.assign(jac._v); + } + + auto temp = + _m({col1 + shift, col1 + shift + n + 1, col1 + shift, col1 + shift + n}, + true); + temp.assign(0.); + auto diag = _m.diagonal('c'); + diag({col1 + shift, col1 + shift + n, 0, 0}, true) + .assign(jac._s({0, n, 0, 0}, true)); + + return; + } + + alphaK = _m.e(col1 + k, col1 + k); + betaK = _m.e(col1 + k + 1, col1 + k); + + DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift); + DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1); + + if (_calcU) { + lambda = _u.e(col1 + k, col1 + k); + phi = _u.e(col1 + k + 1, col2 + 1); + } else { + lambda = _u.e(1, col1 + k); + phi = _u.e(0, col2 + 1); + } + + r0 = math::nd4j_sqrt((math::nd4j_abs(alphaK * lambda) * + math::nd4j_abs(alphaK * lambda)) + + math::nd4j_abs(betaK * phi) * + math::nd4j_abs(betaK * phi)); + + if (_calcU) { + l.assign(_u({col1 + k, col1 + k + 1, col1, col1 + k}, true)); + f.assign(_u({col1 + k + 1, col1 + k + 2, col1 + k + 1, col1 + n}, true)); + } else { + l.assign(_u({1, 2, col1, col1 + k}, true)); + f.assign(_u({0, 1, col1 + k + 1, col1 + n}, true)); + } + + // UofSVD.printIndexedBuffer(); + // VofSVD.printIndexedBuffer(); + // singVals.printIndexedBuffer(); + // printf("!! \n"); + + if (_calcV) _v.p(row1W + k, col1W, 1.f); + + if (r0 < almostZero) { + c0 = 1.; + s0 = 0.; + } else { + c0 = alphaK * lambda / r0; + s0 = betaK * phi / r0; + } + + if (_calcU) { + auto temp = _u({col1, col1 + k + 1, col1 + k, col1 + k + 1}, true); + NDArray q1(temp); + + for (int i = col1 + k - 1; i >= col1; --i) { + auto temp = _u({col1, col1 + k + 1, i + 1, i + 2}, true); + temp.assign(_u({col1, col1 + k + 1, i, i + 1}, true)); + } + + _u({col1, col1 + k + 1, col1, col1 + 1}, true).assign(q1 * c0); + _u({col1, col1 + k + 1, col2 + 1, col2 + 2}, true).assign(q1 * (-s0)); + _u({col1 + k + 1, col1 + n + 1, col1, col1 + 1}, true) + .assign(static_cast(_u( + {col1 + k + 1, col1 + n + 1, col2 + 1, col2 + 2}, true)) * + s0); + _u({col1 + k + 1, col1 + n + 1, col2 + 1, col2 + 2}, true) *= c0; + } else { + T q1 = _u.e(0, col1 + k); + + for (int i = col1 + k - 1; i >= col1; --i) _u.p(0, i + 1, _u.e(0, i)); + + _u.p(0, col1, q1 * c0); + _u.p(0, col2 + 1, -q1 * s0); + _u.p(1, col1, _u.e(1, col2 + 1) * s0); + _u.p(1, col2 + 1, _u.e(1, col2 + 1) * c0); + _u({1, 2, col1 + 1, col1 + k + 1}, true) = 0.f; + _u({0, 1, col1 + k + 1, col1 + n}, true) = 0.f; + } + + _m.p(col1 + shift, col1 + shift, r0); + auto temp1 = _m( + {col1 + shift + 1, col1 + shift + k + 1, col1 + shift, col1 + shift + 1}, + true); + temp1.assign(l * alphaK); + auto temp2 = _m( + {col1 + shift + k + 1, col1 + shift + n, col1 + shift, col1 + shift + 1}, + true); + temp2.assign(f * betaK); + + deflation(col1, col2, k, row1W, col1W, shift); + + NDArray UofSVD, VofSVD, singVals; + calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD); + + if (_calcU) { + auto pTemp = _u({col1, col1 + n + 1, col1, col1 + n + 1}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, UofSVD)); + } else { + auto pTemp = _u({0, 0, col1, col1 + n + 1}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, UofSVD)); + } + + if (_calcV) { + auto pTemp = _v({row1W, row1W + n, row1W, row1W + n}, true); + auto temp = pTemp.dup(); + pTemp.assign(mmul(temp, VofSVD)); + } + + auto blockM = _m( + {col1 + shift, col1 + shift + n, col1 + shift, col1 + shift + n}, true); + blockM = 0.f; + auto diag = blockM.diagonal('c'); + diag.assign(singVals); } ////////////////////////////////////////////////////////////////////////// -template -void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V) { - - if (_calcU) { - - int colsU = _fullUV ? hhU.rows() : _diagSize; - auto temp1 = NDArrayFactory::create(_u.ordering(), {hhU.rows(), colsU}, _u.getContext()); - temp1.setIdentity(); - _u = temp1; - - auto temp2 = _u({0,_diagSize, 0,_diagSize}, true); - temp2.assign(V({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhU).mulLeft(_u); - } - - if (_calcV) { - - int colsV = _fullUV ? hhV.rows() : _diagSize; - auto temp1 = NDArrayFactory::create(_v.ordering(), {hhV.rows(), colsV}, _v.getContext()); - temp1.setIdentity(); - _v = temp1; - - auto temp2 = _v({0,_diagSize, 0,_diagSize}, true); - temp2.assign(U({0,_diagSize, 0,_diagSize}, true)); - const_cast(hhV).mulLeft(_v); - } +template +void SVD::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, + const NDArray& U, const NDArray& V) { + if (_calcU) { + int colsU = _fullUV ? hhU.rows() : _diagSize; + auto temp1 = NDArrayFactory::create(_u.ordering(), {hhU.rows(), colsU}, + _u.getContext()); + temp1.setIdentity(); + _u = temp1; + + auto temp2 = _u({0, _diagSize, 0, _diagSize}, true); + temp2.assign(V({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhU).mulLeft(_u); + } + + if (_calcV) { + int colsV = _fullUV ? hhV.rows() : _diagSize; + auto temp1 = NDArrayFactory::create(_v.ordering(), {hhV.rows(), colsV}, + _v.getContext()); + temp1.setIdentity(); + _v = temp1; + + auto temp2 = _v({0, _diagSize, 0, _diagSize}, true); + temp2.assign(U({0, _diagSize, 0, _diagSize}, true)); + const_cast(hhV).mulLeft(_v); + } } ////////////////////////////////////////////////////////////////////////// template void SVD::evalData(const NDArray& matrix) { + const T almostZero = DataTypeUtils::min(); - const T almostZero = DataTypeUtils::min(); + if (matrix.sizeAt(1) < _switchSize) { + JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); - if(matrix.sizeAt(1) < _switchSize) { + if (_calcU) _u = jac._u; + if (_calcV) _v = jac._v; - JacobiSVD jac(matrix, _calcU, _calcV, _fullUV); + _s.assign(jac._s); - if(_calcU) - _u = jac._u; - if(_calcV) - _v = jac._v; - - _s.assign(jac._s); - - return; - } + return; + } - T scale = matrix.reduceNumber(reduce::AMax).e(0); + T scale = matrix.reduceNumber(reduce::AMax).e(0); - if(scale == (T)0.) - scale = 1.; + if (scale == (T)0.) scale = 1.; - NDArray copy; - if(_transp) - copy = matrix.transpose(); - else - copy = matrix / scale; + NDArray copy; + if (_transp) + copy = matrix.transpose(); + else + copy = matrix / scale; - BiDiagonalUp biDiag(copy); + BiDiagonalUp biDiag(copy); - _u = 0.; - _v = 0.; + _u = 0.; + _v = 0.; - auto temp1 = biDiag._HHbidiag.transpose(); - auto temp2 = _m({0,_diagSize, 0,0}, true); - temp2.assign(temp1); + auto temp1 = biDiag._HHbidiag.transpose(); + auto temp2 = _m({0, _diagSize, 0, 0}, true); + temp2.assign(temp1); - auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true); - temp3.assign(0.); + auto temp3 = _m({_m.sizeAt(0) - 1, _m.sizeAt(0), 0, 0}, true); + temp3.assign(0.); - DivideAndConquer(0, _diagSize - 1, 0, 0, 0); + DivideAndConquer(0, _diagSize - 1, 0, 0, 0); - for (int i = 0; i < _diagSize; ++i) { - T a = math::nd4j_abs(_m.e(i, i)); - _s.p(i, a * scale); - if (a < almostZero) { - auto temp = _s({i+1,_diagSize, 0,0}, true); - temp.assign(0.); - break; - } - else if (i == _diagSize-1) - break; - } + for (int i = 0; i < _diagSize; ++i) { + T a = math::nd4j_abs(_m.e(i, i)); + _s.p(i, a * scale); + if (a < almostZero) { + auto temp = _s({i + 1, _diagSize, 0, 0}, true); + temp.assign(0.); + break; + } else if (i == _diagSize - 1) + break; + } - if(_transp) - exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u); - else - exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v); + if (_transp) + exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u); + else + exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v); } - -BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD,,FLOAT_TYPES); - +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SVD, , FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -// svd operation, this function is not method of SVD class, it is standalone function +// svd operation, this function is not method of SVD class, it is standalone +// function template -static void svd_(const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - - auto s = outArrs[0]; - auto u = outArrs[1]; - auto v = outArrs[2]; - - const int rank = x->rankOf(); - const int sRank = rank - 1; - - auto listX = x->allTensorsAlongDimension({rank-2, rank-1}); - auto listS = s->allTensorsAlongDimension({sRank-1}); - ResultSet* listU(nullptr), *listV(nullptr); - - if(calcUV) { - listU = new ResultSet(u->allTensorsAlongDimension({rank-2, rank-1})); - listV = new ResultSet(v->allTensorsAlongDimension({rank-2, rank-1})); - } - - for(int i = 0; i < listX.size(); ++i) { - - // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), listX.at(i)->sizeAt(1)}, block.getContext()); - // matrix.assign(listX.at(i)); - helpers::SVD svdObj(listX.at(i), switchNum, calcUV, calcUV, fullUV); - listS.at(i).assign(svdObj._s); - - if(calcUV) { - listU->at(i).assign(svdObj._u); - listV->at(i).assign(svdObj._v); - } - } - - if(calcUV) { - delete listU; - delete listV; - } +static void svd_(const NDArray* x, const std::vector& outArrs, + const bool fullUV, const bool calcUV, const int switchNum) { + auto s = outArrs[0]; + auto u = outArrs[1]; + auto v = outArrs[2]; + + const int rank = x->rankOf(); + const int sRank = rank - 1; + + auto listX = x->allTensorsAlongDimension({rank - 2, rank - 1}); + auto listS = s->allTensorsAlongDimension({sRank - 1}); + ResultSet *listU(nullptr), *listV(nullptr); + + if (calcUV) { + listU = new ResultSet(u->allTensorsAlongDimension({rank - 2, rank - 1})); + listV = new ResultSet(v->allTensorsAlongDimension({rank - 2, rank - 1})); + } + + for (int i = 0; i < listX.size(); ++i) { + // NDArray matrix(x->ordering(), {listX.at(i)->sizeAt(0), + // listX.at(i)->sizeAt(1)}, block.getContext()); matrix.assign(listX.at(i)); + helpers::SVD svdObj(listX.at(i), switchNum, calcUV, calcUV, fullUV); + listS.at(i).assign(svdObj._s); + + if (calcUV) { + listU->at(i).assign(svdObj._u); + listV->at(i).assign(svdObj._v); + } + } + + if (calcUV) { + delete listU; + delete listV; + } } - void svd(sd::LaunchContext * context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - BUILD_SINGLE_SELECTOR(x->dataType(), svd_, (x, outArrs, fullUV, calcUV, switchNum), FLOAT_TYPES); - } - - -} -} +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum) { + BUILD_SINGLE_SELECTOR(x->dataType(), svd_, + (x, outArrs, fullUV, calcUV, switchNum), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp b/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp index 4edb9e2a0142..aabc0c56882b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/tile.cpp @@ -18,74 +18,68 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - - T* gradIBuff = reinterpret_cast(gradI.buffer()); - auto gradOBuff = reinterpret_cast(gradO.buffer()); - const Nd4jLong gradILen = gradI.lengthOf(); - const Nd4jLong gradOLen = gradO.lengthOf(); // gradOLen >= gradILen - const Nd4jLong gradIEWS = sd::math::nd4j_abs(gradI.ews()); - const Nd4jLong gradOEWS = gradO.ews(); - - // initial zeroing of gradI content - if(gradIEWS == 1) - memset(gradIBuff, 0, gradILen * sizeof(T)); - else { - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) - gradIBuff[i] = static_cast(0.f); - } - - - if(gradO.ordering() == 'c' && gradOEWS == 1) { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(idx) + gradOBuff[i]); - } +static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, + const std::vector reps) { + T* gradIBuff = reinterpret_cast(gradI.buffer()); + auto gradOBuff = reinterpret_cast(gradO.buffer()); + const Nd4jLong gradILen = gradI.lengthOf(); + const Nd4jLong gradOLen = gradO.lengthOf(); // gradOLen >= gradILen + const Nd4jLong gradIEWS = sd::math::nd4j_abs(gradI.ews()); + const Nd4jLong gradOEWS = gradO.ews(); + + // initial zeroing of gradI content + if (gradIEWS == 1) + memset(gradIBuff, 0, gradILen * sizeof(T)); + else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) + gradIBuff[i] = static_cast(0.f); + } + + if (gradO.ordering() == 'c' && gradOEWS == 1) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto idx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(idx, gradI.e(idx) + gradOBuff[i]); } - else if(gradO.ordering() == 'c' && gradOEWS > 1) { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(idx) + gradOBuff[i * gradOEWS]); - } + } else if (gradO.ordering() == 'c' && gradOEWS > 1) { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto idx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(idx, gradI.e(idx) + gradOBuff[i * gradOEWS]); } - else { - - //PRAGMA_OMP_PARALLEL_FOR_SIMD - for(Nd4jLong i=0; i(fidx) + gradOBuff[shape::getIndexOffset(i, gradO.shapeInfo())]); - } + } else { + // PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradOLen; ++i) { + auto fidx = shape::subArrayIndex(i, gradO.shapeInfo(), gradI.shapeInfo()); + gradI.p(fidx, gradI.e(fidx) + + gradOBuff[shape::getIndexOffset(i, gradO.shapeInfo())]); } + } } -void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (gradO, gradI, reps), FLOAT_TYPES); +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps) { + BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (gradO, gradI, reps), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void tileBP_, + (const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, + const std::vector reps), + FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void tileBP_, (const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps), FLOAT_TYPES); - - - - - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index 67b4e0f776d8..e5ab75571229 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -18,24 +18,22 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void toggle_bits__(NDArray &in, NDArray &out) { - auto lambda = LAMBDA_T(_x) { - return BitwiseUtils::flip_bits(_x); - }; +namespace ops { +namespace helpers { +template +void toggle_bits__(NDArray& in, NDArray& out) { + auto lambda = LAMBDA_T(_x) { return BitwiseUtils::flip_bits(_x); }; - in.applyLambda(lambda, out); - } + in.applyLambda(lambda, out); +} - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out) { - BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void __toggle_bits(sd::LaunchContext* context, NDArray& in, NDArray& out) { + BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp index fdab43261d87..39cd5efe0649 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/top_k.cpp @@ -18,167 +18,186 @@ // @author raver119@gmail.com // -#include -#include #include #include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int topKFunctor_(const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - Nd4jLong width = input->sizeAt(-1); - int lastDim = input->rankOf() - 1; -// ----------------------------------------------------------------------------------------------- // -// this assumption is right: -// if (values->lengthOf() != k * lastDimList->size()) { -// nd4j_printf("top_k: something is wrong. %i expected, but %i given.\n", -// values->lengthOf(), k * lastDimList->size()); -// } -// ----------------------------------------------------------------------------------------------- // - std::vector dimsToExclude(input->rankOf() - 1); - for (size_t d = 0; d < dimsToExclude.size(); ++d) - dimsToExclude[d] = d; - - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); - - if (k == 1) { - for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { - auto trial = (*input)(e, dimsToExclude); - //int maxPos = //lastDimList->at(e)->argMax(); - Nd4jLong maxPos = 0; - //trial.printIndexedBuffer("TRIAL:"); - T maxVal = trial.e(0); - for (Nd4jLong pos = 1; pos < trial.lengthOf(); pos++) - if (maxVal < trial.e(pos)) { - maxPos = pos; - maxVal = trial.e(pos); - } - if (indices) - indices->p(e, maxPos); //topIndex; - if (values) - values->p(e, maxVal); - } - } - else { - int nextPos = 0; - - for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { - auto trial = (*input)(e, dimsToExclude); - - // fill up the first k elements - NDArray topValues = NDArrayFactory::create('c', {k}, input->getContext()); - NDArray sortedVals = NDArrayFactory::create('c', {k}, input->getContext()); - NDArray topIndices = NDArrayFactory::create('c', {k}, input->getContext()); - for (uint pos = 0; pos < k; ++pos) { - topIndices.t(pos) = pos; - topValues.t(pos) = trial.t(pos); - } - //std::vector sortedVals(topValues); - sortedVals.assign(topValues);// = NDArrayFactory::create('c', {k}); - //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); - for (Nd4jLong i = static_cast(k); i < width; ++i) { - T val = trial.e(i); - T minTopVal = sortedVals.t(0); - if (minTopVal < val) { // value should be inserted to top k - // only if it is not contained in - T* begin = reinterpret_cast(sortedVals.buffer()); - T* end = begin + k; - bool exists = std::binary_search(begin, end, val); - if (!exists) { - //exchangePos - a distance between begin and minimal existed to be suppressed by val - T* topBegin = reinterpret_cast(topValues.buffer()); - T* topEnd = topBegin + k; - auto exchangePos = std::distance(topBegin, std::find(topBegin, topEnd, sortedVals.t(0))); - topValues.t(exchangePos) = val; //*exchangeIt = val; - topIndices.t(exchangePos) = i; - sortedVals.t(0) = val; // suppress in sorted - //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order - SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); - } - } - } - if (needSort) { - SpecialMethods::sortGeneric(topValues.buffer(), topValues.shapeInfo(), true); - - for (Nd4jLong j = 0; j < width; j++) - for (uint pos = 0; pos < k; ++pos) - if (topValues.t(pos) == trial.t(j)) - topIndices.t(pos) = j; - } - else { // else sort by indices - std::map sortValsMap; - //std::vector> data(topValues.lengthOf()); - for (Nd4jLong e = 0; e < topValues.lengthOf(); ++e) { - sortValsMap[topIndices.t(e)] = topValues.t(e); - } - - //std::sort(data.begin(), data.end(), [](std::pair const& a, std::pair const& b) { - // return a.first < b.first; - //}); - Nd4jLong e = 0; - for (auto it = sortValsMap.begin(); it != sortValsMap.end(); ++it, e++) { - topIndices.t(e) = it->first; - topValues.t(e) = it->second; - } - - } - if (values) - (*values)(e, dimsToExclude).assign(topValues); - if (indices) - (*indices)(e, dimsToExclude).assign(topIndices); - } - //indices->printIndexedBuffer("Indices as is"); +template +static int topKFunctor_(const NDArray* input, NDArray* values, NDArray* indices, + const uint k, bool needSort) { + Nd4jLong width = input->sizeAt(-1); + int lastDim = input->rankOf() - 1; + // ----------------------------------------------------------------------------------------------- + // // this assumption is right: + // if (values->lengthOf() != k * lastDimList->size()) { + // nd4j_printf("top_k: something is wrong. %i expected, but %i + // given.\n", + // values->lengthOf(), k * lastDimList->size()); + // } + // ----------------------------------------------------------------------------------------------- + // // + std::vector dimsToExclude(input->rankOf() - 1); + for (size_t d = 0; d < dimsToExclude.size(); ++d) dimsToExclude[d] = d; + + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); + + if (k == 1) { + for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { + auto trial = (*input)(e, dimsToExclude); + // int maxPos = //lastDimList->at(e)->argMax(); + Nd4jLong maxPos = 0; + // trial.printIndexedBuffer("TRIAL:"); + T maxVal = trial.e(0); + for (Nd4jLong pos = 1; pos < trial.lengthOf(); pos++) + if (maxVal < trial.e(pos)) { + maxPos = pos; + maxVal = trial.e(pos); } - return Status::OK(); - } -// ----------------------------------------------------------------------------------------------- // - - template - static int inTopKFunctor_(sd::LaunchContext* context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { - - std::vector shapeI(input->rankOf()); - for (int i = 0; i < input->rankOf() - 1; i++) - shapeI[i] = input->sizeAt(i); - shapeI[input->rankOf() - 1] = k; - std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI, context)); - NDArray* values = nullptr; - int status = topKFunctor(context, input, values, indices.get(), k, true); - result->assign(0); - if (status == ND4J_STATUS_OK) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - bool found = false; - for (uint j = 0; j < k; j++) { - if (target->e(e) == indices->e(e * k + j)) { - found = true; - break; - } - } - if (found) - result->p(e, true); - } - }; - - samediff::Threads::parallel_tad(func, 0, target->lengthOf()); - } - return status; - + if (indices) indices->p(e, maxPos); // topIndex; + if (values) values->p(e, maxVal); } - - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - BUILD_SINGLE_SELECTOR(input->dataType(), return topKFunctor_, (input, values, indices, k, needSort), NUMERIC_TYPES); + } else { + int nextPos = 0; + + for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { + auto trial = (*input)(e, dimsToExclude); + + // fill up the first k elements + NDArray topValues = + NDArrayFactory::create('c', {k}, input->getContext()); + NDArray sortedVals = + NDArrayFactory::create('c', {k}, input->getContext()); + NDArray topIndices = + NDArrayFactory::create('c', {k}, input->getContext()); + for (uint pos = 0; pos < k; ++pos) { + topIndices.t(pos) = pos; + topValues.t(pos) = trial.t(pos); + } + // std::vector sortedVals(topValues); + sortedVals.assign(topValues); // = NDArrayFactory::create('c', {k}); + // std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending + // order + SpecialMethods::sortGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), false); + for (Nd4jLong i = static_cast(k); i < width; ++i) { + T val = trial.e(i); + T minTopVal = sortedVals.t(0); + if (minTopVal < val) { // value should be inserted to top k + // only if it is not contained in + T* begin = reinterpret_cast(sortedVals.buffer()); + T* end = begin + k; + bool exists = std::binary_search(begin, end, val); + if (!exists) { + // exchangePos - a distance between begin and minimal existed to be + // suppressed by val + T* topBegin = reinterpret_cast(topValues.buffer()); + T* topEnd = topBegin + k; + auto exchangePos = std::distance( + topBegin, std::find(topBegin, topEnd, sortedVals.t(0))); + topValues.t(exchangePos) = val; //*exchangeIt = val; + topIndices.t(exchangePos) = i; + sortedVals.t(0) = val; // suppress in sorted + // std::sort(sortedVals.begin(), sortedVals.end()); // sorted in + // ascending order + SpecialMethods::sortGeneric(sortedVals.buffer(), + sortedVals.shapeInfo(), false); + } + } + } + if (needSort) { + SpecialMethods::sortGeneric(topValues.buffer(), + topValues.shapeInfo(), true); + + for (Nd4jLong j = 0; j < width; j++) + for (uint pos = 0; pos < k; ++pos) + if (topValues.t(pos) == trial.t(j)) + topIndices.t(pos) = j; + } else { // else sort by indices + std::map sortValsMap; + // std::vector> data(topValues.lengthOf()); + for (Nd4jLong e = 0; e < topValues.lengthOf(); ++e) { + sortValsMap[topIndices.t(e)] = topValues.t(e); } - int inTopKFunctor(sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { - BUILD_SINGLE_SELECTOR(input->dataType(), return inTopKFunctor_, (context, input, target, result, k), NUMERIC_TYPES); + // std::sort(data.begin(), data.end(), [](std::pair const& a, + // std::pair const& b) { + // return a.first < b.first; + //}); + Nd4jLong e = 0; + for (auto it = sortValsMap.begin(); it != sortValsMap.end(); + ++it, e++) { + topIndices.t(e) = it->first; + topValues.t(e) = it->second; } + } + if (values) (*values)(e, dimsToExclude).assign(topValues); + if (indices) (*indices)(e, dimsToExclude).assign(topIndices); + } + // indices->printIndexedBuffer("Indices as is"); + } + return Status::OK(); +} +// ----------------------------------------------------------------------------------------------- +// // + +template +static int inTopKFunctor_(sd::LaunchContext* context, const NDArray* input, + const NDArray* target, NDArray* result, + const uint k) { + std::vector shapeI(input->rankOf()); + for (int i = 0; i < input->rankOf() - 1; i++) shapeI[i] = input->sizeAt(i); + shapeI[input->rankOf() - 1] = k; + std::unique_ptr indices( + NDArrayFactory::create_(input->ordering(), shapeI, context)); + NDArray* values = nullptr; + int status = topKFunctor(context, input, values, indices.get(), k, true); + result->assign(0); + if (status == ND4J_STATUS_OK) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + bool found = false; + for (uint j = 0; j < k; j++) { + if (target->e(e) == indices->e(e * k + j)) { + found = true; + break; + } + } + if (found) result->p(e, true); + } + }; - BUILD_SINGLE_TEMPLATE(template int topKFunctor_, (const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort), NUMERIC_TYPES); - BUILD_SINGLE_TEMPLATE(template int inTopKFunctor_, (sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k), NUMERIC_TYPES); + samediff::Threads::parallel_tad(func, 0, target->lengthOf()); + } + return status; } + +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + BUILD_SINGLE_SELECTOR(input->dataType(), return topKFunctor_, + (input, values, indices, k, needSort), NUMERIC_TYPES); } + +int inTopKFunctor(sd::LaunchContext* context, const NDArray* input, + const NDArray* target, NDArray* result, const uint k) { + BUILD_SINGLE_SELECTOR(input->dataType(), return inTopKFunctor_, + (context, input, target, result, k), NUMERIC_TYPES); } + +BUILD_SINGLE_TEMPLATE(template int topKFunctor_, + (const NDArray* input, NDArray* values, NDArray* indices, + const uint k, bool needSort), + NUMERIC_TYPES); +BUILD_SINGLE_TEMPLATE(template int inTopKFunctor_, + (sd::LaunchContext * context, const NDArray* input, + const NDArray* target, NDArray* result, const uint k), + NUMERIC_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp index c80829aabf34..d6353bd72ff5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/trace.cpp @@ -18,30 +18,30 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template static void trace_(const NDArray& input, NDArray& output) { - const int inRank = input.rankOf(); - auto setOfSubArrs = input.allTensorsAlongDimension({inRank-2, inRank-1}); + const int inRank = input.rankOf(); + auto setOfSubArrs = input.allTensorsAlongDimension({inRank - 2, inRank - 1}); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - output.p(i, setOfSubArrs.at(i).getTrace()); - }; - samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + output.p(i, setOfSubArrs.at(i).getTrace()); + }; + samediff::Threads::parallel_for(func, 0, setOfSubArrs.size()); } - void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), LIBND4J_TYPES); - } -} -} +void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), + LIBND4J_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index 4bb09378930d..5ce9131b6165 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -17,127 +17,143 @@ // // @author GS // -#include +#include "../triangular_solve.h" + #include #include -#include "../triangular_solve.h" +#include namespace sd { namespace ops { namespace helpers { - /* - * lower triangular process for system of linear equations - * x_1 = b_1/a_1,1 - * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 - * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 - * ... - * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - template - static void lowerTriangularSolve(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { - auto rows = leftInput->rows(); - auto cols = rightInput->columns(); - //output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); - for (Nd4jLong r = 0; r < rows; r++) { - for (Nd4jLong j = 0; j < cols; j++) { - auto sum = rightInput->t(r, j); - for (Nd4jLong c = 0; c < r; c++) { - sum -= leftInput->t(r, c) * output->t(c, j); - } - output->t(r, j) = sum / leftInput->t(r, r); - } - } +/* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ +template +static void lowerTriangularSolve(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, + NDArray* output) { + auto rows = leftInput->rows(); + auto cols = rightInput->columns(); + // output->t(0,0) = rightInput->t(0,0) / leftInput->t(0,0); + for (Nd4jLong r = 0; r < rows; r++) { + for (Nd4jLong j = 0; j < cols; j++) { + auto sum = rightInput->t(r, j); + for (Nd4jLong c = 0; c < r; c++) { + sum -= leftInput->t(r, c) * output->t(c, j); + } + output->t(r, j) = sum / leftInput->t(r, r); } + } +} - /* - * upper triangular process for system of linear equations - * x_M = b_M/a_M,M - * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 - * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 - * ... - * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 - * - * output == x - * a == leftInput - * b == rightInput - * - * */ +/* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ - template - static void upperTriangularSolve(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { - auto rows = leftInput->rows(); - auto cols = rightInput->columns(); - for (Nd4jLong r = rows; r > 0; r--) { - for (Nd4jLong j = 0; j < cols; j++) { - auto sum = rightInput->t(r - 1, j); - for (Nd4jLong c = r; c < rows; c++) { - sum -= leftInput->t(r - 1, c) * output->t(c, j); - } - output->t(r - 1, j) = sum / leftInput->t(r - 1, r - 1); - } - } +template +static void upperTriangularSolve(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, + NDArray* output) { + auto rows = leftInput->rows(); + auto cols = rightInput->columns(); + for (Nd4jLong r = rows; r > 0; r--) { + for (Nd4jLong j = 0; j < cols; j++) { + auto sum = rightInput->t(r - 1, j); + for (Nd4jLong c = r; c < rows; c++) { + sum -= leftInput->t(r - 1, c) * output->t(c, j); + } + output->t(r - 1, j) = sum / leftInput->t(r - 1, r - 1); } + } +} - template - static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { - auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); - auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - - auto batchLoop = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - if (lower) { - lowerTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, &outputPart[i]); - } else { - upperTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, &outputPart[i]); - } - } - }; - - samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); - - return Status::OK(); +template +static int triangularSolveFunctor_(sd::LaunchContext* context, + NDArray* leftInput, NDArray* rightInput, + bool lower, bool adjoint, NDArray* output) { + auto leftPart = leftInput->allTensorsAlongDimension({-2, -1}); + auto rightPart = rightInput->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + if (lower) { + lowerTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, + &outputPart[i]); + } else { + upperTriangularSolve(context, &leftPart[i], &rightPart[i], adjoint, + &outputPart[i]); + } } - template - static void adjointTriangularMatrix_(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - auto inputPart = input->allTensorsAlongDimension({-2, -1}); - auto outputPart = output->allTensorsAlongDimension({-2, -1}); - auto cols = input->sizeAt(-1); - auto rows = input->sizeAt(-2); + }; - auto batchLoop = PRAGMA_THREADS_FOR { - for (auto batch = start; batch < stop; batch++) { - if (!lower) { - for (Nd4jLong r = 0; r < rows; r++) { - for (Nd4jLong c = 0; c <= r; c++) { - outputPart[batch].t(r, c) = inputPart[batch].t(c, r); - } - } - } else { - for (Nd4jLong r = 0; r < rows; r++) { - for (Nd4jLong c = r; c < cols; c++) { - outputPart[batch].t(r, c) = inputPart[batch].t(c, r); - } - } - } - } - }; - samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); - } + samediff::Threads::parallel_tad(batchLoop, 0, leftPart.size(), 1); - int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); - } + return Status::OK(); +} +template +static void adjointTriangularMatrix_(sd::LaunchContext* context, + NDArray const* input, bool const lower, + NDArray* output) { + auto inputPart = input->allTensorsAlongDimension({-2, -1}); + auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto cols = input->sizeAt(-1); + auto rows = input->sizeAt(-2); - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); + auto batchLoop = PRAGMA_THREADS_FOR { + for (auto batch = start; batch < stop; batch++) { + if (!lower) { + for (Nd4jLong r = 0; r < rows; r++) { + for (Nd4jLong c = 0; c <= r; c++) { + outputPart[batch].t(r, c) = inputPart[batch].t(c, r); + } + } + } else { + for (Nd4jLong r = 0; r < rows; r++) { + for (Nd4jLong c = r; c < cols; c++) { + outputPart[batch].t(r, c) = inputPart[batch].t(c, r); + } + } + } } + }; + samediff::Threads::parallel_tad(batchLoop, 0, inputPart.size(), 1); } + +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool adjoint, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return triangularSolveFunctor_, + (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); } + +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, + (context, input, lower, output), FLOAT_NATIVE); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp index 4194e976c37d..1b9a1b5a5ea4 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triu.cpp @@ -18,39 +18,41 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// template -static void triuBP_(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - - auto dOdI = NDArray(&gradO); // dO/dI - const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), dOdI, 'b'); - int dLen = dOdI.lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - if (dOdI.t(i) != static_cast(0.f)) - dOdI.t(i) = static_cast(1.f); - } - }; - samediff::Threads::parallel_for(func, 0, dLen); - - // FIXME: !!! - gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO -} - - void triuBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES); +static void triuBP_(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + auto dOdI = NDArray(&gradO); // dO/dI + const_cast(input).fillAsTriangular(0, diagonal, dOdI.sizeAt(-1), + dOdI, 'b'); + int dLen = dOdI.lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + if (dOdI.t(i) != static_cast(0.f)) + dOdI.t(i) = static_cast(1.f); } + }; + samediff::Threads::parallel_for(func, 0, dLen); + // FIXME: !!! + gradI.assign(dOdI * gradO); // chain rule: dLoss/dI = dO/dI * dLoss/dO } + +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, + (context, input, gradO, gradI, diagonal), + LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp index 78268b2dc12e..118c8fba8f89 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,80 +29,106 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* initMsg = initStateMsg.bufferAsT(); - const T* initMsdx = initStateMsdx.bufferAsT(); - - T* up = update.bufferAsT(); - T* stMsg = stateMsg.bufferAsT(); - T* stMsdx = stateMsdx.bufferAsT(); - - const T rho = static_cast(dRho); - const T epsilon = static_cast(dEpsilon); - const T rhoT = (1 - rho); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateMsdx.ordering() && - stateMsdx.ordering() == initStateMsdx.ordering() && - stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; - - up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt(stMsg[i] + epsilon)); - - stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateMsg.shapeInfo()); - bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateMsg.shapeInfo()); - bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateMsdx.shapeInfo()); - bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateMsdx.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < gradient.lengthOf(); i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.shapeInfo(), coords); - const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.shapeInfo(), coords); - const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.shapeInfo(), coords); - const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.shapeInfo(), coords); - - - stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - - up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - - stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; - } +static void adaDeltaUpdater_(const NDArray& gradient, + const NDArray& initStateMsg, + const NDArray& initStateMsdx, NDArray& update, + NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* initMsg = initStateMsg.bufferAsT(); + const T* initMsdx = initStateMsdx.bufferAsT(); + + T* up = update.bufferAsT(); + T* stMsg = stateMsg.bufferAsT(); + T* stMsdx = stateMsdx.bufferAsT(); + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + const T rhoT = (1 - rho); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && + 1 == stateMsg.ews() && 1 == initStateMsg.ews() && + 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateMsdx.ordering() && + stateMsdx.ordering() == initStateMsdx.ordering() && + stateMsdx.ordering() == initStateMsg.ordering() && + stateMsg.ordering() == initStateMsg.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; + + up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / + sd::math::nd4j_sqrt(stMsg[i] + epsilon)); + + stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateMsg.shapeInfo()); + bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + stateMsg.shapeInfo()); + bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateMsdx.shapeInfo()); + bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + stateMsdx.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < gradient.lengthOf(); i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initMsgOffset = + bXInMsgSame ? xOffset + : shape::getOffset(initStateMsg.shapeInfo(), coords); + const auto stMsgOffset = + bXStMsgSame ? xOffset + : shape::getOffset(stateMsg.shapeInfo(), coords); + const auto initMsdxOffset = + bXInMsdxSame ? xOffset + : shape::getOffset(initStateMsdx.shapeInfo(), coords); + const auto stMsdxOffset = + bXStMsdxSame ? xOffset + : shape::getOffset(stateMsdx.shapeInfo(), coords); + + stMsg[stMsgOffset] = + rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = + grad[xOffset] * + (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / + sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = + rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } + }; -void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, + (gradient, initStateMsg, initStateMsdx, update, + stateMsg, stateMsdx, dRho, dEpsilon), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp index e65f34e72bf6..9f80f59dd666 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,63 +29,75 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); - - T* up = update.bufferAsT(); - T* st = stateH.bufferAsT(); - - const T lr = static_cast(dLr); - const T epsilon = static_cast(dEpsilon); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - st[i] = init[i] + grad[i] * grad[i]; - up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); - - st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; - up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); - } +static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateH, const double dLr, + const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateH.ordering() && + stateH.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + st[i] = init[i] + grad[i] * grad[i]; + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); + + st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; + up[zOffset] = (lr * grad[xOffset]) / + (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; -void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, - const double dLr, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, + (gradient, initState, update, stateH, dLr, dEpsilon), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp index 6c7d0d3225f8..315f5997ae3f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,85 +29,112 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initU = initStateU.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stU = stateU.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - T epsilonT = lr / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateU.ordering() && - stateU.ordering() == initStateU.ordering() && - stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - //m = B_1 * m + (1-B_1)*grad - stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32; - - up[i] = stM[i] * epsilonT / stU[i]; - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords); - const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - //m = B_1 * m + (1-B_1)*grad - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; - - up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; - } +static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, + const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateU.ews() && + 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + // m = B_1 * m + (1-B_1)*grad + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[i] = sd::math::nd4j_max((beta2 * initU[i]), + sd::math::nd4j_abs(grad[i])) + + 1e-32; + + up[i] = stM[i] * epsilonT / stU[i]; + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateU.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initUOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateU.shapeInfo(), coords); + const auto stUOffset = + bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + // m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), + sd::math::nd4j_abs(grad[xOffset])) + + 1e-32; + + up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; + } + }; -void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, + (gradient, initStateU, initStateM, update, stateU, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp index 2d670949f748..2d2e5d1fe2a2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,85 +29,110 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, - NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, - const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initU = initStateU.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stU = stateU.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); - - T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateU.ordering() && - stateU.ordering() == initStateU.ordering() && - stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); - stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); - - up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords); - const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); - const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); - - up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); - } +static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, + const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateU.ews() && + 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); + + up[i] = + (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateU.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initUOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateU.shapeInfo(), coords); + const auto stUOffset = + bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); + const auto initMOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / + (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } + }; -void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, + (gradient, initStateU, initStateM, update, stateU, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp index 7cb05075c24b..cf84bf49bf66 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,98 +29,134 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initV = initStateV.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - const T* initH = initStateH.bufferAsT(); - - T* up = update.bufferAsT(); - T* stV = stateV.bufferAsT(); - T* stM = stateM.bufferAsT(); - T* stH = stateH.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - T epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); - - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - const T mbeta1 = (1 - beta1); - const T mbeta2 = (1 - beta2); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && - 1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateV.ordering() && - stateV.ordering() == initStateV.ordering() && - stateV.ordering() == initStateM.ordering() && - stateM.ordering() == initStateM.ordering() && - stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - stM[i] = beta1 * initM[i] + grad[i] * mbeta1; - stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; - stH[i] = sd::math::nd4j_max(initH[i], stV[i]); - - up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateV.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - bool bXInHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateH.shapeInfo()); - bool bXStHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.shapeInfo(), coords); - const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.shapeInfo(), coords); - const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); - - up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); - } +static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, + const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, + NDArray& stateV, NDArray& stateM, NDArray& stateH, + const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, + const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + const T* initH = initStateH.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + T* stH = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + T epsilonT = lr * + sd::math::nd4j_sqrt( + 1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / + (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateV.ews() && + 1 == initStateV.ews() && 1 == stateH.ews() && + 1 == initStateH.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering() && + stateM.ordering() == initStateH.ordering() && + stateH.ordering() == initStateH.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * mbeta1; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + stH[i] = sd::math::nd4j_max(initH[i], stV[i]); + + up[i] = + epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateV.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + bool bXInHSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateH.shapeInfo()); + bool bXStHSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateH.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initVOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateV.shapeInfo(), coords); + const auto stVOffset = + bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + const auto initHOffset = + bXInHSame ? xOffset + : shape::getOffset(initStateH.shapeInfo(), coords); + const auto stHOffset = + bXStHSame ? xOffset : shape::getOffset(stateH.shapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / + (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } + }; -void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } - -} -} +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR( + gradient.dataType(), amsGradUpdater_, + (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, + stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp index 40f9c9407789..5d5278be2aff 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,88 +29,111 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, - const double dBeta2, const double dEpsilon, const int nIteration) { - - const T* grad = gradient.bufferAsT(); - const T* initV = initStateV.bufferAsT(); - const T* initM = initStateM.bufferAsT(); - - T* up = update.bufferAsT(); - T* stV = stateV.bufferAsT(); - T* stM = stateM.bufferAsT(); - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); - const T mbeta1 = (1 - beta1); - const T mbeta2 = (1 - beta2); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && - update.ordering() == stateV.ordering() && - stateV.ordering() == initStateV.ordering() && - stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); - - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - auto oneMinusBeta1Grad = grad[i] * mbeta1; - - stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; - stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; - - up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[i]) + epsilon); - } - }; - - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateV.shapeInfo()); - bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); - bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ - - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.shapeInfo(), coords); - const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); - const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); - - auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; - - stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - - up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); - } +static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, + const NDArray& initStateM, NDArray& update, + NDArray& stateV, NDArray& stateM, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && + 1 == initStateM.ews() && 1 == stateV.ews() && + 1 == initStateV.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto oneMinusBeta1Grad = grad[i] * mbeta1; + + stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + + up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } + + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateV.shapeInfo()); + bool bXStVSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initStateM.shapeInfo()); + bool bXStMSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initVOffset = + bXInVSame ? xOffset + : shape::getOffset(initStateV.shapeInfo(), coords); + const auto stVOffset = + bXStVSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); + const auto initMOffset = + bXInMSame ? xOffset + : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = + bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = + (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); + } + }; -void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } - -} -} +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, + (gradient, initStateV, initStateM, update, stateV, + stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp index 1d8bb8d45fb8..68b81826d881 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp @@ -18,10 +18,10 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { @@ -29,63 +29,76 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template -static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { +static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateV, + const double dLr, const double dMomentum) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); + T* up = update.bufferAsT(); + T* st = stateV.bufferAsT(); - T* up = update.bufferAsT(); - T* st = stateV.bufferAsT(); + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + const T momentumT = (-momentum - 1); - const T lr = static_cast(dLr); - const T momentum = static_cast(dMomentum); - const T momentumT = (-momentum - 1); + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initState.ordering(); - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering(); + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + T prevState = momentum * init[i]; + st[i] = prevState - lr * grad[i]; + up[i] = prevState + momentumT * st[i]; + } + }; - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - T prevState = momentum * init[i]; - st[i] = prevState - lr * grad[i]; - up[i] = prevState + momentumT * st[i]; - } - }; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateV.shapeInfo()); - auto func = PRAGMA_THREADS_FOR{ + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.shapeInfo(), coords); - - T prevState = momentum * init[initOffset]; - st[stOffset] = prevState - lr * grad[xOffset]; - up[zOffset] = prevState + momentumT * st[stOffset]; - } - }; + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } + }; - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } -void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double dMomentum) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, + (gradient, initState, update, stateV, dLr, dMomentum), + FLOAT_TYPES); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp index 473b43cf8c03..8e2ed9d88837 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp @@ -18,74 +18,87 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include #include #include #include +#include namespace sd { namespace ops { namespace helpers { template -static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - const T* grad = gradient.bufferAsT(); - const T* init = initState.bufferAsT(); +static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateG, const double dLr, + const double dRmsDecay, const double dEpsilon) { + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); - T* up = update.bufferAsT(); - T* st = stateG.bufferAsT(); - - const T lr = static_cast(dLr); - const T rmsDecay = static_cast(dRmsDecay); - const T epsilon = static_cast(dEpsilon); - - bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews(); - bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering(); + T* up = update.bufferAsT(); + T* st = stateG.bufferAsT(); - if (bEws1 && bSameOrdering) { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ; - up[i] = (lr * grad[i]) / ( math::nd4j_sqrt(st[i]) + epsilon); - } - }; + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); - samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); - return; - } - - bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); - bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initState.shapeInfo()); - bool bXStSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateG.shapeInfo()); - - auto func = PRAGMA_THREADS_FOR{ + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && + 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateG.ordering() && + stateG.ordering() == initState.ordering(); - int coords[MAX_RANK]; - for (auto i = start; i < stop; i++) { - shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); - const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); - const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); - const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); - const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.shapeInfo(), coords); - - st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ; - up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); - } + if (bEws1 && bSameOrdering) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay); + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } }; samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); return; -} + } -void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES); -} + bool bXZsame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), + initState.shapeInfo()); + bool bXStSame = + shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateG.shapeInfo()); + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = + bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initOffset = + bXInSame ? xOffset : shape::getOffset(initState.shapeInfo(), coords); + const auto stOffset = + bXStSame ? xOffset : shape::getOffset(stateG.shapeInfo(), coords); + st[stOffset] = init[initOffset] * rmsDecay + + grad[xOffset] * grad[xOffset] * (1 - rmsDecay); + up[zOffset] = (lr * grad[xOffset]) / + (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; } + +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon) { + BUILD_SINGLE_SELECTOR( + gradient.dataType(), rmsPropUpdater_, + (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp b/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp index ebdfc674b196..da95201836c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/weights.cpp @@ -24,24 +24,31 @@ namespace sd { namespace ops { namespace helpers { - template - static void adjustWeights_(NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - int val = input->e(e); - if (val < maxLength) { - if (weights != nullptr) - output->p(val, output->e(val) + weights->e(e)); - else - output->p(val, output->e(val) + 1); - } - } +template +static void adjustWeights_(NDArray* input, NDArray* weights, NDArray* output, + int minLength, int maxLength) { + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + int val = input->e(e); + if (val < maxLength) { + if (weights != nullptr) + output->p(val, output->e(val) + weights->e(e)); + else + output->p(val, output->e(val) + 1); } - - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, (input, weights, output, minLength, maxLength), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void adjustWeights_, (NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength), LIBND4J_TYPES); + } } + +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength) { + BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, + (input, weights, output, minLength, maxLength), + LIBND4J_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void adjustWeights_, + (NDArray * input, NDArray* weights, NDArray* output, + int minLength, int maxLength), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp index d127fc1665c9..e4ca9ff0fe6a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/zeta.cpp @@ -18,66 +18,66 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include #include +#include namespace sd { namespace ops { namespace helpers { -const int maxIter = 1000000; // max number of loop iterations +const int maxIter = 1000000; // max number of loop iterations ////////////////////////////////////////////////////////////////////////// // slow implementation template static FORCEINLINE T zetaScalarSlow(const T x, const T q) { - - const T precision = (T)1e-7; // function stops the calculation of series when next item is <= precision - - // if (x <= (T)1.) - // throw("zeta function: x must be > 1 !"); - - // if (q <= (T)0.) - // throw("zeta function: q must be > 0 !"); - - T item; - T result = (T)0.; - for(int i = 0; i < maxIter; ++i) { - - item = math::nd4j_pow((q + i),-x); - result += item; - - if(item <= precision) - break; - } - - return result; -} + const T precision = (T)1e-7; // function stops the calculation of series when + // next item is <= precision + // if (x <= (T)1.) + // throw("zeta function: x must be > 1 !"); -////////////////////////////////////////////////////////////////////////// -// calculate the Hurwitz zeta function for arrays -template -static void zeta_(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray &z) { + // if (q <= (T)0.) + // throw("zeta function: q must be > 0 !"); - //auto result = NDArray(&x, false, context); - int xLen = x.lengthOf(); + T item; + T result = (T)0.; + for (int i = 0; i < maxIter; ++i) { + item = math::nd4j_pow((q + i), -x); + result += item; - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) - z.p(i, zetaScalar(x.e(i), q.e(i))); - }; + if (item <= precision) break; + } - samediff::Threads::parallel_for(func, 0, xLen); + return result; } -void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) { - BUILD_SINGLE_SELECTOR(x.dataType(), zeta_, (context, x, q, z), FLOAT_TYPES); -} +////////////////////////////////////////////////////////////////////////// +// calculate the Hurwitz zeta function for arrays +template +static void zeta_(sd::LaunchContext* context, const NDArray& x, + const NDArray& q, NDArray& z) { + // auto result = NDArray(&x, false, context); + int xLen = x.lengthOf(); -BUILD_SINGLE_TEMPLATE(template void zeta_, (sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z), FLOAT_TYPES); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) + z.p(i, zetaScalar(x.e(i), q.e(i))); + }; + samediff::Threads::parallel_for(func, 0, xLen); } + +void zeta(sd::LaunchContext* context, const NDArray& x, const NDArray& q, + NDArray& z) { + BUILD_SINGLE_SELECTOR(x.dataType(), zeta_, (context, x, q, z), FLOAT_TYPES); } -} +BUILD_SINGLE_TEMPLATE(template void zeta_, + (sd::LaunchContext * context, const NDArray& x, + const NDArray& q, NDArray& z), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/crop_and_resize.h b/libnd4j/include/ops/declarable/helpers/crop_and_resize.h index cff96f93e5b2..18d9ea16d313 100644 --- a/libnd4j/include/ops/declarable/helpers/crop_and_resize.h +++ b/libnd4j/include/ops/declarable/helpers/crop_and_resize.h @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author sgazeos@gmail.com // @@ -22,19 +21,23 @@ #ifndef SD_CROP_AND_RESIZE_H #define SD_CROP_AND_RESIZE_H -#include #include - +#include namespace sd { - namespace ops { - namespace helpers { - template - void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops); - - void cropAndResizeFunctor(sd::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops); - } - } -} - -#endif //SD_CROP_AND_RESIZE_H +namespace ops { +namespace helpers { +template +void cropAndResizeFunctor_(NDArray const* images, NDArray const* boxes, + NDArray const* indices, NDArray const* cropSize, + int method, double extrapolationVal, NDArray* crops); + +void cropAndResizeFunctor(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_CROP_AND_RESIZE_H diff --git a/libnd4j/include/ops/declarable/helpers/cross.h b/libnd4j/include/ops/declarable/helpers/cross.h index 2a3ee22c1626..818613749184 100644 --- a/libnd4j/include/ops/declarable/helpers/cross.h +++ b/libnd4j/include/ops/declarable/helpers/cross.h @@ -18,69 +18,75 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { -void crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o); - -void FORCEINLINE cross(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - - if (a->isR()) { - auto a0 = a->e(0); - auto a1 = a->e(1); - auto a2 = a->e(2); - - auto b0 = b->e(0); - auto b1 = b->e(1); - auto b2 = b->e(2); - - o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); - o->p(1L, a2 * b0 - a0 * b2); - o->p(2L, a0 * b1 - a1 * b0); - } else { - auto a0 = a->e(0); - auto a1 = a->e(1); - auto a2 = a->e(2); - - auto b0 = b->e(0); - auto b1 = b->e(1); - auto b2 = b->e(2); - - o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); - o->p(1L, a2 * b0 - a0 * b2); - o->p(2L, a0 * b1 - a1 * b0); - } +void crossBatched(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o); + +void FORCEINLINE cross(sd::LaunchContext *context, NDArray *a, NDArray *b, + NDArray *o) { + if (a->isR()) { + auto a0 = a->e(0); + auto a1 = a->e(1); + auto a2 = a->e(2); + + auto b0 = b->e(0); + auto b1 = b->e(1); + auto b2 = b->e(2); + + o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); + o->p(1L, a2 * b0 - a0 * b2); + o->p(2L, a0 * b1 - a1 * b0); + } else { + auto a0 = a->e(0); + auto a1 = a->e(1); + auto a2 = a->e(2); + + auto b0 = b->e(0); + auto b1 = b->e(1); + auto b2 = b->e(2); + + o->p(Nd4jLong(0L), a1 * b2 - a2 * b1); + o->p(1L, a2 * b0 - a0 * b2); + o->p(2L, a0 * b1 - a1 * b0); + } } - void FORCEINLINE _crossBatched(sd::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { - auto a_ = a->reshape(a->ordering(), {-1, 3}); - auto b_ = b->reshape(b->ordering(), {-1, 3}); - auto o_ = o->reshape(o->ordering(), {-1, 3}, false); +void FORCEINLINE _crossBatched(sd::LaunchContext *context, NDArray *a, + NDArray *b, NDArray *o) { + auto a_ = a->reshape(a->ordering(), {-1, 3}); + auto b_ = b->reshape(b->ordering(), {-1, 3}); + auto o_ = o->reshape(o->ordering(), {-1, 3}, false); - auto tadsA = a_.allTensorsAlongDimension({1}); - auto tadsB = b_.allTensorsAlongDimension({1}); - auto tadsO = o_.allTensorsAlongDimension({1}); + auto tadsA = a_.allTensorsAlongDimension({1}); + auto tadsB = b_.allTensorsAlongDimension({1}); + auto tadsO = o_.allTensorsAlongDimension({1}); - int tads = tadsA.size(); + int tads = tadsA.size(); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto a_ = tadsA.at(e); - auto b_ = tadsB.at(e); - auto o_ = tadsO.at(e); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto a_ = tadsA.at(e); + auto b_ = tadsB.at(e); + auto o_ = tadsO.at(e); - helpers::cross(context, &a_, &b_, &o_); - } - }; - - samediff::Threads::parallel_tad(func, 0, tads); + helpers::cross(context, &a_, &b_, &o_); } + }; - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); -} + samediff::Threads::parallel_tad(func, 0, tads); } -} \ No newline at end of file + +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext *context, + NDArray const *targets, + NDArray const *input, + NDArray const *weights, + NDArray *output); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu index 70ff75b96336..afffabe80a7f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/BarnesHutTsne.cu @@ -24,47 +24,49 @@ namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// count rows kernel - count input pRows and pCols and put result onto pRowCounts -// pRowCounts - array of ints, with length N -// pRows - array of ints with length N, vals from 0 to N-1 -// pCols - array of ints with length < N and vals between 0 and max(pRows) +// count rows kernel - count input pRows and pCols and put result onto +// pRowCounts pRowCounts - array of ints, with length N pRows - array of ints +// with length N, vals from 0 to N-1 pCols - array of ints with length < N and +// vals between 0 and max(pRows) // - static __global__ void countRowsKernel(int* pRowCounts, int const* pRows, int const* pCols, Nd4jLong N) { - auto start = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - for (int n = threadIdx.x + start; n < N; n += step) { - int begin = pRows[n];//->e(n); - int end = pRows[n + 1];//rowP->e(n + 1); - for (int i = begin; i < end; i++) { - bool present = false; - // loop between near pRows - for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) - if (pCols[m] == n) { // mark index as existed with columns array - present = true; - break; - } +static __global__ void countRowsKernel(int* pRowCounts, int const* pRows, + int const* pCols, Nd4jLong N) { + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + for (int n = threadIdx.x + start; n < N; n += step) { + int begin = pRows[n]; //->e(n); + int end = pRows[n + 1]; // rowP->e(n + 1); + for (int i = begin; i < end; i++) { + bool present = false; + // loop between near pRows + for (int m = pRows[pCols[i]]; m < pRows[pCols[i] + 1]; m++) + if (pCols[m] == n) { // mark index as existed with columns array + present = true; + break; + } - atomicAdd(&pRowCounts[n], 1); + atomicAdd(&pRowCounts[n], 1); - if (!present) // increment row counter for given index - atomicAdd(&pRowCounts[pCols[i]], 1); - } - } + if (!present) // increment row counter for given index + atomicAdd(&pRowCounts[pCols[i]], 1); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // row counter caller - Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) { - - int* pRowCounts = reinterpret_cast(rowCounts.specialBuffer()); - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - auto stream = rowCounts.getContext()->getCudaStream(); - countRowsKernel<<<1, 1, 128, *stream>>>(pRowCounts, pRows, pCols, N); - NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {}); - //rowCounts.printBuffer("Row counts"); - auto numElements = numElementsArr.e(0); - return numElements; - } +Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, + NDArray& rowCounts) { + int* pRowCounts = reinterpret_cast(rowCounts.specialBuffer()); + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + auto stream = rowCounts.getContext()->getCudaStream(); + countRowsKernel<<<1, 1, 128, *stream>>>(pRowCounts, pRows, pCols, N); + NDArray numElementsArr = + rowCounts.sumNumber(); // reduceAlongDimension(reduce::Sum, {}); + // rowCounts.printBuffer("Row counts"); + auto numElements = numElementsArr.e(0); + return numElements; +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // extend symRowP with pRowCounts array vals @@ -72,18 +74,17 @@ namespace helpers { // symRowP - int array with length N+1 // N - given array length // - static __global__ void fillUpsymRow(int const* pRowCounts, int* symRowP, int N) { - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int n = start; n < N + 1; n += step) { // to avoid race condition use shift only for given index - symRowP[n] = 0; - for (int i = 0; i < n; i++) - atomicAdd(&symRowP[n], pRowCounts[i]); - } +static __global__ void fillUpsymRow(int const* pRowCounts, int* symRowP, + int N) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - } + for (int n = start; n < N + 1; + n += step) { // to avoid race condition use shift only for given index + symRowP[n] = 0; + for (int i = 0; i < n; i++) atomicAdd(&symRowP[n], pRowCounts[i]); + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize routine kernel @@ -96,183 +97,225 @@ namespace helpers { // pOutput - result matrix (floats) // N - pRows length // - template - static __global__ void symmetrizeKernel(int const* pRows, int const* pCols, T const* pVals, int* symRowP, int* symColP, int* offset, T* pOutput, int N) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +template +static __global__ void symmetrizeKernel(int const* pRows, int const* pCols, + T const* pVals, int* symRowP, + int* symColP, int* offset, T* pOutput, + int N) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int n = start; n < N; n += step) { - int begin = pRows[n]; - int bound = pRows[n + 1]; + for (int n = start; n < N; n += step) { + int begin = pRows[n]; + int bound = pRows[n + 1]; - for (int i = begin; i < bound; i++) { - bool present = false; - int colPI = pCols[i]; - int start = pRows[colPI]; - int end = pRows[colPI + 1]; + for (int i = begin; i < bound; i++) { + bool present = false; + int colPI = pCols[i]; + int start = pRows[colPI]; + int end = pRows[colPI + 1]; - for (int m = start; m < end; m++) { - if (pCols[m] == n) { - present = true; - if (n <= colPI) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[colPI] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; - } - } - } + for (int m = start; m < end; m++) { + if (pCols[m] == n) { + present = true; + if (n <= colPI) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[colPI] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m]; + } + } + } - // If (colP[i], n) is not present, there is no addition involved - if (!present) { - symColP[symRowP[n] + offset[n]] = colPI; - symColP[symRowP[pCols[i]] + offset[colPI]] = n; - pOutput[symRowP[n] + offset[n]] = pVals[i]; - pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; - } - // Update offsets - if (!present || (present && n <= colPI)) { - atomicAdd(&offset[n], 1); + // If (colP[i], n) is not present, there is no addition involved + if (!present) { + symColP[symRowP[n] + offset[n]] = colPI; + symColP[symRowP[pCols[i]] + offset[colPI]] = n; + pOutput[symRowP[n] + offset[n]] = pVals[i]; + pOutput[symRowP[colPI] + offset[colPI]] = pVals[i]; + } + // Update offsets + if (!present || (present && n <= colPI)) { + atomicAdd(&offset[n], 1); - if (colPI != n) - atomicAdd(&offset[colPI], 1); - } - } - } + if (colPI != n) atomicAdd(&offset[colPI], 1); + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize algorithm itself // - template - static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int* symRowP = reinterpret_cast(outputRows->specialBuffer()); - int* pRowCounts = reinterpret_cast(rowCounts->specialBuffer()); - auto stream = outputCols->getContext()->getCudaStream(); - // fill up syRowP array - fillUpsymRow<<<1, N, 128, *stream>>>(pRowCounts, symRowP, N); - outputRows->syncToHost(); -// outputRows->printBuffer("output rows"); - int* symColP = reinterpret_cast(outputCols->specialBuffer()); -// outputRows->printBuffer("SymRows are"); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - T const* pVals = reinterpret_cast(valP->specialBuffer()); - T* pOutput = reinterpret_cast(outputVals->specialBuffer()); - //std::vector rowCountsV = rowCounts->getBufferAsVector(); - auto offsetArr = NDArrayFactory::create('c', {N}); - int* offset = reinterpret_cast(offsetArr.specialBuffer()); - // symmetrize itself - symmetrizeKernel<<<1, 1, 1024, *stream>>>(pRows, pCols, pVals, symRowP, symColP, offset, pOutput, N); - } +template +static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, + NDArray* outputRows, NDArray* outputCols, + NDArray* outputVals, NDArray* rowCounts) { + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int* symRowP = reinterpret_cast(outputRows->specialBuffer()); + int* pRowCounts = reinterpret_cast(rowCounts->specialBuffer()); + auto stream = outputCols->getContext()->getCudaStream(); + // fill up syRowP array + fillUpsymRow<<<1, N, 128, *stream>>>(pRowCounts, symRowP, N); + outputRows->syncToHost(); + // outputRows->printBuffer("output rows"); + int* symColP = reinterpret_cast(outputCols->specialBuffer()); + // outputRows->printBuffer("SymRows are"); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + T const* pVals = reinterpret_cast(valP->specialBuffer()); + T* pOutput = reinterpret_cast(outputVals->specialBuffer()); + // std::vector rowCountsV = rowCounts->getBufferAsVector(); + auto offsetArr = NDArrayFactory::create('c', {N}); + int* offset = reinterpret_cast(offsetArr.specialBuffer()); + // symmetrize itself + symmetrizeKernel<<<1, 1, 1024, *stream>>>(pRows, pCols, pVals, symRowP, + symColP, offset, pOutput, N); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // symmetrize caller and adoption // - void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) { - BUILD_SINGLE_SELECTOR(valP->dataType(), barnes_symmetrize_, (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), NUMERIC_TYPES); +void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts) { + BUILD_SINGLE_SELECTOR( + valP->dataType(), barnes_symmetrize_, + (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), + NUMERIC_TYPES); - *outputVals /= 2.0; - } - BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES); + *outputVals /= 2.0; +} +BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, + (const NDArray* rowP, const NDArray* colP, + const NDArray* valP, Nd4jLong N, NDArray* outputRows, + NDArray* outputCols, NDArray* outputVals, + NDArray* rowCounts), + NUMERIC_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces implementation // - template - static __global__ void edgeForcesKernel(int const* pRows, int const* pCols, T const* dataP, T const* vals, T* outputP, int N, int colCount, int rowSize) { -// std::vector buffer(colCount); +template +static __global__ void edgeForcesKernel(int const* pRows, int const* pCols, + T const* dataP, T const* vals, + T* outputP, int N, int colCount, + int rowSize) { + // std::vector buffer(colCount); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (int n = start; n < N; n += step) { - int start = pRows[n]; - int end = pRows[n + 1]; - int shift = n * colCount; - for (int i = start; i < end; i++) { - T const* thisSlice = dataP + pCols[i] * colCount; - T res = 1; + for (int n = start; n < N; n += step) { + int start = pRows[n]; + int end = pRows[n + 1]; + int shift = n * colCount; + for (int i = start; i < end; i++) { + T const* thisSlice = dataP + pCols[i] * colCount; + T res = 1; - for (int k = 0; k < colCount; k++) { - auto valTemp = dataP[shift + k] - thisSlice[k];//thisSlice[k]; - res += valTemp * valTemp; // (dataP[shift + k] * dataP[shift + k] - 2 * dataP[shift + k] * thisSlice[k] + thisSlice[k] * thisSlice[k]) - } - res = vals[i] / res; - for (int k = 0; k < colCount; k++) - math::atomics::nd4j_atomicAdd(&outputP[shift + k], T((dataP[shift + k] - thisSlice[k]) * res)); - } - } + for (int k = 0; k < colCount; k++) { + auto valTemp = dataP[shift + k] - thisSlice[k]; // thisSlice[k]; + res += + valTemp * + valTemp; // (dataP[shift + k] * dataP[shift + k] - 2 * dataP[shift + // + k] * thisSlice[k] + thisSlice[k] * thisSlice[k]) + } + res = vals[i] / res; + for (int k = 0; k < colCount; k++) + math::atomics::nd4j_atomicAdd( + &outputP[shift + k], T((dataP[shift + k] - thisSlice[k]) * res)); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces algorithm // - template - static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) { - NDArray::prepareSpecialUse({output}, {data, rowP, colP, valP, valP}); - T const* dataP = reinterpret_cast(data->specialBuffer()); - T const* vals = reinterpret_cast(valP->specialBuffer()); - T* outputP = reinterpret_cast(output->specialBuffer()); - int const* pRows = reinterpret_cast(rowP->specialBuffer()); - int const* pCols = reinterpret_cast(colP->specialBuffer()); - int colCount = data->columns(); - //auto shift = 0; - auto rowSize = sizeof(T) * colCount; - auto stream = output->getContext()->getCudaStream(); - edgeForcesKernel<<<1, 128, 1024, *stream>>>(pRows, pCols, dataP, vals, outputP, N, colCount, rowSize); - NDArray::registerSpecialUse({output}, {rowP, colP, valP, data}); - } +template +static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {data, rowP, colP, valP, valP}); + T const* dataP = reinterpret_cast(data->specialBuffer()); + T const* vals = reinterpret_cast(valP->specialBuffer()); + T* outputP = reinterpret_cast(output->specialBuffer()); + int const* pRows = reinterpret_cast(rowP->specialBuffer()); + int const* pCols = reinterpret_cast(colP->specialBuffer()); + int colCount = data->columns(); + // auto shift = 0; + auto rowSize = sizeof(T) * colCount; + auto stream = output->getContext()->getCudaStream(); + edgeForcesKernel<<<1, 128, 1024, *stream>>>(pRows, pCols, dataP, vals, + outputP, N, colCount, rowSize); + NDArray::registerSpecialUse({output}, {rowP, colP, valP, data}); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // edge forces caller // - void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) { - // Loop over all edges in the graph - BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, (rowP, colP, valP, N, &data, output), FLOAT_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES); +void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray* output, + NDArray const& data) { + // Loop over all edges in the graph + BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, + (rowP, colP, valP, N, &data, output), FLOAT_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, + (const NDArray* rowP, NDArray const* colP, + NDArray const* valP, int N, NDArray const* data, + NDArray* output), + FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// gains - run a function T((x + 2.) * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps)); -// for all members in input and put all in output +// gains - run a function T((x + 2.) * sd::math::nd4j_sign(grad) != +// sd::math::nd4j_sign(eps)) + T(x * 0.8 * sd::math::nd4j_sign(grad) +// != sd::math::nd4j_sign(eps)); for all members in input and put all in +// output // - template - void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - auto gainsInternal = LAMBDA_TTT(x, grad, eps) { - T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) ? x + T(.2) : x * T(.8); - if(res < .01) res = .01; - return res; - }; +template +void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + auto gainsInternal = LAMBDA_TTT(x, grad, eps) { + T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) + ? x + T(.2) + : x * T(.8); + if (res < .01) res = .01; + return res; + }; - input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); - } + input->applyTriplewiseLambda(*gradX, *epsilon, gainsInternal, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // gains caller - void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, (input, gradX, epsilon, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES); +void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, + (input, gradX, epsilon, output), NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void barnes_gains_, + (NDArray * input, NDArray* gradX, NDArray* epsilon, + NDArray* output), + NUMERIC_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // cell contains - check cells for given point // - bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) { - auto cornerMinusWidth = *corner - *width; - auto cornerPlusWidth = *corner + *width; - // executes on host side, so sync all to host memory - cornerMinusWidth.syncToHost(); - cornerPlusWidth.syncToHost(); - for (Nd4jLong i = 0; i < dimension; i++) { - if (cornerMinusWidth.e(i) > point->e(i)) - return false; - if (cornerPlusWidth.e(i) < point->e(i)) - return false; - } +bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, + Nd4jLong dimension) { + auto cornerMinusWidth = *corner - *width; + auto cornerPlusWidth = *corner + *width; + // executes on host side, so sync all to host memory + cornerMinusWidth.syncToHost(); + cornerPlusWidth.syncToHost(); + for (Nd4jLong i = 0; i < dimension; i++) { + if (cornerMinusWidth.e(i) > point->e(i)) return false; + if (cornerPlusWidth.e(i) < point->e(i)) return false; + } - return true; - } -} + return true; } -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index c8bc709a0bd6..7cc229484a04 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -19,593 +19,690 @@ // @author raver119@gmail.com // -#include -#include +#include +#include #include +#include +#include + #include -#include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + const void *vy, const Nd4jLong *yShapeInfo, + void *vz) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xzLen; - __shared__ int xzRank, yRank; + __shared__ Nd4jLong xzLen; + __shared__ int xzRank, yRank; - if (threadIdx.x == 0) { - xzLen = shape::length(xShapeInfo); + if (threadIdx.x == 0) { + xzLen = shape::length(xShapeInfo); - xzRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - } - __syncthreads(); + xzRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int coords[MAX_RANK]; - for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { - shape::index2coords(i, xShapeInfo, coords); + for (int i = tid; i < xzLen; i += blockDim.x * gridDim.x) { + shape::index2coords(i, xShapeInfo, coords); - const auto xzOffset = shape::getOffset(xShapeInfo, coords); - const auto xVal = x[xzOffset]; + const auto xzOffset = shape::getOffset(xShapeInfo, coords); + const auto xVal = x[xzOffset]; - if(xVal < 0) { - for (uint j = 0; j < yRank; ++j) - if(yShapeInfo[j + 1] == 1) - coords[j + 1] = 0; + if (xVal < 0) { + for (uint j = 0; j < yRank; ++j) + if (yShapeInfo[j + 1] == 1) coords[j + 1] = 0; - z[xzOffset] = xVal * y[shape::getOffset(yShapeInfo, coords + 1)]; - } - else - z[xzOffset] = xVal; - } + z[xzOffset] = xVal * y[shape::getOffset(yShapeInfo, coords + 1)]; + } else + z[xzOffset] = xVal; + } } /////////////////////////////////////////////////////////////////// -template -linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { - preluCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz); +template +linkage void preluCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz) { + preluCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz); } /////////////////////////////////////////////////////////////////// -void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { - - PointersManager manager(context, "prelu"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; - - const auto xType = input.dataType(); - const auto yType = alpha.dataType(); - - NDArray::prepareSpecialUse({&output}, {&input, &alpha}); - BUILD_SINGLE_SELECTOR_TWICE(xType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), alpha.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &alpha}); - - manager.synchronize(); +void prelu(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, NDArray &output) { + PointersManager manager(context, "prelu"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; + + const auto xType = input.dataType(); + const auto yType = alpha.dataType(); + + NDArray::prepareSpecialUse({&output}, {&input, &alpha}); + BUILD_SINGLE_SELECTOR_TWICE( + xType, preluCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), + alpha.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &alpha}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo, - const void *vAlpha, const Nd4jLong *alphaShapeInfo, - const void *vdLdO, const Nd4jLong *dLdOShapeInfo, - void *vdLdI, const Nd4jLong *dLdIShapeInfo, - void *vdLdA, const Nd4jLong *dLdAShapeInfo) { - - const auto in = reinterpret_cast(vIn); - const auto alpha = reinterpret_cast(vAlpha); - const auto dLdO = reinterpret_cast(vdLdO); - auto dLdI = reinterpret_cast(vdLdI); - auto dLdA = reinterpret_cast(vdLdA); - - __shared__ Nd4jLong inLen, totalThreads; - __shared__ int inRank, alphaRank; - - if (threadIdx.x == 0) { - inLen = shape::length(inShapeInfo); - totalThreads = gridDim.x * blockDim.x; - - inRank = shape::rank(inShapeInfo); - alphaRank = shape::rank(alphaShapeInfo); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; - - for (int i = tid; i < inLen; i += totalThreads) { - shape::index2coords(i, inShapeInfo, coords); - - const auto inOffset = shape::getOffset(inShapeInfo, coords); - const auto dLdOOffset = shape::getOffset(dLdOShapeInfo, coords); - const auto dLdIOffset = shape::getOffset(dLdIShapeInfo, coords); - - const auto xVal = in[inOffset]; - const auto grO = dLdO[dLdOOffset]; - - if(xVal < 0) { - - for (uint j = 0; j < alphaRank; ++j) - if(alphaShapeInfo[j + 1] == 1) - coords[j + 1] = 0; - - const auto alphaOffset = shape::getOffset(alphaShapeInfo, coords + 1); - const auto dLdAOffset = shape::getOffset(dLdAShapeInfo, coords + 1); - - dLdI[dLdIOffset] = grO * alpha[alphaOffset]; - - sd::math::atomics::nd4j_atomicAdd(&dLdA[dLdAOffset], static_cast(grO * xVal)); - } - else - dLdI[dLdIOffset] = grO; - } +template +__global__ linkage void preluBPCuda( + const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, + const Nd4jLong *alphaShapeInfo, const void *vdLdO, + const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, + void *vdLdA, const Nd4jLong *dLdAShapeInfo) { + const auto in = reinterpret_cast(vIn); + const auto alpha = reinterpret_cast(vAlpha); + const auto dLdO = reinterpret_cast(vdLdO); + auto dLdI = reinterpret_cast(vdLdI); + auto dLdA = reinterpret_cast(vdLdA); + + __shared__ Nd4jLong inLen, totalThreads; + __shared__ int inRank, alphaRank; + + if (threadIdx.x == 0) { + inLen = shape::length(inShapeInfo); + totalThreads = gridDim.x * blockDim.x; + + inRank = shape::rank(inShapeInfo); + alphaRank = shape::rank(alphaShapeInfo); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + int coords[MAX_RANK]; + + for (int i = tid; i < inLen; i += totalThreads) { + shape::index2coords(i, inShapeInfo, coords); + + const auto inOffset = shape::getOffset(inShapeInfo, coords); + const auto dLdOOffset = shape::getOffset(dLdOShapeInfo, coords); + const auto dLdIOffset = shape::getOffset(dLdIShapeInfo, coords); + + const auto xVal = in[inOffset]; + const auto grO = dLdO[dLdOOffset]; + + if (xVal < 0) { + for (uint j = 0; j < alphaRank; ++j) + if (alphaShapeInfo[j + 1] == 1) coords[j + 1] = 0; + + const auto alphaOffset = shape::getOffset(alphaShapeInfo, coords + 1); + const auto dLdAOffset = shape::getOffset(dLdAShapeInfo, coords + 1); + + dLdI[dLdIOffset] = grO * alpha[alphaOffset]; + + sd::math::atomics::nd4j_atomicAdd(&dLdA[dLdAOffset], + static_cast(grO * xVal)); + } else + dLdI[dLdIOffset] = grO; + } } ////////////////////////////////////////////////////////////////////////// -template -__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) { - - preluBPCuda<<>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo); +template +__host__ linkage void preluBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, + const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, + const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, + void *vdLdA, const Nd4jLong *dLdAShapeInfo) { + preluBPCuda<<>>( + vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, + dLdIShapeInfo, vdLdA, dLdAShapeInfo); } ////////////////////////////////////////////////////////////////////////// -void preluBP(sd::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { - dLdA.nullify(); - - PointersManager manager(context, "preluBP"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; - - const auto xType = input.dataType(); - const auto zType = alpha.dataType(); - - NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - BUILD_SINGLE_SELECTOR_TWICE(xType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), alpha.specialShapeInfo(), dLdO.specialBuffer(), dLdO.specialShapeInfo(), dLdI.specialBuffer(), dLdI.specialShapeInfo(), dLdA.specialBuffer(), dLdA.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - - manager.synchronize(); +void preluBP(sd::LaunchContext *context, const NDArray &input, + const NDArray &alpha, const NDArray &dLdO, NDArray &dLdI, + NDArray &dLdA) { + dLdA.nullify(); + + PointersManager manager(context, "preluBP"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; + + const auto xType = input.dataType(); + const auto zType = alpha.dataType(); + + NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); + BUILD_SINGLE_SELECTOR_TWICE( + xType, preluBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), alpha.specialBuffer(), + alpha.specialShapeInfo(), dLdO.specialBuffer(), dLdO.specialShapeInfo(), + dLdI.specialBuffer(), dLdI.specialShapeInfo(), dLdA.specialBuffer(), + dLdA.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); + + manager.synchronize(); } - /////////////////////////////////////////////////////////////////// -template -__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : sd::math::nd4j_max(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); - const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); - z[zOffset] = sd::math::nd4j_exp(x[xOffset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate z[offset] / sum ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); - z[zOffset] /= shmem[0]; - } +template +__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[xOffset] + : sd::math::nd4j_max( + x[xOffset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo); + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); + z[zOffset] = sd::math::nd4j_exp(x[xOffset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[zOffset] + : (z[zOffset] + + temp); // take into account sum element evaluated on previous + // iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate z[offset] / sum ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo); + z[zOffset] /= shmem[0]; + } } -template -__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - softMaxForVectorCuda(vx, xShapeInfo, vz, zShapeInfo); +template +__global__ void softMaxForVectorCudaGlobal(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + softMaxForVectorCuda(vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// template -linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - softMaxForVectorCudaGlobal<<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo); +linkage void softMaxForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + softMaxForVectorCudaGlobal + <<<1, MAX_NUM_THREADS / 4, (MAX_NUM_THREADS / 4) * sizeof(T) + 512, + *stream>>>(vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -template -__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - const auto* xTad = x + xOffsets[blockIdx.x]; - auto* zTad = z + zOffsets[blockIdx.x]; - - softMaxForVectorCuda(xTad, xTadShapeInfo, zTad, zTadShapeInfo); +template +__global__ static void softMaxCuda(const void *vx, + const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, void *vz, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zOffsets) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + const auto *xTad = x + xOffsets[blockIdx.x]; + auto *zTad = z + zOffsets[blockIdx.x]; + + softMaxForVectorCuda(xTad, xTadShapeInfo, zTad, zTadShapeInfo); } /////////////////////////////////////////////////////////////////// -template -static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { - - softMaxCuda<<>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets); +template +static void softMaxCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xTadShapeInfo, + const Nd4jLong *xOffsets, void *vz, + const Nd4jLong *zTadShapeInfo, + const Nd4jLong *zOffsets) { + softMaxCuda<<>>( + vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets); } - ////////////////////////////////////////////////////////////////////////// -void softmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - - PointersManager manager(context, "helpers::softmax"); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) { - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } - else - output = 1.; - } - else { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), {dimension}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), {dimension}); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = packZ.numberOfTads(); - const int sharedMem = input.sizeOfT() * threadsPerBlock + 512; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - // auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - // (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - // auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - // output /= sumAlongDim; - // input.tickReadDevice(); - } - - - manager.synchronize(); - - output.tickWriteDevice(); +void softmax(sd::LaunchContext *context, const NDArray &input, NDArray &output, + const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + + PointersManager manager(context, "helpers::softmax"); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + } else + output = 1.; + } else { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), {dimension}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), {dimension}); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = packZ.numberOfTads(); + const int sharedMem = input.sizeOfT() * threadsPerBlock + 512; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), softMaxCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), packX.specialShapeInfo(), + packX.specialOffsets(), output.specialBuffer(), + packZ.specialShapeInfo(), packZ.specialOffsets()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + // auto maxAlongDim = + // const_cast(input).reduceAlongDimension(reduce::Max, + // {dimension}, true); (input - maxAlongDim).applyTransform(transform::Exp, + // &output); // output contains exponents temporarily auto sumAlongDim = + // output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= + // sumAlongDim; input.tickReadDevice(); + } + + manager.synchronize(); + + output.tickWriteDevice(); } /////////////////////////////////////////////////////////////////// -template -__global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xzShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : sd::math::nd4j_max(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same time evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_exp(x[offset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate log(z[offset] / sum) ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_log(z[offset] / shmem[0]); - } +template +__global__ void logSoftMaxForVectorCuda(const void *vx, + const Nd4jLong *xzShapeInfo, void *vz) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xzShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[offset] + : sd::math::nd4j_max( + x[offset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same time evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_exp(x[offset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[offset] + : (z[offset] + temp); // take into account sum element evaluated + // on previous iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate log(z[offset] / sum) ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_log(z[offset] / shmem[0]); + } } /////////////////////////////////////////////////////////////////// template -linkage void logSoftMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - logSoftMaxForVectorCuda<<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); +linkage void logSoftMaxForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + logSoftMaxForVectorCuda + <<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>( + vx, xzShapeInfo, vz); } ////////////////////////////////////////////////////////////////////////// -void logSoftmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - - if(input.isVector()) { - - if(rank == 1 || input.sizeAt(dimension) != 1) { - BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - input.tickReadDevice(); - } - else - output = 0.; - } - else { - - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output.applyTransform(transform::Log, output); - input.tickReadDevice(); - } - - PointersManager manager(context, "helpers::logSoftmax"); - manager.synchronize(); - - output.tickWriteDevice(); +void logSoftmax(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + + if (input.isVector()) { + if (rank == 1 || input.sizeAt(dimension) != 1) { + BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + input.tickReadDevice(); + } else + output = 0.; + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output.applyTransform(transform::Log, output); + input.tickReadDevice(); + } + + PointersManager manager(context, "helpers::logSoftmax"); + manager.synchronize(); + + output.tickWriteDevice(); } /////////////////////////////////////////////////////////////////// -template -__global__ linkage void softMaxDerivForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - // logic of this kernel is based on assumption gridDim = 1 - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ int numOfIters; - __shared__ T* shmem; - - if (threadIdx.x == 0) { - extern __shared__ char shared[]; - shmem = reinterpret_cast(shared); - len = shape::length(xzShapeInfo); - numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) - } - __syncthreads(); - - T temp = -DataTypeUtils::max(); // set start value to compare with at first iteration, FIXME: what if T is unsigned ?? - - // ************ evaluate max element in input array x ************ // - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : sd::math::nd4j_max(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] = sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); - __syncthreads(); - } - - temp = shmem[0]; // save max value calculated at current iteration - } - - const T max = temp; - temp = 0; - - // ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ // - // at the same evaluate sum of exponents, sum will be stored in shmem[0] - for (int i = 0; i < numOfIters; ++i) { - - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] = sd::math::nd4j_exp(x[offset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp - } - else - shmem[threadIdx.x] = 0; - - __syncthreads(); - - for (int s = blockDim.x / 2; s > 0; s /= 2) { - if(threadIdx.x < s) - shmem[threadIdx.x] += shmem[threadIdx.x + s]; - __syncthreads(); - } - - temp = shmem[0]; // save sum calculated at current iteration - } - - // ************ evaluate (z[offset] / sum) and derivative z[offset] = z[offset] * (1 - z[offset]) ************ // - for (int i = 0; i < numOfIters; ++i) { - const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; - if(elemIdx >= len) continue; - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); - z[offset] /= shmem[0]; - z[offset] *= (1.f - z[offset]); // derivative - } +template +__global__ linkage void softMaxDerivForVectorCuda(const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + // logic of this kernel is based on assumption gridDim = 1 + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ int numOfIters; + __shared__ T *shmem; + + if (threadIdx.x == 0) { + extern __shared__ char shared[]; + shmem = reinterpret_cast(shared); + len = shape::length(xzShapeInfo); + numOfIters = + (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) + } + __syncthreads(); + + T temp = + -DataTypeUtils::max(); // set start value to compare with at first + // iteration, FIXME: what if T is unsigned ?? + + // ************ evaluate max element in input array x ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? x[offset] + : sd::math::nd4j_max( + x[offset], + temp); // take into account max element evaluated on + // previous iteration and stored in temp + } else + shmem[threadIdx.x] = + -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) + shmem[threadIdx.x] = + sd::math::nd4j_max(shmem[threadIdx.x], shmem[threadIdx.x + s]); + __syncthreads(); + } + + temp = shmem[0]; // save max value calculated at current iteration + } + + const T max = temp; + temp = 0; + + // ************ evaluate value of exp(x[offset] - max) per each element, store + // it to shared memory shmem ************ // at the same evaluate sum of + // exponents, sum will be stored in shmem[0] + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx < len) { + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] = sd::math::nd4j_exp(x[offset] - max); + shmem[threadIdx.x] = + (threadIdx.x != 0) + ? z[offset] + : (z[offset] + temp); // take into account sum element evaluated + // on previous iteration and stored in temp + } else + shmem[threadIdx.x] = 0; + + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s /= 2) { + if (threadIdx.x < s) shmem[threadIdx.x] += shmem[threadIdx.x + s]; + __syncthreads(); + } + + temp = shmem[0]; // save sum calculated at current iteration + } + + // ************ evaluate (z[offset] / sum) and derivative z[offset] = + // z[offset] * (1 - z[offset]) ************ // + for (int i = 0; i < numOfIters; ++i) { + const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; + if (elemIdx >= len) continue; + const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo); + z[offset] /= shmem[0]; + z[offset] *= (1.f - z[offset]); // derivative + } } /////////////////////////////////////////////////////////////////// template -linkage void softMaxDerivForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { - - softMaxDerivForVectorCuda<<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); +linkage void softMaxDerivForVectorCudaLauncher(const cudaStream_t *stream, + const void *vx, + const Nd4jLong *xzShapeInfo, + void *vz) { + softMaxDerivForVectorCuda + <<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>( + vx, xzShapeInfo, vz); } /////////////////////////////////////////////////////////////////// -void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - const int rank = input.rankOf(); - int temp; - - if(shape::isCommonVector(input.shapeInfo(), temp)) { - - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxDerivForVectorCudaLauncher, (context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer()), FLOAT_TYPES); - input.tickReadDevice(); - } - else { - - auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - output *= (1.f - output); // derivative - input.tickReadDevice(); - } - - PointersManager manager(context, "helpers::softmaxDerivative"); - manager.synchronize(); - - output.tickWriteDevice(); +void softmaxDerivative(sd::LaunchContext *context, const NDArray &input, + NDArray &output, const int dimension) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + const int rank = input.rankOf(); + int temp; + + if (shape::isCommonVector(input.shapeInfo(), temp)) { + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxDerivForVectorCudaLauncher, + (context->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer()), + FLOAT_TYPES); + input.tickReadDevice(); + } else { + auto maxAlongDim = const_cast(input).reduceAlongDimension( + reduce::Max, {dimension}, true); + (input - maxAlongDim) + .applyTransform(transform::Exp, + output); // output contains exponents temporarily + auto sumAlongDim = + output.reduceAlongDimension(reduce::Sum, {dimension}, true); + output /= sumAlongDim; + output *= (1.f - output); // derivative + input.tickReadDevice(); + } + + PointersManager manager(context, "helpers::softmaxDerivative"); + manager.synchronize(); + + output.tickWriteDevice(); } +template +linkage void thresholdRelu_(NDArray const &input, double threshold, + NDArray &output) { + auto routine = LAMBDA_T(_x, threshold) { + return _x > (T)threshold ? _x : (T)0.f; + }; + const_cast(input).applyLambda(routine, output); +} - template - linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { - auto routine = LAMBDA_T(_x, threshold) { - return _x > (T)threshold ? _x: (T)0.f; - }; - const_cast(input).applyLambda(routine, output); - } - - void thresholdRelu(sd::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, (input, threshold, output), FLOAT_TYPES); - } - - template - linkage void thresholdReluDerivative_(NDArray* input, double theta, NDArray* dLdO, NDArray* output) { - auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; - - input->applyPairwiseLambda(*dLdO, derivative, *output); - } - - void thresholdReluDerivative(sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (input, threshold, dLdO, output), FLOAT_TYPES); - } - +void thresholdRelu(sd::LaunchContext *context, NDArray const &input, + double threshold, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, + (input, threshold, output), FLOAT_TYPES); } + +template +linkage void thresholdReluDerivative_(NDArray *input, double theta, + NDArray *dLdO, NDArray *output) { + auto derivative = LAMBDA_TT(_x, grO, theta) { + if (_x > theta) + return grO; + else + return static_cast(0); + }; + + input->applyPairwiseLambda(*dLdO, derivative, *output); } + +void thresholdReluDerivative(sd::LaunchContext *context, NDArray *input, + double threshold, NDArray *dLdO, NDArray *output) { + BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, + (input, threshold, dLdO, output), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu index 18474f2c79ba..e646ef0f6440 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/addBias.cu @@ -18,131 +18,139 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// -template -__global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - // bias [oC] - - // if(input_rank == 4) - // input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // if(input_rank == 5) - // input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, oC, oH, oW] (NCHW) - - const X* x = reinterpret_cast(vx); - const Y* y = reinterpret_cast(vy); - X* z = reinterpret_cast(vz); - - __shared__ int rank, channelPosition, posOfNonUnityDim; - __shared__ Nd4jLong len, *sharedMem; - __shared__ bool xzSameOffsets, xzAreSame; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(xShapeInfo); // xRank == zRank - xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - len = shape::length(xShapeInfo); - channelPosition = isNCHW ? 1 : rank - 1; // second or last - xzAreSame = x == z; - - shape::isCommonVector(yShapeInfo, posOfNonUnityDim); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) { - - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffsets = shape::getOffset(xShapeInfo, coords); - const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); - const auto yOffsets = coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim]; - - if(xzAreSame) - z[zOffsets] += static_cast(y[yOffsets]); - else - z[zOffsets] = x[xOffsets] + static_cast(y[yOffsets]); - } +template +__global__ static void addBiasCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const bool isNCHW) { + // bias [oC] + + // if(input_rank == 4) + // input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, + // oW] (NCHW) + // if(input_rank == 5) + // input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, + // oC, oH, oW] (NCHW) + + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); + + __shared__ int rank, channelPosition, posOfNonUnityDim; + __shared__ Nd4jLong len, *sharedMem; + __shared__ bool xzSameOffsets, xzAreSame; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + rank = shape::rank(xShapeInfo); // xRank == zRank + xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + len = shape::length(xShapeInfo); + channelPosition = isNCHW ? 1 : rank - 1; // second or last + xzAreSame = x == z; + + shape::isCommonVector(yShapeInfo, posOfNonUnityDim); + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; + i += blockDim.x * gridDim.x) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffsets = shape::getOffset(xShapeInfo, coords); + const auto zOffsets = + xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords); + const auto yOffsets = + coords[channelPosition] * shape::stride(yShapeInfo)[posOfNonUnityDim]; + + if (xzAreSame) + z[zOffsets] += static_cast(y[yOffsets]); + else + z[zOffsets] = x[xOffsets] + static_cast(y[yOffsets]); + } } ////////////////////////////////////////////////////////////////////////// -template -static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - addBiasCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW); +template +static void addBiasCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const bool isNCHW) { + addBiasCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW); } -template -__global__ static void addBias2DCuda( const void* vx, - const void* vy, - void* vz, - uint32_t blocks, uint32_t length) { +template +__global__ static void addBias2DCuda(const void* vx, const void* vy, void* vz, + uint32_t blocks, uint32_t length) { + auto y = reinterpret_cast(vy); - auto y = reinterpret_cast(vy); + for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) { + auto x = reinterpret_cast(vx) + length * b; + auto z = reinterpret_cast(vz) + length * b; - for (uint32_t b = blockIdx.x; b < blocks; b += gridDim.x) { - auto x = reinterpret_cast(vx) + length * b; - auto z = reinterpret_cast(vz) + length * b; - - for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) { - z[e] = x[e] + y[e]; - } + for (uint32_t e = threadIdx.x; e < length; e += blockDim.x) { + z[e] = x[e] + y[e]; } + } } -template -static void addBias2DCudaLauncher(const cudaStream_t *stream, const void* vx, - const void* vy, - void* vz, - uint32_t blocks, uint32_t length) { - - addBias2DCuda<<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length); +template +static void addBias2DCudaLauncher(const cudaStream_t* stream, const void* vx, + const void* vy, void* vz, uint32_t blocks, + uint32_t length) { + addBias2DCuda<<<256, 1024, 128, *stream>>>(vx, vy, vz, blocks, length); } ////////////////////////////////////////////////////////////////////////// -void addBias(sd::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "addBias"); - NDArray::prepareSpecialUse({&output}, {&input, &bias}); - - if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && input.sizeAt(1) == bias.sizeAt(0)) { - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias2DCudaLauncher, - (block.launchContext()->getCudaStream(), input.specialBuffer(), bias.specialBuffer(), output.specialBuffer(), input.sizeAt(0), bias.sizeAt(0)), - FLOAT_TYPES, FLOAT_TYPES); - } else { - // default case - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - - BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, - (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), bias.specialBuffer(), bias.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), - FLOAT_TYPES, FLOAT_TYPES); - } - NDArray::registerSpecialUse({&output}, {&input, &bias}); - manager.synchronize(); +void addBias(sd::graph::Context& block, const NDArray& input, + const NDArray& bias, NDArray& output, const bool isNCHW) { + PointersManager manager(block.launchContext(), "addBias"); + NDArray::prepareSpecialUse({&output}, {&input, &bias}); + + if (input.rankOf() == 2 && bias.rankOf() == 1 && input.ordering() == 'c' && + output.ordering() == 'c' && input.ews() == 1 && bias.ews() == 1 && + input.sizeAt(1) == bias.sizeAt(0)) { + BUILD_DOUBLE_SELECTOR( + input.dataType(), bias.dataType(), addBias2DCudaLauncher, + (block.launchContext()->getCudaStream(), input.specialBuffer(), + bias.specialBuffer(), output.specialBuffer(), input.sizeAt(0), + bias.sizeAt(0)), + FLOAT_TYPES, FLOAT_TYPES); + } else { + // default case + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + BUILD_DOUBLE_SELECTOR( + input.dataType(), bias.dataType(), addBiasCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), bias.specialBuffer(), + bias.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), isNCHW), + FLOAT_TYPES, FLOAT_TYPES); + } + NDArray::registerSpecialUse({&output}, {&input, &bias}); + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu index 9ce00f318303..9d5e494505c8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu @@ -19,91 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const T delta, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - T h, s, v; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - - h += delta ; - if(h > 1) - h -= 1; - else if(h < 0) - h += 1; - - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } +static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const T delta, + const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + T h, s, v; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + + h += delta; + if (h > 1) + h -= 1; + else if (h < 0) + h += 1; + + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) { - - adjustHueCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e(0), dimC); +template +static _CUDA_H void adjustHueCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, + const NDArray* deltaScalarArr, const int dimC) { + adjustHueCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + deltaScalarArr->e(0), dimC); } //////////////////////////////////////////////////////////////////////// -void adjustHue(sd::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "adjustHue"); - - NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); - - manager.synchronize(); +void adjustHue(sd::LaunchContext* context, const NDArray* input, + const NDArray* deltaScalarArr, NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "adjustHue"); + + NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); + BUILD_SINGLE_SELECTOR( + input->dataType(), adjustHueCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, + deltaScalarArr, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); + + manager.synchronize(); } - /* template -static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; +static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong +*xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) +{ int numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto bIn = reinterpret_cast(xBuffer); auto bOut = reinterpret_cast(zBuffer); @@ -128,10 +139,11 @@ static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInf } template -static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - static const int kChannelRange = 6; +static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong +*xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, +Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { int +numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; static const +int kChannelRange = 6; auto bufferR = reinterpret_cast(xBuffer) + xOffsets[0]; auto bufferG = reinterpret_cast(xBuffer) + xOffsets[1]; @@ -166,56 +178,70 @@ static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShape } template -static void _adjust_hue_single(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void _adjust_hue_single(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - adjustHueSingleNHWCKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta); - } else { + adjustHueSingleNHWCKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), +output->specialBuffer(), output->specialShapeInfo(), tuples, delta); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {1, 2}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {1, 2}); + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {1, +2}); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {1, +2}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustHueSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustHueSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } template -static void _adjust_hue_batch(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { - auto xType = array->dataType(); +static void _adjust_hue_batch(sd::LaunchContext * context, NDArray *array, +NDArray *output, float delta, bool isNHWC) { auto xType = array->dataType(); // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - // in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES); - } else { + // in case of nhwc batch, we don't really care about examples: it's +still bunch of RGB values BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, +(context, array, output, delta, isNHWC);, FLOAT_TYPES); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {0, 2, 3}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {0, 2, 3}); + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {0, +2, 3}); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {0, +2, 3}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustHueSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustHueSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } -void _adjust_hue(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void _adjust_hue(sd::LaunchContext * context, NDArray *array, NDArray *output, +NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES); - } else { - BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, +d, isNHWC);, FLOAT_TYPES); } else { BUILD_SINGLE_SELECTOR(xType, +_adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); } } */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu index fd413f8cd4eb..b5d8224473b4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu @@ -19,92 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include #include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const T factor, const int dimC) { +static void _CUDA_G adjustSaturationCuda( + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const T factor, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - T h, s, v; + T h, s, v; - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); - s *= factor; - if(s > 1.f) - s = 1.f; - else if(s < 0.f) - s = 0.f; + s *= factor; + if (s > 1.f) + s = 1.f; + else if (s < 0.f) + s = 0.f; - hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + hsvToRgb(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) { - - adjustSaturationCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e(0), dimC); +template +static _CUDA_H void adjustSaturationCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, + const NDArray* factorScalarArr, const int dimC) { + adjustSaturationCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + factorScalarArr->e(0), dimC); } //////////////////////////////////////////////////////////////////////// -void adjustSaturation(sd::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "adjustSaturation"); - - NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, factorScalarArr}); - - manager.synchronize(); +void adjustSaturation(sd::LaunchContext* context, const NDArray* input, + const NDArray* factorScalarArr, NDArray* output, + const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "adjustSaturation"); + + NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); + BUILD_SINGLE_SELECTOR( + input->dataType(), adjustSaturationCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, + factorScalarArr, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, factorScalarArr}); + + manager.synchronize(); } /* template -static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; +static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong +*xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) +{ int numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto bIn = reinterpret_cast(xBuffer); auto bOut = reinterpret_cast(zBuffer); @@ -117,7 +127,8 @@ static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xS T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * +delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2); @@ -125,10 +136,11 @@ static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xS } template -static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { - int numChannels = 3; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - static const int kChannelRange = 6; +static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong +*xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, +Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) { int +numChannels = 3; auto tid = threadIdx.x + blockIdx.x * blockDim.x; static const +int kChannelRange = 6; auto bufferR = reinterpret_cast(xBuffer) + xOffsets[0]; auto bufferG = reinterpret_cast(xBuffer) + xOffsets[1]; @@ -150,62 +162,79 @@ static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xT T h, s, v; // Convert the RGB color to Hue/V-range. helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v); - s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * delta)); + s = sd::math::nd4j_min((T) 1.0f, sd::math::nd4j_max((T) 0.0f, s * +delta)); // Convert the hue and v-range back into RGB. helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo); } } template -static void _adjust_saturation_single(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { +static void _adjust_saturation_single(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - adjustSaturationSingleNHWCKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta); - } else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {1, 2}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {1, 2}); + adjustSaturationSingleNHWCKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), +output->specialBuffer(), output->specialShapeInfo(), tuples, delta); } else { + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {1, +2}); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {1, +2}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustSaturationSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustSaturationSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } template -static void _adjust_saturation_batch(sd::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { - auto xType = array->dataType(); +static void _adjust_saturation_batch(sd::LaunchContext * context, NDArray +*array, NDArray *output, float delta, bool isNHWC) { auto xType = +array->dataType(); // numChannels is always 3 auto tuples = array->lengthOf() / 3; if (isNHWC) { - // in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES); - } else { + // in case of nhwc batch, we don't really care about examples: it's +still bunch of RGB values BUILD_SINGLE_SELECTOR(xType, +_adjust_saturation_single, (context, array, output, delta, isNHWC);, +FLOAT_TYPES); } else { // TODO: check this one - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {0, 2, 3}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {0, 2, 3}); + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(array->shapeInfo(), {0, +2, 3}); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {0, +2, 3}); auto tadLength = shape::length(packX.primaryShapeInfo()); - adjustSaturationSingleNCHWKernel<<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta); + adjustSaturationSingleNCHWKernel<<<256, 256, 1024, +*context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), +packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), +packZ.platformOffsets(), tadLength, tuples, delta); } } -void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { - auto xType = array->dataType(); +void adjust_saturation(sd::LaunchContext * context, NDArray *array, NDArray +*output, NDArray* delta, bool isNHWC) { auto xType = array->dataType(); float d = delta->e(0); if (array->rankOf() == 4) { - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES); - } else { - BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, +output, d, isNHWC);, FLOAT_TYPES); } else { BUILD_SINGLE_SELECTOR(xType, +_adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); } } */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu index 1dd00f688fd2..b5fd8249b838 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu @@ -20,32 +20,31 @@ #include - namespace sd { namespace ops { namespace helpers { - void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { - output.resize(axisVector->lengthOf()); - axisVector->tickReadDevice(); // mark input as read on device - axisVector->syncToHost(); // sync to host - for (int e = 0; e < axisVector->lengthOf(); e++) { - auto ca = axisVector->e(e); - if (ca < 0) // shift values on rank for negative vals - ca += rank; - - output[e] = ca; - } - } - - void adjustAxis(Nd4jLong rank, std::vector &axisVector) { - for (int e = 0; e < axisVector.size(); e++) { - auto a = axisVector[e]; - if (a < 0) // shift vals on rank for negative vals - axisVector[e] = a + rank; - } - } - -} +void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { + output.resize(axisVector->lengthOf()); + axisVector->tickReadDevice(); // mark input as read on device + axisVector->syncToHost(); // sync to host + for (int e = 0; e < axisVector->lengthOf(); e++) { + auto ca = axisVector->e(e); + if (ca < 0) // shift values on rank for negative vals + ca += rank; + + output[e] = ca; + } } + +void adjustAxis(Nd4jLong rank, std::vector& axisVector) { + for (int e = 0; e < axisVector.size(); e++) { + auto a = axisVector[e]; + if (a < 0) // shift vals on rank for negative vals + axisVector[e] = a + rank; + } } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu index 40540f65d3a1..b9678fdaa809 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batched_gemm.cu @@ -19,14 +19,13 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include +#include +#include #include #include #include -#include -#include - namespace sd { namespace ops { @@ -34,145 +33,157 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////////// // bsxMXK x bSxKxN = bSxMxN -void bgemm(const std::vector& vA, const std::vector& vB, std::vector& vC, const NDArray* alphas, const NDArray* betas, int transA, int transB, int M, int N, int K, const int lda, const int ldb, const int ldc) { - - const auto bS = vA.size(); // batch size - - std::vector pA(bS), pB(bS), pC(bS); - - std::vector toDelete; - - for(int i = 0; i < bS; ++i) { - - if(vA[i]->ews() != 1) { - pA[i] = new NDArray(vA[i]->dup('f')); - toDelete.emplace_back(pA[i]); - } - else - pA[i] = vA[i]; - - if(vB[i]->ews() != 1) { - pB[i] = new NDArray(vB[i]->dup('f')); - toDelete.emplace_back(pB[i]); - } - else - pB[i] = vB[i]; - - if(vC[i]->ews() != 1) { - pC[i] = new NDArray(vC[i]->dup('f')); - toDelete.emplace_back(pC[i]); - } - else - pC[i] = vC[i]; - - if(pC[i]->ordering() != 'f') { - auto temp = pA[i]; - pA[i] = new NDArray(pB[i]->permute({1,0})); - pB[i] = new NDArray(temp ->permute({1,0})); - pC[i] = new NDArray(pC[i]->permute({1,0})); - toDelete.push_back(pA[i]); - toDelete.push_back(pB[i]); - toDelete.push_back(pC[i]); - M = pA[i]->sizeAt(0); - K = pA[i]->sizeAt(1); - N = pB[i]->sizeAt(1); - } - - NDArray::prepareSpecialUse ({pC[i]}, {pA[i], pB[i]}); - NDArray::registerSpecialUse({pC[i]}, {pA[i], pB[i]}); - } - - NDArray::prepareSpecialUse ({}, {alphas, betas}); - NDArray::registerSpecialUse({}, {alphas, betas}); - - std::vector pAbuffs(bS), pBbuffs(bS), pCbuffs(bS); - for(int i = 0; i < bS; ++i) { - pAbuffs[i] = pA[i]->specialBuffer(); - pBbuffs[i] = pB[i]->specialBuffer(); - pCbuffs[i] = pC[i]->specialBuffer(); +void bgemm(const std::vector& vA, const std::vector& vB, + std::vector& vC, const NDArray* alphas, + const NDArray* betas, int transA, int transB, int M, int N, int K, + const int lda, const int ldb, const int ldc) { + const auto bS = vA.size(); // batch size + + std::vector pA(bS), pB(bS), pC(bS); + + std::vector toDelete; + + for (int i = 0; i < bS; ++i) { + if (vA[i]->ews() != 1) { + pA[i] = new NDArray(vA[i]->dup('f')); + toDelete.emplace_back(pA[i]); + } else + pA[i] = vA[i]; + + if (vB[i]->ews() != 1) { + pB[i] = new NDArray(vB[i]->dup('f')); + toDelete.emplace_back(pB[i]); + } else + pB[i] = vB[i]; + + if (vC[i]->ews() != 1) { + pC[i] = new NDArray(vC[i]->dup('f')); + toDelete.emplace_back(pC[i]); + } else + pC[i] = vC[i]; + + if (pC[i]->ordering() != 'f') { + auto temp = pA[i]; + pA[i] = new NDArray(pB[i]->permute({1, 0})); + pB[i] = new NDArray(temp->permute({1, 0})); + pC[i] = new NDArray(pC[i]->permute({1, 0})); + toDelete.push_back(pA[i]); + toDelete.push_back(pB[i]); + toDelete.push_back(pC[i]); + M = pA[i]->sizeAt(0); + K = pA[i]->sizeAt(1); + N = pB[i]->sizeAt(1); } - sd::LaunchContext* context = vA[0]->getContext(); - PointersManager manager(context, "helpers::bgemm cuda"); - - const void** aBuffers = reinterpret_cast(manager.replicatePointer(pAbuffs.data(), bS * sizeof(void*))); - const void** bBuffers = reinterpret_cast(manager.replicatePointer(pBbuffs.data(), bS * sizeof(void*))); - void** cBuffers = reinterpret_cast(manager.replicatePointer(pCbuffs.data(), bS * sizeof(void*))); - - // const auto aOrder = pA->ordering(); - // const auto bOrder = pB->ordering(); - - // const bool transA = aOrder != 'f'; - // const bool transB = bOrder != 'f'; - - const cublasOperation_t transAblas = transA == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; - - // const int lda = aOrder == 'f' ? M : K; - // const int ldb = bOrder == 'f' ? K : N; - // const int ldc = M; // cOrder == 'f' ? M : N; - - const auto aType = pA[0]->dataType(); - const auto bType = pB[0]->dataType(); - const auto cType = pC[0]->dataType(); - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - auto handle = reinterpret_cast(context->getCublasHandle()); - auto stream = context->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - - // choose appropriate cuda gemm api depending on data types - if(ABC && aType == DataType::DOUBLE) { - double alpha = alphas->e(0); - double beta = betas->e(0); - status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const double**)aBuffers, lda, (const double**)bBuffers, ldb, &beta, (double**)cBuffers, ldc, bS); - } - else if(ABC && aType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const float**)aBuffers, lda, (const float**)bBuffers, ldb, &beta, (float**)cBuffers, ldc, bS); - } - else if(ABC && aType == DataType::HALF) { - __half alpha = alphas->e(0); - __half beta = betas->e(0); - status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, &alpha, (const __half**)aBuffers, lda, (const __half**)bBuffers, ldb, &beta, (__half**)cBuffers, ldc, bS); - } - else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); - } - else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) { - float alpha = alphas->e(0); - float beta = betas->e(0); - status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); - } - else - throw std::runtime_error("batched gemm cuda: this mode is not implemented yet !"); - - if (status != CUBLAS_STATUS_SUCCESS) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) - throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); - - for(int i = 0; i < bS; ++i) - if(vC[i]->ews() != 1) - vC[i]->assign(pC[i]); - - for(int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; -} - -} -} + NDArray::prepareSpecialUse({pC[i]}, {pA[i], pB[i]}); + NDArray::registerSpecialUse({pC[i]}, {pA[i], pB[i]}); + } + + NDArray::prepareSpecialUse({}, {alphas, betas}); + NDArray::registerSpecialUse({}, {alphas, betas}); + + std::vector pAbuffs(bS), pBbuffs(bS), pCbuffs(bS); + for (int i = 0; i < bS; ++i) { + pAbuffs[i] = pA[i]->specialBuffer(); + pBbuffs[i] = pB[i]->specialBuffer(); + pCbuffs[i] = pC[i]->specialBuffer(); + } + + sd::LaunchContext* context = vA[0]->getContext(); + PointersManager manager(context, "helpers::bgemm cuda"); + + const void** aBuffers = reinterpret_cast( + manager.replicatePointer(pAbuffs.data(), bS * sizeof(void*))); + const void** bBuffers = reinterpret_cast( + manager.replicatePointer(pBbuffs.data(), bS * sizeof(void*))); + void** cBuffers = reinterpret_cast( + manager.replicatePointer(pCbuffs.data(), bS * sizeof(void*))); + + // const auto aOrder = pA->ordering(); + // const auto bOrder = pB->ordering(); + + // const bool transA = aOrder != 'f'; + // const bool transB = bOrder != 'f'; + + const cublasOperation_t transAblas = + transA == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = + transB == 112 ? CUBLAS_OP_T : CUBLAS_OP_N; + + // const int lda = aOrder == 'f' ? M : K; + // const int ldb = bOrder == 'f' ? K : N; + // const int ldc = M; // cOrder == 'f' ? M : N; + + const auto aType = pA[0]->dataType(); + const auto bType = pB[0]->dataType(); + const auto cType = pC[0]->dataType(); + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + auto handle = reinterpret_cast(context->getCublasHandle()); + auto stream = context->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + // choose appropriate cuda gemm api depending on data types + if (ABC && aType == DataType::DOUBLE) { + double alpha = alphas->e(0); + double beta = betas->e(0); + status = cublasDgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const double**)aBuffers, lda, + (const double**)bBuffers, ldb, &beta, + (double**)cBuffers, ldc, bS); + } else if (ABC && aType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasSgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const float**)aBuffers, lda, + (const float**)bBuffers, ldb, &beta, + (float**)cBuffers, ldc, bS); + } else if (ABC && aType == DataType::HALF) { + __half alpha = alphas->e(0); + __half beta = betas->e(0); + status = cublasHgemmBatched(*handle, transAblas, transBblas, M, N, K, + &alpha, (const __half**)aBuffers, lda, + (const __half**)bBuffers, ldb, &beta, + (__half**)cBuffers, ldc, bS); + } else if (AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, + &alpha, aBuffers, CUDA_R_8I, lda, bBuffers, + CUDA_R_8I, ldb, &beta, cBuffers, CUDA_R_32F, + ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); + } else if (AB && aType == DataType::HALF && cType == DataType::FLOAT32) { + float alpha = alphas->e(0); + float beta = betas->e(0); + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, + &alpha, aBuffers, CUDA_R_16F, lda, bBuffers, + CUDA_R_16F, ldb, &beta, cBuffers, CUDA_R_32F, + ldc, bS, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); + } else + throw std::runtime_error( + "batched gemm cuda: this mode is not implemented yet !"); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", + cudaResult); + + for (int i = 0; i < bS; ++i) + if (vC[i]->ews() != 1) vC[i]->assign(pC[i]); + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index 791953ab715f..130426548d04 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -18,29 +18,25 @@ // @author Yurii Shyrma, created on 25.02.2018 // - -#include -#include -#include #include +#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // template -// __global__ static void batchnormCuda(const void* vx, const Nd4jLong* xShapeInfo, -// const void* vMean, const Nd4jLong* meanShapeInfo, -// const void* vVariance, const Nd4jLong* varianceShapeInfo, -// const void* vGamma, const Nd4jLong* gammaShapeInfo, -// const void* vBeta, const Nd4jLong* betaShapeInfo, -// void* vz, const Nd4jLong* zShapeInfo, -// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, -// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, -// const T epsilon) { +// __global__ static void batchnormCuda(const void* vx, const Nd4jLong* +// xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* +// vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const +// Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, +// void* vz, const Nd4jLong* +// zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const +// Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const T epsilon) { // const auto x = reinterpret_cast(vx); // auto z = reinterpret_cast(vz); @@ -49,7 +45,8 @@ namespace helpers { // const auto gamma = reinterpret_cast(vGamma); // const auto beta = reinterpret_cast(vBeta); -// // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank = betaRank +// // maxRank = xRank = zRank, minRank = meanRank = varianceRank = gammaRank +// = betaRank // __shared__ Nd4jLong minLen, tadLen, totalThreads; // if (threadIdx.x == 0) { @@ -64,10 +61,12 @@ namespace helpers { // for (uint i = tid; i < minLen; i += totalThreads) { -// const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo); +// const auto meanOffset = shape::getIndexOffset(i, +// meanShapeInfo); // const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo); -// T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); +// T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] +// + epsilon); // if(gamma != nullptr) // sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo)]; @@ -84,7 +83,8 @@ namespace helpers { // const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo); // const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo); -// zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam; +// zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * +// sigmaInvGam; // if(beta != nullptr) // zTad[zTadOffset] += beta[betaOffset]; @@ -93,145 +93,175 @@ namespace helpers { // } ////////////////////////////////////////////////////////////////////////// -template -__global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int numDims, const int* dims, - const T epsilon) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - const auto mean = reinterpret_cast(vMean); - const auto variance = reinterpret_cast(vVariance); - const auto gamma = reinterpret_cast(vGamma); - const auto beta = reinterpret_cast(vBeta); - - __shared__ int xRank, minRank; // xRank == zRank, minRank = meanRank = varianceRank = gammaRank = betaRank - __shared__ Nd4jLong xLen, totalThreads; // xLen = zLen - - - if (threadIdx.x == 0) { - - totalThreads = gridDim.x * blockDim.x; - - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - minRank = shape::rank(meanShapeInfo); +template +__global__ static void batchnormCuda2( + const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, + const Nd4jLong* meanShapeInfo, const void* vVariance, + const Nd4jLong* varianceShapeInfo, const void* vGamma, + const Nd4jLong* gammaShapeInfo, const void* vBeta, + const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int numDims, const int* dims, const T epsilon) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + const auto mean = reinterpret_cast(vMean); + const auto variance = reinterpret_cast(vVariance); + const auto gamma = reinterpret_cast(vGamma); + const auto beta = reinterpret_cast(vBeta); + + __shared__ int xRank, minRank; // xRank == zRank, minRank = meanRank = + // varianceRank = gammaRank = betaRank + __shared__ Nd4jLong xLen, totalThreads; // xLen = zLen + + if (threadIdx.x == 0) { + totalThreads = gridDim.x * blockDim.x; + + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + minRank = shape::rank(meanShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (uint i = tid; i < xLen; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + if (minRank == xRank) { + for (uint i = 0, j = 0; i < xRank; ++i) { + if (j < numDims && i != dims[j]) + coords[i] = 0; + else + ++j; + } + } else // minRank = numDims = 1 in this case + coords[0] = coords[dims[0]]; + + const auto meanOffset = shape::getOffset(meanShapeInfo, coords); + const auto varianceOffset = shape::getOffset(varianceShapeInfo, coords); + + T sigmaInvGam = + 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); + + if (gamma != nullptr) { + const auto gammaOffset = shape::getOffset(gammaShapeInfo, coords); + sigmaInvGam *= gamma[gammaOffset]; } - __syncthreads(); - - int coords[MAX_RANK]; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint i = tid; i < xLen; i += totalThreads) { + z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam; - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - if(minRank == xRank) { - for (uint i = 0, j = 0; i < xRank; ++i) { - if(j < numDims && i != dims[j]) - coords[i] = 0; - else - ++j; - } - } - else // minRank = numDims = 1 in this case - coords[0] = coords[dims[0]]; - - const auto meanOffset = shape::getOffset(meanShapeInfo, coords); - const auto varianceOffset = shape::getOffset(varianceShapeInfo, coords); - - T sigmaInvGam = 1. / sd::math::nd4j_sqrt(variance[varianceOffset] + epsilon); - - if(gamma != nullptr) { - const auto gammaOffset = shape::getOffset(gammaShapeInfo, coords); - sigmaInvGam *= gamma[gammaOffset]; - } - - z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam; - - if(beta != nullptr) { - const auto betaOffset = shape::getOffset(betaShapeInfo, coords); - z[zOffset] += beta[betaOffset]; - } + if (beta != nullptr) { + const auto betaOffset = shape::getOffset(betaShapeInfo, coords); + z[zOffset] += beta[betaOffset]; } + } } /////////////////////////////////////////////////////////////////// // template -// __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// const void* vx, const Nd4jLong* xShapeInfo, -// const void* vMean, const Nd4jLong* meanShapeInfo, -// const void* vVariance, const Nd4jLong* varianceShapeInfo, -// const void* vGamma, const Nd4jLong* gammaShapeInfo, -// const void* vBeta, const Nd4jLong* betaShapeInfo, -// void* vz, const Nd4jLong* zShapeInfo, -// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, -// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, -// const double epsilon) { - -// batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); +// __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* +// xShapeInfo, +// const void* vMean, const +// Nd4jLong* meanShapeInfo, +// const void* vVariance, +// const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* +// gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, +// const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* +// xTadOffsets, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, +// const double epsilon) +// { + +// batchnormCuda<<>>(vx, +// xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, +// gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, +// xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); // } /////////////////////////////////////////////////////////////////// -template -__host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vMean, const Nd4jLong* meanShapeInfo, - const void* vVariance, const Nd4jLong* varianceShapeInfo, - const void* vGamma, const Nd4jLong* gammaShapeInfo, - const void* vBeta, const Nd4jLong* betaShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int numDims, const int* dims, - const double epsilon) { - - batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); +template +__host__ static void batchnormCudaLauncher2( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, + const Nd4jLong* varianceShapeInfo, const void* vGamma, + const Nd4jLong* gammaShapeInfo, const void* vBeta, + const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int numDims, const int* dims, const double epsilon) { + batchnormCuda2<<>>( + vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, + vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, + dims, static_cast(epsilon)); } ////////////////////////////////////////////////////////////////////////// -void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { - - // std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); - - // auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimsToExclude); - // auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimsToExclude); - - // const int threadsPerBlock = MAX_NUM_THREADS / 2; - // const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - // PointersManager manager(input->getContext(), "batchnorm"); - - // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - // BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), mean->specialShapeInfo(), variance->specialBuffer(), variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : nullptr, beta ? beta->specialBuffer() : nullptr, beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), epsilon), FLOAT_TYPES); - // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); - - // manager.synchronize(); - - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(input->getContext(), "batchnorm"); - - const int* dims = reinterpret_cast(manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); - - NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher2, (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), mean->specialShapeInfo(), variance->specialBuffer(), variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : nullptr, beta ? beta->specialBuffer() : nullptr, beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), output->specialShapeInfo(), axes.size(), dims, epsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); - - manager.synchronize(); -} - - -} -} +void batchnorm(const NDArray* input, const NDArray* mean, + const NDArray* variance, const NDArray* gamma, + const NDArray* beta, NDArray* output, + const std::vector& axes, const double epsilon) { + // std::vector dimsToExclude = + // ShapeUtils::evalDimsToExclude(input->rankOf(), axes); + + // auto packX = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), + // dimsToExclude); + // auto packZ = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), + // dimsToExclude); + + // const int threadsPerBlock = MAX_NUM_THREADS / 2; + // const int blocksPerGrid = (mean->lengthOf() + threadsPerBlock - 1) / + // threadsPerBlock; + + // PointersManager manager(input->getContext(), "batchnorm"); + + // NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, + // beta}); BUILD_SINGLE_SELECTOR(input->dataType(), batchnormCudaLauncher, + // (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), + // input->specialBuffer(), input->specialShapeInfo(), + // mean->specialBuffer(), mean->specialShapeInfo(), + // variance->specialBuffer(), variance->specialShapeInfo(), gamma ? + // gamma->specialBuffer() : nullptr, gamma ? gamma->specialShapeInfo() : + // nullptr, beta ? beta->specialBuffer() : nullptr, beta ? + // beta->specialShapeInfo() : nullptr, output->specialBuffer(), + // output->specialShapeInfo(), packX.platformShapeInfo(), + // packX.platformOffsets(), packZ.platformShapeInfo(), + // packZ.platformOffsets(), epsilon), FLOAT_TYPES); + // NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, + // beta}); + + // manager.synchronize(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(input->getContext(), "batchnorm"); + + const int* dims = reinterpret_cast( + manager.replicatePointer(axes.data(), axes.size() * sizeof(int))); + + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + BUILD_SINGLE_SELECTOR( + input->dataType(), batchnormCudaLauncher2, + (blocksPerGrid, threadsPerBlock, input->getContext()->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), mean->specialBuffer(), + mean->specialShapeInfo(), variance->specialBuffer(), + variance->specialShapeInfo(), gamma ? gamma->specialBuffer() : nullptr, + gamma ? gamma->specialShapeInfo() : nullptr, + beta ? beta->specialBuffer() : nullptr, + beta ? beta->specialShapeInfo() : nullptr, output->specialBuffer(), + output->specialShapeInfo(), axes.size(), dims, epsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index a18ec1fda8f7..fb7805405688 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -18,177 +18,179 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include -#include #include +#include + +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, -// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions,” +// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering +// Calculations Using Continued Fractions,” template __device__ T continuedFractionCuda(const T a, const T b, const T x) { - - extern __shared__ unsigned char shmem[]; - T* coeffs = reinterpret_cast(shmem); - - const T min = DataTypeUtils::min() / DataTypeUtils::eps(); - const T aPlusb = a + b; - T val, aPlus2i; - - T t2 = coeffs[1]; - T t1 = coeffs[0]; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - T result = t1; - - for(uint i = 1; i <= maxIter; ++i) { - - const uint i2 = 2*i; - aPlus2i = a + static_cast(i2); - - // t1 - t1 = static_cast(1) + coeffs[i2] * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + coeffs[i2] / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - result *= t2 * t1; - // t1 - t1 = static_cast(1) + coeffs[i2 + 1] * t1; - if(math::nd4j_abs(t1) < min) - t1 = min; - t1 = static_cast(1) / t1; - // t2 - t2 = static_cast(1) + coeffs[i2 + 1] / t2; - if(math::nd4j_abs(t2) < min) - t2 = min; - // result - val = t2 * t1; - result *= val; - - // condition to stop loop - if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) - return result; - } - - return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity + extern __shared__ unsigned char shmem[]; + T* coeffs = reinterpret_cast(shmem); + + const T min = DataTypeUtils::min() / DataTypeUtils::eps(); + const T aPlusb = a + b; + T val, aPlus2i; + + T t2 = coeffs[1]; + T t1 = coeffs[0]; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + T result = t1; + + for (uint i = 1; i <= maxIter; ++i) { + const uint i2 = 2 * i; + aPlus2i = a + static_cast(i2); + + // t1 + t1 = static_cast(1) + coeffs[i2] * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2] / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + result *= t2 * t1; + // t1 + t1 = static_cast(1) + coeffs[i2 + 1] * t1; + if (math::nd4j_abs(t1) < min) t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2 + 1] / t2; + if (math::nd4j_abs(t2) < min) t2 = min; + // result + val = t2 * t1; + result *= val; + + // condition to stop loop + if (math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; + } + + return DataTypeUtils::infOrMax(); // no convergence, more iterations is + // required, return infinity } /////////////////////////////////////////////////////////////////// -template +template __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - extern __shared__ unsigned char shmem[]; - T* sharedMem = reinterpret_cast(shmem); - - const Nd4jLong j = blockIdx.x; // one block per each element - - T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); - - __shared__ T a, b, x; - __shared__ bool symmCond; - - if (threadIdx.x == 0) { - - a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); - b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); - x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - - symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); - - if(symmCond) { // swap a and b, x = 1 - x - T temp = a; - a = b; - b = temp; - x = static_cast(1) - x; - } - + const void* vb, const Nd4jLong* bShapeInfo, + const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo) { + extern __shared__ unsigned char shmem[]; + T* sharedMem = reinterpret_cast(shmem); + + const Nd4jLong j = blockIdx.x; // one block per each element + + T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); + + __shared__ T a, b, x; + __shared__ bool symmCond; + + if (threadIdx.x == 0) { + a = *(reinterpret_cast(va) + + shape::getIndexOffset(j, aShapeInfo)); + b = *(reinterpret_cast(vb) + + shape::getIndexOffset(j, bShapeInfo)); + x = *(reinterpret_cast(vx) + + shape::getIndexOffset(j, xShapeInfo)); + + symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); + + if (symmCond) { // swap a and b, x = 1 - x + T temp = a; + a = b; + b = temp; + x = static_cast(1) - x; } - __syncthreads(); - - // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 - if(a == b && x == static_cast(0.5)) { - z = static_cast(0.5); - return; - } - - if (x == static_cast(0) || x == static_cast(1)) { - z = symmCond ? static_cast(1) - x : x; - return; - } - - // calculate two coefficients per thread - if(threadIdx.x != 0) { - - const int i = threadIdx.x; - const T aPlus2i = a + 2*i; - sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - } - - __syncthreads(); - - if(threadIdx.x == 0) { - - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); - - sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); - sharedMem[1] = static_cast(1); - - z = front * continuedFractionCuda(a, b, x) / a; - - if(symmCond) // symmetry relation - z = static_cast(1) - z; - } + } + __syncthreads(); + + // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 + if (a == b && x == static_cast(0.5)) { + z = static_cast(0.5); + return; + } + + if (x == static_cast(0) || x == static_cast(1)) { + z = symmCond ? static_cast(1) - x : x; + return; + } + + // calculate two coefficients per thread + if (threadIdx.x != 0) { + const int i = threadIdx.x; + const T aPlus2i = a + 2 * i; + sharedMem[2 * i] = + i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + sharedMem[2 * i + 1] = + -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = + math::nd4j_exp(math::nd4j_log(x) * a + + math::nd4j_log(1.f - x) * b - gammaPart); + + sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); + sharedMem[1] = static_cast(1); + + z = front * continuedFractionCuda(a, b, x) / a; + + if (symmCond) // symmetry relation + z = static_cast(1) - z; + } } /////////////////////////////////////////////////////////////////// -template -static void betaIncForArrayCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* va, const Nd4jLong* aShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - betaIncForArrayCuda<<>>(va, aShapeInfo, vb, bShapeInfo, vx, xShapeInfo, vz, zShapeInfo); +template +static void betaIncForArrayCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* va, const Nd4jLong* aShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + betaIncForArrayCuda + <<>>( + va, aShapeInfo, vb, bShapeInfo, vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// // overload betaInc for arrays, shapes of a, b and x must be the same !!! -void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, const NDArray& x, NDArray& output) { - - const int threadsPerBlock = maxIter; - const int blocksPerGrid = output.lengthOf(); - const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; - - const auto xType = x.dataType(); - - PointersManager manager(context, "betaInc"); - - NDArray::prepareSpecialUse({&output}, {&a, &b, &x}); - BUILD_SINGLE_SELECTOR(xType, betaIncForArrayCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), a.specialBuffer(), a.specialShapeInfo(), b.specialBuffer(), b.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&a, &b, &x}); - - manager.synchronize(); -} - - -} -} +void betaInc(sd::LaunchContext* context, const NDArray& a, const NDArray& b, + const NDArray& x, NDArray& output) { + const int threadsPerBlock = maxIter; + const int blocksPerGrid = output.lengthOf(); + const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; + + const auto xType = x.dataType(); + + PointersManager manager(context, "betaInc"); + + NDArray::prepareSpecialUse({&output}, {&a, &b, &x}); + BUILD_SINGLE_SELECTOR( + xType, betaIncForArrayCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + a.specialBuffer(), a.specialShapeInfo(), b.specialBuffer(), + b.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&a, &b, &x}); + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu index 62f60cc733b2..5525584da992 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu @@ -19,188 +19,210 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, iW] +// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, +// iW] template -static __global__ void col2imCuda(const void* columns, const Nd4jLong* colShapeInfo, void* image, const Nd4jLong* imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - const T* col = reinterpret_cast(columns); - T* im = reinterpret_cast(image); +static __global__ void col2imCuda(const void* columns, + const Nd4jLong* colShapeInfo, void* image, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + const T* col = reinterpret_cast(columns); + T* im = reinterpret_cast(image); - __shared__ uint kH, kW, oH, oW, *sharedMem; - __shared__ Nd4jLong imLen; + __shared__ uint kH, kW, oH, oW, *sharedMem; + __shared__ Nd4jLong imLen; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - kH = dH * (colShapeInfo[3] - 1) + 1; - kW = dW * (colShapeInfo[4] - 1) + 1; + kH = dH * (colShapeInfo[3] - 1) + 1; + kW = dW * (colShapeInfo[4] - 1) + 1; - oH = colShapeInfo[5]; - oW = colShapeInfo[6]; + oH = colShapeInfo[5]; + oW = colShapeInfo[6]; - imLen = shape::length(imShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * 6; + imLen = shape::length(imShapeInfo); + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * 6; - for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - shape::index2coords(i, imShapeInfo, coords); + for (Nd4jLong i = tid; i < imLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, imShapeInfo, coords); - const auto imOffset = shape::getOffset(imShapeInfo, coords); + const auto imOffset = shape::getOffset(imShapeInfo, coords); - const auto bSiCoffset = coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8]; + const auto bSiCoffset = + coords[0] * colShapeInfo[7] + coords[1] * colShapeInfo[8]; - const uint imH = coords[2] + pH; - const uint imW = coords[3] + pW; + const uint imH = coords[2] + pH; + const uint imW = coords[3] + pW; - const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; - const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); - const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); + const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); - T val = 0; + T val = 0; - for(coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { - coords[2] = imH - coords[4] * sH; - if(coords[2] % dH != 0) continue; + for (coords[4] = colHstart; coords[4] < colHend; ++coords[4]) { + coords[2] = imH - coords[4] * sH; + if (coords[2] % dH != 0) continue; - for(coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { - coords[3] = imW - coords[5] * sW; - if(coords[3] % dW != 0) continue; + for (coords[5] = colWstart; coords[5] < colWend; ++coords[5]) { + coords[3] = imW - coords[5] * sW; + if (coords[3] % dW != 0) continue; - val += col[bSiCoffset + (coords[2]/dH)*colShapeInfo[9] + (coords[3]/dW)*colShapeInfo[10] + coords[4]*colShapeInfo[11] + coords[5]*colShapeInfo[12]]; - } - } - im[imOffset] = val; + val += col[bSiCoffset + (coords[2] / dH) * colShapeInfo[9] + + (coords[3] / dW) * colShapeInfo[10] + + coords[4] * colShapeInfo[11] + coords[5] * colShapeInfo[12]]; + } } + im[imOffset] = val; + } } //////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, iW] -template -__global__ static void col2imCuda2(const void *columns, void *image, const Nd4jLong *colShapeInfo, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - const auto col = reinterpret_cast(columns); - auto im = reinterpret_cast(image); - - auto colShape = shape::shapeOf(const_cast(colShapeInfo)); - auto colStride = shape::stride(const_cast(colShapeInfo)); - - int colStride0 = colStride[0]; - int colStride1 = colStride[1]; - int colStride2 = colStride[2]; - int colStride3 = colStride[3]; - int colStride4 = colStride[4]; - int colStride5 = colStride[5]; - - int kH = colShape[2]; - int kW = colShape[3]; - - auto imShape = shape::shapeOf(const_cast(imShapeInfo)); - auto imOrder = shape::order(const_cast(imShapeInfo)); - auto imStride = shape::stride(const_cast(imShapeInfo)); - - int bS = imShape[0]; - int iC = imShape[1]; - int iH = imShape[2]; - int iW = imShape[3]; - - int oH = colShape[4];//(iH + 2 * pH - kH) / sW + 1; - int oW = colShape[5];//(iW + 2 * pW - kW) / sH + 1; - - int n = bS * iC * iH * iW; - - //Effective kernel size, accounting for dilation - int kHeff = kH + (kH - 1) * (dH - 1); - int kWeff = kW + (kW - 1) * (dW - 1); - - for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - T val = 0; - - int w_im = i % iW + pW; - int h_im = (i / iW) % iH + pH; - int c_im = i / (iW * iH); - int b = c_im / iC; - int c = c_im % iC; - - // compute the start and end of the output - // These are the indexes for dimensions ??? in the 6d col matrix - int w_col_start = (w_im < kWeff) ? 0 : (w_im - kWeff) / sW + 1; - int w_col_end = sd::math::nd4j_min(w_im / sW + 1, oW); - - int h_col_start = (h_im < kHeff) ? 0 : (h_im - kHeff) / sH + 1; - int h_col_end = sd::math::nd4j_min(h_im / sH + 1, oH); - - //Iterate over col entries in the 6d array... these are added up - for (int colH = h_col_start; colH < h_col_end; colH += 1) { - for (int colW = w_col_start; colW < w_col_end; colW += 1) { - int kRow = (h_im - colH * sH); - int kCol = (w_im - colW * sW); - - if(kRow % dH == 0 && kCol % dW == 0){ - kRow /= dH; - kCol /= dW; - - int data_col_index = b * colStride0 + c * colStride1 + kRow * colStride2 + kCol * colStride3 + colH * colStride4 + colW * colStride5; - val += col[data_col_index]; - } - } - } - - int i_f = 0; - int i_c = i; - for (int dim = 3; dim >= 0; dim--) { - i_f += (i_c % imShape[dim]) * imStride[dim]; - i_c = i_c / imShape[dim]; - } - - im[i_f] = val; - } +// columns [bS, iC, kH, kW, oH, oW] to be de-convoluted to image [bS, iC, iH, +// iW] +template +__global__ static void col2imCuda2(const void* columns, void* image, + const Nd4jLong* colShapeInfo, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + const auto col = reinterpret_cast(columns); + auto im = reinterpret_cast(image); + + auto colShape = shape::shapeOf(const_cast(colShapeInfo)); + auto colStride = shape::stride(const_cast(colShapeInfo)); + + int colStride0 = colStride[0]; + int colStride1 = colStride[1]; + int colStride2 = colStride[2]; + int colStride3 = colStride[3]; + int colStride4 = colStride[4]; + int colStride5 = colStride[5]; + + int kH = colShape[2]; + int kW = colShape[3]; + + auto imShape = shape::shapeOf(const_cast(imShapeInfo)); + auto imOrder = shape::order(const_cast(imShapeInfo)); + auto imStride = shape::stride(const_cast(imShapeInfo)); + + int bS = imShape[0]; + int iC = imShape[1]; + int iH = imShape[2]; + int iW = imShape[3]; + + int oH = colShape[4]; //(iH + 2 * pH - kH) / sW + 1; + int oW = colShape[5]; //(iW + 2 * pW - kW) / sH + 1; + + int n = bS * iC * iH * iW; + + // Effective kernel size, accounting for dilation + int kHeff = kH + (kH - 1) * (dH - 1); + int kWeff = kW + (kW - 1) * (dW - 1); + + for (int i = (blockDim.x * blockIdx.x) + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + T val = 0; + + int w_im = i % iW + pW; + int h_im = (i / iW) % iH + pH; + int c_im = i / (iW * iH); + int b = c_im / iC; + int c = c_im % iC; + + // compute the start and end of the output + // These are the indexes for dimensions ??? in the 6d col matrix + int w_col_start = (w_im < kWeff) ? 0 : (w_im - kWeff) / sW + 1; + int w_col_end = sd::math::nd4j_min(w_im / sW + 1, oW); + + int h_col_start = (h_im < kHeff) ? 0 : (h_im - kHeff) / sH + 1; + int h_col_end = sd::math::nd4j_min(h_im / sH + 1, oH); + + // Iterate over col entries in the 6d array... these are added up + for (int colH = h_col_start; colH < h_col_end; colH += 1) { + for (int colW = w_col_start; colW < w_col_end; colW += 1) { + int kRow = (h_im - colH * sH); + int kCol = (w_im - colW * sW); + + if (kRow % dH == 0 && kCol % dW == 0) { + kRow /= dH; + kCol /= dW; + + int data_col_index = b * colStride0 + c * colStride1 + + kRow * colStride2 + kCol * colStride3 + + colH * colStride4 + colW * colStride5; + val += col[data_col_index]; + } + } + } + + int i_f = 0; + int i_c = i; + for (int dim = 3; dim >= 0; dim--) { + i_f += (i_c % imShape[dim]) * imStride[dim]; + i_c = i_c / imShape[dim]; + } + + im[i_f] = val; + } } ////////////////////////////////////////////////////////////////////////// template -static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* columns, const Nd4jLong* colShapeInfo, - void* image, const Nd4jLong* imShapeInfo, - const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); - col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); +static void col2imCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* columns, + const Nd4jLong* colShapeInfo, void* image, + const Nd4jLong* imShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, + // imShapeInfo, sH, sW, pH, pW, dH, dW); + col2imCuda<<>>( + columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); } ////////////////////////////////////////////////////////////////////////// -void col2im(sd::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - - PointersManager manager(&context, "col2im"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; - - NDArray::prepareSpecialUse({&im}, {&col}); - BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.specialBuffer(), col.specialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&im}, {&col}); - - manager.synchronize(); +void col2im(sd::LaunchContext& context, const NDArray& col, NDArray& im, + const int sH, const int sW, const int pH, const int pW, + const int iH, const int iW, const int dH, const int dW) { + PointersManager manager(&context, "col2im"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (im.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; + + NDArray::prepareSpecialUse({&im}, {&col}); + BUILD_SINGLE_SELECTOR( + im.dataType(), col2imCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), + col.specialBuffer(), col.specialShapeInfo(), im.specialBuffer(), + im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&im}, {&col}); + + manager.synchronize(); } - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu index 8d0bede625e5..dc757716e6d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu @@ -17,119 +17,131 @@ #include namespace sd { - namespace ops { - namespace helpers { - - template - static _CUDA_G void comparator(void *vx, const Nd4jLong *xShapeInfo, Nd4jLong length, const bool isStrict, void *reductionBuffer, bool *z) { - auto x = reinterpret_cast(vx); - auto reduction = reinterpret_cast(reductionBuffer); - - extern __shared__ uint32_t shared[]; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - shared[threadIdx.x] = 0; - - // each thread will compare 2 elements: E and E+1 - for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { - auto val0 = x[shape::getIndexOffset(e, xShapeInfo)]; - auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo)]; - - bool v = false; - if (isStrict) - v = val1 > val0; - else - v = val1 >= val0; - - // store comparison result in shared memory - shared[threadIdx.x] += v ? 0 : 1; - } - __syncthreads(); - - // aggregate sums in shared memory - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; - __syncthreads(); - } - - - // store over the grid if we have more than 1 block - if (gridDim.x > 1) { - - auto tc = reinterpret_cast(reductionBuffer); - __shared__ bool amLast; - - tid = threadIdx.x; - if (threadIdx.x == 0) - reduction[blockIdx.x] = shared[0]; - - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - shared[threadIdx.x] = 0; +namespace ops { +namespace helpers { + +template +static _CUDA_G void comparator(void *vx, const Nd4jLong *xShapeInfo, + Nd4jLong length, const bool isStrict, + void *reductionBuffer, bool *z) { + auto x = reinterpret_cast(vx); + auto reduction = reinterpret_cast(reductionBuffer); + + extern __shared__ uint32_t shared[]; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + shared[threadIdx.x] = 0; + + // each thread will compare 2 elements: E and E+1 + for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { + auto val0 = x[shape::getIndexOffset(e, xShapeInfo)]; + auto val1 = x[shape::getIndexOffset(e + 1, xShapeInfo)]; + + bool v = false; + if (isStrict) + v = val1 > val0; + else + v = val1 >= val0; + + // store comparison result in shared memory + shared[threadIdx.x] += v ? 0 : 1; + } + __syncthreads(); + + // aggregate sums in shared memory + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; + __syncthreads(); + } + + // store over the grid if we have more than 1 block + if (gridDim.x > 1) { + auto tc = reinterpret_cast(reductionBuffer); + __shared__ bool amLast; + + tid = threadIdx.x; + if (threadIdx.x == 0) reduction[blockIdx.x] = shared[0]; + + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) - shared[threadIdx.x] += reduction[i]; + __syncthreads(); - __syncthreads(); + if (amLast) { + tc[16384] = 0; + shared[threadIdx.x] = 0; - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; - __syncthreads(); - } + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) + shared[threadIdx.x] += reduction[i]; - __syncthreads(); + __syncthreads(); - if (threadIdx.x == 0) { - z[0] = shared[0] == 0; - } - } - } - else { - // if we have only 1 block, we just store results right away - if (threadIdx.x == 0) { - auto tc = reinterpret_cast(reductionBuffer); - tc[16384] = 0; - z[0] = shared[0] == 0; - } - } - } + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; + __syncthreads(); + } - template - static void _compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto z = NDArrayFactory::create(false, context); + __syncthreads(); - const int numThreads = 256; - const int numBlocks = sd::math::nd4j_min(128, sd::math::nd4j_max(1, input->lengthOf() / numThreads)); + if (threadIdx.x == 0) { + z[0] = shared[0] == 0; + } + } + } else { + // if we have only 1 block, we just store results right away + if (threadIdx.x == 0) { + auto tc = reinterpret_cast(reductionBuffer); + tc[16384] = 0; + z[0] = shared[0] == 0; + } + } +} - comparator<<getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), input->lengthOf(), isStrictlyIncreasing, context->getReductionPointer(), reinterpret_cast(z.specialBuffer())); +template +static void _compare_elem(sd::LaunchContext *context, NDArray *input, + bool isStrictlyIncreasing, bool &output) { + auto z = NDArrayFactory::create(false, context); - z.tickWriteDevice(); - sd::DebugHelper::checkErrorCode(context->getCudaStream(), "is_strictly_increasing"); + const int numThreads = 256; + const int numBlocks = sd::math::nd4j_min( + 128, sd::math::nd4j_max(1, input->lengthOf() / numThreads)); - output = z.e(0); - } + comparator<<getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), input->lengthOf(), + isStrictlyIncreasing, context->getReductionPointer(), + reinterpret_cast(z.specialBuffer())); - void compare_elem(sd::LaunchContext * context, NDArray *input, bool isStrictlyIncreasing, bool& output) { - auto xType = input->dataType(); - input->syncToDevice(); + z.tickWriteDevice(); + sd::DebugHelper::checkErrorCode(context->getCudaStream(), + "is_strictly_increasing"); - BUILD_SINGLE_SELECTOR(xType, _compare_elem, (context, input, isStrictlyIncreasing, output), LIBND4J_TYPES); - } + output = z.e(0); +} +void compare_elem(sd::LaunchContext *context, NDArray *input, + bool isStrictlyIncreasing, bool &output) { + auto xType = input->dataType(); + input->syncToDevice(); - BUILD_SINGLE_TEMPLATE(template void _compare_elem, (sd::LaunchContext * context, NDArray *A, bool isStrictlyIncreasing, bool& output);, LIBND4J_TYPES); - } - } + BUILD_SINGLE_SELECTOR(xType, _compare_elem, + (context, input, isStrictlyIncreasing, output), + LIBND4J_TYPES); } + +BUILD_SINGLE_TEMPLATE(template void _compare_elem, + (sd::LaunchContext * context, NDArray *A, + bool isStrictlyIncreasing, bool &output); + , LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu b/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu index 5de20c57f852..8503f0a9d4f2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compression/compression.cu @@ -17,50 +17,60 @@ // // @author sgazeos@gmail.com // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void decodeBitmap(sd::LaunchContext* context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); +void decodeBitmap(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(512, 512, 16384); - auto xType = output->dataType(); - BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, (launchDims, stream, input->specialBuffer(), output->lengthOf(), output->specialBuffer()), FLOAT_TYPES); + dim3 launchDims(512, 512, 16384); + auto xType = output->dataType(); + BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, + (launchDims, stream, input->specialBuffer(), + output->lengthOf(), output->specialBuffer()), + FLOAT_TYPES); - sd::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); - NDArray::registerSpecialUse({output}, {input}); - } + NDArray::registerSpecialUse({output}, {input}); +} - Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, NDArray* output, float threshold) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - int *resultPointer = reinterpret_cast(LaunchContext::defaultContext()->getScalarPointer()); - int *reductionPointer = reinterpret_cast(LaunchContext::defaultContext()->getReductionPointer()); +Nd4jLong encodeBitmap(sd::LaunchContext* context, NDArray* input, + NDArray* output, float threshold) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); + int* resultPointer = reinterpret_cast( + LaunchContext::defaultContext()->getScalarPointer()); + int* reductionPointer = reinterpret_cast( + LaunchContext::defaultContext()->getReductionPointer()); - // nullify result pointer before use - resultPointer[0] = 0; + // nullify result pointer before use + resultPointer[0] = 0; - NDArray::prepareSpecialUse({},{output, input}); + NDArray::prepareSpecialUse({}, {output, input}); - dim3 launchDims(512, 512, 32768); - auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, cudaEncodeBitmapGeneric, - (launchDims, stream, input->specialBuffer(), input->lengthOf(), reinterpret_cast(output->specialBuffer()), resultPointer, reductionPointer, threshold), - FLOAT_TYPES); + dim3 launchDims(512, 512, 32768); + auto xType = input->dataType(); + BUILD_SINGLE_SELECTOR( + xType, cudaEncodeBitmapGeneric, + (launchDims, stream, input->specialBuffer(), input->lengthOf(), + reinterpret_cast(output->specialBuffer()), resultPointer, + reductionPointer, threshold), + FLOAT_TYPES); - sd::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); + sd::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); - Nd4jLong dZ = (Nd4jLong) resultPointer[0]; - resultPointer[0] = 0; + Nd4jLong dZ = (Nd4jLong)resultPointer[0]; + resultPointer[0] = 0; - NDArray::registerSpecialUse({output, input}, {}); - return dZ; - } -} -} + NDArray::registerSpecialUse({output, input}, {}); + return dZ; } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu b/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu index 6b5af0df4d61..5272f32ec050 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compression/threshold.cu @@ -18,214 +18,241 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include + #include namespace sd { - namespace ops { - namespace helpers { - void prescanArrayRecursive(int** g_scanBlockSums, int *dZ, int *dX, int numElements, int level) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - - int blockSize = 512; // max size of the thread blocks - int numBlocks = sd::math::nd4j_max(1, static_cast(ceil(static_cast(numElements) / (2.f * blockSize)))); - int numThreads; - - if (numBlocks > 1) - numThreads = blockSize; - else if (sd::isPowerOfTwo(numElements)) - numThreads = numElements / 2; - else - numThreads = sd::floorPow2(numElements); - - int numEltsPerBlock = numThreads * 2; - - // if this is a non-power-of-2 array, the last block will be non-full - // compute the smallest power of 2 able to compute its scan. - int numEltsLastBlock = - numElements - (numBlocks-1) * numEltsPerBlock; - int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); - int np2LastBlock = 0; - int sharedMemLastBlock = 0; - - if (numEltsLastBlock != numEltsPerBlock) { - np2LastBlock = 1; - - if(!isPowerOfTwo(numEltsLastBlock)) - numThreadsLastBlock = floorPow2(numEltsLastBlock); - - unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; - sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); - } - - // padding space is used to avoid shared memory bank conflicts - int extraSpace = numEltsPerBlock / NUM_BANKS; - int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); - - // setup execution parameters - // if NP2, we process the last block separately - dim3 grid(sd::math::nd4j_max(1, numBlocks - np2LastBlock), 1, 1); - dim3 threads(numThreads, 1, 1); - dim3 gridOnes(1, 1, 1); - dim3 threadsOnes(numThreadsLastBlock, 1, 1); - - if (sharedMemSize < 2048) - sharedMemSize = 2048; - - if (sharedMemLastBlock < 2048) - sharedMemLastBlock = 2048; - - // execute the scan - if (numBlocks > 1) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, g_scanBlockSums[level], numThreads * 2, 0, 0); - if (np2LastBlock) { - sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, stream, dZ, dX, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - - // After scanning all the sub-blocks, we are mostly done. But now we - // need to take all of the last values of the sub-blocks and scan those. - // This will give us a new value that must be sdded to each block to - // get the final results. - // recursive (CPU) call - prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], g_scanBlockSums[level], numBlocks, level+1); - - sd::uniformAdd<<>>(dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); - - if (np2LastBlock) { - sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>(dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, numElements - numEltsLastBlock); - } - } else if (isPowerOfTwo(numElements)) { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numThreads * 2, 0, 0); - } else { - sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, dX, 0, numElements, 0, 0); - } - - sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); - } - - static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) { - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - prescanArrayRecursive(reinterpret_cast(prs), dz, dx + 1, (int) N, 0); - sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); - } - - static void encodeThresholdP3_(void *dx, const Nd4jLong *hXShapeInfo, int *offsets, Nd4jLong N, int *dz){ - auto stream = LaunchContext::defaultContext()->getCudaStream(); - - int blockSize = 512; - int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - - dim3 launchDims(numBlocks, blockSize, 8192); - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES); - - sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); - } - - - static NDArray thresholdEstimate_(const NDArray &updates, const float threshold) { - const int numThreads = 512; - const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0); - - auto tmp = NDArrayFactory::create('c', {numBlocks + 1}); +namespace ops { +namespace helpers { +void prescanArrayRecursive(int **g_scanBlockSums, int *dZ, int *dX, + int numElements, int level) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); + + int blockSize = 512; // max size of the thread blocks + int numBlocks = sd::math::nd4j_max( + 1, static_cast( + ceil(static_cast(numElements) / (2.f * blockSize)))); + int numThreads; + + if (numBlocks > 1) + numThreads = blockSize; + else if (sd::isPowerOfTwo(numElements)) + numThreads = numElements / 2; + else + numThreads = sd::floorPow2(numElements); + + int numEltsPerBlock = numThreads * 2; + + // if this is a non-power-of-2 array, the last block will be non-full + // compute the smallest power of 2 able to compute its scan. + int numEltsLastBlock = numElements - (numBlocks - 1) * numEltsPerBlock; + int numThreadsLastBlock = sd::math::nd4j_max(1, numEltsLastBlock / 2); + int np2LastBlock = 0; + int sharedMemLastBlock = 0; + + if (numEltsLastBlock != numEltsPerBlock) { + np2LastBlock = 1; + + if (!isPowerOfTwo(numEltsLastBlock)) + numThreadsLastBlock = floorPow2(numEltsLastBlock); + + unsigned int extraSpace = (2 * numThreadsLastBlock) / NUM_BANKS; + sharedMemLastBlock = sizeof(int) * (2 * numThreadsLastBlock + extraSpace); + } + + // padding space is used to avoid shared memory bank conflicts + int extraSpace = numEltsPerBlock / NUM_BANKS; + int sharedMemSize = sizeof(int) * (numEltsPerBlock + extraSpace); + + // setup execution parameters + // if NP2, we process the last block separately + dim3 grid(sd::math::nd4j_max(1, numBlocks - np2LastBlock), 1, 1); + dim3 threads(numThreads, 1, 1); + dim3 gridOnes(1, 1, 1); + dim3 threadsOnes(numThreadsLastBlock, 1, 1); + + if (sharedMemSize < 2048) sharedMemSize = 2048; + + if (sharedMemLastBlock < 2048) sharedMemLastBlock = 2048; + + // execute the scan + if (numBlocks > 1) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, g_scanBlockSums[level], numThreads * 2, + 0, 0); + if (np2LastBlock) { + sd::prescanLauncher(gridOnes, threadsOnes, sharedMemLastBlock, + stream, dZ, dX, g_scanBlockSums[level], + numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); + } - dim3 launchDims(numBlocks, numThreads, 1024); - auto xType = updates.dataType(); - - NDArray::prepareSpecialUse({&tmp}, {&updates}); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, LaunchContext::defaultContext()->getCudaStream(), updates.specialBuffer(), updates.lengthOf(), tmp.specialBuffer(), threshold), FLOAT_TYPES); - NDArray::registerSpecialUse({&tmp}, {&updates}); + // After scanning all the sub-blocks, we are mostly done. But now we + // need to take all of the last values of the sub-blocks and scan those. + // This will give us a new value that must be sdded to each block to + // get the final results. + // recursive (CPU) call + prescanArrayRecursive(g_scanBlockSums, g_scanBlockSums[level], + g_scanBlockSums[level], numBlocks, level + 1); + + sd::uniformAdd<<>>( + dZ, g_scanBlockSums[level], numElements - numEltsLastBlock, 0, 0); + + if (np2LastBlock) { + sd::uniformAdd<<<1, numThreadsLastBlock, 1024, *stream>>>( + dZ, g_scanBlockSums[level], numEltsLastBlock, numBlocks - 1, + numElements - numEltsLastBlock); + } + } else if (isPowerOfTwo(numElements)) { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numThreads * 2, 0, 0); + } else { + sd::prescanLauncher(grid, threads, sharedMemSize, stream, dZ, + dX, 0, numElements, 0, 0); + } + + sd::DebugHelper::checkErrorCode(stream, "prescanArray(...) failed"); +} - return std::move(tmp); - } +static void encodeThresholdP2Int_(void **prs, int *dx, Nd4jLong N, int *dz) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); - int32_t thresholdEstimate(const NDArray &updates, const float threshold) { - return thresholdEstimate_(updates, threshold).e(0); - } + prescanArrayRecursive(reinterpret_cast(prs), dz, dx + 1, (int)N, 0); + sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); +} - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { - // we need these blocks in order to know, how many "updates" will be processed by each GPU block - auto blocks = thresholdEstimate_(updates, threshold); +static void encodeThresholdP3_(void *dx, const Nd4jLong *hXShapeInfo, + int *offsets, Nd4jLong N, int *dz) { + auto stream = LaunchContext::defaultContext()->getCudaStream(); - const int numThreads = 512; - const int numBlocks = updates.lengthOf() / numThreads + (updates.lengthOf() % numThreads ? 1 : 0); + int blockSize = 512; + int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - const int prefixThreads = 512; - int numElts = numBlocks; - int level = 0; + dim3 launchDims(numBlocks, blockSize, 8192); + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, + (launchDims, stream, dx, offsets, N, dz), FLOAT_TYPES); - // here we just calculate number of sumBlock arrays - do { - int numPrefixBlocks = sd::math::nd4j_max(1, sd::math::nd4j_ceil((float) numElts / (2.0f * prefixThreads))); - if (numBlocks > 1) { - level++; - } - numElts = numPrefixBlocks; - } while (numElts > 1); + sd::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); +} +static NDArray thresholdEstimate_(const NDArray &updates, + const float threshold) { + const int numThreads = 512; + const int numBlocks = updates.lengthOf() / numThreads + + (updates.lengthOf() % numThreads ? 1 : 0); + auto tmp = NDArrayFactory::create('c', {numBlocks + 1}); - std::vector tempArrays(level); - std::vector pointers(level); + dim3 launchDims(numBlocks, numThreads, 1024); + auto xType = updates.dataType(); - level = 0; - numElts = numBlocks; + NDArray::prepareSpecialUse({&tmp}, {&updates}); + BUILD_SINGLE_SELECTOR( + xType, encoderKernelP1Generic, + (launchDims, LaunchContext::defaultContext()->getCudaStream(), + updates.specialBuffer(), updates.lengthOf(), tmp.specialBuffer(), + threshold), + FLOAT_TYPES); + NDArray::registerSpecialUse({&tmp}, {&updates}); - do { - int numPrefixBlocks = sd::math::nd4j_max(1, sd::math::nd4j_ceil((float) numElts / (2.0f * prefixThreads))); - if (numPrefixBlocks > 1) { - tempArrays[level] = std::move(NDArrayFactory::create('c', {numPrefixBlocks})); - pointers[level] = tempArrays[level++].specialBuffer(); - } - numElts = numPrefixBlocks; - } while (numElts > 1); + return std::move(tmp); +} - PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode"); - auto dptr = pm.replicatePointer(pointers.data(), pointers.size() * 8); - auto offsets = NDArrayFactory::create('c', {numBlocks}); +int32_t thresholdEstimate(const NDArray &updates, const float threshold) { + return thresholdEstimate_(updates, threshold).e(0); +} - // we want to check, if we're hiting external limit on number of encoded elements - auto numMatches = blocks.e(0); - if (numMatches > encoded.lengthOf() - 4) { - blocks.p(0, encoded.lengthOf() - 4); - blocks.syncToDevice(); - } +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold) { + // we need these blocks in order to know, how many "updates" will be processed + // by each GPU block + auto blocks = thresholdEstimate_(updates, threshold); + + const int numThreads = 512; + const int numBlocks = updates.lengthOf() / numThreads + + (updates.lengthOf() % numThreads ? 1 : 0); + + const int prefixThreads = 512; + int numElts = numBlocks; + int level = 0; + + // here we just calculate number of sumBlock arrays + do { + int numPrefixBlocks = sd::math::nd4j_max( + 1, sd::math::nd4j_ceil((float)numElts / + (2.0f * prefixThreads))); + if (numBlocks > 1) { + level++; + } + numElts = numPrefixBlocks; + } while (numElts > 1); + + std::vector tempArrays(level); + std::vector pointers(level); + + level = 0; + numElts = numBlocks; + + do { + int numPrefixBlocks = sd::math::nd4j_max( + 1, sd::math::nd4j_ceil((float)numElts / + (2.0f * prefixThreads))); + if (numPrefixBlocks > 1) { + tempArrays[level] = + std::move(NDArrayFactory::create('c', {numPrefixBlocks})); + pointers[level] = tempArrays[level++].specialBuffer(); + } + numElts = numPrefixBlocks; + } while (numElts > 1); - NDArray::prepareSpecialUse({}, {&encoded, &updates}); + PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode"); + auto dptr = pm.replicatePointer(pointers.data(), pointers.size() * 8); + auto offsets = NDArrayFactory::create('c', {numBlocks}); - // filling offsets - encodeThresholdP2Int_(reinterpret_cast(dptr), - reinterpret_cast(blocks.specialBuffer()), - numBlocks, - reinterpret_cast(offsets.specialBuffer())); + // we want to check, if we're hiting external limit on number of encoded + // elements + auto numMatches = blocks.e(0); + if (numMatches > encoded.lengthOf() - 4) { + blocks.p(0, encoded.lengthOf() - 4); + blocks.syncToDevice(); + } - NDArray::registerSpecialUse({&blocks, &offsets}, {}); - pm.synchronize(); + NDArray::prepareSpecialUse({}, {&encoded, &updates}); + // filling offsets + encodeThresholdP2Int_(reinterpret_cast(dptr), + reinterpret_cast(blocks.specialBuffer()), + numBlocks, + reinterpret_cast(offsets.specialBuffer())); - encodeThresholdP3_(updates.specialBuffer(), - updates.shapeInfo(), - reinterpret_cast(offsets.specialBuffer()), - updates.lengthOf(), - reinterpret_cast(encoded.specialBuffer())); + NDArray::registerSpecialUse({&blocks, &offsets}, {}); + pm.synchronize(); - pm.synchronize(); + encodeThresholdP3_(updates.specialBuffer(), updates.shapeInfo(), + reinterpret_cast(offsets.specialBuffer()), + updates.lengthOf(), + reinterpret_cast(encoded.specialBuffer())); - NDArray::registerSpecialUse({&encoded, &updates}, {}); - } + pm.synchronize(); - void thresholdDecode(const NDArray &encoded, NDArray &updates) { - dim3 launchDims(128, 512, 512); - auto xType = updates.dataType(); + NDArray::registerSpecialUse({&encoded, &updates}, {}); +} - NDArray::prepareSpecialUse({&updates}, {&encoded}); - BUILD_SINGLE_SELECTOR(xType, decoderKernelGeneric, (launchDims, LaunchContext::defaultContext()->getCudaStream(), encoded.specialBuffer(), updates.lengthOf(), updates.specialBuffer()), FLOAT_TYPES); - NDArray::registerSpecialUse({&updates}, {&encoded}); - } - } - } +void thresholdDecode(const NDArray &encoded, NDArray &updates) { + dim3 launchDims(128, 512, 512); + auto xType = updates.dataType(); + + NDArray::prepareSpecialUse({&updates}, {&encoded}); + BUILD_SINGLE_SELECTOR( + xType, decoderKernelGeneric, + (launchDims, LaunchContext::defaultContext()->getCudaStream(), + encoded.specialBuffer(), updates.lengthOf(), updates.specialBuffer()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&updates}, {&encoded}); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index cbcd35ffea9b..1916f2957c87 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -18,176 +18,196 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +template +__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int axis) { + T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads; + __shared__ int rank; - T* z = reinterpret_cast(vz); - __shared__ Nd4jLong zLen, totalThreads; - __shared__ int rank; + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - if (threadIdx.x == 0) { - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - int coords[MAX_RANK]; + int coords[MAX_RANK]; - for (uint64_t i = tid; i < zLen; i += totalThreads) { - shape::index2coords(i, zShapeInfo, coords); + for (uint64_t i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - int inArrIdx = 0; - Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + int inArrIdx = 0; + Nd4jLong* xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; - while (coords[axis] >= xShapeInfo[axis + 1]) { - coords[axis] -= xShapeInfo[axis + 1]; - xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; - } + while (coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } - const auto *x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto* x = + reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - void* pVx, void* pxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { - - concatCuda<<>>(pVx, pxShapeInfo, vz, zShapeInfo, axis); +template +__host__ static void concatCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, void* pVx, void* pxShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int axis) { + concatCuda<<>>( + pVx, pxShapeInfo, vz, zShapeInfo, axis); } ////////////////////////////////////////////////////////////////////////// -void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfInArrs = inArrs.size(); - const auto sizeofT = output.sizeOfT(); - - NDArray::prepareSpecialUse({&output}, inArrs); - - bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfInArrs; ++i) { - luckCase1 &= inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis) { + const int numOfInArrs = inArrs.size(); + const auto sizeofT = output.sizeOfT(); + + NDArray::prepareSpecialUse({&output}, inArrs); + + bool luckCase1 = + ((axis == 0 && output.ordering() == 'c') || + (axis == output.rankOf() - 1 && output.ordering() == 'f')) && + output.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfInArrs; ++i) { + luckCase1 &= + inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - void* z = static_cast(output.specialBuffer()); + void* z = static_cast(output.specialBuffer()); - for (uint i = 0; i < numOfInArrs; ++i) { - const auto memAmountToCopy = inArrs[i]->lengthOf() * sizeofT; - cudaMemcpyAsync(z, reinterpret_cast(inArrs[i]->specialBuffer()), memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - z = static_cast(z) + memAmountToCopy; - } + for (uint i = 0; i < numOfInArrs; ++i) { + const auto memAmountToCopy = inArrs[i]->lengthOf() * sizeofT; + cudaMemcpyAsync( + z, reinterpret_cast(inArrs[i]->specialBuffer()), + memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); + z = static_cast(z) + memAmountToCopy; + } - if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - throw std::runtime_error("concat cuda: luckCase1 failed!"); + if (cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("concat cuda: luckCase1 failed!"); - for(int i = 0; i < numOfInArrs; ++i) - inArrs[i]->tickReadDevice(); - output.tickWriteDevice(); + for (int i = 0; i < numOfInArrs; ++i) inArrs[i]->tickReadDevice(); + output.tickWriteDevice(); - return; - } + return; + } - // const bool isZcontin = output.strideAt(axis) == 1; - // bool areInputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numOfInArrs); + // const bool isZcontin = output.strideAt(axis) == 1; + // bool areInputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numOfInArrs); - // if(isZcontin) { + // if(isZcontin) { - // for (uint i = 0; i < inArrs.size(); ++i) { + // for (uint i = 0; i < inArrs.size(); ++i) { - // areInputsContin &= inArrs[i]->strideAt(axis) == 1; - // allSameOrder &= output.ordering() == inArrs[i]->ordering(); - // if(!areInputsContin || !allSameOrder) - // break; + // areInputsContin &= inArrs[i]->strideAt(axis) == 1; + // allSameOrder &= output.ordering() == inArrs[i]->ordering(); + // if(!areInputsContin || !allSameOrder) + // break; - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo()); - // } - // } + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // inArrs[i]->shapeInfo()); + // } + // } - // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; + // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and output array - // const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo()); + // const auto zStep = shape::strideOverContigAxis(axis, + // output.shapeInfo()); - // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { + // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { - // const auto iShift = i * sizeofT; - // void* z = static_cast(output.specialBuffer()) + zStep * iShift; + // const auto iShift = i * sizeofT; + // void* z = static_cast(output.specialBuffer()) + zStep * + // iShift; - // for (uint j = 0; j < numOfInArrs; ++j) { - // const auto xDim = inArrs[j]->sizeAt(axis); - // void* x = static_cast(inArrs[j]->specialBuffer()) + strideOfContigStride[j] * iShift; - // const auto memSizeToCopy = xDim * sizeofT; - // cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - // z = static_cast(z) + memSizeToCopy; - // } - // } + // for (uint j = 0; j < numOfInArrs; ++j) { + // const auto xDim = inArrs[j]->sizeAt(axis); + // void* x = static_cast(inArrs[j]->specialBuffer()) + + // strideOfContigStride[j] * iShift; const auto memSizeToCopy = + // xDim * sizeofT; cudaMemcpyAsync(z, x, memSizeToCopy, + // cudaMemcpyDeviceToDevice, *context->getCudaStream()); z = + // static_cast(z) + memSizeToCopy; + // } + // } - // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - // throw std::runtime_error("concat cuda: luckCase2 failed!"); - // } - // else { // general (slower) case + // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + // throw std::runtime_error("concat cuda: luckCase2 failed!"); + // } + // else { // general (slower) case - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; - // prepare arrays of pointers on buffers and shapes - std::vector hInBuffers(numOfInArrs); - std::vector hInShapeInfo(numOfInArrs); + // prepare arrays of pointers on buffers and shapes + std::vector hInBuffers(numOfInArrs); + std::vector hInShapeInfo(numOfInArrs); - for(int i = 0; i < numOfInArrs; ++i) { - hInBuffers[i] = inArrs[i]->specialBuffer(); - hInShapeInfo[i] = inArrs[i]->specialShapeInfo(); - } + for (int i = 0; i < numOfInArrs; ++i) { + hInBuffers[i] = inArrs[i]->specialBuffer(); + hInShapeInfo[i] = inArrs[i]->specialShapeInfo(); + } - PointersManager manager(context, "helpers::concat"); + PointersManager manager(context, "helpers::concat"); - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); + void* dInBuffers = manager.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void*)); + void* dInShapeInfo = manager.replicatePointer( + hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + inArrs[0]->dataType(), concatCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + dInBuffers, dInShapeInfo, output.specialBuffer(), + output.specialShapeInfo(), axis), + LIBND4J_TYPES); - manager.synchronize(); - // } + manager.synchronize(); + // } - NDArray::registerSpecialUse({&output}, inArrs); + NDArray::registerSpecialUse({&output}, inArrs); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu index dfa86124a90c..8bb98a035dce 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu @@ -18,105 +18,135 @@ // @author GS // -#include #include -#include -#include #include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - destination[t] = static_cast(reinterpret_cast(source)[t]); - } - } - - template - __global__ static void confusionFunctorKernel(Nd4jLong* labelsBuffer, Nd4jLong* predictionBuffer, Nd4jLong bufferLength, void const* weightsBuffer, void* outputBuffer, const Nd4jLong* tadShape, const Nd4jLong* tadOffsets) { - __shared__ int arrIdx, blocksPerArr; - __shared__ T *z; - __shared__ T const* w; - __shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - w = reinterpret_cast(weightsBuffer); - arrLen = shape::length(tadShape); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - auto label = labelsBuffer[t]; //->e(j); - auto pred = predictionBuffer[t]; //->e(j); - auto tZ = z + tadOffsets[label]; - T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]); - - auto idx = shape::getIndexOffset(pred, tadShape); - tZ[idx] = val; - } - } - - template - void _confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto stream = context->getCudaStream(); - - auto pack = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), 1); - - PointersManager manager(context, "helpers::confusion"); - - Nd4jLong* labelsLongBuffer = labels->dataType() == sd::DataType::INT64?(Nd4jLong*)labels->specialBuffer():nullptr; - Nd4jLong* predictionLongBuffer = predictions->dataType() == sd::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr; - - if (labelsLongBuffer == nullptr) { - auto err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw sd::cuda_exception::build("Cannot allocate memory for labels long buffer", err); - // copy with type conversion - copyBuffers<<<256, 512, 1024, *stream>>>(labelsLongBuffer, labels->specialBuffer(), labels->lengthOf()); - } - - if (predictionLongBuffer == nullptr) { - auto err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw sd::cuda_exception::build("Cannot allocate memory for predictions long buffer", err); - // copy with type conversion - copyBuffers<<<256, 512, 1024, *stream>>>(predictionLongBuffer, predictions->specialBuffer(), predictions->lengthOf()); - } - - auto bufferLength = labels->lengthOf(); - dim3 launchDims(32, 32, 1024); - confusionFunctorKernel<<>>(labelsLongBuffer, predictionLongBuffer, bufferLength, weights != nullptr? weights->specialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets()); - - manager.synchronize(); - - if (predictionLongBuffer != predictions->specialBuffer()) { - cudaError_t err = cudaFree(predictionLongBuffer); - if (err != 0) - throw sd::cuda_exception::build("Cannot deallocate memory for predictions long buffer", err); - } - - if (labelsLongBuffer != labels->specialBuffer()) { - cudaError_t err = cudaFree(labelsLongBuffer); - if (err != 0) - throw sd::cuda_exception::build("Cannot deallocate memory for labels long buffer", err); - } - } - - void confusionFunctor(sd::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto xType = predictions->dataType(); - auto zType = output->dataType(); // weights can be null - NDArray::prepareSpecialUse({output}, {labels, predictions, weights}); - BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, (context, labels, predictions, weights, output), INDEXING_TYPES, NUMERIC_TYPES); - NDArray::registerSpecialUse({output}, {labels, predictions, weights}); - } +template +__global__ static void copyBuffers(Nd4jLong* destination, void const* source, + Nd4jLong bufferLength) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + destination[t] = + static_cast(reinterpret_cast(source)[t]); + } } + +template +__global__ static void confusionFunctorKernel( + Nd4jLong* labelsBuffer, Nd4jLong* predictionBuffer, Nd4jLong bufferLength, + void const* weightsBuffer, void* outputBuffer, const Nd4jLong* tadShape, + const Nd4jLong* tadOffsets) { + __shared__ int arrIdx, blocksPerArr; + __shared__ T* z; + __shared__ T const* w; + __shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + w = reinterpret_cast(weightsBuffer); + arrLen = shape::length(tadShape); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + auto label = labelsBuffer[t]; //->e(j); + auto pred = predictionBuffer[t]; //->e(j); + auto tZ = z + tadOffsets[label]; + T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]); + + auto idx = shape::getIndexOffset(pred, tadShape); + tZ[idx] = val; + } +} + +template +void _confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, + NDArray* output) { + auto stream = context->getCudaStream(); + + auto pack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), 1); + + PointersManager manager(context, "helpers::confusion"); + + Nd4jLong* labelsLongBuffer = labels->dataType() == sd::DataType::INT64 + ? (Nd4jLong*)labels->specialBuffer() + : nullptr; + Nd4jLong* predictionLongBuffer = + predictions->dataType() == sd::DataType::INT64 + ? (Nd4jLong*)predictions->specialBuffer() + : nullptr; + + if (labelsLongBuffer == nullptr) { + auto err = + cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot allocate memory for labels long buffer", err); + // copy with type conversion + copyBuffers<<<256, 512, 1024, *stream>>>( + labelsLongBuffer, labels->specialBuffer(), labels->lengthOf()); + } + + if (predictionLongBuffer == nullptr) { + auto err = cudaMalloc(&predictionLongBuffer, + predictions->lengthOf() * sizeof(Nd4jLong)); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot allocate memory for predictions long buffer", err); + // copy with type conversion + copyBuffers<<<256, 512, 1024, *stream>>>(predictionLongBuffer, + predictions->specialBuffer(), + predictions->lengthOf()); + } + + auto bufferLength = labels->lengthOf(); + dim3 launchDims(32, 32, 1024); + confusionFunctorKernel + <<>>( + labelsLongBuffer, predictionLongBuffer, bufferLength, + weights != nullptr ? weights->specialBuffer() : nullptr, + output->specialBuffer(), pack.specialShapeInfo(), + pack.specialOffsets()); + + manager.synchronize(); + + if (predictionLongBuffer != predictions->specialBuffer()) { + cudaError_t err = cudaFree(predictionLongBuffer); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot deallocate memory for predictions long buffer", err); + } + + if (labelsLongBuffer != labels->specialBuffer()) { + cudaError_t err = cudaFree(labelsLongBuffer); + if (err != 0) + throw sd::cuda_exception::build( + "Cannot deallocate memory for labels long buffer", err); + } +} + +void confusionFunctor(sd::LaunchContext* context, NDArray* labels, + NDArray* predictions, NDArray* weights, NDArray* output) { + auto xType = predictions->dataType(); + auto zType = output->dataType(); // weights can be null + NDArray::prepareSpecialUse({output}, {labels, predictions, weights}); + BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, + (context, labels, predictions, weights, output), + INDEXING_TYPES, NUMERIC_TYPES); + NDArray::registerSpecialUse({output}, {labels, predictions, weights}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu index 80df76c91ae0..2f6b31688859 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_col2vol.cu @@ -19,113 +19,136 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, iC, iD, iH, iW] +// columns [bS, iC, kD, kH, kW, oD, oH, oW] to be de-convoluted to volume [bS, +// iC, iD, iH, iW] template -static __global__ void col2volCuda(const void* columns, const Nd4jLong* colShapeInfo, void* volume, const Nd4jLong* volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* col = reinterpret_cast(columns); - T* vol = reinterpret_cast(volume); - - __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; - __shared__ Nd4jLong volLen; +static __global__ void col2volCuda(const void* columns, + const Nd4jLong* colShapeInfo, void* volume, + const Nd4jLong* volShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + const T* col = reinterpret_cast(columns); + T* vol = reinterpret_cast(volume); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ uint kD, kH, kW, oD, oH, oW, *sharedMem; + __shared__ Nd4jLong volLen; - oD = colShapeInfo[6]; - oH = colShapeInfo[7]; - oW = colShapeInfo[8]; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - kD = dD * (colShapeInfo[3] - 1) + 1; - kH = dH * (colShapeInfo[4] - 1) + 1; - kW = dW * (colShapeInfo[5] - 1) + 1; + oD = colShapeInfo[6]; + oH = colShapeInfo[7]; + oW = colShapeInfo[8]; - volLen = shape::length(volShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * 8; + kD = dD * (colShapeInfo[3] - 1) + 1; + kH = dH * (colShapeInfo[4] - 1) + 1; + kW = dW * (colShapeInfo[5] - 1) + 1; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + volLen = shape::length(volShapeInfo); + } + __syncthreads(); - for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { + auto coords = sharedMem + threadIdx.x * 8; - shape::index2coords(i, volShapeInfo, coords); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto volOffset = shape::getOffset(volShapeInfo, coords); + for (Nd4jLong i = tid; i < volLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, volShapeInfo, coords); - const auto bSiCoffset = coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; + const auto volOffset = shape::getOffset(volShapeInfo, coords); - const uint imD = coords[2] + pD; - const uint imH = coords[3] + pH; - const uint imW = coords[4] + pW; + const auto bSiCoffset = + coords[0] * colShapeInfo[9] + coords[1] * colShapeInfo[10]; - const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; - const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; - const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; + const uint imD = coords[2] + pD; + const uint imH = coords[3] + pH; + const uint imW = coords[4] + pW; - const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); - const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); - const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); + const uint colDstart = (imD < kD) ? 0 : (imD - kD) / sD + 1; + const uint colHstart = (imH < kH) ? 0 : (imH - kH) / sH + 1; + const uint colWstart = (imW < kW) ? 0 : (imW - kW) / sW + 1; - T val = 0; + const uint colDend = sd::math::nd4j_min(imD / sD + 1, oD); + const uint colHend = sd::math::nd4j_min(imH / sH + 1, oH); + const uint colWend = sd::math::nd4j_min(imW / sW + 1, oW); - for(uint colD = colDstart; colD < colDend; ++colD) { - coords[2] = imD - colD * sD; - if(coords[2] % dD != 0) continue; + T val = 0; - for(uint colH = colHstart; colH < colHend; ++colH) { - coords[3] = imH - colH * sH; - if(coords[3] % dH != 0) continue; + for (uint colD = colDstart; colD < colDend; ++colD) { + coords[2] = imD - colD * sD; + if (coords[2] % dD != 0) continue; - for(uint colW = colWstart; colW < colWend; ++colW) { - coords[4] = imW - colW * sW; - if(coords[4] % dW != 0) continue; + for (uint colH = colHstart; colH < colHend; ++colH) { + coords[3] = imH - colH * sH; + if (coords[3] % dH != 0) continue; - val += col[bSiCoffset + (coords[2]/dD)*colShapeInfo[11] + (coords[3]/dH)*colShapeInfo[12] + (coords[4]/dW)*colShapeInfo[13] + colD*colShapeInfo[14] + colH*colShapeInfo[15] + colW*colShapeInfo[16]]; + for (uint colW = colWstart; colW < colWend; ++colW) { + coords[4] = imW - colW * sW; + if (coords[4] % dW != 0) continue; - } - } + val += col[bSiCoffset + (coords[2] / dD) * colShapeInfo[11] + + (coords[3] / dH) * colShapeInfo[12] + + (coords[4] / dW) * colShapeInfo[13] + + colD * colShapeInfo[14] + colH * colShapeInfo[15] + + colW * colShapeInfo[16]]; } - - vol[volOffset] = val; + } } + + vol[volOffset] = val; + } } ////////////////////////////////////////////////////////////////////////// template -static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* columns, const Nd4jLong* colShapeInfo, - void* volume, const Nd4jLong* volShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +static void col2volCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* columns, + const Nd4jLong* colShapeInfo, void* volume, + const Nd4jLong* volShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + col2volCuda<<>>( + columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, + dH, dW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "col2vol"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; - - NDArray::prepareSpecialUse({&vol}, {&col}); - BUILD_SINGLE_SELECTOR(vol.dataType(), col2volCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), col.specialBuffer(), col.specialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&vol}, {&col}); - - manager.synchronize(); +void ConvolutionUtils::col2vol(sd::graph::Context& block, const NDArray& col, + NDArray& vol, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + PointersManager manager(block.launchContext(), "col2vol"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (vol.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(uint) * threadsPerBlock + 256; + + NDArray::prepareSpecialUse({&vol}, {&col}); + BUILD_SINGLE_SELECTOR( + vol.dataType(), col2volCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), col.specialBuffer(), + col.specialShapeInfo(), vol.specialBuffer(), vol.specialShapeInfo(), sD, + sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&vol}, {&col}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu index 494ce4a815a6..ce4be2c3928e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2d.cu @@ -19,87 +19,113 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector permutForOutput; - - if(isNCHW) - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - else - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - - std::vector wAxes; - if(0 == wFormat) - wAxes = {0, 1, 2}; - else if(1 == wFormat) - wAxes = {2, 3, 1}; - else - wAxes = {1, 2, 3}; - - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); - NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - - //----- calculation of output -----// - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, wAxes, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); - - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; - +static void conv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector permutForOutput; + + if (isNCHW) + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + else + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + + std::vector wAxes; + if (0 == wFormat) + wAxes = {0, 1, 2}; + else if (1 == wFormat) + wAxes = {2, 3, 1}; + else + wAxes = {1, 2, 3}; + + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), + input->getContext()); + NDArray colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS * oH * oW, oC}, output->dataType(), + output->getContext()); + + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, colP, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot( + &col, weights, &mmulResult, {3, 4, 5}, wAxes, + {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + + //----- assign outTemp to output -----// + if (isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); + + //----- add biases if required -----// + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::conv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu index dbf4ee39012f..7317e0009fb0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_conv2dBP.cu @@ -19,107 +19,144 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void conv2dBP_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector gradOaxesForDot; - - if(!isNCHW) { - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - } else { - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - } - - std::vector wPermut, colPermut; - if(0 == wFormat) { - wPermut = {2, 0, 1, 3}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - else if(1 == wFormat) { - wPermut = {1, 2, 3, 0}; - colPermut = {1, 2, 3, 0, 4, 5}; - } - else { - wPermut = {3, 1, 2, 0}; - colPermut = {2, 3, 1, 0, 4, 5}; - } - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - - // ----- calculation of gradW ----- // - if(gradW) { - auto ctx = block.launchContext(); - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - // [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, oH, oW] - // [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - sd::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - - helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } +static void conv2dBP_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector gradOaxesForDot; + + if (!isNCHW) { + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + } else { + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + } + + std::vector wPermut, colPermut; + if (0 == wFormat) { + wPermut = {2, 0, 1, 3}; + colPermut = {2, 3, 1, 0, 4, 5}; + } else if (1 == wFormat) { + wPermut = {1, 2, 3, 0}; + colPermut = {1, 2, 3, 0, 4, 5}; + } else { + wPermut = {3, 1, 2, 0}; + colPermut = {2, 3, 1, 0, 4, 5}; + } + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + + // ----- calculation of gradW ----- // + if (gradW) { + auto ctx = block.launchContext(); + helpers::im2col( + *ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to + // [bS, iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot( + &columns, gradO, gradW, {0, 4, 5}, gradOaxesForDot, + wPermut); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, + // oW] = [iC, kH, kW, oC] + } + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, + false); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + // [kH, kW, iC, oC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, + // oW] [oC, iC, kH, kW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, bS, + // oH, oW] [oC, kH, kW, iC] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, + // bS, oH, oW] + sd::MmulHelper::tensorDot( + weights, gradO, &columns, {indWoC}, {indIOioC}, + colPermut); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, + // oC, oH, oW] = [kH, kW, iC, bS, oH, oW] + + helpers::col2im( + *block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::conv2dBP(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), conv2dBP_, + (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu index bbf5d5892617..42f29c050cdc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2d.cu @@ -19,83 +19,119 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput, modifWeights; - std::vector outReShape; - - if(!isNCHW) { - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); - - helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, modifWeights, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(block, *output, *bias, *output, isNCHW); - - if(!isNCHW) - delete input; +static void depthwiseConv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 0, 4, 5, 2, 3}, + {iC, bS * oH * oW, + kH * kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> + // [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput, modifWeights; + std::vector outReShape; + + if (!isNCHW) { + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray outputReshaped = + output->reshape(output->ordering(), outReShape, false); + + helpers::im2col( + *output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, + modifWeights, + modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, + // mC] = [iC, bS*oH*oW, mC] + + if (bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(block, *output, *bias, *output, isNCHW); + + if (!isNCHW) delete input; } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2d( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2d_, + (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu index b06af61665dd..545ae0daa5e9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_depthwiseConv2dBP.cu @@ -19,102 +19,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include #include +#include +#include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2, modifWeights; - std::vector gradOreShape; - - if(!isNCHW) { - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - input = new NDArray(input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(0 == wFormat) - modifWeights = {{2,0,1,3},{iC,kH*kW,mC}}; - else if(1 == wFormat) - modifWeights = {{1,2,3,0},{iC,kH*kW,mC}}; - else - modifWeights = {{3,1,2,0},{iC,kH*kW,mC}}; - - if(paddingMode == 1) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, modifGradO1, modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); - gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } +static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, + const int kH, const int kW, const int sH, + const int sW, int pH, int pW, const int dH, + const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = { + {1, 2, 3, 0, 4, 5}, + {iC, kH * kW, + bS * oH * oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2, modifWeights; + std::vector gradOreShape; + + if (!isNCHW) { + gradOreShape = {bS, oH, oW, iC, + mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = { + {3, 0, 1, 2, 4}, + {iC, bS * oH * oW, + mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {3, 0, 1, 2}, + {iC, mC, + bS * oH * + oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = new NDArray( + gradI->permute({0, 3, 1, 2})); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + } else { + gradOreShape = {bS, iC, mC, oH, + oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = { + {1, 0, 3, 4, 2}, + {iC, bS * oH * oW, + mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = { + {1, 0, 2, 3}, + {iC, mC, + bS * oH * + oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if (0 == wFormat) + modifWeights = {{2, 0, 1, 3}, {iC, kH * kW, mC}}; + else if (1 == wFormat) + modifWeights = {{1, 2, 3, 0}, {iC, kH * kW, mC}}; + else + modifWeights = {{3, 1, 2, 0}, {iC, kH * kW, mC}}; + + if (paddingMode == 1) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, + input->dataType(), input->getContext()); + NDArray gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col( + *input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, + NDArrayFactory::create( + 0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, + // iC, kH, kW, oH, oW] + sd::MmulHelper::tensorDot(&columns, &gradOreshaped, gradW, modifColumns, + modifGradO1, + modifWeights); // [iC, kW*kH, bS*oH*oW] x [iC, + // bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if (gradB) { + NDArray* gradBR = gradB; + if (gradB->rankOf() == 2) + gradBR = new NDArray( + gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); + gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0, indOoH, indOoH + 1}, + false); // sum over bS, oH, oW + if (gradBR != gradB) delete gradBR; + } + + //----- calculation of gradI -----// + sd::MmulHelper::tensorDot(weights, gradO, &columns, modifWeights, modifGradO2, + modifColumns); // [iC, kH*kW, mC] x [iC, mC, + // bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im( + *input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, + dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if (!isNCHW) { + delete input; + delete gradI; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::depthwiseConv2dBP(sd::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::depthwiseConv2dBP( + sd::graph::Context& block, const NDArray* input, const NDArray* weights, + const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, + NDArray* gradB, const int kH, const int kW, const int sH, const int sW, + int pH, int pW, const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), depthwiseConv2dBP_, + (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu index c146be7bff2b..1ad9a69a29cf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2d.cu @@ -19,324 +19,381 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +static __global__ void avgPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); + Z sum = 0.0f; - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; + const X *inSlice = x + (n * strideB + c * strideC); - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += static_cast(inSlice[h * strideY + w * strideX]); - length = shape::length(zShapeInfo); + int divide_factor = pool_size; // Case 0: exclude padding + if (extraParam0 == 1) // Case 1: include padding + divide_factor = kH * kW; - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z sum = 0.0f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += static_cast(inSlice[h * strideY + w * strideX]); - - int divide_factor = pool_size; //Case 0: exclude padding - if (extraParam0 == 1) //Case 1: include padding - divide_factor = kH * kW; - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sum / static_cast(divide_factor); - } + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = + sum / static_cast(divide_factor); + } } ////////////////////////////////////////////////////////////////////////// template -static void avgPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void avgPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int extraParam0) { + avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// template -static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); +static __global__ void pnormPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; + Z sum = 0.f; - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + const X *inSlice = x + (n * strideB + c * strideC); - length = shape::length(zShapeInfo); + for (int h = hstart; h < hend; h += dH) + for (int w = wstart; w < wend; w += dW) + sum += sd::math::nd4j_pow( + static_cast( + sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), + extraParam0); - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if (hstart < 0) { - int f = sd::math::nd4j_ceil((Z) -hstart / (Z) dH); - hstart += f * dH; - } - if (wstart < 0) { - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if (hend > iH) { - int f = sd::math::nd4j_ceil((Z) (hend - iH) / (Z) dH); - hend -= f * dH; - } - if (wend > iW) { - int f = sd::math::nd4j_ceil((Z) (wend - iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend - hstart) / (double) dH) * - sd::math::nd4j_ceil((double) (wend - wstart) / (double) dW); - - Z sum = 0.f; - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) - for (int w = wstart; w < wend; w += dW) - sum += sd::math::nd4j_pow(static_cast(sd::math::nd4j_abs(inSlice[h * strideY + w * strideX])), extraParam0); - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = sd::math::nd4j_pow(sum, (Z) 1.0f / extraParam0); - } + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = + sd::math::nd4j_pow(sum, (Z)1.0f / extraParam0); + } } ////////////////////////////////////////////////////////////////////////// template -static void pnormPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void pnormPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int extraParam0) { + pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// template -static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; - __shared__ bool fOrder; - - if (threadIdx.x == 0) { - bS = shape::sizeAt(xShapeInfo, 0); - iC = shape::sizeAt(xShapeInfo, 1); - oH = shape::sizeAt(zShapeInfo, 2); - oW = shape::sizeAt(zShapeInfo, 3); - iH = shape::sizeAt(xShapeInfo, 2); - iW = shape::sizeAt(xShapeInfo, 3); - - strideB = shape::stride(xShapeInfo)[0]; - strideC = shape::stride(xShapeInfo)[1]; - strideY = shape::stride(xShapeInfo)[2]; - strideX = shape::stride(xShapeInfo)[3]; +static __global__ void maxPooling2dCuda( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int bS, iC, oH, oW, iH, iW, strideB, strideC, strideY, strideX, + strideOB, strideOC, strideOY, strideOX, length, kHEff, kWEff; + __shared__ bool fOrder; + + if (threadIdx.x == 0) { + bS = shape::sizeAt(xShapeInfo, 0); + iC = shape::sizeAt(xShapeInfo, 1); + oH = shape::sizeAt(zShapeInfo, 2); + oW = shape::sizeAt(zShapeInfo, 3); + iH = shape::sizeAt(xShapeInfo, 2); + iW = shape::sizeAt(xShapeInfo, 3); + + strideB = shape::stride(xShapeInfo)[0]; + strideC = shape::stride(xShapeInfo)[1]; + strideY = shape::stride(xShapeInfo)[2]; + strideX = shape::stride(xShapeInfo)[3]; + + strideOB = shape::stride(zShapeInfo)[0]; + strideOC = shape::stride(zShapeInfo)[1]; + strideOY = shape::stride(zShapeInfo)[2]; + strideOX = shape::stride(zShapeInfo)[3]; + + length = shape::length(zShapeInfo); + + // Replace kernel H/W with *effective* kernel H/W accounting for dilatyon + kHEff = kH + (kH - 1) * (dH - 1); + kWEff = kW + (kW - 1) * (dW - 1); + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int index = tid; index < length; index += blockDim.x * gridDim.x) { + const int pw = index % oW; + const int ph = (index / oW) % oH; + const int c = (index / oW / oH) % iC; + const int n = index / oW / oH / iC; + + int hstart = sH * ph - pH; + int wstart = sW * pw - pW; + int hend = hstart + kHEff; + int wend = wstart + kWEff; + + if (hstart < 0) { + int f = sd::math::nd4j_ceil((Z)-hstart / (Z)dH); + hstart += f * dH; + } + if (wstart < 0) { + int f = sd::math::nd4j_ceil((Z)-wstart / (Z)dW); + wstart += f * dW; + } + if (hend > iH) { + int f = sd::math::nd4j_ceil((Z)(hend - iH) / (Z)dH); + hend -= f * dH; + } + if (wend > iW) { + int f = sd::math::nd4j_ceil((Z)(wend - iW) / (Z)dW); + wend -= f * dW; + } + // Accounts for dilation + int pool_size = + sd::math::nd4j_ceil((double)(hend - hstart) / (double)dH) * + sd::math::nd4j_ceil((double)(wend - wstart) / (double)dW); - strideOB = shape::stride(zShapeInfo)[0]; - strideOC = shape::stride(zShapeInfo)[1]; - strideOY = shape::stride(zShapeInfo)[2]; - strideOX = shape::stride(zShapeInfo)[3]; + Z max = -sd::DataTypeUtils::max(); - length = shape::length(zShapeInfo); + const X *inSlice = x + (n * strideB + c * strideC); - //Replace kernel H/W with *effective* kernel H/W accounting for dilatyon - kHEff = kH + (kH-1)*(dH-1); - kWEff = kW + (kW-1)*(dW-1); - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (int index = tid; index < length; index += blockDim.x * gridDim.x) { - - const int pw = index % oW; - const int ph = (index / oW) % oH; - const int c = (index / oW / oH) % iC; - const int n = index / oW / oH / iC; - - int hstart = sH * ph - pH; - int wstart = sW * pw - pW; - int hend = hstart + kHEff; - int wend = wstart + kWEff; - - if(hstart < 0){ - int f = sd::math::nd4j_ceil((Z) -hstart / (Z)dH); - hstart += f * dH; - } - if(wstart < 0){ - int f = sd::math::nd4j_ceil((Z) -wstart / (Z) dW); - wstart += f * dW; - } - if(hend > iH){ - int f = sd::math::nd4j_ceil((Z) (hend-iH) / (Z) dH); - hend -= f * dH; - } - if(wend > iW){ - int f = sd::math::nd4j_ceil((Z) (wend-iW) / (Z) dW); - wend -= f * dW; - } - //Accounts for dilation - int pool_size = sd::math::nd4j_ceil((double) (hend-hstart) / (double) dH) * sd::math::nd4j_ceil((double) (wend-wstart) / (double) dW); - - Z max = -sd::DataTypeUtils::max(); - - const X *inSlice = x + (n * strideB + c * strideC); - - for (int h = hstart; h < hend; h += dH) { - for (int w = wstart; w < wend; w += dW) { - Z v = static_cast(inSlice[h * strideY + w * strideX]); - if (v > max) - max = v; - } - } - - z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; + for (int h = hstart; h < hend; h += dH) { + for (int w = wstart; w < wend; w += dW) { + Z v = static_cast(inSlice[h * strideY + w * strideX]); + if (v > max) max = v; + } } + + z[n * strideOB + c * strideOC + pw * strideOX + ph * strideOY] = max; + } } ////////////////////////////////////////////////////////////////////////// template -static void maxPooling2dCudaLauncher(sd::LaunchContext & block, const void *vx, const Nd4jLong *vxShapeInfo, void *vz, const Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { - maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); +static void maxPooling2dCudaLauncher(sd::LaunchContext &block, const void *vx, + const Nd4jLong *vxShapeInfo, void *vz, + const Nd4jLong *vzShapeInfo, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int extraParam0) { + maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>( + vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - - if(!input.isActualOnDeviceSide()) input.syncToDevice(); - - switch (poolingMode) { - - case MAX_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case AVG_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - case PNORM_POOL: { - BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); - } - break; - default: - throw std::runtime_error("Pooling2D: Unknown PoolingType used"); - } - - output.tickWriteDevice(); - input.tickReadDevice(); - - auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); - if (result != 0) - throw cuda_exception::build("Pooling2D failed", result); +void ConvolutionUtils::pooling2d(sd::graph::Context &block, + const NDArray &input, NDArray &output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const PoolingType poolingMode, + const int extraParam0) { + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + + switch (poolingMode) { + case MAX_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), maxPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + case AVG_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), avgPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + case PNORM_POOL: { + BUILD_SINGLE_SELECTOR_TWICE( + input.dataType(), pnormPooling2dCudaLauncher, + (*block.launchContext(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, + extraParam0), + FLOAT_TYPES); + } break; + default: + throw std::runtime_error("Pooling2D: Unknown PoolingType used"); + } + + output.tickWriteDevice(); + input.tickReadDevice(); + + auto result = cudaStreamSynchronize(*block.launchContext()->getCudaStream()); + if (result != 0) throw cuda_exception::build("Pooling2D failed", result); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu index 62f4787ddc02..398fd872f6c5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling2dBP.cu @@ -19,170 +19,189 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iH, iW] - // y: gradO [bS, iC, oH, oW] - // z: gradI [bS, iC, iH, iW] -> gradI is output in this function - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3; - __shared__ int rank, kHeff, kWeff, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 4; - - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iH = xShapeInfo[3]; - iW = xShapeInfo[4]; - - kProd = kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int hstart = coords[2] * sH - pH; - int wstart = coords[3] * sW - pW; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - coord2 = hstart; - coord3 = wstart; - - T max = -DataTypeUtils::max(); - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW){ - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - } - } - } - coords[2] = coord2; - coords[3] = coord3; - auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); - //z[zOffset] += y[yOffset]; +__global__ static void pooling2dBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int kH, const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, const int poolingMode, + const int extraParam0) { + // x: input [bS, iC, iH, iW] + // y: gradO [bS, iC, oH, oW] + // z: gradI [bS, iC, iH, iW] -> gradI is output in this function + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3; + __shared__ int rank, kHeff, kWeff, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 4; + + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iH = xShapeInfo[3]; + iW = xShapeInfo[4]; + + kProd = kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (yInd >= yLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int hstart = coords[2] * sH - pH; + int wstart = coords[3] * sW - pW; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + coord2 = hstart; + coord3 = wstart; + + T max = -DataTypeUtils::max(); + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + } } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); - } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { - for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } + } + coords[2] = coord2; + coords[3] = coord3; + auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); + // z[zOffset] += y[yOffset]; + } break; + + /*** avg ***/ + case 1: { + T val = y[yOffset]; + + if (extraParam0 == 0) // Exclude padding + val /= + sd::math::nd4j_ceil(static_cast(hend - hstart) / + static_cast(dH)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast(dW)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + val /= kProd; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], val); + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + val *= sd::math::nd4j_pow(sum, + ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { + for (coords[3] = wstart; coords[3] < wend; coords[3] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd( + &z[zOffset], + val * + sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), + extraParam0 - 1.f) * + sd::math::nd4j_sgn(x[xOffset])); } - break; - } + } + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); +static void pooling2dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling2dBPCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, + dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling2dBP"); - - const int threadsPerBlock = 256; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); +void ConvolutionUtils::pooling2dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling2dBP"); + + const int threadsPerBlock = 256; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + gradO.specialBuffer(), gradO.specialShapeInfo(), + gradI.specialBuffer(), gradI.specialShapeInfo(), kH, + kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu index 0a3bfc9b6119..5f69a30ba8c3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3d.cu @@ -19,163 +19,177 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { - +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x input is [bS, iC, iD, iH, iW] - // z output is [bS, iC, oD, oH, oW] - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong zLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(zInd >= zLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(zInd, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) - max = val; - } - } - } - z[zOffset] = max; - } - break; - - /*** avg ***/ - case 1: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += x[shape::getOffset(xShapeInfo, coords)]; - - if (extraParam0 == 0) { //Exclude padding - uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); - uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); - uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); - sum /= static_cast(a * b * c); // /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - } - else if (extraParam0 == 1) //Include padding - sum /= kProd; - - z[zOffset] = sum; - } - break; - - /*** pnorm ***/ - case 2: { - T sum = static_cast(0.); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - sum = sd::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - z[zOffset] = sum; +__global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, + const int extraParam0) { + // x input is [bS, iC, iD, iH, iW] + // z output is [bS, iC, oD, oH, oW] + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong zLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (zInd >= zLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(zInd, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) max = val; + } } - break; - } + } + z[zOffset] = max; + } break; + + /*** avg ***/ + case 1: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += x[shape::getOffset(xShapeInfo, coords)]; + + if (extraParam0 == 0) { // Exclude padding + uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1); + uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1); + uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1); + sum /= static_cast( + a * b * + c); // /= sd::math::nd4j_ceil(static_cast(dend - + // dstart) / static_cast(dD)) * + // sd::math::nd4j_ceil(static_cast(hend - + // hstart) / static_cast(dH)) * + // sd::math::nd4j_ceil(static_cast(wend - + // wstart) / static_cast(dW)); //Accounts for + // dilation + } else if (extraParam0 == 1) // Include padding + sum /= kProd; + + z[zOffset] = sum; + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + sum = sd::math::nd4j_pow(sum, (T)1.f / extraParam0); + + z[zOffset] = sum; + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +static void pooling3dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling3dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, + dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - PointersManager manager(block.launchContext(), "pooling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::pooling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + PointersManager manager(block.launchContext(), "pooling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), pooling3dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } - -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu index fd78bb80beb0..70fa1a43e0a0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_pooling3dBP.cu @@ -19,184 +19,205 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void pooling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // x: input [bS, iC, iD, iH, iW] - // y: gradO [bS, iC, oD, oH, oW] - // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - - - const T* x = reinterpret_cast(vx); - const T* y = reinterpret_cast(vy); - T* z = reinterpret_cast(vz); - - Nd4jLong coord2, coord3, coord4; - __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; - __shared__ Nd4jLong yLen, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - rank = 5; - - kDeff = kD + (kD - 1) * (dD - 1); - kHeff = kH + (kH - 1) * (dH - 1); - kWeff = kW + (kW - 1) * (dW - 1); - - iD = xShapeInfo[3]; - iH = xShapeInfo[4]; - iW = xShapeInfo[5]; - - kProd = kD * kH * kW; - } - __syncthreads(); - - const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(yInd >= yLen) - return; - - auto coords = sharedMem + threadIdx.x * rank; - - shape::index2coords(yInd, yShapeInfo, coords); - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - int dstart = coords[2] * sD - pD; - int hstart = coords[3] * sH - pH; - int wstart = coords[4] * sW - pW; - int dend = dstart + kDeff; - int hend = hstart + kHeff; - int wend = wstart + kWeff; - - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend - iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend - iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend - iW + dW - 1) / dW); - - - switch (poolingMode) { - - /*** max ***/ - case 0: { - - T max = -DataTypeUtils::max(); - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH){ - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - T val = x[shape::getOffset(xShapeInfo, coords)]; - if (val > max) { - max = val; - coord2 = coords[2]; - coord3 = coords[3]; - coord4 = coords[4]; - } - } - } +__global__ static void pooling3dBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int kD, const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int poolingMode, const int extraParam0) { + // x: input [bS, iC, iD, iH, iW] + // y: gradO [bS, iC, oD, oH, oW] + // z: gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + Nd4jLong coord2, coord3, coord4; + __shared__ int rank, kDeff, kHeff, kWeff, iD, iH, iW, kProd; + __shared__ Nd4jLong yLen, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + rank = 5; + + kDeff = kD + (kD - 1) * (dD - 1); + kHeff = kH + (kH - 1) * (dH - 1); + kWeff = kW + (kW - 1) * (dW - 1); + + iD = xShapeInfo[3]; + iH = xShapeInfo[4]; + iW = xShapeInfo[5]; + + kProd = kD * kH * kW; + } + __syncthreads(); + + const auto yInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (yInd >= yLen) return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(yInd, yShapeInfo, coords); + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + int dstart = coords[2] * sD - pD; + int hstart = coords[3] * sH - pH; + int wstart = coords[4] * sW - pW; + int dend = dstart + kDeff; + int hend = hstart + kHeff; + int wend = wstart + kWeff; + + if (dstart < 0) dstart += dD * ((-dstart + dD - 1) / dD); + if (hstart < 0) hstart += dH * ((-hstart + dH - 1) / dH); + if (wstart < 0) wstart += dW * ((-wstart + dW - 1) / dW); + if (dend > iD) dend -= dD * ((dend - iD + dD - 1) / dD); + if (hend > iH) hend -= dH * ((hend - iH + dH - 1) / dH); + if (wend > iW) wend -= dW * ((wend - iW + dW - 1) / dW); + + switch (poolingMode) { + /*** max ***/ + case 0: { + T max = -DataTypeUtils::max(); + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + T val = x[shape::getOffset(xShapeInfo, coords)]; + if (val > max) { + max = val; + coord2 = coords[2]; + coord3 = coords[3]; + coord4 = coords[4]; } - coords[2] = coord2; - coords[3] = coord3; - coords[4] = coord4; - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); - } - break; - - /*** avg ***/ - case 1: { - - T val = y[yOffset]; - - if (extraParam0 == 0) //Exclude padding - val /= sd::math::nd4j_ceil(static_cast(dend - dstart) / static_cast(dD)) * sd::math::nd4j_ceil(static_cast(hend - hstart) / static_cast(dH)) * sd::math::nd4j_ceil(static_cast(wend - wstart) / static_cast(dW)); //Accounts for dilation - else if (extraParam0 == 1) //Include padding - val /= kProd; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sd::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(zShapeInfo, coords)], val); + } } - break; - - /*** pnorm ***/ - case 2: { - - T sum = static_cast(0.); - T val = y[yOffset]; - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) - sum += sd::math::nd4j_pow(sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), extraParam0); - - val *= sd::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); - - for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { - for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { - for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - sd::math::atomics::nd4j_atomicAdd(&z[zOffset], val * sd::math::nd4j_pow(sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * sd::math::nd4j_sgn(x[xOffset])); - } - } - } + } + coords[2] = coord2; + coords[3] = coord3; + coords[4] = coord4; + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], y[yOffset]); + } break; + + /*** avg ***/ + case 1: { + T val = y[yOffset]; + + if (extraParam0 == 0) // Exclude padding + val /= + sd::math::nd4j_ceil(static_cast(dend - dstart) / + static_cast(dD)) * + sd::math::nd4j_ceil(static_cast(hend - hstart) / + static_cast(dH)) * + sd::math::nd4j_ceil( + static_cast(wend - wstart) / + static_cast(dW)); // Accounts for dilation + else if (extraParam0 == 1) // Include padding + val /= kProd; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getOffset(zShapeInfo, coords)], val); + } break; + + /*** pnorm ***/ + case 2: { + T sum = static_cast(0.); + T val = y[yOffset]; + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) + sum += sd::math::nd4j_pow( + sd::math::nd4j_abs(x[shape::getOffset(xShapeInfo, coords)]), + extraParam0); + + val *= sd::math::nd4j_pow(sum, + ((T)1.f - extraParam0) / extraParam0); + + for (coords[2] = dstart; coords[2] < dend; coords[2] += dD) { + for (coords[3] = hstart; coords[3] < hend; coords[3] += dH) { + for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) { + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + sd::math::atomics::nd4j_atomicAdd( + &z[zOffset], + val * + sd::math::nd4j_pow( + sd::math::nd4j_abs(x[xOffset]), extraParam0 - 1.f) * + sd::math::nd4j_sgn(x[xOffset])); + } } - break; - } + } + } break; + } } ////////////////////////////////////////////////////////////////////////// template -static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int poolingMode, const int extraParam0) { - - pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); +static void pooling3dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, const int dW, + const int poolingMode, const int extraParam0) { + pooling3dBPCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - - // initial zeroing of gradI - gradI.nullify(); - - PointersManager manager(block.launchContext(), "pooling3dBP"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); +void ConvolutionUtils::pooling3dBP(sd::graph::Context& block, + const NDArray& input, const NDArray& gradO, + NDArray& gradI, const int kD, const int kH, + const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW, const int poolingMode, + const int extraParam0) { + // initial zeroing of gradI + gradI.nullify(); + + PointersManager manager(block.launchContext(), "pooling3dBP"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + BUILD_SINGLE_SELECTOR( + input.dataType(), pooling3dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + poolingMode, extraParam0), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu index 3a9ed5364071..cd22dc728343 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_sconv2d.cu @@ -22,52 +22,81 @@ #include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -static void sconv2d_(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { +static void sconv2d_(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, const NDArray* weightsPoint, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, int pH, int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - // weightsPoint [1, 1, iC*mC, oC], [oC, iC*mC, 1, 1], [oC, 1, 1, iC*mC] - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // paddingMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // paddingMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier + NDArray* outputDepth = output; + if (weightsPoint) // if pointwise convolution is expected + outputDepth = + new NDArray(output->ordering(), + !isNCHW ? std::vector({bS, oH, oW, iC * mC}) + : std::vector({bS, iC * mC, oH, oW}), + input->dataType(), input->getContext()); - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + // ----- perform depthwise convolution (if weightsPoint is absent then oC = + // iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d( + block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, + kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, isNCHW, wFormat); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, paddingMode, isNCHW, wFormat); // in this case oH=iH, oW=iW - delete outputDepth; - } + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1, + 1, 1, 1, 0, 0, 1, 1, paddingMode, isNCHW, + wFormat); // in this case oH=iH, oW=iW + delete outputDepth; + } } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW, const int wFormat) { - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat), FLOAT_TYPES); +void ConvolutionUtils::sconv2d(sd::graph::Context& block, const NDArray* input, + const NDArray* weightsDepth, + const NDArray* weightsPoint, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, int pH, int pW, + const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), sconv2d_, + (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat), + FLOAT_TYPES); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu index ee1fa892416f..988e1fc310e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2d.cu @@ -19,79 +19,93 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling2dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW) { - - // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) +__global__ static void upsampling2dCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const int factorH, const int factorW, + const bool isNCHW) { + // x has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // z has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, + // factorW*iW, iC] (NHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimIH; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimIH; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; - } - __syncthreads(); + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[dimIH] /= factorH; - coords[dimIH + 1] /= factorW; + coords[dimIH] /= factorH; + coords[dimIH + 1] /= factorW; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; + z[zOffset] = x[xOffset]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorH, const int factorW, const bool isNCHW) { - - upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); +static void upsampling2dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, + const bool isNCHW) { + upsampling2dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::upsampling2d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorH, const int factorW, + const bool isNCHW) { + PointersManager manager(block.launchContext(), "upsampling2d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), upsampling2dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), factorH, factorW, isNCHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu index c6864c48a109..519380c2e35e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling2dBP.cu @@ -19,85 +19,98 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { - - // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // z (gradI) has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) +__global__ static void upsampling2dBPCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool isNCHW) { + // x (gradO) has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, + // factorH*iH, factorW*iW, iC] (NHWC) z (gradI) has shape [bS, iC, iH, iW] + // (NCHW) or [bS, iH, iW, iC] (NHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimIH; - __shared__ uint factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimIH; + __shared__ uint factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimIH = isNCHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 4; + dimIH = isNCHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 4; - factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; - factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; - } - __syncthreads(); + factorH = xShapeInfo[dimIH + 1] / zShapeInfo[dimIH + 1]; + factorW = xShapeInfo[dimIH + 2] / zShapeInfo[dimIH + 2]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = 0; + z[zOffset] = 0; - const Nd4jLong zCoord2 = coords[dimIH] * factorH; - const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; + const Nd4jLong zCoord2 = coords[dimIH] * factorH; + const Nd4jLong zCoord3 = coords[dimIH + 1] * factorW; - for(coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; ++coords[dimIH]) - for(coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; ++coords[dimIH + 1]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; + for (coords[dimIH] = zCoord2; coords[dimIH] < zCoord2 + factorH; + ++coords[dimIH]) + for (coords[dimIH + 1] = zCoord3; coords[dimIH + 1] < zCoord3 + factorW; + ++coords[dimIH + 1]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCHW) { - - upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); +static void upsampling2dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW) { + upsampling2dBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, isNCHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - - PointersManager manager(block.launchContext(), "upsampling2d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void ConvolutionUtils::upsampling2dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCHW) { + PointersManager manager(block.launchContext(), "upsampling2d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), upsampling2dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), isNCHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu index 1acb4307f3d4..f774c7475edf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3d.cu @@ -19,80 +19,94 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling3dCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) +__global__ static void upsampling3dCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + // x has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // z has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, + // factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimID; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimID; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; - } - __syncthreads(); + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[dimID] /= factorD; - coords[dimID + 1] /= factorH; - coords[dimID + 2] /= factorW; + coords[dimID] /= factorD; + coords[dimID + 1] /= factorH; + coords[dimID + 2] /= factorW; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; + z[zOffset] = x[xOffset]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); +static void upsampling3dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + upsampling3dCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3d(sd::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void ConvolutionUtils::upsampling3d(sd::graph::Context& block, + const NDArray& input, NDArray& output, + const int factorD, const int factorH, + const int factorW, const bool isNCDHW) { + PointersManager manager(block.launchContext(), "upsampling3d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), upsampling3dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), input.specialBuffer(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } -} -} +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu index 5a1e08c07d1f..9d98eac3967f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_upsampling3dBP.cu @@ -19,89 +19,102 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// template -__global__ static void upsampling3dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { - - // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) +__global__ static void upsampling3dBPCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool isNCDHW) { + // x (gradO) has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] + // (NDHWC) z (gradI) has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] + // (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank, dimID; - __shared__ uint factorD, factorH, factorW; - __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank, dimID; + __shared__ uint factorD, factorH, factorW; + __shared__ Nd4jLong zLen, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - dimID = isNCDHW ? 2 : 1; - zLen = shape::length(zShapeInfo); - rank = 5; + dimID = isNCDHW ? 2 : 1; + zLen = shape::length(zShapeInfo); + rank = 5; - factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; - factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; - factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; - } - __syncthreads(); + factorD = xShapeInfo[dimID + 1] / zShapeInfo[dimID + 1]; + factorH = xShapeInfo[dimID + 2] / zShapeInfo[dimID + 2]; + factorW = xShapeInfo[dimID + 3] / zShapeInfo[dimID + 3]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto coords = sharedMem + threadIdx.x * rank; + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(zInd, zShapeInfo, coords); + shape::index2coords(zInd, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = 0; + z[zOffset] = 0; - const Nd4jLong zCoord2 = coords[dimID] * factorD; - const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; - const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; + const Nd4jLong zCoord2 = coords[dimID] * factorD; + const Nd4jLong zCoord3 = coords[dimID + 1] * factorH; + const Nd4jLong zCoord4 = coords[dimID + 2] * factorW; - for(coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; ++coords[dimID]) - for(coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; ++coords[dimID + 1]) - for(coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; ++coords[dimID + 2]) - z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; + for (coords[dimID] = zCoord2; coords[dimID] < zCoord2 + factorD; + ++coords[dimID]) + for (coords[dimID + 1] = zCoord3; coords[dimID + 1] < zCoord3 + factorH; + ++coords[dimID + 1]) + for (coords[dimID + 2] = zCoord4; coords[dimID + 2] < zCoord4 + factorW; + ++coords[dimID + 2]) + z[zOffset] += x[shape::getOffset(xShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const bool isNCDHW) { - - upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); +static void upsampling3dBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW) { + upsampling3dBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, isNCDHW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - - PointersManager manager(block.launchContext(), "upsampling3d_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void ConvolutionUtils::upsampling3dBP(sd::graph::Context& block, + const NDArray& gradO, NDArray& gradI, + const bool isNCDHW) { + PointersManager manager(block.launchContext(), "upsampling3d_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), upsampling3dBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), gradO.specialBuffer(), + gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), isNCDHW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } - -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu index c2c5fb3efb86..12d987a47377 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions_vol2col.cu @@ -19,93 +19,120 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { -namespace ops { +namespace ops { ////////////////////////////////////////////////////////////////////////// -// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, oW] +// vol [bS, iC, iD, iH, iW] is convoluted to col [bS, iC, kD, kH, kW, oD, oH, +// oW] template -static __global__ void vol2colCuda(const void* volume, const Nd4jLong* volShapeInfo, void* columns, const Nd4jLong* colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - const T* vol = reinterpret_cast(volume); - T* col = reinterpret_cast(columns); - - __shared__ int colRank, volRank; - __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - volRank = 5; - colRank = 8; - - colLen = shape::length(colShapeInfo); - - iD = volShapeInfo[3]; - iH = volShapeInfo[4]; - iW = volShapeInfo[5]; - } - __syncthreads(); - - const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(colInd >= colLen) - return; - - auto coords = sharedMem + threadIdx.x * colRank; - - shape::index2coords(colInd, colShapeInfo, coords); - - // const auto colW = coords[7]; - // const auto colH = coords[6]; - // const auto colD = coords[5]; - // const auto kCol = coords[4]; - // const auto kRow = coords[3]; - // const auto kDep = coords[2]; - // const auto c = coords[1]; - // const auto b = coords[0]; - - const auto colOffset = shape::getOffset(colShapeInfo, coords); - - coords[2] = -pD + coords[2] * dD + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; - coords[3] = -pH + coords[3] * dH + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; - coords[4] = -pW + coords[4] * dW + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; - - if (static_cast(coords[2]) >= static_cast(iD) || static_cast(coords[3]) >= static_cast(iH) || static_cast(coords[4]) >= static_cast(iW)) - col[colOffset] = static_cast(0.); - else - col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; +static __global__ void vol2colCuda(const void* volume, + const Nd4jLong* volShapeInfo, void* columns, + const Nd4jLong* colShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + const T* vol = reinterpret_cast(volume); + T* col = reinterpret_cast(columns); + + __shared__ int colRank, volRank; + __shared__ Nd4jLong colLen, iD, iH, iW, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + volRank = 5; + colRank = 8; + + colLen = shape::length(colShapeInfo); + + iD = volShapeInfo[3]; + iH = volShapeInfo[4]; + iW = volShapeInfo[5]; + } + __syncthreads(); + + const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; + + if (colInd >= colLen) return; + + auto coords = sharedMem + threadIdx.x * colRank; + + shape::index2coords(colInd, colShapeInfo, coords); + + // const auto colW = coords[7]; + // const auto colH = coords[6]; + // const auto colD = coords[5]; + // const auto kCol = coords[4]; + // const auto kRow = coords[3]; + // const auto kDep = coords[2]; + // const auto c = coords[1]; + // const auto b = coords[0]; + + const auto colOffset = shape::getOffset(colShapeInfo, coords); + + coords[2] = + -pD + coords[2] * dD + + coords[5] * sD; // const auto volDep = (-pD + kDep * dD) + colD * sD; + coords[3] = + -pH + coords[3] * dH + + coords[6] * sH; // const auto volRow = (-pH + kRow * dH) + colH * sH; + coords[4] = + -pW + coords[4] * dW + + coords[7] * sW; // const auto volCol = (-pW + kCol * dW) + colW * sW; + + if (static_cast(coords[2]) >= static_cast(iD) || + static_cast(coords[3]) >= static_cast(iH) || + static_cast(coords[4]) >= static_cast(iW)) + col[colOffset] = static_cast(0.); + else + col[colOffset] = vol[shape::getOffset(volShapeInfo, coords)]; } ////////////////////////////////////////////////////////////////////////// template -static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* volume, const Nd4jLong* volShapeInfo, - void* columns, const Nd4jLong* colShapeInfo, - const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); +static void vol2colCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* volume, + const Nd4jLong* volShapeInfo, void* columns, + const Nd4jLong* colShapeInfo, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW) { + vol2colCuda<<>>( + volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, + dH, dW); } ////////////////////////////////////////////////////////////////////////// -void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - - PointersManager manager(block.launchContext(), "vol2col"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&col}, {&vol}); - BUILD_SINGLE_SELECTOR(vol.dataType(), vol2colCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), vol.specialBuffer(), vol.specialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({&col}, {&vol}); - - manager.synchronize(); +void ConvolutionUtils::vol2col(sd::graph::Context& block, const NDArray& vol, + NDArray& col, const int sD, const int sH, + const int sW, const int pD, const int pH, + const int pW, const int dD, const int dH, + const int dW) { + PointersManager manager(block.launchContext(), "vol2col"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (col.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&col}, {&vol}); + BUILD_SINGLE_SELECTOR( + vol.dataType(), vol2colCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + block.launchContext()->getCudaStream(), vol.specialBuffer(), + vol.specialShapeInfo(), col.specialBuffer(), col.specialShapeInfo(), sD, + sH, sW, pD, pH, pW, dD, dH, dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({&col}, {&vol}); + + manager.synchronize(); } -} -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/cross.cu b/libnd4j/include/ops/declarable/helpers/cuda/cross.cu index 8de4f65fd690..ad7893ae8936 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/cross.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/cross.cu @@ -18,105 +18,114 @@ // @author Yurii Shyrma, created on 10.06.2019 // - -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template +template __global__ static void crossCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { + void* vz, const Nd4jLong* zShapeInfo) { + __shared__ const T* x; + __shared__ const T* y; + __shared__ T* z; + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong lenWithoutLastDim, totalThreads; - __shared__ const T* x; - __shared__ const T* y; - __shared__ T* z; - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong lenWithoutLastDim, totalThreads; + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + y = reinterpret_cast(vy); + z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - y = reinterpret_cast(vy); - z = reinterpret_cast(vz); + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + totalThreads = gridDim.x * blockDim.x; - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - totalThreads = gridDim.x * blockDim.x; + rank = shape::rank(xShapeInfo); + lenWithoutLastDim = shape::length(xShapeInfo) / + xShapeInfo[rank]; // shape::length(xShapeInfo) / 3; + } + __syncthreads(); - rank = shape::rank(xShapeInfo); - lenWithoutLastDim = shape::length(xShapeInfo) / xShapeInfo[rank]; // shape::length(xShapeInfo) / 3; - } - __syncthreads(); + auto coords = sharedMem + threadIdx.x * rank; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (uint i = tid; i < lenWithoutLastDim; i += totalThreads) { + shape::index2coords(i, rank - 1, xShapeInfo + 1, coords); - for (uint i = tid; i < lenWithoutLastDim; i += totalThreads) { + coords[rank - 1] = 0; - shape::index2coords(i, rank - 1, xShapeInfo + 1, coords); + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto yOffset = shape::getOffset(yShapeInfo, coords); - coords[rank - 1] = 0; + const auto x0 = x[xOffset]; + const auto y0 = y[yOffset]; - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto yOffset = shape::getOffset(yShapeInfo, coords); + xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; + yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; - const auto x0 = x[xOffset]; - const auto y0 = y[yOffset]; + const auto x1 = x[xOffset]; + const auto y1 = y[yOffset]; - xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; - yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; + xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; + yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; - const auto x1 = x[xOffset]; - const auto y1 = y[yOffset]; + const auto x2 = x[xOffset]; + const auto y2 = y[yOffset]; - xOffset += shape::stride(const_cast(xShapeInfo))[rank - 1]; - yOffset += shape::stride(const_cast(yShapeInfo))[rank - 1]; + auto zOffset = shape::getOffset(zShapeInfo, coords); + z[zOffset] = x1 * y2 - x2 * y1; - const auto x2 = x[xOffset]; - const auto y2 = y[yOffset]; + zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; + z[zOffset] = x2 * y0 - x0 * y2; - auto zOffset = shape::getOffset(zShapeInfo, coords); - z[zOffset] = x1 * y2 - x2 * y1; - - zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; - z[zOffset] = x2 * y0 - x0 * y2; - - zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; - z[zOffset] = x0 * y1 - x1 * y0; - } + zOffset += shape::stride(const_cast(zShapeInfo))[rank - 1]; + z[zOffset] = x0 * y1 - x1 * y0; + } } -template -__host__ static void crossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - - crossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +__host__ static void crossCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo) { + crossCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void crossCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo), NUMERIC_TYPES); - - -void crossBatched(sd::LaunchContext* context, NDArray *x, NDArray *y, NDArray *z) { - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->lengthOf() / x->sizeAt(-1) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = sizeof(int) * threadsPerBlock * x->rankOf() + 128; - - PointersManager manager(context, "cross"); - - NDArray::prepareSpecialUse({z}, {x, y}); - BUILD_SINGLE_SELECTOR(x->dataType(), crossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), y->specialBuffer(), y->specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo()), NUMERIC_TYPES); - NDArray::registerSpecialUse({z}, {x, y}); - - manager.synchronize(); +BUILD_SINGLE_TEMPLATE(template void crossCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo), + NUMERIC_TYPES); + +void crossBatched(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* z) { + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->lengthOf() / x->sizeAt(-1) + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = sizeof(int) * threadsPerBlock * x->rankOf() + 128; + + PointersManager manager(context, "cross"); + + NDArray::prepareSpecialUse({z}, {x, y}); + BUILD_SINGLE_SELECTOR( + x->dataType(), crossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), y->specialBuffer(), + y->specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo()), + NUMERIC_TYPES); + NDArray::registerSpecialUse({z}, {x, y}); + + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu index 35d8bf0335ae..c9c87db35147 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu @@ -24,82 +24,106 @@ namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void depthToSpaceKernel(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int block_size, const bool isNHWC) { - auto input_ptr = reinterpret_cast(vx); - auto output_ptr = reinterpret_cast(vz); - - const int batch_size = shape::sizeAt(xShapeInfo, 0); - const int input_depth = isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); - const int input_height = isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); - const int input_width = isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); - - const int output_depth = isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); - const int output_height = isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); - const int output_width = isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); - - const int input_area = input_width * input_height; - const int input_depth_by_input_area = input_depth * input_area; - const int output_depth_by_input_height = output_depth * input_height; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (isNHWC) { - const int total_count = batch_size * output_height * output_width * output_depth; - for (int out_idx = tid; out_idx < total_count; out_idx += blockDim.x * gridDim.x) { - const int d = out_idx % output_depth; - const int out_idx2 = out_idx / output_depth; - const int w = out_idx2 % output_width; - const int out_idx3 = out_idx2 / output_width; - const int h = out_idx3 % output_height; - const int b = out_idx3 / output_height; - - const int in_h = h / block_size; - const int offset_h = h % block_size; - const int in_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * output_depth; - const int in_d = d + offset_d; - const int inp_idx = in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); - (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; - } - } else { - const int total_count = batch_size * input_depth_by_input_area; - - for (int input_idx = tid; input_idx < total_count; input_idx += blockDim.x * gridDim.x) { - const int n_bY_bX_oC_iY = input_idx / input_width; - const int iX = input_idx - n_bY_bX_oC_iY * input_width; - - const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; - const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; - - const int n_bY = n_bY_bX / block_size; - const int bX = n_bY_bX - n_bY * block_size; - - const int n = n_bY / block_size; - const int bY = n_bY - n * block_size; - - const int output_idx = bX + block_size * (iX + input_width * (bY + block_size * (oC_iY + n * output_depth_by_input_height))); - - (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; - } - } +template +static _CUDA_G void depthToSpaceKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const int block_size, + const bool isNHWC) { + auto input_ptr = reinterpret_cast(vx); + auto output_ptr = reinterpret_cast(vz); + + const int batch_size = shape::sizeAt(xShapeInfo, 0); + const int input_depth = + isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); + const int input_height = + isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); + const int input_width = + isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); + + const int output_depth = + isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); + const int output_height = + isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); + const int output_width = + isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); + + const int input_area = input_width * input_height; + const int input_depth_by_input_area = input_depth * input_area; + const int output_depth_by_input_height = output_depth * input_height; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (isNHWC) { + const int total_count = + batch_size * output_height * output_width * output_depth; + for (int out_idx = tid; out_idx < total_count; + out_idx += blockDim.x * gridDim.x) { + const int d = out_idx % output_depth; + const int out_idx2 = out_idx / output_depth; + const int w = out_idx2 % output_width; + const int out_idx3 = out_idx2 / output_width; + const int h = out_idx3 % output_height; + const int b = out_idx3 / output_height; + + const int in_h = h / block_size; + const int offset_h = h % block_size; + const int in_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = d + offset_d; + const int inp_idx = + in_d + input_depth * (in_w + input_width * (in_h + input_height * b)); + (output_ptr + out_idx)[0] = (input_ptr + inp_idx)[0]; } + } else { + const int total_count = batch_size * input_depth_by_input_area; + for (int input_idx = tid; input_idx < total_count; + input_idx += blockDim.x * gridDim.x) { + const int n_bY_bX_oC_iY = input_idx / input_width; + const int iX = input_idx - n_bY_bX_oC_iY * input_width; - template - static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); - } + const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height; + const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height; + + const int n_bY = n_bY_bX / block_size; + const int bX = n_bY_bX - n_bY * block_size; - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input.dataType(); + const int n = n_bY / block_size; + const int bY = n_bY - n * block_size; - NDArray::prepareSpecialUse({output}, {&input}); + const int output_idx = + bX + + block_size * + (iX + input_width * + (bY + block_size * + (oC_iY + n * output_depth_by_input_height))); - BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {&input}); + (output_ptr + output_idx)[0] = (input_ptr + input_idx)[0]; } + } +} + +template +static void __depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), block_size, isNHWC); } + +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); + + NDArray::prepareSpecialUse({output}, {&input}); + + BUILD_SINGLE_SELECTOR(xType, __depthToSpace, + (context, input, output, block_size, isNHWC), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu index ff217bdb6c80..c475d3ea58e8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu @@ -19,60 +19,72 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ bool sameOffset; - - if (threadIdx.x == 0) { - len = shape::length(xShapeInfo); - sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - } - __syncthreads(); - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); - - z[zOffset] = diGammaScalar(x[xOffset]); - } + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffset; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + } + __syncthreads(); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; + i += gridDim.x * blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = + sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = diGammaScalar(x[xOffset]); + } } /////////////////////////////////////////////////////////////////// -template -static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - diGammaCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +template +static void diGammaCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + diGammaCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z) { - - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&z}, {&x}); - BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({&z}, {&x}); +void diGamma(sd::LaunchContext *context, const NDArray &x, NDArray &z) { + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&z}, {&x}); + BUILD_SINGLE_SELECTOR( + x.dataType(), diGammaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), + z.specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({&z}, {&x}); } -BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - -} -} -} +BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index f011f409526a..95bbd5b9893a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -28,110 +28,130 @@ namespace helpers { // diag functor cuda kernel // outputBuffer - output tensor buffer // outputShape - output tensor shape -// inputBuffer - input tensor buffer - this tensor should be placed on diagonal position of output -// inputShape - input tensor shape -// inputLength - length for input tensor +// inputBuffer - input tensor buffer - this tensor should be placed on diagonal +// position of output inputShape - input tensor shape inputLength - length for +// input tensor // template -static __global__ void diagFunctorKernel(void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, const Nd4jLong* inputShape, Nd4jLong inputLength) { - __shared__ T *z; - __shared__ T const* x; - __shared__ Nd4jLong outputLength; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - - outputLength = shape::length(outputShape); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (int t = tid; t < inputLength; t += step) { // for all vals in input, put all on diagonal position to output - z[shape::getIndexOffset(t * (inputLength + 1), outputShape)] = x[shape::getIndexOffset(t, inputShape)]; //tX]; - } - +static __global__ void diagFunctorKernel(void* outputBuffer, + const Nd4jLong* outputShape, + void const* inputBuffer, + const Nd4jLong* inputShape, + Nd4jLong inputLength) { + __shared__ T* z; + __shared__ T const* x; + __shared__ Nd4jLong outputLength; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + x = reinterpret_cast(inputBuffer); + + outputLength = shape::length(outputShape); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (int t = tid; t < inputLength; + t += + step) { // for all vals in input, put all on diagonal position to output + z[shape::getIndexOffset(t * (inputLength + 1), outputShape)] = + x[shape::getIndexOffset(t, inputShape)]; // tX]; + } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diag part functor cuda kernel // outputBuffer - output tensor buffer - linear sequence of diagonal values // outputShape - output tensor shape -// inputBuffer - input tensor buffer - this tensor should be placed on diagonal position of output -// inputShape - input tensor shape -// outputLength - given length of output -// inputLength - given length for input tensor +// inputBuffer - input tensor buffer - this tensor should be placed on diagonal +// position of output inputShape - input tensor shape outputLength - given +// length of output inputLength - given length for input tensor // - template - static __global__ void diagPartFunctorKernel(void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, const Nd4jLong* inputShape, Nd4jLong outputLength, Nd4jLong inputLength) { - __shared__ T *z; - __shared__ T const* x; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - Nd4jLong i = threadIdx.x * (outputLength + 1); // pos to diagonal value - for (int t = tid; t < outputLength && i < inputLength; t += step) { // loop by output, but input matrix may not be square - // put diagonal val from input onto output - z[shape::getIndexOffset(t, outputShape)] = x[shape::getIndexOffset(i, inputShape)]; - i += outputLength + 1; // shift to next diagonal value - } - } +template +static __global__ void diagPartFunctorKernel( + void* outputBuffer, const Nd4jLong* outputShape, void const* inputBuffer, + const Nd4jLong* inputShape, Nd4jLong outputLength, Nd4jLong inputLength) { + __shared__ T* z; + __shared__ T const* x; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuffer); + x = reinterpret_cast(inputBuffer); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + Nd4jLong i = threadIdx.x * (outputLength + 1); // pos to diagonal value + for (int t = tid; t < outputLength && i < inputLength; + t += step) { // loop by output, but input matrix may not be square + // put diagonal val from input onto output + z[shape::getIndexOffset(t, outputShape)] = + x[shape::getIndexOffset(i, inputShape)]; + i += outputLength + 1; // shift to next diagonal value + } +} ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag - template - static void _diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - auto inputLength = input->lengthOf(); - dim3 launchDims(256, 512, 8192); - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - diagFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), inputLength); - } +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +template +static void _diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + auto inputLength = input->lengthOf(); + dim3 launchDims(256, 512, 8192); + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + diagFunctorKernel<<>>( + output->specialBuffer(), output->specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), inputLength); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagFunctor - caller for diag functor processor - void diagFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto xType = input->dataType(); +void diagFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (context, input, output), LIBND4J_TYPES); - } + BUILD_SINGLE_SELECTOR(xType, _diagFunctor, (context, input, output), + LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void _diagFunctor, (sd::LaunchContext * context, const NDArray* input, NDArray* output);, LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _diagFunctor, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output); + , LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagPartFunctor - caller for diag part functor kernel - template - void _diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - const int outLen = output->lengthOf(); - const int inLen = input->lengthOf(); - auto stream = context->getCudaStream(); - - dim3 launchDims(256, 512, 8192); - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - diagPartFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), outLen, inLen); - } +template +void _diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + const int outLen = output->lengthOf(); + const int inLen = input->lengthOf(); + auto stream = context->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + + diagPartFunctorKernel + <<>>( + output->specialBuffer(), output->specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), outLen, inLen); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // diagPartFunctor - caller for diag part functor processor - void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output) { - auto zType = output->dataType(); - BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), NUMERIC_TYPES); - - } - +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + auto zType = output->dataType(); + BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), + NUMERIC_TYPES); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu index 0d25552c93a5..2b7cbcc98afa 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu @@ -18,117 +18,125 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template -__global__ static void dilation2dCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW) { +__global__ static void dilation2dCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW) { + // x [bS, iH, iW, iC] + // y [kH, kW, iC] + // z [bS, oH, oW, iC] - // x [bS, iH, iW, iC] - // y [kH, kW, iC] - // z [bS, oH, oW, iC] + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); - const X* x = reinterpret_cast(vx); - const X* y = reinterpret_cast(vy); - Z* z = reinterpret_cast(vz); + __shared__ int xzRank, yRank, *sharedMem; + __shared__ uint iH, iW, kH, kW; + __shared__ Nd4jLong zLen; - __shared__ int xzRank, yRank, *sharedMem; - __shared__ uint iH, iW, kH, kW; - __shared__ Nd4jLong zLen; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + zLen = shape::length(zShapeInfo); - zLen = shape::length(zShapeInfo); + xzRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); - xzRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); + iH = xShapeInfo[2]; + iW = xShapeInfo[3]; - iH = xShapeInfo[2]; - iW = xShapeInfo[3]; - - kH = yShapeInfo[1]; - kW = yShapeInfo[2]; - } - __syncthreads(); + kH = yShapeInfo[1]; + kW = yShapeInfo[2]; + } + __syncthreads(); - const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto zInd = threadIdx.x + blockIdx.x * blockDim.x; - if(zInd >= zLen) - return; + if (zInd >= zLen) return; - auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank); - auto yCoords = xzCoords + xzRank; + auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank); + auto yCoords = xzCoords + xzRank; - shape::index2coords(zInd, zShapeInfo, xzCoords); + shape::index2coords(zInd, zShapeInfo, xzCoords); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoords); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoords); - yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z + yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z - const auto oh = xzCoords[1]; - const auto ow = xzCoords[2]; + const auto oh = xzCoords[1]; + const auto ow = xzCoords[2]; - X max = -DataTypeUtils::max(); + X max = -DataTypeUtils::max(); - for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) { - xzCoords[1] = oh * sH - pH + yCoords[0] * dH; - if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue; + for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) { + xzCoords[1] = oh * sH - pH + yCoords[0] * dH; + if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue; - for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) { - xzCoords[2] = ow * sW - pW + yCoords[1] * dW; - if(xzCoords[2] < 0 || xzCoords[2] >= iW) continue; + for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) { + xzCoords[2] = ow * sW - pW + yCoords[1] * dW; + if (xzCoords[2] < 0 || xzCoords[2] >= iW) continue; - const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + y[shape::getOffset(yShapeInfo, yCoords)]; - if (val > max) - max = val; - } - } + const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + + y[shape::getOffset(yShapeInfo, yCoords)]; + if (val > max) max = val; + } + } - z[zOffset] = static_cast(max); + z[zOffset] = static_cast(max); } ////////////////////////////////////////////////////////////////////////// template -static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW) { - - dilation2dCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); +static void dilation2dCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + dilation2dCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); } -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - - PointersManager manager(context, "dilation2d"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(int) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({output}, {input, weights}); - BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), weights->specialBuffer(), weights->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, weights}); - - manager.synchronize(); +void dilation2d(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW) { + PointersManager manager(context, "dilation2d"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + (weights->rankOf() + output->rankOf()) * sizeof(int) * threadsPerBlock + + 128; + + NDArray::prepareSpecialUse({output}, {input, weights}); + BUILD_SINGLE_SELECTOR_TWICE( + input->dataType(), dilation2dCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + weights->specialBuffer(), weights->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, + dW), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input, weights}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 4e0fdb377e10..e35e50da000d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -18,258 +18,344 @@ // @author raver119@gmail.com // -#include +#include #include -#include +#include + #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probVal, int inLen, sd::graph::RandomGenerator* nodeRng) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - T const* input = reinterpret_cast(inputBuf); - T* output = reinterpret_cast(outputBuf); - - // trivial idea: loop through all elements, get independent probability for each element to be nullified - for (Nd4jLong e = 0; e < inLen; ++e) { - T val = nodeRng->relativeT(e, T(0.f), T(1.f)); - - // if probability is ok - we're saving scaled value - if (double(val) < probVal) - output[shape::getIndexOffset(e, outputShape)] = T(input[shape::getIndexOffset(e, inputShape)] / probVal); - } - } +template +static __global__ void dropoutSimpleKernel( + void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, + Nd4jLong const* outputShape, double probVal, int inLen, + sd::graph::RandomGenerator* nodeRng) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + T const* input = reinterpret_cast(inputBuf); + T* output = reinterpret_cast(outputBuf); + + // trivial idea: loop through all elements, get independent probability for + // each element to be nullified + for (Nd4jLong e = 0; e < inLen; ++e) { + T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + + // if probability is ok - we're saving scaled value + if (double(val) < probVal) + output[shape::getIndexOffset(e, outputShape)] = + T(input[shape::getIndexOffset(e, inputShape)] / probVal); + } +} - template - static void dropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) { - sd::graph::RandomGenerator nodeRng(3019L, seed); - int inLen = input->lengthOf(); - sd::graph::RandomGenerator* dRandom; - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - - auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err); - } - err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err); - } - - dropoutSimpleKernel<<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom); - err = cudaFree(dRandom); - if (err) { - throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err); - } - NDArray::registerSpecialUse({output}, {input}); - } +template +static void dropoutSimple(sd::LaunchContext* context, NDArray const* input, + NDArray* output, double probValue, int seed) { + sd::graph::RandomGenerator nodeRng(3019L, seed); + int inLen = input->lengthOf(); + sd::graph::RandomGenerator* dRandom; + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + + auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot allocate device memory for random " + "generator.", + err); + } + err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot set up device memory for random " + "generator.", + err); + } + + dropoutSimpleKernel<<<128, 256, 1024, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, + dRandom); + err = cudaFree(dRandom); + if (err) { + throw cuda_exception::build( + "helpers::dropoutSimple: Cannot deallocate device memory for random " + "generator.", + err); + } + NDArray::registerSpecialUse({output}, {input}); +} - template - int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - - if (reduceShape == nullptr){ - dropoutSimple(context.launchContext(), input, output, probValue, seed); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - reduceShape->syncToHost(); // to ensure that follows are actual - bool fit = true; - - for( int i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } +template +int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + if (reduceShape == nullptr) { + dropoutSimple(context.launchContext(), input, output, probValue, seed); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + reduceShape->syncToHost(); // to ensure that follows are actual + bool fit = true; + + for (int i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } - - // check dims to fit input - REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); - chunk->assign(1.f); - - dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; - - // FIXME: we could do this in one step, aren't we? - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } - - return Status::OK(); + } } - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - auto xType = input->dataType(); - NDArray::prepareSpecialUse({output}, {input}); - - BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); + // check dims to fit input + REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), context.launchContext())); + chunk->assign(1.f); - NDArray::registerSpecialUse({output}, {input}); - } + dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), + probValue, seed); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); -/////////////////////////////////// backrpopagations /////////////////////////////////////////////// - template - static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong const* outputShape, void* gradOutBuf, Nd4jLong const* gradOutShape, double probValue) { - __shared__ T* output; - __shared__ T* input; - __shared__ int len; + *dropOutMultiplier += *chunk; - if (threadIdx.x == 0) { - len = shape::length(outputShape); - output = reinterpret_cast(outputBuf); - input = reinterpret_cast(gradOutBuf); - } - __syncthreads(); + // FIXME: we could do this in one step, aren't we? + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int e = tid; e < len; e += step) { - const auto zOffset = shape::getIndexOffset(e, outputShape); - - // if probability was non-zero on FF step, we'll scale grads back - if (output[zOffset] != T(0.)) - output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape)] / probValue); - - } - } - template - static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - // we're making additional FF run to see how probabilities played out with given seeds - int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); - auto stream = context.launchContext()->getCudaStream(); + return Status::OK(); +} - NDArray::prepareSpecialUse({output}, {input, gradOut}); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + auto xType = input->dataType(); + NDArray::prepareSpecialUse({output}, {input}); - if (ND4J_STATUS_OK == res) - dropoutBPKernel<<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, + (context, input, output, reduceShape, seed, probValue), + FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input, gradOut}); + NDArray::registerSpecialUse({output}, {input}); +} - return res; - } +/////////////////////////////////// backrpopagations +////////////////////////////////////////////////// +template +static __global__ void dropoutBPKernel(void* outputBuf, + Nd4jLong const* outputShape, + void* gradOutBuf, + Nd4jLong const* gradOutShape, + double probValue) { + __shared__ T* output; + __shared__ T* input; + __shared__ int len; + + if (threadIdx.x == 0) { + len = shape::length(outputShape); + output = reinterpret_cast(outputBuf); + input = reinterpret_cast(gradOutBuf); + } + __syncthreads(); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int e = tid; e < len; e += step) { + const auto zOffset = shape::getIndexOffset(e, outputShape); + + // if probability was non-zero on FF step, we'll scale grads back + if (output[zOffset] != T(0.)) + output[zOffset] = + T(input[shape::getIndexOffset(e, gradOutShape)] / probValue); + } +} +template +static int dropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue) { + // we're making additional FF run to see how probabilities played out with + // given seeds + int res = + dropOutFunctor(context, input, output, reduceShape, seed, probValue); + auto stream = context.launchContext()->getCudaStream(); + + NDArray::prepareSpecialUse({output}, {input, gradOut}); + + if (ND4J_STATUS_OK == res) + dropoutBPKernel<<<128, 256, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + + NDArray::registerSpecialUse({output}, {input, gradOut}); + + return res; +} - template - static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, sd::graph::RandomGenerator* nodeRng) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - T const* input = reinterpret_cast(inputBuf); - T* output = reinterpret_cast(outputBuf); - - for (auto e = tid; e < inLen; e += step) { - T val = nodeRng->relativeT(e, T(0.f), T(1.f)); - T xVal = input[shape::getIndexOffset(e, inputShape)]; - output[shape::getIndexOffset(e, outputShape)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1)); - } - } - template - static void alphaDropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) { - sd::graph::RandomGenerator nodeRng(3019L, seed), *dRandom; - auto stream = context->getCudaStream(); - auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - NDArray::prepareSpecialUse({output}, {input}); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err); - } - err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err); - } - - alphaDropoutSimpleKernel<<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom); - - err = cudaFree(dRandom); - if (err) { - throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err); - } - NDArray::registerSpecialUse({output}, {input}); - } +template +static __global__ void alphaDropoutSimpleKernel( + void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, + Nd4jLong const* outputShape, double probValue, double alpha, double alpha1, + double beta, int inLen, sd::graph::RandomGenerator* nodeRng) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + T const* input = reinterpret_cast(inputBuf); + T* output = reinterpret_cast(outputBuf); + + for (auto e = tid; e < inLen; e += step) { + T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + T xVal = input[shape::getIndexOffset(e, inputShape)]; + output[shape::getIndexOffset(e, outputShape)] = + (val >= T(probValue) ? T(alpha * beta + alpha1) + : T(alpha * (double)xVal + alpha1)); + } +} +template +static void alphaDropoutSimple(sd::LaunchContext* context, NDArray const* input, + NDArray* output, int seed, double probValue, + double alpha, double alpha1, double beta) { + sd::graph::RandomGenerator nodeRng(3019L, seed), *dRandom; + auto stream = context->getCudaStream(); + auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); + NDArray::prepareSpecialUse({output}, {input}); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot allocate device memory for random " + "generator.", + err); + } + err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot set up device memory for random " + "generator.", + err); + } + + alphaDropoutSimpleKernel<<<128, 256, 1024, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, + alpha1, beta, output->lengthOf(), dRandom); + + err = cudaFree(dRandom); + if (err) { + throw cuda_exception::build( + "helpers::alphaDropoutSimple: Cannot deallocate device memory for " + "random generator.", + err); + } + NDArray::registerSpecialUse({output}, {input}); +} - template - static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - if (reduceShape == nullptr){ - alphaDropoutSimple(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta); - } - else { - REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); - - std::vector dims(reduceShape->lengthOf()); - reduceShape->syncToHost(); // to ensure that follows are actual - bool fit = true; - - for( int i = 0; i < dims.size(); i++ ) { - if (fit) { - dims[i] = reduceShape->e(i); - for (int e = 0; e < input->rankOf(); ++e) - if (fit) - if (input->sizeAt(e) % dims[i]) { - fit = false; - } - } +template +static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + if (reduceShape == nullptr) { + alphaDropoutSimple(context.launchContext(), input, output, seed, + probValue, alpha, alpha1, beta); + } else { + REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, + "dropout: Noise shape should be fittable to input"); + + std::vector dims(reduceShape->lengthOf()); + reduceShape->syncToHost(); // to ensure that follows are actual + bool fit = true; + + for (int i = 0; i < dims.size(); i++) { + if (fit) { + dims[i] = reduceShape->e(i); + for (int e = 0; e < input->rankOf(); ++e) + if (fit) + if (input->sizeAt(e) % dims[i]) { + fit = false; } + } + } - // check dims to fit input - REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank."); - std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); - chunk->assign(1.f); - - alphaDropoutSimple(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta); - - // broadcast chunk to full matrix - std::unique_ptr dropOutMultiplier(new NDArray(*input)); - dropOutMultiplier->assign(1.f); - - *dropOutMultiplier += *chunk; + // check dims to fit input + REQUIRE_TRUE(fit, 0, + "alpha_dropout: Noise shape should fit to input rank."); + std::unique_ptr chunk( + new NDArray('c', dims, output->dataType(), context.launchContext())); + chunk->assign(1.f); - output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); - } + alphaDropoutSimple(context.launchContext(), chunk.get(), chunk.get(), + seed, probValue, alpha, alpha1, beta); + // broadcast chunk to full matrix + std::unique_ptr dropOutMultiplier(new NDArray(*input)); + dropOutMultiplier->assign(1.f); - return Status::OK(); - } + *dropOutMultiplier += *chunk; - template - int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, - NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - - int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); - if (res == ND4J_STATUS_OK) { - // FIXME: can we make it single-loop? - (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); - } - return res; - } + output->assign( + *input * + *dropOutMultiplier); // input->applyPairwiseTransform(pairwise::Multiply, + // dropOutMultiplier.get(), output, nullptr); + } - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { - BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); - } + return Status::OK(); +} - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } +template +int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, + probValue, alpha, alpha1, beta); + if (res == ND4J_STATUS_OK) { + // FIXME: can we make it single-loop? + (*output) *= alpha; + (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, + //output, nullptr); + } + return res; +} - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); - } +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue) { + BUILD_SINGLE_SELECTOR( + context.dataType(), return dropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, probValue), + FLOAT_TYPES); +} +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta) { + BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, + (context, input, output, reduceShape, seed, probValue, + alpha, alpha1, beta), + FLOAT_TYPES); } + +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta) { + BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, + (context, input, gradOut, output, reduceShape, seed, + probValue, alpha, alpha1, beta), + FLOAT_TYPES); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 6f29995d3ae5..6fb057518150 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -17,354 +17,429 @@ // // @author raver119@gmail.com // -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - - template - static _CUDA_G void dynamicPartitionScalarKernel(const void *vx, const Nd4jLong *xShapeInfo, const void *vi, const Nd4jLong *iShapeInfo, void **vz, Nd4jLong **zShapeInfos, const Nd4jLong numOutputs) { - auto x = reinterpret_cast(vx); - auto i = reinterpret_cast(vi); - auto xLength = shape::length(xShapeInfo); - auto iLength = shape::length(iShapeInfo); - - extern __shared__ char shmem[]; - __shared__ Y *rawIndices; - __shared__ Y *trueIndices; - - if (threadIdx.x == 0) { - rawIndices = reinterpret_cast(shmem); - trueIndices = rawIndices + blockDim.x; - } - __syncthreads(); - - // we run things in blocks, 1 partition per block of threads - for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { - auto z = reinterpret_cast(vz[o]); - - auto zShapeInfo = zShapeInfos[o]; - auto zLength = shape::length(zShapeInfo); - - // iLimit should be multiple of blockDim.x - auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x))); - int cnt = 0; - - for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { - // load set of indices into shared memory - if (e < iLength) - rawIndices[threadIdx.x] = i[shape::getIndexOffset(e, iShapeInfo)]; - __syncthreads(); - - // now we need to find out where our actual updates will be mapped - // TODO: this can be improved obviously, by using prefix-sum like approach - if (threadIdx.x == 0) { - for (int f = 0; f < blockDim.x; f++) { - if (rawIndices[f] == static_cast(o)) - trueIndices[f] = cnt++; - else - trueIndices[f] = -1; - } - } - __syncthreads(); - - - // doing actual update - if (e < iLength) - if (trueIndices[threadIdx.x] >= 0) { - z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo)]; - } - - __syncthreads(); - } - } - } - - template - static _CUDA_G void dynamicPartitionTadKernel(const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, Nd4jLong xLength, const void *vindices, const Nd4jLong *iShapeInfo, Nd4jLong iLength, void **vz, Nd4jLong **zTadShapeInfos, Nd4jLong **zTadOffsets, Nd4jLong numOutputs) { - auto x = reinterpret_cast(vx); - auto indices = reinterpret_cast(vindices); - - // we run things in blocks, 1 partition per block of threads - for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { - auto z = reinterpret_cast(vz[i]); - - // each thread has own counter for partitions - int outCnt = 0; - - for (Nd4jLong e = 0; e < iLength; e++) { - if (indices[shape::getIndexOffset(e, iShapeInfo)] == i) { - auto dx = x + xTadOffsets[e]; - auto dz = z + zTadOffsets[i][outCnt++]; - - for (int f = threadIdx.x; f < xLength; f += blockDim.x) { - dz[shape::getIndexOffset(f, zTadShapeInfos[i])] = dx[shape::getIndexOffset(f, xTadShapeInfo)]; - } - } - } - } - } - - template - static void _dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - std::vector> outputs(outputList.size()); - int sourceDimsLen = input->rankOf() - indices->rankOf(); - - unsigned int outSize = outputList.size(); - - PointersManager pm(context, "dynamicPartition"); - - if (sourceDimsLen) { // non-linear case - std::vector sourceDims(sourceDimsLen); - - for (int i = sourceDimsLen; i > 0; i--) - sourceDims[sourceDimsLen - i] = input->rankOf() - i; - //compute tad array for given dimensions - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), sourceDims); - - std::vector outBuffers(outSize); - std::vector tadShapes(outSize); - std::vector tadOffsets(outSize); - std::vector numTads(outSize); - // fill up dimensions array for before kernel - for (unsigned int i = 0; i < outSize; i++) { - outputs[i].first = outputList[i]; - std::vector outDims(outputs[i].first->rankOf() - 1); - - int r = outputs[i].first->rankOf(); - - for (int k = 1; k < r; k++) - outDims[k - 1] = k; - - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(outputList.at(i)->shapeInfo(), outDims); - - outBuffers[i] = outputList.at(i)->specialBuffer(); - tadShapes[i] = packZ.platformShapeInfo(); - tadOffsets[i] = packZ.platformOffsets(); - } +namespace ops { +namespace helpers { + +template +static _CUDA_G void dynamicPartitionScalarKernel( + const void *vx, const Nd4jLong *xShapeInfo, const void *vi, + const Nd4jLong *iShapeInfo, void **vz, Nd4jLong **zShapeInfos, + const Nd4jLong numOutputs) { + auto x = reinterpret_cast(vx); + auto i = reinterpret_cast(vi); + auto xLength = shape::length(xShapeInfo); + auto iLength = shape::length(iShapeInfo); + + extern __shared__ char shmem[]; + __shared__ Y *rawIndices; + __shared__ Y *trueIndices; + + if (threadIdx.x == 0) { + rawIndices = reinterpret_cast(shmem); + trueIndices = rawIndices + blockDim.x; + } + __syncthreads(); + + // we run things in blocks, 1 partition per block of threads + for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { + auto z = reinterpret_cast(vz[o]); + + auto zShapeInfo = zShapeInfos[o]; + auto zLength = shape::length(zShapeInfo); + + // iLimit should be multiple of blockDim.x + auto iLimit = iLength <= blockDim.x + ? blockDim.x + : (iLength + (blockDim.x - (iLength % blockDim.x))); + int cnt = 0; + + for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { + // load set of indices into shared memory + if (e < iLength) + rawIndices[threadIdx.x] = i[shape::getIndexOffset(e, iShapeInfo)]; + __syncthreads(); + + // now we need to find out where our actual updates will be mapped + // TODO: this can be improved obviously, by using prefix-sum like approach + if (threadIdx.x == 0) { + for (int f = 0; f < blockDim.x; f++) { + if (rawIndices[f] == static_cast(o)) + trueIndices[f] = cnt++; + else + trueIndices[f] = -1; + } + } + __syncthreads(); - // we copy pointers to device - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); - auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - // run kernel on device - dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->specialBuffer(), indices->specialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); - - } else { // linear case - auto numThreads = 256; - auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; - - std::vector outBuffers; - std::vector outShapes; - - for (auto v:outputList) { - outBuffers.emplace_back(v->specialBuffer()); - outShapes.emplace_back(v->specialShapeInfo()); - } + // doing actual update + if (e < iLength) + if (trueIndices[threadIdx.x] >= 0) { + z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo)]; + } - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); - auto dOutShapes = reinterpret_cast(pm.replicatePointer(outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); + __syncthreads(); + } + } +} - dynamicPartitionScalarKernel<<<256, numThreads, shmemSize, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), dOutBuffers, dOutShapes, outSize); - } +template +static _CUDA_G void dynamicPartitionTadKernel( + const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + Nd4jLong xLength, const void *vindices, const Nd4jLong *iShapeInfo, + Nd4jLong iLength, void **vz, Nd4jLong **zTadShapeInfos, + Nd4jLong **zTadOffsets, Nd4jLong numOutputs) { + auto x = reinterpret_cast(vx); + auto indices = reinterpret_cast(vindices); + + // we run things in blocks, 1 partition per block of threads + for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { + auto z = reinterpret_cast(vz[i]); + + // each thread has own counter for partitions + int outCnt = 0; + + for (Nd4jLong e = 0; e < iLength; e++) { + if (indices[shape::getIndexOffset(e, iShapeInfo)] == i) { + auto dx = x + xTadOffsets[e]; + auto dz = z + zTadOffsets[i][outCnt++]; + + for (int f = threadIdx.x; f < xLength; f += blockDim.x) { + dz[shape::getIndexOffset(f, zTadShapeInfos[i])] = + dx[shape::getIndexOffset(f, xTadShapeInfo)]; + } + } + } + } +} - pm.synchronize(); - } +template +static void _dynamicPartitionFunctor(sd::LaunchContext *context, + NDArray const *input, + NDArray const *indices, + std::vector &outputList) { + std::vector> outputs(outputList.size()); + int sourceDimsLen = input->rankOf() - indices->rankOf(); + unsigned int outSize = outputList.size(); - template - static _CUDA_G void dynamicStitchScalarKernel(void **vx, Nd4jLong **xShapeInfos, void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong zLength) { - auto z = reinterpret_cast(vz); + PointersManager pm(context, "dynamicPartition"); - for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { - auto x = reinterpret_cast(vx[e]); - auto indices = reinterpret_cast(vindices[e]); + if (sourceDimsLen) { // non-linear case + std::vector sourceDims(sourceDimsLen); - auto xShapeInfo = xShapeInfos[e]; - auto iShapeInfo = iShapeInfos[e]; + for (int i = sourceDimsLen; i > 0; i--) + sourceDims[sourceDimsLen - i] = input->rankOf() - i; + // compute tad array for given dimensions + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), sourceDims); - auto iLength = shape::length(iShapeInfo); + std::vector outBuffers(outSize); + std::vector tadShapes(outSize); + std::vector tadOffsets(outSize); + std::vector numTads(outSize); + // fill up dimensions array for before kernel + for (unsigned int i = 0; i < outSize; i++) { + outputs[i].first = outputList[i]; + std::vector outDims(outputs[i].first->rankOf() - 1); - for (int i = threadIdx.x; i < iLength; i += blockDim.x) { - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - if (idx >= 0 && idx < zLength) - z[shape::getIndexOffset(idx, zShapeInfo)] = x[shape::getIndexOffset(i, xShapeInfo)]; - } - } - } + int r = outputs[i].first->rankOf(); - template - static _CUDA_G void dynamicStitchTadKernel(void **vx, Nd4jLong **xTadShapeInfos, Nd4jLong **xTadOffsets, void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets) { - auto bz = reinterpret_cast(vz); + for (int k = 1; k < r; k++) outDims[k - 1] = k; - for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { - auto indices = reinterpret_cast(vindices[e]); - auto iShapeInfo = iShapeInfos[e]; + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + outputList.at(i)->shapeInfo(), outDims); - if (shape::isEmpty(iShapeInfo)) - continue; + outBuffers[i] = outputList.at(i)->specialBuffer(); + tadShapes[i] = packZ.platformShapeInfo(); + tadOffsets[i] = packZ.platformOffsets(); + } - auto iLength = shape::length(iShapeInfo); - auto zLength = shape::length(zTadShapeInfo); + // we copy pointers to device + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void *))); + auto dOutTadShapes = reinterpret_cast(pm.replicatePointer( + tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer( + tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); + // run kernel on device + dynamicPartitionTadKernel + <<<256, 256, 1024, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), + indices->specialBuffer(), indices->specialShapeInfo(), + indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, + outSize); + + } else { // linear case + auto numThreads = 256; + auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; + + std::vector outBuffers; + std::vector outShapes; + + for (auto v : outputList) { + outBuffers.emplace_back(v->specialBuffer()); + outShapes.emplace_back(v->specialShapeInfo()); + } - auto xShapeInfo = xTadShapeInfos[e]; - auto xLength = shape::length(xShapeInfo); + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void *))); + auto dOutShapes = reinterpret_cast(pm.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(Nd4jLong *))); - for (int i = 0; i < iLength; i++) { - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + dynamicPartitionScalarKernel + <<<256, numThreads, shmemSize, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), dOutBuffers, + dOutShapes, outSize); + } - auto z = bz + zTadOffsets[idx]; - auto x = reinterpret_cast(vx[e]) + xTadOffsets[e][i]; + pm.synchronize(); +} - for (int f = threadIdx.x; f < zLength; f += blockDim.x) { - z[shape::getIndexOffset(f, zTadShapeInfo)] = x[shape::getIndexOffset(f, xShapeInfo)]; - } +template +static _CUDA_G void dynamicStitchScalarKernel( + void **vx, Nd4jLong **xShapeInfos, void **vindices, Nd4jLong **iShapeInfos, + int inputSize, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong zLength) { + auto z = reinterpret_cast(vz); - __syncthreads(); - } - } - } + for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { + auto x = reinterpret_cast(vx[e]); + auto indices = reinterpret_cast(vindices[e]); - template - static int _dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ + auto xShapeInfo = xShapeInfos[e]; + auto iShapeInfo = iShapeInfos[e]; - int inputSize = inputs.size(); + auto iLength = shape::length(iShapeInfo); - PointersManager pm(context, "dynamicStitch"); + for (int i = threadIdx.x; i < iLength; i += blockDim.x) { + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + if (idx >= 0 && idx < zLength) + z[shape::getIndexOffset(idx, zShapeInfo)] = + x[shape::getIndexOffset(i, xShapeInfo)]; + } + } +} - if (output->isVector()) { - std::vector inputBuffers(inputSize); - std::vector inputShapes(inputSize); - std::vector indicesBuffers(inputSize); - std::vector indicesShapes(inputSize); +template +static _CUDA_G void dynamicStitchTadKernel( + void **vx, Nd4jLong **xTadShapeInfos, Nd4jLong **xTadOffsets, + void **vindices, Nd4jLong **iShapeInfos, int inputSize, void *vz, + const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets) { + auto bz = reinterpret_cast(vz); - for (int e = 0; e < inputSize; e++) { - inputBuffers[e] = inputs.at(e)->specialBuffer(); - indicesBuffers[e] = indices.at(e)->specialBuffer(); + for (int e = blockIdx.x; e < inputSize; e += gridDim.x) { + auto indices = reinterpret_cast(vindices[e]); + auto iShapeInfo = iShapeInfos[e]; - inputShapes[e] = inputs.at(e)->specialShapeInfo(); - indicesShapes[e] = indices.at(e)->specialShapeInfo(); - } + if (shape::isEmpty(iShapeInfo)) continue; - // copying pointers to buffers to device - auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); - auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); - auto dInputShapes = reinterpret_cast(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *))); - auto dIndicesShapes = reinterpret_cast(pm.replicatePointer(indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto iLength = shape::length(iShapeInfo); + auto zLength = shape::length(zTadShapeInfo); - dynamicStitchScalarKernel<<<256, 256, 1024, *context->getCudaStream()>>>(dInputBuffers, dInputShapes, dIndicesBuffers, dIndicesShapes, inputSize, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf()); - } else { - std::vector restDims(output->rankOf() - 1); - for (int i = restDims.size(); i > 0; i--) - restDims[restDims.size() - i] = output->rankOf() - i; + auto xShapeInfo = xTadShapeInfos[e]; + auto xLength = shape::length(xShapeInfo); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), restDims); + for (int i = 0; i < iLength; i++) { + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - std::vector inputBuffers(inputSize); - std::vector inputTadShapes(inputSize); - std::vector inputTadOffsets(inputSize); + auto z = bz + zTadOffsets[idx]; + auto x = reinterpret_cast(vx[e]) + xTadOffsets[e][i]; - std::vector indicesBuffers(inputSize); - std::vector indicesShapes(inputSize); + for (int f = threadIdx.x; f < zLength; f += blockDim.x) { + z[shape::getIndexOffset(f, zTadShapeInfo)] = + x[shape::getIndexOffset(f, xShapeInfo)]; + } - for (int e = 0; e < inputSize; e++) { - std::vector sourceDims(inputs[e]->rankOf() - indices[e]->rankOf()); - for (int i = sourceDims.size(); i > 0; i--) - sourceDims[sourceDims.size() - i] = inputs[e]->rankOf() - i; + __syncthreads(); + } + } +} - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(inputs[e]->shapeInfo(), sourceDims); +template +static int _dynamicStitchFunctor(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray *output) { + int inputSize = inputs.size(); - indicesBuffers[e] = indices[e]->specialBuffer(); - indicesShapes[e] = indices[e]->specialShapeInfo(); + PointersManager pm(context, "dynamicStitch"); - inputBuffers[e] = inputs[e]->specialBuffer(); - inputTadShapes[e] = packX.platformShapeInfo(); - inputTadOffsets[e] = packX.platformOffsets(); - } + if (output->isVector()) { + std::vector inputBuffers(inputSize); + std::vector inputShapes(inputSize); + std::vector indicesBuffers(inputSize); + std::vector indicesShapes(inputSize); - // copying pointers to buffers to device - auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); - auto dInputTadShapes = reinterpret_cast(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); - auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); + for (int e = 0; e < inputSize; e++) { + inputBuffers[e] = inputs.at(e)->specialBuffer(); + indicesBuffers[e] = indices.at(e)->specialBuffer(); - auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); - auto dIndicesShapes = reinterpret_cast(pm.replicatePointer(indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + inputShapes[e] = inputs.at(e)->specialShapeInfo(); + indicesShapes[e] = indices.at(e)->specialShapeInfo(); + } - dynamicStitchTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(dInputBuffers, dInputTadShapes, dInputTadOffsets, dIndicesBuffers, dIndicesShapes, inputSize, output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets()); - } + // copying pointers to buffers to device + auto dInputBuffers = reinterpret_cast( + pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); + auto dIndicesBuffers = reinterpret_cast( + pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); + auto dInputShapes = reinterpret_cast(pm.replicatePointer( + inputShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto dIndicesShapes = reinterpret_cast(pm.replicatePointer( + indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); + + dynamicStitchScalarKernel + <<<256, 256, 1024, *context->getCudaStream()>>>( + dInputBuffers, dInputShapes, dIndicesBuffers, dIndicesShapes, + inputSize, output->specialBuffer(), output->specialShapeInfo(), + output->lengthOf()); + } else { + std::vector restDims(output->rankOf() - 1); + for (int i = restDims.size(); i > 0; i--) + restDims[restDims.size() - i] = output->rankOf() - i; + + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), restDims); + + std::vector inputBuffers(inputSize); + std::vector inputTadShapes(inputSize); + std::vector inputTadOffsets(inputSize); + + std::vector indicesBuffers(inputSize); + std::vector indicesShapes(inputSize); + + for (int e = 0; e < inputSize; e++) { + std::vector sourceDims(inputs[e]->rankOf() - indices[e]->rankOf()); + for (int i = sourceDims.size(); i > 0; i--) + sourceDims[sourceDims.size() - i] = inputs[e]->rankOf() - i; + + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + inputs[e]->shapeInfo(), sourceDims); + + indicesBuffers[e] = indices[e]->specialBuffer(); + indicesShapes[e] = indices[e]->specialShapeInfo(); + + inputBuffers[e] = inputs[e]->specialBuffer(); + inputTadShapes[e] = packX.platformShapeInfo(); + inputTadOffsets[e] = packX.platformOffsets(); + } - pm.synchronize(); + // copying pointers to buffers to device + auto dInputBuffers = reinterpret_cast( + pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); + auto dInputTadShapes = reinterpret_cast(pm.replicatePointer( + inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); + auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer( + inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); - return Status::OK(); - } + auto dIndicesBuffers = reinterpret_cast( + pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); + auto dIndicesShapes = reinterpret_cast(pm.replicatePointer( + indicesShapes.data(), inputSize * sizeof(Nd4jLong *))); - template - static void _dynamicPartitionFunctorBP(NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { + dynamicStitchTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + dInputBuffers, dInputTadShapes, dInputTadOffsets, dIndicesBuffers, + dIndicesShapes, inputSize, output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets()); + } - } + pm.synchronize(); - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList) { - auto xType = input->dataType(); - auto yType = indices->dataType(); + return Status::OK(); +} - NDArray::prepareSpecialUse({}, {indices, input}); +template +static void _dynamicPartitionFunctorBP( + NDArray const *input, NDArray const *indices, + std::vector const &inputGradientList, + std::vector &outputList) {} - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), NUMERIC_TYPES, INDEXING_TYPES); +void dynamicPartitionFunctor(sd::LaunchContext *context, NDArray const *input, + NDArray const *indices, + std::vector &outputList) { + auto xType = input->dataType(); + auto yType = indices->dataType(); - NDArray::registerSpecialUse({}, {indices, input}); + NDArray::prepareSpecialUse({}, {indices, input}); - // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list - for (auto v:outputList) { - v->tickWriteDevice(); - } - } + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, + (context, input, indices, outputList), NUMERIC_TYPES, + INDEXING_TYPES); - template - static int _dynamicStitchFunctorBP(std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList){ - throw std::runtime_error("Not umplemented yet"); - } + NDArray::registerSpecialUse({}, {indices, input}); - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output){ - auto xType = inputs.at(0)->dataType(); - auto yType = indices.at(0)->dataType(); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that + // accepts something else beyond initializer_list + for (auto v : outputList) { + v->tickWriteDevice(); + } +} - for (auto v:indices) { - v->syncToDevice(); - v->tickReadDevice(); - } +template +static int _dynamicStitchFunctorBP(std::vector const &inputs, + std::vector const &indices, + NDArray const *gradInput, + std::vector &outputList) { + throw std::runtime_error("Not umplemented yet"); +} - for (auto v:inputs) { - v->syncToDevice(); - v->tickReadDevice(); - } +int dynamicStitchFunctor(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray *output) { + auto xType = inputs.at(0)->dataType(); + auto yType = indices.at(0)->dataType(); - NDArray::prepareSpecialUse({output}, {}); + for (auto v : indices) { + v->syncToDevice(); + v->tickReadDevice(); + } + for (auto v : inputs) { + v->syncToDevice(); + v->tickReadDevice(); + } - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::prepareSpecialUse({output}, {}); - NDArray::registerSpecialUse({output}, {}); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, + (context, inputs, indices, output), NUMERIC_TYPES, + INDEXING_TYPES); - return Status::OK(); - } + NDArray::registerSpecialUse({output}, {}); - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { - auto xType = inputs.at(0)->dataType(); + return Status::OK(); +} - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), NUMERIC_TYPES); - } +int dynamicStitchFunctorBP(sd::LaunchContext *context, + std::vector const &inputs, + std::vector const &indices, + NDArray const *gradInput, + std::vector &outputList) { + auto xType = inputs.at(0)->dataType(); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { - auto xType = input->dataType(); + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, + (inputs, indices, gradInput, outputList), + NUMERIC_TYPES); +} - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), NUMERIC_TYPES); - } +void dynamicPartitionFunctorBP(sd::LaunchContext *context, NDArray const *input, + NDArray const *indices, + std::vector const &inputGradientList, + std::vector &outputList) { + auto xType = input->dataType(); - } - } + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, + (input, indices, inputGradientList, outputList), + NUMERIC_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu index c5e8848cb0bd..377a734dc218 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu @@ -19,11 +19,12 @@ // @author sgazeos@gmail.com // -#include +#include #include #include +#include + #include -#include namespace sd { namespace ops { @@ -46,101 +47,120 @@ namespace helpers { // - outTadShape - output TAD shape // - outputOffsets - output TAD offsets // - template - static __global__ void globalExtractPatchesKernel(bool theSame, int batchCount, int sizeRow, int sizeCol, int rowDim, int colDim, int outRowDim, int outColDim, int strideRow, int strideCol, int rateRow, int rateCol, int rowCast, int colCast, int lastDim, const T* input, const Nd4jLong* patchShape, const Nd4jLong* inputOffsets, T* output, const Nd4jLong* outTadShape, const Nd4jLong* outputOffsets) { - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - - auto step = blockDim.x * gridDim.x; - // batch input by 3 last dims and extrapole input onto output with outColDim/outRowDim - for (Nd4jLong batch = start; batch < batchCount; batch += step) { - auto patch = input + inputOffsets[batch];// listOfMatricies->at(batch); - auto outMatrix = output + outputOffsets[batch]; //listOfOutputs->at(batch); - - for (Nd4jLong i = 0; i < outRowDim; i++) { - for (Nd4jLong j = 0; j < outColDim; j++) { - Nd4jLong pos = 0; - auto rowStart = i * strideRow - (theSame?rowCast:0); - auto colStart = j * strideCol - (theSame?colCast:0); - auto rowEnd = rowStart + sizeRow * rateRow; - auto colEnd = colStart + sizeCol * rateCol; - if (!theSame) { - rowEnd = math::nd4j_min(rowStart + sizeRow * rateRow, Nd4jLong (rowDim)); - colEnd = math::nd4j_min(colStart + sizeCol * rateCol, Nd4jLong (colDim)); - } - - for (auto row = rowStart; row < rowEnd; row += rateRow) { - for (auto col = colStart; col < colEnd; col += rateCol) { - for (auto pixel = 0; pixel < lastDim; pixel++) { - Nd4jLong zPos[] = {i, j, pos}; - Nd4jLong xPos[] = {row, col, pixel}; - bool setUp = - (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame); - - if (setUp) { // VALID or SAME cases - outMatrix[shape::getOffset(outTadShape, zPos)] = patch[shape::getOffset(patchShape, xPos)]; - } - pos++; - } - } - } - } - } +template +static __global__ void globalExtractPatchesKernel( + bool theSame, int batchCount, int sizeRow, int sizeCol, int rowDim, + int colDim, int outRowDim, int outColDim, int strideRow, int strideCol, + int rateRow, int rateCol, int rowCast, int colCast, int lastDim, + const T* input, const Nd4jLong* patchShape, const Nd4jLong* inputOffsets, + T* output, const Nd4jLong* outTadShape, const Nd4jLong* outputOffsets) { + auto start = threadIdx.x + blockIdx.x * blockDim.x; + + auto step = blockDim.x * gridDim.x; + // batch input by 3 last dims and extrapole input onto output with + // outColDim/outRowDim + for (Nd4jLong batch = start; batch < batchCount; batch += step) { + auto patch = input + inputOffsets[batch]; // listOfMatricies->at(batch); + auto outMatrix = output + outputOffsets[batch]; // listOfOutputs->at(batch); + + for (Nd4jLong i = 0; i < outRowDim; i++) { + for (Nd4jLong j = 0; j < outColDim; j++) { + Nd4jLong pos = 0; + auto rowStart = i * strideRow - (theSame ? rowCast : 0); + auto colStart = j * strideCol - (theSame ? colCast : 0); + auto rowEnd = rowStart + sizeRow * rateRow; + auto colEnd = colStart + sizeCol * rateCol; + if (!theSame) { + rowEnd = + math::nd4j_min(rowStart + sizeRow * rateRow, Nd4jLong(rowDim)); + colEnd = + math::nd4j_min(colStart + sizeCol * rateCol, Nd4jLong(colDim)); } + for (auto row = rowStart; row < rowEnd; row += rateRow) { + for (auto col = colStart; col < colEnd; col += rateCol) { + for (auto pixel = 0; pixel < lastDim; pixel++) { + Nd4jLong zPos[] = {i, j, pos}; + Nd4jLong xPos[] = {row, col, pixel}; + bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && + col < colDim) || + (!theSame); + + if (setUp) { // VALID or SAME cases + outMatrix[shape::getOffset(outTadShape, zPos)] = + patch[shape::getOffset(patchShape, xPos)]; + } + pos++; + } + } + } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static void _extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int strideRow, int strideCol, int rateRow, int rateCol, bool theSame){ - NDArray::prepareSpecialUse({output}, {images}); - std::vector restDims({1, 2, 3}); // the first and the last dims - // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) - //int e = 0; - const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); - const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); - const int ksize = ksizeRowsEffective * ksizeColsEffective; - Nd4jLong lastDim = images->sizeAt(3); - Nd4jLong outLastDim = output->sizeAt(3); - Nd4jLong rowDim = images->sizeAt(1); - Nd4jLong colDim = images->sizeAt(2); - Nd4jLong outRowDim = output->sizeAt(1); - Nd4jLong outColDim = output->sizeAt(2); - auto rowCast = 1; - auto colCast = 1; - // validate shifts - if (sizeRow * rateRow < 3) - rowCast = 0; - if (sizeCol * rateCol < 3) - colCast = 0; - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(images->shapeInfo(), restDims.data(), restDims.size()); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), restDims.data(), restDims.size()); - int batchCount = packX.numberOfTads(); - - PointersManager manager(context, "helpers::extractPatches"); - - auto stream = context->getCudaStream(); - auto imagesBuffer = reinterpret_cast(images->specialBuffer()); - auto outputBuffer = reinterpret_cast(output->specialBuffer()); - - globalExtractPatchesKernel<<<128, 128, 1024, *stream>>>(theSame, batchCount, sizeRow, sizeCol, - rowDim, colDim, outRowDim, outColDim, strideRow, strideCol, rateRow, rateCol, rowCast, colCast, lastDim, - imagesBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), - packZ.specialOffsets()); - - manager.synchronize(); - NDArray::registerSpecialUse({output}, {images}); - } - BUILD_SINGLE_TEMPLATE(template void _extractPatches, (sd::LaunchContext * context, NDArray* input, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame), LIBND4J_TYPES); - - - - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame){ - auto xType = images->dataType(); - - BUILD_SINGLE_SELECTOR(xType, _extractPatches, (context, images, output, sizeRow, sizeCol, stradeRow, stradeCol, rateRow, rateCol, theSame), LIBND4J_TYPES); - } +template +static void _extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, + int strideRow, int strideCol, int rateRow, + int rateCol, bool theSame) { + NDArray::prepareSpecialUse({output}, {images}); + std::vector restDims({1, 2, 3}); // the first and the last dims + // 3D matricies - 2D matricies of vectors (if last dim is greater than 1) + // int e = 0; + const int ksizeRowsEffective = sizeRow + (sizeRow - 1) * (rateRow - 1); + const int ksizeColsEffective = sizeCol + (sizeCol - 1) * (rateCol - 1); + const int ksize = ksizeRowsEffective * ksizeColsEffective; + Nd4jLong lastDim = images->sizeAt(3); + Nd4jLong outLastDim = output->sizeAt(3); + Nd4jLong rowDim = images->sizeAt(1); + Nd4jLong colDim = images->sizeAt(2); + Nd4jLong outRowDim = output->sizeAt(1); + Nd4jLong outColDim = output->sizeAt(2); + auto rowCast = 1; + auto colCast = 1; + // validate shifts + if (sizeRow * rateRow < 3) rowCast = 0; + if (sizeCol * rateCol < 3) colCast = 0; + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + images->shapeInfo(), restDims.data(), restDims.size()); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), restDims.data(), restDims.size()); + int batchCount = packX.numberOfTads(); + + PointersManager manager(context, "helpers::extractPatches"); + + auto stream = context->getCudaStream(); + auto imagesBuffer = reinterpret_cast(images->specialBuffer()); + auto outputBuffer = reinterpret_cast(output->specialBuffer()); + + globalExtractPatchesKernel<<<128, 128, 1024, *stream>>>( + theSame, batchCount, sizeRow, sizeCol, rowDim, colDim, outRowDim, + outColDim, strideRow, strideCol, rateRow, rateCol, rowCast, colCast, + lastDim, imagesBuffer, packX.specialShapeInfo(), packX.specialOffsets(), + outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets()); + + manager.synchronize(); + NDArray::registerSpecialUse({output}, {images}); } +BUILD_SINGLE_TEMPLATE(template void _extractPatches, + (sd::LaunchContext * context, NDArray* input, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame), + LIBND4J_TYPES); + +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame) { + auto xType = images->dataType(); + + BUILD_SINGLE_SELECTOR(xType, _extractPatches, + (context, images, output, sizeRow, sizeCol, stradeRow, + stradeCol, rateRow, rateCol, theSame), + LIBND4J_TYPES); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 7fcd71dba4c0..2db249d5a014 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -18,8 +18,8 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { @@ -33,104 +33,116 @@ namespace helpers { // narrowed - shrink is true // output - output tensor // - template - static __host__ __device__ void - nudge(T min, T max, int quantMin, int quantMax, T* scale, T* nudgedMin, T* nudgedMax) { - T quantMaxF = static_cast(quantMax); - T quantMinF = static_cast(quantMin); - *scale = (max - min) / (quantMaxF - quantMinF); - auto zeroPointFromMin = quantMinF - min / *scale; - uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, quantMaxF, quantMinF] { - if (zeroPointFromMin < quantMinF) { - return static_cast(quantMin); - } - if (zeroPointFromMin > quantMaxF) { - return static_cast(quantMax); - } - return sd::math::nd4j_round(zeroPointFromMin); - }(); - *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); - *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); +template +static __host__ __device__ void nudge(T min, T max, int quantMin, int quantMax, + T* scale, T* nudgedMin, T* nudgedMax) { + T quantMaxF = static_cast(quantMax); + T quantMinF = static_cast(quantMin); + *scale = (max - min) / (quantMaxF - quantMinF); + auto zeroPointFromMin = quantMinF - min / *scale; + uint16_t const nudgedZeroPoint = [zeroPointFromMin, quantMin, quantMax, + quantMaxF, quantMinF] { + if (zeroPointFromMin < quantMinF) { + return static_cast(quantMin); } - - template - void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed?1:0; - int upperIntBound = (1 << numBits) - 1; - min->syncToHost(); // these are scalars, so nothing much happened - max->syncToHost(); - T scale, nudgedMin, nudgedMax; - nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - - auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { - T val = x; - if (x < nudgedMin) { - val = nudgedMin; - } - else if (x > nudgedMax) { - val = nudgedMax; - } - else - val = x; - return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); - }; - - input->applyLambda(wiseMinMaxAndSoOn, *output); + if (zeroPointFromMin > quantMaxF) { + return static_cast(quantMax); } + return sd::math::nd4j_round(zeroPointFromMin); + }(); + *nudgedMax = (quantMaxF - static_cast(nudgedZeroPoint)) * (*scale); + *nudgedMin = (quantMinF - static_cast(nudgedZeroPoint)) * (*scale); +} - template - static __global__ void fakeQuantWithMinMaxKernel(const T* input, const Nd4jLong* inputShape, - T* min, T* max, - int lowIntBound, int upperIntBound, Nd4jLong channels, - T* output, const Nd4jLong* outputShape, - Nd4jLong length) { - __shared__ int block; - if (threadIdx.x == 0) { - block = length / channels; // to loop with last dimension as block - } - __syncthreads(); +template +void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; + min->syncToHost(); // these are scalars, so nothing much happened + max->syncToHost(); + T scale, nudgedMin, nudgedMax; + nudge(min->t(0), max->t(0), lowIntBound, upperIntBound, &scale, + &nudgedMin, &nudgedMax); - for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { - T scale, nudgedMin, nudgedMax; - nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - // loop over blocks to quantization between nudged min and max - for (auto b = threadIdx.x; b < block; b += blockDim.x) { - T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; - if (val < nudgedMin) { - val = nudgedMin; - } else if (val > nudgedMax) { - val = nudgedMax; - } - output[shape::getIndexOffset(b * channels + i, outputShape)] = - (math::nd4j_floor((val - nudgedMin) / scale + T(0.5f)) * scale + nudgedMin); - }; - } - } + auto wiseMinMaxAndSoOn = LAMBDA_T(x, nudgedMin, nudgedMax, scale) { + T val = x; + if (x < nudgedMin) { + val = nudgedMin; + } else if (x > nudgedMax) { + val = nudgedMax; + } else + val = x; + return (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + + nudgedMin); + }; - template - void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - int lowIntBound = narrowed?1:0; - int upperIntBound = (1 << numBits) - 1; - auto channels = min->lengthOf(); - auto length = input->lengthOf(); - NDArray::prepareSpecialUse({output}, {min, max, input}); - auto stream = context->getCudaStream(); - T* inputBuf = input->dataBuffer()->specialAsT(); - T* outputBuf = output->dataBuffer()->specialAsT(); - T* minBuf = min->dataBuffer()->specialAsT(); - T* maxBuf = max->dataBuffer()->specialAsT(); - fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>(inputBuf, input->specialShapeInfo(), - minBuf, maxBuf, lowIntBound, upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); - NDArray::registerSpecialUse({output}, {min, max, input}); + input->applyLambda(wiseMinMaxAndSoOn, *output); +} - } +template +static __global__ void fakeQuantWithMinMaxKernel( + const T* input, const Nd4jLong* inputShape, T* min, T* max, int lowIntBound, + int upperIntBound, Nd4jLong channels, T* output, + const Nd4jLong* outputShape, Nd4jLong length) { + __shared__ int block; + if (threadIdx.x == 0) { + block = length / channels; // to loop with last dimension as block + } + __syncthreads(); - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, (input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, (context, input, min, max, numBits, narrowed, output), FLOAT_TYPES); - } + for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { + T scale, nudgedMin, nudgedMax; + nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, + &nudgedMax); + // loop over blocks to quantization between nudged min and max + for (auto b = threadIdx.x; b < block; b += blockDim.x) { + T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; + if (val < nudgedMin) { + val = nudgedMin; + } else if (val > nudgedMax) { + val = nudgedMax; + } + output[shape::getIndexOffset(b * channels + i, outputShape)] = + (math::nd4j_floor((val - nudgedMin) / scale + T(0.5f)) * scale + + nudgedMin); + }; + } } + +template +void fakeQuantWithMinMaxVarsPerChannel_(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + int lowIntBound = narrowed ? 1 : 0; + int upperIntBound = (1 << numBits) - 1; + auto channels = min->lengthOf(); + auto length = input->lengthOf(); + NDArray::prepareSpecialUse({output}, {min, max, input}); + auto stream = context->getCudaStream(); + T* inputBuf = input->dataBuffer()->specialAsT(); + T* outputBuf = output->dataBuffer()->specialAsT(); + T* minBuf = min->dataBuffer()->specialAsT(); + T* maxBuf = max->dataBuffer()->specialAsT(); + fakeQuantWithMinMaxKernel<<<128, 256, 256, *stream>>>( + inputBuf, input->specialShapeInfo(), minBuf, maxBuf, lowIntBound, + upperIntBound, channels, outputBuf, output->specialShapeInfo(), length); + NDArray::registerSpecialUse({output}, {min, max, input}); +} + +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVars_, + (input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), fakeQuantWithMinMaxVarsPerChannel_, + (context, input, min, max, numBits, narrowed, output), + FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu index aa2ff8297ca8..5dbaba63a7c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu @@ -18,68 +18,75 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void _CUDA_G flattenKernel(void **xBuffers, Nd4jLong **xShapeInfos, Nd4jLong *offsets, Nd4jLong numInputs, void *zBuffer, const Nd4jLong *zShapeInfo, char order) { - - int xCoord[MAX_RANK]; - - // each block of threads works on 1 input array - for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { - auto z = reinterpret_cast(zBuffer) + offsets[e]; - - auto xBuffer = reinterpret_cast(xBuffers[e]); - auto xShapeInfo = xShapeInfos[e]; - auto xLength = shape::length(xShapeInfo); - - // each element of this input array has own place within common output array - for (uint i = threadIdx.x; i < xLength; i += blockDim.x) - z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; - } - } - - template - void flatten_(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - PointersManager pm(context, "flatten"); - - std::vector hdBuffers(inputs.size()); - std::vector hOffsets(inputs.size()); - std::vector hdShapes(inputs.size()); - Nd4jLong cOffset = 0; - - // calculating offsets in output - for (int e = 0; e < inputs.size(); e++) { - hOffsets[e] = cOffset; - cOffset += inputs[e]->lengthOf(); - - hdBuffers[e] = inputs[e]->specialBuffer(); - hdShapes[e] = inputs[e]->specialShapeInfo(); - } - - // copying pointers to device - auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*)); - auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*)); - auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); - - - flattenKernel<<<256, 512, 8192, *context->getCudaStream()>>>(dBuffers, dShapes, dOffsets, inputs.size(), output->specialBuffer(), output->specialShapeInfo(), order); - - pm.synchronize(); - } - - void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { - // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually - for (auto v:inputs) - v->syncToDevice(); - - BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, (context, inputs, output, order), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +template +void _CUDA_G flattenKernel(void **xBuffers, Nd4jLong **xShapeInfos, + Nd4jLong *offsets, Nd4jLong numInputs, void *zBuffer, + const Nd4jLong *zShapeInfo, char order) { + int xCoord[MAX_RANK]; + + // each block of threads works on 1 input array + for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { + auto z = reinterpret_cast(zBuffer) + offsets[e]; + + auto xBuffer = reinterpret_cast(xBuffers[e]); + auto xShapeInfo = xShapeInfos[e]; + auto xLength = shape::length(xShapeInfo); + + // each element of this input array has own place within common output array + for (uint i = threadIdx.x; i < xLength; i += blockDim.x) + z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)]; + } +} + +template +void flatten_(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + PointersManager pm(context, "flatten"); + + std::vector hdBuffers(inputs.size()); + std::vector hOffsets(inputs.size()); + std::vector hdShapes(inputs.size()); + Nd4jLong cOffset = 0; + + // calculating offsets in output + for (int e = 0; e < inputs.size(); e++) { + hOffsets[e] = cOffset; + cOffset += inputs[e]->lengthOf(); + + hdBuffers[e] = inputs[e]->specialBuffer(); + hdShapes[e] = inputs[e]->specialShapeInfo(); + } + + // copying pointers to device + auto dBuffers = (void **)pm.replicatePointer(hdBuffers.data(), + inputs.size() * sizeof(void *)); + auto dShapes = (Nd4jLong **)pm.replicatePointer( + hdShapes.data(), inputs.size() * sizeof(Nd4jLong *)); + auto dOffsets = (Nd4jLong *)pm.replicatePointer( + hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); + + flattenKernel<<<256, 512, 8192, *context->getCudaStream()>>>( + dBuffers, dShapes, dOffsets, inputs.size(), output->specialBuffer(), + output->specialShapeInfo(), order); + + pm.synchronize(); +} + +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order) { + // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually + for (auto v : inputs) v->syncToDevice(); + + BUILD_SINGLE_SELECTOR(output->dataType(), flatten_, + (context, inputs, output, order), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu index 26778aa63165..efd8fd119074 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu @@ -18,166 +18,183 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 07.03.2019 // - -#include -#include #include #include +#include -namespace sd { -namespace ops { -namespace helpers { - - template - __global__ static void gatherCudaLinearKernel(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { +#include +namespace sd { +namespace ops { +namespace helpers { - __shared__ const X* x; - __shared__ const Y* y; - __shared__ X* z; - __shared__ Nd4jLong xLen, yLen, zLen; +template +__global__ static void gatherCudaLinearKernel( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { + __shared__ const X* x; + __shared__ const Y* y; + __shared__ X* z; + __shared__ Nd4jLong xLen, yLen, zLen; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + z = reinterpret_cast(vz); + y = reinterpret_cast(vy); + xLen = shape::length(xShapeInfo); + yLen = shape::length(yShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); + // const Nd4jLong zLen = shape::length(zShapeInfo); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int j = start; j < zLen; j += step) { + auto zIndex = shape::getIndexOffset(j, zShapeInfo); + auto yIndex = shape::getIndexOffset(j, yShapeInfo); + auto xIndex = shape::getIndexOffset(y[yIndex], xShapeInfo); + z[zIndex] = x[xIndex]; + } +} +////////////////////////////////////////////////////////////////////// +template +__global__ static void gatherCuda(const int numOfSubArrs, const void* vx, + const Nd4jLong* xShapeInfo, + const Nd4jLong* xOffsets, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zOffsets) { + const Y* y = reinterpret_cast(vy); + __shared__ const X* x; + __shared__ X* z; + + const Nd4jLong len = shape::length(xShapeInfo); + // const Nd4jLong zLen = shape::length(zShapeInfo); + for (int i = blockIdx.x; i < numOfSubArrs; i += gridDim.x) { if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - z = reinterpret_cast(vz); - y = reinterpret_cast(vy); - xLen = shape::length(xShapeInfo); - yLen = shape::length(yShapeInfo); - zLen = shape::length(zShapeInfo); + x = reinterpret_cast(vx) + + xOffsets[y[shape::getIndexOffset(i, yShapeInfo)]]; + z = reinterpret_cast(vz) + zOffsets[i]; } __syncthreads(); - //const Nd4jLong zLen = shape::length(zShapeInfo); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int j = start; j < zLen; j += step) { - auto zIndex = shape::getIndexOffset(j, zShapeInfo); - auto yIndex = shape::getIndexOffset(j, yShapeInfo); - auto xIndex = shape::getIndexOffset(y[yIndex], xShapeInfo); - z[zIndex] = x[xIndex]; - } -} -////////////////////////////////////////////////////////////////////// -template -__global__ static void gatherCuda(const int numOfSubArrs, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets) { - - const Y* y = reinterpret_cast(vy); - __shared__ const X* x; - __shared__ X* z; - - const Nd4jLong len = shape::length(xShapeInfo); - //const Nd4jLong zLen = shape::length(zShapeInfo); - for (int i = blockIdx.x; i < numOfSubArrs; i += gridDim.x) { - - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[y[shape::getIndexOffset(i, yShapeInfo)]]; - z = reinterpret_cast(vz) + zOffsets[i]; - } - __syncthreads(); - - for (int j = threadIdx.x; j < len; j += blockDim.x) { - auto zIndex = shape::getIndexOffset(j, zShapeInfo); - auto xIndex = shape::getIndexOffset(j, xShapeInfo); - z[zIndex] = x[xIndex]; - } - __syncthreads(); + for (int j = threadIdx.x; j < len; j += blockDim.x) { + auto zIndex = shape::getIndexOffset(j, zShapeInfo); + auto xIndex = shape::getIndexOffset(j, xShapeInfo); + z[zIndex] = x[xIndex]; } + __syncthreads(); + } } -template -__host__ static void gatherCudaLinear(const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo) { - gatherCudaLinearKernel<<<128, 256, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +__host__ static void gatherCudaLinear(const cudaStream_t* stream, + const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo) { + gatherCudaLinearKernel<<<128, 256, 1024, *stream>>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } ////////////////////////////////////////////////////////////////////// -template -__host__ static void gatherCudaLauncher(const cudaStream_t *stream, const int numOfSubArrs, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets) { - gatherCuda<<>>(numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, vz, zShapeInfo, zOffsets); +template +__host__ static void gatherCudaLauncher( + const cudaStream_t* stream, const int numOfSubArrs, const void* vx, + const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zOffsets) { + gatherCuda<<>>( + numOfSubArrs, vx, xShapeInfo, xOffsets, vy, yShapeInfo, vz, zShapeInfo, + zOffsets); } ////////////////////////////////////////////////////////////////////// -void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { - - const int inputRank = input->rankOf(); - const int numOfIntArgs = intArgs.size(); - - int axis = numOfIntArgs > 0 ? intArgs[0] : 0; - if(axis < 0) - axis += inputRank; - - if (indices == nullptr && numOfIntArgs == 2) { // scalar case - output->assign((*input)(intArgs[1], {axis})); +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs) { + const int inputRank = input->rankOf(); + const int numOfIntArgs = intArgs.size(); + + int axis = numOfIntArgs > 0 ? intArgs[0] : 0; + if (axis < 0) axis += inputRank; + + if (indices == nullptr && numOfIntArgs == 2) { // scalar case + output->assign((*input)(intArgs[1], {axis})); + } else if (indices != nullptr && indices->isScalar()) { + if (input->rankOf() <= 1) { // For scalar indices, rank 0 or 1 input: can't + // do tensor along dimension 0 as this is whole + // array... instead, we want to get a scalar + auto idx = indices->e(0); + auto scalarNDArray = input->e(idx); + output->assign(scalarNDArray); + } else { + NDArray inSubArr = (*input)(indices->e(0), {axis}); + output->assign(inSubArr); } - else if (indices != nullptr && indices->isScalar()) { - - if(input->rankOf() <= 1) { //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar - auto idx = indices->e(0); - auto scalarNDArray = input->e(idx); - output->assign(scalarNDArray); - } - else { - NDArray inSubArr = (*input)(indices->e(0), {axis}); - output->assign(inSubArr); - } + } else { + NDArray* pIndices = const_cast(indices); + if (indices == nullptr) + pIndices = + new NDArray(input->ordering(), {numOfIntArgs - 1}, + std::vector(intArgs.begin() + 1, intArgs.end()), + DataType::INT64, input->getContext()); + + std::vector dimsOut(pIndices->rankOf()); + std::iota(dimsOut.begin(), dimsOut.end(), + axis); // fill with axis, axis+1, ... axis+pIndices->rankOf()-1 + + const Nd4jLong numOfSubArrs = pIndices->lengthOf(); + + Nd4jLong *outSubArrShapeInfo(nullptr), *inSubArrShapeInfo(nullptr), + *outSubArrOffsets(nullptr), *inSubArrOffsets(nullptr); + input->getSubArrShapeAndOffsets({axis}, inSubArrShapeInfo, inSubArrOffsets); + output->getSubArrShapeAndOffsets(dimsOut, outSubArrShapeInfo, + outSubArrOffsets); + if (output->rankOf() > 1) { + PointersManager manager(context, "gather"); + auto xShapeInfo = reinterpret_cast(manager.replicatePointer( + inSubArrShapeInfo, shape::shapeInfoByteLength(inSubArrShapeInfo))); + auto zShapeInfo = reinterpret_cast(manager.replicatePointer( + outSubArrShapeInfo, shape::shapeInfoByteLength(outSubArrShapeInfo))); + auto xOffsets = reinterpret_cast(manager.replicatePointer( + inSubArrOffsets, + (input->lengthOf() / shape::length(inSubArrShapeInfo)) * + sizeof(Nd4jLong))); + auto zOffsets = reinterpret_cast(manager.replicatePointer( + outSubArrOffsets, + (output->lengthOf() / shape::length(outSubArrShapeInfo)) * + sizeof(Nd4jLong))); + + NDArray::prepareSpecialUse({output}, {input, pIndices}); + BUILD_DOUBLE_SELECTOR( + input->dataType(), pIndices->dataType(), gatherCudaLauncher, + (context->getCudaStream(), numOfSubArrs, input->specialBuffer(), + xShapeInfo, xOffsets, pIndices->specialBuffer(), + pIndices->specialShapeInfo(), output->specialBuffer(), zShapeInfo, + zOffsets), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, pIndices}); + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({output}, {input, pIndices}); + BUILD_DOUBLE_SELECTOR( + input->dataType(), pIndices->dataType(), gatherCudaLinear, + (context->getCudaStream(), input->specialBuffer(), + input->specialShapeInfo(), pIndices->specialBuffer(), + pIndices->specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo()), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, pIndices}); } - else { - - NDArray* pIndices = const_cast(indices); - if(indices == nullptr) - pIndices = new NDArray(input->ordering(), {numOfIntArgs-1}, std::vector(intArgs.begin() + 1, intArgs.end()), DataType::INT64, input->getContext()); - - std::vector dimsOut(pIndices->rankOf()); - std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... axis+pIndices->rankOf()-1 - - const Nd4jLong numOfSubArrs = pIndices->lengthOf(); - - Nd4jLong *outSubArrShapeInfo(nullptr), *inSubArrShapeInfo(nullptr), *outSubArrOffsets(nullptr), *inSubArrOffsets(nullptr); - input-> getSubArrShapeAndOffsets({axis}, inSubArrShapeInfo, inSubArrOffsets); - output->getSubArrShapeAndOffsets(dimsOut, outSubArrShapeInfo, outSubArrOffsets); - if (output->rankOf() > 1) { - PointersManager manager(context, "gather"); - auto xShapeInfo = reinterpret_cast(manager.replicatePointer(inSubArrShapeInfo, - shape::shapeInfoByteLength( - inSubArrShapeInfo))); - auto zShapeInfo = reinterpret_cast(manager.replicatePointer(outSubArrShapeInfo, - shape::shapeInfoByteLength( - outSubArrShapeInfo))); - auto xOffsets = reinterpret_cast(manager.replicatePointer(inSubArrOffsets, (input->lengthOf() / - shape::length( - inSubArrShapeInfo)) * - sizeof(Nd4jLong))); - auto zOffsets = reinterpret_cast(manager.replicatePointer(outSubArrOffsets, - (output->lengthOf() / - shape::length(outSubArrShapeInfo)) * - sizeof(Nd4jLong))); - - NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->specialBuffer(), xShapeInfo, xOffsets, pIndices->specialBuffer(), pIndices->specialShapeInfo(), output->specialBuffer(), zShapeInfo, zOffsets), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, pIndices}); - manager.synchronize(); - } - else { - NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), pIndices->specialBuffer(), pIndices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, pIndices}); - - } - - if(indices == nullptr) - delete pIndices; - } + if (indices == nullptr) delete pIndices; + } } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index d72f3e1bc126..1ab8aeba2790 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -18,128 +18,136 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { - /////////////////////////////////////////////////////////////////// +namespace ops { +namespace helpers { +/////////////////////////////////////////////////////////////////// // x - input, y - indices, z - output - template - __global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); +template +__global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ int xRank, yRank, zRank, maxRank, yLastDim; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ int xRank, yRank, zRank, maxRank, yLastDim; + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - maxRank = sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + maxRank = + sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); - zLen = shape::length(zShapeInfo); - yLastDim = yShapeInfo[yRank]; + zLen = shape::length(zShapeInfo); + yLastDim = yShapeInfo[yRank]; - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - auto coord = sharedMem + threadIdx.x * maxRank; + auto coord = sharedMem + threadIdx.x * maxRank; - Nd4jLong *zCoordStart, *xCoordStart; + Nd4jLong *zCoordStart, *xCoordStart; - if(yLastDim == xRank) { - zCoordStart = coord; - xCoordStart = coord; - } - if(zRank >= xRank) { - zCoordStart = coord; - xCoordStart = coord + zRank - xRank; - } - else { - zCoordStart = coord + xRank - zRank; - xCoordStart = coord; - } + if (yLastDim == xRank) { + zCoordStart = coord; + xCoordStart = coord; + } + if (zRank >= xRank) { + zCoordStart = coord; + xCoordStart = coord + zRank - xRank; + } else { + zCoordStart = coord + xRank - zRank; + xCoordStart = coord; + } - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, zCoordStart); - shape::index2coords(i, zShapeInfo, zCoordStart); + const auto zOffset = shape::getOffset(zShapeInfo, zCoordStart); - const auto zOffset = shape::getOffset(zShapeInfo, zCoordStart); + // last y coordinate + int coordToRestore; + if (yLastDim != xRank) + coordToRestore = static_cast(zCoordStart[yRank - 1]); - // last y coordinate - int coordToRestore; - if(yLastDim != xRank) - coordToRestore = static_cast(zCoordStart[yRank - 1]); + zCoordStart[yRank - 1] = 0; // last y coordinate + const auto yOffset = shape::getOffset(yShapeInfo, zCoordStart); - zCoordStart[yRank - 1] = 0; // last y coordinate - const auto yOffset = shape::getOffset(yShapeInfo, zCoordStart); + // restore z coordinate + if (yLastDim != xRank) zCoordStart[yRank - 1] = coordToRestore; - //restore z coordinate - if(yLastDim != xRank) - zCoordStart[yRank - 1] = coordToRestore; + // construct coordinates for x + for (uint j = 0; j < yLastDim; ++j) + xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride - // construct coordinates for x - for(uint j = 0; j < yLastDim; ++j) - xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride + const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); - const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); - - z[zOffset] = x[xOffset]; - // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); - } - } + z[zOffset] = x[xOffset]; + // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); + } +} /////////////////////////////////////////////////////////////////// - template - static void gatherNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - gatherNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - } +template +static void gatherNDCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + gatherNDCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +} /////////////////////////////////////////////////////////////////// - void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - - const int maxRank = sd::math::nd4j_max(indices.rankOf(), sd::math::nd4j_max(input.rankOf(), output.rankOf())); - - const int threadsPerBlock = 256; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * maxRank + 128; - - const auto xType = input.dataType(); - const auto yType = indices.dataType(); - - PointersManager manager(context, "gatherND"); - - NDArray::prepareSpecialUse({&output}, {&input, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &indices}); - - manager.synchronize(); - } - } - } -} \ No newline at end of file +void gatherND(sd::LaunchContext *context, NDArray &input, NDArray &indices, + NDArray &output) { + const int maxRank = sd::math::nd4j_max( + indices.rankOf(), + sd::math::nd4j_max(input.rankOf(), output.rankOf())); + + const int threadsPerBlock = 256; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 8 * threadsPerBlock * maxRank + 128; + + const auto xType = input.dataType(); + const auto yType = indices.dataType(); + + PointersManager manager(context, "gatherND"); + + NDArray::prepareSpecialUse({&output}, {&input, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, gatherNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), indices.specialBuffer(), + indices.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &indices}); + + manager.synchronize(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index f165d88b7e78..2fa6b4eb33a1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -25,19 +25,23 @@ namespace sd { namespace ops { namespace helpers { template -void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - // classic one - auto lambda = LAMBDA_TT(_x, _y, weight) { - return _x - (_y * weight); - }; +void applyGradientDescent_(LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + // classic one + auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; - input->applyPairwiseLambda(*step, lambda, *output); + input->applyPairwiseLambda(*step, lambda, *output); } -void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (context, input, step, weight, output), FLOAT_TYPES); -} -BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES); -} -} +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, + (context, input, step, weight, output), FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, + (LaunchContext * context, NDArray* input, NDArray* step, + double weight, NDArray* output), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu index e3fdd94116ea..67905940dbbf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu @@ -18,78 +18,91 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void _hammingKernel(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong *shared; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shared = reinterpret_cast(shmem); - } - __syncthreads(); - - // we want to nullify temporary memory before accumulating intermediate results - shared[threadIdx.x] = 0; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { - auto _x = static_cast(x[shape::getIndexOffset(e, xShapeInfo)]); - auto _y = static_cast(y[shape::getIndexOffset(e, yShapeInfo)]); - - // we save intermediate result into shared memory - shared[threadIdx.x] += __popcll(_x ^ _y); - } - __syncthreads(); - - // now we accumulate values - auto numItems = sd::math::nd4j_min(blockDim.x, length); - auto floorPow2 = numItems; - if (floorPow2 & (floorPow2 - 1)) { - - while (floorPow2 & (floorPow2 - 1)) - floorPow2 &= floorPow2 - 1; - - if (threadIdx.x >= floorPow2) - shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x]; - - __syncthreads(); - } - __syncthreads(); - - for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { - if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems) - shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads]; - - __syncthreads(); - } - __syncthreads(); - - // FIXME: do we really want atomicAdd on global memory here - // and store them to output - if (threadIdx.x == 0 && shared[0] > 0) - sd::math::atomics::nd4j_atomicAdd(&z[0], static_cast(shared[threadIdx.x])); - } - - template - static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) { - _hammingKernel<<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf()); - } - - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { - NDArray::prepareSpecialUse({&output}, {&x, &y}); - BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&x, &y}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { +template +static _CUDA_G void _hammingKernel(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, void *reductionBuffer, + Nd4jLong length) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong *shared; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shared = reinterpret_cast(shmem); + } + __syncthreads(); + + // we want to nullify temporary memory before accumulating intermediate + // results + shared[threadIdx.x] = 0; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { + auto _x = static_cast( + x[shape::getIndexOffset(e, xShapeInfo)]); + auto _y = static_cast( + y[shape::getIndexOffset(e, yShapeInfo)]); + + // we save intermediate result into shared memory + shared[threadIdx.x] += __popcll(_x ^ _y); + } + __syncthreads(); + + // now we accumulate values + auto numItems = sd::math::nd4j_min(blockDim.x, length); + auto floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + while (floorPow2 & (floorPow2 - 1)) floorPow2 &= floorPow2 - 1; + + if (threadIdx.x >= floorPow2) + shared[threadIdx.x - floorPow2] = + shared[threadIdx.x - floorPow2] + shared[threadIdx.x]; + + __syncthreads(); + } + __syncthreads(); + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; + activeThreads >>= 1) { + if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems) + shared[threadIdx.x] = + shared[threadIdx.x] + shared[threadIdx.x + activeThreads]; + + __syncthreads(); + } + __syncthreads(); + + // FIXME: do we really want atomicAdd on global memory here + // and store them to output + if (threadIdx.x == 0 && shared[0] > 0) + sd::math::atomics::nd4j_atomicAdd(&z[0], + static_cast(shared[threadIdx.x])); +} + +template +static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, + NDArray &z) { + _hammingKernel + <<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>( + x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf()); +} + +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + NDArray::prepareSpecialUse({&output}, {&x, &y}); + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, + (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&x, &y}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu b/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu index 1c4ca9152345..a66893b6cc8f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/hashcode.cu @@ -20,112 +20,127 @@ #include - namespace sd { - namespace ops { - namespace helpers { - template - static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) { - - for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x*blockDim.x) { - auto blockBuffer = buffer + b * numBlocks; - - Nd4jLong r = 1LL; - for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31LL * r + v; - } - - tempBuffer[b] = r; - } - } - - template - static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) { - - for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x * blockDim.x) { - auto blockBuffer = tempBuffer + b * numBlocks; - Nd4jLong r = 1LL; - - for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) { - auto v = longBytes(blockBuffer[e]); - r = 31LL * r + v; - } - - tempResult[b] = r; - - } - - } - - - static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) { - if (threadIdx.x == 0) { - if (length <= blockSize) - *resultBuf = *tempBufferA; - else - *resultBuf = *tempResult; - } - } - - template - void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { - auto blockSize = 32; - auto stream = context->getCudaStream(); - array.syncToDevice(); - - NDArray::prepareSpecialUse({&result}, {&array}); - auto length = array.lengthOf(); - int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); - auto tempA = NDArrayFactory::create('c', {numBlocks}, context); - auto tempB = NDArrayFactory::create('c', { numBlocks / blockSize + 1}, context); - - auto buffer = reinterpret_cast(array.specialBuffer()); //bufferAsT(); - auto tempBufferA = reinterpret_cast(tempA.specialBuffer()); //bufferAsT(); - auto tempBufferB = reinterpret_cast(tempB.specialBuffer()); //bufferAsT(); - - // default buffer is the first one, because it might be the last one in case of small arrays (< blockSize) - auto tempBuffer = tempBufferA; - auto tempResult = tempBufferB; - - // we divide array into 32 element chunks, and store intermediate results once - splitBufferToChuncks<<>>(buffer, tempBuffer, numBlocks, blockSize, length); - - // we replace pointer with intermediate one, and repeat only one chunk left - int iterationCount = 0; - while (numBlocks > 1) { - int lastLength = numBlocks; - numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); - - - internalHash<<>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength); - - - iterationCount++; - // swapping buffers - if (iterationCount % 2 == 0) { - tempBuffer = tempBufferA; - tempResult = tempBufferB; - } else { - tempBuffer = tempBufferB; - tempResult = tempBufferA; - } - } - - lastStep<<<1,1,128, *stream>>>(reinterpret_cast(result.specialBuffer()), tempBufferA, tempResult, length, blockSize); -// tempA.syncToHost(); -// tempB.syncToHost(); -// result.assign((length <= blockSize?tempA.e(0) : tempB.e(0))); - - NDArray::registerSpecialUse({&result}, {&array}); - } - - void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { - BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES); - } +namespace ops { +namespace helpers { +template +static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, + Nd4jLong numBlocks, + Nd4jLong blockSize, + Nd4jLong length) { + for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; + b += gridDim.x * blockDim.x) { + auto blockBuffer = buffer + b * numBlocks; + + Nd4jLong r = 1LL; + for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) { + auto v = longBytes(blockBuffer[e]); + r = 31LL * r + v; + } + + tempBuffer[b] = r; + } +} + +template +static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, + Nd4jLong numBlocks, Nd4jLong blockSize, + Nd4jLong lastLength) { + for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; + b += gridDim.x * blockDim.x) { + auto blockBuffer = tempBuffer + b * numBlocks; + Nd4jLong r = 1LL; + + for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; + e++) { + auto v = longBytes(blockBuffer[e]); + r = 31LL * r + v; + } + + tempResult[b] = r; + } +} + +static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, + Nd4jLong* tempResult, Nd4jLong length, + Nd4jLong blockSize) { + if (threadIdx.x == 0) { + if (length <= blockSize) + *resultBuf = *tempBufferA; + else + *resultBuf = *tempResult; + } +} - BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES); - } +template +void hashCode_(LaunchContext* context, NDArray& array, NDArray& result) { + auto blockSize = 32; + auto stream = context->getCudaStream(); + array.syncToDevice(); + + NDArray::prepareSpecialUse({&result}, {&array}); + auto length = array.lengthOf(); + int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); + auto tempA = NDArrayFactory::create('c', {numBlocks}, context); + auto tempB = NDArrayFactory::create( + 'c', {numBlocks / blockSize + 1}, context); + + auto buffer = reinterpret_cast(array.specialBuffer()); // bufferAsT(); + auto tempBufferA = reinterpret_cast( + tempA.specialBuffer()); // bufferAsT(); + auto tempBufferB = reinterpret_cast( + tempB.specialBuffer()); // bufferAsT(); + + // default buffer is the first one, because it might be the last one in case + // of small arrays (< blockSize) + auto tempBuffer = tempBufferA; + auto tempResult = tempBufferB; + + // we divide array into 32 element chunks, and store intermediate results once + splitBufferToChuncks<<>>( + buffer, tempBuffer, numBlocks, blockSize, length); + + // we replace pointer with intermediate one, and repeat only one chunk left + int iterationCount = 0; + while (numBlocks > 1) { + int lastLength = numBlocks; + numBlocks = + lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); + + internalHash<<>>( + tempBuffer, tempResult, numBlocks, blockSize, lastLength); + + iterationCount++; + // swapping buffers + if (iterationCount % 2 == 0) { + tempBuffer = tempBufferA; + tempResult = tempBufferB; + } else { + tempBuffer = tempBufferB; + tempResult = tempBufferA; } + } + + lastStep<<<1, 1, 128, *stream>>>( + reinterpret_cast(result.specialBuffer()), tempBufferA, + tempResult, length, blockSize); + // tempA.syncToHost(); + // tempB.syncToHost(); + // result.assign((length <= blockSize?tempA.e(0) : + // tempB.e(0))); + + NDArray::registerSpecialUse({&result}, {&array}); +} + +void hashCode(LaunchContext* context, NDArray& array, NDArray& result) { + BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), + LIBND4J_TYPES); } +BUILD_SINGLE_TEMPLATE(template void hashCode_, + (LaunchContext * context, NDArray& array, + NDArray& result), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 6d7310fc0c20..dbdd24d65048 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -18,118 +18,137 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void _CUDA_G histogramKernel(void *xBuffer, const Nd4jLong *xShapeInfo, void *zBuffer, const Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, X* min_val, X* max_val) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - auto dx = reinterpret_cast(xBuffer); - auto result = reinterpret_cast(zBuffer); - - __shared__ Z *bins; - __shared__ int length; - __shared__ Z *reductor; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - bins = (Z *) shmem; - reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); - - length = shape::length(xShapeInfo); - } - __syncthreads(); - - X binSize = X((*max_val - *min_val) / numBins); - - // nullify bins - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0; - } - __syncthreads(); - - for (int e = tid; e < length; e += blockDim.x * gridDim.x) { - int idx = int((dx[e] - *min_val) / binSize); - idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); - idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); - sd::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); - } - __syncthreads(); - // at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick - - // transfer shared memory to reduction memory - if (gridDim.x > 1) { - unsigned int *tc = (unsigned int *)reductionPointer; - __shared__ bool amLast; - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - reductor[e] = bins[e]; - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - - // nullify shared memory for future accumulation - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0; - } - - // accumulate reduced bins - for (int r = 0; r < gridDim.x; r++) { - Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]); - } - } - __syncthreads(); - - // write them out to Z - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } - } else { - // if there's only 1 block - just write away data - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } - } - - template - static void histogram_(sd::LaunchContext *context, void *xBuffer, const Nd4jLong *xShapeInfo, const Nd4jLong *dxShapeInfo, void *zBuffer, const Nd4jLong *zShapeInfo, Nd4jLong numBins, void* min_val, void* max_val) { - int numThreads = 256; - int numBlocks = sd::math::nd4j_max(256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); - int workspaceSize = numBlocks * numBins; - auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); - - histogramKernel<<getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.specialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast(min_val), reinterpret_cast(max_val)); - - cudaStreamSynchronize(*context->getCudaStream()); - } - - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output) { - Nd4jLong numBins = output.lengthOf(); - NDArray::registerSpecialUse({&output}, {&input}); - - auto min_val = input.reduceNumber(reduce::SameOps::Min); - auto max_val = input.reduceNumber(reduce::SameOps::Max); -// min_val.printIndexedBuffer("MIN"); -// max_val.printIndexedBuffer("MAX"); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.shapeInfo(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), numBins, min_val.specialBuffer(), max_val.specialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } +namespace ops { +namespace helpers { +template +void _CUDA_G histogramKernel(void *xBuffer, const Nd4jLong *xShapeInfo, + void *zBuffer, const Nd4jLong *zShapeInfo, + void *allocationPointer, void *reductionPointer, + Nd4jLong numBins, X *min_val, X *max_val) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); + + __shared__ Z *bins; + __shared__ int length; + __shared__ Z *reductor; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + bins = (Z *)shmem; + reductor = ((Z *)allocationPointer) + (numBins * blockIdx.x); + + length = shape::length(xShapeInfo); + } + __syncthreads(); + + X binSize = X((*max_val - *min_val) / numBins); + + // nullify bins + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z)0; + } + __syncthreads(); + + for (int e = tid; e < length; e += blockDim.x * gridDim.x) { + int idx = int((dx[e] - *min_val) / binSize); + idx = math::nd4j_max(idx, 0); // atomicMax(&idx, 0);//atomicMax(&idx, 0); + idx = math::nd4j_min( + idx, int(numBins - 1)); // atomicMin(&idx, int(numBins - 1)); + sd::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); + } + __syncthreads(); + // at this point all bins in shared memory are calculated, so we aggregate + // them now via threadfence trick + + // transfer shared memory to reduction memory + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionPointer; + __shared__ bool amLast; + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + reductor[e] = bins[e]; + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + + // nullify shared memory for future accumulation + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z)0; + } + + // accumulate reduced bins + for (int r = 0; r < gridDim.x; r++) { + Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]); } + } + __syncthreads(); + + // write them out to Z + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; + } + } + } else { + // if there's only 1 block - just write away data + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; } -} \ No newline at end of file + } +} + +template +static void histogram_(sd::LaunchContext *context, void *xBuffer, + const Nd4jLong *xShapeInfo, const Nd4jLong *dxShapeInfo, + void *zBuffer, const Nd4jLong *zShapeInfo, + Nd4jLong numBins, void *min_val, void *max_val) { + int numThreads = 256; + int numBlocks = sd::math::nd4j_max( + 256, sd::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); + int workspaceSize = numBlocks * numBins; + auto tmp = NDArrayFactory::create('c', {workspaceSize}, context); + + histogramKernel + <<getCudaStream()>>>( + xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.specialBuffer(), + context->getReductionPointer(), numBins, + reinterpret_cast(min_val), reinterpret_cast(max_val)); + + cudaStreamSynchronize(*context->getCudaStream()); +} + +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + NDArray::registerSpecialUse({&output}, {&input}); + + auto min_val = input.reduceNumber(reduce::SameOps::Min); + auto max_val = input.reduceNumber(reduce::SameOps::Max); + // min_val.printIndexedBuffer("MIN"); + // max_val.printIndexedBuffer("MAX"); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, + (context, input.specialBuffer(), input.shapeInfo(), + input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), numBins, + min_val.specialBuffer(), max_val.specialBuffer()), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu index adb5a3ec4009..9593dfd59a79 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu @@ -18,104 +18,115 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 31.08.2018 // -#include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const X leftEdge, const X rightEdge) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; - __shared__ X binWidth, secondEdge, lastButOneEdge; +template +__global__ static void histogramFixedWidthCuda( + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const X leftEdge, const X rightEdge) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { + __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; + __shared__ X binWidth, secondEdge, lastButOneEdge; - xLen = shape::length(xShapeInfo); - nbins = shape::length(zShapeInfo); // nbins = zLen - totalThreads = gridDim.x * blockDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + nbins = shape::length(zShapeInfo); // nbins = zLen + totalThreads = gridDim.x * blockDim.x; - binWidth = (rightEdge - leftEdge ) / nbins; - secondEdge = leftEdge + binWidth; - lastButOneEdge = rightEdge - binWidth; - } + binWidth = (rightEdge - leftEdge) / nbins; + secondEdge = leftEdge + binWidth; + lastButOneEdge = rightEdge - binWidth; + } - __syncthreads(); + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + const X value = x[shape::getIndexOffset(i, xShapeInfo)]; - const X value = x[shape::getIndexOffset(i, xShapeInfo)]; + Nd4jLong zIndex; - Nd4jLong zIndex; + if (value < secondEdge) + zIndex = 0; + else if (value >= lastButOneEdge) + zIndex = nbins - 1; + else + zIndex = static_cast((value - leftEdge) / binWidth); - if(value < secondEdge) - zIndex = 0; - else if(value >= lastButOneEdge) - zIndex = nbins - 1; - else - zIndex = static_cast((value - leftEdge) / binWidth); - - sd::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo)], 1); - } + sd::math::atomics::nd4j_atomicAdd( + &z[shape::getIndexOffset(zIndex, zShapeInfo)], 1); + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) { - - const X leftEdge = range.e(0); - const X rightEdge = range.e(1); - - histogramFixedWidthCuda<<<256, 256, 1024, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge); +template +__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t* stream, + const NDArray& input, + const NDArray& range, + NDArray& output) { + const X leftEdge = range.e(0); + const X rightEdge = range.e(1); + + histogramFixedWidthCuda<<<256, 256, 1024, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), leftEdge, rightEdge); } //////////////////////////////////////////////////////////////////////// -void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const NDArray& range, NDArray& output) { +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output) { + // firstly initialize output with zeros + output.nullify(); - // firstly initialize output with zeros - output.nullify(); + PointersManager manager(context, "histogramFixedWidth"); - PointersManager manager(context, "histogramFixedWidth"); + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), + histogramFixedWidthCudaLauncher, + (context->getCudaStream(), input, range, output), + LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); + manager.synchronize(); } - // template -// __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { +// __global__ static void copyBuffers(Nd4jLong* destination, void const* +// source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { // const auto tid = blockIdx.x * gridDim.x + threadIdx.x; // const auto step = gridDim.x * blockDim.x; // for (int t = tid; t < bufferLength; t += step) { -// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape)]; +// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape)]; // } // } // template -// __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { +// __global__ static void returnBuffers(void* destination, Nd4jLong const* +// source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { // const auto tid = blockIdx.x * gridDim.x + threadIdx.x; // const auto step = gridDim.x * blockDim.x; // for (int t = tid; t < bufferLength; t += step) { -// reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape)] = source[t]; +// reinterpret_cast(destination)[shape::getIndexOffset(t, +// destinationShape)] = source[t]; // } // } // template -// static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { +// static __global__ void histogramFixedWidthKernel(void* outputBuffer, +// Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, +// Nd4jLong inputLength, double const leftEdge, double binWidth, double +// secondEdge, double lastButOneEdge) { // __shared__ T const* x; // __shared__ Nd4jLong* z; // output buffer @@ -131,7 +142,8 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // for(auto i = tid; i < inputLength; i += step) { // const T value = x[shape::getIndexOffset(i, inputShape)]; -// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// Nd4jLong currInd = static_cast((value - leftEdge) / +// binWidth); // if(value < secondEdge) // currInd = 0; @@ -141,9 +153,9 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // } // } - // template -// void histogramFixedWidth_(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// void histogramFixedWidth_(sd::LaunchContext * context, const NDArray& +// input, const NDArray& range, NDArray& output) { // const int nbins = output.lengthOf(); // auto stream = context->getCudaStream(); // // firstly initialize output with zeros @@ -161,16 +173,23 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // const double secondEdge = leftEdge + binWidth; // double lastButOneEdge = rightEdge - binWidth; // Nd4jLong* outputBuffer; -// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); -// if (err != 0) -// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); -// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf()); -// histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.specialBuffer(), input.specialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); -// returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); +// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * +// sizeof(Nd4jLong)); if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot +// allocate memory for output", err); +// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, +// output.specialBuffer(), output.specialShapeInfo(), +// output.lengthOf()); histogramFixedWidthKernel<<<256, 512, 8192, +// *stream>>>(outputBuffer, output.lengthOf(), input.specialBuffer(), +// input.specialShapeInfo(), input.lengthOf(), leftEdge, binWidth, +// secondEdge, lastButOneEdge); returnBuffers<<<256, 512, +// 8192, *stream>>>(output.specialBuffer(), outputBuffer, +// output.specialShapeInfo(), output.lengthOf()); // //cudaSyncStream(*stream); // err = cudaFree(outputBuffer); // if (err != 0) -// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot +// deallocate memory for output buffer", err); // output.tickWriteDevice(); // //#pragma omp parallel for schedule(guided) // // for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { @@ -182,20 +201,27 @@ void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, const // // output.p(0, output.e(0) + 1); // // else if(value >= lastButOneEdge) // //#pragma omp critical -// // output.p(nbins-1, output.e(nbins-1) + 1); +// // output.p(nbins-1, output.e(nbins-1) + +// 1); // // else { -// // Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// // Nd4jLong currInd = static_cast((value - leftEdge) +// / binWidth); // //#pragma omp critical -// // output.p(currInd, output.e(currInd) + 1); +// // output.p(currInd, output.e(currInd) + +// 1); // // } // // } // } -// void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { -// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); +// void histogramFixedWidth(sd::LaunchContext * context, const NDArray& +// input, const NDArray& range, NDArray& output) { +// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, +// (context, input, range, output), LIBND4J_TYPES); // } -// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, +// (sd::LaunchContext * context, const NDArray& input, const NDArray& range, +// NDArray& output), LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu index 08f5959e82c5..ad2992fb7df3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu @@ -18,91 +18,104 @@ // Created by raver119 on 30.11.17. // -#include #include +#include namespace sd { namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// // input [bS, iC, iH, iW] is convoluted to output [bS, iC, kH, kW, oH, oW] template __global__ static void im2colCuda(const void *image, void *columns, - const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, - const int sH, const int sW, - const int pH, const int pW, + const Nd4jLong *imShapeInfo, + const Nd4jLong *colShapeInfo, const int sH, + const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadValD) { + T zeroPadVal = + static_cast(zeroPadValD); // Value to use when value is padding. + // Usually 0 but not always + const auto im = reinterpret_cast(image); + auto col = reinterpret_cast(columns); - T zeroPadVal = static_cast(zeroPadValD); //Value to use when value is padding. Usually 0 but not always - const auto im = reinterpret_cast(image); - auto col = reinterpret_cast(columns); - - __shared__ Nd4jLong colLen, iH, iW; - __shared__ int imRank, colRank, *sharedMem; + __shared__ Nd4jLong colLen, iH, iW; + __shared__ int imRank, colRank, *sharedMem; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - colRank = 6; - imRank = 4; + colRank = 6; + imRank = 4; - colLen = shape::length(colShapeInfo); + colLen = shape::length(colShapeInfo); - iH = imShapeInfo[3]; - iW = imShapeInfo[4]; - } - __syncthreads(); + iH = imShapeInfo[3]; + iW = imShapeInfo[4]; + } + __syncthreads(); - const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto colInd = threadIdx.x + blockIdx.x * blockDim.x; - if(colInd >= colLen) - return; + if (colInd >= colLen) return; - auto coords = sharedMem + threadIdx.x * colRank; + auto coords = sharedMem + threadIdx.x * colRank; - shape::index2coords(colInd, colShapeInfo, coords); + shape::index2coords(colInd, colShapeInfo, coords); - const auto colOffset = shape::getOffset(colShapeInfo, coords); + const auto colOffset = shape::getOffset(colShapeInfo, coords); - coords[2] = (-pH + coords[2] * dH) + coords[4] * sH; // imH - coords[3] = (-pW + coords[3] * dW) + coords[5] * sW; // imW + coords[2] = (-pH + coords[2] * dH) + coords[4] * sH; // imH + coords[3] = (-pW + coords[3] * dW) + coords[5] * sW; // imW - if (static_cast(coords[2]) >= static_cast(iH) || static_cast(coords[3]) >= static_cast(iW)) - col[colOffset] = zeroPadVal; - else - col[colOffset] = im[shape::getOffset(imShapeInfo, coords)]; + if (static_cast(coords[2]) >= static_cast(iH) || + static_cast(coords[3]) >= static_cast(iW)) + col[colOffset] = zeroPadVal; + else + col[colOffset] = im[shape::getOffset(imShapeInfo, coords)]; } - ////////////////////////////////////////////////////////////////////////// template -static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, sd::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) { - im2colCuda<<>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal); +static void im2colCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + sd::LaunchContext &context, const void *image, + void *columns, const Nd4jLong *imShapeInfo, + const Nd4jLong *colShapeInfo, int sH, int sW, + int pH, int pW, int dH, int dW, + double zeroPadVal) { + im2colCuda<<>>(image, columns, imShapeInfo, + colShapeInfo, sH, sW, pH, pW, dH, + dW, zeroPadVal); } ////////////////////////////////////////////////////////////////////////// -void im2col(sd::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - - PointersManager manager(&context, "im2col"); - - const int threadsPerBlock = 512; - const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&columns}, {&image}); - BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.specialBuffer(), columns.specialBuffer(), image.specialShapeInfo(), columns.specialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), FLOAT_TYPES); - NDArray::registerSpecialUse({&columns}, {&image}); - - manager.synchronize(); +void im2col(sd::LaunchContext &context, const NDArray &image, NDArray &columns, + const int kH, const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, const int dW, + const NDArray &arrZeroPadVal) { + PointersManager manager(&context, "im2col"); + + const int threadsPerBlock = 512; + const int blocksPerGrid = + (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&columns}, {&image}); + BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, + (blocksPerGrid, threadsPerBlock, context, + image.specialBuffer(), columns.specialBuffer(), + image.specialShapeInfo(), columns.specialShapeInfo(), + sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), + FLOAT_TYPES); + NDArray::registerSpecialUse({&columns}, {&image}); + + manager.synchronize(); } - - - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu index 47319f10041e..0a5c4be6becf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_draw_bounding_boxes.cu @@ -17,164 +17,193 @@ // // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - typedef NDArray ColorTable_t; - static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { - //std::vector> colorTable; - const Nd4jLong kDefaultTableLength = 10; - const Nd4jLong kDefaultChannelLength = 4; - NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, { - 1,1,0,1, // yellow - 0, 0, 1, 1, // 1: blue - 1, 0, 0, 1, // 2: red - 0, 1, 0, 1, // 3: lime - 0.5, 0, 0.5, 1, // 4: purple - 0.5, 0.5, 0, 1, // 5: olive - 0.5, 0, 0, 1, // 6: maroon - 0, 0, 0.5, 1, // 7: navy blue - 0, 1, 1, 1, // 8: aqua - 1, 0, 1, 1 // 9: fuchsia - }, DataType::FLOAT32, context); - - if (depth == 1) { - colorTable.assign(1.f); // all to white when black and white colors - } - return colorTable; - } - - template - static __global__ void drawBoundingBoxesKernel(T const* images, const Nd4jLong* imagesShape, - float const* boxes, const Nd4jLong* boxesShape, - float const* colorTable, const Nd4jLong* colorTableShape, - T* output, const Nd4jLong* outputShape, - Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, - Nd4jLong channels, Nd4jLong boxSize, Nd4jLong colorTableLen) { - - for (auto batch = blockIdx.x; batch < (int)batchSize; batch += gridDim.x) { // loop by batch - for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) { - // box with shape - //auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); - auto colorIndex = boxIndex % colorTableLen;//colorSet->at(c); -// auto rowStart = sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); -// auto rowEnd = sd::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); -// auto colStart = sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); -// auto colEnd = sd::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); - Nd4jLong indices0[] = {batch, boxIndex, 0}; - Nd4jLong indices1[] = {batch, boxIndex, 1}; - Nd4jLong indices2[] = {batch, boxIndex, 2}; - Nd4jLong indices3[] = {batch, boxIndex, 3}; - auto rowStart = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]); - auto rowStartBound = sd::math::nd4j_max(Nd4jLong (0), rowStart); - auto rowEnd = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices2, 0)]); - auto rowEndBound = sd::math::nd4j_min(Nd4jLong (height - 1), rowEnd); - auto colStart = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]); - auto colStartBound = sd::math::nd4j_max(Nd4jLong (0), colStart); - auto colEnd = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices3, 0)]); - auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); - if (rowStart > rowEnd || colStart > colEnd) { -// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted " -// "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - if (rowStart >= height || rowEnd < 0 || colStart >= width || - colEnd < 0) { -// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely " -// "outside the image and not be drawn\n", rowStart, colStart, rowEnd, colEnd); - continue; - } - - // Draw upper line - if (rowStart >= 0) { - for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, rowStart, j, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - // Draw bottom line. - if (rowEnd < height) { - for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, rowEnd, j, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } +typedef NDArray ColorTable_t; +static NDArray DefaultColorTable(int depth, sd::LaunchContext* context) { + // std::vector> colorTable; + const Nd4jLong kDefaultTableLength = 10; + const Nd4jLong kDefaultChannelLength = 4; + NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, + { + 1, 1, 0, 1, // yellow + 0, 0, 1, 1, // 1: blue + 1, 0, 0, 1, // 2: red + 0, 1, 0, 1, // 3: lime + 0.5, 0, 0.5, 1, // 4: purple + 0.5, 0.5, 0, 1, // 5: olive + 0.5, 0, 0, 1, // 6: maroon + 0, 0, 0.5, 1, // 7: navy blue + 0, 1, 1, 1, // 8: aqua + 1, 0, 1, 1 // 9: fuchsia + }, + DataType::FLOAT32, context); - // Draw left line. - if (colStart >= 0) { - for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, i, colStart, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - // Draw right line. - if (colEnd < width) { - for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x) - for (auto c = 0; c < channels; c++) { - Nd4jLong zPos[] = {batch, i, colEnd, c}; - Nd4jLong cPos[] = {colorIndex, c}; - auto cIndex = shape::getOffset(colorTableShape, cPos, 0); - auto zIndex = shape::getOffset(outputShape, zPos, 0); - output[zIndex] = (T)colorTable[cIndex]; - } - } - } - } + if (depth == 1) { + colorTable.assign(1.f); // all to white when black and white colors + } + return colorTable; +} - } +template +static __global__ void drawBoundingBoxesKernel( + T const* images, const Nd4jLong* imagesShape, float const* boxes, + const Nd4jLong* boxesShape, float const* colorTable, + const Nd4jLong* colorTableShape, T* output, const Nd4jLong* outputShape, + Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, + Nd4jLong boxSize, Nd4jLong colorTableLen) { + for (auto batch = blockIdx.x; batch < (int)batchSize; + batch += gridDim.x) { // loop by batch + for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) { + // box with shape + // auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, + // {0})(c, {0});//internalBoxes->at(c); + auto colorIndex = boxIndex % colorTableLen; // colorSet->at(c); + // auto rowStart = sd::math::nd4j_max(Nd4jLong (0), + // Nd4jLong ((height - 1) * internalBox[0])); auto rowEnd = + // sd::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong + // ((height - 1) * internalBox[2])); auto colStart = + // sd::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * + // internalBox[1])); auto colEnd = + // sd::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width + // - 1) * internalBox[3])); + Nd4jLong indices0[] = {batch, boxIndex, 0}; + Nd4jLong indices1[] = {batch, boxIndex, 1}; + Nd4jLong indices2[] = {batch, boxIndex, 2}; + Nd4jLong indices3[] = {batch, boxIndex, 3}; + auto rowStart = Nd4jLong( + (height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]); + auto rowStartBound = sd::math::nd4j_max(Nd4jLong(0), rowStart); + auto rowEnd = Nd4jLong((height - 1) * + boxes[shape::getOffset(boxesShape, indices2, 0)]); + auto rowEndBound = sd::math::nd4j_min(Nd4jLong(height - 1), rowEnd); + auto colStart = Nd4jLong( + (width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]); + auto colStartBound = sd::math::nd4j_max(Nd4jLong(0), colStart); + auto colEnd = Nd4jLong((width - 1) * + boxes[shape::getOffset(boxesShape, indices3, 0)]); + auto colEndBound = sd::math::nd4j_min(Nd4jLong(width - 1), colEnd); + if (rowStart > rowEnd || colStart > colEnd) { + // printf("helpers::drawBoundingBoxesFunctor: + // Bounding box (%lld, %lld, %lld, %lld) is inverted + // " + // "and will not be drawn\n", rowStart, + // colStart, rowEnd, colEnd); + continue; + } + if (rowStart >= height || rowEnd < 0 || colStart >= width || colEnd < 0) { + // printf("helpers::drawBoundingBoxesFunctor: + // Bounding box (%lld, %lld, %lld, %lld) is + // completely " + // "outside the image and not be + // drawn\n", rowStart, colStart, rowEnd, + // colEnd); + continue; + } - template - void drawBoundingBoxesH(sd::LaunchContext* context, NDArray const* images, NDArray const* boxes, NDArray const* colors, NDArray* output) { - auto batchSize = images->sizeAt(0); - auto height = images->sizeAt(1); - auto width = images->sizeAt(2); - auto channels = images->sizeAt(3); - auto stream = context->getCudaStream(); - auto boxSize = boxes->sizeAt(1); - NDArray colorsTable = DefaultColorTable(channels, context); - if ((colors != nullptr && colors->lengthOf() > 0)) { - colorsTable = *colors; - } + // Draw upper line + if (rowStart >= 0) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; + j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowStart, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw bottom line. + if (rowEnd < height) { + for (auto j = colStartBound + threadIdx.x; j <= colEndBound; + j += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, rowEnd, j, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } - auto imagesBuf = images->getDataBuffer()->specialAsT(); - auto boxesBuf = boxes->getDataBuffer()->specialAsT(); // boxes should be float32 - auto colorsTableBuf = colorsTable.getDataBuffer()->specialAsT(); // color table is float32 - auto outputBuf = output->dataBuffer()->specialAsT(); - drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>(imagesBuf, images->specialShapeInfo(), - boxesBuf, boxes->specialShapeInfo(), colorsTableBuf, colorsTable.specialShapeInfo(), - outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, boxSize, colorsTable.lengthOf()); + // Draw left line. + if (colStart >= 0) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; + i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colStart, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } + // Draw right line. + if (colEnd < width) { + for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; + i += blockDim.x) + for (auto c = 0; c < channels; c++) { + Nd4jLong zPos[] = {batch, i, colEnd, c}; + Nd4jLong cPos[] = {colorIndex, c}; + auto cIndex = shape::getOffset(colorTableShape, cPos, 0); + auto zIndex = shape::getOffset(outputShape, zPos, 0); + output[zIndex] = (T)colorTable[cIndex]; + } + } } + } +} - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { - // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set - // boxes - batch of 2D bounds with last dim (y_start, x_start, y_end, x_end) to compute i and j as - // floor((height - 1 ) * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd - // floor((width - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd - // height = images->sizeAt(1), width = images->sizeAt(2) - // colors - colors for each box given - // set up color for each box as frame - NDArray::prepareSpecialUse({output}, {images, boxes, colors}); - output->assign(images); - BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, (context, images, boxes, colors, output), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {images, boxes, colors}); - } +template +void drawBoundingBoxesH(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* colors, + NDArray* output) { + auto batchSize = images->sizeAt(0); + auto height = images->sizeAt(1); + auto width = images->sizeAt(2); + auto channels = images->sizeAt(3); + auto stream = context->getCudaStream(); + auto boxSize = boxes->sizeAt(1); + NDArray colorsTable = DefaultColorTable(channels, context); + if ((colors != nullptr && colors->lengthOf() > 0)) { + colorsTable = *colors; + } + auto imagesBuf = images->getDataBuffer()->specialAsT(); + auto boxesBuf = + boxes->getDataBuffer()->specialAsT(); // boxes should be float32 + auto colorsTableBuf = colorsTable.getDataBuffer() + ->specialAsT(); // color table is float32 + auto outputBuf = output->dataBuffer()->specialAsT(); + drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>( + imagesBuf, images->specialShapeInfo(), boxesBuf, + boxes->specialShapeInfo(), colorsTableBuf, colorsTable.specialShapeInfo(), + outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, + boxSize, colorsTable.lengthOf()); } + +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, + NDArray* output) { + // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or + // RGBA (last dim = 4) channel set boxes - batch of 2D bounds with last dim + // (y_start, x_start, y_end, x_end) to compute i and j as floor((height - 1 ) + // * y_start) => rowStart, floor((height - 1) * y_end) => rowEnd floor((width + // - 1 ) * x_start) => colStart, floor((width - 1) * x_end) => colEnd height = + // images->sizeAt(1), width = images->sizeAt(2) colors - colors for each box + // given set up color for each box as frame + NDArray::prepareSpecialUse({output}, {images, boxes, colors}); + output->assign(images); + BUILD_SINGLE_SELECTOR(output->dataType(), drawBoundingBoxesH, + (context, images, boxes, colors, output), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {images, boxes, colors}); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index d483f87b36d8..c85be400238c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -33,51 +33,50 @@ limitations under the License. // @author George A. Shulinok // -#include #include +#include namespace sd { namespace ops { namespace helpers { - struct BilinearInterpolationData { - Nd4jLong bottomIndex; // Lower source index used in the interpolation - Nd4jLong topIndex; // Upper source index used in the interpolation - // 1-D linear iterpolation scale (see: - // https://en.wikipedia.org/wiki/Bilinear_interpolation) - double interpolarValue; - }; +struct BilinearInterpolationData { + Nd4jLong bottomIndex; // Lower source index used in the interpolation + Nd4jLong topIndex; // Upper source index used in the interpolation + // 1-D linear iterpolation scale (see: + // https://en.wikipedia.org/wiki/Bilinear_interpolation) + double interpolarValue; +}; // Older incorrect scaling method that causes all resizes to have a slight // translation leading to inconsistent results. For example, a flip then a // resize gives different results then a resize then a flip. - struct LegacyScaler { - _CUDA_HD LegacyScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; +struct LegacyScaler { + _CUDA_HD LegacyScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } +}; // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the // floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - _CUDA_HD HalfPixelScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - - - // Utility functions - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } +struct HalfPixelScaler { + _CUDA_HD HalfPixelScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } +}; + +// Utility functions +// calculateResizeScale determines the float scaling factor. +inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // computeInterpolationWeights kernel @@ -86,1233 +85,1486 @@ namespace helpers { // scale - input scale // interporationData - result // - template - static __global__ void computeInterpolationWeights(Nd4jLong outSize, - Nd4jLong inSize, - double scale, - Nd4jLong channels, - BilinearInterpolationData* interpolationData) { - interpolationData[outSize].bottomIndex = 0; - interpolationData[outSize].topIndex = 0; - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - Scaler scaler; - for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { - double in = scaler(i, scale); -// interpolationData[i].bottomIndex = static_cast(in); -// interpolationData[i].topIndex = sd::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); -// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; - double const in_f = sd::math::p_floor(in); - double const in_c = sd::math::p_ceil(in); - interpolationData[i].bottomIndex = sd::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); - interpolationData[i].topIndex = sd::math::nd4j_min(static_cast(in_c), inSize - 1); - interpolationData[i].interpolarValue = in - in_f; - - if (channels) { - math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels); - math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); - } - } +template +static __global__ void computeInterpolationWeights( + Nd4jLong outSize, Nd4jLong inSize, double scale, Nd4jLong channels, + BilinearInterpolationData* interpolationData) { + interpolationData[outSize].bottomIndex = 0; + interpolationData[outSize].topIndex = 0; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + Scaler scaler; + for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { + double in = scaler(i, scale); + // interpolationData[i].bottomIndex = static_cast(in); + // interpolationData[i].topIndex = + // sd::math::nd4j_min(interpolationData[i].bottomIndex + 1, + // inSize - 1); interpolationData[i].interpolarValue = in - + // interpolationData[i].bottomIndex; + double const in_f = sd::math::p_floor(in); + double const in_c = sd::math::p_ceil(in); + interpolationData[i].bottomIndex = + sd::math::nd4j_max(static_cast(in_f), + (Nd4jLong)0LL); // static_cast(in); + interpolationData[i].topIndex = + sd::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i].interpolarValue = in - in_f; + + if (channels) { + math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, + channels); + math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm // - static void resizeImage(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, - NDArray* output); +static void resizeImage(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, + Nd4jLong channels, BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm kernel // - template - static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr, - Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, - Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, - BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { - - for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index - auto pX = input + batch * inBatchNumValues; - for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { - const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; - const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; - double yVal = ys_[y].interpolarValue; - auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; - for (Nd4jLong x = 0; x < outWidth; x++) { - auto xsBottom = xs_[x].bottomIndex; - auto xsTop = xs_[x].topIndex; - auto xVal = xs_[x].interpolarValue; - // process interpolation for all channels - for (int c = 0; c < channels; c++) { - Z topLeft(ys_input_lower_ptr[xsBottom + c]); - Z topRight(ys_input_lower_ptr[xsTop + c]); - Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); - Z bottomRight(ys_input_upper_ptr[xsTop + c]); - Z top = topLeft + (topRight - topLeft) * xVal; - Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - Z resVal = Z(top + (bottom - top) * yVal); - pZ[x * channels + c] = resVal; - } - } - } +template +static __global__ void resizeImageKernel( + T const* input, Nd4jLong const* inputShape, Z* outputYptr, + Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, + Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, + Nd4jLong outRowSize, Nd4jLong inBatchNumValues, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { + for (auto batch = blockIdx.x; batch < batchSize; + batch += gridDim.x) { // blockIdx.x as batch index + auto pX = input + batch * inBatchNumValues; + for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { + const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; + const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; + double yVal = ys_[y].interpolarValue; + auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; + for (Nd4jLong x = 0; x < outWidth; x++) { + auto xsBottom = xs_[x].bottomIndex; + auto xsTop = xs_[x].topIndex; + auto xVal = xs_[x].interpolarValue; + // process interpolation for all channels + for (int c = 0; c < channels; c++) { + Z topLeft(ys_input_lower_ptr[xsBottom + c]); + Z topRight(ys_input_lower_ptr[xsTop + c]); + Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); + Z bottomRight(ys_input_upper_ptr[xsTop + c]); + Z top = topLeft + (topRight - topLeft) * xVal; + Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + Z resVal = Z(top + (bottom - top) * yVal); + pZ[x * channels + c] = resVal; } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with - template - static void resizeImage_(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, - NDArray* output) { - Nd4jLong inRowSize = inWidth * channels; - Nd4jLong inBatchNumValues = inHeight * inRowSize; - Nd4jLong outRowSize = outWidth * channels; - auto stream = context->getCudaStream(); - T const* pInput = images->getDataBuffer()->specialAsT(); //reinterpret_cast(images->specialBuffer()); // this works only with 'c' direction - F* pOutput = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); - dim3 batchSizeBlock(batchSize, 1, 1); - dim3 pictureBlock(outHeight, outWidth, channels); - resizeImageKernel<<<256, 256, 256, *stream>>>(pInput, images->specialShapeInfo(), pOutput, - output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, - inBatchNumValues, xs_, ys_); - - auto err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::resizeImage_: Cannot synchronize kernel execution", err); - } - } +template +static void resizeImage_(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, + Nd4jLong inWidth, Nd4jLong outHeight, + Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output) { + Nd4jLong inRowSize = inWidth * channels; + Nd4jLong inBatchNumValues = inHeight * inRowSize; + Nd4jLong outRowSize = outWidth * channels; + auto stream = context->getCudaStream(); + T const* pInput = images->getDataBuffer() + ->specialAsT(); // reinterpret_cast(images->specialBuffer()); // + // this works only with 'c' direction + F* pOutput = + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + dim3 batchSizeBlock(batchSize, 1, 1); + dim3 pictureBlock(outHeight, outWidth, channels); + resizeImageKernel<<<256, 256, 256, *stream>>>( + pInput, images->specialShapeInfo(), pOutput, output->specialShapeInfo(), + batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, + inBatchNumValues, xs_, ys_); + + auto err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::resizeImage_: Cannot synchronize kernel execution", err); + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static int resizeBilinearFunctor_(sd::LaunchContext* context, NDArray const* images, int const width, - int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return ND4J_STATUS_OK; - } - - float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); - float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); - - BilinearInterpolationData* xs_;// = xs.data(); - BilinearInterpolationData* ys_;// = xs.data(); - - cudaError_t err = cudaMalloc(&xs_, sizeof(BilinearInterpolationData) * (outWidth + 1)); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot allocate memory for vertical parts rectangulars", err); - } - - err = cudaMalloc(&ys_, sizeof(BilinearInterpolationData) * (outHeight + 1)); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot allocate memory for horizontal parts rectangulars", err); - } - auto stream = context->getCudaStream(); - // Compute the cached interpolation weights on the x and y dimensions. - if (halfPixelCenter) { - computeInterpolationWeights < - HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights < - HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - } - else { - computeInterpolationWeights < - LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights < - LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - } - printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth); - NDArray::prepareSpecialUse({output}, {images}); - resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); - err = cudaStreamSynchronize(*stream); - NDArray::registerSpecialUse({output}, {images}); - - err = cudaFree(xs_); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err); - } - - err = cudaFree(ys_); - if (err != 0) { - throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for horizontical parts rectangulars", err); - } - - return Status::OK(); - } +template +static int resizeBilinearFunctor_(sd::LaunchContext* context, + NDArray const* images, int const width, + int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return ND4J_STATUS_OK; + } + + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + BilinearInterpolationData* xs_; // = xs.data(); + BilinearInterpolationData* ys_; // = xs.data(); + + cudaError_t err = + cudaMalloc(&xs_, sizeof(BilinearInterpolationData) * (outWidth + 1)); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot allocate memory for vertical parts " + "rectangulars", + err); + } + + err = cudaMalloc(&ys_, sizeof(BilinearInterpolationData) * (outHeight + 1)); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot allocate memory for horizontal parts " + "rectangulars", + err); + } + auto stream = context->getCudaStream(); + // Compute the cached interpolation weights on the x and y dimensions. + if (halfPixelCenter) { + computeInterpolationWeights + <<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights<<<256, 512, 512, *stream>>>( + outWidth, inWidth, widthScale, channels, xs_); + } else { + computeInterpolationWeights + <<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights<<<256, 512, 512, *stream>>>( + outWidth, inWidth, widthScale, channels, xs_); + } + printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, + outWidth); + NDArray::prepareSpecialUse({output}, {images}); + resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, + outWidth, channels, xs_, ys_, output); + err = cudaStreamSynchronize(*stream); + NDArray::registerSpecialUse({output}, {images}); + + err = cudaFree(xs_); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot deallocate memory for vertical parts " + "rectangulars", + err); + } + + err = cudaFree(ys_); + if (err != 0) { + throw cuda_exception::build( + "helpers::resize_image: Cannot deallocate memory for horizontical " + "parts rectangulars", + err); + } + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize by interpolation nearest neighbor algorithm kernel // - template - static __global__ void resizeNeighborKernel(T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, - Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) { - - //for (int b = blockIdx.x; b < batchSize; b += gridDim.x) - if (blockIdx.x < batchSize) - { - auto b = blockIdx.x; - for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { - auto posY = alignCorners ? static_cast(sd::math::p_round(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast(sd::math::p_floor( - halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)); - Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); - if (halfPixelCenters) { - inY = sd::math::nd4j_max(0LL, inY); - } - - for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { - auto posX = alignCorners ? static_cast(sd::math::p_round(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast(sd::math::p_floor( - halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)); - Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); - if (halfPixelCenters) { - inX = sd::math::nd4j_max(0LL, inX); - } - - auto start = blockIdx.z * blockDim.z + threadIdx.z; - auto step = blockDim.z * gridDim.z; - - for (Nd4jLong e = start; e < channels; e += step) { - Nd4jLong posX[] = {b, inY, inX, e}; - Nd4jLong posZ[] = {b, y, x, e}; - auto xIndex = shape::getOffset(inputShape, posX); - auto zIndex = shape::getOffset(outputShape, posZ); - output[zIndex] = input[xIndex]; - } - } - } +template +static __global__ void resizeNeighborKernel( + T const* input, Nd4jLong const* inputShape, T* output, + Nd4jLong const* outputShape, Nd4jLong batchSize, Nd4jLong inWidth, + Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, + double widthScale, double heightScale, bool alignCorners, + bool halfPixelCenters) { + // for (int b = blockIdx.x; b < batchSize; b += gridDim.x) + if (blockIdx.x < batchSize) { + auto b = blockIdx.x; + for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { + auto posY = alignCorners + ? static_cast(sd::math::p_round( + halfPixelCenters ? ((float)y + 0.5f) * heightScale + : (float)y * heightScale)) + : static_cast(sd::math::p_floor( + halfPixelCenters ? ((float)y + 0.5f) * heightScale + : (float)y * heightScale)); + Nd4jLong inY = sd::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenters) { + inY = sd::math::nd4j_max(0LL, inY); + } + + for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { + auto posX = alignCorners + ? static_cast(sd::math::p_round( + halfPixelCenters ? ((float)x + 0.5f) * widthScale + : (float)x * widthScale)) + : static_cast(sd::math::p_floor( + halfPixelCenters ? ((float)x + 0.5f) * widthScale + : (float)x * widthScale)); + Nd4jLong inX = sd::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenters) { + inX = sd::math::nd4j_max(0LL, inX); } + auto start = blockIdx.z * blockDim.z + threadIdx.z; + auto step = blockDim.z * gridDim.z; + + for (Nd4jLong e = start; e < channels; e += step) { + Nd4jLong posX[] = {b, inY, inX, e}; + Nd4jLong posZ[] = {b, y, x, e}; + auto xIndex = shape::getOffset(inputShape, posX); + auto zIndex = shape::getOffset(outputShape, posZ); + output[zIndex] = input[xIndex]; + } + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resizeNeighborFunctor - main algorithm by nearest neighbor // - template - int resizeNeighborFunctor_(sd::LaunchContext* context, NDArray const* images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); - - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return ND4J_STATUS_OK; - } - -// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) || -// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { -// // wrong input data -// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); -// return ND4J_STATUS_BAD_ARGUMENTS; -// } -// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight)); -// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth)); - float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); - float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); - - auto imagesBuffer = images->getDataBuffer()->specialAsT();//reinterpret_cast(images->specialBuffer()); - auto outputBuffer = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); - auto stream = context->getCudaStream(); - - NDArray::prepareSpecialUse({output}, {images}); - resizeNeighborKernel<<>>(imagesBuffer, images->specialShapeInfo(), outputBuffer, output->specialShapeInfo(), - batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters); - NDArray::registerSpecialUse({output}, {images}); - - return Status::OK(); - } +template +int resizeNeighborFunctor_(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenters, + NDArray* output) { + const Nd4jLong batchSize = images->sizeAt(0); + const Nd4jLong inHeight = images->sizeAt(1); + const Nd4jLong inWidth = images->sizeAt(2); + const Nd4jLong channels = images->sizeAt(3); + + const Nd4jLong outHeight = output->sizeAt(1); + const Nd4jLong outWidth = output->sizeAt(2); + + // Handle no-op resizes efficiently. + if (outHeight == inHeight && outWidth == inWidth) { + output->assign(images); + return ND4J_STATUS_OK; + } + + // if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < + // 1) || (alignCorners && outHeight < 2) || + // (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) + // || (center && outWidth < 2)) { + // // wrong input data + // nd4j_printf("image.resize_nearest_neighbor: Wrong input or + // output size to resize\n", ""); return ND4J_STATUS_BAD_ARGUMENTS; + // } + // float heightScale = alignCorners ? (inHeight - 1.f) / + // float(outHeight - 1.f) : (inHeight / float(outHeight)); float + // widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) + // : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + auto imagesBuffer = + images->getDataBuffer() + ->specialAsT(); // reinterpret_cast(images->specialBuffer()); + auto outputBuffer = + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + auto stream = context->getCudaStream(); + + NDArray::prepareSpecialUse({output}, {images}); + resizeNeighborKernel<<>>( + imagesBuffer, images->specialShapeInfo(), outputBuffer, + output->specialShapeInfo(), batchSize, inWidth, inHeight, outWidth, + outHeight, channels, widthScale, heightScale, alignCorners, + halfPixelCenters); + NDArray::registerSpecialUse({output}, {images}); + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resizeImage - resize bilinear algorithm caller // - void resizeImage(sd::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, - Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, - BilinearInterpolationData* ys_, NDArray* output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), - resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, - xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES); - } +void resizeImage(sd::LaunchContext* context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, + (context, images, batchSize, inHeight, inWidth, + outHeight, outWidth, channels, xs_, ys_, output), + NUMERIC_TYPES, FLOAT_TYPES); +} - BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(sd::LaunchContext* context, NDArray const* images, - Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, - Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), - NUMERIC_TYPES, FLOAT_TYPES); +BUILD_DOUBLE_TEMPLATE(template void resizeImage_, + (sd::LaunchContext * context, NDArray const* images, + Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, + Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, + BilinearInterpolationData* xs_, + BilinearInterpolationData* ys_, NDArray* output), + NUMERIC_TYPES, FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, int width, int height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images, - width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); - } -// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (sd::LaunchContext* context, -// NDArray const* images, int const width, int const height, bool const alignCorners, -// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* images, + int width, int height, bool const alignCorners, + bool const halfPixelCenter, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), + NUMERIC_TYPES, FLOAT_TYPES); +} +// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, +// (sd::LaunchContext* context, +// NDArray const* images, int const width, int const height, bool +// const alignCorners, bool const halfPixelCenter, NDArray* output), +// LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, - (context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); - } -// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (sd::LaunchContext* context, NDArray const* images, -// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* images, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + images->dataType(), return resizeNeighborFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), + LIBND4J_TYPES); +} +// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, +// (sd::LaunchContext* context, NDArray const* images, +// int width, int height, bool const alignCorners, bool const +// halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Bicubic interpolation //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - struct ImageResizerState { - explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) - : _alignCorners(alignCorners), - _halfPixelCenters(halfPixelCenters) {} - - // ValidateAndCalculateOutputSize checks the bounds on the input tensors - // and requested size, sets up some of the resizing state such as the - // heightScale and widthScale, and calculates the output size. - // If any of these operations fails, it sets an error status in - // the context, which the caller must check. - int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) { - // - batchSize = input->sizeAt(0);//.dim_size(0); - outHeight = height; - outWidth = width; //internal::SubtleMustCopy(Svec(1)); - inHeight = static_cast(input->sizeAt(1)); - inWidth = static_cast(input->sizeAt(2)); - channels = input->sizeAt(3); //.dim_size(3); - heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); - widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); - - // Guard against overflows - if (ceilf((outHeight - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height"); - } - if (ceilf((outWidth - 1) * heightScale) > static_cast(DataTypeUtils::max())) { - nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale)); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width"); - } - - return Status::OK(); - } - - // Calculates all the required variables, and allocates the output. - int validateAndCreateOutput(NDArray const* input, int const width, int const height) { - return validateAndCalculateOutputSize(input, width, height); - } - - Nd4jLong batchSize; - Nd4jLong outHeight; - Nd4jLong outWidth; - Nd4jLong inHeight; - Nd4jLong inWidth; - Nd4jLong channels; - float heightScale; - float widthScale; - NDArray* output = nullptr; - cudaStream_t* stream; - private: - bool _alignCorners; - bool _halfPixelCenters; - }; - - struct WeightsAndIndices { - float _weight0; - float _weight1; - float _weight2; - float _weight3; - Nd4jLong _index0; - Nd4jLong _index1; - Nd4jLong _index2; - Nd4jLong _index3; - - int _advance; // advance value. - }; - - class CachedInterpolationCalculator { - public: - _CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} - - // Advances iteration. Returns the number of values that should be copied from - // the current point to the next point. The copying should always be done by - // copying the last values from the old point to the first - // values of the new point. - inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2, - const Nd4jLong x3) { - // We use 2 hands and walk through, copying from one to another where - // we already have values. - // Invariant, new_indicies_hand <= cached_values_hand - const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3}; - int cachedValuesHand = 0; - int newIndiciesHand = 0; - while (cachedValuesHand < 4) { - if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { - if (newIndiciesHand < cachedValuesHand) { - _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; - } - newIndiciesHand++; - } - cachedValuesHand++; - } - switch (newIndiciesHand) { - case 0: - _indexes[0] = x0; - case 1: - _indexes[1] = x1; - case 2: - _indexes[2] = x2; - case 3: - _indexes[3] = x3; - break; - } - return newIndiciesHand; - } - - private: - Nd4jLong _indexes[4]; - }; - - - static __global__ void initCoefTableKernel(const double a, float* table, Nd4jLong tableSize) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (int i = start; i <= tableSize; i += step) { - float x = i * 1.0 / tableSize; - table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; - x += 1.0; - table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; - } +struct ImageResizerState { + explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) + : _alignCorners(alignCorners), _halfPixelCenters(halfPixelCenters) {} + + // ValidateAndCalculateOutputSize checks the bounds on the input tensors + // and requested size, sets up some of the resizing state such as the + // heightScale and widthScale, and calculates the output size. + // If any of these operations fails, it sets an error status in + // the context, which the caller must check. + int validateAndCalculateOutputSize(NDArray const* input, int const width, + int const height) { + // + batchSize = input->sizeAt(0); //.dim_size(0); + outHeight = height; + outWidth = width; // internal::SubtleMustCopy(Svec(1)); + inHeight = static_cast(input->sizeAt(1)); + inWidth = static_cast(input->sizeAt(2)); + channels = input->sizeAt(3); //.dim_size(3); + heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners); + widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners); + + // Guard against overflows + if (ceilf((outHeight - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize height"); + } + if (ceilf((outWidth - 1) * heightScale) > + static_cast(DataTypeUtils::max())) { + nd4j_printf( + "resize_bicubic: Upper overflow occurs for resize height (%f)\n", + ceilf((outHeight - 1) * heightScale)); + return Status::CODE( + ND4J_STATUS_BAD_INPUT, + "resize_bicubic: Upper overflow occurs for resize width"); } - static const Nd4jLong kTableSize = (1 << 10); - float* initCoeffsTable(const double a, cudaStream_t* stream) { - // Allocate and initialize coefficients table using Bicubic - // convolution algorithm. - // https://en.wikipedia.org/wiki/Bicubic_interpolation - float* coeffs_table; // = new float[(kTableSize + 1) * 2]; - auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2)); - if (err != 0) { - throw cuda_exception::build("helpers::initCoeffsTable: Cannot allocate memory for vertical parts rectangulars", err); - } - - - initCoefTableKernel<<<128,128,128, *stream>>>(a, coeffs_table, kTableSize); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::initCoeffsTable: Cannot syncronize kernel", err); + return Status::OK(); + } + + // Calculates all the required variables, and allocates the output. + int validateAndCreateOutput(NDArray const* input, int const width, + int const height) { + return validateAndCalculateOutputSize(input, width, height); + } + + Nd4jLong batchSize; + Nd4jLong outHeight; + Nd4jLong outWidth; + Nd4jLong inHeight; + Nd4jLong inWidth; + Nd4jLong channels; + float heightScale; + float widthScale; + NDArray* output = nullptr; + cudaStream_t* stream; + + private: + bool _alignCorners; + bool _halfPixelCenters; +}; + +struct WeightsAndIndices { + float _weight0; + float _weight1; + float _weight2; + float _weight3; + Nd4jLong _index0; + Nd4jLong _index1; + Nd4jLong _index2; + Nd4jLong _index3; + + int _advance; // advance value. +}; + +class CachedInterpolationCalculator { + public: + _CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {} + + // Advances iteration. Returns the number of values that should be copied from + // the current point to the next point. The copying should always be done by + // copying the last values from the old point to the first + // values of the new point. + inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, + const Nd4jLong x2, const Nd4jLong x3) { + // We use 2 hands and walk through, copying from one to another where + // we already have values. + // Invariant, new_indicies_hand <= cached_values_hand + const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3}; + int cachedValuesHand = 0; + int newIndiciesHand = 0; + while (cachedValuesHand < 4) { + if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) { + if (newIndiciesHand < cachedValuesHand) { + _indexes[newIndiciesHand] = _indexes[cachedValuesHand]; } - - return coeffs_table; + newIndiciesHand++; + } + cachedValuesHand++; } + switch (newIndiciesHand) { + case 0: + _indexes[0] = x0; + case 1: + _indexes[1] = x1; + case 2: + _indexes[2] = x2; + case 3: + _indexes[3] = x3; + break; + } + return newIndiciesHand; + } + + private: + Nd4jLong _indexes[4]; +}; + +static __global__ void initCoefTableKernel(const double a, float* table, + Nd4jLong tableSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (int i = start; i <= tableSize; i += step) { + float x = i * 1.0 / tableSize; + table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1; + x += 1.0; + table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; + } +} + +static const Nd4jLong kTableSize = (1 << 10); +float* initCoeffsTable(const double a, cudaStream_t* stream) { + // Allocate and initialize coefficients table using Bicubic + // convolution algorithm. + // https://en.wikipedia.org/wiki/Bicubic_interpolation + float* coeffs_table; // = new float[(kTableSize + 1) * 2]; + auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2)); + if (err != 0) { + throw cuda_exception::build( + "helpers::initCoeffsTable: Cannot allocate memory for vertical parts " + "rectangulars", + err); + } + + initCoefTableKernel<<<128, 128, 128, *stream>>>(a, coeffs_table, kTableSize); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::initCoeffsTable: Cannot syncronize kernel", err); + } + + return coeffs_table; +} // _CUDA_HD const float* getCoeffsTable(const bool use_keys_cubic) { // // Static so that we initialize it on first use // if (use_keys_cubic) { // // http://ieeexplore.ieee.org/document/1163711/ -// // R. G. Keys. Cubic convolution interpolation for digital image -// // processing. IEEE Transactions on Acoustics, Speech, and Signal +// // R. G. Keys. Cubic convolution interpolation for digital +// image +// // processing. IEEE Transactions on Acoustics, Speech, and +// Signal // // Processing, 29(6):1153–1160, 1981. -// //static const float* coeffs_table = initCoeffsTable(-0.5f, stream); -// return sCoeffsTableHalf; +// //static const float* coeffs_table = initCoeffsTable(-0.5f, +// stream); return sCoeffsTableHalf; // } else { -// //static const float* coeffs_table = initCoeffsTable(-0.75f, stream); -// return sCoeffsTableThreeFourth; +// //static const float* coeffs_table = initCoeffsTable(-0.75f, +// stream); return sCoeffsTableThreeFourth; // } // } - inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { - return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); - } - +inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) { + return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val)); +} - template - inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3, - const T value0, const T value1, const T value2, const T value3) { - return static_cast(value0) * weight0 + - static_cast(value1) * weight1 + - static_cast(value2) * weight2 + - static_cast(value3) * weight3; - } +template +inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, + const float weight2, const float weight3, + const T value0, const T value1, + const T value2, const T value3) { + return static_cast(value0) * weight0 + + static_cast(value1) * weight1 + + static_cast(value2) * weight2 + + static_cast(value3) * weight3; +} // Compute the 1D interpolation for a given X index using the y_weights - static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) { - return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]); - } - - +static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, + const float xW2, const float xW3) { + return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1], values[2], + values[3]); +} - template - inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) { - const Scaler scaler; - const float in_loc_f = scaler(out_loc, scale); - const Nd4jLong in_loc = math::nd4j_floor(in_loc_f); - const float delta = in_loc_f - in_loc; - const Nd4jLong offset = math::nd4j_round(delta * kTableSize); - //const float* coeffs_table = getCoeffsTable(use_keys_cubic); - if (use_keys_cubic) { - // The legacy code placed more weight on the edge pixels, since bounding - // the set of inputs to sample could cause an edge pixel to be repeated. - // Here we change the behavior at borders to match that used by the - // scale_and_translate_op, where sampling locations outside the image have - // their weight set to 0, and the weights are renormalized so that their sum - // is 1.0. - out->_index0 = bound(in_loc - 1, limit); - out->_weight0 = - (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); - out->_index1 = bound(in_loc, limit); - out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); - out->_index2 = bound(in_loc + 1, limit); - out->_weight2 = - (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] - : 0.0f); - out->_index3 = bound(in_loc + 2, limit); - out->_weight3 = (out->_index3 == in_loc + 2 - ? coeffs_table[(kTableSize - offset) * 2 + 1] - : 0.0f); - - const float weight_sum = - out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; - if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min()) { - const float one_over_weight_sum = 1.0f / weight_sum; - out->_weight0 *= one_over_weight_sum; - out->_weight1 *= one_over_weight_sum; - out->_weight2 *= one_over_weight_sum; - out->_weight3 *= one_over_weight_sum; - } - } else { - out->_weight0 = coeffs_table[offset * 2 + 1]; - out->_weight1 = coeffs_table[offset * 2]; - out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; - out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; - out->_index0 = bound(in_loc - 1, limit); - out->_index1 = bound(in_loc, limit); - out->_index2 = bound(in_loc + 1, limit); - out->_index3 = bound(in_loc + 2, limit); - } +template +inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, + const float scale, + const Nd4jLong out_loc, + const Nd4jLong limit, + WeightsAndIndices* out) { + const Scaler scaler; + const float in_loc_f = scaler(out_loc, scale); + const Nd4jLong in_loc = math::nd4j_floor(in_loc_f); + const float delta = in_loc_f - in_loc; + const Nd4jLong offset = math::nd4j_round(delta * kTableSize); + // const float* coeffs_table = getCoeffsTable(use_keys_cubic); + if (use_keys_cubic) { + // The legacy code placed more weight on the edge pixels, since bounding + // the set of inputs to sample could cause an edge pixel to be repeated. + // Here we change the behavior at borders to match that used by the + // scale_and_translate_op, where sampling locations outside the image have + // their weight set to 0, and the weights are renormalized so that their sum + // is 1.0. + out->_index0 = bound(in_loc - 1, limit); + out->_weight0 = + (out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f); + out->_index1 = bound(in_loc, limit); + out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f); + out->_index2 = bound(in_loc + 1, limit); + out->_weight2 = + (out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2] + : 0.0f); + out->_index3 = bound(in_loc + 2, limit); + out->_weight3 = (out->_index3 == in_loc + 2 + ? coeffs_table[(kTableSize - offset) * 2 + 1] + : 0.0f); + + const float weight_sum = + out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3; + if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min()) { + const float one_over_weight_sum = 1.0f / weight_sum; + out->_weight0 *= one_over_weight_sum; + out->_weight1 *= one_over_weight_sum; + out->_weight2 *= one_over_weight_sum; + out->_weight3 *= one_over_weight_sum; } + } else { + out->_weight0 = coeffs_table[offset * 2 + 1]; + out->_weight1 = coeffs_table[offset * 2]; + out->_weight2 = coeffs_table[(kTableSize - offset) * 2]; + out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1]; + out->_index0 = bound(in_loc - 1, limit); + out->_index1 = bound(in_loc, limit); + out->_index2 = bound(in_loc + 1, limit); + out->_index3 = bound(in_loc + 2, limit); + } +} - static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto x = start; x < outWidth; x += step) { - pXWais[x]._index0 *= channels; - pXWais[x]._index1 *= channels; - pXWais[x]._index2 *= channels; - pXWais[x]._index3 *= channels; - } - } +static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, + Nd4jLong outWidth, + Nd4jLong channels) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + pXWais[x]._index0 *= channels; + pXWais[x]._index1 *= channels; + pXWais[x]._index2 *= channels; + pXWais[x]._index3 *= channels; + } +} - static __global__ void advaceWeightsAndIndicesKernel(float const* cacheTable, CachedInterpolationCalculator* calc, WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale, - Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto x = start; x < outWidth; x += step) { - if (halfPixelCenters) - getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); - else - getWeightsAndIndices(cacheTable, widthScale, x, inWidth, &pXWais[x]); - pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, pXWais[x]._index2, pXWais[x]._index3); - } - } - // resizerState and xWais are device allocated - static void computeXWeightsAndIndices(float const* coeffsTable, const ImageResizerState& resizerState, - const bool halfPixelCenters, - WeightsAndIndices* pXWais) { - - auto stream = resizerState.stream; - auto outWidth = resizerState.outWidth; - CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator; - CachedInterpolationCalculator* pCalcD; - auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator)); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err); - } - err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice, *stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err); - } +static __global__ void advaceWeightsAndIndicesKernel( + float const* cacheTable, CachedInterpolationCalculator* calc, + WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale, + Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto x = start; x < outWidth; x += step) { + if (halfPixelCenters) + getWeightsAndIndices(cacheTable, widthScale, x, + inWidth, &pXWais[x]); + else + getWeightsAndIndices(cacheTable, widthScale, x, + inWidth, &pXWais[x]); + pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, + pXWais[x]._index2, pXWais[x]._index3); + } +} +// resizerState and xWais are device allocated +static void computeXWeightsAndIndices(float const* coeffsTable, + const ImageResizerState& resizerState, + const bool halfPixelCenters, + WeightsAndIndices* pXWais) { + auto stream = resizerState.stream; + auto outWidth = resizerState.outWidth; + CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator; + CachedInterpolationCalculator* pCalcD; + auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator)); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot allocated device memory " + "for interpolate calculator", + err); + } + err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), + cudaMemcpyHostToDevice, *stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot set up device memory for " + "interpolate calculator", + err); + } + + advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>( + coeffsTable, pCalcD, pXWais, resizerState.inWidth, + resizerState.widthScale, outWidth, resizerState.channels, + halfPixelCenters); + err = cudaFree(pCalcD); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot deallocated device memory " + "for interpolate calculator", + err); + } + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot synchronize stream after " + "advance weights and indicers", + err); + } + // Scale the values so they can be used as offsets into buffers. + accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, + resizerState.channels); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "helpers::computeXWeightsAndIndices: Cannot synchronize stream after " + "accumulate channels", + err); + } +} - advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>(coeffsTable, pCalcD, pXWais, resizerState.inWidth, resizerState.widthScale, outWidth, resizerState.channels, halfPixelCenters); - err = cudaFree(pCalcD); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot deallocated device memory for interpolate calculator", err); - } - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after advance weights and indicers", err); - } - // Scale the values so they can be used as offsets into buffers. - accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, resizerState.channels); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after accumulate channels", err); - } +template +static _CUDA_HD FORCEINLINE float computeYInterpolation( + int which, int channelNum, const WeightsAndIndices& yWai, const T* pY0, + const T* pY1, const T* pY2, const T* pY3, const WeightsAndIndices& xWai) { + int xIndex; + switch (which) { + case 0: + xIndex = xWai._index0; + break; + case 1: + xIndex = xWai._index1; + break; + case 2: + xIndex = xWai._index2; + break; + default: + xIndex = xWai._index3; + break; + } + const Nd4jLong pt_index = xIndex + channelNum; + return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, + yWai._weight3, pY0[pt_index], pY1[pt_index], + pY2[pt_index], pY3[pt_index]); +} - } +template +static __global__ void bicubicInterpolateWithCachingKernel( + float const* cachedTable, T const* inputPtr, + ImageResizerState* pResizerState, WeightsAndIndices* xWais, + bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, + float* outputPtr) { + // auto numChannels = pResizerState->channels; + + for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { + auto pInput = inputPtr + b * inBatchWidth; + float* cachedValue; + for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; + y += blockDim.x) { + if (threadIdx.x == 0) { + extern __shared__ char sharedChar[]; + cachedValue = reinterpret_cast(sharedChar); + } + auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * + pResizerState->channels; + auto pOutput = &outputPtr[pos]; + struct WeightsAndIndices yWai; + if (halfPixelCenters) { + getWeightsAndIndices( + cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, + &yWai); + } else { + getWeightsAndIndices( + cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, + &yWai); + } + // Make pointers represent offsets of data in inputBPtr. + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; + + if (pResizerState->channels == 3) { + // Manually unroll case of 3 channels. + float cached_value_0[4] = {0}; + float cached_value_1[4] = {0}; + float cached_value_2[4] = {0}; + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cached_value_* to fill first '_advance' values. + switch (xWai._advance) { + case 3: + cached_value_0[0] = cached_value_0[1]; + cached_value_0[1] = cached_value_0[2]; + cached_value_0[2] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[1]; + cached_value_1[1] = cached_value_1[2]; + cached_value_1[2] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[1]; + cached_value_2[1] = cached_value_2[2]; + cached_value_2[2] = cached_value_2[3]; + break; + case 2: + cached_value_0[0] = cached_value_0[2]; + cached_value_0[1] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[2]; + cached_value_1[1] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[2]; + cached_value_2[1] = cached_value_2[3]; + break; + case 1: { + cached_value_0[0] = cached_value_0[3]; + cached_value_1[0] = cached_value_1[3]; + cached_value_2[0] = cached_value_2[3]; + break; + } + } - template - static _CUDA_HD FORCEINLINE float computeYInterpolation( - int which, int channelNum, const WeightsAndIndices& yWai, - const T* pY0, const T* pY1, const T* pY2, const T* pY3, - const WeightsAndIndices& xWai) { - int xIndex; - switch (which) { + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { case 0: - xIndex = xWai._index0; - break; + cached_value_0[0] = computeYInterpolation( + 0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[0] = computeYInterpolation( + 0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[0] = computeYInterpolation( + 0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); case 1: - xIndex = xWai._index1; - break; + cached_value_0[1] = computeYInterpolation( + 1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[1] = computeYInterpolation( + 1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[1] = computeYInterpolation( + 1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); case 2: - xIndex = xWai._index2; - break; - default: - xIndex = xWai._index3; - break; + cached_value_0[2] = computeYInterpolation( + 2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[2] = computeYInterpolation( + 2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[2] = computeYInterpolation( + 2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + case 3: + cached_value_0[3] = computeYInterpolation( + 3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_1[3] = computeYInterpolation( + 3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + cached_value_2[3] = computeYInterpolation( + 3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + // break; + } + pOutput[x * pResizerState->channels + 0] = + compute(cached_value_0, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 1] = + compute(cached_value_1, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + pOutput[x * pResizerState->channels + 2] = + compute(cached_value_2, xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); } - const Nd4jLong pt_index = xIndex + channelNum; - return interpolate1D(yWai._weight0, yWai._weight1, yWai._weight2, - yWai._weight3, pY0[pt_index], pY1[pt_index], - pY2[pt_index], pY3[pt_index]); - } - - template - static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { -// auto numChannels = pResizerState->channels; - - for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { - auto pInput = inputPtr + b * inBatchWidth; - float* cachedValue; - for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) { - if (threadIdx.x == 0) { - extern __shared__ char sharedChar[]; - cachedValue = reinterpret_cast(sharedChar); - } - auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels; - auto pOutput = &outputPtr[pos]; - struct WeightsAndIndices yWai; - if (halfPixelCenters) { - getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); - } else { - getWeightsAndIndices(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai); - } - // Make pointers represent offsets of data in inputBPtr. - const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; - - if (pResizerState->channels == 3) { - // Manually unroll case of 3 channels. - float cached_value_0[4] = {0}; - float cached_value_1[4] = {0}; - float cached_value_2[4] = {0}; - for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { - const WeightsAndIndices& xWai = xWais[x]; - // Shift values in cached_value_* to fill first '_advance' values. - switch (xWai._advance) { - case 3: - cached_value_0[0] = cached_value_0[1]; - cached_value_0[1] = cached_value_0[2]; - cached_value_0[2] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[1]; - cached_value_1[1] = cached_value_1[2]; - cached_value_1[2] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[1]; - cached_value_2[1] = cached_value_2[2]; - cached_value_2[2] = cached_value_2[3]; - break; - case 2: - cached_value_0[0] = cached_value_0[2]; - cached_value_0[1] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[2]; - cached_value_1[1] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[2]; - cached_value_2[1] = cached_value_2[3]; - break; - case 1: { - cached_value_0[0] = cached_value_0[3]; - cached_value_1[0] = cached_value_1[3]; - cached_value_2[0] = cached_value_2[3]; - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - cached_value_0[0] = computeYInterpolation(0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[0] = computeYInterpolation(0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[0] = computeYInterpolation(0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 1: - cached_value_0[1] = computeYInterpolation(1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[1] = computeYInterpolation(1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[1] = computeYInterpolation(1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 2: - cached_value_0[2] = computeYInterpolation(2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[2] = computeYInterpolation(2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[2] = computeYInterpolation(2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - case 3: - cached_value_0[3] = computeYInterpolation(3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_1[3] = computeYInterpolation(3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - cached_value_2[3] = computeYInterpolation(3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - // break; - } - pOutput[x * pResizerState->channels + 0] = compute(cached_value_0, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * pResizerState->channels + 1] = compute(cached_value_1, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - pOutput[x * pResizerState->channels + 2] = compute(cached_value_2, xWai._weight0, xWai._weight1, - xWai._weight2, xWai._weight3); - } - } else { - for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { - const WeightsAndIndices& xWai = xWais[x]; - // Shift values in cachedValue to fill first '_advance' values. - switch (xWai._advance) { - case 3: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; - } - break; - case 2: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; - cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; - } - break; - case 1: { - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; - } - break; - } - } - - // Set the remaining '4-_advance' values by computing. - switch (xWai._advance) { - case 0: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 0] = computeYInterpolation(0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 1: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 1] = computeYInterpolation(1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 2: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 2] = computeYInterpolation(2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - case 3: - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - cachedValue[4 * c + 3] = computeYInterpolation(3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); - } - // break; - } - for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { - pOutput[x * pResizerState->channels + c] = compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3); - } - } - } + } else { + for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) { + const WeightsAndIndices& xWai = xWais[x]; + // Shift values in cachedValue to fill first '_advance' values. + switch (xWai._advance) { + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 1]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 2] = cachedValue[4 * c + 3]; + } + break; + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 2]; + cachedValue[4 * c + 1] = cachedValue[4 * c + 3]; + } + break; + case 1: { + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = cachedValue[4 * c + 3]; + } + break; } - } - - } - - - template - static void - bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { - const auto numChannels = resizerState.channels; - const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; - const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; - - auto stream = resizerState.stream; //output->getContext()->getCudaStream(); - ImageResizerState* resizerStateD; - auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState)); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err); - } - err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err); - } - -// float* cachedValue = nullptr; -// size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels); -// if (cachedSize) { -// err = cudaMalloc(reinterpret_cast(&cachedValue), cachedSize); -// if (err != 0) { -// throw cuda_exception::build( -// "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err); -// } -// err = cudaMemset(cachedValue, 0, cachedSize); -// if (err != 0) { -// throw cuda_exception::build( -// "helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err); -// } -// } + } - WeightsAndIndices* xWais; //(resizerState.outWidth); - err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for weights and indices", err); - } - - auto coeffsTable = halfPixelCenters?initCoeffsTable(-0.5, stream): initCoeffsTable(-0.75, stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); - } - computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais); - err = cudaStreamQuery(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); - } - const T* pInput = image->getDataBuffer()->specialAsT(); - float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); - bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, pInput, - resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); - err = cudaStreamSynchronize(*stream); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Kernels finished with error", err); + // Set the remaining '4-_advance' values by computing. + switch (xWai._advance) { + case 0: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 0] = computeYInterpolation( + 0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 1: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 1] = computeYInterpolation( + 1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 2: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 2] = computeYInterpolation( + 2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + case 3: + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + cachedValue[4 * c + 3] = computeYInterpolation( + 3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai); + } + // break; + } + for (Nd4jLong c = 0; c < pResizerState->channels; ++c) { + pOutput[x * pResizerState->channels + c] = + compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + xWai._weight2, xWai._weight3); + } } + } + } + } +} - err = cudaFree(resizerStateD); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err); - } -// if (cachedSize) -// err = cudaFree(cachedValue); -// if (err != 0) { -// throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err); -// } +template +static void bicubicInterpolateWithCaching(NDArray const* image, + ImageResizerState const& resizerState, + bool const halfPixelCenters, + NDArray* output) { + const auto numChannels = resizerState.channels; + const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; + const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + + auto stream = resizerState.stream; // output->getContext()->getCudaStream(); + ImageResizerState* resizerStateD; + auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState)); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for " + "resizerState", + err); + } + err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), + cudaMemcpyHostToDevice, *stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot set up memory for " + "resizerState", + err); + } + + // float* cachedValue = nullptr; + // size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * + // numChannels); if (cachedSize) { + // err = cudaMalloc(reinterpret_cast(&cachedValue), + // cachedSize); if (err != 0) { + // throw cuda_exception::build( + // "helpers::bicubicInterpolateWithCaching: Cannot + // allocate memory for cached values", err); + // } + // err = cudaMemset(cachedValue, 0, cachedSize); + // if (err != 0) { + // throw cuda_exception::build( + // "helpers::bicubicInterpolateWithCaching: Cannot set + // up memory for cached values", err); + // } + // } + + WeightsAndIndices* xWais; //(resizerState.outWidth); + err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot allocate memory for " + "weights and indices", + err); + } + + auto coeffsTable = halfPixelCenters ? initCoeffsTable(-0.5, stream) + : initCoeffsTable(-0.75, stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces " + "finished with error", + err); + } + computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais); + err = cudaStreamQuery(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces " + "finished with error", + err); + } + const T* pInput = image->getDataBuffer()->specialAsT(); + float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); + bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>( + coeffsTable, pInput, resizerStateD, xWais, halfPixelCenters, inBatchWidth, + inRowWidth, pOutput); + err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Kernels finished with error", + err); + } + + err = cudaFree(resizerStateD); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "resizerState", + err); + } + // if (cachedSize) + // err = cudaFree(cachedValue); + // if (err != 0) { + // throw + // cuda_exception::build("helpers::bicubicInterpolateWithCaching: + // Cannot deallocate memory for cached values", err); + // } + + err = cudaFree(xWais); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "weights and indices", + err); + } + + err = cudaFree(coeffsTable); + if (err != 0) { + throw cuda_exception::build( + "helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for " + "coefficients table", + err); + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +int resizeBicubicFunctor_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + return Status::OK(); +} - err = cudaFree(xWais); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for weights and indices", err); - } +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctor_, + (context, image, width, height, preserveAspectRatio, antialias, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, + (sd::LaunchContext * context, NDArray const* image, + int width, int height, bool preserveAspectRatio, + bool antialias, NDArray* output), + NUMERIC_TYPES); +// ------------------------------------------------------------------------------------------------------------------ +// // +struct CachedInterpolation { + Nd4jLong start; + Nd4jLong end; + float startScale; + float endMinusOneScale; + bool needsBounding; +}; + +static __global__ void fillInterpolationCache(CachedInterpolation* xCached, + Nd4jLong cacheLen, + Nd4jLong inWidth, + float widthScale) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto increment = blockDim.x * gridDim.x; + + for (auto x = start; x < cacheLen; x += increment) { + auto& xCache = xCached[x]; + const float inX = x * widthScale; + const float inX1 = (x + 1) * widthScale; + + Nd4jLong v = math::nd4j_floor(inX); + xCache.start = v; + xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + v = math::nd4j_ceil(inX1); + xCache.end = v--; + xCache.endMinusOneScale = v < inX + ? (v + 1 > inX1 ? widthScale : v + 1 - inX) + : (v + 1 > inX1 ? inX1 - v : 1.f); + xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || + bound(xCache.end - 1, inWidth) != (xCache.end - 1); + } +} - err = cudaFree(coeffsTable); - if (err != 0) { - throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for coefficients table", err); - } +// ------------------------------------------------------------------------------------------------------------------ +// // +template +struct ScaleCache { + float yScale; + T const* yPtr; +}; +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static __device__ void computePatchSumOf3Channels( + float scale, const ImageResizerState& st, ScaleCache const* yScaleCache, + Nd4jLong ptrsLen, const CachedInterpolation& xCache, float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (int i = 0; i < ptrsLen; ++i) { + const T* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); + float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; + float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; + float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; + + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scaleX = xCache.endMinusOneScale; + offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); + sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; + sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; + sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - int resizeBicubicFunctor_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - return Status::OK(); - } + sum_0 += sum_y_0 * yScaleCache[i].yScale; + sum_1 += sum_y_1 * yScaleCache[i].yScale; + sum_2 += sum_y_2 * yScaleCache[i].yScale; + } + + outputPtr[0] = sum_0 * scale; + outputPtr[1] = sum_1 * scale; + outputPtr[2] = sum_2 * scale; +} - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image, - width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (sd::LaunchContext * context, NDArray const* image, int width, int height, - bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES); -// ------------------------------------------------------------------------------------------------------------------ // - struct CachedInterpolation { - Nd4jLong start; - Nd4jLong end; - float startScale; - float endMinusOneScale; - bool needsBounding; - }; - - static __global__ void fillInterpolationCache(CachedInterpolation* xCached, Nd4jLong cacheLen, Nd4jLong inWidth, float widthScale) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto increment = blockDim.x * gridDim.x; - - for (auto x = start; x < cacheLen; x += increment) { - auto& xCache = xCached[x]; - const float inX = x * widthScale; - const float inX1 = (x + 1) * widthScale; - - Nd4jLong v = math::nd4j_floor(inX); - xCache.start = v; - xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); - v = math::nd4j_ceil(inX1); - xCache.end = v--; - xCache.endMinusOneScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f); - xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || bound(xCache.end - 1, inWidth) != (xCache.end - 1); +// Computes the sum of all x values defined by taken across +// the y offsets and scales defined by y_ptrs and y_scales, for channel c. +// +// Note that is a template parameter to avoid a performance +// penalty from dynamically checking it. +template +static __device__ void computePatchSum(float scale, const ImageResizerState& st, + ScaleCache const* yScaleCache, + Nd4jLong ptrsLen, + const CachedInterpolation& xCache, + float* outputPtr) { + bool const needsXBounding = xCache.needsBounding; + + auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { + return (needsXBounding ? bound(x, y) : (x)); + }; + + const auto numChannels = st.channels; + for (Nd4jLong c = 0; c < numChannels; ++c) { + float sum = 0; + for (int i = 0; i < ptrsLen; ++i) { + T const* ptr = yScaleCache[i].yPtr; + float scaleX = xCache.startScale; + float sumY = + static_cast( + ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * + scaleX; + if (xCache.start + 1 != xCache.end) { + for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { + sumY += static_cast( + ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); } + scaleX = xCache.endMinusOneScale; + sumY += + static_cast( + ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + + c]) * + scaleX; + } + sum += sumY * yScaleCache[i].yScale; } + outputPtr[c] = sum * scale; + } +} -// ------------------------------------------------------------------------------------------------------------------ // - template - struct ScaleCache { - float yScale; - T const* yPtr; - }; - - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static __device__ void computePatchSumOf3Channels(float scale, - const ImageResizerState& st, - ScaleCache const* yScaleCache, - Nd4jLong ptrsLen, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - float sum_0 = 0; - float sum_1 = 0; - float sum_2 = 0; - for (int i = 0; i < ptrsLen; ++i) { - const T* ptr = yScaleCache[i].yPtr; - float scaleX = xCache.startScale; - Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth); - float sum_y_0 = static_cast(ptr[offset + 0]) * scaleX; - float sum_y_1 = static_cast(ptr[offset + 1]) * scaleX; - float sum_y_2 = static_cast(ptr[offset + 2]) * scaleX; - - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]); - sum_y_1 += static_cast(ptr[offset + 1]); - sum_y_2 += static_cast(ptr[offset + 2]); - } - scaleX = xCache.endMinusOneScale; - offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth); - sum_y_0 += static_cast(ptr[offset + 0]) * scaleX; - sum_y_1 += static_cast(ptr[offset + 1]) * scaleX; - sum_y_2 += static_cast(ptr[offset + 2]) * scaleX; - } - sum_0 += sum_y_0 * yScaleCache[i].yScale; - sum_1 += sum_y_1 * yScaleCache[i].yScale; - sum_2 += sum_y_2 * yScaleCache[i].yScale; +template +static __global__ void resizeAreaKernel( + ImageResizerState const* pSt, CachedInterpolation const* caches, + float scale, T const* inputPtr, Nd4jLong const* inputShape, + float* outputPtr, Nd4jLong const* outputShape, + ScaleCache* cachePool) { // batch * outWidth * outHeight + + for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) { + for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) { + const float inY = y * pSt->heightScale; + const float inY1 = (y + 1) * pSt->heightScale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + const Nd4jLong yStart = math::nd4j_floor(inY); + const Nd4jLong yEnd = math::nd4j_ceil(inY1); + auto scalesDim = yEnd - yStart; + auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * + sizeof(ScaleCache); + + // auto startPtr = sharedPtr + y * scalesDim * sizeof(float); + // float* yScales = yScalesShare + y * sizeof(float) * + // scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim + // * y + scalesDim * sizeof(T const *) [scalesDim]; T const** yPtrs = + // yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; yPtrs = + // reinterpret_cast(sharedBuf); + float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * + pSt->outWidth; + // int k = 0; + for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) { + float scaleY; + if (i < inY) { + scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY); + } else { + scaleY = (i + 1 > inY1 ? inY1 - i : 1.0); } - - outputPtr[0] = sum_0 * scale; - outputPtr[1] = sum_1 * scale; - outputPtr[2] = sum_2 * scale; - } - - // Computes the sum of all x values defined by taken across - // the y offsets and scales defined by y_ptrs and y_scales, for channel c. - // - // Note that is a template parameter to avoid a performance - // penalty from dynamically checking it. - template - static __device__ void computePatchSum(float scale, const ImageResizerState& st, - ScaleCache const* yScaleCache, Nd4jLong ptrsLen, - const CachedInterpolation& xCache, - float* outputPtr) { - - bool const needsXBounding = xCache.needsBounding; - - auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong { - return (needsXBounding ? bound(x, y) : (x)); - }; - - const auto numChannels = st.channels; - for (Nd4jLong c = 0; c < numChannels; ++c) { - float sum = 0; - for (int i = 0; i < ptrsLen; ++i) { - T const* ptr = yScaleCache[i].yPtr; - float scaleX = xCache.startScale; - float sumY = static_cast(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX; - if (xCache.start + 1 != xCache.end) { - for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) { - sumY += static_cast( - ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]); - } - scaleX = xCache.endMinusOneScale; - sumY += static_cast(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX; - } - sum += sumY * yScaleCache[i].yScale; - } - outputPtr[c] = sum * scale; + yScaleCache[k].yScale = scaleY; + yScaleCache[k].yPtr = + inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels); + } + + if (pSt->channels == 3) { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSumOf3Channels(scale, *pSt, yScaleCache, scalesDim, + xCache, output); + output += pSt->channels; } - } - - template - static __global__ void resizeAreaKernel(ImageResizerState const* pSt, CachedInterpolation const* caches, float scale, - T const* inputPtr, Nd4jLong const* inputShape, float* outputPtr, Nd4jLong const* outputShape, ScaleCache* cachePool) { //batch * outWidth * outHeight - - for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) { - for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) { - const float inY = y * pSt->heightScale; - const float inY1 = (y + 1) * pSt->heightScale; - // The start and end height indices of all the cells that could - // contribute to the target cell. - const Nd4jLong yStart = math::nd4j_floor(inY); - const Nd4jLong yEnd = math::nd4j_ceil(inY1); - auto scalesDim = yEnd - yStart; - auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * sizeof(ScaleCache); - - //auto startPtr = sharedPtr + y * scalesDim * sizeof(float); - //float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim]; - //T const** yPtrs = yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; - //yPtrs = reinterpret_cast(sharedBuf); - float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * pSt->outWidth; - //int k = 0; - for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) { - float scaleY; - if (i < inY) { - scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY); - } else { - scaleY = (i + 1 > inY1 ? inY1 - i : 1.0); - } - yScaleCache[k].yScale = scaleY; - yScaleCache[k].yPtr = inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels); - } - - if (pSt->channels == 3) { - for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { - const CachedInterpolation& xCache = caches[x]; - computePatchSumOf3Channels(scale, *pSt, yScaleCache, scalesDim, xCache, output); - output += pSt->channels; - } - } else { - for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { - const CachedInterpolation &xCache = caches[x]; - computePatchSum(scale, *pSt, yScaleCache, scalesDim, xCache, output); - output += pSt->channels; - } - } - } + } else { + for (Nd4jLong x = 0; x < pSt->outWidth; ++x) { + const CachedInterpolation& xCache = caches[x]; + computePatchSum(scale, *pSt, yScaleCache, scalesDim, xCache, + output); + output += pSt->channels; } + } } + } +} - template - static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, CachedInterpolation* cache, - NDArray const* input, NDArray* output) { - - T const* inputPtr = reinterpret_cast(input->specialBuffer()); -// float* yScales; -// T const** yPtrs; - float scale = 1.f / (st.heightScale * st.widthScale); - auto outputPtr = reinterpret_cast(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template declaration - ImageResizerState* pSt; - auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); - err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); - ScaleCache* cachePool; - err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight); - resizeAreaKernel<<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, - output->specialShapeInfo(), cachePool); - err = cudaStreamSynchronize(*stream); - err = cudaFree(cachePool); - err = cudaFree(pSt); - } -// ------------------------------------------------------------------------------------------------------------------ // - template - int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - - ImageResizerState st(alignCorners, false); // Create resize info - auto res = st.validateAndCalculateOutputSize(image, width, height); - auto stream = context->getCudaStream(); - if (Status::OK() == res) { - CachedInterpolation* xCached; - //(st.outWidth); - auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth); - NDArray::prepareSpecialUse({output}, {image}); - fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale); - resizeArea(stream, st, xCached, image, output); - err = cudaStreamSynchronize(*stream); - err = cudaFree(xCached); - NDArray::registerSpecialUse({output}, {image}); - } - - return res; - } - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES); - } +template +static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, + CachedInterpolation* cache, NDArray const* input, + NDArray* output) { + T const* inputPtr = reinterpret_cast(input->specialBuffer()); + // float* yScales; + // T const** yPtrs; + float scale = 1.f / (st.heightScale * st.widthScale); + auto outputPtr = reinterpret_cast( + output->specialBuffer()); // output is always float. TO DO: provide + // another float types also with template + // declaration + ImageResizerState* pSt; + auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); + err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), + cudaMemcpyHostToDevice, *stream); + ScaleCache* cachePool; + err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * + st.outWidth * st.outHeight); + resizeAreaKernel<<<128, 2, 2048, *stream>>>( + pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, + output->specialShapeInfo(), cachePool); + err = cudaStreamSynchronize(*stream); + err = cudaFree(cachePool); + err = cudaFree(pSt); +} +// ------------------------------------------------------------------------------------------------------------------ +// // +template +int resizeAreaFunctor_(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + ImageResizerState st(alignCorners, false); // Create resize info + auto res = st.validateAndCalculateOutputSize(image, width, height); + auto stream = context->getCudaStream(); + if (Status::OK() == res) { + CachedInterpolation* xCached; + //(st.outWidth); + auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth); + NDArray::prepareSpecialUse({output}, {image}); + fillInterpolationCache<<<128, 128, 256, *stream>>>( + xCached, st.outWidth, st.inWidth, st.widthScale); + resizeArea(stream, st, xCached, image, output); + err = cudaStreamSynchronize(*stream); + err = cudaFree(xCached); + NDArray::registerSpecialUse({output}, {image}); + } + + return res; +} +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output) { + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, + (context, image, width, height, alignCorners, output), + NUMERIC_TYPES); +} -// ------------------------------------------------------------------------------------------------------------------ // -// simplified bicubic resize without antialiasing +// ------------------------------------------------------------------------------------------------------------------ +// // simplified bicubic resize without antialiasing // - template - int resizeBicubicFunctorA_(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - - ImageResizerState st(alignCorners, halfPixelCenters); // align_corners, half_pixel_align - st.stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {image}); - int res = st.validateAndCreateOutput(image, width, height); - if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelCenters, output); - NDArray::registerSpecialUse({output}, {image}); - return res; - } +template +int resizeBicubicFunctorA_(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output) { + ImageResizerState st(alignCorners, + halfPixelCenters); // align_corners, half_pixel_align + st.stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {image}); + int res = st.validateAndCreateOutput(image, width, height); + if (res == Status::OK()) + bicubicInterpolateWithCaching(image, st, halfPixelCenters, output); + NDArray::registerSpecialUse({output}, {image}); + return res; +} - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int width, int height, - bool const alignCorners, bool const halfPixelCenters, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, - image, width, height, alignCorners, halfPixelCenters, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, (sd::LaunchContext * context, - NDArray const* image, int width, int height, bool const alignCorners, bool const halfPixelCenters, NDArray* output), NUMERIC_TYPES); +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output) { + BUILD_SINGLE_SELECTOR( + image->dataType(), return resizeBicubicFunctorA_, + (context, image, width, height, alignCorners, halfPixelCenters, output), + NUMERIC_TYPES); +} +BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, + (sd::LaunchContext * context, NDArray const* image, + int width, int height, bool const alignCorners, + bool const halfPixelCenters, NDArray* output), + NUMERIC_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int width, int height, - ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { - switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; - case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; - case kResizeLanczos5: - case kResizeGaussian: - case kResizeArea: - case kResizeMitchelcubic: - throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); - } - return ND4J_STATUS_OK; +int resizeFunctor(sd::LaunchContext* context, NDArray const* image, int width, + int height, ImageResizeMethods method, + bool preserveAspectRatio, bool antialias, NDArray* output) { + switch (method) { + case kResizeBilinear: + return resizeBilinearFunctor(context, image, width, height, false, false, + output); + break; + case kResizeNearest: + return resizeNeighborFunctor(context, image, width, height, false, false, + output); + break; + case kResizeBicubic: + return resizeBicubicFunctor(context, image, width, height, + preserveAspectRatio, antialias, output); + break; + case kResizeLanczos5: + case kResizeGaussian: + case kResizeArea: + case kResizeMitchelcubic: + throw std::runtime_error("helper::resizeFunctor: Non implemented yet."); + } + return ND4J_STATUS_OK; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// --------------------------------------------------------------------------------------------------------------- +// // Crop and Resize helper implementation +// -------------------------------------------------------------------------------------------------------------- +// // cropAndResize kernel type of input(images) and output should be the same +// +template +static __global__ void cropAndResizeKernel( + T const* images, Nd4jLong const* imagesShape, Z const* boxes, + Nd4jLong const* boxesShape, I const* indices, Nd4jLong const* indexShape, + I const* cropSize, Nd4jLong const* cropShape, int method, + double extrapolationVal, T* output, Nd4jLong const* outputShape, + int numBoxes, int cropHeight, int cropWidth, int batchSize, int imageHeight, + int imageWidth, int depth) { + for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) { + Nd4jLong x1Pos[] = {b, 1}; + Nd4jLong y1Pos[] = {b, 0}; + Nd4jLong y2Pos[] = {b, 2}; + Nd4jLong x2Pos[] = {b, 3}; + Z y1 = boxes[shape::getOffset(boxesShape, y1Pos)]; //->t(b, 0)]; + Z x1 = boxes[shape::getOffset(boxesShape, x1Pos)]; + Z y2 = boxes[shape::getOffset(boxesShape, y2Pos)]; + Z x2 = boxes[shape::getOffset(boxesShape, x2Pos)]; + + int bIn = indices[b]; + if (bIn >= batchSize) { + continue; } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // --------------------------------------------------------------------------------------------------------------- // - // Crop and Resize helper implementation - // -------------------------------------------------------------------------------------------------------------- // - // cropAndResize kernel type of input(images) and output should be the same - // - template - static __global__ void cropAndResizeKernel(T const *images, Nd4jLong const* imagesShape, Z const* boxes, Nd4jLong const* boxesShape, - I const* indices, Nd4jLong const* indexShape, I const* cropSize, Nd4jLong const* cropShape, int method, - double extrapolationVal, T* output, Nd4jLong const* outputShape, int numBoxes, int cropHeight, int cropWidth, - int batchSize, int imageHeight, int imageWidth, int depth) { - - for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) - { - Nd4jLong x1Pos[] = {b, 1}; - Nd4jLong y1Pos[] = {b, 0}; - Nd4jLong y2Pos[] = {b, 2}; - Nd4jLong x2Pos[] = {b, 3}; - Z y1 = boxes[shape::getOffset(boxesShape, y1Pos)];//->t(b, 0)]; - Z x1 = boxes[shape::getOffset(boxesShape, x1Pos)]; - Z y2 = boxes[shape::getOffset(boxesShape, y2Pos)]; - Z x2 = boxes[shape::getOffset(boxesShape, x2Pos)]; - - int bIn = indices[b]; - if (bIn >= batchSize) { - continue; + Z heightScale = (cropHeight > 1) + ? (y2 - y1) * (imageHeight - 1) / Z(cropHeight - 1) + : Z(0); + Z widthScale = (cropWidth > 1) + ? (x2 - x1) * (imageWidth - 1) / Z(cropWidth - 1) + : Z(0); + + for (int y = threadIdx.x; y < cropHeight; y += blockDim.x) { + const float inY = (cropHeight > 1) + ? y1 * (imageHeight - 1) + y * heightScale + : 0.5 * (y1 + y2) * (imageHeight - 1); + if (inY < 0 || inY > imageHeight - 1) { + for (int x = threadIdx.y; x < cropWidth; x += blockDim.y) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; + // crops->p(b, y, x, d, extrapolationVal); + } + } + continue; + } + + if (method == 0 /* bilinear */) { + const int topYIndex = sd::math::p_floor(inY); + const int bottomYIndex = sd::math::p_ceil(inY); + const float y_lerp = inY - topYIndex; + + for (int x = 0; x < cropWidth; ++x) { + const float in_x = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); + if (in_x < 0 || in_x > imageWidth - 1) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; + // crops->p(b, y, x, d, + // extrapolationVal); } - - Z heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / Z(cropHeight - 1) : Z(0); - Z widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / Z(cropWidth - 1) : Z(0); - - for (int y = threadIdx.x; y < cropHeight; y += blockDim.x) { - const float inY = (cropHeight > 1) - ? y1 * (imageHeight - 1) + y * heightScale - : 0.5 * (y1 + y2) * (imageHeight - 1); - if (inY < 0 || inY > imageHeight - 1) { - for (int x = threadIdx.y; x < cropWidth; x += blockDim.y) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; - //crops->p(b, y, x, d, extrapolationVal); - } - } - continue; - } - - if (method == 0 /* bilinear */) { - const int topYIndex = sd::math::p_floor(inY); - const int bottomYIndex = sd::math::p_ceil(inY); - const float y_lerp = inY - topYIndex; - - for (int x = 0; x < cropWidth; ++x) { - const float in_x = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); - if (in_x < 0 || in_x > imageWidth - 1) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; -// crops->p(b, y, x, d, extrapolationVal); - } - continue; - } - int left_x_index = math::p_floor(in_x); - int right_x_index = math::p_ceil(in_x); - T x_lerp = in_x - left_x_index; - - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong topLeftPos[] = {bIn, topYIndex, left_x_index, d}; - Nd4jLong topRightPos[] = {bIn, topYIndex, right_x_index, d}; - Nd4jLong bottomLeftPos[] = {bIn, bottomYIndex, left_x_index, d}; - Nd4jLong bottomRightPos[] = {bIn, bottomYIndex, right_x_index, d}; - const T topLeft(images[shape::getOffset(imagesShape, topLeftPos)]); //->e(bIn, topYIndex, left_x_index, d)); - const T topRight(images[shape::getOffset(imagesShape, topRightPos)]); //->e(bIn, topYIndex, right_x_index, d)); - const T bottomLeft(images[shape::getOffset(imagesShape, bottomLeftPos)]);//->e(bIn, bottomYIndex, left_x_index, d)); - const T bottomRight(images[shape::getOffset(imagesShape, bottomRightPos)]); //->e(bIn, bottomYIndex, right_x_index, d)); - const T top = topLeft + (topRight - topLeft) * x_lerp; - const T bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = Z(top + (bottom - top) * y_lerp); - } - } - } else { // method is "nearest neighbor" - for (int x = 0; x < cropWidth; ++x) { - const float inX = (cropWidth > 1) - ? x1 * (imageWidth - 1) + x * widthScale - : 0.5 * (x1 + x2) * (imageWidth - 1); - if (inX < 0 || inX > imageWidth - 1) { - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - output[zIndex] = (Z)extrapolationVal; - } - continue; - } - const int closestXIndex = roundf(inX); - const int closestYIndex = roundf(inY); - auto start = blockIdx.z * blockDim.x + threadIdx.z; - auto step = blockDim.z * gridDim.z; - for (int d = start; d < depth; d += step) { - Nd4jLong zPos[] = {b, y, x, d}; - Nd4jLong xPos[] = {bIn, closestYIndex, closestXIndex, d}; - auto zIndex = shape::getOffset(outputShape, zPos); - auto xIndex = shape::getOffset(imagesShape, xPos); - output[zIndex] = images[xIndex]; - } - } - } + continue; + } + int left_x_index = math::p_floor(in_x); + int right_x_index = math::p_ceil(in_x); + T x_lerp = in_x - left_x_index; + + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong topLeftPos[] = {bIn, topYIndex, left_x_index, d}; + Nd4jLong topRightPos[] = {bIn, topYIndex, right_x_index, d}; + Nd4jLong bottomLeftPos[] = {bIn, bottomYIndex, left_x_index, d}; + Nd4jLong bottomRightPos[] = {bIn, bottomYIndex, right_x_index, d}; + const T topLeft(images[shape::getOffset( + imagesShape, + topLeftPos)]); //->e(bIn, topYIndex, left_x_index, d)); + const T topRight(images[shape::getOffset( + imagesShape, topRightPos)]); //->e(bIn, topYIndex, + //right_x_index, d)); + const T bottomLeft(images[shape::getOffset( + imagesShape, bottomLeftPos)]); //->e(bIn, bottomYIndex, + //left_x_index, d)); + const T bottomRight(images[shape::getOffset( + imagesShape, bottomRightPos)]); //->e(bIn, bottomYIndex, + //right_x_index, d)); + const T top = topLeft + (topRight - topLeft) * x_lerp; + const T bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = Z(top + (bottom - top) * y_lerp); + } + } + } else { // method is "nearest neighbor" + for (int x = 0; x < cropWidth; ++x) { + const float inX = (cropWidth > 1) + ? x1 * (imageWidth - 1) + x * widthScale + : 0.5 * (x1 + x2) * (imageWidth - 1); + if (inX < 0 || inX > imageWidth - 1) { + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + output[zIndex] = (Z)extrapolationVal; } + continue; + } + const int closestXIndex = roundf(inX); + const int closestYIndex = roundf(inY); + auto start = blockIdx.z * blockDim.x + threadIdx.z; + auto step = blockDim.z * gridDim.z; + for (int d = start; d < depth; d += step) { + Nd4jLong zPos[] = {b, y, x, d}; + Nd4jLong xPos[] = {bIn, closestYIndex, closestXIndex, d}; + auto zIndex = shape::getOffset(outputShape, zPos); + auto xIndex = shape::getOffset(imagesShape, xPos); + output[zIndex] = images[xIndex]; + } } - + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // cropAndResizeFunctor main algorithm @@ -1325,43 +1577,59 @@ namespace helpers { // extrapolationVal - double value of extrapolation // crops - output (4D tensor - [batch, outWidth, outHeight, pixels]) // - template - void cropAndResizeFunctor_(sd::LaunchContext* context, NDArray const *images, NDArray const *boxes, NDArray const *indices, - NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - const int batchSize = images->sizeAt(0); - const int imageHeight = images->sizeAt(1); - const int imageWidth = images->sizeAt(2); - - const int numBoxes = crops->sizeAt(0); - const int cropHeight = crops->sizeAt(1); - const int cropWidth = crops->sizeAt(2); - const int depth = crops->sizeAt(3); - auto stream = context->getCudaStream(); - T const* imagesBuf = reinterpret_cast(images->specialBuffer()); - Z const* boxesBuf = reinterpret_cast(boxes->specialBuffer()); - I const* indexBuf = reinterpret_cast(indices->specialBuffer()); - I const* cropSizes = reinterpret_cast(cropSize->specialBuffer()); - T* outBuf = reinterpret_cast(crops->specialBuffer()); - - int threadsPerBlock = math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth); - if(threadsPerBlock > MAX_NUM_THREADS/4) - threadsPerBlock = MAX_NUM_THREADS/4; - - NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); - cropAndResizeKernel<<>>(imagesBuf, images->specialShapeInfo(), boxesBuf, boxes->specialShapeInfo(), indexBuf, indices->specialShapeInfo(), - cropSizes, cropSize->specialShapeInfo(), method, extrapolationVal, outBuf, crops->specialShapeInfo(), numBoxes, cropHeight, cropWidth, batchSize, imageHeight, imageWidth, depth); - NDArray::registerSpecialUse({crops}, {images, boxes, indices, cropSize}); - } +template +void cropAndResizeFunctor_(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops) { + const int batchSize = images->sizeAt(0); + const int imageHeight = images->sizeAt(1); + const int imageWidth = images->sizeAt(2); + + const int numBoxes = crops->sizeAt(0); + const int cropHeight = crops->sizeAt(1); + const int cropWidth = crops->sizeAt(2); + const int depth = crops->sizeAt(3); + auto stream = context->getCudaStream(); + T const* imagesBuf = reinterpret_cast(images->specialBuffer()); + Z const* boxesBuf = reinterpret_cast(boxes->specialBuffer()); + I const* indexBuf = reinterpret_cast(indices->specialBuffer()); + I const* cropSizes = reinterpret_cast(cropSize->specialBuffer()); + T* outBuf = reinterpret_cast(crops->specialBuffer()); + + int threadsPerBlock = + math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth); + if (threadsPerBlock > MAX_NUM_THREADS / 4) + threadsPerBlock = MAX_NUM_THREADS / 4; + + NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); + cropAndResizeKernel<<>>( + imagesBuf, images->specialShapeInfo(), boxesBuf, + boxes->specialShapeInfo(), indexBuf, indices->specialShapeInfo(), + cropSizes, cropSize->specialShapeInfo(), method, extrapolationVal, outBuf, + crops->specialShapeInfo(), numBoxes, cropHeight, cropWidth, batchSize, + imageHeight, imageWidth, depth); + NDArray::registerSpecialUse({crops}, {images, boxes, indices, cropSize}); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void cropAndResizeFunctor(sd::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, int method, double extrapolationVal, NDArray *crops) { - BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, - (context, images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); - // - } - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, - (sd::LaunchContext * context, NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops), - NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); -} +void cropAndResizeFunctor(sd::LaunchContext* context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops) { + BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), + indices->dataType(), cropAndResizeFunctor_, + (context, images, boxes, indices, cropSize, method, + extrapolationVal, crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); + // } -} \ No newline at end of file +BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, + (sd::LaunchContext * context, NDArray const* images, + NDArray const* boxes, NDArray const* indices, + NDArray const* cropSize, int method, + double extrapolationVal, NDArray* crops), + NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 8b7e8ee57074..87f6d6fc7982 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -18,10 +18,11 @@ // @author sgazeos@gmail.com // -#include #include -#include #include +#include +#include + #include namespace sd { @@ -37,378 +38,484 @@ namespace helpers { // // return value: true, if threshold is overcome, false otherwise // - template - static __device__ bool needToSuppressWithThreshold(T* boxes, Nd4jLong const* boxesShape, int previousIndex, int nextIndex, T threshold) { - Nd4jLong previous0[] = {previousIndex, 0}; - Nd4jLong previous1[] = {previousIndex, 1}; - Nd4jLong previous2[] = {previousIndex, 2}; - Nd4jLong previous3[] = {previousIndex, 3}; - Nd4jLong next0[] = {nextIndex, 0}; - Nd4jLong next1[] = {nextIndex, 1}; - Nd4jLong next2[] = {nextIndex, 2}; - Nd4jLong next3[] = {nextIndex, 3}; - - // we have rectangle with given max values. Compute vexes of rectangle first - - T minYPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T minXPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T maxYPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T maxXPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - - // compute areas for comparation - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - // of course, areas should be positive - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - // compute intersection of rectangles - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - // final check - return intersectionValue > threshold; - } +template +static __device__ bool needToSuppressWithThreshold(T* boxes, + Nd4jLong const* boxesShape, + int previousIndex, + int nextIndex, T threshold) { + Nd4jLong previous0[] = {previousIndex, 0}; + Nd4jLong previous1[] = {previousIndex, 1}; + Nd4jLong previous2[] = {previousIndex, 2}; + Nd4jLong previous3[] = {previousIndex, 3}; + Nd4jLong next0[] = {nextIndex, 0}; + Nd4jLong next1[] = {nextIndex, 1}; + Nd4jLong next2[] = {nextIndex, 2}; + Nd4jLong next3[] = {nextIndex, 3}; + + // we have rectangle with given max values. Compute vexes of rectangle first + + T minYPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T minXPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T maxYPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T maxXPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + + // compute areas for comparation + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + // of course, areas should be positive + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + // compute intersection of rectangles + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + // final check + return intersectionValue > threshold; +} - template - static __device__ T similirityV3(T* boxes, Nd4jLong const* boxesShape, int previousIndex, int nextIndex) { - Nd4jLong previous0[] = {previousIndex, 0}; - Nd4jLong previous1[] = {previousIndex, 1}; - Nd4jLong previous2[] = {previousIndex, 2}; - Nd4jLong previous3[] = {previousIndex, 3}; - Nd4jLong next0[] = {nextIndex, 0}; - Nd4jLong next1[] = {nextIndex, 1}; - Nd4jLong next2[] = {nextIndex, 2}; - Nd4jLong next3[] = {nextIndex, 3}; - - // we have rectangle with given max values. Compute vexes of rectangle first - - T minYPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T minXPrev = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T maxYPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]); - T maxXPrev = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]); - T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]); - T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]); - - // compute areas for comparation - T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); - T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); - - // of course, areas should be positive - if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; - - // compute intersection of rectangles - T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); - T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); - T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); - T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); - T intersectionArea = - sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * - sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); - T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea); - // final check - return intersectionValue; - } +template +static __device__ T similirityV3(T* boxes, Nd4jLong const* boxesShape, + int previousIndex, int nextIndex) { + Nd4jLong previous0[] = {previousIndex, 0}; + Nd4jLong previous1[] = {previousIndex, 1}; + Nd4jLong previous2[] = {previousIndex, 2}; + Nd4jLong previous3[] = {previousIndex, 3}; + Nd4jLong next0[] = {nextIndex, 0}; + Nd4jLong next1[] = {nextIndex, 1}; + Nd4jLong next2[] = {nextIndex, 2}; + Nd4jLong next3[] = {nextIndex, 3}; + + // we have rectangle with given max values. Compute vexes of rectangle first + + T minYPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T minXPrev = + sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T maxYPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], + boxes[shape::getOffset(boxesShape, previous2)]); + T maxXPrev = + sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], + boxes[shape::getOffset(boxesShape, previous3)]); + T minYNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T minXNext = sd::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + T maxYNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], + boxes[shape::getOffset(boxesShape, next2)]); + T maxXNext = sd::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], + boxes[shape::getOffset(boxesShape, next3)]); + + // compute areas for comparation + T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); + T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); + + // of course, areas should be positive + if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false; + + // compute intersection of rectangles + T minIntersectionY = sd::math::nd4j_max(minYPrev, minYNext); + T minIntersectionX = sd::math::nd4j_max(minXPrev, minXNext); + T maxIntersectionY = sd::math::nd4j_min(maxYPrev, maxYNext); + T maxIntersectionX = sd::math::nd4j_min(maxXPrev, maxXNext); + T intersectionArea = + sd::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) * + sd::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f)); + T intersectionValue = + intersectionArea / (areaPrev + areaNext - intersectionArea); + // final check + return intersectionValue; +} - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // shouldSelectKernel - compute status for all selected rectangles (boxes) // -// we compute boolean flag as shared uint32 and return it on final only for the first thread +// we compute boolean flag as shared uint32 and return it on final only for the +// first thread // - template - static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong const* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - __shared__ unsigned int shouldSelectShared; - if (threadIdx.x == 0) { - shouldSelectShared = (unsigned int)shouldSelect[0]; - } - __syncthreads(); - for (int j = numSelected - 1 - tid; j >= 0; j -= step) { - if (shouldSelectShared) { - if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], - indexBuf[selectedIndicesData[j]], T(threshold))) - atomicCAS(&shouldSelectShared, 1, 0); // exchange only when need to suppress - } - } - __syncthreads(); - - // final move: collect result - if (threadIdx.x == 0) { - *shouldSelect = shouldSelectShared > 0; - } +template +static __global__ void shouldSelectKernel(T* boxesBuf, + Nd4jLong const* boxesShape, + I* indexBuf, I* selectedIndicesData, + double threshold, int numSelected, + int i, bool* shouldSelect) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + __shared__ unsigned int shouldSelectShared; + if (threadIdx.x == 0) { + shouldSelectShared = (unsigned int)shouldSelect[0]; + } + __syncthreads(); + for (int j = numSelected - 1 - tid; j >= 0; j -= step) { + if (shouldSelectShared) { + if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], + indexBuf[selectedIndicesData[j]], + T(threshold))) + atomicCAS(&shouldSelectShared, 1, + 0); // exchange only when need to suppress } + } + __syncthreads(); + + // final move: collect result + if (threadIdx.x == 0) { + *shouldSelect = shouldSelectShared > 0; + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // indices - type depended, indicesLong - type defined (only 64bit integers) // - template - static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) { - I* indexBuf = reinterpret_cast(indices); - Nd4jLong* srcBuf = reinterpret_cast(indicesLong);; +template +static __global__ void copyIndices(void* indices, void* indicesLong, + Nd4jLong len) { + I* indexBuf = reinterpret_cast(indices); + Nd4jLong* srcBuf = reinterpret_cast(indicesLong); + ; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; - for (auto i = tid; i < len; i += step) - indexBuf[i] = (I)srcBuf[i]; - } + for (auto i = tid; i < len; i += step) indexBuf[i] = (I)srcBuf[i]; +} - template - static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, T scoreThreshold) { - auto start = blockIdx.x * blockDim.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start + threadIdx.x; e < (int)length; e += step) { - if (scores[e] < scoreThreshold) { - scores[e] = scoreThreshold; - indices[e] = -1; - } - else { - indices[e] = I(e); - } - } +template +static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, + T scoreThreshold) { + auto start = blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start + threadIdx.x; e < (int)length; e += step) { + if (scores[e] < scoreThreshold) { + scores[e] = scoreThreshold; + indices[e] = -1; + } else { + indices[e] = I(e); } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation +// nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 +// implementation // - template - static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {boxes, scales}); - std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()}, context)); // - 1, scales->lengthOf()); //, scales->getContext()); - - NDArray scores(*scales); - Nd4jPointer extras[2] = {nullptr, stream}; - auto indexBuf = indices->dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); - auto scoreBuf = scores.dataBuffer()->specialAsT(); - suppressScores<<<128, 128, 128, *stream>>>(scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); - indices->tickWriteDevice(); - sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); - indices->tickWriteDevice(); - NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}, context); - int numSelected = 0; - int numBoxes = boxes->sizeAt(0); - auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); - - auto selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - - bool* shouldSelectD; - auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err); - } - for (I i = 0; i < boxes->sizeAt(0); ++i) { - bool shouldSelect = numSelected < output->lengthOf(); - if (shouldSelect) { - err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err); - } - - shouldSelectKernel<<<128, 256, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); - err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err); - } - } - - if (shouldSelect) { - cudaMemcpy(reinterpret_cast(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice); - cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice); - numSelected++; - } - } - - err = cudaFree(shouldSelectD); - if (err) { - throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err); - } - +template +static void nonMaxSuppressionV2_(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double threshold, + double scoreThreshold, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {boxes, scales}); + std::unique_ptr indices(NDArrayFactory::create_( + 'c', {scales->lengthOf()}, + context)); // - 1, scales->lengthOf()); //, scales->getContext()); + + NDArray scores(*scales); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = + indices->dataBuffer() + ->specialAsT< + I>(); /// reinterpret_cast(indices->specialBuffer()); + auto scoreBuf = scores.dataBuffer()->specialAsT(); + suppressScores<<<128, 128, 128, *stream>>>( + scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); + indices->tickWriteDevice(); + sortByValue(extras, indices->buffer(), indices->shapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), + scores.specialShapeInfo(), true); + indices->tickWriteDevice(); + NDArray selectedIndices = + NDArrayFactory::create('c', {output->lengthOf()}, context); + int numSelected = 0; + int numBoxes = boxes->sizeAt(0); + auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); + + auto selectedIndicesData = + reinterpret_cast(selectedIndices.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + + bool* shouldSelectD; + auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", + err); + } + for (I i = 0; i < boxes->sizeAt(0); ++i) { + bool shouldSelect = numSelected < output->lengthOf(); + if (shouldSelect) { + err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), + cudaMemcpyHostToDevice); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", + err); + } + + shouldSelectKernel<<<128, 256, 1024, *stream>>>( + boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, + threshold, numSelected, i, shouldSelectD); + err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), + cudaMemcpyDeviceToHost); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", + err); + } } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong const* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold, bool simple) { - bool shouldHardSuppress = false; - T& nextCandidateScore = scores[nextCandidateIndex]; - I selectedIndex = indices[nextCandidateIndex]; - I finish = startIndices[nextCandidateIndex]; - - for (int j = selectedSize; j > finish; --j) { - T boxVal; - if (simple) { - Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; - auto xShift = shape::getOffset(shape, xPos, 0); - boxVal = boxes[xShift]; - } - else { - boxVal = similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]); - } - if (boxVal > static_cast(overlapThreshold)) - nextCandidateScore = static_cast(0.f); - - // First decide whether to perform hard suppression - if (boxVal >= overlapThreshold) { - shouldHardSuppress = true; - break; - } - - // If nextCandidate survives hard suppression, apply soft suppression - if (nextCandidateScore <= static_cast(scoreThreshold)) break; - } - - return shouldHardSuppress; + if (shouldSelect) { + cudaMemcpy(reinterpret_cast(output->specialBuffer()) + numSelected, + indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice); + cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), + cudaMemcpyHostToDevice); + numSelected++; } + } + + err = cudaFree(shouldSelectD); + if (err) { + throw cuda_exception::build( + "helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", + err); + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void - suppressNonMaxOverlapKernel(T* boxes, Nd4jLong const* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen, - T overlapThreshold, T scoreThreshold, I* output, Nd4jLong const* outputShape, I* outputLength, bool simple) { - - __shared__ I selectedSize; - __shared__ I* tempOutput; - - if (threadIdx.x == 0) { - selectedSize = outputLength?*outputLength:maxOutputLen; - extern __shared__ unsigned char shmem[]; - tempOutput = (I*)shmem; - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (I nextCandidateIndex = start + threadIdx.x; selectedSize < maxOutputLen && nextCandidateIndex < (I)length; ) { - auto originalScore = scoresData[nextCandidateIndex];//nextCandidate._score; - I nextCandidateBoxIndex = indices[nextCandidateIndex]; - auto selectedSizeMark = selectedSize; - - // skip for cases when index is less than 0 (under score threshold) - if (nextCandidateBoxIndex < 0) { - nextCandidateIndex += step; - continue; - } - // check for overlaps - bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize, - nextCandidateIndex, overlapThreshold, scoreThreshold, simple);//false; - T nextCandidateScore = scoresData[nextCandidateIndex]; - - startIndices[nextCandidateIndex] = selectedSize; - if (!shouldHardSuppress) { - if (nextCandidateScore == originalScore) { - // Suppression has not occurred, so select nextCandidate - if (output) - output[selectedSize] = nextCandidateBoxIndex; - tempOutput[selectedSize] = nextCandidateBoxIndex; - math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); - } - - if (nextCandidateScore > scoreThreshold) { - // Soft suppression has occurred and current score is still greater than - // scoreThreshold; add nextCandidate back onto priority queue. - continue; // in some cases, this index not 0 - } - } - nextCandidateIndex += step; - } - - if (threadIdx.x == 0) { - if (outputLength) - *outputLength = selectedSize; - } +template +static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong const* shape, + T* scores, I* indices, + I* selectedIndices, I* startIndices, + I selectedSize, I nextCandidateIndex, + T overlapThreshold, T scoreThreshold, + bool simple) { + bool shouldHardSuppress = false; + T& nextCandidateScore = scores[nextCandidateIndex]; + I selectedIndex = indices[nextCandidateIndex]; + I finish = startIndices[nextCandidateIndex]; + + for (int j = selectedSize; j > finish; --j) { + T boxVal; + if (simple) { + Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; + auto xShift = shape::getOffset(shape, xPos, 0); + boxVal = boxes[xShift]; + } else { + boxVal = + similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]); } + if (boxVal > static_cast(overlapThreshold)) + nextCandidateScore = static_cast(0.f); -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static Nd4jLong - nonMaxSuppressionGeneric_(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, - double overlapThreshold, double scoreThreshold, NDArray* output, bool simple) { - auto stream = context->getCudaStream(); - if (output) - NDArray::prepareSpecialUse({output}, {boxes, scores}); - else { - if (!boxes->isActualOnDeviceSide()) - boxes->syncToDevice(); - if (!scores->isActualOnDeviceSide()) - scores->syncToDevice(); - } - - NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}, context); // - 1, scales->lengthOf()); //, scales->getContext()); - NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); - NDArray selectedScores(*scores); - Nd4jPointer extras[2] = {nullptr, stream}; - auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); - - suppressScores<<<128, 128, 128, *stream>>>(selectedScores.dataBuffer()->specialAsT(), indexBuf, selectedScores.lengthOf(), T(scoreThreshold)); - - sortByValue(extras, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), selectedScores.buffer(), selectedScores.shapeInfo(), selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), true); - indices.tickWriteDevice(); - selectedScores.tickWriteDevice(); - - auto scoresData = selectedScores.dataBuffer()->specialAsT();//, numBoxes, scoresData.begin()); - - auto startIndices = startPositions.dataBuffer()->specialAsT(); - I selectedSize = 0; - Nd4jLong res = 0; - if (output) { // this part used when output shape already calculated to fill up values on output - DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); - suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), - boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize, - T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT(), output->specialShapeInfo(), - selectedSizeBuf.specialAsT(), simple); - } - else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr. - DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); - suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), - boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, - T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT(), simple); - selectedSizeBuf.syncToPrimary(context, true); - res = *selectedSizeBuf.primaryAsT(); - } - - if (output) - NDArray::registerSpecialUse({output}, {boxes, scores}); - - return res; + // First decide whether to perform hard suppression + if (boxVal >= overlapThreshold) { + shouldHardSuppress = true; + break; } + + // If nextCandidate survives hard suppression, apply soft suppression + if (nextCandidateScore <= static_cast(scoreThreshold)) break; + } + + return shouldHardSuppress; +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, - (context, boxes, scales, maxSize, threshold, scoreThreshold, output), - FLOAT_TYPES, INDEXING_TYPES); +template +static __global__ void suppressNonMaxOverlapKernel( + T* boxes, Nd4jLong const* boxesShape, T* scoresData, I* indices, + I* startIndices, Nd4jLong length, I maxOutputLen, T overlapThreshold, + T scoreThreshold, I* output, Nd4jLong const* outputShape, I* outputLength, + bool simple) { + __shared__ I selectedSize; + __shared__ I* tempOutput; + + if (threadIdx.x == 0) { + selectedSize = outputLength ? *outputLength : maxOutputLen; + extern __shared__ unsigned char shmem[]; + tempOutput = (I*)shmem; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (I nextCandidateIndex = start + threadIdx.x; + selectedSize < maxOutputLen && nextCandidateIndex < (I)length;) { + auto originalScore = + scoresData[nextCandidateIndex]; // nextCandidate._score; + I nextCandidateBoxIndex = indices[nextCandidateIndex]; + auto selectedSizeMark = selectedSize; + + // skip for cases when index is less than 0 (under score threshold) + if (nextCandidateBoxIndex < 0) { + nextCandidateIndex += step; + continue; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, - (context, boxes, scales, maxSize, threshold, scoreThreshold, output, true), - FLOAT_TYPES, INDEXING_TYPES); - return boxes->sizeAt(0); + // check for overlaps + bool shouldHardSuppress = + checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, + startIndices, selectedSize, nextCandidateIndex, + overlapThreshold, scoreThreshold, simple); // false; + T nextCandidateScore = scoresData[nextCandidateIndex]; + + startIndices[nextCandidateIndex] = selectedSize; + if (!shouldHardSuppress) { + if (nextCandidateScore == originalScore) { + // Suppression has not occurred, so select nextCandidate + if (output) output[selectedSize] = nextCandidateBoxIndex; + tempOutput[selectedSize] = nextCandidateBoxIndex; + math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); + } + + if (nextCandidateScore > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // scoreThreshold; add nextCandidate back onto priority queue. + continue; // in some cases, this index not 0 + } } + nextCandidateIndex += step; + } - Nd4jLong - nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, - (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, false), - FLOAT_TYPES, INDEXING_TYPES); - return boxes->sizeAt(0); - } + if (threadIdx.x == 0) { + if (outputLength) *outputLength = selectedSize; + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, + NDArray* boxes, NDArray* scores, + int outputSize, + double overlapThreshold, + double scoreThreshold, + NDArray* output, bool simple) { + auto stream = context->getCudaStream(); + if (output) + NDArray::prepareSpecialUse({output}, {boxes, scores}); + else { + if (!boxes->isActualOnDeviceSide()) boxes->syncToDevice(); + if (!scores->isActualOnDeviceSide()) scores->syncToDevice(); + } + + NDArray indices = NDArrayFactory::create( + 'c', {scores->lengthOf()}, + context); // - 1, scales->lengthOf()); //, scales->getContext()); + NDArray startPositions = + NDArrayFactory::create('c', {scores->lengthOf()}, context); + NDArray selectedScores(*scores); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = + indices.dataBuffer() + ->specialAsT< + I>(); /// reinterpret_cast(indices->specialBuffer()); + + suppressScores<<<128, 128, 128, *stream>>>( + selectedScores.dataBuffer()->specialAsT(), indexBuf, + selectedScores.lengthOf(), T(scoreThreshold)); + + sortByValue(extras, indices.buffer(), indices.shapeInfo(), + indices.specialBuffer(), indices.specialShapeInfo(), + selectedScores.buffer(), selectedScores.shapeInfo(), + selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), + true); + indices.tickWriteDevice(); + selectedScores.tickWriteDevice(); + + auto scoresData = selectedScores.dataBuffer() + ->specialAsT(); //, numBoxes, scoresData.begin()); + + auto startIndices = startPositions.dataBuffer()->specialAsT(); + I selectedSize = 0; + Nd4jLong res = 0; + if (output) { // this part used when output shape already calculated to fill + // up values on output + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), + DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream>>>( + boxes->dataBuffer()->specialAsT(), boxes->specialShapeInfo(), + scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), + output->dataBuffer()->specialAsT(), output->specialShapeInfo(), + selectedSizeBuf.specialAsT(), simple); + } else { // this case used on calculation of output shape. Output and output + // shape shoulde be nullptr. + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), + DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream>>>( + boxes->dataBuffer()->specialAsT(), boxes->specialShapeInfo(), + scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*)nullptr, + selectedSizeBuf.specialAsT(), simple); + selectedSizeBuf.syncToPrimary(context, true); + res = *selectedSizeBuf.primaryAsT(); + } + + if (output) NDArray::registerSpecialUse({output}, {boxes, scores}); + + return res; +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double threshold, + double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR( + boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, + double threshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), + output ? output->dataType() : DataType::INT32, + return nonMaxSuppressionGeneric_, + (context, boxes, scales, maxSize, threshold, + scoreThreshold, output, true), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } + +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), + output ? output->dataType() : DataType::INT32, + return nonMaxSuppressionGeneric_, + (context, boxes, scores, maxSize, overlapThreshold, + scoreThreshold, output, false), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu index c26b79ee6692..65fa6d01d2e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu @@ -19,407 +19,498 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include -#include #include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - rgbYuv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - +template +__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbYuv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -linkage void rgbToYuvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - rgbToYuvCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +linkage void rgbToYuvCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + rgbToYuvCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), { dimC }); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), { dimC }); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "yuv_to_rgb"); - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToYuvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - - manager.synchronize(); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), rgbToYuvCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), + output.specialBuffer(), output.specialShapeInfo(), + packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - yuvRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - +template +__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + yuvRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -linkage void yuvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { - - yuvToRgbCuda << > > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +linkage void yuvToRgbCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + yuvToRgbCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), { dimC }); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), { dimC }); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "yuv_to_rgb"); - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_SINGLE_SELECTOR(input.dataType(), yuvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - - manager.synchronize(); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "yuv_to_rgb"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), yuvToRgbCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), packX.platformOffsets(), + output.specialBuffer(), output.specialShapeInfo(), + packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// // for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4} -template -__global__ void rgbToGrsCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen; - __shared__ int rank, *sharedMem; // xRank == zRank - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); - } - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { - const auto xStep = i*3; - z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; - } - else { - - shape::index2coords(i, zShapeInfo, coords); +template +__global__ void rgbToGrsCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int dimC) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen; + __shared__ int rank, *sharedMem; // xRank == zRank + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && + 1 == shape::elementWiseStride(xShapeInfo) && + 'c' == shape::order(zShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo)) { + const auto xStep = i * 3; + z[i] = + 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; + } else { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto xOffset0 = shape::getOffset(xShapeInfo, coords); - const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; - const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; + const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto xOffset0 = shape::getOffset(xShapeInfo, coords); + const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; + const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; - z[zOffset] = 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; - } - } + z[zOffset] = + 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + } } /////////////////////////////////////////////////////////////////// -template -linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { - - rgbToGrsCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, dimC); +template +linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int dimC) { + rgbToGrsCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, dimC); } /////////////////////////////////////////////////////////////////// -void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { - - PointersManager manager(context, "rgbToGrs"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = input.rankOf() * sizeof(int) * threadsPerBlock + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrsCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), dimC), NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC) { + PointersManager manager(context, "rgbToGrs"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = input.rankOf() * sizeof(int) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), rgbToGrsCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), dimC), + NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } - /////////////////////////////////////////////////////////////////// template -static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); +static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// template -static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, +static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; - hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } + hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], + zTad[zDimCstride], zTad[2 * zDimCstride]); + } } /////////////////////////////////////////////////////////////////// -template -static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +static _CUDA_H void hsvToRgbCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + hsvToRgbCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } -template -static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +template +static _CUDA_H void rgbToHsvCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) { + rgbToHsvCuda<<>>( + vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, + dimC); } /////////////////////////////////////////////////////////////////// -void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "hsv_to_rgb"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "hsv_to_rgb"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR( + input->dataType(), hsvToRgbCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {dimC}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "rgb_to_hsv"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->specialBuffer(), input->specialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {dimC}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "rgb_to_hsv"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR( + input->dataType(), rgbToHsvCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input->specialBuffer(), input->specialShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), + FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); } -template -__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) { - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong zLen, *sharedMem; - __shared__ int rank; // xRank == zRank - - float yiqarr[3][3] = { - { 0.299f, 0.59590059f, 0.2115f }, - { 0.587f, -0.27455667f, -0.52273617f }, - { 0.114f, -0.32134392f, 0.31119955f } - }; - - float rgbarr[3][3] = { - { 1.f, 1.f, 1.f }, - { 0.95598634f, -0.27201283f, -1.10674021f }, - { 0.6208248f, -0.64720424f, 1.70423049f } - }; - - auto tr = mode == 1? yiqarr : rgbarr; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - zLen = shape::length(zShapeInfo); - rank = shape::rank(zShapeInfo); +template +__global__ void tripleTransformerCuda( + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadShapeInfo, const Nd4jLong* zOffsets, const int dimC, + int mode, uint64_t numTads) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank; // xRank == zRank + + float yiqarr[3][3] = {{0.299f, 0.59590059f, 0.2115f}, + {0.587f, -0.27455667f, -0.52273617f}, + {0.114f, -0.32134392f, 0.31119955f}}; + + float rgbarr[3][3] = {{1.f, 1.f, 1.f}, + {0.95598634f, -0.27201283f, -1.10674021f}, + {0.6208248f, -0.64720424f, 1.70423049f}}; + + auto tr = mode == 1 ? yiqarr : rgbarr; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + Nd4jLong* coords = sharedMem + threadIdx.x * rank; + + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && + 1 == shape::elementWiseStride(xShapeInfo) && + 'c' == shape::order(zShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo)) { + for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; + f += gridDim.x * blockDim.x) { + auto i = f * 3; + + auto xi0 = x[i]; + auto xi1 = x[i + 1]; + auto xi2 = x[i + 2]; + + for (int e = 0; e < 3; e++) + z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; } - __syncthreads(); - - Nd4jLong* coords = sharedMem + threadIdx.x * rank; - - if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { - for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) { - auto i = f * 3; - - auto xi0 = x[i]; - auto xi1 = x[i+1]; - auto xi2 = x[i+2]; - - for (int e = 0; e < 3; e++) - z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; - } - } else { - // TAD based case - const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC]; - const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC]; - - for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) { - const T* xTad = x + xOffsets[i]; - T* zTad = z + zOffsets[i]; - - auto xi0 = xTad[0]; - auto xi1 = xTad[xDimCstride]; - auto xi2 = xTad[xDimCstride * 2]; - - for (int e = 0; e < 3; e++) - zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; - } + } else { + // TAD based case + const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC]; + const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC]; + + for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; + i += blockDim.x * gridDim.x) { + const T* xTad = x + xOffsets[i]; + T* zTad = z + zOffsets[i]; + + auto xi0 = xTad[0]; + auto xi1 = xTad[xDimCstride]; + auto xi2 = xTad[xDimCstride * 2]; + + for (int e = 0; e < 3; e++) + zTad[zDimCstride * e] = + xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e]; } + } } - template -static void rgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - NDArray::prepareSpecialUse({output}, {input}); - return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads()); - NDArray::registerSpecialUse({output}, {input}); +static void rgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda + <<<256, 256, 8192, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, + packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); } template -FORCEINLINE static void yiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimC); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimC); - - NDArray::prepareSpecialUse({output}, {input}); - return tripleTransformerCuda<<<256, 256, 8192, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads()); - NDArray::registerSpecialUse({output}, {input}); -} - -void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES); +FORCEINLINE static void yiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimC); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimC); + + NDArray::prepareSpecialUse({output}, {input}); + return tripleTransformerCuda + <<<256, 256, 8192, *context->getCudaStream()>>>( + input->specialBuffer(), input->specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, + packZ.numberOfTads()); + NDArray::registerSpecialUse({output}, {input}); } -void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, + (context, input, output, dimC), FLOAT_TYPES); } - - - - -} -} +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, + (context, input, output, dimC), FLOAT_TYPES); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index 723b0f215dc5..f47d82d6a65e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -19,74 +19,91 @@ // @author raver119@gmail.com // - -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { template -static void ismax_(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector& dimensions) { - auto stream = context->getCudaStream(); - - auto xRank = input->rankOf(); - auto zRank = output->rankOf(); - auto xType = input->dataType(); - auto zType = output->dataType(); - input->syncToDevice(); - Nd4jLong* special = nullptr; - PointersManager manager(context, "IsMaxHelper"); - if (dimensions.size() == 0) { - /** - * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call - */ - auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - auto targetIdx = indexMax.e(0); - - dim3 launchDims(128, 512, 1024); - BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES); - manager.synchronize(); - - } else { - Nd4jLong* hostYShapeInfo = nullptr; - Nd4jLong* hostTShapeInfo = nullptr; - int* dimension = nullptr; - int dimensionLength = dimensions.size(); - std::vector copy(dimensions); - - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), copy.data(), copy.size()); - - // we launch legacy IndexMax op, to get indices of max values along dimension - auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - - dim3 launchDims(256, 256, 16384); - dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); - - // at this point, all IMax indexes are gathered, and we execute filler - BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, indexMaxArr.specialBuffer(), output->specialBuffer(), output->specialShapeInfo(), packZ.specialShapeInfo(), dimension, dimensionLength, packZ.specialOffsets()), LIBND4J_TYPES); - manager.synchronize(); - } +static void ismax_(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const std::vector& dimensions) { + auto stream = context->getCudaStream(); + + auto xRank = input->rankOf(); + auto zRank = output->rankOf(); + auto xType = input->dataType(); + auto zType = output->dataType(); + input->syncToDevice(); + Nd4jLong* special = nullptr; + PointersManager manager(context, "IsMaxHelper"); + if (dimensions.size() == 0) { + /** + * In case of vector-input for IsMax, it just turns into IndexReduce call + + * subsequent filler call + */ + auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); + auto targetIdx = indexMax.e(0); + + dim3 launchDims(128, 512, 1024); + BUILD_SINGLE_SELECTOR( + zType, fillIsMaxGeneric, + (launchDims, stream, output->specialBuffer(), + output->specialShapeInfo(), output->lengthOf(), targetIdx), + LIBND4J_TYPES); + manager.synchronize(); + + } else { + Nd4jLong* hostYShapeInfo = nullptr; + Nd4jLong* hostTShapeInfo = nullptr; + int* dimension = nullptr; + int dimensionLength = dimensions.size(); + std::vector copy(dimensions); + + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), copy.data(), copy.size()); + + // we launch legacy IndexMax op, to get indices of max values along + // dimension + auto indexMaxArr = + input->applyIndexReduce(indexreduce::IndexMax, dimensions); + + dim3 launchDims(256, 256, 16384); + dimension = (int*)manager.replicatePointer(dimensions.data(), + dimensions.size() * sizeof(int)); + + // at this point, all IMax indexes are gathered, and we execute filler + BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, + (launchDims, stream, indexMaxArr.specialBuffer(), + output->specialBuffer(), output->specialShapeInfo(), + packZ.specialShapeInfo(), dimension, dimensionLength, + packZ.specialOffsets()), + LIBND4J_TYPES); + manager.synchronize(); + } } +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions) { + NDArray::prepareSpecialUse({output}, {input}); -void ismax(sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { - NDArray::prepareSpecialUse({output}, {input}); - - BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, + (context, input, output, dimensions), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + NDArray::registerSpecialUse({output}, {input}); } -BUILD_SINGLE_TEMPLATE(template void ismax_, (sd::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions), LIBND4J_TYPES); - -} -} -} +BUILD_SINGLE_TEMPLATE(template void ismax_, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output, const std::vector& dimensions), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index ab65ed96b36d..2203f4d07c6b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -18,98 +18,108 @@ // @author GS // +#include #include #include -#include #include namespace sd { - namespace ops { - namespace helpers { - - template - linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.f ? y : T(0.f); - }; - - theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); - } - - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); - } - - template - linkage void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f ? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f && x < (T)6.f? y : T(0.f); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT) { - return x < 0 ? alphaT * y : y; - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { - - const T alphaT = static_cast(alpha); - - auto functor = LAMBDA_TT(x, y, alphaT){ - return y * sd::math::nd4j_eluderivative(x, alphaT); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); - } - - template - linkage void seluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::SELUDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + +template +linkage void reluDerivative__(NDArray* theFirst, NDArray* theSecond) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + theFirst->applyPairwiseLambda(*theSecond, functor, *theFirst); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, + (theFirst, theSecond), FLOAT_TYPES); +} + +template +linkage void reluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > (T)0.f ? y : T(0.f); }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +linkage void relu6Derivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.f && x < (T)6.f ? y : T(0.f); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +template +linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output, const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { return x < 0 ? alphaT * y : y; }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, + const float alpha) { + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return y * sd::math::nd4j_eluderivative(x, alphaT); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, + (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); +} + +template +linkage void seluDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::SELUDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 56a57614f4c2..b11e76be7ad2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -18,69 +18,81 @@ // @author GS // -#include #include +#include #include #include namespace sd { - namespace ops { - namespace helpers { - //////////////////////////////////////////////////////////////////////// - template - linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * ((T)1.0f - (th * th)); - }; +namespace ops { +namespace helpers { +//////////////////////////////////////////////////////////////////////// +template +linkage void tanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * ((T)1.0f - (th * th)); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - // return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); - template - linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T th = sd::math::nd4j_tanh(x); - return y * simdOps::HardTanhDerivative::op(x, nullptr); - }; +// return static_cast(d2) * simdOps::HardTanhDerivative::op(d1, nullptr); +template +linkage void hardTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T th = sd::math::nd4j_tanh(x); + return y * simdOps::HardTanhDerivative::op(x, nullptr); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - template - linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::RationalTanhDerivative::op(x, nullptr); - }; +template +linkage void rationalTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::RationalTanhDerivative::op(x, nullptr); + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} - template - linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T) 0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T) 0.0f; - }; +template +linkage void rectifiedTanhDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return x > (T)0.0f ? y * (sd::math::nd4j_tanhderivative(x)) : (T)0.0f; + }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - } - } -} \ No newline at end of file +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index 181050919e06..b729ca137d5e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -18,232 +18,271 @@ // @author GS // -#include #include -#include +#include #include +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * (3 * x * x); - }; +template +linkage void cubeDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return y * (3 * x * x); }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //return (x >= X(0.f) ? y: -y); - template - linkage void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > T(0.f)? y : -y; - }; +// return (x >= X(0.f) ? y: -y); +template +linkage void reduceNorm1_(NDArray* input, NDArray* epsilon, NDArray* output) { + auto functor = LAMBDA_TT(x, y) { return x > T(0.f) ? y : -y; }; - input->applyPairwiseLambda(*epsilon, functor, *output); - } + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////// - template - linkage void sigmCrossEntropy_(NDArray* logits, NDArray* labels, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - x * y + sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } +//////////////////////////////////////////////////////////////////////// +template +linkage void sigmCrossEntropy_(NDArray* logits, NDArray* labels, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); - } +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, + (logits, labels, output), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////// - template - linkage void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, NDArray* output) { - // 1 - labels - 1 / (1 + exp(logits)) - auto functor = LAMBDA_TT(x, y) { - if(x <= 0) - return static_cast(1.) - y - static_cast(1.) / (static_cast(1.) + sd::math::nd4j_exp(x)); - auto e = sd::math::nd4j_exp(-x); - return static_cast(1.) - y - e / (static_cast(1.) + e); - }; - - logits->applyPairwiseLambda(*labels, functor, *output); - } +//////////////////////////////////////////////////////////////////////// +template +linkage void sigmCrossEntropyGrad_(NDArray* logits, NDArray* labels, + NDArray* output) { + // 1 - labels - 1 / (1 + exp(logits)) + auto functor = LAMBDA_TT(x, y) { + if (x <= 0) + return static_cast(1.) - y - + static_cast(1.) / + (static_cast(1.) + sd::math::nd4j_exp(x)); + auto e = sd::math::nd4j_exp(-x); + return static_cast(1.) - y - e / (static_cast(1.) + e); + }; + + logits->applyPairwiseLambda(*labels, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { - BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); - } +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* labels, NDArray* output) { + BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, + (logits, labels, output), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // X f = (X) 1.0f + sd::math::nd4j_abs(d1); - // return (X) d2 * ((X) 1.0f / (f * f)); - // - template - linkage void softSignDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T ss = (T)1.f + sd::math::nd4j_abs(x); - return y * ((T) 1.0f / (ss * ss)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } +// X f = (X) 1.0f + sd::math::nd4j_abs(d1); +// return (X) d2 * ((X) 1.0f / (f * f)); +// +template +linkage void softSignDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T ss = (T)1.f + sd::math::nd4j_abs(x); + return y * ((T)1.0f / (ss * ss)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void softPlusDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T p = sd::math::nd4j_pow(static_cast(M_E), x); - return y * (p / (p + 1.)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } +template +linkage void softPlusDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T p = sd::math::nd4j_pow(static_cast(M_E), x); + return y * (p / (p + 1.)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// /// \param input /// \param epsilon /// \param output - template - linkage void sigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - T s = sd::math::nd4j_sigmoid(x); - return y * (s * ((T) 1.0f - s)); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - - template - linkage void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * simdOps::HardSigmoidDerivative::op(x, nullptr); - }; - - input->applyPairwiseLambda(*epsilon, functor, *output); - } - - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyTransform(transform::Exp, tempInput); - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - - template - linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - // reduce along axis with - NDArray tempInput = input->dup(); - input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); - tempInput.applyTransform(transform::Exp, tempInput); - - std::vector axisVector; - if (axis != nullptr) { - axisVector.resize(axis->lengthOf()); - for (size_t i = 0; i < axisVector.size(); ++i) - axisVector[i] = axis->e(i); - } - tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); - output->applyTransform(transform::Log, *output); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); - } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - - T posWeight = weights->e(0); - - auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { - T targetWeight = (1. + (posWeight - (T)1.f) * _z); - return (1. - _z) * _x + - targetWeight * (sd::math::nd4j_log((T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f)) - ); - }; - - auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { - return (((T)1.0 - _z) * _x) + - _w * (sd::math::nd4j_log(T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + - sd::math::nd4j_max(-_x, T(0.f))); - }; - - - if (weights->isScalar()) { - const_cast(input)->applyPairwiseLambda(const_cast(*targets), mainRoutineT1, *output); - } - else - { - std::unique_ptr targetVector(new NDArray(*weights)); - targetVector->applyScalar(scalar::Add, -1.f, *targetVector); - - std::unique_ptr targetTensor(new NDArray(*targets)); - *targetTensor = (*targetVector * *targetTensor) + T(1.f); - const_cast(input)->applyTriplewiseLambda(const_cast(*targets), *targetTensor.get(), mainRoutineT2, *output); - } - } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { - NDArray::prepareSpecialUse({output}, {targets, input, weights}); - - BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); - - NDArray::registerSpecialUse({output}, {targets, input, weights}); - } +template +linkage void sigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + T s = sd::math::nd4j_sigmoid(x); + return y * (s * ((T)1.0f - s)); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); } -} -} \ No newline at end of file + +template +linkage void hardSigmoidDerivative_(NDArray* input, NDArray* epsilon, + NDArray* output) { + auto functor = LAMBDA_TT(x, y) { + return y * simdOps::HardSigmoidDerivative::op(x, nullptr); + }; + + input->applyPairwiseLambda(*epsilon, functor, *output); +} + +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, + (theFirst, theSecond, theOutput), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +linkage void logSumExp_(NDArray* input, NDArray* axis, NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyTransform(transform::Exp, tempInput); + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} + +template +linkage void logSumExp_(NDArray* input, NDArray* subtrah, NDArray* axis, + NDArray* output) { + // reduce along axis with + NDArray tempInput = input->dup(); + input->applyPairwiseTransform(pairwise::Subtract, *subtrah, tempInput); + tempInput.applyTransform(transform::Exp, tempInput); + + std::vector axisVector; + if (axis != nullptr) { + axisVector.resize(axis->lengthOf()); + for (size_t i = 0; i < axisVector.size(); ++i) + axisVector[i] = axis->e(i); + } + tempInput.reduceAlongDimension(reduce::Sum, *output, axisVector); + output->applyTransform(transform::Log, *output); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), + FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, + (input, subtrah, axis, output), FLOAT_TYPES); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + T posWeight = weights->e(0); + + auto mainRoutineT1 = LAMBDA_TT(_x, _z, posWeight) { + T targetWeight = (1. + (posWeight - (T)1.f) * _z); + return (1. - _z) * _x + + targetWeight * (sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp( + -sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + auto mainRoutineT2 = LAMBDA_TTT(_x, _z, _w) { + return (((T)1.0 - _z) * _x) + + _w * + (sd::math::nd4j_log( + T(1.) + sd::math::nd4j_exp(-sd::math::nd4j_abs(_x))) + + sd::math::nd4j_max(-_x, T(0.f))); + }; + + if (weights->isScalar()) { + const_cast(input)->applyPairwiseLambda( + const_cast(*targets), mainRoutineT1, *output); + } else { + std::unique_ptr targetVector(new NDArray(*weights)); + targetVector->applyScalar(scalar::Add, -1.f, *targetVector); + + std::unique_ptr targetTensor(new NDArray(*targets)); + *targetTensor = (*targetVector * *targetTensor) + T(1.f); + const_cast(input)->applyTriplewiseLambda( + const_cast(*targets), *targetTensor.get(), mainRoutineT2, + *output); + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {targets, input, weights}); + + BUILD_SINGLE_SELECTOR(targets->dataType(), + weightedCrossEntropyWithLogitsFunctor_, + (targets, input, weights, output), FLOAT_TYPES); + + NDArray::registerSpecialUse({output}, {targets, input, weights}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu index 2de455d0f8cf..2828380ab839 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu @@ -19,7 +19,7 @@ // @author George A. Shulinok // -#include +#include //#include //#include @@ -31,24 +31,24 @@ namespace helpers { // calculate digamma function for array elements template void lgamma_(NDArray& x, NDArray& z) { - //auto dtype = x.dataType(); - auto lgammaProc = LAMBDA_T(x_, dtype) { - return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); - }; - - x.applyLambda(lgammaProc, z); + // auto dtype = x.dataType(); + auto lgammaProc = LAMBDA_T(x_, dtype) { + return T( + DataTypeUtils::fromT() == DataType::DOUBLE + ? ::lgamma(x_) + : ::lgammaf(x_)); // math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); } void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z) { - - BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); - - - -} -} -} +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray & x, NDArray& z), + FLOAT_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index ebc0732e2411..337da2a4b626 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -18,154 +18,192 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void lrnKernel(void *vx, Nd4jLong const*xTadShapeInfo, Nd4jLong const*xTadOffsets, void *vz, Nd4jLong const*zTadShapeInfo, Nd4jLong const*zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { - extern __shared__ char sharedChar[]; - T* shared = reinterpret_cast(sharedChar); - - auto xEws = shape::elementWiseStride(xTadShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - auto xOrder = shape::order(xTadShapeInfo); - auto zOrder = shape::order(zTadShapeInfo); - - const T tbias = static_cast(bias); - const T tbeta = static_cast(beta); - const T talpha = static_cast(alpha); - - // one block of threads processes 1 example within batch - for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto z = reinterpret_cast(vz) + zTadOffsets[i]; - - // load everything into shared memory, so we'll operate on shared memory from now on - shared[threadIdx.x] = x[threadIdx.x * xEws]; - __syncthreads(); - - const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); - const uint last = depth + threadIdx.x + 1; - const uint end = sd::math::nd4j_min(last, tadLength); - - T prev = 0.; - for (int s = begin; s < end; s++) - prev = prev + shared[s] * shared[s]; - - z[threadIdx.x * zEws] = shared[threadIdx.x] / sd::math::nd4j_pow(tbias + alpha * prev, tbeta); - } - } - - template - static _CUDA_G void lrnBPKernel(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { - extern __shared__ char sharedChar[]; - X* sharedX = reinterpret_cast(sharedChar); - Z* sharedY = reinterpret_cast(sharedX + blockDim.x); - - auto xEws = shape::elementWiseStride(xTadShapeInfo); - auto zEws = shape::elementWiseStride(zTadShapeInfo); - - auto xOrder = shape::order(xTadShapeInfo); - auto zOrder = shape::order(zTadShapeInfo); - - const Z tbias = static_cast(bias); - const Z tbeta = static_cast(beta); - const Z talpha = static_cast(alpha); - const Z coeff = talpha * tbeta; - - - - for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto z = reinterpret_cast(vz) + zTadOffsets[i]; - - const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); - const uint last = depth + threadIdx.x + 1; - const uint end = sd::math::nd4j_min(last, tadLength); - - // load everything into shared memory - sharedX[threadIdx.x] = x[threadIdx.x * xEws]; - sharedY[threadIdx.x] = 0.f; - __syncthreads(); - - // we're operating in shared memory - for (int s = begin; s < end; s++) - sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; - __syncthreads(); - - Z factor[1024]; - Z init = tbias + talpha * sharedY[threadIdx.x]; - - Z prev = 0.f; - for (uint s = begin; s < end; ++s) { - factor[s] = sd::math::nd4j_pow(tbias + talpha * sharedY[s], -tbeta - 1); - prev = prev + sharedX[s] * factor[s]; - } - - z[threadIdx.x * zEws] = factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev; - } - } - - - template - static void lrnBP_(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - auto rank = input.rankOf(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), {rank - 1}); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), {rank - 1}); - - const auto tadLength = shape::length(packX.primaryShapeInfo()); - const int numBlocks = sd::math::nd4j_min(1024, packX.numberOfTads()); - const int numThreads = tadLength; - - if (tadLength > 1024 || tadLength < 1) - throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); - - lrnBPKernel<<getCudaStream()>>>(input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); +template +static _CUDA_G void lrnKernel(void* vx, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xTadOffsets, void* vz, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets, Nd4jLong numTads, + Nd4jLong tadLength, int depth, double bias, + double alpha, double beta) { + extern __shared__ char sharedChar[]; + T* shared = reinterpret_cast(sharedChar); + + auto xEws = shape::elementWiseStride(xTadShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + auto xOrder = shape::order(xTadShapeInfo); + auto zOrder = shape::order(zTadShapeInfo); + + const T tbias = static_cast(bias); + const T tbeta = static_cast(beta); + const T talpha = static_cast(alpha); + + // one block of threads processes 1 example within batch + for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto z = reinterpret_cast(vz) + zTadOffsets[i]; + + // load everything into shared memory, so we'll operate on shared memory + // from now on + shared[threadIdx.x] = x[threadIdx.x * xEws]; + __syncthreads(); + + const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); + const uint last = depth + threadIdx.x + 1; + const uint end = sd::math::nd4j_min(last, tadLength); + + T prev = 0.; + for (int s = begin; s < end; s++) prev = prev + shared[s] * shared[s]; + + z[threadIdx.x * zEws] = + shared[threadIdx.x] / + sd::math::nd4j_pow(tbias + alpha * prev, tbeta); + } +} - gradI.tickWriteDevice(); - gradI *= gradO; +template +static _CUDA_G void lrnBPKernel(void const* vx, Nd4jLong const* xTadShapeInfo, + Nd4jLong const* xTadOffsets, void* vz, + Nd4jLong const* zTadShapeInfo, + Nd4jLong const* zTadOffsets, Nd4jLong numTads, + Nd4jLong tadLength, int depth, double bias, + double alpha, double beta) { + extern __shared__ char sharedChar[]; + X* sharedX = reinterpret_cast(sharedChar); + Z* sharedY = reinterpret_cast(sharedX + blockDim.x); + + auto xEws = shape::elementWiseStride(xTadShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + auto xOrder = shape::order(xTadShapeInfo); + auto zOrder = shape::order(zTadShapeInfo); + + const Z tbias = static_cast(bias); + const Z tbeta = static_cast(beta); + const Z talpha = static_cast(alpha); + const Z coeff = talpha * tbeta; + + for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto z = reinterpret_cast(vz) + zTadOffsets[i]; + + const uint begin = sd::math::nd4j_max(0, threadIdx.x - depth); + const uint last = depth + threadIdx.x + 1; + const uint end = sd::math::nd4j_min(last, tadLength); + + // load everything into shared memory + sharedX[threadIdx.x] = x[threadIdx.x * xEws]; + sharedY[threadIdx.x] = 0.f; + __syncthreads(); + + // we're operating in shared memory + for (int s = begin; s < end; s++) + sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; + __syncthreads(); + + Z factor[1024]; + Z init = tbias + talpha * sharedY[threadIdx.x]; + + Z prev = 0.f; + for (uint s = begin; s < end; ++s) { + factor[s] = + sd::math::nd4j_pow(tbias + talpha * sharedY[s], -tbeta - 1); + prev = prev + sharedX[s] * factor[s]; } - void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - input.syncToDevice(); - gradO.syncToDevice(); - - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); + z[threadIdx.x * zEws] = + factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev; + } +} - gradI.tickWriteDevice(); - } +template +static void lrnBP_(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + auto rank = input.rankOf(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), {rank - 1}); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + gradI.shapeInfo(), {rank - 1}); + + const auto tadLength = shape::length(packX.primaryShapeInfo()); + const int numBlocks = + sd::math::nd4j_min(1024, packX.numberOfTads()); + const int numThreads = tadLength; + + if (tadLength > 1024 || tadLength < 1) + throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + + lrnBPKernel<<getCudaStream()>>>( + input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), + packX.numberOfTads(), tadLength, depth, bias, alpha, beta); + + gradI.tickWriteDevice(); + gradI *= gradO; +} - template - static void lrnFunctor_(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - auto rank = input->rankOf(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {rank - 1}); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {rank - 1}); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta) { + input.syncToDevice(); + gradO.syncToDevice(); - const auto tadLength = shape::length(packX.primaryShapeInfo()); - const int numBlocks = sd::math::nd4j_min(1024, packX.numberOfTads()); - const int numThreads = tadLength; + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, + (block, input, gradO, gradI, depth, bias, alpha, beta), + FLOAT_TYPES, FLOAT_TYPES); - if (tadLength > 1024 || tadLength < 1) - throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + gradI.tickWriteDevice(); +} - lrnKernel<<getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); - } +template +static void lrnFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* output, int depth, double bias, double alpha, + double beta) { + auto rank = input->rankOf(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {rank - 1}); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {rank - 1}); + + const auto tadLength = shape::length(packX.primaryShapeInfo()); + const int numBlocks = + sd::math::nd4j_min(1024, packX.numberOfTads()); + const int numThreads = tadLength; + + if (tadLength > 1024 || tadLength < 1) + throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); + + lrnKernel<<getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), + tadLength, depth, bias, alpha, beta); +} - int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { - input->syncToDevice(); +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta) { + input->syncToDevice(); - BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, + (block, input, output, depth, bias, alpha, beta), + FLOAT_TYPES); - output->tickWriteDevice(); + output->tickWriteDevice(); - return Status::OK(); - } -} -} + return Status::OK(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu index af0c413d6701..619dc76a0b79 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu @@ -20,179 +20,211 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - ////////////////////////////////////////////////////////////////////////// -void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params) { - - // xt input [bS x nIn] - // ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!! - // ct_1 previous cell state [bS x nOut], that is at previous time step t-1 - - // Wx input-to-hidden weights, [nIn x 4*nOut] - // Wh hidden-to-hidden weights, [numProj x 4*nOut] - // Wc diagonal weights for peephole connections [3*nOut] - // Wp projection weights [nOut x numProj] - // b biases, [4*nOut] - - // ht current cell output [bS x numProj], that is at current time step t - // ct current cell state [bS x nOut], that is at current time step t - - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!! - double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped - const double forgetBias = params[4]; - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int numProj = ht_1->sizeAt(1); - const int nOut = ct_1->sizeAt(1); - - auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] - - auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut] - auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut] - auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut] - auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate - zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate - } - - // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - ct->assign( sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct) ); - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); - - if(peephole) - zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc - - // current cell output = ot*tanh(ct) - auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] - - // apply projection - if(projection) { - ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] - // if clipping projection is provided then projected cell output state is clipped by this value - if(clippingProjValue != 0.) - ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); - } - else - ht->assign(&htNoPeepHole); +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params) { + // xt input [bS x nIn] + // ht_1 previous cell output [bS x numProj], that is at previous time step + // t-1, in case of projection=false -> numProj=nOut!!! ct_1 previous cell + // state [bS x nOut], that is at previous time step t-1 + + // Wx input-to-hidden weights, [nIn x 4*nOut] + // Wh hidden-to-hidden weights, [numProj x 4*nOut] + // Wc diagonal weights for peephole connections [3*nOut] + // Wp projection weights [nOut x numProj] + // b biases, [4*nOut] + + // ht current cell output [bS x numProj], that is at current time step t + // ct current cell state [bS x nOut], that is at current time step t + + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const bool projection = + (bool)params[1]; // if true, then projection is performed, if false then + // numProj==nOut is mandatory!!!! + double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + double clippingProjValue = + params[3]; // clipping value for projected ht, if it is not equal to + // zero, then projected cell output is clipped + const double forgetBias = params[4]; + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int numProj = ht_1->sizeAt(1); + const int nOut = ct_1->sizeAt(1); + + auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut] + + auto zit = z({0, 0, 0, nOut}); // z for input gate, = mmul(Wxi,xt) + + // mmul(Whi,ht_1) + bi = [bS x nOut] + auto zft = z({0, 0, nOut, 2 * nOut}); // z for forget gate, = mmul(Wxf,xt) + + // mmul(Whf,ht_1) + bf = [bS x nOut] + auto zct = + z({0, 0, 2 * nOut, 3 * nOut}); // z for cell state, = mmul(Wxc,xt) + + // mmul(Whc,ht_1) + bc = [bS x nOut] + auto zot = + z({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, = mmul(Wxo,xt) + + // mmul(Who,ht_1) + bo = [bS x nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zit += + (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate + zft += (*ct_1) * + (*Wc)({nOut, 2 * nOut}); // add peephole connections to forget gate + } + + // current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + ct->assign(sigmoid(zft + forgetBias) * (*ct_1) + sigmoid(zit) * tanh(zct)); + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + ct->applyScalar(scalar::LstmClip, clippingCellValue, *ct); + + if (peephole) + zot += (*ct) * (*Wc)({{2 * nOut, 3 * nOut}}); // add peephole connections + // to output gate zot + ct*Wc + + // current cell output = ot*tanh(ct) + auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut] + + // apply projection + if (projection) { + ht->assign(mmul(htNoPeepHole, + *Wp)); // [bS x nOut] * [ nOut x numProj] = [bS x numProj] + // if clipping projection is provided then projected cell output state is + // clipped by this value + if (clippingProjValue != 0.) + ht->applyScalar(scalar::LstmClip, clippingProjValue, *ht); + } else + ht->assign(&htNoPeepHole); } - - -void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params) { - /* Input arrays: - * 0: xt - input [bS, nIn] at time t - * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 - * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 - * 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] - * 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut] - * 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut] - * 6: Wco - weights - cell peephole (t) connections to output gate, [nOut] - * 7: b - biases, [4*nOut] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, nOut] - * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) - * 2: f - Output - forget gate activations [bs, nOut] - * 3: o - Output - output gate activations [bs, nOut] - * 4: z (ci) - Output - block input [bs, nOut] - * 5: h (co) - Cell state, post tanh [bs, nOut] - * 6: y (h) - Current cell output [bS, nOut], time t - */ - const bool peephole = (bool)params[0]; // if true, provide peephole connections - const double forgetBias = params[1]; - const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped - - const int bS = xt->sizeAt(0); - const int nIn = xt->sizeAt(1); - const int nOut = cLast->sizeAt(1); - - //Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] - NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext()); - helpers::concat(xt->getContext(), {const_cast(xt), const_cast(yLast)}, concatOut, {1}); - - auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] - m += (*b); // addiRowVector - - //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) - auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut] - auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut] - auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut] - auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut] - - if(peephole) { // add peephole connections: z + ct_1*Wc - zi += (*cLast) * (*Wci); // add peephole connections to input gate - zf += (*cLast) * (*Wcf); // add peephole connections to forget gate - } - - // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc - if(forgetBias != 0.0) - zf += forgetBias; - - zz.applyTransform(transform::Tanh, *z); //z = tanh(zz) - zi.applyTransform(transform::Sigmoid, *i); //i = sigmoid(zi) - zf.applyTransform(transform::Sigmoid, *f); //f = sigmoid(zf); - - //cell state = blockInput .* inputGate + prevCellState .* forgetGate - z->applyPairwiseTransform(pairwise::Multiply, *i, *c); //c = z * i - auto temp = (*f) * (*cLast); - *c += temp; //c = (i * z) + (zf * (*cLast)) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - - // if clipping value is provided then cell state is clipped by this value prior to the cell output activation - if(clippingCellValue > 0.0) - c->applyScalar(scalar::LstmClip, clippingCellValue, *c); - - if(peephole) { - // add peephole connections to output gate zot + ct*Wc - auto prod = *c * (*Wco); - zo += prod; - } - zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) - - // current cell output = ot*tanh(ct) - c->applyTransform(transform::Tanh, *h); //h = tanh(c) - o->applyPairwiseTransform(pairwise::Multiply, *h, *y); //y = o * h -} - - -} -} +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params) { + /* Input arrays: + * 0: xt - input [bS, nIn] at time t + * 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1 + * 2: yLast (h_prev) - previous output [bS, nOut], time t-1 + * 3: W - Weights - concatenated (input-to-hidden, + * hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut] 4: Wci - weights - + * cell peephole (t-1) connections to input modulation gate, [nOut] 5: Wcf - + * weights - cell peephole (t-1) connections to forget gate, [nOut] 6: Wco - + * weights - cell peephole (t) connections to output gate, [nOut] 7: b - + * biases, [4*nOut] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell + * state, if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, nOut] + * 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs) + * 2: f - Output - forget gate activations [bs, nOut] + * 3: o - Output - output gate activations [bs, nOut] + * 4: z (ci) - Output - block input [bs, nOut] + * 5: h (co) - Cell state, post tanh [bs, nOut] + * 6: y (h) - Current cell output [bS, nOut], time t + */ + const bool peephole = + (bool)params[0]; // if true, provide peephole connections + const double forgetBias = params[1]; + const double clippingCellValue = + params[2]; // clipping value for ct, if it is not equal to zero, then + // cell state is clipped + + const int bS = xt->sizeAt(0); + const int nIn = xt->sizeAt(1); + const int nOut = cLast->sizeAt(1); + + // Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] + NDArray concatOut(xt->ordering(), + {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, + xt->dataType(), xt->getContext()); + helpers::concat(xt->getContext(), + {const_cast(xt), const_cast(yLast)}, + concatOut, {1}); + + auto m = + mmul(concatOut, + *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut] + m += (*b); // addiRowVector + + // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) + auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] + auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] + auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] + auto zo = m({0, 0, 3 * nOut, 4 * nOut}); // z for output gate, [bS, nOut] + + if (peephole) { // add peephole connections: z + ct_1*Wc + zi += (*cLast) * (*Wci); // add peephole connections to input gate + zf += (*cLast) * (*Wcf); // add peephole connections to forget gate + } + + // current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc + if (forgetBias != 0.0) zf += forgetBias; + + zz.applyTransform(transform::Tanh, *z); // z = tanh(zz) + zi.applyTransform(transform::Sigmoid, *i); // i = sigmoid(zi) + zf.applyTransform(transform::Sigmoid, *f); // f = sigmoid(zf); + + // cell state = blockInput .* inputGate + prevCellState .* forgetGate + z->applyPairwiseTransform(pairwise::Multiply, *i, *c); // c = z * i + auto temp = (*f) * (*cLast); + *c += temp; // c = (i * z) + (zf * (*cLast)) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + + // if clipping value is provided then cell state is clipped by this value + // prior to the cell output activation + if (clippingCellValue > 0.0) + c->applyScalar(scalar::LstmClip, clippingCellValue, *c); + + if (peephole) { + // add peephole connections to output gate zot + ct*Wc + auto prod = *c * (*Wco); + zo += prod; + } + zo.applyTransform(transform::Sigmoid, *o); // o = sigmoid(zo) + + // current cell output = ot*tanh(ct) + c->applyTransform(transform::Tanh, *h); // h = tanh(c) + o->applyPairwiseTransform(pairwise::Multiply, *h, *y); // y = o * h } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu index 8d8548be5582..75879e839196 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstsq.cu @@ -17,99 +17,128 @@ // // @author GS // -#include #include +#include #include #include -#include - -#include +#include #include #include -#include +#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void fillRegularizerKernel(T* ioMatrixData, const Nd4jLong* ioMatrixShape, const Nd4jLong* ioMatrixTads, const Nd4jLong* ioMatrixOffsets, Nd4jLong batchSize, Nd4jLong rows, T const value) { - - for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) { - auto z = ioMatrixData + ioMatrixOffsets[x]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - Nd4jLong pos[] = {r,r}; - auto zIndex = shape::getOffset(ioMatrixTads, pos); - z[zIndex] = value; - } - } - +template +static __global__ void fillRegularizerKernel(T* ioMatrixData, + const Nd4jLong* ioMatrixShape, + const Nd4jLong* ioMatrixTads, + const Nd4jLong* ioMatrixOffsets, + Nd4jLong batchSize, Nd4jLong rows, + T const value) { + for (auto x = blockIdx.x; x < batchSize; x += gridDim.x) { + auto z = ioMatrixData + ioMatrixOffsets[x]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + Nd4jLong pos[] = {r, r}; + auto zIndex = shape::getOffset(ioMatrixTads, pos); + z[zIndex] = value; } + } +} - template - static void fillRegularizer(sd::LaunchContext* context, NDArray& ioMatrix, double const value) { - auto lastDimsTads = ConstantTadHelper::getInstance()->tadForDimensions(ioMatrix.shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto rows = ioMatrix.sizeAt(-2); - //auto cols = ioMatrix.sizeAt(-1); - fillRegularizerKernel<<<256, 256, 128, *stream>>>(ioMatrix.dataBuffer()->specialAsT(), ioMatrix.specialShapeInfo(), lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), lastDimsTads.numberOfTads(), rows, (T)value); - - } - - template - int leastSquaresSolveFunctor_(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - if (fast) { // Cholesky decomposition approach - // Equation for solve A^T * Ax = A^T * b, so - // 1. Computing A2: - auto tAtShape = ShapeUtils::evalShapeForMatmul(leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); - //tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); - NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), context); - MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, false); // Computing A2 = A^T * A - // 2. Computing B' = A^T * b - auto rightOutput = output->ulike(); - - MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, false); // Computing B' = A^T * b - // 3. Regularization ( indeed A' = A2 - l2Regularizer * I) - if (l2Regularizer != 0.0) { - auto regularizer = leftOutput.ulike(); regularizer.nullify(); - fillRegularizer(context, regularizer, (T)l2Regularizer); - leftOutput += regularizer; - } - - // 4. Cholesky decomposition -- output matrix is square and lower triangular - helpers::cholesky(context, &leftOutput, &leftOutput, true); // inplace decomposition - // 5. Solve two triangular systems: - auto rightB = rightOutput.ulike(); rightB.nullify(); - - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, false, &rightB); - - helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); - helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, output); - // All done - } - else { // QR decomposition approach - // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal matrix, and R - upper triangular - // 1. QR decomposition - auto qShape = leftInput->getShapeAsVector(); - auto rShape = leftInput->getShapeAsVector(); - qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); - - NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), context);// = leftInput->ulike(); - NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), context); // = rightInput->ulike(); - helpers::qr(context, leftInput, &Q, &R, true); - // 2. b` = Q^t * b: - auto rightOutput = rightInput->ulike(); - MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); - // 3. Solve triangular system - helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, output); - } - return Status::OK(); - } +template +static void fillRegularizer(sd::LaunchContext* context, NDArray& ioMatrix, + double const value) { + auto lastDimsTads = ConstantTadHelper::getInstance()->tadForDimensions( + ioMatrix.shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto rows = ioMatrix.sizeAt(-2); + // auto cols = ioMatrix.sizeAt(-1); + fillRegularizerKernel<<<256, 256, 128, *stream>>>( + ioMatrix.dataBuffer()->specialAsT(), ioMatrix.specialShapeInfo(), + lastDimsTads.specialShapeInfo(), lastDimsTads.specialOffsets(), + lastDimsTads.numberOfTads(), rows, (T)value); +} - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return leastSquaresSolveFunctor_, (context, leftInput, rightInput, l2Regularizer, fast, output), FLOAT_TYPES); +template +int leastSquaresSolveFunctor_(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + if (fast) { // Cholesky decomposition approach + // Equation for solve A^T * Ax = A^T * b, so + // 1. Computing A2: + auto tAtShape = ShapeUtils::evalShapeForMatmul( + leftInput->shapeInfo(), leftInput->shapeInfo(), true, false); + // tAtShape[tAtShape.size() - 2] = output->sizeAt(-2); + NDArray leftOutput(leftInput->ordering(), tAtShape, output->dataType(), + context); + MmulHelper::matmul(leftInput, leftInput, &leftOutput, true, + false); // Computing A2 = A^T * A + // 2. Computing B' = A^T * b + auto rightOutput = output->ulike(); + + MmulHelper::matmul(leftInput, rightInput, &rightOutput, true, + false); // Computing B' = A^T * b + // 3. Regularization ( indeed A' = A2 - l2Regularizer * I) + if (l2Regularizer != 0.0) { + auto regularizer = leftOutput.ulike(); + regularizer.nullify(); + fillRegularizer(context, regularizer, (T)l2Regularizer); + leftOutput += regularizer; } + // 4. Cholesky decomposition -- output matrix is square and lower triangular + helpers::cholesky(context, &leftOutput, &leftOutput, + true); // inplace decomposition + // 5. Solve two triangular systems: + auto rightB = rightOutput.ulike(); + rightB.nullify(); + + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, true, + false, &rightB); + + helpers::adjointMatrix(context, &leftOutput, true, &leftOutput); + helpers::triangularSolveFunctor(context, &leftOutput, &rightB, false, false, + output); + // All done + } else { // QR decomposition approach + // Equation for solve Rx = Q^T * b, where A = Q * R, where Q - orthogonal + // matrix, and R - upper triangular + // 1. QR decomposition + auto qShape = leftInput->getShapeAsVector(); + auto rShape = leftInput->getShapeAsVector(); + qShape[leftInput->rankOf() - 1] = leftInput->sizeAt(-2); + + NDArray Q(leftInput->ordering(), qShape, leftInput->dataType(), + context); // = leftInput->ulike(); + NDArray R(leftInput->ordering(), rShape, leftInput->dataType(), + context); // = rightInput->ulike(); + helpers::qr(context, leftInput, &Q, &R, true); + // 2. b` = Q^t * b: + auto rightOutput = rightInput->ulike(); + MmulHelper::matmul(&Q, rightInput, &rightOutput, true, false); + // 3. Solve triangular system + helpers::triangularSolveFunctor(context, &R, &rightOutput, false, false, + output); + } + return Status::OK(); } + +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return leastSquaresSolveFunctor_, + (context, leftInput, rightInput, l2Regularizer, fast, output), + FLOAT_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index e9a2fc9e336d..13a3d3329539 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -18,12 +18,12 @@ // @author raver119@gmail.com // -#include -#include #include #include #include +#include #include +#include //#include #include @@ -33,978 +33,1074 @@ namespace sd { namespace ops { namespace helpers { -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invert the second diagonal for lower diagonal matrix - template - static __global__ void - invertKernelLow(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - auto inverted = reinterpret_cast(invertedBuf); - auto input = reinterpret_cast(inputBuf); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start + 1; i < n; i += step) { - Nd4jLong pos[] = {i, i - 1}; - Nd4jLong posX[] = {i, i}; - Nd4jLong posY[] = {i - 1, i - 1}; - auto xIndex = shape::getOffset(inputShape, pos); - auto dxIndex = shape::getOffset(inputShape, posX); - auto dyIndex = shape::getOffset(inputShape, posY); - auto zIndex = shape::getOffset(invertedShape, pos); - // invert lower triangular matrix - inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]); -// math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]); - } - } -// ------------------------------------------------------------------------------------------------------------------ // -// invert diagonal vals to upper diagonal matrix - template - static __global__ void - upvertKernel(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - auto inverted = reinterpret_cast(invertedBuf); - auto input = reinterpret_cast(inputBuf); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < n; i += step) { - Nd4jLong pos[] = {i, i}; - auto xIndex = shape::getOffset(inputShape, pos); - auto zIndex = shape::getOffset(invertedShape, pos); - - // invert diagonal elements - inverted[zIndex] /= input[xIndex]; - } - } +template +static __global__ void invertKernelLow(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + auto input = reinterpret_cast(inputBuf); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start + 1; i < n; i += step) { + Nd4jLong pos[] = {i, i - 1}; + Nd4jLong posX[] = {i, i}; + Nd4jLong posY[] = {i - 1, i - 1}; + auto xIndex = shape::getOffset(inputShape, pos); + auto dxIndex = shape::getOffset(inputShape, posX); + auto dyIndex = shape::getOffset(inputShape, posY); + auto zIndex = shape::getOffset(invertedShape, pos); + // invert lower triangular matrix + inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]); + // math::atomics::nd4j_atomicAdd(&inverted[zIndex], - + // input[xIndex] * inverted[iIndex] / input[dIndex]); + } +} +// ------------------------------------------------------------------------------------------------------------------ +// // invert diagonal vals to upper diagonal matrix +template +static __global__ void upvertKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + auto input = reinterpret_cast(inputBuf); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < n; i += step) { + Nd4jLong pos[] = {i, i}; + auto xIndex = shape::getOffset(inputShape, pos); + auto zIndex = shape::getOffset(invertedShape, pos); + + // invert diagonal elements + inverted[zIndex] /= input[xIndex]; + } +} -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invert upper second diagonal - template - static __global__ void - upvertKernelUp(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - - __shared__ T* inverted; - __shared__ const T* input; - if (threadIdx.x == 0) { - inverted = reinterpret_cast(invertedBuf); - input = reinterpret_cast(inputBuf); - } - __syncthreads(); - - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < n - 1; i += step) { - Nd4jLong pos[] = {i, i + 1}; - Nd4jLong posX[] = {i + 1, i + 1}; - auto xIndex = shape::getOffset(inputShape, pos); - auto iIndex = shape::getOffset(invertedShape, posX); - auto zIndex = shape::getOffset(invertedShape, pos); - // invert upper matrix - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -input[xIndex] * inverted[iIndex]); // / input[yIndex]); - //inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / inputMatrix->t(i, i) - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - template - static __global__ void - invertLowKernel(void *invertedBuf, const Nd4jLong *invertedShape, const void *inputBuf, const Nd4jLong *inputShape, Nd4jLong n) { - - auto input = reinterpret_cast(inputBuf); - auto inverted = reinterpret_cast(invertedBuf); - - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (int i = tid + 2; i < n; i += step) { - for (int j = i - 2; j >= 0; --j) - for (int k = 0; k < i; k++) { - Nd4jLong posZ[] = {i, j}; - Nd4jLong posY[] = {k, j}; - Nd4jLong posX[] = {i, k}; - Nd4jLong posD[] = {i, i}; - - auto xIndex = shape::getOffset(inputShape, posX); - auto yIndex = shape::getOffset(invertedShape, posY); - auto dIndex = shape::getOffset(inputShape, posD); - auto zIndex = shape::getOffset(invertedShape, posZ); - // invert non-diagonal elements - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex] / input[dIndex]); - } - } - } - -// ------------------------------------------------------------------------------------------------------------------ // -// Invertion of upper triangular matrix non-diagonal elements when main and second diagonals already processed - template - static __global__ void - invertUpKernel( - void *invertedBuf, const Nd4jLong *invertedShape, - const void *inputBuf, const Nd4jLong *inputShape, - Nd4jLong n) { - - auto inverted = reinterpret_cast(invertedBuf);; - auto input = reinterpret_cast(inputBuf); - - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = (int)n - tid - 2; i >= 0; i -= step) { - for (int j = i + 2; j < (int)n; j++) - for (int k = i; k < (int)n; k++) { - Nd4jLong posZ[] = {i, j}; - Nd4jLong posY[] = {k, j}; - Nd4jLong posX[] = {i, k}; - // inversion with Joardan Gauss transformation - auto xIndex = shape::getOffset(inputShape, posX); - auto yIndex = shape::getOffset(invertedShape, posY); - auto zIndex = shape::getOffset(invertedShape, posZ); - // invert upper non-diagonal elements - math::atomics::nd4j_atomicAdd(&inverted[zIndex], -inverted[yIndex] * input[xIndex]); - } - } - } - -// ------------------------------------------------------------------------------------------------------------------ // -// procedure to invert lower-triangular matrix. -// In current case lower triangular matrix has main diagonal with general values -// - template - static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - - if (inputMatrix->isIdentityMatrix()) return; - - auto stream = context->getCudaStream(); +template +static __global__ void upvertKernelUp(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + __shared__ T *inverted; + __shared__ const T *input; + if (threadIdx.x == 0) { + inverted = reinterpret_cast(invertedBuf); + input = reinterpret_cast(inputBuf); + } + __syncthreads(); + + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < n - 1; i += step) { + Nd4jLong pos[] = {i, i + 1}; + Nd4jLong posX[] = {i + 1, i + 1}; + auto xIndex = shape::getOffset(inputShape, pos); + auto iIndex = shape::getOffset(invertedShape, posX); + auto zIndex = shape::getOffset(invertedShape, pos); + // invert upper matrix + math::atomics::nd4j_atomicAdd( + &inverted[zIndex], + -input[xIndex] * inverted[iIndex]); // / input[yIndex]); + // inputMatrix->t(i, i + 1) * invertedMatrix->t(i + 1, i + 1) / + // inputMatrix->t(i, i) + } +} - // invert lower matrix - // invert main diagonal - upvertKernel<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - // invert the second diagonal - invertKernelLow<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +// ------------------------------------------------------------------------------------------------------------------ +// // +template +static __global__ void invertLowKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto input = reinterpret_cast(inputBuf); + auto inverted = reinterpret_cast(invertedBuf); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (int i = tid + 2; i < n; i += step) { + for (int j = i - 2; j >= 0; --j) + for (int k = 0; k < i; k++) { + Nd4jLong posZ[] = {i, j}; + Nd4jLong posY[] = {k, j}; + Nd4jLong posX[] = {i, k}; + Nd4jLong posD[] = {i, i}; + + auto xIndex = shape::getOffset(inputShape, posX); + auto yIndex = shape::getOffset(invertedShape, posY); + auto dIndex = shape::getOffset(inputShape, posD); + auto zIndex = shape::getOffset(invertedShape, posZ); // invert non-diagonal elements - invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - } + math::atomics::nd4j_atomicAdd( + &inverted[zIndex], + -inverted[yIndex] * input[xIndex] / input[dIndex]); + } + } +} -// ------------------------------------------------------------------------------------------------------------------ // -// caller for invert lower matrix routine - void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); - NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); - } +// ------------------------------------------------------------------------------------------------------------------ +// // Invertion of upper triangular matrix non-diagonal elements when main and +// second diagonals already processed +template +static __global__ void invertUpKernel(void *invertedBuf, + const Nd4jLong *invertedShape, + const void *inputBuf, + const Nd4jLong *inputShape, Nd4jLong n) { + auto inverted = reinterpret_cast(invertedBuf); + ; + auto input = reinterpret_cast(inputBuf); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = (int)n - tid - 2; i >= 0; i -= step) { + for (int j = i + 2; j < (int)n; j++) + for (int k = i; k < (int)n; k++) { + Nd4jLong posZ[] = {i, j}; + Nd4jLong posY[] = {k, j}; + Nd4jLong posX[] = {i, k}; + // inversion with Joardan Gauss transformation + auto xIndex = shape::getOffset(inputShape, posX); + auto yIndex = shape::getOffset(invertedShape, posY); + auto zIndex = shape::getOffset(invertedShape, posZ); + // invert upper non-diagonal elements + math::atomics::nd4j_atomicAdd(&inverted[zIndex], + -inverted[yIndex] * input[xIndex]); + } + } +} -// ------------------------------------------------------------------------------------------------------------------ // -// procedure to invert upper-triangular matrix. -// In current case upper triangular matrix has main diagonal with all ones on it. - template - static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) { - int n = inputMatrix->rows(); - invertedMatrix->setIdentity(); - auto stream = context->getCudaStream(); - if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I - return; - } +// ------------------------------------------------------------------------------------------------------------------ +// // procedure to invert lower-triangular matrix. In current case lower +// triangular matrix has main diagonal with general values +// +template +static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + + if (inputMatrix->isIdentityMatrix()) return; + + auto stream = context->getCudaStream(); + + // invert lower matrix + // invert main diagonal + upvertKernel<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + // invert the second diagonal + invertKernelLow<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + // invert non-diagonal elements + invertLowKernel<<>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +} - // invert upper matrix - // invert the second diagonal - upvertKernelUp<<<1, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), - inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +// ------------------------------------------------------------------------------------------------------------------ +// // caller for invert lower matrix routine +void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, + (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); + NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); +} - // invert other elements - invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - } +// ------------------------------------------------------------------------------------------------------------------ +// // procedure to invert upper-triangular matrix. In current case upper +// triangular matrix has main diagonal with all ones on it. +template +static void invertUpperMatrix_(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + int n = inputMatrix->rows(); + invertedMatrix->setIdentity(); + auto stream = context->getCudaStream(); + if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I + return; + } + + // invert upper matrix + // invert the second diagonal + upvertKernelUp<<<1, n, 512, *stream>>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + + // invert other elements + invertUpKernel<<>>( + invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), + inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); +} -// ------------------------------------------------------------------------------------------------------------------ // +// ------------------------------------------------------------------------------------------------------------------ +// // // invertion of upper triangular matrix - runner routine - void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); - NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - } - -// ------------------------------------------------------------------------------------------------------------------ // - // determinant kernel - accumulation product of all values on the main diagonal - template - static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < len; i += step) { - auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - // multiply all diagonal elements - math::atomics::nd4j_atomicMul(&result[0], compound[pos]); - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - // determinant logarithm - accumulation sum of all logarithm values on the main diagonal. All in logarithic values - // should be positive - template - static __global__ void determinantLogKernel(T *compound, T *result, Nd4jLong len) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < len; i += step) { - auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); - // sum logs of all diagonal elements - math::atomics::nd4j_atomicAdd(result, math::nd4j_log(math::nd4j_abs(compound[pos]))); - } - } - -// ------------------------------------------------------------------------------------------------------------------ // - // kernel to copy matrix with given shape to compound tensor with given pos - // output - a N-D tensor buffer with rank not less than 2, input - 2D square n x n matrix with n = rowLen - template - static __global__ void - fillMatrix(void *output, const Nd4jLong *outShape, const void *input, const Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ F *matrix; - __shared__ const T *inputBuf; - __shared__ Nd4jLong inputLen; - __shared__ Nd4jLong n2; - - if (threadIdx.x == 0) { - matrix = reinterpret_cast(output); - inputBuf = reinterpret_cast(input); - inputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto xIndex = shape::getIndexOffset(k, inputShape); - matrix[j] = (F) inputBuf[xIndex]; - } - } +void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, + NDArray *invertedMatrix) { + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); + BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, + (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); + NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); +} -// ------------------------------------------------------------------------------------------------------------------ // -// same as above, but without type conversion - template - static __global__ void - returnMatrix(void *output, const Nd4jLong *outputShape, const void *input, const Nd4jLong *inputShape, Nd4jLong pos, Nd4jLong rowLen) { - __shared__ Nd4jLong outputLen; - __shared__ Nd4jLong n2; - auto matrix = reinterpret_cast(input); - auto outputBuf = reinterpret_cast(output); +// ------------------------------------------------------------------------------------------------------------------ +// // determinant kernel - accumulation product of all values on the main +// diagonal +template +static __global__ void determinantKernel(T *compound, T *result, Nd4jLong len) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < len; i += step) { + auto pos = i * len + i; // shape::getOffset(0, shape::shapeOf(shape), + // shape::stride(shape), di, 2); + // multiply all diagonal elements + math::atomics::nd4j_atomicMul(&result[0], compound[pos]); + } +} - if (threadIdx.x == 0) { +// ------------------------------------------------------------------------------------------------------------------ +// // determinant logarithm - accumulation sum of all logarithm values on the +// main diagonal. All in logarithic values should be positive +template +static __global__ void determinantLogKernel(T *compound, T *result, + Nd4jLong len) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < len; i += step) { + auto pos = i * len + i; // shape::getOffset(0, shape::shapeOf(shape), + // shape::stride(shape), di, 2); + // sum logs of all diagonal elements + math::atomics::nd4j_atomicAdd( + result, math::nd4j_log(math::nd4j_abs(compound[pos]))); + } +} - outputLen = shape::length(inputShape); - n2 = rowLen * rowLen; - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +// ------------------------------------------------------------------------------------------------------------------ +// // kernel to copy matrix with given shape to compound tensor with given pos +// output - a N-D tensor buffer with rank not less than 2, input - 2D square n x +// n matrix with n = rowLen +template +static __global__ void fillMatrix(void *output, const Nd4jLong *outShape, + const void *input, const Nd4jLong *inputShape, + Nd4jLong pos, Nd4jLong rowLen) { + __shared__ F *matrix; + __shared__ const T *inputBuf; + __shared__ Nd4jLong inputLen; + __shared__ Nd4jLong n2; + + if (threadIdx.x == 0) { + matrix = reinterpret_cast(output); + inputBuf = reinterpret_cast(input); + inputLen = shape::length(inputShape); + n2 = rowLen * rowLen; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto xIndex = shape::getIndexOffset(k, inputShape); + matrix[j] = (F)inputBuf[xIndex]; + } +} - for (int k = pos + start, j = start; j < n2; k += step, j += step) { - auto zIndex = shape::getIndexOffset(k, outputShape); - outputBuf[zIndex] = matrix[j]; - } - } +// ------------------------------------------------------------------------------------------------------------------ +// // same as above, but without type conversion +template +static __global__ void returnMatrix(void *output, const Nd4jLong *outputShape, + const void *input, + const Nd4jLong *inputShape, Nd4jLong pos, + Nd4jLong rowLen) { + __shared__ Nd4jLong outputLen; + __shared__ Nd4jLong n2; + auto matrix = reinterpret_cast(input); + auto outputBuf = reinterpret_cast(output); + + if (threadIdx.x == 0) { + outputLen = shape::length(inputShape); + n2 = rowLen * rowLen; + } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int k = pos + start, j = start; j < n2; k += step, j += step) { + auto zIndex = shape::getIndexOffset(k, outputShape); + outputBuf[zIndex] = matrix[j]; + } +} -// ------------------------------------------------------------------------------------------------------------------ // - // fill up permutaion matrix kernel. Permutation matrix filled with zeros and ones - template - static __global__ void fillUpPermutation(void *output, const Nd4jLong *shape, int *source, int rowNum) { - F *permutation = reinterpret_cast(output); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - for (auto i = start; i < rowNum; i += step) { - int val = source[i] - 1; - Nd4jLong posF[] = {i, val}; - auto pos = shape::getOffset(shape, posF); - permutation[pos] = F(1.f); - } - } +// ------------------------------------------------------------------------------------------------------------------ +// // fill up permutaion matrix kernel. Permutation matrix filled with zeros and +// ones +template +static __global__ void fillUpPermutation(void *output, const Nd4jLong *shape, + int *source, int rowNum) { + F *permutation = reinterpret_cast(output); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + for (auto i = start; i < rowNum; i += step) { + int val = source[i] - 1; + Nd4jLong posF[] = {i, val}; + auto pos = shape::getOffset(shape, posF); + permutation[pos] = F(1.f); + } +} -// ------------------------------------------------------------------------------------------------------------------ // - // LUP decomposition runner - using CUBLAS SOLVER - // if permutation is given, then using LUP decomposition, LU decomposition otherwise - // L - lower triangular, U - upper triangular, P - permutation matricies - // PA = LU - // - // input - A matrix nxn - // compound - C matrix L + U - I, or main diagonal and lower - L matrix, from the 2nd diagonal - U matrix - template - static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { - auto stream = context->getCudaStream(); - auto n = input->rows(); - std::lock_guard lock(*LaunchContext::deviceMutex()); - - cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; - // create solver handle - cusolverStatus_t status; //cusolverDnCreate(&cusolverH); -// if (CUSOLVER_STATUS_SUCCESS != status) { -// throw cuda_exception::build("Cannot create cuSolver handle", status); -// } - // set solver stream - status = cusolverDnSetStream(*cusolverH, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot set up stream for cuda solver", status); +// ------------------------------------------------------------------------------------------------------------------ +// // LUP decomposition runner - using CUBLAS SOLVER if permutation is given, +// then using LUP decomposition, LU decomposition otherwise L - lower +// triangular, U - upper triangular, P - permutation matricies PA = LU +// +// input - A matrix nxn +// compound - C matrix L + U - I, or main diagonal and lower - L matrix, from +// the 2nd diagonal - U matrix +template +static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, + NDArray *permutation) { + auto stream = context->getCudaStream(); + auto n = input->rows(); + std::lock_guard lock(*LaunchContext::deviceMutex()); + + cusolverDnHandle_t *cusolverH = + (cusolverDnHandle_t *)context->getCusolverHandle(); // nullptr; + // create solver handle + cusolverStatus_t status; // cusolverDnCreate(&cusolverH); + // if (CUSOLVER_STATUS_SUCCESS != status) { + // throw cuda_exception::build("Cannot create cuSolver handle", + // status); + // } + // set solver stream + status = cusolverDnSetStream(*cusolverH, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("Cannot set up stream for cuda solver", status); + } + int lwork = 0; + int *d_info = nullptr; + // allocate memory for permutation vector + auto err = cudaMalloc((void **)&d_info, sizeof(int)); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver info buffer", err); + } + + DataType dtype = input->dataType(); + switch (dtype) { // there are two implementations with cublas for LUP + // decomposition - double and float + + case DataType::DOUBLE: { + double *d_work = nullptr; + // compute internal buffer size + double *matrix = reinterpret_cast(input->specialBuffer()); + status = cusolverDnDgetrf_bufferSize(*cusolverH, n, n, matrix, n, &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::lup_: Cannot create cuSolver handle", status); + } + + err = cudaMalloc((void **)&d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + + if (permutation == nullptr) { + status = cusolverDnDgetrf(*cusolverH, n, n, matrix, n, d_work, nullptr, + d_info); + + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build( + "helpers::lup_: LU factorization is failed due ", status); } - int lwork = 0; - int *d_info = nullptr; - // allocate memory for permutation vector - auto err = cudaMalloc((void **) &d_info, sizeof(int)); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err); + } else { + NDArray permutVector('c', {n}, sd::DataType::INT32, context); + int *permutationBuf = permutVector.dataBuffer()->specialAsT(); + status = cusolverDnDgetrf(*cusolverH, n, n, matrix, n, d_work, + permutationBuf, d_info); + if (status != CUSOLVER_STATUS_SUCCESS) { + throw cuda_exception::build( + "helpers::lup_: LU factorization is failed due ", status); } - DataType dtype = input->dataType(); - switch (dtype) { // there are two implementations with cublas for LUP decomposition - double and float - - case DataType::DOUBLE: { - double *d_work = nullptr; - // compute internal buffer size - double *matrix = reinterpret_cast(input->specialBuffer()); - status = cusolverDnDgetrf_bufferSize( - *cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", - err); - } - - if (permutation == nullptr) { - status = cusolverDnDgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - - if (status != CUSOLVER_STATUS_SUCCESS) { - throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", - status); - } - } - else { - NDArray permutVector('c', {n}, sd::DataType::INT32, context); - int* permutationBuf = permutVector.dataBuffer()->specialAsT(); - status = cusolverDnDgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - if (status != CUSOLVER_STATUS_SUCCESS) { - throw cuda_exception::build("helpers::lup_: LU factorization is failed due ", - status); - } - - if (permutation->rankOf() == 2) { - fillUpPermutation <<< n, n, 1024, *stream >>> - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - } - else { - permutVector.tickWriteDevice(); - input->tickWriteDevice(); - compound->assign(input); - permutation->assign(permutVector); - } - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", - err); - } - } - break; - case DataType::FLOAT32: { - float *matrix = reinterpret_cast(input->specialBuffer()); - float *d_work = nullptr; - - status = cusolverDnSgetrf_bufferSize( - *cusolverH, - n, - n, - matrix, - n, - &lwork); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status); - } - - err = cudaMalloc((void **) &d_work, sizeof(float) * lwork); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", - err); - } - - if (permutation == nullptr) - status = cusolverDnSgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - nullptr, - d_info); - else { - NDArray permutVector('c', {n}, DataType::INT32, context); - int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); - status = cusolverDnSgetrf( - *cusolverH, - n, - n, - matrix, - n, - d_work, - permutationBuf, - d_info); - if (permutation->rankOf() == 2) { - fillUpPermutation <<< n, n, 128, *stream >>> - (permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); - permutation->tickWriteDevice(); - } - else { - input->tickWriteDevice(); - compound->assign(input); - permutation->assign(permutVector); - } - } - err = cudaFree(d_work); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", - err); - } - - } - } - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status); + if (permutation->rankOf() == 2) { + fillUpPermutation<<>>( + permutation->specialBuffer(), permutation->specialShapeInfo(), + permutationBuf, n); + } else { + permutVector.tickWriteDevice(); + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); } - err = cudaFree(d_info); - if (err) { - throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } + } break; + case DataType::FLOAT32: { + float *matrix = reinterpret_cast(input->specialBuffer()); + float *d_work = nullptr; + + status = cusolverDnSgetrf_bufferSize(*cusolverH, n, n, matrix, n, &lwork); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::lup_: Cannot create cuSolver handle", status); + } + + err = cudaMalloc((void **)&d_work, sizeof(float) * lwork); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot allocate memory for solver data buffer", + err); + } + + if (permutation == nullptr) + status = cusolverDnSgetrf(*cusolverH, n, n, matrix, n, d_work, nullptr, + d_info); + else { + NDArray permutVector('c', {n}, DataType::INT32, context); + int *permutationBuf = + reinterpret_cast(permutVector.specialBuffer()); + status = cusolverDnSgetrf(*cusolverH, n, n, matrix, n, d_work, + permutationBuf, d_info); + if (permutation->rankOf() == 2) { + fillUpPermutation<<>>( + permutation->specialBuffer(), permutation->specialShapeInfo(), + permutationBuf, n); + permutation->tickWriteDevice(); + } else { + input->tickWriteDevice(); + compound->assign(input); + permutation->assign(permutVector); } -// cusolverDnDestroy(cusolverH); -// NDArray::registerSpecialUse({input}, {input}); - input->tickWriteDevice(); + } + err = cudaFree(d_work); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver data buffer", + err); + } } -// ------------------------------------------------------------------------------------------------------------------ // - - BUILD_DOUBLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE, INDEXING_TYPES); - - template - static __device__ void swapRows(T* matrix, const Nd4jLong* shape, Nd4jLong theFirst, Nd4jLong theSecond, Nd4jLong n) { - if (theFirst != theSecond) { - for (auto i = 0; i < n; i++) { - Nd4jLong theFirstPos[] = {theFirst, i}; - Nd4jLong theSecondPos[] = {theSecond, i}; - auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); - auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); - math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); - } - } + } + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", + status); + } + err = cudaFree(d_info); + if (err) { + throw cuda_exception::build( + "helpers::lup_: Cannot deallocate memory for solver info buffer", err); + } + // cusolverDnDestroy(cusolverH); + // NDArray::registerSpecialUse({input}, {input}); + input->tickWriteDevice(); +} +// ------------------------------------------------------------------------------------------------------------------ +// // + +BUILD_DOUBLE_TEMPLATE(template void lup_, + (LaunchContext * context, NDArray *input, NDArray *output, + NDArray *permutation), + FLOAT_NATIVE, INDEXING_TYPES); + +template +static __device__ void swapRows(T *matrix, const Nd4jLong *shape, + Nd4jLong theFirst, Nd4jLong theSecond, + Nd4jLong n) { + if (theFirst != theSecond) { + for (auto i = 0; i < n; i++) { + Nd4jLong theFirstPos[] = {theFirst, i}; + Nd4jLong theSecondPos[] = {theSecond, i}; + auto theFirstIndex = shape::getOffset(shape, theFirstPos, 0); + auto theSecondIndex = shape::getOffset(shape, theSecondPos, 0); + math::nd4j_swap(matrix[theFirstIndex], matrix[theSecondIndex]); } + } +} - template - static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, T* compoundBuf, const Nd4jLong* compoundShape) { - Nd4jLong xDiag[] = {currentRow, currentRow}; - auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); - for (auto j = currentRow + 1; j < rowNum; j++) { - Nd4jLong xRow[] = {j, currentRow}; - auto rowIndex = shape::getOffset(compoundShape, xRow, 0); - compoundBuf[rowIndex] /= compoundBuf[diagIndex]; //output->t(i, i); - for (auto k = currentRow + 1; k < rowNum; k++) { - Nd4jLong yRow[] = {j, k}; - Nd4jLong yCol[] = {currentRow, k}; - auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); - auto colIndex = shape::getOffset(compoundShape, yCol, 0); - compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; - } - } +template +static __device__ void processColumns(Nd4jLong currentRow, Nd4jLong rowNum, + T *compoundBuf, + const Nd4jLong *compoundShape) { + Nd4jLong xDiag[] = {currentRow, currentRow}; + auto diagIndex = shape::getOffset(compoundShape, xDiag, 0); + for (auto j = currentRow + 1; j < rowNum; j++) { + Nd4jLong xRow[] = {j, currentRow}; + auto rowIndex = shape::getOffset(compoundShape, xRow, 0); + compoundBuf[rowIndex] /= compoundBuf[diagIndex]; // output->t(i, i); + for (auto k = currentRow + 1; k < rowNum; k++) { + Nd4jLong yRow[] = {j, k}; + Nd4jLong yCol[] = {currentRow, k}; + auto rowIndexY = shape::getOffset(compoundShape, yRow, 0); + auto colIndex = shape::getOffset(compoundShape, yCol, 0); + compoundBuf[rowIndexY] -= compoundBuf[rowIndex] * compoundBuf[colIndex]; } + } +} - template - __device__ Nd4jLong argmaxCol(Nd4jLong column, T* compoundBuffer, const Nd4jLong* compoundShape) { - auto rowNum = shape::sizeAt(compoundShape, 0); - Nd4jLong xInitial[] = {column, column}; - auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); - auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); - auto result = -1LL; - - for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { - Nd4jLong xPos[] = {rowCounter, column}; - auto xIndex = shape::getOffset(compoundShape, xPos, 0); - if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { - maxValue = sd::math::nd4j_max(maxValue, sd::math::nd4j_abs(compoundBuffer[xIndex])); - result = rowCounter; - } - } - return result; +template +__device__ Nd4jLong argmaxCol(Nd4jLong column, T *compoundBuffer, + const Nd4jLong *compoundShape) { + auto rowNum = shape::sizeAt(compoundShape, 0); + Nd4jLong xInitial[] = {column, column}; + auto xInitialIndex = shape::getOffset(compoundShape, xInitial, 0); + auto maxValue = T(0); // sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); + auto result = -1LL; + + for (auto rowCounter = column; rowCounter < rowNum; rowCounter++) { + Nd4jLong xPos[] = {rowCounter, column}; + auto xIndex = shape::getOffset(compoundShape, xPos, 0); + if (sd::math::nd4j_abs(compoundBuffer[xIndex]) > maxValue) { + maxValue = sd::math::nd4j_max(maxValue, + sd::math::nd4j_abs(compoundBuffer[xIndex])); + result = rowCounter; } + } + return result; +} - template - static __device__ int luNN(T* matrix, const Nd4jLong* shape, I* permutation, const Nd4jLong* permuShape, Nd4jLong n) { - - for (auto i = 0; i < n - 1; i++) { - auto pivotIndex = argmaxCol(i, matrix, shape); - if (pivotIndex < 0) { - return -1;//throw std::runtime_error("helpers::luNN_: input matrix is singular."); - } - math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], permutation[shape::getIndexOffset(pivotIndex, permuShape)]); - swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); - - processColumns(i, n, matrix, shape); - } - return 0; +template +static __device__ int luNN(T *matrix, const Nd4jLong *shape, I *permutation, + const Nd4jLong *permuShape, Nd4jLong n) { + for (auto i = 0; i < n - 1; i++) { + auto pivotIndex = argmaxCol(i, matrix, shape); + if (pivotIndex < 0) { + return -1; // throw std::runtime_error("helpers::luNN_: input matrix is + // singular."); } + math::nd4j_swap(permutation[shape::getIndexOffset(i, permuShape)], + permutation[shape::getIndexOffset(pivotIndex, permuShape)]); + swapRows(matrix, shape, (Nd4jLong)i, pivotIndex, n); - template - static __global__ void luBatchedKernel( - T* outputBuf, const Nd4jLong* outputShape, - I* permutations, const Nd4jLong* permuShape, - const Nd4jLong* outputTadShape, const Nd4jLong* outputTadOffsets, - const Nd4jLong* permuTadShape, const Nd4jLong* permuTadOffsets, - Nd4jLong batchNum) { - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto b = start; b < batchNum; b += step) { - T* matrix = outputBuf + outputTadOffsets[b]; - I* permutation = permutations + permuTadOffsets[b]; - - if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, shape::length(permuTadShape))) break; - } - } + processColumns(i, n, matrix, shape); + } + return 0; +} - template - static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { - auto n = input->sizeAt(-1); - auto stream = context->getCudaStream(); - NDArray iota('c', {n}, permutationVectors->dataType(), context);// = NDArrayFactory::create(); // ('c', {n}); - iota.linspace(0); iota.syncToDevice(); - - output->assign(input); // fill up output tensor with zeros -// output->tickWriteDevice(); - permutationVectors->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr); -// permutationVectors->tickWriteDevice(); - auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); - auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1}); - auto batchNum = tads.numberOfTads(); - luBatchedKernel<<>>(reinterpret_cast(output->platformBuffer()), - output->specialShapeInfo(), reinterpret_cast(permutationVectors->platformBuffer()), - permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), tads.specialOffsets(), - permutaionTads.specialShapeInfo(), permutaionTads.specialOffsets(), batchNum); - } +template +static __global__ void luBatchedKernel( + T *outputBuf, const Nd4jLong *outputShape, I *permutations, + const Nd4jLong *permuShape, const Nd4jLong *outputTadShape, + const Nd4jLong *outputTadOffsets, const Nd4jLong *permuTadShape, + const Nd4jLong *permuTadOffsets, Nd4jLong batchNum) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto b = start; b < batchNum; b += step) { + T *matrix = outputBuf + outputTadOffsets[b]; + I *permutation = permutations + permuTadOffsets[b]; + + if (0 != luNN(matrix, outputTadShape, permutation, permuTadShape, + shape::length(permuTadShape))) + break; + } +} - void lu(LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutations) { - NDArray::prepareSpecialUse({output, permutations}, {input}); - BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, (context, input, output, permutations), FLOAT_NATIVE, INDEXING_TYPES); - NDArray::registerSpecialUse({output, permutations}, {input}); - } -// ------------------------------------------------------------------------------------------------------------------ // - template - static int determinant_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); -// DataType dtype = input->dataType(); -// if (dtype != DataType::DOUBLE) -// dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.workspace()); - auto det = NDArrayFactory::create(1, context); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(1.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; -// if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); -// else -// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - lup_(context, &matrix, nullptr, nullptr); -// else -// lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; -// if (matrix.dataType() == input->dataType()) - determinantKernel<<< launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n); -// else -// determinantKernel<<>> (inputBuf, outputBuf, n); - } - NDArray::registerSpecialUse({output}, {input}); +template +static void lu_(LaunchContext *context, NDArray *input, NDArray *output, + NDArray *permutationVectors) { + auto n = input->sizeAt(-1); + auto stream = context->getCudaStream(); + NDArray iota('c', {n}, permutationVectors->dataType(), + context); // = NDArrayFactory::create(); // ('c', {n}); + iota.linspace(0); + iota.syncToDevice(); + + output->assign(input); // fill up output tensor with zeros + // output->tickWriteDevice(); + permutationVectors->applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), iota, + *permutationVectors, true, nullptr); + // permutationVectors->tickWriteDevice(); + auto tads = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {-1}); + auto batchNum = tads.numberOfTads(); + luBatchedKernel<<>>( + reinterpret_cast(output->platformBuffer()), + output->specialShapeInfo(), + reinterpret_cast(permutationVectors->platformBuffer()), + permutationVectors->specialShapeInfo(), tads.specialShapeInfo(), + tads.specialOffsets(), permutaionTads.specialShapeInfo(), + permutaionTads.specialOffsets(), batchNum); +} - return Status::OK(); - } +void lu(LaunchContext *context, NDArray *input, NDArray *output, + NDArray *permutations) { + NDArray::prepareSpecialUse({output, permutations}, {input}); + BUILD_DOUBLE_SELECTOR(input->dataType(), permutations->dataType(), lu_, + (context, input, output, permutations), FLOAT_NATIVE, + INDEXING_TYPES); + NDArray::registerSpecialUse({output, permutations}, {input}); +} +// ------------------------------------------------------------------------------------------------------------------ +// // +template +static int determinant_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + // auto packZ = + // ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), + // {output->rankOf() - 1}); + // DataType dtype = input->dataType(); + // if (dtype != DataType::DOUBLE) + // dtype = DataType::FLOAT32; + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, + DataTypeUtils::fromT(), + context); //, block.workspace()); + auto det = NDArrayFactory::create(1, context); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(1.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; + // if (matrix.dataType() == input->dataType()) + fillMatrix<<>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), pos, n); + // else + // fillMatrix<<>>(matrix.specialBuffer(), + // matrix.specialShapeInfo(), input->specialBuffer(), + // input->specialShapeInfo(), pos, n); + lup_(context, &matrix, nullptr, nullptr); + // else + // lup_(context, &matrix, nullptr, nullptr); + auto offset = shape::getIndexOffset(e, output->shapeInfo()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; + // if (matrix.dataType() == input->dataType()) + determinantKernel<<>>( + inputBuf, outputBuf, n); + // else + // determinantKernel<<>> (inputBuf, outputBuf, n); + } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); +} - int determinant(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int determinant(sd::LaunchContext *context, NDArray *input, NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - template - int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { - Nd4jLong n = input->sizeAt(-1); - Nd4jLong n2 = n * n; - std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); - //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); - DataType dtype = input->dataType(); - if (dtype != DataType::DOUBLE) - dtype = DataType::FLOAT32; - - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.workspace()); - auto det = NDArrayFactory::create(1, context); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input}); - dim3 launchDims(256, 256, 1024); - output->assign(0.f); - for (int e = 0; e < output->lengthOf(); e++) { - Nd4jLong pos = e * n2; -// if (matrix.dataType() == input->dataType()) - fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); -// else -// fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); - -// if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); -// else -// lup_(context, &matrix, nullptr, nullptr); - auto offset = shape::getIndexOffset(e, output->shapeInfo()); - auto inputBuf = reinterpret_cast(matrix.specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; -// if (matrix.dataType() == input->dataType()) - determinantLogKernel<<>>(inputBuf, outputBuf, n); -// else -// determinantLogKernel<<>> (inputBuf, outputBuf, n); - } - NDArray::registerSpecialUse({output}, {input}); - - return Status::OK(); - - return ND4J_STATUS_OK; - } +template +int logAbsDeterminant_(LaunchContext *context, NDArray *input, + NDArray *output) { + Nd4jLong n = input->sizeAt(-1); + Nd4jLong n2 = n * n; + std::vector dims(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + // auto packZ = + // ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), + // {output->rankOf() - 1}); + DataType dtype = input->dataType(); + if (dtype != DataType::DOUBLE) dtype = DataType::FLOAT32; + + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, + context); //, block.workspace()); + auto det = NDArrayFactory::create(1, context); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input}); + dim3 launchDims(256, 256, 1024); + output->assign(0.f); + for (int e = 0; e < output->lengthOf(); e++) { + Nd4jLong pos = e * n2; + // if (matrix.dataType() == input->dataType()) + fillMatrix<<>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), pos, n); + // else + // fillMatrix<<>>(matrix.specialBuffer(), + // matrix.specialShapeInfo(), input->specialBuffer(), + // input->specialShapeInfo(), pos, n); + + // if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + // else + // lup_(context, &matrix, nullptr, nullptr); + auto offset = shape::getIndexOffset(e, output->shapeInfo()); + auto inputBuf = reinterpret_cast(matrix.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; + // if (matrix.dataType() == input->dataType()) + determinantLogKernel + <<>>(inputBuf, + outputBuf, n); + // else + // determinantLogKernel<<>> (inputBuf, + // outputBuf, n); + } + NDArray::registerSpecialUse({output}, {input}); + + return Status::OK(); + + return ND4J_STATUS_OK; +} - int logAbsDeterminant(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int logAbsDeterminant(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - template - static __global__ void - fillLowerUpperKernel( - void *lowerBuf, const Nd4jLong *lowerShape, - void *upperBuf, const Nd4jLong *upperShape, - void *matrixBuf, const Nd4jLong *matrixShape, - Nd4jLong n) { - - __shared__ T *lowerMatrix; - __shared__ T *upperMatrix; - __shared__ T *matrix; - - if (threadIdx.x == 0) { - lowerMatrix = reinterpret_cast(lowerBuf); - upperMatrix = reinterpret_cast(upperBuf); - matrix = reinterpret_cast(matrixBuf); - } - __syncthreads(); - - for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it - for (int j = threadIdx.x; j < n; j += blockDim.x) { - Nd4jLong posX[] = {k, j}; - Nd4jLong posD[] = {j, j}; - auto xPos = shape::getOffset(lowerShape, posX); - auto yPos = shape::getOffset(upperShape, posX); - auto iPos = shape::getOffset(matrixShape, posX); - auto dPos = shape::getOffset(matrixShape, posD); - if (k >= j) - lowerMatrix[xPos] = matrix[iPos];//(k, j); - else - upperMatrix[yPos] = matrix[iPos]; //k, j); - } - } - } +template +static __global__ void fillLowerUpperKernel( + void *lowerBuf, const Nd4jLong *lowerShape, void *upperBuf, + const Nd4jLong *upperShape, void *matrixBuf, const Nd4jLong *matrixShape, + Nd4jLong n) { + __shared__ T *lowerMatrix; + __shared__ T *upperMatrix; + __shared__ T *matrix; + + if (threadIdx.x == 0) { + lowerMatrix = reinterpret_cast(lowerBuf); + upperMatrix = reinterpret_cast(upperBuf); + matrix = reinterpret_cast(matrixBuf); + } + __syncthreads(); + + for (int k = blockIdx.x; k < n; + k += + gridDim.x) { // and then put all values under main diagonal on to it + for (int j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posX[] = {k, j}; + Nd4jLong posD[] = {j, j}; + auto xPos = shape::getOffset(lowerShape, posX); + auto yPos = shape::getOffset(upperShape, posX); + auto iPos = shape::getOffset(matrixShape, posX); + auto dPos = shape::getOffset(matrixShape, posD); + if (k >= j) + lowerMatrix[xPos] = matrix[iPos]; //(k, j); + else + upperMatrix[yPos] = matrix[iPos]; // k, j); + } + } +} - template - static int inverse_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - auto n = input->sizeAt(-1); - auto n2 = n * n; - auto dtype = DataTypeUtils::fromT(); //input->dataType(); -// if (dtype != DataType::DOUBLE) -// dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), - {input->rankOf() - 2, - input->rankOf() - 1}); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), - {output->rankOf() - 2, - output->rankOf() - 1}); - auto stream = context->getCudaStream(); - - for (auto i = 0LL; i < packX.numberOfTads(); i++) { - fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); - matrix.tickWriteDevice(); - //compound.assign(matrix); -// if (matrix.dataType() == input->dataType()) - lup_(context, &matrix, nullptr, nullptr); - fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); - lower.tickWriteDevice(); - upper.tickWriteDevice(); -// lower.printIndexedBuffer("LOWER"); -// upper.printIndexedBuffer("UPPER"); - matrix.assign(0); - invertUpperMatrix(context, &upper, &matrix); // U^{-1} - matrix.tickWriteDevice(); -// matrix.printIndexedBuffer("Upper Inverted"); - compound.assign(0); - invertLowerMatrix(context, &lower, &compound); // L{-1} - compound.tickWriteDevice(); -// compound.printIndexedBuffer("Lower Inverted"); -// matrix.tickWriteDevice(); -// compound.tickWriteDevice(); - sd::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); - upper.tickWriteDevice(); -// upper.printIndexedBuffer("Full inverted"); - returnMatrix<<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); - } - return Status::OK(); - } +template +static int inverse_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + auto n = input->sizeAt(-1); + auto n2 = n * n; + auto dtype = DataTypeUtils::fromT(); // input->dataType(); + // if (dtype != DataType::DOUBLE) + // dtype = DataType::FLOAT32; + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); + auto stream = context->getCudaStream(); + + for (auto i = 0LL; i < packX.numberOfTads(); i++) { + fillMatrix<<<1, n2, 1024, *stream>>>( + matrix.specialBuffer(), matrix.specialShapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), i * n2, n); + matrix.tickWriteDevice(); + // compound.assign(matrix); + // if (matrix.dataType() == input->dataType()) + lup_(context, &matrix, nullptr, nullptr); + fillLowerUpperKernel<<>>( + lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), + upper.specialShapeInfo(), matrix.specialBuffer(), + matrix.specialShapeInfo(), n); + lower.tickWriteDevice(); + upper.tickWriteDevice(); + // lower.printIndexedBuffer("LOWER"); + // upper.printIndexedBuffer("UPPER"); + matrix.assign(0); + invertUpperMatrix(context, &upper, &matrix); // U^{-1} + matrix.tickWriteDevice(); + // matrix.printIndexedBuffer("Upper Inverted"); + compound.assign(0); + invertLowerMatrix(context, &lower, &compound); // L{-1} + compound.tickWriteDevice(); + // compound.printIndexedBuffer("Lower Inverted"); + // matrix.tickWriteDevice(); + // compound.tickWriteDevice(); + sd::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); + upper.tickWriteDevice(); + // upper.printIndexedBuffer("Full inverted"); + returnMatrix<<<1, n2, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); + } + return Status::OK(); +} - int inverse(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {input}); - } +int inverse(sd::LaunchContext *context, NDArray *input, NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, + (context, input, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {input}); +} - bool checkCholeskyInput(sd::LaunchContext *context, NDArray const *input) { - return true; - } +bool checkCholeskyInput(sd::LaunchContext *context, NDArray const *input) { + return true; +} - template - __global__ void fillBatchKernel(F **dArrayBatch, F *buf, const Nd4jLong *offsets, Nd4jLong batchSize) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; +template +__global__ void fillBatchKernel(F **dArrayBatch, F *buf, + const Nd4jLong *offsets, Nd4jLong batchSize) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; - for (auto i = start; i < batchSize; i += step) { - dArrayBatch[i] = buf + offsets[i]; - } - } + for (auto i = start; i < batchSize; i += step) { + dArrayBatch[i] = buf + offsets[i]; + } +} - template - __global__ void - adjustResultsKernel(F *dArray, const Nd4jLong *shape, const Nd4jLong *offsets, Nd4jLong batchSize, Nd4jLong n) { - //auto i = blockIdx.x * blockDim.x + threadIdx.x; - Nd4jLong *shapeOf = shape::shapeOf(shape); - Nd4jLong *strideOf = shape::stride(shape); - - for (auto i = blockIdx.x; i < batchSize; i += gridDim.x) { - auto current = dArray + offsets[i]; - for (auto r = threadIdx.x; r < n; r += blockDim.x) { - for (auto c = r + 1; c < n; c++) { - Nd4jLong posRC[] = {r, c}; - auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2); - current[pos] = 0.; - } - } - } - } +template +__global__ void adjustResultsKernel(F *dArray, const Nd4jLong *shape, + const Nd4jLong *offsets, Nd4jLong batchSize, + Nd4jLong n) { + // auto i = blockIdx.x * blockDim.x + threadIdx.x; + Nd4jLong *shapeOf = shape::shapeOf(shape); + Nd4jLong *strideOf = shape::stride(shape); + + for (auto i = blockIdx.x; i < batchSize; i += gridDim.x) { + auto current = dArray + offsets[i]; + for (auto r = threadIdx.x; r < n; r += blockDim.x) { + for (auto c = r + 1; c < n; c++) { + Nd4jLong posRC[] = {r, c}; + auto pos = + r * n + c; // shape::getOffset(0, shapeOf, strideOf, posRC, 2); + current[pos] = 0.; + } + } + } +} - template - int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - if (!inplace) - output->assign(input); - auto tempOutput =output->dup(); - cusolverDnHandle_t handle = nullptr; - auto n = input->sizeAt(-1); - auto n2 = n * n; - NDArray::prepareSpecialUse({output}, {input}); - auto status = cusolverDnCreate(&handle); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status); - } - F **dArrayBatch = nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.shapeInfo(), - {tempOutput.rankOf() - 2, - tempOutput.rankOf() - 1}); - const Nd4jLong batchSize = packX.numberOfTads(); - int *dInfoArray = nullptr; - auto err = cudaMalloc((void **) &dArrayBatch, sizeof(F *) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", - err); - } - err = cudaMalloc((void **) &dInfoArray, sizeof(int) * batchSize); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); - } - auto stream = context->getCudaStream(); - fillBatchKernel<<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), packX.specialOffsets(), batchSize); - - status = cusolverDnSetStream(handle, *stream); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status); - } - const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; - if (input->dataType() == DataType::DOUBLE) - status = cusolverDnDpotrfBatched( - handle, - uplo, - n, - (double **) dArrayBatch, - n, - dInfoArray, - batchSize); - else - status = cusolverDnSpotrfBatched( - handle, - uplo, - n, - (float **) dArrayBatch, - n, - dInfoArray, - batchSize); - - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status); - } - adjustResultsKernel<<>>(reinterpret_cast(tempOutput.specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); - - err = cudaFree(dArrayBatch); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", - err); - } - err = cudaFree(dInfoArray); - if (err) { - throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err); - } - - if (!inplace) - output->assign(tempOutput); - else - input->assign(tempOutput); - - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +template +int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + if (!inplace) output->assign(input); + auto tempOutput = output->dup(); + cusolverDnHandle_t handle = nullptr; + auto n = input->sizeAt(-1); + auto n2 = n * n; + NDArray::prepareSpecialUse({output}, {input}); + auto status = cusolverDnCreate(&handle); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot create solver handle", status); + } + F **dArrayBatch = nullptr; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempOutput.shapeInfo(), + {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); + const Nd4jLong batchSize = packX.numberOfTads(); + int *dInfoArray = nullptr; + auto err = cudaMalloc((void **)&dArrayBatch, sizeof(F *) * batchSize); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver batch data " + "buffer", + err); + } + err = cudaMalloc((void **)&dInfoArray, sizeof(int) * batchSize); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver errors buffer", + err); + } + auto stream = context->getCudaStream(); + fillBatchKernel<<<1, batchSize, 128, *stream>>>( + dArrayBatch, reinterpret_cast(tempOutput.specialBuffer()), + packX.specialOffsets(), batchSize); + + status = cusolverDnSetStream(handle, *stream); + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot set stream to solver handle", status); + } + const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + if (input->dataType() == DataType::DOUBLE) + status = cusolverDnDpotrfBatched(handle, uplo, n, (double **)dArrayBatch, n, + dInfoArray, batchSize); + else + status = cusolverDnSpotrfBatched(handle, uplo, n, (float **)dArrayBatch, n, + dInfoArray, batchSize); + + if (CUSOLVER_STATUS_SUCCESS != status) { + throw cuda_exception::build( + "helpers::cholesky_: Cholesky factorization failed for batch", status); + } + adjustResultsKernel<<>>( + reinterpret_cast(tempOutput.specialBuffer()), + packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n); + + err = cudaFree(dArrayBatch); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot deallocate memory for solver batch data " + "buffer", + err); + } + err = cudaFree(dInfoArray); + if (err) { + throw cuda_exception::build( + "helpers::cholesky_: Cannot allocate memory for solver errors buffer", + err); + } + + if (!inplace) + output->assign(tempOutput); + else + input->assign(tempOutput); + + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} // template - int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - NDArray::prepareSpecialUse({output}, {input}); - if (input->dataType() == DataType::DOUBLE) - cholesky__(context, input, output, inplace); - else if (input->dataType() == DataType::FLOAT32) - cholesky__(context, input, output, inplace); - else { - std::unique_ptr tempOutput( - NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context)); - tempOutput->assign(input); - cholesky__(context, tempOutput.get(), tempOutput.get(), true); - output->assign(tempOutput.get()); - } - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + NDArray::prepareSpecialUse({output}, {input}); + if (input->dataType() == DataType::DOUBLE) + cholesky__(context, input, output, inplace); + else if (input->dataType() == DataType::FLOAT32) + cholesky__(context, input, output, inplace); + else { + std::unique_ptr tempOutput(NDArrayFactory::create_( + 'c', input->getShapeAsVector(), DataType::FLOAT32, context)); + tempOutput->assign(input); + cholesky__(context, tempOutput.get(), tempOutput.get(), true); + output->assign(tempOutput.get()); + } + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} - int cholesky(sd::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { -// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); - return cholesky_(context, input, output, inplace); - } -// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int inverse_, (sd::LaunchContext * context, NDArray * input, NDArray * output), - FLOAT_NATIVE); - - template - __global__ void logDetKernel( - const T *inputBuf, const Nd4jLong *inputShape, - Nd4jLong batchNum, - const Nd4jLong *tadShape, const Nd4jLong *tadOffsets, - T *outputBuf, const Nd4jLong *outputShape) { - - __shared__ int n; - if (threadIdx.x == 0) { - n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); - } - __syncthreads(); - - auto output = outputBuf; - auto input = inputBuf; - - for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { - auto current = input + tadOffsets[i]; - - auto zIndex = shape::getIndexOffset(i, outputShape); - for (auto e = threadIdx.x; e < n; e += blockDim.x) { - Nd4jLong diag[] = {e, e}; - auto xIndex = shape::getOffset(tadShape, diag); - math::atomics::nd4j_atomicAdd(&output[zIndex],math::nd4j_log(current[xIndex] * current[xIndex])); - } - } - } +int cholesky(sd::LaunchContext *context, NDArray *input, NDArray *output, + bool inplace) { + // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, + // input, output, inplace), FLOAT_TYPES); + return cholesky_(context, input, output, inplace); +} +// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, +// NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int inverse_, + (sd::LaunchContext * context, NDArray *input, + NDArray *output), + FLOAT_NATIVE); + +template +__global__ void logDetKernel(const T *inputBuf, const Nd4jLong *inputShape, + Nd4jLong batchNum, const Nd4jLong *tadShape, + const Nd4jLong *tadOffsets, T *outputBuf, + const Nd4jLong *outputShape) { + __shared__ int n; + if (threadIdx.x == 0) { + n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1); + } + __syncthreads(); + + auto output = outputBuf; + auto input = inputBuf; + + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto current = input + tadOffsets[i]; + + auto zIndex = shape::getIndexOffset(i, outputShape); + for (auto e = threadIdx.x; e < n; e += blockDim.x) { + Nd4jLong diag[] = {e, e}; + auto xIndex = shape::getOffset(tadShape, diag); + math::atomics::nd4j_atomicAdd( + &output[zIndex], + math::nd4j_log(current[xIndex] * current[xIndex])); + } + } +} - template - int logdetFunctor_(sd::LaunchContext *context, NDArray *input, NDArray *output) { - NDArray::prepareSpecialUse({output}, {input}); - auto n2 = input->sizeAt(-1) * input->sizeAt(-2); - auto stream = context->getCudaStream(); - NDArray tempOutput(*input); - - cholesky(context, input, &tempOutput, false); - - auto outputBuf = output->dataBuffer()->specialAsT(); //reinterpret_cast(output->specialBuffer()); // + e * n2; // + e * n2; - auto inputBuf = tempOutput.dataBuffer()->specialAsT(); //reinterpret_cast(tempOutput.specialBuffer()); - output->nullify(); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.shapeInfo(), - {tempOutput.rankOf() - 2, - tempOutput.rankOf() - 1}); - logDetKernel<<<128, 512, 256, *stream>>>(inputBuf, tempOutput.specialShapeInfo(), - packX.numberOfTads(), packX.specialShapeInfo(), - packX.specialOffsets(), outputBuf, output->specialShapeInfo()); - output->tickWriteDevice(); - NDArray::registerSpecialUse({output}, {input}); - return Status::OK(); - } +template +int logdetFunctor_(sd::LaunchContext *context, NDArray *input, + NDArray *output) { + NDArray::prepareSpecialUse({output}, {input}); + auto n2 = input->sizeAt(-1) * input->sizeAt(-2); + auto stream = context->getCudaStream(); + NDArray tempOutput(*input); + + cholesky(context, input, &tempOutput, false); + + auto outputBuf = + output->dataBuffer()->specialAsT(); // reinterpret_cast(output->specialBuffer()); + // // + e * n2; // + e * n2; + auto inputBuf = + tempOutput.dataBuffer() + ->specialAsT< + T>(); // reinterpret_cast(tempOutput.specialBuffer()); + output->nullify(); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempOutput.shapeInfo(), + {tempOutput.rankOf() - 2, tempOutput.rankOf() - 1}); + logDetKernel<<<128, 512, 256, *stream>>>( + inputBuf, tempOutput.specialShapeInfo(), packX.numberOfTads(), + packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, + output->specialShapeInfo()); + output->tickWriteDevice(); + NDArray::registerSpecialUse({output}, {input}); + return Status::OK(); +} - int logdetFunctor(sd::LaunchContext *context, NDArray *input, NDArray *output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE); - } +int logdetFunctor(sd::LaunchContext *context, NDArray *input, NDArray *output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, + (context, input, output), FLOAT_NATIVE); +} - /* - * lup - batched input, batched outputs - * */ - int lup(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { - BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_,(context, input, compound, permutation), FLOAT_NATIVE, INDEXING_TYPES); - return Status::OK(); - } +/* + * lup - batched input, batched outputs + * */ +int lup(LaunchContext *context, NDArray *input, NDArray *compound, + NDArray *permutation) { + BUILD_DOUBLE_SELECTOR(input->dataType(), permutation->dataType(), lup_, + (context, input, compound, permutation), FLOAT_NATIVE, + INDEXING_TYPES); + return Status::OK(); +} // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, -// (sd::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); - } -} -} +// (sd::LaunchContext * context, NDArray * input, +// NDArray * output), FLOAT_NATIVE); +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu index 97124c3dba69..a8dff33b48ae 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrixSetDiag.cu @@ -19,84 +19,104 @@ // #include -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void matrixSetDiagCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - - // x - input, shape [A,B,C] - // y - diagonal, shape [A,B] - // z - output, shape [A,B,C] - // input and output are the same array (x == z) when zeroPad = true - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, *sharedMem; // xRank = zRank, xRank = yRank + 1 - __shared__ Nd4jLong xLen; // xLen = zLen - __shared__ bool areSameOffsets; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); // shapes are definitely the same, but strides might not - - xRank = shape::rank(xShapeInfo); - xLen = shape::length(xShapeInfo); - } - - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * xRank; // we provide (xRank * sizeof(int) * threadIdx.x) amount of shared memory per each thread - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { - - shape::index2coords(i, xShapeInfo, coords); - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); - - // condition to be on diagonal of innermost matrix - if(coords[xRank - 2] == coords[xRank - 1]) - z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; - else - z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; - } +template +__global__ static void matrixSetDiagCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const bool zeroPad) { + // x - input, shape [A,B,C] + // y - diagonal, shape [A,B] + // z - output, shape [A,B,C] + // input and output are the same array (x == z) when zeroPad = true + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, *sharedMem; // xRank = zRank, xRank = yRank + 1 + __shared__ Nd4jLong xLen; // xLen = zLen + __shared__ bool areSameOffsets; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + areSameOffsets = shape::haveSameShapeAndStrides( + xShapeInfo, + zShapeInfo); // shapes are definitely the same, but strides might not + + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); + } + + __syncthreads(); + + auto coords = + sharedMem + + threadIdx.x * xRank; // we provide (xRank * sizeof(int) * threadIdx.x) + // amount of shared memory per each thread + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, xShapeInfo, coords); + + const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto zOffset = + areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords); + + // condition to be on diagonal of innermost matrix + if (coords[xRank - 2] == coords[xRank - 1]) + z[zOffset] = y[shape::getOffset(yShapeInfo, coords)]; + else + z[zOffset] = zeroPad ? static_cast(0) : x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void matrixSetDiagCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool zeroPad) { - - matrixSetDiagCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); +template +static void matrixSetDiagCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const bool zeroPad) { + matrixSetDiagCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, zeroPad); } /////////////////////////////////////////////////////////////////// -void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad) { - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * input.rankOf() + 128; - - PointersManager manager(context, "matrixSetDiag"); - - NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); - BUILD_SINGLE_SELECTOR(input.dataType(), matrixSetDiagCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), diagonal.specialBuffer(), diagonal.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), zeroPad), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input, &diagonal}); - - manager.synchronize(); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * input.rankOf() + 128; + + PointersManager manager(context, "matrixSetDiag"); + + NDArray::prepareSpecialUse({&output}, {&input, &diagonal}); + BUILD_SINGLE_SELECTOR( + input.dataType(), matrixSetDiagCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + diagonal.specialBuffer(), diagonal.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), zeroPad), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input, &diagonal}); + + manager.synchronize(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu index 78249bc381b4..158247beda72 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_band.cu @@ -17,11 +17,11 @@ // // @author George A. Shulinok // -#include -#include #include -#include #include +#include +#include +#include namespace sd { namespace ops { @@ -42,75 +42,87 @@ namespace helpers { // numTads - number of subarrays // inputLength - input subarray length // - template - static __global__ void matrixBandKernel(const void* inputBuffer, const Nd4jLong* inputShape, - void* outputBuffer, const Nd4jLong* outputShape, - Nd4jLong lowerBand, Nd4jLong upperBand, - const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong* tadInputOffsets, - const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong* tadOutputOffsets, - Nd4jLong numTads, - Nd4jLong inputLength) { - int totalThreads = blockDim.x; - Nd4jLong rows = shape::sizeAt(inputShape, -2); - Nd4jLong cols = shape::sizeAt(inputShape, -1); - for (Nd4jLong e = blockIdx.x; e < numTads; e += gridDim.x) { - auto yOffset = tadInputOffsets[e]; - auto xOffset = tadOutputOffsets[e]; - for (Nd4jLong i = blockIdx.y; i < rows; i += gridDim.y) { - for (Nd4jLong j = threadIdx.x; j < cols; j += totalThreads) { - Nd4jLong coords[2] = {i, j}; - Nd4jLong tadOffsetOut = shape::getOffset(tadOnlyOutputShapeInfo, coords); - Nd4jLong tadOffsetIn = shape::getOffset(tadOnlyInputShapeInfo, coords); +template +static __global__ void matrixBandKernel( + const void* inputBuffer, const Nd4jLong* inputShape, void* outputBuffer, + const Nd4jLong* outputShape, Nd4jLong lowerBand, Nd4jLong upperBand, + const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong* tadInputOffsets, + const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong* tadOutputOffsets, + Nd4jLong numTads, Nd4jLong inputLength) { + int totalThreads = blockDim.x; + Nd4jLong rows = shape::sizeAt(inputShape, -2); + Nd4jLong cols = shape::sizeAt(inputShape, -1); + for (Nd4jLong e = blockIdx.x; e < numTads; e += gridDim.x) { + auto yOffset = tadInputOffsets[e]; + auto xOffset = tadOutputOffsets[e]; + for (Nd4jLong i = blockIdx.y; i < rows; i += gridDim.y) { + for (Nd4jLong j = threadIdx.x; j < cols; j += totalThreads) { + Nd4jLong coords[2] = {i, j}; + Nd4jLong tadOffsetOut = + shape::getOffset(tadOnlyOutputShapeInfo, coords); + Nd4jLong tadOffsetIn = shape::getOffset(tadOnlyInputShapeInfo, coords); - if (i >= j) { // check lower diagonals - if (lowerBand > 0) { - if ((i - j) > lowerBand) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); - else - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( - reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); - } - } else if (j > i) { - if (upperBand > 0) - if ((j - i) > upperBand) - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = T(0); - else - *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = *( - reinterpret_cast(inputBuffer) + yOffset + tadOffsetIn); - } - } - } + if (i >= j) { // check lower diagonals + if (lowerBand > 0) { + if ((i - j) > lowerBand) + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + T(0); + else + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + *(reinterpret_cast(inputBuffer) + yOffset + + tadOffsetIn); + } + } else if (j > i) { + if (upperBand > 0) + if ((j - i) > upperBand) + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + T(0); + else + *(reinterpret_cast(outputBuffer) + xOffset + tadOffsetOut) = + *(reinterpret_cast(inputBuffer) + yOffset + + tadOffsetIn); } - + } } + } +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // matrixBandPart_ - main algorithm caller // - template - void matrixBandPart_(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - dim3 launchDims(256, 512, 8192); - auto stream = context->getCudaStream(); +template +void matrixBandPart_(sd::LaunchContext* context, NDArray* input, + NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { + dim3 launchDims(256, 512, 8192); + auto stream = context->getCudaStream(); - std::vector lastDims({input->rankOf() - 2, input->rankOf() - 1}); - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), lastDims); + std::vector lastDims({input->rankOf() - 2, input->rankOf() - 1}); + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(input->rankOf(), lastDims); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), lastDims); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), lastDims); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), lastDims); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), lastDims); - const Nd4jLong numTads = packX.numberOfTads(); + const Nd4jLong numTads = packX.numberOfTads(); - NDArray::prepareSpecialUse({output}, {input}); - matrixBandKernel<<>>(input->specialBuffer(), - input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - lowerBand, upperBand, packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets(), numTads, input->lengthOf()); - NDArray::registerSpecialUse({output}, {input}); - } - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand) { - BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, (context, input, output, lowerBand, upperBand), FLOAT_TYPES); - } -} -} + NDArray::prepareSpecialUse({output}, {input}); + matrixBandKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), lowerBand, upperBand, + packX.specialShapeInfo(), packX.specialOffsets(), + packZ.specialShapeInfo(), packZ.specialOffsets(), numTads, + input->lengthOf()); + NDArray::registerSpecialUse({output}, {input}); } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand) { + BUILD_SINGLE_SELECTOR(input->dataType(), matrixBandPart_, + (context, input, output, lowerBand, upperBand), + FLOAT_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu index 30d5f0ef95f4..9d43a373bb6e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/matrix_diag_part.cu @@ -19,13 +19,12 @@ // #include -#include +#include #include -#include +#include #include #include -#include -#include +#include namespace sd { namespace ops { @@ -33,70 +32,89 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // put diagonals from input batched matricies to output batched vectors - template - static __global__ void matrixDiagPartKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, - const Nd4jLong* tadOnlyInputShapeInfo, const Nd4jLong *tadInputOffsets, - const Nd4jLong* tadOnlyOutputShapeInfo, const Nd4jLong *tadOutputOffsets) { - int totalThreads = blockDim.x; - for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { - auto yOffset = tadInputOffsets[i]; - auto xOffset = tadOutputOffsets[i]; - for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { - Nd4jLong coords[2] = {j, j}; - Nd4jLong tadOffset = shape::getOffset(tadOnlyInputShapeInfo, coords); - *(reinterpret_cast(outputBuffer) + xOffset + shape::getIndexOffset(j, tadOnlyOutputShapeInfo)) = *(reinterpret_cast(inputBuffer) + yOffset + tadOffset); - } - } +template +static __global__ void matrixDiagPartKernel( + void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, + Nd4jLong inputLength, const Nd4jLong* tadOnlyInputShapeInfo, + const Nd4jLong* tadInputOffsets, const Nd4jLong* tadOnlyOutputShapeInfo, + const Nd4jLong* tadOutputOffsets) { + int totalThreads = blockDim.x; + for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) { + auto yOffset = tadInputOffsets[i]; + auto xOffset = tadOutputOffsets[i]; + for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) { + Nd4jLong coords[2] = {j, j}; + Nd4jLong tadOffset = shape::getOffset(tadOnlyInputShapeInfo, coords); + *(reinterpret_cast(outputBuffer) + xOffset + + shape::getIndexOffset(j, tadOnlyOutputShapeInfo)) = + *(reinterpret_cast(inputBuffer) + yOffset + tadOffset); } + } +} ////////////////////////////////////////////////////////////////////////// // Returns a batched matrix tensor with new batched diagonal values. -// for detailed explanations please take a look on web page: https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag +// for detailed explanations please take a look on web page: +// https://www.tensorflow.org/api_docs/python/tf/matrix_set_diag // - template - int _matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - auto stream = context->getCudaStream(); - auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); - auto listDiag = input->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1}); - - if (listOut.size() != listDiag.size()) { - nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); - return ND4J_STATUS_VALIDATION; - } - Nd4jLong lastDimension = sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(output->rankOf(), {output->rankOf() - 1}); - const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension}); - //printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); - //tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets; - std::vector outputDims({output->rankOf() - 1}); - std::vector inputDims({input->rankOf() - 2, input->rankOf() - 1}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), inputDims); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), outputDims); - - - if (!output->isActualOnDeviceSide()) - input->syncToDevice(); - - if (!input->isActualOnDeviceSide()) - input->syncToDevice(); - - - dim3 launchDims(256, 512, 8192); - matrixDiagPartKernel<<>>(input->specialBuffer(), output->specialBuffer(), numTads, lastDimension, packX.specialShapeInfo(), packX.specialOffsets(), packZ.specialShapeInfo(), packZ.specialOffsets()); - - return Status::OK(); - } +template +int _matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + auto stream = context->getCudaStream(); + auto listOut = output->allTensorsAlongDimension({output->rankOf() - 1}); + auto listDiag = input->allTensorsAlongDimension( + {input->rankOf() - 2, input->rankOf() - 1}); + + if (listOut.size() != listDiag.size()) { + nd4j_printf("matrix_diag_part: Input matrix has wrong shape.", ""); + return ND4J_STATUS_VALIDATION; + } + Nd4jLong lastDimension = + sd::math::nd4j_min(input->sizeAt(-2), input->sizeAt(-1)); + + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(output->rankOf(), {output->rankOf() - 1}); + const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs( + input->shapeInfo(), + dimsToExclude); // this->tensorsAlongDimension({dimension}); + // printf("Repeat delta %lld, numTads %lld\n", repeatDelta, numTads); + // tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, + // tadOutputOffsets; + std::vector outputDims({output->rankOf() - 1}); + std::vector inputDims({input->rankOf() - 2, input->rankOf() - 1}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), inputDims); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), outputDims); + + if (!output->isActualOnDeviceSide()) input->syncToDevice(); + + if (!input->isActualOnDeviceSide()) input->syncToDevice(); + + dim3 launchDims(256, 512, 8192); + matrixDiagPartKernel + <<>>( + input->specialBuffer(), output->specialBuffer(), numTads, + lastDimension, packX.specialShapeInfo(), packX.specialOffsets(), + packZ.specialShapeInfo(), packZ.specialOffsets()); + + return Status::OK(); +} //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // caller for _matrixDiagPart // - int matrixDiagPart(sd::LaunchContext * context, const NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, (context, input, output), LIBND4J_TYPES); - } +int matrixDiagPart(sd::LaunchContext* context, const NDArray* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _matrixDiagPart, + (context, input, output), LIBND4J_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, (sd::LaunchContext * context, const NDArray* input, NDArray* output), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template int _matrixDiagPart, + (sd::LaunchContext * context, const NDArray* input, + NDArray* output), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index 6e70d4510ae8..f0fec4f52b94 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -18,80 +18,93 @@ // @author raver119@gmail.com // -#include #include - +#include namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void indicesFiller(void *vz, Nd4jLong const* zShapeInfo, Nd4jLong part, Nd4jLong bSize) { - auto z = reinterpret_cast(vz); +template +static _CUDA_G void indicesFiller(void* vz, Nd4jLong const* zShapeInfo, + Nd4jLong part, Nd4jLong bSize) { + auto z = reinterpret_cast(vz); - for (int b = blockIdx.x; b < bSize; b += gridDim.x) { - for (Nd4jLong e = threadIdx.x; e < part; e += blockDim.x) { - z[shape::getIndexOffset(e + b * part, zShapeInfo)] = static_cast(e); - } - } + for (int b = blockIdx.x; b < bSize; b += gridDim.x) { + for (Nd4jLong e = threadIdx.x; e < part; e += blockDim.x) { + z[shape::getIndexOffset(e + b * part, zShapeInfo)] = static_cast(e); } + } +} - template - static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - int kY = params[0]; - int kX = params[1]; +template +static void maxPoolingFunctor_(sd::graph::Context& block, NDArray* input, + NDArray* values, std::vector const& params, + NDArray* indices) { + int kY = params[0]; + int kX = params[1]; - int sY = params[2]; - int sX = params[3]; + int sY = params[2]; + int sX = params[3]; - int pY = params[4]; - int pX = params[5]; + int pY = params[4]; + int pX = params[5]; - int dY = params[6]; - int dX = params[7]; + int dY = params[6]; + int dX = params[7]; - int oY = 0; - int oX = 0; + int oY = 0; + int oX = 0; - const int bSize = input->sizeAt(0); - const int inD = input->sizeAt(1); - const int inY = input->sizeAt(2); - const int inX = input->sizeAt(3); + const int bSize = input->sizeAt(0); + const int inD = input->sizeAt(1); + const int inY = input->sizeAt(2); + const int inX = input->sizeAt(3); - const bool isSameMode = params[8] != 0; + const bool isSameMode = params[8] != 0; - ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode); + ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, + inY, inX, isSameMode); - if (isSameMode) - ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); + if (isSameMode) + ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], + params[1], params[2], params[3], params[6], + params[7]); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; + ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, + dY, dX, PoolingType::MAX_POOL, 1); - if (nullptr != indices) { - // for max_pool_with_argmax - auto total = input->lengthOf(); - auto part = total / bSize; + if (nullptr != indices) { + // for max_pool_with_argmax + auto total = input->lengthOf(); + auto part = total / bSize; - indicesFiller<<<256, 256, 1024, *block.launchContext()->getCudaStream()>>>(indices->specialBuffer(), indices->specialShapeInfo(), part, bSize); + indicesFiller + <<<256, 256, 1024, *block.launchContext()->getCudaStream()>>>( + indices->specialBuffer(), indices->specialShapeInfo(), part, bSize); - /* - for (int k = 0; k < total; ) - for (int i = 0; i < part; i++) { - indices->p(k++, i); - } - */ + /* + for (int k = 0; k < total; ) + for (int i = 0; i < part; i++) { + indices->p(k++, i); } - } - - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - NDArray::prepareSpecialUse({values, indices}, {input}); - auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({values, indices}, {input}); - } - + */ + } } + +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices) { + NDArray::prepareSpecialUse({values, indices}, {input}); + auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, + (block, input, values, params, indices), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({values, indices}, {input}); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index c4c8783ffe38..2177eb0282e4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -18,90 +18,87 @@ // @author sgazeos@gmail.com // -#include #include #include - +#include namespace sd { - namespace ops { - namespace helpers { - - template - void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; +namespace ops { +namespace helpers { - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; +template +void maximumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; - if (x->isSameShape(y)) { - // PWT case case + if (x->isSameShape(y)) { + // PWT case case - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x >= s ? _e : (T) 0.; - }; + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x >= s ? _e : (T)0.; }; - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); - auto targetShape = epsNext->getShapeAsVector(); + auto targetShape = epsNext->getShapeAsVector(); - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - } + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } +} - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); - BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR(x->dataType(), maximumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); - } + NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); +} - } - } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 3c580ee339eb..383cfb657905 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -14,533 +14,599 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 - // +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 +// - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxIndexCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - T mVal = -DataTypeUtils::max(); - Z mIdx(0); - - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); - auto val = x[shape::getIndexOffset(e, xShape)];; - if (mVal < val) { - mIdx = static_cast(i); - mVal = val; - } - } - - output[shape::getIndexOffset(e, outputShape)] = mIdx; - } - } - - template - static void mergeMaxIndex_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrSize = static_cast(inArrs.size()); - std::vector inBuffers(nArrSize), inShapes(nArrSize); - - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } - - PointersManager manager(context, "mergeMaxIndex"); - - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - - mergeMaxIndexCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrSize, output.specialBuffer(), output.specialShapeInfo(), length); - - manager.synchronize(); - } - - void mergeMaxIndex(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); - - NDArray::registerSpecialUse({ &output }, inArrs); - } - - - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - T mVal = -DataTypeUtils::max(); - - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); - auto val = x[shape::getIndexOffset(e, xShape)];; - if (mVal < val) - mVal = val; - } - - output[shape::getIndexOffset(e, outputShape)] = mVal; - } - } - - template - static void mergeMax_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrsSize = static_cast(inArrs.size()); - - std::vector inBuffers(nArrsSize), inShapes(nArrsSize); - - for (int e = 0; e < nArrsSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } - - PointersManager manager(context, "mergeMax"); - - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - - mergeMaxCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrsSize, output.specialBuffer(), output.specialShapeInfo(), length); - - manager.synchronize(); - } - - void mergeMax(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); - - NDArray::registerSpecialUse({ &output }, inArrs); - } - - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeMaxBpCudaLauncher( - void** inArrs, void** inShapes, - const void* vgradient, const Nd4jLong* gradientShape, - const int numArrays, - void** outArrs, void** outShapes, - Nd4jLong length, - bool bSameOrderAndEws1) { +namespace ops { +namespace helpers { +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxIndexCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, + void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + Z mIdx(0); + + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); + auto val = x[shape::getIndexOffset(e, xShape)]; + ; + if (mVal < val) { + mIdx = static_cast(i); + mVal = val; + } + } + + output[shape::getIndexOffset(e, outputShape)] = mIdx; + } +} - auto grad = reinterpret_cast(vgradient); +template +static void mergeMaxIndex_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrSize = static_cast(inArrs.size()); + std::vector inBuffers(nArrSize), inShapes(nArrSize); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - int coords[MAX_RANK]; + PointersManager manager(context, "mergeMaxIndex"); - for (Nd4jLong e = tid; e < length; e += step) { + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - T mVal = -DataTypeUtils::max(); - int nMaxIndex = 0; - auto xOffset = e, zOffset = e, gradOffset = e; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + + mergeMaxIndexCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrSize, output.specialBuffer(), + output.specialShapeInfo(), length); + + manager.synchronize(); +} - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), + mergeMaxIndex_, (context, inArrs, output), + LIBND4J_TYPES, INDEXING_TYPES); - if (!bSameOrderAndEws1) { - auto xShape = reinterpret_cast(inShapes[i]); - xOffset = shape::getOffset(xShape, coords); - } + NDArray::registerSpecialUse({&output}, inArrs); +} - auto val = x[xOffset]; - if (mVal < val) { - mVal = val; - nMaxIndex = i; - } - } - - // outputs have to be pre-nullify - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[nMaxIndex]); - zOffset = shape::getOffset(outShape, coords); - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); + auto val = x[shape::getIndexOffset(e, xShape)]; + ; + if (mVal < val) mVal = val; + } - auto output = reinterpret_cast(outArrs[nMaxIndex]); + output[shape::getIndexOffset(e, outputShape)] = mVal; + } +} - output[zOffset] = grad[gradOffset]; - } - } +template +static void mergeMax_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrsSize = static_cast(inArrs.size()); - template - static void mergeMaxBp_(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs, int nArrSize, bool bSameOrderAndEws1) { + std::vector inBuffers(nArrsSize), inShapes(nArrsSize); - std::vector inBuffers(nArrSize), inShapes(nArrSize), outBuffers(nArrSize), outShapes(nArrSize); + for (int e = 0; e < nArrsSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } + PointersManager manager(context, "mergeMax"); - PointersManager manager(context, "mergeMaxBp"); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + mergeMaxCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrsSize, output.specialBuffer(), + output.specialShapeInfo(), length); - auto length = inArrs[nArrSize]->lengthOf(); + manager.synchronize(); +} - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - mergeMaxBpCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, inArrs[nArrSize]->specialBuffer(), - inArrs[nArrSize]->specialShapeInfo(), nArrSize, pOutBuffers, pOutShapes, - length, bSameOrderAndEws1); - - manager.synchronize(); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), + LIBND4J_TYPES); - void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs) { + NDArray::registerSpecialUse({&output}, inArrs); +} - // not use gradient - int nArrSize = static_cast(inArrs.size() - 1); - - const std::vector& out = reinterpret_cast&>(outArrs); +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeMaxBpCudaLauncher( + void** inArrs, void** inShapes, const void* vgradient, + const Nd4jLong* gradientShape, const int numArrays, void** outArrs, + void** outShapes, Nd4jLong length, bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); - NDArray::prepareSpecialUse(out, inArrs); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - bool bSameOrderAndEws1 = (1 == inArrs[nArrSize]->ews()); - auto ordering = inArrs[nArrSize]->ordering(); - - for (int i = 0; i < nArrSize; ++i) { - bSameOrderAndEws1 &= (ordering == inArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - - bSameOrderAndEws1 &= (ordering == outArrs[i]->ordering()); - bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); - } + int coords[MAX_RANK]; - BUILD_SINGLE_SELECTOR(inArrs[nArrSize]->dataType(), mergeMaxBp_, (context, inArrs, outArrs, nArrSize, bSameOrderAndEws1), LIBND4J_TYPES); + for (Nd4jLong e = tid; e < length; e += step) { + T mVal = -DataTypeUtils::max(); + int nMaxIndex = 0; + auto xOffset = e, zOffset = e, gradOffset = e; - NDArray::registerSpecialUse( out, inArrs ); - } + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAvgCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - auto output = reinterpret_cast(voutput); + if (!bSameOrderAndEws1) { + auto xShape = reinterpret_cast(inShapes[i]); + xOffset = shape::getOffset(xShape, coords); + } - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + auto val = x[xOffset]; + if (mVal < val) { + mVal = val; + nMaxIndex = i; + } + } - for (Nd4jLong e = tid; e < length; e += step) { - T sum(0.0f); + // outputs have to be pre-nullify + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[nMaxIndex]); + zOffset = shape::getOffset(outShape, coords); + } - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); + auto output = reinterpret_cast(outArrs[nMaxIndex]); - sum += x[shape::getIndexOffset(e, xShape)]; - } + output[zOffset] = grad[gradOffset]; + } +} - output[shape::getIndexOffset(e, outputShape)] = sum / numArrays; - } - } +template +static void mergeMaxBp_(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs, int nArrSize, + bool bSameOrderAndEws1) { + std::vector inBuffers(nArrSize), inShapes(nArrSize), + outBuffers(nArrSize), outShapes(nArrSize); + + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } + + PointersManager manager(context, "mergeMaxBp"); + + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); + + auto length = inArrs[nArrSize]->lengthOf(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + + mergeMaxBpCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, inArrs[nArrSize]->specialBuffer(), + inArrs[nArrSize]->specialShapeInfo(), nArrSize, pOutBuffers, + pOutShapes, length, bSameOrderAndEws1); + + manager.synchronize(); +} - template - static void mergeAvg_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - std::vector inBuffers(inArrs.size()), inShapes(inArrs.size()); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs) { + // not use gradient + int nArrSize = static_cast(inArrs.size() - 1); - for (int e = 0; e < inArrs.size(); e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } + const std::vector& out = + reinterpret_cast&>(outArrs); - PointersManager manager(context, "mergeAvg"); + NDArray::prepareSpecialUse(out, inArrs); - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); + bool bSameOrderAndEws1 = (1 == inArrs[nArrSize]->ews()); + auto ordering = inArrs[nArrSize]->ordering(); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + for (int i = 0; i < nArrSize; ++i) { + bSameOrderAndEws1 &= (ordering == inArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == inArrs[i]->ews()); - mergeAvgCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, (int)inArrs.size(), output.specialBuffer(), output.specialShapeInfo(), length); + bSameOrderAndEws1 &= (ordering == outArrs[i]->ordering()); + bSameOrderAndEws1 &= (1 == outArrs[i]->ews()); + } - manager.synchronize(); - } + BUILD_SINGLE_SELECTOR(inArrs[nArrSize]->dataType(), mergeMaxBp_, + (context, inArrs, outArrs, nArrSize, bSameOrderAndEws1), + LIBND4J_TYPES); - void mergeAvg(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); + NDArray::registerSpecialUse(out, inArrs); +} - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAvgCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); - NDArray::registerSpecialUse({ &output }, inArrs); - } - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAvgBpCudaLauncher( - const void* vgradient, const Nd4jLong* gradientShape, - void** outArrs, void** outShapes, - const int numArrays, - Nd4jLong length, - bool bSameOrderAndEws1) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - auto grad = reinterpret_cast(vgradient); + for (Nd4jLong e = tid; e < length; e += step) { + T sum(0.0f); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); - int coords[MAX_RANK]; + sum += x[shape::getIndexOffset(e, xShape)]; + } - for (Nd4jLong e = tid; e < length; e += step) { + output[shape::getIndexOffset(e, outputShape)] = sum / numArrays; + } +} - auto zOffset = e, gradOffset = e; - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +template +static void mergeAvg_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + std::vector inBuffers(inArrs.size()), inShapes(inArrs.size()); - for (int i = 0; i < numArrays; i++) { + for (int e = 0; e < inArrs.size(); e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[i]); - zOffset = shape::getOffset(outShape, coords); - } + PointersManager manager(context, "mergeAvg"); - auto output = reinterpret_cast(outArrs[i]); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - output[zOffset] = grad[gradOffset] / numArrays; - } - } - } + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - template - static void mergeAvgBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs, bool bSameOrderAndEws1) { + mergeAvgCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, (int)inArrs.size(), output.specialBuffer(), + output.specialShapeInfo(), length); - int nArrSize = static_cast(outArrs.size()); + manager.synchronize(); +} - std::vector outBuffers(nArrSize), outShapes(nArrSize); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int e = 0; e < nArrSize; e++) { - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), + FLOAT_TYPES); - PointersManager manager(context, "mergeAvgBp"); + NDArray::registerSpecialUse({&output}, inArrs); +} +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAvgBpCudaLauncher(const void* vgradient, + const Nd4jLong* gradientShape, + void** outArrs, void** outShapes, + const int numArrays, + Nd4jLong length, + bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + int coords[MAX_RANK]; + + for (Nd4jLong e = tid; e < length; e += step) { + auto zOffset = e, gradOffset = e; + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + for (int i = 0; i < numArrays; i++) { + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[i]); + zOffset = shape::getOffset(outShape, coords); + } - auto length = gradient.lengthOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + auto output = reinterpret_cast(outArrs[i]); - mergeAvgBpCudaLauncher<<getCudaStream()>>>(gradient.specialBuffer(), gradient.specialShapeInfo(), - pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1); + output[zOffset] = grad[gradOffset] / numArrays; + } + } +} - manager.synchronize(); - } +template +static void mergeAvgBp_(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs, + bool bSameOrderAndEws1) { + int nArrSize = static_cast(outArrs.size()); - void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { + std::vector outBuffers(nArrSize), outShapes(nArrSize); - const std::vector& out = reinterpret_cast&>(outArrs); + for (int e = 0; e < nArrSize; e++) { + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } - NDArray::prepareSpecialUse( out, { &gradient }); + PointersManager manager(context, "mergeAvgBp"); - bool bSameOrderAndEws1 = (1 == gradient.ews()); - auto ordering = gradient.ordering(); + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); - for (const auto& v : outArrs) { - bSameOrderAndEws1 &= (ordering == v->ordering()); - bSameOrderAndEws1 &= (1 == v->ews()); - } + auto length = gradient.lengthOf(); - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - NDArray::prepareSpecialUse(out, { &gradient }); - } + mergeAvgBpCudaLauncher + <<getCudaStream()>>>( + gradient.specialBuffer(), gradient.specialShapeInfo(), pOutBuffers, + pOutShapes, nArrSize, length, bSameOrderAndEws1); - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAddCudaLauncher(void** inArrs, void** inShapes, const int numArrays, void* voutput, const Nd4jLong* outputShape, Nd4jLong length) { - - auto output = reinterpret_cast(voutput); + manager.synchronize(); +} - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + const std::vector& out = + reinterpret_cast&>(outArrs); - for (Nd4jLong e = tid; e < length; e += step) { - T sum(0.0f); + NDArray::prepareSpecialUse(out, {&gradient}); - for (int i = 0; i < numArrays; i++) { - auto x = reinterpret_cast(inArrs[i]); - auto xShape = reinterpret_cast(inShapes[i]); + bool bSameOrderAndEws1 = (1 == gradient.ews()); + auto ordering = gradient.ordering(); - sum += x[shape::getIndexOffset(e, xShape)]; - } + for (const auto& v : outArrs) { + bSameOrderAndEws1 &= (ordering == v->ordering()); + bSameOrderAndEws1 &= (1 == v->ews()); + } - output[shape::getIndexOffset(e, outputShape)] = sum; - } - } + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAvgBp_, + (context, gradient, outArrs, bSameOrderAndEws1), + LIBND4J_TYPES); - template - static void mergeAdd_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - int nArrSize = static_cast(inArrs.size()); - std::vector inBuffers(nArrSize), inShapes(nArrSize); + NDArray::prepareSpecialUse(out, {&gradient}); +} - for (int e = 0; e < nArrSize; e++) { - inBuffers[e] = inArrs[e]->specialBuffer(); - inShapes[e] = inArrs[e]->specialShapeInfo(); - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAddCudaLauncher(void** inArrs, void** inShapes, + const int numArrays, void* voutput, + const Nd4jLong* outputShape, + Nd4jLong length) { + auto output = reinterpret_cast(voutput); - PointersManager manager(context, "mergeAdd"); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; - auto pInBuffers = reinterpret_cast(manager.replicatePointer(inBuffers.data(), inBuffers.size() * sizeof(void*))); - auto pInShapes = reinterpret_cast(manager.replicatePointer(inShapes.data(), inShapes.size() * sizeof(void*))); - auto length = output.lengthOf(); + for (Nd4jLong e = tid; e < length; e += step) { + T sum(0.0f); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + for (int i = 0; i < numArrays; i++) { + auto x = reinterpret_cast(inArrs[i]); + auto xShape = reinterpret_cast(inShapes[i]); - mergeAddCudaLauncher<<getCudaStream()>>>(pInBuffers, pInShapes, nArrSize, output.specialBuffer(), output.specialShapeInfo(), length); + sum += x[shape::getIndexOffset(e, xShape)]; + } - manager.synchronize(); - } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (sd::LaunchContext* context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); + output[shape::getIndexOffset(e, outputShape)] = sum; + } +} - void mergeAdd(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output) { - - NDArray::prepareSpecialUse({ &output }, inArrs); - - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); +template +static void mergeAdd_(sd::LaunchContext* context, + const std::vector& inArrs, + NDArray& output) { + int nArrSize = static_cast(inArrs.size()); + std::vector inBuffers(nArrSize), inShapes(nArrSize); - NDArray::registerSpecialUse({ &output }, inArrs); - } + for (int e = 0; e < nArrSize; e++) { + inBuffers[e] = inArrs[e]->specialBuffer(); + inShapes[e] = inArrs[e]->specialShapeInfo(); + } - ////////////////////////////////////////////////////////////////////////// - template - static __global__ void mergeAddBpCudaLauncher(const void* vgradient, const Nd4jLong* gradientShape, void** outArrs, void** outShapes, - const int numArrays, Nd4jLong length, bool bSameOrderAndEws1) { + PointersManager manager(context, "mergeAdd"); - auto grad = reinterpret_cast(vgradient); + auto pInBuffers = reinterpret_cast(manager.replicatePointer( + inBuffers.data(), inBuffers.size() * sizeof(void*))); + auto pInShapes = reinterpret_cast(manager.replicatePointer( + inShapes.data(), inShapes.size() * sizeof(void*))); + auto length = output.lengthOf(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - int coords[MAX_RANK]; + mergeAddCudaLauncher + <<getCudaStream()>>>( + pInBuffers, pInShapes, nArrSize, output.specialBuffer(), + output.specialShapeInfo(), length); - for (Nd4jLong e = tid; e < length; e += step) { + manager.synchronize(); +} +BUILD_SINGLE_TEMPLATE(template void mergeAdd_, + (sd::LaunchContext * context, + const std::vector& inArrs, + NDArray& output), + NUMERIC_TYPES); - auto zOffset = e, gradOffset = e; - if (!bSameOrderAndEws1) { - shape::index2coords(e, gradientShape, coords); - gradOffset = shape::getOffset(gradientShape, coords); - } +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, inArrs); - for (int i = 0; i < numArrays; i++) { - - if (!bSameOrderAndEws1) { - auto outShape = reinterpret_cast(outShapes[i]); - zOffset = shape::getOffset(outShape, coords); - } + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), + NUMERIC_TYPES); - auto output = reinterpret_cast(outArrs[i]); + NDArray::registerSpecialUse({&output}, inArrs); +} - output[zOffset] = grad[gradOffset]; - } - } - } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void mergeAddBpCudaLauncher(const void* vgradient, + const Nd4jLong* gradientShape, + void** outArrs, void** outShapes, + const int numArrays, + Nd4jLong length, + bool bSameOrderAndEws1) { + auto grad = reinterpret_cast(vgradient); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + int coords[MAX_RANK]; + + for (Nd4jLong e = tid; e < length; e += step) { + auto zOffset = e, gradOffset = e; + if (!bSameOrderAndEws1) { + shape::index2coords(e, gradientShape, coords); + gradOffset = shape::getOffset(gradientShape, coords); + } - template - static void mergeAddBp_(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs, bool bSameOrderAndEws1) { + for (int i = 0; i < numArrays; i++) { + if (!bSameOrderAndEws1) { + auto outShape = reinterpret_cast(outShapes[i]); + zOffset = shape::getOffset(outShape, coords); + } - int nArrSize = static_cast(outArrs.size()); + auto output = reinterpret_cast(outArrs[i]); - std::vector outBuffers(nArrSize), outShapes(nArrSize); + output[zOffset] = grad[gradOffset]; + } + } +} - for (int e = 0; e < nArrSize; e++) { - outBuffers[e] = outArrs[e]->specialBuffer(); - outShapes[e] = outArrs[e]->specialShapeInfo(); - } +template +static void mergeAddBp_(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs, + bool bSameOrderAndEws1) { + int nArrSize = static_cast(outArrs.size()); - PointersManager manager(context, "mergeAddBp"); + std::vector outBuffers(nArrSize), outShapes(nArrSize); - auto pOutBuffers = reinterpret_cast(manager.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void*))); - auto pOutShapes = reinterpret_cast(manager.replicatePointer(outShapes.data(), outShapes.size() * sizeof(void*))); + for (int e = 0; e < nArrSize; e++) { + outBuffers[e] = outArrs[e]->specialBuffer(); + outShapes[e] = outArrs[e]->specialShapeInfo(); + } - auto length = gradient.lengthOf(); + PointersManager manager(context, "mergeAddBp"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; + auto pOutBuffers = reinterpret_cast(manager.replicatePointer( + outBuffers.data(), outBuffers.size() * sizeof(void*))); + auto pOutShapes = reinterpret_cast(manager.replicatePointer( + outShapes.data(), outShapes.size() * sizeof(void*))); - mergeAddBpCudaLauncher<<getCudaStream()>>>(gradient.specialBuffer(), gradient.specialShapeInfo(), - pOutBuffers, pOutShapes, nArrSize, length, bSameOrderAndEws1); + auto length = gradient.lengthOf(); - manager.synchronize(); - } + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (length + threadsPerBlock - 1) / threadsPerBlock; - void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs) { + mergeAddBpCudaLauncher + <<getCudaStream()>>>( + gradient.specialBuffer(), gradient.specialShapeInfo(), pOutBuffers, + pOutShapes, nArrSize, length, bSameOrderAndEws1); - const std::vector& out = reinterpret_cast& >(outArrs); - NDArray::prepareSpecialUse( out, { &gradient }); + manager.synchronize(); +} - bool bSameOrderAndEws1 = (1 == gradient.ews()); - auto ordering = gradient.ordering(); +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs) { + const std::vector& out = + reinterpret_cast&>(outArrs); + NDArray::prepareSpecialUse(out, {&gradient}); - for (const auto& v : outArrs) { - bSameOrderAndEws1 &= (ordering == v->ordering()); - bSameOrderAndEws1 &= (1 == v->ews()); - } + bool bSameOrderAndEws1 = (1 == gradient.ews()); + auto ordering = gradient.ordering(); - BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, (context, gradient, outArrs, bSameOrderAndEws1), LIBND4J_TYPES); + for (const auto& v : outArrs) { + bSameOrderAndEws1 &= (ordering == v->ordering()); + bSameOrderAndEws1 &= (1 == v->ews()); + } - NDArray::prepareSpecialUse( out, { &gradient }); - } + BUILD_SINGLE_SELECTOR(gradient.dataType(), mergeAddBp_, + (context, gradient, outArrs, bSameOrderAndEws1), + LIBND4J_TYPES); - } - } + NDArray::prepareSpecialUse(out, {&gradient}); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu index 3f2ed13b5218..dce4e6df952b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu @@ -18,131 +18,144 @@ // @author raver119@gmail.com // - -#include -#include -#include #include +#include +#include +#include + #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static _CUDA_D void assign_(void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); +template +static _CUDA_D void assign_(void *vx, Nd4jLong *xShapeInfo, void *vz, + Nd4jLong *zShapeInfo) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - __shared__ Nd4jLong length; + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - if (threadIdx.x == 0) { - length = shape::length(xShapeInfo); - } - __syncthreads(); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = threadIdx.x; i < length; i += blockDim.x) { - z[i * zEws] = x[i * xEws]; - } - } else { - for (int i = threadIdx.x; i < length; i += blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo); - auto zOffset = shape::getIndexOffset(i, zShapeInfo); + __shared__ Nd4jLong length; - z[zOffset] = x[xOffset]; - } - } + if (threadIdx.x == 0) { + length = shape::length(xShapeInfo); + } + __syncthreads(); + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = threadIdx.x; i < length; i += blockDim.x) { + z[i * zEws] = x[i * xEws]; } + } else { + for (int i = threadIdx.x; i < length; i += blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo); + auto zOffset = shape::getIndexOffset(i, zShapeInfo); - template - static _CUDA_G void meshgridKernel(int rank, void **outBuffers, Nd4jLong **tadShapes, Nd4jLong **tadOffsets, Nd4jLong *numTads, void **inBuffers, Nd4jLong **inShapes) { - // for all arrays - for (int i = blockIdx.x; i < rank; i += gridDim.x) { - - // for all tads in this array - for(Nd4jLong j = 0; j < numTads[i]; j++) { - assign_(inBuffers[i], inShapes[i], reinterpret_cast(outBuffers[i]) + tadOffsets[i][j], tadShapes[i]); - } - __syncthreads(); - } + z[zOffset] = x[xOffset]; } + } +} - template - static void meshgrid_(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - const int rank = inArrs.size(); - int inIndices[MAX_RANK]; - std::iota(inIndices, inIndices + rank, 0); - if(swapFirst2Dims && rank > 1) { - inIndices[0] = 1; - inIndices[1] = 0; - } - - PointersManager pm(context, "meshgrid"); - std::vector hInBuffers(rank); - std::vector hOutBuffers(rank); - std::vector hInShapes(rank); - - std::vector hOutTadShapes(rank); - std::vector hOutTadOffsets(rank); - - std::vector hNumTads(rank); - - for(int i = 0; i < rank; ++i) { - hInBuffers[i] = inArrs[i]->specialBuffer(); - hInShapes[i] = inArrs[i]->specialShapeInfo(); - - hOutBuffers[i] = outArrs[i]->specialBuffer(); - - - auto pack = ConstantTadHelper::getInstance()->tadForDimensions(outArrs[i]->shapeInfo(), {inIndices[i]}); - hOutTadShapes[i] = pack.specialShapeInfo(); - hOutTadOffsets[i] = pack.specialOffsets(); - hNumTads[i] = pack.numberOfTads(); - - - //auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); - //for(int j = 0; j < list->size(); ++j) - // list->at(j)->assign(inArrs[i]); - - //delete list; - } - - auto dInBuffers = reinterpret_cast(pm.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void *))); - auto dOutBuffers = reinterpret_cast(pm.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void *))); - - - auto dInShapes = reinterpret_cast(pm.replicatePointer(hInShapes.data(), hInShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(hOutTadShapes.data(), hOutTadShapes.size() * sizeof(Nd4jLong *))); - auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(hOutTadOffsets.data(), hOutTadOffsets.size() * sizeof(Nd4jLong *))); - - auto dNumTads = reinterpret_cast(pm.replicatePointer(hNumTads.data(), hNumTads.size() * sizeof(Nd4jLong))); - - - meshgridKernel<<<256, 256, 1024, *context->getCudaStream()>>>(rank, dOutBuffers, dOutTadShapes, dOutTadOffsets, dNumTads, dInBuffers, dInShapes); - - pm.synchronize(); +template +static _CUDA_G void meshgridKernel(int rank, void **outBuffers, + Nd4jLong **tadShapes, Nd4jLong **tadOffsets, + Nd4jLong *numTads, void **inBuffers, + Nd4jLong **inShapes) { + // for all arrays + for (int i = blockIdx.x; i < rank; i += gridDim.x) { + // for all tads in this array + for (Nd4jLong j = 0; j < numTads[i]; j++) { + assign_(inBuffers[i], inShapes[i], + reinterpret_cast(outBuffers[i]) + tadOffsets[i][j], + tadShapes[i]); } + __syncthreads(); + } +} - ////////////////////////////////////////////////////////////////////////// - void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - - BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), NUMERIC_TYPES); +template +static void meshgrid_(sd::LaunchContext *context, + const std::vector &inArrs, + const std::vector &outArrs, + const bool swapFirst2Dims) { + const int rank = inArrs.size(); + int inIndices[MAX_RANK]; + std::iota(inIndices, inIndices + rank, 0); + if (swapFirst2Dims && rank > 1) { + inIndices[0] = 1; + inIndices[1] = 0; + } + + PointersManager pm(context, "meshgrid"); + std::vector hInBuffers(rank); + std::vector hOutBuffers(rank); + std::vector hInShapes(rank); + + std::vector hOutTadShapes(rank); + std::vector hOutTadOffsets(rank); + + std::vector hNumTads(rank); + + for (int i = 0; i < rank; ++i) { + hInBuffers[i] = inArrs[i]->specialBuffer(); + hInShapes[i] = inArrs[i]->specialShapeInfo(); + + hOutBuffers[i] = outArrs[i]->specialBuffer(); + + auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + outArrs[i]->shapeInfo(), {inIndices[i]}); + hOutTadShapes[i] = pack.specialShapeInfo(); + hOutTadOffsets[i] = pack.specialOffsets(); + hNumTads[i] = pack.numberOfTads(); + + // auto list = outArrs[i]->allTensorsAlongDimension({inIndices[i]}); + // for(int j = 0; j < list->size(); ++j) + // list->at(j)->assign(inArrs[i]); + + // delete list; + } + + auto dInBuffers = reinterpret_cast(pm.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void *))); + auto dOutBuffers = reinterpret_cast(pm.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void *))); + + auto dInShapes = reinterpret_cast(pm.replicatePointer( + hInShapes.data(), hInShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadShapes = reinterpret_cast(pm.replicatePointer( + hOutTadShapes.data(), hOutTadShapes.size() * sizeof(Nd4jLong *))); + auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer( + hOutTadOffsets.data(), hOutTadOffsets.size() * sizeof(Nd4jLong *))); + + auto dNumTads = reinterpret_cast( + pm.replicatePointer(hNumTads.data(), hNumTads.size() * sizeof(Nd4jLong))); + + meshgridKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + rank, dOutBuffers, dOutTadShapes, dOutTadOffsets, dNumTads, dInBuffers, + dInShapes); + + pm.synchronize(); +} - for (auto v:outArrs) - v->tickWriteDevice(); - } +////////////////////////////////////////////////////////////////////////// +void meshgrid(sd::LaunchContext *context, const std::vector &inArrs, + const std::vector &outArrs, + const bool swapFirst2Dims) { + BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, + (context, inArrs, outArrs, swapFirst2Dims), + NUMERIC_TYPES); -} -} + for (auto v : outArrs) v->tickWriteDevice(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index b43bb418eb39..078894e3f382 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -19,91 +19,88 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include #include +#include namespace sd { - namespace ops { - namespace helpers { - - template - void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - - auto lambdaX = LAMBDA_TTT(_e, _x, _y) { - return _x <= _y ? _e : (T) 0.; - }; - - auto lambdaY = LAMBDA_TTT(_e, _x, _y) { - return _x >= _y ? _e : (T) 0.; - }; - - - if (x->isSameShape(y)) { - // PWT case case - - // X gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); - - // Y gradient - epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); - - } else if (y->isScalar()) { - T s = y->e(0); - auto lambdaS = LAMBDA_TT(_e, _x, s) { - return _x <= s ? _e : (T) 0.; - }; - - // scalar case - auto tmp = epsNext->reduceNumber(reduce::Sum); - if (x <= y) - gradY->assign(tmp); - else - gradY->assign(0.0f); - - epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); - } else { - // broadcast case - - // in this case we want to boost our X and Y shapes to the size of FF pass output (or epsNext, which has the same shape) - auto preX = x->dup(); - auto preY = y->dup(); - - auto targetShape = epsNext->getShapeAsVector(); - - preX.tileToShape(targetShape, preX); - preY.tileToShape(targetShape, preY); - - epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); - epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); - - auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); - auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); - - if (axisX.size() > 0) { - auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); - gradX->assign(sum); - } else - gradX->assign(preX); - - if (axisY.size() > 0) { - auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); - gradY->assign(sum); - } else - gradY->assign(preY); - } - - } - - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY) { - NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); +namespace ops { +namespace helpers { + +template +void minimumBPFunctor_(NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, + NDArray* gradY) { + auto lambdaX = LAMBDA_TTT(_e, _x, _y) { return _x <= _y ? _e : (T)0.; }; + + auto lambdaY = LAMBDA_TTT(_e, _x, _y) { return _x >= _y ? _e : (T)0.; }; + + if (x->isSameShape(y)) { + // PWT case case + + // X gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaX, *gradX); + + // Y gradient + epsNext->applyTriplewiseLambda(*x, *y, lambdaY, *gradY); + + } else if (y->isScalar()) { + T s = y->e(0); + auto lambdaS = LAMBDA_TT(_e, _x, s) { return _x <= s ? _e : (T)0.; }; + + // scalar case + auto tmp = epsNext->reduceNumber(reduce::Sum); + if (x <= y) + gradY->assign(tmp); + else + gradY->assign(0.0f); + + epsNext->applyPairwiseLambda(*x, lambdaS, *gradX); + } else { + // broadcast case + + // in this case we want to boost our X and Y shapes to the size of FF pass + // output (or epsNext, which has the same shape) + auto preX = x->dup(); + auto preY = y->dup(); + + auto targetShape = epsNext->getShapeAsVector(); + + preX.tileToShape(targetShape, preX); + preY.tileToShape(targetShape, preY); + + epsNext->applyTriplewiseLambda(preX, preY, lambdaX, preX); + epsNext->applyTriplewiseLambda(preX, preY, lambdaY, preY); + + auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), + epsNext->shapeInfo()); + auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), + epsNext->shapeInfo()); + + if (axisX.size() > 0) { + auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); + gradX->assign(sum); + } else + gradX->assign(preX); + + if (axisY.size() > 0) { + auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); + gradY->assign(sum); + } else + gradY->assign(preY); + } +} - BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY) { + NDArray::prepareSpecialUse({gradX, gradY}, {x, y, epsNext}); - NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); - } + BUILD_SINGLE_SELECTOR(x->dataType(), minimumBPFunctor_, + (x, y, epsNext, gradX, gradY), NUMERIC_TYPES); - } - } + NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } + +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index c3b4abc51777..222d3ae897c5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -18,72 +18,88 @@ // @author sgazeos@gmail.com // -#include -#include -#include +#include #include +#include +#include #include -#include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong const* outputShapeInfo, void* inputBuffer, Nd4jLong const* inputShapeInfo, Nd4jLong const* pTadShape, Nd4jLong const* pTadOffsets, Nd4jLong n) { - __shared__ Nd4jLong bufferLength; - - auto z = reinterpret_cast(outputBuffer); - auto x = reinterpret_cast(inputBuffer); - - if (threadIdx.x == 0) - bufferLength = shape::length(outputShapeInfo); - - __syncthreads(); +template +static __global__ void fillUpElementKernel( + void* outputBuffer, Nd4jLong const* outputShapeInfo, void* inputBuffer, + Nd4jLong const* inputShapeInfo, Nd4jLong const* pTadShape, + Nd4jLong const* pTadOffsets, Nd4jLong n) { + __shared__ Nd4jLong bufferLength; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - auto tX = x + pTadOffsets[t]; - z[shape::getIndexOffset(t, outputShapeInfo)] = tX[shape::getIndexOffset(n, pTadShape)]; //tX]; - } - } + auto z = reinterpret_cast(outputBuffer); + auto x = reinterpret_cast(inputBuffer); - template - void nthElementFunctor_(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { + if (threadIdx.x == 0) bufferLength = shape::length(outputShapeInfo); - NDArray::prepareSpecialUse({output}, {input}); - NDArray sortedVals(*input); - Nd4jPointer params[2]; - params[0] = context; - params[1] = context->getCudaStream(); - // Nth element in sorted sequence : basic algorithm sort and retrieve nth element in sorted - if (input->isVector()) { - sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse); + __syncthreads(); - cudaMemcpy(reinterpret_cast(output->specialBuffer()), reinterpret_cast(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice); - } - else { // rank greater than 1 - std::vector lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (int t = tid; t < bufferLength; t += step) { + auto tX = x + pTadOffsets[t]; + z[shape::getIndexOffset(t, outputShapeInfo)] = + tX[shape::getIndexOffset(n, pTadShape)]; // tX]; + } +} - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(sortedVals.shapeInfo(), lastDims); +template +void nthElementFunctor_(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse) { + NDArray::prepareSpecialUse({output}, {input}); + NDArray sortedVals(*input); + Nd4jPointer params[2]; + params[0] = context; + params[1] = context->getCudaStream(); + // Nth element in sorted sequence : basic algorithm sort and retrieve nth + // element in sorted + if (input->isVector()) { + sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), + sortedVals.specialShapeInfo(), reverse); - auto pTadShape = packX.specialShapeInfo(); - auto pTadShapeH = packX.primaryShapeInfo(); - auto pTadOffsets = packX.specialOffsets(); - sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse); - sortedVals.tickWriteDevice(); - sortedVals.syncToHost(); - auto stream = context->getCudaStream(); - fillUpElementKernel<<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n); - } - NDArray::registerSpecialUse({output}, {input}); - } - void nthElementFunctor(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { - BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (context, input, n, output, reverse), LIBND4J_TYPES); + cudaMemcpy(reinterpret_cast(output->specialBuffer()), + reinterpret_cast(sortedVals.specialBuffer()) + n, sizeof(T), + cudaMemcpyDeviceToDevice); + } else { // rank greater than 1 + std::vector lastDims( + {input->rankOf() - + 1}); // = ShapeUtils::evalDimsToExclude(input->rankOf(), + // {input->rankOf() - 1}); - } + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sortedVals.shapeInfo(), lastDims); + auto pTadShape = packX.specialShapeInfo(); + auto pTadShapeH = packX.primaryShapeInfo(); + auto pTadOffsets = packX.specialOffsets(); + sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), + lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse); + sortedVals.tickWriteDevice(); + sortedVals.syncToHost(); + auto stream = context->getCudaStream(); + fillUpElementKernel<<<32, 64, 1024, *stream>>>( + output->specialBuffer(), output->specialShapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, + pTadOffsets, n); + } + NDArray::registerSpecialUse({output}, {input}); } +void nthElementFunctor(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse) { + BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, + (context, input, n, output, reverse), LIBND4J_TYPES); } -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu b/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu index f1520045955f..0ee16f576f90 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/one_hot.cu @@ -18,93 +18,104 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 30.05.2019 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, z - output -template -__global__ static void onehotCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const uint axis, const uint depth, const Z on, const Z off) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto coord = sharedMem + threadIdx.x * zRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coord); - const auto zOffset = shape::getOffset(zShapeInfo, coord); - const auto depthCoord = coord[axis]; - - for (uint j = axis; j < zRank - 1; ++j) - coord[j] = coord[j + 1]; - - const auto xOffset = shape::getOffset(xShapeInfo, coord); - const Nd4jLong idx = x[xOffset]; - z[zOffset] = depthCoord == idx ? on : off; - } +template +__global__ static void onehotCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + const uint axis, const uint depth, const Z on, + const Z off) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, zRank; + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + auto coord = sharedMem + threadIdx.x * zRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coord); + const auto zOffset = shape::getOffset(zShapeInfo, coord); + const auto depthCoord = coord[axis]; + + for (uint j = axis; j < zRank - 1; ++j) coord[j] = coord[j + 1]; + + const auto xOffset = shape::getOffset(xShapeInfo, coord); + const Nd4jLong idx = x[xOffset]; + z[zOffset] = depthCoord == idx ? on : off; + } } /////////////////////////////////////////////////////////////////// -template -static void onehotCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const uint axis, const uint depth, - const double on, const double off) { - - onehotCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, axis, depth, static_cast(on), static_cast(off)); +template +static void onehotCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, const uint axis, + const uint depth, const double on, + const double off) { + onehotCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, axis, depth, static_cast(on), + static_cast(off)); } /////////////////////////////////////////////////////////////////// -void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off) { - - const auto xType = indices->dataType(); - const auto zType = output->dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(decltype(*output->shapeInfo())) * output->rankOf() + 128; - - PointersManager manager(context, "onehot"); - - NDArray::prepareSpecialUse({output}, {indices}); - BUILD_DOUBLE_SELECTOR(xType, zType, onehotCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), axis, depth, on, off), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {indices}); - - manager.synchronize(); +void onehot(const sd::LaunchContext *context, const NDArray *indices, + NDArray *output, const uint axis, const uint depth, const double on, + const double off) { + const auto xType = indices->dataType(); + const auto zType = output->dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * + sizeof(decltype(*output->shapeInfo())) * + output->rankOf() + + 128; + + PointersManager manager(context, "onehot"); + + NDArray::prepareSpecialUse({output}, {indices}); + BUILD_DOUBLE_SELECTOR(xType, zType, onehotCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + context->getCudaStream(), indices->specialBuffer(), + indices->specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), axis, depth, on, off), + LIBND4J_TYPES, LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {indices}); + + manager.synchronize(); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu index 842a41ced7cf..5728b48c431d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu @@ -18,250 +18,281 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include + +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - input, y - paddings, z - output - template - __global__ static void padCuda(const int mode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void *vPadVal) { - - const X padVal = *reinterpret_cast(vPadVal); - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int rank, rankMinusOne; - __shared__ Nd4jLong zLen, totalThreads, *coords, *xShape, *zShape, shift1, shift2, yStride0; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - zLen = shape::length(zShapeInfo); - xShape = shape::shapeOf(const_cast(xShapeInfo)); - zShape = shape::shapeOf(const_cast(zShapeInfo)); - yStride0 = shape::stride(const_cast(yShapeInfo))[0]; - rank = shape::rank(xShapeInfo); - zLen = shape::length(zShapeInfo); - rankMinusOne = rank - 1; - totalThreads = gridDim.x * blockDim.x; - shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC - shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC - } - - __syncthreads(); - - auto xzCoord = coords + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - if(mode == 0) { // CONSTANT case - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, xzCoord); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); - - bool within = true; - for(int j = rankMinusOne; j >= 0; --j) { - if(xShape[j] == zShape[j]) continue; - const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; - if(xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) {within = false; break;} - else {xzCoord[j] = xzCoord[j] - left;} - } - - if(within) - z[zOffset] = x[shape::getOffset(xShapeInfo, xzCoord)]; - else - z[zOffset] = padVal; - } - } - else { // REFLECT and SYMMETRIC cases - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, xzCoord); - const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); - - for(int j = rankMinusOne; j >= 0; --j) { - - if(xShape[j] == zShape[j]) continue; - xzCoord[j] = xzCoord[j] - y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; // are ready to fill middle (within input dimension range) - if(xzCoord[j] < 0) xzCoord[j] = -xzCoord[j] - shift1; // means fill from left - else if(xzCoord[j] >= xShape[j]) xzCoord[j] = 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right - } - - const auto xOffset = shape::getOffset(xShapeInfo, xzCoord); - z[zOffset] = x[xOffset]; - } - } - } - -/////////////////////////////////////////////////////////////////// - template - static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int mode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const void* padVal) { +template +__global__ static void padCuda(const int mode, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const void* vPadVal) { + const X padVal = *reinterpret_cast(vPadVal); + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int rank, rankMinusOne; + __shared__ Nd4jLong zLen, totalThreads, *coords, *xShape, *zShape, shift1, + shift2, yStride0; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + zLen = shape::length(zShapeInfo); + xShape = shape::shapeOf(const_cast(xShapeInfo)); + zShape = shape::shapeOf(const_cast(zShapeInfo)); + yStride0 = shape::stride(const_cast(yShapeInfo))[0]; + rank = shape::rank(xShapeInfo); + zLen = shape::length(zShapeInfo); + rankMinusOne = rank - 1; + totalThreads = gridDim.x * blockDim.x; + shift1 = mode == 1 ? 0 : 1; // REFLECT : SYMMETRIC + shift2 = mode == 1 ? 2 : 1; // REFLECT : SYMMETRIC + } + + __syncthreads(); + + auto xzCoord = + coords + + threadIdx.x * rank; // we use xzCoord storage both for x and z arrays + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (mode == 0) { // CONSTANT case + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, xzCoord); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); + + bool within = true; + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; + const auto left = y[shape::getIndexOffset(yStride0 * j, yShapeInfo)]; + if (xzCoord[j] < left || xzCoord[j] >= left + xShape[j]) { + within = false; + break; + } else { + xzCoord[j] = xzCoord[j] - left; + } + } - padCuda<<>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); - } + if (within) + z[zOffset] = x[shape::getOffset(xShapeInfo, xzCoord)]; + else + z[zOffset] = padVal; + } + } else { // REFLECT and SYMMETRIC cases + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, xzCoord); + const auto zOffset = shape::getOffset(zShapeInfo, xzCoord); + + for (int j = rankMinusOne; j >= 0; --j) { + if (xShape[j] == zShape[j]) continue; + xzCoord[j] = + xzCoord[j] - + y[shape::getIndexOffset( + yStride0 * j, yShapeInfo)]; // are ready to fill middle (within + // input dimension range) + if (xzCoord[j] < 0) + xzCoord[j] = -xzCoord[j] - shift1; // means fill from left + else if (xzCoord[j] >= xShape[j]) + xzCoord[j] = + 2 * xShape[j] - xzCoord[j] - shift2; // means fill from right + } + + const auto xOffset = shape::getOffset(xShapeInfo, xzCoord); + z[zOffset] = x[xOffset]; + } + } +} /////////////////////////////////////////////////////////////////// - void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { - - PointersManager manager(context, "pad"); - - NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue}); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128; - - const auto xType = input.dataType(); - const auto yType = paddings.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.specialBuffer(), input.specialShapeInfo(), paddings.specialBuffer(), paddings.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), padValue.specialBuffer()), LIBND4J_TYPES, INDEXING_TYPES); - - NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); - manager.synchronize(); - } - - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void mirrorPadLinearKernel(void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, Nd4jLong zLen) { - - __shared__ T const* x; - __shared__ T* z; - if (threadIdx.x == 0) { - x = reinterpret_cast(vx); - z = reinterpret_cast(vz); - } - __syncthreads(); - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for(int i = start; i < zLen; i+= step) { - auto zIndex = shape::getIndexOffset(i, zShape); - auto xIndex = shape::getIndexOffset(len - i, xShape); - - if (i < leftSide) // left side - xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape); - - else if(i >= leftSide && i < leftSide + xLen) // middle - xIndex = shape::getIndexOffset(i - leftSide, xShape); - -// else // right side -// z[i] = x[len - i]; - z[zIndex] = x[xIndex]; - } - - } - - template - static __global__ void mirrorPadKernel(void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, Nd4jLong outLen, void const* paddings, const Nd4jLong* paddingShape, int reflBorder) { - - __shared__ F const* x; - __shared__ I const* pads; - __shared__ F* z; - __shared__ Nd4jLong zRank, rank; - __shared__ Nd4jLong* xIdx; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - xIdx = reinterpret_cast(shmem); - rank = shape::rank(xShape); - - x = reinterpret_cast(vx);// - pads = reinterpret_cast(paddings); - z = reinterpret_cast(vz); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for(Nd4jLong i = start; i < outLen; i+= step) { - auto xzCoord = xIdx + threadIdx.x * rank; - //auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank; - - shape::index2coords(i, zShape, xzCoord); - auto outOffset = shape::getOffset(zShape, xzCoord); -// auto intStep = blockDim.y * gridDim.y; - for(int j = 0; j < rank; j++) { - - const Nd4jLong inLen = shape::sizeAt(xShape, j); - Nd4jLong coords[2] = {j, 0}; - auto padOffset = shape::getOffset(paddingShape, coords); // padding already has rank 2 - const auto leftSide = pads[padOffset]; - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; - - if(xzCoord[j] < leftSide) // left side - xzCoord[j] = leftSideCorrected - xzCoord[j]; - - else if(xzCoord[j] >= leftSide && xzCoord[j] < leftSide + inLen) // middle - xzCoord[j] = xzCoord[j] - leftSide; - - else if (len > xzCoord[j]) // right side - xzCoord[j] = len - xzCoord[j]; - else - xzCoord[j] = xzCoord[j] - len; - } - - auto inOffset = shape::getOffset(xShape, xzCoord); - z[outOffset] = x[inOffset]; - } - } - - template - static void mirrorPad_(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - // mode: 0 - REFLECT, else - SYMMETRIC - const int reflBorder = (bool)mode ? 1 : 0; - const int rank = input.rankOf(); - const Nd4jLong outLen = output.lengthOf(); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({&output}, {&input, &paddings}); - - if(rank <= 1) { - - const Nd4jLong inLen = input.lengthOf(); - const auto leftSide = paddings.e(0); - const auto leftSideCorrected = leftSide - reflBorder; - const Nd4jLong len = 2*(inLen-1) + leftSide + reflBorder; - - mirrorPadLinearKernel<<<256, 512, 256, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, outLen); - sd::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadLinearKernel(...) failed"); - } - else { - mirrorPadKernel<<<256, 256, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), outLen, paddings.specialBuffer(), paddings.specialShapeInfo(), reflBorder); - sd::DebugHelper::checkErrorCode(stream, "helpers::mirrorPadKernel(...) failed"); - } - NDArray::registerSpecialUse({&output}, {&input, &paddings}); - } - - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INDEXING_TYPES); - } - +template +static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const int mode, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const void* padVal) { + padCuda<<>>( + mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); +} - } +/////////////////////////////////////////////////////////////////// +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, const NDArray& padValue) { + PointersManager manager(context, "pad"); + + NDArray::prepareSpecialUse({&output}, {&input, &paddings, &padValue}); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 8 * threadsPerBlock * output.rankOf() + 128; + + const auto xType = input.dataType(); + const auto yType = paddings.dataType(); + + BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, + context->getCudaStream(), mode, input.specialBuffer(), + input.specialShapeInfo(), paddings.specialBuffer(), + paddings.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), padValue.specialBuffer()), + LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); + manager.synchronize(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static __global__ void mirrorPadLinearKernel( + void const* vx, const Nd4jLong* xShape, void* vz, const Nd4jLong* zShape, + Nd4jLong leftSide, Nd4jLong leftSideCorrected, Nd4jLong xLen, Nd4jLong len, + Nd4jLong zLen) { + __shared__ T const* x; + __shared__ T* z; + if (threadIdx.x == 0) { + x = reinterpret_cast(vx); + z = reinterpret_cast(vz); + } + __syncthreads(); + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < zLen; i += step) { + auto zIndex = shape::getIndexOffset(i, zShape); + auto xIndex = shape::getIndexOffset(len - i, xShape); + + if (i < leftSide) // left side + xIndex = shape::getIndexOffset(leftSideCorrected - i, xShape); + + else if (i >= leftSide && i < leftSide + xLen) // middle + xIndex = shape::getIndexOffset(i - leftSide, xShape); + + // else // right + // side + // z[i] = x[len - i]; + z[zIndex] = x[xIndex]; + } +} + +template +static __global__ void mirrorPadKernel(void const* vx, const Nd4jLong* xShape, + void* vz, const Nd4jLong* zShape, + Nd4jLong outLen, void const* paddings, + const Nd4jLong* paddingShape, + int reflBorder) { + __shared__ F const* x; + __shared__ I const* pads; + __shared__ F* z; + __shared__ Nd4jLong zRank, rank; + __shared__ Nd4jLong* xIdx; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + xIdx = reinterpret_cast(shmem); + rank = shape::rank(xShape); + + x = reinterpret_cast(vx); // + pads = reinterpret_cast(paddings); + z = reinterpret_cast(vz); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (Nd4jLong i = start; i < outLen; i += step) { + auto xzCoord = xIdx + threadIdx.x * rank; + // auto zxCoord = xIdx + (threadIdx.x + threadIdx.x % 2 + 1) * rank; + + shape::index2coords(i, zShape, xzCoord); + auto outOffset = shape::getOffset(zShape, xzCoord); + // auto intStep = blockDim.y * gridDim.y; + for (int j = 0; j < rank; j++) { + const Nd4jLong inLen = shape::sizeAt(xShape, j); + Nd4jLong coords[2] = {j, 0}; + auto padOffset = + shape::getOffset(paddingShape, coords); // padding already has rank 2 + const auto leftSide = pads[padOffset]; + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + if (xzCoord[j] < leftSide) // left side + xzCoord[j] = leftSideCorrected - xzCoord[j]; + + else if (xzCoord[j] >= leftSide && + xzCoord[j] < leftSide + inLen) // middle + xzCoord[j] = xzCoord[j] - leftSide; + + else if (len > xzCoord[j]) // right side + xzCoord[j] = len - xzCoord[j]; + else + xzCoord[j] = xzCoord[j] - len; } -} \ No newline at end of file + + auto inOffset = shape::getOffset(xShape, xzCoord); + z[outOffset] = x[inOffset]; + } +} + +template +static void mirrorPad_(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, + const int mode) { + // mode: 0 - REFLECT, else - SYMMETRIC + const int reflBorder = (bool)mode ? 1 : 0; + const int rank = input.rankOf(); + const Nd4jLong outLen = output.lengthOf(); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({&output}, {&input, &paddings}); + + if (rank <= 1) { + const Nd4jLong inLen = input.lengthOf(); + const auto leftSide = paddings.e(0); + const auto leftSideCorrected = leftSide - reflBorder; + const Nd4jLong len = 2 * (inLen - 1) + leftSide + reflBorder; + + mirrorPadLinearKernel<<<256, 512, 256, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), leftSide, leftSideCorrected, inLen, len, + outLen); + sd::DebugHelper::checkErrorCode( + stream, "helpers::mirrorPadLinearKernel(...) failed"); + } else { + mirrorPadKernel<<<256, 256, 8192, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), outLen, paddings.specialBuffer(), + paddings.specialShapeInfo(), reflBorder); + sd::DebugHelper::checkErrorCode(stream, + "helpers::mirrorPadKernel(...) failed"); + } + NDArray::registerSpecialUse({&output}, {&input, &paddings}); +} + +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode) { + BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, + (context, input, paddings, output, mode), LIBND4J_TYPES, + INDEXING_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu index 7f2bcdcfdbd3..e08a38388b89 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/percentile.cu @@ -19,119 +19,137 @@ // @author raver119@gmail.com // -#include #include +#include #include #include -#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static _CUDA_G void percentileKernel(void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLength, - void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong zLength, - const Nd4jLong position) { - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto z = reinterpret_cast(vz); - - - // sort tad - if (tadLength > 1) { - for (int m = 0; m < tadLength; m++) { - if (m % 2 == 0) { - for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < tadLength) { - auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); - auto t1 = shape::getIndexOffset(top, xTadShapeInfo); - - if (x[t0] > x[t1]) { - //swap values - X dz0 = x[t0]; - x[t0] = x[t1]; - x[t1] = dz0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < tadLength) { - auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); - auto t1 = shape::getIndexOffset(top, xTadShapeInfo); - - if (x[t0] > x[t1]) { - //swap values - X dz0 = x[t0]; - x[t0] = x[t1]; - x[t1] = dz0; - } - } - } - } - __syncthreads(); - } +template +static _CUDA_G void percentileKernel(void* vx, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, + const Nd4jLong numTads, + const Nd4jLong tadLength, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong zLength, + const Nd4jLong position) { + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto z = reinterpret_cast(vz); + + // sort tad + if (tadLength > 1) { + for (int m = 0; m < tadLength; m++) { + if (m % 2 == 0) { + for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < tadLength) { + auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); + auto t1 = shape::getIndexOffset(top, xTadShapeInfo); + + if (x[t0] > x[t1]) { + // swap values + X dz0 = x[t0]; + x[t0] = x[t1]; + x[t1] = dz0; + } } - - // saving final value - if (threadIdx.x == 0) - z[shape::getIndexOffset(t, zShapeInfo)] = x[shape::getIndexOffset(position, xTadShapeInfo)]; - __syncthreads(); + } + } else { + for (int tid = threadIdx.x; tid < tadLength; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < tadLength) { + auto t0 = shape::getIndexOffset(top - 1, xTadShapeInfo); + auto t1 = shape::getIndexOffset(top, xTadShapeInfo); + + if (x[t0] > x[t1]) { + // swap values + X dz0 = x[t0]; + x[t0] = x[t1]; + x[t1] = dz0; + } + } + } } + __syncthreads(); + } } + // saving final value + if (threadIdx.x == 0) + z[shape::getIndexOffset(t, zShapeInfo)] = + x[shape::getIndexOffset(position, xTadShapeInfo)]; + __syncthreads(); + } +} +template +static void _percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axis, const float q, + const int interpolation) { + const int inputRank = input.rankOf(); + + if (axis.empty()) + for (int i = 0; i < inputRank; ++i) axis.push_back(i); + else + shape::checkDimensions(inputRank, axis); + + auto tempArray = input.dup(); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + tempArray.shapeInfo(), axis); + + auto tadLength = shape::length(packX.primaryShapeInfo()); + + const float fraction = 1.f - q / 100.; + Nd4jLong position = 0; + + switch (interpolation) { + case 0: // lower + position = static_cast( + math::nd4j_ceil((tadLength - 1) * fraction)); + break; + case 1: // higher + position = static_cast( + math::nd4j_floor((tadLength - 1) * fraction)); + break; + case 2: // nearest + position = static_cast( + math::nd4j_round((tadLength - 1) * fraction)); + break; + } + position = tadLength - position - 1; + + percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>( + tempArray.specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), packX.numberOfTads(), tadLength, + output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), + position); + + sd::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); +} - template - static void _percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axis, const float q, const int interpolation) { - const int inputRank = input.rankOf(); - - if(axis.empty()) - for(int i=0; itadForDimensions(tempArray.shapeInfo(), axis); - - auto tadLength = shape::length(packX.primaryShapeInfo()); - - const float fraction = 1.f - q / 100.; - Nd4jLong position = 0; - - switch(interpolation) { - case 0: // lower - position = static_cast(math::nd4j_ceil((tadLength - 1) * fraction)); - break; - case 1: // higher - position = static_cast(math::nd4j_floor((tadLength - 1) * fraction)); - break; - case 2: // nearest - position = static_cast(math::nd4j_round((tadLength - 1) * fraction)); - break; - } - position = tadLength - position - 1; - - percentileKernel<<<256, 512, 1024, *context->getCudaStream()>>>(tempArray.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), packX.numberOfTads(), tadLength, output.specialBuffer(), output.specialShapeInfo(), output.lengthOf(), position); - - sd::DebugHelper::checkErrorCode(context->getCudaStream(), "percentile"); - } - - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation) { - NDArray::prepareSpecialUse({&output}, {&input}); +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation) { + NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, (context, input, output, axises, q, interpolation), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), _percentile, + (context, input, output, axises, q, interpolation), + LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - } + NDArray::registerSpecialUse({&output}, {&input}); +} - BUILD_SINGLE_TEMPLATE(template void _percentile, (sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void _percentile, + (sd::LaunchContext * context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation), + LIBND4J_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 3e82632e27f8..ccf0cf6e234a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -18,86 +18,98 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template +template __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto n = reinterpret_cast(vn); - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; - __shared__ bool sameOffsetNX, sameOffsetNZ; - - if (threadIdx.x == 0) { - len = shape::length(nShapeInfo); - sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); - sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto totalThreads = gridDim.x * blockDim.x; - - for (int i = tid; i < len; i += totalThreads) { - - const auto nOffset = shape::getIndexOffset(i, nShapeInfo); - const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); - - const T order = n[nOffset]; - - int sign = (static_cast(order) + 1) % 2 ? -1 : 1; - - if(order != static_cast(order)) { - z[zOffset] = DataTypeUtils::nanOrZero(); - } - else if(order == 0) { - z[zOffset] = diGammaScalar(x[xOffset]); - } - else { - T factorial = 1; - for(int i = 2; i <= order; ++i) - factorial *= i; - - z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); - } + const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto n = reinterpret_cast(vn); + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffsetNX, sameOffsetNZ; + + if (threadIdx.x == 0) { + len = shape::length(nShapeInfo); + sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); + sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto totalThreads = gridDim.x * blockDim.x; + + for (int i = tid; i < len; i += totalThreads) { + const auto nOffset = shape::getIndexOffset(i, nShapeInfo); + const auto xOffset = + sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = + sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); + + const T order = n[nOffset]; + + int sign = (static_cast(order) + 1) % 2 ? -1 : 1; + + if (order != static_cast(order)) { + z[zOffset] = DataTypeUtils::nanOrZero(); + } else if (order == 0) { + z[zOffset] = diGammaScalar(x[xOffset]); + } else { + T factorial = 1; + for (int i = 2; i <= order; ++i) factorial *= i; + + z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); } + } } /////////////////////////////////////////////////////////////////// -template -static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - polyGammaCuda<<>>(vn, nShapeInfo, vx, xShapeInfo, vz, zShapeInfo); +template +static void polyGammaCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t *stream, const void *vn, + const Nd4jLong *nShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + polyGammaCuda<<>>( + vn, nShapeInfo, vx, xShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void polyGamma(sd::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) { +void polyGamma(sd::LaunchContext *context, const NDArray &n, const NDArray &x, + NDArray &z) { + NDArray::prepareSpecialUse({&z}, {&n, &x}); - NDArray::prepareSpecialUse({&z}, {&n, &x}); + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + BUILD_SINGLE_SELECTOR( + n.dataType(), polyGammaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + n.specialBuffer(), n.specialShapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), + FLOAT_TYPES); - BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.specialBuffer(), n.specialShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); - - NDArray::registerSpecialUse({&z}, {&n, &x}); -} - -BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - -} -} + NDArray::registerSpecialUse({&z}, {&n, &x}); } +BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vn, + const Nd4jLong *nShapeInfo, const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu index d2832ec80931..0c960fa9b927 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu @@ -18,11 +18,11 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 12.06.2019 // -#include #include #include #include #include +#include namespace sd { namespace ops { @@ -30,148 +30,162 @@ namespace helpers { /////////////////////////////////////////////////////////////////// template -__global__ static void prefixPerBlockCuda(scalar::Ops op, - const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLen, - const bool exclusive, const bool reverse) { - - __shared__ T *shared, lastElemInChunk; - __shared__ uint numTadChunks, blockDim2; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shared = reinterpret_cast(shmem); - blockDim2 = 2 * blockDim.x; - numTadChunks = (tadLen + blockDim2 - 1) / blockDim2; // ceil +__global__ static void prefixPerBlockCuda( + scalar::Ops op, const void* vx, const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, + const Nd4jLong* zTadOffsets, const Nd4jLong numTads, const Nd4jLong tadLen, + const bool exclusive, const bool reverse) { + __shared__ T *shared, lastElemInChunk; + __shared__ uint numTadChunks, blockDim2; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shared = reinterpret_cast(shmem); + blockDim2 = 2 * blockDim.x; + numTadChunks = (tadLen + blockDim2 - 1) / blockDim2; // ceil + } + __syncthreads(); + + const auto xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; + auto zTad = reinterpret_cast(vz) + zTadOffsets[blockIdx.x]; + + Nd4jLong sharedInd(2 * threadIdx.x), leftArrInd, rightArrInd, step; + T xLeft, xRight; + + for (uint i = 0; i < numTadChunks; ++i) { + leftArrInd = sharedInd + i * blockDim2; + rightArrInd = leftArrInd + 1; + + if (reverse) { + if (rightArrInd < tadLen) { + rightArrInd = tadLen - 1 - rightArrInd; + leftArrInd = tadLen - 1 - leftArrInd; + } else if (leftArrInd < tadLen) + leftArrInd = tadLen - 1 - leftArrInd; } - __syncthreads(); - const auto xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; - auto zTad = reinterpret_cast(vz) + zTadOffsets[blockIdx.x]; - - Nd4jLong sharedInd(2 * threadIdx.x), leftArrInd, rightArrInd, step; - T xLeft, xRight; - - for (uint i = 0; i < numTadChunks; ++i) { - - leftArrInd = sharedInd + i * blockDim2; - rightArrInd = leftArrInd + 1; - - if(reverse) { - if(rightArrInd < tadLen) { - rightArrInd = tadLen - 1 - rightArrInd; - leftArrInd = tadLen - 1 - leftArrInd; - } - else if(leftArrInd < tadLen) - leftArrInd = tadLen - 1 - leftArrInd; - } - - if(leftArrInd < tadLen) - shared[sharedInd] = xLeft = xTad[shape::getIndexOffset(leftArrInd, xTadShapeInfo)]; - // else - // shared[sharedInd] = (op == scalar::Add) ? 0 : 1; - - if(rightArrInd < tadLen) - shared[sharedInd + 1] = xRight = xTad[shape::getIndexOffset(rightArrInd, xTadShapeInfo)]; - // else - // shared[sharedInd + 1] = (op == scalar::Add) ? 0 : 1; - - - step = 1; - - for (uint d = blockDim.x; d > 0; d /= 2) { - - __syncthreads(); - if(threadIdx.x < d) { - uint left = step * (sharedInd + 1) - 1; - uint right = step * (sharedInd + 2) - 1; - shared[right] = (op == scalar::Add) ? (shared[right] + shared[left]) : (shared[right] * shared[left]); - } - step *= 2; - } - - if (threadIdx.x == 0) - shared[blockDim2 - 1] = (op == scalar::Add) ? 0 : 1; - __syncthreads(); - - for (uint d = 1; d < blockDim2; d *= 2) { - - step /= 2; - - __syncthreads(); - if(threadIdx.x < d) { - uint left = step * (sharedInd + 1) - 1; - uint right = step * (sharedInd + 2) - 1; - T temp = shared[left]; - shared[left] = shared[right]; - shared[right] = (op == scalar::Add) ? (shared[right] + temp) : (shared[right] * temp); - } - } - - __syncthreads(); - - if(leftArrInd < tadLen) { - T result = shared[sharedInd]; - if(!exclusive) - result = (op == scalar::Add) ? result + xLeft : result * xLeft; - if(i > 0) - result = (op == scalar::Add) ? result + lastElemInChunk : result * lastElemInChunk; - zTad[shape::getIndexOffset(leftArrInd, zTadShapeInfo)] = result; - } - - if(rightArrInd < tadLen) { - T result = shared[sharedInd + 1]; - if(!exclusive) - result = (op == scalar::Add) ? result + xRight : result * xRight; - if(i > 0) - result = (op == scalar::Add) ? result + lastElemInChunk : result * lastElemInChunk; - if(i < numTadChunks - 1 && threadIdx.x == blockDim.x - 1) // last element in chunk - lastElemInChunk = !exclusive ? result : (op == scalar::Add) ? result + xRight : result * xRight; - zTad[shape::getIndexOffset(rightArrInd, zTadShapeInfo)] = result; - } + if (leftArrInd < tadLen) + shared[sharedInd] = xLeft = + xTad[shape::getIndexOffset(leftArrInd, xTadShapeInfo)]; + // else + // shared[sharedInd] = (op == scalar::Add) ? 0 : 1; + + if (rightArrInd < tadLen) + shared[sharedInd + 1] = xRight = + xTad[shape::getIndexOffset(rightArrInd, xTadShapeInfo)]; + // else + // shared[sharedInd + 1] = (op == scalar::Add) ? 0 : 1; + + step = 1; + + for (uint d = blockDim.x; d > 0; d /= 2) { + __syncthreads(); + if (threadIdx.x < d) { + uint left = step * (sharedInd + 1) - 1; + uint right = step * (sharedInd + 2) - 1; + shared[right] = (op == scalar::Add) ? (shared[right] + shared[left]) + : (shared[right] * shared[left]); + } + step *= 2; } -} - -/////////////////////////////////////////////////////////////////// -template -static void prefixPerBlockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - scalar::Ops op, - const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numTads, const Nd4jLong tadLen, - const bool exclusive, const bool reverse) { - - prefixPerBlockCuda<<>>(op, vx, xTadShapeInfo, xTadOffsets, vz, zTadShapeInfo, zTadOffsets, numTads, tadLen, exclusive, reverse); -} - -/////////////////////////////////////////////////////////////////// -void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse) { - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z->shapeInfo(), dims); - const Nd4jLong numTads = packX.numberOfTads(); - const Nd4jLong tadLen = x->lengthOf() / numTads; + if (threadIdx.x == 0) shared[blockDim2 - 1] = (op == scalar::Add) ? 0 : 1; + __syncthreads(); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = numTads; - const int sharedMem = 2 * threadsPerBlock * x->sizeOfT() + 128; + for (uint d = 1; d < blockDim2; d *= 2) { + step /= 2; + + __syncthreads(); + if (threadIdx.x < d) { + uint left = step * (sharedInd + 1) - 1; + uint right = step * (sharedInd + 2) - 1; + T temp = shared[left]; + shared[left] = shared[right]; + shared[right] = (op == scalar::Add) ? (shared[right] + temp) + : (shared[right] * temp); + } + } - PointersManager manager(context, "prefix"); + __syncthreads(); - NDArray::prepareSpecialUse({z}, {x}); - BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), NUMERIC_TYPES); - NDArray::registerSpecialUse({z}, {x}); + if (leftArrInd < tadLen) { + T result = shared[sharedInd]; + if (!exclusive) + result = (op == scalar::Add) ? result + xLeft : result * xLeft; + if (i > 0) + result = (op == scalar::Add) ? result + lastElemInChunk + : result * lastElemInChunk; + zTad[shape::getIndexOffset(leftArrInd, zTadShapeInfo)] = result; + } - manager.synchronize(); + if (rightArrInd < tadLen) { + T result = shared[sharedInd + 1]; + if (!exclusive) + result = (op == scalar::Add) ? result + xRight : result * xRight; + if (i > 0) + result = (op == scalar::Add) ? result + lastElemInChunk + : result * lastElemInChunk; + if (i < numTadChunks - 1 && + threadIdx.x == blockDim.x - 1) // last element in chunk + lastElemInChunk = !exclusive ? result + : (op == scalar::Add) ? result + xRight + : result * xRight; + zTad[shape::getIndexOffset(rightArrInd, zTadShapeInfo)] = result; + } + } } /////////////////////////////////////////////////////////////////// -void prefix(sd::LaunchContext * context, scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse) { - prefix(context, op, x, z, {}, exclusive, reverse); +template +static void prefixPerBlockCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, scalar::Ops op, const void* vx, + const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, void* vz, + const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numTads, const Nd4jLong tadLen, const bool exclusive, + const bool reverse) { + prefixPerBlockCuda<<>>( + op, vx, xTadShapeInfo, xTadOffsets, vz, zTadShapeInfo, zTadOffsets, + numTads, tadLen, exclusive, reverse); } +/////////////////////////////////////////////////////////////////// +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z->shapeInfo(), dims); + + const Nd4jLong numTads = packX.numberOfTads(); + const Nd4jLong tadLen = x->lengthOf() / numTads; + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = numTads; + const int sharedMem = 2 * threadsPerBlock * x->sizeOfT() + 128; + + PointersManager manager(context, "prefix"); + + NDArray::prepareSpecialUse({z}, {x}); + BUILD_SINGLE_SELECTOR( + x->dataType(), prefixPerBlockCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + x->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), + numTads, tadLen, exclusive, reverse), + NUMERIC_TYPES); + NDArray::registerSpecialUse({z}, {x}); + + manager.synchronize(); } + +/////////////////////////////////////////////////////////////////// +void prefix(sd::LaunchContext* context, scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse) { + prefix(context, op, x, z, {}, exclusive, reverse); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu index 6733ce642087..2f60ce274562 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/print_variable.cu @@ -18,44 +18,48 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void print_device(const void *special, const Nd4jLong *shapeInfo) { - auto length = shape::length(shapeInfo); - auto x = reinterpret_cast(special); - - // TODO: add formatting here - printf("["); - - for (uint64_t e = 0; e < length; e++) { - printf("%f", (float) x[shape::getIndexOffset(e, shapeInfo)]); - - if (e < length - 1) - printf(", "); - } - - printf("]\n"); - } - - template - static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, const Nd4jLong *shapeInfo) { - print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); - } - - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message) { - NDArray::prepareSpecialUse({}, {&array}); - - PointersManager pm(&ctx, "print_device"); - BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, (ctx, array.specialBuffer(), array.specialShapeInfo()), LIBND4J_TYPES) - pm.synchronize(); - - NDArray::registerSpecialUse({}, {&array}); - } - } - } +namespace ops { +namespace helpers { +template +static _CUDA_G void print_device(const void *special, + const Nd4jLong *shapeInfo) { + auto length = shape::length(shapeInfo); + auto x = reinterpret_cast(special); + + // TODO: add formatting here + printf("["); + + for (uint64_t e = 0; e < length; e++) { + printf("%f", (float)x[shape::getIndexOffset(e, shapeInfo)]); + + if (e < length - 1) printf(", "); + } + + printf("]\n"); +} + +template +static _CUDA_H void exec_print_device(LaunchContext &ctx, const void *special, + const Nd4jLong *shapeInfo) { + print_device<<<1, 1, 1024, *ctx.getCudaStream()>>>(special, shapeInfo); +} + +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message) { + NDArray::prepareSpecialUse({}, {&array}); + + PointersManager pm(&ctx, "print_device"); + BUILD_SINGLE_SELECTOR(array.dataType(), exec_print_device, + (ctx, array.specialBuffer(), array.specialShapeInfo()), + LIBND4J_TYPES) + pm.synchronize(); + + NDArray::registerSpecialUse({}, {&array}); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu index 828867b4e2d0..1d802f5c1a31 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -17,163 +17,188 @@ // // @author George A. Shulinok // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, T* inBuffer, Nd4jLong* inShape, Nd4jLong column, Nd4jLong rows, Nd4jLong columns) { -// auto tid = threadIdx.x + blockDim.x * blockIdx.x; -// auto step = blockDim.x * gridDim.x; -// if (threadIdx.x == 0) { -// for (auto i = tid; i < column; i += step) { -// Nd4jLong diagPos[] = {i, i}; -// auto zIndex = shape::getOffset(outShape, diagPos); -// outBuffer[zIndex] = T(1.f); -// } -// } -// __syncthreads(); - - for (auto i = blockIdx.x; i < rows; i += gridDim.x) - for (auto j = threadIdx.x; j < columns; j += blockDim.x) { - Nd4jLong pos[] = {i,j}; - auto zIndex = shape::getOffset(outShape, pos); - auto xIndex = shape::getOffset(inShape, pos); - if (i < column || j < column) { - outBuffer[zIndex] = i != j?T(0.f):T(1.f); - } - else - outBuffer[zIndex] = inBuffer[xIndex]; //m.t(i,j) = in.t(i,j); - } - - +template +static __global__ void matrixMinorKernel(T* outBuffer, Nd4jLong* outShape, + T* inBuffer, Nd4jLong* inShape, + Nd4jLong column, Nd4jLong rows, + Nd4jLong columns) { + // auto tid = threadIdx.x + blockDim.x * blockIdx.x; + // auto step = blockDim.x * gridDim.x; + // if (threadIdx.x == 0) { + // for (auto i = tid; i < column; i += step) { + // Nd4jLong diagPos[] = {i, i}; + // auto zIndex = shape::getOffset(outShape, diagPos); + // outBuffer[zIndex] = T(1.f); + // } + // } + // __syncthreads(); + + for (auto i = blockIdx.x; i < rows; i += gridDim.x) + for (auto j = threadIdx.x; j < columns; j += blockDim.x) { + Nd4jLong pos[] = {i, j}; + auto zIndex = shape::getOffset(outShape, pos); + auto xIndex = shape::getOffset(inShape, pos); + if (i < column || j < column) { + outBuffer[zIndex] = i != j ? T(0.f) : T(1.f); + } else + outBuffer[zIndex] = inBuffer[xIndex]; // m.t(i,j) = in.t(i,j); } +} - template - NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) { - NDArray m = in.ulike(); - m.setIdentity(); - m({col, m.rows(), col, m.columns()}).assign(in({col, m.rows(), col, m.columns()})); - -// auto stream = context->getCudaStream(); -// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), -// matrixMinorKernel<<<128, 128, 256, *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), -// reinterpret_cast(in.specialBuffer()), in.specialShapeInfo(), col, in.rows(), in.columns()); -// - m.tickWriteDevice(); - return m; - } +template +NDArray matrixMinor(LaunchContext* context, NDArray& in, Nd4jLong col) { + NDArray m = in.ulike(); + m.setIdentity(); + m({col, m.rows(), col, m.columns()}) + .assign(in({col, m.rows(), col, m.columns()})); + + // auto stream = context->getCudaStream(); + // matrixMinorKernel<<<128, 128, 256, + // *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), + // matrixMinorKernel<<<128, 128, 256, + // *stream>>>(m.dataBuffer()->specialAsT(), m.specialShapeInfo(), + // reinterpret_cast(in.specialBuffer()), + // in.specialShapeInfo(), col, in.rows(), in.columns()); + // + m.tickWriteDevice(); + return m; +} /* m = I - v v^T */ - template - static __global__ void vmulKernel(T* resBuf, const Nd4jLong* resShape, T const* vBuff, Nd4jLong const* vShape, Nd4jLong n) { - for (auto i = blockIdx.x; i < n; i += gridDim.x) - for (auto j = threadIdx.x; j < n; j += blockDim.x) { - Nd4jLong posR[] = {i, j}; - auto indexR = shape::getOffset(resShape, posR); - auto indexX = shape::getIndexOffset(i, vShape); - auto indexY = shape::getIndexOffset(j, vShape); - - resBuf[indexR] = T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j?T(0.f):T(1.f)); - } - } - - template - NDArray vmul(LaunchContext* context, NDArray const& v, int n) - { - NDArray res('c', {n,n}, v.dataType(), context); // x = matrix_new(n, n); - - auto stream = context->getCudaStream(); - vmulKernel<<<128, 128, 128, *stream>>>(res.dataBuffer()->specialAsT(), res.specialShapeInfo(), - reinterpret_cast(v.specialBuffer()), v.specialShapeInfo(), n); - return res; - } - - template - static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) { - T hVal; - Nd4jLong pos[] = {k, k}; - auto shift = shape::getOffset(matrix->shapeInfo(), pos); - cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost); - return hVal > T(0.f); +template +static __global__ void vmulKernel(T* resBuf, const Nd4jLong* resShape, + T const* vBuff, Nd4jLong const* vShape, + Nd4jLong n) { + for (auto i = blockIdx.x; i < n; i += gridDim.x) + for (auto j = threadIdx.x; j < n; j += blockDim.x) { + Nd4jLong posR[] = {i, j}; + auto indexR = shape::getOffset(resShape, posR); + auto indexX = shape::getIndexOffset(i, vShape); + auto indexY = shape::getIndexOffset(j, vShape); + + resBuf[indexR] = + T(-2.f) * vBuff[indexX] * vBuff[indexY] + (i != j ? T(0.f) : T(1.f)); } +} - template - void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, bool const fullMatricies) { - Nd4jLong M = matrix->sizeAt(0); - Nd4jLong N = matrix->sizeAt(1); - auto resQ = fullMatricies?Q->ulike():NDArrayFactory::create(matrix->ordering(), {M,M}, Q->getContext()); - auto resR = fullMatricies?R->ulike():matrix->ulike(); - std::vector q(M); - NDArray z = *matrix; - NDArray e('c', {M}, DataTypeUtils::fromT(), context); // two internal buffers and scalar for squared norm - for (auto k = 0; k < N && k < M - 1; k++) { // loop for columns, but not further then row number - e.nullify(); - z = matrixMinor(context, z, k); // minor computing for current column with given matrix z (initally is a input matrix) - - auto currentColumn = z({0, 0, k, k + 1}); // retrieve k column from z to x buffer - auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); - if (diagonalIsPositive(matrix, k)) //matrix->t(k,k) > T(0.f)) // negate on positive matrix diagonal element - norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t(0); - - e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled by 1) - e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 - auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); - e /= normE; - q[k] = vmul(context, e, M); - auto qQ = z.ulike(); - MmulHelper::matmul(&q[k], &z, &qQ, false, false); - z = std::move(qQ); - } - resQ.assign(q[0]); // -// MmulHelper::matmul(&q[0], matrix, &resR, false, false); - for (int i = 1; i < N && i < M - 1; i++) { - auto tempResQ = resQ; - MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); - resQ = std::move(tempResQ); - } - MmulHelper::matmul(&resQ, matrix, &resR, false, false); - // resR *= -1.f; - resQ.transposei(); - - if (fullMatricies) { - Q->assign(resQ); - R->assign(resR); - } - else { - Q->assign(resQ({0, 0, 0, N})); - R->assign(resR({0, N, 0, 0})); - } - } +template +NDArray vmul(LaunchContext* context, NDArray const& v, int n) { + NDArray res('c', {n, n}, v.dataType(), context); // x = matrix_new(n, n); - template - void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - Nd4jLong lastDim = input->rankOf() - 1; - Nd4jLong preLastDim = input->rankOf() - 2; - - NDArray::prepareSpecialUse({outputQ, outputR}, {input}); - ResultSet listOutQ(outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listOutR(outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - ResultSet listInput(input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); - auto start = 0; - auto stop = listInput.size(); - auto increment = 1; - - for (auto batch = start; batch < stop; batch += increment) { - //qr here - qrSingle(context, listInput.at(batch), listOutQ.at(batch), listOutR.at(batch), fullMatricies); - } - NDArray::registerSpecialUse({outputQ, outputR}, {input}); - } + auto stream = context->getCudaStream(); + vmulKernel<<<128, 128, 128, *stream>>>( + res.dataBuffer()->specialAsT(), res.specialShapeInfo(), + reinterpret_cast(v.specialBuffer()), v.specialShapeInfo(), n); + return res; +} - void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies) { - BUILD_SINGLE_SELECTOR(input->dataType(), qr_, (context, input, outputQ, outputR, fullMatricies), FLOAT_TYPES); - } +template +static bool diagonalIsPositive(NDArray* matrix, Nd4jLong k) { + T hVal; + Nd4jLong pos[] = {k, k}; + auto shift = shape::getOffset(matrix->shapeInfo(), pos); + cudaMemcpy(&hVal, matrix->specialBuffer(), sizeof(T), cudaMemcpyDeviceToHost); + return hVal > T(0.f); +} +template +void qrSingle(LaunchContext* context, NDArray* matrix, NDArray* Q, NDArray* R, + bool const fullMatricies) { + Nd4jLong M = matrix->sizeAt(0); + Nd4jLong N = matrix->sizeAt(1); + auto resQ = fullMatricies ? Q->ulike() + : NDArrayFactory::create( + matrix->ordering(), {M, M}, Q->getContext()); + auto resR = fullMatricies ? R->ulike() : matrix->ulike(); + std::vector q(M); + NDArray z = *matrix; + NDArray e('c', {M}, DataTypeUtils::fromT(), + context); // two internal buffers and scalar for squared norm + for (auto k = 0; k < N && k < M - 1; + k++) { // loop for columns, but not further then row number + e.nullify(); + z = matrixMinor(context, z, + k); // minor computing for current column with given + // matrix z (initally is a input matrix) + + auto currentColumn = + z({0, 0, k, k + 1}); // retrieve k column from z to x buffer + auto norm = currentColumn.reduceAlongDimension(reduce::Norm2, {0}); + if (diagonalIsPositive(matrix, + k)) // matrix->t(k,k) > T(0.f)) // negate on + // positive matrix diagonal element + norm.applyTransform(transform::Neg, norm); // *= -1.f;//-norm.t(0); + + e.p(k, norm); // e - is filled by 0 vector except diagonal element (filled + // by 1) + e += currentColumn; // e[i] = x[i] + a * e[i] for each i from 0 to n - 1 + auto normE = e.reduceAlongDimension(reduce::Norm2, {0}); + e /= normE; + q[k] = vmul(context, e, M); + auto qQ = z.ulike(); + MmulHelper::matmul(&q[k], &z, &qQ, false, false); + z = std::move(qQ); + } + resQ.assign(q[0]); // + // MmulHelper::matmul(&q[0], matrix, &resR, false, false); + for (int i = 1; i < N && i < M - 1; i++) { + auto tempResQ = resQ; + MmulHelper::matmul(&q[i], &resQ, &tempResQ, false, false); + resQ = std::move(tempResQ); + } + MmulHelper::matmul(&resQ, matrix, &resR, false, false); + // resR *= -1.f; + resQ.transposei(); + + if (fullMatricies) { + Q->assign(resQ); + R->assign(resR); + } else { + Q->assign(resQ({0, 0, 0, N})); + R->assign(resR({0, N, 0, 0})); + } } + +template +void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + Nd4jLong lastDim = input->rankOf() - 1; + Nd4jLong preLastDim = input->rankOf() - 2; + + NDArray::prepareSpecialUse({outputQ, outputR}, {input}); + ResultSet listOutQ( + outputQ->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listOutR( + outputR->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + ResultSet listInput( + input->allTensorsAlongDimension({(int)preLastDim, (int)lastDim})); + auto start = 0; + auto stop = listInput.size(); + auto increment = 1; + + for (auto batch = start; batch < stop; batch += increment) { + // qr here + qrSingle(context, listInput.at(batch), listOutQ.at(batch), + listOutR.at(batch), fullMatricies); + } + NDArray::registerSpecialUse({outputQ, outputR}, {input}); } + +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies) { + BUILD_SINGLE_SELECTOR(input->dataType(), qr_, + (context, input, outputQ, outputR, fullMatricies), + FLOAT_TYPES); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index fe692a0df5ae..34a480027b4c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -20,343 +20,391 @@ #include //#include -#include -#include -#include -#include -#include #include #include +#include #include #include +#include +#include + +#include +#include namespace sd { namespace ops { namespace helpers { - /* - * fillGammaKernel - fill up output with gamma distributed values - * - * uList - uniformly distributed values set - * uLength - length of uList - * alpha - alpha param - * beta - beta param - * output - distributed output. - * */ - template - static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, const Nd4jLong* alphaShape, - T* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { - // fill up - __shared__ Nd4jLong aLength; - if (threadIdx.x == 0) { - aLength = shape::length(alphaShape); - } - __syncthreads(); - - for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { - auto pos = k * aLength; - auto u = uList[k]; // this is a vector - for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { - auto aIndex = shape::getIndexOffset(e, alphaShape); - auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL; - auto betaV = T(beta != nullptr ? beta[bIndex] * u : u); - auto zIndex = shape::getIndexOffset(e + pos, outputShape); - - output[zIndex] = math::nd4j_igamma(alpha[aIndex], betaV); - } - } - } - - template - static void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - // To fill up output need to broadcast alpha and beta to the same shape and in - const Nd4jLong* broadcasted = nullptr; - if (beta != nullptr) - ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace()); - else - broadcasted = alpha->shapeInfo(); - auto step = shape::length(broadcasted); - auto shift = output->lengthOf() / step; - - auto copyAlpha = alpha; - auto copyBeta = beta; - if (beta != nullptr) { - NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); - NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); - - copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha)); - copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); - } - - auto stream = context->getCudaStream(); - NDArray uniform = NDArrayFactory::create('c', {shift}, context); - uniform.syncToDevice(); - // fill up uniform with given length - RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); - - fillGammaKernel<<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT(), shift, - copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), - beta?copyBeta->dataBuffer()->specialAsT():(T*)nullptr, - beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - - if (beta != nullptr) { - delete copyAlpha; - delete copyBeta; - //delete broadcasted; - } - +/* + * fillGammaKernel - fill up output with gamma distributed values + * + * uList - uniformly distributed values set + * uLength - length of uList + * alpha - alpha param + * beta - beta param + * output - distributed output. + * */ +template +static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, + const Nd4jLong* alphaShape, T* beta, + const Nd4jLong* betaShape, T* output, + const Nd4jLong* outputShape) { + // fill up + __shared__ Nd4jLong aLength; + if (threadIdx.x == 0) { + aLength = shape::length(alphaShape); + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + auto pos = k * aLength; + auto u = uList[k]; // this is a vector + for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { + auto aIndex = shape::getIndexOffset(e, alphaShape); + auto bIndex = betaShape ? shape::getIndexOffset(e, betaShape) : -1LL; + auto betaV = T(beta != nullptr ? beta[bIndex] * u : u); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + + output[zIndex] = math::nd4j_igamma(alpha[aIndex], betaV); } + } +} - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { - if (beta) - NDArray::prepareSpecialUse({output}, {alpha, beta}); - else - NDArray::prepareSpecialUse({output}, {alpha}); - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); - if (beta) - NDArray::registerSpecialUse({output}, {alpha, beta}); - else - NDArray::prepareSpecialUse({output}, {alpha}); - } - BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); - - - /* - * algorithm Poisson generator based upon the inversion by sequential search - * - init: - Let x ← 0, p ← e−λ, s ← p. - using uniformly random sequence U (u in U) distributed at [0, 1]. - while u > s do: - x ← x + 1. - p ← p * λ / x. - s ← s + p. - return x. - * */ - template - static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, const Nd4jLong* lambdaShape, - T* output, const Nd4jLong* outputShape) { - - __shared__ Nd4jLong step; - - if (threadIdx.x == 0) { - step = shape::length(lambdaShape); - } - __syncthreads(); - - for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { - auto pos = k * step; - auto u = uList[k]; - for (auto e = threadIdx.x; e < step; e += blockDim.x) { - auto p = math::nd4j_exp(-lambda[e]); - auto s = p; - auto x = T(0.f); - auto lIndex = shape::getIndexOffset(e, lambdaShape); - auto zIndex = shape::getIndexOffset(e + pos, outputShape); - while (u > s) { - x += T(1.); - p *= lambda[lIndex] / x; - s += p; - } - output[zIndex] = x; - } - } - } +template +static void fillRandomGamma_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* alpha, + NDArray* beta, NDArray* output) { + // To fill up output need to broadcast alpha and beta to the same shape and in + const Nd4jLong* broadcasted = nullptr; + if (beta != nullptr) + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, + context->getWorkspace()); + else + broadcasted = alpha->shapeInfo(); + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() / step; + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); + + copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast( + BroadcastOpsTuple::Assign(), *alpha)); + copyBeta = new NDArray( + betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); + copyAlpha->tickWriteDevice(); + copyBeta->tickWriteDevice(); + } + + auto stream = context->getCudaStream(); + NDArray uniform = NDArrayFactory::create('c', {shift}, context); + uniform.syncToDevice(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); + + fillGammaKernel<<<128, 128, 256, *stream>>>( + uniform.dataBuffer()->specialAsT(), shift, + copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), + beta ? copyBeta->dataBuffer()->specialAsT() : (T*)nullptr, + beta ? copyBeta->specialShapeInfo() : (Nd4jLong*)nullptr, + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + // delete broadcasted; + } +} - template - static void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - auto shift = output->lengthOf() / lambda->lengthOf(); - NDArray uniform('c', {shift}, output->dataType()); - auto stream = context->getCudaStream(); - // fill up uniform with given length - RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); - fillPoissonKernel<<<128, 256, 128, *stream>>>(uniform.dataBuffer()->specialAsT(), uniform.lengthOf(), - lambda->dataBuffer()->specialAsT(), lambda->specialShapeInfo(), - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output) { + if (beta) + NDArray::prepareSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, + (context, rng, alpha, beta, output), FLOAT_NATIVE); + if (beta) + NDArray::registerSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); +} +BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output), + FLOAT_NATIVE); - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { - NDArray::prepareSpecialUse({output}, {lambda}); - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); - NDArray::registerSpecialUse({output}, {lambda}); +/* + * algorithm Poisson generator based upon the inversion by sequential search + * +init: + Let x ← 0, p ← e−λ, s ← p. + using uniformly random sequence U (u in U) distributed at [0, 1]. +while u > s do: + x ← x + 1. + p ← p * λ / x. + s ← s + p. +return x. + * */ +template +static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, + const Nd4jLong* lambdaShape, T* output, + const Nd4jLong* outputShape) { + __shared__ Nd4jLong step; + + if (threadIdx.x == 0) { + step = shape::length(lambdaShape); + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + auto pos = k * step; + auto u = uList[k]; + for (auto e = threadIdx.x; e < step; e += blockDim.x) { + auto p = math::nd4j_exp(-lambda[e]); + auto s = p; + auto x = T(0.f); + auto lIndex = shape::getIndexOffset(e, lambdaShape); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + while (u > s) { + x += T(1.); + p *= lambda[lIndex] / x; + s += p; + } + output[zIndex] = x; } + } +} - BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE); - - template - static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, T to, T* output, const Nd4jLong* outputShape) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - __shared__ Nd4jLong outputLen; +template +static void fillRandomPoisson_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* lambda, + NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + NDArray uniform('c', {shift}, output->dataType()); + auto stream = context->getCudaStream(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); + fillPoissonKernel<<<128, 256, 128, *stream>>>( + uniform.dataBuffer()->specialAsT(), uniform.lengthOf(), + lambda->dataBuffer()->specialAsT(), lambda->specialShapeInfo(), + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); +} - if (0 == threadIdx.x) { - outputLen = shape::length(outputShape); - } - __syncthreads(); +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output) { + NDArray::prepareSpecialUse({output}, {lambda}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, + (context, rng, lambda, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {lambda}); +} - for (auto i = start; i < outputLen; i += step) { - auto zIndex = shape::getIndexOffset(i, outputShape); - output[zIndex] = devRng->relativeT(i, from, to); - } +BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, + (LaunchContext * context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output), + FLOAT_NATIVE); + +template +static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, + T to, T* output, + const Nd4jLong* outputShape) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + __shared__ Nd4jLong outputLen; + + if (0 == threadIdx.x) { + outputLen = shape::length(outputShape); + } + __syncthreads(); + + for (auto i = start; i < outputLen; i += step) { + auto zIndex = shape::getIndexOffset(i, outputShape); + output[zIndex] = devRng->relativeT(i, from, to); + } +} +template +static void fillRandomUniform_(LaunchContext* context, + graph::RandomGenerator& rng, NDArray* min, + NDArray* max, NDArray* output) { + T minVal = T(0); + T maxVal = DataTypeUtils::infOrMax(); + if (min) minVal = min->t(0); + if (max) maxVal = max->t(0); + + if (output->isR()) + RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); + else { + auto stream = context->getCudaStream(); + graph::RandomGenerator* devRng; + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot allocate device memory for random " + "generator due error", + err); } - template - static void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - T minVal = T(0); - T maxVal = DataTypeUtils::infOrMax(); - if (min) - minVal = min->t(0); - if (max) - maxVal = max->t(0); - - if (output->isR()) - RandomLauncher::fillUniform(context, rng, output, minVal, maxVal); - else { - auto stream = context->getCudaStream(); - graph::RandomGenerator *devRng; - auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot allocate device memory for random generator due error", err); - } - - err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot copy random generator to device", err); - } - auto outputBuf = output->dataBuffer()->specialAsT(); - auto outputShape = output->specialShapeInfo(); - fillUniformKernel<<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, outputBuf, outputShape); - - err = cudaStreamSynchronize(*stream); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot successfully finish kernel call", err); - } - - err = cudaFree(devRng); - if (err != 0) { - cuda_exception::build("fillRandomUniform_: Cannot deallocate device memory for random generator", err); - } - } + err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), + cudaMemcpyHostToDevice); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot copy random generator to device", err); + } + auto outputBuf = output->dataBuffer()->specialAsT(); + auto outputShape = output->specialShapeInfo(); + fillUniformKernel<<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, + outputBuf, outputShape); + + err = cudaStreamSynchronize(*stream); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot successfully finish kernel call", err); } - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build( + "fillRandomUniform_: Cannot deallocate device memory for random " + "generator", + err); } + } +} + +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, + (context, rng, min, max, output), NUMERIC_TYPES); +} /////////////////////////////////////////////////////////////////// // used https://en.wikipedia.org/wiki/Categorical_distribution // methods: gumbel trick + softmax + argmax -template -__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, - const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, - const Nd4jLong dimA, const X minVal, const X maxVal) { - - - const X* x = reinterpret_cast(vx); - Z* z = reinterpret_cast(vz); - - __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; - - if (0 == threadIdx.x) { - dimC = (0 == dimA) ? 1 : 0; - zDimAstride = shape::stride(zShapeInfo)[dimA]; - xDimAstride = shape::stride(xShapeInfo)[dimA]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - xDimCstride = shape::stride(xShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) { - - Nd4jLong nBatchIndex = index / numOfSamples; - Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); - - const X* xTad = x + (nBatchIndex * xDimCstride); - Z* zTad = z + (nBatchIndex * zDimCstride); - Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; - - X Max = -minVal; - Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; - Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; - - for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { - Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; - X tValue = (xTad[nClass * xDimAstride] - sd::math::nd4j_log(-sd::math::nd4j_log(devRng->relativeT(nIndex, minVal, maxVal)))); - if (tValue > Max) { - Max = tValue; - arg = nClass; - } - } +template +__global__ static void fillMultiNomialCuda_( + graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, + const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, + const Nd4jLong dimA, const X minVal, const X maxVal) { + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); + + __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; + + if (0 == threadIdx.x) { + dimC = (0 == dimA) ? 1 : 0; + zDimAstride = shape::stride(zShapeInfo)[dimA]; + xDimAstride = shape::stride(xShapeInfo)[dimA]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + xDimCstride = shape::stride(xShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong index = tid; index < batchValue * numOfSamples; + index += gridDim.x * blockDim.x) { + Nd4jLong nBatchIndex = index / numOfSamples; + Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); + + const X* xTad = x + (nBatchIndex * xDimCstride); + Z* zTad = z + (nBatchIndex * zDimCstride); + Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; + + X Max = -minVal; + Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; + + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { + Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; + X tValue = (xTad[nClass * xDimAstride] - + sd::math::nd4j_log(-sd::math::nd4j_log( + devRng->relativeT(nIndex, minVal, maxVal)))); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } } + } } ////////////////////////////////////////////////////////////////////////// -template +template __host__ static void fillMultiNomialCudaLauncher( - const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong batchValue, const Nd4jLong numOfSamples, - const Nd4jLong numOfClassX, const Nd4jLong dimA){ - - const X minVal = DataTypeUtils::min(); - const X maxVal = 1.0; - - fillMultiNomialCuda_ <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> ( - devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, - numOfSamples, numOfClassX, dimA, minVal, maxVal); + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, graph::RandomGenerator* devRng, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong batchValue, const Nd4jLong numOfSamples, + const Nd4jLong numOfClassX, const Nd4jLong dimA) { + const X minVal = DataTypeUtils::min(); + const X maxVal = 1.0; + + fillMultiNomialCuda_<<>>( + devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, numOfSamples, + numOfClassX, dimA, minVal, maxVal); } - -/////////////////////////////////////////////////////////////////// -void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { - - Nd4jLong dimA = (0 == dimC) ? 1 : 0; - - const Nd4jLong batchValue = output.sizeAt(dimC); - const Nd4jLong numOfClassX = input.sizeAt(dimA); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "fillMultinomial"); - graph::RandomGenerator *devRng; - - auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err); - } - err = cudaStreamSynchronize(*context->getCudaStream()); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err); - } - err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream()); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err); - } - - NDArray::prepareSpecialUse({ &output }, { &input }); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, - (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.specialBuffer(), - input.specialShapeInfo(), output.specialBuffer(), - output.specialShapeInfo(), batchValue, numOfSamples, - numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({ &output }, { &input }); - manager.synchronize(); - - err = cudaFree(devRng); - if (err != 0) { - cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err); - } - rng.rewindH(output.lengthOf() * numOfClassX); - } +/////////////////////////////////////////////////////////////////// +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC) { + Nd4jLong dimA = (0 == dimC) ? 1 : 0; + + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "fillMultinomial"); + graph::RandomGenerator* devRng; + + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot allocate device memory for random " + "generator due error", + err); + } + err = cudaStreamSynchronize(*context->getCudaStream()); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot synchronize stream for random generator " + "due error", + err); + } + err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), + cudaMemcpyHostToDevice, *context->getCudaStream()); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot copy random generator to device", err); + } + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), batchValue, numOfSamples, numOfClassX, dimA), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + manager.synchronize(); + + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build( + "fillRandomMultiNomial: Cannot deallocate device memory for random " + "generator", + err); + } + rng.rewindH(output.lengthOf() * numOfClassX); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu b/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu index 0489103f9287..3401d1e89250 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random_crop.cu @@ -20,24 +20,31 @@ #include //#include -#include -#include #include + +#include +#include namespace sd { namespace ops { namespace helpers { - template - static int _randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - return Status::OK(); - } +template +static int _randomCropFunctor(graph::Context& context, NDArray* input, + NDArray* shape, NDArray* output, int seed) { + return Status::OK(); +} - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed) { - BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, (context, input, shape, output, seed), FLOAT_TYPES); - } +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed) { + BUILD_SINGLE_SELECTOR(input->dataType(), return _randomCropFunctor, + (context, input, shape, output, seed), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, (graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int _randomCropFunctor, + (graph::Context & context, NDArray* input, NDArray* shape, + NDArray* output, int seed), + FLOAT_TYPES); -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index e33f95c52fc3..de50a258fa12 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -18,36 +18,40 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 27.08.2018 // - #include namespace sd { namespace ops { namespace helpers { - template - static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) { - auto buff = reinterpret_cast(output); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for(Nd4jLong i = tid; i < length; i += step) - buff[i] = start + i * delta; - } - - ////////////////////////////////////////////////////////////////////////// - // be careful: outVector must have c-order and ews = 1 !!! - template - static void _range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - global_range<<<512, 512, 2048, *context->getCudaStream()>>>(outVector.specialBuffer(), outVector.lengthOf(), start.e(0), delta.e(0)); - } - - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { - NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); - BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); - NDArray::registerSpecialUse({&outVector}, {&start, &delta}); - } +template +static __global__ void global_range(void* output, Nd4jLong length, T start, + T delta) { + auto buff = reinterpret_cast(output); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + for (Nd4jLong i = tid; i < length; i += step) buff[i] = start + i * delta; } + +////////////////////////////////////////////////////////////////////////// +// be careful: outVector must have c-order and ews = 1 !!! +template +static void _range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + global_range<<<512, 512, 2048, *context->getCudaStream()>>>( + outVector.specialBuffer(), outVector.lengthOf(), start.e(0), + delta.e(0)); } -} \ No newline at end of file + +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector) { + NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); + BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, + (context, start, delta, outVector), LIBND4J_TYPES); + NDArray::registerSpecialUse({&outVector}, {&start, &delta}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index b6bbeea4cb71..eb4e11c82459 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -18,218 +18,265 @@ // @author Yurii Shyrma, created on 16.04.2018 // -#include -#include #include -#include -#include #include +#include +#include +#include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - template - static __global__ void reverseTadKernel(const void* vinput, const Nd4jLong *inputShape, void* voutput, const Nd4jLong *outputShape, const Nd4jLong *inputTadShape, const Nd4jLong *inputTadOffsets, const Nd4jLong *outputTadShape, const Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) { - auto input = reinterpret_cast(vinput); - auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - // this means that we'll have additional cycle, to move middle element - auto div = numOfElemsToReverse / 2; - auto odd = numOfElemsToReverse % 2 != 0; - auto rlimit = odd ? limit / 2 + 1 : limit / 2; - - // all threads operate in the same input/output space - for (uint64_t e = tid; e < rlimit; e += step) { - // finding out the TAD we're going to process - auto tadId = e / div; - - if (tadId >= numTads) - continue; - - // now finding out element within tad - auto idx = e % div; - - //printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx); - - auto tadInput = input + inputTadOffsets[tadId]; - auto tadOutput = output + outputTadOffsets[tadId]; - - // we're calculating offsets within input TAD - auto fOffset = shape::getIndexOffset(idx, inputTadShape); - auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); - - // now we're storing input values - auto v1 = tadInput[fOffset]; - auto v2 = tadInput[lOffset]; - - // now we're calculating offsets within output TAD - auto zfOffset = shape::getIndexOffset(idx, outputTadShape); - auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); - - // and saving values to output arrays - tadOutput[zfOffset] = v2; - tadOutput[zlOffset] = v1; - } - - // moving odd element in blocks - if (odd && threadIdx.x == 0) { - for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { - auto tadInput = input + inputTadOffsets[e]; - auto tadOutput = output + outputTadOffsets[e]; - - auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); - auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); - - tadOutput[zOffset] = tadInput[xOffset]; - } - } - - } - - - template - static __global__ void reverseArrayKernel(const void* input, const Nd4jLong *inputShape, void* output, const Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - __shared__ int linearStatus; - __shared__ const T* inputArr; - __shared__ T* outputArr; - __shared__ char inputOrder, outputOrder; - - if (threadIdx.x == 0) { - linearStatus = (shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape)) && (inputOrder == outputOrder)? shape::elementWiseStride(inputShape):0; - - char inputOrder = shape::order(inputShape); - char outputOrder = shape::order(outputShape); - inputArr = reinterpret_cast(input); - outputArr = reinterpret_cast(output); - } - __syncthreads(); - - auto odd = numOfElemsToReverse % 2 != 0; - auto limit = numOfElemsToReverse / 2; - - for (uint64_t e = tid; e < limit; e += step) { - // we're calculating offsets within input array - auto fOffset = shape::getIndexOffset(e, inputShape); - auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); - - // now we're storing input values - auto v1 = inputArr[fOffset]; - auto v2 = inputArr[lOffset]; - - // now we're calculating offsets within output array - auto zfOffset = shape::getIndexOffset(e, outputShape); - auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape); - - // and saving values to output arrays - outputArr[zfOffset] = v2; - outputArr[zlOffset] = v1; - } - - // in case of odd array we'll have to move middle value - if (odd && tid == 0) { - auto xOffset = shape::getIndexOffset(limit, inputShape); - auto zOffset = shape::getIndexOffset(limit, outputShape); - - outputArr[zOffset] = inputArr[xOffset]; - } - } - - template - static void reverseTad(sd::LaunchContext * context, const NDArray* input, NDArray* output, const Nd4jLong *inputTadShape, const Nd4jLong *inputTadOffsets, const Nd4jLong *outputTadShape, const Nd4jLong *outputTadOffsets, uint64_t tadLength) { - auto stream = context->getCudaStream(); - reverseTadKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength); - } - - template - static void reverseArray(sd::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { - auto stream = context->getCudaStream(); - Nd4jLong numOfReverse = numOfElemsToReverse; - if (numOfElemsToReverse == 0) - numOfReverse = input->lengthOf(); - - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); +template +static __global__ void reverseTadKernel( + const void* vinput, const Nd4jLong* inputShape, void* voutput, + const Nd4jLong* outputShape, const Nd4jLong* inputTadShape, + const Nd4jLong* inputTadOffsets, const Nd4jLong* outputTadShape, + const Nd4jLong* outputTadOffsets, uint64_t limit, + uint64_t numOfElemsToReverse, uint64_t numTads) { + auto input = reinterpret_cast(vinput); + auto output = reinterpret_cast(voutput); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + // this means that we'll have additional cycle, to move middle element + auto div = numOfElemsToReverse / 2; + auto odd = numOfElemsToReverse % 2 != 0; + auto rlimit = odd ? limit / 2 + 1 : limit / 2; + + // all threads operate in the same input/output space + for (uint64_t e = tid; e < rlimit; e += step) { + // finding out the TAD we're going to process + auto tadId = e / div; + + if (tadId >= numTads) continue; + + // now finding out element within tad + auto idx = e % div; + + // printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", + // tid, numTads, numOfElemsToReverse, tadId, idx); + + auto tadInput = input + inputTadOffsets[tadId]; + auto tadOutput = output + outputTadOffsets[tadId]; + + // we're calculating offsets within input TAD + auto fOffset = shape::getIndexOffset(idx, inputTadShape); + auto lOffset = + shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); + + // now we're storing input values + auto v1 = tadInput[fOffset]; + auto v2 = tadInput[lOffset]; + + // now we're calculating offsets within output TAD + auto zfOffset = shape::getIndexOffset(idx, outputTadShape); + auto zlOffset = + shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); + + // and saving values to output arrays + tadOutput[zfOffset] = v2; + tadOutput[zlOffset] = v1; + } + + // moving odd element in blocks + if (odd && threadIdx.x == 0) { + for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { + auto tadInput = input + inputTadOffsets[e]; + auto tadOutput = output + outputTadOffsets[e]; + + auto xOffset = + shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); + auto zOffset = + shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); + + tadOutput[zOffset] = tadInput[xOffset]; } + } +} +template +static __global__ void reverseArrayKernel(const void* input, + const Nd4jLong* inputShape, + void* output, + const Nd4jLong* outputShape, + Nd4jLong numOfElemsToReverse) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + __shared__ int linearStatus; + __shared__ const T* inputArr; + __shared__ T* outputArr; + __shared__ char inputOrder, outputOrder; + + if (threadIdx.x == 0) { + linearStatus = (shape::elementWiseStride(inputShape) == + shape::elementWiseStride(outputShape)) && + (inputOrder == outputOrder) + ? shape::elementWiseStride(inputShape) + : 0; + + char inputOrder = shape::order(inputShape); + char outputOrder = shape::order(outputShape); + inputArr = reinterpret_cast(input); + outputArr = reinterpret_cast(output); + } + __syncthreads(); + + auto odd = numOfElemsToReverse % 2 != 0; + auto limit = numOfElemsToReverse / 2; + + for (uint64_t e = tid; e < limit; e += step) { + // we're calculating offsets within input array + auto fOffset = shape::getIndexOffset(e, inputShape); + auto lOffset = + shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); + + // now we're storing input values + auto v1 = inputArr[fOffset]; + auto v2 = inputArr[lOffset]; + + // now we're calculating offsets within output array + auto zfOffset = shape::getIndexOffset(e, outputShape); + auto zlOffset = + shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape); + + // and saving values to output arrays + outputArr[zfOffset] = v2; + outputArr[zlOffset] = v1; + } + + // in case of odd array we'll have to move middle value + if (odd && tid == 0) { + auto xOffset = shape::getIndexOffset(limit, inputShape); + auto zOffset = shape::getIndexOffset(limit, outputShape); + + outputArr[zOffset] = inputArr[xOffset]; + } +} - /////////////////////////////////////////////////////////////////// - template - static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ - int posOfNonUnityDim = -1; - seqLengths->syncToHost(); - auto stream = context->getCudaStream(); - - if(input->isVector() || shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) { - int numOfElemsToReverse = seqLengths->e(0); - if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim)) - output->assign(input); - else - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfElemsToReverse);//helpers::reverseArray(context, const_cast(input), output, numOfElemsToReverse); - } - else { - - if(seqDim > batchDim) - --seqDim; - - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); - - auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); - auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); - - for(int i = 0; i < inSubArrsSet.size(); ++i) { - - int numOfElemsToReverse = seqLengths->e(i); - - if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); - } - else { - auto inInnerSet = inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - auto outInnerSet = outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); - for(int j = 0; j < inInnerSet.size(); ++j) - reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), numOfElemsToReverse); - } - } - } - } +template +static void reverseTad(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const Nd4jLong* inputTadShape, + const Nd4jLong* inputTadOffsets, + const Nd4jLong* outputTadShape, + const Nd4jLong* outputTadOffsets, uint64_t tadLength) { + auto stream = context->getCudaStream(); + reverseTadKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTadShape, + inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), + tadLength, input->lengthOf() / tadLength); +} - void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { - NDArray::prepareSpecialUse({output}, {input, seqLengths}); +template +static void reverseArray(sd::LaunchContext* context, const NDArray* input, + NDArray* output, Nd4jLong numOfElemsToReverse) { + auto stream = context->getCudaStream(); + Nd4jLong numOfReverse = numOfElemsToReverse; + if (numOfElemsToReverse == 0) numOfReverse = input->lengthOf(); - // if op isn't inplace - copy original data into output array - if (output->specialBuffer() != input->specialBuffer()) - output->assign(input); + reverseArrayKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), numOfReverse); +} - BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input, seqLengths}); +/////////////////////////////////////////////////////////////////// +template +static void reverseSequence_(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, + int seqDim, const int batchDim) { + int posOfNonUnityDim = -1; + seqLengths->syncToHost(); + auto stream = context->getCudaStream(); + + if (input->isVector() || + shape::isLikeVector(input->shapeInfo(), posOfNonUnityDim) || + seqLengths->lengthOf() == 1) { + int numOfElemsToReverse = seqLengths->e(0); + if ((seqDim == 0 && input->sizeAt(0) == 1) || + (batchDim == posOfNonUnityDim)) + output->assign(input); + else + reverseArrayKernel<<<256, 512, 8192, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), + numOfElemsToReverse); // helpers::reverseArray(context, + // const_cast(input), output, + // numOfElemsToReverse); + } else { + if (seqDim > batchDim) --seqDim; + + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {batchDim}); + + auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); + auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); + + for (int i = 0; i < inSubArrsSet.size(); ++i) { + int numOfElemsToReverse = seqLengths->e(i); + + if (numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { + outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); + } else { + auto inInnerSet = + inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + auto outInnerSet = + outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + for (int j = 0; j < inInnerSet.size(); ++j) + reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), + numOfElemsToReverse); + } } + } +} - ////////////////////////////////////////////////////////////////////////// - void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp) { - // we need to reverse axis only if that's new op - std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; - std::vector axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - - - - NDArray::prepareSpecialUse({output}, {input}); +void reverseSequence(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim) { + NDArray::prepareSpecialUse({output}, {input, seqLengths}); - if (packX.numberOfTads() == 1) { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); - } else { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES); - } + // if op isn't inplace - copy original data into output array + if (output->specialBuffer() != input->specialBuffer()) output->assign(input); - NDArray::registerSpecialUse({output}, {input}); - } -} -} + BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, + (context, input, seqLengths, output, seqDim, batchDim), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {input, seqLengths}); } +////////////////////////////////////////////////////////////////////////// +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs, bool isBackProp) { + // we need to reverse axis only if that's new op + std::vector dimensions = + isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) + : *intArgs; + std::vector axis = + ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + + NDArray::prepareSpecialUse({output}, {input}); + + if (packX.numberOfTads() == 1) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, + (context, input, output, 0), LIBND4J_TYPES); + } else { + BUILD_SINGLE_SELECTOR( + input->dataType(), reverseTad, + (context, input, output, packX.platformShapeInfo(), + packX.platformOffsets(), packZ.platformShapeInfo(), + packZ.platformOffsets(), + (uint64_t)(input->lengthOf() / packX.numberOfTads())), + LIBND4J_TYPES); + } + + NDArray::registerSpecialUse({output}, {input}); +} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index 773f7279ddb7..c09e8750b02e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -18,314 +18,374 @@ // @author raver119@gmail.com // -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - template - static void _CUDA_D rollKernelLinearStage1Dev(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_D rollKernelLinearStage1Dev( + const void *vx, const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int sourceIndex = fullLength - actualShift + i; + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int sourceIndex = fullLength - actualShift + i; - auto eA = x[sourceIndex * xEws]; - auto eB = x[i * xEws]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[i * xEws]; - z[i * zEws] = eA; - z[sourceIndex * zEws] = eB; - } - } else { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int sourceIndex = fullLength - actualShift + i; + z[i * zEws] = eA; + z[sourceIndex * zEws] = eB; + } + } else { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int sourceIndex = fullLength - actualShift + i; - auto xOffsetA = shape::getIndexOffset(i, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto xOffsetA = shape::getIndexOffset(i, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zOffsetA = shape::getIndexOffset(i, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + auto zOffsetA = shape::getIndexOffset(i, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; } + } +} - template - static void _CUDA_G rollKernelLinearStage1(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift) { - rollKernelLinearStage1Dev(vx, xShapeInfo, vz, zShapeInfo, fullLength, actualShift); - } +template +static void _CUDA_G rollKernelLinearStage1(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, + int actualShift) { + rollKernelLinearStage1Dev(vx, xShapeInfo, vz, zShapeInfo, fullLength, + actualShift); +} - template - static void _CUDA_G rollKernelLinearStage2(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift, int shiftCount) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_G rollKernelLinearStage2(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, int actualShift, + int shiftCount) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto tid = threadIdx.x + blockIdx.x * blockDim.x; + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int count = 1; count < shiftCount; ++count) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int destinationIndex = fullLength - (count + 1) * actualShift + i; - int sourceIndex = fullLength - count * actualShift + i; + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int count = 1; count < shiftCount; ++count) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int destinationIndex = fullLength - (count + 1) * actualShift + i; + int sourceIndex = fullLength - count * actualShift + i; - auto eA = x[sourceIndex * xEws]; - auto eB = x[destinationIndex * xEws]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[destinationIndex * xEws]; - z[destinationIndex * zEws] = eA; - z[sourceIndex * zEws] = eB; - } + z[destinationIndex * zEws] = eA; + z[sourceIndex * zEws] = eB; + } - __syncthreads(); - } - } else { - for (int count = 1; count < shiftCount; ++count) { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int destinationIndex = fullLength - (count + 1) * actualShift + i; - int sourceIndex = fullLength - count * actualShift + i; + __syncthreads(); + } + } else { + for (int count = 1; count < shiftCount; ++count) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int destinationIndex = fullLength - (count + 1) * actualShift + i; + int sourceIndex = fullLength - count * actualShift + i; - auto xOffsetA = shape::getIndexOffset(destinationIndex, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto xOffsetA = shape::getIndexOffset(destinationIndex, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zOffsetA = shape::getIndexOffset(destinationIndex, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + auto zOffsetA = shape::getIndexOffset(destinationIndex, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; + } - __syncthreads(); - } - } + __syncthreads(); } + } +} - template - static void _CUDA_G rollKernelLinearStage3(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong fullLength, int actualShift, int remainShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - auto xEws = shape::elementWiseStride(xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); - - auto xOrder = shape::order(xShapeInfo); - auto zOrder = shape::order(zShapeInfo); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = tid ; i < actualShift; i += blockDim.x * gridDim.x) { - int remainIdx = i + actualShift; - int sourceIndex = remainIdx + remainShift; +template +static void _CUDA_G rollKernelLinearStage3(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong fullLength, int actualShift, + int remainShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - auto eA = x[sourceIndex * xEws]; - auto eB = x[remainIdx * xEws]; + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); - z[remainIdx * zEws] = eA; - z[sourceIndex * zEws] = eB; - } - } else { - for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { - int remainIdx = i + actualShift; - int sourceIndex = remainIdx + remainShift; + auto xOrder = shape::order(xShapeInfo); + auto zOrder = shape::order(zShapeInfo); - auto xOffsetA = shape::getIndexOffset(remainIdx, xShapeInfo); - auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - auto zOffsetA = shape::getIndexOffset(remainIdx, zShapeInfo); - auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); + if (xEws > 0 && zEws > 0 && xOrder == zOrder) { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int remainIdx = i + actualShift; + int sourceIndex = remainIdx + remainShift; - auto eA = x[xOffsetA]; - auto eB = x[xOffsetB]; + auto eA = x[sourceIndex * xEws]; + auto eB = x[remainIdx * xEws]; - z[zOffsetA] = eB; - z[zOffsetB] = eA; - } - } + z[remainIdx * zEws] = eA; + z[sourceIndex * zEws] = eB; } + } else { + for (int i = tid; i < actualShift; i += blockDim.x * gridDim.x) { + int remainIdx = i + actualShift; + int sourceIndex = remainIdx + remainShift; - template - static void _CUDA_D swapTadsKernel(void *vx, void *vz, const Nd4jLong *zShapeInfo, Nd4jLong tadLength) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + auto xOffsetA = shape::getIndexOffset(remainIdx, xShapeInfo); + auto xOffsetB = shape::getIndexOffset(sourceIndex, xShapeInfo); - auto zEws = shape::elementWiseStride(zShapeInfo); + auto zOffsetA = shape::getIndexOffset(remainIdx, zShapeInfo); + auto zOffsetB = shape::getIndexOffset(sourceIndex, zShapeInfo); - auto zOrder = shape::order(zShapeInfo); + auto eA = x[xOffsetA]; + auto eB = x[xOffsetB]; - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (zEws > 0) { - for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto eA = x[e * zEws]; - auto eB = z[e * zEws]; - - x[e * zEws] = eB; - z[e * zEws] = eA; - } - } else { - for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { - auto zOffset = shape::getIndexOffset(e, zShapeInfo); - - auto eA = x[zOffset]; - auto eB = z[zOffset]; - - x[zOffset] = eB; - z[zOffset] = eA; - } - } + z[zOffsetA] = eB; + z[zOffsetB] = eA; } + } +} - template - static void _CUDA_G rollKernelFullAnyDimensionStage1(const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); +template +static void _CUDA_D swapTadsKernel(void *vx, void *vz, + const Nd4jLong *zShapeInfo, + Nd4jLong tadLength) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - for (int e = blockIdx.x + theShift; e < sizeAt - theShift; e += gridDim.x) { - int sourceIndex = dim * sizeAt + e - theShift; - int targetIndex = dim * sizeAt + e; + auto zEws = shape::elementWiseStride(zShapeInfo); - swapTadsKernel(z + xTadOffsets[sourceIndex], z + xTadOffsets[targetIndex], zTadShapeInfo, tadLength); - } - } + auto zOrder = shape::order(zShapeInfo); - template - static void _CUDA_G rollKernelFullAnyDimensionStage2(void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int e = blockIdx.x; e < theShift; e += gridDim.x) { - int sourceIndex = dim * sizeAt + sizeAt - theShift + e; - int targetIndex = dim * sizeAt + e; - - swapTadsKernel(z + zTadOffsets[sourceIndex], z + zTadOffsets[targetIndex], zTadShapeInfo, tadLength); - } - } + if (zEws > 0) { + for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto eA = x[e * zEws]; + auto eB = z[e * zEws]; - template - static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - if (!inplace) - output->assign(input); - - for (size_t i = 0; i < axes.size(); i++) { - int axe = axes[i]; - if (axe == input->rankOf() - 1) { // last dimension - ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); - ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); - int fullLen = listOfTensors.size(); - int theShift = shifts[i]; -// if (theShift > 0) { -// theShift %= fullLen; -// } -// else { -// theShift -= fullLen * (theShift / fullLen - 1); -// } - for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(output->getContext(), listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); - } - } else { - std::vector dims(input->rankOf() - axe - 1); - for (int i = 0; i < dims.size(); ++i) - dims[i] = axe + 1 + i; - - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dims); - - int numTads = packZ.numberOfTads(); - int sizeAt = input->sizeAt(axe); - auto tadLength = shape::length(packZ.primaryShapeInfo()); - - int theShift = shifts[i]; - -// if (theShift > 0) -// theShift %= sizeAt; -// else -// theShift -= sizeAt * (theShift / sizeAt - 1); - - if (theShift) { - for (int dim = 0; dim < numTads / sizeAt; ++dim) { - - rollKernelFullAnyDimensionStage1<<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift); - - rollKernelFullAnyDimensionStage2<<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift); - } - } - } - } + x[e * zEws] = eB; + z[e * zEws] = eA; } + } else { + for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { + auto zOffset = shape::getIndexOffset(e, zShapeInfo); - template - static void rollFunctorLinear_(NDArray* input, NDArray* output, int shift, bool inplace){ - if (!inplace) - output->assign(input); + auto eA = x[zOffset]; + auto eB = z[zOffset]; - auto fullLen = input->lengthOf(); - int actualShift = shift; // % fullLen; // shift already non-negative then - if (actualShift < 0) { - actualShift -= fullLen * (actualShift / fullLen - 1); - } - else - actualShift %= fullLen; - - if (actualShift) { - int shiftCount = fullLen / actualShift - 1; - int remainShift = fullLen % actualShift; + x[zOffset] = eB; + z[zOffset] = eA; + } + } +} - // stage 1) swap last actualShift elements with first ones. - rollKernelLinearStage1<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift); +template +static void _CUDA_G rollKernelFullAnyDimensionStage1( + const void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, + int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + for (int e = blockIdx.x + theShift; e < sizeAt - theShift; e += gridDim.x) { + int sourceIndex = dim * sizeAt + e - theShift; + int targetIndex = dim * sizeAt + e; + + swapTadsKernel(z + xTadOffsets[sourceIndex], + z + xTadOffsets[targetIndex], zTadShapeInfo, tadLength); + } +} - // stage 2) swap swapped actualShift elements with rest remainShiftCount times. - rollKernelLinearStage2<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, shiftCount); +template +static void _CUDA_G rollKernelFullAnyDimensionStage2( + void *vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xTadOffsets, + void *vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zTadOffsets, + int numTads, Nd4jLong tadLength, int dim, Nd4jLong sizeAt, int theShift) { + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + for (int e = blockIdx.x; e < theShift; e += gridDim.x) { + int sourceIndex = dim * sizeAt + sizeAt - theShift + e; + int targetIndex = dim * sizeAt + e; + + swapTadsKernel(z + zTadOffsets[sourceIndex], + z + zTadOffsets[targetIndex], zTadShapeInfo, tadLength); + } +} - // FIXME: no parallelism here :( - // stage 3) swap remainer of items. - if (remainShift && shiftCount) - rollKernelLinearStage3<<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, remainShift); +template +static void rollFunctorFull_(NDArray *input, NDArray *output, + std::vector const &shifts, + std::vector const &axes, bool inplace) { + if (!inplace) output->assign(input); + + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; + if (axe == input->rankOf() - 1) { // last dimension + ResultSet listOfTensors = output->allTensorsAlongDimension({axe}); + ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); + int fullLen = listOfTensors.size(); + int theShift = shifts[i]; + // if (theShift > 0) { + // theShift %= fullLen; + // } + // else { + // theShift -= fullLen * (theShift / fullLen - 1); + // } + for (int k = 0; k < fullLen; k++) { + rollFunctorLinear(output->getContext(), listOfTensors.at(k), + listOfOutTensors.at(k), theShift, true); + } + } else { + std::vector dims(input->rankOf() - axe - 1); + for (int i = 0; i < dims.size(); ++i) dims[i] = axe + 1 + i; + + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dims); + + int numTads = packZ.numberOfTads(); + int sizeAt = input->sizeAt(axe); + auto tadLength = shape::length(packZ.primaryShapeInfo()); + + int theShift = shifts[i]; + + // if (theShift > 0) + // theShift %= sizeAt; + // else + // theShift -= sizeAt * (theShift / sizeAt - 1); + + if (theShift) { + for (int dim = 0; dim < numTads / sizeAt; ++dim) { + rollKernelFullAnyDimensionStage1 + <<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, + tadLength, dim, sizeAt, theShift); + + rollKernelFullAnyDimensionStage2 + <<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), output->specialBuffer(), + packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, + tadLength, dim, sizeAt, theShift); } + } } + } +} - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ - input->syncToDevice(); +template +static void rollFunctorLinear_(NDArray *input, NDArray *output, int shift, + bool inplace) { + if (!inplace) output->assign(input); + + auto fullLen = input->lengthOf(); + int actualShift = shift; // % fullLen; // shift already non-negative then + if (actualShift < 0) { + actualShift -= fullLen * (actualShift / fullLen - 1); + } else + actualShift %= fullLen; + + if (actualShift) { + int shiftCount = fullLen / actualShift - 1; + int remainShift = fullLen % actualShift; + + // stage 1) swap last actualShift elements with first ones. + rollKernelLinearStage1 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift); + + // stage 2) swap swapped actualShift elements with rest remainShiftCount + // times. + rollKernelLinearStage2 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift, shiftCount); + + // FIXME: no parallelism here :( + // stage 3) swap remainer of items. + if (remainShift && shiftCount) + rollKernelLinearStage3 + <<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>( + output->specialBuffer(), output->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), fullLen, + actualShift, remainShift); + } +} - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES); +void rollFunctorFull(sd::LaunchContext *context, NDArray *input, + NDArray *output, std::vector const &shifts, + std::vector const &axes, bool inplace) { + input->syncToDevice(); - output->tickWriteDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, + (input, output, shifts, axes, inplace), LIBND4J_TYPES); - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace){ - input->syncToDevice(); + output->tickWriteDevice(); +} - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, (input, output, shift, inplace), LIBND4J_TYPES); +void rollFunctorLinear(sd::LaunchContext *context, NDArray *input, + NDArray *output, int shift, bool inplace) { + input->syncToDevice(); - output->tickWriteDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorLinear_, + (input, output, shift, inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace), LIBND4J_TYPES); + output->tickWriteDevice(); } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, + (NDArray * input, NDArray *output, int shift, + bool inplace), + LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, + (NDArray * input, NDArray *output, + std::vector const &shifts, + std::vector const &axes, bool inplace), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu index 8b7bfb2b5e64..265dc9c46411 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_b.cu @@ -19,486 +19,577 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ static void batchToSpaceCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) { +template +__global__ static void batchToSpaceCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const uint cropBottom, + const uint cropLeft) { + // input [bS, H * blockSize, W * blockSize, iC] + // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft + // - cropRight, iC] - // input [bS, H * blockSize, W * blockSize, iC] - // output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC] + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } - __syncthreads(); + auto coords = sharedMem + threadIdx.x * rank; - auto coords = sharedMem + threadIdx.x * rank; + const auto i = blockIdx.x * blockDim.x + threadIdx.x; - const auto i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= zLen) return; - if(i >= zLen) - return; + shape::index2coords(i, zShapeInfo, coords); - shape::index2coords(i, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + coords[1] += cropBottom; + coords[2] += cropLeft; - coords[1] += cropBottom; - coords[2] += cropLeft; - - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; + const auto xOffset = shape::getOffset(xShapeInfo, coords); + z[zOffset] = x[xOffset]; } /////////////////////////////////////////////////////////////////// -template -static void batchToSpaceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) { - - batchToSpaceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, cropBottom, cropLeft); +template +static void batchToSpaceCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, + const uint cropLeft) { + batchToSpaceCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, cropBottom, cropLeft); } -BUILD_SINGLE_TEMPLATE(template void batchToSpaceCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void batchToSpaceCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint cropBottom, + const uint cropLeft), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// -void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) { - - // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC] - // oH = H - cropTop - cropBottom - // oW = W - cropLeft - cropRight - - NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); - inputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - - output.assign(inputRearranged0); - } - else { - - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)}); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - - PointersManager manager(context, "batchToSpace"); - - NDArray::prepareSpecialUse({&output}, {&inputRearranged1}); - BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), cropBottom, cropLeft), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&inputRearranged1}); - - manager.synchronize(); - } +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize) { + // [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is + // rearranged/permuted to [bS, oH, oW, iC] oH = H - cropTop - cropBottom oW = + // W - cropLeft - cropRight + + NDArray inputRearranged0 = input.reshape( + input.ordering(), {blockSize, blockSize, output.sizeAt(0), + input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)}); + inputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + NDArray inputRearranged1 = inputRearranged0.reshape( + input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, + input.sizeAt(2) * blockSize, input.sizeAt(3)}); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + + PointersManager manager(context, "batchToSpace"); + + NDArray::prepareSpecialUse({&output}, {&inputRearranged1}); + BUILD_SINGLE_SELECTOR( + input.dataType(), batchToSpaceCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), cropBottom, + cropLeft), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&inputRearranged1}); + + manager.synchronize(); + } } - - /////////////////////////////////////////////////////////////////// -template -__global__ static void batchToSpaceNDCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, +template +__global__ static void batchToSpaceNDCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + // 4D example, numOfSpatialDims = 2 + // input [bS, H * blockShape[0], W * blockShape[1], iC] + // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - + // cropLeft - cropRight, iC] - // 4D example, numOfSpatialDims = 2 - // input [bS, H * blockShape[0], W * blockShape[1], iC] - // output [bS, H * blockShape[0] - cropBottom - cropTop, W * blockShape[1] - cropLeft - cropRight, iC] - - // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same - // else: - // oH -> [cropBottom, iH - cropTop] - // oW -> [cropLeft, iH - cropRight] - // xLen >= zLen + // if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same + // else: + // oH -> [cropBottom, iH - cropTop] + // oW -> [cropLeft, iH - cropRight] + // xLen >= zLen - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - __syncthreads(); + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } - auto coords = sharedMem + threadIdx.x * rank; + __syncthreads(); - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + auto coords = sharedMem + threadIdx.x * rank; - shape::index2coords(i, zShapeInfo, coords); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + shape::index2coords(i, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - // evaluate spatial coordinates for x - for(uint j = 1; j <= numOfSpatialDims; ++j) { - const auto yOffset = (j - 1) * yShapeInfo[3]; // yRank = 2, calculate offset manually - coords[j] += y[yOffset]; // add crop left - } + // evaluate spatial coordinates for x + for (uint j = 1; j <= numOfSpatialDims; ++j) { + const auto yOffset = + (j - 1) * yShapeInfo[3]; // yRank = 2, calculate offset manually + coords[j] += y[yOffset]; // add crop left + } - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void batchToSpaceNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { - - batchToSpaceNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); +template +static void batchToSpaceNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + batchToSpaceNDCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); } -BUILD_DOUBLE_TEMPLATE(template void batchToSpaceNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); +BUILD_DOUBLE_TEMPLATE(template void batchToSpaceNDCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// -void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output) { - - // 4D example, numOfSpatialDims = 2 - two spatial dimensions - // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - cropRight, iC] - - const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output) { + // 4D example, numOfSpatialDims = 2 - two spatial dimensions + // [bS*blockShape[0]*blockShape[1], iH, iW, iC] is rearranged/permuted to [bS, + // iH*blockShape[0] - cropTop - cropBottom, iW*blockShape[1] - cropLeft - + // cropRight, iC] - //*** construct reshaping std::vector for first reshape of input array ***// + const uint rank = input.rankOf(); + const uint numOfSpatialDims = blockShape.sizeAt(0); - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of input array ***// - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = output.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = input.sizeAt(j); + std::vector temp(numOfSpatialDims + rank); - NDArray inputRearranged0 = input.reshape(input.ordering(), temp); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = output.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = input.sizeAt(j); - //*** construct permuting std::vector for permutation of input array ***// + NDArray inputRearranged0 = input.reshape(input.ordering(), temp); - temp[0] = numOfSpatialDims; + //*** construct permuting std::vector for permutation of input array ***// - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; - - inputRearranged0.permutei(temp); + temp[0] = numOfSpatialDims; + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - if(input.lengthOf() == output.lengthOf()) { + inputRearranged0.permutei(temp); - output.assign(inputRearranged0); - } - else { - //*** construct reshaping std::vector for second reshape of input array ***// + if (input.lengthOf() == output.lengthOf()) { + output.assign(inputRearranged0); + } else { + //*** construct reshaping std::vector for second reshape of input array + //***// - temp.resize(rank); + temp.resize(rank); - temp[0] = output.sizeAt(0); + temp[0] = output.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? input.sizeAt(i) * blockShape.e(i - 1) : input.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? input.sizeAt(i) * blockShape.e(i - 1) + : input.sizeAt(i); - NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); + NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), temp); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - PointersManager manager(context, "batchToSpaceND"); + PointersManager manager(context, "batchToSpaceND"); - NDArray::prepareSpecialUse({&output}, {&inputRearranged1, &crop}); - BUILD_DOUBLE_SELECTOR(input.dataType(), crop.dataType(), batchToSpaceNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), crop.specialBuffer(), crop.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&output}, {&inputRearranged1, &crop}); + NDArray::prepareSpecialUse({&output}, {&inputRearranged1, &crop}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), crop.dataType(), batchToSpaceNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + inputRearranged1.specialBuffer(), inputRearranged1.specialShapeInfo(), + crop.specialBuffer(), crop.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&output}, {&inputRearranged1, &crop}); - manager.synchronize(); - } + manager.synchronize(); + } } - - /////////////////////////////////////////////////////////////////// -template -__global__ static void spaceToBatchCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC] - // output [bs, H * blockSize, W * blockSize, iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen +template +__global__ static void spaceToBatchCuda(const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const uint padBottom, const uint padTop, + const uint padLeft, + const uint padRight) { + // input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - + // padRight, iC] output [bs, H * blockSize, W * blockSize, iC] - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - __shared__ int rank, *sharedMem; - __shared__ Nd4jLong zLen; + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ int rank, *sharedMem; + __shared__ Nd4jLong zLen; - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - } - __syncthreads(); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - auto coords = sharedMem + threadIdx.x * rank; + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + } + __syncthreads(); - const auto i = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; - if(i >= zLen) - return; + const auto i = blockIdx.x * blockDim.x + threadIdx.x; - shape::index2coords(i, zShapeInfo, coords); + if (i >= zLen) return; - const auto zOffset = shape::getOffset(zShapeInfo, coords); + shape::index2coords(i, zShapeInfo, coords); - if(coords[1] >= padBottom && coords[1] < zShapeInfo[2] - padTop && coords[2] >= padLeft && coords[2] < zShapeInfo[3] - padRight) { + const auto zOffset = shape::getOffset(zShapeInfo, coords); - coords[1] -= padBottom; - coords[2] -= padLeft; + if (coords[1] >= padBottom && coords[1] < zShapeInfo[2] - padTop && + coords[2] >= padLeft && coords[2] < zShapeInfo[3] - padRight) { + coords[1] -= padBottom; + coords[2] -= padLeft; - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - z[zOffset] = x[xOffset]; - } - else - z[zOffset] = 0.f; + z[zOffset] = x[xOffset]; + } else + z[zOffset] = 0.f; } /////////////////////////////////////////////////////////////////// -template -static void spaceToBatchCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) { - - spaceToBatchCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, padBottom, padTop, padLeft, padRight); +template +static void spaceToBatchCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, + const uint padTop, const uint padLeft, const uint padRight) { + spaceToBatchCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, padBottom, padTop, padLeft, padRight); } -BUILD_SINGLE_TEMPLATE(template void spaceToBatchCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void spaceToBatchCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint padBottom, + const uint padTop, const uint padLeft, + const uint padRight), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// -void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) { - - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] - - NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false); - outputRearranged0.permutei({2, 3,0, 4,1, 5}); - - if(input.lengthOf() == output.lengthOf()) { - - outputRearranged0.assign(input); - } - else { - - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - - PointersManager manager(context, "spaceToBatch"); - - NDArray::prepareSpecialUse({&outputRearranged1}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), outputRearranged1.specialBuffer(), outputRearranged1.specialShapeInfo(), padBottom, padTop, padLeft, padRight), LIBND4J_TYPES); - NDArray::registerSpecialUse({&outputRearranged1}, {&input}); - - manager.synchronize(); - - if(output.specialBuffer() != outputRearranged1.specialBuffer()) - outputRearranged0.assign(outputRearranged1); - } +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize) { + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + + // padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] + + NDArray outputRearranged0 = + output.reshape(output.ordering(), + {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), + output.sizeAt(2), input.sizeAt(3)}, + false); + outputRearranged0.permutei({2, 3, 0, 4, 1, 5}); + + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + NDArray outputRearranged1 = outputRearranged0.reshape( + output.ordering(), + {input.sizeAt(0), output.sizeAt(1) * blockSize, + output.sizeAt(2) * blockSize, input.sizeAt(3)}, + false); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + + PointersManager manager(context, "spaceToBatch"); + + NDArray::prepareSpecialUse({&outputRearranged1}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), spaceToBatchCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + outputRearranged1.specialBuffer(), + outputRearranged1.specialShapeInfo(), padBottom, padTop, padLeft, + padRight), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&outputRearranged1}, {&input}); + + manager.synchronize(); + + if (output.specialBuffer() != outputRearranged1.specialBuffer()) + outputRearranged0.assign(outputRearranged1); + } } /////////////////////////////////////////////////////////////////// -template -__global__ static void spaceToBatchNDCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, +template +__global__ static void spaceToBatchNDCuda(const void* vx, + const Nd4jLong* xShapeInfo, + const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + // x - input, y - padding, z - output - // x - input, y - padding, z - output - - // 4D example - // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - padLeft - padRight, iC] - // output [bS, H * blockShape[0], W * blockShape[1], iC] - - // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same - // else: - // iH -> [padBottom, oH - padTop] - // iW -> [padLeft, oW - padRight] - // zLen > xLen + // 4D example + // input [bS, H * blockShape[0] - padBottom - padTop, W * blockShape[1] - + // padLeft - padRight, iC] output [bS, H * blockShape[0], W * blockShape[1], + // iC] - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + // if (padTop = padBottom = padRight = padLeft = 0) shapes are the same + // else: + // iH -> [padBottom, oH - padTop] + // iW -> [padLeft, oW - padRight] + // zLen > xLen - __shared__ int rank, *sharedMem; // xRank = zRank, yRank = 2; - __shared__ Nd4jLong zLen, totalThreads; + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - rank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } + __shared__ int rank, *sharedMem; // xRank = zRank, yRank = 2; + __shared__ Nd4jLong zLen, totalThreads; - __syncthreads(); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - auto coords = sharedMem + threadIdx.x * rank; + rank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } - for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < zLen; i += totalThreads) { + __syncthreads(); - shape::index2coords(i, zShapeInfo, coords); + auto coords = sharedMem + threadIdx.x * rank; - const auto zOffset = shape::getOffset(zShapeInfo, coords); + for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < zLen; + i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - bool within = true; + const auto zOffset = shape::getOffset(zShapeInfo, coords); - for(uint j = 1; j <= numOfSpatialDims; ++j) { + bool within = true; - // yRank = 2, calculate offset manually - const auto yOffset = (j - 1) * yShapeInfo[3]; - const auto padLeft = y[yOffset]; - const auto padRight = y[yOffset + yShapeInfo[4]]; + for (uint j = 1; j <= numOfSpatialDims; ++j) { + // yRank = 2, calculate offset manually + const auto yOffset = (j - 1) * yShapeInfo[3]; + const auto padLeft = y[yOffset]; + const auto padRight = y[yOffset + yShapeInfo[4]]; - within &= (coords[j] >= padLeft && coords[j] < shape::shapeOf(const_cast(zShapeInfo))[j] - padRight); + within &= + (coords[j] >= padLeft && + coords[j] < + shape::shapeOf(const_cast(zShapeInfo))[j] - padRight); - if(!within) - break; + if (!within) break; - coords[j] -= padLeft; // get coordinates for x - } - - if(within) - z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; - else - z[zOffset] = 0.f; + coords[j] -= padLeft; // get coordinates for x } + + if (within) + z[zOffset] = x[shape::getOffset(xShapeInfo, coords)]; + else + z[zOffset] = 0.f; + } } /////////////////////////////////////////////////////////////////// -template -static void spaceToBatchNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { - - spaceToBatchNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); +template +static void spaceToBatchNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims) { + spaceToBatchNDCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, numOfSpatialDims); } -BUILD_DOUBLE_TEMPLATE(template void spaceToBatchNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); +BUILD_DOUBLE_TEMPLATE(template void spaceToBatchNDCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vy, const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// -void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output ) { - - // 4D example with two spatial dimensions - // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + padRight)/blockShape[1], iC] +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output) { + // 4D example with two spatial dimensions + // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockShape[0]*blockShape[1], + // (iH + padBottom + padTop)/blockShape[0], (iW + padLeft + + // padRight)/blockShape[1], iC] - const uint rank = input.rankOf(); + const uint rank = input.rankOf(); - const uint numOfSpatialDims = blockShape.sizeAt(0); + const uint numOfSpatialDims = blockShape.sizeAt(0); - //*** construct reshaping std::vector for first reshape of output array ***// - std::vector temp(numOfSpatialDims + rank); + //*** construct reshaping std::vector for first reshape of output array ***// + std::vector temp(numOfSpatialDims + rank); - int i; - for(i = 0; i < numOfSpatialDims; ++i) - temp[i] = blockShape.e(i); - temp[i++] = input.sizeAt(0); - for(int j = 1; j < rank; ++i, ++j) - temp[i] = output.sizeAt(j); + int i; + for (i = 0; i < numOfSpatialDims; ++i) temp[i] = blockShape.e(i); + temp[i++] = input.sizeAt(0); + for (int j = 1; j < rank; ++i, ++j) temp[i] = output.sizeAt(j); - NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); + NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); - //*** construct permuting std::vector for permutation of output array ***// + //*** construct permuting std::vector for permutation of output array ***// - temp[0] = numOfSpatialDims; - - for(i = 1; i <= numOfSpatialDims; ++i) { - temp[2*i - 1] = numOfSpatialDims + i; - temp[2*i] = i - 1; - } - for(i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) - temp[i] = i; + temp[0] = numOfSpatialDims; - outputRearranged0.permutei(temp); + for (i = 1; i <= numOfSpatialDims; ++i) { + temp[2 * i - 1] = numOfSpatialDims + i; + temp[2 * i] = i - 1; + } + for (i = 2 * numOfSpatialDims + 1; i < temp.size(); ++i) temp[i] = i; - // ****** // + outputRearranged0.permutei(temp); - if(input.lengthOf() == output.lengthOf()) { - outputRearranged0.assign(input); - } - else { + // ****** // - //*** construct reshaping std::vector for second reshape of output array ***// - temp.resize(rank); + if (input.lengthOf() == output.lengthOf()) { + outputRearranged0.assign(input); + } else { + //*** construct reshaping std::vector for second reshape of output array + //***// + temp.resize(rank); - temp[0] = input.sizeAt(0); + temp[0] = input.sizeAt(0); - for(i = 1; i < rank; ++i) - temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e(i - 1) : output.sizeAt(i); + for (i = 1; i < rank; ++i) + temp[i] = (i <= numOfSpatialDims) + ? output.sizeAt(i) * blockShape.e(i - 1) + : output.sizeAt(i); - NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); + NDArray outputRearranged1 = + outputRearranged0.reshape(output.ordering(), temp, false); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * output.rankOf() + 128; - PointersManager manager(context, "spaceToBatchND"); + PointersManager manager(context, "spaceToBatchND"); - NDArray::prepareSpecialUse({&outputRearranged1}, {&input, &padding}); - BUILD_DOUBLE_SELECTOR(input.dataType(), padding.dataType(), spaceToBatchNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), padding.specialBuffer(), padding.specialShapeInfo(), outputRearranged1.specialBuffer(), outputRearranged1.specialShapeInfo(), numOfSpatialDims), LIBND4J_TYPES, INTEGER_TYPES); - NDArray::registerSpecialUse({&outputRearranged1}, {&input, &padding}); + NDArray::prepareSpecialUse({&outputRearranged1}, {&input, &padding}); + BUILD_DOUBLE_SELECTOR( + input.dataType(), padding.dataType(), spaceToBatchNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), + padding.specialBuffer(), padding.specialShapeInfo(), + outputRearranged1.specialBuffer(), + outputRearranged1.specialShapeInfo(), numOfSpatialDims), + LIBND4J_TYPES, INTEGER_TYPES); + NDArray::registerSpecialUse({&outputRearranged1}, {&input, &padding}); - manager.synchronize(); + manager.synchronize(); - if(output.specialBuffer() != outputRearranged1.specialBuffer()) - outputRearranged0.assign(outputRearranged1); - } + if (output.specialBuffer() != outputRearranged1.specialBuffer()) + outputRearranged0.assign(outputRearranged1); + } } - /* template struct SpaceToBatchHelper { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - for (int batch_pos = 0; batch_pos < batch_shape[0]; ++batch_pos) { - const int space_pos = batch_pos * block_shape[0] + block_offsets[0] - pad_start[0]; - if (space_pos >= 0 && space_pos < space_shape[0]) { - SpaceToBatchHelper::run(ptrSpace + space_pos * space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start + 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); - } else { + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { for (int batch_pos = 0; batch_pos < batch_shape[0]; +++batch_pos) { const int space_pos = batch_pos * block_shape[0] + +block_offsets[0] - pad_start[0]; if (space_pos >= 0 && space_pos < +space_shape[0]) { SpaceToBatchHelper::run(ptrSpace + space_pos * +space_strides[0], space_shape + 1, space_strides + 1, block_shape + 1, pad_start ++ 1, block_offsets + 1, ptrBatch, batch_shape + 1, batch_strides + 1); } else { if (!B2S) for (int i = 0; i < batch_strides[0]; i++) ptrBatch[i] = (T) 0.f; @@ -512,24 +603,29 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr template struct SpaceToBatchHelper<0, B2S> { template - static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - int str = batch_strides[-1]; - for (int i = 0; i < str; i++) - if (B2S) - ptrSpace[i] = ptrBatch[i]; - else - ptrBatch[i] = ptrSpace[i]; + static void run(T *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong +*space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const +Nd4jLong *block_offsets, T *ptrBatch, const Nd4jLong *batch_shape, const +Nd4jLong *batch_strides) { int str = batch_strides[-1]; for (int i = 0; i < str; +i++) if (B2S) ptrSpace[i] = ptrBatch[i]; else ptrBatch[i] = ptrSpace[i]; } }; template - void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { - auto ptrSpace = reinterpret_cast(vptrSpace); - auto ptrBatch = reinterpret_cast(vptrBatch); - SpaceToBatchHelper::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides); + void _execute(sd::LaunchContext * context, void *vptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *vptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides) { auto ptrSpace = +reinterpret_cast(vptrSpace); auto ptrBatch = reinterpret_cast(vptrBatch); SpaceToBatchHelper::run(ptrSpace, +space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, +batch_shape, batch_strides); }; - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) { + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops) { return Status::OK(); } @@ -542,12 +638,16 @@ void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDAr #define STB_BOOL (0, false),\ (1, true) - BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); + BUILD_TRIPLE_TEMPLATE(template void _execute, (sd::LaunchContext * context, +void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, +const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong +*block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong +*batch_strides), LIBND4J_TYPES, STB_DIM, STB_BOOL); #undef STB_BOOL #undef STB_DIM */ -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu index 19a1937ddbe9..1203ce7e5ec6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu @@ -23,86 +23,106 @@ namespace sd { namespace ops { namespace helpers { - template - static _CUDA_G void spaceToDepthKernel( - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const int block_size, - const bool isNHWC) { - auto input_ptr = reinterpret_cast(vx); - auto output_ptr = reinterpret_cast(vz); - - const int batch_size = shape::sizeAt(xShapeInfo, 0); - const int input_depth = isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); - const int input_height = isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); - const int input_width = isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); - - const int output_depth = isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); - const int output_height = isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); - const int output_width = isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); - - const int input_depth_by_output_height = input_depth * output_height; - - const int output_area = output_width * output_height; - const int output_depth_by_output_area = output_depth * output_area; - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (isNHWC) { - const int total_count = batch_size * input_height * input_width * input_depth; - - for (int inp_idx = tid; inp_idx < total_count; inp_idx += blockDim.x * gridDim.x){ - // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) - const int d = inp_idx % input_depth; - const int inp_idx2 = inp_idx / input_depth; - const int w = inp_idx2 % input_width; - const int inp_idx3 = inp_idx2 / input_width; - const int h = inp_idx3 % input_height; - const int b = inp_idx3 / input_height; - - const int out_h = h / block_size; - const int offset_h = h % block_size; - const int out_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; - const int out_d = d + offset_d; - - const int out_idx = out_d + output_depth * (out_w + output_width * (out_h + output_height * b)); - *(output_ptr + out_idx) = *(input_ptr + inp_idx); - } - } else { - const int total_count = batch_size * output_depth_by_output_area; - - for (int inp_idx = tid; inp_idx < total_count; inp_idx += blockDim.x * gridDim.x) { - const int n_iC_oY_bY_oX = inp_idx / block_size; - const int bX = inp_idx - n_iC_oY_bY_oX * block_size; - - const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; - const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; - - const int n_iC_oY = n_iC_oY_bY / block_size; - const int bY = n_iC_oY_bY - n_iC_oY * block_size; - - const int n = n_iC_oY / input_depth_by_output_height; - const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - - const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * input_depth_by_output_height + iC_oY) * output_width; - - *(output_ptr + output_idx) = *(input_ptr + inp_idx); - } - } +template +static _CUDA_G void spaceToDepthKernel(const void *vx, + const Nd4jLong *xShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, + const int block_size, + const bool isNHWC) { + auto input_ptr = reinterpret_cast(vx); + auto output_ptr = reinterpret_cast(vz); + + const int batch_size = shape::sizeAt(xShapeInfo, 0); + const int input_depth = + isNHWC ? shape::sizeAt(xShapeInfo, 3) : shape::sizeAt(xShapeInfo, 1); + const int input_height = + isNHWC ? shape::sizeAt(xShapeInfo, 1) : shape::sizeAt(xShapeInfo, 2); + const int input_width = + isNHWC ? shape::sizeAt(xShapeInfo, 2) : shape::sizeAt(xShapeInfo, 3); + + const int output_depth = + isNHWC ? shape::sizeAt(zShapeInfo, 3) : shape::sizeAt(zShapeInfo, 1); + const int output_height = + isNHWC ? shape::sizeAt(zShapeInfo, 1) : shape::sizeAt(zShapeInfo, 2); + const int output_width = + isNHWC ? shape::sizeAt(zShapeInfo, 2) : shape::sizeAt(zShapeInfo, 3); + + const int input_depth_by_output_height = input_depth * output_height; + + const int output_area = output_width * output_height; + const int output_depth_by_output_area = output_depth * output_area; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (isNHWC) { + const int total_count = + batch_size * input_height * input_width * input_depth; + + for (int inp_idx = tid; inp_idx < total_count; + inp_idx += blockDim.x * gridDim.x) { + // inp_idx = d + input_depth * (w + input_width * (h + input_height * b)) + const int d = inp_idx % input_depth; + const int inp_idx2 = inp_idx / input_depth; + const int w = inp_idx2 % input_width; + const int inp_idx3 = inp_idx2 / input_width; + const int h = inp_idx3 % input_height; + const int b = inp_idx3 / input_height; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + + const int out_idx = + out_d + + output_depth * (out_w + output_width * (out_h + output_height * b)); + *(output_ptr + out_idx) = *(input_ptr + inp_idx); } + } else { + const int total_count = batch_size * output_depth_by_output_area; - template - static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); - } + for (int inp_idx = tid; inp_idx < total_count; + inp_idx += blockDim.x * gridDim.x) { + const int n_iC_oY_bY_oX = inp_idx / block_size; + const int bX = inp_idx - n_iC_oY_bY_oX * block_size; + + const int n_iC_oY_bY = n_iC_oY_bY_oX / output_width; + const int oX = n_iC_oY_bY_oX - n_iC_oY_bY * output_width; + + const int n_iC_oY = n_iC_oY_bY / block_size; + const int bY = n_iC_oY_bY - n_iC_oY * block_size; + + const int n = n_iC_oY / input_depth_by_output_height; + const int iC_oY = n_iC_oY - n * input_depth_by_output_height; - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { - NDArray::prepareSpecialUse({output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {&input}); + const int output_idx = oX + (((n * block_size + bY) * block_size + bX) * + input_depth_by_output_height + + iC_oY) * + output_width; + + *(output_ptr + output_idx) = *(input_ptr + inp_idx); } + } +} + +template +static void _spaceTodepth_(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), input.specialShapeInfo(), output->specialBuffer(), + output->specialShapeInfo(), block_size, isNHWC); } + +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC) { + NDArray::prepareSpecialUse({output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, + (context, input, output, block_size, isNHWC), + LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index 94b0e008027a..7bdc60484fd8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -19,698 +19,773 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include -#include #include #include #include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// // x - indices, y - contains number of bad indices, z - input/output -template -__global__ static void checkIndicesCuda(const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { - - const auto x = reinterpret_cast(vx); - - __shared__ int xRank, *coords, xLastDim; - __shared__ Nd4jLong xLen, numOfBadIndxPerBlock; +template +__global__ static void checkIndicesCuda(const void *vx, + const Nd4jLong *xShapeInfo, Nd4jLong *y, + const Nd4jLong *zShapeInfo, + const int axis) { + const auto x = reinterpret_cast(vx); - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); + __shared__ int xRank, *coords, xLastDim; + __shared__ Nd4jLong xLen, numOfBadIndxPerBlock; - xRank = shape::rank(xShapeInfo); - xLen = shape::length(xShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); - numOfBadIndxPerBlock = 0; - } - __syncthreads(); + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); - auto xCoords = coords + threadIdx.x * xRank; + numOfBadIndxPerBlock = 0; + } + __syncthreads(); - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + auto xCoords = coords + threadIdx.x * xRank; - shape::index2coords(i, xShapeInfo, xCoords); + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + shape::index2coords(i, xShapeInfo, xCoords); - const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { - printf("checkIndices cuda: out of range element %lld at index %lld \n", currentInd, i); - sd::math::atomics::nd4j_atomicAdd(&numOfBadIndxPerBlock, 1); - } + if (currentInd >= + shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank - 1] : axis)) { + printf("checkIndices cuda: out of range element %lld at index %lld \n", + currentInd, i); + sd::math::atomics::nd4j_atomicAdd(&numOfBadIndxPerBlock, 1); } - __syncthreads(); + } + __syncthreads(); - if (threadIdx.x == 0 && numOfBadIndxPerBlock != 0) - sd::math::atomics::nd4j_atomicAdd(y, numOfBadIndxPerBlock); + if (threadIdx.x == 0 && numOfBadIndxPerBlock != 0) + sd::math::atomics::nd4j_atomicAdd(y, numOfBadIndxPerBlock); } /////////////////////////////////////////////////////////////////// -template -static void checkIndicesCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { - - checkIndicesCuda<<>>(vx, xShapeInfo, y, zShapeInfo, axis); +template +static void checkIndicesCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + Nd4jLong *y, const Nd4jLong *zShapeInfo, const int axis) { + checkIndicesCuda<<>>( + vx, xShapeInfo, y, zShapeInfo, axis); } - /////////////////////////////////////////////////////////////////// -Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * indices.rankOf() + 256; +Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray &indices, + const NDArray &output, const int axis) { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * indices.rankOf() + 256; - const auto xType = indices.dataType(); + const auto xType = indices.dataType(); - PointersManager manager(context, "scatterNDcheckIndices"); + PointersManager manager(context, "scatterNDcheckIndices"); - // scalar, initial value = 0 - NDArray numOfBadIndx(sd::DataType::INT64, context, true); + // scalar, initial value = 0 + NDArray numOfBadIndx(sd::DataType::INT64, context, true); - NDArray::prepareSpecialUse({&numOfBadIndx}, {&indices}); - BUILD_SINGLE_SELECTOR(xType, checkIndicesCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), reinterpret_cast(numOfBadIndx.specialBuffer()), output.specialShapeInfo(), axis), INDEXING_TYPES); - NDArray::registerSpecialUse({&numOfBadIndx}, {&indices}); + NDArray::prepareSpecialUse({&numOfBadIndx}, {&indices}); + BUILD_SINGLE_SELECTOR( + xType, checkIndicesCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + reinterpret_cast(numOfBadIndx.specialBuffer()), + output.specialShapeInfo(), axis), + INDEXING_TYPES); + NDArray::registerSpecialUse({&numOfBadIndx}, {&indices}); - manager.synchronize(); + manager.synchronize(); - return numOfBadIndx.t(0); + return numOfBadIndx.t(0); } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - input/output -template -__global__ static void scatterLockCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; - __shared__ Nd4jLong xLen, zLen; - __shared__ bool is1Dcase, xySameStride; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - xLen = shape::length(xShapeInfo); - zLen = shape::length(zShapeInfo); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - if(is1Dcase) - xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; +template +__global__ static void scatterLockCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, + *coords; + __shared__ Nd4jLong xLen, zLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xLen = shape::length(xShapeInfo); + zLen = shape::length(zShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + if (is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = + shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + Nd4jLong yOffset, zOffset; + int zFirstCoord, *yCoords, *zCoords; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + shape::index2coords(i, zShapeInfo, zCoords); } - __syncthreads(); - - - Nd4jLong yOffset, zOffset; - int zFirstCoord, *yCoords, *zCoords; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if(!is1Dcase) { - - yCoords = coords + threadIdx.x * (yRank + zRank); - zCoords = yCoords + yRank; - shape::index2coords(i, zShapeInfo, zCoords); - } - - for (Nd4jLong j = 0; j < xLen; ++j) { - - if(is1Dcase) { - - yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; - zFirstCoord = x[xySameStride ? yOffset : j * shape::stride(xShapeInfo)[xNonUnitDim]]; - - if(i != zFirstCoord) - continue; - - zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(j, xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x - - zFirstCoord = x[shape::getOffset(xShapeInfo, yCoords)]; - - if(zCoords[0] != zFirstCoord) - continue; - - for (uint k = 0; k < yRank - xRank; ++k) - yCoords[xRank + k] = zCoords[k + 1]; - - yOffset = shape::getOffset(yShapeInfo, yCoords); - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } - } + for (Nd4jLong j = 0; j < xLen; ++j) { + if (is1Dcase) { + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zFirstCoord = + x[xySameStride ? yOffset + : j * shape::stride(xShapeInfo)[xNonUnitDim]]; + + if (i != zFirstCoord) continue; + + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } + + else { + shape::index2coords(j, xShapeInfo, + yCoords); // first xRank coordinates in yCoords are + // the same for y and x + + zFirstCoord = x[shape::getOffset(xShapeInfo, yCoords)]; + + if (zCoords[0] != zFirstCoord) continue; + + for (uint k = 0; k < yRank - xRank; ++k) + yCoords[xRank + k] = zCoords[k + 1]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } } + } } - /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - input/output -template -__global__ static void scatterCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; - __shared__ Nd4jLong yLen; - __shared__ bool is1Dcase, xySameStride; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - if(is1Dcase) - xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; +template +__global__ static void scatterCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, + *coords; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + if (is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = + shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + Nd4jLong xOffset, yOffset, zOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; + i += gridDim.x * blockDim.x) { + if (is1Dcase) { + yOffset = i * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = x[xySameStride ? yOffset + : i * shape::stride(xShapeInfo)[xNonUnitDim]] * + shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + xOffset = shape::getOffset( + xShapeInfo, yCoords); // first xRank coordinates in yCoords are the + // same for y and x -> for (uint j = 0; j < + // xRank; ++j) xCoords[j] = yCoords[j]; + + zCoords[0] = x[xOffset]; + + for (uint j = 0; j < yRank - xRank; ++j) + zCoords[j + 1] = yCoords[xRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); } - __syncthreads(); - - - Nd4jLong xOffset, yOffset, zOffset; - int *yCoords, *zCoords; - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (yRank + zRank); - zCoords = yCoords + yRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { - - if(is1Dcase) { - - yOffset = i * shape::stride(yShapeInfo)[yNonUnitDim]; - zOffset = x[xySameStride ? yOffset : i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - shape::index2coords(i, yShapeInfo, yCoords); - - yOffset = shape::getOffset(yShapeInfo, yCoords); - xOffset = shape::getOffset(xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x -> for (uint j = 0; j < xRank; ++j) xCoords[j] = yCoords[j]; - - zCoords[0] = x[xOffset]; - - for (uint j = 0; j < yRank - xRank; ++j) - zCoords[j + 1] = yCoords[xRank + j]; - - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; } + } } /////////////////////////////////////////////////////////////////// -template -static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, +template +static void scatterCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, const bool lock) { - - if(lock) - scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - else - scatterCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + if (lock) + scatterLockCuda + <<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterCuda<<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } - /////////////////////////////////////////////////////////////////// -void scatter(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = sizeof(int) * threadsPerBlock * (updates.rankOf() + output.rankOf()) + 256; - - PointersManager manager(context, "scatter"); - - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - - manager.synchronize(); +void scatter(sd::LaunchContext *context, pairwise::Ops op, + const NDArray &indices, const NDArray &updates, NDArray &output, + const bool lock) { + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / + threadsPerBlock; + const int sharedMem = + sizeof(int) * threadsPerBlock * (updates.rankOf() + output.rankOf()) + + 256; + + PointersManager manager(context, "scatter"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, scatterCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), lock), + INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - output -template -__global__ static void scatterNDLockCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; - __shared__ Nd4jLong zLen, len; - __shared__ bool is1Dcase; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - xLastDim = shape::sizeAt(xShapeInfo, -1); - - biggerXYRank = xRank > yRank ? xRank : yRank; - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); - - len = is1Dcase ? shape::length(xShapeInfo) : shape::length(xShapeInfo) / xLastDim; - zLen = shape::length(zShapeInfo); - } - __syncthreads(); - - Nd4jLong yOffset, zOffset, xOffset; - int *yCoords, *zCoords; - - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (biggerXYRank + zRank); - zCoords = yCoords + biggerXYRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { - - if(!is1Dcase) - shape::index2coords(i, zShapeInfo, zCoords); - - for (Nd4jLong j = 0; j < len; ++j) { // if !is1Dcase then we loop through first xRank-1 dimensions of x, that is we exclude last x dimension - - if(is1Dcase) { - - if(x[j * shape::stride(xShapeInfo)[xNonUnitDim]] != i) - continue; - - yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; - zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(j, xRank-1, shape::shapeOf(const_cast(xShapeInfo)), yCoords); // first xRank-1 coordinates in yCoords are the same for y and x - - // first iteration - yCoords[xRank - 1] = 0; - xOffset = shape::getOffset(xShapeInfo, yCoords); - if(zCoords[0] != x[xOffset]) - continue; - - // rest iterations - bool matched = true; - for (uint k = 1; k < xLastDim; ++k) { - yCoords[xRank - 1] = k; - xOffset += shape::stride(xShapeInfo)[xRank-1]; - if(zCoords[k] != x[xOffset]) { - matched = false; - break; - } - } - - if(!matched) - continue; - - for (uint k = xLastDim; k < zRank; ++k) - yCoords[yRank - zRank + k] = zCoords[k]; - - yOffset = shape::getOffset(yShapeInfo, yCoords); - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } +template +__global__ static void scatterNDLockCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, + const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, + xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong zLen, len; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + + len = is1Dcase ? shape::length(xShapeInfo) + : shape::length(xShapeInfo) / xLastDim; + zLen = shape::length(zShapeInfo); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset, xOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; + i += gridDim.x * blockDim.x) { + if (!is1Dcase) shape::index2coords(i, zShapeInfo, zCoords); + + for (Nd4jLong j = 0; j < len; + ++j) { // if !is1Dcase then we loop through first xRank-1 dimensions + // of x, that is we exclude last x dimension + + if (is1Dcase) { + if (x[j * shape::stride(xShapeInfo)[xNonUnitDim]] != i) continue; + + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(j, xRank - 1, + shape::shapeOf(const_cast(xShapeInfo)), + yCoords); // first xRank-1 coordinates in yCoords + // are the same for y and x + + // first iteration + yCoords[xRank - 1] = 0; + xOffset = shape::getOffset(xShapeInfo, yCoords); + if (zCoords[0] != x[xOffset]) continue; + + // rest iterations + bool matched = true; + for (uint k = 1; k < xLastDim; ++k) { + yCoords[xRank - 1] = k; + xOffset += shape::stride(xShapeInfo)[xRank - 1]; + if (zCoords[k] != x[xOffset]) { + matched = false; + break; + } } + + if (!matched) continue; + + for (uint k = xLastDim; k < zRank; ++k) + yCoords[yRank - zRank + k] = zCoords[k]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } } + } } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - output -template -__global__ static void scatterNDCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; - __shared__ Nd4jLong yLen; - __shared__ bool is1Dcase; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coords = reinterpret_cast(shmem); - - yLen = shape::length(yShapeInfo); - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - xLastDim = shape::sizeAt(xShapeInfo, -1); - - biggerXYRank = xRank > yRank ? xRank : yRank; - - xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; - - is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); +template +__global__ static void scatterNDCuda(const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, + const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, + xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || + shape::isScalar(zShapeInfo)) && + (shape::isCommonVector(yShapeInfo, yNonUnitDim) || + shape::isScalar(yShapeInfo)) && + (shape::isCommonVector(xShapeInfo, xNonUnitDim) || + shape::isScalar(xShapeInfo)); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset; + int *yCoords, *zCoords; + + if (!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; + i += gridDim.x * blockDim.x) { + if (is1Dcase) { + yOffset = i * shape::stride(yShapeInfo)[zNonUnitDim]; + zOffset = x[i * shape::stride(xShapeInfo)[xNonUnitDim]] * + shape::stride(zShapeInfo)[zNonUnitDim]; + } else { + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + + if (yRank >= xRank) + zCoords[xLastDim] = + yCoords[xRank - 1]; // saving y coordinate, since it might be + // changed in next instructions + + for (uint j = 0; j < xLastDim; ++j) { // first xRank-1 coordinates in + // yCoords are the same for y and x + yCoords[xRank - 1] = j; + zCoords[j] = x[shape::getOffset(xShapeInfo, yCoords)]; + } + + for (uint j = xLastDim + 1; j < zRank; ++j) + zCoords[j] = yCoords[yRank - zRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); } - __syncthreads(); - - Nd4jLong yOffset, zOffset; - int *yCoords, *zCoords; - - if(!is1Dcase) { - yCoords = coords + threadIdx.x * (biggerXYRank + zRank); - zCoords = yCoords + biggerXYRank; - } - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { - - if(is1Dcase) { - - yOffset = i * shape::stride(yShapeInfo)[zNonUnitDim]; - zOffset = x[i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; - } - else { - - shape::index2coords(i, yShapeInfo, yCoords); - - yOffset = shape::getOffset(yShapeInfo, yCoords); - - if(yRank >= xRank) - zCoords[xLastDim] = yCoords[xRank - 1]; // saving y coordinate, since it might be changed in next instructions - - for (uint j = 0; j < xLastDim; ++j) { // first xRank-1 coordinates in yCoords are the same for y and x - yCoords[xRank - 1] = j; - zCoords[j] = x[shape::getOffset(xShapeInfo, yCoords)]; - } - - for (uint j = xLastDim + 1; j < zRank; ++j) - zCoords[j] = yCoords[yRank - zRank + j]; - - zOffset = shape::getOffset(zShapeInfo, zCoords); - } - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if (z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if (z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; } + } } /////////////////////////////////////////////////////////////////// -template -static void scatterNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const bool lock) { - - if(lock) - scatterNDLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); - else - scatterNDCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +static void scatterNDCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const int opCode, const void *vx, + const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, const bool lock) { + if (lock) + scatterNDLockCuda + <<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterNDCuda<<>>( + opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void scatterND(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const int xRank = indices.rankOf(); - const int yRank = updates.rankOf(); - const int zRank = output.rankOf(); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * ((yRank > xRank ? yRank : xRank) + zRank) + 256; - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - PointersManager manager(context, "scatterND"); - - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - - manager.synchronize(); +void scatterND(sd::LaunchContext *context, pairwise::Ops op, + const NDArray &indices, const NDArray &updates, NDArray &output, + const bool lock) { + const int xRank = indices.rankOf(); + const int yRank = updates.rankOf(); + const int zRank = output.rankOf(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / + threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * + ((yRank > xRank ? yRank : xRank) + zRank) + + 256; + + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + PointersManager manager(context, "scatterND"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR( + xType, yType, scatterNDCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo(), lock), + INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template +template __global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ Nd4jLong xLen; - __shared__ int xRank, *sharedMem; // xRank = zRank, yRank = xRank + 1 + __shared__ Nd4jLong xLen; + __shared__ int xRank, *sharedMem; // xRank = zRank, yRank = xRank + 1 - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - } - __syncthreads(); + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + } + __syncthreads(); - const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; + const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; - if(xInd >= xLen) - return; + if (xInd >= xLen) return; - auto coords = sharedMem + threadIdx.x * (xRank + 1); + auto coords = sharedMem + threadIdx.x * (xRank + 1); - shape::index2coords(xInd, xShapeInfo, coords); + shape::index2coords(xInd, xShapeInfo, coords); - // y last coordinate - coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; + // y last coordinate + coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; - const auto yOffset = shape::getOffset(yShapeInfo, coords); + const auto yOffset = shape::getOffset(yShapeInfo, coords); - if(z == nullptr) { // gradient calculation - y[yOffset] -= 1.f; - } - else { - z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; - } + if (z == nullptr) { // gradient calculation + y[yOffset] -= 1.f; + } else { + z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; + } } /////////////////////////////////////////////////////////////////// -template -static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) { - - scatterForLossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +template +static void scatterForLossCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + void *vy, const Nd4jLong *yShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + scatterForLossCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// -void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { - // shapes of indices and output must be the same - // shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b} - - PointersManager manager(context, "scatterForLoss"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = updates.rankOf() * sizeof(int) * threadsPerBlock + 128; - - if(calcGrad) { - NDArray::prepareSpecialUse({&updates}, {&indices}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&updates}, {&indices}); - } - else { - NDArray::prepareSpecialUse({&output}, {&indices, &updates}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&indices, &updates}); - } - - manager.synchronize(); -} - -} -} +void scatterForLoss(sd::LaunchContext *context, const NDArray &indices, + NDArray &updates, NDArray &output, const bool calcGrad) { + // shapes of indices and output must be the same + // shape of indices should be the same as updates shape with last dimension + // excluded, for example if updates is {a,b,c} then indices should be {a,b} + + PointersManager manager(context, "scatterForLoss"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = updates.rankOf() * sizeof(int) * threadsPerBlock + 128; + + if (calcGrad) { + NDArray::prepareSpecialUse({&updates}, {&indices}); + BUILD_DOUBLE_SELECTOR( + indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), + INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&updates}, {&indices}); + } else { + NDArray::prepareSpecialUse({&output}, {&indices, &updates}); + BUILD_DOUBLE_SELECTOR( + indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + indices.specialBuffer(), indices.specialShapeInfo(), + updates.specialBuffer(), updates.specialShapeInfo(), + output.specialBuffer(), output.specialShapeInfo()), + INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&indices, &updates}); + } + + manager.synchronize(); } +} // namespace helpers +} // namespace ops +} // namespace sd /* /////////////////////////////////////////////////////////////////// template -static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void* vx, const Nd4jLong *xShapeInfo, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { - - scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); +static void scatterLockCudaLauncher(const int blocksPerGrid, const int +threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int +opCode, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const +Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong +*zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong xLen, const Nd4jLong +yTadLen, const Nd4jLong zTadLen) { + + scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, +zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); } @@ -718,16 +793,18 @@ static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPe // x - indices, y - updates, z - input/output template __global__ static void scatterLockCuda(const int opCode, - const void* vx, const Nd4jLong *xShapeInfo, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { + const void* vx, const Nd4jLong +*xShapeInfo, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong +*yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, + const Nd4jLong xLen, const Nd4jLong +yTadLen, const Nd4jLong zTadLen) { const int xRank = indices.rankOf(); - std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); + std::vector zTadDims = +ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); int sizeOfUpdDims = xRank; if(output.rankOf() == updates.rankOf() && indices.isVector()) @@ -736,19 +813,28 @@ __global__ static void scatterLockCuda(const int opCode, std::vector yTadDims(sizeOfUpdDims); std::iota(yTadDims.begin(), yTadDims.end(), 0); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), zTadDims); + auto packY = +sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), +ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), +zTadDims); const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); const Nd4jLong yTadLen = shape::length(packY.primaryShapeInfo()); - const auto threadsPerBlock = sd::math::nd4j_max(32, sd::math::nd4j_min(zTadLen, 1024)); - const auto blocksPerGrid = indices.lengthOf(); + const auto threadsPerBlock = sd::math::nd4j_max(32, +sd::math::nd4j_min(zTadLen, 1024)); const auto blocksPerGrid = +indices.lengthOf(); const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, +(blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, +indices.specialBuffer(), indices.specialShapeInfo(), updates.specialBuffer(), +packY.specialShapeInfo(), packY.specialOffsets(), output.specialBuffer(), +packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, +zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); @@ -764,12 +850,14 @@ __global__ static void scatterLockCuda(const int opCode, for (int e = 0; e < xLen; e++) { const Nd4jLong zIndex = x[shape::getIndexOffset(e, xShapeInfo)]; - const bool isOwner = zIndex < gridDim.x ? blockIdx.x == zIndex : blockIdx.x == zIndex % gridDim.x; + const bool isOwner = zIndex < gridDim.x ? blockIdx.x == zIndex : +blockIdx.x == zIndex % gridDim.x; if (!isOwner) continue; - if(vectorCase) { // means z_rank = 1 and might be yTadLen != zTadLen in this case + if(vectorCase) { // means z_rank = 1 and might be yTadLen != zTadLen in +this case if(threadIdx.x != 0) continue; @@ -842,13 +930,9 @@ __global__ static void scatterLockCuda(const int opCode, zTad[zOffset] = yTad[yOffset]; break; case pairwise::MaxPairwise: - if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - case pairwise::MinPairwise: - if(zTad[zOffset] > yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - default: - continue; + if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = +yTad[yOffset]; break; case pairwise::MinPairwise: if(zTad[zOffset] > +yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; break; default: continue; } } } @@ -856,10 +940,11 @@ __global__ static void scatterLockCuda(const int opCode, } template - __global__ static void scatterCuda(const int opCode, const int numOfSubArrs, - void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const int* indexes, unsigned int arrLenX, unsigned int arrLenY) { + __global__ static void scatterCuda(const int opCode, const int +numOfSubArrs, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, + void* vy, const Nd4jLong +*yShapeInfo, const Nd4jLong *yOffsets, const int* indexes, unsigned int arrLenX, +unsigned int arrLenY) { __shared__ T *x, *y; @@ -868,7 +953,8 @@ __global__ static void scatterLockCuda(const int opCode, for (int e = 0; e < numOfSubArrs; e++) { const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == +xIndex : blockIdx.x == xIndex % gridDim.x; if (!isOwner) continue; @@ -879,10 +965,11 @@ __global__ static void scatterLockCuda(const int opCode, } __syncthreads(); - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += +blockDim.x) { - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + const auto xOffset = shape::getIndexOffset(i, +xShapeInfo); const auto yOffset = shape::getIndexOffset(i, yShapeInfo); switch (opCode) { case pairwise::Add: @@ -922,9 +1009,9 @@ __global__ static void scatterLockCuda(const int opCode, } __syncthreads(); - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += +blockDim.x) { const auto xOffset = shape::getIndexOffset(i, xShapeInfo); const +auto yOffset = shape::getIndexOffset(i, yShapeInfo); switch (opCode) { case pairwise::Add: @@ -959,12 +1046,17 @@ __global__ static void scatterLockCuda(const int opCode, template - void scatter_(sd::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + void scatter_(sd::LaunchContext *context, pairwise::Ops op, const +NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { std::vector dims = {0}; - auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), dims); + auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), +dims); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), inverted); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), inverted); + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), +inverted); auto packY = +sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), +inverted); auto psX = packX.specialShapeInfo(); auto psY = packY.specialShapeInfo(); @@ -976,17 +1068,23 @@ __global__ static void scatterLockCuda(const int opCode, NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - unsigned int tadLengthX = shape::length(packX.primaryShapeInfo()); - unsigned int tadLengthY = shape::length(packY.primaryShapeInfo()); - if (tadLengthX != tadLengthY) - throw std::runtime_error("scatter: Lengths of TADs must be equal"); + unsigned int tadLengthX = +shape::length(packX.primaryShapeInfo()); unsigned int tadLengthY = +shape::length(packY.primaryShapeInfo()); if (tadLengthX != tadLengthY) throw +std::runtime_error("scatter: Lengths of TADs must be equal"); - auto blockSize = sd::math::nd4j_max(32, sd::math::nd4j_min(tadLengthX, 1024)); + auto blockSize = sd::math::nd4j_max(32, +sd::math::nd4j_min(tadLengthX, 1024)); if (lock) - scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); - else - scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); + scatterCuda<<<512, blockSize, 1024, +*context->getCudaStream()>>>(op, indices.lengthOf(), output.specialBuffer(), +psX, poX, updates.specialBuffer(), psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, tadLengthY); else scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, +indices.lengthOf(), output.specialBuffer(), psX, poX, updates.specialBuffer(), +psY, poY, reinterpret_cast(indices.specialBuffer()), tadLengthX, +tadLengthY); NDArray::registerSpecialUse({&output}, {&updates, &indices}); manager.synchronize(); @@ -998,11 +1096,11 @@ __global__ static void scatterLockCuda(const int opCode, // x - indices, y - updates, z - output template __global__ static void scatterNDLockCuda(const int opCode, - const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong *zShapeInfo, - const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong yTadLen) { + const void* vx, const Nd4jLong +*xTadShapeInfo, const Nd4jLong *xOffsets, const void* vy, const Nd4jLong +*yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong +*zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong *zShapeInfo, const +Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong yTadLen) { @@ -1016,17 +1114,23 @@ const int xLastDim = indices.sizeAt(-1); zTadDims[i] = zRank - 1 - j; } - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(indices.shapeInfo(), {xRank - 1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), yTadDims); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), zTadDims); + auto packX = +sd::ConstantTadHelper::getInstance()->tadForDimensions(indices.shapeInfo(), +{xRank - 1}); auto packY = +sd::ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), +yTadDims); auto packZ = +sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), +zTadDims); const int threadsPerBlock = MAX_NUM_THREADS / 4; const int blocksPerGrid = packZ.numberOfTads(); const int sharedMem = 8 * threadsPerBlock * xLastDim + 128; --------------------------------------------------------------------------- - // zTadLen == yTadLen if numOfZTads > 1, in opposite case z and y are vectors - // numOfXTads == numOfYTads if numOfZTads > 1, in opposite case z and y are vectors + // zTadLen == yTadLen if numOfZTads > 1, in opposite case z and y are +vectors + // numOfXTads == numOfYTads if numOfZTads > 1, in opposite case z and y are +vectors const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); @@ -1049,11 +1153,14 @@ const int xLastDim = indices.sizeAt(-1); const X* xTad = x + xOffsets[i]; for (uint k = 0; k < xLastDim; ++k) - zTadCoordsPerThread[k] = xTad[shape::getIndexOffset(k, xTadShapeInfo)]; + zTadCoordsPerThread[k] = xTad[shape::getIndexOffset(k, +xTadShapeInfo)]; - const auto zTadIndex = shape::coords2index(xLastDim, zShapeInfo + 1, zTadCoordsPerThread); + const auto zTadIndex = shape::coords2index(xLastDim, zShapeInfo + 1, +zTadCoordsPerThread); - const bool isOwner = zTadIndex < gridDim.x ? blockIdx.x == zTadIndex : blockIdx.x == zTadIndex % gridDim.x; + const bool isOwner = zTadIndex < gridDim.x ? blockIdx.x == zTadIndex : +blockIdx.x == zTadIndex % gridDim.x; if(!isOwner) continue; @@ -1063,8 +1170,8 @@ const int xLastDim = indices.sizeAt(-1); if(threadIdx.x != 0) continue; - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - const auto zOffset = shape::getIndexOffset(zTadIndex, zTadShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); const +auto zOffset = shape::getIndexOffset(zTadIndex, zTadShapeInfo); switch (opCode) { case pairwise::Add: @@ -1130,13 +1237,9 @@ const int xLastDim = indices.sizeAt(-1); zTad[zOffset] = yTad[yOffset]; break; case pairwise::MaxPairwise: - if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - case pairwise::MinPairwise: - if(zTad[zOffset] > yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; - break; - default: - continue; + if(zTad[zOffset] < yTad[yOffset]) zTad[zOffset] = +yTad[yOffset]; break; case pairwise::MinPairwise: if(zTad[zOffset] > +yTad[yOffset]) zTad[zOffset] = yTad[yOffset]; break; default: continue; } } } @@ -1144,24 +1247,33 @@ const int xLastDim = indices.sizeAt(-1); } */ - // PointersManager manager(&context, "NativeOps::concat"); - // PointersManager::printDevContentOnDev(vx, 2); - // PointersManager::printDevContentOnDev(xShapeInfo, 8); - // PointersManager::printDevContentOnDev(vy, 8); - // PointersManager::printDevContentOnDev(yShapeInfo, 8); - // PointersManager::printDevContentOnDev(zShapeInfo, 8); - - // manager.printDevContentOnHost(indices.specialBuffer(), indices.lengthOf()); - // manager.printDevContentOnHost(indices.specialShapeInfo(), shape::shapeInfoLength(indices.rankOf())); - // manager.printDevContentOnHost(updates.specialBuffer(), updates.lengthOf()); - // manager.printDevContentOnHost(updates.specialShapeInfo(), shape::shapeInfoLength(updates.rankOf())); - // manager.printDevContentOnHost(output.specialShapeInfo(), shape::shapeInfoLength(output.rankOf())); - // printf("!!!!!!!\n"); - // manager.printDevContentOnHost(packX.specialShapeInfo(), 2*shape::rank(packX.primaryShapeInfo()) + 4); - // manager.printDevContentOnHost(packX.specialOffsets(), packX.numberOfTads()); - // manager.printDevContentOnHost(packY.specialShapeInfo(), 2*shape::rank(packY.primaryShapeInfo()) + 4); - // manager.printDevContentOnHost(packY.specialOffsets(), packY.numberOfTads()); - // manager.printDevContentOnHost(packZ.specialShapeInfo(), 2*shape::rank(packZ.primaryShapeInfo()) + 4); - // manager.printDevContentOnHost(packZ.specialOffsets(), packZ.numberOfTads()); - // printf("dddddddd\n"); - // shape::printShapeInfoLinear(packY.primaryShapeInfo()); \ No newline at end of file +// PointersManager manager(&context, "NativeOps::concat"); +// PointersManager::printDevContentOnDev(vx, 2); +// PointersManager::printDevContentOnDev(xShapeInfo, 8); +// PointersManager::printDevContentOnDev(vy, 8); +// PointersManager::printDevContentOnDev(yShapeInfo, 8); +// PointersManager::printDevContentOnDev(zShapeInfo, 8); + +// manager.printDevContentOnHost(indices.specialBuffer(), +// indices.lengthOf()); +// manager.printDevContentOnHost(indices.specialShapeInfo(), +// shape::shapeInfoLength(indices.rankOf())); +// manager.printDevContentOnHost(updates.specialBuffer(), +// updates.lengthOf()); +// manager.printDevContentOnHost(updates.specialShapeInfo(), +// shape::shapeInfoLength(updates.rankOf())); +// manager.printDevContentOnHost(output.specialShapeInfo(), +// shape::shapeInfoLength(output.rankOf())); printf("!!!!!!!\n"); +// manager.printDevContentOnHost(packX.specialShapeInfo(), +// 2*shape::rank(packX.primaryShapeInfo()) + 4); +// manager.printDevContentOnHost(packX.specialOffsets(), +// packX.numberOfTads()); +// manager.printDevContentOnHost(packY.specialShapeInfo(), +// 2*shape::rank(packY.primaryShapeInfo()) + 4); +// manager.printDevContentOnHost(packY.specialOffsets(), +// packY.numberOfTads()); +// manager.printDevContentOnHost(packZ.specialShapeInfo(), +// 2*shape::rank(packZ.primaryShapeInfo()) + 4); +// manager.printDevContentOnHost(packZ.specialOffsets(), +// packZ.numberOfTads()); printf("dddddddd\n"); +// shape::printShapeInfoLinear(packY.primaryShapeInfo()); \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu index a17464cbdf3e..f783c27392ad 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu @@ -18,62 +18,75 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { - namespace ops { - namespace helpers { - template - static _CUDA_G void scatterSimpleKernel(void *vx, const Nd4jLong *xTadShape, const Nd4jLong *xTadOffsets, Nd4jLong xLength, Nd4jLong numTads, const void *vi, const Nd4jLong *iShapeInfo, Nd4jLong iLength, const void *vu, const Nd4jLong *uShapeInfo, Nd4jLong uLength) { - auto u = reinterpret_cast(vu); - auto indices = reinterpret_cast(vi); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[i]; - auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; - - x[shape::getIndexOffset(idx, xTadShape)] = u[shape::getIndexOffset(i, uShapeInfo)]; - } - } - - - template - void scatterSimple_(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - - auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dims); - - auto xLength = shape::length(packX.primaryShapeInfo()); - auto iLength = indices.lengthOf(); - auto uLength = updates.lengthOf(); - - scatterSimpleKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.specialBuffer(), indices.specialShapeInfo(), iLength, updates.specialBuffer(), updates.specialShapeInfo(), uLength); - } - - - void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions) { - auto xType = input.dataType(); - auto yType = indices.dataType(); - - if (opId != 6) - throw std::runtime_error("scatterSimple: only copy op is supported"); - - NDArray::prepareSpecialUse({&input}, {&updates, &indices}); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INDEXING_TYPES); +#include - NDArray::registerSpecialUse({&input}, {&updates, &indices}); - } - } - } -} \ No newline at end of file +namespace sd { +namespace ops { +namespace helpers { +template +static _CUDA_G void scatterSimpleKernel( + void* vx, const Nd4jLong* xTadShape, const Nd4jLong* xTadOffsets, + Nd4jLong xLength, Nd4jLong numTads, const void* vi, + const Nd4jLong* iShapeInfo, Nd4jLong iLength, const void* vu, + const Nd4jLong* uShapeInfo, Nd4jLong uLength) { + auto u = reinterpret_cast(vu); + auto indices = reinterpret_cast(vi); + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < iLength; i += blockDim.x * gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[i]; + auto idx = indices[shape::getIndexOffset(i, iShapeInfo)]; + + x[shape::getIndexOffset(idx, xTadShape)] = + u[shape::getIndexOffset(i, uShapeInfo)]; + } +} + +template +void scatterSimple_(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + auto dims = ShapeUtils::evalDimsToExclude(input.rankOf(), dimensions); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dims); + + auto xLength = shape::length(packX.primaryShapeInfo()); + auto iLength = indices.lengthOf(); + auto uLength = updates.lengthOf(); + + scatterSimpleKernel<<<256, 256, 1024, *context->getCudaStream()>>>( + input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), + xLength, packX.numberOfTads(), indices.specialBuffer(), + indices.specialShapeInfo(), iLength, updates.specialBuffer(), + updates.specialShapeInfo(), uLength); +} + +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions) { + auto xType = input.dataType(); + auto yType = indices.dataType(); + + if (opId != 6) + throw std::runtime_error("scatterSimple: only copy op is supported"); + + NDArray::prepareSpecialUse({&input}, {&updates, &indices}); + + BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, + (context, opId, input, updates, indices, dimensions), + LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&input}, {&updates, &indices}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu index 51f917a7913d..16bf7e7c3273 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_update.cu @@ -18,115 +18,124 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { - namespace ops { - namespace helpers { - /////////////////////////////////////////////////////////////////// - template - __global__ static void scatterUpdateCuda(const int opCode, const int numOfInd, - void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, - void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - const int* indexes) { - - __shared__ T *x, *y; - __shared__ Nd4jLong arrLenX, arrLenY; - - for (int e = 0; e < numOfInd; e++ ) { - - const auto xIndex = indexes[e]; - const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; - - if (!isOwner) - continue; - - if (threadIdx.x == 0) { - x = reinterpret_cast(vx) + xOffsets[xIndex]; - y = reinterpret_cast(vy) + yOffsets[e]; - arrLenX = shape::length(xShapeInfo); - arrLenY = shape::length(yShapeInfo); - } - __syncthreads(); - - if (arrLenX != arrLenY) - return; - - for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - - switch (opCode) { - case 0: - x[xOffset] += y[yOffset]; - break; - case 1: - x[xOffset] -= y[yOffset]; - break; - case 2: - x[xOffset] *= y[yOffset]; - break; - case 3: - x[xOffset] /= y[yOffset]; - break; - case 4: - x[xOffset] = y[yOffset] - x[xOffset]; - break; - case 5: - x[xOffset] = y[yOffset] / x[xOffset]; - break; - case 6: - x[xOffset] = y[yOffset]; - break; - default: - continue; - } - } - __syncthreads(); - } - } - - template - __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) { - - scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); - } +#include +namespace sd { +namespace ops { +namespace helpers { +/////////////////////////////////////////////////////////////////// +template +__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd, + void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xOffsets, void* vy, + const Nd4jLong* yShapeInfo, + const Nd4jLong* yOffsets, + const int* indexes) { + __shared__ T *x, *y; + __shared__ Nd4jLong arrLenX, arrLenY; + + for (int e = 0; e < numOfInd; e++) { + const auto xIndex = indexes[e]; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex + : blockIdx.x == xIndex % gridDim.x; + + if (!isOwner) continue; + + if (threadIdx.x == 0) { + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + arrLenX = shape::length(xShapeInfo); + arrLenY = shape::length(yShapeInfo); + } + __syncthreads(); + + if (arrLenX != arrLenY) return; + + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + + switch (opCode) { + case 0: + x[xOffset] += y[yOffset]; + break; + case 1: + x[xOffset] -= y[yOffset]; + break; + case 2: + x[xOffset] *= y[yOffset]; + break; + case 3: + x[xOffset] /= y[yOffset]; + break; + case 4: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case 5: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case 6: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } + } + __syncthreads(); + } +} + +template +__host__ static void scatterUpdateCudaLauncher( + const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, + const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, void* vy, + const Nd4jLong* yShapeInfo, const Nd4jLong* yOffsets, const int* indexes) { + scatterUpdateCuda<<<512, 256, MAX_NUM_THREADS, *stream>>>( + opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, + indexes); +} ////////////////////////////////////////////////////////////////////////// - void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector* intArgs) { - - const int opCode = (*intArgs)[0]; - const int numOfDims = (*intArgs)[1]; - const int numOfInd = (*intArgs)[2 + numOfDims]; - - std::vector tadDimensions(numOfDims); - for (int e = 2; e < 2 + numOfDims; e++) - tadDimensions[e-2] = (*intArgs)[e]; - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), tadDimensions); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.shapeInfo(), tadDimensions); - - NDArray indices(const_cast(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, sd::DataType::INT32, context); - - PointersManager manager(context, "scatterUpdate"); - - NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices}); - BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast(indices.specialBuffer())), LIBND4J_TYPES); - NDArray::registerSpecialUse({&input}, {&input, &updates, &indices}); - - manager.synchronize(); - } - } - } -} \ No newline at end of file +void scatterUpdate(sd::LaunchContext* context, NDArray& input, NDArray& updates, + const std::vector* intArgs) { + const int opCode = (*intArgs)[0]; + const int numOfDims = (*intArgs)[1]; + const int numOfInd = (*intArgs)[2 + numOfDims]; + + std::vector tadDimensions(numOfDims); + for (int e = 2; e < 2 + numOfDims; e++) tadDimensions[e - 2] = (*intArgs)[e]; + + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), tadDimensions); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + updates.shapeInfo(), tadDimensions); + + NDArray indices(const_cast(intArgs->data()) + numOfDims + 3, 'c', + {numOfInd}, sd::DataType::INT32, context); + + PointersManager manager(context, "scatterUpdate"); + + NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices}); + BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, + (context->getCudaStream(), opCode, numOfInd, + input.specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), updates.specialBuffer(), + packY.platformShapeInfo(), packY.platformOffsets(), + reinterpret_cast(indices.specialBuffer())), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&input}, {&input, &updates, &indices}); + + manager.synchronize(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu index 60d00fb60f57..f077d8df7ce1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu @@ -18,117 +18,149 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Sorted segments ops implementations - - template - static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& aoutput) { - return true; - } - - bool segmentIndicesValidate(sd::LaunchContext* context , NDArray* indices, NDArray& expected, NDArray& output) { - BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); - } - - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted segment ops functors implementation - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentIndexValidateKernel(const I* indices, const Nd4jLong* indicesShape, I expected, I* found) { - __shared__ bool onlyTrue; - __shared__ Nd4jLong len; - - if (threadIdx.x == 0) { - onlyTrue = true; - len = shape::length(indicesShape); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = gridDim.x * blockDim.x; - for (int e = start; e < len && onlyTrue; e += step) { - sd::math::atomics::nd4j_atomicMax(found, indices[e]); - if (expected < *found) - onlyTrue = false; - } - } - - template - static bool unsortedSegmentIndicesValidate_(sd::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - output = expected; - I found = output; - I exp = expected; - auto stream = context->getCudaStream(); - I* devFound; - cudaMalloc(&devFound, sizeof(I)); - cudaMemcpy(devFound, &found, sizeof(I), cudaMemcpyHostToDevice); - unsortedSegmentIndexValidateKernel<<<1, indices->lengthOf(), 128, *stream>>>(reinterpret_cast(indices->specialBuffer()), indices->specialShapeInfo(), exp, devFound); - cudaMemcpy(&found, devFound, sizeof(I), cudaMemcpyDeviceToHost); - cudaFree(devFound); - output = found; - return expected == output; - } - - bool unsortedSegmentIndicesValidate(sd::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INDEXING_TYPES); - } - - // -------------------------------------------------------------------------------------------------------------- // - - // -------------------------------------------------------------------------------------------------------------- // - // fill up segments starts and ends - splitted ordered case - template - static __global__ void fillUpSegmentsKernel(const void* indices, const Nd4jLong* indexShape, int numClasses, int* classesRangesStart, int* classesRangesLenghts) { - __shared__ const I* idxBuf; - __shared__ Nd4jLong idxLen; - __shared__ int* result; - if (threadIdx.x == 0) { - idxBuf = reinterpret_cast(indices); - idxLen = shape::length(indexShape); - } - __syncthreads(); - - auto tid = threadIdx.x + blockDim.x * blockIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto j = tid; j < idxLen; j += step) { - auto pos = idxBuf[j]; - sd::math::atomics::nd4j_atomicMin(&classesRangesStart[pos], (int)j); - sd::math::atomics::nd4j_atomicAdd(&classesRangesLenghts[pos], 1); - } - } - - // -------------------------------------------------------------------------------------------------------------- // - - template - static void fillUpSegments_(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - auto stream = classesRangesBegs.getContext()->getCudaStream(); - fillUpSegmentsKernel<<>>(indices->specialBuffer(), indices->specialShapeInfo(), numClasses, begins, lengths); - } - // -------------------------------------------------------------------------------------------------------------- // - - void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INDEXING_TYPES); - } - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // Sorted segments ops implementations +template +static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, + NDArray& aoutput) { + return true; } + +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output) { + BUILD_DOUBLE_SELECTOR( + output.dataType(), indices->dataType(), return segmentIndicesValidate_, + (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); +} + +// -------------------------------------------------------------------------------------------------------------- +// // Unsorted segment ops functors implementation +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentIndexValidateKernel( + const I* indices, const Nd4jLong* indicesShape, I expected, I* found) { + __shared__ bool onlyTrue; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + onlyTrue = true; + len = shape::length(indicesShape); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + for (int e = start; e < len && onlyTrue; e += step) { + sd::math::atomics::nd4j_atomicMax(found, indices[e]); + if (expected < *found) onlyTrue = false; + } +} + +template +static bool unsortedSegmentIndicesValidate_(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + output = expected; + I found = output; + I exp = expected; + auto stream = context->getCudaStream(); + I* devFound; + cudaMalloc(&devFound, sizeof(I)); + cudaMemcpy(devFound, &found, sizeof(I), cudaMemcpyHostToDevice); + unsortedSegmentIndexValidateKernel + <<<1, indices->lengthOf(), 128, *stream>>>( + reinterpret_cast(indices->specialBuffer()), + indices->specialShapeInfo(), exp, devFound); + cudaMemcpy(&found, devFound, sizeof(I), cudaMemcpyDeviceToHost); + cudaFree(devFound); + output = found; + return expected == output; +} + +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong expected, + Nd4jLong& output) { + BUILD_SINGLE_SELECTOR(indices->dataType(), + return unsortedSegmentIndicesValidate_, + (context, indices, expected, output), INDEXING_TYPES); +} + +// -------------------------------------------------------------------------------------------------------------- +// // + +// -------------------------------------------------------------------------------------------------------------- +// // fill up segments starts and ends - splitted ordered case +template +static __global__ void fillUpSegmentsKernel(const void* indices, + const Nd4jLong* indexShape, + int numClasses, + int* classesRangesStart, + int* classesRangesLenghts) { + __shared__ const I* idxBuf; + __shared__ Nd4jLong idxLen; + __shared__ int* result; + if (threadIdx.x == 0) { + idxBuf = reinterpret_cast(indices); + idxLen = shape::length(indexShape); + } + __syncthreads(); + + auto tid = threadIdx.x + blockDim.x * blockIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto j = tid; j < idxLen; j += step) { + auto pos = idxBuf[j]; + sd::math::atomics::nd4j_atomicMin(&classesRangesStart[pos], (int)j); + sd::math::atomics::nd4j_atomicAdd(&classesRangesLenghts[pos], 1); + } +} + +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void fillUpSegments_(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, + NDArray& classesRangesLens) { + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + auto stream = classesRangesBegs.getContext()->getCudaStream(); + fillUpSegmentsKernel<<>>( + indices->specialBuffer(), indices->specialShapeInfo(), numClasses, begins, + lengths); } +// -------------------------------------------------------------------------------------------------------------- +// // + +void fillUpSegments(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, NDArray& classesRangesLens) { + BUILD_SINGLE_SELECTOR( + indices->dataType(), fillUpSegments_, + (indices, numClasses, classesRangesBegs, classesRangesLens), + INDEXING_TYPES); } -// -------------------------------------------------------------------------------------------------------------- // -// -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // + +} // namespace helpers +} // namespace ops +} // namespace sd +// -------------------------------------------------------------------------------------------------------------- +// // +// -------------------------------------------------------------------------------------------------------------- +// // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index 927b1bb2f5d9..b84412e12e74 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -18,413 +18,509 @@ // @author GS // -#include -#include - #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - namespace helpers { - - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - segmentMaxLinearKernel(void *input, Nd4jLong const* inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, Nd4jLong const* outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T *x; - __shared__ T *z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = blockIdx.x; - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - extern __shared__ unsigned char shmem[]; - val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentMaxLinearKernel(void *input, Nd4jLong const* inputShape, void *indices, Nd4jLong const* indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, void *output, - Nd4jLong const* outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T *x; - __shared__ T *z; - __shared__ I *y; //int threadsPerSegment, start, finish; - auto segment = blockIdx.x; - - if (threadIdx.x == 0) { - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - //start = starts[segment]; - //finish = start + lengths[segment]; - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = -DataTypeUtils::max(); - } - __syncthreads(); - if (lengths[segment] > 0) - for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment) { - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, - Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, - Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets, T filler = 0) { - - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int start, finish; - __shared__ I segment; - - if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - } - __syncthreads(); - - auto idx = blockIdx.x; - if (idx <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - //z[zIndex] = x[xIndex]; - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - //Nd4jLong idx = indices->e(0); - output->assign(-DataTypeUtils::infOrMax()); - auto stream = context->getCudaStream(); - indices->syncToHost(); - Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(256, 512, 256); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - - NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - - if (input->isVector()) { - - segmentMaxLinearKernel<<lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMaxTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMaxFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void unsortedSegmentMaxFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - output->assign(DataTypeUtils::infOrMax()); - - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentMaxLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - output->assign(-DataTypeUtils::max()); - segmentMaxTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - // segment max - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { - z[zOffset] = gradOut[gradOffsetO]; - } - } - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, Nd4jLong const* gradInOffsets, - Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, - Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - T* current = x + inputOffsets[i]; - T* currentOut = z + outOffsets[i]; - T* in = gradIn + gradInOffsets[segment]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) - currentOut[e] = outGrad[e]; - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentMaxFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentMaxFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMaxBPLinearKernel<<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); - Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMaxBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int segmentMaxFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentMaxFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMaxBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); - Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMaxBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } - } +namespace ops { +namespace helpers { + +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentMaxLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = blockIdx.x; + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; segment = blockIdx.x / + // threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + extern __shared__ unsigned char shmem[]; + val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; + val[segment] = z[zIndex]; + } + } + __syncthreads(); + + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentMaxLinearKernel( + void* input, Nd4jLong const* inputShape, void* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ I* y; // int threadsPerSegment, start, finish; + auto segment = blockIdx.x; + + if (threadIdx.x == 0) { + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + // start = starts[segment]; + // finish = start + lengths[segment]; + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = -DataTypeUtils::max(); + } + __syncthreads(); + if (lengths[segment] > 0) + for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment) { + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets, + T filler = 0) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int start, finish; + __shared__ I segment; + + if (threadIdx.x == 0) { + segment = indices[blockIdx.x]; // / threadsPerSegment; + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (idx <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + // z[zIndex] = x[xIndex]; + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + } + } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + // int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + // Nd4jLong idx = indices->e(0); + output->assign(-DataTypeUtils::infOrMax()); + auto stream = context->getCudaStream(); + indices->syncToHost(); + Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(256, 512, 256); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + + NDArray::prepareSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); + + if (input->isVector()) { + segmentMaxLinearKernel + <<lengthOf(), numOfClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMaxTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentMaxFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void unsortedSegmentMaxFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + output->assign(DataTypeUtils::infOrMax()); + + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), row, + // classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentMaxLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + output->assign(-DataTypeUtils::max()); + segmentMaxTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMaxFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} + +// -------------------------------------------------------------------------------------------------------------- +// // segment max +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { + z[zOffset] = gradOut[gradOffsetO]; + } + } +} + +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMaxBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape, Nd4jLong const* inputTad, + Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, + Nd4jLong const* gradInOffsets, Nd4jLong const* gradOutTad, + Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, + Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + T* current = x + inputOffsets[i]; + T* currentOut = z + outOffsets[i]; + T* in = gradIn + gradInOffsets[segment]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) + currentOut[e] = outGrad[e]; } -} \ No newline at end of file + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentMaxFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentMaxFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMaxBPLinearKernel + <<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); + Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMaxBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} + +// -------------------------------------------------------------------------------------------------------------- +// // +template +static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentMaxFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMaxBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradInTads = packGradIn.specialShapeInfo(); + Nd4jLong const* gradInTadOffsets = packGradIn.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMaxBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMaxFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index c75293c1da83..68262f509b10 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -18,400 +18,502 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong const* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - __shared__ T* val; - __shared__ Nd4jLong xLen, zLen, segment, zIndex; - __shared__ T* x; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); -// extern __shared__ unsigned char shmem[]; -// val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - //[zIndex] = - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - //val[segment] = ; - z[zIndex] = T(x[shape::getIndexOffset(start, inputShape)] / lengths[segment]); -// val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong const* inputShape, void* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - __shared__ T* val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ T* x; - __shared__ T* z; - __shared__ I* y; //int threadsPerSegment, start, finish; - auto segment = blockIdx.x;// / - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); -// extern __shared__ unsigned char shmem[]; -// val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - -// if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - //start = starts[segment]; - //finish = start + lengths[segment]; - if (lengths[segment] > 0) - z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape)] / T(lengths[segment])); - else - z[zIndex] = 0; //DataTypeUtils::max(); -// val[segment] = z[zIndex]; -// } - - } - __syncthreads(); - if (lengths[segment] > 0) - for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment && e != starts[segment]) { - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment]))); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentMean kernel - template - static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[segment]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen mean - template - static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - NDArray::prepareSpecialUse({output}, {input, indices}); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - - if (input->isVector()) { - segmentMeanLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices}); - - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMeanFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentMeanFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentMeanLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(0); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentMeanTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; + segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + // extern __shared__ unsigned char shmem[]; + // val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + //[zIndex] = + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + // val[segment] = ; + z[zIndex] = + T(x[shape::getIndexOffset(start, inputShape)] / lengths[segment]); + // val[segment] = z[zIndex]; } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMeanFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - int* lengths, void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, - void* indicesBuf, Nd4jLong const* indicesShape, int* lengths, void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { -// auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[i]; //yIndex]; - T* currentOut = z + outOffsets[i]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - auto zIndex = shape::getIndexOffset(e, outTad); - auto gradIndex = shape::getIndexOffset(e, gradOutTad); - if (lengths[segment] > 0) - currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment])); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // backrop for mean - template - int segmentMeanFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMeanBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMeanBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); + } + __syncthreads(); + + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentMeanLinearKernel( + void* input, Nd4jLong const* inputShape, void* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T* x; + __shared__ T* z; + __shared__ I* y; // int threadsPerSegment, start, finish; + auto segment = blockIdx.x; // / + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + // extern __shared__ unsigned char shmem[]; + // val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + // if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + // start = starts[segment]; + // finish = start + lengths[segment]; + if (lengths[segment] > 0) + z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape)] / + T(lengths[segment])); + else + z[zIndex] = 0; // DataTypeUtils::max(); + // val[segment] = z[zIndex]; + // } + } + __syncthreads(); + if (lengths[segment] > 0) + for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment && e != starts[segment]) { + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / T(lengths[segment]))); + } } - // -------------------------------------------------------------------------------------------------------------- // - // segmen mean bp main - int segmentMeanFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentMean kernel +template +static __global__ void segmentMeanTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[segment]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], + T(x[xIndex] / lengths[segment])); + } } - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentMeanFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen mean +template +static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + NDArray::prepareSpecialUse({output}, {input, indices}); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + + if (input->isVector()) { + segmentMeanLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), + segmentMeanFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMeanBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - Nd4jLong const* inputTads = packX.specialShapeInfo(); - Nd4jLong const* inputTadOffsets = packX.specialOffsets(); - Nd4jLong const* outputTads = packZ.specialShapeInfo(); - Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); - Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); - Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentMeanFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentMeanLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(0); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentMeanTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMeanFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - segmentMeanBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMeanBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* inputTad, Nd4jLong const* inputOffsets, + Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, + Nd4jLong const* outTad, Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + // auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[i]; // yIndex]; + T* currentOut = z + outOffsets[i]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + auto zIndex = shape::getIndexOffset(e, outTad); + auto gradIndex = shape::getIndexOffset(e, gradOutTad); + if (lengths[segment] > 0) + currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment])); } - + } +} +// -------------------------------------------------------------------------------------------------------------- +// // backrop for mean +template +int segmentMeanFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMeanBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMeanBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // segmen mean bp main +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentMeanFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMeanBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + Nd4jLong const* inputTads = packX.specialShapeInfo(); + Nd4jLong const* inputTadOffsets = packX.specialOffsets(); + Nd4jLong const* outputTads = packZ.specialShapeInfo(); + Nd4jLong const* outputTadOffsets = packZ.specialOffsets(); + Nd4jLong const* gradOutTads = packGradOut.specialShapeInfo(); + Nd4jLong const* gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMeanBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMeanFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index c6f2d4ed25bd..ed667105ccca 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -18,411 +18,493 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - segmentMinLinearKernel(const void *input, const Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ T *val; - __shared__ Nd4jLong xLen, zLen, zIndex; - __shared__ const T *x; - __shared__ T *z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = blockIdx.x; - if (threadIdx.x == 0) { -// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; -// segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - extern __shared__ unsigned char shmem[]; - val = reinterpret_cast(shmem); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - val[segment] = z[zIndex]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentMinLinearKernel( + const void* input, const Nd4jLong* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = blockIdx.x; + if (threadIdx.x == 0) { + // threadsPerSegment = (gridDim.x + numOfClasses - 1) / + // numOfClasses; segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + extern __shared__ unsigned char shmem[]; + val = reinterpret_cast(shmem); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; + val[segment] = z[zIndex]; } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentMinLinearKernel(const void *input, const Nd4jLong *inputShape, const void *indices, const Nd4jLong *indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, void *output, - const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ - const I *y; //int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - segment = blockIdx.x; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = DataTypeUtils::max(); + } + __syncthreads(); - } - __syncthreads(); - - if (lengths[segment] > 0) - for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment) { - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentMinLinearKernel( + const void* input, const Nd4jLong* inputShape, const void* indices, + const Nd4jLong* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ const I* y; // int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + segment = blockIdx.x; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = DataTypeUtils::max(); + } + __syncthreads(); + + if (lengths[segment] > 0) + for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment) { + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // +} +// -------------------------------------------------------------------------------------------------------------- +// // // SegmentMin kernel - template - static __global__ void segmentMinTadKernel(const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, const Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, const Nd4jLong* outputShape, const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; - - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); -// if (lengths[indices[idx]]) - sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen min - template - static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - auto classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - auto classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - output->assign(DataTypeUtils::infOrMax()); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - if (input->isVector()) { - segmentMinLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentMinTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - - } - NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); - +template +static __global__ void segmentMinTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, + const Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, const Nd4jLong* outputShape, + const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + if (threadIdx.x == 0) { + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + // if (lengths[indices[idx]]) + sd::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // - void segmentMinFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - - template - static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - output->assign(DataTypeUtils::infOrMax()); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - NDArray::prepareSpecialUse({output}, {input, indices}); - if (input->isVector()) { - unsortedSegmentMinLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(DataTypeUtils::max()); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentMinTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices}); - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentMinFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - template - static __global__ void segmentMinBPLinearKernel(const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, - const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape) { - __shared__ const T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { - z[zOffset] = gradOut[gradOffsetO]; - } - } - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentMinBPTadKernel(const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, - const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape, - const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, - const Nd4jLong* gradInTad, const Nd4jLong* gradInOffsets, - const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, - const Nd4jLong* outTad, const Nd4jLong* outOffsets) { - __shared__ const T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen min +template +static void segmentMinFunctor_(LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + auto classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + auto classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + output->assign(DataTypeUtils::infOrMax()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + NDArray::prepareSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + if (input->isVector()) { + segmentMinLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentMinTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse( + {output}, {input, indices, &classesRangesBegs, &classesRangesLens}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentMinFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - auto current = x + inputOffsets[i]; - auto currentOut = z + outOffsets[i]; - auto in = gradIn + gradInOffsets[segment]; - auto outGrad = gradOut + gradOutOffsets[segment]; +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void unsortedSegmentMinFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + output->assign(DataTypeUtils::infOrMax()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + NDArray::prepareSpecialUse({output}, {input, indices}); + if (input->isVector()) { + unsortedSegmentMinLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(DataTypeUtils::max()); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentMinTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentMinFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) - currentOut[e] = outGrad[e]; - } - } +template +static __global__ void segmentMinBPLinearKernel( + const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, + const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, + const void* indicesBuf, const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape) { + __shared__ const T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + if (sd::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { + z[zOffset] = gradOut[gradOffsetO]; } + } +} - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentMinFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentMinFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - - segmentMinBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentMinBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - // segmen min - int segmentMinFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentMinBPTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, void* forwardOutput, + const Nd4jLong* forwardShape, void* eps, const Nd4jLong* epsShape, + const void* indicesBuf, const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape, const Nd4jLong* inputTad, + const Nd4jLong* inputOffsets, const Nd4jLong* gradInTad, + const Nd4jLong* gradInOffsets, const Nd4jLong* gradOutTad, + const Nd4jLong* gradOutOffsets, const Nd4jLong* outTad, + const Nd4jLong* outOffsets) { + __shared__ const T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + auto current = x + inputOffsets[i]; + auto currentOut = z + outOffsets[i]; + auto in = gradIn + gradInOffsets[segment]; + auto outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + if (sd::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) + currentOut[e] = outGrad[e]; } + } +} - template - static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentMinFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentMinBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentMinFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentMinFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + + segmentMinBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMinBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // segmen min +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentMinFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - segmentMinBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentMinFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +template +static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + // int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentMinFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentMinBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentMinBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentMinFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 026ded3e7525..e59980e4635c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -18,366 +18,462 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment Prod ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void segmentProdLinearKernel(void* input, Nd4jLong const* inputShape, int* starts, int* lengths, - Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { - - __shared__ Nd4jLong xLen, zLen; - __shared__ T* x; - __shared__ T* z; - - if (threadIdx.x == 0) { - x = reinterpret_cast(input); - z = reinterpret_cast(output); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); - - for(auto segment = blockIdx.x; segment < numOfClasses; segment += gridDim.x) { - auto zIndex = shape::getIndexOffset(segment, outputShape); - auto start = starts[segment]; - auto finish = start + lengths[segment]; - if (lengths[segment] == 0) { - continue; - } - for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]); - } - } - - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentProdLinearKernel(T* input, Nd4jLong const* inputShape, I* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - for (auto idx = start; idx < xLen; idx += step) { - auto xIndex = shape::getIndexOffset(idx, inputShape); - auto yIndex = shape::getIndexOffset(idx, indicesShape); - auto segment = indices[yIndex]; - auto zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] == 0) { - continue; - } - sd::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]); - } +// -------------------------------------------------------------------------------------------------------------- +// // Segment Prod ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void segmentProdLinearKernel( + void* input, Nd4jLong const* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; + __shared__ T* x; + __shared__ T* z; + + if (threadIdx.x == 0) { + x = reinterpret_cast(input); + z = reinterpret_cast(output); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); + + for (auto segment = blockIdx.x; segment < numOfClasses; + segment += gridDim.x) { + auto zIndex = shape::getIndexOffset(segment, outputShape); + auto start = starts[segment]; + auto finish = start + lengths[segment]; + if (lengths[segment] == 0) { + continue; } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentProd kernel - template - static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, - Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, - Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { - - __shared__ Nd4jLong len, total; - - if (threadIdx.x == 0) { - total = shape::sizeAt(inputShape, 0); - len = shape::length(inputTads); - } - __syncthreads(); - - for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - auto segment = indices[idx]; // / threadsPerSegment; - auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - auto start = starts[segment]; - auto finish = start + lengths[segment]; - if (lengths[segment] == 0) continue; - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); - } - } + for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]); } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - output->assign(1); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - segmentProdLinearKernel<<<128, 256, 128, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentProdTadKernel<<<128, 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentProdLinearKernel( + T* input, Nd4jLong const* inputShape, I* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + for (auto idx = start; idx < xLen; idx += step) { + auto xIndex = shape::getIndexOffset(idx, inputShape); + auto yIndex = shape::getIndexOffset(idx, indicesShape); + auto segment = indices[yIndex]; + auto zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] == 0) { + continue; } - // -------------------------------------------------------------------------------------------------------------- // - void segmentProdFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); + sd::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentProd kernel +template +static __global__ void segmentProdTadKernel( + void* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ Nd4jLong len, total; + + if (threadIdx.x == 0) { + total = shape::sizeAt(inputShape, 0); + len = shape::length(inputTads); + } + __syncthreads(); + + for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + auto segment = indices[idx]; // / threadsPerSegment; + auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + auto start = starts[segment]; + auto finish = start + lengths[segment]; + if (lengths[segment] == 0) continue; + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentProdFunctor_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + output->assign(1); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + segmentProdLinearKernel<<<128, 256, 128, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentProdTadKernel<<<128, 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), + segmentProdFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - output->assign(1); - - if (input->isVector()) { - unsortedSegmentProdLinearKernel<<<128, 256, 256, *stream>>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), - indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), begins, lengths, numOfClasses, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentProdTadKernel<<<128, 256, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentProdFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentProdFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + output->assign(1); + + if (input->isVector()) { + unsortedSegmentProdLinearKernel<<<128, 256, 256, *stream>>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), + begins, lengths, numOfClasses, output->dataBuffer()->specialAsT(), + output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentProdTadKernel<<<128, 256, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentProdFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradIn = reinterpret_cast(forwardOutput); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset]; - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, - Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, Nd4jLong const* gradInOffsets, - Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, - Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradIn = reinterpret_cast(forwardOutput); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - T* current = x + inputOffsets[i]; - T* currentOut = z + outOffsets[i]; - T* in = gradIn + gradInOffsets[segment]; - T* outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - currentOut[e] = outGrad[e] * in[e] / current[e]; - } - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentProdBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradIn = reinterpret_cast(forwardOutput); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape); + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset]; + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentProdBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* forwardOutput, + Nd4jLong const* forwardShape, void* eps, Nd4jLong const* epsShape, + void* indicesBuf, Nd4jLong const* indicesShape, void* outputBuf, + Nd4jLong const* outputShape, Nd4jLong const* inputTad, + Nd4jLong const* inputOffsets, Nd4jLong const* gradInTad, + Nd4jLong const* gradInOffsets, Nd4jLong const* gradOutTad, + Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, + Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradIn = reinterpret_cast(forwardOutput); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + T* current = x + inputOffsets[i]; + T* currentOut = z + outOffsets[i]; + T* in = gradIn + gradInOffsets[segment]; + T* outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + currentOut[e] = outGrad[e] * in[e] / current[e]; } + } +} - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentProdFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - segmentProdFunctor_(context, input, indices, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loopSize = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentProdBPLinearKernel<<lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentProdBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentProdFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + segmentProdFunctor_(context, input, indices, &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loopSize = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentProdBPLinearKernel + <<lengthOf(), loopSize, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentProdBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // - int segmentProdFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentProdFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentProdFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - - NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); - unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, &tempRes); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loopSize = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentProdBPLinearKernel<<lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradInTads = packGradIn.specialShapeInfo(); - auto gradInTadOffsets = packGradIn.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentProdBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static int unsortedSegmentProdFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), + DataTypeUtils::fromT(), + context); //->shapeInfo(), context); + unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, + &tempRes); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loopSize = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentProdBPLinearKernel + <<lengthOf(), loopSize, 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tempRes.shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradInTads = packGradIn.specialShapeInfo(); + auto gradInTadOffsets = packGradIn.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentProdBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + tempRes.specialBuffer(), tempRes.specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, + gradOutTadOffsets, outputTads, outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentProdFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentProdFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} - // -------------------------------------------------------------------------------------------------------------- // +// -------------------------------------------------------------------------------------------------------------- +// // -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index b72abeffc14a..e40388072250 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -18,240 +18,307 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void unsortedSegmentSqrtNLinearKernel(T* input, Nd4jLong const* inputShape, I* indices, Nd4jLong const* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - } - __syncthreads(); +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void unsortedSegmentSqrtNLinearKernel( + T* input, Nd4jLong const* inputShape, I* indices, + Nd4jLong const* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, T* output, Nd4jLong const* outputShape) { + __shared__ Nd4jLong xLen, zLen; - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + } + __syncthreads(); - for (auto idx = start; idx < xLen; idx += step) { - auto yIndex = shape::getIndexOffset(idx, indicesShape); - auto segment = indices[yIndex]; - auto zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] == 0) continue; - auto xIndex = shape::getIndexOffset(idx, inputShape); + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; - sd::math::atomics::nd4j_atomicAdd(&output[zIndex], input[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentSqrtN kernel - template - static __global__ void segmentSqrtNTadKernel(T* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + for (auto idx = start; idx < xLen; idx += step) { + auto yIndex = shape::getIndexOffset(idx, indicesShape); + auto segment = indices[yIndex]; + auto zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] == 0) continue; + auto xIndex = shape::getIndexOffset(idx, inputShape); - __shared__ Nd4jLong len, total; + sd::math::atomics::nd4j_atomicAdd( + &output[zIndex], + input[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentSqrtN kernel +template +static __global__ void segmentSqrtNTadKernel( + T* inputBuf, Nd4jLong const* inputShape, Nd4jLong const* inputTads, + Nd4jLong const* inputTadOffsets, I* indices, int* starts, int* lengths, + Nd4jLong numOfClasses, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* outputTads, Nd4jLong const* outputTadOffsets) { + __shared__ Nd4jLong len, total; - if (threadIdx.x == 0) { - total = shape::sizeAt(inputShape, 0); - len = shape::length(inputTads); - } - __syncthreads(); + if (threadIdx.x == 0) { + total = shape::sizeAt(inputShape, 0); + len = shape::length(inputTads); + } + __syncthreads(); - for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { - auto segment = indices[idx]; - auto x = inputBuf + inputTadOffsets[idx]; - auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - auto start = starts[segment]; - auto finish = start + lengths[segment]; + for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) { + auto segment = indices[idx]; + auto x = inputBuf + inputTadOffsets[idx]; + auto z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + auto start = starts[segment]; + auto finish = start + lengths[segment]; - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); -// dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); - dim3 dims(128, 256, 256); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - output->nullify(); - if (input->isVector()) { - unsortedSegmentSqrtNLinearKernel<<>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), - indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), begins, lengths, numOfClasses, - output->dataBuffer()->specialAsT(), output->specialShapeInfo()); - } - else { - output->nullify(); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentSqrtNTadKernel<<>>( - input->dataBuffer()->specialAsT(), input->specialShapeInfo(), inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT(), - begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd( + &z[zIndex], + x[xIndex] / sd::math::nd4j_sqrt(lengths[segment])); } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, - int* lengths, void* outputBuf, Nd4jLong const* outputShape) { - __shared__ T* x; - __shared__ T* gradIn; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - } - __syncthreads(); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentSqrtNFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + // dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + + // 32); + dim3 dims(128, 256, 256); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + output->nullify(); + if (input->isVector()) { + unsortedSegmentSqrtNLinearKernel<<>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + indices->dataBuffer()->specialAsT(), indices->specialShapeInfo(), + begins, lengths, numOfClasses, output->dataBuffer()->specialAsT(), + output->specialShapeInfo()); + } else { + output->nullify(); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentSqrtNTadKernel<<>>( + input->dataBuffer()->specialAsT(), input->specialShapeInfo(), + inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT(), + begins, lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentSqrtNFunctor_, + (context, input, indices, numOfClasses, output), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSqrtNBPLinearKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape) { + __shared__ T* x; + __shared__ T* gradIn; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, gradLen; - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + } + __syncthreads(); - for (auto e = start; e < xLen; e += step) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt(lengths[classIndex])); - } - } - // -------------------------------------------------------------------------------------------------------------- // + z[zOffset] = T(gradOut[gradOffsetO] / + math::nd4j_sqrt(lengths[classIndex])); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // - template - static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong const* inputShape, void* eps, Nd4jLong const* epsShape, - void* indicesBuf, Nd4jLong const* indicesShape, int* lengths, void* outputBuf, Nd4jLong const* outputShape,Nd4jLong const* inputTad, - Nd4jLong const* inputOffsets, Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, Nd4jLong const* outTad, Nd4jLong const* outOffsets) { - __shared__ T* x; - __shared__ T* gradOut; - __shared__ I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; +template +static __global__ void segmentSqrtNBPTadKernel( + void* inputBuf, Nd4jLong const* inputShape, void* eps, + Nd4jLong const* epsShape, void* indicesBuf, Nd4jLong const* indicesShape, + int* lengths, void* outputBuf, Nd4jLong const* outputShape, + Nd4jLong const* inputTad, Nd4jLong const* inputOffsets, + Nd4jLong const* gradOutTad, Nd4jLong const* gradOutOffsets, + Nd4jLong const* outTad, Nd4jLong const* outOffsets) { + __shared__ T* x; + __shared__ T* gradOut; + __shared__ I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { -// auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[i]; //yIndex]; - T* currentOut = z + outOffsets[i]; - T* outGrad = gradOut + gradOutOffsets[segment]; + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + // auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[i]; // yIndex]; + T* currentOut = z + outOffsets[i]; + T* outGrad = gradOut + gradOutOffsets[segment]; - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - auto zIndex = shape::getIndexOffset(e, outTad); - auto gradIndex = shape::getIndexOffset(e, gradOutTad); - if (lengths[segment] > 0) - currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt(lengths[segment])); - } - } + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + auto zIndex = shape::getIndexOffset(e, outTad); + auto gradIndex = shape::getIndexOffset(e, gradOutTad); + if (lengths[segment] > 0) + currentOut[zIndex] = T(outGrad[gradIndex] / + math::nd4j_sqrt(lengths[segment])); } - // -------------------------------------------------------------------------------------------------------------- // + } +} +// -------------------------------------------------------------------------------------------------------------- +// // - template - static int unsortedSegmentSqrtNFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - auto numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); +template +static int unsortedSegmentSqrtNFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, + Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + auto numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSqrtNBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); -// auto packGradIn = sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSqrtNBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + // auto packGradIn = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.shapeInfo(), + // dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); - segmentSqrtNBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, - output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + segmentSqrtNBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), lengths, + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentSqrtNFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 7a762a52602a..c1030d2993b6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -18,393 +18,448 @@ // @author GS // -#include -#include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - // -------------------------------------------------------------------------------------------------------------- // - // Segment ops linear kernels - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void - segmentSumLinearKernel( - const void *input, const Nd4jLong *inputShape, - int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - - if (segment < numOfClasses) { - zIndex = shape::getIndexOffset(segment, outputShape); - start = starts[segment]; - finish = start + lengths[segment]; - //val[segment] = ; - z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; - } - - } - __syncthreads(); - - for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static __global__ void - unsortedSegmentSumLinearKernel( - const void *input, const Nd4jLong *inputShape, - const void *indices, const Nd4jLong *indicesShape, - int *starts, int *lengths, Nd4jLong numOfClasses, - void *output, const Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - const T *x; - __shared__ - T *z; - __shared__ - const I *y; //int threadsPerSegment, start, finish; - - if (threadIdx.x == 0) { - segment = blockIdx.x; - x = reinterpret_cast(input); - z = reinterpret_cast(output); - y = reinterpret_cast(indices); - xLen = shape::length(inputShape); - zLen = shape::length(outputShape); - - zIndex = shape::getIndexOffset(segment, outputShape); - if (lengths[segment] > 0) - z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; - else - z[zIndex] = 0; //DataTypeUtils::max(); - } - __syncthreads(); - - if (lengths[segment] > 0) - for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputShape); - auto yIndex = shape::getIndexOffset(e, indicesShape); - if (y[yIndex] == segment && e != starts[segment]) { - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - // SegmentSum kernel - template - static __global__ void segmentSumTadKernel( - const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, const Nd4jLong* inputTadOffsets, - const I* indices, - int* starts, int* lengths, Nd4jLong numOfClasses, - void* outputBuf, const Nd4jLong* outputShape, const Nd4jLong* outputTads, const Nd4jLong* outputTadOffsets) { - __shared__ T* val; - __shared__ Nd4jLong len, zIndex, total; - __shared__ T* z; - __shared__ int start, finish; - - if (threadIdx.x == 0) { - auto segment = indices[blockIdx.x]; // / threadsPerSegment; - z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; - len = shape::length(inputTads); - start = starts[segment]; - finish = start + lengths[segment]; - total = shape::sizeAt(inputShape, 0); - - } - __syncthreads(); - - auto idx = blockIdx.x; - if (blockIdx.x <= total) { - auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; - if (blockIdx.x == start) { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - else { - for (auto e = threadIdx.x; e < len; e += blockDim.x) { - auto xIndex = shape::getIndexOffset(e, inputTads); - auto zIndex = shape::getIndexOffset(e, outputTads); - if (lengths[indices[idx]]) - sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); - } - } - } - } - // -------------------------------------------------------------------------------------------------------------- // - - template - static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { - auto stream = context->getCudaStream(); - Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; - NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}, context); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}, context); - - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - - dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); - fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - segmentSumLinearKernel<<lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - segmentSumTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - - } - // -------------------------------------------------------------------------------------------------------------- // - void segmentSumFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - } - - // -------------------------------------------------------------------------------------------------------------- // - template - static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); -// NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); - NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}, context); - NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}, context); -// NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); -// classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, &classes); - classesRangesBegs.assign(indices->lengthOf()); - classesRangesLens.assign(0); - dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64); -// int* classesBuf = reinterpret_cast(classes.specialBuffer()); - fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); - int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); - int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); - - if (input->isVector()) { - unsortedSegmentSumLinearKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); - } - else { - output->assign(0); - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - dims.x = input->sizeAt(0); - segmentSumTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); - } - +// -------------------------------------------------------------------------------------------------------------- +// // Segment ops linear kernels +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSumLinearKernel( + const void* input, const Nd4jLong* inputShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; + segment = blockIdx.x / threadsPerSegment; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + if (segment < numOfClasses) { + zIndex = shape::getIndexOffset(segment, outputShape); + start = starts[segment]; + finish = start + lengths[segment]; + // val[segment] = ; + z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; } - // -------------------------------------------------------------------------------------------------------------- // - void unsortedSegmentSumFunctor(sd::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices}); - output->nullify(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices}); - - } - - // -------------------------------------------------------------------------------------------------------------- // - // Backpropagate ops - // -------------------------------------------------------------------------------------------------------------- // - // Sorted sum backpropagate - template - static __global__ void segmentSumBPLinearKernel( - const void* inputBuf, const Nd4jLong* inputShape, - const void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape) { - auto x = reinterpret_cast(inputBuf); - auto y = reinterpret_cast(indicesBuf); - auto z = reinterpret_cast(outputBuf); - auto gradOut = reinterpret_cast(eps); - __shared__ Nd4jLong xLen, gradLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - gradLen = shape::length(epsShape); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = gridDim.x * blockDim.x; - - for (auto e = start; e < xLen; e += step) { - - auto zOffset = shape::getIndexOffset(e, outputShape); - auto xOffset = shape::getIndexOffset(e, inputShape); - auto yOffset = shape::getIndexOffset(e, indicesShape); - auto classIndex = y[yOffset]; - auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); - - z[zOffset] = gradOut[gradOffsetO]; - } - } - // -------------------------------------------------------------------------------------------------------------- // - template - static __global__ void segmentSumBPTadKernel( - const void* inputBuf, const Nd4jLong* inputShape, - const void* eps, const Nd4jLong* epsShape, - const void* indicesBuf, const Nd4jLong* indicesShape, - void* outputBuf, const Nd4jLong* outputShape, - const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, - const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, - const Nd4jLong* outTad, const Nd4jLong* outOffsets) { - __shared__ const T* x; - __shared__ const T* gradOut; - __shared__ const I* y; - __shared__ T* z; - __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; - - if (threadIdx.x == 0) { - xLen = shape::length(inputShape); - x = reinterpret_cast(inputBuf); - y = reinterpret_cast(indicesBuf); - z = reinterpret_cast(outputBuf); - yLen = shape::length(indicesShape); - gradOut = reinterpret_cast(eps); - gradLen = shape::length(epsShape); - currentLen = shape::length(outTad); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { - auto yIndex = shape::getIndexOffset(i, indicesShape); - auto segment = y[yIndex]; - auto currentOut = z + outOffsets[i]; - auto outGrad = gradOut + gradOutOffsets[segment]; - - for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { - currentOut[e] = outGrad[e]; - } - } + } + __syncthreads(); + for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static __global__ void unsortedSegmentSumLinearKernel( + const void* input, const Nd4jLong* inputShape, const void* indices, + const Nd4jLong* indicesShape, int* starts, int* lengths, + Nd4jLong numOfClasses, void* output, const Nd4jLong* outputShape) { + __shared__ T* val; + __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ const T* x; + __shared__ T* z; + __shared__ const I* y; // int threadsPerSegment, start, finish; + + if (threadIdx.x == 0) { + segment = blockIdx.x; + x = reinterpret_cast(input); + z = reinterpret_cast(output); + y = reinterpret_cast(indices); + xLen = shape::length(inputShape); + zLen = shape::length(outputShape); + + zIndex = shape::getIndexOffset(segment, outputShape); + if (lengths[segment] > 0) + z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)]; + else + z[zIndex] = 0; // DataTypeUtils::max(); + } + __syncthreads(); + + if (lengths[segment] > 0) + for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputShape); + auto yIndex = shape::getIndexOffset(e, indicesShape); + if (y[yIndex] == segment && e != starts[segment]) { + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // - template - int segmentSumFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSumBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentSumBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // SegmentSum kernel +template +static __global__ void segmentSumTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const Nd4jLong* inputTads, + const Nd4jLong* inputTadOffsets, const I* indices, int* starts, + int* lengths, Nd4jLong numOfClasses, void* outputBuf, + const Nd4jLong* outputShape, const Nd4jLong* outputTads, + const Nd4jLong* outputTadOffsets) { + __shared__ T* val; + __shared__ Nd4jLong len, zIndex, total; + __shared__ T* z; + __shared__ int start, finish; + + if (threadIdx.x == 0) { + auto segment = indices[blockIdx.x]; // / threadsPerSegment; + z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; + len = shape::length(inputTads); + start = starts[segment]; + finish = start + lengths[segment]; + total = shape::sizeAt(inputShape, 0); + } + __syncthreads(); + + auto idx = blockIdx.x; + if (blockIdx.x <= total) { + auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; + if (blockIdx.x == start) { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } + } else { + for (auto e = threadIdx.x; e < len; e += blockDim.x) { + auto xIndex = shape::getIndexOffset(e, inputTads); + auto zIndex = shape::getIndexOffset(e, outputTads); + if (lengths[indices[idx]]) + sd::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); + } } - // -------------------------------------------------------------------------------------------------------------- // + } +} +// -------------------------------------------------------------------------------------------------------------- +// // + +template +static void segmentSumFunctor_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + auto stream = context->getCudaStream(); + Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; + NDArray classesRangesLens = + NDArrayFactory::create('c', {numClasses}, context); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numClasses}, context); + + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + + dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); + fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + segmentSumLinearKernel + <<lengthOf(), numClasses * 32 + 32, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), begins, lengths, + numClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + segmentSumTadKernel<<sizeAt(0), 512, 2048, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + segmentSumFunctor_, (context, input, indices, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - int segmentSumFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - } +// -------------------------------------------------------------------------------------------------------------- +// // +template +static void unsortedSegmentSumFunctor_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + Nd4jLong numOfClasses, NDArray* output) { + auto stream = context->getCudaStream(); + // NDArray classes = NDArrayFactory::create('c', {numOfClasses, + // 2}); + NDArray classesRangesBegs = + NDArrayFactory::create('c', {numOfClasses}, context); + NDArray classesRangesLens = + NDArrayFactory::create('c', {numOfClasses}, context); + // NDArray row = NDArrayFactory::create('c', {1, 2}, + // {(int)indices->lengthOf(), (int)0}); + // classes.applyTrueBroadcast(sd::BroadcastOpsTuple::Assign(), &row, + // &classes); + classesRangesBegs.assign(indices->lengthOf()); + classesRangesLens.assign(0); + dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64); + // int* classesBuf = reinterpret_cast(classes.specialBuffer()); + fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); + int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); + int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); + + if (input->isVector()) { + unsortedSegmentSumLinearKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, + numOfClasses, output->specialBuffer(), output->specialShapeInfo()); + } else { + output->assign(0); + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + dims.x = input->sizeAt(0); + segmentSumTadKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), inputTads, + inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, + lengths, numOfClasses, output->specialBuffer(), + output->specialShapeInfo(), outputTads, outputTadOffsets); + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), + unsortedSegmentSumFunctor_, + (context, input, indices, numOfClasses, output), + NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); +} - template - static int unsortedSegmentSumFunctorBP_(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - auto numOfClasses = gradOut->lengthOf(); //indices->e(loop_size - 1); - segmentSumBPLinearKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), - input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); - } - else { - std::vector dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->shapeInfo(), dimensions); - auto inputTads = packX.specialShapeInfo(); - auto inputTadOffsets = packX.specialOffsets(); - auto outputTads = packZ.specialShapeInfo(); - auto outputTadOffsets = packZ.specialOffsets(); - auto gradOutTads = packGradOut.specialShapeInfo(); - auto gradOutTadOffsets = packGradOut.specialOffsets(); - - segmentSumBPTadKernel<<lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), - gradOut->specialBuffer(), gradOut->specialShapeInfo(), - indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), - inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, - outputTads, outputTadOffsets); - } - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); - return Status::OK(); - } - // -------------------------------------------------------------------------------------------------------------- // - int unsortedSegmentSumFunctorBP(sd::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +// -------------------------------------------------------------------------------------------------------------- +// // Backpropagate ops +// -------------------------------------------------------------------------------------------------------------- +// // Sorted sum backpropagate +template +static __global__ void segmentSumBPLinearKernel( + const void* inputBuf, const Nd4jLong* inputShape, const void* eps, + const Nd4jLong* epsShape, const void* indicesBuf, + const Nd4jLong* indicesShape, void* outputBuf, + const Nd4jLong* outputShape) { + auto x = reinterpret_cast(inputBuf); + auto y = reinterpret_cast(indicesBuf); + auto z = reinterpret_cast(outputBuf); + auto gradOut = reinterpret_cast(eps); + __shared__ Nd4jLong xLen, gradLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + gradLen = shape::length(epsShape); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start; e < xLen; e += step) { + auto zOffset = shape::getIndexOffset(e, outputShape); + auto xOffset = shape::getIndexOffset(e, inputShape); + auto yOffset = shape::getIndexOffset(e, indicesShape); + auto classIndex = y[yOffset]; + auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape); + + z[zOffset] = gradOut[gradOffsetO]; + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +static __global__ void segmentSumBPTadKernel( + const void* inputBuf, const Nd4jLong* inputShape, const void* eps, + const Nd4jLong* epsShape, const void* indicesBuf, + const Nd4jLong* indicesShape, void* outputBuf, const Nd4jLong* outputShape, + const Nd4jLong* inputTad, const Nd4jLong* inputOffsets, + const Nd4jLong* gradOutTad, const Nd4jLong* gradOutOffsets, + const Nd4jLong* outTad, const Nd4jLong* outOffsets) { + __shared__ const T* x; + __shared__ const T* gradOut; + __shared__ const I* y; + __shared__ T* z; + __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; + + if (threadIdx.x == 0) { + xLen = shape::length(inputShape); + x = reinterpret_cast(inputBuf); + y = reinterpret_cast(indicesBuf); + z = reinterpret_cast(outputBuf); + yLen = shape::length(indicesShape); + gradOut = reinterpret_cast(eps); + gradLen = shape::length(epsShape); + currentLen = shape::length(outTad); + } + __syncthreads(); + + for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { + auto yIndex = shape::getIndexOffset(i, indicesShape); + auto segment = y[yIndex]; + auto currentOut = z + outOffsets[i]; + auto outGrad = gradOut + gradOutOffsets[segment]; + + for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { + currentOut[e] = outGrad[e]; } + } +} +// -------------------------------------------------------------------------------------------------------------- +// // +template +int segmentSumFunctorBP_(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSumBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentSumBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); +} +// -------------------------------------------------------------------------------------------------------------- +// // + +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), return segmentSumFunctorBP_, + (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); +} +template +static int unsortedSegmentSumFunctorBP_(sd::LaunchContext* context, + NDArray* input, NDArray* indices, + NDArray* gradOut, Nd4jLong numOfClasses, + NDArray* output) { + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + if (input->isVector()) { + Nd4jLong loop_size = input->lengthOf(); + auto numOfClasses = + gradOut->lengthOf(); // indices->e(loop_size - 1); + segmentSumBPLinearKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo()); + } else { + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), dimensions); + auto packGradOut = sd::ConstantTadHelper::getInstance()->tadForDimensions( + gradOut->shapeInfo(), dimensions); + auto inputTads = packX.specialShapeInfo(); + auto inputTadOffsets = packX.specialOffsets(); + auto outputTads = packZ.specialShapeInfo(); + auto outputTadOffsets = packZ.specialOffsets(); + auto gradOutTads = packGradOut.specialShapeInfo(); + auto gradOutTadOffsets = packGradOut.specialOffsets(); + + segmentSumBPTadKernel + <<lengthOf(), input->lengthOf(), 256, *stream>>>( + input->specialBuffer(), input->specialShapeInfo(), + gradOut->specialBuffer(), gradOut->specialShapeInfo(), + indices->specialBuffer(), indices->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), inputTads, + inputTadOffsets, gradOutTads, gradOutTadOffsets, outputTads, + outputTadOffsets); + } + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); + return Status::OK(); } +// -------------------------------------------------------------------------------------------------------------- +// // +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); + BUILD_DOUBLE_SELECTOR( + output->dataType(), indices->dataType(), + return unsortedSegmentSumFunctorBP_, + (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, + INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu index 51b7590c0ad0..bd3a854fc408 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu @@ -24,41 +24,53 @@ namespace sd { namespace ops { namespace helpers { - template - static __global__ void sequenceMaskKernel(const void* inputBuf, const Nd4jLong* inputShape, void* outputBuf, const Nd4jLong* outputShape, int maxIndex) { +template +static __global__ void sequenceMaskKernel(const void* inputBuf, + const Nd4jLong* inputShape, + void* outputBuf, + const Nd4jLong* outputShape, + int maxIndex) { + __shared__ const I* input; + __shared__ B* output; + __shared__ Nd4jLong inputLen, outputLen; + if (threadIdx.x == 0) { + input = reinterpret_cast(inputBuf); + output = reinterpret_cast(outputBuf); + inputLen = shape::length(inputShape); + outputLen = shape::length(outputShape); + } + __syncthreads(); - __shared__ const I* input; - __shared__ B* output; - __shared__ Nd4jLong inputLen, outputLen; - if (threadIdx.x == 0) { - input = reinterpret_cast(inputBuf); - output = reinterpret_cast(outputBuf); - inputLen = shape::length(inputShape); - outputLen = shape::length(outputShape); - } - __syncthreads(); - - for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x) - for(auto k = threadIdx.x; k < inputLen; k += blockDim.x) - if (i < input[shape::getIndexOffset(k, inputShape)]) - output[shape::getIndexOffset(k * maxIndex + i, outputShape)] = B(true); - - } - - template - static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) { - dim3 launchDims(maxIndex, input->lengthOf(), 128); - NDArray::prepareSpecialUse({output}, {input}); - auto stream = context->getCudaStream(); - sequenceMaskKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex); - NDArray::registerSpecialUse({output}, {input}); - } - - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); - } + for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x) + for (auto k = threadIdx.x; k < inputLen; k += blockDim.x) + if (i < input[shape::getIndexOffset(k, inputShape)]) + output[shape::getIndexOffset(k * maxIndex + i, outputShape)] = B(true); +} - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (sd::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +template +static void sequenceMask_(LaunchContext* context, NDArray* input, + NDArray* output, int maxIndex) { + dim3 launchDims(maxIndex, input->lengthOf(), 128); + NDArray::prepareSpecialUse({output}, {input}); + auto stream = context->getCudaStream(); + sequenceMaskKernel + <<>>( + input->specialBuffer(), input->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), maxIndex); + NDArray::registerSpecialUse({output}, {input}); } + +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, + (context, input, output, maxIndex), INTEGER_TYPES, + LIBND4J_TYPES_EXTENDED); } -} \ No newline at end of file + +BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, + (sd::LaunchContext * context, NDArray* input, + NDArray* output, int maxIndex), + INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu index 3957f23d5ada..99e6c2b8d1e1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu @@ -18,745 +18,934 @@ // @author raver119@gmail.com // -#include -#include #include +#include +#include #define HS_MAX_EXP 6.0f namespace sd { - namespace ops { - namespace helpers { - template - __global__ void hSoftmaxKernel(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot(0.0f); - T g(0.0f); - T f(0.0f); - - // dot - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1[e]; - } - - // gradient - if (dot < (T) - HS_MAX_EXP || dot >= (T) HS_MAX_EXP) - return; - - - int idx = static_cast((dot + HS_MAX_EXP) * ((float) expLength / HS_MAX_EXP / 2.0f)); - - if (idx >= expLength || idx < 0) - return; - - f = expTable[idx]; - g = (static_cast(1.0f) - static_cast(code) - f) * (T) alpha; - - // axpy1 - - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1[e] = g * syn0[e] + syn1[e]; - } - } - } - - template - void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) { - hSoftmaxKernel<<<1,1,128, *stream>>>(vsyn0, vsyn1, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference); - } - - template - __global__ void nSamplingKernel(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto neu1e = reinterpret_cast(vneu1e); - - T dot = (T) 0.0f; - T g = (T) 0.0f; - - for (int e = 0; e < vectorLength; e++) { - dot += syn0[e] * syn1Neg[e]; - } - - if (dot > HS_MAX_EXP) - g = (code - 1) * alpha; - else if (dot < (T) - HS_MAX_EXP) - g = (code - 0) * alpha; - else { - int idx = (int) ((dot + (T) HS_MAX_EXP) * ((T) expLength / HS_MAX_EXP / 2.0)); - if (idx >= expLength) - return; - - if (idx < 0) - return; - - g = ((T) code - expTable[idx]) * alpha; - } - - // axpy1 - for (int e = 0; e < vectorLength; e++) { - neu1e[e] = g * syn1Neg[e] + neu1e[e]; - } - - // axpy2 - if (!isInference) { - for (int e = 0; e < vectorLength; e++) { - syn1Neg[e] = g * syn0[e] + syn1Neg[e]; - } - } - } - - template - void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) { - nSamplingKernel<<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference); - } - - /* - * binarySearch - find element in haystack buffer (haystack - sorted device memory) - * */ - int binarySearch(const int *haystack, const int needle, const int totalElements) { - int firstIndex = 0; - int lastIndex = totalElements - 1; - int halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - - while(haystack[halfIndex] != needle && firstIndex < lastIndex) { - if (needle < haystack[halfIndex]) { - lastIndex = halfIndex - 1; - } else if (needle > haystack[halfIndex]) { - firstIndex = halfIndex + 1; - } - halfIndex = sd::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); - } - - return (haystack[halfIndex] == needle) ? halfIndex : -1; - } - template - __global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto i = start; i < vectorLength; i += step) { - neu1[i] += infVector[i]; - } - } - - template - void skipgram_(NDArray& s0, NDArray& s1, NDArray& s1n, NDArray& expTableV, NDArray& negTableV, NDArray& infV, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds) { -// void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) { - auto syn0 = reinterpret_cast(s0.specialBuffer()); - auto syn1 = reinterpret_cast(s1.specialBuffer()); - auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); - auto expTable = reinterpret_cast(expTableV.specialBuffer()); - auto negTable = reinterpret_cast(negTableV.specialBuffer()); - auto infVector = reinterpret_cast(infV.specialBuffer()); - const int vocabSize = s0.sizeAt(0); - const int vectorLength = s0.sizeAt(1); - const int expLength = expTableV.lengthOf(); - const int negLength = negTableV.lengthOf(); - indices.tickReadDevice(); - indices.syncToHost(); - codes.tickReadDevice(); - codes.syncToHost(); - auto stream = s0.getContext()->getCudaStream(); - - T* neu1e; // = new T[vectorLength]; - //memset(neu1e, 0, vectorLength * sizeof(T)); - auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength); - err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength); - // hierarchic softmax goes first (if enabled) - - auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength); - auto irow = 0; - if (hsRounds > 0) { - for (int r = 0; r < hsRounds; r++) { - irow = indices.t(r); - if (irow < 0 || irow >= vocabSize) - break; - - hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes.t(r), expLength, infVector != nullptr, stream); - } - } - - // negative sampling goes second (if enabled) - auto nsStarter = ngStarter; - irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : negTableV.e(idx); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - } - - if (infVector == nullptr) { - addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); - } else { - addInfVectorKernel<<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); - } - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build("helpers::skipgram_: Cannot synchronize stream after addInfVectorKernel", err); - } - - err = cudaFree(neu1e); - if (0 != err) { - throw cuda_exception::build("helpers::skipgram_: Cannot deallocate temp memory for lingual net", err); - } - } - BUILD_SINGLE_TEMPLATE(template void skipgram_, (NDArray& syn0, NDArray& syn1, NDArray& syn1Neg, NDArray& expTable, NDArray& negTable, NDArray& infVector, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds), FLOAT_TYPES); - - /* - * batched version of skipgram routine - * */ - template - void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTableV, NDArray& negTableV, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { -// (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray& infVector, NDArray& targets, NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& lr, NDArray& nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { - //auto syn0 = reinterpret_cast(vsyn0); - //auto syn1 = reinterpret_cast(vsyn1); - //auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto stream = s0.getContext()->getCudaStream(); - negTableV.tickReadDevice(); - negTableV.syncToHost(); - const auto expTable = reinterpret_cast(expTableV.specialBuffer()); - const auto negTable = reinterpret_cast(negTableV.buffer()); - const auto infVector = (T*)nullptr; //reinterpret_cast(infVector.specialBuffer()); - - const int vocabSize = s0.sizeAt(0); - const int vectorLength = s0.sizeAt(1); - const int expLength = expTableV.lengthOf(); - const int negLength = negTableV.lengthOf(); - - //T sneu1e[600]; - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - - // regular mode provides 0 guarantees for reproducibility - auto numTargets = targets.lengthOf(); - targets.syncToHost(); - indices.syncToHost(); - codes.syncToHost(); - lr.syncToHost(); - nextRandom.syncToHost(); - negStarters.tickReadDevice(); - negStarters.syncToHost(); - auto bTarget = reinterpret_cast(targets.buffer()); //targets.bufferAsT(); - auto bIndices = reinterpret_cast(indices.buffer()); //indices.bufferAsT(); - auto bCodes = reinterpret_cast(codes.buffer()); //codes.bufferAsT(); - -// PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads)) - for (int t = 0; t < numTargets; t++) { - T* neu1e;//lvectorLength <= 600 ? sneu1e : new T[vectorLength]; - auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T)); - err = cudaMemset(neu1e, 0, vectorLength * sizeof(T)); - //memset(neu1e, 0, vectorLength * sizeof(T)); - - auto target = bTarget[t]; - auto alpha = lr.e(t); - unsigned long long randomValue = nextRandom.e(t); - - auto syn0row = reinterpret_cast(s0.specialBuffer()) + (target * vectorLength); - - if (hsRounds > 0) { - int irow = 0; - auto cShift = t * idxShift; - - for (int e = 0; e < hsRounds; e++) { - irow = bIndices[e + cShift]; - if (irow < 0 || irow >= vocabSize) - continue; - - auto syn1row = reinterpret_cast(s1.specialBuffer()) + (irow * vectorLength); - auto code = bCodes[e + cShift]; - - //nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code); - hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, expLength, false, stream); - } - } - - - if (nsRounds > 0) { - int irow = negStarters.e(t); - int nsStarter = irow; - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) - irow = randomValue % (vocabSize - 1) + 1; - - if (irow == nsStarter) - continue; - } - auto syn1row = reinterpret_cast(s1n.specialBuffer()) + (irow * vectorLength); - - nSampling_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, false, stream); - } - } - addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot synchronize stream after addInfVectorKernel", err); - } - - // optionally release temp arrays - err = cudaFree(neu1e); - if (err != 0) { - throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot deallocate memory with stage", err); - break; - } -// if (vectorLength > 600) -// delete[] neu1e; - } - } - BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads), FLOAT_TYPES); - - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, - NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { - auto xType = syn0.dataType(); - // single round case - if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) { - auto hsRounds = codes.lengthOf(); - target.syncToHost(); - ngStarter.syncToHost(); - alpha.syncToHost(); - randomValue.syncToHost(); - - auto targetV = target.isEmpty() ? -1 : target.e(0); - auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e(0); - auto alphaV = alpha.e(0); - auto randomV = randomValue.e(0); - BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), FLOAT_TYPES); - } else if (ngStarter.isVector() || target.isVector()){ - // batch mode -// NDArray& infVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) - BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("SkipGram: target must have rank 0 or 1"); - } - - template - static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) { - __shared__ bool hasError; - if (0 == threadIdx.x) { - hasError = false; - } - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int c = start; c < contextWidth; c += step) { - if (context[c] >= vocabSize) - hasError = true; //throw std::runtime_error("Bad context 4"); - if (!hasError) { - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - neu1[i] += syn0word[i]; - } - } - } - if (threadIdx.x == 0) { - if (hasError) - neu1[0] = DataTypeUtils::infOrMax(); - } - __syncthreads(); - } - - template - __global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < vectorLength; i += step) { - neu1[i] /= contextWidth + int(infVector != nullptr); // ? 1 : 0); - } - } - - template - __global__ void fillUpSynonymsKernel(int starter, int contextWidth, int vectorLength, int* lockedWords, int* context, T* neu1e, T* syn0) { - auto start = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int c = starter + start; c < contextWidth; c += step) { - if (lockedWords[c] == 1) - continue; - - T *syn0word = syn0 + (context[c] * vectorLength); - - for (int i = 0; i < vectorLength; i++) { - syn0word[i] += neu1e[i]; - } - } - } - - template - void cbow_(LaunchContext* lc, void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords) { - auto syn0 = reinterpret_cast(vsyn0); - auto syn1 = reinterpret_cast(vsyn1); - auto syn1Neg = reinterpret_cast(vsyn1Neg); - auto expTable = reinterpret_cast(vexpTable); - auto negTable = reinterpret_cast(vnegTable); - auto infVector = reinterpret_cast(vinfVector); - auto stream = lc->getCudaStream(); - - T* neu1; // = new T[vectorLength]; - T* neu1e; // = new T[vectorLength]; - size_t buffSize = sizeof(T) * vectorLength; - auto err = cudaMalloc(&neu1, buffSize); - err = cudaMalloc(&neu1e, buffSize); - err = cudaMemset(neu1, 0, buffSize); - err = cudaMemset(neu1e, 0, buffSize); - - // building neu1 for current window - checkContextKernel<<<1,1,128,*stream>>>(context, syn0, neu1, contextWidth, vectorLength, vocabSize); - - T checkVal; - err = cudaMemcpy(&checkVal, neu1, sizeof(T), cudaMemcpyDeviceToHost); - if (DataTypeUtils::infOrMax() == checkVal) - throw std::runtime_error("Bad context 4"); - // for inference we add additional inference vector - if (infVector != nullptr) { - addInfVectorKernel<<<128, 256, 128, *stream>>>(neu1, infVector, vectorLength); - } - - - // average neu1 - if (contextWidth > 0) { - shiftKernel<<<128, 256, 128, *stream>>>(neu1, infVector, contextWidth, vectorLength); - } - - // softmax round - if (hsRounds > 0) { - for (int i = 0; i < hsRounds; i++) { - if (indices[i] < 0 || indices[i] >= vocabSize) - throw std::runtime_error("Bad context 5"); - T* syn1Shifted = syn1 + (indices[i] * vectorLength); - hSoftmax_(neu1, syn1Shifted, expTable, neu1e, alpha, vectorLength, codes[i], expLength, infVector != nullptr, stream); - } - } - - auto nsStarter = ngStarter; - auto irow = nsStarter; - if (nsRounds > 0) { - for (int r = 0; r < nsRounds + 1; r++) { - if (r == 0) { - // target is known in advance - } else { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - } - - nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - } - - // if we don't train words - we skip start of idxSyn0 - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // propagate neu1e -> syn0 - if (infVector == nullptr) { - fillUpSynonymsKernel<<<1,1,128, *stream>>>(starter, contextWidth, vectorLength, lockedWords, context, neu1e, syn0); - } else { - - for (int i = 0; i < vectorLength; i++) { - infVector[i] += neu1e[i]; - } - } - err = cudaStreamSynchronize(*stream); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot synchronize stream after kernel executing", err); - } - err = cudaFree(neu1); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot deallocate memory for synonims table", err); - } - - err = cudaFree(neu1e); - if (0 != err) { - throw cuda_exception::build( - "helpers::cbow_: Cannot deallocate memory for antonims table", err); - } - } - BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); - - template - static __global__ void buildCurrentWindowKernel(int vocabSize, int contextWidth, int vectorLength, int* bContext, T* syn0, T* neu1, int* actualContext, int e) { - // building neu1 for current window - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int c = start; c < contextWidth; c += step) { - // getting next context word - auto cContext = bContext[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0) - continue; - -// if (cContext >= vocabSize) -// throw std::runtime_error("ContextID can't be >= vocab size"); - - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - neu1[i] += syn0word[i]; - - atomicAdd(actualContext, 1); - } - } - - template - __global__ void arrangeNeuKernel(int vectorLength, T* neu1, T* infVector, int* actualContext) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (int i = start; i < vectorLength && *actualContext > 0; i += step) - neu1[i] /= (*actualContext + int(infVector != nullptr)); - } - - template - __global__ void applyShiftKernel(int* bContext, int* bLocker, T* syn0, T* neu1e, int contextWidth, int vectorLength, int e, int starter) { - auto step = blockDim.x * gridDim.x; - auto start = blockDim.x * blockIdx.x + threadIdx.x; - - for (int c = starter + start; c < contextWidth; c += step) { - // getting context - auto cContext = bContext[c + (e * contextWidth)]; - auto cLock = bLocker[c + (e * contextWidth)]; - - // skipping padded values - if (cContext < 0 || cLock == 1) - continue; - -// if (cContext >= vocabSize) -// throw std::runtime_error("ContextID can't be > vocab size"); - - // one word from context - T *syn0word = syn0 + (cContext * vectorLength); - - for (int i = 0; i < vectorLength; i++) - syn0word[i] += neu1e[i]; - - } - } - - template - void cbowBatchExec_(LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads) { - const auto syn0 = reinterpret_cast(s0.specialBuffer()); //bufferAsT(); - const auto syn1 = reinterpret_cast(s1.specialBuffer()); //bufferAsT(); - const auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); //bufferAsT(); - - const auto expTable = reinterpret_cast(vexpTable); - const auto negTable = reinterpret_cast(vnegTable); - const auto infVector = reinterpret_cast(vinfVector); - - auto stream = lc->getCudaStream(); - - indices.syncToHost(); - codes.syncToHost(); - negStarters.syncToHost(); - context.syncToHost(); - - //const auto numThreads = omp_get_max_threads(); - const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); - const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); - const auto numTargets = context.sizeAt(0); - const int contextWidth = context.sizeAt(1); - //const auto bContext = reinterpret_cast(context.buffer()); //bufferAsT(); - const auto dContext = context.dataBuffer()->specialAsT(); //bufferAsT(); -// const auto bLocker = reinterpret_cast(lockedWords.buffer()); //lockedWords.bufferAsT(); - const auto dLocker = lockedWords.dataBuffer()->specialAsT(); //.specialBuffer()); //lockedWords.bufferAsT(); - const auto bIndices = indices.dataBuffer()->primaryAsT(); //buffer());//AsT(); - const auto bCodes = codes.dataBuffer()->primaryAsT(); //reinterpret_cast(codes.buffer()); //bufferAsT(); - const auto bStarters = negStarters.dataBuffer()->primaryAsT(); //reinterpret_cast(negStarters.buffer()); //AsT(); - const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); - lr.syncToHost(); - nLabels.syncToHost(); - //PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads) private(sneu1, sneu1e)) - //NDArray neuVector('c', {vectorLength}, DataTypeUtils::fromT()); - // auto neuEVector = neuVector; //NDArrayFactory::create('c', {vectorLength}); - T* neu1; // = reinterpret_cast(neuVector.specialBuffer());// = vectorLength <= 600 ? sneu1 : new T[vectorLength]; - T* neu1e; // = reinterpret_cast(neuVector.specialBuffer()); // = vectorLength <= 600 ? sneu1e : new T[vectorLength]; - auto cerr = cudaMalloc(&neu1, sizeof(T) * vectorLength); - if (cerr) { - throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); - } - cerr = cudaMalloc(&neu1e, sizeof(T) * vectorLength); - if (cerr) { - throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); - } - int* actualContext; - cerr = cudaMalloc(&actualContext, sizeof(int)); - if (cerr) { - throw cuda_exception::build("Cannot allocate counter buffer", cerr); - } - - for (int e = 0; e < numTargets; e++) { - -// auto err = cudaMalloc(&neu1, sizeof(T)* vectorLength); -// q err = cudaMalloc(&neu1e, sizeof(T)*vectorLength); -// -// // optionally we nullify temp arrays after successful (and on first) cycle -// memset(neu1, 0, sizeof(T) * vectorLength); -// memset(neu1e, 0, sizeof(T) * vectorLength); - - auto alpha = lr.e(e); - auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); - -// auto err = cudaMemset(actualContext, 0, sizeof(int)); -// if (err) { -// printf("Cuda error %d\n", err); break; -// } - - buildCurrentWindowKernel<<<1,1,128, *stream>>>(vocabSize, contextWidth, vectorLength, dContext, syn0, neu1, actualContext, e); - arrangeNeuKernel<<<1,1,128, *stream>>>(vectorLength, neu1, infVector, actualContext); - - // hierarchic softmax step - if (!indices.isEmpty()) { - for (int i = 0; i < numIndices; i++) { - const int cIndex = bIndices[(e * numIndices) + i]; - const int cCode = bCodes[(e * numIndices) + i]; - - // we're skipping padded values - if (cIndex < 0) - continue; - - if (cIndex >= vocabSize) - throw std::runtime_error("Index can't be > vocab size"); - - hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, alpha, vectorLength, cCode, expLength, false, stream); - } - } - - // negative sampling step - if (!negStarters.isEmpty() && nsRounds > 0) { - int irow = bStarters[e]; - const int nsStarter = irow; - unsigned long long randomValue = nextRandom.e(e); - - for (int r = 0; r < nsRounds + 1; r++) { - // we're skipping rng on 0 step - if (r != 0) { - randomValue = randomValue * (unsigned long long) 25214903917 + 11; - auto idx = sd::math::nd4j_abs((randomValue >> 16) % negLength); - irow = idx >= negLength ? -1 : static_cast(negTable[idx]); - - if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; - if (irow == nsStarter) - continue; - - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } else { - nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); - } - - //nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", omp_get_thread_num(), 0, irow); - } - } - - - // if we're skipping labels - int starter = trainWords == 1 ? 0 : contextWidth - numLabels; - - // applying previously averaged results - applyShiftKernel<<<1,1,128, *stream>>>(dContext, dLocker, syn0, neu1e, contextWidth, vectorLength, e, starter); - - // optionally release temp arrays -// if (vectorLength > 600) { -// } - - } - cerr = cudaStreamSynchronize(*stream); - if (cerr) { - throw cuda_exception::build("Cannot syncronize stream before memory deallocation", cerr); - } - - cerr = cudaFree(neu1); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); - } - cerr = cudaFree(neu1e); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1 E", cerr); - } - cerr = cudaFree(actualContext); - if (cerr) { - throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); - } - - } - BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES); - - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) { - auto xType = syn0.dataType(); - auto lc = context.getContext(); - indices.syncToHost(); - NDArray::prepareSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector}); - //auto stream = lc->getCudaStream(); - if ((context.rankOf() == 0 || context.rankOf() == 1) && (indices.rankOf() == 1 || indices.rankOf() == 0)) { - // single round case - /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); - if (context.lengthOf() == 2) { - context.printBuffer("context"); - lockedWords.printBuffer("locked"); - codes.printBuffer("codes"); - indices.printBuffer("indices"); - }*/ - - auto hsRounds = codes.lengthOf(); - target.syncToHost(); - numLabels.syncToHost(); - target.syncToHost(); - alpha.syncToHost(); - numLabels.syncToHost(); - codes.syncToHost(); - negTable.syncToHost(); - BUILD_SINGLE_SELECTOR(xType, cbow_, (lc, syn0.specialBuffer(), syn1.specialBuffer(), syn1Neg.specialBuffer(), expTable.specialBuffer(), negTable.buffer(), inferenceVector.specialBuffer(), target.isEmpty() ? -1 : target.e(0), ngStarter.isEmpty() ? -1 : ngStarter.e(0), reinterpret_cast(context.specialBuffer()), reinterpret_cast(lockedWords.specialBuffer()),reinterpret_cast(indices.buffer()), reinterpret_cast(codes.buffer()), alpha.e( 0), randomValue.e(0), (int) context.lengthOf(), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), FLOAT_TYPES); - } else if (context.rankOf() == 2 && indices.rankOf() == 2) { - // batch mode - //nd4j_printf("Batch exec\n",""); - - BUILD_SINGLE_SELECTOR(xType, cbowBatchExec_, (lc, syn0, syn1, syn1Neg, expTable.specialBuffer(), negTable.specialBuffer(), nullptr, context, lockedWords, target, ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), FLOAT_TYPES); - } else - throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); - - NDArray::registerSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector}); - } +namespace ops { +namespace helpers { +template +__global__ void hSoftmaxKernel(void *vsyn0, void *vsyn1, void *vexpTable, + void *vneu1e, double alpha, int vectorLength, + int code, int expLength, bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot(0.0f); + T g(0.0f); + T f(0.0f); + + // dot + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1[e]; + } + + // gradient + if (dot < (T)-HS_MAX_EXP || dot >= (T)HS_MAX_EXP) return; + + int idx = static_cast((dot + HS_MAX_EXP) * + ((float)expLength / HS_MAX_EXP / 2.0f)); + + if (idx >= expLength || idx < 0) return; + + f = expTable[idx]; + g = (static_cast(1.0f) - static_cast(code) - f) * (T)alpha; + + // axpy1 + + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1[e] = g * syn0[e] + syn1[e]; + } + } +} + +template +void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference, cudaStream_t *stream) { + hSoftmaxKernel<<<1, 1, 128, *stream>>>(vsyn0, vsyn1, vexpTable, vneu1e, + alpha, vectorLength, code, + expLength, isInference); +} + +template +__global__ void nSamplingKernel(void *vsyn0, void *vsyn1Neg, void *vexpTable, + void *vneu1e, double alpha, int vectorLength, + int code, int expLength, bool isInference) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto neu1e = reinterpret_cast(vneu1e); + + T dot = (T)0.0f; + T g = (T)0.0f; + + for (int e = 0; e < vectorLength; e++) { + dot += syn0[e] * syn1Neg[e]; + } + + if (dot > HS_MAX_EXP) + g = (code - 1) * alpha; + else if (dot < (T)-HS_MAX_EXP) + g = (code - 0) * alpha; + else { + int idx = (int)((dot + (T)HS_MAX_EXP) * ((T)expLength / HS_MAX_EXP / 2.0)); + if (idx >= expLength) return; + + if (idx < 0) return; + + g = ((T)code - expTable[idx]) * alpha; + } + + // axpy1 + for (int e = 0; e < vectorLength; e++) { + neu1e[e] = g * syn1Neg[e] + neu1e[e]; + } + + // axpy2 + if (!isInference) { + for (int e = 0; e < vectorLength; e++) { + syn1Neg[e] = g * syn0[e] + syn1Neg[e]; + } + } +} + +template +void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, + double alpha, int vectorLength, int code, int expLength, + bool isInference, cudaStream_t *stream) { + nSamplingKernel<<<1, 1, 128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, + alpha, vectorLength, code, + expLength, isInference); +} + +/* + * binarySearch - find element in haystack buffer (haystack - sorted device + * memory) + * */ +int binarySearch(const int *haystack, const int needle, + const int totalElements) { + int firstIndex = 0; + int lastIndex = totalElements - 1; + int halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + + while (haystack[halfIndex] != needle && firstIndex < lastIndex) { + if (needle < haystack[halfIndex]) { + lastIndex = halfIndex - 1; + } else if (needle > haystack[halfIndex]) { + firstIndex = halfIndex + 1; + } + halfIndex = + sd::math::nd4j_floor((lastIndex + firstIndex) / (float)2); + } + + return (haystack[halfIndex] == needle) ? halfIndex : -1; +} +template +__global__ void addInfVectorKernel(T *neu1, T *infVector, int vectorLength) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto i = start; i < vectorLength; i += step) { + neu1[i] += infVector[i]; + } +} + +template +void skipgram_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray &expTableV, + NDArray &negTableV, NDArray &infV, int target, int ngStarter, + NDArray &indices, NDArray &codes, double alpha, + Nd4jLong randomValue, const int hsRounds, const int nsRounds) { + // void *vsyn0, void *vsyn1, void *vsyn1Neg, void + // *vexpTable, void *vnegTable, void *vinfVector, int + // target, int ngStarter, int *indices, int8_t *codes, + // double alpha, Nd4jLong randomValue, const int hsRounds, + // const int nsRounds, const int vocabSize, const int + // vectorLength, const int expLength, const int negLength) + // { + auto syn0 = reinterpret_cast(s0.specialBuffer()); + auto syn1 = reinterpret_cast(s1.specialBuffer()); + auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); + auto expTable = reinterpret_cast(expTableV.specialBuffer()); + auto negTable = reinterpret_cast(negTableV.specialBuffer()); + auto infVector = reinterpret_cast(infV.specialBuffer()); + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + indices.tickReadDevice(); + indices.syncToHost(); + codes.tickReadDevice(); + codes.syncToHost(); + auto stream = s0.getContext()->getCudaStream(); + + T *neu1e; // = new T[vectorLength]; + // memset(neu1e, 0, vectorLength * sizeof(T)); + auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength); + err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength); + // hierarchic softmax goes first (if enabled) + + auto syn0row = + infVector != nullptr ? infVector : syn0 + (target * vectorLength); + auto irow = 0; + if (hsRounds > 0) { + for (int r = 0; r < hsRounds; r++) { + irow = indices.t(r); + if (irow < 0 || irow >= vocabSize) break; + + hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, codes.t(r), expLength, + infVector != nullptr, stream); + } + } + + // negative sampling goes second (if enabled) + auto nsStarter = ngStarter; + irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : negTableV.e(idx); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr, stream); + } + } + + if (infVector == nullptr) { + addInfVectorKernel + <<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + } else { + addInfVectorKernel + <<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); + } + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgram_: Cannot synchronize stream after " + "addInfVectorKernel", + err); + } + + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgram_: Cannot deallocate temp memory for lingual net", + err); + } +} +BUILD_SINGLE_TEMPLATE(template void skipgram_, + (NDArray & syn0, NDArray &syn1, NDArray &syn1Neg, + NDArray &expTable, NDArray &negTable, NDArray &infVector, + int target, int ngStarter, NDArray &indices, + NDArray &codes, double alpha, Nd4jLong randomValue, + const int hsRounds, const int nsRounds), + FLOAT_TYPES); + +/* + * batched version of skipgram routine + * */ +template +void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, + NDArray &expTableV, NDArray &negTableV, + NDArray &targets, NDArray &negStarters, + NDArray &indices, NDArray &codes, NDArray &lr, + NDArray &nextRandom, const int nsRounds, + const bool preciseMode, const int numThreads) { + // (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, + // NDArray& negTable, NDArray& infVector, NDArray& targets, + // NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& + // lr, NDArray& nextRandom, const int nsRounds, const bool + // preciseMode, const int numThreads) { + // auto syn0 = reinterpret_cast(vsyn0); + // auto syn1 = reinterpret_cast(vsyn1); + // auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto stream = s0.getContext()->getCudaStream(); + negTableV.tickReadDevice(); + negTableV.syncToHost(); + const auto expTable = reinterpret_cast(expTableV.specialBuffer()); + const auto negTable = reinterpret_cast(negTableV.buffer()); + const auto infVector = + (T *)nullptr; // reinterpret_cast(infVector.specialBuffer()); + + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + + // T sneu1e[600]; + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + + // regular mode provides 0 guarantees for reproducibility + auto numTargets = targets.lengthOf(); + targets.syncToHost(); + indices.syncToHost(); + codes.syncToHost(); + lr.syncToHost(); + nextRandom.syncToHost(); + negStarters.tickReadDevice(); + negStarters.syncToHost(); + auto bTarget = + reinterpret_cast(targets.buffer()); // targets.bufferAsT(); + auto bIndices = + reinterpret_cast(indices.buffer()); // indices.bufferAsT(); + auto bCodes = + reinterpret_cast(codes.buffer()); // codes.bufferAsT(); + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads)) + for (int t = 0; t < numTargets; t++) { + T *neu1e; // lvectorLength <= 600 ? sneu1e : new T[vectorLength]; + auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T)); + err = cudaMemset(neu1e, 0, vectorLength * sizeof(T)); + // memset(neu1e, 0, vectorLength * sizeof(T)); + + auto target = bTarget[t]; + auto alpha = lr.e(t); + unsigned long long randomValue = nextRandom.e(t); + + auto syn0row = + reinterpret_cast(s0.specialBuffer()) + (target * vectorLength); + + if (hsRounds > 0) { + int irow = 0; + auto cShift = t * idxShift; + + for (int e = 0; e < hsRounds; e++) { + irow = bIndices[e + cShift]; + if (irow < 0 || irow >= vocabSize) continue; + + auto syn1row = + reinterpret_cast(s1.specialBuffer()) + (irow * vectorLength); + auto code = bCodes[e + cShift]; + + // nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, + // code); + hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + code, expLength, false, stream); + } + } + + if (nsRounds > 0) { + int irow = negStarters.e(t); + int nsStarter = irow; + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + + if (irow == nsStarter) continue; + } + auto syn1row = + reinterpret_cast(s1n.specialBuffer()) + (irow * vectorLength); + nSampling_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, + r == 0 ? 1 : 0, expLength, false, stream); + } + } + addInfVectorKernel + <<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::skipgramBatchExec_: Cannot synchronize stream after " + "addInfVectorKernel", + err); + } + + // optionally release temp arrays + err = cudaFree(neu1e); + if (err != 0) { + throw cuda_exception::build( + "helpers::skipgramBatchExec_: Cannot deallocate memory with stage", + err); + break; + } + // if (vectorLength > 600) + // delete[] neu1e; + } +} +BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, + (NDArray & s0, NDArray &s1, NDArray &s1n, + NDArray &expTable, NDArray &negTable, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, const int nsRounds, + const bool preciseMode, const int numThreads), + FLOAT_TYPES); + +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers) { + auto xType = syn0.dataType(); + // single round case + if ((ngStarter.isScalar() && !ngStarter.isEmpty()) || + (target.isScalar() && !target.isEmpty())) { + auto hsRounds = codes.lengthOf(); + target.syncToHost(); + ngStarter.syncToHost(); + alpha.syncToHost(); + randomValue.syncToHost(); + + auto targetV = target.isEmpty() ? -1 : target.e(0); + auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e(0); + auto alphaV = alpha.e(0); + auto randomV = randomValue.e(0); + BUILD_SINGLE_SELECTOR( + xType, skipgram_, + (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, + starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), + FLOAT_TYPES); + } else if (ngStarter.isVector() || target.isVector()) { + // batch mode + // NDArray& infVector, NDArray &targets, NDArray + // &negStarters, NDArray &indices, NDArray &codes, + // NDArray &lr, NDArray &nextRandom, const int nsRounds, + // const bool preciseMode, const int numThreads) + BUILD_SINGLE_SELECTOR( + xType, skipgramBatchExec_, + (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, + codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("SkipGram: target must have rank 0 or 1"); +} + +template +static __global__ void checkContextKernel(int *context, T *syn0, T *neu1, + int contextWidth, int vectorLength, + int vocabSize) { + __shared__ bool hasError; + if (0 == threadIdx.x) { + hasError = false; + } + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int c = start; c < contextWidth; c += step) { + if (context[c] >= vocabSize) + hasError = true; // throw std::runtime_error("Bad context 4"); + if (!hasError) { + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + neu1[i] += syn0word[i]; + } + } + } + if (threadIdx.x == 0) { + if (hasError) neu1[0] = DataTypeUtils::infOrMax(); + } + __syncthreads(); +} + +template +__global__ void shiftKernel(T *neu1, T *infVector, int contextWidth, + int vectorLength) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i < vectorLength; i += step) { + neu1[i] /= contextWidth + int(infVector != nullptr); // ? 1 : 0); + } +} + +template +__global__ void fillUpSynonymsKernel(int starter, int contextWidth, + int vectorLength, int *lockedWords, + int *context, T *neu1e, T *syn0) { + auto start = threadIdx.x + blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int c = starter + start; c < contextWidth; c += step) { + if (lockedWords[c] == 1) continue; + + T *syn0word = syn0 + (context[c] * vectorLength); + + for (int i = 0; i < vectorLength; i++) { + syn0word[i] += neu1e[i]; + } + } +} + +template +void cbow_(LaunchContext *lc, void *vsyn0, void *vsyn1, void *vsyn1Neg, + void *vexpTable, void *vnegTable, void *vinfVector, int target, + int ngStarter, int *context, int *lockedWords, int *indices, + int8_t *codes, double alpha, Nd4jLong randomValue, + const int contextWidth, const int hsRounds, const int nsRounds, + const int vocabSize, const int vectorLength, const int expLength, + const int negLength, const int numLabels, const bool trainWords) { + auto syn0 = reinterpret_cast(vsyn0); + auto syn1 = reinterpret_cast(vsyn1); + auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto expTable = reinterpret_cast(vexpTable); + auto negTable = reinterpret_cast(vnegTable); + auto infVector = reinterpret_cast(vinfVector); + auto stream = lc->getCudaStream(); + + T *neu1; // = new T[vectorLength]; + T *neu1e; // = new T[vectorLength]; + size_t buffSize = sizeof(T) * vectorLength; + auto err = cudaMalloc(&neu1, buffSize); + err = cudaMalloc(&neu1e, buffSize); + err = cudaMemset(neu1, 0, buffSize); + err = cudaMemset(neu1e, 0, buffSize); + + // building neu1 for current window + checkContextKernel<<<1, 1, 128, *stream>>>( + context, syn0, neu1, contextWidth, vectorLength, vocabSize); + + T checkVal; + err = cudaMemcpy(&checkVal, neu1, sizeof(T), cudaMemcpyDeviceToHost); + if (DataTypeUtils::infOrMax() == checkVal) + throw std::runtime_error("Bad context 4"); + // for inference we add additional inference vector + if (infVector != nullptr) { + addInfVectorKernel + <<<128, 256, 128, *stream>>>(neu1, infVector, vectorLength); + } + + // average neu1 + if (contextWidth > 0) { + shiftKernel<<<128, 256, 128, *stream>>>(neu1, infVector, contextWidth, + vectorLength); + } + + // softmax round + if (hsRounds > 0) { + for (int i = 0; i < hsRounds; i++) { + if (indices[i] < 0 || indices[i] >= vocabSize) + throw std::runtime_error("Bad context 5"); + T *syn1Shifted = syn1 + (indices[i] * vectorLength); + hSoftmax_(neu1, syn1Shifted, expTable, neu1e, alpha, vectorLength, + codes[i], expLength, infVector != nullptr, stream); + } + } + + auto nsStarter = ngStarter; + auto irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + } + + nSampling_(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, + alpha, vectorLength, r == 0 ? 1 : 0, expLength, + infVector != nullptr, stream); + } + } + + // if we don't train words - we skip start of idxSyn0 + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // propagate neu1e -> syn0 + if (infVector == nullptr) { + fillUpSynonymsKernel<<<1, 1, 128, *stream>>>( + starter, contextWidth, vectorLength, lockedWords, context, neu1e, syn0); + } else { + for (int i = 0; i < vectorLength; i++) { + infVector[i] += neu1e[i]; + } + } + err = cudaStreamSynchronize(*stream); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot synchronize stream after kernel executing", + err); + } + err = cudaFree(neu1); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for synonims table", err); + } + + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build( + "helpers::cbow_: Cannot deallocate memory for antonims table", err); + } +} +BUILD_SINGLE_TEMPLATE( + template void cbow_, + (LaunchContext * lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, + void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, + int *lockedWords, int *indices, int8_t *codes, double alpha, + Nd4jLong randomValue, const int contextWidth, const int hsRounds, + const int nsRounds, const int vocabSize, const int vectorLength, + const int expLength, const int negLength, const int numLabels, + const bool trainWords), + FLOAT_TYPES); + +template +static __global__ void buildCurrentWindowKernel(int vocabSize, int contextWidth, + int vectorLength, int *bContext, + T *syn0, T *neu1, + int *actualContext, int e) { + // building neu1 for current window + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int c = start; c < contextWidth; c += step) { + // getting next context word + auto cContext = bContext[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0) continue; + + // if (cContext >= vocabSize) + // throw std::runtime_error("ContextID can't be >= + // vocab size"); + + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) neu1[i] += syn0word[i]; + + atomicAdd(actualContext, 1); + } +} + +template +__global__ void arrangeNeuKernel(int vectorLength, T *neu1, T *infVector, + int *actualContext) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (int i = start; i 0; i += step) + neu1[i] /= (*actualContext + int(infVector != nullptr)); +} + +template +__global__ void applyShiftKernel(int *bContext, int *bLocker, T *syn0, T *neu1e, + int contextWidth, int vectorLength, int e, + int starter) { + auto step = blockDim.x * gridDim.x; + auto start = blockDim.x * blockIdx.x + threadIdx.x; + + for (int c = starter + start; c < contextWidth; c += step) { + // getting context + auto cContext = bContext[c + (e * contextWidth)]; + auto cLock = bLocker[c + (e * contextWidth)]; + + // skipping padded values + if (cContext < 0 || cLock == 1) continue; + + // if (cContext >= vocabSize) + // throw std::runtime_error("ContextID can't be > + // vocab size"); + + // one word from context + T *syn0word = syn0 + (cContext * vectorLength); + + for (int i = 0; i < vectorLength; i++) syn0word[i] += neu1e[i]; + } +} + +template +void cbowBatchExec_(LaunchContext *lc, NDArray &s0, NDArray &s1, NDArray &s1n, + void *vexpTable, void *vnegTable, void *vinfVector, + NDArray &context, NDArray &lockedWords, NDArray &targets, + NDArray &negStarters, NDArray &indices, NDArray &codes, + NDArray &lr, NDArray &nextRandom, NDArray &nLabels, + const int nsRounds, const int vocabSize, + const int vectorLength, const int expLength, + const int negLength, const bool trainWords, + const int numThreads) { + const auto syn0 = + reinterpret_cast(s0.specialBuffer()); // bufferAsT(); + const auto syn1 = + reinterpret_cast(s1.specialBuffer()); // bufferAsT(); + const auto syn1Neg = + reinterpret_cast(s1n.specialBuffer()); // bufferAsT(); + + const auto expTable = reinterpret_cast(vexpTable); + const auto negTable = reinterpret_cast(vnegTable); + const auto infVector = reinterpret_cast(vinfVector); + + auto stream = lc->getCudaStream(); + + indices.syncToHost(); + codes.syncToHost(); + negStarters.syncToHost(); + context.syncToHost(); + + // const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + const auto numTargets = context.sizeAt(0); + const int contextWidth = context.sizeAt(1); + // const auto bContext = reinterpret_cast(context.buffer()); + // //bufferAsT(); + const auto dContext = + context.dataBuffer()->specialAsT(); // bufferAsT(); + // const auto bLocker = + // reinterpret_cast(lockedWords.buffer()); + // //lockedWords.bufferAsT(); + const auto dLocker = + lockedWords.dataBuffer() + ->specialAsT(); //.specialBuffer()); + ////lockedWords.bufferAsT(); + const auto bIndices = + indices.dataBuffer()->primaryAsT(); // buffer());//AsT(); + const auto bCodes = + codes.dataBuffer()->primaryAsT(); // reinterpret_cast(codes.buffer()); + // //bufferAsT(); + const auto bStarters = + negStarters.dataBuffer()->primaryAsT(); // reinterpret_cast(negStarters.buffer()); + // //AsT(); + const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); + lr.syncToHost(); + nLabels.syncToHost(); + // PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads) private(sneu1, + // sneu1e)) NDArray neuVector('c', {vectorLength}, DataTypeUtils::fromT()); + // auto neuEVector = neuVector; //NDArrayFactory::create('c', + // {vectorLength}); + T *neu1; // = reinterpret_cast(neuVector.specialBuffer());// = + // vectorLength <= 600 ? sneu1 : new T[vectorLength]; + T *neu1e; // = reinterpret_cast(neuVector.specialBuffer()); // = + // vectorLength <= 600 ? sneu1e : new T[vectorLength]; + auto cerr = cudaMalloc(&neu1, sizeof(T) * vectorLength); + if (cerr) { + throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); + } + cerr = cudaMalloc(&neu1e, sizeof(T) * vectorLength); + if (cerr) { + throw cuda_exception::build("Cannot allocate temp vector buffer", cerr); + } + int *actualContext; + cerr = cudaMalloc(&actualContext, sizeof(int)); + if (cerr) { + throw cuda_exception::build("Cannot allocate counter buffer", cerr); + } + + for (int e = 0; e < numTargets; e++) { + // auto err = cudaMalloc(&neu1, sizeof(T)* vectorLength); + // q err = cudaMalloc(&neu1e, sizeof(T)*vectorLength); + // + // // optionally we nullify temp arrays after successful + // (and on first) cycle memset(neu1, 0, sizeof(T) * + // vectorLength); memset(neu1e, 0, sizeof(T) * + // vectorLength); + + auto alpha = lr.e(e); + auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e(e); + + // auto err = cudaMemset(actualContext, 0, sizeof(int)); + // if (err) { + // printf("Cuda error %d\n", err); break; + // } + + buildCurrentWindowKernel + <<<1, 1, 128, *stream>>>(vocabSize, contextWidth, vectorLength, + dContext, syn0, neu1, actualContext, e); + arrangeNeuKernel + <<<1, 1, 128, *stream>>>(vectorLength, neu1, infVector, actualContext); + + // hierarchic softmax step + if (!indices.isEmpty()) { + for (int i = 0; i < numIndices; i++) { + const int cIndex = bIndices[(e * numIndices) + i]; + const int cCode = bCodes[(e * numIndices) + i]; + + // we're skipping padded values + if (cIndex < 0) continue; + + if (cIndex >= vocabSize) + throw std::runtime_error("Index can't be > vocab size"); + + hSoftmax_(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, + alpha, vectorLength, cCode, expLength, false, stream); + } + } + + // negative sampling step + if (!negStarters.isEmpty() && nsRounds > 0) { + int irow = bStarters[e]; + const int nsStarter = irow; + unsigned long long randomValue = nextRandom.e(e); + + for (int r = 0; r < nsRounds + 1; r++) { + // we're skipping rng on 0 step + if (r != 0) { + randomValue = randomValue * (unsigned long long)25214903917 + 11; + auto idx = + sd::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) continue; + + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr, stream); + } else { + nSampling_(neu1, s1n.bufferWithOffset(irow * vectorLength), + expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, + expLength, infVector != nullptr, stream); } + + // nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", + // omp_get_thread_num(), 0, irow); + } } -} \ No newline at end of file + + // if we're skipping labels + int starter = trainWords == 1 ? 0 : contextWidth - numLabels; + + // applying previously averaged results + applyShiftKernel<<<1, 1, 128, *stream>>>( + dContext, dLocker, syn0, neu1e, contextWidth, vectorLength, e, starter); + + // optionally release temp arrays + // if (vectorLength > 600) { + // } + } + cerr = cudaStreamSynchronize(*stream); + if (cerr) { + throw cuda_exception::build( + "Cannot syncronize stream before memory deallocation", cerr); + } + + cerr = cudaFree(neu1); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); + } + cerr = cudaFree(neu1e); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1 E", cerr); + } + cerr = cudaFree(actualContext); + if (cerr) { + throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); + } +} +BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, + (LaunchContext * lc, NDArray &s0, NDArray &s1, + NDArray &s1n, void *vexpTable, void *vnegTable, + void *vinfVector, NDArray &context, NDArray &lockedWords, + NDArray &targets, NDArray &negStarters, NDArray &indices, + NDArray &codes, NDArray &lr, NDArray &nextRandom, + NDArray &nLabels, const int nsRounds, + const int vocabSize, const int vectorLength, + const int expLength, const int negLength, + const bool trainWords, const int numThreads), + FLOAT_TYPES); + +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + int numWorkers) { + auto xType = syn0.dataType(); + auto lc = context.getContext(); + indices.syncToHost(); + NDArray::prepareSpecialUse( + {&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, + {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, + &numLabels, &inferenceVector}); + // auto stream = lc->getCudaStream(); + if ((context.rankOf() == 0 || context.rankOf() == 1) && + (indices.rankOf() == 1 || indices.rankOf() == 0)) { + // single round case + /*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; + Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), (int) trainWords); if + (context.lengthOf() == 2) { context.printBuffer("context"); + lockedWords.printBuffer("locked"); + codes.printBuffer("codes"); + indices.printBuffer("indices"); + }*/ + + auto hsRounds = codes.lengthOf(); + target.syncToHost(); + numLabels.syncToHost(); + target.syncToHost(); + alpha.syncToHost(); + numLabels.syncToHost(); + codes.syncToHost(); + negTable.syncToHost(); + BUILD_SINGLE_SELECTOR( + xType, cbow_, + (lc, syn0.specialBuffer(), syn1.specialBuffer(), + syn1Neg.specialBuffer(), expTable.specialBuffer(), negTable.buffer(), + inferenceVector.specialBuffer(), + target.isEmpty() ? -1 : target.e(0), + ngStarter.isEmpty() ? -1 : ngStarter.e(0), + reinterpret_cast(context.specialBuffer()), + reinterpret_cast(lockedWords.specialBuffer()), + reinterpret_cast(indices.buffer()), + reinterpret_cast(codes.buffer()), alpha.e(0), + randomValue.e(0), (int)context.lengthOf(), hsRounds, + nsRounds, (int)syn0.sizeAt(0), (int)syn0.sizeAt(1), + (int)expTable.lengthOf(), (int)negTable.lengthOf(), + numLabels.isEmpty() ? 0 : numLabels.e(0), trainWords), + FLOAT_TYPES); + } else if (context.rankOf() == 2 && indices.rankOf() == 2) { + // batch mode + // nd4j_printf("Batch exec\n",""); + + BUILD_SINGLE_SELECTOR( + xType, cbowBatchExec_, + (lc, syn0, syn1, syn1Neg, expTable.specialBuffer(), + negTable.specialBuffer(), nullptr, context, lockedWords, target, + ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, + syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), + negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), + FLOAT_TYPES); + } else + throw std::runtime_error("CBOW: context must have rank 0/1 or 2"); + + NDArray::registerSpecialUse( + {&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, + {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, + &numLabels, &inferenceVector}); +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu index c69285ef29b9..cc241d3eb30d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -21,61 +21,65 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x >> shift; - }; +namespace ops { +namespace helpers { +template +void rshift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x >> shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto lambda = LAMBDA_T(x, shift) { - return x << shift; - }; +template +void shift_bits_(LaunchContext *launchContext, NDArray &input, NDArray &output, + uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { return x << shift; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), + INTEGER_TYPES); +} - template - void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x >> shift | x << step; - }; +template +void cyclic_rshift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x >> shift | x << step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} - template - void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { - auto step = (sizeof(T) * 8) - shift; - auto lambda = LAMBDA_T(x, shift, step) { - return x << shift | x >> step; - }; +template +void cyclic_shift_bits_(LaunchContext *launchContext, NDArray &input, + NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { return x << shift | x >> step; }; - input.applyLambda(lambda, output); - } + input.applyLambda(lambda, output); +} - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { - BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, + (launchContext, x, z, shift), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu index cf8308bbe987..ad36a8933c3d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/solve.cu @@ -18,123 +18,154 @@ // @author GS // -#include #include #include -#include - #include #include -#include "../triangular_solve.h" +#include +#include + #include "../lup.h" #include "../solve.h" +#include "../triangular_solve.h" namespace sd { - namespace ops { - namespace helpers { - - template - static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong const* ioShape, Nd4jLong const* tadShape, Nd4jLong const* tadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { - for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { - auto matrixPart = ioBuf + tadOffsets[i]; - for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { - Nd4jLong pos[] = {j, j}; - auto offset = shape::getOffset(tadShape, pos); - - matrixPart[offset] = T(1.f); - } - } - } - - template - static __global__ void restorePermutationsKernel(T* PBuf, Nd4jLong const* PShapeInfo, int const* permutationsBuf, - Nd4jLong const* PTadShapeInfo, Nd4jLong const* PTadSOffsets, Nd4jLong const* permutationsTadShapeInfo, Nd4jLong const* permutationsTadOffsets, Nd4jLong batchNum, Nd4jLong rowNum) { - for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { - auto permutations = permutationsBuf + permutationsTadOffsets[batch]; - auto P = PBuf + PTadSOffsets[batch]; - - for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { - //auto posX[] = {row}; - Nd4jLong posZ[] = {row, permutations[row]}; - auto zOffset = shape::getOffset(PTadShapeInfo, posZ); - P[zOffset] = T(1.f); - } - } - } - - template - static int solveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, - bool adjoint, NDArray* output) { - NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); - // stage 1: LU decomposition batched - auto leftOutput = leftInput->ulike(); - auto permuShape = rightInput->getShapeAsVector(); permuShape.pop_back(); - auto permutations = NDArrayFactory::create('c', permuShape, context); - helpers::lu(context, leftInput, &leftOutput, &permutations); - auto leftLower = leftOutput.dup(); - auto rightOutput = rightInput->ulike(); - auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions(leftLower.shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - oneOnDiagonalKernel<<<128, 256, 256, *stream>>>(leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); - auto P = leftOutput.ulike(); P.nullify(); - auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.shapeInfo(), {-2, -1}); - auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions(permutations.shapeInfo(), {-1}); - restorePermutationsKernel<<<128, 256, 256, *stream>>>(P.dataBuffer()->specialAsT(), P.specialShapeInfo(), permutations.dataBuffer()->specialAsT(), - PTad.specialShapeInfo(), PTad.specialOffsets(), permutationsTad.specialShapeInfo(), permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), permutations.sizeAt(-1)); - P.tickWriteDevice(); - auto rightPart = rightInput->ulike(); - MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); - - // stage 2: triangularSolveFunctor for Lower with given b - helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, &rightOutput); - // stage 3: triangularSolveFunctor for Upper with output of previous stage - helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, false, output); - NDArray::registerSpecialUse({output}, {leftInput, rightInput}); - - return Status::OK(); - } - - int solveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, (context, leftInput, rightInput, adjoint, output), FLOAT_TYPES); - } - - template - static __global__ void adjointKernel(T* output, Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, Nd4jLong const* outputTads, - Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = threadIdx.y; c < r; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(outputTads, xPos); - math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); - } - } - } - - } - - template - static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - NDArray::prepareSpecialUse({output}, {input}); - auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - auto rows = input->sizeAt(-2); - auto columns = input->sizeAt(-1); - output->assign(input); - adjointKernel<<<128, 256, 256, *stream>>>(outputBuf, outputTads.numberOfTads(), rows, columns, outputTads.specialShapeInfo(), outputTads.specialOffsets()); - NDArray::registerSpecialUse({output}, {input}); - } - - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, (context, input, output), FLOAT_TYPES); - } - - } +namespace ops { +namespace helpers { + +template +static __global__ void oneOnDiagonalKernel(T* ioBuf, Nd4jLong const* ioShape, + Nd4jLong const* tadShape, + Nd4jLong const* tadOffsets, + Nd4jLong batchNum, Nd4jLong rowNum) { + for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) { + auto matrixPart = ioBuf + tadOffsets[i]; + for (auto j = threadIdx.x; j < rowNum; j += blockDim.x) { + Nd4jLong pos[] = {j, j}; + auto offset = shape::getOffset(tadShape, pos); + + matrixPart[offset] = T(1.f); + } + } +} + +template +static __global__ void restorePermutationsKernel( + T* PBuf, Nd4jLong const* PShapeInfo, int const* permutationsBuf, + Nd4jLong const* PTadShapeInfo, Nd4jLong const* PTadSOffsets, + Nd4jLong const* permutationsTadShapeInfo, + Nd4jLong const* permutationsTadOffsets, Nd4jLong batchNum, + Nd4jLong rowNum) { + for (auto batch = blockIdx.x; batch < batchNum; batch += gridDim.x) { + auto permutations = permutationsBuf + permutationsTadOffsets[batch]; + auto P = PBuf + PTadSOffsets[batch]; + + for (auto row = threadIdx.x; row < rowNum; row += blockDim.x) { + // auto posX[] = {row}; + Nd4jLong posZ[] = {row, permutations[row]}; + auto zOffset = shape::getOffset(PTadShapeInfo, posZ); + P[zOffset] = T(1.f); } + } } + +template +static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + // stage 1: LU decomposition batched + auto leftOutput = leftInput->ulike(); + auto permuShape = rightInput->getShapeAsVector(); + permuShape.pop_back(); + auto permutations = NDArrayFactory::create('c', permuShape, context); + helpers::lu(context, leftInput, &leftOutput, &permutations); + auto leftLower = leftOutput.dup(); + auto rightOutput = rightInput->ulike(); + auto leftLowerTad = ConstantTadHelper::getInstance()->tadForDimensions( + leftLower.shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + oneOnDiagonalKernel<<<128, 256, 256, *stream>>>( + leftLower.dataBuffer()->specialAsT(), leftLower.specialShapeInfo(), + leftLowerTad.specialShapeInfo(), leftLowerTad.specialOffsets(), + leftLowerTad.numberOfTads(), leftLower.sizeAt(-1)); + auto P = leftOutput.ulike(); + P.nullify(); + auto PTad = ConstantTadHelper::getInstance()->tadForDimensions(P.shapeInfo(), + {-2, -1}); + auto permutationsTad = ConstantTadHelper::getInstance()->tadForDimensions( + permutations.shapeInfo(), {-1}); + restorePermutationsKernel<<<128, 256, 256, *stream>>>( + P.dataBuffer()->specialAsT(), P.specialShapeInfo(), + permutations.dataBuffer()->specialAsT(), PTad.specialShapeInfo(), + PTad.specialOffsets(), permutationsTad.specialShapeInfo(), + permutationsTad.specialOffsets(), permutationsTad.numberOfTads(), + permutations.sizeAt(-1)); + P.tickWriteDevice(); + auto rightPart = rightInput->ulike(); + MmulHelper::matmul(&P, rightInput, &rightPart, 0, 0); + + // stage 2: triangularSolveFunctor for Lower with given b + helpers::triangularSolveFunctor(context, &leftLower, &rightPart, true, false, + &rightOutput); + // stage 3: triangularSolveFunctor for Upper with output of previous stage + helpers::triangularSolveFunctor(context, &leftOutput, &rightOutput, false, + false, output); + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); +} + +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output) { + BUILD_SINGLE_SELECTOR(leftInput->dataType(), return solveFunctor_, + (context, leftInput, rightInput, adjoint, output), + FLOAT_TYPES); +} + +template +static __global__ void adjointKernel(T* output, Nd4jLong batchSize, + Nd4jLong rows, Nd4jLong columns, + Nd4jLong const* outputTads, + Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c < r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(outputTads, xPos); + math::nd4j_swap(outputPart[zIndex], outputPart[xIndex]); + } + } + } +} + +template +static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + NDArray::prepareSpecialUse({output}, {input}); + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + output->assign(input); + adjointKernel<<<128, 256, 256, *stream>>>( + outputBuf, outputTads.numberOfTads(), rows, columns, + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + NDArray::registerSpecialUse({output}, {input}); +} + +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointMatrix_, + (context, input, output), FLOAT_TYPES); +} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/split.cu b/libnd4j/include/ops/declarable/helpers/cuda/split.cu index 19c58b89eba0..71e97993f205 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/split.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/split.cu @@ -19,174 +19,195 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { - - const T* x = reinterpret_cast(vx); +template +__global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* pVz, const Nd4jLong* zTadShapeInfo, + const int axis) { + const T* x = reinterpret_cast(vx); - __shared__ Nd4jLong xLen, totalThreads; - __shared__ int xRank, zDim; + __shared__ Nd4jLong xLen, totalThreads; + __shared__ int xRank, zDim; - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - int coords[MAX_RANK]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (uint64_t i = tid; i < xLen; i += totalThreads) { + int coords[MAX_RANK]; - shape::index2coords(i, xShapeInfo, coords); + for (uint64_t i = tid; i < xLen; i += totalThreads) { + shape::index2coords(i, xShapeInfo, coords); - const auto xOffset = shape::getOffset(xShapeInfo, coords); + const auto xOffset = shape::getOffset(xShapeInfo, coords); - auto *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis] / zDim]); + auto* z = reinterpret_cast( + reinterpret_cast(pVz)[coords[axis] / zDim]); - coords[axis] %= zDim; + coords[axis] %= zDim; - const auto zOffset = shape::getOffset(zTadShapeInfo, coords); + const auto zOffset = shape::getOffset(zTadShapeInfo, coords); - z[zOffset] = x[xOffset]; - } + z[zOffset] = x[xOffset]; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void splitCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { - - splitCuda<<>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis); +template +__host__ static void splitCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { + splitCuda<<>>( + vx, xShapeInfo, pVz, zTadShapeInfo, axis); } -BUILD_SINGLE_TEMPLATE(template void splitCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void splitCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* pVz, + const Nd4jLong* zTadShapeInfo, const int axis), + LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { - - const int numOfSubArrs = outArrs.size(); - const auto sizeofT = input.sizeOfT(); - - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->syncToDevice(); - input.syncToDevice(); - - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfSubArrs; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis) { + const int numOfSubArrs = outArrs.size(); + const auto sizeofT = input.sizeOfT(); + + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->syncToDevice(); + input.syncToDevice(); + + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfSubArrs; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - auto x = static_cast(input.specialBuffer()); + auto x = static_cast(input.specialBuffer()); - for (uint i = 0; i < numOfSubArrs; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf() * sizeofT; - cudaMemcpyAsync(static_cast(outArrs[i]->specialBuffer()), x, memAmountToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - x = static_cast(x) + memAmountToCopy; - } + for (uint i = 0; i < numOfSubArrs; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf() * sizeofT; + cudaMemcpyAsync(static_cast(outArrs[i]->specialBuffer()), x, + memAmountToCopy, cudaMemcpyDeviceToDevice, + *context->getCudaStream()); + x = static_cast(x) + memAmountToCopy; + } - if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - throw std::runtime_error("split cuda: luckCase1 failed!"); + if (cudaStreamSynchronize(*context->getCudaStream()) != 0) + throw std::runtime_error("split cuda: luckCase1 failed!"); - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->tickWriteDevice(); - input.tickReadDevice(); + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); - return; - } + return; + } - // const bool isXcontin = input.strideAt(axis) == 1; - // bool areOutputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(outArrs.size()); + // const bool isXcontin = input.strideAt(axis) == 1; + // bool areOutputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(outArrs.size()); - // if(isXcontin) { + // if(isXcontin) { - // for (uint i = 0; i < outArrs.size(); ++i) { + // for (uint i = 0; i < outArrs.size(); ++i) { - // areOutputsContin &= outArrs[i]->strideAt(axis) == 1; - // allSameOrder &= input.ordering() == outArrs[i]->ordering(); - // if(!areOutputsContin || !allSameOrder) - // break; + // areOutputsContin &= outArrs[i]->strideAt(axis) == 1; + // allSameOrder &= input.ordering() == outArrs[i]->ordering(); + // if(!areOutputsContin || !allSameOrder) + // break; - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->shapeInfo()); - // } - // } + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // outArrs[i]->shapeInfo()); + // } + // } - // const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder; + // const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder; - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and input array - // const auto xStep = shape::strideOverContigAxis(axis, input.shapeInfo()); - // const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs + // const auto xStep = shape::strideOverContigAxis(axis, + // input.shapeInfo()); const auto zDim = outArrs[0]->sizeAt(axis); // + // same for all outArrs - // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { + // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { - // const auto iShift = i * sizeofT; - // void* x = static_cast(input.specialBuffer()) + xStep * iShift; + // const auto iShift = i * sizeofT; + // void* x = static_cast(input.specialBuffer()) + xStep * + // iShift; - // for (uint j = 0; j < numOfSubArrs; ++j) { - // void* z = static_cast(outArrs[j]->specialBuffer()) + strideOfContigStride[j] * iShift; - // const auto memSizeToCopy = zDim * sizeofT; - // cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream()); - // x = static_cast(x) + memSizeToCopy; - // } - // } + // for (uint j = 0; j < numOfSubArrs; ++j) { + // void* z = static_cast(outArrs[j]->specialBuffer()) + + // strideOfContigStride[j] * iShift; const auto memSizeToCopy = + // zDim * sizeofT; cudaMemcpyAsync(z, x, memSizeToCopy, + // cudaMemcpyDeviceToDevice, *context->getCudaStream()); x = + // static_cast(x) + memSizeToCopy; + // } + // } - // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) - // throw std::runtime_error("split cuda: luckCase2 failed!"); - // } - // else { // general (slower) case + // if(cudaStreamSynchronize(*context->getCudaStream()) != 0) + // throw std::runtime_error("split cuda: luckCase2 failed!"); + // } + // else { // general (slower) case - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - // prepare arrays of pointers on buffers and shapes - std::vector hOutBuffers(numOfSubArrs); + // prepare arrays of pointers on buffers and shapes + std::vector hOutBuffers(numOfSubArrs); - for(int i = 0; i < numOfSubArrs; ++i) - hOutBuffers[i] = outArrs[i]->specialBuffer(); + for (int i = 0; i < numOfSubArrs; ++i) + hOutBuffers[i] = outArrs[i]->specialBuffer(); - PointersManager manager(context, "helpers::split"); + PointersManager manager(context, "helpers::split"); - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + void* dOutBuffers = manager.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - BUILD_SINGLE_SELECTOR(input.dataType(), splitCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), splitCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, + outArrs[0]->specialShapeInfo(), axis), + LIBND4J_TYPES); - manager.synchronize(); - // } + manager.synchronize(); + // } - for(int i = 0; i < numOfSubArrs; ++i) - outArrs[i]->tickWriteDevice(); - input.tickReadDevice(); + for (int i = 0; i < numOfSubArrs; ++i) outArrs[i]->tickWriteDevice(); + input.tickReadDevice(); } -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index b59ac00524b2..27d9beadaa38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -15,521 +15,539 @@ ******************************************************************************/ // -// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 +// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 +// [cs.CL] 12 Sep 2017 // // @author Yurii Shyrma, created on 05.12.2017 // -#include #include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray activation(const NDArray& arr) { - // return (const_cast&>(arr)).template transform>(); - auto result = NDArray(&arr, false, arr.getContext()); - (const_cast(arr)).applyTransform(transform::Tanh, result); - return result; - } - - - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); - } - +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE NDArray activation(const NDArray& arr) { + // return (const_cast&>(arr)).template + // transform>(); + auto result = NDArray(&arr, false, arr.getContext()); + (const_cast(arr)).applyTransform(transform::Tanh, result); + return result; +} ////////////////////////////////////////////////////////////////////////// -void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { +static FORCEINLINE NDArray sigmoid(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Sigmoid); +} - // x input [bS x inSize], bS - batch size, inSize - number of features - // c0 previous cell state c [bS x inSize], that is at previous time step t-1 - // w weights [inSize x 3*inSize] - // b biases [2*inSize] +////////////////////////////////////////////////////////////////////////// +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { + // x input [bS x inSize], bS - batch size, inSize - number of features + // c0 previous cell state c [bS x inSize], that is at previous time step t-1 + // w weights [inSize x 3*inSize] + // b biases [2*inSize] - // h current cell output [bS x inSize], that is at current time step t - // c current cell state [bS x inSize], that is at current time step t + // h current cell output [bS x inSize], that is at current time step t + // c current cell state [bS x inSize], that is at current time step t - const int inSize = x->sizeAt(1); // inSize - number of features + const int inSize = x->sizeAt(1); // inSize - number of features - auto z = mmul(*x, *w); // [bS x 3*inSize] + auto z = mmul(*x, *w); // [bS x 3*inSize] - // forget gate = sigmoid(x*Wf + bf) - auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize})); + // forget gate = sigmoid(x*Wf + bf) + auto f = sigmoid(z({0, 0, inSize, 2 * inSize}) + (*b)({0, inSize})); - // reset gate = sigmoid(x*Wr + br) - auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize})); + // reset gate = sigmoid(x*Wr + br) + auto r = + sigmoid(z({0, 0, 2 * inSize, 3 * inSize}) + (*b)({inSize, 2 * inSize})); - // ◦ means element-wise product or so called Hadamard product - // current sell state = f◦c0 + (1 - f)◦(x*Wc) - c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) ); - // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); + // ◦ means element-wise product or so called Hadamard product + // current sell state = f◦c0 + (1 - f)◦(x*Wc) + c->assign(f * (*c0) + (1.f - f) * z({0, 0, 0, inSize})); + // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); - // current cell output = r◦activation(c) + (1 - r)◦x - h->assign( r * activation(*c) + (1.f - r) * (*x) ); - // *h = r * (activation(c) - *x) + *x; + // current cell output = r◦activation(c) + (1 - r)◦x + h->assign(r * activation(*c) + (1.f - r) * (*x)); + // *h = r * (activation(c) - *x) + *x; } ////////////////////////////////////////////////////////////////////////// -void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { - - // x input [bS x inSize x time] - // c0 initial cell state (at time step = 0) [bS x inSize], - // w weights, [3*inSize x inSize] - // b biases, [2*inSize] +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c) { + // x input [bS x inSize x time] + // c0 initial cell state (at time step = 0) [bS x inSize], + // w weights, [3*inSize x inSize] + // b biases, [2*inSize] - // h cell outputs [bS x inSize x time] - // c cell states [bS x inSize x time] + // h cell outputs [bS x inSize x time] + // c cell states [bS x inSize x time] - auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] + auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] - const int time = x->sizeAt(2); + const int time = x->sizeAt(2); - NDArray ct_1(*c0); + NDArray ct_1(*c0); - // loop through time steps - for (int t = 0; t < time; ++t) { + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({0, 0, 0, 0, t, t + 1}); + auto ht = (*h)({0, 0, 0, 0, t, t + 1}); + auto ct = (*c)({0, 0, 0, 0, t, t + 1}); - auto xt = (*x)({0,0, 0,0, t,t+1}); - auto ht = (*h)({0,0, 0,0, t,t+1}); - auto ct = (*c)({0,0, 0,0, t,t+1}); - - helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); - ct_1.assign(ct); - } + helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); + ct_1.assign(ct); + } } - ////////////////////////////////////////////////////////////////////////// template -__global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - void* vht, const Nd4jLong* htShapeInfo, - void* vct, const Nd4jLong* ctShapeInfo) { - // inputs: - // x [time, bS, 2*K] - // wi [time, bS, 6*K], wi = mmul(x, weights); - // b [4*K] - // c0 [bS, 2*K] - // mask [bS, 2*K], optional - - // outputs - // ht [time, bS, 2*K] - // ct [time, bS, 2*K] - - const auto x = reinterpret_cast(vx); - const auto wi = reinterpret_cast(vwi); - const auto b = reinterpret_cast(vb); - const auto c0 = reinterpret_cast(vc0); - const auto mask = reinterpret_cast(vmask); - auto ht = reinterpret_cast(vht); - auto ct = reinterpret_cast(vct); - - const int rank = 3; - - __shared__ int time, K, *sharedMem; - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - time = xShapeInfo[1]; - K = xShapeInfo[3] / 2; - len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS - - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - - if(tid >= len) - return; - - shape::index2coords(tid, rank - 1, xShapeInfo + 2, coords + 1); // loop through last two dimensions of x : {bS, 2*K} - - const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; - const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); - const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); - const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride - - const T maskVal = mask ? mask[maskOffst] : static_cast(1); - const T bF = b[bFOffset]; - const T bR = b[bROffset]; - T c0Val = c0[c0Offset]; - - const bool flip = coords[2] >= K; - - if(flip) - coords[0] = time - 1; - else - coords[0] = 0; - - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto htOffset = shape::getOffset(htShapeInfo, coords); - auto ctOffset = shape::getOffset(ctShapeInfo, coords); - - coords[2] *= 3; - auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); - auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride - auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride - - // time loop - for (uint t = 0; t < time; ++t) { - - // evaluate sigmoids - T ft = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); - T rt = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); - - c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0]; - ct[ctOffset] = c0Val; - T val = sd::math::nd4j_tanh(c0Val); - T xVal = x[xOffset]; - ht[htOffset] = (val * maskVal - xVal) * rt + xVal; - - if(flip) { - xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step - htOffset -= htShapeInfo[rank + 1]; - ctOffset -= htShapeInfo[rank + 1]; - wiOffset0 -= wiShapeInfo[rank + 1]; - wiOffset1 -= wiShapeInfo[rank + 1]; - wiOffset2 -= wiShapeInfo[rank + 1]; - } - else { - xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step - htOffset += htShapeInfo[rank + 1]; - ctOffset += htShapeInfo[rank + 1]; - wiOffset0 += wiShapeInfo[rank + 1]; - wiOffset1 += wiShapeInfo[rank + 1]; - wiOffset2 += wiShapeInfo[rank + 1]; - } +__global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, + const Nd4jLong* maskShapeInfo, void* vht, + const Nd4jLong* htShapeInfo, void* vct, + const Nd4jLong* ctShapeInfo) { + // inputs: + // x [time, bS, 2*K] + // wi [time, bS, 6*K], wi = mmul(x, weights); + // b [4*K] + // c0 [bS, 2*K] + // mask [bS, 2*K], optional + + // outputs + // ht [time, bS, 2*K] + // ct [time, bS, 2*K] + + const auto x = reinterpret_cast(vx); + const auto wi = reinterpret_cast(vwi); + const auto b = reinterpret_cast(vb); + const auto c0 = reinterpret_cast(vc0); + const auto mask = reinterpret_cast(vmask); + auto ht = reinterpret_cast(vht); + auto ct = reinterpret_cast(vct); + + const int rank = 3; + + __shared__ int time, K, *sharedMem; + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + time = xShapeInfo[1]; + K = xShapeInfo[3] / 2; + len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS + + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; + + if (tid >= len) return; + + shape::index2coords( + tid, rank - 1, xShapeInfo + 2, + coords + 1); // loop through last two dimensions of x : {bS, 2*K} + + const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; + const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); + const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); + const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride + + const T maskVal = mask ? mask[maskOffst] : static_cast(1); + const T bF = b[bFOffset]; + const T bR = b[bROffset]; + T c0Val = c0[c0Offset]; + + const bool flip = coords[2] >= K; + + if (flip) + coords[0] = time - 1; + else + coords[0] = 0; + + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto htOffset = shape::getOffset(htShapeInfo, coords); + auto ctOffset = shape::getOffset(ctShapeInfo, coords); + + coords[2] *= 3; + auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); + auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride + auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride + + // time loop + for (uint t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); + T rt = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); + + c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0]; + ct[ctOffset] = c0Val; + T val = sd::math::nd4j_tanh(c0Val); + T xVal = x[xOffset]; + ht[htOffset] = (val * maskVal - xVal) * rt + xVal; + + if (flip) { + xOffset -= + xShapeInfo[rank + 1]; // first stride, corresponds to time step + htOffset -= htShapeInfo[rank + 1]; + ctOffset -= htShapeInfo[rank + 1]; + wiOffset0 -= wiShapeInfo[rank + 1]; + wiOffset1 -= wiShapeInfo[rank + 1]; + wiOffset2 -= wiShapeInfo[rank + 1]; + } else { + xOffset += + xShapeInfo[rank + 1]; // first stride, corresponds to time step + htOffset += htShapeInfo[rank + 1]; + ctOffset += htShapeInfo[rank + 1]; + wiOffset0 += wiShapeInfo[rank + 1]; + wiOffset1 += wiShapeInfo[rank + 1]; + wiOffset2 += wiShapeInfo[rank + 1]; } + } } ////////////////////////////////////////////////////////////////////////// template -static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - void* vht, const Nd4jLong* htShapeInfo, - void* vct, const Nd4jLong* ctShapeInfo) { - - sruBICuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); +static void sruBICudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, + const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, + const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo) { + sruBICuda<<>>( + vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, + maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); } ////////////////////////////////////////////////////////////////////////// -void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - PointersManager manager(context, "sru_bi"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K - const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; - - NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); - BUILD_SINGLE_SELECTOR(x->dataType(), sruBICudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), c0->specialBuffer(), c0->specialShapeInfo(), mask ? mask->specialBuffer() : nullptr, mask ? mask->specialShapeInfo() : nullptr, ht->specialBuffer(), ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); - - manager.synchronize(); +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct) { + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + PointersManager manager(context, "sru_bi"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / + threadsPerBlock; // loop through last two dimensions of x array -> bS, + // 2*K + const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; + + NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBICudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), + wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), + c0->specialBuffer(), c0->specialShapeInfo(), + mask ? mask->specialBuffer() : nullptr, + mask ? mask->specialShapeInfo() : nullptr, ht->specialBuffer(), + ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask}); + + manager.synchronize(); } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template -__global__ static void sruBIBPCuda(const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - const void* vct, const Nd4jLong* ctShapeInfo, - const void* vgradHt, const Nd4jLong* gradHtShapeInfo, - const void* vgradCt, const Nd4jLong* gradCtShapeInfo, - void* vgradI, const Nd4jLong* gradIShapeInfo, - void* vgradWi, const Nd4jLong* gradWiShapeInfo, - void* vgradB, const Nd4jLong* gradBShapeInfo, - void* vgradC0, const Nd4jLong* gradC0ShapeInfo) { - // inputs: - // x [time, bS, 2*K] - // wi [time, bS, 6*K], wi = mmul(x, weights); - // b [4*K] - // c0 [bS, 2*K] - // mask [bS, 2*K], optional - // ct [time, bS, 2*K] - // gradHt [time, bS, 2*K] - // gradCt [bS, 2*K] - - // outputs - // gradI [time, bS, 2*K] - // gradWi [time, 2*K, 6*K] - // gradB [bS, 4*K] - // gradC0 [bS, 2*K] - - const auto x = reinterpret_cast(vx); - const auto wi = reinterpret_cast(vwi); - const auto b = reinterpret_cast(vb); - const auto c0 = reinterpret_cast(vc0); - const auto mask = reinterpret_cast(vmask); - const auto ct = reinterpret_cast(vct); - const auto gradHt = reinterpret_cast(vgradHt); - const auto gradCt = reinterpret_cast(vgradCt); - - auto gradI = reinterpret_cast(vgradI); - auto gradWi = reinterpret_cast(vgradWi); - auto gradB = reinterpret_cast(vgradB); - auto gradC0 = reinterpret_cast(vgradC0); - - const int rank = 3; - - __shared__ int time, K, *sharedMem; - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - time = xShapeInfo[1]; - K = xShapeInfo[3] / 2; - len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS - - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - auto coords = sharedMem + threadIdx.x * rank; - - if(tid >= len) - return; - - shape::index2coords(tid, rank - 1, xShapeInfo + 2, coords + 1); // loop through last two dimensions of x : {bS, 2*K} - - const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; - const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); - const auto gradCtOffset = shape::getOffset(gradCtShapeInfo, coords + 1); - const auto gradC0Offset = shape::getOffset(gradC0ShapeInfo, coords + 1); - const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); - const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride - // const auto gradBFOffset = shape::getOffset(gradBShapeInfo, coords + 1); - const auto gradBFOffset = coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4]; - const auto gradBROffset = gradBFOffset + gradBShapeInfo[3]; - - const bool flip = coords[2] >= K; - - if(flip) - coords[0] = 0; +__global__ static void sruBIBPCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, + const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, + const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, + const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, + const Nd4jLong* gradCtShapeInfo, void* vgradI, + const Nd4jLong* gradIShapeInfo, void* vgradWi, + const Nd4jLong* gradWiShapeInfo, void* vgradB, + const Nd4jLong* gradBShapeInfo, void* vgradC0, + const Nd4jLong* gradC0ShapeInfo) { + // inputs: + // x [time, bS, 2*K] + // wi [time, bS, 6*K], wi = mmul(x, weights); + // b [4*K] + // c0 [bS, 2*K] + // mask [bS, 2*K], optional + // ct [time, bS, 2*K] + // gradHt [time, bS, 2*K] + // gradCt [bS, 2*K] + + // outputs + // gradI [time, bS, 2*K] + // gradWi [time, 2*K, 6*K] + // gradB [bS, 4*K] + // gradC0 [bS, 2*K] + + const auto x = reinterpret_cast(vx); + const auto wi = reinterpret_cast(vwi); + const auto b = reinterpret_cast(vb); + const auto c0 = reinterpret_cast(vc0); + const auto mask = reinterpret_cast(vmask); + const auto ct = reinterpret_cast(vct); + const auto gradHt = reinterpret_cast(vgradHt); + const auto gradCt = reinterpret_cast(vgradCt); + + auto gradI = reinterpret_cast(vgradI); + auto gradWi = reinterpret_cast(vgradWi); + auto gradB = reinterpret_cast(vgradB); + auto gradC0 = reinterpret_cast(vgradC0); + + const int rank = 3; + + __shared__ int time, K, *sharedMem; + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + time = xShapeInfo[1]; + K = xShapeInfo[3] / 2; + len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS + + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto coords = sharedMem + threadIdx.x * rank; + + if (tid >= len) return; + + shape::index2coords( + tid, rank - 1, xShapeInfo + 2, + coords + 1); // loop through last two dimensions of x : {bS, 2*K} + + const auto maskOffst = mask ? shape::getOffset(maskShapeInfo, coords + 1) : 0; + const auto c0Offset = shape::getOffset(c0ShapeInfo, coords + 1); + const auto gradCtOffset = shape::getOffset(gradCtShapeInfo, coords + 1); + const auto gradC0Offset = shape::getOffset(gradC0ShapeInfo, coords + 1); + const auto bFOffset = shape::getOffset(bShapeInfo, coords + 2); + const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride + // const auto gradBFOffset = shape::getOffset(gradBShapeInfo, coords + 1); + const auto gradBFOffset = + coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4]; + const auto gradBROffset = gradBFOffset + gradBShapeInfo[3]; + + const bool flip = coords[2] >= K; + + if (flip) + coords[0] = 0; + else + coords[0] = time - 1; + + auto xOffset = shape::getOffset(xShapeInfo, coords); + auto ctOffset = shape::getOffset(ctShapeInfo, coords); + auto gradIOffset = shape::getOffset(gradIShapeInfo, coords); + auto gradHtOffset = shape::getOffset(gradHtShapeInfo, coords); + + coords[2] *= 3; + auto gradWiOffset0 = shape::getOffset(gradWiShapeInfo, coords); + auto gradWiOffset1 = + gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride + auto gradWiOffset2 = + gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride + auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); + auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride + auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride + + const T xVal = x[xOffset]; + const T maskVal = mask ? mask[maskOffst] : static_cast(1); + const T c0Val = c0[c0Offset]; + const T bF = b[bFOffset]; + const T bR = b[bROffset]; + T gradCtVal = gradCt[gradCtOffset]; + T gbF = 0.f; + T gbR = 0.f; + + // time loop + for (uint t = 0; t < time; ++t) { + // evaluate sigmoids + T ft = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); + T rt = (1.f) / (1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); + + T val = sd::math::nd4j_tanh(ct[ctOffset]); + + T prevVal; + if (t < time - 1) + prevVal = + ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]]; else - coords[0] = time - 1; - - auto xOffset = shape::getOffset(xShapeInfo, coords); - auto ctOffset = shape::getOffset(ctShapeInfo, coords); - auto gradIOffset = shape::getOffset(gradIShapeInfo, coords); - auto gradHtOffset = shape::getOffset(gradHtShapeInfo, coords); - - coords[2] *= 3; - auto gradWiOffset0 = shape::getOffset(gradWiShapeInfo, coords); - auto gradWiOffset1 = gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride - auto gradWiOffset2 = gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride - auto wiOffset0 = shape::getOffset(wiShapeInfo, coords); - auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride - auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride - - const T xVal = x[xOffset]; - const T maskVal = mask ? mask[maskOffst] : static_cast(1); - const T c0Val = c0[c0Offset]; - const T bF = b[bFOffset]; - const T bR = b[bROffset]; - T gradCtVal = gradCt[gradCtOffset]; - T gbF = 0.f; - T gbR = 0.f; - - // time loop - for (uint t = 0; t < time; ++t) { - - // evaluate sigmoids - T ft = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset1] + bF))); - T rt = (1.f)/(1.f + sd::math::nd4j_exp(-(wi[wiOffset2] + bR))); - - T val = sd::math::nd4j_tanh(ct[ctOffset]); - - T prevVal; - if(t < time-1) - prevVal = ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]]; - else - prevVal = c0Val; - - // grad wrt input - gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt ; - - // grad wrt rt, wiR and bR - T grt = gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt); - gradWi[gradWiOffset2] = grt; - gbR += grt; - - // grad wrt state - T gradC0Val = gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal; - - // grad wrt wi0 - gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft; - - // grad wrt ft, wi1, and bF - T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft); - gradWi[gradWiOffset1] = gft; - gbF += gft; - - // grad wrt c_previous - gradCtVal = gradC0Val * ft; - - if(flip) { - xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step - gradHtOffset += gradHtShapeInfo[rank + 1]; - gradIOffset += gradIShapeInfo[rank + 1]; - wiOffset0 += wiShapeInfo[rank + 1]; - wiOffset1 += wiShapeInfo[rank + 1]; - wiOffset2 += wiShapeInfo[rank + 1]; - gradWiOffset0 += gradWiShapeInfo[rank + 1]; - gradWiOffset1 += gradWiShapeInfo[rank + 1]; - gradWiOffset2 += gradWiShapeInfo[rank + 1]; - } - else { - xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step - gradHtOffset -= gradHtShapeInfo[rank + 1]; - gradIOffset -= gradIShapeInfo[rank + 1]; - wiOffset0 -= wiShapeInfo[rank + 1]; - wiOffset1 -= wiShapeInfo[rank + 1]; - wiOffset2 -= wiShapeInfo[rank + 1]; - gradWiOffset0 -= gradWiShapeInfo[rank + 1]; - gradWiOffset1 -= gradWiShapeInfo[rank + 1]; - gradWiOffset2 -= gradWiShapeInfo[rank + 1]; - } + prevVal = c0Val; + + // grad wrt input + gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt; + + // grad wrt rt, wiR and bR + T grt = + gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt); + gradWi[gradWiOffset2] = grt; + gbR += grt; + + // grad wrt state + T gradC0Val = + gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal; + + // grad wrt wi0 + gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft; + + // grad wrt ft, wi1, and bF + T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft); + gradWi[gradWiOffset1] = gft; + gbF += gft; + + // grad wrt c_previous + gradCtVal = gradC0Val * ft; + + if (flip) { + xOffset += + xShapeInfo[rank + 1]; // first stride, corresponds to time step + gradHtOffset += gradHtShapeInfo[rank + 1]; + gradIOffset += gradIShapeInfo[rank + 1]; + wiOffset0 += wiShapeInfo[rank + 1]; + wiOffset1 += wiShapeInfo[rank + 1]; + wiOffset2 += wiShapeInfo[rank + 1]; + gradWiOffset0 += gradWiShapeInfo[rank + 1]; + gradWiOffset1 += gradWiShapeInfo[rank + 1]; + gradWiOffset2 += gradWiShapeInfo[rank + 1]; + } else { + xOffset -= + xShapeInfo[rank + 1]; // first stride, corresponds to time step + gradHtOffset -= gradHtShapeInfo[rank + 1]; + gradIOffset -= gradIShapeInfo[rank + 1]; + wiOffset0 -= wiShapeInfo[rank + 1]; + wiOffset1 -= wiShapeInfo[rank + 1]; + wiOffset2 -= wiShapeInfo[rank + 1]; + gradWiOffset0 -= gradWiShapeInfo[rank + 1]; + gradWiOffset1 -= gradWiShapeInfo[rank + 1]; + gradWiOffset2 -= gradWiShapeInfo[rank + 1]; } + } - gradB[gradBFOffset] = gbF; - gradB[gradBROffset] = gbR; - gradC0[gradC0Offset] = gradCtVal; + gradB[gradBFOffset] = gbF; + gradB[gradBROffset] = gbR; + gradC0[gradC0Offset] = gradCtVal; } ////////////////////////////////////////////////////////////////////////// template -static void sruBIBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, - const void* vwi, const Nd4jLong* wiShapeInfo, - const void* vb, const Nd4jLong* bShapeInfo, - const void* vc0, const Nd4jLong* c0ShapeInfo, - const void* vmask, const Nd4jLong* maskShapeInfo, - const void* vct, const Nd4jLong* ctShapeInfo, - const void* vgradHt, const Nd4jLong* gradHtShapeInfo, - const void* vgradCt, const Nd4jLong* gradCtShapeInfo, - void* vgradI, const Nd4jLong* gradIShapeInfo, - void* vgradWi, const Nd4jLong* gradWiShapeInfo, - void* vgradB, const Nd4jLong* gradBShapeInfo, - void* vgradC0, const Nd4jLong* gradC0ShapeInfo) { - - sruBIBPCuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, gradBShapeInfo, vgradC0, gradC0ShapeInfo); +static void sruBIBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, + const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, + const Nd4jLong* ctShapeInfo, const void* vgradHt, + const Nd4jLong* gradHtShapeInfo, const void* vgradCt, + const Nd4jLong* gradCtShapeInfo, void* vgradI, + const Nd4jLong* gradIShapeInfo, void* vgradWi, + const Nd4jLong* gradWiShapeInfo, void* vgradB, + const Nd4jLong* gradBShapeInfo, void* vgradC0, + const Nd4jLong* gradC0ShapeInfo) { + sruBIBPCuda<<>>( + vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, + maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, + gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, + gradBShapeInfo, vgradC0, gradC0ShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, const Nd4jLong* gradCtShapeInfo, void* vgradI, const Nd4jLong* gradIShapeInfo, void* vgradWi, const Nd4jLong* gradWiShapeInfo, void* vgradB, const Nd4jLong* gradBShapeInfo, void* vgradC0, const Nd4jLong* gradC0ShapeInfo), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, + const void* vwi, const Nd4jLong* wiShapeInfo, + const void* vb, const Nd4jLong* bShapeInfo, + const void* vc0, const Nd4jLong* c0ShapeInfo, + const void* vmask, const Nd4jLong* maskShapeInfo, + const void* vct, const Nd4jLong* ctShapeInfo, + const void* vgradHt, const Nd4jLong* gradHtShapeInfo, + const void* vgradCt, const Nd4jLong* gradCtShapeInfo, + void* vgradI, const Nd4jLong* gradIShapeInfo, + void* vgradWi, const Nd4jLong* gradWiShapeInfo, + void* vgradB, const Nd4jLong* gradBShapeInfo, + void* vgradC0, const Nd4jLong* gradC0ShapeInfo), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, - const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask, - NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { - - // x = x * mask - if(mask) - x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask - - // U = x * w - NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int K = x->sizeAt(2) / 2; - - NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), context); - NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), context); - - PointersManager manager(context, "sru_bi_bp"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K - const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; - - NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask}); - BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), c0->specialBuffer(), c0->specialShapeInfo(), mask ? mask->specialBuffer() : nullptr, mask ? mask->specialShapeInfo() : nullptr, ct->specialBuffer(), ct->specialShapeInfo(), gradHt->specialBuffer(), gradHt->specialShapeInfo(), gradCt->specialBuffer(), gradCt->specialShapeInfo(), gradI->specialBuffer(), gradI->specialShapeInfo(), gradWi.specialBuffer(), gradWi.specialShapeInfo(), gradBias.specialBuffer(), gradBias.specialShapeInfo(), gradC0->specialBuffer(), gradC0->specialShapeInfo()), FLOAT_TYPES); - NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask}); - - manager.synchronize(); - - // gradB - gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] - - // gradW - x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] - MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K] +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask, + NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { + // x = x * mask + if (mask) + x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x); // apply mask + + // U = x * w + NDArray wi = mmul(*x, *w); // U [time x bS x 6*K] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int K = x->sizeAt(2) / 2; + + NDArray gradBias(x->ordering(), {bS, 4 * K}, x->dataType(), context); + NDArray gradWi(x->ordering(), {time, bS, 6 * K}, x->dataType(), context); + + PointersManager manager(context, "sru_bi_bp"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / + threadsPerBlock; // loop through last two dimensions of x array -> bS, + // 2*K + const int sharedMem = threadsPerBlock * sizeof(int) * x->rankOf() + 128; + + NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, + {x, &wi, b, c0, ct, gradCt, gradHt, mask}); + BUILD_SINGLE_SELECTOR( + x->dataType(), sruBIBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + x->specialBuffer(), x->specialShapeInfo(), wi.specialBuffer(), + wi.specialShapeInfo(), b->specialBuffer(), b->specialShapeInfo(), + c0->specialBuffer(), c0->specialShapeInfo(), + mask ? mask->specialBuffer() : nullptr, + mask ? mask->specialShapeInfo() : nullptr, ct->specialBuffer(), + ct->specialShapeInfo(), gradHt->specialBuffer(), + gradHt->specialShapeInfo(), gradCt->specialBuffer(), + gradCt->specialShapeInfo(), gradI->specialBuffer(), + gradI->specialShapeInfo(), gradWi.specialBuffer(), + gradWi.specialShapeInfo(), gradBias.specialBuffer(), + gradBias.specialShapeInfo(), gradC0->specialBuffer(), + gradC0->specialShapeInfo()), + FLOAT_TYPES); + NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, + {x, &wi, b, c0, ct, gradCt, gradHt, mask}); + + manager.synchronize(); + + // gradB + gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0}); // [4*K] + + // gradW + x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS] + MmulHelper::mmul( + x, &gradWi, gradW, 1., + 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K] } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu index f0983b76c9db..f899da6db46d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu @@ -18,196 +18,216 @@ // Created by Yurii Shyrma on 02.01.2018 // -#include -#include #include #include -#include -#include #include +#include +#include +#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// template -static __global__ void stackScalarsCuda(void* pVx, void* vz, const Nd4jLong* zShapeInfo) { +static __global__ void stackScalarsCuda(void* pVx, void* vz, + const Nd4jLong* zShapeInfo) { + T* z = reinterpret_cast(vz); - T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads; - __shared__ Nd4jLong zLen, totalThreads; - - if (threadIdx.x == 0) { - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + if (threadIdx.x == 0) { + zLen = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - const T *x = reinterpret_cast(reinterpret_cast(pVx)[i]); - z[shape::getIndexOffset(i, zShapeInfo)] = *x; - } + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const T* x = reinterpret_cast(reinterpret_cast(pVx)[i]); + z[shape::getIndexOffset(i, zShapeInfo)] = *x; + } } - /////////////////////////////////////////////////////////////////// -template -__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - void* pVx, void* vz, const Nd4jLong* zShapeInfo) { - - stackScalarsCuda<<>>(pVx, vz, zShapeInfo); +template +__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t* stream, + void* pVx, void* vz, + const Nd4jLong* zShapeInfo) { + stackScalarsCuda + <<>>(pVx, vz, zShapeInfo); } /////////////////////////////////////////////////////////////////// template -static void stack_(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim) { +static void stack_(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + const int numOfSubArrs = inArrs.size(); - const int numOfSubArrs = inArrs.size(); + NDArray::prepareSpecialUse({&output}, inArrs); - NDArray::prepareSpecialUse({&output}, inArrs); + if (inArrs[0]->rankOf() == 0) { + std::vector hInBuffers(numOfSubArrs); - if(inArrs[0]->rankOf() == 0) { + for (int i = 0; i < numOfSubArrs; ++i) + hInBuffers[i] = inArrs[i]->specialBuffer(); - std::vector hInBuffers(numOfSubArrs); + PointersManager manager(context, "helpers::stack cuda"); - for(int i = 0; i < numOfSubArrs; ++i) - hInBuffers[i] = inArrs[i]->specialBuffer(); + void* dInBuffers = manager.replicatePointer( + hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - PointersManager manager(context, "helpers::stack cuda"); - - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - stackScalarsCudaLauncher(blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, output.specialBuffer(), output.specialShapeInfo()); - - manager.synchronize(); - } - else { + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); - auto zTadShapeInfo = zTadPack.primaryShapeInfo(); + stackScalarsCudaLauncher( + blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, + output.specialBuffer(), output.specialShapeInfo()); - for (uint i = 0; i < numOfSubArrs; ++i) { + manager.synchronize(); + } else { + auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), + ShapeUtils::evalDimsToExclude(output.rankOf(), {dim})); + auto zTadShapeInfo = zTadPack.primaryShapeInfo(); - void* zBuff = output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]); + for (uint i = 0; i < numOfSubArrs; ++i) { + void* zBuff = + output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]); - NativeOpExecutioner::execTransformAny(context, transform::Assign, - nullptr, inArrs[i]->shapeInfo(), inArrs[i]->specialBuffer(), inArrs[i]->specialShapeInfo(), - nullptr, zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(), - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } + NativeOpExecutioner::execTransformAny( + context, transform::Assign, nullptr, inArrs[i]->shapeInfo(), + inArrs[i]->specialBuffer(), inArrs[i]->specialShapeInfo(), nullptr, + zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(), nullptr, nullptr, + nullptr, false /*allowParallelism*/); } + } - NDArray::registerSpecialUse({&output}, inArrs); + NDArray::registerSpecialUse({&output}, inArrs); } //////////////////////////////////////////////////////////////////////// -void stack(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim) { - BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (context, inArrs, output, dim), LIBND4J_TYPES); +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int dim) { + BUILD_SINGLE_SELECTOR(output.dataType(), stack_, + (context, inArrs, output, dim), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int dim), LIBND4J_TYPES); - +BUILD_SINGLE_TEMPLATE(template void stack_, + (sd::LaunchContext * context, + const std::vector& inArrs, + NDArray& output, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// template -static __global__ void unstackScalarsCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz) { - - const T* x = reinterpret_cast(vx); - - __shared__ Nd4jLong xLen, totalThreads; +static __global__ void unstackScalarsCuda(const void* vx, + const Nd4jLong* xShapeInfo, + void* pVz) { + const T* x = reinterpret_cast(vx); - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); + __shared__ Nd4jLong xLen, totalThreads; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + __syncthreads(); - for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - T* z = reinterpret_cast(reinterpret_cast(pVz)[i]); - *z = x[shape::getIndexOffset(i, xShapeInfo)]; - } + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + T* z = reinterpret_cast(reinterpret_cast(pVz)[i]); + *z = x[shape::getIndexOffset(i, xShapeInfo)]; + } } - /////////////////////////////////////////////////////////////////// -template -__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* pVz) { - - unstackScalarsCuda<<>>(vx, xShapeInfo, pVz); +template +__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const cudaStream_t* stream, + const void* vx, + const Nd4jLong* xShapeInfo, + void* pVz) { + unstackScalarsCuda + <<>>(vx, xShapeInfo, pVz); } /////////////////////////////////////////////////////////////////// template -static void unstack_(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - - const int numOfSubArrs = outArrs.size(); +static void unstack_(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + const int numOfSubArrs = outArrs.size(); - // NDArray::prepareSpecialUse(outArrs, {&input}); - input.syncToDevice(); - for (const auto a : outArrs) - a->getDataBuffer()->allocateSpecial(); + // NDArray::prepareSpecialUse(outArrs, {&input}); + input.syncToDevice(); + for (const auto a : outArrs) a->getDataBuffer()->allocateSpecial(); + if (outArrs[0]->rankOf() == 0) { + std::vector hOutBuffers(numOfSubArrs); - if(outArrs[0]->rankOf() == 0) { + for (int i = 0; i < numOfSubArrs; ++i) + hOutBuffers[i] = outArrs[i]->specialBuffer(); - std::vector hOutBuffers(numOfSubArrs); + PointersManager manager(context, "helpers::unstack cuda"); - for(int i = 0; i < numOfSubArrs; ++i) - hOutBuffers[i] = outArrs[i]->specialBuffer(); + void* dOutBuffers = manager.replicatePointer( + hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - PointersManager manager(context, "helpers::unstack cuda"); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + unstackScalarsCudaLauncher( + blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), dOutBuffers); - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + manager.synchronize(); + } else { + auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), + ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); + auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - unstackScalarsCudaLauncher(blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers); + for (uint i = 0; i < numOfSubArrs; ++i) { + auto xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]); - manager.synchronize(); + NativeOpExecutioner::execTransformAny( + input.getContext(), transform::Assign, nullptr, xTadShapeInfo, xBuff, + xTadPack.specialShapeInfo(), nullptr, outArrs[i]->shapeInfo(), + outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(), nullptr, + nullptr, nullptr, false /*allowParallelism*/); } - else { - - auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim})); - auto xTadShapeInfo = xTadPack.primaryShapeInfo(); - - for (uint i = 0; i < numOfSubArrs; ++i) { + } - auto xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]); - - NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign, - nullptr, xTadShapeInfo, xBuff, xTadPack.specialShapeInfo(), - nullptr, outArrs[i]->shapeInfo(), outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(), - nullptr, nullptr, nullptr, false/*allowParallelism*/); - } - } - - // NDArray::registerSpecialUse(outArrs, {&input}); - input.tickReadDevice(); - for (const auto p : outArrs) - p->tickWriteDevice(); + // NDArray::registerSpecialUse(outArrs, {&input}); + input.tickReadDevice(); + for (const auto p : outArrs) p->tickWriteDevice(); } //////////////////////////////////////////////////////////////////////// -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim) { - BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (context, input, outArrs, dim), LIBND4J_TYPES); +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim) { + BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, + (context, input, outArrs, dim), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void unstack_, + (sd::LaunchContext * context, const NDArray& input, + const std::vector& outArrs, const int dim), + LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// // template -// static __global__ void unstackCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { +// static __global__ void unstackCuda(const void* vx, const Nd4jLong* +// xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { // const T* x = reinterpret_cast(vx); // __shared__ Nd4jLong xLen, totalThreads; @@ -230,10 +250,11 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // const auto xOffset = shape::getOffset(xShapeInfo, coords); -// T *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis]]); +// T *z = reinterpret_cast(reinterpret_cast(pVz)[coords[axis]]); -// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring from axis position -// coords[j] = coords[j + 1]; +// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring +// from axis position coords[j] = coords[j + 1]; // const auto zOffset = shape::getOffset(zTadShapeInfo, coords); @@ -243,19 +264,25 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // /////////////////////////////////////////////////////////////////// // template -// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { +// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* +// xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { -// unstackCuda<<>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis); +// unstackCuda<<>>(vx, +// xShapeInfo, pVz, zTadShapeInfo, axis); // } -// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES); - +// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int +// blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const +// void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* +// zTadShapeInfo, const int axis), LIBND4J_TYPES); // /////////////////////////////////////////////////////////////////// -// void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int axis) { +// void unstack(sd::LaunchContext* context, const NDArray& input, const +// std::vector& outArrs, const int axis) { // const int threadsPerBlock = MAX_NUM_THREADS / 2; -// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; +// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / +// threadsPerBlock; // const int numOfSubArrs = outArrs.size(); @@ -266,13 +293,17 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // PointersManager manager(context, "helpers::unstack"); -// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); +// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), +// hOutBuffers.size() * sizeof(void*)); // for(uint i = 0; i < numOfSubArrs; ++i) // outArrs[i]->syncToDevice(); // input.syncToDevice(); -// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES); +// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, +// (blocksPerGrid, threadsPerBlock, context->getCudaStream(), +// input.specialBuffer(), input.specialShapeInfo(), dOutBuffers, +// outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES); // manager.synchronize(); @@ -281,10 +312,10 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // input.tickWriteDevice(); // } - // /////////////////////////////////////////////////////////////////// // template -// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, +// void* vz, const Nd4jLong* zShapeInfo, const int axis) { // T* z = reinterpret_cast(vz); @@ -308,10 +339,11 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // const auto zOffset = shape::getOffset(zShapeInfo, coords); -// const T *x = reinterpret_cast(reinterpret_cast(pVx)[coords[axis]]); +// const T *x = reinterpret_cast(reinterpret_cast(pVx)[coords[axis]]); -// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring from axis position -// coords[j] = coords[j + 1]; +// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring +// from axis position coords[j] = coords[j + 1]; // const auto xOffset = shape::getOffset(xTadShapeInfo, coords); @@ -321,19 +353,25 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // /////////////////////////////////////////////////////////////////// // template -// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, -// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int +// threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* +// xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { -// stackCuda<<>>(pVx, xTadShapeInfo, vz, zShapeInfo, axis); +// stackCuda<<>>(pVx, +// xTadShapeInfo, vz, zShapeInfo, axis); // } -// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES); - +// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int +// blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* +// pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, +// const int axis), LIBND4J_TYPES); // /////////////////////////////////////////////////////////////////// -// void stack(sd::LaunchContext* context, const std::vector& inArrs, NDArray& output, const int axis) { +// void stack(sd::LaunchContext* context, const std::vector& +// inArrs, NDArray& output, const int axis) { // const int threadsPerBlock = MAX_NUM_THREADS / 2; -// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; +// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / +// threadsPerBlock; // const int numOfSubArrs = inArrs.size(); @@ -344,13 +382,17 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // PointersManager manager(context, "helpers::stack"); -// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); +// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), +// hInBuffers.size() * sizeof(void*)); // for(uint i = 0; i < numOfSubArrs; ++i) // inArrs[i]->syncToDevice(); // output.syncToDevice(); -// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, inArrs[0]->specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES); +// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, +// (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, +// inArrs[0]->specialShapeInfo(), output.specialBuffer(), +// output.specialShapeInfo(), axis), LIBND4J_TYPES); // manager.synchronize(); @@ -359,7 +401,6 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const // output.tickWriteDevice(); // } -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index 33dd0251a5fd..bb0668cb95db 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -18,660 +18,728 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include #include #include #include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - -// FIXME -> we should optimize these helpers for the case when input matrices have c order (perform transpositions appropriately) +// FIXME -> we should optimize these helpers for the case when input matrices +// have c order (perform transpositions appropriately) template -__global__ static void inverseColumnSignCuda(void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo) { - - T* u = reinterpret_cast(vu); - T* v = reinterpret_cast(vv); - - __shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank - __shared__ Nd4jLong uLen, vLen; - __shared__ Nd4jLong *sharedMem; - - if (threadIdx.x == 0) { +__global__ static void inverseColumnSignCuda(void* vu, + const Nd4jLong* uShapeInfo, + void* vv, + const Nd4jLong* vShapeInfo) { + T* u = reinterpret_cast(vu); + T* v = reinterpret_cast(vv); - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank + __shared__ Nd4jLong uLen, vLen; + __shared__ Nd4jLong* sharedMem; - rank = shape::rank(uShapeInfo); - uLen = shape::length(uShapeInfo); - vLen = shape::length(vShapeInfo); + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - uLastButOneColumn = uShapeInfo[rank] - 2; - vLastButOneColumn = vShapeInfo[rank - 1] - 2; - } - - __syncthreads(); + rank = shape::rank(uShapeInfo); + uLen = shape::length(uShapeInfo); + vLen = shape::length(vShapeInfo); - const auto ind = threadIdx.x + blockIdx.x * blockDim.x; + uLastButOneColumn = uShapeInfo[rank] - 2; + vLastButOneColumn = vShapeInfo[rank - 1] - 2; + } - auto coords = sharedMem + threadIdx.x * rank; + __syncthreads(); - // u - for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) { + const auto ind = threadIdx.x + blockIdx.x * blockDim.x; - shape::index2coords(i, uShapeInfo, coords); + auto coords = sharedMem + threadIdx.x * rank; - if(coords[rank - 1] == 0 || coords[rank - 1] == uLastButOneColumn) // do not change sign in first and last but one columns - continue; + // u + for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, uShapeInfo, coords); - const auto uOffset = shape::getOffset(uShapeInfo, coords); + if (coords[rank - 1] == 0 || + coords[rank - 1] == uLastButOneColumn) // do not change sign in first + // and last but one columns + continue; - u[uOffset] = -u[uOffset]; - } + const auto uOffset = shape::getOffset(uShapeInfo, coords); - // v - for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) { + u[uOffset] = -u[uOffset]; + } - shape::index2coords(i, vShapeInfo, coords); + // v + for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) { + shape::index2coords(i, vShapeInfo, coords); - if(coords[rank - 2] == 0 || coords[rank - 2] == vLastButOneColumn) // do not change sign in first and last but one columns - continue; + if (coords[rank - 2] == 0 || + coords[rank - 2] == vLastButOneColumn) // do not change sign in first + // and last but one columns + continue; - const auto vOffset = shape::getOffset(vShapeInfo, coords); + const auto vOffset = shape::getOffset(vShapeInfo, coords); - v[vOffset] = -v[vOffset]; - } + v[vOffset] = -v[vOffset]; + } } ////////////////////////////////////////////////////////////////////////// template -static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - void* vu, const Nd4jLong* uShapeInfo, - void* vv, const Nd4jLong* vShapeInfo) { - - inverseColumnSignCuda<<>>(vu, uShapeInfo, vv, vShapeInfo); +static void inverseColumnSignCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, + const int sharedMem, + const cudaStream_t* stream, void* vu, + const Nd4jLong* uShapeInfo, void* vv, + const Nd4jLong* vShapeInfo) { + inverseColumnSignCuda + <<>>(vu, uShapeInfo, + vv, vShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const int sharedMem, const cudaStream_t* stream, + void* vu, const Nd4jLong* uShapeInfo, void* vv, + const Nd4jLong* vShapeInfo), + FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) { - - // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f' - // we make this function to have deal with 2 valid cases only: - // 1) A_rows >= A_columns and A_corder = 'f' - // 2) A_rows <= A_columns and A_corder = 'c' - int this case perform transposition to get f order - // if 1) or 2) are not met then throw exception - - // A [m, n] - // S [n] - // U [m, m] or [m, n] if fullUV = false and m > n - // VT [n, n] or [m, n] if fullUV = false and m < n - - if(A->rankOf() != 2) - throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); - - auto m = A->sizeAt(0); - auto n = A->sizeAt(1); - const int minDim = m < n ? m : n; - const char orderA = A->ordering(); - - if(m < n) - throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !"); - - if(std::vector({minDim}) != S->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of S array !"); - - if(calcUV) { - - if(fullUV && std::vector({m,m}) != U->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of U array !"); - else if(!fullUV && std::vector({m,minDim}) != U->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of U array !"); - - if(fullUV && std::vector({n,n}) != VT->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of VT array !"); - else if(!fullUV && std::vector({minDim,n}) != VT->getShapeAsVector()) - throw std::runtime_error("svdQR: wrong shape of VT array !"); - } - - NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pVT = VT; - - std::vector toDelete; - - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); +static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* VT, const bool fullUV, + const bool calcUV) { + // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain + // on input matrix A: A_rows >= A_columns && A_order = 'f' we make this + // function to have deal with 2 valid cases only: 1) A_rows >= A_columns and + // A_corder = 'f' 2) A_rows <= A_columns and A_corder = 'c' - int this case + // perform transposition to get f order if 1) or 2) are not met then throw + // exception + + // A [m, n] + // S [n] + // U [m, m] or [m, n] if fullUV = false and m > n + // VT [n, n] or [m, n] if fullUV = false and m < n + + if (A->rankOf() != 2) + throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); + + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); + const int minDim = m < n ? m : n; + const char orderA = A->ordering(); + + if (m < n) + throw std::runtime_error( + "svdQR: due to cuda api input constrains given shape of A array are " + "not valid !"); + + if (std::vector({minDim}) != S->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of S array !"); + + if (calcUV) { + if (fullUV && std::vector({m, m}) != U->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of U array !"); + else if (!fullUV && + std::vector({m, minDim}) != U->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of U array !"); + + if (fullUV && std::vector({n, n}) != VT->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of VT array !"); + else if (!fullUV && + std::vector({minDim, n}) != VT->getShapeAsVector()) + throw std::runtime_error("svdQR: wrong shape of VT array !"); + } + + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pVT = VT; + + std::vector toDelete; + + if (pA->ews() != 1 || pA->ordering() == 'c') { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } + + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } + + if (calcUV) { + if (pU->ews() != 1 || pU->ordering() == 'c') { + pU = new NDArray(U->dup('f')); + toDelete.push_back(pU); } - if(calcUV) { - - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = new NDArray(U->dup('f')); - toDelete.push_back(pU); - } - - if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = new NDArray(VT->dup('f')); - toDelete.push_back(pVT); - } + if (pVT->ews() != 1 || pVT->ordering() == 'c') { + pVT = new NDArray(VT->dup('f')); + toDelete.push_back(pVT); } - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - // create cusolverDn handle - cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; - //cusolverStatus_t status = cusolverDnCreate(&handle); - if(handle == nullptr) - throw cuda_exception::build("svdQR: cuda failed !", -1); - - // stream - auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); + } + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + // create cusolverDn handle + cusolverDnHandle_t* handle = + (cusolverDnHandle_t*)context->getCusolverHandle(); // nullptr; + // cusolverStatus_t status = cusolverDnCreate(&handle); + if (handle == nullptr) + throw cuda_exception::build("svdQR: cuda failed !", -1); + + // stream + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); + else + throw std::invalid_argument("svdQR: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + // allocate memory for dWork + void* dWork = nullptr; + cudaError_t status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdQR: cuda failed !", status2); + + signed char jobu, jobvt; + + if (calcUV) { + if (fullUV) + jobu = jobvt = 'A'; else - throw std::invalid_argument("svdQR: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - // allocate memory for dWork - void* dWork = nullptr; - cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdQR: cuda failed !", status2); - - signed char jobu, jobvt; - - if(calcUV) { - if(fullUV) - jobu = jobvt = 'A'; - else - jobu = jobvt = 'S'; - } - else { - jobu = jobvt = 'N'; - } - - int *devInfo = nullptr; - void* rWork = nullptr; - - int lda(m), ldu, ldvt; - - if(calcUV) { - ldu = pU->sizeAt(0); - ldvt = pVT->sizeAt(0); - } - - PointersManager manager(context, "svdQR"); - - NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); - } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); - } - else - throw std::invalid_argument("svdQR: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); - - S->assign(pS); - - if(calcUV) { - U->assign(pU); - VT->assign(pVT); - } - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); - if (rWork) - cudaFree(rWork); - -// if(handle) -// cusolverDnDestroy(handle); - - // cudaDeviceReset(); + jobu = jobvt = 'S'; + } else { + jobu = jobvt = 'N'; + } + + int* devInfo = nullptr; + void* rWork = nullptr; + + int lda(m), ldu, ldvt; + + if (calcUV) { + ldu = pU->sizeAt(0); + ldvt = pVT->sizeAt(0); + } + + PointersManager manager(context, "svdQR"); + + NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvd( + *handle, jobu, jobvt, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, + ldvt, reinterpret_cast(dWork), lwork, + reinterpret_cast(rWork), devInfo); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvd( + *handle, jobu, jobvt, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pVT->specialBuffer()) : nullptr, ldvt, + reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), + devInfo); + } else + throw std::invalid_argument("svdQR: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdQR: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); + + S->assign(pS); + + if (calcUV) { + U->assign(pU); + VT->assign(pVT); + } + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + if (rWork) cudaFree(rWork); + + // if(handle) + // cusolverDnDestroy(handle); + + // cudaDeviceReset(); } ////////////////////////////////////////////////////////////////////////// -static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { - - // A [m, n] - // S [n] - // U [m, m] or [m, n] if fullUV = false and m > n - // V [n, n] or [n, m] if fullUV = false and m < n +static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* V, const bool fullUV, + const bool calcUV) { + // A [m, n] + // S [n] + // U [m, m] or [m, n] if fullUV = false and m > n + // V [n, n] or [n, m] if fullUV = false and m < n - if(A->rankOf() != 2) - throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); + if (A->rankOf() != 2) + throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - int m = A->sizeAt(0); - int n = A->sizeAt(1); - const int minDim = m < n ? m : n; + int m = A->sizeAt(0); + int n = A->sizeAt(1); + const int minDim = m < n ? m : n; - if(std::vector({minDim}) != S->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of S array !"); + if (std::vector({minDim}) != S->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of S array !"); - if(calcUV) { + if (calcUV) { + if (fullUV && std::vector({m, m}) != U->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of U array !"); + else if (!fullUV && + std::vector({m, minDim}) != U->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of U array !"); - if(fullUV && std::vector({m,m}) != U->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of U array !"); - else if(!fullUV && std::vector({m,minDim}) != U->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of U array !"); + if (fullUV && std::vector({n, n}) != V->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of V array !"); + else if (!fullUV && + std::vector({n, minDim}) != V->getShapeAsVector()) + throw std::runtime_error("svdJcb: wrong shape of V array !"); + } - if(fullUV && std::vector({n,n}) != V->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of V array !"); - else if(!fullUV && std::vector({n,minDim}) != V->getShapeAsVector()) - throw std::runtime_error("svdJcb: wrong shape of V array !"); - } + NDArray* pA = const_cast(A); - NDArray* pA = const_cast(A); + const bool aForder = m == 1 || A->strideAt(0) == 1; + const bool aCorder = n == 1 || A->strideAt(1) == 1; - const bool aForder = m == 1 || A->strideAt(0) == 1; - const bool aCorder = n == 1 || A->strideAt(1) == 1; + const bool transA = !aForder && aCorder; + const bool dupA = !aForder && !aCorder; - const bool transA = !aForder && aCorder; - const bool dupA = !aForder && !aCorder; + std::vector toDelete; - std::vector toDelete; + if (dupA) { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } - if(dupA) { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - NDArray* pS = S; - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); - } + NDArray* pS = S; - NDArray *pU(nullptr), *pV(nullptr); + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } - int lda = transA ? pA->strideAt(0) : pA->strideAt(1); - int ldu(transA ? n : m), ldv(transA ? m : n); - bool uForder(true), vForder(true); + NDArray *pU(nullptr), *pV(nullptr); - if(calcUV) { + int lda = transA ? pA->strideAt(0) : pA->strideAt(1); + int ldu(transA ? n : m), ldv(transA ? m : n); + bool uForder(true), vForder(true); - pU = transA ? V : U; - pV = transA ? U : V; + if (calcUV) { + pU = transA ? V : U; + pV = transA ? U : V; - uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; - vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; + uForder = pU->sizeAt(0) == 1 || pU->strideAt(0) == 1; + vForder = pV->sizeAt(0) == 1 || pV->strideAt(0) == 1; - if(!uForder) { - pU = new NDArray(pU->dup('f')); - toDelete.push_back(pU); - } - - if(!vForder) { - pV = new NDArray(pV->dup('f')); - toDelete.push_back(pV); - } - - ldu = pU->strideAt(1); - ldv = pV->strideAt(1); - } - - std::lock_guard lock(*LaunchContext::deviceMutex()); - - // create cusolverDn handle - cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); - //cusolverStatus_t status = cusolverDnCreate(&handle); - if(handle == nullptr) - throw cuda_exception::build("svdJcb: cuda failed !", -1); - - // stream - auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - // set parameters - gesvdjInfo_t gesvdjParams = nullptr; - status = cusolverDnCreateGesvdjInfo(&gesvdjParams); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - int *devInfo = nullptr; - const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - const int econ = !fullUV; - - if(transA) - math::nd4j_swap(m, n); - - // *** avoid bug in cuda API *** - void* nullPtr = nullptr; - NDArray* arrToAvoidBugInAPI = nullptr; - if(!calcUV && m != n) { - int maxDim = m > n ? m : n; - arrToAvoidBugInAPI = new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); - nullPtr = arrToAvoidBugInAPI->specialBuffer(); + if (!uForder) { + pU = new NDArray(pU->dup('f')); + toDelete.push_back(pU); } - // ****************** - - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); - else - throw std::invalid_argument("svdJcb: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - // allocate memory dWork - void* dWork = nullptr; - auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdJcb: cuda failed !", status2); - - PointersManager manager(context, "svdJcb"); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + if (!vForder) { + pV = new NDArray(pV->dup('f')); + toDelete.push_back(pV); } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); - } - else - throw std::invalid_argument("svdJcb: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - - if(S->ews() != 1) - S->assign(pS); - if(calcUV) { - - if(!uForder) - U->assign(transA ? pV : pU); - if(!vForder) - V->assign(transA ? pU : pV); - } - - if(!calcUV && m != n) - delete arrToAvoidBugInAPI; - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); -// if(handle) -// cusolverDnDestroy(handle); - if(gesvdjParams) - cusolverDnDestroyGesvdjInfo(gesvdjParams); - - // cudaDeviceReset(); + ldu = pU->strideAt(1); + ldv = pV->strideAt(1); + } + + std::lock_guard lock(*LaunchContext::deviceMutex()); + + // create cusolverDn handle + cusolverDnHandle_t* handle = + (cusolverDnHandle_t*)context->getCusolverHandle(); + // cusolverStatus_t status = cusolverDnCreate(&handle); + if (handle == nullptr) + throw cuda_exception::build("svdJcb: cuda failed !", -1); + + // stream + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + // set parameters + gesvdjInfo_t gesvdjParams = nullptr; + status = cusolverDnCreateGesvdjInfo(&gesvdjParams); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + int* devInfo = nullptr; + const cusolverEigMode_t jobz = + calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + const int econ = !fullUV; + + if (transA) math::nd4j_swap(m, n); + + // *** avoid bug in cuda API *** + void* nullPtr = nullptr; + NDArray* arrToAvoidBugInAPI = nullptr; + if (!calcUV && m != n) { + int maxDim = m > n ? m : n; + arrToAvoidBugInAPI = + new NDArray('c', {maxDim, maxDim}, pA->dataType(), context); + nullPtr = arrToAvoidBugInAPI->specialBuffer(); + } + // ****************** + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdj_bufferSize( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, &lwork, gesvdjParams); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdj_bufferSize( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, &lwork, gesvdjParams); + else + throw std::invalid_argument("svdJcb: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + // allocate memory dWork + void* dWork = nullptr; + auto status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdJcb: cuda failed !", status2); + + PointersManager manager(context, "svdJcb"); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdj( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdj( + *handle, jobz, econ, m, n, + reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) + : reinterpret_cast(nullPtr), + ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) + : reinterpret_cast(nullPtr), + ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + } else + throw std::invalid_argument("svdJcb: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdJcb: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pV}, {pA}); + + if (S->ews() != 1) S->assign(pS); + + if (calcUV) { + if (!uForder) U->assign(transA ? pV : pU); + if (!vForder) V->assign(transA ? pU : pV); + } + + if (!calcUV && m != n) delete arrToAvoidBugInAPI; + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + // if(handle) + // cusolverDnDestroy(handle); + if (gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); + + // cudaDeviceReset(); } ////////////////////////////////////////////////////////////////////////// -static void svdBatched(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { - - // A [..., m, n] - // S [..., n] - // U [..., m, m] or [..., m, n] if fullUV = false and m > n - // V [..., n, n] or [..., n, m] if fullUV = false and m < n - - auto m = A->sizeAt(-2); - auto n = A->sizeAt(-1); - const int minDim = m < n ? m : n; - const Nd4jLong bS = A->lengthOf() / (m * n); - - if(m > 32 || n > 32) - throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !"); - - if(minDim != S->sizeAt(-1)) - throw std::runtime_error("svdBatched: wrong shape of S array !"); - - if(calcUV) { - - if(U->sizeAt(-2) != m) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U->sizeAt(-1) != (fullUV ? m : minDim)) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) - throw std::runtime_error("svdBatched: wrong shape of U array !"); - - if(V->sizeAt(-2) != n) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V->sizeAt(-1) != (fullUV ? n : minDim)) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) - throw std::runtime_error("svdBatched: wrong shape of V array !"); - } - - NDArray* pA = const_cast(A); - NDArray* pS = S; - NDArray* pU = U; - NDArray* pV = V; - - std::vector toDelete; - - if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = new NDArray(A->dup('f')); - toDelete.push_back(pA); - } - - if(S->ews() != 1) { - pS = new NDArray(S->dup('f')); - toDelete.push_back(pS); - } - - if(calcUV) { - - if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = new NDArray(U->dup('f')); - toDelete.push_back(pU); - } - - if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = new NDArray(V->dup('f')); - toDelete.push_back(pV); - } - } - - // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // set parameters - gesvdjInfo_t gesvdjParams = nullptr; - status = cusolverDnCreateGesvdjInfo(&gesvdjParams); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // devInfo - int *devInfo = nullptr; - auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - status2 = cudaDeviceSynchronize(); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdJcb: cuda failed !", status2); - - const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - - int lda(m), ldu, ldv; - - if(calcUV) { - ldu = pU->sizeAt(-2); - ldv = pV->sizeAt(-2); +static void svdBatched(sd::LaunchContext* context, const NDArray* A, NDArray* S, + NDArray* U, NDArray* V, const bool fullUV, + const bool calcUV) { + // A [..., m, n] + // S [..., n] + // U [..., m, m] or [..., m, n] if fullUV = false and m > n + // V [..., n, n] or [..., n, m] if fullUV = false and m < n + + auto m = A->sizeAt(-2); + auto n = A->sizeAt(-1); + const int minDim = m < n ? m : n; + const Nd4jLong bS = A->lengthOf() / (m * n); + + if (m > 32 || n > 32) + throw std::runtime_error( + "svdBatched: numbers of rows and columns should be <= 32 !"); + + if (minDim != S->sizeAt(-1)) + throw std::runtime_error("svdBatched: wrong shape of S array !"); + + if (calcUV) { + if (U->sizeAt(-2) != m) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + if (U->sizeAt(-1) != (fullUV ? m : minDim)) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + if (U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) + throw std::runtime_error("svdBatched: wrong shape of U array !"); + + if (V->sizeAt(-2) != n) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + if (V->sizeAt(-1) != (fullUV ? n : minDim)) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + if (V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) + throw std::runtime_error("svdBatched: wrong shape of V array !"); + } + + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; + + std::vector toDelete; + + if (pA->ews() != 1 || pA->ordering() == 'c') { + pA = new NDArray(A->dup('f')); + toDelete.push_back(pA); + } + + if (S->ews() != 1) { + pS = new NDArray(S->dup('f')); + toDelete.push_back(pS); + } + + if (calcUV) { + if (pU->ews() != 1 || pU->ordering() == 'c') { + pU = new NDArray(U->dup('f')); + toDelete.push_back(pU); } - // Ak (i,j) = A[i + 5*j + 25*k] - - // query working space of SVD - int lwork = 0; - if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); - else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); - else - throw std::invalid_argument("svdBatched: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - // allocate memory dWork - void* dWork = nullptr; - status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - status2 = cudaDeviceSynchronize(); - if(status2 != cudaSuccess) - throw cuda_exception::build("svdBatched: cuda failed !", status2); - - PointersManager manager(context, "svdBatched"); - - NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); - - // choose appropriate cuda gemm api depending on data types - if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + if (pV->ews() != 1 || pV->ordering() == 'c') { + pV = new NDArray(V->dup('f')); + toDelete.push_back(pV); } - else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, reinterpret_cast(pS->specialBuffer()), calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); - } - else - throw std::invalid_argument("svdBatched: given data type is unsupported !"); - - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdBatched: cuda failed !", status); - - manager.synchronize(); - - NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - - S->assign(pS); - - if(calcUV) { - U->assign(pU); - V->assign(pV); - } - - for (int i = toDelete.size() - 1; i >= 0; --i) - delete toDelete[i]; - - if (devInfo) - cudaFree(devInfo); - if (dWork ) - cudaFree(dWork); - if(handle) - cusolverDnDestroy(handle); - if(gesvdjParams) - cusolverDnDestroyGesvdjInfo(gesvdjParams); - - // cudaDeviceReset(); + } + + // create cusolverDn handle + cusolverDnHandle_t handle = nullptr; + cusolverStatus_t status = cusolverDnCreate(&handle); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // stream + status = cusolverDnSetStream(handle, *context->getCudaStream()); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // set parameters + gesvdjInfo_t gesvdjParams = nullptr; + status = cusolverDnCreateGesvdjInfo(&gesvdjParams); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // devInfo + int* devInfo = nullptr; + auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + status2 = cudaDeviceSynchronize(); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdJcb: cuda failed !", status2); + + const cusolverEigMode_t jobz = + calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + + int lda(m), ldu, ldv; + + if (calcUV) { + ldu = pU->sizeAt(-2); + ldv = pV->sizeAt(-2); + } + + // Ak (i,j) = A[i + 5*j + 25*k] + + // query working space of SVD + int lwork = 0; + if (A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdjBatched_bufferSize( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + &lwork, gesvdjParams, bS); + else if (A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdjBatched_bufferSize( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + &lwork, gesvdjParams, bS); + else + throw std::invalid_argument("svdBatched: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + // allocate memory dWork + void* dWork = nullptr; + status2 = cudaMalloc((void**)&dWork, A->sizeOfT() * lwork); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + status2 = cudaDeviceSynchronize(); + if (status2 != cudaSuccess) + throw cuda_exception::build("svdBatched: cuda failed !", status2); + + PointersManager manager(context, "svdBatched"); + + NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); + + // choose appropriate cuda gemm api depending on data types + if (A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdjBatched( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + } else if (A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdjBatched( + handle, jobz, m, n, reinterpret_cast(pA->specialBuffer()), lda, + reinterpret_cast(pS->specialBuffer()), + calcUV ? reinterpret_cast(pU->specialBuffer()) : nullptr, ldu, + calcUV ? reinterpret_cast(pV->specialBuffer()) : nullptr, ldv, + reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + } else + throw std::invalid_argument("svdBatched: given data type is unsupported !"); + + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("svdBatched: cuda failed !", status); + + manager.synchronize(); + + NDArray::registerSpecialUse({pS, pU, pV}, {pA}); + + S->assign(pS); + + if (calcUV) { + U->assign(pU); + V->assign(pV); + } + + for (int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; + + if (devInfo) cudaFree(devInfo); + if (dWork) cudaFree(dWork); + if (handle) cusolverDnDestroy(handle); + if (gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); + + // cudaDeviceReset(); } //////////////////////////////////////////////////////////////////// -void svd(sd::LaunchContext* context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { - - NDArray* S = outArrs[0]; - NDArray* U = outArrs[1]; - // NDArray VT = outArrs[2]->transpose(); - NDArray* V = outArrs[2]; - - NDArray::prepareSpecialUse({S, U, V}, {x}); - - if(x->rankOf() == 2) { - // svdQR(context, x, S, U, VT, fullUV, calcUV); - svdJcb(context, x, S, U, V, fullUV, calcUV); +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum) { + NDArray* S = outArrs[0]; + NDArray* U = outArrs[1]; + // NDArray VT = outArrs[2]->transpose(); + NDArray* V = outArrs[2]; + + NDArray::prepareSpecialUse({S, U, V}, {x}); + + if (x->rankOf() == 2) { + // svdQR(context, x, S, U, VT, fullUV, calcUV); + svdJcb(context, x, S, U, V, fullUV, calcUV); + } else { + // svdBatched(context, *x, *S, *U, *V, fullUV, calcUV); + + ResultSet *tadsU(nullptr), *tadsV(nullptr); + + auto tadsX = + x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1}); + auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1}); + + if (calcUV) { + tadsU = new ResultSet( + U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); + tadsV = new ResultSet( + V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); } - else { - // svdBatched(context, *x, *S, *U, *V, fullUV, calcUV); + for (int i = 0; i < tadsX.size(); ++i) + svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, + calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); - ResultSet *tadsU(nullptr), *tadsV(nullptr); - - auto tadsX = x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1}); - auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1}); - - if(calcUV) { - tadsU = new ResultSet(U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1})); - tadsV = new ResultSet(V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1})); - } - - for (int i = 0; i < tadsX.size(); ++i) - svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); - - if(calcUV) { - delete tadsU; - delete tadsV; - } + if (calcUV) { + delete tadsU; + delete tadsV; } + } - NDArray::registerSpecialUse({S, U, V}, {x}); + NDArray::registerSpecialUse({S, U, V}, {x}); } - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index 2138e1188d8b..43501948e282 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -18,25 +18,26 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - void toggle_bits__(NDArray &in, NDArray &out) { - auto lambda = LAMBDA_T(_x) { - return ~_x;//eUtils::flip_bits(_x); - }; +namespace ops { +namespace helpers { +template +void toggle_bits__(NDArray &in, NDArray &out) { + auto lambda = LAMBDA_T(_x) { + return ~_x; // eUtils::flip_bits(_x); + }; - in.applyLambda(lambda, out); - } - BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); + in.applyLambda(lambda, out); +} +BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray & in, NDArray &out), + INTEGER_TYPES); - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out) { - BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); - } - } - } -} \ No newline at end of file +void __toggle_bits(sd::LaunchContext *context, NDArray &in, NDArray &out) { + BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index ce19d41ccdd4..b5bc3eaff75d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -18,265 +18,308 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template +template __global__ static void inTopKCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const uint k) { - - - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ uint* sharedMem; - __shared__ X elemToCompare; - __shared__ const X* xTad; - __shared__ Nd4jLong idx, xTadLen; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xTadLen = shape::length(xTadShapeInfo); - - xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; - idx = y[shape::getIndexOffset(blockIdx.x, yShapeInfo)]; // shape::length(yShapeInfo) == numTads - elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo)]; - } - - __syncthreads(); - - sharedMem[threadIdx.x] = 0; - for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x) - if(elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo)]) - ++sharedMem[threadIdx.x]; - + void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, const uint k) { + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ uint* sharedMem; + __shared__ X elemToCompare; + __shared__ const X* xTad; + __shared__ Nd4jLong idx, xTadLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + xTadLen = shape::length(xTadShapeInfo); + + xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; + idx = y[shape::getIndexOffset( + blockIdx.x, yShapeInfo)]; // shape::length(yShapeInfo) == numTads + elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo)]; + } + + __syncthreads(); + + sharedMem[threadIdx.x] = 0; + for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x) + if (elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo)]) + ++sharedMem[threadIdx.x]; + + __syncthreads(); + + // aggregate sum + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; __syncthreads(); + } - // aggregate sum - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; - __syncthreads(); - } - - if (threadIdx.x == 0) - z[shape::getIndexOffset(blockIdx.x, zShapeInfo)] = *sharedMem < k; + if (threadIdx.x == 0) + z[shape::getIndexOffset(blockIdx.x, zShapeInfo)] = *sharedMem < k; } /////////////////////////////////////////////////////////////////// -template -static void inTopKCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, - const uint k) { - - inTopKCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, k); +template +static void inTopKCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, + const Nd4jLong* xTadShapeInfo, + const Nd4jLong* xTadOffsets, const uint k) { + inTopKCuda<<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, + xTadOffsets, k); } /////////////////////////////////////////////////////////////////// -int inTopKFunctor(sd::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k) { - - PointersManager manager(context, "in_top_k"); - - const auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(predictions->shapeInfo(), {1}); - - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = static_cast(packX.numberOfTads()); - const int sharedMem = sizeof(uint) * threadsPerBlock + 128; - - const auto xType = predictions->dataType(); - const auto yType = targets->dataType(); - - NDArray::prepareSpecialUse({output}, {predictions, targets}); - BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->specialBuffer(), predictions->specialShapeInfo(), targets->specialBuffer(), targets->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES); - NDArray::registerSpecialUse({output}, {predictions, targets}); - - manager.synchronize(); - - return Status::OK(); +int inTopKFunctor(sd::LaunchContext* context, const NDArray* predictions, + const NDArray* targets, NDArray* output, const uint k) { + PointersManager manager(context, "in_top_k"); + + const auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + predictions->shapeInfo(), {1}); + + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = static_cast(packX.numberOfTads()); + const int sharedMem = sizeof(uint) * threadsPerBlock + 128; + + const auto xType = predictions->dataType(); + const auto yType = targets->dataType(); + + NDArray::prepareSpecialUse({output}, {predictions, targets}); + BUILD_DOUBLE_SELECTOR( + xType, yType, inTopKCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + predictions->specialBuffer(), predictions->specialShapeInfo(), + targets->specialBuffer(), targets->specialShapeInfo(), + output->specialBuffer(), output->specialShapeInfo(), + packX.specialShapeInfo(), packX.specialOffsets(), k), + FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {predictions, targets}); + + manager.synchronize(); + + return Status::OK(); } - template - static _CUDA_G void topValuesMover(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void const* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong tadLength, int numTads, int k) { - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto i = reinterpret_cast(vi) + iTadOffsets[t]; - auto z = reinterpret_cast(vz) + zTadOffsets[t]; - - for (int e = threadIdx.x; e < k; e += blockDim.x) { - auto idx = i[shape::getIndexOffset(e, iTadShapeInfo)]; - - z[shape::getIndexOffset(e, zTadShapeInfo)] = x[shape::getIndexOffset(idx, xTadShapeInfo)]; - } - } +template +static _CUDA_G void topValuesMover( + void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, + void const* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, + void* vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, + Nd4jLong tadLength, int numTads, int k) { + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto i = reinterpret_cast(vi) + iTadOffsets[t]; + auto z = reinterpret_cast(vz) + zTadOffsets[t]; + + for (int e = threadIdx.x; e < k; e += blockDim.x) { + auto idx = i[shape::getIndexOffset(e, iTadShapeInfo)]; + + z[shape::getIndexOffset(e, zTadShapeInfo)] = + x[shape::getIndexOffset(idx, xTadShapeInfo)]; } + } +} - - template - static _CUDA_G void indicesAlongDimension(void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, void* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, void *vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) { - extern __shared__ char _shmem[]; - - X* tempValues = reinterpret_cast(_shmem) + threadIdx.x * scanWidth; - Y* tempIndices = reinterpret_cast(reinterpret_cast(_shmem) + blockDim.x * scanWidth) + threadIdx.x * scanWidth; - - __shared__ X localMaximum; - if (threadIdx.x == 0) - localMaximum = -DataTypeUtils::max(); - __syncthreads(); - - for (int t = blockIdx.x; t < numTads; t += gridDim.x) { - auto x = reinterpret_cast(vx) + xTadOffsets[t]; - auto i = reinterpret_cast(vi) + iTadOffsets[t]; - auto z = reinterpret_cast(vz) + zTadOffsets[t]; - - // we'll do multiple reads here - for (int p = 0; p < k; p += scanWidth) { - - // resetting temporary storage - for (int p = 0; p < scanWidth; p++) { - tempValues[p] = -DataTypeUtils::max(); - tempIndices[p] = DataTypeUtils::max(); - } - - // local max values/indices - for (int e = threadIdx.x; e < tadLength; e++) { - auto value = x[shape::getIndexOffset(e, xTadShapeInfo)]; - - // we'll compare this value to current stored ones - for (int f = 0; f < scanWidth; f++) { - if (value > tempValues[f] && (p == 0 || value < localMaximum)) { - tempValues[f] = value; - tempIndices[f] = e; - } - } - } - __syncthreads(); - - // at this point we have local part ready for merge and define global maximum for this iteration, and local maximum for next iteration - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) { - if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) { - tempValues[0] = tempValues[0 + activeThreads * scanWidth]; - tempIndices[0] = tempIndices[0 + activeThreads * scanWidth]; - } - } - __syncthreads(); - } - __syncthreads(); - - // at this point we know local minimum for next iteration - if (threadIdx.x == 0) { - localMaximum = tempValues[scanWidth - 1]; - z[shape::getIndexOffset(p, zTadShapeInfo)] = tempValues[scanWidth - 1]; - i[shape::getIndexOffset(p, iTadShapeInfo)] = tempIndices[scanWidth - 1]; - } - __syncthreads(); - } - - __syncthreads(); - if (!needSort) { - // if we don't need sort, we need to return values based on their indices (ascending) - for (int m = 0; m < k; m++) { - if (m % 2 == 0) { - for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { - auto top = 2 * tid + 1; - if (top < k) { - auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); - auto t1 = shape::getIndexOffset(top, iTadShapeInfo); - - if (i[t0] > i[t1]) { - // swap indices first - Y di0 = i[t0]; - i[t0] = i[t1]; - i[t1] = di0; - - //swap values next - - X dz0 = z[t0]; - z[t0] = z[t1]; - z[t1] = dz0; - } - } - } - } else { - for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { - auto top = 2 * tid + 2; - if (top < k) { - auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); - auto t1 = shape::getIndexOffset(top, iTadShapeInfo); - - if (i[t0] > i[t1]) { - // swap indices first - Y di0 = i[t0]; - i[t0] = i[t1]; - i[t1] = di0; - - //swap values next - - X dz0 = z[t0]; - z[t0] = z[t1]; - z[t1] = dz0; - } - } - } - } - __syncthreads(); - } - } +template +static _CUDA_G void indicesAlongDimension( + void const* vx, Nd4jLong const* xTadShapeInfo, Nd4jLong const* xTadOffsets, + void* vi, Nd4jLong const* iTadShapeInfo, Nd4jLong const* iTadOffsets, + void* vz, Nd4jLong const* zTadShapeInfo, Nd4jLong const* zTadOffsets, + Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) { + extern __shared__ char _shmem[]; + + X* tempValues = reinterpret_cast(_shmem) + threadIdx.x * scanWidth; + Y* tempIndices = reinterpret_cast(reinterpret_cast(_shmem) + + blockDim.x * scanWidth) + + threadIdx.x * scanWidth; + + __shared__ X localMaximum; + if (threadIdx.x == 0) localMaximum = -DataTypeUtils::max(); + __syncthreads(); + + for (int t = blockIdx.x; t < numTads; t += gridDim.x) { + auto x = reinterpret_cast(vx) + xTadOffsets[t]; + auto i = reinterpret_cast(vi) + iTadOffsets[t]; + auto z = reinterpret_cast(vz) + zTadOffsets[t]; + + // we'll do multiple reads here + for (int p = 0; p < k; p += scanWidth) { + // resetting temporary storage + for (int p = 0; p < scanWidth; p++) { + tempValues[p] = -DataTypeUtils::max(); + tempIndices[p] = DataTypeUtils::max(); + } + + // local max values/indices + for (int e = threadIdx.x; e < tadLength; e++) { + auto value = x[shape::getIndexOffset(e, xTadShapeInfo)]; + + // we'll compare this value to current stored ones + for (int f = 0; f < scanWidth; f++) { + if (value > tempValues[f] && (p == 0 || value < localMaximum)) { + tempValues[f] = value; + tempIndices[f] = e; + } + } + } + __syncthreads(); + + // at this point we have local part ready for merge and define global + // maximum for this iteration, and local maximum for next iteration + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) { + if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) { + tempValues[0] = tempValues[0 + activeThreads * scanWidth]; + tempIndices[0] = tempIndices[0 + activeThreads * scanWidth]; + } } + __syncthreads(); + } + __syncthreads(); + + // at this point we know local minimum for next iteration + if (threadIdx.x == 0) { + localMaximum = tempValues[scanWidth - 1]; + z[shape::getIndexOffset(p, zTadShapeInfo)] = tempValues[scanWidth - 1]; + i[shape::getIndexOffset(p, iTadShapeInfo)] = tempIndices[scanWidth - 1]; + } + __syncthreads(); } - - template - static int topKFunctor_(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {input->rankOf() - 1}); - auto packI = ConstantTadHelper::getInstance()->tadForDimensions(indices->shapeInfo(), {input->rankOf() - 1}); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(values->shapeInfo(), {input->rankOf() - 1}); - - auto tadLength = shape::length(packX.primaryShapeInfo()); - - // we get top K values first - if (k == 1) { - input->applyIndexReduce(indexreduce::IndexMax, *indices, {input->rankOf() - 1}); - - // copy values on specified indices - topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); + __syncthreads(); + if (!needSort) { + // if we don't need sort, we need to return values based on their indices + // (ascending) + for (int m = 0; m < k; m++) { + if (m % 2 == 0) { + for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { + auto top = 2 * tid + 1; + if (top < k) { + auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); + auto t1 = shape::getIndexOffset(top, iTadShapeInfo); + + if (i[t0] > i[t1]) { + // swap indices first + Y di0 = i[t0]; + i[t0] = i[t1]; + i[t1] = di0; + + // swap values next + + X dz0 = z[t0]; + z[t0] = z[t1]; + z[t1] = dz0; + } + } + } } else { - int scanWidth = 1; - int numTreads = 256; - int shMemSize = (numTreads * sizeof(X) * scanWidth) + (numTreads * sizeof(Y) * scanWidth) + 512; - - indicesAlongDimension<<<256, numTreads, shMemSize, *context->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, scanWidth, needSort); + for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { + auto top = 2 * tid + 2; + if (top < k) { + auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo); + auto t1 = shape::getIndexOffset(top, iTadShapeInfo); + + if (i[t0] > i[t1]) { + // swap indices first + Y di0 = i[t0]; + i[t0] = i[t1]; + i[t1] = di0; + + // swap values next + + X dz0 = z[t0]; + z[t0] = z[t1]; + z[t1] = dz0; + } + } + } } - - return Status::OK(); + __syncthreads(); + } } + } +} - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { - input->syncToDevice(); +template +static int topKFunctor_(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {input->rankOf() - 1}); + auto packI = ConstantTadHelper::getInstance()->tadForDimensions( + indices->shapeInfo(), {input->rankOf() - 1}); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + values->shapeInfo(), {input->rankOf() - 1}); + + auto tadLength = shape::length(packX.primaryShapeInfo()); + + // we get top K values first + if (k == 1) { + input->applyIndexReduce(indexreduce::IndexMax, *indices, + {input->rankOf() - 1}); + + // copy values on specified indices + topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), indices->specialBuffer(), + packI.platformShapeInfo(), packI.platformOffsets(), + values->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); + } else { + int scanWidth = 1; + int numTreads = 256; + int shMemSize = (numTreads * sizeof(X) * scanWidth) + + (numTreads * sizeof(Y) * scanWidth) + 512; + + indicesAlongDimension + <<<256, numTreads, shMemSize, *context->getCudaStream()>>>( + input->specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), indices->specialBuffer(), + packI.platformShapeInfo(), packI.platformOffsets(), + values->specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, + scanWidth, needSort); + } + + return Status::OK(); +} - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES); +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, + bool needSort) { + input->syncToDevice(); - values->tickWriteDevice(); - indices->tickWriteDevice(); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, + (context, input, values, indices, k, needSort), + LIBND4J_TYPES, INDEXING_TYPES); - return Status::OK(); - } + values->tickWriteDevice(); + indices->tickWriteDevice(); + return Status::OK(); } -} -} \ No newline at end of file + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index f016491a60e5..0e41a4e3a409 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -19,869 +19,1000 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // - -#include -#include -#include -#include #include -#include +#include #include -#include #include +#include +#include +#include +#include -namespace sd { -namespace ops { +#include + +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ static void invertPermutationCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ Nd4jLong len, totalThreads; - - if (threadIdx.x == 0) { - - len = shape::length(xShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const Nd4jLong index = x[xOffset]; - const auto zOffset = shape::getIndexOffset(index, zShapeInfo); - z[zOffset] = i; - } +template +__global__ static void invertPermutationCuda(const void* vx, + const Nd4jLong* xShapeInfo, + void* vz, + const Nd4jLong* zShapeInfo) { + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ Nd4jLong len, totalThreads; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const Nd4jLong index = x[xOffset]; + const auto zOffset = shape::getIndexOffset(index, zShapeInfo); + z[zOffset] = i; + } } /////////////////////////////////////////////////////////////////// -template -__host__ static void invertPermutationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - invertPermutationCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +template +__host__ static void invertPermutationCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo) { + invertPermutationCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo); } //////////////////////////////////////////////////////////////////////// -void invertPermutation(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - const int threadsPerBlock = MAX_NUM_THREADS; - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "invertPermutation"); - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), invertPermutationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output) { + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "invertPermutation"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), invertPermutationCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo()), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } ////////////////////////////////////////////////////////////////////////// -template -__global__ static void traceCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint diagLen) { - - const auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vz); - - __shared__ T* sharedMem; - __shared__ int xRank, zRank, *coordsMem; // xRank = zRank + 2 - __shared__ Nd4jLong xLen, zLen; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - coordsMem = reinterpret_cast(shmem + blockDim.x * sizeof(T)); - - xRank = shape::rank(xShapeInfo); - zRank = shape::rank(zShapeInfo); - xLen = shape::length(xShapeInfo); - zLen = shape::length(zShapeInfo); // corresponds to number of matrices - +template +__global__ static void traceCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const uint diagLen) { + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ T* sharedMem; + __shared__ int xRank, zRank, *coordsMem; // xRank = zRank + 2 + __shared__ Nd4jLong xLen, zLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + coordsMem = reinterpret_cast(shmem + blockDim.x * sizeof(T)); + + xRank = shape::rank(xShapeInfo); + zRank = shape::rank(zShapeInfo); + xLen = shape::length(xShapeInfo); + zLen = shape::length(zShapeInfo); // corresponds to number of matrices + } + __syncthreads(); + + auto coords = coordsMem + threadIdx.x * xRank; + + for (uint m = blockIdx.x; m < zLen; + m += + gridDim.x) { // one block per each element of z, that is per each matrix + + shape::index2coords(m, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + sharedMem[threadIdx.x] = 0; + + for (uint i = threadIdx.x; i < diagLen; i += blockDim.x) { + coords[zRank] = coords[zRank + 1] = i; + const auto xOffset = shape::getOffset(xShapeInfo, coords); + sharedMem[threadIdx.x] += x[xOffset]; } - __syncthreads(); - - auto coords = coordsMem + threadIdx.x * xRank; - - for (uint m = blockIdx.x; m < zLen; m += gridDim.x) { // one block per each element of z, that is per each matrix - - shape::index2coords(m, zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - sharedMem[threadIdx.x] = 0; - - for (uint i = threadIdx.x; i < diagLen; i += blockDim.x) { - - coords[zRank] = coords[zRank + 1] = i; - const auto xOffset = shape::getOffset(xShapeInfo, coords); - sharedMem[threadIdx.x] += x[xOffset]; - } - __syncthreads(); - - // aggregate sum - for (Nd4jLong activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - if (threadIdx.x < activeThreads) - sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; - __syncthreads(); - } + __syncthreads(); - if (threadIdx.x == 0) - z[zOffset] = *sharedMem; - __syncthreads(); + // aggregate sum + for (Nd4jLong activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads) + sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; + __syncthreads(); } + if (threadIdx.x == 0) z[zOffset] = *sharedMem; + __syncthreads(); + } } /////////////////////////////////////////////////////////////////// -template -static void traceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, - const uint diagLen) { - - traceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diagLen); +template +static void traceCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const uint diagLen) { + traceCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, diagLen); } - /////////////////////////////////////////////////////////////////// void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output) { - - PointersManager manager(context, "trace"); - - const uint diagLen = input.sizeAt(-1) < input.sizeAt(-2) ? input.sizeAt(-1) : input.sizeAt(-2); - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * (sizeof(int) * input.rankOf() + input.sizeOfT()) + 128; - - NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), traceCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), diagLen), LIBND4J_TYPES); - NDArray::registerSpecialUse({&output}, {&input}); - - manager.synchronize(); + PointersManager manager(context, "trace"); + + const uint diagLen = + input.sizeAt(-1) < input.sizeAt(-2) ? input.sizeAt(-1) : input.sizeAt(-2); + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * (sizeof(int) * input.rankOf() + input.sizeOfT()) + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR( + input.dataType(), traceCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), diagLen), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) { - - // x and z have same shapes - const auto x = reinterpret_cast(vx); // gradO - auto z = reinterpret_cast(vz); // gradI - - __shared__ int rank, areSameOffsets, *sharedMem; // xRank = zRank - __shared__ Nd4jLong len, totalThreads; // xLen = zLen - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - rank = shape::rank(xShapeInfo); - len = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - - __syncthreads(); - - auto coords = sharedMem + threadIdx.x * rank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < len; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, coords); - - const auto zOffset = shape::getOffset(zShapeInfo, coords); - - if((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col - z[zOffset] = 0; - else - z[zOffset] = x[areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords)]; - } +template +__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const int diag) { + // x and z have same shapes + const auto x = reinterpret_cast(vx); // gradO + auto z = reinterpret_cast(vz); // gradI + + __shared__ int rank, areSameOffsets, *sharedMem; // xRank = zRank + __shared__ Nd4jLong len, totalThreads; // xLen = zLen + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + rank = shape::rank(xShapeInfo); + len = shape::length(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + auto coords = sharedMem + threadIdx.x * rank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < len; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + + if ((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col + z[zOffset] = 0; + else + z[zOffset] = + x[areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords)]; + } } /////////////////////////////////////////////////////////////////// -template -static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) { - - triuBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diag); +template +static void triuBPCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, const int diag) { + triuBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, diag); } /////////////////////////////////////////////////////////////////// -void triuBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * gradO.rankOf() + 128; - - PointersManager manager(context, "triuBP"); - - NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), triuBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), diagonal), LIBND4J_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO}); - - manager.synchronize(); +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal) { + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * gradO.rankOf() + 128; + + PointersManager manager(context, "triuBP"); + + NDArray::prepareSpecialUse({&gradI}, {&gradO}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), triuBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), diagonal), + LIBND4J_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO}); + + manager.synchronize(); } /////////////////////////////////////////////////////////////////// -template -__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { - - // x and z have same shapes - const auto x = reinterpret_cast(vx); // gradO - auto z = reinterpret_cast(vz); // gradI - - __shared__ int xRank, zRank, *sharedMem; // xRank >= zRank - __shared__ Nd4jLong numOfXOffsets, zLen, totalThreads; // xLen >= zLen - - if (threadIdx.x == 0) { +template +__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + Nd4jLong* globMem) { + // x and z have same shapes + const auto x = reinterpret_cast(vx); // gradO + auto z = reinterpret_cast(vz); // gradI - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); + __shared__ int xRank, zRank, *sharedMem; // xRank >= zRank + __shared__ Nd4jLong numOfXOffsets, zLen, totalThreads; // xLen >= zLen - xRank = shape::rank(zShapeInfo); - zLen = shape::length(zShapeInfo); - numOfXOffsets = shape::length(xShapeInfo) / zLen; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - totalThreads = gridDim.x * blockDim.x; - } + xRank = shape::rank(zShapeInfo); + zLen = shape::length(zShapeInfo); + numOfXOffsets = shape::length(xShapeInfo) / zLen; - __syncthreads(); + totalThreads = gridDim.x * blockDim.x; + } - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + __syncthreads(); - auto memBuff = sharedMem + threadIdx.x * 2 * xRank; - auto xOffsets = globMem + tid * numOfXOffsets; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + auto memBuff = sharedMem + threadIdx.x * 2 * xRank; + auto xOffsets = globMem + tid * numOfXOffsets; - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + const auto zOffset = shape::getIndexOffset(i, zShapeInfo); - shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff); + shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff); - z[zOffset] = x[xOffsets[0]]; // first offset - for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets - z[zOffset] += x[xOffsets[j]]; - } + z[zOffset] = x[xOffsets[0]]; // first offset + for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets + z[zOffset] += x[xOffsets[j]]; + } } /////////////////////////////////////////////////////////////////// -template -static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { - - tileBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, globMem); +template +static void tileBPCudaLauncher(const int blocksPerGrid, + const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, + const Nd4jLong* xShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, Nd4jLong* globMem) { + tileBPCuda<<>>( + vx, xShapeInfo, vz, zShapeInfo, globMem); } - ////////////////////////////////////////////////////////////////////////// -void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps) { - - NDArray memBuff('c', gradO.getShapeAsVector(), sd::DataType::INT64, context); // empty auxiliary array for storing device memory which will be used in kernel calculations - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(int) * 2 * gradO.rankOf() + 128; - - PointersManager manager(context, "tileBP"); - - NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), reinterpret_cast(memBuff.specialBuffer())), FLOAT_TYPES); - NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff}); - - manager.synchronize(); +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps) { + NDArray memBuff('c', gradO.getShapeAsVector(), sd::DataType::INT64, + context); // empty auxiliary array for storing device memory + // which will be used in kernel calculations + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = + threadsPerBlock * sizeof(int) * 2 * gradO.rankOf() + 128; + + PointersManager manager(context, "tileBP"); + + NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff}); + BUILD_SINGLE_SELECTOR( + gradI.dataType(), tileBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + gradO.specialBuffer(), gradO.specialShapeInfo(), gradI.specialBuffer(), + gradI.specialShapeInfo(), + reinterpret_cast(memBuff.specialBuffer())), + FLOAT_TYPES); + NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff}); + + manager.synchronize(); } ////////////////////////////////////////////////////////////////////////// // x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) { - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - if(tid >= shape::length(zShapeInfo)) - return; - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto reducBuff = reinterpret_cast(vreducBuff); - uint* count = reinterpret_cast(vreducBuff) + 16384; - - __shared__ Z* shMem; - __shared__ Nd4jLong len; - __shared__ bool amIinLastBlock; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shMem = reinterpret_cast(shmem); - - len = shape::length(zShapeInfo); // xLen = yLen = zLen +template +__global__ static void clipByNormBPWholeArrCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vreducBuff, const Z clipNormVal) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= shape::length(zShapeInfo)) return; + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + auto reducBuff = reinterpret_cast(vreducBuff); + uint* count = reinterpret_cast(vreducBuff) + 16384; + + __shared__ Z* shMem; + __shared__ Nd4jLong len; + __shared__ bool amIinLastBlock; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shMem = reinterpret_cast(shmem); + + len = shape::length(zShapeInfo); // xLen = yLen = zLen + } + __syncthreads(); + + // fill shared memory with array elements + const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)]; + const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)]; + + shMem[2 * threadIdx.x] = static_cast(xVal * xVal); // for norm + shMem[2 * threadIdx.x + 1] = + static_cast(xVal * yVal); // for input * gradO + + __syncthreads(); + + // accumulate sum per block + for (int activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads && tid + activeThreads < len) { + shMem[2 * threadIdx.x] += shMem[2 * (threadIdx.x + activeThreads)]; + shMem[2 * threadIdx.x + 1] += + shMem[2 * (threadIdx.x + activeThreads) + 1]; } __syncthreads(); - - // fill shared memory with array elements - const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)]; - const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)]; - - shMem[2*threadIdx.x] = static_cast(xVal * xVal); // for norm - shMem[2*threadIdx.x + 1] = static_cast(xVal * yVal); // for input * gradO - + } + + // store accumulated sums in reduction buffer (reducBuff) + if (threadIdx.x == 0) { + reducBuff[2 * blockIdx.x] = shMem[0]; + reducBuff[2 * blockIdx.x + 1] = shMem[1]; + + __threadfence(); + + amIinLastBlock = + gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1); + } + __syncthreads(); + + // shared memory of last block is used for final summation of values stored in + // reduction buffer + if (amIinLastBlock) { + for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { + shMem[2 * threadIdx.x] = (i == threadIdx.x) + ? reducBuff[2 * i] + : reducBuff[2 * i] + shMem[2 * threadIdx.x]; + shMem[2 * threadIdx.x + 1] = + (i == threadIdx.x) + ? reducBuff[2 * i + 1] + : reducBuff[2 * i + 1] + shMem[2 * threadIdx.x + 1]; + } __syncthreads(); - // accumulate sum per block - for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && tid + activeThreads < len) { - - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); + // accumulate sum + for (int activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads && + threadIdx.x + activeThreads < gridDim.x) { + shMem[2 * threadIdx.x] += shMem[2 * (threadIdx.x + activeThreads)]; + shMem[2 * threadIdx.x + 1] += + shMem[2 * (threadIdx.x + activeThreads) + 1]; + } + __syncthreads(); } - // store accumulated sums in reduction buffer (reducBuff) if (threadIdx.x == 0) { - - reducBuff[2*blockIdx.x] = shMem[0]; - reducBuff[2*blockIdx.x + 1] = shMem[1]; - - __threadfence(); - - amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1); - } - __syncthreads(); - - // shared memory of last block is used for final summation of values stored in reduction buffer - if (amIinLastBlock) { - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - - shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x]; - shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1]; - } - __syncthreads(); - - // accumulate sum - for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) { - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); - } - - if (threadIdx.x == 0) { - - reducBuff[0] = math::nd4j_sqrt(shMem[0]); - reducBuff[1] = shMem[1]; - count = 0; - } + reducBuff[0] = math::nd4j_sqrt(shMem[0]); + reducBuff[1] = shMem[1]; + count = 0; } + } } ////////////////////////////////////////////////////////////////////////// // x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) { - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; +template +__global__ static void clipByNormBPCalcGradCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vy, + const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vreducBuff, const Z clipNormVal) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen + const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen - if(tid >= len) - return; + if (tid >= len) return; - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); - __shared__ Z norm, sumOfProd; - - if (threadIdx.x == 0) { - - norm = reinterpret_cast(vreducBuff)[0]; - sumOfProd = reinterpret_cast(vreducBuff)[1]; - } - __syncthreads(); + __shared__ Z norm, sumOfProd; - const auto yOffset = shape::getIndexOffset(tid, yShapeInfo); - const auto zOffset = shape::getIndexOffset(tid, zShapeInfo); + if (threadIdx.x == 0) { + norm = reinterpret_cast(vreducBuff)[0]; + sumOfProd = reinterpret_cast(vreducBuff)[1]; + } + __syncthreads(); - if(norm > clipNormVal) { + const auto yOffset = shape::getIndexOffset(tid, yShapeInfo); + const auto zOffset = shape::getIndexOffset(tid, zShapeInfo); - const auto xOffset = shape::getIndexOffset(tid, xShapeInfo); + if (norm > clipNormVal) { + const auto xOffset = shape::getIndexOffset(tid, xShapeInfo); - const Z factor1 = static_cast(1) / norm; // 1 / norm - const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) + const Z factor1 = static_cast(1) / norm; // 1 / norm + const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) - z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]); - } - else { - z[zOffset] = y[yOffset]; - } + z[zOffset] = + clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]); + } else { + z[zOffset] = y[yOffset]; + } } ////////////////////////////////////////////////////////////////////////// // x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Z* shMem; - __shared__ Nd4jLong tadLen; +template +__global__ static void clipByNormBPTadsCuda( + const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, + const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, + void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, + const Z clipNormVal) { + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Z* shMem; + __shared__ Nd4jLong tadLen; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shMem = reinterpret_cast(shmem); + tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen + } + __syncthreads(); + + const auto* xTad = x + xTadOffsets[blockIdx.x]; + const auto* yTad = y + yTadOffsets[blockIdx.x]; + auto* zTad = z + zTadOffsets[blockIdx.x]; + + // *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** // + + Z norm = 0; + Z sumOfProd = 0; + + for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); + + shMem[2 * threadIdx.x] = + static_cast(xTad[xOffset] * xTad[xOffset]); // for norm + shMem[2 * threadIdx.x + 1] = + static_cast(xTad[xOffset] * yTad[yOffset]); // for input * gradO - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - shMem = reinterpret_cast(shmem); - tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen - } __syncthreads(); - const auto* xTad = x + xTadOffsets[blockIdx.x]; - const auto* yTad = y + yTadOffsets[blockIdx.x]; - auto* zTad = z + zTadOffsets[blockIdx.x]; - - // *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** // - - Z norm = 0; - Z sumOfProd = 0; - - for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - - shMem[2*threadIdx.x] = static_cast(xTad[xOffset] * xTad[xOffset]); // for norm - shMem[2*threadIdx.x + 1] = static_cast(xTad[xOffset] * yTad[yOffset]); // for input * gradO - - __syncthreads(); - - // accumulate sum per block - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && i + activeThreads < tadLen) { - - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); - } - - norm += shMem[0]; - sumOfProd += shMem[1]; + // accumulate sum per block + for (uint activeThreads = blockDim.x / 2; activeThreads > 0; + activeThreads /= 2) { + if (threadIdx.x < activeThreads && i + activeThreads < tadLen) { + shMem[2 * threadIdx.x] += shMem[2 * (threadIdx.x + activeThreads)]; + shMem[2 * threadIdx.x + 1] += + shMem[2 * (threadIdx.x + activeThreads) + 1]; + } + __syncthreads(); } - // *** SECOND STAGE - GRADIENT CALCULATION *** // - - norm = math::nd4j_sqrt(norm); + norm += shMem[0]; + sumOfProd += shMem[1]; + } - for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { + // *** SECOND STAGE - GRADIENT CALCULATION *** // - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); + norm = math::nd4j_sqrt(norm); - if(norm > clipNormVal) { + for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { + const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); + const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); - const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); + if (norm > clipNormVal) { + const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); - const Z factor1 = static_cast(1) / norm; // 1 / norm - const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) + const Z factor1 = static_cast(1) / norm; // 1 / norm + const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) - zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]); - } - else { - zTad[zOffset] = yTad[yOffset]; - } + zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - + factor2 * sumOfProd * xTad[xOffset]); + } else { + zTad[zOffset] = yTad[yOffset]; } + } } ////////////////////////////////////////////////////////////////////////// -template -static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - void* vreducBuff, const double clipNormVal) { - - if(xTadOffsets == nullptr) { // means whole array - clipByNormBPWholeArrCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast(clipNormVal)); - clipByNormBPCalcGradCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast(clipNormVal)); - } - else // means tads using - clipByNormBPTadsCuda<<>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast(clipNormVal)); +template +static void clipByNormBPCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yShapeInfo, + const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal) { + if (xTadOffsets == nullptr) { // means whole array + clipByNormBPWholeArrCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, + static_cast(clipNormVal)); + clipByNormBPCalcGradCuda + <<>>( + vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, + static_cast(clipNormVal)); + } else // means tads using + clipByNormBPTadsCuda + <<>>( + vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, + zShapeInfo, zTadOffsets, static_cast(clipNormVal)); } -BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES); +BUILD_DOUBLE_TEMPLATE( + template void clipByNormBPCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yShapeInfo, + const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), + FLOAT_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - - PointersManager manager(context, "clipByNormBP"); - - const double clipNormVal = clipNorm.e(0); - - const auto xType = input.dataType(); - const auto zType = gradI.dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - - - if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array - - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), nullptr, gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES); - } - else { // means tads using - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.shapeInfo(), dimensions); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), dimensions); +void clipByNormBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI /*output*/, + const std::vector& dimensions, const NDArray& clipNorm) { + PointersManager manager(context, "clipByNormBP"); + + const double clipNormVal = clipNorm.e(0); + + const auto xType = input.dataType(); + const auto zType = gradI.dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128; + + NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); + + if (dimensions.empty() || + dimensions.size() == input.rankOf()) { // means whole array + + const int blocksPerGrid = + (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + BUILD_DOUBLE_SELECTOR( + xType, zType, clipByNormBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, + gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, + context->getReductionPointer(), clipNormVal), + FLOAT_TYPES, FLOAT_TYPES); + } else { // means tads using + + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + gradO.shapeInfo(), dimensions); + auto packZ = ConstantTadHelper::getInstance()->tadForDimensions( + gradI.shapeInfo(), dimensions); + + const int blocksPerGrid = packX.numberOfTads(); + BUILD_DOUBLE_SELECTOR( + xType, zType, clipByNormBPCudaLauncher, + (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), + input.specialBuffer(), packX.platformShapeInfo(), + packX.platformOffsets(), gradO.specialBuffer(), + packY.platformShapeInfo(), packY.platformOffsets(), + gradI.specialBuffer(), packZ.platformShapeInfo(), + packZ.platformOffsets(), nullptr, clipNormVal), + FLOAT_TYPES, FLOAT_TYPES); + } + + NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); + + manager.synchronize(); +} - const int blocksPerGrid = packX.numberOfTads(); - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES); +template +static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, + Nd4jLong firstDim, + sd::graph::RandomGenerator* rng) { + auto tid = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { + int r = rng->relativeInt(i) % i; + if (i != r) { + const auto iOffset = shape::getIndexOffset(i, shape); + const auto rOffset = shape::getIndexOffset(r, shape); + T e0 = input[iOffset]; + T e1 = input[rOffset]; + // math::nd4j_swap(input(i), input(r)); + input[iOffset] = e1; + input[rOffset] = e0; } - - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); + } } - - template - static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) { - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - if (i != r) { - const auto iOffset = shape::getIndexOffset(i, shape); - const auto rOffset = shape::getIndexOffset(r, shape); - T e0 = input[iOffset]; - T e1 = input[rOffset]; - //math::nd4j_swap(input(i), input(r)); - input[iOffset] = e1; - input[rOffset] = e0; - } - } +template +static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, + T* output, Nd4jLong const* outputShape, + Nd4jLong firstDim, int* indices, + sd::graph::RandomGenerator* rng) { + // PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > + // Environment::getInstance()->tadThreshold()) + auto tid = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { + int r = rng->relativeInt(i) % i; + output[shape::getIndexOffset(i, outputShape)] = + input[shape::getIndexOffset(indices[r], inputShape)]; + if (i != r) { + output[shape::getIndexOffset(r, outputShape)] = + input[shape::getIndexOffset(indices[i], inputShape)]; + // output.p(r, input.e(indices[i])); + // math::nd4j_swap(indices[i], indices[r]); + atomicExch(&indices[i], indices[r]); } - template - static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) { - -// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)]; - if(i != r) { - output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)]; -// output.p(r, input.e(indices[i])); -// math::nd4j_swap(indices[i], indices[r]); - atomicExch(&indices[i], indices[r]); - } - } - + } +} +////////////////////////////////////////////////////////////////////////// +template +void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace) { + // check edge cases first + int temp; + const int firstDim = input.sizeAt(0); + auto stream = context->getCudaStream(); + NDArray::prepareSpecialUse({&output}, {&input}); + if (input.lengthOf() == 1 || firstDim == 1) { + if (!isInplace) output.assign(input); + } else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { + // apply Fisher-Yates shuffle + sd::graph::RandomGenerator* dRandom = nullptr; + cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); + cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), + cudaMemcpyHostToDevice); + T* inputBuf = reinterpret_cast(input.specialBuffer()); + if (isInplace) { + swapShuffleKernel<<<128, 256, 1024, *stream>>>( + inputBuf, input.specialShapeInfo(), firstDim, dRandom); + } else { + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); + cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), + cudaMemcpyDeviceToDevice); + // output.p(Nd4jLong(0), input.e(0)); + PointersManager pointersManager(context, "helper::randomShuffle_"); + int* indicesDev = reinterpret_cast(pointersManager.replicatePointer( + indices.data(), indices.size() * sizeof(int))); + T* outputBuf = reinterpret_cast(output.specialBuffer()); + fillShuffleKernel<<<128, 256, 1024, *stream>>>( + inputBuf, input.specialShapeInfo(), outputBuf, + output.specialShapeInfo(), firstDim, indicesDev, dRandom); + pointersManager.synchronize(); } - ////////////////////////////////////////////////////////////////////////// - template - void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - - // check edge cases first - int temp; - const int firstDim = input.sizeAt(0); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({&output}, {&input}); - if(input.lengthOf() == 1 || firstDim == 1) { - if(!isInplace) - output.assign(input); - } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { - - // apply Fisher-Yates shuffle - sd::graph::RandomGenerator* dRandom = nullptr; - cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - T* inputBuf = reinterpret_cast(input.specialBuffer()); - if(isInplace) { - swapShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom); - } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice); - //output.p(Nd4jLong(0), input.e(0)); - PointersManager pointersManager(context, "helper::randomShuffle_"); - int* indicesDev = reinterpret_cast(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int))); - T* outputBuf = reinterpret_cast(output.specialBuffer()); - fillShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom); - pointersManager.synchronize(); - } -// rng.rewindH(firstDim - 1); - cudaFree(dRandom); - } - else { - - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); - - // apply Fisher-Yates shuffle - if(isInplace) { - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - - if(i != r) - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); - } - } - else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - - if(i != r) { - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - } - if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); - } - rng.rewindH(firstDim-1); + // rng.rewindH(firstDim - 1); + cudaFree(dRandom); + } else { + // evaluate sub-arrays list of input array through all dimensions excluding + // first one + std::vector dimensions = + ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); + auto subArrsListIn = input.allTensorsAlongDimension(dimensions); + + // apply Fisher-Yates shuffle + if (isInplace) { + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + + if (i != r) subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); + } + } else { + // evaluate sub-arrays list of output array through all dimensions + // excluding first one + auto subArrsListOut = output.allTensorsAlongDimension(dimensions); + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); + bool isZeroShuffled = false; + + for (int i = firstDim - 1; i > 0; --i) { + int r = rng.relativeInt(i) % i; + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); + if (r == 0) isZeroShuffled = true; + + if (i != r) { + subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); + math::nd4j_swap(indices[i], indices[r]); } - NDArray::registerSpecialUse({&output}, {&input}); - - } - - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); + } + if (!isZeroShuffled) subArrsListOut.at(0)->assign(subArrsListIn.at(0)); } + rng.rewindH(firstDim - 1); + } + NDArray::registerSpecialUse({&output}, {&input}); +} - BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); +void randomShuffle(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, + (context, input, output, rng, isInplace), + LIBND4J_TYPES); +} +BUILD_SINGLE_TEMPLATE(template void randomShuffle_, + (sd::LaunchContext * context, NDArray& input, + NDArray& output, sd::graph::RandomGenerator& rng, + const bool isInplace), + LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// - void eye(sd::LaunchContext * context, NDArray& output) { +////////////////////////////////////////////////////////////////////////// +void eye(sd::LaunchContext* context, NDArray& output) { output.setIdentity(); } + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static __global__ void clipByNormInplaceKernel( + Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, + Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, + T clipNorm) { + for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { + __shared__ T* z; + __shared__ Nd4jLong len; + if (threadIdx.x == 0) { + len = shape::length(shape); + z = inputBuffer + inputOffsets[arr]; + } + __syncthreads(); + for (int j = threadIdx.x; j < len; j += blockDim.x) { + auto xIndex = shape::getIndexOffset(j, shape); - output.setIdentity(); + if (norm2Buf[arr] > clipNorm) + z[xIndex] *= + clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c' } + } +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static __global__ void clipByNormKernel( + Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, + Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, + Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, + T clipNorm) { + for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { + __shared__ T *x, *z; + __shared__ Nd4jLong lenZ; + __shared__ T norm2; - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) { - for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { - __shared__ T* z; - __shared__ Nd4jLong len; - if (threadIdx.x == 0) { - len = shape::length(shape); - z = inputBuffer + inputOffsets[arr]; - } - __syncthreads(); - for (int j = threadIdx.x; j < len; j+= blockDim.x) { - auto xIndex = shape::getIndexOffset(j, shape); - - if(norm2Buf[arr] > clipNorm) - z[xIndex] *= clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c' - } - } + if (threadIdx.x == 0) { + x = inputBuffer + inputOffsets[arr]; + z = outputBuffer + outputOffsets[arr]; + lenZ = shape::length(outputShape); + norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)]; } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void clipByNormKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) { - - for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { - __shared__ T* x, *z; - __shared__ Nd4jLong lenZ; - __shared__ T norm2; - - if (threadIdx.x == 0) { - x = inputBuffer + inputOffsets[arr]; - z = outputBuffer + outputOffsets[arr]; - lenZ = shape::length(outputShape); - norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)]; - } - __syncthreads(); - for (Nd4jLong j = threadIdx.x; j < lenZ; j+= blockDim.x) { - auto xIndex = shape::getIndexOffset(j, shape); - auto zIndex = shape::getIndexOffset(j, outputShape); - if(norm2 > clipNorm) { - z[zIndex] = x[xIndex] * clipNorm / norm2; // case with ews = 1 and ordering is 'c' - } else { - z[zIndex] = x[xIndex]; - } - //printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]); - } - __syncthreads(); - } + __syncthreads(); + for (Nd4jLong j = threadIdx.x; j < lenZ; j += blockDim.x) { + auto xIndex = shape::getIndexOffset(j, shape); + auto zIndex = shape::getIndexOffset(j, outputShape); + if (norm2 > clipNorm) { + z[zIndex] = x[xIndex] * clipNorm / + norm2; // case with ews = 1 and ordering is 'c' + } else { + z[zIndex] = x[xIndex]; + } + // printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]); } + __syncthreads(); + } +} - ////////////////////////////////////////////////////////////////////////// - template - static void clipByNorm_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, NDArray const& clipNormA, const bool isInplace) { - const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - clipNormA.syncToHost(); - //norm2.printBuffer("Norm2"); - T const clipNorm = clipNormA.e(0); - //clipNormA.printBuffer("ClipNorm"); - auto stream = context->getCudaStream(); - if (isInplace) { - if(norm2.lengthOf() == 1) { - norm2.syncToHost(); - T norm2Val = norm2.e(0); - if(norm2Val > clipNorm) - input *= clipNorm / norm2Val; - } - else { - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - //auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimsToExclude); - T* inputBuffer = reinterpret_cast(input.specialBuffer()); - T* norm2buf = reinterpret_cast(norm2.specialBuffer()); - - clipByNormInplaceKernel<<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); - } - } - else { - - if(norm2.lengthOf() == 1) { - norm2.syncToHost(); - T norm2Val = norm2.e(0); - - if(norm2Val > clipNorm) - output.assign( input * (clipNorm / norm2Val)); - else - output.assign( input ); - } - else { - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimensions); - T* inputBuffer = reinterpret_cast(input.specialBuffer()); - T* norm2buf = reinterpret_cast(norm2.specialBuffer()); - T* outputBuffer = reinterpret_cast(output.specialBuffer()); - - clipByNormKernel<<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); - } - } +////////////////////////////////////////////////////////////////////////// +template +static void clipByNorm_(sd::LaunchContext* context, NDArray& input, + NDArray& output, const std::vector& dimensions, + NDArray const& clipNormA, const bool isInplace) { + const int rank = input.rankOf(); + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); + clipNormA.syncToHost(); + // norm2.printBuffer("Norm2"); + T const clipNorm = clipNormA.e(0); + // clipNormA.printBuffer("ClipNorm"); + auto stream = context->getCudaStream(); + if (isInplace) { + if (norm2.lengthOf() == 1) { + norm2.syncToHost(); + T norm2Val = norm2.e(0); + if (norm2Val > clipNorm) input *= clipNorm / norm2Val; + } else { + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, dimensions); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + // auto packZ = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), + // dimsToExclude); + T* inputBuffer = reinterpret_cast(input.specialBuffer()); + T* norm2buf = reinterpret_cast(norm2.specialBuffer()); + + clipByNormInplaceKernel<<<256, 512, 1024, *stream>>>( + numOfSubArrs, inputBuffer, packX.specialShapeInfo(), + packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); } - - void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); + } else { + if (norm2.lengthOf() == 1) { + norm2.syncToHost(); + T norm2Val = norm2.e(0); + + if (norm2Val > clipNorm) + output.assign(input * (clipNorm / norm2Val)); + else + output.assign(input); + } else { + std::vector dimsToExclude = + ShapeUtils::evalDimsToExclude(rank, dimensions); + const Nd4jLong numOfSubArrs = + ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + output.shapeInfo(), dimensions); + T* inputBuffer = reinterpret_cast(input.specialBuffer()); + T* norm2buf = reinterpret_cast(norm2.specialBuffer()); + T* outputBuffer = reinterpret_cast(output.specialBuffer()); + + clipByNormKernel<<<256, 512, 1024, *stream>>>( + numOfSubArrs, inputBuffer, packX.specialShapeInfo(), + packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), + packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); } + } +} - BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - - template - void clipByGlobalNorm_(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - NDArray globalNorm = NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) - - for (auto i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - auto l2norm = input->reduceNumber(reduce::Norm2); - globalNorm += l2norm * l2norm; - } - - globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm); - outputs[inputs.size()]->p(0, globalNorm); - globalNorm.syncToHost(); - const T factor = static_cast(clipNorm) / globalNorm.e(0); +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace) { + BUILD_SINGLE_SELECTOR( + output.dataType(), clipByNorm_, + (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); +} - for (size_t e = 0; e < inputs.size(); e++) { - // all-reduce - auto input = inputs[e]; - auto output = outputs[e]; +BUILD_SINGLE_TEMPLATE(template void clipByNorm_, + (sd::LaunchContext * context, NDArray& input, + NDArray& output, const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace), + FLOAT_TYPES); + +template +void clipByGlobalNorm_(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + NDArray globalNorm = NDArrayFactory::create( + 0, inputs[0]->getContext()); // sqrt(sum([l2norm(t)**2 for t in t_list])) + + for (auto i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + auto l2norm = input->reduceNumber(reduce::Norm2); + globalNorm += l2norm * l2norm; + } + + globalNorm.applyTransform(transform::Sqrt, + globalNorm); // = sd::math::nd4j_sqrt(globalNorm); + outputs[inputs.size()]->p(0, globalNorm); + globalNorm.syncToHost(); + const T factor = static_cast(clipNorm) / globalNorm.e(0); + + for (size_t e = 0; e < inputs.size(); e++) { + // all-reduce + auto input = inputs[e]; + auto output = outputs[e]; + + if (globalNorm.e(0) <= clipNorm) { + output->assign(input); + } else { + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + input->applyLambda(lambda, *output); + } + } +} - if (globalNorm.e(0) <= clipNorm) { - output->assign(input); - } - else { +void clipByGlobalNorm(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace) { + BUILD_SINGLE_SELECTOR( + outputs[0]->dataType(), clipByGlobalNorm_, + (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); +} - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, *output); - } - } - } +BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, + (sd::LaunchContext * context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace), + FLOAT_TYPES); - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); +////////////////////////////////////////////////////////////////////////// +template +static void clipByAveraged_(sd::LaunchContext* context, NDArray& input, + NDArray& output, const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace) { + auto cn = clipNorm.e(0); + if (dimensions.size() == 0) { + // all-reduce + T n2 = input.reduceNumber(reduce::Norm2).e(0) / + static_cast(input.lengthOf()); + if (n2 <= cn) { + if (!isInplace) output.assign(input); + } else { + const T factor = cn / n2; + // auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + // input.applyLambda(lambda, output); + output.assign(input * factor); } - - BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); - - - ////////////////////////////////////////////////////////////////////////// - template - static void clipByAveraged_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - auto cn = clipNorm.e(0); - if (dimensions.size() == 0) { - // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / static_cast(input.lengthOf()); - if (n2 <= cn) { - if (!isInplace) - output.assign(input); - } - else { - const T factor = cn / n2; - //auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - //input.applyLambda(lambda, output); - output.assign(input * factor); - } - } - else { - // along dimension - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); - if (!isInplace) - output.assign(input); - auto tads = output.allTensorsAlongDimension(dimensions); - auto outTads = output.allTensorsAlongDimension(dimensions); - // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads.size(); e++) { - T n2 = norm2.e(e) / static_cast(tads.at(e)->lengthOf()); - const T factor = cn / n2; - if (n2 > cn) { - //auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda(lambda, &output); - } - } - } + } else { + // along dimension + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); + if (!isInplace) output.assign(input); + auto tads = output.allTensorsAlongDimension(dimensions); + auto outTads = output.allTensorsAlongDimension(dimensions); + // TODO: make this CUDA-compliant somehow + for (int e = 0; e < tads.size(); e++) { + T n2 = norm2.e(e) / static_cast(tads.at(e)->lengthOf()); + const T factor = cn / n2; + if (n2 > cn) { + // auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; + tads.at(e)->applyScalar( + scalar::Multiply, factor, + *outTads.at(e)); // applyLambda(lambda, &output); + } } + } +} - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); - } +void clipByAveraged(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace) { + BUILD_SINGLE_SELECTOR( + input.dataType(), clipByAveraged_, + (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); +} - BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, + (sd::LaunchContext * context, NDArray& input, + NDArray& output, const std::vector& dimensions, + const NDArray& clipNorm, const bool isInplace), + FLOAT_TYPES); /* if (d1 > params[1]) @@ -890,55 +1021,73 @@ void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArra return params[0]; else return d1; */ - template - static void __global__ clipByValueKernel(void* input, Nd4jLong const* inputShape, void* output, Nd4jLong const* outputShape, double leftBound, double rightBound) { - __shared__ T* outputBuf; - __shared__ T* inputBuf; - __shared__ Nd4jLong length; - __shared__ bool linearBuffers; - if (threadIdx.x == 0) { - outputBuf = reinterpret_cast(output); - inputBuf = reinterpret_cast(input); - length = shape::length(inputShape); - linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1; - } - __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - if (linearBuffers) { - if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound; - else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound; - else outputBuf[e] = inputBuf[e]; - } - else { - auto inputOffset = shape::getIndexOffset(e, inputShape); - auto outputOffset = shape::getIndexOffset(e, outputShape); - if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound; - else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound; - else outputBuf[outputOffset] = inputBuf[outputOffset]; - } - } +template +static void __global__ clipByValueKernel(void* input, + Nd4jLong const* inputShape, + void* output, + Nd4jLong const* outputShape, + double leftBound, double rightBound) { + __shared__ T* outputBuf; + __shared__ T* inputBuf; + __shared__ Nd4jLong length; + __shared__ bool linearBuffers; + if (threadIdx.x == 0) { + outputBuf = reinterpret_cast(output); + inputBuf = reinterpret_cast(input); + length = shape::length(inputShape); + linearBuffers = shape::elementWiseStride(inputShape) == + shape::elementWiseStride(outputShape) && + shape::elementWiseStride(inputShape) == 1; + } + __syncthreads(); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + if (linearBuffers) { + if (inputBuf[e] > rightBound) + outputBuf[e] = (T)rightBound; + else if (inputBuf[e] < leftBound) + outputBuf[e] = (T)leftBound; + else + outputBuf[e] = inputBuf[e]; + } else { + auto inputOffset = shape::getIndexOffset(e, inputShape); + auto outputOffset = shape::getIndexOffset(e, outputShape); + if (inputBuf[inputOffset] > rightBound) + outputBuf[outputOffset] = (T)rightBound; + else if (inputBuf[inputOffset] < leftBound) + outputBuf[outputOffset] = (T)leftBound; + else + outputBuf[outputOffset] = inputBuf[outputOffset]; } - - template - static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - auto stream = context->getCudaStream(); - if (!input.isActualOnDeviceSide()) - input.syncToDevice(); - NDArray::prepareSpecialUse({&output}, {&input}); - clipByValueKernel<<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound); - NDArray::registerSpecialUse({&output}, {&input}); - } - - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES); - + } } + +template +static void clipByValue_(sd::LaunchContext* context, NDArray& input, + double leftBound, double rightBound, NDArray& output) { + auto stream = context->getCudaStream(); + if (!input.isActualOnDeviceSide()) input.syncToDevice(); + NDArray::prepareSpecialUse({&output}, {&input}); + clipByValueKernel<<<256, 512, 8192, *stream>>>( + input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), leftBound, rightBound); + NDArray::registerSpecialUse({&output}, {&input}); } + +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, + (context, input, leftBound, rightBound, output), + FLOAT_TYPES); } +BUILD_SINGLE_TEMPLATE(template void clipByValue_, + (sd::LaunchContext * context, NDArray& input, + double leftBound, double rightBound, NDArray& output); + , FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu index c8f26de6f209..091e361cc845 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -18,217 +18,241 @@ // @author GS // -#include #include #include #include +#include + #include "../triangular_solve.h" namespace sd { - namespace ops { - namespace helpers { - /* - * lower triangular process for system of linear equations - * x_1 = b_1/a_1,1 - * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 - * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 - * ... - * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - template - static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, - T const* rightInput, Nd4jLong const* rightInputShape, - bool const adjoint, T* output, Nd4jLong const* outputShape, - Nd4jLong rows, Nd4jLong cols) { - - for (auto r = 0; r < rows; r++) { - for (auto j = 0; j < cols; j++) { - Nd4jLong posY[] = {r, j}; - Nd4jLong posX[] = {r, r}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - - auto sum = rightInput[yIndex]; - for (auto c = 0; c < r; c++) { - Nd4jLong posZ[] = {c, j}; - Nd4jLong pos[] = {r, c}; - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; - } - output[zIndex] = sum / leftInput[xIndex]; - } - } - } - - /* - * upper triangular process for system of linear equations - * x_M = b_M/a_M,M - * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 - * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 - * ... - * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 - * - * output == x - * a == leftInput - * b == rightInput - * - * */ - - template - static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape, - T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output, - Nd4jLong const* outputShape, Nd4jLong rows, Nd4jLong cols) { - - for (auto r = rows; r > 0; r--) { - for (auto j = 0; j < cols; j++) { - Nd4jLong posY[] = {r - 1, j}; - Nd4jLong posX[] = {r - 1, r - 1}; - auto xIndex = shape::getOffset(leftInputShape, posX, 0); - auto yIndex = shape::getOffset(rightInputShape, posY, 0); - auto zIndex = shape::getOffset(outputShape, posY, 0); - auto sum = rightInput[yIndex]; - for (auto c = r; c < rows; c++) { - Nd4jLong posZ[] = {c, j}; - Nd4jLong pos[] = {r - 1, c}; - auto zcIndex = shape::getOffset(outputShape, posZ, 0); - auto xcIndex = shape::getOffset(leftInputShape, pos, 0); - sum -= leftInput[xcIndex] * output[zcIndex]; - } - output[zIndex] = sum / leftInput[xIndex]; - } - } - } - - template - static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape, - T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, T* output, - Nd4jLong const* outputShape, Nd4jLong const* tadLeftShape, Nd4jLong const* tadLeftOffset, Nd4jLong const* tadRightShape, - Nd4jLong const* tadRightOffset, Nd4jLong const* tadOutputShape, Nd4jLong const* tadOutputOffset, Nd4jLong batchNum) { - - __shared__ Nd4jLong rows; - __shared__ Nd4jLong cols; - - if (threadIdx.x == 0) { - rows = shape::sizeAt(leftPartShape, -2); - cols = shape::sizeAt(rightPartShape, -1); - } - __syncthreads(); - - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto stop = batchNum; - auto increment = blockDim.x * gridDim.x; - - for (auto i = start; i < stop; i += increment) { - auto pLeftPart = leftInput + tadLeftOffset[i]; - auto pRightPart = rightInput + tadRightOffset[i]; - auto pOutputPart = output + tadOutputOffset[i]; - if (lower) { - lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); - } else { - upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows, cols); - } - } - } - - template - static int triangularSolveFunctor_(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, - bool lower, bool adjoint, NDArray* output) { - NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); - auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions(leftInput->shapeInfo(), {-2, -1}); - auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions(rightInput->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); - - auto stream = context->getCudaStream(); - T const* leftBuf = reinterpret_cast(leftInput->specialBuffer()); - T const* rightBuf = reinterpret_cast(rightInput->specialBuffer()); - T* outputBuf = reinterpret_cast(output->specialBuffer()); - triangularSolveKernel<<<128, 128, 256, *stream>>>(leftBuf, leftInput->specialShapeInfo(), - rightBuf, rightInput->specialShapeInfo(), lower, adjoint, outputBuf, output->specialShapeInfo(), - leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(), - rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(), - leftTads.numberOfTads()); - - NDArray::registerSpecialUse({output}, {leftInput, rightInput}); - - return Status::OK(); - - } - - int triangularSolveFunctor(sd::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) { - BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); - } - - template - static __global__ void upperAdjointKernel(T const* input, T* output, - Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, - Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto inputPart = input + inputOffsets[b]; - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = threadIdx.y; c <= r; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(inputTads, xPos); - outputPart[zIndex] = inputPart[xIndex]; - } - } - } - - } - - template - static __global__ void lowerAdjointKernel(T const* input, T* output, - Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns, - Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { - - for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { - auto inputPart = input + inputOffsets[b]; - auto outputPart = output + outputOffsets[b]; - for (auto r = threadIdx.x; r < rows; r += blockDim.x) { - for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { - Nd4jLong zPos[] = {r, c}; - Nd4jLong xPos[] = {c, r}; - auto zIndex = shape::getOffset(outputTads, zPos); - auto xIndex = shape::getOffset(inputTads, xPos); - outputPart[zIndex] = inputPart[xIndex]; - } - } - } - } - - template - static void adjointTriangularMatrix_(sd::LaunchContext* context, NDArray const* input, bool const lower, - NDArray* output) { - - auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {-2, -1}); - auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); - auto stream = context->getCudaStream(); - auto inputBuf = reinterpret_cast(input->specialBuffer()); - auto outputBuf = reinterpret_cast(output->specialBuffer()); - auto rows = input->sizeAt(-2); - auto columns = input->sizeAt(-1); - - if (lower) { - lowerAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); - } else { - upperAdjointKernel<<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets()); - } - } - - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE); - } - - } +namespace ops { +namespace helpers { +/* + * lower triangular process for system of linear equations + * x_1 = b_1/a_1,1 + * x_2 = (b_2 - a_2,1 * x_1) / a_2,2 + * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3 + * ... + * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M + * + * output == x + * a == leftInput + * b == rightInput + * + * */ +template +static __device__ void lowerTriangularSolve( + T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, + Nd4jLong const* rightInputShape, bool const adjoint, T* output, + Nd4jLong const* outputShape, Nd4jLong rows, Nd4jLong cols) { + for (auto r = 0; r < rows; r++) { + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r, j}; + Nd4jLong posX[] = {r, r}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + + auto sum = rightInput[yIndex]; + for (auto c = 0; c < r; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r, c}; + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; } + } } + +/* + * upper triangular process for system of linear equations + * x_M = b_M/a_M,M + * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1 + * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3 + * ... + * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1 + * + * output == x + * a == leftInput + * b == rightInput + * + * */ + +template +static __device__ void upperTriangularSolve( + T const* leftInput, Nd4jLong const* leftInputShape, T const* rightInput, + Nd4jLong const* rightInputShape, bool const adjoint, T* output, + Nd4jLong const* outputShape, Nd4jLong rows, Nd4jLong cols) { + for (auto r = rows; r > 0; r--) { + for (auto j = 0; j < cols; j++) { + Nd4jLong posY[] = {r - 1, j}; + Nd4jLong posX[] = {r - 1, r - 1}; + auto xIndex = shape::getOffset(leftInputShape, posX, 0); + auto yIndex = shape::getOffset(rightInputShape, posY, 0); + auto zIndex = shape::getOffset(outputShape, posY, 0); + auto sum = rightInput[yIndex]; + for (auto c = r; c < rows; c++) { + Nd4jLong posZ[] = {c, j}; + Nd4jLong pos[] = {r - 1, c}; + auto zcIndex = shape::getOffset(outputShape, posZ, 0); + auto xcIndex = shape::getOffset(leftInputShape, pos, 0); + sum -= leftInput[xcIndex] * output[zcIndex]; + } + output[zIndex] = sum / leftInput[xIndex]; + } + } +} + +template +static __global__ void triangularSolveKernel( + T const* leftInput, Nd4jLong const* leftPartShape, T const* rightInput, + Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, + T* output, Nd4jLong const* outputShape, Nd4jLong const* tadLeftShape, + Nd4jLong const* tadLeftOffset, Nd4jLong const* tadRightShape, + Nd4jLong const* tadRightOffset, Nd4jLong const* tadOutputShape, + Nd4jLong const* tadOutputOffset, Nd4jLong batchNum) { + __shared__ Nd4jLong rows; + __shared__ Nd4jLong cols; + + if (threadIdx.x == 0) { + rows = shape::sizeAt(leftPartShape, -2); + cols = shape::sizeAt(rightPartShape, -1); + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto stop = batchNum; + auto increment = blockDim.x * gridDim.x; + + for (auto i = start; i < stop; i += increment) { + auto pLeftPart = leftInput + tadLeftOffset[i]; + auto pRightPart = rightInput + tadRightOffset[i]; + auto pOutputPart = output + tadOutputOffset[i]; + if (lower) { + lowerTriangularSolve(pLeftPart, tadLeftShape, pRightPart, + tadRightShape, adjoint, pOutputPart, + tadOutputShape, rows, cols); + } else { + upperTriangularSolve(pLeftPart, tadLeftShape, pRightPart, + tadRightShape, adjoint, pOutputPart, + tadOutputShape, rows, cols); + } + } +} + +template +static int triangularSolveFunctor_(sd::LaunchContext* context, + NDArray* leftInput, NDArray* rightInput, + bool lower, bool adjoint, NDArray* output) { + NDArray::prepareSpecialUse({output}, {leftInput, rightInput}); + auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions( + leftInput->shapeInfo(), {-2, -1}); + auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions( + rightInput->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {-2, -1}); + + auto stream = context->getCudaStream(); + T const* leftBuf = reinterpret_cast(leftInput->specialBuffer()); + T const* rightBuf = reinterpret_cast(rightInput->specialBuffer()); + T* outputBuf = reinterpret_cast(output->specialBuffer()); + triangularSolveKernel<<<128, 128, 256, *stream>>>( + leftBuf, leftInput->specialShapeInfo(), rightBuf, + rightInput->specialShapeInfo(), lower, adjoint, outputBuf, + output->specialShapeInfo(), leftTads.specialShapeInfo(), + leftTads.specialOffsets(), rightTads.specialShapeInfo(), + rightTads.specialOffsets(), outputTads.specialShapeInfo(), + outputTads.specialOffsets(), leftTads.numberOfTads()); + + NDArray::registerSpecialUse({output}, {leftInput, rightInput}); + + return Status::OK(); +} + +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool adjoint, + NDArray* output) { + BUILD_SINGLE_SELECTOR( + leftInput->dataType(), return triangularSolveFunctor_, + (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE); +} + +template +static __global__ void upperAdjointKernel( + T const* input, T* output, Nd4jLong batchSize, Nd4jLong rows, + Nd4jLong columns, Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, + Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = threadIdx.y; c <= r; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } +} + +template +static __global__ void lowerAdjointKernel( + T const* input, T* output, Nd4jLong batchSize, Nd4jLong rows, + Nd4jLong columns, Nd4jLong const* inputTads, Nd4jLong const* inputOffsets, + Nd4jLong const* outputTads, Nd4jLong const* outputOffsets) { + for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { + auto inputPart = input + inputOffsets[b]; + auto outputPart = output + outputOffsets[b]; + for (auto r = threadIdx.x; r < rows; r += blockDim.x) { + for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) { + Nd4jLong zPos[] = {r, c}; + Nd4jLong xPos[] = {c, r}; + auto zIndex = shape::getOffset(outputTads, zPos); + auto xIndex = shape::getOffset(inputTads, xPos); + outputPart[zIndex] = inputPart[xIndex]; + } + } + } +} + +template +static void adjointTriangularMatrix_(sd::LaunchContext* context, + NDArray const* input, bool const lower, + NDArray* output) { + auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions( + input->shapeInfo(), {-2, -1}); + auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions( + output->shapeInfo(), {-2, -1}); + auto stream = context->getCudaStream(); + auto inputBuf = reinterpret_cast(input->specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); + auto rows = input->sizeAt(-2); + auto columns = input->sizeAt(-1); + + if (lower) { + lowerAdjointKernel<<<128, 256, 256, *stream>>>( + inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, + inputTads.specialShapeInfo(), inputTads.specialOffsets(), + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } else { + upperAdjointKernel<<<128, 256, 256, *stream>>>( + inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, + inputTads.specialShapeInfo(), inputTads.specialOffsets(), + outputTads.specialShapeInfo(), outputTads.specialOffsets()); + } +} + +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, + (context, input, lower, output), FLOAT_NATIVE); +} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu index c096c4294cd1..dae449321d78 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu @@ -18,112 +18,143 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo, - const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, - const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { - - const auto grad = reinterpret_cast(vx); - const auto initMsg= reinterpret_cast(vinMsg); - const auto initMsdx = reinterpret_cast(vinMsdx); - - auto up = reinterpret_cast(vz); - auto stMsg = reinterpret_cast(vstMsg); - auto stMsdx = reinterpret_cast(vstMsdx); - - __shared__ Nd4jLong xLen; - __shared__ T rhoT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - rhoT = (1 - rho); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) && - 1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && - shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && - shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); - bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); - bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); - bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); +template +__global__ void adaDeltaUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, + const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, + const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstMsg, const Nd4jLong* stMsgShapeInfo, void* vstMsdx, + const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { + const auto grad = reinterpret_cast(vx); + const auto initMsg = reinterpret_cast(vinMsg); + const auto initMsdx = reinterpret_cast(vinMsdx); + + auto up = reinterpret_cast(vz); + auto stMsg = reinterpret_cast(vstMsg); + auto stMsdx = reinterpret_cast(vstMsdx); + + __shared__ Nd4jLong xLen; + __shared__ T rhoT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, + bXInMsdxSame, bXStMsdxSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + rhoT = (1 - rho); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stMsgShapeInfo) && + 1 == shape::elementWiseStride(inMsgShapeInfo) && + 1 == shape::elementWiseStride(stMsdxShapeInfo) && + 1 == shape::elementWiseStride(inMsdxShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && + shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && + shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && + shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); + bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); + bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); + bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, + stMsgOffset = i, stMsdxOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMsgOffset = + bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); + stMsgOffset = + bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); + initMsdxOffset = + bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); + stMsdxOffset = + bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + stMsg[stMsgOffset] = + rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i; + up[zOffset] = + grad[xOffset] * + (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / + sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); - stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); - initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); - stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); - } - - stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; - - up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); - - stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; - } + stMsdx[stMsdxOffset] = + rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo, - void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { - - const T rho = static_cast(dRho); - const T epsilon = static_cast(dEpsilon); - - adaDeltaUpdaterCuda<<>>(vx, xShapeInfo, vinMsg, inMsgShapeInfo, - vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon); +template +linkage void adaDeltaUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, + const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstMsg, const Nd4jLong* stMsgShapeInfo, void* vstMsdx, + const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + + adaDeltaUpdaterCuda<<>>( + vx, xShapeInfo, vinMsg, inMsgShapeInfo, vinMsdx, inMsdxShapeInfo, vz, + zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, + epsilon); } /////////////////////////////////////////////////////////////////// -void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, - NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { - - PointersManager manager(context, "adaDeltaUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateMsg.specialBuffer(), initStateMsg.specialShapeInfo(), initStateMsdx.specialBuffer(), initStateMsdx.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(),stateMsg.specialBuffer(), stateMsg.specialShapeInfo(), - stateMsdx.specialBuffer(), stateMsdx.specialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); - - manager.synchronize(); +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon) { + PointersManager manager(context, "adaDeltaUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateMsg, &stateMsdx}, + {&gradient, &initStateMsg, &initStateMsdx}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaDeltaUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateMsg.specialBuffer(), initStateMsg.specialShapeInfo(), + initStateMsdx.specialBuffer(), initStateMsdx.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateMsg.specialBuffer(), stateMsg.specialShapeInfo(), + stateMsdx.specialBuffer(), stateMsdx.specialShapeInfo(), dRho, dEpsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateMsg, &stateMsdx}, + {&gradient, &initStateMsg, &initStateMsdx}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu index 50a43986c433..e40aa5a60804 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu @@ -18,100 +18,109 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, +template +__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vst, const Nd4jLong* stShapeInfo, const T lr, const T epsilon) { - - const auto x = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - __shared__ Nd4jLong xLen; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && - shape::order(xShapeInfo) == shape::order(inShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + __shared__ Nd4jLong xLen; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; - up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); - - } + st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; + up[zOffset] = + (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dEpsilon) { - - const T lr = static_cast(dLr); - const T epsilon = static_cast(dEpsilon); - - adaGradUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, epsilon); +template +linkage void adaGradUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dEpsilon) { + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + adaGradUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + epsilon); } /////////////////////////////////////////////////////////////////// -void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, - NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { - - PointersManager manager(context, "adaGradUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), - gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState }); - - manager.synchronize(); +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + PointersManager manager(context, "adaGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateH}, {&gradient, &initState}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaGradUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initState.specialBuffer(), initState.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dEpsilon), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateH}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu index 09301d05a770..38a45d0d84c9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu @@ -18,125 +18,150 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initU = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stU = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T beta1T, epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - beta1T = sd::math::nd4j_pow(beta1, (iteration + 1) ); - - epsilonT = lr / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) && - shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) && - shape::order(xShapeInfo) == shape::order(stvShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void adaMaxUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, + const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T beta1T, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + + epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(stmShapeInfo) && + shape::order(xShapeInfo) == shape::order(inmShapeInfo) && + shape::order(xShapeInfo) == shape::order(invShapeInfo) && + shape::order(xShapeInfo) == shape::order(stvShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering) { - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } + // m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + // u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), + sd::math::nd4j_abs(grad[xOffset])) + + 1e-32; - //m = B_1 * m + (1-B_1)*grad - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - //u = max(B_2 * u, |grad|) - stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; - - up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; - } + up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; + } } /////////////////////////////////////////////////////////////////// -template -linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, - void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, - const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - adaMaxUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, - zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void adaMaxUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + adaMaxUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, - NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, - const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "adaMaxUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), - gradient.specialBuffer(), gradient.specialShapeInfo(), initStateU.specialBuffer(), - initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), - stateU.specialShapeInfo(), stateM.specialBuffer(), stateM.specialShapeInfo(), - dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - manager.synchronize(); +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "adaMaxUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adaMaxUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateU.specialBuffer(), initStateU.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateU.specialBuffer(), stateU.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu index 91d79809c707..112cfabe9df8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu @@ -18,122 +18,152 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, - const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, - const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initU = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stU = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); - T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); - - epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + + grad[xOffset] * grad[xOffset] * (1 - beta2); - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } - - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); - stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); - - up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); - } + up[zOffset] = (stM[stMOffset] * epsilonT) / + (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, - void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - adamUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void adamUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + adamUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, - NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "adamUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(), - stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - - NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); - - manager.synchronize(); + PointersManager manager(context, "adamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + BUILD_SINGLE_SELECTOR( + gradient.dataType(), adamUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateU.specialBuffer(), initStateU.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateU.specialBuffer(), stateU.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateU, &stateM}, + {&gradient, &initStateU, &initStateM}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu index ff3bc1e4bed8..17d55d5054bd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu @@ -18,135 +18,176 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, - const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initV = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - const auto initH = reinterpret_cast(vinh); - - auto up = reinterpret_cast(vz); - auto stV = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - auto stH = reinterpret_cast(vstH); - - __shared__ Nd4jLong xLen; - __shared__ T mbeta1, mbeta2, epsilonT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); - - if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) - epsilonT = epsilon; - - mbeta1 = (1 - beta1); - mbeta2 = (1 - beta2); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) && - 1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo); - - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) && - shape::order(sthShapeInfo) == shape::order(inhShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); - bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); - bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); +template +__global__ void amsGradUpdaterCuda( + const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, + const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, const void* vinh, + const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, + const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + const auto initH = reinterpret_cast(vinh); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + auto stH = reinterpret_cast(vstH); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1, mbeta2, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame, bXInHSame, bXStHSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + epsilonT = lr * + sd::math::nd4j_sqrt( + 1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / + (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || + sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo) && + 1 == shape::elementWiseStride(sthShapeInfo) && + 1 == shape::elementWiseStride(inhShapeInfo); + + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo) && + shape::order(invShapeInfo) == shape::order(sthShapeInfo) && + shape::order(sthShapeInfo) == shape::order(inhShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); + bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, + initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + initVOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initHOffset = + bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); + stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); - stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); - } - stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; - stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = + beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); - up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); - } + up[zOffset] = epsilonT * stM[stMOffset] / + (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - amsGradUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void amsGradUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, const void* vinh, + const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + amsGradUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vinh, + inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, + vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); } /////////////////////////////////////////////////////////////////// -void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, - NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "amsGradUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateV.specialBuffer(), initStateV.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - initStateH.specialBuffer(), initStateH.specialShapeInfo(), update.specialBuffer(), update.specialShapeInfo(), - stateV.specialBuffer(), stateV.specialShapeInfo(), stateM.specialBuffer(), stateM.specialShapeInfo(), - stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); - - manager.synchronize(); +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "amsGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse( + {&update, &stateV, &stateM, &stateH}, + {&gradient, &initStateV, &initStateM, &initStateH}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), amsGradUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateV.specialBuffer(), initStateV.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + initStateH.specialBuffer(), initStateH.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), + stateH.specialBuffer(), stateH.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse( + {&update, &stateV, &stateM, &stateH}, + {&gradient, &initStateV, &initStateM, &initStateH}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu index 141ed27dbfbe..32bd086cad99 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu @@ -18,120 +18,150 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, - const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, - void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, - const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { - - const auto grad = reinterpret_cast(vx); - const auto initV = reinterpret_cast(vinv); - const auto initM = reinterpret_cast(vinm); - - auto up = reinterpret_cast(vz); - auto stV = reinterpret_cast(vstV); - auto stM = reinterpret_cast(vstM); - - __shared__ Nd4jLong xLen; - __shared__ T mbeta1T, mbeta1, mbeta2; - __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - - mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); - mbeta1 = (1 - beta1); - mbeta2 = (1 - beta2); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && - 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && - shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && - shape::order(stvShapeInfo) == shape::order(invShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); - bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); - bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); - bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); +template +__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, + const T epsilon, const T iteration) { + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1T, mbeta1, mbeta2; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, + bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && + 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && + 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && + shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, + stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = + bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = + bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; - - if (!bEWS || !bOrdering){ - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); - stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); - initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); - stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); - } - auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; - stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; - stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stUOffset] = + beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; - up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); - } + up[zOffset] = + (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / + (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, - const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, - const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - const T lr = static_cast(dLr); - const T beta1 = static_cast(dBeta1); - const T beta2 = static_cast(dBeta2); - const T epsilon = static_cast(dEpsilon); - const T iteration = static_cast(nIteration); - - nadamUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, - vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +template +linkage void nadamUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + nadamUpdaterCuda<<>>( + vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, zShapeInfo, + vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, + iteration); } /////////////////////////////////////////////////////////////////// -void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, - NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { - - PointersManager manager(context, "nadamUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initStateV.specialBuffer(), initStateV.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), stateV.specialBuffer(), stateV.specialShapeInfo(), - stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); - - manager.synchronize(); +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + PointersManager manager(context, "nadamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateV, &stateM}, + {&gradient, &initStateV, &initStateM}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), nadamUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateV.specialBuffer(), initStateV.specialShapeInfo(), + initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, + dEpsilon, nIteration), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateV, &stateM}, + {&gradient, &initStateV, &initStateM}); + + manager.synchronize(); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu index 75e1f593884b..b09139ed8453 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu @@ -18,100 +18,111 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template -__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) { - - const auto grad = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ Nd4jLong xLen; - __shared__ T momentumT; - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - - if (threadIdx.x == 0) { - xLen = shape::length(xShapeInfo); - momentumT = (-momentum - 1); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) && - shape::order(xShapeInfo) == shape::order(stShapeInfo); - - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); +template +__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, + const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, + const Nd4jLong* stShapeInfo, const T lr, + const T momentum) { + const auto grad = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ T momentumT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + momentumT = (-momentum - 1); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - T prevState = momentum * init[initOffset]; - st[stOffset] = prevState - lr * grad[xOffset]; - up[zOffset] = prevState + momentumT * st[stOffset]; - } + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } } /////////////////////////////////////////////////////////////////// -template -linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, - const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, - void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dMomentum) { - - const T lr = static_cast(dLr); - const T momentum = static_cast(dMomentum); - nesterovsUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, momentum); +template +linkage void nesterovsUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vin, const Nd4jLong* inShapeInfo, void* vz, + const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dMomentum) { + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + nesterovsUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + momentum); } /////////////////////////////////////////////////////////////////// -void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, - NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { - - PointersManager manager(context, "nesterovsUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState }); - BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, - context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateV.specialBuffer(), stateV.specialShapeInfo(), dLr, dMomentum), FLOAT_TYPES); - NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState }); - - manager.synchronize(); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double dMomentum) { + PointersManager manager(context, "nesterovsUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateV}, {&gradient, &initState}); + BUILD_SINGLE_SELECTOR( + gradient.dataType(), nesterovsUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.specialBuffer(), gradient.specialShapeInfo(), + initState.specialBuffer(), initState.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), + stateV.specialBuffer(), stateV.specialShapeInfo(), dLr, dMomentum), + FLOAT_TYPES); + NDArray::registerSpecialUse({&update, &stateV}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu index 26f7253d2dea..f72092b8703c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu @@ -18,104 +18,115 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // -#include -#include #include #include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// -template -__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const T lr, const T rmsDecay, const T epsilon) { - - const auto x = reinterpret_cast(vx); - const auto init = reinterpret_cast(vin); - - auto up = reinterpret_cast(vz); - auto st = reinterpret_cast(vst); - - __shared__ Nd4jLong xLen; - __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; - - if (threadIdx.x == 0) { - - xLen = shape::length(xShapeInfo); - - bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && - 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); - - bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && - shape::order(xShapeInfo) == shape::order(inShapeInfo); - bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); - bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); +template +__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vst, const Nd4jLong *stShapeInfo, + const T lr, const T rmsDecay, + const T epsilon) { + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && + 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && + 1 == shape::elementWiseStride(inShapeInfo); + + bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; + i += gridDim.x * blockDim.x) { + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); } - __syncthreads(); - - int coords[MAX_RANK]; - - for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - - auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; - if (!bEWS || !bOrdering) { - - shape::index2coords(i, xShapeInfo, coords); - xOffset = shape::getOffset(xShapeInfo, coords); - zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); - initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); - stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); - } - - st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ; - up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); - } + st[stOffset] = + init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay); + up[zOffset] = + (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } } /////////////////////////////////////////////////////////////////// -template -linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, - void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - const T lr = static_cast(dLr); - const T rmsDecay = static_cast(dRmsDecay); - const T epsilon = static_cast(dEpsilon); - - rmsPropUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, - vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon); +template +linkage void rmsPropUpdaterCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, + const void *vin, const Nd4jLong *inShapeInfo, void *vz, + const Nd4jLong *zShapeInfo, void *vst, const Nd4jLong *stShapeInfo, + const double dLr, const double dRmsDecay, const double dEpsilon) { + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + rmsPropUpdaterCuda<<>>( + vx, xShapeInfo, vin, inShapeInfo, vz, zShapeInfo, vst, stShapeInfo, lr, + rmsDecay, epsilon); } /////////////////////////////////////////////////////////////////// -void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, - const double dLr, const double dRmsDecay, const double dEpsilon) { - - PointersManager manager(context, "rmsPropUpdater"); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - - NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState }); - - BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, - context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), - initState.specialBuffer(), initState.specialShapeInfo(), - update.specialBuffer(), update.specialShapeInfo(), - stateG.specialBuffer(), stateG.specialShapeInfo(), - dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES); - - NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); - - manager.synchronize(); +void updaterRmsProp(sd::LaunchContext *context, const NDArray &gradient, + const NDArray &initState, NDArray &update, NDArray &stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon) { + PointersManager manager(context, "rmsPropUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = + (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState}); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, + (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.specialBuffer(), + gradient.specialShapeInfo(), initState.specialBuffer(), + initState.specialShapeInfo(), update.specialBuffer(), + update.specialShapeInfo(), stateG.specialBuffer(), + stateG.specialShapeInfo(), dLr, dRmsDecay, dEpsilon), + FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); + + manager.synchronize(); } -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/weights.cu b/libnd4j/include/ops/declarable/helpers/cuda/weights.cu index 1620820a5896..9b5b01df8fef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/weights.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/weights.cu @@ -24,96 +24,117 @@ namespace sd { namespace ops { namespace helpers { - - - template - static __device__ void adjustWeightsKernelD(void* inputBuffer, Nd4jLong const* inputShape, - void* weightsBuffer, Nd4jLong const* weightsShape, - void* outputBuffer, Nd4jLong inputLength, - Nd4jLong outputLength, int val) { - // typedef Nd4jLong T; - auto tid = threadIdx.x; - //int threadCount = gridDim.x * blockDim.x; - __shared__ T* outputPart; - __shared__ Nd4jLong offset; - //for (int e = 0; e < inputLength; e++) { - for (Nd4jLong e = tid; e < inputLength; e += blockDim.x) { - - Nd4jLong xOffset = shape::getIndexOffset(e, inputShape); - int current = *(reinterpret_cast(inputBuffer) + xOffset); - if (current == val) { - //printf("%lld\n", xOffset); - //Nd4jLong zOffset = shape::getIndexOffset(val, outputShape); - if (weightsBuffer != nullptr) { - Nd4jLong yOffset = shape::getIndexOffset(e, weightsShape); - //atomicAdd(); - //*reinterpret_cast(outputBuffer) += reinterpret_cast(weightsBuffer)[yOffset]; - sd::math::atomics::nd4j_atomicAdd(reinterpret_cast(outputBuffer), reinterpret_cast(weightsBuffer)[yOffset]); //output->p(val, output->e(val) + 1); -// atomicAdd(reinterpret_cast(outputBuffer), reinterpret_cast(weightsBuffer)[yOffset]); //output->p(val, output->e(val) + 1); - } - else { - //*reinterpret_cast(outputBuffer) += int(1); - //printf("outputBuffer[0] = %d\n", static_cast(*(reinterpret_cast(outputBuffer)))); - sd::math::atomics::nd4j_atomicAdd(reinterpret_cast(outputBuffer), T(1)); //output->p(val, output->e(val) + 1); -// atomicAdd(reinterpret_cast(outputBuffer), int(1)); //output->p(val, output->e(val) + 1); - // printf("outputBuffer[%ld] = %d\n", zOffset, static_cast(*(reinterpret_cast(outputBuffer) + zOffset))); - } - //printf("xOffset is %ld, zOffset is %ld\n", xOffset, zOffset); - } - } -// if (threadIdx.x + offset < outputLength) -// reinterpret_cast(outputBuffer)[threadIdx.x + offset] = outputPart[threadIdx.x]; +template +static __device__ void adjustWeightsKernelD( + void* inputBuffer, Nd4jLong const* inputShape, void* weightsBuffer, + Nd4jLong const* weightsShape, void* outputBuffer, Nd4jLong inputLength, + Nd4jLong outputLength, int val) { + // typedef Nd4jLong T; + auto tid = threadIdx.x; + // int threadCount = gridDim.x * blockDim.x; + __shared__ T* outputPart; + __shared__ Nd4jLong offset; + // for (int e = 0; e < inputLength; e++) { + for (Nd4jLong e = tid; e < inputLength; e += blockDim.x) { + Nd4jLong xOffset = shape::getIndexOffset(e, inputShape); + int current = *(reinterpret_cast(inputBuffer) + xOffset); + if (current == val) { + // printf("%lld\n", xOffset); + // Nd4jLong zOffset = shape::getIndexOffset(val, outputShape); + if (weightsBuffer != nullptr) { + Nd4jLong yOffset = shape::getIndexOffset(e, weightsShape); + // atomicAdd(); + //*reinterpret_cast(outputBuffer) += reinterpret_cast(weightsBuffer)[yOffset]; + sd::math::atomics::nd4j_atomicAdd( + reinterpret_cast(outputBuffer), + reinterpret_cast( + weightsBuffer)[yOffset]); // output->p(val, output->e(val) + + // 1); + // atomicAdd(reinterpret_cast(outputBuffer), + // reinterpret_cast(weightsBuffer)[yOffset]); + // //output->p(val, output->e(val) + 1); + } else { + //*reinterpret_cast(outputBuffer) += int(1); + // printf("outputBuffer[0] = %d\n", + // static_cast(*(reinterpret_cast(outputBuffer)))); + sd::math::atomics::nd4j_atomicAdd( + reinterpret_cast(outputBuffer), + T(1)); // output->p(val, output->e(val) + 1); + // atomicAdd(reinterpret_cast(outputBuffer), + // int(1)); //output->p(val, output->e(val) + 1); + // printf("outputBuffer[%ld] = %d\n", zOffset, + // static_cast(*(reinterpret_cast(outputBuffer) + + // zOffset))); + } + // printf("xOffset is %ld, zOffset is %ld\n", xOffset, zOffset); } + } + // if (threadIdx.x + offset < outputLength) + // reinterpret_cast(outputBuffer)[threadIdx.x + offset] = + // outputPart[threadIdx.x]; +} - template - static __global__ void adjustWeightsKernel(void* inputBuffer, Nd4jLong const* inputShape, - void* weightsBuffer, Nd4jLong const* weightsShape, - void* outputBuffer, Nd4jLong const* outputShape, - int minLength, int maxLength) { - - //auto tid = blockIdx.x * blockDim.x + threadIdx.x; // * blockDim.x; // + threadIdx.x; - int threadCount = gridDim.x * blockDim.x; - Nd4jLong inputLength = shape::length(inputShape); - - Nd4jLong outputLength = shape::length(outputShape); - Nd4jLong borderLen = 1; - - for (Nd4jLong e = blockIdx.x; e < outputLength; e += threadCount) { - //if (blockIdx.x < outputLength) { - //if (e + threadCount < outputLength) { - Nd4jLong zOffset = shape::getIndexOffset(e, outputShape); - //printf("%d %d %d\n", blockIdx.x, blockDim.x, threadIdx.x); - //Nd4jLong borderLen = 1; - T* outputBufferZ = reinterpret_cast(outputBuffer) + zOffset; - adjustWeightsKernelD(inputBuffer, inputShape, weightsBuffer, weightsShape, (void*)outputBufferZ, - inputLength, outputLength, (int)zOffset); - - } - } +template +static __global__ void adjustWeightsKernel( + void* inputBuffer, Nd4jLong const* inputShape, void* weightsBuffer, + Nd4jLong const* weightsShape, void* outputBuffer, + Nd4jLong const* outputShape, int minLength, int maxLength) { + // auto tid = blockIdx.x * blockDim.x + threadIdx.x; // * blockDim.x; // + + // threadIdx.x; + int threadCount = gridDim.x * blockDim.x; + Nd4jLong inputLength = shape::length(inputShape); - template - static void adjustWeights_(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { -// for (int e = 0; e < input->lengthOf(); e++) { -// int val = input->e(e); -// if (val < maxLength) { -// if (weights != nullptr) -// output->p(val, output->e(val) + weights->e(e)); -// else -// output->p(val, output->e(val) + 1); -// } -// } - dim3 launchDims(256, 512, 8192); - auto stream = context->getCudaStream(); - adjustWeightsKernel<<>>(input->specialBuffer(), - input->specialShapeInfo(), weights?weights->specialBuffer():nullptr, weights?weights->specialShapeInfo():nullptr, - output->specialBuffer(), output->specialShapeInfo(), minLength, maxLength); - } + Nd4jLong outputLength = shape::length(outputShape); + Nd4jLong borderLen = 1; - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength) { - BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, (context, input, weights, output, minLength, maxLength), GENERIC_NUMERIC_TYPES); - } + for (Nd4jLong e = blockIdx.x; e < outputLength; e += threadCount) { + // if (blockIdx.x < outputLength) { + // if (e + threadCount < outputLength) { + Nd4jLong zOffset = shape::getIndexOffset(e, outputShape); + // printf("%d %d %d\n", blockIdx.x, blockDim.x, threadIdx.x); + // Nd4jLong borderLen = 1; + T* outputBufferZ = reinterpret_cast(outputBuffer) + zOffset; + adjustWeightsKernelD(inputBuffer, inputShape, weightsBuffer, + weightsShape, (void*)outputBufferZ, inputLength, + outputLength, (int)zOffset); + } +} - BUILD_SINGLE_TEMPLATE(template void adjustWeights_, (sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength), GENERIC_NUMERIC_TYPES); +template +static void adjustWeights_(sd::LaunchContext* context, NDArray* input, + NDArray* weights, NDArray* output, int minLength, + int maxLength) { + // for (int e = 0; e < input->lengthOf(); e++) { + // int val = input->e(e); + // if (val < maxLength) { + // if (weights != nullptr) + // output->p(val, output->e(val) + weights->e(e)); + // else + // output->p(val, output->e(val) + 1); + // } + // } + dim3 launchDims(256, 512, 8192); + auto stream = context->getCudaStream(); + adjustWeightsKernel<<>>( + input->specialBuffer(), input->specialShapeInfo(), + weights ? weights->specialBuffer() : nullptr, + weights ? weights->specialShapeInfo() : nullptr, output->specialBuffer(), + output->specialShapeInfo(), minLength, maxLength); } + +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength) { + BUILD_SINGLE_SELECTOR(output->dataType(), adjustWeights_, + (context, input, weights, output, minLength, maxLength), + GENERIC_NUMERIC_TYPES); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template void adjustWeights_, + (sd::LaunchContext * context, NDArray* input, + NDArray* weights, NDArray* output, int minLength, + int maxLength), + GENERIC_NUMERIC_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu index 660c49325845..dd397ac1e02f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/zeta.cu @@ -18,68 +18,77 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include +#include namespace sd { namespace ops { namespace helpers { - /////////////////////////////////////////////////////////////////// -template +template __global__ static void zetaCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto q = reinterpret_cast(vq); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong len; + void *vz, const Nd4jLong *zShapeInfo) { + const auto x = reinterpret_cast(vx); + const auto q = reinterpret_cast(vq); + auto z = reinterpret_cast(vz); - if (threadIdx.x == 0) - len = shape::length(xShapeInfo); - __syncthreads(); + __shared__ Nd4jLong len; - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto totalThreads = gridDim.x * blockDim.x; + if (threadIdx.x == 0) len = shape::length(xShapeInfo); + __syncthreads(); - for (int i = tid; i < len; i += totalThreads) { + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto totalThreads = gridDim.x * blockDim.x; - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto qOffset = shape::getIndexOffset(i, qShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); + for (int i = tid; i < len; i += totalThreads) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto qOffset = shape::getIndexOffset(i, qShapeInfo); + const auto zOffset = shape::getIndexOffset(i, zShapeInfo); - z[zOffset] = zetaScalar(x[xOffset], q[qOffset]); - } + z[zOffset] = zetaScalar(x[xOffset], q[qOffset]); + } } /////////////////////////////////////////////////////////////////// -template -static void zetaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - zetaCuda<<>>(vx, xShapeInfo, vq, qShapeInfo, vz, zShapeInfo); +template +static void zetaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vq, + const Nd4jLong *qShapeInfo, void *vz, + const Nd4jLong *zShapeInfo) { + zetaCuda<<>>( + vx, xShapeInfo, vq, qShapeInfo, vz, zShapeInfo); } -void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& z) { - - if(!x.isActualOnDeviceSide()) x.syncToDevice(); - if(!q.isActualOnDeviceSide()) q.syncToDevice(); +void zeta(sd::LaunchContext *context, const NDArray &x, const NDArray &q, + NDArray &z) { + if (!x.isActualOnDeviceSide()) x.syncToDevice(); + if (!q.isActualOnDeviceSide()) q.syncToDevice(); - int threadsPerBlock = MAX_NUM_THREADS / 2; - int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_SINGLE_SELECTOR(x.dataType(), zetaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.specialBuffer(), x.specialShapeInfo(), q.specialBuffer(), q.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + x.dataType(), zetaCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + x.specialBuffer(), x.specialShapeInfo(), q.specialBuffer(), + q.specialShapeInfo(), z.specialBuffer(), z.specialShapeInfo()), + FLOAT_TYPES); - x.tickReadHost(); - q.tickReadHost(); - z.tickWriteDevice(); -} - -BUILD_SINGLE_TEMPLATE(template void zetaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vq, const Nd4jLong *qShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); - - -} -} + x.tickReadHost(); + q.tickReadHost(); + z.tickWriteDevice(); } +BUILD_SINGLE_TEMPLATE(template void zetaCudaLauncher, + (const int blocksPerGrid, const int threadsPerBlock, + const cudaStream_t *stream, const void *vx, + const Nd4jLong *xShapeInfo, const void *vq, + const Nd4jLong *qShapeInfo, void *vz, + const Nd4jLong *zShapeInfo), + FLOAT_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/d_t_s.h b/libnd4j/include/ops/declarable/helpers/d_t_s.h index e5ac58e5aa24..b66d12b6ce3f 100644 --- a/libnd4j/include/ops/declarable/helpers/d_t_s.h +++ b/libnd4j/include/ops/declarable/helpers/d_t_s.h @@ -18,14 +18,15 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); -} +void _depthToSpace(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/diag.h b/libnd4j/include/ops/declarable/helpers/diag.h index af84eec01fc7..a216c9457454 100644 --- a/libnd4j/include/ops/declarable/helpers/diag.h +++ b/libnd4j/include/ops/declarable/helpers/diag.h @@ -19,17 +19,19 @@ // #ifndef __DIAG_H_HELPERS__ #define __DIAG_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void diagFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output); - void diagPartFunctor(sd::LaunchContext * context, NDArray const* input, NDArray* output); +void diagFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output); +void diagPartFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/dilation2d.h b/libnd4j/include/ops/declarable/helpers/dilation2d.h index 281a2f26a347..1198a6b59d66 100644 --- a/libnd4j/include/ops/declarable/helpers/dilation2d.h +++ b/libnd4j/include/ops/declarable/helpers/dilation2d.h @@ -20,68 +20,77 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////// -void dilation2d(sd::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW); +void dilation2d(sd::LaunchContext *context, NDArray *input, NDArray *weights, + NDArray *output, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW); ////////////////////////////////////////////////////////////////////// -FORCEINLINE Nd4jStatus outputSize(sd::LaunchContext * context, const int inSize, const int k, const int d, const int s, bool isSameMode, int *outSize, int *padding_before, int *padding_after) { - if (s <= 0) - return Status::THROW("Dilation2D: Stride must be > 0"); - - if (d < 1) - return Status::THROW("Dilation2D: Dilation rate must be >= 1"); - - int kEff = (k - 1) * d + 1; - if (isSameMode) { - *outSize = (inSize + s - 1) / s; - const int padding_needed = sd::math::nd4j_max(0, (*outSize - 1) * s + kEff -inSize); - - *padding_before = padding_needed / 2; - *padding_after = padding_needed - *padding_before; - } else { - *outSize = (inSize - kEff + s) / s; - *padding_before = *padding_after = 0; - } - - if (*outSize < 0) - return Status::THROW("Dilation2D: outSize has negative value"); - - return Status::OK(); +FORCEINLINE Nd4jStatus outputSize(sd::LaunchContext *context, const int inSize, + const int k, const int d, const int s, + bool isSameMode, int *outSize, + int *padding_before, int *padding_after) { + if (s <= 0) return Status::THROW("Dilation2D: Stride must be > 0"); + + if (d < 1) return Status::THROW("Dilation2D: Dilation rate must be >= 1"); + + int kEff = (k - 1) * d + 1; + if (isSameMode) { + *outSize = (inSize + s - 1) / s; + const int padding_needed = + sd::math::nd4j_max(0, (*outSize - 1) * s + kEff - inSize); + + *padding_before = padding_needed / 2; + *padding_after = padding_needed - *padding_before; + } else { + *outSize = (inSize - kEff + s) / s; + *padding_before = *padding_after = 0; + } + + if (*outSize < 0) + return Status::THROW("Dilation2D: outSize has negative value"); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// -FORCEINLINE Nd4jStatus dilation_hw(sd::LaunchContext * context, Nd4jLong const* in, Nd4jLong const* wh, std::vector &strides, std::vector &rates, bool isSameMode, int *sH, int *sW, int *pH, int *pW, int *dH, int *dW, int *oH, int *oW) { - const int iH = shape::sizeAt(in, 1); - const int iW = shape::sizeAt(in, 2); - const int iC = shape::sizeAt(in, 3); - - *sH = strides[1]; - *sW = strides[2]; - *dH = rates[1]; - *dW = rates[2]; - - const int kH = shape::sizeAt(wh, 0); - const int kW = shape::sizeAt(wh, 1); - - const int kHeff = kH + (kH - 1) * (*dH - 1); - const int kWeff = kW + (kW - 1) * (*dW - 1); - - int padding_after_unusedA, padding_after_unusedB; - if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, &padding_after_unusedA) != Status::OK()) - return Status::THROW("Dilation2D: bad height"); - - if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, &padding_after_unusedA) != Status::OK()) - return Status::THROW("Dilation2D: bad width"); - - return Status::OK(); +FORCEINLINE Nd4jStatus dilation_hw(sd::LaunchContext *context, + Nd4jLong const *in, Nd4jLong const *wh, + std::vector &strides, + std::vector &rates, bool isSameMode, + int *sH, int *sW, int *pH, int *pW, int *dH, + int *dW, int *oH, int *oW) { + const int iH = shape::sizeAt(in, 1); + const int iW = shape::sizeAt(in, 2); + const int iC = shape::sizeAt(in, 3); + + *sH = strides[1]; + *sW = strides[2]; + *dH = rates[1]; + *dW = rates[2]; + + const int kH = shape::sizeAt(wh, 0); + const int kW = shape::sizeAt(wh, 1); + + const int kHeff = kH + (kH - 1) * (*dH - 1); + const int kWeff = kW + (kW - 1) * (*dW - 1); + + int padding_after_unusedA, padding_after_unusedB; + if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, + &padding_after_unusedA) != Status::OK()) + return Status::THROW("Dilation2D: bad height"); + + if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, + &padding_after_unusedA) != Status::OK()) + return Status::THROW("Dilation2D: bad width"); + + return Status::OK(); } - - -} -} -} \ No newline at end of file +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/dropout.h b/libnd4j/include/ops/declarable/helpers/dropout.h index 052b68f33f6b..4a4c1c423ee3 100644 --- a/libnd4j/include/ops/declarable/helpers/dropout.h +++ b/libnd4j/include/ops/declarable/helpers/dropout.h @@ -19,20 +19,29 @@ // #ifndef __DROP_OUT_HELPERS__ #define __DROP_OUT_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue); - int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue); - int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta); - int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta); +int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, + NDArray* reduceShape, int seed, double probValue); +int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, + NDArray* output, NDArray* reduceShape, int seed, + double probValue); +int alphaDropOutFunctor(graph::Context& context, NDArray* input, + NDArray* output, NDArray* reduceShape, int seed, + double probValue, double alpha, double alpha1, + double beta); +int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, + NDArray* gradOut, NDArray* output, + NDArray* reduceShape, int seed, double probValue, + double alpha, double alpha1, double beta); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/dynamic.h b/libnd4j/include/ops/declarable/helpers/dynamic.h index 29452cb8ff38..75c6ee21dbf1 100644 --- a/libnd4j/include/ops/declarable/helpers/dynamic.h +++ b/libnd4j/include/ops/declarable/helpers/dynamic.h @@ -20,21 +20,32 @@ #ifndef __DYNAMIC_H_HELPERS__ #define __DYNAMIC_H_HELPERS__ -#include #include +#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void dynamicPartitionFunctor(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList); +void dynamicPartitionFunctor(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector& outputList); - int dynamicStitchFunctor(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output); +int dynamicStitchFunctor(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, NDArray* output); - void dynamicPartitionFunctorBP(sd::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& gradientInputList, std::vector& outputList); +void dynamicPartitionFunctorBP(sd::LaunchContext* context, NDArray const* input, + NDArray const* indices, + std::vector const& gradientInputList, + std::vector& outputList); - int dynamicStitchFunctorBP(sd::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradientInput, std::vector& outputList); - } - } -} +int dynamicStitchFunctorBP(sd::LaunchContext* context, + std::vector const& inputs, + std::vector const& indices, + NDArray const* gradientInput, + std::vector& outputList); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/extract_patches.h b/libnd4j/include/ops/declarable/helpers/extract_patches.h index 63d5e94f4846..1c64745d00c5 100644 --- a/libnd4j/include/ops/declarable/helpers/extract_patches.h +++ b/libnd4j/include/ops/declarable/helpers/extract_patches.h @@ -19,16 +19,18 @@ // #ifndef __EXTRACT_PATCHES_H_HELPERS__ #define __EXTRACT_PATCHES_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void extractPatches(sd::LaunchContext * context, NDArray* images, NDArray* output, int sizeRow, int sizeCol, int stradeRow, int stradeCol, int rateRow, int rateCol, bool theSame); +void extractPatches(sd::LaunchContext* context, NDArray* images, + NDArray* output, int sizeRow, int sizeCol, int stradeRow, + int stradeCol, int rateRow, int rateCol, bool theSame); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/fake_quantization.h b/libnd4j/include/ops/declarable/helpers/fake_quantization.h index b5f4dff00225..90d42a07e1b8 100644 --- a/libnd4j/include/ops/declarable/helpers/fake_quantization.h +++ b/libnd4j/include/ops/declarable/helpers/fake_quantization.h @@ -19,16 +19,19 @@ // #ifndef __FAKE_QUANTIZATION_H_HELPERS__ #define __FAKE_QUANTIZATION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); - void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output); -} -} -} +void fakeQuantWithMinMaxVars(NDArray* input, NDArray* min, NDArray* max, + int numBits, bool narrowed, NDArray* output); +void fakeQuantWithMinMaxVarsPerChannel(LaunchContext* context, NDArray* input, + NDArray* min, NDArray* max, int numBits, + bool narrowed, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/flatten.h b/libnd4j/include/ops/declarable/helpers/flatten.h index bddf5362011e..8240d7ffa053 100644 --- a/libnd4j/include/ops/declarable/helpers/flatten.h +++ b/libnd4j/include/ops/declarable/helpers/flatten.h @@ -21,48 +21,45 @@ #ifndef SD_FLATTEN_H #define SD_FLATTEN_H -#include #include -namespace sd { -namespace ops { -namespace helpers { +#include +namespace sd { +namespace ops { +namespace helpers { ////////////////////////////////////////////////////////////////////// -void flatten(sd::LaunchContext *context, std::vector &inputs, NDArray *output, char order); - +void flatten(sd::LaunchContext *context, std::vector &inputs, + NDArray *output, char order); ////////////////////////////////////////////////////////////////////// -INLINEDEF _CUDA_HD Nd4jLong getIndexOffsetOrdered(Nd4jLong index, const Nd4jLong *shapeInfo, const char order) { - - Nd4jLong offset = 0; - - if (order == 'c') { - - for(uint i = shapeInfo[0]; i > 1; --i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - - offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration +INLINEDEF _CUDA_HD Nd4jLong getIndexOffsetOrdered(Nd4jLong index, + const Nd4jLong *shapeInfo, + const char order) { + Nd4jLong offset = 0; + + if (order == 'c') { + for (uint i = shapeInfo[0]; i > 1; --i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; } - else { - - for(uint i = 1; i < shapeInfo[0]; ++i) { - offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; - index /= shapeInfo[i]; - } - offset += index * shapeInfo[2 * shapeInfo[0]]; // last iteration + offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration + } else { + for (uint i = 1; i < shapeInfo[0]; ++i) { + offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]]; + index /= shapeInfo[i]; } - return offset; -} - + offset += index * shapeInfo[2 * shapeInfo[0]]; // last iteration + } + return offset; } -} -} -#endif //SD_FLATTEN_H +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_FLATTEN_H diff --git a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h index 2f99f3777c31..8ed853a43173 100644 --- a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h +++ b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h @@ -23,78 +23,98 @@ #define LIBND4J_GAMMAMATHFUNC_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - // calculate the digamma function for each element for array - void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z); - - // calculate the polygamma function - void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z); - - // calculate the digamma function for one element - // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) - template - _CUDA_HD T diGammaScalar(T x) { - - const int xInt = static_cast(x); - - // negative and zero - if(x <= 0) { - if(x == xInt) // integer - return DataTypeUtils::infOrMax(); - else - return diGammaScalar(1 - x) - M_PI / sd::math::nd4j_tan(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x) - } - - // positive integer - if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n - T result = -0.577215664901532; - for (uint i = 1; i <= xInt - 1; ++i) { - result += static_cast(1) / i; - } - return result; - } - - // positive half-integer - if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n - T result = -0.577215664901532 - 2 * sd::math::nd4j_log(2); - for (uint i = 1; i <= xInt; ++i) { - result += static_cast(2) / (2*i - 1); - } - return result; - } - - // positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion - if(x < 5) - return diGammaScalar(1 + x) - static_cast(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x. - - // *** other positive **** // - - // truncated expansion formula (from wiki) - // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... - - if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x) - return sd::math::nd4j_log(x); - - // coefficients used in truncated asymptotic expansion formula - const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12}; - // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333}; - - const T x2Inv = static_cast(1) / (x * x); - T result = 0; - - for (int i = 6; i >= 0; --i) - result = (result + coeffs[i]) * x2Inv; - return result + sd::math::nd4j_log(x) - static_cast(0.5) / x; - } - -} -} +// calculate the digamma function for each element for array +void diGamma(sd::LaunchContext* context, const NDArray& x, NDArray& z); + +// calculate the polygamma function +void polyGamma(sd::LaunchContext* context, const NDArray& n, const NDArray& x, + NDArray& z); + +// calculate the digamma function for one element +// implementation is based on serial representation written in terms of the +// Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) +template +_CUDA_HD T diGammaScalar(T x) { + const int xInt = static_cast(x); + + // negative and zero + if (x <= 0) { + if (x == xInt) // integer + return DataTypeUtils::infOrMax(); + else + return diGammaScalar(1 - x) - + M_PI / sd::math::nd4j_tan( + M_PI * x); // use reflection formula psi(1-x) = psi(x) + // + pi*cot(pi*x) + } + + // positive integer + if (x == xInt && + xInt <= + 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k + // ), for n = 1,2,3,...inf, we use this formula only for n <= + // 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532; + for (uint i = 1; i <= xInt - 1; ++i) { + result += static_cast(1) / i; + } + return result; + } + + // positive half-integer + if (x - xInt == 0.5 && + xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + + // sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, + // we use this formula only for n <= 20 to avoid time + // consuming sum calculation for bigger n + T result = -0.577215664901532 - 2 * sd::math::nd4j_log(2); + for (uint i = 1; i <= xInt; ++i) { + result += static_cast(2) / (2 * i - 1); + } + return result; + } + + // positive, smaller then 5; we should use number > 5 in order to have + // satisfactory accuracy in asymptotic expansion + if (x < 5) + return diGammaScalar(1 + x) - + static_cast(1) / + x; // recurrence formula psi(x) = psi(x+1) - 1/x. + + // *** other positive **** // + + // truncated expansion formula (from wiki) + // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + + // 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... + + if (x >= (sizeof(T) > 4 + ? 1.e16 + : 1.e8)) // if x is too big take into account only log(x) + return sd::math::nd4j_log(x); + + // coefficients used in truncated asymptotic expansion formula + const T coeffs[7] = {-(T)1 / 12, (T)1 / 120, -(T)1 / 252, (T)1 / 240, + -(T)5 / 660, (T)691 / 32760, -(T)1 / 12}; + // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, + // -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, + // 0.0210927960927961, -0.0833333333333333}; + + const T x2Inv = static_cast(1) / (x * x); + T result = 0; + + for (int i = 6; i >= 0; --i) result = (result + coeffs[i]) * x2Inv; + return result + sd::math::nd4j_log(x) - static_cast(0.5) / x; } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_GAMMAMATHFUNC_H +#endif // LIBND4J_GAMMAMATHFUNC_H diff --git a/libnd4j/include/ops/declarable/helpers/gather.h b/libnd4j/include/ops/declarable/helpers/gather.h index f3870838576c..bb324558627d 100644 --- a/libnd4j/include/ops/declarable/helpers/gather.h +++ b/libnd4j/include/ops/declarable/helpers/gather.h @@ -26,11 +26,13 @@ namespace sd { namespace ops { namespace helpers { - - void gather(sd::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs); + +void gather(sd::LaunchContext* context, const NDArray* input, + const NDArray* indices, NDArray* output, + const std::vector& intArgs); } -} -} +} // namespace ops +} // namespace sd -#endif //LIBND4J_GATHER_H +#endif // LIBND4J_GATHER_H diff --git a/libnd4j/include/ops/declarable/helpers/gradient.h b/libnd4j/include/ops/declarable/helpers/gradient.h index 583396cf345f..b10298b0c7aa 100644 --- a/libnd4j/include/ops/declarable/helpers/gradient.h +++ b/libnd4j/include/ops/declarable/helpers/gradient.h @@ -19,19 +19,20 @@ // #ifndef __GRADIENT_H_HELPERS__ #define __GRADIENT_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - /* - * applyGradientDescent: calculate z = x - y * w. - * */ - void applyGradientDescent(sd::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output); +/* + * applyGradientDescent: calculate z = x - y * w. + * */ +void applyGradientDescent(sd::LaunchContext* context, NDArray* input, + NDArray* step, double weight, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/gru.h b/libnd4j/include/ops/declarable/helpers/gru.h index 9e98e4046616..afc2d9854058 100644 --- a/libnd4j/include/ops/declarable/helpers/gru.h +++ b/libnd4j/include/ops/declarable/helpers/gru.h @@ -27,33 +27,37 @@ namespace sd { namespace ops { namespace helpers { - void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, - const NDArray* bru, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h); - - void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b, - NDArray* gates, NDArray* h); - - void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); - - void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc); - - void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); - - void gruTimeLoopBp(sd::LaunchContext * context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); -} -} -} - - -#endif //LIBND4J_GRU_H \ No newline at end of file +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, + const NDArray* Wru, const NDArray* Wc, const NDArray* bru, + const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, NDArray* h); + +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, + const NDArray* Wru, const NDArray* Wc, const NDArray* b, + NDArray* gates, NDArray* h); + +void gruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, NDArray* h); + +void gruCellBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hLast, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, const NDArray* dLdr, + const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc); + +void gruCellBp(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* dLdh, const NDArray* gates, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); + +void gruTimeLoopBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, const NDArray* dLdh, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, + NDArray* dLdb); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_GRU_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/hamming.h b/libnd4j/include/ops/declarable/helpers/hamming.h index 2b6883a4f680..78e17c2dda3a 100644 --- a/libnd4j/include/ops/declarable/helpers/hamming.h +++ b/libnd4j/include/ops/declarable/helpers/hamming.h @@ -21,12 +21,15 @@ #ifndef SD_HAMMING_H #define SD_HAMMING_H +#include +#include + namespace sd { - namespace ops { - namespace helpers { - void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output); - } - } +namespace ops { +namespace helpers { +void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SD_HAMMING_H +#endif // SD_HAMMING_H diff --git a/libnd4j/include/ops/declarable/helpers/hashcode.h b/libnd4j/include/ops/declarable/helpers/hashcode.h index 6e76fc97609e..583317eb61d3 100644 --- a/libnd4j/include/ops/declarable/helpers/hashcode.h +++ b/libnd4j/include/ops/declarable/helpers/hashcode.h @@ -24,47 +24,46 @@ #include "helpers.h" namespace sd { - namespace ops { - namespace helpers { - template - FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value); +namespace ops { +namespace helpers { +template +FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value); - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) { - int intie = *(int *)&value; - return static_cast(intie); - } - - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) { - Nd4jLong longie = *(Nd4jLong *)&value; - return longie; - } - - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) { - return longBytes((float) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) { + int intie = *(int *)&value; + return static_cast(intie); +} - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) { - return value; - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) { + Nd4jLong longie = *(Nd4jLong *)&value; + return longie; +} - template <> - FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) { - return longBytes((float) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) { + return longBytes((float)value); +} - template - FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) { - return longBytes((Nd4jLong) value); - } +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) { + return value; +} +template <> +FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) { + return longBytes((float)value); +} - void hashCode(LaunchContext *context, NDArray &array, NDArray &result); - } - } +template +FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) { + return longBytes((Nd4jLong)value); } -#endif //SD_HASHCODE_H +void hashCode(LaunchContext *context, NDArray &array, NDArray &result); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_HASHCODE_H diff --git a/libnd4j/include/ops/declarable/helpers/helpers.h b/libnd4j/include/ops/declarable/helpers/helpers.h index c36387e6e57b..6651757329cb 100644 --- a/libnd4j/include/ops/declarable/helpers/helpers.h +++ b/libnd4j/include/ops/declarable/helpers/helpers.h @@ -21,29 +21,28 @@ #ifndef LIBND4J_OPS_HELPERS_H #define LIBND4J_OPS_HELPERS_H -#include -#include +#include +#include #include +#include +#include +#include +#include #include #include -#include -#include -#include + #include -#include -#include +#include #ifdef __CUDACC__ #include -#include -#include #include +#include +#include #include #include #include -#include - -#endif // CUDACC +#endif // CUDACC -#endif // LIBND4J_HELPERS_H +#endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/histogram.h b/libnd4j/include/ops/declarable/helpers/histogram.h index 2963d5f0e37b..909b521e8774 100644 --- a/libnd4j/include/ops/declarable/helpers/histogram.h +++ b/libnd4j/include/ops/declarable/helpers/histogram.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void histogramHelper(sd::LaunchContext *context, NDArray &input, NDArray &output); - } - } +namespace ops { +namespace helpers { +void histogramHelper(sd::LaunchContext *context, NDArray &input, + NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SD_HISTOGRAM_H +#endif // SD_HISTOGRAM_H diff --git a/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h b/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h index 40ba6ffecfff..065fc001dbbd 100644 --- a/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h +++ b/libnd4j/include/ops/declarable/helpers/histogramFixedWidth.h @@ -27,11 +27,11 @@ namespace sd { namespace ops { namespace helpers { -void histogramFixedWidth(sd::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output); +void histogramFixedWidth(sd::LaunchContext* context, const NDArray& input, + const NDArray& range, NDArray& output); - -} -} } +} // namespace ops +} // namespace sd -#endif //LIBND4J_HELPERS_HISTOGRAMFIXEDWIDTH_H +#endif // LIBND4J_HELPERS_HISTOGRAMFIXEDWIDTH_H diff --git a/libnd4j/include/ops/declarable/helpers/im2col.h b/libnd4j/include/ops/declarable/helpers/im2col.h index 87eaa3bbc05a..7c2cdfcaada7 100644 --- a/libnd4j/include/ops/declarable/helpers/im2col.h +++ b/libnd4j/include/ops/declarable/helpers/im2col.h @@ -27,9 +27,12 @@ namespace sd { namespace ops { namespace helpers { - SD_EXPORT void im2col(sd::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal); -} -} +SD_EXPORT void im2col(sd::LaunchContext& context, const NDArray& im, + NDArray& col, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const NDArray& arrZeroPadVal); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_HELPERS_H +#endif // LIBND4J_HELPERS_H diff --git a/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h b/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h index 758a02e31e34..cdb4303f4081 100644 --- a/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h +++ b/libnd4j/include/ops/declarable/helpers/image_draw_bounding_boxes.h @@ -19,16 +19,17 @@ // #ifndef __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__ #define __IMAGE_DRAW_BOUNDING_BOXES_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void drawBoundingBoxesFunctor(sd::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output); +void drawBoundingBoxesFunctor(sd::LaunchContext* context, NDArray* images, + NDArray* boxes, NDArray* colors, NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index c11e94ed4e25..374de5fdad68 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -20,37 +20,47 @@ // #ifndef __IMAGE_RESIZE_HELPERS__ #define __IMAGE_RESIZE_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - enum ImageResizeMethods { - kResizeBilinear = 1, - kResizeBicubic, - kResizeNearest, - kResizeGaussian, - kResizeLanczos5, - kResizeMitchelcubic, - kResizeArea - }; +enum ImageResizeMethods { + kResizeBilinear = 1, + kResizeBicubic, + kResizeNearest, + kResizeGaussian, + kResizeLanczos5, + kResizeMitchelcubic, + kResizeArea +}; - int resizeBilinearFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output); - int resizeNeighborFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelCenter, NDArray* output); - int resizeBicubicFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool preserveAspectRatio, bool antialias, NDArray* output); - int resizeBicubicFunctorA(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, bool const halfPixelAlign, NDArray* output); - int resizeAreaFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - bool const alignCorners, NDArray* output); +int resizeBilinearFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output); +int resizeNeighborFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, + NDArray* output); +int resizeBicubicFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool preserveAspectRatio, bool antialias, + NDArray* output); +int resizeBicubicFunctorA(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, bool const halfPixelAlign, + NDArray* output); +int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, + bool const alignCorners, NDArray* output); - int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, - ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); -} -} -} +int resizeFunctor(sd::LaunchContext* context, NDArray const* image, + int const width, int const height, ImageResizeMethods method, + bool preserveAspectRatio, bool antialias, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/image_suppression.h b/libnd4j/include/ops/declarable/helpers/image_suppression.h index a8d2027b8708..139a70b4cb45 100644 --- a/libnd4j/include/ops/declarable/helpers/image_suppression.h +++ b/libnd4j/include/ops/declarable/helpers/image_suppression.h @@ -19,21 +19,26 @@ // #ifndef __IMAGE_SUPPRESSION_H_HELPERS__ #define __IMAGE_SUPPRESSION_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void nonMaxSuppression(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); - Nd4jLong nonMaxSuppressionV3(sd::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); - Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, - double overlapThreshold, double scoreThreshold, NDArray* output); +void nonMaxSuppression(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output); +Nd4jLong nonMaxSuppressionV3(sd::LaunchContext* context, NDArray* boxes, + NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, + NDArray* output); +Nd4jLong nonMaxSuppressionGeneric(sd::LaunchContext* context, NDArray* boxes, + NDArray* scores, int maxSize, + double overlapThreshold, + double scoreThreshold, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/imagesHelpers.h b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h index 3a1666c7aca0..25fb3d3e8a73 100644 --- a/libnd4j/include/ops/declarable/helpers/imagesHelpers.h +++ b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h @@ -16,7 +16,7 @@ // // @author Oleh Semeniv (oleg.semeniv@gmail.com) -// +// // // @author AbdelRauf (rauf@konduit.ai) // @@ -24,27 +24,34 @@ #ifndef LIBND4J_HELPERS_IMAGES_H #define LIBND4J_HELPERS_IMAGES_H -#include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); +void transformRgbGrs(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); - void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); +void transformHsvRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); - void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); - void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); - void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); +void transformRgbHsv(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); +void transformYuvRgb(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); +void transformRgbYuv(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const int dimC); - void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); +void transformYiqRgb(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); - void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); -} -} -} +void transformRgbYiq(sd::LaunchContext* context, const NDArray* input, + NDArray* output, const int dimC); +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp index 2f574a52edc4..5c39c63068cc 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/choose.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/choose.cpp @@ -18,137 +18,137 @@ // @author sgazeos@gmail.com // -#include #include +#include #include namespace sd { namespace ops { namespace helpers { - template - static sd::NDArray* processCondition_(int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray& compScalar); - - template - static T processElementCondition(int mode,T d1,T d2); - - - template - sd::NDArray* processCondition_(int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar) { - - //Convert to straight ndarray based on input - - int numResults = 0; - if(comp != nullptr) { - if (comp->isScalar()) { - //Other input for compare could be an ndarray or a secondary scalar - //for comparison -// sd::NDArray arg1 = *arg; -// sd::NDArray comp1 = *comp; - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } - } else { - // REQUIRE_TRUE(comp.isSameShape(arg)); - //Other input for compare could be an ndarray or a secondary scalar - //for comparison - sd::NDArray arg1 = *arg; - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } - } - +template +static sd::NDArray* processCondition_(int mode, sd::NDArray* arg, + sd::NDArray* comp, + sd::NDArray& compScalar); + +template +static T processElementCondition(int mode, T d1, T d2); + +template +sd::NDArray* processCondition_(int mode, sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar) { + // Convert to straight ndarray based on input + + int numResults = 0; + if (comp != nullptr) { + if (comp->isScalar()) { + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + // sd::NDArray arg1 = *arg; + // sd::NDArray comp1 = *comp; + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = processElementCondition(mode, arg->e(i), comp->e(0)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; } - else { - // sd::NDArray arg1 = *arg; - //Other input for compare could be an ndarray or a secondary scalar - //for comparison - for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { - T result2 = processElementCondition(mode, arg->e(i), compScalar.e(0)); - if(result2 > static_cast(0)) { - if (output != nullptr) - output->p(numResults, arg->e(i)); - numResults++; - } - } + } + } else { + // REQUIRE_TRUE(comp.isSameShape(arg)); + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + sd::NDArray arg1 = *arg; + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = processElementCondition(mode, arg->e(i), comp->e(i)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; } - - if(numResult != nullptr) - numResult->p(0,numResults); - - return output; + } } - sd::NDArray* processCondition(sd::LaunchContext * context, int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar) { - arg->syncToHost(); + } else { + // sd::NDArray arg1 = *arg; + // Other input for compare could be an ndarray or a secondary scalar + // for comparison + for (Nd4jLong i = 0; i < arg->lengthOf(); i++) { + T result2 = + processElementCondition(mode, arg->e(i), compScalar.e(0)); + if (result2 > static_cast(0)) { + if (output != nullptr) output->p(numResults, arg->e(i)); + numResults++; + } + } + } - if (comp != nullptr) - comp->syncToHost(); + if (numResult != nullptr) numResult->p(0, numResults); - if (output != nullptr) - output->syncToHost(); + return output; +} - if (numResult != nullptr) - numResult->syncToHost(); +sd::NDArray* processCondition(sd::LaunchContext* context, int mode, + sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar) { + arg->syncToHost(); - compScalar.syncToHost(); + if (comp != nullptr) comp->syncToHost(); - BUILD_SINGLE_SELECTOR(arg->dataType(), return processCondition_, (mode, arg, comp, output, numResult, compScalar), FLOAT_TYPES); + if (output != nullptr) output->syncToHost(); - arg->syncToDevice(); + if (numResult != nullptr) numResult->syncToHost(); - if (comp != nullptr) - comp->syncToDevice(); + compScalar.syncToHost(); - if (output != nullptr) - output->syncToDevice(); + BUILD_SINGLE_SELECTOR(arg->dataType(), return processCondition_, + (mode, arg, comp, output, numResult, compScalar), + FLOAT_TYPES); - if (numResult != nullptr) - numResult->syncToDevice(); - - compScalar.syncToDevice(); + arg->syncToDevice(); - } - BUILD_SINGLE_TEMPLATE(template NDArray* processCondition_, (int mode,sd::NDArray *arg, sd::NDArray *comp, sd::NDArray *output, sd::NDArray *numResult, sd::NDArray& compScalar), FLOAT_TYPES); + if (comp != nullptr) comp->syncToDevice(); - template - T processElementCondition(int mode,T d1,T d2) { - T modePointer = (T ) mode; - T input[3] = {d2, (T) EPS, (T) mode}; - T res = simdOps::MatchCondition::op(d1, input); - return res; + if (output != nullptr) output->syncToDevice(); - } + if (numResult != nullptr) numResult->syncToDevice(); - void chooseFunctorArray(sd::LaunchContext * context, NDArray* arg, NDArray* comp, int mode, NDArray* result, NDArray* numResults) { - if(arg->isScalar() || comp->isScalar()) { - if(arg->isScalar()) { - processCondition(context, mode,comp,nullptr,result,numResults, *arg); - } - else { - processCondition(context, mode,arg,nullptr,result,numResults, *comp); - } - } - else { - auto zero = NDArrayFactory::create(0); - processCondition(context, mode,arg,comp,result,numResults, zero); - } - } + compScalar.syncToDevice(); +} +BUILD_SINGLE_TEMPLATE(template NDArray* processCondition_, + (int mode, sd::NDArray* arg, sd::NDArray* comp, + sd::NDArray* output, sd::NDArray* numResult, + sd::NDArray& compScalar), + FLOAT_TYPES); + +template +T processElementCondition(int mode, T d1, T d2) { + T modePointer = (T)mode; + T input[3] = {d2, (T)EPS, (T)mode}; + T res = simdOps::MatchCondition::op(d1, input); + return res; +} - void chooseFunctorScalar(sd::LaunchContext * context, NDArray* arg, double scalar, int mode, NDArray* result, NDArray* numResults) { - auto scalarA = NDArrayFactory::create(scalar); - processCondition(context, mode, arg, nullptr,result, numResults, scalarA); +void chooseFunctorArray(sd::LaunchContext* context, NDArray* arg, NDArray* comp, + int mode, NDArray* result, NDArray* numResults) { + if (arg->isScalar() || comp->isScalar()) { + if (arg->isScalar()) { + processCondition(context, mode, comp, nullptr, result, numResults, *arg); + } else { + processCondition(context, mode, arg, nullptr, result, numResults, *comp); } - -} + } else { + auto zero = NDArrayFactory::create(0); + processCondition(context, mode, arg, comp, result, numResults, zero); + } } + +void chooseFunctorScalar(sd::LaunchContext* context, NDArray* arg, + double scalar, int mode, NDArray* result, + NDArray* numResults) { + auto scalarA = NDArrayFactory::create(scalar); + processCondition(context, mode, arg, nullptr, result, numResults, scalarA); } + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp index 7357ad862d47..98c5f191ab5e 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -7,9 +7,9 @@ * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software - * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permnInsions and limitations + * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, + *WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See + *the License for the specific language governing permnInsions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 @@ -21,526 +21,554 @@ // implementation of gated Recurrent Unit cell // (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation" +// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi +// Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase Representations +// using RNN Encoder-Decoder for StatnIntical Machine Translation" - -#include +#include #include +#include #include -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, nIn], nIn - input size - // hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units - // W RU weights - [nIn+nOut, 2*nOut] - reset and update gates - // Wc C weights - [nIn+nOut, nOut] - cell gate - // b r and u biases, [2*nOut] - reset and update gates - // bc c biases, [nOut] - cell gate - - //Outputs: - // r Reset gate output [bS, nOut] - // u Update gate output [bS, nOut] - // c Cell gate output [bS, nOut] - // h current cell output [bS, nOut] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int nIn = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut] - NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut] - NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut] - NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut] - - NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut] - NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut] - - NDArray br = (*b)({0, nOut}); // [nOut] - NDArray bu = (*b)({nOut, 2*nOut}); // [nOut] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] - c->applyTransform(transform::Tanh, *c); - - // cell output - h->assign(*u * *hI + (1.f - *u) * *c); +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* W, const NDArray* Wc, const NDArray* b, + const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, + NDArray* h) { + // Inputs: + // x input [bS, nIn], nIn - input size + // hI previous cell output [bS, nOut], that is at previous time step + // t-1, nOut - number of units W RU weights - [nIn+nOut, 2*nOut] - + // reset and update gates Wc C weights - [nIn+nOut, nOut] - cell gate b + // r and u biases, [2*nOut] - reset and update gates bc c biases, [nOut] + // - cell gate + + // Outputs: + // r Reset gate output [bS, nOut] + // u Update gate output [bS, nOut] + // c Cell gate output [bS, nOut] + // h current cell output [bS, nOut] + + /***************************************************************************************/ + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** however it is more math-friendly and convenient for backprop formulas + * derivation) **/ + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray Wrx = (*W)({0, nIn, 0, nOut}); // [nIn, nOut] + NDArray Wux = (*W)({0, nIn, nOut, 2 * nOut}); // [nIn, nOut] + NDArray Wrh = (*W)({nIn, nIn + nOut, 0, nOut}); // [nOut, nOut] + NDArray Wuh = (*W)({nIn, nIn + nOut, nOut, 2 * nOut}); // [nOut, nOut] + + NDArray Wcx = (*Wc)({0, nIn, 0, 0}); // reset cell weights [nIn, nOut] + NDArray Wch = + (*Wc)({nIn, nIn + nOut, 0, 0}); // updates cell weights [nOut, nOut] + + NDArray br = (*b)({0, nOut}); // [nOut] + NDArray bu = (*b)({nOut, 2 * nOut}); // [nOut] + + // × means matrix multipication + // * means element-wise product or so called Hadamard product + + // reset gate + r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + r->applyTransform(transform::Sigmoid, *r); + + // update gate + u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + u->applyTransform(transform::Sigmoid, *u); + + // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) + c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + + // [nOut] = [bS, nOut] + c->applyTransform(transform::Tanh, *c); + + // cell output + h->assign(*u * *hI + (1.f - *u) * *c); } ////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, +void gruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* gates, NDArray* h) { - - //Inputs: - // x input [bS, nIn] - // hI previous cell output [bS, nOut], that is at previous time step t-1 - // Wx weights for x - [nIn, 3*nOut] - // Wh weights for h - [nOut, 3*nOut] - // b biases [3*nOut] - - // 3*nOut means following sequence: reset, update, cell - - //Outputs: - // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut] - // h current cell output [bS, nOut] - - // formulas: - // zr = x × Wxr + hI × Whr + br - // zu = x × Wxu + hI × Whu + bu - // r = sigmoid(zr) - // u = sigmoid(zu) - // zc = x × Wxc + (r * hI) × Whc + bc - // c = tanh(zc) - // h = (1-u)*c + u*hI - - const int bS = x->sizeAt(0); - const int nIn = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray temp = gates->ulike(); - MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] - temp += *b; - - MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] - - NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut] - - NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] - NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] - - // reset and update gates - ru += temp({0,0, 0,2*nOut}); - ru.applyTransform(transform::Sigmoid, ru); - - // cell gate - c.assign(c*r + temp({0,0, 2*nOut, 3*nOut})); - c.applyTransform(transform::Tanh, c); - - // cell output - h->assign(u * *hI + (1.f - u) * c); + // Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that is at previous time step + // t-1 Wx weights for x - [nIn, 3*nOut] Wh weights for h - [nOut, + // 3*nOut] b biases [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + // Outputs: + // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + + // cell gate [bS, nOut] h current cell output [bS, nOut] + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray temp = gates->ulike(); + MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] + temp += *b; + + MmulHelper::mmul(hI, Wh, + gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] + + NDArray ru = (*gates)({0, 0, 0, 2 * nOut}); // [bS, 2*nOut] + + NDArray r = (*gates)({0, 0, 0, nOut}); // [bS, nOut] + NDArray u = (*gates)({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray c = (*gates)({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] + + // reset and update gates + ru += temp({0, 0, 0, 2 * nOut}); + ru.applyTransform(transform::Sigmoid, ru); + + // cell gate + c.assign(c * r + temp({0, 0, 2 * nOut, 3 * nOut})); + c.applyTransform(transform::Tanh, c); + + // cell output + h->assign(u * *hI + (1.f - u) * c); } ////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // sL means time steps - - // x input [sL, bS, nIn] - // hI initial cell output (at time step = 0) [bS, nOut] - // Wx input-to-hidden weights, [nIn, 3*nOut] - // Wh hidden-to-hidden weights, [nOut, 3*nOut] - // b biases, [3*nOut] - - // h cell outputs at each time step [sL, bS, nOut] - - const int sL = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context); - - auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - - // time loop - for (int t = 0; t < sL; ++t) - gruCell(context, &xSet.at(t), t == 0 ? hI : &hSet.at(t-1), Wx, Wh, b, &gates, &hSet.at(t)); +void gruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, NDArray* h) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + + // h cell outputs at each time step [sL, bS, nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(h->ordering(), {bS, 3 * nOut}, h->dataType(), context); + + auto xSet = + x->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nIn] + auto hSet = + h->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nOut] + + // time loop + for (int t = 0; t < sL; ++t) + gruCell(context, &xSet.at(t), t == 0 ? hI : &hSet.at(t - 1), Wx, Wh, b, + &gates, &hSet.at(t)); } ////////////////////////////////////////////////////////////////////////// -void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] +void gruCellBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hLast, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, const NDArray* dLdr, + const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc) { + // Inputs: + // x input [bS, iS] + // hLast previous cell output [bS, nU], that is at previous time + // step t-1 W weights - [iS+nU, 2*nU] - reset and update gates Wc + // C weights - [iS+nU, nU] - cell gate b r and u biases, [2*nU] - + // reset and update gates bc c biases, [nU] - cell gate dLdr + // gradient wrt reset gate, [bS, nU] dLdu gradient wrt update gate, + // [bS, nU] dLdc gradient wrt cell state, [bS, nU] dLdh gradient wrt + // current cell output, [bS, nU] + + // Outputs: + // dLdx gradient wrt x, [bS, iS], + // dLdhLast gradient wrt hLast, [bS, nU] + // dLdW gradient wrt W, [iS+nU, 2*nU] + // dLdWc gradient wrt Wc, [iS+nU, nU] + // dLdb gradient wrt bru [2*nU] + // dLdbc gradient wrt bc [nU] + + // * means element-wise product or so called Hadamard product + // × means matrix multiplication + + /************************************************************************************************/ + /******************************* THIS IS NOT OPTIMAZED CODE + * *************************************/ + /*** aim is to have math-readable code in order to keep track of backprop + * formulas derivation ***/ + + const int bS = x->sizeAt(0); + const int iS = x->sizeAt(1); + const int nU = hLast->sizeAt(1); + + NDArray xT = x->transpose(); // [iS, bS] + NDArray hLastT = hLast->transpose(); // [nU, bS] + + NDArray Wrx = (*W)({0, iS, 0, nU}); // [iS, nU] + NDArray Wux = (*W)({0, iS, nU, 2 * nU}); // [iS, nU] + NDArray Wrh = (*W)({iS, iS + nU, 0, nU}); // [nU, nU] + NDArray Wuh = (*W)({iS, iS + nU, nU, 2 * nU}); // [nU, nU] + + NDArray Wcx = (*Wc)({0, iS, 0, 0}); // reset cell weights [iS, nU] + NDArray Wch = (*Wc)({iS, iS + nU, 0, 0}); // updates cell weights [nU, nU] + + NDArray br = (*b)({0, nU}); // [nU] + NDArray bu = (*b)({nU, 2 * nU}); // [nU] + + NDArray WrxT = Wrx.transpose(); // [nU, iS] + NDArray WuxT = Wux.transpose(); // [nU, iS] + NDArray WrhT = Wrh.transpose(); // [nU, nU] + NDArray WuhT = Wuh.transpose(); // [nU, nU] + + NDArray WcxT = Wcx.transpose(); // [nU, iS] + NDArray WchT = Wch.transpose(); // [nU, nU] + + NDArray dLdWrx = (*dLdW)({0, iS, 0, nU}); // [iS, nU] + NDArray dLdWux = (*dLdW)({0, iS, nU, 2 * nU}); // [iS, nU] + NDArray dLdWrh = (*dLdW)({iS, iS + nU, 0, nU}); // [nU, nU] + NDArray dLdWuh = (*dLdW)({iS, iS + nU, nU, 2 * nU}); // [nU, nU] + + NDArray dLdWcx = (*dLdWc)({0, iS, 0, 0}); // [iS, nU] + NDArray dLdWch = (*dLdWc)({iS, iS + nU, 0, 0}); // [nU, nU] + + NDArray dLdbr = (*dLdb)({0, nU}); // [nU] + NDArray dLdbu = (*dLdb)({nU, 2 * nU}); // [nU] + + // ***** feed forward step ***** // + + // reset gate + NDArray r = + mmul(*x, Wrx) + mmul(*hLast, Wrh) + + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + r.applyTransform(transform::Sigmoid, r); + + // update gate + NDArray u = + mmul(*x, Wux) + mmul(*hLast, Wuh) + + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + u.applyTransform(transform::Sigmoid, u); + + // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) + NDArray c = + mmul(*x, Wcx) + mmul(r * *hLast, Wch) + + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + c.applyTransform(transform::Tanh, c); + + // h = (1 - u) * c + u * hPrev + + // ***** back prop step ***** // + + // notations: + // Zr = x × Wrx + hLast × Wrh + br + // Zu = x × Wux + hLast × Wuh + bu + // Sr = sigmoid(Zr) + // Su = sigmoid(Zu) + // Zc = x × Wcx + (r * hlast) × Wch + bc + + // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * + // dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx + // = dLdx_u + dLdx_c + // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu + // * dudZu) × WuxT dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * + // drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + + // dLdx_c1 dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * + // dcdZc) × WcxT dZcdr = (... * hLast) × WchT dLdc * dcdZc * dZcdr = dLdr = + // (dLdc * dcdZc * hLast) × WchT drdx = drdZr * dZrdx dZrdx = ... × WrxT + // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT + // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * + // dcdZc) × WcxT + (dLdr * drdZr) × WrxT + + // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh + // * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast + // = dLdhLast_h + dLdhLast_u + dLdhLast_c + // dLdhLast_h = dLdh * dhdhLas = dLdh * u + // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = + // ... × WuhT| = (dLdu * dudZu) × WuhT dLdhLast_c = dLdc * dcdhLast = dLdc * + // (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = + // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = + // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + + // dLdhLast_c1 + // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = + // (dLdc * dcdZc * r) × WchT dLdhLast_c1 = dLdr * drdhLast = |drdhLast = + // drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT finally + // dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = + // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × + // WchT + (dLdr * drdZr) × WrhT + + // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = + // dLdc * dcdZc * dZcdr * drdWrx = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx + // dZrdWrx = xT × ... + // finally dLdWrx = xT × (dLdr * drdZr) + + // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = + // dLdc * dcdZc * dZcdr * drdWrh = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh + // dZrdWrh = hLastT × ... + // finally dLdWrh = hLastT × (dLdr * drdZr) + + // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux + // dZudWux = xT × ... + // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) + + // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * + // dZudWuh = dLdu * dudZu * dZudWuh dZudWuh = hLastT × ... finally dLdWuh = + // hLastT × (dLdu * dudZu) + + // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * + // dZcdWcx = dLdc * dcdZc * dZcdWcx dZcdWcx = xT × ... finally dLdWcx = xT × + // (dLdc * dcdZc) + + // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * + // dZcdWch = dLdc * dcdZc * dZcdWch dZcdWch = (r*hLast)^T × ... finally dLdWch + // = (r*hLast)^T × (dLdc * dcdZc) + + // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc + // * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = + // = dLdr * drdZr * dZrdbr + // dZrdbr = 1 + // finally dLdbr = dLdr * drdZr + + // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu + // dZudbu = 1 + // finally dLdbu = dLdu * dudZu + + // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc + // dZcdbc = 1 + // finally dLdbc = dLdc * dcdZc + + NDArray dhdc = 1.f - u; // [bS, nU] + NDArray dhdu = *hLast - c; // [bS, nU] + NDArray dudZu = u * dhdc; // [bS, nU] + NDArray drdZr = r * (1.f - r); // [bS, nU] + NDArray dcdZc = 1.f - c * c; // [bS, nU] + NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] + NDArray dLdZu = *dLdu * dudZu; // [bS, nU] + NDArray dLdZr = *dLdr * drdZr; // [bS, nU] + + // NDArray dLdc = *dLdh * dhdc; // [bS, nU] + // NDArray dLdu = *dLdh * dhdu; // [bS, nU] + // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] + + dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + + mmul(dLdZr, WrxT)); // [bS, iS] + + dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + + mmul(dLdZr, WrhT)); // [bS, nU] + + dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] + dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWch.assign( + mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] + + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] } - ////////////////////////////////////////////////////////////////////////// -void gruCellBp(sd::LaunchContext* context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { - - //Inputs: - // x input [bS, nIn] - // hI previous cell output [bS, nOut], that nIn at previous time step t-1 - // Wx input-to-hidden weights - [nIn, 3*nOut] - // Wh hidden-to-hidden weights - [nOut, 3*nOut] - // b biases, [3*nOut] - reset and update gates - // dLdh gradient vs. ff output, [bS, nOut] +void gruCellBp(sd::LaunchContext* context, const NDArray* x, const NDArray* hI, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* dLdh, const NDArray* gates, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + // Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that nIn at previous time + // step t-1 Wx input-to-hidden weights - [nIn, 3*nOut] Wh + // hidden-to-hidden weights - [nOut, 3*nOut] b biases, [3*nOut] - + // reset and update gates dLdh gradient vs. ff output, [bS, nOut] - //Outputs: - // dLdx gradient vs. x, [bS, nIn], - // dLdhI gradient vs. hI, [bS, nOut] - // dLdWx gradient vs. W, [nIn, 3*nOut] - // dLdWh gradient vs. Wc, [nOut, 3*nOut] - // dLdb gradient vs. b [3*nOut] + // Outputs: + // dLdx gradient vs. x, [bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] - // 3*nOut means following sequence: reset, update, cell + // 3*nOut means following sequence: reset, update, cell - // * means element-wnIne product or so called Hadamard product - // × means matrix multiplication + // * means element-wnIne product or so called Hadamard product + // × means matrix multiplication - // formulas: - // zr = x × Wxr + hI × Whr + br - // zu = x × Wxu + hI × Whu + bu - // r = sigmoid(zr) - // u = sigmoid(zu) - // zc = x × Wxc + (r * hI) × Whc + bc - // c = tanh(zc) - // h = (1-u)*c + u*hI + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI - // dLdhI += dLdh; [bS, nOut] + // dLdhI += dLdh; [bS, nOut] + // dhdc = 1 - u [bS, nOut] dhdu = -c + hI [bS, nOut] - // dhdc = 1 - u [bS, nOut] - // dhdu = -c + hI [bS, nOut] + // dcdzc = 1 - c*c; [bS, nOut] dudzu = u*(1-u) [bS, nOut] drdzr = r(1-r) [bS, + // nOut] - // dcdzc = 1 - c*c; [bS, nOut] - // dudzu = u*(1-u) [bS, nOut] - // drdzr = r(1-r) [bS, nOut] + // dzcdr = (...*hI × WhcT) [bS, nOut] - // dzcdr = (...*hI × WhcT) [bS, nOut] + // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] + // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] dLdzc = + // dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] - // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] - // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] - // dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] + // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, + // nIn] + ... = [bS, nIn] - // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn] + // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, + // nOut] + ... = [bS, nOut] - // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut] + // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, + // nOut] dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = + // [nIn, nOut] dLdWxc = xT × dLdzc [nIn, bS] x [bS, + // nOut] = [nIn, nOut] - // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = + // [nOut, nOut] dLdWhu = xT × dLdzu [nOut, bS] x [bS, + // nOut] = [nOut, nOut] dLdWhc = (r*hI)T × dLdzc [nOut, + // bS] x [bS, nOut] = [nOut, nOut] - // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + const int nOut = hI->sizeAt(1); - const int nOut = hI->sizeAt(1); + NDArray dLdz = gates->ulike(); // [bS, 3*nOut] - NDArray dLdz = gates->ulike(); // [bS, 3*nOut] + NDArray dLdzru = dLdz({0, 0, 0, 2 * nOut}); // [bS, 2*nOut] - NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut] + NDArray dLdzr = dLdz({0, 0, 0, nOut}); // [bS, nOut] + NDArray dLdzu = dLdz({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray dLdzc = dLdz({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] - NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut] - NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut] + NDArray r = (*gates)({0, 0, 0, nOut}); // [bS, nOut] + NDArray u = (*gates)({0, 0, nOut, 2 * nOut}); // [bS, nOut] + NDArray c = (*gates)({0, 0, 2 * nOut, 3 * nOut}); // [bS, nOut] - NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] - NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] - NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + NDArray WhcT = (*Wh)({0, 0, 2 * nOut, 3 * nOut}).transpose(); - NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose(); + if (dLdh) *dLdhI += *dLdh; - if(dLdh) - *dLdhI += *dLdh; + NDArray temp1 = 1 - u; // [bS, nOut] - NDArray temp1 = 1 - u; // [bS, nOut] + // dLdzc + dLdzc.assign(*dLdhI * temp1 * (1 - c * c)); // [bS, nOut] - // dLdzc - dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut] + // dLdzu + dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] - // dLdzu - dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] + // dLdzr + NDArray temp2 = dLdzc * (*hI) * r * (1 - r); + MmulHelper::mmul(&temp2, &WhcT, + &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] - // dLdzr - NDArray temp2 = dLdzc * (*hI) * r *(1-r); - MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, + dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] - // dLdx - NDArray WxT = Wx->transpose(); - MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] + // dLdWx + *dLdWx += + mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] - // dLdWx - *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] + // dLdb + *dLdb += dLdz.reduceAlongDimension( + reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; - // dLdb - *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; + dLdzc *= r; - dLdzc *= r; + // dLdhI + NDArray WhT = Wh->transpose(); + dLdhI->assign(*dLdhI * u + + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] - // dLdhI - NDArray WhT = Wh->transpose(); - dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] - - // dLdWr - *dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] + // dLdWr + *dLdWh += mmul(hI->transpose(), + dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] } - ////////////////////////////////////////////////////////////////////////// -void gruTimeLoopBp(sd::LaunchContext * context, - const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { - // sL means time steps - - // x input [sL, bS, nIn] - // hI initial cell output (at time step = 0) [bS, nOut] - // Wx input-to-hidden weights, [nIn, 3*nOut] - // Wh hidden-to-hidden weights, [nOut, 3*nOut] - // b biases, [3*nOut] - // dLdh gradient vs. ff output, [sL, bS, nOut] - - // dLdx gradient vs. x, [sL, bS, nIn], - // dLdhI gradient vs. hI, [bS, nOut] - // dLdWx gradient vs. W, [nIn, 3*nOut] - // dLdWh gradient vs. Wc, [nOut, 3*nOut] - // dLdb gradient vs. b [3*nOut] - - const int sL = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int nOut = hI->sizeAt(1); - - NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext()); - NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext()); - - auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] - auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] - - hSet.at(0).assign(hI); - - // forward time loop - for (int t = 0; t < sL; ++t) - gruCell(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &gatesSet.at(t), &hSet.at(t+1)); - - // backward time loop - for (int t = sL-1; t >= 0; --t) - gruCellBp(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &dLdhSet.at(t), &gatesSet.at(t), - &dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); +void gruTimeLoopBp(sd::LaunchContext* context, const NDArray* x, + const NDArray* hI, const NDArray* Wx, const NDArray* Wh, + const NDArray* b, const NDArray* dLdh, NDArray* dLdx, + NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, + NDArray* dLdb) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + // dLdh gradient vs. ff output, [sL, bS, nOut] + + // dLdx gradient vs. x, [sL, bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(x->ordering(), {sL, bS, 3 * nOut}, dLdh->dataType(), + x->getContext()); + NDArray h(x->ordering(), {sL + 1, bS, nOut}, dLdh->dataType(), + x->getContext()); + + auto xSet = + x->allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nIn] + auto dLdhSet = dLdh->allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nOut] + auto hSet = + h.allTensorsAlongDimension({1, 2}); // sub-arrays with shape [bS, nOut] + auto gatesSet = gates.allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nOut] + auto dLdxSet = dLdx->allTensorsAlongDimension( + {1, 2}); // sub-arrays with shape [bS, nIn] + + hSet.at(0).assign(hI); + + // forward time loop + for (int t = 0; t < sL; ++t) + gruCell(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &gatesSet.at(t), + &hSet.at(t + 1)); + + // backward time loop + for (int t = sL - 1; t >= 0; --t) + gruCellBp(context, &xSet.at(t), &hSet.at(t), Wx, Wh, b, &dLdhSet.at(t), + &gatesSet.at(t), &dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); } - -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp index 1a61587a3582..f7187265d896 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/knn_mindistance.cpp @@ -21,42 +21,47 @@ #include namespace sd { - namespace ops { - namespace helpers { - template - void mindistance_(const void* vinput, const void *vlow, const void *vhigh, int32_t length, void *vout) { - auto input = reinterpret_cast(vinput); - auto low = reinterpret_cast(vlow); - auto high = reinterpret_cast(vhigh); - auto output = reinterpret_cast(vout); - - T res = 0.0f; - T po = 2.f; - T o = 1.f; - -#pragma omp simd reduction(sumT:res) - for (auto e = 0; e < length; e++) { - T p = input[e]; - T l = low[e]; - T h = high[e]; - if (!(l <= p || h <= p)) { - if (p < l) - res += sd::math::nd4j_pow((p - o), po); - else - res += sd::math::nd4j_pow((p - h), po); - } - } - - output[0] = sd::math::nd4j_pow(res, (T) 0.5f); - } - - void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output) { - NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); - - BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, (input.buffer(), lowest.buffer(), highest.buffer(), input.lengthOf(), output.buffer()), FLOAT_TYPES); - - NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); - } - } +namespace ops { +namespace helpers { +template +void mindistance_(const void *vinput, const void *vlow, const void *vhigh, + int32_t length, void *vout) { + auto input = reinterpret_cast(vinput); + auto low = reinterpret_cast(vlow); + auto high = reinterpret_cast(vhigh); + auto output = reinterpret_cast(vout); + + T res = 0.0f; + T po = 2.f; + T o = 1.f; + +#pragma omp simd reduction(sumT : res) + for (auto e = 0; e < length; e++) { + T p = input[e]; + T l = low[e]; + T h = high[e]; + if (!(l <= p || h <= p)) { + if (p < l) + res += sd::math::nd4j_pow((p - o), po); + else + res += sd::math::nd4j_pow((p - h), po); } -} \ No newline at end of file + } + + output[0] = sd::math::nd4j_pow(res, (T)0.5f); +} + +void knn_mindistance(const NDArray &input, const NDArray &lowest, + const NDArray &highest, NDArray &output) { + NDArray::preparePrimaryUse({&output}, {&input, &lowest, &highest}); + + BUILD_SINGLE_SELECTOR(input.dataType(), mindistance_, + (input.buffer(), lowest.buffer(), highest.buffer(), + input.lengthOf(), output.buffer()), + FLOAT_TYPES); + + NDArray::registerPrimaryUse({&output}, {&input, &lowest, &highest}); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp index 6d937dc0fbb2..e975f61f4d99 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp @@ -19,105 +19,122 @@ // #include + #include //#include namespace sd { namespace ops { namespace helpers { - template - static Nd4jLong listDiffCount_(NDArray* values, NDArray* keep) { - Nd4jLong saved = 0L; - for (Nd4jLong e = 0; e < values->lengthOf(); e++) { - auto v = values->e(e); - ExtraArguments extras({v, 0.0, 10.0}); - auto idx = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); - auto index = idx.e(0); - if (index < 0) - saved++; - } - return saved; - } +template +static Nd4jLong listDiffCount_(NDArray* values, NDArray* keep) { + Nd4jLong saved = 0L; + for (Nd4jLong e = 0; e < values->lengthOf(); e++) { + auto v = values->e(e); + ExtraArguments extras({v, 0.0, 10.0}); + auto idx = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); + auto index = idx.e(0); + if (index < 0) saved++; + } + return saved; +} - Nd4jLong listDiffCount(sd::LaunchContext * context, NDArray* values, NDArray* keep) { - auto xType = values->dataType(); +Nd4jLong listDiffCount(sd::LaunchContext* context, NDArray* values, + NDArray* keep) { + auto xType = values->dataType(); - NDArray::preparePrimaryUse({},{values, keep}); + NDArray::preparePrimaryUse({}, {values, keep}); - BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), + LIBND4J_TYPES); - NDArray::registerPrimaryUse({},{values, keep}); - } + NDArray::registerPrimaryUse({}, {values, keep}); +} - BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, (NDArray* values, NDArray* keep);, LIBND4J_TYPES); - - template - static int listDiffFunctor_(NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { - - std::vector saved; - std::vector indices; - - for (Nd4jLong e = 0; e < values->lengthOf(); e++) { - auto v = values->e(e); - ExtraArguments extras({v, 0.0, 10.0}); - NDArray idxScalar = keep->indexReduceNumber(indexreduce::FirstIndex, &extras); - Nd4jLong idx = idxScalar.e(0); - - if (idx < 0) { - saved.emplace_back(v); - indices.emplace_back(e); - } - } - - - if (saved.size() == 0) { -// if (sd::ops::conditionHelper(__FILE__, __LINE__, false, 0, "ListDiff: search returned no results") != 0) - nd4j_printf("ListDiff: search returned no results", ""); - throw std::invalid_argument("Op validation failed"); - } else { - auto z0 = output1;//OUTPUT_VARIABLE(0); //new NDArray('c', {(int) saved.size()}); - auto z1 = output2; //OUTPUT_VARIABLE(1); //new NDArray('c', {(int) saved.size()}); - - if (z0->lengthOf() != saved.size()) { - nd4j_printf("ListDiff: output/actual size mismatch", ""); - throw std::invalid_argument("Op validation failed"); - } - - if (z1->lengthOf() != saved.size()) { - nd4j_printf("ListDiff: output/actual indices size mismatch", ""); - throw std::invalid_argument("Op validation failed"); - } - memcpy(z0->buffer(), saved.data(), saved.size() * sizeof(T)); - for (int e = 0; e < indices.size(); e++) { - z1->p(e, indices[e]); - } - } - return Status::OK(); +BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, + (NDArray * values, NDArray* keep); + , LIBND4J_TYPES); + +template +static int listDiffFunctor_(NDArray* values, NDArray* keep, NDArray* output1, + NDArray* output2) { + std::vector saved; + std::vector indices; + + for (Nd4jLong e = 0; e < values->lengthOf(); e++) { + auto v = values->e(e); + ExtraArguments extras({v, 0.0, 10.0}); + NDArray idxScalar = + keep->indexReduceNumber(indexreduce::FirstIndex, &extras); + Nd4jLong idx = idxScalar.e(0); + + if (idx < 0) { + saved.emplace_back(v); + indices.emplace_back(e); + } + } + + if (saved.size() == 0) { + // if (sd::ops::conditionHelper(__FILE__, __LINE__, false, 0, + // "ListDiff: search returned no results") != 0) + nd4j_printf("ListDiff: search returned no results", ""); + throw std::invalid_argument("Op validation failed"); + } else { + auto z0 = output1; // OUTPUT_VARIABLE(0); //new NDArray('c', {(int) + // saved.size()}); + auto z1 = output2; // OUTPUT_VARIABLE(1); //new NDArray('c', {(int) + // saved.size()}); + + if (z0->lengthOf() != saved.size()) { + nd4j_printf("ListDiff: output/actual size mismatch", ""); + throw std::invalid_argument("Op validation failed"); } - int listDiffFunctor(sd::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { - auto xType = values->dataType(); - - NDArray::preparePrimaryUse({output1, output2}, {values, keep}); + if (z1->lengthOf() != saved.size()) { + nd4j_printf("ListDiff: output/actual indices size mismatch", ""); + throw std::invalid_argument("Op validation failed"); + } + memcpy(z0->buffer(), saved.data(), saved.size() * sizeof(T)); + for (int e = 0; e < indices.size(); e++) { + z1->p(e, indices[e]); + } + } + return Status::OK(); +} - int result = 0; +int listDiffFunctor(sd::LaunchContext* context, NDArray* values, NDArray* keep, + NDArray* output1, NDArray* output2) { + auto xType = values->dataType(); - if (DataTypeUtils::isR(xType)) { - BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, (values, keep, output1, output2), FLOAT_TYPES); - } else if (DataTypeUtils::isZ(xType)) { - BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, (values, keep, output1, output2), INTEGER_TYPES); - } else { - throw std::runtime_error("ListDiff: Only integer and floating point data types are supported"); - } + NDArray::preparePrimaryUse({output1, output2}, {values, keep}); - NDArray::registerPrimaryUse({output1, output2}, {values, keep}); + int result = 0; - return result; - } + if (DataTypeUtils::isR(xType)) { + BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, + (values, keep, output1, output2), FLOAT_TYPES); + } else if (DataTypeUtils::isZ(xType)) { + BUILD_SINGLE_SELECTOR(xType, result = listDiffFunctor_, + (values, keep, output1, output2), INTEGER_TYPES); + } else { + throw std::runtime_error( + "ListDiff: Only integer and floating point data types are supported"); + } - BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, (NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2);, FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, (NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2);, INTEGER_TYPES); + NDArray::registerPrimaryUse({output1, output2}, {values, keep}); + return result; } -} -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, + (NDArray * values, NDArray* keep, NDArray* output1, + NDArray* output2); + , FLOAT_TYPES); +BUILD_SINGLE_TEMPLATE(template int listDiffFunctor_, + (NDArray * values, NDArray* keep, NDArray* output1, + NDArray* output2); + , INTEGER_TYPES); + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp index 4ab585e26f83..5bb3942790f3 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstm.cpp @@ -20,122 +20,127 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include +#include #include +#include #include -#include #include -#include +#include +#include + #include -#include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { ///////////////////////////////////////////////////////////////////////////// - void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, const NDArray* c0, const NDArray* y0, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq, - const NDArray* hSeq, const NDArray* ySeq, const std::vector& params, const int dataFormat){ - - int seqLen, bS, nIn, nOut; - - if(dataFormat == 0) { - seqLen = xSeq->sizeAt(0); - bS = xSeq->sizeAt(1); - nIn = xSeq->sizeAt(2); - nOut = iSeq->sizeAt(2); - } - else if(dataFormat == 1) { - seqLen = xSeq->sizeAt(2); - bS = xSeq->sizeAt(0); - nIn = xSeq->sizeAt(1); - nOut = iSeq->sizeAt(1); - } - else if(dataFormat == 2) { - seqLen = xSeq->sizeAt(1); - bS = xSeq->sizeAt(0); - nIn = xSeq->sizeAt(2); - nOut = iSeq->sizeAt(2); - } - - const std::vector inSliceShape({bS,nIn}); - const std::vector outSliceShape({bS,nOut}); - - auto c_t1 = const_cast(c0); - auto y_t1 = const_cast(y0); - - // loop through time steps - for (int t = 0; t < seqLen; ++t) { - - auto xt = timeSubset(xSeq, t, dataFormat); - - auto it = timeSubset(iSeq, t, dataFormat); - auto ct = timeSubset(cSeq, t, dataFormat); - auto ft = timeSubset(fSeq, t, dataFormat); - auto ot = timeSubset(oSeq, t, dataFormat); - auto zt = timeSubset(zSeq, t, dataFormat); - auto ht = timeSubset(hSeq, t, dataFormat); - auto yt = timeSubset(ySeq, t, dataFormat); - - helpers::lstmBlockCell(&xt, c_t1, y_t1, W, Wci, Wcf, Wco, b, &it, &ct, &ft, &ot, &zt, &ht, &yt, params); - - if(t != 0) { - delete c_t1; - delete y_t1; - } - - if(t < seqLen - 1) { - c_t1 = new NDArray(std::move(ct)); - y_t1 = new NDArray(std::move(yt)); - } - } - } - - - - ////////////////////////////////////////////////////////////////////////// - void lstmTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* h, NDArray* c, const std::vector& params) { - - // x input [time x bS x nIn] - // h0 initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!! - // c0 initial cell state (at time step = 0) [bS x numUnits], - - // Wx input-to-hidden weights, [nIn x 4*numUnits] - // Wh hidden-to-hidden weights, [numProj x 4*numUnits] - // Wc diagonal weights for peephole connections [3*numUnits] - // Wp projection weights [numUnits x numProj] - // b biases, [4*numUnits] - - // h cell outputs [time x bS x numProj], that is per each time step - // c cell states [time x bS x numUnits] that is per each time step - - const int time = x->sizeAt(0); - - NDArray currentH(*h0); - NDArray currentC(*c0); - - // loop through time steps - for (int t = 0; t < time; ++t) { - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - auto ct = (*c)({t,t+1, 0,0, 0,0}); - - helpers::lstmCell(context, &xt,¤tH,¤tC, Wx,Wh,Wc,Wp, b, &ht, &ct, params); - currentH.assign(ht); - currentC.assign(ct); - } - } - - - } +void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, + const NDArray* c0, const NDArray* y0, const NDArray* W, + const NDArray* Wci, const NDArray* Wcf, + const NDArray* Wco, const NDArray* b, + const NDArray* iSeq, const NDArray* cSeq, + const NDArray* fSeq, const NDArray* oSeq, + const NDArray* zSeq, const NDArray* hSeq, + const NDArray* ySeq, const std::vector& params, + const int dataFormat) { + int seqLen, bS, nIn, nOut; + + if (dataFormat == 0) { + seqLen = xSeq->sizeAt(0); + bS = xSeq->sizeAt(1); + nIn = xSeq->sizeAt(2); + nOut = iSeq->sizeAt(2); + } else if (dataFormat == 1) { + seqLen = xSeq->sizeAt(2); + bS = xSeq->sizeAt(0); + nIn = xSeq->sizeAt(1); + nOut = iSeq->sizeAt(1); + } else if (dataFormat == 2) { + seqLen = xSeq->sizeAt(1); + bS = xSeq->sizeAt(0); + nIn = xSeq->sizeAt(2); + nOut = iSeq->sizeAt(2); + } + + const std::vector inSliceShape({bS, nIn}); + const std::vector outSliceShape({bS, nOut}); + + auto c_t1 = const_cast(c0); + auto y_t1 = const_cast(y0); + + // loop through time steps + for (int t = 0; t < seqLen; ++t) { + auto xt = timeSubset(xSeq, t, dataFormat); + + auto it = timeSubset(iSeq, t, dataFormat); + auto ct = timeSubset(cSeq, t, dataFormat); + auto ft = timeSubset(fSeq, t, dataFormat); + auto ot = timeSubset(oSeq, t, dataFormat); + auto zt = timeSubset(zSeq, t, dataFormat); + auto ht = timeSubset(hSeq, t, dataFormat); + auto yt = timeSubset(ySeq, t, dataFormat); + + helpers::lstmBlockCell(&xt, c_t1, y_t1, W, Wci, Wcf, Wco, b, &it, &ct, &ft, + &ot, &zt, &ht, &yt, params); + + if (t != 0) { + delete c_t1; + delete y_t1; + } + + if (t < seqLen - 1) { + c_t1 = new NDArray(std::move(ct)); + y_t1 = new NDArray(std::move(yt)); } -} \ No newline at end of file + } +} + +////////////////////////////////////////////////////////////////////////// +void lstmTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* c0, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* h, NDArray* c, + const std::vector& params) { + // x input [time x bS x nIn] + // h0 initial cell output (at time step = 0) [bS x numProj], in case of + // projection=false -> numProj == numUnits !!! c0 initial cell state (at time + // step = 0) [bS x numUnits], + + // Wx input-to-hidden weights, [nIn x 4*numUnits] + // Wh hidden-to-hidden weights, [numProj x 4*numUnits] + // Wc diagonal weights for peephole connections [3*numUnits] + // Wp projection weights [numUnits x numProj] + // b biases, [4*numUnits] + + // h cell outputs [time x bS x numProj], that is per each time step + // c cell states [time x bS x numUnits] that is per each time step + + const int time = x->sizeAt(0); + + NDArray currentH(*h0); + NDArray currentC(*c0); + + // loop through time steps + for (int t = 0; t < time; ++t) { + auto xt = (*x)({t, t + 1, 0, 0, 0, 0}); + auto ht = (*h)({t, t + 1, 0, 0, 0, 0}); + auto ct = (*c)({t, t + 1, 0, 0, 0, 0}); + + helpers::lstmCell(context, &xt, ¤tH, ¤tC, Wx, Wh, Wc, Wp, b, + &ht, &ct, params); + currentH.assign(ht); + currentC.assign(ct); + } +} + +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 125b0cd3ca09..6bfca621d6e6 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -21,17 +21,18 @@ // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf -// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. -// and +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural +// Computation, 9(8):1735-1780, 1997. and // https://research.google.com/pubs/archive/43905.pdf -// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory +// recurrent neural network architectures for large scale acoustic modeling." +// INTERSPEECH, 2014. - -#include #include -#include -#include #include +#include +#include +#include // #include // #include // #include @@ -39,1176 +40,1379 @@ // #include // #include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::Tanh, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, z); - break; - case 3: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, z, &args); - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); - break; - case 5: - thresholdRelu(x.getContext(), x, alpha, z); - break; - case 6: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); - break; - } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSign, z); - break; - case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, z); - break; - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); +static void applyActivation(const NDArray& x, const int opId, const float alpha, + const float beta, NDArray& z) { + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, z); + break; + case 3: { + ExtraArguments args( + {static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, z, &args); + break; } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); + break; + case 5: + thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args( + {static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, z); + break; + default: + throw std::invalid_argument( + "LSTM_LAYER operation: wrong id number of activation !"); + } } ////////////////////////////////////////////////////////////////////////// -static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::TanhDerivative, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELUDerivative, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); - break; - case 3: { - z = alpha; - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELUDerivative, alpha, z); - break; - case 5: - (const_cast(x)).applyScalar(scalar::RELUDerivative, alpha, z); - break; - case 6: { - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - auto val = beta * x.e(i); - z.p(i, alpha * beta * (1.f - sd::math::nd4j_tanh(val) * sd::math::nd4j_tanh(val))); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); - break; +static void activationDeriv(const NDArray& x, const int opId, const float alpha, + const float beta, NDArray& z) { + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::TanhDerivative, z); + break; + case 1: + (const_cast(x)) + .applyScalar(scalar::RELUDerivative, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); + break; + case 3: { + z = alpha; + break; + } + case 4: + (const_cast(x)) + .applyScalar(scalar::LeakyRELUDerivative, alpha, z); + break; + case 5: + (const_cast(x)) + .applyScalar(scalar::RELUDerivative, alpha, z); + break; + case 6: { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + auto val = beta * x.e(i); + z.p(i, alpha * beta * + (1.f - sd::math::nd4j_tanh(val) * + sd::math::nd4j_tanh(val))); } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoidDerivative, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELUDerivative, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSignDerivative, z); - break; - case 10: { - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - auto val = sd::math::nd4j_exp(x.e(i)); - z.p(i, val / (1.f + val)); - } - }; - samediff::Threads::parallel_for(func, 0, x.lengthOf()); - break; + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + case 7: + (const_cast(x)) + .applyTransform(transform::HardSigmoidDerivative, z); + break; + case 8: + (const_cast(x)) + .applyScalar(scalar::ELUDerivative, alpha, z); + break; + case 9: + (const_cast(x)) + .applyTransform(transform::SoftSignDerivative, z); + break; + case 10: { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + auto val = sd::math::nd4j_exp(x.e(i)); + z.p(i, val / (1.f + val)); } - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; } + default: + throw std::invalid_argument( + "LSTM_LAYER operation: wrong id number of activation !"); + } } ////////////////////////////////////////////////////////////////////////// -// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal -static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) { - - if(clipVal == 0) - return; - - auto func = PRAGMA_THREADS_FOR { - for(Nd4jLong i = start; i < stop; ++i) { - const auto val = c.e(i); - if(val == -clipVal || val == clipVal) { - z0.p(i, 0.f); - z1.p(i, 0.f); - z2.p(i, 0.f); - z3.p(i, 0.f); - } - } - }; - samediff::Threads::parallel_for(func, 0, c.lengthOf()); +// FIXME - derivative undefined when not-clipped c has element/elements equal to +// -clipVal or clipVal +static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, + NDArray& z1, NDArray& z2, NDArray& z3) { + if (clipVal == 0) return; + + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = start; i < stop; ++i) { + const auto val = c.e(i); + if (val == -clipVal || val == clipVal) { + z0.p(i, 0.f); + z1.p(i, 0.f); + z2.p(i, 0.f); + z3.p(i, 0.f); + } + } + }; + samediff::Threads::parallel_for(func, 0, c.lengthOf()); } ////////////////////////////////////////////////////////////////////////// -static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { - - if(dataFormat == 0 || dataFormat == 3) - return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] +static NDArray tensorAlongTimeBatchDims(const NDArray& arr, + const int dataFormat, const int t1, + const int t2, const int b1, + const int b2) { + if (dataFormat == 0 || dataFormat == 3) + return arr({t1, t2, b1, b2, 0, 0}); // TNS: [sL, bS, nIn] - if(dataFormat == 1) - return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + if (dataFormat == 1) + return arr({b1, b2, t1, t2, 0, 0}); // NTS: [bS, sL ,nIn] - return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] + return arr({b1, b2, 0, 0, t1, t2}); // NST: [bS, nIn, sL] } ////////////////////////////////////////////////////////////////////////// -static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { - - if(dataFormat == 0 || dataFormat == 3) - return t * bS + b; // TNS: shape [sL, bS, nIn] +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, + const int sL, const int bS, + const int t, const int b) { + if (dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] - return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } - ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const std::vector& params, NDArray* h, NDArray* c) { - - // * -> means element-wise multiplication - // × -> means matrix multiplication - - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** the objective is to provide math-readable code **/ - - // equations (no peephole connections) - // it = σ(Wxi × xt + Wri × ht-1 + bi) - // ft = σ(Wxf × xt + Wrf × ht-1 + bf) - // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) - // ct = ft * ct-1 + it * c't - // ot = σ(Wxo × xt + Wro × ht-1 + bo) - // ht = ot * tanh(ct) - - // equations (peephole connections are present) - // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) - // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) - // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) - // ct = ft * ct-1 + it * c't - // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) - // ht = ot * tanh(ct) - - - // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - - // params[0] - dataFormat, ignore - // params[1] - directionMode, ignore - // params[2] - cell clipping value, if it = 0 then do not apply clipping - - // params[3] - activation ID for input (i), forget (f) and output (o) gates - // params[4] - alpha value for gates activation - // params[5] - beta value for gates activation - - // params[6] - activation ID for cell state (c) - // params[7] - alpha value for cell state activation - // params[8] - beta value for cell state activation - - // params[9] - activation ID for output (h) - // params[10] - alpha value for output activation - // params[11] - beta value for output activation - - // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - - // OUTPUTS: - // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr - // c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] - //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] - - // add biases if they are given - if(b != nullptr) - z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - - auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - // peephole connections for input and forget gates - if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - } - - applyActivation(zi, params[3], params[4], params[5], zi); // inplace - applyActivation(zf, params[3], params[4], params[5], zf); // inplace - applyActivation(zg, params[6], params[7], params[8], zg); // inplace - - c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) - - // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation - if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2], *c); - - // peephole connections for output gate - if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - - applyActivation(zo, params[3], params[4], params[5], zo); - - applyActivation(*c, params[9], params[10], params[11], *h); - *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) + // * -> means element-wise multiplication + // × -> means matrix multiplication + + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // it = σ(Wxi × xt + Wri × ht-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo × xt + Wro × ht-1 + bo) + // ht = ot * tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) + // ht = ot * tanh(ct) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= + // thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, + // 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, + // [bS, nOut] or [nOut] if seqLen != nullptr cI - (ct-1) previous (initial) + // cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if + // seqLen != nullptr Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - current output, that is at current time step t, [bS, nOut] or [nOut] if + // seqLen != nullptr c - current cell state, that is at current time step t, + // [bS, nOut] or [nOut] if seqLen != nullptr + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + auto z = mmul(*x, *Wx) + + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, + // 4*nOut] = [bS, 4*nOut] + // or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, + // 4*nOut] = [4*nOut] + + // add biases if they are given + if (b != nullptr) + z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 + ? z({0, nOut}) + : z({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = + x->rankOf() == 1 + ? z({nOut, 2 * nOut}) + : z({0, 0, nOut, 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 + ? z({2 * nOut, 3 * nOut}) + : z({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 + ? z({3 * nOut, 4 * nOut}) + : z({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if (Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * + // [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], zi); // inplace + applyActivation(zf, params[3], params[4], params[5], zf); // inplace + applyActivation(zg, params[6], params[7], params[8], zg); // inplace + + c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, + // nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value + // prior to the cell output activation + if (params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if (Wp != nullptr) + zo += + *c * (*Wp)({2 * nOut, 3 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], zo); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) } - ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const std::vector& params, NDArray* z, NDArray* a, NDArray* h, NDArray* c) { - - // z - zi, zf, zg, zo - // a - i, f, g, o - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] - //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] - // add biases if they are given - if(b != nullptr) - *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - - auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) - auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) - auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) - auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) - - // peephole connections for input and forget gates - if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - } - - applyActivation(zi, params[3], params[4], params[5], i); - applyActivation(zf, params[3], params[4], params[5], f); - applyActivation(zg, params[6], params[7], params[8], g); - - c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) - - // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation - if(params[2] != 0) - c->applyScalar(scalar::LstmClip, params[2], *c); - - // peephole connections for output gate - if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) - - applyActivation(zo, params[3], params[4], params[5], o); - - applyActivation(*c, params[9], params[10], params[11], *h); - *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) + // z - zi, zf, zg, zo + // a - i, f, g, o + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + z->assign(mmul(*x, *Wx) + + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * + // [nOut, 4*nOut] = [bS, 4*nOut] + // or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, + // 4*nOut] = [4*nOut] + // add biases if they are given + if (b != nullptr) + *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 + ? (*z)({0, nOut}) + : (*z)({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 + ? (*z)({nOut, 2 * nOut}) + : (*z)({0, 0, nOut, + 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 + ? (*z)({2 * nOut, 3 * nOut}) + : (*z)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 + ? (*z)({3 * nOut, 4 * nOut}) + : (*z)({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + auto i = x->rankOf() == 1 + ? (*a)({0, nOut}) + : (*a)({0, 0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto f = x->rankOf() == 1 + ? (*a)({nOut, 2 * nOut}) + : (*a)({0, 0, nOut, + 2 * nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto g = x->rankOf() == 1 + ? (*a)({2 * nOut, 3 * nOut}) + : (*a)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto o = x->rankOf() == 1 + ? (*a)({3 * nOut, 4 * nOut}) + : (*a)({0, 0, 3 * nOut, + 4 * nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if (Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * + // [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], i); + applyActivation(zf, params[3], params[4], params[5], f); + applyActivation(zg, params[6], params[7], params[8], g); + + c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, + // nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value + // prior to the cell output activation + if (params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if (Wp != nullptr) + zo += + *c * (*Wp)({2 * nOut, 3 * nOut}); // broadcast: [bS, nOut] + [bS, nOut] + // * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], o); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) } - ////////////////////////////////////////////////////////////////////////// -void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { - - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** the objective is to provide math-readable code **/ - - // equations (no peephole connections) - // zi = x × Wxi + hI × Wri + bi - // zf = x × Wxf + hI × Wrf + bf - // zg = x × Wxg + hI × Wrg + bg - // zo = x × Wxo + hI × Wro + bo - // i = act(zi) - // f = act(zf) - // g = actC(zg) - // o = act(zo) - // c = clip(f * cI + i * g) - // h = o * actH(c) - - // equations (peephole connections are present) - // zi = x × Wxi + hI × Wri + cI * Wpi + bi - // zf = x × Wxf + hI × Wrf + cI * Wpf + bf - // zg = x × Wxg + hI × Wrg + bg - // zo = x × Wxo + hI × Wro + c * Wpo + bo - // i = act(zi) - // f = act(zf) - // g = actC(zg) - // o = act(zo) - // c = clip(f * cI + i * g) - // h = o * actH(c) - - // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus - - // params[0] - dataFormat, ignore - // params[1] - directionMode, ignore - // params[2] - cell clipping value, if it = 0 then do not apply clipping - - // params[3] - activation ID for input (i), forget (f) and output (o) gates - // params[4] - alpha value for gates activation - // params[5] - beta value for gates activation - - // params[6] - activation ID for cell state (c) - // params[7] - alpha value for cell state activation - // params[8] - beta value for cell state activation - - // params[9] - activation ID for output (h) - // params[10] - alpha value for output activation - // params[11] - beta value for output activation - - // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr - // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] - - // OUTPUTS: - // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr - // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] - // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] - // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] - // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] - - // !!! dimension 4*nOut implies order i, f, g, o - // !!! dimension 3*nOut implies order i, f, o - - // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] - // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] - - // dLdhI += dLdh; [bS, nOut] - // dLdcI += dLdhI * dhdc; [bS, nOut] - - // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) - // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) - // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) - // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) - - // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] - // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] - // dLdcI = dLdcI*dcdcI, [bS, nOut] - - // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] - - // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] - - // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - - // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - // dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] - - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - const Nd4jLong nIn = x->sizeAt(-1); - - NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) - NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) - NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) - NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) - - NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) - NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) - NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) - NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) - - NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) - NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut}); - NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut}); - NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut}); - NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut}); - - // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) - activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace - dLdzi *= g; // dcdi = g*clipDeriv - - // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) - activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace - dLdzf *= *cI; // dcdf = cI*clipDeriv - - // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) - activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace - dLdzg *= i; // dcdf = i*clipDeriv - - // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) - activationDeriv(zo, params[3], params[4], params[5], dLdzo); - NDArray temp = dLdzo.ulike(); - applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace - dLdzo *= temp; - - // dcdcI - NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) - - // take into account possible deposit from clipping derivative - clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); - - // dhdc - NDArray dhdc = c->ulike(); - activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] - dhdc *= o; - - if(Wp) { - dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut}); - dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ... - } - - if(dLdh) - *dLdhI += *dLdh; - if(dLdhL) - *dLdhI += *dLdhL; - if(dLdcL) - *dLdcI += *dLdcL; - - *dLdcI += *dLdhI * dhdc; - - dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) - dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) - - // dLdx - NDArray WxT = Wx->transpose(); - MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) - - // dLdhI - NDArray WrT = Wr->transpose(); - MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) - - // dLdcI - dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut]) - - if(x->rankOf() == 1) { - - NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1] - NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1] - NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut] - - // dLdWx - *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] - - // dLdWr - *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] - } - else { - - // dLdWx - *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] - - // dLdWr - *dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] - } - - // dLdb - if(b && x->rankOf() == 1) - *dLdb += dLdz; // [4*nOut] - else if(b) - *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; - - // dLdWp - if(Wp && x->rankOf() == 1) { - (*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut] - (*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut] - (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut] - } - else if(Wp) { - NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); - (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({0,nOut}) += temp; - (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({nOut,2*nOut}) += temp; - (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] - (*dLdWp)({2*nOut,3*nOut}) += temp; - } +void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* z, const NDArray* a, const NDArray* c, + const std::vector& params, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + /************************ THIS IS NOT OPTIMAZED CODE + * ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // zi = x × Wxi + hI × Wri + bi + // zf = x × Wxf + hI × Wrf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // equations (peephole connections are present) + // zi = x × Wxi + hI × Wri + cI * Wpi + bi + // zf = x × Wxf + hI × Wrf + cI * Wpf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + c * Wpo + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= + // thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, + // 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] + // if seqLen != nullptr cI - (ct-1) previous (initial) cell state at time + // t-1, [bS, nOut] or [nOut] if seqLen != nullptr Wp - peephole weights + // [3*nOut], optional, may be nullptr dLdh - loss derivative with respect to + // h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr dLdhL - loss + // derivative with respect to h at last time step, [bS, nOut] or [nOut] if + // seqLen != nullptr dLdcL - loss derivative with respect to c at last time + // step, [bS, nOut] or [nOut] if seqLen != nullptr z - zi,zf,zg,zo taken + // from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] a - + // i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, + // 4*nOut] c - taken from ff outputs to reduce amount of calculations in + // bp, [bS, nOut] + + // OUTPUTS: + // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != + // nullptr dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] dLdWr - + // loss derivative with respect to Wr, [nOut, 4*nOut] dLdb - loss derivative + // with respect to b, optional, may be nullptr, [4*nOut] dLdhI - loss + // derivative with respect to hI, optional may be nullptr, [bS, nOut] or + // [nOut] if seqLen != nullptr dLdcI - loss derivative with respect to cI, + // optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr dLdWp - + // loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + + // !!! dimension 4*nOut implies order i, f, g, o + // !!! dimension 3*nOut implies order i, f, o + + // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] + // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] + + // dLdhI += dLdh; [bS, nOut] + // dLdcI += dLdhI * dhdc; [bS, nOut] + + // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) + // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) + // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) + // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) + + // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] + // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] + // dLdcI = dLdcI*dcdcI, [bS, nOut] + + // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, + // nOut] dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = + // [nIn, nOut] dLdWxg = xT×dLdzg [nIn, bS] x [bS, + // nOut] = [nIn, nOut] dLdWxo = xT×dLdzo [nIn, bS] + // x [bS, nOut] = [nIn, nOut] + + // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = + // [nOut, nOut] dLdWrf = hIT×dLdzf [nOut, bS] x [bS, + // nOut] = [nOut, nOut] dLdWrg = hIT×dLdzg [nOut, + // bS] x [bS, nOut] = [nOut, nOut] dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] + // = [nOut, nOut] + + // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> + // [nOut] dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> + // reduce -> [nOut] dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, + // nOut] -> reduce -> [nOut] + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + const Nd4jLong nIn = x->sizeAt(-1); + + NDArray zi = + x->rankOf() == 1 + ? (*z)({0, nOut}) + : (*z)({0, 0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray zf = x->rankOf() == 1 + ? (*z)({nOut, 2 * nOut}) + : (*z)({0, 0, nOut, + 2 * nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray zg = x->rankOf() == 1 + ? (*z)({2 * nOut, 3 * nOut}) + : (*z)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray zo = x->rankOf() == 1 + ? (*z)({3 * nOut, 4 * nOut}) + : (*z)({0, 0, 3 * nOut, + 4 * nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray i = + x->rankOf() == 1 + ? (*a)({0, nOut}) + : (*a)({0, 0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray f = x->rankOf() == 1 + ? (*a)({nOut, 2 * nOut}) + : (*a)({0, 0, nOut, + 2 * nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray g = x->rankOf() == 1 + ? (*a)({2 * nOut, 3 * nOut}) + : (*a)({0, 0, 2 * nOut, + 3 * nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray o = x->rankOf() == 1 + ? (*a)({3 * nOut, 4 * nOut}) + : (*a)({0, 0, 3 * nOut, + 4 * nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) + NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0, 0, 0, nOut}); + NDArray dLdzf = + x->rankOf() == 1 ? dLdz({nOut, 2 * nOut}) : dLdz({0, 0, nOut, 2 * nOut}); + NDArray dLdzg = x->rankOf() == 1 ? dLdz({2 * nOut, 3 * nOut}) + : dLdz({0, 0, 2 * nOut, 3 * nOut}); + NDArray dLdzo = x->rankOf() == 1 ? dLdz({3 * nOut, 4 * nOut}) + : dLdz({0, 0, 3 * nOut, 4 * nOut}); + + // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) + activationDeriv(zi, params[3], params[4], params[5], + dLdzi); // didzi, inplace + dLdzi *= g; // dcdi = g*clipDeriv + + // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) + activationDeriv(zf, params[3], params[4], params[5], + dLdzf); // dfdzf, inplace + dLdzf *= *cI; // dcdf = cI*clipDeriv + + // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) + activationDeriv(zg, params[6], params[7], params[8], + dLdzg); // dgdzg, inplace + dLdzg *= i; // dcdf = i*clipDeriv + + // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) + activationDeriv(zo, params[3], params[4], params[5], dLdzo); + NDArray temp = dLdzo.ulike(); + applyActivation(*c, params[9], params[10], params[11], + temp); // actH(c), inplace + dLdzo *= temp; + + // dcdcI + NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) + + // take into account possible deposit from clipping derivative + clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); + + // dhdc + NDArray dhdc = c->ulike(); + activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] + dhdc *= o; + + if (Wp) { + dhdc += dLdzo * (*Wp)({2 * nOut, 3 * nOut}); + dcdcI += + dLdzi * (*Wp)({0, nOut}) + + dLdzf * (*Wp)({nOut, 2 * nOut}); // broadcast [bS, nOut] * nOut + ... + } + + if (dLdh) *dLdhI += *dLdh; + if (dLdhL) *dLdhI += *dLdhL; + if (dLdcL) *dLdcI += *dLdcL; + + *dLdcI += *dLdhI * dhdc; + + dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, + dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x + // [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) + + // dLdhI + NDArray WrT = Wr->transpose(); + MmulHelper::mmul(&dLdz, &WrT, + dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x + // [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) + + // dLdcI + dLdcI->assign(*dLdcI * dcdcI); // [bS, nOut](or[nOut]) + + if (x->rankOf() == 1) { + NDArray xT = x->reshape(x->ordering(), {nIn, 1}); // [nIn] -> [nIn, 1] + NDArray hIT = + hI->reshape(hI->ordering(), {nOut, 1}); // [nOut] -> [nOut, 1] + NDArray dLdzR = + dLdz.reshape(dLdz.ordering(), {1, 4 * nOut}); // [nOut] -> [1, 4*nOut] + + // dLdWx + *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] + } else { + // dLdWx + *dLdWx += + mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hI->transpose(), + dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] + } + + // dLdb + if (b && x->rankOf() == 1) + *dLdb += dLdz; // [4*nOut] + else if (b) + *dLdb += dLdz.reduceAlongDimension( + reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; + + // dLdWp + if (Wp && x->rankOf() == 1) { + (*dLdWp)({0, nOut}) += std::move(dLdzi) * (*cI); // [nOut] + (*dLdWp)({nOut, 2 * nOut}) += std::move(dLdzf) * (*cI); // [nOut] + (*dLdWp)({2 * nOut, 3 * nOut}) += std::move(dLdzo) * (*c); // [nOut] + } else if (Wp) { + NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); + (std::move(dLdzi) * (*cI)) + .reduceAlongDimension(reduce::Sum, temp, + {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({0, nOut}) += temp; + (std::move(dLdzf) * (*cI)) + .reduceAlongDimension(reduce::Sum, temp, + {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({nOut, 2 * nOut}) += temp; + (std::move(dLdzo) * (*c)) + .reduceAlongDimension(reduce::Sum, temp, + {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({2 * nOut, 3 * nOut}) += temp; + } } ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - const bool forward, + const NDArray* b, const NDArray* seqLen, + const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, const bool forward, NDArray* h, NDArray* hL, NDArray* cL) { - - // INPUTS: - // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // seqLen - [bS], optional, may be nullptr - // hI - initial output [bS, nOut], optional, may be nullptr - // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - - // OUTPUTS: - // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr - // hL - output at last step [bS, nOut], optional, may be nullptr - // cL - cell state at last step [bS, nOut], optional, may be nullptr - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - const std::vector shapeOut = {bS, nOut}; - - const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); - - auto h0 = const_cast(hI); - if(!hI) { - h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - h0->nullify(); - } - - auto c0 = const_cast(cI); - if(!cI) { - c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - c0->nullify(); - } - - auto ct = cL; - if(!cL) - ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - - auto ht = hL; - if(!h && !hL) - ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); - - // create sets of required (depends on seqLen presence) sub-arrays - std::vector dims; - ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr); - - if(!seqLen) { - - dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - if(h) - hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] - } - else { - - dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(h) - hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(ht) - htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - } - - // loops - if(forward) { - - if(!seqLen) { - - if(!h) { // seqLen and h are absent - - lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step - for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps - } - else { // seqLen is absent and h is present - - lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(0), ct); // first time step - for (Nd4jLong t = 1; t < sL; ++t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t - 1), ct, Wp, params, &hSet->at(t), ct); // rest time steps - - if(hL) - hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr - } + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may + // be nullptr hL - output at last step [bS, nOut], optional, may be nullptr cL + // - cell state at last step [bS, nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + const std::vector shapeOut = {bS, nOut}; + + const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); + + auto h0 = const_cast(hI); + if (!hI) { + h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + h0->nullify(); + } + + auto c0 = const_cast(cI); + if (!cI) { + c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + c0->nullify(); + } + + auto ct = cL; + if (!cL) ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + + auto ht = hL; + if (!h && !hL) + ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), + *htSet(nullptr), *ctSet(nullptr); + + if (!seqLen) { + dims = ShapeUtils::evalDimsToExclude( + x->rankOf(), + {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + if (h) + hSet = new ResultSet(h->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nOut] + } else { + dims = dataFormat == 2 ? std::vector({1}) + : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + h0Set = new ResultSet( + h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + c0Set = new ResultSet( + c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + ctSet = new ResultSet( + ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if (h) + hSet = new ResultSet( + h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] + if (ht) + htSet = new ResultSet( + ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + } + + // loops + if (forward) { + if (!seqLen) { + if (!h) { // seqLen and h are absent + + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, + ct); // first time step + for (Nd4jLong t = 1; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, + ct); // rest time steps + } else { // seqLen is absent and h is present + + lstmLayerCell(&xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(0), + ct); // first time step + for (Nd4jLong t = 1; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t - 1), ct, Wp, + params, &hSet->at(t), ct); // rest time steps + + if (hL) + hL->assign(hSet->at( + sL - 1)); // assign last output to hL if it is not nullptr + } + } else { + if (!h) { // seqLen is present and h is absent + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } } - else { - - if(!h) { // seqLen is present and h is absent - - for (Nd4jLong e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); - continue; - } - - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step + } else { // seqLen and h are present - for (int t = 1; t < limit; ++t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); - for (Nd4jLong e = 0; e < bS; ++e) { + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range - int limit = seqLen->e(e); + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); - if(limit == 0) { + continue; + } - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); + for (int t = 1; t < limit; ++t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } - continue; - } + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if hL is not nullptr - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); - lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step - - for (int t = 1; t < limit; ++t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } + } } - else { // backward - - if(!seqLen) { - - if(!h) { // seqLen and h are absent - - lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step - for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps - } - else { // seqLen is absent and h is present - - lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, &hSet->at(sL - 1), ct); // first time step - for (Nd4jLong t = sL - 2; t >= 0; --t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), ct, Wp, params, &hSet->at(t), ct); // rest time steps - - if(hL) - hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr - } + } else { // backward + + if (!seqLen) { + if (!h) { // seqLen and h are absent + + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, + ct); // first time step + for (Nd4jLong t = sL - 2; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, + ct); // rest time steps + } else { // seqLen is absent and h is present + + lstmLayerCell(&xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, + &hSet->at(sL - 1), ct); // first time step + for (Nd4jLong t = sL - 2; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), ct, Wp, + params, &hSet->at(t), ct); // rest time steps + + if (hL) + hL->assign( + hSet->at(0)); // assign last output to hL if it is not nullptr + } + } else if (directionMode == 1) { // only backward, no bidirectional mode + + if (!h) { // h is absent and seqLen is present + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } } - else if(directionMode == 1) { // only backward, no bidirectional mode - - if(!h) { // h is absent and seqLen is present - - for (Nd4jLong e = 0; e < bS; ++e) { + } else { // seqLen and h are present - const int limit = seqLen->e(e); + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); - if(limit == 0) { - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); - continue; - } + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); - for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present + continue; + } - for (Nd4jLong e = 0; e < bS; ++e) { + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step - int limit = seqLen->e(e); + for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } - if(limit == 0) { + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); - - continue; - } - - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); - lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step - - for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, 0, sL - limit, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } - else { // backward in bidirectional mode - - if(!h) { // h is absent and seqLen is present - - for (Nd4jLong e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); - continue; - } - - auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // first time step - - for (int t = limit - 2; t >= 0; --t) { - ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), &ctSet->at(e), Wp, params, &htSet->at(e), &ctSet->at(e)); // rest time steps - } - } - } - else { // seqLen and h are present - - for (Nd4jLong e = 0; e < bS; ++e) { - - int limit = seqLen->e(e); - - if(limit == 0) { - - tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - - if(cL) - ctSet->at(e).nullify(); - if(hL) - htSet->at(e).nullify(); - - continue; - } - - auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); - lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), Wp, params, &hSet->at(indPrev), &ctSet->at(e)); // first time step - - for (int t = limit - 2; t >= 0; --t) { - auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), &ctSet->at(e), Wp, params, &hSet->at(indCurr), &ctSet->at(e)); // rest time steps - indPrev = indCurr; - } - - if(hL) - htSet->at(e).assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - - if(limit != sL) - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } - } + } + } else { // backward in bidirectional mode + + if (!h) { // h is absent and seqLen is present + + for (Nd4jLong e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &h0Set->at(e), &c0Set->at(e), + Wp, params, &htSet->at(e), + &ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(ind), Wx, Wr, b, &htSet->at(e), + &ctSet->at(e), Wp, params, &htSet->at(e), + &ctSet->at(e)); // rest time steps + } + } + } else { // seqLen and h are present + + for (Nd4jLong e = 0; e < bS; ++e) { + int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*h, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + + if (cL) ctSet->at(e).nullify(); + if (hL) htSet->at(e).nullify(); + + continue; + } + + auto indPrev = + getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(&xSet->at(indPrev), Wx, Wr, b, &h0Set->at(e), + &c0Set->at(e), Wp, params, &hSet->at(indPrev), + &ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(&xSet->at(indCurr), Wx, Wr, b, &hSet->at(indPrev), + &ctSet->at(e), Wp, params, &hSet->at(indCurr), + &ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if (hL) + htSet->at(e).assign(hSet->at( + indPrev)); // assign last output to hL if it is not nullptr + + if (limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) } + } } - - delete xSet; - delete hSet; - delete h0Set; - delete c0Set; - delete htSet; - delete ctSet; - - if(!hI) - delete h0; - if(!cI) - delete c0; - if(!cL) - delete ct; - if(!h && !hL) - delete ht; + } + + delete xSet; + delete hSet; + delete h0Set; + delete c0Set; + delete htSet; + delete ctSet; + + if (!hI) delete h0; + if (!cI) delete c0; + if (!cL) delete ct; + if (!h && !hL) delete ht; } - ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* b, const NDArray* seqLen, NDArray* hI, + NDArray* cI, const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, const std::vector& params, const bool forward, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) { - - // INPUTS: - // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // seqLen - [bS], optional, may be nullptr - // hI - initial output [bS, nOut], optional, may be nullptr - // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr - // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr - // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr - - // OUTPUTS: - // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] - // dLdWx - gradient vs. input weights [nIn, 4*nOut] - // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] - // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr - // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr - // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr - // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); - const int nOut = Wx->sizeAt(-1) / 4; - - const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); - - auto dLdh0 = dLdhI; - if(!hI) - dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - - auto dLdc0 = dLdcI; - if(!cI) - dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - - NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); - NDArray a = z.ulike(); - NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); - NDArray c = h.ulike(); - - // create sets of required (depends on seqLen presence) sub-arrays - std::vector dims; - ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), - *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr); - - if(!seqLen) { - - dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] - hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] - cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] - zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] - aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] - if(dLdh) - dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] - } - else { - - dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis - - xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] - hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] - cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] - zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] - aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] - - if(hI) - hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(cI) - cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - - dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - - if(dLdh) - dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(dLdhL) - dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(dLdcL) - dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - } - - - // loops - if(forward) { - - if(!seqLen) { // seqLen is absent - - if(hI) - hSet->at(0).assign(hI); - else - hSet->at(0).nullify(); - if(cI) - cSet->at(0).assign(cI); - else - cSet->at(0).nullify(); - - // ff - for (int t = 0; t < sL; ++t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, params, &zSet->at(t), &aSet->at(t), &hSet->at(t+1), &cSet->at(t+1)); - - // bp - for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, - &zSet->at(t), &aSet->at(t), &cSet->at(t+1), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); - } + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, + NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, + NDArray* dLdWp) { + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, + // sL], optional, may be nullptr dLdhL - gradient vs. output at last time step + // [bS, nOut], optional, may be nullptr dLdcL - gradient vs. cell state at + // last time step [bS, nOut], optional, may be nullptr + + // OUTPUTS: + // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] + // dLdWx - gradient vs. input weights [nIn, 4*nOut] + // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] + // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr + // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr + // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, + // may be nullptr dLdWp - gradient vs. peephole weights [3*nOut], optional, + // may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const int bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const int nOut = Wx->sizeAt(-1) / 4; + + const auto type = + dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); + + auto dLdh0 = dLdhI; + if (!hI) + dLdh0 = new NDArray( + x->ordering(), {bS, nOut}, type, + x->getContext()); // this constructor nullifies array automatically + + auto dLdc0 = dLdcI; + if (!cI) + dLdc0 = new NDArray( + x->ordering(), {bS, nOut}, type, + x->getContext()); // this constructor nullifies array automatically + + NDArray z(x->ordering(), {sL, bS, 4 * nOut}, type, x->getContext()); + NDArray a = z.ulike(); + NDArray h(x->ordering(), {sL + 1, bS, nOut}, type, x->getContext()); + NDArray c = h.ulike(); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), + *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), *dLdh0Set(nullptr), + *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), + *hISet(nullptr), *cISet(nullptr); + + if (!seqLen) { + dims = ShapeUtils::evalDimsToExclude( + x->rankOf(), + {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nIn] + hSet = new ResultSet(h.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, nOut] + cSet = new ResultSet(c.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, nOut] + zSet = new ResultSet(z.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, 4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension( + {1, 2})); // sub-arrays with shape [bS, 4*nOut] + if (dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension( + dims)); // sub-arrays with shape [bS, nOut] + } else { + dims = dataFormat == 2 ? std::vector({1}) + : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet( + x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + dLdxSet = new ResultSet( + dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + hSet = new ResultSet( + h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + cSet = new ResultSet( + c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + zSet = new ResultSet( + z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + aSet = new ResultSet( + a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + + if (hI) + hISet = new ResultSet( + hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if (cI) + cISet = new ResultSet( + cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + dLdh0Set = new ResultSet( + dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + dLdc0Set = new ResultSet( + dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + if (dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension( + dims)); // sub-arrays with shape [nOut] + if (dLdhL) + dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension( + {1})); // sub-arrays with shape [nOut] + if (dLdcL) + dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension( + {1})); // sub-arrays with shape [nOut] + } + + // loops + if (forward) { + if (!seqLen) { // seqLen is absent + + if (hI) + hSet->at(0).assign(hI); + else + hSet->at(0).nullify(); + if (cI) + cSet->at(0).assign(cI); + else + cSet->at(0).nullify(); + + // ff + for (int t = 0; t < sL; ++t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, + params, &zSet->at(t), &aSet->at(t), &hSet->at(t + 1), + &cSet->at(t + 1)); + + // bp + for (int t = sL - 1; t >= 0; --t) { + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL - 1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL - 1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t), &cSet->at(t), Wp, + dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), + &cSet->at(t + 1), params, &dLdxSet->at(t), dLdWx, dLdWr, + dLdh0, dLdc0, dLdb, dLdWp); + } + } else { // seqLen is present + + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; } - else { // seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - hSet->at(e).assign(hISet->at(e)); - else - hSet->at(e).nullify(); - if(cI) - cSet->at(e).assign(cISet->at(e)); - else - cSet->at(e).nullify(); - - // ff - for (int t = 0; t < limit; ++t) - lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, params, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e)); - - // bp - for (int t = limit-1; t >= 0; --t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? &dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == limit-1 && dLdcL) ? &dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t*bS + e), &cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at((t+1)*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, - &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + + if (hI) + hSet->at(e).assign(hISet->at(e)); + else + hSet->at(e).nullify(); + if (cI) + cSet->at(e).assign(cISet->at(e)); + else + cSet->at(e).nullify(); + + // ff + for (int t = 0; t < limit; ++t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at(t * bS + e), &cSet->at(t * bS + e), Wp, params, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e)); + + // bp + for (int t = limit - 1; t >= 0; --t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == limit - 1 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == limit - 1 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at(t * bS + e), + &cSet->at(t * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at((t + 1) * bS + e), params, + &dLdxSet->at(ind), dLdWx, dLdWr, &dLdh0Set->at(e), + &dLdc0Set->at(e), dLdb, dLdWp); } + + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } } - else { // backward or bidirectional - - if(!seqLen) { // backward or bidirectional, seqLen is absent - - if(hI) - hSet->at(sL).assign(hI); - else - hSet->at(sL).nullify(); - if(cI) - cSet->at(sL).assign(cI); - else - cSet->at(sL).nullify(); - - // ff - for (int t = sL-1; t >= 0; --t) - lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, params, &zSet->at(t), &aSet->at(t), &hSet->at(t), &cSet->at(t)); - - // bp - for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; - lstmLayerCellBp(&xSet->at(t), Wx, Wr, b, &hSet->at(t+1), &cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, - &zSet->at(t), &aSet->at(t), &cSet->at(t), params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); - } - } - else if(directionMode == 1) { // backward, seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - hSet->at(sL*bS + e).assign(hISet->at(e)); - else - hSet->at(sL*bS + e).nullify(); - if(cI) - cSet->at(sL*bS + e).assign(cISet->at(e)); - else - cSet->at(sL*bS + e).nullify(); - - // ff - for (int t = sL - 1; t >= sL-limit; --t) - lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, params, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at(t*bS + e), &cSet->at(t*bS + e)); - - // bp - for (int t = sL-limit; t < sL; ++t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? &dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == sL-limit && dLdcL) ? &dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, - &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + } else { // backward or bidirectional + + if (!seqLen) { // backward or bidirectional, seqLen is absent + + if (hI) + hSet->at(sL).assign(hI); + else + hSet->at(sL).nullify(); + if (cI) + cSet->at(sL).assign(cI); + else + cSet->at(sL).nullify(); + + // ff + for (int t = sL - 1; t >= 0; --t) + lstmLayerCell(&xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), + &cSet->at(t + 1), Wp, params, &zSet->at(t), &aSet->at(t), + &hSet->at(t), &cSet->at(t)); + + // bp + for (int t = 0; t < sL; ++t) { + const NDArray* dLdhh = dLdh ? &dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp( + &xSet->at(t), Wx, Wr, b, &hSet->at(t + 1), &cSet->at(t + 1), Wp, + dLdhh, dLdhhL, dLdccL, &zSet->at(t), &aSet->at(t), &cSet->at(t), + params, &dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } else if (directionMode == 1) { // backward, seqLen is present + + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; } - else { // bidirectional mode, seqLen is present - - for (int e = 0; e < bS; ++e) { - - const int limit = seqLen->e(e); - - if(limit == 0) { - tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range - continue; - } - - if(hI) - h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e)); - else - h({limit,limit+1, e,e+1, 0,0}).nullify(); - if(cI) - c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e)); - else - c({limit,limit+1, e,e+1, 0,0}).nullify(); - - // ff - for (int t = limit - 1; t >= 0; --t) - lstmLayerCell(&xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, params, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &hSet->at(t*bS + e), &cSet->at(t*bS + e)); - - // bp - for (int t = 0; t < limit; ++t) { - const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; - const NDArray* dLdhhL = (t == 0 && dLdhL) ? &dLdhLSet->at(e) : nullptr; - const NDArray* dLdccL = (t == 0 && dLdcL) ? &dLdcLSet->at(e) : nullptr; - lstmLayerCellBp(&xSet->at(ind), Wx, Wr, b, &hSet->at((t+1)*bS + e), &cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, - &zSet->at(t*bS + e), &aSet->at(t*bS + e), &cSet->at(t*bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, - &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); - } - - if(limit != sL) - tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) - } + + if (hI) + hSet->at(sL * bS + e).assign(hISet->at(e)); + else + hSet->at(sL * bS + e).nullify(); + if (cI) + cSet->at(sL * bS + e).assign(cISet->at(e)); + else + cSet->at(sL * bS + e).nullify(); + + // ff + for (int t = sL - 1; t >= sL - limit; --t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e), + Wp, params, &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at(t * bS + e), &cSet->at(t * bS + e)); + + // bp + for (int t = sL - limit; t < sL; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == sL - limit && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == sL - limit && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp( + &xSet->at(ind), Wx, Wr, b, &hSet->at((t + 1) * bS + e), + &cSet->at((t + 1) * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at(t * bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); } - } - delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet; - delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet; + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, sL - limit, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } + } else { // bidirectional mode, seqLen is present - if(!hI) - delete dLdh0; - if(!cI) - delete dLdc0; -} + for (int e = 0; e < bS; ++e) { + const int limit = seqLen->e(e); + if (limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0, 0, e, e + 1) + .nullify(); // nullify for given e and whole time range + continue; + } -} -} -} + if (hI) + h({limit, limit + 1, e, e + 1, 0, 0}).assign(hISet->at(e)); + else + h({limit, limit + 1, e, e + 1, 0, 0}).nullify(); + if (cI) + c({limit, limit + 1, e, e + 1, 0, 0}).assign(cISet->at(e)); + else + c({limit, limit + 1, e, e + 1, 0, 0}).nullify(); + + // ff + for (int t = limit - 1; t >= 0; --t) + lstmLayerCell( + &xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, + Wr, b, &hSet->at((t + 1) * bS + e), &cSet->at((t + 1) * bS + e), + Wp, params, &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &hSet->at(t * bS + e), &cSet->at(t * bS + e)); + + // bp + for (int t = 0; t < limit; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? &dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = + (t == 0 && dLdhL) ? &dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = + (t == 0 && dLdcL) ? &dLdcLSet->at(e) : nullptr; + lstmLayerCellBp( + &xSet->at(ind), Wx, Wr, b, &hSet->at((t + 1) * bS + e), + &cSet->at((t + 1) * bS + e), Wp, dLdhh, dLdhhL, dLdccL, + &zSet->at(t * bS + e), &aSet->at(t * bS + e), + &cSet->at(t * bS + e), params, &dLdxSet->at(ind), dLdWx, dLdWr, + &dLdh0Set->at(e), &dLdc0Set->at(e), dLdb, dLdWp); + } + if (limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit, sL, e, e + 1) + .nullify(); // nullify for given e and time range [limit, sL) + } + } + } + + delete xSet; + delete dLdxSet; + delete hSet; + delete cSet; + delete aSet; + delete zSet; + delete dLdhSet; + delete dLdh0Set; + delete dLdc0Set; + delete dLdhLSet; + delete dLdcLSet; + delete hISet; + delete cISet; + + if (!hI) delete dLdh0; + if (!cI) delete dLdc0; +} +} // namespace helpers +} // namespace ops +} // namespace sd ////////////////////////////////////////////////////////////////////////// -// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, -// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, -// const std::vector& params, const bool firstIter, - -// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr, -// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, -// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { - -// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ +// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* +// Wr, +// const NDArray* b, NDArray* hI, NDArray* +// cI, const NDArray* Wp, const NDArray* dLdh, const +// std::vector& params, const bool firstIter, + +// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, +// NDArray* dhIdWr, NDArray* dcIdWr, NDArray* dhIdb, +// NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, +// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* +// dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + +// /************************ THIS IS NOT OPTIMAZED CODE +// ***********************************/ // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) @@ -1235,13 +1439,16 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // c = clip(f * cI + i * g) // // h = o * actH(c) -// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus +// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky +// relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, +// 9=softsign, 10=softplus // // params[0] - dataFormat, ignore // // params[1] - directionMode, ignore // // params[2] - cell clipping value, if it = 0 then do not apply clipping -// // params[3] - activation ID for input (i), forget (f) and output (o) gates +// // params[3] - activation ID for input (i), forget (f) and output (o) +// gates // // params[4] - alpha value for gates activation // // params[5] - beta value for gates activation @@ -1254,32 +1461,50 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // params[11] - beta value for output activation // // INPUTS: -// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr +// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != +// nullptr // // Wx - input weights [nIn, 4*nOut] // // Wr - recurrent weights [nOut, 4*nOut] // // b - biases [4*nOut], optional, may be nullptr -// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr -// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or +// [nOut] if seqLen != nullptr +// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or +// [nOut] if seqLen != nullptr // // Wp - peephole weights [3*nOut], optional, may be nullptr -// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr -// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr -// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr -// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr -// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr -// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr -// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr -// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr +// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if +// seqLen != nullptr +// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if +// seqLen != nullptr +// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, +// bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, +// bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, +// bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, +// bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWp - derivative from previous time step, [3*nOut], optional, may +// be nullptr +// // dhIdWp - derivative from previous time step, [3*nOut], optional, may +// be nullptr +// // dcIdb - derivative from previous time step, [4*nOut], optional, may +// be nullptr +// // dhIdb - derivative from previous time step, [4*nOut], optional, may +// be nullptr // // OUTPUTS: -// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr +// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if +// seqLen != nullptr // // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] // // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] -// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] -// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr -// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr -// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] +// // dLdb - loss derivative with respect to b, optional, may be nullptr, +// [4*nOut] +// // dLdhI - loss derivative with respect to hI, optional may be nullptr, +// [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdcI - loss derivative with respect to cI, optional may be nullptr, +// [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, +// [3*nOut] // // !!! dimension 4*nOut implies order i, f, g, o // // !!! dimension 3*nOut implies order i, f, o @@ -1302,52 +1527,100 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // dhIdcI = dhdc_from_previous_time_step -// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] -// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] -// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] - -// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] - -// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] -// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] - -// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] - -// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] -// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] - -// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] - -// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] -// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, +// [bS, nIn] +// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, +// [bS, nOut] +// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, +// nOut] + +// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, +// dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, +// dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, +// dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; +// dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, +// dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, +// dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, +// dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, +// dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, +// dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, +// dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, +// dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, +// dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, +// dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; +// dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + +// tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + +// tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dcIdWpo = (0 + tempIFE*dhIdWpo + +// tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, +// [bS, nOut]->reduce->[bS] + +// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + +// tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + +// tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, +// [bS, nOut]->reduce->[bS] +// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + +// tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, +// [bS, nOut]->reduce->[bS] + +// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + +// tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + +// tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + +// tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if +// firstIter, [bS, nOut]->reduce->[bS] +// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + +// tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if +// firstIter, [bS, nOut]->reduce->[bS] + +// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + +// tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + +// tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + +// tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, +// nOut]->reduce->[bS] +// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + +// tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, +// nOut]->reduce->[bS] // const Nd4jLong nOut = Wx->sizeAt(-1) / 4; -// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr); -// if(Wp) { +// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), +// *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), +// *dhIdWpf(nullptr), *dhIdWpo(nullptr); if(Wp) { // Wpi = new NDArray((*Wp)({0, nOut})); // Wpf = new NDArray((*Wp)({nOut, 2*nOut})); // Wpo = new NDArray((*Wp)({2*nOut, 3*nOut})); @@ -1359,8 +1632,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut})); // } -// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr); -// if(b) { +// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), +// *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), +// *dhIdbo(nullptr); if(b) { // dhIdbi = new NDArray((*dhIdb)({0, nOut})); // dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut})); // dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut})); @@ -1371,53 +1645,91 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut})); // } -// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr - -// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr - -// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr - -// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr -// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr - -// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn] -// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn] -// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn] -// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn] - -// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut] -// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut] -// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut] -// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut] +// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : +// (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] +// or [nIn, nOut, nOut] if seqLen != nullptr NDArray dhIdWxf = x->rankOf() +// == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, +// nOut] if seqLen != nullptr NDArray dhIdWxg = x->rankOf() == 1 ? +// (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != +// nullptr NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : +// (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] +// or [nOut, nOut, nOut] if seqLen != nullptr NDArray dhIdWrf = x->rankOf() +// == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, +// nOut] if seqLen != nullptr NDArray dhIdWrg = x->rankOf() == 1 ? +// (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen +// != nullptr NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : +// (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] +// or [nIn, nOut, nOut] if seqLen != nullptr NDArray dcIdWxf = x->rankOf() +// == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, +// nOut] if seqLen != nullptr NDArray dcIdWxg = x->rankOf() == 1 ? +// (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != +// nullptr NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : +// (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] +// or [nOut, nOut, nOut] if seqLen != nullptr NDArray dcIdWrf = x->rankOf() +// == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, +// nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, +// nOut] if seqLen != nullptr NDArray dcIdWrg = x->rankOf() == 1 ? +// (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, +// 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen +// != nullptr NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, +// 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // +// [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // +// [nOut, nIn] NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // +// [nOut, nIn] NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // +// [nOut, nIn] NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // +// [nOut, nIn] + +// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // +// [nOut, nOut] NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // +// [nOut, nOut] NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // +// [nOut, nOut] NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // +// [nOut, nOut] // // ***** feed forward step ***** // -// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] -// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] +// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, +// 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] +// //or [nIn] * [nIn, 4*nOut] + +// [nOut] * [nOut, 4*nOut] = +// [4*nOut] // // add biases if they are given // if(b) -// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut]) +// z += *b; // broadcast [bS, 4*nOut] +// + [4*nOut] = [bS, 4*nOut](or[4*nOut]) -// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) -// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) -// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) -// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) +// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // +// input gate i, [bS, nOut](or[nOut]) auto zf = x->rankOf() == 1 ? z({nOut, +// 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, +// nOut](or[nOut]) auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : +// z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) +// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, +// 4*nOut}); // output gate o, [bS, nOut](or[nOut]) // // peephole connections for input and forget gates // if(Wp) { -// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) -// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] +// = [bS, nOut](or[nOut]) zf += *cI * *Wpf; // broadcast: [bS, +// nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) // } // NDArray i = zi.ulike(); // [bS, nOut] @@ -1427,69 +1739,79 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // applyActivation(zf, params[3], params[4], params[5], f); // applyActivation(zg, params[6], params[7], params[8], g); -// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) +// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, +// nOut] * [bS, nOut] = [bS, nOut](or[nOut]) -// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation -// if(params[2] != 0) +// // if clipping value is non-zero then cell state is clipped by this value +// prior to the cell output activation if(params[2] != 0) // c.applyScalar(scalar::LstmClip, params[2], c); // // peephole connections for output gate // if(Wp) -// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = +// [bS, nOut](or[nOut]) // NDArray o = zo.ulike(); // [bS, nOut](or[nOut]) // applyActivation(zo, params[3], params[4], params[5], o); // // ***** back prop step ***** // -// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut]) -// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut]) +// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, +// nOut, bS, nOut] (or [nIn, nOut, nOut]) NDArray dWrJacobian = +// mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or +// [nOut, nOut, nOut]) // // dodzo -// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zo, params[3], params[4], params[5], dodzo); +// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zo, +// params[3], params[4], params[5], dodzo); // // dhdzo = dhdo*dodzo = actH(c)*dodzo -// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) -// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c) +// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) applyActivation(c, +// params[9], params[10], params[11], dhdzo); // actH(c) // hI->assign(o*dhdzo); // dhdzo *= dodzo; // // dcdzi = dcdi*didzi -// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi -// dcdzi *= g; // dcdi = g*clipDeriv +// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zi, +// params[3], params[4], params[5], dcdzi); // didzi dcdzi *= g; +// // dcdi = g*clipDeriv // // dcdzf = dcdf*dfdzf -// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf -// dcdzf *= *cI; // dcdf = cI*clipDeriv +// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zf, +// params[3], params[4], params[5], dcdzf); // dfdzf dcdzf *= +// *cI; // dcdf = +// cI*clipDeriv // // dcdzg = dcde*dedzg -// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) -// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg -// dcdzg *= i; // dcdf = i*clipDeriv +// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) activationDeriv(zg, +// params[6], params[7], params[8], dcdzg); // dedzg dcdzg *= i; +// // dcdf = i*clipDeriv // // dcdcI -// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) +// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) // // take into account possible deposit from clipping derivative // clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI); // // dzodc -// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzodc = Wpo; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dzidcI -// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzidcI = Wpi; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dzfdcI -// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) +// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, +// nOut] actually, however it will be broadcasted appropriately in future +// calcus (element-wise multiplication) // // dhdc // NDArray dhdc = c.ulike(); -// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut] -// dhdc *= o; -// if(Wp) +// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, +// nOut] dhdc *= o; if(Wp) // dhdc += dhdzo* *dzodc; // NDArray factor = *dLdh * dhdc; @@ -1504,25 +1826,33 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI); // // dLdx -// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) +// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, +// WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) // // NDArray temp = c.ulike(); -// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c) -// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut]) +// // applyActivation(c, params[9], params[10], params[11], temp); // +// actH(c) +// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + +// mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + +// mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // +// [bS, nIn](or[nOut]) // // dLdhI // NDArray* dLdhII = dLdhI; // if(dLdcI && !dLdhI) // dLdhII = new NDArray(dLdcI->ulike()); -// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) +// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, +// WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) // if(firstIter) { // // dLdcI // if(dLdcI) -// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut]) +// dLdcI->assign(factor*tempC); // [bS, +// nOut](or[nOut]) // // dcIdWxi(dcdWxi) -// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian); // // dcIdWxg(dcdWxg) @@ -1531,7 +1861,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWxo.nullify(); // // dhIdWxi -// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf); // // dhIdWxg @@ -1540,7 +1871,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */); // // dcIdWri(dcdWri) -// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; +// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * +// [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian); // // dcIdWrg(dcdWrg) @@ -1549,7 +1881,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWro.nullify(); // // dhIdWri -// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * +// [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf); // // dhIdWrg @@ -1574,18 +1907,23 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // } // else if(Wp) { // // dcIdWpi -// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdWpf -// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdWpo -// dcIdWpo->nullify(); // [nOut] +// dcIdWpo->nullify(); // [nOut] // // dhIdWpi -// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, +// *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf -// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, +// *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo -// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, +// *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { @@ -1610,36 +1948,45 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // } // else if(b) { // // dcIdbi -// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbf -// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbg -// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // +// [bS, nOut]->reduce->[nOut] // // dcIdbo // dcIdbo->nullify(); // [nOut] // //dhIdbi -// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf -// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg -// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo -// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] // } // } // else { -// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT); -// NDArray tempO = mmul(dhdzo, WroT); +// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, +// WrgT); NDArray tempO = mmul(dhdzo, WroT); // // dLdcI // if(dLdcI) // dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI)); // // dcIdWxi(dcdWxi) -// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); +// // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, +// nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf); // // dcIdWxg(dcdWxg) @@ -1648,7 +1995,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo); // // dhIdWxi -// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] +// * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf); // // dhIdWxg @@ -1657,7 +2005,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo); // // dcIdWri(dcdWri) -// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); +// // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, +// nOut, nOut]); // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf); // // dcIdWrg(dcdWrg) @@ -1666,7 +2016,8 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro); // // dhIdWri -// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] +// * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf); // // dhIdWrg @@ -1676,99 +2027,150 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if(Wp && x->rankOf() == 1) { // // dcIdWpi -// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut] +// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + +// tempC*(*dcIdWpi)); // [nOut] * [nOut] // // dcIdWpf -// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut] +// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + +// tempC*(*dcIdWpf)); // [nOut] * [nOut] // // dcIdWpo -// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut] +// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + +// tempC*(*dcIdWpo)); // [nOut] * [nOut] // // dhdWpi -// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut] +// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * +// [nOut] // // dhdWpf -// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut] +// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * +// [nOut] // // dhdWpo -// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut] +// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // +// [nOut] * [nOut] // } // else if(Wp) { // // dcIdWpi -// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + +// tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpf -// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + +// tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpo -// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (/* 0 + */ tempIFE*(*dhIdWpo) + +// tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpi -// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdWpi) + +// tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf -// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdWpf) + +// tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo -// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdzo*c + dhdc*(*dcIdWpo) + +// tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { // // dcIdbi -// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut] +// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // +// [nOut] // // dcIdbf -// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut] +// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // +// [nOut] // // dcIdbg -// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut] +// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // +// [nOut] // // dcIdbo -// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut] +// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // +// [nOut] // //dhIdbi -// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut] +// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // +// [nOut] // //dhIdbf -// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut] +// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // +// [nOut] // //dhIdbg -// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut] +// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // +// [nOut] // //dhIdbo -// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut] +// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // +// [nOut] // } // else if(b) { // // dcIdbi -// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzi + tempIFE*(*dhIdbi) + +// tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbf -// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzf + tempIFE*(*dhIdbf) + +// tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbg -// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// (dcdzg + tempIFE*(*dhIdbg) + +// tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); +// // [bS, nOut]->reduce->[nOut] // // dcIdbo -// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut] +// (/*0+*/ tempIFE*(*dhIdbo) + +// tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); +// // [bS, nOut]->reduce->[nOut] // //dhIdbi -// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbi) + +// tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf -// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbf) + +// tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg -// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdc*(*dcIdbg) + +// tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo -// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (dhdzo + dhdc*(*dcIdbo) + +// tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); +// // ([bS, nOut] * [nOut])->reduce->[nOut] // } // } -// const std::vector dimsToExclude = x->rankOf() == 1 ? std::vector({2}) : std::vector({2, 3}); +// const std::vector dimsToExclude = x->rankOf() == 1 ? +// std::vector({2}) : std::vector({2, 3}); // // dLdWxi, dLdWxf, dLdWxg, dLdWxo -// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude); +// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, +// dimsToExclude); // // dLdWri, dLdWrf, dLdWrg, dLdWro -// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude); +// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, +// dimsToExclude); // // dLdWpi, dLdWpf, dLdWpo // if(Wp) { // if(x->rankOf() == 1) { -// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut] -// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut] -// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut] +// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] +// * [nOut] +// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] +// * [nOut] +// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] +// * [nOut] // } // else { // // NDArray temp1 = (*dLdWp)({0, nOut}); // // NDArray temp2 = (*dLdWp)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut}); -// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // +// ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdWp)({0, nOut}).assign(dhIdWpi); // (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf); // (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo); @@ -1778,20 +2180,28 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // dLdbi, dLdbf, dLdbg, dLdbo // if(b) { // if(x->rankOf() == 1) { -// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut] -// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut] -// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut] -// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut] +// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * +// [nOut] +// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * +// [nOut] +// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * +// [nOut] +// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * +// [nOut] // } // else { // // NDArray temp1 = (*dLdb)({0, nOut}); // // NDArray temp2 = (*dLdb)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut}); // // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut}); -// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] -// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, +// {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdb)({0, nOut}).assign(dhIdbi); // (*dLdb)({nOut, 2*nOut}).assign(dhIdbf); // (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg); @@ -1808,9 +2218,11 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // if(dLdcI && !dLdhI) // delete dLdhII; // if(Wp) { -// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; +// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; +// delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; // } // if(b) { -// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; +// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete +// dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; // } // } diff --git a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp index ef04b9a4e240..710199b1ab3f 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/multiUnique.cpp @@ -18,50 +18,53 @@ // @author sgazeos@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - bool multiUnique(std::vector const& inputList, sd::memory::Workspace *workspace) { - Nd4jLong length = 0; - std::vector reshaped(inputList.size()); - int pos = 0; - Nd4jLong axis = 0; - Context cContext(1); - for (auto array: inputList) { - if (array->dataType() != sd::DataType::INT32) - throw std::runtime_error("multiUnique: this op support INT32 data type only."); - - reshaped[pos] = array->reshape(array->ordering(), {-1}); - cContext.setInputArray(pos, reshaped[pos]); +bool multiUnique(std::vector const& inputList, + sd::memory::Workspace* workspace) { + Nd4jLong length = 0; + std::vector reshaped(inputList.size()); + int pos = 0; + Nd4jLong axis = 0; + Context cContext(1); + for (auto array : inputList) { + if (array->dataType() != sd::DataType::INT32) + throw std::runtime_error( + "multiUnique: this op support INT32 data type only."); - length += array->lengthOf(); - pos++; - } - NDArray arrayFull('c', {length}, sd::DataType::INT32, inputList[0]->getContext()); - cContext.setOutputArray(0, arrayFull); - cContext.setIArguments(&axis, 1); + reshaped[pos] = array->reshape(array->ordering(), {-1}); + cContext.setInputArray(pos, reshaped[pos]); - sd::ops::concat opConcat; - auto cResult = opConcat.execute(&cContext); - if (Status::OK() != cResult) - throw std::runtime_error("multiUnique: cannot execute concat op properly."); + length += array->lengthOf(); + pos++; + } + NDArray arrayFull('c', {length}, sd::DataType::INT32, + inputList[0]->getContext()); + cContext.setOutputArray(0, arrayFull); + cContext.setIArguments(&axis, 1); - sd::ops::unique opUnique; - auto uResult = opUnique.evaluate({&arrayFull}); - if (Status::OK() != uResult.status()) - throw std::runtime_error("multiUnique: cannot execute unique op properly."); + sd::ops::concat opConcat; + auto cResult = opConcat.execute(&cContext); + if (Status::OK() != cResult) + throw std::runtime_error("multiUnique: cannot execute concat op properly."); - auto uniqueVals = uResult.at(0); + sd::ops::unique opUnique; + auto uResult = opUnique.evaluate({&arrayFull}); + if (Status::OK() != uResult.status()) + throw std::runtime_error("multiUnique: cannot execute unique op properly."); - bool res = uniqueVals.lengthOf() == arrayFull.lengthOf(); + auto uniqueVals = uResult.at(0); - return res; - } + bool res = uniqueVals.lengthOf() == arrayFull.lengthOf(); + return res; } -} -} + +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp index f910f07ed101..60bdcc20fc13 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/rnn.cpp @@ -18,80 +18,84 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 16.04.2018 // -// function nnCell implements an Elman RNN cell: output = activation(Wx*x + bx + Wh*ht + bh) +// function nnCell implements an Elman RNN cell: output = activation(Wx*x + bx +// + Wh*ht + bh) -#include #include +#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// -void rnnCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) { - - // xt input [bS x iS] - // Wx input-to-hidden weights, [iS x nU] - // Wh hidden-to-hidden weights, [nU x nU] - // b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases - // hPrev previous cell output [bS x nU], that is at previous time step t-1, in case of projection=false -> nU=nU!!! - - const int nU = hPrev->sizeAt(1); - - // ht is current cell output [bS x nU], that is at current time step t - ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] - ht->applyTransform(transform::Tanh, *ht); +void rnnCell(sd::LaunchContext* context, const NDArray* xt, const NDArray* Wx, + const NDArray* Wh, const NDArray* b, const NDArray* hPrev, + NDArray* ht) { + // xt input [bS x iS] + // Wx input-to-hidden weights, [iS x nU] + // Wh hidden-to-hidden weights, [nU x nU] + // b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are + // hidden-to-hidden biases hPrev previous cell output [bS x nU], that is at + // previous time step t-1, in case of projection=false -> nU=nU!!! + + const int nU = hPrev->sizeAt(1); + + // ht is current cell output [bS x nU], that is at current time step t + ht->assign( + mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + + (*b)({{nU, + 2 * nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU] + ht->applyTransform(transform::Tanh, *ht); } ////////////////////////////////////////////////////////////////////////// -void rnnTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) { - - // x input [time x bS x iS] - // Wx input-to-hidden weights, [iS x nU] - // Wh hidden-to-hidden weights, [nU x nU] - // b biases for, [2*nU] - - // h0 initial cell output (at time step = 0) [bS x nU] - // maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep - - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - - // at first time step - if(h0) - hFinal->assign(h0); - else - *hFinal = 0.; - - BlasHelper::getInstance(); // to avoid memory leak in pragma parallel loops - // loop through batch of inputs - for (int e = 0; e < bS; ++e) { - // loop through time steps - for (int t = 0; t < time; ++t) { - - int maxStep = maxTimeStep ? maxTimeStep->e(e) : time; - - auto xt = (*x)({t,t+1, e,e+1, 0,0}, true); - auto ht = (*h)({t,t+1, e,e+1, 0,0}, true); - auto hPrev = (*hFinal)({e,e+1, 0,0}, true); // previous state - - if(t >= maxStep) { - ht = 0.; - if(maxStep != 0) - hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0})); - } - else { - helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht); - hPrev.assign(ht); - } - } +void rnnTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, + NDArray* hFinal) { + // x input [time x bS x iS] + // Wx input-to-hidden weights, [iS x nU] + // Wh hidden-to-hidden weights, [nU x nU] + // b biases for, [2*nU] + + // h0 initial cell output (at time step = 0) [bS x nU] + // maxTimeStep vector [bS] containing integer values within [0,time), each + // element of this vector set max time step per each input in batch, this + // means there are no calculations for time >= maxTimeStep + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + + // at first time step + if (h0) + hFinal->assign(h0); + else + *hFinal = 0.; + + BlasHelper::getInstance(); // to avoid memory leak in pragma parallel loops + // loop through batch of inputs + for (int e = 0; e < bS; ++e) { + // loop through time steps + for (int t = 0; t < time; ++t) { + int maxStep = maxTimeStep ? maxTimeStep->e(e) : time; + + auto xt = (*x)({t, t + 1, e, e + 1, 0, 0}, true); + auto ht = (*h)({t, t + 1, e, e + 1, 0, 0}, true); + auto hPrev = (*hFinal)({e, e + 1, 0, 0}, true); // previous state + + if (t >= maxStep) { + ht = 0.; + if (maxStep != 0) + hPrev.assign((*h)({maxStep - 1, maxStep, e, e + 1, 0, 0})); + } else { + helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht); + hPrev.assign(ht); + } } + } } - -} -} -} - +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp index bbcb1eca3765..103ca62e40ab 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/sparse_to_dense.cpp @@ -18,106 +18,111 @@ // @author raver119@gmail.com // - -#include -#include #include +#include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { - auto values = reinterpret_cast(vvalues); - auto indices = reinterpret_cast(vindices); - auto output = reinterpret_cast(voutput); - - int coords[MAX_RANK]; - uint64_t pos = 0; - for (uint64_t e = 0L; e < length; e++) { - // indices come in blocks - for (uint8_t p = 0; p < rank; p++) { - coords[p] = indices[pos++]; - } - - // fill output at given coords with sparse value - output[shape::getOffset(zShapeInfo, coords)] = values[e]; - } - - } - - void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) { - // make sure host buffer is updated - values.syncToHost(); - indices.syncToHost(); - - auto rank = output.rankOf(); - - if (output.isS()) { - // string case is not so trivial, since elements might, and probably will, have different sizes - auto numValues = values.lengthOf(); - auto numElements = output.lengthOf(); - - // first of all we calculate final buffer sizes and offsets - auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); - auto valuesLength = StringUtils::byteLength(values); - auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength; - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); - - // now we make sure our output buffer can hold results - output.dataBuffer()->expand( bufferLength + headerLength); - - std::vector outputCoords(rank); - std::vector valueCoords(rank); - - auto offsetsBuffer = output.bufferAsT(); - auto dataBuffer = reinterpret_cast(offsetsBuffer + output.lengthOf()); - - offsetsBuffer[0] = 0; - - // getting initial value coords - for (int e = 0; e < rank; e++) - valueCoords[e] = indices.e(e); - - // write results individually - for (Nd4jLong e = 0; e < numElements; e++) { - auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); - auto cLength = 0L; - std::string str; - if (vIndex == e) { - // we're writing down sparse value here - str = values.e(e); - } else { - // we're writing down default value if it exists - if (def != nullptr) - str = def->e(0); - else - str = ""; - } - - // TODO: make it unicode compliant - memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); - - // writing down offset - offsetsBuffer[e+1] = cLength; - } - } else { - // numeric case is trivial, since all elements have equal sizes - - // write out default values, if they are present - if (def != nullptr) { - output.assign(def); - - // make sure output is synced back - output.syncToHost(); - } - - // write out values - BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.buffer(), indices.buffer(), output.buffer(), output.shapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES); - } - // copy back to device, if there's any - output.syncToDevice(); - } - } +namespace ops { +namespace helpers { +template +static void fill_(const void *vvalues, const void *vindices, void *voutput, + const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { + auto values = reinterpret_cast(vvalues); + auto indices = reinterpret_cast(vindices); + auto output = reinterpret_cast(voutput); + + int coords[MAX_RANK]; + uint64_t pos = 0; + for (uint64_t e = 0L; e < length; e++) { + // indices come in blocks + for (uint8_t p = 0; p < rank; p++) { + coords[p] = indices[pos++]; } -} \ No newline at end of file + + // fill output at given coords with sparse value + output[shape::getOffset(zShapeInfo, coords)] = values[e]; + } +} + +void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, + NDArray *def, NDArray &output) { + // make sure host buffer is updated + values.syncToHost(); + indices.syncToHost(); + + auto rank = output.rankOf(); + + if (output.isS()) { + // string case is not so trivial, since elements might, and probably will, + // have different sizes + auto numValues = values.lengthOf(); + auto numElements = output.lengthOf(); + + // first of all we calculate final buffer sizes and offsets + auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); + auto valuesLength = StringUtils::byteLength(values); + auto bufferLength = + defaultLength * (output.lengthOf() - numValues) + valuesLength; + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); + + // now we make sure our output buffer can hold results + output.dataBuffer()->expand(bufferLength + headerLength); + + std::vector outputCoords(rank); + std::vector valueCoords(rank); + + auto offsetsBuffer = output.bufferAsT(); + auto dataBuffer = + reinterpret_cast(offsetsBuffer + output.lengthOf()); + + offsetsBuffer[0] = 0; + + // getting initial value coords + for (int e = 0; e < rank; e++) valueCoords[e] = indices.e(e); + + // write results individually + for (Nd4jLong e = 0; e < numElements; e++) { + auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); + auto cLength = 0L; + std::string str; + if (vIndex == e) { + // we're writing down sparse value here + str = values.e(e); + } else { + // we're writing down default value if it exists + if (def != nullptr) + str = def->e(0); + else + str = ""; + } + + // TODO: make it unicode compliant + memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); + + // writing down offset + offsetsBuffer[e + 1] = cLength; + } + } else { + // numeric case is trivial, since all elements have equal sizes + + // write out default values, if they are present + if (def != nullptr) { + output.assign(def); + + // make sure output is synced back + output.syncToHost(); + } + + // write out values + BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, + (values.buffer(), indices.buffer(), output.buffer(), + output.shapeInfo(), rank, values.lengthOf()), + LIBND4J_TYPES, INDEXING_TYPES); + } + // copy back to device, if there's any + output.syncToDevice(); +} +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp index c67b713c282d..03d9dbc3e831 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp @@ -18,93 +18,98 @@ // @author sgazeos@gmail.com // -#include -#include #include +#include #include +#include namespace sd { namespace ops { namespace helpers { - template - static Nd4jLong uniqueCount_(NDArray* input) { - Nd4jLong count = 0; +template +static Nd4jLong uniqueCount_(NDArray* input) { + Nd4jLong count = 0; - std::vector values; + std::vector values; - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - T v = input->e(e); - if (std::find(values.begin(), values.end(), v) == values.end()) { - values.push_back(v); - count++; - } - } - return count; + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + T v = input->e(e); + if (std::find(values.begin(), values.end(), v) == values.end()) { + values.push_back(v); + count++; } + } + return count; +} + +Nd4jLong uniqueCount(sd::LaunchContext* context, NDArray* input) { + BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueCount_, (input), + LIBND4J_TYPES); +} - Nd4jLong uniqueCount(sd::LaunchContext * context, NDArray* input) { - BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueCount_, (input), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray * input), + LIBND4J_TYPES); + +template +static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, + NDArray* indices, NDArray* counts) { + std::vector valuesVector; + MAP_IMPL indicesMap; + MAP_IMPL countsMap; + + for (Nd4jLong e = 0; e < input->lengthOf(); e++) { + T v = input->e(e); + if (std::find(valuesVector.begin(), valuesVector.end(), v) == + valuesVector.end()) { + valuesVector.push_back(v); + indicesMap[v] = e; + countsMap[v] = 1; + } else { + countsMap[v]++; } + } - BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray* input), LIBND4J_TYPES); - - template - static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) { - - std::vector valuesVector; - MAP_IMPL indicesMap; - MAP_IMPL countsMap; - - for (Nd4jLong e = 0; e < input->lengthOf(); e++) { - T v = input->e(e); - if (std::find(valuesVector.begin(), valuesVector.end(), v) == valuesVector.end()) { - valuesVector.push_back(v); - indicesMap[v] = e; - countsMap[v] = 1; - } - else { - countsMap[v]++; - } - } - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - values->p(e, static_cast(valuesVector[e])); - if (counts != nullptr) - counts->p(e, countsMap[valuesVector[e]]); - } - }; - samediff::Threads::parallel_for(func, 0, values->lengthOf()); - - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - auto posI = std::find(valuesVector.begin(), valuesVector.end(), input->e(e)); - auto dist = std::distance(valuesVector.begin(), posI); - indices->p(e, Nd4jLong(dist));//indicesMap[(*input)(e)]; - } - - return Status::OK(); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + values->p(e, static_cast(valuesVector[e])); + if (counts != nullptr) counts->p(e, countsMap[valuesVector[e]]); } + }; + samediff::Threads::parallel_for(func, 0, values->lengthOf()); - Nd4jStatus uniqueFunctor(sd::LaunchContext * context, NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) { - input->syncToHost(); - values->syncToHost(); - indices->syncToHost(); + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + auto posI = + std::find(valuesVector.begin(), valuesVector.end(), input->e(e)); + auto dist = std::distance(valuesVector.begin(), posI); + indices->p(e, Nd4jLong(dist)); // indicesMap[(*input)(e)]; + } - if (counts != nullptr) - counts->syncToHost(); + return Status::OK(); +} - BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueFunctor_,(input, values, indices, counts), LIBND4J_TYPES); +Nd4jStatus uniqueFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* values, NDArray* indices, NDArray* counts) { + input->syncToHost(); + values->syncToHost(); + indices->syncToHost(); - input->syncToDevice(); - values->syncToDevice(); - indices->syncToDevice(); + if (counts != nullptr) counts->syncToHost(); - if (counts != nullptr) - counts->syncToDevice(); - } + BUILD_SINGLE_SELECTOR(input->dataType(), return uniqueFunctor_, + (input, values, indices, counts), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template Nd4jStatus uniqueFunctor_, (NDArray* input, NDArray* values, NDArray* indices, NDArray* counts), LIBND4J_TYPES); -} + input->syncToDevice(); + values->syncToDevice(); + indices->syncToDevice(); + + if (counts != nullptr) counts->syncToDevice(); } -} \ No newline at end of file + +BUILD_SINGLE_TEMPLATE(template Nd4jStatus uniqueFunctor_, + (NDArray * input, NDArray* values, NDArray* indices, + NDArray* counts), + LIBND4J_TYPES); +} // namespace helpers +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/impl/where.cpp b/libnd4j/include/ops/declarable/helpers/impl/where.cpp index 9485f2e2d299..ba360f919dc3 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/where.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/where.cpp @@ -18,44 +18,49 @@ // Created by raver119 on 24/09/18. // -#include #include +#include namespace sd { - namespace ops { - namespace helpers { - template - static void __where(NDArray &condition, NDArray& output, memory::Workspace *workspace) { - NDArrayList list(0, true); - int cnt = 0; - - int idx[MAX_RANK]; +namespace ops { +namespace helpers { +template +static void __where(NDArray &condition, NDArray &output, + memory::Workspace *workspace) { + NDArrayList list(0, true); + int cnt = 0; - for (Nd4jLong e = 0; e < condition.lengthOf(); e++) { + int idx[MAX_RANK]; - shape::index2coordsCPU(0, e, condition.shapeInfo(), idx); + for (Nd4jLong e = 0; e < condition.lengthOf(); e++) { + shape::index2coordsCPU(0, e, condition.shapeInfo(), idx); - auto offset = shape::getOffset(condition.shapeInfo(), idx); + auto offset = shape::getOffset(condition.shapeInfo(), idx); - if (condition.e(offset)) { - auto array = NDArrayFactory::create('c', {1, condition.rankOf()}, output.dataType(), output.getContext()); - for (int f = 0; f < condition.rankOf(); f++) - array.p(f, (T) idx[f]); + if (condition.e(offset)) { + auto array = NDArrayFactory::create( + 'c', {1, condition.rankOf()}, output.dataType(), output.getContext()); + for (int f = 0; f < condition.rankOf(); f++) array.p(f, (T)idx[f]); - list.write(cnt++, array); - } - } + list.write(cnt++, array); + } + } - auto s = list.stack(); - output.assign(s); - } - BUILD_SINGLE_TEMPLATE(template void __where,(NDArray &condition, NDArray& output, memory::Workspace *workspace), LIBND4J_TYPES); + auto s = list.stack(); + output.assign(s); +} +BUILD_SINGLE_TEMPLATE(template void __where, + (NDArray & condition, NDArray &output, + memory::Workspace *workspace), + LIBND4J_TYPES); - void _where(sd::LaunchContext * context, NDArray &condition, NDArray& output, memory::Workspace *workspace) { - condition.syncToHost(); - BUILD_SINGLE_SELECTOR(output.dataType(), __where, (condition, output, workspace), LIBND4J_TYPES); - output.syncToDevice(); - } - } - } +void _where(sd::LaunchContext *context, NDArray &condition, NDArray &output, + memory::Workspace *workspace) { + condition.syncToHost(); + BUILD_SINGLE_SELECTOR(output.dataType(), __where, + (condition, output, workspace), LIBND4J_TYPES); + output.syncToDevice(); } +} // namespace helpers +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/ismax.h b/libnd4j/include/ops/declarable/helpers/ismax.h index 6052b362484b..32b322ea6aaf 100644 --- a/libnd4j/include/ops/declarable/helpers/ismax.h +++ b/libnd4j/include/ops/declarable/helpers/ismax.h @@ -24,15 +24,15 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void ismax(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector& dimensions); +void ismax(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector& dimensions); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_LSTM_H +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/knn.h b/libnd4j/include/ops/declarable/helpers/knn.h index 3a3494a121b3..695803e1fadc 100644 --- a/libnd4j/include/ops/declarable/helpers/knn.h +++ b/libnd4j/include/ops/declarable/helpers/knn.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void knn_mindistance(const NDArray &input, const NDArray &lowest, const NDArray &highest, NDArray &output); - } - } +namespace ops { +namespace helpers { +void knn_mindistance(const NDArray &input, const NDArray &lowest, + const NDArray &highest, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_KNN_H +#endif // SAMEDIFF_KNN_H diff --git a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h index e3191425d936..84a9ff5d48b7 100644 --- a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h +++ b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h @@ -25,46 +25,77 @@ namespace sd { namespace ops { namespace helpers { /* - FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray const* theSecond); - FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void relu6Derivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void leakyReluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void eluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void seluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void cubeDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void reduceNorm1(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void sxeLossWithLogits(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void tanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void hardTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void rationalTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void rectifiedTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void softSignDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void softPlusDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void sigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - FORCEINLINE void hardSigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray const* + theSecond); FORCEINLINE void reluDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void relu6Derivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + leakyReluDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void eluDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void seluDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + cubeDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void reduceNorm1(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void sxeLossWithLogits(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void tanhDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + hardTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void rationalTanhDerivative(NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + rectifiedTanhDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* + theOutput); FORCEINLINE void softSignDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); FORCEINLINE void softPlusDerivative(NDArray* + theFirst, NDArray* theSecond, NDArray* theOutput); FORCEINLINE void + sigmoidDerivative(NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + FORCEINLINE void hardSigmoidDerivative(NDArray* theFirst, NDArray* + theSecond, NDArray* theOutput); */ - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond); - void reluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void relu6Derivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void leakyReluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); - void eluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); - void seluDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void cubeDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void reduceNorm1(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void sigmCrossEntropy(sd::LaunchContext * context, NDArray* logits, NDArray* lablels, NDArray* theOutput); - void sigmCrossEntropyGrad(sd::LaunchContext * context, NDArray* logits, NDArray* lablels, NDArray* theOutput); - void tanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void hardTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void rationalTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void rectifiedTanhDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void softSignDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void softPlusDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void sigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void hardSigmoidDerivative(sd::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output); - void logSumExp(sd::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output); - void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output); -} -} -} +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond); +void reluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void relu6Derivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void leakyReluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, + const float alpha); +void eluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput, const float alpha); +void seluDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void cubeDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void reduceNorm1(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void sigmCrossEntropy(sd::LaunchContext* context, NDArray* logits, + NDArray* lablels, NDArray* theOutput); +void sigmCrossEntropyGrad(sd::LaunchContext* context, NDArray* logits, + NDArray* lablels, NDArray* theOutput); +void tanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void hardTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void rationalTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void rectifiedTanhDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void softSignDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void softPlusDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void sigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void hardSigmoidDerivative(sd::LaunchContext* context, NDArray* theFirst, + NDArray* theSecond, NDArray* theOutput); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* axis, + NDArray* output); +void logSumExp(sd::LaunchContext* context, NDArray* input, NDArray* subtrah, + NDArray* axis, NDArray* output); +void weightedCrossEntropyWithLogitsFunctor(sd::LaunchContext* context, + NDArray const* targets, + NDArray const* input, + NDArray const* weights, + NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lgamma.h b/libnd4j/include/ops/declarable/helpers/lgamma.h index 184e33556777..507cad1587f3 100644 --- a/libnd4j/include/ops/declarable/helpers/lgamma.h +++ b/libnd4j/include/ops/declarable/helpers/lgamma.h @@ -23,18 +23,18 @@ #define __LIBND4J_L_GAMMA__H__ #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - // calculate the digamma function for each element for array - void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z); - -} -} -} +// calculate the digamma function for each element for array +void lgamma(sd::LaunchContext* context, NDArray& x, NDArray& z); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //__LIBND4J_L_GAMMA__H__ +#endif //__LIBND4J_L_GAMMA__H__ diff --git a/libnd4j/include/ops/declarable/helpers/listdiff.h b/libnd4j/include/ops/declarable/helpers/listdiff.h index 227eccac8f6a..4348fbbfd390 100644 --- a/libnd4j/include/ops/declarable/helpers/listdiff.h +++ b/libnd4j/include/ops/declarable/helpers/listdiff.h @@ -26,9 +26,11 @@ namespace sd { namespace ops { namespace helpers { - int listDiffFunctor(sd::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2); - Nd4jLong listDiffCount(sd::LaunchContext * context, NDArray* values, NDArray* keep); -} -} -} +int listDiffFunctor(sd::LaunchContext* context, NDArray* values, NDArray* keep, + NDArray* output1, NDArray* output2); +Nd4jLong listDiffCount(sd::LaunchContext* context, NDArray* values, + NDArray* keep); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lrn.h b/libnd4j/include/ops/declarable/helpers/lrn.h index f8c9089c7e53..24a06e30695b 100644 --- a/libnd4j/include/ops/declarable/helpers/lrn.h +++ b/libnd4j/include/ops/declarable/helpers/lrn.h @@ -19,19 +19,22 @@ // #ifndef __LRN_H_HELPERS__ #define __LRN_H_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta); +int lrnFunctor(sd::graph::Context& block, NDArray* input, NDArray* output, + int depth, double bias, double alpha, double beta); - void lrnBP(sd::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta); +void lrnBP(sd::graph::Context& block, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int depth, + const float bias, const float alpha, const float beta); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lstm.h b/libnd4j/include/ops/declarable/helpers/lstm.h index 6eb5886f3f5f..eabd15aba331 100644 --- a/libnd4j/include/ops/declarable/helpers/lstm.h +++ b/libnd4j/include/ops/declarable/helpers/lstm.h @@ -23,57 +23,64 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray sigmoid(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Sigmoid); - } - - static FORCEINLINE void sigmoidInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Sigmoid, const_cast(arr)); - } - ////////////////////////////////////////////////////////////////////////// - static FORCEINLINE NDArray tanh(const NDArray& arr) { - return (const_cast(arr)).transform(transform::Tanh); - } +static FORCEINLINE NDArray sigmoid(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Sigmoid); +} - static FORCEINLINE void tanhInplace(const NDArray& arr) { - (const_cast(arr)).applyTransform(transform::Tanh, const_cast(arr)); - } +static FORCEINLINE void sigmoidInplace(const NDArray& arr) { + (const_cast(arr)) + .applyTransform(transform::Sigmoid, const_cast(arr)); +} ////////////////////////////////////////////////////////////////////////// - static NDArray timeSubset(const NDArray* arr, const int t, const int dataFormat){ - - if(dataFormat == 0) { // TNS: shape [timeLength, numExamples, inOutSize] - return (*arr)({t,t+1, 0,0, 0,0}); - } - else if(dataFormat == 1) { //NST: shape [numExamples, inOutSize, timeLength] - return (*arr)({0,0, 0,0, t,t+1}); - } - else { //NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - return (*arr)({0,0, t,t+1, 0,0}); - } - } - - void lstmCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* ht, NDArray* ct, const std::vector& params); - - void lstmTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b, - NDArray* h, NDArray* c, const std::vector& params); - - void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params); - - - -} +static FORCEINLINE NDArray tanh(const NDArray& arr) { + return (const_cast(arr)).transform(transform::Tanh); } + +static FORCEINLINE void tanhInplace(const NDArray& arr) { + (const_cast(arr)) + .applyTransform(transform::Tanh, const_cast(arr)); } +////////////////////////////////////////////////////////////////////////// +static NDArray timeSubset(const NDArray* arr, const int t, + const int dataFormat) { + if (dataFormat == 0) { // TNS: shape [timeLength, numExamples, inOutSize] + return (*arr)({t, t + 1, 0, 0, 0, 0}); + } else if (dataFormat == + 1) { // NST: shape [numExamples, inOutSize, timeLength] + return (*arr)({0, 0, 0, 0, t, t + 1}); + } else { // NTS: shape [numExamples, timeLength, inOutSize] - TF + // "time_major=false" layout + return (*arr)({0, 0, t, t + 1, 0, 0}); + } +} -#endif //LIBND4J_LSTM_H +void lstmCell(sd::LaunchContext* context, const NDArray* xt, + const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* ht, NDArray* ct, + const std::vector& params); + +void lstmTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* h0, const NDArray* c0, const NDArray* Wx, + const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, + const NDArray* b, NDArray* h, NDArray* c, + const std::vector& params); + +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/lstmBlock.h b/libnd4j/include/ops/declarable/helpers/lstmBlock.h index 7df9bb795c0d..9487e5c6295c 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmBlock.h +++ b/libnd4j/include/ops/declarable/helpers/lstmBlock.h @@ -23,23 +23,28 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - - void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector& params); - - void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, const NDArray* c0, const NDArray* y0, - const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b, - const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq, - const NDArray* hSeq, const NDArray* ySeq, const std::vector& params, const int dataFormat); - -} -} -} - - -#endif //LIBND4J_LSTM_H +void lstmBlockCell(const NDArray* xt, const NDArray* cLast, + const NDArray* yLast, const NDArray* W, const NDArray* Wci, + const NDArray* Wcf, const NDArray* Wco, const NDArray* b, + NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, + NDArray* h, NDArray* y, const std::vector& params); + +void lstmBlockTimeLoop(const NDArray* maxSeqLength, const NDArray* xSeq, + const NDArray* c0, const NDArray* y0, const NDArray* W, + const NDArray* Wci, const NDArray* Wcf, + const NDArray* Wco, const NDArray* b, + const NDArray* iSeq, const NDArray* cSeq, + const NDArray* fSeq, const NDArray* oSeq, + const NDArray* zSeq, const NDArray* hSeq, + const NDArray* ySeq, const std::vector& params, + const int dataFormat); + +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // LIBND4J_LSTM_H diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 7eee8917dcf3..6ba4f2e660b3 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -23,48 +23,59 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - NDArray* h, NDArray* c); +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, + const std::vector& params, NDArray* h, + NDArray* c); ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop -void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - NDArray* z, NDArray* a, NDArray* h, NDArray* c); +void SD_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, + const std::vector& params, NDArray* z, + NDArray* a, NDArray* h, NDArray* c); ////////////////////////////////////////////////////////////////////////// -void SD_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); - +void SD_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const NDArray* Wp, const NDArray* dLdh, + const NDArray* dLdhL, const NDArray* dLdcL, + const NDArray* z, const NDArray* a, + const NDArray* c, + const std::vector& params, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); ////////////////////////////////////////////////////////////////////////// -void SD_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const std::vector& params, - const bool forward, - NDArray* h, NDArray* hL, NDArray* cL); +void SD_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* seqLen, const NDArray* hI, + const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, NDArray* h, NDArray* hL, + NDArray* cL); ////////////////////////////////////////////////////////////////////////// -void SD_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, - const std::vector& params, const bool forward, - NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp); - - -} -} -} +void SD_EXPORT lstmLayerTimeLoopBp( + const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, + const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, NDArray* dLdx, + NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, + NDArray* dLdcI, NDArray* dLdWp); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_LSTMLAYER_H +#endif // LIBND4J_LSTMLAYER_H diff --git a/libnd4j/include/ops/declarable/helpers/lstsq.h b/libnd4j/include/ops/declarable/helpers/lstsq.h index 9cc6293837a0..de2322bb70c4 100644 --- a/libnd4j/include/ops/declarable/helpers/lstsq.h +++ b/libnd4j/include/ops/declarable/helpers/lstsq.h @@ -20,15 +20,19 @@ #ifndef __LST_SQ_SOLVE__H_HELPERS__ #define __LST_SQ_SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int leastSquaresSolveFunctor(sd::LaunchContext* context, NDArray const* leftInput, NDArray const* rightInput, double const l2Regularizer, bool const fast, NDArray* output); -} -} +int leastSquaresSolveFunctor(sd::LaunchContext* context, + NDArray const* leftInput, + NDArray const* rightInput, + double const l2Regularizer, bool const fast, + NDArray* output); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/lup.h b/libnd4j/include/ops/declarable/helpers/lup.h index 1e58c2e3fe83..d3aa99919730 100644 --- a/libnd4j/include/ops/declarable/helpers/lup.h +++ b/libnd4j/include/ops/declarable/helpers/lup.h @@ -19,26 +19,32 @@ // #ifndef __LUP_H_HELPERS__ #define __LUP_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int lup(sd::LaunchContext* context, NDArray* input, NDArray* lu, NDArray* permutation); - void lu(sd::LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation); - int determinant(sd::LaunchContext * context, NDArray* input, NDArray* output); - int logAbsDeterminant(sd::LaunchContext * context, NDArray* input, NDArray* output); +int lup(sd::LaunchContext* context, NDArray* input, NDArray* lu, + NDArray* permutation); +void lu(sd::LaunchContext* context, NDArray* input, NDArray* output, + NDArray* permutation); +int determinant(sd::LaunchContext* context, NDArray* input, NDArray* output); +int logAbsDeterminant(sd::LaunchContext* context, NDArray* input, + NDArray* output); - int inverse(sd::LaunchContext * context, NDArray* input, NDArray* output); - int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); - int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); +int inverse(sd::LaunchContext* context, NDArray* input, NDArray* output); +int upperInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output); +int lowerInverseFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* output); - bool checkCholeskyInput(sd::LaunchContext * context, NDArray const* input); - int cholesky(sd::LaunchContext * context, NDArray* input, NDArray* output, bool inplace = false); - int logdetFunctor(sd::LaunchContext * context, NDArray* input, NDArray* output); -} -} -} +bool checkCholeskyInput(sd::LaunchContext* context, NDArray const* input); +int cholesky(sd::LaunchContext* context, NDArray* input, NDArray* output, + bool inplace = false); +int logdetFunctor(sd::LaunchContext* context, NDArray* input, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/matmul.h b/libnd4j/include/ops/declarable/helpers/matmul.h index 2cce162d81b3..5051cc24f163 100644 --- a/libnd4j/include/ops/declarable/helpers/matmul.h +++ b/libnd4j/include/ops/declarable/helpers/matmul.h @@ -24,12 +24,13 @@ #include namespace sd { - namespace ops { - namespace helpers { +namespace ops { +namespace helpers { - void _matmul(sd::LaunchContext * context, NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha = 1., double beta = 0.); - } - } +void _matmul(sd::LaunchContext *context, NDArray *A, NDArray *B, NDArray *C, + int transA, int transB, double alpha = 1., double beta = 0.); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_MATMUL_H +#endif // LIBND4J_MATMUL_H diff --git a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h index 332c3134bac1..caf708675971 100644 --- a/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h +++ b/libnd4j/include/ops/declarable/helpers/matrixSetDiag.h @@ -22,17 +22,19 @@ #define LIBND4J_MATRIXSETDIAG_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, const NDArray& diagonal, NDArray& output, const bool zeroPad); +void matrixSetDiag(sd::LaunchContext* context, const NDArray& input, + const NDArray& diagonal, NDArray& output, + const bool zeroPad); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_MATRIXSETDIAG_H +#endif // LIBND4J_MATRIXSETDIAG_H diff --git a/libnd4j/include/ops/declarable/helpers/matrix_band.h b/libnd4j/include/ops/declarable/helpers/matrix_band.h index f997e4d56573..8aadd8bd47a1 100644 --- a/libnd4j/include/ops/declarable/helpers/matrix_band.h +++ b/libnd4j/include/ops/declarable/helpers/matrix_band.h @@ -19,17 +19,17 @@ // #ifndef __MATRIX_BAND_H_HELPERS__ #define __MATRIX_BAND_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void matrixBandPart(sd::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong lowerBand, Nd4jLong upperBand); - +void matrixBandPart(sd::LaunchContext* context, NDArray* input, NDArray* output, + Nd4jLong lowerBand, Nd4jLong upperBand); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h b/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h index fd25c636c92a..01f3b40d4d02 100644 --- a/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h +++ b/libnd4j/include/ops/declarable/helpers/matrix_diag_part.h @@ -19,16 +19,17 @@ // #ifndef __MATRIX_DIAG_PART_HELPERS__ #define __MATRIX_DIAG_PART_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int matrixDiagPart(sd::LaunchContext * context, NDArray const* input, NDArray* output); +int matrixDiagPart(sd::LaunchContext* context, NDArray const* input, + NDArray* output); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/max_pooling.h b/libnd4j/include/ops/declarable/helpers/max_pooling.h index a3750798b752..f713e27545be 100644 --- a/libnd4j/include/ops/declarable/helpers/max_pooling.h +++ b/libnd4j/include/ops/declarable/helpers/max_pooling.h @@ -19,16 +19,18 @@ // #ifndef __MAX_POOLING_HELPERS__ #define __MAX_POOLING_HELPERS__ -#include #include #include +#include namespace sd { namespace ops { namespace helpers { - void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices); -} -} +void maxPoolingFunctor(sd::LaunchContext* context, sd::graph::Context& block, + NDArray* input, NDArray* values, + std::vector const& params, NDArray* indices); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/meshgrid.h b/libnd4j/include/ops/declarable/helpers/meshgrid.h index e6c38502940a..66723f2ed03c 100644 --- a/libnd4j/include/ops/declarable/helpers/meshgrid.h +++ b/libnd4j/include/ops/declarable/helpers/meshgrid.h @@ -23,15 +23,15 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void meshgrid(sd::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims); +void meshgrid(sd::LaunchContext* context, const std::vector& inArrs, + const std::vector& outArrs, const bool swapFirst2Dims); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_SRU_H +#endif // LIBND4J_SRU_H diff --git a/libnd4j/include/ops/declarable/helpers/minimax.h b/libnd4j/include/ops/declarable/helpers/minimax.h index f619a20f667e..58c0215629bc 100644 --- a/libnd4j/include/ops/declarable/helpers/minimax.h +++ b/libnd4j/include/ops/declarable/helpers/minimax.h @@ -19,17 +19,19 @@ // #ifndef __MIN_I_MAX_H_HELPERS__ #define __MIN_I_MAX_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void minimumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY); - void maximumBPFunctor(sd::LaunchContext * context, NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY); +void minimumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY); +void maximumBPFunctor(sd::LaunchContext* context, NDArray* x, NDArray* y, + NDArray* epsNext, NDArray* gradX, NDArray* gradY); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/multiUnique.h b/libnd4j/include/ops/declarable/helpers/multiUnique.h index a7ce14818d27..d915ed7db0c3 100644 --- a/libnd4j/include/ops/declarable/helpers/multiUnique.h +++ b/libnd4j/include/ops/declarable/helpers/multiUnique.h @@ -19,16 +19,17 @@ // #ifndef __MULTI_UNIQUE_H_HELPERS__ #define __MULTI_UNIQUE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - SD_EXPORT bool multiUnique(std::vector const& inputList, sd::memory::Workspace* workspace = nullptr); +SD_EXPORT bool multiUnique(std::vector const& inputList, + sd::memory::Workspace* workspace = nullptr); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/nth_element.h b/libnd4j/include/ops/declarable/helpers/nth_element.h index 1a2c28719af2..fac411d1833d 100644 --- a/libnd4j/include/ops/declarable/helpers/nth_element.h +++ b/libnd4j/include/ops/declarable/helpers/nth_element.h @@ -19,16 +19,17 @@ // #ifndef __NTH_ELEMENT__H_HELPERS__ #define __NTH_ELEMENT__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void nthElementFunctor(sd::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse); +void nthElementFunctor(sd::LaunchContext* context, NDArray* input, Nd4jLong n, + NDArray* output, bool reverse); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/one_hot.h b/libnd4j/include/ops/declarable/helpers/one_hot.h index 2c435a75e948..7b9f2b997ec9 100644 --- a/libnd4j/include/ops/declarable/helpers/one_hot.h +++ b/libnd4j/include/ops/declarable/helpers/one_hot.h @@ -21,17 +21,19 @@ #ifndef SD_ONE_HOT_H #define SD_ONE_HOT_H -#include #include +#include -namespace sd { -namespace ops { -namespace helpers { +namespace sd { +namespace ops { +namespace helpers { - void onehot(const sd::LaunchContext* context, const NDArray *indices, NDArray *output, const uint axis, const uint depth, const double on, const double off); +void onehot(const sd::LaunchContext *context, const NDArray *indices, + NDArray *output, const uint axis, const uint depth, const double on, + const double off); } -} -} +} // namespace ops +} // namespace sd -#endif //SD_ONE_HOT_H +#endif // SD_ONE_HOT_H diff --git a/libnd4j/include/ops/declarable/helpers/percentile.h b/libnd4j/include/ops/declarable/helpers/percentile.h index 5eddb42b2907..a0044ccfb6a2 100644 --- a/libnd4j/include/ops/declarable/helpers/percentile.h +++ b/libnd4j/include/ops/declarable/helpers/percentile.h @@ -22,18 +22,19 @@ #define LIBND4J_PERCENTILE_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - void percentile(sd::LaunchContext * context, const NDArray& input, NDArray& output, std::vector& axises, const float q, const int interpolation); - +void percentile(sd::LaunchContext* context, const NDArray& input, + NDArray& output, std::vector& axises, const float q, + const int interpolation); } -} -} - +} // namespace ops +} // namespace sd -#endif //LIBND4J_PERCENTILE_H \ No newline at end of file +#endif // LIBND4J_PERCENTILE_H \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/prefix.h b/libnd4j/include/ops/declarable/helpers/prefix.h index 757c5c94f72f..dd810a17a8d8 100644 --- a/libnd4j/include/ops/declarable/helpers/prefix.h +++ b/libnd4j/include/ops/declarable/helpers/prefix.h @@ -21,22 +21,28 @@ #ifndef LIBND4J_PREFIX_HELPER_H #define LIBND4J_PREFIX_HELPER_H +#include #include #include + #include -#include namespace sd { - namespace ops { - namespace helpers { - // template - // void prefix(sd::LaunchContext * context, sd::scalar::Ops op, void* x, Nd4jLong *xShapeInfo, void* z, Nd4jLong* zShapeInfo, bool exclusive, bool reverse); +namespace ops { +namespace helpers { +// template +// void prefix(sd::LaunchContext * context, sd::scalar::Ops op, void* x, +// Nd4jLong *xShapeInfo, void* z, Nd4jLong* zShapeInfo, bool exclusive, bool +// reverse); - void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, NDArray* z, bool exclusive, bool reverse); +void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, + NDArray* z, bool exclusive, bool reverse); - void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, NDArray* z, const std::vector& dims, bool exclusive, bool reverse); - } - } -} +void prefix(sd::LaunchContext* context, sd::scalar::Ops op, const NDArray* x, + NDArray* z, const std::vector& dims, bool exclusive, + bool reverse); +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/print_variable.h b/libnd4j/include/ops/declarable/helpers/print_variable.h index 46cf4ee01bc1..d04a1cc7ce19 100644 --- a/libnd4j/include/ops/declarable/helpers/print_variable.h +++ b/libnd4j/include/ops/declarable/helpers/print_variable.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void print_special(LaunchContext &ctx, const NDArray &array, const std::string &message = {}); - } - } +namespace ops { +namespace helpers { +void print_special(LaunchContext &ctx, const NDArray &array, + const std::string &message = {}); } +} // namespace ops +} // namespace sd -#endif //LIBND4J_PRINT_VARIABLE_H +#endif // LIBND4J_PRINT_VARIABLE_H diff --git a/libnd4j/include/ops/declarable/helpers/qr.h b/libnd4j/include/ops/declarable/helpers/qr.h index 05de6ca4049e..115d374e62c8 100644 --- a/libnd4j/include/ops/declarable/helpers/qr.h +++ b/libnd4j/include/ops/declarable/helpers/qr.h @@ -19,17 +19,17 @@ // #ifndef __QR__H_HELPERS__ #define __QR__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void qr(sd::LaunchContext * context, NDArray const* input, NDArray* outputQ, NDArray* outputR, bool const fullMatricies); - +void qr(sd::LaunchContext* context, NDArray const* input, NDArray* outputQ, + NDArray* outputR, bool const fullMatricies); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/random.h b/libnd4j/include/ops/declarable/helpers/random.h index 5ee75e141fc7..dbe93151e360 100644 --- a/libnd4j/include/ops/declarable/helpers/random.h +++ b/libnd4j/include/ops/declarable/helpers/random.h @@ -22,20 +22,25 @@ // #ifndef __RANDOM_HELPERS__ #define __RANDOM_HELPERS__ -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); - void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); - void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output); - void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC); -} -} -} +void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* alpha, NDArray* beta, NDArray* output); +void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* lambda, NDArray* output); +void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, + NDArray* min, NDArray* max, NDArray* output); +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, + NDArray& input, NDArray& output, + const Nd4jLong numOfSamples, const int dimC); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/random_crop.h b/libnd4j/include/ops/declarable/helpers/random_crop.h index f4d36a850b19..bf1730c66791 100644 --- a/libnd4j/include/ops/declarable/helpers/random_crop.h +++ b/libnd4j/include/ops/declarable/helpers/random_crop.h @@ -19,18 +19,19 @@ // #ifndef __RANDOM_CROP_HELPERS__ #define __RANDOM_CROP_HELPERS__ -#include #include -#include #include +#include +#include namespace sd { namespace ops { namespace helpers { - int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, NDArray* output, int seed); +int randomCropFunctor(graph::Context& context, NDArray* input, NDArray* shape, + NDArray* output, int seed); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/range.h b/libnd4j/include/ops/declarable/helpers/range.h index 13155fd70529..fc5e6ae85b6d 100644 --- a/libnd4j/include/ops/declarable/helpers/range.h +++ b/libnd4j/include/ops/declarable/helpers/range.h @@ -23,16 +23,16 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - // be careful: outVector must have c-order and ews = 1 !!! - void range(sd::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector); +// be careful: outVector must have c-order and ews = 1 !!! +void range(sd::LaunchContext* context, const NDArray& start, + const NDArray& delta, NDArray& outVector); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd - -#endif //LIBND4J_RANGE_H +#endif // LIBND4J_RANGE_H diff --git a/libnd4j/include/ops/declarable/helpers/reverse.h b/libnd4j/include/ops/declarable/helpers/reverse.h index d85d017ba832..9c5fa7364fa4 100644 --- a/libnd4j/include/ops/declarable/helpers/reverse.h +++ b/libnd4j/include/ops/declarable/helpers/reverse.h @@ -23,19 +23,19 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim); +void reverseSequence(sd::LaunchContext* context, const NDArray* input, + const NDArray* seqLengths, NDArray* output, int seqDim, + const int batchDim); - void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp); +void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, + const std::vector* intArgs, bool isBackProp); - +} // namespace helpers +} // namespace ops +} // namespace sd -} -} -} - - -#endif //LIBND4J_REVERSESEQUENCE_H +#endif // LIBND4J_REVERSESEQUENCE_H diff --git a/libnd4j/include/ops/declarable/helpers/rnn.h b/libnd4j/include/ops/declarable/helpers/rnn.h index 32f49fe2e317..885d3be8bb77 100644 --- a/libnd4j/include/ops/declarable/helpers/rnn.h +++ b/libnd4j/include/ops/declarable/helpers/rnn.h @@ -23,18 +23,21 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void rnnCell(sd::LaunchContext* context, const NDArray* xt, const NDArray* Wx, + const NDArray* Wh, const NDArray* b, const NDArray* ht_1, + NDArray* ht); - void rnnCell(sd::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht); +void rnnTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* Wx, const NDArray* Wh, const NDArray* b, + const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, + NDArray* hFinal); - void rnnTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal); +} // namespace helpers +} // namespace ops +} // namespace sd -} -} -} - - -#endif //LIBND4J_RNN_H +#endif // LIBND4J_RNN_H diff --git a/libnd4j/include/ops/declarable/helpers/roll.h b/libnd4j/include/ops/declarable/helpers/roll.h index 3e637dbc43ea..715e1e40e4a2 100644 --- a/libnd4j/include/ops/declarable/helpers/roll.h +++ b/libnd4j/include/ops/declarable/helpers/roll.h @@ -24,10 +24,13 @@ namespace sd { namespace ops { namespace helpers { - void rollFunctorLinear(sd::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false); +void rollFunctorLinear(sd::LaunchContext* context, NDArray* input, + NDArray* output, int shift, bool inplace = false); - void rollFunctorFull(sd::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace = false); -} -} -} +void rollFunctorFull(sd::LaunchContext* context, NDArray* input, + NDArray* output, std::vector const& shifts, + std::vector const& axes, bool inplace = false); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/s_t_b.h b/libnd4j/include/ops/declarable/helpers/s_t_b.h index 1147f05ab92d..360aa68dd457 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_b.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_b.h @@ -27,23 +27,38 @@ namespace sd { namespace ops { namespace helpers { - void batchToSpace(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize); +void batchToSpace(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint cropBottom, const uint cropTop, + const uint cropLeft, const uint cropRight, + const uint blockSize); - void spaceToBatch(sd::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize); +void spaceToBatch(sd::LaunchContext* context, const NDArray& input, + NDArray& output, const uint padBottom, const uint padTop, + const uint padLeft, const uint padRight, + const uint blockSize); - void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& padding, NDArray& output); +void spaceToBatchND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& padding, + NDArray& output); - void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, const NDArray& blockShape, const NDArray& crop, NDArray& output); +void batchToSpaceND(sd::LaunchContext* context, const NDArray& input, + const NDArray& blockShape, const NDArray& crop, + NDArray& output); /* // this method MUST be platform-specific template - void _execute(sd::LaunchContext * context, void *ptrSpace, const Nd4jLong *space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const Nd4jLong *batch_shape, const Nd4jLong *batch_strides); + void _execute(sd::LaunchContext * context, void *ptrSpace, const Nd4jLong +*space_shape, const Nd4jLong *space_strides, const Nd4jLong *block_shape, const +Nd4jLong *pad_start, const Nd4jLong *block_offsets, void *ptrBatch, const +Nd4jLong *batch_shape, const Nd4jLong *batch_strides); template - FORCEINLINE void _prepare(sd::LaunchContext * context, NDArray * space, NDArray *batch, const Nd4jLong block_array[NUM_BLOCK_DIMS], const Nd4jLong padding_array[NUM_BLOCK_DIMS * 2]) { + FORCEINLINE void _prepare(sd::LaunchContext * context, NDArray * space, +NDArray *batch, const Nd4jLong block_array[NUM_BLOCK_DIMS], const Nd4jLong +padding_array[NUM_BLOCK_DIMS * 2]) { Nd4jLong pad_start[NUM_BLOCK_DIMS]; Nd4jLong block_shape[NUM_BLOCK_DIMS]; @@ -69,26 +84,38 @@ namespace helpers { const Nd4jLong space_b = batch_b % space_size; Nd4jLong block_index = batch_b / space_size; Nd4jLong block_offsets[NUM_BLOCK_DIMS]; - for (Nd4jLong block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) { - block_offsets[block_dim] = block_dim > 0 ? block_index % block_shape[block_dim] : block_index; - block_index /= block_shape[block_dim]; + for (Nd4jLong block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; +--block_dim) { block_offsets[block_dim] = block_dim > 0 ? block_index % +block_shape[block_dim] : block_index; block_index /= block_shape[block_dim]; } Nd4jLong space_offset = space_b * space_strides[0]; Nd4jLong batch_offset = batch_b * batch_strides[0]; auto xType = space->dataType(); - //_execute(space->buffer() + space_offset, space_shape, &space_strides[1], block_shape, pad_start, block_offsets, batch->buffer() + batch_offset, batch_shape, &batch_strides[1]); - BUILD_SINGLE_PARTIAL_SELECTOR(xType, _execute<, (NUM_BLOCK_DIMS, B2S>(context, space->bufferWithOffset(space_offset), space_shape, &space_strides[1], block_shape, pad_start, block_offsets, batch->bufferWithOffset(batch_offset), batch_shape, &batch_strides[1])), LIBND4J_TYPES); + //_execute(space->buffer() + space_offset, +space_shape, &space_strides[1], block_shape, pad_start, block_offsets, +batch->buffer() + batch_offset, batch_shape, &batch_strides[1]); + BUILD_SINGLE_PARTIAL_SELECTOR(xType, _execute<, (NUM_BLOCK_DIMS, +B2S>(context, space->bufferWithOffset(space_offset), space_shape, +&space_strides[1], block_shape, pad_start, block_offsets, +batch->bufferWithOffset(batch_offset), batch_shape, &batch_strides[1])), +LIBND4J_TYPES); } }; - Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings); + Nd4jStatus _spaceToBatch(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *paddings); - Nd4jStatus _batchToSpace(sd::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector &internal_input_shape, std::vector &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops); + Nd4jStatus _batchToSpace(sd::LaunchContext * context, int +internal_block_dims, NDArray *input, NDArray *output, std::vector +&internal_input_shape, std::vector &internal_output_shape, Nd4jLong +*block_shape, Nd4jLong *crops); */ -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_S_T_B_H +#endif // LIBND4J_S_T_B_H diff --git a/libnd4j/include/ops/declarable/helpers/s_t_d.h b/libnd4j/include/ops/declarable/helpers/s_t_d.h index 7ef500f03a29..21f588fd05d1 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_d.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_d.h @@ -18,13 +18,14 @@ // @author raver119@gmail.com // -#include #include +#include namespace sd { namespace ops { namespace helpers { - void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); -} +void _spaceTodepth(sd::LaunchContext *context, const NDArray &input, + NDArray *output, int block_size, bool isNHWC); } -} \ No newline at end of file +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/scatter.h b/libnd4j/include/ops/declarable/helpers/scatter.h index 0460a702d3cc..e888260a1f37 100644 --- a/libnd4j/include/ops/declarable/helpers/scatter.h +++ b/libnd4j/include/ops/declarable/helpers/scatter.h @@ -24,17 +24,23 @@ #include namespace sd { - namespace ops { - namespace helpers { - void scatter(sd::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); - - void scatterND(sd::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); - - void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad); - - Nd4jLong checkIndices(sd::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis = -1); - } - } -} - -#endif //SD_SCATTER_H +namespace ops { +namespace helpers { +void scatter(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock); + +void scatterND(sd::LaunchContext* context, pairwise::Ops op, + const NDArray& indices, const NDArray& updates, NDArray& output, + const bool lock); + +void scatterForLoss(sd::LaunchContext* context, const NDArray& indices, + NDArray& updates, NDArray& output, const bool calcGrad); + +Nd4jLong checkIndices(sd::LaunchContext* context, const NDArray& indices, + const NDArray& output, const int axis = -1); +} // namespace helpers +} // namespace ops +} // namespace sd + +#endif // SD_SCATTER_H diff --git a/libnd4j/include/ops/declarable/helpers/segment.h b/libnd4j/include/ops/declarable/helpers/segment.h index 2433313ffbaf..900b80ba0c1a 100644 --- a/libnd4j/include/ops/declarable/helpers/segment.h +++ b/libnd4j/include/ops/declarable/helpers/segment.h @@ -17,66 +17,104 @@ // // @author sgazeos@gmail.com // @brief helpers fuctions for segment_* ops (segment_max, segment_min, etc.) -// @brief helpers fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.) +// @brief helpers fuctions for unsorted_segment_* ops (unsorted_segment_max, +// etc.) // #ifndef __SEGMENT_HELPERS__ #define __SEGMENT_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output); +bool segmentIndicesValidate(sd::LaunchContext* context, NDArray* indices, + NDArray& expected, NDArray& output); - bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong numOfClasses, Nd4jLong& output); +bool unsortedSegmentIndicesValidate(sd::LaunchContext* context, + NDArray* indices, Nd4jLong numOfClasses, + Nd4jLong& output); - void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output); +void segmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* output); - void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentSqrtNFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMaxFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMinFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentSumFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output); +void unsortedSegmentProdFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* indices, Nd4jLong numOfClasses, + NDArray* output); - int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output); +int segmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, NDArray* output); - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMaxFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMinFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentMeanFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentSumFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); - int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output); +int unsortedSegmentProdFunctorBP(sd::LaunchContext* context, NDArray* input, + NDArray* indices, NDArray* gradOut, + Nd4jLong numOfClasses, NDArray* output); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/segment_common.h b/libnd4j/include/ops/declarable/helpers/segment_common.h index b0a92b8b33e1..aa67893e9ed7 100644 --- a/libnd4j/include/ops/declarable/helpers/segment_common.h +++ b/libnd4j/include/ops/declarable/helpers/segment_common.h @@ -16,21 +16,23 @@ // // @author sgazeos@gmail.com -// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.) -// @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.) +// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, +// etc.) +// @brief helpers common fuctions for unsorted_segment_* ops +// (unsorted_segment_max, etc.) // #ifndef __SEGMENT_COMMON_HELPERS__ #define __SEGMENT_COMMON_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens); - +void fillUpSegments(NDArray* indices, Nd4jLong numClasses, + NDArray& classesRangesBegs, NDArray& classesRangesLens); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sequence_mask.h b/libnd4j/include/ops/declarable/helpers/sequence_mask.h index 491640b9e769..2b97dcca580d 100644 --- a/libnd4j/include/ops/declarable/helpers/sequence_mask.h +++ b/libnd4j/include/ops/declarable/helpers/sequence_mask.h @@ -19,16 +19,17 @@ // #ifndef __SEQUENCE_MASK_HELPERS__ #define __SEQUENCE_MASK_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void sequenceMask(sd::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex); +void sequenceMask(sd::LaunchContext* context, NDArray* input, NDArray* output, + int maxIndex); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sg_cb.h b/libnd4j/include/ops/declarable/helpers/sg_cb.h index abf073786b6d..9279c2d6e6db 100644 --- a/libnd4j/include/ops/declarable/helpers/sg_cb.h +++ b/libnd4j/include/ops/declarable/helpers/sg_cb.h @@ -21,20 +21,30 @@ #ifndef SD_SG_CB_H #define SD_SG_CB_H +#include #include #include -#include namespace sd { - namespace ops { - namespace helpers { - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers); +namespace ops { +namespace helpers { +void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, + int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, + NDArray &randomValue, NDArray &inferenceVector, + const bool preciseMode, const int numWorkers); - void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, const int numWorkers); +void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, + NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, + NDArray &context, NDArray &lockedWords, NDArray &indices, + NDArray &codes, NDArray &alpha, NDArray &randomValue, + NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, + const int numWorkers); - int binarySearch(const int *haystack, const int needle, const int totalElements); - } - } -} +int binarySearch(const int *haystack, const int needle, + const int totalElements); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //SD_SG_CB_H +#endif // SD_SG_CB_H diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h index da816a902640..6dfe541d1fc0 100644 --- a/libnd4j/include/ops/declarable/helpers/shift.h +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -21,22 +21,26 @@ #ifndef SD_SHIFT_H #define SD_SHIFT_H +#include #include #include -#include namespace sd { - namespace ops { - namespace helpers { - void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +namespace ops { +namespace helpers { +void rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +void shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); +void cyclic_rshift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); - void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); - } - } -} +void cyclic_shift_bits(LaunchContext *launchContext, NDArray &x, NDArray &z, + uint32_t shift); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //SD_SHIFT_H +#endif // SD_SHIFT_H diff --git a/libnd4j/include/ops/declarable/helpers/solve.h b/libnd4j/include/ops/declarable/helpers/solve.h index 17234f313877..ee0bca0c4820 100644 --- a/libnd4j/include/ops/declarable/helpers/solve.h +++ b/libnd4j/include/ops/declarable/helpers/solve.h @@ -19,16 +19,18 @@ // #ifndef __SOLVE__H_HELPERS__ #define __SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool adjoint, NDArray* output); - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, NDArray* output); -} -} -} +int solveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool adjoint, NDArray* output); +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h index 541621257548..ecb66ff9f9ef 100644 --- a/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h +++ b/libnd4j/include/ops/declarable/helpers/sparse_to_dense.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output); - } - } +namespace ops { +namespace helpers { +void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, + NDArray *def, NDArray &output); } +} // namespace ops +} // namespace sd -#endif //SAMEDIFF_SPARSE_TO_DENSE_H +#endif // SAMEDIFF_SPARSE_TO_DENSE_H diff --git a/libnd4j/include/ops/declarable/helpers/sru.h b/libnd4j/include/ops/declarable/helpers/sru.h index 639247278512..a998ef8a345a 100644 --- a/libnd4j/include/ops/declarable/helpers/sru.h +++ b/libnd4j/include/ops/declarable/helpers/sru.h @@ -23,25 +23,28 @@ #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { +void sruCell(sd::LaunchContext* context, const NDArray* x, const NDArray* c0, + const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); - void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); +void sruTimeLoop(sd::LaunchContext* context, const NDArray* x, + const NDArray* c0, const NDArray* w, const NDArray* b, + NDArray* h, NDArray* c); - void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c); +void sruBI(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* mask, + NDArray* ht, NDArray* ct); - void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct); +void sruBIBP(sd::LaunchContext* context, NDArray* x, const NDArray* w, + const NDArray* b, const NDArray* c0, const NDArray* ct, + const NDArray* inGradC0, const NDArray* inGradH, + const NDArray* mask, NDArray* gradI, NDArray* gradWeights, + NDArray* gradB, NDArray* gradC0); +} // namespace helpers +} // namespace ops +} // namespace sd - void sruBIBP(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, - NDArray* gradI, NDArray* gradWeights, NDArray* gradB, NDArray* gradC0); -} -} -} - - -#endif //LIBND4J_SRU_H - - - \ No newline at end of file +#endif // LIBND4J_SRU_H diff --git a/libnd4j/include/ops/declarable/helpers/stack.h b/libnd4j/include/ops/declarable/helpers/stack.h index 0ab486a5d33c..af694cc57448 100644 --- a/libnd4j/include/ops/declarable/helpers/stack.h +++ b/libnd4j/include/ops/declarable/helpers/stack.h @@ -21,20 +21,21 @@ #ifndef LIBND4J_STACK_H #define LIBND4J_STACK_H -#include #include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { -void stack (sd::LaunchContext* context, const std::vector& inArrs, NDArray& outArr, const int dim); -void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector& outArrs, const int dim); - - -} -} -} +void stack(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& outArr, + const int dim); +void unstack(sd::LaunchContext* context, const NDArray& input, + const std::vector& outArrs, const int dim); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_STACK_H +#endif // LIBND4J_STACK_H diff --git a/libnd4j/include/ops/declarable/helpers/svd.h b/libnd4j/include/ops/declarable/helpers/svd.h index 027807191897..14843d09d9a6 100644 --- a/libnd4j/include/ops/declarable/helpers/svd.h +++ b/libnd4j/include/ops/declarable/helpers/svd.h @@ -22,19 +22,22 @@ #define LIBND4J_SVD_HELPER_H #include + #include "array/NDArray.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -// svd operation, this function is not method of SVD class, it is standalone function -void svd(sd::LaunchContext* context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum); - +// svd operation, this function is not method of SVD class, it is standalone +// function +void svd(sd::LaunchContext* context, const NDArray* x, + const std::vector& outArrs, const bool fullUV, + const bool calcUV, const int switchNum); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_SVD_HELPER_H +#endif // LIBND4J_SVD_HELPER_H diff --git a/libnd4j/include/ops/declarable/helpers/threshold.h b/libnd4j/include/ops/declarable/helpers/threshold.h index 21ac0c8208a0..f77e0efb057c 100644 --- a/libnd4j/include/ops/declarable/helpers/threshold.h +++ b/libnd4j/include/ops/declarable/helpers/threshold.h @@ -24,14 +24,14 @@ #include namespace sd { - namespace ops { - namespace helpers { - int32_t thresholdEstimate(const NDArray &updates, float threshold); +namespace ops { +namespace helpers { +int32_t thresholdEstimate(const NDArray &updates, float threshold); - void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold); - void thresholdDecode(const NDArray &encoded, NDArray &updates); - } - } -} +void thresholdEncode(NDArray &updates, NDArray &encoded, float threshold); +void thresholdDecode(const NDArray &encoded, NDArray &updates); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //SD_THRESHOLD_H +#endif // SD_THRESHOLD_H diff --git a/libnd4j/include/ops/declarable/helpers/toggle_bits.h b/libnd4j/include/ops/declarable/helpers/toggle_bits.h index 5c30765dd22c..763c494e3784 100644 --- a/libnd4j/include/ops/declarable/helpers/toggle_bits.h +++ b/libnd4j/include/ops/declarable/helpers/toggle_bits.h @@ -24,14 +24,15 @@ #define SD_TOGGLE_BITS_H namespace sd { - namespace ops { - namespace helpers { - template - static void toggle_bits__(sd::LaunchContext * context, NDArray& in, NDArray& out); +namespace ops { +namespace helpers { +template +static void toggle_bits__(sd::LaunchContext* context, NDArray& in, + NDArray& out); - void __toggle_bits(sd::LaunchContext * context, NDArray& in, NDArray& out); - } - } -} +void __toggle_bits(sd::LaunchContext* context, NDArray& in, NDArray& out); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //SD_TOGGLE_BITS_H +#endif // SD_TOGGLE_BITS_H diff --git a/libnd4j/include/ops/declarable/helpers/top_k.h b/libnd4j/include/ops/declarable/helpers/top_k.h index 6a459f925308..4c93d4d5bf85 100644 --- a/libnd4j/include/ops/declarable/helpers/top_k.h +++ b/libnd4j/include/ops/declarable/helpers/top_k.h @@ -19,18 +19,20 @@ // #ifndef __TOP_K_HELPERS__ #define __TOP_K_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort); +int topKFunctor(sd::LaunchContext* context, const NDArray* input, + NDArray* values, NDArray* indices, const uint k, bool needSort); - int inTopKFunctor(sd::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k); +int inTopKFunctor(sd::LaunchContext* context, const NDArray* predictions, + const NDArray* targets, NDArray* output, const uint k); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/transforms.h b/libnd4j/include/ops/declarable/helpers/transforms.h index 6ebecd8f7fdb..202b71be2186 100644 --- a/libnd4j/include/ops/declarable/helpers/transforms.h +++ b/libnd4j/include/ops/declarable/helpers/transforms.h @@ -21,70 +21,105 @@ #ifndef LIBND4J_TRANSFORMS_H #define LIBND4J_TRANSFORMS_H -#include -#include #include +#include +#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace helpers { - void triuBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal); - - void trace(sd::LaunchContext * context, const NDArray& input, NDArray& output); +void triuBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI, const int diagonal); - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace); +void trace(sd::LaunchContext* context, const NDArray& input, NDArray& output); - // auxiliary function which serves for recursion purpose and is used in pad operation - // void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int inIdx, int outIdx, NDArray& padValue); +void randomShuffle(sd::LaunchContext* context, NDArray& input, NDArray& output, + sd::graph::RandomGenerator& rng, const bool isInplace); - void pad(sd::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue); +// auxiliary function which serves for recursion purpose and is used in pad +// operation void recursiveLoopForPad(const int mode, NDArray& input, const +// NDArray& paddings, NDArray& output, std::vector dimensions, int dim, int +// inIdx, int outIdx, NDArray& padValue); - void invertPermutation(sd::LaunchContext * context, const NDArray& input, NDArray& output); +void pad(sd::LaunchContext* context, const int mode, const NDArray& input, + const NDArray& paddings, NDArray& output, NDArray const& padValue); - void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output); +void invertPermutation(sd::LaunchContext* context, const NDArray& input, + NDArray& output); - void gather(sd::LaunchContext * context, NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs); +void gatherND(sd::LaunchContext* context, NDArray& input, NDArray& indices, + NDArray& output); - void eye(sd::LaunchContext * context, NDArray& output); +void gather(sd::LaunchContext* context, NDArray* input, const NDArray* indices, + NDArray* output, const std::vector& intArgs); - void scatterUpdate(sd::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector* intArgs); +void eye(sd::LaunchContext* context, NDArray& output); - void scatterSimple(sd::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector& dimensions); +void scatterUpdate(sd::LaunchContext* context, NDArray& operand, + NDArray& updates, const std::vector* intArgs); - void mergeMaxIndex(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); +void scatterSimple(sd::LaunchContext* context, const int opId, NDArray& input, + const NDArray& updates, const NDArray& indices, + const std::vector& dimensions); - void mergeMax(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeMaxBp(sd::LaunchContext* context, const std::vector& inArrs, std::vector& outArrs); +void mergeMaxIndex(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); - void mergeAvg(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs); +void mergeMax(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeMaxBp(sd::LaunchContext* context, + const std::vector& inArrs, + std::vector& outArrs); - void mergeAdd(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); - void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs); +void mergeAvg(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeAvgBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs); - void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace); +void mergeAdd(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output); +void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, + std::vector& outArrs); - void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm); +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace); +void clipByGlobalNorm(sd::LaunchContext* context, + std::vector const& inputs, double clipNorm, + sd::memory::Workspace* workspace, + std::vector& outputs, bool isInplace); - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output); +void clipByNormBP(sd::LaunchContext* context, const NDArray& input, + const NDArray& gradO, NDArray& gradI /*output*/, + const std::vector& dimensions, const NDArray& clipNorm); - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode); +void clipByAveraged(sd::LaunchContext* context, NDArray& input, NDArray& output, + const std::vector& dimensions, const NDArray& clipNorm, + const bool isInplace); +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output); - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode); - void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode); +void clipByValue(sd::LaunchContext* context, NDArray& input, double leftBound, + double rightBound, NDArray& output); - void concat(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis); +void mirrorPad(sd::LaunchContext* context, const NDArray& input, + const NDArray& paddings, NDArray& output, const int mode); - void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector reps); +void concat(sd::LaunchContext* context, + const std::vector& inArrs, NDArray& output, + const int axis); - void split(sd::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis); -} -} -} +void tileBP(sd::LaunchContext* context, const NDArray& gradO /*input*/, + NDArray& gradI /*output*/, const std::vector reps); +void split(sd::LaunchContext* context, const NDArray& input, + std::vector& outArrs, const int axis); +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_TRANSFORMS_H +#endif // LIBND4J_TRANSFORMS_H diff --git a/libnd4j/include/ops/declarable/helpers/triangular_solve.h b/libnd4j/include/ops/declarable/helpers/triangular_solve.h index 73965f8c5681..8f43bd3279e4 100644 --- a/libnd4j/include/ops/declarable/helpers/triangular_solve.h +++ b/libnd4j/include/ops/declarable/helpers/triangular_solve.h @@ -19,16 +19,19 @@ // #ifndef __TRIANGULAR_SOLVE__H_HELPERS__ #define __TRIANGULAR_SOLVE__H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output); - void adjointMatrix(sd::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output); -} -} -} +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + NDArray* rightInput, bool lower, bool adjoint, + NDArray* output); +void adjointMatrix(sd::LaunchContext* context, NDArray const* input, + bool const lower, NDArray* output); +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/unique.h b/libnd4j/include/ops/declarable/helpers/unique.h index 8898be585cd9..74071451df16 100644 --- a/libnd4j/include/ops/declarable/helpers/unique.h +++ b/libnd4j/include/ops/declarable/helpers/unique.h @@ -20,18 +20,19 @@ #ifndef __UNIQUE_H_HELPERS__ #define __UNIQUE_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - Nd4jLong uniqueCount(sd::LaunchContext * context, NDArray* input); +Nd4jLong uniqueCount(sd::LaunchContext* context, NDArray* input); - Nd4jStatus uniqueFunctor(sd::LaunchContext * context, NDArray* input, NDArray* values, NDArray* indices, NDArray* counts); +Nd4jStatus uniqueFunctor(sd::LaunchContext* context, NDArray* input, + NDArray* values, NDArray* indices, NDArray* counts); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h index 5bd89b487484..4bff0431152f 100644 --- a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h +++ b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h @@ -21,24 +21,52 @@ #ifndef LIBND4J_UPDATER_RMS_PROM_H #define LIBND4J_UPDATER_RMS_PROM_H -#include #include +#include namespace sd { namespace ops { namespace helpers { - void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon); - void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon); - void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum); - void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); - void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, + const double dEpsilon); +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon); +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initState, NDArray& update, + NDArray& stateV, const double dLr, + const double bMomentum); +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, + const double dRho, const double dEpsilon); +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, + const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, + const NDArray& initStateV, const NDArray& initStateM, + const NDArray& initStateH, NDArray& update, NDArray& stateV, + NDArray& stateM, NDArray& stateH, const double dLr, + const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration); -} -} -} +} // namespace helpers +} // namespace ops +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/weights.h b/libnd4j/include/ops/declarable/helpers/weights.h index 66246b641049..88f293601cf4 100644 --- a/libnd4j/include/ops/declarable/helpers/weights.h +++ b/libnd4j/include/ops/declarable/helpers/weights.h @@ -19,16 +19,17 @@ // #ifndef __WEIGHTS_H_HELPERS__ #define __WEIGHTS_H_HELPERS__ -#include #include +#include namespace sd { namespace ops { namespace helpers { - void adjustWeights(sd::LaunchContext * context, NDArray* input, NDArray* weights, NDArray* output, int minLength, int maxLength); +void adjustWeights(sd::LaunchContext* context, NDArray* input, NDArray* weights, + NDArray* output, int minLength, int maxLength); } -} -} +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/helpers/where.h b/libnd4j/include/ops/declarable/helpers/where.h index 284a365f8730..1538c61e8309 100644 --- a/libnd4j/include/ops/declarable/helpers/where.h +++ b/libnd4j/include/ops/declarable/helpers/where.h @@ -24,11 +24,12 @@ #include namespace sd { - namespace ops { - namespace helpers { - void _where(sd::LaunchContext * context, NDArray &condition, NDArray& output, memory::Workspace *workspace); - } - } +namespace ops { +namespace helpers { +void _where(sd::LaunchContext *context, NDArray &condition, NDArray &output, + memory::Workspace *workspace); } +} // namespace ops +} // namespace sd -#endif //SD_WHERE_H +#endif // SD_WHERE_H diff --git a/libnd4j/include/ops/declarable/helpers/zeta.h b/libnd4j/include/ops/declarable/helpers/zeta.h index 7aee45f1c1ef..c3bf280bf66f 100644 --- a/libnd4j/include/ops/declarable/helpers/zeta.h +++ b/libnd4j/include/ops/declarable/helpers/zeta.h @@ -22,79 +22,84 @@ #define LIBND4J_ZETA_H #include + #include "array/NDArray.h" namespace sd { namespace ops { namespace helpers { - - // calculate the Hurwitz zeta function for arrays - void zeta(sd::LaunchContext * context, const NDArray& x, const NDArray& q, NDArray& output); - - - - // calculate the Hurwitz zeta function for scalars - // fast implementation, it is based on Euler-Maclaurin summation formula - template - _CUDA_HD T zetaScalar(const T x, const T q) { - - const T machep = 1.11022302462515654042e-16; - - // FIXME: @raver119 - // expansion coeffZetaicients for Euler-Maclaurin summation formula (2k)! / B2k, where B2k are Bernoulli numbers - const T coeffZeta[] = { 12.0,-720.0,30240.0,-1209600.0,47900160.0,-1.8924375803183791606e9,7.47242496e10,-2.950130727918164224e12, 1.1646782814350067249e14, -4.5979787224074726105e15, 1.8152105401943546773e17, -7.1661652561756670113e18}; - - // if (x <= (T)1.) - // throw("zeta function: x must be > 1 !"); - - // if (q <= (T)0.) - // throw("zeta function: q must be > 0 !"); - - T a, b(0.), k, s, t, w; - - s = math::nd4j_pow(q, -x); - a = q; - int i = 0; - - while(i < 9 || a <= (T)9.) { - i += 1; - a += (T)1.0; - b = math::nd4j_pow(a, -x); - s += b; - if(math::nd4j_abs(b / s) < (T)machep) - return s; - } - - w = a; - s += b * (w / (x - (T)1.) - (T)0.5); - a = (T)1.; - k = (T)0.; - - for(i = 0; i < 12; ++i) { - a *= x + k; - b /= w; - t = a * b / coeffZeta[i]; - s += t; - t = math::nd4j_abs(t / s); - - if(t < (T)machep) - return s; - - k += (T)1.f; - a *= x + k; - b /= w; - k += (T)1.f; - } - - return s; - } - - - -} -} +// calculate the Hurwitz zeta function for arrays +void zeta(sd::LaunchContext* context, const NDArray& x, const NDArray& q, + NDArray& output); + +// calculate the Hurwitz zeta function for scalars +// fast implementation, it is based on Euler-Maclaurin summation formula +template +_CUDA_HD T zetaScalar(const T x, const T q) { + const T machep = 1.11022302462515654042e-16; + + // FIXME: @raver119 + // expansion coeffZetaicients for Euler-Maclaurin summation formula (2k)! / + // B2k, where B2k are Bernoulli numbers + const T coeffZeta[] = {12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, + 7.47242496e10, + -2.950130727918164224e12, + 1.1646782814350067249e14, + -4.5979787224074726105e15, + 1.8152105401943546773e17, + -7.1661652561756670113e18}; + + // if (x <= (T)1.) + // throw("zeta function: x must be > 1 !"); + + // if (q <= (T)0.) + // throw("zeta function: q must be > 0 !"); + + T a, b(0.), k, s, t, w; + + s = math::nd4j_pow(q, -x); + a = q; + int i = 0; + + while (i < 9 || a <= (T)9.) { + i += 1; + a += (T)1.0; + b = math::nd4j_pow(a, -x); + s += b; + if (math::nd4j_abs(b / s) < (T)machep) return s; + } + + w = a; + s += b * (w / (x - (T)1.) - (T)0.5); + a = (T)1.; + k = (T)0.; + + for (i = 0; i < 12; ++i) { + a *= x + k; + b /= w; + t = a * b / coeffZeta[i]; + s += t; + t = math::nd4j_abs(t / s); + + if (t < (T)machep) return s; + + k += (T)1.f; + a *= x + k; + b /= w; + k += (T)1.f; + } + + return s; } +} // namespace helpers +} // namespace ops +} // namespace sd -#endif //LIBND4J_ZETA_H +#endif // LIBND4J_ZETA_H diff --git a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp index f584f803cdcb..5eb4c8c2d33d 100644 --- a/libnd4j/include/ops/declarable/impl/BooleanOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BooleanOp.cpp @@ -19,117 +19,126 @@ // #include "ops/declarable/BooleanOp.h" -#include -#include -#include - -namespace sd { - namespace ops { - BooleanOp::BooleanOp(const char *name, int numInputs, bool scalar) : DeclarableOp::DeclarableOp(name, numInputs, scalar) { - // - } - /** - * Output shape of any BooleanOp is ALWAYS scalar - */ - ShapeList *BooleanOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL)); - } - - bool BooleanOp::verify(sd::graph::Context &block) { - // check if scalar or not +#include - // validation? +#include +#include - Nd4jStatus status = this->validateNonEmptyInput(block); - if (status != ND4J_STATUS_OK) { - nd4j_printf("Inputs should be not empty for BooleanOps",""); - throw std::runtime_error("Bad inputs"); - } +namespace sd { +namespace ops { +BooleanOp::BooleanOp(const char *name, int numInputs, bool scalar) + : DeclarableOp::DeclarableOp(name, numInputs, scalar) { + // +} - status = this->validateAndExecute(block); - if (status == ND4J_STATUS_TRUE) - return true; - else if (status == ND4J_STATUS_FALSE) - return false; - else { - nd4j_printf("Got error %i during [%s] evaluation: ", (int) status, this->getOpDescriptor()->getOpName()->c_str()); - throw std::runtime_error("Internal error"); - } - } +/** + * Output shape of any BooleanOp is ALWAYS scalar + */ +ShapeList *BooleanOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + return SHAPELIST( + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL)); +} - bool BooleanOp::prepareOutputs(Context& ctx) { +bool BooleanOp::verify(sd::graph::Context &block) { + // check if scalar or not + + // validation? + + Nd4jStatus status = this->validateNonEmptyInput(block); + if (status != ND4J_STATUS_OK) { + nd4j_printf("Inputs should be not empty for BooleanOps", ""); + throw std::runtime_error("Bad inputs"); + } + + status = this->validateAndExecute(block); + if (status == ND4J_STATUS_TRUE) + return true; + else if (status == ND4J_STATUS_FALSE) + return false; + else { + nd4j_printf("Got error %i during [%s] evaluation: ", (int)status, + this->getOpDescriptor()->getOpName()->c_str()); + throw std::runtime_error("Internal error"); + } +} - auto variableSpace = ctx.getVariableSpace(); - if (ctx.isFastPath()) - return true; +bool BooleanOp::prepareOutputs(Context &ctx) { + auto variableSpace = ctx.getVariableSpace(); + if (ctx.isFastPath()) return true; - for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { - std::pair pair(ctx.nodeId(), e); + for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { + std::pair pair(ctx.nodeId(), e); - if (!variableSpace->hasVariable(pair)) - variableSpace->putVariable(pair, std::make_shared()); + if (!variableSpace->hasVariable(pair)) + variableSpace->putVariable(pair, std::make_shared()); - auto var = ctx.variable(pair); + auto var = ctx.variable(pair); - if (!var->hasNDArray()) { - var->setNDArray(std::make_shared(NDArrayFactory::create(false, ctx.launchContext()))); - } - } + if (!var->hasNDArray()) { + var->setNDArray(std::make_shared( + NDArrayFactory::create(false, ctx.launchContext()))); + } + } - return true; - } + return true; +} - Nd4jStatus sd::ops::BooleanOp::execute(Context* block) { +Nd4jStatus sd::ops::BooleanOp::execute(Context *block) { + // basic validation: ensure inputs are set + REQUIRE_OK(this->validateNonEmptyInput(*block)); - // basic validation: ensure inputs are set - REQUIRE_OK(this->validateNonEmptyInput(*block)); + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); + // this method will allocate output NDArrays for this op + this->prepareOutputs(*block); - // this method will allocate output NDArrays for this op - this->prepareOutputs(*block); + auto timeStart = std::chrono::system_clock::now(); - auto timeStart = std::chrono::system_clock::now(); + Nd4jStatus status = this->validateAndExecute(*block); - Nd4jStatus status = this->validateAndExecute(*block); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + block->setInnerTime(outerTime); - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - block->setInnerTime(outerTime); + // basically we're should be putting 0.0 as FALSE, and any non-0.0 value will + // be treated as TRUE + std::pair p(block->nodeId(), 0); + auto var = block->isFastPath() ? block->fastpath_out()[0] + : block->variable(p)->getNDArray(); + var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f); - // basically we're should be putting 0.0 as FALSE, and any non-0.0 value will be treated as TRUE - std::pair p(block->nodeId(), 0); - auto var = block->isFastPath() ? block->fastpath_out()[0] : block->variable(p)->getNDArray(); - var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f); + // for CPU backend that's nop, but for CUDA-like archs this will update + // special buffer + var->syncToDevice(); - // for CPU backend that's nop, but for CUDA-like archs this will update special buffer - var->syncToDevice(); + if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE) + return ND4J_STATUS_OK; - if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE) - return ND4J_STATUS_OK; - - nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", this->getOpName().c_str(), block->nodeId(), status); - return ND4J_STATUS_KERNEL_FAILURE; - } + nd4j_printf("%s: node_%i got unexpected result instead of boolean: [%i]\n", + this->getOpName().c_str(), block->nodeId(), status); + return ND4J_STATUS_KERNEL_FAILURE; +} - bool BooleanOp::verify(const std::vector &args) { - VariableSpace variableSpace; +bool BooleanOp::verify(const std::vector &args) { + VariableSpace variableSpace; - int cnt = -1; - std::vector in; - for (auto v: args) { - auto var = std::make_shared(*v, "", cnt); - in.emplace_back(cnt); - variableSpace.putVariable(cnt--, var); - } + int cnt = -1; + std::vector in; + for (auto v : args) { + auto var = std::make_shared(*v, "", cnt); + in.emplace_back(cnt); + variableSpace.putVariable(cnt--, var); + } - Context block(1, &variableSpace, false); - block.fillInputs(in); + Context block(1, &variableSpace, false); + block.fillInputs(in); - return this->verify(block); - } - } + return this->verify(block); } - +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp index 8f0a6dcb88c1..7948bf12ecd1 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableBoolOp.cpp @@ -18,55 +18,69 @@ // Created by raver on 6/6/2018. // +#include +#include #include #include -#include -#include namespace sd { - namespace ops { - BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) { - // - } +namespace ops { +BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, + int numIArgs) + : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, + numIArgs) { + // +} - ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto shapeList = SHAPELIST(); - auto x = inputShape->at(0); - auto y = inputShape->at(1); - sd::DataType dtype = sd::DataType::BOOL; +ShapeList *BroadcastableBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto shapeList = SHAPELIST(); + auto x = inputShape->at(0); + auto y = inputShape->at(1); + sd::DataType dtype = sd::DataType::BOOL; - if(shape::isEmpty(x) || shape::isEmpty(y)) { - // this is edge case, [3, 4] + [] = [] - if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); - return shapeList; - } - - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else if (shape::isScalar(x) && shape::isScalar(y)) { - if (shape::rank(x) >= shape::rank(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); - } - } else if (shape::equalsSoft(x, y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (shape::isScalar(x) && !shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); - } else if (!shape::isScalar(x) && shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (ShapeUtils::areShapesBroadcastable(x, y)) { - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else { - // in this case we'll throw exception later - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } + if (shape::isEmpty(x) || shape::isEmpty(y)) { + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || + (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } - return shapeList; - } + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else if (shape::isScalar(x) && shape::isScalar(y)) { + if (shape::rank(x) >= shape::rank(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(y, dtype))); } -} \ No newline at end of file + } else if (shape::equalsSoft(x, y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (shape::isScalar(x) && !shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(y, dtype))); + } else if (!shape::isScalar(x) && shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (ShapeUtils::areShapesBroadcastable(x, y)) { + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else { + // in this case we'll throw exception later + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp index c567a27d7b88..9838a1daec7c 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp @@ -18,69 +18,81 @@ // Created by raver on 6/6/2018. // +#include +#include #include #include -#include -#include namespace sd { - namespace ops { - BroadcastableOp::BroadcastableOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) { - // - } - - ShapeList *BroadcastableOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto shapeList = SHAPELIST(); - auto x = inputShape->at(0); - auto y = inputShape->at(1); - auto outputs = _descriptor->getOutputTypesForOutput(0); - sd::DataType dtype; - if (!(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { - if (Environment::getInstance()->isExperimentalBuild()) { - if (shape::length(y) > shape::length(x)) { - dtype = DataTypeUtils::pickPairwiseResultType(y, x); - } else { - dtype = DataTypeUtils::pickPairwiseResultType(x, y); - } - } else { - dtype = ArrayOptions::dataType(x); - } - } else - dtype = sd::DataType::BOOL; - - if(shape::isEmpty(x) || shape::isEmpty(y)) { - // this is edge case, [3, 4] + [] = [] - if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); - return shapeList; - } +namespace ops { +BroadcastableOp::BroadcastableOp(const char *name, int numTArgs, int numIArgs) + : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, + numIArgs) { + // +} +ShapeList *BroadcastableOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto shapeList = SHAPELIST(); + auto x = inputShape->at(0); + auto y = inputShape->at(1); + auto outputs = _descriptor->getOutputTypesForOutput(0); + sd::DataType dtype; + if (!(outputs.size() == 1 && outputs[0] == sd::DataType::BOOL)) { + if (Environment::getInstance()->isExperimentalBuild()) { + if (shape::length(y) > shape::length(x)) { + dtype = DataTypeUtils::pickPairwiseResultType(y, x); + } else { + dtype = DataTypeUtils::pickPairwiseResultType(x, y); + } + } else { + dtype = ArrayOptions::dataType(x); + } + } else + dtype = sd::DataType::BOOL; - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else if (shape::isScalar(x) && shape::isScalar(y)) { - if (shape::rank(x) >= shape::rank(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); - } - } else if (shape::equalsSoft(x, y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (shape::isScalar(x) && !shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype))); - } else if (!shape::isScalar(x) && shape::isScalar(y)) { - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } else if (ShapeUtils::areShapesBroadcastable(x, y)) { - const Nd4jLong *newshape = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); - } else { - // in this case we'll throw exception later - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); - } + if (shape::isEmpty(x) || shape::isEmpty(y)) { + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || + (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } - return shapeList; - } + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else if (shape::isScalar(x) && shape::isScalar(y)) { + if (shape::rank(x) >= shape::rank(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(y, dtype))); } -} \ No newline at end of file + } else if (shape::equalsSoft(x, y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (shape::isScalar(x) && !shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(y, dtype))); + } else if (!shape::isScalar(x) && shape::isScalar(y)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } else if (ShapeUtils::areShapesBroadcastable(x, y)) { + const Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newshape, dtype))); + } else { + // in this case we'll throw exception later + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(x, dtype))); + } + + return shapeList; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp index d6227af0cd91..4169b055eca7 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableCustomOp.cpp @@ -22,9 +22,13 @@ #include namespace sd { - namespace ops { - DeclarableCustomOp::DeclarableCustomOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - // - } - } -} \ No newline at end of file +namespace ops { +DeclarableCustomOp::DeclarableCustomOp(int numInputs, int numOutputs, + const char *opName, bool allowsInplace, + int tArgs, int iArgs) + : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, + iArgs) { + // +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp index 64f7b09cff74..8dbf5275abeb 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableListOp.cpp @@ -18,130 +18,137 @@ // @author raver119@gmail.com // -#include -#include #include #include #include +#include +#include namespace sd { - namespace ops { - DeclarableListOp::DeclarableListOp(int numInputs, int numOutputs, const char* opName, int tArgs, int iArgs) : DeclarableOp::DeclarableOp(numInputs, numOutputs, opName, false, tArgs, iArgs) { - // This kind of operations work with sets: NDArrayList - this->getOpDescriptor()->setInputType(InputType_NUMERIC_SET); - } +namespace ops { +DeclarableListOp::DeclarableListOp(int numInputs, int numOutputs, + const char* opName, int tArgs, int iArgs) + : DeclarableOp::DeclarableOp(numInputs, numOutputs, opName, false, tArgs, + iArgs) { + // This kind of operations work with sets: NDArrayList + this->getOpDescriptor()->setInputType(InputType_NUMERIC_SET); +} /* template void DeclarableListOp::execute(Block& block) { // } */ - /** - * This method just outputs scalar buffer - * - * @tparam T - * @param inputShape - * @param block - * @return - */ - ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, sd::graph::Context& block) { - // TODO: ensure this method isn't ever called - - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(DataType::FLOAT32, 'c', {1, 1}); - return SHAPELIST(newShape); - } - - sd::NDArray* sd::ops::DeclarableListOp::getZ(Context& block, int inputId) { - //nd4j_printf("wow\n",""); - return nullptr; - } - - void DeclarableListOp::setupResult(const NDArray &array, Context& block) { - block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); - } - - void DeclarableListOp::setupResultList(const NDArrayList &arrayList, Context& block) { - block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); - } - - - Nd4jStatus DeclarableListOp::execute(Context* block) { - if (block == nullptr) - throw std::invalid_argument("Block is NULL"); - - nd4j_debug("Executing list op: [%s]\n", this->getOpName().c_str()); - - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); - - // we shouldn't call for this in ListOp - //this->prepareOutputs(*block); - - auto timeStart = std::chrono::system_clock::now(); - - Nd4jStatus status = this->validateAndExecute(*block); - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - block->setInnerTime(outerTime); - - return status; - } - - ResultSet DeclarableListOp::execute(const NDArrayList &list, const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs) { - VariableSpace varSpace; - int nodeId = 119; - - // should be never used in practice, since in-graph NDArrayList should have id set - int cnt = -1; - std::vector in; - - // first input must be our NDArrayList, except create_list op. it creates list itself. - if (getOpName() != "create_list") { - auto listVar = std::make_shared(list, "", cnt); - varSpace.putVariable(cnt, listVar); - in.push_back(cnt--); - } - - for (auto v: inputs) { - auto var = std::make_shared(*v, "", cnt); - in.push_back(cnt); - varSpace.putVariable(cnt--, var); - } - - Context block(1, &varSpace, false); - block.fillInputs(in); - - for (int e = 0; e < tArgs.size(); e++) - block.appendT(tArgs.at(e)); - - - for (int e = 0; e < iArgs.size(); e++) - block.appendI(iArgs.at(e)); - - - Nd4jStatus result = this->validateAndExecute(block); - ResultSet res; - res.setStatus(result); - - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (varSpace.hasVariable(pair)) { - auto var = varSpace.getVariable(pair); - if (var->hasNDArray()) { - auto arr = var->getNDArray(); - if (arr->isAttached()) { - res.push_back(arr->detach()); - } else { - var->markRemovable(false); - res.push_back(*arr); - } - } - } else - break; - } - - return res; +/** + * This method just outputs scalar buffer + * + * @tparam T + * @param inputShape + * @param block + * @return + */ +ShapeList* DeclarableListOp::calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) { + // TODO: ensure this method isn't ever called + + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + DataType::FLOAT32, 'c', {1, 1}); + return SHAPELIST(newShape); +} + +sd::NDArray* sd::ops::DeclarableListOp::getZ(Context& block, int inputId) { + // nd4j_printf("wow\n",""); + return nullptr; +} + +void DeclarableListOp::setupResult(const NDArray& array, Context& block) { + block.pushNDArrayToVariableSpace(block.getNodeId(), 0, array); +} + +void DeclarableListOp::setupResultList(const NDArrayList& arrayList, + Context& block) { + block.pushNDArrayListToVariableSpace(block.getNodeId(), 0, arrayList); +} + +Nd4jStatus DeclarableListOp::execute(Context* block) { + if (block == nullptr) throw std::invalid_argument("Block is NULL"); + + nd4j_debug("Executing list op: [%s]\n", this->getOpName().c_str()); + + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); + + // we shouldn't call for this in ListOp + // this->prepareOutputs(*block); + + auto timeStart = std::chrono::system_clock::now(); + + Nd4jStatus status = this->validateAndExecute(*block); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + block->setInnerTime(outerTime); + + return status; +} + +ResultSet DeclarableListOp::execute(const NDArrayList& list, + const std::vector& inputs, + const std::vector& tArgs, + const std::vector& iArgs) { + VariableSpace varSpace; + int nodeId = 119; + + // should be never used in practice, since in-graph NDArrayList should have id + // set + int cnt = -1; + std::vector in; + + // first input must be our NDArrayList, except create_list op. it creates list + // itself. + if (getOpName() != "create_list") { + auto listVar = std::make_shared(list, "", cnt); + varSpace.putVariable(cnt, listVar); + in.push_back(cnt--); + } + + for (auto v : inputs) { + auto var = std::make_shared(*v, "", cnt); + in.push_back(cnt); + varSpace.putVariable(cnt--, var); + } + + Context block(1, &varSpace, false); + block.fillInputs(in); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + Nd4jStatus result = this->validateAndExecute(block); + ResultSet res; + res.setStatus(result); + + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (varSpace.hasVariable(pair)) { + auto var = varSpace.getVariable(pair); + if (var->hasNDArray()) { + auto arr = var->getNDArray(); + if (arr->isAttached()) { + res.push_back(arr->detach()); + } else { + var->markRemovable(false); + res.push_back(*arr); } - } -} \ No newline at end of file + } + } else + break; + } + + return res; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 6eb014d7edca..27323c1d65f2 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -18,1136 +18,1216 @@ // @author raver119@gmail.com // -#include -#include -#include #include +#include #include +#include #include -#include -#include +#include #include +#include +#include + #include namespace sd { - namespace ops { - Nd4jStatus conditionHelper(const char *file, int line, int condition, int argNumber, const char *format, ...) { - if (!condition) { - va_list args; - - printf("Error at [%s:%i:%i]:\n", file, line, argNumber); - va_start(args, format); - vprintf(format, args); - va_end(args); - printf("\n"); - fflush(stdout); - - return ND4J_STATUS_BAD_PARAMS; - } - return ND4J_STATUS_OK; - } - - DeclarableOp::DeclarableOp() { - // no-op - } - - DeclarableOp::DeclarableOp(const char *name, bool isLogical) { - _descriptor = new OpDescriptor(name, isLogical); - _name = name; - } - - DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { - _descriptor = new OpDescriptor(numInputs, name, scalar); - _name = name; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); - _name = opName; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); - _name = opName; - } - - DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) { - _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); - _name = opName; - } - - DeclarableOp::~DeclarableOp() { - if (_descriptor != nullptr) - delete _descriptor; - - if (_scalar != nullptr) - delete _scalar; - } +namespace ops { +Nd4jStatus conditionHelper(const char *file, int line, int condition, + int argNumber, const char *format, ...) { + if (!condition) { + va_list args; + + printf("Error at [%s:%i:%i]:\n", file, line, argNumber); + va_start(args, format); + vprintf(format, args); + va_end(args); + printf("\n"); + fflush(stdout); + + return ND4J_STATUS_BAD_PARAMS; + } + return ND4J_STATUS_OK; +} + +DeclarableOp::DeclarableOp() { + // no-op +} + +DeclarableOp::DeclarableOp(const char *name, bool isLogical) { + _descriptor = new OpDescriptor(name, isLogical); + _name = name; +} + +DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { + _descriptor = new OpDescriptor(numInputs, name, scalar); + _name = name; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace) { + _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); + _name = opName; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace, bool divergent) { + _descriptor = + new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); + _name = opName; +} + +DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, + bool allowsInplace, int tArgs, int iArgs) { + _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, + tArgs, iArgs); + _name = opName; +} + +DeclarableOp::~DeclarableOp() { + if (_descriptor != nullptr) delete _descriptor; + + if (_scalar != nullptr) delete _scalar; +} + +OpDescriptor *DeclarableOp::getOpDescriptor() { return _descriptor; } + +const std::string &DeclarableOp::getOpName() const { + return *_descriptor->getOpName(); +} + +Nd4jLong DeclarableOp::getOpHash() const { return _descriptor->getHash(); } + +sd::NDArray *sd::ops::DeclarableOp::getNullifiedZ(Context &block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && !block.isInplace()) result->nullify(); + + return result; +} + +sd::NDArray *sd::ops::DeclarableOp::getZ(Context &ctx, int inputId) { + NDArray *z = nullptr; + + if (ctx.isFastPath()) { + if (ctx.fastpath_out().size() <= inputId) { + if (ctx.isInplace()) { + z = ctx.fastpath_in()[inputId].get(); + } else + throw std::runtime_error("fastpath_out: unresolved output array"); + } else { + z = ctx.fastpath_out()[inputId].get(); + } + } else { + std::pair pair(ctx.nodeId(), inputId); + + if (ctx.isInplace()) { + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); + + // hypothetically it's possible to have no variable. chances are low, but + // who knows. let's just create it for now + if (!ctx.getVariableSpace()->hasVariable(pair)) { + auto var = std::make_shared(); + ctx.getVariableSpace()->putVariable(pair, var); + } + + // now we're saving input array as output array + auto var = ctx.getVariableSpace()->getVariable(pair); + var->markRemovable(false); + var->setNDArray(vz); + } else if (!ctx.isInplace()) { + auto var = ctx.variable(pair); + if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { + z = var->getNDArray().get(); + } else { + nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); + } + } else { + nd4j_printf("BOOM!\n", ""); + throw std::runtime_error("Boom!"); + } + } - OpDescriptor* DeclarableOp::getOpDescriptor() { - return _descriptor; - } + if (z != nullptr && z->undefined()) return nullptr; - const std::string& DeclarableOp::getOpName() const { - return *_descriptor->getOpName(); - } + return z; +} - Nd4jLong DeclarableOp::getOpHash() const { - return _descriptor->getHash(); - } +int DeclarableOp::prepareOutputs(Context &ctx) { + auto workspace = ctx.workspace(); + GraphProfile *prof = nullptr; + NodeProfile *node = nullptr; + std::chrono::time_point inputEnd, inputStart, + shapeStart, shapeEnd, arrayStart, arrayEnd; + bool canUseFastPath = true; - sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) { - auto result = getZ(block, inputId); - if (result != nullptr && !block.isInplace()) - result->nullify(); + auto fp = ctx.isFastPath(); - return result; + if (Environment::getInstance()->isProfiling()) { + /* + if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() + != nullptr) { prof = ctx.getVariableSpace()->flowPath()->profile(); node = + prof->nodeById(ctx.nodeId()); + } + */ + throw std::runtime_error( + "DeclarableOp::prepareOutputs - Not implemented yet"); + } + + if (ctx.isInplace()) { + if (Environment::getInstance()->isProfiling() && node != nullptr) { + if (fp) { + // + } else { + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray().get(); + + node->addInputShape(array->shapeInfo()); + node->addOutputShape(array->shapeInfo()); + } } + } + } - - sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) { - NDArray* z = nullptr; - - if (ctx.isFastPath()) { - if (ctx.fastpath_out().size() <= inputId) { - if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId].get(); - } else - throw std::runtime_error("fastpath_out: unresolved output array"); - } else { - z = ctx.fastpath_out()[inputId].get(); - } + // if that's not fp, we can still propagate inputs and outputs + if (!fp) { + int cnt = 0; + auto id = ctx.nodeId(); + auto vs = ctx.getVariableSpace(); + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray(); + ctx.setInputArray(cnt, array); + ctx.setOutputArray(cnt, array); + + // in case of this override we might need to update outputs in the + // Graph VariableSpace as well + if (vs != nullptr) { + if (vs->hasVariable(id, cnt)) { + auto v2 = vs->getVariable(id, cnt); + if (!v2->hasNDArray()) { + v2->setNDArray(array); + v2->markRemovable(false); + } } else { - std::pair pair(ctx.nodeId(), inputId); - - if (ctx.isInplace()) { - auto vz = ctx.variable(inputId)->getNDArray(); - z = vz.get(); - - // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now - if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = std::make_shared(); - ctx.getVariableSpace()->putVariable(pair, var); - } - - // now we're saving input array as output array - auto var = ctx.getVariableSpace()->getVariable(pair); - var->markRemovable(false); - var->setNDArray(vz); - } else if (!ctx.isInplace()) { - auto var = ctx.variable(pair); - if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray().get(); - } else { - nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); - } - } else { - nd4j_printf("BOOM!\n", ""); - throw std::runtime_error("Boom!"); - } + auto v2 = vs->putVariable(id, cnt, array); + v2->markRemovable(false); } + } - if (z != nullptr && z->undefined()) - return nullptr; - - return z; + cnt++; + } else { + canUseFastPath = false; } + } + } - int DeclarableOp::prepareOutputs(Context &ctx) { - auto workspace = ctx.workspace(); - GraphProfile *prof = nullptr; - NodeProfile *node = nullptr; - std::chrono::time_point inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; - bool canUseFastPath = true; - - auto fp = ctx.isFastPath(); - - if (Environment::getInstance()->isProfiling()) { - /* - if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) { - prof = ctx.getVariableSpace()->flowPath()->profile(); - node = prof->nodeById(ctx.nodeId()); - } - */ - throw std::runtime_error("DeclarableOp::prepareOutputs - Not implemented yet"); - } - - if (ctx.isInplace()) { - if (Environment::getInstance()->isProfiling() && node != nullptr) { - if (fp) { - // - } else { - for (auto p: ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - auto array = var->getNDArray().get(); - - node->addInputShape(array->shapeInfo()); - node->addOutputShape(array->shapeInfo()); - } - } - } - } - - // if that's not fp, we can still propagate inputs and outputs - if (!fp) { - int cnt = 0; - auto id = ctx.nodeId(); - auto vs = ctx.getVariableSpace(); - for (auto p: ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - auto array = var->getNDArray(); - ctx.setInputArray(cnt, array); - ctx.setOutputArray(cnt, array); - - - // in case of this override we might need to update outputs in the Graph VariableSpace as well - if (vs != nullptr) { - if (vs->hasVariable(id, cnt)) { - auto v2 = vs->getVariable(id, cnt); - if (!v2->hasNDArray()) { - v2->setNDArray(array); - v2->markRemovable(false); - - } - } else { - auto v2 = vs->putVariable(id, cnt, array); - v2->markRemovable(false); - } - } - - cnt++; - } else { - canUseFastPath = false; - } - } - } - - if (!canUseFastPath) - ctx.forbidFastPath(true); - - // do nothing, getZ result will do the trick - return static_cast(ctx.width()); - } else { - // if op is not inplace - we should pre-allocate arrays - ShapeList inSha; - int results = 0; - - if (Environment::getInstance()->isProfiling() && node != nullptr) - inputStart = std::chrono::system_clock::now(); - - int cntIn = 0; - // we build list of input shapes - if (fp) { - for (const auto p:ctx.fastpath_in()) { - inSha.push_back(p == nullptr ? nullptr : p->shapeInfo()); - } - } else { - int arrCnt = 0; - for (auto p: ctx.inputs()) { - auto var = ctx.variable(p); - if (var->variableType() == VariableType::NDARRAY) { - auto array = var->getNDArray(); - if (array.get() == nullptr) - throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p); - - inSha.push_back(array->shapeInfo()); - - // we're also filling ctx with arrays - if (canUseFastPath) - ctx.setInputArray(arrCnt++, array); - } else { - canUseFastPath = false; - } - cntIn++; - } - } - - // if we override shape function, we'll return size of fastPath - if (fp && ctx.shapeFunctionOverride()) { - return (int) ctx.fastpath_out().size(); - } - - // optionally saving input time - if (Environment::getInstance()->isProfiling() && node != nullptr) { - inputEnd = std::chrono::system_clock::now(); - auto inputTime = std::chrono::duration_cast(inputEnd - inputStart).count(); - node->setInputTime(inputTime); - - // saving output shapes in profile - for (int e = 0; e < inSha.size(); e++) - node->addInputShape(inSha.at(e)); - - shapeStart = std::chrono::system_clock::now(); - } - - auto outSha = this->calculateOutputShape(&inSha, ctx); - results = outSha->size(); - - // optionally saving shapeTime - if (Environment::getInstance()->isProfiling() && node != nullptr) { - shapeEnd = std::chrono::system_clock::now(); - auto prepTime = std::chrono::duration_cast(shapeEnd - shapeStart).count(); - node->setShapeFunctionTime(prepTime); - - // saving output shapes in profile - for (int e = 0; e < outSha->size(); e++) - node->addOutputShape(outSha->at(e)); - - arrayStart = std::chrono::system_clock::now(); - } - - int cnt = 0; - - for (auto out: *outSha->asVector()) { - if (!fp) { - // we need to check, if Z is really needed - std::pair pair(ctx.nodeId(), cnt++); - - if (!ctx.isValueAvailable(ctx.name(), ctx.nodeId(), pair.second)) { - if (Environment::getInstance()->isDebugAndVerbose()) - shape::printShapeInfoLinear("Going to create variable with shape", out); - - // we're creating non-initialized array here - NDArray outArr(out, true, ctx.launchContext(), false); - - ctx.pushNDArrayToVariableSpace(pair, outArr); - - if (canUseFastPath) - ctx.setOutputArray(pair.second, outArr); - } else { - // validate/compare shapes here. existent vs provided in outSha - auto var = ctx.variable(pair); - auto shape = var->getNDArray()->shapeInfo(); - - if (canUseFastPath) - ctx.setOutputArray(pair.second, var->getNDArray()); - - if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) { - auto eShape = ShapeUtils::shapeAsString(out); - auto aShape = ShapeUtils::shapeAsString(shape); - - //outSha->destroy(); - delete outSha; - - nd4j_printf("Expected vs provided shapes mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), pair.second); - throw std::runtime_error("Expected vs provided shapes mismatch"); - } - - /* - * FIXME: we want to uncomment this eventually, and check data types equality - //checking out data type equality - if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { - std::string msg = "Provided array [" + StringUtils::valueToString(pair.second) + "] has unexpected data type"; - throw sd::datatype_exception::build(msg, ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); - } - */ - } - } else { - auto fout = ctx.fastpath_out(); - auto idx = cnt++; - if (fout.size() <= idx) { - // array doesnt exist - auto outArr = std::make_shared(out, true, ctx.launchContext()); - ctx.setOutputArray(idx, outArr); - } else { - auto array = fout[idx]; - // checking out shape equality - if (!shape::equalsSoft(out, array->shapeInfo()) || shape::isEmpty(out) != array->isEmpty()) { - auto eShape = ShapeUtils::shapeAsString(out); - auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); - - //outSha->destroy(); - delete outSha; - - nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx); - throw std::runtime_error("Expected vs provided shape mismatch"); - } - } - } - } - - if (!canUseFastPath) - ctx.forbidFastPath(true); - - delete outSha; - - // saving arrayTime - if (Environment::getInstance()->isProfiling() && node != nullptr) { - arrayEnd = std::chrono::system_clock::now(); - auto arrayTime = std::chrono::duration_cast(arrayEnd - arrayStart).count(); - node->setArrayTime(arrayTime); - } - - return results; - } + if (!canUseFastPath) ctx.forbidFastPath(true); + + // do nothing, getZ result will do the trick + return static_cast(ctx.width()); + } else { + // if op is not inplace - we should pre-allocate arrays + ShapeList inSha; + int results = 0; + + if (Environment::getInstance()->isProfiling() && node != nullptr) + inputStart = std::chrono::system_clock::now(); + + int cntIn = 0; + // we build list of input shapes + if (fp) { + for (const auto p : ctx.fastpath_in()) { + inSha.push_back(p == nullptr ? nullptr : p->shapeInfo()); + } + } else { + int arrCnt = 0; + for (auto p : ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + auto array = var->getNDArray(); + if (array.get() == nullptr) + throw unresolved_input_exception::build( + "Variable wasn't resolved prior shape calculation", p); + + inSha.push_back(array->shapeInfo()); + + // we're also filling ctx with arrays + if (canUseFastPath) ctx.setInputArray(arrCnt++, array); + } else { + canUseFastPath = false; } + cntIn++; + } + } - void sd::ops::DeclarableOp::storeResult(Context &block, int outputNumber, NDArray* array) { - this->storeResult(block, outputNumber, *array); - } + // if we override shape function, we'll return size of fastPath + if (fp && ctx.shapeFunctionOverride()) { + return (int)ctx.fastpath_out().size(); + } - void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, int outputNumber, NDArray& array) { - ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, array); - } + // optionally saving input time + if (Environment::getInstance()->isProfiling() && node != nullptr) { + inputEnd = std::chrono::system_clock::now(); + auto inputTime = std::chrono::duration_cast( + inputEnd - inputStart) + .count(); + node->setInputTime(inputTime); - bool sd::ops::DeclarableOp::allocateResult(Context& block, Nd4jLong* shape) { - auto var = block.variable(block.getNodeId(), 0); + // saving output shapes in profile + for (int e = 0; e < inSha.size(); e++) node->addInputShape(inSha.at(e)); - auto workspace = block.workspace(); + shapeStart = std::chrono::system_clock::now(); + } - Nd4jLong len = shape::length(shape); - Nd4jLong* __shape; - ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), Nd4jLong); //new int[shape[0] * 2 + 4]; + auto outSha = this->calculateOutputShape(&inSha, ctx); + results = outSha->size(); - memcpy(__shape, shape, shape::shapeInfoByteLength(shape)); + // optionally saving shapeTime + if (Environment::getInstance()->isProfiling() && node != nullptr) { + shapeEnd = std::chrono::system_clock::now(); + auto prepTime = std::chrono::duration_cast( + shapeEnd - shapeStart) + .count(); + node->setShapeFunctionTime(prepTime); - // if that's first run - we probably have nothing here - if (var->getNDArray() == nullptr) { + // saving output shapes in profile + for (int e = 0; e < outSha->size(); e++) + node->addOutputShape(outSha->at(e)); - std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), block.launchContext())); - } - else if(var->getNDArray()->lengthOf() != len) { - // if length not match - lets reallocate array - std::shared_ptr buffer = std::make_shared(len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); - var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), block.launchContext())); - } + arrayStart = std::chrono::system_clock::now(); + } - return true; + int cnt = 0; + + for (auto out : *outSha->asVector()) { + if (!fp) { + // we need to check, if Z is really needed + std::pair pair(ctx.nodeId(), cnt++); + + if (!ctx.isValueAvailable(ctx.name(), ctx.nodeId(), pair.second)) { + if (Environment::getInstance()->isDebugAndVerbose()) + shape::printShapeInfoLinear("Going to create variable with shape", + out); + + // we're creating non-initialized array here + NDArray outArr(out, true, ctx.launchContext(), false); + + ctx.pushNDArrayToVariableSpace(pair, outArr); + + if (canUseFastPath) ctx.setOutputArray(pair.second, outArr); + } else { + // validate/compare shapes here. existent vs provided in outSha + auto var = ctx.variable(pair); + auto shape = var->getNDArray()->shapeInfo(); + + if (canUseFastPath) + ctx.setOutputArray(pair.second, var->getNDArray()); + + if (!shape::equalsSoft(out, shape) || + shape::isEmpty(out) != shape::isEmpty(shape)) { + auto eShape = ShapeUtils::shapeAsString(out); + auto aShape = ShapeUtils::shapeAsString(shape); + + // outSha->destroy(); + delete outSha; + + nd4j_printf( + "Expected vs provided shapes mismatch %s vs %s at index %i\n", + eShape.c_str(), aShape.c_str(), pair.second); + throw std::runtime_error("Expected vs provided shapes mismatch"); + } + + /* + * FIXME: we want to uncomment this eventually, and check data types + equality + //checking out data type equality + if (ArrayOptions::dataType(out) != ArrayOptions::dataType(shape)) { + std::string msg = "Provided array [" + + StringUtils::valueToString(pair.second) + "] has unexpected data + type"; throw sd::datatype_exception::build(msg, + ArrayOptions::dataType(out), ArrayOptions::dataType(shape)); + } + */ } - - - bool sd::ops::DeclarableOp::allocateResult(Context& block, std::initializer_list& shape, char order) { - auto var = block.variable(block.getNodeId(), 0); - auto workspace = block.workspace(); - - Nd4jLong len = shape::length(shape); - // if that's first run - we probably have nothing here - if (var->getNDArray() == nullptr) { - var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, block.launchContext())); - } else if(var->getNDArray()->lengthOf() != len) { - // if length not match - lets reallocate array - var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, block.launchContext())); - } - - return true; + } else { + auto fout = ctx.fastpath_out(); + auto idx = cnt++; + if (fout.size() <= idx) { + // array doesnt exist + auto outArr = + std::make_shared(out, true, ctx.launchContext()); + ctx.setOutputArray(idx, outArr); + } else { + auto array = fout[idx]; + // checking out shape equality + if (!shape::equalsSoft(out, array->shapeInfo()) || + shape::isEmpty(out) != array->isEmpty()) { + auto eShape = ShapeUtils::shapeAsString(out); + auto aShape = ShapeUtils::shapeAsString(array->shapeInfo()); + + // outSha->destroy(); + delete outSha; + + nd4j_printf( + "Expected vs provided shape mismatch %s vs %s at index %i\n", + eShape.c_str(), aShape.c_str(), idx); + throw std::runtime_error("Expected vs provided shape mismatch"); + } } + } + } - Nd4jStatus sd::ops::DeclarableOp::validateDataTypes(Context& block) { - _registrator.lock(); - if (!_registered) { - _registered = true; - this->registerTypes(); - } - _registrator.unlock(); - - // rolling over inputs first - int cnt = 0, inT = 0; - std::vector inputTypes(block.width()); - if (block.isFastPath()) { - for (auto array: block.fastpath_in()) { - if (array == nullptr) - continue; - - inputTypes[inT++] = array->dataType(); - if (!_descriptor->checkInputMatch(cnt, array->dataType())) { - auto ctype = DataTypeUtils::asString(array->dataType()); - nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), cnt, ctype.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - cnt++; - } - } else { - for (auto &p: block.inputs()) { - auto var = block.variable(p); - - // we're not checking validity, if ANY types were explicitly allowed - //if (block.dataType(cnt) == sd::DataType::ANY) - // continue; - - // only validating non-null variables - if (var != nullptr && var->hasNDArray()) { - auto array = var->getNDArray(); - - inputTypes[inT++] = array->dataType(); - if (!_descriptor->checkInputMatch(cnt, array->dataType())) { - auto ctype = DataTypeUtils::asString(array->dataType()); - nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), cnt, ctype.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - - cnt++; - } - } + if (!canUseFastPath) ctx.forbidFastPath(true); - if (block.isFastPath()) { - int index = 0; - for (auto array: block.fastpath_out()) { - if (array == nullptr) - continue; - - auto cType = array->dataType(); - - if (_descriptor->isSameMode()) { - - if (index >= block.width()) { - if (block.fastpath_in().size() == 0) - continue; - - auto ia = block.fastpath_in()[0]; - - if (ia->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } else { - // for same mode, output type must be the same as input type - auto ia = block.fastpath_in()[index]; - - if (ia->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else if (_descriptor->isInherit(index)) { - // in inherit mode, output type must be the same as one of input types - if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - - } else if (!_descriptor->checkOutputMatch(index, cType)) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - index++; - } - } else { - // checking optionally available outputs - auto varSpace = block.getVariableSpace(); - for (int index = 0; index < DataTypeUtils::max(); index++) { - if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { - auto var = block.variable(block.nodeId(), index); - - // only validating non-null variables - if (var != nullptr && var->hasNDArray()) { - auto array = var->getNDArray(); - auto cType = array->dataType(); - - if (_descriptor->isSameMode()) { - - if (index >= block.width()) { - if (block.width() == 0) - continue; - - auto iv = block.variable(0); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } else { - // for same mode, output type must be the same as input type - auto iv = block.variable(index); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else if (_descriptor->isInherit(index)) { - // in inherit mode, output type must be the same as one of input types - if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - - } else if (!_descriptor->checkOutputMatch(index, cType)) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", - _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else - break; - } - } + delete outSha; + // saving arrayTime + if (Environment::getInstance()->isProfiling() && node != nullptr) { + arrayEnd = std::chrono::system_clock::now(); + auto arrayTime = std::chrono::duration_cast( + arrayEnd - arrayStart) + .count(); + node->setArrayTime(arrayTime); + } - return ND4J_STATUS_OK; + return results; + } +} + +void sd::ops::DeclarableOp::storeResult(Context &block, int outputNumber, + NDArray *array) { + this->storeResult(block, outputNumber, *array); +} + +void sd::ops::DeclarableOp::storeResult(sd::graph::Context &ctx, + int outputNumber, NDArray &array) { + ctx.pushNDArrayToVariableSpace(ctx.nodeId(), outputNumber, array); +} + +bool sd::ops::DeclarableOp::allocateResult(Context &block, Nd4jLong *shape) { + auto var = block.variable(block.getNodeId(), 0); + + auto workspace = block.workspace(); + + Nd4jLong len = shape::length(shape); + Nd4jLong *__shape; + ALLOCATE(__shape, workspace, shape::shapeInfoLength(shape), + Nd4jLong); // new int[shape[0] * 2 + 4]; + + memcpy(__shape, shape, shape::shapeInfoByteLength(shape)); + + // if that's first run - we probably have nothing here + if (var->getNDArray() == nullptr) { + std::shared_ptr buffer = std::make_shared( + len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), + block.launchContext())); + } else if (var->getNDArray()->lengthOf() != len) { + // if length not match - lets reallocate array + std::shared_ptr buffer = std::make_shared( + len * sizeof(int8_t), ArrayOptions::dataType(__shape), workspace); + var->setNDArray(std::make_shared(buffer, ShapeDescriptor(__shape), + block.launchContext())); + } + + return true; +} + +bool sd::ops::DeclarableOp::allocateResult( + Context &block, std::initializer_list &shape, char order) { + auto var = block.variable(block.getNodeId(), 0); + auto workspace = block.workspace(); + + Nd4jLong len = shape::length(shape); + // if that's first run - we probably have nothing here + if (var->getNDArray() == nullptr) { + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, + block.launchContext())); + } else if (var->getNDArray()->lengthOf() != len) { + // if length not match - lets reallocate array + var->setNDArray(std::make_shared(order, shape, DataType::FLOAT32, + block.launchContext())); + } + + return true; +} + +Nd4jStatus sd::ops::DeclarableOp::validateDataTypes(Context &block) { + _registrator.lock(); + if (!_registered) { + _registered = true; + this->registerTypes(); + } + _registrator.unlock(); + + // rolling over inputs first + int cnt = 0, inT = 0; + std::vector inputTypes(block.width()); + if (block.isFastPath()) { + for (auto array : block.fastpath_in()) { + if (array == nullptr) continue; + + inputTypes[inT++] = array->dataType(); + if (!_descriptor->checkInputMatch(cnt, array->dataType())) { + auto ctype = DataTypeUtils::asString(array->dataType()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + cnt++; + } + } else { + for (auto &p : block.inputs()) { + auto var = block.variable(p); + + // we're not checking validity, if ANY types were explicitly allowed + // if (block.dataType(cnt) == sd::DataType::ANY) + // continue; + + // only validating non-null variables + if (var != nullptr && var->hasNDArray()) { + auto array = var->getNDArray(); + + inputTypes[inT++] = array->dataType(); + if (!_descriptor->checkInputMatch(cnt, array->dataType())) { + auto ctype = DataTypeUtils::asString(array->dataType()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } + } - Nd4jStatus DeclarableOp::execute(Context* block) { - nd4j_debug("Executing op: [%s]\n", this->getOpName().c_str()); - - std::chrono::time_point timeEnter, timeStart, timeEnd; - Nd4jLong prepTime, outerTime; - - Nd4jLong memoryBefore = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); - if (Environment::getInstance()->isProfiling()) - timeEnter = std::chrono::system_clock::now(); - - // make sure we're not trying to call non-inpace op inplace - if (block->isInplace() && !this->getOpDescriptor()->allowsInplace()) - throw std::runtime_error("DeclarableOp::execute - trying to execute non-inplace op as inplace"); - - // basic validation: ensure inputs are set - REQUIRE_OK(this->validateNonEmptyInput(*block)); - - // ensure number of IArgs, TArgs match our expectations - REQUIRE_OK(this->validateArguments(*block)); - - // validating data types for inputs and (optionally) outputs - REQUIRE_OK(this->validateDataTypes(*block)); - - - - - // this method will allocate output NDArrays for this op - auto numOutputs = this->prepareOutputs(*block); - - if (Environment::getInstance()->isProfiling()) { - timeStart = std::chrono::system_clock::now(); - prepTime = std::chrono::duration_cast(timeStart - timeEnter).count(); - } - - - Nd4jStatus status; - bool hasHelper = false; - - // platform helpers use might be forbidden for various reasons, so we'll check it out first - if (block->helpersAllowed() && sd::Environment::getInstance()->helpersAllowed()) { - // if we have platform-specific helper for this op - invoke it - if (OpRegistrator::getInstance()->hasHelper(this->getOpHash(), block->engine())) { - auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash(), block->engine()); - if (helper->isUsable(*block)) { - status = helper->invokeHelper(*block); - hasHelper = true; - } - } - } - - // if we don't have platform-specific helper - invoke generic implementation - if (!hasHelper) - status = this->validateAndExecute(*block); - - // optionally saving execution time - if (Environment::getInstance()->isProfiling()) { - timeEnd = std::chrono::system_clock::now(); - outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - block->setInnerTime(outerTime); - } - - if (Environment::getInstance()->isProfiling() && block->getVariableSpace() != nullptr) { - /* - auto fp = block->getVariableSpace()->flowPath(); - if (fp != nullptr) { - auto p = fp->profile(); - if (p != nullptr) { - Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); - Nd4jLong memoryUsed = memoryAfter - memoryBefore; - p->nodeById(block->nodeId())->setPreparationTime(prepTime); - p->nodeById(block->nodeId())->setExecutionTime(outerTime); - p->nodeById(block->nodeId())->setTotalSize(memoryUsed); - } - } - */ - throw std::runtime_error("DeclarableOp::execute - Not implemented yet"); - } - - - // now we print out all outputs for this node - if (sd::Environment::getInstance()->isDebugAndVerbose()) { - auto vs = block->getVariableSpace(); - - for (int e = 0; e < numOutputs; e++) { - // if given output index doesn't exist - we're done - - if (!block->isFastPath()) { - if (!vs->hasVariable(block->nodeId(), e)) - break; - } else { - // we have to check either in or out stack, depending on isInplace() - if (block->isInplace()) { - if (block->fastpath_in().size() <= e) - break; - } else { - if (block->fastpath_out().size() <= e) - break; - } - } - - auto array = block->isFastPath() ? block->isInplace() ? block->fastpath_in()[e] : block->fastpath_out()[e] : vs->getVariable(block->nodeId(), e)->getNDArray(); - - auto shape = ShapeUtils::shapeAsString(array.get()); - auto first = array->isEmpty() ? std::string("Empty NDArray") : array->asString(32); - auto type = DataTypeUtils::asString(array->dataType()); - - nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", block->nodeId(), e, shape.c_str(), type.c_str(), first.c_str()); - } - } - - return status; + cnt++; + } + } + + if (block.isFastPath()) { + int index = 0; + for (auto array : block.fastpath_out()) { + if (array == nullptr) continue; + + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + if (index >= block.width()) { + if (block.fastpath_in().size() == 0) continue; + + auto ia = block.fastpath_in()[0]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } else { + // for same mode, output type must be the same as input type + auto ia = block.fastpath_in()[index]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArray *array) { - throw std::runtime_error("Overwrite result used!"); - //block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array); - /* - auto varSpace = block.getVariableSpace(); - if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { - auto var = varSpace->getVariable(block.getNodeId(), outputIdx); - if (var->getNDArray() != nullptr && var->isRemovable()) - delete var->getNDArray(); - - var->setNDArray(array); - var->markRemovable(true); - } else { - auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx); - varSpace->putVariable(block.getNodeId(), outputIdx, var); - } - */ + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == + inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } - void DeclarableOp::overwriteResult(Context &block, int outputIdx, NDArrayList *list) { - throw std::runtime_error("Overwrite result used!"); - //block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list); - /* - auto varSpace = block.getVariableSpace(); - if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { - auto var = varSpace->getVariable(block.getNodeId(), outputIdx); - var->setNDArrayList(list); + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + index++; + } + } else { + // checking optionally available outputs + auto varSpace = block.getVariableSpace(); + for (int index = 0; index < DataTypeUtils::max(); index++) { + if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { + auto var = block.variable(block.nodeId(), index); + + // only validating non-null variables + if (var != nullptr && var->hasNDArray()) { + auto array = var->getNDArray(); + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + if (index >= block.width()) { + if (block.width() == 0) continue; + + auto iv = block.variable(0); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } else { - auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx); - var->setNDArrayList(list); - varSpace->putVariable(block.getNodeId(), outputIdx, var); - } - */ - } - - Nd4jStatus sd::ops::DeclarableOp::validateArguments(Context& block) { - /* - * We're checking number of T and I arguments. If number of args is finite number - we check strict equality - * If number of args is variable (-1), but variables MUST be present - we check for non-zero number of arguments - */ - if (_descriptor->getNumberOfTArgs() > 0) { - if ((int) block.numT() < _descriptor->getNumberOfTArgs()) { - nd4j_printf("%s: %i T args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), block.numT()); - return ND4J_STATUS_BAD_PARAMS; - } - } else - if (_descriptor->getNumberOfTArgs() == -1) - if (block.numT() == 0) { - nd4j_printf("%s: Number of T arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); - return ND4J_STATUS_BAD_PARAMS; - } - - if (_descriptor->getNumberOfIArgs() > 0) { - if ((int) block.numI() < _descriptor->getNumberOfIArgs()) { - nd4j_printf("%s: %i int args expected, but %i received\n", this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), block.numI()); - return ND4J_STATUS_BAD_PARAMS; - } - } else - if (_descriptor->getNumberOfIArgs() == -1) - if (block.numI() == 0) { - nd4j_printf("%s: Number of Integer arguments should be positive number, but got 0 arguments\n", this->getOpName().c_str()); - return ND4J_STATUS_BAD_PARAMS; - } - - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateInputDimensions(Context& block, int rank) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - for (auto p: block.inputs()) { - auto v = block.variable(p); - NDArray *aV = v->getNDArray().get(); - - if (aV == nullptr) - return ND4J_STATUS_BAD_INPUT; - - if (aV->rankOf() != rank) - return ND4J_STATUS_BAD_DIMENSIONS; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput2D(Context& block) { - return validateInputDimensions(block, 2); - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput3D(Context& block) { - return validateInputDimensions(block, 3); - } - - Nd4jStatus sd::ops::DeclarableOp::validateInput4D(Context& block) { - return validateInputDimensions(block, 4); - } - - Nd4jStatus sd::ops::DeclarableOp::validateNonEmptyInput(Context& block) { - if (this->getOpDescriptor()->getNumberOfInputs() == -2 || this->getOpDescriptor()->getNumberOfInputs() == 0) - return Status::OK(); - - if (block.width() < 1) { - nd4j_printf("%s: no operands provided for the op", this->getOpName().c_str()); - return ND4J_STATUS_BAD_INPUT; - } - - - int cnt = 0; - for (auto p: block.inputs()) { - auto v = block.variable(p); - if (v == nullptr) { - if (!this->getOpName().empty()) { - nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName().c_str(), cnt, p.first, p.second); - } else { - nd4j_printf("Node [%i:]: Variable [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); - } - return ND4J_STATUS_BAD_INPUT; - } - - if (v->variableType() == VariableType::NDARRAY) { - // if array is empty intentionally - we're ok with that - if (v->hasNDArray() && v->isEmpty()) { - continue; - } - - NDArray *aV = v->getNDArray().get(); - - if (aV == nullptr || !aV->nonNull()) { - if (!this->getOpName().empty()) { - nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), this->getOpName().c_str(), cnt, p.first, p.second); - } else { - nd4j_printf("Node [%i:]: NDArray [%i] (%i:%i) is NULL\n", block.getNodeId(), cnt, p.first, p.second); - } - return ND4J_STATUS_BAD_INPUT; - } - } - - cnt++; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::validateOrdersMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - NDArray *a0 = block.variable(0)->getNDArray().get(); - for (auto p: block.inputs()) { - auto v = block.variable(p); - NDArray *aV = v->getNDArray().get(); - if (a0->ordering() != aV->ordering()) - return ND4J_STATUS_BAD_ORDER; - } - - return ND4J_STATUS_OK; - } - - Nd4jStatus sd::ops::DeclarableOp::execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& outputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& bArgs, const std::vector& dArgs, bool isInplace, sd::DataType type) { - VariableSpace variableSpace; - FlowPath fp; - //variableSpace.setFlowPath(&fp); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = std::make_shared(*v, "", cnt); - - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - int et = 0; - for (auto v: outputs) { - std::pair pair(1, et++); - auto var = std::make_shared(*v, "", pair.first, pair.second); - - variableSpace.putVariable(pair, var); - } - - Context block(1, &variableSpace, false); - block.fillInputs(in); - block.markInplace(isInplace); - - // we need this line for tests basically - //if (rng != nullptr) - block.setRng(rng); - - for (int e = 0; e < tArgs.size(); e++) - block.appendT(tArgs.at(e)); - - // FIXME: iargs should be Nd4jLong - for (int e = 0; e < iArgs.size(); e++) - block.appendI(static_cast(iArgs.at(e))); - - for (int e = 0; e < bArgs.size(); e++) - block.appendB(static_cast(bArgs.at(e))); - - for (int e = 0; e < dArgs.size(); e++) - block.appendD(dArgs.at(e)); - - Nd4jStatus result = this->execute(&block); - - return result; - } - - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs) { - return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - return execute(inputs, outputs, tArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list dArgs) { - return execute(inputs, outputs, std::vector(), std::vector(), std::vector(), dArgs); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs) { - std::vector realArgs; - for (auto v:tArgs) - realArgs.emplace_back(v); - - return execute(inputs, outputs, realArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - return execute(inputs, outputs, std::vector(), iArgs, std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list iArgs) { - std::vector realArgs; - for (auto v:iArgs) - realArgs.emplace_back(v); - - return execute(inputs, outputs, std::vector(), realArgs, std::vector(), std::vector()); - } - - template <> - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list bArgs) { - return execute(inputs, outputs, std::vector(), std::vector(), bArgs, std::vector()); - } - - Nd4jStatus DeclarableOp::execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { - Context ctx(1); - - for (int e = 0; e < inputs.size(); e++) { - ctx.setInputArray(e, inputs[e] == nullptr ? NDArray() : *inputs[e]); + // for same mode, output type must be the same as input type + auto iv = block.variable(index); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - for (int e = 0; e < outputs.size(); e++) { - ctx.setOutputArray(e, outputs[e] == nullptr ? NDArray() : *outputs[e]); + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input + // types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == + inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; } - - if (isInplace) - ctx.markInplace(isInplace); - - ctx.setIArguments(iArgs); - ctx.setTArguments(tArgs); - ctx.setBArguments(bArgs); - ctx.setDArguments(dArgs); - - return execute(&ctx); - } - - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs) { - return evaluate(inputs, std::vector(), std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - std::vector realArgs; - for (auto v:iArgs) - realArgs.emplace_back(v); - - return evaluate(inputs, std::vector(), realArgs, std::vector(), std::vector()); + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf( + "Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list iArgs) { - return evaluate(inputs, std::vector(), iArgs, std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - std::vector realArgs; - for (auto v:tArgs) - realArgs.emplace_back(v); - - return evaluate(inputs, realArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list tArgs) { - return evaluate(inputs, tArgs, std::vector(), std::vector(), std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - return evaluate(inputs, std::vector(), std::vector(), bArgs, std::vector()); - } - - template <> - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, std::initializer_list bArgs) { - return evaluate(inputs, std::vector(), std::vector(), std::vector(), bArgs); - } - - sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs, bool isInplace) { - VariableSpace variableSpace; - //ResultSet arrayList; - FlowPath fp; - //variableSpace.setFlowPath(&fp); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = std::make_shared(*v, "", cnt, 0); - var->markRemovable(false); - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - Context block(1, &variableSpace, false); - block.fillInputs(in); - block.markInplace(isInplace); - // block.setRNG(ProviderRNG::getInstance().getRNG()); - - for (int e = 0; e < tArgs.size(); e++) - block.appendT(tArgs.at(e)); - - for (int e = 0; e < iArgs.size(); e++) - block.appendI(iArgs.at(e)); - - for (int e = 0; e < bArgs.size(); e++) - block.appendB(bArgs.at(e)); - - for (int e = 0; e < dArgs.size(); e++) - block.appendD(dArgs.at(e)); - - Nd4jStatus status = this->execute(&block); - ResultSet arrayList; - if (isInplace) - arrayList.setNonRemovable(); - - arrayList.setStatus(status); - if (status != ND4J_STATUS_OK) - return arrayList; - - if (!isInplace) { - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (variableSpace.hasVariable(pair)) { - auto var = variableSpace.getVariable(pair); - auto arr = var->getNDArray(); - if (!arr->isAttached()) { - var->markRemovable(false); - arr->setContext(sd::LaunchContext::defaultContext()); - arrayList.push_back(*arr.get()); - } else { - arrayList.push_back(arr->detach()); - } - } else - break; - } - } else { - for (auto v:inputs) { - arrayList.push_back(*v); - } - } - - return arrayList; - } - - sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder& holder, bool isInplace) { - // FIXME: add DArgs to OpArgsHolder - return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), holder.getBArgs(), std::vector(), isInplace); + } else + break; + } + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus DeclarableOp::execute(Context *block) { + nd4j_debug("Executing op: [%s]\n", this->getOpName().c_str()); + + std::chrono::time_point timeEnter, timeStart, + timeEnd; + Nd4jLong prepTime, outerTime; + + Nd4jLong memoryBefore = block->workspace() == nullptr + ? 0L + : block->workspace()->getSpilledSize() + + block->workspace()->getUsedSize(); + if (Environment::getInstance()->isProfiling()) + timeEnter = std::chrono::system_clock::now(); + + // make sure we're not trying to call non-inpace op inplace + if (block->isInplace() && !this->getOpDescriptor()->allowsInplace()) + throw std::runtime_error( + "DeclarableOp::execute - trying to execute non-inplace op as inplace"); + + // basic validation: ensure inputs are set + REQUIRE_OK(this->validateNonEmptyInput(*block)); + + // ensure number of IArgs, TArgs match our expectations + REQUIRE_OK(this->validateArguments(*block)); + + // validating data types for inputs and (optionally) outputs + REQUIRE_OK(this->validateDataTypes(*block)); + + // this method will allocate output NDArrays for this op + auto numOutputs = this->prepareOutputs(*block); + + if (Environment::getInstance()->isProfiling()) { + timeStart = std::chrono::system_clock::now(); + prepTime = std::chrono::duration_cast(timeStart - + timeEnter) + .count(); + } + + Nd4jStatus status; + bool hasHelper = false; + + // platform helpers use might be forbidden for various reasons, so we'll check + // it out first + if (block->helpersAllowed() && + sd::Environment::getInstance()->helpersAllowed()) { + // if we have platform-specific helper for this op - invoke it + if (OpRegistrator::getInstance()->hasHelper(this->getOpHash(), + block->engine())) { + auto helper = OpRegistrator::getInstance()->getPlatformHelper( + this->getOpHash(), block->engine()); + if (helper->isUsable(*block)) { + status = helper->invokeHelper(*block); + hasHelper = true; + } + } + } + + // if we don't have platform-specific helper - invoke generic implementation + if (!hasHelper) status = this->validateAndExecute(*block); + + // optionally saving execution time + if (Environment::getInstance()->isProfiling()) { + timeEnd = std::chrono::system_clock::now(); + outerTime = std::chrono::duration_cast(timeEnd - + timeStart) + .count(); + block->setInnerTime(outerTime); + } + + if (Environment::getInstance()->isProfiling() && + block->getVariableSpace() != nullptr) { + /* + auto fp = block->getVariableSpace()->flowPath(); + if (fp != nullptr) { + auto p = fp->profile(); + if (p != nullptr) { + Nd4jLong memoryAfter = block->workspace() == nullptr ? 0L : + block->workspace()->getSpilledSize() + block->workspace()->getUsedSize(); + Nd4jLong memoryUsed = memoryAfter - memoryBefore; + p->nodeById(block->nodeId())->setPreparationTime(prepTime); + p->nodeById(block->nodeId())->setExecutionTime(outerTime); + p->nodeById(block->nodeId())->setTotalSize(memoryUsed); } - - Nd4jStatus sd::ops::DeclarableOp::validateInputDimensionsMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; - - NDArray *a0 = block.array(0).get(); - for (int e = 0; e < block.width(); e++) { - auto aV = block.array(e); - if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo())) - return ND4J_STATUS_BAD_DIMENSIONS; - } - - return ND4J_STATUS_OK; + } + */ + throw std::runtime_error("DeclarableOp::execute - Not implemented yet"); + } + + // now we print out all outputs for this node + if (sd::Environment::getInstance()->isDebugAndVerbose()) { + auto vs = block->getVariableSpace(); + + for (int e = 0; e < numOutputs; e++) { + // if given output index doesn't exist - we're done + + if (!block->isFastPath()) { + if (!vs->hasVariable(block->nodeId(), e)) break; + } else { + // we have to check either in or out stack, depending on isInplace() + if (block->isInplace()) { + if (block->fastpath_in().size() <= e) break; + } else { + if (block->fastpath_out().size() <= e) break; } + } - Nd4jStatus sd::ops::DeclarableOp::validateInputLengthMatch(Context& block) { - if (block.width() == 0) - return ND4J_STATUS_OK; + auto array = block->isFastPath() + ? block->isInplace() ? block->fastpath_in()[e] + : block->fastpath_out()[e] + : vs->getVariable(block->nodeId(), e)->getNDArray(); + auto shape = ShapeUtils::shapeAsString(array.get()); + auto first = + array->isEmpty() ? std::string("Empty NDArray") : array->asString(32); + auto type = DataTypeUtils::asString(array->dataType()); - Nd4jLong l0 = block.array(0)->lengthOf(); - for (uint32_t e = 0; e < block.width(); e++) { - if (l0 != block.array(e)->lengthOf()) - return ND4J_STATUS_BAD_LENGTH; - } + nd4j_printf("node_%i:%i result shape: %s; dtype: %s; first values %s\n", + block->nodeId(), e, shape.c_str(), type.c_str(), + first.c_str()); + } + } + + return status; +} + +void DeclarableOp::overwriteResult(Context &block, int outputIdx, + NDArray *array) { + throw std::runtime_error("Overwrite result used!"); + // block.pushNDArrayToVariableSpace(block.nodeId(), outputIdx, array); + /* + auto varSpace = block.getVariableSpace(); + if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { + auto var = varSpace->getVariable(block.getNodeId(), outputIdx); + if (var->getNDArray() != nullptr && var->isRemovable()) + delete var->getNDArray(); + + var->setNDArray(array); + var->markRemovable(true); + } else { + auto var = new Variable(array, nullptr, block.getNodeId(), outputIdx); + varSpace->putVariable(block.getNodeId(), outputIdx, var); + } + */ +} + +void DeclarableOp::overwriteResult(Context &block, int outputIdx, + NDArrayList *list) { + throw std::runtime_error("Overwrite result used!"); + // block.pushNDArrayListToVariableSpace(block.nodeId(), outputIdx, list); + /* + auto varSpace = block.getVariableSpace(); + if (varSpace->hasVariable(block.getNodeId(), outputIdx)) { + auto var = varSpace->getVariable(block.getNodeId(), outputIdx); + var->setNDArrayList(list); + } else { + auto var = new Variable(nullptr, nullptr, block.getNodeId(), outputIdx); + var->setNDArrayList(list); + varSpace->putVariable(block.getNodeId(), outputIdx, var); + } + */ +} + +Nd4jStatus sd::ops::DeclarableOp::validateArguments(Context &block) { + /* + * We're checking number of T and I arguments. If number of args is finite + * number - we check strict equality If number of args is variable (-1), but + * variables MUST be present - we check for non-zero number of arguments + */ + if (_descriptor->getNumberOfTArgs() > 0) { + if ((int)block.numT() < _descriptor->getNumberOfTArgs()) { + nd4j_printf("%s: %i T args expected, but %i received\n", + this->getOpName().c_str(), _descriptor->getNumberOfTArgs(), + block.numT()); + return ND4J_STATUS_BAD_PARAMS; + } + } else if (_descriptor->getNumberOfTArgs() == -1) + if (block.numT() == 0) { + nd4j_printf( + "%s: Number of T arguments should be positive number, but got 0 " + "arguments\n", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_PARAMS; + } - return ND4J_STATUS_OK; - } + if (_descriptor->getNumberOfIArgs() > 0) { + if ((int)block.numI() < _descriptor->getNumberOfIArgs()) { + nd4j_printf("%s: %i int args expected, but %i received\n", + this->getOpName().c_str(), _descriptor->getNumberOfIArgs(), + block.numI()); + return ND4J_STATUS_BAD_PARAMS; + } + } else if (_descriptor->getNumberOfIArgs() == -1) + if (block.numI() == 0) { + nd4j_printf( + "%s: Number of Integer arguments should be positive number, but got " + "0 arguments\n", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_PARAMS; + } - samediff::EmptyHandling DeclarableOp::emptyHandling() { - return samediff::EmptyHandling::EMPTY_SKIP; - } + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputDimensions(Context &block, + int rank) { + if (block.width() == 0) return ND4J_STATUS_OK; + + for (auto p : block.inputs()) { + auto v = block.variable(p); + NDArray *aV = v->getNDArray().get(); + + if (aV == nullptr) return ND4J_STATUS_BAD_INPUT; + + if (aV->rankOf() != rank) return ND4J_STATUS_BAD_DIMENSIONS; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput2D(Context &block) { + return validateInputDimensions(block, 2); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput3D(Context &block) { + return validateInputDimensions(block, 3); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInput4D(Context &block) { + return validateInputDimensions(block, 4); +} + +Nd4jStatus sd::ops::DeclarableOp::validateNonEmptyInput(Context &block) { + if (this->getOpDescriptor()->getNumberOfInputs() == -2 || + this->getOpDescriptor()->getNumberOfInputs() == 0) + return Status::OK(); + + if (block.width() < 1) { + nd4j_printf("%s: no operands provided for the op", + this->getOpName().c_str()); + return ND4J_STATUS_BAD_INPUT; + } + + int cnt = 0; + for (auto p : block.inputs()) { + auto v = block.variable(p); + if (v == nullptr) { + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: Variable [%i] (%i:%i) is NULL\n", + block.getNodeId(), this->getOpName().c_str(), cnt, p.first, + p.second); + } else { + nd4j_printf("Node [%i:]: Variable [%i] (%i:%i) is NULL\n", + block.getNodeId(), cnt, p.first, p.second); + } + return ND4J_STATUS_BAD_INPUT; + } - void DeclarableOp::registerTypes() { - this->getOpDescriptor()->setSameMode(true); + if (v->variableType() == VariableType::NDARRAY) { + // if array is empty intentionally - we're ok with that + if (v->hasNDArray() && v->isEmpty()) { + continue; + } + + NDArray *aV = v->getNDArray().get(); + + if (aV == nullptr || !aV->nonNull()) { + if (!this->getOpName().empty()) { + nd4j_printf("Node [%i:<%s>]: NDArray [%i] (%i:%i) is NULL\n", + block.getNodeId(), this->getOpName().c_str(), cnt, + p.first, p.second); + } else { + nd4j_printf("Node [%i:]: NDArray [%i] (%i:%i) is NULL\n", + block.getNodeId(), cnt, p.first, p.second); } + return ND4J_STATUS_BAD_INPUT; + } + } - /* - template - int* sd::ops::DeclarableOp::calculateOutputShape(int* inputShape, sd::graph::Block& block) { - // default implementation suits transform, so just returns the same shape - - int* newshape; - ALLOCATE(newshape, block.workspace(), shape::shapeInfoLength(inputShape), int); - memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape)); - - return newshape; + cnt++; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateOrdersMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + NDArray *a0 = block.variable(0)->getNDArray().get(); + for (auto p : block.inputs()) { + auto v = block.variable(p); + NDArray *aV = v->getNDArray().get(); + if (a0->ordering() != aV->ordering()) return ND4J_STATUS_BAD_ORDER; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::execute( + sd::graph::RandomGenerator &rng, const std::vector &inputs, + const std::vector &outputs, const std::vector &tArgs, + const std::vector &iArgs, const std::vector &bArgs, + const std::vector &dArgs, bool isInplace, sd::DataType type) { + VariableSpace variableSpace; + FlowPath fp; + // variableSpace.setFlowPath(&fp); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt); + + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + int et = 0; + for (auto v : outputs) { + std::pair pair(1, et++); + auto var = std::make_shared(*v, "", pair.first, pair.second); + + variableSpace.putVariable(pair, var); + } + + Context block(1, &variableSpace, false); + block.fillInputs(in); + block.markInplace(isInplace); + + // we need this line for tests basically + // if (rng != nullptr) + block.setRng(rng); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + // FIXME: iargs should be Nd4jLong + for (int e = 0; e < iArgs.size(); e++) + block.appendI(static_cast(iArgs.at(e))); + + for (int e = 0; e < bArgs.size(); e++) + block.appendB(static_cast(bArgs.at(e))); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus result = this->execute(&block); + + return result; +} + +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs) { + return execute(inputs, outputs, std::vector(), + std::vector(), std::vector(), + std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list tArgs) { + return execute(inputs, outputs, tArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list dArgs) { + return execute(inputs, outputs, std::vector(), + std::vector(), std::vector(), dArgs); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list tArgs) { + std::vector realArgs; + for (auto v : tArgs) realArgs.emplace_back(v); + + return execute(inputs, outputs, realArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list iArgs) { + return execute(inputs, outputs, std::vector(), iArgs, + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list iArgs) { + std::vector realArgs; + for (auto v : iArgs) realArgs.emplace_back(v); + + return execute(inputs, outputs, std::vector(), realArgs, + std::vector(), std::vector()); +} + +template <> +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + std::initializer_list bArgs) { + return execute(inputs, outputs, std::vector(), + std::vector(), bArgs, std::vector()); +} + +Nd4jStatus DeclarableOp::execute(const std::vector &inputs, + const std::vector &outputs, + const std::vector &tArgs, + const std::vector &iArgs, + const std::vector &bArgs, + const std::vector &dArgs, + bool isInplace) { + Context ctx(1); + + for (int e = 0; e < inputs.size(); e++) { + ctx.setInputArray(e, inputs[e] == nullptr ? NDArray() : *inputs[e]); + } + + for (int e = 0; e < outputs.size(); e++) { + ctx.setOutputArray(e, outputs[e] == nullptr ? NDArray() : *outputs[e]); + } + + if (isInplace) ctx.markInplace(isInplace); + + ctx.setIArguments(iArgs); + ctx.setTArguments(tArgs); + ctx.setBArguments(bArgs); + ctx.setDArguments(dArgs); + + return execute(&ctx); +} + +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs) { + return evaluate(inputs, std::vector(), std::vector(), + std::vector(), std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list iArgs) { + std::vector realArgs; + for (auto v : iArgs) realArgs.emplace_back(v); + + return evaluate(inputs, std::vector(), realArgs, std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list iArgs) { + return evaluate(inputs, std::vector(), iArgs, std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list tArgs) { + std::vector realArgs; + for (auto v : tArgs) realArgs.emplace_back(v); + + return evaluate(inputs, realArgs, std::vector(), + std::vector(), std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list tArgs) { + return evaluate(inputs, tArgs, std::vector(), std::vector(), + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), bArgs, + std::vector()); +} + +template <> +sd::ResultSet DeclarableOp::evaluate( + const std::vector &inputs, + std::initializer_list bArgs) { + return evaluate(inputs, std::vector(), std::vector(), + std::vector(), bArgs); +} + +sd::ResultSet DeclarableOp::evaluate(const std::vector &inputs, + const std::vector &tArgs, + const std::vector &iArgs, + const std::vector &bArgs, + const std::vector &dArgs, + bool isInplace) { + VariableSpace variableSpace; + // ResultSet arrayList; + FlowPath fp; + // variableSpace.setFlowPath(&fp); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt, 0); + var->markRemovable(false); + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + Context block(1, &variableSpace, false); + block.fillInputs(in); + block.markInplace(isInplace); + // block.setRNG(ProviderRNG::getInstance().getRNG()); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + for (int e = 0; e < bArgs.size(); e++) block.appendB(bArgs.at(e)); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus status = this->execute(&block); + ResultSet arrayList; + if (isInplace) arrayList.setNonRemovable(); + + arrayList.setStatus(status); + if (status != ND4J_STATUS_OK) return arrayList; + + if (!isInplace) { + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (variableSpace.hasVariable(pair)) { + auto var = variableSpace.getVariable(pair); + auto arr = var->getNDArray(); + if (!arr->isAttached()) { + var->markRemovable(false); + arr->setContext(sd::LaunchContext::defaultContext()); + arrayList.push_back(*arr.get()); + } else { + arrayList.push_back(arr->detach()); } - */ + } else + break; + } + } else { + for (auto v : inputs) { + arrayList.push_back(*v); } -} \ No newline at end of file + } + + return arrayList; +} + +sd::ResultSet sd::ops::DeclarableOp::execute(const sd::OpArgsHolder &holder, + bool isInplace) { + // FIXME: add DArgs to OpArgsHolder + return evaluate(holder.getInArrs(), holder.getTArgs(), holder.getIArgs(), + holder.getBArgs(), std::vector(), isInplace); +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputDimensionsMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + NDArray *a0 = block.array(0).get(); + for (int e = 0; e < block.width(); e++) { + auto aV = block.array(e); + if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo())) + return ND4J_STATUS_BAD_DIMENSIONS; + } + + return ND4J_STATUS_OK; +} + +Nd4jStatus sd::ops::DeclarableOp::validateInputLengthMatch(Context &block) { + if (block.width() == 0) return ND4J_STATUS_OK; + + Nd4jLong l0 = block.array(0)->lengthOf(); + for (uint32_t e = 0; e < block.width(); e++) { + if (l0 != block.array(e)->lengthOf()) return ND4J_STATUS_BAD_LENGTH; + } + + return ND4J_STATUS_OK; +} + +samediff::EmptyHandling DeclarableOp::emptyHandling() { + return samediff::EmptyHandling::EMPTY_SKIP; +} + +void DeclarableOp::registerTypes() { + this->getOpDescriptor()->setSameMode(true); +} + +/* +template +int* sd::ops::DeclarableOp::calculateOutputShape(int* inputShape, +sd::graph::Block& block) { + // default implementation suits transform, so just returns the same shape + + int* newshape; + ALLOCATE(newshape, block.workspace(), shape::shapeInfoLength(inputShape), +int); memcpy(newshape, inputShape, shape::shapeInfoByteLength(inputShape)); + + return newshape; +} +*/ +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp index 4e1b6db0b2c2..b258ab815435 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableReductionOp.cpp @@ -18,46 +18,51 @@ // Created by raver119 on 07.10.2017. // -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include namespace sd { - namespace ops { - DeclarableReductionOp::DeclarableReductionOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - // - } +namespace ops { +DeclarableReductionOp::DeclarableReductionOp(int numInputs, int numOutputs, + const char* opName, + bool allowsInplace, int tArgs, + int iArgs) + : sd::ops::DeclarableOp(numInputs, numOutputs, opName, allowsInplace, tArgs, + iArgs) { + // +} - sd::ShapeList* DeclarableReductionOp::calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block) { - // int numDims = INT_ARG(0); - std::vector dims; - if (inputShape->size() > 1) { - // the second argument is axis - auto axis = INPUT_VARIABLE(1); - for (int e = 0; e < axis->lengthOf(); e++) - dims.push_back(axis->e(e)); - } - else if (block.numI()) - for (int e = 0; e < block.numI(); e++) - dims.push_back(INT_ARG(e)); - else if (block.getAxis().size()) { - dims = block.getAxis(); //.push_back(axis->e(e)); - } +sd::ShapeList* DeclarableReductionOp::calculateOutputShape( + sd::ShapeList* inputShape, sd::graph::Context& block) { + // int numDims = INT_ARG(0); + std::vector dims; + if (inputShape->size() > 1) { + // the second argument is axis + auto axis = INPUT_VARIABLE(1); + for (int e = 0; e < axis->lengthOf(); e++) dims.push_back(axis->e(e)); + } else if (block.numI()) + for (int e = 0; e < block.numI(); e++) dims.push_back(INT_ARG(e)); + else if (block.getAxis().size()) { + dims = block.getAxis(); //.push_back(axis->e(e)); + } - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); - // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - auto newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); - return SHAPELIST(newShape); - } + // special case - output is scalar + if (dims.size() == 0 || + (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + auto newShape = + ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::FLOAT32); + return SHAPELIST(newShape); + } - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), false, false, block.workspace()); - return SHAPELIST(newShape); - } - } + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', dims, inputShape->at(0), false, false, block.workspace()); + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp index c20c4e92d870..73b768bc777c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp @@ -18,80 +18,115 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include - +#include namespace sd { - namespace ops { - Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - std::vector dims(block.getIArguments()); - if (dims.size() > 0) - std::sort(dims.begin(), dims.end()); - - NDArray::prepareSpecialUse({z}, {x, y}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - PointersManager manager(block.launchContext(), "LegacyBroadcastBoolOp"); - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - REQUIRE_TRUE(shape::length(packX.primaryShapeInfo()) == y->lengthOf(), 0, "Length of broadcast TAD should be equal to length of Y operand, but got [%i] vs [%i]", (int) shape::length(packX.primaryShapeInfo()), (int) y->lengthOf()); - - if (x == z) - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, pTadShape, pTadOffsets); - else { - // this is rare, but possible use case - X and Z might have different shapes/strides/orders. In this case we prepare and pass separate TAD info - - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z->shapeInfo(), dims); - - auto zTadShape = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); - auto zTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOffsets, tadZ.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, zTadShape, zTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } +namespace ops { +Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + std::vector dims(block.getIArguments()); + if (dims.size() > 0) std::sort(dims.begin(), dims.end()); + + NDArray::prepareSpecialUse({z}, {x, y}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + PointersManager manager(block.launchContext(), "LegacyBroadcastBoolOp"); + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + REQUIRE_TRUE(shape::length(packX.primaryShapeInfo()) == y->lengthOf(), 0, + "Length of broadcast TAD should be equal to length of Y " + "operand, but got [%i] vs [%i]", + (int)shape::length(packX.primaryShapeInfo()), + (int)y->lengthOf()); + + if (x == z) + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, pTadShape, pTadOffsets); + else { + // this is rare, but possible use case - X and Z might have different + // shapes/strides/orders. In this case we prepare and pass separate TAD info + + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z->shapeInfo(), dims); + + auto zTadShape = + Environment::getInstance()->isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tadZ.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + auto zTadOffsets = + Environment::getInstance()->isCPU() + ? packZ.primaryOffsets() + : packZ.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tadZ.tadOffsets, + //tadZ.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, zTadShape, zTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} - LegacyBroadcastBoolOp::LegacyBroadcastBoolOp() : LegacyOp::LegacyOp(2) { - // - } +LegacyBroadcastBoolOp::LegacyBroadcastBoolOp() : LegacyOp::LegacyOp(2) { + // +} - LegacyBroadcastBoolOp::LegacyBroadcastBoolOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } +LegacyBroadcastBoolOp::LegacyBroadcastBoolOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + // +} - LegacyOp* LegacyBroadcastBoolOp::clone() { - return new LegacyBroadcastBoolOp(this->_opNum); - } +LegacyOp *LegacyBroadcastBoolOp::clone() { + return new LegacyBroadcastBoolOp(this->_opNum); +} - /** - * If external NDArray wasn't specified - the same shape is returned by all broadcast ops. - */ - ShapeList* LegacyBroadcastBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } +/** + * If external NDArray wasn't specified - the same shape is returned by all + * broadcast ops. + */ +ShapeList *LegacyBroadcastBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp index 7dab80f0d108..516a88155d97 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp @@ -18,91 +18,125 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include +#include #include +#include #include -#include -#include -#include namespace sd { - namespace ops { - Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - std::vector dims(block.getAxis()); - if (dims.size() == 0 && block.width() > 2) { - auto axis = INPUT_VARIABLE(2); - helpers::adjustAxis(x->rankOf(), axis, dims); - //dims = ShapeUtils::convertAxisToTadTarget(z->rankOf(), dims); - } - if (dims.size() > 0) - std::sort(dims.begin(), dims.end()); - - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto tadLen = shape::length(packX.primaryShapeInfo()); - REQUIRE_TRUE(tadLen == y->lengthOf(), 0, "Length of broadcast TAD should be equal to length of Y operand, but got [%i] vs [%i]",tadLen, (int) y->lengthOf()); - - PointersManager manager(block.launchContext(),"LegacyBroadcastOp"); - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - if (x == z) - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), pTadShape, pTadOffsets, pTadShape, pTadOffsets); - else { - // this is rare, but possible use case - X and Z might have different shapes/strides/orders. In this case we prepare and pass separate TAD info - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z->shapeInfo(), dims); - - auto zTadShape = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); - auto zTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadZ.tadOffsets, tadZ.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execBroadcast(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), pTadShape, pTadOffsets, zTadShape, zTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyBroadcastOp::LegacyBroadcastOp() : LegacyOp::LegacyOp(2) { - // - } - - LegacyBroadcastOp::LegacyBroadcastOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } +namespace ops { +Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + std::vector dims(block.getAxis()); + if (dims.size() == 0 && block.width() > 2) { + auto axis = INPUT_VARIABLE(2); + helpers::adjustAxis(x->rankOf(), axis, dims); + // dims = ShapeUtils::convertAxisToTadTarget(z->rankOf(), dims); + } + if (dims.size() > 0) std::sort(dims.begin(), dims.end()); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto tadLen = shape::length(packX.primaryShapeInfo()); + REQUIRE_TRUE(tadLen == y->lengthOf(), 0, + "Length of broadcast TAD should be equal to length of Y " + "operand, but got [%i] vs [%i]", + tadLen, (int)y->lengthOf()); + + PointersManager manager(block.launchContext(), "LegacyBroadcastOp"); + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + if (x == z) + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, pTadShape, pTadOffsets); + else { + // this is rare, but possible use case - X and Z might have different + // shapes/strides/orders. In this case we prepare and pass separate TAD info + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z->shapeInfo(), dims); + + auto zTadShape = + Environment::getInstance()->isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tadZ.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + auto zTadOffsets = + Environment::getInstance()->isCPU() + ? packZ.primaryOffsets() + : packZ.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tadZ.tadOffsets, + //tadZ.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execBroadcast( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + pTadShape, pTadOffsets, zTadShape, zTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} - LegacyOp* LegacyBroadcastOp::clone() { - return new LegacyBroadcastOp(this->_opNum); - } +LegacyBroadcastOp::LegacyBroadcastOp() : LegacyOp::LegacyOp(2) { + // +} - /** - * If external NDArray wasn't specified - the same shape is returned by all broadcast ops. - */ - ShapeList* LegacyBroadcastOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +LegacyBroadcastOp::LegacyBroadcastOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { + // +} - // FIXME: remove memcpy - Nd4jLong *newShape; - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(inShape), Nd4jLong); - memcpy(newShape, inShape, shape::shapeInfoByteLength(inShape)); +LegacyOp *LegacyBroadcastOp::clone() { + return new LegacyBroadcastOp(this->_opNum); +} - return SHAPELIST(CONSTANT(newShape)); - } - } +/** + * If external NDArray wasn't specified - the same shape is returned by all + * broadcast ops. + */ +ShapeList *LegacyBroadcastOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + // FIXME: remove memcpy + Nd4jLong *newShape; + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(inShape), + Nd4jLong); + memcpy(newShape, inShape, shape::shapeInfoByteLength(inShape)); + + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp index baba609fbb04..e6a11814819e 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyIndexReduceOp.cpp @@ -18,180 +18,194 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include #include #include - +#include +#include +#include namespace sd { - namespace ops { - LegacyIndexReduceOp::LegacyIndexReduceOp() : LegacyOp::LegacyOp(1){ - // - } - - LegacyIndexReduceOp::LegacyIndexReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // - } - - LegacyOp* LegacyIndexReduceOp::clone() { - return new LegacyIndexReduceOp(this->_opNum); - } - - ShapeList *LegacyIndexReduceOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - if (block.getAxis().size() == 0 && block.width() == 1) { - Nd4jLong *newShape; - // in this case we just return scalar - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[6] = 1; - newShape[7] = 99; - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.workspace()); - return SHAPELIST(result); - } else if (block.getAxis().size()){ - // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.workspace()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', block.getAxis(), *array, DataType::INT64, false, true, block.workspace()); - return SHAPELIST(newShape); - } - else { - bool allAxes = false; - auto indices = INPUT_VARIABLE(1); - Nd4jLong rank = shape::rank(inShape); - if (indices->lengthOf() == rank) - allAxes = true; - - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += rank; - } - if (allAxes){ - Nd4jLong *newShape; - // in this case we just return scalar - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[6] = 1; - newShape[7] = 99; - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShape, DataType::INT64)); - RELEASE(newShape, block.workspace()); - return SHAPELIST(result); - } else { - // in this case we're building proper shape for reduction - auto array = INPUT_VARIABLE(0); //new NDArray(nullptr, inShape, block.workspace()); - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', axis, *array, DataType::INT64, false, true, block.workspace())); - } - } - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - Nd4jStatus LegacyIndexReduceOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - if (z->dataType() != INT64) { - throw std::runtime_error("IndexReduce operations require output to be INT64"); - } - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - bool allAxes = false; - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyIndexReduceOp"); - - if (block.width() == 1) { - if (block.getAxis().size() == 0) { - // scalar - NativeOpExecutioner::execIndexReduceScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), - z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(block.getAxis().size()); - for (size_t e = 0; e < dims.size(); e++) { - auto axe = block.getAxis().at(e); - dims[e] = axe < 0 ? axe + x->rankOf(): axe; - } - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - NativeOpExecutioner::execIndexReduce(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - reinterpret_cast(z->buffer()), z->shapeInfo(), - z->specialBuffer(), z->specialShapeInfo(), - nullptr, (int) dims.size(), - Environment::getInstance()->isCPU() ? tadPack.primaryShapeInfo() : tadPack.specialShapeInfo(), Environment::getInstance()->isCPU() ? tadPack.primaryOffsets() : tadPack.specialOffsets()); - } - } else { - // TF mode - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += x->rankOf(); - } - - if (allAxes) { - NativeOpExecutioner::execIndexReduceScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), - x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), - z->specialShapeInfo()); - - } else { - if (indices->lengthOf() > 1) - std::sort(axis.begin(), axis.end()); - - REQUIRE_TRUE(axis.size() > 0, 0, "Some dimensions required for reduction!"); - - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), axis); - - NativeOpExecutioner::execIndexReduce(block.launchContext(), opNum, - x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - reinterpret_cast(z->buffer()), - z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - nullptr, (int) axis.size(), - Environment::getInstance()->isCPU() ? tadPack.primaryShapeInfo() : tadPack.specialShapeInfo(), - Environment::getInstance()->isCPU() ? tadPack.primaryOffsets() : tadPack.specialOffsets()); - } - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } +namespace ops { +LegacyIndexReduceOp::LegacyIndexReduceOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyIndexReduceOp::LegacyIndexReduceOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // +} + +LegacyOp *LegacyIndexReduceOp::clone() { + return new LegacyIndexReduceOp(this->_opNum); +} + +ShapeList *LegacyIndexReduceOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + if (block.getAxis().size() == 0 && block.width() == 1) { + Nd4jLong *newShape; + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[6] = 1; + newShape[7] = 99; + + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newShape, DataType::INT64)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(result); + } else if (block.getAxis().size()) { + // in this case we're building proper shape for reduction + auto array = + INPUT_VARIABLE(0); // new NDArray(nullptr, inShape, block.workspace()); + + auto newShape = ShapeUtils::evalReduceShapeInfo( + 'c', block.getAxis(), *array, DataType::INT64, false, true, + block.workspace()); + return SHAPELIST(newShape); + } else { + bool allAxes = false; + auto indices = INPUT_VARIABLE(1); + Nd4jLong rank = shape::rank(inShape); + if (indices->lengthOf() == rank) allAxes = true; + + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += rank; } -} \ No newline at end of file + if (allAxes) { + Nd4jLong *newShape; + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[6] = 1; + newShape[7] = 99; + + auto result = ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(newShape, DataType::INT64)); + RELEASE(newShape, block.workspace()); + return SHAPELIST(result); + } else { + // in this case we're building proper shape for reduction + auto array = INPUT_VARIABLE( + 0); // new NDArray(nullptr, inShape, block.workspace()); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + 'c', axis, *array, DataType::INT64, false, true, block.workspace())); + } + } +} + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +Nd4jStatus LegacyIndexReduceOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + if (z->dataType() != INT64) { + throw std::runtime_error( + "IndexReduce operations require output to be INT64"); + } + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyIndexReduceOp"); + + if (block.width() == 1) { + if (block.getAxis().size() == 0) { + // scalar + NativeOpExecutioner::execIndexReduceScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(block.getAxis().size()); + for (size_t e = 0; e < dims.size(); e++) { + auto axe = block.getAxis().at(e); + dims[e] = axe < 0 ? axe + x->rankOf() : axe; + } + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + NativeOpExecutioner::execIndexReduce( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), + reinterpret_cast(z->buffer()), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), nullptr, (int)dims.size(), + Environment::getInstance()->isCPU() ? tadPack.primaryShapeInfo() + : tadPack.specialShapeInfo(), + Environment::getInstance()->isCPU() ? tadPack.primaryOffsets() + : tadPack.specialOffsets()); + } + } else { + // TF mode + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; + + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += x->rankOf(); + } + + if (allAxes) { + NativeOpExecutioner::execIndexReduceScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + + } else { + if (indices->lengthOf() > 1) std::sort(axis.begin(), axis.end()); + + REQUIRE_TRUE(axis.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), axis); + + NativeOpExecutioner::execIndexReduce( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), + reinterpret_cast(z->buffer()), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), nullptr, (int)axis.size(), + Environment::getInstance()->isCPU() ? tadPack.primaryShapeInfo() + : tadPack.specialShapeInfo(), + Environment::getInstance()->isCPU() ? tadPack.primaryOffsets() + : tadPack.specialOffsets()); + } + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp index dfa4d42a0002..e5afc53cf39c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyOp.cpp @@ -20,46 +20,45 @@ #include - namespace sd { - namespace ops { - LegacyOp::LegacyOp(int numInputs) : DeclarableOp::DeclarableOp(numInputs , 1, "LegacyOp", false) { - _numInputs = numInputs; - } +namespace ops { +LegacyOp::LegacyOp(int numInputs) + : DeclarableOp::DeclarableOp(numInputs, 1, "LegacyOp", false) { + _numInputs = numInputs; +} - LegacyOp::LegacyOp(int numInputs, int opNum) : DeclarableOp::DeclarableOp(numInputs , 1, "LegacyOp", false) { - _opNum = opNum; - _numInputs = numInputs; - } +LegacyOp::LegacyOp(int numInputs, int opNum) + : DeclarableOp::DeclarableOp(numInputs, 1, "LegacyOp", false) { + _opNum = opNum; + _numInputs = numInputs; +} - LegacyOp::LegacyOp(const LegacyOp &other) noexcept { - _numInputs = other._numInputs; - _opNum = other._opNum; - } +LegacyOp::LegacyOp(const LegacyOp &other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; +} - LegacyOp &LegacyOp::operator=(const LegacyOp &other) noexcept { - if (this == &other) - return *this; +LegacyOp &LegacyOp::operator=(const LegacyOp &other) noexcept { + if (this == &other) return *this; - _numInputs = other._numInputs; - _opNum = other._opNum; + _numInputs = other._numInputs; + _opNum = other._opNum; - return *this; - } + return *this; +} - LegacyOp::LegacyOp(LegacyOp &&other) noexcept { - _numInputs = other._numInputs; - _opNum = other._opNum; - } +LegacyOp::LegacyOp(LegacyOp &&other) noexcept { + _numInputs = other._numInputs; + _opNum = other._opNum; +} - LegacyOp &LegacyOp::operator=(LegacyOp &&other) noexcept { - if (this == &other) - return *this; +LegacyOp &LegacyOp::operator=(LegacyOp &&other) noexcept { + if (this == &other) return *this; - _numInputs = other._numInputs; - _opNum = other._opNum; + _numInputs = other._numInputs; + _opNum = other._opNum; - return *this; - } - } -} \ No newline at end of file + return *this; +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp index 36a0b5a20415..f1008ab51df0 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformBoolOp.cpp @@ -21,53 +21,64 @@ #include #include - namespace sd { - namespace ops { - LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp() : LegacyOp::LegacyOp(2) { - // just a no-op - } +namespace ops { +LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp() + : LegacyOp::LegacyOp(2) { + // just a no-op +} - LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // just a no-op - } +LegacyPairwiseTransformBoolOp::LegacyPairwiseTransformBoolOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + // just a no-op +} - LegacyOp* LegacyPairwiseTransformBoolOp::clone() { - return new LegacyPairwiseTransformBoolOp(this->_opNum); - } +LegacyOp *LegacyPairwiseTransformBoolOp::clone() { + return new LegacyPairwiseTransformBoolOp(this->_opNum); +} - Nd4jStatus LegacyPairwiseTransformBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyPairwiseTransformBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {x, y}); + NDArray::prepareSpecialUse({z}, {x, y}); - if (!x->isSameShape(y)) - REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); + if (!x->isSameShape(y)) + REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, + "Node_%i: For Pairwise transforms shapes of both operands " + "should be equal but got %s vs %s", + block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyPairwiseTransformBoolOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), + "LegacyPairwiseTransformBoolOp"); - NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(x->dataType())); + NativeOpExecutioner::execPairwiseTransform( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(x->dataType())); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * Output shape of PWT operations always the same as input[0] shape, no exclusions. - */ - ShapeList *LegacyPairwiseTransformBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } -} \ No newline at end of file +/** + * Output shape of PWT operations always the same as input[0] shape, no + * exclusions. + */ +ShapeList *LegacyPairwiseTransformBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp index be4307de52ea..92111c80a1fa 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyPairwiseTransformOp.cpp @@ -21,57 +21,65 @@ #include #include - namespace sd { - namespace ops { - LegacyPairwiseTransformOp::LegacyPairwiseTransformOp() : LegacyOp::LegacyOp(2) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyPairwiseTransformOp::LegacyPairwiseTransformOp(int opNum) : LegacyOp::LegacyOp(2, opNum) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyOp* LegacyPairwiseTransformOp::clone() { - return new LegacyPairwiseTransformOp(this->_opNum); - } - - Nd4jStatus LegacyPairwiseTransformOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - if (!x->isSameShape(y)) - REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyPairwiseTransformOp"); - - NativeOpExecutioner::execPairwiseTransform(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(z->dataType())); - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - /** - * Output shape of PWT operations always the same as input[0] shape, no exclusions. - */ - ShapeList *LegacyPairwiseTransformOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +namespace ops { +LegacyPairwiseTransformOp::LegacyPairwiseTransformOp() : LegacyOp::LegacyOp(2) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyPairwiseTransformOp::LegacyPairwiseTransformOp(int opNum) + : LegacyOp::LegacyOp(2, opNum) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyOp *LegacyPairwiseTransformOp::clone() { + return new LegacyPairwiseTransformOp(this->_opNum); +} + +Nd4jStatus LegacyPairwiseTransformOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + if (!x->isSameShape(y)) + REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, + "Node_%i: For Pairwise transforms shapes of both operands " + "should be equal but got %s vs %s", + block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyPairwiseTransformOp"); + + NativeOpExecutioner::execPairwiseTransform( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType())); + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +/** + * Output shape of PWT operations always the same as input[0] shape, no + * exclusions. + */ +ShapeList *LegacyPairwiseTransformOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index dcf3b7d9efe7..5470a6af37cb 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -18,428 +18,467 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include #include #include +#include +#include #include +#include namespace sd { - namespace ops { - LegacyRandomOp::LegacyRandomOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } - - LegacyRandomOp::LegacyRandomOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } - - LegacyOp* LegacyRandomOp::clone() { - return new LegacyRandomOp(this->_opNum); - } - - template - Nd4jStatus LegacyRandomOp::validateAndExecute_(Context &block) { - auto input = INPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - /* - (0, randomOps::UniformDistribution) ,\ - (1, randomOps::DropOut) ,\ - (2, randomOps::DropOutInverted) ,\ - (3, randomOps::ProbablisticMerge) ,\ - (4, randomOps::Linspace) ,\ - (5, randomOps::Choice) ,\ - (6, randomOps::GaussianDistribution) ,\ - (7, randomOps::BernoulliDistribution) ,\ - (8, randomOps::BinomialDistribution),\ - (9, randomOps::BinomialDistributionEx),\ - (10, randomOps::LogNormalDistribution) ,\ - (11, randomOps::TruncatedNormalDistribution) ,\ - (12, randomOps::AlphaDropOut) - */ - switch(opNum) { - case sd::random::UniformDistribution: { - // uniform distribution - T from, to; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Uniform: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Uniform: Third argument must be scalar"); - - from = arg1->e(0); - to = arg2->e(0); - } else if (block.numT() == 2) { - from = T_ARG(0); - to = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "Uniform requires either TArgs or 3 arguments to be present"); - } - - auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to); - - // FIXME: - //OVERWRITE_RESULT(z); - } - break; - case sd::random::DropOut: { - auto z = OUTPUT_VARIABLE(0); - - T prob; - if (block.width() > 1) { - auto arg = INPUT_VARIABLE(1); - REQUIRE_TRUE(arg->isScalar(), 0, "DropOut: Second argument must be scalar"); - - prob = arg->e(0); - } else if (block.numT() > 0) { - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "DropOut requires either TArgs or second argument to be present"); - } - - if (!block.isInplace()) - z->assign(input); - - RandomLauncher::applyDropOut(block.launchContext(), block.randomGenerator(), z, prob); - } - break; - case sd::random::DropOutInverted: { - auto z = OUTPUT_VARIABLE(0); - sd::ops::dropout op; - return op.execute(&block); - } - break; - case sd::random::GaussianDistribution: { - // gaussian distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Gaussian: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Gaussian: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.numT() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "Gaussian requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Gaussian requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev); - - // FIXME: !! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::BernoulliDistribution: { - // bernoulli distribution - T prob; - if (block.width() > 1) { - auto arg1 = INPUT_VARIABLE(1); - REQUIRE_TRUE(arg1->isScalar(), 0, "Bernoulli: Second argument must be scalar"); - - prob = arg1->e(0); - } else if (block.numT() > 0) { - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "Bernoulli requires either 1 TArg or 2 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Bernoulli requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob); - - // FIXME: - //OVERWRITE_RESULT(z); - } - break; - case sd::random::BinomialDistributionEx: { - // BinomialEx distribution - T prob; - int trials; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "Binomial: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "Binomial: Third argument must be scalar"); - - trials = arg1->e(0); - prob = arg2->e(0); - } else if (block.numT() == 1 && block.numI() == 1) { - trials = INT_ARG(0); - prob = T_ARG(0); - } else { - REQUIRE_TRUE(false, 0, "Binomial requires either TArgs/IArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "Binomial requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob); - - // FIXME: !!! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::LogNormalDistribution: { - // lognorm distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "LogNormal: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "LogNormal: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.numT() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "LogNormal requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "LogNormal requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); - - // FIXME: !! - //OVERWRITE_RESULT(z); - } - break; - case sd::random::TruncatedNormalDistribution: { - // truncated norm distribution - T mean, stdev; - if (block.width() > 2) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - REQUIRE_TRUE(arg1->isScalar(), 0, "TruncatedNormal: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "TruncatedNormal: Third argument must be scalar"); - - mean = arg1->e(0); - stdev = arg2->e(0); - } else if (block.numT() == 2) { - mean = T_ARG(0); - stdev = T_ARG(1); - } else { - REQUIRE_TRUE(false, 0, "TruncatedNormal requires either TArgs or 3 arguments to be present"); - } - - REQUIRE_TRUE(input->isVector(), 0, "TruncatedNormal requires pure shape as first argument"); - - std::vector shape(input->lengthOf()); - for (int e = 0; e < input->lengthOf(); e++) - shape[e] = input->e(e); - - auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.workspace()); - - RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); - } - break; - case sd::random::AlphaDropOut: { - auto z = OUTPUT_VARIABLE(0); - - T prob, a, b, pa; - if (block.width() > 4) { - auto arg1 = INPUT_VARIABLE(1); - auto arg2 = INPUT_VARIABLE(2); - auto arg3 = INPUT_VARIABLE(3); - auto arg4 = INPUT_VARIABLE(4); - REQUIRE_TRUE(arg1->isScalar(), 0, "AlphaDropOut: Second argument must be scalar"); - REQUIRE_TRUE(arg2->isScalar(), 0, "AlphaDropOut: Third argument must be scalar"); - REQUIRE_TRUE(arg3->isScalar(), 0, "AlphaDropOut: Fourth argument must be scalar"); - REQUIRE_TRUE(arg4->isScalar(), 0, "AlphaDropOut: Fifth argument must be scalar"); - - prob = arg1->e(0); - a = arg2->e(0); - b = arg3->e(0); - pa = arg4->e(0); - } else if (block.numT() == 4) { - prob = T_ARG(0); - a = T_ARG(1); - b = T_ARG(2); - pa = T_ARG(3); - } else { - REQUIRE_TRUE(false, 0, "AlphaDropOut requires either TArgs or 5 arguments to be present"); - } - - if (!block.isInplace()) - z->assign(input); - - RandomLauncher::applyAlphaDropOut(block.launchContext(), block.randomGenerator(), z, prob, a, b, pa); - } - break; - case sd::random::Linspace: { - auto z = OUTPUT_VARIABLE(0); - auto start = INPUT_VARIABLE(0); - auto finish = INPUT_VARIABLE(1); - auto numOfElements = INPUT_VARIABLE(2); - - z->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); - } - break; - default: { - nd4j_printf("Unknown random op requested: [%i]\n", opNum); - return ND4J_STATUS_KERNEL_FAILURE; - } - } - - return Status::OK(); - } - - Nd4jStatus LegacyRandomOp::validateAndExecute(Context &block) { -// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be provided for LegacyRandomOp, but got NULL instead at node_%i", block.nodeId()) - - auto z = OUTPUT_VARIABLE(0); - BUILD_SINGLE_SELECTOR(z->dataType(), return validateAndExecute_, (block), FLOAT_TYPES); - } - - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyRandomOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - auto xType = ArrayOptions::dataType(inShape); - Nd4jLong *newShape; - if (DataTypeUtils::isR(xType)) { - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } else if (DataTypeUtils::isZ(xType)) { - auto zShapeArr = INPUT_VARIABLE(0); - auto zShapeVector = zShapeArr->asVectorT(); - auto dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; - - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', zShapeVector)); - } else - throw std::runtime_error("LegacyRandomOp: Unknown input data type!"); - } - - Nd4jStatus LegacyRandomOp::execute(Context* block) { - return DeclarableOp::execute(block); - } - - - sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, const std::vector& inputs, const std::vector& tArgs, const std::vector& iArgs, const std::vector& dArgs, bool isInplace) { - VariableSpace variableSpace; - ResultSet arrayList; - //ResultSet arrayList; - - if (isInplace) - arrayList.setNonRemovable(); - - int cnt = -1; - std::vector in; - for (auto v: inputs) { - if (v == nullptr) - continue; - - auto var = std::make_shared(*v, "", cnt); - - in.push_back(cnt); - variableSpace.putVariable(cnt--, var); - } - - Context block(1, &variableSpace, false); - // FIX ME: implement setRng method - block.setRng(rng); - block.fillInputs(in); - block.markInplace(isInplace); - - for (int e = 0; e < tArgs.size(); e++) - block.appendT(tArgs.at(e)); - - - for (int e = 0; e < iArgs.size(); e++) - block.appendI(iArgs.at(e)); - - for (int e = 0; e < dArgs.size(); e++) - block.appendD(dArgs.at(e)); - - Nd4jStatus status = this->execute(&block); - arrayList.setStatus(status); - if (status != ND4J_STATUS_OK) - return arrayList; - - - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (variableSpace.hasVariable(pair)) { - auto var = variableSpace.getVariable(pair); - auto arr = var->getNDArray(); - if (!arr->isAttached()) { - var->markRemovable(false); - arrayList.push_back(*arr.get()); - } else { - arrayList.push_back(arr->detach()); - } - } else - break; - } - - return arrayList; - } - - Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) { - if (block.isFastPath()) { - // in this case we'll roll through pre-defined outputs - auto fpo = block.fastpath_out(); - for (auto v:fpo) { - if (v != nullptr) { - if (!v->isR()) - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else { - std::pair pair(block.nodeId(), 0); - if (block.getVariableSpace()->hasVariable(pair)) { - auto var = block.variable(pair); - if (!var->hasNDArray()) - return ND4J_STATUS_BAD_ARGUMENTS; - - auto arr = var->getNDArray(); - if (!arr->isR()) - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - - return Status::OK(); - } - - BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, (Context&), FLOAT_TYPES); +namespace ops { +LegacyRandomOp::LegacyRandomOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} + +LegacyRandomOp::LegacyRandomOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} + +LegacyOp* LegacyRandomOp::clone() { return new LegacyRandomOp(this->_opNum); } + +template +Nd4jStatus LegacyRandomOp::validateAndExecute_(Context& block) { + auto input = INPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + /* + (0, randomOps::UniformDistribution) ,\ + (1, randomOps::DropOut) ,\ + (2, randomOps::DropOutInverted) ,\ + (3, randomOps::ProbablisticMerge) ,\ + (4, randomOps::Linspace) ,\ + (5, randomOps::Choice) ,\ + (6, randomOps::GaussianDistribution) ,\ + (7, randomOps::BernoulliDistribution) ,\ + (8, randomOps::BinomialDistribution),\ + (9, randomOps::BinomialDistributionEx),\ + (10, randomOps::LogNormalDistribution) ,\ + (11, randomOps::TruncatedNormalDistribution) ,\ + (12, randomOps::AlphaDropOut) + */ + switch (opNum) { + case sd::random::UniformDistribution: { + // uniform distribution + T from, to; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Uniform: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Uniform: Third argument must be scalar"); + + from = arg1->e(0); + to = arg2->e(0); + } else if (block.numT() == 2) { + from = T_ARG(0); + to = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "Uniform requires either TArgs or 3 arguments to be present"); + } + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillUniform(block.launchContext(), + block.randomGenerator(), z, from, to); + + // FIXME: + // OVERWRITE_RESULT(z); + } break; + case sd::random::DropOut: { + auto z = OUTPUT_VARIABLE(0); + + T prob; + if (block.width() > 1) { + auto arg = INPUT_VARIABLE(1); + REQUIRE_TRUE(arg->isScalar(), 0, + "DropOut: Second argument must be scalar"); + + prob = arg->e(0); + } else if (block.numT() > 0) { + prob = T_ARG(0); + } else { + REQUIRE_TRUE( + false, 0, + "DropOut requires either TArgs or second argument to be present"); + } + + if (!block.isInplace()) z->assign(input); + + RandomLauncher::applyDropOut(block.launchContext(), + block.randomGenerator(), z, prob); + } break; + case sd::random::DropOutInverted: { + auto z = OUTPUT_VARIABLE(0); + sd::ops::dropout op; + return op.execute(&block); + } break; + case sd::random::GaussianDistribution: { + // gaussian distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Gaussian: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Gaussian: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "Gaussian requires either TArgs or 3 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Gaussian requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillGaussian(block.launchContext(), + block.randomGenerator(), z, mean, stdev); + + // FIXME: !! + // OVERWRITE_RESULT(z); + } break; + case sd::random::BernoulliDistribution: { + // bernoulli distribution + T prob; + if (block.width() > 1) { + auto arg1 = INPUT_VARIABLE(1); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Bernoulli: Second argument must be scalar"); + + prob = arg1->e(0); + } else if (block.numT() > 0) { + prob = T_ARG(0); + } else { + REQUIRE_TRUE( + false, 0, + "Bernoulli requires either 1 TArg or 2 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Bernoulli requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillBernoulli(block.launchContext(), + block.randomGenerator(), z, prob); + + // FIXME: + // OVERWRITE_RESULT(z); + } break; + case sd::random::BinomialDistributionEx: { + // BinomialEx distribution + T prob; + int trials; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "Binomial: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "Binomial: Third argument must be scalar"); + + trials = arg1->e(0); + prob = arg2->e(0); + } else if (block.numT() == 1 && block.numI() == 1) { + trials = INT_ARG(0); + prob = T_ARG(0); + } else { + REQUIRE_TRUE(false, 0, + "Binomial requires either TArgs/IArgs or 3 arguments to " + "be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "Binomial requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillBinomial(block.launchContext(), + block.randomGenerator(), z, trials, prob); + + // FIXME: !!! + // OVERWRITE_RESULT(z); + } break; + case sd::random::LogNormalDistribution: { + // lognorm distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "LogNormal: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "LogNormal: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE( + false, 0, + "LogNormal requires either TArgs or 3 arguments to be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "LogNormal requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillLogNormal(block.launchContext(), + block.randomGenerator(), z, mean, stdev); + + // FIXME: !! + // OVERWRITE_RESULT(z); + } break; + case sd::random::TruncatedNormalDistribution: { + // truncated norm distribution + T mean, stdev; + if (block.width() > 2) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + REQUIRE_TRUE(arg1->isScalar(), 0, + "TruncatedNormal: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "TruncatedNormal: Third argument must be scalar"); + + mean = arg1->e(0); + stdev = arg2->e(0); + } else if (block.numT() == 2) { + mean = T_ARG(0); + stdev = T_ARG(1); + } else { + REQUIRE_TRUE(false, 0, + "TruncatedNormal requires either TArgs or 3 arguments to " + "be present"); + } + + REQUIRE_TRUE(input->isVector(), 0, + "TruncatedNormal requires pure shape as first argument"); + + std::vector shape(input->lengthOf()); + for (int e = 0; e < input->lengthOf(); e++) + shape[e] = input->e(e); + + auto z = OUTPUT_VARIABLE( + 0); // NDArrayFactory::create_('c', shape, block.workspace()); + + RandomLauncher::fillTruncatedNormal( + block.launchContext(), block.randomGenerator(), z, mean, stdev); + } break; + case sd::random::AlphaDropOut: { + auto z = OUTPUT_VARIABLE(0); + + T prob, a, b, pa; + if (block.width() > 4) { + auto arg1 = INPUT_VARIABLE(1); + auto arg2 = INPUT_VARIABLE(2); + auto arg3 = INPUT_VARIABLE(3); + auto arg4 = INPUT_VARIABLE(4); + REQUIRE_TRUE(arg1->isScalar(), 0, + "AlphaDropOut: Second argument must be scalar"); + REQUIRE_TRUE(arg2->isScalar(), 0, + "AlphaDropOut: Third argument must be scalar"); + REQUIRE_TRUE(arg3->isScalar(), 0, + "AlphaDropOut: Fourth argument must be scalar"); + REQUIRE_TRUE(arg4->isScalar(), 0, + "AlphaDropOut: Fifth argument must be scalar"); + + prob = arg1->e(0); + a = arg2->e(0); + b = arg3->e(0); + pa = arg4->e(0); + } else if (block.numT() == 4) { + prob = T_ARG(0); + a = T_ARG(1); + b = T_ARG(2); + pa = T_ARG(3); + } else { + REQUIRE_TRUE( + false, 0, + "AlphaDropOut requires either TArgs or 5 arguments to be present"); + } + + if (!block.isInplace()) z->assign(input); + + RandomLauncher::applyAlphaDropOut( + block.launchContext(), block.randomGenerator(), z, prob, a, b, pa); + } break; + case sd::random::Linspace: { + auto z = OUTPUT_VARIABLE(0); + auto start = INPUT_VARIABLE(0); + auto finish = INPUT_VARIABLE(1); + auto numOfElements = INPUT_VARIABLE(2); + + z->linspace(start->e(0), + (finish->e(0) - start->e(0)) / + (numOfElements->e(0) - 1.)); + } break; + default: { + nd4j_printf("Unknown random op requested: [%i]\n", opNum); + return ND4J_STATUS_KERNEL_FAILURE; } + } + + return Status::OK(); +} + +Nd4jStatus LegacyRandomOp::validateAndExecute(Context& block) { + // REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be + // provided for LegacyRandomOp, but got NULL instead at node_%i", + // block.nodeId()) + + auto z = OUTPUT_VARIABLE(0); + BUILD_SINGLE_SELECTOR(z->dataType(), return validateAndExecute_, (block), + FLOAT_TYPES); } + +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList* LegacyRandomOp::calculateOutputShape(ShapeList* inputShape, + sd::graph::Context& block) { + auto inShape = inputShape->at(0); + auto xType = ArrayOptions::dataType(inShape); + Nd4jLong* newShape; + if (DataTypeUtils::isR(xType)) { + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); + } else if (DataTypeUtils::isZ(xType)) { + auto zShapeArr = INPUT_VARIABLE(0); + auto zShapeVector = zShapeArr->asVectorT(); + auto dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + dtype, 'c', zShapeVector)); + } else + throw std::runtime_error("LegacyRandomOp: Unknown input data type!"); +} + +Nd4jStatus LegacyRandomOp::execute(Context* block) { + return DeclarableOp::execute(block); +} + +sd::ResultSet LegacyRandomOp::execute(sd::graph::RandomGenerator& rng, + const std::vector& inputs, + const std::vector& tArgs, + const std::vector& iArgs, + const std::vector& dArgs, + bool isInplace) { + VariableSpace variableSpace; + ResultSet arrayList; + // ResultSet arrayList; + + if (isInplace) arrayList.setNonRemovable(); + + int cnt = -1; + std::vector in; + for (auto v : inputs) { + if (v == nullptr) continue; + + auto var = std::make_shared(*v, "", cnt); + + in.push_back(cnt); + variableSpace.putVariable(cnt--, var); + } + + Context block(1, &variableSpace, false); + // FIX ME: implement setRng method + block.setRng(rng); + block.fillInputs(in); + block.markInplace(isInplace); + + for (int e = 0; e < tArgs.size(); e++) block.appendT(tArgs.at(e)); + + for (int e = 0; e < iArgs.size(); e++) block.appendI(iArgs.at(e)); + + for (int e = 0; e < dArgs.size(); e++) block.appendD(dArgs.at(e)); + + Nd4jStatus status = this->execute(&block); + arrayList.setStatus(status); + if (status != ND4J_STATUS_OK) return arrayList; + + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (variableSpace.hasVariable(pair)) { + auto var = variableSpace.getVariable(pair); + auto arr = var->getNDArray(); + if (!arr->isAttached()) { + var->markRemovable(false); + arrayList.push_back(*arr.get()); + } else { + arrayList.push_back(arr->detach()); + } + } else + break; + } + + return arrayList; +} + +Nd4jStatus LegacyRandomOp::validateDataTypes(Context& block) { + if (block.isFastPath()) { + // in this case we'll roll through pre-defined outputs + auto fpo = block.fastpath_out(); + for (auto v : fpo) { + if (v != nullptr) { + if (!v->isR()) return ND4J_STATUS_BAD_ARGUMENTS; + } + } + } else { + std::pair pair(block.nodeId(), 0); + if (block.getVariableSpace()->hasVariable(pair)) { + auto var = block.variable(pair); + if (!var->hasNDArray()) return ND4J_STATUS_BAD_ARGUMENTS; + + auto arr = var->getNDArray(); + if (!arr->isR()) return ND4J_STATUS_BAD_ARGUMENTS; + } + } + + return Status::OK(); +} + +BUILD_SINGLE_TEMPLATE(template Nd4jStatus LegacyRandomOp::validateAndExecute_, + (Context&), FLOAT_TYPES); +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp index d5e9b6c57c37..52fcd2f46c34 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp @@ -18,106 +18,139 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include +#include namespace sd { - namespace ops { - Nd4jStatus LegacyReduce3Op::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x, y}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - nd4j_debug("Executing LegacyReduce3Op: [%i]\n", opNum); - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduce3Op"); - - if (x->isSameShape(y) && (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { - // reduce3 to scalar - NativeOpExecutioner::execReduce3Scalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - std::vector dims(block.getAxis()); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z->shapeInfo(), dims); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions requuired for reduction!"); - - auto xTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tadX.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); - auto xTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadX.tadOffsets, tadX.numTads * sizeof(Nd4jLong)); - - auto yTadShape = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadY.tadOnlyShapeInfo, shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); - auto yTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() : packZ.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tadY.tadOffsets, tadY.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduce3(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), dims.size(), xTadShape, xTadOffsets, yTadShape, yTadOffsets); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyReduce3Op::LegacyReduce3Op() : LegacyOp::LegacyOp(2) { - // - } - - LegacyReduce3Op::LegacyReduce3Op(int opNum) : LegacyOp::LegacyOp(2, opNum) { - // - } - - LegacyOp* LegacyReduce3Op::clone() { - return new LegacyReduce3Op(this->_opNum); - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduce3Op::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto xShape = inputShape->at(0); - auto yShape = inputShape->at(1); - - Nd4jLong *zShape = nullptr; - - if (shape::equalsSoft(xShape, yShape) && (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { - // reduce3 to scalar case - ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - zShape[0] = 2; - zShape[1] = 1; - zShape[2] = 1; - zShape[3] = 1; - zShape[4] = 1; - zShape[5] = 0; - zShape[6] = 1; - zShape[7] = 99; - } else { - auto array = new NDArray(nullptr, xShape, block.launchContext()); - - xShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, false, true); - - delete array; - } - - return SHAPELIST(zShape); - } - } -} \ No newline at end of file +namespace ops { +Nd4jStatus LegacyReduce3Op::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x, y}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + nd4j_debug("Executing LegacyReduce3Op: [%i]\n", opNum); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduce3Op"); + + if (x->isSameShape(y) && + (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + // reduce3 to scalar + NativeOpExecutioner::execReduce3Scalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + std::vector dims(block.getAxis()); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z->shapeInfo(), dims); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions requuired for reduction!"); + + auto xTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tadX.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); + auto xTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tadX.tadOffsets, + //tadX.numTads * sizeof(Nd4jLong)); + + auto yTadShape = + Environment::getInstance()->isCPU() + ? packZ.primaryShapeInfo() + : packZ + .specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tadY.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); + auto yTadOffsets = + Environment::getInstance()->isCPU() + ? packZ.primaryOffsets() + : packZ.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tadY.tadOffsets, + //tadY.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduce3( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), dims.size(), + xTadShape, xTadOffsets, yTadShape, yTadOffsets); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +LegacyReduce3Op::LegacyReduce3Op() : LegacyOp::LegacyOp(2) { + // +} + +LegacyReduce3Op::LegacyReduce3Op(int opNum) : LegacyOp::LegacyOp(2, opNum) { + // +} + +LegacyOp *LegacyReduce3Op::clone() { return new LegacyReduce3Op(this->_opNum); } + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduce3Op::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto xShape = inputShape->at(0); + auto yShape = inputShape->at(1); + + Nd4jLong *zShape = nullptr; + + if (shape::equalsSoft(xShape, yShape) && + (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()))) { + // reduce3 to scalar case + ALLOCATE(zShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + zShape[0] = 2; + zShape[1] = 1; + zShape[2] = 1; + zShape[3] = 1; + zShape[4] = 1; + zShape[5] = 0; + zShape[6] = 1; + zShape[7] = 99; + } else { + auto array = new NDArray(nullptr, xShape, block.launchContext()); + + xShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, + false, true); + + delete array; + } + + return SHAPELIST(zShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp index 56c4a4917d6c..d215bf8da546 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp @@ -18,134 +18,175 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceBoolOp::LegacyReduceBoolOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceBoolOp::LegacyReduceBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceBoolOp::clone() { - return new LegacyReduceBoolOp(this->_opNum); - } - - Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - auto axis = block.getAxis(); - - bool allAxes = false; - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyReduceBoolOp"); - - if (block.width() == 1) { - if (axis.size() == x->rankOf()) - allAxes = true; - - if ((axis.empty()) || - (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceBoolScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), reinterpret_cast(pTadShape), reinterpret_cast(pTadOffsets)); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceBoolScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - if (indices->lengthOf() > 1) - std::sort(dims.begin(), dims.end()); +namespace ops { +LegacyReduceBoolOp::LegacyReduceBoolOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceBoolOp::LegacyReduceBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceBoolOp::clone() { + return new LegacyReduceBoolOp(this->_opNum); +} + +Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceBoolOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if ((axis.empty()) || + (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceBoolScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), reinterpret_cast(pTadShape), + reinterpret_cast(pTadOffsets)); + } - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); + // indices->printIndexedBuffer("indices"); - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + std::vector dims(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - NativeOpExecutioner::execReduceBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceBoolScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + if (indices->lengthOf() > 1) std::sort(dims.begin(), dims.end()); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceBoolOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; + Nd4jLong *newShape; - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, DataType::BOOL, keepDims, !newFormat, block.workspace())); - } - } -} \ No newline at end of file + // in this case we're building proper shape for reduction + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inShape), axis, inShape, DataType::BOOL, keepDims, + !newFormat, block.workspace())); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp index 1cc2dc75e391..552cfcc049a2 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp @@ -18,135 +18,173 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceFloatOp::LegacyReduceFloatOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceFloatOp::LegacyReduceFloatOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceFloatOp::clone() { - return new LegacyReduceFloatOp(this->_opNum); - } - - Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - bool allAxes = false; - auto axis = block.getAxis(); - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduceFloatOp"); - - if (block.width() == 1) { - - if (axis.size() == x->rankOf()) - allAxes = true; - - // _axis.(block.getIArguments()->size() == 0) || - // (block.getIArguments()->size() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) - if (block.getAxis().empty() || allAxes) { - // scalar - NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(block.getAxis()); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceFloat(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), reinterpret_cast(pTadShape), reinterpret_cast(pTadOffsets)); - - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceFloatScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); +namespace ops { +LegacyReduceFloatOp::LegacyReduceFloatOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceFloatOp::LegacyReduceFloatOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceFloatOp::clone() { + return new LegacyReduceFloatOp(this->_opNum); +} + +Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + bool allAxes = false; + auto axis = block.getAxis(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceFloatOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + // _axis.(block.getIArguments()->size() == 0) || + // (block.getIArguments()->size() == 1 && INT_ARG(0) == + // sd::DataTypeUtils::max()) + if (block.getAxis().empty() || allAxes) { + // scalar + NativeOpExecutioner::execReduceFloatScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(block.getAxis()); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceFloat( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), reinterpret_cast(pTadShape), + reinterpret_cast(pTadOffsets)); + } - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - NativeOpExecutioner::execReduceFloat(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); + // indices->printIndexedBuffer("indices"); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceFloatScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceFloat( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceFloatOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceFloatOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - auto newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, keepDims, !newFormat, block.workspace()); + // in this case we're building proper shape for reduction + auto newShape = + ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, + keepDims, !newFormat, block.workspace()); - return SHAPELIST(newShape); - } - } -} \ No newline at end of file + return SHAPELIST(newShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp index 47a903d658df..47743b822d5c 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp @@ -18,135 +18,175 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceLongOp::LegacyReduceLongOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceLongOp::LegacyReduceLongOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceLongOp::clone() { - return new LegacyReduceLongOp(this->_opNum); - } - - Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); - - auto axis = block.getAxis(); - bool allAxes = false; - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyReduceLongOp"); - - if (block.width() == 1) { - - if (axis.size() == x->rankOf()) - allAxes = true; - - if ((axis.empty()) || - (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceLongScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceLong(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } - - if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceLongScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); +namespace ops { +LegacyReduceLongOp::LegacyReduceLongOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceLongOp::LegacyReduceLongOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceLongOp::clone() { + return new LegacyReduceLongOp(this->_opNum); +} + +Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceFloatOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceLongOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if ((axis.empty()) || + (axis.size() == 1 && axis[0] == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceLongScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceLong( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + // indices->printIndexedBuffer("indices"); - NativeOpExecutioner::execReduceLong(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(x->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceLongScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceLong( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(x->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); - return Status::OK(); - } + manager.synchronize(); + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceLongOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceLongOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; + Nd4jLong *newShape; - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - return SHAPELIST(ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, DataType::INT64, keepDims, !newFormat, block.workspace())); - } - } -} \ No newline at end of file + // in this case we're building proper shape for reduction + return SHAPELIST(ShapeUtils::evalReduceShapeInfo( + shape::order(inShape), axis, inShape, DataType::INT64, keepDims, + !newFormat, block.workspace())); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp index f91af666b792..363dbc294847 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceOp.cpp @@ -18,170 +18,186 @@ // Created by raver119 on 16.10.2017. // -#include -#include #include +#include +#include #ifdef LEGACY_REDUCE_SAME_ONLY namespace sd { - namespace ops { - LegacyReduceOp::LegacyReduceOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceOp::LegacyReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceOp::clone() { - return new LegacyReduceOp(this->_opNum); - } - - Nd4jStatus LegacyReduceOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceOp: [%i]\n", opNum); - - bool allAxes = false; - - if (block.width() == 1) { - auto z = OUTPUT_VARIABLE(0); +namespace ops { +LegacyReduceOp::LegacyReduceOp() : LegacyOp::LegacyOp(1) { + // +} - if (block.getIArguments()->size() == x->rankOf()) - allAxes = true; +LegacyReduceOp::LegacyReduceOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} - if ((block.getIArguments()->size() == 0) || - (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - // scalar - NativeOpExcutioner::execReduceFloatScalar(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo()); - } else { - // TAD - std::vector dims(*block.getIArguments()); +LegacyOp *LegacyReduceOp::clone() { return new LegacyReduceOp(this->_opNum); } - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); +Nd4jStatus LegacyReduceOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); - std::sort(dims.begin(), dims.end()); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceOp: [%i]\n", opNum); - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + bool allAxes = false; - shape::TAD tad(x->shapeInfo(), dims.data(), dims.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + if (block.width() == 1) { + auto z = OUTPUT_VARIABLE(0); - NativeOpExcutioner::execReduceFloat(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo(), dims.data(), (int) dims.size(), tad.tadOnlyShapeInfo, tad.tadOffsets); - } + if (block.getIArguments()->size() == x->rankOf()) allAxes = true; - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; + if ((block.getIArguments()->size() == 0) || + (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + // scalar + NativeOpExcutioner::execReduceFloatScalar( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo()); + } else { + // TAD + std::vector dims(*block.getIArguments()); - //indices->printIndexedBuffer("indices"); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); - std::vector axis(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - axis[e] = f >= 0 ? f : f += x->rankOf(); - } + std::sort(dims.begin(), dims.end()); - if ((block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - auto z = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); - auto b = x->buffer(); - auto s = x->shapeInfo(); - auto e = block.numT() > 0 ? block.getTArguments()->data() : nullptr; + shape::TAD tad(x->shapeInfo(), dims.data(), dims.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - //x->printIndexedBuffer("x"); + NativeOpExcutioner::execReduceFloat( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo(), dims.data(), (int)dims.size(), + tad.tadOnlyShapeInfo, tad.tadOffsets); + } - // scalar - NativeOpExcutioner::execReduceFloatScalar(opNum, b, s, e, z->buffer(), z->shapeInfo()); - } else { - // TAD - if (indices->lengthOf() > 1) - std::sort(axis.begin(), axis.end()); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - REQUIRE_TRUE(axis.size() > 0, 0, "Some dimensions required for reduction!"); + // indices->printIndexedBuffer("indices"); - shape::TAD tad(x->shapeInfo(), axis.data(), axis.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + std::vector axis(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + axis[e] = f >= 0 ? f : f += x->rankOf(); + } - auto newShape = ShapeUtils::evalReduceShapeInfo(x->ordering(), axis, *x); - auto z = new NDArray(newShape, x->getWorkspace()); + if ((block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + auto z = OUTPUT_VARIABLE(0); + + auto b = x->buffer(); + auto s = x->shapeInfo(); + auto e = block.numT() > 0 ? block.getTArguments()->data() : nullptr; + + // x->printIndexedBuffer("x"); + + // scalar + NativeOpExcutioner::execReduceFloatScalar(opNum, b, s, e, z->buffer(), + z->shapeInfo()); + } else { + // TAD + if (indices->lengthOf() > 1) std::sort(axis.begin(), axis.end()); + + REQUIRE_TRUE(axis.size() > 0, 0, + "Some dimensions required for reduction!"); + + shape::TAD tad(x->shapeInfo(), axis.data(), axis.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); + + auto newShape = ShapeUtils::evalReduceShapeInfo(x->ordering(), axis, *x); + auto z = new NDArray(newShape, x->getWorkspace()); + + NativeOpExcutioner::execReduceFloat( + opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), + z->buffer(), z->shapeInfo(), axis.data(), (int)axis.size(), + tad.tadOnlyShapeInfo, tad.tadOffsets); + + // keepDims processing, for TF compatibility + if (block.getIArguments()->size() > 0 && + block.getIArguments()->at(0) == 1) { + // z->printShapeInfo("z shape before"); + std::vector newshape(z->getShapeAsVector()); + for (int e = 0; e < axis.size(); e++) { + auto a = axis.at(e); + newshape.insert(newshape.begin() + a, 1); + } + z->reshapei(z->ordering(), newshape); + // z->printShapeInfo("z shape after"); + } - NativeOpExcutioner::execReduceFloat(opNum, x->buffer(), x->shapeInfo(), block.getTArguments()->data(), z->buffer(), z->shapeInfo(), axis.data(), (int) axis.size(), tad.tadOnlyShapeInfo, tad.tadOffsets); + OVERWRITE_RESULT(z); + } + } + return ND4J_STATUS_OK; +} - // keepDims processing, for TF compatibility - if (block.getIArguments()->size() > 0 && block.getIArguments()->at(0) == 1) { - // z->printShapeInfo("z shape before"); - std::vector newshape(z->getShapeAsVector()); - for (int e = 0; e < axis.size(); e++) { - auto a = axis.at(e); - newshape.insert(newshape.begin() + a, 1); - } - z->reshapei(z->ordering(), newshape); - // z->printShapeInfo("z shape after"); - } +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + + bool allAxes = false; + + if (block.getIArguments()->size() == shape::rank(inShape)) allAxes = true; + + if (block.getIArguments()->size() == 0 || + (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || + allAxes) { + if (block.getIArguments()->size() > 0 && + block.getIArguments()->at(0) == 1) { + // in this case we just return legacy scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), + Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + // ArrayOptions::setDataType(newShape, block.dataType() == + // DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); + } else { + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(0), + Nd4jLong); + newShape[0] = 0; + newShape[1] = 0; + newShape[2] = 1; + newShape[3] = 99; + // ArrayOptions::setDataType(newShape, block.dataType() == + // DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); + } + } else { + // in this case we're building proper shape for reduction + auto array = new NDArray(nullptr, inShape, block.workspace()); - OVERWRITE_RESULT(z); - } - } + newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), + *block.getIArguments(), *array, + false, false, block.workspace()); - return ND4J_STATUS_OK; - } + delete array; + } - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - - bool allAxes = false; - - if (block.getIArguments()->size() == shape::rank(inShape)) - allAxes = true; - - if (block.getIArguments()->size() == 0 || (block.getIArguments()->size() == 1 && INT_ARG(0) == MAX_INT) || allAxes) { - if (block.getIArguments()->size() > 0 && block.getIArguments()->at(0) == 1) { - // in this case we just return legacy scalar - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - //ArrayOptions::setDataType(newShape, block.dataType() == DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); - } else { - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(0), Nd4jLong); - newShape[0] = 0; - newShape[1] = 0; - newShape[2] = 1; - newShape[3] = 99; - //ArrayOptions::setDataType(newShape, block.dataType() == DataType::BOOL?block.dataType():ArrayOptions::dataType(inShape)); - } - } else { - // in this case we're building proper shape for reduction - auto array = new NDArray(nullptr, inShape, block.workspace()); - - newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), *block.getIArguments(), *array, false, false, block.workspace()); - - delete array; - } - - return SHAPELIST(newShape); - } - } + return SHAPELIST(newShape); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp index 1f0cb1c9b704..da1a60e3a606 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp @@ -18,131 +18,172 @@ // Created by raver119 on 16.10.2017. // -#include -#include -#include +#include #include #include -#include +#include +#include +#include namespace sd { - namespace ops { - LegacyReduceSameOp::LegacyReduceSameOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyReduceSameOp::LegacyReduceSameOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - //this->_opNum = opNum; - } - - LegacyOp* LegacyReduceSameOp::clone() { - return new LegacyReduceSameOp(this->_opNum); - } - - Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - nd4j_debug("Executing LegacyReduceSameOp: [%i]\n", opNum); - - auto axis = block.getAxis(); - bool allAxes = false; - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyReduceSameOp"); - - if (block.width() == 1) { - if (axis.size() == x->rankOf()) - allAxes = true; - - if (axis.empty() || allAxes) { - // scalar - NativeOpExecutioner::execReduceSameScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - std::vector dims(axis); - - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execReduceSame(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), - z->buffer(), z->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - - STORE_RESULT(*z); - } else { - auto indices = INPUT_VARIABLE(1); - if (indices->lengthOf() == x->rankOf()) - allAxes = true; - - //indices->printIndexedBuffer("indices"); - - std::vector dims(indices->lengthOf()); - for (int e = 0; e < indices->lengthOf(); e++) { - // lol otherwise we segfault on macOS - int f = indices->e(e); - dims[e] = f >= 0 ? f : f += x->rankOf(); - } +namespace ops { +LegacyReduceSameOp::LegacyReduceSameOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyReduceSameOp::LegacyReduceSameOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // this->_opNum = opNum; +} + +LegacyOp *LegacyReduceSameOp::clone() { + return new LegacyReduceSameOp(this->_opNum); +} + +Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + nd4j_debug("Executing LegacyReduceSameOp: [%i]\n", opNum); + + auto axis = block.getAxis(); + bool allAxes = false; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyReduceSameOp"); + + if (block.width() == 1) { + if (axis.size() == x->rankOf()) allAxes = true; + + if (axis.empty() || allAxes) { + // scalar + NativeOpExecutioner::execReduceSameScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + std::vector dims(axis); + + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceSame( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } - if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || allAxes) { - // scalar - NativeOpExecutioner::execReduceSameScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo()); - } else { - // TAD - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions required for reduction!"); + STORE_RESULT(*z); + } else { + auto indices = INPUT_VARIABLE(1); + if (indices->lengthOf() == x->rankOf()) allAxes = true; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); + // indices->printIndexedBuffer("indices"); - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); + std::vector dims(indices->lengthOf()); + for (int e = 0; e < indices->lengthOf(); e++) { + // lol otherwise we segfault on macOS + int f = indices->e(e); + dims[e] = f >= 0 ? f : f += x->rankOf(); + } - NativeOpExecutioner::execReduceSame(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - dims.data(), (int) dims.size(), pTadShape, pTadOffsets); - } - } + if ((block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max()) || + allAxes) { + // scalar + NativeOpExecutioner::execReduceSameScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo()); + } else { + // TAD + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions required for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execReduceSame( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets); + } + } - manager.synchronize(); + manager.synchronize(); - return Status::OK(); - } + return Status::OK(); +} - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyReduceSameOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyReduceSameOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - bool allAxes = false; + bool allAxes = false; - auto keepDims = block.numB() > 0 ? B_ARG(0) : false; - auto newFormat = block.numB() > 1 ? B_ARG(1) : true; + auto keepDims = block.numB() > 0 ? B_ARG(0) : false; + auto newFormat = block.numB() > 1 ? B_ARG(1) : true; - auto axis = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); + auto axis = + block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : block.getAxis(); - if (axis.size() == shape::rank(inShape)) - allAxes = true; + if (axis.size() == shape::rank(inShape)) allAxes = true; - // in this case we're building proper shape for reduction - auto newShape = ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, keepDims, !newFormat, block.workspace()); + // in this case we're building proper shape for reduction + auto newShape = + ShapeUtils::evalReduceShapeInfo(shape::order(inShape), axis, inShape, + keepDims, !newFormat, block.workspace()); - return SHAPELIST(newShape); - } - } -} \ No newline at end of file + return SHAPELIST(newShape); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index 1c38f7b058bb..29c2efa1ab6d 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -18,70 +18,87 @@ // Created by raver119 on 16.10.2017. // -#include #include #include - +#include namespace sd { - namespace ops { - LegacyScalarBoolOp::LegacyScalarBoolOp() : LegacyOp::LegacyOp(1) { - // no-op - } - - LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum){ - // no-op - } - - LegacyOp* LegacyScalarBoolOp::clone() { - return new LegacyScalarBoolOp(this->_opNum, *this->_scalar); - } - - LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = new NDArray(scalar.dup(scalar.ordering())); - } - - ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - - Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyScalarBoolOp"); - - if (block.width() > 1) { - auto y = INPUT_VARIABLE(1); - - NDArray::prepareSpecialUse({z}, {x, y}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(x->dataType())); - } else if (block.numT() > 0) { - auto y = NDArrayFactory::create(T_ARG(0), block.launchContext()); - - NDArray::prepareSpecialUse({z}, {x, &y}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(),z->buffer(), z->shapeInfo(),z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), extras.argumentsAsT(x->dataType(), 1)); - - manager.synchronize(); - } else { - NDArray::prepareSpecialUse({z}, {x, _scalar}); - - NativeOpExecutioner::execScalarBool(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(),z->buffer(), z->shapeInfo(),z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), _scalar->shapeInfo(), _scalar->specialBuffer(), _scalar->specialShapeInfo(), extras.argumentsAsT(x->dataType())); - } - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - } -} \ No newline at end of file +namespace ops { +LegacyScalarBoolOp::LegacyScalarBoolOp() : LegacyOp::LegacyOp(1) { + // no-op +} + +LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // no-op +} + +LegacyOp *LegacyScalarBoolOp::clone() { + return new LegacyScalarBoolOp(this->_opNum, *this->_scalar); +} + +LegacyScalarBoolOp::LegacyScalarBoolOp(int opNum, NDArray &scalar) + : LegacyOp::LegacyOp(1, opNum) { + _scalar = new NDArray(scalar.dup(scalar.ordering())); +} + +ShapeList *LegacyScalarBoolOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyScalarBoolOp"); + + if (block.width() > 1) { + auto y = INPUT_VARIABLE(1); + + NDArray::prepareSpecialUse({z}, {x, y}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), + extras.argumentsAsT(x->dataType())); + } else if (block.numT() > 0) { + auto y = NDArrayFactory::create(T_ARG(0), block.launchContext()); + + NDArray::prepareSpecialUse({z}, {x, &y}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo(), + extras.argumentsAsT(x->dataType(), 1)); + + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({z}, {x, _scalar}); + + NativeOpExecutioner::execScalarBool( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), + _scalar->shapeInfo(), _scalar->specialBuffer(), + _scalar->specialShapeInfo(), extras.argumentsAsT(x->dataType())); + } + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp index 454b94db4849..ec15186aaf87 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarOp.cpp @@ -18,72 +18,88 @@ // Created by raver119 on 16.10.2017. // -#include #include #include - +#include namespace sd { - namespace ops { - LegacyScalarOp::LegacyScalarOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } - - LegacyScalarOp::LegacyScalarOp(int opNum) : LegacyOp::LegacyOp(1, opNum){ - this->getOpDescriptor()->allowInplace(true); - } - - LegacyOp* LegacyScalarOp::clone() { - return new LegacyScalarOp(this->_opNum, *this->_scalar); - } - - LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) : LegacyOp::LegacyOp(1, opNum){ - _scalar = new NDArray(scalar.dup(scalar.ordering())); - } - - ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - - Nd4jStatus LegacyScalarOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyScalarOp"); - - if (block.width() > 1) { - auto y = INPUT_VARIABLE(1); - - NDArray::prepareSpecialUse({z}, {x, y}); - - NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), y->specialBuffer(), y->specialShapeInfo(), extras.argumentsAsT(z->dataType())); - - NDArray::registerSpecialUse({z}, {x, y}); - } else if (block.numT() > 0) { - auto y = NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); - - x->applyScalarArr(static_cast(opNum), y, *z); - // NDArray::prepareSpecialUse({z}, {x, &y}); - // NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), extras.argumentsAsT(z->dataType(), 1)); - - manager.synchronize(); - } else { - NDArray::prepareSpecialUse({z}, {x, _scalar}); - - NativeOpExecutioner::execScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), _scalar->shapeInfo(), _scalar->specialBuffer(), _scalar->specialShapeInfo(), extras.argumentsAsT(z->dataType())); - - NDArray::registerSpecialUse({z}, {x, _scalar}); - } - - return Status::OK(); - } - } -} \ No newline at end of file +namespace ops { +LegacyScalarOp::LegacyScalarOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyScalarOp::LegacyScalarOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} + +LegacyOp *LegacyScalarOp::clone() { + return new LegacyScalarOp(this->_opNum, *this->_scalar); +} + +LegacyScalarOp::LegacyScalarOp(int opNum, NDArray &scalar) + : LegacyOp::LegacyOp(1, opNum) { + _scalar = new NDArray(scalar.dup(scalar.ordering())); +} + +ShapeList *LegacyScalarOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + +Nd4jStatus LegacyScalarOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyScalarOp"); + + if (block.width() > 1) { + auto y = INPUT_VARIABLE(1); + + NDArray::prepareSpecialUse({z}, {x, y}); + + NativeOpExecutioner::execScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), y->buffer(), y->shapeInfo(), + y->specialBuffer(), y->specialShapeInfo(), + extras.argumentsAsT(z->dataType())); + + NDArray::registerSpecialUse({z}, {x, y}); + } else if (block.numT() > 0) { + auto y = + NDArrayFactory::create(x->dataType(), T_ARG(0), block.launchContext()); + + x->applyScalarArr(static_cast(opNum), y, *z); + // NDArray::prepareSpecialUse({z}, {x, &y}); + // NativeOpExecutioner::execScalar(block.launchContext(), opNum, + // x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), + // z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + // y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + // extras.argumentsAsT(z->dataType(), 1)); + + manager.synchronize(); + } else { + NDArray::prepareSpecialUse({z}, {x, _scalar}); + + NativeOpExecutioner::execScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), _scalar->buffer(), + _scalar->shapeInfo(), _scalar->specialBuffer(), + _scalar->specialShapeInfo(), extras.argumentsAsT(z->dataType())); + + NDArray::registerSpecialUse({z}, {x, _scalar}); + } + + return Status::OK(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp index 3f95a1fbfe1c..9acce2851924 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp @@ -18,103 +18,125 @@ // Created by raver119 on 17.10.2017. // -#include +#include +#include #include #include -#include -#include - +#include namespace sd { - namespace ops { - Nd4jStatus LegacyStatsOp::validateAndExecute(Context &block) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - NDArray::prepareSpecialUse({z}, {x}); - - // we assume that opNuk is either stored in block, or was provided via op constructor - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - - // bias goes as first argument, unlike all other reductions - bool biasCorrected = false; - if (block.numI() > 0) - biasCorrected = INT_ARG(0) > 0; - - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyStatsOp"); - - if (block.numI() == 1 || (block.numI() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { - // scalar - NativeOpExecutioner::execSummaryStatsScalar(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), - extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), biasCorrected); - } else { - // dimensions for TAD - // we should skip first argument here, because it's addressing bias correction - std::vector dims(block.getIArguments()); - for (int e = 0; e < dims.size(); e++) - if (dims[e] < 0) - dims[e] += x->rankOf(); - - REQUIRE_TRUE(dims.size() > 0, 0, "Some dimensions requuired for reduction!"); - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x->shapeInfo(), dims); - - auto pTadShape = Environment::getInstance()->isCPU() ? packX.primaryShapeInfo() : packX.specialShapeInfo(); //(Nd4jLong *) manager.replicatePointer(tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) manager.replicatePointer(tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); - - NativeOpExecutioner::execSummaryStats(block.launchContext(), opNum, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), extras.argumentsAsT(z->dataType()), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dims.data(), (int) dims.size(), pTadShape, pTadOffsets, biasCorrected); - } - - manager.synchronize(); - STORE_RESULT(*z); - - return Status::OK(); - } - - LegacyStatsOp::LegacyStatsOp() : LegacyOp::LegacyOp(1) { - // - } - - LegacyStatsOp::LegacyStatsOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // - } - - LegacyOp* LegacyStatsOp::clone() { - return new LegacyStatsOp(this->_opNum); - } - - /** - * For all reductions rules are simple: either you return scalar, or you return reduced NDArray. - * It solely depends on input shape, and requested dimensions - */ - ShapeList *LegacyStatsOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - if (block.numI() == 0 || (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { - // in this case we just return scalar - ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); - newShape[0] = 2; - newShape[1] = 1; - newShape[2] = 1; - newShape[3] = 1; - newShape[4] = 1; - newShape[5] = 0; - newShape[6] = 1; - newShape[7] = 99; - } else { - // in this case we're building proper shape for reduction - auto array = new NDArray(nullptr, inShape, block.launchContext()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), *array, false, true); - - delete array; - return SHAPELIST(newShape); - } - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +namespace ops { +Nd4jStatus LegacyStatsOp::validateAndExecute(Context &block) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + NDArray::prepareSpecialUse({z}, {x}); + + // we assume that opNuk is either stored in block, or was provided via op + // constructor + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + + // bias goes as first argument, unlike all other reductions + bool biasCorrected = false; + if (block.numI() > 0) biasCorrected = INT_ARG(0) > 0; + + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyStatsOp"); + + if (block.numI() == 1 || + (block.numI() == 2 && INT_ARG(1) == sd::DataTypeUtils::max())) { + // scalar + NativeOpExecutioner::execSummaryStatsScalar( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), biasCorrected); + } else { + // dimensions for TAD + // we should skip first argument here, because it's addressing bias + // correction + std::vector dims(block.getIArguments()); + for (int e = 0; e < dims.size(); e++) + if (dims[e] < 0) dims[e] += x->rankOf(); + + REQUIRE_TRUE(dims.size() > 0, 0, + "Some dimensions requuired for reduction!"); + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x->shapeInfo(), dims); + + auto pTadShape = + Environment::getInstance()->isCPU() + ? packX.primaryShapeInfo() + : packX + .specialShapeInfo(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOnlyShapeInfo, + //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + auto pTadOffsets = + Environment::getInstance()->isCPU() + ? packX.primaryOffsets() + : packX.specialOffsets(); //(Nd4jLong *) + //manager.replicatePointer(tad.tadOffsets, + //tad.numTads * sizeof(Nd4jLong)); + + NativeOpExecutioner::execSummaryStats( + block.launchContext(), opNum, x->buffer(), x->shapeInfo(), + x->specialBuffer(), x->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), z->buffer(), z->shapeInfo(), + z->specialBuffer(), z->specialShapeInfo(), dims.data(), + (int)dims.size(), pTadShape, pTadOffsets, biasCorrected); + } + + manager.synchronize(); + STORE_RESULT(*z); + + return Status::OK(); +} + +LegacyStatsOp::LegacyStatsOp() : LegacyOp::LegacyOp(1) { + // +} + +LegacyStatsOp::LegacyStatsOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // +} + +LegacyOp *LegacyStatsOp::clone() { return new LegacyStatsOp(this->_opNum); } + +/** + * For all reductions rules are simple: either you return scalar, or you + * return reduced NDArray. It solely depends on input shape, and requested + * dimensions + */ +ShapeList *LegacyStatsOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + if (block.numI() == 0 || + (block.numI() == 1 && INT_ARG(0) == sd::DataTypeUtils::max())) { + // in this case we just return scalar + ALLOCATE(newShape, block.workspace(), shape::shapeInfoLength(2), Nd4jLong); + newShape[0] = 2; + newShape[1] = 1; + newShape[2] = 1; + newShape[3] = 1; + newShape[4] = 1; + newShape[5] = 0; + newShape[6] = 1; + newShape[7] = 99; + } else { + // in this case we're building proper shape for reduction + auto array = new NDArray(nullptr, inShape, block.launchContext()); + + auto newShape = ShapeUtils::evalReduceShapeInfo('c', block.getIArguments(), + *array, false, true); + + delete array; + return SHAPELIST(newShape); + } + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp index faa8836b2d2a..b8edd4eac4f5 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformAnyOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformAnyOp::LegacyTransformAnyOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformAnyOp::LegacyTransformAnyOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformAnyOp::LegacyTransformAnyOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformAnyOp::LegacyTransformAnyOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformAnyOp::clone() { - return new LegacyTransformAnyOp(this->_opNum); - } +LegacyOp *LegacyTransformAnyOp::clone() { + return new LegacyTransformAnyOp(this->_opNum); +} - Nd4jStatus LegacyTransformAnyOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformAnyOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyTransformAnyOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformAnyOp"); - NativeOpExecutioner::execTransformAny(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformAny( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformAnyOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformAnyOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp index ca2c53b93644..fff75cdd8d61 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformBoolOp.cpp @@ -18,54 +18,58 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformBoolOp::LegacyTransformBoolOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformBoolOp::LegacyTransformBoolOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformBoolOp::LegacyTransformBoolOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformBoolOp::LegacyTransformBoolOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformBoolOp::clone() { - return new LegacyTransformBoolOp(this->_opNum); - } +LegacyOp *LegacyTransformBoolOp::clone() { + return new LegacyTransformBoolOp(this->_opNum); +} - Nd4jStatus LegacyTransformBoolOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformBoolOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(),"LegacyTransformBoolOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformBoolOp"); - NativeOpExecutioner::execTransformBool(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), - extras.argumentsAsT(input->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformBool( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(input->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inShape, DataType::BOOL))); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformBoolOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo( + ShapeDescriptor(inShape, DataType::BOOL))); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp index 58d0f0579d6c..5896f96f51a5 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformFloatOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformFloatOp::LegacyTransformFloatOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformFloatOp::LegacyTransformFloatOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformFloatOp::LegacyTransformFloatOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformFloatOp::LegacyTransformFloatOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformFloatOp::clone() { - return new LegacyTransformFloatOp(this->_opNum); - } +LegacyOp *LegacyTransformFloatOp::clone() { + return new LegacyTransformFloatOp(this->_opNum); +} - Nd4jStatus LegacyTransformFloatOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformFloatOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformFloatOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformFloatOp"); - NativeOpExecutioner::execTransformFloat(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformFloat( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformFloatOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformFloatOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp index d0a8f7604d52..1ebefda35482 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp @@ -18,51 +18,54 @@ // Created by raver119 on 16.10.2017. // -#include - #include +#include #ifdef ONLY_SAME_TRANSFORM namespace sd { - namespace ops { - LegacyTransformOp::LegacyTransformOp() : LegacyOp::LegacyOp(1) { - // just a no-op - } +namespace ops { +LegacyTransformOp::LegacyTransformOp() : LegacyOp::LegacyOp(1) { + // just a no-op +} - LegacyTransformOp::LegacyTransformOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - // just a no-op - } +LegacyTransformOp::LegacyTransformOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { + // just a no-op +} - LegacyOp* LegacyTransformOp::clone() { - return new LegacyTransformOp(this->_opNum); - } +LegacyOp *LegacyTransformOp::clone() { + return new LegacyTransformOp(this->_opNum); +} - Nd4jStatus LegacyTransformOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - NativeOpExcutioner::execTransformSame(opNum, input->buffer(), input->shapeInfo(), z->buffer(), z->shapeInfo(), block.getTArguments()->data(), nullptr, nullptr); + NativeOpExcutioner::execTransformSame( + opNum, input->buffer(), input->shapeInfo(), z->buffer(), z->shapeInfo(), + block.getTArguments()->data(), nullptr, nullptr); - STORE_RESULT(*z); + STORE_RESULT(*z); - return ND4J_STATUS_OK; - } + return ND4J_STATUS_OK; +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + auto inShape = inputShape->at(0); - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); - return SHAPELIST(CONSTANT(newShape)); - } - } + return SHAPELIST(CONSTANT(newShape)); } +} // namespace ops +} // namespace sd #endif diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp index cd1cc2b999f4..c353627fa3ff 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformSameOp.cpp @@ -18,57 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformSameOp::LegacyTransformSameOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } +namespace ops { +LegacyTransformSameOp::LegacyTransformSameOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyTransformSameOp::LegacyTransformSameOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - this->getOpDescriptor()->allowInplace(true); - } +LegacyTransformSameOp::LegacyTransformSameOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyOp* LegacyTransformSameOp::clone() { - return new LegacyTransformSameOp(this->_opNum); - } +LegacyOp *LegacyTransformSameOp::clone() { + return new LegacyTransformSameOp(this->_opNum); +} - Nd4jStatus LegacyTransformSameOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformSameOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformSameOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformSameOp"); - NativeOpExecutioner::execTransformSame(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformSame( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformSameOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformSameOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp index 220a373753b8..baa3386afe12 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformStrictOp.cpp @@ -18,56 +18,61 @@ // Created by raver119 on 16.10.2017. // -#include - #include - +#include namespace sd { - namespace ops { - LegacyTransformStrictOp::LegacyTransformStrictOp() : LegacyOp::LegacyOp(1) { - this->getOpDescriptor()->allowInplace(true); - } +namespace ops { +LegacyTransformStrictOp::LegacyTransformStrictOp() : LegacyOp::LegacyOp(1) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyTransformStrictOp::LegacyTransformStrictOp(int opNum) : LegacyOp::LegacyOp(1, opNum) { - this->getOpDescriptor()->allowInplace(true); - } +LegacyTransformStrictOp::LegacyTransformStrictOp(int opNum) + : LegacyOp::LegacyOp(1, opNum) { + this->getOpDescriptor()->allowInplace(true); +} - LegacyOp* LegacyTransformStrictOp::clone() { - return new LegacyTransformStrictOp(this->_opNum); - } +LegacyOp *LegacyTransformStrictOp::clone() { + return new LegacyTransformStrictOp(this->_opNum); +} - Nd4jStatus LegacyTransformStrictOp::validateAndExecute(Context &block) { - auto input = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); +Nd4jStatus LegacyTransformStrictOp::validateAndExecute(Context &block) { + auto input = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - NDArray::prepareSpecialUse({z}, {input}); + NDArray::prepareSpecialUse({z}, {input}); - int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); + int opNum = block.opNum() < 0 ? this->_opNum : block.opNum(); - ExtraArguments extras(block.getTArguments()); - PointersManager manager(block.launchContext(), "LegacyTransformStrictOp"); + ExtraArguments extras(block.getTArguments()); + PointersManager manager(block.launchContext(), "LegacyTransformStrictOp"); - NativeOpExecutioner::execTransformStrict(block.launchContext(), opNum, input->buffer(), input->shapeInfo(), input->specialBuffer(), input->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), extras.argumentsAsT(z->dataType()), nullptr, nullptr); + NativeOpExecutioner::execTransformStrict( + block.launchContext(), opNum, input->buffer(), input->shapeInfo(), + input->specialBuffer(), input->specialShapeInfo(), z->buffer(), + z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), + extras.argumentsAsT(z->dataType()), nullptr, nullptr); - manager.synchronize(); - STORE_RESULT(*z); + manager.synchronize(); + STORE_RESULT(*z); - return Status::OK(); - } + return Status::OK(); +} - /** - * For transform operations, output shape always equals to input shape. With just a few exclusions, like im2col and col2im. - * But these ops already have CustomOp implementations. - * - */ - ShapeList *LegacyTransformStrictOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - auto inShape = inputShape->at(0); - - Nd4jLong *newShape; - COPY_SHAPE(inShape, newShape); - - return SHAPELIST(CONSTANT(newShape)); - } - } -} \ No newline at end of file +/** + * For transform operations, output shape always equals to input shape. With + * just a few exclusions, like im2col and col2im. But these ops already have + * CustomOp implementations. + * + */ +ShapeList *LegacyTransformStrictOp::calculateOutputShape( + ShapeList *inputShape, sd::graph::Context &block) { + auto inShape = inputShape->at(0); + + Nd4jLong *newShape; + COPY_SHAPE(inShape, newShape); + + return SHAPELIST(CONSTANT(newShape)); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/LogicOp.cpp b/libnd4j/include/ops/declarable/impl/LogicOp.cpp index ae24b5631c9c..a90698a136b9 100644 --- a/libnd4j/include/ops/declarable/impl/LogicOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LogicOp.cpp @@ -21,20 +21,21 @@ #include "ops/declarable/LogicOp.h" namespace sd { - namespace ops { - LogicOp::LogicOp(const char *name) : DeclarableOp::DeclarableOp(name, true) { - // just using DeclarableOp constructor - //this->_descriptor-> - } +namespace ops { +LogicOp::LogicOp(const char *name) : DeclarableOp::DeclarableOp(name, true) { + // just using DeclarableOp constructor + // this->_descriptor-> +} - Nd4jStatus LogicOp::validateAndExecute(sd::graph::Context &block) { - nd4j_logger("WARNING: LogicOps should NOT be ever called\n", ""); - return ND4J_STATUS_BAD_INPUT; - } +Nd4jStatus LogicOp::validateAndExecute(sd::graph::Context &block) { + nd4j_logger("WARNING: LogicOps should NOT be ever called\n", ""); + return ND4J_STATUS_BAD_INPUT; +} - ShapeList* LogicOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) { - // FIXME: we probably want these ops to evaluate scopes - return SHAPELIST(); - } - } -} \ No newline at end of file +ShapeList *LogicOp::calculateOutputShape(ShapeList *inputShape, + sd::graph::Context &block) { + // FIXME: we probably want these ops to evaluate scopes + return SHAPELIST(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp index 398c11729843..569856c5eb9d 100644 --- a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp +++ b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp @@ -21,281 +21,275 @@ #include namespace sd { - namespace ops { - - OpDescriptor::OpDescriptor(const char * opName, bool isLogic) { - _logic = isLogic; - _opName = opName; - } - - OpDescriptor::OpDescriptor(int numInputs, const char * opName, bool isScalar) { - _numInputs = numInputs; - _numOutputs = 1; - - _opName = opName; - _hash = sd::ops::HashHelper::getInstance()->getLongHash(_opName); - _opClass = sd::graph::OpClass_CONDITIONAL; - - _scalar = isScalar; - } - - OpDescriptor::OpDescriptor(int numInputs, std::string opName, bool isScalar) { - _numInputs = numInputs; - _numOutputs = 1; - - _opName = opName; - _hash = sd::ops::HashHelper::getInstance()->getLongHash(_opName); - _opClass = sd::graph::OpClass_CONDITIONAL; - - _scalar = isScalar; - } - - void OpDescriptor::allowInplace(bool reallyAllow){ - _allowsInplace = reallyAllow; - } - - bool OpDescriptor::operator==(const OpDescriptor& other) const { - if (_hash == -1 && other._hash == -1) - return this->_opNum == other._opNum; - else - return this->_hash == other._hash; - } - - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace) { - // - } - - void OpDescriptor::setHash(Nd4jLong hash) { - _hash = hash; - } - - // default constructor - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { - _numInputs = numInputs; - _numOutputs = numOutputs; - - std::string tmp(opName); - _opName = tmp; - _allowsInplace = allowsInplace; - _hash = sd::ops::HashHelper::getInstance()->getLongHash(tmp); - _divergent = false; - - // just default value - _opClass = sd::graph::OpClass_TRANSFORM; - } - - // constructor for configurable op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { - _tArgs = tArgs; - _iArgs = iArgs; - } - - // constructor for non-configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), allowsInplace, divergent) { - - } - - // constructor for non-configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { - _divergent = divergent; - } - - // constructor for configurable divergent op - OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent, int tArgs, int iArgs) : OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { - _divergent = divergent; - } - - // default destructor - OpDescriptor::~OpDescriptor() { - // - } - - int OpDescriptor::getNumberOfTArgs() { - return _tArgs; - } - - int OpDescriptor::getNumberOfIArgs() { - return _iArgs; - } - - int OpDescriptor::getNumberOfInputs() { - return _numInputs; - } - - Nd4jLong OpDescriptor::getHash() { - return _hash; - } - - int OpDescriptor::getNumberOfOutputs() { - return _numOutputs; - } - - std::string * OpDescriptor::getOpName() { - return &_opName; - } - - bool OpDescriptor::isDivergent() { - return _divergent; - } - - void OpDescriptor::setOpNum(int opNum) { - _opNum = opNum; - } - - bool OpDescriptor::allowsInplace() { - return _allowsInplace; - } - - int OpDescriptor::getOpNum() { - return _opNum; - } - - OpDescriptor* OpDescriptor::setInputType(const InputType type) { - _inputType = type; - return this; - } - - InputType OpDescriptor::inputType() { - return _inputType; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(const std::initializer_list &dtypes) { - _allowedIns = dtypes; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(const std::initializer_list &dtypes) { - _allowedOuts = dtypes; - return this; - } - - OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) { - _dtypeOverride = allowOverride; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(const sd::DataType dtype) { - _allowedIns.clear(); - _allowedIns.emplace_back(dtype); - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(const sd::DataType dtype) { - _allowedOuts.clear(); - _allowedOuts.emplace_back(dtype); - return this; - } - - OpDescriptor* OpDescriptor::setInputType(const int idx, const sd::DataType dtype) { - _inputTypes[idx] = { dtype }; - return this; - } - - OpDescriptor* OpDescriptor::setOutputType(const int idx, const sd::DataType dtype) { - _outputTypes[idx] = { dtype }; - return this; - } - - OpDescriptor* OpDescriptor::setSameMode(const bool reallySame) { - _sameMode = reallySame; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, const std::vector &dtype) { - _inputTypes[index] = dtype; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, const std::vector &dtype) { - _outputTypes[index] = dtype; - return this; - } - - OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, sd::DataType dtype) { - if (_inputTypes.count(index) == 0) - _inputTypes[index] = {dtype}; - else - _inputTypes[index].emplace_back(dtype); - - return this; - } - - OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, sd::DataType dtype) { - if (_outputTypes.count(index) == 0) - _outputTypes[index] = {dtype}; - else - _outputTypes[index].emplace_back(dtype); - - return this; - } - - bool OpDescriptor::checkDataTypesMatch(sd::DataType needle, std::vector &haystack) const { - // if haystack is empty - INHERIT is occurs - any type is perfect? - if (haystack.empty()) - return true; - - // first we're checking for direct input type match - if (std::find(haystack.begin(), haystack.end(), needle) == haystack.end()) { - - // if direct input match failed - we're checking for ANY as allowed input - if (std::find(haystack.begin(), haystack.end(), sd::DataType::ANY) == haystack.end()) - return false; - else - return true; - } else { - return true; - } - } - - bool OpDescriptor::checkInputMatch(int index, sd::DataType dataType) { - // we check for per-input types first - if (_inputTypes.empty() || _inputTypes.count(index) == 0) { - // checking global input types - return checkDataTypesMatch(dataType, _allowedIns); - } else { - // checking data type for specified input - auto allowed = _inputTypes[index]; - return checkDataTypesMatch(dataType, allowed); - } - return true; - } - - bool OpDescriptor::checkOutputMatch(int index, sd::DataType dataType) { - // we check for per-output types first - if (_outputTypes.empty() || _outputTypes.count(index) == 0) { - - // checking global output types - return checkDataTypesMatch(dataType, _allowedOuts); - } else { - // checking data type for specified output - auto allowed = _outputTypes[index]; - return checkDataTypesMatch(dataType, allowed); - } - return true; - } - - bool OpDescriptor::isSameMode() { - return _sameMode; - } - - bool OpDescriptor::isInherit(int index) { - if (std::find(_allowedOuts.begin(), _allowedOuts.end(), sd::DataType::INHERIT) != _allowedOuts.end()) - return true; - if (_outputTypes.count(index) > 0) { - auto vec = _outputTypes[index]; - - if (std::find(vec.begin(), vec.end(), sd::DataType::INHERIT) != vec.end()) - return true; - } - - return false; - } - - std::vector OpDescriptor::getOutputTypesForOutput(int index) { - if (_outputTypes.count(index) > 0) - return _outputTypes.at(index); - else - return std::vector(); - } - } -} \ No newline at end of file +namespace ops { + +OpDescriptor::OpDescriptor(const char* opName, bool isLogic) { + _logic = isLogic; + _opName = opName; +} + +OpDescriptor::OpDescriptor(int numInputs, const char* opName, bool isScalar) { + _numInputs = numInputs; + _numOutputs = 1; + + _opName = opName; + _hash = sd::ops::HashHelper::getInstance()->getLongHash(_opName); + _opClass = sd::graph::OpClass_CONDITIONAL; + + _scalar = isScalar; +} + +OpDescriptor::OpDescriptor(int numInputs, std::string opName, bool isScalar) { + _numInputs = numInputs; + _numOutputs = 1; + + _opName = opName; + _hash = sd::ops::HashHelper::getInstance()->getLongHash(_opName); + _opClass = sd::graph::OpClass_CONDITIONAL; + + _scalar = isScalar; +} + +void OpDescriptor::allowInplace(bool reallyAllow) { + _allowsInplace = reallyAllow; +} + +bool OpDescriptor::operator==(const OpDescriptor& other) const { + if (_hash == -1 && other._hash == -1) + return this->_opNum == other._opNum; + else + return this->_hash == other._hash; +} + +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), + allowsInplace) { + // +} + +void OpDescriptor::setHash(Nd4jLong hash) { _hash = hash; } + +// default constructor +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace) { + _numInputs = numInputs; + _numOutputs = numOutputs; + + std::string tmp(opName); + _opName = tmp; + _allowsInplace = allowsInplace; + _hash = sd::ops::HashHelper::getInstance()->getLongHash(tmp); + _divergent = false; + + // just default value + _opClass = sd::graph::OpClass_TRANSFORM; +} + +// constructor for configurable op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, int tArgs, int iArgs) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { + _tArgs = tArgs; + _iArgs = iArgs; +} + +// constructor for non-configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, std::string opName, + bool allowsInplace, bool divergent) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName.c_str(), + allowsInplace, divergent) {} + +// constructor for non-configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent) + : OpDescriptor::OpDescriptor(numInputs, numOutputs, opName, allowsInplace) { + _divergent = divergent; +} + +// constructor for configurable divergent op +OpDescriptor::OpDescriptor(int numInputs, int numOutputs, const char* opName, + bool allowsInplace, bool divergent, int tArgs, + int iArgs) + : OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs) { + _divergent = divergent; +} + +// default destructor +OpDescriptor::~OpDescriptor() { + // +} + +int OpDescriptor::getNumberOfTArgs() { return _tArgs; } + +int OpDescriptor::getNumberOfIArgs() { return _iArgs; } + +int OpDescriptor::getNumberOfInputs() { return _numInputs; } + +Nd4jLong OpDescriptor::getHash() { return _hash; } + +int OpDescriptor::getNumberOfOutputs() { return _numOutputs; } + +std::string* OpDescriptor::getOpName() { return &_opName; } + +bool OpDescriptor::isDivergent() { return _divergent; } + +void OpDescriptor::setOpNum(int opNum) { _opNum = opNum; } + +bool OpDescriptor::allowsInplace() { return _allowsInplace; } + +int OpDescriptor::getOpNum() { return _opNum; } + +OpDescriptor* OpDescriptor::setInputType(const InputType type) { + _inputType = type; + return this; +} + +InputType OpDescriptor::inputType() { return _inputType; } + +OpDescriptor* OpDescriptor::setAllowedInputTypes( + const std::initializer_list& dtypes) { + _allowedIns = dtypes; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes( + const std::initializer_list& dtypes) { + _allowedOuts = dtypes; + return this; +} + +OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) { + _dtypeOverride = allowOverride; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes(const sd::DataType dtype) { + _allowedIns.clear(); + _allowedIns.emplace_back(dtype); + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes(const sd::DataType dtype) { + _allowedOuts.clear(); + _allowedOuts.emplace_back(dtype); + return this; +} + +OpDescriptor* OpDescriptor::setInputType(const int idx, + const sd::DataType dtype) { + _inputTypes[idx] = {dtype}; + return this; +} + +OpDescriptor* OpDescriptor::setOutputType(const int idx, + const sd::DataType dtype) { + _outputTypes[idx] = {dtype}; + return this; +} + +OpDescriptor* OpDescriptor::setSameMode(const bool reallySame) { + _sameMode = reallySame; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes( + int index, const std::vector& dtype) { + _inputTypes[index] = dtype; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes( + int index, const std::vector& dtype) { + _outputTypes[index] = dtype; + return this; +} + +OpDescriptor* OpDescriptor::setAllowedInputTypes(int index, + sd::DataType dtype) { + if (_inputTypes.count(index) == 0) + _inputTypes[index] = {dtype}; + else + _inputTypes[index].emplace_back(dtype); + + return this; +} + +OpDescriptor* OpDescriptor::setAllowedOutputTypes(int index, + sd::DataType dtype) { + if (_outputTypes.count(index) == 0) + _outputTypes[index] = {dtype}; + else + _outputTypes[index].emplace_back(dtype); + + return this; +} + +bool OpDescriptor::checkDataTypesMatch( + sd::DataType needle, std::vector& haystack) const { + // if haystack is empty - INHERIT is occurs - any type is perfect? + if (haystack.empty()) return true; + + // first we're checking for direct input type match + if (std::find(haystack.begin(), haystack.end(), needle) == haystack.end()) { + // if direct input match failed - we're checking for ANY as allowed input + if (std::find(haystack.begin(), haystack.end(), sd::DataType::ANY) == + haystack.end()) + return false; + else + return true; + } else { + return true; + } +} + +bool OpDescriptor::checkInputMatch(int index, sd::DataType dataType) { + // we check for per-input types first + if (_inputTypes.empty() || _inputTypes.count(index) == 0) { + // checking global input types + return checkDataTypesMatch(dataType, _allowedIns); + } else { + // checking data type for specified input + auto allowed = _inputTypes[index]; + return checkDataTypesMatch(dataType, allowed); + } + return true; +} + +bool OpDescriptor::checkOutputMatch(int index, sd::DataType dataType) { + // we check for per-output types first + if (_outputTypes.empty() || _outputTypes.count(index) == 0) { + // checking global output types + return checkDataTypesMatch(dataType, _allowedOuts); + } else { + // checking data type for specified output + auto allowed = _outputTypes[index]; + return checkDataTypesMatch(dataType, allowed); + } + return true; +} + +bool OpDescriptor::isSameMode() { return _sameMode; } + +bool OpDescriptor::isInherit(int index) { + if (std::find(_allowedOuts.begin(), _allowedOuts.end(), + sd::DataType::INHERIT) != _allowedOuts.end()) + return true; + if (_outputTypes.count(index) > 0) { + auto vec = _outputTypes[index]; + + if (std::find(vec.begin(), vec.end(), sd::DataType::INHERIT) != vec.end()) + return true; + } + + return false; +} + +std::vector OpDescriptor::getOutputTypesForOutput(int index) { + if (_outputTypes.count(index) > 0) + return _outputTypes.at(index); + else + return std::vector(); +} +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index d94748eb9ddb..438779585e6c 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -18,255 +18,270 @@ // Created by raver119 on 07.10.2017. // - - #include + #include namespace sd { - namespace ops { - - /////////////////////////////// +namespace ops { - template - __registrator::__registrator() { - auto ptr = new OpName(); - OpRegistrator::getInstance()->registerOperation(ptr); - } +/////////////////////////////// +template +__registrator::__registrator() { + auto ptr = new OpName(); + OpRegistrator::getInstance()->registerOperation(ptr); +} - template - __registratorSynonym::__registratorSynonym(const char *name, const char *oname) { - auto ptr = reinterpret_cast(OpRegistrator::getInstance()->getOperation(oname)); - if (ptr == nullptr) { - std::string newName(name); - std::string oldName(oname); - - OpRegistrator::getInstance()->updateMSVC(sd::ops::HashHelper::getInstance()->getLongHash(newName), oldName); - return; - } - OpRegistrator::getInstance()->registerOperation(name, ptr); - } - - /////////////////////////////// - +template +__registratorSynonym::__registratorSynonym(const char* name, + const char* oname) { + auto ptr = reinterpret_cast( + OpRegistrator::getInstance()->getOperation(oname)); + if (ptr == nullptr) { + std::string newName(name); + std::string oldName(oname); + + OpRegistrator::getInstance()->updateMSVC( + sd::ops::HashHelper::getInstance()->getLongHash(newName), oldName); + return; + } + OpRegistrator::getInstance()->registerOperation(name, ptr); +} - OpRegistrator* OpRegistrator::getInstance() { - if (!_INSTANCE) - _INSTANCE = new sd::ops::OpRegistrator(); +/////////////////////////////// - return _INSTANCE; - } +OpRegistrator* OpRegistrator::getInstance() { + if (!_INSTANCE) _INSTANCE = new sd::ops::OpRegistrator(); + return _INSTANCE; +} - void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { - std::pair pair(newHash, oldName); - _msvc.insert(pair); - } +void OpRegistrator::updateMSVC(Nd4jLong newHash, std::string& oldName) { + std::pair pair(newHash, oldName); + _msvc.insert(pair); +} - template - std::string OpRegistrator::local_to_string(T value) { - //create an output string stream - std::ostringstream os ; +template +std::string OpRegistrator::local_to_string(T value) { + // create an output string stream + std::ostringstream os; - //throw the value into the string stream - os << value ; + // throw the value into the string stream + os << value; - //convert the string stream into a string and return - return os.str() ; - } + // convert the string stream into a string and return + return os.str(); +} - template <> - std::string OpRegistrator::local_to_string(int value) { - //create an output string stream - std::ostringstream os ; +template <> +std::string OpRegistrator::local_to_string(int value) { + // create an output string stream + std::ostringstream os; - //throw the value into the string stream - os << value ; + // throw the value into the string stream + os << value; - //convert the string stream into a string and return - return os.str() ; - } + // convert the string stream into a string and return + return os.str(); +} - void OpRegistrator::sigIntHandler(int sig) { +void OpRegistrator::sigIntHandler(int sig) { #ifndef _RELEASE - delete OpRegistrator::getInstance(); + delete OpRegistrator::getInstance(); #endif - } +} - void OpRegistrator::exitHandler() { +void OpRegistrator::exitHandler() { #ifndef _RELEASE - delete OpRegistrator::getInstance(); + delete OpRegistrator::getInstance(); #endif - } +} - void OpRegistrator::sigSegVHandler(int sig) { +void OpRegistrator::sigSegVHandler(int sig) { #ifndef _RELEASE - delete OpRegistrator::getInstance(); + delete OpRegistrator::getInstance(); #endif - } +} - OpRegistrator::~OpRegistrator() { +OpRegistrator::~OpRegistrator() { #ifndef _RELEASE - _msvc.clear(); + _msvc.clear(); - for (auto x: _uniqueH) - delete x; + for (auto x : _uniqueH) delete x; - _uniqueH.clear(); + _uniqueH.clear(); - _declarablesD.clear(); + _declarablesD.clear(); - _declarablesLD.clear(); + _declarablesLD.clear(); #endif - } - - const char * OpRegistrator::getAllCustomOperations() { - _locker.lock(); - - if (!isInit) { - for (MAP_IMPL>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { - std::string op = it->first + ":" - + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + ":" - + local_to_string(it->second->getOpDescriptor()->allowsInplace()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + ":" - + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + ":" - + ";" ; - _opsList += op; - } - - isInit = true; - } - - _locker.unlock(); - - return _opsList.c_str(); - } - - bool OpRegistrator::hasOperation(const std::string &opName) const { - return _declarablesD.count(opName) > 0; - } - - bool OpRegistrator::hasOperation(Nd4jLong opName) const { - return _declarablesLD.count(opName) > 0; - } - - bool OpRegistrator::registerOperation(const std::string &opName, std::shared_ptr op) { - std::pair> pair(opName, op); - _declarablesD.insert(pair); - - auto hash = sd::ops::HashHelper::getInstance()->getLongHash(opName); - std::pair> pair2(hash, op); - _declarablesLD.insert(pair2); - return true; - } - - /** - * This method registers operation - * - * @param op - */ - bool OpRegistrator::registerOperation(std::shared_ptr op) { - return registerOperation(op->getOpName(), op); - } - - void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { - std::pair p = {op->hash(), op->engine()}; - if (_helpersLH.count(p) > 0) - throw std::runtime_error("Tried to double register PlatformHelper"); - - _uniqueH.emplace_back(op); - - nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), op->hash(), (int) op->engine()); - - std::pair, sd::ops::platforms::PlatformHelper*> pair({op->name(), op->engine()}, op); - _helpersH.insert(pair); - - std::pair, sd::ops::platforms::PlatformHelper*> pair2(p, op); - _helpersLH.insert(pair2); - } - - /** - * This method returns registered Op by name - * - * @param name - * @return - */ - std::shared_ptr OpRegistrator::getOperation(Nd4jLong hash) { - if (!_declarablesLD.count(hash)) { - if (!_msvc.count(hash)) { - nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); - return nullptr; - } else { - std::lock_guard lock(_locker); - - auto str = _msvc.at(hash); - auto op = _declarablesD.at(str); - auto oHash = op->getOpDescriptor()->getHash(); - - std::pair> pair(oHash, op); - _declarablesLD.insert(pair); - } - } - - return _declarablesLD.at(hash); - } - - std::shared_ptr OpRegistrator::getOperation(const std::string& name) { - if (!_declarablesD.count(name)) { - nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); - return nullptr; - } - - return _declarablesD.at(name); - } - - sd::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper(Nd4jLong hash, samediff::Engine engine) { - std::pair p = {hash, engine}; - if (_helpersLH.count(p) == 0) - throw std::runtime_error("Requested helper can't be found"); - - return _helpersLH[p]; - } - - bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { - std::pair p = {hash, engine}; - return _helpersLH.count(p) > 0; - } - - int OpRegistrator::numberOfOperations() { - return (int) _declarablesLD.size(); - } - - std::vector OpRegistrator::getAllHashes() { - std::vector result; - - for (auto &v:_declarablesLD) { - result.emplace_back(v.first); - } - - return result; - } - - sd::ops::OpRegistrator* sd::ops::OpRegistrator::_INSTANCE = 0; - } } -namespace std { - size_t hash>::operator()(const std::pair& k) const { - using std::hash; - auto res = std::hash()(k.first); - res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; +const char* OpRegistrator::getAllCustomOperations() { + _locker.lock(); + + if (!isInit) { + for (MAP_IMPL>::iterator + it = _declarablesD.begin(); + it != _declarablesD.end(); ++it) { + std::string op = + it->first + ":" + + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfOutputs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->allowsInplace()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfTArgs()) + + ":" + + local_to_string(it->second->getOpDescriptor()->getNumberOfIArgs()) + + ":" + ";"; + _opsList += op; } - size_t hash>::operator()(const std::pair& k) const { - using std::hash; - auto res = std::hash()(k.first); - res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); - return res; + isInit = true; + } + + _locker.unlock(); + + return _opsList.c_str(); +} + +bool OpRegistrator::hasOperation(const std::string& opName) const { + return _declarablesD.count(opName) > 0; +} + +bool OpRegistrator::hasOperation(Nd4jLong opName) const { + return _declarablesLD.count(opName) > 0; +} + +bool OpRegistrator::registerOperation( + const std::string& opName, std::shared_ptr op) { + std::pair> pair(opName, + op); + _declarablesD.insert(pair); + + auto hash = sd::ops::HashHelper::getInstance()->getLongHash(opName); + std::pair> pair2(hash, op); + _declarablesLD.insert(pair2); + return true; +} + +/** + * This method registers operation + * + * @param op + */ +bool OpRegistrator::registerOperation( + std::shared_ptr op) { + return registerOperation(op->getOpName(), op); +} + +void OpRegistrator::registerHelper(sd::ops::platforms::PlatformHelper* op) { + std::pair p = {op->hash(), op->engine()}; + if (_helpersLH.count(p) > 0) + throw std::runtime_error("Tried to double register PlatformHelper"); + + _uniqueH.emplace_back(op); + + nd4j_debug("Adding helper for op \"%s\": [%lld - %i]\n", op->name().c_str(), + op->hash(), (int)op->engine()); + + std::pair, + sd::ops::platforms::PlatformHelper*> + pair({op->name(), op->engine()}, op); + _helpersH.insert(pair); + + std::pair, + sd::ops::platforms::PlatformHelper*> + pair2(p, op); + _helpersLH.insert(pair2); +} + +/** + * This method returns registered Op by name + * + * @param name + * @return + */ +std::shared_ptr OpRegistrator::getOperation( + Nd4jLong hash) { + if (!_declarablesLD.count(hash)) { + if (!_msvc.count(hash)) { + nd4j_printf("Unknown D operation requested by hash: [%lld]\n", hash); + return nullptr; + } else { + std::lock_guard lock(_locker); + + auto str = _msvc.at(hash); + auto op = _declarablesD.at(str); + auto oHash = op->getOpDescriptor()->getHash(); + + std::pair> pair(oHash, + op); + _declarablesLD.insert(pair); } + } + + return _declarablesLD.at(hash); } +std::shared_ptr OpRegistrator::getOperation( + const std::string& name) { + if (!_declarablesD.count(name)) { + nd4j_debug("Unknown operation requested: [%s]\n", name.c_str()); + return nullptr; + } + + return _declarablesD.at(name); +} + +sd::ops::platforms::PlatformHelper* OpRegistrator::getPlatformHelper( + Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + if (_helpersLH.count(p) == 0) + throw std::runtime_error("Requested helper can't be found"); + + return _helpersLH[p]; +} + +bool OpRegistrator::hasHelper(Nd4jLong hash, samediff::Engine engine) { + std::pair p = {hash, engine}; + return _helpersLH.count(p) > 0; +} + +int OpRegistrator::numberOfOperations() { return (int)_declarablesLD.size(); } + +std::vector OpRegistrator::getAllHashes() { + std::vector result; + + for (auto& v : _declarablesLD) { + result.emplace_back(v.first); + } + + return result; +} + +sd::ops::OpRegistrator* sd::ops::OpRegistrator::_INSTANCE = 0; +} // namespace ops +} // namespace sd + +namespace std { +size_t hash>::operator()( + const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int)k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; +} + +size_t hash>::operator()( + const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int)k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; +} +} // namespace std diff --git a/libnd4j/include/ops/declarable/impl/OpTuple.cpp b/libnd4j/include/ops/declarable/impl/OpTuple.cpp index fc43739e8b2d..c64ae1a755cb 100644 --- a/libnd4j/include/ops/declarable/impl/OpTuple.cpp +++ b/libnd4j/include/ops/declarable/impl/OpTuple.cpp @@ -20,38 +20,40 @@ #include "ops/declarable/OpTuple.h" -sd::ops::OpTuple::OpTuple(const char *opName) { - _opName = opName; -} - -sd::ops::OpTuple::OpTuple(const char *opName, std::initializer_list &&inputs, std::initializer_list &&tArgs, std::initializer_list &&iArgs) { - _opName = opName; - _inputs = inputs; - _iArgs = iArgs; - _tArgs = tArgs; +sd::ops::OpTuple::OpTuple(const char *opName) { _opName = opName; } + +sd::ops::OpTuple::OpTuple(const char *opName, + std::initializer_list &&inputs, + std::initializer_list &&tArgs, + std::initializer_list &&iArgs) { + _opName = opName; + _inputs = inputs; + _iArgs = iArgs; + _tArgs = tArgs; } sd::ops::OpTuple::~OpTuple() { - for (auto v: _inputs) - delete v; + for (auto v : _inputs) delete v; } sd::ops::OpTuple *sd::ops::OpTuple::addInput(sd::NDArray *array) { - _inputs.emplace_back(array); - return this; + _inputs.emplace_back(array); + return this; } sd::ops::OpTuple *sd::ops::OpTuple::addOutput(sd::NDArray *array) { - _outputs.emplace_back(array); - return this; + _outputs.emplace_back(array); + return this; } -sd::ops::OpTuple *sd::ops::OpTuple::setTArgs(std::initializer_list tArgs) { - _tArgs = tArgs; - return this; +sd::ops::OpTuple *sd::ops::OpTuple::setTArgs( + std::initializer_list tArgs) { + _tArgs = tArgs; + return this; } -sd::ops::OpTuple *sd::ops::OpTuple::setIArgs(std::initializer_list iArgs) { - _iArgs = iArgs; - return this; +sd::ops::OpTuple *sd::ops::OpTuple::setIArgs( + std::initializer_list iArgs) { + _iArgs = iArgs; + return this; } diff --git a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp index ccf69b586991..3509ca44df0e 100644 --- a/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp +++ b/libnd4j/include/ops/declarable/impl/PlatformHelper.cpp @@ -19,85 +19,79 @@ // #include "../PlatformHelper.h" + #include namespace sd { - namespace ops { - namespace platforms { - PlatformHelper::PlatformHelper(const char *name, samediff::Engine engine) { - // we just store name/hash of target operation - _name = std::string(name); - _hash = HashHelper::getInstance()->getLongHash(_name); - _engine = engine; - } - - sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { - auto result = getZ(block, inputId); - if (result != nullptr && result->undefined()) - return nullptr; - - if (result != nullptr && !block.isInplace()) - result->nullify(); - - return result; - } - - sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) { - NDArray *z = nullptr; - - if (ctx.isFastPath()) { - if (ctx.fastpath_out().size() <= inputId) { - if (ctx.isInplace()) { - z = ctx.fastpath_in()[inputId].get(); - } else - throw std::runtime_error("fastpath_out: unresolved output array"); - } else { - z = ctx.fastpath_out()[inputId].get(); - } - } else { - std::pair pair(ctx.nodeId(), inputId); - - if (ctx.isInplace()) { - auto vz = ctx.variable(inputId)->getNDArray(); - z = vz.get(); - - // hypothetically it's possible to have no variable. chances are low, but who knows. let's just create it for now - if (!ctx.getVariableSpace()->hasVariable(pair)) { - auto var = std::make_shared(); - ctx.getVariableSpace()->putVariable(pair, var); - } - - // now we're saving input array as output array - auto var = ctx.getVariableSpace()->getVariable(pair); - var->markRemovable(false); - var->setNDArray(vz); - } else if (!ctx.isInplace()) { - auto var = ctx.variable(pair); - if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { - z = var->getNDArray().get(); - } else { - nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); - } - } else { - nd4j_printf("BOOM!\n", ""); - throw std::runtime_error("Boom!"); - } - } - - return z; - } - - samediff::Engine PlatformHelper::engine() { - return _engine; - } - - std::string PlatformHelper::name() { - return _name; - } - - Nd4jLong PlatformHelper::hash() { - return _hash; - } - } +namespace ops { +namespace platforms { +PlatformHelper::PlatformHelper(const char* name, samediff::Engine engine) { + // we just store name/hash of target operation + _name = std::string(name); + _hash = HashHelper::getInstance()->getLongHash(_name); + _engine = engine; +} + +sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) { + auto result = getZ(block, inputId); + if (result != nullptr && result->undefined()) return nullptr; + + if (result != nullptr && !block.isInplace()) result->nullify(); + + return result; +} + +sd::NDArray* PlatformHelper::getZ(graph::Context& ctx, int inputId) { + NDArray* z = nullptr; + + if (ctx.isFastPath()) { + if (ctx.fastpath_out().size() <= inputId) { + if (ctx.isInplace()) { + z = ctx.fastpath_in()[inputId].get(); + } else + throw std::runtime_error("fastpath_out: unresolved output array"); + } else { + z = ctx.fastpath_out()[inputId].get(); + } + } else { + std::pair pair(ctx.nodeId(), inputId); + + if (ctx.isInplace()) { + auto vz = ctx.variable(inputId)->getNDArray(); + z = vz.get(); + + // hypothetically it's possible to have no variable. chances are low, but + // who knows. let's just create it for now + if (!ctx.getVariableSpace()->hasVariable(pair)) { + auto var = std::make_shared(); + ctx.getVariableSpace()->putVariable(pair, var); + } + + // now we're saving input array as output array + auto var = ctx.getVariableSpace()->getVariable(pair); + var->markRemovable(false); + var->setNDArray(vz); + } else if (!ctx.isInplace()) { + auto var = ctx.variable(pair); + if (var->getNDArray() != nullptr && var->getNDArray()->nonNull()) { + z = var->getNDArray().get(); + } else { + nd4j_printf("Can't get Z variable for node_%i!\n", ctx.nodeId()); + } + } else { + nd4j_printf("BOOM!\n", ""); + throw std::runtime_error("Boom!"); } -} \ No newline at end of file + } + + return z; +} + +samediff::Engine PlatformHelper::engine() { return _engine; } + +std::string PlatformHelper::name() { return _name; } + +Nd4jLong PlatformHelper::hash() { return _hash; } +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu index eb213f4c2ae6..737eaa01af41 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -18,121 +18,170 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, + dH, dW, isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto kH = INT_ARG(0); // filter(kernel) height - const auto kW = INT_ARG(1); // filter(kernel) width - const auto sH = INT_ARG(2); // strides height - const auto sW = INT_ARG(3); // strides width - auto pH = INT_ARG(4); // paddings height - auto pW = INT_ARG(5); // paddings width - const auto dH = INT_ARG(6); // dilations height - const auto dW = INT_ARG(7); // dilations width - const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - const auto extraParam0 = INT_ARG(9); - const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto extraParam0 = INT_ARG(9); + const auto isNCHW = block.getIArguments()->size() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, + pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); } PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu index da2fdbc098f7..aabf3300905c 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -18,127 +18,188 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - - pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 + ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got " + "%i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = + (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); } PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu index 7568ba47a330..f22c00d33573 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -18,551 +18,716 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void batchnormCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* mean, const NDArray* variance, - const NDArray* gamma, const NDArray* beta, - NDArray* output, - const double epsilon, const bool isSpatialMode) { - - - // input, output -> 4D:nchw, 5D:ncdhw - // mean, variance, gamma, beta -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode - // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode - - const cudnnDataType_t dataType = cudnnDataType(input->dataType()); - - const int xRank = input->rankOf(); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); - - const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes - - std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes - if(isSpatialMode) { // 1xCx1x1 - const int iC = mean->lengthOf(); - const int stride0 = mean->strideAt(0); - paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); - paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); - } - else { - paramsShape = mean->getShapeAsVectorInt(); - paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); - } - - std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; - std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3)}; - - if(xRank > 4) { // 5D - xStrides.push_back((int)input->strideAt(4)); - zStrides.push_back((int)output->strideAt(4)); - } - - cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(z, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(z, dataType, xRank, xShape.data(), zStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // mean, variance, gamma and beta descriptor, the same descriptor for all of them - cudnnTensorDescriptor_t params; - cudnnCreateTensorDescriptor(¶ms); - if(mean->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); - else - err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* ptrAlpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* ptrBeta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); - - // calculations - err = cudnnBatchNormalizationForwardInference(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, - ptrAlpha, ptrBeta, - x, input->specialBuffer(), - z, output->specialBuffer(), - params, - gamma->specialBuffer(), beta->specialBuffer(), - mean->specialBuffer(), variance->specialBuffer(), epsilon); - - if (err != 0) throw sd::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); +static void batchnormCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* mean, const NDArray* variance, + const NDArray* gamma, const NDArray* beta, + NDArray* output, const double epsilon, + const bool isSpatialMode) { + // input, output -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for + // BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for + // BATCHNORM_MODE_PER_ACTIVATION mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", + err); + + const std::vector xShape = + input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, + paramsStrides; // mean, variance, gamma and beta have same shapes + if (isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) + : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 + ? std::vector({iC * stride0, stride0, 1, 1}) + : std::vector({iC * stride0, stride0, 1, 1, 1}); + } else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = + xRank == 4 + ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3)}) + : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3), + (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), + (int)input->strideAt(3)}; + std::vector zStrides = { + (int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3)}; + + if (xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + zStrides.push_back((int)output->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), + xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(z, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, dataType, xRank, xShape.data(), + zStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output " + "failed", + err); + + // mean, variance, gamma and beta descriptor, the same descriptor for all of + // them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if (mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, + paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, + paramsShape.data(), paramsStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "mean/variance/gamma/beta failed", + err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* ptrAlpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* ptrBeta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, mean, variance, gamma, beta}); + + // calculations + err = cudnnBatchNormalizationForwardInference( + *handle, + isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, x, input->specialBuffer(), z, output->specialBuffer(), + params, gamma->specialBuffer(), beta->specialBuffer(), + mean->specialBuffer(), variance->specialBuffer(), epsilon); + + if (err != 0) + throw sd::cuda_exception::build( + "batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); } ////////////////////////////////////////////////////////////////////////// -static void batchnormBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO, - NDArray* gradI, NDArray* gradG, NDArray* gradB, - const double epsilon, const bool isSpatialMode) { - - // input, gradO, gradI -> 4D:nchw, 5D:ncdhw - // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode - // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode - - const cudnnDataType_t dataType = cudnnDataType(input->dataType()); - - const int xRank = input->rankOf(); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err); - - const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes - - std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes - if(isSpatialMode) { // 1xCx1x1 - const int iC = mean->lengthOf(); - const int stride0 = mean->strideAt(0); - paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); - paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); - } - else { - paramsShape = mean->getShapeAsVectorInt(); - paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); - } - - std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; - std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)}; - std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)}; - - if(xRank > 4) { // 5D - xStrides.push_back((int)input->strideAt(4)); - dxStrides.push_back((int)gradI->strideAt(4)); - dzStrides.push_back((int)gradO->strideAt(4)); - } - - cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); - - // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them - cudnnTensorDescriptor_t params; - cudnnCreateTensorDescriptor(¶ms); - if(mean->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); - else - err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - double alpha64(1), beta64(0); - const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); - - // calculations - // TODO: we can use cache here - err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, - ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, - x, input->specialBuffer(), - dz, gradO->specialBuffer(), - dx, gradI->specialBuffer(), - params, - gamma->specialBuffer(), gradG->specialBuffer(), gradB->specialBuffer(), - epsilon, - nullptr/*mean->specialBuffer()*/, nullptr/*variance->specialBuffer()*/); - - if (err != 0) throw sd::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); +static void batchnormBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* mean, const NDArray* variance, + const NDArray* gamma, const NDArray* gradO, + NDArray* gradI, NDArray* gradG, NDArray* gradB, + const double epsilon, const bool isSpatialMode) { + // input, gradO, gradI -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D + // and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D + // and 1xCxDxHxW for + // 5D for + // BATCHNORM_MODE_PER_ACTIVATION + // mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: can't set stream for cuDNN", err); + + const std::vector xShape = + input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, + paramsStrides; // mean, variance, gamma and beta have same shapes + if (isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) + : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 + ? std::vector({iC * stride0, stride0, 1, 1}) + : std::vector({iC * stride0, stride0, 1, 1, 1}); + } else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = + xRank == 4 + ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3)}) + : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), + (int)mean->strideAt(2), (int)mean->strideAt(3), + (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), + (int)input->strideAt(3)}; + std::vector dxStrides = { + (int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), + (int)gradI->strideAt(3)}; + std::vector dzStrides = { + (int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), + (int)gradO->strideAt(3)}; + + if (xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + dxStrides.push_back((int)gradI->strideAt(4)); + dzStrides.push_back((int)gradO->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = + cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), + xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, + xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), + dzStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, + xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), + dxStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI " + "failed", + err); + + // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for + // all of them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if (mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, + paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, + paramsShape.data(), paramsStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "mean/variance/gamma/gradG/gradB failed", + err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + double alpha64(1), beta64(0); + const void* ptrAlpha = input->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* ptrBeta = input->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradG, gradB}, + {input, mean, variance, gamma, gradO}); + + // calculations + // TODO: we can use cache here + err = cudnnBatchNormalizationBackward( + *handle, + isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, x, input->specialBuffer(), dz, + gradO->specialBuffer(), dx, gradI->specialBuffer(), params, + gamma->specialBuffer(), gradG->specialBuffer(), gradB->specialBuffer(), + epsilon, nullptr /*mean->specialBuffer()*/, + nullptr /*variance->specialBuffer()*/); + + if (err != 0) + throw sd::cuda_exception::build( + "batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI, gradG, gradB}, + {input, mean, variance, gamma, gradO}); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(int i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM CUDNN op: types of all input arrays should be the same !"); - - // cudnn supports NCHW format only - const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); - - if(needPermut) { // if NHWC - std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW - input = new NDArray(input->permute(perm)); - output = new NDArray(output->permute(perm)); - } - - // cudnn requires gamma and beta to be non-nullptr - if(!applyScale) { - gamma = new NDArray(mean); - *gamma = 1; - } - if(!applyOffset) { - beta = new NDArray(mean); - *beta = 0; - } - - // calculations - batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1); - - if(needPermut) { - delete input; - delete output; - } - - if(!applyScale) - delete gamma; - - if(!applyOffset) - delete beta; - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto mean = INPUT_VARIABLE(1); + auto variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + + auto output = OUTPUT_VARIABLE(0); + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM CUDNN op: too big number of input axes to normalize " + "over, expected number should be less or equal to rank of input " + "array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of mean array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of variance array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of gamma array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM CUDNN op: wrong shape of beta array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (int i = 1; i < block.width(); ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM CUDNN op: types of all input arrays should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = + axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); + + if (needPermut) { // if NHWC + std::vector perm = + inRank == 4 ? std::vector({0, 3, 1, 2}) + : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + output = new NDArray(output->permute(perm)); + } + + // cudnn requires gamma and beta to be non-nullptr + if (!applyScale) { + gamma = new NDArray(mean); + *gamma = 1; + } + if (!applyOffset) { + beta = new NDArray(mean); + *beta = 0; + } + + // calculations + batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, + output, epsilon, axes.size() == 1); + + if (needPermut) { + delete input; + delete output; + } + + if (!applyScale) delete gamma; + + if (!applyOffset) delete beta; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; - NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; - - const int numOfIntArgs = block.getIArguments()->size(); - const int xRank = input->rankOf(); - - // *********************************** // - if(xRank != 4 && xRank != 5) - return false; - - // *********************************** // - const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - if(badType) - return false; - - // *********************************** // - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(xRank-1); // default dimension to reduce along is last dimension - - if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) - return false; - - // *********************************** // - bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(gamma) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(beta) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); - - if(!allParamsHaveSameShapeAndStrides) - return false; - - // *********************************** // - bool isFormatGood = false; - if(axes.size() == 1) - isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] - else { - auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] - inputShapeModif[0] = 1; - isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] - } - if(!isFormatGood) - return false; - - return true; + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; + NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; + + const int numOfIntArgs = block.getIArguments()->size(); + const int xRank = input->rankOf(); + + // *********************************** // + if (xRank != 4 && xRank != 5) return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + if (badType) return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank - + 1); // default dimension to reduce along is last dimension + + if (axes.size() != 1 && axes.size() != 3 && axes.size() != 4) return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (gamma) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (beta) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), beta->shapeInfo()); + + if (!allParamsHaveSameShapeAndStrides) return false; + + // *********************************** // + bool isFormatGood = false; + if (axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || + mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = + input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or + // [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = + mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or + // [1,dim1,dim2,dim3,dim4] + } + if (!isFormatGood) return false; + + return true; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* gradI = OUTPUT_VARIABLE(0); - NDArray* gradM = OUTPUT_VARIABLE(1); - NDArray* gradV = OUTPUT_VARIABLE(2); - NDArray* gradG = nullptr; - NDArray* gradB = nullptr; - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - gradG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - gradB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.getIArguments()->size(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes - // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape; - if(numOfAxes == 1) - expShape.push_back(input->sizeAt(axes[0])); - else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} - expShape = std::vector(inRank, 1); - for(uint i = 0; i < numOfAxes; ++i) - expShape[axes[i]] = input->sizeAt(axes[i]); - } - - REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma) - REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta) - REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); - - REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - // types of all input arrays should be the same (except gradO) - for(int i = 1; i < block.width() - 2; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); - - // cudnn supports NCHW format only - const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); - - if(needPermut) { // if NHWC - std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW - input = new NDArray(input->permute(perm)); - gradO = new NDArray(gradO->permute(perm)); - gradI = new NDArray(gradI->permute(perm)); - } - - // cudnn requires gamma, gradG, gradB to be non-nullptr - if(!applyScale) { - gamma = new NDArray(mean); - gradG = new NDArray(mean); - *gamma = 1; - } - if(!applyOffset) - gradB = new NDArray(mean); - - // calculations - batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1); - - *gradM = 0; // put zeros so far - *gradV = 0; // put zeros so far - - if(needPermut) { - delete input; - delete gradO; - delete gradI; - } - - if(!applyScale) { - delete gamma; - delete gradG; - } - - if(!applyOffset) - delete gradB; - - return Status::OK(); - + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + gradG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + gradB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, + "BATCHNORM_BP CUDNN op: too big number of input axes to " + "normalize over, expected number should be less or equal to " + "rank of input array, but got %i and %i correspondingly !", + numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should + // have identical shapes for example if input shape is {2,3,4,5,6} and axes = + // {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then + // expected shape would be {5} + std::vector expShape; + if (numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if + // axes = {1, 3} + expShape = std::vector(inRank, 1); + for (uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(variance).c_str()); + if (gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(gamma).c_str()); + if (beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, + "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expShape).c_str(), + ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(gradO), 0, + "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + // types of all input arrays should be the same (except gradO) + for (int i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP CUDNN op: types of arrays (input, mean, " + "variance, gamma, beta) should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = + axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); + + if (needPermut) { // if NHWC + std::vector perm = + inRank == 4 ? std::vector({0, 3, 1, 2}) + : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + gradO = new NDArray(gradO->permute(perm)); + gradI = new NDArray(gradI->permute(perm)); + } + + // cudnn requires gamma, gradG, gradB to be non-nullptr + if (!applyScale) { + gamma = new NDArray(mean); + gradG = new NDArray(mean); + *gamma = 1; + } + if (!applyOffset) gradB = new NDArray(mean); + + // calculations + batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, + gradI, gradG, gradB, epsilon, axes.size() == 1); + + *gradM = 0; // put zeros so far + *gradV = 0; // put zeros so far + + if (needPermut) { + delete input; + delete gradO; + delete gradI; + } + + if (!applyScale) { + delete gamma; + delete gradG; + } + + if (!applyOffset) delete gradB; + + return Status::OK(); } PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) { - - NDArray* input = INPUT_VARIABLE(0); - NDArray* mean = INPUT_VARIABLE(1); - NDArray* variance = INPUT_VARIABLE(2); - NDArray* gamma = nullptr; - NDArray* beta = nullptr; - NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon - - NDArray* gradI = OUTPUT_VARIABLE(0); - NDArray* gradM = OUTPUT_VARIABLE(1); - NDArray* gradV = OUTPUT_VARIABLE(2); - NDArray* gradG = nullptr; - NDArray* gradB = nullptr; - - const int numOfIntArgs = block.getIArguments()->size(); - const int xRank = input->rankOf(); - - // *********************************** // - if(xRank != 4 && xRank != 5) - return false; - - // *********************************** // - const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - if(badType) - return false; - - // *********************************** // - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(xRank-1); // default dimension to reduce along is last dimension - - if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) - return false; - - // *********************************** // - bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); - if(gamma) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); - if(gradG) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gradG->shapeInfo()); - if(gradB) - allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->shapeInfo(), gradB->shapeInfo()); - - if(!allParamsHaveSameShapeAndStrides) - return false; - - // *********************************** // - bool isFormatGood = false; - if(axes.size() == 1) - isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] - else { - auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] - inputShapeModif[0] = 1; - isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] - } - if(!isFormatGood) - return false; - - return true; + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const int numOfIntArgs = block.getIArguments()->size(); + const int xRank = input->rankOf(); + + // *********************************** // + if (xRank != 4 && xRank != 5) return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + if (badType) return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank - + 1); // default dimension to reduce along is last dimension + + if (axes.size() != 1 && axes.size() != 3 && axes.size() != 4) return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = + shape::haveSameShapeAndStrides(mean->shapeInfo(), variance->shapeInfo()); + if (gamma) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gamma->shapeInfo()); + if (gradG) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gradG->shapeInfo()); + if (gradB) + allParamsHaveSameShapeAndStrides &= + shape::haveSameShapeAndStrides(mean->shapeInfo(), gradB->shapeInfo()); + + if (!allParamsHaveSameShapeAndStrides) return false; + + // *********************************** // + bool isFormatGood = false; + if (axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || + mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = + input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or + // [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = + mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or + // [1,dim1,dim2,dim3,dim4] + } + if (!isFormatGood) return false; + + return true; } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu index a77faf6f73a4..f93cd5d42b28 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu @@ -19,520 +19,823 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void conv2dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), formatW, oC, iC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); - err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); +static void conv2dCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // cudnn support only two formats for weights {oC,iC,kH,kW} and {oC,kH,kW,iC} + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: can't set stream for cuDNN", + err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx " + "for input failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), + formatW, oC, iC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx " + "for output failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + // err = cudnnSetTensor4dDescriptor(b, format, + // cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, + // isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor( + b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build("conv2dCUDNN: cudnnAddTensor bias failed", + err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dCUDNN: cudaStreamSynchronize failed + // !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// -static void conv2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, +static void conv2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1 && gradI->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, oC, iC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); - err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv2dBpCUDNN: can't set stream for cuDNN", + err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1 && gradI->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, + gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), + gradI->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), + formatW, oC, iC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + // err = cudnnSetTensor4dDescriptor(db, format, + // cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, + // isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor( + db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv2dBpCUDNN: cudaStreamSynchronize + // failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) { - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - REQUIRE_TRUE((bias->rankOf() == 1 && bias->strideAt(0) == 1) || (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && bias->strideAt(1) == 1) || (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && bias->strideAt(0) == 1), 0, "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); - } - - NDArray* newWeights = weights; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - if(0 == wFormat) { - newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW, wFormat); - - if(newInput != input) - delete newInput; - - if(0 == wFormat) - delete newWeights; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + bool isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - + // [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D CUDNN OP: rank of input array must be equal to " + "4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D CUDNN OP: rank of weights array must be equal to " + "4, but got %i instead !", + weights->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D CUDNN OP: wrong shape of weights array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) { + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + REQUIRE_TRUE( + (bias->rankOf() == 1 && bias->strideAt(0) == 1) || + (bias->rankOf() == 2 && bias->sizeAt(0) == 1 && + bias->strideAt(1) == 1) || + (bias->rankOf() == 2 && bias->sizeAt(1) == 1 && + bias->strideAt(0) == 1), + 0, + "CUSTOM CONV2D CUDNN OP: bias array should be contiguous in memory !"); + } + + NDArray* newWeights = weights; // cudnn support only two formats + // {oC,iC,kH,kW} and {oC,kH,kW,iC} + if (0 == wFormat) { + newWeights = new NDArray(weights->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCHW + ? std::vector({3, 2, 0, 1}) + : std::vector( + {3, 0, 1, 2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or + // (kH, kW, iC, oC --> oC, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH, kW, + sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + if (newInput != input) delete newInput; + + if (0 == wFormat) delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - + // [oC, kH, kW, iC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM CONV2D_BP CUDNN OP: rank of output's gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + NDArray *newWeights = weights, + *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} + // and {oC,kH,kW,iC} + if (0 == wFormat) { + newGradW = new NDArray(gradW->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + gradW->dataType(), gradW->getContext()); + newWeights = new NDArray(weights->ordering(), + isNCHW ? std::vector({oC, iC, kH, kW}) + : std::vector({oC, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCHW + ? std::vector({3, 2, 0, 1}) + : std::vector( + {3, 0, 1, 2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or + // (kH, kW, iC, oC --> oC, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, + newGradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, + isNCHW, wFormat); + + if (0 == wFormat) { + newGradW->permutei( + isNCHW ? std::vector({2, 3, 1, 0}) + : std::vector( + {1, 2, 3, 0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or + // (oC, kH, kW, iC --> kH, kW, iC, oC) + gradW->assign(newGradW); + } + + if (newInput != input) { + if (isNCHW) + gradI->assign( + (*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, gradI->sizeAt(3)})); + else + gradI->assign( + (*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM CONV2D_BP CUDNN OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kH,kW} and {oC,kH,kW,iC} - if(0 == wFormat) { - newGradW = new NDArray(gradW->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), gradW->dataType(), gradW->getContext()); - newWeights = new NDArray(weights->ordering(), isNCHW ? std::vector({oC, iC, kH, kW}) : std::vector({oC, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCHW ? std::vector({3,2,0,1}) : std::vector({3,0,1,2}))); // (kH, kW, iC, oC --> oC, iC, kH, kW) or (kH, kW, iC, oC --> oC, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - conv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW,wFormat); - - if(0 == wFormat) { - newGradW->permutei(isNCHW ? std::vector({2,3,1,0}) : std::vector({1,2,3,0})); // (oC, iC, kH, kW --> kH, kW, iC, oC) or (oC, kH, kW, iC --> kH, kW, iC, oC) - gradW->assign(newGradW); - } - - if(newInput != input) { - - if(isNCHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); - - delete newInput; - delete newGradI; - } - - if(0 == wFormat) { - delete newWeights; - delete newGradW; - } - - return Status::OK(); -} - -PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + if (0 == wFormat) { + delete newWeights; + delete newGradW; + } - return isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } - - - - - +PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && + !badGradOType && !badBiasType; +} // PLATFORM_IMPL(conv2d, ENGINE_CUDA) { -// auto handle = reinterpret_cast(block.launchContext()->getCuDnnHandle()); -// auto res = cudnnSetStream(*handle, *block.launchContext()->getCudaStream()); -// if (res != 0) +// auto handle = reinterpret_cast(block.launchContext()->getCuDnnHandle()); auto res = +// cudnnSetStream(*handle, *block.launchContext()->getCudaStream()); if (res +// != 0) // throw sd::cuda_exception::build("Can't set stream for cuDNN", res); -// auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) -// auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always -// auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] +// auto input = INPUT_VARIABLE(0); // +// [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = +// INPUT_VARIABLE(1); // [kH, kW, iC, oC] +// always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // +// [oC] -// auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) +// auto output = OUTPUT_VARIABLE(0); // +// [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // NDArray::prepareSpecialUse({output}, {input, weights, bias}); -// int sH = INT_ARG(2); // strides height -// int sW = INT_ARG(3); // strides width -// int pH = INT_ARG(4); // paddings height -// int pW = INT_ARG(5); // paddings width -// int dH = INT_ARG(6); // dilations height -// int dW = INT_ARG(7); // dilations width -// int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME -// bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - -// int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height -// int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - -// int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; -// int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes -// ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); -// ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, isSameMode); +// int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides +// width int pH = INT_ARG(4); // paddings height int pW = INT_ARG(5); // +// paddings width int dH = INT_ARG(6); // dilations height int dW = +// INT_ARG(7); // +// dilations width int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME bool +// isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // +// INT_ARG(9): 0-NCHW, 1-NHWC + +// int kH = INT_ARG(0) > 0 ? INT_ARG(0) : +// static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = +// INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // +// filter(kernel) width + +// int bS, iC, iH, iW, oC, oH, oW; // batch +// size, input channels, input height/width, output channels, output +// height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // +// corresponding indexes ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, +// *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, +// indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, +// iH, iW, kH, kW, sH, sW, dH, dW, isSameMode); // auto dtype = cudnnDataType(input->dataType()); - // cudnnTensorDescriptor_t src; // cudnnCreateTensorDescriptor(&src); -// res = cudnnSetTensor4dDescriptorEx(src, dtype, input->sizeAt(0), input->sizeAt(1), input->sizeAt(2), input->sizeAt(3), input->strideAt(0), input->strideAt(1), input->strideAt(2), input->strideAt(3)); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx src failed", res); +// res = cudnnSetTensor4dDescriptorEx(src, dtype, input->sizeAt(0), +// input->sizeAt(1), input->sizeAt(2), input->sizeAt(3), input->strideAt(0), +// input->strideAt(1), input->strideAt(2), input->strideAt(3)); if (res != +// 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx src +// failed", res); // // TODO: we definitely want NHWC here as well // cudnnFilterDescriptor_t wght; // cudnnCreateFilterDescriptor(&wght); -// res = cudnnSetFilter4dDescriptor(wght, dtype, CUDNN_TENSOR_NCHW, oC, iC, kH, kW); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetFilter4dDescriptor failed", res); +// res = cudnnSetFilter4dDescriptor(wght, dtype, CUDNN_TENSOR_NCHW, oC, iC, +// kH, kW); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetFilter4dDescriptor failed", +// res); // cudnnConvolutionDescriptor_t cdc; // cudnnCreateConvolutionDescriptor(&cdc); -// res = cudnnSetConvolution2dDescriptor(cdc, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, dtype); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetConvolution2dDescriptor failed", res); +// res = cudnnSetConvolution2dDescriptor(cdc, pH, pW, sH, sW, dH, dW, +// CUDNN_CROSS_CORRELATION, dtype); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetConvolution2dDescriptor +// failed", res); // cudnnTensorDescriptor_t dst; // cudnnCreateTensorDescriptor(&dst); -// res = cudnnSetTensor4dDescriptorEx(dst, dtype, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), output->strideAt(0), output->strideAt(1), output->strideAt(2), output->strideAt(3)); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx dst failed", res); - -// // TODO: workspace algorithms are supposed to be faster, so we should use it here if we have enough memory -// cudnnConvolutionFwdAlgo_t algo; -// res = cudnnGetConvolutionForwardAlgorithm(*handle, src, wght, cdc, dst, CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, &algo); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnGetConvolutionForwardAlgorithm failed", res); +// res = cudnnSetTensor4dDescriptorEx(dst, dtype, output->sizeAt(0), +// output->sizeAt(1), output->sizeAt(2), output->sizeAt(3), +// output->strideAt(0), output->strideAt(1), output->strideAt(2), +// output->strideAt(3)); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx dst +// failed", res); + +// // TODO: workspace algorithms are supposed to be faster, so we should use +// it here if we have enough memory cudnnConvolutionFwdAlgo_t algo; res = +// cudnnGetConvolutionForwardAlgorithm(*handle, src, wght, cdc, dst, +// CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, &algo); if (res != 0) +// throw sd::cuda_exception::build("cudnnGetConvolutionForwardAlgorithm +// failed", res); // // TODO: should be float if dtype is half/float, and double otherwise // float alpha = 1.0f; // float beta = 0.0f; -// res = cudnnConvolutionForward(*handle, &alpha, src, input->specialBuffer(), wght, weights->specialBuffer(), cdc, algo, nullptr, 0, &beta, dst, output->specialBuffer()); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnConvolutionForward failed", res); - +// res = cudnnConvolutionForward(*handle, &alpha, src, +// input->specialBuffer(), wght, weights->specialBuffer(), cdc, algo, +// nullptr, 0, &beta, dst, output->specialBuffer()); if (res != 0) +// throw sd::cuda_exception::build("cudnnConvolutionForward failed", +// res); // if (bias != nullptr) { // cudnnTensorDescriptor_t bs; // cudnnCreateTensorDescriptor(&bs); // if (isNCHW) { -// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NCHW, dtype, 1, bias->lengthOf(), 1, 1); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NCHW, dtype, 1, +// bias->lengthOf(), 1, 1); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx +// bias NHWC failed", res); // } else { -// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NHWC, dtype, 1, 1, 1, bias->lengthOf()); -// if (res != 0) -// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx bias NHWC failed", res); +// res = cudnnSetTensor4dDescriptor(bs, CUDNN_TENSOR_NHWC, dtype, 1, +// 1, 1, bias->lengthOf()); if (res != 0) +// throw sd::cuda_exception::build("cudnnSetTensor4dDescriptorEx +// bias NHWC failed", res); // } -// res = cudnnAddTensor(*handle, &alpha, bs, bias->specialBuffer(), &alpha, dst, output->specialBuffer()); -// if (res != 0) +// res = cudnnAddTensor(*handle, &alpha, bs, bias->specialBuffer(), +// &alpha, dst, output->specialBuffer()); if (res != 0) // throw sd::cuda_exception::build("cudnnAddTensor failed", res); // } - // NDArray::registerSpecialUse({output}, {input, weights, bias}); // return Status::OK(); // } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu index 693ebeefa053..9a2007b7497c 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu @@ -19,453 +19,758 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void conv3dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW, const int wFormat) { - - // cudnn support only one format for weights {oC,iC,kD,kH,kW} - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: can't set stream for cuDNN", err); - - const std::vector pads = {pD, pH, pW}; - const std::vector filtStrides = {sD, sH, sW}; - const std::vector dilations = {dD, dH, dW}; - - const std::vector xShape = {bS, iC, iD, iH, iW}; - const std::vector zShape = {bS, oC, oD, oH, oW}; - const std::vector wShape = {oC, iC, kD, kH, kW}; - const std::vector bShape = {1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; - - const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const std::vector zStrides = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilterNdDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, numDims, wShape.data()); - if(err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetFilterNdDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape.data()); - else - err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape.data(), zStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - err = cudnnSetTensorNdDescriptorEx(b, /*format*/CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), numDims, bShape.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv3dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); +static void conv3dCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* bias, + NDArray* output, const int kD, const int kH, + const int kW, const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, const int dD, + const int dH, const int dW, const int paddingMode, + const bool isNCDHW, const int wFormat) { + // cudnn support only one format for weights {oC,iC,kD,kH,kW} + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv3dCUDNN: can't set stream for cuDNN", + err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector zShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector bShape = { + 1, oC, 1, 1, 1}; // {1, (isNCDHW ? oC : 1), 1, 1, (isNCDHW ? 1 : oC)}; + + const std::vector xStrides = { + (int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), + (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector zStrides = { + (int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3), + (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape.data(), xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx " + "for input failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilterNdDescriptor(w, cudnnDataType(weights->dataType()), + CUDNN_TENSOR_NCHW, numDims, wShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + z, format, cudnnDataType(output->dataType()), numDims, zShape.data()); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), + numDims, zShape.data(), zStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx " + "for output failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor( + conv, numDims - 2, pads.data(), filtStrides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudaMalloc for auxiliary workspace memory failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + err = cudnnSetTensorNdDescriptorEx(b, /*format*/ CUDNN_TENSOR_NCHW, + cudnnDataType(bias->dataType()), numDims, + bShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudnnSetTensorNdDescriptor for bias failed", err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build("conv3dCUDNN: cudnnAddTensor bias failed", + err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dCUDNN: cudaStreamSynchronize failed + // !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// -static void conv3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, +static void conv3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const int paddingMode, const bool isNCDHW, const int wFormat) { - - // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for weights/gradW - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: can't set stream for cuDNN", err); - - const std::vector pads = {pD, pH, pW}; - const std::vector filtStrides = {sD, sH, sW}; - const std::vector dilations = {dD, dH, dW}; - - const std::vector xShape = {bS, iC, iD, iH, iW}; - const std::vector dzShape = {bS, oC, oD, oH, oW}; - const std::vector wShape = {oC, iC, kD, kH, kW}; - const std::vector dbShape = {1, (int)(isNCDHW ? oC : 1), 1, 1, (int)(isNCDHW ? 1 : oC)}; - - const std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3), (int)gradI->strideAt(4)}; - const std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - cudnnTensorFormat_t formatW = 0 == wFormat ? format : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape.data(), xStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape.data()); - else - err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape.data(), dzStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1) - err = cudnnSetTensorNdDescriptorEx(dx, format, cudnnDataType(gradI->dataType()), numDims, xShape.data()); - else - err = cudnnSetTensorNdDescriptor(dx, cudnnDataType(gradI->dataType()), numDims, xShape.data(), dxStrides.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), formatW, numDims, wShape.data()); - if(err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolutionNdDescriptor(conv, numDims-2, pads.data(), filtStrides.data(), dilations.data(), CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - err = cudnnSetTensorNdDescriptorEx(db, format, cudnnDataType(gradB->dataType()), numDims, dbShape.data()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnSetTensorNdDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("conv3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const int paddingMode, const bool isNCDHW, + const int wFormat) { + // cudnn supports only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} for + // weights/gradW + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("conv3dBpCUDNN: can't set stream for cuDNN", + err); + + const std::vector pads = {pD, pH, pW}; + const std::vector filtStrides = {sD, sH, sW}; + const std::vector dilations = {dD, dH, dW}; + + const std::vector xShape = {bS, iC, iD, iH, iW}; + const std::vector dzShape = {bS, oC, oD, oH, oW}; + const std::vector wShape = {oC, iC, kD, kH, kW}; + const std::vector dbShape = {1, (int)(isNCDHW ? oC : 1), 1, 1, + (int)(isNCDHW ? 1 : oC)}; + + const std::vector xStrides = { + (int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), + (int)input->strideAt(3), (int)input->strideAt(4)}; + const std::vector dxStrides = { + (int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), + (int)gradI->strideAt(3), (int)gradI->strideAt(4)}; + const std::vector dzStrides = { + (int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), + (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + cudnnTensorFormat_t formatW = + 0 == wFormat ? format + : (1 == wFormat ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC); + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape.data(), xStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), + numDims, dzShape.data(), dzStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1) + err = cudnnSetTensorNdDescriptorEx( + dx, format, cudnnDataType(gradI->dataType()), numDims, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, cudnnDataType(gradI->dataType()), + numDims, xShape.data(), dxStrides.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilterNdDescriptor(dw, cudnnDataType(gradW->dataType()), + formatW, numDims, wShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetFilterNdDescriptor failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolutionNdDescriptor( + conv, numDims - 2, pads.data(), filtStrides.data(), dilations.data(), + CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetConvolutionNdDescriptor failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + err = cudnnSetTensorNdDescriptorEx( + db, format, cudnnDataType(gradB->dataType()), numDims, dbShape.data()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnSetTensorNdDescriptor for gradB failed", err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("conv3dBpCUDNN: cudaStreamSynchronize + // failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData " + "failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "conv3dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData " + "failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - NDArray* newWeights = weights; // cudnn support only one format {oC,iC,kD,kH,kW} - if(1 != wFormat) { - newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(0 == wFormat ? std::vector({4,3,0,1,2}) : std::vector({0,4,1,2,3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW - } - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - - conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW, paddingMode, isNCDHW, wFormat); - - if(newInput != input) - delete newInput; - - if(1 != wFormat) - delete newWeights; - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CONV3D CUDNN OP: rank of input array must be equal to 5, but " + "got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CONV3D CUDNN OP: rank of weights array must be equal to 5, but " + "got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = block.getIArguments()->size() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 + ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], + // 2-[oC, kD, kH, kW, iC] + + REQUIRE_TRUE(paddingMode < 2, 0, + "CONV3D CUDNN OP: causal padding mode (paddingMode = 2) is not " + "allowed for this operation !"); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV3D CUDNN OP: wrong shape of weights array, expected is %s, " + "but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV3D CUDNN OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + NDArray* newWeights = + weights; // cudnn support only one format {oC,iC,kD,kH,kW} + if (1 != wFormat) { + newWeights = new NDArray(weights->ordering(), {oC, iC, kD, kH, kW}, + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + 0 == wFormat + ? std::vector({4, 3, 0, 1, 2}) + : std::vector( + {0, 4, 1, 2, + 3}))); // kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW or + // oC, kD, kH, kW, iC --> oC, iC, kD, kH, kW + } + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, + kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, isNCDHW); + + conv3dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kD, kH, + kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, + wFormat); + + if (newInput != input) delete newInput; + + if (1 != wFormat) delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(conv3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of weights array must be equal to 5, " + "but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CONV3D_BP CUDNN OP: rank of output gradients (next epsilon) " + "array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.getIArguments()->size() > 14 + ? INT_ARG(14) + : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], + // 2-[oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + REQUIRE_TRUE(paddingMode < 2, 0, + "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is " + "not allowed for this operation !"); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CONV3D_BP CUDNN OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, + "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV3D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW, paddingMode); + + NDArray *newWeights = weights, + *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} + // and {oC,kD,kH,kW,iC} + if (0 == wFormat) { + newGradW = + new NDArray(gradW->ordering(), + isNCDHW ? std::vector({oC, iC, kD, kH, kW}) + : std::vector({oC, kD, kH, kW, iC}), + gradW->dataType(), gradW->getContext()); + newWeights = + new NDArray(weights->ordering(), + isNCDHW ? std::vector({oC, iC, kD, kH, kW}) + : std::vector({oC, kD, kH, kW, iC}), + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute( + isNCDHW + ? std::vector({4, 3, 0, 1, 2}) + : std::vector( + {4, 0, 1, 2, + 3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or + // (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC) + } + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, + kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, + dW, isNCDHW); + + conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, + newGradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, isNCDHW, wFormat); + + if (0 == wFormat) { + newGradW->permutei( + isNCDHW ? std::vector({2, 3, 4, 1, 0}) + : std::vector( + {1, 2, 3, 4, + 0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or + // (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) + gradW->assign(newGradW); + } + + if (newInput != input) { + if (isNCDHW) + gradI->assign((*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, + gradI->sizeAt(3), 0, gradI->sizeAt(4)})); + else + gradI->assign((*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), + 0, gradI->sizeAt(3), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CONV3D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - REQUIRE_TRUE(paddingMode < 2, 0, "CONV3D_BP CUDNN OP: causal padding mode (paddingMode = 2) is not allowed for this operation !"); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV3D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradW->isSameShape(expectedWeightsShape), 0, "CONV3D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV3D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW, paddingMode); - - NDArray *newWeights = weights, *newGradW = gradW; // cudnn support only two formats {oC,iC,kD,kH,kW} and {oC,kD,kH,kW,iC} - if(0 == wFormat) { - newGradW = new NDArray(gradW->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), gradW->dataType(), gradW->getContext()); - newWeights = new NDArray(weights->ordering(), isNCDHW ? std::vector({oC, iC, kD, kH, kW}) : std::vector({oC, kD, kH, kW, iC}), weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(isNCDHW ? std::vector({4,3,0,1,2}) : std::vector({4,0,1,2,3}))); // (kD, kH, kW, iC, oC --> oC, iC, kD, kH, kW) or (kD, kH, kW, iC, oC --> oC, kD, kH, kW, iC) - } - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv3dCUDNNPadAsymmetric(newInput, newGradI, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); - - conv3dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kD,kH,kW,sD,sH,sW,pD,pH,pW,dD,dH,dW,paddingMode,isNCDHW,wFormat); - - if(0 == wFormat) { - newGradW->permutei(isNCDHW ? std::vector({2,3,4,1,0}) : std::vector({1,2,3,4,0})); // (oC, iC, kD, kH, kW --> kD, kH, kW, iC, oC) or (oC, kD, kH, kW, iC --> kD, kH, kW, iC, oC) - gradW->assign(newGradW); - } - - - if(newInput != input) { - - if(isNCDHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,gradI->sizeAt(4)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,gradI->sizeAt(3), 0,0})); - - delete newInput; - delete newGradI; - } - - if(0 == wFormat) { - delete newWeights; - delete newGradW; - } - - return Status::OK(); -} - -PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + if (0 == wFormat) { + delete newWeights; + delete newGradW; + } - return isNCDHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } +PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 + ? !INT_ARG(13) + : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return isNCDHW && paddingMode != 2 && !badInputType && !badWeightsType && + !badGradOType && !badBiasType; } -} -} + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu index 54f8a1f3bbca..821ad71fc841 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -18,395 +18,533 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW) { +void checkConv2dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iH, const int iW, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW) { + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); + if (!isPHasymm && !isPWasymm) return; - if(!isPHasymm && !isPWasymm) - return; + std::vector newShape = input->getShapeAsVector(); - std::vector newShape = input->getShapeAsVector(); + const int iHposition = isNCHW ? 2 : 1; - const int iHposition = isNCHW ? 2 : 1; + if (isPHasymm) newShape[iHposition] += 1; + if (isPWasymm) newShape[iHposition + 1] += 1; - if(isPHasymm) - newShape[iHposition] += 1; - if(isPWasymm) - newShape[iHposition + 1] += 1; + NDArray* newInput = new NDArray(input->ordering(), newShape, + input->dataType(), input->getContext()); - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + if (isNCHW) + (*newInput)({0, 0, 0, 0, 0, input->sizeAt(2), 0, input->sizeAt(3)}) + .assign(input); + else + (*newInput)({0, 0, 0, input->sizeAt(1), 0, input->sizeAt(2), 0, 0}) + .assign(input); - if(isNCHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); + input = newInput; - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); + if (gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), + gradI->getContext()); } - ////////////////////////////////////////////////////////////////////////// -void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW) { - - const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - - const bool isPDasymm = pD != (pDsum - pD); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); - - if(!isPDasymm && !isPHasymm && !isPWasymm) - return; - - std::vector newShape = input->getShapeAsVector(); - - const int iDposition = isNCDHW ? 2 : 1; - - if(isPDasymm) - newShape[iDposition] += 1; - if(isPHasymm) - newShape[iDposition + 1] += 1; - if(isPWasymm) - newShape[iDposition + 2] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCDHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); - - input = newInput; +void checkConv3dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW) { + const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPDasymm = pD != (pDsum - pD); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if (!isPDasymm && !isPHasymm && !isPWasymm) return; + + std::vector newShape = input->getShapeAsVector(); + + const int iDposition = isNCDHW ? 2 : 1; + + if (isPDasymm) newShape[iDposition] += 1; + if (isPHasymm) newShape[iDposition + 1] += 1; + if (isPWasymm) newShape[iDposition + 2] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, + input->dataType(), input->getContext()); + + if (isNCDHW) + (*newInput)({0, 0, 0, 0, 0, input->sizeAt(2), 0, input->sizeAt(3), 0, + input->sizeAt(4)}) + .assign(input); + else + (*newInput)({0, 0, 0, input->sizeAt(1), 0, input->sizeAt(2), 0, + input->sizeAt(3), 0, 0}) + .assign(input); + + input = newInput; + + if (gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), + gradI->getContext()); +} - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const bool isNCHW, + const cudnnPoolingMode_t mode) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, + pH, pW, sH, sW); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); } ////////////////////////////////////////////////////////////////////////// -void pooling2dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input}); - - // run calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input}); +void pooling2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for " + "input/gradI failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, + pH, pW, sH, sW); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, + gradO->specialBuffer(), dz, gradO->specialBuffer(), + x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); } ////////////////////////////////////////////////////////////////////////// -void pooling2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode) { - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input and gradI descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); +void pooling3dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, + const int dW, const bool isNCDHW, + const cudnnPoolingMode_t mode) { + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int zShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), (int)input->strideAt(3), + (int)input->strideAt(4)}; + const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), + (int)output->strideAt(2), (int)output->strideAt(3), + (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape, xStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input " + "failed", + err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + z, format, cudnnDataType(output->dataType()), numDims, zShape); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), + numDims, zShape, zStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, + numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int dzShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), + (int)input->strideAt(2), (int)input->strideAt(3), + (int)input->strideAt(4)}; + const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), + (int)gradO->strideAt(2), (int)gradO->strideAt(3), + (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), + numDims, xShape, xStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for " + "input/gradI failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensorNdDescriptorEx( + dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), + numDims, dzShape, dzStrides); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: " + "cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO " + "failed", + err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, + numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + // cudnn maxpool2d_bp api requires ff output as one of input arguments + if (mode == CUDNN_POOLING_MAX) { + NDArray temp(gradO); + + NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); + + // run ff calculation + err = + cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), + beta, dz, temp.specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling3dCUDNN: cudnnPoolingForward failed", err); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, + temp.specialBuffer(), dz, gradO->specialBuffer(), + x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); + } else { NDArray::prepareSpecialUse({gradI}, {input, gradO}); - // run calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + // run bp calculation for gradI + err = cudnnPoolingBackward( + *handle, pooling, alpha, dz, gradO->specialBuffer(), dz, + gradO->specialBuffer(), x, input->specialBuffer(), beta, x, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "pooling2dBpCUDNN: cudnnPoolingBackward failed", err); NDArray::registerSpecialUse({gradI}, {input, gradO}); -} + } -////////////////////////////////////////////////////////////////////////// -void pooling3dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode) { - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err); - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - const int pSizes[] = {pD, pH, pW}; - const int sSizes[] = {sD, sH, sW}; - const int kSizes[] = {kD, kH, kW}; - - const int xShape[] = {bS, iC, iD, iH, iW}; - const int zShape[] = {bS, oC, oD, oH, oW}; - - const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); - else - err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input}); - - // run calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - NDArray::registerSpecialUse({output}, {input}); + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build( + "pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); } -////////////////////////////////////////////////////////////////////////// -void pooling3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode) { - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err); - - const int numDims = 5; - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - const int pSizes[] = {pD, pH, pW}; - const int sSizes[] = {sD, sH, sW}; - const int kSizes[] = {kD, kH, kW}; - - const int xShape[] = {bS, iC, iD, iH, iW}; - const int dzShape[] = {bS, oC, oD, oH, oW}; - - const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; - const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; - - cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input and gradI descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); - else - err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); - else - err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); - - // description of pooling - cudnnPoolingDescriptor_t pooling; - cudnnCreatePoolingDescriptor(&pooling); - err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); - if (err != 0) throw sd::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - // cudnn maxpool2d_bp api requires ff output as one of input arguments - if(mode == CUDNN_POOLING_MAX) { - - NDArray temp(gradO); - - NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); - - // run ff calculation - err = cudnnPoolingForward(*handle, pooling, alpha, x, input->specialBuffer(), beta, dz, temp.specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); - - // run bp calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); - } - else { - - NDArray::prepareSpecialUse({gradI}, {input, gradO}); - - // run bp calculation for gradI - err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->specialBuffer(), dz, gradO->specialBuffer(), x, input->specialBuffer(), beta, x, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); - - NDArray::registerSpecialUse({gradI}, {input, gradO}); - } - - auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - if (cudaErr != 0) - throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); -} - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h index 3379979a32b9..9d015abaec36 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -21,121 +21,110 @@ #ifndef SD_CUDNNUTILS_H #define SD_CUDNNUTILS_H -#include -#include -#include +#include #include #include +#include +#include #include +#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - DECLARE_PLATFORM(conv2d, ENGINE_CUDA); - DECLARE_PLATFORM(conv2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(conv2d, ENGINE_CUDA); +DECLARE_PLATFORM(conv2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(conv3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(conv3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CUDA); - DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CUDA); - DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CUDA); +DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); - DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); +DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); +DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); - DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); +DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); ////////////////////////////////////////////////////////////////////////// FORCEINLINE cudnnDataType_t cudnnDataType(sd::DataType dataType) { - switch (dataType) { - case sd::DataType::FLOAT32: - return CUDNN_DATA_FLOAT; - case sd::DataType::DOUBLE: - return CUDNN_DATA_DOUBLE; - case sd::DataType::HALF: - return CUDNN_DATA_HALF; - case sd::DataType::INT32: - return CUDNN_DATA_INT32; - case sd::DataType::INT8: - return CUDNN_DATA_INT8; - default: - throw datatype_exception::build("Unsupported data type", dataType); - } + switch (dataType) { + case sd::DataType::FLOAT32: + return CUDNN_DATA_FLOAT; + case sd::DataType::DOUBLE: + return CUDNN_DATA_DOUBLE; + case sd::DataType::HALF: + return CUDNN_DATA_HALF; + case sd::DataType::INT32: + return CUDNN_DATA_INT32; + case sd::DataType::INT8: + return CUDNN_DATA_INT8; + default: + throw datatype_exception::build("Unsupported data type", dataType); + } } ////////////////////////////////////////////////////////////////////////// -void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW); +void checkConv2dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iH, const int iW, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW); ////////////////////////////////////////////////////////////////////////// -void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW); +void checkConv3dCUDNNPadAsymmetric(NDArray*& input, NDArray*& gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW); ////////////////////////////////////////////////////////////////////////// -void pooling2dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW, const cudnnPoolingMode_t mode); +void pooling2dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, const int dH, + const int dW, const bool isNCHW, + const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, +void pooling2dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, const bool isNCHW, const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling3dCUDNN(const LaunchContext* context, - const NDArray* input, NDArray* output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode); +void pooling3dCUDNN(const LaunchContext* context, const NDArray* input, + NDArray* output, const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, const int dH, + const int dW, const bool isNCDHW, + const cudnnPoolingMode_t mode); ////////////////////////////////////////////////////////////////////////// -void pooling3dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* gradO, - NDArray* gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const cudnnPoolingMode_t mode); - -} -} -} - -#endif //SD_CUDNNUTILS_H +void pooling3dBpCUDNN(const LaunchContext* context, const NDArray* input, + const NDArray* gradO, NDArray* gradI, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); + +} // namespace platforms +} // namespace ops +} // namespace sd + +#endif // SD_CUDNNUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu index c268961ce840..512b51efad5e 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu @@ -18,453 +18,755 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW) { - - // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) - - // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights [iC, mC, kH, kW] - // bias [oC], may be nullptr - // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(1); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // weights descriptor - cudnnFilterDescriptor_t w; - cudnnCreateFilterDescriptor(&w); - err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); - - // output descriptor - cudnnTensorDescriptor_t z; - cudnnCreateTensorDescriptor(&z); - if(output->ews() == 1 && output->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(output->dataType())); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); - err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetConvolutionGroupCount failed", err); - - // algorithm description - cudnnConvolutionFwdAlgo_t algo; - err = cudnnGetConvolutionForwardAlgorithm(*handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", err); - - // allocate auxiliary device memory, abbreviation ws means workspace - size_t wsSize; - err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, &wsSize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", err); - void* wsData; - auto cudaErr = cudaMalloc(&wsData, wsSize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudaMalloc for auxiliary workspace memory failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({output}, {input, weights, bias}); - - // run calculation - err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, weights->specialBuffer(), conv, algo, wsData, wsSize, beta, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnConvolutionForward failed", err); - - // add bias if it is present - if (bias != nullptr) { - - cudnnTensorDescriptor_t b; - cudnnCreateTensorDescriptor(&b); - // err = cudnnSetTensor4dDescriptor(b, format, cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, isNCHW ? 1: bias->lengthOf()); - err = cudnnSetTensor4dDescriptor(b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", err); - err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, output->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); - } - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("depthwiseConv2dCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dCUDNN: cudaFree for auxiliary workspace memory failed", cudaErr); - - NDArray::registerSpecialUse({output}, {input, weights, bias}); + const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW) { + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights [iC, mC, kH, kW] + // bias [oC], may be nullptr + // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // weights descriptor + cudnnFilterDescriptor_t w; + cudnnCreateFilterDescriptor(&w); + err = cudnnSetFilter4dDescriptor(w, cudnnDataType(weights->dataType()), + CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetFilter4dDescriptor failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if (output->ews() == 1 && output->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + z, cudnnDataType(output->dataType()), bS, oC, oH, oW, + output->strideAt(0), output->strideAt(indIOioC), + output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output " + "failed", + err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(output->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount( + conv, iC); // set number of groups (depthwise mode) in description of + // convolution, groupCount == iC + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // algorithm description + cudnnConvolutionFwdAlgo_t algo; + err = cudnnGetConvolutionForwardAlgorithm( + *handle, x, w, conv, z, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, &algo); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnGetConvolutionForwardAlgorithm failed", + err); + + // allocate auxiliary device memory, abbreviation ws means workspace + size_t wsSize; + err = cudnnGetConvolutionForwardWorkspaceSize(*handle, x, w, conv, z, algo, + &wsSize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnGetConvolutionForwardWorkspaceSize failed", + err); + void* wsData; + auto cudaErr = cudaMalloc(&wsData, wsSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudaMalloc for auxiliary workspace memory " + "failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input, weights, bias}); + + // run calculation + err = cudnnConvolutionForward(*handle, alpha, x, input->specialBuffer(), w, + weights->specialBuffer(), conv, algo, wsData, + wsSize, beta, z, output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnConvolutionForward failed", err); + + // add bias if it is present + if (bias != nullptr) { + cudnnTensorDescriptor_t b; + cudnnCreateTensorDescriptor(&b); + // err = cudnnSetTensor4dDescriptor(b, format, + // cudnnDataType(bias->dataType()), 1, isNCHW ? bias->lengthOf() : 1, 1, + // isNCHW ? 1: bias->lengthOf()); + err = cudnnSetTensor4dDescriptor( + b, CUDNN_TENSOR_NCHW, cudnnDataType(bias->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnSetTensor4dDescriptor for bias failed", + err); + err = cudnnAddTensor(*handle, alpha, b, bias->specialBuffer(), alpha, z, + output->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudnnAddTensor bias failed", err); + } + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dCUDNN: + // cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dCUDNN: cudaFree for auxiliary workspace memory failed", + cudaErr); + + NDArray::registerSpecialUse({output}, {input, weights, bias}); } ////////////////////////////////////////////////////////////////////////// static void depthwiseConv2dBpCUDNN(const LaunchContext* context, - const NDArray* input, const NDArray* weights, const NDArray* gradO, - NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const int paddingMode, const bool isNCHW) { - - // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) - - // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc - // weights, gradW [iC, mC, kH, kW] - // gradB [oC], may be nullptr - // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(1); - - auto handle = reinterpret_cast(context->getCuDnnHandle()); - cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: can't set stream for cuDNN", err); - - cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; - - // input descriptor - cudnnTensorDescriptor_t x; - cudnnCreateTensorDescriptor(&x); - if(input->ews() == 1 && input->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); - - // gradO descriptor - cudnnTensorDescriptor_t dz; - cudnnCreateTensorDescriptor(&dz); - if(gradO->ews() == 1 && gradO->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); - else - err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); - - // gradI descriptor - cudnnTensorDescriptor_t dx; - cudnnCreateTensorDescriptor(&dx); - if(gradI->ews() == 1 && gradI->ordering() == 'c') - err = cudnnSetTensor4dDescriptor(dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); - else - err = cudnnSetTensor4dDescriptorEx(dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), gradI->strideAt(indIiH + 1)); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI failed", err); - - // gradW descriptor - cudnnFilterDescriptor_t dw; - cudnnCreateFilterDescriptor(&dw); - err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), CUDNN_TENSOR_NCHW, iC, mC, kH, kW); - if(err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); - - // description of convolution - cudnnConvolutionDescriptor_t conv; - cudnnCreateConvolutionDescriptor(&conv); - err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, CUDNN_CROSS_CORRELATION, cudnnDataType(gradO->dataType())); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); - err = cudnnSetConvolutionGroupCount(conv, iC); // set number of groups (depthwise mode) in description of convolution, groupCount == iC - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetConvolutionGroupCount failed", err); - - // gradW algorithm description - cudnnConvolutionBwdFilterAlgo_t algoGradW; - err = cudnnGetConvolutionBackwardFilterAlgorithm(*handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, &algoGradW); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm failed", err); - - // gradI algorithm description - cudnnConvolutionBwdDataAlgo_t algoGradI; - err = cudnnGetConvolutionBackwardDataAlgorithm(*handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, &algoGradI); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm failed", err); - - // allocate auxiliary device memory for gradW calculation, abbreviation ws means workspace - size_t wsGradWSize; - err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, algoGradW, &wsGradWSize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterWorkspaceSize failed", err); - void* wsGradWData; - auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradWData failed", cudaErr); - - // allocate auxiliary device memory for gradI calculation, abbreviation ws means workspace - size_t wsGradISize; - err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, algoGradI, &wsGradISize); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize failed", err); - void* wsGradIData; - cudaErr = cudaMalloc(&wsGradIData, wsGradISize); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory wsGradIData failed", cudaErr); - - // provide scaling parameters - const float alpha32(1), beta32(0); - const double alpha64(1), beta64(0); - const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); - const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); - - NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); - - // run calculation for gradB (if not nullptr) - if(gradB != nullptr) { - cudnnTensorDescriptor_t db; - cudnnCreateTensorDescriptor(&db); - // err = cudnnSetTensor4dDescriptor(db, format, cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, isNCHW ? 1: gradB->lengthOf()); - err = cudnnSetTensor4dDescriptor(db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", err); - - err = cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), beta, db, gradB->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); - } - - // run calculation for gradW - err = cudnnConvolutionBackwardFilter(*handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, gradW->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); - - // run calculation for gradI - err = cudnnConvolutionBackwardData(*handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), conv, algoGradI, wsGradIData, wsGradISize, beta, dx, gradI->specialBuffer()); - if (err != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); - - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("depthwiseConv2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); - - cudaErr = cudaFree(wsGradWData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradWData failed", cudaErr); - cudaErr = cudaFree(wsGradIData); - if (cudaErr != 0) throw sd::cuda_exception::build("depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory wsGradIData failed", cudaErr); - - NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW) { + // cudnn supports only following case: mC = 1, oC = iC (groupCount == iC) + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc + // weights, gradW [iC, mC, kH, kW] + // gradB [oC], may be nullptr + // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc + // oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(1); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if (input->ews() == 1 && input->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), + input->strideAt(indIOioC), input->strideAt(indIiH), + input->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input " + "failed", + err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if (gradO->ews() == 1 && gradO->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx( + dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, + gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), + gradO->strideAt(indOoH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO " + "failed", + err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if (gradI->ews() == 1 && gradI->ordering() == 'c') + err = cudnnSetTensor4dDescriptor( + dx, format, cudnnDataType(gradI->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx( + dx, cudnnDataType(gradI->dataType()), bS, iC, iH, iW, + gradI->strideAt(0), gradI->strideAt(indIOioC), gradI->strideAt(indIiH), + gradI->strideAt(indIiH + 1)); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradI " + "failed", + err); + + // gradW descriptor + cudnnFilterDescriptor_t dw; + cudnnCreateFilterDescriptor(&dw); + err = cudnnSetFilter4dDescriptor(dw, cudnnDataType(gradW->dataType()), + CUDNN_TENSOR_NCHW, iC, mC, kH, kW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetFilter4dDescriptor gradW failed", err); + + // description of convolution + cudnnConvolutionDescriptor_t conv; + cudnnCreateConvolutionDescriptor(&conv); + err = cudnnSetConvolution2dDescriptor(conv, pH, pW, sH, sW, dH, dW, + CUDNN_CROSS_CORRELATION, + cudnnDataType(gradO->dataType())); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetConvolution2dDescriptor failed", err); + err = cudnnSetConvolutionGroupCount( + conv, iC); // set number of groups (depthwise mode) in description of + // convolution, groupCount == iC + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetConvolutionGroupCount failed", err); + + // gradW algorithm description + cudnnConvolutionBwdFilterAlgo_t algoGradW; + err = cudnnGetConvolutionBackwardFilterAlgorithm( + *handle, x, dz, conv, dw, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, + &algoGradW); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardFilterAlgorithm " + "failed", + err); + + // gradI algorithm description + cudnnConvolutionBwdDataAlgo_t algoGradI; + err = cudnnGetConvolutionBackwardDataAlgorithm( + *handle, dw, dz, conv, x, CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, 0, + &algoGradI); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataAlgorithm " + "failed", + err); + + // allocate auxiliary device memory for gradW calculation, abbreviation ws + // means workspace + size_t wsGradWSize; + err = cudnnGetConvolutionBackwardFilterWorkspaceSize(*handle, x, dz, conv, dw, + algoGradW, &wsGradWSize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: " + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed", + err); + void* wsGradWData; + auto cudaErr = cudaMalloc(&wsGradWData, wsGradWSize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory " + "wsGradWData failed", + cudaErr); + + // allocate auxiliary device memory for gradI calculation, abbreviation ws + // means workspace + size_t wsGradISize; + err = cudnnGetConvolutionBackwardDataWorkspaceSize(*handle, dw, dz, conv, dx, + algoGradI, &wsGradISize); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnGetConvolutionBackwardDataWorkspaceSize " + "failed", + err); + void* wsGradIData; + cudaErr = cudaMalloc(&wsGradIData, wsGradISize); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaMalloc for auxiliary workspace memory " + "wsGradIData failed", + cudaErr); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&alpha32) + : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 + ? reinterpret_cast(&beta32) + : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); + + // run calculation for gradB (if not nullptr) + if (gradB != nullptr) { + cudnnTensorDescriptor_t db; + cudnnCreateTensorDescriptor(&db); + // err = cudnnSetTensor4dDescriptor(db, format, + // cudnnDataType(gradB->dataType()), 1, isNCHW ? gradB->lengthOf() : 1, 1, + // isNCHW ? 1: gradB->lengthOf()); + err = cudnnSetTensor4dDescriptor( + db, CUDNN_TENSOR_NCHW, cudnnDataType(gradB->dataType()), 1, oC, 1, 1); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnSetTensor4dDescriptor for gradB failed", + err); + + err = + cudnnConvolutionBackwardBias(*handle, alpha, dz, gradO->specialBuffer(), + beta, db, gradB->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardBias failed", err); + } + + // run calculation for gradW + err = cudnnConvolutionBackwardFilter( + *handle, alpha, x, input->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradW, wsGradWData, wsGradWSize, beta, dw, + gradW->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardFilter failed", err); + + // run calculation for gradI + err = cudnnConvolutionBackwardData( + *handle, alpha, dw, weights->specialBuffer(), dz, gradO->specialBuffer(), + conv, algoGradI, wsGradIData, wsGradISize, beta, dx, + gradI->specialBuffer()); + if (err != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudnnConvolutionBackwardData failed", err); + + // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + // if (cudaErr != 0) + // throw cuda_exception::build("depthwiseConv2dBpCUDNN: + // cudaStreamSynchronize failed !", cudaErr); + + cudaErr = cudaFree(wsGradWData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory " + "wsGradWData failed", + cudaErr); + cudaErr = cudaFree(wsGradIData); + if (cudaErr != 0) + throw sd::cuda_exception::build( + "depthwiseConv2dBpCUDNN: cudaFree for auxiliary workspace memory " + "wsGradIData failed", + cudaErr); + + NDArray::registerSpecialUse({gradI, gradW, gradB}, {input, weights, gradO}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} in our case - if(0 == wFormat) - wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW - else if(1 == wFormat) - wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW - else - wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW - - NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); - newWeights->assign(weights->permute(wPermut)); - - NDArray* newInput = input; - NDArray* newGradI = nullptr; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - depthwiseConv2dCUDNN(block.launchContext(), newInput, newWeights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW, paddingMode, isNCHW); - - if(newInput != input) - delete newInput; - - delete newWeights; - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DEPTHWISECONV2D CUDNN OP: rank of input array must be equal to " + "4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "DEPTHWISECONV2D CUDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - + // [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "DEPTHWISECONV2D CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "DEPTHWISECONV2D CUDNN OP: the output_channels must be equal to " + "input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "DEPTHWISECONV2D CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector wPermut; // cudnn support format {oC, iC/groupCount, kH, kW} + // only, mC = 1, oC = iC (groupCount == iC) that is + // {iC, mC, kH, kW} in our case + if (0 == wFormat) + wPermut = {2, 3, 0, 1}; // kH, kW, iC, mC -> iC, mC, kH, kW + else if (1 == wFormat) + wPermut = {1, 0, 2, 3}; // mC, iC, kH, kW -> iC, mC, kH, kW + else + wPermut = {3, 0, 1, 2}; // mC, kH, kW, iC -> iC, mC, kH, kW + + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, + weights->dataType(), weights->getContext()); + newWeights->assign(weights->permute(wPermut)); + + NDArray* newInput = input; + NDArray* newGradI = nullptr; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dCUDNN(block.launchContext(), newInput, newWeights, bias, + output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, + isNCHW); + + if (newInput != input) delete newInput; + + delete newWeights; + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); - - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); - - return mC == 1 && paddingMode != 2 && !badInputType && !badWeightsType && !badBiasType; + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 + // - [mC, kH, kW, iC] + + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return mC == 1 && paddingMode != 2 && !badInputType && !badWeightsType && + !badBiasType; } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "DEPTHWISECONV2D_BP CUDNN OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - + // [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + std::vector wPermut, + gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC + // = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} + if (0 == wFormat) { + wPermut = {2, 3, 0, 1}; // kH, kW, iC, mC -> iC, mC, kH, kW + gradWPermut = {2, 3, 0, 1}; // iC, mC, kH, kW -> kH, kW, iC, mC + } else if (1 == wFormat) { + wPermut = {1, 0, 2, 3}; // mC, iC, kH, kW -> iC, mC, kH, kW + gradWPermut = {1, 0, 2, 3}; // iC, mC, kH, kW -> mC, iC, kH, kW + } else { + wPermut = {3, 0, 1, 2}; // mC, kH, kW, iC -> iC, mC, kH, kW + gradWPermut = {1, 2, 3, 0}; // iC, mC, kH, kW -> mC, kH, kW, iC + } + + NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, + gradW->dataType(), gradW->getContext()); + NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, + weights->dataType(), weights->getContext()); + + newWeights->assign(weights->permute(wPermut)); + + NDArray* newInput = input; + NDArray* newGradI = gradI; + if (paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric + // left/right top/bottopm paddings + checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, + sH, sW, pH, pW, dH, dW, isNCHW); + + depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, + newGradI, newGradW, gradB, kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, isNCHW); + + newGradW->permutei(gradWPermut); + gradW->assign(newGradW); + + if (newInput != input) { + if (isNCHW) + gradI->assign( + (*newGradI)({0, 0, 0, 0, 0, gradI->sizeAt(2), 0, gradI->sizeAt(3)})); + else + gradI->assign( + (*newGradI)({0, 0, 0, gradI->sizeAt(1), 0, gradI->sizeAt(2), 0, 0})); - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "DEPTHWISECONV2D_BP CUDNN OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "DEPTHWISECONV2D_BP CUDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - std::vector wPermut, gradWPermut; // cudnn support format {oC, iC/groupCount, kH, kW} only, mC = 1, oC = iC (groupCount == iC) that is {iC, mC, kH, kW} - if(0 == wFormat) { - wPermut = {2,3,0,1}; // kH, kW, iC, mC -> iC, mC, kH, kW - gradWPermut = {2,3,0,1}; // iC, mC, kH, kW -> kH, kW, iC, mC - } - else if(1 == wFormat) { - wPermut = {1,0,2,3}; // mC, iC, kH, kW -> iC, mC, kH, kW - gradWPermut = {1,0,2,3}; // iC, mC, kH, kW -> mC, iC, kH, kW - } - else { - wPermut = {3,0,1,2}; // mC, kH, kW, iC -> iC, mC, kH, kW - gradWPermut = {1,2,3,0}; // iC, mC, kH, kW -> mC, kH, kW, iC - } - - NDArray* newGradW = new NDArray(gradW->ordering(), {iC, mC, kH, kW}, gradW->dataType(), gradW->getContext()); - NDArray* newWeights = new NDArray(weights->ordering(), {iC, mC, kH, kW}, weights->dataType(), weights->getContext()); - - newWeights->assign(weights->permute(wPermut)); - - NDArray* newInput = input; - NDArray* newGradI = gradI; - if(paddingMode == 1) // in same paddingMode cudnn doesn't support asymmetric left/right top/bottopm paddings - checkConv2dCUDNNPadAsymmetric(newInput, newGradI, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - - depthwiseConv2dBpCUDNN(block.launchContext(), newInput, newWeights, gradO, newGradI, newGradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,paddingMode,isNCHW); - - newGradW->permutei(gradWPermut); - gradW->assign(newGradW); - - if(newInput != input) { - - if(isNCHW) - gradI->assign((*newGradI)({0,0, 0,0, 0,gradI->sizeAt(2), 0,gradI->sizeAt(3)})); - else - gradI->assign((*newGradI)({0,0, 0,gradI->sizeAt(1), 0,gradI->sizeAt(2), 0,0})); - - delete newInput; - delete newGradI; - } - - delete newWeights; - delete newGradW; - - return Status::OK(); -} - -PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + delete newInput; + delete newGradI; + } - const bool badInputType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; - const bool badWeightsType = weights->dataType() != DataType::DOUBLE && weights->dataType() != DataType::FLOAT32 && weights->dataType() != DataType::HALF; - const bool badGradOType = gradO->dataType() != DataType::DOUBLE && gradO->dataType() != DataType::FLOAT32 && gradO->dataType() != DataType::HALF; - const bool badBiasType = bias == nullptr ? false : (bias->dataType() != DataType::DOUBLE && bias->dataType() != DataType::FLOAT32 && bias->dataType() != DataType::HALF); + delete newWeights; + delete newGradW; - return mC == 1 && isNCHW && paddingMode != 2 && !badInputType && !badWeightsType && !badGradOType && !badBiasType; + return Status::OK(); } - -} -} +PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL + const int isNCHW = block.getIArguments()->size() > 9 + ? !INT_ARG(9) + : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + const int wFormat = block.getIArguments()->size() > 10 + ? INT_ARG(10) + : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 + // - [mC, kH, kW, iC] + + const int mC = weights->sizeAt(0 == wFormat ? 3 : 0); + + const bool badInputType = input->dataType() != DataType::DOUBLE && + input->dataType() != DataType::FLOAT32 && + input->dataType() != DataType::HALF; + const bool badWeightsType = weights->dataType() != DataType::DOUBLE && + weights->dataType() != DataType::FLOAT32 && + weights->dataType() != DataType::HALF; + const bool badGradOType = gradO->dataType() != DataType::DOUBLE && + gradO->dataType() != DataType::FLOAT32 && + gradO->dataType() != DataType::HALF; + const bool badBiasType = bias == nullptr + ? false + : (bias->dataType() != DataType::DOUBLE && + bias->dataType() != DataType::FLOAT32 && + bias->dataType() != DataType::HALF); + + return mC == 1 && isNCHW && paddingMode != 2 && !badInputType && + !badWeightsType && !badGradOType && !badBiasType; } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu index 5bb646f57c47..fb1682b9b2b3 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -18,115 +18,160 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = static_cast(INT_ARG(8)); - const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int oH = 0; - int oW = 0; - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const int isNCHW = block.getIArguments()->size() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE( + dH != 0 && dW != 0, 0, + "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, + iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, + dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto kH = INT_ARG(0); // filter(kernel) height - const auto kW = INT_ARG(1); // filter(kernel) width - const auto sH = INT_ARG(2); // strides height - const auto sW = INT_ARG(3); // strides width - auto pH = INT_ARG(4); // paddings height - auto pW = INT_ARG(5); // paddings width - const auto dH = INT_ARG(6); // dilations height - const auto dW = INT_ARG(7); // dilations width - const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto isNCHW = block.getIArguments()->size() > 10 + ? !INT_ARG(10) + : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, + pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); } PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu index f7b9c8b50278..b9fdd58d1f58 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -18,123 +18,180 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include #include "cudnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, " + "but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, + "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected " + "is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedOutputShape).c_str(), + ShapeUtils::shapeAsString(output).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, + pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && input->dataType() == output->dataType(); + return goodType && input->dataType() == output->dataType(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // const int extraParam0 = INT_ARG(13); // define what divisor to use while + // averaging + const int isNCDHW = block.getIArguments()->size() > 14 + ? !INT_ARG(14) + : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got " + "%i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + std::vector expectedGradIShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iD, iH, iW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array " + "(next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, + "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array " + "(epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradIShape).c_str(), + ShapeUtils::shapeAsString(gradI).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); } PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; - - return goodType && (input->dataType() == gradO->dataType()) - && (input->dataType() == gradI->dataType()) - && shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || + input->dataType() == DataType::FLOAT32 || + input->dataType() == DataType::HALF || + input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) && + (input->dataType() == gradI->dataType()) && + shape::haveSameShapeAndStrides(input->shapeInfo(), gradI->shapeInfo()); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 2fc2d12c8342..fd6b6cb8be81 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -20,115 +20,151 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; using namespace samediff; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - auto pH = INT_ARG(4); - auto pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto paddingMode = INT_ARG(8); - const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = INT_ARG(8); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingMKLDNN(input, output, 0, kH, kW, 0, sH, sW, 0, pH, pW, + isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0, kH, kW, 0, sH, sW, 0, pH, + pW, isNCHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index ff3199e3e206..a22f3abf6886 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -20,120 +20,155 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingMKLDNN(input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - - REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = + INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but " + "got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, isNCDHW, mode); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index ee0bb3a3012a..e768921e5cc7 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -21,439 +21,498 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// -static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z, - const float epsilon, const bool isNCHW) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - - // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - // mean -> 1D [c] - // variance -> 1D [c] - // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta - // z(output) - same shape as x - - const int xRank = x->rankOf(); - - // input type - dnnl::memory::data_type type = dnnl::memory::data_type::f32; - - // indicate whether gamma or/and beta are given - auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch - if (weights != nullptr) - flags |= dnnl::normalization_flags::use_scale_shift; - - dnnl::memory::dims dims; - dnnl::memory::format_tag format; - - const int indHW = isNCHW ? 2 : 1; - const int bS = x->sizeAt(0); - const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); - - int iD, iH, iW; - - if(xRank == 2) { - dims = {bS, iC}; - format = dnnl::memory::format_tag::nc; - } - else if(xRank == 4) { - iH = x->sizeAt(indHW); - iW = x->sizeAt(indHW + 1); - dims = {bS, iC, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // xRank = 5 - iD = x->sizeAt(indHW); - iH = x->sizeAt(indHW + 1); - iW = x->sizeAt(indHW + 2); - dims = {bS, iC, iD, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(x, x_user_md); - // z, output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // batchnorm forward description - dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory and check whether reorder is required - - // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; - if (zReorder) - dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); - args[DNNL_ARG_DST] = z_mkl_mem; - - // mean - auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); - args[DNNL_ARG_MEAN] = mean_mkl_mem; - - // variance - auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, const_cast(variance->buffer())); - args[DNNL_ARG_VARIANCE] = var_mkl_mem; - - // gamma and beta (and their gradients) if they are present - if(weights != nullptr) { - - auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, const_cast(weights->buffer())); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - } - - // run calculations - dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, + const NDArray* variance, const NDArray* weights, + NDArray* z, const float epsilon, + const bool isNCHW) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for x + + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + // mean -> 1D [c] + // variance -> 1D [c] + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, + // 0,0}) contains beta z(output) - same shape as x + + const int xRank = x->rankOf(); + + // input type + dnnl::memory::data_type type = dnnl::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = + dnnl::normalization_flags::use_global_stats; // don't calculate the mean + // and variance for each + // mini-batch + if (weights != nullptr) flags |= dnnl::normalization_flags::use_scale_shift; + + dnnl::memory::dims dims; + dnnl::memory::format_tag format; + + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + + if (xRank == 2) { + dims = {bS, iC}; + format = dnnl::memory::format_tag::nc; + } else if (xRank == 4) { + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // xRank = 5 + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(x, x_user_md); + // z, output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(z, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // batchnorm forward description + dnnl::batch_normalization_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); + const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; + if (zReorder) + dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); + args[DNNL_ARG_DST] = z_mkl_mem; + + // mean + auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, + const_cast(mean->buffer())); + args[DNNL_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = dnnl::memory(op_ff_prim_desc.variance_desc(), engine, + const_cast(variance->buffer())); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if (weights != nullptr) { + auto w_mkl_mem = dnnl::memory(op_ff_prim_desc.weights_desc(), engine, + const_cast(weights->buffer())); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + } + + // run calculations + dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////////// -static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, - NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - - // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - // mean -> 1D [c] - // variance -> 1D [c] - // dLdO - same shape as x - // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta - // dLdI - same shape as x - // dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta - - const int xRank = x->rankOf(); - - // input type - dnnl::memory::data_type type = dnnl::memory::data_type::f32; - - // indicate whether gamma or/and beta are given - auto flags = dnnl::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch - if (weights != nullptr) - flags |= dnnl::normalization_flags::use_scale_shift; - - dnnl::memory::dims dims; - dnnl::memory::format_tag format; - - const int indHW = isNCHW ? 2 : 1; - const int bS = x->sizeAt(0); - const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); - - int iD, iH, iW; - - if(xRank == 2) { - dims = {bS, iC}; - format = dnnl::memory::format_tag::nc; - } - else if(xRank == 4) { - iH = x->sizeAt(indHW); - iW = x->sizeAt(indHW + 1); - dims = {bS, iC, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // xRank = 5 - iD = x->sizeAt(indHW); - iH = x->sizeAt(indHW + 1); - iW = x->sizeAt(indHW + 2); - dims = {bS, iC, iD, iH, iW}; - format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(x, x_user_md); - - // dLdO - dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md); - - // dLdI - dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // batchnorm forward description - dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // batchnorm backprop description - dnnl::batch_normalization_backward::desc op_bp_desc(dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); - dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory and check whether reorder is required - - // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // dLdO - mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // mean - auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); - args[DNNL_ARG_MEAN] = mean_mkl_mem; - - // variance - auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, const_cast(variance->buffer())); - args[DNNL_ARG_VARIANCE] = var_mkl_mem; - - // dLdI - auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer()); - const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc(); - auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem; - args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem; - - // gamma and beta (and their gradients) if they are present - if(weights != nullptr) { - - auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, const_cast(weights->buffer())); - args[DNNL_ARG_WEIGHTS] = w_mkl_mem; - - auto dLdW_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->buffer()); - args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; - } - - // run calculations - dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (dLdIReorder) - dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); - - stream.wait(); - - // shape::printArray(dLdI_mkl_mem.map_data(),8); - - // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output - // g = dLdO - // stdInv = 1 / (v + eps)^0.5 - // N - batch size (product of spatial dimensions) - - // formula for full derivative with respect to input (x) - // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) - - // !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!! - - // dfdm = -gamma*stdInv*g_sum; - // dmdx = 1/N; - // dvdx = 2 * (x - m) / N - // dvdm = -2 * [(x - m)]_sum / N - // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience - - // finally: - // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) - // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) - - std::vector axes = isNCHW ? std::vector{1} : std::vector{xRank - 1}; - const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); - - // inversed batch size 1 / N - const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf(); - - // x - mean - NDArray xMinusMean(x); // empty array with same shape as x - const_cast(x)->applyBroadcast(sd::broadcast::Subtract, axes, *mean, xMinusMean); - - // stdInv - NDArray stdInv = *variance + epsilon; - stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon) - stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 - - // dfdm / N - auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); - dfdm *= stdInv; - dfdm *= -Ninv; - - // dvdm / 2 - NDArray dvdm(mean); // empty array with same shape as mean - xMinusMean.reduceAlongDimension(sd::reduce::Sum, dvdm, excludedAxes); - dvdm *= -Ninv; - - // (2/N)*dfdv - NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); - dfdv *= stdInv*stdInv*stdInv; - dfdv *= -Ninv; - - // dvdm/2 + (x - m) - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dvdm, xMinusMean); - // dfdv * (dvdm/2 + (x - m)) - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, dfdv, xMinusMean); - // add dfdm / N - xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dfdm, xMinusMean); - // * gamma - auto gamma = (*weights)({0,1, 0,0}); - xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, gamma, xMinusMean); - - *dLdI += xMinusMean; +static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, + const NDArray* variance, + const NDArray& dLdO, const NDArray* weights, + NDArray* dLdI, NDArray* dLdW, + const float epsilon, const bool isNCHW) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for x + + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + // mean -> 1D [c] + // variance -> 1D [c] + // dLdO - same shape as x + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, + // 0,0}) contains beta dLdI - same shape as x dLdW - same shape as weights, + // dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains + // grad_beta + + const int xRank = x->rankOf(); + + // input type + dnnl::memory::data_type type = dnnl::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = + dnnl::normalization_flags::use_global_stats; // don't calculate the mean + // and variance for each + // mini-batch + if (weights != nullptr) flags |= dnnl::normalization_flags::use_scale_shift; + + dnnl::memory::dims dims; + dnnl::memory::format_tag format; + + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + + if (xRank == 2) { + dims = {bS, iC}; + format = dnnl::memory::format_tag::nc; + } else if (xRank == 4) { + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // xRank = 5 + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(x, x_user_md); + + // dLdO + dnnl::memory::desc dLdO_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md); + + // dLdI + dnnl::memory::desc dLdI_mkl_md = + dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); + + mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // batchnorm forward description + dnnl::batch_normalization_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // batchnorm backprop description + dnnl::batch_normalization_backward::desc op_bp_desc( + dnnl::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); + dnnl::batch_normalization_backward::primitive_desc op_bp_prim_desc( + op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // dLdO + mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, + op_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // mean + auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, + const_cast(mean->buffer())); + args[DNNL_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = dnnl::memory(op_bp_prim_desc.variance_desc(), engine, + const_cast(variance->buffer())); + args[DNNL_ARG_VARIANCE] = var_mkl_mem; + + // dLdI + auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer()); + const bool dLdIReorder = + op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc(); + auto dLdI_mkl_mem = + dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) + : dLdI_user_mem; + args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if (weights != nullptr) { + auto w_mkl_mem = dnnl::memory(op_bp_prim_desc.weights_desc(), engine, + const_cast(weights->buffer())); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + auto dLdW_mkl_mem = + dnnl::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->buffer()); + args[DNNL_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; + } + + // run calculations + dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (dLdIReorder) + dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem) + .execute(stream, dLdI_mkl_mem, dLdI_user_mem); + + stream.wait(); + + // shape::printArray(dLdI_mkl_mem.map_data(),8); + + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * + // ff_output g = dLdO stdInv = 1 / (v + eps)^0.5 N - batch size (product of + // spatial dimensions) + + // formula for full derivative with respect to input (x) + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM + // (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!! + + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc + // convenience + + // finally: + // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) + // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) + + std::vector axes = + isNCHW ? std::vector{1} : std::vector{xRank - 1}; + const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); + + // inversed batch size 1 / N + const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf(); + + // x - mean + NDArray xMinusMean(x); // empty array with same shape as x + const_cast(x)->applyBroadcast(sd::broadcast::Subtract, axes, *mean, + xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal, + stdInv); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt, + stdInv); // 1 / (variance + epsilon)^0.5 + + // dfdm / N + auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); + dfdm *= stdInv; + dfdm *= -Ninv; + + // dvdm / 2 + NDArray dvdm(mean); // empty array with same shape as mean + xMinusMean.reduceAlongDimension(sd::reduce::Sum, dvdm, excludedAxes); + dvdm *= -Ninv; + + // (2/N)*dfdv + NDArray dfdv(variance); // empty array with same shape as variance + (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); + dfdv *= stdInv * stdInv * stdInv; + dfdv *= -Ninv; + + // dvdm/2 + (x - m) + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dvdm, xMinusMean); + // dfdv * (dvdm/2 + (x - m)) + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, dfdv, xMinusMean); + // add dfdm / N + xMinusMean.applyBroadcast(sd::broadcast::Add, axes, dfdm, xMinusMean); + // * gamma + auto gamma = (*weights)({0, 1, 0, 0}); + xMinusMean.applyBroadcast(sd::broadcast::Multiply, axes, gamma, xMinusMean); + + *dLdI += xMinusMean; } PLATFORM_IMPL(batchnorm, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - auto mean = INPUT_VARIABLE(1); // [c] - auto variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - auto output = OUTPUT_VARIABLE(0); // same shape as input - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - const int numOfIntArgs = block.numI(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, + "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis " + "which represents channel dimension, but got %i axes instead!", + numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, + "BATCHNORM_MKLDNN op: possible values for rank of input array " + "are 2, 4 or 5, but got %i instead!", + inRank); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE( + variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if (gamma != nullptr) + REQUIRE_TRUE( + gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if (beta != nullptr) + REQUIRE_TRUE( + beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], " + "but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for (int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE( + INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, + "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); + + NDArray* weights = nullptr; + + if (applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + + if (applyScale) + (*weights)({0, 1, 0, 0}).assign(gamma); else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma != nullptr) - REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta != nullptr) - REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same (except dLdO) - for(int i = 1; i < block.width() - 1; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); - - - NDArray *weights = nullptr; - - if(applyScale || applyOffset) { - - weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - - if(applyScale) - (*weights)({0,1, 0,0}).assign(gamma); - else - (*weights)({0,1, 0,0}).assign(1); - if(applyOffset) - (*weights)({1,2, 0,0}).assign(beta); - else - (*weights)({1,2, 0,0}).assign(0); - } + (*weights)({0, 1, 0, 0}).assign(1); + if (applyOffset) + (*weights)({1, 2, 0, 0}).assign(beta); + else + (*weights)({1, 2, 0, 0}).assign(0); + } - const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW); + batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW); - delete weights; + delete weights; - return Status::OK(); + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - auto mean = INPUT_VARIABLE(1); // [c] - auto variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - auto output = OUTPUT_VARIABLE(0); // same shape as input - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + (int)applyScale); - - - const int numOfIntArgs = block.numI(); - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension - - DataType inputType = input->dataType(); - DataType meanType = mean->dataType(); - DataType varType = variance->dataType(); - DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; - DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; - DataType outType = output->dataType(); - - const int inRank = input->rankOf(); - - return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && - (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && - gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if (applyScale) gamma = INPUT_VARIABLE(3); + if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.numI(); + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf() - + 1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + DataType outType = output->dataType(); + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && + (axes[0] == 1 || axes[0] == inRank - 1) && + (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && + varType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && + betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); } ////////////////////////////////////////////////////////////////////////// @@ -484,42 +543,53 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // axes.push_back(input->rankOf() - 1); // std::vector shape({2, mean->lengthOf()}); -// NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); -// weights({0, 1, 0, 0}).assign(1.0f); -// weights({1, 2, 0, 0}).assign(0.0f); +// NDArray weights = NDArrayFactory::create('c', shape, +// block.launchContext()); weights({0, 1, 0, 0}).assign(1.0f); weights({1, +// 2, 0, 0}).assign(0.0f); // mkldnn_memory_desc_t empty; -// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); +// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), +// user_src_md(empty), user_dst_md(empty); // auto flag = dnnl::normalization_flags::use_global_stats; // if (applyScale || applyOffset) // flag |= dnnl::normalization_flags::use_scale_shift; // mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, -// &batchnorm_src_md, nullptr, &batchnorm_dst_md, -// &user_src_md, nullptr, &user_dst_md, axes[0]); +// &batchnorm_src_md, nullptr, +// &batchnorm_dst_md, +// &user_src_md, nullptr, +// &user_dst_md, axes[0]); -// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); +// auto batchnorm_desc = +// dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, +// batchnorm_src_md, epsilon, flag); -// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); +// auto engine = +// mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // dnnl::stream stream(engine); -// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); -// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); -// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); -// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine, +// auto batchnorm_prim_desc = +// dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, +// engine); auto user_src_memory = dnnl::memory(user_src_md, engine, +// input->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, +// engine, output->buffer()); auto batchnorm_mean_memory = +// dnnl::memory(batchnorm_prim_desc.mean_desc(), engine, // mean->buffer()); -// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine, +// auto batchnorm_variance_memory = +// dnnl::memory(batchnorm_prim_desc.variance_desc(), engine, // variance->buffer()); // auto batchnorm_src_memory = user_src_memory; // dnnl::memory m(batchnorm_src_md, engine); // if (m.get_desc() != user_src_memory.get_desc()) { // batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine); -// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, +// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, +// user_src_memory, // batchnorm_src_memory); // } // auto batchnorm_dst_memory = user_dst_memory; // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine); +// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), +// engine); // } // if (applyScale || applyOffset) { // if (gamma != nullptr) { @@ -529,22 +599,34 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // weights({1, 2, 0, 0}).assign(beta); // } -// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); +// auto batchnorm_weights_memory = +// dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, +// weights.buffer()); // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, -// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, -// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, -// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, -// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, -// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// {{MKLDNN_ARG_SRC, +// batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, +// batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, +// batchnorm_variance_memory}, +// {MKLDNN_ARG_WEIGHTS, +// batchnorm_weights_memory}, +// {MKLDNN_ARG_DST, +// batchnorm_dst_memory}}); // } else { // dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream, -// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, -// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, -// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, -// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// {{MKLDNN_ARG_SRC, +// batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, +// batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, +// batchnorm_variance_memory}, +// {MKLDNN_ARG_DST, +// batchnorm_dst_memory}}); // } // if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { -// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, +// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, +// batchnorm_dst_memory, // user_dst_memory); // } // stream.wait(); @@ -583,160 +665,192 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { // axes.push_back(input->rankOf() - 1); // return block.isUseMKLDNN() && -// sd::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && -// axes.size() == 1; +// sd::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, +// output}) && axes.size() == 1; // } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { - - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc - NDArray* mean = INPUT_VARIABLE(1); // [c] - NDArray* variance = INPUT_VARIABLE(2); // [c] - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input - - NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input - NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] - NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] - NDArray* dLdG = nullptr; // [c] - NDArray* dLdB = nullptr; // [c] - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - const float epsilon = T_ARG(0); - - if(applyScale) { - gamma = INPUT_VARIABLE(3); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(3 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.numI(); - const int inRank = input->rankOf(); - - // get axes args to normalize input array over - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if (applyScale) { + gamma = INPUT_VARIABLE(3); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank - + 1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, + "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis " + "which represents channel dimension, but got %i axes instead!", + numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, + "BATCHNORM_BP_MKLDNN op: possible values for rank of input " + "array are 2, 4 or 5, but got %i instead!", + inRank); + REQUIRE_TRUE(input->isSameShape(dLdO), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(input).c_str(), + ShapeUtils::shapeAsString(dLdO).c_str()); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE( + variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), + 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if (gamma != nullptr) + REQUIRE_TRUE( + gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if (beta != nullptr) + REQUIRE_TRUE( + beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, + "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is " + "[%lld], but got %s instead !", + input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same + for (int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), + 0, + "BATCHNORM_BP_MKLDNN op: types of all input arrays should be " + "the same !"); + + NDArray *weights = nullptr, *dLdW = nullptr; + + if (applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, + input->dataType()); + if (applyScale) + (*weights)({0, 1, 0, 0}).assign(gamma); else - axes.push_back(inRank-1); // default dimension to reduce along is last dimension - - const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); - REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); - REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); - REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); - if(gamma != nullptr) - REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); - if(beta != nullptr) - REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); - - // types of all input arrays should be the same - for(int i = 1; i < block.width() - 1; ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); - - - NDArray *weights = nullptr, *dLdW = nullptr; - - if(applyScale || applyOffset) { - weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); - if(applyScale) - (*weights)({0,1, 0,0}).assign(gamma); - else - (*weights)({0,1, 0,0}).assign(1); - if(applyOffset) - (*weights)({1,2, 0,0}).assign(beta); - else - (*weights)({1,2, 0,0}).assign(0); - } - - const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - - if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) - batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + (*weights)({0, 1, 0, 0}).assign(1); + if (applyOffset) + (*weights)({1, 2, 0, 0}).assign(beta); else - batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); + (*weights)({1, 2, 0, 0}).assign(0); + } - *dLdM = 0; - *dLdV = 0; + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - if(applyScale || applyOffset) { - if(applyScale) - dLdG->assign((*dLdW)({0,1, 0,0})); - if(applyOffset) - dLdB->assign((*dLdW)({1,2, 0,0})); + if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) + batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, + epsilon, isNCHW); + else + batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, + dLdW, epsilon, isNCHW); - delete weights; - delete dLdW; - } + *dLdM = 0; + *dLdV = 0; - return Status::OK(); + if (applyScale || applyOffset) { + if (applyScale) dLdG->assign((*dLdW)({0, 1, 0, 0})); + if (applyOffset) dLdB->assign((*dLdW)({1, 2, 0, 0})); + + delete weights; + delete dLdW; + } + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) { - - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw - NDArray* mean = INPUT_VARIABLE(1); // [c] - NDArray* variance = INPUT_VARIABLE(2); // [c] - NDArray* dLdO = INPUT_VARIABLE(3); // same as input - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] - - NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input - NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] - NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] - NDArray* dLdG = nullptr; // [c] - NDArray* dLdB = nullptr; // [c] - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - if(applyScale) { - gamma = INPUT_VARIABLE(4); - dLdG = OUTPUT_VARIABLE(3); - } - if(applyOffset) { - beta = INPUT_VARIABLE(4 + (int)applyScale); - dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); - } - - const int numOfIntArgs = block.numI(); - std::vector axes; - if(numOfIntArgs > 2) - for(int i = 2; i < numOfIntArgs; ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension - - DataType inputType = input->dataType(); - DataType meanType = mean->dataType(); - DataType varType = variance->dataType(); - DataType dLdOType = dLdO->dataType(); - DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; - DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; - - DataType dLdIType = dLdI->dataType(); - DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; - DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; - - const int inRank = input->rankOf(); - - return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && - (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && - dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && - dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if (applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if (applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.numI(); + std::vector axes; + if (numOfIntArgs > 2) + for (int i = 2; i < numOfIntArgs; ++i) axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf() - + 1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType dLdOType = dLdO->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + + DataType dLdIType = dLdI->dataType(); + DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; + DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && + (axes[0] == 1 || axes[0] == inRank - 1) && + (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && + varType == DataType::FLOAT32 && dLdOType == DataType::FLOAT32 && + gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && + dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && + dLdBType == DataType::FLOAT32); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index ce9dea75022d..28eba154fc63 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -20,352 +20,445 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW, const int wFormat) { - - // mkl support weights in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - } - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; + const NDArray *bias, NDArray *output, const int kH, + const int kW, const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW, + const int wFormat) { + // mkl support weights in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + auto xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if (weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 3; + i1 = 2; + i2 = 0; + i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + } else { + i0 = 0; + i1 = 3; + i2 = 1; + i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] } - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - // shape::printArray(z_mkl_mem.map_data(),8); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + } + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////// -static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW, const int wFormat) { - - // mkl support weights/gradW in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - auto xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - } - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); +static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, const int paddingMode, + const int isNCHW, const int wFormat) { + // mkl support weights/gradW in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + auto xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if (weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 3; + i1 = 2; + i2 = 0; + i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + } else { + i0 = 0; + i1 = 3; + i2 = 1; + i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] } - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + } + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, type, wFormatMkl); + if (gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 3; + i1 = 2; + i2 = 0; + i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + } else { + i0 = 0; + i1 = 3; + i2 = 1; + i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + } + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); + const bool gradWReorder = + op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = + gradWReorder + ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) + : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem) + .execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } /* ////////////////////////////////////////////////////////////////////// -static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, - const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, - const int isNCHW) { +static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const +NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int +kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, +const int paddingMode, const int isNCHW) { - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, +input channels, input height/width, output channels, output height/width; int +indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, +iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, +dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty); - dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty); + dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), +z_mkl_md(empty); dnnl::memory::desc x_user_md(empty), w_user_md(empty), +b_user_md(empty), z_user_md(empty); dnnl::memory::dims strides, padding, padding_r, dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, - bias, output, - &x_mkl_md, nullptr, &w_mkl_md, nullptr, - &b_mkl_md, &z_mkl_md, - &x_user_md, nullptr, &w_user_md, nullptr, - &b_user_md, &z_user_md, - strides, padding, padding_r, dilation); - - auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, x_mkl_md, - w_mkl_md, b_mkl_md, - z_mkl_md, strides, dilation, padding, - padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, x_mkl_md, - w_mkl_md, - z_mkl_md, strides, dilation, padding, - padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, +paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, +nullptr, bias, output, &x_mkl_md, nullptr, &w_mkl_md, nullptr, &b_mkl_md, +&z_mkl_md, &x_user_md, nullptr, &w_user_md, nullptr, &b_user_md, &z_user_md, + strides, padding, padding_r, +dilation); + + auto conv_desc = bias != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r) + : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +x_mkl_md, w_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); auto +engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(w_user_md, engine, - const_cast(weights)->buffer()); - auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer()); - auto conv_src_memory = user_src_memory; - if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { - conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); - reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +engine); auto user_src_memory = dnnl::memory(x_user_md, engine, +const_cast(input)->buffer()); auto user_weights_memory = +dnnl::memory(w_user_md, engine, const_cast(weights)->buffer()); auto +user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer()); auto +conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != +user_src_memory.get_desc()) { conv_src_memory = +dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, +conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, - conv_weights_memory); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), +engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, +user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { @@ -373,105 +466,133 @@ static void conv2dMKLDNN(sd::graph::Context &block, const NDArray *input, const } if (bias != nullptr) { auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, - const_cast(bias)->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); - } else { - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + const_cast(bias)->buffer()); convolution_forward(conv_prim_desc).execute(stream, +{{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, +conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } else { + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, +conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, +conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); + reorder(conv_dst_memory, user_dst_memory).execute(stream, +conv_dst_memory, user_dst_memory); } stream.wait(); } ////////////////////////////////////////////////////////////////////// static void conv2dBpMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW, - const int paddingMode, const int isNCHW) { + const NDArray *input, const NDArray *weights, const +NDArray *bias, const NDArray *gradO, NDArray *gradI, NDArray *gradW, NDArray +*gradB, const int kH, const int kW, const int sH,const int sW, int pH, int pW, +const int dH, const int dW, const int paddingMode, const int isNCHW) { - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, +input channels, input height/width, output channels, output height/width; int +indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, +iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, +dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), +conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), +conv_dst_md(empty); dnnl::memory::desc user_src_md(empty), +user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), +user_bias_md(empty), user_dst_md(empty); - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, +paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); + &conv_src_md, &conv_diff_src_md, +&conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, +&user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, +conv_padding_r, conv_dilation); auto conv_desc = gradB != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); if (gradW != nullptr) { - auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto convW_desc = gradB != nullptr ? +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r) : +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + auto convW_prim_desc = +convolution_backward_weights::primitive_desc(convW_desc, engine, +conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); + auto userW_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto userW_weights_memory = +dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); auto +userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, +userW_src_memory,convW_src_memory); } auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { convW_weights_memory = +dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), +engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, +userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + auto convW_bias_memory = +dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { reorder(convW_weights_memory, +userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); } @@ -480,38 +601,48 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto convI_desc = +convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, +conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convI_prim_desc = +convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, +gradI->buffer()); auto userI_weights_memory = dnnl::memory(user_weights_md, +engine,const_cast(weights)->buffer()); auto userI_dst_memory = +dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), +engine); } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), +engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, +userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), +engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, +userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, +convI_dst_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, +convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + reorder(convI_src_memory, userI_src_memory).execute(stream, +convI_src_memory, userI_src_memory); } stream.wait(); @@ -522,116 +653,172 @@ static void conv2dBpMKLDNN(sd::graph::Context &block, ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI()> 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + bool isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV2D MKLDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV2D MKLDNN OP: wrong shape of array with biases, expected " + "rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat); + + return Status::OK(); } - PLATFORM_CHECK(conv2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); - // conv2d is only available for float32 dtype - return block.isUseMKLDNN() && input->dataType() == sd::DataType::FLOAT32 && - weights->dataType() == sd::DataType::FLOAT32; + // conv2d is only available for float32 dtype + return block.isUseMKLDNN() && input->dataType() == sd::DataType::FLOAT32 && + weights->dataType() == sd::DataType::FLOAT32; } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CONV2D_BP MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, + sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - - return block.isUseMKLDNN() && - sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported( + {input, weights, bias, gradO, gradI, gradW, gradB}); } - - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 5ba9f081d8cb..429995cb2818 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -19,358 +19,468 @@ // @author raver119@gmail.com // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, - const NDArray *bias, NDArray *output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW, const int wFormat) { - - // mkl support weights in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = {sD, sH, sW}; - dnnl::memory::dims padding = {pD, pH, pW}; - // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; - dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - - auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); - } - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; + const NDArray *bias, NDArray *output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW, + const int wFormat) { + // mkl support weights in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) + // * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * + // sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if (weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if (0 == wFormat) { + i0 = 4; + i1 = 3; + i2 = 0; + i3 = 1; + i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + i4 = 4; + } else { + i0 = 0; + i1 = 4; + i2 = 1; + i3 = 2; + i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] } - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + } + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); } ////////////////////////////////////////////////////////////////////// -static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW, const int wFormat) { - - // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = {sD, sH, sW}; - dnnl::memory::dims padding = {pD, pH, pW}; - // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; - dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; - - auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - auto type = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); +static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW, + const int wFormat) { + // mkl support weights/gradW in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) + // * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * + // sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + auto xzFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); + if (weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { + w_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if (0 == wFormat) { + i0 = 4; + i1 = 3; + i2 = 0; + i3 = 1; + i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + i4 = 4; + } else { + i0 = 0; + i1 = 4; + i2 = 1; + i3 = 2; + i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] } - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + } + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, type, xzFormatMkl); + + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, type, xzFormatMkl); + + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, type, wFormatMkl); + if (gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + uint i0, i1, i2, i3, i4; + if (0 == wFormat) { + i0 = 4; + i1 = 3; + i2 = 0; + i3 = 1; + i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + } else if (1 == wFormat) { + i0 = 0; + i1 = 1; + i2 = 2; + i3 = 3; + i4 = 4; + } else { + i0 = 0; + i1 = 4; + i2 = 1; + i3 = 2; + i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] } - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); + } + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); + const bool gradWReorder = + op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = + gradWReorder + ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) + : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem) + .execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - /* ////////////////////////////////////////////////////////////////////// static void conv3dMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, - NDArray *output, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, - const int paddingMode, const int isNCDHW) { - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + const NDArray *input, const NDArray *weights, const +NDArray *bias, NDArray *output, const int kD, const int kH, const int kW, const +int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const +int dH, const int dW, const int paddingMode, const int isNCDHW) { + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, +input channels, input depth/height/width, output channels, output +depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // +corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, +*input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, +indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); - - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, - isNCDHW, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, - nullptr, bias, output, - &conv_src_md, nullptr, &conv_weights_md, nullptr, - &conv_bias_md, &conv_dst_md, - &user_src_md, nullptr, &user_weights_md, nullptr, - &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), +conv_bias_md(empty), conv_dst_md( empty); dnnl::memory::desc user_src_md(empty), +user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, +dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, +nullptr, weights, nullptr, bias, output, &conv_src_md, nullptr, +&conv_weights_md, nullptr, &conv_bias_md, &conv_dst_md, &user_src_md, nullptr, +&user_weights_md, nullptr, &user_bias_md, &user_dst_md, conv_strides, +conv_padding, conv_padding_r, conv_dilation); auto conv_desc = bias != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +engine); auto user_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto user_weights_memory = +dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); - reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); + reorder(user_src_memory, conv_src_memory).execute(stream, +user_src_memory, conv_src_memory); } auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { - conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); + conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), +engine); reorder(user_weights_memory, conv_weights_memory).execute(stream, +user_weights_memory, conv_weights_memory); } auto conv_dst_memory = user_dst_memory; @@ -379,20 +489,21 @@ static void conv3dMKLDNN(sd::graph::Context &block, } if (bias != nullptr) { - auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_BIAS, conv_bias_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, +bias->buffer()); convolution_forward(conv_prim_desc).execute(stream, +{{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, +conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } else { - convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, - {DNNL_ARG_WEIGHTS, conv_weights_memory}, - {DNNL_ARG_DST, conv_dst_memory}}); + convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, +conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, +conv_dst_memory}}); } if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) - reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); + reorder(conv_dst_memory, user_dst_memory).execute(stream, +conv_dst_memory, user_dst_memory); stream.wait(); } @@ -400,247 +511,380 @@ static void conv3dMKLDNN(sd::graph::Context &block, ////////////////////////////////////////////////////////////////////// static void conv3dBpMKLDNN(sd::graph::Context &block, - const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, - NDArray *gradI, NDArray *gradW, NDArray *gradB, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const NDArray *input, const NDArray *weights, const +NDArray *bias, const NDArray *gradO, NDArray *gradI, NDArray *gradW, NDArray +*gradB, const int kD, const int kH, const int kW, const int sD, const int sH, +const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, const int paddingMode, const int isNCDHW) { - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, +input channels, input depth/height/width, output channels, output +depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // +corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, +*input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, +indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, - isNCDHW, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, - gradW, gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - - auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), +conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), +conv_dst_md(empty); dnnl::memory::desc user_src_md(empty), +user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), +user_bias_md(empty), user_dst_md(empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, +conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, +dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, +gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, +&conv_weights_md, &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, +&user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, +conv_padding_r, conv_dilation); + + auto conv_desc = gradB != nullptr ? +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, +conv_dilation, conv_padding, conv_padding_r) : +convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, +conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); + + auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream +stream(engine); if (gradW != nullptr) { - auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); - - auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convW_desc = gradB != nullptr ? +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, +conv_padding, conv_padding_r) : +convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, +conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); auto engine = +mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + auto convW_prim_desc = +convolution_backward_weights::primitive_desc(convW_desc, engine, +conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, +const_cast(input)->buffer()); auto userW_weights_memory = +dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); auto +userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, +userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) { convW_weights_memory = +dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); } auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), +engine); reorder(userW_dst_memory, convW_dst_memory).execute(stream, +userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + auto convW_bias_memory = +dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); } else { convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + {{DNNL_ARG_SRC, +convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, +convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); + if (convW_prim_desc.diff_weights_desc() != +userW_weights_memory.get_desc()) reorder(convW_weights_memory, +userW_weights_memory).execute(stream, convW_weights_memory, +userW_weights_memory); stream.wait(); } if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + auto convI_desc = +convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, +conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, +conv_padding_r); + + auto convI_prim_desc = +convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, +gradI->buffer()); auto userI_weights_memory = dnnl::memory(user_weights_md, +engine, const_cast(weights)->buffer()); auto userI_dst_memory = +dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), +engine); auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), +engine); reorder(userI_weights_memory, convI_weights_memory).execute(stream, +userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), +engine); reorder(userI_dst_memory, convI_dst_memory).execute(stream, +userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + {{DNNL_ARG_DIFF_DST, +convI_dst_memory}, {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, +convI_src_memory}}); if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + reorder(convI_src_memory, userI_src_memory).execute(stream, +convI_src_memory, userI_src_memory); } } */ ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if (paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal " + "to 5, but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], + // 2 - [oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + dD, dH, dW, paddingMode, isNCDHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, weights, bias, output}); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported({input, weights, bias, output}); } ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal " + "to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be " + "equal to 5, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next " + "epsilon) array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], + // 2 - [oC, kD, kH, kW, iC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, iC, oC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is " + "%s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM CONV3D_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, + sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, + wFormat); + + return Status::OK(); } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - return block.isUseMKLDNN() && - sd::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB}); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + return block.isUseMKLDNN() && + sd::MKLDNNStream::isSupported( + {input, weights, bias, gradO, gradI, gradW, gradB}); } - - -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 2099e82933aa..1cb36d813694 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -18,463 +18,652 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports weights format [oC, iC, kH, kW] only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dH-1, dW-1 }; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - } - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // mkl supports weights format [oC, iC, kH, kW] only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + } + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::deconvolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports weights/gradW in [oC, iC, kH, kW] format only - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dH-1, dW-1 }; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - } - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kH, + const int kW, const int sH, const int sW, + const int pH, const int pW, const int dH, + const int dW, const int paddingMode, + const bool isNCHW, const int wFormat) { + // mkl supports weights/gradW in [oC, iC, kH, kW] format only + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] + } + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::deconvolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // backward data primitive description + dnnl::deconvolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, + gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); + const bool gradWReorder = + op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = + gradWReorder + ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) + : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem) + .execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(paddingMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = + OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal " + "to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal " + "to 4, but got %i instead !", + weights->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, + paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(paddingMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); - } - - deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be " + "equal to 4 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, oC, iC], 1 - [iC, + // oC, kH, kW], 2 - [iC, kH, kW, oC] + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWoC, indWiC, indWkH, indOoH); + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, + dW); + } + + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, + pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + return block.isUseMKLDNN() && (dH <= 1 && dW <= 1 && !paddingMode) && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 45809cb5af0e..fea01c50a9a5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -18,204 +18,280 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, - const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const bool isNCHW, const int wFormat) { - - // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format - // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] - // gradO [bS, oH, oW, oC] - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims dilation = { dH - 1, dW - 1 }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - - dnnl::memory::format_tag xFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv2TFdBackPropMKLDNN(const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + const int bS, const int iC, const int iH, + const int iW, const int oC, const int oH, + const int oW, const int kH, const int kW, + const int sH, const int sW, const int pH, + const int pW, const int dH, const int dW, + const bool isNCHW, const int wFormat) { + // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format + // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, + // kW, iC, oC] gradO [bS, oH, oW, oC] + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + + dnnl::memory::format_tag xFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, gradOType, dnnl::memory::format_tag::any); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, + op_data_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { - - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - [oC, kH, kW, iC] - - const int rank = gradO->rankOf(); - - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); - - int indIOioC, indIiH, indWoC(3), indOoH; - if(!isNCHW) { - indIOioC = 3; indIiH = 1; indOoH = 1; - } - else { - indIOioC = 1; indIiH = 2; indOoH = 2; - } - - std::vector gradIShapeVector = gradIShape->template asVectorT(); - - const int bS = gradIShapeVector[0]; // batch size - const int iH = gradIShapeVector[indIiH]; // input height - const int iW = gradIShapeVector[indIiH+1]; // input width - const int iC = gradIShapeVector[indIOioC]; // input channels - const int oC = weights->sizeAt(indWoC); // output channels - const int oH = gradO->sizeAt(indOoH); // input height - const int oW = gradO->sizeAt(indOoH); // input width - - int trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = {kH, kW, iC, oC}; - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - // // mkl supports only [oC, iC, kH, kW] for weights - // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - - // // mkl supports NCHW format only - // if(!isNCHW) { - // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - // } - - deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); - - // delete weights; - - // if(!isNCHW) { - // delete gradI; - // delete gradO; - // } - - // ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); - - return Status::OK(); + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradIShape = INPUT_VARIABLE( + 0); // [4] - shape of input of conv2d (that is shape of gradI) + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, + // iC, kH, kW], 2 - [oC, kH, kW, iC] + + const int rank = gradO->rankOf(); + + REQUIRE_TRUE(weights->rankOf() == rank, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be " + "equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape " + "must be equal to 1, but got %i instead !", + gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, + "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output " + "shape must be equal to 4, but got %i instead !", + gradIShape->lengthOf()); + + int indIOioC, indIiH, indWoC(3), indOoH; + if (!isNCHW) { + indIOioC = 3; + indIiH = 1; + indOoH = 1; + } else { + indIOioC = 1; + indIiH = 2; + indOoH = 2; + } + + std::vector gradIShapeVector = + gradIShape->template asVectorT(); + + const int bS = gradIShapeVector[0]; // batch size + const int iH = gradIShapeVector[indIiH]; // input height + const int iW = gradIShapeVector[indIiH + 1]; // input width + const int iC = gradIShapeVector[indIOioC]; // input channels + const int oC = weights->sizeAt(indWoC); // output channels + const int oH = gradO->sizeAt(indOoH); // input height + const int oW = gradO->sizeAt(indOoH); // input width + + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE( + gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on " + "array with output shape expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + // // mkl supports only [oC, iC, kH, kW] for weights + // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, + // oC] -> [oC, iC, kH, kW] + + // // mkl supports NCHW format only + // if(!isNCHW) { + // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] + // -> [bS, iC, iH, iW] gradO = new NDArray(gradO->permute({0,3,1,2})); // + // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + // } + + deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, + kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); + + // delete weights; + + // if(!isNCHW) { + // delete gradI; + // delete gradO; + // } + + // ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, + // nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + + return Status::OK(); } PLATFORM_CHECK(deconv2d_tf, ENGINE_CPU) { - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI - - - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - const DataType gradIType = gradI->dataType(); - - return block.isUseMKLDNN() && ((wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16)); + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto gradO = INPUT_VARIABLE( + 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI + + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + const DataType gradIType = gradI->dataType(); + + return block.isUseMKLDNN() && + ((wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index eed031cf3f63..4267424db401 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -18,479 +18,684 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include - -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, - const bool isNCDHW, const int wFormat) { - - // mkl supports weights in [oC, iC, kD, kH, kW] only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - dnnl::memory::dims strides = { sD, sH, sW }; - dnnl::memory::dims padding = { pD, pH, pW }; - dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - } - else { - i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - } - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::deconvolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, const bool isNCDHW, + const int wFormat) { + // mkl supports weights in [oC, iC, kD, kH, kW] only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + dnnl::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, + (iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + uint i0, i1, i2, i3, i4; + if (0 == wFormat) { + i0 = 3; + i1 = 4; + i2 = 0; + i3 = 1; + i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; + i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + } else { + i0 = 4; + i1 = 0; + i2 = 1; + i3 = 2; + i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + } + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xFormatMkl = isNCDHW + ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::deconvolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW, const int wFormat) { - - // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - dnnl::memory::dims strides = { sD, sH, sW }; - dnnl::memory::dims padding = { pD, pH, pW }; - dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; - dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - } - else { - i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - } - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xFormatMkl = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; - - dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; - dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::deconvolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::deconvolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::deconvolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* gradO, NDArray* gradI, + NDArray* gradW, NDArray* gradB, const int kD, + const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, + const int pH, const int pW, const int dD, + const int dH, const int dW, + const bool isNCDHW, const int wFormat) { + // mkl supports weights/gradW in [oC, iC, kD, kH, kW] format only + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + dnnl::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, + (iH - 1) * sH - oH + kH - pH, + (iW - 1) * sW - oW + kW - pW}; + dnnl::memory::dims dilation = {dD - 1, dH - 1, dW - 1}; + + uint i0, i1, i2, i3, i4; + if (0 == wFormat) { + i0 = 3; + i1 = 4; + i2 = 0; + i3 = 1; + i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; + i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + } else { + i0 = 4; + i1 = 0; + i2 = 1; + i3 = 2; + i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] + } + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xFormatMkl = isNCDHW + ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::deconvolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // backward data primitive description + dnnl::deconvolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::deconvolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, + gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); + const bool gradWReorder = + op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = + gradWReorder + ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) + : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem) + .execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv3d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2)); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode){ // SAME - //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - } - - deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal " + "to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal " + "to 5, but got %i instead !", + weights->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) { // SAME + // Note: we're intentionally swapping iH and oH, to calculated the padding + // for a"normal" conv (not deconv) forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + } + + deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, + pW, dD, dH, dW, isNCDHW, wFormat); + + return Status::OK(); } PLATFORM_CHECK(deconv3d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + return block.isUseMKLDNN() && + (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); - - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); - - int trueoD, trueoH, trueoW; // true output height, width - ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass - ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be " + "equal to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be " + "equal to 5 , but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next " + "epsilon) array must be equal to 5, but got %i instead !", + gradO->rankOf()); + + int kD = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) depth + int kH = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) height + int kW = INT_ARG(2) > 0 + ? INT_ARG(2) + : static_cast(weights->sizeAt(2)); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = + block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + int wFormat = block.numI() > 14 + ? INT_ARG(14) + : 0; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], + // 2 - [iC, kD, kH, kW, oC] + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, wFormat, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWoC, indWiC, indWkD); + + int trueoD, trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, + sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, + iW, isSameMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, oC, trueoD, trueoH, trueoW, + 0, indIOioC, indIOioD, + indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kD, kH, kW, oC, iC); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients " + "(next epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + if (isSameMode) // Note: we're intentionally swapping iH and oH, to + // calculated the padding for a"normal" conv (not deconv) + // forward pass + ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, + sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, wFormat); + + return Status::OK(); } - PLATFORM_CHECK(deconv3d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI - auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - return block.isUseMKLDNN() && (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = + block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI + auto gradW = OUTPUT_VARIABLE( + 1); // [kD, kH, kW, oC, iC], [iC, oC, kD, kH, kW], [iC, kD, kH, kW, oC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + return block.isUseMKLDNN() && + (dD <= 1 && dH <= 1 && dW <= 1 && !isSameMode) && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index d66e6fbc8543..c969c227150d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -19,475 +19,671 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include -#include -#include #include +#include +#include #include +#include + #include "mkldnnUtils.h" using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports only following case: mC = 1, oC = iC - - // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights {iC, mC, 1, kH, kW} - // bias [oC], may be nullptr - // output [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] - } - - // input type - dnnl::memory::data_type xType; - if(input->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(input->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(input->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // output and bias type (have the same types) - dnnl::memory::data_type zType; - if(output->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(output->dataType() == DataType::HALF) - zType = dnnl::memory::data_type::f16; - else if(output->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(output->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - else - zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); - - // bias - dnnl::memory::desc b_mkl_md; - if(bias != nullptr) - b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, - x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - if(bias != nullptr) { - auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = b_mkl_mem; - } - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::convolution_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - // shape::printArray(z_mkl_mem.map_data(),8); +static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, + const NDArray* bias, NDArray* output, + const int kH, const int kW, const int sH, + const int sW, const int pH, const int pW, + const int dH, const int dW, + const int paddingMode, const bool isNCHW, + const int wFormat) { + // mkl supports only following case: mC = 1, oC = iC + + // input [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't + // support nhwc format we'll permute when nhwc is given weights {iC, mC, 1, + // kH, kW} bias [oC], may be nullptr output [bS, oC, oH, oW] nchw or [bS, oH, + // oW, oC] nhwc oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + + // input type + dnnl::memory::data_type xType; + if (input->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (input->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (input->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // output and bias type (have the same types) + dnnl::memory::data_type zType; + if (output->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (output->dataType() == DataType::HALF) + zType = dnnl::memory::data_type::f16; + else if (output->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (output->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + else + zType = dnnl::memory::data_type::s32; + + dnnl::memory::format_tag xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); + + // bias + dnnl::memory::desc b_mkl_md; + if (bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, zType, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); + mkldnnUtils::setBlockStrides(output, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, + padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + if (bias != nullptr) { + auto b_mkl_mem = + dnnl::memory(b_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode, const bool isNCHW, const int wFormat) { - - // mkl supports only following case: mC = 1, oC = iC - - // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl doesn't support nhwc format we'll permute when nhwc is given - // weights/gradW {iC, mC, 1, kH, kW} - // gradB [oC], may be nullptr - // gradO [bS, oC, oH, oW] nchw or [bS, oH, oW, oC] nhwc - // oC = iC*mC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - dnnl::memory::dims strides = { sH, sW }; - dnnl::memory::dims padding = { pH, pW }; - dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - dnnl::memory::dims dilation = { dH-1, dW-1}; - - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] - } - - // input type - dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // weights type - dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradO type - dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradI type - dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradW type - dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - // gradB type - dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - - dnnl::memory::format_tag xzFormatMkl = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; - - dnnl::memory::dims xDims = {bS, iC, iH, iW}; - dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; - dnnl::memory::dims zDims = {bS, oC, oH, oW}; - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); - - // weights - dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); // permute - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = 0; - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); - - // gradW - dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); // permute - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = 0; - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); - - // gradB - dnnl::memory::desc gradB_mkl_md; - if(gradB != nullptr) - gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // forward primitive description - dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward data primitive description - dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); - - // backward weights primitive description - dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); - dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // gradO - auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); - const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorderW) - dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); - if (gradOReorderD) - dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; - - // gradB - if(gradB != nullptr) { - auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); - args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; - } - - // run backward data calculations - dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); - - if(gradOReorderW || gradOReorderD) - args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; - - // run backward weights calculations - dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); - - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); - - stream.wait(); - - // shape::printArray(z_mkl_mem.map_data(),8); +static void depthwiseConv2dNackPropMKLDNN( + const NDArray* input, const NDArray* weights, const NDArray* gradO, + NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, + const int sH, const int sW, const int pH, const int pW, const int dH, + const int dW, const int paddingMode, const bool isNCHW, const int wFormat) { + // mkl supports only following case: mC = 1, oC = iC + + // input, gradI [bS, iC, iH, iW] nchw or [bS, iH, iW, iC] nhwc, since mkl + // doesn't support nhwc format we'll permute when nhwc is given weights/gradW + // {iC, mC, 1, kH, kW} gradB [oC], may be nullptr gradO [bS, oC, oH, oW] nchw + // or [bS, oH, oW, oC] nhwc oC = iC*mC + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); + + const int pWSame = (paddingMode == 2 && dW > 1) + ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 + : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sH, sW}; + dnnl::memory::dims padding = {pH, pW}; + dnnl::memory::dims padding_r = {(oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pWSame}; + dnnl::memory::dims dilation = {dH - 1, dW - 1}; + + uint i0, i1, i2, i3; + if (0 == wFormat) { + i0 = 2; + i1 = 3; + i2 = 0; + i3 = 1; // [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW] + } else if (1 == wFormat) { + i0 = 1; + i1 = 0; + i2 = 2; + i3 = 3; // [mC, iC, kH, kW] -> [iC, mC, 1, kH, kW] + } else { + i0 = 3; + i1 = 0; + i2 = 1; + i3 = 2; // [mC, kH, kW, iC] -> [iC, mC, 1, kH, kW] + } + + // input type + dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // weights type + dnnl::memory::data_type wType = weights->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradO type + dnnl::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradI type + dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradW type + dnnl::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16; + // gradB type + dnnl::memory::data_type gradBType = + gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 + ? dnnl::memory::data_type::f32 + : dnnl::memory::data_type::bf16) + : dnnl::memory::data_type::f32; + + dnnl::memory::format_tag xzFormatMkl = + isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormatMkl = dnnl::memory::format_tag::goihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {iC, mC, 1, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); + mkldnnUtils::setBlockStrides(input, x_user_md); + + // weights + dnnl::memory::desc w_mkl_md = + dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(i0); // permute + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); + w_user_md.data.format_desc.blocking.strides[2] = 0; + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i2); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i3); + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = + dnnl::memory::desc(zDims, gradOType, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = + dnnl::memory::desc(xDims, gradIType, xzFormatMkl); + mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + + // gradW + dnnl::memory::desc gradW_mkl_md = + dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = + dnnl::memory::desc(wDims, gradWType, wFormatMkl); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = + gradW->strideAt(i0); // permute + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); + gradW_user_md.data.format_desc.blocking.strides[2] = 0; + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i2); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i3); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if (gradB != nullptr) + gradB_mkl_md = + dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, + x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, + padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc( + dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, + strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc( + op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc( + dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, + gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc( + op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_weights_bp_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, + op_data_bp_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // gradO + auto gradO_user_mem = + dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); + const bool gradOReorderW = + op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = + op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = + gradOReorderW + ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + auto gradO_mkl_memD = + gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) + : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW) + .execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD) + .execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); + const bool gradWReorder = + op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = + gradWReorder + ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) + : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if (gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->buffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if (gradOReorderW || gradOReorderD) args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc) + .execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem) + .execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); } - ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC - - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - REQUIRE_TRUE(output->sizeAt(indIOioC) == iC*mC, 0, "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be equal to input_channels * channels_multiplier = %i !", iC*mC); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC + + auto output = OUTPUT_VARIABLE( + 0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of weights array, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(output->sizeAt(indIOioC) == iC * mC, 0, + "CUSTOM DEPTHWISECONV2D MKL OP: the output_channels must be " + "equal to input_channels * channels_multiplier = %i !", + iC * mC); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D MKL OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, + dH, dW, paddingMode, isNCHW, wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto weights = INPUT_VARIABLE(1); - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; - - auto output = INPUT_VARIABLE(0); - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType zType = output->dataType(); - const DataType bType = bias != nullptr ? bias->dataType() : zType; - - const int mC = weights->sizeAt(3); - - return block.isUseMKLDNN() && mC == 1 && - ( - (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) - ); + auto input = INPUT_VARIABLE(0); + auto weights = INPUT_VARIABLE(1); + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; + + auto output = INPUT_VARIABLE(0); + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType zType = output->dataType(); + const DataType bType = bias != nullptr ? bias->dataType() : zType; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + (xType == DataType::BFLOAT16 && wType == DataType::BFLOAT16 && + bType == DataType::BFLOAT16 && zType == DataType::BFLOAT16) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + wType == DataType::INT8 && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32) && + bType == zType)); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf()); - - int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) height - int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - [mC, kH, kW, iC] - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - int trueoH, trueoW; // correct output height, width - ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); - std::vector expectedWeightsShape = ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if(bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - - depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_NULLIFIED( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_NULLIFIED( + 1); // [kH, kW, iC, mC], [mC, iC, kH, kW], [mC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be " + "equal to 4, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must " + "be equal to 4, but got %i instead !", + weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of output gradients " + "(next epsilon) array must be equal to 4, but got %i instead !", + gradO->rankOf()); + + int kH = INT_ARG(0) > 0 + ? INT_ARG(0) + : static_cast(weights->sizeAt(0)); // filter(kernel) height + int kW = INT_ARG(1) > 0 + ? INT_ARG(1) + : static_cast(weights->sizeAt(1)); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = + block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW + int wFormat = + block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, + // iC, kH, kW], 2 - [mC, kH, kW, iC] + + int bS, iC, iH, iW, mC, oC, oH, + oW; // batch size, input channels, input height/width, channels + // multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, wFormat, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + int trueoH, trueoW; // correct output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, + dH, dW, iH, iW, paddingMode); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW, paddingMode); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, oC, trueoH, trueoW, 0, indIOioC, indOoH, indOoH + 1}); + std::vector expectedWeightsShape = + ConvolutionUtils::expectWeightsShape(wFormat, kH, kW, iC, mC); + REQUIRE_TRUE( + gradO->isSameShape(expectedGradOShape), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of output gradients (next " + "epsilon) array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of weights " + "array, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), + ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE( + bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, + "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, " + "expected rank, length: <=2, %i, but got %i, %i instead !", + oC, bias->rankOf(), bias->lengthOf()); + + depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, + kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, + wFormat); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - const DataType xType = input->dataType(); - const DataType wType = weights->dataType(); - const DataType gradOType = gradO->dataType(); - - const DataType gradIType = gradI->dataType(); - const DataType gradWType = gradW->dataType(); - const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; - - const int mC = weights->sizeAt(3); - - return block.isUseMKLDNN() && mC == 1 && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) ); + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto bias = + block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] + auto gradO = block.width() > 3 + ? INPUT_VARIABLE(3) + : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, + // oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE( + 1); // [kH, kW, iC, oC], [oC, iC, kH, kW], [oC, kH, kW, iC] + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + const DataType xType = input->dataType(); + const DataType wType = weights->dataType(); + const DataType gradOType = gradO->dataType(); + + const DataType gradIType = gradI->dataType(); + const DataType gradWType = gradW->dataType(); + const DataType gradBType = + gradB != nullptr ? gradB->dataType() : DataType::FLOAT32; + + const int mC = weights->sizeAt(3); + + return block.isUseMKLDNN() && mC == 1 && + ((xType == DataType::FLOAT32 || xType == DataType::BFLOAT16) && + (wType == DataType::FLOAT32 || wType == DataType::BFLOAT16) && + (gradOType == DataType::FLOAT32 || gradOType == DataType::BFLOAT16) && + (gradIType == DataType::FLOAT32 || gradIType == DataType::BFLOAT16) && + (gradWType == DataType::FLOAT32 || gradWType == DataType::BFLOAT16) && + (gradBType == DataType::FLOAT32 || gradBType == DataType::BFLOAT16)); } -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp index 583ab08528c0..c99c5c93dbba 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lrn.cpp @@ -19,75 +19,83 @@ // @author raver119@gmail.com // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - PLATFORM_IMPL(lrn, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", - input->rankOf()); - - double alpha = T_ARG(1); - double beta = T_ARG(2); - double bias = T_ARG(0); - int depth = INT_ARG(0); - - dnnl_memory_desc_t empty; - dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); - - mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, - &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1); - - auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels, - lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto lrn_src_memory = user_src_memory; - if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { - lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); - reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory); - } - - auto lrn_dst_memory = user_dst_memory; - if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); - } - - lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, - {DNNL_ARG_DST, lrn_dst_memory}}); - - if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory); - } - - stream.wait(); - - return Status::OK(); - }; - - PLATFORM_CHECK(lrn, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file +namespace ops { +namespace platforms { +PLATFORM_IMPL(lrn, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "lrn: Input rank of 4 expected, but got %i instead", + input->rankOf()); + + double alpha = T_ARG(1); + double beta = T_ARG(2); + double bias = T_ARG(0); + int depth = INT_ARG(0); + + dnnl_memory_desc_t empty; + dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), + user_dst_md(empty); + + mkldnnUtils::getMKLDNNMemoryDescLrn( + input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, &user_src_md, + nullptr, &user_dst_md, input->rankOf() - 1); + + auto lrn_desc = lrn_forward::desc( + prop_kind::forward_inference, algorithm::lrn_across_channels, lrn_src_md, + (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto lrn_src_memory = user_src_memory; + if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { + lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); + reorder(user_src_memory, lrn_src_memory) + .execute(stream, user_src_memory, lrn_src_memory); + } + + auto lrn_dst_memory = user_dst_memory; + if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); + } + + lrn_forward(lrn_prim_desc) + .execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, + {DNNL_ARG_DST, lrn_dst_memory}}); + + if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(lrn_dst_memory, user_dst_memory) + .execute(stream, lrn_dst_memory, user_dst_memory); + } + + stream.wait(); + + return Status::OK(); +}; + +PLATFORM_CHECK(lrn, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); +} +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 60c61ea5f729..25d2fa1d0961 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -19,509 +19,695 @@ // #include + #include "mkldnnUtils.h" using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { -static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr, - const NDArray* b, const NDArray* hI, const NDArray* cI, - const std::vector& params, - NDArray* h, NDArray* hL, NDArray* cL) { - - // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) - - // notations: - // bS - batch size - // sL - sequence length, number of time steps - // nIn - input size - // nOut - output size (hidden size) - - // INPUTS: - - // ******* - // input x: - // 1) [sL, bS, nIn] when dataFormat == 0 - - // ******* - // input weights Wx: - // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 - // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 - - // ******* - // recurrent weights Wr: - // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 - // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 - - // ******* - // biases b: - // 1) [1, 1, 4*nOut] when directionMode < 2 - // 2) [1, 2, 4*nOut] when directionMode >= 2 - - // ******* - // initial output hI: - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // ******* - // initial cell state cI (same shape as in hI): - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - - // OUTPUTS: - - // ******* - // output h: - // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 - // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 - - // ******* - // output at last step hL: - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // ******* - // cell state at last step cL (same shape as in hL): - // 1) [1, 1, bS, nOut] when directionMode < 2 - // 2) [1, 2, bS, nOut] when directionMode >= 2 - - // !!! dimension 4*nOut implies order it, ft, c't, ot - // !!! dimension 3*nOut implies order it, ft, ot - - // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - // dataFormat: 0 = [sL, bS, nIn] - // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat - - const int dataFormat = params[0]; - const int directionMode = params[1]; - - const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); - const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); - const int nIn = x->sizeAt(-1); - const int nOut = Wx->sizeAt(-1); - - const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional - const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) - - // evaluate direction - rnn_direction direction; - switch (directionMode) { - case 0: - direction = rnn_direction::unidirectional_left2right; - break; - case 1: - direction = rnn_direction::unidirectional_right2left; - break; - case 2: - direction = rnn_direction::bidirectional_sum; - break; - default: - direction = rnn_direction::bidirectional_concat; - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, - x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; - - // input type - dnnl::memory::data_type xType; - if(x->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(x->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else - xType = dnnl::memory::data_type::u8; - - // weights type - dnnl::memory::data_type wType = xType; - if(xType == dnnl::memory::data_type::u8) - wType = dnnl::memory::data_type::s8; - - // bias type - dnnl::memory::data_type bType = xType; - if(xType == dnnl::memory::data_type::u8) - bType = dnnl::memory::data_type::f32; - - // output type - dnnl::memory::data_type hType; - if(h->dataType() == DataType::FLOAT32) - hType = dnnl::memory::data_type::f32; - else if(h->dataType() == DataType::HALF) - hType = dnnl::memory::data_type::f16; - else - hType = dnnl::memory::data_type::u8; - - - // memory descriptors for arrays - // x - x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); - // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); - x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; - - // wx - wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any); - wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - wx_user_md.data.format_kind = dnnl_blocked; // overrides format - wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; - wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; - wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; - wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3]; - wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; - - // wr - wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any); - wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - wr_user_md.data.format_kind = dnnl_blocked; // overrides format - wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; - wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; - wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; - wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3]; - wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; - - // h - h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any); - // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); - h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc); - h_user_md.data.format_kind = dnnl_blocked; // overrides format - h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; - h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; - h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; - - // b - if(b) { - b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any); - b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo); - b_user_md.data.format_kind = dnnl_blocked; // overrides format - b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; - b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; - b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; - b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3]; - } - - // hI - if(hI) { - hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); - hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - hI_user_md.data.format_kind = dnnl_blocked; // overrides format - hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; - hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; - hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; - hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3]; - } - - // cI - if(cI) { - cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); - cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - cI_user_md.data.format_kind = dnnl_blocked; // overrides format - cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; - cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; - cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; - cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3]; - } - - // hL - if(hL) { - hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any); - hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - hL_user_md.data.format_kind = dnnl_blocked; // overrides format - hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; - hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; - hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; - hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3]; - } - - if(cL) { - cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - cL_user_md.data.format_kind = dnnl_blocked; // overrides format - cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; - cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; - cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; - cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3]; - } - - // lstm memory description - lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, - x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, - h_lstm_md, hL_lstm_md, cL_lstm_md); - - dnnl::stream stream(engine); - - // lstm primitive description - lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - // provide memory and check whether reorder is required - // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); - - // wx - mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); - - // wr - mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); - - // h - auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer()); - const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); - auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; - args[DNNL_ARG_DST_LAYER] = h_lstm_mem; - - // b - if(b) { - mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); - } - - // hI - if(hI) { - mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); - } - - // cI - if(cI) { - mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); - } - - bool hLReorder(false), cLReorder(false); - dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; - - // hL - if(hL) { - hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer()); - hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); - hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; - args[DNNL_ARG_DST_ITER] = hL_lstm_mem; - } - - // cL - if(cL) { - cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer()); - cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); - cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; - args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem; - } - - // run calculations - lstm_forward(lstm_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (hReorder) - reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem); - if(hLReorder) - reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem); - if(cLReorder) - reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem); - - stream.wait(); +static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, + const NDArray* Wr, const NDArray* b, + const NDArray* hI, const NDArray* cI, + const std::vector& params, NDArray* h, + NDArray* hL, NDArray* cL) { + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + + // ******* + // input weights Wx: + // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 + // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 + // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [1, 1, 4*nOut] when directionMode < 2 + // 2) [1, 2, 4*nOut] when directionMode >= 2 + + // ******* + // initial output hI: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + + // ******* + // output at last step hL: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, + // gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + // dataFormat: 0 = [sL, bS, nIn] + // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = + // bidirectional concat + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = + x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); + const int bS = + x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const int nIn = x->sizeAt(-1); + const int nOut = Wx->sizeAt(-1); + + const int dirDim = + directionMode < 2 + ? 1 + : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional + const int hDirDim = directionMode <= 2 + ? 1 + : 2; // for h array, take into account + // bidirectional_sum mode (directionMode == 2) + + // evaluate direction + rnn_direction direction; + switch (directionMode) { + case 0: + direction = rnn_direction::unidirectional_left2right; + break; + case 1: + direction = rnn_direction::unidirectional_right2left; + break; + case 2: + direction = rnn_direction::bidirectional_sum; + break; + default: + direction = rnn_direction::bidirectional_concat; + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + dnnl::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, + cI_user_md, h_user_md, hL_user_md, cL_user_md, x_lstm_md, wx_lstm_md, + wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, + cL_lstm_md; + + // input type + dnnl::memory::data_type xType; + if (x->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (x->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else + xType = dnnl::memory::data_type::u8; + + // weights type + dnnl::memory::data_type wType = xType; + if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + + // bias type + dnnl::memory::data_type bType = xType; + if (xType == dnnl::memory::data_type::u8) + bType = dnnl::memory::data_type::f32; + + // output type + dnnl::memory::data_type hType; + if (h->dataType() == DataType::FLOAT32) + hType = dnnl::memory::data_type::f32; + else if (h->dataType() == DataType::HALF) + hType = dnnl::memory::data_type::f16; + else + hType = dnnl::memory::data_type::u8; + + // memory descriptors for arrays + // x + x_lstm_md = + dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, + // dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, + // dnnl::memory::format_tag::ntc); + x_user_md = + dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + + // wx + wx_lstm_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, + dnnl::memory::format_tag::any); + wx_user_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, + dnnl::memory::format_tag::ldigo); + wx_user_md.data.format_kind = dnnl_blocked; // overrides format + wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; + wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; + wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; + wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3]; + wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; + + // wr + wr_lstm_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, + dnnl::memory::format_tag::any); + wr_user_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, + dnnl::memory::format_tag::ldigo); + wr_user_md.data.format_kind = dnnl_blocked; // overrides format + wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; + wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; + wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; + wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3]; + wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; + + // h + h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, + dnnl::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, + // type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, + // hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); + h_user_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, + dnnl::memory::format_tag::tnc); + h_user_md.data.format_kind = dnnl_blocked; // overrides format + h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; + h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; + h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; + + // b + if (b) { + b_lstm_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, + dnnl::memory::format_tag::any); + b_user_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, + dnnl::memory::format_tag::ldgo); + b_user_md.data.format_kind = dnnl_blocked; // overrides format + b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; + b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; + b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; + b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3]; + } + + // hI + if (hI) { + hI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, + dnnl::memory::format_tag::any); + hI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, + dnnl::memory::format_tag::ldnc); + hI_user_md.data.format_kind = dnnl_blocked; // overrides format + hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; + hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; + hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; + hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3]; + } + + // cI + if (cI) { + cI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, + dnnl::memory::format_tag::any); + cI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, + dnnl::memory::format_tag::ldnc); + cI_user_md.data.format_kind = dnnl_blocked; // overrides format + cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; + cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3]; + } + + // hL + if (hL) { + hL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, + dnnl::memory::format_tag::any); + hL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, + dnnl::memory::format_tag::ldnc); + hL_user_md.data.format_kind = dnnl_blocked; // overrides format + hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; + hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; + hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; + hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3]; + } + + if (cL) { + cL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, + dnnl::memory::format_tag::ldnc); + cL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, + dnnl::memory::format_tag::ldnc); + cL_user_md.data.format_kind = dnnl_blocked; // overrides format + cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; + cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; + cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; + cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3]; + } + + // lstm memory description + lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, + x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, + wr_lstm_md, b_lstm_md, h_lstm_md, hL_lstm_md, + cL_lstm_md); + + dnnl::stream stream(engine); + + // lstm primitive description + lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // provide memory and check whether reorder is required + // x + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + lstm_prim_desc.src_layer_desc(), + args[DNNL_ARG_SRC_LAYER]); + + // wx + mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, + lstm_prim_desc.weights_layer_desc(), + args[DNNL_ARG_WEIGHTS_LAYER]); + + // wr + mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, + lstm_prim_desc.weights_iter_desc(), + args[DNNL_ARG_WEIGHTS_ITER]); + + // h + auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer()); + const bool hReorder = + lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); + auto h_lstm_mem = hReorder + ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) + : h_user_mem; + args[DNNL_ARG_DST_LAYER] = h_lstm_mem; + + // b + if (b) { + mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, + lstm_prim_desc.bias_desc(), + args[DNNL_ARG_BIAS]); + } + + // hI + if (hI) { + mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, + lstm_prim_desc.src_iter_desc(), + args[DNNL_ARG_SRC_ITER]); + } + + // cI + if (cI) { + mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, + lstm_prim_desc.src_iter_c_desc(), + args[DNNL_ARG_SRC_ITER_C]); + } + + bool hLReorder(false), cLReorder(false); + dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; + + // hL + if (hL) { + hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer()); + hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); + hL_lstm_mem = hLReorder + ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) + : hL_user_mem; + args[DNNL_ARG_DST_ITER] = hL_lstm_mem; + } + + // cL + if (cL) { + cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer()); + cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); + cL_lstm_mem = cLReorder + ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) + : cL_user_mem; + args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem; + } + + // run calculations + lstm_forward(lstm_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (hReorder) + reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem); + if (hLReorder) + reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem); + if (cLReorder) + reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem); + + stream.wait(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - int count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !"); - REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); - REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); - REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); - REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); - REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); - REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); - REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); - const Nd4jLong nOut = Wx->sizeAt(-1) / 4; - - // inputs validations - if(directionMode < 2) { // no bidirectional - - // Wx validation - if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - } - else { // bidirectional - // Wx validation - if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); - // Wr validation - if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); - // biases validation - if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); - // initial output validation - if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); - // initial cell validation - if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); - } - - std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; - - const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional - - // permut x and h to tnc format if they have ntc format - NDArray* xP(const_cast(x)), *hP(h); - if(dataFormat == 1) { - xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn] - hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] - } - - // reshape arrays in accordance to mkl allowed formats - NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); - - WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut})); - WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut})); - if(b) - bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); - if(hI) - hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); - if(cI) - cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); - if(hL) - hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false)); - if(cL) - cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false)); - - lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); - - delete WxR; - delete WrR; - delete bR; - delete hIR; - delete cIR; - delete hLR; - delete cLR; - - if(dataFormat == 1) { - delete xP; - delete hP; - } - - return Status::OK(); + const auto dataFormat = INT_ARG( + 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, + // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time + // sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only, + // in this case shape would be [bS, nOut] (exact shape depends + // on dataFormat argument) + const auto retLastC = + B_ARG(7); // indicates whether to return cells state at last time step + // only, in this case shape would be [bS, nOut] (exact shape + // depends on dataFormat argument) + + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = + hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = + hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip == 0, 0, + "LSTM_LAYER_MKLDNN operation: cell clipping is not supported " + "currently !"); + REQUIRE_TRUE( + retFullSeq, 0, + "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence " + "output h should be always true in case of mkl dnn library !"); + REQUIRE_TRUE(hasPH == false, 0, + "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support " + "peephole connections !"); + REQUIRE_TRUE(hasSeqLen == false, 0, + "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support " + "array specifying max time step per each example in batch !"); + REQUIRE_TRUE( + dataFormat < 2, 0, + "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are " + "allowed for input/output tensors in mkl dnn library: TNC and NTC!"); + REQUIRE_TRUE(directionMode < 4, 0, + "LSTM_LAYER_MKLDNN operation: option for bidirectional extra " + "output dimension is not valid in mkl dnn library !"); + REQUIRE_TRUE(retLastH == retLastC, 0, + "LSTM_LAYER_MKLDNN operation: only two options are present: 1) " + "calculate both output at last time and cell state at last " + "time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, + "LSTM_LAYER_MKLDNN operation: either both of or neither of " + "initial C and initial H must be provided"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = + retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = + retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = + dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if (directionMode < 2) { // no bidirectional + + // Wx validation + if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of biases, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && + (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of initial " + "output, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && + (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + } else { // bidirectional + // Wx validation + if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || + Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent " + "weights, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if (b != nullptr && + (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of biases, " + "expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), + ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || + hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of initial " + "output, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || + cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, + "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell " + "state, expected is %s, but got %s instead !", + ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), + ShapeUtils::shapeAsString(cI).c_str()); + } + + std::vector params = {static_cast(dataFormat), + static_cast(directionMode), + static_cast(cellClip)}; + + const int dirDim = + directionMode < 2 + ? 1 + : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional + + // permut x and h to tnc format if they have ntc format + NDArray *xP(const_cast(x)), *hP(h); + if (dataFormat == 1) { + xP = new NDArray(x->permute({1, 0, 2})); // [bS, sL, nIn] -> [sL, bS, nIn] + hP = new NDArray( + h->permute({1, 0, 2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] + } + + // reshape arrays in accordance to mkl allowed formats + NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), + *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); + + WxR = new NDArray(Wx->reshape(Wx->ordering(), {1, dirDim, nIn, 4, nOut})); + WrR = new NDArray(Wr->reshape(Wr->ordering(), {1, dirDim, nOut, 4, nOut})); + if (b) bR = new NDArray(b->reshape(b->ordering(), {1, dirDim, 4, nOut})); + if (hI) hIR = new NDArray(hI->reshape(hI->ordering(), {1, dirDim, bS, nOut})); + if (cI) cIR = new NDArray(cI->reshape(cI->ordering(), {1, dirDim, bS, nOut})); + if (hL) + hLR = + new NDArray(hL->reshape(hL->ordering(), {1, dirDim, bS, nOut}, false)); + if (cL) + cLR = + new NDArray(cL->reshape(cL->ordering(), {1, dirDim, bS, nOut}, false)); + + lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); + + delete WxR; + delete WrR; + delete bR; + delete hIR; + delete cIR; + delete hLR; + delete cLR; + + if (dataFormat == 1) { + delete xP; + delete hP; + } + + return Status::OK(); } PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { - - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - - const auto hasBiases = B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - - const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping - - const auto x = INPUT_VARIABLE(0); // input - const auto Wx = INPUT_VARIABLE(1); // input weights - const auto Wr = INPUT_VARIABLE(2); // recurrent weights - - int count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - - count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step - - DataType xType = x->dataType(); - DataType WxType = Wx->dataType(); - DataType WrType = Wr->dataType(); - DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); - DataType hIType = hI != nullptr ? hI->dataType() : xType; - DataType cIType = cI != nullptr ? cI->dataType() : xType; - DataType hType = h != nullptr ? h->dataType() : xType; - DataType hLType = hL != nullptr ? hL->dataType() : xType; - DataType cLType = cL != nullptr ? cL->dataType() : xType; - - auto featuresSupported = (cellClip == 0) //Cell clipping not supported - && retFullSeq //Always return full sequence in case of MKL DNN - && !hasPH //Peephole connections not supported in MKL DNN - && !hasSeqLen //Sequence length array not supported in MKL DNN - && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] - && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat - && retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other) - && hasInitH == hasInitC; //Need both or neither initial H and C - - return block.isUseMKLDNN() && featuresSupported && ( - (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || - (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || - (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) - ); + const auto dataFormat = INT_ARG( + 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, + // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = + INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = + // bidirectional concat, 4 = bidirectional extra output dim + // (in conjunction with format dataFormat = 3) + + const auto hasBiases = + B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = + B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = + B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = + B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = + B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time + // sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = + B_ARG(6); // indicates whether to return output at last time step only, + // in this case shape would be [bS, nOut] (exact shape depends + // on dataFormat argument) + const auto retLastC = + B_ARG(7); // indicates whether to return cells state at last time step + // only, in this case shape would be [bS, nOut] (exact shape + // depends on dataFormat argument) + + const auto cellClip = + T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = + hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = + hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = + retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = + retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + DataType xType = x->dataType(); + DataType WxType = Wx->dataType(); + DataType WrType = Wr->dataType(); + DataType bType = b != nullptr + ? b->dataType() + : (xType == DataType::HALF ? xType : DataType::FLOAT32); + DataType hIType = hI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? cI->dataType() : xType; + DataType hType = h != nullptr ? h->dataType() : xType; + DataType hLType = hL != nullptr ? hL->dataType() : xType; + DataType cLType = cL != nullptr ? cL->dataType() : xType; + + auto featuresSupported = + (cellClip == 0) // Cell clipping not supported + && retFullSeq // Always return full sequence in case of MKL DNN + && !hasPH // Peephole connections not supported in MKL DNN + && !hasSeqLen // Sequence length array not supported in MKL DNN + && dataFormat < 2 // Data format - only 0 and 1 supported in MKL DNN- 0 = + // [sL, bS, nIn], 1 = [bS, sL ,nIn] + && directionMode < 4 // Direction mode - only 0-3 supported in MKL DNN + // (no extra dim option) - 0 = fwd, 1 = bwd, 2 = + // bidirectional sum, 3 = bidirectional concat + && retLastH == retLastC // Return both lastH and lastC, or return neither + // (not just 1 or other) + && hasInitH == hasInitC; // Need both or neither initial H and C + + return block.isUseMKLDNN() && featuresSupported && + ((xType == DataType::FLOAT32 && WxType == DataType::FLOAT32 && + WrType == DataType::FLOAT32 && bType == DataType::FLOAT32 && + hIType == DataType::FLOAT32 && cIType == DataType::FLOAT32 && + hType == DataType::FLOAT32 && hLType == DataType::FLOAT32 && + cLType == DataType::FLOAT32) || + (xType == DataType::HALF && WxType == DataType::HALF && + WrType == DataType::HALF && bType == DataType::HALF && + hIType == DataType::HALF && cIType == DataType::HALF && + hType == DataType::HALF && hLType == DataType::HALF && + cLType == DataType::HALF) || + (xType == DataType::UINT8 && WxType == DataType::INT8 && + WrType == DataType::INT8 && bType == DataType::FLOAT32 && + hIType == DataType::UINT8 && cIType == DataType::UINT8 && + (hType == DataType::FLOAT32 && hLType == DataType::FLOAT32 && + cLType == DataType::FLOAT32 || + hType == DataType::UINT8 && hLType == DataType::UINT8 && + cLType == DataType::UINT8))); } - - -} -} -} +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index ec58af9430db..c0ea75ffc64e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -18,312 +18,386 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include #include -#include -#include "mkldnnUtils.h" #include +#include "mkldnnUtils.h" -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) { - switch (array.rankOf()) { - case 1: - return dnnl::memory::format_tag::ab; - case 2: - return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; - case 3: - return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba; - default: - throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays"); - } - } - +dnnl::memory::format_tag get_format_tag(const sd::NDArray& array) { + switch (array.rankOf()) { + case 1: + return dnnl::memory::format_tag::ab; + case 2: + return array.ordering() == 'c' ? dnnl::memory::format_tag::ab + : dnnl::memory::format_tag::ba; + case 3: + return array.ordering() == 'c' ? dnnl::memory::format_tag::abc + : dnnl::memory::format_tag::cba; + default: + throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays"); + } +} ////////////////////////////////////////////////////////////////////////// -static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { - - // mkl works with following - // [M,K] x [K,N] = [M,N] - // [bS, M,K] x [bS, K,N] = [bS, M,N] - - // possible input cases not supported by mkl, however we'll perform permut/reshape procedures in order to fit requirements - // [4] x [4] = [1] --> [1,4] x [4,1] = [1,1] - // [4] x [4,5] = [5] --> [1,4] x [4,5] = [1,5] - // [4,5] x [5] = [4] --> [4,5] x [5,1] = [4,1] - // [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> [6, 4,5] x [6, 5,4] = [6, 4,4] - // [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] --> [12, 4,5] x [12, 5,4] = [12, 4,4] - - const auto xRank = x->rankOf(); - const auto yRank = y->rankOf(); - const auto zRank = z->rankOf(); - - std::vector permut; - - // fill permutation vector appropriately if transposition is required - if((transX && xRank > 1) || (transY && yRank > 1)) { - - const int rank = xRank >= yRank ? xRank : yRank; - permut.resize(rank); - std::iota(std::begin(permut), std::end(permut), 0); - permut[rank-2] = rank - 1; - permut[rank-1] = rank - 2; - } - - const NDArray* xT = (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x; - const NDArray* yT = (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y; - - const NDArray* xTR = xRank <= 3 ? xT : new NDArray(xT->reshape(xT->ordering(), {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), xT->sizeAt(-2), xT->sizeAt(-1)})); - const NDArray* yTR = xRank <= 3 ? yT : new NDArray(yT->reshape(yT->ordering(), {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), yT->sizeAt(-2), yT->sizeAt(-1)})); - NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/); - - // [M,K] x [K,N] = [M,N] - const int64_t M = (xRank > 1) ? xTR->sizeAt(-2) : 1; - const int64_t K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); - const int64_t N = (yRank > 1) ? yTR->sizeAt(-1) : 1; - const int64_t bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] - - dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); - dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); - dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); - - // x type - dnnl::memory::data_type xType; - if(x->dataType() == DataType::FLOAT32) - xType = dnnl::memory::data_type::f32; - else if(x->dataType() == DataType::HALF) - xType = dnnl::memory::data_type::f16; - else if(x->dataType() == DataType::BFLOAT16) - xType = dnnl::memory::data_type::bf16; - else if(x->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else - xType = dnnl::memory::data_type::s8; - - // y type - dnnl::memory::data_type yType = xType; - if(y->dataType() == DataType::UINT8) - yType = dnnl::memory::data_type::u8; - else if(y->dataType() == DataType::INT8) - yType = dnnl::memory::data_type::s8; - - // z type - dnnl::memory::data_type zType = xType; - if(z->dataType() == DataType::FLOAT32) - zType = dnnl::memory::data_type::f32; - else if(z->dataType() == DataType::INT32) - zType = dnnl::memory::data_type::s32; - else if(z->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if(z->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - - // memory descriptors for arrays - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); - if(xTR->ews() != 1) { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); - if(xRank > 2) - x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2); - } - - // y - dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); - dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); - if(yTR->ews() != 1) { - y_user_md.data.format_kind = dnnl_blocked; // overrides format - y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0); - y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); - if(yRank > 2) - y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2); - } - - // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); - if(zR->ews() != 1) { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); - if(zRank > 2) - z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2); - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) - if (alpha != 1.f) attr.set_output_scales(0, {alpha}); - if (beta != 0.f) { - dnnl::post_ops po; - po.append_sum(beta); - attr.set_post_ops(po); - } - - // operation primitive description - dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); - dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - /* - auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; +static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, + const bool transX, const bool transY, + float alpha = 1.f, float beta = 0.f) { + // mkl works with following + // [M,K] x [K,N] = [M,N] + // [bS, M,K] x [bS, K,N] = [bS, M,N] + + // possible input cases not supported by mkl, however we'll perform + // permut/reshape procedures in order to fit requirements [4] x [4] + // = [1] --> [1,4] x [4,1] = [1,1] [4] x [4,5] = [5] + // --> [1,4] x [4,5] = [1,5] [4,5] x [5] = [4] --> + // [4,5] x [5,1] = [4,1] [2,3, 4,5] x [2,3, 5,4] = [2,3, 4,4] --> + // [6, 4,5] x [6, 5,4] = [6, 4,4] [2,2,3, 4,5] x [2,2,3, 5,4] = [2,2,3, 4,4] + // --> [12, 4,5] x [12, 5,4] = [12, 4,4] + + const auto xRank = x->rankOf(); + const auto yRank = y->rankOf(); + const auto zRank = z->rankOf(); + + std::vector permut; + + // fill permutation vector appropriately if transposition is required + if ((transX && xRank > 1) || (transY && yRank > 1)) { + const int rank = xRank >= yRank ? xRank : yRank; + permut.resize(rank); + std::iota(std::begin(permut), std::end(permut), 0); + permut[rank - 2] = rank - 1; + permut[rank - 1] = rank - 2; + } + + const NDArray* xT = + (transX && xRank > 1) ? new NDArray(x->permute(permut)) : x; + const NDArray* yT = + (transY && yRank > 1) ? new NDArray(y->permute(permut)) : y; + + const NDArray* xTR = + xRank <= 3 ? xT + : new NDArray(xT->reshape( + xT->ordering(), + {xT->lengthOf() / (xT->sizeAt(-2) * xT->sizeAt(-1)), + xT->sizeAt(-2), xT->sizeAt(-1)})); + const NDArray* yTR = + xRank <= 3 ? yT + : new NDArray(yT->reshape( + yT->ordering(), + {yT->lengthOf() / (yT->sizeAt(-2) * yT->sizeAt(-1)), + yT->sizeAt(-2), yT->sizeAt(-1)})); + NDArray* zR = + xRank <= 3 + ? z + : new NDArray(z->reshape( + z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), + z->sizeAt(-2), z->sizeAt(-1)}) /*, false*/); + + // [M,K] x [K,N] = [M,N] + const int64_t M = (xRank > 1) ? xTR->sizeAt(-2) : 1; + const int64_t K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); + const int64_t N = (yRank > 1) ? yTR->sizeAt(-1) : 1; + const int64_t bS = + (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] + + dnnl::memory::dims xShape = + xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); + dnnl::memory::dims yShape = + xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); + dnnl::memory::dims zShape = + xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); + + // x type + dnnl::memory::data_type xType; + if (x->dataType() == DataType::FLOAT32) + xType = dnnl::memory::data_type::f32; + else if (x->dataType() == DataType::HALF) + xType = dnnl::memory::data_type::f16; + else if (x->dataType() == DataType::BFLOAT16) + xType = dnnl::memory::data_type::bf16; + else if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else + xType = dnnl::memory::data_type::s8; + + // y type + dnnl::memory::data_type yType = xType; + if (y->dataType() == DataType::UINT8) + yType = dnnl::memory::data_type::u8; + else if (y->dataType() == DataType::INT8) + yType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = xType; + if (z->dataType() == DataType::FLOAT32) + zType = dnnl::memory::data_type::f32; + else if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); + dnnl::memory::desc x_user_md = + dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); + if (xTR->ews() != 1) { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = + xRank == 1 ? 1 : xTR->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = + xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); + if (xRank > 2) + x_user_md.data.format_desc.blocking.strides[2] = xTR->strideAt(2); + } + + // y + dnnl::memory::desc y_mkl_md = + dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); + dnnl::memory::desc y_user_md = + dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); + if (yTR->ews() != 1) { + y_user_md.data.format_kind = dnnl_blocked; // overrides format + y_user_md.data.format_desc.blocking.strides[0] = + yRank == 1 ? 1 : yTR->strideAt(0); + y_user_md.data.format_desc.blocking.strides[1] = + yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); + if (yRank > 2) + y_user_md.data.format_desc.blocking.strides[2] = yTR->strideAt(2); + } + + // z + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); + dnnl::memory::desc z_user_md = + dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); + if (zR->ews() != 1) { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = + zRank == 1 ? 1 : zR->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = + zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); + if (zRank > 2) + z_user_md.data.format_desc.blocking.strides[2] = zR->strideAt(2); + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) + if (alpha != 1.f) attr.set_output_scales(0, {alpha}); + if (beta != 0.f) { + dnnl::post_ops po; + po.append_sum(beta); + attr.set_post_ops(po); + } + + // operation primitive description + dnnl::matmul::desc op_desc(x_mkl_md, y_mkl_md, z_mkl_md); + dnnl::matmul::primitive_desc op_prim_desc(op_desc, attr, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + /* + auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : + x_user_mem; if (xReorder) dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, + x_user_mem, x_mkl_mem); args[DNNL_ARG_SRC] = x_mkl_mem; */ - // y - mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - /* - auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer()); - const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); - auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem; - if (yReorder) - dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); - args[DNNL_ARG_WEIGHTS] = y_mkl_mem; + // y + mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + /* + auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer()); + const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); + auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) + : y_user_mem; if (yReorder) dnnl::reorder(y_user_mem, + y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); args[DNNL_ARG_WEIGHTS] = + y_mkl_mem; */ - // z - auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::matmul(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - - if(zR->buffer() != z->buffer()) - z->assign(zR); - - if(zR != z) - delete zR; - if(xTR != xT) - delete xTR; - if(xT != x) - delete xT; - if(yTR != yT) - delete yTR; - if(yT != y) - delete yT; - - // shape::printArray(z_mkl_mem.map_data(),8); -} - -////////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(matmul, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); - - if(x->isEmpty() || y->isEmpty()) - return Status::OK(); - - int iSize = (int) block.numI(); - int transX = iSize > 0 ? INT_ARG(0) : 0; - int transY = iSize > 1 ? INT_ARG(1) : 0; - const int transZ = iSize > 2 ? INT_ARG(2) : 0; - - // optional use alpha nad beta - iSize = (int) block.numT(); - float alpha = iSize > 0 ? T_ARG(0) : 1.0; - float beta = iSize > 1 ? T_ARG(1) : 0.0; - - const int xRank = x->rankOf(); - const int yRank = y->rankOf(); - const int zRank = z->rankOf(); - - if (transZ) { - x = INPUT_VARIABLE(1); - y = INPUT_VARIABLE(0); - bool temp = transX; - transX = !transY; - transY = !temp; - } - - const int xLastDim = transX ? -2 : -1; - const int yLastDim = transY ? -2 : -1; - const int xLastButOneDim = transX ? -1 : -2; - const int yLastButOneDim = transY ? -1 : -2; - - // ******* input validation ******* // - REQUIRE_TRUE(xRank > 0 && yRank > 0, 0, "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not be scalars), but got instead: x rank = %i, y rank = %i !", xRank, yRank); - - if (xRank == 1 && yRank == 1) { // dot case, output is scalar (or vector with length = 1) - REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0,"MATMUL MKLDNN OP: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !",x->lengthOf(), y->lengthOf()); - } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector - REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, "MATMUL MKLDNN OP: input arrays have inconsistent shapes for matrix-vector product: x %s, y %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str()); - } else { - REQUIRE_TRUE(xRank == yRank && yRank == zRank, 0, "MATMUL MKLDNN OP: input and output arrays must have the same rank, but got instead: x rank = %i, y rank = %i, z rank = %i !", xRank, yRank, zRank); - REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && y->sizeAt(yLastDim) == z->sizeAt(-1), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - - if (xRank > 2) // outer dims must be the same - for (int i = 0; i < xRank - 2; ++i) - REQUIRE_TRUE(x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes for matrix product: x %s, y %s, z %s !", ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str(), ShapeUtils::shapeAsString(z).c_str()); - } - // ******* end of input validation ******* // - - matmulMKLDNN(x, y, z, transX, transY, alpha, beta); - - return Status::OK(); -} - -////////////////////////////////////////////////////////////////////////// -PLATFORM_CHECK(matmul, ENGINE_CPU) { + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; - auto x = INPUT_VARIABLE(0); - auto y = INPUT_VARIABLE(1); + // run calculations + dnnl::matmul(op_prim_desc).execute(stream, args); - auto z = OUTPUT_VARIABLE(0); + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - const auto xType = x->dataType(); - const auto yType = y->dataType(); - const auto zType = z->dataType(); + stream.wait(); - float alpha = block.numT() > 0 ? T_ARG(0) : 1.0f; - float beta = block.numT() > 1 ? T_ARG(1) : 0.0f; + if (zR->buffer() != z->buffer()) z->assign(zR); - // we're skipping if result order is F or arrays are not continuous - bool skip2D = z->rankOf() == 2 && (z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); + if (zR != z) delete zR; + if (xTR != xT) delete xTR; + if (xT != x) delete xT; + if (yTR != yT) delete yTR; + if (yT != y) delete yT; - // we're skipping 3D cases if they are not C continuoys - bool skip3D = z->rankOf() == 3 && (x->ordering() == 'f' || y->ordering() == 'f' || z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); - - return !skip2D && !skip3D && block.isUseMKLDNN() && x->rankOf() < 3 && - ( - (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || - (xType==DataType::BFLOAT16 && yType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || - ((xType==DataType::UINT8 || xType==DataType::INT8) && (yType==DataType::UINT8 || yType==DataType::INT8) && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32)) - ); + // shape::printArray(z_mkl_mem.map_data(),8); } - -} +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(matmul, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || y->isEmpty()) return Status::OK(); + + int iSize = (int)block.numI(); + int transX = iSize > 0 ? INT_ARG(0) : 0; + int transY = iSize > 1 ? INT_ARG(1) : 0; + const int transZ = iSize > 2 ? INT_ARG(2) : 0; + + // optional use alpha nad beta + iSize = (int)block.numT(); + float alpha = iSize > 0 ? T_ARG(0) : 1.0; + float beta = iSize > 1 ? T_ARG(1) : 0.0; + + const int xRank = x->rankOf(); + const int yRank = y->rankOf(); + const int zRank = z->rankOf(); + + if (transZ) { + x = INPUT_VARIABLE(1); + y = INPUT_VARIABLE(0); + bool temp = transX; + transX = !transY; + transY = !temp; + } + + const int xLastDim = transX ? -2 : -1; + const int yLastDim = transY ? -2 : -1; + const int xLastButOneDim = transX ? -1 : -2; + const int yLastButOneDim = transY ? -1 : -2; + + // ******* input validation ******* // + REQUIRE_TRUE( + xRank > 0 && yRank > 0, 0, + "MATMUL MKLDNN OP: input arrays must have rank bigger than 0 (should not " + "be scalars), but got instead: x rank = %i, y rank = %i !", + xRank, yRank); + + if (xRank == 1 && + yRank == 1) { // dot case, output is scalar (or vector with length = 1) + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, + "MATMUL MKLDNN OP: since input arrays are vectors they must " + "have the same length, but got x length = %i, y length = %i !", + x->lengthOf(), y->lengthOf()); + } else if (xRank == 1 && yRank == 2) { // vector x matrix, i.e. [4] x [4,5] = + // [5], output is vector + REQUIRE_TRUE(x->lengthOf() == y->sizeAt(yLastButOneDim), 0, + "MATMUL MKLDNN OP: input arrays have inconsistent shapes for " + "vector-matrix product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else if (xRank == 2 && yRank == 1) { // matrix x vector , i.e. [4,5] x [5] + // = [4], output is vector + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->lengthOf(), 0, + "MATMUL MKLDNN OP: input arrays have inconsistent shapes for " + "matrix-vector product: x %s, y %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str()); + } else { + REQUIRE_TRUE( + xRank == yRank && yRank == zRank, 0, + "MATMUL MKLDNN OP: input and output arrays must have the same rank, " + "but got instead: x rank = %i, y rank = %i, z rank = %i !", + xRank, yRank, zRank); + REQUIRE_TRUE(x->sizeAt(xLastDim) == y->sizeAt(yLastButOneDim) && + x->sizeAt(xLastButOneDim) == z->sizeAt(-2) && + y->sizeAt(yLastDim) == z->sizeAt(-1), + 0, + "MATMUL MKLDNN OP: input/output arrays have inconsistent " + "shapes for matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + + if (xRank > 2) // outer dims must be the same + for (int i = 0; i < xRank - 2; ++i) + REQUIRE_TRUE( + x->sizeAt(i) == y->sizeAt(i) && y->sizeAt(i) == z->sizeAt(i), 0, + "MATMUL MKLDNN OP: input/output arrays have inconsistent shapes " + "for matrix product: x %s, y %s, z %s !", + ShapeUtils::shapeAsString(x).c_str(), + ShapeUtils::shapeAsString(y).c_str(), + ShapeUtils::shapeAsString(z).c_str()); + } + // ******* end of input validation ******* // + + matmulMKLDNN(x, y, z, transX, transY, alpha, beta); + + return Status::OK(); } + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(matmul, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + + auto z = OUTPUT_VARIABLE(0); + + const auto xType = x->dataType(); + const auto yType = y->dataType(); + const auto zType = z->dataType(); + + float alpha = block.numT() > 0 ? T_ARG(0) : 1.0f; + float beta = block.numT() > 1 ? T_ARG(1) : 0.0f; + + // we're skipping if result order is F or arrays are not continuous + bool skip2D = z->rankOf() == 2 && (z->ordering() == 'f' || x->ews() != 1 || + y->ews() != 1 || z->ews() != 1); + + // we're skipping 3D cases if they are not C continuoys + bool skip3D = + z->rankOf() == 3 && + (x->ordering() == 'f' || y->ordering() == 'f' || z->ordering() == 'f' || + x->ews() != 1 || y->ews() != 1 || z->ews() != 1); + + return !skip2D && !skip3D && block.isUseMKLDNN() && x->rankOf() < 3 && + ((xType == DataType::FLOAT32 && yType == DataType::FLOAT32 && + zType == DataType::FLOAT32) || + (xType == DataType::HALF && yType == DataType::HALF && + zType == DataType::FLOAT32) || + (xType == DataType::BFLOAT16 && yType == DataType::BFLOAT16 && + zType == DataType::BFLOAT16) || + ((xType == DataType::UINT8 || xType == DataType::INT8) && + (yType == DataType::UINT8 || yType == DataType::INT8) && + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32))); } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index cab94ffa8a76..b91b1b8f599a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -20,108 +20,143 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { - ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf()); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - const int kH = INT_ARG(0); - const int kW = INT_ARG(1); - const int sH = INT_ARG(2); - const int sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const int dH = INT_ARG(6); - const int dW = INT_ARG(7); - const int paddingMode = INT_ARG(8); - // const int extraParam0 = INT_ARG(9); - const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - if (paddingMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); - - return Status::OK(); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but " + "got %i instead", + input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + mkldnnUtils::poolingMKLDNN(input, output, 0, kH, kW, 0, sH, sW, 0, pH, pW, + isNCHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - // int extraParam0 = INT_ARG(9); - int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if (paddingMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); - - return Status::OK(); + auto input = + INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + // int extraParam0 = INT_ARG(9); + int isNCHW = + block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE( + input->rankOf() == 4, 0, + "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", + input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i}", + dH, dW); + + int bS, iC, iH, iW, oC, oH, + oW; // batch size, input channels, input height/width, output channels, + // output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, + indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, + dW); + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0, kH, kW, 0, sH, sW, 0, pH, + pW, isNCHW, algorithm::pooling_max); + + return Status::OK(); } PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 59f6d19499d8..c1a0b35af33e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -19,115 +19,150 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // -#include +#include #include +#include +#include #include -#include #include "mkldnnUtils.h" -#include using namespace dnnl; -namespace sd { -namespace ops { +namespace sd { +namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(paddingMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); - - return Status::OK(); - + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to " + "5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + mkldnnUtils::poolingMKLDNN(input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, + isNCDHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if(paddngMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); - - return Status::OK(); + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, + // oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, + // iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // + // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but " + "got %i instead", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got " + "instead {%i, %i, %i}", + dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, + oW; // batch size, input channels, input depth/height/width, output + // channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d( + isNCDHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oD, oH, oW, 0, indIOioC, + indIOioD, indIOioD + 1, + indIOioD + 2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, + "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients " + "array (next epsilon), expected is %s, but got %s instead !", + ShapeUtils::shapeAsString(expectedGradOShape).c_str(), + ShapeUtils::shapeAsString(gradO).c_str()); + + if (paddngMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, + kW, sD, sH, sW, dD, dH, dW); + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, + pH, pW, isNCDHW, algorithm::pooling_max); + + return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); + return block.isUseMKLDNN() && sd::MKLDNNStream::isSupported({input, output}); } -} -} -} \ No newline at end of file +} // namespace platforms +} // namespace ops +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index bc79e6169c17..e31c425a88c6 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -19,369 +19,441 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // +#include "mkldnnUtils.h" + #include #include -#include "mkldnnUtils.h" using namespace dnnl; -namespace sd { +namespace sd { namespace mkldnnUtils { ////////////////////////////////////////////////////////////////////// -void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ - - std::vector vDims(rank); - for (auto i = 0; i < rank; i++) { - vDims[i] = array->sizeAt(i); - } - mklDims = dnnl::memory::dims(vDims); +void getDims(const NDArray* array, const int rank, + dnnl::memory::dims& mklDims) { + std::vector vDims(rank); + for (auto i = 0; i < rank; i++) { + vDims[i] = array->sizeAt(i); + } + mklDims = dnnl::memory::dims(vDims); } ////////////////////////////////////////////////////////////////////// -dnnl::memory::format_tag getFormat(const int rank){ - if (2 == rank) { - return dnnl::memory::format_tag::ab; - } - else if (3 == rank) { - return dnnl::memory::format_tag::abc; - } - else if (4 == rank) { - return dnnl::memory::format_tag::abcd; - } - else if (5 == rank) { - return dnnl::memory::format_tag::abcde; - } - else if (6 == rank) { - return dnnl::memory::format_tag::abcdef; - } - return dnnl::memory::format_tag::a; // 1 == dataSetRank +dnnl::memory::format_tag getFormat(const int rank) { + if (2 == rank) { + return dnnl::memory::format_tag::ab; + } else if (3 == rank) { + return dnnl::memory::format_tag::abc; + } else if (4 == rank) { + return dnnl::memory::format_tag::abcd; + } else if (5 == rank) { + return dnnl::memory::format_tag::abcde; + } else if (6 == rank) { + return dnnl::memory::format_tag::abcdef; + } + return dnnl::memory::format_tag::a; // 1 == dataSetRank } ////////////////////////////////////////////////////////////////////// -void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){ - - if (array->ews() != 1 || array->ordering() != 'c') { - mklMd.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < array->rankOf(); ++i) { - mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); - } +void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd) { + if (array->ews() != 1 || array->ordering() != 'c') { + mklMd.data.format_kind = dnnl_blocked; // overrides format + for (auto i = 0; i < array->rankOf(); ++i) { + mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); } + } } //////////////////////////////////////////////////////////////////////////////////////////////// -void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, +void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, + const dnnl::stream& stream, + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& primitive_md, dnnl::memory& arg) { - - auto user_mem = dnnl::memory(user_md, engine,const_cast(array->buffer())); - const bool bReorder = primitive_md != user_mem.get_desc(); - auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; - if (bReorder) - dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); - arg = mkl_mem; + auto user_mem = + dnnl::memory(user_md, engine, const_cast(array->buffer())); + const bool bReorder = primitive_md != user_mem.get_desc(); + auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; + if (bReorder) + dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); + arg = mkl_mem; } ////////////////////////////////////////////////////////////////////// -void poolingMKLDNN(const NDArray *input, NDArray *output, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int isNCHW, const dnnl::algorithm mode) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input - const int rank = input->rankOf(); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; - dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; - dnnl::memory::format_tag xzFrmat; - - const auto type = dnnl::memory::data_type::f32; - - if(rank == 4) { // 2d - - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - strides = { sH, sW }; - kernel = { kH, kW }; - padding = { pH, pW }; - padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iH, iW}; - zDims = {bS, oC, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // 3d - - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); - - strides = { sD, sH, sW }; - kernel = { kD, kH, kW }; - padding = { pD, pH, pW }; - padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iD, iH, iW}; - zDims = {bS, oC, oD, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - - // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); - } - - // output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3); - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::pooling_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); +void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for input + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, + indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if (rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + strides = {sH, sW}; + kernel = {kH, kW}; + padding = {pH, pW}; + padding_r = {(oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d( + isNCHW, 0, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIiH, indWiC, indWoC, indWkH); + + strides = {sD, sH, sW}; + kernel = {kD, kH, kW}; + padding = {pD, pH, pW}; + padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if (input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = + input->strideAt(isNCHW ? 1 : -1); + x_user_md.data.format_desc.blocking.strides[2] = + input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = + input->strideAt(isNCHW ? 3 : 2); + if (rank == 5) + x_user_md.data.format_desc.blocking.strides[4] = + input->strideAt(isNCHW ? 4 : 3); + } + + // output + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if (output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = + output->strideAt(isNCHW ? 1 : -1); + z_user_md.data.format_desc.blocking.strides[2] = + output->strideAt(isNCHW ? 2 : 1); + z_user_md.data.format_desc.blocking.strides[3] = + output->strideAt(isNCHW ? 3 : 2); + if (rank == 5) + z_user_md.data.format_desc.blocking.strides[4] = + output->strideAt(isNCHW ? 4 : 3); + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, + x_mkl_md, z_mkl_md, strides, kernel, + padding, padding_r); + dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::pooling_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); } ////////////////////////////////////////////////////////////////////// -void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int isNCHW, const dnnl::algorithm mode) { - - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input - - const int rank = input->rankOf(); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; - dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; - dnnl::memory::format_tag xzFrmat; - - const auto type = dnnl::memory::data_type::f32; - - if(rank == 4) { // 2d - - ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - - strides = { sH, sW }; - kernel = { kH, kW }; - padding = { pH, pW }; - padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iH, iW}; - zDims = {bS, oC, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - } - else { // 3d - - ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); - - strides = { sD, sH, sW }; - kernel = { kD, kH, kW }; - padding = { pD, pH, pW }; - padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - xDims = {bS, iC, iD, iH, iW}; - zDims = {bS, oC, oD, oH, oW}; - - xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - } - - // memory descriptors for arrays - +void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, + const int kD, const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int isNCHW, + const dnnl::algorithm mode) { + // unfortunately mkl dnn doesn't support any format + // (dnnl::memory::format_tag::any) for input + + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, + indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if (rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d( + isNCHW, 0, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, + indWiC, indWoC, indWkH, indOoH); + + strides = {sH, sW}; + kernel = {kH, kW}; + padding = {pH, pW}; + padding_r = {(oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + } else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d( + isNCHW, 0, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH); + + strides = {sD, sH, sW}; + kernel = {kD, kH, kW}; + padding = {pD, pH, pW}; + padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW}; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw + : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if (input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = + input->strideAt(isNCHW ? 1 : -1); + x_user_md.data.format_desc.blocking.strides[2] = + input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = + input->strideAt(isNCHW ? 3 : 2); + if (rank == 5) + x_user_md.data.format_desc.blocking.strides[4] = + input->strideAt(isNCHW ? 4 : 3); + } + + // gradO + dnnl::memory::desc gradO_mkl_md = + dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if (gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = + gradO->strideAt(isNCHW ? 1 : -1); + gradO_user_md.data.format_desc.blocking.strides[2] = + gradO->strideAt(isNCHW ? 2 : 1); + gradO_user_md.data.format_desc.blocking.strides[3] = + gradO->strideAt(isNCHW ? 3 : 2); + if (rank == 5) + gradO_user_md.data.format_desc.blocking.strides[4] = + gradO->strideAt(isNCHW ? 4 : 3); + } + + // gradI + dnnl::memory::desc gradI_mkl_md = + dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if (gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = + gradI->strideAt(isNCHW ? 1 : -1); + gradI_user_md.data.format_desc.blocking.strides[2] = + gradI->strideAt(isNCHW ? 2 : 1); + gradI_user_md.data.format_desc.blocking.strides[3] = + gradI->strideAt(isNCHW ? 3 : 2); + if (rank == 5) + gradI_user_md.data.format_desc.blocking.strides[4] = + gradI->strideAt(isNCHW ? 4 : 3); + } + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + // forward primitive description + dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, + x_mkl_md, gradO_mkl_md, strides, + kernel, padding, padding_r); + dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward primitive description + dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, + strides, kernel, padding, padding_r); + dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, + op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // gradO + mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, + op_bp_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); + const bool gradIReorder = + op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = + gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) + : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + if (mode == algorithm::pooling_max) { // input - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); - } - - // gradO - dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3); - } - - // gradI - dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3); - } - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - - // forward primitive description - dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward primitive description - dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); - dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; - - if(mode == algorithm::pooling_max) { + mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + args[DNNL_ARG_SRC]); - // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // z - auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); - args[DNNL_ARG_DST] = z_mkl_mem; - - // auxiliary memory allocation - auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine); - args[DNNL_ARG_WORKSPACE] = workspace; + // z + auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); + args[DNNL_ARG_DST] = z_mkl_mem; - // run forward calculations - dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args); - } + // auxiliary memory allocation + auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine); + args[DNNL_ARG_WORKSPACE] = workspace; - // run backward calculations - dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); + // run forward calculations + dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args); + } + // run backward calculations + dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem) + .execute(stream, gradI_mkl_mem, gradI_user_mem); - stream.wait(); + stream.wait(); } ////////////////////////////////////////////////////////////////////////// -void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->shapeInfo(); - long rank = shape[0]; - long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - long dim2 = axis >= 2 ? 1 : 2; - long dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - auto type = dnnl::memory::data_type::f32; - auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = format; // doesn't work with "any" - - if (src != nullptr && src->buffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->buffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::memory::desc* lrn_src_md, + dnnl::memory::desc* lrn_diff_src_md, + dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, + dnnl::memory::desc* user_diff_src_md, + dnnl::memory::desc* user_dst_md, int axis) { + const Nd4jLong* shape = src->shapeInfo(); + long rank = shape[0]; + long dim1 = + axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + long dim2 = axis >= 2 ? 1 : 2; + long dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims lrn_src_tz = {(int)shape[1], (int)shape[dim1 + 1], + rank > 2 ? (int)shape[dim2 + 1] : 1, + rank > 3 ? (int)shape[dim3 + 1] : 1}; + + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw + : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = format; // doesn't work with "any" + + if (src != nullptr && src->buffer() != nullptr && lrn_src_md != nullptr) { + *lrn_src_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_src_md->data.format_kind = dnnl_blocked; + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + user_src_md->data.format_desc.blocking.strides[2] = + rank > 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = + rank > 3 ? src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && + lrn_diff_src_md != nullptr) { + *lrn_diff_src_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; + user_diff_src_md->data.format_desc.blocking.strides[0] = + diff_src->stridesOf()[0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = + diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = + rank > 2 ? diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = + rank > 3 ? diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->buffer() != nullptr && lrn_dst_md != nullptr) { + *lrn_dst_md = + dnnl::memory::desc({lrn_src_tz}, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({lrn_src_tz}, type, format); + user_dst_md->data.format_kind = dnnl_blocked; + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + user_dst_md->data.format_desc.blocking.strides[2] = + rank > 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = + rank > 3 ? dst->stridesOf()[dim3] : 1; + } } ////////////////////////////////////////////////////////////////////////// -dnnl::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); - return *eng; +dnnl::engine& getEngine(void* ptr) { + auto eng = reinterpret_cast(ptr); + return *eng; } - /* ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int +poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int iW, int +oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* +dst, dnnl::algorithm& algorithm, dnnl::memory::desc* pool_src_md, +dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, +dnnl::memory::desc* user_dst_md, dnnl::memory::dims& pool_strides, +dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, +dnnl::memory::dims& pool_padding_r) { dnnl::memory::dims pool_src_tz = { bS, iC, +iH, iW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; pool_strides = { sH, sW }; pool_kernel = { kH, kW }; @@ -390,51 +462,70 @@ void getMKLDNNMemoryDescPool2d( (oW - 1) * sW - iW + kW - pW }; algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; + : extraParam0 == 0 ? +algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + auto format = isNCHW ? dnnl::memory::format_tag::nchw : +dnnl::memory::format_tag::nhwc; auto supposed_to_be_any_format = +dnnl::memory::format_tag::nChw8c; // doesn't work with "any" if (src != nullptr && src->buffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, +type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md +!= nullptr) { *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz +}, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; // +overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCHW ? 3 : 2]; } if (dst != nullptr && dst->buffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, +type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCHW ? 3 : 2]; } }; ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, +int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int bS, +int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* +src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, +dnnl::memory::desc* pool_dst_md, dnnl::memory::desc* user_src_md, +dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, +dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; @@ -446,258 +537,346 @@ void getMKLDNNMemoryDescPool3d( (oW - 1) * sW - iW + kW - pW }; algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; + : extraParam0 == 0 ? +algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : +dnnl::memory::format_tag::ndhwc; auto supposed_to_be_any_format = +dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" if (src != nullptr && src->buffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ pool_src_tz }, +type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = +src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && pool_diff_src_md +!= nullptr) { *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, +supposed_to_be_any_format); *user_diff_src_md = dnnl::memory::desc({ pool_src_tz +}, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; // +overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = +diff_src->stridesOf()[isNCDHW ? 4 : 3]; } if (dst != nullptr && dst->buffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, +type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = +dst->stridesOf()[isNCDHW ? 4 : 3]; } }; ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const +int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, +int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, +const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, +dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_diff_weights_md, +dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, +dnnl::memory::desc* user_weights_md, dnnl::memory::desc* user_diff_weights_md, +dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, +dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) +* dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d conv_strides = { sH, sW }; conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - conv_dilation = { dH-1, dW-1}; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - +pWSame }; conv_dilation = { dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto formatw = dnnl::memory::format_tag::hwio; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : +dnnl::memory::format_tag::nhwc; auto formatw = dnnl::memory::format_tag::hwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_src_md = dnnl::memory::desc({ conv_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCHW ? 3 : 2]; } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_diff_src_md = dnnl::memory::desc({ +conv_src_tz }, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; +// overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCHW ? 3 : 2]; } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = hwio" + user_weights_md->data.format_desc.blocking.strides[0] = +weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[1] = +weights->stridesOf()[2]; user_weights_md->data.format_desc.blocking.strides[2] = +weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[3] = +weights->stridesOf()[1]; } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_diff_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_diff_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = hwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = +diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = +diff_weights->stridesOf()[2]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = +diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = +diff_weights->stridesOf()[1]; } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, +dnnl::memory::format_tag::any); *user_bias_md = dnnl::memory::desc({ +conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, +dnnl::memory::format_tag::any); *user_dst_md = dnnl::memory::desc({ conv_dst_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCHW ? 3 : 2]; } } ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, +int dD, int dH, int dW, bool paddingMode, bool isNCDHW, int bS, int iC, int iD, +int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const +NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const +NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, +dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* +conv_bias_md, dnnl::memory::desc* conv_dst_md, dnnl::memory::desc* user_src_md, +dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* +user_bias_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& conv_strides, +dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, +dnnl::memory::dims& conv_dilation) { dnnl::memory::dims conv_src_tz = { bS, iC, +iD, iH, iW }; dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; conv_strides = { sD, sH, sW }; conv_padding = { pD, pH, pW }; - conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - conv_dilation = { dD-1, dH-1, dW-1}; + conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - +pH, (oW - 1) * sW - iW + kW - pW }; conv_dilation = { dD-1, dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto formatw = dnnl::memory::format_tag::dhwio; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : +dnnl::memory::format_tag::ndhwc; auto formatw = dnnl::memory::format_tag::dhwio; if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_src_md = dnnl::memory::desc({ conv_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = +src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = +src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = +src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = +src->stridesOf()[isNCDHW ? 4 : 3]; } if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, +dnnl::memory::format_tag::any); *user_diff_src_md = dnnl::memory::desc({ +conv_src_tz }, type, format); user_diff_src_md->data.format_kind = dnnl_blocked; +// overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = +diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = +diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = +diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = +diff_src->stridesOf()[isNCDHW ? 4 : 3]; } if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = dhwio" + user_weights_md->data.format_desc.blocking.strides[0] = +weights->stridesOf()[4]; user_weights_md->data.format_desc.blocking.strides[1] = +weights->stridesOf()[3]; user_weights_md->data.format_desc.blocking.strides[2] = +weights->stridesOf()[0]; user_weights_md->data.format_desc.blocking.strides[3] = +weights->stridesOf()[1]; user_weights_md->data.format_desc.blocking.strides[4] = +weights->stridesOf()[2]; } if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, +dnnl::memory::format_tag::any); *user_diff_weights_md = dnnl::memory::desc({ +conv_weights_tz }, type, formatw); user_diff_weights_md->data.format_kind = +dnnl_blocked; // overrides "formatw = dhwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = +diff_weights->stridesOf()[4]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = +diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = +diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = +diff_weights->stridesOf()[1]; + user_diff_weights_md->data.format_desc.blocking.strides[4] = +diff_weights->stridesOf()[2]; } if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, +dnnl::memory::format_tag::any); *user_bias_md = dnnl::memory::desc({ +conv_bias_tz }, type, dnnl::memory::format_tag::x); } if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, +dnnl::memory::format_tag::any); *user_dst_md = dnnl::memory::desc({ conv_dst_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +"format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = +dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = +dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = +dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = +dst->stridesOf()[isNCDHW ? 4 : 3]; } }; -void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->shapeInfo(); - Nd4jLong rank = shape[0]; - Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - Nd4jLong dim2 = axis >= 2 ? 1 : 2; - Nd4jLong dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; +void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, +const NDArray* dst, dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* +batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, dnnl::memory::desc* +user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* +user_dst_md, int axis) { const Nd4jLong* shape = src->shapeInfo(); Nd4jLong rank += shape[0]; Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to +be the "channel" one Nd4jLong dim2 = axis >= 2 ? 1 : 2; Nd4jLong dim3 = axis >= +3 ? 2 : 3; dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], +(int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? +(int)shape[dim3 + 1] : 1}; auto type = dnnl::memory::data_type::f32; auto format = dnnl::memory::format_tag::nchw; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - - if (src != nullptr && src->buffer() != nullptr && batchnorm_src_md != nullptr) { - *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides format - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->buffer() != nullptr && batchnorm_diff_src_md != nullptr) { - *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->buffer() != nullptr && batchnorm_dst_md != nullptr) { - *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides format - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // +doesn't work with "any" + + if (src != nullptr && src->buffer() != nullptr && batchnorm_src_md != +nullptr) { *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +supposed_to_be_any_format); *user_src_md = dnnl::memory::desc({ batchnorm_src_tz +}, type, format); user_src_md->data.format_kind = dnnl_blocked; // overrides +format user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = +src->stridesOf()[dim1]; user_src_md->data.format_desc.blocking.strides[2] = rank +> 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? +src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->buffer() != nullptr && +batchnorm_diff_src_md != nullptr) { *batchnorm_diff_src_md = +dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +format); user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format + user_diff_src_md->data.format_desc.blocking.strides[0] = +diff_src->stridesOf()[0]; user_diff_src_md->data.format_desc.blocking.strides[1] += diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? +diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? +diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->buffer() != nullptr && batchnorm_dst_md != +nullptr) { *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, +supposed_to_be_any_format); *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz +}, type, format); user_dst_md->data.format_kind = dnnl_blocked; // overrides +format user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = +dst->stridesOf()[dim1]; user_dst_md->data.format_desc.blocking.strides[2] = rank +> 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? +dst->stridesOf()[dim3] : 1; } }; */ -} -} \ No newline at end of file +} // namespace mkldnnUtils +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index d83d15576fe2..525b7f6642fd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -14,188 +14,219 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author saudet - // @author Yurii Shyrma (iuriish@yahoo.com) - // +// +// @author saudet +// @author Yurii Shyrma (iuriish@yahoo.com) +// #ifndef SD_MKLDNNUTILS_H #define SD_MKLDNNUTILS_H - -#include #include -#include -#include #include +#include +#include #include #include -using namespace samediff; +#include +using namespace samediff; namespace sd { - namespace ops { - namespace platforms { - /** - * Here we actually declare our platform helpers - */ - DECLARE_PLATFORM(conv2d, ENGINE_CPU); - - DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU); - - DECLARE_PLATFORM(conv3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU); - - DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU); - - DECLARE_PLATFORM(lrn, ENGINE_CPU); +namespace ops { +namespace platforms { +/** + * Here we actually declare our platform helpers + */ +DECLARE_PLATFORM(conv2d, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm, ENGINE_CPU); +DECLARE_PLATFORM(conv2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU); +DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); - DECLARE_PLATFORM(lstmLayer, ENGINE_CPU); +DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d, ENGINE_CPU); +DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU); +DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d, ENGINE_CPU); +DECLARE_PLATFORM(conv3dnew, ENGINE_CPU); - DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU); +DECLARE_PLATFORM(conv3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); +DECLARE_PLATFORM(maxpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); +DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); +DECLARE_PLATFORM(avgpool3dnew, ENGINE_CPU); - DECLARE_PLATFORM(matmul, ENGINE_CPU); +DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CPU); - DECLARE_PLATFORM(softmax, ENGINE_CPU); +DECLARE_PLATFORM(lrn, ENGINE_CPU); - DECLARE_PLATFORM(softmax_bp, ENGINE_CPU); +DECLARE_PLATFORM(batchnorm, ENGINE_CPU); - DECLARE_PLATFORM(tanh, ENGINE_CPU); +DECLARE_PLATFORM(batchnorm_bp, ENGINE_CPU); - DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); +DECLARE_PLATFORM(lstmLayer, ENGINE_CPU); - DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); +DECLARE_PLATFORM(deconv2d, ENGINE_CPU); - DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); +DECLARE_PLATFORM(deconv2d_tf, ENGINE_CPU); - } - } +DECLARE_PLATFORM(deconv3d, ENGINE_CPU); - namespace mkldnnUtils { +DECLARE_PLATFORM(deconv2d_bp, ENGINE_CPU); - void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); +DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); - void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); +DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); - void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); +DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); - dnnl::engine& getEngine(void* ptr); +DECLARE_PLATFORM(matmul, ENGINE_CPU); - /** - * This function creates memory dimentions - * @param const pointer to array - * @param const array rank - * @param reference to memory dimentions - */ - void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); - /** - * This function generate memory format tag based on rank - * @param const array rank - * @return memory format - */ - dnnl::memory::format_tag getFormat(const int rank); - /** - * This function generate memory format tag based on rank - * @param const pointer to dataset - * @param const dataset rank - * @param reference to memory descriptor - * @return memory format - */ - void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd); - ////////////////////////////////////////////////////////////////////// - /** - * This function load and reorder user memory to mkl - * @param const pointer to dataset - * @param reference to mkl engine - * @param reference to mkl stream - * @param reference to args container for dnnl - * @param reference to user memory description - * @param primitive memory descriptor - * @param dnnl arg activation enumerator - */ - void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, - dnnl::memory& arg); +DECLARE_PLATFORM(softmax, ENGINE_CPU); - /** - * Utility methods for MKLDNN - */ - /* void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); - - void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation); - - void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); - - void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r); - - void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); - */ - } -} - - - -#endif //SD_MKLDNNUTILS_H +DECLARE_PLATFORM(softmax_bp, ENGINE_CPU); + +DECLARE_PLATFORM(tanh, ENGINE_CPU); + +DECLARE_PLATFORM(tanh_bp, ENGINE_CPU); + +DECLARE_PLATFORM(xw_plus_b, ENGINE_CPU); + +DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); + +} // namespace platforms +} // namespace ops + +namespace mkldnnUtils { + +void poolingMKLDNN(const NDArray* input, NDArray* output, const int kD, + const int kH, const int kW, const int sD, const int sH, + const int sW, const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode); + +void poolingBpMKLDNN(const NDArray* input, const NDArray* gradO, NDArray* gradI, + const int kD, const int kH, const int kW, const int sD, + const int sH, const int sW, const int pD, const int pH, + const int pW, const int isNCHW, + const dnnl::algorithm mode); + +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::memory::desc* lrn_src_md, + dnnl::memory::desc* lrn_diff_src_md, + dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, + dnnl::memory::desc* user_diff_src_md, + dnnl::memory::desc* user_dst_md, int axis); + +dnnl::engine& getEngine(void* ptr); + +/** + * This function creates memory dimentions + * @param const pointer to array + * @param const array rank + * @param reference to memory dimentions + */ +void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); +/** + * This function generate memory format tag based on rank + * @param const array rank + * @return memory format + */ +dnnl::memory::format_tag getFormat(const int rank); +/** + * This function generate memory format tag based on rank + * @param const pointer to dataset + * @param const dataset rank + * @param reference to memory descriptor + * @return memory format + */ +void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd); +////////////////////////////////////////////////////////////////////// +/** + * This function load and reorder user memory to mkl + * @param const pointer to dataset + * @param reference to mkl engine + * @param reference to mkl stream + * @param reference to args container for dnnl + * @param reference to user memory description + * @param primitive memory descriptor + * @param dnnl arg activation enumerator + */ +void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, + const dnnl::stream& stream, + const dnnl::memory::desc& user_md, + const dnnl::memory::desc& primitive_md, + dnnl::memory& arg); + +/** + * Utility methods for MKLDNN + */ +/* void getMKLDNNMemoryDescConv2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, + const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, + int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* + weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* + dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, + dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* + conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* + conv_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* + user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* + user_dst_md, dnnl::memory::dims& conv_strides, dnnl::memory::dims& + conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& + conv_dilation); + + void getMKLDNNMemoryDescConv3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, + int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int bS, int + iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* + src, const NDArray* diff_src, const NDArray* weights, const NDArray* + diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* + conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* + conv_weights_md, dnnl::memory::desc* conv_diff_weights_md, + dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_weights_md, dnnl::memory::desc* + user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* + user_dst_md, dnnl::memory::dims& conv_strides, dnnl::memory::dims& + conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& + conv_dilation); + + void getMKLDNNMemoryDescPool2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, + int poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int + iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* dst, dnnl::algorithm& algorithm, dnnl::memory::desc* + pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* + pool_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& + pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& + pool_padding, dnnl::memory::dims& pool_padding_r); + + void getMKLDNNMemoryDescPool3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, + int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool + isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int + oW, const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::algorithm& algorithm, dnnl::memory::desc* pool_src_md, + dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, dnnl::memory::dims& + pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& + pool_padding, dnnl::memory::dims& pool_padding_r); + + void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* + diff_src, const NDArray* dst, dnnl::memory::desc* batchnorm_src_md, + dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* + batchnorm_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* + user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); +*/ +} // namespace mkldnnUtils +} // namespace sd + +#endif // SD_MKLDNNUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index e69639580644..96756a07aae9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -14,261 +14,292 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - - - ////////////////////////////////////////////////////////////////////// - static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { - - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, zShape; - - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(z, xRank, zShape); - - - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - // optimized cases - if (2 == xRank && 0 == axis) { - format = dnnl::memory::format_tag::ba; - } - else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { - format = dnnl::memory::format_tag::acdb; - } - - dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); - - // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); - mkldnnUtils::setBlockStrides(z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) - - // operation primitive description - dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); - - dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; +namespace ops { +namespace platforms { - // run calculations - dnnl::softmax_forward(op_prim_desc).execute(stream, args); +////////////////////////////////////////////////////////////////////// +static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { + const auto xRank = x->rankOf(); + dnnl::memory::dims xShape, zShape; - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(z, xRank, zShape); - stream.wait(); - } + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + // optimized cases + if (2 == xRank && 0 == axis) { + format = dnnl::memory::format_tag::ba; + } else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { + format = dnnl::memory::format_tag::acdb; + } + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - PLATFORM_IMPL(softmax, ENGINE_CPU) { + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + // z + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); + mkldnnUtils::setBlockStrides(z, z_user_md); - const int rank = input->rankOf(); - int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - if (dim < 0) { - dim += rank; - } + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) - REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); + // operation primitive description + dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, axis); - REQUIRE_TRUE(rank <= 6, 0, "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead !", rank); + dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - // mkldnnSoftMax - softmaxMKLDNN(input, output, dim); + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - return Status::OK(); - } + dnnl::stream stream(engine); - PLATFORM_CHECK(softmax, ENGINE_CPU) { + // provide memory buffers and check whether reorder is required - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - const DataType xType = x->dataType(); - const DataType zType = z->dataType(); + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; - const int xRank = x->rankOf(); - bool bSupportedRanks = (xRank > 2 && xRank < 7); - /* - Source Destination - f32 f32 - */ - return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); + // run calculations + dnnl::softmax_forward(op_prim_desc).execute(stream, args); - } + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - ////////////////////////////////////////////////////////////////////// - static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) { - - const auto xRank = x->rankOf(); - const auto dLdzRank = dLdz->rankOf(); - - dnnl::memory::dims xShape, dLdxShape, dLdzShape; - - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(dLdx, xRank, dLdxShape); - mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape); - - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); - - // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); - // todo if mkl does not support broadcast we can remove this - format = mkldnnUtils::getFormat(dLdzRank); - - // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - // forward description - dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); - dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward description - dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis); - dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map argsbp, argsff; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required for forward - // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); - - // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem; - argsff[DNNL_ARG_DST] = dLdx_mkl_mem; - - // check and arg set for backprob - argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; - argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; - // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); + stream.wait(); +} - // run calculations forward - dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); +PLATFORM_IMPL(softmax, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); - // run calculations backward - dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); + const int rank = input->rankOf(); + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + if (dim < 0) { + dim += rank; + } - stream.wait(); - } + REQUIRE_TRUE( + dim < rank && dim >= 0, 0, + "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); + REQUIRE_TRUE(rank <= 6, 0, + "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, " + "but got rank = %i instead !", + rank); - PLATFORM_IMPL(softmax_bp, ENGINE_CPU) { + // mkldnnSoftMax + softmaxMKLDNN(input, output, dim); - auto input = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); + return Status::OK(); +} - const int rank = input->rankOf(); - const int dLdzRank = dLdz->rankOf(); - int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; +PLATFORM_CHECK(softmax, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType zType = z->dataType(); + + const int xRank = x->rankOf(); + bool bSupportedRanks = (xRank > 2 && xRank < 7); + /* + Source Destination + f32 f32 + */ + return !x->isEmpty() && block.isUseMKLDNN() && bSupportedRanks && + (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); +} - if (dim < 0) { - dim += rank; - } +////////////////////////////////////////////////////////////////////// +static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, + NDArray* dLdx, const int axis) { + const auto xRank = x->rankOf(); + const auto dLdzRank = dLdz->rankOf(); + + dnnl::memory::dims xShape, dLdxShape, dLdzShape; + + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape); + + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc x_user_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = + dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdx_user_md = + dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + // todo if mkl does not support broadcast we can remove this + format = mkldnnUtils::getFormat(dLdzRank); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = + dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdz_user_md = + dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + // forward description + dnnl::softmax_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, axis); + dnnl::softmax_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward description + dnnl::softmax_backward::desc op_bp_desc(dLdz_mkl_md, dLdx_mkl_md, axis); + dnnl::softmax_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, + op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsbp, argsff; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_ff_prim_desc.src_desc(), + argsff[DNNL_ARG_SRC]); + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); + const bool dLdxReorder = + op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder + ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) + : dLdx_user_mem; + argsff[DNNL_ARG_DST] = dLdx_mkl_mem; + + // check and arg set for backprob + argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; + // dLdz + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, + op_bp_prim_desc.diff_dst_desc(), + argsbp[DNNL_ARG_DIFF_DST]); + + // run calculations forward + dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); + + // run calculations backward + dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem) + .execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + stream.wait(); +} - REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); +PLATFORM_IMPL(softmax_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); + int dim = block.numI() > 0 ? INT_ARG(0) : rank - 1; - // mkldnnSoftMax - softmaxBpMKLDNN(input, dLdz, dLdx, dim); + if (dim < 0) { + dim += rank; + } - return Status::OK(); - } + REQUIRE_TRUE( + dim < rank && dim >= 0, 0, + "SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) " + "must be less than input array rank %i, but got dimension = %i instead !", + rank, dim); - PLATFORM_CHECK(softmax_bp, ENGINE_CPU) { + REQUIRE_TRUE( + rank <= 6 && dLdzRank <= 6, 0, + "SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual " + "6, but got input rank = %i and dLdz rank rank = %i instead !", + rank, dLdzRank); - auto x = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); + // mkldnnSoftMax + softmaxBpMKLDNN(input, dLdz, dLdx, dim); - const DataType xType = x->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); + return Status::OK(); +} - const int xRank = x->rankOf(); - const int dLdzRank = dLdz->rankOf(); +PLATFORM_CHECK(softmax_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); - if (bSupportedRanks) { - for (int i = 0; i < xRank; i++) { - if (x->sizeAt(i) != dLdz->sizeAt(i)) { - bSupportedRanks = false; - break; - } - } - } + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); - //Source Destination - //f32 f32 - return block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); - } + bool bSupportedRanks = + xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); - } + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } } + } + + // Source Destination + // f32 f32 + return block.isUseMKLDNN() && bSupportedRanks && + (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdxType == DataType::FLOAT32); } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index 53d75d0a97e5..cea1b1b76842 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -14,223 +14,254 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - - ////////////////////////////////////////////////////////////////////// - static void tanhMKLDNN(const NDArray* x, NDArray* z) { - - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, zShape; - - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(z, xRank, zShape); - - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); - - // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // Create attributes (to handle alpha and beta if necessary) - dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0) - - // operation primitive description - dnnl::eltwise_forward::desc op_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0); - - dnnl::eltwise_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::eltwise_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - } - +namespace ops { +namespace platforms { - PLATFORM_IMPL(tanh, ENGINE_CPU) { +////////////////////////////////////////////////////////////////////// +static void tanhMKLDNN(const NDArray* x, NDArray* z) { + const auto xRank = x->rankOf(); + dnnl::memory::dims xShape, zShape; - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - const int rank = input->rankOf(); - REQUIRE_TRUE(rank <= 6, 0, "TANH_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead !", rank); + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(z, xRank, zShape); - // mkldnnTanh - tanhMKLDNN(input, output); + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - return Status::OK(); - } + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc x_user_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(x, x_user_md); - PLATFORM_CHECK(tanh, ENGINE_CPU) { + // z + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc z_user_md = + dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(z, z_user_md); - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - const DataType xType = x->dataType(); - const DataType zType = z->dataType(); + // Create attributes (to handle alpha and beta if necessary) + dnnl::primitive_attr attr; // it is empty since we have usual values for + // alpha (=1) and beta (=0) - const int xRank = x->rankOf(); - bool bSupportedRanks = !x->isEmpty() && xRank < 7 && xRank > 0 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); - /* - Source Destination - f32 f32 - */ - return block.isUseMKLDNN() && bSupportedRanks; - } + // operation primitive description + dnnl::eltwise_forward::desc op_desc(dnnl::prop_kind::forward_inference, + algorithm::eltwise_tanh, x_mkl_md, 0, 0); + dnnl::eltwise_forward::primitive_desc op_prim_desc(op_desc, attr, engine); - ////////////////////////////////////////////////////////////////////// - static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { + // arguments (memory buffers) necessary for calculations + std::unordered_map args; - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, dLdzShape, dLdxShape; + dnnl::stream stream(engine); - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(dLdz, xRank, dLdzShape); - mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + // provide memory buffers and check whether reorder is required + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + // run calculations + dnnl::eltwise_forward(op_prim_desc).execute(stream, args); - // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // operation primitive description - // forward - dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, algorithm::eltwise_tanh, x_mkl_md, 0, 0); - dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backward description - dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, x_mkl_md, 0, 0); - dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, op_ff_prim_desc); - - // provide memory buffers and check whether reorder is required for forward - // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); - - // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; - args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; - - // run calculations backward - dnnl::eltwise_backward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); - - stream.wait(); - } - - - PLATFORM_IMPL(tanh_bp, ENGINE_CPU) { - - auto input = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); - - const int rank = input->rankOf(); - const int dLdzRank = dLdz->rankOf(); + stream.wait(); +} - REQUIRE_TRUE(rank <= 6 && dLdzRank <= 6, 0, "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead !", rank, dLdzRank); +PLATFORM_IMPL(tanh, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int rank = input->rankOf(); + REQUIRE_TRUE(rank <= 6, 0, + "TANH_MKLDNN OP: the rank of input must be less or qual 6, but " + "got rank = %i instead !", + rank); - // mkldnnSoftMax - tanhBpMKLDNN(input, dLdz, dLdx); + // mkldnnTanh + tanhMKLDNN(input, output); - return Status::OK(); - } + return Status::OK(); +} - PLATFORM_CHECK(tanh_bp, ENGINE_CPU) { +PLATFORM_CHECK(tanh, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType zType = z->dataType(); + + const int xRank = x->rankOf(); + bool bSupportedRanks = + !x->isEmpty() && xRank < 7 && xRank > 0 && + (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); + /* + Source Destination + f32 f32 + */ + return block.isUseMKLDNN() && bSupportedRanks; +} - auto x = INPUT_VARIABLE(0); - auto dLdz = INPUT_VARIABLE(1); - auto dLdx = OUTPUT_VARIABLE(0); +////////////////////////////////////////////////////////////////////// +static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { + const auto xRank = x->rankOf(); + dnnl::memory::dims xShape, dLdzShape, dLdxShape; + + mkldnnUtils::getDims(x, xRank, xShape); + mkldnnUtils::getDims(dLdz, xRank, dLdzShape); + mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + + dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc x_user_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdz_user_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + dnnl::memory::desc dLdx_user_md = + dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // operation primitive description + // forward + dnnl::eltwise_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, + algorithm::eltwise_tanh, x_mkl_md, 0, + 0); + dnnl::eltwise_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward description + dnnl::eltwise_backward::desc op_desc(algorithm::eltwise_tanh, dLdz_mkl_md, + x_mkl_md, 0, 0); + dnnl::eltwise_backward::primitive_desc op_prim_desc(op_desc, engine, + op_ff_prim_desc); + + // provide memory buffers and check whether reorder is required for forward + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // dLdz + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, + op_prim_desc.diff_dst_desc(), + args[DNNL_ARG_DIFF_DST]); + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); + const bool dLdxReorder = + op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = dLdxReorder + ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) + : dLdx_user_mem; + args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + + // run calculations backward + dnnl::eltwise_backward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem) + .execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + stream.wait(); +} - const DataType xType = x->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); +PLATFORM_IMPL(tanh_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); - const int xRank = x->rankOf(); - const int dLdzRank = dLdz->rankOf(); + const int rank = input->rankOf(); + const int dLdzRank = dLdz->rankOf(); - bool bSupportedRanks = xRank < 7 && xRank > 0 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); - bSupportedRanks &= (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); + REQUIRE_TRUE( + rank <= 6 && dLdzRank <= 6, 0, + "TANH_BP_MKLDNN OP: the rank of input and dLdz must be less or qual 6, " + "but got input rank = %i and dLdz rank rank = %i instead !", + rank, dLdzRank); - if (bSupportedRanks) { - for (int i = 0; i < xRank; i++) { - if (x->sizeAt(i) != dLdz->sizeAt(i)) { - bSupportedRanks = false; - break; - } - } - } + // mkldnnSoftMax + tanhBpMKLDNN(input, dLdz, dLdx); - //Source Destination - //f32 f32 - return block.isUseMKLDNN() && bSupportedRanks; - } + return Status::OK(); +} - } +PLATFORM_CHECK(tanh_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto dLdz = INPUT_VARIABLE(1); + auto dLdx = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + + const int xRank = x->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + bool bSupportedRanks = xRank < 7 && xRank > 0 && dLdzRank == xRank && + (!x->isEmpty() && !dLdz->isEmpty()); + bSupportedRanks &= + (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdxType == DataType::FLOAT32); + + if (bSupportedRanks) { + for (int i = 0; i < xRank; i++) { + if (x->sizeAt(i) != dLdz->sizeAt(i)) { + bSupportedRanks = false; + break; + } } + } + + // Source Destination + // f32 f32 + return block.isUseMKLDNN() && bSupportedRanks; } + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp index 8a8ddd3efc75..bebf0fd418d5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -14,413 +14,501 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Oleg Semeniv - // - // +// +// @author Oleg Semeniv +// +// -#include +#include #include +#include #include -#include + #include "mkldnnUtils.h" using namespace dnnl; namespace sd { - namespace ops { - namespace platforms { - - ////////////////////////////////////////////////////////////////////// - static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, const NDArray* bias, NDArray* z, const bool bShouldTransp) { - - // mkl works with following - // [M,K] x [N,K]^T + [N] = [M,N] - const auto xRank = x->rankOf(); - - // [M,K] x [K,N] = [M,N] - const int M = x->sizeAt(0); - const int K = x->sizeAt(1); // K == wK - const int N = z->sizeAt(1); - - dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); - dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); - dnnl::memory::dims zShape = dnnl::memory::dims({ M, N }); - dnnl::memory::dims bShape = dnnl::memory::dims({ N }); - - dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; - - // x type - dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - if (x->dataType() == DataType::UINT8) - xType = dnnl::memory::data_type::u8; - else if (x->dataType() == DataType::INT8) - xType = dnnl::memory::data_type::s8; - - // weights type - dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) ? - wType = dnnl::memory::data_type::f32 : wType = dnnl::memory::data_type::s8; - - // bias type need add description for bias - dnnl::memory::data_type bType = dnnl::memory::data_type::f32; - if (bias->dataType() == DataType::INT32) - bType = dnnl::memory::data_type::s32; - else if (bias->dataType() == DataType::UINT8) - bType = dnnl::memory::data_type::u8; - else if (bias->dataType() == DataType::INT8) - bType = dnnl::memory::data_type::s8; - - // z type - dnnl::memory::data_type zType = dnnl::memory::data_type::f32; - if (z->dataType() == DataType::INT32) - zType = dnnl::memory::data_type::s32; - else if (z->dataType() == DataType::UINT8) - zType = dnnl::memory::data_type::u8; - else if (z->dataType() == DataType::INT8) - zType = dnnl::memory::data_type::s8; - - // memory descriptors for arrays - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); - - // weights - dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format); - if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { - - weights_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); - } - else { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); - } - } - // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(bias, bias_user_md); - - // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); - mkldnnUtils::setBlockStrides(z, z_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - - // operation primitive description - dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, z_mkl_md); - - dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); - - // arguments (memory buffers) necessary for calculations - std::unordered_map args; - - dnnl::stream stream(engine); - - // provide memory buffers and check whether reorder is required - - // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - - // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - - // bias - auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast(bias->buffer())); - args[DNNL_ARG_BIAS] = bias_mkl_mem; - - // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; - - // run calculations - dnnl::inner_product_forward(op_prim_desc).execute(stream, args); - - // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); - - stream.wait(); - } - - ////////////////////////////////////////////////////////////////////// - static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, const NDArray* bias, const NDArray* dLdz, - NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, const bool bShouldTransp) { - - // mkl works with following - // [M,K] x [N,K]^T + [N] = [M,N] - const auto xRank = x->rankOf(); - - // [M,K] x [K,N] = [M,N] - const int M = x->sizeAt(0); - const int K = x->sizeAt(1); // K == wK - const int N = dLdz->sizeAt(1); - // input dims - dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); - dnnl::memory::dims wShape = dnnl::memory::dims({ N, K }); - dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N }); - - dnnl::memory::dims bShape = dnnl::memory::dims({ N }); - // output dims - dnnl::memory::dims dLdxShape = xShape; - dnnl::memory::dims dLdwShape = wShape; - - dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; - dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; - - // memory descriptors for arrays - // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); - - // weights - dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format); - if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { - - weights_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); - } - else { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); - } - } - // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(bias, bias_user_md); - - // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); - - // dLdw - dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format); - dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format); - if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) { - - dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1); - dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0); - } - else { - dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0); - dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1); - } - } - - // dLdb - dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md); - - // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // forward - // operation primitive description - dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md); - dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); - - // backprob - // dLdw - auto op_bpdw_desc = inner_product_backward_weights::desc(x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); - auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc(op_bpdw_desc, engine, op_ff_prim_desc); - - // backprob - // dLdx - auto op_bpdx_desc = inner_product_backward_data::desc(dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); - auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc(op_bpdx_desc, engine, op_ff_prim_desc); - - // arguments (memory buffers) necessary for calculations - std::unordered_map argsDw, argsDx; - - dnnl::stream stream(engine); - - // dLdz dw - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); - - // dLdz - dx - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); - - // input x for dw - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); - - // weights - dx - mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); - - // dLdw - auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer()); - const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc(); - auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem; - argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem; - - // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; - argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; - - // dLdb - auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer()); - const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc(); - auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem; - argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem; - - // run calculations dw - dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw); - // run calculations dx - dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); - - // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); - - if (dLdwReorder) - dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem); - - if (dLdbReorder) - dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem); - - stream.wait(); - } - - PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); - - if (x->isEmpty() || w->isEmpty() || b->isEmpty()) - return Status::OK(); - - const int xRank = x->rankOf(); - const int wRank = w->rankOf(); - const int zRank = z->rankOf(); - - const bool bShouldTransp = block.numI() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] - - REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); - REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); - REQUIRE_TRUE(zRank == 2, 0, "xw_plus_b MKL: Output array should have rank equal 2, but got instead %i!", zRank); - - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, "xw_plus_b MKL: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); - - // mkldnnInerPorductss - xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); - - return Status::OK(); - } - - PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto z = OUTPUT_VARIABLE(0); - - const DataType xType = x->dataType(); - const DataType wType = w->dataType(); - const DataType bType = b->dataType(); - const DataType zType = z->dataType(); - - /* - Source Weights Destination Bias - f32 f32 f32 f32 - u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 - */ - return block.isUseMKLDNN() && - ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || - ( // x - (xType == DataType::UINT8 || xType == DataType::INT8) && - // w - (wType == DataType::UINT8 || wType == DataType::INT8) && - // b - (bType == DataType::UINT8 || bType == DataType::INT8 || bType == DataType::INT32 || bType == DataType::FLOAT32) && - // z - (zType == DataType::UINT8 || zType == DataType::INT8 || zType == DataType::INT32 || zType == DataType::FLOAT32) - )); - } - - PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdw = OUTPUT_VARIABLE(1); - auto dLdb = OUTPUT_VARIABLE(2); - - if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) - return Status::OK(); - - const int xRank = x->rankOf(); - const int wRank = w->rankOf(); - const int dLdzRank = dLdz->rankOf(); - - const bool bShouldTransp = block.numI() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] - - REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); - REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); - REQUIRE_TRUE(dLdz->rankOf() == 2, 0, "xw_plus_b BP MKL: Output array should have rank equal 2, but got instead %i!", dLdz->rankOf()); - REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, "xw_plus_b BP MKL: Input bias vector should be 1D and have proper dimension 1x%i." - " But got rank %i, and got length %i instead %i.", dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); - - xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); - - return Status::OK(); - } - - PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { - - auto x = INPUT_VARIABLE(0); - auto w = INPUT_VARIABLE(1); - auto b = INPUT_VARIABLE(2); - auto dLdz = INPUT_VARIABLE(3); - - auto dLdx = OUTPUT_VARIABLE(0); - auto dLdw = OUTPUT_VARIABLE(1); - auto dLdb = OUTPUT_VARIABLE(2); - - const DataType xType = x->dataType(); - const DataType wType = w->dataType(); - const DataType bType = b->dataType(); - const DataType dLdzType = dLdz->dataType(); - const DataType dLdxType = dLdx->dataType(); - const DataType dLdwType = dLdw->dataType(); - const DataType dLdbType = dLdb->dataType(); - - /* - Source Weights Destination Bias - f32 f32 f32 f32 - */ - return block.isUseMKLDNN() && - (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && - bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && - dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && - dLdwType == DataType::FLOAT32); - } - - } +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////// +static void xwPlusBiasMKLDNN(const NDArray* x, const NDArray* weights, + const NDArray* bias, NDArray* z, + const bool bShouldTransp) { + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = z->sizeAt(1); + + dnnl::memory::dims xShape = dnnl::memory::dims({M, K}); + dnnl::memory::dims wShape = dnnl::memory::dims({N, K}); + dnnl::memory::dims zShape = dnnl::memory::dims({M, N}); + dnnl::memory::dims bShape = dnnl::memory::dims({N}); + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + + // x type + dnnl::memory::data_type xType = dnnl::memory::data_type::f32; + if (x->dataType() == DataType::UINT8) + xType = dnnl::memory::data_type::u8; + else if (x->dataType() == DataType::INT8) + xType = dnnl::memory::data_type::s8; + + // weights type + dnnl::memory::data_type wType = (weights->dataType() == DataType::FLOAT32) + ? wType = dnnl::memory::data_type::f32 + : wType = dnnl::memory::data_type::s8; + + // bias type need add description for bias + dnnl::memory::data_type bType = dnnl::memory::data_type::f32; + if (bias->dataType() == DataType::INT32) + bType = dnnl::memory::data_type::s32; + else if (bias->dataType() == DataType::UINT8) + bType = dnnl::memory::data_type::u8; + else if (bias->dataType() == DataType::INT8) + bType = dnnl::memory::data_type::s8; + + // z type + dnnl::memory::data_type zType = dnnl::memory::data_type::f32; + if (z->dataType() == DataType::INT32) + zType = dnnl::memory::data_type::s32; + else if (z->dataType() == DataType::UINT8) + zType = dnnl::memory::data_type::u8; + else if (z->dataType() == DataType::INT8) + zType = dnnl::memory::data_type::s8; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = + dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = + dnnl::memory::desc(wShape, wType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = + weights->strideAt(0); + } else { + weights_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = + weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = + dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = + dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // z + dnnl::memory::desc z_mkl_md = + dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); + mkldnnUtils::setBlockStrides(z, z_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::inner_product_forward::desc op_desc(dnnl::prop_kind::forward_inference, + x_mkl_md, weights_mkl_md, + bias_mkl_md, z_mkl_md); + + dnnl::inner_product_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + + // weights + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, + op_prim_desc.weights_desc(), + args[DNNL_ARG_WEIGHTS]); + + // bias + auto bias_mkl_mem = + dnnl::memory(bias_mkl_md, engine, const_cast(bias->buffer())); + args[DNNL_ARG_BIAS] = bias_mkl_mem; + + // z + auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = + zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::inner_product_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////// +static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, + const NDArray* bias, const NDArray* dLdz, + NDArray* dLdx, NDArray* dLdw, NDArray* dLdb, + const bool bShouldTransp) { + // mkl works with following + // [M,K] x [N,K]^T + [N] = [M,N] + const auto xRank = x->rankOf(); + + // [M,K] x [K,N] = [M,N] + const int M = x->sizeAt(0); + const int K = x->sizeAt(1); // K == wK + const int N = dLdz->sizeAt(1); + // input dims + dnnl::memory::dims xShape = dnnl::memory::dims({M, K}); + dnnl::memory::dims wShape = dnnl::memory::dims({N, K}); + dnnl::memory::dims dLdzShape = dnnl::memory::dims({M, N}); + + dnnl::memory::dims bShape = dnnl::memory::dims({N}); + // output dims + dnnl::memory::dims dLdxShape = xShape; + dnnl::memory::dims dLdwShape = wShape; + + dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; + dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + // x + dnnl::memory::desc x_mkl_md = + dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(x, x_user_md); + + // weights + dnnl::memory::desc weights_mkl_md = + dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc weights_user_md = + dnnl::memory::desc(wShape, dataType, format); + if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + weights_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + weights_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(1); + weights_user_md.data.format_desc.blocking.strides[1] = + weights->strideAt(0); + } else { + weights_user_md.data.format_desc.blocking.strides[0] = + weights->strideAt(0); + weights_user_md.data.format_desc.blocking.strides[1] = + weights->strideAt(1); + } + } + // bias + dnnl::memory::desc bias_mkl_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc bias_user_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(bias, bias_user_md); + + // dLdz + dnnl::memory::desc dLdz_mkl_md = + dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdz_user_md = + dnnl::memory::desc(dLdzShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + + // dLdw + dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format); + dnnl::memory::desc dLdw_user_md = + dnnl::memory::desc(wShape, dataType, format); + if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) { + dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format + if (bShouldTransp) { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0); + } else { + dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0); + dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1); } + } + + // dLdb + dnnl::memory::desc dLdb_mkl_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + dnnl::memory::desc dLdb_user_md = + dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); + mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md); + + // dLdx + dnnl::memory::desc dLdx_mkl_md = + dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdx_user_md = + dnnl::memory::desc(xShape, dataType, format); + mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + + auto engine = + mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // forward + // operation primitive description + dnnl::inner_product_forward::desc op_ff_desc( + dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, + dLdz_mkl_md); + dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, + engine); + + // backprob + // dLdw + auto op_bpdw_desc = inner_product_backward_weights::desc( + x_mkl_md, dLdw_mkl_md, dLdb_mkl_md, dLdz_mkl_md); + auto op_bpdw_prim_desc = inner_product_backward_weights::primitive_desc( + op_bpdw_desc, engine, op_ff_prim_desc); + + // backprob + // dLdx + auto op_bpdx_desc = inner_product_backward_data::desc( + dLdx_mkl_md, weights_mkl_md, dLdz_mkl_md); + auto op_bpdx_prim_desc = inner_product_backward_data::primitive_desc( + op_bpdx_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map argsDw, argsDx; + + dnnl::stream stream(engine); + + // dLdz dw + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, + op_bpdw_prim_desc.diff_dst_desc(), + argsDw[DNNL_ARG_DIFF_DST]); + + // dLdz - dx + mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, + op_bpdx_prim_desc.diff_dst_desc(), + argsDx[DNNL_ARG_DIFF_DST]); + + // input x for dw + mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, + op_bpdw_prim_desc.src_desc(), + argsDw[DNNL_ARG_SRC]); + + // weights - dx + mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, + op_bpdx_prim_desc.weights_desc(), + argsDx[DNNL_ARG_WEIGHTS]); + + // dLdw + auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer()); + const bool dLdwReorder = + op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc(); + auto dLdw_mkl_mem = + dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) + : dLdw_user_mem; + argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem; + + // dLdx + auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); + const bool dLdxReorder = + op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); + auto dLdx_mkl_mem = + dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) + : dLdx_user_mem; + argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + + // dLdb + auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer()); + const bool dLdbReorder = + op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc(); + auto dLdb_mkl_mem = + dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) + : dLdb_user_mem; + argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem; + + // run calculations dw + dnnl::inner_product_backward_weights(op_bpdw_prim_desc) + .execute(stream, argsDw); + // run calculations dx + dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); + + // reorder outputs if necessary + if (dLdxReorder) + dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem) + .execute(stream, dLdx_mkl_mem, dLdx_user_mem); + + if (dLdwReorder) + dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem) + .execute(stream, dLdw_mkl_mem, dLdw_user_mem); + + if (dLdbReorder) + dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem) + .execute(stream, dLdb_mkl_mem, dLdb_user_mem); + + stream.wait(); } + +PLATFORM_IMPL(xw_plus_b, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty()) return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int zRank = z->rankOf(); + + const bool bShouldTransp = + block.numI() > 0 + ? (1 != INT_ARG(0)) + : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(xRank == 2, 0, + "xw_plus_b MKL: Input x array should have rank equal 2, but got " + "instead %i!", + xRank); + REQUIRE_TRUE(wRank == 2, 0, + "xw_plus_b MKL: Input weights array should have rank equal 2, " + "but got instead %i!", + wRank); + REQUIRE_TRUE(zRank == 2, 0, + "xw_plus_b MKL: Output array should have rank equal 2, but got " + "instead %i!", + zRank); + + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == z->sizeAt(-1), 0, + "xw_plus_b MKL: Input bias vector should be 1D and have proper " + "dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + z->sizeAt(-1), b->rankOf(), b->lengthOf(), z->sizeAt(-1)); + + // mkldnnInerPorductss + xwPlusBiasMKLDNN(x, w, b, z, bShouldTransp); + + return Status::OK(); +} + +PLATFORM_CHECK(xw_plus_b, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto z = OUTPUT_VARIABLE(0); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType zType = z->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + u8, s8 s8 u8, s8, s32, f32 u8, s8, s32, f32 + */ + return block.isUseMKLDNN() && + ((xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && zType == DataType::FLOAT32) || + ( // x + (xType == DataType::UINT8 || xType == DataType::INT8) && + // w + (wType == DataType::UINT8 || wType == DataType::INT8) && + // b + (bType == DataType::UINT8 || bType == DataType::INT8 || + bType == DataType::INT32 || bType == DataType::FLOAT32) && + // z + (zType == DataType::UINT8 || zType == DataType::INT8 || + zType == DataType::INT32 || zType == DataType::FLOAT32))); +} + +PLATFORM_IMPL(xw_plus_b_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + if (x->isEmpty() || w->isEmpty() || b->isEmpty() || dLdz->isEmpty()) + return Status::OK(); + + const int xRank = x->rankOf(); + const int wRank = w->rankOf(); + const int dLdzRank = dLdz->rankOf(); + + const bool bShouldTransp = + block.numI() > 0 + ? (1 != INT_ARG(0)) + : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + + REQUIRE_TRUE(x->rankOf() == 2, 0, + "xw_plus_b BP MKL: Input x array should have rank equal 2, but " + "got instead %i!", + x->rankOf()); + REQUIRE_TRUE(w->rankOf() == 2, 0, + "xw_plus_b BP MKL: Input weights array should have rank equal " + "2, but got instead %i!", + w->rankOf()); + REQUIRE_TRUE(dLdz->rankOf() == 2, 0, + "xw_plus_b BP MKL: Output array should have rank equal 2, but " + "got instead %i!", + dLdz->rankOf()); + REQUIRE_TRUE(1 == b->rankOf() && b->lengthOf() == dLdz->sizeAt(1), 0, + "xw_plus_b BP MKL: Input bias vector should be 1D and have " + "proper dimension 1x%i." + " But got rank %i, and got length %i instead %i.", + dLdz->sizeAt(1), b->rankOf(), b->lengthOf(), dLdz->sizeAt(1)); + + xwPlusBiasBp(x, w, b, dLdz, dLdx, dLdw, dLdb, bShouldTransp); + + return Status::OK(); +} + +PLATFORM_CHECK(xw_plus_b_bp, ENGINE_CPU) { + auto x = INPUT_VARIABLE(0); + auto w = INPUT_VARIABLE(1); + auto b = INPUT_VARIABLE(2); + auto dLdz = INPUT_VARIABLE(3); + + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdw = OUTPUT_VARIABLE(1); + auto dLdb = OUTPUT_VARIABLE(2); + + const DataType xType = x->dataType(); + const DataType wType = w->dataType(); + const DataType bType = b->dataType(); + const DataType dLdzType = dLdz->dataType(); + const DataType dLdxType = dLdx->dataType(); + const DataType dLdwType = dLdw->dataType(); + const DataType dLdbType = dLdb->dataType(); + + /* + Source Weights Destination Bias + f32 f32 f32 f32 + */ + return block.isUseMKLDNN() && + (xType == DataType::FLOAT32 && wType == DataType::FLOAT32 && + bType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && + dLdbType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32 && + dLdwType == DataType::FLOAT32); +} + +} // namespace platforms +} // namespace ops +} // namespace sd diff --git a/libnd4j/include/ops/gemm.h b/libnd4j/include/ops/gemm.h index 23f1636a2cfd..cf99babd93c0 100644 --- a/libnd4j/include/ops/gemm.h +++ b/libnd4j/include/ops/gemm.h @@ -25,38 +25,40 @@ #include #include - namespace sd { - namespace blas { - template - static void * transpose(int orderSource, int orderTarget, int rows, int cols, void *source); - - static inline int linearIndexC(int rows, int cols, int r, int c); - static inline int linearIndexF(int rows, int cols, int r, int c); - - template - class GEMM { - protected: - public: - static void op(int Order, int TransA, int TransB, int M, int N, int K, double alpha, void *A, int lda, void *B, int ldb, double beta, void *C, int ldc); - }; +namespace blas { +template +static void *transpose(int orderSource, int orderTarget, int rows, int cols, + void *source); - template - class GEMV : public sd::blas::GEMM{ - public: - static void op(int TRANS, int M, int N, double alpha, void* vA, int lda, void* vX, int incx, double beta, void* vY, int incy ); - }; +static inline int linearIndexC(int rows, int cols, int r, int c); +static inline int linearIndexF(int rows, int cols, int r, int c); +template +class GEMM { + protected: + public: + static void op(int Order, int TransA, int TransB, int M, int N, int K, + double alpha, void *A, int lda, void *B, int ldb, double beta, + void *C, int ldc); +}; - int FORCEINLINE linearIndexC(int rows, int cols, int r, int c) { - return (r * cols + c); - } +template +class GEMV : public sd::blas::GEMM { + public: + static void op(int TRANS, int M, int N, double alpha, void *vA, int lda, + void *vX, int incx, double beta, void *vY, int incy); +}; - int FORCEINLINE linearIndexF(int rows, int cols, int r, int c) { - return (c * rows + r); - } +int FORCEINLINE linearIndexC(int rows, int cols, int r, int c) { + return (r * cols + c); +} - } +int FORCEINLINE linearIndexF(int rows, int cols, int r, int c) { + return (c * rows + r); } -#endif //LIBND4J_GEMM_H +} // namespace blas +} // namespace sd + +#endif // LIBND4J_GEMM_H diff --git a/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp index 7e903346b817..3b2fb741b61f 100644 --- a/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastBoolOpsTuple.cpp @@ -20,8 +20,10 @@ #include namespace sd { - BroadcastBoolOpsTuple BroadcastBoolOpsTuple::custom(sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, sd::broadcast::BoolOps broadcast) { - BroadcastBoolOpsTuple t(scalar, pairwise, broadcast); - return t; - } +BroadcastBoolOpsTuple BroadcastBoolOpsTuple::custom( + sd::scalar::BoolOps scalar, sd::pairwise::BoolOps pairwise, + sd::broadcast::BoolOps broadcast) { + BroadcastBoolOpsTuple t(scalar, pairwise, broadcast); + return t; } +} // namespace sd diff --git a/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp index 5680b80569dd..d42887bc394d 100644 --- a/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp @@ -20,8 +20,10 @@ #include namespace sd { - BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, sd::broadcast::IntOps broadcast) { - BroadcastIntOpsTuple t(scalar, pairwise, broadcast); - return t; - } +BroadcastIntOpsTuple BroadcastIntOpsTuple::custom( + sd::scalar::IntOps scalar, sd::pairwise::IntOps pairwise, + sd::broadcast::IntOps broadcast) { + BroadcastIntOpsTuple t(scalar, pairwise, broadcast); + return t; } +} // namespace sd diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index 71afe82605f2..4fbee9db4f3c 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -20,47 +20,56 @@ #include namespace sd { - BroadcastOpsTuple BroadcastOpsTuple::custom(sd::scalar::Ops scalar, sd::pairwise::Ops pairwise, sd::broadcast::Ops broadcast) { - BroadcastOpsTuple t(scalar, pairwise, broadcast); - return t; - } - - BroadcastOpsTuple BroadcastOpsTuple::Add() { - return custom(sd::scalar::Add, sd::pairwise::Add, sd::broadcast::Add); - } - - BroadcastOpsTuple BroadcastOpsTuple::Assign() { - return custom(sd::scalar::CopyPws, sd::pairwise::CopyPws, sd::broadcast::CopyPws); - } +BroadcastOpsTuple BroadcastOpsTuple::custom(sd::scalar::Ops scalar, + sd::pairwise::Ops pairwise, + sd::broadcast::Ops broadcast) { + BroadcastOpsTuple t(scalar, pairwise, broadcast); + return t; +} - BroadcastOpsTuple BroadcastOpsTuple::Divide() { - return custom(sd::scalar::Divide, sd::pairwise::Divide, sd::broadcast::Divide); - } +BroadcastOpsTuple BroadcastOpsTuple::Add() { + return custom(sd::scalar::Add, sd::pairwise::Add, sd::broadcast::Add); +} - BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { - return custom(sd::scalar::DivideNoNan, sd::pairwise::DivideNoNan, sd::broadcast::DivideNoNan); - } +BroadcastOpsTuple BroadcastOpsTuple::Assign() { + return custom(sd::scalar::CopyPws, sd::pairwise::CopyPws, + sd::broadcast::CopyPws); +} - BroadcastOpsTuple BroadcastOpsTuple::Multiply() { - return custom(sd::scalar::Multiply, sd::pairwise::Multiply, sd::broadcast::Multiply); - } +BroadcastOpsTuple BroadcastOpsTuple::Divide() { + return custom(sd::scalar::Divide, sd::pairwise::Divide, + sd::broadcast::Divide); +} - BroadcastOpsTuple BroadcastOpsTuple::Subtract() { - return custom(sd::scalar::Subtract, sd::pairwise::Subtract, sd::broadcast::Subtract); - } - BroadcastOpsTuple BroadcastOpsTuple::IGamma() { - return custom(sd::scalar::IGamma, sd::pairwise::IGamma, sd::broadcast::IGamma); - } - BroadcastOpsTuple BroadcastOpsTuple::IGammac() { - return custom(sd::scalar::IGammac, sd::pairwise::IGammac, sd::broadcast::IGammac); - } +BroadcastOpsTuple BroadcastOpsTuple::DivideNoNan() { + return custom(sd::scalar::DivideNoNan, sd::pairwise::DivideNoNan, + sd::broadcast::DivideNoNan); +} +BroadcastOpsTuple BroadcastOpsTuple::Multiply() { + return custom(sd::scalar::Multiply, sd::pairwise::Multiply, + sd::broadcast::Multiply); +} - BroadcastOpsTuple BroadcastOpsTuple::Pow() { - return custom(sd::scalar::Pow, sd::pairwise::Pow, sd::broadcast::Pow); - } - BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { - return custom(sd::scalar::PowDerivative, sd::pairwise::PowDerivative, sd::broadcast::PowDerivative); - } +BroadcastOpsTuple BroadcastOpsTuple::Subtract() { + return custom(sd::scalar::Subtract, sd::pairwise::Subtract, + sd::broadcast::Subtract); +} +BroadcastOpsTuple BroadcastOpsTuple::IGamma() { + return custom(sd::scalar::IGamma, sd::pairwise::IGamma, + sd::broadcast::IGamma); +} +BroadcastOpsTuple BroadcastOpsTuple::IGammac() { + return custom(sd::scalar::IGammac, sd::pairwise::IGammac, + sd::broadcast::IGammac); +} +BroadcastOpsTuple BroadcastOpsTuple::Pow() { + return custom(sd::scalar::Pow, sd::pairwise::Pow, sd::broadcast::Pow); +} +BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() { + return custom(sd::scalar::PowDerivative, sd::pairwise::PowDerivative, + sd::broadcast::PowDerivative); } + +} // namespace sd diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_0.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_0.cpp index e9d262f58944..36d551ab6180 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_0.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_0.cpp @@ -22,7 +22,10 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_0); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_0); - BUILD_DOUBLE_TEMPLATE(template void SpecialTypeConverter::convertGeneric, (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), LIBND4J_TYPES, LIBND4J_TYPES); -} \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void SpecialTypeConverter::convertGeneric, + (Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz), + LIBND4J_TYPES, LIBND4J_TYPES); +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_1.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_1.cpp index a61a98870525..875c1a6106fc 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_1.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_1.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_1); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_1); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_2.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_2.cpp index 89deb3d9c3cc..d928d6926bb9 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_2.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_2.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_2); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_2); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_3.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_3.cpp index 7690749bf4ad..e6b2529368f1 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_3.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_3.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_3); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_3); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_4.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_4.cpp index 505ea99214e5..c54518b9b677 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_4.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_4.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_4); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_4); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_5.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_5.cpp index caa9d2dfa955..4763d9e31cbc 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_5.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_5.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_5); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_5); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_6.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_6.cpp index 9646534a9799..c025ff0295ef 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_6.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_6.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_6); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_6); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_7.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_7.cpp index 3230c1fbc63d..be9bfaea33a8 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_7.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_7.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_7); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_7); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_8.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_8.cpp index a56b335b62a2..01201b49752d 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_8.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_8.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_8); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_8); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_double_9.cpp b/libnd4j/include/ops/impl/compilation_units/specials_double_9.cpp index bb13c0415b46..c06c4106d770 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_double_9.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_double_9.cpp @@ -22,5 +22,6 @@ #include "../specials_double.hpp" namespace sd { - BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, LIBND4J_TYPES_9); +BUILD_DOUBLE_TEMPLATE(template class DoubleMethods, , LIBND4J_TYPES, + LIBND4J_TYPES_9); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_0.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_0.cpp index f74717f05f44..fe5caa77e489 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_0.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_0.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_0); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_0); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_1.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_1.cpp index cbacbb60ece1..e2b169700c77 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_1.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_1.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_1); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_1); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_2.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_2.cpp index b1c7c0db6501..11194acda55a 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_2.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_2.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_2); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_2); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_3.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_3.cpp index d340500e548f..9270211983cd 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_3.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_3.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_3); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_3); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_4.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_4.cpp index b8ea2a93357d..80570febd0f3 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_4.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_4.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_4); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_4); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_5.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_5.cpp index cc3fe3f0bbae..357fd39e8a23 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_5.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_5.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_5); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_5); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_6.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_6.cpp index 4e0b96a82e8b..fa9f037e6f35 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_6.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_6.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_6); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_6); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_7.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_7.cpp index e8bd8d950363..722272273a61 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_7.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_7.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_7); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_7); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_8.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_8.cpp index b2581352ef28..f18604f925cb 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_8.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_8.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_8); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_8); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/compilation_units/specials_single_9.cpp b/libnd4j/include/ops/impl/compilation_units/specials_single_9.cpp index 5105affa87ba..999ef9fee566 100644 --- a/libnd4j/include/ops/impl/compilation_units/specials_single_9.cpp +++ b/libnd4j/include/ops/impl/compilation_units/specials_single_9.cpp @@ -22,5 +22,5 @@ #include "../specials_single.hpp" namespace sd { - BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_9); +BUILD_SINGLE_TEMPLATE(template class SpecialMethods, , LIBND4J_TYPES_9); } \ No newline at end of file diff --git a/libnd4j/include/ops/impl/gemm.cpp b/libnd4j/include/ops/impl/gemm.cpp index 0c4ab167ce23..f8cb31a95a45 100644 --- a/libnd4j/include/ops/impl/gemm.cpp +++ b/libnd4j/include/ops/impl/gemm.cpp @@ -19,132 +19,128 @@ // Modified by GS on 3/9/2018 // +#include #include -#include #include -#include +#include namespace sd { - namespace blas { - - template - void* transpose(int orderSource, int orderTarget, int rows, int cols, void *vsource) { - auto ret = new T[rows * cols]; - auto source = reinterpret_cast(vsource); - - // handle transpose in parallel - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - for (int c = 0; c < cols; c++) { - int zIdx = orderTarget == CblasRowMajor ? linearIndexC(rows, cols, r, c) : linearIndexF(rows, cols, r, c); - int xIdx = orderSource == CblasColMajor ? linearIndexF(rows, cols, r, c) : linearIndexC(rows, cols, r, c); +namespace blas { + +template +void *transpose(int orderSource, int orderTarget, int rows, int cols, + void *vsource) { + auto ret = new T[rows * cols]; + auto source = reinterpret_cast(vsource); + + // handle transpose in parallel + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + for (int c = 0; c < cols; c++) { + int zIdx = orderTarget == CblasRowMajor + ? linearIndexC(rows, cols, r, c) + : linearIndexF(rows, cols, r, c); + int xIdx = orderSource == CblasColMajor + ? linearIndexF(rows, cols, r, c) + : linearIndexC(rows, cols, r, c); + + ret[zIdx] = source[xIdx]; + } + } + }; - ret[zIdx] = source[xIdx]; - } - } - }; + samediff::Threads::parallel_for(func, 0, rows); - samediff::Threads::parallel_for(func, 0, rows); + return ret; +} - return ret; +template +void GEMM::op(int Order, int TransA, int TransB, int M, int N, int K, + double alpha, void *vA, int lda, void *vB, int ldb, + double beta, void *vC, int ldc) { + auto A = reinterpret_cast(vA); + auto B = reinterpret_cast(vB); + auto C = reinterpret_cast(vC); + + bool transAFlag = TransA == CblasTrans; + bool transBFlag = TransB == CblasTrans; + + if (beta == 0.0) { + Z z = 0.f; + int length = M * N; + if (length <= Environment::getInstance()->elementwiseThreshold()) { + for (int r = 0; r < length; r++) C[r] = z; + } else { + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) C[r] = z; + }; + samediff::Threads::parallel_for(func, 0, length); + } + } + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto r = start_x; r < stop_x; r += inc_x) { + for (auto c = start_y; c < stop_y; c += inc_y) { + int zIdx = linearIndexF(M, N, r, c); + + Z dot = static_cast(0.0f); + + if (alpha != 0.0) { + int bIdx; // = linearIndexF(K, N, 0, c); + int aIdx; + + for (int k = 0; k < K; k++) { + aIdx = (transAFlag ? linearIndexC(M, K, r, k) + : linearIndexF(M, K, r, k)); + bIdx = (transBFlag ? linearIndexC(K, N, k, c) + : linearIndexF(K, N, k, c)); + dot += static_cast(alpha) * static_cast(A[aIdx]) * + static_cast(B[bIdx]); // A[aIdx]sd::math::nd4j_dot(aX, + // bX, K) * alpha; + } } - template - void GEMM::op(int Order, int TransA, int TransB, - int M, int N, int K, - double alpha, - void *vA, int lda, - void *vB, int ldb, - double beta, - void *vC, int ldc) { - - auto A = reinterpret_cast(vA); - auto B = reinterpret_cast(vB); - auto C = reinterpret_cast(vC); - - bool transAFlag = TransA == CblasTrans; - bool transBFlag = TransB == CblasTrans; - - if (beta == 0.0) { - Z z = 0.f; - int length = M*N; - if (length <= Environment::getInstance()->elementwiseThreshold()) { - for (int r = 0; r < length; r++) - C[r] = z; - } else { - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) - C[r] = z; - }; - samediff::Threads::parallel_for(func, 0, length); - } - } - - - auto func = PRAGMA_THREADS_FOR_2D { - for (auto r = start_x; r < stop_x; r += inc_x) { - for (auto c = start_y; c < stop_y; c += inc_y) { - int zIdx = linearIndexF(M, N, r, c); - - Z dot = static_cast(0.0f); - - if (alpha != 0.0) { - int bIdx; // = linearIndexF(K, N, 0, c); - int aIdx; - - for (int k = 0; k < K; k++) { - aIdx = (transAFlag ? linearIndexC(M, K, r, k) : linearIndexF(M, K, r, k)); - bIdx = (transBFlag ? linearIndexC(K, N, k, c) : linearIndexF(K, N, k, c)); - dot += static_cast(alpha) * static_cast(A[aIdx]) * static_cast(B[bIdx]);//A[aIdx]sd::math::nd4j_dot(aX, bX, K) * alpha; - } - } - - if (beta != 0.0) { - C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); - } else { - C[zIdx] = static_cast(dot); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); + if (beta != 0.0) { + C[zIdx] = static_cast(dot + static_cast(beta) * C[zIdx]); + } else { + C[zIdx] = static_cast(dot); } + } + } + }; + samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); +} - template - void GEMV::op(int TRANS, int M, int N, - double alpha, - void * vX, - int lda, - void* vY, - int incx, - double beta, - void* vZ, - int incy ) { - - auto x = reinterpret_cast(vX); - auto y = reinterpret_cast(vY); - auto z = reinterpret_cast(vZ); - - auto aT = TRANS == CblasTrans ? reinterpret_cast(sd::blas::transpose(CblasColMajor, CblasRowMajor, M, N, reinterpret_cast(x))) : x; - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - int aIdx = linearIndexC(M, N, r, 0); - auto aX = aT + aIdx; - - auto dot = sd::math::nd4j_dot(aX, y, lda) * static_cast(alpha); - z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; - } - }; - samediff::Threads::parallel_for(func, 0, M); - - if (TRANS == CblasTrans) - delete[] aT; - } - - //BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - //BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +template +void GEMV::op(int TRANS, int M, int N, double alpha, void *vX, int lda, + void *vY, int incx, double beta, void *vZ, int incy) { + auto x = reinterpret_cast(vX); + auto y = reinterpret_cast(vY); + auto z = reinterpret_cast(vZ); + + auto aT = TRANS == CblasTrans ? reinterpret_cast(sd::blas::transpose( + CblasColMajor, CblasRowMajor, M, N, + reinterpret_cast(x))) + : x; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + int aIdx = linearIndexC(M, N, r, 0); + auto aX = aT + aIdx; + + auto dot = + sd::math::nd4j_dot(aX, y, lda) * static_cast(alpha); + z[r] = beta == 0.0f ? dot : dot + static_cast(beta) * z[r]; } + }; + samediff::Threads::parallel_for(func, 0, M); + + if (TRANS == CblasTrans) delete[] aT; } + +// BUILD_TRIPLE_TEMPLATE(template class GEMV, , LIBND4J_TYPES, FLOAT_TYPES, +// FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template class GEMM, , LIBND4J_TYPES, +// FLOAT_TYPES, FLOAT_TYPES); +} // namespace blas +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_double.hpp b/libnd4j/include/ops/impl/specials_double.hpp index 1eaf3fbc0df7..9aa8cd65f3fd 100644 --- a/libnd4j/include/ops/impl/specials_double.hpp +++ b/libnd4j/include/ops/impl/specials_double.hpp @@ -19,252 +19,296 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include +#include +#include #include +#include +#include #include #include -#include -#include +#include #include -#include namespace sd { +template +void SpecialTypeConverter::convertGeneric(Nd4jPointer *extras, void *dx, + Nd4jLong N, void *dz) { + auto x = reinterpret_cast(dx); + auto z = reinterpret_cast(dz); - template - void SpecialTypeConverter::convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz) { - auto x = reinterpret_cast(dx); - auto z = reinterpret_cast(dz); - - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - z[i] = static_cast(x[i]); - } - }; - - samediff::Threads::parallel_for(func, 0, N); - }; - - - template - void quickSort_parallel_internal_key(X* key, Nd4jLong const* xShapeInfo, Y* values, Nd4jLong const* yShapeInfo, int left, int right, int cutoff, bool descending) { - int i = left, j = right; - X ktmp; - X pivot = key[shape::getIndexOffset((left + right) / 2, xShapeInfo)]; - - Y vtmp; - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (key[shape::getIndexOffset(i, xShapeInfo)] > pivot) - i++; - while (key[shape::getIndexOffset(j, xShapeInfo)] < pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; - values[shape::getIndexOffset(i, yShapeInfo)] = values[shape::getIndexOffset(j, yShapeInfo)]; - values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } else { - while (key[shape::getIndexOffset(i, xShapeInfo)] < pivot) - i++; - while (key[shape::getIndexOffset(j, xShapeInfo)] > pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; - values[shape::getIndexOffset(i, yShapeInfo)] = values[shape::getIndexOffset(j, yShapeInfo)]; - values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } - } - + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + z[i] = static_cast(x[i]); + } + }; + + samediff::Threads::parallel_for(func, 0, N); +}; + +template +void quickSort_parallel_internal_key(X *key, Nd4jLong const *xShapeInfo, + Y *values, Nd4jLong const *yShapeInfo, + int left, int right, int cutoff, + bool descending) { + int i = left, j = right; + X ktmp; + X pivot = key[shape::getIndexOffset((left + right) / 2, xShapeInfo)]; + + Y vtmp; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (key[shape::getIndexOffset(i, xShapeInfo)] > pivot) i++; + while (key[shape::getIndexOffset(j, xShapeInfo)] < pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; + values[shape::getIndexOffset(i, yShapeInfo)] = + values[shape::getIndexOffset(j, yShapeInfo)]; + values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } - - // - - if ( ((right-left) pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = values[shape::getIndexOffset(i, yShapeInfo)]; + values[shape::getIndexOffset(i, yShapeInfo)] = + values[shape::getIndexOffset(j, yShapeInfo)]; + values[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } + } } + } + // - template - void quickSort_parallel_internal_value(X* key, Nd4jLong const* xShapeInfo, Y* value, Nd4jLong const* yShapeInfo, int left, int right, int cutoff, bool descending) { - int i = left, j = right; - X ktmp; - Y pivot = value[shape::getIndexOffset((left + right) / 2, yShapeInfo)]; - - Y vtmp; - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (value[shape::getIndexOffset(i, yShapeInfo)] > pivot) - i++; - while (value[shape::getIndexOffset(j, yShapeInfo)] < pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; - value[shape::getIndexOffset(i, yShapeInfo)] = value[shape::getIndexOffset(j, yShapeInfo)]; - value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } else { - while (value[shape::getIndexOffset(i, yShapeInfo)] < pivot) - i++; - while (value[shape::getIndexOffset(j, yShapeInfo)] > pivot) - j--; - if (i <= j) { - ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; - key[shape::getIndexOffset(i, xShapeInfo)] = key[shape::getIndexOffset(j, xShapeInfo)]; - key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; - - vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; - value[shape::getIndexOffset(i, yShapeInfo)] = value[shape::getIndexOffset(j, yShapeInfo)]; - value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; - - i++; - j--; - } - } - } - - } - - // + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal_key(key, xShapeInfo, values, yShapeInfo, left, + j, cutoff, descending); + } + if (i < right) { + quickSort_parallel_internal_key(key, xShapeInfo, values, yShapeInfo, i, + right, cutoff, descending); + } - if ( ((right-left) +void quickSort_parallel_internal_value(X *key, Nd4jLong const *xShapeInfo, + Y *value, Nd4jLong const *yShapeInfo, + int left, int right, int cutoff, + bool descending) { + int i = left, j = right; + X ktmp; + Y pivot = value[shape::getIndexOffset((left + right) / 2, yShapeInfo)]; + + Y vtmp; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (value[shape::getIndexOffset(i, yShapeInfo)] > pivot) i++; + while (value[shape::getIndexOffset(j, yShapeInfo)] < pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; + value[shape::getIndexOffset(i, yShapeInfo)] = + value[shape::getIndexOffset(j, yShapeInfo)]; + value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; + } + } else { + while (value[shape::getIndexOffset(i, yShapeInfo)] < pivot) i++; + while (value[shape::getIndexOffset(j, yShapeInfo)] > pivot) j--; + if (i <= j) { + ktmp = key[shape::getIndexOffset(i, xShapeInfo)]; + key[shape::getIndexOffset(i, xShapeInfo)] = + key[shape::getIndexOffset(j, xShapeInfo)]; + key[shape::getIndexOffset(j, xShapeInfo)] = ktmp; + + vtmp = value[shape::getIndexOffset(i, yShapeInfo)]; + value[shape::getIndexOffset(i, yShapeInfo)] = + value[shape::getIndexOffset(j, yShapeInfo)]; + value[shape::getIndexOffset(j, yShapeInfo)] = vtmp; + + i++; + j--; } + } } + } + // - template - static void quickSort_parallel_key(void *varray, Nd4jLong const* xShapeInfo, void *yarray, Nd4jLong const* yShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - auto values = reinterpret_cast(yarray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal_key(array, xShapeInfo, values, yShapeInfo, 0, lenArray-1, cutoff, descending); - } - } + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, + left, j, cutoff, descending); } - - template - static void quickSort_parallel_value(void *varray, Nd4jLong const* xShapeInfo, void *yarray, Nd4jLong const* yShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - auto values = reinterpret_cast(yarray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal_value(array, xShapeInfo, values, yShapeInfo, 0, lenArray-1, cutoff, descending); - } - } + if (i < right) { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, i, + right, cutoff, descending); } - template - void DoubleMethods::sortByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending) { - quickSort_parallel_key(vx, xShapeInfo, vy, yShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); + } else { + PRAGMA_OMP_TASK { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, + left, j, cutoff, descending); } - - template - void DoubleMethods::sortByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending) { - quickSort_parallel_value(vx, xShapeInfo, vy, yShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); + PRAGMA_OMP_TASK { + quickSort_parallel_internal_value(key, xShapeInfo, value, yShapeInfo, i, + right, cutoff, descending); } + } +} - template - void DoubleMethods::sortTadByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); - - auto xLength = shape::length(xShapeInfo); - auto xTadLength = shape::length(packX.primaryShapeInfo()); - auto numTads = packX.numberOfTads(); - - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto dx = x + packX.primaryOffsets()[r]; - auto dy = y + packY.primaryOffsets()[r]; - - quickSort_parallel_key(dx, packX.primaryShapeInfo(), dy, packY.primaryShapeInfo(), xTadLength, 1, descending); - } - }; - - samediff::Threads::parallel_tad(func, 0, numTads); +template +static void quickSort_parallel_key(void *varray, Nd4jLong const *xShapeInfo, + void *yarray, Nd4jLong const *yShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + auto values = reinterpret_cast(yarray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal_key(array, xShapeInfo, values, yShapeInfo, 0, + lenArray - 1, cutoff, descending); } + } +} - template - void DoubleMethods::sortTadByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending) { - auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); +template +static void quickSort_parallel_value(void *varray, Nd4jLong const *xShapeInfo, + void *yarray, Nd4jLong const *yShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + auto values = reinterpret_cast(yarray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal_value(array, xShapeInfo, values, yShapeInfo, + 0, lenArray - 1, cutoff, descending); + } + } +} - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); +template +void DoubleMethods::sortByKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + bool descending) { + quickSort_parallel_key(vx, xShapeInfo, vy, yShapeInfo, + shape::length(xShapeInfo), omp_get_max_threads(), + descending); +} - auto xLength = shape::length(xShapeInfo); - auto xTadLength = shape::length(packX.primaryShapeInfo()); - auto numTads = packX.numberOfTads(); +template +void DoubleMethods::sortByValue(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + bool descending) { + quickSort_parallel_value(vx, xShapeInfo, vy, yShapeInfo, + shape::length(xShapeInfo), + omp_get_max_threads(), descending); +} - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - auto dx = x + packX.primaryOffsets()[r]; - auto dy = y + packY.primaryOffsets()[r]; +template +void DoubleMethods::sortTadByKey(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + bool descending) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + auto xLength = shape::length(xShapeInfo); + auto xTadLength = shape::length(packX.primaryShapeInfo()); + auto numTads = packX.numberOfTads(); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto dx = x + packX.primaryOffsets()[r]; + auto dy = y + packY.primaryOffsets()[r]; + + quickSort_parallel_key(dx, packX.primaryShapeInfo(), dy, + packY.primaryShapeInfo(), xTadLength, 1, + descending); + } + }; - quickSort_parallel_value(dx, packX.primaryShapeInfo(), dy, packY.primaryShapeInfo(), xTadLength, 1, descending); - } - }; + samediff::Threads::parallel_tad(func, 0, numTads); +} - samediff::Threads::parallel_tad(func, 0, numTads); +template +void DoubleMethods::sortTadByValue(void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int *dimension, int dimensionLength, + bool descending) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + xShapeInfo, dimension, dimensionLength); + auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + yShapeInfo, dimension, dimensionLength); + + auto xLength = shape::length(xShapeInfo); + auto xTadLength = shape::length(packX.primaryShapeInfo()); + auto numTads = packX.numberOfTads(); + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + auto dx = x + packX.primaryOffsets()[r]; + auto dy = y + packY.primaryOffsets()[r]; + + quickSort_parallel_value(dx, packX.primaryShapeInfo(), dy, + packY.primaryShapeInfo(), xTadLength, 1, + descending); } -} + }; + samediff::Threads::parallel_tad(func, 0, numTads); +} +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_single.hpp b/libnd4j/include/ops/impl/specials_single.hpp index 9a700251c555..892ffce48bf4 100644 --- a/libnd4j/include/ops/impl/specials_single.hpp +++ b/libnd4j/include/ops/impl/specials_single.hpp @@ -19,24 +19,24 @@ // @author Yurii Shyrma (iuriish@yahoo.com) // - -#include -#include +#include +#include #include +#include +#include #include #include -#include -#include +#include #include -#include namespace sd { /** -* Concatneate multi array of the same shape together -* along a particular dimension -*/ + * Concatneate multi array of the same shape together + * along a particular dimension + */ // template -// void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { +// void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, +// NDArray& output, const int axis) { // const uint numOfArrs = inArrs.size(); // int outDim; @@ -45,17 +45,20 @@ namespace sd { // if(isOutputVector || (axis == 0 && output.ordering() == 'c')) { // bool allVectorsOrScalars = true; -// const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews(); +// const uint outEws = isOutputVector ? output.stridesOf()[outDim] : +// output.ews(); // std::vector nonUnityDim(numOfArrs); // std::vector zOffset(numOfArrs); // for(int i = 0; i < numOfArrs; i++) { -// allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i])); +// allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || +// inArrs[i]->isCommonVector(nonUnityDim[i])); // if(!allVectorsOrScalars) // break; // if(i == 0) zOffset[0] = 0; -// else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf(); +// else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - +// 1]->lengthOf(); // } // if(allVectorsOrScalars) { @@ -65,7 +68,8 @@ namespace sd { // auto func = PRAGMA_THREADS_FOR { // for (auto r = start; r < stop; r += increment) { // const Nd4jLong arrLen = inArrs[r]->lengthOf(); -// const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]]; +// const uint xEws = (arrLen == 1) ? 1 : +// inArrs[r]->stridesOf()[nonUnityDim[r]]; // T *z = outBuff + zOffset[r]; // T *x = inArrs[r]->bufferAsT(); @@ -86,21 +90,26 @@ namespace sd { // const int rank = inArrs[0]->rankOf(); // const int rank2 = 2*rank; -// std::vector> indices(numOfArrs, std::vector(rank2,0)); +// std::vector> indices(numOfArrs, +// std::vector(rank2,0)); // // take into account indices for first array // indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); // // loop through the rest of input arrays // for(int i = 1; i < numOfArrs; ++i) { -// indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from -// indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) +// indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index +// start from indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] +// + inArrs[i]->sizeAt(axis); // index end with (excluding) // } // auto func = PRAGMA_THREADS_FOR { // for (auto i = start; i < stop; i += increment) { // auto temp = output(indices[i], true); -// sd::TransformLoops::template loopTransform>( inArrs[i]->bufferAsT(), inArrs[i]->shapeInfo(), temp.bufferAsT(), temp.shapeInfo(), nullptr, 0, 1); +// sd::TransformLoops::template +// loopTransform>( +// inArrs[i]->bufferAsT(), inArrs[i]->shapeInfo(), +// temp.bufferAsT(), temp.shapeInfo(), nullptr, 0, 1); // } // }; @@ -108,227 +117,238 @@ namespace sd { // } template -void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfInArrs = inArrs.size(); - const auto sizeofT = output.sizeOfT(); - - T* zBuff = output.bufferAsT(); - - bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1; - - if(luckCase1) { - for (uint i = 0; i < numOfInArrs; ++i) { - luckCase1 &= inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; - if(!luckCase1) - break; - } +void SpecialMethods::concatCpuGeneric( + const std::vector &inArrs, NDArray &output, + const int axis) { + const int numOfInArrs = inArrs.size(); + const auto sizeofT = output.sizeOfT(); + + T *zBuff = output.bufferAsT(); + + bool luckCase1 = + ((axis == 0 && output.ordering() == 'c') || + (axis == output.rankOf() - 1 && output.ordering() == 'f')) && + output.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numOfInArrs; ++i) { + luckCase1 &= + inArrs[i]->ordering() == output.ordering() && inArrs[i]->ews() == 1; + if (!luckCase1) break; } + } - if(luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; or {10,1} + {10,2} + {10,3} = {10, 6} order f + if (luckCase1) { // for example {1,10} + {2,10} + {3,10} = {6, 10} order c; + // or {10,1} + {10,2} + {10,3} = {10, 6} order f - T* z = zBuff; - for (uint i = 0; i < numOfInArrs; ++i) { - const auto memAmountToCopy = inArrs[i]->lengthOf(); - memcpy(z, inArrs[i]->bufferAsT(), memAmountToCopy * sizeofT); - z += memAmountToCopy; - } - return; + T *z = zBuff; + for (uint i = 0; i < numOfInArrs; ++i) { + const auto memAmountToCopy = inArrs[i]->lengthOf(); + memcpy(z, inArrs[i]->bufferAsT(), memAmountToCopy * sizeofT); + z += memAmountToCopy; } + return; + } - // const bool isZcontin = output.strideAt(axis) == 1; - // bool areInputsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numOfInArrs); - - // if(isZcontin) { + // const bool isZcontin = output.strideAt(axis) == 1; + // bool areInputsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numOfInArrs); - // for (uint i = 0; i < numOfInArrs; ++i) { + // if(isZcontin) { - // areInputsContin &= inArrs[i]->strideAt(axis) == 1; - // allSameOrder &= inArrs[i]->ordering() == output.ordering(); - // if(!areInputsContin || !allSameOrder) - // break; + // for (uint i = 0; i < numOfInArrs; ++i) { - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo()); - // } - // } + // areInputsContin &= inArrs[i]->strideAt(axis) == 1; + // allSameOrder &= inArrs[i]->ordering() == output.ordering(); + // if(!areInputsContin || !allSameOrder) + // break; - // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // inArrs[i]->shapeInfo()); + // } + // } - // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array + // const bool luckCase2 = isZcontin && areInputsContin && allSameOrder; - // const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo()); + // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, + // here axis 1 shoud have stride = 1 for all inputs arrays and output array - // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { + // const auto zStep = shape::strideOverContigAxis(axis, + // output.shapeInfo()); - // T* z = zBuff + zStep * i; + // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { - // for (uint j = 0; j < inArrs.size(); ++j) { - // const auto xDim = inArrs[j]->sizeAt(axis); - // const T* x = inArrs[j]->bufferAsT() + strideOfContigStride[j] * i; - // memcpy(z, x, xDim * sizeofT); - // z += xDim; - // } - // } + // T* z = zBuff + zStep * i; - // return; - // } - - // general case - auto func = PRAGMA_THREADS_FOR { + // for (uint j = 0; j < inArrs.size(); ++j) { + // const auto xDim = inArrs[j]->sizeAt(axis); + // const T* x = inArrs[j]->bufferAsT() + + // strideOfContigStride[j] * i; memcpy(z, x, xDim * sizeofT); z += + // xDim; + // } + // } - int coords[MAX_RANK], temp; + // return; + // } - for (auto i = start; i < stop; i += increment) { + // general case + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - shape::index2coordsCPU(start, i, output.shapeInfo(), coords); + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, output.shapeInfo(), coords); - const auto zOffset = shape::getOffset(output.shapeInfo(), coords); + const auto zOffset = shape::getOffset(output.shapeInfo(), coords); - uint inArrIdx = 0; - uint xDim = inArrs[inArrIdx]->sizeAt(axis); + uint inArrIdx = 0; + uint xDim = inArrs[inArrIdx]->sizeAt(axis); - temp = coords[axis]; - while (coords[axis] >= xDim) { - coords[axis] -= xDim; - xDim = inArrs[++inArrIdx]->sizeAt(axis); - } + temp = coords[axis]; + while (coords[axis] >= xDim) { + coords[axis] -= xDim; + xDim = inArrs[++inArrIdx]->sizeAt(axis); + } - const T* x = inArrs[inArrIdx]->bufferAsT(); - const auto xOffset = shape::getOffset(inArrs[inArrIdx]->shapeInfo(), coords); + const T *x = inArrs[inArrIdx]->bufferAsT(); + const auto xOffset = + shape::getOffset(inArrs[inArrIdx]->shapeInfo(), coords); - zBuff[zOffset] = x[xOffset]; + zBuff[zOffset] = x[xOffset]; - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, output.lengthOf()); + samediff::Threads::parallel_for(func, 0, output.lengthOf()); } /** -* Concatneate multi array of the same shape together -* along a particular dimension -*/ + * Concatneate multi array of the same shape together + * along a particular dimension + */ template -void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong const* resultShapeInfo) { - auto result = reinterpret_cast(vresult); - std::vector inputs(numArrays); +void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, + Nd4jPointer *data, + Nd4jPointer *inputShapeInfo, + void *vresult, + Nd4jLong const *resultShapeInfo) { + auto result = reinterpret_cast(vresult); + std::vector inputs(numArrays); - NDArray output(static_cast(result), resultShapeInfo); + NDArray output(static_cast(result), resultShapeInfo); - for(int i = 0; i < numArrays; ++i) - inputs[i] = new NDArray(static_cast(data[i]), static_cast(inputShapeInfo[i])); + for (int i = 0; i < numArrays; ++i) + inputs[i] = new NDArray(static_cast(data[i]), + static_cast(inputShapeInfo[i])); - sd::SpecialMethods::concatCpuGeneric(inputs, output, dimension); + sd::SpecialMethods::concatCpuGeneric(inputs, output, dimension); - for(int i = 0; i < numArrays; ++i) - delete inputs[i]; + for (int i = 0; i < numArrays; ++i) delete inputs[i]; } - template -void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis) { - - int numSplits = outArrs.size(); +void SpecialMethods::splitCpuGeneric(const NDArray &input, + const std::vector &outArrs, + const int axis) { + int numSplits = outArrs.size(); - const auto sizeofT = input.sizeOfT(); + const auto sizeofT = input.sizeOfT(); - auto xBuff = input.bufferAsT(); + auto xBuff = input.bufferAsT(); - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || + (axis == input.rankOf() - 1 && input.ordering() == 'f')) && + input.ews() == 1; - if (luckCase1) { - for (uint i = 0; i < numSplits; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if (!luckCase1) - break; - } + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= + outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) break; } - - if (luckCase1) { - - T* x = const_cast(xBuff); - for (uint i = 0; i < numSplits; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf(); - memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); - x += memAmountToCopy; - } - return; + } + + if (luckCase1) { + T *x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; } + return; + } - // const bool isXcontin = input.strideAt(axis) == 1; - // bool areOutsContin = true; - // bool allSameOrder = true; - // std::vector strideOfContigStride(numSplits); - - // if (isXcontin) { + // const bool isXcontin = input.strideAt(axis) == 1; + // bool areOutsContin = true; + // bool allSameOrder = true; + // std::vector strideOfContigStride(numSplits); - // for (uint i = 0; i < numSplits; ++i) { + // if (isXcontin) { - // areOutsContin &= outArrs[i]->strideAt(axis) == 1; - // allSameOrder &= outArrs[i]->ordering() == input.ordering(); - // if (!areOutsContin || !allSameOrder) - // break; + // for (uint i = 0; i < numSplits; ++i) { - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->shapeInfo()); - // } - // } + // areOutsContin &= outArrs[i]->strideAt(axis) == 1; + // allSameOrder &= outArrs[i]->ordering() == input.ordering(); + // if (!areOutsContin || !allSameOrder) + // break; - // const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + // strideOfContigStride[i] = shape::strideOverContigAxis(axis, + // outArrs[i]->shapeInfo()); + // } + // } - // if (luckCase2) { + // const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; - // const auto xStep = shape::strideOverContigAxis(axis, input.shapeInfo()); + // if (luckCase2) { - // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { + // const auto xStep = shape::strideOverContigAxis(axis, + // input.shapeInfo()); - // T* x = xBuff + xStep * i; + // for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) { - // for (uint j = 0; j < numSplits; ++j) { - // const auto zDim = outArrs[j]->sizeAt(axis); - // T* z = outArrs[j]->bufferAsT() + strideOfContigStride[j] * i; - // memcpy(z, x, zDim * sizeofT); - // x += zDim; - // } - // } + // T* x = xBuff + xStep * i; - // return; - // } + // for (uint j = 0; j < numSplits; ++j) { + // const auto zDim = outArrs[j]->sizeAt(axis); + // T* z = outArrs[j]->bufferAsT() + strideOfContigStride[j] * + // i; memcpy(z, x, zDim * sizeofT); x += zDim; + // } + // } - uint zDim = outArrs[0]->sizeAt(axis); - // general case + // return; + // } - auto func = PRAGMA_THREADS_FOR{ + uint zDim = outArrs[0]->sizeAt(axis); + // general case - int coords[MAX_RANK], temp; + auto func = PRAGMA_THREADS_FOR { + int coords[MAX_RANK], temp; - for (auto i = start; i < stop; i += increment) { + for (auto i = start; i < stop; i += increment) { + shape::index2coordsCPU(start, i, input.shapeInfo(), coords); + const auto xOffset = shape::getOffset(input.shapeInfo(), coords); - shape::index2coordsCPU(start, i, input.shapeInfo(), coords); - const auto xOffset = shape::getOffset(input.shapeInfo(), coords); + uint outArrIdx = 0; + temp = coords[axis]; - uint outArrIdx = 0; - temp = coords[axis]; + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } - while (coords[axis] >= zDim) { - coords[axis] -= zDim; - ++outArrIdx; - } + T *z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = + shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; - T* z = outArrs[outArrIdx]->bufferAsT(); - const auto zOffset = shape::getOffset(outArrs[outArrIdx]->shapeInfo(), coords); - z[zOffset] = xBuff[xOffset]; - - coords[axis] = temp; - } - }; + coords[axis] = temp; + } + }; - samediff::Threads::parallel_for(func, 0, input.lengthOf()); + samediff::Threads::parallel_for(func, 0, input.lengthOf()); } - /** * This kernel accumulates X arrays, and stores result into Z * @@ -338,22 +358,23 @@ void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector< * @param n * @param length */ - template - void SpecialMethods::accumulateGeneric(void **vx, void *vz, Nd4jLong const* zShapeInfo, int n, const Nd4jLong length) { - auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (auto ar = 0L; ar < n; ar++) { - z[i] += x[ar][i]; - } - } - }; - - samediff::Threads::parallel_for(func, 0, length); +template +void SpecialMethods::accumulateGeneric(void **vx, void *vz, + Nd4jLong const *zShapeInfo, int n, + const Nd4jLong length) { + auto z = reinterpret_cast(vz); + auto x = reinterpret_cast(vx); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (auto ar = 0L; ar < n; ar++) { + z[i] += x[ar][i]; + } } + }; + samediff::Threads::parallel_for(func, 0, length); +} /** * This kernel averages X input arrays, and stores result to Z @@ -365,277 +386,296 @@ void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector< * @param length * @param propagate */ - template - void SpecialMethods::averageGeneric(void **vx, void *vz, Nd4jLong const* zShapeInfo, int n, const Nd4jLong length, bool propagate) { - auto z = reinterpret_cast(vz); - auto x = reinterpret_cast(vx); - - if (z == nullptr) { - //code branch for absent Z - z = x[0]; - - PRAGMA_OMP_SIMD - for (uint64_t i = 0; i < length; i++) { - z[i] /= static_cast(n); - } - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (Nd4jLong ar = 1; ar < n; ar++) { - z[i] += x[ar][i] / static_cast(n); - } - } - }; - samediff::Threads::parallel_for(func, 0, length); - - // instead of doing element-wise propagation, we just issue memcpy to propagate data - for (Nd4jLong ar = 1; ar < n; ar++) { - memcpy(x[ar], z, length * sizeof(T)); - } - } else { - // code branch for existing Z - - // memset before propagation - memset(z, 0, length * sizeof(T)); - - // aggregation step - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - for (Nd4jLong ar = 0; ar < n; ar++) { - z[i] += x[ar][i] / static_cast(n); - } - } - }; - samediff::Threads::parallel_for(func, 0, length); - - // instead of doing element-wise propagation, we just issue memcpy to propagate data - for (Nd4jLong ar = 0; ar < n; ar++) { - memcpy(x[ar], z, length * sizeof(T)); - } - } +template +void SpecialMethods::averageGeneric(void **vx, void *vz, + Nd4jLong const *zShapeInfo, int n, + const Nd4jLong length, bool propagate) { + auto z = reinterpret_cast(vz); + auto x = reinterpret_cast(vx); + + if (z == nullptr) { + // code branch for absent Z + z = x[0]; + + PRAGMA_OMP_SIMD + for (uint64_t i = 0; i < length; i++) { + z[i] /= static_cast(n); } - template - Nd4jLong SpecialMethods::getPosition(Nd4jLong const* xShapeInfo, Nd4jLong index) { - auto xEWS = shape::elementWiseStride(xShapeInfo); + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (Nd4jLong ar = 1; ar < n; ar++) { + z[i] += x[ar][i] / static_cast(n); + } + } + }; + samediff::Threads::parallel_for(func, 0, length); - if (xEWS == 1) - return index; - else if (xEWS > 1) - return index * xEWS; - else - return shape::getIndexOffset(index, xShapeInfo); + // instead of doing element-wise propagation, we just issue memcpy to + // propagate data + for (Nd4jLong ar = 1; ar < n; ar++) { + memcpy(x[ar], z, length * sizeof(T)); } + } else { + // code branch for existing Z - template - void SpecialMethods::quickSort_parallel_internal(T* array, Nd4jLong const* xShapeInfo, int left, int right, int cutoff, bool descending) { - - int i = left, j = right; - T tmp; - T pivot = array[getPosition(xShapeInfo, (left + right) / 2)]; - - - { - /* PARTITION PART */ - while (i <= j) { - if (descending) { - while (array[getPosition(xShapeInfo, i)] > pivot) - i++; - while (array[getPosition(xShapeInfo, j)] < pivot) - j--; - if (i <= j) { - tmp = array[getPosition(xShapeInfo, i)]; - array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; - array[getPosition(xShapeInfo, j)] = tmp; - i++; - j--; - } - } else { - while (array[getPosition(xShapeInfo, i)] < pivot) - i++; - while (array[getPosition(xShapeInfo, j)] > pivot) - j--; - if (i <= j) { - tmp = array[getPosition(xShapeInfo, i)]; - array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; - array[getPosition(xShapeInfo, j)] = tmp; - i++; - j--; - } - } - } + // memset before propagation + memset(z, 0, length * sizeof(T)); + // aggregation step + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + for (Nd4jLong ar = 0; ar < n; ar++) { + z[i] += x[ar][i] / static_cast(n); } + } + }; + samediff::Threads::parallel_for(func, 0, length); - // + // instead of doing element-wise propagation, we just issue memcpy to + // propagate data + for (Nd4jLong ar = 0; ar < n; ar++) { + memcpy(x[ar], z, length * sizeof(T)); + } + } +} - if ( ((right-left) +Nd4jLong SpecialMethods::getPosition(Nd4jLong const *xShapeInfo, + Nd4jLong index) { + auto xEWS = shape::elementWiseStride(xShapeInfo); + + if (xEWS == 1) + return index; + else if (xEWS > 1) + return index * xEWS; + else + return shape::getIndexOffset(index, xShapeInfo); +} - }else{ -PRAGMA_OMP_TASK - { quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, descending); } -PRAGMA_OMP_TASK - { quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, descending); } +template +void SpecialMethods::quickSort_parallel_internal(T *array, + Nd4jLong const *xShapeInfo, + int left, int right, + int cutoff, + bool descending) { + int i = left, j = right; + T tmp; + T pivot = array[getPosition(xShapeInfo, (left + right) / 2)]; + + { + /* PARTITION PART */ + while (i <= j) { + if (descending) { + while (array[getPosition(xShapeInfo, i)] > pivot) i++; + while (array[getPosition(xShapeInfo, j)] < pivot) j--; + if (i <= j) { + tmp = array[getPosition(xShapeInfo, i)]; + array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; + array[getPosition(xShapeInfo, j)] = tmp; + i++; + j--; } - } - - template - void SpecialMethods::quickSort_parallel(void *varray, Nd4jLong const* xShapeInfo, Nd4jLong lenArray, int numThreads, bool descending){ - auto array = reinterpret_cast(varray); - int cutoff = 1000; - - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - quickSort_parallel_internal(array, xShapeInfo, 0, lenArray-1, cutoff, descending); - } + } else { + while (array[getPosition(xShapeInfo, i)] < pivot) i++; + while (array[getPosition(xShapeInfo, j)] > pivot) j--; + if (i <= j) { + tmp = array[getPosition(xShapeInfo, i)]; + array[getPosition(xShapeInfo, i)] = array[getPosition(xShapeInfo, j)]; + array[getPosition(xShapeInfo, j)] = tmp; + i++; + j--; } - + } } + } + // - - template - int SpecialMethods::nextPowerOf2(int number) { - int pos = 0; - - while (number > 0) { - pos++; - number = number >> 1; - } - return (int) pow(2, pos); + if (((right - left) < cutoff)) { + if (left < j) { + quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, + descending); + } + if (i < right) { + quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, + descending); } - template - int SpecialMethods::lastPowerOf2(int number) { - int p = 1; - while (p <= number) - p <<= 1; - - p >>= 1; - return p; + } else { + PRAGMA_OMP_TASK { + quickSort_parallel_internal(array, xShapeInfo, left, j, cutoff, + descending); + } + PRAGMA_OMP_TASK { + quickSort_parallel_internal(array, xShapeInfo, i, right, cutoff, + descending); } + } +} +template +void SpecialMethods::quickSort_parallel(void *varray, + Nd4jLong const *xShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending) { + auto array = reinterpret_cast(varray); + int cutoff = 1000; + + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + quickSort_parallel_internal(array, xShapeInfo, 0, lenArray - 1, cutoff, + descending); + } + } +} - template - void SpecialMethods::sortGeneric(void *vx, Nd4jLong const* xShapeInfo, bool descending) { - auto x = reinterpret_cast(vx); +template +int SpecialMethods::nextPowerOf2(int number) { + int pos = 0; + + while (number > 0) { + pos++; + number = number >> 1; + } + return (int)pow(2, pos); +} - quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); - } +template +int SpecialMethods::lastPowerOf2(int number) { + int p = 1; + while (p <= number) p <<= 1; - template - void SpecialMethods::sortTadGeneric(void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending) { - auto x = reinterpret_cast(vx); + p >>= 1; + return p; +} - //quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), omp_get_max_threads(), descending); - Nd4jLong xLength = shape::length(xShapeInfo); - Nd4jLong xTadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); - int numTads = xLength / xTadLength; +template +void SpecialMethods::sortGeneric(void *vx, Nd4jLong const *xShapeInfo, + bool descending) { + auto x = reinterpret_cast(vx); - auto func = PRAGMA_THREADS_FOR { - for (auto r = start; r < stop; r++) { - T *dx = x + tadOffsets[r]; + quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), + omp_get_max_threads(), descending); +} - quickSort_parallel(dx, tadShapeInfo, xTadLength, 1, descending); - } - }; - samediff::Threads::parallel_tad(func, 0, numTads); +template +void SpecialMethods::sortTadGeneric(void *vx, Nd4jLong const *xShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, + bool descending) { + auto x = reinterpret_cast(vx); + + // quickSort_parallel(x, xShapeInfo, shape::length(xShapeInfo), + // omp_get_max_threads(), descending); + Nd4jLong xLength = shape::length(xShapeInfo); + Nd4jLong xTadLength = + shape::tadLength(xShapeInfo, dimension, dimensionLength); + int numTads = xLength / xTadLength; + + auto func = PRAGMA_THREADS_FOR { + for (auto r = start; r < stop; r++) { + T *dx = x + tadOffsets[r]; + + quickSort_parallel(dx, tadShapeInfo, xTadLength, 1, descending); } + }; + samediff::Threads::parallel_tad(func, 0, numTads); +} +template +void SpecialMethods::decodeBitmapGeneric(const void *dx, Nd4jLong N, + void *vz, + Nd4jLong const *zShapeInfo) { + auto dz = reinterpret_cast(vz); + auto x = reinterpret_cast(dx); + Nd4jLong lim = N / 16 + 5; + + FloatBits2 fb; + fb.i_ = x[2]; + float threshold = fb.f_; + + auto pPos = -1; + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + const auto v = x[e]; + for (int bitId = 0; bitId < 16; bitId++) { + bool hasBit = (v & 1 << (bitId)) != 0; + bool hasSign = (v & 1 << (bitId + 16)) != 0; + auto cPos = (e - 4) * 16 + bitId; + + if (hasBit) { + if (hasSign) + dz[cPos] -= static_cast(threshold); + else + dz[cPos] += static_cast(threshold); + } else if (hasSign) { + dz[cPos] -= static_cast(threshold / 2); + } - template - void SpecialMethods::decodeBitmapGeneric(const void *dx, Nd4jLong N, void *vz, Nd4jLong const* zShapeInfo) { - auto dz = reinterpret_cast(vz); - auto x = reinterpret_cast(dx); - Nd4jLong lim = N / 16 + 5; - - FloatBits2 fb; - fb.i_ = x[2]; - float threshold = fb.f_; - - auto pPos = -1; - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - const auto v = x[e]; - for (int bitId = 0; bitId < 16; bitId++) { - bool hasBit = (v & 1 << (bitId)) != 0; - bool hasSign = (v & 1 << (bitId + 16)) != 0; - auto cPos = (e - 4) * 16 + bitId; - - if (hasBit) { - if (hasSign) - dz[cPos] -= static_cast(threshold); - else - dz[cPos] += static_cast(threshold); - } else if (hasSign) { - dz[cPos] -= static_cast(threshold / 2); - } - - pPos = cPos; - } - } - }; - - samediff::Threads::parallel_for(func, 4, lim); + pPos = cPos; + } } + }; - template - Nd4jLong SpecialMethods::encodeBitmapGeneric(void *vx, Nd4jLong const* xShapeInfo, Nd4jLong N, int *dz, float threshold) { - auto dx = reinterpret_cast(vx); - const T two(2.0f); - const T zero(0.0f); - const T t(threshold); - const T thalf = t / two; + samediff::Threads::parallel_for(func, 4, lim); +} + +template +Nd4jLong SpecialMethods::encodeBitmapGeneric(void *vx, + Nd4jLong const *xShapeInfo, + Nd4jLong N, int *dz, + float threshold) { + auto dx = reinterpret_cast(vx); + const T two(2.0f); + const T zero(0.0f); + const T t(threshold); + const T thalf = t / two; - //auto func = PRAGMA_REDUCE_LONG { - Nd4jLong retVal = 0L; + // auto func = PRAGMA_REDUCE_LONG { + Nd4jLong retVal = 0L; - PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+:retVal) - for (auto x = 0; x < N; x += 16) { - int byte = 0; - int byteId = x / 16 + 4; + PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+ : retVal) + for (auto x = 0; x < N; x += 16) { + int byte = 0; + int byteId = x / 16 + 4; - for (int f = 0; f < 16; f++) { - auto e = x + f; + for (int f = 0; f < 16; f++) { + auto e = x + f; - if (e >= N) - continue; + if (e >= N) continue; - T val = dx[e]; - T abs = sd::math::nd4j_abs(val); + T val = dx[e]; + T abs = sd::math::nd4j_abs(val); - int bitId = e % 16; + int bitId = e % 16; - if (abs >= t) { - byte |= 1 << (bitId); - retVal++; + if (abs >= t) { + byte |= 1 << (bitId); + retVal++; - if (val < zero) { - byte |= 1 << (bitId + 16); - dx[e] += t; - } else { - dx[e] -= t; - } - } else if (abs >= thalf && val < zero) { - byte |= 1 << (bitId + 16); - dx[e] += thalf; + if (val < zero) { + byte |= 1 << (bitId + 16); + dx[e] += t; + } else { + dx[e] -= t; + } + } else if (abs >= thalf && val < zero) { + byte |= 1 << (bitId + 16); + dx[e] += thalf; - retVal++; - } - } + retVal++; + } + } - dz[byteId] = byte; - } + dz[byteId] = byte; + } - return retVal; - //}; + return retVal; + //}; - //return samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, N, 16); - } + // return samediff::Threads::parallel_long(func, LAMBDA_SUML, 0, N, 16); } - +} // namespace sd diff --git a/libnd4j/include/ops/impl/specials_sparse.cpp b/libnd4j/include/ops/impl/specials_sparse.cpp index 798e3f93a6cd..7de2d954a4e4 100644 --- a/libnd4j/include/ops/impl/specials_sparse.cpp +++ b/libnd4j/include/ops/impl/specials_sparse.cpp @@ -19,10 +19,10 @@ // #include -#include -#include #include #include +#include +#include #ifdef _OPENMP #include #endif @@ -30,193 +30,203 @@ #include namespace sd { - namespace sparse { - - template - void SparseUtils::printIndex(Nd4jLong *indices, int rank, int x) { - printf(" ["); - for (int e = 0; e < rank; e++) { - if (e > 0) - printf(", "); - - printf("%lld", (long long) indices[x * rank + e]); - } - printf("] "); - } - - template - bool SparseUtils::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) { - for (int e = 0; e < rank; e++) { - Nd4jLong idxX = indices[x * rank + e]; - Nd4jLong idxY = indices[y * rank + e]; - // we're comparing indices one by one, starting from outer dimension - if (idxX < idxY) { - return true; - } else if (idxX == idxY) { - // do nothing, continue to next dimension - } else - return false; - } - - return false; - } +namespace sparse { - template - bool SparseUtils::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) { - for (int e = 0; e < rank; e++) { - // we're comparing indices one by one, starting from outer dimension - Nd4jLong idxX = indices[x * rank + e]; - Nd4jLong idxY = indices[y * rank + e]; - if ( idxX > idxY) { - return true; - } else if (idxX == idxY) { - // do nothing, continue to next dimension - } else - return false; - } - return false; - } +template +void SparseUtils::printIndex(Nd4jLong *indices, int rank, int x) { + printf(" ["); + for (int e = 0; e < rank; e++) { + if (e > 0) printf(", "); - template - void SparseUtils::swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y) { - // swap indices - for (int e = 0; e < rank; e++) { - Nd4jLong tmp = indices[x * rank + e]; - indices[x * rank + e] = indices[y * rank + e]; - indices[y * rank + e] = tmp; - } - - // swap values - T tmp = array[x]; - array[x] = array[y]; - array[y] = tmp; - } + printf("%lld", (long long)indices[x * rank + e]); + } + printf("] "); +} - template - Nd4jLong SparseUtils::coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right, - int rank) { - Nd4jLong mid = (left + right) / 2; +template +bool SparseUtils::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, + Nd4jLong y) { + for (int e = 0; e < rank; e++) { + Nd4jLong idxX = indices[x * rank + e]; + Nd4jLong idxY = indices[y * rank + e]; + // we're comparing indices one by one, starting from outer dimension + if (idxX < idxY) { + return true; + } else if (idxX == idxY) { + // do nothing, continue to next dimension + } else + return false; + } + + return false; +} - // ensure left < mid - if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid - swapEverything(indices, array, rank, mid, left); - } +template +bool SparseUtils::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, + Nd4jLong y) { + for (int e = 0; e < rank; e++) { + // we're comparing indices one by one, starting from outer dimension + Nd4jLong idxX = indices[x * rank + e]; + Nd4jLong idxY = indices[y * rank + e]; + if (idxX > idxY) { + return true; + } else if (idxX == idxY) { + // do nothing, continue to next dimension + } else + return false; + } + return false; +} - // ensure left < right - if (ltIndices(indices, rank, right, left)) { - swapEverything(indices, array, rank, right, left); - } +template +void SparseUtils::swapEverything(Nd4jLong *indices, T *array, int rank, + Nd4jLong x, Nd4jLong y) { + // swap indices + for (int e = 0; e < rank; e++) { + Nd4jLong tmp = indices[x * rank + e]; + indices[x * rank + e] = indices[y * rank + e]; + indices[y * rank + e] = tmp; + } + + // swap values + T tmp = array[x]; + array[x] = array[y]; + array[y] = tmp; +} - // ensure mid < right - if (ltIndices(indices, rank, right, mid)) { - swapEverything(indices, array, rank, right, mid); - } +template +Nd4jLong SparseUtils::coo_quickSort_findPivot(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int rank) { + Nd4jLong mid = (left + right) / 2; + + // ensure left < mid + if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid + swapEverything(indices, array, rank, mid, left); + } + + // ensure left < right + if (ltIndices(indices, rank, right, left)) { + swapEverything(indices, array, rank, right, left); + } + + // ensure mid < right + if (ltIndices(indices, rank, right, mid)) { + swapEverything(indices, array, rank, right, mid); + } + + // mid is the median of the 3, and is the optimal pivot point + return mid; +} - // mid is the median of the 3, and is the optimal pivot point - return mid; +template +void SparseUtils::coo_quickSort_parallel_internal(Nd4jLong *indices, + T *array, Nd4jLong left, + Nd4jLong right, int cutoff, + int rank) { + Nd4jLong span = right - left; // elements to be partitioned - 1 + + if (span == 1) { + // only 2 elements to partition. swap if needed and return directly without + // further sorting. + if (ltIndices(indices, rank, right, left)) { + swapEverything(indices, array, rank, left, right); } + return; + } + + // find optimal pivot and sort left < right < right + Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank); + + if (span == 2) { + // only 3 elements to partition. findPivot has already sorted them. no + // further sorting is needed. + return; + } + + // index that is greater than pivot - leftmost element is already partitioned + // because of findPivot. + Nd4jLong i = left + 1; + + // index that is smaller than pivot - rightmost element is already partitioned + // because of findPivot. + Nd4jLong j = right - 1; + + { + // flag that indicates that pivot index lies between i and j and *could* be + // swapped. + bool checkPivot = true; + /* PARTITION PART */ + while (i <= j) { + while (ltIndices(indices, rank, i, pvt)) i++; + + while (gtIndices(indices, rank, j, pvt)) j--; + + if (i <= j) { + if (i != j) { // swap can be fairly expensive. don't swap i -> i + swapEverything(indices, array, rank, i, j); + } - template - void SparseUtils::coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank) { - Nd4jLong span = right - left; // elements to be partitioned - 1 - - if (span == 1){ - // only 2 elements to partition. swap if needed and return directly without further sorting. - if (ltIndices(indices, rank, right, left)){ - swapEverything(indices, array, rank, left, right); - } - return; - } - - - // find optimal pivot and sort left < right < right - Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank); - - if (span == 2){ - // only 3 elements to partition. findPivot has already sorted them. no further sorting is needed. - return; - } - - // index that is greater than pivot - leftmost element is already partitioned because of findPivot. - Nd4jLong i = left + 1; - - // index that is smaller than pivot - rightmost element is already partitioned because of findPivot. - Nd4jLong j = right - 1; - - - { - // flag that indicates that pivot index lies between i and j and *could* be swapped. - bool checkPivot = true; - /* PARTITION PART */ - while (i <= j) { - while (ltIndices(indices, rank, i, pvt)) - i++; - - while (gtIndices(indices, rank, j, pvt)) - j--; - - - if (i <= j) { - if(i != j) { // swap can be fairly expensive. don't swap i -> i - swapEverything(indices, array, rank, i, j); - } - - // only check pivot if it hasn't already been swapped. - if (checkPivot) { - // check if we moved the pivot, if so, change pivot index accordingly - if (pvt == j) { - pvt = i; - checkPivot = false; - } else if (pvt == i) { - pvt = j; - checkPivot = false; - } - } - - i++; - j--; - } - } - - } - - if ( (span < cutoff) ){ - if (left < j){ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); } - if (i < right){ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); } - - }else{ -PRAGMA_OMP_TASK - { coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); } -PRAGMA_OMP_TASK - { coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); } - } - + // only check pivot if it hasn't already been swapped. + if (checkPivot) { + // check if we moved the pivot, if so, change pivot index accordingly + if (pvt == j) { + pvt = i; + checkPivot = false; + } else if (pvt == i) { + pvt = j; + checkPivot = false; + } } - template - void SparseUtils::coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank){ + i++; + j--; + } + } + } - int cutoff = 1000; + if ((span < cutoff)) { + if (left < j) { + coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); + } + if (i < right) { + coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); + } - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { -PRAGMA_OMP_SINGLE_ARGS(nowait) - { - coo_quickSort_parallel_internal(indices, array, 0, lenArray-1, cutoff, rank); - } - } + } else { + PRAGMA_OMP_TASK { + coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); + } + PRAGMA_OMP_TASK { + coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); + } + } +} - } +template +void SparseUtils::coo_quickSort_parallel(Nd4jLong *indices, T *array, + Nd4jLong lenArray, int numThreads, + int rank) { + int cutoff = 1000; - template - void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) { + PRAGMA_OMP_PARALLEL_THREADS(numThreads) { + PRAGMA_OMP_SINGLE_ARGS(nowait) { + coo_quickSort_parallel_internal(indices, array, 0, lenArray - 1, cutoff, + rank); + } + } +} + +template +void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, T *values, + Nd4jLong length, int rank) { #ifdef _OPENMP - coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank); + coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank); #else - coo_quickSort_parallel(indices, values, length, 1, rank); + coo_quickSort_parallel(indices, values, length, 1, rank); #endif - } - - BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SparseUtils, , LIBND4J_TYPES); - } } + +BUILD_SINGLE_TEMPLATE(template class SD_EXPORT SparseUtils, , LIBND4J_TYPES); +} // namespace sparse +} // namespace sd diff --git a/libnd4j/include/ops/meta_ops.h b/libnd4j/include/ops/meta_ops.h index a2120477d72e..defa1d0853d3 100644 --- a/libnd4j/include/ops/meta_ops.h +++ b/libnd4j/include/ops/meta_ops.h @@ -18,157 +18,157 @@ #ifndef FUSED_OPS_H_ #define FUSED_OPS_H_ -#include -#include - #include +#include +#include namespace metaOps { - /** - * InvertedMetaOp shares the same idea as MetaOp, but op being applied to op.Y in pairwise/broadcast ops +/** + * InvertedMetaOp shares the same idea as MetaOp, but op being applied to op.Y + * in pairwise/broadcast ops + */ +template +class InvertedMetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + /* + * PREDICATE + */ + + // scalar, transform, reduce, indexreduce entry + op_def static T + op(T d1, T *params) { + /* + * We assume, that this method won't be EVER called + */ + printf("You should NEVER see this message in output\n"); + return (T)0.0f; + } + + // PWT, broadcast entry. Predicate can be only scalar, transform + op_def static T op(T d1, T d2, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, d2, paramsA), paramsB); + } + + /* + * POSTULATE + */ + + // will be called for reduce, reduce3 + op_def static T postProcess(T reduction, Nd4jLong n, T *params) { + /* + * We assume, that this method won't be EVER called + */ + printf("You should NEVER EVER see this message in output\n"); + + return (T)0.0f; + } +}; + +/** + * Special case here: MetaOp which consist of 2 operations. + * + * Predicate can be either scalar or transform, to process data before actual op + * call Postulate will be the scalar/transform, but will be applied to result of + * broadcast/reduce/reduce3 + */ +template +class MetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + /* + * PREDICATE + */ + + meta_def static T + startingValue(const T *input) { + return (T)0.0f; + } + + // scalar, transform, reduce, indexreduce entry + meta_def static T op(T d1, T *params) { + /* + * We assume, that params for MetaOp is a set of pointers to actual op A & B + * extraArgs */ - template - class InvertedMetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - /* - * PREDICATE - */ - - // scalar, transform, reduce, indexreduce entry - op_def static T op(T d1, T *params) { - /* - * We assume, that this method won't be EVER called - */ - printf("You should NEVER see this message in output\n"); - return (T) 0.0f; - } - - // PWT, broadcast entry. Predicate can be only scalar, transform - op_def static T op(T d1, T d2, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, d2, paramsA), paramsB); - } - - /* - * POSTULATE - */ - - // will be called for reduce, reduce3 - op_def static T postProcess(T reduction, Nd4jLong n, T *params) { - /* - * We assume, that this method won't be EVER called - */ - printf("You should NEVER EVER see this message in output\n"); - - return (T) 0.0f; - } - }; - - - /** - * Special case here: MetaOp which consist of 2 operations. - * - * Predicate can be either scalar or transform, to process data before actual op call - * Postulate will be the scalar/transform, but will be applied to result of broadcast/reduce/reduce3 - */ - template - class MetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - /* - * PREDICATE - */ - - meta_def static T startingValue(const T *input) { - return (T) 0.0f; - } - - // scalar, transform, reduce, indexreduce entry - meta_def static T op(T d1, T *params) { - /* - * We assume, that params for MetaOp is a set of pointers to actual op A & B extraArgs - */ - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); - } - - // PWT, broadcast entry. Predicate can be only scalar, transform - meta_def static T op(T d1, T d2, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), d2, paramsB); - } - - /* - * POSTULATE - */ - - // will be called for reduce, reduce3 - meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::postProcess(reduction, n, paramsA), paramsB); - } - }; - - - template - class ReduceMetaOp { - public: - no_op_exec_special - no_op_exec_special_cuda - - meta_def static T startingValue(const T *input) { - return OpTypeB::startingValue(input); - } - - meta_def static T merge(T old, T opOutput, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); -// T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::merge(old, opOutput, paramsB); - } - - meta_def static T update(T old, T opOutput, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - //T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::update(old, opOutput, paramsB); - } - - meta_def static T op(T d1, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); - T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); - } - - meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { - Nd4jPointer *wrap = reinterpret_cast (params); -// T *paramsA = reinterpret_cast (wrap[0]); - T *paramsB = reinterpret_cast (wrap[1]); - - return OpTypeB::postProcess(reduction, n, paramsB); - } - }; -} + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); + } + + // PWT, broadcast entry. Predicate can be only scalar, transform + meta_def static T op(T d1, T d2, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), d2, paramsB); + } + + /* + * POSTULATE + */ + + // will be called for reduce, reduce3 + meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::postProcess(reduction, n, paramsA), paramsB); + } +}; + +template +class ReduceMetaOp { + public: + no_op_exec_special no_op_exec_special_cuda + + meta_def static T + startingValue(const T *input) { + return OpTypeB::startingValue(input); + } + + meta_def static T merge(T old, T opOutput, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::merge(old, opOutput, paramsB); + } + + meta_def static T update(T old, T opOutput, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::update(old, opOutput, paramsB); + } + + meta_def static T op(T d1, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + T *paramsA = reinterpret_cast(wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::op(OpTypeA::op(d1, paramsA), paramsB); + } + + meta_def static T postProcess(T reduction, Nd4jLong n, T *params) { + Nd4jPointer *wrap = reinterpret_cast(params); + // T *paramsA = reinterpret_cast (wrap[0]); + T *paramsB = reinterpret_cast(wrap[1]); + + return OpTypeB::postProcess(reduction, n, paramsB); + } +}; +} // namespace metaOps #endif \ No newline at end of file diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 21cd07c40d0d..065534883a95 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -18,13 +18,14 @@ #ifndef OPS_H_ #define OPS_H_ -#include #include #include -#include -#include -#include #include +#include +#include +#include + +#include #define MIN_V 1e-12 #define MAX_FLOAT 1e37 @@ -37,21 +38,94 @@ #define DOUBLE_PI_T T(2.0 * 3.14159265358979323846) #define DOUBLE_PI_X X(2.0 * 3.14159265358979323846) -#define no_op_exec_special_any static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_bool static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_same static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, X *result, const Nd4jLong *resultShapeBuffer, X *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special static const bool requiresSpecial = false; static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, Z *extraParams, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} -#define no_op_exec_special_accumulation_long static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} -#define no_op_exec_special_accumulation_same static const bool requiresSpecialAccumulation = false; static void execSpecial(const X *x, const Nd4jLong *xShapeInfo, X *extraParams, X *result, const Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset){} +#define no_op_exec_special_any \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_bool \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_same \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + X *result, const Nd4jLong *resultShapeBuffer, \ + X *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special \ + static const bool requiresSpecial = false; \ + static void execSpecial(const X *dx, const Nd4jLong *xShapeBuffer, \ + Z *result, const Nd4jLong *resultShapeBuffer, \ + Z *extraParams, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} +#define no_op_exec_special_accumulation_long \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} +#define no_op_exec_special_accumulation_same \ + static const bool requiresSpecialAccumulation = false; \ + static void execSpecial( \ + const X *x, const Nd4jLong *xShapeInfo, X *extraParams, X *result, \ + const Nd4jLong *resultShapeInfoBuffer, int *dimension, \ + int dimensionLength, const Nd4jLong *tadShapeInfo, \ + const Nd4jLong *tadOffset) {} #ifdef __CUDACC__ -#define no_op_exec_special_any_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_bool_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, Z *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_same_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer, X *result, const Nd4jLong *resultShapeBuffer, X *extraParams, int *allocationPointer, X *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_cuda static __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeBuffer,Z *result, const Nd4jLong *resultShapeBuffer,Z *extraParams, int *allocationPointer, Z *reductionPointer, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_same_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, X *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, X *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_long_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} -#define no_op_exec_special_accumulation_cuda static inline __device__ void execSpecialCuda(const X *dx, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_any_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_bool_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_same_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, X *result, \ + const Nd4jLong *resultShapeBuffer, X *extraParams, \ + int *allocationPointer, X *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_cuda \ + static __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeBuffer, Z *result, \ + const Nd4jLong *resultShapeBuffer, Z *extraParams, \ + int *allocationPointer, Z *reductionPointer, \ + const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_same_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, X *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + X *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_long_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, X *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} +#define no_op_exec_special_accumulation_cuda \ + static inline __device__ void execSpecialCuda( \ + const X *dx, const Nd4jLong *xShapeInfo, Z *extraParams, Z *result, \ + const Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, \ + Z *reductionBuffer, const Nd4jLong *tadOnlyShapeInfo, \ + const Nd4jLong *tadOffsets) {} #else // hacky fix for isnan/being being out of scope @@ -73,4610 +147,4250 @@ #define no_op_exec_special_accumulation_same_cuda #endif - #define SELU_ALPHA 1.6732632423543772848170429916717 #define SELU_LAMBDA 1.0507009873554804934193349852946 - namespace functions { - namespace indexreduce { - template - struct IndexValue { - T value; - Nd4jLong index; - _CUDA_HD IndexValue() = default; - _CUDA_HD IndexValue(const T val, const Nd4jLong ind): index(ind), value(val) {} - }; - } - - namespace summarystats { - template - class SummaryStatsData; - } +namespace indexreduce { +template +struct IndexValue { + T value; + Nd4jLong index; + _CUDA_HD IndexValue() = default; + _CUDA_HD IndexValue(const T val, const Nd4jLong ind) + : index(ind), value(val) {} +}; +} // namespace indexreduce + +namespace summarystats { +template +class SummaryStatsData; } +} // namespace functions namespace simdOps { - template - class Add { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 + d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 + d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 + params[0]); - } - - op_def static X startingValue() { - return static_cast(0.f); - } - }; - - template - class NewAdd { - public: - op_def static X op(X d1, Y d2, X *params) { - return d1 + d2; - } - }; - - template - class Subtract { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 - d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 - d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 - params[0]); - } - - }; - - template - class SquaredSubtract { - public: - op_def static Z op(X d1, Y d2) { - auto d = static_cast(d1 - d2); - return d * d; - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto d = static_cast(d1 - d2); - return d * d; - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto d = static_cast(d1 - params[0]); - return d * d; - } - }; - - template - class SquaredReverseSubtract { - public: - op_def static Z op(X d1, Y d2) { - auto d = static_cast(d2 - d1); - return d * d; - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto d = static_cast(d2 - d1); - return d * d; - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto d = static_cast(params[0] - d1); - return d * d; - } - }; - - template - class ReverseSubtract { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 - d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2 - d1); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(params[0] - d1); - } - }; - - - template - class LogPoissonLossFull { - - public: - op_def static Z op(X z, Y c) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc + (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); - } - - op_def static Z op(X z, Y c, Z *params) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc + (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); - } - - op_def static Z op(X z) { - auto zz = static_cast(z); - return (zz * sd::math::nd4j_log(z) - zz + static_cast(0.5f) * sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz)); - } - - // op for MetaOps - op_def static X op(X z, Y *params) { - return (sd::math::nd4j_exp(params[0]) - z * params[0] + (z * sd::math::nd4j_log(z) - z + static_cast(0.5f) * sd::math::nd4j_log(DOUBLE_PI_X * z))); - } - }; - - template - class LogPoissonLoss { - - public: - op_def static Z op(X z, Y c) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc); - } - - op_def static Z op(X z, Y c, Z *params) { - auto zz = static_cast(z); - auto zc = static_cast(c); - return (sd::math::nd4j_exp(c) - zz * zc); - } - - op_def static Z op(X z) { - return static_cast(z); - } - - // op for MetaOps - op_def static Z op(X z, Y *params) { - return (sd::math::nd4j_exp(params[0]) - static_cast(z) * static_cast(params[0])); - } - }; - - template - class Multiply { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 * d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 * d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 * params[0]); - } - - op_def static X startingValue() { - return static_cast(1.f); - } - }; - - template - class Divide { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(d1 / params[0]); - } - - op_def static X startingValue() { - return static_cast(1); - } - }; - - template - class DivideNoNan { - public: - op_def static Z op(X d1, Y d2) { - if (d2 == (Y)0) return (Z)0; - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - if (d2 == (Y)0) return (Z)0; - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - if (params[0] == (Y)0) return (Z)0; - return static_cast(d1 / params[0]); - } - - op_def static X startingValue() { - return static_cast(1); - } - }; - - template - class SafeDivide { - public: - op_def static Z op(X d1, Y d2) { - if(d2 == static_cast(0)) - return static_cast(0); - return static_cast(d1 / d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - if(d2 == static_cast(0)) - return static_cast(0); - return static_cast(d1 / d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - if(params[0] == static_cast(0)) - return static_cast(0); - return static_cast(d1 / params[0]); - } - }; - - template - class FloorDiv { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_floor(static_cast(d1 / d2)); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_floor(static_cast(d1 / d2)); - } - - op_def static Z op(X d1) { - return sd::math::nd4j_floor(static_cast(d1)); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_floor(static_cast(d1 / params[0])); - } - }; - - template - class TruncateDiv { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 / i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 / i2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(params[0]); - return static_cast(i1 / i2); - } - }; - - template - class TruncateMod { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 % i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - return static_cast(i1 % i2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - auto i1 = static_cast(d1); - auto i2 = static_cast(params[0]); - return static_cast(i1 % i2); - } - }; - - template - class Remainder { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_remainder(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_remainder(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_remainder(d1, params[0]); - } - }; - - template - class FMod { - public: - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_fmod(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_fmod(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return sd::math::nd4j_fmod(d1, params[0]); - } - }; - - template - class FloorMod { - public: - op_def static Z op(X d1, Y d2) { - auto m = sd::math::nd4j_fmod(d1, d2); - return (d1 < static_cast(0)) == (d2 < static_cast(0)) ? m : sd::math::nd4j_fmod(m + static_cast(d2), d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto m = sd::math::nd4j_fmod(d1, d2); - return (d1 < static_cast(0.0f)) == (d2 < static_cast(0)) ? m : sd::math::nd4j_fmod(m + static_cast(d2), d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - template - class ReverseDivide { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 / d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2 / d1); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(params[0] / d1); - } - }; - - template - class CopyPws { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - op_def static Z op(X d1, Y *params) { - return static_cast(d1); - } - }; - - template - class Copy { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1; - } - }; - - - template - class Copy2 { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - - op_def static Z op(X d1, Y *params) { - return static_cast(d1); - } - }; - - template - class Axpy { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2 + d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - auto alpha = params[0]; - return alpha * static_cast(d1) + static_cast(d2); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - }; - - template - class Assign { - public: - no_op_exec_special_any - no_op_exec_special_any_cuda - - op_def static Z op(X d1, X *params) { - return static_cast(d1); - } - }; - - template - class And { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - return d1 != comp && d2 != comp ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return (b1 && b2) ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, X *params) { - return static_cast(119); - } - }; - - template - class IntOr { - public: - - op_def static X op(X d1, X d2) { - return d2 | d1; - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class IntAnd { - public: - - op_def static X op(X d1, X d2) { - return d2 & d1; - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class IntXor { - public: - - op_def static X op(X d1, X d2) { - return d2 ^ d1; - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class ShiftLeft { - public: - - op_def static X op(X d1, X d2) { - return d1 << d2; - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class ShiftRight { - public: - - op_def static X op(X d1, X d2) { - return d1 >> d2; - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class CyclicShiftLeft { - public: - - op_def static X op(X d1, X d2) { - return sd::math::nd4j_rotl(d1, d2); - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - template - class CyclicShiftRight { - public: - - op_def static X op(X d1, X d2) { - return sd::math::nd4j_rotr(d1, d2); - } - - op_def static X op(X d1, X d2, X *params) { - return op(d1, d2); - } - }; - - - template - class Or { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - - return d1 != comp || d2 != comp ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return b1 || b2 ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, X *params) { - return static_cast(119); - } - }; - - template - class Xor { - public: - - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return d2 + d1; - } - - op_def static Z op(X d1, X d2, X *params) { - if (params != nullptr) { - auto comp = params[0]; - - return ((d1 == comp && d2 != comp) || (d1 != comp && d2 == comp)) ? static_cast(1) : static_cast(0); - } else { - auto b1 = static_cast(d1); - auto b2 = static_cast(d2); - - return (!b1 && b2 )||(b1 && !b2) ? static_cast(1) : static_cast(0); - } - } - - op_def static Z op(X d1) { - return d1; - } - }; - - - template - class Not { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - op_def static Z op(X d1, X d2) { - return static_cast(0); - } - - op_def static Z op(X d1, X d2, X *params) { - return d1 != d2 ? static_cast(1) : static_cast(0); - } - - // this transform op should run only on boolean input - op_def static Z op(X d1, X *params) { - auto b1 = static_cast(d1); - return !b1; - } - }; - - template - class LogicalNot { - public: - op_def static Z op(X d1, Y d2) { - return !((int) d1 && (int) d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(!(static_cast(d1) && static_cast(d2))); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalXor { - public: - op_def static Z op(X d1, Y d2) { - auto i1 = static_cast(d1); - auto i2 = static_cast(d2); - - return (i1 | i2) &~ (i1 & i2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalAnd { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1) & static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(Y d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - template - class LogicalOr { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d1) | static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return static_cast(119); - } - }; - - - template - class Mod { - public: - /* - - // just a optional note, feel free to remove later - - op_def static half op(half d1, half d2, half *params) { - return __float2half(simdOps::Mod::op(__half2float(d1), __half2float(d2), nullptr)); - } - */ - - op_def static Z op(X d1, Y d2) { - return static_cast(d1) % static_cast(d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOp - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - template - class ReverseMod { - public: - op_def static Z op(X d1, Y d2) { - return static_cast(d2) % static_cast(d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOp - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - /** - * Whether 2 elements in an array - * are epsilion equal - */ - template - class Epsilon { - public: - - op_def static Z op(X d1, X d2) { - X diff = d1 - d2; - X absDiff = sd::math::nd4j_abs(diff); - if (absDiff <= static_cast(MIN_V)) - return static_cast(1); - return static_cast(0); - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class EqualTo { - public: - op_def static Z op(X d1, X d2) { - return d1 == d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class NotEqualTo { - public: - op_def static Z op(X d1, X d2) { - return d1 != d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - - template - class GreaterThanOrEqual { - public: - op_def static Z op(X d1, X d2) { - return d1 >= d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - // FIXME: this signature clashes with MetaOp stuff - op_def static Z op(X d1, X *params) { - return d1; - } - }; - - - template - class GreaterThan { - public: - op_def static Z op(X d1, X d2) { - return d1 > d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - // FIXME: this signature clashes with MetaOp stuff - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class LessThan { - public: - op_def static Z op(X d1, X d2) { - return d1 < d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class LessThanOrEqual { - public: - op_def static Z op(X d1, X d2) { - return d1 <= d2; - } - - op_def static Z op(X d1, X d2, X *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, X *params) { - return d1; - } - - }; - - - template - class Abs { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_abs(d1); - } - }; - - - template - class Ceiling { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_ceil(d1); - } - }; - - - template - class Cosine { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cos(d1); - } - }; - - - template - class Exp { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_exp(d1); - } - }; - - - template - class HardTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return ((d1 >= static_cast(-1.f) && d1 <= static_cast(1.f)) ? static_cast(1.f) : static_cast(0.f)); - } - }; - - - template - class HardTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 < static_cast(-1)) - return static_cast(-1); - else if (d1 > static_cast(1)) - return static_cast(1); - else - return d1; - - } - }; - - - template - class Floor { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_floor(d1); - } - }; - - - template - class Log { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(d1); - } - }; - - template - class Log1p { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(1 + d1); - } - }; - - template - class LogX { - public: - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_log(d1) / sd::math::nd4j_log(d2) ; - } - }; - - template - class StabilizeFP16 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 <= static_cast(0)) - return static_cast(sd::DataTypeUtils::min()); - else return d1; - } - }; - - template - class StabilizeX { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 <= static_cast(0)) - return sd::DataTypeUtils::min(); - else return d1; - } - }; - - template - class SpecialDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * (static_cast(1.f) - d1); - } - }; - - - template - class Neg { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return -d1; - } - }; - - template - class Erf { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_erf(d1); - } - }; - - - template - class Erfc { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_erfc(d1); - } - }; - - template - class Reciprocal { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda -// op_def static T op(T d1) { -// return (T(1.0f) / d1); -// } - // op for MetaOps - op_def static X op(X d1, X *params) { - return (static_cast(1) / d1); - } - }; - - template - class Sqr { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(d1, static_cast(2)); - } - - op_def static Z op(X d1) { - return sd::math::nd4j_pow(d1, static_cast(2)); - } - }; - - - template - class RelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_re(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class BinaryRelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - X threshold = params[0]; - return sd::math::nd4j_re(d1, d2) > threshold ? static_cast(1) : static_cast(0); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class BinaryMinimumAbsoluteRelativeError { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, X *params) { - X d2 = params[0]; - X thresholdRelative = params[1]; - X thresholdAbsolute = params[2]; - return sd::math::nd4j_re(d1, d2) > thresholdRelative ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < thresholdAbsolute ? static_cast(0) : static_cast(1)) : static_cast(0); - } - - op_def static Z op(X d1, Y d2, Z *params) { - X thresholdRelative = params[0]; - X thresholdAbsolute = params[1]; - return sd::math::nd4j_re(d1, d2) > thresholdRelative ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < thresholdAbsolute ? static_cast(0) : static_cast(1)) : static_cast(0); - } - - op_def static Z op(X d1) { - return static_cast(0); - } - }; - - template - class ReversePow { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(params[0], d1); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_pow(d2, d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_pow(d2, d1); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class Pow { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_pow(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_pow(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_pow(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - - template - class PowDerivative { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return params[0] * sd::math::nd4j_pow(d1, static_cast(params[0]) - static_cast(1.f)); - } - - op_def static Z op(X d1, Y d2) { - return static_cast(d2) * sd::math::nd4j_pow(d1, static_cast(d2) - static_cast(1.f)); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return static_cast(d2) * sd::math::nd4j_pow(d1, static_cast(d2) - static_cast(1.f)); - } - - op_def static Z op(X d1) { - return static_cast(d1); - } - }; - - - template - class IGamma { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_igamma(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_igamma(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_igamma(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class IGammac { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_igammac(d1, params[0]); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_igammac(d1, d2); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_igammac(d1, d2); - } - - op_def static Z op(X d1) { - return d1; - } - }; - - template - class Round { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_round(d1); - } - }; - - template - class IsNan { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isnan(d1) ? static_cast(1) : static_cast(0); - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class Expm1 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_exp(d1) - static_cast(1); - } - }; - - template - class IsPositive { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda +template +class Add { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 + d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 + d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 + params[0]); } + + op_def static X startingValue() { return static_cast(0.f); } +}; + +template +class NewAdd { + public: + op_def static X op(X d1, Y d2, X *params) { return d1 + d2; } +}; + +template +class Subtract { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 - d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 - d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 - params[0]); } +}; + +template +class SquaredSubtract { + public: + op_def static Z op(X d1, Y d2) { + auto d = static_cast(d1 - d2); + return d * d; + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto d = static_cast(d1 - d2); + return d * d; + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto d = static_cast(d1 - params[0]); + return d * d; + } +}; + +template +class SquaredReverseSubtract { + public: + op_def static Z op(X d1, Y d2) { + auto d = static_cast(d2 - d1); + return d * d; + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto d = static_cast(d2 - d1); + return d * d; + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto d = static_cast(params[0] - d1); + return d * d; + } +}; + +template +class ReverseSubtract { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 - d1); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2 - d1); } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(params[0] - d1); } +}; + +template +class LogPoissonLossFull { + public: + op_def static Z op(X z, Y c) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc + + (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); + } + + op_def static Z op(X z, Y c, Z *params) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc + + (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz))); + } + + op_def static Z op(X z) { + auto zz = static_cast(z); + return (zz * sd::math::nd4j_log(z) - zz + + static_cast(0.5f) * + sd::math::nd4j_log(static_cast(DOUBLE_PI_X) * zz)); + } + + // op for MetaOps + op_def static X op(X z, Y *params) { + return (sd::math::nd4j_exp(params[0]) - z * params[0] + + (z * sd::math::nd4j_log(z) - z + + static_cast(0.5f) * sd::math::nd4j_log(DOUBLE_PI_X * z))); + } +}; + +template +class LogPoissonLoss { + public: + op_def static Z op(X z, Y c) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc); + } + + op_def static Z op(X z, Y c, Z *params) { + auto zz = static_cast(z); + auto zc = static_cast(c); + return (sd::math::nd4j_exp(c) - zz * zc); + } + + op_def static Z op(X z) { return static_cast(z); } + + // op for MetaOps + op_def static Z op(X z, Y *params) { + return (sd::math::nd4j_exp(params[0]) - + static_cast(z) * static_cast(params[0])); + } +}; + +template +class Multiply { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 * d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 * d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 * params[0]); } + + op_def static X startingValue() { return static_cast(1.f); } +}; + +template +class Divide { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d1 / d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d1 / d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(d1 / params[0]); } + + op_def static X startingValue() { return static_cast(1); } +}; + +template +class DivideNoNan { + public: + op_def static Z op(X d1, Y d2) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + if (d2 == (Y)0) return (Z)0; + return static_cast(d1 / d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + if (params[0] == (Y)0) return (Z)0; + return static_cast(d1 / params[0]); + } + + op_def static X startingValue() { return static_cast(1); } +}; + +template +class SafeDivide { + public: + op_def static Z op(X d1, Y d2) { + if (d2 == static_cast(0)) return static_cast(0); + return static_cast(d1 / d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + if (d2 == static_cast(0)) return static_cast(0); + return static_cast(d1 / d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + if (params[0] == static_cast(0)) return static_cast(0); + return static_cast(d1 / params[0]); + } +}; + +template +class FloorDiv { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_floor(static_cast(d1 / d2)); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_floor(static_cast(d1 / d2)); + } + + op_def static Z op(X d1) { + return sd::math::nd4j_floor(static_cast(d1)); + } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_floor(static_cast(d1 / params[0])); + } +}; + +template +class TruncateDiv { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 / i2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 / i2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(params[0]); + return static_cast(i1 / i2); + } +}; + +template +class TruncateMod { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 % i2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + return static_cast(i1 % i2); + } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + auto i1 = static_cast(d1); + auto i2 = static_cast(params[0]); + return static_cast(i1 % i2); + } +}; + +template +class Remainder { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_remainder(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_remainder(d1, d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_remainder(d1, params[0]); + } +}; + +template +class FMod { + public: + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_fmod(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_fmod(d1, d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { + return sd::math::nd4j_fmod(d1, params[0]); + } +}; + +template +class FloorMod { + public: + op_def static Z op(X d1, Y d2) { + auto m = sd::math::nd4j_fmod(d1, d2); + return (d1 < static_cast(0)) == (d2 < static_cast(0)) + ? m + : sd::math::nd4j_fmod(m + static_cast(d2), d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + auto m = sd::math::nd4j_fmod(d1, d2); + return (d1 < static_cast(0.0f)) == (d2 < static_cast(0)) + ? m + : sd::math::nd4j_fmod(m + static_cast(d2), d2); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; + +template +class ReverseDivide { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 / d1); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2 / d1); } + + op_def static Z op(X d1) { return static_cast(d1); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(params[0] / d1); } +}; + +template +class CopyPws { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + op_def static Z op(X d1, Y *params) { return static_cast(d1); } +}; + +template +class Copy { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1; + } +}; + +template +class Copy2 { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2); } + + op_def static Z op(X d1, Y d2, Z *params) { return static_cast(d2); } + + op_def static Z op(X d1) { return static_cast(d1); } + + op_def static Z op(X d1, Y *params) { return static_cast(d1); } +}; + +template +class Axpy { + public: + op_def static Z op(X d1, Y d2) { return static_cast(d2 + d1); } + + op_def static Z op(X d1, Y d2, Z *params) { + auto alpha = params[0]; + return alpha * static_cast(d1) + static_cast(d2); + } + + op_def static Z op(X d1) { return static_cast(d1); } +}; + +template +class Assign { + public: + no_op_exec_special_any no_op_exec_special_any_cuda + + op_def static Z + op(X d1, X *params) { + return static_cast(d1); + } +}; + +template +class And { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + return d1 != comp && d2 != comp ? static_cast(1) : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + return (b1 && b2) ? static_cast(1) : static_cast(0); + } + } - op_def static Z op(X d1, X *params) { - return d1 > (X)0.f; - } + op_def static Z op(X d1) { return d1; } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + // op for MetaOps + op_def static Z op(X d1, X *params) { return static_cast(119); } +}; - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } +template +class IntOr { + public: + op_def static X op(X d1, X d2) { return d2 | d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class IntAnd { + public: + op_def static X op(X d1, X d2) { return d2 & d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class IntXor { + public: + op_def static X op(X d1, X d2) { return d2 ^ d1; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class ShiftLeft { + public: + op_def static X op(X d1, X d2) { return d1 << d2; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class ShiftRight { + public: + op_def static X op(X d1, X d2) { return d1 >> d2; } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class CyclicShiftLeft { + public: + op_def static X op(X d1, X d2) { return sd::math::nd4j_rotl(d1, d2); } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class CyclicShiftRight { + public: + op_def static X op(X d1, X d2) { return sd::math::nd4j_rotr(d1, d2); } + + op_def static X op(X d1, X d2, X *params) { return op(d1, d2); } +}; + +template +class Or { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + + return d1 != comp || d2 != comp ? static_cast(1) : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); + + return b1 || b2 ? static_cast(1) : static_cast(0); + } + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, X *params) { return static_cast(119); } +}; + +template +class Xor { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return d2 + d1; + } + + op_def static Z op(X d1, X d2, X *params) { + if (params != nullptr) { + auto comp = params[0]; + + return ((d1 == comp && d2 != comp) || (d1 != comp && d2 == comp)) + ? static_cast(1) + : static_cast(0); + } else { + auto b1 = static_cast(d1); + auto b2 = static_cast(d2); + + return (!b1 && b2) || (b1 && !b2) ? static_cast(1) : static_cast(0); + } + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class Not { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + op_def static Z + op(X d1, X d2) { + return static_cast(0); + } + + op_def static Z op(X d1, X d2, X *params) { + return d1 != d2 ? static_cast(1) : static_cast(0); + } + + // this transform op should run only on boolean input + op_def static Z op(X d1, X *params) { + auto b1 = static_cast(d1); + return !b1; + } +}; + +template +class LogicalNot { + public: + op_def static Z op(X d1, Y d2) { return !((int)d1 && (int)d2); } + + op_def static Z op(X d1, Y d2, Z *params) { + return static_cast(!(static_cast(d1) && static_cast(d2))); + } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class LogicalXor { + public: + op_def static Z op(X d1, Y d2) { + auto i1 = static_cast(d1); + auto i2 = static_cast(d2); + + return (i1 | i2) & ~(i1 & i2); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class LogicalAnd { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d1) & static_cast(d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(Y d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class LogicalOr { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d1) | static_cast(d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1) { return d1; } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return static_cast(119); } +}; + +template +class Mod { + public: + /* - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } + // just a optional note, feel free to remove later - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static half op(half d1, half d2, half *params) { + return __float2half(simdOps::Mod::op(__half2float(d1), + __half2float(d2), nullptr)); + } + */ - template - class IsNegative { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda + op_def static Z op(X d1, Y d2) { + return static_cast(d1) % static_cast(d2); + } - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - op_def static Z op(X d1, X *params) { - return d1 < (X)0.f; - } + // op for MetaOp + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; - op_def static X startingValue(const X *input) { - return static_cast(0); - } +template +class ReverseMod { + public: + op_def static Z op(X d1, Y d2) { + return static_cast(d2) % static_cast(d1); + } - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } + // op for MetaOp + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; +/** + * Whether 2 elements in an array + * are epsilion equal + */ +template +class Epsilon { + public: + op_def static Z op(X d1, X d2) { + X diff = d1 - d2; + X absDiff = sd::math::nd4j_abs(diff); + if (absDiff <= static_cast(MIN_V)) return static_cast(1); + return static_cast(0); + } - template - class IsInf { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z op(X d1, X *params) { return d1; } +}; - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isinf(d1) ? static_cast(1) : static_cast(0); - } +template +class EqualTo { + public: + op_def static Z op(X d1, X d2) { return d1 == d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class NotEqualTo { + public: + op_def static Z op(X d1, X d2) { return d1 != d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class GreaterThanOrEqual { + public: + op_def static Z op(X d1, X d2) { return d1 >= d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + // FIXME: this signature clashes with MetaOp stuff + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class GreaterThan { + public: + op_def static Z op(X d1, X d2) { return d1 > d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + // FIXME: this signature clashes with MetaOp stuff + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class LessThan { + public: + op_def static Z op(X d1, X d2) { return d1 < d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class LessThanOrEqual { + public: + op_def static Z op(X d1, X d2) { return d1 <= d2; } + + op_def static Z op(X d1, X d2, X *params) { return op(d1, d2); } + + op_def static Z op(X d1, X *params) { return d1; } +}; + +template +class Abs { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_abs(d1); + } +}; + +template +class Ceiling { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_ceil(d1); + } +}; + +template +class Cosine { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cos(d1); + } +}; + +template +class Exp { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_exp(d1); + } +}; + +template +class HardTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return ((d1 >= static_cast(-1.f) && d1 <= static_cast(1.f)) + ? static_cast(1.f) + : static_cast(0.f)); + } +}; + +template +class HardTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 < static_cast(-1)) + return static_cast(-1); + else if (d1 > static_cast(1)) + return static_cast(1); + else + return d1; + } +}; + +template +class Floor { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_floor(d1); + } +}; + +template +class Log { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(d1); + } +}; + +template +class Log1p { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(1 + d1); + } +}; + +template +class LogX { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_log(d1) / sd::math::nd4j_log(d2); + } +}; + +template +class StabilizeFP16 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 <= static_cast(0)) + return static_cast(sd::DataTypeUtils::min()); + else + return d1; + } +}; + +template +class StabilizeX { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 <= static_cast(0)) + return sd::DataTypeUtils::min(); + else + return d1; + } +}; + +template +class SpecialDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * (static_cast(1.f) - d1); + } +}; + +template +class Neg { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return -d1; + } +}; + +template +class Erf { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_erf(d1); + } +}; + +template +class Erfc { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_erfc(d1); + } +}; + +template +class Reciprocal { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + // op_def static T op(T d1) { + // return (T(1.0f) / d1); + // } + // op for MetaOps + op_def static X + op(X d1, X *params) { + return (static_cast(1) / d1); + } +}; + +template +class Sqr { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(d1, static_cast(2)); + } + + op_def static Z op(X d1) { + return sd::math::nd4j_pow(d1, static_cast(2)); + } +}; + +template +class RelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2) { + return sd::math::nd4j_re(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class BinaryRelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + X threshold = params[0]; + return sd::math::nd4j_re(d1, d2) > threshold ? static_cast(1) + : static_cast(0); + } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class BinaryMinimumAbsoluteRelativeError { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, X *params) { + X d2 = params[0]; + X thresholdRelative = params[1]; + X thresholdAbsolute = params[2]; + return sd::math::nd4j_re(d1, d2) > thresholdRelative + ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < + thresholdAbsolute + ? static_cast(0) + : static_cast(1)) + : static_cast(0); + } + + op_def static Z op(X d1, Y d2, Z *params) { + X thresholdRelative = params[0]; + X thresholdAbsolute = params[1]; + return sd::math::nd4j_re(d1, d2) > thresholdRelative + ? (sd::math::nd4j_abs(d1 - static_cast(d2)) < + thresholdAbsolute + ? static_cast(0) + : static_cast(1)) + : static_cast(0); + } + + op_def static Z op(X d1) { return static_cast(0); } +}; + +template +class ReversePow { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(params[0], d1); + } + + op_def static Z op(X d1, Y d2) { return sd::math::nd4j_pow(d2, d1); } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_pow(d2, d1); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class Pow { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_pow(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { return sd::math::nd4j_pow(d1, d2); } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_pow(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class PowDerivative { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return params[0] * sd::math::nd4j_pow( + d1, static_cast(params[0]) - static_cast(1.f)); + } + + op_def static Z op(X d1, Y d2) { + return static_cast(d2) * + sd::math::nd4j_pow( + d1, static_cast(d2) - static_cast(1.f)); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return static_cast(d2) * + sd::math::nd4j_pow( + d1, static_cast(d2) - static_cast(1.f)); + } + + op_def static Z op(X d1) { return static_cast(d1); } +}; + +template +class IGamma { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_igamma(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class IGammac { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_igammac(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1) { return d1; } +}; + +template +class Round { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_round(d1); + } +}; + +template +class IsNan { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isnan(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class Expm1 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_exp(d1) - static_cast(1); + } +}; + +template +class IsPositive { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return d1 > (X)0.f; + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } +template +class IsNegative { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return d1 < (X)0.f; + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class IsInf { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isinf(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class IsInfOrNan { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isfin(d1) ? static_cast(0) : static_cast(1); + } + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) && old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) && old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction != static_cast(0); + } +}; + +template +class IsFinite { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + op_def static Z + op(X d1, X *params) { + return sd::math::nd4j_isfin(d1) ? static_cast(1) : static_cast(0); + } + + op_def static X startingValue(const X *input) { return static_cast(1); } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) || old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput == static_cast(0) || old == static_cast(0) + ? static_cast(0) + : static_cast(1); + } + + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction != static_cast(0); + } +}; + +template +class ClipByValue { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + if (d1 > params[1]) return params[1]; + if (d1 < params[0]) return params[0]; + return d1; + } +}; + +template +class LstmClip { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + X _v = (X)d2; + if (d1 > _v) + return _v; + else if (d1 < -_v) + return -_v; + else + return d1; + } +}; + +template +class Swish { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_sigmoid(d1); + } +}; + +template +class Mish { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_tanh(sd::math::nd4j_softplus(d1)); + } +}; + +template +class MishDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto ex = sd::math::nd4j_exp(d1); + auto e2x = ex * ex; + auto e3x = ex * ex * ex; + + return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex * (4 * d1 + 6))) / + sd::math::nd4j_pow((2 * ex + e2x + 2), (X)2.f); + } +}; + +template +class GELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * sd::math::nd4j_sigmoid(static_cast(1.702f) * d1); + } +}; + +template +class PreciseGELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto sp = + sd::math::nd4j_sqrt(static_cast(2) / static_cast(M_PI)); + auto xp = d1 + sd::math::nd4j_pow(static_cast(0.044715) * d1, + static_cast(3)); + return (d1 / static_cast(2)) * + (static_cast(1) + sd::math::nd4j_tanh(sp * xp)); + } +}; + +template +class GELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto x17 = static_cast(1.702f) * d1; + auto ep = sd::math::nd4j_pow(static_cast(M_E), x17); + // (E^(1.702 x) (1. + E^(1.702 x) + 1.702 x))/(1. + E^(1.702 x))^2 + return (ep * (static_cast(1.f) + ep + x17)) / + sd::math::nd4j_pow((static_cast(1.f) + ep), 2); + } +}; + +template +class PreciseGELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto x79 = static_cast(0.797885) * d1; + auto x03 = sd::math::nd4j_pow(static_cast(0.0356774) * d1, 3); + auto x39 = static_cast(0.398942) * d1; + auto x05 = sd::math::nd4j_pow(static_cast(0.0535161) * d1, 3); + auto scz = sd::math::nd4j_sech(x79 + x03); + // 0.5 + (0.398942 x + 0.0535161 x^3) Sech[0.797885 x + 0.0356774 x^3]^2 + + // 0.5 Tanh[0.797885 x + 0.0356774 x^3] + return static_cast(0.5) + (x39 + x05) * (scz * scz) + + static_cast(0.5) * sd::math::nd4j_tanh(x79 + x03); + } +}; + +template +class SwishDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X ex = sd::math::nd4j_pow(static_cast(M_E), d1); + return (ex * (d1 + ex + static_cast(1.f))) / + sd::math::nd4j_pow((ex + static_cast(1.f)), + static_cast(2.f)); + } +}; + +template +class LogSigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_log(sd::math::nd4j_sigmoid(d1)); + } +}; + +template +class LogSigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X ex = sd::math::nd4j_pow(M_E, d1); + return static_cast(1.f) / (ex + static_cast(1.f)); + } +}; + +template +class Sigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sigmoid(d1); + } +}; + +template +class Affine { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return params[0] * d1 + params[1]; + } +}; + +template +class SigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sigmoidderivative(d1); + } +}; + +template +class HardSigmoid { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_min( + static_cast(1), + sd::math::nd4j_max(static_cast(0), (static_cast(0.2f)) * d1 + + static_cast(0.5f))); + } +}; + +template +class HardSigmoidDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 < static_cast(-2.5f) || d1 > static_cast(2.5f) + ? static_cast(0.f) + : static_cast(0.2f); + } +}; + +/** + * Scale to be between a min and max + */ +template +class SetRange { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto min = params[0]; + auto max = params[1]; + if (static_cast(d1) >= min && static_cast(d1) <= max) return d1; + if (min == static_cast(0) && max == static_cast(1)) { + auto val = static_cast(1) / + (static_cast(1) + sd::math::nd4j_exp(-d1)); + return (sd::math::nd4j_floor(val * (max - min)) + min); + } + + return (sd::math::nd4j_floor(d1 * (max - min)) + min); + } +}; + +template +class Sin { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sin(d1); + } +}; + +template +class Square { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * d1; + } +}; + +template +class Sqrt { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return sd::math::nd4j_sqrt(d1); + } +}; + +template +class RSqrt { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Z *params) { + return static_cast(1) / sd::math::nd4j_sqrt(d1); + } +}; + +template +class Rint { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_rint(d1); + } +}; + +template +class SoftPlus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softplus(d1); + } +}; + +template +class Sign { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return (d1 > static_cast(0)) - (d1 < static_cast(0)); + } +}; + +template +class TimesOneMinus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * (static_cast(1) - d1); + } +}; + +template +class RationalTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + // keep 2/3 as runtime variable, to match precision + auto dis = (static_cast(2) / static_cast(3)) * d1; + + auto tanh = + sd::math::nd4j_sgn(dis) * + (static_cast(1) - + (static_cast(1) / + (static_cast(1) + static_cast(sd::math::nd4j_abs(dis)) + + sd::math::nd4j_pow(dis, static_cast(2)) + + static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(4))))); + return static_cast(1.7159f) * tanh; + } +}; + +template +class RationalTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + auto dis = (static_cast(2.f) / static_cast(3.f)) * d1; + + auto a = static_cast(1.f) + sd::math::nd4j_abs(dis) + + sd::math::nd4j_pow(dis, static_cast(2.f)) + + static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(4)); + + auto tDeriv = + (static_cast(1.f) + + sd::math::nd4j_sign(dis) * + (static_cast(2.f) * dis + + static_cast(4.f) * static_cast(1.41645f) * + sd::math::nd4j_pow(dis, static_cast(3)))) / + (a * a); + + return static_cast(1.7159f) * + (static_cast(2.f) / static_cast(3.f)) * tDeriv; + } +}; + +template +class Tanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tanh(d1); + } +}; + +template +class ScaledTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return params[0] * sd::math::nd4j_tanh(params[1] * d1); + } +}; + +template +class RectifiedTanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_max(static_cast(0), + sd::math::nd4j_tanh(d1)); + } +}; + +template +class RectifiedTanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.f) ? sd::math::nd4j_tanhderivative(d1) + : static_cast(0.f); + } +}; + +template +class ATanh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_atanh(d1); + } +}; + +template +class TanhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tanhderivative(d1); + } +}; + +template +class Cube { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 * d1 * d1; + } +}; + +template +class CubeDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(3) * d1 * d1; + } +}; + +template +class ACos { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_acos(d1); + } +}; + +template +class ASinh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_asinh(d1); + } +}; + +template +class ASinhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + (sd::math::nd4j_sqrt( + sd::math::nd4j_pow(d1, static_cast(2.f)) + + static_cast(1.f))); + } +}; + +template +class ACosh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_acosh(d1); + } +}; + +template +class ACoshDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + (sd::math::nd4j_sqrt(d1 - static_cast(1.f)) * + sd::math::nd4j_sqrt(d1 + static_cast(1.f))); + } +}; + +template +class Ones { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.0f); + } +}; + +template +class SoftSign { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softsign(d1); + } +}; + +template +class SoftSignDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_softsignderivative(d1); + } +}; + +template +class MatchConditionBool { + public: + no_op_exec_special_bool no_op_exec_special_bool_cuda + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z + op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + // nd4j_printf("value: %f; comp: %f; eps: %f; mode: %i;\n", d1, compare, + // eps, mode); + + switch (mode) { + case 0: // equals + return sd::math::nd4j_abs(d1 - compare) <= eps ? true : false; + case 1: // not equals + return sd::math::nd4j_abs(d1 - compare) > eps ? true : false; + case 2: // less_than + return d1 < compare ? true : false; + case 3: // greater_than + return d1 > compare ? true : false; + case 4: // less_or_equals_than + return d1 <= compare ? true : false; + case 5: // greater_or_equals_than + return d1 >= compare ? true : false; + case 6: // abs_less_than + return sd::math::nd4j_abs(d1) < compare ? true : false; + case 7: // abs_greater_than + return sd::math::nd4j_abs(d1) > compare ? true : false; + case 8: // is inf + return sd::math::nd4j_isinf(d1) ? true : false; + case 9: // is nan + return sd::math::nd4j_isnan(d1) ? true : false; + case 10: + return (d1 == compare) ? true : false; + case 11: + return (d1 != compare) ? true : false; + case 12: // abs_greater_or_equals_than + return sd::math::nd4j_abs(d1) >= compare ? true : false; + case 13: // abs_less_or_equals_than + return sd::math::nd4j_abs(d1) <= compare ? true : false; + case 14: + // isFinite + return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)); + case 15: + // isInfinite + return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1); + default: + printf("Undefined match condition: [%i]\n", mode); + } + + return d1; + } +}; + +template +class MatchCondition { + public: + no_op_exec_special no_op_exec_special_cuda + + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda + + op_def static Z + startingValue(const X *input) { + return static_cast(0); + } + + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return old + opOutput; + } + + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return old + opOutput; + } + + op_def static Z op(X d1, X compare, X eps, int mode) { + switch (mode) { + case 0: // equals + return sd::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; + case 1: // not equals + return sd::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; + case 2: // less_than + return d1 < compare ? 1 : 0; + case 3: // greater_than + return d1 > compare ? 1 : 0; + case 4: // less_or_equals_than + return d1 <= compare ? 1 : 0; + case 5: // greater_or_equals_than + return d1 >= compare ? 1 : 0; + case 6: // abs_less_than + return sd::math::nd4j_abs(d1) < compare ? 1 : 0; + case 7: // abs_greater_than + return sd::math::nd4j_abs(d1) > compare ? 1 : 0; + case 8: // is inf + return sd::math::nd4j_isinf(d1) ? 1 : 0; + case 9: // is nan + return sd::math::nd4j_isnan(d1) ? 1 : 0; + case 10: + return (d1 == compare) ? 1 : 0; + case 11: + return (d1 != compare) ? 1 : 0; + case 12: // abs_greater_or_equals_than + return sd::math::nd4j_abs(d1) >= compare ? 1 : 0; + case 13: // abs_less_or_equals_than + return sd::math::nd4j_abs(d1) <= compare ? 1 : 0; + case 14: + // isFinite + return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)) ? 1 : 0; + case 15: + // isInfinite + return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1) ? 1 : 0; + default: + printf("Undefined match condition: [%i]\n", mode); + } + + return d1; + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X compare, X *extraParams) { + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[0]); + + return op(d1, compare, eps, mode); + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + + return op(d1, compare, eps, mode); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_elu(d1, static_cast(d2)); + } +}; + +template +class ELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_eluderivative(d1, static_cast(d2)); + } +}; + +template +class RELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt < xf ? xf : xt; + } +}; + +template +class RELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt > xf ? static_cast(1.f) : static_cast(0.f); + } +}; + +template +class SXELogitsSmoother { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return d1 * ((X)1.f - (X)d2) + (X)(0.5f) * (X)d2; + } +}; + +template +class RELU6 { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto relu = simdOps::RELU::op(d1, d2, params); + return relu < static_cast(6) ? relu : static_cast(6); + } +}; + +template +class LeakyRELU { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + auto val = static_cast(d1); + auto alpha = static_cast(d2); + return val < 0.0f ? alpha * val : val; + } +}; + +template +class SELU { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.0f) + ? static_cast(SELU_LAMBDA) * static_cast(d1) + : static_cast(SELU_LAMBDA) * + (static_cast(SELU_ALPHA) * + sd::math::nd4j_exp(d1) - + static_cast(SELU_ALPHA)); + } +}; + +template +class SELUDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1 > static_cast(0.f) + ? static_cast(SELU_LAMBDA) + : static_cast(SELU_ALPHA) * static_cast(SELU_LAMBDA) * + sd::math::nd4j_exp(d1); + } +}; + +template +class LeakyRELUDerivative { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + if (d1 >= static_cast(0)) + return static_cast(1); + else + return static_cast(d2); + } +}; + +template +class ASin { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_asin(d1); + } +}; + +template +class Sinh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_sinh(d1); + } +}; + +template +class SinhDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cosh(d1); + } +}; + +template +class Cosh { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_cosh(d1); + } +}; + +template +class Tan { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_tan(d1); + } +}; + +template +class TanDerivative { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1.f) / + sd::math::nd4j_pow(sd::math::nd4j_cos(d1), + static_cast(2.0f)); + } +}; + +template +class ATan { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return sd::math::nd4j_atan(d1); + } +}; + +template +class Atan2 { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2) { + return sd::math::nd4j_atan2(d2, d1); + } + + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + // op for MetaOps + op_def static Z op(X d1, Y *params) { return op(d1, params[0]); } +}; + +template +class Identity { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return d1; + } +}; + +template +class Stabilize { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + X k = params[0]; + if (d1 * k > static_cast(-MIN_CUTFOFF)) + return static_cast(-MIN_CUTFOFF) / k; + else if (d1 * k < static_cast(MIN_CUTFOFF)) + return static_cast(MIN_CUTFOFF) / k; + return d1; + } +}; + +template +class Step { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return (d1 > static_cast(d2) ? static_cast(1) : static_cast(0)); + } +}; + +template +class OneMinus { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + op_def static X + op(X d1, X *params) { + return static_cast(1) - d1; + } +}; + +template +class Sum { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + op_def static X + startingValue(const X *input) { + return static_cast(0.0f); + } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X op(X d1, X *extraParams) { return d1; } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ReduceSameBenchmarkOp { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static X op(X d1, X *extraParams) { + auto f1 = static_cast(d1); + return static_cast( + sd::math::nd4j_pow(f1, 3) + + sd::math::nd4j_log(f1) * + sd::math::nd4j_sin(f1) / + sd::math::nd4j_tanh(static_cast(M_E) * + static_cast(M_PI) * f1) * + sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - + sd::math::nd4j_atan(static_cast(M_E) / f1)); + } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class ShannonEntropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { + auto p = d1 * d1; + return static_cast(p) * sd::math::nd4j_log(p); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return -reduction; + } +}; + +template +class LogEntropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + op_def static Z op(X d1, Z *extraParams) { + return static_cast(d1) * sd::math::nd4j_log(d1); + } - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class IsInfOrNan{ - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isfin(d1) ? static_cast(0) : static_cast(1); - } - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) && old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) && old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction != static_cast(0); - } - }; - - - - template - class IsFinite { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - op_def static Z op(X d1, X *params) { - return sd::math::nd4j_isfin(d1) ? static_cast(1) : static_cast(0); - } - - op_def static X startingValue(const X *input) { - return static_cast(1); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) || old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput == static_cast(0) || old == static_cast(0) ? static_cast(0) : static_cast(1); - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction != static_cast(0); - } - }; - - - template - class ClipByValue { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - if (d1 > params[1]) - return params[1]; - if (d1 < params[0]) - return params[0]; - return d1; - } - }; - - template - class LstmClip { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - X _v = (X) d2; - if (d1 > _v) - return _v; - else if (d1 < -_v) - return -_v; - else return d1; - } - }; - - template - class Swish { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_sigmoid(d1); - } - }; - - template - class Mish { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_tanh(sd::math::nd4j_softplus(d1)); - } - }; - - template - class MishDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto ex = sd::math::nd4j_exp(d1); - auto e2x = ex * ex; - auto e3x = ex * ex * ex; - - return (ex * (4 * (d1 + 1) + 4 * e2x + e3x + ex *(4 * d1 + 6))) / sd::math::nd4j_pow((2 * ex + e2x + 2), (X) 2.f); - } - }; - - template - class GELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * sd::math::nd4j_sigmoid(static_cast(1.702f) * d1); - } - }; - - template - class PreciseGELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto sp = sd::math::nd4j_sqrt(static_cast(2) / static_cast(M_PI)); - auto xp = d1 + sd::math::nd4j_pow(static_cast(0.044715) * d1, static_cast(3)); - return (d1 / static_cast(2)) * (static_cast(1) + sd::math::nd4j_tanh(sp * xp)); - } - }; - - template - class GELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto x17 = static_cast(1.702f) * d1; - auto ep = sd::math::nd4j_pow(static_cast(M_E), x17); - // (E^(1.702 x) (1. + E^(1.702 x) + 1.702 x))/(1. + E^(1.702 x))^2 - return (ep * (static_cast(1.f) + ep + x17)) / sd::math::nd4j_pow((static_cast(1.f) + ep), 2); - } - }; - - template - class PreciseGELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto x79 = static_cast(0.797885) * d1; - auto x03 = sd::math::nd4j_pow(static_cast(0.0356774) * d1, 3); - auto x39 = static_cast(0.398942) * d1; - auto x05 = sd::math::nd4j_pow(static_cast(0.0535161) * d1, 3); - auto scz = sd::math::nd4j_sech(x79 + x03); - // 0.5 + (0.398942 x + 0.0535161 x^3) Sech[0.797885 x + 0.0356774 x^3]^2 + 0.5 Tanh[0.797885 x + 0.0356774 x^3] - return static_cast(0.5) + (x39 + x05) * (scz * scz) + static_cast(0.5) * sd::math::nd4j_tanh(x79 + x03); - } - }; - - - template - class SwishDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X ex = sd::math::nd4j_pow(static_cast(M_E), d1); - return (ex * (d1 + ex + static_cast(1.f))) / sd::math::nd4j_pow((ex + static_cast(1.f)) , static_cast(2.f)); - } - }; - - - template - class LogSigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_log(sd::math::nd4j_sigmoid(d1)); - } - }; - - template - class LogSigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X ex = sd::math::nd4j_pow(M_E, d1); - return static_cast(1.f) / (ex + static_cast(1.f)); - } - }; - - template - class Sigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sigmoid(d1); - } - }; - - template - class Affine { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return params[0] * d1 + params[1]; - } - }; - - template - class SigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sigmoidderivative(d1); - } - }; - - - template - class HardSigmoid { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_min(static_cast(1), sd::math::nd4j_max(static_cast(0), (static_cast(0.2f)) * d1 + static_cast(0.5f))); - } - }; - - template - class HardSigmoidDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 < static_cast(-2.5f) || d1 > static_cast(2.5f) ? static_cast(0.f) : static_cast(0.2f); - } - }; - - - /** - * Scale to be between a min and max - */ - template - class SetRange { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto min = params[0]; - auto max = params[1]; - if (static_cast(d1) >= min && static_cast(d1) <= max) - return d1; - if (min == static_cast(0) && max == static_cast(1)) { - auto val = static_cast(1) / (static_cast(1) + sd::math::nd4j_exp(-d1)); - return (sd::math::nd4j_floor(val * (max - min)) + min); - } - - return (sd::math::nd4j_floor(d1 * (max - min)) + min); - } - }; - - - template - class Sin { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sin(d1); - } - }; - - template - class Square { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * d1; - } - }; - - template - class Sqrt { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return sd::math::nd4j_sqrt(d1); - } - }; - - template - class RSqrt { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Z *params) { - return static_cast(1) / sd::math::nd4j_sqrt(d1); - } - }; - - template - class Rint { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_rint(d1); - } - }; - - - template - class SoftPlus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softplus(d1); - } - }; - - - template - class Sign { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return (d1 > static_cast(0)) - (d1 < static_cast(0)); - } - }; - - - template - class TimesOneMinus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * (static_cast(1) - d1); - } - }; - - - template - class RationalTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - // keep 2/3 as runtime variable, to match precision - auto dis = (static_cast(2) / static_cast(3)) * d1; - - auto tanh = sd::math::nd4j_sgn(dis) * (static_cast(1) - (static_cast(1) / (static_cast(1) + static_cast(sd::math::nd4j_abs(dis)) + sd::math::nd4j_pow(dis, static_cast(2)) + static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(4)) ))); - return static_cast(1.7159f) * tanh; - } - }; - - template - class RationalTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - auto dis = (static_cast(2.f) / static_cast(3.f)) * d1; - - auto a = static_cast(1.f) + sd::math::nd4j_abs(dis) + sd::math::nd4j_pow(dis, static_cast(2.f)) + static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(4)); - - auto tDeriv = (static_cast(1.f) + sd::math::nd4j_sign(dis) * (static_cast(2.f) * dis + static_cast(4.f) * static_cast(1.41645f) * sd::math::nd4j_pow(dis, static_cast(3)))) / (a * a); - - return static_cast(1.7159f) * (static_cast(2.f) / static_cast(3.f)) * tDeriv; - } - }; - - template - class Tanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tanh(d1); - } - }; - - template - class ScaledTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return params[0] * sd::math::nd4j_tanh(params[1] * d1); - } - }; - - template - class RectifiedTanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_max(static_cast(0), sd::math::nd4j_tanh(d1)); - } - }; - - template - class RectifiedTanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.f) ? sd::math::nd4j_tanhderivative(d1) : static_cast(0.f); - } - }; - - template - class ATanh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_atanh(d1); - } - }; - - template - class TanhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tanhderivative(d1); - } - }; - - template - class Cube { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 * d1 * d1; - } - }; - - - template - class CubeDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(3) * d1 * d1; - } - }; - - template - class ACos { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_acos(d1); - } - }; - - template - class ASinh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_asinh(d1); - } - }; - - template - class ASinhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / (sd::math::nd4j_sqrt(sd::math::nd4j_pow(d1, static_cast(2.f)) + static_cast(1.f))); - } - }; - - template - class ACosh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_acosh(d1); - } - }; - - - template - class ACoshDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / (sd::math::nd4j_sqrt(d1 - static_cast(1.f)) * sd::math::nd4j_sqrt(d1 + static_cast(1.f))); - } - }; - - - - template - class Ones { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.0f); - } - }; - - - - template - class SoftSign { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softsign(d1); - } - }; - - - template - class SoftSignDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_softsignderivative(d1); - } - }; - - template - class MatchConditionBool { - public: - no_op_exec_special_bool - no_op_exec_special_bool_cuda - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - //nd4j_printf("value: %f; comp: %f; eps: %f; mode: %i;\n", d1, compare, eps, mode); - - switch (mode) { - case 0: // equals - return sd::math::nd4j_abs(d1 - compare) <= eps ? true : false; - case 1: // not equals - return sd::math::nd4j_abs(d1 - compare) > eps ? true : false; - case 2: // less_than - return d1 < compare ? true : false; - case 3: // greater_than - return d1 > compare ? true : false; - case 4: // less_or_equals_than - return d1 <= compare ? true : false; - case 5: // greater_or_equals_than - return d1 >= compare ? true : false; - case 6: // abs_less_than - return sd::math::nd4j_abs(d1) < compare ? true : false; - case 7: // abs_greater_than - return sd::math::nd4j_abs(d1) > compare ? true : false; - case 8: // is inf - return sd::math::nd4j_isinf(d1) ? true : false; - case 9: // is nan - return sd::math::nd4j_isnan(d1) ? true : false; - case 10: - return (d1 == compare) ? true : false; - case 11: - return (d1 != compare) ? true : false; - case 12: // abs_greater_or_equals_than - return sd::math::nd4j_abs(d1) >= compare ? true : false; - case 13: // abs_less_or_equals_than - return sd::math::nd4j_abs(d1) <= compare ? true : false; - case 14: - // isFinite - return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)); - case 15: - // isInfinite - return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1); - default: - printf("Undefined match condition: [%i]\n", mode); - } - - return d1; - } - }; - - template - class MatchCondition { - public: - no_op_exec_special - no_op_exec_special_cuda - - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda - - op_def static Z startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return old + opOutput; - } - - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return old + opOutput; - } - - op_def static Z op(X d1, X compare, X eps, int mode) { - switch (mode) { - case 0: // equals - return sd::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; - case 1: // not equals - return sd::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; - case 2: // less_than - return d1 < compare ? 1 : 0; - case 3: // greater_than - return d1 > compare ? 1 : 0; - case 4: // less_or_equals_than - return d1 <= compare ? 1 : 0; - case 5: // greater_or_equals_than - return d1 >= compare ? 1 : 0; - case 6: // abs_less_than - return sd::math::nd4j_abs(d1) < compare ? 1 : 0; - case 7: // abs_greater_than - return sd::math::nd4j_abs(d1) > compare ? 1 : 0; - case 8: // is inf - return sd::math::nd4j_isinf(d1) ? 1 : 0; - case 9: // is nan - return sd::math::nd4j_isnan(d1) ? 1 : 0; - case 10: - return (d1 == compare) ? 1 : 0; - case 11: - return (d1 != compare) ? 1 : 0; - case 12: // abs_greater_or_equals_than - return sd::math::nd4j_abs(d1) >= compare ? 1 : 0; - case 13: // abs_less_or_equals_than - return sd::math::nd4j_abs(d1) <= compare ? 1 : 0; - case 14: - // isFinite - return !(sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1)) ? 1 : 0; - case 15: - // isInfinite - return sd::math::nd4j_isinf(d1) || sd::math::nd4j_isnan(d1) ? 1 : 0; - default: - printf("Undefined match condition: [%i]\n", mode); - } - - return d1; - } - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X compare, X *extraParams) { - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[0]); - - return op(d1, compare, eps, mode); - } - - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - - return op(d1, compare, eps, mode); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class ELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_elu(d1, static_cast(d2)); - } - }; - - - template - class ELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_eluderivative(d1, static_cast(d2)); - } - }; - - - template - class RELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto xt = static_cast(d1); - auto xf = static_cast(d2); - return xt < xf ? xf : xt; - } - }; - - template - class RELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto xt = static_cast(d1); - auto xf = static_cast(d2); - return xt > xf ? static_cast(1.f) : static_cast(0.f); - } - }; - - template - class SXELogitsSmoother { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return d1 * ((X)1.f - (X) d2) + (X)(0.5f) * (X) d2; - } - }; - - template - class RELU6 { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto relu = simdOps::RELU::op(d1, d2, params); - return relu < static_cast(6) ? relu : static_cast(6); - } - }; - - template - class LeakyRELU { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - auto val = static_cast(d1); - auto alpha = static_cast(d2); - return val < 0.0f ? alpha * val : val; - } - }; - - template - class SELU { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.0f) ? static_cast(SELU_LAMBDA) * static_cast(d1) : static_cast(SELU_LAMBDA) * (static_cast(SELU_ALPHA) * sd::math::nd4j_exp(d1) - static_cast(SELU_ALPHA)); - } - }; - - template - class SELUDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1 > static_cast(0.f) ? static_cast(SELU_LAMBDA) : static_cast(SELU_ALPHA) * static_cast(SELU_LAMBDA) * sd::math::nd4j_exp(d1); - } - }; - - template - class LeakyRELUDerivative { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - if (d1 >= static_cast(0)) - return static_cast(1); - else - return static_cast(d2); - } - }; - - - template - class ASin { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_asin(d1); - } - }; - - template - class Sinh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_sinh(d1); - } - }; - - template - class SinhDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cosh(d1); - } - }; - - template - class Cosh { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_cosh(d1); - } - }; - - - template - class Tan { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_tan(d1); - } - }; - - template - class TanDerivative { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1.f) / sd::math::nd4j_pow(sd::math::nd4j_cos(d1), static_cast(2.0f)); - } - }; - - template - class ATan { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return sd::math::nd4j_atan(d1); - } - }; - - template - class Atan2 { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_atan2(d2, d1); - } - - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - // op for MetaOps - op_def static Z op(X d1, Y *params) { - return op(d1, params[0]); - } - }; - - - template - class Identity { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return d1; - } - }; - - - template - class Stabilize { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - X k = params[0]; - if (d1 * k > static_cast(- MIN_CUTFOFF)) - return static_cast(- MIN_CUTFOFF) / k; - else if (d1 * k < static_cast(MIN_CUTFOFF)) - return static_cast(MIN_CUTFOFF) / k; - return d1; - } - }; - - - - template - class Step { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return (d1 > static_cast(d2) ? static_cast(1) : static_cast(0)); - } - }; - - - - template - class OneMinus { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - op_def static X op(X d1, X *params) { - return static_cast(1) - d1; - } - }; - - template - class Sum { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X op(X d1, X *extraParams) { - return d1; - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - template - class ReduceSameBenchmarkOp { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static X op(X d1, X *extraParams) { - auto f1 = static_cast(d1); - return static_cast(sd::math::nd4j_pow(f1, 3) - + sd::math::nd4j_log(f1) * sd::math::nd4j_sin(f1) - / sd::math::nd4j_tanh(static_cast(M_E) * static_cast(M_PI) * f1) - * sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - - sd::math::nd4j_atan(static_cast(M_E) / f1)); - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class ShannonEntropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + // entropy is -sum(p(x) * log(p(x))); log entropy is log of this + return sd::math::nd4j_log(-reduction); + } +}; - op_def static X startingValue(const X *input) { - return static_cast(0); - } +template +class Entropy { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, Z *extraParams) { - auto p = d1 * d1; - return static_cast(p) * sd::math::nd4j_log(p); - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return -reduction; - } - }; - - - template - class LogEntropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z op(X d1, Z *extraParams) { + return static_cast(d1) * sd::math::nd4j_log(d1); + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return static_cast(-reduction); // entropy is -sum(p(x) * log(p(x))) + } +}; - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } +template +class ASum { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + const static functions::ReduceType reduceType = + functions::ReduceType::ASUM; - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1) * sd::math::nd4j_log(d1); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - //entropy is -sum(p(x) * log(p(x))); log entropy is log of this - return sd::math::nd4j_log(-reduction); - } - }; + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } - template - class Entropy { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } +template +class CountNonZero { + public: + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + const static functions::ReduceType reduceType = + functions::ReduceType::ASUM; - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1) * sd::math::nd4j_log(d1); - } + op_def static Z startingValue(const X *input) { return static_cast(0); } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return static_cast(-reduction); //entropy is -sum(p(x) * log(p(x))) - } - }; + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - template - class ASum { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + op_def static Z op(X d1, X *extraParams) { + return d1 == static_cast(0.0f) ? static_cast(0.0f) + : static_cast(1.0f); + } - const static functions::ReduceType reduceType = functions::ReduceType::ASUM; + op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - op_def static X startingValue(const X *input) { - return static_cast(0); - } +template +class CountZero { + public: + no_op_exec_special_accumulation_long no_op_exec_special_accumulation_cuda - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } + op_def static Z startingValue(const X *input) { return static_cast(0.0f); } - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } + op_def static Z merge(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; + op_def static Z update(Z old, Z opOutput, X *extraParams) { + return opOutput + old; + } + op_def static Z op(X d1, X *extraParams) { + return d1 == static_cast(0) ? static_cast(1) : static_cast(0); + } - template - class CountNonZero { - public: - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return static_cast(reduction); + } +}; - const static functions::ReduceType reduceType = functions::ReduceType::ASUM; +template +class Prod { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - op_def static Z startingValue(const X *input) { - return static_cast(0); - } + const static functions::ReduceType reduceType = + functions::ReduceType::PRODUCT; - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } + op_def static X startingValue(const X *input) { return static_cast(1); } - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } + op_def static X merge(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static Z op(X d1, X *extraParams) { - return d1 == static_cast(0.0f) ? static_cast(0.0f) : static_cast(1.0f); - } + op_def static X update(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static X op(X d1, X *extraParams) { return d1; } + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - template - class CountZero { - public: - no_op_exec_special_accumulation_long - no_op_exec_special_accumulation_cuda +template +class Any { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static Z startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static X startingValue(const X *input) { return static_cast(0.0f); } - op_def static Z merge(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z update(Z old, Z opOutput, X *extraParams) { - return opOutput + old; - } + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, X *extraParams) { - return d1 == static_cast(0) ? static_cast(1) : static_cast(0); - } + op_def static Z op(X d1, X *extraParams) { return d1; } - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return static_cast(reduction); - } - }; + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction > static_cast(0) ? static_cast(1) + : static_cast(0); + } +}; - template - class Prod { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda +template +class All { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - const static functions::ReduceType reduceType = functions::ReduceType::PRODUCT; + const static functions::ReduceType reduceType = + functions::ReduceType::PRODUCT; - op_def static X startingValue(const X *input) { - return static_cast(1); - } + op_def static X startingValue(const X *input) { return static_cast(1); } - op_def static X merge(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X update(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput * old; + } - op_def static X op(X d1, X *extraParams) { - return d1; - } + op_def static Z op(X d1, X *extraParams) { return d1; } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction > static_cast(0) ? static_cast(1) + : static_cast(0); + } +}; +template +class Mean { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - template - class Any { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput + old; - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, X *extraParams) { - return d1; - } + op_def static Z op(X d1, Z *extraParams) { return d1; } - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction > static_cast(0) ? static_cast(1) : static_cast(0) ; - } - }; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction / (Z)n; + } +}; +template +class ReduceFloatBenchmarkOp { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - template - class All { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - const static functions::ReduceType reduceType = functions::ReduceType::PRODUCT; + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static X startingValue(const X *input) { - return static_cast(1); - } - - op_def static Z merge(X old, X opOutput, X *extraParams) { - return opOutput * old; - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z update(X old, X opOutput, X *extraParams) { - return opOutput * old; - } - - op_def static Z op(X d1, X *extraParams) { - return d1; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction > static_cast(0) ? static_cast(1) : static_cast(0); - } - }; - - template - class Mean { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z op(X d1, Z *extraParams) { + auto f1 = static_cast(d1); + return static_cast( + sd::math::nd4j_pow(f1, 3) + + sd::math::nd4j_log(f1) * + sd::math::nd4j_sin(f1) / + sd::math::nd4j_tanh(static_cast(M_E) * + static_cast(M_PI) * f1) * + sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - + sd::math::nd4j_atan(static_cast(M_E) / f1)); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return (Z)reduction / (Z)n; + } +}; + +template +class AMean { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_abs(reduction) / static_cast(n); + } +}; + +template +class Max { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::MAX; + + op_def static X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(old, opOutput); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(opOutput, old); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_max(d1, d2); + } + + op_def static X op(X d1, X d2) { return sd::math::nd4j_max(d1, d2); } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return d1; } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; + +template +class AMaxPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1, Y d2) { + auto z1 = static_cast(d1); + auto z2 = static_cast(d2); + + if (sd::math::nd4j_abs(z1) > sd::math::nd4j_abs(z2)) + return z1; + else + return z2; + } +}; + +template +class AMinPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { return op(d1, d2); } + + op_def static Z op(X d1, Y d2) { + auto z1 = static_cast(d1); + auto z2 = static_cast(d2); + + if (sd::math::nd4j_abs(z1) < sd::math::nd4j_abs(z2)) + return z1; + else + return z2; + } +}; + +template +class MaxPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); + } +}; + +template +class MinPairwise { + public: + op_def static Z op(X d1, Y d2, Z *params) { + return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); + } + + op_def static Z op(X d1, Y d2) { + return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); + } +}; + +template +class AMax { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::AMAX; + + op_def static X startingValue(const X *input) { return input[0]; } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(opOutput), + sd::math::nd4j_abs(old)); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_max(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + op_def static X op(X d1, X d2) { + return sd::math::nd4j_abs(d1) > sd::math::nd4j_abs(d2) ? d1 : d2; + } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; + +template +class AMin { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::AMIN; + + op_def static X startingValue(const X *input) { return input[0]; } + + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(sd::math::nd4j_abs(opOutput), + sd::math::nd4j_abs(old)); + } + + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_min(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + op_def static X op(X d1, X d2) { + return sd::math::nd4j_min(sd::math::nd4j_abs(d1), + sd::math::nd4j_abs(d2)); + } + + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return sd::math::nd4j_abs(d1); } + + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return sd::math::nd4j_abs(reduction); + } +}; - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return d1; - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction / (Z) n; - } - }; - - template - class ReduceFloatBenchmarkOp { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda +template +class Min { + public: + no_op_exec_special_accumulation_same no_op_exec_special_accumulation_same_cuda - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - auto f1 = static_cast(d1); - return static_cast(sd::math::nd4j_pow(f1, 3) - + sd::math::nd4j_log(f1) * sd::math::nd4j_sin(f1) - / sd::math::nd4j_tanh(static_cast(M_E) * static_cast(M_PI) * f1) - * sd::math::nd4j_sqrt(static_cast(M_PI) / f1) - - sd::math::nd4j_atan(static_cast(M_E) / f1)); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return (Z) reduction / (Z) n; - } - }; - - - template - class AMean { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return sd::math::nd4j_abs(opOutput) + sd::math::nd4j_abs(old); - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return sd::math::nd4j_abs(d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_abs(reduction) / static_cast(n); - } - }; - - template - class Max { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::MAX; - - op_def static X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(old, opOutput); - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(opOutput, old); - } - - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_max(d1, d2); - } - - op_def static X op(X d1, X d2) { - return sd::math::nd4j_max(d1, d2); - } - - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return d1; - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; - - - template - class AMaxPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, Y d2) { - auto z1 = static_cast(d1); - auto z2 = static_cast(d2); - - if (sd::math::nd4j_abs(z1) > sd::math::nd4j_abs(z2)) - return z1; - else - return z2; - } - }; - - - template - class AMinPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return op(d1, d2); - } - - op_def static Z op(X d1, Y d2) { - auto z1 = static_cast(d1); - auto z2 = static_cast(d2); - - if (sd::math::nd4j_abs(z1) < sd::math::nd4j_abs(z2)) - return z1; - else - return z2; - } - }; - - template - class MaxPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_max(static_cast(d1), static_cast(d2)); - } - }; - - - template - class MinPairwise { - public: - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); - } - - op_def static Z op(X d1, Y d2) { - return sd::math::nd4j_min(static_cast(d1), static_cast(d2)); - } - }; - - template - class AMax { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::AMAX; - - op_def static X startingValue(const X *input) { - return input[0]; - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(old), sd::math::nd4j_abs(opOutput)); - } - - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(opOutput), sd::math::nd4j_abs(old)); - } - - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_max(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } - - op_def static X op(X d1, X d2) { - return sd::math::nd4j_abs(d1) > sd::math::nd4j_abs(d2) ? d1 : d2; - } - - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } - - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; - - - template - class AMin { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::AMIN; - - op_def static X startingValue(const X *input) { - return input[0]; - } - - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(sd::math::nd4j_abs(old), sd::math::nd4j_abs(opOutput)); - } + const static functions::ReduceType reduceType = + functions::ReduceType::MIN; - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(sd::math::nd4j_abs(opOutput), sd::math::nd4j_abs(old)); - } - - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_min(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } + op_def static X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } - op_def static X op(X d1, X d2) { - return sd::math::nd4j_min(sd::math::nd4j_abs(d1), sd::math::nd4j_abs(d2)); - } + op_def static X merge(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(old, opOutput); + } - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return sd::math::nd4j_abs(d1); - } + op_def static X update(X old, X opOutput, X *extraParams) { + return sd::math::nd4j_min(opOutput, old); + } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return sd::math::nd4j_abs(reduction); - } - }; + op_def static X op(X d1, X d2, X *params) { + return sd::math::nd4j_min(d1, d2); + } - template - class Min { - public: - no_op_exec_special_accumulation_same - no_op_exec_special_accumulation_same_cuda + op_def static X op(X d1, X d2) { return sd::math::nd4j_min(d1, d2); } - const static functions::ReduceType reduceType = functions::ReduceType::MIN; + // FIXME: this signature overlaps with MetaOp + op_def static X op(X d1, X *extraParams) { return d1; } - op_def static X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } + op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } +}; - op_def static X merge(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(old, opOutput); - } +template +class Norm1 { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static X update(X old, X opOutput, X *extraParams) { - return sd::math::nd4j_min(opOutput, old); - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static X op(X d1, X d2, X *params) { - return sd::math::nd4j_min(d1, d2); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static X op(X d1, X d2) { - return sd::math::nd4j_min(d1, d2); - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - // FIXME: this signature overlaps with MetaOp - op_def static X op(X d1, X *extraParams) { - return d1; - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static X postProcess(X reduction, Nd4jLong n, X *extraParams) { - return reduction; - } - }; + op_def static Z op(X d1, Z *extraParams) { + return static_cast(sd::math::nd4j_abs(d1)); + } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction; + } +}; - template - class Norm1 { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda +template +class Norm2 { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - } + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_sqrt(reduction); + } - } + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1 * d1); } +}; - op_def static Z op(X d1, Z *extraParams) { - return static_cast(sd::math::nd4j_abs(d1)); - } +template +class SquaredNorm { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction; - } - }; + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { return static_cast(0); } - template - class Norm2 { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static X startingValue(const X *input) { - return static_cast(0); - } + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1 * d1); } - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return reduction; + } +}; +template +class NormFrobenius { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_sqrt(reduction); - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1 * d1); - } - }; + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } - template - class SquaredNorm { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda + op_def static Z op(X d1, Z *extraParams) { + X v = sd::math::nd4j_abs(d1); + return static_cast(v * v); + } - const static functions::ReduceType reduceType = functions::ReduceType::SUM; + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_sqrt(reduction); + } +}; - op_def static X startingValue(const X *input) { - return static_cast(0); - } +template +class NormP { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + op_def static X startingValue(const X *input) { return static_cast(0); } - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1 * d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return reduction; - } - }; - - template - class NormFrobenius { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - X v = sd::math::nd4j_abs(d1); - return static_cast(v * v); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_sqrt(reduction); - } - }; - - template - class NormP { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - } - - op_def static Z op(X d1, Z *extraParams) { - return sd::math::nd4j_pow(sd::math::nd4j_abs(d1), extraParams[0]); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_pow(reduction, static_cast(1.0f) / extraParams[0]); - } - }; - - template - class NormMax { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0); - } - - op_def static Z merge(Z old, Z opOutput, Z *extraParams) { - return opOutput + old; - - } - - op_def static Z update(Z old, Z opOutput, Z *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(old), - sd::math::nd4j_abs(opOutput)); - } - - op_def static Z op(X d1, Z *extraParams) { - return static_cast(d1); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { - return sd::math::nd4j_max(sd::math::nd4j_abs(reduction), sd::math::nd4j_abs(reduction)); - } - }; - - template - class Variance { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z merge(X old, X opOutput, Z *extraParams) { - return old + opOutput; - } - - op_def static Z update(X old, X opOutput, Z *extraParams) { - return old + opOutput; - - } - - op_def static X op(X d1, Z *extraParams) { - X mean = static_cast(extraParams[0]); - X ret = d1 - mean; - return ret * ret; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { - // T bias = extraParams[1]; - // return (reduction - (sd::math::nd4j_pow(bias, static_cast(2.0f)) / static_cast(n))) / (n - 1) - return static_cast(reduction) / static_cast(n - 1); - } - }; - - /** - * Standard deviation of a buffer - */ - template - class StandardDeviation { - public: - no_op_exec_special_accumulation - no_op_exec_special_accumulation_cuda - - const static functions::ReduceType reduceType = functions::ReduceType::SUM; - - op_def static X startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z merge(X old, X opOutput, Z *extraParams) { - return old + opOutput; - } - - op_def static Z update(X old, X opOutput, Z *extraParams) { - return old + opOutput; - - } - - op_def static Z op(X d1, Z *extraParams) { - X mean = extraParams[0]; - X ret = d1 - mean; - return ret * ret; - } - - op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { - Z ret = Variance::postProcess(reduction, n, extraParams); - Z sqrtRet = sd::math::nd4j_sqrt(ret); - return sqrtRet; - } - }; - - template - class CosineSimilarity { - public: - static const int extraParamsLen = 2; - - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } - - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return reduction / (sd::math::nd4j_sqrt(extraParams[0]) * sd::math::nd4j_sqrt(extraParams[1])); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(d1 * d1); - extraParams[1] += static_cast(d2 * d2); - return static_cast(d1 * d2); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z op(X d1, Z *extraParams) { + return sd::math::nd4j_pow(sd::math::nd4j_abs(d1), + extraParams[0]); + } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_pow(reduction, + static_cast(1.0f) / extraParams[0]); + } +}; + +template +class NormMax { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0); } + + op_def static Z merge(Z old, Z opOutput, Z *extraParams) { + return opOutput + old; + } + + op_def static Z update(Z old, Z opOutput, Z *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(old), + sd::math::nd4j_abs(opOutput)); + } + + op_def static Z op(X d1, Z *extraParams) { return static_cast(d1); } + + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParams) { + return sd::math::nd4j_max(sd::math::nd4j_abs(reduction), + sd::math::nd4j_abs(reduction)); + } +}; + +template +class Variance { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Z merge(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z update(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static X op(X d1, Z *extraParams) { + X mean = static_cast(extraParams[0]); + X ret = d1 - mean; + return ret * ret; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { + // T bias = extraParams[1]; + // return (reduction - (sd::math::nd4j_pow(bias, static_cast(2.0f)) / + // static_cast(n))) / (n - 1) + return static_cast(reduction) / static_cast(n - 1); + } +}; + +/** + * Standard deviation of a buffer + */ +template +class StandardDeviation { + public: + no_op_exec_special_accumulation no_op_exec_special_accumulation_cuda + + const static functions::ReduceType reduceType = + functions::ReduceType::SUM; + + op_def static X startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Z merge(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z update(X old, X opOutput, Z *extraParams) { + return old + opOutput; + } + + op_def static Z op(X d1, Z *extraParams) { + X mean = extraParams[0]; + X ret = d1 - mean; + return ret * ret; + } + + op_def static Z postProcess(X reduction, Nd4jLong n, Z *extraParams) { + Z ret = Variance::postProcess(reduction, n, extraParams); + Z sqrtRet = sd::math::nd4j_sqrt(ret); + return sqrtRet; + } +}; + +template +class CosineSimilarity { + public: + static const int extraParamsLen = 2; + + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } + + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } + + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return reduction / (sd::math::nd4j_sqrt(extraParams[0]) * + sd::math::nd4j_sqrt(extraParams[1])); + } + + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += static_cast(d1 * d1); + extraParams[1] += static_cast(d2 * d2); + return static_cast(d1 * d2); + } + + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0],static_cast(d1 * d1)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1],static_cast(d2 * d2)); + static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd(&extraParams[0], static_cast(d1 * d1)); + sd::math::atomics::nd4j_atomicAdd(&extraParams[1], static_cast(d2 * d2)); - return static_cast(d1 * d2); - } + return static_cast(d1 * d2); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } - - - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - template - class JaccardDistance { - public: - static const int extraParamsLen = 2; +template +class JaccardDistance { + public: + static const int extraParamsLen = 2; - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - // num / denom - return (static_cast(1.0f)) - (extraParams[0] / extraParams[1]); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + // num / denom + return (static_cast(1.0f)) - (extraParams[0] / extraParams[1]); + } - op_def static Y num(X d1, X d2) { - return sd::math::nd4j_min(d1, d2); - } + op_def static Y num(X d1, X d2) { return sd::math::nd4j_min(d1, d2); } - op_def static Y denom(X d1, X d2) { - return sd::math::nd4j_max(d1, d2); - } + op_def static Y denom(X d1, X d2) { return sd::math::nd4j_max(d1, d2); } - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(num(d1, d2)); - extraParams[1] += static_cast(denom(d1, d2)); - return static_cast(0.0f); - } + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += static_cast(num(d1, d2)); + extraParams[1] += static_cast(denom(d1, d2)); + return static_cast(0.0f); + } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0],num(d1, d2)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1], denom(d1, d2)); + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd(&extraParams[0], num(d1, d2)); + sd::math::atomics::nd4j_atomicAdd(&extraParams[1], denom(d1, d2)); - return static_cast(0.0f); - } + return static_cast(0.0f); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; +template +class SimpleHammingDistance { + public: + static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } - template - class SimpleHammingDistance { - public: - static const int extraParamsLen = 0; + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return static_cast(reduction / n); + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Y op(X d1, X d2, Y *extraParams) { + return (d1 == d2) ? static_cast(0.0f) : static_cast(1.0f); + } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return static_cast(reduction / n); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - return (d1 == d2) ? static_cast(0.0f) : static_cast(1.0f); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - - } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParams) { - return op(d1, d2, extraParams); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParams) { + return op(d1, d2, extraParams); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } - - - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; - - template - class CosineDistance { - public: - static const int extraParamsLen = 2; - - op_def static X *generateExtraParams() { - //T *extraParams = new T[2]; - return nullptr; - } - - op_def static void finalizeExtraParams(X *extraParams) { - //delete[] extraParams; - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { - return (static_cast(1.0f)) - (reduction / (sd::math::nd4j_sqrt(extraParams[0]) * sd::math::nd4j_sqrt(extraParams[1]))); - } - - op_def static Y op(X d1, X d2, Y *extraParams) { - extraParams[0] += static_cast(sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); - extraParams[1] += static_cast(sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); - return (d1 * d2); - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - extraParamsTotal[0] += extraParamsLocal[0]; - extraParamsTotal[1] += extraParamsLocal[1]; - } + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; + +template +class CosineDistance { + public: + static const int extraParamsLen = 2; + + op_def static X *generateExtraParams() { + // T *extraParams = new T[2]; + return nullptr; + } + + op_def static void finalizeExtraParams(X *extraParams) { + // delete[] extraParams; + } + + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } + + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParams) { + return (static_cast(1.0f)) - + (reduction / (sd::math::nd4j_sqrt(extraParams[0]) * + sd::math::nd4j_sqrt(extraParams[1]))); + } + + op_def static Y op(X d1, X d2, Y *extraParams) { + extraParams[0] += + static_cast(sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); + extraParams[1] += + static_cast(sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); + return (d1 * d2); + } + + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) { + extraParamsTotal[0] += extraParamsLocal[0]; + extraParamsTotal[1] += extraParamsLocal[1]; + } #ifdef __CUDACC__ - static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { - sd::math::atomics::nd4j_atomicAdd(&extraParams[0], sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); - sd::math::atomics::nd4j_atomicAdd(&extraParams[1], sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); - - return (d1 * d2); - } + static _CUDA_D inline Y opAtomic(X d1, X d2, Y *extraParams) { + sd::math::atomics::nd4j_atomicAdd( + &extraParams[0], sd::math::nd4j_abs(d1) * sd::math::nd4j_abs(d1)); + sd::math::atomics::nd4j_atomicAdd( + &extraParams[1], sd::math::nd4j_abs(d2) * sd::math::nd4j_abs(d2)); + + return (d1 * d2); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParams) { - return old + opOutput; - } + op_def static Y update(Y old, Y opOutput, Y *extraParams) { + return old + opOutput; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParams) { + return update(old, opOutput, extraParams); + } +}; - op_def static Y merge(Y old, Y opOutput, Y *extraParams) { - return update(old, opOutput, extraParams); - } - }; +/** + * Dot product between 2 arrays + */ +template +class Dot { + public: + static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { return nullptr; } - /** - * Dot product between 2 arrays - */ - template - class Dot { - public: - static const int extraParamsLen = 0; + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + // delete[] * extraParamsRef; + } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - //delete[] * extraParamsRef; - } - - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return reduction; - } - - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - return static_cast(d1 * d2); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return reduction; + } + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + return static_cast(d1 * d2); + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return opOutput + old; - } + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return opOutput + old; + } - op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } + op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) {} - }; + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} +}; +/** + * Op to check equality within arrays + */ +template +class EqualsWithEps { + public: + static const int extraParamsLen = 0; - /** - * Op to check equality within arrays - */ - template - class EqualsWithEps { - public: - static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { return nullptr; } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } + op_def static Z startingValue(const X *input) { return static_cast(0.0f); } - op_def static Z startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParamsRef) { - return reduction; - } - - op_def static Z op(X d1, X d2, Z *extraParamsRef) { - double eps = sd::math::nd4j_abs(extraParamsRef[2]); - return static_cast(!sd::math::nd4j_eq(d1, d2, eps)); - } + op_def static Z postProcess(Z reduction, Nd4jLong n, Z *extraParamsRef) { + return reduction; + } + op_def static Z op(X d1, X d2, Z *extraParamsRef) { + double eps = sd::math::nd4j_abs(extraParamsRef[2]); + return static_cast(!sd::math::nd4j_eq(d1, d2, eps)); + } #ifdef __CUDACC__ - __device__ - static inline Z opAtomic(X d1, X d2, Z *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Z opAtomic(X d1, X d2, Z *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Z update(Z old, Z opOutput, Z *extraParamsRef) { - return opOutput + old; - } - - op_def static Z merge(X old, Z opOutput, Z *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } + op_def static Z update(Z old, Z opOutput, Z *extraParamsRef) { + return opOutput + old; + } - op_def static void aggregateExtraParams(Z *extraParamsTotal, Z *extraParamsLocal) {} - }; + op_def static Z merge(X old, Z opOutput, Z *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } + op_def static void aggregateExtraParams(Z *extraParamsTotal, + Z *extraParamsLocal) {} +}; +template +class EuclideanDistance { + public: + static const int extraParamsLen = 0; - template - class EuclideanDistance { - public: - static const int extraParamsLen = 0; + op_def static X *generateExtraParams() { return nullptr; } - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } - - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return sd::math::nd4j_sqrt(reduction); - } - - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - X ret = d1 - d2; - return static_cast(ret * ret); - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return sd::math::nd4j_sqrt(reduction); + } + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + X ret = d1 - d2; + return static_cast(ret * ret); + } #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return opOutput + old; - } - - op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) {} - - }; + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return opOutput + old; + } + op_def static Y merge(Y old, Y opOutput, Y *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} +}; - template - class ManhattanDistance { - public: - static const int extraParamsLen = 0; +template +class ManhattanDistance { + public: + static const int extraParamsLen = 0; - op_def static X * generateExtraParams() { - return nullptr; - } + op_def static X *generateExtraParams() { return nullptr; } - op_def static void finalizeExtraParams(X *extraParamsRef) { - //no-op - } + op_def static void finalizeExtraParams(X *extraParamsRef) { + // no-op + } - op_def static Y startingValue(const X *input) { - return static_cast(0.0f); - } + op_def static Y startingValue(const X *input) { return static_cast(0.0f); } - op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { - return reduction; - } + op_def static Y postProcess(Y reduction, Nd4jLong n, Y *extraParamsRef) { + return reduction; + } - op_def static Y op(X d1, X d2, Y *extraParamsRef) { - return sd::math::nd4j_abs(d1 - d2); - } + op_def static Y op(X d1, X d2, Y *extraParamsRef) { + return sd::math::nd4j_abs(d1 - d2); + } - op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { - return old + opOutput; - } - - op_def static void aggregateExtraParams(Y *extraParamsTotal, Y *extraParamsLocal) { - - } + op_def static Y update(Y old, Y opOutput, Y *extraParamsRef) { + return old + opOutput; + } + op_def static void aggregateExtraParams(Y *extraParamsTotal, + Y *extraParamsLocal) {} #ifdef __CUDACC__ - __device__ - static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { - return op(d1, d2, extraParamsRef); - } + __device__ static inline Y opAtomic(X d1, X d2, Y *extraParamsRef) { + return op(d1, d2, extraParamsRef); + } #endif #ifndef __clang__ #pragma omp declare simd uniform(extraParamsRef) #endif - op_def static Y merge(X old, X opOutput, X *extraParamsRef) { - return update(old, opOutput, extraParamsRef); - } - }; - - - template - class IndexAbsoluteMax { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return sd::math::nd4j_abs(val); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - opOutput.value = sd::math::nd4j_abs(opOutput.value); - old.value = sd::math::nd4j_abs(old.value); - if (opOutput.value > old.value) - return opOutput; + op_def static Y merge(X old, X opOutput, X *extraParamsRef) { + return update(old, opOutput, extraParamsRef); + } +}; + +template +class IndexAbsoluteMax { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return sd::math::nd4j_abs(val); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + opOutput.value = sd::math::nd4j_abs(opOutput.value); + old.value = sd::math::nd4j_abs(old.value); + if (opOutput.value > old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (sd::math::nd4j_abs(f1.value) > sd::math::nd4j_abs(f2.value)) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return 0; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - template - class FirstIndex { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (sd::math::nd4j_abs(f1.value) > sd::math::nd4j_abs(f2.value)) + return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline X startingValue(const X *input) { return 0; } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class FirstIndex { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { #ifdef __CUDACC__ - if (opOutput.index < 0) - return old; + if (opOutput.index < 0) return old; #endif - auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); - - //printf("res: %f; oldIdx: %i; newIdx: %i\n", res, old.index, opOutput.index); - - if (res == static_cast(0)) - return old; - - if (old.index < 0) - return opOutput; - - if (old.index > opOutput.index) - return opOutput; - - return old; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = -1; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.index > f2.index) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - }; - - - template - class LastIndex { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { + auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); + + // printf("res: %f; oldIdx: %i; newIdx: %i\n", res, old.index, + // opOutput.index); + + if (res == static_cast(0)) return old; + + if (old.index < 0) return opOutput; + + if (old.index > opOutput.index) return opOutput; + + return old; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = -1; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.index > f2.index) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } +}; + +template +class LastIndex { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { #ifdef __CUDACC__ - if (opOutput.index < 0) - return old; + if (opOutput.index < 0) return old; #endif - auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); - - if (res == static_cast(0)) - return old; - - if (old.index < 0) - return opOutput; - - if (old.index < opOutput.index) - return opOutput; - - return old; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = -1; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.index < f2.index) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - }; - - - template - class IndexMax { - public: - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - if (opOutput.value > old.value) { - return opOutput; - } + auto res = simdOps::MatchCondition::op(opOutput.value, extraParams); + + if (res == static_cast(0)) return old; + + if (old.index < 0) return opOutput; + + if (old.index < opOutput.index) return opOutput; + + return old; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = -1; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.index < f2.index) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } +}; + +template +class IndexMax { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + if (opOutput.value > old.value) { + return opOutput; + } #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.value > f2.value) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return -sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - - template - class IndexAbsoluteMin { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op( - functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - opOutput.value = sd::math::nd4j_abs(opOutput.value); - old.value = sd::math::nd4j_abs(old.value); - if (opOutput.value < old.value) - return opOutput; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.value > f2.value) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return -sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class IndexAbsoluteMin { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + opOutput.value = sd::math::nd4j_abs(opOutput.value); + old.value = sd::math::nd4j_abs(old.value); + if (opOutput.value < old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (sd::math::nd4j_abs(f1.value) < sd::math::nd4j_abs(f2.value)) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - - template - class IndexMin { - public: - static _CUDA_HD inline functions::indexreduce::IndexValue op( - functions::indexreduce::IndexValue val, X *extraParams) { - return val; - } - - static _CUDA_HD inline X startingValue(const X *input) { - return sd::DataTypeUtils::infOrMax(); - } - - static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(const X *input) { - functions::indexreduce::IndexValue local; - local.value = startingValue(input); - local.index = 0; - return local; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue update(functions::indexreduce::IndexValue &old, functions::indexreduce::IndexValue &opOutput, X *extraParams) { - if (opOutput.value < old.value) - return opOutput; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (sd::math::nd4j_abs(f1.value) < sd::math::nd4j_abs(f2.value)) + return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class IndexMin { + public: + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue val, X *extraParams) { + return val; + } + + static _CUDA_HD inline X startingValue(const X *input) { + return sd::DataTypeUtils::infOrMax(); + } + + static _CUDA_HD inline functions::indexreduce::IndexValue + startingIndexValue(const X *input) { + functions::indexreduce::IndexValue local; + local.value = startingValue(input); + local.index = 0; + return local; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue update( + functions::indexreduce::IndexValue &old, + functions::indexreduce::IndexValue &opOutput, X *extraParams) { + if (opOutput.value < old.value) return opOutput; #ifdef __CUDACC__ - // workaround for cuda race condition at merge phase - else if (opOutput.value == old.value && opOutput.index < old.index) - return opOutput; + // workaround for cuda race condition at merge phase + else if (opOutput.value == old.value && opOutput.index < old.index) + return opOutput; #elif defined(__GNUC__) #endif - return old; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue merge( - functions::indexreduce::IndexValue f1, - functions::indexreduce::IndexValue f2, X *extraParams) { - if (f1.value < f2.value) - return f2; - return f1; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( - functions::indexreduce::IndexValue reduction, int n, int xOffset, - X *dx, int incx, X *extraParams, X *result) { - return reduction; - } - - static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue d1, - functions::indexreduce::IndexValue d2, X *extraParams) { - return d1; - } - }; - - template - class SummaryStatsVariance { - public: - - static _CUDA_HD inline Z getValue(const bool biasCorrected, functions::summarystats::SummaryStatsData val) { - if (biasCorrected) { - Z ret = static_cast(val.varianceBiasCorrected()); - if (ret < static_cast(0.0f)) - return static_cast(val.variance()); - return ret; - } - return static_cast(val.variance()); - } - - static _CUDA_HD inline functions::summarystats::SummaryStatsData op(functions::summarystats::SummaryStatsData d1, Z *extraParams) { - return d1; - } - }; - - template - class SummaryStatsStandardDeviation { - public: - - static _CUDA_HD inline Z getValue(const bool biasCorrected, functions::summarystats::SummaryStatsData val) { - if (biasCorrected) { - auto ret = static_cast(val.varianceBiasCorrected()); - if (ret < static_cast(0.0f)) - return sd::math::nd4j_sqrt(val.variance()); - else - return sd::math::nd4j_sqrt(ret); - } - return sd::math::nd4j_sqrt(val.variance()); - } - - static _CUDA_HD inline functions::summarystats::SummaryStatsData op(functions::summarystats::SummaryStatsData d1, Z *extraParams) { - return d1; - } - }; - - template - class DropOut { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - inline _CUDA_D static X op(X d1, X *params) { - X prob = params[0]; + return old; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue merge( + functions::indexreduce::IndexValue f1, + functions::indexreduce::IndexValue f2, X *extraParams) { + if (f1.value < f2.value) return f2; + return f1; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue postProcess( + functions::indexreduce::IndexValue reduction, int n, int xOffset, + X *dx, int incx, X *extraParams, X *result) { + return reduction; + } + + static _CUDA_HD inline functions::indexreduce::IndexValue op( + functions::indexreduce::IndexValue d1, + functions::indexreduce::IndexValue d2, X *extraParams) { + return d1; + } +}; + +template +class SummaryStatsVariance { + public: + static _CUDA_HD inline Z getValue( + const bool biasCorrected, + functions::summarystats::SummaryStatsData val) { + if (biasCorrected) { + Z ret = static_cast(val.varianceBiasCorrected()); + if (ret < static_cast(0.0f)) return static_cast(val.variance()); + return ret; + } + return static_cast(val.variance()); + } + + static _CUDA_HD inline functions::summarystats::SummaryStatsData op( + functions::summarystats::SummaryStatsData d1, Z *extraParams) { + return d1; + } +}; + +template +class SummaryStatsStandardDeviation { + public: + static _CUDA_HD inline Z getValue( + const bool biasCorrected, + functions::summarystats::SummaryStatsData val) { + if (biasCorrected) { + auto ret = static_cast(val.varianceBiasCorrected()); + if (ret < static_cast(0.0f)) + return sd::math::nd4j_sqrt(val.variance()); + else + return sd::math::nd4j_sqrt(ret); + } + return sd::math::nd4j_sqrt(val.variance()); + } + + static _CUDA_HD inline functions::summarystats::SummaryStatsData op( + functions::summarystats::SummaryStatsData d1, Z *extraParams) { + return d1; + } +}; + +template +class DropOut { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + inline _CUDA_D static X + op(X d1, X *params) { + X prob = params[0]; #ifdef __CUDACC__ - X length = params[1]; - X tid = blockIdx.x * blockDim.x + threadIdx.x; - X rnd = sd::math::nd4j_abs(sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + static_cast(length) * static_cast(tid))); + X length = params[1]; + X tid = blockIdx.x * blockDim.x + threadIdx.x; + X rnd = sd::math::nd4j_abs( + sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + + static_cast(length) * static_cast(tid))); #else - X rnd = static_cast(rand() / RAND_MAX); + X rnd = static_cast(rand() / RAND_MAX); #endif - return rnd >= prob ? static_cast(0.0f) : d1; - } - }; - - template - class DropOutInverted { - public: - no_op_exec_special - no_op_exec_special_cuda - + return rnd >= prob ? static_cast(0.0f) : d1; + } +}; + +template +class DropOutInverted { + public: + no_op_exec_special no_op_exec_special_cuda #ifdef __CUDACC__ - __device__ + __device__ #endif - inline static Z op(X d1, Y d2, Z *params) { - Y prob = d2; + inline static Z + op(X d1, Y d2, Z *params) { + Y prob = d2; #ifdef __CUDACC__ - X length = params[1]; - X tid = blockIdx.x * blockDim.x + threadIdx.x; - X rnd = sd::math::nd4j_abs(sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + static_cast(length) * static_cast(tid))); + X length = params[1]; + X tid = blockIdx.x * blockDim.x + threadIdx.x; + X rnd = sd::math::nd4j_abs( + sd::math::nd4j_cos(static_cast(clock64()) * static_cast(tid) + + static_cast(length) * static_cast(tid))); #else - X rnd = static_cast(rand() / RAND_MAX); + X rnd = static_cast(rand() / RAND_MAX); #endif - return rnd >= static_cast(prob) ? static_cast(0.0f) : reinterpret_cast(d1 / static_cast(prob)); - } - }; - - - template - class ReplaceNans { - public: - no_op_exec_special - no_op_exec_special_cuda - - op_def static Z op(X d1, Y d2, Z *params) { - return sd::math::nd4j_isnan(d1) ? static_cast(d2) : static_cast(d1) ; - } - }; - - // this op is used for conditional pairwise transforms only - template - class CompareAndReplace{ - public: - // op definition for PairWise Transform - op_def static Z op(X d1, Y d2, Z *params) { - auto zd1 = static_cast(d1); - auto zd2 = static_cast(d2); - auto compare = params[0]; - auto eps = params[2]; - int mode = (int) params[3]; - if (mode == 0) // equals - if (sd::math::nd4j_abs(zd1 - compare) <= eps) - return zd2; - else - return zd1; - else if (mode == 1) // not equals eps - if (sd::math::nd4j_abs(zd1 - compare) > eps) - return zd2; - else - return zd1; - else if (mode == 2) // less_than eps - if (zd1 < compare) - return zd2; - else - return zd1; - else if (mode ==3) // greater_than - if (zd1 > compare) - return zd2; - else - return zd1; - else if (mode == 4) // less_or_equals_than - if (zd1 <= compare) - return zd2; - else - return zd1; - else if (mode == 5) // greater_or_equals_than - if (zd1 >= compare) - return zd2; - else - return zd1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(zd1) < compare) - return zd2; - else - return zd1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(zd1) > compare) - return zd2; - else - return zd1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(zd1)) - return zd2; - else - return zd1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(zd1)) - return zd2; - else - return zd1; - else if (mode == 10) - if (zd1 == compare) - return zd2; - else - return zd1; - else if (mode == 11) - if (zd1 != compare) - return zd2; - else - return zd1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(zd1) >= compare) - return zd2; - else - return zd1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(zd1) <= compare) - return zd2; - else - return zd1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return zd1; - } - }; - - template - class CompareAndSet { - public: - - - // op definition for PairWise Transform - op_def static Z op(X dX, Y dY, Z *params) { - auto d1 = static_cast(dX); - auto d2 = static_cast(dY); - auto compare = params[0]; - auto eps = params[2]; - auto mode = static_cast(params[3]); - if (mode == 0) // equals - if (sd::math::nd4j_abs(d2 - compare) <= eps) - return d2; - else - return d1; - else if (mode == 1) // not equals - if (sd::math::nd4j_abs(d2 - compare) > eps) - return d2; - else - return d1; - else if (mode == 2) // less_than - if (d2 < compare) - return d2; - else - return d1; - else if (mode ==3) // greater_than - if (d2 > compare) - return d2; - else - return d1; - else if (mode == 4) // less_or_equals_than - if (d2 <= compare) - return d2; - else - return d1; - else if (mode == 5) // greater_or_equals_than - if (d2 >= compare) - return d2; - else - return d1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(d2) < compare) - return d2; - else - return d1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(d2) > compare) - return d2; - else - return d1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(d2)) - return d2; - else - return d1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(d2)) - return d2; - else - return d1; - else if (mode == 10) - if (d2 == compare) - return d2; - else - return d1; - else if (mode == 11) - if (d2 != compare) - return d2; - else - return d1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(d1) >= compare) - return d2; - else - return d1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(d1) <= compare) - return d2; - else - return d1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return d1; - } - }; - - template - class CompareAndSetTransform { - public: - no_op_exec_special_same - no_op_exec_special_same_cuda - - - // op definition for Transform - op_def static X op(X d1, X *params) { - auto compare = params[0]; - auto set = params[1]; - auto eps = params[2]; - - // with mode == 0 we do set if d1 equals to compare, and with mode == 1 - we go otherwise - int mode = (int) params[3]; - if (mode == 0) // equals - if (sd::math::nd4j_abs(d1 - compare) <= eps) - return set; - else - return d1; - //return sd::math::nd4j_abs(d1 - compare) <= eps ? set : d1; - else if (mode == 1) // not equals - if (sd::math::nd4j_abs(d1 - compare) > eps) - return set; - else - return d1; - //return sd::math::nd4j_abs(d1 - compare) > eps ? set : d1; - else if (mode == 2) // less_than - if (d1 < compare) - return set; - else - return d1; - else if (mode ==3) // greater_than - if (d1 > compare) - return set; - else - return d1; - else if (mode == 4) // less_or_equals_than - if (d1 <= compare) - return set; - else - return d1; - else if (mode == 5) // greater_or_equals_than - if (d1 >= compare) - return set; - else - return d1; - else if (mode == 6) // abs_less_than - if (sd::math::nd4j_abs(d1) < compare) - return set; - else - return d1; - else if (mode == 7) // abs_greater_than - if (sd::math::nd4j_abs(d1) > compare) - return set; - else - return d1; - else if (mode == 8) // is inf - if (sd::math::nd4j_isinf(d1)) - return set; - else - return d1; - else if (mode == 9) // is nan - if (sd::math::nd4j_isnan(d1)) - return set; - else - return d1; - else if (mode == 10) - if (d1 == compare) - return set; - else - return d1; - else if (mode == 11) - if (d1 != compare) - return set; - else - return d1; - else if (mode == 12) // abs_greater_or_equals_than - if (sd::math::nd4j_abs(d1) >= compare) - return set; - else - return d1; - else if (mode == 13) // abs_less_or_equals_than - if (sd::math::nd4j_abs(d1) <= compare) - return set; - else - return d1; - else - printf("Undefined boolean operation: [%i]\n", mode); - return d1; - } - }; - - -} + return rnd >= static_cast(prob) + ? static_cast(0.0f) + : reinterpret_cast(d1 / static_cast(prob)); + } +}; + +template +class ReplaceNans { + public: + no_op_exec_special no_op_exec_special_cuda + + op_def static Z + op(X d1, Y d2, Z *params) { + return sd::math::nd4j_isnan(d1) ? static_cast(d2) : static_cast(d1); + } +}; + +// this op is used for conditional pairwise transforms only +template +class CompareAndReplace { + public: + // op definition for PairWise Transform + op_def static Z op(X d1, Y d2, Z *params) { + auto zd1 = static_cast(d1); + auto zd2 = static_cast(d2); + auto compare = params[0]; + auto eps = params[2]; + int mode = (int)params[3]; + if (mode == 0) // equals + if (sd::math::nd4j_abs(zd1 - compare) <= eps) + return zd2; + else + return zd1; + else if (mode == 1) // not equals eps + if (sd::math::nd4j_abs(zd1 - compare) > eps) + return zd2; + else + return zd1; + else if (mode == 2) // less_than eps + if (zd1 < compare) + return zd2; + else + return zd1; + else if (mode == 3) // greater_than + if (zd1 > compare) + return zd2; + else + return zd1; + else if (mode == 4) // less_or_equals_than + if (zd1 <= compare) + return zd2; + else + return zd1; + else if (mode == 5) // greater_or_equals_than + if (zd1 >= compare) + return zd2; + else + return zd1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(zd1) < compare) + return zd2; + else + return zd1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(zd1) > compare) + return zd2; + else + return zd1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(zd1)) + return zd2; + else + return zd1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(zd1)) + return zd2; + else + return zd1; + else if (mode == 10) + if (zd1 == compare) + return zd2; + else + return zd1; + else if (mode == 11) + if (zd1 != compare) + return zd2; + else + return zd1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(zd1) >= compare) + return zd2; + else + return zd1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(zd1) <= compare) + return zd2; + else + return zd1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return zd1; + } +}; + +template +class CompareAndSet { + public: + // op definition for PairWise Transform + op_def static Z op(X dX, Y dY, Z *params) { + auto d1 = static_cast(dX); + auto d2 = static_cast(dY); + auto compare = params[0]; + auto eps = params[2]; + auto mode = static_cast(params[3]); + if (mode == 0) // equals + if (sd::math::nd4j_abs(d2 - compare) <= eps) + return d2; + else + return d1; + else if (mode == 1) // not equals + if (sd::math::nd4j_abs(d2 - compare) > eps) + return d2; + else + return d1; + else if (mode == 2) // less_than + if (d2 < compare) + return d2; + else + return d1; + else if (mode == 3) // greater_than + if (d2 > compare) + return d2; + else + return d1; + else if (mode == 4) // less_or_equals_than + if (d2 <= compare) + return d2; + else + return d1; + else if (mode == 5) // greater_or_equals_than + if (d2 >= compare) + return d2; + else + return d1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(d2) < compare) + return d2; + else + return d1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(d2) > compare) + return d2; + else + return d1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(d2)) + return d2; + else + return d1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(d2)) + return d2; + else + return d1; + else if (mode == 10) + if (d2 == compare) + return d2; + else + return d1; + else if (mode == 11) + if (d2 != compare) + return d2; + else + return d1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(d1) >= compare) + return d2; + else + return d1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(d1) <= compare) + return d2; + else + return d1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return d1; + } +}; + +template +class CompareAndSetTransform { + public: + no_op_exec_special_same no_op_exec_special_same_cuda + + // op definition for Transform + op_def static X + op(X d1, X *params) { + auto compare = params[0]; + auto set = params[1]; + auto eps = params[2]; + + // with mode == 0 we do set if d1 equals to compare, and with mode == 1 - we + // go otherwise + int mode = (int)params[3]; + if (mode == 0) // equals + if (sd::math::nd4j_abs(d1 - compare) <= eps) + return set; + else + return d1; + // return sd::math::nd4j_abs(d1 - compare) <= eps ? set : d1; + else if (mode == 1) // not equals + if (sd::math::nd4j_abs(d1 - compare) > eps) + return set; + else + return d1; + // return sd::math::nd4j_abs(d1 - compare) > eps ? set : d1; + else if (mode == 2) // less_than + if (d1 < compare) + return set; + else + return d1; + else if (mode == 3) // greater_than + if (d1 > compare) + return set; + else + return d1; + else if (mode == 4) // less_or_equals_than + if (d1 <= compare) + return set; + else + return d1; + else if (mode == 5) // greater_or_equals_than + if (d1 >= compare) + return set; + else + return d1; + else if (mode == 6) // abs_less_than + if (sd::math::nd4j_abs(d1) < compare) + return set; + else + return d1; + else if (mode == 7) // abs_greater_than + if (sd::math::nd4j_abs(d1) > compare) + return set; + else + return d1; + else if (mode == 8) // is inf + if (sd::math::nd4j_isinf(d1)) + return set; + else + return d1; + else if (mode == 9) // is nan + if (sd::math::nd4j_isnan(d1)) + return set; + else + return d1; + else if (mode == 10) + if (d1 == compare) + return set; + else + return d1; + else if (mode == 11) + if (d1 != compare) + return set; + else + return d1; + else if (mode == 12) // abs_greater_or_equals_than + if (sd::math::nd4j_abs(d1) >= compare) + return set; + else + return d1; + else if (mode == 13) // abs_less_or_equals_than + if (sd::math::nd4j_abs(d1) <= compare) + return set; + else + return d1; + else + printf("Undefined boolean operation: [%i]\n", mode); + return d1; + } +}; + +} // namespace simdOps #endif - diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index d738589a7c08..c77dbcd8f8de 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -27,259 +27,305 @@ #define random_def inline static #endif -// since we can't inherit/overwrite static methods - we just define default impls -#define method_idx random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -1.0f; } -#define method_X random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -2.0f; } -#define method_XY random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator* rng, T *extraParams) { return -3.0f; } - -#define no_exec_special static const bool requiresSpecial = false; static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { } +// since we can't inherit/overwrite static methods - we just define default +// impls +#define method_idx \ + random_def T op(Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -1.0f; \ + } +#define method_X \ + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -2.0f; \ + } +#define method_XY \ + random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, \ + sd::graph::RandomGenerator *rng, T *extraParams) { \ + return -3.0f; \ + } + +#define no_exec_special \ + static const bool requiresSpecial = false; \ + static inline void specialOp( \ + Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, \ + const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, \ + T *extraArguments) {} #ifdef __CUDACC__ -#define no_exec_special_cuda __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { } +#define no_exec_special_cuda \ + __device__ static inline void specialOpCuda( \ + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, \ + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, \ + T *extraArguments) {} #else #define no_exec_special_cuda #endif -#include -#include #include +#include +#include namespace randomOps { - /** - * This Op merges two arrays per-element, if probability meets threshold - */ - template - class ProbablisticMerge { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_X - - random_def T op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T threshold = extraParams[0]; - T randVal = helper->relativeT(idx); - - return randVal <= threshold ? valueY : valueX; - } - }; - - /** - * This Op produces random values within specified boundaries. Disribution is uniform - */ - template - class UniformDistribution { - public: - - no_exec_special - no_exec_special_cuda - - method_XY - method_X - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return helper->relativeT(idx, extraParams[0], extraParams[1]); - } - }; - - /** - * This op produces single bernoulli trial - */ - template - class BernoulliDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return extraParams[0] >= helper->relativeT(idx) ? (T) 1.0f : (T) 0.0f; - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - return valueX >= helper->relativeT(idx) ? (T) 1.0f : (T) 0.0f; - } - }; - - - /** - * This op produces single bernoulli trial - */ - template - class ExponentialDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::min(), T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) without bounds - T xVal = -sd::math::nd4j_log(x); - - return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)(valueX/lambda); //1.f - sd::math::nd4j_exp(-lambda * valueX); //pow((T) M_E, -(lambda * valueX)); - } - }; - - template - class PoissonDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : sd::math::nd4j_igammac(sd::math::nd4j_floor(x), lambda); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return valueX <= (T)0.f ? (T)0.f : (T)sd::math::nd4j_igammac(sd::math::nd4j_floor(valueX), lambda); - } - }; - - template - class GammaDistribution { - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T alpha = extraParams[0]; - T beta = extraParams[1]; - T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10 , sd::DataTypeUtils::template max() / 10); - return x <= (T)0.f ? (T)0.f : sd::math::nd4j_igamma(alpha, x * beta); - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T alpha = extraParams[0]; - T beta = extraParams[1]; - return valueX <= (T)0.f ? (T)0.f : sd::math::nd4j_igamma(alpha, beta * valueX); - } - }; - - /** - * Basic DropOut/DropConnect Op - */ - template - class DropOut { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T randVal = helper->relativeT(idx); - return randVal >= extraParams[0] ? (T) 0.0f : valueX; - } - }; - - template - class AlphaDropOut { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T randVal = helper->relativeT(idx); - // extraParams[0] == p - // [1] = a - // [2] = b - // [3] = alphaPrime - return randVal >= extraParams[0] ? (T) extraParams[1] * extraParams[3] + extraParams[2] : extraParams[1] * valueX + extraParams[2]; - } - }; - - /** - * Inverted DropOut implementation, used in DL4j - */ - template - class DropOutInverted { - public: - - no_exec_special - no_exec_special_cuda - - method_idx - method_XY - - // please note: prob is chance to retain original value - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T prob = extraParams[0]; - T randVal = helper->relativeT(idx); - return randVal >= prob ? (T) 0.0f : valueX / prob; - } - }; - - - template - class Linspace { - public: - - no_exec_special - no_exec_special_cuda - - method_X - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T from = extraParams[0]; - T to = extraParams[1]; - T step = extraParams[2]; - - if (step == static_cast(0.0f)) { - step = (T) idx / ((T)length - (T) 1.0f); - return from * ((T) 1.0f - step) + step * to; - } - return from + (idx * step); - - } - }; - - template - class ExponentialDistributionInv { // inverse exponential distribution - public: - no_exec_special - no_exec_special_cuda - - method_XY - - random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f - sd::DataTypeUtils::template min()); - return -sd::math::nd4j_log((T)1.f - x) / lambda; - } - - random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { - T lambda = extraParams[0]; - return -sd::math::nd4j_log((T)1.f - valueX) / lambda; // valueX must be within (0, 1] - } - }; - -} - -#endif //LIBND4J_RANDOM_OPS_H +/** + * This Op merges two arrays per-element, if probability meets threshold + */ +template +class ProbablisticMerge { + public: + no_exec_special no_exec_special_cuda + + method_idx method_X + + random_def T + op(T valueX, T valueY, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T threshold = extraParams[0]; + T randVal = helper->relativeT(idx); + + return randVal <= threshold ? valueY : valueX; + } +}; + +/** + * This Op produces random values within specified boundaries. Disribution is + * uniform + */ +template +class UniformDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY method_X + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + return helper->relativeT(idx, extraParams[0], extraParams[1]); + } +}; + +/** + * This op produces single bernoulli trial + */ +template +class BernoulliDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + return extraParams[0] >= helper->relativeT(idx) ? (T)1.0f : (T)0.0f; + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + return valueX >= helper->relativeT(idx) ? (T)1.0f : (T)0.0f; + } +}; + +/** + * This op produces single bernoulli trial + */ +template +class ExponentialDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT( + idx, sd::DataTypeUtils::min(), + T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) + // without bounds + T xVal = -sd::math::nd4j_log(x); + + return xVal <= (T)0.f + ? (T)0.f + : xVal / lambda; // pow((T) M_E, -(lambda * x)); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return valueX <= (T)0.f + ? (T)0.f + : (T)(valueX / lambda); // 1.f - sd::math::nd4j_exp(-lambda + // * valueX); //pow((T) M_E, + // -(lambda * valueX)); + } +}; + +template +class PoissonDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f + : sd::math::nd4j_igammac( + sd::math::nd4j_floor(x), lambda); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return valueX <= (T)0.f ? (T)0.f + : (T)sd::math::nd4j_igammac( + sd::math::nd4j_floor(valueX), lambda); + } +}; + +template +class GammaDistribution { + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + T x = helper->relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f + : sd::math::nd4j_igamma(alpha, x * beta); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + return valueX <= (T)0.f + ? (T)0.f + : sd::math::nd4j_igamma(alpha, beta * valueX); + } +}; + +/** + * Basic DropOut/DropConnect Op + */ +template +class DropOut { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T randVal = helper->relativeT(idx); + return randVal >= extraParams[0] ? (T)0.0f : valueX; + } +}; + +template +class AlphaDropOut { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T randVal = helper->relativeT(idx); + // extraParams[0] == p + // [1] = a + // [2] = b + // [3] = alphaPrime + return randVal >= extraParams[0] + ? (T)extraParams[1] * extraParams[3] + extraParams[2] + : extraParams[1] * valueX + extraParams[2]; + } +}; + +/** + * Inverted DropOut implementation, used in DL4j + */ +template +class DropOutInverted { + public: + no_exec_special no_exec_special_cuda + + method_idx method_XY + + // please note: prob is chance to retain original value + random_def T + op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T prob = extraParams[0]; + T randVal = helper->relativeT(idx); + return randVal >= prob ? (T)0.0f : valueX / prob; + } +}; + +template +class Linspace { + public: + no_exec_special no_exec_special_cuda + + method_X method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T from = extraParams[0]; + T to = extraParams[1]; + T step = extraParams[2]; + + if (step == static_cast(0.0f)) { + step = (T)idx / ((T)length - (T)1.0f); + return from * ((T)1.0f - step) + step * to; + } + return from + (idx * step); + } +}; + +template +class ExponentialDistributionInv { // inverse exponential distribution + public: + no_exec_special no_exec_special_cuda + + method_XY + + random_def T + op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, + T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), + (T)1.f - sd::DataTypeUtils::template min()); + return -sd::math::nd4j_log((T)1.f - x) / lambda; + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, + sd::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return -sd::math::nd4j_log((T)1.f - valueX) / + lambda; // valueX must be within (0, 1] + } +}; + +} // namespace randomOps + +#endif // LIBND4J_RANDOM_OPS_H diff --git a/libnd4j/include/ops/special_random_ops.h b/libnd4j/include/ops/special_random_ops.h index 08808e67c9a2..936d8a7d2ac1 100644 --- a/libnd4j/include/ops/special_random_ops.h +++ b/libnd4j/include/ops/special_random_ops.h @@ -21,824 +21,892 @@ #ifndef LIBND4J_SPECIAL_RANDOM_OPS_H #define LIBND4J_SPECIAL_RANDOM_OPS_H -#include -#include +#include #include +#include +#include #include -#include namespace randomOps { ////////////////////////////////////////////////////////////////////// - template - class Choice { - public: - - method_idx - method_X - method_XY - - static const bool requiresSpecial = true; +template +class Choice { + public: + method_idx method_X method_XY + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - /** - * X holds data, - * Y holds probabilities - * Z will hold results - */ - - // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 - //T probSum = extraArguments[0]; - - __shared__ Nd4jLong xLength; - __shared__ Nd4jLong yLength; - __shared__ Nd4jLong zLength; - - __shared__ Nd4jLong xEWS; - __shared__ Nd4jLong yEWS; - __shared__ Nd4jLong zEWS; - __shared__ char xOrder; - __shared__ char yOrder; - __shared__ char zOrder; - - __shared__ sd::graph::RandomGenerator *rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = (sd::graph::RandomGenerator*) shmem; - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - xLength = shape::length(xShapeBuffer); - yLength = shape::length(yShapeBuffer); - zLength = shape::length(zShapeBuffer); - - xEWS = shape::elementWiseStride(xShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - xOrder = shape::order(xShapeBuffer); - yOrder = shape::order(yShapeBuffer); - zOrder = shape::order(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && xOrder == zOrder) { - for (Nd4jLong e = tid; e < zLength; e+=blockDim.x * gridDim.x) { - T prob = rng->relativeT(e); - T cumProb = (T) 0.0f; - for (Nd4jLong f = 0; f < yLength; f++) { - T relProb = y[f * yEWS]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - z[e * zEWS] = x[f * xEWS]; - f += yLength; - } -// __syncthreads(); // Eliminated due RTX20xx specific - } -// __syncthreads(); // Eliminated due RTX20xx specific - } - } - else { - - for (Nd4jLong i = tid; i < zLength; i+=blockDim.x * gridDim.x) { - - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - T prob = rng->relativeT(i); - T cumProb = (T) 0.0f; - - for (Nd4jLong f = 0; f < yLength; f++) { - - auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); - T relProb = y[yOffset2]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - - auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); - z[zOffset2] = x[xOffset2]; - f += yLength; - } -// __syncthreads(); // Eliminated due RTX20xx specific - } -// __syncthreads(); // Eliminated due RTX20xx specific - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + /** + * X holds data, + * Y holds probabilities + * Z will hold results + */ + + // TODO: we probably might want to skip this sum, and state that + // probabilities array should be real probabilities, i.e. should sum to 1.0 + // T probSum = extraArguments[0]; + + __shared__ Nd4jLong xLength; + __shared__ Nd4jLong yLength; + __shared__ Nd4jLong zLength; + + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = (sd::graph::RandomGenerator *)shmem; + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + xLength = shape::length(xShapeBuffer); + yLength = shape::length(yShapeBuffer); + zLength = shape::length(zShapeBuffer); + + xEWS = shape::elementWiseStride(xShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + xOrder = shape::order(xShapeBuffer); + yOrder = shape::order(yShapeBuffer); + zOrder = shape::order(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1 && xOrder == yOrder && + xOrder == zOrder) { + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + T prob = rng->relativeT(e); + T cumProb = (T)0.0f; + for (Nd4jLong f = 0; f < yLength; f++) { + T relProb = y[f * yEWS]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + z[e * zEWS] = x[f * xEWS]; + f += yLength; + } + // __syncthreads(); // Eliminated due RTX20xx + // specific } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + } else { + for (Nd4jLong i = tid; i < zLength; i += blockDim.x * gridDim.x) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + T prob = rng->relativeT(i); + T cumProb = (T)0.0f; + + for (Nd4jLong f = 0; f < yLength; f++) { + auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); + T relProb = y[yOffset2]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); + z[zOffset2] = x[xOffset2]; + f += yLength; + } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + // __syncthreads(); // Eliminated due RTX20xx + // specific + } + } + } #endif - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - /** - * X holds data, - * Y holds probabilities - * Z will hold results - */ - - //sd::random::RandomBuffer *buffer = reinterpret_cast (state); - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - // TODO: we probably might want to skip this sum, and state that probabilities array should be real probabilities, i.e. should sum to 1.0 - //T probSum = extraArguments[0]; - - auto xLength = shape::length(xShapeBuffer); - auto yLength = shape::length(yShapeBuffer); - auto zLength = shape::length(zShapeBuffer); - - auto xEWS = shape::elementWiseStride(xShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - T prob = rng->relativeT(e); - T cumProb = (T) 0.0f; - for (Nd4jLong f = 0; f < yLength; f++) { - T relProb = y[f * yEWS]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - z[e * zEWS] = x[f * xEWS]; - break; - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + /** + * X holds data, + * Y holds probabilities + * Z will hold results + */ + + // sd::random::RandomBuffer *buffer = + // reinterpret_cast (state); + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + // TODO: we probably might want to skip this sum, and state that + // probabilities array should be real probabilities, i.e. should sum to 1.0 + // T probSum = extraArguments[0]; + + auto xLength = shape::length(xShapeBuffer); + auto yLength = shape::length(yShapeBuffer); + auto zLength = shape::length(zShapeBuffer); + + auto xEWS = shape::elementWiseStride(xShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + if (zEWS >= 1 && xEWS >= 1 && yEWS >= 1) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + T prob = rng->relativeT(e); + T cumProb = (T)0.0f; + for (Nd4jLong f = 0; f < yLength; f++) { + T relProb = y[f * yEWS]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + z[e * zEWS] = x[f * xEWS]; + break; } - else { - - auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < zLength; i++) { - - auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); - T prob = rng->relativeT(i); - T cumProb = (T) 0.0f; - - for (Nd4jLong f = 0; f < yLength; f++) { - - auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); - T relProb = y[yOffset2]; - cumProb += relProb; - - if (prob <= cumProb || f == yLength - 1) { - - auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); - z[zOffset2] = x[xOffset2]; - break; - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } + } + }; + + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } else { + auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < zLength; i++) { + auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer); + T prob = rng->relativeT(i); + T cumProb = (T)0.0f; + + for (Nd4jLong f = 0; f < yLength; f++) { + auto yOffset2 = shape::getIndexOffset(f, yShapeBuffer); + T relProb = y[yOffset2]; + cumProb += relProb; + + if (prob <= cumProb || f == yLength - 1) { + auto xOffset2 = shape::getIndexOffset(f, xShapeBuffer); + z[zOffset2] = x[xOffset2]; + break; } + } } - }; + }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within specified boundaries. Distribuion is Gaussian - */ - template - class GaussianDistribution { - public: +/** + * This Op produces random values within specified boundaries. Distribuion is + * Gaussian + */ +template +class GaussianDistribution { + public: + method_XY method_X method_idx - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - - epsilon = static_cast(1e-5); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - T t(-2.0f); - - for (int e = tid; e < middle; e += step) { - auto epm = e + middle; - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = y == z ? mean : y[e * yEWS]; - - z[e * zEWS] = (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - - if (epm < zLength) { - T realMean1 = y == z ? mean : y[epm * yEWS]; - z[epm * zEWS] = (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-5); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + T t(-2.0f); + + for (int e = tid; e < middle; e += step) { + auto epm = e + middle; + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = + (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + + if (epm < zLength) { + T realMean1 = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = + (sd::math::nd4j_sqrt(t * sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + } + } + } #endif - - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - auto zLength = shape::length(zShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - auto middle = zLength % 2 + zLength / 2; - - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - int span = (middle / _threads) + 8; - - // we're enforcing even chunks, since it's mandatory for this algorithm - span -= span % 2; - - //sd::random::RandomBuffer *buffer = reinterpret_cast (state); - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - const T mean = extraArguments[0]; - const T stddev = extraArguments[1]; - - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = y == z ? mean : y[e * yEWS]; - - auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - z[e * zEWS] = z0; - - if (epm < zLength) { - T realMean1 = y == z ? mean : y[epm * yEWS]; - auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - z[epm * zEWS] = z1; - } - } - }; - - samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + + auto zLength = shape::length(zShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + auto middle = zLength % 2 + zLength / 2; + + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + int span = (middle / _threads) + 8; + + // we're enforcing even chunks, since it's mandatory for this algorithm + span -= span % 2; + + // sd::random::RandomBuffer *buffer = + // reinterpret_cast (state); + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + const T mean = extraArguments[0]; + const T stddev = extraArguments[1]; + + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = y == z ? mean : y[e * yEWS]; + + auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + z[e * zEWS] = z0; + + if (epm < zLength) { + T realMean1 = y == z ? mean : y[epm * yEWS]; + auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + z[epm * zEWS] = z1; } + } }; + samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within [0..N], Distribuion is binomial - */ - template - class BinomialDistribution { - public: - - - method_XY - method_X - method_idx +/** + * This Op produces random values within [0..N], Distribuion is binomial + */ +template +class BinomialDistribution { + public: + method_XY method_X method_idx - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - T prob = extraArguments[1]; - - __shared__ Nd4jLong zLength; - __shared__ int yEWS; - __shared__ int zEWS; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast(state); - dB = reinterpret_cast (state); - - zLength = shape::length(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[(t-1) * yEWS]; - } - - if (randVal < prob) - success++; - } - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + T prob = extraArguments[1]; + + __shared__ Nd4jLong zLength; + __shared__ int yEWS; + __shared__ int zEWS; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + zLength = shape::length(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[(t - 1) * yEWS]; } -#endif - - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - - Nd4jLong zLength = shape::length(zShapeBuffer); - - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - T prob = extraArguments[1]; - - sd::graph::RandomGenerator* rng = reinterpret_cast(state); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[(t-1) * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } - }; + if (randVal < prob) success++; + } + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } + } +#endif - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + + Nd4jLong zLength = shape::length(zShapeBuffer); + + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + T prob = extraArguments[1]; + + sd::graph::RandomGenerator *rng = + reinterpret_cast(state); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[(t - 1) * yEWS]; + } + + if (randVal < prob) success++; } + + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; ////////////////////////////////////////////////////////////////////// - /** - * This Op produces random values within [0..N], Distribuion is binomial - */ - template - class BinomialDistributionEx { - public: +/** + * This Op produces random values within [0..N], Distribuion is binomial + */ +template +class BinomialDistributionEx { + public: + method_XY method_X method_idx - - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - T prob = extraArguments[1]; - - __shared__ Nd4jLong zLength; - __shared__ int yEWS; - __shared__ int zEWS; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator *devRng; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = (sd::graph::RandomGenerator*) shmem; - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - zLength = shape::length(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[e * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = (T) success; - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + T prob = extraArguments[1]; + + __shared__ Nd4jLong zLength; + __shared__ int yEWS; + __shared__ int zEWS; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = (sd::graph::RandomGenerator *)shmem; + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + zLength = shape::length(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong e = tid; e < zLength; e += blockDim.x * gridDim.x) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[e * yEWS]; } -#endif - - static inline void specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - int trials = (int) extraArguments[0]; - - Nd4jLong zLength = shape::length(zShapeBuffer); - - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - int elementsPerThread = zLength / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - T prob = extraArguments[1]; + if (randVal < prob) success++; + } - auto rng = reinterpret_cast(state); - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - - int success = 0; - for (int t = 1; t <= trials; t++) { - T randVal = rng->relativeT((e+1) * t); - if (y != z) { - // we're using external probs - prob = y[e * yEWS]; - } - - if (randVal < prob) - success++; - } - - // if trials is set to 0, effectively we just have successful memset - z[e * zEWS] = static_cast(success); - } - }; + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = (T)success; + } + } +#endif - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + int trials = (int)extraArguments[0]; + + Nd4jLong zLength = shape::length(zShapeBuffer); + + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + int elementsPerThread = zLength / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + T prob = extraArguments[1]; + + auto rng = reinterpret_cast(state); + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + int success = 0; + for (int t = 1; t <= trials; t++) { + T randVal = rng->relativeT((e + 1) * t); + if (y != z) { + // we're using external probs + prob = y[e * yEWS]; + } + + if (randVal < prob) success++; } - }; -////////////////////////////////////////////////////////////////////// - // This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev] - template - class TruncatedNormalDistribution { - private: - static inline _CUDA_HD T step(sd::graph::RandomGenerator* rng, T mean, T stddev, Nd4jLong e, Nd4jLong middle, T& z) { - auto epm = e + middle; - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - const T epsilon = static_cast(1.e-5f); - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean0 = mean; - - auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean0; - z = z0; - if (epm < middle) { - T realMean1 = mean; - auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * - sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; - z = z1; - } - return z; - } - public: + // if trials is set to 0, effectively we just have successful memset + z[e * zEWS] = static_cast(success); + } + }; - method_XY - method_X - method_idx + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; - static const bool requiresSpecial = true; +////////////////////////////////////////////////////////////////////// +// This Op produces random Gaussian values within [mean-2*stddev,mean+2*stddev] +template +class TruncatedNormalDistribution { + private: + static inline _CUDA_HD T step(sd::graph::RandomGenerator *rng, T mean, + T stddev, Nd4jLong e, Nd4jLong middle, T &z) { + auto epm = e + middle; + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + const T epsilon = static_cast(1.e-5f); + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean0 = mean; + + auto z0 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean0; + z = z0; + if (epm < middle) { + T realMean1 = mean; + auto z1 = (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean1; + z = z1; + } + return z; + } + + public: + method_XY method_X method_idx + + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator* devRng; - __shared__ Nd4jLong middle; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(shmem); - cB = shmem; - devRng = reinterpret_cast (state); - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - epsilon = static_cast(1e-6f); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - middle = zLength / 2 + (zLength % 2); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - GaussianDistribution::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - __syncthreads(); - - T ds = sd::math::nd4j_abs(stddev) * static_cast(2.0f); - for (Nd4jLong e = tid; e < zLength; e += step) { - if (z[e] > mean + ds || z[e] < mean - ds) { - z[e] = TruncatedNormalDistribution::step(rng, mean, stddev, e, middle, z[e]); - - if (z[e] > mean + ds || z[e] < mean - ds) - z[e] = mean + sd::DataTypeUtils::min(); - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + __shared__ Nd4jLong middle; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(shmem); + cB = shmem; + devRng = reinterpret_cast(state); + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-6f); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + middle = zLength / 2 + (zLength % 2); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + GaussianDistribution::specialOpCuda(state, x, xShapeBuffer, y, + yShapeBuffer, z, zShapeBuffer, + extraArguments); + __syncthreads(); + + T ds = sd::math::nd4j_abs(stddev) * static_cast(2.0f); + for (Nd4jLong e = tid; e < zLength; e += step) { + if (z[e] > mean + ds || z[e] < mean - ds) { + z[e] = TruncatedNormalDistribution::step(rng, mean, stddev, e, + middle, z[e]); + + if (z[e] > mean + ds || z[e] < mean - ds) + z[e] = mean + sd::DataTypeUtils::min(); + } + } + } #endif - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - GaussianDistribution::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); - Nd4jLong zLength = shape::length(zShapeBuffer); - //auto yEWS = shape::elementWiseStride(yShapeBuffer); - //auto zEWS = shape::elementWiseStride(zShapeBuffer); - auto rng = reinterpret_cast(state); - T mean = extraArguments[0]; - T stddev = extraArguments[1]; - T ds = sd::math::nd4j_abs(stddev) * (T) 2.0f; - Nd4jLong middle = zLength / 2 + (zLength % 2); - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - if (z[e] > mean + ds || z[e] < mean - ds) { - z[e] = step(rng, mean, stddev, e, middle, z[e]); - - if (z[e] > mean + ds || z[e] < mean - ds) - z[e] = mean + sd::DataTypeUtils::min(); - } - } - }; - - samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + GaussianDistribution::specialOp(state, x, xShapeBuffer, y, yShapeBuffer, + z, zShapeBuffer, extraArguments); + Nd4jLong zLength = shape::length(zShapeBuffer); + // auto yEWS = shape::elementWiseStride(yShapeBuffer); + // auto zEWS = shape::elementWiseStride(zShapeBuffer); + auto rng = reinterpret_cast(state); + T mean = extraArguments[0]; + T stddev = extraArguments[1]; + T ds = sd::math::nd4j_abs(stddev) * (T)2.0f; + Nd4jLong middle = zLength / 2 + (zLength % 2); + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + if (z[e] > mean + ds || z[e] < mean - ds) { + z[e] = step(rng, mean, stddev, e, middle, z[e]); + + if (z[e] > mean + ds || z[e] < mean - ds) + z[e] = mean + sd::DataTypeUtils::min(); } + } }; + samediff::Threads::parallel_for(func, 0, zLength, 1, _threads); + } +}; + ////////////////////////////////////////////////////////////////////// // This Op produces random Log-normal distribution - template - class LogNormalDistribution { - public: - - method_XY - method_X - method_idx - - static const bool requiresSpecial = true; +template +class LogNormalDistribution { + public: + method_XY method_X method_idx + static const bool requiresSpecial = true; #ifdef __CUDACC__ - __device__ static inline void specialOpCuda(Nd4jPointer state, T const* x, Nd4jLong const* xShapeBuffer, T const* y, Nd4jLong const* yShapeBuffer, T *z, Nd4jLong const* zShapeBuffer, T *extraArguments) { - __shared__ T epsilon; - __shared__ T two_pi; - - __shared__ Nd4jLong zLength; - __shared__ Nd4jLong zEWS; - __shared__ Nd4jLong yEWS; - __shared__ T mean; - __shared__ T stddev; - __shared__ int step; - - __shared__ T *tZ; - - __shared__ sd::graph::RandomGenerator* rng; - __shared__ unsigned char *cB; - __shared__ unsigned char *dB; - __shared__ sd::graph::RandomGenerator* devRng; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - rng = reinterpret_cast(state); - cB = shmem; - devRng = reinterpret_cast(state); - - dB = reinterpret_cast (state); - - tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); - - zLength = shape::length(zShapeBuffer); - zEWS = shape::elementWiseStride(zShapeBuffer); - yEWS = shape::elementWiseStride(yShapeBuffer); - - - epsilon = static_cast(1e-5); - two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - mean = extraArguments[0]; - stddev = extraArguments[1]; - - step = (blockDim.x * gridDim.x); - } - __syncthreads(); - - // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); e+= blockDim.x) - cB[e] = dB[e]; - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - - for (Nd4jLong e = tid; e < middle; e += step) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean = y == z ? mean : y[e * yEWS]; - - z[e *zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean); - - if (epm < zLength) { - realMean = y == z ? mean : y[epm * yEWS]; - z[epm *zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean); - } - } - } + __device__ static inline void specialOpCuda( + Nd4jPointer state, T const *x, Nd4jLong const *xShapeBuffer, T const *y, + Nd4jLong const *yShapeBuffer, T *z, Nd4jLong const *zShapeBuffer, + T *extraArguments) { + __shared__ T epsilon; + __shared__ T two_pi; + + __shared__ Nd4jLong zLength; + __shared__ Nd4jLong zEWS; + __shared__ Nd4jLong yEWS; + __shared__ T mean; + __shared__ T stddev; + __shared__ int step; + + __shared__ T *tZ; + + __shared__ sd::graph::RandomGenerator *rng; + __shared__ unsigned char *cB; + __shared__ unsigned char *dB; + __shared__ sd::graph::RandomGenerator *devRng; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + rng = reinterpret_cast(state); + cB = shmem; + devRng = reinterpret_cast(state); + + dB = reinterpret_cast(state); + + tZ = reinterpret_cast(shmem + sizeof(sd::graph::RandomGenerator)); + + zLength = shape::length(zShapeBuffer); + zEWS = shape::elementWiseStride(zShapeBuffer); + yEWS = shape::elementWiseStride(yShapeBuffer); + + epsilon = static_cast(1e-5); + two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); + + mean = extraArguments[0]; + stddev = extraArguments[1]; + + step = (blockDim.x * gridDim.x); + } + __syncthreads(); + + // using this loop instead of memcpy + for (int e = threadIdx.x; e < sizeof(sd::graph::RandomGenerator); + e += blockDim.x) + cB[e] = dB[e]; + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + int middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + + for (Nd4jLong e = tid; e < middle; e += step) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean); + + if (epm < zLength) { + realMean = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean); + } + } + } #endif - static inline void - specialOp(Nd4jPointer state, const T *x, const Nd4jLong *xShapeBuffer, const T *y, const Nd4jLong *yShapeBuffer, T *z, const Nd4jLong *zShapeBuffer, T *extraArguments) { - const T two_pi = static_cast(2.0f) * static_cast(3.14159265358979323846); - - Nd4jLong zLength = shape::length(zShapeBuffer); - auto yEWS = shape::elementWiseStride(yShapeBuffer); - auto zEWS = shape::elementWiseStride(zShapeBuffer); - - auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; - - int elementsPerThread = middle / TAD_THRESHOLD; - int _threads = sd::math::nd4j_max(1, elementsPerThread); - _threads = sd::math::nd4j_min(_threads, sd::Environment::getInstance()->maxThreads()); - - int span = (zLength / _threads) + 8; - - // we're enforcing even chunks, since it's mandatory for this algorithm - span -= span % 2; - - auto rng = reinterpret_cast(state); - - const T mean = extraArguments[0]; - const T stddev = extraArguments[1]; - const T epsilon = static_cast(1e-5); - - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto e = start; e < stop; e++) { - auto epm = e + middle; - - // we need to get random values - T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); - T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); - - T realMean = y == z ? mean : y[e * yEWS]; - - z[e * zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_cos(two_pi * r1)) * stddev + realMean); - - if (epm < zLength) { - realMean = y == z ? mean : y[epm * yEWS]; - z[epm * zEWS] = sd::math::nd4j_exp((sd::math::nd4j_sqrt(static_cast(-2.0f) * sd::math::nd4j_log(r0)) * sd::math::nd4j_sin(two_pi * r1)) * stddev + realMean); - } - } - }; - - samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + static inline void specialOp(Nd4jPointer state, const T *x, + const Nd4jLong *xShapeBuffer, const T *y, + const Nd4jLong *yShapeBuffer, T *z, + const Nd4jLong *zShapeBuffer, + T *extraArguments) { + const T two_pi = + static_cast(2.0f) * static_cast(3.14159265358979323846); + + Nd4jLong zLength = shape::length(zShapeBuffer); + auto yEWS = shape::elementWiseStride(yShapeBuffer); + auto zEWS = shape::elementWiseStride(zShapeBuffer); + + auto middle = zLength % 2 == 0 ? zLength / 2 : zLength / 2 + 1; + + int elementsPerThread = middle / TAD_THRESHOLD; + int _threads = sd::math::nd4j_max(1, elementsPerThread); + _threads = sd::math::nd4j_min( + _threads, sd::Environment::getInstance()->maxThreads()); + + int span = (zLength / _threads) + 8; + + // we're enforcing even chunks, since it's mandatory for this algorithm + span -= span % 2; + + auto rng = reinterpret_cast(state); + + const T mean = extraArguments[0]; + const T stddev = extraArguments[1]; + const T epsilon = static_cast(1e-5); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto e = start; e < stop; e++) { + auto epm = e + middle; + + // we need to get random values + T r0 = rng->relativeT(e, epsilon, static_cast(1.0f)); + T r1 = rng->relativeT(epm, epsilon, static_cast(1.0f)); + + T realMean = y == z ? mean : y[e * yEWS]; + + z[e * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_cos(two_pi * r1)) * + stddev + + realMean); + + if (epm < zLength) { + realMean = y == z ? mean : y[epm * yEWS]; + z[epm * zEWS] = sd::math::nd4j_exp( + (sd::math::nd4j_sqrt(static_cast(-2.0f) * + sd::math::nd4j_log(r0)) * + sd::math::nd4j_sin(two_pi * r1)) * + stddev + + realMean); } + } }; + samediff::Threads::parallel_for(func, 0, middle, 1, _threads); + } +}; -} +} // namespace randomOps -#endif //LIBND4J_SPECIAL_RANDOM_OPS_H +#endif // LIBND4J_SPECIAL_RANDOM_OPS_H diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index 63184206ac10..a4e8fa2d3613 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -21,67 +21,84 @@ #ifndef LIBND4J_SPECIALS_H #define LIBND4J_SPECIALS_H - #ifdef __CUDACC__ #define ELEMENT_THRESHOLD 8192 #define TAD_THRESHOLD 2 #endif #include + #include namespace sd { - class NDArray; - - //FIXME: get rid of this redefinition - typedef union - { - float f_; - int i_; - } FloatBits2; - - - class SD_EXPORT SpecialTypeConverter { - public: - template - static void convertGeneric(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - }; - - template - class SD_EXPORT SpecialMethods { - public: - static void concatCpuGeneric(const std::vector& inArrs, NDArray& output, int axis); - static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong const* resultShapeInfo); - static void splitCpuGeneric(const NDArray& input, const std::vector& outArrs, int axis); - static void accumulateGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, int n, Nd4jLong length); - static void averageGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, int n, Nd4jLong length, bool propagate); - - static Nd4jLong getPosition(const Nd4jLong *xShapeInfo, Nd4jLong index); - static void quickSort_parallel_internal(T* array, const Nd4jLong *xShapeInfo, int left, int right, int cutoff, bool descending); - static void quickSort_parallel(void* array, const Nd4jLong *xShapeInfo, Nd4jLong lenArray, int numThreads, bool descending); - - static int nextPowerOf2(int number); - static int lastPowerOf2(int number); - - static void sortGeneric(void *x, const Nd4jLong *xShapeInfo, bool descending); - static void sortTadGeneric(void *x, const Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets, bool descending); - - static void decodeBitmapGeneric(const void *dx, Nd4jLong N, void *dz, const Nd4jLong *zShapeInfo); - static Nd4jLong encodeBitmapGeneric(void *dx, const Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold); - - }; - - template - class SD_EXPORT DoubleMethods{ - public: - static void sortByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending); - static void sortByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, bool descending); - - - static void sortTadByKey(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending); - static void sortTadByValue(void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, bool descending); - }; -} - - -#endif //LIBND4J_SPECIALS_H +class NDArray; + +// FIXME: get rid of this redefinition +typedef union { + float f_; + int i_; +} FloatBits2; + +class SD_EXPORT SpecialTypeConverter { + public: + template + static void convertGeneric(Nd4jPointer *extras, void *dx, Nd4jLong N, + void *dz); +}; + +template +class SD_EXPORT SpecialMethods { + public: + static void concatCpuGeneric(const std::vector &inArrs, + NDArray &output, int axis); + static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, + Nd4jPointer *inputShapeInfo, void *result, + Nd4jLong const *resultShapeInfo); + static void splitCpuGeneric(const NDArray &input, + const std::vector &outArrs, int axis); + static void accumulateGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, + int n, Nd4jLong length); + static void averageGeneric(void **x, void *z, const Nd4jLong *zShapeInfo, + int n, Nd4jLong length, bool propagate); + + static Nd4jLong getPosition(const Nd4jLong *xShapeInfo, Nd4jLong index); + static void quickSort_parallel_internal(T *array, const Nd4jLong *xShapeInfo, + int left, int right, int cutoff, + bool descending); + static void quickSort_parallel(void *array, const Nd4jLong *xShapeInfo, + Nd4jLong lenArray, int numThreads, + bool descending); + + static int nextPowerOf2(int number); + static int lastPowerOf2(int number); + + static void sortGeneric(void *x, const Nd4jLong *xShapeInfo, bool descending); + static void sortTadGeneric(void *x, const Nd4jLong *xShapeInfo, + int *dimension, int dimensionLength, + const Nd4jLong *tadShapeInfo, + const Nd4jLong *tadOffsets, bool descending); + + static void decodeBitmapGeneric(const void *dx, Nd4jLong N, void *dz, + const Nd4jLong *zShapeInfo); + static Nd4jLong encodeBitmapGeneric(void *dx, const Nd4jLong *zShapeInfo, + Nd4jLong N, int *dz, float threshold); +}; + +template +class SD_EXPORT DoubleMethods { + public: + static void sortByKey(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, bool descending); + static void sortByValue(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, bool descending); + + static void sortTadByKey(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, bool descending); + static void sortTadByValue(void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, bool descending); +}; +} // namespace sd + +#endif // LIBND4J_SPECIALS_H diff --git a/libnd4j/include/ops/specials_cuda.h b/libnd4j/include/ops/specials_cuda.h index a12fd302f648..17cf26e30af7 100644 --- a/libnd4j/include/ops/specials_cuda.h +++ b/libnd4j/include/ops/specials_cuda.h @@ -21,85 +21,115 @@ #ifndef PROJECT_SPECIALS_CUDA_H #define PROJECT_SPECIALS_CUDA_H -#include #include +#include #ifdef __CUDACC__ //////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int j, int k, int length, bool descending); +template +__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + int j, int k, int length, bool descending); //////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int window, int length, int reverse, bool descending); +template +__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, + int window, int length, int reverse, + bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending); +__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, + void *vy, Nd4jLong const *yShapeInfo, + int j, int k, int length, + bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending); +__host__ void bitonicArbitraryStepGenericKey( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int j, int k, int length, bool descending); +__host__ void bitonicSortStepGenericValue(dim3 &launchDims, + cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int j, + int k, int length, bool descending); //////////////////////////////////////////////////////////////////////// template -__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int window, int length, int reverse, bool descending); - - +__host__ void bitonicArbitraryStepGenericValue( + dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, Nd4jLong const *yShapeInfo, + int window, int length, int reverse, bool descending); //////////////////////////////////////////////////////////////////////// -template -__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +template +__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, int *dimension, + int dimensionLength, Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); template -__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, + Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); template -__host__ void oesTadGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong const* xShapeInfo, void *vy, Nd4jLong const* yShapeInfo, int *dimension, int dimensionLength, Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets, bool descending); +__host__ void oesTadGenericValue(dim3 &launchDims, cudaStream_t *stream, + void *vx, Nd4jLong const *xShapeInfo, void *vy, + Nd4jLong const *yShapeInfo, int *dimension, + int dimensionLength, + Nd4jLong const *tadShapeInfo, + Nd4jLong const *tadOffsets, bool descending); //////////////////////////////////////////////////////////////////////// -template -__global__ void printCudaGlobal(void* pointer, const int len) { - - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pointer)[i] ); - printf("\n"); +template +__global__ void printCudaGlobal(void *pointer, const int len) { + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pointer)[i]); + printf("\n"); } //////////////////////////////////////////////////////////////////////// -template -__device__ void printCudaDevice(void* pointer, const int len, const int tid = 0) { - - if(blockIdx.x * blockDim.x + threadIdx.x != tid) return; - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(pointer)[i] ); - printf("\n"); +template +__device__ void printCudaDevice(void *pointer, const int len, + const int tid = 0) { + if (blockIdx.x * blockDim.x + threadIdx.x != tid) return; + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(pointer)[i]); + printf("\n"); } //////////////////////////////////////////////////////////////////////// -template -__host__ void printCudaHost(void* pointer, const int len, cudaStream_t& stream) { - - void* ptr = malloc(sizeof(T)*len); - - cudaMemcpyAsync(ptr, pointer, sizeof(T)*len, cudaMemcpyDeviceToHost, stream); - cudaError_t cudaResult = cudaStreamSynchronize(stream); - if(cudaResult != 0) - throw std::runtime_error("printCudaHost:: cudaStreamSynchronize failed!"); - - for(int i = 0; i < len; ++i) - printf("%f, ", (double)reinterpret_cast(ptr)[i]); - printf("\n"); - - free(ptr); +template +__host__ void printCudaHost(void *pointer, const int len, + cudaStream_t &stream) { + void *ptr = malloc(sizeof(T) * len); + + cudaMemcpyAsync(ptr, pointer, sizeof(T) * len, cudaMemcpyDeviceToHost, + stream); + cudaError_t cudaResult = cudaStreamSynchronize(stream); + if (cudaResult != 0) + throw std::runtime_error("printCudaHost:: cudaStreamSynchronize failed!"); + + for (int i = 0; i < len; ++i) + printf("%f, ", (double)reinterpret_cast(ptr)[i]); + printf("\n"); + + free(ptr); } - #endif -#endif //PROJECT_SPECIALS_CUDA_H +#endif // PROJECT_SPECIALS_CUDA_H diff --git a/libnd4j/include/ops/specials_sparse.h b/libnd4j/include/ops/specials_sparse.h index cd0e2f6b529b..32715509d7ba 100644 --- a/libnd4j/include/ops/specials_sparse.h +++ b/libnd4j/include/ops/specials_sparse.h @@ -26,44 +26,50 @@ #include namespace sd { - namespace sparse { +namespace sparse { - template - class SparseUtils { - public: - /** - * Just simple helper for debugging :) - * - * @param indices - * @param rank - * @param x - */ - static void printIndex(Nd4jLong *indices, int rank, int x); - static bool ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); +template +class SparseUtils { + public: + /** + * Just simple helper for debugging :) + * + * @param indices + * @param rank + * @param x + */ + static void printIndex(Nd4jLong *indices, int rank, int x); + static bool ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); - /** - * Returns true, if x > y, false otherwise - * @param indices - * @param rank - * @param x - * @param y - * @return - */ - static bool gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); + /** + * Returns true, if x > y, false otherwise + * @param indices + * @param rank + * @param x + * @param y + * @return + */ + static bool gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y); - static void swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y); + static void swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, + Nd4jLong y); - static void coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank); + static void coo_quickSort_parallel_internal(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int cutoff, int rank); - static void coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank); + static void coo_quickSort_parallel(Nd4jLong *indices, T *array, + Nd4jLong lenArray, int numThreads, + int rank); - static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right, - int rank); + static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, + Nd4jLong left, Nd4jLong right, + int rank); - static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank); - }; - } -} + static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, + Nd4jLong length, int rank); +}; +} // namespace sparse +} // namespace sd - -#endif //LIBND4J_SPECIALS_SPARSE_H +#endif // LIBND4J_SPECIALS_SPARSE_H diff --git a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h index ab150ddbbb52..1d5de9c518d9 100644 --- a/libnd4j/include/performance/benchmarking/BenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/BenchmarkSuit.h @@ -21,21 +21,21 @@ #ifndef LIBND4J_BENCHMARKSUIT_H #define LIBND4J_BENCHMARKSUIT_H -#include -#include -#include -#include #include +#include +#include +#include -namespace sd { - class SD_EXPORT BenchmarkSuit { - public: - BenchmarkSuit() = default; - ~BenchmarkSuit() = default; +#include - virtual std::string runSuit() = 0; - }; -} +namespace sd { +class SD_EXPORT BenchmarkSuit { + public: + BenchmarkSuit() = default; + ~BenchmarkSuit() = default; + virtual std::string runSuit() = 0; +}; +} // namespace sd -#endif //SD_BENCHMARKSUIT_H +#endif // SD_BENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h index d5a653649f99..14fe0cf78695 100644 --- a/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/FullBenchmarkSuit.h @@ -24,11 +24,10 @@ #include namespace sd { - class FullBenchmarkSuit : public BenchmarkSuit { - public: - std::string runSuit() override; - }; -} +class FullBenchmarkSuit : public BenchmarkSuit { + public: + std::string runSuit() override; +}; +} // namespace sd - -#endif //SD_FULLBENCHMARKSUIT_H +#endif // SD_FULLBENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h index 1822a6f98588..846d29759941 100644 --- a/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h +++ b/libnd4j/include/performance/benchmarking/LightBenchmarkSuit.h @@ -24,11 +24,10 @@ #include namespace sd { - class LightBenchmarkSuit : public BenchmarkSuit { - public: - std::string runSuit() override; - }; -} +class LightBenchmarkSuit : public BenchmarkSuit { + public: + std::string runSuit() override; +}; +} // namespace sd - -#endif //SD_LIGHTBENCHMARKSUIT_H +#endif // SD_LIGHTBENCHMARKSUIT_H diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index f1bccad73017..3b0cba7a5251 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -19,1898 +19,2044 @@ // #include -#include #include +#include + #include #ifdef RELEASE_BUILD - int wIterations = 4; - int rIterations = 20; - int gemmRegularUpperPow = 11; - int scalarBenchmarkPowLimit = 26; - int transformBenchmarkPowLimit = 26; - int intermediateTransformPowLimit = 22; - int intermediateTransformPowLimit2 = 18; - int pairwisePowLimit = 26; - int heavyPowLimit = 22; - int nonEwsPowLimit = 10; - int reduceScalarPowLimit = 26; - int stridedReductionPowLimit = 20; - int mismatchedAssignPowLimit = 26; - int gatherOpPowLimit = 18; - int gatherOpPowLimit2 = 16; - int gatherOpPowLimit3 = 12; - int broadcastMatrixRankLimit = 5; - int limit30 = 30; - int limit26 = 26; - int limit24 = 24; - int limit22 = 22; - int limit20 = 20; - int limit18 = 18; - int limit10 = 10; - int limit5 = 5; - int limit3 = 3; +int wIterations = 4; +int rIterations = 20; +int gemmRegularUpperPow = 11; +int scalarBenchmarkPowLimit = 26; +int transformBenchmarkPowLimit = 26; +int intermediateTransformPowLimit = 22; +int intermediateTransformPowLimit2 = 18; +int pairwisePowLimit = 26; +int heavyPowLimit = 22; +int nonEwsPowLimit = 10; +int reduceScalarPowLimit = 26; +int stridedReductionPowLimit = 20; +int mismatchedAssignPowLimit = 26; +int gatherOpPowLimit = 18; +int gatherOpPowLimit2 = 16; +int gatherOpPowLimit3 = 12; +int broadcastMatrixRankLimit = 5; +int limit30 = 30; +int limit26 = 26; +int limit24 = 24; +int limit22 = 22; +int limit20 = 20; +int limit18 = 18; +int limit10 = 10; +int limit5 = 5; +int limit3 = 3; #else - int wIterations = 0; - int rIterations = 1; - int gemmRegularUpperPow = 7; - int scalarBenchmarkPowLimit = 10; - int transformBenchmarkPowLimit = 10; - int intermediateTransformPowLimit = 10; - int intermediateTransformPowLimit2 = 10; - int pairwisePowLimit = 10; - int heavyPowLimit = 10; - int nonEwsPowLimit = 6; - int reduceScalarPowLimit = 10; - int stridedReductionPowLimit = 12; - int mismatchedAssignPowLimit = 2; - int gatherOpPowLimit = 10; - int gatherOpPowLimit2 = 8; - int gatherOpPowLimit3 = 8; - int broadcastMatrixRankLimit = 3; - int limit26 = 8; - int limit24 = 8; - int limit22 = 8; - int limit20 = 8; - int limit18 = 8; - int limit10 = 4; - int limit5 = 3; - int limit3 = 1; +int wIterations = 0; +int rIterations = 1; +int gemmRegularUpperPow = 7; +int scalarBenchmarkPowLimit = 10; +int transformBenchmarkPowLimit = 10; +int intermediateTransformPowLimit = 10; +int intermediateTransformPowLimit2 = 10; +int pairwisePowLimit = 10; +int heavyPowLimit = 10; +int nonEwsPowLimit = 6; +int reduceScalarPowLimit = 10; +int stridedReductionPowLimit = 12; +int mismatchedAssignPowLimit = 2; +int gatherOpPowLimit = 10; +int gatherOpPowLimit2 = 8; +int gatherOpPowLimit3 = 8; +int broadcastMatrixRankLimit = 3; +int limit26 = 8; +int limit24 = 8; +int limit22 = 8; +int limit20 = 8; +int limit18 = 8; +int limit10 = 4; +int limit5 = 3; +int limit3 = 1; #endif namespace sd { - static std::string layerNormBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); +static std::string layerNormBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - BoolParameters nhwc("nhwc"); //0 = nchw + BoolParameters nhwc("nhwc"); // 0 = nchw #ifdef _RELEASE - int c = 32; - int hw = 64; + int c = 32; + int hw = 64; #else - int c = 3; - int hw = 8; + int c = 3; + int hw = 8; #endif - ParametersBatch batch({&nhwc}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - - int axis; - if (n == 0) { - //nchw - auto input = NDArrayFactory::create('c', {16, c, hw, hw}); - auto output = NDArrayFactory::create('c', {16, c, hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - axis = 1; - } else { - auto input = NDArrayFactory::create('c', {32, hw, hw, c}); - auto output = NDArrayFactory::create('c', {32, hw, hw, c}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - axis = 3; - } - - auto bias = NDArrayFactory::create('c', {c}); - ctx->setInputArray(1, bias); - auto iargs = new Nd4jLong[1]; - iargs[0] = axis; - ctx->setIArguments(iargs, 1); - delete[] iargs; - - return ctx; - }; + ParametersBatch batch({&nhwc}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + + int axis; + if (n == 0) { + // nchw + auto input = NDArrayFactory::create('c', {16, c, hw, hw}); + auto output = NDArrayFactory::create('c', {16, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + axis = 1; + } else { + auto input = NDArrayFactory::create('c', {32, hw, hw, c}); + auto output = NDArrayFactory::create('c', {32, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + axis = 3; + } - sd::ops::layer_norm layerNorm; - DeclarableBenchmark benchmark(layerNorm, "layer norm"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Layer Norm"); + auto bias = NDArrayFactory::create('c', {c}); + ctx->setInputArray(1, bias); + auto iargs = new Nd4jLong[1]; + iargs[0] = axis; + ctx->setIArguments(iargs, 1); + delete[] iargs; - return output; - } + return ctx; + }; + + sd::ops::layer_norm layerNorm; + DeclarableBenchmark benchmark(layerNorm, "layer norm"); + output += helper.runOperationSuit(&benchmark, generator, batch, "Layer Norm"); + return output; +} - static std::string maxPool3DBenchmark(){ - std::string output; - BenchmarkHelper helper(wIterations, rIterations); +static std::string maxPool3DBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - BoolParameters ncdhw("ncdhw"); //1 = ndhwc - ParametersBatch batch({&ncdhw}); + BoolParameters ncdhw("ncdhw"); // 1 = ndhwc + ParametersBatch batch({&ncdhw}); - sd::ops::maxpool3dnew maxpool3Dnew; - DeclarableBenchmark benchmark(maxpool3Dnew, "maxPool3d"); + sd::ops::maxpool3dnew maxpool3Dnew; + DeclarableBenchmark benchmark(maxpool3Dnew, "maxPool3d"); #ifdef _RELEASE - int mb = 16; - int chIn = 16; - int chOut = 16; - int dhw = 64; + int mb = 16; + int chIn = 16; + int chOut = 16; + int dhw = 64; #else - int mb = 1; - int chIn = 3; - int chOut = 3; - int dhw = 16; + int mb = 1; + int chIn = 3; + int chOut = 3; + int dhw = 16; #endif - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int format = p.getIntParam("ncdhw"); - - //Set inputs and outputs - //Same mode + stride 1: output is same shape as input - if(format == 1) { - //NDHWC - ctx->setInputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); - } else { - //NCDHW - ctx->setInputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); - } - - auto iargs = new Nd4jLong[15]; - //Kernel, strides, padding, dilation - x3 each - iargs[0] = 3; //Kernel - iargs[1] = 3; - iargs[2] = 3; - iargs[3] = 1; //Stride - iargs[4] = 1; - iargs[5] = 1; - iargs[6] = 0; //Padding - iargs[7] = 0; - iargs[8] = 0; - iargs[9] = 1; //Dilation - iargs[10] = 1; - iargs[11] = 1; - iargs[12] = 1; //Same mode - iargs[13] = 0; //Unused for max - iargs[14] = format; //0 = ncdhw - ctx->setIArguments(iargs, 14); - delete[] iargs; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "maxPool3d"); - return output; + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int format = p.getIntParam("ncdhw"); + + // Set inputs and outputs + // Same mode + stride 1: output is same shape as input + if (format == 1) { + // NDHWC + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + } else { + // NCDHW + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } - - static std::string conv3dBenchmark(){ - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - BoolParameters ncdhw("ncdhw"); //1 = ndhwc - ParametersBatch batch({&ncdhw}); - - sd::ops::conv3dnew conv3Dnew; - DeclarableBenchmark benchmark(conv3Dnew, "conv3d"); + auto iargs = new Nd4jLong[15]; + // Kernel, strides, padding, dilation - x3 each + iargs[0] = 3; // Kernel + iargs[1] = 3; + iargs[2] = 3; + iargs[3] = 1; // Stride + iargs[4] = 1; + iargs[5] = 1; + iargs[6] = 0; // Padding + iargs[7] = 0; + iargs[8] = 0; + iargs[9] = 1; // Dilation + iargs[10] = 1; + iargs[11] = 1; + iargs[12] = 1; // Same mode + iargs[13] = 0; // Unused for max + iargs[14] = format; // 0 = ncdhw + ctx->setIArguments(iargs, 14); + delete[] iargs; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "maxPool3d"); + return output; +} + +static std::string conv3dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + BoolParameters ncdhw("ncdhw"); // 1 = ndhwc + ParametersBatch batch({&ncdhw}); + + sd::ops::conv3dnew conv3Dnew; + DeclarableBenchmark benchmark(conv3Dnew, "conv3d"); #ifdef _RELEASE - int mb = 16; - int chIn = 16; - int chOut = 16; - int dhw = 64; + int mb = 16; + int chIn = 16; + int chOut = 16; + int dhw = 64; #else - int mb = 1; - int chIn = 3; - int chOut = 3; - int dhw = 16; + int mb = 1; + int chIn = 3; + int chOut = 3; + int dhw = 16; #endif - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int format = p.getIntParam("ncdhw"); - - //Set inputs and outputs - //Same mode + stride 1: output is same shape as input - if(format == 1) { - //NDHWC - ctx->setInputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); - } else { - //NCDHW - ctx->setInputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); - } - - //Weights and bias: - ctx->setInputArray(1, NDArrayFactory::create('c', {3, 3, 3, chIn, chOut})); - ctx->setInputArray(2, NDArrayFactory::create('c', {chOut})); - - - auto iargs = new Nd4jLong[14]; - //Kernel, strides, padding, dilation - x3 each - iargs[0] = 3; //Kernel - iargs[1] = 3; - iargs[2] = 3; - iargs[3] = 1; //Stride - iargs[4] = 1; - iargs[5] = 1; - iargs[6] = 0; //Padding - iargs[7] = 0; - iargs[8] = 0; - iargs[9] = 1; //Dilation - iargs[10] = 1; - iargs[11] = 1; - iargs[12] = 1; //Same mode - iargs[13] = format; //0 = ncdhw - ctx->setIArguments(iargs, 14); - delete[] iargs; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "CNN3D"); - return output; + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int format = p.getIntParam("ncdhw"); + + // Set inputs and outputs + // Same mode + stride 1: output is same shape as input + if (format == 1) { + // NDHWC + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, dhw, dhw, dhw, chIn})); + } else { + // NCDHW + ctx->setInputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {mb, chIn, dhw, dhw, dhw})); } - - static std::string lstmBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - BoolParameters format("format"); //0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] + // Weights and bias: + ctx->setInputArray( + 1, NDArrayFactory::create('c', {3, 3, 3, chIn, chOut})); + ctx->setInputArray(2, NDArrayFactory::create('c', {chOut})); + + auto iargs = new Nd4jLong[14]; + // Kernel, strides, padding, dilation - x3 each + iargs[0] = 3; // Kernel + iargs[1] = 3; + iargs[2] = 3; + iargs[3] = 1; // Stride + iargs[4] = 1; + iargs[5] = 1; + iargs[6] = 0; // Padding + iargs[7] = 0; + iargs[8] = 0; + iargs[9] = 1; // Dilation + iargs[10] = 1; + iargs[11] = 1; + iargs[12] = 1; // Same mode + iargs[13] = format; // 0 = ncdhw + ctx->setIArguments(iargs, 14); + delete[] iargs; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "CNN3D"); + return output; +} + +static std::string lstmBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + BoolParameters format( + "format"); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] #ifdef _RELEASE - PredefinedParameters mb("mb", {1, 8, 64}); - PredefinedParameters nInOut("nInOut", {32, 256, 1024}); + PredefinedParameters mb("mb", {1, 8, 64}); + PredefinedParameters nInOut("nInOut", {32, 256, 1024}); #else - PredefinedParameters mb("mb", {1}); - PredefinedParameters nInOut("nInOut", {32}); + PredefinedParameters mb("mb", {1}); + PredefinedParameters nInOut("nInOut", {32}); #endif - ParametersBatch batch({&format, &mb, &nInOut}); - sd::ops::lstmBlock lstmBlock; - DeclarableBenchmark benchmark(lstmBlock, "lstm"); - - int seqLength = 32; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int f = p.getIntParam("format"); - int m = p.getIntParam("mb"); - int n = p.getIntParam("nInOut"); - - Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create(l)); //Max TS length (unused) - - - if (f == 0) { - //TNS format - ctx->setInputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //x - ctx->setOutputArray(0, NDArrayFactory::create('c', {seqLength, m, n})); //i - ctx->setOutputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //c - ctx->setOutputArray(2, NDArrayFactory::create('c', {seqLength, m, n})); //f - ctx->setOutputArray(3, NDArrayFactory::create('c', {seqLength, m, n})); //o - ctx->setOutputArray(4, NDArrayFactory::create('c', {seqLength, m, n})); //z - ctx->setOutputArray(5, NDArrayFactory::create('c', {seqLength, m, n})); //h - ctx->setOutputArray(6, NDArrayFactory::create('c', {seqLength, m, n})); //y - } else { - //NST format - ctx->setInputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //x - ctx->setOutputArray(0, NDArrayFactory::create('f', {m, n, seqLength})); //i - ctx->setOutputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //c - ctx->setOutputArray(2, NDArrayFactory::create('f', {m, n, seqLength})); //f - ctx->setOutputArray(3, NDArrayFactory::create('f', {m, n, seqLength})); //o - ctx->setOutputArray(4, NDArrayFactory::create('f', {m, n, seqLength})); //z - ctx->setOutputArray(5, NDArrayFactory::create('f', {m, n, seqLength})); //h - ctx->setOutputArray(6, NDArrayFactory::create('f', {m, n, seqLength})); //y - } - - auto cLast = NDArrayFactory::create('c', {m, n}); - auto yLast = NDArrayFactory::create('c', {m, n}); - auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create('c', {n}); - auto Wcf = NDArrayFactory::create('c', {n}); - auto Wco = NDArrayFactory::create('c', {n}); - auto b = NDArrayFactory::create('c', {4 * n}); - - ctx->setInputArray(2, cLast); - ctx->setInputArray(3, yLast); - ctx->setInputArray(4, W); - ctx->setInputArray(5, Wci); - ctx->setInputArray(6, Wcf); - ctx->setInputArray(7, Wco); - ctx->setInputArray(8, b); - - auto iargs = new Nd4jLong[2]; - iargs[0] = 0; //No peephole - iargs[1] = f; - ctx->setIArguments(iargs, 2); - delete[] iargs; - - auto targs = new double[2]; - targs[0] = 1.0; //forget bias - targs[1] = 0.0; //cell clipping value - ctx->setTArguments(targs, 2); - delete[] targs; - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); - return output; + ParametersBatch batch({&format, &mb, &nInOut}); + sd::ops::lstmBlock lstmBlock; + DeclarableBenchmark benchmark(lstmBlock, "lstm"); + + int seqLength = 32; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int f = p.getIntParam("format"); + int m = p.getIntParam("mb"); + int n = p.getIntParam("nInOut"); + + Nd4jLong l = 0; + ctx->setInputArray( + 0, NDArrayFactory::create(l)); // Max TS length (unused) + + if (f == 0) { + // TNS format + ctx->setInputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {seqLength, m, n})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('c', {seqLength, m, n})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('c', {seqLength, m, n})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('c', {seqLength, m, n})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('c', {seqLength, m, n})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('c', {seqLength, m, n})); // y + } else { + // NST format + ctx->setInputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('f', {m, n, seqLength})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('f', {m, n, seqLength})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('f', {m, n, seqLength})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('f', {m, n, seqLength})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('f', {m, n, seqLength})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('f', {m, n, seqLength})); // y } - static std::string batchnormBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); + + auto iargs = new Nd4jLong[2]; + iargs[0] = 0; // No peephole + iargs[1] = f; + ctx->setIArguments(iargs, 2); + delete[] iargs; + + auto targs = new double[2]; + targs[0] = 1.0; // forget bias + targs[1] = 0.0; // cell clipping value + ctx->setTArguments(targs, 2); + delete[] targs; + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); + return output; +} + +static std::string batchnormBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {16}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {16}); #endif - ParametersBatch batch({&nhwc, &c, &hw}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int ch = p.getIntParam("c"); - - auto args = new Nd4jLong[3]; - args[0] = args[1] = 1; //apply scale and offset - if (n == 0) { - auto input = NDArrayFactory::create('c', {32, ch, hw, hw}); - auto output = NDArrayFactory::create('c', {32, ch, hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - args[2] = 1; //axis - } else { - auto input = NDArrayFactory::create('c', {32, hw, hw, ch}); - auto output = NDArrayFactory::create('c', {32, hw, hw, ch}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - args[2] = 3; //axis - } - ctx->setIArguments(args, 3); - delete[] args; - - ctx->setInputArray(1, NDArrayFactory::create('c', {ch})); //mean - auto v = NDArrayFactory::create('c', {ch}); - v.assign(1.0f); - ctx->setInputArray(2, v); //variance - auto g = NDArrayFactory::create('c', {ch}); - g.assign(1.0); - ctx->setInputArray(3, g); //gamma - auto b = NDArrayFactory::create('c', {ch}); - b.assign(1.0); - ctx->setInputArray(4, b); //beta - - auto targs = new double[1]; - targs[0] = 1e-5; - ctx->setTArguments(targs, 1); - delete[] targs; - - return ctx; - }; - - sd::ops::batchnorm batchnorm; - DeclarableBenchmark benchmark(batchnorm, "batchnorm"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization"); - - return output; + ParametersBatch batch({&nhwc, &c, &hw}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int ch = p.getIntParam("c"); + + auto args = new Nd4jLong[3]; + args[0] = args[1] = 1; // apply scale and offset + if (n == 0) { + auto input = NDArrayFactory::create('c', {32, ch, hw, hw}); + auto output = NDArrayFactory::create('c', {32, ch, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + args[2] = 1; // axis + } else { + auto input = NDArrayFactory::create('c', {32, hw, hw, ch}); + auto output = NDArrayFactory::create('c', {32, hw, hw, ch}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + args[2] = 3; // axis } - - static std::string pool2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + ctx->setIArguments(args, 3); + delete[] args; + + ctx->setInputArray(1, NDArrayFactory::create('c', {ch})); // mean + auto v = NDArrayFactory::create('c', {ch}); + v.assign(1.0f); + ctx->setInputArray(2, v); // variance + auto g = NDArrayFactory::create('c', {ch}); + g.assign(1.0); + ctx->setInputArray(3, g); // gamma + auto b = NDArrayFactory::create('c', {ch}); + b.assign(1.0); + ctx->setInputArray(4, b); // beta + + auto targs = new double[1]; + targs[0] = 1e-5; + ctx->setTArguments(targs, 1); + delete[] targs; + + return ctx; + }; + + sd::ops::batchnorm batchnorm; + DeclarableBenchmark benchmark(batchnorm, "batchnorm"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Batch Normalization"); + + return output; +} + +static std::string pool2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters k("k", {2, 3, 5}); - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters k("k", {2, 3, 5}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters k("k", {2}); - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {8}); + PredefinedParameters k("k", {2}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {8}); #endif - ParametersBatch batch({&nhwc, &k, &c, &hw}); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } else { - auto input = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } - - auto args = new Nd4jLong[11]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = 0; //Divisor mode - 0 = exclude padding in divisor - args[10] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 11); - delete[] args; - - return ctx; - }; - - sd::ops::avgpool2d avgpool2d; - DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); - output += helper.runOperationSuit(&benchmark1, generator, batch, "Average Pooling 2d Operation"); - - sd::ops::maxpool2d maxpool2d; - DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); - output += helper.runOperationSuit(&benchmark2, generator, batch, "Max Pooling 2d Operation"); - return output; + ParametersBatch batch({&nhwc, &k, &c, &hw}); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - static std::string conv2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Convolution2D op - BoolParameters nhwc("nhwc"); + auto args = new Nd4jLong[11]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = 0; // Divisor mode - 0 = exclude padding in divisor + args[10] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 11); + delete[] args; + + return ctx; + }; + + sd::ops::avgpool2d avgpool2d; + DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); + output += helper.runOperationSuit(&benchmark1, generator, batch, + "Average Pooling 2d Operation"); + + sd::ops::maxpool2d maxpool2d; + DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); + output += helper.runOperationSuit(&benchmark2, generator, batch, + "Max Pooling 2d Operation"); + return output; +} + +static std::string conv2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Convolution2D op + BoolParameters nhwc("nhwc"); #ifdef _RELEASE - PredefinedParameters k("k", {2, 3, 5}); - PredefinedParameters c("c", {3, 32, 128}); - PredefinedParameters hw("hw", {32, 128}); + PredefinedParameters k("k", {2, 3, 5}); + PredefinedParameters c("c", {3, 32, 128}); + PredefinedParameters hw("hw", {32, 128}); #else - PredefinedParameters k("k", {2}); - PredefinedParameters c("c", {3}); - PredefinedParameters hw("hw", {8}); + PredefinedParameters k("k", {2}); + PredefinedParameters c("c", {3}); + PredefinedParameters hw("hw", {8}); #endif - ParametersBatch batch({&nhwc, &k, &c, &hw}); - sd::ops::conv2d conv2d; - DeclarableBenchmark benchmark(conv2d, "conv2d"); - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int hw = p.getIntParam("hw"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); - auto output = NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } else { - auto input = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); - auto output = NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } - - auto b = NDArrayFactory::create('c', {p.getIntParam("c")}); - auto w = NDArrayFactory::create('c', {khw, khw, p.getIntParam("c"), p.getIntParam("c")}); // [kH, kW, iC, oC] always - - ctx->setInputArray(1, w); - ctx->setInputArray(2, b); - - auto args = new Nd4jLong[10]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 10); - delete[] args; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d Operation"); - return output; - } - - static std::string rngBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - //Uniform, gaussian and bernoulli RNG generation - - IntPowerParameters length("length", 2, 4, scalarBenchmarkPowLimit, 3); //2^8 to 2^30 in steps of 3 - - ParametersBatch batch({&length}); - - auto gen01 = PARAMETRIC_D() { - auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create('c', {2},{1, p.getIntParam("length")})); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); - auto d = new double[2]; - d[0] = 0.0; - d[1] = 1.0; - ctx->setTArguments(d, 2); - delete[] d; - return ctx; - }; - - auto gen05 = PARAMETRIC_D() { - auto ctx = new Context(1); - ctx->setInputArray(0, NDArrayFactory::create('c', {2},{1, p.getIntParam("length")})); //Shape as NDArray - ctx->setOutputArray(0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); - auto d = new double[1]; - d[0] = 0.5; - ctx->setTArguments(d, 1); - delete[] d; - return ctx; - }; - - sd::ops::LegacyRandomOp unif(random::UniformDistribution); - DeclarableBenchmark dbU(unif, "uniform"); - output += helper.runOperationSuit(&dbU, gen01, batch, "Uniform Distribution"); - - sd::ops::LegacyRandomOp gaussian(random::GaussianDistribution); - DeclarableBenchmark dbG(gaussian, "gaussian"); - output += helper.runOperationSuit(&dbG, gen01, batch, "Gaussian Distribution"); - - sd::ops::LegacyRandomOp trunc(random::TruncatedNormalDistribution); - DeclarableBenchmark dbTU(unif, "trunc.norm"); - output += helper.runOperationSuit(&dbTU, gen01, batch, "Truncated Normal Distribution"); - - sd::ops::LegacyRandomOp ln(random::LogNormalDistribution); - DeclarableBenchmark dbLN(ln, "uniform"); - output += helper.runOperationSuit(&dbLN, gen01, batch, "Log Normal Distribution"); - - sd::ops::LegacyRandomOp bernoulli(random::BernoulliDistribution); - DeclarableBenchmark dbB(bernoulli, "bernoulli"); - output += helper.runOperationSuit(&dbB, gen05, batch, "Bernoulli Distribution"); - - sd::ops::LegacyRandomOp dropout(random::BernoulliDistribution); - DeclarableBenchmark dbD(dropout, "dropout"); - output += helper.runOperationSuit(&dbD, gen05, batch, "Dropout"); - - return output; + ParametersBatch batch({&nhwc, &k, &c, &hw}); + sd::ops::conv2d conv2d; + DeclarableBenchmark benchmark(conv2d, "conv2d"); + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int hw = p.getIntParam("hw"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + auto output = + NDArrayFactory::create('c', {32, p.getIntParam("c"), hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + auto output = + NDArrayFactory::create('c', {32, hw, hw, p.getIntParam("c")}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - static std::string gemmIrregularBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Basically the same as above, but with irregular shapes (not multiples of 8, etc) + auto b = NDArrayFactory::create('c', {p.getIntParam("c")}); + auto w = NDArrayFactory::create( + 'c', {khw, khw, p.getIntParam("c"), + p.getIntParam("c")}); // [kH, kW, iC, oC] always + + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); + + auto args = new Nd4jLong[10]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 10); + delete[] args; + + return ctx; + }; + + output += + helper.runOperationSuit(&benchmark, generator, batch, "Conv2d Operation"); + return output; +} + +static std::string rngBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + // Uniform, gaussian and bernoulli RNG generation + + IntPowerParameters length("length", 2, 4, scalarBenchmarkPowLimit, + 3); // 2^8 to 2^30 in steps of 3 + + ParametersBatch batch({&length}); + + auto gen01 = PARAMETRIC_D() { + auto ctx = new Context(1); + ctx->setInputArray( + 0, NDArrayFactory::create( + 'c', {2}, {1, p.getIntParam("length")})); // Shape as NDArray + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); + auto d = new double[2]; + d[0] = 0.0; + d[1] = 1.0; + ctx->setTArguments(d, 2); + delete[] d; + return ctx; + }; + + auto gen05 = PARAMETRIC_D() { + auto ctx = new Context(1); + ctx->setInputArray( + 0, NDArrayFactory::create( + 'c', {2}, {1, p.getIntParam("length")})); // Shape as NDArray + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {1, p.getIntParam("length")})); + auto d = new double[1]; + d[0] = 0.5; + ctx->setTArguments(d, 1); + delete[] d; + return ctx; + }; + + sd::ops::LegacyRandomOp unif(random::UniformDistribution); + DeclarableBenchmark dbU(unif, "uniform"); + output += helper.runOperationSuit(&dbU, gen01, batch, "Uniform Distribution"); + + sd::ops::LegacyRandomOp gaussian(random::GaussianDistribution); + DeclarableBenchmark dbG(gaussian, "gaussian"); + output += + helper.runOperationSuit(&dbG, gen01, batch, "Gaussian Distribution"); + + sd::ops::LegacyRandomOp trunc(random::TruncatedNormalDistribution); + DeclarableBenchmark dbTU(unif, "trunc.norm"); + output += helper.runOperationSuit(&dbTU, gen01, batch, + "Truncated Normal Distribution"); + + sd::ops::LegacyRandomOp ln(random::LogNormalDistribution); + DeclarableBenchmark dbLN(ln, "uniform"); + output += + helper.runOperationSuit(&dbLN, gen01, batch, "Log Normal Distribution"); + + sd::ops::LegacyRandomOp bernoulli(random::BernoulliDistribution); + DeclarableBenchmark dbB(bernoulli, "bernoulli"); + output += + helper.runOperationSuit(&dbB, gen05, batch, "Bernoulli Distribution"); + + sd::ops::LegacyRandomOp dropout(random::BernoulliDistribution); + DeclarableBenchmark dbD(dropout, "dropout"); + output += helper.runOperationSuit(&dbD, gen05, batch, "Dropout"); + + return output; +} + +static std::string gemmIrregularBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Basically the same as above, but with irregular shapes (not multiples of 8, + // etc) #ifdef _RELEASE - int tAMax = 1; - int tBMax = 1; - int b = 1024; - int c = 1024; + int tAMax = 1; + int tBMax = 1; + int b = 1024; + int c = 1024; #else - int tAMax = 1; - int tBMax = 1; - int b = 32; - int c = 32; + int tAMax = 1; + int tBMax = 1; + int b = 32; + int c = 32; #endif - for (int tA = 0; tA <= tAMax; tA++) { - for (int tB = 0; tB <= tBMax; tB++) { - IntParameters d("d", 1020, 1028, 1); //1020, 1021, ..., 1028 - ParametersBatch dim({&d}); - - //Vary A.rows: - auto generator = PARAMETRIC_XYZ() { - auto a = p.getIntParam("d"); - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create('c', shapeA); - auto B = NDArrayFactory::create('c', shapeB); - auto C = NDArrayFactory::create('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n; - n += "Gemm (a.rows) - tA="; - n += std::to_string(tA); - n += ", tB="; - n += std::to_string(tB); - - MatrixBenchmark mb(1.0, 0.0, tA, tB, n); - - output += helper.runOperationSuit(&mb, generator, dim, n.c_str()); - - //Vary A.columns / B.rows - auto generator2 = PARAMETRIC_XYZ() { - auto a = 1024; - auto b = p.getIntParam("d"); - auto c = 1024; - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create('c', shapeA); - auto B = NDArrayFactory::create('c', shapeB); - auto C = NDArrayFactory::create('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n2; - n2 += "Gemm (a.columns) - tA="; - n2 += std::to_string(tA); - n2 += ", tB="; - n2 += std::to_string(tB); - - MatrixBenchmark mb2(1.0, 0.0, tA, tB, n2); - - output += helper.runOperationSuit(&mb2, generator2, dim, n2.c_str()); - - //Vary A.columns / B.rows - auto generator3 = PARAMETRIC_XYZ() { - auto a = 1024; - auto b = 1024; - auto c = p.getIntParam("d"); - std::vector shapeA; - std::vector shapeB; - if (tA) { - shapeA = {b, a}; - } else { - shapeA = {a, b}; - } - if (tB) { - shapeB = {c, b}; - } else { - shapeB = {b, c}; - } - auto A = NDArrayFactory::create('c', shapeA); - auto B = NDArrayFactory::create('c', shapeB); - auto C = NDArrayFactory::create('f', {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n3; - n3 += "Gemm (b.columns) - tA="; - n3 += std::to_string(tA); - n3 += ", tB="; - n3 += std::to_string(tB); - - MatrixBenchmark mb3(1.0, 0.0, tA, tB, n); - - output += helper.runOperationSuit(&mb3, generator3, dim, n3.c_str()); - } + for (int tA = 0; tA <= tAMax; tA++) { + for (int tB = 0; tB <= tBMax; tB++) { + IntParameters d("d", 1020, 1028, 1); // 1020, 1021, ..., 1028 + ParametersBatch dim({&d}); + + // Vary A.rows: + auto generator = PARAMETRIC_XYZ() { + auto a = p.getIntParam("d"); + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n; + n += "Gemm (a.rows) - tA="; + n += std::to_string(tA); + n += ", tB="; + n += std::to_string(tB); + + MatrixBenchmark mb(1.0, 0.0, tA, tB, n); + + output += helper.runOperationSuit(&mb, generator, dim, n.c_str()); + + // Vary A.columns / B.rows + auto generator2 = PARAMETRIC_XYZ() { + auto a = 1024; + auto b = p.getIntParam("d"); + auto c = 1024; + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; + } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n2; + n2 += "Gemm (a.columns) - tA="; + n2 += std::to_string(tA); + n2 += ", tB="; + n2 += std::to_string(tB); + + MatrixBenchmark mb2(1.0, 0.0, tA, tB, n2); + + output += helper.runOperationSuit(&mb2, generator2, dim, n2.c_str()); + + // Vary A.columns / B.rows + auto generator3 = PARAMETRIC_XYZ() { + auto a = 1024; + auto b = 1024; + auto c = p.getIntParam("d"); + std::vector shapeA; + std::vector shapeB; + if (tA) { + shapeA = {b, a}; + } else { + shapeA = {a, b}; + } + if (tB) { + shapeB = {c, b}; + } else { + shapeB = {b, c}; + } + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create('f', {a, c}); - return output; - } - - static std::string batchGemmBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Rank 3 - [32,1024,1024]x[32,1024,1024] - //Rank 4 - [4,8,1024,1024]x[4,8,1024,1024] - - IntParameters rank("rank", 3, 4, 1); - - ParametersBatch b({&rank}); - - auto generator = PARAMETRIC_D() { - auto rank = p.getIntParam("rank"); - std::vector shapeA; - std::vector shapeB; - auto ctx = new Context(1); - - if(rank == 3){ - ctx->setInputArray(0, NDArrayFactory::create('c', {32, 1024, 1024})); - ctx->setInputArray(1, NDArrayFactory::create('c', {32, 1024, 1024})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {32, 1024, 1024})); - } else { - ctx->setInputArray(0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); - ctx->setInputArray(1, NDArrayFactory::create('c', {4, 8, 1024, 1024})); - ctx->setOutputArray(0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); - } + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; - return ctx; - }; + std::string n3; + n3 += "Gemm (b.columns) - tA="; + n3 += std::to_string(tA); + n3 += ", tB="; + n3 += std::to_string(tB); - sd::ops::matmul mmul; - DeclarableBenchmark benchmark(mmul, "mmul (batch)"); - output += helper.runOperationSuit(&benchmark, generator, b, "MMul (batch)"); + MatrixBenchmark mb3(1.0, 0.0, tA, tB, n); - return output; + output += helper.runOperationSuit(&mb3, generator3, dim, n3.c_str()); + } + } + + return output; +} + +static std::string batchGemmBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Rank 3 - [32,1024,1024]x[32,1024,1024] + // Rank 4 - [4,8,1024,1024]x[4,8,1024,1024] + + IntParameters rank("rank", 3, 4, 1); + + ParametersBatch b({&rank}); + + auto generator = PARAMETRIC_D() { + auto rank = p.getIntParam("rank"); + std::vector shapeA; + std::vector shapeB; + auto ctx = new Context(1); + + if (rank == 3) { + ctx->setInputArray(0, + NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setInputArray(1, + NDArrayFactory::create('c', {32, 1024, 1024})); + ctx->setOutputArray(0, + NDArrayFactory::create('c', {32, 1024, 1024})); + } else { + ctx->setInputArray( + 0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setInputArray( + 1, NDArrayFactory::create('c', {4, 8, 1024, 1024})); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {4, 8, 1024, 1024})); } - static std::string gemmRegularBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - for (int o = 0; o <= 1; o++) { - char resultOrder = (o == 0 ? 'f' : 'c'); - for (int tA = 0; tA <= 1; tA++) { - for (int tB = 0; tB <= 1; tB++) { - - - IntPowerParameters pa("sz", 2, 7, gemmRegularUpperPow, 2); //2^7=128, 2^9=512, 2^11=2048 - - ParametersBatch b({&pa}); - - auto generator = PARAMETRIC_XYZ() { - auto s = p.getIntParam("sz"); - auto A = NDArrayFactory::create('c', {s, s}); - auto B = NDArrayFactory::create('c', {s, s}); - auto C = NDArrayFactory::create(resultOrder, {s, s}); + return ctx; + }; - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; + sd::ops::matmul mmul; + DeclarableBenchmark benchmark(mmul, "mmul (batch)"); + output += helper.runOperationSuit(&benchmark, generator, b, "MMul (batch)"); - std::string n; - n += "Gemm - tA="; - n += std::to_string(tA); - n += ", tB="; - n += std::to_string(tB); - n += ", cOrder="; - n += resultOrder; + return output; +} - MatrixBenchmark mb(1.0, 0.0, tA == 0 ? false : true, tB == 0 ? false : true, n); +static std::string gemmRegularBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - output += helper.runOperationSuit(&mb, generator, b, n.c_str()); - } - } - } + for (int o = 0; o <= 1; o++) { + char resultOrder = (o == 0 ? 'f' : 'c'); + for (int tA = 0; tA <= 1; tA++) { + for (int tB = 0; tB <= 1; tB++) { + IntPowerParameters pa("sz", 2, 7, gemmRegularUpperPow, + 2); // 2^7=128, 2^9=512, 2^11=2048 - return output; - } + ParametersBatch b({&pa}); - static std::string scatterOpBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, gatherOpPowLimit, 4); //2^10 to 2^26 in steps of 4 - ParametersBatch batch({&length}); - - //Gather 1D tests - 1d ref, 1d indices, 1d updates -> 1d output - sd::ops::scatter_upd scatter_update1; - DeclarableBenchmark sa1d(scatter_update1, "scatter_update1d"); - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int length = p.getIntParam("length"); - auto in = NDArrayFactory::create('c', {length}); - auto indices = NDArrayFactory::create('c', {length}); - auto updates = NDArrayFactory::create('c', {length}); - - int* a = new int[length]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setInputArray(2, updates); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; + auto generator = PARAMETRIC_XYZ() { + auto s = p.getIntParam("sz"); + auto A = NDArrayFactory::create('c', {s, s}); + auto B = NDArrayFactory::create('c', {s, s}); + auto C = NDArrayFactory::create(resultOrder, {s, s}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); }; - output += helper.runOperationSuit(&sa1d, generator, batch, "Scatter Update - 1d"); - - //Gather 2D tests - 2d input, 1d indices, 2d updates -> 2d output - IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, 4); //2^10 to 2^16 in steps of 2: 2^10, ..., 2^20 - PredefinedParameters cols("cols", {32}); - ParametersBatch batch2({&rows, &cols}); - sd::ops::scatter_upd scatter_update2; - DeclarableBenchmark sa2d(scatter_update2, "scatter_update2d"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create('c', {rows, cols}); - auto indices = NDArrayFactory::create('c', {rows}); - auto updates = NDArrayFactory::create('c', {rows, cols}); - - int* a = new int[rows]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setInputArray(2, updates); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; - }; + std::string n; + n += "Gemm - tA="; + n += std::to_string(tA); + n += ", tB="; + n += std::to_string(tB); + n += ", cOrder="; + n += resultOrder; - output += helper.runOperationSuit(&sa2d, generator2, batch2, "Scatter Update - 2d"); - - //Gather 3D tests - 3d input, 1d indices -> 3d output - IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); - PredefinedParameters sz1("sz1", {32}); - ParametersBatch batch3({&sz0, &sz1}); - sd::ops::scatter_upd scatter_update3; - DeclarableBenchmark sa3d(scatter_update3, "scatter3d"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - int sz0 = p.getIntParam("sz0"); - int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create('c', {sz0}); - auto updates = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); - - int* a = new int[sz0]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setInputArray(2, updates); - ctx->setOutputArray(0, in); //Needs to be inplace to avoid copy! - ctx->markInplace(true); - return ctx; - }; + MatrixBenchmark mb(1.0, 0.0, tA == 0 ? false : true, + tB == 0 ? false : true, n); - output += helper.runOperationSuit(&sa3d, generator3, batch3, "Scatter Update - 3d"); - return output; + output += helper.runOperationSuit(&mb, generator, b, n.c_str()); + } } - - static std::string gatherOpBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, gatherOpPowLimit, 4); //2^10 to 2^22 in steps of 4 - ParametersBatch batch({&length}); - - //Gather 1D tests - 1d input, 1d indices -> 1d output - sd::ops::gather gather1; - DeclarableBenchmark gather1d(gather1, "gather1d"); - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int length = p.getIntParam("length"); - auto in = NDArrayFactory::create('c', {length}); - auto indices = NDArrayFactory::create('c', {length}); - int* a = new int[length]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setOutputArray(0, NDArrayFactory::create('c', {length})); - return ctx; - }; - - output += helper.runOperationSuit(&gather1d, generator, batch, "Gather - 1d"); - - //Gather 2D tests - 2d input, 1d indices -> 2d output - IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, 4); //2^10 to 2^20 in steps of 2: 2^10, ..., 2^20 - PredefinedParameters cols("cols", {32}); - ParametersBatch batch2({&rows, &cols}); - sd::ops::gather gather2; - DeclarableBenchmark gather2d(gather2, "gather2d"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = p.getIntParam("cols"); - auto in = NDArrayFactory::create('c', {rows, cols}); - auto indices = NDArrayFactory::create('c', {rows}); - - int* a = new int[rows]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, cols})); - return ctx; - }; - - output += helper.runOperationSuit(&gather2d, generator2, batch2, "Gather - 2d"); - - //Gather 3D tests - 3d input, 1d indices -> 3d output - IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); //2^8 to 2^16 in steps of 4 - PredefinedParameters sz1("sz1", {32}); - ParametersBatch batch3({&sz0, &sz1}); - sd::ops::gather gather3; - DeclarableBenchmark gather3d(gather3, "gather3d"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - int sz0 = p.getIntParam("sz0"); - int sz1 = p.getIntParam("sz1"); - auto in = NDArrayFactory::create('c', {sz0, sz1, 512/sz1}); - auto indices = NDArrayFactory::create('c', {sz0}); - - int* a = new int[sz0]; - for( int i=0; isetInputArray(0, in); - ctx->setInputArray(1, indices); - ctx->setOutputArray(0, NDArrayFactory::create('c', {sz0, sz1, 512/sz1})); - return ctx; - }; - - output += helper.runOperationSuit(&gather3d, generator3, batch3, "Gather - 3d"); - - return output; + } + + return output; +} + +static std::string scatterOpBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, gatherOpPowLimit, + 4); // 2^10 to 2^26 in steps of 4 + ParametersBatch batch({&length}); + + // Gather 1D tests - 1d ref, 1d indices, 1d updates -> 1d output + sd::ops::scatter_upd scatter_update1; + DeclarableBenchmark sa1d(scatter_update1, "scatter_update1d"); + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int length = p.getIntParam("length"); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); + auto updates = NDArrayFactory::create('c', {length}); + + int *a = new int[length]; + for (int i = 0; i < length; i++) { + a[i] = i; } - - static std::string mismatchedOrdersAssignBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters rows("rows", 2, 2, mismatchedAssignPowLimit, 4); //2^2 to 2^26 in steps of 2 - 2^1=2, ..., 2^26=67108864 - BoolParameters cf("cf"); - - ParametersBatch batch({&rows, &cf}); - - auto generator = PARAMETRIC_XZ() { - int numElements = 67108864; //2^26 - int rows = p.getIntParam("rows"); - int cols = numElements / rows; - bool c = p.getIntParam("cf"); - - auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); - x.push_back(arr); - z.push_back(arr2); - }; - - TransformBenchmark tb(transform::AnyOps::Assign, "assign"); - output += helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign"); - - //Also test: NCHW to NHWC and back - BoolParameters nchw("nchw"); - ParametersBatch batch2({&nchw}); - auto generator2 = PARAMETRIC_XZ() { - bool nchw = p.getIntParam("nchw"); - - if(nchw) { - auto orig = NDArrayFactory::create('c', {16, 32, 64, 64}); - orig.permutei({0,2,3,1}); - x.push_back(orig); - z.push_back(NDArrayFactory::create('c', {16, 64, 64, 32})); - } else { - auto orig = NDArrayFactory::create('c', {16, 64, 64, 32}); - orig.permutei({0,3,1,2}); - x.push_back(orig); - z.push_back(NDArrayFactory::create('c', {16, 32, 64, 64})); - } - }; - - TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); - output += helper.runOperationSuit(&tb2, generator2, batch2, "nchw->nhwc and nhwc->nchw Assign"); - return output; + srand(12345); + std::random_shuffle(a, (a + length - 1)); + for (int i = 0; i < length; i++) { + indices.p(i, a[i]); } - - static std::string broadcastOpsMatrixBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Broadcast ops: matrices for rank 3, 4, 5 - for( int rank=3; rank <= broadcastMatrixRankLimit; rank++ ){ - int numAxisTests = -1; - if(rank == 3){ - numAxisTests = 3; - } else if(rank == 4){ - numAxisTests = 6; - } else if(rank == 5){ - numAxisTests = 10; - } - - IntParameters testNum("testNum", 0,numAxisTests-1,1); - ParametersBatch b({&testNum}); - - auto generator = PARAMETRIC_D(){ - int n = p.getIntParam("testNum"); - std::vector axis({}); - switch(n){ - //rank 3+ - case 0: - axis = std::vector({0,1}); - break; - case 1: - axis = std::vector({0,2}); - break; - case 2: - axis = std::vector({1,2}); - break; - //rank 4+ - case 3: - axis = std::vector({0,3}); - break; - case 4: - axis = std::vector({1,3}); - break; - case 5: - axis = std::vector({2,3}); - break; - //Rank 5 - case 6: - axis = std::vector({0,4}); - break; - case 7: - axis = std::vector({1,4}); - break; - case 8: - axis = std::vector({2,4}); - break; - case 9: - axis = std::vector({3,4}); - break; - } - - - std::vector shape({}); - std::vector toBcShape({}); - int vectorLength; - if(rank == 3){ - shape = std::vector({64,64,64}); - toBcShape = std::vector({64,64,64}); - vectorLength = 64; - } else if(rank == 4){ - shape = std::vector({32,32,32,32}); - toBcShape = std::vector({32,32,32,32}); - vectorLength = 32; - } else if(rank == 5){ - shape = std::vector({16,16,16,16,16}); - toBcShape = std::vector({16,16,16,16,16}); - vectorLength = 16; - } - - for( int i=0; isetInputArray(0, NDArrayFactory::create('c', shape)); - ctx->setInputArray(1, NDArrayFactory::create('c', toBcShape)); - ctx->setOutputArray(0, NDArrayFactory::create('c', shape)); - return ctx; - }; - - std::string name; - name += "Broadcast Matrix Add (Custom) - Rank"; - name += std::to_string(rank); - - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, b, name.c_str()); - } - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa1d, generator, batch, "Scatter Update - 1d"); + + // Gather 2D tests - 2d input, 1d indices, 2d updates -> 2d output + IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, + 4); // 2^10 to 2^16 in steps of 2: 2^10, ..., 2^20 + PredefinedParameters cols("cols", {32}); + ParametersBatch batch2({&rows, &cols}); + sd::ops::scatter_upd scatter_update2; + DeclarableBenchmark sa2d(scatter_update2, "scatter_update2d"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = p.getIntParam("cols"); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); + auto updates = NDArrayFactory::create('c', {rows, cols}); + + int *a = new int[rows]; + for (int i = 0; i < rows; i++) { + a[i] = i; } - - - static std::string broadcast2dBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - PredefinedParameters rows("rows", {65536}); - IntPowerParameters cols("cols", 2, 2, limit10, 4); //2^2, 2^6, 2^10 - BoolParameters axis("axis"); - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rows, &cols, &axis, &inplace}); - - auto generator = PARAMETRIC_D() { - auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - - auto ctx = new Context(1); - ctx->setInputArray(0, arr); - if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create('c', {p.getIntParam("rows"), 1})); - } else { - ctx->setInputArray(1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); - } - if (p.getIntParam("inplace") == 1) { - ctx->setOutputArray(0, arr); - ctx->markInplace(true); - } else { - ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); - } - return ctx; - }; - - std::string s("add"); - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Broadcast (Custom) Add - 2d"); - return output; + srand(12345); + std::random_shuffle(a, (a + rows - 1)); + for (int i = 0; i < rows; i++) { + indices.p(i, a[i]); } - - static std::string broadcastBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - //Broadcast ops: vectors for rank 2, 3, 4, 5 - for( int axis=0; axis<=1; axis++ ){ - PredefinedParameters rows("rows", {65536}); - IntPowerParameters cols("cols", 2, 2, limit10, 4); //2^1 to 2^10 in steps of 2 - 2^1=2, ..., 2^10=1024 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rows, &cols, &inplace}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - x.push_back(arr); - if(axis == 0){ - y.push_back(NDArrayFactory::create('c', {p.getIntParam("rows")})); - } else { - y.push_back(NDArrayFactory::create('c', {p.getIntParam("cols")})); - } - if (p.getIntParam("inplace") == 1) { - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); - } - }; - - std::string s("bAdd"); s += std::to_string(axis); s += "r2"; - BroadcastBenchmark bAdd(broadcast::Add, s, {axis}); - output += helper.runOperationSuit(&bAdd, generator, batch, "Broadcast Add - Rank 2"); - } - - for( int rank=3; rank<=5; rank++ ){ - for( int axis=1; axis shape({}); - int vectorLength; - if(rank == 3){ - shape = std::vector({32,128,128}); - vectorLength = 128; - } else if(rank == 4){ - shape = std::vector({16,64,64,64}); - vectorLength = 64; - } else if(rank == 5){ - shape = std::vector({16,48,48,48,48}); - vectorLength = 48; - } - - ParametersBatch batch({}); - - //Note: always inplace here - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create('c', shape); - x.push_back(arr); - y.push_back(NDArrayFactory::create('c', {vectorLength})); - z.push_back(arr); - }; - - std::string name("bArr-r"); name += std::to_string(rank); name += "a"; name += std::to_string(axis); - BroadcastBenchmark bAdd(broadcast::Add, name, {axis}); - std::string n2("Broadcast Add - Rank"); n2 += std::to_string(rank); n2 += " - axis="; n2 += std::to_string(axis); - output += helper.runOperationSuit(&bAdd, generator, batch, n2.c_str()); - } - } - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa2d, generator2, batch2, "Scatter Update - 2d"); + + // Gather 3D tests - 3d input, 1d indices -> 3d output + IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, 4); + PredefinedParameters sz1("sz1", {32}); + ParametersBatch batch3({&sz0, &sz1}); + sd::ops::scatter_upd scatter_update3; + DeclarableBenchmark sa3d(scatter_update3, "scatter3d"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int sz0 = p.getIntParam("sz0"); + int sz1 = p.getIntParam("sz1"); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); + auto updates = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + + int *a = new int[sz0]; + for (int i = 0; i < sz0; i++) { + a[i] = i; } - - static std::string fastStridedReductionNonEws() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters stride("stride", 2, 0, 10, 2); //2^0=1, ..., 2^10=1024 - - ParametersBatch batch({&stride}); - - //This is an edge case: technically an EWS *should* be available here - auto generator1 = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create('c', {131072 + (stride == 1 ? 0 : 1), stride}); - - NDArray strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::interval(0,131072), NDIndex::interval(0,1)}); - strided = arr.subarray(indices); //All rows, first column - } - - strided.assign(1.0); - x.push_back(strided); - y.push_back(NDArray()); - z.push_back(NDArrayFactory::create(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - output += helper.runOperationSuit(&rbSum, (const std::function)(generator1), batch, "Strided Sum - No EWS Test 1"); - - - //No EWS defined for this case - auto generator2 = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create('c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); - - NDArray strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::interval(0,2*1024,2), NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr.subarray(indices); - } - - strided.assign(1.0); - x.push_back(strided); - y.push_back(NDArray()); - z.push_back(NDArrayFactory::create(0.0f)); - }; - - ReductionBenchmark rbSum2(reduce::SameOps::Sum, "stridedSumNoEWS"); - output += helper.runOperationSuit(&rbSum2, (const std::function)(generator2), batch, "Strided Sum - No EWS Test 2"); - - return output; + srand(12345); + std::random_shuffle(a, (a + sz0 - 1)); + for (int i = 0; i < sz0; i++) { + indices.p(i, a[i]); } - - static std::string fastStridedReductionIrregular() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, 4); //2^12 to 2^20 in steps of 4 - PredefinedParameters stride("stride", {26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, - 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028}); - - ParametersBatch batch({&length, &stride}); - - auto generator = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - - NDArray strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - strided = arr.subarray(indices); //All rows, first column - } - - strided.assign(1.0); - x.push_back(strided); - y.push_back(NDArray()); - z.push_back(NDArrayFactory::create(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Strided Sum - Irregular Strides"); - - return output; + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setInputArray(2, updates); + ctx->setOutputArray(0, in); // Needs to be inplace to avoid copy! + ctx->markInplace(true); + return ctx; + }; + + output += + helper.runOperationSuit(&sa3d, generator3, batch3, "Scatter Update - 3d"); + return output; +} + +static std::string gatherOpBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, gatherOpPowLimit, + 4); // 2^10 to 2^22 in steps of 4 + ParametersBatch batch({&length}); + + // Gather 1D tests - 1d input, 1d indices -> 1d output + sd::ops::gather gather1; + DeclarableBenchmark gather1d(gather1, "gather1d"); + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int length = p.getIntParam("length"); + auto in = NDArrayFactory::create('c', {length}); + auto indices = NDArrayFactory::create('c', {length}); + int *a = new int[length]; + for (int i = 0; i < length; i++) { + a[i] = i; } - - static std::string fastStridedReductionsRegular() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, 4); //2^12 to 2^20 in steps of 4 - IntPowerParameters stride("stride", 2, 0, 10); //2^0=1, ..., 2^10=1024 - - ParametersBatch batch({&length, &stride}); - - auto generator = PARAMETRIC_XYZ() { - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - - NDArray strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr.subarray(indices); //All rows, first column - } - - strided.assign(1.0); - x.push_back(strided); - y.push_back(NDArray()); -// z.push_back(NDArrayFactory::create(0.0f)); - z.push_back(NDArrayFactory::create('c', {1})); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "Strided Sum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Strided Sum - Regular Strides (powers of 2)"); - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - auto stride = p.getIntParam("stride"); - auto arr = NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - - NDArray strided; - if(stride == 1){ - strided = arr; - } else { - IndicesList indices({NDIndex::all(), NDIndex::point(0)}); - strided = arr.subarray(indices); //All rows, first column - } - - strided.assign(1.0); - ctx->setInputArray(0, strided); - ctx->setOutputArray(0, NDArrayFactory::create('c', {1})); - auto iargs = new Nd4jLong[1]; - iargs[0] = 0; - ctx->setIArguments(iargs, 1); - delete[] iargs; - return ctx; - }; - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "stridedArgmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Strided Argmax"); - return output; + srand(12345); + std::random_shuffle(a, (a + length - 1)); + for (int i = 0; i < length; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {length})); + return ctx; + }; + + output += helper.runOperationSuit(&gather1d, generator, batch, "Gather - 1d"); + + // Gather 2D tests - 2d input, 1d indices -> 2d output + IntPowerParameters rows("rows", 2, 8, gatherOpPowLimit2, + 4); // 2^10 to 2^20 in steps of 2: 2^10, ..., 2^20 + PredefinedParameters cols("cols", {32}); + ParametersBatch batch2({&rows, &cols}); + sd::ops::gather gather2; + DeclarableBenchmark gather2d(gather2, "gather2d"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = p.getIntParam("cols"); + auto in = NDArrayFactory::create('c', {rows, cols}); + auto indices = NDArrayFactory::create('c', {rows}); + + int *a = new int[rows]; + for (int i = 0; i < rows; i++) { + a[i] = i; + } + srand(12345); + std::random_shuffle(a, (a + rows - 1)); + for (int i = 0; i < rows; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, cols})); + return ctx; + }; + + output += + helper.runOperationSuit(&gather2d, generator2, batch2, "Gather - 2d"); + + // Gather 3D tests - 3d input, 1d indices -> 3d output + IntPowerParameters sz0("sz0", 2, 8, gatherOpPowLimit3, + 4); // 2^8 to 2^16 in steps of 4 + PredefinedParameters sz1("sz1", {32}); + ParametersBatch batch3({&sz0, &sz1}); + sd::ops::gather gather3; + DeclarableBenchmark gather3d(gather3, "gather3d"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int sz0 = p.getIntParam("sz0"); + int sz1 = p.getIntParam("sz1"); + auto in = NDArrayFactory::create('c', {sz0, sz1, 512 / sz1}); + auto indices = NDArrayFactory::create('c', {sz0}); + + int *a = new int[sz0]; + for (int i = 0; i < sz0; i++) { + a[i] = i; + } + srand(12345); + std::random_shuffle(a, (a + sz0 - 1)); + for (int i = 0; i < sz0; i++) { + indices.p(i, a[i]); + } + delete[] a; + + ctx->setInputArray(0, in); + ctx->setInputArray(1, indices); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {sz0, sz1, 512 / sz1})); + return ctx; + }; + + output += + helper.runOperationSuit(&gather3d, generator3, batch3, "Gather - 3d"); + + return output; +} + +static std::string mismatchedOrdersAssignBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters rows( + "rows", 2, 2, mismatchedAssignPowLimit, + 4); // 2^2 to 2^26 in steps of 2 - 2^1=2, ..., 2^26=67108864 + BoolParameters cf("cf"); + + ParametersBatch batch({&rows, &cf}); + + auto generator = PARAMETRIC_XZ() { + int numElements = 67108864; // 2^26 + int rows = p.getIntParam("rows"); + int cols = numElements / rows; + bool c = p.getIntParam("cf"); + + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); + x.push_back(arr); + z.push_back(arr2); + }; + + TransformBenchmark tb(transform::AnyOps::Assign, "assign"); + output += + helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign"); + + // Also test: NCHW to NHWC and back + BoolParameters nchw("nchw"); + ParametersBatch batch2({&nchw}); + auto generator2 = PARAMETRIC_XZ() { + bool nchw = p.getIntParam("nchw"); + + if (nchw) { + auto orig = NDArrayFactory::create('c', {16, 32, 64, 64}); + orig.permutei({0, 2, 3, 1}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {16, 64, 64, 32})); + } else { + auto orig = NDArrayFactory::create('c', {16, 64, 64, 32}); + orig.permutei({0, 3, 1, 2}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {16, 32, 64, 64})); + } + }; + + TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); + output += helper.runOperationSuit(&tb2, generator2, batch2, + "nchw->nhwc and nhwc->nchw Assign"); + return output; +} + +static std::string broadcastOpsMatrixBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Broadcast ops: matrices for rank 3, 4, 5 + for (int rank = 3; rank <= broadcastMatrixRankLimit; rank++) { + int numAxisTests = -1; + if (rank == 3) { + numAxisTests = 3; + } else if (rank == 4) { + numAxisTests = 6; + } else if (rank == 5) { + numAxisTests = 10; } - static std::string fastReduceAlongDimBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - int length[] = {1024*1024, 64*1024*1024}; - int powLimit[] = {10, 20, 26}; - int powStep[] = {2, 2, 4}; - - for( int i=0; i < limit3; i++ ){ - IntPowerParameters rows("rows", 2, 0, powLimit[i], powStep[i]); - BoolParameters dim("dim"); - - - ParametersBatch batch({&rows, &dim}); - - auto generator = PARAMETRIC_XYZ() { - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create('c', {rows, cols}); - - - x.push_back(arr); - y.push_back(NDArrayFactory::create(dim)); - - NDArray result; - if(dim == 0){ - result = NDArrayFactory::create('c', {cols}); - } else { - result = NDArrayFactory::create('c', {rows}); - } - z.push_back(result); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - std::string s1("Sum Along Dimension - "); - s1 += std::to_string(length[i]); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, s1.c_str()); - - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create('c', {rows, cols}); - - Nd4jLong* dimArg = new Nd4jLong[1]; - dimArg[0] = dim; - ctx->setIArguments(dimArg, 1); - delete[] dimArg; - - ctx->setInputArray(0, arr); - - NDArray result; - if(dim == 0){ - result = NDArrayFactory::create('c', {cols}); - } else { - result = NDArrayFactory::create('c', {rows}); - } - ctx->setOutputArray(0, result); - return ctx; - }; - - std::string s5("Argmax Along Dimension - "); - s5 += std::to_string(length[i]); - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + IntParameters testNum("testNum", 0, numAxisTests - 1, 1); + ParametersBatch b({&testNum}); + + auto generator = PARAMETRIC_D() { + int n = p.getIntParam("testNum"); + std::vector axis({}); + switch (n) { + // rank 3+ + case 0: + axis = std::vector({0, 1}); + break; + case 1: + axis = std::vector({0, 2}); + break; + case 2: + axis = std::vector({1, 2}); + break; + // rank 4+ + case 3: + axis = std::vector({0, 3}); + break; + case 4: + axis = std::vector({1, 3}); + break; + case 5: + axis = std::vector({2, 3}); + break; + // Rank 5 + case 6: + axis = std::vector({0, 4}); + break; + case 7: + axis = std::vector({1, 4}); + break; + case 8: + axis = std::vector({2, 4}); + break; + case 9: + axis = std::vector({3, 4}); + break; + } + + std::vector shape({}); + std::vector toBcShape({}); + int vectorLength; + if (rank == 3) { + shape = std::vector({64, 64, 64}); + toBcShape = std::vector({64, 64, 64}); + vectorLength = 64; + } else if (rank == 4) { + shape = std::vector({32, 32, 32, 32}); + toBcShape = std::vector({32, 32, 32, 32}); + vectorLength = 32; + } else if (rank == 5) { + shape = std::vector({16, 16, 16, 16, 16}); + toBcShape = std::vector({16, 16, 16, 16, 16}); + vectorLength = 16; + } + + for (int i = 0; i < rank; i++) { + if (axis[0] == i || axis[1] == i) { + continue; } - - return output; + toBcShape[i] = 1; + } + + auto ctx = new Context(1); + ctx->setInputArray(0, NDArrayFactory::create('c', shape)); + ctx->setInputArray(1, NDArrayFactory::create('c', toBcShape)); + ctx->setOutputArray(0, NDArrayFactory::create('c', shape)); + return ctx; + }; + + std::string name; + name += "Broadcast Matrix Add (Custom) - Rank"; + name += std::to_string(rank); + + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, b, name.c_str()); + } + + return output; +} + +static std::string broadcast2dBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + PredefinedParameters rows("rows", {65536}); + IntPowerParameters cols("cols", 2, 2, limit10, 4); // 2^2, 2^6, 2^10 + BoolParameters axis("axis"); + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rows, &cols, &axis, &inplace}); + + auto generator = PARAMETRIC_D() { + auto a = p.getIntParam("axis"); + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + + auto ctx = new Context(1); + ctx->setInputArray(0, arr); + if (a == 0) { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {p.getIntParam("rows"), 1})); + } else { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } - - static std::string fastReduceToScalarBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, reduceScalarPowLimit, 4); //2^10 to 2^26 in steps of 4 - - ParametersBatch batch({&length}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - - x.push_back(arr); - y.push_back(NDArray()); - z.push_back(NDArrayFactory::create(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Sum - Full Array Reduction"); - - //Index reduction - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - - ctx->setInputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); - ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); - ctx->setOutputArray(0, NDArrayFactory::create(0)); - - return ctx; - }; - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Argmax Full Array Reduction"); - - return output; + if (p.getIntParam("inplace") == 1) { + ctx->setOutputArray(0, arr); + ctx->markInplace(true); + } else { + ctx->setOutputArray( + 0, NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); } - - static std::string fastNonEwsTransformBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters rowcol("rowcol", 2, 2, nonEwsPowLimit, 4); //2^2 to 2^14 in steps of 4 -> non-inplace case: 2x 2^10 x 2^10 = 128mb - BoolParameters inplace("inplace"); - - ParametersBatch batch({&rowcol, &inplace}); - - auto generator = PARAMETRIC_XZ() { - int r = p.getIntParam("rowcol"); - auto arr = NDArrayFactory::create('c', {r, r+1}); - IndicesList indices({NDIndex::all(), NDIndex::interval(0,r-1)}); - auto view = arr.subarray(indices); - //nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), view->sizeAt(1)); - x.push_back(view); - if(p.getIntParam("inplace") == 1){ - z.push_back(view); - } else { - z.push_back(NDArrayFactory::create('c', {view.sizeAt(0),view.sizeAt(1)})); - } - }; - - ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU_View"); - sbLRelu.setY(NDArrayFactory::create(0.0)); - - TransformBenchmark tbExp(transform::StrictOps::Exp, "exp view"); - - output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU View"); - output += helper.runOperationSuit(&tbExp, generator, batch, "Exp View"); - - return output; + return ctx; + }; + + std::string s("add"); + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Broadcast (Custom) Add - 2d"); + return output; +} + +static std::string broadcastBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + // Broadcast ops: vectors for rank 2, 3, 4, 5 + for (int axis = 0; axis <= 1; axis++) { + PredefinedParameters rows("rows", {65536}); + IntPowerParameters cols( + "cols", 2, 2, limit10, + 4); // 2^1 to 2^10 in steps of 2 - 2^1=2, ..., 2^10=1024 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rows, &cols, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + x.push_back(arr); + if (axis == 0) { + y.push_back( + NDArrayFactory::create('c', {p.getIntParam("rows")})); + } else { + y.push_back( + NDArrayFactory::create('c', {p.getIntParam("cols")})); + } + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); + } + }; + + std::string s("bAdd"); + s += std::to_string(axis); + s += "r2"; + BroadcastBenchmark bAdd(broadcast::Add, s, {axis}); + output += helper.runOperationSuit(&bAdd, generator, batch, + "Broadcast Add - Rank 2"); + } + + for (int rank = 3; rank <= 5; rank++) { + for (int axis = 1; axis < rank; axis++) { + std::vector shape({}); + int vectorLength; + if (rank == 3) { + shape = std::vector({32, 128, 128}); + vectorLength = 128; + } else if (rank == 4) { + shape = std::vector({16, 64, 64, 64}); + vectorLength = 64; + } else if (rank == 5) { + shape = std::vector({16, 48, 48, 48, 48}); + vectorLength = 48; + } + + ParametersBatch batch({}); + + // Note: always inplace here + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', shape); + x.push_back(arr); + y.push_back(NDArrayFactory::create('c', {vectorLength})); + z.push_back(arr); + }; + + std::string name("bArr-r"); + name += std::to_string(rank); + name += "a"; + name += std::to_string(axis); + BroadcastBenchmark bAdd(broadcast::Add, name, {axis}); + std::string n2("Broadcast Add - Rank"); + n2 += std::to_string(rank); + n2 += " - axis="; + n2 += std::to_string(axis); + output += helper.runOperationSuit(&bAdd, generator, batch, n2.c_str()); } + } - static std::string fastPairwiseBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, pairwisePowLimit, 4); //2^10 to 2^26 in steps of 4 -> max is 512mb - BoolParameters inplace("inplace"); + return output; +} - ParametersBatch batch({&length, &inplace}); +static std::string fastStridedReductionNonEws() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); - x.push_back(arr1); - y.push_back(arr2); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr1); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; + IntPowerParameters stride("stride", 2, 0, 10, 2); // 2^0=1, ..., 2^10=1024 - PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); - output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + ParametersBatch batch({&stride}); - PairwiseBenchmark pb2(pairwise::Ops::Add, "Multiply"); - output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Multiply"); + // This is an edge case: technically an EWS *should* be available here + auto generator1 = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = NDArrayFactory::create( + 'c', {131072 + (stride == 1 ? 0 : 1), stride}); - return output; + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices( + {NDIndex::interval(0, 131072), NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string heavyTransformsBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, heavyPowLimit, 4); //2^10 to 2^22, steps of 4 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if (p.getIntParam("inplace") == 1) { - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; - - //Ops to test: erf (transform), betainc (custom), polygamma, synthetic ops? - TransformBenchmark erf(transform::StrictOps::Erf, "Erf"); - output += helper.runOperationSuit(&erf, generator, batch, "Error Function (Erf)"); - - ParametersBatch batch2({&length}); - sd::ops::polygamma op1; - DeclarableBenchmark pg(op1, "polygamma"); - auto generator2 = PARAMETRIC_D() { - auto ctx = new Context(1); - auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); - in0.assign(0.25); - auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); - in1.assign(0.5); - ctx->setInputArray(0, in0); - ctx->setInputArray(1, in1); - ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); - return ctx; - }; - - - IntPowerParameters lengthBetaInc("length", 2, 10, heavyPowLimit, 4); //2^10 to 2^22 in steps of 4 - ParametersBatch batch3({&lengthBetaInc}); - sd::ops::betainc op2; - DeclarableBenchmark binc(op2, "betainc"); - auto generator3 = PARAMETRIC_D() { - auto ctx = new Context(1); - auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); - in0.assign(0.25); - auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); - in1.assign(0.5); - auto in2 = NDArrayFactory::create('c', {p.getIntParam("length")}); - in2.assign(0.75); - ctx->setInputArray(0, in0); - ctx->setInputArray(1, in1); - ctx->setInputArray(2, in2); - ctx->setOutputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); - return ctx; - }; - - output += helper.runOperationSuit(&pg, generator2, batch2, "PolyGamma Function"); - output += helper.runOperationSuit(&binc, generator3, batch3, "Incomplete Beta Function (BetaInc)"); + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator1), + batch, "Strided Sum - No EWS Test 1"); + + // No EWS defined for this case + auto generator2 = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = NDArrayFactory::create( + 'c', {(stride == 1 ? 1 : 2) * 1024, 1024, stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::interval(0, 2 * 1024, 2), NDIndex::all(), + NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); + } - return output; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum2(reduce::SameOps::Sum, "stridedSumNoEWS"); + output += helper.runOperationSuit( + &rbSum2, + (const std::function)(generator2), + batch, "Strided Sum - No EWS Test 2"); + + return output; +} + +static std::string fastStridedReductionIrregular() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, + 4); // 2^12 to 2^20 in steps of 4 + PredefinedParameters stride( + "stride", + {26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, + 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028}); + + ParametersBatch batch({&length, &stride}); + + auto generator = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::interval(0, 1)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string intermediateTransformsBenchmark() { - std::string output; - - //Non-inplace: 2x 2^26 elements FP32 -> 512MB - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, intermediateTransformPowLimit, 4); //2^20 to 2^22 in steps of 4 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; - TransformBenchmark tbTanh(transform::StrictOps::Tanh, "tanh"); - TransformBenchmark tbGelu(transform::StrictOps::GELU, "gelu"); + ReductionBenchmark rbSum(reduce::SameOps::Sum, "stridedSum"); - output += helper.runOperationSuit(&tbTanh, generator, batch, "Tanh"); - output += helper.runOperationSuit(&tbGelu, generator, batch, "gelu"); + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Strided Sum - Irregular Strides"); + return output; +} - //2x 1024 cols x 2^18 = 2GB - IntPowerParameters rows("rows", 2, 10, intermediateTransformPowLimit2, 4); - PredefinedParameters cols("cols", {4, 128, 1024}); +static std::string fastStridedReductionsRegular() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - ParametersBatch batch2({&rows, &cols, &inplace}); + IntPowerParameters length("length", 2, 12, stridedReductionPowLimit, + 4); // 2^12 to 2^20 in steps of 4 + IntPowerParameters stride("stride", 2, 0, 10); // 2^0=1, ..., 2^10=1024 - auto generator2 = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("rows"), p.getIntParam("cols")})); - } - }; + ParametersBatch batch({&length, &stride}); - //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + auto generator = PARAMETRIC_XYZ() { + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); - //output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax"); - - return output; + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::point(0)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string fastTransformsBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - IntPowerParameters length("length", 2, 10, transformBenchmarkPowLimit, 4); //2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU"); - sbLRelu.setY(NDArrayFactory::create(0.0)); - - TransformBenchmark tbAbs(transform::SameOps::Abs, "abs"); - TransformBenchmark tbExp(transform::StrictOps::Exp, "exp"); - - output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU"); - output += helper.runOperationSuit(&tbAbs, generator, batch, "Abs"); - output += helper.runOperationSuit(&tbExp, generator, batch, "Exp"); - - return output; + strided.assign(1.0); + x.push_back(strided); + y.push_back(NDArray()); + // z.push_back(NDArrayFactory::create(0.0f)); + z.push_back(NDArrayFactory::create('c', {1})); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "Strided Sum"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Strided Sum - Regular Strides (powers of 2)"); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto stride = p.getIntParam("stride"); + auto arr = + NDArrayFactory::create('c', {p.getIntParam("length"), stride}); + + NDArray strided; + if (stride == 1) { + strided = arr; + } else { + IndicesList indices({NDIndex::all(), NDIndex::point(0)}); + strided = arr.subarray(indices); // All rows, first column } - static std::string fastScalarBenchmark() { - std::string output; - BenchmarkHelper helper(wIterations, rIterations); - - IntPowerParameters length("length", 2, 10, scalarBenchmarkPowLimit, 4); //2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); - ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); - ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - - - sbAdd.setY(NDArrayFactory::create(3.14159265359)); - sbDiv.setY(NDArrayFactory::create(3.14159265359)); - sbPow.setY(NDArrayFactory::create(3.14159265359)); - - - output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359) - F32"); - output += helper.runOperationSuit(&sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359) - F32"); - output += helper.runOperationSuit(&sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359) - F32"); - - return output; + strided.assign(1.0); + ctx->setInputArray(0, strided); + ctx->setOutputArray(0, NDArrayFactory::create('c', {1})); + auto iargs = new Nd4jLong[1]; + iargs[0] = 0; + ctx->setIArguments(iargs, 1); + delete[] iargs; + return ctx; + }; + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "stridedArgmax"); + output += + helper.runOperationSuit(&dbArgmax, generator3, batch, "Strided Argmax"); + return output; +} + +static std::string fastReduceAlongDimBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + int length[] = {1024 * 1024, 64 * 1024 * 1024}; + int powLimit[] = {10, 20, 26}; + int powStep[] = {2, 2, 4}; + + for (int i = 0; i < limit3; i++) { + IntPowerParameters rows("rows", 2, 0, powLimit[i], powStep[i]); + BoolParameters dim("dim"); + + ParametersBatch batch({&rows, &dim}); + + auto generator = PARAMETRIC_XYZ() { + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + x.push_back(arr); + y.push_back(NDArrayFactory::create(dim)); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + z.push_back(result); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + std::string s1("Sum Along Dimension - "); + s1 += std::to_string(length[i]); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, s1.c_str()); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + Nd4jLong *dimArg = new Nd4jLong[1]; + dimArg[0] = dim; + ctx->setIArguments(dimArg, 1); + delete[] dimArg; + + ctx->setInputArray(0, arr); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + ctx->setOutputArray(0, result); + return ctx; + }; + + std::string s5("Argmax Along Dimension - "); + s5 += std::to_string(length[i]); + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + } + + return output; +} + +static std::string fastReduceToScalarBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + + IntPowerParameters length("length", 2, 10, reduceScalarPowLimit, + 4); // 2^10 to 2^26 in steps of 4 + + ParametersBatch batch({&length}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + + x.push_back(arr); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Sum - Full Array Reduction"); + + // Index reduction + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + + ctx->setInputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); + + return ctx; + }; + output += helper.runOperationSuit(&dbArgmax, generator3, batch, + "Argmax Full Array Reduction"); + + return output; +} + +static std::string fastNonEwsTransformBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters rowcol("rowcol", 2, 2, nonEwsPowLimit, + 4); // 2^2 to 2^14 in steps of 4 -> non-inplace + // case: 2x 2^10 x 2^10 = 128mb + BoolParameters inplace("inplace"); + + ParametersBatch batch({&rowcol, &inplace}); + + auto generator = PARAMETRIC_XZ() { + int r = p.getIntParam("rowcol"); + auto arr = NDArrayFactory::create('c', {r, r + 1}); + IndicesList indices({NDIndex::all(), NDIndex::interval(0, r - 1)}); + auto view = arr.subarray(indices); + // nd4j_printf("VIEW ARRAY: rows=%lld, columns=%lld", view->sizeAt(0), + // view->sizeAt(1)); + x.push_back(view); + if (p.getIntParam("inplace") == 1) { + z.push_back(view); + } else { + z.push_back( + NDArrayFactory::create('c', {view.sizeAt(0), view.sizeAt(1)})); + } + }; + + ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU_View"); + sbLRelu.setY(NDArrayFactory::create(0.0)); + + TransformBenchmark tbExp(transform::StrictOps::Exp, "exp view"); + + output += + helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU View"); + output += helper.runOperationSuit(&tbExp, generator, batch, "Exp View"); + + return output; +} + +static std::string fastPairwiseBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, pairwisePowLimit, + 4); // 2^10 to 2^26 in steps of 4 -> max is 512mb + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + x.push_back(arr1); + y.push_back(arr2); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr1); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); + output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + + PairwiseBenchmark pb2(pairwise::Ops::Add, "Multiply"); + output += + helper.runOperationSuit(&pb2, generator, batch, "Pairwise Multiply"); + + return output; +} + +static std::string heavyTransformsBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, heavyPowLimit, + 4); // 2^10 to 2^22, steps of 4 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + // Ops to test: erf (transform), betainc (custom), polygamma, synthetic ops? + TransformBenchmark erf(transform::StrictOps::Erf, "Erf"); + output += + helper.runOperationSuit(&erf, generator, batch, "Error Function (Erf)"); + + ParametersBatch batch2({&length}); + sd::ops::polygamma op1; + DeclarableBenchmark pg(op1, "polygamma"); + auto generator2 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + return ctx; + }; + + IntPowerParameters lengthBetaInc("length", 2, 10, heavyPowLimit, + 4); // 2^10 to 2^22 in steps of 4 + ParametersBatch batch3({&lengthBetaInc}); + sd::ops::betainc op2; + DeclarableBenchmark binc(op2, "betainc"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + auto in0 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in0.assign(0.25); + auto in1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in1.assign(0.5); + auto in2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + in2.assign(0.75); + ctx->setInputArray(0, in0); + ctx->setInputArray(1, in1); + ctx->setInputArray(2, in2); + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + return ctx; + }; + + output += + helper.runOperationSuit(&pg, generator2, batch2, "PolyGamma Function"); + output += helper.runOperationSuit(&binc, generator3, batch3, + "Incomplete Beta Function (BetaInc)"); + + return output; +} + +static std::string intermediateTransformsBenchmark() { + std::string output; + + // Non-inplace: 2x 2^26 elements FP32 -> 512MB + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length("length", 2, 10, intermediateTransformPowLimit, + 4); // 2^20 to 2^22 in steps of 4 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + TransformBenchmark tbTanh(transform::StrictOps::Tanh, "tanh"); + TransformBenchmark tbGelu(transform::StrictOps::GELU, "gelu"); + + output += helper.runOperationSuit(&tbTanh, generator, batch, "Tanh"); + output += helper.runOperationSuit(&tbGelu, generator, batch, "gelu"); + + // 2x 1024 cols x 2^18 = 2GB + IntPowerParameters rows("rows", 2, 10, intermediateTransformPowLimit2, 4); + PredefinedParameters cols("cols", {4, 128, 1024}); + + ParametersBatch batch2({&rows, &cols, &inplace}); + + auto generator2 = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create( + 'c', {p.getIntParam("rows"), p.getIntParam("cols")})); } + }; + + // TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + + // output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, + // "Softmax"); + + return output; +} + +static std::string fastTransformsBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); + IntPowerParameters length( + "length", 2, 10, transformBenchmarkPowLimit, + 4); // 2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + ScalarBenchmark sbLRelu(scalar::Ops::LeakyRELU, "LeakyRELU"); + sbLRelu.setY(NDArrayFactory::create(0.0)); - static long nowMs(){ - auto s = std::chrono::system_clock::now().time_since_epoch(); - auto v = std::chrono::duration_cast(s).count(); - return v; - } + TransformBenchmark tbAbs(transform::SameOps::Abs, "abs"); + TransformBenchmark tbExp(transform::StrictOps::Exp, "exp"); - static long duration(long start){ - return nowMs() - start; - } + output += helper.runOperationSuit(&sbLRelu, generator, batch, "LeakyRELU"); + output += helper.runOperationSuit(&tbAbs, generator, batch, "Abs"); + output += helper.runOperationSuit(&tbExp, generator, batch, "Exp"); - static long done(long start){ - long dur = duration(start); - nd4j_printf("Done: %i ms\n", dur); - return nowMs(); - } + return output; +} +static std::string fastScalarBenchmark() { + std::string output; + BenchmarkHelper helper(wIterations, rIterations); - std::string FullBenchmarkSuit::runSuit() { - std::string result; - - long start = nowMs(); - - // set 1 - nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); - result += fastScalarBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastTransformsBenchmark\n", ""); - result += fastTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.intermediateTransformsBenchmark\n", ""); - result += intermediateTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastPairwiseBenchmark\n", ""); - result += fastPairwiseBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.heavyTransformsBenchmark\n", ""); - result += heavyTransformsBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastNonEwsTransformBenchmark\n", ""); - result += fastNonEwsTransformBenchmark(); - start = done(start); - - // set 2 - nd4j_printf("Running FullBenchmarkSuite.fastReduceToScalarBenchmark\n", ""); - result += fastReduceToScalarBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastReduceAlongDimBenchmark\n", ""); - result += fastReduceAlongDimBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionsRegular\n", ""); - result += fastStridedReductionsRegular(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionIrregular\n", ""); - result += fastStridedReductionIrregular(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionNonEws\n", ""); - result += fastStridedReductionNonEws(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcastBenchmark\n", ""); - result += broadcastBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcast2dBenchmark\n", ""); - result += broadcast2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.broadcastOpsMatrixBenchmark\n", ""); - result += broadcastOpsMatrixBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.mismatchedOrdersAssignBenchmark\n", ""); - result += mismatchedOrdersAssignBenchmark(); - start = done(start); - - - // set 3 - nd4j_printf("Running FullBenchmarkSuite.gatherOpBenchmark\n", ""); - result += gatherOpBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.scatterOpBenchmark\n", ""); - result += scatterOpBenchmark(); - start = done(start); - - // set 4 - nd4j_printf("Running FullBenchmarkSuite.gemmRegularBenchmark\n", ""); - result += gemmRegularBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.gemmIrregularBenchmark\n", ""); - result += gemmIrregularBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.rngBenchmark\n", ""); - result += rngBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.conv2dBenchmark\n", ""); - result += conv2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.pool2dBenchmark\n", ""); - result += pool2dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.batchnormBenchmark\n", ""); - result += batchnormBenchmark(); - start = done(start); - - nd4j_printf("Running FullBenchmarkSuite.lstmBenchmark\n", ""); - result += lstmBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.conv3dBenchmark\n", ""); - result += conv3dBenchmark(); - start = done(start); - nd4j_printf("Running FullBenchmarkSuite.maxPool3DBenchmark\n", ""); - result += maxPool3DBenchmark(); - start = done(start); -// nd4j_printf("Running FullBenchmarkSuite.layerNormBenchmark\n", ""); -// result += layerNormBenchmark(); -// start = done(start); - - return result; - } + IntPowerParameters length( + "length", 2, 10, scalarBenchmarkPowLimit, + 4); // 2^10 to 2^30 in steps of 4 - 2^10, 2^14, ..., 2^26 + BoolParameters inplace("inplace"); + ParametersBatch batch({&length, &inplace}); -} \ No newline at end of file + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back( + NDArrayFactory::create('c', {p.getIntParam("length")})); + } + }; + + ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); + ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); + ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); + + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); + + output += helper.runOperationSuit( + &sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359) - F32"); + output += helper.runOperationSuit( + &sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359) - F32"); + output += helper.runOperationSuit( + &sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359) - F32"); + + return output; +} + +static long nowMs() { + auto s = std::chrono::system_clock::now().time_since_epoch(); + auto v = std::chrono::duration_cast(s).count(); + return v; +} + +static long duration(long start) { return nowMs() - start; } + +static long done(long start) { + long dur = duration(start); + nd4j_printf("Done: %i ms\n", dur); + return nowMs(); +} + +std::string FullBenchmarkSuit::runSuit() { + std::string result; + + long start = nowMs(); + + // set 1 + nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); + result += fastScalarBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastTransformsBenchmark\n", ""); + result += fastTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.intermediateTransformsBenchmark\n", + ""); + result += intermediateTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastPairwiseBenchmark\n", ""); + result += fastPairwiseBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.heavyTransformsBenchmark\n", ""); + result += heavyTransformsBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastNonEwsTransformBenchmark\n", ""); + result += fastNonEwsTransformBenchmark(); + start = done(start); + + // set 2 + nd4j_printf("Running FullBenchmarkSuite.fastReduceToScalarBenchmark\n", ""); + result += fastReduceToScalarBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastReduceAlongDimBenchmark\n", ""); + result += fastReduceAlongDimBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionsRegular\n", ""); + result += fastStridedReductionsRegular(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionIrregular\n", ""); + result += fastStridedReductionIrregular(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.fastStridedReductionNonEws\n", ""); + result += fastStridedReductionNonEws(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcastBenchmark\n", ""); + result += broadcastBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcast2dBenchmark\n", ""); + result += broadcast2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.broadcastOpsMatrixBenchmark\n", ""); + result += broadcastOpsMatrixBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.mismatchedOrdersAssignBenchmark\n", + ""); + result += mismatchedOrdersAssignBenchmark(); + start = done(start); + + // set 3 + nd4j_printf("Running FullBenchmarkSuite.gatherOpBenchmark\n", ""); + result += gatherOpBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.scatterOpBenchmark\n", ""); + result += scatterOpBenchmark(); + start = done(start); + + // set 4 + nd4j_printf("Running FullBenchmarkSuite.gemmRegularBenchmark\n", ""); + result += gemmRegularBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.gemmIrregularBenchmark\n", ""); + result += gemmIrregularBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.rngBenchmark\n", ""); + result += rngBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.conv2dBenchmark\n", ""); + result += conv2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.pool2dBenchmark\n", ""); + result += pool2dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.batchnormBenchmark\n", ""); + result += batchnormBenchmark(); + start = done(start); + + nd4j_printf("Running FullBenchmarkSuite.lstmBenchmark\n", ""); + result += lstmBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.conv3dBenchmark\n", ""); + result += conv3dBenchmark(); + start = done(start); + nd4j_printf("Running FullBenchmarkSuite.maxPool3DBenchmark\n", ""); + result += maxPool3DBenchmark(); + start = done(start); + // nd4j_printf("Running FullBenchmarkSuite.layerNormBenchmark\n", ""); + // result += layerNormBenchmark(); + // start = done(start); + + return result; +} + +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp index 158131bbec54..2ee5bfead650 100644 --- a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp @@ -18,9 +18,10 @@ // @author raver119@gmail.com // -#include #include "performance/benchmarking/LightBenchmarkSuit.h" +#include + #ifdef RELEASE_BUILD #define WARMUP 5 #define NUM_ITER 100 @@ -34,607 +35,670 @@ namespace sd { - template - static std::string transformBenchmark() { - std::string output; - output += "transformBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - 4MB - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbRelu(scalar::Ops::RELU, "RELU"); - sbRelu.setY(NDArrayFactory::create(0.0)); - - TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); - //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); - - output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); - output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); - //output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); - - return output; - } - - template - static std::string scalarBenchmark() { - std::string output; - output += "scalarBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - BoolParameters inplace("inplace"); - - ParametersBatch batch({&length, &inplace}); - - auto generator = PARAMETRIC_XZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - arr.assign(1.0); - x.push_back(arr); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; - - ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); - ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); - ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); - - - sbAdd.setY(NDArrayFactory::create(3.14159265359)); - sbDiv.setY(NDArrayFactory::create(3.14159265359)); - sbPow.setY(NDArrayFactory::create(3.14159265359)); - - - output += helper.runOperationSuit(&sbAdd, generator, batch, "Scalar Addition - x.add(3.14159265359)"); - output += helper.runOperationSuit(&sbDiv, generator, batch, "Scalar Division - x.div(3.14159265359)"); - output += helper.runOperationSuit(&sbPow, generator, batch, "Scalar Power - x.pow(3.14159265359)"); - - return output; +template +static std::string transformBenchmark() { + std::string output; + output += "transformBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + IntPowerParameters length("length", 2, 8, 20, + 4); // 2^8, 2^12, 2^16, 2^20 - 4MB + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } + }; + ScalarBenchmark sbRelu(scalar::Ops::RELU, "RELU"); + sbRelu.setY(NDArrayFactory::create(0.0)); - template - static std::string pairwiseBenchmark() { - std::string output; - output += "pairwiseBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); + TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); + // TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); - BenchmarkHelper helper(WARMUP, NUM_ITER); - IntPowerParameters length("length", 2, 8, 20, 4); //2^4 to 2^20 in steps of 4 - 2^4, 2^8, 2^16, 2^20 - BoolParameters inplace("inplace"); + output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); + output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); + // output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); - ParametersBatch batch({&length, &inplace}); + return output; +} - auto generator = PARAMETRIC_XYZ() { - auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); - auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); - x.push_back(arr1); - y.push_back(arr2); - if(p.getIntParam("inplace") == 1){ - z.push_back(arr1); - } else { - z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); - } - }; +template +static std::string scalarBenchmark() { + std::string output; + output += + "scalarBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); - output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + BenchmarkHelper helper(WARMUP, NUM_ITER); - PairwiseBenchmark pb2(pairwise::Ops::Divide, "Divide"); - output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Divide"); + IntPowerParameters length("length", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + BoolParameters inplace("inplace"); - return output; - } + ParametersBatch batch({&length, &inplace}); - static std::string mismatchedOrderAssign() { - std::string output; - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters rows("rows", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - BoolParameters cf("cf"); - - ParametersBatch batch({&rows, &cf}); - - auto generator = PARAMETRIC_XZ() { - int numElements = 4194304; //2^24 - int rows = p.getIntParam("rows"); - int cols = numElements / rows; - bool c = p.getIntParam("cf"); - - auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); - auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); - x.push_back(arr); - z.push_back(arr2); - }; - - TransformBenchmark tb(transform::AnyOps::Assign, "assign"); - output += helper.runOperationSuit(&tb, generator, batch, "C->F and F->C Assign F32"); - - //Also test: NCHW to NHWC and back - BoolParameters nchw("nchw"); - int mb = 8; - int hw = 64; - int c = 3; - ParametersBatch batch2({&nchw}); - auto generator2 = PARAMETRIC_XZ() { - bool nchw = p.getIntParam("nchw"); - - if(nchw) { - auto orig = NDArrayFactory::create('c', {mb, c, hw, hw}); - orig.permutei({0,2,3,1}); - x.push_back(orig); - z.push_back(NDArrayFactory::create('c', {mb, hw, hw, c})); - } else { - auto orig = NDArrayFactory::create('c', {mb, hw, hw, c}); - orig.permutei({0,3,1,2}); - x.push_back(orig); - z.push_back(NDArrayFactory::create('c', {mb, c, hw, hw})); - } - }; - - TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); - output += helper.runOperationSuit(&tb2, generator2, batch2, "nchw->nhwc and nhwc->nchw Assign FP32"); - return output; + auto generator = PARAMETRIC_XZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + arr.assign(1.0); + x.push_back(arr); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } - - template - static std::string gemmBenchmark() { - std::string output; - output += "gemm " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - for (int o = 0; o <= 1; o++) { - char resultOrder = (o == 0 ? 'f' : 'c'); - IntPowerParameters sz("sz", 2, 4, 10, 2); //2^4=16, ..., 2^10=1024 -> 4 elements - - ParametersBatch b({&sz}); - - auto generator = PARAMETRIC_XYZ() { - auto a = p.getIntParam("sz"); - auto b = p.getIntParam("sz"); - auto c = p.getIntParam("sz"); - std::vector shapeA; - std::vector shapeB; - shapeA = {a, b}; - shapeB = {b, c}; - auto A = NDArrayFactory::create('c', shapeA); - auto B = NDArrayFactory::create('c', shapeB); - auto C = NDArrayFactory::create(resultOrder, {a, c}); - - x.push_back(A); - y.push_back(B); - z.push_back(C); - }; - - std::string n; - n += "Gemm - cOrder="; - n += resultOrder; - - MatrixBenchmark mb(1.0, 0.0, false, false, n); - - output += helper.runOperationSuit(&mb, generator, b, n.c_str()); - } - - return output; + }; + + ScalarBenchmark sbAdd(scalar::Ops::Add, "sAdd"); + ScalarBenchmark sbDiv(scalar::Ops::Divide, "sDiv"); + ScalarBenchmark sbPow(scalar::Ops::Pow, "sPow"); + + sbAdd.setY(NDArrayFactory::create(3.14159265359)); + sbDiv.setY(NDArrayFactory::create(3.14159265359)); + sbPow.setY(NDArrayFactory::create(3.14159265359)); + + output += helper.runOperationSuit(&sbAdd, generator, batch, + "Scalar Addition - x.add(3.14159265359)"); + output += helper.runOperationSuit(&sbDiv, generator, batch, + "Scalar Division - x.div(3.14159265359)"); + output += helper.runOperationSuit(&sbPow, generator, batch, + "Scalar Power - x.pow(3.14159265359)"); + + return output; +} + +template +static std::string pairwiseBenchmark() { + std::string output; + output += + "pairwiseBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + IntPowerParameters length( + "length", 2, 8, 20, + 4); // 2^4 to 2^20 in steps of 4 - 2^4, 2^8, 2^16, 2^20 + BoolParameters inplace("inplace"); + + ParametersBatch batch({&length, &inplace}); + + auto generator = PARAMETRIC_XYZ() { + auto arr1 = NDArrayFactory::create('c', {p.getIntParam("length")}); + auto arr2 = NDArrayFactory::create('c', {p.getIntParam("length")}); + x.push_back(arr1); + y.push_back(arr2); + if (p.getIntParam("inplace") == 1) { + z.push_back(arr1); + } else { + z.push_back(NDArrayFactory::create('c', {p.getIntParam("length")})); } - - template - static std::string reduceFullBenchmark() { - std::string output; - output += "reduceFullBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - IntPowerParameters length("length", 2, 8, 20, 4); //2^8, 2^12, 2^16, 2^20 - - ParametersBatch batch({&length}); - - auto generator = PARAMETRIC_XYZ() { - auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); - - x.push_back(arr); - y.push_back(NDArray()); - z.push_back(NDArrayFactory::create(0.0f)); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbProd(reduce::SameOps::Prod, "prod"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, "Sum - Full Array Reduction"); - output += helper.runOperationSuit(&rbProd, (const std::function)(generator), batch, "Product - Full Array Reduction"); - output += helper.runOperationSuit(&rbMax, (const std::function)(generator), batch, "Maximum - Full Array Reduction"); - - //Index reduction - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - - ctx->setInputArray(0, NDArrayFactory::create('c', {p.getIntParam("length")})); - ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); - ctx->setOutputArray(0, NDArrayFactory::create(0)); - - return ctx; - }; - output += helper.runOperationSuit(&dbArgmax, generator3, batch, "Argmax Full Array Reduction"); - return output; + }; + + PairwiseBenchmark pb1(pairwise::Ops::Add, "Add"); + output += helper.runOperationSuit(&pb1, generator, batch, "Pairwise Add"); + + PairwiseBenchmark pb2(pairwise::Ops::Divide, "Divide"); + output += helper.runOperationSuit(&pb2, generator, batch, "Pairwise Divide"); + + return output; +} + +static std::string mismatchedOrderAssign() { + std::string output; + BenchmarkHelper helper(WARMUP, NUM_ITER); + + IntPowerParameters rows("rows", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + BoolParameters cf("cf"); + + ParametersBatch batch({&rows, &cf}); + + auto generator = PARAMETRIC_XZ() { + int numElements = 4194304; // 2^24 + int rows = p.getIntParam("rows"); + int cols = numElements / rows; + bool c = p.getIntParam("cf"); + + auto arr = NDArrayFactory::create(c ? 'c' : 'f', {rows, cols}); + auto arr2 = NDArrayFactory::create(c ? 'f' : 'c', {rows, cols}); + x.push_back(arr); + z.push_back(arr2); + }; + + TransformBenchmark tb(transform::AnyOps::Assign, "assign"); + output += helper.runOperationSuit(&tb, generator, batch, + "C->F and F->C Assign F32"); + + // Also test: NCHW to NHWC and back + BoolParameters nchw("nchw"); + int mb = 8; + int hw = 64; + int c = 3; + ParametersBatch batch2({&nchw}); + auto generator2 = PARAMETRIC_XZ() { + bool nchw = p.getIntParam("nchw"); + + if (nchw) { + auto orig = NDArrayFactory::create('c', {mb, c, hw, hw}); + orig.permutei({0, 2, 3, 1}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {mb, hw, hw, c})); + } else { + auto orig = NDArrayFactory::create('c', {mb, hw, hw, c}); + orig.permutei({0, 3, 1, 2}); + x.push_back(orig); + z.push_back(NDArrayFactory::create('c', {mb, c, hw, hw})); } - - template - static std::string reduceDimBenchmark(){ - std::string output; - output += "reduceDimBenchmark " + DataTypeUtils::asString(DataTypeUtils::fromT()); - - BenchmarkHelper helper(WARMUP, NUM_ITER); - - int length[] = {1024*1024}; - int pow[] = {10}; - - for( int i=0; i<1; i++ ){ - IntPowerParameters rows("rows", 2, 0, pow[i], 2); - BoolParameters dim("dim"); - - - ParametersBatch batch({&rows, &dim}); - - auto generator = PARAMETRIC_XYZ() { - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create('c', {rows, cols}); - - - x.push_back(arr); - y.push_back(NDArrayFactory::create(dim)); - - NDArray result; - if(dim == 0){ - result = NDArrayFactory::create('c', {cols}); - } else { - result = NDArrayFactory::create('c', {rows}); - } - z.push_back(result); - }; - - ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); - ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); - - std::string s1("Sum Along Dimension - "); - s1 += std::to_string(length[i]); - std::string s3("Maximum Along Dimension - "); - s3 += std::to_string(length[i]); - - output += helper.runOperationSuit(&rbSum, (const std::function)(generator), batch, s1.c_str()); - output += helper.runOperationSuit(&rbMax, (const std::function)(generator), batch, s3.c_str()); - - - - auto generator3 = PARAMETRIC_D(){ - auto ctx = new Context(1); - int rows = p.getIntParam("rows"); - int cols = length[i] / rows; - int dim = p.getIntParam("dim"); - auto arr = NDArrayFactory::create('c', {rows, cols}); - - auto dimArg = new Nd4jLong[1]; - dimArg[0] = dim; - ctx->setIArguments(dimArg, 1); - delete[] dimArg; - - ctx->setInputArray(0, arr); - - NDArray result; - if(dim == 0){ - result = NDArrayFactory::create('c', {cols}); - } else { - result = NDArrayFactory::create('c', {rows}); - } - ctx->setOutputArray(0, result); - return ctx; - }; - - std::string s5("Argmax Along Dimension - "); - s5 += std::to_string(length[i]); - - sd::ops::argmax opArgmax; - DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); - output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); - } - return output; + }; + + TransformBenchmark tb2(transform::AnyOps::Assign, "assign_nchw"); + output += helper.runOperationSuit(&tb2, generator2, batch2, + "nchw->nhwc and nhwc->nchw Assign FP32"); + return output; +} + +template +static std::string gemmBenchmark() { + std::string output; + output += "gemm " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + for (int o = 0; o <= 1; o++) { + char resultOrder = (o == 0 ? 'f' : 'c'); + IntPowerParameters sz("sz", 2, 4, 10, + 2); // 2^4=16, ..., 2^10=1024 -> 4 elements + + ParametersBatch b({&sz}); + + auto generator = PARAMETRIC_XYZ() { + auto a = p.getIntParam("sz"); + auto b = p.getIntParam("sz"); + auto c = p.getIntParam("sz"); + std::vector shapeA; + std::vector shapeB; + shapeA = {a, b}; + shapeB = {b, c}; + auto A = NDArrayFactory::create('c', shapeA); + auto B = NDArrayFactory::create('c', shapeB); + auto C = NDArrayFactory::create(resultOrder, {a, c}); + + x.push_back(A); + y.push_back(B); + z.push_back(C); + }; + + std::string n; + n += "Gemm - cOrder="; + n += resultOrder; + + MatrixBenchmark mb(1.0, 0.0, false, false, n); + + output += helper.runOperationSuit(&mb, generator, b, n.c_str()); + } + + return output; +} + +template +static std::string reduceFullBenchmark() { + std::string output; + output += "reduceFullBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + + IntPowerParameters length("length", 2, 8, 20, 4); // 2^8, 2^12, 2^16, 2^20 + + ParametersBatch batch({&length}); + + auto generator = PARAMETRIC_XYZ() { + auto arr = NDArrayFactory::create('c', {p.getIntParam("length")}); + + x.push_back(arr); + y.push_back(NDArray()); + z.push_back(NDArrayFactory::create(0.0f)); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbProd(reduce::SameOps::Prod, "prod"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, "Sum - Full Array Reduction"); + output += helper.runOperationSuit( + &rbProd, + (const std::function)(generator), + batch, "Product - Full Array Reduction"); + output += helper.runOperationSuit( + &rbMax, + (const std::function)(generator), + batch, "Maximum - Full Array Reduction"); + + // Index reduction + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + + ctx->setInputArray( + 0, NDArrayFactory::create('c', {p.getIntParam("length")})); + ctx->setInputArray(1, NDArrayFactory::create((Nd4jLong)0)); + ctx->setOutputArray(0, NDArrayFactory::create(0)); + + return ctx; + }; + output += helper.runOperationSuit(&dbArgmax, generator3, batch, + "Argmax Full Array Reduction"); + return output; +} + +template +static std::string reduceDimBenchmark() { + std::string output; + output += "reduceDimBenchmark " + + DataTypeUtils::asString(DataTypeUtils::fromT()); + + BenchmarkHelper helper(WARMUP, NUM_ITER); + + int length[] = {1024 * 1024}; + int pow[] = {10}; + + for (int i = 0; i < 1; i++) { + IntPowerParameters rows("rows", 2, 0, pow[i], 2); + BoolParameters dim("dim"); + + ParametersBatch batch({&rows, &dim}); + + auto generator = PARAMETRIC_XYZ() { + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + x.push_back(arr); + y.push_back(NDArrayFactory::create(dim)); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + z.push_back(result); + }; + + ReductionBenchmark rbSum(reduce::SameOps::Sum, "sum"); + ReductionBenchmark rbMax(reduce::SameOps::Max, "max"); + + std::string s1("Sum Along Dimension - "); + s1 += std::to_string(length[i]); + std::string s3("Maximum Along Dimension - "); + s3 += std::to_string(length[i]); + + output += helper.runOperationSuit( + &rbSum, + (const std::function)(generator), + batch, s1.c_str()); + output += helper.runOperationSuit( + &rbMax, + (const std::function)(generator), + batch, s3.c_str()); + + auto generator3 = PARAMETRIC_D() { + auto ctx = new Context(1); + int rows = p.getIntParam("rows"); + int cols = length[i] / rows; + int dim = p.getIntParam("dim"); + auto arr = NDArrayFactory::create('c', {rows, cols}); + + auto dimArg = new Nd4jLong[1]; + dimArg[0] = dim; + ctx->setIArguments(dimArg, 1); + delete[] dimArg; + + ctx->setInputArray(0, arr); + + NDArray result; + if (dim == 0) { + result = NDArrayFactory::create('c', {cols}); + } else { + result = NDArrayFactory::create('c', {rows}); + } + ctx->setOutputArray(0, result); + return ctx; + }; + + std::string s5("Argmax Along Dimension - "); + s5 += std::to_string(length[i]); + + sd::ops::argmax opArgmax; + DeclarableBenchmark dbArgmax(opArgmax, "Argmax"); + output += helper.runOperationSuit(&dbArgmax, generator3, batch, s5.c_str()); + } + return output; +} + +template +static std::string conv2d() { + std::string output; + output += "conv2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + // Convolution2D op + BoolParameters nhwc("nhwc"); + PredefinedParameters k("k", {2, 3}); + + ParametersBatch batch({&nhwc, &k}); + sd::ops::conv2d conv2d; + DeclarableBenchmark benchmark(conv2d, "conv2d"); + + int hw = 64; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = NDArrayFactory::create('c', {8, 3, hw, hw}); + auto output = NDArrayFactory::create('c', {8, 3, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = NDArrayFactory::create('c', {8, hw, hw, 3}); + auto output = NDArrayFactory::create('c', {8, hw, hw, 3}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - template - static std::string conv2d(){ - std::string output; - output += "conv2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - //Convolution2D op - BoolParameters nhwc("nhwc"); - PredefinedParameters k("k", {2, 3}); - - ParametersBatch batch({&nhwc, &k}); - sd::ops::conv2d conv2d; - DeclarableBenchmark benchmark(conv2d, "conv2d"); - - int hw = 64; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create('c', {8, 3, hw, hw}); - auto output = NDArrayFactory::create('c', {8, 3, hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } else { - auto input = NDArrayFactory::create('c', {8, hw, hw, 3}); - auto output = NDArrayFactory::create('c', {8, hw, hw, 3}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } - - auto b = NDArrayFactory::create('c', {3}); - auto w = NDArrayFactory::create('c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always - - ctx->setInputArray(1, w); - ctx->setInputArray(2, b); - - auto args = new Nd4jLong[10]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 10); - delete[] args; - - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d"); - return output; + auto b = NDArrayFactory::create('c', {3}); + auto w = NDArrayFactory::create( + 'c', {khw, khw, 3, 3}); // [kH, kW, iC, oC] always + + ctx->setInputArray(1, w); + ctx->setInputArray(2, b); + + auto args = new Nd4jLong[10]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 10); + delete[] args; + + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "Conv2d"); + return output; +} + +template +static std::string pool2d() { + std::string output; + output += "pool2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + // Convolution2D op + BoolParameters nhwc("nhwc"); + PredefinedParameters k("k", {2, 3}); + + ParametersBatch batch({&nhwc, &k}); + + int c = 3; + int hw = 64; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int n = p.getIntParam("nhwc"); + int khw = p.getIntParam("k"); + + if (n == 0) { + auto input = NDArrayFactory::create('c', {8, c, hw, hw}); + auto output = NDArrayFactory::create('c', {8, c, hw, hw}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); + } else { + auto input = NDArrayFactory::create('c', {8, hw, hw, c}); + auto output = NDArrayFactory::create('c', {8, hw, hw, c}); + ctx->setInputArray(0, input); + ctx->setOutputArray(0, output); } - template - static std::string pool2d() { - std::string output; - output += "pool2d " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - //Convolution2D op - BoolParameters nhwc("nhwc"); - PredefinedParameters k("k", {2, 3}); - - ParametersBatch batch({&nhwc, &k}); - - int c = 3; - int hw = 64; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int n = p.getIntParam("nhwc"); - int khw = p.getIntParam("k"); - - if (n == 0) { - auto input = NDArrayFactory::create('c', {8, c, hw, hw}); - auto output = NDArrayFactory::create('c', {8, c, hw, hw}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } else { - auto input = NDArrayFactory::create('c', {8, hw, hw, c}); - auto output = NDArrayFactory::create('c', {8, hw, hw, c}); - ctx->setInputArray(0, input); - ctx->setOutputArray(0, output); - } - - auto args = new Nd4jLong[11]; - args[0] = args[1] = khw; //Kernel - args[2] = args[3] = 1;//Stride - args[4] = args[5] = 0; //Pad - args[6] = args[7] = 1; //Dilation - args[8] = 1; //SAME - args[9] = 0; //Divisor mode - 0 = exclude padding in divisor - args[10] = n;//0-nchw, 1=nhwc - ctx->setIArguments(args, 11); - delete[] args; - - return ctx; - }; - - sd::ops::avgpool2d avgpool2d; - DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); - output += helper.runOperationSuit(&benchmark1, generator, batch, "Average Pool 2d"); - - sd::ops::maxpool2d maxpool2d; - DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); - output += helper.runOperationSuit(&benchmark2, generator, batch, "Max Pool 2d"); - return output; + auto args = new Nd4jLong[11]; + args[0] = args[1] = khw; // Kernel + args[2] = args[3] = 1; // Stride + args[4] = args[5] = 0; // Pad + args[6] = args[7] = 1; // Dilation + args[8] = 1; // SAME + args[9] = 0; // Divisor mode - 0 = exclude padding in divisor + args[10] = n; // 0-nchw, 1=nhwc + ctx->setIArguments(args, 11); + delete[] args; + + return ctx; + }; + + sd::ops::avgpool2d avgpool2d; + DeclarableBenchmark benchmark1(avgpool2d, "avgpool"); + output += + helper.runOperationSuit(&benchmark1, generator, batch, "Average Pool 2d"); + + sd::ops::maxpool2d maxpool2d; + DeclarableBenchmark benchmark2(maxpool2d, "maxpool"); + output += + helper.runOperationSuit(&benchmark2, generator, batch, "Max Pool 2d"); + return output; +} + +template +static std::string lstmBenchmark() { + std::string output; + output += "lstm " + DataTypeUtils::asString(DataTypeUtils::fromT()); + BenchmarkHelper helper(WARMUP, NUM_ITER); + + BoolParameters format( + "format"); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] + PredefinedParameters mb("mb", {1, 8}); + int n = 128; + + ParametersBatch batch({&format, &mb}); + sd::ops::lstmBlock lstmBlock; + DeclarableBenchmark benchmark(lstmBlock, "lstm"); + + int seqLength = 8; + + auto generator = PARAMETRIC_D() { + auto ctx = new Context(1); + int f = p.getIntParam("format"); + int m = p.getIntParam("mb"); + + Nd4jLong l = 0; + ctx->setInputArray( + 0, NDArrayFactory::create(l)); // Max TS length (unused) + + if (f == 0) { + // TNS format + ctx->setInputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {seqLength, m, n})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('c', {seqLength, m, n})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('c', {seqLength, m, n})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('c', {seqLength, m, n})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('c', {seqLength, m, n})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('c', {seqLength, m, n})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('c', {seqLength, m, n})); // y + } else { + // NST format + ctx->setInputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // x + ctx->setOutputArray( + 0, NDArrayFactory::create('f', {m, n, seqLength})); // i + ctx->setOutputArray( + 1, NDArrayFactory::create('f', {m, n, seqLength})); // c + ctx->setOutputArray( + 2, NDArrayFactory::create('f', {m, n, seqLength})); // f + ctx->setOutputArray( + 3, NDArrayFactory::create('f', {m, n, seqLength})); // o + ctx->setOutputArray( + 4, NDArrayFactory::create('f', {m, n, seqLength})); // z + ctx->setOutputArray( + 5, NDArrayFactory::create('f', {m, n, seqLength})); // h + ctx->setOutputArray( + 6, NDArrayFactory::create('f', {m, n, seqLength})); // y } - template - static std::string lstmBenchmark() { - std::string output; - output += "lstm " + DataTypeUtils::asString(DataTypeUtils::fromT()); - BenchmarkHelper helper(WARMUP, NUM_ITER); - - BoolParameters format("format"); //0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen] - PredefinedParameters mb("mb", {1, 8}); - int n = 128; - - ParametersBatch batch({&format, &mb}); - sd::ops::lstmBlock lstmBlock; - DeclarableBenchmark benchmark(lstmBlock, "lstm"); - - int seqLength = 8; - - auto generator = PARAMETRIC_D() { - auto ctx = new Context(1); - int f = p.getIntParam("format"); - int m = p.getIntParam("mb"); - - Nd4jLong l = 0; - ctx->setInputArray(0, NDArrayFactory::create(l)); //Max TS length (unused) - - - if (f == 0) { - //TNS format - ctx->setInputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //x - ctx->setOutputArray(0, NDArrayFactory::create('c', {seqLength, m, n})); //i - ctx->setOutputArray(1, NDArrayFactory::create('c', {seqLength, m, n})); //c - ctx->setOutputArray(2, NDArrayFactory::create('c', {seqLength, m, n})); //f - ctx->setOutputArray(3, NDArrayFactory::create('c', {seqLength, m, n})); //o - ctx->setOutputArray(4, NDArrayFactory::create('c', {seqLength, m, n})); //z - ctx->setOutputArray(5, NDArrayFactory::create('c', {seqLength, m, n})); //h - ctx->setOutputArray(6, NDArrayFactory::create('c', {seqLength, m, n})); //y - } else { - //NST format - ctx->setInputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //x - ctx->setOutputArray(0, NDArrayFactory::create('f', {m, n, seqLength})); //i - ctx->setOutputArray(1, NDArrayFactory::create('f', {m, n, seqLength})); //c - ctx->setOutputArray(2, NDArrayFactory::create('f', {m, n, seqLength})); //f - ctx->setOutputArray(3, NDArrayFactory::create('f', {m, n, seqLength})); //o - ctx->setOutputArray(4, NDArrayFactory::create('f', {m, n, seqLength})); //z - ctx->setOutputArray(5, NDArrayFactory::create('f', {m, n, seqLength})); //h - ctx->setOutputArray(6, NDArrayFactory::create('f', {m, n, seqLength})); //y - } - - auto cLast = NDArrayFactory::create('c', {m, n}); - auto yLast = NDArrayFactory::create('c', {m, n}); - auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); - auto Wci = NDArrayFactory::create('c', {n}); - auto Wcf = NDArrayFactory::create('c', {n}); - auto Wco = NDArrayFactory::create('c', {n}); - auto b = NDArrayFactory::create('c', {4 * n}); - - ctx->setInputArray(2, cLast); - ctx->setInputArray(3, yLast); - ctx->setInputArray(4, W); - ctx->setInputArray(5, Wci); - ctx->setInputArray(6, Wcf); - ctx->setInputArray(7, Wco); - ctx->setInputArray(8, b); - - auto iargs = new Nd4jLong[2]; - iargs[0] = 0; //No peephole - iargs[1] = f; - ctx->setIArguments(iargs, 2); - delete[] iargs; - - auto targs = new double[2]; - targs[0] = 1.0; //forget bias - targs[1] = 0.0; //cell clipping value - ctx->setTArguments(targs, 2); - delete[] targs; - return ctx; - }; - - output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); - return output; + auto cLast = NDArrayFactory::create('c', {m, n}); + auto yLast = NDArrayFactory::create('c', {m, n}); + auto W = NDArrayFactory::create('c', {2 * n, 4 * n}); + auto Wci = NDArrayFactory::create('c', {n}); + auto Wcf = NDArrayFactory::create('c', {n}); + auto Wco = NDArrayFactory::create('c', {n}); + auto b = NDArrayFactory::create('c', {4 * n}); + + ctx->setInputArray(2, cLast); + ctx->setInputArray(3, yLast); + ctx->setInputArray(4, W); + ctx->setInputArray(5, Wci); + ctx->setInputArray(6, Wcf); + ctx->setInputArray(7, Wco); + ctx->setInputArray(8, b); + + auto iargs = new Nd4jLong[2]; + iargs[0] = 0; // No peephole + iargs[1] = f; + ctx->setIArguments(iargs, 2); + delete[] iargs; + + auto targs = new double[2]; + targs[0] = 1.0; // forget bias + targs[1] = 0.0; // cell clipping value + ctx->setTArguments(targs, 2); + delete[] targs; + return ctx; + }; + + output += helper.runOperationSuit(&benchmark, generator, batch, "LSTMBlock"); + return output; +} + +static std::string broadcast2d() { + std::string output; + BenchmarkHelper helper(WARMUP, NUM_ITER); + + int rows = 65536; + IntPowerParameters cols( + "cols", 2, 2, 12, + 4); // 2^2 to 2^12 in steps of 2 - 2^1=2, ..., 2^10=1024 + BoolParameters axis("axis"); + BoolParameters inplace("inplace"); + + ParametersBatch batch({&cols, &axis, &inplace}); + + auto generator = PARAMETRIC_D() { + auto a = p.getIntParam("axis"); + auto arr = + NDArrayFactory::create('c', {rows, p.getIntParam("cols")}); + + auto ctx = new Context(1); + ctx->setInputArray(0, arr); + if (a == 0) { + ctx->setInputArray(1, NDArrayFactory::create('c', {rows, 1})); + } else { + ctx->setInputArray( + 1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); } - - static std::string broadcast2d() { - std::string output; - BenchmarkHelper helper(WARMUP, NUM_ITER); - - int rows = 65536; - IntPowerParameters cols("cols", 2, 2, 12, 4); //2^2 to 2^12 in steps of 2 - 2^1=2, ..., 2^10=1024 - BoolParameters axis("axis"); - BoolParameters inplace("inplace"); - - ParametersBatch batch({&cols, &axis, &inplace}); - - auto generator = PARAMETRIC_D() { - auto a = p.getIntParam("axis"); - auto arr = NDArrayFactory::create('c', {rows, p.getIntParam("cols")}); - - auto ctx = new Context(1); - ctx->setInputArray(0, arr); - if(a == 0){ - ctx->setInputArray(1, NDArrayFactory::create('c', {rows, 1})); - } else { - ctx->setInputArray(1, NDArrayFactory::create('c', {1, p.getIntParam("cols")})); - } - if (p.getIntParam("inplace") == 1) { - ctx->setOutputArray(0, arr); - ctx->markInplace(true); - } else { - ctx->setOutputArray(0, NDArrayFactory::create('c', {rows, p.getIntParam("cols")})); - } - return ctx; - }; - - std::string s("add"); - sd::ops::add op; - DeclarableBenchmark benchmark(op, "add"); - output += helper.runOperationSuit(&benchmark, generator, batch, "Broadcast (Custom) Add - 2d"); - return output; + if (p.getIntParam("inplace") == 1) { + ctx->setOutputArray(0, arr); + ctx->markInplace(true); + } else { + ctx->setOutputArray( + 0, NDArrayFactory::create('c', {rows, p.getIntParam("cols")})); } - - std::string LightBenchmarkSuit::runSuit() { + return ctx; + }; + + std::string s("add"); + sd::ops::add op; + DeclarableBenchmark benchmark(op, "add"); + output += helper.runOperationSuit(&benchmark, generator, batch, + "Broadcast (Custom) Add - 2d"); + return output; +} + +std::string LightBenchmarkSuit::runSuit() { #ifdef RELEASE_BUILD - std::vector dtypes({sd::DataType::FLOAT32, sd::DataType::HALF}); + std::vector dtypes({sd::DataType::FLOAT32, sd::DataType::HALF}); #else - std::vector dtypes({sd::DataType::FLOAT32}); + std::vector dtypes({sd::DataType::FLOAT32}); #endif - std::string result; - - for (auto t:dtypes) { - nd4j_printf("Running LightBenchmarkSuite.transformBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += transformBenchmark, (), LIBND4J_TYPES); + std::string result; - nd4j_printf("Running LightBenchmarkSuite.scalarBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += scalarBenchmark, (), LIBND4J_TYPES); + for (auto t : dtypes) { + nd4j_printf("Running LightBenchmarkSuite.transformBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += transformBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.scalarBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += scalarBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.pairwiseBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += pairwiseBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.reduceDimBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += reduceDimBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.reduceFullBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += reduceFullBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.gemmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += gemmBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.reduceDimBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += reduceDimBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.conv2d [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += conv2d, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.gemmBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += gemmBenchmark, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.pool2d [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += pool2d, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.conv2d [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += conv2d, (), LIBND4J_TYPES); - nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", DataTypeUtils::asString(t).c_str()); - BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES); + nd4j_printf("Running LightBenchmarkSuite.pool2d [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += pool2d, (), LIBND4J_TYPES); - } + nd4j_printf("Running LightBenchmarkSuite.lstmBenchmark [%s]\n", + DataTypeUtils::asString(t).c_str()); + BUILD_SINGLE_SELECTOR(t, result += lstmBenchmark, (), LIBND4J_TYPES); + } - nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", ""); - result += broadcast2d(); - nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", ""); - result += mismatchedOrderAssign(); + nd4j_printf("Running LightBenchmarkSuite.broadcast2d\n", ""); + result += broadcast2d(); + nd4j_printf("Running LightBenchmarkSuite.mismatchedOrderAssign\n", ""); + result += mismatchedOrderAssign(); - return result; - } -} \ No newline at end of file + return result; +} +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/samediff.h b/libnd4j/include/samediff.h index 4907c980256f..71a058f45d88 100644 --- a/libnd4j/include/samediff.h +++ b/libnd4j/include/samediff.h @@ -30,10 +30,10 @@ #include // basic Graph-related includes -#include #include +#include // ML ops includes #include -#endif //S_SAMEDIFF_H +#endif // S_SAMEDIFF_H diff --git a/libnd4j/include/system/BlasVersionHelper.h b/libnd4j/include/system/BlasVersionHelper.h index cee95d26bf98..1878db24a370 100644 --- a/libnd4j/include/system/BlasVersionHelper.h +++ b/libnd4j/include/system/BlasVersionHelper.h @@ -21,20 +21,20 @@ #ifndef SAMEDIFF_BLASVERSIONHELPER_H #define SAMEDIFF_BLASVERSIONHELPER_H -#include #include #include +#include namespace sd { - class SD_EXPORT BlasVersionHelper { - public: - int _blasMajorVersion = 0; - int _blasMinorVersion = 0; - int _blasPatchVersion = 0; +class SD_EXPORT BlasVersionHelper { + public: + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; - BlasVersionHelper(); - ~BlasVersionHelper() = default; - }; -} + BlasVersionHelper(); + ~BlasVersionHelper() = default; +}; +} // namespace sd -#endif //SD_BLASVERSIONHELPER_H +#endif // SD_BLASVERSIONHELPER_H diff --git a/libnd4j/include/system/Environment.h b/libnd4j/include/system/Environment.h index 1369f99506b6..d2e174de77ef 100644 --- a/libnd4j/include/system/Environment.h +++ b/libnd4j/include/system/Environment.h @@ -21,131 +21,132 @@ #ifndef LIBND4J_ENVIRONMENT_H #define LIBND4J_ENVIRONMENT_H -#include -#include -#include -#include #include -#include +#include #include +#include + +#include +#include +#include -namespace sd{ - class SD_EXPORT Environment { - private: - std::atomic _tadThreshold; - std::atomic _elementThreshold; - std::atomic _verbose; - std::atomic _debug; - std::atomic _leaks; - std::atomic _profile; - std::atomic _dataType; - std::atomic _precBoost; - std::atomic _useMKLDNN{true}; - std::atomic _allowHelpers{true}; - - std::atomic _maxThreads; - std::atomic _maxMasterThreads; - - // these fields hold defaults - std::atomic _maxTotalPrimaryMemory{-1}; - std::atomic _maxTotalSpecialMemory{-1}; - std::atomic _maxDeviceMemory{-1}; - - bool _blasFallback = false; +namespace sd { +class SD_EXPORT Environment { + private: + std::atomic _tadThreshold; + std::atomic _elementThreshold; + std::atomic _verbose; + std::atomic _debug; + std::atomic _leaks; + std::atomic _profile; + std::atomic _dataType; + std::atomic _precBoost; + std::atomic _useMKLDNN{true}; + std::atomic _allowHelpers{true}; + + std::atomic _maxThreads; + std::atomic _maxMasterThreads; + + // these fields hold defaults + std::atomic _maxTotalPrimaryMemory{-1}; + std::atomic _maxTotalSpecialMemory{-1}; + std::atomic _maxDeviceMemory{-1}; + + bool _blasFallback = false; #ifdef __ND4J_EXPERIMENTAL__ - const bool _experimental = true; + const bool _experimental = true; #else - const bool _experimental = false; + const bool _experimental = false; #endif - // device compute capability for CUDA - std::vector _capabilities; + // device compute capability for CUDA + std::vector _capabilities; + + static Environment* _instance; - static Environment* _instance; + Environment(); + ~Environment(); - Environment(); - ~Environment(); - public: - /** - * These 3 fields are mostly for CUDA/cuBLAS version tracking - */ - int _blasMajorVersion = 0; - int _blasMinorVersion = 0; - int _blasPatchVersion = 0; + public: + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + int _blasMajorVersion = 0; + int _blasMinorVersion = 0; + int _blasPatchVersion = 0; - static Environment* getInstance(); + static Environment* getInstance(); - bool isVerbose(); - void setVerbose(bool reallyVerbose); - bool isDebug(); - bool isProfiling(); - bool isDetectingLeaks(); - bool isDebugAndVerbose(); - void setDebug(bool reallyDebug); - void setProfiling(bool reallyProfile); - void setLeaksDetector(bool reallyDetect); - bool helpersAllowed(); - void allowHelpers(bool reallyAllow); + bool isVerbose(); + void setVerbose(bool reallyVerbose); + bool isDebug(); + bool isProfiling(); + bool isDetectingLeaks(); + bool isDebugAndVerbose(); + void setDebug(bool reallyDebug); + void setProfiling(bool reallyProfile); + void setLeaksDetector(bool reallyDetect); + bool helpersAllowed(); + void allowHelpers(bool reallyAllow); - bool blasFallback(); - - int tadThreshold(); - void setTadThreshold(int threshold); + bool blasFallback(); - int elementwiseThreshold(); - void setElementwiseThreshold(int threshold); + int tadThreshold(); + void setTadThreshold(int threshold); - int maxThreads(); - void setMaxThreads(int max); + int elementwiseThreshold(); + void setElementwiseThreshold(int threshold); - int maxMasterThreads(); - void setMaxMasterThreads(int max); + int maxThreads(); + void setMaxThreads(int max); - /* - * Legacy memory limits API, still used in new API as simplified version - */ - void setMaxPrimaryMemory(uint64_t maxBytes); - void setMaxSpecialyMemory(uint64_t maxBytes); - void setMaxDeviceMemory(uint64_t maxBytes); + int maxMasterThreads(); + void setMaxMasterThreads(int max); - uint64_t maxPrimaryMemory(); - uint64_t maxSpecialMemory(); - //////////////////////// + /* + * Legacy memory limits API, still used in new API as simplified version + */ + void setMaxPrimaryMemory(uint64_t maxBytes); + void setMaxSpecialyMemory(uint64_t maxBytes); + void setMaxDeviceMemory(uint64_t maxBytes); - /* - * Methods for memory limits/counters - */ - void setGroupLimit(int group, Nd4jLong numBytes); - void setDeviceLimit(int deviceId, Nd4jLong numBytes); + uint64_t maxPrimaryMemory(); + uint64_t maxSpecialMemory(); + //////////////////////// - Nd4jLong getGroupLimit(int group); - Nd4jLong getDeviceLimit(int deviceId); + /* + * Methods for memory limits/counters + */ + void setGroupLimit(int group, Nd4jLong numBytes); + void setDeviceLimit(int deviceId, Nd4jLong numBytes); - Nd4jLong getGroupCounter(int group); - Nd4jLong getDeviceCounter(int deviceId); - //////////////////////// + Nd4jLong getGroupLimit(int group); + Nd4jLong getDeviceLimit(int deviceId); - bool isUseMKLDNN() { return _useMKLDNN.load(); } - void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } + Nd4jLong getGroupCounter(int group); + Nd4jLong getDeviceCounter(int deviceId); + //////////////////////// - sd::DataType defaultFloatDataType(); - void setDefaultFloatDataType(sd::DataType dtype); + bool isUseMKLDNN() { return _useMKLDNN.load(); } + void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } - bool precisionBoostAllowed(); - void allowPrecisionBoost(bool reallyAllow); + sd::DataType defaultFloatDataType(); + void setDefaultFloatDataType(sd::DataType dtype); - bool isExperimentalBuild(); + bool precisionBoostAllowed(); + void allowPrecisionBoost(bool reallyAllow); - bool isCPU(); + bool isExperimentalBuild(); - int blasMajorVersion(); - int blasMinorVersion(); - int blasPatchVersion(); + bool isCPU(); - std::vector& capabilities(); - }; -} + int blasMajorVersion(); + int blasMinorVersion(); + int blasPatchVersion(); + std::vector& capabilities(); +}; +} // namespace sd -#endif //LIBND4J_ENVIRONMENT_H +#endif // LIBND4J_ENVIRONMENT_H diff --git a/libnd4j/include/system/buffer.h b/libnd4j/include/system/buffer.h old mode 100755 new mode 100644 index 5072965ca9bd..d54ebe441343 --- a/libnd4j/include/system/buffer.h +++ b/libnd4j/include/system/buffer.h @@ -28,62 +28,56 @@ #include #include #endif -#include - #include -#include #include +#include #include - //Question: Should the indexes here really be int? Isn't size_t or Nd4jLong more appropriate? +// Question: Should the indexes here really be int? Isn't size_t or Nd4jLong +// more appropriate? namespace sd { - namespace buffer { +namespace buffer { /** * Represents both a cpu and gpu * buffer - mainly used for testing */ - template - struct Buffer { - int length = 0; - int allocatedOnGpu = 0; - T *data = nullptr; - T *gData = nullptr; - T one, two; - public: - ~Buffer() { - delete []data; - delete []gData; - } - - void assign(T *val) { - data = val; - } - - T &operator=(T x) { - one = x; - return x; - } - - class Proxy { - Buffer &a; - int idx; - public: - Proxy(Buffer &a, int idx) : - a(a), idx(idx) { - } - - T &operator=(T x) { - a.two = x; - a.data[idx] = x; - return a.data[idx]; - } - }; - - - Proxy operator[](int index) { - return Proxy(*this, index); - } - }; +template +struct Buffer { + int length = 0; + int allocatedOnGpu = 0; + T *data = nullptr; + T *gData = nullptr; + T one, two; + + public: + ~Buffer() { + delete[] data; + delete[] gData; + } + + void assign(T *val) { data = val; } + + T &operator=(T x) { + one = x; + return x; + } + + class Proxy { + Buffer &a; + int idx; + + public: + Proxy(Buffer &a, int idx) : a(a), idx(idx) {} + + T &operator=(T x) { + a.two = x; + a.data[idx] = x; + return a.data[idx]; + } + }; + + Proxy operator[](int index) { return Proxy(*this, index); } +}; /** * Returns the size of the buffer @@ -91,13 +85,14 @@ namespace sd { * @param buffer the buffer to get the size of * @return the size of the buffer in bytes */ - template +template #ifdef __CUDACC__ - __host__ __device__ +__host__ __device__ #endif - int bufferSize(Buffer *buffer); + int + bufferSize(Buffer *buffer); /** * Copies data to the gpu @@ -105,45 +100,41 @@ namespace sd { */ #ifdef __CUDACC__ - template - __host__ - void copyDataToGpu(Buffer **buffer, cudaStream_t stream); +template +__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream); #endif - - /** * Copies data from the gpu * @param buffer the buffer to copy */ #ifdef __CUDACC__ - template - __host__ - void copyDataFromGpu(Buffer **buffer, cudaStream_t stream); +template +__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream); #endif - - /** * Allocate buffer of the given * length on the cpu and gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void allocBuffer(Buffer **buffer, int length); + void + allocBuffer(Buffer **buffer, int length); /** * Frees the given buffer * (gpu and cpu */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void freeBuffer(Buffer **buffer); + void + freeBuffer(Buffer **buffer); /** * Creates a buffer @@ -151,60 +142,65 @@ namespace sd { * and also synchronizes * the data on the gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - Buffer - * - createBuffer(T *data, int length); + Buffer + *createBuffer(T *data, int length); /** * Print the buffer on the host * @param buff */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void printArr(Buffer *buff); + void + printArr(Buffer *buff); /** * * @param buffer * @return */ - template +template #ifdef __CUDACC__ - __host__ __device__ +__host__ __device__ #endif - int bufferSize(Buffer *buffer) { - return sizeof(T) * buffer->length; - } + int + bufferSize(Buffer *buffer) { + return sizeof(T) * buffer->length; +} #ifdef __CUDACC__ - /** +/** * * @param buffer */ -template -__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream) { - Buffer *bufferRef = *buffer; - checkCudaErrors(cudaMemcpyAsync(bufferRef->gData, bufferRef->data, bufferSize(bufferRef), cudaMemcpyHostToDevice, stream)); - checkCudaErrors(cudaStreamSynchronize(stream)); +template +__host__ void copyDataToGpu(Buffer **buffer, cudaStream_t stream) { + Buffer *bufferRef = *buffer; + checkCudaErrors(cudaMemcpyAsync(bufferRef->gData, bufferRef->data, + bufferSize(bufferRef), cudaMemcpyHostToDevice, + stream)); + checkCudaErrors(cudaStreamSynchronize(stream)); } /** * * @param buffer */ -template -__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { - Buffer *bufferRef = *buffer; - int bufferTotalSize = bufferSize(bufferRef); - checkCudaErrors(cudaMemcpyAsync(bufferRef->data, bufferRef->gData, bufferTotalSize, cudaMemcpyDeviceToHost, stream)); - checkCudaErrors(cudaStreamSynchronize(stream)); +template +__host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { + Buffer *bufferRef = *buffer; + int bufferTotalSize = bufferSize(bufferRef); + checkCudaErrors(cudaMemcpyAsync(bufferRef->data, bufferRef->gData, + bufferTotalSize, cudaMemcpyDeviceToHost, + stream)); + checkCudaErrors(cudaStreamSynchronize(stream)); } #endif @@ -212,38 +208,40 @@ __host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { * Allocate buffer of the given * length on the cpu and gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void allocBuffer(Buffer **buffer, int length) { - Buffer *bufferRef = *buffer; - bufferRef->length = length; - bufferRef->data = reinterpret_cast(malloc(sizeof(T) * length)); - - CHECK_ALLOC(bufferRef->data, "Failed to allocate new buffer", sizeof(T) * length); + void + allocBuffer(Buffer **buffer, int length) { + Buffer *bufferRef = *buffer; + bufferRef->length = length; + bufferRef->data = reinterpret_cast(malloc(sizeof(T) * length)); + + CHECK_ALLOC(bufferRef->data, "Failed to allocate new buffer", + sizeof(T) * length); #ifdef __CUDACC__ - checkCudaErrors(cudaMalloc(&bufferRef->gData, sizeof(T) * length)); + checkCudaErrors(cudaMalloc(&bufferRef->gData, sizeof(T) * length)); #endif - } +} /** * Frees the given buffer * (gpu and cpu */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - void freeBuffer(Buffer *buffer) { + void + freeBuffer(Buffer *buffer) { #ifdef __CUDACC__ - if(buffer->gData != nullptr) - checkCudaErrors(cudaFree(buffer->gData)); + if (buffer->gData != nullptr) checkCudaErrors(cudaFree(buffer->gData)); #endif - delete buffer; - } + delete buffer; +} /** * Creates a buffer @@ -251,47 +249,44 @@ __host__ void copyDataFromGpu(Buffer **buffer, cudaStream_t stream) { * and also synchronizes * the data on the gpu. */ - template +template #ifdef __CUDACC__ - __host__ +__host__ #endif - Buffer *createBuffer(T *data, int length) { - Buffer *ret = new Buffer; - T *buffData = new T[length]; - for(int i = 0; i < length; i++) - buffData[i] = data[i]; - ret->data = buffData; - ret->length = length; - return ret; - } - - + Buffer + *createBuffer(T *data, int length) { + Buffer *ret = new Buffer; + T *buffData = new T[length]; + for (int i = 0; i < length; i++) buffData[i] = data[i]; + ret->data = buffData; + ret->length = length; + return ret; +} #ifdef __CUDACC__ - template - __host__ - Buffer *createBuffer(T *data, int length, cudaStream_t stream) { - Buffer *ret = createBuffer(data, length); - - T *gData; - T **gDataRef = &(gData); - checkCudaErrors(cudaMalloc(reinterpret_cast(gDataRef), sizeof(T) * length)); - ret->gData = gData; - checkCudaErrors(cudaMemcpyAsync(ret->gData, ret->data, sizeof(T) * length, cudaMemcpyHostToDevice, stream)); - return ret; - } -#endif - } +template +__host__ Buffer *createBuffer(T *data, int length, cudaStream_t stream) { + Buffer *ret = createBuffer(data, length); + + T *gData; + T **gDataRef = &(gData); + checkCudaErrors( + cudaMalloc(reinterpret_cast(gDataRef), sizeof(T) * length)); + ret->gData = gData; + checkCudaErrors(cudaMemcpyAsync(ret->gData, ret->data, sizeof(T) * length, + cudaMemcpyHostToDevice, stream)); + return ret; } - - +#endif +} // namespace buffer +} // namespace sd #ifdef __CUDACC__ -template -__host__ void printArr(sd::buffer::Buffer *buff) { - for (int i = 0; i < buff->length; i++) { - printf("Buffer[%d] was %f\n", i, buff->data[i]); - } +template +__host__ void printArr(sd::buffer::Buffer *buff) { + for (int i = 0; i < buff->length; i++) { + printf("Buffer[%d] was %f\n", i, buff->data[i]); + } } #endif diff --git a/libnd4j/include/system/dll.h b/libnd4j/include/system/dll.h index 63d7c39c4bc1..99dc203a0664 100644 --- a/libnd4j/include/system/dll.h +++ b/libnd4j/include/system/dll.h @@ -25,8 +25,8 @@ #ifdef _WIN32 //#include -# define SD_EXPORT __declspec(dllexport) +#define SD_EXPORT __declspec(dllexport) #else -# define SD_EXPORT +#define SD_EXPORT #endif -#endif //NATIVEOPERATIONS_DLL_H +#endif // NATIVEOPERATIONS_DLL_H diff --git a/libnd4j/include/system/enum_boilerplate.h b/libnd4j/include/system/enum_boilerplate.h index 2acb1536aa37..5b198b7005c1 100644 --- a/libnd4j/include/system/enum_boilerplate.h +++ b/libnd4j/include/system/enum_boilerplate.h @@ -23,113 +23,133 @@ #include - #define EN_1(WHAT, OP_PAIR) WHAT(OP_PAIR) -#define EN_2(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_1(WHAT, __VA_ARGS__)) -#define EN_3(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_2(WHAT, __VA_ARGS__)) -#define EN_4(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_3(WHAT, __VA_ARGS__)) -#define EN_5(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_4(WHAT, __VA_ARGS__)) -#define EN_6(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_5(WHAT, __VA_ARGS__)) -#define EN_7(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_6(WHAT, __VA_ARGS__)) -#define EN_8(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_7(WHAT, __VA_ARGS__)) -#define EN_9(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_8(WHAT, __VA_ARGS__)) -#define EN_10(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_9(WHAT, __VA_ARGS__)) -#define EN_11(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_10(WHAT, __VA_ARGS__)) -#define EN_12(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_11(WHAT, __VA_ARGS__)) -#define EN_13(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_12(WHAT, __VA_ARGS__)) -#define EN_14(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_13(WHAT, __VA_ARGS__)) -#define EN_15(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_14(WHAT, __VA_ARGS__)) -#define EN_16(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_15(WHAT, __VA_ARGS__)) -#define EN_17(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_16(WHAT, __VA_ARGS__)) -#define EN_18(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_17(WHAT, __VA_ARGS__)) -#define EN_19(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_18(WHAT, __VA_ARGS__)) -#define EN_20(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_19(WHAT, __VA_ARGS__)) -#define EN_21(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_20(WHAT, __VA_ARGS__)) -#define EN_22(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_21(WHAT, __VA_ARGS__)) -#define EN_23(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_22(WHAT, __VA_ARGS__)) -#define EN_24(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_23(WHAT, __VA_ARGS__)) -#define EN_25(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_24(WHAT, __VA_ARGS__)) -#define EN_26(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_25(WHAT, __VA_ARGS__)) -#define EN_27(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_26(WHAT, __VA_ARGS__)) -#define EN_28(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_27(WHAT, __VA_ARGS__)) -#define EN_29(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_28(WHAT, __VA_ARGS__)) -#define EN_30(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_29(WHAT, __VA_ARGS__)) -#define EN_31(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_30(WHAT, __VA_ARGS__)) -#define EN_32(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_31(WHAT, __VA_ARGS__)) -#define EN_33(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_32(WHAT, __VA_ARGS__)) -#define EN_34(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_33(WHAT, __VA_ARGS__)) -#define EN_35(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_34(WHAT, __VA_ARGS__)) -#define EN_36(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_35(WHAT, __VA_ARGS__)) -#define EN_37(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_36(WHAT, __VA_ARGS__)) -#define EN_38(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_37(WHAT, __VA_ARGS__)) -#define EN_39(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_38(WHAT, __VA_ARGS__)) -#define EN_40(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_39(WHAT, __VA_ARGS__)) -#define EN_41(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_40(WHAT, __VA_ARGS__)) -#define EN_42(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_41(WHAT, __VA_ARGS__)) -#define EN_43(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_42(WHAT, __VA_ARGS__)) -#define EN_44(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_43(WHAT, __VA_ARGS__)) -#define EN_45(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_44(WHAT, __VA_ARGS__)) -#define EN_46(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_45(WHAT, __VA_ARGS__)) -#define EN_47(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_46(WHAT, __VA_ARGS__)) -#define EN_48(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_47(WHAT, __VA_ARGS__)) -#define EN_49(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_48(WHAT, __VA_ARGS__)) -#define EN_50(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_49(WHAT, __VA_ARGS__)) -#define EN_51(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_50(WHAT, __VA_ARGS__)) -#define EN_52(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_51(WHAT, __VA_ARGS__)) -#define EN_53(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_52(WHAT, __VA_ARGS__)) -#define EN_54(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_53(WHAT, __VA_ARGS__)) -#define EN_55(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_54(WHAT, __VA_ARGS__)) -#define EN_56(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_55(WHAT, __VA_ARGS__)) -#define EN_57(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_56(WHAT, __VA_ARGS__)) -#define EN_58(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_57(WHAT, __VA_ARGS__)) -#define EN_59(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_58(WHAT, __VA_ARGS__)) -#define EN_60(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_59(WHAT, __VA_ARGS__)) -#define EN_61(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_60(WHAT, __VA_ARGS__)) -#define EN_62(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_61(WHAT, __VA_ARGS__)) -#define EN_63(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_62(WHAT, __VA_ARGS__)) -#define EN_64(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_63(WHAT, __VA_ARGS__)) -#define EN_65(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_64(WHAT, __VA_ARGS__)) -#define EN_66(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_65(WHAT, __VA_ARGS__)) -#define EN_67(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_66(WHAT, __VA_ARGS__)) -#define EN_68(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_67(WHAT, __VA_ARGS__)) -#define EN_69(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_68(WHAT, __VA_ARGS__)) -#define EN_70(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_69(WHAT, __VA_ARGS__)) -#define EN_71(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_70(WHAT, __VA_ARGS__)) -#define EN_72(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_71(WHAT, __VA_ARGS__)) -#define EN_73(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_72(WHAT, __VA_ARGS__)) -#define EN_74(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_73(WHAT, __VA_ARGS__)) -#define EN_75(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_74(WHAT, __VA_ARGS__)) -#define EN_76(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_75(WHAT, __VA_ARGS__)) -#define EN_77(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_76(WHAT, __VA_ARGS__)) -#define EN_78(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_77(WHAT, __VA_ARGS__)) -#define EN_79(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_78(WHAT, __VA_ARGS__)) -#define EN_80(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_79(WHAT, __VA_ARGS__)) -#define EN_81(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_80(WHAT, __VA_ARGS__)) -#define EN_82(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_81(WHAT, __VA_ARGS__)) -#define EN_83(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_82(WHAT, __VA_ARGS__)) -#define EN_84(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_83(WHAT, __VA_ARGS__)) -#define EN_85(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_84(WHAT, __VA_ARGS__)) -#define EN_86(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_85(WHAT, __VA_ARGS__)) -#define EN_87(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_86(WHAT, __VA_ARGS__)) -#define EN_88(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_87(WHAT, __VA_ARGS__)) -#define EN_89(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_88(WHAT, __VA_ARGS__)) -#define EN_90(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_89(WHAT, __VA_ARGS__)) -#define EN_91(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_90(WHAT, __VA_ARGS__)) -#define EN_92(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_91(WHAT, __VA_ARGS__)) -#define EN_93(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_92(WHAT, __VA_ARGS__)) -#define EN_94(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_93(WHAT, __VA_ARGS__)) -#define EN_95(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_94(WHAT, __VA_ARGS__)) -#define EN_96(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_95(WHAT, __VA_ARGS__)) -#define EN_97(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_96(WHAT, __VA_ARGS__)) -#define EN_98(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_97(WHAT, __VA_ARGS__)) -#define EN_99(WHAT, OP_PAIR, ...) WHAT(OP_PAIR)EVAL(EN_98(WHAT, __VA_ARGS__)) +#define EN_2(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_1(WHAT, __VA_ARGS__)) +#define EN_3(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_2(WHAT, __VA_ARGS__)) +#define EN_4(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_3(WHAT, __VA_ARGS__)) +#define EN_5(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_4(WHAT, __VA_ARGS__)) +#define EN_6(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_5(WHAT, __VA_ARGS__)) +#define EN_7(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_6(WHAT, __VA_ARGS__)) +#define EN_8(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_7(WHAT, __VA_ARGS__)) +#define EN_9(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_8(WHAT, __VA_ARGS__)) +#define EN_10(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_9(WHAT, __VA_ARGS__)) +#define EN_11(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_10(WHAT, __VA_ARGS__)) +#define EN_12(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_11(WHAT, __VA_ARGS__)) +#define EN_13(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_12(WHAT, __VA_ARGS__)) +#define EN_14(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_13(WHAT, __VA_ARGS__)) +#define EN_15(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_14(WHAT, __VA_ARGS__)) +#define EN_16(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_15(WHAT, __VA_ARGS__)) +#define EN_17(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_16(WHAT, __VA_ARGS__)) +#define EN_18(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_17(WHAT, __VA_ARGS__)) +#define EN_19(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_18(WHAT, __VA_ARGS__)) +#define EN_20(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_19(WHAT, __VA_ARGS__)) +#define EN_21(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_20(WHAT, __VA_ARGS__)) +#define EN_22(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_21(WHAT, __VA_ARGS__)) +#define EN_23(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_22(WHAT, __VA_ARGS__)) +#define EN_24(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_23(WHAT, __VA_ARGS__)) +#define EN_25(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_24(WHAT, __VA_ARGS__)) +#define EN_26(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_25(WHAT, __VA_ARGS__)) +#define EN_27(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_26(WHAT, __VA_ARGS__)) +#define EN_28(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_27(WHAT, __VA_ARGS__)) +#define EN_29(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_28(WHAT, __VA_ARGS__)) +#define EN_30(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_29(WHAT, __VA_ARGS__)) +#define EN_31(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_30(WHAT, __VA_ARGS__)) +#define EN_32(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_31(WHAT, __VA_ARGS__)) +#define EN_33(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_32(WHAT, __VA_ARGS__)) +#define EN_34(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_33(WHAT, __VA_ARGS__)) +#define EN_35(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_34(WHAT, __VA_ARGS__)) +#define EN_36(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_35(WHAT, __VA_ARGS__)) +#define EN_37(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_36(WHAT, __VA_ARGS__)) +#define EN_38(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_37(WHAT, __VA_ARGS__)) +#define EN_39(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_38(WHAT, __VA_ARGS__)) +#define EN_40(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_39(WHAT, __VA_ARGS__)) +#define EN_41(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_40(WHAT, __VA_ARGS__)) +#define EN_42(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_41(WHAT, __VA_ARGS__)) +#define EN_43(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_42(WHAT, __VA_ARGS__)) +#define EN_44(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_43(WHAT, __VA_ARGS__)) +#define EN_45(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_44(WHAT, __VA_ARGS__)) +#define EN_46(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_45(WHAT, __VA_ARGS__)) +#define EN_47(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_46(WHAT, __VA_ARGS__)) +#define EN_48(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_47(WHAT, __VA_ARGS__)) +#define EN_49(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_48(WHAT, __VA_ARGS__)) +#define EN_50(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_49(WHAT, __VA_ARGS__)) +#define EN_51(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_50(WHAT, __VA_ARGS__)) +#define EN_52(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_51(WHAT, __VA_ARGS__)) +#define EN_53(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_52(WHAT, __VA_ARGS__)) +#define EN_54(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_53(WHAT, __VA_ARGS__)) +#define EN_55(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_54(WHAT, __VA_ARGS__)) +#define EN_56(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_55(WHAT, __VA_ARGS__)) +#define EN_57(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_56(WHAT, __VA_ARGS__)) +#define EN_58(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_57(WHAT, __VA_ARGS__)) +#define EN_59(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_58(WHAT, __VA_ARGS__)) +#define EN_60(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_59(WHAT, __VA_ARGS__)) +#define EN_61(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_60(WHAT, __VA_ARGS__)) +#define EN_62(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_61(WHAT, __VA_ARGS__)) +#define EN_63(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_62(WHAT, __VA_ARGS__)) +#define EN_64(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_63(WHAT, __VA_ARGS__)) +#define EN_65(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_64(WHAT, __VA_ARGS__)) +#define EN_66(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_65(WHAT, __VA_ARGS__)) +#define EN_67(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_66(WHAT, __VA_ARGS__)) +#define EN_68(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_67(WHAT, __VA_ARGS__)) +#define EN_69(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_68(WHAT, __VA_ARGS__)) +#define EN_70(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_69(WHAT, __VA_ARGS__)) +#define EN_71(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_70(WHAT, __VA_ARGS__)) +#define EN_72(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_71(WHAT, __VA_ARGS__)) +#define EN_73(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_72(WHAT, __VA_ARGS__)) +#define EN_74(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_73(WHAT, __VA_ARGS__)) +#define EN_75(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_74(WHAT, __VA_ARGS__)) +#define EN_76(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_75(WHAT, __VA_ARGS__)) +#define EN_77(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_76(WHAT, __VA_ARGS__)) +#define EN_78(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_77(WHAT, __VA_ARGS__)) +#define EN_79(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_78(WHAT, __VA_ARGS__)) +#define EN_80(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_79(WHAT, __VA_ARGS__)) +#define EN_81(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_80(WHAT, __VA_ARGS__)) +#define EN_82(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_81(WHAT, __VA_ARGS__)) +#define EN_83(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_82(WHAT, __VA_ARGS__)) +#define EN_84(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_83(WHAT, __VA_ARGS__)) +#define EN_85(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_84(WHAT, __VA_ARGS__)) +#define EN_86(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_85(WHAT, __VA_ARGS__)) +#define EN_87(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_86(WHAT, __VA_ARGS__)) +#define EN_88(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_87(WHAT, __VA_ARGS__)) +#define EN_89(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_88(WHAT, __VA_ARGS__)) +#define EN_90(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_89(WHAT, __VA_ARGS__)) +#define EN_91(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_90(WHAT, __VA_ARGS__)) +#define EN_92(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_91(WHAT, __VA_ARGS__)) +#define EN_93(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_92(WHAT, __VA_ARGS__)) +#define EN_94(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_93(WHAT, __VA_ARGS__)) +#define EN_95(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_94(WHAT, __VA_ARGS__)) +#define EN_96(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_95(WHAT, __VA_ARGS__)) +#define EN_97(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_96(WHAT, __VA_ARGS__)) +#define EN_98(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_97(WHAT, __VA_ARGS__)) +#define EN_99(WHAT, OP_PAIR, ...) WHAT(OP_PAIR) EVAL(EN_98(WHAT, __VA_ARGS__)) #define __EXPAND_ENUM(NUM, CLASS) CLASS = NUM, -#define _EXPAND_ENUM(OP_PAIR) EVALUATING_PASTE(__EXPAND, _ENUM(UNPAREN(OP_PAIR))) +#define _EXPAND_ENUM(OP_PAIR) \ + EVALUATING_PASTE(__EXPAND, _ENUM(UNPAREN(OP_PAIR))) -#define GET_MACROS_ENUM(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, NAME,...) NAME +#define GET_MACROS_ENUM( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, \ + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, \ + _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, \ + _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, \ + _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, \ + _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, \ + _92, _93, _94, _95, _96, _97, _98, _99, NAME, ...) \ + NAME -#define FOR_EACH_ENUM(WHAT, ...) EXPAND(GET_MACROS_ENUM(__VA_ARGS__, EN_99, EN_98, EN_97, EN_96, EN_95, EN_94, EN_93, EN_92, EN_91, EN_90, EN_89, EN_88, EN_87, EN_86, EN_85, EN_84, EN_83, EN_82, EN_81, EN_80, EN_79, EN_78, EN_77, EN_76, EN_75, EN_74, EN_73, EN_72, EN_71, EN_70, EN_69, EN_68, EN_67, EN_66, EN_65, EN_64, EN_63, EN_62, EN_61, EN_60, EN_59, EN_58, EN_57, EN_56, EN_55, EN_54, EN_53, EN_52, EN_51, EN_50, EN_49, EN_48, EN_47, EN_46, EN_45, EN_44, EN_43, EN_42, EN_41, EN_40, EN_39, EN_38, EN_37, EN_36, EN_35, EN_34, EN_33, EN_32, EN_31, EN_30, EN_29, EN_28, EN_27, EN_26, EN_25, EN_24, EN_23, EN_22, EN_21, EN_20, EN_19, EN_18, EN_17, EN_16, EN_15, EN_14, EN_13, EN_12, EN_11, EN_10, EN_9, EN_8, EN_7, EN_6, EN_5, EN_4, EN_3, EN_2, EN_1)(WHAT, __VA_ARGS__)) +#define FOR_EACH_ENUM(WHAT, ...) \ + EXPAND(GET_MACROS_ENUM( \ + __VA_ARGS__, EN_99, EN_98, EN_97, EN_96, EN_95, EN_94, EN_93, EN_92, \ + EN_91, EN_90, EN_89, EN_88, EN_87, EN_86, EN_85, EN_84, EN_83, EN_82, \ + EN_81, EN_80, EN_79, EN_78, EN_77, EN_76, EN_75, EN_74, EN_73, EN_72, \ + EN_71, EN_70, EN_69, EN_68, EN_67, EN_66, EN_65, EN_64, EN_63, EN_62, \ + EN_61, EN_60, EN_59, EN_58, EN_57, EN_56, EN_55, EN_54, EN_53, EN_52, \ + EN_51, EN_50, EN_49, EN_48, EN_47, EN_46, EN_45, EN_44, EN_43, EN_42, \ + EN_41, EN_40, EN_39, EN_38, EN_37, EN_36, EN_35, EN_34, EN_33, EN_32, \ + EN_31, EN_30, EN_29, EN_28, EN_27, EN_26, EN_25, EN_24, EN_23, EN_22, \ + EN_21, EN_20, EN_19, EN_18, EN_17, EN_16, EN_15, EN_14, EN_13, EN_12, \ + EN_11, EN_10, EN_9, EN_8, EN_7, EN_6, EN_5, EN_4, EN_3, EN_2, \ + EN_1)(WHAT, __VA_ARGS__)) #define _EXEC_ENUM(WHAT, ...) EVAL(FOR_EACH_ENUM(WHAT, __VA_ARGS__)) #define EXEC_ENUMERATOR(...) EXPAND(_EXEC_ENUM(_EXPAND_ENUM, __VA_ARGS__)) diff --git a/libnd4j/include/system/msvc.h b/libnd4j/include/system/msvc.h index 4708d97c6874..f14423a8572c 100644 --- a/libnd4j/include/system/msvc.h +++ b/libnd4j/include/system/msvc.h @@ -23,17 +23,17 @@ #if defined(_MSC_VER) -#pragma warning( disable : 4244 ) -#pragma warning( disable : 4267 ) -#pragma warning( disable : 4251 ) -#pragma warning( disable : 4101 ) -#pragma warning( disable : 4305 ) -#pragma warning( disable : 4309 ) -#pragma warning( disable : 4333 ) -#pragma warning( disable : 4146 ) -#pragma warning( disable : 4018 ) -#pragma warning( disable : 4297 ) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4251) +#pragma warning(disable : 4101) +#pragma warning(disable : 4305) +#pragma warning(disable : 4309) +#pragma warning(disable : 4333) +#pragma warning(disable : 4146) +#pragma warning(disable : 4018) +#pragma warning(disable : 4297) #endif -#endif //SD_MSVC_H +#endif // SD_MSVC_H diff --git a/libnd4j/include/system/nd4jmalloc.h b/libnd4j/include/system/nd4jmalloc.h index c808aad090f7..1ee40a021302 100644 --- a/libnd4j/include/system/nd4jmalloc.h +++ b/libnd4j/include/system/nd4jmalloc.h @@ -21,4 +21,4 @@ #ifndef NATIVEOPERATIONS_ND4JMALLOC_H #define NATIVEOPERATIONS_ND4JMALLOC_H #include -#endif //NATIVEOPERATIONS_ND4JMALLOC_H +#endif // NATIVEOPERATIONS_ND4JMALLOC_H diff --git a/libnd4j/include/system/nd4jmemset.h b/libnd4j/include/system/nd4jmemset.h index 6482dcb8a2f7..7df35d5a0fef 100644 --- a/libnd4j/include/system/nd4jmemset.h +++ b/libnd4j/include/system/nd4jmemset.h @@ -21,4 +21,4 @@ #ifndef NATIVEOPERATIONS_ND4JSTRING_H #define NATIVEOPERATIONS_ND4JSTRING_H #include -#endif //NATIVEOPERATIONS_ND4JSTRING_H +#endif // NATIVEOPERATIONS_ND4JSTRING_H diff --git a/libnd4j/include/system/op_enums.h b/libnd4j/include/system/op_enums.h index ad16d281e56b..5a1bb3217473 100644 --- a/libnd4j/include/system/op_enums.h +++ b/libnd4j/include/system/op_enums.h @@ -18,136 +18,91 @@ // @author raver119@gmail.com // - #ifndef LIBND4J_OP_ENUMS_H #define LIBND4J_OP_ENUMS_H #include -#include #include +#include namespace sd { - namespace random { - enum Ops { - BUILD_ENUMERATION(RANDOM_OPS) - }; - } - - namespace transform { - enum FloatOps { - BUILD_ENUMERATION(TRANSFORM_FLOAT_OPS) - }; - - enum SameOps { - BUILD_ENUMERATION(TRANSFORM_SAME_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(TRANSFORM_BOOL_OPS) - }; - - enum AnyOps { - BUILD_ENUMERATION(TRANSFORM_ANY_OPS) - }; - - enum StrictOps { - BUILD_ENUMERATION(TRANSFORM_STRICT_OPS) - }; - } - - namespace pairwise { - enum Ops { - BUILD_ENUMERATION(PAIRWISE_TRANSFORM_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(PAIRWISE_INT_OPS) - }; - } - - namespace scalar { - enum Ops { - BUILD_ENUMERATION(SCALAR_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(SCALAR_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(SCALAR_INT_OPS) - }; - } - - namespace reduce { - enum FloatOps { - BUILD_ENUMERATION(REDUCE_FLOAT_OPS) - }; - - enum SameOps { - BUILD_ENUMERATION(REDUCE_SAME_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(REDUCE_BOOL_OPS) - }; - - enum LongOps { - BUILD_ENUMERATION(REDUCE_LONG_OPS) - }; - } - - namespace reduce3 { - enum Ops { - BUILD_ENUMERATION(REDUCE3_OPS) - }; - } - - namespace indexreduce { - enum Ops { - BUILD_ENUMERATION(INDEX_REDUCE_OPS) - }; - } - - namespace broadcast { - enum Ops { - BUILD_ENUMERATION(BROADCAST_OPS) - }; - - enum BoolOps { - BUILD_ENUMERATION(BROADCAST_BOOL_OPS) - }; - - enum IntOps { - BUILD_ENUMERATION(BROADCAST_INT_OPS) - }; - } - - namespace variance { - enum Ops { - BUILD_ENUMERATION(SUMMARY_STATS_OPS) - }; - } - - namespace logic { - enum Ops { - While = 0, - Scope = 10, - Conditional = 20, - Switch = 30, - Return = 40, - Expose = 50, - Merge = 60, - LoopCond = 70, - NextIteration = 80, - Exit = 90, - Enter = 100, - }; - } +namespace random { +enum Ops { BUILD_ENUMERATION(RANDOM_OPS) }; +} + +namespace transform { +enum FloatOps { BUILD_ENUMERATION(TRANSFORM_FLOAT_OPS) }; + +enum SameOps { BUILD_ENUMERATION(TRANSFORM_SAME_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(TRANSFORM_BOOL_OPS) }; + +enum AnyOps { BUILD_ENUMERATION(TRANSFORM_ANY_OPS) }; + +enum StrictOps { BUILD_ENUMERATION(TRANSFORM_STRICT_OPS) }; +} // namespace transform + +namespace pairwise { +enum Ops { BUILD_ENUMERATION(PAIRWISE_TRANSFORM_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(PAIRWISE_INT_OPS) }; +} // namespace pairwise + +namespace scalar { +enum Ops { BUILD_ENUMERATION(SCALAR_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(SCALAR_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(SCALAR_INT_OPS) }; +} // namespace scalar + +namespace reduce { +enum FloatOps { BUILD_ENUMERATION(REDUCE_FLOAT_OPS) }; + +enum SameOps { BUILD_ENUMERATION(REDUCE_SAME_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(REDUCE_BOOL_OPS) }; + +enum LongOps { BUILD_ENUMERATION(REDUCE_LONG_OPS) }; +} // namespace reduce + +namespace reduce3 { +enum Ops { BUILD_ENUMERATION(REDUCE3_OPS) }; +} + +namespace indexreduce { +enum Ops { BUILD_ENUMERATION(INDEX_REDUCE_OPS) }; +} + +namespace broadcast { +enum Ops { BUILD_ENUMERATION(BROADCAST_OPS) }; + +enum BoolOps { BUILD_ENUMERATION(BROADCAST_BOOL_OPS) }; + +enum IntOps { BUILD_ENUMERATION(BROADCAST_INT_OPS) }; +} // namespace broadcast + +namespace variance { +enum Ops { BUILD_ENUMERATION(SUMMARY_STATS_OPS) }; +} + +namespace logic { +enum Ops { + While = 0, + Scope = 10, + Conditional = 20, + Switch = 30, + Return = 40, + Expose = 50, + Merge = 60, + LoopCond = 70, + NextIteration = 80, + Exit = 90, + Enter = 100, +}; } +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/system/openmp_pragmas.h b/libnd4j/include/system/openmp_pragmas.h index 0259ed75474f..14a42241e6c6 100644 --- a/libnd4j/include/system/openmp_pragmas.h +++ b/libnd4j/include/system/openmp_pragmas.h @@ -61,9 +61,8 @@ #else - #define OMP_STRINGIFY(args) #args -#define OMP_IF(args) if(args) +#define OMP_IF(args) if (args) #define OMP_SCHEDULE(args) schedule(args) #define OMP_MAXT maxT #define OMP_SUMT sumT @@ -73,13 +72,21 @@ #define PRAGMA_OMP_CRITICAL _Pragma(OMP_STRINGIFY(omp critical)) #define PRAGMA_OMP_SIMD _Pragma(OMP_STRINGIFY(omp simd)) #define PRAGMA_OMP_SIMD_ARGS(args) _Pragma(OMP_STRINGIFY(omp simd args)) -#define PRAGMA_OMP_SIMD_SUM(args) _Pragma(OMP_STRINGIFY(omp simd reduction(sumT:args))) -#define PRAGMA_OMP_SIMD_MAX(args) _Pragma(OMP_STRINGIFY(omp simd reduction(maxTF:args))) +#define PRAGMA_OMP_SIMD_SUM(args) \ + _Pragma(OMP_STRINGIFY(omp simd reduction(sumT : args))) +#define PRAGMA_OMP_SIMD_MAX(args) \ + _Pragma(OMP_STRINGIFY(omp simd reduction(maxTF : args))) #define PRAGMA_OMP_PARALLEL _Pragma(OMP_STRINGIFY(omp parallel default(shared))) -#define PRAGMA_OMP_PARALLEL_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared))) -#define PRAGMA_OMP_PARALLEL_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel args default(shared))) -#define PRAGMA_OMP_PARALLEL_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel num_threads(args) if(args > 1) default(shared))) -#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition) _Pragma(OMP_STRINGIFY(omp parallel num_threads(threads) if(condition) default(shared))) +#define PRAGMA_OMP_PARALLEL_REDUCTION(args) \ + _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared))) +#define PRAGMA_OMP_PARALLEL_ARGS(args) \ + _Pragma(OMP_STRINGIFY(omp parallel args default(shared))) +#define PRAGMA_OMP_PARALLEL_THREADS(args) \ + _Pragma(OMP_STRINGIFY( \ + omp parallel num_threads(args) if (args > 1) default(shared))) +#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition) \ + _Pragma(OMP_STRINGIFY( \ + omp parallel num_threads(threads) if (condition) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR _Pragma(OMP_STRINGIFY(omp parallel for default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for reduction(args) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel for args default(shared))) @@ -92,7 +99,8 @@ #define PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(loops) _Pragma(OMP_STRINGIFY(omp parallel for simd default(shared) collapse(loops))) #define PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for simd reduction(args) default(shared))) #define PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel for simd num_threads(args) if(args > 1) default(shared))) -#define PRAGMA_OMP_PARALLEL_SECTIONS _Pragma(OMP_STRINGIFY(omp parallel sections)) +#define PRAGMA_OMP_PARALLEL_SECTIONS \ + _Pragma(OMP_STRINGIFY(omp parallel sections)) #define PRAGMA_OMP_SECTION _Pragma(OMP_STRINGIFY(omp section)) #define PRAGMA_OMP_SINGLE _Pragma(OMP_STRINGIFY(omp single)) #define PRAGMA_OMP_SINGLE_ARGS(args) _Pragma(OMP_STRINGIFY(omp single args)) @@ -113,26 +121,43 @@ // parallel_for block #define FUNC_1D std::function -#define FUNC_2D std::function -#define FUNC_3D std::function +#define FUNC_2D \ + std::function +#define FUNC_3D \ + std::function // aggregation lambda #define LAMBDA_AL [&](int64_t _old, int64_t _new) -> int64_t #define LAMBDA_AD [&](double _old, double _new) -> double -#define LAMBDA_SUML LAMBDA_AL {return _old + _new; } -#define LAMBDA_SUMD LAMBDA_AD {return _old + _new; } +#define LAMBDA_SUML \ + LAMBDA_AL { return _old + _new; } +#define LAMBDA_SUMD \ + LAMBDA_AD { return _old + _new; } // reduction lambda -#define PRAGMA_REDUCE_LONG [&] (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) mutable -> int64_t -#define PRAGMA_REDUCE_DOUBLE [&] (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) mutable -> double +#define PRAGMA_REDUCE_LONG \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) mutable -> int64_t +#define PRAGMA_REDUCE_DOUBLE \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) mutable -> double // paralllel block lambda -#define PRAGMA_THREADS_DO [&](uint64_t thread_id, uint64_t numThreads) -> void +#define PRAGMA_THREADS_DO [&](uint64_t thread_id, uint64_t numThreads) -> void // paralllel_for lambdas -#define PRAGMA_THREADS_FOR [&](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void -#define PRAGMA_THREADS_FOR_2D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) -> void -#define PRAGMA_THREADS_FOR_3D [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) -> void - -#endif //SD_OPENMP_PRAGMAS_H +#define PRAGMA_THREADS_FOR \ + [&](uint64_t thread_id, int64_t start, int64_t stop, \ + int64_t increment) -> void +#define PRAGMA_THREADS_FOR_2D \ + [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, \ + int64_t start_y, int64_t stop_y, int64_t inc_y) -> void +#define PRAGMA_THREADS_FOR_3D \ + [&](uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t inc_x, \ + int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, \ + int64_t stop_z, int64_t inc_z) -> void + +#endif // SD_OPENMP_PRAGMAS_H diff --git a/libnd4j/include/system/optype.h b/libnd4j/include/system/optype.h index 70323f104c8b..23008db8de25 100644 --- a/libnd4j/include/system/optype.h +++ b/libnd4j/include/system/optype.h @@ -16,7 +16,7 @@ #pragma once -//TODO convert this into an enum class +// TODO convert this into an enum class // might break JNI though... typedef int OpType; @@ -25,9 +25,7 @@ typedef int OpType; #define constexpr #endif -namespace op_type -{ - constexpr OpType Variance = 0; - constexpr OpType StandardDeviation = 1; -} - +namespace op_type { +constexpr OpType Variance = 0; +constexpr OpType StandardDeviation = 1; +} // namespace op_type diff --git a/libnd4j/include/system/pairwise_util.h b/libnd4j/include/system/pairwise_util.h old mode 100755 new mode 100644 index d9e0965c89d0..b494690259a5 --- a/libnd4j/include/system/pairwise_util.h +++ b/libnd4j/include/system/pairwise_util.h @@ -33,16 +33,17 @@ #endif #include -#include -#include -#include #include #include +#include +#include + +#include #ifdef _OPENMP #include #endif -//Loops adapted from: -//https://github.com/numpy/numpy/blob/009b17a85a22707e63ac9ea1896413992bbf9ce5/numpy/core/src/private/lowlevel_strided_loops.h#L401-L401 +// Loops adapted from: +// https://github.com/numpy/numpy/blob/009b17a85a22707e63ac9ea1896413992bbf9ce5/numpy/core/src/private/lowlevel_strided_loops.h#L401-L401 /* namespace shape { @@ -60,16 +61,15 @@ namespace shape { ************************************************************/ typedef struct { - int perm, stride; + int perm, stride; } StridePermutation; - /** * Credit to: * http://alienryderflex.com/quicksort/ * - * The non recursive implementation is important for being able to run on cuda as well - * as host. + * The non recursive implementation is important for being able to run on cuda + * as well as host. * * In practice the work loads intended for * this won't be so hard @@ -81,108 +81,97 @@ typedef struct { #ifdef __CUDACC__ __host__ __device__ #endif -void quickSort(StridePermutation *arr, int elements); - + void + quickSort(StridePermutation *arr, int elements); /* Start raw iteration */ #define ND4J_RAW_ITER_START(idim, ndim, coord, shape) \ - memset((coord), 0, (ndim) * sizeof(coord[0])); \ - do { - + memset((coord), 0, (ndim) * sizeof(coord[0])); \ + do { /* Increment to the next n-dimensional coordinate for one raw array */ #define ND4J_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ - } \ - else { \ - (data) += (strides)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ + } else { \ + (data) += (strides)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) #define ND4J_RAW_ITER_ONE_NEXTF(idim, ndim, coord, shape, data, strides) \ - for ((idim) = ndim - 1; (idim) >= (0); (idim)--) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ - } \ - else { \ - (data) += (strides)[idim]; \ - break; \ - } \ - } \ - } while ((idim) >= (0)) - + for ((idim) = ndim - 1; (idim) >= (0); (idim)--) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ + } else { \ + (data) += (strides)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) >= (0)) /* Increment to the next n-dimensional coordinate for two raw arrays */ -#define ND4J_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, dataB, stridesB) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) +#define ND4J_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /* Increment to the next n-dimensional coordinate for three raw arrays */ -#define ND4J_RAW_ITER_THREE_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, \ - dataB, stridesB, \ - dataC, stridesC) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - (dataC) += (stridesC)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) +#define ND4J_RAW_ITER_THREE_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB, dataC, stridesC) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /* Increment to the next n-dimensional coordinate for four raw arrays */ -#define ND4J_RAW_ITER_FOUR_NEXT(idim, ndim, coord, shape, \ - dataA, stridesA, \ - dataB, stridesB, \ - dataC, stridesC, \ - dataD, stridesD) \ - for ((idim) = 0; (idim) < (ndim); (idim)++) { \ - if (++(coord)[idim] == (shape)[idim]) { \ - (coord)[idim] = 0; \ - (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ - (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ - (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ - (dataD) -= ((shape)[idim] - 1) * (stridesD)[idim]; \ - } \ - else { \ - (dataA) += (stridesA)[idim]; \ - (dataB) += (stridesB)[idim]; \ - (dataC) += (stridesC)[idim]; \ - (dataD) += (stridesD)[idim]; \ - break; \ - } \ - } \ - } while ((idim) < (ndim)) - - - - - +#define ND4J_RAW_ITER_FOUR_NEXT(idim, ndim, coord, shape, dataA, stridesA, \ + dataB, stridesB, dataC, stridesC, dataD, \ + stridesD) \ + for ((idim) = 0; (idim) < (ndim); (idim)++) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + (dataD) -= ((shape)[idim] - 1) * (stridesD)[idim]; \ + } else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + (dataD) += (stridesD)[idim]; \ + break; \ + } \ + } \ + } \ + while ((idim) < (ndim)) /*NUMPY_API * @@ -194,18 +183,18 @@ void quickSort(StridePermutation *arr, int elements); #ifdef __CUDACC__ __host__ __device__ #endif -inline void SortStrideArray(int ndim, int strides[], - StridePermutation *out_strideperm) { - - /* Set up the strideperm values */ - for (int i = 0; i < ndim; i++) { - out_strideperm[i].perm = i; - out_strideperm[i].stride = strides[i]; - } - - /* Sort them */ - quickSort(out_strideperm,ndim); - + inline void + SortStrideArray(int ndim, int strides[], + StridePermutation *out_strideperm) { + + /* Set up the strideperm values */ + for (int i = 0; i < ndim; i++) { + out_strideperm[i].perm = i; + out_strideperm[i].stride = strides[i]; + } + + /* Sort them */ + quickSort(out_strideperm, ndim); } /* @@ -224,21 +213,19 @@ inline void SortStrideArray(int ndim, int strides[], * Returns 0 on success, -1 on failure. */ template - #ifdef __CUDACC__ __host__ __device__ #endif -inline int PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], - T data[], Nd4jLong strides[], - int *out_ndim, Nd4jLong outShape[], - T **out_data, Nd4jLong *outStrides) { - - for (int i = 0; i < ndim; i++) { - outShape[i] = shape[i]; - outStrides[i] = strides[i]; - } - + inline int + PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], T data[], + Nd4jLong strides[], int *out_ndim, + Nd4jLong outShape[], T **out_data, + Nd4jLong *outStrides) { + for (int i = 0; i < ndim; i++) { + outShape[i] = shape[i]; + outStrides[i] = strides[i]; + } #if 0 /* DEBUG */ @@ -257,51 +244,47 @@ inline int PrepareOneRawArrayIter(int ndim, Nd4jLong shape[], } #endif - *out_data = data; - *out_ndim = ndim; - return 0; + *out_data = data; + *out_ndim = ndim; + return 0; } - class BlockInformation { -public: - Nd4jLong items; - int threads; - Nd4jLong chunks; - Nd4jLong modulo; - Nd4jLong remainder; - - BlockInformation(Nd4jLong length, int threshold) { - + public: + Nd4jLong items; + int threads; + Nd4jLong chunks; + Nd4jLong modulo; + Nd4jLong remainder; + + BlockInformation(Nd4jLong length, int threshold) { threads = length / threshold; - threads = (1 < threads)?threads:1;//sd::math::nd4j_max(1, threads); - threads = (threads < omp_get_max_threads())?threads:omp_get_max_threads();//sd::math::nd4j_min(threads, omp_get_max_threads()); + threads = + (1 < threads) ? threads : 1; // sd::math::nd4j_max(1, threads); + threads = (threads < omp_get_max_threads()) + ? threads + : omp_get_max_threads(); // sd::math::nd4j_min(threads, + // omp_get_max_threads()); items = length / threads; remainder = length % threads; - if(items < 1) - items = 1; + if (items < 1) items = 1; chunks = length / items; modulo = length % items; - //one left over chunk - if(modulo > 0) - chunks++; - } -}; - - -class CudaBlockInformation { - + // one left over chunk + if (modulo > 0) chunks++; + } }; +class CudaBlockInformation {}; /** * Credit to: * http://alienryderflex.com/quicksort/ * - * The non recursive implementation is important for being able to run on cuda as well - * as host. + * The non recursive implementation is important for being able to run on cuda + * as well as host. * * In practice the work loads intended for * this won't be so hard @@ -313,46 +296,42 @@ class CudaBlockInformation { #ifdef __CUDACC__ __host__ __device__ #endif -inline void quickSort(StridePermutation *arr, int elements) { -#define MAX_LEVELS 300 - - int beg[MAX_LEVELS], end[MAX_LEVELS], i= 0, L, R, swap ; - StridePermutation piv; - beg[0] = 0; - end[0] = elements; - while (i >= 0) { - L = beg[i]; - R= end[i] - 1; - if (L < R) { - piv = arr[L]; - while (L < R) { - while (arr[R].stride >= piv.stride && L < R) - R--; - if (L < R) - arr[L++] = arr[R]; - while (arr[L].stride <= piv.stride && L < R) - L++; - if (L end[i - 1] - beg[i - 1]) { - swap = beg[i]; - beg[i]= beg[i - 1]; - beg[i - 1] = swap; - swap = end[i]; - end[i] = end[i - 1]; - end[i - 1] = swap; - } - } - else { - i--; - } + inline void + quickSort(StridePermutation *arr, int elements) { +#define MAX_LEVELS 300 + + int beg[MAX_LEVELS], end[MAX_LEVELS], i = 0, L, R, swap; + StridePermutation piv; + beg[0] = 0; + end[0] = elements; + while (i >= 0) { + L = beg[i]; + R = end[i] - 1; + if (L < R) { + piv = arr[L]; + while (L < R) { + while (arr[R].stride >= piv.stride && L < R) R--; + if (L < R) arr[L++] = arr[R]; + while (arr[L].stride <= piv.stride && L < R) L++; + if (L < R) arr[R--] = arr[L]; + } + + arr[L] = piv; + beg[i + 1] = L + 1; + end[i + 1] = end[i]; + end[i++] = L; + if (end[i] - beg[i] > end[i - 1] - beg[i - 1]) { + swap = beg[i]; + beg[i] = beg[i - 1]; + beg[i - 1] = swap; + swap = end[i]; + end[i] = end[i - 1]; + end[i - 1] = swap; + } + } else { + i--; } + } } /** @@ -372,50 +351,48 @@ inline void quickSort(StridePermutation *arr, int elements) { * Returns 0 on success, -1 on failure. */ template -int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, - X *dataA, Nd4jLong *stridesA, - Y *dataB, Nd4jLong *stridesB, - int *out_ndim, Nd4jLong *outShape, - X **out_dataA, Nd4jLong *outStridesA, - Y **out_dataB, Nd4jLong *outStridesB) { - int i; - -/* Sort the axes based on the destination strides */ - for (i = 0; i < ndim; ++i) { - outShape[i] = shape[i]; - outStridesA[i] = stridesA[i]; - outStridesB[i] = stridesB[i]; +int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, X *dataA, + Nd4jLong *stridesA, Y *dataB, + Nd4jLong *stridesB, int *out_ndim, + Nd4jLong *outShape, X **out_dataA, + Nd4jLong *outStridesA, Y **out_dataB, + Nd4jLong *outStridesB) { + int i; + + /* Sort the axes based on the destination strides */ + for (i = 0; i < ndim; ++i) { + outShape[i] = shape[i]; + outStridesA[i] = stridesA[i]; + outStridesB[i] = stridesB[i]; + } + + /* Reverse any negative strides of operand A */ + for (i = 0; i < ndim; i++) { + Nd4jLong stride_entryA = outStridesA[i]; + Nd4jLong stride_entryB = outStridesB[i]; + Nd4jLong shape_entry = outShape[i]; + + if (stride_entryA < 0) { + dataA += stride_entryA * (shape_entry - 1); + dataB += stride_entryB * (shape_entry - 1); + outStridesA[i] = -stride_entryA; + outStridesB[i] = -stride_entryB; } - -/* Reverse any negative strides of operand A */ - for (i = 0; i < ndim; i++) { - Nd4jLong stride_entryA = outStridesA[i]; - Nd4jLong stride_entryB = outStridesB[i]; - Nd4jLong shape_entry = outShape[i]; - - if (stride_entryA < 0) { - dataA += stride_entryA * (shape_entry - 1); - dataB += stride_entryB * (shape_entry - 1); - outStridesA[i] = -stride_entryA; - outStridesB[i] = -stride_entryB; - } -/* Detect 0-size arrays here */ - if (shape_entry == 0) { - *out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - outShape[0] = 0; - outStridesA[0] = 0; - outStridesB[0] = 0; - return 0; - } + /* Detect 0-size arrays here */ + if (shape_entry == 0) { + *out_ndim = 1; + *out_dataA = dataA; + *out_dataB = dataB; + outShape[0] = 0; + outStridesA[0] = 0; + outStridesB[0] = 0; + return 0; } + } - - *out_dataA = dataA; - *out_dataB = dataB; - *out_ndim = ndim; - + *out_dataA = dataA; + *out_dataB = dataB; + *out_ndim = ndim; #if 0 /* DEBUG */ @@ -443,7 +420,7 @@ int _CUDA_HD PrepareTwoRawArrayIter(int ndim, Nd4jLong *shape, } #endif - return 0; + return 0; } /** @@ -466,97 +443,93 @@ template #ifdef __CUDACC__ __host__ __device__ #endif -int PrepareThreeRawArrayIter(int ndim, Nd4jLong shape[], - X *dataA, Nd4jLong *stridesA, - Y *dataB, Nd4jLong *stridesB, - Z *dataC, Nd4jLong *stridesC, - int &out_ndim, Nd4jLong *outShape, - X **out_dataA, Nd4jLong outStridesA[], - Y **out_dataB, Nd4jLong outStridesB[], - Z **out_dataC, Nd4jLong outStridesC[]) -{ - - /* Special case 0 and 1 dimensions */ - if (ndim == 0) { - out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outShape[0] = 1; - outStridesA[0] = 0; - outStridesB[0] = 0; - outStridesC[0] = 0; - return 0; - } - else if (ndim == 1) { - auto stride_entryA = stridesA[0]; - auto stride_entryB = stridesB[0]; - auto stride_entryC = stridesC[0]; - auto shape_entry = shape[0]; - out_ndim = 1; - outShape[0] = shape[0]; - /* Always make a positive stride for the first operand */ - if (stride_entryA >= 0) { - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outStridesA[0] = stride_entryA; - outStridesB[0] = stride_entryB; - outStridesC[0] = stride_entryC; - } - else { - *out_dataA = dataA + stride_entryA * (shape_entry - 1); - *out_dataB = dataB + stride_entryB * (shape_entry - 1); - *out_dataC = dataC + stride_entryC * (shape_entry - 1); - outStridesA[0] = -stride_entryA; - outStridesB[0] = -stride_entryB; - outStridesC[0] = -stride_entryC; - } - return 0; - } - - for (int i = 0; i < ndim; ++i) { - outShape[i] = shape[i]; - outStridesA[i] = stridesA[i]; - outStridesB[i] = stridesB[i]; - outStridesC[i] = stridesC[i]; - } - - /* Reverse any negative strides of operand A */ - for (int i = 0; i < ndim; ++i) { - auto stride_entryA = outStridesA[i]; - auto stride_entryB = outStridesB[i]; - auto stride_entryC = outStridesC[i]; - auto shape_entry = outShape[i]; - - if (stride_entryA < 0) { - dataA += stride_entryA * (shape_entry - 1); - dataB += stride_entryB * (shape_entry - 1); - dataC += stride_entryC * (shape_entry - 1); - outStridesA[i] = -stride_entryA; - outStridesB[i] = -stride_entryB; - outStridesC[i] = -stride_entryC; - } - /* Detect 0-size arrays here */ - if (shape_entry == 0) { - out_ndim = 1; - *out_dataA = dataA; - *out_dataB = dataB; - *out_dataC = dataC; - outShape[0] = 0; - outStridesA[0] = 0; - outStridesB[0] = 0; - outStridesC[0] = 0; - return 0; - } - } - - + int + PrepareThreeRawArrayIter(int ndim, Nd4jLong shape[], X *dataA, + Nd4jLong *stridesA, Y *dataB, Nd4jLong *stridesB, + Z *dataC, Nd4jLong *stridesC, int &out_ndim, + Nd4jLong *outShape, X **out_dataA, + Nd4jLong outStridesA[], Y **out_dataB, + Nd4jLong outStridesB[], Z **out_dataC, + Nd4jLong outStridesC[]) { + + /* Special case 0 and 1 dimensions */ + if (ndim == 0) { + out_ndim = 1; *out_dataA = dataA; *out_dataB = dataB; *out_dataC = dataC; - out_ndim = ndim; + outShape[0] = 1; + outStridesA[0] = 0; + outStridesB[0] = 0; + outStridesC[0] = 0; + return 0; + } else if (ndim == 1) { + auto stride_entryA = stridesA[0]; + auto stride_entryB = stridesB[0]; + auto stride_entryC = stridesC[0]; + auto shape_entry = shape[0]; + out_ndim = 1; + outShape[0] = shape[0]; + /* Always make a positive stride for the first operand */ + if (stride_entryA >= 0) { + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + outStridesA[0] = stride_entryA; + outStridesB[0] = stride_entryB; + outStridesC[0] = stride_entryC; + } else { + *out_dataA = dataA + stride_entryA * (shape_entry - 1); + *out_dataB = dataB + stride_entryB * (shape_entry - 1); + *out_dataC = dataC + stride_entryC * (shape_entry - 1); + outStridesA[0] = -stride_entryA; + outStridesB[0] = -stride_entryB; + outStridesC[0] = -stride_entryC; + } return 0; + } + + for (int i = 0; i < ndim; ++i) { + outShape[i] = shape[i]; + outStridesA[i] = stridesA[i]; + outStridesB[i] = stridesB[i]; + outStridesC[i] = stridesC[i]; + } + + /* Reverse any negative strides of operand A */ + for (int i = 0; i < ndim; ++i) { + auto stride_entryA = outStridesA[i]; + auto stride_entryB = outStridesB[i]; + auto stride_entryC = outStridesC[i]; + auto shape_entry = outShape[i]; + + if (stride_entryA < 0) { + dataA += stride_entryA * (shape_entry - 1); + dataB += stride_entryB * (shape_entry - 1); + dataC += stride_entryC * (shape_entry - 1); + outStridesA[i] = -stride_entryA; + outStridesB[i] = -stride_entryB; + outStridesC[i] = -stride_entryC; + } + /* Detect 0-size arrays here */ + if (shape_entry == 0) { + out_ndim = 1; + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + outShape[0] = 0; + outStridesA[0] = 0; + outStridesB[0] = 0; + outStridesC[0] = 0; + return 0; + } + } + + *out_dataA = dataA; + *out_dataB = dataB; + *out_dataC = dataC; + out_ndim = ndim; + return 0; } -#endif //NATIVEOPERATIONS_PAIRWISE_UTIL_H +#endif // NATIVEOPERATIONS_PAIRWISE_UTIL_H diff --git a/libnd4j/include/system/platform_boilerplate.h b/libnd4j/include/system/platform_boilerplate.h index f8ee0eb55652..753c17bb6d72 100644 --- a/libnd4j/include/system/platform_boilerplate.h +++ b/libnd4j/include/system/platform_boilerplate.h @@ -23,35 +23,35 @@ #include - - -#define CONCATP(A,B) A ##_##B - - -#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) class SD_EXPORT PLATFORM_##CNAME : public PlatformHelper {\ - public: \ - PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) { } \ - bool isUsable(graph::Context &context) override; \ - Nd4jStatus invokeHelper(graph::Context &context) override; \ - }; - -#define DECLARE_PLATFORM(NAME, ENGINE) DECLARE_PLATFORM_F(NAME, ENGINE, NAME ##_## ENGINE) - -#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) struct SD_EXPORT __registratorPlatformHelper_##CNAME { \ - __registratorPlatformHelper_##CNAME() { \ - auto helper = new PLATFORM_##CNAME(); \ - OpRegistrator::getInstance()->registerHelper(helper); \ - } \ - }; \ - static __registratorPlatformHelper_##CNAME platformHelper_##CNAME; \ - Nd4jStatus PLATFORM_##CNAME::invokeHelper(sd::graph::Context &block) - - -#define PLATFORM_IMPL(NAME, ENGINE) PLATFORM_IMPL_F(NAME, ENGINE, NAME ##_## ENGINE) - - -#define PLATFORM_CHECK_F(NAME, ENGINE, CNAME) bool PLATFORM_##CNAME::isUsable(graph::Context &block) -#define PLATFORM_CHECK(NAME, ENGINE) PLATFORM_CHECK_F(NAME, ENGINE, NAME ##_## ENGINE) - - -#endif //SD_PLATFORM_BOILERPLATE_H +#define CONCATP(A, B) A##_##B + +#define DECLARE_PLATFORM_F(NAME, ENGINE, CNAME) \ + class SD_EXPORT PLATFORM_##CNAME : public PlatformHelper { \ + public: \ + PLATFORM_##CNAME() : PlatformHelper(#NAME, samediff::Engine::ENGINE) {} \ + bool isUsable(graph::Context &context) override; \ + Nd4jStatus invokeHelper(graph::Context &context) override; \ + }; + +#define DECLARE_PLATFORM(NAME, ENGINE) \ + DECLARE_PLATFORM_F(NAME, ENGINE, NAME##_##ENGINE) + +#define PLATFORM_IMPL_F(NAME, ENGINE, CNAME) \ + struct SD_EXPORT __registratorPlatformHelper_##CNAME { \ + __registratorPlatformHelper_##CNAME() { \ + auto helper = new PLATFORM_##CNAME(); \ + OpRegistrator::getInstance()->registerHelper(helper); \ + } \ + }; \ + static __registratorPlatformHelper_##CNAME platformHelper_##CNAME; \ + Nd4jStatus PLATFORM_##CNAME::invokeHelper(sd::graph::Context &block) + +#define PLATFORM_IMPL(NAME, ENGINE) \ + PLATFORM_IMPL_F(NAME, ENGINE, NAME##_##ENGINE) + +#define PLATFORM_CHECK_F(NAME, ENGINE, CNAME) \ + bool PLATFORM_##CNAME::isUsable(graph::Context &block) +#define PLATFORM_CHECK(NAME, ENGINE) \ + PLATFORM_CHECK_F(NAME, ENGINE, NAME##_##ENGINE) + +#endif // SD_PLATFORM_BOILERPLATE_H diff --git a/libnd4j/include/system/play.h b/libnd4j/include/system/play.h index 9e121d88bd92..fc525fd0a26a 100644 --- a/libnd4j/include/system/play.h +++ b/libnd4j/include/system/play.h @@ -30,7 +30,7 @@ #define Y_TYPES \ (DATA_INT8, int8_t) ,\ - (DATA_INT16, int16_t) + (DATA_INT16, int16_t) #define Z_TYPES \ (DATA_UINT8, uint8_t) ,\ @@ -38,7 +38,7 @@ #define PWT_LIST \ (float, long, float),\ - (float, long, long) + (float, long, long) BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) @@ -46,23 +46,27 @@ BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) DECLARE_PLATFORM(conv2d, ENGINE_CPU) -//BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES); +// BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), +// DATA_TYPES, Y_TYPES); -//BUILD_SINGLE_UNCHAINED_TEMPLATE(functionName , (signature), Y_TYPES); +// BUILD_SINGLE_UNCHAINED_TEMPLATE(functionName , (signature), Y_TYPES); -//BUILD_TRIPLE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_TRIPLE_SELECTOR(xType, yType, zType, functionName, (signature), +// DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_TRIPLE_TEMPLATE(functionName, (signature), DATA_TYPES, Y_TYPES, +// Z_TYPES) -//BUILD_TRIPLE_TEMPLATE(functionName, (signature), DATA_TYPES, Y_TYPES, Z_TYPES) +// BUILD_ENUMERATION(DATA_TYPES) -//BUILD_ENUMERATION(DATA_TYPES) +// BUILD_SINGLE_SELECTOR(xType, functions::IndexReduce, ::op(a, b, c, d, e), +// DATA_TYPES) BUILD_DOUBLE_SELECTOR(xType, yType, functions::IndexReduce, +// ::op(a, b, c, d, e), DATA_TYPES, DATA_TYPES) -//BUILD_SINGLE_SELECTOR(xType, functions::IndexReduce, ::op(a, b, c, d, e), DATA_TYPES) -//BUILD_DOUBLE_SELECTOR(xType, yType, functions::IndexReduce, ::op(a, b, c, d, e), DATA_TYPES, DATA_TYPES) +// BUILD_SINGLE_TEMPLATE(template class Alpha, (signature), DATA_TYPES); -//BUILD_SINGLE_TEMPLATE(template class Alpha, (signature), DATA_TYPES); - -//BUILD_DOUBLE_TEMPLATE(template class Alpha, (signature) , DATA_TYPES, DATA_TYPES); +// BUILD_DOUBLE_TEMPLATE(template class Alpha, (signature) , DATA_TYPES, +// DATA_TYPES); /* #define SCALAR_OPS \ @@ -111,60 +115,94 @@ DECLARE_PLATFORM(conv2d, ENGINE_CPU) EXECUTE_NOE((x, y, extras), OPS_A(PAIRWISE_TRANSFORM_OPS)) */ +// EXECUTE_NOE((x, extras), OPS_A(SCALAR_OPS)) -//EXECUTE_NOE((x, extras), OPS_A(SCALAR_OPS)) - -//BUILD_CALL_1(template void sd::NDArray::applyTransform, float16, (NDArray* a, float16* b), TRANSFORM_OPS) +// BUILD_CALL_1(template void sd::NDArray::applyTransform, float16, +// (NDArray* a, float16* b), TRANSFORM_OPS) -//BUILD_CALL_1(template void sd::NDArray::applyPairwiseTransform, float16, (NDArray* other, float16* extraParams), PAIRWISE_TRANSFORM_OPS) -//BUILD_TRACKER(TRANSFORM, ACTIVATIONS) +// BUILD_CALL_1(template void sd::NDArray::applyPairwiseTransform, +// float16, (NDArray* other, float16* extraParams), +// PAIRWISE_TRANSFORM_OPS) BUILD_TRACKER(TRANSFORM, ACTIVATIONS) -//BUILD_CALL_1(template void sd::NDArray::applyScalar, float16, (float16 scalar, NDArray* target, float16 *extraParams) , ACTIVATIONS); +// BUILD_CALL_1(template void sd::NDArray::applyScalar, float16, +// (float16 scalar, NDArray* target, float16 *extraParams) , +// ACTIVATIONS); /* -#define DECLARE_OP(NAME, NIN, NOUT) DECLARE_OP_UNIQ(__COUNTER__, NAME, NIN, NOUT) +#define DECLARE_OP(NAME, NIN, NOUT) DECLARE_OP_UNIQ(__COUNTER__, NAME, NIN, +NOUT) #define DECLARE_OP_UNIQ(CTR, NAME, NIN, NOUT) template \ - class NAME: public sd::ops::DeclarableOp { \ + class NAME: public +sd::ops::DeclarableOp { \ public:\ - NAME() : sd::ops::DeclarableOp(NIN, NOUT, #NAME) { } \ + NAME() : +sd::ops::DeclarableOp(NIN, NOUT, #NAME) { } \ protected: \ - Nd4jStatus validateAndExecute(Block& block); \ + Nd4jStatus +validateAndExecute(Block& block); \ };\ template \ - Nd4jStatus sd::ops::NAME::validateAndExecute(Block& block) + Nd4jStatus +sd::ops::NAME::validateAndExecute(Block& block) */ -//#define END_OP(NAME) }; static sd::ops::__registrator> register_op##Name; +//#define END_OP(NAME) }; static sd::ops::__registrator> +//register_op##Name; //#DECLARE_OP(Concat, -1, 1) -//END_OP(Concat) - - -//BUILD_LAYERS_FACTORY(float, OPS_A(NATIVE_LAYERS), OPS_B(ACTIVATIONS)) +// END_OP(Concat) +// BUILD_LAYERS_FACTORY(float, OPS_A(NATIVE_LAYERS), OPS_B(ACTIVATIONS)) -//DISPATCH_SIMPLE(scalarAlongDimension_, float, PARAMS(x, xShapeInfo, extraParamx, z, zShapeInfo, scalars, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), OPS_A(SCALAR_OPS)) +// DISPATCH_SIMPLE(scalarAlongDimension_, float, PARAMS(x, xShapeInfo, +// extraParamx, z, zShapeInfo, scalars, tadShapeInfo, tadOffsets, tadShapeInfoZ, +// tadOffsetsZ), OPS_A(SCALAR_OPS)) -//_EXEC_KERNEL_F(scalarAlongDimension_, scalarAlongDimensionGeneric, float, (float inputA, float inputB), (paramA, paramB), (10, SCALAR::Add), (11, SCALAR::Subtract), (12, SCALAR::Multiply)) +//_EXEC_KERNEL_F(scalarAlongDimension_, scalarAlongDimensionGeneric, float, +//(float inputA, float inputB), (paramA, paramB), (10, SCALAR::Add), (11, +//SCALAR::Subtract), (12, SCALAR::Multiply)) -//DISPATCH_KERNEL_SIMPLE(scalarAlongDimension_, scalarAlongDimensionGeneric, float, INPUT(float inputA, float inputB), PARAMS(paramA, paramB), OPS_A(SCALAR_OPS)) +// DISPATCH_KERNEL_SIMPLE(scalarAlongDimension_, scalarAlongDimensionGeneric, +// float, INPUT(float inputA, float inputB), PARAMS(paramA, paramB), +// OPS_A(SCALAR_OPS)) // original version -// DISPATCH_METAOP(functions::pairwise_transforms::PairWiseTransform::template transformCuda, PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), InvertedMetaOp, OPS_A(SCALAR_OPS), OPS_B(PAIRWISE_TRANSFORM_OPS)) +// DISPATCH_METAOP(functions::pairwise_transforms::PairWiseTransform::template +// transformCuda, PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, +// nullptr, nullptr, nullptr), InvertedMetaOp, OPS_A(SCALAR_OPS), +// OPS_B(PAIRWISE_TRANSFORM_OPS)) /* -DISPATCH_METAOP(invertedMetaPairwiseShaped_Pairwise_Scalar, PARAMS(opTypeA, opTypeB, N, x, xShape, y, yShape, z, zShape, extrasA, extrasB, scalarA, scalarB), float, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));*/ - -//DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) - -//_EXPAND_KERNEL_CALL(invertedMetaPairwiseShaped_Pairwise_Scalar_, invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB), PARAMS(N, dx, dy, xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), 66, simdOps::SomeOpA, 99, simdOps::SomeOpB) +DISPATCH_METAOP(invertedMetaPairwiseShaped_Pairwise_Scalar, PARAMS(opTypeA, +opTypeB, N, x, xShape, y, yShape, z, zShape, extrasA, extrasB, scalarA, +scalarB), float, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));*/ + +// DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, +// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const +// int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, +// int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, +// float scalarA, float scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, +// yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), +// OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) + +//_EXPAND_KERNEL_CALL(invertedMetaPairwiseShaped_Pairwise_Scalar_, +//invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const +//int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, +//int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, +//float scalarA, float scalarB), PARAMS(N, dx, dy, xStride, yStride, paramsPtr, +//dz, zStride, nullptr, nullptr, nullptr), 66, simdOps::SomeOpA, 99, +//simdOps::SomeOpB) /* - extern "C" __global__ void invertedMetaOpKernel_Pairwise_Scalar_16_1_float(const int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float scalarB) { - invertedMetaPairwiseShapedGeneric, simdOps::Multiply>>(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB); + extern "C" __global__ void + invertedMetaOpKernel_Pairwise_Scalar_16_1_float(const int opTypeA, const int + opTypeB, long N, float *dx, int *xShapeInfo, float *dy, int *yShapeInfo, float + *dz, int *zShapeInfo, float *extraA, float *extraB, float scalarA, float + scalarB) { invertedMetaPairwiseShapedGeneric, + simdOps::Multiply>>(opTypeA, opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, + dz, zShapeInfo, extraA, extraB, scalarA, scalarB); } */ - - -#endif //LIBND4J_PLAY_H +#endif // LIBND4J_PLAY_H diff --git a/libnd4j/include/system/pointercast.h b/libnd4j/include/system/pointercast.h index 2c64d608eda2..aedd9afbe9e5 100644 --- a/libnd4j/include/system/pointercast.h +++ b/libnd4j/include/system/pointercast.h @@ -21,44 +21,41 @@ #ifndef NATIVEOPERATIONS_POINTERCAST_H #define NATIVEOPERATIONS_POINTERCAST_H -#include #include +#include typedef void* Nd4jPointer; typedef long long Nd4jLong; typedef uint64_t Nd4jULong; typedef int Nd4jStatus; -#define ND4J_STATUS_OK 0 -#define ND4J_STATUS_BAD_INPUT 1 -#define ND4J_STATUS_BAD_SHAPE 2 -#define ND4J_STATUS_BAD_RANK 3 -#define ND4J_STATUS_BAD_PARAMS 4 -#define ND4J_STATUS_BAD_OUTPUT 5 -#define ND4J_STATUS_BAD_RNG 6 -#define ND4J_STATUS_BAD_EPSILON 7 +#define ND4J_STATUS_OK 0 +#define ND4J_STATUS_BAD_INPUT 1 +#define ND4J_STATUS_BAD_SHAPE 2 +#define ND4J_STATUS_BAD_RANK 3 +#define ND4J_STATUS_BAD_PARAMS 4 +#define ND4J_STATUS_BAD_OUTPUT 5 +#define ND4J_STATUS_BAD_RNG 6 +#define ND4J_STATUS_BAD_EPSILON 7 #define ND4J_STATUS_BAD_GRADIENTS 8 -#define ND4J_STATUS_BAD_BIAS 9 - -#define ND4J_STATUS_VALIDATION 20 - -#define ND4J_STATUS_BAD_GRAPH 30 -#define ND4J_STATUS_BAD_LENGTH 31 -#define ND4J_STATUS_BAD_DIMENSIONS 32 -#define ND4J_STATUS_BAD_ORDER 33 -#define ND4J_STATUS_BAD_ARGUMENTS 34 +#define ND4J_STATUS_BAD_BIAS 9 -#define ND4J_STATUS_DOUBLE_WRITE 40 -#define ND4J_STATUS_DOUBLE_READ 45 +#define ND4J_STATUS_VALIDATION 20 +#define ND4J_STATUS_BAD_GRAPH 30 +#define ND4J_STATUS_BAD_LENGTH 31 +#define ND4J_STATUS_BAD_DIMENSIONS 32 +#define ND4J_STATUS_BAD_ORDER 33 +#define ND4J_STATUS_BAD_ARGUMENTS 34 -#define ND4J_STATUS_KERNEL_FAILURE 50 +#define ND4J_STATUS_DOUBLE_WRITE 40 +#define ND4J_STATUS_DOUBLE_READ 45 +#define ND4J_STATUS_KERNEL_FAILURE 50 -#define ND4J_STATUS_TRUE 100 -#define ND4J_STATUS_FALSE 101 -#define ND4J_STATUS_MAYBE 119 - +#define ND4J_STATUS_TRUE 100 +#define ND4J_STATUS_FALSE 101 +#define ND4J_STATUS_MAYBE 119 #ifdef _MSC_VER @@ -82,6 +79,4 @@ typedef int Nd4jStatus; #endif - - -#endif //NATIVEOPERATIONS_POINTERCAST_H +#endif // NATIVEOPERATIONS_POINTERCAST_H diff --git a/libnd4j/include/system/type_boilerplate.h b/libnd4j/include/system/type_boilerplate.h index 997fcab22dee..dcc51b739d48 100644 --- a/libnd4j/include/system/type_boilerplate.h +++ b/libnd4j/include/system/type_boilerplate.h @@ -26,539 +26,1513 @@ #define EXPAND3(...) __VA_ARGS__ #define EXTRACT(...) EXTRACT __VA_ARGS__ #define NOTHING_EXTRACT -#define PASTE(x, ...) x ## __VA_ARGS__ -#define PASTE2(x, ...) x ## __VA_ARGS__ -#define PASTE3(x, ...) x ## __VA_ARGS__ +#define PASTE(x, ...) x##__VA_ARGS__ +#define PASTE2(x, ...) x##__VA_ARGS__ +#define PASTE3(x, ...) x##__VA_ARGS__ #define EVALUATING_PASTE(x, ...) PASTE(x, __VA_ARGS__) #define EVALUATING_PASTE2(x, ...) PASTE2(x, __VA_ARGS__) #define EVALUATING_PASTE3(x, ...) PASTE3(x, __VA_ARGS__) #define UNPAREN(x) EVALUATING_PASTE(NOTHING_, EXTRACT x) #define UNPAREN2(x) EVALUATING_PASTE2(NOTHING_, EXTRACT x) #define UNPAREN3(x) EVALUATING_PASTE3(NOTHING_, EXTRACT x) -#define EVAL( x ) x -#define EVALX( x ) x -#define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) +#define EVAL(x) x +#define EVALX(x) x +#define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) #define EVAL1(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) #define EVAL2(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) #define EVAL3(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) #define EVAL4(...) EVAL5(EVAL5(EVAL5(__VA_ARGS__))) #define EVAL5(...) __VA_ARGS__ - #define SEL_T_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - -#define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - - - -#define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - +#define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) \ + EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) + +#define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + ...) \ + WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) + +#define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) #define DS_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - +#define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) #define DP_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -#define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -#define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -#define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -#define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - - -#define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -#define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -#define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -#define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -#define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_DP(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME,...) NAME -#define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - - -#define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -#define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -#define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -#define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_TT1(__VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, SEL_TT1_3, SEL_TT1_2, SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_SEL_TT2(__VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, SEL_TT2_3, SEL_TT2_2, SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, DS_6, DS_5, DS_4, DS_3, DS_2, DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, DT_6, DT_5, DT_4, DT_3, DT_2, DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DP(__VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -#define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, __VA_ARGS__)) -#define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, __VA_ARGS__)) -#define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) - -#define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, TTT1_4, TTT1_3, TTT1_2, TTT1_1)(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -#define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, TTT2_4, TTT2_3, TTT2_2, TTT2_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, TTT3_4, TTT3_3, TTT3_2, TTT3_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -#define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -#define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -#define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) -#define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -#define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - -#define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -#define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -#define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) -#define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) -#define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), __VA_ARGS__)) - -#define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) -#define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -#define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) -#define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -#define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -#define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - +#define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +#define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +#define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) +#define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, \ + ...) \ + WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, \ + __VA_ARGS__)) + +#define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +#define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +#define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) \ + WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) \ + EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +#define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, \ + _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_DP( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, \ + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, \ + _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, \ + _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, \ + _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, \ + _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, \ + _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, \ + _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, \ + _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME, ...) \ + NAME +#define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME +#define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, NAME, ...) \ + NAME + +#define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, \ + SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, \ + SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, \ + SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, \ + SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_SEL_TT1( \ + __VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, \ + SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, \ + SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, \ + SEL_TT1_3, SEL_TT1_2, \ + SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, \ + SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, \ + SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, \ + SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, \ + SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)( \ + WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, \ + SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, \ + SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, \ + SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, \ + SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)( \ + WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EXPAND(GET_MACRO_SEL_TT2( \ + __VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, \ + SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, \ + SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, \ + SEL_TT2_3, SEL_TT2_2, \ + SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, \ + DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, \ + DS_6, DS_5, DS_4, DS_3, DS_2, \ + DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) \ + EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, \ + DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, \ + DT_6, DT_5, DT_4, DT_3, DT_2, \ + DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, \ + DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, \ + DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, \ + DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) \ + EXPAND(GET_MACRO_DP( \ + __VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, \ + DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, \ + DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, \ + DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, \ + DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, \ + DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, \ + DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, \ + DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, \ + DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, \ + DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, \ + DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, \ + DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, \ + DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, \ + DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) \ + EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, \ + TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, \ + TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, \ + TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, \ + __VA_ARGS__)) +#define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) \ + EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, \ + TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, \ + TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, \ + TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, \ + __VA_ARGS__)) +#define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) \ + EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, \ + TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, \ + TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, \ + TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, \ + __VA_ARGS__)) + +#define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + ...) \ + EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, \ + TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, \ + TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, \ + TTT1_4, TTT1_3, TTT1_2, TTT1_1)( \ + WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) +#define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) \ + EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, \ + TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, \ + TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, \ + TTT2_4, TTT2_3, TTT2_2, TTT2_1)( \ + WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) +#define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) \ + EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, \ + TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, \ + TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, \ + TTT3_4, TTT3_3, TTT3_2, TTT3_1)( \ + WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +#define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + TYPES_A, ...) \ + EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + ...) \ + EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) \ + EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +#define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +#define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) \ + EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) +#define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) \ + EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +#define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) \ + EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +#define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, \ + TYPES_Y, ...) \ + EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + ...) \ + EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + __VA_ARGS__)) +#define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, \ + ...) \ + EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +#define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) \ + EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) +#define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) \ + EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) +#define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) \ + EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), \ + __VA_ARGS__)) + +#define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) +#define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) \ + EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, \ + SIGNATURE, TYPE, __VA_ARGS__)) + +#define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) +#define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) \ + EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, \ + __VA_ARGS__)) + +#define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) \ + EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, \ + TYPES_Z, __VA_ARGS__)) +#define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) \ + EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, \ + TYPE_Y, __VA_ARGS__)) #ifndef __CLION_IDE__ -#define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) -#define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) -#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) -#define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) -#define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} - - -#define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} -#define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } -#define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), TYPES_Z)) -#define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) -#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} +#define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) \ + EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) +#define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) \ + EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) +#define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } + +#define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), \ + (TYPES_B), TYPES_A)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, \ + TYPES_Y, TYPES_Z) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + (TYPES_Z), (TYPES_Y), TYPES_X)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } +#define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) \ + EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), \ + TYPES_Z)) +#define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) \ + EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) +#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + TYPES_B) \ + switch (XTYPE) { \ + EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, \ + (SIGNATURE), (TYPES_B), TYPES_A)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("bad data type"); \ + } \ + } #else #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) @@ -569,74 +1543,210 @@ #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) -#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +#define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, \ + TYPES_Y, TYPES_Z) #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) -#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +#define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, \ + TYPES_B) #endif #define LIST(...) __VA_ARGS__ -#define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { NAME SIGNATURE; break; }; -#define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) - -#define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -#define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -#define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { if (ZTYPE == YTYPE) {NAME SIGNATURE;} else if (XTYPE == ZTYPE ){NAME SIGNATURE;} else {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("Unknown Z operand");}; break; }; -#define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) -#define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -#define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -#define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) case ENUM_Z: { NAMESIGNATURE;}; break; -#define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) EVALUATING_PASTE3(_SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) -#define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, TYPES_Z) case ENUM_Y: { switch (ZTYPE) { EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPES_Z))); default: {printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -#define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, UNPAREN2(TYPE_Y), TYPES_Z)) -#define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, TYPES_Z, ...) case ENUM_X: { switch (YTYPE) { EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__ )); default: {printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -#define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, TYPE_X) EVALUATING_PASTE(_SELECTOR, _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), TYPES_Z, UNPAREN(TYPES_Y))) - -#define _SELECTOR_SINGLE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) - -#define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) - -#define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; -#define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) - -#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; -#define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) - -#define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; -#define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) - -#define _RANDOMSINGLE(A, B, C, D) AB; +#define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) \ + case ENUM: { \ + NAME SIGNATURE; \ + break; \ + }; +#define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) \ + EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, \ + UNPAREN3(TYPE_B))) + +#define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) \ + case ENUM: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + } \ + }; \ + break; \ + }; +#define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) \ + EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), \ + UNPAREN(TYPES_B))) + +#define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + ENUM, TYPE_B) \ + case ENUM: { \ + if (ZTYPE == YTYPE) { \ + NAME SIGNATURE; \ + } else if (XTYPE == ZTYPE) { \ + NAME SIGNATURE; \ + } else { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + throw std::runtime_error("Unknown Z operand"); \ + }; \ + break; \ + }; +#define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + TYPE_B) \ + EVALUATING_PASTE2( \ + _SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), \ + TYPE_A, UNPAREN3(TYPE_B))) +#define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, \ + ...) \ + case ENUM: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, \ + __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + fflush(stdout); \ + } \ + }; \ + break; \ + }; +#define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, \ + TYPE_A) \ + EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, \ + UNPAREN(TYPE_A), UNPAREN(TYPES_B))) + +#define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) \ + case ENUM_Z: { \ + NAME SIGNATURE; \ + }; break; +#define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) \ + EVALUATING_PASTE3( \ + _SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) +#define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, \ + TYPES_Z) \ + case ENUM_Y: { \ + switch (ZTYPE) { \ + EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, \ + UNPAREN3(TYPES_Z))); \ + default: { \ + printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, \ + __LINE__); \ + ; \ + fflush(stdout); \ + } \ + } \ + break; \ + }; +#define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) \ + EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, \ + UNPAREN2(TYPE_Y), TYPES_Z)) +#define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, \ + TYPES_Z, ...) \ + case ENUM_X: { \ + switch (YTYPE) { \ + EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, \ + __VA_ARGS__)); \ + default: { \ + printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, \ + __LINE__); \ + ; \ + fflush(stdout); \ + } \ + } \ + break; \ + }; +#define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, \ + TYPE_X) \ + EVALUATING_PASTE(_SELECTOR, \ + _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), \ + TYPES_Z, UNPAREN(TYPES_Y))) + +#define _SELECTOR_SINGLE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) + +#define _SELECTOR_SINGLE_THRICE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE_THRICE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) + +#define _SELECTOR_SINGLE_TWICE(A, B, C, D) \ + case C: { \ + A B; \ + break; \ + }; +#define SELECTOR_SINGLE_TWICE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) A B; +#define TEMPLATE_SINGLE_TWICE(A, B, C) \ + EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) \ + case C: { \ + A D, UNPAREN2(B); \ + break; \ + }; +#define SELECTOR_PARTIAL_SINGLE(A, B, C) \ + EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) + +#define _RANDOMSINGLE(A, B, C, D) A B; #define _RANDOMSINGLEU(A, B, C, D) A D B; -#define RANDOMSINGLE(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) -#define RANDOMSINGLEU(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) -#define RANDOMDOUBLE(A, B, C, D) EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) - -#define _RANDOMDOUBLE2(A, B, C, D, E, F) AB; -#define RANDOMDOUBLE2(A, B, C, D) EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) - -#define _RANDOMPAIRWISE2(A, B, C, D, E) AE; -#define RANDOMPAIRWISE(A, B, C) EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) - -#define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) AB; -#define RANDOMTRIPLE3(A, B, Z, Y, X) EVALUATING_PASTE(_RANDOM, TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) - -#define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, UNPAREN(TYPES_X))) -#define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) -#define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, UNPAREN(TYPES_Y))) -#define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) - - -#define BROADCAST(NAME) sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) -#define BROADCAST_BOOL(NAME) sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) +#define RANDOMSINGLE(A, B, C) \ + EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) +#define RANDOMSINGLEU(A, B, C) \ + EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) +#define RANDOMDOUBLE(A, B, C, D) \ + EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) + +#define _RANDOMDOUBLE2(A, B, C, D, E, F) A B; +#define RANDOMDOUBLE2(A, B, C, D) \ + EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) + +#define _RANDOMPAIRWISE2(A, B, C, D, E) A E; +#define RANDOMPAIRWISE(A, B, C) \ + EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) + +#define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) A B; +#define RANDOMTRIPLE3(A, B, Z, Y, X) \ + EVALUATING_PASTE(_RANDOM, \ + TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) + +#define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) \ + EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, \ + UNPAREN(TYPES_X))) +#define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) \ + _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +#define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) \ + EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, \ + UNPAREN(TYPES_Y))) +#define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) \ + _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) + +#define BROADCAST(NAME) \ + sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, \ + sd::broadcast::NAME) +#define BROADCAST_BOOL(NAME) \ + sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, \ + sd::broadcast::NAME) #define ALL_STRINGS sd::DataType::UTF8, sd::DataType::UTF16, sd::DataType::UTF32 #define ALL_INDICES sd::DataType::INT32, sd::DataType::INT64 -#define ALL_INTS sd::DataType::INT8, sd::DataType::UINT8, sd::DataType::INT16, sd::DataType::UINT16, sd::DataType::INT32, sd::DataType::UINT32, sd::DataType::INT64, sd::DataType::UINT64 -#define ALL_FLOATS sd::DataType::HALF, sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16 - -#endif //TESTS_CPU_TYPE_BOILERPLATE_H +#define ALL_INTS \ + sd::DataType::INT8, sd::DataType::UINT8, sd::DataType::INT16, \ + sd::DataType::UINT16, sd::DataType::INT32, sd::DataType::UINT32, \ + sd::DataType::INT64, sd::DataType::UINT64 +#define ALL_FLOATS \ + sd::DataType::HALF, sd::DataType::FLOAT32, sd::DataType::DOUBLE, \ + sd::DataType::BFLOAT16 + +#endif // TESTS_CPU_TYPE_BOILERPLATE_H diff --git a/libnd4j/include/system/util.h b/libnd4j/include/system/util.h index aa2055606fcc..fbd3214c4c35 100644 --- a/libnd4j/include/system/util.h +++ b/libnd4j/include/system/util.h @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -/* +/* * File: util.h * Author: saudet * @@ -34,14 +34,14 @@ static inline Nd4jLong microTime() { #ifdef WIN32 - LARGE_INTEGER freq, count; - QueryPerformanceFrequency(&freq); - QueryPerformanceCounter(&count); - return (Nd4jLong)count.QuadPart/freq.QuadPart; + LARGE_INTEGER freq, count; + QueryPerformanceFrequency(&freq); + QueryPerformanceCounter(&count); + return (Nd4jLong)count.QuadPart / freq.QuadPart; #else - timeval tv; - gettimeofday(&tv, NULL); - return (Nd4jLong)tv.tv_sec*1000000 + tv.tv_usec; + timeval tv; + gettimeofday(&tv, NULL); + return (Nd4jLong)tv.tv_sec * 1000000 + tv.tv_usec; #endif } diff --git a/libnd4j/include/types/bfloat16.h b/libnd4j/include/types/bfloat16.h index a0590981605f..2be0f4e3c140 100644 --- a/libnd4j/include/types/bfloat16.h +++ b/libnd4j/include/types/bfloat16.h @@ -16,7 +16,8 @@ /* - Intel bfloat16 data type, based on https://software.intel.com/sites/default/files/managed/40/8b/bf16-hardware-numerics-definition-white-paper.pdf + Intel bfloat16 data type, based on + https://software.intel.com/sites/default/files/managed/40/8b/bf16-hardware-numerics-definition-white-paper.pdf */ @@ -32,7 +33,6 @@ #include #endif - #ifdef __CUDACC__ #define local_def inline __host__ __device__ #elif _MSC_VER @@ -43,212 +43,352 @@ #define local_def inline #endif -//namespace sd -//{ - struct bfloat16 - { - private: - template - struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; - // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - - public: - int16_t _data; - - local_def bfloat16() { - _data = 0; - } - - template ::value>::type> - local_def bfloat16(const T& rhs) { - *this = rhs; - } - - local_def operator float() const { - int32_t temp = this->_data << 16; //((sign << 31) | (exponent << 23) | mantissa); - return *reinterpret_cast(&temp); - } - - local_def explicit operator bool() const { - return this->_data == 0 ? false : true; - } - - template ::value>::type> - local_def explicit operator T() const { - return static_cast(static_cast(*this)); - } - - local_def bfloat16& operator=(const bool rhs) { - *this = (float)rhs ? 1.f: 0.f; - return *this; - } - - local_def bfloat16& operator=(const float& rhs) { - #ifdef __CUDACC__ - if(::isnan(rhs)) { - _data = bfloat16::nan(); - return *this; - } - #endif - auto x = *reinterpret_cast(& const_cast(rhs)); - uint32_t lsb = (x >> 16) & 1; - uint32_t rounding_bias = 0x7fff + lsb; - x += rounding_bias; - this->_data = static_cast(x >> 16); - - return *this; - } - - local_def bfloat16& operator=(const bfloat16& rhs) { - _data = rhs._data; - return *this; - } - - template ::value>::type> - local_def bfloat16& operator=(const T& rhs) { - *this = (float)rhs; - return *this; - } - - local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { return (a._data == b._data); } - local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { return !(a == b); } - local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { return (float)a < (float)b; } - local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { return (float)a > (float)b; } - local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { return (float)a <= (float)b; } - local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { return (float)a >= (float)b; } - - local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a + (float)b); } - local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a - (float)b); } - local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a * (float)b); } - local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { return bfloat16((float)a / (float)b); } - - template ::value>::type> - local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { return a + static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { return static_cast(a) + b; } - - template ::value>::type> - local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { return a - static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { return static_cast(a) - b; } - - template ::value>::type> - local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { return a * static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { return static_cast(a) * b; } - - template ::value>::type> - local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { return a / static_cast(b); } - template ::value>::type> - local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { return static_cast(a) / b; } - - template ::value>::type> - local_def friend bool operator==(const bfloat16& a, const T& b) { return a == static_cast(b); } - template ::value>::type> - local_def friend bool operator==(const T& a, const bfloat16& b) { return static_cast(a) == b; } - - template ::value>::type> - local_def friend bool operator!=(const bfloat16& a, const T& b) { return a != static_cast(b); } - template ::value>::type> - local_def friend bool operator!=(const T& a, const bfloat16& b) { return static_cast(a) != b; } - - template ::value>::type> - local_def friend bool operator<(const bfloat16& a, const T& b) { return a < static_cast(b); } - template ::value>::type> - local_def friend bool operator<(const T& a, const bfloat16& b) { return static_cast(a) < b; } - - template ::value>::type> - local_def friend bool operator>(const bfloat16& a, const T& b) { return a > static_cast(b); } - template ::value>::type> - local_def friend bool operator>(const T& a, const bfloat16& b) { return static_cast(a) > b; } - - template ::value>::type> - local_def friend bool operator<=(const bfloat16& a, const T& b) { return a <= static_cast(b); } - template ::value>::type> - local_def friend bool operator<=(const T& a, const bfloat16& b) { return static_cast(a) <= b; } - - template ::value>::type> - local_def friend bool operator>=(const bfloat16& a, const T& b) { return a >= static_cast(b); } - template ::value>::type> - local_def friend bool operator>=(const T& a, const bfloat16& b) { return static_cast(a) >= b; } - - local_def bfloat16& operator+=(bfloat16 rhs) { *this = (float)(*this) + (float)rhs; return *this; } - - local_def bfloat16& operator-=(bfloat16 rhs) { *this = (float)(*this) - (float)rhs; return *this; } - - local_def bfloat16& operator*=(bfloat16 rhs) { *this = (float)(*this) * (float)rhs; return *this; } - - local_def bfloat16& operator/=(bfloat16 rhs) { *this = (float)(*this) / (float)rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } - - template ::value>::type> - local_def bfloat16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } - - local_def bfloat16& operator++() { *this = (float)*this + (float)1.f; return *this; } - - local_def bfloat16& operator--() { *this = (float)*this - (float)1.f; return *this; } - - local_def bfloat16 operator++(int) { *this = (float)*this + (float)1.f; return *this; } - - local_def bfloat16 operator--(int) { *this = (float)*this - (float)1.f; return *this; } - - local_def bfloat16 operator-() const { - return 0.f - (float)*this; - } - - - - // local_def std::ostream& operator<<(std::ostream& os) { - // os << static_cast(*this); - // return os; - // } - local_def static bfloat16 min() { - bfloat16 res; - res._data = 0xFF7F; - return res; - } - local_def static bfloat16 max() { - bfloat16 res; - res._data = 0x7F7F; - return res; - - } - local_def static bfloat16 eps() { - bfloat16 res; - res._data = 0x3C00; - return res; - } +struct float16; - local_def static bfloat16 inf() { - bfloat16 res; - res._data = 0x3C00; - return res; - } - - local_def static bfloat16 nan() { - bfloat16 res; - res._data = 0x7FC0; - return res; - } +// namespace sd +//{ +struct bfloat16 { + private: + template + struct isNumericType { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; + + public: + int16_t _data; + + local_def bfloat16() { _data = 0; } + + template ::value>::type> + local_def bfloat16(const T& rhs) { + *this = rhs; + } + + local_def operator float() const { + int32_t temp = this->_data + << 16; //((sign << 31) | (exponent << 23) | mantissa); + return *reinterpret_cast(&temp); + } + + local_def explicit operator bool() const { + return this->_data == 0 ? false : true; + } + + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } + + local_def bfloat16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f : 0.f; + return *this; + } + + local_def bfloat16& operator=(const float& rhs) { +#ifdef __CUDACC__ + if (::isnan(rhs)) { + _data = bfloat16::nan(); + return *this; + } +#endif + auto x = *reinterpret_cast(&const_cast(rhs)); + uint32_t lsb = (x >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + x += rounding_bias; + this->_data = static_cast(x >> 16); + + return *this; + } + + local_def bfloat16& operator=(const bfloat16& rhs) { + _data = rhs._data; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + + local_def friend bool operator==(const bfloat16& a, const bfloat16& b) { + return (a._data == b._data); + } + local_def friend bool operator!=(const bfloat16& a, const bfloat16& b) { + return !(a == b); + } + local_def friend bool operator<(const bfloat16& a, const bfloat16& b) { + return (float)a < (float)b; + } + local_def friend bool operator>(const bfloat16& a, const bfloat16& b) { + return (float)a > (float)b; + } + local_def friend bool operator<=(const bfloat16& a, const bfloat16& b) { + return (float)a <= (float)b; + } + local_def friend bool operator>=(const bfloat16& a, const bfloat16& b) { + return (float)a >= (float)b; + } + + local_def friend bfloat16 operator+(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a + (float)b); + } + local_def friend bfloat16 operator-(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a - (float)b); + } + local_def friend bfloat16 operator*(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a * (float)b); + } + local_def friend bfloat16 operator/(const bfloat16& a, const bfloat16& b) { + return bfloat16((float)a / (float)b); + } + + template ::value>::type> + local_def friend bfloat16 operator+(const bfloat16& a, const T& b) { + return a + static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator+(const T& a, const bfloat16& b) { + return static_cast(a) + b; + } + + template ::value>::type> + local_def friend bfloat16 operator-(const bfloat16& a, const T& b) { + return a - static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator-(const T& a, const bfloat16& b) { + return static_cast(a) - b; + } + + template ::value>::type> + local_def friend bfloat16 operator*(const bfloat16& a, const T& b) { + return a * static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator*(const T& a, const bfloat16& b) { + return static_cast(a) * b; + } + + template ::value>::type> + local_def friend bfloat16 operator/(const bfloat16& a, const T& b) { + return a / static_cast(b); + } + template ::value>::type> + local_def friend bfloat16 operator/(const T& a, const bfloat16& b) { + return static_cast(a) / b; + } + + template ::value>::type> + local_def friend bool operator==(const bfloat16& a, const T& b) { + return a == static_cast(b); + } + template ::value>::type> + local_def friend bool operator==(const T& a, const bfloat16& b) { + return static_cast(a) == b; + } + + template ::value>::type> + local_def friend bool operator!=(const bfloat16& a, const T& b) { + return a != static_cast(b); + } + template ::value>::type> + local_def friend bool operator!=(const T& a, const bfloat16& b) { + return static_cast(a) != b; + } + + template ::value>::type> + local_def friend bool operator<(const bfloat16& a, const T& b) { + return a < static_cast(b); + } + template ::value>::type> + local_def friend bool operator<(const T& a, const bfloat16& b) { + return static_cast(a) < b; + } + + template ::value>::type> + local_def friend bool operator>(const bfloat16& a, const T& b) { + return a > static_cast(b); + } + template ::value>::type> + local_def friend bool operator>(const T& a, const bfloat16& b) { + return static_cast(a) > b; + } + + template ::value>::type> + local_def friend bool operator<=(const bfloat16& a, const T& b) { + return a <= static_cast(b); + } + template ::value>::type> + local_def friend bool operator<=(const T& a, const bfloat16& b) { + return static_cast(a) <= b; + } + + template ::value>::type> + local_def friend bool operator>=(const bfloat16& a, const T& b) { + return a >= static_cast(b); + } + template ::value>::type> + local_def friend bool operator>=(const T& a, const bfloat16& b) { + return static_cast(a) >= b; + } + + local_def bfloat16& operator+=(bfloat16 rhs) { + *this = (float)(*this) + (float)rhs; + return *this; + } + + local_def bfloat16& operator-=(bfloat16 rhs) { + *this = (float)(*this) - (float)rhs; + return *this; + } + + local_def bfloat16& operator*=(bfloat16 rhs) { + *this = (float)(*this) * (float)rhs; + return *this; + } + + local_def bfloat16& operator/=(bfloat16 rhs) { + *this = (float)(*this) / (float)rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator+=(const T& rhs) { + *this = *this + rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator-=(const T& rhs) { + *this = *this - rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator*=(const T& rhs) { + *this = *this * rhs; + return *this; + } + + template ::value>::type> + local_def bfloat16& operator/=(const T& rhs) { + *this = *this / rhs; + return *this; + } + + local_def bfloat16& operator++() { + *this = (float)*this + (float)1.f; + return *this; + } + + local_def bfloat16& operator--() { + *this = (float)*this - (float)1.f; + return *this; + } + + local_def bfloat16 operator++(int) { + *this = (float)*this + (float)1.f; + return *this; + } + + local_def bfloat16 operator--(int) { + *this = (float)*this - (float)1.f; + return *this; + } + + local_def bfloat16 operator-() const { return 0.f - (float)*this; } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } + local_def static bfloat16 min() { + bfloat16 res; + res._data = 0xFF7F; + return res; + } + local_def static bfloat16 max() { + bfloat16 res; + res._data = 0x7F7F; + return res; + } + local_def static bfloat16 eps() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 inf() { + bfloat16 res; + res._data = 0x3C00; + return res; + } + + local_def static bfloat16 nan() { + bfloat16 res; + res._data = 0x7FC0; + return res; + } }; - - // local_def std::ostream& operator<<(std::ostream &os, const bfloat16 &f) { // os << static_cast(f); // return os; // } - -// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; } +// local_def bfloat16 /* constexpr */ operator+(const bfloat16& h) { return h; +// } // local_def bfloat16 operator - (const bfloat16& h) { // auto temp = h._data; @@ -258,8 +398,8 @@ // return t; // } -// WARNING: this implementation only for avoid cyclic references between float16 and bfloat16 types. -// local_def void float16::assign(const bfloat16& rhs) { +// WARNING: this implementation only for avoid cyclic references between float16 +// and bfloat16 types. local_def void float16::assign(const bfloat16& rhs) { // assign((float)rhs); // } diff --git a/libnd4j/include/types/float16.h b/libnd4j/include/types/float16.h index 761e66f1b2b5..3eecf3f32744 100644 --- a/libnd4j/include/types/float16.h +++ b/libnd4j/include/types/float16.h @@ -17,12 +17,13 @@ #ifndef LIBND4J_FLOAT16_H #define LIBND4J_FLOAT16_H +#include + #include #include #include -#include #if defined(__INTEL_COMPILER) || defined(SD_F16C) - #include +#include #endif struct bfloat16; @@ -34,63 +35,50 @@ struct bfloat16; // CUDA_9 and above struct ihalf : public __half { - public: - __host__ __device__ ihalf() : half() { - // - } - - inline __host__ __device__ unsigned short * getXP() { - return &this->__x; - } - - inline __host__ __device__ unsigned short getX() const { - return this->__x; - } - - inline __host__ __device__ void assign(const half f) { - this->__x = ((__half_raw *) &f)->x; - } + public: + __host__ __device__ ihalf() : half() { + // + } + + inline __host__ __device__ unsigned short* getXP() { return &this->__x; } + + inline __host__ __device__ unsigned short getX() const { return this->__x; } + + inline __host__ __device__ void assign(const half f) { + this->__x = ((__half_raw*)&f)->x; + } }; #else struct ihalf : public __half { - public: - __host__ __device__ ihalf() : half() { - // - } - - inline __host__ __device__ unsigned short * getXP() { - return &this->x; - } - - inline __host__ __device__ unsigned short getX() const { - return this->x; - } - - inline __host__ __device__ void assign(const half f) { - this->x = ((__half *) &f)->x; - } + public: + __host__ __device__ ihalf() : half() { + // + } + + inline __host__ __device__ unsigned short* getXP() { return &this->x; } + + inline __host__ __device__ unsigned short getX() const { return this->x; } + + inline __host__ __device__ void assign(const half f) { + this->x = ((__half*)&f)->x; + } }; -#endif // CUDA_8 +#endif // CUDA_8 #else struct __half { -public: - unsigned short x; - inline unsigned short * getXP() { - return &this->x; - } + public: + unsigned short x; + inline unsigned short* getXP() { return &this->x; } - inline unsigned short getX() const { - return this->x; - } + inline unsigned short getX() const { return this->x; } }; typedef __half half; typedef __half ihalf; - -#endif // CUDA +#endif // CUDA #ifdef __CUDACC__ #define local_def inline __host__ __device__ @@ -102,386 +90,545 @@ typedef __half ihalf; #define local_def inline #endif - static local_def int ishnan_(unsigned short h) { - return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) != 0; + return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) != 0; } static local_def int ishinf_(unsigned short h) { - return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) == 0; + return (h & 0x7c00U) == 0x7c00U && (h & 0x03ffU) == 0; } static local_def int ishequ_(unsigned short x, unsigned short y) { - return ishnan_(x) == 0 && ishnan_(y) == 0 && x == y; + return ishnan_(x) == 0 && ishnan_(y) == 0 && x == y; } static local_def unsigned short hneg(unsigned short h) { - h ^= 0x8000U; - return h; + h ^= 0x8000U; + return h; } - #if defined(__INTEL_COMPILER) || defined(SD_F16C) //_Pragma("omp declare simd") inline -local_def float cpu_ihalf2float(ihalf h) { - return _cvtsh_ss(h.getX()); -} +local_def float cpu_ihalf2float(ihalf h) { return _cvtsh_ss(h.getX()); } #else local_def float cpu_ihalf2float(ihalf h) { - unsigned sign = ((h.getX() >> 15) & 1); - unsigned exponent = ((h.getX() >> 10) & 0x1f); - unsigned mantissa = ((h.getX() & 0x3ff) << 13); - - if (exponent == 0x1f) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x71; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x70; + unsigned sign = ((h.getX() >> 15) & 1); + unsigned exponent = ((h.getX() >> 10) & 0x1f); + unsigned mantissa = ((h.getX() & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ } + } else { + exponent += 0x70; + } - int temp = ((sign << 31) | (exponent << 23) | mantissa); + int temp = ((sign << 31) | (exponent << 23) | mantissa); - return *((float*)((void*)&temp)); + return *((float*)((void*)&temp)); } #endif #if defined(__INTEL_COMPILER) || defined(SD_F16C) //_Pragma("omp declare simd") inline local_def ihalf cpu_float2ihalf_rn(float f) { - ihalf ret; - ret.x = _cvtss_sh(f, 0); - return ret; + ihalf ret; + ret.x = _cvtss_sh(f, 0); + return ret; } #else -local_def ihalf cpu_float2ihalf_rn(float f) -{ - ihalf ret; - - unsigned x = *((int*)(void*)(&f)); - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - *ret.getXP() = 0x7fffU; - return ret; - } +local_def ihalf cpu_float2ihalf_rn(float f) { + ihalf ret; - sign = ((x >> 16) & 0x8000); + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - *ret.getXP() = sign | 0x7c00U; - return ret; - } - if (u < 0x33000001) { - *ret.getXP() = (sign | 0x0000); - return ret; - } + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + *ret.getXP() = 0x7fffU; + return ret; + } - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); + sign = ((x >> 16) & 0x8000); - if (exponent > 0x70) { - shift = 13; - exponent -= 0x70; - } else { - shift = 0x7e - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0x3ff)) { - ++exponent; - mantissa = 0; - } + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + *ret.getXP() = sign | 0x7c00U; + return ret; + } + if (u < 0x33000001) { + *ret.getXP() = (sign | 0x0000); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; } + } - *ret.getXP() = (sign | (exponent << 10) | mantissa); + *ret.getXP() = (sign | (exponent << 10) | mantissa); - return ret; + return ret; } #endif struct float16 { + private: + template + struct isNumericType { + static bool const value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value; + }; // || std::is_same::value; }; + // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; + + public: + ihalf data; + local_def float16() { *data.getXP() = 0; } + + template ::value || + std::is_same::value>::type> + local_def float16(const T& rhs) { + *this = rhs; + } + + local_def float16(const half& rhs) { +#ifdef __CUDACC__ + data.assign(rhs); +#endif + } + + local_def operator float() const { +#ifdef __CUDA_ARCH__ + return __half2float(data); +#else + return cpu_ihalf2float(data); +#endif + } + + local_def explicit operator bool() const { + return static_cast(*this) != 0.0f; + } + + local_def explicit operator half() const { return data; } + + template ::value>::type> + local_def explicit operator T() const { + return static_cast(static_cast(*this)); + } - private: - template - struct isNumericType { static bool const value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; };// || std::is_same::value; }; - // struct isNumericType { static bool const value = std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::type>::value || std::is_same::value;; }; - - public: - ihalf data; - local_def float16() { *data.getXP() = 0; } - - template ::value || std::is_same::value>::type> - local_def float16(const T& rhs) { - *this = rhs; - } - - local_def float16(const half& rhs) { - #ifdef __CUDACC__ - data.assign(rhs); - #endif - } - - local_def operator float() const { - #ifdef __CUDA_ARCH__ - return __half2float(data); - #else - return cpu_ihalf2float(data); - #endif - } - - local_def explicit operator bool() const { - return static_cast(*this) != 0.0f; - } - - local_def explicit operator half() const { - return data; - } - - template ::value>::type> - local_def explicit operator T() const { - return static_cast(static_cast(*this)); - } - - local_def float16& operator=(const float& rhs) { - #ifdef __CUDA_ARCH__ - auto t = __float2half_rn(rhs); - auto b = *(data.getXP()); - - #ifdef CUDA_8 - *(data.getXP()) = t; - #else - data.assign(t); - #endif - - #else - data = cpu_float2ihalf_rn(rhs); - #endif - - return *this; - } - - local_def float16& operator=(const unsigned short rhs) { - *data.getXP() = rhs; - return *this; - } - - local_def float16& operator=(const bool rhs) { - *this = (float)rhs ? 1.f: 0.f; - return *this; - } - - local_def float16& operator=(const ihalf& rhs) { - *data.getXP() = ((ihalf) rhs).getX(); - return *this; - } - - #ifdef __CUDACC__ - local_def float16& operator=(const half& rhs) { - data.assign(rhs); - return *this; - } - #endif - - local_def float16& operator=(const float16& rhs) { - data = rhs.data; - return *this; - } - - template ::value || std::is_same::value>::type> - local_def float16& operator=(const T& rhs) { - *this = (float)rhs; - return *this; - } - - #ifdef NATIVE_HALFS - local_def friend bool operator==(const float16& a, const float16& b) { return __hequ(a.data, b.data); } - #else - local_def friend bool operator==(const float16& a, const float16& b) { return ishequ_(((ihalf) a.data).getX(), ((ihalf)b.data).getX()); } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator!=(const float16& a, const float16& b) { return !(__hequ(a.data, b.data)); } - #else - local_def friend bool operator!=(const float16& a, const float16& b) { return !(a == b); } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator<(const float16& a, const float16& b) { return __hlt(a.data, b.data); } - #else - local_def friend bool operator<(const float16& a, const float16& b) { return (float)a < (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator>(const float16& a, const float16& b) { return __hgt(a.data, b.data); } - #else - local_def friend bool operator>(const float16& a, const float16& b) { return (float)a > (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator<=(const float16& a, const float16& b) { return __hle(a.data, b.data); } - #else - local_def friend bool operator<=(const float16& a, const float16& b) { return (float)a <= (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend bool operator>=(const float16& a, const float16& b) { return __hge(a.data, b.data); } - #else - local_def friend bool operator>=(const float16& a, const float16& b) { return (float)a >= (float)b; } - #endif - - #ifdef NATIVE_HALFS - local_def friend float16 operator+(const float16& a, const float16& b) { return __hadd(a.data, b.data); } - - local_def friend float16 operator-(const float16& a, const float16& b) { return __hsub(a.data, b.data); } - - local_def friend float16 operator*(const float16& a, const float16& b) { return __hmul(a.data, b.data); } - - local_def friend float16 operator/(const float16& a, const float16& b) { - #ifdef CUDA_8 - return hdiv(a.data, b.data); - #else - return __hdiv(a.data, b.data); - #endif - } - #else - local_def friend float16 operator+(const float16& a, const float16& b) { return float16((float)a + (float)b); } - local_def friend float16 operator-(const float16& a, const float16& b) { return float16((float)a - (float)b); } - local_def friend float16 operator*(const float16& a, const float16& b) { return float16((float)a * (float)b); } - local_def friend float16 operator/(const float16& a, const float16& b) { return float16((float)a / (float)b); } - #endif - - template ::value>::type> - local_def friend float16 operator+(const float16& a, const T& b) { return a + static_cast(b); } - template ::value>::type> - local_def friend float16 operator+(const T& a, const float16& b) { return static_cast(a) + b; } - - template ::value>::type> - local_def friend float16 operator-(const float16& a, const T& b) { return a - static_cast(b); } - template ::value>::type> - local_def friend float16 operator-(const T& a, const float16& b) { return static_cast(a) - b; } - - template ::value>::type> - local_def friend float16 operator*(const float16& a, const T& b) { return a * static_cast(b); } - template ::value>::type> - local_def friend float16 operator*(const T& a, const float16& b) { return static_cast(a) * b; } - - template ::value>::type> - local_def friend float16 operator/(const float16& a, const T& b) { return a / static_cast(b); } - template ::value>::type> - local_def friend float16 operator/(const T& a, const float16& b) { return static_cast(a) / b; } - - template ::value>::type> - local_def friend bool operator==(const float16& a, const T& b) { return a == static_cast(b); } - template ::value>::type> - local_def friend bool operator==(const T& a, const float16& b) { return static_cast(a) == b; } - - template ::value>::type> - local_def friend bool operator!=(const float16& a, const T& b) { return a != static_cast(b); } - template ::value>::type> - local_def friend bool operator!=(const T& a, const float16& b) { return static_cast(a) != b; } - - template ::value>::type> - local_def friend bool operator<(const float16& a, const T& b) { return a < static_cast(b); } - template ::value>::type> - local_def friend bool operator<(const T& a, const float16& b) { return static_cast(a) < b; } - - template ::value>::type> - local_def friend bool operator>(const float16& a, const T& b) { return a > static_cast(b); } - template ::value>::type> - local_def friend bool operator>(const T& a, const float16& b) { return static_cast(a) > b; } - - template ::value>::type> - local_def friend bool operator<=(const float16& a, const T& b) { return a <= static_cast(b); } - template ::value>::type> - local_def friend bool operator<=(const T& a, const float16& b) { return static_cast(a) <= b; } - - template ::value>::type> - local_def friend bool operator>=(const float16& a, const T& b) { return a >= static_cast(b); } - template ::value>::type> - local_def friend bool operator>=(const T& a, const float16& b) { return static_cast(a) >= b; } - - local_def float16& operator+=(float16 rhs) { *this = (float)*this + (float)rhs; return *this; } - - local_def float16& operator-=(float16 rhs) { *this = (float)*this - (float)rhs; return *this; } - - local_def float16& operator*=(float16 rhs) { *this = (float)*this * (float)rhs; return *this; } - - local_def float16& operator/=(float16 rhs) { *this = (float)*this / (float)rhs; return *this; } - - template ::value>::type> - local_def float16& operator+=(const T& rhs) { *this = *this + rhs; return *this; } + local_def float16& operator=(const float& rhs) { +#ifdef __CUDA_ARCH__ + auto t = __float2half_rn(rhs); + auto b = *(data.getXP()); - template ::value>::type> - local_def float16& operator-=(const T& rhs) { *this = *this - rhs; return *this; } +#ifdef CUDA_8 + *(data.getXP()) = t; +#else + data.assign(t); +#endif - template ::value>::type> - local_def float16& operator*=(const T& rhs) { *this = *this * rhs; return *this; } +#else + data = cpu_float2ihalf_rn(rhs); +#endif - template ::value>::type> - local_def float16& operator/=(const T& rhs) { *this = *this / rhs; return *this; } + return *this; + } - local_def float16& operator++() { *this = *this + (float16)1.f; return *this; } + local_def float16& operator=(const unsigned short rhs) { + *data.getXP() = rhs; + return *this; + } - local_def float16& operator--() { *this = *this - (float16)1.f; return *this; } + local_def float16& operator=(const bool rhs) { + *this = (float)rhs ? 1.f : 0.f; + return *this; + } - local_def float16 operator++(int) { *this = *this + (float16)1.f; return *this; } + local_def float16& operator=(const ihalf& rhs) { + *data.getXP() = ((ihalf)rhs).getX(); + return *this; + } - local_def float16 operator--(int) { *this = *this - (float16)1.f; return *this; } +#ifdef __CUDACC__ + local_def float16& operator=(const half& rhs) { + data.assign(rhs); + return *this; + } +#endif - local_def float16 operator-() const { - return 0.f - (float)*this; - } + local_def float16& operator=(const float16& rhs) { + data = rhs.data; + return *this; + } + + template ::value || + std::is_same::value>::type> + local_def float16& operator=(const T& rhs) { + *this = (float)rhs; + return *this; + } + +#ifdef NATIVE_HALFS + local_def friend bool operator==(const float16& a, const float16& b) { + return __hequ(a.data, b.data); + } +#else + local_def friend bool operator==(const float16& a, const float16& b) { + return ishequ_(((ihalf)a.data).getX(), ((ihalf)b.data).getX()); + } +#endif - // local_def std::ostream& operator<<(std::ostream& os) { - // os << static_cast(*this); - // return os; - // } -}; +#ifdef NATIVE_HALFS + local_def friend bool operator!=(const float16& a, const float16& b) { + return !(__hequ(a.data, b.data)); + } +#else + local_def friend bool operator!=(const float16& a, const float16& b) { + return !(a == b); + } +#endif +#ifdef NATIVE_HALFS + local_def friend bool operator<(const float16& a, const float16& b) { + return __hlt(a.data, b.data); + } +#else + local_def friend bool operator<(const float16& a, const float16& b) { + return (float)a < (float)b; + } +#endif +#ifdef NATIVE_HALFS + local_def friend bool operator>(const float16& a, const float16& b) { + return __hgt(a.data, b.data); + } +#else + local_def friend bool operator>(const float16& a, const float16& b) { + return (float)a > (float)b; + } +#endif - // local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { - // os << static_cast(f); - // return os; - // } +#ifdef NATIVE_HALFS + local_def friend bool operator<=(const float16& a, const float16& b) { + return __hle(a.data, b.data); + } +#else + local_def friend bool operator<=(const float16& a, const float16& b) { + return (float)a <= (float)b; + } +#endif - // local_def float16 operator+(const float16& h) { return h; } +#ifdef NATIVE_HALFS + local_def friend bool operator>=(const float16& a, const float16& b) { + return __hge(a.data, b.data); + } +#else + local_def friend bool operator>=(const float16& a, const float16& b) { + return (float)a >= (float)b; + } +#endif + +#ifdef NATIVE_HALFS + local_def friend float16 operator+(const float16& a, const float16& b) { + return __hadd(a.data, b.data); + } + + local_def friend float16 operator-(const float16& a, const float16& b) { + return __hsub(a.data, b.data); + } + + local_def friend float16 operator*(const float16& a, const float16& b) { + return __hmul(a.data, b.data); + } + + local_def friend float16 operator/(const float16& a, const float16& b) { +#ifdef CUDA_8 + return hdiv(a.data, b.data); +#else + return __hdiv(a.data, b.data); +#endif + } +#else + local_def friend float16 operator+(const float16& a, const float16& b) { + return float16((float)a + (float)b); + } + local_def friend float16 operator-(const float16& a, const float16& b) { + return float16((float)a - (float)b); + } + local_def friend float16 operator*(const float16& a, const float16& b) { + return float16((float)a * (float)b); + } + local_def friend float16 operator/(const float16& a, const float16& b) { + return float16((float)a / (float)b); + } +#endif + + template ::value>::type> + local_def friend float16 operator+(const float16& a, const T& b) { + return a + static_cast(b); + } + template ::value>::type> + local_def friend float16 operator+(const T& a, const float16& b) { + return static_cast(a) + b; + } + + template ::value>::type> + local_def friend float16 operator-(const float16& a, const T& b) { + return a - static_cast(b); + } + template ::value>::type> + local_def friend float16 operator-(const T& a, const float16& b) { + return static_cast(a) - b; + } + + template ::value>::type> + local_def friend float16 operator*(const float16& a, const T& b) { + return a * static_cast(b); + } + template ::value>::type> + local_def friend float16 operator*(const T& a, const float16& b) { + return static_cast(a) * b; + } + + template ::value>::type> + local_def friend float16 operator/(const float16& a, const T& b) { + return a / static_cast(b); + } + template ::value>::type> + local_def friend float16 operator/(const T& a, const float16& b) { + return static_cast(a) / b; + } + + template ::value>::type> + local_def friend bool operator==(const float16& a, const T& b) { + return a == static_cast(b); + } + template ::value>::type> + local_def friend bool operator==(const T& a, const float16& b) { + return static_cast(a) == b; + } + + template ::value>::type> + local_def friend bool operator!=(const float16& a, const T& b) { + return a != static_cast(b); + } + template ::value>::type> + local_def friend bool operator!=(const T& a, const float16& b) { + return static_cast(a) != b; + } + + template ::value>::type> + local_def friend bool operator<(const float16& a, const T& b) { + return a < static_cast(b); + } + template ::value>::type> + local_def friend bool operator<(const T& a, const float16& b) { + return static_cast(a) < b; + } + + template ::value>::type> + local_def friend bool operator>(const float16& a, const T& b) { + return a > static_cast(b); + } + template ::value>::type> + local_def friend bool operator>(const T& a, const float16& b) { + return static_cast(a) > b; + } + + template ::value>::type> + local_def friend bool operator<=(const float16& a, const T& b) { + return a <= static_cast(b); + } + template ::value>::type> + local_def friend bool operator<=(const T& a, const float16& b) { + return static_cast(a) <= b; + } + + template ::value>::type> + local_def friend bool operator>=(const float16& a, const T& b) { + return a >= static_cast(b); + } + template ::value>::type> + local_def friend bool operator>=(const T& a, const float16& b) { + return static_cast(a) >= b; + } + + local_def float16& operator+=(float16 rhs) { + *this = (float)*this + (float)rhs; + return *this; + } + + local_def float16& operator-=(float16 rhs) { + *this = (float)*this - (float)rhs; + return *this; + } + + local_def float16& operator*=(float16 rhs) { + *this = (float)*this * (float)rhs; + return *this; + } + + local_def float16& operator/=(float16 rhs) { + *this = (float)*this / (float)rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator+=(const T& rhs) { + *this = *this + rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator-=(const T& rhs) { + *this = *this - rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator*=(const T& rhs) { + *this = *this * rhs; + return *this; + } + + template ::value>::type> + local_def float16& operator/=(const T& rhs) { + *this = *this / rhs; + return *this; + } + + local_def float16& operator++() { + *this = *this + (float16)1.f; + return *this; + } + + local_def float16& operator--() { + *this = *this - (float16)1.f; + return *this; + } + + local_def float16 operator++(int) { + *this = *this + (float16)1.f; + return *this; + } + + local_def float16 operator--(int) { + *this = *this - (float16)1.f; + return *this; + } + + local_def float16 operator-() const { return 0.f - (float)*this; } + + // local_def std::ostream& operator<<(std::ostream& os) { + // os << static_cast(*this); + // return os; + // } +}; - // local_def float16 operator - (const float16& h) { - // const ihalf * tmp = &h.data; - // return float16(hneg(tmp->getX())); - // } +// local_def std::ostream& operator<<(std::ostream &os, const float16 &f) { +// os << static_cast(f); +// return os; +// } + +// local_def float16 operator+(const float16& h) { return h; } + +// local_def float16 operator - (const float16& h) { +// const ihalf * tmp = &h.data; +// return float16(hneg(tmp->getX())); +// } #ifdef __CUDACC__ - local_def int isnan(const float16& h) { return ishnan_(((ihalf)h.data).getX()); } +local_def int isnan(const float16& h) { + return ishnan_(((ihalf)h.data).getX()); +} - local_def int isinf(const float16& h) { return ishinf_(((ihalf)h.data).getX()); } +local_def int isinf(const float16& h) { + return ishinf_(((ihalf)h.data).getX()); +} #endif - // std::ostream& operator << (std::ostream& s, const float16&); +// std::ostream& operator << (std::ostream& s, const float16&); #endif diff --git a/libnd4j/include/types/float8.h b/libnd4j/include/types/float8.h index 6dc03bba45fc..4060bb80ad27 100644 --- a/libnd4j/include/types/float8.h +++ b/libnd4j/include/types/float8.h @@ -35,151 +35,137 @@ #include - namespace sd { - typedef struct { - unsigned char x; - } __quarter; - - typedef __quarter quarter; - - quarter _CUDA_HD FORCEINLINE cpu_float2quarter_rn(float f); - float _CUDA_HD FORCEINLINE cpu_quarter2float(quarter b); - - struct float8 { - quarter data; +typedef struct { + unsigned char x; +} __quarter; - _CUDA_HD FORCEINLINE float8(); +typedef __quarter quarter; - template - _CUDA_HD FORCEINLINE float8(const T& rhs); +quarter _CUDA_HD FORCEINLINE cpu_float2quarter_rn(float f); +float _CUDA_HD FORCEINLINE cpu_quarter2float(quarter b); - template - _CUDA_HD FORCEINLINE float8& operator=(const T& rhs); +struct float8 { + quarter data; - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE float8(); - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE float8(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE float8& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - float cpu_quarter2float(quarter b) { - unsigned sign = ((b.x >> 7) & 1); - unsigned exponent = ((b.x >> 4) & 0x7); - unsigned mantissa = ((b.x & 0xf) << 19); + _CUDA_HD FORCEINLINE void assign(double rhs); - if (exponent == 0x7) { /* NaN or Inf */ - mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); - exponent = 0xff; - } else if (!exponent) { /* Denorm or Zero */ - if (mantissa) { - unsigned int msb; - exponent = 0x7d; - do { - msb = (mantissa & 0x400000); - mantissa <<= 1; /* normalize */ - --exponent; - } while (!msb); - mantissa &= 0x7fffff; /* 1.mantissa is implicit */ - } - } else { - exponent += 0x7C; - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int temp = ((sign << 31) | (exponent << 23) | mantissa); +float cpu_quarter2float(quarter b) { + unsigned sign = ((b.x >> 7) & 1); + unsigned exponent = ((b.x >> 4) & 0x7); + unsigned mantissa = ((b.x & 0xf) << 19); - return *((float*)((void*)&temp)); + if (exponent == 0x7) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x7d; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ } + } else { + exponent += 0x7C; + } + int temp = ((sign << 31) | (exponent << 23) | mantissa); + return *((float*)((void*)&temp)); +} - quarter cpu_float2quarter_rn(float f) - { - quarter ret; - - unsigned x = *((int*)(void*)(&f)); - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - unsigned sign, exponent, mantissa; - - // Get rid of +NaN/-NaN case first. - if (u > 0x7f800000) { - ret.x = 0x7fU; - return ret; - } - - sign = ((x >> 24) & 0x80); - - // Get rid of +Inf/-Inf, +0/-0. - if (u > 0x477fefff) { - ret.x = sign | 0x70U; - return ret; - } - if (u < 0x33000001) { - ret.x = (sign | 0x00); - return ret; - } - - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); - - if (exponent > 0x7C) { - shift = 19; - exponent -= 0x7C; - } else { - shift = 0x90 - exponent; - exponent = 0; - mantissa |= 0x800000; - } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); - - // Round to nearest even. - remainder = (mantissa & lsb_m1); - mantissa >>= shift; - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { - ++mantissa; - if (!(mantissa & 0xf)) { - ++exponent; - mantissa = 0; - } - } - - ret.x = (sign | (exponent << 4) | mantissa); - - return ret; +quarter cpu_float2quarter_rn(float f) { + quarter ret; + + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + ret.x = 0x7fU; + return ret; + } + + sign = ((x >> 24) & 0x80); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + ret.x = sign | 0x70U; + return ret; + } + if (u < 0x33000001) { + ret.x = (sign | 0x00); + return ret; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x7C) { + shift = 19; + exponent -= 0x7C; + } else { + shift = 0x90 - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0xf)) { + ++exponent; + mantissa = 0; } + } + ret.x = (sign | (exponent << 4) | mantissa); - float8::float8() { - data = cpu_float2quarter_rn(0.0f); - } + return ret; +} - template - float8::float8(const T& rhs) { - assign(rhs); - } +float8::float8() { data = cpu_float2quarter_rn(0.0f); } - template - float8& float8::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +float8::float8(const T& rhs) { + assign(rhs); +} +template +float8& float8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - float8::operator float() const { - return cpu_quarter2float(data); - } +float8::operator float() const { return cpu_quarter2float(data); } - void float8::assign(double rhs) { - assign((float)rhs); - } +void float8::assign(double rhs) { assign((float)rhs); } - void float8::assign(float rhs) { - data = cpu_float2quarter_rn(rhs); - } -} +void float8::assign(float rhs) { data = cpu_float2quarter_rn(rhs); } +} // namespace sd -#endif //LIBND4J_FLOAT8_H +#endif // LIBND4J_FLOAT8_H diff --git a/libnd4j/include/types/impl/int16.cpp b/libnd4j/include/types/impl/int16.cpp index 67f90e9d8724..7ab8290557e4 100644 --- a/libnd4j/include/types/impl/int16.cpp +++ b/libnd4j/include/types/impl/int16.cpp @@ -22,11 +22,11 @@ namespace sd { - /* - template int16::int16(const float& rhs); - template int16::int16(const double& rhs); +/* +template int16::int16(const float& rhs); +template int16::int16(const double& rhs); - template int16& int16::operator=(const float& rhs); - template int16& int16::operator=(const double& rhs); - */ +template int16& int16::operator=(const float& rhs); +template int16& int16::operator=(const double& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/pair.cpp b/libnd4j/include/types/impl/pair.cpp index 767bfa63028f..b12fd767cdfe 100644 --- a/libnd4j/include/types/impl/pair.cpp +++ b/libnd4j/include/types/impl/pair.cpp @@ -21,16 +21,12 @@ #include namespace sd { - Pair::Pair(int first, int second) { - _first = first; - _second = second; - } +Pair::Pair(int first, int second) { + _first = first; + _second = second; +} - int Pair::first() const { - return _first; - } +int Pair::first() const { return _first; } - int Pair::second() const { - return _second; - }; -} +int Pair::second() const { return _second; }; +} // namespace sd diff --git a/libnd4j/include/types/impl/triple.cpp b/libnd4j/include/types/impl/triple.cpp index 0b39d4bac15e..9e03c24cd5b5 100644 --- a/libnd4j/include/types/impl/triple.cpp +++ b/libnd4j/include/types/impl/triple.cpp @@ -21,21 +21,15 @@ #include namespace sd { - int Triple::first() const { - return _first; - } +int Triple::first() const { return _first; } - int Triple::second() const { - return _second; - } +int Triple::second() const { return _second; } - int Triple::third() const { - return _third; - } +int Triple::third() const { return _third; } - Triple::Triple(int first, int second, int third) { - _first = first; - _second = second; - _third = third; - } +Triple::Triple(int first, int second, int third) { + _first = first; + _second = second; + _third = third; } +} // namespace sd diff --git a/libnd4j/include/types/impl/uint16.cpp b/libnd4j/include/types/impl/uint16.cpp index 5b858222da3b..234fcaa4a815 100644 --- a/libnd4j/include/types/impl/uint16.cpp +++ b/libnd4j/include/types/impl/uint16.cpp @@ -23,12 +23,11 @@ namespace sd { +/* +template uint16::uint16(const float& rhs); +template uint16::uint16(const double& rhs); - /* - template uint16::uint16(const float& rhs); - template uint16::uint16(const double& rhs); - - template uint16& uint16::operator=(const double& rhs); - template uint16& uint16::operator=(const float& rhs); - */ +template uint16& uint16::operator=(const double& rhs); +template uint16& uint16::operator=(const float& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/uint8.cpp b/libnd4j/include/types/impl/uint8.cpp index a6d25c9d3780..0ab6861ca59a 100644 --- a/libnd4j/include/types/impl/uint8.cpp +++ b/libnd4j/include/types/impl/uint8.cpp @@ -22,11 +22,11 @@ namespace sd { - /* - template uint8::uint8(const float& rhs); - template uint8::uint8(const double& rhs); +/* +template uint8::uint8(const float& rhs); +template uint8::uint8(const double& rhs); - template uint8& uint8::operator=(const float& rhs); - template uint8& uint8::operator=(const double& rhs); - */ +template uint8& uint8::operator=(const float& rhs); +template uint8& uint8::operator=(const double& rhs); + */ } \ No newline at end of file diff --git a/libnd4j/include/types/impl/utf8string.cpp b/libnd4j/include/types/impl/utf8string.cpp index a7df7cc28ddd..7cc6dc8a20ea 100644 --- a/libnd4j/include/types/impl/utf8string.cpp +++ b/libnd4j/include/types/impl/utf8string.cpp @@ -19,57 +19,57 @@ // #include + #include namespace sd { - utf8string::~utf8string() { - if (_allocated) - delete[] _buffer; - } +utf8string::~utf8string() { + if (_allocated) delete[] _buffer; +} - utf8string::utf8string() { - _allocated = false; - _length = 0; - _buffer = nullptr; - } +utf8string::utf8string() { + _allocated = false; + _length = 0; + _buffer = nullptr; +} - utf8string::utf8string(const char *string, int length) { - _length = length; - _buffer = new char[_length]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, string, _length); - } +utf8string::utf8string(const char *string, int length) { + _length = length; + _buffer = new char[_length]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, string, _length); +} - utf8string::utf8string(const std::string &str) { - _length = str.length(); - _buffer = new char[_length + 1]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, str.data(), _length); - _buffer[_length] = 0; - } +utf8string::utf8string(const std::string &str) { + _length = str.length(); + _buffer = new char[_length + 1]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, str.data(), _length); + _buffer[_length] = 0; +} - utf8string::utf8string(const utf8string &other) { - _length = other._length; - _buffer = new char[_length+1]; - _allocated = true; - std::memset(_buffer, 0, _length + 1); - std::memcpy(_buffer, other._buffer, _length); - _buffer[_length] = 0; - } +utf8string::utf8string(const utf8string &other) { + _length = other._length; + _buffer = new char[_length + 1]; + _allocated = true; + std::memset(_buffer, 0, _length + 1); + std::memcpy(_buffer, other._buffer, _length); + _buffer[_length] = 0; +} - void utf8string::Swap(utf8string &other) { - std::swap(_length, other._length); - std::swap(_buffer, other._buffer);// = new char[_length+1]; - std::swap(_allocated, other._allocated); // = true; - } +void utf8string::Swap(utf8string &other) { + std::swap(_length, other._length); + std::swap(_buffer, other._buffer); // = new char[_length+1]; + std::swap(_allocated, other._allocated); // = true; +} - utf8string& utf8string::operator=(const utf8string &other) { - if (this != &other) { - utf8string temp(other); - Swap(temp); - } - return *this; - } +utf8string &utf8string::operator=(const utf8string &other) { + if (this != &other) { + utf8string temp(other); + Swap(temp); + } + return *this; } +} // namespace sd diff --git a/libnd4j/include/types/int16.h b/libnd4j/include/types/int16.h index 25a77138151d..ab5ac4eb75a3 100644 --- a/libnd4j/include/types/int16.h +++ b/libnd4j/include/types/int16.h @@ -24,75 +24,61 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_int162float(int16_t data); - int16_t _CUDA_HD FORCEINLINE cpu_float2int16(float data); - - struct int16 { - int16_t data; - - _CUDA_HD FORCEINLINE int16(); - _CUDA_HD FORCEINLINE ~int16() = default; - - template - _CUDA_HD FORCEINLINE int16(const T& rhs); - - template - _CUDA_HD FORCEINLINE int16& operator=(const T& rhs); +float _CUDA_HD FORCEINLINE cpu_int162float(int16_t data); +int16_t _CUDA_HD FORCEINLINE cpu_float2int16(float data); +struct int16 { + int16_t data; - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE int16(); + _CUDA_HD FORCEINLINE ~int16() = default; - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE int16(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE int16& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - ////////////////////////////// + _CUDA_HD FORCEINLINE void assign(double rhs); - float cpu_int162float(int16_t data) { - return (float) ((int) data); - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int16_t cpu_float2int16(float data) { - auto t = static_cast(data); - if (t > 32767 ) t = 32767; - if (t < -32768) t = -32768; +////////////////////////////// - return static_cast(t); - } +float cpu_int162float(int16_t data) { return (float)((int)data); } +int16_t cpu_float2int16(float data) { + auto t = static_cast(data); + if (t > 32767) t = 32767; + if (t < -32768) t = -32768; - int16::int16() { - data = cpu_float2int16(0.0f); - } + return static_cast(t); +} - template - int16::int16(const T& rhs) { - assign(rhs); - } +int16::int16() { data = cpu_float2int16(0.0f); } - template - int16& int16::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +int16::int16(const T& rhs) { + assign(rhs); +} +template +int16& int16::operator=(const T& rhs) { + assign(rhs); + return *this; +} - int16::operator float() const { - return cpu_int162float(data); - } +int16::operator float() const { return cpu_int162float(data); } - void int16::assign(double rhs) { - assign(static_cast(rhs)); - } +void int16::assign(double rhs) { assign(static_cast(rhs)); } - void int16::assign(float rhs) { - data = cpu_float2int16(rhs); - } +void int16::assign(float rhs) { data = cpu_float2int16(rhs); } -} +} // namespace sd -#endif //LIBND4J_INT16_H +#endif // LIBND4J_INT16_H diff --git a/libnd4j/include/types/int8.h b/libnd4j/include/types/int8.h index 19e1b91e16a4..1e1daaeba8d6 100644 --- a/libnd4j/include/types/int8.h +++ b/libnd4j/include/types/int8.h @@ -24,71 +24,58 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_int82float(int8_t data); - int8_t _CUDA_HD FORCEINLINE cpu_float2int8(float data); - - struct int8 { - int8_t data; - - _CUDA_HD FORCEINLINE int8(); - _CUDA_HD FORCEINLINE ~int8() = default; - - template - _CUDA_HD FORCEINLINE int8(const T& rhs); +float _CUDA_HD FORCEINLINE cpu_int82float(int8_t data); +int8_t _CUDA_HD FORCEINLINE cpu_float2int8(float data); - template - _CUDA_HD FORCEINLINE int8& operator=(const T& rhs); +struct int8 { + int8_t data; + _CUDA_HD FORCEINLINE int8(); + _CUDA_HD FORCEINLINE ~int8() = default; - _CUDA_HD FORCEINLINE operator float() const; + template + _CUDA_HD FORCEINLINE int8(const T& rhs); - _CUDA_HD FORCEINLINE void assign(double rhs); + template + _CUDA_HD FORCEINLINE int8& operator=(const T& rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE void assign(double rhs); - float cpu_int82float(int8_t data) { - return (float) ((int) data); - } + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - int8_t cpu_float2int8(float data) { - int t = (int) data; - if (t > 127) t = 127; - if (t < -128) t = -128; +float cpu_int82float(int8_t data) { return (float)((int)data); } - return (int8_t) t; - } +int8_t cpu_float2int8(float data) { + int t = (int)data; + if (t > 127) t = 127; + if (t < -128) t = -128; - int8::int8() { - data = cpu_float2int8(0.0f); - } + return (int8_t)t; +} - template - int8::int8(const T& rhs) { - assign(rhs); - } +int8::int8() { data = cpu_float2int8(0.0f); } - template - int8& int8::operator=(const T& rhs) { - assign(rhs); return *this; - } +template +int8::int8(const T& rhs) { + assign(rhs); +} +template +int8& int8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - int8::operator float() const { - return cpu_int82float(data); - } +int8::operator float() const { return cpu_int82float(data); } - void int8::assign(double rhs) { - assign((float)rhs); - } +void int8::assign(double rhs) { assign((float)rhs); } - void int8::assign(float rhs) { - data = cpu_float2int8(rhs); - } -} +void int8::assign(float rhs) { data = cpu_float2int8(rhs); } +} // namespace sd -#endif //LIBND4J_INT8_H +#endif // LIBND4J_INT8_H diff --git a/libnd4j/include/types/pair.h b/libnd4j/include/types/pair.h index 7067e888242c..2a4142042c91 100644 --- a/libnd4j/include/types/pair.h +++ b/libnd4j/include/types/pair.h @@ -24,19 +24,18 @@ #include namespace sd { - class SD_EXPORT Pair { - protected: - int _first = 0; - int _second = 0; +class SD_EXPORT Pair { + protected: + int _first = 0; + int _second = 0; - public: - Pair(int first = 0, int second = 0); - ~Pair() = default; + public: + Pair(int first = 0, int second = 0); + ~Pair() = default; - int first() const; - int second() const; - }; -} + int first() const; + int second() const; +}; +} // namespace sd - -#endif //LIBND4J_PAIR_H +#endif // LIBND4J_PAIR_H diff --git a/libnd4j/include/types/triple.h b/libnd4j/include/types/triple.h index 520e24d569ab..e7912084f535 100644 --- a/libnd4j/include/types/triple.h +++ b/libnd4j/include/types/triple.h @@ -21,24 +21,23 @@ #ifndef LIBND4J_TRIPLE_H #define LIBND4J_TRIPLE_H - #include namespace sd { - class SD_EXPORT Triple { - protected: - int _first = 0; - int _second = 0; - int _third = 0; - - public: - Triple(int first = 0, int second = 0, int third = 0); - ~Triple() = default; - - int first() const; - int second() const; - int third() const; - }; -} - -#endif //LIBND4J_TRIPLE_H +class SD_EXPORT Triple { + protected: + int _first = 0; + int _second = 0; + int _third = 0; + + public: + Triple(int first = 0, int second = 0, int third = 0); + ~Triple() = default; + + int first() const; + int second() const; + int third() const; +}; +} // namespace sd + +#endif // LIBND4J_TRIPLE_H diff --git a/libnd4j/include/types/types.h b/libnd4j/include/types/types.h index 7717c801908d..a6fe4064d4bb 100644 --- a/libnd4j/include/types/types.h +++ b/libnd4j/include/types/types.h @@ -22,478 +22,289 @@ #define LIBND4J_TYPES_H #include -#include +#include +#include #include -#include +#include #include -#include +#include #include +#include #include -#include -#include - -#define LIBND4J_STRINGTYPES \ - (sd::DataType::UTF8, std::string),\ - (sd::DataType::UTF16, std::u16string), \ - (sd::DataType::UTF32, std::u32string) - -#define LIBND4J_TYPES \ - (sd::DataType::BFLOAT16, bfloat16),\ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::BOOL, bool), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) - -#define LIBND4J_TYPES_EXTENDED \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::BOOL, bool), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT64, Nd4jULong), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::BFLOAT16, bfloat16) - -#define BOOL_TYPES \ - (sd::DataType::BOOL, bool) +#define LIBND4J_STRINGTYPES \ + (sd::DataType::UTF8, std::string), (sd::DataType::UTF16, std::u16string), \ + (sd::DataType::UTF32, std::u32string) + +#define LIBND4J_TYPES \ + (sd::DataType::BFLOAT16, bfloat16), (sd::DataType::HALF, float16), \ + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double), \ + (sd::DataType::BOOL, bool), (sd::DataType::INT8, int8_t), \ + (sd::DataType::UINT8, uint8_t), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT32, uint32_t), (sd::DataType::UINT64, uint64_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong) + +#define LIBND4J_TYPES_EXTENDED \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::BOOL, bool), \ + (sd::DataType::INT8, int8_t), (sd::DataType::UINT8, uint8_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT64, Nd4jULong), (sd::DataType::UINT32, uint32_t), \ + (sd::DataType::BFLOAT16, bfloat16) + +#define BOOL_TYPES (sd::DataType::BOOL, bool) #define LONG_TYPES \ - (sd::DataType::INT64, Nd4jLong),\ - (sd::DataType::UINT64, uint64_t) + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT64, uint64_t) -#define FLOAT_TYPES \ - (sd::DataType::BFLOAT16, bfloat16) ,\ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double) +#define FLOAT_TYPES \ + (sd::DataType::BFLOAT16, bfloat16), (sd::DataType::HALF, float16), \ + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double) #define INDEXING_TYPES \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) #define FLOAT_NATIVE \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double) + (sd::DataType::FLOAT32, float), (sd::DataType::DOUBLE, double) -#define FLOAT_TYPES_0 \ - (sd::DataType::HALF, float16) +#define FLOAT_TYPES_0 (sd::DataType::HALF, float16) -#define FLOAT_TYPES_1 \ - (sd::DataType::FLOAT32, float) +#define FLOAT_TYPES_1 (sd::DataType::FLOAT32, float) -#define FLOAT_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define FLOAT_TYPES_2 (sd::DataType::DOUBLE, double) -#define FLOAT_TYPES_3 \ - (sd::DataType::BFLOAT16, bfloat16) +#define FLOAT_TYPES_3 (sd::DataType::BFLOAT16, bfloat16) -#define LIBND4J_TYPES_0 \ - (sd::DataType::HALF, float16) +#define LIBND4J_TYPES_0 (sd::DataType::HALF, float16) -#define LIBND4J_TYPES_1 \ - (sd::DataType::FLOAT32, float) +#define LIBND4J_TYPES_1 (sd::DataType::FLOAT32, float) -#define LIBND4J_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define LIBND4J_TYPES_2 (sd::DataType::DOUBLE, double) -#define LIBND4J_TYPES_3 \ - (sd::DataType::BOOL, bool) +#define LIBND4J_TYPES_3 (sd::DataType::BOOL, bool) -#define LIBND4J_TYPES_4 \ - (sd::DataType::INT8, int8_t) +#define LIBND4J_TYPES_4 (sd::DataType::INT8, int8_t) -#define LIBND4J_TYPES_5 \ - (sd::DataType::UINT8, uint8_t) +#define LIBND4J_TYPES_5 (sd::DataType::UINT8, uint8_t) #define LIBND4J_TYPES_6 \ - (sd::DataType::INT16, int16_t),\ - (sd::DataType::UINT16, uint16_t) + (sd::DataType::INT16, int16_t), (sd::DataType::UINT16, uint16_t) #define LIBND4J_TYPES_7 \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::UINT32, uint32_t) + (sd::DataType::INT32, int32_t), (sd::DataType::UINT32, uint32_t) #define LIBND4J_TYPES_8 \ - (sd::DataType::INT64, Nd4jLong),\ - (sd::DataType::UINT64, uint64_t) + (sd::DataType::INT64, Nd4jLong), (sd::DataType::UINT64, uint64_t) -#define LIBND4J_TYPES_9 \ - (sd::DataType::BFLOAT16, bfloat16) +#define LIBND4J_TYPES_9 (sd::DataType::BFLOAT16, bfloat16) -#define INTEGER_TYPES \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) +#define INTEGER_TYPES \ + (sd::DataType::INT8, int8_t), (sd::DataType::UINT8, uint8_t), \ + (sd::DataType::UINT16, uint16_t), (sd::DataType::UINT32, uint32_t), \ + (sd::DataType::UINT64, uint64_t), (sd::DataType::INT16, int16_t), \ + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) -#define INTEGER_TYPES_0 \ - (sd::DataType::INT8, int8_t) +#define INTEGER_TYPES_0 (sd::DataType::INT8, int8_t) -#define INTEGER_TYPES_1 \ - (sd::DataType::UINT8, uint8_t) +#define INTEGER_TYPES_1 (sd::DataType::UINT8, uint8_t) -#define INTEGER_TYPES_2 \ - (sd::DataType::UINT16, uint16_t) +#define INTEGER_TYPES_2 (sd::DataType::UINT16, uint16_t) -#define INTEGER_TYPES_3 \ - (sd::DataType::UINT32, uint32_t) +#define INTEGER_TYPES_3 (sd::DataType::UINT32, uint32_t) -#define INTEGER_TYPES_4 \ - (sd::DataType::UINT64, uint64_t) +#define INTEGER_TYPES_4 (sd::DataType::UINT64, uint64_t) -#define INTEGER_TYPES_5 \ - (sd::DataType::INT16, int16_t) +#define INTEGER_TYPES_5 (sd::DataType::INT16, int16_t) -#define INTEGER_TYPES_6 \ - (sd::DataType::INT32, int32_t) +#define INTEGER_TYPES_6 (sd::DataType::INT32, int32_t) -#define INTEGER_TYPES_7 \ - (sd::DataType::INT64, Nd4jLong) +#define INTEGER_TYPES_7 (sd::DataType::INT64, Nd4jLong) +#define NUMERIC_TYPES \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::INT8, int8_t), \ + (sd::DataType::UINT8, uint8_t), (sd::DataType::UINT16, uint16_t), \ + (sd::DataType::UINT32, uint32_t), (sd::DataType::UINT64, uint64_t), \ + (sd::DataType::INT16, int16_t), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::BFLOAT16, bfloat16) -#define NUMERIC_TYPES \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::UINT8, uint8_t), \ - (sd::DataType::UINT16, uint16_t), \ - (sd::DataType::UINT32, uint32_t), \ - (sd::DataType::UINT64, uint64_t), \ - (sd::DataType::INT16, int16_t), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::BFLOAT16, bfloat16) +#define NUMERIC_TYPES_0 (sd::DataType::HALF, float16) -#define NUMERIC_TYPES_0 \ - (sd::DataType::HALF, float16) +#define NUMERIC_TYPES_1 (sd::DataType::FLOAT32, float) -#define NUMERIC_TYPES_1 \ - (sd::DataType::FLOAT32, float) - -#define NUMERIC_TYPES_2 \ - (sd::DataType::DOUBLE, double) +#define NUMERIC_TYPES_2 (sd::DataType::DOUBLE, double) #define NUMERIC_TYPES_3 \ - (sd::DataType::INT8, int8_t), \ - (sd::DataType::BFLOAT16, bfloat16) + (sd::DataType::INT8, int8_t), (sd::DataType::BFLOAT16, bfloat16) -#define NUMERIC_TYPES_4 \ - (sd::DataType::UINT8, uint8_t) +#define NUMERIC_TYPES_4 (sd::DataType::UINT8, uint8_t) -#define NUMERIC_TYPES_5 \ - (sd::DataType::UINT16, uint16_t) +#define NUMERIC_TYPES_5 (sd::DataType::UINT16, uint16_t) -#define NUMERIC_TYPES_6 \ - (sd::DataType::UINT32, uint32_t) +#define NUMERIC_TYPES_6 (sd::DataType::UINT32, uint32_t) -#define NUMERIC_TYPES_7 \ - (sd::DataType::UINT64, uint64_t) +#define NUMERIC_TYPES_7 (sd::DataType::UINT64, uint64_t) -#define NUMERIC_TYPES_8 \ - (sd::DataType::INT16, int16_t) +#define NUMERIC_TYPES_8 (sd::DataType::INT16, int16_t) #define NUMERIC_TYPES_9 \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong) - - -#define GENERIC_NUMERIC_TYPES \ - (sd::DataType::HALF, float16), \ - (sd::DataType::FLOAT32, float), \ - (sd::DataType::DOUBLE, double), \ - (sd::DataType::INT32, int32_t), \ - (sd::DataType::INT64, Nd4jLong), \ - (sd::DataType::BFLOAT16, bfloat16) + (sd::DataType::INT32, int32_t), (sd::DataType::INT64, Nd4jLong) +#define GENERIC_NUMERIC_TYPES \ + (sd::DataType::HALF, float16), (sd::DataType::FLOAT32, float), \ + (sd::DataType::DOUBLE, double), (sd::DataType::INT32, int32_t), \ + (sd::DataType::INT64, Nd4jLong), (sd::DataType::BFLOAT16, bfloat16) #ifdef __ND4J_EXPERIMENTAL__ -#define PAIRWISE_TYPES_0 \ - (double, double, double), \ - (double, uint8_t, double), \ - (double, uint8_t, uint8_t), \ - (double, float, double), \ - (double, float, float), \ - (double, bfloat16, double), \ - (double, bfloat16, bfloat16), \ - (double, Nd4jLong, double), \ - (double, Nd4jLong, Nd4jLong), \ - (double, int32_t, double), \ - (double, int32_t, int32_t) , \ - (bool, bool, bool), \ - (bool, int8_t, bool), \ - (int8_t, bool, bool), \ - (int8_t, int8_t, int8_t), \ - (int16_t, bool, int16_t), \ - (float16, int8_t, float16), \ - (bfloat16, bool, bool), \ - (double, int8_t, double) - -#define PAIRWISE_TYPES_9 \ - (double, int16_t, double), \ - (double, int16_t, int16_t), \ - (double, float16, double), \ - (double, float16, float16), \ - (double, bool, double), \ - (double, bool, bool), \ - (int8_t, double, int8_t), \ - (int8_t, double, double), \ - (int8_t, uint8_t, int8_t), \ - (int8_t, uint8_t, uint8_t), \ - (int8_t, float, int8_t), \ - (double, int8_t, int8_t) , \ - (bool, int8_t, int8_t) ,\ - (float16, int32_t, float16), \ - (float16, int32_t, int32_t), \ - (float16, int16_t, float16), \ - (float16, int16_t, int16_t), \ - (float16, float16, float16), \ - (float16, bool, float16), \ - (float16, bool, bool), \ - (float16, int8_t, int8_t) - -#define PAIRWISE_TYPES_1 \ - (uint8_t, double, uint8_t), \ - (uint8_t, double, double), \ - (uint8_t, uint8_t, uint8_t), \ - (uint8_t, float, uint8_t), \ - (uint8_t, float, float), \ - (uint8_t, bfloat16, uint8_t), \ - (uint8_t, bfloat16, bfloat16), \ - (uint8_t, Nd4jLong, uint8_t) , \ - (uint8_t, Nd4jLong, Nd4jLong), \ - (uint8_t, int32_t, uint8_t), \ - (uint8_t, int32_t, int32_t), \ - (uint8_t, int16_t, uint8_t), \ - (uint8_t, int16_t, int16_t), \ - (uint8_t, float16, uint8_t), \ - (uint8_t, float16, float16), \ - (uint8_t, bool, uint8_t), \ - (uint8_t, bool, bool), \ - (uint8_t, int8_t, uint8_t), \ - (uint8_t, int8_t, int8_t) - -#define PAIRWISE_TYPES_2 \ - (float, double, float), \ - (float, double, double), \ - (float, uint8_t, float), \ - (float, uint8_t, uint8_t), \ - (float, float, float), \ - (float, bfloat16, float), \ - (float, bfloat16, bfloat16), \ - (float, Nd4jLong, float), \ - (float, Nd4jLong, Nd4jLong) , \ - (float, int32_t, float), \ - (int8_t, int32_t, int8_t), \ - (int8_t, int32_t, int32_t), \ - (float, int32_t, int32_t), \ - (float, int16_t, float), \ - (float, int16_t, int16_t), \ - (float, float16, float), \ - (float, float16, float16), \ - (float, bool, float) - -#define PAIRWISE_TYPES_3 \ - (bfloat16, double, bfloat16), \ - (bfloat16, double, double), \ - (bfloat16, uint8_t, bfloat16), \ - (bfloat16, uint8_t, uint8_t), \ - (bfloat16, float, bfloat16), \ - (bfloat16, float, float), \ - (bfloat16, bfloat16, bfloat16), \ - (bfloat16, Nd4jLong, bfloat16), \ - (bfloat16, Nd4jLong, Nd4jLong), \ - (bfloat16, int32_t, bfloat16) , \ - (float, bool, bool), \ - (bfloat16, int32_t, int32_t), \ - (float, int8_t, float), \ - (int8_t, float, float), \ - (int8_t, bfloat16, int8_t), \ - (int8_t, bfloat16, bfloat16), \ - (int8_t, Nd4jLong, int8_t), \ - (int8_t, Nd4jLong, Nd4jLong), \ - (float, int8_t, int8_t) - -#define PAIRWISE_TYPES_4 \ - (Nd4jLong, double, Nd4jLong), \ - (Nd4jLong, double, double), \ - (Nd4jLong, uint8_t, Nd4jLong), \ - (Nd4jLong, uint8_t, uint8_t), \ - (Nd4jLong, float, Nd4jLong), \ - (Nd4jLong, float, float), \ - (Nd4jLong, bfloat16, Nd4jLong), \ - (Nd4jLong, bfloat16, bfloat16), \ - (Nd4jLong, Nd4jLong, Nd4jLong), \ - (Nd4jLong, int32_t, Nd4jLong), \ - (bfloat16, int16_t, bfloat16), \ - (bfloat16, int16_t, int16_t), \ - (bfloat16, float16, bfloat16), \ - (bfloat16, float16, float16), \ - (bfloat16, bool, bfloat16), \ - (int8_t, int16_t, int8_t), \ - (int8_t, int16_t, int16_t), \ - (bfloat16, int8_t, bfloat16), \ - (bfloat16, int8_t, int8_t) - -#define PAIRWISE_TYPES_5 \ - (int32_t, double, int32_t), \ - (int32_t, double, double), \ - (int32_t, uint8_t, int32_t), \ - (int32_t, uint8_t, uint8_t), \ - (int32_t, float, int32_t), \ - (int32_t, float, float), \ - (int32_t, bfloat16, int32_t), \ - (int32_t, bfloat16, bfloat16), \ - (int32_t, Nd4jLong, int32_t), \ - (Nd4jLong, int32_t, int32_t), \ - (Nd4jLong, int16_t, Nd4jLong), \ - (Nd4jLong, int16_t, int16_t), \ - (Nd4jLong, float16, Nd4jLong), \ - (Nd4jLong, float16, float16), \ - (Nd4jLong, bool, Nd4jLong), \ - (Nd4jLong, bool, bool), \ - (Nd4jLong, int8_t, Nd4jLong), \ - (Nd4jLong, int8_t, int8_t) - - -#define PAIRWISE_TYPES_6 \ - (int16_t, double, int16_t), \ - (int16_t, double, double), \ - (int16_t, uint8_t, int16_t), \ - (int16_t, uint8_t, uint8_t), \ - (int16_t, float, int16_t), \ - (int16_t, float, float), \ - (int16_t, bfloat16, int16_t), \ - (int16_t, bfloat16, bfloat16), \ - (int16_t, Nd4jLong, int16_t), \ - (float16, bfloat16, bfloat16), \ - (int16_t, Nd4jLong, Nd4jLong), \ - (int32_t, Nd4jLong, Nd4jLong), \ - (int32_t, int32_t, int32_t), \ - (int32_t, int16_t, int32_t), \ - (int32_t, int16_t, int16_t), \ - (int32_t, float16, int32_t), \ - (int32_t, float16, float16), \ - (int32_t, bool, int32_t), \ - (int32_t, bool, bool), \ - (int32_t, int8_t, int32_t), \ - (int32_t, int8_t, int8_t) - - -#define PAIRWISE_TYPES_7 \ - (float16, double, float16), \ - (float16, double, double), \ - (float16, uint8_t, float16), \ - (float16, uint8_t, uint8_t), \ - (float16, float, float16), \ - (float16, float, float), \ - (float16, bfloat16, float16), \ - (float16, Nd4jLong, float16), \ - (float16, Nd4jLong, Nd4jLong), \ - (int16_t, int32_t, int16_t), \ - (int16_t, int32_t, int32_t), \ - (int16_t, int16_t, int16_t), \ - (int16_t, float16, int16_t), \ - (int16_t, float16, float16), \ - (int8_t, float16, int8_t), \ - (int8_t, float16, float16), \ - (int8_t, bool, int8_t), \ - (int16_t, bool, bool), \ - (int16_t, int8_t, int16_t), \ - (int16_t, int8_t, int8_t) - - -#define PAIRWISE_TYPES_8 \ - (bool, double, bool), \ - (bool, double, double), \ - (bool, uint8_t, bool), \ - (bool, uint8_t, uint8_t), \ - (bool, float, bool), \ - (bool, float, float), \ - (bool, bfloat16, bool) ,\ - (bool, bfloat16, bfloat16), \ - (bool, Nd4jLong, bool), \ - (bool, Nd4jLong, Nd4jLong), \ - (bool, int32_t, bool), \ - (bool, int32_t, int32_t), \ - (bool, int16_t, bool), \ - (bool, int16_t, int16_t), \ - (bool, float16, bool), \ - (bool, float16, float16) - +#define PAIRWISE_TYPES_0 \ + (double, double, double), (double, uint8_t, double), \ + (double, uint8_t, uint8_t), (double, float, double), \ + (double, float, float), (double, bfloat16, double), \ + (double, bfloat16, bfloat16), (double, Nd4jLong, double), \ + (double, Nd4jLong, Nd4jLong), (double, int32_t, double), \ + (double, int32_t, int32_t), (bool, bool, bool), (bool, int8_t, bool), \ + (int8_t, bool, bool), (int8_t, int8_t, int8_t), \ + (int16_t, bool, int16_t), (float16, int8_t, float16), \ + (bfloat16, bool, bool), (double, int8_t, double) + +#define PAIRWISE_TYPES_9 \ + (double, int16_t, double), (double, int16_t, int16_t), \ + (double, float16, double), (double, float16, float16), \ + (double, bool, double), (double, bool, bool), (int8_t, double, int8_t), \ + (int8_t, double, double), (int8_t, uint8_t, int8_t), \ + (int8_t, uint8_t, uint8_t), (int8_t, float, int8_t), \ + (double, int8_t, int8_t), (bool, int8_t, int8_t), \ + (float16, int32_t, float16), (float16, int32_t, int32_t), \ + (float16, int16_t, float16), (float16, int16_t, int16_t), \ + (float16, float16, float16), (float16, bool, float16), \ + (float16, bool, bool), (float16, int8_t, int8_t) + +#define PAIRWISE_TYPES_1 \ + (uint8_t, double, uint8_t), (uint8_t, double, double), \ + (uint8_t, uint8_t, uint8_t), (uint8_t, float, uint8_t), \ + (uint8_t, float, float), (uint8_t, bfloat16, uint8_t), \ + (uint8_t, bfloat16, bfloat16), (uint8_t, Nd4jLong, uint8_t), \ + (uint8_t, Nd4jLong, Nd4jLong), (uint8_t, int32_t, uint8_t), \ + (uint8_t, int32_t, int32_t), (uint8_t, int16_t, uint8_t), \ + (uint8_t, int16_t, int16_t), (uint8_t, float16, uint8_t), \ + (uint8_t, float16, float16), (uint8_t, bool, uint8_t), \ + (uint8_t, bool, bool), (uint8_t, int8_t, uint8_t), \ + (uint8_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_2 \ + (float, double, float), (float, double, double), (float, uint8_t, float), \ + (float, uint8_t, uint8_t), (float, float, float), \ + (float, bfloat16, float), (float, bfloat16, bfloat16), \ + (float, Nd4jLong, float), (float, Nd4jLong, Nd4jLong), \ + (float, int32_t, float), (int8_t, int32_t, int8_t), \ + (int8_t, int32_t, int32_t), (float, int32_t, int32_t), \ + (float, int16_t, float), (float, int16_t, int16_t), \ + (float, float16, float), (float, float16, float16), (float, bool, float) + +#define PAIRWISE_TYPES_3 \ + (bfloat16, double, bfloat16), (bfloat16, double, double), \ + (bfloat16, uint8_t, bfloat16), (bfloat16, uint8_t, uint8_t), \ + (bfloat16, float, bfloat16), (bfloat16, float, float), \ + (bfloat16, bfloat16, bfloat16), (bfloat16, Nd4jLong, bfloat16), \ + (bfloat16, Nd4jLong, Nd4jLong), (bfloat16, int32_t, bfloat16), \ + (float, bool, bool), (bfloat16, int32_t, int32_t), \ + (float, int8_t, float), (int8_t, float, float), \ + (int8_t, bfloat16, int8_t), (int8_t, bfloat16, bfloat16), \ + (int8_t, Nd4jLong, int8_t), (int8_t, Nd4jLong, Nd4jLong), \ + (float, int8_t, int8_t) + +#define PAIRWISE_TYPES_4 \ + (Nd4jLong, double, Nd4jLong), (Nd4jLong, double, double), \ + (Nd4jLong, uint8_t, Nd4jLong), (Nd4jLong, uint8_t, uint8_t), \ + (Nd4jLong, float, Nd4jLong), (Nd4jLong, float, float), \ + (Nd4jLong, bfloat16, Nd4jLong), (Nd4jLong, bfloat16, bfloat16), \ + (Nd4jLong, Nd4jLong, Nd4jLong), (Nd4jLong, int32_t, Nd4jLong), \ + (bfloat16, int16_t, bfloat16), (bfloat16, int16_t, int16_t), \ + (bfloat16, float16, bfloat16), (bfloat16, float16, float16), \ + (bfloat16, bool, bfloat16), (int8_t, int16_t, int8_t), \ + (int8_t, int16_t, int16_t), (bfloat16, int8_t, bfloat16), \ + (bfloat16, int8_t, int8_t) + +#define PAIRWISE_TYPES_5 \ + (int32_t, double, int32_t), (int32_t, double, double), \ + (int32_t, uint8_t, int32_t), (int32_t, uint8_t, uint8_t), \ + (int32_t, float, int32_t), (int32_t, float, float), \ + (int32_t, bfloat16, int32_t), (int32_t, bfloat16, bfloat16), \ + (int32_t, Nd4jLong, int32_t), (Nd4jLong, int32_t, int32_t), \ + (Nd4jLong, int16_t, Nd4jLong), (Nd4jLong, int16_t, int16_t), \ + (Nd4jLong, float16, Nd4jLong), (Nd4jLong, float16, float16), \ + (Nd4jLong, bool, Nd4jLong), (Nd4jLong, bool, bool), \ + (Nd4jLong, int8_t, Nd4jLong), (Nd4jLong, int8_t, int8_t) + +#define PAIRWISE_TYPES_6 \ + (int16_t, double, int16_t), (int16_t, double, double), \ + (int16_t, uint8_t, int16_t), (int16_t, uint8_t, uint8_t), \ + (int16_t, float, int16_t), (int16_t, float, float), \ + (int16_t, bfloat16, int16_t), (int16_t, bfloat16, bfloat16), \ + (int16_t, Nd4jLong, int16_t), (float16, bfloat16, bfloat16), \ + (int16_t, Nd4jLong, Nd4jLong), (int32_t, Nd4jLong, Nd4jLong), \ + (int32_t, int32_t, int32_t), (int32_t, int16_t, int32_t), \ + (int32_t, int16_t, int16_t), (int32_t, float16, int32_t), \ + (int32_t, float16, float16), (int32_t, bool, int32_t), \ + (int32_t, bool, bool), (int32_t, int8_t, int32_t), \ + (int32_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_7 \ + (float16, double, float16), (float16, double, double), \ + (float16, uint8_t, float16), (float16, uint8_t, uint8_t), \ + (float16, float, float16), (float16, float, float), \ + (float16, bfloat16, float16), (float16, Nd4jLong, float16), \ + (float16, Nd4jLong, Nd4jLong), (int16_t, int32_t, int16_t), \ + (int16_t, int32_t, int32_t), (int16_t, int16_t, int16_t), \ + (int16_t, float16, int16_t), (int16_t, float16, float16), \ + (int8_t, float16, int8_t), (int8_t, float16, float16), \ + (int8_t, bool, int8_t), (int16_t, bool, bool), \ + (int16_t, int8_t, int16_t), (int16_t, int8_t, int8_t) + +#define PAIRWISE_TYPES_8 \ + (bool, double, bool), (bool, double, double), (bool, uint8_t, bool), \ + (bool, uint8_t, uint8_t), (bool, float, bool), (bool, float, float), \ + (bool, bfloat16, bool), (bool, bfloat16, bfloat16), \ + (bool, Nd4jLong, bool), (bool, Nd4jLong, Nd4jLong), \ + (bool, int32_t, bool), (bool, int32_t, int32_t), (bool, int16_t, bool), \ + (bool, int16_t, int16_t), (bool, float16, bool), \ + (bool, float16, float16) #else -#define PAIRWISE_TYPES_0 \ -(float16, float16, float16) , \ -(float16, bool, float16) +#define PAIRWISE_TYPES_0 (float16, float16, float16), (float16, bool, float16) -#define PAIRWISE_TYPES_1 \ -(float, float, float) , \ -(float, bool, float) +#define PAIRWISE_TYPES_1 (float, float, float), (float, bool, float) -#define PAIRWISE_TYPES_2 \ -(double, double, double) , \ -(double, bool, double) +#define PAIRWISE_TYPES_2 (double, double, double), (double, bool, double) -#define PAIRWISE_TYPES_3 \ -(int8_t, int8_t, int8_t) , \ -(int8_t, bool, int8_t) +#define PAIRWISE_TYPES_3 (int8_t, int8_t, int8_t), (int8_t, bool, int8_t) -#define PAIRWISE_TYPES_4 \ -(int16_t, int16_t, int16_t) , \ -(int16_t, bool, int16_t) +#define PAIRWISE_TYPES_4 (int16_t, int16_t, int16_t), (int16_t, bool, int16_t) -#define PAIRWISE_TYPES_5 \ -(uint8_t, uint8_t, uint8_t) , \ -(uint8_t, bool, uint8_t) +#define PAIRWISE_TYPES_5 (uint8_t, uint8_t, uint8_t), (uint8_t, bool, uint8_t) -#define PAIRWISE_TYPES_6 \ -(int, int, int) ,\ -(int, bool, int) +#define PAIRWISE_TYPES_6 (int, int, int), (int, bool, int) -#define PAIRWISE_TYPES_7 \ -(bool, bool, bool) +#define PAIRWISE_TYPES_7 (bool, bool, bool) #define PAIRWISE_TYPES_8 \ -(Nd4jLong, Nd4jLong, Nd4jLong) ,\ -(Nd4jLong, bool, Nd4jLong) + (Nd4jLong, Nd4jLong, Nd4jLong), (Nd4jLong, bool, Nd4jLong) #define PAIRWISE_TYPES_9 \ -(bfloat16, bfloat16, bfloat16) , \ -(bfloat16, bool, bfloat16) + (bfloat16, bfloat16, bfloat16), (bfloat16, bool, bfloat16) #define PAIRWISE_TYPES_10 \ -(uint64_t, uint64_t, uint64_t) ,\ -(uint64_t, bool, uint64_t) + (uint64_t, uint64_t, uint64_t), (uint64_t, bool, uint64_t) #define PAIRWISE_TYPES_11 \ -(uint32_t, uint32_t, uint32_t) ,\ -(uint32_t, bool, uint32_t) + (uint32_t, uint32_t, uint32_t), (uint32_t, bool, uint32_t) #define PAIRWISE_TYPES_12 \ -(uint16_t, uint16_t, uint16_t) ,\ -(uint16_t, bool, uint16_t) + (uint16_t, uint16_t, uint16_t), (uint16_t, bool, uint16_t) #endif -#endif //LIBND4J_TYPES_H - +#endif // LIBND4J_TYPES_H diff --git a/libnd4j/include/types/u64.h b/libnd4j/include/types/u64.h index 908a9ba1cd63..ef9537ce3eeb 100644 --- a/libnd4j/include/types/u64.h +++ b/libnd4j/include/types/u64.h @@ -20,45 +20,43 @@ #ifndef LIBND4J_U64_H #define LIBND4J_U64_H -#include #include #include +#include namespace sd { - typedef struct { - int16_t _v0; - int16_t _v1; - int16_t _v2; - int16_t _v3; - } di16; - - typedef struct { - int _v0; - int _v1; - } di32; - - typedef struct { - uint32_t _v0; - uint32_t _v1; - } du32; - - union u64 { - bool _bool; - int8_t _char; - int16_t _short; - int32_t _int; - //float16 _half = 0.0f; - float _float; - double _double; - Nd4jLong _long; - uint64_t _ulong; - di32 _di32; - du32 _du32; - u64() { - _long = 0; - } - }; -} +typedef struct { + int16_t _v0; + int16_t _v1; + int16_t _v2; + int16_t _v3; +} di16; + +typedef struct { + int _v0; + int _v1; +} di32; + +typedef struct { + uint32_t _v0; + uint32_t _v1; +} du32; + +union u64 { + bool _bool; + int8_t _char; + int16_t _short; + int32_t _int; + // float16 _half = 0.0f; + float _float; + double _double; + Nd4jLong _long; + uint64_t _ulong; + di32 _di32; + du32 _du32; + u64() { _long = 0; } +}; +} // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/types/uint16.h b/libnd4j/include/types/uint16.h index 5fee50e7aa30..d45454336538 100644 --- a/libnd4j/include/types/uint16.h +++ b/libnd4j/include/types/uint16.h @@ -24,75 +24,66 @@ #include #include - namespace sd { - uint16_t _CUDA_HD FORCEINLINE cpu_float2uint16(float data); - float _CUDA_HD FORCEINLINE cpu_uint162float(uint16_t data); +uint16_t _CUDA_HD FORCEINLINE cpu_float2uint16(float data); +float _CUDA_HD FORCEINLINE cpu_uint162float(uint16_t data); - struct uint16 { - uint16_t data; +struct uint16 { + uint16_t data; - _CUDA_HD FORCEINLINE uint16(); - _CUDA_HD FORCEINLINE ~uint16(); + _CUDA_HD FORCEINLINE uint16(); + _CUDA_HD FORCEINLINE ~uint16(); - template - _CUDA_HD FORCEINLINE uint16(const T& rhs); + template + _CUDA_HD FORCEINLINE uint16(const T& rhs); - template - _CUDA_HD FORCEINLINE uint16& operator=(const T& rhs); + template + _CUDA_HD FORCEINLINE uint16& operator=(const T& rhs); - _CUDA_HD FORCEINLINE operator float() const; + _CUDA_HD FORCEINLINE operator float() const; - _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE void assign(double rhs); - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + _CUDA_HD FORCEINLINE void assign(float rhs); +}; //////////////////// IMPLEMENTATIONS - float _CUDA_HD cpu_uint162float(uint16_t data) { - return static_cast(data); - } +float _CUDA_HD cpu_uint162float(uint16_t data) { + return static_cast(data); +} - uint16_t _CUDA_HD cpu_float2uint16(float data) { - auto t = static_cast(data); - if (t > 65536 ) t = 65536; - if (t < 0) t = 0; +uint16_t _CUDA_HD cpu_float2uint16(float data) { + auto t = static_cast(data); + if (t > 65536) t = 65536; + if (t < 0) t = 0; - return static_cast(t); - } + return static_cast(t); +} - _CUDA_HD uint16::uint16() { - data = cpu_float2uint16(0.0f); - } +_CUDA_HD uint16::uint16() { data = cpu_float2uint16(0.0f); } - _CUDA_HD uint16::~uint16() { - // - } +_CUDA_HD uint16::~uint16() { + // +} - template - _CUDA_HD uint16::uint16(const T& rhs) { - assign(rhs); - } +template +_CUDA_HD uint16::uint16(const T& rhs) { + assign(rhs); +} - template - _CUDA_HD uint16& uint16::operator=(const T& rhs) { - assign(rhs); - return *this; - } +template +_CUDA_HD uint16& uint16::operator=(const T& rhs) { + assign(rhs); + return *this; +} - _CUDA_HD uint16::operator float() const { - return cpu_uint162float(data); - } +_CUDA_HD uint16::operator float() const { return cpu_uint162float(data); } - _CUDA_HD void uint16::assign(float rhs) { - data = cpu_float2uint16(rhs); - } +_CUDA_HD void uint16::assign(float rhs) { data = cpu_float2uint16(rhs); } - _CUDA_HD void uint16::assign(double rhs) { - assign((float)rhs); - } -} +_CUDA_HD void uint16::assign(double rhs) { assign((float)rhs); } +} // namespace sd -#endif //LIBND4J_UINT16_H +#endif // LIBND4J_UINT16_H diff --git a/libnd4j/include/types/uint8.h b/libnd4j/include/types/uint8.h index a2505c9ab64c..b3c8de86f327 100644 --- a/libnd4j/include/types/uint8.h +++ b/libnd4j/include/types/uint8.h @@ -24,71 +24,62 @@ #include #include - namespace sd { - float _CUDA_HD FORCEINLINE cpu_uint82float(uint8_t data); - uint8_t _CUDA_HD FORCEINLINE cpu_float2uint8(float data); - - struct uint8 { - uint8_t data; - - _CUDA_HD FORCEINLINE uint8(); - _CUDA_HD FORCEINLINE ~uint8() = default; - - template - _CUDA_HD FORCEINLINE uint8(const T& rhs); - - template - _CUDA_HD FORCEINLINE uint8& operator=(const T& rhs); - +float _CUDA_HD FORCEINLINE cpu_uint82float(uint8_t data); +uint8_t _CUDA_HD FORCEINLINE cpu_float2uint8(float data); - _CUDA_HD FORCEINLINE operator float() const; +struct uint8 { + uint8_t data; - _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE uint8(); + _CUDA_HD FORCEINLINE ~uint8() = default; - _CUDA_HD FORCEINLINE void assign(float rhs); - }; + template + _CUDA_HD FORCEINLINE uint8(const T& rhs); + template + _CUDA_HD FORCEINLINE uint8& operator=(const T& rhs); + _CUDA_HD FORCEINLINE operator float() const; - /////////////////////////// + _CUDA_HD FORCEINLINE void assign(double rhs); + _CUDA_HD FORCEINLINE void assign(float rhs); +}; - float cpu_uint82float(uint8_t data) { - return static_cast(static_cast(data)); - } +/////////////////////////// - uint8_t cpu_float2uint8(float data) { - auto t = static_cast(data); - if (t > 255) t = 255; - if (t < 0) t = 0; +float cpu_uint82float(uint8_t data) { + return static_cast(static_cast(data)); +} - return static_cast(t); - } +uint8_t cpu_float2uint8(float data) { + auto t = static_cast(data); + if (t > 255) t = 255; + if (t < 0) t = 0; - uint8::uint8() { data = cpu_float2uint8(0.0f); } + return static_cast(t); +} - template - uint8::uint8(const T& rhs) { - assign(rhs); - } +uint8::uint8() { data = cpu_float2uint8(0.0f); } - template - uint8& uint8::operator=(const T& rhs) { assign(rhs); return *this; } +template +uint8::uint8(const T& rhs) { + assign(rhs); +} +template +uint8& uint8::operator=(const T& rhs) { + assign(rhs); + return *this; +} - uint8::operator float() const { - return cpu_uint82float(data); - } +uint8::operator float() const { return cpu_uint82float(data); } - void uint8::assign(double rhs) { - assign(static_cast(rhs)); - } +void uint8::assign(double rhs) { assign(static_cast(rhs)); } - void uint8::assign(float rhs) { - data = cpu_float2uint8(rhs); - } -} +void uint8::assign(float rhs) { data = cpu_float2uint8(rhs); } +} // namespace sd -#endif //LIBND4J_UINT8_H +#endif // LIBND4J_UINT8_H diff --git a/libnd4j/include/types/utf8string.h b/libnd4j/include/types/utf8string.h index efab2c027c8d..8a580d1ee364 100644 --- a/libnd4j/include/types/utf8string.h +++ b/libnd4j/include/types/utf8string.h @@ -21,29 +21,30 @@ #ifndef SD_UTF8STRING_H #define SD_UTF8STRING_H -#include #include +#include + namespace sd { - struct SD_EXPORT utf8string { - private: - bool _allocated = false; - public: - char *_buffer = nullptr; - unsigned int _length = 0; +struct SD_EXPORT utf8string { + private: + bool _allocated = false; - utf8string(); - ~utf8string(); + public: + char *_buffer = nullptr; + unsigned int _length = 0; - utf8string(const char *string, int length); - utf8string(const std::string &string); - utf8string(const utf8string &other); - utf8string& operator=(const utf8string &other); + utf8string(); + ~utf8string(); - protected: - void Swap(utf8string &other); - }; -} + utf8string(const char *string, int length); + utf8string(const std::string &string); + utf8string(const utf8string &other); + utf8string &operator=(const utf8string &other); + protected: + void Swap(utf8string &other); +}; +} // namespace sd -#endif //SD_UTF8STRING_H +#endif // SD_UTF8STRING_H diff --git a/libnd4j/minifier/graphopt.cpp b/libnd4j/minifier/graphopt.cpp index 50321aacc66c..28fa29b121b2 100644 --- a/libnd4j/minifier/graphopt.cpp +++ b/libnd4j/minifier/graphopt.cpp @@ -21,113 +21,112 @@ * */ +#include "graphopt.h" + #include #include -#include "graphopt.h" - -std::ostream& -operator<< (std::ostream& out, GraphOpt const& opts) { - if (opts._files.empty() && opts._opts.empty()) { - out << "Empty options" << std::endl; - return out; - } - out << "==================================================" << std::endl; - out << "Files:" << std::endl; - int index = 1; - for (auto file: opts._files) { - out << "File " << index++ << ": " << file << std::endl; - } - out << "Options:" << std::endl; - for (char opt: opts._opts) { - out << "Option: " << opt; - if (opts._args.find(opt) != opts._args.end()) { - out << " with arg: " << opts._args.at(opt) << std::endl; - } - else { - out << std::endl; - } - } - out << "=================================================="; +std::ostream& operator<<(std::ostream& out, GraphOpt const& opts) { + if (opts._files.empty() && opts._opts.empty()) { + out << "Empty options" << std::endl; return out; + } + out << "==================================================" << std::endl; + out << "Files:" << std::endl; + int index = 1; + for (auto file : opts._files) { + out << "File " << index++ << ": " << file << std::endl; + } + out << "Options:" << std::endl; + for (char opt : opts._opts) { + out << "Option: " << opt; + if (opts._args.find(opt) != opts._args.end()) { + out << " with arg: " << opts._args.at(opt) << std::endl; + } else { + out << std::endl; + } + } + out << "=================================================="; + return out; } //////////////////////////////////////////////////////////////////////////////// -int -GraphOpt::optionsWithArgs(int argc, char* argv[], GraphOpt& res) { - char* optArg = nullptr; - int optIndex = 1; - - char const* optionStr = "lxa:o:e"; - std::string const defaultOutputName("nd4jlib_mini"); - - for (optIndex = 1; (optIndex < argc) && (argv[optIndex][0] == '-') && - (argv[optIndex][0]); optIndex++) { - - int opt = argv[optIndex][1]; - - if (opt == '?' || opt == 'h') { - res.help(argv[0], std::cout); - res.reset(); - return 1; - } - - char const* p = strchr(optionStr, opt); - - if (p == nullptr) - { - std::cerr << "opt " << (char)opt << " not found with " << optionStr << std::endl; - res._opts.push_back('?'); - res.reset(); - return -1; - } - else { - res._opts.push_back(opt); - - if (p[1] == ':') // processing param with - { - optIndex++; - if (optIndex >= argc) - { - std::cerr << "optIndex " << optIndex << " is out of bounds " << argc << std::endl; - res.reset(); - res._opts.push_back('?'); - return -2; - } - res._args[opt] = std::string(argv[optIndex]); - } - } - } - - if ( !res.hasParam('l') && !res.hasParam('x') ) { - std::cerr << "No -l or -x params are provided. At least one of them should be used." << std::endl; - res.reset(); - res._opts.push_back('?'); - return -3; +int GraphOpt::optionsWithArgs(int argc, char* argv[], GraphOpt& res) { + char* optArg = nullptr; + int optIndex = 1; + + char const* optionStr = "lxa:o:e"; + std::string const defaultOutputName("nd4jlib_mini"); + + for (optIndex = 1; + (optIndex < argc) && (argv[optIndex][0] == '-') && (argv[optIndex][0]); + optIndex++) { + int opt = argv[optIndex][1]; + + if (opt == '?' || opt == 'h') { + res.help(argv[0], std::cout); + res.reset(); + return 1; } - if (res._args.empty()) - res._args['o'] = defaultOutputName; - - for ( ; optIndex < argc; optIndex++) { - res._files.push_back(std::string(argv[optIndex])); + char const* p = strchr(optionStr, opt); + + if (p == nullptr) { + std::cerr << "opt " << (char)opt << " not found with " << optionStr + << std::endl; + res._opts.push_back('?'); + res.reset(); + return -1; + } else { + res._opts.push_back(opt); + + if (p[1] == ':') // processing param with + { + optIndex++; + if (optIndex >= argc) { + std::cerr << "optIndex " << optIndex << " is out of bounds " << argc + << std::endl; + res.reset(); + res._opts.push_back('?'); + return -2; + } + res._args[opt] = std::string(argv[optIndex]); + } } - return 0; + } + + if (!res.hasParam('l') && !res.hasParam('x')) { + std::cerr << "No -l or -x params are provided. At least one of them should " + "be used." + << std::endl; + res.reset(); + res._opts.push_back('?'); + return -3; + } + + if (res._args.empty()) res._args['o'] = defaultOutputName; + + for (; optIndex < argc; optIndex++) { + res._files.push_back(std::string(argv[optIndex])); + } + return 0; } //////////////////////////////////////////////////////////////////////////////// -std::ostream& -GraphOpt::help(std::string app, std::ostream& out) { - out << "Usage: \n" << app << " [-lxe] [-o outname] filename1 " - "[filename2 filename3 ... filenameN]" << std::endl; - out << "Parameters:" << std::endl; - out << "\t-l\t Generate library" << std::endl; - out << "\t-x\t Generate executable" << std::endl; - out << "\t-e\t Embed the Graph(s) into executable as resource" << std::endl; - out << "\t-o Set up output name (for library, executable or both)" << std::endl; - out << "\t-a target CPU architecture" << std::endl; - out << "\t-h\t This help" << std::endl; - - return out; +std::ostream& GraphOpt::help(std::string app, std::ostream& out) { + out << "Usage: \n" + << app + << " [-lxe] [-o outname] filename1 " + "[filename2 filename3 ... filenameN]" + << std::endl; + out << "Parameters:" << std::endl; + out << "\t-l\t Generate library" << std::endl; + out << "\t-x\t Generate executable" << std::endl; + out << "\t-e\t Embed the Graph(s) into executable as resource" << std::endl; + out << "\t-o Set up output name (for library, executable or both)" + << std::endl; + out << "\t-a target CPU architecture" << std::endl; + out << "\t-h\t This help" << std::endl; + + return out; } - diff --git a/libnd4j/minifier/graphopt.h b/libnd4j/minifier/graphopt.h index 329fc22d6a7d..d236e008c5d1 100644 --- a/libnd4j/minifier/graphopt.h +++ b/libnd4j/minifier/graphopt.h @@ -17,8 +17,8 @@ /* * GraphOpt class declarations * - * GraphOpt class used for parsing command line arguments - * + * GraphOpt class used for parsing command line arguments + * * * Created by GS 3/2/2018 * @@ -27,49 +27,51 @@ #ifndef __H__GRAPH_OPTIONS__ #define __H__GRAPH_OPTIONS__ -#include +#include +#include #include +#include #include -#include -#include class GraphOpt { -public: - typedef std::list FileList; - typedef std::list OptionList; - typedef std::unordered_map ArgumentDict; -public: - GraphOpt() - {} + public: + typedef std::list FileList; + typedef std::list OptionList; + typedef std::unordered_map ArgumentDict; - static int optionsWithArgs(int argc, char* argv[], GraphOpt& options); + public: + GraphOpt() {} - FileList& files() { return _files; } - FileList const& files() const { return _files; } - OptionList const& options() const { return _opts; } - std::string outputName() const { return _args.at('o'); } - std::string arch() const { - if (_args.count('a') < 1) { - printf("No Arg!!!\n"); - fflush(stdout); - } - return _args.at('a'); - }; - std::ostream& help(std::string app, std::ostream& out); - bool hasParam(int param) const { return std::find(_opts.begin(), _opts.end(), param) != _opts.end(); } - - friend std::ostream& operator<< (std::ostream& out, GraphOpt const& opts); + static int optionsWithArgs(int argc, char* argv[], GraphOpt& options); - void reset() { - _files.clear(); - _opts.clear(); - _args.clear(); + FileList& files() { return _files; } + FileList const& files() const { return _files; } + OptionList const& options() const { return _opts; } + std::string outputName() const { return _args.at('o'); } + std::string arch() const { + if (_args.count('a') < 1) { + printf("No Arg!!!\n"); + fflush(stdout); } + return _args.at('a'); + }; + std::ostream& help(std::string app, std::ostream& out); + bool hasParam(int param) const { + return std::find(_opts.begin(), _opts.end(), param) != _opts.end(); + } + + friend std::ostream& operator<<(std::ostream& out, GraphOpt const& opts); + + void reset() { + _files.clear(); + _opts.clear(); + _args.clear(); + } -private: - FileList _files; - OptionList _opts; - ArgumentDict _args; + private: + FileList _files; + OptionList _opts; + ArgumentDict _args; }; #endif diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index 3d047211deac..4867678d66cf 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -17,140 +17,142 @@ #include #ifdef _WIN32 - #include +#include #else - #include +#include #endif -#include -#include "graphopt.h" +#include #include #include -#include - -using namespace sd::ops; -using namespace sd::graph; - -int -main(int argc, char *argv[]) { - // this string will contain list of operations - std::string opts_arg; - // this string will contain optional name for output binary file - std::string name_arg; +#include - // this string will contain binary compilation mode: shared/static/executable - std::string build_arg; +#include "graphopt.h" - // this string will contain target arch/optimization mode - std::string arch_arg; +using namespace sd::ops; +using namespace sd::graph; - GraphOpt opt; - int err = GraphOpt::optionsWithArgs(argc, argv, opt); - - //std::cout << opt << std::endl; - if (err > 0) { - // only help message - return err; +int main(int argc, char *argv[]) { + // this string will contain list of operations + std::string opts_arg; + + // this string will contain optional name for output binary file + std::string name_arg; + + // this string will contain binary compilation mode: shared/static/executable + std::string build_arg; + + // this string will contain target arch/optimization mode + std::string arch_arg; + + GraphOpt opt; + int err = GraphOpt::optionsWithArgs(argc, argv, opt); + + // std::cout << opt << std::endl; + if (err > 0) { + // only help message + return err; + } + + if (err < 0) { + std::cerr << "Wrong parameter list" << std::endl; + opt.help(argv[0], std::cerr); + return err; + } + + for (int option : opt.options()) { + std::cout << "Option \'" << (char)option << "\': "; + switch (option) { + case 'l': + std::cout << "Build library" << std::endl; + break; + case 'x': + std::cout << "Build executable" << std::endl; + break; + case 'e': + std::cout << "Link the Graph to executable as Resource" << std::endl; + break; + case 'o': + std::cout << "Output file name is " << opt.outputName() << std::endl; + break; + case 'a': + std::cout << "Target arch: " << opt.arch() << std::endl; + break; + default: + std::cerr << "Wrong parameter " << (char)option << std::endl; } + } - if (err < 0) { - std::cerr << "Wrong parameter list" << std::endl; - opt.help(argv[0], std::cerr); - return err; - } - - for (int option: opt.options()) { - std::cout << "Option \'" << (char)option <<"\': "; - switch (option) { - case 'l': - std::cout << "Build library" << std::endl; - break; - case 'x': - std::cout << "Build executable" << std::endl; - break; - case 'e': - std::cout << "Link the Graph to executable as Resource" << std::endl; - break; - case 'o': - std::cout << "Output file name is " << opt.outputName() << std::endl; - break; - case 'a': - std::cout << "Target arch: " << opt.arch() << std::endl; - break; - default: - std::cerr << "Wrong parameter " << (char)option << std::endl; - } - } - - if (!opt.hasParam('o')) { - std::cout << "Ouput file name is " << opt.outputName() << std::endl; - } + if (!opt.hasParam('o')) { + std::cout << "Ouput file name is " << opt.outputName() << std::endl; + } + + name_arg = " --name \'" + opt.outputName() + "\' "; - name_arg = " --name \'" + opt.outputName() + "\' "; + if (opt.hasParam('a')) arch_arg = opt.arch(); - if (opt.hasParam('a')) - arch_arg = opt.arch(); - - std::vector descriptors; - nd4j_printf("Total available operations: %i\n", OpRegistrator::getInstance()->numberOfOperations()); + std::vector descriptors; + nd4j_printf("Total available operations: %i\n", + OpRegistrator::getInstance()->numberOfOperations()); - for (auto file: opt.files()) { - // all files will be checked for accessibility & size + for (auto file : opt.files()) { + // all files will be checked for accessibility & size #ifdef _WIN32 - if (_access(file.c_str(), 1) != -1) { + if (_access(file.c_str(), 1) != -1) { #else - if (access(file.c_str(), F_OK | R_OK) != -1) { + if (access(file.c_str(), F_OK | R_OK) != -1) { #endif #ifdef _WIN32 - struct _stat st; - _stat(file.c_str(), &st); + struct _stat st; + _stat(file.c_str(), &st); #else - struct stat st; - stat(file.c_str(), &st); -#endif - if (st.st_size != 0) { - //std::cout << "File " << file << " exists and can be read" << std::endl; - auto graph = Graph::fromFlatBuffers(file.c_str()); - auto ops = graph->getOperations(); - - for (auto &v:ops) { - descriptors.emplace_back(v); - } - } else { - std::cerr << "File " << file << " exists, but has zero size" << std::endl; - return 2; - } - } - else { - std::cerr << "File " << file << " does not exists " << std::endl; - return 10; + struct stat st; + stat(file.c_str(), &st); +#endif + if (st.st_size != 0) { + // std::cout << "File " << file << " exists and can be read" << + // std::endl; + auto graph = Graph::fromFlatBuffers(file.c_str()); + auto ops = graph->getOperations(); + + for (auto &v : ops) { + descriptors.emplace_back(v); } + } else { + std::cerr << "File " << file << " exists, but has zero size" + << std::endl; + return 2; + } + } else { + std::cerr << "File " << file << " does not exists " << std::endl; + return 10; } + } - if (!descriptors.empty()) { - GraphUtils::filterOperations(descriptors); - - nd4j_printf("Operations found so far:\n",""); - for (auto &v: descriptors) { - nd4j_printf("%s\n", v.getOpName()->c_str()); - } + if (!descriptors.empty()) { + GraphUtils::filterOperations(descriptors); - // building list of operations - opts_arg = GraphUtils::makeCommandLine(descriptors); + nd4j_printf("Operations found so far:\n", ""); + for (auto &v : descriptors) { + nd4j_printf("%s\n", v.getOpName()->c_str()); } - nd4j_printf("\n",""); - std::string output(opt.outputName()); + // building list of operations + opts_arg = GraphUtils::makeCommandLine(descriptors); + } + nd4j_printf("\n", ""); - std::string input("../include/ops/declarable/CustomOperations.h"); + std::string output(opt.outputName()); - if (0 == GraphUtils::runPreprocessor(input.c_str(), output.c_str())) { - nd4j_printf("All done successfully.\n", ""); - } + std::string input("../include/ops/declarable/CustomOperations.h"); + + if (0 == GraphUtils::runPreprocessor(input.c_str(), output.c_str())) { + nd4j_printf("All done successfully.\n", ""); + } - //nd4j_printf("Command line: %s\n", cmdline.c_str()); - // FIXME: do this in cross-platform way - nd4j_printf("Building minified library...\n", ""); + // nd4j_printf("Command line: %s\n", cmdline.c_str()); + // FIXME: do this in cross-platform way + nd4j_printf("Building minified library...\n", ""); - return EXIT_SUCCESS; + return EXIT_SUCCESS; } diff --git a/libnd4j/server/GraphServer.cpp b/libnd4j/server/GraphServer.cpp index a9e8c3ddc853..2a34c3ccc37f 100644 --- a/libnd4j/server/GraphServer.cpp +++ b/libnd4j/server/GraphServer.cpp @@ -19,121 +19,132 @@ // #include "GraphServer.h" -#include + #include +#include +#include +#include +#include +#include #include #include + #include #include -#include -#include -#include -#include +namespace sd { +namespace graph { +grpc::Status GraphInferenceServerImpl::RegisterGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto flat_graph = request_msg->GetRoot(); + + try { + // building our graph + auto graph = new Graph(flat_graph); + + // single data type for now + GraphHolder::getInstance()->registerGraph(flat_graph->id(), graph); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } +} +grpc::Status GraphInferenceServerImpl::ReplaceGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto flat_graph = request_msg->GetRoot(); + + try { + // building our graph + auto graph = new Graph(flat_graph); + + // single data type for now + GraphHolder::getInstance()->replaceGraph(flat_graph->id(), graph); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } +} +grpc::Status GraphInferenceServerImpl::ForgetGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + try { + // getting drop request + auto request = request_msg->GetRoot(); + + // dropping out graph (any datatype) + GraphHolder::getInstance()->dropGraphAny(request->id()); + + // sending out OK response + auto response_offset = CreateFlatResponse(mb_, 0); + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } +} -namespace sd { - namespace graph { - grpc::Status GraphInferenceServerImpl::RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto flat_graph = request_msg->GetRoot(); - - try { - // building our graph - auto graph = new Graph(flat_graph); - - // single data type for now - GraphHolder::getInstance()->registerGraph(flat_graph->id(), graph); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto flat_graph = request_msg->GetRoot(); - - try { - // building our graph - auto graph = new Graph(flat_graph); - - // single data type for now - GraphHolder::getInstance()->replaceGraph(flat_graph->id(), graph); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - try { - - // getting drop request - auto request = request_msg->GetRoot(); - - // dropping out graph (any datatype) - GraphHolder::getInstance()->dropGraphAny(request->id()); - - // sending out OK response - auto response_offset = CreateFlatResponse(mb_, 0); - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } - } - - grpc::Status GraphInferenceServerImpl::InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg) { - auto request = request_msg->GetRoot(); - - try { - // GraphHolder - auto response_offset = GraphHolder::getInstance()->execute(request->id(), mb_, request); - - mb_.Finish(response_offset); - *response_msg = mb_.ReleaseMessage(); - assert(response_msg->Verify()); - - return grpc::Status::OK; - } catch (sd::graph::no_results_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); - } catch (sd::graph::unknown_graph_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); - } catch (sd::graph::graph_execution_exception &e) { - grpc::string gmsg(e.message()); - return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); - } catch (std::runtime_error &e) { - grpc::string gmsg("Caught runtime_error exception"); - return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); - } - } - } +grpc::Status GraphInferenceServerImpl::InferenceRequest( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg) { + auto request = request_msg->GetRoot(); + + try { + // GraphHolder + auto response_offset = + GraphHolder::getInstance()->execute(request->id(), mb_, request); + + mb_.Finish(response_offset); + *response_msg = mb_.ReleaseMessage(); + assert(response_msg->Verify()); + + return grpc::Status::OK; + } catch (sd::graph::no_results_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); + } catch (sd::graph::unknown_graph_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg); + } catch (sd::graph::graph_execution_exception &e) { + grpc::string gmsg(e.message()); + return grpc::Status(grpc::StatusCode::INTERNAL, gmsg); + } catch (std::runtime_error &e) { + grpc::string gmsg("Caught runtime_error exception"); + return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg); + } } +} // namespace graph +} // namespace sd void RunServer(int port) { assert(port > 0 && port < 65535); @@ -148,43 +159,44 @@ void RunServer(int port) { builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); - std::cerr << "Server listening on: [" << server_address << "]; Number of operations: [" << registrator->numberOfOperations() << "]"<< std::endl; + std::cerr << "Server listening on: [" << server_address + << "]; Number of operations: [" << registrator->numberOfOperations() + << "]" << std::endl; server->Wait(); } -char* getCmdOption(char **begin, char **end, const std::string & option) { - auto itr = std::find(begin, end, option); - if (itr != end && ++itr != end) - return *itr; +char *getCmdOption(char **begin, char **end, const std::string &option) { + auto itr = std::find(begin, end, option); + if (itr != end && ++itr != end) return *itr; - return 0; + return 0; } -bool cmdOptionExists(char** begin, char** end, const std::string& option) { - return std::find(begin, end, option) != end; +bool cmdOptionExists(char **begin, char **end, const std::string &option) { + return std::find(begin, end, option) != end; } int main(int argc, char *argv[]) { - /** - * basically we only care about few things here: - * 1) port number - * 2) if we should use gprc, json, or both - * 3) if there's any graph(s) provided at startup - */ - int port = 40123; - if(cmdOptionExists(argv, argv+argc, "-p")) { - auto sPort = getCmdOption(argv, argv + argc, "-p"); - port = atoi(sPort); - } - - if(cmdOptionExists(argv, argv+argc, "-f")) { - auto file = getCmdOption(argv, argv + argc, "-f"); - auto graph = GraphExecutioner::importFromFlatBuffers(file); - sd::graph::GraphHolder::getInstance()->registerGraph(0L, graph); - } - - RunServer(port); - - return 0; + /** + * basically we only care about few things here: + * 1) port number + * 2) if we should use gprc, json, or both + * 3) if there's any graph(s) provided at startup + */ + int port = 40123; + if (cmdOptionExists(argv, argv + argc, "-p")) { + auto sPort = getCmdOption(argv, argv + argc, "-p"); + port = atoi(sPort); + } + + if (cmdOptionExists(argv, argv + argc, "-f")) { + auto file = getCmdOption(argv, argv + argc, "-f"); + auto graph = GraphExecutioner::importFromFlatBuffers(file); + sd::graph::GraphHolder::getInstance()->registerGraph(0L, graph); + } + + RunServer(port); + + return 0; } \ No newline at end of file diff --git a/libnd4j/server/GraphServer.h b/libnd4j/server/GraphServer.h index 0dceacf25a33..4be077443e61 100644 --- a/libnd4j/server/GraphServer.h +++ b/libnd4j/server/GraphServer.h @@ -18,27 +18,38 @@ // @author raver119@gmail.com // - -#include #include #include -#include - #include +#include +#include namespace sd { - namespace graph { - class GraphInferenceServerImpl final : public GraphInferenceServer::Service { - private: - flatbuffers::grpc::MessageBuilder mb_; - public: - virtual grpc::Status RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - - virtual grpc::Status InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message *request_msg, flatbuffers::grpc::Message *response_msg); - }; - } -} \ No newline at end of file +namespace graph { +class GraphInferenceServerImpl final : public GraphInferenceServer::Service { + private: + flatbuffers::grpc::MessageBuilder mb_; + + public: + virtual grpc::Status RegisterGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status ForgetGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status ReplaceGraph( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); + + virtual grpc::Status InferenceRequest( + grpc::ServerContext *context, + const flatbuffers::grpc::Message *request_msg, + flatbuffers::grpc::Message *response_msg); +}; +} // namespace graph +} // namespace sd \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AllTests.cpp b/libnd4j/tests_cpu/layers_tests/AllTests.cpp index 669a5b1d0bed..32c84a4704c9 100644 --- a/libnd4j/tests_cpu/layers_tests/AllTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/AllTests.cpp @@ -20,19 +20,19 @@ // #include "testlayers.h" /* +#include "ConvolutionTests.cpp" +#include "DeclarableOpsTests.cpp" #include "DenseLayerTests.cpp" +#include "FlatBuffersTests.cpp" +#include "GraphTests.cpp" +#include "HashUtilsTests.cpp" #include "NDArrayTests.cpp" +#include "SessionLocalTests.cpp" +#include "StashTests.cpp" +#include "TadTests.cpp" #include "VariableSpaceTests.cpp" #include "VariableTests.cpp" -#include "DeclarableOpsTests.cpp" -#include "HashUtilsTests.cpp" #include "WorkspaceTests.cpp" -#include "ConvolutionTests.cpp" -#include "TadTests.cpp" -#include "StashTests.cpp" -#include "SessionLocalTests.cpp" -#include "GraphTests.cpp" -#include "FlatBuffersTests.cpp" */ /////// @@ -40,6 +40,6 @@ // #include "ProtoBufTests.cpp" int main(int argc, char **argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp index 97893ca5bc8d..507d5f744ea2 100644 --- a/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ArrayOptionsTests.cpp @@ -18,92 +18,89 @@ // Created by raver119 on 13.01.2018. // -#include "testlayers.h" #include #include -using namespace sd; +#include "testlayers.h" +using namespace sd; class ArrayOptionsTests : public testing::Test { -public: - Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99}; + public: + Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99}; }; TEST_F(ArrayOptionsTests, TestShape_Basic_0) { - shape[5] = 1; + shape[5] = 1; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); } - TEST_F(ArrayOptionsTests, TestShape_Basic_1) { - shape[5] = 2; - + shape[5] = 2; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_2) { - shape[5] = 258; + shape[5] = 258; - ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); - ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_3) { - ASSERT_EQ(0, shape::extra(shape)); + ASSERT_EQ(0, shape::extra(shape)); - ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_4) { + ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED}); - ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED}); - - auto dtype = ArrayOptions::dataType(shape); + auto dtype = ArrayOptions::dataType(shape); - ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); - ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape)); - ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape)); - ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape)); + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape)); + ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_5) { - ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC}); + ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC}); - ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); - ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape)); - ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_6) { - ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC}); + ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC}); - ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape)); + ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_7) { - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_8) { - ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); } TEST_F(ArrayOptionsTests, TestShape_Basic_9) { - ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); - ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu index f8248e7ea52f..1030edd88678 100644 --- a/libnd4j/tests_cpu/layers_tests/AtomicTests.cu +++ b/libnd4j/tests_cpu/layers_tests/AtomicTests.cu @@ -18,224 +18,253 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include +#include #include #include -#include +#include +#include +#include "testlayers.h" using namespace sd; - class AtomicTests : public testing::Test { -public: - AtomicTests() { - // - } + public: + AtomicTests() { + // + } }; template -static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); +static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, + void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); + } } template static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { - multiplyKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("multiply failed", err); + multiplyKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("multiply failed", err); } template static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); + } } template static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) { - sumKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("sum failed", err); + sumKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("sum failed", err); } template static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); + } } template static void subLauncher(void *vbuffer, uint64_t length, void *vresult) { - subKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("sub failed", err); + subKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("sub failed", err); } template static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) { - auto buffer = reinterpret_cast(vbuffer); - auto result = reinterpret_cast(vresult); + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); - auto tid = blockIdx.x * blockDim.x + threadIdx.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { - auto rem = e % 4; - auto i = (e - rem) / 4; + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; - sd::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); - } + sd::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); + } } template static void divLauncher(void *vbuffer, uint64_t length, void *vresult) { - divKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); - auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); - if (err != 0) - throw sd::cuda_exception::build("div failed", err); + divKernel<<<256, 256, 1024, + *sd::LaunchContext::defaultContext()->getCudaStream()>>>( + vbuffer, length, vresult); + auto err = cudaStreamSynchronize( + *sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) throw sd::cuda_exception::build("div failed", err); } static void multiplyHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), multiplyLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + NUMERIC_TYPES); } static void sumHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), sumLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + NUMERIC_TYPES); } static void subHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), subLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + FLOAT_TYPES); } static void divHost(NDArray &input, NDArray &output) { - BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR( + input.dataType(), divLauncher, + (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), + FLOAT_TYPES); } TEST_F(AtomicTests, test_multiply) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::INT16, sd::DataType::HALF}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(2); - output.assign(2); - exp.assign(32); - - multiplyHost(input, output); - ASSERT_EQ(exp, output); - } + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::INT16, + sd::DataType::HALF}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(2); + exp.assign(32); + + multiplyHost(input, output); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_multiply_2) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF, sd::DataType::BFLOAT16}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(1.5); - output.assign(2); - exp.assign(10.125); - - multiplyHost(input, output); -// output.printBuffer("multiply 2"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::HALF, + sd::DataType::BFLOAT16}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1.5); + output.assign(2); + exp.assign(10.125); + + multiplyHost(input, output); + // output.printBuffer("multiply 2"); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_sum) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF, sd::DataType::INT16}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(1); - output.assign(1); - exp.assign(5); - - sumHost(input, output); -// output.printIndexedBuffer("Sum"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = { + sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, + sd::DataType::HALF, sd::DataType::INT16}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(1); + exp.assign(5); + + sumHost(input, output); + // output.printIndexedBuffer("Sum"); + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_sub) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF}; + std::vector dtypes = {sd::DataType::FLOAT32, + sd::DataType::DOUBLE, sd::DataType::HALF}; - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); - input.assign(1); - output.assign(5); - exp.assign(1); + input.assign(1); + output.assign(5); + exp.assign(1); - subHost(input, output); -// output.printBuffer("Sub"); + subHost(input, output); + // output.printBuffer("Sub"); - ASSERT_EQ(exp, output); - } + ASSERT_EQ(exp, output); + } } TEST_F(AtomicTests, test_div) { - std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF}; - - for (auto t:dtypes) { - nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); - NDArray input('c', {4, 25}, t); - NDArray output('c', {input.lengthOf() / 4}, t); - NDArray exp = output.ulike(); - - input.assign(2); - output.assign(32); - exp.assign(2); - - divHost(input, output); -// output.printBuffer("Div"); - ASSERT_EQ(exp, output); - } + std::vector dtypes = { + sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, + sd::DataType::HALF}; + + for (auto t : dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(32); + exp.assign(2); + + divHost(input, output); + // output.printBuffer("Div"); + ASSERT_EQ(exp, output); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp index ab6d50b53895..7b7066ed5e3b 100644 --- a/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/AttentionTests.cpp @@ -18,37 +18,37 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class AttentionTests : public testing::Test { -public: - AttentionTests() { - printf("\n"); - fflush(stdout); - } + public: + AttentionTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(AttentionTests, basic_dot_product_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - https://github.com/deeplearning4j/deeplearning4j/issues/7657 +//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - +https://github.com/deeplearning4j/deeplearning4j/issues/7657 TEST_F(AttentionTests, basic_dot_product_attention_bp) { auto keys = NDArrayFactory::create('c', {10, 4, 3}); auto values = NDArrayFactory::create('c', {10, 4, 3}); @@ -64,25 +64,25 @@ TEST_F(AttentionTests, basic_dot_product_attention_bp) { */ TEST_F(AttentionTests, basic_dot_product_attention_with_weights) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(AttentionTests, basic_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 3}); - auto values = NDArrayFactory::create('c', {10, 4, 3}); - auto queries = NDArrayFactory::create('c', {10, 4, 1}); - auto mask = NDArrayFactory::create('c', {10, 3}); - mask.assign(1.); - - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto mask = NDArrayFactory::create('c', {10, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* @@ -96,23 +96,23 @@ TEST_F(AttentionTests, basic_dot_product_attention_bp_with_mask) { mask.assign(1.); sd::ops::dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, +0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); - auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); - auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); - auto mask = NDArrayFactory::create('c', {2, 3}); - mask.assign(1.); - - sd::ops::dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); + auto mask = NDArrayFactory::create('c', {2, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* @@ -126,35 +126,36 @@ TEST_F(AttentionTests, multi_head_input_dot_product_attention_bp_with_mask) { mask.assign(1.); sd::ops::dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, +0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ - TEST_F(AttentionTests, basic_multi_head_dot_product_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); - - auto Wk = NDArrayFactory::create('c', {2, 3, 4}); - auto Wv = NDArrayFactory::create('c', {2, 3, 4}); - auto Wq = NDArrayFactory::create('c', {2, 3, 4}); - auto Wo = NDArrayFactory::create('c', {2* 3, 4}); - - sd::ops::multi_head_dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2 * 3, 4}); + + sd::ops::multi_head_dot_product_attention op; + auto result = + op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 -TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - +disabling this pre-emptively - See issue #7657 TEST_F(AttentionTests, +basic_multi_head_dot_product_bp_attention) { auto keys = +NDArrayFactory::create('c', {10, 4, 5}); auto values = +NDArrayFactory::create('c', {10, 4, 5}); auto queries = +NDArrayFactory::create('c', {10, 4, 2}); auto Wk = NDArrayFactory::create('c', {2, 3, 4}); auto Wv = NDArrayFactory::create('c', {2, 3, 4}); @@ -165,38 +166,39 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) { sd::ops::multi_head_dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, +&eps}, {}, {1, 0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } */ TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); - - auto Wk = NDArrayFactory::create('c', {2, 3, 4}); - auto Wv = NDArrayFactory::create('c', {2, 3, 4}); - auto Wq = NDArrayFactory::create('c', {2, 3, 4}); - auto Wo = NDArrayFactory::create('c', {2* 3, 4}); - - auto mask = NDArrayFactory::create('c', {10, 5}); - mask.assign(1.); - - - sd::ops::multi_head_dot_product_attention op; - auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2 * 3, 4}); + + auto mask = NDArrayFactory::create('c', {10, 5}); + mask.assign(1.); + + sd::ops::multi_head_dot_product_attention op; + auto result = op.evaluate( + {&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); } /* -//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 -TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) { - auto keys = NDArrayFactory::create('c', {10, 4, 5}); - auto values = NDArrayFactory::create('c', {10, 4, 5}); - auto queries = NDArrayFactory::create('c', {10, 4, 2}); +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - +disabling this pre-emptively - See issue #7657 TEST_F(AttentionTests, +basic_multi_head_dot_product_bp_attention_with_mask) { auto keys = +NDArrayFactory::create('c', {10, 4, 5}); auto values = +NDArrayFactory::create('c', {10, 4, 5}); auto queries = +NDArrayFactory::create('c', {10, 4, 2}); auto Wk = NDArrayFactory::create('c', {2, 3, 4}); auto Wv = NDArrayFactory::create('c', {2, 3, 4}); @@ -210,8 +212,8 @@ TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) { sd::ops::multi_head_dot_product_attention_bp op; - auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps, &mask}, {}, {1, 0}, {}); - ASSERT_EQ(Status::OK(), result->status()); + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, +&eps, &mask}, {}, {1, 0}, {}); ASSERT_EQ(Status::OK(), result->status()); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp index 5c528f072597..d2bf722a49e4 100644 --- a/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BackpropTests.cpp @@ -18,32 +18,31 @@ // Created by raver119 on 13.01.2018. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class BackpropTests : public testing::Test { -public: - + public: }; TEST_F(BackpropTests, Test_Add_1) { + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray y('c', {3, 4}, sd::DataType::FLOAT32); + NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray y('c', {3, 4}, sd::DataType::FLOAT32); - NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32); - - sd::ops::add_bp op; - auto result = op.evaluate({&x, &y, &e}); + sd::ops::add_bp op; + auto result = op.evaluate({&x, &y, &e}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto eps = result.at(0); - auto grad = result.at(1); + auto eps = result.at(0); + auto grad = result.at(1); - ASSERT_TRUE(x.isSameShape(eps)); - ASSERT_TRUE(y.isSameShape(grad)); + ASSERT_TRUE(x.isSameShape(eps)); + ASSERT_TRUE(y.isSameShape(grad)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp index 4174637e201a..c05f152b321c 100644 --- a/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BitwiseUtilsTests.cpp @@ -18,59 +18,57 @@ // Created by raver119 on 10.11.2017. // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class BitwiseUtilsTests : public testing::Test { -public: - + public: }; // oviously, this test will fail on big-endian machines, but who cares TEST_F(BitwiseUtilsTests, Test_Runtime_Endianess_1) { - bool isBE = BitwiseUtils::isBE(); + bool isBE = BitwiseUtils::isBE(); - ASSERT_FALSE(isBE); + ASSERT_FALSE(isBE); } TEST_F(BitwiseUtilsTests, Test_ValueBit_1) { - int idx = BitwiseUtils::valueBit(1); + int idx = BitwiseUtils::valueBit(1); - ASSERT_EQ(0, idx); + ASSERT_EQ(0, idx); } TEST_F(BitwiseUtilsTests, Test_ValueBit_2) { - int idx = BitwiseUtils::valueBit(2); + int idx = BitwiseUtils::valueBit(2); - ASSERT_EQ(1, idx); + ASSERT_EQ(1, idx); } TEST_F(BitwiseUtilsTests, Test_ValueBits_1) { - std::vector expected({1, 1}); - while (expected.size() < 32) - expected.push_back(0); + std::vector expected({1, 1}); + while (expected.size() < 32) expected.push_back(0); - std::vector result = BitwiseUtils::valueBits(3); + std::vector result = BitwiseUtils::valueBits(3); - ASSERT_EQ(32, result.size()); - ASSERT_EQ(expected, result); + ASSERT_EQ(32, result.size()); + ASSERT_EQ(expected, result); } TEST_F(BitwiseUtilsTests, Test_ValueBits_2) { - int value = 48; - int flipped = BitwiseUtils::flip_bits(value); + int value = 48; + int flipped = BitwiseUtils::flip_bits(value); - ASSERT_NE(value, flipped); + ASSERT_NE(value, flipped); - auto o = BitwiseUtils::valueBits(value); - auto f = BitwiseUtils::valueBits(flipped); + auto o = BitwiseUtils::valueBits(value); + auto f = BitwiseUtils::valueBits(flipped); - for (int e = 0; e < o.size(); e++) - ASSERT_NE(o.at(e), f.at(e)); + for (int e = 0; e < o.size(); e++) ASSERT_NE(o.at(e), f.at(e)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp index 74d4cdb44b38..3bcb90dd8ee8 100644 --- a/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -18,131 +18,123 @@ // Created by raver119 on 13.10.2017. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class BooleanOpsTests : public testing::Test { -public: - + public: }; - TEST_F(BooleanOpsTests, LtTest_1) { - auto x = NDArrayFactory::create_(1.0f); - auto y = NDArrayFactory::create_(2.0f); - - sd::ops::lt_scalar op; + auto x = NDArrayFactory::create_(1.0f); + auto y = NDArrayFactory::create_(2.0f); + sd::ops::lt_scalar op; - ASSERT_TRUE(op.verify({x, y})); + ASSERT_TRUE(op.verify({x, y})); - delete x; - delete y; + delete x; + delete y; } TEST_F(BooleanOpsTests, LtTest_2) { - auto x = NDArrayFactory::create_(2.0f); - auto y = NDArrayFactory::create_(1.0f); + auto x = NDArrayFactory::create_(2.0f); + auto y = NDArrayFactory::create_(1.0f); - sd::ops::lt_scalar op; + sd::ops::lt_scalar op; + ASSERT_FALSE(op.verify({x, y})); - ASSERT_FALSE(op.verify({x, y})); - - delete x; - delete y; + delete x; + delete y; } TEST_F(BooleanOpsTests, Is_non_decreasing_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 4}); - - sd::ops::is_non_decreasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 4}); - ASSERT_TRUE(op.verify({&x})); + sd::ops::is_non_decreasing op; + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_non_decreasing_2) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); - - sd::ops::is_non_decreasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 3}); - ASSERT_FALSE(op.verify({&x})); + sd::ops::is_non_decreasing op; + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 5}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 5}); - sd::ops::is_strictly_increasing op; - - ASSERT_TRUE(op.verify({&x})); + sd::ops::is_strictly_increasing op; + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_2) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 3, 3}); - - sd::ops::is_strictly_increasing op; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 3}); - ASSERT_FALSE(op.verify({&x})); + sd::ops::is_strictly_increasing op; + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_3) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 4, 3}); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_FALSE(op.verify({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { - auto x = NDArrayFactory::create('c', {64, 512}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_TRUE(op.verify({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { - auto x = NDArrayFactory::create('c', {64, 512}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); - x.p(18, 1000323.f); + x.p(18, 1000323.f); - sd::ops::is_strictly_increasing op; + sd::ops::is_strictly_increasing op; - ASSERT_FALSE(op.verify({&x})); + ASSERT_FALSE(op.verify({&x})); } TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { - auto x = NDArrayFactory::create('c', {2 , 2}, {1.f, 2.f, 4.f, 3.f}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 4.f, 3.f}); - sd::ops::is_numeric_tensor op; + sd::ops::is_numeric_tensor op; - ASSERT_TRUE(op.verify({&x})); + ASSERT_TRUE(op.verify({&x})); } TEST_F(BooleanOpsTests, test_where_1) { - auto x = NDArrayFactory::create('c', {6}, { 1, -3, 4, 8, -2, 5 }); - auto y = NDArrayFactory::create('c', {6}, { 2, -3, 1, 1, -2, 1 }); - auto e = NDArrayFactory::create('c', {3}, { 4, 8, 5 }); + auto x = NDArrayFactory::create('c', {6}, {1, -3, 4, 8, -2, 5}); + auto y = NDArrayFactory::create('c', {6}, {2, -3, 1, 1, -2, 1}); + auto e = NDArrayFactory::create('c', {3}, {4, 8, 5}); - sd::ops::choose op; + sd::ops::choose op; - auto result = op.evaluate({&x, &y}, {3}); - ASSERT_EQ(Status::OK(), result.status()); + auto result = op.evaluate({&x, &y}, {3}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - //z->printIndexedBuffer("z"); + // z->printIndexedBuffer("z"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } - diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 439c5c5ac056..fa72c879d2e6 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -18,838 +18,799 @@ // Created by raver119 on 23.11.17. // - -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class BroadcastableOpsTests : public testing::Test { -public: - + public: }; TEST_F(BroadcastableOpsTests, Test_Add_1) { + NDArray x('c', {5, 5}, sd::DataType::FLOAT32); + NDArray y('c', {1, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, sd::DataType::FLOAT32); + x.linspace(1); + y.linspace(1); + exp.linspace(1); - NDArray x('c', {5, 5}, sd::DataType::FLOAT32); - NDArray y('c', {1, 5}, sd::DataType::FLOAT32); - NDArray exp('c', {5, 5}, sd::DataType::FLOAT32); - x.linspace(1); - y.linspace(1); - exp.linspace(1); - - //exp.printIndexedBuffer("E B"); - - exp.applyBroadcast(broadcast::Add, {1}, y, exp); + // exp.printIndexedBuffer("E B"); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); + exp.applyBroadcast(broadcast::Add, {1}, y, exp); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //exp.printIndexedBuffer("E A"); - //z->printIndexedBuffer("Z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // exp.printIndexedBuffer("E A"); + // z->printIndexedBuffer("Z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Multiply_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - y.linspace(1); - exp.linspace(1); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); - exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); + exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - y.linspace(1); - exp.linspace(1); - - exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); + exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); - auto exp = NDArrayFactory::create('c', {1,3}, {1, 0, -1}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 0, -1}); - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) { - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); - auto exp = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Maximum_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); - auto row = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); - auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto row = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); - sd::ops::maximum op; - auto result = op.evaluate({&x, &row}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::maximum op; + auto result = op.evaluate({&x, &row}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Minimum_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); - auto col = NDArrayFactory::create('c', {2, 1}, {2, 1}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto col = NDArrayFactory::create('c', {2, 1}, {2, 1}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); - sd::ops::minimum op; - auto result = op.evaluate({&x, &col}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::minimum op; + auto result = op.evaluate({&x, &col}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Shape_1) { - sd::ops::minimum op; + sd::ops::minimum op; - Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } TEST_F(BroadcastableOpsTests, Test_Shape_2) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ)); - delete shapes; + delete shapes; } - TEST_F(BroadcastableOpsTests, Test_Shape_3) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } - TEST_F(BroadcastableOpsTests, Test_Shape_4) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); - delete shapes; + delete shapes; } // (2,1,3) + (4,3) = (2,4,3) TEST_F(BroadcastableOpsTests, Test_Shape_5) { - sd::ops::minimum op; + sd::ops::minimum op; - const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99}; - const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99}; - const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99}; - ShapeList inputShape({shapeX, shapeY}); - VariableSpace vs; - Context ctx(1, &vs, false); + const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); - auto shapes = op.calculateOutputShape(&inputShape, ctx); + auto shapes = op.calculateOutputShape(&inputShape, ctx); - auto shapeZ = shapes->at(0); - ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ)); + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ)); - delete shapes; + delete shapes; } TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); - - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); - auto z = result.at(0); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, Test_Inplace_Output_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto o = NDArrayFactory::create('c', {2, 4, 3}); - auto e = NDArrayFactory::create('c', {2, 4, 3}); - auto buffO1 = reinterpret_cast(o.buffer()); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto o = NDArrayFactory::create('c', {2, 4, 3}); + auto e = NDArrayFactory::create('c', {2, 4, 3}); + auto buffO1 = reinterpret_cast(o.buffer()); + y.assign(1.0f); + e.assign(1.0f); - sd::ops::add op; - auto result = op.execute({&x, &y}, {&o}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::add op; + auto result = op.execute({&x, &y}, {&o}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); - auto buffO2 = reinterpret_cast(o.buffer()); + auto buffO2 = reinterpret_cast(o.buffer()); - ASSERT_TRUE(e.isSameShape(o)); - ASSERT_TRUE(e.equalsTo(o)); + ASSERT_TRUE(e.isSameShape(o)); + ASSERT_TRUE(e.equalsTo(o)); - ASSERT_TRUE(buffO1 == buffO2); + ASSERT_TRUE(buffO1 == buffO2); } TEST_F(BroadcastableOpsTests, Test_Subtract_1) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - - auto z = x - y; + auto z = x - y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_2) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + auto z = result.at(0); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_3) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto z = NDArrayFactory::create('c', {2}, {0.0f, 0.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto z = NDArrayFactory::create('c', {2}, {0.0f, 0.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - sd::ops::subtract op; - auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::subtract op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Subtract_4) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y); + auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_5) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {-1., 0.}); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {-1., 0.}); - auto z = y - x; + auto z = y - x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_6) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(3.f); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(3.f); - auto z = y - x; + auto z = y - x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Subtract_7) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(-3.f); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(-3.f); - auto z = x - y; + auto z = x - y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_2) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - - auto z = x + y; + auto z = x + y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_3) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); - auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto z = y + x; - auto z = y + x; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_4) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(5.f); - - auto z = x + y; + auto z = x + y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Add_5) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(5.f); - - auto z = y + x; + auto z = y + x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_2) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); - auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + auto z = y * x; - auto z = y * x; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_3) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); - auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + auto z = x * y; - auto z = x * y; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_4) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(8.f); - - auto z = y * x; + auto z = y * x; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, Test_Multiply_5) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(4.f); - auto e = NDArrayFactory::create(8.f); + auto z = x * y; - auto z = x * y; - - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_6) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1}, {4.f}); - auto e = NDArrayFactory::create('c', {1}, {8.f}); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); - auto z = x * y; + auto z = x * y; - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_7) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1}, {4.f}); - auto e = NDArrayFactory::create('c', {1}, {8.f}); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); - auto z = result.at(0); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, Test_Multiply_8) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create('c', {1, 1}, {4.f}); - auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1, 1}, {4.f}); + auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_add_1) { + NDArray x('c', {4}, {1, 1, 1, 1}); + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray z('c', {1, 4}, sd::DataType::DOUBLE); + NDArray exp('c', {1, 4}, {2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray x('c', {4}, {1,1,1,1}); - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray z('c', {1,4}, sd::DataType::DOUBLE); - NDArray exp('c', {1,4}, {2,3,4,5}, sd::DataType::DOUBLE); - - sd::ops::add op; - auto status = op.execute({&x, &y}, {&z}); + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_equals_1) { + NDArray x('c', {1, 4}, {1, 2, 3, 4}); + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray z('c', {3, 4}, sd::DataType::BOOL); + NDArray exp('c', {3, 4}, {0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, + sd::DataType::BOOL); - NDArray x('c', {1,4}, {1,2,3,4}); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray z('c', {3,4}, sd::DataType::BOOL); - NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, sd::DataType::BOOL); + sd::ops::equals op; + auto status = op.execute({&x, &y}, {&z}); + // z.printIndexedBuffer(); - sd::ops::equals op; - auto status = op.execute({&x, &y}, {&z}); - // z.printIndexedBuffer(); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.equalsTo(exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(BroadcastableOpsTests, broadcast_empty_1) { + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::DOUBLE, y.getContext(), false); + NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray x(sd::DataType::DOUBLE, y.getContext(), false); - NDArray z(sd::DataType::DOUBLE, y.getContext(), false); - NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false); - - sd::ops::multiply op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.isSameShape(zExp)); - ASSERT_TRUE(z.equalsTo(zExp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); } TEST_F(BroadcastableOpsTests, broadcast_empty_2) { + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4}); + ; - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray x = NDArrayFactory::create('c', {0, 4}); - NDArray e = NDArrayFactory::create('c', {0, 4});; - - sd::ops::multiply op; - auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(e.isSameShape(x)); - ASSERT_TRUE(e.equalsTo(x)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(e.isSameShape(x)); + ASSERT_TRUE(e.equalsTo(x)); } TEST_F(BroadcastableOpsTests, broadcast_empty_3) { + NDArray x = NDArrayFactory::create('c', {1, 0, 2}); + NDArray y('c', {}, std::vector{0.1}, sd::DataType::FLOAT32); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 2}); - NDArray y('c', {}, std::vector{0.1}, sd::DataType::FLOAT32); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); - sd::ops::maximum op; - auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_4) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 0, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; - - sd::ops::maximum op; - auto result = op.evaluate({&x, &y}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_5) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 0, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_6) { + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 1}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2});; - - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_empty_7) { + NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2, 0}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0}); + ; - NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); - NDArray y = NDArrayFactory::create('c', {1, 2, 0}); - NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0});; - - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } - TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) { + NDArray y('c', {3, 4}, {0, 0, 0, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::BOOL, y.getContext(), false); + NDArray zExp(sd::DataType::BOOL, y.getContext(), false); - NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); - NDArray x(sd::DataType::DOUBLE, y.getContext(), false); - NDArray z(sd::DataType::BOOL, y.getContext(), false); - NDArray zExp(sd::DataType::BOOL, y.getContext(), false); + sd::ops::greater op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - sd::ops::greater op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(z.isSameShape(zExp)); - ASSERT_TRUE(z.equalsTo(zExp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); } TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { + NDArray y('c', {1, 4}, {1, 2, 3, 4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4}); + ; - NDArray y('c', {1,4}, {1,2,3,4}); - NDArray x = NDArrayFactory::create('c', {0, 4}); - NDArray e = NDArrayFactory::create('c', {0, 4});; - - - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); - auto z = result.at(0); + auto z = result.at(0); - // z->printShapeInfo("z"); + // z->printShapeInfo("z"); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(BroadcastableOpsTests, broadcast_bool_1) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + x.assign(4.f); + y.assign(2.f); + e.assign(true); - x.assign(4.f); - y.assign(2.f); - e.assign(true); + sd::ops::greater op; - sd::ops::greater op; + auto status = op.execute({&x, &y}, {&z}); - auto status = op.execute({&x, &y}, {&z}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_EQ(ND4J_STATUS_OK, status); + // z.printIndexedBuffer("Z"); - // z.printIndexedBuffer("Z"); - - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_bool_2) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - - x.assign(1.f); - y.assign(2.f); - e.assign(false); + x.assign(1.f); + y.assign(2.f); + e.assign(false); - sd::ops::equals op; + sd::ops::equals op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_bool_3) { + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::BOOL); + NDArray e('c', {3}, sd::DataType::BOOL); - auto x = NDArrayFactory::create(0); - auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); - NDArray z('c', {3}, sd::DataType::BOOL); - NDArray e('c', {3}, sd::DataType::BOOL); - - e.assign(true); + e.assign(true); - sd::ops::less op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::less op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_2) { - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32); - NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32); + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32); + NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32); - x = 4.f; - y = 2.f; - e = -2.f; + x = 4.f; + y = 2.f; + e = -2.f; - sd::ops::reversesubtract op; // z = y - x; + sd::ops::reversesubtract op; // z = y - x; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, broadcast_3) { - auto x = NDArrayFactory::create(0); - auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); - NDArray z('c', {3}, sd::DataType::INT32); - auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::INT32); + auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); - sd::ops::add op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - // z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_TRUE(z.isSameShape(e)); - ASSERT_TRUE(z.equalsTo(e)); + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); } TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { - auto x = NDArrayFactory::create('c', {4, 128, 1}); - auto y = NDArrayFactory::create('c', {4, 1, 128}); - auto z = NDArrayFactory::create('c', {4, 128, 128}); - auto e = NDArrayFactory::create('c', {4, 128, 128}); + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {4, 1, 128}); + auto z = NDArrayFactory::create('c', {4, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 128, 128}); - x.assign(0.f); - y.assign(1.f); - z.assign(119.f); - e.assign(0.f); -/* - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); - ctx.setOutputArray(0, &z); + x.assign(0.f); + y.assign(1.f); + z.assign(119.f); + e.assign(0.f); + /* + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); - sd::ops::multiply op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::multiply op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - z.printIndexedBuffer(); -*/ + z.printIndexedBuffer(); + */ - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - //z.printIndexedBuffer(); + // z.printIndexedBuffer(); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { - auto x = NDArrayFactory::create('c', {4, 128, 1}); - auto y = NDArrayFactory::create('c', {768}); - auto z = NDArrayFactory::create('c', {4, 128, 768}); - auto e = NDArrayFactory::create('c', {4, 128, 768}); + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {768}); + auto z = NDArrayFactory::create('c', {4, 128, 768}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); - x.assign(1.f); - y.assign(2.f); - z.assign(119.f); - e.assign(2.f); + x.assign(1.f); + y.assign(2.f); + z.assign(119.f); + e.assign(2.f); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp index 34d0132bb221..135fb0c4ba1a 100644 --- a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp @@ -18,49 +18,56 @@ // Created by agibsonccc on 1/19/17. // -#include "testinclude.h" #include +#include "testinclude.h" + class BroadcastMultiDimTest : public testing::Test { -public: - int dimensions[2] = {0,2}; - Nd4jLong inputShapeBuffer[10] = {3,2,3,5,15,5,1,8192,1,99}; - float inputData[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0}; - float dataAssertion[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,0.0,0.0,21.0,22.0,23.0,0.0,0.0,26.0,27.0,28.0,0.0,0.0}; - float result[30] = {0.0}; - float broadcastData[10] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0}; - Nd4jLong broadcastShapeInfo[8] = {2,2,5,5,1,8192,1,99}; - int opNum = 2; - int dimensionLength = 2; + public: + int dimensions[2] = {0, 2}; + Nd4jLong inputShapeBuffer[10] = {3, 2, 3, 5, 15, 5, 1, 8192, 1, 99}; + float inputData[30] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0}; + float dataAssertion[30] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 0.0, 0.0, 21.0, 22.0, 23.0, 0.0, + 0.0, 26.0, 27.0, 28.0, 0.0, 0.0}; + float result[30] = {0.0}; + float broadcastData[10] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}; + Nd4jLong broadcastShapeInfo[8] = {2, 2, 5, 5, 1, 8192, 1, 99}; + int opNum = 2; + int dimensionLength = 2; }; #ifndef __CUDABLAS__ -TEST_F(BroadcastMultiDimTest,MultimDimTest) { - auto tad = new shape::TAD(); - tad->init(inputShapeBuffer,dimensions,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad-> createOffsets(); - functions::broadcast::Broadcast::exec( - opNum, - inputData, //x - inputShapeBuffer, //xShapeInfo - broadcastData, //y - broadcastShapeInfo, //yShapeInfo - result, //result - inputShapeBuffer, //resultShapeInfo - dimensions, //dimension - dimensionLength, //dimensionLength - tad->tadOnlyShapeInfo, //tadShapeInfo - tad->tadOffsets, //tadOffset - tad->tadOnlyShapeInfo, //tadShapeInfoZ - tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ +TEST_F(BroadcastMultiDimTest, MultimDimTest) { + auto tad = new shape::TAD(); + tad->init(inputShapeBuffer, dimensions, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + functions::broadcast::Broadcast::exec( + opNum, + inputData, // x + inputShapeBuffer, // xShapeInfo + broadcastData, // y + broadcastShapeInfo, // yShapeInfo + result, // result + inputShapeBuffer, // resultShapeInfo + dimensions, // dimension + dimensionLength, // dimensionLength + tad->tadOnlyShapeInfo, // tadShapeInfo + tad->tadOffsets, // tadOffset + tad->tadOnlyShapeInfo, // tadShapeInfoZ + tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); // tadOffsetZ - for(int i = 0; i < 30; i++) { - ASSERT_EQ(dataAssertion[i],result[i]); - } + for (int i = 0; i < 30; i++) { + ASSERT_EQ(dataAssertion[i], result[i]); + } - delete tad; + delete tad; } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp b/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp index ea8025592c82..256be6b110b8 100644 --- a/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/CnpyTests.cpp @@ -18,53 +18,50 @@ // Created by agibsonccc on 3/30/17. // -#include "testinclude.h" -#include #include -class FileTest : public testing::Test { - -}; +#include -class LoadFromStringTest : public testing::Test { +#include "testinclude.h" -}; +class FileTest : public testing::Test {}; -class HeaderTest : public testing::Test { +class LoadFromStringTest : public testing::Test {}; -}; +class HeaderTest : public testing::Test {}; TEST_F(HeaderTest, test_dataTypes_1) { - std::string header("0NUMPY6789{'descr': '>f4"); + std::string header("0NUMPY6789{'descr': '>f4"); - - ASSERT_EQ(sd::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::FLOAT32, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_2) { - std::string header("0NUMPY6789{'descr': '>f8"); - + std::string header("0NUMPY6789{'descr': '>f8"); - ASSERT_EQ(sd::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::DOUBLE, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_3) { - std::string header("0NUMPY6789{'descr': '(header.data()))); + ASSERT_EQ(sd::DataType::INT32, + dataTypeFromNpyHeader(const_cast(header.data()))); } TEST_F(HeaderTest, test_dataTypes_4) { - std::string header("0NUMPY6789{'descr': '>u2"); - + std::string header("0NUMPY6789{'descr': '>u2"); - ASSERT_EQ(sd::DataType::UINT16, dataTypeFromNpyHeader(const_cast(header.data()))); + ASSERT_EQ(sd::DataType::UINT16, + dataTypeFromNpyHeader(const_cast(header.data()))); } /* TEST_F(FileTest,T) { - cnpy::NpyArray npy = cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy")); + cnpy::NpyArray npy = +cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy")); ASSERT_FALSE(npy.fortranOrder); ASSERT_EQ(2,npy.shape[0]); diff --git a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp index 5b747ab5bea5..4f9b506ba888 100644 --- a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -18,226 +18,237 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include -#include #include +#include +#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; using namespace sd::graph; class ConstantShapeHelperTests : public testing::Test { -public: - + public: }; class ConstantHelperTests : public testing::Test { -public: - + public: }; class ConstantTadHelperTests : public testing::Test { -public: - + public: }; TEST_F(ConstantShapeHelperTests, test_cachedAmount_1) { - auto ttlBefore = ConstantShapeHelper::getInstance()->totalCachedEntries(); + auto ttlBefore = ConstantShapeHelper::getInstance()->totalCachedEntries(); - auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlMiddle = ConstantShapeHelper::getInstance()->totalCachedEntries(); + auto ttlMiddle = ConstantShapeHelper::getInstance()->totalCachedEntries(); - auto arrayB = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto arrayB = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlAfter = ConstantShapeHelper::getInstance()->totalCachedEntries(); + auto ttlAfter = ConstantShapeHelper::getInstance()->totalCachedEntries(); - ASSERT_TRUE(ttlBefore <= ttlMiddle); - ASSERT_EQ(ttlMiddle, ttlAfter); + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); } TEST_F(ConstantTadHelperTests, test_cachedAmount_1) { - auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); - auto ttlBefore = ConstantTadHelper::getInstance()->totalCachedEntries(); + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto ttlBefore = ConstantTadHelper::getInstance()->totalCachedEntries(); - auto packAA = ConstantTadHelper::getInstance()->tadForDimensions(arrayA.shapeInfo(), {3, 4}); + auto packAA = ConstantTadHelper::getInstance()->tadForDimensions( + arrayA.shapeInfo(), {3, 4}); - auto ttlMiddle = ConstantTadHelper::getInstance()->totalCachedEntries(); + auto ttlMiddle = ConstantTadHelper::getInstance()->totalCachedEntries(); - auto packAB = ConstantTadHelper::getInstance()->tadForDimensions(arrayA.shapeInfo(), {3, 4}); + auto packAB = ConstantTadHelper::getInstance()->tadForDimensions( + arrayA.shapeInfo(), {3, 4}); - auto ttlAfter = ConstantTadHelper::getInstance()->totalCachedEntries(); + auto ttlAfter = ConstantTadHelper::getInstance()->totalCachedEntries(); - ASSERT_TRUE(ttlBefore <= ttlMiddle); - ASSERT_EQ(ttlMiddle, ttlAfter); + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); } TEST_F(ConstantShapeHelperTests, basic_test_1) { - auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15}); - ShapeDescriptor descriptor(ptr); - ShapeDescriptor descriptor2(ptr); - - ASSERT_EQ(descriptor, descriptor2); + auto ptr = + ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15}); + ShapeDescriptor descriptor(ptr); + ShapeDescriptor descriptor2(ptr); - ASSERT_EQ(1, descriptor.ews()); - ASSERT_EQ(3, descriptor.rank()); - ASSERT_EQ('f', descriptor.order()); - ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType()); - ASSERT_FALSE(descriptor.isEmpty()); + ASSERT_EQ(descriptor, descriptor2); - ASSERT_FALSE(ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo(descriptor)); + ASSERT_EQ(1, descriptor.ews()); + ASSERT_EQ(3, descriptor.rank()); + ASSERT_EQ('f', descriptor.order()); + ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType()); + ASSERT_FALSE(descriptor.isEmpty()); - auto buffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); + ASSERT_FALSE( + ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo( + descriptor)); - ASSERT_TRUE(ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo(descriptor)); + auto buffer = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); - auto buffer2 = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor2); + ASSERT_TRUE( + ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo( + descriptor)); + auto buffer2 = + ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor2); - ASSERT_TRUE(buffer.primary() != nullptr); - ASSERT_TRUE(buffer.primary() == buffer2.primary()); - ASSERT_TRUE(buffer.special() == buffer2.special()); + ASSERT_TRUE(buffer.primary() != nullptr); + ASSERT_TRUE(buffer.primary() == buffer2.primary()); + ASSERT_TRUE(buffer.special() == buffer2.special()); - delete []ptr; + delete[] ptr; } TEST_F(ConstantShapeHelperTests, stress_test_1) { - - for (auto x = 0; x < 1000; x++) { - auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', {5, x + 10, x + 1}); - ShapeDescriptor descriptor(ptr); - ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); - delete [] ptr; - } - ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); -// nd4j_printf("%d\n", ConstantShapeHelper::getInstance()->cachedEntriesForDevice(0)); - - auto timeStart = std::chrono::system_clock::now(); - ASSERT_TRUE(ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo(aShape)); - auto timeEnd = std::chrono::system_clock::now(); - - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - nd4j_printf("Total time (us) %lld\n", outerTime); + for (auto x = 0; x < 1000; x++) { + auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', + {5, x + 10, x + 1}); + ShapeDescriptor descriptor(ptr); + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); + delete[] ptr; + } + ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', + {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); + // nd4j_printf("%d\n", + // ConstantShapeHelper::getInstance()->cachedEntriesForDevice(0)); + + auto timeStart = std::chrono::system_clock::now(); + ASSERT_TRUE( + ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo( + aShape)); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + nd4j_printf("Total time (us) %lld\n", outerTime); } TEST_F(ConstantShapeHelperTests, basic_test_3) { - auto array = NDArrayFactory::create_('c', {128}); + auto array = NDArrayFactory::create_('c', {128}); - ASSERT_TRUE(array->shapeInfo() != nullptr); + ASSERT_TRUE(array->shapeInfo() != nullptr); #ifdef __CUDABLAS__ - ASSERT_TRUE(array->specialShapeInfo() != nullptr); + ASSERT_TRUE(array->specialShapeInfo() != nullptr); #endif - delete array; + delete array; } - TEST_F(ConstantShapeHelperTests, basic_test_4) { - auto array = NDArrayFactory::create_('c', {128, 256}); + auto array = NDArrayFactory::create_('c', {128, 256}); - auto dup = new NDArray(array->dup('f')); + auto dup = new NDArray(array->dup('f')); - ASSERT_TRUE(dup->shapeInfo() != nullptr); + ASSERT_TRUE(dup->shapeInfo() != nullptr); #ifdef __CUDABLAS__ - ASSERT_TRUE(dup->specialShapeInfo() != nullptr); - PointersManager manager(sd::LaunchContext ::defaultContext(), "test"); - // manager.printDevContentOnDev(dup->specialShapeInfo(), shape::shapeInfoLength(2), 0); + ASSERT_TRUE(dup->specialShapeInfo() != nullptr); + PointersManager manager(sd::LaunchContext ::defaultContext(), "test"); + // manager.printDevContentOnDev(dup->specialShapeInfo(), + // shape::shapeInfoLength(2), 0); #endif - delete array; - delete dup; + delete array; + delete dup; } - TEST_F(ConstantShapeHelperTests, basic_test_5) { + auto arrayA = NDArrayFactory::create(1); + auto arrayB = NDArrayFactory::create_('c', {128, 256}); - auto arrayA = NDArrayFactory::create(1); - auto arrayB = NDArrayFactory::create_('c', {128, 256}); - - //arrayA.printShapeInfo("A"); - //arrayB->printShapeInfo("B"); - ASSERT_EQ(0, arrayA.rankOf()); - ASSERT_EQ(2, arrayB->rankOf()); - ASSERT_NE(arrayA.dataType(), arrayB->dataType()); + // arrayA.printShapeInfo("A"); + // arrayB->printShapeInfo("B"); + ASSERT_EQ(0, arrayA.rankOf()); + ASSERT_EQ(2, arrayB->rankOf()); + ASSERT_NE(arrayA.dataType(), arrayB->dataType()); - delete arrayB; + delete arrayB; } TEST_F(ConstantShapeHelperTests, basic_test_6) { - ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {}); - ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10}); + ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {}); + ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10}); - // ASSERT_FALSE(descriptorA < descriptorB); - // ASSERT_TRUE(descriptorB < descriptorA); + // ASSERT_FALSE(descriptorA < descriptorB); + // ASSERT_TRUE(descriptorB < descriptorA); - ASSERT_TRUE(descriptorA < descriptorB); - ASSERT_FALSE(descriptorB < descriptorA); + ASSERT_TRUE(descriptorA < descriptorB); + ASSERT_FALSE(descriptorB < descriptorA); } TEST_F(ConstantShapeHelperTests, basic_test_7) { - auto array = NDArrayFactory::create_('c', {32, 256}); + auto array = NDArrayFactory::create_('c', {32, 256}); - IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); - auto strided = array->subarray(indices); - strided.assign(1.0f); + IndicesList indices({NDIndex::all(), NDIndex::interval(0, 1)}); + auto strided = array->subarray(indices); + strided.assign(1.0f); - //strided->printIndexedBuffer("column"); + // strided->printIndexedBuffer("column"); - delete array; + delete array; } TEST_F(ConstantHelperTests, basic_test_1) { + ConstantDescriptor descriptor({1, 2, 3}); - ConstantDescriptor descriptor({1, 2, 3}); - - ConstantDataBuffer* fBuffer = ConstantHelper::getInstance()->constantBuffer(descriptor, sd::DataType::FLOAT32); - auto fPtr = fBuffer->primaryAsT(); + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance()->constantBuffer( + descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); - ASSERT_NEAR(1.f, fPtr[0], 1e-5); - ASSERT_NEAR(2.f, fPtr[1], 1e-5); - ASSERT_NEAR(3.f, fPtr[2], 1e-5); + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); - auto iBuffer = ConstantHelper::getInstance()->constantBuffer(descriptor, sd::DataType::INT32); - auto iPtr = iBuffer->primaryAsT(); + auto iBuffer = ConstantHelper::getInstance()->constantBuffer( + descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); - ASSERT_EQ(1, iPtr[0]); - ASSERT_EQ(2, iPtr[1]); - ASSERT_EQ(3, iPtr[2]); + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); } TEST_F(ConstantHelperTests, basic_test_2) { + double array[] = {1., 2., 3.}; + ConstantDescriptor descriptor(array, 3); - double array[] = {1., 2., 3.}; - ConstantDescriptor descriptor(array, 3); + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance()->constantBuffer( + descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); - ConstantDataBuffer* fBuffer = ConstantHelper::getInstance()->constantBuffer(descriptor, sd::DataType::FLOAT32); - auto fPtr = fBuffer->primaryAsT(); + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); - ASSERT_NEAR(1.f, fPtr[0], 1e-5); - ASSERT_NEAR(2.f, fPtr[1], 1e-5); - ASSERT_NEAR(3.f, fPtr[2], 1e-5); + auto iBuffer = ConstantHelper::getInstance()->constantBuffer( + descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); - auto iBuffer = ConstantHelper::getInstance()->constantBuffer(descriptor, sd::DataType::INT32); - auto iPtr = iBuffer->primaryAsT(); - - ASSERT_EQ(1, iPtr[0]); - ASSERT_EQ(2, iPtr[1]); - ASSERT_EQ(3, iPtr[2]); + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); } ////////////////////////////////////////////////////////////////////// TEST_F(ConstantShapeHelperTests, ShapeDescriptor_1) { + Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99}; - - ShapeDescriptor descr1(shapeInfo1); - ShapeDescriptor descr2(shapeInfo2); + ShapeDescriptor descr1(shapeInfo1); + ShapeDescriptor descr2(shapeInfo2); - ASSERT_FALSE(descr1 == descr2); + ASSERT_FALSE(descr1 == descr2); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp index 9176b4f77452..460de23abf23 100644 --- a/libnd4j/tests_cpu/layers_tests/ContextTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ContextTests.cpp @@ -18,343 +18,352 @@ // Created by raver119 on 30.10.2017. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class ContextTests : public testing::Test { -public: - + public: }; - TEST_F(ContextTests, Basic_Test_1) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto _20 = NDArrayFactory::create('c', {2, 2}); - auto _21 = NDArrayFactory::create('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20.assign(1.0f); - _21.assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); - variableSpace.putVariable(2, 0, _20); - variableSpace.putVariable(2, 1, _21); + variableSpace.putVariable(2, 0, _20); + variableSpace.putVariable(2, 1, _21); - Context block(1, &variableSpace); + Context block(1, &variableSpace); - block.pickInput(2, 0); - block.pickInput(2, 1); + block.pickInput(2, 0); + block.pickInput(2, 1); - ASSERT_EQ(2, block.inputs().size()); - ASSERT_EQ(2, block.width()); + ASSERT_EQ(2, block.inputs().size()); + ASSERT_EQ(2, block.width()); - ASSERT_TRUE(variableSpace.hasVariable(2, 0)); - ASSERT_TRUE(variableSpace.hasVariable(2, 1)); + ASSERT_TRUE(variableSpace.hasVariable(2, 0)); + ASSERT_TRUE(variableSpace.hasVariable(2, 1)); - ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), + 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), + 1e-5); } - TEST_F(ContextTests, Basic_Test_2) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto _20 = NDArrayFactory::create('c', {2, 2}); - auto _21 = NDArrayFactory::create('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _21 = NDArrayFactory::create('c', {2, 2}); - _20.assign(1.0f); - _21.assign(2.0f); + _20.assign(1.0f); + _21.assign(2.0f); - variableSpace.putVariable(-1, _20); - variableSpace.putVariable(-2, _21); + variableSpace.putVariable(-1, _20); + variableSpace.putVariable(-2, _21); - Context block(1, &variableSpace); + Context block(1, &variableSpace); - block.pickInput(-1); - block.pickInput(-2); + block.pickInput(-1); + block.pickInput(-2); - ASSERT_EQ(2, block.inputs().size()); - ASSERT_EQ(2, block.width()); + ASSERT_EQ(2, block.inputs().size()); + ASSERT_EQ(2, block.width()); - ASSERT_TRUE(variableSpace.hasVariable(-1)); - ASSERT_TRUE(variableSpace.hasVariable(-2)); + ASSERT_TRUE(variableSpace.hasVariable(-1)); + ASSERT_TRUE(variableSpace.hasVariable(-2)); - ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); - ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), + 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), + 1e-5); } - TEST_F(ContextTests, Basic_Test_3) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create('c', {2, 2}); + auto _20 = NDArrayFactory::create('c', {2, 2}); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); } - TEST_F(ContextTests, Basic_Test_4) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create('c', {2, 2}); - _20.linspace(1); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - auto _21 = NDArrayFactory::create('c', {2, 2}); - _21.linspace(10); + auto _21 = NDArrayFactory::create('c', {2, 2}); + _21.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ctx.pushNDArrayToVariableSpace(1, 1, _21); + ctx.pushNDArrayToVariableSpace(1, 1, _21); - auto vA = ctx.variable(1, 1); + auto vA = ctx.variable(1, 1); - ASSERT_TRUE(vA->getNDArray()->equalsTo(_21)); + ASSERT_TRUE(vA->getNDArray()->equalsTo(_21)); } TEST_F(ContextTests, Basic_Test_5) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _20 = NDArrayFactory::create('c', {2, 2}); - _20.linspace(1); + auto _20 = NDArrayFactory::create('c', {2, 2}); + _20.linspace(1); - auto exp = _20.dup(); + auto exp = _20.dup(); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ctx.pushNDArrayToVariableSpace(1, 1, _20); + ctx.pushNDArrayToVariableSpace(1, 1, _20); - auto vA = ctx.variable(1, 1); + auto vA = ctx.variable(1, 1); - ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); + ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); } - TEST_F(ContextTests, Basic_Test_6) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto v0 = ctx.ensureVariable("", 1, 0); - auto v1 = ctx.ensureVariable("", 1, 1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - ASSERT_TRUE(variableSpace.hasVariable(1, 0)); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - auto var0 = variableSpace.getVariable(1, 0); - auto var1 = variableSpace.getVariable(1, 1); + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); - ASSERT_TRUE(v0 == var0); - ASSERT_TRUE(v1 == var1); + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); } - TEST_F(ContextTests, Basic_Test_7) { - VariableSpace variableSpace; - - Context ctx(1, &variableSpace); + VariableSpace variableSpace; - auto v0 = ctx.ensureVariable("", 1, 0); - auto v1 = ctx.ensureVariable("", 1, 1); + Context ctx(1, &variableSpace); - ASSERT_TRUE(variableSpace.hasVariable(1, 0)); - ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - auto var0 = variableSpace.getVariable(1, 0); - auto var1 = variableSpace.getVariable(1, 1); + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); - ASSERT_TRUE(v0 == var0); - ASSERT_TRUE(v1 == var1); + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); - auto _10 = NDArrayFactory::create('c', {2, 2}); - _10.linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create('c', {2, 2}); - _11.linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 0, _10); - ctx.pushNDArrayToVariableSpace(1, 1, _11); + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); - auto z0 = variableSpace.getVariable(1, 0); - auto z1 = variableSpace.getVariable(1, 1); + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); - ASSERT_TRUE(v0 == z0); - ASSERT_TRUE(v1 == z1); + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); } TEST_F(ContextTests, Basic_Test_8) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(1, &variableSpace); + Context ctx(1, &variableSpace); - auto _10 = NDArrayFactory::create('c', {2, 2}); - _10.linspace(1); + auto _10 = NDArrayFactory::create('c', {2, 2}); + _10.linspace(1); - auto _11 = NDArrayFactory::create('c', {2, 2}); - _11.linspace(10); + auto _11 = NDArrayFactory::create('c', {2, 2}); + _11.linspace(10); - ctx.pushNDArrayToVariableSpace(1, 0, _10); - ctx.pushNDArrayToVariableSpace(1, 1, _11); + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); - auto z0 = variableSpace.getVariable(1, 0); - auto z1 = variableSpace.getVariable(1, 1); + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); - auto v0 = ctx.ensureVariable("", 1, 0); - auto v1 = ctx.ensureVariable("", 1, 1); + auto v0 = ctx.ensureVariable("", 1, 0); + auto v1 = ctx.ensureVariable("", 1, 1); - ASSERT_TRUE(v0 == z0); - ASSERT_TRUE(v1 == z1); + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); } - TEST_F(ContextTests, Basic_Test_9) { - VariableSpace variableSpace; + VariableSpace variableSpace; - auto in = NDArrayFactory::create('c', {5, 5}); + auto in = NDArrayFactory::create('c', {5, 5}); - Context ctx(1, &variableSpace, true); - ctx.pushNDArrayToVariableSpace(1, 1, in); + Context ctx(1, &variableSpace, true); + ctx.pushNDArrayToVariableSpace(1, 1, in); } TEST_F(ContextTests, Basic_Test_10) { - VariableSpace variableSpace; + VariableSpace variableSpace; - Context ctx(119, &variableSpace); + Context ctx(119, &variableSpace); } - TEST_F(ContextTests, Prototype_Test_1) { - ContextPrototype prototype(nullptr, 119, true); - prototype.pickInput(12, 3); - prototype.pickInput(12, 4); + ContextPrototype prototype(nullptr, 119, true); + prototype.pickInput(12, 3); + prototype.pickInput(12, 4); - prototype.appendT(2.0); - prototype.appendT(-2.0); + prototype.appendT(2.0); + prototype.appendT(-2.0); - prototype.appendI(17); - prototype.appendI(119); + prototype.appendI(17); + prototype.appendI(119); - Context ctx(prototype, nullptr); + Context ctx(prototype, nullptr); - ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); - ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); - ASSERT_EQ(2, ctx.inputs().size()); - ASSERT_EQ(2, ctx.getTArguments().size()); - ASSERT_EQ(2, ctx.getIArguments().size()); + ASSERT_EQ(2, ctx.inputs().size()); + ASSERT_EQ(2, ctx.getTArguments().size()); + ASSERT_EQ(2, ctx.getIArguments().size()); - ASSERT_EQ(2.0, ctx.getTArguments().at(0)); - ASSERT_EQ(-2.0, ctx.getTArguments().at(1)); + ASSERT_EQ(2.0, ctx.getTArguments().at(0)); + ASSERT_EQ(-2.0, ctx.getTArguments().at(1)); - ASSERT_EQ(17, ctx.getIArguments().at(0)); - ASSERT_EQ(119, ctx.getIArguments().at(1)); + ASSERT_EQ(17, ctx.getIArguments().at(0)); + ASSERT_EQ(119, ctx.getIArguments().at(1)); } - TEST_F(ContextTests, Prototype_Test_2) { - ContextPrototype prototype(nullptr, 119, false); - prototype.setOpNum(179); + ContextPrototype prototype(nullptr, 119, false); + prototype.setOpNum(179); - Context ctx(prototype, nullptr); + Context ctx(prototype, nullptr); - ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); - ASSERT_EQ(ctx.opNum(), prototype.opNum()); + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + ASSERT_EQ(ctx.opNum(), prototype.opNum()); - ASSERT_EQ(0, ctx.inputs().size()); - ASSERT_EQ(0, ctx.getTArguments().size()); - ASSERT_EQ(0, ctx.getIArguments().size()); + ASSERT_EQ(0, ctx.inputs().size()); + ASSERT_EQ(0, ctx.getTArguments().size()); + ASSERT_EQ(0, ctx.getIArguments().size()); } TEST_F(ContextTests, test_short_context_1) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); - Context ctx(1); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create( + 'c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); + Context ctx(1); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - auto input0 = ctx.array(0); - ASSERT_TRUE(input0 != nullptr); + auto input0 = ctx.array(0); + ASSERT_TRUE(input0 != nullptr); - auto input1 = ctx.array(1); - ASSERT_TRUE(input1 != nullptr); + auto input1 = ctx.array(1); + ASSERT_TRUE(input1 != nullptr); - ASSERT_TRUE(input0->buffer() == array0.buffer()); - ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo()); + ASSERT_TRUE(input0->buffer() == array0.buffer()); + ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo()); - ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer()); - ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo()); + ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer()); + ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo()); - ASSERT_TRUE(input1->buffer() == array1.buffer()); - ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo()); + ASSERT_TRUE(input1->buffer() == array1.buffer()); + ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo()); - ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer()); - ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo()); + ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer()); + ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo()); } TEST_F(ContextTests, test_short_context_2) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - op.execute(&ctx); + sd::ops::add op; + op.execute(&ctx); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(ContextTests, test_short_context_3) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - op.execute(&ctx); + sd::ops::add op; + op.execute(&ctx); - ASSERT_EQ(1, ctx.fastpath_out().size()); + ASSERT_EQ(1, ctx.fastpath_out().size()); - auto z = ctx.fastpath_out()[0]; + auto z = ctx.fastpath_out()[0]; - ASSERT_EQ(exp, *z); + ASSERT_EQ(exp, *z); } TEST_F(ContextTests, test_copy_1) { - ContextPrototype prototype(nullptr, 12); + ContextPrototype prototype(nullptr, 12); - auto copy = prototype; + auto copy = prototype; - ASSERT_EQ(prototype.nodeId(), copy.nodeId()); + ASSERT_EQ(prototype.nodeId(), copy.nodeId()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 0d7d6288d5e8..aa4582a511cf 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -21,17 +21,18 @@ #ifndef LIBND4J_CONVOLUTIONTESTS1_H #define LIBND4J_CONVOLUTIONTESTS1_H -#include "testlayers.h" #include #include #include #include #include +#include +#include #include -#include #include -#include -#include +#include + +#include "testlayers.h" #ifdef HAVE_MKLDNN #include @@ -41,14 +42,12 @@ using namespace sd; using namespace sd::graph; class ConvolutionTests1 : public testing::Test { -public: - + public: }; template class TypedConvolutionTests1 : public testing::Test { -public: - + public: }; typedef ::testing::Types TestingTypes; @@ -56,2803 +55,4840 @@ TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_1) { - - int bS=1, iH=5,iW=4, iC=2,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; - Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - for (int e = 0; e < input.lengthOf(); e++) - input.p(e, e + 1); - - for (int e = 0; e < weights.lengthOf(); e++) - weights.p(e, e + 1); - weights.permutei({2,3,1,0}); - - // weights->printShapeInfo("weights"); - - ArrayOptions::setDataType(_expS, input.dataType()); - auto exp = new NDArray(_expB, _expS); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, weights); - - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({-1, -2}); - // 5,5 kernel - block->appendI(kH); - block->appendI(kW); - - // 1,1 stride - block->appendI(sH); - block->appendI(sW); - - // 0,0 padding - block->appendI(pH); - block->appendI(pW); - - // 1,1 dilation - block->appendI(dH); - block->appendI(dW); - - // same mode - block->appendI(1); - - // is NHWC - block->appendI(0); - - sd::ops::conv2d op; - - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto res = variableSpace->getVariable(1)->getNDArray(); - - - // checking output shape - ASSERT_EQ(1, res->sizeAt(0)); - ASSERT_EQ(3, res->sizeAt(1)); - ASSERT_EQ(5, res->sizeAt(2)); - ASSERT_EQ(4, res->sizeAt(3)); - - // basically the same as above - ASSERT_TRUE(res->isSameShape(exp)); - // just for visual validation - // exp->printIndexedBuffer("Expected"); - // res->printIndexedBuffer("Actual "); - // res->printShapeInfo("Result shape"); - // final check - ASSERT_TRUE(res->equalsTo(exp)); - - delete block; - delete variableSpace; - delete exp; + int bS = 1, iH = 5, iW = 4, iC = 2, oC = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + TypeParam _expB[]{ + 664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, + 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, + 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, + 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, + 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, + 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, + 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; + Nd4jLong _expS[]{ + 4, 1, 3, 5, 4, + 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); + + for (int e = 0; e < weights.lengthOf(); e++) weights.p(e, e + 1); + weights.permutei({2, 3, 1, 0}); + + // weights->printShapeInfo("weights"); + + ArrayOptions::setDataType(_expS, input.dataType()); + auto exp = new NDArray(_expB, _expS); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1, -2}); + // 5,5 kernel + block->appendI(kH); + block->appendI(kW); + + // 1,1 stride + block->appendI(sH); + block->appendI(sW); + + // 0,0 padding + block->appendI(pH); + block->appendI(pW); + + // 1,1 dilation + block->appendI(dH); + block->appendI(dW); + + // same mode + block->appendI(1); + + // is NHWC + block->appendI(0); + + sd::ops::conv2d op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto res = variableSpace->getVariable(1)->getNDArray(); + + // checking output shape + ASSERT_EQ(1, res->sizeAt(0)); + ASSERT_EQ(3, res->sizeAt(1)); + ASSERT_EQ(5, res->sizeAt(2)); + ASSERT_EQ(4, res->sizeAt(3)); + + // basically the same as above + ASSERT_TRUE(res->isSameShape(exp)); + // just for visual validation + // exp->printIndexedBuffer("Expected"); + // res->printIndexedBuffer("Actual "); + // res->printShapeInfo("Result shape"); + // final check + ASSERT_TRUE(res->equalsTo(exp)); + + delete block; + delete variableSpace; + delete exp; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, + 6.f, 8.f}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::conv2d op; + auto result = + op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_3) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - - + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_4) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f}); - - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f, + 170.4f, 175.20001f, 180.f, 170.4f, 175.20001f, 180.f}); + + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_5) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, + 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, + 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_6) { - auto input = NDArrayFactory::create('c', {54, 1, 12, 12}); - auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); + auto input = NDArrayFactory::create('c', {54, 1, 12, 12}); + auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv2d op; + auto result = + op.evaluate({&input, &weights}, {}, {-1, -1, 1, 1, 0, 0, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_7) { + int bS = 1, iH = 256, iW = 256, iC = 1, oC = 1, kH = 4, kW = 3, sH = 1, + sW = 1, pH = 0, pW = 0, dH = 1, dW = 1; + // int oH=256,oW=256; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW - int bS=1, iH=256,iW=256, iC=1,oC=1, kH=4,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - // int oH=256,oW=256; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - - input = 5.; - weights = 3.; + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + input = 5.; + weights = 3.; + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_8) { - - int bS=1, iH=6,iW=8, iC=2,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=6,oW=8; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); - - NDArray weights('c', {kH, kW, iC, oC}, {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, 0.17209206521511078, 0.6257694959640503}); - NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); - - NDArray expOutput('c', {bS, oC, oH, oW}, {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, - 1.112327, 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, 1.085454, 0.977661, - 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, - 0.860346, 2.264212, 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, 1.476586, - 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, 1.471205, 2.150177, 2.039078, 1.933456, - 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 1, iH = 6, iW = 8, iC = 2, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 6, oW = 8; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); + + NDArray weights('c', {kH, kW, iC, oC}, + {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, + 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, + 0.17209206521511078, 0.6257694959640503}); + NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, + 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, 1.112327, + 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, + 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, + 1.085454, 0.977661, 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, + 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, + 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, 0.860346, 2.264212, + 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, + 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, + 1.476586, 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, + 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, + 1.471205, 2.150177, 2.039078, 1.933456, 1.764169, 2.584944, 2.521004, + 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, + 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); + + sd::ops::conv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_9) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kH, kW}, {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, - 2.7, 3.9, -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, - 2.8, 4., -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, - 84.200005, 76.400009, -221.5, -226.899994, -237.699997, -243.099991, -241.899994, -248.5, -261.700012, -268.299988, - -266.799988, -274.600006, -290.200012, -298.}, sd::DataType::FLOAT32); - - input.linspace(25,-0.5); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kH, kW}, + {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, + -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, 2.7, 3.9, + -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, + -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, 2.8, 4., + -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, + -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, + 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, + 84.200005, 76.400009, -221.5, -226.899994, -237.699997, + -243.099991, -241.899994, -248.5, -261.700012, -268.299988, + -266.799988, -274.600006, -290.200012, -298.}, + sd::DataType::FLOAT32); + + input.linspace(25, -0.5); + + sd::ops::conv2d op; + auto results = op.evaluate( + {&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv2d_10) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kH, kW, iC}, {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, - 3.6, 3.9, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, - 3.1, 3.4, 3.7, 4., -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, - 2.9, 3.2, 3.5, 3.8, 4.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, 470.500031, 113.600006, 130.400009, 142.699982, - -63.999958, -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, -144.100021, -124., -108.399994, -128.799988, -98.799973, -73.300011, - -150.400009, -125.200012, -104.500008, -133.300003, -120.399994, -112.000008, -170.199997, -154., -142.299988, -146.200012, -133.199997, -124.699997, - -88.000008, -80.800003, -78.099991, -170.200012, -173.199997, -180.699982, -223., -229.199997, -239.900009, -88., -90.400002, -97.300003, -323.200012, - -336.399994, -354.100037, -344.800018, -362.799988, -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, -415.599976, -447.700012, -409.599976, - -442., -478.900024, -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, 68.899994, 141.799988, 116.399994, 86.5, 171.200012, 159.200012, 142.699997}, sd::DataType::FLOAT32); - - input.linspace(25,-0.5); - - sd::ops::conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kH, kW, iC}, + {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, + 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, 3.6, 3.9, + -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, + 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., + -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, + 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, + 470.500031, 113.600006, 130.400009, 142.699982, -63.999958, + -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, + -144.100021, -124., -108.399994, -128.799988, -98.799973, + -73.300011, -150.400009, -125.200012, -104.500008, -133.300003, + -120.399994, -112.000008, -170.199997, -154., -142.299988, + -146.200012, -133.199997, -124.699997, -88.000008, -80.800003, + -78.099991, -170.200012, -173.199997, -180.699982, -223., + -229.199997, -239.900009, -88., -90.400002, -97.300003, + -323.200012, -336.399994, -354.100037, -344.800018, -362.799988, + -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, + -415.599976, -447.700012, -409.599976, -442., -478.900024, + -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, + 68.899994, 141.799988, 116.399994, 86.5, 171.200012, + 159.200012, 142.699997}, + sd::DataType::FLOAT32); + + input.linspace(25, -0.5); + + sd::ops::conv2d op; + auto results = op.evaluate( + {&input, &weights, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { - float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; - Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99}; - NDArray exp(_expB, _expS); - - int sY = 1; - int sX = 1; - int pY = 0; - int pX = 0; - int iC = 2; - int oC = 3; - int kY = 5; - int kX = 5; - int iY = 10; - int iX = 10; - int B = 2; - - auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); - for (int e = 0; e < input.lengthOf(); e++) - input.p(e, e+1); - - auto weights = NDArrayFactory::create('c', {oC, iC, kY, kX}); - for (int e = 0; e < weights.lengthOf(); e++) - weights.p(e, e+1); - weights.permutei({2,3,1,0}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, weights); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2}); - - block->appendI(kY); - block->appendI(kX); - - block->appendI(sY); - block->appendI(sX); - - block->appendI(pY); - block->appendI(pX); - - // dilation - block->appendI(1); - block->appendI(1); - - // NOT same mode - block->appendI(0); - - sd::ops::sconv2d op; - - Nd4jStatus status = op.execute(block); - - ASSERT_EQ(ND4J_STATUS_OK, status); - auto output = variableSpace->getVariable(1)->getNDArray(); - - //exp.printShapeInfo("Expected shape"); - //output->printShapeInfo("Result shape"); - ASSERT_TRUE(exp.isSameShape(*output)); - - //exp.printBuffer("Expctd buffer"); - //output->printBuffer("Result buffer"); - ASSERT_TRUE(exp.equalsTo(*output)); - - delete block; - delete variableSpace; + float _expB[] = { + 10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, + 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, + 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, + 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, + 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, + 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, + 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, + 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, + 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, + 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, + 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, + 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, + 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, + 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, + 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, + 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, + 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, + 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, + 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, + 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, + 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, + 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, + 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, + 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, + 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, + 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, + 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, + 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, + 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, + 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, + 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, + 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, + 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, + 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, + 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, + 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, + 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, + 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, + 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, + 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, + 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, + 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, + 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, + 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, + 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, + 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, + 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, + 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, + 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, + 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, + 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, + 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, + 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, + 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, + 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, + 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, + 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, + 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, + 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, + 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, + 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, + 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, + 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, + 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, + 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, + 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, + 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, + 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, + 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, + 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, + 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, + 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f, + }; + Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99}; + NDArray exp(_expB, _expS); + + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 3; + int kY = 5; + int kX = 5; + int iY = 10; + int iX = 10; + int B = 2; + + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); + + auto weights = NDArrayFactory::create('c', {oC, iC, kY, kX}); + for (int e = 0; e < weights.lengthOf(); e++) weights.p(e, e + 1); + weights.permutei({2, 3, 1, 0}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); + + block->appendI(kY); + block->appendI(kX); + + block->appendI(sY); + block->appendI(sX); + + block->appendI(pY); + block->appendI(pX); + + // dilation + block->appendI(1); + block->appendI(1); + + // NOT same mode + block->appendI(0); + + sd::ops::sconv2d op; + + Nd4jStatus status = op.execute(block); + + ASSERT_EQ(ND4J_STATUS_OK, status); + auto output = variableSpace->getVariable(1)->getNDArray(); + + // exp.printShapeInfo("Expected shape"); + // output->printShapeInfo("Result shape"); + ASSERT_TRUE(exp.isSameShape(*output)); + + // exp.printBuffer("Expctd buffer"); + // output->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(*output)); + + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { - TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, - 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, - 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, - 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, - 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, - 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, - 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, - 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, - 3932.1512818f, 3940.2052819f, 3948.2592820f, 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, - 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, - 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; - Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; - NDArray expFF(_expBFF, _expSFF); - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {5, 3, 5, 5}); - auto weightsP = NDArrayFactory::create('c', {10, 15, 1, 1}); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - input.applyScalar(scalar::Divide, 100.0, input); - weightsD.applyScalar(scalar::Divide, 100.0, weightsD); - weightsP.applyScalar(scalar::Divide, 100.0, weightsP); - - sd::ops::sconv2d op; - - auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - auto z = resultFF.at(0); - //z->printShapeInfo("FF shape"); - - - ASSERT_TRUE(z.isSameShape(&expFF)); - - //expFF.printBuffer("e"); - //z->printBuffer("z"); - ASSERT_TRUE(z.equalsTo(&expFF, 1e-3)); + TypeParam _expBFF[] = { + 108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, + 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, + 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, + 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, + 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, + 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, + 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, + 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, + 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, + 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, + 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, + 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, + 283.1417516f, 284.8507516f, 286.5597516f, 288.268751f, + 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, + 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, + 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, + 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, + 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, + 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, + 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, + 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, + 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, + 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, + 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, + 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, + 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, + 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, + 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, + 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, + 574.892248f, 578.716248f, 582.540248f, 586.3642486f, + 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, + 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, + 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, + 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, + 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, + 727.852248f, 731.676248f, 735.500248f, 739.324248f, + 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, + 688.5515044f, 693.4330044f, 717.8405044f, 722.7220044f, + 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, + 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, + 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, + 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, + 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, + 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, + 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, + 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, + 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, + 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, + 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, + 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, + 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, + 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, + 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, + 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, + 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, + 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, + 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, + 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, + 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, + 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, + 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, + 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, + 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, + 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, + 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, + 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, + 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, + 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, + 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, + 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, + 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, + 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, + 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, + 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, + 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, + 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, + 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, + 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, + 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, + 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, + 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, + 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, + 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, + 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, + 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, + 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, + 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, + 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, + 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, + 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, + 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, + 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, + 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, + 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, + 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, + 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, + 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, + 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, + 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, + 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, + 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, + 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, + 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, + 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, + 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, + 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, + 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, + 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, + 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, + 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, + 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, + 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, + 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, + 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, + 1340.6589908f, 1343.4254908f, 1357.2579907f, 1360.0244907f, + 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, + 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, + 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, + 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, + 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, + 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, + 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, + 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, + 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, + 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, + 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, + 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, + 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, + 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, + 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, + 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, + 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, + 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, + 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, + 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, + 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, + 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, + 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, + 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, + 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, + 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, + 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, + 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, + 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, + 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, + 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, + 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, + 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, + 3271.9059562f, 3278.9024562f, 3285.8989561f, 3292.8954560f, + 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, + 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, + 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, + 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, + 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, + 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, + 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, + 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, + 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, + 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, + 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, + 3924.097281f, 3932.1512818f, 3940.2052819f, 3948.2592820f, + 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, + 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, + 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, + 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, + 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, + 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, + 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, + 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, + 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, + 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, + 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, + 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, + 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, + 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, + 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, + 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, + 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, + 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; + Nd4jLong _expSFF[] = { + 4, 2, 10, 6, 6, + 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99, + }; + NDArray expFF(_expBFF, _expSFF); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {5, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 15, 1, 1}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + + sd::ops::sconv2d op; + + auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, + {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = resultFF.at(0); + // z->printShapeInfo("FF shape"); + + ASSERT_TRUE(z.isSameShape(&expFF)); + + // expFF.printBuffer("e"); + // z->printBuffer("z"); + ASSERT_TRUE(z.equalsTo(&expFF, 1e-3)); } TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {2}); - auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); - output.assign(0.0); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - bias.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); - - sd::ops::sconv2d op; - Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); - auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); - - auto z = result.at(0); - - //printf("\n"); - //output.printBuffer("output"); - //z->printBuffer("z"); - - - //ASSERT_TRUE(expOutput.isSameShape(z)); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, + {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); + auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, + {1, 1, 1, 1, 0, 0, 1, 1, 0}); + + auto z = result.at(0); + + // printf("\n"); + // output.printBuffer("output"); + // z->printBuffer("z"); + + // ASSERT_TRUE(expOutput.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_4) { - - int bS=1, iH=6,iW=6, iC=3,oC=2,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=6,oW=6; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, - 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231}); - NDArray weightsD('c', {kH, kW, iC, mC}, {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); - NDArray weightsP('c', {1, 1, iC*mC, oC}, {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); - NDArray biases('c', {1,oC}, {0.8807470202445984, 0.6262521147727966}); - - NDArray expOutput('c', {bS, oC, oH, oW}, {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, - 2.026476, 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, 1.297912, 1.212318, - 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, - 2.096277, 1.178815, 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, 1.288620, - 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); - - sd::ops::sconv2d op; - auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { - TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; - Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expWGrad(_expWGradB, _expWGradS); - expWGrad.permutei({2,3,1,0}); - - TypeParam _expBGradB[] = {784.0, 1296.0}; - Nd4jLong _expBGradS[] = {2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - - NDArray expBGrad(_expBGradB, _expBGradS); - - auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, input.shapeInfo()); - - input.linspace(1); - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - - auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - - ASSERT_TRUE(results.size() == 3); - - auto epsilon = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expWGrad.isSameShape(gradW)); - - //expWGrad.printBuffer("Expctd buffer"); - // gradW->printBuffer("Result buffer"); - ASSERT_TRUE(expWGrad.equalsTo(gradW)); - - - ASSERT_TRUE(input.isSameShape(epsilon)); - - // expEps.printBuffer("Expctd buffer"); - //epsilon->printBuffer("Result buffer"); - ASSERT_TRUE(expEps.equalsTo(epsilon)); - - ASSERT_TRUE(expBGrad.isSameShape(gradB)); - - ASSERT_TRUE(expBGrad.equalsTo(gradB)); - - -} - - -TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { - TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; - Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expWGrad(_expWGradB, _expWGradS); - expWGrad.permutei({2,3,1,0}); - - auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, input.shapeInfo()); - - input.linspace(1); - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - - auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - - ASSERT_TRUE(results.size() == 2); - - auto epsilon = results.at(0); - auto gradW = results.at(1); - - ASSERT_TRUE(expWGrad.isSameShape(gradW)); - - //expWGrad.printBuffer("Expctd buffer"); - // gradW->printBuffer("Result buffer"); - ASSERT_TRUE(expWGrad.equalsTo(gradW)); - - - ASSERT_TRUE(input.isSameShape(epsilon)); - - // expEps.printBuffer("Expctd buffer"); - //epsilon->printBuffer("Result buffer"); - ASSERT_TRUE(expEps.equalsTo(epsilon)); - - + int bS = 1, iH = 6, iW = 6, iC = 3, oC = 2, mC = 3, kH = 1, kW = 1, sH = 1, + sW = 1, pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 6, oW = 6; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, + 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, + 0.273894, 0.431796, 0.133231}); + NDArray weightsD( + 'c', {kH, kW, iC, mC}, + {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, + 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, + 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); + NDArray weightsP( + 'c', {1, 1, iC * mC, oC}, + {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, + 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, + 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, + 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, + 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, + 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); + NDArray biases('c', {1, oC}, {0.8807470202445984, 0.6262521147727966}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, + 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, 2.026476, + 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, + 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, + 1.297912, 1.212318, 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, + 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, + 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, 2.096277, 1.178815, + 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, + 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, + 1.288620, 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, + 1.762792, 1.352515}); + + sd::ops::sconv2d op; + auto results = + op.evaluate({&input, &weightsD, &weightsP, &biases}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } -TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {5, 5, 3, 2}, {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, - 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, - 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, - 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, - 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); - auto weightsP = NDArrayFactory::create('c', {1, 1, 6, 10}, {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, 0.0049f, 0.0055f,0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, 0.0038f, 0.0044f, 0.0050f, 0.0056f, - 0.0003f, 0.0009f, 0.0015f, 0.0021f, 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f,0.0004f, 0.0010f, 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, - 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, 0.0053f, 0.0059f,0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, 0.0042f, 0.0048f, 0.0054f, 0.0060f}); - - auto expFF = NDArrayFactory::create('c', {2, 6, 6, 6}, {10025.0f,10350.0f,10675.0f,11000.0f,11325.0f,11650.0f,13275.0f,13600.0f,13925.0f,14250.0f,14575.0f,14900.0f,16525.0f,16850.0f, - 17175.0f,17500.0f,17825.0f,18150.0f,19775.0f,20100.0f,20425.0f,20750.0f,21075.0f,21400.0f,23025.0f,23350.0f,23675.0f,24000.0f, - 24325.0f,24650.0f,26275.0f,26600.0f,26925.0f,27250.0f,27575.0f,27900.0f,53150.0f,55350.0f,57550.0f,59750.0f,61950.0f,64150.0f, - 75150.0f,77350.0f,79550.0f,81750.0f,83950.0f,86150.0f,97150.0f,99350.0f,101550.0f,103750.0f,105950.0f,108150.0f,119150.0f, - 121350.0f,123550.0f,125750.0f,127950.0f,130150.0f,141150.0f,143350.0f,145550.0f,147750.0f,149950.0f,152150.0f,163150.0f, - 165350.0f,167550.0f,169750.0f,171950.0f,174150.0f,119400.0f,120350.0f,121300.0f,122250.0f,123200.0f,124150.0f,128900.0f, - 129850.0f,130800.0f,131750.0f,132700.0f,133650.0f,138400.0f,139350.0f,140300.0f,141250.0f,142200.0f,143150.0f,147900.0f, - 148850.0f,149800.0f,150750.0f,151700.0f,152650.0f,157400.0f,158350.0f,159300.0f,160250.0f,161200.0f,162150.0f,166900.0f, - 167850.0f,168800.0f,169750.0f,170700.0f,171650.0f,350025.0f,352850.0f,355675.0f,358500.0f,361325.0f,364150.0f,378275.0f, - 381100.0f,383925.0f,386750.0f,389575.0f,392400.0f,406525.0f,409350.0f,412175.0f,415000.0f,417825.0f,420650.0f,434775.0f, - 437600.0f,440425.0f,443250.0f,446075.0f,448900.0f,463025.0f,465850.0f,468675.0f,471500.0f,474325.0f,477150.0f,491275.0f, - 494100.0f,496925.0f,499750.0f,502575.0f,505400.0f,353775.0f,355350.0f,356925.0f,358500.0f,360075.0f,361650.0f,369525.0f, - 371100.0f,372675.0f,374250.0f,375825.0f,377400.0f,385275.0f,386850.0f,388425.0f,390000.0f,391575.0f,393150.0f,401025.0f, - 402600.0f,404175.0f,405750.0f,407325.0f,408900.0f,416775.0f,418350.0f,419925.0f,421500.0f,423075.0f,424650.0f,432525.0f, - 434100.0f,435675.0f,437250.0f,438825.0f,440400.0f,771900.0f,775350.0f,778800.0f,782250.0f,785700.0f,789150.0f,806400.0f, - 809850.0f,813300.0f,816750.0f,820200.0f,823650.0f,840900.0f,844350.0f,847800.0f,851250.0f,854700.0f,858150.0f,875400.0f, - 878850.0f,882300.0f,885750.0f,889200.0f,892650.0f,909900.0f,913350.0f,916800.0f,920250.0f,923700.0f,927150.0f,944400.0f, - 947850.0f,951300.0f,954750.0f,958200.0f,961650.0f,107525.0f,107850.0f,108175.0f,108500.0f,108825.0f,109150.0f,110775.0f, - 111100.0f,111425.0f,111750.0f,112075.0f,112400.0f,114025.0f,114350.0f,114675.0f,115000.0f,115325.0f,115650.0f,117275.0f, - 117600.0f,117925.0f,118250.0f,118575.0f,118900.0f,120525.0f,120850.0f,121175.0f,121500.0f,121825.0f,122150.0f,123775.0f, - 124100.0f,124425.0f,124750.0f,125075.0f,125400.0f,713150.0f,715350.0f,717550.0f,719750.0f,721950.0f,724150.0f,735150.0f, - 737350.0f,739550.0f,741750.0f,743950.0f,746150.0f,757150.0f,759350.0f,761550.0f,763750.0f,765950.0f,768150.0f,779150.0f, - 781350.0f,783550.0f,785750.0f,787950.0f,790150.0f,801150.0f,803350.0f,805550.0f,807750.0f,809950.0f,812150.0f,823150.0f, - 825350.0f,827550.0f,829750.0f,831950.0f,834150.0f,404400.0f,405350.0f,406300.0f,407250.0f,408200.0f,409150.0f,413900.0f, - 414850.0f,415800.0f,416750.0f,417700.0f,418650.0f,423400.0f,424350.0f,425300.0f,426250.0f,427200.0f,428150.0f,432900.0f,433850.0f,434800.0f,435750.0f,436700.0f,437650.0f,442400.0f,443350.0f,444300.0f,445250.0f,446200.0f,447150.0f,451900.0f,452850.0f,453800.0f,454750.0f,455700.0f,456650.0f,1197525.0f,1200350.0f,1203175.0f,1206000.0f,1208825.0f,1211650.0f,1225775.0f,1228600.0f,1231425.0f,1234250.0f,1237075.0f,1239900.0f,1254025.0f,1256850.0f,1259675.0f,1262500.0f,1265325.0f,1268150.0f,1282275.0f,1285100.0f,1287925.0f,1290750.0f,1293575.0f,1296400.0f,1310525.0f,1313350.0f,1316175.0f,1319000.0f,1321825.0f,1324650.0f,1338775.0f,1341600.0f,1344425.0f,1347250.0f,1350075.0f,1352900.0f,826275.0f,827850.0f,829425.0f,831000.0f,832575.0f,834150.0f,842025.0f,843600.0f,845175.0f,846750.0f,848325.0f,849900.0f,857775.0f,859350.0f,860925.0f,862500.0f,864075.0f,865650.0f,873525.0f,875100.0f,876675.0f,878250.0f,879825.0f,881400.0f,889275.0f,890850.0f,892425.0f,894000.0f,895575.0f,897150.0f,905025.0f,906600.0f,908175.0f,909750.0f,911325.0f,912900.0f,1806900.0f,1810350.0f,1813800.0f,1817250.0f,1820700.0f,1824150.0f,1841400.0f,1844850.0f,1848300.0f,1851750.0f,1855200.0f,1858650.0f,1875900.0f,1879350.0f,1882800.0f,1886250.0f,1889700.0f,1893150.0f,1910400.0f,1913850.0f,1917300.0f,1920750.0f,1924200.0f,1927650.0f,1944900.0f,1948350.0f,1951800.0f,1955250.0f,1958700.0f,1962150.0f,1979400.0f,1982850.0f,1986300.0f,1989750.0f,1993200.0f,1996650.f}); - auto exp2FF = NDArrayFactory::create('c', {2, 10, 6, 6}, {827.4900282f,832.2350283f,836.9800284f,841.725028f,846.4700287f,851.2150288f,874.9400293f,879.6850294f,884.4300295f,889.1750296f,893.9200297f,898.665029f, - 922.3900304f,927.1350305f,931.8800306f,936.6250307f,941.3700308f,946.1150309f,969.8400315f,974.5850316f,979.3300317f,984.0750318f,988.8200319f,993.5650320f, - 1017.2900326f,1022.0350327f,1026.7800328f,1031.5250329f,1036.2700330f,1041.0150331f,1064.7400337f,1069.4850338f,1074.2300339f,1078.9750340f,1083.7200341f, - 1088.4650342f,1822.4550553f,1833.995055f,1845.5350558f,1857.075056f,1868.6150563f,1880.1550566f,1937.8550578f,1949.3950581f,1960.9350583f,1972.4750586f, - 1984.015058f,1995.5550591f,2053.2550604f,2064.7950606f,2076.3350609f,2087.8750611f,2099.4150614f,2110.955061f,2168.6550629f,2180.1950632f,2191.7350634f, - 2203.2750637f,2214.8150639f,2226.3550642f,2284.0550655f,2295.5950657f,2307.1350660f,2318.6750662f,2330.2150665f,2341.7550667f,2399.4550680f,2410.9950683f, - 2422.5350685f,2434.0750688f,2445.6150690f,2457.1550693f,2817.419968f,2835.7549686f,2854.0899683f,2872.4249680f,2890.7599677f,2909.0949674f,3000.7699660f, - 3019.104965f,3037.4399655f,3055.7749652f,3074.1099649f,3092.4449646f,3184.1199632f,3202.4549629f,3220.789962f,3239.1249624f,3257.4599621f,3275.7949618f, - 3367.4699604f,3385.8049601f,3404.1399598f,3422.474959f,3440.8099593f,3459.1449590f,3550.8199576f,3569.1549573f,3587.4899570f,3605.8249567f,3624.1599565f, - 3642.4949562f,3734.1699548f,3752.5049545f,3770.8399542f,3789.1749539f,3807.5099536f,3825.8449534f,3812.385098f,3837.5150988f,3862.6450994f,3887.7751000f, - 3912.9051006f,3938.0351012f,4063.6851041f,4088.8151047f,4113.9451053f,4139.0751059f,4164.2051065f,4189.3351071f,4314.9851100f,4340.1151106f,4365.2451112f, - 4390.3751118f,4415.5051124f,4440.6351130f,4566.2851159f,4591.4151165f,4616.5451171f,4641.6751177f,4666.805118f,4691.9351188f,4817.5851218f,4842.7151224f, - 4867.8451230f,4892.975123f,4918.1051241f,4943.2351247f,5068.8851277f,5094.0151283f,5119.1451288f,5144.2751294f,5169.4051300f,5194.5351306f,4807.3499803f, - 4839.2749801f,4871.1999799f,4903.1249797f,4935.0499795f,4966.9749793f,5126.5999784f,5158.5249782f,5190.4499780f,5222.3749778f,5254.2999777f,5286.2249775f, - 5445.8499765f,5477.774976f,5509.6999762f,5541.6249760f,5573.5499758f,5605.4749756f,5765.0999747f,5797.0249745f,5828.9499743f,5860.8749741f,5892.7999739f, - 5924.724973f,6084.3499728f,6116.2749726f,6148.1999724f,6180.1249723f,6212.0499721f,6243.9749719f,6403.59997f,6435.5249708f,6467.4499706f,6499.3749704f, - 6531.2999702f,6563.2249700f,5802.3150007f,5841.0350006f,5879.7550005f,5918.4750004f,5957.195000f,5995.9150003f,6189.5149999f,6228.2349998f,6266.9549997f, - 6305.6749996f,6344.3949995f,6383.114999f,6576.7149990f,6615.4349990f,6654.1549989f,6692.8749988f,6731.5949987f,6770.3149986f,6963.9149982f,7002.6349981f, - 7041.3549981f,7080.0749980f,7118.7949979f,7157.5149978f,7351.1149974f,7389.8349973f,7428.5549972f,7467.2749972f,7505.9949971f,7544.7149970f,7738.3149966f,7777.0349965f,7815.7549964f,7854.4749963f,7893.1949963f,7931.9149962f,6797.2799488f,6842.794948f,6888.3099489f,6933.8249490f,6979.3399491f,7024.8549492f,7252.4299497f,7297.9449498f,7343.4599499f,7388.9749500f,7434.489950f,7480.0049501f,7707.5799506f,7753.0949507f,7798.6099508f,7844.1249509f,7889.6399510f,7935.1549511f,8162.7299515f,8208.2449516f,8253.7599517f,8299.2749518f,8344.7899519f,8390.3049520f,8617.8799525f,8663.394952f,8708.9099526f,8754.4249527f,8799.9399528f,8845.4549529f,9073.0299534f,9118.5449535f,9164.0599536f,9209.5749537f,9255.089953f,9300.604953f,7792.2451647f,7844.5551655f,7896.8651663f,7949.1751671f,8001.4851679f,8053.7951686f,8315.3451725f,8367.6551733f,8419.9651741f,8472.2751749f,8524.585175f,8576.8951764f,8838.4451803f,8890.7551811f,8943.0651819f,8995.3751827f,9047.6851834f,9099.9951842f,9361.5451881f,9413.8551889f,9466.1651897f,9518.475190f,9570.7851912f,9623.0951920f,9884.6451959f,9936.9551967f,9989.2651975f,10041.5751982f,10093.8851990f,10146.1951998f,10407.7452037f,10460.0552045f,10512.3652053f,10564.6752060f,10616.9852068f,10669.2952076f,8787.210074f,8846.3150748f,8905.4200750f,8964.5250752f,9023.6300755f,9082.7350757f,9378.2600768f,9437.3650770f,9496.4700773f,9555.5750775f,9614.6800777f,9673.7850779f,9969.3100791f,10028.4150793f,10087.5200795f,10146.625079f,10205.7300800f,10264.8350802f,10560.3600813f,10619.465081f,10678.5700818f,10737.6750820f,10796.7800822f,10855.8850825f,11151.4100836f,11210.5150838f,11269.6200840f,11328.7250843f,11387.8300845f,11446.9350847f,11742.4600858f,11801.5650861f,11860.6700863f,11919.7750865f,11978.880086f,12037.9850870f,9782.1750935f,9848.0750935f,9913.9750934f,9979.8750934f,10045.7750934f,10111.6750933f,10441.1750931f,10507.0750931f,10572.9750931f,10638.8750930f,10704.7750930f,10770.6750930f,11100.1750928f,11166.0750927f,11231.9750927f,11297.8750927f,11363.7750926f,11429.6750926f,11759.1750924f,11825.0750924f,11890.9750923f,11956.8750923f,12022.7750923f,12088.6750922f,12418.175092f,12484.0750920f,12549.9750920f,12615.8750919f,12681.7750919f,12747.6750919f,13077.1750917f,13143.0750916f,13208.9750916f,13274.8750916f,13340.7750915f,13406.6750915f,2250.990060f,2255.7350610f,2260.4800611f,2265.2250612f,2269.9700613f,2274.7150614f,2298.4400619f,2303.185062f,2307.9300622f,2312.6750623f,2317.4200624f,2322.1650625f,2345.8900630f,2350.6350631f,2355.380063f,2360.1250634f,2364.8700635f,2369.6150636f,2393.3400641f,2398.0850642f,2402.8300643f,2407.5750644f,2412.320064f,2417.0650647f,2440.7900652f,2445.5350653f,2450.2800654f,2455.0250655f,2459.7700656f,2464.515065f,2488.2400663f,2492.9850664f,2497.7300665f,2502.4750666f,2507.2200667f,2511.9650668f,5284.4551315f,5295.9951318f,5307.535132f,5319.0751323f,5330.6151326f,5342.1551328f,5399.8551341f,5411.3951343f,5422.9351346f,5434.475134f,5446.0151351f,5457.5551354f,5515.2551366f,5526.7951369f,5538.3351371f,5549.8751374f,5561.4151376f,5572.9551379f,5630.6551392f,5642.1951394f,5653.7351397f,5665.2751399f,5676.8151402f,5688.3551404f,5746.0551417f,5757.5951420f,5769.1351422f,5780.6751425f,5792.2151427f,5803.7551430f,5861.455144f,5872.9951445f,5884.5351448f,5896.0751450f,5907.6151453f,5919.1551455f,8317.919884f,8336.2548841f,8354.5898838f,8372.9248835f,8391.2598832f,8409.59488f,8501.2698815f,8519.6048813f,8537.9398810f,8556.2748807f,8574.6098804f,8592.9448801f,8684.6198787f,8702.9548784f,8721.2898782f,8739.6248779f,8757.9598776f,8776.2948773f,8867.9698759f,8886.3048756f,8904.6398753f,8922.9748751f,8941.3098748f,8959.6448745f,9051.3198731f,9069.6548728f,9087.9898725f,9106.3248722f,9124.6598720f,9142.9948717f,9234.6698703f,9253.0048700f,9271.3398697f,9289.6748694f,9308.0098691f,9326.3448689f,11351.3852747f,11376.5152753f,11401.6452759f,11426.7752765f,11451.9052771f,11477.0352777f,11602.6852806f,11627.8152812f,11652.9452818f,11678.0752824f,11703.2052830f,11728.335283f,11853.9852865f,11879.1152871f,11904.2452877f,11929.3752883f,11954.505288f,11979.6352894f,12105.2852924f,12130.4152930f,12155.545293f,12180.6752941f,12205.8052947f,12230.9352953f,12356.5852983f,12381.715298f,12406.8452994f,12431.9753000f,12457.1053006f,12482.2353012f,12607.8853041f,12633.0153047f,12658.1453053f,12683.2753059f,12708.4053065f,12733.5353071f,14384.8499244f,14416.7749242f,14448.6999240f,14480.6249238f,14512.549923f,14544.4749235f,14704.0999225f,14736.024922f,14767.9499222f,14799.8749220f,14831.7999218f,14863.7249216f,15023.3499207f,15055.2749205f,15087.1999203f,15119.1249201f,15151.0499199f,15182.9749197f,15342.5999188f,15374.5249186f,15406.4499184f,15438.374918f,15470.2999181f,15502.2249179f,15661.84991f,15693.7749168f,15725.6999166f,15757.6249164f,15789.5499162f,15821.4749160f,15981.0999151f,16013.0249149f,16044.9499147f,16076.8749145f,16108.7999143f,16140.7249142f,17418.314976f,17457.0349761f,17495.7549760f,17534.4749759f,17573.1949758f,17611.9149757f,17805.5149753f,17844.234975f,17882.9549752f,17921.6749751f,17960.3949750f,17999.1149749f,18192.7149745f,18231.4349744f,18270.154974f,18308.8749743f,18347.5949742f,18386.3149741f,18579.9149737f,18618.6349736f,18657.3549735f,18696.074973f,18734.7949734f,18773.5149733f,18967.1149729f,19005.8349728f,19044.5549727f,19083.2749726f,19121.994972f,19160.7149725f,19354.3149721f,19393.0349720f,19431.7549719f,19470.4749718f,19509.1949717f,19547.914971f,20451.7799765f,20497.2949766f,20542.8099767f,20588.3249768f,20633.8399769f,20679.3549770f,20906.929977f,20952.4449775f,20997.9599776f,21043.4749777f,21088.9899778f,21134.5049779f,21362.0799784f,21407.5949785f,21453.1099786f,21498.624978f,21544.139978f,21589.6549788f,21817.2299793f,21862.7449794f,21908.2599795f,21953.7749796f,21999.2899797f,22044.8049798f,22272.3799802f,22317.8949803f,22363.4099804f,22408.9249805f,22454.4399806f,22499.9549807f,22727.529981f,22773.044981f,22818.5599813f,22864.0749814f,22909.5899815f,22955.1049816f,23485.2453985f,23537.555399f,23589.8654000f,23642.1754008f,23694.4854016f,23746.7954024f,24008.3454063f,24060.655407f,24112.9654078f,24165.2754086f,24217.5854094f,24269.8954102f,24531.4454141f,24583.7554148f,24636.0654156f,24688.3754164f,24740.6854172f,24792.99541f,25054.545421f,25106.8554226f,25159.1654234f,25211.4754242f,25263.7854250f,25316.0954257f,25577.6454296f,25629.9554304f,25682.2654312f,25734.5754320f,25786.8854328f,25839.1954335f,26100.7454374f,26153.0554382f,26205.3654390f,26257.6754398f,26309.985440f,26362.2954413f,26518.7101423f,26577.8151425f,26636.920142f,26696.0251430f,26755.1301432f,26814.2351434f,27109.7601446f,27168.8651448f,27227.9701450f,27287.0751452f,27346.1801455f,27405.2851457f,27700.8101468f,27759.9151470f,27819.0201473f,27878.1251475f,27937.2301477f,27996.33514f,28291.8601491f,28350.9651493f,28410.0701495f,28469.175149f,28528.2801500f,28587.3851502f,28882.9101513f,28942.0151516f,29001.1201518f,29060.2251520f,29119.3301522f,29178.4351525f,29473.9601536f,29533.0651538f,29592.1701540f,29651.2751543f,29710.3801545f,29769.4851547f,29552.1750826f,29618.0750825f,29683.9750825f,29749.8750825f,29815.7750824f,29881.6750824f,30211.1750822f,30277.0750822f,30342.9750821f,30408.8750821f,30474.7750821f,30540.6750820f,30870.175081f,30936.0750818f,31001.9750818f,31067.8750817f,31133.7750817f,31199.6750817f,31529.1750815f,31595.075081f,31660.9750814f,31726.8750814f,31792.7750813f,31858.6750813f,32188.1750811f,32254.0750811f,32319.975081f,32385.8750810f,32451.7750810f,32517.6750809f,32847.1750808f,32913.0750807f,32978.9750807f,33044.875080f,33110.7750806f,33176.67508062f}); - - input.linspace(1); - - sd::ops::sconv2d op; - auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); - - auto z = resultFF.at(0); - - ASSERT_TRUE(z.isSameShape(&expFF)); - ASSERT_TRUE(z.equalsTo(&expFF, 1)); - - - sd::ops::conv2d op2d; - // weightsP.printShapeInfo(); - auto result2D = op2d.evaluate({&z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - - auto z2d = result2D.at(0); - // z2d->printBuffer(); - - ASSERT_TRUE(z2d.isSameShape(&exp2FF)); - ASSERT_TRUE(z2d.equalsTo(&exp2FF)); -} - -TEST_F(ConvolutionTests1, deconv2d_bp_1) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, - 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, - 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, - 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, - 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, - 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, - 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, - 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, - 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, oC, iC}, {160008., 191112., 222216., 203400., 246792., 290184.f}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {1944.f, 2712.f}, sd::DataType::FLOAT32); - - input.linspace(1); - bias.linspace(1); - gradO.linspace(1); - - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, deconv2d_bp_2) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; // 5,4 - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, oC, kH, kW}, {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {-77.400002, -77.199997, -77., -76.800003, -76.599998, -76.400002, -76.200005, -76., -75.800003, -75.599998, -75.399994, - -75.199997, -11.32, -11.29, -11.26, -11.23, -100.839996, -100.580002, -100.32, -100.059998, -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, -98.50, - -98.240005, -97.979996, -26.52, -26.450001, -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, -123.32, -123., -122.68, -122.360001, -122.040001, - -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, -41.610001, -41.50, -41.389999, -71., -70.800003, -70.599998, -70.399994, -70.199997, -70., -69.800003, -69.600006, -69.400002, -69.199997, -69., -68.799995, -10.360001, -10.33, -10.30, -10.27, -92.519997, -92.260002, -92., -91.740005, -91.479996, -91.220001, -90.960007, -90.700005, -90.440002, -90.18, -89.919998, -89.660004, -24.280001, -24.209999, -24.139999, -24.07, -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, -110.839996, -110.520004, -38.200001, -38.09, -37.980003, -37.869999, -64.599998, -64.400002, -64.199997, -64., -63.799995, -63.599998, -63.400002, -63.199997, -63., -62.799995, -62.599998, -62.400002, -9.40, -9.37, -9.34, -9.309999, -84.200005, -83.940002, -83.68, -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, -21.970001, -21.90, -21.83, -103.800003, -103.480003, -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, -34.68, -34.57, -34.459999, -34.349998}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, oC, kH, kW}, {-3010.799805, -2502.420410, -2899.439209, -2407.380615, -242.159332, -437.460510, -253.680466, -434.580048, 2526.479980, 1627.500000, 2392.079834, 1538.220093}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, sd::DataType::FLOAT32); - - input.linspace(70., -1); - gradO.linspace(-4, 0.01); - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, deconv2d_bp_3) { - - int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=5,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, kH, kW, oC}, {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iH, iW, iC}, {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, -117.540001, -85.619995, -101.279999, -116.940002, -85.18, - -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, -114.539993, -83.419998, -98.68, - -113.939995, -82.979996, -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, -111.539993, - -81.220001, -96.080002, -110.939995, -80.779999, -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, -79.900002, -94.519997, -109.139992, -77.699997, - -91.919998, -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, -75.940002, -89.839996, -103.740005, -75.5, -89.320007, -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, -88.279999, -101.940002, -74.18, -87.759995, -101.339996, -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, -85.68, -98.939995, -71.979996, -85.160004, -98.339996, -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, -81.00, -93.539993, -68.019997, -80.479996, -92.940002, -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, -88.740005, -64.5, -76.320007, -88.139999, -64.060005, -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, -63.18, -74.759995, -86.339996, -62.739998, -74.239998, -85.739998, -62.299999, -73.720001, -85.139999}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, kH, kW, oC}, {-592.800110, -593.039917, -594.719116, -594.960266, -427.199890, -427.919617, -432.959900, -433.679993, -261.600281, -262.799591, -271.200317, -272.399536}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); - - input.linspace(70., -1); - gradO.linspace(-4, 0.01); - - sd::ops::deconv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { - auto input = NDArrayFactory::create('c', {2, 2, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); - auto bias = NDArrayFactory::create('c', {3}); - auto expFF = NDArrayFactory::create('c', {2, 3, 5}, {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); - auto expEps = NDArrayFactory::create('c', {2, 2, 6}, {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); - auto expGW = NDArrayFactory::create('c', {3, 2, 2}, {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, 2315.0f, 2520.0f, 3545.0f, 3750.0f}); - auto expGB = NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); - - expGW.permutei({2,1,0}); - input.linspace(1); - bias.linspace(1); - - sd::ops::conv1d op; - auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result_FF.status()); - - auto z = result_FF.at(0); - - ASSERT_TRUE(expFF.isSameShape(z)); - ASSERT_TRUE(expFF.equalsTo(z)); - - sd::ops::conv1d_bp op_bp; - - auto epsilonNxt = new NDArray(z.dup()); - epsilonNxt->linspace(1); - - auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result_BP.status()); - - auto eps = result_BP.at(0); - auto gradW = result_BP.at(1); - auto gradB = result_BP.at(2); - - ASSERT_TRUE(expEps.isSameShape(eps)); - ASSERT_TRUE(expGW.isSameShape(gradW)); - ASSERT_TRUE(expGB.isSameShape(gradB)); - - ASSERT_TRUE(expEps.equalsTo(eps)); - ASSERT_TRUE(expGW.equalsTo(gradW)); - ASSERT_TRUE(expGB.equalsTo(gradB)); - - delete epsilonNxt; -} - - -TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { - auto input = NDArrayFactory::create('c', {2, 2, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); - - input.linspace(1); - - sd::ops::conv1d op; - auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_1) { - - int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3}); - - NDArray expOutput('c', {bS, oW, oC}, {18. , 18. , 18. , 53. , 55.6, 58.2, 89.8, 95.6, 101.4, 102. , 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_2) { - - int bS=2, iW=16, iC=3,oC=4, kW=2, sW=2, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, { 10. , 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95. , 101.5, 108. , 128.1, 138.2, 148.3, 158.4, - 167.7, 181.4, 195.1, 208.8, 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311. , 335.5, 360. , - 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, - 484.5, 527. , 569.5, 612. , 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, 603.3, 656.6, 709.9, 763.2}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_3) { - - int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16.,145.4, 151.6, 157.8, 164.,283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., - 558.5, 589., 619.5, 650.,696.2001, 734.8, 773.4, 812., 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., - 1017.5, 1075., 1132.5, 1190.,1155.2001, 1220.8, 1286.4, 1352.,1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_4) { - - int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16. ,43.3, 43.8, 44.3, 44.8,69.4, 70.8, 72.2, 73.6,106.5, 109.4, 112.3, 115.2,147.9, 152.6, 157.3, 162. ,189.3, 195.8, 202.3, - 208.8,234.5, 243.4, 252.3, 261.2,280.4, 292. , 303.6, 315.2, 226. , 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2,278.2, 286.8, 295.4, 304. ,437.7, - 455. , 472.3, 489.6,479.1, 498.2, 517.3, 536.4,520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, 647.6, 680.8, 714. , 747.2}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_5) { - - int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iW}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oC, oW}, { 83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7,85.4, 94.4, 103.4, 167.4, 181.8, 196.2, 233.2, 249.4,87.1, 96.4, 105.7, 172.7, 187.7, 202.7, 243. , 260.1, - 88.8, 98.4, 108. , 178. , 193.6, 209.2, 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, 301.4, 310.4, 319.4, 513. , 527.4, 541.8, 622. , 638.2, - 310.3, 319.6, 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, 568. , 583.6, 684.8, 702.8}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_6) { - - int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iW}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3,-4}); - - NDArray expOutput('c', {bS, oC, oW}, {159.7,335.3,381.2,427.1,473. ,518.9,163.8,351.4,400. ,448.6,497.2,545.8,167.9,367.5,418.8,470.1,521.4,572.7,172. ,383.6,437.6,491.6,545.6,599.6, - 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3,595.8, 1129. , 1177.6, 1226.2, 1274.8, 1323.4,614.3, 1188.3, 1239.6, 1290.9, 1342.2, 1393.5, - 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_7) { - - int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, - 61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, - 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, - 221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, - 313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, sd::DataType::FLOAT32); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_8) { - - int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998, - 51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, - 130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, - 156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, - 281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, - 356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, sd::DataType::FLOAT32); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - - sd::ops::conv1d op; - auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { - - int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; - int oW = (iW-1)/sW + 1; - int paddingMode = 2; // CAUSAL - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iW, iC}); - NDArray weights('c', {kW, iC, oC}); - NDArray bias('c', {oC}, {-1,-2,-3}); - NDArray gradO('c', {bS, oW, oC}); - - input.linspace(1., 1.); - weights.linspace(0.1, 0.1); - gradO.linspace(-1.5, 0.1); - - const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); - const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); - - sd::ops::conv1d opFF; - sd::ops::conv1d_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); -} - -TEST_F(ConvolutionTests1, Test_Dilation2D_1) { - auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); - auto weights = NDArrayFactory::create('c', {3, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 3}, {77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222,}); - - input.linspace(1); - weights.linspace(1); - - sd::ops::dilation2d op; - auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(ConvolutionTests1, Test_Dilation2D_2) { - auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); - auto weights = NDArrayFactory::create('c', {3, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 1, 2, 3}, {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213}); - - input.linspace(1); - weights.linspace(1); - - sd::ops::dilation2d op; - auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, - 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f, - 3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f, - 11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, - 17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f, - 11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f, - 0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f, - 0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f, - 3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, - 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, - 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); -} - -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, - 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, - 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, - 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); - - auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, - 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, - 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, - 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); - auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - weights.permutei({2,3,1,0}); - expGradW.permutei({2,3,1,0}); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_4) { - - int bS=1, iH=7,iW=1, iC=2,oC=3, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=7,oW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,2,3}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); - NDArray gradB('c', {oC}, sd::DataType::FLOAT32); - - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - - ASSERT_EQ(Status::OK(), status); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kH, kW}, {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW},{0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, - 0.392, -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, 0.382, 0.662, - 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, -0.052, -1.406, -1.426, -0.749, -2.221, - -1.508, 1.624, 2.732, 1.072, 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, - -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {oC, iC, kH, kW},{-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, 13.02, 13.700001, 15.06, 15.74, 17.1, - 17.780001, 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, 1.66, 2.66, 4.660001, - 5.660001, 7.66, 8.66, 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, -17.380001, -16.059999, -13.420003, -12.099999, - -9.46, -8.139999, -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, 30.139999, 31.459999, 34.099998, - 35.419998, 38.060001, 39.380001}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); - - input.linspace(-48, 1); - // weights.linspace(3.6, -0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kH, kW, iC}, {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iH, iW, iC}, {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, 2.259, -1.305, 1.962, -1.602, 4.545, - -3.555, 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, 3.273, -2.019, - 2.832, -2.460, 6.069, -5.163, 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, 9.225, -9.783, 11.547, -8.757, 9.855, - -10.448999, 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, 11.574, -6.570, 10.062, -8.082, - 20.745001, -16.514999, 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {oC, kH, kW, iC},{34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, - 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, 33.120003, 40.499996, - 47.879993, 55.260002, 32.399998, 37.139996, 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, 91.199997, 96.799995, 102.399994, 108.0, 115.199989, - 120.959999, 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, 41.400002, 46.260002, - 120.0, 129.0, 138.0, 147.0, 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); - - input.linspace(-48, 1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); -} - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, - 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, - 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, - 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, - 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, - 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, - 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, - 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); - - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, - 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, - 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, - 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, - 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, - 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - -} - - -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, - 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, - 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, - 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, - 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, - 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, - 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, - 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); - - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); - // auto expGradB('c', {oC},{}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - +TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, + 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, + 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, + 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = { + 4, 2, 1, 3, 3, + 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2, 3, 1, 0}); -} + TypeParam _expBGradB[] = {784.0, 1296.0}; + Nd4jLong _expBGradS[] = { + 2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; -//////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { + NDArray expBGrad(_expBGradB, _expBGradS); - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, - 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, - 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, - 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, - 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, - 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, - 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, - 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); + sd::ops::conv2d_bp op; - auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); + auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - weights.permutei({2, 3, 4, 1, 0}); - expGradW.permutei({2, 3, 4, 1, 0}); + ASSERT_TRUE(results.size() == 3); - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); + auto epsilon = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expWGrad.isSameShape(gradW)); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + // expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(input.isSameShape(epsilon)); - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + // expEps.printBuffer("Expctd buffer"); + // epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); + ASSERT_TRUE(expBGrad.isSameShape(gradB)); + ASSERT_TRUE(expBGrad.equalsTo(gradB)); } -////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv3d_bp_test4) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] +TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, + 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, + 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, + 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = { + 4, 2, 1, 3, 3, + 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2, 3, 1, 0}); + + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); + + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::conv2d_bp op; + + auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + + ASSERT_TRUE(results.size() == 2); + + auto epsilon = results.at(0); + auto gradW = results.at(1); + + ASSERT_TRUE(expWGrad.isSameShape(gradW)); + + // expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); + + ASSERT_TRUE(input.isSameShape(epsilon)); + + // expEps.printBuffer("Expctd buffer"); + // epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); +} - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kD, kH, kW}, {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, - -5.3, -6.5, 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, 6.9, 5.7, 4.5, - 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, - -3.3, -4.5, -5.7, -6.9, 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); +TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create( + 'c', {5, 5, 3, 2}, + {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, + 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, + 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, + 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, + 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, + 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, + 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, + 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, + 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, + 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, + 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, + 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, + 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, + 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, + 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); + auto weightsP = NDArrayFactory::create( + 'c', {1, 1, 6, 10}, + {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, + 0.0049f, 0.0055f, 0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, + 0.0038f, 0.0044f, 0.0050f, 0.0056f, 0.0003f, 0.0009f, 0.0015f, 0.0021f, + 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f, 0.0004f, 0.0010f, + 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, + 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, + 0.0053f, 0.0059f, 0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, + 0.0042f, 0.0048f, 0.0054f, 0.0060f}); + + auto expFF = NDArrayFactory::create( + 'c', {2, 6, 6, 6}, + {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, + 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, + 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, + 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, + 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, + 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, + 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, + 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, + 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, + 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, + 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, + 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, + 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, + 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, + 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, + 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, + 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, + 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, + 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, + 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, + 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, + 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, + 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, + 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, + 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, + 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, + 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, + 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, + 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, + 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, + 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, + 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, + 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, + 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, + 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, + 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, + 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, + 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, + 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, + 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, + 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, + 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, + 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, + 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, + 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, + 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, + 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, + 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, + 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, + 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, + 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, + 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, + 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, + 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, + 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, + 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, + 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, + 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, + 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, + 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, + 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, + 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, + 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, + 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, + 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, + 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, + 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, + 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, + 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, + 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, + 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, + 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.f}); + auto exp2FF = NDArrayFactory::create( + 'c', {2, 10, 6, 6}, + {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, + 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, + 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, + 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, + 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, + 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, + 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, + 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, + 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, + 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, + 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, + 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, + 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, + 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, + 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, + 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, + 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, + 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, + 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, + 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, + 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, + 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, + 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, + 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, + 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, + 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, + 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, + 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, + 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, + 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, + 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, + 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, + 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, + 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, + 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, + 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, + 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, + 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, + 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, + 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, + 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, + 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, + 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, + 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, + 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, + 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, + 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, + 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, + 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, + 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, + 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, + 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, + 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, + 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, + 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, + 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, + 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, + 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, + 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, + 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, + 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, + 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, + 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, + 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, + 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, + 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, + 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, + 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, + 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, + 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, + 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, + 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, + 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, + 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, + 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, + 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, + 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, + 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, + 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, + 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, + 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, + 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, + 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, + 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, + 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, + 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, + 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, + 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, + 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, + 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, + 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, + 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, + 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, + 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, + 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, + 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, + 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, + 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, + 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, + 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, + 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, + 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, + 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, + 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, + 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, + 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, + 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, + 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, + 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, + 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, + 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, + 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, + 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, + 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, + 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, + 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, + 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, + 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, + 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, + 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, + 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, + 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, + 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, + 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, + 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, + 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, + 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, + 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, + 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, + 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, + 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, + 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, + 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, + 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, + 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, + 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, + 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, + 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, + 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, + 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, + 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, + 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, + 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, + 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, + 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, + 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, + 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, + 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, + 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, + 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, + 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, + 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, + 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, + 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, + 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, + 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, + 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, + 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, + 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, + 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, + 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, + 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, + 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, + 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, + 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, + 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, + 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, + 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, + 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, + 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, + 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, + 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, + 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, + 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, + 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, + 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, + 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, + 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, + 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, + 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062f}); + + input.linspace(1); + + sd::ops::sconv2d op; + auto resultFF = + op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + auto z = resultFF.at(0); + + ASSERT_TRUE(z.isSameShape(&expFF)); + ASSERT_TRUE(z.equalsTo(&expFF, 1)); + + sd::ops::conv2d op2d; + // weightsP.printShapeInfo(); + auto result2D = + op2d.evaluate({&z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + + auto z2d = result2D.at(0); + // z2d->printBuffer(); + + ASSERT_TRUE(z2d.isSameShape(&exp2FF)); + ASSERT_TRUE(z2d.equalsTo(&exp2FF)); +} - NDArray expGradI('c', {bS, iC, iD, iH, iW},{1.847, 3.577, 1.694, 3.460, 6.542, 3.010, 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, 5.408, 9.483999, 3.932, 1.894, - 2.978, 1.012, 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, - 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, -0.140, -1.108, -1.040, -1.936, - -5.816, -4.024, -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, 1.271, - 2.263, 0.956, 2.830, 5.102, 2.200, 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, -2.318, -5.770, -3.524, -1.303, - -3.173, -1.906, -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, sd::DataType::FLOAT32); +TEST_F(ConvolutionTests1, deconv2d_bp_1) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, oC, iC}, {1, 3, 5, 2, 4, 6}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, + 65.f, 68.f, 71.f, 74.f, 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, + 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, + 169.f, 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, + 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, 131.f, 134.f, + 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, + 167.f, 170.f, 173.f, 176.f, 295.f, 302.f, 309.f, 316.f, 323.f, 330.f, + 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, + 459.f, 470.f, 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, + 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, 236.f, + 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, + 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, 547.f, 554.f, 561.f, 568.f, + 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, + 833.f, 844.f, 855.f, 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, + 943.f, 954.f, 965.f, 976.f}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, oC, iC}, + {160008., 191112., 222216., 203400., 246792., 290184.f}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {1944.f, 2712.f}, sd::DataType::FLOAT32); + + input.linspace(1); + bias.linspace(1); + gradO.linspace(1); + + sd::ops::deconv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - NDArray expGradW('c', {oC, iC, kD, kH, kW},{-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, - 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, 46.879997, 48.200005, 50.839996, 52.160004, - 70.639999, 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, 99.680008, 118.160004, 119.479996, - 122.120003, 123.440010, 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, -64.279999, - -62.319996, -52.519993, -50.559994, -46.640003, -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, - 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, 211.680008, 214.280014, 219.479996, 222.079987}, sd::DataType::FLOAT32); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_2) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; // 5,4 + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kH, kW}, + {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {-77.400002, -77.199997, -77., -76.800003, -76.599998, + -76.400002, -76.200005, -76., -75.800003, -75.599998, + -75.399994, -75.199997, -11.32, -11.29, -11.26, + -11.23, -100.839996, -100.580002, -100.32, -100.059998, + -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, + -98.50, -98.240005, -97.979996, -26.52, -26.450001, + -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, + -123.32, -123., -122.68, -122.360001, -122.040001, + -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, + -41.610001, -41.50, -41.389999, -71., -70.800003, + -70.599998, -70.399994, -70.199997, -70., -69.800003, + -69.600006, -69.400002, -69.199997, -69., -68.799995, + -10.360001, -10.33, -10.30, -10.27, -92.519997, + -92.260002, -92., -91.740005, -91.479996, -91.220001, + -90.960007, -90.700005, -90.440002, -90.18, -89.919998, + -89.660004, -24.280001, -24.209999, -24.139999, -24.07, + -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, + -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, + -110.839996, -110.520004, -38.200001, -38.09, -37.980003, + -37.869999, -64.599998, -64.400002, -64.199997, -64., + -63.799995, -63.599998, -63.400002, -63.199997, -63., + -62.799995, -62.599998, -62.400002, -9.40, -9.37, + -9.34, -9.309999, -84.200005, -83.940002, -83.68, + -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, + -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, + -21.970001, -21.90, -21.83, -103.800003, -103.480003, + -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, + -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, + -34.68, -34.57, -34.459999, -34.349998}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kH, kW}, + {-3010.799805, -2502.420410, -2899.439209, -2407.380615, + -242.159332, -437.460510, -253.680466, -434.580048, + 2526.479980, 1627.500000, 2392.079834, 1538.220093}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, + sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_3) { + int bS = 3, iH = 4, iW = 4, iC = 3, oC = 2, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 5, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kH, kW, oC}, + {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, + -117.540001, -85.619995, -101.279999, -116.940002, -85.18, + -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, + -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, + -114.539993, -83.419998, -98.68, -113.939995, -82.979996, + -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, + -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, + -111.539993, -81.220001, -96.080002, -110.939995, -80.779999, + -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, + -79.900002, -94.519997, -109.139992, -77.699997, -91.919998, + -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, + -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, + -75.940002, -89.839996, -103.740005, -75.5, -89.320007, + -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, + -88.279999, -101.940002, -74.18, -87.759995, -101.339996, + -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, + -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, + -85.68, -98.939995, -71.979996, -85.160004, -98.339996, + -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, + -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, + -81.00, -93.539993, -68.019997, -80.479996, -92.940002, + -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, + -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, + -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, + -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, + -88.740005, -64.5, -76.320007, -88.139999, -64.060005, + -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, + -63.18, -74.759995, -86.339996, -62.739998, -74.239998, + -85.739998, -62.299999, -73.720001, -85.139999}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kH, kW, oC}, + {-592.800110, -593.039917, -594.719116, -594.960266, + -427.199890, -427.919617, -432.959900, -433.679993, + -261.600281, -262.799591, -271.200317, -272.399536}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - input.linspace(-75, 0.5); - gradO.linspace(0.01, 0.01); +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 5, 9, 3, 7, 11, 2, 6, 10, 4, 8, 12}); + auto bias = NDArrayFactory::create('c', {3}); + auto expFF = NDArrayFactory::create( + 'c', {2, 3, 5}, + {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, + 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, + 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, + 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); + auto expEps = NDArrayFactory::create( + 'c', {2, 2, 6}, + {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, + 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, + 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); + auto expGW = NDArrayFactory::create( + 'c', {3, 2, 2}, + {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, + 2315.0f, 2520.0f, 3545.0f, 3750.0f}); + auto expGB = + NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); + + expGW.permutei({2, 1, 0}); + input.linspace(1); + bias.linspace(1); + + sd::ops::conv1d op; + auto result_FF = + op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result_FF.status()); + + auto z = result_FF.at(0); + + ASSERT_TRUE(expFF.isSameShape(z)); + ASSERT_TRUE(expFF.equalsTo(z)); + + sd::ops::conv1d_bp op_bp; + + auto epsilonNxt = new NDArray(z.dup()); + epsilonNxt->linspace(1); + + auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, + {2, 1, 0, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result_BP.status()); + + auto eps = result_BP.at(0); + auto gradW = result_BP.at(1); + auto gradB = result_BP.at(2); + + ASSERT_TRUE(expEps.isSameShape(eps)); + ASSERT_TRUE(expGW.isSameShape(gradW)); + ASSERT_TRUE(expGB.isSameShape(gradB)); + + ASSERT_TRUE(expEps.equalsTo(eps)); + ASSERT_TRUE(expGW.equalsTo(gradW)); + ASSERT_TRUE(expGB.equalsTo(gradB)); + + delete epsilonNxt; +} - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create( + 'c', {2, 2, 3}, + {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); - ASSERT_EQ(Status::OK(), results.status()); + input.linspace(1); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::conv1d op; + auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1, 0}); - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + auto z = result.at(0); } ////////////////////////////////////////////////////////////////////// -TEST_F(ConvolutionTests1, conv3d_bp_test5) { +TEST_F(ConvolutionTests1, conv1d_causal_1) { + int bS = 2, iW = 3, iC = 4, oC = 3, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3}); - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kD, kH, kW, iC}, {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., - 8.7, 8.4, 8.1, 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, 14.9, 14.6, 14.3, 14., - 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., - 4.7, 4.4, 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, 11.2, 10.9, 10.6, - 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oW, oC}, + {18., 18., 18., 53., 55.6, 58.2, 89.8, 95.6, 101.4, 102., + 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); - NDArray expGradI('c', {bS, iD, iH, iW, iC}, {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, - 30.388000, 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, 43.852001, 42.807999, - 41.764000, 40.719997, 87.596001, 85.400002, 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, - 66.210007, 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, 130.367996, 126.587997, - 146.891998, 142.787994, 138.683990, 134.580017, 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, 155.316010, 150.563995, 173.783997, 168.707993, - 163.631989, 158.556000, 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, sd::DataType::FLOAT32); + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); - NDArray expGradW('c', {oC, kD, kH, kW, iC}, {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, - 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, 1090.080078, 1115.520020, 1140.959961, - 1166.400024, 1183.679932, 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, 1576.260010, - 1029.780029, 1046.429932, 1063.080078, 1079.729980, 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, 748.560059, 759.119995, 769.679993, 389.880005, - 422.819946, 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, 742.319946, 753.119995, 763.919983, 774.720032}, sd::DataType::FLOAT32); + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); - NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, sd::DataType::FLOAT32); + ASSERT_EQ(Status::OK(), results.status()); - input.linspace(-75, 0.5); - gradO.linspace(0.01, 0.01); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_2) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 2, sW = 2, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {10., 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95., + 101.5, 108., 128.1, 138.2, 148.3, 158.4, 167.7, 181.4, 195.1, 208.8, + 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311., + 335.5, 360., 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, + 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, 484.5, 527., + 569.5, 612., 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, + 603.3, 656.6, 709.9, 763.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_3) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 3, sW = 3, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {17.2, 16.8, 16.4, 16., 145.4, 151.6, 157.8, 164., + 283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., + 558.5, 589., 619.5, 650., 696.2001, 734.8, 773.4, 812., + 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., + 1017.5, 1075., 1132.5, 1190., 1155.2001, 1220.8, 1286.4, 1352., + 1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - sd::ops::conv3dnew_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_4) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 3, sW = 1, pW = 0, dW = 3; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {17.2, 16.8, 16.4, 16., 43.3, 43.8, 44.3, 44.8, 69.4, 70.8, + 72.2, 73.6, 106.5, 109.4, 112.3, 115.2, 147.9, 152.6, 157.3, 162., + 189.3, 195.8, 202.3, 208.8, 234.5, 243.4, 252.3, 261.2, 280.4, 292., + 303.6, 315.2, 226., 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2, + 278.2, 286.8, 295.4, 304., 437.7, 455., 472.3, 489.6, 479.1, 498.2, + 517.3, 536.4, 520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, + 647.6, 680.8, 714., 747.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_EQ(Status::OK(), results.status()); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_5) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 3, sW = 1, pW = 0, dW = 3; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oC, oW}, + {83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7, 85.4, 94.4, + 103.4, 167.4, 181.8, 196.2, 233.2, 249.4, 87.1, 96.4, 105.7, 172.7, + 187.7, 202.7, 243., 260.1, 88.8, 98.4, 108., 178., 193.6, 209.2, + 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, + 301.4, 310.4, 319.4, 513., 527.4, 541.8, 622., 638.2, 310.3, 319.6, + 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, + 568., 583.6, 684.8, 702.8}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_6) { + int bS = 2, iW = 16, iC = 3, oC = 4, kW = 3, sW = 3, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3, -4}); + + NDArray expOutput( + 'c', {bS, oC, oW}, + {159.7, 335.3, 381.2, 427.1, 473., 518.9, 163.8, 351.4, + 400., 448.6, 497.2, 545.8, 167.9, 367.5, 418.8, 470.1, + 521.4, 572.7, 172., 383.6, 437.6, 491.6, 545.6, 599.6, + 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3, 595.8, 1129., + 1177.6, 1226.2, 1274.8, 1323.4, 614.3, 1188.3, 1239.6, 1290.9, + 1342.2, 1393.5, 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_7) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, + 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, 61.599998, + 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, + 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, + 129.100006, 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, + 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, + 188.500000, 205.000000, 221.500000, 238.000000, 208.299988, 226.600006, + 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, + 247.899994, 269.799988, 291.700012, 313.600006, 267.700012, 291.399994, + 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, + 307.299988, 334.600006, 361.899994, 389.200012}, + sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_8) { + int bS = 2, iW = 8, iC = 3, oC = 4, kW = 2, sW = 1, pW = 0, dW = 2; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oW, oC}, + {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, + 29.299999, 30.799999, 45.399998, 48.399998, 51.400002, 54.400005, + 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, + 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, 130.000000, + 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, + 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, 156.800003, + 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, + 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, 281.200012, + 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, + 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, 356.800018, + 302.799988, 329.199982, 355.600006, 382.000000}, + sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { +TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { + int bS = 2, iW = 3, iC = 4, oC = 3, kW = 2, sW = 1, pW = 0, dW = 1; + int oW = (iW - 1) / sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1, -2, -3}); + NDArray gradO('c', {bS, oW, oC}); - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + gradO.linspace(-1.5, 0.1); - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); - input = 2.; - weights.linspace(0.1, 0.1); + const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, + {kW, sW, pW, dW, paddingMode, dataFormat}); + const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, + {kW, sW, pW, dW, paddingMode, dataFormat}); - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); + sd::ops::conv1d opFF; + sd::ops::conv1d_bp opBP; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_TRUE(isGradCorrect); } -////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { +TEST_F(ConvolutionTests1, Test_Dilation2D_1) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + { + 77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, + 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, + 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, + 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222, + }); + + input.linspace(1); + weights.linspace(1); + + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 2, 2, 1, 1, 2, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW +TEST_F(ConvolutionTests1, Test_Dilation2D_2) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 2, 3}, + {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213}); - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); - input = 2.; - weights.linspace(0.1, 0.1); + input.linspace(1); + weights.linspace(1); - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {0, 1, 2, 2, 1, 1, 2, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, + 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, + 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, + 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, + 11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f, + 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, + 3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f, + 10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f, + 25.194f, 27.813f, 30.432f, 33.051f, 26.922f, 29.703f, 32.484f, 35.265f, + 11.814f, 13.326f, 14.838f, 16.35f, 30.378f, 33.483f, 36.588f, 39.693f, + 32.106f, 35.373f, 38.64f, 41.907f, 13.474f, 14.563f, 15.652f, 16.741f, + 31.988f, 34.22f, 36.452f, 38.684f, 33.572f, 35.912f, 38.252f, 40.592f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, oC}, + {14.4f, 14.76f, 15.12f, 14.4f, 14.76f, 15.12f, 14.4f, 14.76f, 15.12f, + 14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, + 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 17.04f, 17.52f, 18.f, + 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, + 10.88f, 11.2f, 11.52f, 10.88f, 11.2f, 11.52f, 10.88f, 11.2f, 11.52f, + 10.88f, 11.2f, 11.52f, 11.16f, 11.52f, 11.88f, 11.16f, 11.52f, 11.88f, + 11.16f, 11.52f, 11.88f, 11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, + 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} ////////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, + 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, + 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, + 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, + 2.147f, 2.246f, 2.345f, 0.086f, 0.212f, 0.338f, 0.464f, 0.694f, 0.973f, + 1.252f, 1.531f, 0.716f, 0.869f, 1.022f, 1.175f, 1.216f, 1.522f, 1.828f, + 2.134f, 3.908f, 4.574f, 5.24f, 5.906f, 2.908f, 3.268f, 3.628f, 3.988f, + 3.664f, 3.97f, 4.276f, 4.582f, 9.236f, 9.902f, 10.568f, 11.234f, 5.788f, + 6.148f, 6.508f, 6.868f, 3.002f, 3.182f, 3.362f, 3.542f, 7.174f, 7.561f, + 7.948f, 8.335f, 4.28f, 4.487f, 4.694f, 4.901f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, oC}, + {1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f, 1.84f, 2.f, 2.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, + 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, + 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, 0.747f, 1.62f, 0.876f, + 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, + 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, + 1.179f, 2.52f, 1.344f, 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, + 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, + 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, + 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, + 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); + + auto expGradW = NDArrayFactory::create( + 'c', {oC, iC, kH, kW}, + {1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f}); + auto expGradB = + NDArrayFactory::create('c', {oC}, {0.68f, 1.f, 1.32f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2, 3, 1, 0}); + expGradW.permutei({2, 3, 1, 0}); + + sd::ops::conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); - input = 2.; - weights = 0.5; - expected = 48.; +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_4) { + int bS = 1, iH = 7, iW = 1, iC = 2, oC = 3, kH = 2, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 7, oW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, 2, 3}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray gradB('c', {oC}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto status = op.execute( + {&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); +} - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_5) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kH, kW}, + {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, + 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, + 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, + 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, + 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, + 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, + -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, 0.392, + -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, + 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, + 0.382, 0.662, 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, + -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, + -0.052, -1.406, -1.426, -0.749, -2.221, -1.508, 1.624, 2.732, 1.072, + 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, + 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, + -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, + -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, iC, kH, kW}, + {-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, + 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, + 13.02, 13.700001, 15.06, 15.74, 17.1, 17.780001, + 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, + -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, + 1.66, 2.66, 4.660001, 5.660001, 7.66, 8.66, + 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, + 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, + -17.380001, -16.059999, -13.420003, -12.099999, -9.46, -8.139999, + -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, + 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, + 30.139999, 31.459999, 34.099998, 35.419998, 38.060001, 39.380001}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + // weights.linspace(3.6, -0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_6) { + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kH, kW, iC}, + {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, + 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, + 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, + 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, + 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, + 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, + 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, + 2.259, -1.305, 1.962, -1.602, 4.545, -3.555, + 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, + 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, + 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, + 3.273, -2.019, 2.832, -2.460, 6.069, -5.163, + 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, + 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, + 9.225, -9.783, 11.547, -8.757, 9.855, -10.448999, + 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, + 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, + 11.574, -6.570, 10.062, -8.082, 20.745001, -16.514999, + 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, + 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, + 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, kH, kW, iC}, + {34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, + 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, + 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, + 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, + 33.120003, 40.499996, 47.879993, 55.260002, 32.399998, 37.139996, + 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, + 91.199997, 96.799995, 102.399994, 108.0, 115.199989, 120.959999, + 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, + 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, + 41.400002, 46.260002, 120.0, 129.0, 138.0, 147.0, + 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, + 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, + 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, + 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, + 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, + 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, + 15.339f, 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, + 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, + 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, + 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, + 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, + 60.432f, 64.05f, 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, + 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, + 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, + 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, + 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, + 74.992f, 79.672f, 84.352f, 58.296f, 61.806f, 65.316f, 68.826f, + 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, + 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, + 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, + 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, + 142.312f, 148.288f, 154.264f, 160.24f, 9.298f, 11.359f, 13.42f, + 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, + 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, + 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, + 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, + 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, 29.89f, 32.275f, + 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, + 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, + 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, + 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, + 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, 148.692f, + 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, + 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, + 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, + 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, + 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, + 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, + 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, + 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, + 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, + 390.928f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, + 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, + 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, + 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, + 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, + 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, + 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, + 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, + 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, + 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, + 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, + 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, + 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, + 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, + 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, + 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, + 2.188f, 2.332f, 2.476f, 2.62f, 1.202f, 1.274f, 1.346f, 1.418f, + 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, + 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, + 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, + 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, + 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, + 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, + 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, + 4.388f, 4.541f, 4.694f, 4.847f, 8.56f, 8.866f, 9.172f, 9.478f, + 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, + 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, + 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, + 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, + 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, + 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, + 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, + 9.388f, 9.964f, 10.54f, 11.116f, 4.802f, 5.09f, 5.378f, 5.666f, + 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, + 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, + 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, + 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, + 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, + 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, + 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, + 10.58f, 10.949f, 11.318f, 11.687f, 20.944f, 21.682f, 22.42f, 23.158f, + 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, + 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, + 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, + 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, + 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, + 8.16f, 7.52f, 7.84f, 8.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, + 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, + 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, + 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, + 6.84f, 3.423f, 7.068f, 3.648f, 2.415f, 5.04f, 2.628f, 5.25f, + 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, + 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, + 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, + 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, + 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, + 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, + 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, + 9.f, 4.503f, 9.3f, 4.8f, 3.063f, 6.408f, 3.348f, 6.69f, + 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, + 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, + 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, + 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, + 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, + 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, + 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, + 15.336f, 7.671f, 15.636f, 7.968f, 6.807f, 13.896f, 7.092f, 14.178f, + 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, + 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, + 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, + 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, + 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, + 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, + 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, + 20.952f, 10.479f, 21.324f, 10.848f, 9.183f, 18.72f, 9.54f, 19.074f, + 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, + 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, + 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, + 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); + + auto expGradW = NDArrayFactory::create( + 'c', {oC, iC, kD, kH, kW}, + {5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f}); + + auto expGradB = + NDArrayFactory::create('c', {oC}, {2.64f, 3.92f, 5.2f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2, 3, 4, 1, 0}); + expGradW.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test4) { + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kD, kH, kW}, + {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, + 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, -5.3, -6.5, + 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, + 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, + 6.9, 5.7, 4.5, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, + 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, + 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, -4.5, -5.7, -6.9, + 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, + 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, + 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, + 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., + 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {1.847, 3.577, 1.694, 3.460, 6.542, 3.010, + 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, + 5.408, 9.483999, 3.932, 1.894, 2.978, 1.012, + 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, + -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, + -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, + 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, + 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, + 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, + -0.140, -1.108, -1.040, -1.936, -5.816, -4.024, + -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, + -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, + 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, + 1.271, 2.263, 0.956, 2.830, 5.102, 2.200, + 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, + -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, + -2.318, -5.770, -3.524, -1.303, -3.173, -1.906, + -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, + 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, + 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, + 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, + -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, + -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, + -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, + 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, + 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, + 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, + -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, + -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, + -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, + 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, + 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, + 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, + -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, + -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, + -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, + 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, + 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, + 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, + -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, + -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, + -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, + 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, + 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, + 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, + -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, + -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, + -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, iC, kD, kH, kW}, + {-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, + -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, + 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, + 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, + 46.879997, 48.200005, 50.839996, 52.160004, 70.639999, + 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, + 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, + 99.680008, 118.160004, 119.479996, 122.120003, 123.440010, + 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, + 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, + -64.279999, -62.319996, -52.519993, -50.559994, -46.640003, + -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, + 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, + 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, + 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, + 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, + 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, + 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, + 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, + 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, + -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, + -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, + -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, + 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, + 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, + 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, + 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, + 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, + 211.680008, 214.280014, 219.479996, 222.079987}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - input = 2.; - weights = 0.5; - expected = 49.; - bias = 1.; +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test5) { + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kD, kH, kW, iC}, + {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, + 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., 8.7, 8.4, 8.1, + 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, + 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, + 14.9, 14.6, 14.3, 14., 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, + 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., + 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., 4.7, 4.4, + 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, + 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, + 11.2, 10.9, 10.6, 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, + 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, + 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1, -0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iD, iH, iW, iC}, + {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, + 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, 30.388000, + 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, + 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, + 43.852001, 42.807999, 41.764000, 40.719997, 87.596001, 85.400002, + 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, + 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, 66.210007, + 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, + 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, + 130.367996, 126.587997, 146.891998, 142.787994, 138.683990, 134.580017, + 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, + 155.316010, 150.563995, 173.783997, 168.707993, 163.631989, 158.556000, + 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, + 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, + 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, + 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, + 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, + 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, + 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, + 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, + 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, + 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, + 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, + 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, + 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, + 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, + 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, + 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, + 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, + 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, + 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, + 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, + 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, + 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, + 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, + 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, + 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, + 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, + 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, + 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, + 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, + 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, + 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, + 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, + 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, + 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, + 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, + 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {oC, kD, kH, kW, iC}, + {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, + 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, + 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, + 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, + 1090.080078, 1115.520020, 1140.959961, 1166.400024, 1183.679932, + 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, + 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, + 1576.260010, 1029.780029, 1046.429932, 1063.080078, 1079.729980, + 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, + 748.560059, 759.119995, 769.679993, 389.880005, 422.819946, + 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, + 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, + 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, + 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, + 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, + 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, + 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, + 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, + 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, + 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, + 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, + 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, + 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, + 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, + 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, + 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, + 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, + 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, + 742.319946, 753.119995, 763.919983, 774.720032}, + sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, + sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4, 3, 3}, + {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} - // output->printIndexedBuffer(); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2, 2, 3}, + {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, + 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, + 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, + 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, + 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + input = 2.; + weights = 0.5; + expected = 48.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + + input = 2.; + weights = 0.5; + expected = 49.; + bias = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, - 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, - 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); - input = 2.; - weights = 0.5; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, + 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, + 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, + 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); + input = 2.; + weights = 0.5; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, - 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, + 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, + 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, + 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, + 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { - - int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, - 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, - 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 4, oC = 3, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 2, 2}, + {235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, + 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { - auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); - auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); - auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); - sd::ops::conv3dnew op; - auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv3dnew op; + auto result = + op.evaluate({&x, &y}, {}, {2, 5, 5, 5, 4, 3, 0, 0, 0, 1, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.isSameShape(z)); } TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { - auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); - auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); - sd::ops::conv3dnew op; - auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv3dnew op; + auto result = + op.evaluate({&x, &w}, {}, {2, 5, 5, 5, 4, 3, 0, 0, 0, 1, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); - ContextPrototype proto; - Context ctx(1); - ctx.appendI(2); - ctx.appendI(5); - ctx.appendI(5); + ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); + ContextPrototype proto; + Context ctx(1); + ctx.appendI(2); + ctx.appendI(5); + ctx.appendI(5); - ctx.appendI(5); - ctx.appendI(4); - ctx.appendI(3); + ctx.appendI(5); + ctx.appendI(4); + ctx.appendI(3); - ctx.appendI(0); - ctx.appendI(0); - ctx.appendI(0); + ctx.appendI(0); + ctx.appendI(0); + ctx.appendI(0); - ctx.appendI(1); - ctx.appendI(1); - ctx.appendI(1); + ctx.appendI(1); + ctx.appendI(1); + ctx.appendI(1); - ctx.appendI(0); - ctx.appendI(1); // previous variant was "ctx.appendI(0)" and this caused fail + ctx.appendI(0); + ctx.appendI(1); // previous variant was "ctx.appendI(0)" and this caused fail - auto shapes = op.calculateOutputShape(&shapeList, ctx); - ASSERT_EQ(1, shapes->size()); + auto shapes = op.calculateOutputShape(&shapeList, ctx); + ASSERT_EQ(1, shapes->size()); - auto s = shapes->at(0); + auto s = shapes->at(0); - auto z = result.at(0); - // z->printShapeInfo("z shape"); + auto z = result.at(0); + // z->printShapeInfo("z shape"); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - shapes->destroy(); - delete shapes; + shapes->destroy(); + delete shapes; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 1, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - input = 2.; - weights = 1.; + input = 2.; + weights = 1.; - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { - - int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=13,oW=13; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - - input = 2.; - weights = 1.; - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output.isSameShape(&expected)); - + int bS = 5, iD = 4, iH = 14, iW = 14, iC = 1, oC = 1, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 13, oW = 13; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isSameShape(&expected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv3d_test12) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, iC, kD, kH, kW}, {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, - -6.9, -5.7, -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, -8.7, -7.5, -6.3, - -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, - -4.4, -3.2, -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, -3.7, -2.5, -1.3, -0.1}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-42520.597656, -42344.199219, -41991.402344, -41814.996094, -40932.992188, -40756.597656, -40403.800781, -40227.406250, - -41953.601562, -41779.601562, -41431.597656, -41257.601562, -40387.601562, -40213.597656, -39865.601562, -39691.597656, -41391.105469, -41219.492188, - -40876.300781, -40704.699219, -39846.707031, -39675.097656, -39331.898438, -39160.300781, -17119.001953, -16942.599609, -16589.798828, -16413.400391, - -15531.399414, -15355.000000, -15002.199219, -14825.800781, -16897.597656, -16723.597656, -16375.599609, -16201.599609, -15331.599609, -15157.600586, - -14809.601562, -14635.598633, -16680.703125, -16509.099609, -16165.900391, -15994.300781, -15136.300781, -14964.700195, -14621.500000, -14449.900391}, sd::DataType::FLOAT32); - - input.linspace(150,-0.5); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, iC, kD, kH, kW}, + {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, + -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, -6.9, -5.7, + -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, + -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, + -8.7, -7.5, -6.3, -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, + -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, + -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, -4.4, -3.2, + -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, + -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, + -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, + -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, + -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, + -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, + -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, + -3.7, -2.5, -1.3, -0.1}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oD, oH, oW}, + {-42520.597656, -42344.199219, -41991.402344, -41814.996094, + -40932.992188, -40756.597656, -40403.800781, -40227.406250, + -41953.601562, -41779.601562, -41431.597656, -41257.601562, + -40387.601562, -40213.597656, -39865.601562, -39691.597656, + -41391.105469, -41219.492188, -40876.300781, -40704.699219, + -39846.707031, -39675.097656, -39331.898438, -39160.300781, + -17119.001953, -16942.599609, -16589.798828, -16413.400391, + -15531.399414, -15355.000000, -15002.199219, -14825.800781, + -16897.597656, -16723.597656, -16375.599609, -16201.599609, + -15331.599609, -15157.600586, -14809.601562, -14635.598633, + -16680.703125, -16509.099609, -16165.900391, -15994.300781, + -15136.300781, -14964.700195, -14621.500000, -14449.900391}, + sd::DataType::FLOAT32); + + input.linspace(150, -0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv3d_test13) { - - int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {oC, kD, kH, kW, iC}, {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, - -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, -6.9, -6.6, -6.3, - -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, - 2.4, 2.7, 3., 3.3, 3.6, 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, -3.2, - -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oD, oH, oW, oC}, {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, 4193.299805, 1317.000000, 1413.199829, 1504.899902, - 3498.999756, 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, 854.500000, - 645.800049, 729.200073, 808.099976, 80.799995, 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, -2855.799805, -2611.600098, -2371.900879, - -2124.399414, -2003.199951, -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, -2654.800049, -2443.899902, -2045.200073, -1938.399902, - -1836.100220, -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, - -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, 1682.699829}, sd::DataType::FLOAT32); - - input.linspace(75,-0.5); - - sd::ops::conv3dnew op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iD = 4, iH = 3, iW = 3, iC = 4, oC = 3, kD = 3, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, + // kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {oC, kD, kH, kW, iC}, + {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, + -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, + 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, + 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, + -6.9, -6.6, -6.3, -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, + -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., + 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, 3.6, + 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, + -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, + -3.2, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, + 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, + 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1, 2, 0.5}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oD, oH, oW, oC}, + {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, + 4193.299805, 1317.000000, 1413.199829, 1504.899902, 3498.999756, + 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, + 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, + 854.500000, 645.800049, 729.200073, 808.099976, 80.799995, + 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, + -2855.799805, -2611.600098, -2371.900879, -2124.399414, -2003.199951, + -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, + -2654.800049, -2443.899902, -2045.200073, -1938.399902, -1836.100220, + -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, + -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, + -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, + -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, + -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, + -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, + -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, + -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, + -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, + -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, + -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, + -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, + -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, + -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, + -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, + -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, + -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, + -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, + -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, + -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, + -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, + -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, + -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, + -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, + -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, + -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, + -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, + -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, + -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, + 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, + 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, + 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, + 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, + 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, + 1682.699829}, + sd::DataType::FLOAT32); + + input.linspace(75, -0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { - - int bS=2, iH=4,iW=3, iC=4,oC=3; - - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {1, 1, iC, oC}); - auto bias = NDArrayFactory::create('c', {oC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, - 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, - 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, - 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); - input = 2.; - weights.linspace(0.1, 0.1); - bias = 1.; - - sd::ops::pointwise_conv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3; + + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {1, 1, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iH, iW, oC}, + {5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); + input = 2.; + weights.linspace(0.1, 0.1); + bias = 1.; + + sd::ops::pointwise_conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, vol2col_test1) { - - int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=3,oW=2; - - NDArray volume('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, sd::DataType::FLOAT32); - - columns = -1.; - volume.linspace(1); - - NDArray columnsExpected('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., 4., 0., 6.,0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., 10., 0., 12., 0., 0., 0., 5., 6., -0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., -0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17.,18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., -0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., -24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., -34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., -0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., -41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., -0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., -0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54.,0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., -53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0.,0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., -0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., -70., 71., 72., 0., 0., 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., -0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - - graph::Context context(1); - sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); - // columns.printBuffer(); - - ASSERT_TRUE(columns.equalsTo(columnsExpected)); + int bS = 2, iD = 2, iH = 3, iW = 2, iC = 3, oC = 2, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 3, oW = 2; + + NDArray volume('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, sd::DataType::FLOAT32); + + columns = -1.; + volume.linspace(1); + + NDArray columnsExpected( + 'c', {bS, iC, kD, kH, kW, oD, oH, oW}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., + 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., + 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., + 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., + 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., + 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., + 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., + 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., + 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., + 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., + 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., + 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., + 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., + 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., + 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., + 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., + 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., + 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., + 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., + 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., + 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., + 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., + 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., + 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., + 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., + 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., + 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., + 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., + 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., + 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., + 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., + 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., + 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., + 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., + 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., + 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., + 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., + 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., + 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., + 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., + 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., + 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., + 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., + 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., + 0., 0., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., + 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, + pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, vol2col_test2) { - - int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=3,oW=2; - - auto volume = NDArrayFactory::create('c', {iD, bS, iH, iC, iW}); - volume.permutei({1, 3, 0, 2, 4}); - volume.linspace(1); - - auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); - columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); - columns = -1.; - auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, -10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, -9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, -23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, -0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, -34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, -0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, -48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, -0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, -0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - - graph::Context context(1); - sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); - // columns.printBuffer(); - - ASSERT_TRUE(columns.equalsTo(columnsExpected)); + int bS = 2, iD = 2, iH = 3, iW = 2, iC = 3, oC = 2, kD = 2, kH = 3, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 3, oW = 2; + + auto volume = NDArrayFactory::create('c', {iD, bS, iH, iC, iW}); + volume.permutei({1, 3, 0, 2, 4}); + volume.linspace(1); + + auto columns = + NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); + columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); + columns = -1.; + auto columnsExpected = NDArrayFactory::create( + 'c', {bS, iC, kD, kH, kW, oD, oH, oW}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, + 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, + 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, + 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, + 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, + 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, + 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, + 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, + 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, + 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, + 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, + 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, + 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 38.f, 0.f, 40.f, 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, + 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, + 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, + 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, + 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, + 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, + 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, + 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, + 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, + 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, + 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, + 64.f, 0.f, 66.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, + 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, + 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, + pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, col2im_test1) { + int bS = 2, iH = 2, iW = 2, iC = 2, kH = 2, kW = 2, sD = 1, sH = 1, sW = 1, + pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oH = 2, oW = 2; - int bS=2, iH=2,iW=2, iC=2, kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oH=2,oW=2; - - auto image = NDArrayFactory::create('c', {bS, iC, iH, iW}); - image = -2.; - - auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); - columns.linspace(1); + auto image = NDArrayFactory::create('c', {bS, iC, iH, iW}); + image = -2.; - auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); + auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + columns.linspace(1); + auto imageExpected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, + 49.f, 103.f, 108.f, 226.f}); - sd::ops::col2im op; - auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::col2im op; + auto status = + op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(image.equalsTo(imageExpected)); + ASSERT_TRUE(image.equalsTo(imageExpected)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling2d_test1) { - - const int bS=3, iH=2,iW=2, iC=3; - const int factorH=2, factorW=3; - const int isNCHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); - - sd::ops::upsampling2d op; - auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iH = 2, iW = 2, iC = 3; + const int factorH = 2, factorW = 3; + const int isNCHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iH * factorH, iW * factorW, iC}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling2d_test2) { - - const int bS=3, iH=2,iW=2, iC=3; - const int factorH=2, factorW=3; - const int isNCHW = 1; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, - 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, - 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, - 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, - 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, - 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); - - sd::ops::upsampling2d op; - auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iH = 2, iW = 2, iC = 3; + const int factorH = 2, factorW = 3; + const int isNCHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, iH * factorH, iW * factorW}, + {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, + 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, + 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, + 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, + 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, + 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, + 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, + 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, + 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, + 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, + 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, + 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, + 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_test1) { - - const int bS=3, iD=2,iH=2,iW=2, iC=3; - const int factorD=2,factorH=3,factorW=2; - const int isNCDHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, - 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); - - sd::ops::upsampling3d op; - auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iD = 2, iH = 2, iW = 2, iC = 3; + const int factorD = 2, factorH = 3, factorW = 2; + const int isNCDHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iD * factorD, iH * factorH, iW * factorW, iC}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_test2) { - - const int bS=3, iD=2,iH=2,iW=2, iC=3; - const int factorD=2,factorH=3,factorW=2; - const int isNCDHW = 1; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - input.linspace(1); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, - 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, - 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, - 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, - 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, - 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, - 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, - 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, - 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, - 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, - 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, - 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); - - sd::ops::upsampling3d op; - auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + const int bS = 3, iD = 2, iH = 2, iW = 2, iC = 3; + const int factorD = 2, factorH = 3, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}, + {1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, + 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, + 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, + 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, + 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, + 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, + 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, + 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, + 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, + 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, + 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, + 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, + 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, + 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, + 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, + 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, + 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, + 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, + 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, + 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, + 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, + 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, + 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, + 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, + 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, + 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, + 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, + 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, + 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, + 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, + 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, + 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, + 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, + 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, + 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, + 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, + 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { + const int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1; + const int factorD = 2, factorH = 2, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW - const int bS=1, iD=2,iH=2,iW=2, iC=1; - const int factorD=2, factorH=2, factorW=2; - const int isNCDHW = 1; // data format, default is NCHW + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}); + gradO = 1.; - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}); - gradO = 1.; + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + expGradI = 8.; - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - expGradI = 8.; + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto gradI = results.at(0); - sd::ops::upsampling3d_bp op; - auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) { + auto inputShape = NDArrayFactory::create('c', {4}, {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + auto shapeArr = NDArrayFactory::create('c', {2, 1, 4, 4}); - auto inputShape = NDArrayFactory::create('c', {4}, {2, 1, 4, 4}); - auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); - auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); - auto shapeArr = NDArrayFactory::create('c', {2, 1, 4, 4}); - - - TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; - NDArray expEps(_expEpsB, shapeArr.shapeInfo()); - - weights.linspace(1); - epsilonNext.linspace(1); - weights.permutei({2,3,1,0}); + TypeParam _expEpsB[] = { + 952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, + 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, + 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, + 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, shapeArr.shapeInfo()); - sd::ops::conv2d_input_bp op; + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2, 3, 1, 0}); - auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); + sd::ops::conv2d_input_bp op; - ASSERT_TRUE(results.size() == 1); + auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, + {3, 3, 1, 1, 0, 0, 1, 1, 1}); - auto epsilon = results.at(0); + ASSERT_TRUE(results.size() == 1); - ASSERT_TRUE(shapeArr.isSameShape(epsilon)); - ASSERT_TRUE(expEps.equalsTo(epsilon)); + auto epsilon = results.at(0); + ASSERT_TRUE(shapeArr.isSameShape(epsilon)); + ASSERT_TRUE(expEps.equalsTo(epsilon)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { - - const int bS=1, iD=3,iH=3,iW=3, iC=2; - const int factorD=2, factorH=2, factorW=2; - const int isNCDHW = 1; // data format, default is NCHW - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, - 0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, - 0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, - 0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, - 0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, - 0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, - 0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, - 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047, - 0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033, - 0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843, - 0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876, - 0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, - 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415, - 0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304, - 0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846, - 0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239, - 0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, - 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075, - 0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565, - 0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615, - 0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824, - 0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, - 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802, - 0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821, - 0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676, - 0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527, - 0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, - 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731, - 0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452, - 0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176, - 0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142, - 0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, - 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457, - 0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804, - 0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821, - 0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333, - 0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278, - 3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016, - 4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917, - 4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, - 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, sd::DataType::FLOAT32); - - sd::ops::upsampling3d_bp op; - auto results = op.evaluate({&input, &gradO}, {isNCDHW}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - + const int bS = 1, iD = 3, iH = 3, iW = 3, iC = 2; + const int factorD = 2, factorH = 2, factorW = 2; + const int isNCDHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray gradO( + 'c', {bS, iC, iD * factorD, iH * factorH, iW * factorW}, + {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, + 0.31069338, 0.44793984, 0.93800974, 0.32667395, 0.15187258, + 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, + 0.14696825, 0.26089668, 0.13505761, 0.7562093, 0.27545404, + 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, + 0.31279507, 0.13591796, 0.5175439, 0.32870287, 0.061735712, + 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, + 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, + 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, + 0.19694561, 0.6994972, 0.0743224, 0.42042503, 0.5842631, + 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, + 0.8759956, 0.5698191, 0.4458631, 0.5277549, 0.016646361, + 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, + 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, + 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, + 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, + 0.15736352, 0.49352047, 0.5699365, 0.12683152, 0.11572781, + 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, + 0.3900982, 0.14730452, 0.8506447, 0.49765033, 0.07186628, + 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, + 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, + 0.38641843, 0.51154125, 0.19903564, 0.1416313, 0.69769853, + 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, + 0.6643576, 0.018850708, 0.63755876, 0.2904297, 0.43490165, + 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, + 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, + 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, + 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, + 0.66288394, 0.2188415, 0.3354802, 0.03566524, 0.5101009, + 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, + 0.47145858, 0.5369367, 0.19884548, 0.99008304, 0.08256686, + 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, + 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, + 0.52304846, 0.76631916, 0.4187526, 0.7653719, 0.5159193, + 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, + 0.3197591, 0.040378205, 0.5427239, 0.9228089, 0.045940384, + 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, + 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, + 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, + 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, + 0.44395888, 0.99384075, 0.6127142, 0.44844577, 0.6347944, + 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, + 0.2052629, 0.46441218, 0.041791342, 0.89369565, 0.7000381, + 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, + 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, + 0.060582615, 0.08239174, 0.64630795, 0.32862368, 0.60225064, + 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, + 0.34913155, 0.42887798, 0.45344824, 0.73956585, 0.9714739, + 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, + 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, + 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, + 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, + 0.27467912, 0.3852802, 0.0766939, 0.94622654, 0.38687763, + 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, + 0.18050623, 0.21057767, 0.012561422, 0.7977821, 0.61251044, + 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, + 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, + 0.2663676, 0.96846986, 0.8273284, 0.10700377, 0.7600526, + 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, + 0.85725874, 0.99090636, 0.54562527, 0.93597686, 0.21142527, + 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, + 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, + 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, + 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, + 0.54379046, 0.3583731, 0.33369112, 0.04279039, 0.24939054, + 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, + 0.86916673, 0.80322117, 0.049972698, 0.47177452, 0.37419558, + 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, + 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, + 0.46070176, 0.14496958, 0.47706795, 0.50678796, 0.64902323, + 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, + 0.9963951, 0.8239163, 0.305142, 0.012419582, 0.9498972, + 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, + 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, + 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, + 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, + 0.67495126, 0.96461457, 0.10535406, 0.66438645, 0.4372345, + 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, + 0.52931345, 0.20154329, 0.07698499, 0.6125804, 0.3583082, + 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, + 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, + 0.6361821, 0.71167386, 0.5134439, 0.57761437, 0.58598644, + 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, + 0.051309288, 0.24846801, 0.55938333, 0.10230542, 0.9370694, + 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, + 0.38641605, 0.9836358}, + sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, + 3.3928843, 3.228055, 3.1467278, 3.2603023, 5.611751, 4.334653, + 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, + 4.181534, 2.9965065, 2.8553872, 5.2719016, 4.5671935, 3.7027276, + 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, + 3.1446428, 3.0932689, 3.9730802, 3.0466917, 4.9675374, 4.769673, + 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, + 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, 4.225355, + 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, + sd::DataType::FLOAT32); + + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test1) { - - int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, oC = 5, iC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto exp = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, + 42.75f, 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, + 115.5f, 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, + 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, + 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, + 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, + 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, + 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, + 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test2) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, + 60.5f, 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, + 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, + 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, + 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test3) { - - int bS=1, oH=5,oW=5, oC=3,iC=2, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto bias = NDArrayFactory::create('c', {oC}); - - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, - -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, - -32.8f, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, - -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f}); - - input.linspace(-10, 0.5); - weights.linspace(0.1, 0.1); - bias = 0.2; - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, oH = 5, oW = 5, oC = 3, iC = 2, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 2, dW = 2; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, + -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, + -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, + -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, + -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, -36.6f, -40.4f, + -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, + -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, + -6.8f, -7.5f, -8.2f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test4) { - - NDArray input('c', {2, 3, 4, 4}, sd::DataType::FLOAT32); - NDArray weights('c', {3, 3, 5, 5}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0, - 100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0, - 84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0, - 54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0, - 90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0, - 8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0, - 144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0, - 118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0, - 115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0, - 268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0, - 52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0, - 78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0, - 89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, sd::DataType::FLOAT32); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::deconv2d op; - auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - auto z = result.at(0); - // z->printShapeInfo(); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray input('c', {2, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray weights('c', {3, 3, 5, 5}, sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 3, 8, 8}, + {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, + 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, + 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, + 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, + 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, + 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, + 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, + 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, + 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, + 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, + 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, + 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, + 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, + 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, + 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, + 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, + 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, + 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, + 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, + 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, + 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, + 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, + 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, + 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, + 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, + 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, + 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, + 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, + 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, + 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, + 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, + 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, + 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, + 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, + 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, + 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, + 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, + 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, + 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, + 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, + 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, + 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, + 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, + 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, + 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, + 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, + 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, + 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, + 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, + 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, + 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, + 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, + 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, + 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, + 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}, + sd::DataType::FLOAT32); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::deconv2d op; + auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = result.at(0); + // z->printShapeInfo(); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test5) { - Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; - double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,}; - NDArray exp(_expB, _expS); - - auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); - auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); - auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); - - input.linspace(1); - weights.linspace(1); - weights.permutei({2,3,1,0}); - - sd::ops::deconv2d op; - auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result); - - ASSERT_TRUE(exp.isSameShape(&z)); - ASSERT_TRUE(exp.equalsTo(&z)); + Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; + double _expB[] = { + 6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, + 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, + 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, + 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, + 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, + 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, + 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, + 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, + 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, + 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, + 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, + 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, + 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, + 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, + 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, + 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, + 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, + 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, + 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, + 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, + 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, + 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, + 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, + 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, + 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, + 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, + 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, + 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, + 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, + 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, + 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, + 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, + 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, + 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, + 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, + 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, + 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, + 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, + 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, + 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, + 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, + 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, + 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, + 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, + 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, + 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, + 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, + 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, + 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, + 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, + 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, + 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, + 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, + 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, + 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0, + }; + NDArray exp(_expB, _expS); + + auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); + auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); + auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::deconv2d op; + auto result = + op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(exp.isSameShape(&z)); + ASSERT_TRUE(exp.equalsTo(&z)); } TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { - - int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=8,oW=8; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, - 4.f, 79.f, 154.f, 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, 181.f, 56.f, 131.f, 206.f, - 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, 209.f, - 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, - 13.f, 88.f, 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, 115.f, 190.f, 65.f, 140.f, 215.f, - 16.f, 91.f, 166.f, 41.f, 116.f, 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, 143.f, 218.f, - 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, - 22.f, 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, - 25.f, 100.f, 175.f,50.f, 125.f, 200.f,75.f, 150.f, 225.f}); - - auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); - - input.linspace(1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 3, oC = 3, kH = 5, kW = 5, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 8, oW = 8; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create( + 'c', {kH, kW, oC, iC}, + {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, + 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, + 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, 4.f, 79.f, 154.f, + 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, + 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, + 181.f, 56.f, 131.f, 206.f, 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, + 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, + 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, + 209.f, 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, + 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, + 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, 13.f, 88.f, + 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, + 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, + 115.f, 190.f, 65.f, 140.f, 215.f, 16.f, 91.f, 166.f, 41.f, 116.f, + 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, + 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, + 143.f, 218.f, 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, + 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, + 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, 22.f, + 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, + 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, + 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, 25.f, 100.f, 175.f, 50.f, + 125.f, 200.f, 75.f, 150.f, 225.f}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, + 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, + 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, + 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, + 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, + 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, + 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, + 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, + 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, + 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, + 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, + 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, + 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, + 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, + 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, + 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, + 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, + 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, + 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, + 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, + 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, + 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, + 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, + 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, + 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, + 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, + 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, + 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, + 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, + 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, + 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, + 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, + 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, + 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, + 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, + 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, + 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, + 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, + 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, + 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, + 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, + 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, + 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, + 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, + 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, + 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, + 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, + 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, + 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, + 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, + 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, + 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, + 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, + 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, + 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, + 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, + 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, + 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, + 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, + 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, + 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, + 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, + 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, + 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); + + input.linspace(1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH, kW, sH, sW, pH, pW, dH, + dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(ConvolutionTests1, deconv2d_test7) { - - NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.}); - - auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); - auto weights = NDArrayFactory::create('c',{1, 1, 2, 3}, {1,3,5,2,4,6}); - auto bias = NDArrayFactory::create('c', {2}); - - input.linspace(1); - bias.linspace(1); - - sd::ops::deconv2d op; - - auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + NDArray exp( + 'c', {3, 2, 4, 4}, + {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., + 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., + 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., + 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., + 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., + 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., + 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., + 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., + 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., + 1542., 1554., 1566., 1578., 1590., 1602.}); + + auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); + auto weights = + NDArrayFactory::create('c', {1, 1, 2, 3}, {1, 3, 5, 2, 4, 6}); + auto bias = NDArrayFactory::create('c', {2}); + + input.linspace(1); + bias.linspace(1); + + sd::ops::deconv2d op; + + auto result = + op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test8) { - - int bS=1, iH=7,iW=7, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=7,oW=7; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, - 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, - 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, - 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, - 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, - 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, - 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231, 0.192975, - 0.246897, 0.386418, 0.511541, 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, - 0.842513, 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, 0.221469, - 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); - - NDArray weights('c', {kH, kW, oC, iC}, {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); - NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); - - NDArray exp('c', {bS, oC, oH, oW}, {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, - 0.847623, 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, 0.678572, - 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, 0.683336, 0.591897, 0.705213, - 0.748148, 0.648922, 0.484723, 0.873482, 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, - 1.148711, 1.334007, 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, - 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, 1.133576, 1.405098, - 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, iH = 7, iW = 7, iC = 3, oC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 7, oW = 7; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input( + 'c', {bS, iC, iH, iW}, + {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, + 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, 0.798564, + 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, + 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, + 0.328703, 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, + 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, + 0.318416, 0.068546, 0.284533, 0.232720, 0.352142, 0.058909, 0.711221, + 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, + 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, 0.569819, 0.445863, + 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, + 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, + 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, + 0.737939, 0.490079, 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, + 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, + 0.273894, 0.431796, 0.133231, 0.192975, 0.246897, 0.386418, 0.511541, + 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, + 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, 0.842513, + 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, + 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, + 0.221469, 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); + + NDArray weights( + 'c', {kH, kW, oC, iC}, + {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, + 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); + NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); + + NDArray exp( + 'c', {bS, oC, oH, oW}, + {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, + 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, 0.847623, + 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, + 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, + 0.678572, 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, + 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, + 0.683336, 0.591897, 0.705213, 0.748148, 0.648922, 0.484723, 0.873482, + 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, + 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, 1.148711, 1.334007, + 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, + 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, + 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, + 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, + 1.133576, 1.405098, 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); + + sd::ops::deconv2d op; + auto results = + op.evaluate({&input, &weights, &bias}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test9) { - - int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, oC, kH, kW}, {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, - 15.000000, 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, 94.500000, 69.500000, - 44.500000, 19.500000, 89.500000, 64.500000, 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, - 74.000000, 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, 34.000000, 9.000000, - 79.000000, 54.000000, 29.000000, 4.000000, 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, - 13.500000, 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, 93.000000, 68.000000, - 43.000000, 18.000000, 88.000000, 63.000000, 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, 25.500000, 0.500000}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oH, oW, oC}, {-30844.250000, -29266.750000, -27689.250000, -26111.750000, -24534.250000, -52823.500000, -49718.500000, -46613.500000, -43508.500000, -40403.500000, -51118.500000, - -48113.500000, -45108.500000, -42103.500000, -39098.500000, -21501.750000, -20024.250000, -18546.750000, -17069.250000, -15591.750000, -42981.000000, -39976.000000, -36971.000000, -33966.000000, -30961.000000, - -69482.000000, -63572.000000, -57662.000000, -51752.000000, -45842.000000, -67072.000000, -61362.000000, -55652.000000, -49942.000000, -44232.000000, -26046.000000, -23241.000000, -20436.000000, -17631.000000, - -14826.000000, -38616.000000, -35911.000000, -33206.000000, -30501.000000, -27796.000000, -62252.000000, -56942.000000, -51632.000000, -46322.000000, -41012.000000, -59842.000000, -54732.000000, -49622.000000, - -44512.000000, -39402.000000, -23181.000000, -20676.000000, -18171.000000, -15666.000000, -13161.000000, -12204.250000, -10926.750000, -9649.250000, -8371.750000, -7094.250000, -17543.500000, -15038.500000, - -12533.500000, -10028.500000, -7523.500000, -16838.500000, -14433.499023, -12028.500000, -9623.500000, -7218.500000, -5361.750000, -4184.250000, -3006.750000, -1829.250000, -651.750000, -22046.750000, -20919.250000, - -19791.750000, -18664.250000, -17536.750000, -37478.500000, -35273.500000, -33068.500000, -30863.500000, -28658.500000, -35773.500000, -33668.500000, -31563.500000, -29458.500000, -27353.500000, -14954.250000, - -13926.750000, -12899.250000, -11871.750000, -10844.250000, -29886.000000, -27781.000000, -25676.000000, -23571.000000, -21466.000000, -47792.000000, -43682.000000, -39572.000000, -35462.000000, -31352.000000, - -45382.000000, -41472.000000, -37562.000000, -33652.000000, -29742.000000, -17451.000000, -15546.000000, -13641.000000, -11736.000000, -9831.000000, -25521.000000, -23716.000000, -21911.000000, -20106.000000, -18301.000000, -40562.000000, -37052.000000, -33542.000000, -30032.000000, -26522.000000, -38152.000000, -34842.000000, -31532.000000, -28222.000000, -24912.000000, -14586.000000, -12981.000000, -11376.000000, -9771.000000, -8166.000000, -7906.750000, -7079.250000, -6251.750000, -5424.250000, -4596.750000, -11198.500000, -9593.500000, -7988.500000, -6383.500000, -4778.500000, -10493.500000, -8988.500000, -7483.500000, -5978.500000, -4473.500000, -3314.250000, -2586.750000, -1859.250000, -1131.750000, -404.250000}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, oC = 5, iC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, oC, kH, kW}, + {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, + 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, 15.000000, + 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, + 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, + 94.500000, 69.500000, 44.500000, 19.500000, 89.500000, 64.500000, + 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, + 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, 74.000000, + 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, + 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, + 34.000000, 9.000000, 79.000000, 54.000000, 29.000000, 4.000000, + 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, + 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, 13.500000, + 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, + 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, + 93.000000, 68.000000, 43.000000, 18.000000, 88.000000, 63.000000, + 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, + 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, + 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, + 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, + 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, + 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, + 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, + 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, + 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, + 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, + 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, + 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, + 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, + 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, + 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, + 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, + 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, + 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, + 25.500000, 0.500000}, + sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oH, oW, oC}, + {-30844.250000, -29266.750000, -27689.250000, -26111.750000, + -24534.250000, -52823.500000, -49718.500000, -46613.500000, + -43508.500000, -40403.500000, -51118.500000, -48113.500000, + -45108.500000, -42103.500000, -39098.500000, -21501.750000, + -20024.250000, -18546.750000, -17069.250000, -15591.750000, + -42981.000000, -39976.000000, -36971.000000, -33966.000000, + -30961.000000, -69482.000000, -63572.000000, -57662.000000, + -51752.000000, -45842.000000, -67072.000000, -61362.000000, + -55652.000000, -49942.000000, -44232.000000, -26046.000000, + -23241.000000, -20436.000000, -17631.000000, -14826.000000, + -38616.000000, -35911.000000, -33206.000000, -30501.000000, + -27796.000000, -62252.000000, -56942.000000, -51632.000000, + -46322.000000, -41012.000000, -59842.000000, -54732.000000, + -49622.000000, -44512.000000, -39402.000000, -23181.000000, + -20676.000000, -18171.000000, -15666.000000, -13161.000000, + -12204.250000, -10926.750000, -9649.250000, -8371.750000, + -7094.250000, -17543.500000, -15038.500000, -12533.500000, + -10028.500000, -7523.500000, -16838.500000, -14433.499023, + -12028.500000, -9623.500000, -7218.500000, -5361.750000, + -4184.250000, -3006.750000, -1829.250000, -651.750000, + -22046.750000, -20919.250000, -19791.750000, -18664.250000, + -17536.750000, -37478.500000, -35273.500000, -33068.500000, + -30863.500000, -28658.500000, -35773.500000, -33668.500000, + -31563.500000, -29458.500000, -27353.500000, -14954.250000, + -13926.750000, -12899.250000, -11871.750000, -10844.250000, + -29886.000000, -27781.000000, -25676.000000, -23571.000000, + -21466.000000, -47792.000000, -43682.000000, -39572.000000, + -35462.000000, -31352.000000, -45382.000000, -41472.000000, + -37562.000000, -33652.000000, -29742.000000, -17451.000000, + -15546.000000, -13641.000000, -11736.000000, -9831.000000, + -25521.000000, -23716.000000, -21911.000000, -20106.000000, + -18301.000000, -40562.000000, -37052.000000, -33542.000000, + -30032.000000, -26522.000000, -38152.000000, -34842.000000, + -31532.000000, -28222.000000, -24912.000000, -14586.000000, + -12981.000000, -11376.000000, -9771.000000, -8166.000000, + -7906.750000, -7079.250000, -6251.750000, -5424.250000, + -4596.750000, -11198.500000, -9593.500000, -7988.500000, + -6383.500000, -4778.500000, -10493.500000, -8988.500000, + -7483.500000, -5978.500000, -4473.500000, -3314.250000, + -2586.750000, -1859.250000, -1131.750000, -404.250000}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate( + {&input, &weights}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, deconv2d_test10) { - - int bS=2, oH=4,oW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int iH=4,iW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, kH, kW, oC}, {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., - -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., 59., 54., 49., 44., - 39., 34., 29., 24., 19., 14., 9., 4., -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., -61., -66., -71., -76., -81., -86., -91., -96., - 98., 93., 88., 83., 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., -42., -47., - -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., 37., 32., 27., 22., 17., 12., 7., 2., - -3., -8., -13., -18., -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., - 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., -64., -69., -74., -79., -84., -89., -94., -99.}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oC, oH, oW}, {-14128., -21007., -20934., -20861., -13660., -12972., -12926.000977, -12880., -13468., -12788., -12742., -12696.000977, - -13276., -12604., -12558., -12512., -13408., -19569.5, -19501.5, -19433.5, -12230., -10117., -10081.000977, -10045., -12058., -9973., -9937., -9901.000977, - -11886., -9829., -9793., -9757., -12688., -18132., -18069., -18006., -10800., -7262., -7236., -7210., -10648., -7157.999512, -7132., -7106., -10496., -7054., - -7027.999512, -7002., -11968., -16694.5, -16636.5, -16578.5, -9370., -4406.999023, -4391., -4375., -9238., -4343., -4326.999023, -4311., -9106., -4279., -4263., - -4246.999023, -11247.999023, -15257., -15204., -15151., -7940., -1551.999023, -1546., -1540., -7828., -1528.000977, -1521.999023, -1516., -7716., -1504., - -1498.000977, -1491.999023, -10527.999023, -13819.5, -13771.5, -13723.5, -6510., 1303.000977, 1299., 1295., -6418., 1286.999023, 1283.000977, 1279., -6326., - 1271., 1266.999023, 1263.000977, -9807.999023, -12382., -12339., -12296., -5080., 4158.000977, 4144., 4130., -5008., 4101.999023, 4088., 4074., -4936., 4046., 4031.999023, 4018., -9088., -10944.5, -10906.5, -10868.5, -3650., 7013., 6989., 6965., -3598., 6917., 6893., 6869., -3546., 6821., 6797., 6773., -8368., -9507., -9474., -9441., -2220., 9868., 9834., 9800., -2187.999512, 9732., 9698., 9664., -2156., 9596., 9562., 9528., -7648., -8069.5, -8041.5, -8013.499512, -790.000488, 12723., 12679., 12635., -777.999512, 12547., 12503., 12459., -766., 12371., 12327., 12283., -10208., -15167., -15094., -15021., -9820., -9292., -9246., -9200., -9628., -9108., -9062., -9016., -9436., -8924., -8878., -8832., -9687.999023, -14129.5, -14061.5, -13993.5, -8790., -7236.999023, -7201., -7164.999512, -8618., -7093., -7057., -7021., -8446., -6949., -6913., -6877., -9168., -13092., -13029., -12966., -7760., -5182., -5156., -5129.999512, -7608., -5078., -5052., -5026., -7456., -4974., -4948., -4922., -8648., -12054.5, -11996.5, -11938.5, -6730., -3127., -3111., -3095., -6598., -3063., -3047., -3031., -6465.999512, -2999., -2983.000488, -2967., -8128., -11017., -10964., -10911., -5700.000488, -1072., -1066., -1060., -5587.999512, -1048.000488, -1042., -1036., -5476., -1023.999512, -1018.000488, -1012., -7608., -9979.5, -9931.5, -9883.5, -4670.000488, 983., 979., 975., -4577.999512, 966.999512, 963., 959., -4486., 951.000488, 946.999512, 943., -7088., -8942., -8899., -8856., -3640.000488, 3038., 3024., 3010., -3567.999512, 2981.999512, 2968., 2954., -3496., 2926.000488, 2911.999512, 2898., -6568., -7904.5, -7866.5, -7828.499512, -2610.000488, 5093., 5069., 5045., -2557.999512, 4996.999512, 4973., 4949., -2506., 4901.000488, 4877., 4853., -6048., -6867., -6834., -6800.999512, -1580., 7148., 7114., 7080., -1547.999512, 7012., 6978., 6944., -1516., 6876.000488, 6842., 6808., -5528., -5829.5, -5801.5, -5773.499512, -550., 9203., 9159., 9115., -537.999512, 9027., 8983., 8939., -526., 8851., 8807., 8763.}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv2d op; - auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oH = 4, oW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int iH = 4, iW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kH, kW, oC}, + {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., + 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., + -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., + -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., + 59., 54., 49., 44., 39., 34., 29., 24., 19., 14., 9., 4., + -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., + -61., -66., -71., -76., -81., -86., -91., -96., 98., 93., 88., 83., + 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., + 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., + -42., -47., -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., + 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., + 37., 32., 27., 22., 17., 12., 7., 2., -3., -8., -13., -18., + -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., + -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., + 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., + -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., + -64., -69., -74., -79., -84., -89., -94., -99.}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {-14128., -21007., -20934., -20861., -13660., + -12972., -12926.000977, -12880., -13468., -12788., + -12742., -12696.000977, -13276., -12604., -12558., + -12512., -13408., -19569.5, -19501.5, -19433.5, + -12230., -10117., -10081.000977, -10045., -12058., + -9973., -9937., -9901.000977, -11886., -9829., + -9793., -9757., -12688., -18132., -18069., + -18006., -10800., -7262., -7236., -7210., + -10648., -7157.999512, -7132., -7106., -10496., + -7054., -7027.999512, -7002., -11968., -16694.5, + -16636.5, -16578.5, -9370., -4406.999023, -4391., + -4375., -9238., -4343., -4326.999023, -4311., + -9106., -4279., -4263., -4246.999023, -11247.999023, + -15257., -15204., -15151., -7940., -1551.999023, + -1546., -1540., -7828., -1528.000977, -1521.999023, + -1516., -7716., -1504., -1498.000977, -1491.999023, + -10527.999023, -13819.5, -13771.5, -13723.5, -6510., + 1303.000977, 1299., 1295., -6418., 1286.999023, + 1283.000977, 1279., -6326., 1271., 1266.999023, + 1263.000977, -9807.999023, -12382., -12339., -12296., + -5080., 4158.000977, 4144., 4130., -5008., + 4101.999023, 4088., 4074., -4936., 4046., + 4031.999023, 4018., -9088., -10944.5, -10906.5, + -10868.5, -3650., 7013., 6989., 6965., + -3598., 6917., 6893., 6869., -3546., + 6821., 6797., 6773., -8368., -9507., + -9474., -9441., -2220., 9868., 9834., + 9800., -2187.999512, 9732., 9698., 9664., + -2156., 9596., 9562., 9528., -7648., + -8069.5, -8041.5, -8013.499512, -790.000488, 12723., + 12679., 12635., -777.999512, 12547., 12503., + 12459., -766., 12371., 12327., 12283., + -10208., -15167., -15094., -15021., -9820., + -9292., -9246., -9200., -9628., -9108., + -9062., -9016., -9436., -8924., -8878., + -8832., -9687.999023, -14129.5, -14061.5, -13993.5, + -8790., -7236.999023, -7201., -7164.999512, -8618., + -7093., -7057., -7021., -8446., -6949., + -6913., -6877., -9168., -13092., -13029., + -12966., -7760., -5182., -5156., -5129.999512, + -7608., -5078., -5052., -5026., -7456., + -4974., -4948., -4922., -8648., -12054.5, + -11996.5, -11938.5, -6730., -3127., -3111., + -3095., -6598., -3063., -3047., -3031., + -6465.999512, -2999., -2983.000488, -2967., -8128., + -11017., -10964., -10911., -5700.000488, -1072., + -1066., -1060., -5587.999512, -1048.000488, -1042., + -1036., -5476., -1023.999512, -1018.000488, -1012., + -7608., -9979.5, -9931.5, -9883.5, -4670.000488, + 983., 979., 975., -4577.999512, 966.999512, + 963., 959., -4486., 951.000488, 946.999512, + 943., -7088., -8942., -8899., -8856., + -3640.000488, 3038., 3024., 3010., -3567.999512, + 2981.999512, 2968., 2954., -3496., 2926.000488, + 2911.999512, 2898., -6568., -7904.5, -7866.5, + -7828.499512, -2610.000488, 5093., 5069., 5045., + -2557.999512, 4996.999512, 4973., 4949., -2506., + 4901.000488, 4877., 4853., -6048., -6867., + -6834., -6800.999512, -1580., 7148., 7114., + 7080., -1547.999512, 7012., 6978., 6944., + -1516., 6876.000488, 6842., 6808., -5528., + -5829.5, -5801.5, -5773.499512, -550., 9203., + 9159., 9115., -537.999512, 9027., 8983., + 8939., -526., 8851., 8807., 8763.}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate( + {&input, &weights}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d_tf op; - auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create( + 'c', {4}, + {static_cast(bS), static_cast(iH), + static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, + 42.75f, 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, + 115.5f, 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, + 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, + 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, + 47.75f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, + 125.5f, 135.5f, 145.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, + 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, + 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = + op.evaluate({&outShape, &weights, &input}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - -#endif //LIBND4J_CONVOLUTIONTESTS1_H - +#endif // LIBND4J_CONVOLUTIONTESTS1_H diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index ceb9e68e8958..87fe7e16d003 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -22,77 +22,84 @@ #ifndef LIBND4J_CONVOLUTIONTESTS2_H #define LIBND4J_CONVOLUTIONTESTS2_H -#include "testlayers.h" #include #include #include #include #include +#include +#include #include -#include #include -#include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ConvolutionTests2 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width }; ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, im2col_1) { - - int bS=2, iH=4,iW=3, iC=4, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // VALID - int oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; // VALID - - int paddingMode = 0; // 1-SAME, 0-VALID; - - NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); - NDArray expected('c', {bS, iC, kH, kW, oH, oW}, {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14, - 15, 17, 18, 16, 17, 19, 20, 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, 29, 30, - 28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43, - 44, 41, 42, 44, 45, 43, 44, 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, 53, 54, - 56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67, - 68, 70, 71, 68, 69, 71, 72, 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, 82, 83, - 80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96}); - - image.linspace(1, 1); - - sd::ops::im2col op; - auto results = op.evaluate({&image}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); - auto column = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(column)); - ASSERT_TRUE(expected.equalsTo(column)); - + int bS = 2, iH = 4, iW = 3, iC = 4, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = (iH - (kH + (kH - 1) * (dH - 1)) + 2 * pH) / sH + 1; // VALID + int oW = (iW - (kW + (kW - 1) * (dW - 1)) + 2 * pW) / sW + 1; // VALID + + int paddingMode = 0; // 1-SAME, 0-VALID; + + NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {bS, iC, kH, kW, oH, oW}, + {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, + 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14, 15, 17, 18, 16, 17, 19, 20, + 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, + 29, 30, 28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, + 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43, 44, 41, 42, 44, 45, 43, 44, + 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, + 53, 54, 56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, + 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67, 68, 70, 71, 68, 69, 71, 72, + 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, + 82, 83, 80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, + 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96}); + + image.linspace(1, 1); + + sd::ops::im2col op; + auto results = + op.evaluate({&image}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode}); + auto column = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(column)); + ASSERT_TRUE(expected.equalsTo(column)); } template class TypedConvolutionTests2 : public testing::Test { -public: - + public: }; typedef ::testing::Types TestingTypes; @@ -100,2738 +107,6567 @@ TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { - - int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv2d_tf op; - auto results = op.evaluate({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iH = 4, iW = 4, iC = 5, oC = 10, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create( + 'c', {4}, + {static_cast(bS), static_cast(iH), + static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, + 60.5f, 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, + 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, + 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, + 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, + 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 2.75f, + 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, + 70.5f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, + 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, + 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, + 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, + 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = + op.evaluate({&outShape, &weights, &input}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { - auto input0 = NDArrayFactory::create('c', {4}, {12.f, 5.f, 5.f, 32.f}); - auto input1 = NDArrayFactory::create('c', {2, 2, 32, 16}); - auto input2 = NDArrayFactory::create('c', {12, 4, 4, 16}); - auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); - - sd::ops::deconv2d_tf op; - auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, result.at(0)); - + auto input0 = + NDArrayFactory::create('c', {4}, {12.f, 5.f, 5.f, 32.f}); + auto input1 = NDArrayFactory::create('c', {2, 2, 32, 16}); + auto input2 = NDArrayFactory::create('c', {12, 4, 4, 16}); + auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, result.at(0)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { - auto input0 = NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); - - auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, --1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, --1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, -2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f, -1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f, --0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, -1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, -0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f, --0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f, -0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f, --1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, --1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, --0.73649710f, 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); - - auto input2 = NDArrayFactory::create('c', {3, 4, 4, 5}, {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); - auto exp = NDArrayFactory::create('c', {3, 8, 8, 16}, {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, 3.53978682f, 0.66798484f}); - - sd::ops::deconv2d_tf op; - auto result = op.evaluate({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input0 = + NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); + + auto input1 = NDArrayFactory::create( + 'c', {7, 7, 16, 5}, + {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, + 0.56918693f, -1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, + -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, + -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, + -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, + 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, + -1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, + 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, + -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, + 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, + 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, 2.19538260f, + -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, + -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, + -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, + 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, + 2.33141375f, 1.10972047f, 0.71407092f, 1.70640314f, 1.80666339f, + 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, + 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, + -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, + -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, + -0.13715318f, -0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, + 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, + -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, + 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, + -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, + 1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, + 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, + 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, + -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, + -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, 0.97053039f, + -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, + 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, + 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, + -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, + 0.95460182f, 1.59319806f, -0.44433126f, -0.33717924f, 0.79566282f, + 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, + 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, + 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, + 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, + 0.15854552f, -0.67599791f, 0.53726906f, -1.20158446f, -1.78549063f, + 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, + 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, + 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, + 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, + -0.10575818f, -1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, + 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, + -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, + -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, + -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, + -1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, + 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, + -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, + 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, + -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, -0.73649710f, + 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, + 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, + -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, + -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, + -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, + 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, + 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, + -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, + -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, + -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, + 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, + -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, + -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, + -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, + 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, + 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, + -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, + -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, + -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, + 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, + -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, + 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, + 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, + 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, + -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, + 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, + -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, + -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, + -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, + 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, + 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, + -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, + 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, + 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, + -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, + -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, + -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, + -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, + -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, + 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, + -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, + -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, + -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, + 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, + 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, + 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, + 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, + -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, + 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, + 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, + -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, + -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, + -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, + 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, + 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, + 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, + -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, + -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, + -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, + -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, + 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, + 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, + 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, + -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, + 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, + 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, + -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, + -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, + -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, + -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, + -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, + 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, + -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, + -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, + 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, + 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, + 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, + 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, + 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, + 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, + -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, + 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, + 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, + -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, + -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, + 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, + 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, + 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, + 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, + -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, + 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, + -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, + -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, + 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, + -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, + 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, + -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, + -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, + -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, + -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, + -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, + 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, + 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, + -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, + -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, + 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, + -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, + -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, + -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, + -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, + -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, + 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, + 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, + 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, + -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, + 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, + 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, + 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, + -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, + 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, + -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, + 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, + -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, + 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, + 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, + -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, + -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, + 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, + 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, + -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, + -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, + -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, + 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, + 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, + 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, + -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, + -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, + 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, + -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, + 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, + 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, + 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, + -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, + -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, + -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, + 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, + 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, + -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, + 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, + 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, + 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, + 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, + 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, + -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, + -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, + -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, + -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, + 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, + 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, + 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, + -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, + 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, + 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, + 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, + -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, + 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, + -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, + 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, + -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, + 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, + 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, + -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, + -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, + -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, + -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, + 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, + 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, + -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, + -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, + -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, + 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, + -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, + -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, + 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, + -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, + 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, + -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, + 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, + 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, + -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, + 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, + 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, + -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, + 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, + 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, + 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, + 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, + 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, + -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, + 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, + -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, + -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, + -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, + -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, + -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, + 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, + 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, + -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, + -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, + 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, + 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, + -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, + -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, + 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, + -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, + 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, + 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, + 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, + 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, + 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, + -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, + 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, + -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, + 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, + 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, + 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, + 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, + 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, + 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, + 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, + 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, + -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, + -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, + 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, + -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, + -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, + 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, + -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, + -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, + -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, + 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, + -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, + 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, + -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, + 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, + 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, + -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, + -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, + -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, + -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, + 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, + -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, + -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, + -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, + 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, + -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, + -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, + 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, + -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, + 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, + -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, + -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, + -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, + -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, + -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, + 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, + -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, + 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, + 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, + 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, + -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, + 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, + -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, + 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, + 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, + 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, + -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, + -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, + -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, + -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, + -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, + -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, + 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, + 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, + -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, + -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, + -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, + -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, + 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, + 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, + 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, + -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, + -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, + 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, + -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, + -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, + 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, + -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, + -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, + -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, + -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, + 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, + -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, + 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, + -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, + 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, + -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, + 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, + 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, + 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, + -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, + 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, + -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, + -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, + 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, + -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, + -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, + -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, + -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, + 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, + 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, + -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, + -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, + 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, + -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, + -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, + 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, + 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, + 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, + -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, + -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, + 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, + -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, + -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, + 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, + -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, + -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, + 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, + 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, + -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, + 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, + -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, + -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, + -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, + -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, + 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, + -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, + 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, + -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, + 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, + -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, + 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, + 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, + -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, + -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, + 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, + -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, + -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, + -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, + 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, + -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, + 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, + -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, + -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, + -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, + 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, + -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, + 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, + 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, + 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, + -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, + 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, + -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, + 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, + 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, + -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, + 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, + 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, + -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, + 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, + 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, + -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, + -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, + 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, + -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, + -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, + 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, + 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, + 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, + -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, + -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, + -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, + 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, + -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, + -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, + -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, + -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, + -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, + 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, + -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, + 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, + 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, + -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, + 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, + 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, + 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, + 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, + -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, + -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, + 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, + 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, + 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, + -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, + 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, + -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, + -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, + 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, + -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, + 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, + -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, + -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, + 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, + 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, + -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, + 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, + 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, + -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, + -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, + 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, + 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, + -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, + -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, + 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, + 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, + 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, + -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, + 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, + 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, + -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, + 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, + -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, + -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, + -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, + -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, + -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, + -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, + 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, + 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, + 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, + 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, + -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, + -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, + -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, + -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, + 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, + -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, + -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, + 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, + 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, + 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, + -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, + -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, + -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, + -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, + 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, + -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, + 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, + 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, + -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, + 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, + -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, + 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, + -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, + 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, + 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, + 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, + -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, + -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, + -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, + 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, + 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, + -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, + -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, + 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, + 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, + 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, + -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, + -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, + 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, + 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, + 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, + -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, + 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, + -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, + -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, + 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, + -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, + 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, + 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, + 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, + -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, + -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, + -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, + 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, + -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, + 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, + 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, + 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, + 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, + 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, + 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, + 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, + 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, + 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, + 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, + -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, + 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, + 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, + 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, + 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, + 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, + -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, + -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, + 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, + -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, + 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, + -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, + -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, + 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, + 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, + 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, + -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, + -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, + -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, + 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, + 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, + -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, + -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, + -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, + 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, + -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, + 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, + 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, + -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, + 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, + -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, + 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, + -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, + -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, + 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, + 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, + -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, + -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, + 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, + 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, + 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, + -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, + 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, + 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, + 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, + 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, + -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, + -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, + 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, + -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, + -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, + -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, + -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, + 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, + -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, + -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, + -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, + 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, + 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, + 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, + -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, + 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, + -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, + -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, + 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, + 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, + -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, + -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, + 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, + -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, + 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, + 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, + -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, + -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, + 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, + 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, + -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, + -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, + -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, + -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, + -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, + -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, + -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, + -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, + -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, + -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, + -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, + -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, + -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, + 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, + 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, + 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, + 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, + -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, + 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, + 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, + 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, + 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, + -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, + -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, + -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, + -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, + -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, + -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, + -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, + 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, + 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, + 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, + 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, + 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, + 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, + 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, + 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, + -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, + 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, + -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, + -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, + 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, + 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, + -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, + 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, + -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, + 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, + -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, + 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, + 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, + -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, + 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, + 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, + 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, + -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, + 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, + 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, + -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, + 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, + -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, + 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, + 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, + -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, + 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, + 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, + -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, + 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, + 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, + 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, + -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, + -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, + 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, + -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, + 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, + -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, + 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, + -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, + -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, + 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, + 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, + -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, + -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, + -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, + 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, + 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, + -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, + 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, + -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, + -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, + 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, + -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, + -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, + -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, + -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, + -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, + -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, + -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, + -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, + -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, + -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, + -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, + -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, + -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, + 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, + -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, + 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, + 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, + 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, + 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, + -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, + -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, + 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, + -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, + 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, + 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, + 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, + -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, + -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, + -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, + 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, + -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, + -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, + 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, + -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, + -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, + 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, + -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, + -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, + -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, + 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, + -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, + 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, + -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, + -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, + 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, + 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, + 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, + 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, + 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); + + auto input2 = NDArrayFactory::create( + 'c', {3, 4, 4, 5}, + {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, + 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, + 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, + 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, + 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, + 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, + 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, + 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, + 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, + 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, + 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, + 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, + 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, + 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, + 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, + 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, + 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, + 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, + 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, + 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, + 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, + 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, + 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, + 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, + 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, + 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, + 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, + 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, + 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, + 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, + 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, + 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, + 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, + 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, + 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, + 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, + 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, + 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, + 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, + 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, + 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, + 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, + 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, + 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, + 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, + 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, + 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, + 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); + auto exp = NDArrayFactory::create( + 'c', {3, 8, 8, 16}, + {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, + -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, + 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, + 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, + -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, + 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, + -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, + -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, + 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, + -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, + -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, + 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, + -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, + -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, + 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, + 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, + -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, + 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, + -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, + -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, + 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, + -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, + -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, + -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, + -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, + -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, + -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, + 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, + 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, + -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, + -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, + 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, + -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, + 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, + -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, + 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, + -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, + 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, + 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, + -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, + 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, + -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, + -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, + 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, + -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, + -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, + -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, + 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, + -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, + 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, + 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, + 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, + -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, + 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, + -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, + 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, + 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, + -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, + -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, + -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, + -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, + -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, + 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, + 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, + 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, + -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, + 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, + 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, + -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, + -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, + -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, + 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, + -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, + 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, + -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, + -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, + 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, + 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, + 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, + 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, + 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, + 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, + -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, + -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, + -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, + 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, + -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, + -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, + 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, + -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, + 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, + 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, + 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, + 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, + 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, + 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, + -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, + 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, + -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, + 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, + -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, + 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, + 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, + -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, + -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, + -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, + -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, + 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, + 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, + -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, + 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, + -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, + -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, + 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, + -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, + 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, + 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, + 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, + -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, + -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, + 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, + -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, + 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, + -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, + 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, + 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, + -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, + -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, + -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, + -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, + -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, + -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, + -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, + -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, + 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, + -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, + 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, + 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, + -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, + 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, + -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, + 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, + 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, + 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, + 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, + -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, + -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, + -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, + -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, + 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, + -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, + -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, + -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, + -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, + 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, + 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, + -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, + -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, + 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, + -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, + -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, + 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, + 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, + -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, + -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, + -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, + -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, + -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, + 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, + -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, + -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, + 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, + 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, + 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, + 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, + -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, + 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, + 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, + -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, + 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, + -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, + 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, + 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, + 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, + -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, + 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, + -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, + 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, + -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, + 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, + -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, + -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, + -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, + -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, + -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, + -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, + -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, + 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, + 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, + -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, + 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, + -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, + -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, + -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, + -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, + -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, + -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, + 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, + -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, + 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, + -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, + -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, + -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, + 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, + -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, + -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, + -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, + -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, + -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, + -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, + 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, + -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, + 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, + -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, + 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, + -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, + 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, + 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, + -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, + 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, + 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, + -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, + -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, + -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, + 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, + -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, + 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, + 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, + 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, + 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, + 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, + 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, + 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, + -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, + -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, + 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, + 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, + -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, + -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, + 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, + 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, + 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, + -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, + 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, + -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, + -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, + 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, + -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, + 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, + 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, + -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, + -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, + -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, + 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, + 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, + -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, + -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, + 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, + 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, + -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, + 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, + 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, + -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, + 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, + -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, + -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, + 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, + -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, + -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, + -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, + -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, + -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, + -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, + 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, + 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, + 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, + -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, + 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, + -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, + 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, + -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, + 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, + -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, + 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, + 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, + -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, + 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, + -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, + -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, + 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, + -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, + -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, + -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, + 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, + 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, + 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, + 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, + -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, + 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, + 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, + 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, + 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, + 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, + 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, + -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, + 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, + -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, + 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, + 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, + 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, + -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, + -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, + 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, + 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, + 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, + -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, + -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, + -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, + 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, + 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, + 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, + 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, + -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, + 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, + 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, + 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, + -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, + 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, + -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, + 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, + -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, + 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, + 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, + -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, + -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, + 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, + 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, + 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, + -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, + 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, + -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, + 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, + -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, + 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, + -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, + -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, + -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, + 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, + 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, + -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, + -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, + -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, + -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, + 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, + 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, + -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, + 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, + -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, + 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, + 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, + -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, + 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, + 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, + 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, + -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, + -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, + 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, + -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, + 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, + -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, + 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, + 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, + -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, + -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, + -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, + 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, + 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, + -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, + -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, + -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, + 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, + -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, + 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, + 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, + -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, + 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, + -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, + 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, + 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, + -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, + 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, + -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, + -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, + 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, + -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, + 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, + -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, + -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, + -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, + -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, + 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, + -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, + 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, + -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, + 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, + 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, + 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, + -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, + 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, + 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, + 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, + 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, + -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, + -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, + -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, + -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, + -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, + 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, + 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, + -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, + 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, + -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, + -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, + 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, + -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, + -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, + -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, + 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, + -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, + -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, + -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, + 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, + 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, + 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, + 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, + 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, + 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, + 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, + -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, + 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, + -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, + -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, + -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, + 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, + 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, + -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, + 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, + 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, + 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, + 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, + -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, + -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, + 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, + 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, + -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, + 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, + -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, + -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, + 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, + 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, + -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, + 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, + 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, + -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, + 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, + 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, + -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, + -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, + 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, + -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, + 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, + -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, + 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, + 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, + 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, + 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, + 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, + -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, + 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, + 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, + 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, + -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, + -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, + 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, + 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, + 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, + 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, + -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, + 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, + -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, + -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, + 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, + -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, + -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, + -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, + -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, + -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, + 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, + -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, + 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, + -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, + 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, + -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, + 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, + 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, + 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, + -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, + -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, + -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, + 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, + -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, + -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, + -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, + -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, + 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, + -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, + 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, + -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, + -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, + 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, + -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, + 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, + 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, + 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, + 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, + -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, + 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, + -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, + 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, + 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, + 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, + -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, + -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, + 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, + -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, + 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, + -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, + 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, + -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, + 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, + 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, + -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, + 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, + -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, + -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, + -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, + -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, + -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, + 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, + 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, + -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, + -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, + -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, + -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, + 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, + 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, + -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, + 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, + 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, + 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, + -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, + 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, + -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, + 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, + 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, + -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, + -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, + -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, + 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, + -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, + -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, + -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, + -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, + -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, + 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, + -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, + 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, + 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, + -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, + 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, + 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, + -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, + 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, + -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, + -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, + 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, + -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, + 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, + 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, + 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, + -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, + -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, + 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, + 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, + -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, + 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, + -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, + -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, + -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, + 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, + -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, + -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, + 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, + 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, + 3.53978682f, 0.66798484f}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, + {7, 7, 2, 2, 0, 0, 1, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) { - auto x = NDArrayFactory::create('c', {4, 128, 128, 4}); - auto w = NDArrayFactory::create('c', {4, 5, 4}); - auto exp = NDArrayFactory::create('c', {4, 64, 43, 4}); - + auto x = NDArrayFactory::create('c', {4, 128, 128, 4}); + auto w = NDArrayFactory::create('c', {4, 5, 4}); + auto exp = NDArrayFactory::create('c', {4, 64, 43, 4}); - sd::ops::dilation2d op; - auto result = op.evaluate({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {1, 1, 5, 7, 1, 1, 2, 3, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { - auto x = NDArrayFactory::create('c', {4, 26, 19, 4}); - auto w = NDArrayFactory::create('c', {11, 7, 4}); - - sd::ops::dilation2d op; - auto result = op.evaluate({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {4, 26, 19, 4}); + auto w = NDArrayFactory::create('c', {11, 7, 4}); + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {0, 1, 2, 3, 1, 1, 3, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); } TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { - TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; - Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expGWP(_expGradWpB, _expGradWpS); - expGWP.permutei({2,3,1,0}); - - TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; - Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expGWD(_expGradWdB, _expGradWdS); - expGWD.permutei({2,3,1,0}); - - TypeParam _expEB[] = {5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f }; - Nd4jLong _expES[] = {4, 2, 3, 10, 10, 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; - NDArray expE(_expEB, _expES); - - auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); - auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); - - auto epsilon = NDArrayFactory::create('c', {2, 3, 10, 10}); - auto epsilonNext = NDArrayFactory::create('c', {2, 10, 6, 6}); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - epsilonNext.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - input.applyScalar(scalar::Divide, 100.0, input); - weightsD.applyScalar(scalar::Divide, 100.0, weightsD); - weightsP.applyScalar(scalar::Divide, 100.0, weightsP); - epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); - - sd::ops::sconv2d_bp op; - auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); - - ASSERT_EQ(3, resultBP.size()); - - auto _epsilon = resultBP.at(0); - auto _gradWD = resultBP.at(1); - auto _gradWP = resultBP.at(2); - - //_gradWP->printBuffer("gradWP"); - - ASSERT_TRUE(_gradWP.isSameShape(&expGWP)); - ASSERT_TRUE(_gradWP.isSameShape(&weightsP)); - - ASSERT_TRUE(_gradWP.equalsTo(&expGWP)); - - //_gradWD->printShapeInfo("gradWD shape"); - - ASSERT_TRUE(_gradWD.isSameShape(&expGWD)); - ASSERT_TRUE(_gradWD.isSameShape(&weightsD)); -// _gradWD->printIndexedBuffer(); - ASSERT_TRUE(_gradWD.equalsTo(&expGWD)); - - ASSERT_TRUE(_epsilon.isSameShape(&input)); - ASSERT_TRUE(_epsilon.isSameShape(&expE)); - - ASSERT_TRUE(_epsilon.equalsTo(&expE)); - + TypeParam _expGradWpB[] = { + 1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, + 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, + 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, + 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, + 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, + 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, + 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, + 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, + 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, + 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, + 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, + 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, + 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, + 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, + 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; + Nd4jLong _expGradWpS[]{ + 4, 10, 6, 1, 1, + 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expGWP(_expGradWpB, _expGradWpS); + expGWP.permutei({2, 3, 1, 0}); + + TypeParam _expGradWdB[] = { + 2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, + 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, + 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, + 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, + 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, + 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, + 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, + 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, + 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, + 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, + 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, + 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, + 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, + 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, + 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, + 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, + 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, + 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, + 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, + 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, + 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, + 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, + 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, + 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, + 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, + 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, + 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, + 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, + 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, + 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; + Nd4jLong _expGradWdS[] = { + 4, 2, 3, 5, 5, + 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expGWD(_expGradWdB, _expGradWdS); + expGWD.permutei({2, 3, 1, 0}); + + TypeParam _expEB[] = { + 5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, + 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, + 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, + 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, + 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, + 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, + 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, + 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, + 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, + 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, + 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, + 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, + 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, + 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, + 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, + 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, + 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, + 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, + 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, + 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, + 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, + 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, + 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, + 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, + 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, + 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, + 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, + 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, + 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, + 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, + 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, + 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, + 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, + 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, + 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, + 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, + 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, + 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, + 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, + 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, + 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, + 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, + 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, + 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, + 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, + 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, + 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, + 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, + 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, + 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, + 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, + 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, + 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, + 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, + 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, + 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, + 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, + 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, + 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, + 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, + 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, + 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, + 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, + 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, + 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, + 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, + 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, + 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, + 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, + 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, + 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, + 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, + 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, + 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, + 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, + 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, + 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, + 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, + 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, + 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, + 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, + 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, + 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, + 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, + 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, + 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, + 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, + 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, + 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, + 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, + 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, + 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, + 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, + 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, + 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, + 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, + 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, + 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, + 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, + 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f}; + Nd4jLong _expES[] = { + 4, 2, 3, 10, 10, + 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, + 1, 99}; + NDArray expE(_expEB, _expES); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); + + auto epsilon = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto epsilonNext = NDArrayFactory::create('c', {2, 10, 6, 6}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + epsilonNext.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); + + sd::ops::sconv2d_bp op; + auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, + {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + ASSERT_EQ(3, resultBP.size()); + + auto _epsilon = resultBP.at(0); + auto _gradWD = resultBP.at(1); + auto _gradWP = resultBP.at(2); + + //_gradWP->printBuffer("gradWP"); + + ASSERT_TRUE(_gradWP.isSameShape(&expGWP)); + ASSERT_TRUE(_gradWP.isSameShape(&weightsP)); + + ASSERT_TRUE(_gradWP.equalsTo(&expGWP)); + + //_gradWD->printShapeInfo("gradWD shape"); + + ASSERT_TRUE(_gradWD.isSameShape(&expGWD)); + ASSERT_TRUE(_gradWD.isSameShape(&weightsD)); + // _gradWD->printIndexedBuffer(); + ASSERT_TRUE(_gradWD.equalsTo(&expGWD)); + + ASSERT_TRUE(_epsilon.isSameShape(&input)); + ASSERT_TRUE(_epsilon.isSameShape(&expE)); + + ASSERT_TRUE(_epsilon.equalsTo(&expE)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_2) { - - int bS=3, iH=16,iW=16, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; - int oH=16,oW=16; - int oC=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray gradO('c', {bS, oC, oH, oW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray weightsDepth('c', {kH, kW, iC, mC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray weightsPoint('f', {1, 1, iC*mC, oC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - NDArray bias('c', {1,oC}, {0.5, 0.5}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); - - NDArray gradI(&input); - NDArray gradWD(&weightsDepth); - NDArray gradWP(&weightsPoint); - NDArray gradB(&bias); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - weightsPoint.linspace(0.15, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, - {&gradI, &gradWD, &gradWP, &gradB}, - {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - + int bS = 3, iH = 16, iW = 16, iC = 3, mC = 3, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 2, dW = 2; + int oH = 16, oW = 16; + int oC = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray gradO('c', {bS, oC, oH, oW}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray weightsDepth('c', {kH, kW, iC, mC}, + typeid(TypeParam) == typeid(float) + ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray weightsPoint('f', {1, 1, iC * mC, oC}, + typeid(TypeParam) == typeid(float) + ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + NDArray bias('c', {1, oC}, {0.5, 0.5}, + typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 + : sd::DataType::DOUBLE); + + NDArray gradI(&input); + NDArray gradWD(&weightsDepth); + NDArray gradWP(&weightsPoint); + NDArray gradB(&bias); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsPoint.linspace(0.15, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + Nd4jStatus status = + op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + NDArray expGradI = gradI; + NDArray expGradWD = gradWD; + NDArray expGradWP = gradWP; + NDArray expGradB = gradB; + + for (int i = 0; i < 10; i++) { + Nd4jStatus status = op.execute( + {&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); ASSERT_EQ(Status::OK(), status); - NDArray expGradI = gradI; - NDArray expGradWD = gradWD; - NDArray expGradWP = gradWP; - NDArray expGradB = gradB; - - for( int i=0; i<10; i++ ) { - Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, - {&gradI, &gradWD, &gradWP, &gradB}, - {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(expGradI.equalsTo(gradI)); - ASSERT_TRUE(expGradWD.equalsTo(gradWD)); - ASSERT_TRUE(expGradWP.equalsTo(gradWP)); - ASSERT_TRUE(expGradB.equalsTo(expGradB)); - } + ASSERT_TRUE(expGradI.equalsTo(gradI)); + ASSERT_TRUE(expGradWD.equalsTo(gradWD)); + ASSERT_TRUE(expGradWP.equalsTo(gradWP)); + ASSERT_TRUE(expGradB.equalsTo(expGradB)); + } } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_3) { + auto input = NDArrayFactory::create('c', {3, 3, 16, 16}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 2, 2}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {1, 2}); - auto input = NDArrayFactory::create('c', {3, 3, 16, 16}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 2, 2}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {1, 2}); - - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); - auto epsilonNext = NDArrayFactory::create('c', {3, 2, 14, 14}); + auto epsilonNext = NDArrayFactory::create('c', {3, 2, 14, 14}); - auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); + auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); - sd::ops::sconv2d_bp op; - auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); + sd::ops::sconv2d_bp op; + auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, + {2, 2, 1, 1, 0, 0, 2, 2, 0}); - auto eps = result.at(0); - auto gWD = result.at(1); - auto gWP = result.at(2); + auto eps = result.at(0); + auto gWD = result.at(1); + auto gWP = result.at(2); - - ASSERT_TRUE(epsilon.isSameShape(eps)); + ASSERT_TRUE(epsilon.isSameShape(eps)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int oC=iC*mC; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f, - 1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f,10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f,10.38f, 12.028f}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC},{19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f}); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradWD = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradWD)); - ASSERT_TRUE(expGradW.equalsTo(gradWD)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int oC = iC * mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, + 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, + 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f, + 1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, + 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f, + 10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f, 10.38f, 12.028f}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, mC}, + {19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, + 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, + 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto results = + op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradWD = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradWD)); + ASSERT_TRUE(expGradW.equalsTo(gradWD)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, sconv2d_bp_5) { - - int bS=1, iH=8,iW=8, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=8,oW=8; - int oC=2; // iC*mC if weightsPoint = nullptr - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto weightsPoint = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); - auto bias = NDArrayFactory::create('c', {1,oC}, {1,2}); - - auto gradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradWD = NDArrayFactory::create('f', {kH, kW, iC, mC}); - auto gradWP = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); - auto gradB = NDArrayFactory::create('c', {1,oC}, {1,2}); - - input = 2.; - weightsDepth.linspace(0.1, 0.1); - weightsDepth.linspace(-0.5, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::sconv2d_bp op; - auto status = op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, {&gradI, &gradWD, &gradWP, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); - ASSERT_EQ(Status::OK(), status); + int bS = 1, iH = 8, iW = 8, iC = 3, mC = 3, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 8, oW = 8; + int oC = 2; // iC*mC if weightsPoint = nullptr + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto weightsPoint = NDArrayFactory::create('c', {1, 1, iC * mC, oC}); + auto bias = NDArrayFactory::create('c', {1, oC}, {1, 2}); + + auto gradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradWD = NDArrayFactory::create('f', {kH, kW, iC, mC}); + auto gradWP = NDArrayFactory::create('c', {1, 1, iC * mC, oC}); + auto gradB = NDArrayFactory::create('c', {1, oC}, {1, 2}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsDepth.linspace(-0.5, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto status = + op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, im2col_bp_1) { + int bS = 3, iH = 12, iW = 12, iC = 6, oC = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 12, oW = 12; - int bS=3, iH=12,iW=12, iC=6,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=12,oW=12; + // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE); + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output - // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); - NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE); - NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output - - sd::ops::im2col_bp op; - Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::im2col_bp op; + Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test1) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05, - 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , - 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05}); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.3, 0.75, 1.5, 2.4, 1.5, 2.4, 1.2, 1.65, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 4.2, 5.1, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, + 2.1, 2.55, 5.1, 6., 5.1, 6., 3., 3.45, 4.2, 5.1, 10.2, 12., + 10.2, 12., 6., 6.9, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, + 17.4, 19.2, 9.6, 10.5, 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, + 3.9, 4.35, 8.7, 9.6, 8.7, 9.6, 4.8, 5.25, 9.6, 10.5, 21., 22.8, + 21., 22.8, 11.4, 12.3, 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, + 5.7, 6.15, 12.3, 13.2, 12.3, 13.2, 6.6, 7.05, 0.3, 0.75, 1.5, 2.4, + 1.5, 2.4, 1.2, 1.65, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 2.1, 2.55, 5.1, 6., + 5.1, 6., 3., 3.45, 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, + 4.2, 5.1, 10.2, 12., 10.2, 12., 6., 6.9, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 15.6, 17.4, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 15.6, 17.4, + 7.8, 8.7, 17.4, 19.2, 17.4, 19.2, 9.6, 10.5, 3.9, 4.35, 8.7, 9.6, + 8.7, 9.6, 4.8, 5.25, 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, + 9.6, 10.5, 21., 22.8, 21., 22.8, 11.4, 12.3, 5.7, 6.15, 12.3, 13.2, + 12.3, 13.2, 6.6, 7.05}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test2) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , - 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 }); - input = 0.5; - weights.linspace(0.1, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.3, 0.75, 1.5, 2.4, 1.5, 2.4, 1.5, 2.4, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 6.6, 8.4, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 10.2, 12., + 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 0.3, 0.75, 1.5, 2.4, + 1.5, 2.4, 1.5, 2.4, 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, + 2.4, 3.3, 6.6, 8.4, 6.6, 8.4, 6.6, 8.4, 2.4, 3.3, 6.6, 8.4, + 6.6, 8.4, 6.6, 8.4, 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 4.2, 5.1, 10.2, 12., 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 4.2, 5.1, 10.2, 12., + 10.2, 12., 10.2, 12., 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, + 12., 13.8, 27.6, 31.2, 27.6, 31.2, 27.6, 31.2, 12., 13.8, 27.6, 31.2, + 27.6, 31.2, 27.6, 31.2}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test3) { - - int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, - 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, - 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , - 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8, - 2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, - 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, - 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , - 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8}); - input = 0.5; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {2.55, 5.25, 5.25, 2.7, 5.4, 11.1, 11.1, 5.7, 5.4, 11.1, 11.1, 5.7, + 2.85, 5.85, 5.85, 3., 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, + 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, 5.7, 11.7, 11.7, 6., + 12., 24.6, 24.6, 12.6, 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, + 3.15, 6.45, 6.45, 3.3, 6.6, 13.5, 13.5, 6.9, 6.6, 13.5, 13.5, 6.9, + 3.45, 7.05, 7.05, 3.6, 3.75, 7.65, 7.65, 3.9, 7.8, 15.9, 15.9, 8.1, + 7.8, 15.9, 15.9, 8.1, 4.05, 8.25, 8.25, 4.2, 8.1, 16.5, 16.5, 8.4, + 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., + 8.1, 16.5, 16.5, 8.4, 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, + 8.7, 17.7, 17.7, 9., 4.35, 8.85, 8.85, 4.5, 9., 18.3, 18.3, 9.3, + 9., 18.3, 18.3, 9.3, 4.65, 9.45, 9.45, 4.8, 2.55, 5.25, 5.25, 2.7, + 5.4, 11.1, 11.1, 5.7, 5.4, 11.1, 11.1, 5.7, 2.85, 5.85, 5.85, 3., + 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, 12., 24.6, 24.6, 12.6, + 6.3, 12.9, 12.9, 6.6, 5.7, 11.7, 11.7, 6., 12., 24.6, 24.6, 12.6, + 12., 24.6, 24.6, 12.6, 6.3, 12.9, 12.9, 6.6, 3.15, 6.45, 6.45, 3.3, + 6.6, 13.5, 13.5, 6.9, 6.6, 13.5, 13.5, 6.9, 3.45, 7.05, 7.05, 3.6, + 3.75, 7.65, 7.65, 3.9, 7.8, 15.9, 15.9, 8.1, 7.8, 15.9, 15.9, 8.1, + 4.05, 8.25, 8.25, 4.2, 8.1, 16.5, 16.5, 8.4, 16.8, 34.2, 34.2, 17.4, + 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., 8.1, 16.5, 16.5, 8.4, + 16.8, 34.2, 34.2, 17.4, 16.8, 34.2, 34.2, 17.4, 8.7, 17.7, 17.7, 9., + 4.35, 8.85, 8.85, 4.5, 9., 18.3, 18.3, 9.3, 9., 18.3, 18.3, 9.3, + 4.65, 9.45, 9.45, 4.8}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test4) { - - int bS=2, iD=2,iH=2,iW=2, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {24.6, 24.6,24.6, 24.6,24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2,24.6, 24.6,24.6, 24.6, - 24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2}); - input = 0.5; - weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 4, 1, 0}); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 2, iD = 2, iH = 2, iW = 2, iC = 2, oC = 3, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, 34.2, 34.2, 34.2, + 34.2, 34.2, 34.2, 34.2, 34.2, 24.6, 24.6, 24.6, 24.6, 24.6, 24.6, + 24.6, 24.6, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2, 34.2}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test5) { - int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int iD=3,iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); - auto bias = NDArrayFactory::create('c', {oC}); - - auto exp = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, - -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, - -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f, - -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, - 2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, - 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, - 45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f, - -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, - -78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f, - 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f, - 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f, - 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, - 43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f, - 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, - 134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, - 93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); - - input.linspace(-10, 0.5); - weights.linspace(0.1, 0.1); - bias = 0.2; - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int bS = 1, oD = 5, oH = 5, oW = 5, oC = 3, iC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 2, dH = 2, dW = 2; + int iD = 3, iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, oC}, + {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, + -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, + -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, + -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, + -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, -36.6f, -40.4f, + -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, + -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, + -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, + -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, 2.9f, + 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, + 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, + 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, + 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, + 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, + 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, 45.1f, 50.1f, + 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, + -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, + -33.4f, -32.4f, -31.4f, -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, + -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, + -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, -78.8f, -68.4f, + -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, + -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, + 10.4f, 13.8f, 17.2f, 10.4f, 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, + 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, + -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, + 1.6f, 1.7f, 1.8f, 7.9f, 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, + 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, + 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, + 73.0f, 78.2f, 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, + 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, + 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, 43.6f, + 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, + 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, + 104.3f, 111.7f, 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, + 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, + 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, 134.6f, 143.6f, 152.6f, + 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, + 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, 93.4f, 91.9f, 96.8f, 101.7f, + 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, + 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, + 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test6) { - - int bS=2, oD=4,oH=4,oW=4, oC=5,iC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int iD=3,iH=3,iW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, oC, kD, kH, kW}, {20., 15., 10., 5., 0., -5., -10., -15., 19., 14., 9., 4., -1., -6., -11., -16., 18., 13., 8., 3., -2., -7., -12., -17., - 17., 12., 7., 2., -3., -8., -13., -18., 16., 11., 6., 1., -4., -9., -14., -19., 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, - -11.1, -16.1, 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, - 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, - -17.200001, 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oD, oH, oW, oC}, {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, -8859.700195, -8338.700195, -7817.700195, - -7296.700195, -6775.700195, -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, -3572.850098, -3327.349854, -3081.850098, -2836.350098, - -2590.850098, -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, - -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, -6268.200195, - -5827.200195, -5386.200195, -4945.200195, -4504.200195, -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, -9558.400391, -8736.400391, - -7914.400391, -7092.399902, -6270.400391, -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, sd::DataType::FLOAT32); - - input.linspace(-27, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); + int bS = 2, oD = 4, oH = 4, oW = 4, oC = 5, iC = 10, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int iD = 3, iH = 3, iW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kD, kH, kW}, + {20., 15., 10., 5., 0., -5., -10., -15., + 19., 14., 9., 4., -1., -6., -11., -16., + 18., 13., 8., 3., -2., -7., -12., -17., + 17., 12., 7., 2., -3., -8., -13., -18., + 16., 11., 6., 1., -4., -9., -14., -19., + 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, + 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, -11.1, -16.1, + 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, + 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, + 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, + 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, + 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, + 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, -17.200001, + 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, + 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, + 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, + 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, + 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, + 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, + 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, + 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, + 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, + 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, + 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, + 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, + 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, + 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, + 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, + 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, + 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, + 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, + 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, + 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, + 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, + 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, + 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, + 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, + 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, + 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, + 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, + 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, + 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, + 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, + 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, + 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, + 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, + 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, + 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, + 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, + 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oD, oH, oW, oC}, + {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, + -8859.700195, -8338.700195, -7817.700195, -7296.700195, -6775.700195, + -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, + -3572.850098, -3327.349854, -3081.850098, -2836.350098, -2590.850098, + -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, + -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, + -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, + -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, + -6268.200195, -5827.200195, -5386.200195, -4945.200195, -4504.200195, + -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, + -9558.400391, -8736.400391, -7914.400391, -7092.399902, -6270.400391, + -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, + -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, + -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, + -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, + -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, + -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, + -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, + -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, + -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, + -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, + 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, + 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, + 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, + -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, + 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, + 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, + 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, + 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, + 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, + 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, + 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, + -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, + -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, + -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, + -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, + -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, + 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, + 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, + 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, + 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, + 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, + 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, + 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, + 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, + 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, + 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, + 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, + 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, + 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, + 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, + 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, + 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, + 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, + 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, + 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, + 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, + 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, + 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, + 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, + 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, + 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, + 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, + 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, + 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, + 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, + 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, + 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, + 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, + 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, + 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, + 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, + 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, + 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, + 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, + 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, + 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, + 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, + 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, + 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, + 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, + 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, + 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, + 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, + 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, + 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, + 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, + 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, + 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, + 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, + 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, + 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, + 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, + 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, + 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, + -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, + 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, + 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, + 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, + 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, + 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, + 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, + 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, + 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, + 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, + 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, + 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, + 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, + 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, + -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, + -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, + -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, + -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, + -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, + -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, + -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, + -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, + -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, + -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, + -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, + -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, + -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, + -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, + -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, + -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, + -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, + -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, + -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, + sd::DataType::FLOAT32); + + input.linspace(-27, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_test7) { - - int bS=2, oD=4,oH=4,oW=4, iC=5,oC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=0,pW=0, dD=1,dH=1,dW=1; - int iD=4,iH=4,iW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {iC, kD, kH, kW, oC}, {20., 19.5, 19., 18.5, 18., 17.5, 17., 16.5, 16., 15.5, 15., 14.5, 14., 13.5, 13., 12.5, 12., 11.5, 11., 10.5, 10., - 9.5, 9., 8.5, 8., 7.5, 7., 6.5, 6., 5.5, 5., 4.5, 4., 3.5, 3., 2.5, 2., 1.5, 1., 0.5, 0., -0.5, -1., -1.5, -2., -2.5, -3., -3.5, -4., -4.5, -5., -5.5, -6., - -6.5, -7., -7.5, -8., -8.5, -9., -9.5, -10., -10.5, -11., -11.5, -12., -12.5, -13., -13.5, -14., -14.5, -15., -15.5, -16., -16.5, -17., -17.5, -18., -18.5, - -19., -19.5, 19.9, 19.4, 18.9, 18.4, 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, 8.9, 8.4, 7.9, - 7.4, 6.9, 6.4, 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, -18.4, -18.9, -19.4, -19.9}, sd::DataType::FLOAT32); - NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, -4664.800293, -4640.199707, -4615.600098, - -2755.599854, -4566.400391, -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., - -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, -1273.200195, -1263.999878, - -1579.200073, -2308.599854, -2294., -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, -1400.800049, -1172., -1162.800049, -1153.600098, - -1362.399902, -1135.199951, -1126., -1116.799805, -1422.400024, -2075., -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, -1043.200195, - -1247.199951, -1024.800049, -1015.599976, -1006.400146, -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, -957., -949.800049, -1342., -935.399902, -928.199951, -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, -849., -841.800171, -834.599976, -1204.400024, -820.199707, -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, -705., -697.800171, -690.599976, -1811.199951, -3133., -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, -2514.399902, -4105., -4082.400146, -4059.800293, -1552., -2175.200195, -2162.599854, -2150., -1228.800049, -630., -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, -598.800049, -1167.999878, -588.400391, -583.199951, -578., -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, -520.800049, -515.599976, -1046.400146, -505.199829, -500., -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, -260.599854, -1020.400146, -254.199829, -251., -247.799927, -994., -241.400269, -238.200073, -234.999878, -1327.200073, -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, -148.599976, -145.400024, -782.799927, -139., -135.799805, -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, -842.39978, 100.799927, 102.000244, 103.200439, -820., 105.599609, 106.800171, 108., -1243.199951, -1638.599854, -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, 117.60022, -752.799805, 120., 121.200073, 122.400024, -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, 141.599731, -640.799988, 144., 145.200195, 146.400146, -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, -590.799927, 443., 442.200073, 441.399658, -572.39978, 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, -1184., -1441.199951, -1432.599854, -1424., -500.799927, 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, 808.000244, 805.200073, -472., 799.60022, 796.799683, 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, 766., 763.200317, 760.400024, -414.400146, 754.800049, 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, 3453.200195, 1400., 2129.800049, 2144.400146, 2159., 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, 1624.800171, 1634., 1643.199951, 1556., 1661.600098, 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, 1347., 1354.199951, 1410., 1368.600098, 1375.800171, 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, 1901.599976, 3127., 3149.599854, 3172.200195, 1264., 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, 763.400024, 1091.599976, 769.800171, 773., 776.199951, 1118., 782.599976, 785.800049, 789., 1328.800049, 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, 811.400024, 814.60022, 1197.199951, 821., 824.199951, 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, 875.400146, 878.599854, 1329.199951, 885., 888.199951, 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, 1605., 927.200012, 480., 481.199951, 482.400146, 949.599976, 484.800171, 486., 487.200073, 971.999878, 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, 501.60022, 1039.199951, 504., 505.199951, 506.400146, 1061.599976, 508.799927, 510., 511.200195, 1377.599976, 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, 525.600098, 1151.199829, 528., 529.199829, 530.400146, 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, 199.800171, 199., 198.200195, 826., 196.599731, 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, 881.199951, 187., 186.199829, 185.400024, 899.60022, 183.800171, 183., 182.200073, 1293.599976, 1754.499878, 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, 973.200073, 171., 170.200073, 169.400146, 1068.800049, 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, -74., -76.799805, -79.599854, 665.600098, -85.199829, -87.999756, -90.799805, 680., -96.400024, -99.199829, -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, -130., -132.800171, -135.599976, 737.599976, -141.200073, -144., -146.799805, 1209.599976, 1586., 1594.600098, 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, -379.799805, 534., -389.400146, -394.19989, -398.999817, 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, 1044.400146, 1051., 375.199951, -627.999756, -634.800171, -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, 239.599976, -940.199707, -948.999817, -957.800171, 242., -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, sd::DataType::FLOAT32); - - input.linspace(-32, 0.1); - - sd::ops::deconv3d op; - auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, oD = 4, oH = 4, oW = 4, iC = 5, oC = 10, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int iD = 4, iH = 4, iW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kD, kH, kW, oC}, + {20., 19.5, 19., 18.5, 18., 17.5, + 17., 16.5, 16., 15.5, 15., 14.5, + 14., 13.5, 13., 12.5, 12., 11.5, + 11., 10.5, 10., 9.5, 9., 8.5, + 8., 7.5, 7., 6.5, 6., 5.5, + 5., 4.5, 4., 3.5, 3., 2.5, + 2., 1.5, 1., 0.5, 0., -0.5, + -1., -1.5, -2., -2.5, -3., -3.5, + -4., -4.5, -5., -5.5, -6., -6.5, + -7., -7.5, -8., -8.5, -9., -9.5, + -10., -10.5, -11., -11.5, -12., -12.5, + -13., -13.5, -14., -14.5, -15., -15.5, + -16., -16.5, -17., -17.5, -18., -18.5, + -19., -19.5, 19.9, 19.4, 18.9, 18.4, + 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, + 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, + 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, + 8.9, 8.4, 7.9, 7.4, 6.9, 6.4, + 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, + 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, + -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, + -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, + -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, + -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, + -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, + -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, + -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, + 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, + 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, + 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, + 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, + 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, + 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, + 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, + -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, + -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, + -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, + -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, + -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, + -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, + 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, + 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, + 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, + 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, + 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, + 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, + 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, + -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, + -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, + -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, + -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, + -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, + -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, + -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, + 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, + 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, + 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, + 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, + 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, + 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, + -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, + -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, + -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, + -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, + -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, + -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, + -18.4, -18.9, -19.4, -19.9}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {bS, oC, oD, oH, oW}, + {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, + -4664.800293, -4640.199707, -4615.600098, -2755.599854, -4566.400391, + -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, + -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., + -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, + -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, + -1273.200195, -1263.999878, -1579.200073, -2308.599854, -2294., + -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, + -1400.800049, -1172., -1162.800049, -1153.600098, -1362.399902, + -1135.199951, -1126., -1116.799805, -1422.400024, -2075., + -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, + -1043.200195, -1247.199951, -1024.800049, -1015.599976, -1006.400146, + -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, + -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, + -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, + -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, + -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, + -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, + -957., -949.800049, -1342., -935.399902, -928.199951, + -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, + -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, + -849., -841.800171, -834.599976, -1204.400024, -820.199707, + -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, + -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, + -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, + -705., -697.800171, -690.599976, -1811.199951, -3133., + -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, + -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, + -2514.399902, -4105., -4082.400146, -4059.800293, -1552., + -2175.200195, -2162.599854, -2150., -1228.800049, -630., + -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, + -598.800049, -1167.999878, -588.400391, -583.199951, -578., + -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, + -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, + -520.800049, -515.599976, -1046.400146, -505.199829, -500., + -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, + -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, + -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, + -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, + -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, + -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, + -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, + -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, + -260.599854, -1020.400146, -254.199829, -251., -247.799927, + -994., -241.400269, -238.200073, -234.999878, -1327.200073, + -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, + -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, + -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, + -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, + -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, + -148.599976, -145.400024, -782.799927, -139., -135.799805, + -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, + -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, + -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., + -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, + -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, + -842.39978, 100.799927, 102.000244, 103.200439, -820., + 105.599609, 106.800171, 108., -1243.199951, -1638.599854, + -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, + 117.60022, -752.799805, 120., 121.200073, 122.400024, + -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, + -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, + 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, + 141.599731, -640.799988, 144., 145.200195, 146.400146, + -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, + -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, + -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, + -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, + -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, + 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, + 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, + -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, + -590.799927, 443., 442.200073, 441.399658, -572.39978, + 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, + -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, + 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, + -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, + -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, + -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, + -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, + -1184., -1441.199951, -1432.599854, -1424., -500.799927, + 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, + 808.000244, 805.200073, -472., 799.60022, 796.799683, + 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, + -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, + 766., 763.200317, 760.400024, -414.400146, 754.800049, + 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, + -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, + -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, + 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, + -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, + -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, + -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., + -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, + 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, + 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, + -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, + 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, + 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, + 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, + -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, + 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, + 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, + -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, + -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, + -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, + -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, + 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, + -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, + -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, + 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, + 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, + -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, + 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, + 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, + 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, + -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, + -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, + -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, + -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, + 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, + 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, + -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, + 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, + 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, + -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, + 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, + 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, + 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, + 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, + 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, + 3453.200195, 1400., 2129.800049, 2144.400146, 2159., + 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, + 1624.800171, 1634., 1643.199951, 1556., 1661.600098, + 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., + 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, + 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, + 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., + 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, + 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, + 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, + 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, + 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, + 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, + 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, + 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, + 1347., 1354.199951, 1410., 1368.600098, 1375.800171, + 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, + 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, + 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, + 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, + 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, + 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, + 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., + 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, + 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, + 1901.599976, 3127., 3149.599854, 3172.200195, 1264., + 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., + 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, + 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., + 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, + 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., + 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, + 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, + 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, + 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, + 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, + 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, + 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, + 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, + 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, + 763.400024, 1091.599976, 769.800171, 773., 776.199951, + 1118., 782.599976, 785.800049, 789., 1328.800049, + 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, + 811.400024, 814.60022, 1197.199951, 821., 824.199951, + 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, + 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, + 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, + 875.400146, 878.599854, 1329.199951, 885., 888.199951, + 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, + 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, + 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., + 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, + 1605., 927.200012, 480., 481.199951, 482.400146, + 949.599976, 484.800171, 486., 487.200073, 971.999878, + 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, + 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, + 501.60022, 1039.199951, 504., 505.199951, 506.400146, + 1061.599976, 508.799927, 510., 511.200195, 1377.599976, + 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, + 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, + 525.600098, 1151.199829, 528., 529.199829, 530.400146, + 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, + 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, + 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, + 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, + 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, + 199.800171, 199., 198.200195, 826., 196.599731, + 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, + 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, + 881.199951, 187., 186.199829, 185.400024, 899.60022, + 183.800171, 183., 182.200073, 1293.599976, 1754.499878, + 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, + 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, + 973.200073, 171., 170.200073, 169.400146, 1068.800049, + 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, + 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, + 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, + 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, + -74., -76.799805, -79.599854, 665.600098, -85.199829, + -87.999756, -90.799805, 680., -96.400024, -99.199829, + -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, + 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, + -130., -132.800171, -135.599976, 737.599976, -141.200073, + -144., -146.799805, 1209.599976, 1586., 1594.600098, + 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, + 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, + -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, + 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, + 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, + 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, + 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, + -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, + -379.799805, 534., -389.400146, -394.19989, -398.999817, + 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, + -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, + -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, + -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, + 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, + -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, + -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, + 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, + 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, + 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, + 1044.400146, 1051., 375.199951, -627.999756, -634.800171, + -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, + 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, + 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, + -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, + -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, + 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, + -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, + -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, + -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, + 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, + 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, + 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, + 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, + 239.599976, -940.199707, -948.999817, -957.800171, 242., + -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, + 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, + -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, + 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, + 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, + -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, + -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, + sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test1) { - - int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto bias = NDArrayFactory::create('c', {iC}); - auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, sd::DataType::FLOAT32); - NDArray expGradB('c', {iC}, std::vector{364.5}, sd::DataType::FLOAT32); - - input = 0.5; - weights.linspace(0.1, 0.1); - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - - } + int bS = 1, iD = 3, iH = 3, iW = 3, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {iC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, + 126., 139.6, 138.8, 154., 145.2, 161.2}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, + {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., + 76., 76., 80., 80.}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {iC}, std::vector{364.5}, + sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test2) { - - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, sd::DataType::FLOAT32); - - input = 0.5; - weights.linspace(0.1, 0.1); - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, + 5.6, 2.9, 4.3, 0.75, 1.5}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kD, kH, kW, iC, oC}, + {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, + sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test3) { - - int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - - NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, sd::DataType::FLOAT32); - - input = 0.5; - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 1, iD = 3, iH = 3, iW = 3, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, + 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, + {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, + 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kD, kH, kW, iC, oC}, + {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, + sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test4) { - - int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=3,oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - - NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, sd::DataType::FLOAT32); - - input = 0.5; - gradO.linspace(0.5); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); - - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 1, iD = 2, iH = 2, iW = 2, iC = 1, oC = 2, kD = 2, kH = 2, kW = 2, + sD = 1, sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create( + 'c', {kD, kH, kW, iC, oC}, + {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, + 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI( + 'c', {bS, oC, oD, oH, oW}, + {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, + 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, + 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, + 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, + 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, + sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, + {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, + 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, + sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate( + {&input, &weights, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, dataFormat}, + {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test5) { - - int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, oC, kD, kH, kW}, {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oD, oH, oW},sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iD, iH, iW}, {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, 9.552, - 9.540001, 9.528, 9.516, 9.504001, 9.492, 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, 9.335999, - 9.324001, 9.312, 9.300001, 9.288001, 9.276001, 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, 13.152, 13.134001, - 13.116, 13.098, 13.080001, 13.062, 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, 3.608, 3.604, - 3.6, 3.596, 3.592, 3.588, 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, oC, kD, kH, kW}, {-73678.695312, -59907.972656, -67739.515625, -54962.082031, -15966.075195, -17115.042969, -15269.777344, -16101.275391, 41746.566406, 25677.917969, 37200.003906, 22759.517578}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, sd::DataType::FLOAT32); - - input.linspace(100., -0.5); - gradO.linspace(-16, 0.02); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 3, oC = 2, kD = 2, kH = 1, kW = 1, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, oC, kD, kH, kW}, + {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iD, iH, iW}, + {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, + 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, + 9.552, 9.540001, 9.528, 9.516, 9.504001, 9.492, + 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, + 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, + 9.335999, 9.324001, 9.312, 9.300001, 9.288001, 9.276001, + 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, + 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, + 13.152, 13.134001, 13.116, 13.098, 13.080001, 13.062, + 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, + 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, + 3.608, 3.604, 3.6, 3.596, 3.592, 3.588, + 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, + 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, + 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, + 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, + 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, + 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, + 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, + 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, + 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, + 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, + -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, + -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, + -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, + -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, + -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, + -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, + -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, + -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, + 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, + 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, + 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, + 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, + 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, + 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, + 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, + 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, + 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, + 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, + 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, + 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, + 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, + 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, + 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, + 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, + 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, + 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, + 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, + 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, + 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, + 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, + 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, + 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, + -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, + -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, + -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, + -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, + -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, + -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, + -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, + -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, + 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, + 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kD, kH, kW}, + {-73678.695312, -59907.972656, -67739.515625, -54962.082031, + -15966.075195, -17115.042969, -15269.777344, -16101.275391, + 41746.566406, 25677.917969, 37200.003906, 22759.517578}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, + sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-16, 0.02); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, deconv3d_bp_test6) { - - int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=5,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] - - NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); - NDArray weights('c',{iC, kD, kH, kW, oC}, {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iD, iH, iW, iC}, {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, - 0.462, -0.072, 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, 0.912, 0.434, - -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, - 0.816, 0.402, -0.012, 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, 0.732, 0.374, - 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {iC, kD, kH, kW, oC}, {-6328.958984, -6322.880371, -6134.400879, -6128.319824, -6318.079590, -6312.640137, -6144.000000, -6138.560547, -6307.202637, -6302.399414, -6153.599609, -6148.799316}, sd::DataType::FLOAT32); - NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); - - input.linspace(100., -0.5); - gradO.linspace(-1.6, 0.01); - - sd::ops::deconv3d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 2, iD = 4, iH = 4, iW = 4, iC = 3, oC = 2, kD = 2, kH = 1, kW = 1, + sD = 1, sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 5, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - + // [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {iC, kD, kH, kW, oC}, + {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, + sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iD, iH, iW, iC}, + {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, + 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, 0.462, -0.072, + 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, + 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, + 0.912, 0.434, -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, + 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, + 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, 0.816, 0.402, -0.012, + 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., + 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, + 0.732, 0.374, 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, + 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, + 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, + 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, + 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, + 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, + 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, + 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, + 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, + 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, + 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, + 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, + 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, + 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, + 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, + 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, + -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, + -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, + -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, + -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, + -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, + -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, + -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, + -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, + -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, + -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, + -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, + -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, + -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, + -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, + -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, + -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, + -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, + -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kD, kH, kW, oC}, + {-6328.958984, -6322.880371, -6134.400879, -6128.319824, + -6318.079590, -6312.640137, -6144.000000, -6138.560547, + -6307.202637, -6302.399414, -6153.599609, -6148.799316}, + sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-1.6, 0.01); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_1) { - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_2) { - - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - VariableSpace variableSpace; - variableSpace.putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - Context block(1, &variableSpace, false); - block.setName("alpha"); - block.fillInputs({-1}); - block.appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(&block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(variableSpace.hasVariable(block.nodeId(), 0)); - ASSERT_TRUE(variableSpace.hasVariable("alpha")); - ASSERT_TRUE(variableSpace.hasVariable(block.nodeId())); - auto result = variableSpace.getVariable(block.nodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + Context block(1, &variableSpace, false); + block.setName("alpha"); + block.fillInputs({-1}); + block.appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(&block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId(), 0)); + ASSERT_TRUE(variableSpace.hasVariable("alpha")); + ASSERT_TRUE(variableSpace.hasVariable(block.nodeId())); + auto result = variableSpace.getVariable(block.nodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_3) { - - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_4) { - - const int bS = 2; - const int iD = 1; - const int iH = 24; - const int iW = 24; - const int kH = 3; - const int kW = 3; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_5) { - - const int bS = 2; - const int iD = 1; - const int iH = 24; - const int iW = 24; - const int kH = 3; - const int kW = 3; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dH,dW, 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result.printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dH, dW, + 1}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, + 59.f, 60.f, 63.f, 64.f}); - x.linspace(1); + x.linspace(1); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, + 59.f, 60.f, 63.f, 64.f}); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f}); + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, + 82.f, 84.f, 92.f, 94.f}); - x.linspace(1); + x.linspace(1); - sd::ops::maxpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) { + int bS = 3; // batch size (number of samples) + int iC = 3; // input channels + int iH = 28, iW = 28; // input height/width + int kH = 2, kW = 2; // kernel (filter) height/width + int sH = 1, sW = 1; // stride height/width + int pH = 0, pW = 0; // padding height/width + int dH = 1, dW = 1; // dilation height/width - int bS = 3; // batch size (number of samples) - int iC = 3; // input channels - int iH = 28, iW = 28; // input height/width - int kH = 2, kW = 2; // kernel (filter) height/width - int sH = 1, sW = 1; // stride height/width - int pH = 0, pW = 0; // padding height/width - int dH = 1, dW = 1; // dilation height/width - - int oH = 27, oW = 27; // output height/width - - int isSameMode = 0; // 1-SAME, 0-VALID + int oH = 27, oW = 27; // output height/width - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + int isSameMode = 0; // 1-SAME, 0-VALID - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0}); - auto output = results.at(0); + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output.isSameShape({bS, iC, oH, oW})); + sd::ops::maxpool2d op; + auto results = op.evaluate( + {&input}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, 1, 0}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isSameShape({bS, iC, oH, oW})); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { - - int bS=1, iH=4,iW=4, iC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f, - 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, - 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, - 0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f}); - - auto expOutput = NDArrayFactory::create('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f, - 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, - 0.85722464f, 0.85722464f, 0.85019743f}); - - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 4, iW = 4, iC = 3, kH = 2, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, + 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, + 0.7977864f, 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, + 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, + 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, + 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, + 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, + 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, 0.07198584f, + 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, + 0.85722464f, 0.06376881f, 0.39791203f}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, iC, oH, oW}, + {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, + 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, + 0.7135856f, 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, + 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, + 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, 0.85722464f, + 0.85722464f, 0.85019743f}); + + sd::ops::maxpool2d op; + auto results = + op.evaluate({&input}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_11) { + NDArray input('c', {1, 1, 4, 5}, sd::DataType::FLOAT32); + NDArray z('c', {1, 1, 4, 5}, sd::DataType::FLOAT32); - NDArray input('c', {1,1,4,5}, sd::DataType::FLOAT32); - NDArray z('c', {1,1,4,5}, sd::DataType::FLOAT32); - - input.linspace(1.); - - sd::ops::maxpool2d op; - auto results = op.evaluate({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0}); + input.linspace(1.); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::maxpool2d op; + auto results = op.evaluate({&input}, {}, {2, 2, 1, 1, 1, 1, 2, 2, 1, 0, 0}); + ASSERT_EQ(Status::OK(), results.status()); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, - 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f, - 154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, + 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, + 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f, + 118.5f, 119.5f, 121.5f, 122.5f, 130.5f, 131.5f, 133.5f, 134.5f, + 154.5f, 155.5f, 157.5f, 158.5f, 166.5f, 167.5f, 169.5f, 170.5f, + 190.5f, 191.5f, 193.5f, 194.5f, 202.5f, 203.5f, 205.5f, 206.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, - 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, - 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, - 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, - 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, - 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, + 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, + 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, + 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, + 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, + 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, + 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, + 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, + 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, + 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, + 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, + 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, + 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, - 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, - 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, + 69.5f, 70.5f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, + 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, + 148.5f, 149.5f, 150.5f, 151.5f, 173.5f, 174.5f, 175.5f, 176.5f, + 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, - 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, - 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, - 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, - 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, - 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, - 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, - 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, - 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, - 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, - 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, - 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, - 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, - 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, - 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, - 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, - 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, - 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); - input.linspace(1.); - - sd::ops::avgpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, + 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, + 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, + 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, + 6.50f, 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, + 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, + 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, + 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, + 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, + 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, + 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, 10.50f, + 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, + 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, + 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, + 50.50f, 25.50f, 16.833334f, 34.00f, 34.666668f, 17.50f, + 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, + 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, + 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, + 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, + 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, + 19.00f, 38.25f, 38.75f, 19.50f, 19.75f, 39.75f, + 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, + 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, + 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, + 28.833334f, 58.00f, 58.666668f, 29.50f, 30.833334f, 62.00f, + 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, + 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, + 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, 16.75f, + 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, + 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, + 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, + 56.75f, 28.50f, 28.75f, 57.75f, 58.25f, 29.25f, + 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, + 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, + 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, + 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, + 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, + 22.416666f, 45.00f, 45.333332f, 22.75f, 34.00f, 68.25f, + 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, + 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, + 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, 37.50f, + 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, + 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, + 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, + 158.50f, 79.50f, 52.833332f, 106.00f, 106.666664f, 53.50f, + 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, + 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, + 56.833332f, 114.00f, 114.666664f, 57.50f, 28.416666f, 57.00f, + 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, + 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, + 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, + 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, + 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, + 191.50f, 96.00f, 96.50f, 193.50f, 194.50f, 97.50f, + 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, + 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, + 102.50f, 205.50f, 206.50f, 103.50f, 68.833336f, 138.00f, + 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, + 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, + 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, - 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, + 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, + 104.f, 105.f, 107.f, 108.f, 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, + 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, + 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, - 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, - 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); -} + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, + 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, + 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, + 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 85.f, 86.f, 87.f, 88.f, + 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, + 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, + 107.f, 108.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, + 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, + 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, + 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 157.f, 158.f, + 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, + 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, + 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, + 180.f, 178.f, 179.f, 180.f, 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, + 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, + 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, + 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, + 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, + 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, + 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, - 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oD, oH, oW, iC}, + {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, + 105.f, 106.f, 107.f, 108.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, + 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, + 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // -SAME, 0-VALID - int dataFormat = 0; // -NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, - 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, - 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, - 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, - 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, - 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, - 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, - 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); - input.linspace(1.); - - sd::ops::maxpool3dnew op; - auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // -SAME, 0-VALID + int dataFormat = 0; // -NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, oD, oH, oW}, + {4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, + 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, + 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, + 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, + 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 28.f, 29.f, + 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, + 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, + 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, + 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, + 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, 64.f, 65.f, 66.f, 66.f, + 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, + 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, + 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, + 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, + 82.f, 83.f, 84.f, 84.f, 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, + 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, + 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, + 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, + 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, + 108.f, 108.f, 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, + 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, + 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, + 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, + 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, + 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, + 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, + 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, + 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, 172.f, 173.f, + 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, + 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, + 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, + 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, + 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, 196.f, 197.f, 198.f, 198.f, + 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, + 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, + 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, + 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, + 214.f, 215.f, 216.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, - 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, - 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, + 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, + 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, - 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, + 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, - 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, + 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, + 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, + 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.41667f, + 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, + 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, + 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, + 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.16667f, 1.16667f, + 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, + 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, + 2.5f, 3.75f, 3.75f, 3.75f, 1.75f, 1.75f, 1.75f, + 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, + 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, + 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, + 3.75f, 3.75f, 3.75f, 0.41667f, 0.41667f, 0.41667f, 0.83333f, + 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, + 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, + 1.75f, 1.75f, 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, + 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, + 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, + 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, + 3.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, + 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, + 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, + 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, + 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, + 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, + 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, - 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, - 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, - 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, - 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); - input.linspace(1.); - gradO = 2.; - - sd::ops::avgpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, + 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, + 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, + 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.91667f, + 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, + 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, + 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, + 1.33333f, 1.33333f, 2.f, 2.f, 2.f, 1.16667f, 1.16667f, + 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, + 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, + 1.f, 1.5f, 1.5f, 1.5f, 1.f, 1.f, 1.f, + 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, + 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, + 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, + 8.25f, 8.25f, 8.25f, 0.16667f, 0.16667f, 0.16667f, 0.33333f, + 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, + 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, + 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, + 1.75f, 1.75f, 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, + 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, + 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, + 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, + 2.f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, + 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, + 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, + 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, + 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, + 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, + 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=2,oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); - - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 2, oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; - int oD=4,oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, - 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 1, pH = 1, pW = 1, dD = 1, dH = 1, dW = 1; + int oD = 4, oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iD, iH, iW}, + {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, + 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, + 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, + 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, + 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, + 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, + 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, + 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, + 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, + 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, + 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, + 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, + 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, + 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, + 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, + 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, + 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, + 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, + 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, - 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, - 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, + 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, + 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, + 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, + 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, + 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { - - int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; - int oD=3,oH=4,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, - 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, - 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool3dnew_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 3, iH = 4, iW = 3, iC = 3, kD = 2, kH = 3, kW = 2, sD = 1, + sH = 1, sW = 1, pD = 0, pH = 0, pW = 0, dD = 1, dH = 1, dW = 1; + int oD = 3, oH = 4, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, 14.1f, 14.7f, 15.3f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, + 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, + 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, + {kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, + paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_1) { - - auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_2) { - - int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; - int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; - - // TypeParam epsilonBuff[] = {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}; - // TypeParam expectedBuff[] = {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}; - - NDArray input('c', {bS,iD,iH,iW}); - NDArray epsilon('c', {bS,iD,oH,oW}, {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}); - NDArray expected('c', {bS,iD,iH,iW}, {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}); - - - input.linspace(1.); - - std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, argI); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 1, iH = 4, iW = 4, oD = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; + int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; + + // TypeParam epsilonBuff[] = + // {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., + // 31., 32.}; TypeParam expectedBuff[] = {0., 0., 0., + // 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., + // 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}; + + NDArray input('c', {bS, iD, iH, iW}); + NDArray epsilon('c', {bS, iD, oH, oW}, + {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., + 27., 28., 30., 31., 32.}); + NDArray expected('c', {bS, iD, iH, iW}, + {0., 0., 0., 0., 0., 6., 7., 8., 0., 10., 11., + 12., 0., 14., 15., 16., 0., 0., 0., 0., 0., 22., + 23., 24., 0., 26., 27., 28., 0., 30., 31., 32.}); + + input.linspace(1.); + + std::initializer_list argI = { + kH, kW, sH, sW, pH, pW, + dW, dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride + // Height/Width; 4,5 - pad Height/Width; 6,7 - + // dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, - 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, - 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 1, + pW = 1, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, + 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, + 5.6f, 11.8f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, + 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, + 11.1f, 11.8f, 12.f, 24.6f, 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, + 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, + 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, - 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, - 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, + 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, + 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, 0.f, + 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, + 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, - 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, + 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, maxpool2d_bp_7) { - - int bS=2, iH=56,iW=56, iC=3, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oH=28,oW=28; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::maxpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - // auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - // ASSERT_TRUE(expected.isSameShape(output)); - // ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 56, iW = 56, iC = 3, kH = 2, kW = 2, sH = 2, sW = 2, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 28, oW = 28; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + // auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(expected.isSameShape(output)); + // ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, avgpool2d_bp_1) { - - auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format - - sd::ops::avgpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 1, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - + // extraParam0 (unnecessary for avg mode), 10 - data format + + sd::ops::avgpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) { - - int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; - int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; - - // TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}; - // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; - - auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); - auto expected = NDArrayFactory::create('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); - - input.linspace(1.); - - std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 1, 0}; - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, argI); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iD = 1, iH = 4, iW = 4, oD = 3, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; + int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; + + // TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 + // , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}; + // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 + // , 6., 7., 3.75, 4.75 + // ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., + // 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create( + 'c', {bS, iD, oH, oW}, + {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, + 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); + auto expected = NDArrayFactory::create( + 'c', {bS, iD, iH, iW}, + {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, + 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, + 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, + 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); + + input.linspace(1.); + + std::initializer_list argI = {kH, kW, sH, sW, pH, pW, + dW, dH, 1, 1, 0}; + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, - 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, - 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, - 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, - 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, - 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, + 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, + 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, + 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, + 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, + 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, + 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; - int oH=4,oW=4; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, - 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, - 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, - 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, - 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, - 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 1, + pW = 1, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, + 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, + 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, + 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, + 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, + 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, + 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, - 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, - 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, - 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, + 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, + 1.425f, 2.4f, 2.575f, 2.75f, 1.18333f, 1.24167f, 1.3f, + 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, + 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, + 3.925f, 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, + 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, + 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, 3.28333f, 3.34167f, + 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, + 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, + 8.3f, 8.425f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); - auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, - 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, - 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, - 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::avgpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, iH, iW, iC}, + {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, + 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, + 0.5f, 0.23333f, 0.26667f, 0.3f, 0.13333f, 0.16667f, 0.2f, + 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, + 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, + 0.2f, 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, + 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, + 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.53333f, 0.56667f, + 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, + 0.38333f, 0.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = + op.evaluate({&input, &gradO}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { - - auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}); - auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->fillInputs({-2}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor - block->appendT(0.000001); - - sd::ops::pnormpool2d_bp bp; - Nd4jStatus status = bp.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto epsilon = NDArrayFactory::create('c', {bS, iD, oH, oW}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + block->appendI({kH, kW, sH, sW, pH, pW, dW, dH, 0, + 3}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; + // 4,5 - pad Height/Width; 6,7 - dilation Height/Width; + // 8 - same mode; 9 - divisor + block->appendT(0.000001); + + sd::ops::pnormpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int pnorm = 3; - double eps = 0.; - - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, - 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, - 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, - 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, - 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, - 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::pnormpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int pnorm = 3; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, + 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, + 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, + 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, + 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, + 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, + 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, + 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, + 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, + 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, + 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, + 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, + 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate( + {&input, &gradO}, {eps}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { - - int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int pnorm = 2; - double eps = 0.; - - int paddingMode = 0; // 1-SAME, 0-VALID - int dataFormat = 0; // 1-NDHWC, 0-NCDHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); - auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, - 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, - 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, - 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, - 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, - 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); - input.linspace(1.); - gradO.linspace(0.1, 0.1); - - sd::ops::pnormpool2d_bp op; - auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 3, kH = 3, kW = 2, sH = 1, sW = 1, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int pnorm = 2; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, + 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, + 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, + 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, + 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, + 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, + 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, + 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate( + {&input, &gradO}, {eps}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_1) { + const int bS = 1, iH = 2, iW = 2, iC = 1; + const int factorH = 2, factorW = 2; + const int isNCHW = 1; // data format, default is NCHW - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 1; // data format, default is NCHW + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = + NDArrayFactory::create('c', {bS, iC, iH * factorH, iW * factorW}); + gradO = 1.; - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto gradO = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}); - gradO = 1.; + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + expGradI = 4.; - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); - expGradI = 4.; - - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_2) { + const int bS = 1, iH = 2, iW = 2, iC = 1; + const int factorH = 2, factorW = 2; + const int isNCHW = 0; // data format, default is NCHW - const int bS=1, iH=2,iW=2, iC=1; - const int factorH=2, factorW=2; - const int isNCHW = 0; // data format, default is NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto gradO = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}); - gradO = 1.; + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = + NDArrayFactory::create('c', {bS, iH * factorH, iW * factorW, iC}); + gradO = 1.; - auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); - expGradI = 4.; + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); + expGradI = 4.; - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, upsampling2d_bp_3) { - - const int bS=1, iH=3,iW=3, iC=2; - const int factorH=2, factorW=2; - const int isNCHW = 1; // data format, default is NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - - NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984, - 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761, - 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287, - 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, - 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972, - 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549, - 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, sd::DataType::FLOAT32); - - NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644, - 1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, sd::DataType::FLOAT32); - - sd::ops::upsampling2d_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - + const int bS = 1, iH = 3, iW = 3, iC = 2; + const int factorH = 2, factorW = 2; + const int isNCHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + + NDArray gradO('c', {bS, iC, iH * factorH, iW * factorW}, + {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, + 0.31069338, 0.44793984, 0.93800974, 0.32667395, 0.15187258, + 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, + 0.14696825, 0.26089668, 0.13505761, 0.7562093, 0.27545404, + 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, + 0.31279507, 0.13591796, 0.5175439, 0.32870287, 0.061735712, + 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, + 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, + 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, + 0.19694561, 0.6994972, 0.0743224, 0.42042503, 0.5842631, + 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, + 0.8759956, 0.5698191, 0.4458631, 0.5277549, 0.016646361, + 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, + 0.3326449, 0.11739397}, + sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, + 1.0523045, 1.3174672, 1.9263644, 1.090545, 1.9094483, 1.3611296, + 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, + sd::DataType::FLOAT32); + + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=4,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, - 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 4, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, + 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 5.4f, 6.f, 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, + 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, + 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, 12.f, 12.8f, + 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, + 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, + 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_2) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); - input = 2.; - weights.linspace(0.1, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oH, oW, oC}, + {13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_3) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); - auto biases = NDArrayFactory::create('c', {iC*mC}, {1.f,2.f,3.f,4.f}); - - NDArray expOutput('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); + auto biases = + NDArrayFactory::create('c', {iC * mC}, {1.f, 2.f, 3.f, 4.f}); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {5.2, 5.2, 5.2, 5.2, 20.6, 20.6, 20.6, 20.6, 14.4, 14.4, 14.4, + 14.4, 29.8, 29.8, 29.8, 29.8, 5.2, 5.2, 5.2, 5.2, 20.6, 20.6, + 20.6, 20.6, 14.4, 14.4, 14.4, 14.4, 29.8, 29.8, 29.8, 29.8}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_4) { + int bS = 1, iH = 111, iW = 111, iC = 32, mC = 1, kH = 7, kW = 7, sH = 2, + sW = 2, pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 56, oW = 56; - int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=56,oW=56; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW + const float unique = -1000000; - const float unique = -1000000; + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.0001); + weights = 0.5; + output = unique; - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.0001); - weights = 0.5; - output = unique; + sd::ops::depthwise_conv2d op; + Nd4jStatus status = + op.execute({&input, &weights}, {&output}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}, {}); - sd::ops::depthwise_conv2d op; - Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(Status::OK(), status); - - for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i) - ASSERT_EQ(output.e(i) != unique, true); + for (Nd4jLong i = output.lengthOf() / 1.5; i < output.lengthOf(); ++i) + ASSERT_EQ(output.e(i) != unique, true); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_5) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - - NDArray expOutput('c', {bS, oH, oW, oC}, {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., 14., 15., 16., 17., 8.5, 9.}, sd::DataType::FLOAT32); - - input.linspace(1.); - weights = 0.5; - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); - + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 1, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + NDArray expOutput('c', {bS, oH, oW, oC}, + {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., + 14., 15., 16., 17., 8.5, 9.}, + sd::DataType::FLOAT32); + + input.linspace(1.); + weights = 0.5; + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_6) { - - int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}, sd::DataType::FLOAT32); - input.linspace(1.); - weights = 1.; - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - // output.printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 1, kH = 2, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, + {20., 24., 28., 32., 16., 18., 44., 48., 52., 56., 28., 30., + 28., 30., 32., 34., 17., 18.}, + sd::DataType::FLOAT32); + input.linspace(1.); + weights = 1.; + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + // output.printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_7) { - - int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, - 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, - 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}, sd::DataType::FLOAT32); - NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, - 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, - 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, - 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, - 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); - - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, + {0.6793503761291504, 0.35508695244789124, 0.842789351940155, + 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, + 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, + 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, + sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, + {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, + 0.19896849989891052}, + sd::DataType::FLOAT32); + NDArray biases('c', {1, iC * mC}, + {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, + 0.4270855486392975}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {0.7012459761288241, 0.6588178652487691, 0.722631079971582, + 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, + 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, + 0.5054379267801892, 0.8283436386757472, 0.5765540302788565, + 0.6649797296980537, 0.9807239274294943, 0.586850056971322, + 0.261199593183985, 0.3930965634902499, 0.6203697362284615, + 0.28794692117826504, 0.6297390019475202, 0.26769104886224415, + 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, + 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, + 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_8) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, - -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, - 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, - 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, - 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, - 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, - 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, - 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, - 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, + -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, + -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, + -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, + -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, + -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, + -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, + 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, + 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, + 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, + 22.059998, 25.600000, 26.860001, 28.239998, 29.739998, + 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, + 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, + 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, + 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, + 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, + 63.900002, 64.559998, 65.340004, 106.080002, 106.169998, + 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, + 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, + 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, + 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, + 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, + 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, + 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, + 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, + 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, + 176.639999, 181.769989, 187.079987, 192.570007, 166.559998, + 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, + 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, + 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, + 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, + 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, + 207.239990, 214.889999, 222.720001, 230.730011, 238.919998, + 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, + 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, + 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, + 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, + 247.289993, 257.279999, 267.449982, 277.799988, 288.330017, + 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, + 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, + 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, + 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, + 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, + 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, + 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, + 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, + 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, + 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, + 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, + 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, + 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, + 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, + 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, + 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, + 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, + 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, + 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, + 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, + 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, + 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, + 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, + 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, + 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, + 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, + 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, + 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, + 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, + 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, + 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, + 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, + 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, + 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, + 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, + 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, + 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, + 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, + 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, + 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, + 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, + 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, + 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, + 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, + 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, + 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, + 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, + 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, + 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, + 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, + 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, + 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, + 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, + 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, + 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, + 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, + 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, + 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, + 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, + 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, + 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, + 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, + 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, + 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, + 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, + 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, + 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, + 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, + 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, + 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, + 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, + 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, + 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, + 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, + 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, + 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, + 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, + 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, + 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, + 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, + 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, + 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, + 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, + 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, + 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, + 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, + 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, + 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, + 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, + 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, + 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, + 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, + 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, + 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, + 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, + 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, + 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, + 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, + 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, + 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, + 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, + 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, + 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, + 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, + 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, + 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, + 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, + 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, + 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, + 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, + 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, + 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, + 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, + 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, + 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, + 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, + 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, + 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, + -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_9) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, -125.680000, -124.240005, -122.799995, -121.360001, -66.720001,-76.199997, -81.239998, -80.160004, -79.080002, -78.000000, -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, -66.599998, -70.440002, -69.360001, -68.279999, - -67.199997, -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, -37.799999, -38.040001, - -36.959999, -35.879997, -34.799999, -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, -22.919998,-21.840000, -20.759998, -19.679998, -5.400000, -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, - -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, 3.920000,3.920001, 3.920000, 3.920000, 3.520000, 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, 20.430000, 27.750000, - 28.919998, 30.090000, 31.260000, 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, 51.029999, 62.849998, - 64.019997, 65.190002, 66.360001, 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, 81.630005, 97.949997, - 99.120003, 100.290009, 101.459999, 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, 129.080002, 169.279999, - 170.839996, 172.399994, 173.960007, 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, - 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, 210.179993, 211.440002, 212.700012, - 213.959991, 104.819992, 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, 247.980011, - 249.239990, 250.500000, 251.759995, 122.819992, 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); - + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, + -125.680000, -124.240005, -122.799995, -121.360001, -66.720001, + -76.199997, -81.239998, -80.160004, -79.080002, -78.000000, + -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, + -66.599998, -70.440002, -69.360001, -68.279999, -67.199997, + -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, + -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, + -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, + -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, + -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, + -37.799999, -38.040001, -36.959999, -35.879997, -34.799999, + -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, + -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, + -22.919998, -21.840000, -20.759998, -19.679998, -5.400000, + -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, + -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, + -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, + -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, + 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, + 3.920000, 3.920001, 3.920000, 3.920000, 3.520000, + 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, + 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, + 20.430000, 27.750000, 28.919998, 30.090000, 31.260000, + 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, + 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, + 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, + 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, + 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, + 51.029999, 62.849998, 64.019997, 65.190002, 66.360001, + 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, + 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, + 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, + 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, + 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, + 81.630005, 97.949997, 99.120003, 100.290009, 101.459999, + 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, + 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, + 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, + 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, + 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, + 129.080002, 169.279999, 170.839996, 172.399994, 173.960007, + 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, + 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, + 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, + 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, + 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, + 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, + 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, + 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, + 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, + 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, + 210.179993, 211.440002, 212.700012, 213.959991, 104.819992, + 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, + 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, + 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, + 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, + 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, + 247.980011, 249.239990, 250.500000, 251.759995, 122.819992, + 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, + 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, + 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, + 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, + 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, + 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, + 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, + 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, + 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, + 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, + 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, + 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, + 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, + 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, + 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, + 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, + 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, + 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, + 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, + 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, + 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, + 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, + 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, + 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, + 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, + 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, + 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, + 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, + 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, + 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, + 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, + 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, + 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, + 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, + 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, + 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, + 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, + 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, + 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, + 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, + 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, + 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, + 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, + 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, + 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, + 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, + 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, + 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, + 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, + 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, + 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, + 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, + 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, + 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, + 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, + 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, + 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, + 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, + 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, + 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, + 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, + 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, + 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, + 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, + 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, + 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, + 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, + 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, + 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, + 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, + 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, + 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, + 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, + 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, + 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, + 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, + 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, + 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, + 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, + 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, + 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, + 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, + 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, + 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, + 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, + 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, + 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, + 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, + 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, + 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, + 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, + 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, + 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, + 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, + 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, + 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, + 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, + 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, + 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, + 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, + 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, + 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = + op.evaluate({&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_10) { - - int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=3,oW=3; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, - 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, - 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, iC, kH, kW}, {0.130845, 0.569885, 0.644284, 0.198968}, sd::DataType::FLOAT32); - NDArray biases('c', {iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, - 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, - 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, - 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, - 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 3, iW = 3, iC = 2, mC = 2, kH = 1, kW = 1, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 3, oW = 3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, + {0.6793503761291504, 0.35508695244789124, 0.842789351940155, + 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, + 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, + 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, + sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, + {0.130845, 0.569885, 0.644284, 0.198968}, + sd::DataType::FLOAT32); + NDArray biases('c', {iC * mC}, + {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, + 0.4270855486392975}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oC, oH, oW}, + {0.7012459761288241, 0.6588178652487691, 0.722631079971582, + 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, + 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, + 0.5054379267801892, 0.8283436386757472, 0.5765540302788565, + 0.6649797296980537, 0.9807239274294943, 0.586850056971322, + 0.261199593183985, 0.3930965634902499, 0.6203697362284615, + 0.28794692117826504, 0.6297390019475202, 0.26769104886224415, + 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, + 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, + 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate( + {&input, &weights, &biases}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_11) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - int wFormat = 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, kH, kW, iC}, {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, - 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., - 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, sd::DataType::FLOAT32); - - NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, - -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, - 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, - 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, - 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, - 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, - 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, - 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, - 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, - 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - - sd::ops::depthwise_conv2d op; - auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = + 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {mC, kH, kW, iC}, + {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, + -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, + 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, + 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, + 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, + 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, + sd::DataType::FLOAT32); + + NDArray expOutput( + 'c', {bS, oH, oW, oC}, + {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, + -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, + -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, + -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, + -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, + -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, + -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, + 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, + 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, + 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, + 22.059998, 25.600000, 26.860001, 28.239998, 29.739998, + 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, + 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, + 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, + 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, + 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, + 63.900002, 64.559998, 65.340004, 106.080002, 106.169998, + 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, + 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, + 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, + 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, + 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, + 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, + 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, + 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, + 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, + 176.639999, 181.769989, 187.079987, 192.570007, 166.559998, + 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, + 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, + 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, + 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, + 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, + 207.239990, 214.889999, 222.720001, 230.730011, 238.919998, + 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, + 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, + 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, + 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, + 247.289993, 257.279999, 267.449982, 277.799988, 288.330017, + 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, + 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, + 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, + 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, + 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, + 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, + 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, + 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, + 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, + 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, + 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, + 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, + 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, + 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, + 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, + 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, + 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, + 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, + 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, + 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, + 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, + 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, + 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, + 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, + 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, + 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, + 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, + 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, + 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, + 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, + 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, + 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, + 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, + 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, + 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, + 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, + 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, + 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, + 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, + 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, + 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, + 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, + 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, + 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, + 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, + 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, + 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, + 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, + 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, + 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, + 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, + 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, + 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, + 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, + 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, + 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, + 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, + 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, + 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, + 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, + 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, + 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, + 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, + 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, + 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, + 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, + 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, + 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, + 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, + 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, + 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, + 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, + 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, + 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, + 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, + 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, + 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, + 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, + 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, + 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, + 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, + 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, + 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, + 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, + 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, + 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, + 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, + 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, + 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, + 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, + 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, + 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, + 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, + 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, + 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, + 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, + 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, + 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, + 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, + 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, + 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, + 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, + 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, + 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, + 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, + 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, + 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, + 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, + 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, + 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, + 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, + 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, + 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, + 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, + 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, + 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, + 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, + 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, + -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, + sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate( + {&input, &weights}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=4,oW=3; - int oC=iC*mC; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, - 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 3; + int oC = iC * mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.07, 0.19, 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, + 1.878, 2.67, 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, + 3.932, 4.748, 4.428, 5.308, 1.126, 1.63, 3.228, 4.3, 3.468, 4.604, + 3.123, 3.999, 7.95, 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742, + 10.158, 12.39, 4.198, 4.958, 9.884, 11.468, 10.38, 12.028}, + sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC}, + {19.08, 19.44, 19.8, 20.16, 12.24, 12.48, 12.72, 12.96, + 22.56, 23.04, 23.52, 24., 14.4, 14.72, 15.04, 15.36, + 14.76, 15.12, 15.48, 15.84, 9.36, 9.6, 9.84, 10.08}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) { - - int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); - - NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, - 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32); - NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 2, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {0.005, 0.025, 0.034, 0.106, 0.061, 0.113, 0.058, 0.162, 0.292, 0.564, + 0.298, 0.466, 0.234, 0.402, 0.772, 1.172, 0.602, 0.834, 0.333, 0.449, + 0.882, 1.146, 0.581, 0.729, 0.053, 0.137, 0.258, 0.458, 0.237, 0.353, + 0.41, 0.642, 1.252, 1.78, 0.906, 1.202, 1.098, 1.394, 2.756, 3.412, + 1.722, 2.082, 0.893, 1.073, 2.13, 2.522, 1.269, 1.481}, + sd::DataType::FLOAT32); + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, + 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88, 2.4, 2.56, 2.72, 2.88}, + sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test3) { - - auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); - auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); - auto b = NDArrayFactory::create('c', {1, 16}); - auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); - - auto gradI = in.like(); - auto gradW = w.like(); - auto gradB = b.like(); - - nd4j:ops::depthwise_conv2d_bp op; - auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); - ASSERT_EQ(Status::OK(), status); + auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); + auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); + auto b = NDArrayFactory::create('c', {1, 16}); + auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); + + auto gradI = in.like(); + auto gradW = w.like(); + auto gradB = b.like(); + +nd4j: + ops::depthwise_conv2d_bp op; + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, + {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); - - - NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, - 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, - -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, - -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, - -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, - -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, - -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, - -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, - -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, - -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + NDArray expGradI( + 'c', {bS, iH, iW, iC}, + {10.880001, 13.239998, 15.520001, 17.719997, 19.840000, + 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, + 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, + 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, + 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, + 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, + 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, + 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, + 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, + 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, + 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, + 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, + 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, + 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, + 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, + 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, + 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, + 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, + 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, + 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, + 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, + 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, + 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, + 98.790009, 96.000008, 93.030006, 89.879990, 86.549988, + 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, + 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, + 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, + 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, + 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, + 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, + 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, + 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, + 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, + 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, + 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, + -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, + 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, + 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, + -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, + -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, + 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, + -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, + -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, + -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, + -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, + -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, + -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, + -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, + 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, + -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, + -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, + -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, + -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, + -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, + -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, + -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, + -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, + -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, + -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, + -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, + -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, + -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, + -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, + -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, + -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, + -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, + -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, + -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, + -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, + -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, + -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, + -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, + -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, + -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, + -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, + -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, + -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, + -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, + -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, + -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, + -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, + -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, + -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, + -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, + -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, + -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, + -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, + -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, + -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, + -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, + -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, + -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, + -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, + -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, + -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, + -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, + -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, + -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, + -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, + -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, + -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, + -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, + -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, + -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, + -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, + -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, + -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, + -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, + -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, + -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, + -609.959961, -652.169983, -352.320007, -380.220001, -408.239990, + -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, + -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, + -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, + -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, + -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, + -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, + -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, + -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, + -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, + -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, + -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, + -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, + -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, + -629.760010, -678.089966, -726.600037, -775.289917, -446.880035, + -495.210052, -543.719971, -592.410034, -641.279968, -690.330017, + -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, + -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, + -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, + -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, + -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, + -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, + -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, + -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, + -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, + -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, + -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, + -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, + -624.360046, -678.809937, -733.440002, -788.250000, -843.239990, + -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, + -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, + -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, + -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, + -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, + -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, + -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, + -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, + -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, + -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, + -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, + -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, + -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, + -861.440063, -900.140015, -938.960022, -977.899963, -1016.960022, + -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, + -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, + -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, + -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, + -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, + -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, + -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {-104306.421875, -104786.734375, -105268.687500, -105752.250000, + -106237.421875, -106724.242188, -107212.671875, -107702.734375, + -116289.593750, -116823.296875, -117358.781250, -117896.109375, + -118435.210938, -118976.109375, -119518.796875, -120063.296875, + -104824.789062, -105305.117188, -105787.070312, -106270.640625, + -106755.843750, -107242.640625, -107731.078125, -108221.117188, + -126744.000000, -127277.710938, -127813.187500, -128350.484375, + -128889.601562, -129430.515625, -129973.210938, -130517.703125, + -140944.000000, -141536.984375, -142131.984375, -142729.000000, + -143328.000000, -143929.015625, -144532.000000, -145137.000000, + -126744.000000, -127277.710938, -127813.187500, -128350.484375, + -128889.601562, -129430.515625, -129973.210938, -130517.703125, + -104824.789062, -105305.117188, -105787.070312, -106270.640625, + -106755.843750, -107242.640625, -107731.078125, -108221.117188, + -116289.593750, -116823.296875, -117358.781250, -117896.109375, + -118435.210938, -118976.109375, -119518.796875, -120063.296875, + -104306.421875, -104786.734375, -105268.687500, -105752.250000, + -106237.421875, -106724.242188, -107212.671875, -107702.734375}, + sd::DataType::FLOAT32); + + NDArray expGradB( + 'c', {oC}, + {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) { - - int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oC=iC*mC; - int oH=10,oW=10; - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input.linspace(-10, 0.1); - weights.linspace(-2, 0.1); - gradO.linspace(10, -0.1); - - - NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, - -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, - -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, - -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, - -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, - -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, - -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, - -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, - -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, - -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, - -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - auto gradB = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - - ASSERT_TRUE(expGradB.isSameShape(gradB)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); + int bS = 1, iH = 10, iW = 10, iC = 8, mC = 1, kH = 3, kW = 3, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oC = iC * mC; + int oH = 10, oW = 10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, + 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, + 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, + 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, + 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, + 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, + 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, + 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, + 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, + 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, + 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, + 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, + 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, + 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, + 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, + 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, + 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, + 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, + 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, + 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, + 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, + 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, + 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, + -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, + -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, + -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, + -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, + -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, + -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, + -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, + -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, + -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, + -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, + -61.170002, -62.340004, -63.510002, -64.680000, -60.149994, + -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, + -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, + -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, + -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, + -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, + -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, + 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, + -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, + -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, + -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, + -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, + -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, + -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, + -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, + -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, + -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, + -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, + -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, + -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, + -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, + -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, + -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, + -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, + -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, + -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, + -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, + 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, + -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, + -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, + -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, + -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, + -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, + -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, + -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, + -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, + -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, + -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, + -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, + -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, + -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, + -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, + -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, + -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, + -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, + -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, + -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, + 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, + -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, + -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, + -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, + -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, + -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, + -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, + -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, + -235.440002, -476.159912, -477.600006, -479.040039, -480.479980, + -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, + -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, + -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, + -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, + -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, + -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, + -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, + -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, + -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, + -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, + -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, + -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, + -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, + -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, + -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, + -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, + -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, + -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, + -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, + -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, + -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, + -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, + -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, + -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, + -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, + -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, + -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, + -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, + -727.169922, -728.700012, -730.229980, -731.760010, -605.789978, + -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, + -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, + -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, + -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, + -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, + -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, + -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, + -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, + -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, + -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, + -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, + -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, + -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, + -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, + -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, + -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, + -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, + -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, + -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, + -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, + -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, + -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, + -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, + -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, + -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, + -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, + -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, + -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, + -557.729980, -1064.130005, -1065.840088, -1067.550049, -1069.260010, + -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, + -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, + -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, + -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, + -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, + -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, + -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, + -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, + -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, + -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, + -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, + -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, + -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {kH, kW, iC, mC}, + {-2586.600586, -2505.600098, -18624.595703, -50943.605469, + -99462.601562, -164181.609375, -245100.609375, -342219.625000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, + -110520.156250, -182430.156250, -272340.156250, -380250.125000, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, + -99470.695312, -164189.703125, -245108.687500, -342227.750000, + -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -3383.499756, -3283.500000, -23183.501953, -63083.500000, + -122983.500000, -202883.515625, -302783.531250, -422683.468750, + -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, + -99470.695312, -164189.703125, -245108.687500, -342227.750000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, + -110520.156250, -182430.156250, -272340.156250, -380250.125000, + -2586.600586, -2505.600098, -18624.595703, -50943.605469, + -99462.601562, -164181.609375, -245100.609375, -342219.625000}, + sd::DataType::FLOAT32); + + NDArray expGradB( + 'c', {oC}, + {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, + sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) { - - int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto bias = NDArrayFactory::create('c', {oC}, {3,4}); - auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, - 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, - 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); - - auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); - - input = 2.; - weights.linspace(0.1, 0.1); - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 1, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {3, 4}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create( + 'c', {bS, iC, iH, iW}, + {0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, + 0.069, 0.044, 0.01, 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, + 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, + 0.136, 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, + 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); + + auto expGradW = NDArrayFactory::create( + 'c', {kH, kW, iC, mC}, + {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = + op.evaluate({&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) { - - int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int oC=iC*mC; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] - - NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - NDArray weights('c', {mC, iC, kH, kW}, {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, {3,4}, sd::DataType::FLOAT32); - NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - - - NDArray expGradI('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, - 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, - 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, sd::DataType::FLOAT32); - - NDArray expGradW('c', {mC, iC, kH, kW}, {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, sd::DataType::FLOAT32); - - input = 2.; - gradO.linspace(0.01, 0.01); - - sd::ops::depthwise_conv2d_bp op; - auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); - auto gradI = results.at(0); - auto gradW = results.at(1); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expGradI.isSameShape(gradI)); - ASSERT_TRUE(expGradI.equalsTo(gradI)); - - ASSERT_TRUE(expGradW.isSameShape(gradW)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); + int bS = 2, iH = 4, iW = 3, iC = 2, mC = 1, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int oC = iC * mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = + 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights( + 'c', {mC, iC, kH, kW}, + {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, + sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {3, 4}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI( + 'c', {bS, iC, iH, iW}, + {0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, + 0.069, 0.044, 0.01, 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, + 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, + 0.136, 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, + 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, + sd::DataType::FLOAT32); + + NDArray expGradW( + 'c', {mC, iC, kH, kW}, + {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, + sd::DataType::FLOAT32); + + input = 2.; + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate( + {&input, &weights, &bias, &gradO}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); } -#endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file +#endif // LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu index 47892c973f2b..25729de21cdb 100644 --- a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu @@ -14,17 +14,18 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include -#include #include +#include +#include + +#include + +#include "testlayers.h" #ifdef HAVE_CUDNN @@ -35,114 +36,115 @@ using namespace sd; class CuDnnTests : public testing::Test { -public: - + public: }; -static void printer(std::initializer_list helpers) { - - for (auto v:helpers) { - nd4j_printf("Initialized [%s]\n", v->name().c_str()); - } +static void printer( + std::initializer_list helpers) { + for (auto v : helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } } - TEST_F(CuDnnTests, helpers_includer) { - // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker + // we need this block, to make sure all helpers are still available within + // binary, and not optimized out by linker #ifdef HAVE_CUDNN - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; - sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; - sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; - sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; - sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; - sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; - sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; - sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; - sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; - sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; - sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; - sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; - sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; - sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; - sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; - - - - printer({&conv2d}); - printer({&conv2d_bp}); - printer({&conv3dnew}); - printer({&conv3dnew_bp}); - printer({&depthwise_conv2d}); - printer({&depthwise_conv2d_bp}); - printer({&batchnorm}); - printer({&batchnorm_bp}); - printer({&avgpool2d}); - printer({&avgpool2d_bp}); - printer({&maxpool2d}); - printer({&maxpool2d_bp}); - printer({&avgpool3dnew}); - printer({&avgpool3dnew_bp}); - printer({&maxpool3dnew}); - printer({&maxpool3dnew_bp}); + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; + sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; + sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; + sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; + sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA + depthwise_conv2d_bp; + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; + + printer({&conv2d}); + printer({&conv2d_bp}); + printer({&conv3dnew}); + printer({&conv3dnew_bp}); + printer({&depthwise_conv2d}); + printer({&depthwise_conv2d_bp}); + printer({&batchnorm}); + printer({&batchnorm_bp}); + printer({&avgpool2d}); + printer({&avgpool2d_bp}); + printer({&maxpool2d}); + printer({&maxpool2d_bp}); + printer({&avgpool3dnew}); + printer({&avgpool3dnew_bp}); + printer({&maxpool3dnew}); + printer({&maxpool3dnew_bp}); #endif } - TEST_F(CuDnnTests, mixed_helpers_test_1) { -#if defined(HAVE_CUDNN) && defined (HAVE_MKLDNN) - nd4j_printf("Mixed platforms test\n", ""); - - - int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 0; // 1-SAME, 0-VALID; - int dataFormat = 0; // 1-NHWC, 0-NCHW - - auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); - auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); - auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); - - auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); - auto zCUDA = expOutput.like(); - auto zMKL = expOutput.like(); - - input = 2.; - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); - - input.syncToHost(); - weights.syncToHost(); - bias.syncToHost(); - - sd::ops::conv2d op; - - // cuDNN part - Context cuda(1); - cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); - cuda.setInputArray(0, &input); - cuda.setInputArray(1, &weights); - cuda.setInputArray(2, &bias); - cuda.setOutputArray(0, &zCUDA); - cuda.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto statusCUDA = op.execute(&cuda); - - ASSERT_EQ(Status::OK(), statusCUDA); - ASSERT_EQ(expOutput, zCUDA); - - // MKL-DNN part - Context mkl(1); - mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); - mkl.setInputArray(0, &input); - mkl.setInputArray(1, &weights); - mkl.setInputArray(2, &bias); - mkl.setOutputArray(0, &zMKL); - mkl.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto statusMKL = op.execute(&mkl); - - zMKL.tickWriteHost(); - - ASSERT_EQ(Status::OK(), statusMKL); - ASSERT_EQ(expOutput, zMKL); +#if defined(HAVE_CUDNN) && defined(HAVE_MKLDNN) + nd4j_printf("Mixed platforms test\n", ""); + + int bS = 2, iH = 4, iW = 3, iC = 4, oC = 3, kH = 3, kW = 2, sH = 1, sW = 1, + pH = 0, pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1, 2, 3}); + + auto expOutput = NDArrayFactory::create( + 'c', {bS, oC, oH, oW}, + {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, + 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, + 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + auto zCUDA = expOutput.like(); + auto zMKL = expOutput.like(); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); + + input.syncToHost(); + weights.syncToHost(); + bias.syncToHost(); + + sd::ops::conv2d op; + + // cuDNN part + Context cuda(1); + cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); + cuda.setInputArray(0, &input); + cuda.setInputArray(1, &weights); + cuda.setInputArray(2, &bias); + cuda.setOutputArray(0, &zCUDA); + cuda.setIArguments({kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto statusCUDA = op.execute(&cuda); + + ASSERT_EQ(Status::OK(), statusCUDA); + ASSERT_EQ(expOutput, zCUDA); + + // MKL-DNN part + Context mkl(1); + mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); + mkl.setInputArray(0, &input); + mkl.setInputArray(1, &weights); + mkl.setInputArray(2, &bias); + mkl.setOutputArray(0, &zMKL); + mkl.setIArguments({kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, dataFormat}); + auto statusMKL = op.execute(&mkl); + + zMKL.tickWriteHost(); + + ASSERT_EQ(Status::OK(), statusMKL); + ASSERT_EQ(expOutput, zMKL); #endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 97243552372c..206345ed966c 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -14,3087 +14,3584 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" +#include #include #include +#include +#include #include #include #include #include -#include -#include +#include +#include #include #include -#include #include -#include -#include -#include -#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class CudaBasicsTests1 : public testing::Test { -public: - + public: }; - ////////////////////////////////////////////////////////////////////////// -static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { - - if(devicePtrs.size() != hostData.size()) - throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); - - cudaError_t cudaResult; - - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - cudaStream_t stream = *lc.getCudaStream(); - - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); - } - return cudaResult; +static cudaError_t allocateDeviceMem( + LaunchContext &lc, std::vector &devicePtrs, + const std::vector> &hostData) { + if (devicePtrs.size() != hostData.size()) + throw std::invalid_argument( + "prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + if (cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, stream); + } + return cudaResult; } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, TestPairwise_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {0,0,0,0,0}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - ASSERT_EQ(0, res); - res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - ASSERT_EQ(0, res); - res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - x.dataBuffer()->allocatePrimary(); - x.syncToHost(); - - cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - LaunchContext lc(stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), nullptr); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - z.dataBuffer()->allocatePrimary(); - - cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), cudaMemcpyDeviceToHost, *stream); - res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - cudaFree(devBufferPtrX); - cudaFree(devBufferPtrZ); - cudaFree(devShapePtrX); - - // needed due to memcpy - z.tickWriteHost(); - - for (int e = 0; e < z.lengthOf(); e++) { - nd4j_printf("step %i\n", e); - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {0, 0, 0, 0, 0}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), + x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devShapePtrX), + shape::shapeInfoByteLength(x.shapeInfo())); + ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + x.dataBuffer()->allocatePrimary(); + x.syncToHost(); + + cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + cudaMemcpyHostToDevice, *stream); + cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), + shape::shapeInfoByteLength(x.shapeInfo()), + cudaMemcpyHostToDevice, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, + reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), + devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, + z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), + nullptr); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + z.dataBuffer()->allocatePrimary(); + + cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), + cudaMemcpyDeviceToHost, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + cudaFree(devBufferPtrX); + cudaFree(devBufferPtrZ); + cudaFree(devShapePtrX); + + // needed due to memcpy + z.tickWriteHost(); + + for (int e = 0; e < z.lengthOf(); e++) { + nd4j_printf("step %i\n", e); + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::BFLOAT16); + NDArray x3('c', {2, 2}, {0, -1, 0, 1}, sd::DataType::BOOL); + + NDArray scalar('c', {}, std::vector{0}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {}, std::vector{1}, sd::DataType::INT64); + + void *dX1, *dX2, *dX3, *dZ; + Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; + + cudaError_t cudaResult; + + cudaResult = + cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ), + scalar.lengthOf() * scalar.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), + shape::shapeInfoByteLength(x1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX2ShapeInfo), + shape::shapeInfoByteLength(x2.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), + shape::shapeInfoByteLength(x3.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZShapeInfo), + shape::shapeInfoByteLength(scalar.shapeInfo())); + ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + scalar.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), + shape::shapeInfoByteLength(x1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2ShapeInfo, x2.shapeInfo(), + shape::shapeInfoByteLength(x2.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), + shape::shapeInfoByteLength(x3.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZShapeInfo, scalar.shapeInfo(), + shape::shapeInfoByteLength(scalar.shapeInfo()), + cudaMemcpyHostToDevice, stream); + + void *reductionPointer = nullptr; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); + ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getScalarPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, x1.buffer(), x1.shapeInfo(), dX1, + dX1ShapeInfo, nullptr, scalar.buffer(), scalar.shapeInfo(), dZ, + dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar.tickWriteHost(); + + ASSERT_NEAR(exp1.e(0), scalar.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, nullptr, x2.shapeInfo(), dX2, + dX2ShapeInfo, nullptr, nullptr, scalar.shapeInfo(), dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar.e(0), 1e-5); + + // ************************************* + + NativeOpExecutioner::execIndexReduceScalar( + &lc, sd::indexreduce::IndexAbsoluteMax, nullptr, x3.shapeInfo(), dX3, + dX3ShapeInfo, nullptr, nullptr, scalar.shapeInfo(), dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp3.e(0), scalar.e(0), 1e-5); + + /***************************************/ - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::BFLOAT16); - NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); - - NDArray scalar('c', {}, std::vector{0}, sd::DataType::INT64); - - NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp3('c', {}, std::vector{1}, sd::DataType::INT64); - - void *dX1, *dX2, *dX3, *dZ; - Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; - - cudaError_t cudaResult; - - cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ), scalar.lengthOf() * scalar.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2ShapeInfo), shape::shapeInfoByteLength(x2.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZShapeInfo), shape::shapeInfoByteLength(scalar.shapeInfo())); ASSERT_EQ(0, cudaResult); - - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); - ASSERT_EQ(0, cudaResult); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - scalar.syncToHost(); - - cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2ShapeInfo, x2.shapeInfo(), shape::shapeInfoByteLength(x2.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZShapeInfo, scalar.shapeInfo(), shape::shapeInfoByteLength(scalar.shapeInfo()), cudaMemcpyHostToDevice, stream); - - void* reductionPointer = nullptr; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); - ASSERT_EQ(0, cudaResult); - cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); - ASSERT_EQ(0, cudaResult); - - LaunchContext lc(&stream, LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getScalarPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - - /***************************************/ - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - x1.buffer(), x1.shapeInfo(), - dX1, dX1ShapeInfo, - nullptr, - scalar.buffer(), scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - scalar.tickWriteHost(); - - ASSERT_NEAR(exp1.e(0), scalar.e(0), 1e-5); - - /***************************************/ - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - nullptr, x2.shapeInfo(), - dX2, dX2ShapeInfo, - nullptr, - nullptr, scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp2.e(0), scalar.e(0), 1e-5); - - // ************************************* - - NativeOpExecutioner::execIndexReduceScalar(&lc, - sd::indexreduce::IndexAbsoluteMax, - nullptr, x3.shapeInfo(), - dX3, dX3ShapeInfo, - nullptr, - nullptr, scalar.shapeInfo(), - dZ, dZShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp3.e(0), scalar.e(0), 1e-5); - - /***************************************/ - - cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dZ); - cudaFree(dX1ShapeInfo); cudaFree(dX2ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZShapeInfo); - - /***************************************/ - - cudaResult = cudaStreamDestroy(stream); - ASSERT_EQ(0, cudaResult); - + cudaFree(dX1); + cudaFree(dX2); + cudaFree(dX3); + cudaFree(dZ); + cudaFree(dX1ShapeInfo); + cudaFree(dX2ShapeInfo); + cudaFree(dX3ShapeInfo); + cudaFree(dZShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3Scalar_1) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15.}, sd::DataType::DOUBLE); - - NDArray scalar1('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); - NDArray scalar2('c', {}, std::vector{100.}, sd::DataType::DOUBLE); - - void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; - Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; - - cudaError_t cudaResult; - - cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX4), x4.lengthOf() * x4.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ1), scalar1.lengthOf() * scalar1.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ2), scalar2.lengthOf() * scalar2.sizeOfT()); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ1ShapeInfo), shape::shapeInfoByteLength(scalar1.shapeInfo())); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&dZ2ShapeInfo), shape::shapeInfoByteLength(scalar2.shapeInfo())); ASSERT_EQ(0, cudaResult); - - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); - ASSERT_EQ(0, cudaResult); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - x4.syncToHost(); - scalar1.syncToHost(); - scalar2.syncToHost(); - - cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX4, x4.buffer(), x4.lengthOf() * x4.sizeOfT(), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZ1ShapeInfo, scalar1.shapeInfo(), shape::shapeInfoByteLength(scalar1.shapeInfo()), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(dZ2ShapeInfo, scalar2.shapeInfo(), shape::shapeInfoByteLength(scalar2.shapeInfo()), cudaMemcpyHostToDevice, stream); - - /***************************************/ - - void* reductionPointer = nullptr; - int* allocationPointer = nullptr; - - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - - LaunchContext lc(&stream, reductionPointer, nullptr, allocationPointer); - - /***************************************/ - - NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x1.shapeInfo(),dX1, dX1ShapeInfo, nullptr, nullptr, x2.shapeInfo(),dX2, dX1ShapeInfo,nullptr, scalar1.shapeInfo(),dZ1, dZ1ShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - scalar1.tickWriteHost(); - scalar2.tickWriteHost(); - - cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp1.e(0), scalar1.e(0), 1e-5); - - /***************************************/ - - NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x3.shapeInfo(),dX3, dX3ShapeInfo, nullptr, nullptr, x4.shapeInfo(),dX4, dX3ShapeInfo,nullptr, scalar2.shapeInfo(),dZ2, dZ2ShapeInfo); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - cudaMemcpyAsync(scalar2.buffer(), dZ2, scalar2.lengthOf() * scalar2.sizeOfT(), cudaMemcpyDeviceToHost, stream); - - cudaResult = cudaStreamSynchronize(stream); - ASSERT_EQ(0, cudaResult); - - ASSERT_NEAR(exp2.e(0), scalar2.e(0), 1e-5); - - /***************************************/ - - cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dX4); cudaFree(dZ1); cudaFree(dZ2); - cudaFree(dX1ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZ1ShapeInfo); cudaFree(dZ2ShapeInfo); - - /***************************************/ - - cudaResult = cudaStreamDestroy(stream); - ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + + NDArray scalar1('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; + Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; + + cudaError_t cudaResult; + + cudaResult = + cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&dX4), x4.lengthOf() * x4.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1), + scalar1.lengthOf() * scalar1.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2), + scalar2.lengthOf() * scalar2.sizeOfT()); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), + shape::shapeInfoByteLength(x1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), + shape::shapeInfoByteLength(x3.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1ShapeInfo), + shape::shapeInfoByteLength(scalar1.shapeInfo())); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2ShapeInfo), + shape::shapeInfoByteLength(scalar2.shapeInfo())); + ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + x4.syncToHost(); + scalar1.syncToHost(); + scalar2.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX4, x4.buffer(), x4.lengthOf() * x4.sizeOfT(), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), + shape::shapeInfoByteLength(x1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), + shape::shapeInfoByteLength(x3.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ1ShapeInfo, scalar1.shapeInfo(), + shape::shapeInfoByteLength(scalar1.shapeInfo()), + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ2ShapeInfo, scalar2.shapeInfo(), + shape::shapeInfoByteLength(scalar2.shapeInfo()), + cudaMemcpyHostToDevice, stream); + + /***************************************/ + + void *reductionPointer = nullptr; + int *allocationPointer = nullptr; + + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, reductionPointer, nullptr, allocationPointer); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar( + &lc, sd::reduce3::Dot, nullptr, x1.shapeInfo(), dX1, dX1ShapeInfo, + nullptr, nullptr, x2.shapeInfo(), dX2, dX1ShapeInfo, nullptr, + scalar1.shapeInfo(), dZ1, dZ1ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar1.tickWriteHost(); + scalar2.tickWriteHost(); + + cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp1.e(0), scalar1.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar( + &lc, sd::reduce3::Dot, nullptr, x3.shapeInfo(), dX3, dX3ShapeInfo, + nullptr, nullptr, x4.shapeInfo(), dX4, dX3ShapeInfo, nullptr, + scalar2.shapeInfo(), dZ2, dZ2ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar2.buffer(), dZ2, scalar2.lengthOf() * scalar2.sizeOfT(), + cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar2.e(0), 1e-5); + + /***************************************/ + + cudaFree(dX1); + cudaFree(dX2); + cudaFree(dX3); + cudaFree(dX4); + cudaFree(dZ1); + cudaFree(dZ2); + cudaFree(dX1ShapeInfo); + cudaFree(dX3ShapeInfo); + cudaFree(dZ1ShapeInfo); + cudaFree(dZ2ShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_1) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray y('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - - NDArray exp('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); - - std::vector dimensions = {0, 1}; - - x.syncToHost(); - y.syncToHost(); - z.syncToHost(); - - - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - nullptr, nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + + NDArray exp('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 1}; + + x.syncToHost(); + y.syncToHost(); + z.syncToHost(); + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), nullptr, + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_2) { - - NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - - NDArray exp('c', {}, std::vector{15.}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100.}, sd::DataType::DOUBLE); - - std::vector dimensions = {0, 1}; - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - nullptr, nullptr, nullptr, nullptr); - - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray y('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + + NDArray exp('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + std::vector dimensions = {0, 1}; + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), nullptr, + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_3) { - - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray y('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); - - NDArray exp('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray y('c', {2, 3}, {-6, -5, -4, -3, -2, -1}, sd::DataType::INT32); + + NDArray exp('c', {3}, {-18, -20, -18}, sd::DataType::FLOAT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_4) { - - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - - NDArray exp('c', {2}, {9,22.5}, sd::DataType::DOUBLE); - NDArray z('c', {2}, {100,100}, sd::DataType::DOUBLE); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2}, {9, 22.5}, sd::DataType::DOUBLE); + NDArray z('c', {2}, {100, 100}, sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3_5) { - - NDArray x('c', {2,2,3}, {1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::FLOAT32); - NDArray y('c', {2,2,3}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); - - NDArray exp('c', {2,3}, {7.5, 10.5, 13.5, 25.5, 28.5, 31.5}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, + {1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::FLOAT32); + + NDArray exp('c', {2, 3}, {7.5, 10.5, 13.5, 25.5, 28.5, 31.5}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3All_1) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray y('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); - - NDArray exp('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4 -- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('c', {2, 3}, {-1, 1, -1, 1, -1, 1}, sd::DataType::INT32); + + NDArray exp('c', {2, 3}, {2, -2, 2, 2, -2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4 -- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3All_2) { - - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - - NDArray exp('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); - NDArray z('c', {2,3}, {100,100,100,100,100,100,},sd::DataType::DOUBLE); - - std::vector dimensions = {0}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // evaluate yTad data - shape::TAD yTad; - yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); - yTad.createTadOnlyShapeInfo(); - yTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo - hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2, 3}, {6, 6, 6, 9, 9, 9}, sd::DataType::DOUBLE); + NDArray z('c', {2, 3}, + { + 100, + 100, + 100, + 100, + 100, + 100, + }, + sd::DataType::DOUBLE); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back( + yTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo)); // 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, + yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[3], (Nd4jLong *)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_1) { - - NDArray x('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - x.linspace(-2.); x.syncToDevice(); - NDArray exp('c', {2}, {2, 2}, sd::DataType::INT64); - NDArray z('c', {2}, {100,100}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + x.linspace(-2.); + x.syncToDevice(); + NDArray exp('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray z('c', {2}, {100, 100}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_2) { - - NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - x.linspace(-2.f); x.syncToDevice(); - NDArray exp('c', {2,5}, {11,11,11,11,11,11,11,11,11,11}, sd::DataType::INT64); - NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {1,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x( + 'c', {2, 3, 4, 5}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + x.linspace(-2.f); + x.syncToDevice(); + NDArray exp('c', {2, 5}, {11, 11, 11, 11, 11, 11, 11, 11, 11, 11}, + sd::DataType::INT64); + NDArray z('c', {2, 5}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + + std::vector dimensions = {1, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execIndexReduce_3) { - - NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, - 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - x.linspace(-2.); x.syncToDevice(); - NDArray exp('c', {3}, {39, 39, 39}, sd::DataType::INT64); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {0,2,3}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x( + 'c', {2, 3, 4, 5}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + x.linspace(-2.); + x.syncToDevice(); + NDArray exp('c', {3}, {39, 39, 39}, sd::DataType::INT64); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT64); + + std::vector dimensions = {0, 2, 3}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce( + &lc, sd::indexreduce::IndexMax, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_1) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT64); - NDArray exp('c',{2,3}, {0,0,1,1,2,2}, sd::DataType::INT64); - NDArray scalar('c',{}, std::vector{2.f}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::INT64); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {0, 0, 1, 1, 2, 2}, sd::DataType::INT64); + NDArray scalar('c', {}, std::vector{2.f}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::Divide, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_2) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, sd::DataType::INT64); - NDArray exp('c',{2,3}, {10,10,10,10,10,10}, sd::DataType::FLOAT32); - NDArray scalar('c',{}, std::vector{10.f}, sd::DataType::FLOAT32); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::CopyPws, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 3}, {-1, -2, -3, -4, -5, -6}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::FLOAT32); + NDArray scalar('c', {}, std::vector{10.f}, sd::DataType::FLOAT32); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::CopyPws, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalar_3) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,3,2}, {0,1,2,3,4,5,6,7,8,9,10,11}, sd::DataType::INT64); - NDArray scalars('c',{2,2}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,2}, {0,0,2,1,4,2, 2,1,2,2,3,2}, sd::DataType::INT64); - NDArray z('c', {2,3,2}, {100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + sd::DataType::INT64); + NDArray scalars('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 2}, {0, 0, 2, 1, 4, 2, 2, 1, 2, 2, 3, 2}, + sd::DataType::INT64); + NDArray z('c', {2, 3, 2}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar( + &lc, sd::scalar::Divide, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalars.shapeInfo(), + scalars.specialBuffer(), scalars.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalarBool_1) { - - NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, sd::DataType::BFLOAT16); - NDArray scalar('c',{}, std::vector{0}, sd::DataType::BFLOAT16); - NDArray exp('c',{2,3}, {0,0,0,1,1,1}, sd::DataType::BOOL); - NDArray z('c', {2,3}, {100,100,100,100,100,100,}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - // call cuda kernel which calculates result - NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {-1, -2, 0, 1, 2, 3}, sd::DataType::BFLOAT16); + NDArray scalar('c', {}, std::vector{0}, sd::DataType::BFLOAT16); + NDArray exp('c', {2, 3}, {0, 0, 0, 1, 1, 1}, sd::DataType::BOOL); + NDArray z('c', {2, 3}, + { + 100, + 100, + 100, + 100, + 100, + 100, + }, + sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool( + &lc, sd::scalar::GreaterThan, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalar.shapeInfo(), scalar.specialBuffer(), + scalar.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execScalarBool_2) { - - NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray scalars('c',{2}, {-1,4}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3}, {1,1,1,0,0,1}, sd::DataType::BOOL); - NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::BOOL); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray scalars('c', {2}, {-1, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); + NDArray z('c', {2, 3}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool( + &lc, sd::scalar::GreaterThan, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, scalars.shapeInfo(), + scalars.specialBuffer(), scalars.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcast_1) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {2, 3, 4}, {10, 11, 12, 13, 24, 25, 26, 27, 38, 39, 40, 41, + 22, 23, 24, 25, 36, 37, 38, 39, 50, 51, 52, 53}, + sd::DataType::INT32); + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcast_2) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::FLOAT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::FLOAT32); - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 4}, + {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., + 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, + sd::DataType::FLOAT32); + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcastBool_1) { - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {2, 12, 22}, sd::DataType::INT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,}, sd::DataType::BOOL); - NDArray exp('c', {2,3,4}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, sd::DataType::BOOL); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {2, 12, 22}, sd::DataType::INT32); + NDArray z('c', {2, 3, 4}, + { + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + }, + sd::DataType::BOOL); + NDArray exp('c', {2, 3, 4}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, + sd::DataType::BOOL); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool( + &lc, sd::broadcast::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execBroadcastBool_2) { - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100},sd::DataType::FLOAT32); - NDArray y('c', {2,4}, {1,10,10,15,20,20,20,24}, sd::DataType::FLOAT32); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {2,3,4}, {1, 0, 0, 0,0, 0, 0, 0,0, 1, 0, 0,0, 0, 0, 0,0, 0, 0, 0,0, 0, 0, 1}, sd::DataType::BOOL); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 4}, {1, 10, 10, 15, 20, 20, 20, 24}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {2, 3, 4}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + sd::DataType::BOOL); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool( + &lc, sd::broadcast::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execPairwiseTransform_1) { - - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT32); - NDArray y('c', {4,2}, {0.1,0.2,0.3,0.4,1.5,0.6,0.7,1.8}, sd::DataType::DOUBLE); - NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {8}, {0,1,2,3,3,5,6,6}, sd::DataType::INT32); - x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execPairwiseTransform(&lc, sd::pairwise::Subtract, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x('c', {2, 2, 2}, {1, 5, 3, 7, 2, 6, 4, 8}, sd::DataType::INT32); + NDArray y('c', {4, 2}, {0.1, 0.2, 0.3, 0.4, 1.5, 0.6, 0.7, 1.8}, + sd::DataType::DOUBLE); + NDArray z('c', {8}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {8}, {0, 1, 2, 3, 3, 5, 6, 6}, sd::DataType::INT32); + x.permutei({2, 1, 0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseTransform( + &lc, sd::pairwise::Subtract, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execPairwiseBoolTransform_1) { - - NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT64); - NDArray y('c', {4,2}, {0,2,0,4,0,6,0,8}, sd::DataType::INT64); - NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {8}, {0,1,0,1,0,1,0,1}, sd::DataType::BOOL); - x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execPairwiseBoolTransform(&lc, sd::pairwise::EqualTo, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 2}, {1, 5, 3, 7, 2, 6, 4, 8}, sd::DataType::INT64); + NDArray y('c', {4, 2}, {0, 2, 0, 4, 0, 6, 0, 8}, sd::DataType::INT64); + NDArray z('c', {8}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {8}, {0, 1, 0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseBoolTransform( + &lc, sd::pairwise::EqualTo, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } - //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformFloat_1) { - - NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); - NDArray z('c', {4}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {4}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - x.permutei({1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {4}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + x.permutei({1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat( + &lc, sd::transform::Sqrt, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformFloat_2) { - - NDArray x('c', {1,4}, {0, 4, 9, 16}, sd::DataType::INT64); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {2,2}, {0, 2, 3, 4}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {1, 4}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 2}, {0, 2, 3, 4}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat( + &lc, sd::transform::Sqrt, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformAny_1) { - - NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); - NDArray z('c', {4,1}, {100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {4,1}, {0, 2, 6, 12}, sd::DataType::INT32); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4, 1}, {100, 100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {4, 1}, {0, 2, 6, 12}, sd::DataType::INT32); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny( + &lc, sd::transform::Assign, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformAny_2) { - - NDArray x('c', {1,4}, {0, 6.25, 2.25, 12.25}, sd::DataType::BFLOAT16); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {1, 4}, {0, 6.25, 2.25, 12.25}, sd::DataType::BFLOAT16); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {0, 6.25, 2.25, 12.25}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny( + &lc, sd::transform::Assign, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformStrict_1) { - - NDArray x('c', {2,3}, {0,2,4,1,3,5}, sd::DataType::DOUBLE); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2, 4, 1, 3, 5}, sd::DataType::DOUBLE); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict( + &lc, sd::transform::CubeDerivative, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformStrict_2) { - - NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict( + &lc, sd::transform::CubeDerivative, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformSame_1) { - - NDArray x('c', {2,3}, {0,2.5,4.5,1.5,3.5,5.5}, sd::DataType::DOUBLE); - NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {1,6}, {0,2.25,6.25,12.25,20.25,30.25}, sd::DataType::DOUBLE); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2.5, 4.5, 1.5, 3.5, 5.5}, sd::DataType::DOUBLE); + NDArray z('c', {1, 6}, {100, 100, 100, 100, 100, 100}, sd::DataType::DOUBLE); + NDArray exp('c', {1, 6}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, + sd::DataType::DOUBLE); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame( + &lc, sd::transform::Square, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformSame_2) { - - NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::INT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {3,2}, {0,1,4,9,16,25}, sd::DataType::INT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {3, 2}, {0, 1, 4, 9, 16, 25}, sd::DataType::INT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame( + &lc, sd::transform::Square, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformBool_1) { - - NDArray x('c', {2,3}, {0,2,4,-1,-3,-5}, sd::DataType::DOUBLE); - NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {1,6}, {0,0,1,0,1,0}, sd::DataType::BOOL); - x.permutei({1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3}, {0, 2, 4, -1, -3, -5}, sd::DataType::DOUBLE); + NDArray z('c', {1, 6}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {1, 6}, {0, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + x.permutei({1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool( + &lc, sd::transform::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execTransformBool_2) { - - NDArray x('c', {6}, {0,-1,2,-3,4,-5}, sd::DataType::INT32); - NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {3,2}, {0,0,1,0,1,0}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // call cuda kernel which calculates result - NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {6}, {0, -1, 2, -3, 4, -5}, sd::DataType::INT32); + NDArray z('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {3, 2}, {0, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool( + &lc, sd::transform::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloat_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {3}, {2.5, 6.5, 10.5}, sd::DataType::FLOAT32); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {2.5, 6.5, 10.5}, sd::DataType::FLOAT32); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloat( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloat_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {2,4}, {-1., 0., 1., 2.,11., 12., 13., 14.}, sd::DataType::DOUBLE); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {2, 4}, {-1., 0., 1., 2., 11., 12., 13., 14.}, + sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloat( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSame_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT32); - NDArray exp('c', {3}, {20, 52, 84}, sd::DataType::INT32); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT32); + NDArray exp('c', {3}, {20, 52, 84}, sd::DataType::INT32); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSame( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSame_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray exp('c', {2,4}, {-3., 0., 3., 6.,33., 36., 39., 42.}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 4}, {-3., 0., 3., 6., 33., 36., 39., 42.}, + sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSame( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBool_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {3}, {0, 1, 1}, sd::DataType::BOOL); - x.permutei({2,1,0}); - - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::BOOL); + NDArray exp('c', {3}, {0, 1, 1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBool( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBool_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); - NDArray exp('c', {2,4}, {1, 1, 1, 1, 0, 0, 0, 0}, sd::DataType::BOOL); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::BOOL); + NDArray exp('c', {2, 4}, {1, 1, 1, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBool( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLong_1) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); - NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); - NDArray exp('c', {3}, {5,6,6}, sd::DataType::INT64); - x.permutei({2,1,0}); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::INT32); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::INT64); + NDArray exp('c', {3}, {5, 6, 6}, sd::DataType::INT64); + x.permutei({2, 1, 0}); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLong( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLong_2) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::INT64); - NDArray exp('c', {2,4}, {3, 1, 3, 2, 2, 1, 2, 3}, sd::DataType::INT64); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, {100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT64); + NDArray exp('c', {2, 4}, {3, 1, 3, 2, 2, 1, 2, 3}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLong( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - NDArray exp('c', {}, std::vector{6.5}, sd::DataType::FLOAT32); - x.permutei({2,1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::FLOAT32); + x.permutei({2, 1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{6.5}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar( + &lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT32); - NDArray exp('c', {}, std::vector{156}, sd::DataType::INT32); - x.permutei({2,1,0}); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT32); + NDArray exp('c', {}, std::vector{156}, sd::DataType::INT32); + x.permutei({2, 1, 0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{156}, sd::DataType::DOUBLE); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{156}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar( + &lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); - NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); - x.permutei({2,1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + x.permutei({2, 1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { - - NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); - NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, + {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, + -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar( + &lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); - x.permutei({2,1,0}); - x.syncShape(); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + x.permutei({2, 1, 0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduceLongScalar_2) { - - NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::DOUBLE); - NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 3, 4}, {-5, 0, -3, 0, -1, 0, 1, 2, 3, 4, 5, 6, + 7, 0, 9, 10, 11, 0, 13, 14, 0, 16, 0, 18}, + sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + int *allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar( + &lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_1) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray exp('c', {3}, {10,20,30}, sd::DataType::DOUBLE); - NDArray z('c', {3}, {100,100,100}, sd::DataType::DOUBLE); - - std::vector dimensions = {0,1}; - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions); - LaunchContext* context = x.getContext(); - - x.syncToDevice(); - y.syncToDevice(); - PointersManager pm(context, "execReduce3TAD_1"); - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(context, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, dimensions.size(), - packX.specialShapeInfo(), packX.specialOffsets(), nullptr, nullptr); - pm.synchronize(); -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); -// z.printIndexedBuffer("OutputReduce3TAD"); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {10, 20, 30}, sd::DataType::DOUBLE); + NDArray z('c', {3}, {100, 100, 100}, sd::DataType::DOUBLE); + + std::vector dimensions = {0, 1}; + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), + dimensions); + LaunchContext *context = x.getContext(); + + x.syncToDevice(); + y.syncToDevice(); + PointersManager pm(context, "execReduce3TAD_1"); + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + context, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr, dimensions.size(), + packX.specialShapeInfo(), packX.specialOffsets(), nullptr, nullptr); + pm.synchronize(); + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + // z.printIndexedBuffer("OutputReduce3TAD"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_2) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray y('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {2}, {10,73}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray y('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); + NDArray exp('c', {2}, {10, 73}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_3) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray y('c', {3}, {1,2,3}, sd::DataType::INT64); - NDArray exp('c', {2,2}, {-22,-4,14,32}, sd::DataType::FLOAT32); - NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray y('c', {3}, {1, 2, 3}, sd::DataType::INT64); + NDArray exp('c', {2, 2}, {-22, -4, 14, 32}, sd::DataType::FLOAT32); + NDArray z('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execReduce3TAD_4) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, sd::DataType::DOUBLE); - NDArray exp('c', {}, std::vector{1820}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,1,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 2, 3}, {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}, + sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{1820}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 1, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD( + &lc, sd::reduce3::Dot, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int *)devicePtrs[0], dimensions.size(), + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_1) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_2) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2}, {3.405877, 9.715966}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -20, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray exp('c', {2}, {3.405877, 9.715966}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStats_3) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2}, {10.606602, 2.121320}, sd::DataType::FLOAT32); - NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -20, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::DOUBLE); + NDArray exp('c', {2}, {10.606602, 2.121320}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), (int *)devicePtrs[0], + dimensions.size(), (Nd4jLong *)devicePtrs[1], (Nd4jLong *)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } //////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) { - - NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); - NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); - NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); - lc.setReductionPointer(reductionPointer); - - // call cuda kernel which calculates result - NativeOpExecutioner::execSummaryStatsScalar(&lc, sd::variance::SummaryStatsStandardDeviation, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - true); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {2, 2, 3}, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6}, + sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void *reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStatsScalar( + &lc, sd::variance::SummaryStatsStandardDeviation, nullptr, x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), true); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_1) { - -// NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,0}, sd::DataType::DOUBLE); - NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,100}, sd::DataType::FLOAT32); - NDArray exp('c', {10}, {0.050942, -0.183229, -0.093921, 0.075469, 0.257166, -0.254838, 0.342227, -0.682188, -0.004345, 0.464633}, sd::DataType::FLOAT32); - - sd::graph::RandomGenerator gen(119,5); - - cudaError_t cudaResult; - NDArray* array = &z; - ExtraArguments arguments({0.f, 0.5f}); - auto context = z.getContext(); - PointersManager pm(context, "tests::execRandom_1"); -// z.printIndexedBuffer("Input data"); -// z.syncToDevice(); - NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &gen, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); - pm.synchronize(); - z.tickWriteDevice(); -// z.printIndexedBuffer("Output Gaussian"); -// RandomLauncher::fillGaussian(context, gen, &z, 0.f, 0.5f); -// pm.synchronize(); -// z.tickWriteDevice(); -// z.printIndexedBuffer("Output Gaussian"); - -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); -// LaunchContext lc(&stream); -// -// // ::execRandom(extraPointers, random::GaussianDistribution, &gen, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &extra); -// // call cuda kernel which calculates result -// NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, -// &gen, -// nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), -// nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), -// nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), -// extraArguments.argumentsAsT(z.dataType())); -// -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); -// ASSERT_EQ(cudaResult, 0); -// z.tickWriteDevice(); -// z.syncToHost(); -// z.printIndexedBuffer("Random1"); - ASSERT_EQ(exp, z); -// // verify results -// for (int e = 0; e < z.lengthOf(); e++) -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// cudaFree(dExtraArgs); - // free allocated global device memory -// cudaFree(dGen); - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,0}, sd::DataType::DOUBLE); + NDArray z('c', {10}, {100, 0, 0, 0, 0, 0, 0, 0, 0, 100}, + sd::DataType::FLOAT32); + NDArray exp('c', {10}, + {0.050942, -0.183229, -0.093921, 0.075469, 0.257166, -0.254838, + 0.342227, -0.682188, -0.004345, 0.464633}, + sd::DataType::FLOAT32); + + sd::graph::RandomGenerator gen(119, 5); + + cudaError_t cudaResult; + NDArray *array = &z; + ExtraArguments arguments({0.f, 0.5f}); + auto context = z.getContext(); + PointersManager pm(context, "tests::execRandom_1"); + // z.printIndexedBuffer("Input data"); + // z.syncToDevice(); + NativeOpExecutioner::execRandom( + context, random::GaussianDistribution, &gen, array->buffer(), + array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), + array->buffer(), array->shapeInfo(), array->specialBuffer(), + array->specialShapeInfo(), array->buffer(), array->shapeInfo(), + array->specialBuffer(), array->specialShapeInfo(), + arguments.argumentsAsT(array->dataType())); + pm.synchronize(); + z.tickWriteDevice(); + // z.printIndexedBuffer("Output Gaussian"); + // RandomLauncher::fillGaussian(context, gen, &z, 0.f, 0.5f); + // pm.synchronize(); + // z.tickWriteDevice(); + // z.printIndexedBuffer("Output Gaussian"); + + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + // LaunchContext lc(&stream); + // + // // ::execRandom(extraPointers, random::GaussianDistribution, &gen, + //z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &extra); + // // call cuda kernel which calculates result + // NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, + // &gen, + // nullptr, z.shapeInfo(), z.specialBuffer(), + //z.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + //z.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + //z.specialShapeInfo(), extraArguments.argumentsAsT(z.dataType())); + // + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + // ASSERT_EQ(cudaResult, 0); + // z.tickWriteDevice(); + // z.syncToHost(); + // z.printIndexedBuffer("Random1"); + ASSERT_EQ(exp, z); + // // verify results + // for (int e = 0; e < z.lengthOf(); e++) + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // cudaFree(dExtraArgs); + // free allocated global device memory + // cudaFree(dGen); + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_2) { - - NDArray x('c', {10}, {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1}, sd::DataType::DOUBLE); - NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {10}, {0., 0., 0.3, 0., 0.5, 0., 0.7, 0., 0., 1.}, sd::DataType::DOUBLE); - - ExtraArguments extraArguments({0.7}); - sd::graph::RandomGenerator gen(119,5); - -// // prepare input arrays for prepareDataForCuda function -// std::vector> hostData; -// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions -// std::vector devicePtrs(hostData.size(), nullptr); -// - // create cuda stream and LaunchContext - cudaError_t cudaResult; -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext* lc = x.getContext(); //(&stream); - - // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(lc, sd::random::DropOut, - &gen, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extraArguments.argumentsAsT(z.dataType())); - - cudaResult = cudaStreamSynchronize(*lc->getCudaStream()); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - z.syncToHost(); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory -// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray x('c', {10}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 5}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {10}, {0., 0., 0.3, 0., 0.5, 0., 0.7, 0., 0., 1.}, + sd::DataType::DOUBLE); + + ExtraArguments extraArguments({0.7}); + sd::graph::RandomGenerator gen(119, 5); + + // // prepare input arrays for prepareDataForCuda function + // std::vector> hostData; + // hostData.emplace_back(extraArguments.data(), extraArguments.size() * + //sizeof(double)); // 0 -- dimensions std::vector + //devicePtrs(hostData.size(), nullptr); + // + // create cuda stream and LaunchContext + cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext *lc = x.getContext(); //(&stream); + + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + //ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom( + lc, sd::random::DropOut, &gen, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), extraArguments.argumentsAsT(z.dataType())); + + cudaResult = cudaStreamSynchronize(*lc->getCudaStream()); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + // for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_3) { - - NDArray z('c', {10}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); - NDArray exp('c', {10}, {2.373649, 2.239791, 1.887353, 2.488636, 2.068904, 2.281399, 1.828228, 2.228222, 2.490847, 1.669537}, sd::DataType::DOUBLE); - - std::vector extraArguments = {1.5, 2.5}; - sd::graph::RandomGenerator gen(119,5); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(&lc, sd::random::UniformDistribution, - &gen, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - devicePtrs[0]); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray z('c', {10}, {100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + NDArray exp('c', {10}, + {2.373649, 2.239791, 1.887353, 2.488636, 2.068904, 2.281399, + 1.828228, 2.228222, 2.490847, 1.669537}, + sd::DataType::DOUBLE); + + std::vector extraArguments = {1.5, 2.5}; + sd::graph::RandomGenerator gen(119, 5); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back( + extraArguments.data(), + extraArguments.size() * sizeof(double)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(&lc, sd::random::UniformDistribution, &gen, + nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), devicePtrs[0]); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests1, execRandom_4) { - - NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::FLOAT32); - NDArray exp('c', {10}, {2.373649, 2.281399, 2.239791, 1.828228, 1.887353, 2.228222, 2.488636, 2.490847, 2.068904, 1.669537}, sd::DataType::FLOAT32); - z.permutei({1,0}); - - ExtraArguments extraArguments({1.5, 2.5}); - sd::graph::RandomGenerator gen(119,5); - -// // prepare input arrays for prepareDataForCuda function -// std::vector> hostData; -// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions -// std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext -// cudaError_t cudaResult; -// cudaStream_t stream; -// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); -// LaunchContext lc(&stream); -// -// // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - auto context = z.getContext(); - PointersManager pm(context, "execRandom4"); - // call cuda kernel which calculates result - NativeOpExecutioner::execRandom(context, sd::random::UniformDistribution, - &gen, - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extraArguments.argumentsAsT(z.dataType())); - -// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); -// z.printIndexedBuffer("Output Uniform4"); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory -// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); - - // delete cuda stream -// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + NDArray z('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + sd::DataType::FLOAT32); + NDArray exp('c', {10}, + {2.373649, 2.281399, 2.239791, 1.828228, 1.887353, 2.228222, + 2.488636, 2.490847, 2.068904, 1.669537}, + sd::DataType::FLOAT32); + z.permutei({1, 0}); + + ExtraArguments extraArguments({1.5, 2.5}); + sd::graph::RandomGenerator gen(119, 5); + + // // prepare input arrays for prepareDataForCuda function + // std::vector> hostData; + // hostData.emplace_back(extraArguments.data(), extraArguments.size() * + //sizeof(double)); // 0 -- dimensions std::vector + //devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + // cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + // LaunchContext lc(&stream); + // + // // allocate required amount of global device memory and copy host data + //to it cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + //ASSERT_EQ(0, cudaResult); + auto context = z.getContext(); + PointersManager pm(context, "execRandom4"); + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(context, sd::random::UniformDistribution, + &gen, nullptr, z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), + extraArguments.argumentsAsT(z.dataType())); + + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + // z.printIndexedBuffer("Output Uniform4"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + // for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } - diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu index b425ffcbb112..ec84ad32dd89 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -14,973 +14,1199 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" #include #include +#include #include #include #include #include -#include -#include #include +#include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class CudaBasicsTests2 : public testing::Test { -public: - + public: }; TEST_F(CudaBasicsTests2, test_devices_1) { - auto caps = Environment::getInstance()->capabilities(); - ASSERT_FALSE(caps.empty()); + auto caps = Environment::getInstance()->capabilities(); + ASSERT_FALSE(caps.empty()); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_1) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printIndexedBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_2) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + NDArray exp('f', {M, N}, + {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, + 1.6, 2.5, 3.4}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_3) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_4) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {0.1, 2.5, 4.9, 7.3, 9.7,0.3, 2.7, 5.1, 7.5, 9.9,0.5, 2.9, 5.3, 7.7, 10.1}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); - - - // NDArray* pA = a.permute({1,0}); - // NDArray* pB = b.permute({1,0}); - // NDArray* pC = c.permute({1,0}); - - // sd::MmulHelper::mmul(pB, pA, pC, 1., 0.); - // ASSERT_TRUE(c.equalsTo(&exp)); - - // delete pA; - // delete pB; - // delete pC; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {0.1, 2.5, 4.9, 7.3, 9.7, 0.3, 2.7, 5.1, 7.5, 9.9, 0.5, 2.9, 5.3, + 7.7, 10.1}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + ASSERT_TRUE(c.equalsTo(&exp)); + + // NDArray* pA = a.permute({1,0}); + // NDArray* pB = b.permute({1,0}); + // NDArray* pC = c.permute({1,0}); + + // sd::MmulHelper::mmul(pB, pA, pC, 1., 0.); + // ASSERT_TRUE(c.equalsTo(&exp)); + + // delete pA; + // delete pB; + // delete pC; } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_5) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('f', {M,N}, {-8.8, -4.3, 0.2, 8.6, 4.1, -0.4, -8.4, -3.9, 0.6, 8.2, 3.7, -0.8, -8.0, -3.5, 1.}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('f', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M, N}, + {-8.8, -4.3, 0.2, 8.6, 4.1, -0.4, -8.4, -3.9, 0.6, 8.2, 3.7, -0.8, + -8.0, -3.5, 1.}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_6) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-1.6, -0.8, -0.0, 0.8, 1.6, -0.7, 0.1, 0.9, 1.7, 2.5, 0.2, 1.0, 1.8, 2.6, 3.4}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-1.6, -0.8, -0.0, 0.8, 1.6, -0.7, 0.1, 0.9, 1.7, 2.5, 0.2, 1.0, + 1.8, 2.6, 3.4}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_7) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-1.9, 1.3, -0.7, 0.1, 0.5, -0.9, 0.3, 0.3, -0.9, 1.5, 0.1, -0.7, 1.3, -1.9, 2.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-1.9, 1.3, -0.7, 0.1, 0.5, -0.9, 0.3, 0.3, -0.9, 1.5, 0.1, -0.7, + 1.3, -1.9, 2.5}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_8) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::DOUBLE); + NDArray c('c', {M, N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_9) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_10) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printIndexedBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_11) { - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::FLOAT32); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::FLOAT32); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_12) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 4; - const Nd4jLong K = 4; - const Nd4jLong N = 4; - - NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, sd::DataType::INT8); - NDArray b('f', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-1,2,-2,3,-4,5,-6.}, sd::DataType::INT8); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 4; + const Nd4jLong K = 4; + const Nd4jLong N = 4; + + NDArray a('f', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 2, 1, 0, 4, 7.}, + sd::DataType::INT8); + NDArray b('f', {K, N}, + {-2, -3, 0, 1, 5, -6, 7, -8, 9, -1, 2, -2, 3, -4, 5, -6.}, + sd::DataType::INT8); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., + 18., 22., -8., -28., -52.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_13) { + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('f', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT8); + NDArray b('c', {K, N}, {-2, -3, 0, 1, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::INT8); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {-109., -122., -135., 111., 120., 129., -121., -134., -147., 129., 144., 159., -130., -140., -150.}, sd::DataType::FLOAT32); + NDArray exp('f', {M, N}, + {-109., -122., -135., 111., 120., 129., -121., -134., -147., 129., + 144., 159., -130., -140., -150.}, + sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_14) { + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + NDArray a('c', {M, K}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT8); + NDArray b('c', {K, N}, {-2, -3, 0, 1, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::INT8); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); + NDArray exp('c', {M, N}, + {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., + 115., -153., 173., -130.}, + sd::DataType::FLOAT32); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_15) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_16) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_17) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('c', {M, N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_18) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::HALF); - - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('f', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::HALF); + + NDArray exp('f', {M, N}, + {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, + 9.9, 10.1}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_19) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('f', {M,N}, sd::DataType::HALF); - - NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('f', {M, N}, sd::DataType::HALF); + + NDArray exp('f', {M, N}, + {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, + 0.5, 1.5, 2.5}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_20) { - - int devCnt = 0; - cudaGetDevice(&devCnt); - if(Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; - - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; - - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::HALF); - - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); - - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - - ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); + int devCnt = 0; + cudaGetDevice(&devCnt); + if (Environment::getInstance()->capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M, K}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::HALF); + NDArray b('c', {K, N}, {1, -2, 3, -4, 5, -6, 7, -8, 9, -10, + 11, -12, 13, -14, 15, -16, 17, -18, 19, -20}, + sd::DataType::HALF); + NDArray c('c', {M, N}, sd::DataType::HALF); + + NDArray exp('c', {M, N}, + {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, + -0.4, 0.6, -0.8, 1.}, + sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } /* ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_21) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::INT8); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_22) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::FLOAT32); + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::FLOAT32); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::HALF); NDArray c('c', {M,N}, sd::DataType::FLOAT32); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_23) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::HALF); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_24) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_25) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::HALF); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::DOUBLE); NDArray b('c', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::HALF); - NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, +-3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_26) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - // 3x4 * 4x5 = 3x5 - NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT64); - NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); - NDArray c('c', {M,N}, sd::DataType::DOUBLE); + // 3x4 * 4x5 = 3x5 + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, +sd::DataType::INT64); NDArray b('c', {K,N}, +{-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::FLOAT32); NDArray c('c', {M,N}, sd::DataType::DOUBLE); - NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., +113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_27) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::HALF); NDArray b('f', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::DOUBLE); NDArray c('f', {M,N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + NDArray exp('f', {M,N}, {0.1, 0.3, +0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, +sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - // c.printBuffer(); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); - ASSERT_TRUE(c.equalsTo(&exp, 0.01)); + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_28) { - const Nd4jLong M = 3; - const Nd4jLong K = 4; - const Nd4jLong N = 5; + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; - NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); - NDArray c('f', {M,N}, sd::DataType::FLOAT32); + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray b('f', {K,N}, +{1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, +sd::DataType::DOUBLE); NDArray c('f', {M,N}, sd::DataType::FLOAT32); - NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::FLOAT32); + NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, +0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); - ASSERT_TRUE(c.equalsTo(&exp)); + ASSERT_TRUE(c.equalsTo(&exp)); } */ ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_1) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('f', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_2) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('f', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_3) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_4) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('c', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_5) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray x('c', {N}, {1, -2, 3, -4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_6) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_7) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_8) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(4, {1,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {N, M, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(4, {1, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_9) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(3, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_10) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_11) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(13, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(13, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_12) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}); - NDArray y('c', {M}, sd::DataType::DOUBLE); - - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(10, {0, 2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_13) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}, true); - NDArray y('f', {M}, sd::DataType::DOUBLE); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}, true); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_14) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}, true); - NDArray y('c', {M}, sd::DataType::DOUBLE); - - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(10, {0, 2}, true); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_15) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp(17, {0,2}); - - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp(17, {0, 2}); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_16) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp1(17, {0,2}); - - NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray temp1('c', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp1(17, {0, 2}); + + NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_17) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y = temp(17, {0,2}, true); - // y.printShapeInfo(); - - NDArray exp('f', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y = temp(17, {0, 2}, true); + // y.printShapeInfo(); + + NDArray exp('f', {1, M, 1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_18) { - - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1},true); - NDArray y = temp1(17, {0,2},true); - - NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray temp1('c', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}, true); + NDArray y = temp1(17, {0, 2}, true); + + NDArray exp('c', {1, M, 1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// /* TEST_F(CudaBasicsTests2, mmulMxV_19) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); - NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_20) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('f', {M}, sd::DataType::FLOAT32); - NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_21) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); - NDArray y('c', {M}, sd::DataType::FLOAT32); - NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::FLOAT32); + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_22) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2}); NDArray y('f', {M}, +sd::DataType::FLOAT32); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// @@ -989,11 +1215,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_23) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(3, {0,1}); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); @@ -1004,18 +1230,19 @@ TEST_F(CudaBasicsTests2, mmulMxV_23) { ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_24) { - const Nd4jLong M = 3; - const Nd4jLong N = 4; + const Nd4jLong M = 3; + const Nd4jLong N = 4; - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2},true); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2},true); NDArray y('f', {M}, +sd::DataType::FLOAT32); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// @@ -1024,11 +1251,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_25) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}, true); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(3, {0,1}, true); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); @@ -1042,12 +1269,13 @@ TEST_F(CudaBasicsTests2, mmulMxV_26) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); - NDArray x = temp(2, {0,1}); - NDArray y = temp1(17, {0,2}); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray temp1('c', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::FLOAT32); NDArray x = temp(2, {0,1}); NDArray y = temp1(17, +{0,2}); NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); @@ -1061,12 +1289,13 @@ TEST_F(CudaBasicsTests2, mmulMxV_27) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); - NDArray x = temp(2, {0,1},true); - NDArray y = temp1(17, {0,2},true); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); a.permutei({1,0}); NDArray temp('f', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray temp1('c', {5,M,N}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::FLOAT32); NDArray x = temp(2, {0,1},true); NDArray y = temp1(17, +{0,2},true); NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); @@ -1080,10 +1309,11 @@ TEST_F(CudaBasicsTests2, mmulMxV_28) { const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, +sd::DataType::FLOAT32); NDArray temp('f', {M,N,5}, +{16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, +sd::DataType::DOUBLE); NDArray x = temp(6, {0,2}); NDArray y('f', {M}, +sd::DataType::FLOAT32); NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::FLOAT32); @@ -1127,10 +1357,9 @@ TEST_F(CudaBasicsTests2, mmulDot_3) { const Nd4jLong N = 4; NDArray xBig('c', {4,2}, {1, 0, 2, 0, 3, 0, 4, 0}, sd::DataType::INT32); - NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); - NDArray x = xBig(0, {1}, true); - NDArray y = yBig(0, {1}, true); - NDArray z(sd::DataType::DOUBLE); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, +sd::DataType::FLOAT32); NDArray x = xBig(0, {1}, true); NDArray y = yBig(0, {1}, +true); NDArray z(sd::DataType::DOUBLE); NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); @@ -1144,10 +1373,9 @@ TEST_F(CudaBasicsTests2, mmulDot_4) { const Nd4jLong N = 4; NDArray xBig('f', {4,2}, {1, 2, 3, 4, 0, 0, 0, 0}, sd::DataType::INT32); - NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); - NDArray x = xBig(0, {1}, true); - NDArray y = yBig(0, {1}); - NDArray z(sd::DataType::DOUBLE); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, +sd::DataType::FLOAT32); NDArray x = xBig(0, {1}, true); NDArray y = yBig(0, +{1}); NDArray z(sd::DataType::DOUBLE); NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu b/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu index 30d58946c51b..bd66220c07aa 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu @@ -18,57 +18,56 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; class CudaExtraArgumentsTests : public testing::Test { -public: - - CudaExtraArgumentsTests() { - printf("\n"); - fflush(stdout); - } + public: + CudaExtraArgumentsTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(CudaExtraArgumentsTests, Basic_Test_1) { - ExtraArguments args({1.0, 2.0, 3.0}); + ExtraArguments args({1.0, 2.0, 3.0}); - float ef[] = {1.f, 2.f, 3.f}; - double ed[] = {1., 2., 3.}; + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; - auto ptrFloat = reinterpret_cast(args.argumentsAsT()); - auto ptrDouble = reinterpret_cast(args.argumentsAsT()); - ASSERT_TRUE(ptrFloat != nullptr); - ASSERT_TRUE(ptrDouble != nullptr); + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); - auto tmpFloat = new float[3]; - auto tmpDouble = new double[3]; + auto tmpFloat = new float[3]; + auto tmpDouble = new double[3]; - cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost); - cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost); + cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost); - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f); + } - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5); + } - delete[] tmpFloat; - delete[] tmpDouble; + delete[] tmpFloat; + delete[] tmpDouble; } - TEST_F(CudaExtraArgumentsTests, Basic_Test_2) { - ExtraArguments args; + ExtraArguments args; - auto ptrInt = args.argumentsAsT(); - ASSERT_TRUE(ptrInt == nullptr); + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); } - diff --git a/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp index 66cc024bf98d..6e9979d33c5d 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp @@ -18,29 +18,29 @@ // Created by raver on 11/26/2018. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class CudaLaunchHelperTests : public testing::Test { -public: - + public: }; TEST_F(CudaLaunchHelperTests, test_reduction_blocks_1) { - ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512)); + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_2) { - ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121)); + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_3) { - ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513)); + ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513)); } TEST_F(CudaLaunchHelperTests, test_reduction_blocks_4) { - ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225)); + ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp index 42ab543b1a65..b8bd4d72e6ce 100644 --- a/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTests.cpp @@ -18,61 +18,64 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; using namespace sd::memory; class DataBufferTests : public testing::Test { -public: - + public: }; TEST_F(DataBufferTests, test_alloc_limit_1) { - if (!Environment::getInstance()->isCPU()) - return; - - auto deviceId = AffinityManager::currentDeviceId(); - auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId); - auto ogLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST); - auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId); - auto ogUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST); + if (!Environment::getInstance()->isCPU()) return; - auto limitSize = odUse + (150 * 1024 * 1024); - auto allocSize = 100000000; + auto deviceId = AffinityManager::currentDeviceId(); + auto odLimit = MemoryCounter::getInstance()->deviceLimit(deviceId); + auto ogLimit = MemoryCounter::getInstance()->groupLimit(MemoryType::HOST); + auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId); + auto ogUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST); - MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit + limitSize); + auto limitSize = odUse + (150 * 1024 * 1024); + auto allocSize = 100000000; - DataBuffer buffer(allocSize, DataType::INT32); + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, + odLimit + limitSize); - // separately testing per-device limits and group limits - ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId)); - ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); + DataBuffer buffer(allocSize, DataType::INT32); + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, + MemoryCounter::getInstance()->allocatedDevice(deviceId)); + ASSERT_EQ(ogUse + allocSize, + MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); - // setting smaller limits, to make sure next allocation fails with OOM exception - MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, allocSize - 100); + // setting smaller limits, to make sure next allocation fails with OOM + // exception + MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, + allocSize - 100); - try { - DataBuffer bufferFailed(allocSize, DataType::INT32); - ASSERT_TRUE(false); - } catch (allocation_exception &e) { - // we expect exception here - } + try { + DataBuffer bufferFailed(allocSize, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } - // restore original limits, so subsequent tests do not fail - MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit); + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, odLimit); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu index 730ade82452f..6fb0caf005a3 100644 --- a/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu +++ b/libnd4j/tests_cpu/layers_tests/DataBufferTestsCuda.cu @@ -18,24 +18,24 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; using namespace sd::memory; class DataBufferTestsCuda : public testing::Test { -public: - + public: }; /* @@ -50,25 +50,30 @@ TEST_F(DataBufferTestsCuda, test_alloc_limit_1) { auto odUse = MemoryCounter::getInstance()->allocatedDevice(deviceId); auto opUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST); - auto osUse = MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE); + auto osUse = +MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE); auto limitSize = odUse + 150000000; auto allocSize = 100000000; MemoryCounter::getInstance()->setDeviceLimit(deviceId, odLimit + limitSize); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit + limitSize); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, osLimit + limitSize); + MemoryCounter::getInstance()->setGroupLimit(MemoryType::HOST, opLimit + +limitSize); MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, +osLimit + limitSize); DataBuffer buffer(allocSize, DataType::INT32, nullptr, true); // separately testing per-device limits and group limits - ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance()->allocatedDevice(deviceId)); - ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); - ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE)); - - // setting smaller limits, to make sure next allocation fails with OOM exception - MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - 100); - MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, allocSize - 100); + ASSERT_EQ(odUse + allocSize, +MemoryCounter::getInstance()->allocatedDevice(deviceId)); ASSERT_EQ(opUse + +allocSize, MemoryCounter::getInstance()->allocatedGroup(MemoryType::HOST)); + ASSERT_EQ(osUse + allocSize, +MemoryCounter::getInstance()->allocatedGroup(MemoryType::DEVICE)); + + // setting smaller limits, to make sure next allocation fails with OOM +exception MemoryCounter::getInstance()->setDeviceLimit(deviceId, allocSize - +100); MemoryCounter::getInstance()->setGroupLimit(MemoryType::DEVICE, allocSize +- 100); // this allocation should fail, since we're allocating too much diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index be6440b707e4..ad4740a880f4 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -18,139 +18,150 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include #include #include +#include #include -#include #include -#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DataTypesValidationTests : public testing::Test { -public: - + public: }; TEST_F(DataTypesValidationTests, Basic_Test_1) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status()); + ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status()); } TEST_F(DataTypesValidationTests, Basic_Test_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - weights.assign(2.0); - input.linspace(1); + weights.assign(2.0); + input.linspace(1); - sd::ops::conv2d op; - auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DataTypesValidationTests, Basic_Test_3) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); - - weights.assign(2.0); - input.linspace(1); - - sd::ops::conv2d op; - auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - ASSERT_EQ(Status::OK(), result); - - ASSERT_EQ(exp, out); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, + {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(exp, out); } TEST_F(DataTypesValidationTests, Basic_Test_4) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); - - weights.assign(2.0); - input.linspace(1); - - sd::ops::conv2d op; - auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); - ASSERT_EQ(ND4J_STATUS_VALIDATION, result); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 1, 4}, + {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, + {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(ND4J_STATUS_VALIDATION, result); } TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { - auto x = NDArrayFactory::create('c', {5, 10}); - RandomGenerator gen(119, 120); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); - ASSERT_TRUE(x.sumNumber().e(0) != 0.f); + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); } TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) { - auto x = NDArrayFactory::create('c', {5, 10}); - RandomGenerator gen(119, 120); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); - ASSERT_TRUE(x.sumNumber().e(0) != 0.f); + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); } TEST_F(DataTypesValidationTests, cast_1) { + float16 x = static_cast(1.f); + float y = static_cast(x); - float16 x = static_cast(1.f); - float y = static_cast(x); - - ASSERT_TRUE(static_cast(1.f) == x); - ASSERT_TRUE(y == static_cast(x)); + ASSERT_TRUE(static_cast(1.f) == x); + ASSERT_TRUE(y == static_cast(x)); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { - auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); - auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); - auto z = NDArrayFactory::create(0); - - Context ctx(1); - ctx.setInputArray(0, x); - ctx.setInputArray(1, y); - ctx.setOutputArray(0, z); - - sd::ops::bits_hamming_distance op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {3}, + {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, + {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) { - auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); - auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); - auto z = NDArrayFactory::create(0); - - Context ctx(1); - ctx.setInputArray(0, x); - ctx.setInputArray(1, y); - ctx.setOutputArray(0, z); - - sd::ops::bits_hamming_distance op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {3}, + {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, + {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setInputArray(1, y); + ctx.setOutputArray(0, z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 73ac7e3bfa2c..7b4fc8bb0d75 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -14,78 +14,80 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" +#include +#include #include -#include #include #include -#include -#include +#include #include -#include -#include #include +#include +#include #include -#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests1 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height - const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width - - DeclarableOpsTests1() { - sd::memory::MemoryTracker::getInstance()->reset(); - } - - ~DeclarableOpsTests1() { - sd::memory::MemoryTracker::getInstance()->summarize(); - } + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + DeclarableOpsTests1() { sd::memory::MemoryTracker::getInstance()->reset(); } + + ~DeclarableOpsTests1() { + sd::memory::MemoryTracker::getInstance()->summarize(); + } }; template class TypedDeclarableOpsTests1 : public testing::Test { -public: - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int oD = 3; // output depth (= N for dense layer) - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height - const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width - - TypedDeclarableOpsTests1() { - printf("\n"); - } + public: + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + TypedDeclarableOpsTests1() { printf("\n"); } }; typedef ::testing::Types TestingTypes; @@ -93,1607 +95,1651 @@ TYPED_TEST_CASE(TypedDeclarableOpsTests1, TestingTypes); ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BasicInitialization1) { - auto concat = new sd::ops::concat(); - std::string expName("concat"); - ASSERT_EQ(expName, concat->getOpName()); - - auto x0 = NDArrayFactory::create('c', { 1, 5 }); - auto x1 = NDArrayFactory::create('c', { 1, 5 }); - auto x2 = NDArrayFactory::create('c', { 1, 5 }); - auto x3 = NDArrayFactory::create('c', { 1, 5 }); - auto x4 = NDArrayFactory::create('c', { 1, 5 }); - - x0.assign(1.0f); - x1.assign(1.0f); - x2.assign(1.0f); - x3.assign(1.0f); - x4.assign(1.0f); + auto concat = new sd::ops::concat(); + std::string expName("concat"); + ASSERT_EQ(expName, concat->getOpName()); - auto variableSpace = new VariableSpace(); + auto x0 = NDArrayFactory::create('c', {1, 5}); + auto x1 = NDArrayFactory::create('c', {1, 5}); + auto x2 = NDArrayFactory::create('c', {1, 5}); + auto x3 = NDArrayFactory::create('c', {1, 5}); + auto x4 = NDArrayFactory::create('c', {1, 5}); - variableSpace->putVariable(-1, x0); - variableSpace->putVariable(-2, x1); - variableSpace->putVariable(-3, x2); - variableSpace->putVariable(-4, x3); - variableSpace->putVariable(-5, x4); + x0.assign(1.0f); + x1.assign(1.0f); + x2.assign(1.0f); + x3.assign(1.0f); + x4.assign(1.0f); - auto nodeVar = std::make_shared(); + auto variableSpace = new VariableSpace(); - variableSpace->putVariable(1, nodeVar); + variableSpace->putVariable(-1, x0); + variableSpace->putVariable(-2, x1); + variableSpace->putVariable(-3, x2); + variableSpace->putVariable(-4, x3); + variableSpace->putVariable(-5, x4); - Context block(1, variableSpace); - block.appendI(1); - block.fillInputs({ -1, -2, -3, -4, -5 }); + auto nodeVar = std::make_shared(); - ASSERT_FALSE(nodeVar->hasNDArray()); + variableSpace->putVariable(1, nodeVar); - Nd4jStatus result = concat->execute(&block); + Context block(1, variableSpace); + block.appendI(1); + block.fillInputs({-1, -2, -3, -4, -5}); - ASSERT_TRUE(nodeVar->hasNDArray()); + ASSERT_FALSE(nodeVar->hasNDArray()); - ASSERT_EQ(25, nodeVar->getNDArray()->lengthOf()); + Nd4jStatus result = concat->execute(&block); - ASSERT_NEAR(25.0, nodeVar->getNDArray()->reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_TRUE(nodeVar->hasNDArray()); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(25, nodeVar->getNDArray()->lengthOf()); + ASSERT_NEAR(25.0, + nodeVar->getNDArray()->reduceNumber(reduce::Sum).e(0), + 1e-5); - delete variableSpace; - delete concat; + ASSERT_EQ(ND4J_STATUS_OK, result); + + delete variableSpace; + delete concat; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BasicInitialization2) { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation("concat"); + auto op = sd::ops::OpRegistrator::getInstance()->getOperation("concat"); - ASSERT_TRUE(op != nullptr); - std::string expName("concat"); - ASSERT_EQ(expName, op->getOpName()); + ASSERT_TRUE(op != nullptr); + std::string expName("concat"); + ASSERT_EQ(expName, op->getOpName()); - ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); - ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); + ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); + ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 3,4 }, { 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2 }); - auto exp = NDArrayFactory::create('c', { 3,4 }); - exp.linspace(0.9, 0.9); - sd::ops::apply_sgd op; - auto result = op.evaluate({ &x, &y }, { 1. }, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z.equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2}); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(0.9, 0.9); + sd::ops::apply_sgd op; + auto result = op.evaluate({&x, &y}, {1.}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); - auto exp = NDArrayFactory::create('c', { 3,4 }, { 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4 }); - sd::ops::assign op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z.equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create('c', {1, 4}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}); + sd::ops::assign op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); - auto eps = NDArrayFactory::create('c', { 3,4 }, { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); - auto exp1 = NDArrayFactory::create('c', { 3,4 }); // zero - auto exp2 = NDArrayFactory::create('c', { 1,4 }, { 3, 6, 9, 12 }); - sd::ops::assign_bp op; - auto result = op.evaluate({ &x, &y, &eps }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z1 = result.at(0); - auto z2 = result.at(1); - - ASSERT_TRUE(z1.equalsTo(exp1)); - ASSERT_TRUE(z2.equalsTo(exp2)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create('c', {1, 4}, {0.1, 0.2, 0.3, 0.4}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {3, 4}); // zero + auto exp2 = NDArrayFactory::create('c', {1, 4}, {3, 6, 9, 12}); + sd::ops::assign_bp op; + auto result = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z1 = result.at(0); + auto z2 = result.at(1); + ASSERT_TRUE(z1.equalsTo(exp1)); + ASSERT_TRUE(z2.equalsTo(exp2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AXpY_Test_1) { - auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto y = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); - auto exp = NDArrayFactory::create('c', { 3,4 }); - exp.linspace(3, 3); - sd::ops::axpy op; - auto result = op.evaluate({ &x, &y }, { 2. }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto z = result.at(0); - - ASSERT_TRUE(z.equalsTo(exp)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(3, 3); + sd::ops::axpy op; + auto result = op.evaluate({&x, &y}, {2.}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(DeclarableOpsTests1, BasicInitialization3) { - auto op1 = sd::ops::OpRegistrator::getInstance()->getOperation("concat"); - std::string expName("concat"); - auto hash = sd::ops::HashHelper::getInstance()->getLongHash(expName); + auto op1 = sd::ops::OpRegistrator::getInstance()->getOperation("concat"); + std::string expName("concat"); + auto hash = sd::ops::HashHelper::getInstance()->getLongHash(expName); - auto op2 = sd::ops::OpRegistrator::getInstance()->getOperation(hash); + auto op2 = sd::ops::OpRegistrator::getInstance()->getOperation(hash); - ASSERT_TRUE(op1 == op2); + ASSERT_TRUE(op1 == op2); } - TEST_F(DeclarableOpsTests1, SynonymInitialization2) { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation("Mul"); - auto op2 = sd::ops::OpRegistrator::getInstance()->getOperation("multiply"); + auto op = sd::ops::OpRegistrator::getInstance()->getOperation("Mul"); + auto op2 = sd::ops::OpRegistrator::getInstance()->getOperation("multiply"); - ASSERT_TRUE(op != nullptr); - std::string expName("multiply"); - ASSERT_EQ(expName, op->getOpName()); - ASSERT_TRUE(op == op2); + ASSERT_TRUE(op != nullptr); + std::string expName("multiply"); + ASSERT_EQ(expName, op->getOpName()); + ASSERT_TRUE(op == op2); } - TEST_F(DeclarableOpsTests1, TestTensorMmul1) { + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray y('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(1); - y.linspace(1); - - NDArray exp('c', { 2, 2 }, { 650.0, 1586.0, 1586.0, 4250.0 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + x.linspace(1); + y.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2}, {650.0, 1586.0, 1586.0, 4250.0}, + sd::DataType::FLOAT32); - auto out = results.at(0); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot2) { + NDArray x('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {2300.0, 2444.0, 2444.0, 2600.0}, + sd::DataType::FLOAT32); - NDArray exp('c', { 2, 2 }, { 2300.0, 2444.0, 2444.0, 2600.0 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot3) { + NDArray x('c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - - NDArray exp('f', { 2, 2 }, { 1090.0, 2818.0, 1168.0, 3040.0 }, sd::DataType::FLOAT32); + NDArray exp('f', {2, 2}, {1090.0, 2818.0, 1168.0, 3040.0}, + sd::DataType::FLOAT32); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out = results.at(0); + auto out = results.at(0); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } TEST_F(DeclarableOpsTests1, TestTensorDot4) { + NDArray x('f', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); + NDArray y('c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}, + sd::DataType::FLOAT32); - NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); - - NDArray exp('f', { 2, 2 }, { 1090.0, 1168.0, 2818.0, 3040.0 }, sd::DataType::FLOAT32); + NDArray exp('f', {2, 2}, {1090.0, 1168.0, 2818.0, 3040.0}, + sd::DataType::FLOAT32); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 1, 2, 2, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {44, 110, 160, 66, 132, 38, 88, 154, 68, 170, 224, 102, 204, + 82, 136, 238, 92, 230, 288, 138, 276, 126, 184, 322, 116, 290, + 352, 174, 348, 170, 232, 406, 76, 190, 160, 114, 228, 182, 152, + 266, 100, 250, 224, 150, 300, 226, 200, 350, 124, 310, 288, 186, + 372, 270, 248, 434, 148, 370, 352, 222, 444, 314, 296, 518}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {22, 66, 110, 154, 44, 88, 132, 176, 34, 102, 170, 238, 68, + 136, 204, 272, 46, 138, 230, 322, 92, 184, 276, 368, 58, 174, + 290, 406, 116, 232, 348, 464, 38, 114, 190, 266, 76, 152, 228, + 304, 50, 150, 250, 350, 100, 200, 300, 400, 62, 186, 310, 434, + 124, 248, 372, 496, 74, 222, 370, 518, 148, 296, 444, 592}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot7) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {76, 166, 112, 106, 196, 62, 136, 226, 60, 174, 208, 98, 212, + 230, 136, 250, 76, 214, 336, 122, 260, 174, 168, 306, 124, 286, + 240, 178, 340, 150, 232, 394, 100, 226, 176, 142, 268, 106, 184, + 310, 84, 234, 272, 134, 284, 274, 184, 334, 100, 274, 400, 158, + 332, 218, 216, 390, 148, 346, 304, 214, 412, 194, 280, 478}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot8) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 2, 4}, + {30, 90, 150, 210, 60, 120, 180, 240, 38, 114, 190, 266, 76, + 152, 228, 304, 46, 138, 230, 322, 92, 184, 276, 368, 54, 162, + 270, 378, 108, 216, 324, 432, 42, 126, 210, 294, 84, 168, 252, + 336, 50, 150, 250, 350, 100, 200, 300, 400, 58, 174, 290, 406, + 116, 232, 348, 464, 66, 198, 330, 462, 132, 264, 396, 528}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1, 2}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot9) { - - // NDArray z('f',{2,2,3}, sd::DataType::DOUBLE); - // z.linspace(1); - // z.printShapeInfo(); - // z.printIndexedBuffer(); - // z.reshapei('c', {4,3}); - // z.printShapeInfo(); - // z.printIndexedBuffer(); - - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,4,4,3 }, { 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,0,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // NDArray z('f',{2,2,3}, sd::DataType::DOUBLE); + // z.linspace(1); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + // z.reshapei('c', {4,3}); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 4, 4, 3}, + {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, + 86, 198, 198, 198, 310, 310, 310, 422, 422, 422, 62, 62, 62, 142, + 142, 142, 222, 222, 222, 302, 302, 302, 38, 38, 38, 86, 86, 86, + 134, 134, 134, 182, 182, 182, 38, 38, 38, 86, 86, 86, 134, 134, + 134, 182, 182, 182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, + 62, 62, 86, 86, 86, 198, 198, 198, 310, 310, 310, 422, 422, 422, + 62, 62, 62, 142, 142, 142, 222, 222, 222, 302, 302, 302, 62, 62, + 62, 142, 142, 142, 222, 222, 222, 302, 302, 302, 38, 38, 38, 86, + 86, 86, 134, 134, 134, 182, 182, 182, 14, 14, 14, 30, 30, 30, + 46, 46, 46, 62, 62, 62, 86, 86, 86, 198, 198, 198, 310, 310, + 310, 422, 422, 422}); + + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot10) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {114, 258, 402, 546, 138, 314, 490, 666, + 162, 370, 578, 786, 186, 426, 666, 906}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot11) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {98, 218, 338, 458, 134, 302, 470, 638, + 170, 386, 602, 818, 206, 470, 734, 998}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot12) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {2, 4, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = + NDArrayFactory::create('c', {4, 4}, + {272, 292, 312, 332, 368, 396, 424, 452, + 464, 500, 536, 572, 560, 604, 648, 692}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 4,4 }, { 272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 1, 2, 0, 2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot13) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {640, 560, 640, 576, 624, 576, 640, 560, 640}); - auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 640,560,640, 576,624,576, 640,560,640 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot14) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'c', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {648, 600, 520, 648, 536, 648, 520, 600, 648}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 648,600,520, 648,536,648, 520,600,648 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot15) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, + 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15}); + auto y = NDArrayFactory::create( + 'f', {4, 2, 3}, {2, 4, 6, 8, 10, 12, 14, 16, 2, 4, 6, 8, + 10, 12, 14, 16, 2, 4, 6, 8, 10, 12, 14, 16}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {624, 624, 624, 656, 656, 656, 624, 624, 624}); - auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); - auto y = NDArrayFactory::create('f', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); - auto expected = NDArrayFactory::create('c', { 3,3 }, { 624,624,624, 656,656,656, 624,624,624 }); - - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {2, 0, 2, 2, 1, 0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot16) { + NDArray x('c', {1}, std::vector{2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 1, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {2, 4, 6, 8}, sd::DataType::FLOAT32); - NDArray x('c', { 1 }, std::vector{2}, sd::DataType::FLOAT32); - NDArray y('c', { 2,1,2 }, { 1,2,3,4 }, sd::DataType::FLOAT32); - NDArray exp('c', { 2,2 }, { 2,4,6,8 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1, 1}); - sd::ops::tensormmul op; - auto results = op.evaluate({ &x, &y }, {}, { 1,0, 1,1 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestTensorDot17) { + NDArray x('f', {16, 16}, sd::DataType::FLOAT32); + NDArray y('f', {1000, 16}, sd::DataType::FLOAT32); + NDArray z('c', {16, 1000}, sd::DataType::FLOAT32); - NDArray x('f', { 16,16 }, sd::DataType::FLOAT32); - NDArray y('f', { 1000,16 }, sd::DataType::FLOAT32); - NDArray z('c', { 16,1000 }, sd::DataType::FLOAT32); + sd::ops::tensormmul op; + auto status = op.execute({&x, &y}, {&z}, {}, {1, 1, 1, 1}, {}); - sd::ops::tensormmul op; - auto status = op.execute({ &x, &y }, { &z }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivergentCheck1) { - auto op = sd::ops::OpRegistrator::getInstance()->getOperation("switch"); + auto op = sd::ops::OpRegistrator::getInstance()->getOperation("switch"); - ASSERT_TRUE(op != nullptr); - std::string expName("Switch"); - ASSERT_EQ(expName, op->getOpName()); - ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); - ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); + ASSERT_TRUE(op != nullptr); + std::string expName("Switch"); + ASSERT_EQ(expName, op->getOpName()); + ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); + ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(1); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::add addOp; + sd::ops::add addOp; - addOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(2); - y.assign(1); - exp.assign(3); - - auto variableSpace = new VariableSpace(); + auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); - sd::ops::add addOp; + block->fillInputs({-1, -2}); + sd::ops::add addOp; - addOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(1); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::add addOp; + sd::ops::add addOp; - addOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + addOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, AddScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(1); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 1, 1 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x.assign(2); - y.assign(1); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::add addOp; + sd::ops::add addOp; - addOp.execute(block); + addOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + subOp.execute(block); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractTest_1) { + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {1, 6}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 1, 6 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 1, 6 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractTest_2) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + // auto y({6}, {1,1,1,1,1,1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - // auto y({6}, {1,1,1,1,1,1}); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - sd::ops::subtract subOp; + sd::ops::subtract subOp; - auto res = subOp.evaluate({ &x, &y }); + auto res = subOp.evaluate({&x, &y}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } TEST_F(DeclarableOpsTests1, TestRng1) { - /* - Nd4jLong *buffer = new Nd4jLong[100000]; + /* + Nd4jLong *buffer = new Nd4jLong[100000]; - sd::random::RandomBuffer *rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); + sd::random::RandomBuffer *rng = (sd::random::RandomBuffer *) + initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("RNG initialization failed"); + if (rng == nullptr) + throw std::runtime_error("RNG initialization failed"); - auto x = NDArrayFactory::create_('c', {5, 3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, true); - block->fillInputs({-1}); - block->setRNG(rng); - block->getTArguments()->push_back(0.0f); - block->getTArguments()->push_back(1.0f); + auto x = NDArrayFactory::create_('c', {5, 3}); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + auto block = new Context(1, variableSpace, true); + block->fillInputs({-1}); + block->setRNG(rng); + block->getTArguments()->push_back(0.0f); + block->getTArguments()->push_back(1.0f); - sd::ops::randomuniform uniform; + sd::ops::randomuniform uniform; - Nd4jStatus status = uniform.execute(block); + Nd4jStatus status = uniform.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(x->sumNumber() > 0.0); + ASSERT_TRUE(x->sumNumber() > 0.0); - destroyRandom((Nd4jPointer) rng); - delete[] buffer; + destroyRandom((Nd4jPointer) rng); + delete[] buffer; - delete variableSpace; - delete block; - */ + delete variableSpace; + delete block; + */ } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeSumTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3); + y.assign(1); + z.assign(2); + exp.assign(6); - auto x = NDArrayFactory::create('c', { 5, 5 }); - auto y = NDArrayFactory::create('c', { 5, 5 }); - auto z = NDArrayFactory::create('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x.assign(3); - y.assign(1); - z.assign(2); - exp.assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + variableSpace->putVariable(1, NDArrayFactory::create('c', {5, 5})); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2, -3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, z); - variableSpace->putVariable(1, NDArrayFactory::create('c', { 5, 5 })); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2, -3 }); + sd::ops::mergeadd merge; - sd::ops::mergeadd merge; + merge.execute(block); - merge.execute(block); + auto res = variableSpace->getVariable(1)->getNDArray(); - auto res = variableSpace->getVariable(1)->getNDArray(); + ASSERT_TRUE(res->equalsTo(&exp)); - ASSERT_TRUE(res->equalsTo(&exp)); - - delete variableSpace; - delete block; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ClipByValue1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(4); + x.p(0, -1); + x.p(1, 2); + exp.assign(3); + exp.p(0, 0); + exp.p(1, 2); - auto x = NDArrayFactory::create('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x.assign(4); - x.p(0, -1); - x.p(1, 2); - exp.assign(3); - exp.p(0, 0); - exp.p(1, 2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + auto block = new Context(1, variableSpace, false); - block->appendT(0.0f); - block->appendT(3.0f); + block->appendT(0.0f); + block->appendT(3.0f); - block->fillInputs({ -1 }); + block->fillInputs({-1}); - sd::ops::clipbyvalue clip; + sd::ops::clipbyvalue clip; - clip.execute(block); + clip.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MergeAvgTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3); + y.assign(1); + z.assign(2); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 5, 5 }); - auto y = NDArrayFactory::create('c', { 5, 5 }); - auto z = NDArrayFactory::create('c', { 5, 5 }); - auto exp = NDArrayFactory::create('c', { 5, 5 }); - x.assign(3); - y.assign(1); - z.assign(2); - exp.assign(2); - - auto zu = NDArrayFactory::create('c', { 5, 5 }); + auto zu = NDArrayFactory::create('c', {5, 5}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(-3, z); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2, -3 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2, -3}); - sd::ops::mergeavg merge; + sd::ops::mergeavg merge; - merge.execute(block); + merge.execute(block); - auto res = variableSpace->getVariable(1)->getNDArray(); + auto res = variableSpace->getVariable(1)->getNDArray(); - ASSERT_TRUE(res->equalsTo(exp)); + ASSERT_TRUE(res->equalsTo(exp)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(3); + y.assign(1); + exp.assign(2); - auto x = NDArrayFactory::create('c', { 1, 1 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x.assign(3); - y.assign(1); - exp.assign(2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::subtract subOp; + sd::ops::subtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(3.f); - y.assign(1.f); - exp.assign(-2.f); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {1, 6}); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); - auto x = NDArrayFactory::create('c', { 1, 6 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 1, 6 }); - x.assign(3.f); - y.assign(1.f); - exp.assign(-2.f); + sd::ops::reversesubtract subOp; - sd::ops::reversesubtract subOp; + auto res = subOp.evaluate({&x, &y}); - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {1, 6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 1, 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(3.f); - y.assign(1.f); - exp.assign(-2.f); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - auto res = subOp.evaluate({ &x, &y }); + auto res = subOp.evaluate({&x, &y}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(1); + y.assign(3); + exp.assign(2); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + sd::ops::reversesubtract subOp; - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(1); - y.assign(3); - exp.assign(2); - x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); - ASSERT_TRUE(z.equalsTo(&exp)); - sd::ops::reversesubtract subOp; - - auto res = subOp.evaluate({ &x, &y }); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); + auto res = subOp.evaluate({&x, &y}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseModTest_1) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {6}); + auto y = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto z(exp); + x.assign(2.); + y.assign(9.f); + exp.assign(1.f); + y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); + ASSERT_TRUE(exp.equalsTo(&z)); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 6 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - auto z(exp); - x.assign(2.); - y.assign(9.f); - exp.assign(1.f); - y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); - ASSERT_TRUE(exp.equalsTo(&z)); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(exp.equalsTo(&z)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); - ASSERT_TRUE(exp.equalsTo(&z)); + sd::ops::reversemod subOp; - sd::ops::reversemod subOp; + auto res = subOp.evaluate({&x, &y}); - auto res = subOp.evaluate({ &x, &y }); - - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); - ASSERT_TRUE(exp.equalsTo(&z)); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(exp.equalsTo(&z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseModTest_2) { + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}); + auto z(exp); + x.assign(2.f); + y.assign(9.f); + exp.assign(1.f); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(z.equalsTo(&exp)); - // auto x('c', {1, 6}); - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto y = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5 }); - auto z(exp); - x.assign(2.f); - y.assign(9.f); - exp.assign(1.f); - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); - ASSERT_TRUE(z.equalsTo(&exp)); - x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); - ASSERT_TRUE(z.equalsTo(&exp)); - - sd::ops::reversemod subOp; + sd::ops::reversemod subOp; - auto res = subOp.evaluate({ &x, &y }); + auto res = subOp.evaluate({&x, &y}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(3); - y.assign(1); - exp.assign(-2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(3); - y.assign(1); - exp.assign(-2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(3); + y.assign(1); + exp.assign(-2); - auto x = NDArrayFactory::create('c', { 1, 1 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x.assign(3); - y.assign(1); - exp.assign(-2); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversesubtract subOp; + sd::ops::reversesubtract subOp; - subOp.execute(block); + subOp.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(3); - exp.assign(6); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(2); - y.assign(3); - exp.assign(6); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(3); - exp.assign(6); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(3); + exp.assign(6); - auto x = NDArrayFactory::create('c', { 1, 1 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x.assign(2); - y.assign(3); - exp.assign(6); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); - - sd::ops::multiply mul; + sd::ops::multiply mul; - mul.execute(block); + mul.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { + auto input = NDArrayFactory::create('c', {2, 2}); + for (int e = 0; e < input.lengthOf(); e++) input.p(e, e + 1); - auto input = NDArrayFactory::create('c', { 2, 2 }); - for (int e = 0; e < input.lengthOf(); e++) - input.p(e, e + 1); - - auto epsilon = NDArrayFactory::create('c', { 2, 2 }); - epsilon.p(0, 0.1f); - epsilon.p(1, 0.2f); - epsilon.p(2, 0.3f); - epsilon.p(3, 0.4f); + auto epsilon = NDArrayFactory::create('c', {2, 2}); + epsilon.p(0, 0.1f); + epsilon.p(1, 0.2f); + epsilon.p(2, 0.3f); + epsilon.p(3, 0.4f); - auto output = NDArrayFactory::create('c', { 2, 2 }); - output.assign(1.0f); + auto output = NDArrayFactory::create('c', {2, 2}); + output.assign(1.0f); - auto exp = NDArrayFactory::create('c', { 2, 2 }); - exp.p(0, -0.019661194f); - exp.p(1, 0.019661194f); - exp.p(2, -0.019661194f); - exp.p(3, 0.019661194f); + auto exp = NDArrayFactory::create('c', {2, 2}); + exp.p(0, -0.019661194f); + exp.p(1, 0.019661194f); + exp.p(2, -0.019661194f); + exp.p(3, 0.019661194f); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); - variableSpace->putVariable(-2, epsilon); - variableSpace->putVariable(1, output); - //variableSpace->putVariable(42, exp); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + variableSpace->putVariable(1, output); + // variableSpace->putVariable(42, exp); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::softmax_bp op; + sd::ops::softmax_bp op; - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); + ASSERT_TRUE(output.equalsTo(exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - sd::ops::divide div; + sd::ops::divide div; - auto res = div.evaluate({ &x, &y }); + auto res = div.evaluate({&x, &y}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(exp)); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - sd::ops::divide_no_nan div; - auto res = div.evaluate({ &x, &y }); + sd::ops::divide_no_nan div; + auto res = div.evaluate({&x, &y}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(exp)); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { + auto x = NDArrayFactory::create({6, 6, 6, 6, 6}); + auto y = NDArrayFactory::create({3, 3, 0, 3, 3}); + auto exp = NDArrayFactory::create({2, 2, 0, 2, 2}); - auto x = NDArrayFactory::create({ 6,6,6,6,6 }); - auto y = NDArrayFactory::create({ 3,3,0,3,3 }); - auto exp = NDArrayFactory::create({ 2, 2, 0, 2, 2 }); + sd::ops::divide_no_nan div; + auto res = div.evaluate({&x, &y}); - sd::ops::divide_no_nan div; - auto res = div.evaluate({ &x, &y }); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(exp)); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 1}); + auto y = NDArrayFactory::create('c', {1, 6}); + auto exp = NDArrayFactory::create('c', {3, 4, 5, 6}); + x.assign(3.f); + y.assign(6.f); + exp.assign(2.f); - auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); - auto y = NDArrayFactory::create('c', { 1, 6 }); - auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); - x.assign(3.f); - y.assign(6.f); - exp.assign(2.f); - - sd::ops::reversedivide div; + sd::ops::reversedivide div; - auto res = div.evaluate({ &x, &y }); + auto res = div.evaluate({&x, &y}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(res.at(0).equalsTo(exp)); - auto z(exp); - x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); - y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); + ASSERT_TRUE(res.at(0).equalsTo(exp)); + auto z(exp); + x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); - ASSERT_TRUE(z.equalsTo(&exp)); + ASSERT_TRUE(z.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x.equalsTo(exp)); + ASSERT_TRUE(x.equalsTo(exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { + auto x = NDArrayFactory::create('c', {5, 1}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 1}); + x.assign(6); + y.assign(2); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 1 }); - auto y = NDArrayFactory::create('c', { 5, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 1 }); - x.assign(6); - y.assign(2); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::divide div; + sd::ops::divide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {5, 3}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 5, 3 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(6); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { + auto x = NDArrayFactory::create('c', {1, 15}); + auto y = NDArrayFactory::create('c', {1, 15}); + auto exp = NDArrayFactory::create('c', {1, 15}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 1, 15 }); - auto y = NDArrayFactory::create('c', { 1, 15 }); - auto exp = NDArrayFactory::create('c', { 1, 15 }); - x.assign(2); - y.assign(6); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { + auto x = NDArrayFactory::create('c', {5, 3}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 5, 3 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 5, 3 }); - x.assign(2); - y.assign(6); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); - auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(block->getVariableSpace()->hasVariable(1)); + auto z = block->getVariableSpace()->getVariable(1)->getNDArray().get(); + ASSERT_TRUE(exp.equalsTo(z)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}); + x.assign(2); + y.assign(6); + exp.assign(3); - auto x = NDArrayFactory::create('c', { 1, 1 }); - auto y = NDArrayFactory::create('c', { 1, 1 }); - auto exp = NDArrayFactory::create('c', { 1, 1 }); - x.assign(2); - y.assign(6); - exp.assign(3); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, x); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1, -2 }); + variableSpace->putVariable(1, x); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - sd::ops::reversedivide div; + sd::ops::reversedivide div; - div.execute(block); + div.execute(block); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } TEST_F(DeclarableOpsTests1, Test_Cast_1) { - // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected - auto x = NDArrayFactory::create('c', { 5, 5 }); - auto yExp = NDArrayFactory::create('c', { 5, 5 }); - x.linspace(1); - yExp.linspace(1); - sd::ops::cast op; + // TODO: right now there's no real cast implementation, but genera idea should + // be the same: arrays equality to be expected + auto x = NDArrayFactory::create('c', {5, 5}); + auto yExp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + yExp.linspace(1); + sd::ops::cast op; - auto result = op.evaluate({ &x }, {}, { 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - ASSERT_TRUE(yExp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(yExp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestRegistrator1) { - auto res = sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); + auto res = sd::ops::OpRegistrator::getInstance()->getAllCustomOperations(); } // ////////////////////////////////////////////////////////////////////// @@ -1711,7 +1757,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // z->assign(120.0f); // std::string opName("add"); -// auto hash = sd::ops::HashHelper::getInstance()->getInstance()->getLongHash(opName); +// auto hash = +// sd::ops::HashHelper::getInstance()->getInstance()->getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; @@ -1728,13 +1775,14 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // outputBuffers[0] = (Nd4jPointer) z->buffer(); // outputShapes[0] = (Nd4jPointer) z->shapeInfo(); - -// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); -// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); -// ASSERT_EQ(ND4J_STATUS_OK, status); -// ASSERT_NEAR(2.0f, y->meanNumber().e(0), 1e-5); -// ASSERT_NEAR(1.0f, x->meanNumber().e(0), 1e-5); -// ASSERT_NEAR(3.0f, z->meanNumber().e(0), 1e-5); +// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, +// outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); auto +// status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, +// outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, +// false); ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_NEAR(2.0f, +// y->meanNumber().e(0), 1e-5); ASSERT_NEAR(1.0f, +// x->meanNumber().e(0), 1e-5); ASSERT_NEAR(3.0f, +// z->meanNumber().e(0), 1e-5); // delete x; // delete y; @@ -1761,7 +1809,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // std::string opName("add"); -// auto hash = sd::ops::HashHelper::getInstance()->getInstance()->getLongHash(opName); +// auto hash = +// sd::ops::HashHelper::getInstance()->getInstance()->getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; @@ -1775,12 +1824,12 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // auto outputBuffers = new Nd4jPointer[1]; // auto outputShapes = new Nd4jPointer[1]; -// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); +// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, +// outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); // ASSERT_NEAR(2.0, y->meanNumber().e(0), 1e-5); // ASSERT_NEAR(3.0, x->meanNumber().e(0), 1e-5); - // delete x; // delete y; // delete z; @@ -1794,214 +1843,212 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestGemv1) { - /* - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 3, 1, 0, 1, 99}; - ArrayOptions::setDataType(xShape, sd::DataType::FLOAT32); - auto x = new NDArray(xBuffer, xShape); + /* + auto xBuffer = new + float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, + 15.f}; auto xShape = new Nd4jLong[8] {2, 5, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(xShape, sd::DataType::FLOAT32); + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(yShape, sd::DataType::FLOAT32); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(yShape, sd::DataType::FLOAT32); - auto y = new NDArray(yBuffer, yShape); + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, + x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete []xBuffer; delete []xShape; delete x; delete []yBuffer; delete []yShape; delete y; delete z; delete []expBuffer; delete exp; - */ + delete []xBuffer; delete []xShape; delete x; delete []yBuffer; delete + []yShape; delete y; delete z; delete []expBuffer; delete exp; + */ } #endif - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Transpose1) { - auto x = NDArrayFactory::create('c', { 3,5,2 }); - auto exp = NDArrayFactory::create('c', { 2,5,3 }); + auto x = NDArrayFactory::create('c', {3, 5, 2}); + auto exp = NDArrayFactory::create('c', {2, 5, 3}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); - sd::ops::transpose transpose; + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + sd::ops::transpose transpose; - Nd4jStatus status = transpose.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = transpose.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); - ASSERT_TRUE(exp.dataType() == result->dataType()); - ASSERT_TRUE(exp.ordering() == result->ordering()); + ASSERT_TRUE(exp.isSameShape(*result)); + ASSERT_TRUE(exp.dataType() == result->dataType()); + ASSERT_TRUE(exp.ordering() == result->ordering()); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(DeclarableOpsTests1, Permute1) { + Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99}; + Nd4jLong shapeExp[] = {3, 15, 5, 10, 50, 10, 1, 0, 1, 99}; + const std::vector perm = {2, 0, 1}; - Nd4jLong shapeX[] = { 3, 5,10,15, 150,15,1, 0,1,99 }; - Nd4jLong shapeExp[] = { 3, 15,5,10, 50,10,1, 0,1,99 }; - const std::vector perm = { 2, 0, 1 }; - - ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); - NDArray x(shapeX, true); - NDArray exp(shapeExp, true); + NDArray x(shapeX, true); + NDArray exp(shapeExp, true); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - block->appendI(perm); // set dimensions to be permuted + block->appendI(perm); // set dimensions to be permuted - sd::ops::permute permute; - Nd4jStatus status = permute.execute(block); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + sd::ops::permute permute; + Nd4jStatus status = permute.execute(block); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(result->isSameShapeStrict(exp)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(result->isSameShapeStrict(exp)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { - Nd4jLong shapeX[] = { 3, 5, 10, 15, 150, 15, 1, 0, 1, 99 }; - Nd4jLong shapeExp[] = { 3, 15, 5, 10, 1, 150, 15, 0, -1, 99 }; + Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99}; + Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, -1, 99}; - ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); - const std::vector perm = { 2, 0, 1 }; - NDArray x(shapeX); - NDArray exp(shapeExp); + const std::vector perm = {2, 0, 1}; + NDArray x(shapeX); + NDArray exp(shapeExp); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - sd::ops::im2col permute; - Nd4jStatus status = permute.execute(block); + sd::ops::im2col permute; + Nd4jStatus status = permute.execute(block); - ASSERT_TRUE(status != 0); + ASSERT_TRUE(status != 0); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape1) { - auto input = NDArrayFactory::create('c', { 4, 5, 5, 10, 10 }); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto input = NDArrayFactory::create('c', {4, 5, 5, 10, 10}); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - // kernel params - block->appendI(MAX_INT); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - sd::ops::testreduction testop; + // kernel params + block->appendI(MAX_INT); - auto inP = new Nd4jLong[shape::shapeInfoLength(input.shapeInfo())]; - memcpy(inP, input.shapeInfo(), shape::shapeInfoByteLength(input.rankOf())); + sd::ops::testreduction testop; - auto inshape = new ShapeList(inP); + auto inP = new Nd4jLong[shape::shapeInfoLength(input.shapeInfo())]; + memcpy(inP, input.shapeInfo(), shape::shapeInfoByteLength(input.rankOf())); - auto shapes = testop.calculateOutputShape(inshape, *block); + auto inshape = new ShapeList(inP); - ASSERT_EQ(1, shapes->size()); - ASSERT_EQ(0, shapes->at(0)[0]); // scalar shape has rank 0 - ASSERT_EQ(8192, shapes->at(0)[1]); - ASSERT_EQ(1, shapes->at(0)[2]); + auto shapes = testop.calculateOutputShape(inshape, *block); - delete[] inP; - delete variableSpace; - delete block; - delete inshape; - delete shapes; + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(0, shapes->at(0)[0]); // scalar shape has rank 0 + ASSERT_EQ(8192, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); + delete[] inP; + delete variableSpace; + delete block; + delete inshape; + delete shapes; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestReductionShape2) { - auto input = NDArrayFactory::create('c', { 4, 5, 5, 10, 10 }); + auto input = NDArrayFactory::create('c', {4, 5, 5, 10, 10}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - // kernel params - //block->getIArguments()->push_back(4); - block->appendI(1); - block->appendI(2); - block->appendI(3); - block->appendI(4); + // kernel params + // block->getIArguments()->push_back(4); + block->appendI(1); + block->appendI(2); + block->appendI(3); + block->appendI(4); - sd::ops::testreduction testop; + sd::ops::testreduction testop; - auto inshapes = new ShapeList(input.shapeInfo()); - auto shapes = testop.calculateOutputShape(inshapes, *block); - ASSERT_EQ(1, shapes->size()); - ASSERT_EQ(1, shapes->at(0)[0]); - ASSERT_EQ(4, shapes->at(0)[1]); - ASSERT_EQ(1, shapes->at(0)[2]); + auto inshapes = new ShapeList(input.shapeInfo()); + auto shapes = testop.calculateOutputShape(inshapes, *block); + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(1, shapes->at(0)[0]); + ASSERT_EQ(4, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); - delete variableSpace; - delete block; - delete shapes; - delete inshapes; + delete variableSpace; + delete block; + delete shapes; + delete inshapes; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, TestCustomShape1) { - auto input = NDArrayFactory::create('c', { 2, 3, 4 }); + auto input = NDArrayFactory::create('c', {2, 3, 4}); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, input); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({ -1 }); - - sd::ops::testcustom test; + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); - auto inshapes = new ShapeList(input.shapeInfo()); - auto shapes = test.calculateOutputShape(inshapes, *block); + sd::ops::testcustom test; + auto inshapes = new ShapeList(input.shapeInfo()); + auto shapes = test.calculateOutputShape(inshapes, *block); - ASSERT_EQ(input.shapeInfo()[0], shapes->at(0)[0]); - ASSERT_EQ(input.shapeInfo()[1] * 2, shapes->at(0)[1]); - ASSERT_EQ(input.shapeInfo()[2] * 2, shapes->at(0)[2]); - ASSERT_EQ(input.shapeInfo()[3] * 2, shapes->at(0)[3]); + ASSERT_EQ(input.shapeInfo()[0], shapes->at(0)[0]); + ASSERT_EQ(input.shapeInfo()[1] * 2, shapes->at(0)[1]); + ASSERT_EQ(input.shapeInfo()[2] * 2, shapes->at(0)[2]); + ASSERT_EQ(input.shapeInfo()[3] * 2, shapes->at(0)[3]); - delete variableSpace; - delete block; - delete shapes; - delete inshapes; + delete variableSpace; + delete block; + delete shapes; + delete inshapes; } - ////////////////////////////////////////////////////////////////////// /* TEST_F(DeclarableOpsTests1, Sum1) { @@ -2041,27 +2088,29 @@ TEST_F(DeclarableOpsTests1, Sum1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Pnormpool2d1) { + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); - auto x = NDArrayFactory::create('c', { bS,iD,iH,iW }); - auto exp = NDArrayFactory::create('c', { bS,iD,oH,oW }); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - block->appendI({ kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI({kH, kW, sH, sW, pH, pW, dW, dH, 0, 1, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; + // 4,5 - pad Height/Width; 6,7 - dilation Height/Width; + // 8 - same mode; 9 - extraParam0 for pnorm case; - sd::ops::pnormpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::pnormpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); - delete variableSpace; - delete block; + delete variableSpace; + delete block; } /*///////////////////////////////////////////////////////////////////// @@ -2082,7 +2131,8 @@ TEST_F(DeclarableOpsTests1, IsMax1) { block->fillInputs({-1}); std::vector* argI = block->getIArguments(); // *argI = {1}; // dimensions - argI->push_back(1); // = {1}; // dimensions + argI->push_back(1); // = {1}; // +dimensions sd::ops::ismax ismaxOp; Nd4jStatus status = ismaxOp.execute(block); @@ -2099,81 +2149,78 @@ TEST_F(DeclarableOpsTests1, IsMax1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax1) { - NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); - // NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); - x.linspace(1); - exp.p(0, 2, true); - exp.p(1, 2, true); - exp.p(2, 2, true); - - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, { 1 }); + NDArray x('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + exp.p(0, 2, true); + exp.p(1, 2, true); + exp.p(2, 2, true); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto res = result.at(0); - res.printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + res.printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax2) { - NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); - // NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); - x.linspace(1); - //exp.p(0, 2, true); - //exp.p(1, 2, true); - exp.p(2, 2, true); - - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 0, 1 }); + NDArray x('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + // exp.p(0, 2, true); + // exp.p(1, 2, true); + exp.p(2, 2, true); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax3) { - NDArray x = NDArrayFactory::create(120.f); //('c', {3, 3}, sd::DataType::FLOAT32); -// NDArray exp('c', {3, 3}, sd::DataType::BOOL); - NDArray exp = NDArrayFactory::create(1.f);//, sd::DataType::FLOAT32); //'c', {3, 3}, sd::DataType::FLOAT32); - x.linspace(1); - //exp.p(0, 2, true); - //exp.p(1, 2, true); - //exp.p(2, 2, true); - - sd::ops::ismax ismaxOp; - auto result = ismaxOp.evaluate({ &x }, {}, { 0 }); + NDArray x = NDArrayFactory::create( + 120.f); //('c', {3, 3}, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp = NDArrayFactory::create( + 1.f); //, sd::DataType::FLOAT32); //'c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + // exp.p(0, 2, true); + // exp.p(1, 2, true); + // exp.p(2, 2, true); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto res = result.at(0); - //res->printIndexedBuffer("IS_MAX"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, IsMax4) { - auto x = NDArrayFactory::create('c', { 6 }, { 0, 0, 0, 2, 2, 0 }); - auto z = NDArrayFactory::create('c', { 6 }); - auto e = NDArrayFactory::create('c', { 6 }, { false, false, false, true, false, false }); + auto x = NDArrayFactory::create('c', {6}, {0, 0, 0, 2, 2, 0}); + auto z = NDArrayFactory::create('c', {6}); + auto e = NDArrayFactory::create( + 'c', {6}, {false, false, false, true, false, false}); - sd::ops::ismax op; - auto result = op.execute({ &x }, { &z }); - ASSERT_EQ(Status::OK(), result); + sd::ops::ismax op; + auto result = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////// @@ -2188,8 +2235,15 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // NDArray bias('c', {1,2*K}, sd::DataType::DOUBLE); // NDArray init('c', {bS,K}, sd::DataType::DOUBLE); // NDArray mask('c', {bS,K}, sd::DataType::DOUBLE); -// NDArray expState('c', {bS,K,N}, {0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, sd::DataType::DOUBLE); -// NDArray expOut('c', {bS,K,N}, {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, sd::DataType::DOUBLE); +// NDArray expState('c', {bS,K,N}, {0.847983, 0.874549, 0.896109, 0.913715, +// 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, +// 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, +// 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, +// sd::DataType::DOUBLE); NDArray expOut('c', {bS,K,N}, +// {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, +// 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, +// 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, +// sd::DataType::DOUBLE); // input.assign(1.5); // weights.assign(0.5); @@ -2198,8 +2252,8 @@ TEST_F(DeclarableOpsTests1, IsMax4) { // mask.assign(1.); // sd::ops::sru_old op; -// auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); -// ASSERT_TRUE(results.size() == 2); +// auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, +// {}); ASSERT_TRUE(results.size() == 2); // auto state = results.at(0); // auto output = results.at(1); @@ -2214,1064 +2268,1277 @@ TEST_F(DeclarableOpsTests1, IsMax4) { ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_test1) { - - const int bS = 2; - const int K = 3; - const int N = 4; - - NDArray input('c', { bS,K,N }, sd::DataType::DOUBLE); - NDArray weights('c', { 3 * K,K }, sd::DataType::DOUBLE); - NDArray bias('c', { 2 * K }, sd::DataType::DOUBLE); - NDArray init('c', { bS,K }, sd::DataType::DOUBLE); - NDArray mask('c', { bS,K }, sd::DataType::DOUBLE); - NDArray expState('c', { bS,K,N }, { 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656 }, sd::DataType::DOUBLE); - NDArray expOut('c', { bS,K,N }, { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }, sd::DataType::DOUBLE); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - init.assign(1.); - mask.assign(1.); - - sd::ops::sru op; - auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }); - ASSERT_TRUE(results.size() == 2); - - auto output = results.at(0); - auto state = results.at(1); - - ASSERT_TRUE(expState.equalsTo(state)); - ASSERT_TRUE(expOut.equalsTo(output)); - - + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', {bS, K, N}, sd::DataType::DOUBLE); + NDArray weights('c', {3 * K, K}, sd::DataType::DOUBLE); + NDArray bias('c', {2 * K}, sd::DataType::DOUBLE); + NDArray init('c', {bS, K}, sd::DataType::DOUBLE); + NDArray mask('c', {bS, K}, sd::DataType::DOUBLE); + NDArray expState('c', {bS, K, N}, + {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, + 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, + 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, + 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, + sd::DataType::DOUBLE); + NDArray expOut('c', {bS, K, N}, + {0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, + sd::DataType::DOUBLE); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru op; + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_bp) { - - const int bS = 2; - const int K = 3; - const int N = 4; - std::vector expGradXBuff = { -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165 }; - std::vector expGradWBuff = { 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215 }; - std::vector expGradBBuff = { -0.7043748, -0.7043748, -0.7043748, -0.2128962, -0.2128962, -0.2128962 }; - std::vector expGradInitBuff = { 1.1421, 1.1421, 1.1421, 1.1421, 1.1421, 1.1421 }; - std::vector stateBuff = { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }; - - auto input = NDArrayFactory::create('c', { bS,K,N }); - auto weights = NDArrayFactory::create('c', { 3 * K,K }); - auto bias = NDArrayFactory::create('c', { 1,2 * K }); - auto init = NDArrayFactory::create('c', { bS,K }); - auto mask = NDArrayFactory::create('c', { bS,K }); - auto state = NDArrayFactory::create('c', { bS,K,N }, stateBuff); - auto inGradCt = NDArrayFactory::create('c', { bS,K }); - auto inGradH = NDArrayFactory::create('c', { bS,K,N }); - - auto expGradX = NDArrayFactory::create('c', { bS,K,N }, expGradXBuff); - auto expGradW = NDArrayFactory::create('c', { bS,3 * K,K }, expGradWBuff); - auto expGradB = NDArrayFactory::create('c', { 1,2 * K }, expGradBBuff); - auto expGradInit = NDArrayFactory::create('c', { bS,K }, expGradInitBuff); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - mask.assign(1.); - init.assign(1.); - inGradCt.assign(0.5); - inGradH.assign(0.5); - - sd::ops::sru_bp bp; - auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); - ASSERT_TRUE(resultsBP.size() == 4); - - auto gradX = resultsBP.at(0); - auto gradW = resultsBP.at(1); - auto gradB = resultsBP.at(2); - auto gradInit = resultsBP.at(3); - // expGradX.printBuffer("Exp GRAD"); - // gradX->printBuffer("Res GRAD"); - ASSERT_TRUE(expGradX.equalsTo(gradX, 1e-4)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - ASSERT_TRUE(expGradInit.equalsTo(gradInit)); + const int bS = 2; + const int K = 3; + const int N = 4; + std::vector expGradXBuff = { + -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, + -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, + -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, + -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, + -0.0259303, -0.03869125, -0.0302272, -0.02299165}; + std::vector expGradWBuff = { + 0.42526005, 0.42526005, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, 0.42526005, 0.42526005, 0.42526005, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, 0.42526005, 0.42526005, 0.42526005, 0.42526005, + 0.42526005, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.5282811, -0.5282811, -0.5282811, -0.5282811, -0.5282811, + -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, + -0.15967215, -0.15967215, -0.15967215, -0.15967215}; + std::vector expGradBBuff = {-0.7043748, -0.7043748, -0.7043748, + -0.2128962, -0.2128962, -0.2128962}; + std::vector expGradInitBuff = {1.1421, 1.1421, 1.1421, + 1.1421, 1.1421, 1.1421}; + std::vector stateBuff = { + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, + 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, + 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}; + + auto input = NDArrayFactory::create('c', {bS, K, N}); + auto weights = NDArrayFactory::create('c', {3 * K, K}); + auto bias = NDArrayFactory::create('c', {1, 2 * K}); + auto init = NDArrayFactory::create('c', {bS, K}); + auto mask = NDArrayFactory::create('c', {bS, K}); + auto state = NDArrayFactory::create('c', {bS, K, N}, stateBuff); + auto inGradCt = NDArrayFactory::create('c', {bS, K}); + auto inGradH = NDArrayFactory::create('c', {bS, K, N}); + + auto expGradX = NDArrayFactory::create('c', {bS, K, N}, expGradXBuff); + auto expGradW = + NDArrayFactory::create('c', {bS, 3 * K, K}, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', {1, 2 * K}, expGradBBuff); + auto expGradInit = + NDArrayFactory::create('c', {bS, K}, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bp bp; + auto resultsBP = bp.evaluate( + {&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, + {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + // expGradX.printBuffer("Exp GRAD"); + // gradX->printBuffer("Res GRAD"); + ASSERT_TRUE(expGradX.equalsTo(gradX, 1e-4)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, sru_bi_1) { - - const int bS = 2; - const int K = 3; - const int N = 4; - - NDArray input('c', { N,bS,2 * K }, sd::DataType::DOUBLE); - NDArray weights('c', { 2 * K,6 * K }, sd::DataType::DOUBLE); - NDArray bias('c', { 4 * K }, sd::DataType::DOUBLE); - NDArray init('c', { bS,2 * K }, sd::DataType::DOUBLE); - NDArray mask('c', { bS,2 * K }, sd::DataType::DOUBLE); - NDArray expState('c', { N,bS,2 * K }, { 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857 }); - NDArray expOut('c', { N,bS,2 * K }, { 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265 }); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - init.assign(1.); - mask.assign(1.); - - sd::ops::sru_bi op; - auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }, {}, {}); - ASSERT_TRUE(results.size() == 2); - - auto output = results.at(0); - auto state = results.at(1); - // state->printBuffer(); - // output->printBuffer(); - - ASSERT_TRUE(expState.equalsTo(state)); - ASSERT_TRUE(expOut.equalsTo(output)); - - + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', {N, bS, 2 * K}, sd::DataType::DOUBLE); + NDArray weights('c', {2 * K, 6 * K}, sd::DataType::DOUBLE); + NDArray bias('c', {4 * K}, sd::DataType::DOUBLE); + NDArray init('c', {bS, 2 * K}, sd::DataType::DOUBLE); + NDArray mask('c', {bS, 2 * K}, sd::DataType::DOUBLE); + NDArray expState( + 'c', {N, bS, 2 * K}, + {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, + 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, + 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, + 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, + 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, + 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857}); + NDArray expOut( + 'c', {N, bS, 2 * K}, + {0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, + 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.790317, 0.790317, + 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, + 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.790317, + 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, + 0.790317, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, + 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265}); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru_bi op; + auto results = op.evaluate({&input, &weights, &bias, &init, &mask}, {}, {}); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + // state->printBuffer(); + // output->printBuffer(); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); } TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { - - const int bS = 2; - const int K = 3; - const int N = 3; - std::vector expGradXBuff = { 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129 }; - std::vector expGradInitBuff = { 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676, 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676 }; - std::vector expGradWBuff = { 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926 }; - std::vector expGradBBuff = { -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306 }; - std::vector stateBuff = { 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, 1.028569, 1.028569, 1.028569, 1.112884,1.112884, 1.112884, 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905,1.056905, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905,1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905 }; - - auto input = NDArrayFactory::create('c', { N,bS,2 * K }); - auto weights = NDArrayFactory::create('c', { 2 * K,6 * K }); - auto bias = NDArrayFactory::create('c', { 4 * K }); - auto init = NDArrayFactory::create('c', { bS,2 * K }); - auto mask = NDArrayFactory::create('c', { bS,2 * K }); - NDArray state('c', { N,bS,2 * K }, stateBuff); - auto inGradCt = NDArrayFactory::create('c', { bS,2 * K }); - auto inGradH = NDArrayFactory::create('c', { N,bS,2 * K }); - - NDArray gradBias('c', { bS,4 * K }, expGradBBuff); - - NDArray expGradX('c', { N,bS,2 * K }, expGradXBuff); - NDArray expGradW('c', { N,2 * K,6 * K }, expGradWBuff); - auto expGradB = NDArrayFactory::create('c', { 4 * K }); - gradBias.reduceAlongDimension(reduce::Sum, expGradB, { 0 }); // [bS, 4K] -> [4K] - NDArray expGradInit('c', { bS,2 * K }, expGradInitBuff); - - input.assign(1.5); - weights.assign(0.5); - bias.assign(0.3); - mask.assign(1.); - init.assign(1.); - inGradCt.assign(0.5); - inGradH.assign(0.5); - - sd::ops::sru_bi_bp bp; - auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); - ASSERT_TRUE(resultsBP.size() == 4); - - auto gradX = resultsBP.at(0); - auto gradW = resultsBP.at(1); - auto gradB = resultsBP.at(2); - auto gradInit = resultsBP.at(3); - - ASSERT_TRUE(expGradX.equalsTo(gradX)); - ASSERT_TRUE(expGradW.equalsTo(gradW)); - ASSERT_TRUE(expGradB.equalsTo(gradB)); - ASSERT_TRUE(expGradInit.equalsTo(gradInit)); - - + const int bS = 2; + const int K = 3; + const int N = 3; + std::vector expGradXBuff = { + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, + 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129}; + std::vector expGradInitBuff = {1.05121, 1.05121, 1.05121, 1.02676, + 1.02676, 1.02676, 1.05121, 1.05121, + 1.05121, 1.02676, 1.02676, 1.02676}; + std::vector expGradWBuff = { + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02595354, -0.090096, -0.00882456, 0.02595354, -0.090096, -0.0088245, + 0.02595354, -0.090096, -0.00882456, 0.01651665, -0.0559437, -0.0084390, + 0.01651665, -0.0559437, -0.00843906, 0.01651665, -0.0559437, -0.00843906, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.02124567, -0.0731508, -0.00868926, 0.02124567, -0.0731508, -0.0086892, + 0.02124567, -0.0731508, -0.00868926, 0.02084955, -0.0712011, -0.0085608, + 0.02084955, -0.0712011, -0.00856086, 0.02084955, -0.0712011, -0.00856086, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926, + 0.01671156, -0.0570699, -0.00856086, 0.01671156, -0.0570699, -0.0085608, + 0.01671156, -0.0570699, -0.00856086, 0.02534988, -0.0880002, -0.0086892, + 0.02534988, -0.0880002, -0.00868926, 0.02534988, -0.0880002, -0.00868926}; + std::vector expGradBBuff = { + -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, + -0.0717151, -0.0734389, -0.0734389, -0.0734389, -0.0717151, + -0.0717151, -0.0717151, -0.00869156, -0.00869156, -0.00869156, + -0.00856306, -0.00856306, -0.00856306, -0.00869156, -0.00869156, + -0.00869156, -0.00856306, -0.00856306, -0.00856306}; + std::vector stateBuff = { + 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, + 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, + 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, + 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, + 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905, + 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905}; + + auto input = NDArrayFactory::create('c', {N, bS, 2 * K}); + auto weights = NDArrayFactory::create('c', {2 * K, 6 * K}); + auto bias = NDArrayFactory::create('c', {4 * K}); + auto init = NDArrayFactory::create('c', {bS, 2 * K}); + auto mask = NDArrayFactory::create('c', {bS, 2 * K}); + NDArray state('c', {N, bS, 2 * K}, stateBuff); + auto inGradCt = NDArrayFactory::create('c', {bS, 2 * K}); + auto inGradH = NDArrayFactory::create('c', {N, bS, 2 * K}); + + NDArray gradBias('c', {bS, 4 * K}, expGradBBuff); + + NDArray expGradX('c', {N, bS, 2 * K}, expGradXBuff); + NDArray expGradW('c', {N, 2 * K, 6 * K}, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', {4 * K}); + gradBias.reduceAlongDimension(reduce::Sum, expGradB, + {0}); // [bS, 4K] -> [4K] + NDArray expGradInit('c', {bS, 2 * K}, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bi_bp bp; + auto resultsBP = bp.evaluate( + {&input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask}, {}, + {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + + ASSERT_TRUE(expGradX.equalsTo(gradX)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); } TEST_F(DeclarableOpsTests1, ArgMax1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(4); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(4); - sd::ops::argmax op; + sd::ops::argmax op; - auto result = op.evaluate({ &x }, {}, { 1 }); + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax2) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 5 }); - exp.assign(2); - - sd::ops::argmax op; + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2); - auto result = op.evaluate({ &x }, {}, { 0 }); + sd::ops::argmax op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {0}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax3) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 1 }, { 0. }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 5 }); - exp.assign(2); + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 1}, {0.}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2); - sd::ops::argmax op; + sd::ops::argmax op; - auto result = op.evaluate({ &x, &dim }, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, ArgMax4) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 1 }, { 1 }); - x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(4); + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 1}, {1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(4); - sd::ops::argmax op; + sd::ops::argmax op; - auto result = op.evaluate({ &x, &dim }, {}, {}); + auto result = op.evaluate({&x, &dim}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, ArgMax5) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - auto dim = NDArrayFactory::create('c', { 1, 2 }, { 0, 1 }); - x.linspace(1); - auto exp = NDArrayFactory::create(14); + auto x = NDArrayFactory::create('c', {3, 5}); + auto dim = NDArrayFactory::create('c', {1, 2}, {0, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create(14); + sd::ops::argmax op; - sd::ops::argmax op; + auto result = op.evaluate({&x, &dim}, {}, {}); - auto result = op.evaluate({ &x, &dim }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, ArgMax6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto dim = NDArrayFactory::create(-1.f); - x.linspace(1); - + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto dim = NDArrayFactory::create(-1.f); + x.linspace(1); - sd::ops::argmax op; + sd::ops::argmax op; - auto expected = op.evaluate({ &x }, {}, { 2 }); - ASSERT_EQ(Status::OK(), expected.status()); - auto exp = expected.at(0); + auto expected = op.evaluate({&x}, {}, {2}); + ASSERT_EQ(Status::OK(), expected.status()); + auto exp = expected.at(0); - auto result = op.evaluate({ &x, &dim }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto result = op.evaluate({&x, &dim}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } - TEST_F(DeclarableOpsTests1, ArgMin1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); - // auto exp('c', {3, 1}); - auto exp = NDArrayFactory::create('c', { 3 }); - exp.assign(0.0f); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + // auto exp('c', {3, 1}); + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(0.0f); - sd::ops::argmin op; + sd::ops::argmin op; - auto result = op.evaluate({ &x }, {}, { 1 }); + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests1, SquareTests1) { - auto x = NDArrayFactory::create('c', { 3, 5 }); - x.linspace(1); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); - auto exp = NDArrayFactory::create('c', { 3, 5 }); - exp.linspace(1); - exp *= exp; + auto exp = NDArrayFactory::create('c', {3, 5}); + exp.linspace(1); + exp *= exp; - sd::ops::square op; + sd::ops::square op; - auto result = op.evaluate({ &x }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_1) { + auto indices = + NDArrayFactory::create('c', {1, 4}, {0.0f, 2.0f, -1.0f, 1.0f}); - auto indices = NDArrayFactory::create('c', { 1, 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - - auto exp = NDArrayFactory::create('c', { 1, 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto exp = NDArrayFactory::create( + 'c', {1, 4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - sd::ops::onehot op; + sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printBuffer(); + auto z = result.at(0); + // z->printBuffer(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_2) { - auto indices = NDArrayFactory::create('c', { 2, 2 }, { 0.f, 2.f, 1.f, -1.f }); + auto indices = + NDArrayFactory::create('c', {2, 2}, {0.f, 2.f, 1.f, -1.f}); - auto exp = NDArrayFactory::create('c', { 2, 2, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_3) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); - sd::ops::onehot op; + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::onehot op; - auto z = result.at(0); + auto result = op.evaluate({&indices}, {1.0f, 0.0f}, {-1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_4) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - auto depth = NDArrayFactory::create(3.0f); + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); + auto depth = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - sd::ops::onehot op; + sd::ops::onehot op; - auto result = op.evaluate({ &indices, &depth }, { 1.0f, 0.0f }, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&indices, &depth}, {1.0f, 0.0f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_5) { - auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); - auto depth = NDArrayFactory::create(3.0f); - auto on = NDArrayFactory::create(1.0f); - auto off = NDArrayFactory::create(0.0f); + auto indices = + NDArrayFactory::create('c', {4}, {0.0f, 2.0f, -1.0f, 1.0f}); + auto depth = NDArrayFactory::create(3.0f); + auto on = NDArrayFactory::create(1.0f); + auto off = NDArrayFactory::create(0.0f); - auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - sd::ops::onehot op; + sd::ops::onehot op; - auto result = op.evaluate({ &indices, &depth, &on, &off }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&indices, &depth, &on, &off}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests1, OneHotTests_6) { - auto indices = NDArrayFactory::create('c', { 3 }, { 0.f, 1.f, 2.f }); - auto e = NDArrayFactory::create('c', { 3, 3 }, { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }); + auto indices = NDArrayFactory::create('c', {3}, {0.f, 1.f, 2.f}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }); - auto z = result.at(0); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests1, OneHotTests_7) { - auto indices = NDArrayFactory::create('c', { 3 }, { 0, 1, 2 }); - auto e = NDArrayFactory::create('c', { 3, 3 }, { 1., 0., 0., 0., 1., 0., 0., 0., 1. }); + auto indices = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {1., 0., 0., 0., 1., 0., 0., 0., 1.}); - sd::ops::onehot op; - auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }, {}, { sd::DataType::HALF }, false); - auto z = result.at(0); + sd::ops::onehot op; + auto result = op.evaluate({&indices}, {1.0, 0.0}, {0, 3}, {}, + {sd::DataType::HALF}, false); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests1, FillAs_1) { - auto x = NDArrayFactory::create('c', { 2, 2 }); - x.assign(117); + auto x = NDArrayFactory::create('c', {2, 2}); + x.assign(117); - float scalar = 119.f; + float scalar = 119.f; - sd::ops::fill_as op; - auto result = op.evaluate({ &x }, { scalar }, {}); + sd::ops::fill_as op; + auto result = op.evaluate({&x}, {scalar}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(result.at(0))); + ASSERT_TRUE(x.isSameShape(result.at(0))); - ASSERT_NEAR(scalar, result.at(0).meanNumber().e(0), 1e-5f); + ASSERT_NEAR(scalar, result.at(0).meanNumber().e(0), 1e-5f); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, LRN1) { - sd::ops::lrn lrn; + sd::ops::lrn lrn; - lrn.getOpName(); + lrn.getOpName(); } TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - sd::ops::range op; + sd::ops::range op; - auto result = op.evaluate({}, {}, { 1, 5, 1 }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({}, {}, {1, 5, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto array = result.at(0); - // array->printIndexedBuffer("Range integer 1"); - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + auto array = result.at(0); + // array->printIndexedBuffer("Range integer 1"); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } - TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - auto start = NDArrayFactory::create('c', { 1, 1 }); - auto stop = NDArrayFactory::create('c', { 1, 1 }); - auto step = NDArrayFactory::create('c', { 1, 1 }); - start.p(0, 1.f); - stop.p(0, 5.f); - step.p(0, 1.f); + auto start = NDArrayFactory::create('c', {1, 1}); + auto stop = NDArrayFactory::create('c', {1, 1}); + auto step = NDArrayFactory::create('c', {1, 1}); + start.p(0, 1.f); + stop.p(0, 5.f); + step.p(0, 1.f); - sd::ops::range op; + sd::ops::range op; - auto result = op.evaluate({ &start, &stop, &step }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&start, &stop, &step}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto array = result.at(0); + auto array = result.at(0); - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } - TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { - auto exp = NDArrayFactory::create('c', { 4 }); - exp.linspace(1); + auto exp = NDArrayFactory::create('c', {4}); + exp.linspace(1); - sd::ops::range op; + sd::ops::range op; - auto result = op.evaluate({}, { 1.f, 5.f, 1.f }, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({}, {1.f, 5.f, 1.f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto array = result.at(0); + auto array = result.at(0); - ASSERT_TRUE(exp.isSameShape(array)); - ASSERT_TRUE(exp.equalsTo(array)); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test1) { + NDArray input('c', {3, 3}, {-1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f}, + sd::DataType::FLOAT32); - NDArray input('c', { 3, 3 }, { -1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f }, sd::DataType::FLOAT32); + NDArray expOutput('c', {3, 3}, + {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, + 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, + 9.02116571e-05, 2.68917160e-01, 7.30992629e-01}, + sd::DataType::FLOAT32); - NDArray expOutput('c', { 3, 3 }, { 1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01 }, sd::DataType::FLOAT32); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {}, {}); + auto z = results.at(0); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test2) { - NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 3, 3, 3 }, { 4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {3, 3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {3, 3, 3}, + {4.73142e-02, 4.73847e-02, 6.69062e-03, 9.50330e-01, 8.67881e-04, + 9.92976e-01, 2.35563e-03, 9.51747e-01, 3.33106e-04, 4.74259e-02, + 2.26032e-06, 4.74259e-02, 2.91395e-07, 9.99998e-01, 3.94360e-08, + 9.52574e-01, 1.12535e-07, 9.52574e-01, 7.58256e-10, 4.74259e-02, + 1.22325e-11, 1.00000e+00, 1.32293e-11, 1.19203e-01, 3.77513e-11, + 9.52574e-01, 8.80797e-01}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test3) { - NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 3, 3, 3 }, { 2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input('c', {3, 3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {3, 3, 3}, + {2.47262e-03, 1.23395e-04, 3.35350e-04, 1.23395e-04, 4.53979e-05, + 1.23395e-04, 6.14417e-06, 1.23395e-04, 5.56530e-09, 9.97527e-01, + 1.12521e-07, 9.99665e-01, 1.52281e-08, 9.99955e-01, 2.06090e-09, + 9.99994e-01, 2.78912e-10, 6.69285e-03, 3.05146e-07, 9.99876e-01, + 4.13855e-08, 9.99877e-01, 5.60254e-09, 9.99877e-01, 7.58251e-10, + 9.99877e-01, 9.93307e-01}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test4) { - NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 1, 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + NDArray input('c', {1, 5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {1, 5}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test5) { - NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 1, 5 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); + NDArray input('c', {1, 5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {1, 5}, {1, 1, 1, 1, 1}, sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test6) { - NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5, 1 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + NDArray input('c', {5, 1}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5, 1}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 0 }, {}); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {0}, {}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test7) { - NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5, 1 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); + NDArray input('c', {5, 1}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5, 1}, {1, 1, 1, 1, 1}, sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 1 }, {}); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {1}, {}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test8) { - NDArray input('c', { 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + NDArray input('c', {5}, {-1, 1, -2, 2, 3}, sd::DataType::FLOAT32); + NDArray expOutput('c', {5}, {0.01198, 0.08855, 0.00441, 0.24072, 0.65434}, + sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, {}, {}); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {}, {}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test9) { - NDArray input('c', { 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059 }, sd::DataType::FLOAT32); + NDArray input('c', {2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, -8, 8}, + sd::DataType::FLOAT32); + NDArray expOutput('c', {2, 2, 2, 2}, + {0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, + 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, + 0.731059, 0.268941, 0.268941, 0.731059}, + sd::DataType::FLOAT32); - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 2 }, {}); - auto z = results.at(0); + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {2}, {}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test10) { - NDArray input('c', { 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2, 2 }, { 0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.00000 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 4 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input( + 'c', {2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, -8, 8, + -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14, -14, 15, -15, 16, -16}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {2, 2, 2, 2, 2}, + {0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, + 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, + 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, + 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, + 1.000000, 0.000000, 1.000000, 0.00000}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {4}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test11) { - NDArray input('c', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); - NDArray expOutput('c', { 2, 2, 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, 0.475021 }, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 4 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input( + 'c', {2, 2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, + 6, -7, 7, -8, 8, -9, 9, -10, 10, -11, 11, + -12, 12, -13, 13, 14, -14, 15, -15, 16, -16, -2.1, + 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5, 2.5, -2.6, 2.6, + -2.7, 2.7, -2.8, 2.8, -2.9, 2.9, -3.0, 3.0, -3.1, 3.1, -3.2, + 3.2, -3.3, 3.3, 3.4, -3.4, 3.5, -3.5, 3.6, -3.6}, + sd::DataType::FLOAT32); + NDArray expOutput( + 'c', {2, 2, 2, 2, 2, 2}, + {0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, + 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, + 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, + 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, + 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, + 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, + 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, + 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, + 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, + 0.475021}, + sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {4}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, softmax_test12) { - NDArray input('f', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); - NDArray exp('c', { 2, 2, 2, 2, 2, 2 }, { 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, 0.001113 }, sd::DataType::FLOAT32); - - auto expOutput = NDArray('f', { 2, 2, 2, 2, 2, 2 }, sd::DataType::FLOAT32); - expOutput.assign(exp); - - sd::ops::softmax op; - auto results = op.evaluate({ &input }, {}, { 3 }, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + NDArray input( + 'f', {2, 2, 2, 2, 2, 2}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, + 6, -7, 7, -8, 8, -9, 9, -10, 10, -11, 11, + -12, 12, -13, 13, 14, -14, 15, -15, 16, -16, -2.1, + 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5, 2.5, -2.6, 2.6, + -2.7, 2.7, -2.8, 2.8, -2.9, 2.9, -3.0, 3.0, -3.1, 3.1, -3.2, + 3.2, -3.3, 3.3, 3.4, -3.4, 3.5, -3.5, 3.6, -3.6}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 2, 2, 2, 2}, + {0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, + 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, + 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, + 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, + 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, + 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, + 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, + 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, + 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, + 0.001113}, + sd::DataType::FLOAT32); + + auto expOutput = NDArray('f', {2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + expOutput.assign(exp); + + sd::ops::softmax op; + auto results = op.evaluate({&input}, {}, {3}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_1) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., + 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1, 2}); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,1,2 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_2) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, {}, {}, {}, true); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(input)); - ASSERT_TRUE(expected.equalsTo(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_3) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {12., 11., 10., 9., 8., 7., 6., 5., + 4., 3., 2., 1., 24., 23., 22., 21., + 20., 19., 18., 17., 16., 15., 14., 13.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 1,2 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_4) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = { + 16, 15, 14, 13, 20, 19, 18, 17, 24, 23, 22, 21, + 4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, + }; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 16,15,14,13,20,19,18,17,24,23,22,21,4,3,2,1,8,7,6,5,12,11,10,9, }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,2 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_5) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16., + 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16., 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0,1 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_6) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {4., 3., 2., 1., 8., 7., 6., 5., + 12., 11., 10., 9., 16., 15., 14., 13., + 20., 19., 18., 17., 24., 23., 22., 21.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 4., 3., 2., 1., 8., 7., 6., 5., 12., 11., 10., 9., 16., 15., 14., 13., 20., 19., 18., 17., 24., 23., 22., 21. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 2 }, {}, {}, true); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {2}, {}, {}, true); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(input)); - ASSERT_TRUE(expected.equalsTo(&input)); + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_7) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {9., 10., 11., 12., 5., 6., 7., 8., + 1., 2., 3., 4., 21., 22., 23., 24., + 17., 18., 19., 20., 13., 14., 15., 16.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4., 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 1 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - //expected.printIndexedBuffer("E"); - //result->printIndexedBuffer("R"); + auto result = results.at(0); + // expected.printIndexedBuffer("E"); + // result->printIndexedBuffer("R"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_8) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {12., 11., 10., 9., 8., 7., 6., 5., + 4., 3., 2., 1., 24., 23., 22., 21., + 20., 19., 18., 17., 16., 15., 14., 13.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 2,1 }); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printBuffer(); + auto result = results.at(0); + // result->printBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_9) { + float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {13., 14., 15., 16., 17., 18., 19., 20., + 21., 22., 23., 24., 1., 2., 3., 4., + 5., 6., 7., 8., 9., 10., 11., 12.}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); - float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; - float expBuff[] = { 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }; - Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; - ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); - NDArray input(inBuff, shapeInfo); - NDArray expected(expBuff, shapeInfo); - NDArray output(shapeInfo); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0}); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests1, Reverse_10) { - auto x = NDArrayFactory::create('c', { 4, 3 }, { 1.5375735, 0.1592365, 0.09966054, 0.677872, 1.144433, -1.0355669, 0.48456487, -0.67863184, 0.85020787, 0.13950661, 0.20998026, -1.1660044 }); - auto i = NDArrayFactory::create('c', { 1 }, { -1 }); - auto e = NDArrayFactory::create('c', { 4, 3 }, { 0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661 }); + auto x = NDArrayFactory::create( + 'c', {4, 3}, + {1.5375735, 0.1592365, 0.09966054, 0.677872, 1.144433, -1.0355669, + 0.48456487, -0.67863184, 0.85020787, 0.13950661, 0.20998026, + -1.1660044}); + auto i = NDArrayFactory::create('c', {1}, {-1}); + auto e = NDArrayFactory::create( + 'c', {4, 3}, + {0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872, + 0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, + 0.13950661}); - sd::ops::reverse op; - auto result = op.evaluate({ &x, &i }, {}, {}, {}); + sd::ops::reverse op; + auto result = op.evaluate({&x, &i}, {}, {}, {}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_11) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, 15.f, 14.f, 13.f, + 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0, 1, 2}); - auto input = NDArrayFactory::create('c', { 2,3,4 }); - auto expected = NDArrayFactory::create('c', { 2,3,4 }, { 24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, - 15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, - 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }); - - input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0, 1, 2 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_12) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({4.f, 3.f, 2.f, 1.f, 0.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {0}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { 0 }); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - //result->printIndexedBuffer("Result reverse"); - //expected.printIndexedBuffer("Expected reverse"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer("Result reverse"); + // expected.printIndexedBuffer("Expected reverse"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_13) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({4.f, 3.f, 2.f, 1.f, 0.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {-1}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); - - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, { -1 }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Reverse_14) { + auto input = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + auto expected = NDArrayFactory::create({0.f, 1.f, 2.f, 3.f, 4.f}); + // input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({&input}, {}, {}, {}); - auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - auto expected = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); - - //input.linspace(1); - sd::ops::reverse op; - auto results = op.evaluate({ &input }, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests1, Test_Expose_1) { - auto input0 = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 6, 5, 4 }); - auto input1 = NDArrayFactory::create('c', { 2, 3 }, { 3, 2, 1, 4, 5, 6 }); + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - sd::ops::expose op; + sd::ops::expose op; - auto result = op.evaluate({ &input0, &input1 }); + auto result = op.evaluate({&input0, &input1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z0 = result.at(0); - auto z1 = result.at(1); + auto z0 = result.at(0); + auto z1 = result.at(1); - ASSERT_TRUE(input0.equalsTo(z0)); - ASSERT_TRUE(input1.equalsTo(z1)); + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); } TEST_F(DeclarableOpsTests1, Test_Expose_2) { - if (1 > 0) - throw std::runtime_error("Test not implemented yet"); - - auto list = new NDArrayList(0, true); + if (1 > 0) throw std::runtime_error("Test not implemented yet"); - auto var = std::make_shared(NDArray(), "arraylist", -1, 0); - //var->setNDArrayList(list); + auto list = new NDArrayList(0, true); - VariableSpace variableSpace; - variableSpace.putVariable(-1, var); + auto var = std::make_shared(NDArray(), "arraylist", -1, 0); + // var->setNDArrayList(list); - Context block(1, &variableSpace); - block.pickInput(-1); + VariableSpace variableSpace; + variableSpace.putVariable(-1, var); - sd::ops::expose op; - auto result = op.execute(&block); + Context block(1, &variableSpace); + block.pickInput(-1); - ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(variableSpace.hasVariable(1)); + sd::ops::expose op; + auto result = op.execute(&block); - auto var1 = variableSpace.getVariable(1); + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(variableSpace.hasVariable(1)); - ASSERT_EQ(var->variableType(), var1->variableType()); + auto var1 = variableSpace.getVariable(1); - auto list1 = var1->getNDArrayList(); + ASSERT_EQ(var->variableType(), var1->variableType()); - ASSERT_TRUE(list == list1.get()); + auto list1 = var1->getNDArrayList(); + ASSERT_TRUE(list == list1.get()); } TEST_F(DeclarableOpsTests1, Test_Release) { - auto x = NDArrayFactory::create('c', { 8, 8 }); - // x.printShapeInfo("x shape"); + auto x = NDArrayFactory::create('c', {8, 8}); + // x.printShapeInfo("x shape"); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index e2e3d17cc7a6..d28022f05932 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -15,3133 +15,3126 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests10 : public testing::Test { -public: - - DeclarableOpsTests10() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests10 : public testing::Test { -public: - - TypedDeclarableOpsTests10() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests10, TestingTypes); TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { - auto x = NDArrayFactory::create('c', {3, 3}); - auto e = NDArrayFactory::create(8); - - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {3, 3}); + auto e = NDArrayFactory::create(8); + x.linspace(1.0); - sd::ops::argmax op; - auto result = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::argmax op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests10, Test_ArgMax_2) { - auto x = NDArrayFactory::create('c', {3, 3}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 3}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {2, 2, 2}); - x.linspace(1.0); + x.linspace(1.0); - sd::ops::argmax op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::argmax op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - //z.printIndexedBuffer("z"); - //z.printShapeInfo("z shape"); + // z.printIndexedBuffer("z"); + // z.printShapeInfo("z shape"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests10, Test_And_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - sd::ops::boolean_and op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::boolean_and op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Or_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); - auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - sd::ops::boolean_or op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::boolean_or op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, result.at(0)); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests10, Test_Not_1) { - auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); -// auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); - auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); + auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); + // auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); + auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); - sd::ops::boolean_not op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + sd::ops::boolean_not op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); - ASSERT_TRUE(e.equalsTo(res)); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests10, Test_Size_at_1) { - auto x = NDArrayFactory::create('c', {10, 20, 30}); - auto e = NDArrayFactory::create(20); - - sd::ops::size_at op; - auto result = op.evaluate({&x}, {1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {10, 20, 30}); + auto e = NDArrayFactory::create(20); - ASSERT_EQ(e, result.at(0)); + sd::ops::size_at op; + auto result = op.evaluate({&x}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) { + auto in = NDArrayFactory::create({1., 2., 3., 4., 5.}); + // auto pad('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new + // long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 2., 3., 4., 5.}); -// auto pad('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); + auto exp = NDArrayFactory::create({2., 1., 2., 3., 4., 5., 4.}); - auto exp = NDArrayFactory::create({2., 1., 2., 3., 4., 5., 4.}); + sd::ops::mirror_pad op; - sd::ops::mirror_pad op; + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) { - auto input = NDArrayFactory::create({3., 4., 3., 1., 3., 0., 2., 4., 2., 4.}); - auto expIdx = NDArrayFactory::create({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); - auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); - - sd::ops::unique op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res1 = res.at(0); - auto res2 = res.at(1); + auto input = + NDArrayFactory::create({3., 4., 3., 1., 3., 0., 2., 4., 2., 4.}); + auto expIdx = + NDArrayFactory::create({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); - ASSERT_TRUE(exp.equalsTo(res1)); - ASSERT_TRUE(expIdx.equalsTo(res2)); + sd::ops::unique op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + ASSERT_TRUE(exp.equalsTo(res1)); + ASSERT_TRUE(expIdx.equalsTo(res2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) { - auto input = NDArrayFactory::create('c', {3, 3}, {true, false, false, true, true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); + auto input = NDArrayFactory::create( + 'c', {3, 3}, {true, false, false, true, true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create( + 'c', {6, 2}, + {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - - ASSERT_TRUE(exp.isSameShape(resA)); - ASSERT_TRUE(exp.equalsTo(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.isSameShape(resA)); + ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {5, 3}, + {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, + 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { - auto cond3d = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); -// auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp1 = NDArrayFactory::create({0, 0, 1, 1, 1}); - auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); - auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); - sd::ops::where_np op; - auto res = op.evaluate({&cond3d}, {}, {}); - ASSERT_TRUE(res.size() == 3); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto res1 = res.at(0); - auto res2 = res.at(1); - auto res3 = res.at(2); -// res1->printShapeInfo("Res1 shape"); res1->printBuffer("Res1"); -// res2->printShapeInfo("Res2 shape"); res2->printBuffer("Res2"); -// res3->printShapeInfo("Res3 shape"); res3->printBuffer("Res3"); - ASSERT_TRUE(exp1.equalsTo(res1)); - ASSERT_TRUE(exp2.equalsTo(res2)); - ASSERT_TRUE(exp3.equalsTo(res3)); - //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto cond3d = NDArrayFactory::create( + 'c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp1 = NDArrayFactory::create({0, 0, 1, 1, 1}); + auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); + auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); + sd::ops::where_np op; + auto res = op.evaluate({&cond3d}, {}, {}); + ASSERT_TRUE(res.size() == 3); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + auto res3 = res.at(2); + // res1->printShapeInfo("Res1 shape"); res1->printBuffer("Res1"); + // res2->printShapeInfo("Res2 shape"); res2->printBuffer("Res2"); + // res3->printShapeInfo("Res3 shape"); res3->printBuffer("Res3"); + ASSERT_TRUE(exp1.equalsTo(res1)); + ASSERT_TRUE(exp2.equalsTo(res2)); + ASSERT_TRUE(exp3.equalsTo(res3)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { - auto cond2d = NDArrayFactory::create('c', {3, 5}, {true, true, false, false, true, true, true, - true, true, true, false, true, true, true, true}); -// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); - auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); - sd::ops::where_np op; - auto res = op.evaluate({&cond2d}, {}, {}); - ASSERT_TRUE(res.size() == 2); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - ASSERT_TRUE(exp1.equalsTo(res.at(0))); - ASSERT_TRUE(exp2.equalsTo(res.at(1))); - //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto cond2d = NDArrayFactory::create( + 'c', {3, 5}, + {true, true, false, false, true, true, true, true, true, true, false, + true, true, true, true}); + // auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp1 = + NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp2 = + NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + sd::ops::where_np op; + auto res = op.evaluate({&cond2d}, {}, {}); + ASSERT_TRUE(res.size() == 2); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(exp1.equalsTo(res.at(0))); + ASSERT_TRUE(exp2.equalsTo(res.at(1))); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) { - auto input = NDArrayFactory::create({true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4,1}, {0, 2, 3, 4}); - - sd::ops::Where op; - auto res = op.evaluate({&input}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); -// resA->printIndexedBuffer("Result A"); -// resA->printShapeInfo("ShapeA"); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + auto input = NDArrayFactory::create({true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4, 1}, {0, 2, 3, 4}); + sd::ops::Where op; + auto res = op.evaluate({&input}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) { - auto input = NDArrayFactory::create('c', {5, 1}, {true, false, true, true, true}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create('c', {5, 1}, + {true, false, true, true, true}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { - auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - ASSERT_TRUE(resA.isEmpty()); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - //ASSERT_TRUE(exp.equalsTo(resA)); - //ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create( + 'c', {5, 1}, {false, false, false, false, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA.isEmpty()); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + // ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) { - auto input = NDArrayFactory::create('c', {5}, {1, 0, 0, 2, 3}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); + auto input = NDArrayFactory::create('c', {5}, {1, 0, 0, 2, 3}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); - sd::ops::Where op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - //ASSERT_TRUE(resA->isEmpty()); - - ASSERT_TRUE(exp.equalsTo(resA)); - ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + // ASSERT_TRUE(resA->isEmpty()); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { - auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); - - sd::ops::where_np op; - auto res = op.evaluate({&input}, {}, {}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - ASSERT_TRUE(resA.isEmpty()); - //resA->printIndexedBuffer("Result A"); - //resA->printShapeInfo("ShapeA"); - //ASSERT_TRUE(exp.equalsTo(resA)); - //ASSERT_TRUE(exp.isSameShape(resA)); -// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + auto input = NDArrayFactory::create( + 'c', {5, 1}, {false, false, false, false, false}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::where_np op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA.isEmpty()); + // resA->printIndexedBuffer("Result A"); + // resA->printShapeInfo("ShapeA"); + // ASSERT_TRUE(exp.equalsTo(resA)); + // ASSERT_TRUE(exp.isSameShape(resA)); + // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) { - auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); - auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); - auto exp = NDArrayFactory::create(0.6); - - sd::ops::cosine_distance_loss op; - auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); + auto labels = NDArrayFactory::create('c', {2, 3}, + {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create( + 'c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); - ASSERT_TRUE(exp.equalsTo(resA)); + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) { - auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); - auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); - auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); - auto exp = NDArrayFactory::create(0.6); + auto labels = NDArrayFactory::create('c', {2, 3}, + {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + // auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create( + 'c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); - sd::ops::cosine_distance_loss op; - auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); - ASSERT_TRUE(res.status() == ND4J_STATUS_OK); - auto resA = res.at(0); - - ASSERT_TRUE(exp.equalsTo(resA)); + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(exp.equalsTo(resA)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) { + auto x = NDArrayFactory::create('c', {2, 3, 3}); - auto x = NDArrayFactory::create('c', {2, 3, 3}); - - auto exp = NDArrayFactory::create('c', {2, 3, 3}); - x.linspace(1); - exp.linspace(1); - exp.p(0, 0, 2, 0.); - exp.p(1, 0, 2, 0.); - exp.p(0, 2, 0, 0.); - exp.p(1, 2, 0, 0.); + auto exp = NDArrayFactory::create('c', {2, 3, 3}); + x.linspace(1); + exp.linspace(1); + exp.p(0, 0, 2, 0.); + exp.p(1, 0, 2, 0.); + exp.p(0, 2, 0, 0.); + exp.p(1, 2, 0, 0.); - sd::ops::matrix_band_part op; - auto results = op.evaluate({&x}, {}, {1, 1}); + sd::ops::matrix_band_part op; + auto results = op.evaluate({&x}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - //results.at(0)->printIndexedBuffer("MBP Test1"); - //exp.printIndexedBuffer("MBP Expec"); - ASSERT_TRUE(exp.equalsTo(results.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // results.at(0)->printIndexedBuffer("MBP Test1"); + // exp.printIndexedBuffer("MBP Expec"); + ASSERT_TRUE(exp.equalsTo(results.at(0))); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test1) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-0.51, -0.46, -0.41, -0.36, -0.31, -0.26, -0.21, -0.16, -0.11, -0.06, -0.01, 0.04, 0.09, 0.14, 0.19, 0.24, 0.29, 0.34, 0.39, 0.44, 0.49, 0.54, 0.59, 0.61}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.04201, -2.03663, -2.03009, -2.02199,-2.01166, -1.99808, -1.97941, -1.95217,-1.90875, -1.8292 , -1.6416 , -0.942 , - 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-0.51, -0.46, -0.41, -0.36, -0.31, -0.26, -0.21, -0.16, + -0.11, -0.06, -0.01, 0.04, 0.09, 0.14, 0.19, 0.24, + 0.29, 0.34, 0.39, 0.44, 0.49, 0.54, 0.59, 0.61}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + { + -2.04201, -2.03663, -2.03009, -2.02199, -2.01166, -1.99808, + -1.97941, -1.95217, -1.90875, -1.8292, -1.6416, -0.942, + 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, + 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266, + }); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test2) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.38008, -2.30149, -2.22748, -2.1232 ,-1.96979, -1.73736, -1.3973 , -0.98279,-0.61088, -0.34685, -0.17256, -0.0555 , - 3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printIndexedBuffer(); - - // x.applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), &y, &z, true); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, + 0.809, 0.99}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.38008, -2.30149, -2.22748, -2.1232, -1.96979, -1.73736, + -1.3973, -0.98279, -0.61088, -0.34685, -0.17256, -0.0555, + 3.11208, 2.99987, 2.83399, 2.57869, 2.207, 1.77611, + 1.41664, 1.17298, 1.01458, 0.90829, 0.8336, 0.77879}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + + // x.applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, + // pairwise::Atan2, broadcast::Atan2), &y, &z, true); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test3) { - - auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); - auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.33231, -2.41089, -2.48491, -2.58919,-2.74259, -2.97502, 2.9681 , 2.55359, 2.18167, 1.91765, 1.74335, 1.62629, - -1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201}); - - sd::ops::tf_atan2 op; - auto result = op.evaluate({&x, &y}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.001, -0.915, -0.829, -0.743, -0.657, -0.571, -0.485, -0.399, + -0.313, -0.227, -0.141, -0.055, 0.031, 0.117, 0.203, 0.289, + 0.375, 0.461, 0.547, 0.633, 0.719, 0.805, 0.891, 0.977}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, + 0.809, 0.99}); + + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.33231, -2.41089, -2.48491, -2.58919, -2.74259, -2.97502, + 2.9681, 2.55359, 2.18167, 1.91765, 1.74335, 1.62629, + -1.54128, -1.42907, -1.2632, -1.00789, -0.63621, -0.20531, + 0.15416, 0.39782, 0.55622, 0.6625, 0.7372, 0.79201}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test4) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.45527, -2.36165, -2.24628, -2.10492,-2.1703 , -1.86945, -1.50321, -1.15359,-0.25062, -0.17373, -0.13273, -0.10733, - 3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 }); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.45527, -2.36165, -2.24628, -2.10492, -2.1703, -1.86945, + -1.50321, -1.15359, -0.25062, -0.17373, -0.13273, -0.10733, + 3.05688, 3.03942, 3.01293, 2.9681, 2.18167, 1.87635, + 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372}); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&x, &y}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test5) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-2.25712, -2.35074, -2.46611, -2.60747, -2.54209, -2.84294, + 3.07401, 2.72438, 1.82141, 1.74453, 1.70353, 1.67813, + -1.48608, -1.46862, -1.44214, -1.3973, -0.61088, -0.30556, + 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336}); - auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.25712, -2.35074, -2.46611, -2.60747,-2.54209, -2.84294, 3.07401, 2.72438, 1.82141, 1.74453, 1.70353, 1.67813, - -1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 }); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, atan2_test6) { + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-1.001, -0.829, -0.657, -0.485, -0.313, -0.141, 0.031, 0.203, 0.375, + 0.547, 0.719, 0.891}); + auto x = + NDArrayFactory::create('c', {4}, {-0.82, -0.096, 0.085, 0.809}); - auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); - auto x = NDArrayFactory::create('c', { 4}, {-0.82, -0.096, 0.085, 0.809}); - - auto exp = NDArrayFactory::create('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 }); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-2.25712, -1.68608, -1.44214, -0.54006, -2.77695, -2.16855, 0.34972, + 0.24585, 2.71267, 1.74453, 1.45312, 0.8336}); - sd::ops::tf_atan2 op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, IGamma_Test1) { - - auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); - auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); - - auto exp = NDArrayFactory::create('c', {1,3,4}, { - 0.659917, 0.61757898, 0.59726304, 0.58478117, - 0.0066205109, 0.022211598, 0.040677428, 0.059117373, - 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); - - sd::ops::igamma op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1}); + auto x = NDArrayFactory::create('c', {4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {0.659917, 0.61757898, 0.59726304, 0.58478117, 0.0066205109, 0.022211598, + 0.040677428, 0.059117373, 0.0000039433403, 0.000086064574, 0.000436067, + 0.0012273735}); + + sd::ops::igamma op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, IGamma_Test2) { - - auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , - 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); - auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); - auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, - 0.993379, 0.977788, 0.959323, 0.940883, - 0.999996, 0.999914, 0.999564, 0.998773}); - - sd::ops::igammac op; - auto result = op.evaluate({&y, &x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto y = NDArrayFactory::create( + 'c', {1, 3, 4}, + {1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1}); + auto x = NDArrayFactory::create('c', {4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 4}, + {0.340083, 0.382421, 0.402737, 0.415221, 0.993379, 0.977788, 0.959323, + 0.940883, 0.999996, 0.999914, 0.999564, 0.998773}); + + sd::ops::igammac op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LGamma_Test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); - auto x = NDArrayFactory::create('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); - - auto exp = NDArrayFactory::create('c', {3,3}, { - 2.2527127 , 0.5723649 , 0.26086727, - -0.12078223, -0.09580769, 0., - 0.28468287, 0.4348206 , 0.6931472 - }); - - sd::ops::lgamma op; - auto result = op.evaluate({&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printBuffer("OUtput"); -// exp.printBuffer("EXpect"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2.2527127, 0.5723649, 0.26086727, -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206, 0.6931472}); + sd::ops::lgamma op; + auto result = op.evaluate({&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer("OUtput"); + // exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) { + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + limit = 5.; + auto exp = NDArrayFactory::create('c', {5}, {0., 1., 2., 3., 4.}); - auto limit = NDArrayFactory::create('c', {1, 3, 4}); - limit = 5.; - auto exp = NDArrayFactory::create('c', {5}, {0.,1.,2.,3.,4.}); + sd::ops::range op; + auto result = op.evaluate({&limit}, {}, {}, {}); - sd::ops::range op; - auto result = op.evaluate({&limit}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test11) { + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + auto start = NDArrayFactory::create('c', {2, 4}); + limit = 5.; + start = 0.5; + auto exp = + NDArrayFactory::create('c', {5}, {0.5, 1.5, 2.5, 3.5, 4.5}); - auto limit = NDArrayFactory::create('c', {1, 3, 4}); - auto start = NDArrayFactory::create('c', {2, 4}); - limit = 5.; - start = 0.5; - auto exp = NDArrayFactory::create('c', {5}, {0.5,1.5,2.5,3.5,4.5}); - - sd::ops::range op; - auto result = op.evaluate({&start, &limit}, {}, {}, {}); + sd::ops::range op; + auto result = op.evaluate({&start, &limit}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test12) { + auto exp = NDArrayFactory::create( + 'c', {9}, {0.5f, 1.f, 1.5f, 2.f, 2.5f, 3.f, 3.5f, 4.f, 4.5f}); - auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); - - sd::ops::range op; - auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); + sd::ops::range op; + auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, top_k_permuted_test1) { + auto x = + NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = + NDArrayFactory::create({7., 6., 9., 8.}); // Sorted = False + auto expSorted = + NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); - auto expUnsorted = NDArrayFactory::create({7., 6., 9., 8.}); // Sorted = False - auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {4}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {4}, {false}); + auto z = result.at(0); + auto zI = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); - auto z = result.at(0); - auto zI = result.at(1); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); - ASSERT_TRUE(expUnsorted.isSameShape(z)); - ASSERT_TRUE(expUnsorted.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - auto result2 = op.evaluate({&x}, {}, {5}, {true}); + z = result2.at(0); + zI = result2.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - - z = result2.at(0); - zI = result2.at(1); - - ASSERT_TRUE(expSorted.isSameShape(z)); - ASSERT_TRUE(expSorted.equalsTo(z)); + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, top_k_permuted_test2) { + auto x = + NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = + NDArrayFactory::create({7., 5., 6., 9., 8.}); // Sorted = False + auto expSorted = + NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); - auto expUnsorted = NDArrayFactory::create({7., 5., 6., 9., 8.}); // Sorted = False - auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False - + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {5}, {false}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {5}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + auto zI = result.at(1); - auto z = result.at(0); - auto zI = result.at(1); + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); - ASSERT_TRUE(expUnsorted.isSameShape(z)); - ASSERT_TRUE(expUnsorted.equalsTo(z)); + auto result2 = op.evaluate({&x}, {}, {5}, {true}); - auto result2 = op.evaluate({&x}, {}, {5}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); - ASSERT_EQ(ND4J_STATUS_OK, result2.status()); + z = result2.at(0); + zI = result2.at(1); - z = result2.at(0); - zI = result2.at(1); - - ASSERT_TRUE(expSorted.isSameShape(z)); - ASSERT_TRUE(expSorted.equalsTo(z)); + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1) { - - auto labels = NDArrayFactory::create('c', {2,3},{3, 2, 1, 0, 1, 2}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3}, {1.24254, 1.34254, 1.44254, 1.54254, 1.44254, 1.34254}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test1) { + auto labels = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 0, 1, 2}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, {1.24254, 1.34254, 1.44254, 1.54254, 1.44254, 1.34254}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2) { - - auto labels = NDArrayFactory::create('c', {2},{1, 0}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2}, {1.10194, 1.20194}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test2) { + auto labels = NDArrayFactory::create('c', {2}, {1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create('c', {2}, {1.10194, 1.20194}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) { +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test3) { + NDArray labels('c', {1}, std::vector{0}, sd::DataType::INT32); + auto logits = NDArrayFactory::create('c', {1, 3}); + auto expected = NDArrayFactory::create('c', {1}, {1.20194}); - NDArray labels('c', {1}, std::vector{0}, sd::DataType::INT32); - auto logits = NDArrayFactory::create('c', {1,3}); - auto expected = NDArrayFactory::create('c', {1}, {1.20194}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4) { - - auto labels = NDArrayFactory::create('c', {2},{0, 0}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); +TEST_F(DeclarableOpsTests10, + sparse_softmax_cross_entropy_loss_with_logits_test4) { + auto labels = NDArrayFactory::create('c', {2}, {0, 0}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&labels, &logits}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { + auto input = NDArrayFactory::create( + 'c', {2, 3}, {-1.f, 0.f, 1.5f, 2.f, 5.f, 15.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); - auto input = NDArrayFactory::create('c', {2,3},{-1.f, 0.f, 1.5f, 2.f, 5.f, 15.f}); - auto range = NDArrayFactory::create('c', {2}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 5.f, 2.f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.f, 3.f, -1.f, 5.f, + 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); - auto input = NDArrayFactory::create('c', {2,3,4},{0.f, 5.f, 2.f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.f, 3.f, -1.f, 5.f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.f}); - auto range = NDArrayFactory::create('c', {2}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out = results.at(0); + auto out = results.at(0); - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 1, 4, 1}, + {0.f, 5.f, 2.001f, 1.f, -1.f, 2.f, 5.f, 3.f, + 2.999f, 3.00001f, -1.f, 3.99999f, 3.f, 2.f, 1.f, 4.f, + 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.00001f}); + auto range = NDArrayFactory::create('c', {1, 2, 1}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); - auto input = NDArrayFactory::create('c', {2,3,1,4,1},{0.f, 5.f, 2.001f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.999f, 3.00001f, -1.f, 3.99999f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.00001f}); - auto range = NDArrayFactory::create('c', {1,2,1}, {0, 5}); - auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) { - - auto input = NDArrayFactory::create('c', {20,5},{13.8387f,0.1509f,50.39f,30.403f,13.5174f,9.7351f,37.6652f,28.9215f,22.7011f,45.2834f,40.7628f,50.4995f,26.8003f,27.479f,44.633f,6.9109f,48.5004f, - 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, - 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, - 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, - 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, - 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); - auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); - auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + auto input = NDArrayFactory::create( + 'c', {20, 5}, + {13.8387f, 0.1509f, 50.39f, 30.403f, 13.5174f, 9.7351f, 37.6652f, + 28.9215f, 22.7011f, 45.2834f, 40.7628f, 50.4995f, 26.8003f, 27.479f, + 44.633f, 6.9109f, 48.5004f, 46.5971f, 1.6203f, 23.6381f, 38.9661f, + 50.8146f, 17.2482f, 8.0429f, 7.5666f, 7.9709f, 21.8403f, 20.1694f, + 23.3004f, 50.9151f, 46.239f, 38.7323f, 29.6946f, 32.9876f, 23.0013f, + 39.7318f, 19.4486f, 37.6147f, -0.1506f, 5.3246f, 3.6173f, 24.2573f, + 4.3941f, 9.7105f, 24.0364f, 35.3681f, 17.7805f, 35.7681f, 16.4144f, + 17.4362f, 8.4987f, 26.8108f, 36.2937f, 31.6442f, 29.7221f, 8.7445f, + 33.3301f, 4.0939f, 13.078f, 45.1481f, 29.0172f, 21.6548f, 35.408f, + 27.1861f, 2.2576f, 40.6804f, 36.2201f, 29.7352f, 29.1244f, 38.7444f, + 5.8721f, 33.5983f, 48.2694f, 34.4161f, 19.7148f, 13.8085f, 13.6075f, + 22.5042f, 37.8002f, 50.0543f, 48.5314f, 20.3694f, 28.5042f, -0.4679f, + 4.4245f, 18.9837f, 40.7724f, 2.7611f, 44.0431f, 37.186f, 27.7361f, + 14.6001f, 9.1721f, 14.6087f, 21.4072f, 49.3344f, 11.4668f, 14.6171f, + 15.2502f, 5.244f}); + auto range = NDArrayFactory::create('c', {1, 2}, {0, 50}); + auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { - - auto input = NDArrayFactory::create('c', {5,20},{20.f, 0.f, 60.f, 40.f, 20.f, 0.f, 40.f, 0.f, 40.f, 40.f,40.f,60.f, 20.f, 20.f, 60.f, 0.f, 40.f, - 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, - 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, - 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, - 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, - 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); - auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); -// auto exp = NDArrayFactory::create('c', {5}, {23, 19, 20, 23, 15}); // 23, 15, 24, 17, 21 - auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); - - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range}, {}, {5}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out = results.at(0); - - ASSERT_TRUE(exp.isSameShape(out)); - // out->printBuffer("5HIST"); - ASSERT_TRUE(exp.equalsTo(out)); + auto input = NDArrayFactory::create( + 'c', {5, 20}, + {20.f, 0.f, 60.f, 40.f, 20.f, 0.f, 40.f, + 0.f, 40.f, 40.f, 40.f, 60.f, 20.f, 20.f, + 60.f, 0.f, 40.f, 46.5971f, 1.6203f, 23.6381f, 38.9661f, + 50.8146f, 17.2482f, 8.0429f, 7.5666f, 7.9709f, 21.8403f, 20.1694f, + 23.3004f, 50.9151f, 46.239f, 38.7323f, 29.6946f, 32.9876f, 23.0013f, + 39.7318f, 19.4486f, 37.6147f, -0.1506f, 5.3246f, 3.6173f, 24.2573f, + 4.3941f, 9.7105f, 24.0364f, 35.3681f, 17.7805f, 35.7681f, 16.4144f, + 17.4362f, 8.4987f, 26.8108f, 36.2937f, 31.6442f, 29.7221f, 8.7445f, + 33.3301f, 4.0939f, 13.078f, 45.1481f, 29.0172f, 21.6548f, 35.408f, + 27.1861f, 2.2576f, 40.6804f, 36.2201f, 29.7352f, 29.1244f, 38.7444f, + 5.8721f, 33.5983f, 48.2694f, 34.4161f, 19.7148f, 13.8085f, 13.6075f, + 22.5042f, 37.8002f, 50.0543f, 48.5314f, 20.3694f, 28.5042f, -0.4679f, + 4.4245f, 18.9837f, 40.7724f, 2.7611f, 44.0431f, 37.186f, 27.7361f, + 14.6001f, 9.1721f, 14.6087f, 21.4072f, 49.3344f, 11.4668f, 14.6171f, + 15.2502f, 5.244f}); + auto range = NDArrayFactory::create('c', {1, 2}, {0, 50}); + // auto exp = NDArrayFactory::create('c', {5}, {23, 19, 20, 23, + // 15}); // 23, 15, 24, 17, 21 + auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("5HIST"); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { + auto input = NDArrayFactory::create( + 'c', {7}, {0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); + auto range = NDArrayFactory::create('c', {2}, {0, 1}); + auto bins = NDArrayFactory::create(5); - auto input = NDArrayFactory::create('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); - auto range = NDArrayFactory::create('c', {2}, {0, 1}); - auto bins = NDArrayFactory::create(5); + auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); - auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); - sd::ops::histogram_fixed_width op; - auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out = results.at(0); + // out->printShapeInfo(); + // out->printIndexedBuffer(); - auto out = results.at(0); - // out->printShapeInfo(); - // out->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(out)); - ASSERT_TRUE(exp.equalsTo(out)); + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4.f); + NDArray exp = NDArrayFactory::create(5.f); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(4.f); - NDArray exp = NDArrayFactory::create(5.f); - - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) { + NDArray input = NDArrayFactory::create( + 'c', {3, 4}, {10, 11, 9, 12, 8, 7, 6, 5, 1, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({12.f, 8.f, 4.f}); - NDArray input = NDArrayFactory::create('c', {3, 4}, {10, 11, 9, 12, 8, 7, 6, 5, 1, 3, 2, 4}); - NDArray n = NDArrayFactory::create(3); - NDArray exp = NDArrayFactory::create({12.f, 8.f, 4.f}); + // input.linspace(1.f); -// input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) { + NDArray input = NDArrayFactory::create( + 'c', {3, 4}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({1.f, 5.f, 2.f}); - NDArray input = NDArrayFactory::create('c', {3,4}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(3); - NDArray exp = NDArrayFactory::create({1.f, 5.f, 2.f}); - - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) { + NDArray input = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = + NDArrayFactory::create('c', {2, 2}, {10.f, 11.f, 12.f, 4.f}); - NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,2}, {10.f, 11.f, 12.f, 4.f}); - - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) { + NDArray input = NDArrayFactory::create('c', {6, 15}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = NDArrayFactory::create( + 'c', {6}, {5.f, 20.f, 35.f, 50.f, 65.f, 80.f}); - NDArray input = NDArrayFactory::create('c', {6, 15}); - NDArray n = NDArrayFactory::create(4); - NDArray exp = NDArrayFactory::create('c', {6}, {5.f, 20.f, 35.f, 50.f, 65.f, 80.f}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) { + NDArray input = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = + NDArrayFactory::create('c', {2, 2}, {1.f, 7.f, 5.f, 2.f}); - NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,2}, {1.f, 7.f, 5.f, 2.f}); - -// input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(0); + NDArray exp = NDArrayFactory::create(1.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(0); - NDArray exp = NDArrayFactory::create(1.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); - -// input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {0}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { + NDArray input = NDArrayFactory::create( + 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = NDArrayFactory::create(8.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); - NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); - NDArray n = NDArrayFactory::create(4); - NDArray exp = NDArrayFactory::create(8.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); + // input.linspace(1.f); -// input.linspace(1.f); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f, - - 0.6591f, 0.5555f, 0.1596f, 0.3087f, - 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.1800f, 0.6750f, 0.2246f}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f}); + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f}); - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {0}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) { + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f, - - 0.6591f, 0.5555f, 0.1596f, 0.3087f, - 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.1800f, 0.6750f, 0.2246f}); - NDArray n = NDArrayFactory::create(2); - NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f}); + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f}); - //input.linspace(1.f); + // input.linspace(1.f); - sd::ops::nth_element op; - auto results = op.evaluate({&input, &n}, {}, {1}); + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test1) { + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {2}, {3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 1, 2, 3, 1, 2, 3}); - auto input = NDArrayFactory::create('c', {3}); - auto shape = NDArrayFactory::create('c', {2}, {3, 3}); - auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3,1, 2, 3, 1, 2, 3}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test2) { + auto input = NDArrayFactory::create('c', {1, 3}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - auto input = NDArrayFactory::create('c', {1,3}); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test3) { + auto input = NDArrayFactory::create('c', {3, 1}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}); - auto input = NDArrayFactory::create('c', {3,1}); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test4) { + auto input = NDArrayFactory::create(10.); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); - auto input = NDArrayFactory::create(10.); - auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test5) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create('c', {1}, {3.f}); + auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create('c', {1}, {3.f}); - auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test6) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {10.f}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {1}, {10.f}); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test7) { + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1); + auto exp = NDArrayFactory::create('c', {1}, {10.}); - auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create(1); - auto exp = NDArrayFactory::create('c', {1}, {10.}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test8) { + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {3}, {1.f, 3.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - auto input = NDArrayFactory::create('c', {3}); - auto shape = NDArrayFactory::create('c', {3}, {1.f, 3.f, 3.f}); - auto exp = NDArrayFactory::create('c', {1,3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test9) { + auto input = NDArrayFactory::create('c', {5, 1, 1}); + auto shape = + NDArrayFactory::create('c', {5}, {2.f, 1.f, 5.f, 1.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 5, 1, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, + 4.f, 4.f, 5.f, 5.f, 5.f, 1.f, 1.f, 1.f, 2.f, 2.f, + 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1,1}); - auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); - auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f, - 1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f}); - input.linspace(1.f); - - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, broadcast_to_test10) { + auto input = NDArrayFactory::create('c', {5, 1, 3}); + auto shape = + NDArrayFactory::create('c', {5}, {2.f, 1.f, 5.f, 1.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 5, 1, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 1.f, 2.f, 3.f, 4.f, 5.f, + 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1,3}); - auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); - auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f}); - input.linspace(1.f); + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); - sd::ops::broadcast_to op; - auto results = op.evaluate({&input, &shape}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - //result.printIndexedBuffer("Resized to 10x10"); - //expected.printIndexedBuffer("Expect for 10x10"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + input.assign(0.8f); // linspace(1); + auto size = NDArrayFactory::create({65, 65}); + auto ex = NDArrayFactory::create('c', {1, 65, 65, 256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {false}); - input.assign(0.8f); //linspace(1); - auto size = NDArrayFactory::create({65,65}); - auto ex = NDArrayFactory::create('c', {1,65,65,256}); - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - ASSERT_NE(result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); - - input.assign(0.8f); //linspace(1); - auto size = NDArrayFactory::create({65,65}); - auto ex = NDArrayFactory::create('c', {1,65,65,256}); - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); + input.assign(0.8f); // linspace(1); + auto size = NDArrayFactory::create({65, 65}); + auto ex = NDArrayFactory::create('c', {1, 65, 65, 256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - ASSERT_NE(result, ex); + auto result = results.at(0); + ASSERT_NE(result, ex); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1., 2., 3., 4., 2.6, 3.6, 4.6, 5.6, 5., 6., + 7., 8., 7.4, 8.4, 9.4, 10.4, 9., 10., 11., 12., - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1., 2., 3., 4., - 2.6, 3.6, 4.6, 5.6, - 5., 6., 7., 8., - 7.4, 8.4, 9.4, 10.4, - 9., 10., 11., 12., - - 4., 5., 6., 7., - 5.6, 6.6, 7.6, 8.6, - 8., 9., 10., 11., - 10.4, 11.4, 12.4, 13.4, - 12., 13., 14., 15., - - 10., 11., 12., 13., - 11.6, 12.6, 13.6, 14.6, - 14., 15., 16., 17., - 16.4, 17.4, 18.4, 19.4, - 18., 19., 20., 21., + 4., 5., 6., 7., 5.6, 6.6, 7.6, 8.6, 8., 9., + 10., 11., 10.4, 11.4, 12.4, 13.4, 12., 13., 14., 15., - 13., 14., 15., 16., - 14.6, 15.6, 16.6, 17.6, - 17., 18., 19., 20., - 19.4, 20.4, 21.4, 22.4, - 21., 22., 23., 24. - }); - //input = 1.f; - input.linspace(1); + 10., 11., 12., 13., 11.6, 12.6, 13.6, 14.6, 14., 15., + 16., 17., 16.4, 17.4, 18.4, 19.4, 18., 19., 20., 21., - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + 13., 14., 15., 16., 14.6, 15.6, 16.6, 17.6, 17., 18., + 19., 20., 19.4, 20.4, 21.4, 22.4, 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); -// result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); - //expected.printIndexedBuffer("Expect for 10x10"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); + // expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 2.6f, 3.6f, 4.6f, 5.6f, 5.f, 6.f, + 7.f, 8.f, 7.4f, 8.4f, 9.4f, 10.4f, 9.f, 10.f, 11.f, 12.f, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1.f, 2.f, 3.f, 4.f, - 2.6f, 3.6f, 4.6f, 5.6f, - 5.f, 6.f, 7.f, 8.f, - 7.4f, 8.4f, 9.4f, 10.4f, - 9.f, 10.f, 11.f, 12.f, + 4.f, 5.f, 6.f, 7.f, 5.6f, 6.6f, 7.6f, 8.6f, 8.f, 9.f, + 10.f, 11.f, 10.4f, 11.4f, 12.4f, 13.4f, 12.f, 13.f, 14.f, 15.f, - 4.f, 5.f, 6.f, 7.f, - 5.6f, 6.6f, 7.6f, 8.6f, - 8.f, 9.f, 10.f, 11.f, - 10.4f, 11.4f, 12.4f, 13.4f, - 12.f, 13.f, 14.f, 15.f, + 10.f, 11.f, 12.f, 13.f, 11.6f, 12.6f, 13.6f, 14.6f, 14.f, 15.f, + 16.f, 17.f, 16.4f, 17.4f, 18.4f, 19.4f, 18.f, 19.f, 20.f, 21.f, - 10.f, 11.f, 12.f, 13.f, - 11.6f, 12.6f, 13.6f, 14.6f, - 14.f, 15.f, 16.f, 17.f, - 16.4f, 17.4f, 18.4f, 19.4f, - 18.f, 19.f, 20.f, 21.f, + 13.f, 14.f, 15.f, 16.f, 14.6f, 15.6f, 16.6f, 17.6f, 17.f, 18.f, + 19.f, 20.f, 19.4f, 20.4f, 21.4f, 22.4f, 21.f, 22.f, 23.f, 24.f}); + // input = 1.f; + input.linspace(1); - 13.f, 14.f, 15.f, 16.f, - 14.6f, 15.6f, 16.6f, 17.6f, - 17.f, 18.f, 19.f, 20.f, - 19.4f, 20.4f, 21.4f, 22.4f, - 21.f, 22.f, 23.f, 24.f - }); - //input = 1.f; - input.linspace(1); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 4x5"); -// expected.printBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("Resized to 4x5"); + // expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { - - NDArray input = NDArrayFactory::create('c', {2,3,4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - //result.printIndexedBuffer("Resized to 10x10"); - //expected.printIndexedBuffer("Expect for 10x10"); -// result.printShapeInfo("Output shape"); -// expected.printShapeInfo("Expect shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expect for 10x10"); + // result.printShapeInfo("Output shape"); + // expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { - - NDArray input = NDArrayFactory::create('c', {2, 5,5,3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f, - 0.9234f, 0.0856f, 0.7938f, - 0.6591f, 0.5555f, 0.1596f, - 0.3087f, 0.1548f, 0.4695f, - 0.9939f, 0.6113f, 0.6765f, - 0.1800f, 0.6750f, 0.2246f, - 0.0509f, 0.4601f, 0.8284f, - 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, - 0.7679f, 0.5373f, 0.7234f, - 0.2690f, 0.0062f, 0.0327f, - 0.0644f, 0.8428f, 0.7494f, - 0.0755f, 0.6245f, 0.3491f, - 0.5793f, 0.5730f, 0.1822f, - 0.6420f, 0.9143f, 0.3019f, - 0.3574f, 0.1704f, 0.8395f, - 0.5468f, 0.0744f, 0.9011f, - 0.6574f, 0.4124f, 0.2445f, - 0.4248f, 0.5219f, 0.6952f, - 0.4900f, 0.2158f, 0.9549f, - 0.1386f, 0.1544f, 0.5365f, - 0.0134f, 0.4163f, 0.1456f, - 0.4109f, 0.2484f, 0.3330f, - 0.2974f, 0.6636f, 0.3808f, - 0.8664f, 0.1896f, 0.7530f, - 0.7215f, 0.6612f, 0.7270f, - 0.5704f, 0.2666f, 0.7453f, - 0.0444f, 0.3024f, 0.4850f, - 0.7982f, 0.0965f, 0.7843f, - 0.5075f, 0.0844f, 0.8370f, - 0.6103f, 0.4604f, 0.6087f, - 0.8594f, 0.4599f, 0.6714f, - 0.2744f, 0.1981f, 0.4143f, - 0.7821f, 0.3505f, 0.5040f, - 0.1180f, 0.8307f, 0.1817f, - 0.8442f, 0.5074f, 0.4471f, - 0.5105f, 0.6666f, 0.2576f, - 0.2341f, 0.6801f, 0.2652f, - 0.5394f, 0.4690f, 0.6146f, - 0.1210f, 0.2576f, 0.0769f, - 0.4643f, 0.1628f, 0.2026f, - 0.3774f, 0.0506f, 0.3462f, - 0.5720f, 0.0838f, 0.4228f, - 0.0588f, 0.5362f, 0.4756f, - 0.2530f, 0.1778f, 0.0751f, - 0.8977f, 0.3648f, 0.3065f, - 0.4739f, 0.7014f, 0.4473f, - 0.5171f, 0.1744f, 0.3487f}); - - NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 3}, { - 0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f, - 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f, - 0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f, - 0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f, - 0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f, - 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f, - 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f, - 0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f, - 0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f, - 0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f, - 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f, - 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f, - 0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f, - 0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f, - 0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f, - 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f, - 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f, - 0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f, - 0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f, - 0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f, - 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f, - 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f, - 0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f, - 0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f, - 0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f, - 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f, - 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f, - 0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f, - 0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f, - 0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f, - 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f, - 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f, - 0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f, - 0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f, - 0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f, - 0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f, - 0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f, - 0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f, - 0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f, - 0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f, - 0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f, - 0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f, - 0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f, - 0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f, - 0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f, - 0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f, - 0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f, - 0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f, - 0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f, - 0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f, - 0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f, - 0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f, - 0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f, - 0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f, - 0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f, - 0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f, - 0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f, - 0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f, - 0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f, - 0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f, - 0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f, - 0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f, - 0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f, - 0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f, - 0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f, - 0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f, - 0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f, - 0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f, - 0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f, - 0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f, - 0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f, - 0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f, - 0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f, - 0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f, - 0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f, - 0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f, - 0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f, - 0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f, - 0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f, - 0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f, - 0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f - }); - //input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {9, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 9x9"); -// expected.printBuffer("Expect for 9x9"); -// result.printShapeInfo("Output shape"); -// expected.printShapeInfo("Expect shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 9, 9, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, + 0.42217776f, 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, + 0.8373667f, 0.42420003f, 0.59844446f, 0.71318877f, 0.6011445f, + 0.83055556f, 0.264911f, 0.7387556f, 0.83529997f, 0.2422334f, + 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f, 0.6591f, + 0.5555f, 0.1596f, 0.5176333f, 0.44208887f, 0.5827889f, + 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, + 0.44617534f, 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, + 0.6654419f, 0.46649635f, 0.41335183f, 0.39988017f, 0.7140149f, + 0.43368888f, 0.45865932f, 0.72049254f, 0.42537406f, 0.73366547f, + 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f, 0.30312222f, + 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f, + 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, + 0.36275554f, 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, + 0.52441484f, 0.6826569f, 0.10747781f, 0.64715934f, 0.80707777f, + 0.19927411f, 0.8880544f, 0.7861703f, 0.21763334f, 0.9362333f, + 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f, 0.5907889f, + 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f, + 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, + 0.24757043f, 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, + 0.7239556f, 0.6876333f, 0.12114445f, 0.73849255f, 0.54079986f, + 0.12879999f, 0.74139994f, 0.51143324f, 0.32978892f, 0.45314446f, + 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f, 0.68978024f, + 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f, + 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, + 0.65619516f, 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, + 0.58788633f, 0.37666163f, 0.20481117f, 0.57736665f, 0.32585555f, + 0.50801116f, 0.5387556f, 0.29788882f, 0.59799266f, 0.7008482f, + 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f, 0.44849625f, + 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f, + 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, + 0.6675073f, 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, + 0.4595333f, 0.26774442f, 0.52779996f, 0.5559667f, 0.35320008f, + 0.5630963f, 0.62568885f, 0.44562602f, 0.557237f, 0.62408876f, + 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f, 0.30325183f, + 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f, + 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, + 0.3156962f, 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, + 0.52757776f, 0.6382001f, 0.47803456f, 0.3974851f, 0.7738359f, + 0.4686691f, 0.27816284f, 0.8476581f, 0.2775703f, 0.20192216f, + 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f, 0.0927209f, + 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f, + 0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, + 0.32316667f, 0.4248f, 0.5219f, 0.6952f, 0.46102223f, + 0.35184443f, 0.8394778f, 0.45095554f, 0.20897777f, 0.9084111f, + 0.2557333f, 0.17486666f, 0.6759666f, 0.11077777f, 0.21260004f, + 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f, 0.14590007f, + 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f, + 0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, + 0.3808f, 0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, + 0.24200003f, 0.7501111f, 0.76979995f, 0.50400007f, 0.7356667f, + 0.6879222f, 0.57351106f, 0.73106664f, 0.60397774f, 0.35428885f, + 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f, 0.10284433f, + 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f, + 0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, + 0.71313334f, 0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, + 0.41014817f, 0.7074074f, 0.67555183f, 0.51060987f, 0.6708259f, + 0.7151259f, 0.41302344f, 0.6946963f, 0.5446962f, 0.33081108f, + 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f, 0.17217779f, + 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f, + 0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, + 0.7449173f, 0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, + 0.46945432f, 0.5984506f, 0.7796679f, 0.47903457f, 0.617716f, + 0.63706285f, 0.40579626f, 0.54952586f, 0.33111224f, 0.27734566f, + 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f, 0.7874667f, + 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f, + 0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, + 0.46735552f, 0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, + 0.5741519f, 0.41896293f, 0.50037766f, 0.57161117f, 0.3686555f, + 0.28967398f, 0.5281297f, 0.3238592f, 0.24753332f, 0.5194334f, + 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f, 0.3895555f, + 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f, + 0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, + 0.36710492f, 0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, + 0.53567034f, 0.28493333f, 0.32827038f, 0.54560244f, 0.2976741f, + 0.30918893f, 0.5475888f, 0.30022222f, 0.5933333f, 0.44266668f, + 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f, 0.16793211f, + 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f, + 0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, + 0.3110494f, 0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, + 0.21310854f, 0.38097042f, 0.49691117f, 0.21631104f, 0.3877778f, + 0.37919992f, 0.4914f, 0.56826663f, 0.26019996f, 0.34673333f, + 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f, 0.46084452f, + 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f, + 0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, + 0.38596666f, 0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, + 0.11400001f, 0.3981f, 0.11219993f, 0.5287333f, 0.49104443f, + 0.18227404f, 0.3386963f, 0.26007527f, 0.30624574f, 0.20396544f, + 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f, 0.7636852f, + 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f, + 0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, + 0.3657259f, 0.5232f, 0.16433334f, 0.3569333f, 0.0588f, + 0.5362f, 0.4756f, 0.16668889f, 0.33708888f, 0.25309998f, + 0.32463336f, 0.19857779f, 0.10081112f, 0.68280005f, 0.3024667f, + 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f, 0.5680777f, + 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f, + 0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, + 0.3487f}); + // input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // expected.printBuffer("Expect for 9x9"); + // result.printShapeInfo("Output shape"); + // expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - NDArray size = NDArrayFactory::create({10, 10}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, - 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., - 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., - 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, - 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, - 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, - 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, - 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, - 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., - 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, - 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, - 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, - 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, - 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., - 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, - 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, - 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., - 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., - 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, - 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., - 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, - 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., - 20.2,21.2, 22.2, 23.2, - 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + NDArray size = NDArrayFactory::create({10, 10}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, 11.4, 12.4, 13.4, 14.4, + 11.4, 12.4, 13.4, 14.4, 11.4, 12.4, 13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, 13.0, 14.0, 15.0, 16., + 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, 13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 10.6, 11.6, 12.6, 13.6, + 11.8, 12.8, 13.8, 14.8, 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.2, 17.2, 18.2, 19.2, 16.2, 17.2, 18.2, 19.2, + 16.2, 17.2, 18.2, 19.2, 10.6, 11.6, 12.6, 13.6, 11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 18.6, 19.6, 20.6, 21.6, + 18.6, 19.6, 20.6, 21.6, 18.6, 19.6, 20.6, 21.6, 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, 16.6, 17.6, 18.6, 19.6, + 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 15.4, 16.4, 17.4, 18.4, + 16.6, 17.6, 18.6, 19.6, 17.8, 18.8, 19.8, 20.8, 19., 20., 21., 22., + 20.2, 21.2, 22.2, 23.2, 21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, - { 1., 2., 3., 4. , - 1.8888888, 2.8888888, 3.8888888, 4.888889, - 2.7777777, 3.7777777, 4.7777777, 5.7777777, - 3.6666667, 4.666667 , 5.666667, 6.666667 , - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445, - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111111, 11.111111, - 9., 10., 11., 12., - - 2.3333335, 3.3333335, 4.3333335, 5.3333335, - 3.2222223, 4.2222223, 5.2222223, 6.2222223, - 4.111111, 5.111111, 6.111111, 7.111111, - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888888, - 6.777778, 7.777778, 8.777778, 9.777778, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444444, 10.444444, 11.444444, 12.444444, - 10.333333, 11.333333, 12.333333, 13.333333, - - 3.6666667, 4.666667, 5.666667, 6.666667, - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445 , - 6.3333335, 7.3333335, 8.333334, 9.333334 , - 7.2222223, 8.222222, 9.222222, 10.222222 , - 8.111112, 9.111112, 10.111112, 11.111112 , - 9., 10., 11.000001, 12.000001 , - 9.888889, 10.888889, 11.888889, 12.888889 , - 10.777778, 11.777778, 12.777778, 13.777778 , - 11.666667, 12.666667, 13.666667, 14.666667, - - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888889, - 6.7777777, 7.7777777, 8.777779, 9.777779, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111111, 13.111111, 14.111111, 15.111111, - 13., 14., 15., 16., - - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111112, 11.111112, - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777779, 11.777779, 12.777779, 13.777779, - 11.666667, 12.666667, 13.666668, 14.666668, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111112, 13.111112, 14.111112, 15.111112, - 13., 14., 15.0, 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666668, 17.666668, 18.666668, - - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777778, 11.777778, 12.777779, 13.777779, - 11.666667, 12.666666, 13.666666, 14.666666, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 15.222221, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222223, 12.222223, 13.222223, 14.222223, - 12.111112, 13.111112, 14.111112, 15.111112, - 13.000001, 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777779, 15.777779, 16.777779, 17.777779, - 15.666668, 16.666668, 17.666668, 18.666668, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 11.666667, 12.666667, 13.666667, 14.666667, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444446, 16.444447, - 14.333334, 15.333333, 16.333332, 17.333332, - 15.222222, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - 17.88889, 18.88889, 19.88889, 20.88889, - 18.777779, 19.777779, 20.777779, 21.777779, - 19.666668, 20.666668, 21.666668, 22.666668, - - 13., 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666666, 17.666666, 18.666666, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 19.222221, 20.222221, 21.222221, 22.222221, - 20.11111, 21.11111, 22.11111, 23.11111, - 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input}, {}, {10, 10}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 1.8888888, 2.8888888, + 3.8888888, 4.888889, 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, 3.2222223, 4.2222223, + 5.2222223, 6.2222223, 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888888, 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111112, 9.111112, 10.111112, 11.111112, + 9., 10., 11.000001, 12.000001, 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777778, 13.777778, + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888889, 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, 11.222222, 12.222222, + 13.222222, 14.222222, 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 7.666667, 8.666667, + 9.666667, 10.666667, 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, 10.333334, 11.333334, + 12.333334, 13.333334, 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, 13., 14., + 15.0, 16., 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, 15.666667, 16.666668, + 17.666668, 18.666668, + + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 15.222221, 16.222221, + 17.222221, 18.222221, 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, 11.222223, 12.222223, + 13.222223, 14.222223, 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 11.666667, 12.666667, + 13.666667, 14.666667, 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, 14.333334, 15.333333, + 16.333332, 17.333332, 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, 17., 18., + 19., 20., 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, 19.666668, 20.666668, + 21.666668, 22.666668, + + 13., 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 19.222221, 20.222221, + 21.222221, 22.222221, 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { - - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); - NDArray size = NDArrayFactory::create({10, 10}); - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, - { 1., 2., 3., 4. , - 1.8888888, 2.8888888, 3.8888888, 4.888889, - 2.7777777, 3.7777777, 4.7777777, 5.7777777, - 3.6666667, 4.666667 , 5.666667, 6.666667 , - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445, - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111111, 11.111111, - 9., 10., 11., 12., - - 2.3333335, 3.3333335, 4.3333335, 5.3333335, - 3.2222223, 4.2222223, 5.2222223, 6.2222223, - 4.111111, 5.111111, 6.111111, 7.111111, - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888888, - 6.777778, 7.777778, 8.777778, 9.777778, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444444, 10.444444, 11.444444, 12.444444, - 10.333333, 11.333333, 12.333333, 13.333333, - - 3.6666667, 4.666667, 5.666667, 6.666667, - 4.5555553, 5.5555553, 6.5555553, 7.5555553, - 5.4444447, 6.4444447, 7.4444447, 8.444445 , - 6.3333335, 7.3333335, 8.333334, 9.333334 , - 7.2222223, 8.222222, 9.222222, 10.222222 , - 8.111112, 9.111112, 10.111112, 11.111112 , - 9., 10., 11.000001, 12.000001 , - 9.888889, 10.888889, 11.888889, 12.888889 , - 10.777778, 11.777778, 12.777778, 13.777778 , - 11.666667, 12.666667, 13.666667, 14.666667, - - 5., 6., 7., 8., - 5.888889, 6.888889, 7.888889, 8.888889, - 6.7777777, 7.7777777, 8.777779, 9.777779, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111111, 13.111111, 14.111111, 15.111111, - 13., 14., 15., 16., - - 6.3333335, 7.3333335, 8.333334, 9.333334, - 7.2222223, 8.222222, 9.222222, 10.222222, - 8.111111, 9.111111, 10.111112, 11.111112, - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777779, 11.777779, 12.777779, 13.777779, - 11.666667, 12.666667, 13.666668, 14.666668, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 7.666667, 8.666667, 9.666667, 10.666667, - 8.555555, 9.555555, 10.555555, 11.555555, - 9.444445, 10.444445, 11.444445, 12.444445, - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222222, 12.222222, 13.222222, 14.222222, - 12.111112, 13.111112, 14.111112, 15.111112, - 13., 14., 15.0, 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666668, 17.666668, 18.666668, - - 9., 10., 11., 12., - 9.888889, 10.888889, 11.888889, 12.888889, - 10.777778, 11.777778, 12.777779, 13.777779, - 11.666667, 12.666666, 13.666666, 14.666666, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444445, 16.444445, - 14.333334, 15.333334, 16.333334, 17.333334, - 15.222221, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - - 10.333334, 11.333334, 12.333334, 13.333334, - 11.222223, 12.222223, 13.222223, 14.222223, - 12.111112, 13.111112, 14.111112, 15.111112, - 13.000001, 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777779, 15.777779, 16.777779, 17.777779, - 15.666668, 16.666668, 17.666668, 18.666668, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 11.666667, 12.666667, 13.666667, 14.666667, - 12.555555, 13.555555, 14.555555, 15.555555, - 13.444445, 14.444445, 15.444446, 16.444447, - 14.333334, 15.333333, 16.333332, 17.333332, - 15.222222, 16.222221, 17.222221, 18.222221, - 16.11111, 17.11111, 18.11111, 19.11111, - 17., 18., 19., 20., - 17.88889, 18.88889, 19.88889, 20.88889, - 18.777779, 19.777779, 20.777779, 21.777779, - 19.666668, 20.666668, 21.666668, 22.666668, - - 13., 14., 15., 16., - 13.888889, 14.888889, 15.888889, 16.88889, - 14.777778, 15.777778, 16.777779, 17.777779, - 15.666667, 16.666666, 17.666666, 18.666666, - 16.555555, 17.555555, 18.555555, 19.555555, - 17.444445, 18.444445, 19.444445, 20.444445, - 18.333334, 19.333334, 20.333334, 21.333334, - 19.222221, 20.222221, 21.222221, 22.222221, - 20.11111, 21.11111, 22.11111, 23.11111, - 21., 22., 23., 24.}); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_bilinear op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Resized to 10x10"); -// expected.printIndexedBuffer("Expected of 10x10"); -// result.printShapeInfo("Resized to 10x10 shape"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + NDArray size = NDArrayFactory::create({10, 10}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1., 2., 3., 4., 1.8888888, 2.8888888, + 3.8888888, 4.888889, 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, 3.2222223, 4.2222223, + 5.2222223, 6.2222223, 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888888, 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, 4.5555553, 5.5555553, + 6.5555553, 7.5555553, 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111112, 9.111112, 10.111112, 11.111112, + 9., 10., 11.000001, 12.000001, 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777778, 13.777778, + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., 5.888889, 6.888889, + 7.888889, 8.888889, 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, 8.555555, 9.555555, + 10.555555, 11.555555, 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, 11.222222, 12.222222, + 13.222222, 14.222222, 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, 7.2222223, 8.222222, + 9.222222, 10.222222, 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 7.666667, 8.666667, + 9.666667, 10.666667, 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, 10.333334, 11.333334, + 12.333334, 13.333334, 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, 13., 14., + 15.0, 16., 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, 15.666667, 16.666668, + 17.666668, 18.666668, + + 9., 10., 11., 12., 9.888889, 10.888889, + 11.888889, 12.888889, 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, 12.555555, 13.555555, + 14.555555, 15.555555, 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, 15.222221, 16.222221, + 17.222221, 18.222221, 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, 11.222223, 12.222223, + 13.222223, 14.222223, 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 11.666667, 12.666667, + 13.666667, 14.666667, 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, 14.333334, 15.333333, + 16.333332, 17.333332, 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, 17., 18., + 19., 20., 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, 19.666668, 20.666668, + 21.666668, 22.666668, + + 13., 14., 15., 16., 13.888889, 14.888889, + 15.888889, 16.88889, 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, 16.555555, 17.555555, + 18.555555, 19.555555, 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, 19.222221, 20.222221, + 21.222221, 22.222221, 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + // input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 10x10"); + // expected.printIndexedBuffer("Expected of 10x10"); + // result.printShapeInfo("Resized to 10x10 shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test1) { + NDArray start = NDArrayFactory::create(1.); + NDArray finish = NDArrayFactory::create(12.); + NDArray num = NDArrayFactory::create(23); + NDArray expect = NDArrayFactory::create( + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - NDArray start = NDArrayFactory::create(1.); - NDArray finish = NDArrayFactory::create(12.); - NDArray num = NDArrayFactory::create(23); - NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, - 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - - sd::ops::lin_space op; - auto result = op.evaluate({&start, &finish, &num}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - - ASSERT_TRUE(expect.equalsTo(res)); + sd::ops::lin_space op; + auto result = op.evaluate({&start, &finish, &num}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + NDArray expect = NDArrayFactory::create( + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, - 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); - - sd::ops::lin_space op; - auto result = op.evaluate({}, {1, 12}, {23}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - ASSERT_EQ( res.dataType(), sd::DataType::FLOAT32 ); - ASSERT_TRUE(expect.equalsTo(res)); - + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ(res.dataType(), sd::DataType::FLOAT32); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + NDArray expect('c', {23}, + {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, + 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, + sd::DataType::DOUBLE); - NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); - - sd::ops::lin_space op; - auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - - ASSERT_EQ( res.dataType(), expect.dataType()); - ASSERT_TRUE(expect.equalsTo(res)); + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, {sd::DOUBLE}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ(res.dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, -TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); -TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f - }); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4,5}, {false, true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printIndexedBuffer("Resized to 4x5"); -// expected.printBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + ASSERT_EQ(ND4J_STATUS_OK, results.status()); -TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { + auto result = results.at(0); - NDArray input = NDArrayFactory::create('c', {2, 3, 4}); - //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); - //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {4, 5, 4}, { 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 1, 2, 3, 4, - 1, 2, 3, 4, - 5, 6, 7, 8, - 5, 6, 7, 8, - 9, 10, 11, 12, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24, - - 13, 14, 15, 16, - 13, 14, 15, 16, - 17, 18, 19, 20, - 17, 18, 19, 20, - 21, 22, 23, 24 - }); - //input = 1.f; - input.linspace(1); - - sd::ops::resize_nearest_neighbor op; - auto results = op.evaluate({&input}, {}, {4, 5}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - //result.printIndexedBuffer("Resized to 4x5"); - //expected.printIndexedBuffer("Expect for 4x5"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray input = NDArrayFactory::create ('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray expected = NDArrayFactory::create(2.5206409f); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {}, {}); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 4, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - NDArray expected = NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {}, {0}); + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f}); + // input = 1.f; + input.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); - auto result = results.at(0); -// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); -// expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { - - NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - - NDArray expected = NDArrayFactory::create('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f}); - sd::ops::reduce_logsumexp op; - auto results = op.evaluate({&input}, {1.f}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); -// expected.printIndexedBuffer("LSE EXPECTED"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + // NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + // NDArray expected('c', {2,4,4}, + // {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create( + 'c', {4, 5, 4}, {1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - NDArray boxes = NDArrayFactory::create('c', {3,4}); - NDArray scores = NDArrayFactory::create('c', {3}, {1, 2, 3}); - NDArray expected = NDArrayFactory::create('c', {3}, {2, 1, 0}); - boxes.linspace(1.f); + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, + 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scores}, {}, {3}); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 13, 14, 15, 16, 13, 14, 15, 16, 17, 18, + 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + // input = 1.f; + input.linspace(1); - auto result = results.at(0); - //result.printIndexedBuffer("OOOOUUUUTTT"); + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("Resized to 4x5"); + // expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + + NDArray expected = NDArrayFactory::create(2.5206409f); - NDArray boxes = NDArrayFactory::create('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, - 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); - NDArray scales = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {3}, {3,0,5}); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {}); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {1}, {1}); + NDArray expected = + NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput3"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); + // expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { + NDArray input = + NDArrayFactory::create('c', {3, 3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { + NDArray expected = NDArrayFactory::create( + 'c', {1, 3}, {1.0986123f, 1.8619947f, 1.0986123f}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {1}, {1}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(0.5); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("REDUCE_LOGSUMEXP"); + // expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { + NDArray boxes = NDArrayFactory::create('c', {3, 4}); + NDArray scores = NDArrayFactory::create('c', {3}, {1, 2, 3}); + NDArray expected = NDArrayFactory::create('c', {3}, {2, 1, 0}); + boxes.linspace(1.f); + + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scores}, {}, {3}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(Status::OK(), results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("OOOOUUUUTTT"); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput4"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { + NDArray boxes = NDArrayFactory::create( + 'c', {6, 4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); + NDArray scales = NDArrayFactory::create( + 'c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {3}, {3, 0, 5}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput6"); -// result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, - 0.7412f, 0.7607f, 0.1543f, 0.5479f, - 0.8223f, 0.2246f, 0.0049f, 0.6465f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 - NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); - NDArray maxSize = NDArrayFactory::create(2); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput06"); -// result.printShapeInfo("Ouput06 shape is"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput3"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } -TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5); + sd::ops::non_max_suppression op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2329f, - 0.7271f, 0.1804f, 0.5056f, 0.8929f, - 0.5461f, 0.9234f, 0.0856f, 0.7938f}); - NDArray scales = NDArrayFactory::create('c', {3}, {0.7717f, 0.9281f, 0.9846f}); //3, 0, 1, 2, 4, 5 - NDArray maxSize = NDArrayFactory::create(0); - NDArray threshold = NDArrayFactory::create(0.5f); - NDArray scoreThreshold = NDArrayFactory::create(0.5f); - sd::ops::non_max_suppression_v3 op; - auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput6"); + // result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - ASSERT_EQ(Status::OK(), results.status()); +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.8115f, 0.4121f, 0.0771f, 0.4863f, 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.0029f, 0.8135f, 0.4873f}); // 3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = + NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput06"); + // result.printShapeInfo("Ouput06 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - auto result = results.at(0); -// result.printBuffer("NonMaxSuppression OUtput7"); -// result.printShapeInfo("Ouput6 shape is"); - ASSERT_TRUE(result.isEmpty()); +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { + NDArray boxes = NDArrayFactory::create( + 'c', {3, 4}, + {0.7788f, 0.8012f, 0.7244f, 0.2329f, 0.7271f, 0.1804f, 0.5056f, 0.8929f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}); + NDArray scales = NDArrayFactory::create( + 'c', {3}, {0.7717f, 0.9281f, 0.9846f}); // 3, 0, 1, 2, 4, 5 + NDArray maxSize = NDArrayFactory::create(0); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5f); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate( + {&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto result = results.at(0); + // result.printBuffer("NonMaxSuppression OUtput7"); + // result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(result.isEmpty()); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { + NDArray boxes = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + NDArray scores = + NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); // 3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', + { + 1, + }, + {3}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); //3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', {1,}, {3}); - - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap1 Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { + NDArray boxes = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + NDArray scores = + NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); // 3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', + { + 3, + }, + {1, 1, 1}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); //3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { + NDArray boxes = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + NDArray scores = + NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); // 3 + NDArray max_num = NDArrayFactory::create(5); + NDArray expected = NDArrayFactory::create('c', + { + 5, + }, + {1, 1, 1, 1, 1}); - NDArray boxes = NDArrayFactory::create('c', {4,4}, { - 0, 0, 1, 1, - 0, 0.1, 1, 1.1, - 0, -0.1, 1, 0.9, - 0, 10, 1, 11}); - NDArray scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); //3 - NDArray max_num = NDArrayFactory::create(5); - NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); - - sd::ops::non_max_suppression_overlaps op; - auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { - int axis = 0; - NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1,2,3,4}); - NDArray boxes = NDArrayFactory::create('c', {1,4}, {0,0,1,1}); - NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); - NDArray cropSize = NDArrayFactory::create({1, 1}); + int axis = 0; + NDArray images = + NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + NDArray boxes = NDArrayFactory::create('c', {1, 4}, {0, 0, 1, 1}); + NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); + NDArray cropSize = NDArrayFactory::create({1, 1}); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1, 1, 1, 1}, {2.5f}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printIndexedBuffer("Cropped and Resized"); + auto result = results.at(0); + // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { - int axis = 0; - NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1.f, 2.f, 3.f, 4.f}); - NDArray boxes = NDArrayFactory::create('c', {1,4}, {0.f, 0.f, 1.f, 1.f}); - NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); - NDArray cropSize = NDArrayFactory::create({1, 1}); + int axis = 0; + NDArray images = + NDArrayFactory::create('c', {1, 2, 2, 1}, {1.f, 2.f, 3.f, 4.f}); + NDArray boxes = + NDArrayFactory::create('c', {1, 4}, {0.f, 0.f, 1.f, 1.f}); + NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); + NDArray cropSize = NDArrayFactory::create({1, 1}); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {4.f}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1, 1, 1, 1}, {4.f}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { + NDArray images('c', {1, 2, 2, 1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector{0}, sd::DataType::INT64); + NDArray cropSize = NDArrayFactory::create({3, 3}); - NDArray images ('c', {1,2,2,1}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {1}, std::vector{0}, sd::DataType::INT64); - NDArray cropSize = NDArrayFactory::create({3, 3}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, sd::DataType::FLOAT32); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 3, 3, 1}, + {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, + sd::DataType::FLOAT32); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { + NDArray images('c', {1, 2, 2, 1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector({0.}), sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({3, 3}); - NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {1}, std::vector({0.}), sd::DataType::INT32); - NDArray cropSize = NDArrayFactory::create({3, 3}); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 3, 3, 1}, + {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, + sd::DataType::FLOAT32); - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, sd::DataType::FLOAT32); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer("Cropped and Resized"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Cropped and Resized"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { + NDArray images('c', {1, 100, 100, 3}, sd::DataType::FLOAT32); + NDArray boxes('c', {1, 4}, {0, 0, 1, 1}, sd::DataType::FLOAT32); + NDArray boxI('c', {2}, {1, 1}, sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({10, 10}); - NDArray images('c', {1, 100, 100, 3}, sd::DataType::FLOAT32); - NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); - NDArray boxI('c', {2}, {1,1}, sd::DataType::INT32); - NDArray cropSize = NDArrayFactory::create({10, 10}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1, 10, 10,3}, sd::DataType::FLOAT32); + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 10, 10, 3}, sd::DataType::FLOAT32); - sd::ops::crop_and_resize op; - auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - //ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + // ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { - NDArray images = NDArrayFactory::create('c', {2,4,5,3}); - NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, { - 0.f , 0.f , 1.f , 1.f , 0.1f, 0.2f, 0.9f, 0.8f, - 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f - }); - - NDArray colors = NDArrayFactory::create('c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {2,4,5,3}, { - 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, - 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, - - 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, - 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, - 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, - 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f - }); - images.linspace(1.); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - result.syncToHost(); -// result.printBuffer("Bounded boxes"); -// expected.printBuffer("Bounded expec"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create('c', {2, 4, 5, 3}); + NDArray boxes = NDArrayFactory::create( + 'c', {2, 2, 4}, + {0.f, 0.f, 1.f, 1.f, 0.1f, 0.2f, 0.9f, 0.8f, 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, + 0.4f, 0.6f, 0.6f}); + + NDArray colors = NDArrayFactory::create( + 'c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create( + 'c', {2, 4, 5, 3}, + {127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, + 128.f, 129.f, 201.f, 202.f, 203.f, 127.f, 128.f, 129.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, + 128.f, 129.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, + 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, + + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 127.f, 128.f, + 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + images.linspace(1.); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + result.syncToHost(); + // result.printBuffer("Bounded boxes"); + // expected.printBuffer("Bounded expec"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { - NDArray images = NDArrayFactory::create('c', {1,9,9,1}); - NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); - NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected = NDArrayFactory::create('c', {1,9,9,1}, { - 1.1f , 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f , 8.1f , 9.1f , - 10.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f , 17.1f , 18.1f , - 19.1f , 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f , 26.1f , 27.1f , - 28.1f , 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f , 35.1f , 36.1f , - 37.1f , 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f , 44.1f , 45.1f , - 46.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f , 53.1f , 54.1f , - 55.1f , 56.1f, 57.1f, 58.1f, 59.1f , 60.1f, 61.1f , 62.1f , 63.1f , - 64.1f , 65.1f, 66.1f, 67.1f, 68.1f , 69.1f, 70.1f , 71.1f , 72.1f , - 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); - images.linspace(1.1); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.syncToHost(); -// result.printBuffer("Bounded boxes 2"); -// expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create('c', {1, 9, 9, 1}); + NDArray boxes = + NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); + NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 9, 9, 1}, + {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f, 9.1f, + 10.1f, 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f, 17.1f, 18.1f, + 19.1f, 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f, 26.1f, 27.1f, + 28.1f, 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f, 35.1f, 36.1f, + 37.1f, 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f, 44.1f, 45.1f, + 46.1f, 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f, 53.1f, 54.1f, + 55.1f, 56.1f, 57.1f, 58.1f, 59.1f, 60.1f, 61.1f, 62.1f, 63.1f, + 64.1f, 65.1f, 66.1f, 67.1f, 68.1f, 69.1f, 70.1f, 71.1f, 72.1f, + 73.1f, 74.1f, 75.1f, 76.1f, 77.1f, 78.1f, 79.1f, 80.1f, 81.1f}); + images.linspace(1.1); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.syncToHost(); + // result.printBuffer("Bounded boxes 2"); + // expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { - NDArray images = NDArrayFactory::create('c', {2,5,5,1}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, - 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, - 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, - 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, 0.7234f, - 0.2690f, 0.0062f, 0.0327f, 0.0644f, 0.8428f, 0.7494f, - 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, - 0.6420f, 0.9143f}); - - NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f, - 0.7605f, 0.3948f, 0.9493f, 0.8600f, - 0.7876f, 0.8945f, 0.4638f, 0.7157f}); - NDArray colors = NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); - - //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); -// NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { -// 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, -// 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, -// 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, -// 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, -// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, -// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f }); - NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.9441f, 0.9441f, 0.1596f, - 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, - 0.6765f, 0.18f , 0.675f , 0.2246f, 0.0509f, - - 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, - 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, - 0.7234f, 0.269f , 0.0062f, 0.0327f, 0.0644f, - 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, - 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); - sd::ops::draw_bounding_boxes op; - auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Boxes3 output"); -// expected.printBuffer("Boxes3 expect"); - -// result.syncToHost(); -// result.printBuffer("Bounded boxes 2"); -// expected.printBuffer("Bounded expec 2"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray images = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f}); + + NDArray boxes = NDArrayFactory::create( + 'c', {2, 2, 4}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, 0.7876f, 0.8945f, 0.4638f, 0.7157f}); + NDArray colors = + NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); + + // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + // NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { + // 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + // 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, + // 0.9441f, 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, + // 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, + // 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, + // 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, + // 0.8428f, + // 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f + // }); + NDArray expected = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, + 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, 0.9441f, + 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, + 0.18f, 0.675f, 0.2246f, 0.0509f, + + 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.269f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, 0.5793f, + 0.573f, 0.1822f, 0.642f, 0.9143f}); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Boxes3 output"); + // expected.printBuffer("Boxes3 expect"); + + // result.syncToHost(); + // result.printBuffer("Bounded boxes 2"); + // expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { + NDArray x('c', {2, 3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, + sd::DataType::FLOAT32); + NDArray min('c', {}, std::vector{-63.65f}, sd::DataType::FLOAT32); + NDArray max('c', {}, std::vector{0.1f}, sd::DataType::FLOAT32); - NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, sd::DataType::FLOAT32); - NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, sd::DataType::FLOAT32); - NDArray min('c', {}, std::vector{-63.65f}, sd::DataType::FLOAT32); - NDArray max('c', {}, std::vector{0.1f}, sd::DataType::FLOAT32); - - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); -// result.printBuffer("Quantized"); -// exp.printBuffer("Expected"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer("Quantized"); + // exp.printBuffer("Expected"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3}, {-63.75, -63.75, -63.5, -63.5, 0., 0.}); + NDArray min = NDArrayFactory::create(-63.65); + NDArray max = NDArrayFactory::create(0.1); - NDArray x = NDArrayFactory::create('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); - NDArray exp = NDArrayFactory::create('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); - NDArray min = NDArrayFactory::create(-63.65); - NDArray max = NDArrayFactory::create(0.1); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { + NDArray x = NDArrayFactory::create( + 'c', {1, 2, 3, 1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create( + 'c', {1, 2, 3, 1}, {-63.75, -63.75, -63.5, -63.5, 0., 0.}); + NDArray min = NDArrayFactory::create('c', {1}, {-63.65}); + NDArray max = NDArrayFactory::create('c', {1}, {0.1}); - NDArray x = NDArrayFactory::create('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); - NDArray exp = NDArrayFactory::create('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); - NDArray min = NDArrayFactory::create('c', {1},{-63.65}); - NDArray max = NDArrayFactory::create('c', {1}, {0.1}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printIndexedBuffer("Quantized2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, - 0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f, - 0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Quantized03"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, 0.179308f, + 0.505282f, 0.86846f, 0.349958f, 0.509824f, 0.087355f, 0.596913f, + 0.65740f, 0.349958f, 0.159745f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, - 0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f, - 0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Quantized03_1"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, 0.180014f, + 0.504643f, 0.868406f, 0.351335f, 0.508419f, 0.087699f, 0.596635f, + 0.659988f, 0.351335f, 0.160374f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03_1"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, - 0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f, - 0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - //result.printIndexedBuffer("Quantized03_2"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, 0.189097f, + 0.506084f, 0.868069f, 0.349355f, 0.503245f, 0.094548f, 0.592226f, + 0.654610f, 0.349355f, 0.153769f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { - NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f, - 0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f}); - NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - //result->printIndexedBuffer("Quantized03_3"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, + 0.508648f, 0.868295f, 0.343809f, 0.509014f, 0.093048f, 0.593422f, + 0.658224f, 0.343809f, 0.165086f}); + NDArray min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - NDArray x = NDArrayFactory::create('c', {2,4,5,3}); - NDArray exp = NDArrayFactory::create('c', {2,4,5,3},{ - 1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, 7.0588236f, 8.039216f, 9.058824f, - 10.058824f, 10.980392f, 12.078432f, 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f, - 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, 25.058825f, 26.078432f, 26.901962f, - 28.058825f, 29.019608f, 29.92157f, 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f, - 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, 43.058826f, 43.92157f, 45.01961f, - 45.f, 47.058823f, 48.03922f, 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f, - 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, 45.f, 50.f, 62.862747f, - 45.f, 50.f, 65.882355f, 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, - 45.f, 50.f, 70.f}); - NDArray min = NDArrayFactory::create({20.f, 20.f, 20.f}); - NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); - x.linspace(1.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 4"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = result - exp; -// diff.printIndexedBuffer("Difference"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {2, 4, 5, 3}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 4, 5, 3}, + {1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, + 7.0588236f, 8.039216f, 9.058824f, 10.058824f, 10.980392f, 12.078432f, + 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f, + 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, + 25.058825f, 26.078432f, 26.901962f, 28.058825f, 29.019608f, 29.92157f, + 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f, + 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, + 43.058826f, 43.92157f, 45.01961f, 45.f, 47.058823f, 48.03922f, + 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f, + 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, + 45.f, 50.f, 62.862747f, 45.f, 50.f, 65.882355f, + 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f}); + NDArray min = NDArrayFactory::create({20.f, 20.f, 20.f}); + NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); + x.linspace(1.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 4"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { - NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 5, 4},{ - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -19.92157f, -18.980392f, -18.039217f, -16.941177f, - -16.f, -15.058824f, -13.960785f, -13.0196085f, - -11.92157f, -10.980392f, -10.039217f, -8.941177f, - -8.000001f, -7.0588236f, -5.960785f, -5.0196085f, - -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f, - 0.f, 0.94117737f, 2.039215f, 2.9803925f, - 4.07843f, 5.0196075f, 5.960783f, 7.0588226f, - 8.f, 8.941177f, 10.039215f, 10.980392f, - 12.07843f, 13.019608f, 13.960783f, 15.058823f, - 16.f, 16.941177f, 18.039217f, 18.980392f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f, - 20.07843f, 21.019608f, 21.960783f, 23.058823f - }); - NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); - NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); - x.linspace(-60.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 5"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = result - exp; -// diff.printIndexedBuffer("Difference"); - - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {-19.92157f, -18.980392f, -18.039217f, -16.941177f, -19.92157f, + -18.980392f, -18.039217f, -16.941177f, -19.92157f, -18.980392f, + -18.039217f, -16.941177f, -19.92157f, -18.980392f, -18.039217f, + -16.941177f, -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, -19.92157f, + -18.980392f, -18.039217f, -16.941177f, -19.92157f, -18.980392f, + -18.039217f, -16.941177f, -19.92157f, -18.980392f, -18.039217f, + -16.941177f, -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, -16.f, + -15.058824f, -13.960785f, -13.0196085f, -11.92157f, -10.980392f, + -10.039217f, -8.941177f, -8.000001f, -7.0588236f, -5.960785f, + -5.0196085f, -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f, + 0.f, 0.94117737f, 2.039215f, 2.9803925f, 4.07843f, + 5.0196075f, 5.960783f, 7.0588226f, 8.f, 8.941177f, + 10.039215f, 10.980392f, 12.07843f, 13.019608f, 13.960783f, + 15.058823f, 16.f, 16.941177f, 18.039217f, 18.980392f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, + 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, + 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, + 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, + 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, + 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, + 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f}); + NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); + NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); + x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 5"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { - NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); -// NDArray exp = NDArrayFactory::create('c', {3, 5},{ -// 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, -// 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, -// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f -// }); - - NDArray exp = NDArrayFactory::create('c', {3,5}, { - 0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f, - 0.17930824f, 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f, - 0.08735529f, 0.596913f, 0.6574f, 0.34995764f, 0.15974471f}); - NDArray min = NDArrayFactory::create('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - NDArray max = NDArrayFactory::create('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - // x.linspace(-60.); - sd::ops::fake_quant_with_min_max_vars_per_channel op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized per channels 5"); -// exp.printBuffer("Quantized per channest E"); -// auto diff = result - exp; -// diff.printIndexedBuffer("Difference"); - - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + // NDArray exp = NDArrayFactory::create('c', {3, 5},{ + // 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, + // 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, + // 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f + // }); + + NDArray exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f, 0.17930824f, + 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f, 0.08735529f, 0.596913f, + 0.6574f, 0.34995764f, 0.15974471f}); + NDArray min = NDArrayFactory::create( + 'c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create( + 'c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + // x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized per channels 5"); + // exp.printBuffer("Quantized per channest E"); + // auto diff = result - exp; + // diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { - - NDArray x = NDArrayFactory::create('c', {100}); - NDArray exp = NDArrayFactory::create('c', {100}, { - 0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f, - 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f, - 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f, - 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f, - 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f, - 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f, - 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f, - 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f, - 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f, - 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f, - 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f, - 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f, - 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f, - 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f, - 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f, - 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f, - 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f, - 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f, - 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f, - 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f - }); - NDArray min = NDArrayFactory::create('c', {1},{0.0f}); - NDArray max = NDArrayFactory::create('c', {1}, {1.f}); - x.linspace(0., 0.01); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Quantized7"); -// exp.printBuffer("Expected 7"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {100}); + NDArray exp = NDArrayFactory::create( + 'c', {100}, + {0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f, + 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f, + 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f, + 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f, + 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f, + 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f, + 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f, + 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f, + 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f, + 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f, + 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f, + 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f, + 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f, + 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f, + 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f, + 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f, + 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f, + 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f, + 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f, + 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f}); + NDArray min = NDArrayFactory::create('c', {1}, {0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.01); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Quantized7"); + // exp.printBuffer("Expected 7"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { - - NDArray x = NDArrayFactory::create('c', {10}); - NDArray exp = NDArrayFactory::create('c', {10}, { - 0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f, - 0.6f, 0.69803923f, 0.8000001f, 0.8980393f - }); - NDArray min = NDArrayFactory::create('c', {1},{0.0f}); - NDArray max = NDArrayFactory::create('c', {1}, {1.f}); - x.linspace(0., 0.1); - sd::ops::fake_quant_with_min_max_vars op; - auto results = op.evaluate({&x, &min, &max}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// x.printBuffer("SourInput8"); -// result.printBuffer("Quantized8"); -// exp.printBuffer("Expected 8"); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = NDArrayFactory::create('c', {10}); + NDArray exp = NDArrayFactory::create( + 'c', {10}, + {0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f, + 0.6f, 0.69803923f, 0.8000001f, 0.8980393f}); + NDArray min = NDArrayFactory::create('c', {1}, {0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.1); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // x.printBuffer("SourInput8"); + // result.printBuffer("Quantized8"); + // exp.printBuffer("Expected 8"); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { + NDArray arr1('c', {2, 2, 1}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray arr2('c', {2, 2}, {0, 1, 0, 4}, sd::DataType::INT32); - NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); - - NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, sd::DataType::BOOL); + NDArray expd('c', {2, 2, 2}, + {false, true, false, false, false, false, false, true}, + sd::DataType::BOOL); - NDArray result('c', {2,2,2}, sd::DataType::BOOL); + NDArray result('c', {2, 2, 2}, sd::DataType::BOOL); - arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), arr2, result, true); - // result.printIndexedBuffer(); - // expd.printIndexedBuffer(); + arr1.applyTrueBroadcast( + sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, + broadcast::EqualTo), + arr2, result, true); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); - ASSERT_TRUE(expd.isSameShape(result)); - ASSERT_TRUE(expd.equalsTo(result)); + ASSERT_TRUE(expd.isSameShape(result)); + ASSERT_TRUE(expd.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, printIndexedTest_1) { - - NDArray arr('c', {2,2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8,9, 10, 11, 12, 13, 14, 15, 16}, sd::DataType::INT32); -// NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); - -// NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, sd::DataType::BOOL); - -// NDArray result('c', {2,2,2}, sd::DataType::BOOL); - -// arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); - // result.printIndexedBuffer(); - // expd.printIndexedBuffer(); - -// ASSERT_TRUE(expd.isSameShape(result)); -// ASSERT_TRUE(expd.equalsTo(result)); - // arr.printIndexedBuffer("Test Print"); // output as [1, 2, 3, 4, 5, 6, 7, 8] -// -// we want output as -// [[[1 2] -// [3 4]] -// -// [[5 6] -// [7 8]]] -// - ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim - size_t k = 0; // k from 0 to lastDims->size() - Nd4jLong rank = 4; // in this case + NDArray arr('c', {2, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + sd::DataType::INT32); + // NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); + + // NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, sd::DataType::BOOL); + + // NDArray result('c', {2,2,2}, sd::DataType::BOOL); + + // arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, + // pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); + + // ASSERT_TRUE(expd.isSameShape(result)); + // ASSERT_TRUE(expd.equalsTo(result)); + // arr.printIndexedBuffer("Test Print"); // output as [1, 2, 3, 4, 5, 6, 7, 8] + // + // we want output as + // [[[1 2] + // [3 4]] + // + // [[5 6] + // [7 8]]] + // + ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim + size_t k = 0; // k from 0 to lastDims->size() + Nd4jLong rank = 4; // in this case + printf("["); + for (Nd4jLong i = 0; i < rank - 1; i++) { + for (Nd4jLong l = 0; l < i; ++l) printf("\n"); printf("["); - for (Nd4jLong i = 0; i < rank - 1; i++) { - - for (Nd4jLong l = 0; l < i; ++l) - printf("\n"); - printf("["); - for (Nd4jLong j = 0; j < arr.sizeAt(i); j++) { - // if (!i) - // printf("["); - // else - // printf(" "); - lastDims.at(k++).printBuffer(); - //if (k == arr.sizeAt(i)) - // printf("]\n"); - } - printf("]\n"); + for (Nd4jLong j = 0; j < arr.sizeAt(i); j++) { + // if (!i) + // printf("["); + // else + // printf(" "); + lastDims.at(k++).printBuffer(); + // if (k == arr.sizeAt(i)) + // printf("]\n"); } printf("]\n"); + } + printf("]\n"); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index c7ce2530a26f..aa2f75cca781 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -14,3980 +14,4496 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include #include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests11 : public testing::Test { -public: - - DeclarableOpsTests11() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests11() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests11, test_listdiff_1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{2}, {3, 1}); + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2}, {3, 1}); - sd::ops::listdiff op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, - -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444, - -0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434}); - NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, - -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-12.49997, -13.04346, -13.63635, -14.28571, -14.99999, -15.78947, + -16.66666, -17.64705, -18.75, -20., -21.42857, -23.07692, + -24.99999, -27.27272, -29.99999, -33.33332, -37.49999, -42.85713, + -49.99998, -59.99998, -74.99995, -99.99992, -149.99986, -299.99911}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, + 6.93973, 6.41584, 5.62456, 4.56548, 3.2326, 1.61444, + -0.30659, -2.55529, -5.16569, -8.18417, -11.67468, -15.72734, + -20.47379, -26.11644, -32.9902, -41.71318, -53.64824, -73.05434}); + NDArray dLdlExp('c', {2, 3, 4}, + {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, + 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002, -0.12058, -0.20273, -0.28768, -0.37689, -0.47223, + -0.57634, -0.69315, -0.82911, -0.99621, -1.22117, -1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = + op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497}); + NDArray dLdwExp('c', {2, 1, 4}, + {15.99805, 16.72406, 16.27746, 14.83754, -44.97147, -59.99582, + -79.28771, -107.35497}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, - -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {}, std::vector{-227.77286}); - NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, - -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-12.49997, -13.04346, -13.63635, -14.28571, -14.99999, -15.78947, + -16.66666, -17.64705, -18.75, -20., -21.42857, -23.07692, + -24.99999, -27.27272, -29.99999, -33.33332, -37.49999, -42.85713, + -49.99998, -59.99998, -74.99995, -99.99992, -149.99986, -299.99911}); + NDArray dLdwExp('c', {}, std::vector{-227.77286}); + NDArray dLdlExp('c', {2, 3, 4}, + {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, + 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002, -0.12058, -0.20273, -0.28768, -0.37689, -0.47223, + -0.57634, -0.69315, -0.82911, -0.99621, -1.22117, -1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887}); + NDArray dLdwExp('c', {1, 3, 1}, {4.8876, -46.29156, -186.36887}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - // dLdw->printIndexedBuffer(); - // dLdw->printShapeInfo(); + auto dLdw = results.at(1); + // dLdw->printIndexedBuffer(); + // dLdw->printShapeInfo(); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308, - -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); - NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541, - 0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698}); - NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334, - -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-1.04166, -1.08696, -1.13636, -1.19048, -1.25, -1.31579, + -1.38889, -1.47059, -1.5625, -1.66667, -1.78571, -1.92308, + -2.08333, -2.27273, -2.5, -2.77778, -3.125, -3.57143, + -4.16667, -5., -6.25, -8.33333, -12.49999, -24.99993}); + NDArray dLdwExp('c', {2, 3, 4}, + {1.05912, 1.20488, 1.29964, 1.35815, 1.3871, 1.39009, + 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541, + 0.76533, 0.57794, 0.3604, 0.10886, -0.18201, -0.51973, + -0.91527, -1.38549, -1.95831, -2.68522, -3.67981, -5.29698}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, + 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334, + -0.00334, -0.01005, -0.01689, -0.02397, -0.03141, -0.03935, + -0.04803, -0.05776, -0.06909, -0.08302, -0.10176, -0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {6.73432, 2.46939, -9.20372}); - NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769, - -2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991}); - NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 , - 1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004, - -0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -1.5, -1.57895, + -1.66667, -1.76471, -1.875, -2., -2.14286, -2.30769, + -2.5, -2.72727, -3., -3.33333, -3.75, -4.28571, + -5., -6., -7.49999, -9.99999, -14.99999, -29.99991}); + NDArray dLdwExp('c', {2, 3, 4}, + {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, + 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058, + 1.2137, 0.98883, 0.72779, 0.42594, 0.07689, -0.32837, + -0.80302, -1.36728, -2.05466, -2.92696, -4.12046, -6.06107}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0.06931, 0.05763, + 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004, + -0.004, -0.01206, -0.02027, -0.02877, -0.03769, -0.04722, + -0.05763, -0.06931, -0.08291, -0.09962, -0.12212, -0.1589}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test9) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154, - -1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996}); - NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727, - -0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393}); - NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167, - -0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.52083, -0.54348, -0.56818, -0.59524, -0.625, -0.65789, + -0.69444, -0.73529, -0.78125, -0.83333, -0.89286, -0.96154, + -1.04167, -1.13636, -1.25, -1.38889, -1.5625, -1.78571, + -2.08333, -2.5, -3.125, -4.16666, -6.24999, -12.49996}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.13412, 0.207, 0.25438, 0.28364, 0.29811, 0.2996, + 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727, + -0.01277, -0.10647, -0.21524, -0.34101, -0.48645, -0.65531, + -0.85307, -1.08819, -1.37459, -1.73805, -2.23534, -3.04393}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, + 0.01968, 0.0157, 0.01199, 0.00845, 0.00502, 0.00167, + -0.00167, -0.00502, -0.00845, -0.01199, -0.0157, -0.01968, + -0.02401, -0.02888, -0.03455, -0.04151, -0.05088, -0.06621}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,1}, std::vector{-9.49054}); + NDArray dLdwExp('c', {1, 1}, std::vector{-9.49054}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); + auto dLdw = results.at(1); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {0.20365, -1.92882, -7.76537}); - NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdw = results.at(1); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846, - -1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956}); - NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072, - -0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 , - -0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945}); - - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - weights.t(3) = 0.; - - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.75, -0.789473, + -0.833333, -0.882353, -0.9375, -1., -1.071428, -1.153846, + -1.25, -1.363636, -1.5, -1.666666, -1.875, -2.142857, + -2.499999, -2.999999, -3.749997, -4.999997, -7.499993, -14.999956}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.16094, 0.2484, 0.30526, 0.34036, 0.35773, 0.35953, + 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072, + -0.01533, -0.12776, -0.25828, -0.40921, -0.58373, -0.78637, + -1.02369, -1.30582, -1.64951, -2.08566, -2.68241, -3.65272}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0.03466, 0.02882, + 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002, + -0.002, -0.00603, -0.01014, -0.01438, -0.01884, -0.02361, + -0.02882, -0.03466, -0.04146, -0.04981, -0.06106, -0.07945}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + weights.t(3) = 0.; + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); - NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383}); - NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - - sd::ops::log_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -2.08333, -2.27273, -2.5, -2.77778, -3.125, -3.57143, + -4.16667, -5., -6.25, -8.33333, -12.49999, -24.99993}); + NDArray dLdwExp('c', {2, 3, 1}, + {1.75828, 2.30839, 1.25309, -1.35098, -6.16602, -16.78383}); + NDArray dLdlExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.00334, -0.01005, -0.01689, -0.02397, -0.03141, -0.03935, + -0.04803, -0.05776, -0.06909, -0.08302, -0.10176, -0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { - - NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { - 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, - 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, - 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, - 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, - 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, - 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f - }); - NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { - 1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f, - 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f, - 3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f, - 5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, - 6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, - 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f, - 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f, - 5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f, - 6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, - 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, - 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f, - 5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f, - 6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f, - 8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f, - 9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, - 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f, - 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f, - 8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f, - 10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f, - 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, - 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, - 8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f, - 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, - 12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, - 12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, - 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f, - 10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f, - 12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f, - 14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f, - 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, - 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, - 12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f, - 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, - 15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f, - 16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, - 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, - 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, - 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, - 16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, - 17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, - 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, - 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, - 16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, - 18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f, - 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, - 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f, - 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, - 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, - 20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, - 21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, - 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f, - 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, - 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, - 21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f, - 23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, - 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f, - 20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f, - 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f, - 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, - 24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, - 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, - 21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f, - 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f, - 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f, - 25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, - 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f, - 23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, - 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, - 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, - 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, - 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f, - 25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f, - 27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f, - 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, - 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, - 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f, - 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f, - 28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f, - 30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, - 31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, - 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, - 28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f, - 30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f, - 31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f, - 33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, - 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f, - 31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f, - 32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f, - 33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f, - 35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, - 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f, - 33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f, - 34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f, - 36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f, - 37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, - 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f, - 34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f, - 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, - 37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f, - 38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, - 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f, - 35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f, - 37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f, - 38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f, - 40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, - 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f, - 37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f, - 38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f, - 40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f, - 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, - 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f, - 39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f, - 41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f, - 42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f, - 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, - 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f, - 41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f, - 42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f, - 44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f, - 45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, - 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f, - 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f, - 44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f, - 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f, - 47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, - 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, - 44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f, - 45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, - 47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f, - 48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, - 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f, - 45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f, - 47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f, - 48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f, - 49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, - 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f, - 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, - 47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f, - 49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f, - 50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, - 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f, - 46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f, - 47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f, - 48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f, - 50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, - 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f, - 45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f, - 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f, - 48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f, - 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); - - auto size = NDArrayFactory::create({30, 30}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - -// result.printBuffer("Resized to 30x30"); -// expected.printBuffer("Expect for 30x30"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 7, 7, 1}, + {1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, 8.f, 9.1f, 10.f, + 11.f, 12.9f, 13.1f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 30.f, 31.f, + 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, + 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 30, 30, 1}, + {1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, + 2.3283265f, 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, + 3.5441515f, 3.7380352f, 3.948995f, 4.248106f, 4.5073795f, + 4.6843743f, 4.8572845f, 5.104302f, 5.3869915f, 5.581401f, + 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, 6.718899f, + 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, + 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, + 3.5745337f, 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, + 4.740569f, 4.9217057f, 5.133866f, 5.459533f, 5.7744613f, + 6.0197873f, 6.254011f, 6.535633f, 6.8097296f, 6.9607787f, + 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, 7.954571f, + 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, + 4.9556503f, 5.160583f, 5.3258467f, 5.535462f, 5.84216f, + 6.058749f, 6.223753f, 6.437597f, 6.797369f, 7.1836042f, + 7.5164022f, 7.8290343f, 8.154773f, 8.417635f, 8.512958f, + 8.5521f, 8.649708f, 8.87788f, 9.108794f, 9.320926f, + 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, + 6.6005697f, 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, + 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, + 9.26469f, 9.649218f, 10.015648f, 10.268647f, 10.313368f, + 10.2843275f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, + 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, + 8.638187f, 8.827649f, 8.976217f, 9.168955f, 9.45726f, + 9.6442375f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, + 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.4374485f, + 12.370511f, 12.371386f, 12.545973f, 12.766424f, 12.992249f, + 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, + 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, + 10.78547f, 10.974367f, 11.123442f, 11.31637f, 11.603645f, + 11.790616f, 11.930889f, 12.144082f, 12.546447f, 13.024898f, + 13.4723f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, + 14.50145f, 14.515459f, 14.700572f, 14.927055f, 15.156046f, + 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, + 12.238887f, 12.43254f, 12.588294f, 12.787534f, 13.079956f, + 13.27752f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, + 14.827978f, 15.191209f, 15.549808f, 15.81343f, 15.881828f, + 15.883522f, 15.950411f, 16.16933f, 16.40794f, 16.636436f, + 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, + 13.464224f, 13.665207f, 13.830567f, 14.039036f, 14.339629f, + 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, + 15.924977f, 16.213829f, 16.532364f, 16.8099f, 16.934835f, + 17.012146f, 17.150164f, 17.413412f, 17.666712f, 17.892765f, + 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, + 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, + 16.126892f, 16.301655f, 16.50487f, 16.815214f, 17.107498f, + 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, + 18.4461f, 18.651634f, 18.956806f, 19.22382f, 19.447308f, + 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, + 17.14954f, 17.361883f, 17.542162f, 17.764957f, 18.078188f, + 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, + 19.410137f, 19.583265f, 19.839512f, 20.13878f, 20.35177f, + 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, + 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, + 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, + 18.738056f, 18.951529f, 19.133352f, 19.357613f, 19.672083f, + 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, + 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, + 22.119276f, 22.37999f, 22.71991f, 22.998823f, 23.22097f, + 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, + 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.88284f, + 21.12133f, 21.30459f, 21.505253f, 21.792645f, 22.038572f, + 22.204426f, 22.37289f, 22.626648f, 22.926834f, 23.143423f, + 23.343302f, 23.596668f, 23.931936f, 24.209232f, 24.431519f, + 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, + 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, + 21.377607f, 21.589132f, 21.768297f, 21.99003f, 22.302366f, + 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, + 23.653934f, 23.83589f, 24.096842f, 24.394371f, 24.600555f, + 24.786541f, 25.026773f, 25.353731f, 25.62813f, 25.850672f, + 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, + 23.57634f, 23.787327f, 23.96576f, 24.186796f, 24.498543f, + 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, + 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, + 26.992926f, 27.227505f, 27.550882f, 27.824034f, 28.046684f, + 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, + 25.642756f, 25.853743f, 26.032173f, 26.25321f, 26.564959f, + 26.79954f, 26.97954f, 27.181242f, 27.47763f, 27.74168f, + 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, + 29.059345f, 29.293922f, 29.617298f, 29.890451f, 30.113104f, + 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, + 27.213314f, 27.424305f, 27.602734f, 27.823772f, 28.135519f, + 28.3701f, 28.550098f, 28.7518f, 29.04819f, 29.312237f, + 29.5f, 29.687763f, 29.951813f, 30.2482f, 30.449903f, + 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, + 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, + 28.783876f, 28.994865f, 29.173294f, 29.39433f, 29.70608f, + 29.940659f, 30.120655f, 30.32236f, 30.618746f, 30.882797f, + 31.070557f, 31.25832f, 31.522371f, 31.818754f, 32.02046f, + 32.20046f, 32.43504f, 32.758415f, 33.031567f, 33.25422f, + 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, + 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, + 30.85029f, 31.061283f, 31.239712f, 31.46075f, 31.7725f, + 32.00708f, 32.187077f, 32.38878f, 32.685165f, 32.949215f, + 33.13698f, 33.32474f, 33.58879f, 33.885178f, 34.086884f, + 34.26688f, 34.501457f, 34.824837f, 35.09799f, 35.320637f, + 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, + 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, + 33.046757f, 33.257744f, 33.436176f, 33.657207f, 33.96896f, + 34.203537f, 34.383537f, 34.58524f, 34.88163f, 35.145676f, + 35.33344f, 35.521206f, 35.785255f, 36.081642f, 36.28334f, + 36.46334f, 36.69792f, 37.021297f, 37.294453f, 37.517097f, + 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, + 34.466957f, 34.677948f, 34.856377f, 35.077415f, 35.38916f, + 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, + 36.753647f, 36.941406f, 37.205456f, 37.50184f, 37.703545f, + 37.883545f, 38.118122f, 38.4415f, 38.714653f, 38.9373f, + 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, + 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, + 35.6781f, 35.889088f, 36.067516f, 36.28855f, 36.6003f, + 36.834885f, 37.014877f, 37.216583f, 37.51297f, 37.77702f, + 37.964783f, 38.152546f, 38.416595f, 38.71298f, 38.914684f, + 39.094685f, 39.32926f, 39.652645f, 39.925793f, 40.14844f, + 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, + 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, + 37.271378f, 37.48237f, 37.6608f, 37.881836f, 38.19359f, + 38.42817f, 38.608162f, 38.809868f, 39.10625f, 39.3703f, + 39.558064f, 39.74583f, 40.00988f, 40.306267f, 40.50797f, + 40.68797f, 40.92255f, 41.245926f, 41.519077f, 41.741722f, + 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, + 39.52832f, 39.739307f, 39.917736f, 40.138775f, 40.45052f, + 40.685104f, 40.865097f, 41.066803f, 41.36319f, 41.627243f, + 41.815002f, 42.002766f, 42.26682f, 42.5632f, 42.764908f, + 42.944904f, 43.179485f, 43.50286f, 43.776016f, 43.998665f, + 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, + 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, + 41.440395f, 41.651382f, 41.82982f, 42.050854f, 42.3626f, + 42.597183f, 42.77718f, 42.97888f, 43.27527f, 43.53932f, + 43.72708f, 43.914845f, 44.178894f, 44.47528f, 44.676983f, + 44.856983f, 45.09156f, 45.41494f, 45.68809f, 45.91074f, + 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, + 42.998936f, 43.209923f, 43.388355f, 43.609394f, 43.921143f, + 44.15572f, 44.335716f, 44.53742f, 44.833805f, 45.09786f, + 45.285614f, 45.473377f, 45.737427f, 46.033817f, 46.235523f, + 46.415524f, 46.650105f, 46.973476f, 47.24663f, 47.469276f, + 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, + 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, + 44.328457f, 44.539444f, 44.717873f, 44.93891f, 45.25066f, + 45.48524f, 45.665237f, 45.86694f, 46.163326f, 46.427376f, + 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.56505f, + 47.74505f, 47.979626f, 48.302998f, 48.576153f, 48.798798f, + 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, + 45.51718f, 45.72817f, 45.9066f, 46.12764f, 46.439384f, + 46.673965f, 46.853966f, 47.055668f, 47.352055f, 47.6161f, + 47.803867f, 47.99163f, 48.25568f, 48.552063f, 48.75377f, + 48.933773f, 49.16835f, 49.491726f, 49.764877f, 49.987526f, + 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, + 45.98499f, 46.195976f, 46.374413f, 46.595448f, 46.907196f, + 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.08391f, + 48.27168f, 48.459446f, 48.72349f, 49.019882f, 49.22158f, + 49.401585f, 49.63616f, 49.959538f, 50.232693f, 50.455338f, + 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, + 45.82328f, 46.03427f, 46.2127f, 46.433743f, 46.74549f, + 46.98007f, 47.160065f, 47.36177f, 47.658157f, 47.922207f, + 48.10997f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, + 49.239872f, 49.47445f, 49.79783f, 50.07098f, 50.293625f, + 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, + 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, + 45.43256f, 45.643543f, 45.82198f, 46.04302f, 46.354763f, + 46.589344f, 46.76934f, 46.971046f, 47.267433f, 47.531483f, + 47.719242f, 47.907005f, 48.17105f, 48.467438f, 48.66914f, + 48.849144f, 49.08372f, 49.4071f, 49.680256f, 49.902905f, + 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + + // result.printBuffer("Resized to 30x30"); + // expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { - - NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, - 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, - 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f, - 34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f, - 38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f, - 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, - 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, - 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f, - 45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f, - 50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f, - 53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f, - 52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f, - 56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f, - 50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f, - 54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f, - 59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f, - 62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f, - 67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f, - 70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f, - 68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f, - 73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f, - 73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f, - 77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f, - 82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f, - 80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f, - 85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f, - 88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f, - 88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f, - 92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f, - 91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f, - 95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f, - 100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f, - 99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f, - 102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f, - 106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f, - 106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f, - 109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f, - 113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f, - 111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f, - 114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f, - 118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f, - 112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f, - 115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f, - 119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f - }); //input = 1.f; - input.linspace(1); - auto size = NDArrayFactory::create({10, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 10x8"); -// expected.printBuffer("Expect for 10x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {2, 10, 8, 3}, + {1.000000f, 2.000000f, 3.000000f, 2.218750f, + 3.218750f, 4.218750f, 4.000000f, 5.000000f, + 6.000000f, 5.500000f, 6.500000f, 7.500000f, + 7.000000f, 8.000000f, 9.000000f, 8.781250f, + 9.781250f, 10.781250f, 10.000000f, 11.000000f, + 12.000000f, 10.281250f, 11.281250f, 12.281250f, + 5.875000f, 6.875000f, 7.875000f, 7.093750f, + 8.093750f, 9.093750f, 8.875000f, 9.875000f, + 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, + 14.656250f, 15.656250f, 14.875000f, 15.875000f, + 16.875000f, 15.156250f, 16.156250f, 17.156250f, + 13.000000f, 14.000000f, 15.000000f, 14.218750f, + 15.218750f, 16.218750f, 16.000000f, 17.000000f, + 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, + 21.781250f, 22.781250f, 22.000000f, 23.000000f, + 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, + 21.218750f, 22.218750f, 22.000000f, 23.000000f, + 24.000000f, 23.500000f, 24.500000f, 25.500000f, + 25.000000f, 26.000000f, 27.000000f, 26.781250f, + 27.781250f, 28.781250f, 28.000000f, 29.000000f, + 30.000000f, 28.281250f, 29.281250f, 30.281250f, + 25.000000f, 26.000000f, 27.000000f, 26.218750f, + 27.218750f, 28.218750f, 28.000000f, 29.000000f, + 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, + 33.781250f, 34.781250f, 34.000000f, 35.000000f, + 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 31.000000f, 32.000000f, 33.000000f, 32.218750f, + 33.218750f, 34.218750f, 34.000000f, 35.000000f, + 36.000000f, 35.500000f, 36.500000f, 37.500000f, + 37.000000f, 38.000000f, 39.000000f, 38.781250f, + 39.781250f, 40.781250f, 40.000000f, 41.000000f, + 42.000000f, 40.281250f, 41.281250f, 42.281250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, + 39.218750f, 40.218750f, 40.000000f, 41.000000f, + 42.000000f, 41.500000f, 42.500000f, 43.500000f, + 43.000000f, 44.000000f, 45.000000f, 44.781250f, + 45.781250f, 46.781250f, 46.000000f, 47.000000f, + 48.000000f, 46.281250f, 47.281250f, 48.281250f, + 44.125000f, 45.125000f, 46.125000f, 45.343750f, + 46.343750f, 47.343750f, 47.125000f, 48.125000f, + 49.125000f, 48.625000f, 49.625000f, 50.625000f, + 50.125000f, 51.125000f, 52.125000f, 51.906250f, + 52.906250f, 53.906250f, 53.125000f, 54.125000f, + 55.125000f, 53.406250f, 54.406250f, 55.406250f, + 49.000000f, 50.000000f, 51.000000f, 50.218750f, + 51.218750f, 52.218750f, 52.000000f, 53.000000f, + 54.000000f, 53.500000f, 54.500000f, 55.500000f, + 55.000000f, 56.000000f, 57.000000f, 56.781250f, + 57.781250f, 58.781250f, 58.000000f, 59.000000f, + 60.000000f, 58.281250f, 59.281250f, 60.281250f, + 50.125000f, 51.125000f, 52.125000f, 51.343750f, + 52.343750f, 53.343750f, 53.125000f, 54.125000f, + 55.125000f, 54.625000f, 55.625000f, 56.625000f, + 56.125000f, 57.125000f, 58.125000f, 57.906250f, + 58.906250f, 59.906250f, 59.125000f, 60.125000f, + 61.125000f, 59.406250f, 60.406250f, 61.406250f, + 61.000000f, 62.000000f, 63.000000f, 62.218750f, + 63.218750f, 64.218750f, 64.000000f, 65.000000f, + 66.000000f, 65.500000f, 66.500000f, 67.500000f, + 67.000000f, 68.000000f, 69.000000f, 68.781250f, + 69.781250f, 70.781250f, 70.000000f, 71.000000f, + 72.000000f, 70.281250f, 71.281250f, 72.281250f, + 65.875000f, 66.875000f, 67.875000f, 67.093750f, + 68.093750f, 69.093750f, 68.875000f, 69.875000f, + 70.875000f, 70.375000f, 71.375000f, 72.375000f, + 71.875000f, 72.875000f, 73.875000f, 73.656250f, + 74.656250f, 75.656250f, 74.875000f, 75.875000f, + 76.875000f, 75.156250f, 76.156250f, 77.156250f, + 73.000000f, 74.000000f, 75.000000f, 74.218750f, + 75.218750f, 76.218750f, 76.000000f, 77.000000f, + 78.000000f, 77.500000f, 78.500000f, 79.500000f, + 79.000000f, 80.000000f, 81.000000f, 80.781250f, + 81.781250f, 82.781250f, 82.000000f, 83.000000f, + 84.000000f, 82.281250f, 83.281250f, 84.281250f, + 79.000000f, 80.000000f, 81.000000f, 80.218750f, + 81.218750f, 82.218750f, 82.000000f, 83.000000f, + 84.000000f, 83.500000f, 84.500000f, 85.500000f, + 85.000000f, 86.000000f, 87.000000f, 86.781250f, + 87.781250f, 88.781250f, 88.000000f, 89.000000f, + 90.000000f, 88.281250f, 89.281250f, 90.281250f, + 85.000000f, 86.000000f, 87.000000f, 86.218750f, + 87.218750f, 88.218750f, 88.000000f, 89.000000f, + 90.000000f, 89.500000f, 90.500000f, 91.500000f, + 91.000000f, 92.000000f, 93.000000f, 92.781250f, + 93.781250f, 94.781250f, 94.000000f, 95.000000f, + 96.000000f, 94.281250f, 95.281250f, 96.281250f, + 91.000000f, 92.000000f, 93.000000f, 92.218750f, + 93.218750f, 94.218750f, 94.000000f, 95.000000f, + 96.000000f, 95.500000f, 96.500000f, 97.500000f, + 97.000000f, 98.000000f, 99.000000f, 98.781250f, + 99.781250f, 100.781250f, 100.000000f, 101.000000f, + 102.000000f, 100.281250f, 101.281250f, 102.281250f, + 97.000000f, 98.000000f, 99.000000f, 98.218750f, + 99.218750f, 100.218750f, 100.000000f, 101.000000f, + 102.000000f, 101.500000f, 102.500000f, 103.500000f, + 103.000000f, 104.000000f, 105.000000f, 104.781250f, + 105.781250f, 106.781250f, 106.000000f, 107.000000f, + 108.000000f, 106.281250f, 107.281250f, 108.281250f, + 104.125000f, 105.125000f, 106.125000f, 105.343750f, + 106.343750f, 107.343750f, 107.125000f, 108.125000f, + 109.125000f, 108.625000f, 109.625000f, 110.625000f, + 110.125000f, 111.125000f, 112.125000f, 111.906250f, + 112.906250f, 113.906250f, 113.125000f, 114.125000f, + 115.125000f, 113.406250f, 114.406250f, 115.406250f, + 109.000000f, 110.000000f, 111.000000f, 110.218750f, + 111.218750f, 112.218750f, 112.000000f, 113.000000f, + 114.000000f, 113.500000f, 114.500000f, 115.500000f, + 115.000000f, 116.000000f, 117.000000f, 116.781250f, + 117.781250f, 118.781250f, 118.000000f, 119.000000f, + 120.000000f, 118.281250f, 119.281250f, 120.281250f, + 110.125000f, 111.125000f, 112.125000f, 111.343750f, + 112.343750f, 113.343750f, 113.125000f, 114.125000f, + 115.125000f, 114.625000f, 115.625000f, 116.625000f, + 116.125000f, 117.125000f, 118.125000f, 117.906250f, + 118.906250f, 119.906250f, 119.125000f, 120.125000f, + 121.125000f, 119.406250f, 120.406250f, 121.406250f}); // input = 1.f; + input.linspace(1); + auto size = NDArrayFactory::create({10, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 10x8"); + // expected.printBuffer("Expect for 10x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { - - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f, - 6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, - 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f, - 8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, - 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f, - 15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, - 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f, - 22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, - 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f, - 25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, - 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f, - 28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, - 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f, - 34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, - 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f, - 35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 6x6"); -// expected.printBuffer("Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, + 4.625000f, 5.625000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, + 5.875000f, 6.875000f, 7.875000f, 8.875000f, 7.500000f, 8.500000f, + 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, + 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, + 15.875000f, 16.875000f, 14.250000f, 15.250000f, 16.250000f, 17.250000f, + 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, + 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, + 19.375000f, 20.375000f, 21.375000f, 22.375000f, 21.000000f, 22.000000f, + 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, + 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, + 23.750000f, 24.750000f, 24.125000f, 25.125000f, 26.125000f, 27.125000f, + 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, + 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, + 25.000000f, 26.000000f, 27.000000f, 28.000000f, 26.625000f, 27.625000f, + 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, + 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, + 35.000000f, 36.000000f, 33.375000f, 34.375000f, 35.375000f, 36.375000f, + 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, + 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, + 32.500000f, 33.500000f, 34.500000f, 35.500000f, 34.125000f, 35.125000f, + 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 6x6"); + // expected.printBuffer("Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { - - NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f, - 24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, - 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, - 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f, - 33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 6x8"); -// expected.printBuffer("Expect for 6x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 8, 3}, + {1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, + 4.000000f, 5.000000f, 6.000000f, 5.500000f, 6.500000f, 7.500000f, + 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, + 5.875000f, 6.875000f, 7.875000f, 7.093750f, 8.093750f, 9.093750f, + 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, + 14.875000f, 15.875000f, 16.875000f, 15.156250f, 16.156250f, 17.156250f, + 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, 21.781250f, 22.781250f, + 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, + 23.125000f, 24.125000f, 25.125000f, 24.625000f, 25.625000f, 26.625000f, + 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, + 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, + 25.000000f, 26.000000f, 27.000000f, 26.218750f, 27.218750f, 28.218750f, + 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, + 34.000000f, 35.000000f, 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, + 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, + 32.125000f, 33.125000f, 34.125000f, 33.906250f, 34.906250f, 35.906250f, + 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 6x8"); + // expected.printBuffer("Expect for 6x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { - - NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { - 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, - 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, - 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, - 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, - 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, - 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, - 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, - 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, - 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, - 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, - 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, - 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, - 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, - 34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f, - 35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, - 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f, - 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, - 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, - 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f, - 39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, - 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f, - 47.406250f, 48.406250f, 49.406250f, - }); - input.linspace(1); - auto size = NDArrayFactory::create({8, 8}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 8x8"); -// expected.printBuffer("Expect for 8x8"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 8, 3}, + { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, + 4.218750f, 4.000000f, 5.000000f, 6.000000f, 5.500000f, + 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, + 8.781250f, 9.781250f, 10.781250f, 10.000000f, 11.000000f, + 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, + 6.875000f, 7.875000f, 7.093750f, 8.093750f, 9.093750f, + 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, + 12.375000f, 11.875000f, 12.875000f, 13.875000f, 13.656250f, + 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, + 15.000000f, 14.218750f, 15.218750f, 16.218750f, 16.000000f, + 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, + 19.000000f, 20.000000f, 21.000000f, 20.781250f, 21.781250f, + 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, + 23.281250f, 24.281250f, 19.000000f, 20.000000f, 21.000000f, + 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, + 24.000000f, 23.500000f, 24.500000f, 25.500000f, 25.000000f, + 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, + 30.281250f, 25.000000f, 26.000000f, 27.000000f, 26.218750f, + 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, + 29.500000f, 30.500000f, 31.500000f, 31.000000f, 32.000000f, + 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, + 35.000000f, 36.000000f, 34.281250f, 35.281250f, 36.281250f, + 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, + 35.343750f, 35.125000f, 36.125000f, 37.125000f, 36.625000f, + 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, + 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, + 43.125000f, 41.406250f, 42.406250f, 43.406250f, 37.000000f, + 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, + 40.000000f, 41.000000f, 42.000000f, 41.500000f, 42.500000f, + 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, + 45.781250f, 46.781250f, 46.000000f, 47.000000f, 48.000000f, + 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, + 40.125000f, 39.343750f, 40.343750f, 41.343750f, 41.125000f, + 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, + 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, + 47.906250f, 47.125000f, 48.125000f, 49.125000f, 47.406250f, + 48.406250f, 49.406250f, + }); + input.linspace(1); + auto size = NDArrayFactory::create({8, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 8x8"); + // expected.printBuffer("Expect for 8x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { - - NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { - 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, - 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, - 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, - 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, - 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, - 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f - }); - - NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { - 1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f, - 3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, - 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f, - 7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, - 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f, - 6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, - 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f, - 4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, - 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f, - 8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, - 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f, - 7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, - 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f, - 11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, - 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, - 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, - 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f, - 10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, - 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f, - 14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, - 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f, - 13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, - 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f, - 17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, - 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, - 15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, - 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f, - 14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, - 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f, - 18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, - 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f, - 18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, - 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f, - 22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, - 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, - 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, - 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f, - 19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, - 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f, - 23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, - 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f, - 22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, - 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f, - 26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, - 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, - 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, - 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f, - 25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, - 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f, - 29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, - 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f, - 28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, - 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f, - 32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, - 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f, - 31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, - 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f, - 30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, - 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f, - 34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, - 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f, - 33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, - 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f, - 37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, - 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, - 36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, - 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f, - 35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, - 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f, - 39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, - 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f, - 38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, - 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f, - 42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, - 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f, - 41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, - 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f, - 40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, - 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f, - 45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, - 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f, - 43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, - 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f, - 47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, - 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f, - 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, - 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f, - 44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, - 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f, - 49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, - 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f, - 46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, - 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f, - 50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, - 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f, - 48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, - 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f, - 44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, - 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f, - 49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f - }); - - auto size = NDArrayFactory::create({30, 30}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - -// result.printBuffer("Resized to 30x30"); -// expected.printBuffer("Expect for 30x30"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {7, 7, 1}, + {1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, 8.f, 9.1f, 10.f, + 11.f, 12.9f, 13.1f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 30.f, 31.f, + 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, + 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f}); + + NDArray expected = NDArrayFactory::create( + 'c', {30, 30, 1}, + {1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, + 2.550918f, 2.736061f, 2.965541f, 3.292965f, 3.544151f, 3.738035f, + 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, + 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, + 6.718899f, 6.887104f, 7.039068f, 7.099216f, 7.078424f, 7.028189f, + 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, + 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, + 5.133866f, 5.459533f, 5.774461f, 6.019787f, 6.254011f, 6.535633f, + 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.628684f, 3.830573f, 4.056959f, 4.321157f, 4.636486f, 4.955650f, + 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, + 8.417635f, 8.512958f, 8.552100f, 8.649708f, 8.877880f, 9.108794f, + 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, + 6.796207f, 6.951142f, 7.150400f, 7.446143f, 7.644651f, 7.794562f, + 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, + 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, + 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, + 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, + 9.493208f, 9.692467f, 9.916944f, 10.176801f, 10.482199f, 10.785470f, + 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, + 14.528972f, 14.555555f, 14.501450f, 14.515459f, 14.700572f, 14.927055f, + 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, + 12.432540f, 12.588294f, 12.787534f, 13.079956f, 13.277520f, 13.426631f, + 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, + 16.636436f, 16.842583f, 17.010887f, 17.073630f, 17.051940f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, + 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, + 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.766397f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, + 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, + 18.118288f, 18.296928f, 18.446100f, 18.651634f, 18.956806f, 19.223820f, + 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, + 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, + 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, + 21.626736f, 21.815500f, 21.985610f, 22.052843f, 22.029604f, 21.973448f, + 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, + 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, + 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, + 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.642070f, 19.952137f, + 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, + 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, + 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, + 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, + 21.589132f, 21.768297f, 21.990030f, 22.302366f, 22.538124f, 22.719105f, + 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, + 25.850672f, 26.040140f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, + 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, + 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, + 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.607670f, 24.804970f, 25.040410f, 25.333065f, 25.642756f, + 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, + 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, + 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, + 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, + 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.370100f, 28.550098f, + 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, + 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, + 31.683659f, 31.873592f, 32.043407f, 32.110240f, 32.087135f, 32.031320f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, + 30.322360f, 30.618746f, 30.882797f, 31.070557f, 31.258320f, 31.522371f, + 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, + 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, + 29.636976f, 29.815207f, 30.012500f, 30.247944f, 30.540600f, 30.850290f, + 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, + 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, + 33.885178f, 34.086884f, 34.266880f, 34.501457f, 34.824837f, 35.097990f, + 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, + 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, + 33.257744f, 33.436176f, 33.657207f, 33.968960f, 34.203537f, 34.383537f, + 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, + 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, + 37.517097f, 37.707027f, 37.876846f, 37.943680f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, + 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, + 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, + 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, + 34.464783f, 34.643010f, 34.840305f, 35.075752f, 35.368404f, 35.678100f, + 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, + 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, + 38.712980f, 38.914684f, 39.094685f, 39.329260f, 39.652645f, 39.925793f, + 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, + 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, + 37.482370f, 37.660800f, 37.881836f, 38.193590f, 38.428170f, 38.608162f, + 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, + 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, + 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, + 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, + 41.066803f, 41.363190f, 41.627243f, 41.815002f, 42.002766f, 42.266820f, + 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, + 40.227080f, 40.405310f, 40.602608f, 40.838050f, 41.130707f, 41.440395f, + 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, + 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, + 44.475280f, 44.676983f, 44.856983f, 45.091560f, 45.414940f, 45.688090f, + 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, + 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.155720f, 44.335716f, + 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, + 47.469276f, 47.659210f, 47.829030f, 47.895855f, 47.872753f, 47.816940f, + 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, + 45.866940f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, + 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.517180f, + 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, + 48.552063f, 48.753770f, 48.933773f, 49.168350f, 49.491726f, 49.764877f, + 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, + 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, + 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, + 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, + 50.455338f, 50.645270f, 50.815090f, 50.881920f, 50.858818f, 50.803000f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, + 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, + 47.361770f, 47.658157f, 47.922207f, 48.109970f, 48.297733f, 48.561783f, + 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, + 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, + 44.219246f, 44.397472f, 44.594772f, 44.830210f, 45.122868f, 45.432560f, + 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, + 48.467438f, 48.669140f, 48.849144f, 49.083720f, 49.407100f, 49.680256f, + 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f}); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + + // result.printBuffer("Resized to 30x30"); + // expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { - - NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { - 0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511, - 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613, - 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989, - 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312, - 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); - - NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 1}, { - 0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, - 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, - 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, - 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, - 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, - 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, - 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, - 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, - 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, - 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, - 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, - 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, - 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, - 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, - 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, - 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, - 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, - 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, - 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, - 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, - 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, - 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, - 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, - 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, - 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, - 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, - 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, - 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, - 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, - 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, - 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, - 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, - 0.9139262f, 0.92068815f - }); - auto size = NDArrayFactory::create({9, 9}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 9x9"); -// expected.printBuffer("Expect for 9x9"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, + 0.9511, 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, + 0.1524, 0.2613, 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, + 0.3741, 0.2088, 0.2989, 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, + 0.2869, 0.6638, 0.3091, 0.9312, 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, + 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 9, 9, 1}, + {0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, + 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, + 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, + 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, + 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, + 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, + 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, + 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, + 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, + 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, + 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, + 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, + 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, + 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, + 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, + 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, + 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, + 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, + 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, + 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, + 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, + 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, + 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, + 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, + 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, + 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, + 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, + 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, + 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, + 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, + 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, + 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, + 0.9139262f, 0.92068815f}); + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { - - NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { - 0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338, - 0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859, - 0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, - 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986, - 0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547, - 0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666, - 0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496, - 0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752, - 0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, - 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859}); - - - auto testData = NDArrayFactory::create('c', {2,9,9,1}, { - 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, - 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, - 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, - 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, - 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, - 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, - 0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f, - - 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, - 0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f, - 0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f, - 0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f, - 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, - 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, - - 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, - 0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f, - 0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f, - 0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f, - 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f - }); - - auto size = NDArrayFactory::create({9, 9}); - sd::ops::resize_bicubic op; - auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Resized to 9x9"); -// testData.printBuffer("Expect for 9x9"); - ASSERT_TRUE(testData.isSameShape(result)); - ASSERT_TRUE(testData.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 5, 5, 1}, + {0.23028551377579154, 0.7949972231516509, 0.8171307820461517, + 0.04507309923418412, 0.3689673597428338, 0.6845757584903018, + 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, + 0.9511201914609859, 0.41160882670429033, 0.3997152563642703, + 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, + 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, + 0.1523603480424196, 0.2612986044295986, 0.7424762244324299, + 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, + 0.13607200219168547, 0.4154002170215956, 0.36340617544852116, + 0.37405031188276827, 0.20880251686544882, 0.298919946410666, + 0.39820758164277126, 0.5617728968896589, 0.72660225993937, + 0.10888245916813699, 0.29215797784445496, 0.3305531351746034, + 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, + 0.9312186188801752, 0.0239594182399363, 0.2892942758780874, + 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, + 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, + 0.8820362490795286, 0.8699337744328859}); + + auto testData = NDArrayFactory::create( + 'c', {2, 9, 9, 1}, + {0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, + 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, 0.483021289f, + 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, + 0.105112493f, 0.349290252f, 0.674043298f, 0.684575737f, 0.478224277f, + 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, + 0.590989769f, 0.951120198f, 0.622912169f, 0.441326082f, 0.266387194f, + 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, + 0.981704295f, 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, + 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, + 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, + 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, 0.06782046f, + 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, + 0.15236035f, 0.193153858f, 0.261298597f, + + 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, + 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, 0.742476225f, + 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, + 0.583827078f, 0.340373576f, 0.136071995f, 0.415400207f, 0.388405323f, + 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, + 0.238369256f, 0.298919946f, 0.413146496f, 0.444389015f, 0.488355637f, + 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, + 0.23562704f, 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, + 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, + 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, + 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, + + 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, + 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, 0.137025774f, + 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, + 0.662163854f, 0.68354249f, 0.692712903f, 0.023959421f, 0.130951077f, + 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, + 0.728188932f, 0.418943912f, 0.175951749f, 0.198239252f, 0.281999886f, + 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, + 0.596413136f, 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, + 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f}); + + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Resized to 9x9"); + // testData.printBuffer("Expect for 9x9"); + ASSERT_TRUE(testData.isSameShape(result)); + ASSERT_TRUE(testData.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 9.f, 10.f, 11.f, 12.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 13.f, 14.f, 15.f, 16.f, - 13.f, 14.f, 15.f, 16.f, - 17.f, 18.f, 19.f, 20.f, - 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, - 21.f, 22.f, 23.f, 24.f, - - 25.f, 26.f, 27.f, 28.f, - 25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f, - 29.f, 30.f, 31.f, 32.f, - 33.f, 34.f, 35.f, 36.f, - 33.f, 34.f, 35.f, 36.f, - - 25.f, 26.f, 27.f, 28.f, - 25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f, - 29.f, 30.f, 31.f, 32.f, - 33.f, 34.f, 35.f, 36.f, - 33.f, 34.f, 35.f, 36.f }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f, -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { - 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, - 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, - 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, - 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, - 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, - 7.f, 7.f, 8.f, 8.f, 9.f, 9.f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + 13.f, 14.f, 15.f, 16.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 21.f, 22.f, 23.f, 24.f, - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + 25.f, 26.f, 27.f, 28.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 33.f, 34.f, 35.f, 36.f, - auto result = results.at(0); + 25.f, 26.f, 27.f, 28.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 33.f, 34.f, 35.f, 36.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); -} - - -TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} - auto result = results.at(0); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 1}, + {1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { - - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 - }); - - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { - - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 - }); - - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { - - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); - - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, - - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); - auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { - - NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); - - NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, - - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 6}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {2, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {2, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { - - NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, - 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, - 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f - }); - //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 6}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x6"); -// expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 6, 6, 1}, + {1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f}); + // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x6"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f - - }); - //input.linspace(1); - auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 10x10"); - // expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, + 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, + 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, + 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, + 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, + 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, + 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, + 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, + 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, + 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, + 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, + 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, + 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, + 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, + 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, + 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, + 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, + 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, + 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, + 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + // input.linspace(1); + auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - - NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { - 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f - - }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {10, 10}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 10x10"); - // expected.printBuffer("Area Expect for 6x6"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + NDArray expected = NDArrayFactory::create( + 'c', {1, 10, 10, 4}, + {1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, + 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, + 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, + 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, + 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, + 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, + 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, + 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, + 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, + 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, + 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, + 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, + 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, + 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, + 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, + 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, + 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, + 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, + 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, + 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, + 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, + 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, + 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, + 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, + 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, + 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, + 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, + 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, + 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, + 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, + 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {6, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x9"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {10, 15}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 6x9"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 15}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { - - NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 - }); - -// NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { -// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 -// -// }); - //input.linspace(1); - //auto size = NDArrayFactory::create({10, 10}); - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {9, 9}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - -// result.printBuffer("Area Resized to 8x8"); - // expected.printBuffer("Area Expect for 6x6"); -// ASSERT_TRUE(expected.isSameShape(result)); -// ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + // NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, + // 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, + // 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, + // 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, + // 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, + // 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, + // 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, + // 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, + // 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, + // 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, + // 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, + // 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, + // 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, + // 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, + // 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, + // 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, + // 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, + // 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, + // 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, + // 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, + // 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, + // 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, + // 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, + // 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, + // 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, + // 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, + // 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, + // 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, + // 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, + // 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, + // 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, + // 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, + // 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, + // 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, + // 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, + // 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, + // 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, + // 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, + // 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, + // 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, + // 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, + // 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, + // 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, + // 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, + // 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, + // 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, + // 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, + // 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, + // 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, + // 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, + // 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, + // 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, + // 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, + // 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, + // 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, + // 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 + // + // }); + // input.linspace(1); + // auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + // result.printBuffer("Area Resized to 8x8"); + // expected.printBuffer("Area Expect for 6x6"); + // ASSERT_TRUE(expected.isSameShape(result)); + // ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { - - NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 - }); - auto size = NDArrayFactory::create({8, 7}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { - 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, - 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, - 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, - 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, - 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , - 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , - 21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f , - 25.f - }); //input.linspace(1); -// auto size = NDArrayFactory::create({6, 6}); - sd::ops::resize_area op; - auto results = op.evaluate({&input, &size}, {}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Area Resized to 8x7"); -// expected.printBuffer("Area Expect for 8x7"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}); + auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 7, 1}, + {1.f, 1.6f, 2.1999993f, 2.9999995f, 3.8f, 4.399997f, + 5.f, 2.9999995f, 3.5999997f, 4.199999f, 4.9999995f, 5.8f, + 6.3999963f, 7.f, 5.999999f, 6.6f, 7.1999984f, 7.9999995f, + 8.8f, 9.399994f, 10.f, 10.f, 10.6f, 11.199998f, + 12.f, 12.8f, 13.399992f, 14.f, 12.f, 12.599999f, + 13.199998f, 13.999998f, 14.800002f, 15.399991f, 16.f, 15.999999f, + 16.599998f, 17.199995f, 18.f, 18.800003f, 19.399986f, 20.000002f, + 19.f, 19.599998f, 20.199997f, 20.999998f, 21.800003f, 22.399984f, + 23.000002f, 20.999998f, 21.599998f, 22.199995f, 22.999998f, 23.800001f, + 24.399984f, 25.f}); // input.linspace(1); + // auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Area Resized to 8x7"); + // expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { - - NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 - }); - //auto size = NDArrayFactory::create({8, 7}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { - 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, - 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, - 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, - 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, - 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , - 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f , - 22.999998f , 23.800001f , 24.399984f , 25.f - }); - - sd::ops::resize_area op; - auto results = op.evaluate({&input}, {}, {8, 7}, {false}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printBuffer("Area Resized to 8x7"); -// expected.printBuffer("Area Expect for 8x7"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + NDArray input = NDArrayFactory::create( + 'c', {1, 5, 5, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}); + // auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create( + 'c', {1, 8, 7, 1}, + {1.f, 1.6f, 2.1999993f, 2.9999995f, 3.8f, 4.399997f, + 5.f, 2.9999995f, 3.5999997f, 4.199999f, 4.9999995f, 5.8f, + 6.3999963f, 7.f, 5.999999f, 6.6f, 7.1999984f, 7.9999995f, + 8.8f, 9.399994f, 10.f, 10.f, 10.6f, 11.199998f, + 12.f, 12.8f, 13.399992f, 14.f, 12.f, 12.599999f, + 13.199998f, 13.999998f, 14.800002f, 15.399991f, 16.f, 15.999999f, + 16.599998f, 17.199995f, 18.f, 18.800003f, 19.399986f, 20.000002f, + 19.f, 19.599998f, 20.199997f, 20.999998f, 21.800003f, 22.399984f, + 23.000002f, 20.999998f, 21.599998f, 22.199995f, 22.999998f, 23.800001f, + 24.399984f, 25.f}); + + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {8, 7}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printBuffer("Area Resized to 8x7"); + // expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { + functions::summarystats::SummaryStatsData var1; + functions::summarystats::SummaryStatsData var2; + var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5; - functions::summarystats::SummaryStatsData var1; - functions::summarystats::SummaryStatsData var2; - var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5; + functions::summarystats::SummaryStatsData* arr = + new functions::summarystats::SummaryStatsData[2]; + arr[0] = var1; + arr[1] = var2; + arr[0] = arr[1]; - functions::summarystats::SummaryStatsData* arr = new functions::summarystats::SummaryStatsData[2]; - arr[0] = var1; - arr[1] = var2; - arr[0] = arr[1]; + functions::summarystats::SummaryStatsData var3(var1); - functions::summarystats::SummaryStatsData var3(var1); + ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && + arr[0].n == 5); + ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && + arr[1].n == 5); + ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0); - ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5); - ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5); - ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0); - - delete []arr; + delete[] arr; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_1) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, {2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { - 2.f, 4.f, 3.f - }); + auto b = NDArrayFactory::create('c', {3, 1}, {2.f, 4.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3, 1}, { - 7.625f, 3.25f, 5.f - }); + auto exp = NDArrayFactory::create('c', {3, 1}, {7.625f, 3.25f, 5.f}); - sd::ops::solve op; + sd::ops::solve op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("Solve of 3x3"); + // z->printIndexedBuffer("Solve of 3x3"); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_2) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f - }); - - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("Solve 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f}); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_3) { + auto a = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 0.f, 0.f, + 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 3.f, - auto a = NDArrayFactory::create('c', {2, 4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, + 3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 1.f, + 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f + }); - }); + auto b = NDArrayFactory::create( + 'c', {2, 4, 1}, {2.f, 4.f, 2.f, 4.f, 4.f, 2.f, 4.f, 2.f}); - auto b = NDArrayFactory::create('c', {2, 4, 1}, { - 2.f, 4.f, 2.f, 4.f, - 4.f, 2.f, 4.f, 2.f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 4, 1}, + {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, 1.333333f, -0.6666667f, + 2.6666667f, -1.3333333f}); - auto exp = NDArrayFactory::create('c', {2, 4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f - }); + sd::ops::solve op; - sd::ops::solve op; + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printIndexedBuffer("Solve 4x4"); -// z->printIndexedBuffer("Solve 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto a = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7717f, 0.9281f, 0.9846f, 0.4838f, - 0.6433f, 0.6041f, 0.6501f, 0.7612f - }); + auto b = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, -// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f - 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, - 0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, + // 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f + 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, 0.81915987f, + 0.72049433f, 0.2643504f, 0.44472617f}); - sd::ops::solve op; + sd::ops::solve op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto a = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); + auto b = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f}); - auto b = NDArrayFactory::create('c', {2, 2, 2}, { - 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, 0.4400987f, + 0.2766527f, 0.6394467f, 0.79696566f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, - 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f - }); + sd::ops::solve op; - sd::ops::solve op; + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - 0.35071516f, 0.50630623f, 0.42935497f, - -0.30013534f, -0.53690606f, -0.47959247f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, 0.35071516f, 0.50630623f, + 0.42935497f, -0.30013534f, -0.53690606f, -0.47959247f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {true, false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printBuffer("4_2 Triangular_Solve 3x3"); -// exp.printBuffer("4_2 Triangular_Expec 3x3"); + // z->printBuffer("4_2 Triangular_Solve 3x3"); + // exp.printBuffer("4_2 Triangular_Expec 3x3"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.45400196f, 0.53174824f, 0.62064564f, - -0.79585856f, -0.82621557f, -0.87855506f, - 1.1904413f, 1.3938838f, 1.3926021f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.45400196f, 0.53174824f, 0.62064564f, -0.79585856f, -0.82621557f, + -0.87855506f, 1.1904413f, 1.3938838f, 1.3926021f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {true, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printBuffer("4_3 Triangular_Solve 3x3"); -// exp.printBuffer("4_3 Triangular_Expec 3x3"); + // z->printBuffer("4_3 Triangular_Solve 3x3"); + // exp.printBuffer("4_3 Triangular_Expec 3x3"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.8959121f, 1.6109066f, 1.7501404f, 0.49000582f, 0.66842675f, 0.5577021f, + -0.4398522f, -1.1899745f, -1.1392052f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.8959121f, 1.6109066f, 1.7501404f, - 0.49000582f, 0.66842675f, 0.5577021f, - -0.4398522f, -1.1899745f, -1.1392052f - }); + sd::ops::solve op; - sd::ops::solve op; + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printBuffer("4_4 Solve 3x3"); + // exp.printBuffer("4_4 Expec 3x3"); -// z->printBuffer("4_4 Solve 3x3"); -// exp.printBuffer("4_4 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); - - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 1.5504692f, 1.8953944f, 2.2765768f, - 0.03399149f, 0.2883001f, 0.5377323f, - -0.8774802f, -1.2155888f, -1.8049058f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.5504692f, 1.8953944f, 2.2765768f, 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f}); - sd::ops::solve op; + sd::ops::solve op; - auto res = op.evaluate({&a, &b}, {true, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printBuffer("4_5 Solve 3x3"); -// exp.printBuffer("4_5 Expec 3x3"); + // z->printBuffer("4_5 Solve 3x3"); + // exp.printBuffer("4_5 Expec 3x3"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, -0.426483f, -0.42840624f, + -0.5622601f, 0.01692283f, -0.04538865f, -0.09868701f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - -0.426483f, -0.42840624f, -0.5622601f, - 0.01692283f, -0.04538865f, -0.09868701f - }); + sd::ops::triangular_solve op; - sd::ops::triangular_solve op; + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printBuffer("4_6 Solve 3x3"); + // exp.printBuffer("4_6 Expec 3x3"); -// z->printBuffer("4_6 Solve 3x3"); -// exp.printBuffer("4_6 Expec 3x3"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {// 0.7788f, 0.2309f, 0.5056f, + // 0.8012f, 0.7271f, 0.8925f, + // 0.7244f, 0.1804f, 0.5461f - auto a = NDArrayFactory::create('c', {3, 3}, { -// 0.7788f, 0.2309f, 0.5056f, -// 0.8012f, 0.7271f, 0.8925f, -// 0.7244f, 0.1804f, 0.5461f - - 0.7788f, 0.2309f, 0.5056f, - 0.8012f, 0.7271f, 0.8925f, - 0.7244f, 0.1804f, 0.5461f - }); + 0.7788f, 0.2309f, 0.5056f, 0.8012f, 0.7271f, 0.8925f, 0.7244f, 0.1804f, + 0.5461f}); - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 0.99088347f, 1.1917052f, 1.2642528f, - -0.426483f, -0.42840624f, -0.5622601f, - 0.01692283f, -0.04538865f, -0.09868701f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {0.99088347f, 1.1917052f, 1.2642528f, -0.426483f, -0.42840624f, + -0.5622601f, 0.01692283f, -0.04538865f, -0.09868701f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {true, false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printBuffer("4_7 Solve 3x3"); -// exp.printBuffer("4_7 Expec 3x3"); + // z->printBuffer("4_7 Solve 3x3"); + // exp.printBuffer("4_7 Expec 3x3"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_5) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 0.7788f, 0.8012f, 0.7244f, - 0.2309f, 0.7271f, 0.1804f, - 0.5056f, 0.8925f, 0.5461f - }); + auto b = NDArrayFactory::create( + 'c', {3, 3}, + {0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f}); - auto b = NDArrayFactory::create('c', {3, 3}, { - 0.7717f, 0.9281f, 0.9846f, - 0.4838f, 0.6433f, 0.6041f, - 0.6501f, 0.7612f, 0.7605f - }); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.5504692f, 1.8953944f, 2.2765768f, 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 1.5504692f, 1.8953944f, 2.2765768f, - 0.03399149f, 0.2883001f, 0.5377323f, - -0.8774802f, -1.2155888f, -1.8049058f - }); + sd::ops::solve op; - sd::ops::solve op; + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printBuffer("4 Solve 4x4"); + // exp.printBuffer("4 Expec 4x4"); -// z->printBuffer("4 Solve 4x4"); -// exp.printBuffer("4 Expec 4x4"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, SolveLS_Test_1) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 1}, { - 3.f, 7.f, 11.f, 15.f - }); + auto b = + NDArrayFactory::create('c', {2, 2, 1}, {3.f, 7.f, 11.f, 15.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 1}, { - 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 1}, {0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("LS Solve 2x2"); -// exp.printIndexedBuffer("LS Expec 2x2"); + // z->printIndexedBuffer("LS Solve 2x2"); + // exp.printIndexedBuffer("LS Expec 2x2"); - ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); } TEST_F(DeclarableOpsTests11, SolveLS_Test_2) { + auto a = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f - }); - - auto b = NDArrayFactory::create('c', {2, 2, 1}, { - 3.f, 7.f, 11.f, 15.f - }); + auto b = + NDArrayFactory::create('c', {2, 2, 1}, {3.f, 7.f, 11.f, 15.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 1}, { - 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 1}, {0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("2LS Solve 2x2"); -// exp.printIndexedBuffer("2LS Expec 2x2"); + // z->printIndexedBuffer("2LS Solve 2x2"); + // exp.printIndexedBuffer("2LS Expec 2x2"); - ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) { + auto a = NDArrayFactory::create('c', {2, 2, 2}, + {10.f, 14.f, 14.f, 20.f, - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 10.f, 14.f, - 14.f, 20.f, + 74.f, 86.f, 86.f, 100.f}); - 74.f, 86.f, - 86.f, 100.f - }); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, + {3.1622777f, 0.f, 4.427189f, 0.6324552f, + 8.602325f, 0.f, 9.997296f, 0.23252854f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 3.1622777f, 0.f, 4.427189f, 0.6324552f, - 8.602325f, 0.f, 9.997296f, 0.23252854f - }); + sd::ops::cholesky op; - sd::ops::cholesky op; + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - - //z->printIndexedBuffer("L matrix is"); - //exp.printIndexedBuffer("L expected is"); - - ASSERT_TRUE(exp.equalsTo(z)); + // z->printIndexedBuffer("L matrix is"); + // exp.printIndexedBuffer("L expected is"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) { + auto a = NDArrayFactory::create('c', {2, 2, 2}, + {10.5f, 14.f, 14.f, 20.5f, - auto a = NDArrayFactory::create('c', {2,2, 2}, { - 10.5f, 14.f, - 14.f, 20.5f, + 74.5f, 86.f, 86.f, 100.5f}); - 74.5f, 86.f, - 86.f, 100.5f - }); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, + {3.2403703f, 0.f, 4.3204937f, 1.3540066f, + 8.631338f, 0.f, 9.963693f, 1.1067207f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { - 3.2403703f, 0.f, 4.3204937f, 1.3540066f, - 8.631338f, 0.f, 9.963693f, 1.1067207f - }); + sd::ops::cholesky op; - sd::ops::cholesky op; + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("L matrix is"); -// exp.printIndexedBuffer("L expected is"); - MmulHelper::matmul(&z, &z, &exp, false, true); - ASSERT_TRUE(exp.equalsTo(a)); + // z->printIndexedBuffer("L matrix is"); + // exp.printIndexedBuffer("L expected is"); + MmulHelper::matmul(&z, &z, &exp, false, true); + ASSERT_TRUE(exp.equalsTo(a)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, - -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 , - 155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.96, -1.92, -2.88, -3.84, -4.8, -5.76, -6.72, -7.68, + -8.64, -9.6, -10.56, -11.52, -12.48, -13.44, -14.4, -15.36, + -16.32, -17.28, -18.24, -19.2, -20.16, -21.12, -22.08, -23.04}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.9216, 3.6864, 8.2944, 14.7456, 23.04, 33.1776, + 45.1584, 58.9824, 74.6496, 92.16, 111.51361, 132.7104, + 155.75038, 180.63359, 207.35999, 235.9296, 266.34238, 298.59842, + 332.6976, 368.64001, 406.4256, 446.05444, 487.5264, 530.84161}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {2, 1, 4}, + {98.61121, 129.024, 164.9664, 206.4384, 828.51837, 925.28644, + 1027.58398, 1135.41113}); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, - -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {}, std::vector{4515.84}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.96, -1.92, -2.88, -3.84, -4.8, -5.76, -6.72, -7.68, + -8.64, -9.6, -10.56, -11.52, -12.48, -13.44, -14.4, -15.36, + -16.32, -17.28, -18.24, -19.2, -20.16, -21.12, -22.08, -23.04}); + NDArray dLdwExp('c', {}, std::vector{4515.84}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159}); + NDArray dLdwExp('c', {1, 3, 1}, {807.32153, 1426.63684, 2281.88159}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96, - -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); - NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208, - -2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.08, -0.16, -0.24, -0.32, -0.4, -0.48, -0.56, -0.64, + -0.72, -0.8, -0.88, -0.96, -1.04, -1.12, -1.2, -1.28, + -1.36, -1.44, -1.52, -1.6, -1.68, -1.76, -1.84, -1.92}); + NDArray dLdwExp('c', {2, 3, 4}, + {-15.6032, -15.3728, -14.9888, -14.4512, -13.76, -12.9152, + -11.9168, -10.7648, -9.4592, -8., -6.3872, -4.6208, + -2.7008, -0.6272, 1.6, 3.9808, 6.5152, 9.2032, + 12.0448, 15.04, 18.1888, 21.4912, 24.9472, 28.5568}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {-58.16319, -6.5536, 64.71682}); - NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152, - -1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304}); - NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992, - -6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 }); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.48, -0.576, -0.672, -0.768, + -0.864, -0.96, -1.056, -1.152, -1.248, -1.344, -1.44, -1.536, + -1.632, -1.728, -1.824, -1.92, -2.016, -2.112, -2.208, -2.304}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {-22.3488, -22.07232, -21.61152, -20.9664, -20.13696, -19.1232, + -17.92512, -16.54272, -14.976, -13.22496, -11.2896, -9.16992, + -6.86592, -4.3776, -1.70496, 1.152, 4.19328, 7.41888, + 10.8288, 14.42304, 18.2016, 22.16449, 26.31168, 30.6432}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48, - -0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96}); - NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296, - 6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.04, -0.08, -0.12, -0.16, -0.2, -0.24, -0.28, -0.32, + -0.36, -0.4, -0.44, -0.48, -0.52, -0.56, -0.6, -0.64, + -0.68, -0.72, -0.76, -0.8, -0.84, -0.88, -0.92, -0.96}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.0384, 0.1536, 0.3456, 0.6144, 0.96, 1.3824, 1.8816, 2.4576, + 3.1104, 3.84, 4.6464, 5.5296, 6.4896, 7.5264, 8.64, 9.8304, + 11.0976, 12.4416, 13.8624, 15.36, 16.9344, 18.5856, 20.3136, 22.1184}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 1}, std::vector{188.16}); - NDArray dLdwExp('c', {1,1}, std::vector{188.16}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841}); + NDArray dLdwExp('c', {1, 3, 1}, {33.6384, 59.4432, 95.07841}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576, - -0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152}); - NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552, - 7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - weights.t(3) = 0.; - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.24, -0.288, -0.336, -0.384, + -0.432, -0.48, -0.528, -0.576, -0.624, -0.672, -0.72, -0.768, + -0.816, -0.864, -0.912, -0.96, -1.008, -1.056, -1.104, -1.152}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.04608, 0.18432, 0.41472, 0.73728, 1.152, 1.65888, + 2.25792, 2.94912, 3.73248, 4.608, 5.57568, 6.63552, + 7.78752, 9.03168, 10.368, 11.79648, 13.31712, 14.92992, + 16.63488, 18.432, 20.32128, 22.30272, 24.37632, 26.54208}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + weights.t(3) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); - NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - - sd::ops::mean_sqerr_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., -1.04, -1.12, -1.2, -1.28, + -1.36, -1.44, -1.52, -1.6, -1.68, -1.76, -1.84, -1.92}); + NDArray dLdwExp('c', {2, 3, 1}, + {2.304, 13.3632, 34.2528, 64.97279, 105.5232, 155.90401}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {4}, {9, 1,1, 9}); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {4}, {9, 1, 1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) { - auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9}); - sd::ops::squaredsubtract op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = + NDArrayFactory::create('c', {2, 4}, {9, 1, 1, 9, 9, 1, 1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) { - auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); - auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); - auto exp = NDArrayFactory::create('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1,2,3,4,5,6,7,8}); - sd::ops::squaredsubtract_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {2, 4}, + {-6, -4, 6, 24, -30, -12, 14, 48}); + auto eps = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); + sd::ops::squaredsubtract_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, - -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52, - 12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.96, 1.92, 2.88, 3.84, 4.8, 5.76, 6.72, 7.68, + 8.64, 9.6, 10.56, 11.52, 12.48, 13.44, 14.4, 15.36, + 16.32, 17.28, 18.24, 19.2, 20.16, 21.12, 22.08, 23.04}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {2, 1, 4}, + {14.4, 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, - -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {}, std::vector{288.}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}); + NDArray dLdwExp('c', {}, std::vector{288.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001}); + NDArray dLdwExp('c', {1, 3, 1}, {65.28, 96., 126.72001}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167, - -0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167}); - NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04, - 0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167}); + NDArray dLdwExp('c', {2, 3, 4}, + {-0.92, -0.84, -0.76, -0.68, -0.6, -0.52, -0.44, -0.36, + -0.28, -0.2, -0.12, -0.04, 0.04, 0.12, 0.2, 0.28, + 0.36, 0.44, 0.52, 0.6, 0.68, 0.76, 0.84, 0.92}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {-2.56, 0., 2.56}); - NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); + NDArray dLdwExp('c', {}, std::vector{0.}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05, - -0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05}); - NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 , - -0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0., -0., -0., -0., -0.05, -0.05, -0.05, -0.05, + -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, + -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05, -0.05}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {-1.296, -1.2, -1.104, -1.008, -0.912, -0.816, -0.72, -0.624, + -0.528, -0.432, -0.336, -0.24, -0.144, -0.048, 0.048, 0.144, + 0.24, 0.336, 0.432, 0.528, 0.624, 0.72, 0.816, 0.912}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083, - -0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083}); - NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48, - 0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083, -0.02083, -0.02083}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.04, 0.08, 0.12, 0.16, 0.2, 0.24, 0.28, 0.32, 0.36, 0.4, 0.44, 0.48, + 0.52, 0.56, 0.6, 0.64, 0.68, 0.72, 0.76, 0.8, 0.84, 0.88, 0.92, 0.96}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 1}, std::vector{12.}); - NDArray dLdwExp('c', {1,1}, std::vector{12.}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28}); + NDArray dLdwExp('c', {1, 3, 1}, {2.72, 4., 5.28}); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025, - -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025}); - NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576, - 0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - weights.t(3) = 0.; - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025, + -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, + -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025, -0.025}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.048, 0.096, 0.144, 0.192, 0.24, 0.288, 0.336, 0.384, + 0.432, 0.48, 0.528, 0.576, 0.624, 0.672, 0.72, 0.768, + 0.816, 0.864, 0.912, 0.96, 1.008, 1.056, 1.104, 1.152}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + weights.t(3) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., - -0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167}); - NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 }); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - - sd::ops::absolute_difference_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); - ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, + -0.04167, -0.04167, -0.04167, -0.04167, -0.04167, -0.04167}); + NDArray dLdwExp('c', {2, 3, 1}, {0.8, 2.08, 3.36, 4.64, 5.92, 7.2}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); - + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); - + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { + NDArray x('c', {2, 3, 4}, sd::DataType::BFLOAT16); + NDArray y('c', {2, 3, 4}, sd::DataType::BFLOAT16); + NDArray exp('c', {2, 3, 4}, sd::DataType::BFLOAT16); - NDArray x('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray y('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548, - -6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077}); - NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074, - -4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113}); - NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18, - -0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42}); - - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.25999, -0.755, -1.25, -1.745, -2.24001, -2.73502, + -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548, + -6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, + -9.17262, -9.66813, -10.1637, -10.65932, -11.15501, -11.65077}); + NDArray dLdwExp( + 'c', {2, 3, 4}, + {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, + -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074, + -4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251, + -11.0965, -12.51013, -14.00341, -15.57633, -17.2289, -18.96113}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.04, 0.02, -0., -0.02, -0.04, -0.06, -0.08, -0.1, + -0.12, -0.14, -0.16, -0.18, -0.2, -0.22, -0.24, -0.26, + -0.28, -0.3, -0.32, -0.34, -0.36, -0.38, -0.4, -0.42}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, - -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669}); - NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, - -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); - - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.18499, -0.53, -0.875, -1.22, -1.56501, -1.91002, + -2.25504, -2.60008, -2.94514, -3.29023, -3.63534, -3.98048, + -4.32566, -4.67087, -5.01613, -5.36143, -5.70677, -6.05217, + -6.39762, -6.74313, -7.0887, -7.43432, -7.78001, -8.12577}); + NDArray dLdwExp('c', {2, 1, 4}, + {0.43622, -0.19079, -0.98462, -1.94525, -18.09855, -20.72768, + -23.52373, -26.48669}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.028, 0.014, -0., -0.014, -0.028, -0.042, -0.056, -0.07, + -0.084, -0.098, -0.112, -0.126, -0.14, -0.154, -0.168, -0.182, + -0.196, -0.21, -0.224, -0.238, -0.252, -0.266, -0.28, -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, - -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {}, std::vector{-91.52109}); - NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, - -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); - - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.18499, -0.53, -0.875, -1.22, -1.56501, -1.91002, + -2.25504, -2.60008, -2.94514, -3.29023, -3.63534, -3.98048, + -4.32566, -4.67087, -5.01613, -5.36143, -5.70677, -6.05217, + -6.39762, -6.74313, -7.0887, -7.43432, -7.78001, -8.12577}); + NDArray dLdwExp('c', {}, std::vector{-91.52109}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.028, 0.014, -0., -0.014, -0.028, -0.042, -0.056, -0.07, + -0.084, -0.098, -0.112, -0.126, -0.14, -0.154, -0.168, -0.182, + -0.196, -0.21, -0.224, -0.238, -0.252, -0.266, -0.28, -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {-12.54779, -28.13393, -50.83936}); - NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171, - -0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715}); - NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539, - 0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881}); - NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105, - -0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245}); - - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.01542, -0.04417, -0.07292, -0.10167, -0.13042, -0.15917, + -0.18792, -0.21667, -0.24543, -0.27419, -0.30294, -0.33171, + -0.36047, -0.38924, -0.41801, -0.44679, -0.47556, -0.50435, + -0.53314, -0.56193, -0.59072, -0.61953, -0.64833, -0.67715}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, + 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539, + 0.08553, 0.03104, -0.02808, -0.09184, -0.16023, -0.23326, + -0.31093, -0.39323, -0.48017, -0.57175, -0.66796, -0.76881}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.00233, 0.00117, -0., -0.00117, -0.00233, -0.0035, + -0.00467, -0.00583, -0.007, -0.00817, -0.00933, -0.0105, + -0.01167, -0.01283, -0.014, -0.01517, -0.01633, -0.0175, + -0.01867, -0.01983, -0.021, -0.02217, -0.02333, -0.0245}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 3, 1}, {1.4966, 0.19776, -1.69436}); - NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights(sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {}, std::vector{0.}); - - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + NDArray dLdwExp('c', {}, std::vector{0.}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - auto dLdw = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805, - -0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258}); - NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246, - 0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258}); - NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126, - -0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., -0.1565, -0.191, + -0.2255, -0.26001, -0.29451, -0.32902, -0.36353, -0.39805, + -0.43257, -0.46709, -0.50161, -0.53614, -0.57068, -0.60522, + -0.63976, -0.67431, -0.70887, -0.74343, -0.778, -0.81258}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, + 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246, + 0.19264, 0.12725, 0.0563, -0.02021, -0.10228, -0.18992, + -0.28312, -0.38188, -0.48621, -0.5961, -0.71156, -0.83258}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {-0., -0., 0., 0., -0.0028, -0.0042, -0.0056, -0.007, + -0.0084, -0.0098, -0.0112, -0.0126, -0.014, -0.0154, -0.0168, -0.0182, + -0.0196, -0.021, -0.0224, -0.0238, -0.0252, -0.0266, -0.028, -0.0294}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585, - -0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857}); - NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 , - -0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 }); - NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525, - -0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {-0.00771, -0.02208, -0.03646, -0.05083, -0.06521, -0.07958, + -0.09396, -0.10834, -0.12271, -0.13709, -0.15147, -0.16585, + -0.18024, -0.19462, -0.20901, -0.22339, -0.23778, -0.25217, + -0.26657, -0.28096, -0.29536, -0.30976, -0.32417, -0.33857}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, + -0.00132, -0.01466, -0.03032, -0.0483, -0.06859, -0.0912, + -0.11612, -0.14337, -0.17293, -0.20481, -0.23901, -0.27552, + -0.31435, -0.35551, -0.39898, -0.44476, -0.49287, -0.5433}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.00117, 0.00058, -0., -0.00058, -0.00117, -0.00175, + -0.00233, -0.00292, -0.0035, -0.00408, -0.00467, -0.00525, + -0.00583, -0.00642, -0.007, -0.00758, -0.00817, -0.00875, + -0.00933, -0.00992, -0.0105, -0.01108, -0.01167, -0.01225}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + NDArray dLdwExp('c', {1, 1}, std::vector{-3.81338}); - NDArray dLdwExp('c', {1,1}, std::vector{-3.81338}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831}); + NDArray dLdwExp('c', {1, 3, 1}, {-0.52282, -1.17225, -2.11831}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdw = results.at(1); - - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdw = results.at(1); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902, - -0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629}); - NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944, - -0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196}); - NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063, - -0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - weights.t(3) = 0.; - - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., -0.07825, -0.0955, + -0.11275, -0.13, -0.14726, -0.16451, -0.18177, -0.19902, + -0.21628, -0.23354, -0.25081, -0.26807, -0.28534, -0.30261, + -0.31988, -0.33716, -0.35443, -0.37172, -0.389, -0.40629}); + NDArray dLdwExp('c', {2, 3, 4}, + {0.0361, 0.03677, 0.03466, 0.02977, 0.0221, 0.01165, + -0.00158, -0.01759, -0.03638, -0.05795, -0.08231, -0.10944, + -0.13935, -0.17204, -0.20752, -0.24577, -0.28681, -0.33063, + -0.37723, -0.42661, -0.47877, -0.53372, -0.59144, -0.65196}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {-0., -0., 0., 0., -0.0014, -0.0021, -0.0028, -0.0035, + -0.0042, -0.0049, -0.0056, -0.0063, -0.007, -0.0077, -0.0084, -0.0091, + -0.0098, -0.0105, -0.0112, -0.0119, -0.0126, -0.0133, -0.014, -0.0147}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + weights.t(3) = 0.; + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) { - - NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , - -0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715}); - NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,}); - NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0., - -0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245}); - logits.linspace(-0.08, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.t(0) = 0.; - weights.t(1) = 0.; - weights.t(2) = 0.; - - sd::ops::sigm_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - + NDArray labels('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2, 3, 4}, + {0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.36047, -0.38924, -0.41801, -0.44679, -0.47556, -0.50435, + -0.53314, -0.56193, -0.59072, -0.61953, -0.64833, -0.67715}); + NDArray dLdwExp('c', {2, 3, 1}, + { + 0.22882, + 0.02428, + -0.4768, + -1.27447, + -2.36878, + -3.75981, + }); + NDArray dLdlExp('c', {2, 3, 4}, + {-0., -0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + -0.01167, -0.01283, -0.014, -0.01517, -0.01633, -0.0175, + -0.01867, -0.01983, -0.021, -0.02217, -0.02333, -0.0245}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.t(0) = 0.; + weights.t(1) = 0.; + weights.t(2) = 0.; + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_4) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(1); - y.linspace(1); - exp.linspace(2,2); - sd::ops::add op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2, 2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_5) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(2, 2); - y.linspace(1); - exp.linspace(1); - sd::ops::subtract op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); - + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, BFloat16_Test_6) { - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); - - x.linspace(2, 2); - y.linspace(1); - exp.linspace(1); - sd::ops::subtract op; - auto results = op.evaluate({&x, &y}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto res = results.at(0); - ASSERT_TRUE(res.equalsTo(exp)); + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 4}); //('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res.equalsTo(exp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) { + NDArray labels('c', {2, 4}, {0, 0, 1, 0, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray logits('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2}, sd::DataType::DOUBLE); + NDArray dLdpExp( + 'c', {2, 4}, + {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326}); + NDArray dLdwExp('c', {2}, {1.36729, 1.40729}); - NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326}); - NDArray dLdwExp('c', {2}, {1.36729, 1.40729}); + logits.linspace(-0.08, 0.04); + weights.assign(0.5); - logits.linspace(-0.08, 0.04); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {1}, std::vector{1.38629}); + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {1}, std::vector{1.38629}); - logits = 2.; - weights.assign(0.5); + logits = 2.; + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss_grad op; + sd::ops::softmax_cross_entropy_loss_grad op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {}, std::vector{1.38629}); - NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {}, std::vector{1.38629}); + logits = 2.; + weights.assign(0.5); - logits = 2.; - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); - NDArray dLdwExp('c', {}, std::vector{0.}); - - logits.linspace(-0.08, 0.04); - weights = 0.5; + NDArray dLdpExp('c', {4}, {0.23521, 0.2448, -0.7452, 0.26519}); + NDArray dLdwExp('c', {}, std::vector{0.}); - sd::ops::softmax_cross_entropy_loss_grad op; + logits.linspace(-0.08, 0.04); + weights = 0.5; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + sd::ops::softmax_cross_entropy_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { + NDArray labels('c', {4}, {0, 0, 1, 0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); - NDArray logits('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); + NDArray dLdwExp('c', {1}, std::vector{1.36729}); - NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); - NDArray dLdwExp('c', {1}, std::vector{1.36729}); + logits.linspace(-0.08, 0.04); + weights = 0.5; - logits.linspace(-0.08, 0.04); - weights = 0.5; + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) { + NDArray labels('c', {2, 4}, {0, 0, 1, 0, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray logits('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951}); - NDArray dLdwExp('c', {2}, {-0.014000, 0.014000}); - - logits.linspace(-0.08, 0.04); - weights.assign(0.5); + NDArray dLdpExp( + 'c', {2, 4}, + {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951}); + NDArray dLdwExp('c', {2}, {-0.014000, 0.014000}); - sd::ops::softmax_cross_entropy_loss_grad op; + logits.linspace(-0.08, 0.04); + weights.assign(0.5); - auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + sd::ops::softmax_cross_entropy_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0}, + sd::DataType::INT32); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3}, {0.5, 0., 1.5}); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, sd::DataType::INT32); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3}, {0.5, 0., 1.5}); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.0956, 0.0306, 0.03185, 0.03315, 0., -0., 0., 0., + 0.0882, 0.0918, -0.27945, 0.09945, 0.0294, 0.0306, 0.03185, -0.09185, + -0., 0., 0., 0., 0.0882, -0.2832, 0.09555, 0.09945}); + NDArray dLdwExp('c', {1, 3}, {0.69365, 0.71365, 0.69365}); - NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945, - 0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945}); - NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365}); + logits.linspace(-0.08, 0.04); - logits.linspace(-0.08, 0.04); + sd::ops::softmax_cross_entropy_loss_grad op; - sd::ops::softmax_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { - - NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0, - 0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1, - 0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, sd::DataType::INT32); - - NDArray logits('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, - 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, - 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, - 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, - 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, - 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, - 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, - 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866, - 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901}); - - NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005}); - logits.linspace(-0.08, 0.04); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss_grad op; - - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - // dLdp->printIndexedBuffer(); - - // ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - // ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - + NDArray labels( + 'c', {2, 3, 4, 5}, + {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0}, + sd::DataType::INT32); + + NDArray logits('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1, 4}, sd::DataType::DOUBLE); + + NDArray dLdpExp( + 'c', {2, 3, 4, 5}, + {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768, -0.03367, + 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, -0.03335, 0.00866, + 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, 0.00901, 0.00768, + 0.00799, 0.00832, 0.00866, -0.03265, -0.03399, 0.00799, 0.00832, + 0.00866, 0.00901, 0.00768, -0.03367, 0.00832, 0.00866, 0.00901, + 0.00768, 0.00799, -0.03335, 0.00866, 0.00901, 0.00768, 0.00799, + 0.00832, -0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866, + -0.03265, -0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768, + -0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, -0.03335, + 0.00866, 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, 0.00901, + 0.00768, 0.00799, 0.00832, 0.00866, -0.03265, -0.03399, 0.00799, + 0.00832, 0.00866, 0.00901, 0.00768, -0.03367, 0.00832, 0.00866, + 0.00901, 0.00768, 0.00799, -0.03335, 0.00866, 0.00901, 0.00768, + 0.00799, 0.00832, -0.03301, 0.00901, 0.00768, 0.00799, 0.00832, + 0.00866, -0.03265, -0.03399, 0.00799, 0.00832, 0.00866, 0.00901, + 0.00768, -0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799, + -0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832, -0.03301, + 0.00901}); + + NDArray dLdwExp('c', {1, 1, 4}, {0.005, 0.00167, -0.00167, -0.005}); + logits.linspace(-0.08, 0.04); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + // dLdp->printIndexedBuffer(); + + // ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + // ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) { + NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); - NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); - auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); - - NDArray numOfNonZero(sumDiff.shapeInfo(), sd::DataType::INT64, false); - numOfNonZero.assign(1); - sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); + NDArray numOfNonZero(sumDiff.shapeInfo(), sd::DataType::INT64, false); + numOfNonZero.assign(1); + sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519, - 0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519}); - logits.linspace(-0.08, 0.04); - - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp( + 'c', {2, 3, 4}, + {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521, -0.7552, 0.2548, 0.26519, + 0.23521, 0.2448, -0.7452, 0.26519, 0.23521, 0.2448, 0.2548, -0.73481, + -0.76479, 0.2448, 0.2548, 0.26519, 0.23521, -0.7552, 0.2548, 0.26519}); + logits.linspace(-0.08, 0.04); - auto results = op.evaluate({&logits, &labels}, {}, {}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {}); - auto dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) { + NDArray labels('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, + 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 3, 4}, + {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, + 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785, + 0.28164, 0.28164, 0.28164, -0.71836, -0.66949, 0.33051, + -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785}); + logits.linspace(-0.08, 0.04); - NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785, - 0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785}); - logits.linspace(-0.08, 0.04); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - auto results = op.evaluate({&logits, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) { + NDArray labels('c', {2, 3}, {1, 0, 0, 0, 1, 1}); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3}, {1,0,0, 0,1,1}); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004}); - logits.linspace(-0.08, 0.04); - - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {2, 3}, + {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004}); + logits.linspace(-0.08, 0.04); - auto results = op.evaluate({&logits, &labels}, {}, {0}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { + NDArray labels('c', {2, 1}, {1, 1}); + NDArray logits('c', {2, 1}, {-0.04, 0.04}); - NDArray labels('c', {2,1}, {1,1}); - NDArray logits('c', {2,1}, {-0.04, 0.04}); + NDArray dLdpExp('c', {2, 1}, {0., 0.}); - NDArray dLdpExp('c', {2,1}, {0., 0.}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - auto results = op.evaluate({&logits, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { + NDArray labels('c', {2, 1}, std::vector{1, 0}); + NDArray logits('c', {2, 1}, {-0.04, 0.04}); - NDArray labels('c', {2,1}, std::vector{1,0}); - NDArray logits('c', {2,1}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999}); + NDArray dLdpExp('c', {2, 1}, {-0.51999, 0.51999}); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { + NDArray labels('c', {1, 2}, {1, 1.}); + NDArray logits('c', {1, 2}, {-0.04, 0.04}); - NDArray labels('c', {1,2}, {1,1.}); - NDArray logits('c', {1,2}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {1,2}, {0, 0.}); + NDArray dLdpExp('c', {1, 2}, {0, 0.}); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { + NDArray labels('c', {2}, {0, 1}); + NDArray logits('c', {2}, {-0.04, 0.04}); - NDArray labels('c', {2}, {0,1}); - NDArray logits('c', {2}, {-0.04, 0.04}); - - NDArray dLdpExp('c', {2}, {0.48001, -0.48001}); + NDArray dLdpExp('c', {2}, {0.48001, -0.48001}); - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { + NDArray labels('c', {1}, std::vector{1}); + NDArray logits('c', {1}, std::vector{0.04}); - NDArray labels('c', {1}, std::vector{1}); - NDArray logits('c', {1}, std::vector{0.04}); - - NDArray dLdpExp('c', {1}, std::vector{0}); - - sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {1}, std::vector{0}); - auto results = op.evaluate({&logits, &labels}, {}, {0}); + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) { + NDArray x('c', {3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {1, 1, 1}, sd::DataType::DOUBLE); - NDArray x('c', {3,4,5}, sd::DataType::DOUBLE); - NDArray y('c', {1,1,1}, sd::DataType::DOUBLE); + NDArray dLdp('c', {3, 4, 5}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {3, 4, 5}, sd::DataType::DOUBLE); - NDArray dLdp('c', {3,4,5}, sd::DataType::DOUBLE); - NDArray dLdpExp('c', {3,4,5}, sd::DataType::DOUBLE); + x.assign(1.0); // linspace(0.1, 0.1); + y.assign(1.0); + dLdp.assign(1.0); + dLdpExp.assign(1.0); + sd::ops::multiply_bp op; - x.assign(1.0);//linspace(0.1, 0.1); - y.assign(1.0); - dLdp.assign(1.0); - dLdpExp.assign(1.0); - sd::ops::multiply_bp op; + auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); - auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdo = results.at(0); - ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdo = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) { + NDArray labels('c', {2}, {2, 1}, sd::DataType::INT64); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2}, {2,1}, sd::DataType::INT64); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717}); - - logits.linspace(0.1, 0.1); + NDArray dLdpExp('c', {2, 3}, + {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717}); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + logits.linspace(0.1, 0.1); - auto results = op.evaluate({&labels, &logits}, {}, {}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { + NDArray labels('c', {2}, {0, 1}, sd::DataType::INT64); + NDArray logits('c', {2, 3}, sd::DataType::DOUBLE); - NDArray labels('c', {2}, {0,1}, sd::DataType::INT64); - NDArray logits('c', {2,3}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 3}, + {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717}); - NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717}); + logits.linspace(-0.1, 0.1); - logits.linspace(-0.1, 0.1); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto results = op.evaluate({&labels, &logits}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { + NDArray labels('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray logits('c', {2}, {-0.2, 0.3}); - NDArray labels('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray logits('c', {2}, {-0.2, 0.3}); - - NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); + NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - auto results = op.evaluate({&labels, &logits}, {}, {}); + auto results = op.evaluate({&labels, &logits}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { + NDArray labels('c', {2, 3}, {0, 1, 1, 3, 3, 2}, sd::DataType::INT64); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, sd::DataType::INT64); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, - 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865}); - logits.linspace(-0.5, 0.1); - - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + NDArray dLdpExp('c', {2, 3, 4}, + {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, + 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, + 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, + 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865}); + logits.linspace(-0.5, 0.1); - auto results = op.evaluate({&labels, &logits}, {}, {}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto dLdp = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { + NDArray labels('c', {1, 1}, std::vector({0}), sd::DataType::INT64); + NDArray logits('c', {1, 1, 2}, {-0.3, 0.2}); - NDArray labels('c', {1,1}, std::vector({0}), sd::DataType::INT64); - NDArray logits('c', {1,1,2}, {-0.3,0.2}); + NDArray dLdpExp('c', {1, 1, 2}, {-0.62246, 0.62246}); - NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246}); + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; - sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + auto results = op.evaluate({&labels, &logits}, {}, {}); - auto results = op.evaluate({&labels, &logits}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdp = results.at(0); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + auto dLdp = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index a1cbb4319715..ecf544d603b3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -14,3125 +14,3429 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include -#include #include -#include +#include #include +#include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests12 : public testing::Test { -public: - - DeclarableOpsTests12() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests12() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests12, test_any_validation_1) { - auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); - auto y = NDArrayFactory::create('c', {2}, {1, 0}); + auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 0}); - sd::ops::transpose op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::transpose op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_EQ(x.dataType(), z.dataType()); - - + auto z = result.at(0); + ASSERT_EQ(x.dataType(), z.dataType()); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { + NDArray labels('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + NDArray predictions('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {0,1,1,0,1,0,1,0}); - NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {2, 4}, {-0., -0.5, -0.5, -0., -0.5, -0., -0.5, -0.}); + NDArray dLdwExp('c', {2, 1}, {1.2, -0.2}); - NDArray dLdpExp('c', {2,4}, {-0. , -0.5, -0.5, -0., -0.5, -0. , -0.5, -0.}); - NDArray dLdwExp('c', {2,1}, {1.2, -0.2}); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { + NDArray labels('c', {2, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); + NDArray predictions('c', {2, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 4}, sd::DataType::DOUBLE); - NDArray labels('c', {2,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); - NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,4}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,4}, {0.05, -0.15, -1. , 0.7 ,-1.25, 1.5 , -0.6 , -1.1 }); - NDArray dLdwExp('c', {1,4}, {-0.04, 2.86, 0.04, -0.92}); - NDArray dLdlExp('c', {2,4}, {0.2, 0.1, 0. , -0.1, -0.2, -0.3, -0.4, -0.5}); + NDArray dLdpExp('c', {2, 4}, {0.05, -0.15, -1., 0.7, -1.25, 1.5, -0.6, -1.1}); + NDArray dLdwExp('c', {1, 4}, {-0.04, 2.86, 0.04, -0.92}); + NDArray dLdlExp('c', {2, 4}, {0.2, 0.1, 0., -0.1, -0.2, -0.3, -0.4, -0.5}); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - sd::ops::cosine_distance_loss_grad op; + sd::ops::cosine_distance_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); - NDArray predictions('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {1}, std::vector{1.3}); - NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); - - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {1}, std::vector{1.3}); + NDArray dLdlExp('c', {4}, {0.2, 0.1, -0., -0.1}); - sd::ops::cosine_distance_loss_grad op; + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + sd::ops::cosine_distance_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { + NDArray labels('c', {1, 4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {1, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0.}, sd::DataType::DOUBLE); - NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); - NDArray predictions('c', {1,4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {1, 4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {}, std::vector{1.3}); + NDArray dLdlExp('c', {1, 4}, {0.2, 0.1, -0., -0.1}); - NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {}, std::vector{1.3}); - NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); - predictions.linspace(-0.4, 0.2); - weights.assign(0.5); + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); - NDArray predictions('c', {4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); - NDArray dLdwExp('c', {1,1}, std::vector{0.}); - NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); - - predictions.linspace(-0.4, 0.2); - weights = 0.5; + NDArray dLdpExp('c', {4}, {0.1, -0.3, -2., 1.4}); + NDArray dLdwExp('c', {1, 1}, std::vector{0.}); + NDArray dLdlExp('c', {4}, {0.4, 0.2, -0., -0.2}); - sd::ops::cosine_distance_loss_grad op; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + sd::ops::cosine_distance_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { + NDArray labels('c', {4, 1}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4, 1}, sd::DataType::DOUBLE); + NDArray weights('c', {4, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {4,1}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); - NDArray predictions('c', {4,1}, sd::DataType::DOUBLE); - NDArray weights('c', {4,1}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {4, 1}, {0.0125, -0.0375, -0.25, 0.175}); + NDArray dLdwExp('c', {4, 1}, {0.24, 0.265, 0.25, 0.32}); + NDArray dLdlExp('c', {4, 1}, {0.05, 0.025, -0., -0.025}); - NDArray dLdpExp('c', {4,1}, {0.0125, -0.0375, -0.25 , 0.175}); - NDArray dLdwExp('c', {4,1}, {0.24 , 0.265, 0.25 , 0.32}); - NDArray dLdlExp('c', {4,1}, {0.05 , 0.025, -0. , -0.025}); + predictions.linspace(-0.4, 0.2); + weights = 0.5; - predictions.linspace(-0.4, 0.2); - weights = 0.5; + sd::ops::cosine_distance_loss_grad op; - sd::ops::cosine_distance_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); - - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {1, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.00833, -0.025 , -0.16667, 0.11667,-0.20833, 0.25 , -0.1 , -0.18333, 0.00833, -0.025 , -0.16667, 0.28333, - -0.20833, 0.25 , -0.1 , -0.18333, 0.01667, -0.025 , -0.16667, 0.11667,-0.225 , 0.25 , -0.1 , -0.35 }); - NDArray dLdwExp('c', {1,3,1}, {0.50444, 0.89778, -1.40222}); - NDArray dLdlExp('c', {2,3,4}, {0.03333, 0.01667, -0. , -0.01667,-0.03333, -0.05 , -0.06667, -0.08333,-0.1, -0.11667, -0.13333, -0.15, - -0.16667, -0.18333, -0.2 , -0.21667,-0.23333, -0.25 , -0.26667, -0.28333,-0.3, -0.31667, -0.33333, -0.35 }); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0.00833, -0.025, -0.16667, 0.11667, -0.20833, 0.25, -0.1, -0.18333, + 0.00833, -0.025, -0.16667, 0.28333, -0.20833, 0.25, -0.1, -0.18333, + 0.01667, -0.025, -0.16667, 0.11667, -0.225, 0.25, -0.1, -0.35}); + NDArray dLdwExp('c', {1, 3, 1}, {0.50444, 0.89778, -1.40222}); + NDArray dLdlExp('c', {2, 3, 4}, + {0.03333, 0.01667, -0., -0.01667, -0.03333, -0.05, + -0.06667, -0.08333, -0.1, -0.11667, -0.13333, -0.15, + -0.16667, -0.18333, -0.2, -0.21667, -0.23333, -0.25, + -0.26667, -0.28333, -0.3, -0.31667, -0.33333, -0.35}); - predictions.linspace(-0.4, 0.2); - weights = 0.5; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - sd::ops::cosine_distance_loss_grad op; + sd::ops::cosine_distance_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 1, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,1,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.00625, -0.01875, -0.125 , 0.0875,-0.15625, 0.1875 , -0.075 , -0.1375, 0.00625, -0.01875, -0.125 , 0.2125, - -0.15625, 0.1875 , -0.075 , -0.1375, 0.0125 , -0.01875, -0.125 , 0.0875,-0.16875, 0.1875 , -0.075 , -0.2625}); - NDArray dLdwExp('c', {2,1,1}, {0.57, -3.2175}); - NDArray dLdlExp('c', {2,3,4}, {0.025, 0.0125, -0. , -0.0125,-0.025, -0.0375, -0.05, -0.0625,-0.075, -0.0875, -0.1 , -0.1125, - -0.125, -0.1375, -0.15, -0.1625,-0.175, -0.1875, -0.2 , -0.2125,-0.225, -0.2375, -0.25, -0.2625}); + NDArray dLdpExp( + 'c', {2, 3, 4}, + {0.00625, -0.01875, -0.125, 0.0875, -0.15625, 0.1875, -0.075, -0.1375, + 0.00625, -0.01875, -0.125, 0.2125, -0.15625, 0.1875, -0.075, -0.1375, + 0.0125, -0.01875, -0.125, 0.0875, -0.16875, 0.1875, -0.075, -0.2625}); + NDArray dLdwExp('c', {2, 1, 1}, {0.57, -3.2175}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.025, 0.0125, -0., -0.0125, -0.025, -0.0375, -0.05, -0.0625, + -0.075, -0.0875, -0.1, -0.1125, -0.125, -0.1375, -0.15, -0.1625, + -0.175, -0.1875, -0.2, -0.2125, -0.225, -0.2375, -0.25, -0.2625}); - predictions.linspace(-0.4, 0.2); - weights = 0.5; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - sd::ops::cosine_distance_loss_grad op; + sd::ops::cosine_distance_loss_grad op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); - - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { + NDArray labels('c', {2, 3, 4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2, + -0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2, + -0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {2, 3, 1}, sd::DataType::DOUBLE); - NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); - NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); - - NDArray dLdpExp('c', {2,3,4}, {0.05, -0.15, -1. , 0.7,-1.25, 1.5 , -0.6 , -1.1, 0.05, -0.15, -1. , 1.7, - -1.25, 1.5 , -0.6 , -1.1, 0.1 , -0.15, -1. , 0.7,-1.35, 1.5 , -0.6 , -2.1}); - NDArray dLdwExp('c', {2,3,1}, {1.3 , -1.36, 3.62, -6. , -0.98,-19.76}); - NDArray dLdlExp('c', {2,3,4}, {0.2, 0.1, -0. , -0.1,-0.2, -0.3, -0.4, -0.5,-0.6, -0.7, -0.8, -0.9, - -1. , -1.1, -1.2, -1.3,-1.4, -1.5, -1.6, -1.7,-1.8, -1.9, -2. , -2.1}); - - predictions.linspace(-0.4, 0.2); - weights = 0.5; + NDArray dLdpExp('c', {2, 3, 4}, + {0.05, -0.15, -1., 0.7, -1.25, 1.5, -0.6, -1.1, + 0.05, -0.15, -1., 1.7, -1.25, 1.5, -0.6, -1.1, + 0.1, -0.15, -1., 0.7, -1.35, 1.5, -0.6, -2.1}); + NDArray dLdwExp('c', {2, 3, 1}, {1.3, -1.36, 3.62, -6., -0.98, -19.76}); + NDArray dLdlExp( + 'c', {2, 3, 4}, + {0.2, 0.1, -0., -0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8, -0.9, + -1., -1.1, -1.2, -1.3, -1.4, -1.5, -1.6, -1.7, -1.8, -1.9, -2., -2.1}); - sd::ops::cosine_distance_loss_grad op; + predictions.linspace(-0.4, 0.2); + weights = 0.5; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); + sd::ops::cosine_distance_loss_grad op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - auto dLdp = results.at(0); - auto dLdw = results.at(1); - auto dLdl = results.at(2); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); - ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); - ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); - ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); - ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); - ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); - + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, hinge_loss_14) { + NDArray logits('c', {3, 4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{1.}); + NDArray labels('c', {3, 4}, {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}); - NDArray logits('c', {3,4}, sd::DataType::DOUBLE); - NDArray weights('c', {}, std::vector{1.}); - NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); - - NDArray output('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + NDArray output('c', {}, std::vector{0.}, sd::DataType::DOUBLE); - logits.linspace(1.); - weights.assign(1.); + logits.linspace(1.); + weights.assign(1.); - sd::ops::hinge_loss op; - Nd4jStatus status = op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); + sd::ops::hinge_loss op; + Nd4jStatus status = + op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.e(0) == 47.); + ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_1) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(2.); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2(sd::DataType::DOUBLE); + x.linspace(2., 2.); + eps.linspace(1.); - x.linspace(2., 2.); - eps.linspace(1.); + sd::ops::divide_bp op; + Nd4jStatus status = + op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); - sd::ops::divide_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - //ASSERT_TRUE(output.e(0) == 47.); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestDivideBP_2) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create('c', {3,4}); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - exp1.assign(1.); - exp2.assign(-2.); - x.linspace(2., 2.); - y.linspace(1.); - eps.linspace(1.); - - sd::ops::divide_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3, 4}); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + exp1.assign(1.); + exp2.assign(-2.); + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + + sd::ops::divide_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(2.); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2(sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); - x.linspace(2., 2.); - eps.linspace(1.); + x.linspace(2., 2.); + eps.linspace(1.); - sd::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute( + {&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - //ASSERT_TRUE(output.e(0) == 47.); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(output.e(0) == 47.); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) { + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3, 4}); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create('c', {3,4}); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + exp1.assign(1.); + exp2.assign(-2.); + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute( + {&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - x.linspace(2., 2.); - y.linspace(1.); - eps.linspace(1.); - exp1.assign(1.); - exp2.assign(-2.); - sd::ops::reversedivide_bp op; - Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestSliceBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp('c', {3,4}, {0., 0., 0., 0., 0., 1.,1., 0., 0., 1., 1., 0.}); - //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output('c', {3, 4}, sd::DataType::DOUBLE); - //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output.assign(119.113); - x.linspace(1.); - eps.assign(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::slice_bp op; - Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1,1,2,2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - //ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {2, 2}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 4}, {0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0.}); + // NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {3, 4}, sd::DataType::DOUBLE); + // NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + eps.assign(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::slice_bp op; + Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1, 1, 2, 2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + // ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestConfusionZero_1) { - - NDArray x('c', {2}, {1,2}, sd::DataType::INT64); - NDArray i('c', {2}, {0,2}, sd::DataType::INT64); - //NDArray eps('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp('c', {4,4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, sd::DataType::INT64); - //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); - - NDArray output('c', {4, 4}, sd::DataType::INT64); - //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output.assign(119.113); - x.linspace(1.); - //eps.assign(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::confusion_matrix op; - Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - //ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray i('c', {2}, {0, 2}, sd::DataType::INT64); + // NDArray eps('c', {2,2}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, + sd::DataType::INT64); + // NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {4, 4}, sd::DataType::INT64); + // NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + // eps.assign(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::confusion_matrix op; + Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + // ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMaximumBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output1.assign(119); - x.linspace(1.); - y.linspace(12., -1.); - eps.linspace(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::maximum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, + sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, + sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::maximum_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { - - NDArray x('c', {3,4}, sd::DataType::DOUBLE); - NDArray y('c', {3,4}, sd::DataType::DOUBLE); - NDArray eps('c', {3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); - NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); - - NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); - NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); - output1.assign(119); - x.linspace(1.); - y.linspace(12., -1.); - eps.linspace(1.); - //exp1.assign(1.); - //exp2.assign(-2.); - sd::ops::minimum_bp op; - Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output1.equalsTo(exp1)); - ASSERT_TRUE(output2.equalsTo(exp2)); + NDArray x('c', {3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {3, 4}, sd::DataType::DOUBLE); + NDArray eps('c', {3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3, 4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, + sd::DataType::DOUBLE); + NDArray exp2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, + sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + // exp1.assign(1.); + // exp2.assign(-2.); + sd::ops::minimum_bp op; + Nd4jStatus status = op.execute( + {&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); } - ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reverse_test15) { + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray axis('c', {}, std::vector{0}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {5, 4, 3, 2, 1}, sd::DataType::DOUBLE); - NDArray x('c', {5}, {1,2,3,4,5}, sd::DataType::DOUBLE); - NDArray axis('c', {}, std::vector{0}, sd::DataType::INT32); - NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {5}, {5,4,3,2,1}, sd::DataType::DOUBLE); - + sd::ops::reverse op; + // auto result = op.execute({&x, &axis}, {}, {1}, {}); + Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); + // auto z = result.at(0); + // z->printIndexedBuffer(); - sd::ops::reverse op; - // auto result = op.execute({&x, &axis}, {}, {1}, {}); - Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); - // auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - // + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + // } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test17) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray padding('c', {2, 2}, {1, 1, 2, 2}, sd::DataType::INT64); + NDArray z('c', {4, 7}, sd::DataType::DOUBLE); + NDArray exp1('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}, + sd::DataType::DOUBLE); + NDArray exp2('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}, + sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray padding('c', {2,2}, {1,1,2,2}, sd::DataType::INT64); - NDArray z('c', {4,7}, sd::DataType::DOUBLE); - NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, sd::DataType::DOUBLE); - - sd::ops::mirror_pad op; - Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp1.isSameShape(z)); - ASSERT_TRUE(exp1.equalsTo(z)); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); - z = 0.; - status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric + z = 0.; + status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp2.isSameShape(z)); - ASSERT_TRUE(exp2.equalsTo(z)); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp2.isSameShape(z)); + ASSERT_TRUE(exp2.equalsTo(z)); } ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, mirrorPad_test18) { + NDArray x('c', {3}, {1, 2, 3}, sd::DataType::DOUBLE); + NDArray padding('c', {1, 2}, {1, 1}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {2, 1, 2, 3, 2}, sd::DataType::DOUBLE); - NDArray x('c', {3}, {1,2,3}, sd::DataType::DOUBLE); - NDArray padding('c', {1, 2}, {1,1}, sd::DataType::INT32); - NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {5}, {2,1,2,3,2}, sd::DataType::DOUBLE); - - sd::ops::mirror_pad op; - Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, relu_1) { - - NDArray input('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015,0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, - 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506,0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, - 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310,0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, - 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221,0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, - 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197,0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, - 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997,0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, - 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522,0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, - 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234,0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, - 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100,0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, - 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804,0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, - 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615,0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, - 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687,0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, - 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, sd::DataType::FLOAT32); - - NDArray expected('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, 0.900299, 0.789979, 0., 0., 0., - 0.142528, 0.959611, 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, 0., 0., 0., - 0.603772, 0.799391, 0.560310, 0., 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., 0., - 0.221464, 0.824996, 0.472221, 0., 0., 0., 0.427730, 0.397933, 0.714365, 0., 0., 0., - 0.488365, 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, 0.838412, 0., 0., 0., - 0.404485, 0.677328, 0.754997, 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., 0., 0., - 0.588081, 0.652226, 0.725522, 0., 0., 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., - 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, 1.025464, 0.695378, 0., 0., 0., - 0.236289, 0.907919, 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, 0., 0., 0., - 0.133276, 0.326284, 0.102804, 0., 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., 0., - 0.177977, 0.841799, 0.800615, 0., 0., 0., 0.001991, 0.518389, 0.439322, 0., 0., 0., - 0.166846, 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, 0.868717, 0., 0., 0., - 0.174864, 0.444607, 0.445000, 0., 0., 0.}, sd::DataType::FLOAT32); - - NDArray z('c', {1,5,5,6}, sd::DataType::FLOAT32); - - sd::ops::relu op; - Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + NDArray input('c', {1, 5, 5, 6}, + {0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015, + 0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, + 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506, + 0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, + 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310, + 0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, + 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221, + 0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, + 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197, + 0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, + 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997, + 0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, + 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522, + 0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, + 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234, + 0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, + 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100, + 0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, + 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804, + 0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, + 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615, + 0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, + 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687, + 0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, + 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, + sd::DataType::FLOAT32); + + NDArray expected( + 'c', {1, 5, 5, 6}, + {0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, + 0.900299, 0.789979, 0., 0., 0., 0.142528, 0.959611, + 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, + 0., 0., 0., 0.603772, 0.799391, 0.560310, 0., + 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., + 0., 0.221464, 0.824996, 0.472221, 0., 0., 0., + 0.427730, 0.397933, 0.714365, 0., 0., 0., 0.488365, + 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, + 0.838412, 0., 0., 0., 0.404485, 0.677328, 0.754997, + 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., + 0., 0., 0.588081, 0.652226, 0.725522, 0., 0., + 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., + 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, + 1.025464, 0.695378, 0., 0., 0., 0.236289, 0.907919, + 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, + 0., 0., 0., 0.133276, 0.326284, 0.102804, 0., + 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., + 0., 0.177977, 0.841799, 0.800615, 0., 0., 0., + 0.001991, 0.518389, 0.439322, 0., 0., 0., 0.166846, + 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, + 0.868717, 0., 0., 0., 0.174864, 0.444607, 0.445000, + 0., 0., 0.}, + sd::DataType::FLOAT32); + + NDArray z('c', {1, 5, 5, 6}, sd::DataType::FLOAT32); + + sd::ops::relu op; + Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } #include "ops/declarable/helpers/multiUnique.h" //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_1) { + NDArray input1('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); + NDArray input2('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + sd::DataType::INT32); + NDArray input3('c', {2, 3}, {10, 11, 12, 13, 14, 15}, sd::DataType::INT32); + NDArray input4('c', {1, 5}, {7, 8, 9, 10, 11}, sd::DataType::INT32); + NDArray input5('c', {5, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); - NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input2('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT32); - NDArray input3('c', {2,3}, {10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input4('c', {1,5}, {7,8,9,10,11}, sd::DataType::INT32); - NDArray input5('c', {5,3}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - - //NDArray indices('c', {1}, {2}, sd::DataType::INT32); - //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + // NDArray indices('c', {1}, {2}, sd::DataType::INT32); + // NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); + std::vector arrayList( + {&input1, &input2, &input3, &input4, &input5}); - ASSERT_FALSE(sd::ops::helpers::multiUnique(arrayList)); + ASSERT_FALSE(sd::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, multiUnique_2) { - - NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); - NDArray input2('c', {3,4}, {21,22,23,24,25,26,27,28,29,210,211,212}, sd::DataType::INT32); - NDArray input3('c', {2,3}, {310,311,312,313,314,315}, sd::DataType::INT32); - NDArray input4('c', {1,5}, {47,48,49,410,411}, sd::DataType::INT32); - NDArray input5('c', {5,3}, {51,52,53,54,55,56,57,58,59,510,511,512,513,514,515}, sd::DataType::INT32); - - //NDArray indices('c', {1}, {2}, sd::DataType::INT32); - //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - - std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); - ASSERT_TRUE(sd::ops::helpers::multiUnique(arrayList)); + NDArray input1('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::INT32); + NDArray input2('c', {3, 4}, + {21, 22, 23, 24, 25, 26, 27, 28, 29, 210, 211, 212}, + sd::DataType::INT32); + NDArray input3('c', {2, 3}, {310, 311, 312, 313, 314, 315}, + sd::DataType::INT32); + NDArray input4('c', {1, 5}, {47, 48, 49, 410, 411}, sd::DataType::INT32); + NDArray input5( + 'c', {5, 3}, + {51, 52, 53, 54, 55, 56, 57, 58, 59, 510, 511, 512, 513, 514, 515}, + sd::DataType::INT32); + + // NDArray indices('c', {1}, {2}, sd::DataType::INT32); + // NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + + std::vector arrayList( + {&input1, &input2, &input3, &input4, &input5}); + ASSERT_TRUE(sd::ops::helpers::multiUnique(arrayList)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { + NDArray x('c', {3, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + NDArray gradO('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 5}, sd::DataType::DOUBLE); - NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - NDArray gradO('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', {3,5}, sd::DataType::DOUBLE); - - gradO = 1.; - exp = 0.333333; - - sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO}, {}, {0}); - auto output = result.at(0); + gradO = 1.; + exp = 0.333333; - // output->printShapeInfo(); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {0}); + auto output = result.at(0); - + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_5) { + NDArray x('c', {3, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + NDArray gradO('c', {3}, sd::DataType::DOUBLE); + NDArray exp('c', {3, 5}, sd::DataType::DOUBLE); - NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - NDArray gradO('c', {3}, sd::DataType::DOUBLE); - NDArray exp('c', {3,5}, sd::DataType::DOUBLE); + gradO = 1.; + exp = 0.2; - gradO = 1.; - exp = 0.2; + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {1}); + auto output = result.at(0); - sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO}, {}, {1}); - auto output = result.at(0); - - // output->printShapeInfo(); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) { + NDArray x('c', {8, 6, 4}, sd::DataType::DOUBLE); + NDArray gradO('c', {8, 6, 1}, sd::DataType::DOUBLE); - NDArray x('c', {8,6,4}, sd::DataType::DOUBLE); - NDArray gradO('c', {8,6,1}, sd::DataType::DOUBLE); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &gradO}, {1}, {2}); - ASSERT_EQ(Status::OK(), result.status()); - - + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &gradO}, {1}, {2}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_1) { + NDArray x('c', {5, 1}, {0, 1, 2, 3, 4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - NDArray x('c', {5, 1}, {0,1,2,3,4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); - - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - std::vector dims = {1}; + std::vector dims = {1}; - auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dims); + auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dims); - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pullRows_2) { + NDArray arr('f', {5, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + NDArray* y = new NDArray(arr.dup('c')); + NDArray x = (*y)({0, 0, 0, 1}, + true); // view, points on first column of y, shape is {5,1} - NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); - NDArray* y = new NDArray(arr.dup('c')); - NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + std::vector dims = {1}; - std::vector dims = {1}; + auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dims); - auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dims); - - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); - delete y; + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); + delete y; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, softmax_9) { - NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, sd::DataType::FLOAT32); - NDArray* arrF = new NDArray(arrC.dup('f')); - - NDArray outCC('c', {5,2}, sd::DataType::FLOAT32); - NDArray outCF('f', {5,2}, sd::DataType::FLOAT32); - NDArray outFC('c', {5,2}, sd::DataType::FLOAT32); - NDArray outFF('c', {5,2}, sd::DataType::FLOAT32); - - sd::ops::softmax op; - auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status1); - auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status2); - auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status3); - auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status4); - - // outCC.printIndexedBuffer("\n"); - // outCF.printIndexedBuffer("\n"); - // outFC.printIndexedBuffer("\n"); - // outFF.printIndexedBuffer("\n"); - - ASSERT_EQ(outCC, outCF); - ASSERT_EQ(outCC, outFC); - ASSERT_EQ(outCC, outFF); - - delete arrF; + NDArray arrC('c', {5, 2}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, + sd::DataType::FLOAT32); + NDArray* arrF = new NDArray(arrC.dup('f')); + + NDArray outCC('c', {5, 2}, sd::DataType::FLOAT32); + NDArray outCF('f', {5, 2}, sd::DataType::FLOAT32); + NDArray outFC('c', {5, 2}, sd::DataType::FLOAT32); + NDArray outFF('c', {5, 2}, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status1); + auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status2); + auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status3); + auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status4); + + // outCC.printIndexedBuffer("\n"); + // outCF.printIndexedBuffer("\n"); + // outFC.printIndexedBuffer("\n"); + // outFF.printIndexedBuffer("\n"); + + ASSERT_EQ(outCC, outCF); + ASSERT_EQ(outCC, outFC); + ASSERT_EQ(outCC, outFF); + + delete arrF; } TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) { - auto x = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); - auto y = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); - auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); - - sd::ops::maxpool2d_bp op; - Context ctx(1); - Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0}; - ctx.setIArguments(iArgs, 11); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); - + auto x = NDArrayFactory::create( + 'c', {2, 3, 10, 1}, + {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, + 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, + 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, + 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, + -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, + -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, + -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, + -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, + 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, + 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, + 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, + 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); + auto y = NDArrayFactory::create( + 'c', {2, 3, 10, 1}, + {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); + + sd::ops::maxpool2d_bp op; + Context ctx(1); + Nd4jLong iArgs[] = {5, 1, 1, 2, 2, 0, 1, 1, 1, 0, 0}; + ctx.setIArguments(iArgs, 11); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_1) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, - -1.02553482e-03, -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, -1.56231865e-04, -9.91634734e-05, - 8.99407951e-06, -3.76849275e-05, -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, - 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, - -8.02115537e-05, -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, -2.00884024e-05, -9.90276021e-06, - -1.61666367e-06, -1.09328157e-05, -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, - -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, - -2.69959855e-05, -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, -7.35144840e-06, -3.26711961e-06, - -1.13179840e-06, -4.98940426e-06, -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, - -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, - -1.33870126e-05, -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, -3.76846447e-06, -1.58258263e-06, - -7.39498375e-07, -2.83553072e-06, -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, - -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, - -7.97212033e-06, -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, -2.28403951e-06, -9.25535005e-07, - -5.10270809e-07, -1.82444705e-06, -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, - -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, - -5.28384317e-06, -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, -1.53039912e-06, -6.05077048e-07, - -3.70789934e-07, -1.27108103e-06, -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); - input.linspace(1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); - auto gradI = results.at(0); - - ASSERT_EQ(gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, + -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, + -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, + -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, -1.02553482e-03, + -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, + 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, + -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, + -1.56231865e-04, -9.91634734e-05, 8.99407951e-06, -3.76849275e-05, + -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, + -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, + 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, + -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, + -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, + -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, -8.02115537e-05, + -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, + -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, + -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, + -2.00884024e-05, -9.90276021e-06, -1.61666367e-06, -1.09328157e-05, + -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, + -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, + -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, + -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, + -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, + -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, -2.69959855e-05, + -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, + -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, + -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, + -7.35144840e-06, -3.26711961e-06, -1.13179840e-06, -4.98940426e-06, + -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, + -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, + -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, + -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, + -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, + -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, -1.33870126e-05, + -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, + -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, + -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, + -3.76846447e-06, -1.58258263e-06, -7.39498375e-07, -2.83553072e-06, + -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, + -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, + -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, + -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, + -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, + -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, -7.97212033e-06, + -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, + -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, + -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, + -2.28403951e-06, -9.25535005e-07, -5.10270809e-07, -1.82444705e-06, + -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, + -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, + -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, + -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, + -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, + -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, -5.28384317e-06, + -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, + -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, + -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, + -1.53039912e-06, -6.05077048e-07, -3.70789934e-07, -1.27108103e-06, + -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, + -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); + input.linspace(1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_2) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, - -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, - -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, - -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, - -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, - 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, - -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, - -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, - -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, - -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, - -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, - -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, + -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, + -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, + -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, + -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, + -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, + -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, + -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, + -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, + -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, + -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, + -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, + -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, + -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, + -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, + -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, + -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, + -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, + -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, + -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, + -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, + -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, + -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, + -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, + 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, + 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, + 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, + -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, + -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, + -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, + -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, + -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, + -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, + -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, + -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, + -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, + -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, + -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, + -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, + -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, + -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, + -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, + -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, + -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, + -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, + -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, + -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, + -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, + -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, + -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, + -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, + -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, + -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, + -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, + -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, + -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, + -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, + -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, + -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, + -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_3) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, - -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, - -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, - -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, - -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, - 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, - -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, - -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, - -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, - -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, - -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, - -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); - auto gradI = results.at(0); - - ASSERT_EQ(gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, + -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, + -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, + -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, + -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, + -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, + -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, + -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, + -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, + -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, + -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, + -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, + -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, + -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, + -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, + -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, + -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, + -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, + -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, + -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, + -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, + -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, + -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, + -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, + 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, + 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, + 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, + -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, + -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, + -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, + -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, + -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, + -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, + -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, + -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, + -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, + -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, + -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, + -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, + -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, + -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, + -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, + -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, + -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, + -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, + -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, + -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, + -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, + -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, + -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, + -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, + -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, + -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, + -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, + -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, + -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, + -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, + -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, + -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, + -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_4) { - - NDArray input('c', {2,3,4,10}); - NDArray gradO('c', {2,3,4,10}); - NDArray exp('c', {2,3,4,10}, {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, -0.00150102, -0.00146918, -0.00143734, -0.0014055 , -0.00137366, -0.00134182, -0.00130998, -0.00127814, -0.0012463 , -0.00121446, - -0.00194534,-0.00189916, -0.00185299, -0.00180681, -0.00176064, -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, -0.0026189 , -0.00254833, -0.00247776, -0.00240719, -0.00233662, -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, - -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, -0.00460013, -0.0043915 , -0.00418288, -0.00397425, -0.00376562, - -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, - -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, -0.26145172, -0.214688 , -0.16792431, -0.12116055, -0.07439683, -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, - 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263 , -0.01177884, -0.0173331 , -0.02288735, -0.02844159, -0.03399584, -0.0395501 , -0.04510435, -0.05065861, -0.05621284, -0.0617671 , - -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, - -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, -0.00336631, -0.00348836, -0.0036104 , -0.00373245, -0.00385449, - -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, - -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, - -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, -0.0006998 , -0.00071308, -0.00072637, -0.00073965, -0.00075294, -0.00076623, -0.00077951, -0.0007928 , -0.00080609, -0.00081937, - -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894 , -0.0005142 , -0.0005224 , -0.00053061, -0.00053881, -0.00054701, -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); - - input.linspace(-10, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); - auto gradI = results.at(0); - - ASSERT_EQ(gradI, exp); - - + NDArray input('c', {2, 3, 4, 10}); + NDArray gradO('c', {2, 3, 4, 10}); + NDArray exp( + 'c', {2, 3, 4, 10}, + {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, + -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, + -0.00150102, -0.00146918, -0.00143734, -0.0014055, -0.00137366, + -0.00134182, -0.00130998, -0.00127814, -0.0012463, -0.00121446, + -0.00194534, -0.00189916, -0.00185299, -0.00180681, -0.00176064, + -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, + -0.0026189, -0.00254833, -0.00247776, -0.00240719, -0.00233662, + -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, + -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, + -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, + -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, + -0.00460013, -0.0043915, -0.00418288, -0.00397425, -0.00376562, + -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, + -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, + -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, + -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, + -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, + -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, + -0.26145172, -0.214688, -0.16792431, -0.12116055, -0.07439683, + -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, + 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, + -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263, + -0.01177884, -0.0173331, -0.02288735, -0.02844159, -0.03399584, + -0.0395501, -0.04510435, -0.05065861, -0.05621284, -0.0617671, + -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, + -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, + -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, + -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, + -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, + -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, + -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, + -0.00336631, -0.00348836, -0.0036104, -0.00373245, -0.00385449, + -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, + -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, + -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, + -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, + -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, + -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, + -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, + -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, + -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, + -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, + -0.0006998, -0.00071308, -0.00072637, -0.00073965, -0.00075294, + -0.00076623, -0.00077951, -0.0007928, -0.00080609, -0.00081937, + -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, + -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894, + -0.0005142, -0.0005224, -0.00053061, -0.00053881, -0.00054701, + -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_5) { - - NDArray input('c', {2,2,2,5}); - NDArray gradO('c', {2,2,2,5}); - NDArray exp('c', {2,2,2,5}, {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, 1.5153568e-03, 2.0571884e-02, - 6.7926152e-03, -1.0990440e-02, -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, - 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, -1.3280880e-02, 5.9767403e-03, - 2.3028374e-02, 2.0452859e-03, -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, 3.8017845e-04, -1.6107092e-02,-3.6896234e-03, 6.4357026e-03}); - input.linspace(-20, 1); - // gradO.linspace(0.1, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(gradI, exp); - - + NDArray input('c', {2, 2, 2, 5}); + NDArray gradO('c', {2, 2, 2, 5}); + NDArray exp('c', {2, 2, 2, 5}, + {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, + 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, + 1.5153568e-03, 2.0571884e-02, 6.7926152e-03, -1.0990440e-02, + -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, + -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, + 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, + -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, + -1.3280880e-02, 5.9767403e-03, 2.3028374e-02, 2.0452859e-03, + -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, + 3.8017845e-04, -1.6107092e-02, -3.6896234e-03, 6.4357026e-03}); + input.linspace(-20, 1); + // gradO.linspace(0.1, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_6) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); + // gradO.linspace(-1.5, 0.1); + gradO = 1; - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray gradO('c', {1,1,1,5}); - NDArray exp('c', {1,1,1,5}, {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); - // gradO.linspace(-1.5, 0.1); - gradO = 1; - - sd::ops::lrn_bp op; - - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); - auto gradI = results.at(0); + sd::ops::lrn_bp op; - ASSERT_EQ(gradI, exp); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); + auto gradI = results.at(0); - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_7) { + NDArray input('c', {2, 2, 2, 5}); + NDArray gradO('c', {2, 2, 2, 5}); - NDArray input('c', {2,2,2,5}); - NDArray gradO('c', {2,2,2,5}); + input.linspace(-20, 1); + gradO.linspace(-1.5, 0.1); - input.linspace(-20, 1); - gradO.linspace(-1.5, 0.1); + const OpArgsHolder argsHolderFF({&input}, {1, 2, 0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1, 2, 0.5}, {2}); - const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); - const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_8) { + NDArray input('c', {1, 1, 1, 5}, {1, 2, 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}, {2, 3, 4, 5, 6}); - NDArray input('c', {1,1,1,5}, {1, 2, 3, 4, 5}); - NDArray gradO('c', {1,1,1,5}, {2, 3, 4, 5, 6}); - - const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); - const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + const OpArgsHolder argsHolderFF({&input}, {1, 2, 0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1, 2, 0.5}, {2}); - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_9) { + NDArray input('c', {1, 1, 1, 5}, {1, 2, 3, 4, 5}); + NDArray gradO('c', {1, 1, 1, 5}, {1, 1, 1, 1, 1}); + NDArray exp('c', {1, 1, 1, 5}, + {0.1084472, 0.03816165, 0.00978456, -0.01859251, -0.02511311}); - NDArray input('c', {1,1,1,5}, {1,2,3,4,5}); - NDArray gradO('c', {1,1,1,5}, {1, 1, 1, 1, 1}); - NDArray exp('c', {1,1,1,5}, {0.1084472 , 0.03816165, 0.00978456, -0.01859251,-0.02511311}); - - sd::ops::lrn_bp op; + sd::ops::lrn_bp op; - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); - auto gradI = results.at(0); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); + auto gradI = results.at(0); - // for (int i = 0; i < exp.lengthOf(); ++i) - // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); + // for (int i = 0; i < exp.lengthOf(); ++i) + // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); - ASSERT_EQ(gradI, exp); - - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_10) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1}); + NDArray gradO('c', {1, 1, 1, 1}, std::vector{1}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.19245008}); - NDArray input('c', {1,1,1,1}, std::vector{1}); - NDArray gradO('c', {1,1,1,1}, std::vector{1}); - NDArray exp('c', {1,1,1,1}, std::vector{0.19245008}); - - sd::ops::lrn_bp op; + sd::ops::lrn_bp op; - auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); - auto gradI = results.at(0); + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); + auto gradI = results.at(0); - ASSERT_EQ(gradI, exp); - - + ASSERT_EQ(gradI, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_1) { + NDArray input('c', {2, 2, 2, 5}); + NDArray exp('c', {2, 2, 2, 5}, + {-0.42923987, -0.3623817, -0.3152079, -0.34268343, -0.3836809, + -0.43648192, -0.3652726, -0.31428117, -0.3379276, -0.3731494, + -0.45129365, -0.37083852, -0.3111639, -0.3260225, -0.34698898, + -0.4975186, -0.3831305, -0.2847474, -0.25607377, -0.18569534, + 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565, + 0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, + 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, + 0.3821113, 0.34197718, 0.31508508, 0.36284128, 0.4303756}); - NDArray input('c', {2,2,2,5}); - NDArray exp('c', {2,2,2,5}, {-0.42923987, -0.3623817 , -0.3152079 , -0.34268343, -0.3836809, -0.43648192, -0.3652726 , -0.31428117, -0.3379276 , -0.3731494 , - -0.45129365, -0.37083852, -0.3111639 , -0.3260225 , -0.34698898, -0.4975186 , -0.3831305 , -0.2847474 , -0.25607377, -0.18569534, - 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565,0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, - 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, 0.3821113 , 0.34197718, 0.31508508, 0.36284128, 0.4303756 }); - - input.linspace(-20, 1); - - sd::ops::lrn op; + input.linspace(-20, 1); - auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); - auto output = results.at(0); + sd::ops::lrn op; - ASSERT_EQ(output, exp); + auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); + auto output = results.at(0); - + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_2) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.09530295, 0.1906059, 0.28590885, 0.3812118, 0.47651473}); - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray exp('c', {1,1,1,5}, {0.09530295, 0.1906059 , 0.28590885, 0.3812118 , 0.47651473}); - - sd::ops::lrn op; - - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); - auto output = results.at(0); - ASSERT_EQ(output, exp); + sd::ops::lrn op; - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_3) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1.}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.69006556}); - NDArray input('c', {1,1,1,1}, std::vector{1.}); - NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); + sd::ops::lrn op; - sd::ops::lrn op; - - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); - auto output = results.at(0); - ASSERT_EQ(output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_4) { + NDArray input('c', {1, 1, 1, 1}, std::vector{1.}); + NDArray exp('c', {1, 1, 1, 1}, std::vector{0.69006556}); - NDArray input('c', {1,1,1,1}, std::vector{1.}); - NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); - - sd::ops::lrn op; + sd::ops::lrn op; - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); - auto output = results.at(0); - ASSERT_EQ(output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_5) { + NDArray input('c', {1, 1, 1, 5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1, 1, 1, 5}, + {0.69006556, 0.70272833, 0.7051508, 0.7060045, 0.7064008}); - NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); - NDArray exp('c', {1,1,1,5}, {0.69006556, 0.70272833, 0.7051508 , 0.7060045 , 0.7064008}); - - sd::ops::lrn op; + sd::ops::lrn op; - auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); - auto output = results.at(0); - ASSERT_EQ(output, exp); - - + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(output, exp); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_1) { + NDArray x('c', {4, 5}, + {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, + 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + NDArray y('c', {4}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray z('c', {4}, {1., 1, 1, 1}, sd::DataType::BOOL); - NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - NDArray y('c', {4}, {0., 0, 0, 0}, sd::DataType::INT64); - NDArray z('c', {4}, {1., 1, 1, 1}, sd::DataType::BOOL); - - NDArray expV('c', {4}, {1., 0, 0, 0}, sd::DataType::BOOL); + NDArray expV('c', {4}, {1., 0, 0, 0}, sd::DataType::BOOL); - sd::ops::in_top_k op; - Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); + sd::ops::in_top_k op; + Nd4jStatus status = op.execute( + { + &x, + &y, + }, + {&z}, {}, {2}, {}); - // z.printIndexedBuffer(); - ASSERT_EQ(ND4J_STATUS_OK, status); + // z.printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expV.isSameShape(z)); - ASSERT_TRUE(expV.equalsTo(z)); + ASSERT_TRUE(expV.isSameShape(z)); + ASSERT_TRUE(expV.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_2) { + auto input = NDArrayFactory::create('c', {4, 5}); + auto idx = NDArrayFactory::create('c', {4}); - auto input = NDArrayFactory::create('c', {4, 5}); - auto idx = NDArrayFactory::create('c', {4}); - - auto exp = NDArrayFactory::create({false, false, false, true}); + auto exp = NDArrayFactory::create({false, false, false, true}); - int exclusive, reverse; - input.linspace(1); - idx.linspace(1); + int exclusive, reverse; + input.linspace(1); + idx.linspace(1); - sd::ops::in_top_k op; + sd::ops::in_top_k op; - auto res = op.evaluate({&input, &idx}, {}, {1}); + auto res = op.evaluate({&input, &idx}, {}, {1}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - //res.at(0)->printIndexedBuffer("IN_TOP_K output"); - ASSERT_TRUE(res.at(0).equalsTo(&exp)); - + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + // res.at(0)->printIndexedBuffer("IN_TOP_K output"); + ASSERT_TRUE(res.at(0).equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_3) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0, 11.0, 3.0, 14.0, 5.0, 6.0}); - auto y = NDArrayFactory::create('c', {2}, {1, 1}); - auto expV = NDArrayFactory::create('c', {2}, {true, false}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.0, 11.0, 3.0, 14.0, 5.0, 6.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto expV = NDArrayFactory::create('c', {2}, {true, false}); - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - auto v = result.at(0); + auto v = result.at(0); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_4) { - auto x = NDArrayFactory::create('c', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); - auto y = NDArrayFactory::create('c', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('c', {6}, {true, false, true, false, false, true}); - - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + auto x = NDArrayFactory::create( + 'c', {6, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + auto y = NDArrayFactory::create('c', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create( + 'c', {6}, {true, false, true, false, false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - auto v = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - + auto v = result.at(0); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, inTopK_5) { - auto x = NDArrayFactory::create('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); - auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); - - sd::ops::in_top_k op; - auto result = op.evaluate({&x, &y}, {}, {2}); + auto x = NDArrayFactory::create( + 'f', {6, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create( + 'f', {6}, {true, false, false, false, false, false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); - auto v = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto v = result.at(0); - + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_1) { + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); - NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); - NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); - - sd::ops::cube op; + sd::ops::cube op; - auto result = op.evaluate({&x}); + auto result = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, cube_bp_1) { + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray gradO('c', {2, 3}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); - NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); - NDArray gradO('c', {2, 3}, sd::DataType::DOUBLE); - NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); - - gradO = 0.5; + gradO = 0.5; - sd::ops::cube_bp op; + sd::ops::cube_bp op; - auto result = op.evaluate({&x, &gradO}); + auto result = op.evaluate({&x, &gradO}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printIndexedBuffer(); + auto z = result.at(0); + // z->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, pad_tests1) { + NDArray input('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::FLOAT32); + NDArray paddings('c', {2, 2}, {1, 1, 2, 2}, sd::DataType::INT32); + NDArray expected('c', {4, 7}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, + 0, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - NDArray input('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray paddings('c', {2,2}, {1,1,2,2}, sd::DataType::INT32); - NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, sd::DataType::FLOAT32); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, pad_tests2) { + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + int padBuff[] = {1, 1, 2, 2}; + float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, + 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, + 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; - float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - int padBuff[] = {1,1,2,2}; - float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printIndexedBuffer(); - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, pad_tests3) { + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + Nd4jLong padBuff[] = {1, 1, 2, 2}; + float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f, 1.f, 1.f, + 2.f, 3.f, 3.f, 2.f, 5.f, 4.f, 4.f, 5.f, 6.f, 6.f, + 5.f, 5.f, 4.f, 4.f, 5.f, 6.f, 6.f, 5.f}; - float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - Nd4jLong padBuff[] = {1,1,2,2}; - float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, pad_tests4) { - - float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; - int padBuff[] = {1,1,2,2,2,2}; - float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, - 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, - 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - // for(int i = 0; i < expected.lengthOf(); ++i) { - // float one = expected.e(i); - // float two = result->e(i); - // if(one != two) - // printf("%i : %f, %f\n", i, one, two); - // } - - + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + float expBuff[] = { + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, + 0.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, + 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, + 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + // for(int i = 0; i < expected.lengthOf(); ++i) { + // float one = expected.e(i); + // float two = result->e(i); + // if(one != two) + // printf("%i : %f, %f\n", i, one, two); + // } } - - //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, pad_tests5) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, 18, 17, 16, 17, 18, 17, 16, 15, + 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, + 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, + 4, 3, 2, 1, 2, 3, 2, 1, 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, + 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, + 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, pad_tests6) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, + 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, + 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, + 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, + 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, + 8, 5, 4, 4, 5, 6, 6, 5, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, + 11, 12, 12, 11, 11, 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, + 17, 16, 16, 17, 18, 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, + 15, 15, 14, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, 11, 12, 12, 11, 11, + 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, 17, 16, 16, 17, 18, + 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, 15, 15, 14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D -TEST_F(DeclarableOpsTests12, pad_tests7) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests7) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, + 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D -TEST_F(DeclarableOpsTests12, pad_tests8) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests8) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, + 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, + 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, + 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, + 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, + 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, + 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, + 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, + 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, + 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, + 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, + 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, + 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, + 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, + 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D -TEST_F(DeclarableOpsTests12, pad_tests9) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, pad_tests9) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, + 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, + 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, + 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, + 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, + 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, + 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, + 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, + 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, + 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, + 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, + 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, + 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, + 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests10) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {0, 0, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - - input = 1.f; - //input.assign(1.); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + input = 1.f; + // input.assign(1.); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests11) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {0, 0, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 5., 6., 7., 8., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., 17., 18., 19., 20.}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,4}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,10.,11.,12., 5., 6., 7., 8.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24.,17.,18.,19.,20.}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests12) { - - auto input = NDArrayFactory::create('c', {2,3,4,5}); - auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {2,4,5,5}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., 20., - 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., 39., 40., - 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., - 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., - 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 76., 77., 78., 79., 80., - 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.,100., 96., 97., 98., 99.,100., - 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120., - 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto input = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto paddings = + NDArrayFactory::create('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {2, 4, 5, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., + 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., + 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., + 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., + 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., + 58., 59., 60., 41., 42., 43., 44., 45., 46., 47., 48., 49., + 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., + 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., + 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., + 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., + 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., + 119., 120., 116., 117., 118., 119., 120., 101., 102., 103., 104., 105., + 106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., + 118., 119., 120., 116., 117., 118., 119., 120.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests13) { + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {2, 3}); + auto expected = NDArrayFactory::create( + 'c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5}); - auto paddings = NDArrayFactory::create('c', {1,2}, {2,3}); - auto expected = NDArrayFactory::create('c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests14) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 3}); + auto expected = NDArrayFactory::create( + 'c', {1, 10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,2,3}); - auto expected = NDArrayFactory::create('c', {1,10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests15) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {1,1,0,0}); - auto expected = NDArrayFactory::create('c', {3,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests16) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 3, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {2,3,0,0}); - auto expected = NDArrayFactory::create('c', {10,1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests17) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 2}, {1., 1., 2., 2., 3., 3., 4., 4., 5., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,1,0}); - auto expected = NDArrayFactory::create('c', {5,2}, {1.,1., 2.,2., 3.,3., 4.,4., 5.,5.}); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests18) { + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {0, 0}); + auto expected = + NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5}); - auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); - auto expected = NDArrayFactory::create('c', {5}, {1.,2.,3.,4.,5.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests19) { + auto input = NDArrayFactory::create('c', {5, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = + NDArrayFactory::create('c', {5, 1}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {5,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); - auto expected = NDArrayFactory::create('c', {5,1}, {1., 2., 3., 4., 5.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests20) { + auto input = NDArrayFactory::create('c', {1, 5}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = + NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,5}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); - auto expected = NDArrayFactory::create('c', {1,5}, {1., 2., 3., 4., 5.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests21) { + auto input = NDArrayFactory::create('c', {1, 3, 1, 5}); + auto paddings = + NDArrayFactory::create('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}); + auto expected = NDArrayFactory::create( + 'c', {1, 4, 2, 5}, + {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 10., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 11., 12., 13., + 14., 15., 11., 12., 13., 14., 15., 11., 12., 13., 14., 15.}); + input.linspace(1.f); - auto input = NDArrayFactory::create('c', {1,3,1,5}); - auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); - auto expected = NDArrayFactory::create('c', {1,4,2,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9.,10., 6., 7., 8., 9.,10., - 11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.}); - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests22) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto expected = NDArrayFactory::create('c', {1, 1}, {1.}); - auto input = NDArrayFactory::create('c', {1,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 0,0}); - auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + input.linspace(1.f); - input.linspace(1.f); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printIndexedBuffer(); - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests23) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 0}); + auto expected = NDArrayFactory::create('c', {1, 2}, {0., 1.}); - auto input = NDArrayFactory::create('c', {1,1}); - auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 1,0}); - auto expected = NDArrayFactory::create('c', {1,2}, {0.,1.}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printShapeInfo("r"); - // expected.printShapeInfo("e"); + auto result = results.at(0); + // result->printShapeInfo("r"); + // expected.printShapeInfo("e"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests24) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {0, 0}); + auto expected = NDArrayFactory::create('c', {1}, {1.}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); - auto expected = NDArrayFactory::create('c', {1}, {1.}); - - input.linspace(1.f); + input.linspace(1.f); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests25) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {1, 1}); + auto expected = NDArrayFactory::create('c', {3}, {1., 1., 1}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {1,1}); - auto expected = NDArrayFactory::create('c', {3}, {1.,1.,1}); - - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + input.linspace(1.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests26) { + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {3, 2}); + auto expected = + NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); - auto input = NDArrayFactory::create('c', {1}); - auto paddings = NDArrayFactory::create('c', {1,2}, {3,2}); - auto expected = NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); - - input.linspace(1.f); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + input.linspace(1.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests27) { + NDArray input('c', {2, 3}, sd::DataType::FLOAT32); + NDArray paddings('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::INT32); + NDArray exp('c', {2, 4}, {1, 1, 1, 0, 1, 1, 1, 0}, sd::DataType::FLOAT32); + NDArray z('c', {2, 4}, sd::DataType::FLOAT32); + input = 1.; - NDArray input('c', {2,3}, sd::DataType::FLOAT32); - NDArray paddings('c', {2,2}, {0,0,0,1}, sd::DataType::INT32); - NDArray exp('c', {2,4}, {1,1,1,0,1,1,1,0}, sd::DataType::FLOAT32); - NDArray z('c', {2,4}, sd::DataType::FLOAT32); - input = 1.; + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant - // z.printIndexedBuffer(); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.isSameShapeStrict(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.isSameShapeStrict(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests28) { + NDArray input('c', {1, 111, 111, 32}, sd::DataType::FLOAT32); + NDArray paddings('c', {4, 2}, {0, 0, 0, 1, 0, 1, 0, 0}, sd::DataType::INT32); + NDArray z('c', {1, 112, 112, 32}, sd::DataType::FLOAT32); + input = 1.; - NDArray input('c', {1,111,111,32}, sd::DataType::FLOAT32); - NDArray paddings('c', {4,2}, {0,0,0,1,0,1,0,0}, sd::DataType::INT32); - NDArray z('c', {1,112,112,32}, sd::DataType::FLOAT32); - input = 1.; - - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant - // z.printIndexedBuffer(); + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); - NDArray sum = z.reduceNumber(sd::reduce::Sum); + NDArray sum = z.reduceNumber(sd::reduce::Sum); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_EQ(sum.e(0), 111*111*32); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(sum.e(0), 111 * 111 * 32); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests29) { + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); - - auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests30) { + auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); - auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); - - auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); + auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests31) { + auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); + auto exp = NDArrayFactory::create( + {11., 1., 11., 111., 1111., 11111., 1111.}); - auto exp = NDArrayFactory::create({11., 1., 11., 111., 1111., 11111., 1111.}); + sd::ops::pad op; - sd::ops::pad op; - - auto res = op.evaluate({&in, &pad}, {10.0}, {1}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {1}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests32) { + auto in = NDArrayFactory::create('c', {3, 3}, + {1., 2., 3., 4., 5., 6, 7, 8, 9}); + auto pad = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 4., 5.,6,7,8,9}); - auto pad = NDArrayFactory::create('c', {2,2}, {1, 2, 2, 3}); - - auto exp = NDArrayFactory::create('c', {6,8}, {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 8}, + {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, + 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests33) { - - auto in = NDArrayFactory::create('c', {2,3,4}, {1, 2, 3, 4,5, 6, 7, 8,9,10,11,12,13, 14, 15, 16,17, 18, 19, 20,21, 22, 23, 24}); - - auto pad = NDArrayFactory::create('c', {3,2}, {1, 2, 2, 3, 3,3}); - - auto exp = NDArrayFactory::create('c', {5,8,10}, { 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., - 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., - 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10.,7,6,5,5,6,7,8,8,7,6., - 3,2,1,1,2,3,4,4,3,2., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., - 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., - 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., - 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., - 15,14,13,13,14,15,16,16,15,14., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., - 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); - sd::ops::pad op; - - auto res = op.evaluate({&in, &pad}, {10.0}, {2}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto in = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + auto pad = NDArrayFactory::create('c', {3, 2}, {1, 2, 2, 3, 3, 3}); + + auto exp = NDArrayFactory::create( + 'c', {5, 8, 10}, + {7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., + 7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., + 19, 18, 17, 17, 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, + 16, 16, 15, 14., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., 19, 18, + 17, 17, 18, 19, 20, 20, 19, 18., 23, 22, 21, 21, 22, 23, 24, 24, + 23, 22., 23, 22, 21, 21, 22, 23, 24, 24, 23, 22., 19, 18, 17, 17, + 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., + 19, 18, 17, 17, 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, + 16, 16, 15, 14., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., 19, 18, + 17, 17, 18, 19, 20, 20, 19, 18., 23, 22, 21, 21, 22, 23, 24, 24, + 23, 22., 23, 22, 21, 21, 22, 23, 24, 24, 23, 22., 19, 18, 17, 17, + 18, 19, 20, 20, 19, 18., 15, 14, 13, 13, 14, 15, 16, 16, 15, 14., + 7, 6, 5, 5, 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, + 4, 4, 3, 2., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2., 7, 6, + 5, 5, 6, 7, 8, 8, 7, 6., 11, 10, 9, 9, 10, 11, 12, 12, + 11, 10., 11, 10, 9, 9, 10, 11, 12, 12, 11, 10., 7, 6, 5, 5, + 6, 7, 8, 8, 7, 6., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2.}); + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, pad_tests34) { + NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, + sd::DataType::FLOAT32); + NDArray paddings('c', {1, 2}, {1, 1}, sd::DataType::INT32); + NDArray expected('c', {7}, + {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, + sd::DataType::FLOAT32); + NDArray z('c', {7}, sd::DataType::FLOAT32); - NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, sd::DataType::FLOAT32); - NDArray paddings('c', {1,2}, {1,1}, sd::DataType::INT32); - NDArray expected('c', {7}, {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, sd::DataType::FLOAT32); - NDArray z('c', {7}, sd::DataType::FLOAT32); - - sd::ops::pad op; - Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant + sd::ops::pad op; + Nd4jStatus status = + op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 2D TEST_F(DeclarableOpsTests12, Pad_1) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, + 0, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result->printIndexedBuffer(); + auto result = results.at(0); + // result->printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // REFLECT mode 2D TEST_F(DeclarableOpsTests12, Pad_2) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result->printIndexedBuffer(); - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 2D TEST_F(DeclarableOpsTests12, Pad_3) { + double inBuff[] = {1, 2, 3, 4, 5, 6}; + int padBuff[] = {1, 1, 2, 2}; + double expBuff[] = {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}; - double inBuff[] = {1,2,3,4,5,6}; - int padBuff[] = {1,1,2,2}; - double expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // CONSTANT mode 3D TEST_F(DeclarableOpsTests12, Pad_4) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 10, 11, 12, 0, 0, 0, 0, 13, 14, 15, 0, 0, 0, 0, 16, 17, 18, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); - auto result = results.at(0); - // result->printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result->printIndexedBuffer(); - + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - - //////////////////////////////////////////////////////////////////// // REFLECT mode 3D TEST_F(DeclarableOpsTests12, Pad_5) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, 18, 17, 16, 17, 18, 17, 16, 15, + 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, + 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, + 4, 3, 2, 1, 2, 3, 2, 1, 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, + 14, 15, 14, 13, 12, 11, 10, 11, 12, 11, 10, 15, 14, 13, 14, 15, 14, 13, + 18, 17, 16, 17, 18, 17, 16, 15, 14, 13, 14, 15, 14, 13, 12, 11, 10, 11, + 12, 11, 10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, + 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, + 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// // SYMMETRIC mode 3D TEST_F(DeclarableOpsTests12, Pad_6) { - - double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; - int padBuff[] = {1,1,2,2,2,2}; - double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; - - auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + int padBuff[] = {1, 1, 2, 2, 2, 2}; + double expBuff[] = { + 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, + 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, + 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, + 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, + 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, + 8, 5, 4, 4, 5, 6, 6, 5, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, + 11, 12, 12, 11, 11, 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, + 17, 16, 16, 17, 18, 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, + 15, 15, 14, 14, 13, 13, 14, 15, 15, 14, 11, 10, 10, 11, 12, 12, 11, 11, + 10, 10, 11, 12, 12, 11, 14, 13, 13, 14, 15, 15, 14, 17, 16, 16, 17, 18, + 18, 17, 17, 16, 16, 17, 18, 18, 17, 14, 13, 13, 14, 15, 15, 14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2, 3, 3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 7, 7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // CONSTANT mode 4D -TEST_F(DeclarableOpsTests12, Pad_7) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_7) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, + 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, + 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// // REFLECT mode 4D -TEST_F(DeclarableOpsTests12, Pad_8) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_8) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, + 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, + 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, + 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, + 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, + 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, + 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, + 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, + 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, + 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, + 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, + 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, + 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, + 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, + 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } ////////////////////////////////////////////////////////////////// // SYMMETRIC mode 4D -TEST_F(DeclarableOpsTests12, Pad_9) -{ - - double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; - double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; - auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); - - sd::ops::pad op; - auto results = op.evaluate({&input, &paddings}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result->printIndexedBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - +TEST_F(DeclarableOpsTests12, Pad_9) { + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = { + 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, + 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, + 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, + 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, + 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, + 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, + 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, + 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, + 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, + 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, + 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, + 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, + 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, + 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); } TEST_F(DeclarableOpsTests12, Test_Expose_1) { - auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); - auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - - sd::ops::expose op; + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); - auto result = op.evaluate({&input0, &input1}); + sd::ops::expose op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&input0, &input1}); - auto z0 = result.at(0); - auto z1 = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(input0.equalsTo(z0)); - ASSERT_TRUE(input1.equalsTo(z1)); + auto z0 = result.at(0); + auto z1 = result.at(1); - + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); + // auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = + // Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + // auto value(10.0); - auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); -// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); - auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); -// auto value(10.0); - - auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); - sd::ops::pad op; + sd::ops::pad op; - auto res = op.evaluate({&in, &pad}, {10.0}, {0}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - // res.at(0)->printIndexedBuffer("PAD_SGO"); - // exp.printIndexedBuffer("PAD_EXP"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + // res.at(0)->printIndexedBuffer("PAD_SGO"); + // exp.printIndexedBuffer("PAD_EXP"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_1) { + auto in = NDArrayFactory::create( + 'c', {3, 3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); + auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); + sd::ops::lu op; - auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto exp = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); - auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars"); -// p->printIndexedBuffer("Permutaions"); - - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(pExp.equalsTo(p)); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars"); + // p->printIndexedBuffer("Permutaions"); - + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(pExp.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_2) { - auto in = NDArrayFactory::create('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); + auto in = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); - auto expLU = NDArrayFactory::create('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); - auto expP = NDArrayFactory::create({2, 0, 1}); - sd::ops::lu op; + auto expLU = NDArrayFactory::create( + 'c', {3, 3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); + auto expP = NDArrayFactory::create({2, 0, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars2"); -// p->printIndexedBuffer("Permutaions2"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars2"); + // p->printIndexedBuffer("Permutaions2"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3) { - auto in = NDArrayFactory::create('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13}); + auto in = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13}); - auto expLU = NDArrayFactory::create('c', {3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753}); + auto expLU = NDArrayFactory::create( + 'c', {3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753}); - auto expP = NDArrayFactory::create({2, 1, 0}); - sd::ops::lu op; + auto expP = NDArrayFactory::create({2, 1, 0}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3"); -// p->printIndexedBuffer("Permutaions3"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3"); + // p->printIndexedBuffer("Permutaions3"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4) { - - auto in = NDArrayFactory::create('c', {10,10}, { - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); - - auto expLU = NDArrayFactory::create('c', {10,10}, { - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 - }); - - auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printBuffer("Triangulars4"); -// expLU.printBuffer("TriangulExp4"); -// p->printBuffer("Permutaions4"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto in = NDArrayFactory::create( + 'c', {10, 10}, + {1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create( + 'c', {10, 10}, + {5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695}); + + auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printBuffer("Triangulars4"); + // expLU.printBuffer("TriangulExp4"); + // p->printBuffer("Permutaions4"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } TEST_F(DeclarableOpsTests12, LU_Test_5) { - - auto in = NDArrayFactory::create('c', {2, 10,10}, { - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., - - 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., - 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., - 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., - 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., - 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., - 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., - 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., - 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., - 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., - 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. - }); - - auto expLU = NDArrayFactory::create('c', {2, 10,10}, { - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695, - - 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, - 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, - 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, - 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, - 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, - 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, - 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, - 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, - 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, - 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 - - }); - - auto expP = NDArrayFactory::create('c', {2, 10}, { - 1, 2, 7, 3, 6, 8, 5, 4, 0, 9, - 1, 2, 7, 3, 6, 8, 5, 4, 0, 9 - }); - sd::ops::lu op; - - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printBuffer("Triangulars5"); -// expLU.printBuffer("TriangulExp5"); -// p->printBuffer("Permutaions5"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + auto in = NDArrayFactory::create( + 'c', {2, 10, 10}, + {1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., 5., 1., 13., 4., 15., + 1., 17., 9., 11., 25., 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., 2., 3., 2., 5., 4., + 4., 7., 3, 3., 4., 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., 3., 4., 3., 3., 4., + 4., 4., 1., 3., 1., 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create( + 'c', {2, 10, 10}, + {5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695, + + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, + 17.0, 9.0, 11.0, 25.0, 0.2, 8.8, + -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, + 18.8, 10.0, 0.6, 0.386364, -4.181818, -0.636364, + -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, + -6.978261, -8.114130, -17.641304, -9.836957, 0.4, 0.068182, + 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, + 22.365079, 25.751323, 0.2, 0.090909, 0.347826, -0.031746, + -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, + -7.593805, -9.585099, 1.663379, -15.900300, 0.4, 0.295455, + 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, + -16.392106, -9.022119, 0.2, 0.204545, -0.173913, -0.592593, + 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, + -0.030154, -0.243578, 0.087256, 0.112695 + + }); + + auto expP = NDArrayFactory::create( + 'c', {2, 10}, + {1, 2, 7, 3, 6, 8, 5, 4, 0, 9, 1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printBuffer("Triangulars5"); + // expLU.printBuffer("TriangulExp5"); + // p->printBuffer("Permutaions5"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_1_2) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 2., 3., 0., 2., 3., 0., 0., 7., 1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto in = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.}); - auto exp = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); - - sd::ops::lu op; + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars (2,3,3)"); -// p->printIndexedBuffer("Permutaions (2,3,3)"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); - + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars (2,3,3)"); + // p->printIndexedBuffer("Permutaions (2,3,3)"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3_2) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13, 1, 2, 3, 4, 7, 9, 11, 12, 13}); - auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13}); + auto expLU = NDArrayFactory::create( + 'c', {2, 3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753, - auto expLU = NDArrayFactory::create('c', {2, 3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753, + 11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753}); - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753 - }); + auto expP = NDArrayFactory::create('c', {2, 3}, {2, 1, 0, 2, 1, 0}); + sd::ops::lu op; - auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); - sd::ops::lu op; + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3_2"); + // p->printIndexedBuffer("Permutaions3_2"); - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3_2"); -// p->printIndexedBuffer("Permutaions3_2"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_3_3) { + auto in = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1, 2, 3, 4, 7, 9, 11, 12, 13, 13, 2, 3, 4, 7, 9, 11, 12, 1}); + auto expLU = NDArrayFactory::create( + 'c', {2, 3, 3}, + {11., 12., 13., 0.36363637, 2.6363635, 4.272727, 0.09090909, 0.3448276, + 0.34482753, - auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1}); - auto expLU = NDArrayFactory::create('c', {2, 3,3}, { - 11., 12., 13., - 0.36363637, 2.6363635, 4.272727, - 0.09090909, 0.3448276, 0.34482753, - - 13., 2., 3., - 0.84615386, 10.307693, -1.5384617, - 0.30769232, 0.619403, 9.029851}); + 13., 2., 3., 0.84615386, 10.307693, -1.5384617, 0.30769232, 0.619403, + 9.029851}); - auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 3}, {2, 1, 0, 0, 2, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars3_3"); -// p->printIndexedBuffer("Permutaions3_3"); + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars3_3"); + // p->printIndexedBuffer("Permutaions3_3"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4_1) { + auto in = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto in = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); + auto expLU = + NDArrayFactory::create('c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f}); - auto expLU = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.930149f, -0.514335f, - 0.7271f, 0.1804f, 0.695365f, 0.767056f - }); + auto expP = NDArrayFactory::create('c', {2, 2}, {0, 1, 0, 1}); + sd::ops::lu op; - auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); - sd::ops::lu op; + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + // z->printIndexedBuffer("Triangulars4_1"); + // p->printIndexedBuffer("Permutaions4_1"); - auto res = op.evaluate({&in}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); -// z->printIndexedBuffer("Triangulars4_1"); -// p->printIndexedBuffer("Permutaions4_1"); - - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, LU_Test_4_2) { + auto in = NDArrayFactory::create( + 'c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f}); - auto in = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.7244f, 0.2309f, - 0.7271f, 0.1804f, 0.5056f, 0.8925f - }); - - auto expLU = NDArrayFactory::create('c', {2, 2,2}, { - 0.7788f, 0.8012f, 0.930149f, -0.514335f, - 0.7271f, 0.1804f, 0.695365f, 0.767056f - }); + auto expLU = + NDArrayFactory::create('c', {2, 2, 2}, + {0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f}); - auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); - sd::ops::lu op; + auto expP = NDArrayFactory::create('c', {2, 2}, {0, 1, 0, 1}); + sd::ops::lu op; - auto res = op.evaluate({&in}, {}, {sd::DataType::INT64}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - auto p = res.at(1); + auto res = op.evaluate({&in}, {}, {sd::DataType::INT64}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); -// z->printIndexedBuffer("Triangulars4_2"); -// p->printIndexedBuffer("Permutaions4_2"); + // z->printIndexedBuffer("Triangulars4_2"); + // p->printIndexedBuffer("Permutaions4_2"); - ASSERT_TRUE(expLU.equalsTo(z)); - ASSERT_TRUE(expP.equalsTo(p)); - + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_1) { - - auto in = NDArrayFactory::create('c', {5,3}, { - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. - }); - auto expQ = NDArrayFactory::create('c', {5, 5}, { - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 - }); - - auto expR = NDArrayFactory::create('c', {5,3}, { - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {true}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); -// q->printIndexedBuffer("Orthogonal 5x5"); -// expQ.printBuffer("Orthogonal Exp"); -// r->printIndexedBuffer("Upper triangular 5x3"); -// expR.printBuffer("Upper triangular Exp"); -// q->printShapeInfo("Q shape"); -// r->printShapeInfo("R shape"); - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp.isSameShape(in)); -// ASSERT_TRUE(q->isSameShape(expQ)); - - //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp.equalsTo(in)); - - + auto in = NDArrayFactory::create( + 'c', {5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {5, 5}, + {0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485}); + + auto expR = NDArrayFactory::create( + 'c', {5, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., + 35.201546, 0., 0., 0., 0., 0., 0.}); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + // q->printIndexedBuffer("Orthogonal 5x5"); + // expQ.printBuffer("Orthogonal Exp"); + // r->printIndexedBuffer("Upper triangular 5x3"); + // expR.printBuffer("Upper triangular Exp"); + // q->printShapeInfo("Q shape"); + // r->printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + // ASSERT_TRUE(q->isSameShape(expQ)); + + // ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp.equalsTo(in)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_1_1) { - - auto in = NDArrayFactory::create('c', {4, 5, 3}, { - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., - 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. - }); - auto expQ = NDArrayFactory::create('c', {4, 5, 5}, { - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, - 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 - }); - - auto expR = NDArrayFactory::create('c', {4, 5,3}, { - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., - -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. - }); - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {true}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); -// q->printIndexedBuffer("Orthogonal 5x5"); -// expQ.printBuffer("Orthogonal Exp"); -// r->printIndexedBuffer("Upper triangular 5x3"); -// expR.printBuffer("Upper triangular Exp"); -// q->printShapeInfo("Q shape"); -// r->printShapeInfo("R shape"); - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp.isSameShape(in)); -// ASSERT_TRUE(q->isSameShape(expQ)); - - //ASSERT_TRUE(expQ.equalsTo(q)); - ASSERT_TRUE(exp.equalsTo(in)); - - + auto in = NDArrayFactory::create( + 'c', {4, 5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {4, 5, 5}, + {0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, + -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, + 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, + 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, + -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485}); + + auto expR = NDArrayFactory::create( + 'c', {4, 5, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, + 0., 0., 35.201546, 0., 0., 0., + 0., 0., 0., -14.177447, -20.666622, 13.401566, + 0., -175.04254, 70.080315, 0., 0., 35.201546, + 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, + 0., 0., 35.201546, 0., 0., 0., + 0., 0., 0., -14.177447, -20.666622, 13.401566, + 0., -175.04254, 70.080315, 0., 0., 35.201546, + 0., 0., 0., 0., 0., 0.}); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + // q->printIndexedBuffer("Orthogonal 5x5"); + // expQ.printBuffer("Orthogonal Exp"); + // r->printIndexedBuffer("Upper triangular 5x3"); + // expR.printBuffer("Upper triangular Exp"); + // q->printShapeInfo("Q shape"); + // r->printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + // ASSERT_TRUE(q->isSameShape(expQ)); + + // ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp.equalsTo(in)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, QR_Test_2) { - - auto in = NDArrayFactory::create('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); - auto expQ = NDArrayFactory::create('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161}); - auto expR = NDArrayFactory::create('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546}); - - sd::ops::qr op; - auto res = op.evaluate({&in}, {}, {}, {false}); - - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto q = res.at(0); - auto r = res.at(1); - ASSERT_TRUE(q.isSameShape(expQ)); - ASSERT_TRUE(r.isSameShape(expR)); - - sd::ops::matmul opMul; - auto res2 = opMul.evaluate({&q, &r}); //MmulHelper::matmul(q, r, &in, false, false); - auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); - ASSERT_TRUE(exp.isSameShape(in)); - ASSERT_TRUE(exp.equalsTo(in)); - + auto in = NDArrayFactory::create( + 'c', {5, 3}, + {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create( + 'c', {5, 3}, + {0.8464148, 0.3912908, -0.3431241, -0.42320737, -0.9040873, 0.02927014, + 0.28213826, -0.17042054, -0.93285596, 0.07053456, -0.01404065, + 0.00109937, -0.14106913, 0.0166551, 0.10577161}); + auto expR = NDArrayFactory::create( + 'c', {3, 3}, + {-14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., + 35.201546}); + + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {false}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + ASSERT_TRUE(q.isSameShape(expQ)); + ASSERT_TRUE(r.isSameShape(expR)); + + sd::ops::matmul opMul; + auto res2 = + opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0); //->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp.isSameShape(in)); + ASSERT_TRUE(exp.equalsTo(in)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {4.f, 2.f, 4.f, 2.f}); - auto b = NDArrayFactory::create('c', {4, 1}, { - 4.f, 2.f, 4.f, 2.f - }); + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + sd::ops::triangular_solve op; - sd::ops::triangular_solve op; + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printIndexedBuffer("TriangularSolve"); -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 1.f, 1.3333333f }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = + NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 1.f, 1.3333333f}); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { + auto a = NDArrayFactory::create( + 'c', {2, 4, 4}, {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, - auto a = NDArrayFactory::create('c', {2, 4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f, + 3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); + auto b = NDArrayFactory::create( + 'c', {2, 4, 1}, {4.f, 2.f, 4.f, 2.f, 4.f, 2.f, 4.f, 2.f}); - auto b = NDArrayFactory::create('c', {2, 4, 1}, { - 4.f, 2.f, 4.f, 2.f, - 4.f, 2.f, 4.f, 2.f - }); + auto exp = NDArrayFactory::create( + 'c', {2, 4, 1}, + {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, 1.333333f, -0.6666667f, + 2.6666667f, -1.3333333f}); - auto exp = NDArrayFactory::create('c', {2, 4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f - }); + sd::ops::triangular_solve op; - sd::ops::triangular_solve op; + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + // z->printIndexedBuffer("TriangularSolve"); -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { - - auto a = NDArrayFactory::create('c', {4, 4}, { - 1.f, 1.f, 1.f, 1.f, - 0.f, 1.f, 1.f, 0.f, - 0.f, 0.f, 2.f, 1.f, - 0.f, 0.f, 0.f, 3.f, - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 2.f, 4.f, 2.f, 4.f - }); - - auto exp = NDArrayFactory::create('c', {4, 1}, { - -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f - }); - - sd::ops::triangular_solve op; - - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - -// z->printIndexedBuffer("TriangularSolve"); - - ASSERT_TRUE(exp.equalsTo(z)); - + auto a = NDArrayFactory::create('c', {4, 4}, + { + 1.f, + 1.f, + 1.f, + 1.f, + 0.f, + 1.f, + 1.f, + 0.f, + 0.f, + 0.f, + 2.f, + 1.f, + 0.f, + 0.f, + 0.f, + 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, {2.f, 4.f, 2.f, 4.f}); + + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {-3.3333333f, 3.6666666f, 0.333333f, 1.3333333f}); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + // z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {5.f, 1., -3.f, 3.f, 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, 0.f, 0.f, 0.f, 4.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 5.f, 1., -3.f, 3.f, - 0.f, 1.f, 1.f, -1.f, - 0.f, 0.f, 2.f, -9.f, - 0.f, 0.f, 0.f, 4.f - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 5.f, 2.f, 0.f, -3.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {5.f, 2.f, 0.f, -3.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.f, 1.f, 1.f, 1.f - }); + auto exp = NDArrayFactory::create('c', {4, 1}, {1.f, 1.f, 1.f, 1.f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("TriangularSolve with adjoint"); + // z->printIndexedBuffer("TriangularSolve with adjoint"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_1) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {3.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 3.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, - 1.f, 0.f, 1.f, 0.f, - 1.f, 1.f, 1.f, 1.f - }); - - auto b = NDArrayFactory::create('c', {4, 1}, { - 4.f, 2.f, 4.f, 2.f - }); + auto b = NDArrayFactory::create('c', {4, 1}, {4.f, 2.f, 4.f, 2.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { - 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + auto exp = NDArrayFactory::create( + 'c', {4, 1}, {1.333333f, -0.6666667f, 2.6666667f, -1.3333333f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("MatrixSolveLS"); - MmulHelper::matmul(&a, &z, &exp, false, false); + // z->printIndexedBuffer("MatrixSolveLS"); + MmulHelper::matmul(&a, &z, &exp, false, false); - ASSERT_TRUE(exp.equalsTo(b)); - + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_2) { + auto a = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f}); - auto a = NDArrayFactory::create('c', {3, 3}, { - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f - }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto exp = NDArrayFactory::create( + 'c', {3, 1}, {-0.24999914f, 0.4999994f, 0.08333314f}); - auto exp = NDArrayFactory::create('c', {3, 1}, { -0.24999914f, 0.4999994f, 0.08333314f }); + sd::ops::lstsq op; - sd::ops::lstsq op; + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + MmulHelper::matmul(&a, &z, &exp, false, false); - MmulHelper::matmul(&a, &z, &exp, false, false); + // z->printIndexedBuffer("MatrixSolveLS2"); -// z->printIndexedBuffer("MatrixSolveLS2"); - - ASSERT_TRUE(exp.equalsTo(b)); - + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_3) { + auto a = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 1.f, 0.f, 0.f, -1.f, 1.f, 0.f, 0.f, 1.f, 1.f, -1.f, -1.f}); - auto a = NDArrayFactory::create('c', {3, 4}, { - 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f - }); - - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3, 1}, { -0.5f, 1.5f, -2.f }); + auto exp = NDArrayFactory::create('c', {3, 1}, {-0.5f, 1.5f, -2.f}); - sd::ops::lstsq op; + sd::ops::lstsq op; - auto res = op.evaluate({&a, &b}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); -// z->printIndexedBuffer("MatrixSolveLS3"); - MmulHelper::matmul(&a, &z, &exp, false, false); - ASSERT_TRUE(exp.equalsTo(b)); - + // z->printIndexedBuffer("MatrixSolveLS3"); + MmulHelper::matmul(&a, &z, &exp, false, false); + ASSERT_TRUE(exp.equalsTo(b)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_4) { + auto a = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 1.f, 0.f, 0.f, -1.f, 1.f, 0.f, 0.f, 1.f, 1.f, -1.f, -1.f}); - auto a = NDArrayFactory::create('c', {3, 4}, { - 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f - }); + auto b = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + auto exp = + NDArrayFactory::create('c', {4, 1}, {-0.5f, 1.5f, -2.f, 0.f}); - auto exp = NDArrayFactory::create('c', {4, 1}, { -0.5f, 1.5f, -2.f, 0.f}); + sd::ops::lstsq op; - sd::ops::lstsq op; + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + // z->printIndexedBuffer("Output_12.4"); + // z->printShapeInfo("Output_12.4 shape"); + // MmulHelper::matmul(&a, z, &exp, false, false); - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); -// z->printIndexedBuffer("Output_12.4"); -// z->printShapeInfo("Output_12.4 shape"); -// MmulHelper::matmul(&a, z, &exp, false, false); + // z->printIndexedBuffer("MatrixSolveLS4"); -// z->printIndexedBuffer("MatrixSolveLS4"); - - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, SolveLs_Test_5) { + auto a = NDArrayFactory::create('c', {1, 0, 3, 4}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - auto a = NDArrayFactory::create('c', {1, 0, 3, 4}); - auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - - sd::ops::lstsq op; - - auto res = op.evaluate({&a, &b}, {false}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - ASSERT_TRUE(z.isEmpty()); + sd::ops::lstsq op; - + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z.isEmpty()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, Solve_Test_6) { + auto a = NDArrayFactory::create('c', {1, 0, 3, 3}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); - auto a = NDArrayFactory::create('c', {1, 0, 3, 3}); - auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); + sd::ops::solve op; - sd::ops::solve op; - - auto res = op.evaluate({&a, &b}, {true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); - ASSERT_TRUE(z.isEmpty()); + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z.isEmpty()); } ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { + auto a = + NDArrayFactory::create('c', {4, 4}, + {5.f, 1.f, -3.f, 3.f, 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, 0.f, 0.f, 0.f, 4.f}); - auto a = NDArrayFactory::create('c', {4, 4}, { - 5.f, 1.f, -3.f, 3.f, - 0.f, 1.f, 1.f, -1.f, - 0.f, 0.f, 2.f, -9.f, - 0.f, 0.f, 0.f, 4.f - }); - - auto b = NDArrayFactory::create('c', {4, 2}, { - 5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f - }); + auto b = NDArrayFactory::create( + 'c', {4, 2}, {5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f}); - auto exp = NDArrayFactory::create('c', {4, 2}, { - 1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f - }); + auto exp = NDArrayFactory::create( + 'c', {4, 2}, {1.f, 0.2f, 1.f, 0.8f, 1.f, 0.4f, 1.f, 1.2f}); - sd::ops::triangular_solve op; + sd::ops::triangular_solve op; - auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); - ASSERT_EQ(res.status(), ND4J_STATUS_OK); - auto z = res.at(0); + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); - //z.printIndexedBuffer("TriangularSolve with adjoint"); + // z.printIndexedBuffer("TriangularSolve with adjoint"); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 77851cb5f1d5..a9ae77875aef 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -14,3336 +14,3948 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include -#include #include +#include +#include -using namespace sd; +#include +#include "testlayers.h" -class DeclarableOpsTests13 : public testing::Test { -public: +using namespace sd; - DeclarableOpsTests13() { - //printf("\n"); - //fflush(stdout); - } +class DeclarableOpsTests13 : public testing::Test { + public: + DeclarableOpsTests13() { + // printf("\n"); + // fflush(stdout); + } }; template class TypedDeclarableOpsTests13 : public testing::Test { -public: - - TypedDeclarableOpsTests13() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests13() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests13, TestingTypes); TEST_F(DeclarableOpsTests13, test_pow_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); - auto y = NDArrayFactory::create('c', {2}, {3, 3}); - auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); + auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); + auto y = NDArrayFactory::create('c', {2}, {3, 3}); + auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); - sd::ops::Pow op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::Pow op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_empty_range_1) { - auto start = NDArrayFactory::create(0); - auto limit = NDArrayFactory::create(0); - - sd::ops::range op; - auto result = op.evaluate({&start, &limit}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.isEmpty()); + auto start = NDArrayFactory::create(0); + auto limit = NDArrayFactory::create(0); + sd::ops::range op; + auto result = op.evaluate({&start, &limit}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_empty_range_2) { + sd::ops::range op; + auto result = op.evaluate({}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::range op; - auto result = op.evaluate({}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.isEmpty()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_empty_range_3) { + sd::ops::range op; + auto result = op.evaluate({}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::range op; - auto result = op.evaluate({}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.isEmpty()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { - auto ctx = Context(1); - auto arr = NDArrayFactory::create('c', {1024,1}); - - ctx.setInputArray(0, arr); - ctx.setOutputArray(0, NDArrayFactory::create('c', {1})); - ctx.setInputArray(1, NDArrayFactory::create(0)); //Axis 0 + auto ctx = Context(1); + auto arr = NDArrayFactory::create('c', {1024, 1}); + ctx.setInputArray(0, arr); + ctx.setOutputArray(0, NDArrayFactory::create('c', {1})); + ctx.setInputArray(1, NDArrayFactory::create(0)); // Axis 0 - sd::ops::argmax op; - auto result = op.execute(&ctx); - ASSERT_EQ(Status::OK(), result); + sd::ops::argmax op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests13, test_add_1) { - auto x = NDArrayFactory::create('c', {1, 768}); - auto y = NDArrayFactory::create('c', {768}); - auto e = NDArrayFactory::create('c', {1, 768});; - y. assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {1, 768}); + auto y = NDArrayFactory::create('c', {768}); + auto e = NDArrayFactory::create('c', {1, 768}); + ; + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests13, test_listdiff_1) { - auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); - auto y = NDArrayFactory::create('c', {2}, {3, 1}); + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2}, {3, 1}); - auto od = NDArrayFactory::create('c', {2}); - auto oi = NDArrayFactory::create('c', {2}); + auto od = NDArrayFactory::create('c', {2}); + auto oi = NDArrayFactory::create('c', {2}); - sd::ops::listdiff op; - auto result = op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::listdiff op; + auto result = + op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests13, test_greater_1) { - auto x = NDArrayFactory::create('c', {3, 1}); - auto y = NDArrayFactory::create('c', {1, 4}); + auto x = NDArrayFactory::create('c', {3, 1}); + auto y = NDArrayFactory::create('c', {1, 4}); - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) { - Nd4jLong axis = 0L; - auto x = NDArrayFactory::create('c', {2}, {4, 2}); - auto y = NDArrayFactory::create('c', {1}, {axis}); - auto exp = NDArrayFactory::create('c', {2}, {1, 2}); + Nd4jLong axis = 0L; + auto x = NDArrayFactory::create('c', {2}, {4, 2}); + auto y = NDArrayFactory::create('c', {1}, {axis}); + auto exp = NDArrayFactory::create('c', {2}, {1, 2}); - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {true}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {true}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests13, test_or_1) { + NDArray x('c', {4}, {false, true, false, true}, sd::DataType::BOOL); + NDArray y('c', {4}, {false, false, true, true}, sd::DataType::BOOL); + NDArray e('c', {4}, {false, true, true, true}, sd::DataType::BOOL); - NDArray x('c', {4}, {false, true, false, true}, sd::DataType::BOOL); - NDArray y('c', {4}, {false, false, true, true}, sd::DataType::BOOL); - NDArray e('c', {4}, {false, true, true, true}, sd::DataType::BOOL); - - NDArray z('c', {4}, sd::DataType::BOOL); + NDArray z('c', {4}, sd::DataType::BOOL); - x.applyPairwiseTransform(pairwise::Or, y, z); + x.applyPairwiseTransform(pairwise::Or, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_and_1) { - auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); - auto e = NDArrayFactory::create('c', {4}, {false, false, false, true}); + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, false, false, true}); - auto z = NDArrayFactory::create('c', {4}); + auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::And, y, z); + x.applyPairwiseTransform(pairwise::And, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, test_xor_1) { - auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); - auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); - auto e = NDArrayFactory::create('c', {4}, {false, true, true, false}); + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, true, true, false}); - auto z = NDArrayFactory::create('c', {4}); + auto z = NDArrayFactory::create('c', {4}); - x.applyPairwiseTransform(pairwise::Xor, y, z); + x.applyPairwiseTransform(pairwise::Xor, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) { - auto x = NDArrayFactory::create('c', {2,3}, {1,2,3, 4, 5, 6}); - auto y = NDArrayFactory::create('c', {2,3}, {1,-2,3, -4, 5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.2, 2.2, 3.2, 4.2, 5.2, 6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) { - auto x = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); - auto y = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, -2, 3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); - //ASSERT_EQ(e, z); + // ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) { - auto x = NDArrayFactory::create('c', {2,3}, {-1, 2, -3, 4, -5, 6}); - auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); - auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - auto exp = NDArrayFactory::create('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); - sd::ops::barnes_gains op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Gains out"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto x = NDArrayFactory::create('c', {2, 3}, {-1, 2, -3, 4, -5, 6}); + auto y = + NDArrayFactory::create('c', {2, 3}, {-0.1, -2, 3, -4, -0.5, -6}); + auto eps = NDArrayFactory::create('c', {2, 3}, + {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { - auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {2}, {2, 3}); - auto cols = NDArrayFactory::create('c', {5}, {0, 2, 1, 4, 3}); - auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp1 = NDArrayFactory::create('c', {5,4}, {-1.846154, -1.846154, -1.846154, -1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); - - - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - ASSERT_TRUE(exp1.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create('c', {5, 4}); + auto rows = NDArrayFactory::create('c', {2}, {2, 3}); + auto cols = NDArrayFactory::create('c', {5}, {0, 2, 1, 4, 3}); + auto vals = + NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp1 = NDArrayFactory::create( + 'c', {5, 4}, + {-1.846154, -1.846154, -1.846154, -1.846154, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); + + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp1.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { - auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {3}, {1,2,3}); - auto cols = NDArrayFactory::create('c', {5}, {1, 2, 0, 4, 3}); - auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {5,4}, {-0.622568, -0.622568, -0.622568, -0.622568, 1.846154, 1.846154, 1.846154, 1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); - - - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create('c', {5, 4}); + auto rows = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto cols = NDArrayFactory::create('c', {5}, {1, 2, 0, 4, 3}); + auto vals = + NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {-0.622568, -0.622568, -0.622568, -0.622568, 1.846154, 1.846154, 1.846154, + 1.846154, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); + + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { - auto data = NDArrayFactory::create('c', {11, 5}, {0.3, 0.2625, 0.2674, 0.8604, 0.4803, 0.1096, 0.795, 0.5918, 0.2738, 0.952, 0.969, 0.8586, 0.8088, 0.5338, 0.5961, 0.7187, 0.463, 0.0867, 0.7748, 0.4802, 0.2493, 0.3227, 0.3064, 0.698, 0.7977, 0.7674, 0.168, 0.3107, 0.0217, 0.138, 0.8619, 0.8413, 0.5285, 0.9703, 0.6774, 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, 0.6202, 0.7604, 0.0788, 0.0865, 0.7445, 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}); - auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - auto vals = NDArrayFactory::create({0.6199614579042966, 0.19644097697184246, 0.13824979367331638, 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, 0.6278185083409108, 0.235127797795446, 0.07023700015217448, 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273}); - //auto buf = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {11, 5}, {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, 0.169803}); - //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. - //std::vector exp({&exp1, &exp2}); - //data.assign(1.0); //linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_edge_forces op; - auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); - - //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printBuffer("Output"); - //exp.printBuffer("Expect"); - //result.at(0)->printShapeInfo("Shape output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto data = NDArrayFactory::create( + 'c', {11, 5}, + {0.3, 0.2625, 0.2674, 0.8604, 0.4803, 0.1096, 0.795, 0.5918, + 0.2738, 0.952, 0.969, 0.8586, 0.8088, 0.5338, 0.5961, 0.7187, + 0.463, 0.0867, 0.7748, 0.4802, 0.2493, 0.3227, 0.3064, 0.698, + 0.7977, 0.7674, 0.168, 0.3107, 0.0217, 0.138, 0.8619, 0.8413, + 0.5285, 0.9703, 0.6774, 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, + 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, 0.6202, 0.7604, 0.0788, + 0.0865, 0.7445, 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}); + auto rows = NDArrayFactory::create( + {0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create( + {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, + 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, + 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, + 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, + 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create( + {0.6199614579042966, 0.19644097697184246, 0.13824979367331638, + 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, + 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, + 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, + 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, + 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, + 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, + 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, + 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, + 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, + 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, + 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, + 0.6278185083409108, 0.235127797795446, 0.07023700015217448, + 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, + 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, + 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, + 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, + 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, + 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, + 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, + 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, + 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, + 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, + 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, + 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, + 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, + 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, + 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, + 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, + 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, + 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, + 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, + 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273}); + // auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {11, 5}, + {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, + 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, + 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, + 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, + 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, + 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, + 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, + 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, + -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, + 0.169803}); + // auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + // std::vector exp({&exp1, &exp2}); + // data.assign(1.0); //linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); + + // nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", + // rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("Output"); + // exp.printBuffer("Expect"); + // result.at(0)->printShapeInfo("Shape output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { -// auto data = NDArrayFactory::create('c', {5,4}); - auto rows = NDArrayFactory::create('c', {2}, {0, 1}); - auto cols = NDArrayFactory::create('c', {4}, {0, 1, 1, 0}); - auto vals = NDArrayFactory::create('c', {4}, {20., 30., 40., 50.}); - auto exp = NDArrayFactory::create('c', {1,1}, {20.}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized1"); - ASSERT_TRUE(exp.equalsTo(result.at(2))); + // auto data = NDArrayFactory::create('c', {5,4}); + auto rows = NDArrayFactory::create('c', {2}, {0, 1}); + auto cols = NDArrayFactory::create('c', {4}, {0, 1, 1, 0}); + auto vals = NDArrayFactory::create('c', {4}, {20., 30., 40., 50.}); + auto exp = NDArrayFactory::create('c', {1, 1}, {20.}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized1"); + ASSERT_TRUE(exp.equalsTo(result.at(2))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { - auto rows = NDArrayFactory::create('c', {4}, {0, 2, 2, 3}); - auto cols = NDArrayFactory::create('c', {8}, {0, 1, 1, 0, 0, 1, 1, 1}); - auto vals = NDArrayFactory::create('c', {8}, {20., 30., 40., 50., 120., 130., 140., 150.}); - auto exp = NDArrayFactory::create('c', {1,5}, {20., 15., 15., 20., 20.}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized2"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - ASSERT_TRUE(exp.equalsTo(result.at(2))); - + auto rows = NDArrayFactory::create('c', {4}, {0, 2, 2, 3}); + auto cols = NDArrayFactory::create('c', {8}, {0, 1, 1, 0, 0, 1, 1, 1}); + auto vals = NDArrayFactory::create( + 'c', {8}, {20., 30., 40., 50., 120., 130., 140., 150.}); + auto exp = + NDArrayFactory::create('c', {1, 5}, {20., 15., 15., 20., 20.}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized2"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp.equalsTo(result.at(2))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { - auto rows = NDArrayFactory::create('c', {12}, {0, 2, 3, 5, 7, 8, 9, 11, 12, 14, 18, 21}); - auto cols = NDArrayFactory::create('c', {24}, {0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 1, 0, 2, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5}); - auto vals = NDArrayFactory::create('c', {24}, {20., 30., 40., 50., 120., 130., 140., 150.,220., 230., 240., 250., 2120., 2130., 2140., 2150., 320., 330., 340., 350., 3120., 3130., 3140., 3150.}); - auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); -// data.linspace(1); - -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(2)->printBuffer("Symmetrized3"); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto rows = NDArrayFactory::create( + 'c', {12}, {0, 2, 3, 5, 7, 8, 9, 11, 12, 14, 18, 21}); + auto cols = NDArrayFactory::create( + 'c', {24}, + {0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 1, 0, 2, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5}); + auto vals = NDArrayFactory::create( + 'c', {24}, {20., 30., 40., 50., 120., 130., 140., 150., + 220., 230., 240., 250., 2120., 2130., 2140., 2150., + 320., 330., 340., 350., 3120., 3130., 3140., 3150.}); + auto exp = NDArrayFactory::create( + 'c', {1, 39}, + {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, + 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(2)->printBuffer("Symmetrized3"); + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { - auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - auto vals = NDArrayFactory::create( {0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030}); - //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); -// data.linspace(1); - auto exp4 = NDArrayFactory::create('c', {1, 108}, {0.6239, 0.1813, 0.1236, 0.03695, 0.00795, 0.03385, 0.0074, 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051, 0.00895, 0.01605, 0.00245, 0.00705, 0.00125, 0.0021, 0.01605, 0.6022, 0.1615, 0.0233, - 0.0183, 0.0108, 0.0068, 0.0042, 0.0113, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, 0.0779, 0.03565, 0.05085, 0.03835, 0.02625, 0.6239, 0.3093, 0.0068, 0.0653, 0.2099, 0.0205, 0.0173, 0.0073, - 0.0171, 0.0089, 0.0158, 0.0113, 0.03835, 0.71495, 0.04775, 0.03615, 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, 0.0779, 0.0073, 0.5098, 0.0159, 0.00135, 1.5623E-5, 0.03385, 0.00705, - 0.02625, 0.0171, 0.71495, 0.06515, 0.01835, 0.00775, 0.00115, 0.03695, 0.051, 0.1615, 0.03565, 0.0205, 0.00275, 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, 0.05085, 0.0173, 0.04775, - 0.00135, 0.06515, 0.0026, 0.35855, 0.1236, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.0159, 0.01835, 0.0055, 0.35855}); -// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); -// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); -// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::barnes_symmetrized op; - auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(2); + auto rows = NDArrayFactory::create( + {0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create( + {4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, + 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, + 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, + 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, + 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create( + {0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, + 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, + 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, + 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, + 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, + 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, + 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, + 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, + 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, + 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, + 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, + 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, + 0.0038, 0.0038, 0.0030}); + // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, + // 0.000000, 0.000000, 65.000000, 60.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + auto exp4 = NDArrayFactory::create( + 'c', {1, 108}, + {0.6239, 0.1813, 0.1236, 0.03695, 0.00795, 0.03385, 0.0074, + 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051, + 0.00895, 0.01605, 0.00245, 0.00705, 0.00125, 0.0021, 0.01605, + 0.6022, 0.1615, 0.0233, 0.0183, 0.0108, 0.0068, 0.0042, + 0.0113, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, + 0.0779, 0.03565, 0.05085, 0.03835, 0.02625, 0.6239, 0.3093, + 0.0068, 0.0653, 0.2099, 0.0205, 0.0173, 0.0073, 0.0171, + 0.0089, 0.0158, 0.0113, 0.03835, 0.71495, 0.04775, 0.03615, + 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, + 0.0779, 0.0073, 0.5098, 0.0159, 0.00135, 1.5623E-5, 0.03385, + 0.00705, 0.02625, 0.0171, 0.71495, 0.06515, 0.01835, 0.00775, + 0.00115, 0.03695, 0.051, 0.1615, 0.03565, 0.0205, 0.00275, + 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, + 0.05085, 0.0173, 0.04775, 0.00135, 0.06515, 0.0026, 0.35855, + 0.1236, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.0159, + 0.01835, 0.0055, 0.35855}); + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(2); // res->printBuffer("Symmetrized4"); // exp4.printBuffer("Expected sym"); // nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); // nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - ASSERT_TRUE(exp4.equalsTo(res)); - + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp4.equalsTo(res)); } TEST_F(DeclarableOpsTests13, CellContains_test_1) { - - auto corners = NDArrayFactory::create( {0.5384, 0.5640, 0.3449, 0.5257, 0.5505}); - auto width = NDArrayFactory::create({0.4306, 0.3960, 0.4639, 0.5040, 0.4904}); - auto point = NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); - //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); - // data.linspace(1); - - // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); - // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); - // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); - sd::ops::cell_contains op; - auto result = op.evaluate({&corners, &width, &point}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(result.at(0).e(0)); - //result.at(2)->printBuffer("Symmetrized3"); - //exp.printBuffer("EXPect symm3"); - // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto corners = + NDArrayFactory::create({0.5384, 0.5640, 0.3449, 0.5257, 0.5505}); + auto width = + NDArrayFactory::create({0.4306, 0.3960, 0.4639, 0.5040, 0.4904}); + auto point = + NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); + // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, + // 0.000000, 0.000000, 65.000000, 60.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, + // -0.5, -6}); auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, + // 0.2, -0.3, 0.4, -0.5, 0.6}); auto exp = + // NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::cell_contains op; + auto result = op.evaluate({&corners, &width, &point}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(result.at(0).e(0)); + // result.at(2)->printBuffer("Symmetrized3"); + // exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_1) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp('c', {2, 2, 3}, + {100, 0, 44, 208, 5, 220, 177, 230, 97, 2, 255, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray factor = NDArrayFactory::create(0.5); - NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, sd::DataType::FLOAT32); - - - sd::ops::adjust_hue op; - auto results (op.evaluate({&input, &factor}, {}, {2})); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input, &factor}, {}, {2})); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printIndexedBuffer(); + auto result = results.at(0); + // result.printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_2) { + NDArray input('c', {2, 2, 3}, + {0.f, 100.f / 255.f, 56.f / 255.f, 17.f / 255.f, 220.f / 255.f, + 5.f / 255.f, 150.f / 255.f, 97.f / 255.f, 230.f / 255.f, + 255.f / 255.f, 2.f / 255.f, 13.f / 255.f}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 2, 3}, + {4.f / 255.f, 100.f / 255.f, 0.f, 146.f / 255.f, 220.f / 255.f, + 5.f / 255.f, 97.f / 255.f, 123.8f / 255.f, 230.f / 255.f, + 255.f / 255.f, 2.f / 255.f, 164.8f / 255.f}, + sd::DataType::FLOAT32); - NDArray input('c', { 2,2,3 }, { 0.f,100.f / 255.f,56.f / 255.f, 17.f / 255.f,220.f / 255.f,5.f / 255.f, 150.f / 255.f,97.f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,13.f / 255.f }, sd::DataType::FLOAT32); - NDArray exp('c', { 2,2,3 }, { 4.f / 255.f,100.f / 255.f,0.f, 146.f / 255.f,220.f / 255.f,5.f / 255.f, 97.f / 255.f,123.8f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,164.8f / 255.f }, sd::DataType::FLOAT32); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.9}, {2})); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.9}, {2})); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_3) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 3}, + {0., 84., 100., 5., 220., 122.0001, 229.8, 97., 230., 255., 142.8002, 2.}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, sd::DataType::FLOAT32); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {-0.9}, {2})); - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {-0.9}, {2})); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_4) { + NDArray input('c', {2, 3, 2}, + {0, 17, 100, 220, 56, 5, 150, 255, 97, 2, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp('c', {2, 3, 2}, + {100, 208, 0, 5, 44, 220, 177, 2, 230, 255, 97, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, sd::DataType::FLOAT32); - - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.5}, {1})); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {1})); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustHue_5) { + NDArray input('c', {3, 2, 2}, + {0, 17, 150, 255, 100, 220, 97, 2, 56, 5, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp('c', {3, 2, 2}, + {100, 208, 177, 2, 0, 5, 230, 255, 44, 220, 97, 244}, + sd::DataType::FLOAT32); - NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, sd::DataType::FLOAT32); - - sd::ops::adjust_hue op; - auto results(op.evaluate({&input}, {0.5}, {0})); + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {0})); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_1) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp( + 'c', {2, 2, 3}, + {50, 100, 78, 118.5, 220, 112.5, 190, 163.5, 230, 255, 128.5, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray factor = NDArrayFactory::create(0.5); - NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input, &factor}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input, &factor}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_2) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::DOUBLE); + NDArray exp('c', {2, 2, 3}, + {0., 100., 56., 12.279087, 220., 0., 91.654228, 0., 230., 255., + 0., 11.087015}, + sd::DataType::DOUBLE); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::DOUBLE); - NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, sd::DataType::DOUBLE); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {10}, {2}); - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {10}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); -// result.printIndexedBuffer("Result2"); -// exp.printIndexedBuffer("Expect2"); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer("Result2"); + // exp.printIndexedBuffer("Expect2"); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_3) { + NDArray input('c', {2, 2, 3}, + {0, 100, 56, 17, 220, 5, 150, 97, 230, 255, 2, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 2, 3}, + {100., 100., 100., 220., 220., 220., 230., 230., 230., 255., 255., 255.}, + sd::DataType::FLOAT32); - NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {-10}, {2}); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {-10}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_4) { + NDArray input('c', {2, 3, 2}, + {0, 17, 100, 220, 56, 5, 150, 255, 97, 2, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {2, 3, 2}, + {50, 118.5, 100, 220, 78, 112.5, 190, 255, 163.5, 128.5, 230, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {0.5}, {1}); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + // result.printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, adjustSaturation_5) { + NDArray input('c', {3, 2, 2}, + {0, 17, 150, 255, 100, 220, 97, 2, 56, 5, 230, 13}, + sd::DataType::FLOAT32); + NDArray exp( + 'c', {3, 2, 2}, + {50, 118.5, 190, 255, 100, 220, 163.5, 128.5, 78, 112.5, 230, 134}, + sd::DataType::FLOAT32); - NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); - NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, sd::DataType::FLOAT32); - - sd::ops::adjust_saturation op; - auto results = op.evaluate({&input}, {0.5}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp.isSameShape(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - TEST_F(DeclarableOpsTests13, shift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(32); - e.assign(512); - - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); - auto z = result.at(0); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, rshift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(512); - e.assign(32); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); - sd::ops::rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(32); - e.assign(512); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); - sd::ops::cyclic_shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create(4); - auto e = x.ulike(); - x.assign(512); - e.assign(32); - - sd::ops::cyclic_rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); - auto z = result.at(0); + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, shift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); - - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - auto z = result.at(0); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, rshift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(512); - y.assign(4); - e.assign(32); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); - sd::ops::rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_EQ(e, z); + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - sd::ops::cyclic_shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); - auto e = x.ulike(); - x.assign(512); - y.assign(4); - e.assign(32); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); - sd::ops::cyclic_rshift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests13, shift_bits_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {1, 5}); - auto e = x.ulike(); - x.assign(32); - y.assign(4); - e.assign(512); - - sd::ops::shift_bits op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); - auto z = result.at(0); + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) { + NDArray x('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 2}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0}, + sd::DataType::INT32); - NDArray x('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 2} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0} , sd::DataType::INT32); - - NDArray exp('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); + NDArray exp('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) { - - NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); - - NDArray exp('c', {24, 1,3,2, 1}, { 0, 2, 0, 8, 0, 0, 0, 26, 0, 32, 0, 0, 0, 3, 0, 9, 0, 0, 0, 27, 0, 33, 0, 0, 1, - 0, 7, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 29, 0, 35, 0, 0, 0, 6, - 0, 12, 0, 0, 0, 30, 0, 36, 0, 0, 4, 0, 10, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 14, - 0, 20, 0, 0, 0, 38, 0, 44, 0, 0, 0, 15, 0, 21, 0, 0, 0, 39, 0, 45, 0, 0, 13, 0, - 19, 0, 0, 0, 37, 0, 43, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 41, 0, 47, 0, 0, 0, 18, - 0, 24, 0, 0, 0, 42, 0, 48, 0, 0, 16, 0, 22, 0, 0, 0, 40, 0, 46, 0, 0, 0}, sd::DataType::FLOAT32); - x.linspace(1); - - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray x('c', {2, 2, 4, 3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {0, 0, 0, 2, 2, 1}, sd::DataType::INT32); + + NDArray exp('c', {24, 1, 3, 2, 1}, + {0, 2, 0, 8, 0, 0, 0, 26, 0, 32, 0, 0, 0, 3, 0, 9, 0, 0, + 0, 27, 0, 33, 0, 0, 1, 0, 7, 0, 0, 0, 25, 0, 31, 0, 0, 0, + 0, 5, 0, 11, 0, 0, 0, 29, 0, 35, 0, 0, 0, 6, 0, 12, 0, 0, + 0, 30, 0, 36, 0, 0, 4, 0, 10, 0, 0, 0, 28, 0, 34, 0, 0, 0, + 0, 14, 0, 20, 0, 0, 0, 38, 0, 44, 0, 0, 0, 15, 0, 21, 0, 0, + 0, 39, 0, 45, 0, 0, 13, 0, 19, 0, 0, 0, 37, 0, 43, 0, 0, 0, + 0, 17, 0, 23, 0, 0, 0, 41, 0, 47, 0, 0, 0, 18, 0, 24, 0, 0, + 0, 42, 0, 48, 0, 0, 16, 0, 22, 0, 0, 0, 40, 0, 46, 0, 0, 0}, + sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) { - - NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray paddings('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); - - NDArray exp('c', {24, 2,3,2, 1}, { 0, 0, 0, 0, 0, 0, 0, 14, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38, 0, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, - 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 19, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 37, 0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41, 0, 47, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 18, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, - 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0, 46, 0, 0, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 0, 32, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 29, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 36, 0, 0, - 0, 0, 0, 0, 0, 0, 4, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0}, sd::DataType::FLOAT32); - x.linspace(1); - - sd::ops::space_to_batch_nd op; - auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray x('c', {2, 2, 4, 3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {1, 1, 0, 2, 2, 1}, sd::DataType::INT32); + + NDArray exp( + 'c', {24, 2, 3, 2, 1}, + {0, 0, 0, 0, 0, 0, 0, 14, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 38, 0, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 21, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 39, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, + 13, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 43, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 41, 0, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 24, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 42, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, + 16, 0, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0, 46, 0, 0, 0, + 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 0, 32, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 3, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 27, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 5, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 29, 0, 35, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 6, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 30, 0, 36, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 10, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) { + NDArray x('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); - NDArray x('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); - - NDArray blockShape('c', {3}, {2., 2, 2} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , sd::DataType::INT32); + NDArray blockShape('c', {3}, {2., 2, 2}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0}, sd::DataType::INT32); - NDArray exp('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray exp('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) { + NDArray x('c', {24, 1, 3, 2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0, 0, 0, 2, 2, 1}, sd::DataType::INT32); - NDArray x('c', {24, 1,3,2, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); + NDArray exp('c', {2, 2, 4, 3, 1}, + {25, 2, 14, 61, 38, 50, 27, 4, 16, 63, 40, 52, + 97, 74, 86, 133, 110, 122, 99, 76, 88, 135, 112, 124, + 31, 8, 20, 67, 44, 56, 33, 10, 22, 69, 46, 58, + 103, 80, 92, 139, 116, 128, 105, 82, 94, 141, 118, 130}, + sd::DataType::FLOAT32); + x.linspace(1); - NDArray exp('c', {2, 2,4,3, 1}, {25, 2, 14, 61, 38, 50, 27, 4, 16, 63, 40, 52, 97, 74, 86, 133, 110, 122, 99, 76, 88, 135, 112, 124, - 31, 8, 20, 67, 44, 56, 33, 10, 22, 69, 46, 58, 103, 80, 92, 139, 116, 128, 105, 82, 94, 141, 118, 130}, sd::DataType::FLOAT32); - x.linspace(1); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) { + NDArray x('c', {24, 2, 3, 2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3}, + sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {1, 1, 0, 2, 2, 1}, sd::DataType::INT32); - NDArray x('c', {24, 2,3,2, 1}, sd::DataType::FLOAT32); - NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions - NDArray crop('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); - - NDArray exp('c', {2, 2,4,3, 1}, {193, 146, 170, 265, 218, 242, 195, 148, 172, 267, 220, 244, 55, 8, 32, 127, 80, 104, 57, 10, 34, 129, 82, - 106, 205, 158, 182, 277, 230, 254, 207, 160, 184, 279, 232, 256, 67, 20, 44, 139, 92, 116, 69, 22, 46, 141, 94, 118}, sd::DataType::FLOAT32); - x.linspace(1); + NDArray exp('c', {2, 2, 4, 3, 1}, + {193, 146, 170, 265, 218, 242, 195, 148, 172, 267, 220, 244, + 55, 8, 32, 127, 80, 104, 57, 10, 34, 129, 82, 106, + 205, 158, 182, 277, 230, 254, 207, 160, 184, 279, 232, 256, + 67, 20, 44, 139, 92, 116, 69, 22, 46, 141, 94, 118}, + sd::DataType::FLOAT32); + x.linspace(1); - sd::ops::batch_to_space_nd op; - auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray e('c', {5, 5}, sd::DataType::FLOAT32); + x1.assign(3); + x2.assign(1); + x3.assign(2); + e.assign(3); - NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); - NDArray e('c', {5, 5}, sd::DataType::FLOAT32); - x1.assign(3); - x2.assign(1); - x3.assign(2); - e.assign(3); - + sd::ops::mergemax op; + auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::mergemax op; - auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_2) { + NDArray x1('c', {1, 3}, {0., 1, 2}, sd::DataType::FLOAT32); + NDArray x2('c', {1, 1}, std::vector{1.}, sd::DataType::FLOAT32); + NDArray out('c', {1, 3}, {-1., -1, -1}, sd::DataType::FLOAT32); - NDArray x1('c', {1, 3}, {0., 1, 2}, sd::DataType::FLOAT32); - NDArray x2('c', {1, 1}, std::vector{1.}, sd::DataType::FLOAT32); - NDArray out('c', {1, 3}, {-1., -1, -1}, sd::DataType::FLOAT32); + sd::ops::mergemax op; + auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); - sd::ops::mergemax op; - auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); - - ASSERT_EQ(20, status); + ASSERT_EQ(20, status); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); - - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z = result.at(0); - - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_2) { - - NDArray x1('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); - NDArray x2('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); - NDArray x3('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); - NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); - - grad.linspace(.1, .1); - - NDArray exp1('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); - NDArray exp2('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); - NDArray exp3('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); - - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z1 = result.at(0); - auto z2 = result.at(1); - auto z3 = result.at(2); - - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.isSameShape(z2)); - ASSERT_TRUE(exp2.equalsTo(z2)); - ASSERT_TRUE(exp3.isSameShape(z3)); - ASSERT_TRUE(exp3.equalsTo(z3)); - + NDArray x1('c', {2, 5}, {1, 2, 3, 4, 5, 4, 3, 2, 1, 0}, + sd::DataType::FLOAT32); + NDArray x2('c', {2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + sd::DataType::FLOAT32); + NDArray x3('c', {2, 5}, {0, 1, 1, 2, 3, 4, 7, 5, 8, 10}, + sd::DataType::FLOAT32); + NDArray grad('c', {2, 5}, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray exp1('c', {2, 5}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0}, + sd::DataType::FLOAT32); + NDArray exp2('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0}, + sd::DataType::FLOAT32); + NDArray exp3('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0}, + sd::DataType::FLOAT32); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergemax_bp_3) { - - NDArray x1C('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); - NDArray x2C('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); - NDArray x3C('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); - NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); - - grad.linspace(.1, .1); - - NDArray x1('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray x2('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray x3('f', { 2, 5 }, sd::DataType::FLOAT32); - - NDArray exp1C('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); - NDArray exp2C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); - NDArray exp3C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); - - NDArray exp1('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray exp2('f', { 2, 5 }, sd::DataType::FLOAT32); - NDArray exp3('f', { 2, 5 }, sd::DataType::FLOAT32); - - x1.assign(x1C); - x2.assign(x2C); - x3.assign(x3C); - - exp1.assign(exp1C); - exp2.assign(exp2C); - exp3.assign(exp3C); - - sd::ops::mergemax_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - auto z1 = result.at(0); - auto z2 = result.at(1); - auto z3 = result.at(2); - - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.isSameShape(z2)); - ASSERT_TRUE(exp2.equalsTo(z2)); - ASSERT_TRUE(exp3.isSameShape(z3)); - ASSERT_TRUE(exp3.equalsTo(z3)); - + NDArray x1C('c', {2, 5}, {1, 2, 3, 4, 5, 4, 3, 2, 1, 0}, + sd::DataType::FLOAT32); + NDArray x2C('c', {2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + sd::DataType::FLOAT32); + NDArray x3C('c', {2, 5}, {0, 1, 1, 2, 3, 4, 7, 5, 8, 10}, + sd::DataType::FLOAT32); + NDArray grad('c', {2, 5}, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray x1('f', {2, 5}, sd::DataType::FLOAT32); + NDArray x2('f', {2, 5}, sd::DataType::FLOAT32); + NDArray x3('f', {2, 5}, sd::DataType::FLOAT32); + + NDArray exp1C('c', {2, 5}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0}, + sd::DataType::FLOAT32); + NDArray exp2C('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0}, + sd::DataType::FLOAT32); + NDArray exp3C('c', {2, 5}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0}, + sd::DataType::FLOAT32); + + NDArray exp1('f', {2, 5}, sd::DataType::FLOAT32); + NDArray exp2('f', {2, 5}, sd::DataType::FLOAT32); + NDArray exp3('f', {2, 5}, sd::DataType::FLOAT32); + + x1.assign(x1C); + x2.assign(x2C); + x3.assign(x3C); + + exp1.assign(exp1C); + exp2.assign(exp2C); + exp3.assign(exp3C); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergeadd_bp_1) { - - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); - - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); - - sd::ops::mergeadd_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - - for (int i = 0; i < 3; i++) { - auto z = result.at(0); - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); - } + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); + + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + + sd::ops::mergeadd_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + for (int i = 0; i < 3; i++) { + auto z = result.at(0); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, mergeavg_bp_1) { + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray grad('c', {5, 5}, sd::DataType::FLOAT32); - NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); - NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); - - x1.assign(3); - x2.assign(1); - x3.assign(2); - grad.linspace(.1, .1); + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); - sd::ops::mergeavg_bp op; - auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); + sd::ops::mergeavg_bp op; + auto result = op.evaluate({&x1, &x2, &x3, &grad}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); - grad.applyScalar(sd::scalar::Divide, 3, grad); - - for (int i = 0; i < 3; i++) { - auto z = result.at(i); - ASSERT_TRUE(grad.isSameShape(z)); - ASSERT_TRUE(grad.equalsTo(z)); - } + grad.applyScalar(sd::scalar::Divide, 3, grad); + for (int i = 0; i < 3; i++) { + auto z = result.at(i); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_1) { - - const int sL = 5; - const int bS = 3; - const int nIn = 3; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, - 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, - 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, - 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, - 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); - - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(cL)); - ASSERT_TRUE(expClast.equalsTo(cL)); - + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create( + 'c', {sL, bS, nOut}, + {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, + 0.58434f, 0.58434f, 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, + 0.55732f, 0.56338f, 0.56338f, 0.56338f, 0.53763f, 0.53763f, 0.53763f, + 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, 0.53626f, + 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, + 0.55327f, 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, + 0.5625f, 0.5625f, 0.5625f}); + + auto expClast = NDArrayFactory::create( + 'c', {bS, nOut}, + {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, + 1.219861f, 1.219861f, 1.219861f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_2) { - - const int sL = 5; - const int bS = 3; - const int nIn = 3; - const int nOut = 3; - - // input arguments - - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, - 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, - 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, 0.577735f, 0.577735f, 0.577735f}); - - auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(cL)); - ASSERT_TRUE(expClast.equalsTo(cL)); - + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create( + 'c', {bS, sL, nOut}, + {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, + 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, + 0.485999f, 0.485999f, 0.485999f, 0.596965f, 0.596965f, 0.596965f, + 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, + 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, + 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, + 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, + 0.577735f, 0.577735f, 0.577735f}); + + auto expClast = NDArrayFactory::create( + 'c', {bS, nOut}, + {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, + 1.301922f, 1.301922f, 1.301922f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_3) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL,bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, - 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, - 0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, + 0.534701f, 0.534701f, 0.534701f, 0.549139f, 0.549139f, 0.549139f, + 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, + 0.605106f, 0.605106f, 0.605106f, 0.614114f, 0.614114f, 0.614114f, + 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {bS, nOut}, + {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {bS, nOut}, + {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_4) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2 * nOut}, { - 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, - -0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, - 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f, - -0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, - 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f, - -0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, - 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f, - -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f, - -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, 2 * nOut}, + {0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, + 0.585289f, 0.585289f, 0.585289f, -0.106937f, -0.106937f, -0.106937f, + 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, + 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, + 0.547395f, 0.547395f, 0.547395f, -0.123305f, -0.123305f, -0.123305f, + 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, + 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, + 0.565308f, 0.565308f, 0.565308f, -0.152313f, -0.152313f, -0.152313f, + 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, + 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_5) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {bS, sL, 2*nOut}, { - 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, - 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, - 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, - 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, - 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, - -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f, - -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - // h->printBuffer(); - // hL->printBuffer(); - // cL->printBuffer(); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003; + Wx({1, 2, 0, 0, 0, 0}) = -0.003; + Wr({0, 1, 0, 0, 0, 0}) = 0.006; + Wr({1, 2, 0, 0, 0, 0}) = -0.006; + b({0, 1, 0, 0}) = 0.5; + b({1, 2, 0, 0}) = -0.5; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {bS, sL, 2 * nOut}, + {0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, + 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, + 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, + 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, + 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, + 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, + 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, + 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, + 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, + 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, + -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.07293f, 1.07293f, 1.07293f, 1.346609f, 1.346609f, 1.346609f, + -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + // h->printBuffer(); + // hL->printBuffer(); + // cL->printBuffer(); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_6) { - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 2; // bidirectional sum - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, - 0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f, - 0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, - -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, - sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, - -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, - sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 2; // bidirectional sum + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, + 0.444871f, 0.444871f, 0.444871f, 0.457060f, 0.457060f, 0.457060f, + 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, + 0.394491f, 0.394491f, 0.394491f, 0.412995f, 0.412995f, 0.412995f, + 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_7) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556, - 0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805, - 0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, sd::DataType::FLOAT32); +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, + {0.55533, 0.55533, 0.55533, 0.562925, 0.562925, 0.562925, + 0.531795, 0.531795, 0.531795, 0.542556, 0.542556, 0.542556, + 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, + 0.524805, 0.524805, 0.524805, 0.539187, 0.539187, 0.539187, + 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_8) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 1.; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, + 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, + 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, + 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, + 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {bS, nOut}, + {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {bS, nOut}, + {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 1.; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, - 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, - 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_9) { - #ifndef HAVE_MKLDNN - - const int sL = 5; - const int bS = 2; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05; - Wp({1,2, 0,0}) = 0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2*nOut}, { - 0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, - 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, - 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, - 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, - 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f, - -0.103843f, -0.103843f, -0.103843f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f, - -0.292174f, -0.292174f, -0.292174f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003; + Wx({1, 2, 0, 0, 0, 0}) = -0.003; + Wr({0, 1, 0, 0, 0, 0}) = 0.006; + Wr({1, 2, 0, 0, 0, 0}) = -0.006; + b({0, 1, 0, 0}) = 0.5; + b({1, 2, 0, 0}) = -0.5; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + Wp({0, 1, 0, 0}) = -0.05; + Wp({1, 2, 0, 0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, 2 * nOut}, + {0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, + 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, + 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, + 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, + 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, + 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, + 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, + 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, + 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, + 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, + -0.104502f, -0.104502f, -0.104502f, -0.103843f, -0.103843f, -0.103843f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, + -0.289425f, -0.289425f, -0.289425f, -0.292174f, -0.292174f, -0.292174f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = + op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_10) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f, - 0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f, - 0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, - sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, + 0.570404f, 0.570404f, 0.570404f, 0.57777f, 0.57777f, 0.57777f, + 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.576568f, 0.576568f, 0.576568f, + 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.611224f, 0.611224f, 0.611224f, + 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, + 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, + 0.692315f, 0.692315f, 0.692315f}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, + 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, + 1.767702f, 1.767702f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_11) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003f; + Wr = 0.006f; + b = 0.5f; + hI = 1.f; + cI = 2.f; + Wp = -0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH( + 'c', {sL, bS, nOut}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.61209f, 0.61209f, 0.61209f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.699627f, 0.699627f, 0.699627f, 0.705371f, 0.705371f, 0.705371f, + 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., + 0.719014, 0.719014, 0.719014, 0.724087, 0.724087f, 0.724087f, + 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, + 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, + 0.61209f, 0.61209f, 0.61209f}, + sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, + {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, + 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, + 1.646034f, 1.646034f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx = 0.003f; - Wr = 0.006f; - b = 0.5f; - hI = 1.f; - cI = 2.f; - Wp = -0.05f; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, nOut}, { - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f, - 0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f, - 0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, - 0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, sd::DataType::FLOAT32); - - NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, sd::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - - #endif +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_12) { - #ifndef HAVE_MKLDNN - - const int sL = 6; - const int bS = 5; - const int nIn = 4; - const int nOut = 3; - - // input arguments - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step - - const double cellClip = 0; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); - NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); - NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); - - x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003f; - Wx({1,2, 0,0, 0,0}) = -0.003f; - Wr({0,1, 0,0, 0,0}) = 0.006f; - Wr({1,2, 0,0, 0,0}) = -0.006f; - b({0,1, 0,0}) = 0.5f; - b({1,2, 0,0}) = -0.5f; - hI({0,1, 0,0, 0,0}) = 1; - hI({1,2, 0,0, 0,0}) = -1; - cI({0,1, 0,0, 0,0}) = 2; - cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05f; - Wp({1,2, 0,0}) = 0.05f; - - std::initializer_list tArgs = {cellClip}; - std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103, - -0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025, - -0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, - 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, - 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - - NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f, - 0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, sd::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f, - 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, sd::DataType::FLOAT32); - - sd::ops::lstmLayer op; - auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hL = results.at(1); - auto cL = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expHL.isSameShape(hL)); - ASSERT_TRUE(expHL.equalsTo(hL)); - - ASSERT_TRUE(expCL.isSameShape(cL)); - ASSERT_TRUE(expCL.equalsTo(cL)); - - #endif +#ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2, 4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2, bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0, 1, 2, 3, 5}, sd::DataType::FLOAT32); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0, 1, 0, 0, 0, 0}) = 0.003f; + Wx({1, 2, 0, 0, 0, 0}) = -0.003f; + Wr({0, 1, 0, 0, 0, 0}) = 0.006f; + Wr({1, 2, 0, 0, 0, 0}) = -0.006f; + b({0, 1, 0, 0}) = 0.5f; + b({1, 2, 0, 0}) = -0.5f; + hI({0, 1, 0, 0, 0, 0}) = 1; + hI({1, 2, 0, 0, 0, 0}) = -1; + cI({0, 1, 0, 0, 0, 0}) = 2; + cI({1, 2, 0, 0, 0, 0}) = -2; + Wp({0, 1, 0, 0}) = -0.05f; + Wp({1, 2, 0, 0}) = 0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, + cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, + hasInitC, hasPH, retFullSeq, + retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2 * nOut}, + {0., 0., 0., 0., 0., 0., + 0.562925, 0.562925, 0.562925, -0.25361, -0.25361, -0.25361, + 0.570404, 0.570404, 0.570404, -0.157103, -0.157103, -0.157103, + 0.57777, 0.57777, 0.57777, -0.116502, -0.116502, -0.116502, + 0.585023, 0.585023, 0.585023, -0.100025, -0.100025, -0.100025, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, + 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714, + 0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, + 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + + NDArray expHL( + 'c', {2, bS, nOut}, + {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, + 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, + 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, + -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, + -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, + sd::DataType::FLOAT32); + NDArray expCL( + 'c', {2, bS, nOut}, + {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, + 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, + 1.767702f, 1.767702f, 1.767702f, 0.f, 0.f, 0.f, + -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, + -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +#endif } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4 * nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, + bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = + true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4 * nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, + GradCheck::LossFunc::MEAN); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { - - const int sL = 4; - const int bS = 3; - const int nIn = 3; - const int nOut = 2; - - const int dataFormat = 2; // [bS, nIn, sL] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); - - ASSERT_TRUE(isGradCorrect); + const int sL = 4; + const int bS = 3; + const int nIn = 3; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 0; // forward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4 * nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {2, 0, 4}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, + tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, + iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, + {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4 * nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, + iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, + bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 2; - - const int dataFormat = 2; // [bS, nIn, sL] - const int directionMode = 1; // backward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 1; // backward + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4 * nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0, 2}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, + tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, + iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, + {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 2; - - const int dataFormat = 2; // [bS, nIn, sL] - const int directionMode = 2; // bidirectional sum - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); - NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); - NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); - NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 2; // bidirectional sum + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4 * nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0, 2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, + tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, + iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdh, &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); const + // OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, &Wx, + // &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, + {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 2; - - const int dataFormat = 1; // [bS,sL,nIn] - const int directionMode = 3; // bidirectional concat - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); - NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); - NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4 * nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0, 2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, 2 * nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, + tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, + iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdh, &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); const + // OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, &Wx, + // &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, + {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 2; - - const int dataFormat = 3; // [sL, bS, nIn] - const int directionMode = 4; // bidirectional extra output dim - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = true; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); - NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); - NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); - NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); - - ASSERT_TRUE(isGradCorrect); + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 3; // [sL, bS, nIn] + const int directionMode = 4; // bidirectional extra output dim + const int gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4 * nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4 * nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0, 2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3 * nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2, 0.1); + hI.linspace(-1.5, 0.1); + cI.linspace(0.7, -0.1); + Wx.linspace(1, -0.1); + Wr.linspace(-1, 0.1); + Wp.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, + outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, + hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, + tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP( + {&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, + iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdh, &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, + // &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); const + // OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, + // &dLdhL}, tArgs, iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, &Wx, + // &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, + {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test1) { + NDArray input('c', {2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); + NDArray expected('c', {2, 4}, + {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, + 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, + sd::DataType::FLOAT32); - sd::ops::batchnorm op; + input.linspace(0.1, 0.1); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + sd::ops::batchnorm op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1}); - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create('c', {4}); + auto variance = NDArrayFactory::create('c', {4}); + auto gamma = NDArrayFactory::create('c', {4}); + auto beta = NDArrayFactory::create('c', {4}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {4}); - auto variance = NDArrayFactory::create('c', {4}); - auto gamma = NDArrayFactory::create('c', {4}); - auto beta = NDArrayFactory::create('c', {4}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, - 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, + 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, + 1.16970393f, 1.33940786f, 1.50911179f, 1.67881572f, 1.84851965f, + 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, + 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); - sd::ops::batchnorm op; + input.linspace(0.1, 0.1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(1.); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + sd::ops::batchnorm op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1}); - auto output = results.at(0); - // output->printBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); + auto variance = + NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); + auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); + auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); - auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); - auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); - auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, + -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, + 0.21633459f, 0.38366541f, 0.52425983f, 0.69396376f, 0.86366769f, + 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, + 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, - 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); + input.linspace(0.1, 0.1); - input.linspace(0.1, 0.1); + sd::ops::batchnorm op; - sd::ops::batchnorm op; + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { + auto input = NDArrayFactory::create('c', {2, 3, 4}); + auto mean = NDArrayFactory::create( + 'c', {2, 1, 4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); + auto variance = NDArrayFactory::create( + 'c', {2, 1, 4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto gamma = NDArrayFactory::create( + 'c', {2, 1, 4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); + auto beta = NDArrayFactory::create( + 'c', {2, 1, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - auto input = NDArrayFactory::create('c', {2,3,4}); - auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); - auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); - auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); - auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); - - auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, - 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, + -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, + 0.21633459f, 0.4f, 0.58432694f, 0.82999915f, 0.95743373f, + 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, + 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); - input.linspace(0.1, 0.1); + input.linspace(0.1, 0.1); - sd::ops::batchnorm op; + sd::ops::batchnorm op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test5) { + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, - -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, - -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - - sd::ops::batchnorm op; + NDArray expected( + 'c', {2, 4, 2, 2}, + {11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, + 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, + -9.852428f, -10.f, -20.f, -19.856981f, -19.713963f, + -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, + 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, + -17.425663f, -17.282644f}, + sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); + sd::ops::batchnorm op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); - auto output = results.at(0); - // output->printBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test6) { + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); + NDArray expected( + 'c', {2, 2, 2, 4}, + {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, + 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, + -9.852428f, -20.143019f, 9.57574f, 20.388447f, -10.442716f, + -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, + 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, + 22.25299f, -12.213582f, -17.854719f, 6.860477f, 22.874504f, + -12.803871f, -17.282644f}, + sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); - NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, - 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, - -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); + sd::ops::batchnorm op; - sd::ops::batchnorm op; + auto results = + op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 3}); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test7) { + NDArray input1('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); + NDArray input2('c', {3, 15, 15, 3}, sd::DataType::FLOAT32); + input2.permutei({0, 3, 1, 2}); - NDArray input1('c', {3,3,15,15}, sd::DataType::FLOAT32); - NDArray input2('c', {3,15,15,3}, sd::DataType::FLOAT32); - input2.permutei({0,3,1,2}); + NDArray mean('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray gamma('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray beta('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); - NDArray mean ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); - NDArray variance('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); - NDArray beta ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + NDArray out1('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); + NDArray out2('c', {3, 3, 15, 15}, sd::DataType::FLOAT32); - NDArray out1('c', {3,3,15,15}, sd::DataType::FLOAT32); - NDArray out2('c', {3,3,15,15}, sd::DataType::FLOAT32); + input1.linspace(-1012, 1); + input2.assign(input1); - input1.linspace(-1012, 1); - input2.assign(input1); + sd::ops::batchnorm op; - sd::ops::batchnorm op; + auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, + {1e-5}, {1, 1, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res1); - auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, {1e-5}, {1,1,1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res1); + auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, + {1e-5}, {1, 1, 1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res2); - auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, {1e-5}, {1,1,1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res2); - - ASSERT_TRUE(out1.equalsTo(out2)); + ASSERT_TRUE(out1.equalsTo(out2)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test8) { - - NDArray input('c', {2,3,4,5}, sd::DataType::FLOAT32); - - NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,3,4,5}, {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, - -82.957886, -81.260841, -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, - -60.896374, -59.199333, -57.502296, -55.805256, -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, - -38.834862, -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, - -16.773354, -15.076314, -13.379274, -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, - 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, - 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, - 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, - 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, sd::DataType::FLOAT32); - - input.linspace(-60, 1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(-1.5); - - sd::ops::batchnorm op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray mean('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray expected( + 'c', {2, 3, 4, 5}, + {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, + -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, + -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, + -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, + -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, + -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, + -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, + -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, + -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, + -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, + -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, + -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, + 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, + 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, + 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, + 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, + 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, + 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, + 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, + 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, + 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, + 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, + 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, + sd::DataType::FLOAT32); + + input.linspace(-60, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 1, 2, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test9) { - - NDArray input('c', {2,3,3,3,3}, sd::DataType::FLOAT32); - - NDArray mean ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); - - NDArray expected('c', {2,3,3,3,3}, {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, -121.989784, -120.292747, - -118.595711, -116.898666, -115.201630, -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, -105.019394, -103.322357, -101.625313, -99.928276, - -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, -79.563805, -77.866768, - -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, - -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, -37.137825, -35.440784, -33.743744, - -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, -11.682236, - -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, - 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, - 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, - 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, - 91.837158, 93.534195, 95.231239, 96.928276, 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, - 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, 132.566101, 134.263138}, sd::DataType::FLOAT32); - - input.linspace(-80, 1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(-1.5); - - sd::ops::batchnorm op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + NDArray input('c', {2, 3, 3, 3, 3}, sd::DataType::FLOAT32); + + NDArray mean('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 3, 3, 3}, sd::DataType::FLOAT32); + + NDArray expected( + 'c', {2, 3, 3, 3, 3}, + {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, + -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, + -121.989784, -120.292747, -118.595711, -116.898666, -115.201630, + -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, + -105.019394, -103.322357, -101.625313, -99.928276, -98.231239, + -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, + -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, + -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, + -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, + -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, + -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, + -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, + -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, + -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, + -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, + -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, + 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, + 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, + 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, + 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, + 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, + 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, + 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, + 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, + 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, + 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, + 90.140121, 91.837158, 93.534195, 95.231239, 96.928276, + 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, + 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, + 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, + 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, + 132.566101, 134.263138}, + sd::DataType::FLOAT32); + + input.linspace(-80, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, + {1, 1, 1, 2, 3, 4}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.1, 1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - variance.assign(0.46666667); - gamma.assign(1.2); - beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - sd::ops::batchnorm_bp op; + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, + -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, + 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, + 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, sd::DataType::FLOAT32); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + input.linspace(0.1, 0.1); + variance.assign(0.46666667); + gamma.assign(1.2); + beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::batchnorm_bp op; - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1}); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {3}, {1.05, 1.1, 1.15}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, sd::DataType::FLOAT32); + NDArray gamma('c', {3}, {1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray beta('c', {3}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, sd::DataType::FLOAT32); - NDArray variance('c', {3}, {0.5, 0.6, 0.7}, sd::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, sd::DataType::FLOAT32); - NDArray beta ('c', {3}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, - 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, - -0.290863, -0.343746, -0.396631}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, + -0.334624, -0.273784, 0.396631, 0.343747, 0.290863, 0.237978, + 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, + 0.395465, 0.456306, -0.237978, -0.290863, -0.343746, -0.396631}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236, 7.048771, 12.155388}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, sd::DataType::FLOAT32); - sd::ops::batchnorm_bp op; + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + sd::ops::batchnorm_bp op; - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { - - NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, sd::DataType::FLOAT32); - NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, sd::DataType::FLOAT32); - NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, sd::DataType::FLOAT32); - NDArray beta ('c', {2,1,4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, - 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, - -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, sd::DataType::FLOAT32); - NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, sd::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {2, 1, 4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, + sd::DataType::FLOAT32); + NDArray variance('c', {2, 1, 4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, + sd::DataType::FLOAT32); + NDArray gamma('c', {2, 1, 4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, + sd::DataType::FLOAT32); + NDArray beta('c', {2, 1, 4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 3, 4}, + {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, + -0.000000, -0.000000, 0.577002, 0.744041, 0.850999, 0.922373, + -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {2, 1, 4}, + {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, + 3.289431, 3.64234}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {2, 1, 4}, {-0.9, -0.45, 0., 0.45, 4.5, 4.95, 5.4, 5.85}, + sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 0, 2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { + NDArray input('c', {2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4}, sd::DataType::FLOAT32); - NDArray input ('c', {2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4}, sd::DataType::FLOAT32); + NDArray expdLdI('c', {2, 4}, + {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, + 0.289673, -0.354174, 0.386151}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + sd::ops::batchnorm_bp op; - sd::ops::batchnorm_bp op; + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1}); - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, - -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, - -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 4, 2, 2}, + {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, + 0.443214, 0.384118, -1.168243, -1.045270, -0.922297, -0.799324, + 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, + 0.659880, 0.737512, -0.384118, -0.443214, -0.502308, -0.561404, + 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, + -1.699129, -1.899026}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9., 13.8, 18.6}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, - 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, - -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 2, 2, 4}, + {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, + -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, 0.339329, + -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, + 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, -0.339330, + 3.563660, -1.814540, 1.082159, -0.565549, 4.989125, -2.540356, + 1.515022, -0.791770}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, - -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, - 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, - -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, - -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, sd::DataType::FLOAT32); + NDArray input('c', {2, 2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 2, 4}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + NDArray expdLdI( + 'c', {2, 2, 2, 2, 4}, + {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, + -50.901920, 40.412773, -87.585716, 57.317142, -43.070854, 34.195419, + -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, + -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, 15.543370, + -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, + -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, + 23.887032, -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, + 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, + 71.661064, -46.895851, 35.239788, -27.978077, 87.585732, -57.317154, + 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, + 119.435097, -78.159744, 58.732998, -46.630131}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734, 244.542027, 224.140995, 207.548793}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60., 62.4, 64.8}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 4}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - // dLdI->printBuffer(); + // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) { - #if defined(HAVE_CUDNN) -return; + return; #endif - NDArray input ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); + NDArray input('c', {2, 4, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2, 2}, sd::DataType::FLOAT32); - NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, - 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, - -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, - 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, - 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, sd::DataType::FLOAT32); + NDArray expdLdI( + 'c', {2, 4, 2, 2, 2}, + {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, + -23.797251, -22.034491, 36.146996, 34.293301, 32.439610, 30.585917, + 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, + -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, -27.484968, + 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, + 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, + 29.085526, 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, + -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, + 27.484982, 29.683773, 31.882572, 34.081364, 36.280178, 38.478970, + 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, + -42.878403, -45.477081, -48.075775, -50.674484}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, sd::DataType::FLOAT32); - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); - sd::ops::batchnorm_bp op; + sd::ops::batchnorm_bp op; - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); - // dLdI->printBuffer(); + // dLdI->printBuffer(); - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { - - NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, - -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, - -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0,2,3}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 4, 2, 2}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 4, 2, 2}, + {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, + -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, + -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, + -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, + -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, + 0.036168, 0.040426}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9., 13.8, 18.6}, sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0, 2, 3}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { - - NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); - NDArray mean ('c', {4}, sd::DataType::FLOAT32); - NDArray variance('c', {4}, sd::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); - NDArray beta ('c', {4}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, - -0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171, - -0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0,1,2}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + NDArray mean('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta('c', {4}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 4}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 2, 2, 4}, + {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, + 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, -0.017519, + 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, + -0.005452, 0.005824, -0.013974, 0.015171, -0.016325, 0.017508, + -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, + -0.038118, 0.040878}, + sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, + sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0, 1, 2}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { - - NDArray input ('c', {2,3,4,5}, sd::DataType::FLOAT32); - NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4,5}, sd::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837, - 0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498, - 0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, - -0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, - -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985, - -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835, - -0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334, - 0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, - 0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, sd::DataType::FLOAT32); - NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, - 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503, - 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501, - 8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, sd::DataType::FLOAT32); - NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0, - 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5, - 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, sd::DataType::FLOAT32); - - input.linspace(1,0.01); - gradO.linspace(-0.9, 0.15); - gamma.linspace(-3, 0.1); - - // calculate mean and variance of input - PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); - std::vector dimensions = {0}; - int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); - input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions, true); - NDArray::prepareSpecialUse({&variance}, {&input}); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); - manager.synchronize(); - NDArray::registerSpecialUse({&variance}, {&input}); - - sd::ops::batchnorm_bp op; - - auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdI = results.at(0); - auto dLdG = results.at(3); - auto dLdB = results.at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray mean('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray variance('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gamma('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray beta('c', {1, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + + NDArray expdLdI( + 'c', {2, 3, 4, 5}, + {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, + 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, + 0.003001, 0.002837, 0.002670, 0.002505, 0.002337, 0.002167, + 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, + 0.000996, 0.000830, 0.000664, 0.000498, 0.000332, 0.000166, + -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, + -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, -0.001838, + -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, + -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, + -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, + -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, + -0.003985, -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, + -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, + -0.002003, -0.001835, -0.001664, -0.001499, -0.001328, -0.001162, + -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, + 0.0, 0.000166, 0.000334, 0.000500, 0.000668, 0.000834, + 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, + 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, 0.002836, + 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, + 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, + sd::DataType::FLOAT32); + NDArray expdLdG( + 'c', {1, 3, 4, 5}, + {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, + 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, + 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, + 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, + 8.999504, 8.999503, 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, + 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, + 8.999496, 8.999496, 8.999501, 8.999501, 8.999499, 8.999499, 8.999499, + 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, + 8.999501, 8.999495, 8.999495, 8.999497}, + sd::DataType::FLOAT32); + NDArray expdLdB( + 'c', {1, 3, 4, 5}, + {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, + 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, + 14.4, 14.7, 15.0, 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, + 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, + 21.6, 21.9, 22.2, 22.5, 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, + sd::DataType::FLOAT32); + + input.linspace(1, 0.01); + gradO.linspace(-0.9, 0.15); + gamma.linspace(-3, 0.1); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), + "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0}; + int* dims = reinterpret_cast(manager.replicatePointer( + dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions, true); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats( + input.getContext(), 0, input.buffer(), input.shapeInfo(), + input.specialBuffer(), input.specialShapeInfo(), nullptr, + variance.buffer(), variance.shapeInfo(), variance.specialBuffer(), + variance.specialShapeInfo(), dims, dimensions.size(), + packX.platformShapeInfo(), packX.platformOffsets(), false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, + {1e-5}, {1, 1, 1, 2, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index e21b81dbe497..3ffc0d4e70d2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -14,2406 +14,2706 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests14 : public testing::Test { -public: - - DeclarableOpsTests14() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests14() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2}, Environment::getInstance()->defaultFloatDataType()); - exp.assign(4.0f); - - sd::ops::fill op; - auto result = op.evaluate({&x}, {4.0f}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, Environment::getInstance()->defaultFloatDataType()); + exp.assign(4.0f); - ASSERT_EQ(exp, z); + sd::ops::fill op; + auto result = op.evaluate({&x}, {4.0f}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto x = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - ASSERT_EQ(x, y); + ASSERT_EQ(x, y); } TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_2) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); - auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, -std::numeric_limits::infinity(), 5}); + auto x = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create( + 'c', {5}, {1, 2, 3, -std::numeric_limits::infinity(), 5}); - ASSERT_NE(x, y); + ASSERT_NE(x, y); } TEST_F(DeclarableOpsTests14, Multiply_test) { + for (int k = 2; k < 10; k++) { + // nd4j_printf("k=%d\n", k); + NDArray x = NDArrayFactory::create('c', {k, 1}); + NDArray y = NDArrayFactory::create('c', {k}); + NDArray e = NDArrayFactory::create('c', {k, k}); + x.assign(1.0); + y.assign(1.0); + e.assign(1.0); - for(int k=2;k<10;k++){ - //nd4j_printf("k=%d\n", k); - NDArray x = NDArrayFactory::create('c', {k, 1}); - NDArray y = NDArrayFactory::create('c', {k}); - NDArray e = NDArrayFactory::create('c', {k, k}); - x.assign(1.0); - y.assign(1.0); - e.assign(1.0); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}); - auto f = result.at(0); + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + auto f = result.at(0); - ASSERT_EQ(e, f); - } + ASSERT_EQ(e, f); + } } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { - auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {2}, {5, 4}); - - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {2}, {5, 4}); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { - auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); - auto y = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); - - sd::ops::evaluate_reduction_shape op; - auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { - auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); - auto z = NDArrayFactory::create('c', {4}); - auto e = NDArrayFactory::create('c', {4}, {-999.f, 0.2236f, -2.1340f, 0.0962f}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, + -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {4}); + auto e = NDArrayFactory::create('c', {4}, + {-999.f, 0.2236f, -2.1340f, 0.0962f}); - sd::ops::reduce_min op; - op.execute({&x}, {&z}, {}, {0}, {}); + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {0}, {}); - //z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_1) { - auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {-999.f, -0.7301f, -2.1340f}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, + -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {3}); + auto e = + NDArrayFactory::create('c', {3}, {-999.f, -0.7301f, -2.1340f}); - sd::ops::reduce_min op; - op.execute({&x}, {&z}, {}, {1}, {}); + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {1}, {}); - //z.printIndexedBuffer("Z"); + // z.printIndexedBuffer("Z"); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Test_Diag_Zeros_1) { - auto x = NDArrayFactory::create('c', {2}, {1, 2}); - auto z = NDArrayFactory::create('c', {2, 2}, {-119, -119, -119, -119}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 2}); + auto x = NDArrayFactory::create('c', {2}, {1, 2}); + auto z = + NDArrayFactory::create('c', {2, 2}, {-119, -119, -119, -119}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 2}); - sd::ops::diag op; - auto status = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::diag op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {5, 10}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0); - - - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, result.at(0)); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { - auto x = NDArrayFactory::create(1.0f); - auto y = NDArrayFactory::create('c', {5, 10}); - auto e = NDArrayFactory::create('c', {5, 10}); - y.assign(2.0f); - e.assign(-1.0f); - - - sd::ops::subtract op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, result.at(0)); + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + y.assign(2.0f); + e.assign(-1.0f); + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests14, test_empty_fill_1) { - auto x = NDArrayFactory::empty(); - auto y = NDArrayFactory::create(1); - - sd::ops::fill op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(y, z); + auto x = NDArrayFactory::empty(); + auto y = NDArrayFactory::create(1); + sd::ops::fill op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(y, z); } TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { - auto a = NDArrayFactory::create('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); - auto b = NDArrayFactory::create('c', {1, 3}); - auto c = NDArrayFactory::create('c', {1, 3}); - auto d = NDArrayFactory::create('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976}); - auto e = NDArrayFactory::create('c', {3}); - auto f = NDArrayFactory::create('c', {3}); - auto g = NDArrayFactory::create('c', {3}); - auto h = NDArrayFactory::create('c', {12}); - - auto z0 = NDArrayFactory::create('c', {1, 3}); - auto z1 = NDArrayFactory::create('c', {1, 3}); - auto z2 = NDArrayFactory::create('c', {1, 3}); - auto z3 = NDArrayFactory::create('c', {1, 3}); - auto z4 = NDArrayFactory::create('c', {1, 3}); - auto z5 = NDArrayFactory::create('c', {1, 3}); - auto z6 = NDArrayFactory::create('c', {1, 3}); - - sd::ops::lstmBlockCell op; - auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); - ASSERT_EQ(Status::OK(), result); + auto a = NDArrayFactory::create( + 'c', {1, 5}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); + auto b = NDArrayFactory::create('c', {1, 3}); + auto c = NDArrayFactory::create('c', {1, 3}); + auto d = NDArrayFactory::create( + 'c', {8, 12}, + {-0.15320599, -0.120416045, 0.33126968, 0.13921785, -0.32313538, + -0.43956736, 0.4756174, 0.4335605, -0.5450856, -0.3943429, + -0.28687626, 0.068032146, -0.2793799, 0.17298919, -0.36553562, + -0.097853184, -0.2544747, -0.39872527, -0.14556861, -0.31479517, + 0.2559092, 0.47166896, -0.31330687, 0.47313118, 0.5134543, + -0.4678212, -0.12853557, 0.26142156, 0.43472284, -0.42842552, + -0.1895876, 0.538689, 0.508651, -0.020272732, 0.112327516, + 0.2704304, -0.046546757, 0.32570732, -0.15148133, -0.19145513, + 0.18631572, -0.024152994, 0.41603214, -0.3421499, 0.0106860995, + -0.2966229, -0.36713937, 0.25841123, 0.0843398, 0.49082482, + 0.10800403, 0.1874243, -0.26379472, -0.22531849, 0.24924624, + 0.23119557, 0.49940765, -0.051413506, 0.20315129, -0.41888732, + 0.44097036, 0.40453392, 0.013338983, 0.23434466, 0.23942488, + 0.47894, -0.19898453, 0.09253675, -0.032358468, -0.15213022, + -0.3441009, -0.15600958, -0.08235118, 0.12165731, -0.4481289, + -0.4842423, -0.45797008, -0.4606034, 0.08163166, -0.2981107, + 0.50207126, 0.44195646, 0.13850057, 0.072246075, -0.34388685, + 0.030900061, 0.35821778, 0.47900867, 0.5094063, 0.23683065, + 0.18020362, -0.1369732, 0.015235603, 0.2786904, 0.07954317, + 0.12543976}); + auto e = NDArrayFactory::create('c', {3}); + auto f = NDArrayFactory::create('c', {3}); + auto g = NDArrayFactory::create('c', {3}); + auto h = NDArrayFactory::create('c', {12}); + + auto z0 = NDArrayFactory::create('c', {1, 3}); + auto z1 = NDArrayFactory::create('c', {1, 3}); + auto z2 = NDArrayFactory::create('c', {1, 3}); + auto z3 = NDArrayFactory::create('c', {1, 3}); + auto z4 = NDArrayFactory::create('c', {1, 3}); + auto z5 = NDArrayFactory::create('c', {1, 3}); + auto z6 = NDArrayFactory::create('c', {1, 3}); + + sd::ops::lstmBlockCell op; + auto result = + op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, + {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_min sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - - ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); - + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_max sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_max sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - - ASSERT_EQ(out.e(0), -DataTypeUtils::infOrMax()); - + ASSERT_EQ(out.e(0), -DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_sum sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - ASSERT_EQ(out.e(0), 0.f); - + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_sum sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + ASSERT_EQ(out.e(0), 0.f); } TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto e = NDArrayFactory::create('c', {1, 0}); - sd::ops::reduce_mean sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); - // out->printShapeInfo("ReduceMean empty shape with keep dims"); - // out->printIndexedBuffer("ReduceMean scalar"); - ASSERT_TRUE(std::isnan(out.e(0))); - + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_mean sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + // out->printShapeInfo("ReduceMean empty shape with keep dims"); + // out->printIndexedBuffer("ReduceMean scalar"); + ASSERT_TRUE(std::isnan(out.e(0))); } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { - auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); - auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); - auto e = NDArrayFactory::create('c', {3}, {2,0,2}); - auto s = NDArrayFactory::create('c', {3}, {1,1,1}); - - auto exp = NDArrayFactory::create('c', {1,0,0,4}); - - matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2, 0, 2}); + auto s = NDArrayFactory::create('c', {3}, {1, 1, 1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = NDArrayFactory::create('c', {1, 0, 0, 4}); - auto z = result.at(0); + matrix.linspace(1); - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { - auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); - auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); - auto e = NDArrayFactory::create('c', {3}, {2,0,2}); - auto s = NDArrayFactory::create('c', {3}, {1,1,1}); - - auto exp = NDArrayFactory::create('c', {0,0,4}); - - matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2, 0, 2}); + auto s = NDArrayFactory::create('c', {3}, {1, 1, 1}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = NDArrayFactory::create('c', {0, 0, 4}); - auto z = result.at(0); + matrix.linspace(1); - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { - auto x = NDArrayFactory::create('c', {1, 0}); - auto y = NDArrayFactory::create(0); - auto e = NDArrayFactory::create('c', {0}); + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(0); + auto e = NDArrayFactory::create('c', {0}); - sd::ops::argmax op; - //sd::ops::reduce_max op; + sd::ops::argmax op; + // sd::ops::reduce_max op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_EQ(e, z); + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { - auto x = NDArrayFactory::create('c', {1, 0}); - auto y = NDArrayFactory::create(1); + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(1); - sd::ops::argmax op; - try { - auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::exception &e) { - // - } + sd::ops::argmax op; + try { + auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::exception &e) { + // + } } TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { - auto x = NDArrayFactory::create('c', {32, 0}); - - sd::ops::tanh op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {32, 0}); - auto z = result.at(0); - - ASSERT_TRUE(x.isSameShape(z)); - ASSERT_EQ(x, z); + sd::ops::tanh op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_EQ(x, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_1) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_2) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_3) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 6, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {1,2,3, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1, 2, 3, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_4) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {7, 3}, + {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); - NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {3,4, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {3, 4, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, repeat_5) { + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray e('c', {2, 4, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); - - sd::ops::repeat op; - auto result = op.evaluate({&x}, {}, {1,2,1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1, 2, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { + auto y = NDArray('c', {3}, sd::DataType::FLOAT32); + auto x = NDArray('c', {5, 2, 1}, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3 }, sd::DataType::FLOAT32); - auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); - - auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); + auto e = NDArray( + 'c', {5, 2, 3}, + {2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., + 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11.}, + sd::DataType::FLOAT32); - y.assign(1.0); - x.linspace(1.0); - - sd::ops::add op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); - - auto res = result.at(0); + y.assign(1.0); + x.linspace(1.0); - ASSERT_EQ(e, res); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + ASSERT_EQ(e, res); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { + auto y = NDArray('c', {1, 3}, sd::DataType::FLOAT32); + auto x = NDArray('c', {5, 2, 1}, sd::DataType::FLOAT32); - auto y = NDArray('c', { 1, 3 }, sd::DataType::FLOAT32); - auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); - - auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); - - y.assign(1.0); - x.linspace(1.0); - - sd::ops::add op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); + auto e = NDArray( + 'c', {5, 2, 3}, + {2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., + 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11.}, + sd::DataType::FLOAT32); - auto res = result.at(0); + y.assign(1.0); + x.linspace(1.0); - ASSERT_EQ(e, res); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + ASSERT_EQ(e, res); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest3) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { 10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315. }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {3, 5, 4}, + {10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., + 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., + 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., + 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., + 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315.}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest4) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 10., 11., 12., 13.,20., 22., 24., 26.,30., 33., 36., 39.,40., 44., 48., 52.,50., 55., 60., 65.,84., 90., 96., 102.,98., 105., 112., 119.,112., 120., 128., 136.,126., 135., 144., 153.,140., 150., 160., 170.,198., 209., 220., 231.,216., 228., 240., 252.,234., 247., 260., 273.,252., 266., 280., 294.,270., 285., 300., 315.,352., 368., 384., 400.,374., 391., 408., 425.,396., 414., 432., 450.,418., 437., 456., 475.,440., 460., 480., 500.,546., 567., 588., 609.,572., 594., 616., 638.,598., 621., 644., 667.,624., 648., 672., 696.,650., 675., 700., 725.,780., 806., 832., 858.,810., 837., 864., 891.,840., 868., 896., 924.,870., 899., 928., 957.,900., 930., 960., 990. }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., + 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., + 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., + 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., + 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315., + 352., 368., 384., 400., 374., 391., 408., 425., 396., 414., 432., 450., + 418., 437., 456., 475., 440., 460., 480., 500., 546., 567., 588., 609., + 572., 594., 616., 638., 598., 621., 644., 667., 624., 648., 672., 696., + 650., 675., 700., 725., 780., 806., 832., 858., 810., 837., 864., 891., + 840., 868., 896., 924., 870., 899., 928., 957., 900., 930., 960., 990.}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest5) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, 0.833333, 0.789474, 0.750000, 0.714286 }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {3, 5, 4}, + {0.1, 0.090909, 0.083333, 0.076923, 0.2, 0.181818, 0.166667, + 0.153846, 0.3, 0.272727, 0.250000, 0.230769, 0.4, 0.363636, + 0.333333, 0.307692, 0.5, 0.454545, 0.416667, 0.384615, 0.428571, + 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, + 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, + 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, + 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, + 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, + 0.833333, 0.789474, 0.750000, 0.714286}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest6) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235,0.611111, 0.578947, 0.550000, 0.523810,0.666667, 0.631579, 0.600000, 0.571429,0.722222, 0.684211, 0.650000, 0.619048,0.777778, 0.736842, 0.700000, 0.666667,0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, 0.909091 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {0.1, 0.090909, 0.083333, 0.076923, 0.2, 0.181818, 0.166667, + 0.153846, 0.3, 0.272727, 0.250000, 0.230769, 0.4, 0.363636, + 0.333333, 0.307692, 0.5, 0.454545, 0.416667, 0.384615, 0.428571, + 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, + 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, + 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, + 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, + 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, + 0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, + 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, + 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, + 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, + 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, + 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, + 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, + 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, + 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, + 0.909091}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest7) { - - auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 3, 5, 4 }, { -9., -10., -11., -12.,-8., -9., -10., -11., -7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8.000000, -9.000000, -10.00,-6.000000, -7.000000, -8.000000, -9.000,-5.000000, -6.000000, -7.000000, -8.000,-4.000000, -5.000000, -6.000000, -7.000,-3.000000, -4.000000, -5.000000, -6.000 }, sd::DataType::FLOAT32); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = + NDArray('c', {3, 5, 4}, + {-9., -10., -11., -12., -8., -9., + -10., -11., -7., -8., -9., -10., + -6., -7., -8., -9., -5., -6., + -7., -8., -8., -9., -10., -11., + -7., -8., -9., -10., -6., -7., + -8., -9., -5., -6., -7., -8., + -4., -5., -6., -7., -7., -8.000000, + -9.000000, -10.00, -6.000000, -7.000000, -8.000000, -9.000, + -5.000000, -6.000000, -7.000000, -8.000, -4.000000, -5.000000, + -6.000000, -7.000, -3.000000, -4.000000, -5.000000, -6.000}, + sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) { - - auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { -9.0, -10., -11., -12.,-8., -9., -10., -11.0,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4., 0., -1., -2., -3. }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 1}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 3, 1, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4}, + {-9.0, -10., -11., -12., -8., -9., -10., -11.0, -7., -8., -9., -10., + -6., -7., -8., -9., -5., -6., -7., -8., -8., -9., -10., -11., + -7., -8., -9., -10., -6., -7., -8., -9., -5., -6., -7., -8., + -4., -5., -6., -7., -7., -8., -9., -10., -6., -7., -8., -9., + -5., -6., -7., -8., -4., -5., -6., -7., -3., -4., -5., -6., + -6., -7., -8., -9., -5., -6., -7., -8., -4., -5., -6., -7., + -3., -4., -5., -6., -2., -3., -4., -5., -5., -6., -7., -8., + -4., -5., -6., -7., -3., -4., -5., -6., -2., -3., -4., -5., + -1., -2., -3., -4., -4., -5., -6., -7., -3., -4., -5., -6., + -2., -3., -4., -5., -1., -2., -3., -4., 0., -1., -2., -3.}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test2) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test3) { + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create('f', {3, 4}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test4) { + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); - auto x = NDArrayFactory::create ('f', {3, 4}); - auto y = NDArrayFactory::create('f', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test5) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test6) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('f', {3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('f', {3, 4}); - auto exp = NDArrayFactory::create('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test7) { - - auto x = NDArrayFactory::create('c', {5, 3,4}); - auto y = NDArrayFactory::create('f', {5, 3,4}); - auto exp = NDArrayFactory::create('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2, - 7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8, - 11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {5, 3, 4}); + auto y = NDArrayFactory::create('f', {5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {5, 3, 3}, + { + 3., 84.6, 281.4, 593.4, 1020.6, 7., 107.8, 323.8, 655., 1101.4, + 11., 131., 366.2, 716.6, 1182.2, 7., 107.8, 323.8, 655., 1101.4, + 17.4, 137.4, 372.6, 723., 1188.6, 27.8, 167., 421.4, 791., 1275.8, + 11., 131., 366.2, 716.6, 1182.2, 27.8, 167., 421.4, 791., 1275.8, + 44.6, 203., 476.6, 865.4, 1369.4, + }); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test8) { - - auto x = NDArrayFactory::create('c', {2,5, 3,4}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, - 17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, - 44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 5, 3, 4}); + auto y = NDArrayFactory::create('f', {2, 5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 5, 3, 3}, + {3., 1563., 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4, 1020.6, + 4884.6, 7., 1663., 107.8, 2339.8, 323.8, 3131.8, 655., 4039., + 1101.4, 5061.4, 11., 1763., 131., 2459., 366.2, 3270.2, 716.6, + 4196.6, 1182.2, 5238.2, 7., 1663., 107.8, 2339.8, 323.8, 3131.8, + 655., 4039., 1101.4, 5061.4, 17.4, 1769.4, 137.4, 2465.4, 372.6, + 3276.6, 723., 4203., 1188.6, 5244.6, 27.8, 1875.8, 167., 2591., + 421.4, 3421.4, 791., 4367., 1275.8, 5427.8, 11., 1763., 131., + 2459., 366.2, 3270.2, 716.6, 4196.6, 1182.2, 5238.2, 27.8, 1875.8, + 167., 2591., 421.4, 3421.4, 791., 4367., 1275.8, 5427.8, 44.6, + 1988.6, 203., 2723., 476.6, 3572.6, 865.4, 4537.4, 1369.4, 5617.4}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test9) { - - auto x = NDArrayFactory::create('c', {2,5, 4,3}); - auto y = NDArrayFactory::create('f', {2,5, 3,4}); - auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4, - 9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8, - 18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4, - 24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8, - 33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto y = NDArrayFactory::create('f', {2, 5, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 5, 3, 3}, + {7., 1639., 103., 2311., 314.2, 3098.2, 640.6, 4000.6, 1082.2, + 5018.2, 8., 1664., 108.8, 2340.8, 324.8, 3132.8, 656., 4040., + 1102.4, 5062.4, 9., 1689., 114.6, 2370.6, 335.4, 3167.4, 671.4, + 4079.4, 1122.6, 5106.6, 15.8, 1743.8, 131., 2435., 361.4, 3241.4, + 707., 4163., 1167.8, 5199.8, 18.4, 1770.4, 138.4, 2466.4, 373.6, + 3277.6, 724., 4204., 1189.6, 5245.6, 21., 1797., 145.8, 2497.8, + 385.8, 3313.8, 741., 4245., 1211.4, 5291.4, 24.6, 1848.6, 159., + 2559., 408.6, 3384.6, 773.4, 4325.4, 1253.4, 5381.4, 28.8, 1876.8, + 168., 2592., 422.4, 3422.4, 792., 4368., 1276.8, 5428.8, 33., + 1905., 177., 2625., 436.2, 3460.2, 810.6, 4410.6, 1300.2, 5476.2}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test10) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1); + auto y = NDArrayFactory::create('c', {5, 3}); + y.linspace(1); - auto y = NDArrayFactory::create('c', {5, 3}); - y.linspace(1); + float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, + 550.0f, 165.0f, 390.0f, 615.0f}; + Nd4jLong _expS[]{2, 3, 3, 1, 3, 0, 1, 102}; // expected shape + ArrayOptions::setDataType(_expS, sd::DataType::FLOAT32); + NDArray exp(_expB, _expS); - float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; - Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape - ArrayOptions::setDataType(_expS, sd::DataType::FLOAT32); - NDArray exp(_expB, _expS); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(1, std::make_shared()); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(-2, y); - variableSpace->putVariable(1, std::make_shared()); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1, -2}); + sd::ops::matmul op; - sd::ops::matmul op; + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(variableSpace->hasVariable(1)); - Nd4jStatus status = op.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(variableSpace->hasVariable(1)); + auto result = variableSpace->getVariable(1)->getNDArray(); - auto result = variableSpace->getVariable(1)->getNDArray(); + ASSERT_TRUE(result->equalsTo(exp)); - ASSERT_TRUE(result->equalsTo(exp)); - - delete block; - delete variableSpace; + delete block; + delete variableSpace; } TEST_F(DeclarableOpsTests14, matmul_test11) { - auto A = NDArrayFactory::create('c', {3, 3}); - auto B = NDArrayFactory::create('c', {3, 1}); - auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); - - A.linspace(1); - B.linspace(1); - - sd::ops::matmul op; + auto A = NDArrayFactory::create('c', {3, 3}); + auto B = NDArrayFactory::create('c', {3, 1}); + auto exp = + NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); - auto result = op.evaluate({&A, &B}, {}, {}); + A.linspace(1); + B.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matmul op; - auto z = result.at(0); + auto result = op.evaluate({&A, &B}, {}, {}); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test12) { - auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); - auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto y = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'f', {4, 4}, + {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, + 200.0, 173.0, 206.0, 239.0, 272.0}); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests14, matmul_test13) { - auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - auto z = result.at(0); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test14) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - auto z = result.at(0); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test15) { - auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {3, 4}, + {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); - //z->printIndexedBuffer("z"); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test16) { - auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); - auto z = result.at(0); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printIndexedBuffer("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, matmul_test17) { - auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); - auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(exp, result.at(0)); + auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(exp, result.at(0)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test18) { + auto x = NDArrayFactory::create('c', {1, 4, 3}); + auto y = NDArrayFactory::create('f', {1, 3, 4}); + auto exp = NDArrayFactory::create( + 'f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - auto x = NDArrayFactory::create('c', {1, 4, 3}); - auto y = NDArrayFactory::create('f', {1, 3, 4}); - auto exp = NDArrayFactory::create('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test19) { + auto x = NDArrayFactory::create('c', {4, 1}); + auto y = NDArrayFactory::create('f', {1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1}, {15}); - auto x = NDArrayFactory::create('c', {4, 1}); - auto y = NDArrayFactory::create('f', {1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), results.status()); - - auto z = results.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test20) { + auto x = NDArrayFactory::create('c', {1, 4, 1}); + auto y = NDArrayFactory::create('f', {1, 1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); - auto x = NDArrayFactory::create('c', {1, 4, 1}); - auto y = NDArrayFactory::create('f', {1, 1, 4}); - auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - - ASSERT_EQ(Status::OK(), results.status()); - auto z = results.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test21) { + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {23., 26., 29., 32., 35., 50., 57.5, 65., 72.5, 80.}); - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test22) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {37., 41.5, 46., 50.5, 55., 46., 52., 58., 64., 70.}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test23) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'f', {5, 2}, {37., 41.5, 46., 50.5, 55., 46., 52., 58., 64., 70.}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); - - x.linspace(1.); - y.linspace(0.5, 0.5); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test24) { - - auto x = NDArrayFactory::create('c', {2,2, 3,5}); - auto y = NDArrayFactory::create('c', {2,2, 4,3}); - auto exp = NDArrayFactory::create('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8, - 11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9, - 20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8, - 30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 5}); + auto y = NDArrayFactory::create('c', {2, 2, 4, 3}); + auto exp = NDArrayFactory::create( + 'f', {2, 2, 4, 5}, + {4.6, 281.8, 89.2, 582.4, 10., 314.2, 108.1, 628.3, 15.4, 346.6, + 127., 674.2, 20.8, 379., 145.9, 720.1, 5.2, 289.6, 93.4, 593.8, + 11.5, 322.9, 113.2, 640.6, 17.8, 356.2, 133., 687.4, 24.1, 389.5, + 152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13., 331.6, 118.3, 652.9, + 20.2, 365.8, 139., 700.6, 27.4, 400., 159.7, 748.3, 6.4, 305.2, + 101.8, 616.6, 14.5, 340.3, 123.4, 665.2, 22.6, 375.4, 145., 713.8, + 30.7, 410.5, 166.6, 762.4, 7., 313., 106., 628., 16., 349., + 128.5, 677.5, 25., 385., 151., 727., 34., 421., 173.5, 776.5}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test25) { + auto x = NDArrayFactory::create('f', {4, 3}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('f', {3}, {7., 8., 9.}); - auto x = NDArrayFactory::create('f', {4, 3}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('f',{3}, {7., 8., 9.}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test26) { + auto x = NDArrayFactory::create('f', {3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {4}, {1.4, 3.2, 5., 6.8}); - auto x = NDArrayFactory::create('f', {3}); - auto y = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('f',{4}, {1.4, 3.2, 5., 6.8}); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {0, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test27) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1, 1}, {0.2}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test28) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1, 1}, {0.2}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1,1,1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test29) { + auto x = NDArrayFactory::create('f', {1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f', {1}, {0.2}); - auto x = NDArrayFactory::create('f', {1}); - auto y = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test30) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('f', {1}, {0.2}); - auto x = NDArrayFactory::create('f', {1,1}); - auto y = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('f',{1}, {0.2}); - - x.linspace(2.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(2.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test31) { + auto x = NDArrayFactory::create('f', {4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create(3.); - auto x = NDArrayFactory::create('f', {4}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create(3.); - - x.linspace(1.); - y.linspace(0.1, 0.1); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.); + y.linspace(0.1, 0.1); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test32) { + auto x = NDArrayFactory::create('f', {1}, {2.}); + auto y = NDArrayFactory::create('c', {1}, {3.}); + auto exp = NDArrayFactory::create(6.); - auto x = NDArrayFactory::create('f', {1}, {2.}); - auto y = NDArrayFactory::create('c', {1}, {3.}); - auto exp = NDArrayFactory::create(6.); - - sd::ops::matmul op; - auto results = op.evaluate({&x, &y}, {}, {1, 1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test33) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto exp = NDArrayFactory::create('c',{ 3, 1}, {70, 80, 90}); - - x.linspace(1); - y.linspace(1); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto exp = NDArrayFactory::create('c', {3, 1}, {70, 80, 90}); - auto z = result.at(0); + x.linspace(1); + y.linspace(1); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test34) { - auto a = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); + auto a = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test35) { - auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); - - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test36) { - auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); - - sd::ops::matmul op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, matmul_test37) { + NDArray a('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray b('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray c('c', {32, 12, 128, 128}, sd::DataType::FLOAT32); + NDArray cExp('c', {32, 12, 128, 128}, sd::DataType::FLOAT32); - NDArray a('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); - NDArray b('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); - NDArray c('c', {32,12,128,128}, sd::DataType::FLOAT32); - NDArray cExp('c', {32,12,128,128}, sd::DataType::FLOAT32); - - a = 1; - b = 1; - cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications + a = 1; + b = 1; + cExp = 64; // Each entry in output c is sum of 64 (1.0 x 1.0) multiplications - sd::ops::matmul op; - auto status = op.execute({&a, &b}, {&c}, {}, {0,1}); + sd::ops::matmul op; + auto status = op.execute({&a, &b}, {&c}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(cExp.isSameShape(c)); - ASSERT_TRUE(cExp.equalsTo(c)); + ASSERT_TRUE(cExp.isSameShape(c)); + ASSERT_TRUE(cExp.equalsTo(c)); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_3D_1) { - - // x[4, 12, 128] * y[4, 128] = z[4, 12, 128] - - auto x = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 60.000000, 77.000000, 96.000000, 117.000000, 140.000000, 110.000000, 132.000000, 156.000000, 182.000000, 210.000000, 240.000000, 272.000000, 306.000000, 342.000000, 380.000000, 315.000000, 352.000000, 391.000000, 432.000000, 475.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Multiply, { 0,2 }, y, z); - //z.printBuffer(); - ASSERT_EQ(e, z); + // x[4, 12, 128] * y[4, 128] = z[4, 12, 128] + + auto x = NDArray('c', {2, 3, 5}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 60.000000, + 77.000000, 96.000000, 117.000000, 140.000000, 110.000000, 132.000000, + 156.000000, 182.000000, 210.000000, 240.000000, 272.000000, 306.000000, + 342.000000, 380.000000, 315.000000, 352.000000, 391.000000, 432.000000, + 475.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, {0, 2}, y, z); + // z.printBuffer(); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_3D_2) { + auto x = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = + NDArray('c', {2, 3, 5}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.600000, + 0.636364, 0.666667, 0.692308, 0.714286, 1.100000, 1.090909, + 1.083333, 1.076923, 1.071429, 1.066667, 1.062500, 1.058824, + 1.055556, 1.052632, 1.400000, 1.375000, 1.352941, 1.333333, + 1.315789, 1.733333, 1.687500, 1.647059, 1.611111, 1.578947}, + sd::DataType::FLOAT32); - auto x = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.600000, 0.636364, 0.666667, 0.692308, 0.714286, 1.100000, 1.090909, 1.083333, 1.076923, 1.071429, 1.066667, 1.062500, 1.058824, 1.055556, 1.052632, 1.400000, 1.375000, 1.352941, 1.333333, 1.315789, 1.733333, 1.687500, 1.647059, 1.611111, 1.578947 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); + auto e = NDArray('f', {2, 3, 5}, sd::DataType::FLOAT32); - e.assign(eC); + e.assign(eC); - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_1) { - - auto x = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 5, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 210.000000, 242.000000, 276.000000, 312.000000, 350.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000, 620.000000, 672.000000, 726.000000, 782.000000, 840.000000, 900.000000, 962.000000, 1026.000000, 1092.000000, 1160.000000, 410.000000, 462.000000, 516.000000, 572.000000, 630.000000, 690.000000, 752.000000, 816.000000, 882.000000, 950.000000, 1020.000000, 1092.000000, 1166.000000, 1242.000000, 1320.000000, 1400.000000, 1482.000000, 1566.000000, 1652.000000, 1740.000000, 1830.000000, 1922.000000, 2016.000000, 2112.000000, 2210.000000, 2310.000000, 2412.000000, 2516.000000, 2622.000000, 2730.000000, 2840.000000, 2952.000000, 3066.000000, 3182.000000, 3300.000000, 3420.000000, 3542.000000, 3666.000000, 3792.000000, 3920.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 3030.000000, 3162.000000, 3296.000000, 3432.000000, 3570.000000, 3710.000000, 3852.000000, 3996.000000, 4142.000000, 4290.000000, 4440.000000, 4592.000000, 4746.000000, 4902.000000, 5060.000000, 5220.000000, 5382.000000, 5546.000000, 5712.000000, 5880.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Multiply, { 0,2,3 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 5, 4}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = + NDArray('c', {2, 3, 5, 4}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, + 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, + 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, + 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, + 210.000000, 242.000000, 276.000000, 312.000000, 350.000000, + 390.000000, 432.000000, 476.000000, 522.000000, 570.000000, + 620.000000, 672.000000, 726.000000, 782.000000, 840.000000, + 900.000000, 962.000000, 1026.000000, 1092.000000, 1160.000000, + 410.000000, 462.000000, 516.000000, 572.000000, 630.000000, + 690.000000, 752.000000, 816.000000, 882.000000, 950.000000, + 1020.000000, 1092.000000, 1166.000000, 1242.000000, 1320.000000, + 1400.000000, 1482.000000, 1566.000000, 1652.000000, 1740.000000, + 1830.000000, 1922.000000, 2016.000000, 2112.000000, 2210.000000, + 2310.000000, 2412.000000, 2516.000000, 2622.000000, 2730.000000, + 2840.000000, 2952.000000, 3066.000000, 3182.000000, 3300.000000, + 3420.000000, 3542.000000, 3666.000000, 3792.000000, 3920.000000, + 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, + 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, + 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, + 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, + 3030.000000, 3162.000000, 3296.000000, 3432.000000, 3570.000000, + 3710.000000, 3852.000000, 3996.000000, 4142.000000, 4290.000000, + 4440.000000, 4592.000000, 4746.000000, 4902.000000, 5060.000000, + 5220.000000, 5382.000000, 5546.000000, 5712.000000, 5880.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, {0, 2, 3}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_2) { - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5, 4 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000,0.181818,0.250000,0.307692,0.357143,0.400000,0.437500,0.470588,0.500000,0.526316,0.550000,0.571429, 0.590909,0.608696,0.625000,0.640000, 0.653846,0.666667,0.678571,0.689655, 2.100000,2.000000,1.916667, 1.846154, 1.785714, 1.733333,1.687500, 1.647059,1.611111, 1.578947,1.550000, 1.523810,1.500000, 1.478261,1.458333, 1.440000,1.423077, 1.407407,1.392857, 1.379310,4.100000, 3.818182,3.583333, 3.384615, 3.214286, 3.066667,2.937500, 2.823529,2.722222, 2.631579,2.550000, 2.476191,2.409091, 2.347826,2.291667, 2.240000,2.192308, 2.148148,2.107143, 2.068965,2.033333, 2.000000,1.968750, 1.939394,1.911765, 1.885714,1.861111, 1.837838,1.815789, 1.794872,1.775000, 1.756098,1.738095, 1.720930,1.704545, 1.688889,1.673913, 1.659575,1.645833,1.632653,2.700000,2.645161,2.593750,2.545455,2.500000,2.457143,2.416667,2.378378,2.342105,2.307692,2.275000,2.243902,2.214286,2.186047,2.159091,2.133333,2.108696,2.085106,2.062500,2.040816,3.366667,3.290323,3.218750,3.151515,3.088235,3.028571,2.972222,2.918919,2.868421,2.820513,2.775000,2.731707,2.690476,2.651163,2.613636,2.577778,2.543478,2.510638,2.479167,2.448980 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2,3 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5, 4}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, + 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, + 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 2.100000, + 2.000000, 1.916667, 1.846154, 1.785714, 1.733333, 1.687500, 1.647059, + 1.611111, 1.578947, 1.550000, 1.523810, 1.500000, 1.478261, 1.458333, + 1.440000, 1.423077, 1.407407, 1.392857, 1.379310, 4.100000, 3.818182, + 3.583333, 3.384615, 3.214286, 3.066667, 2.937500, 2.823529, 2.722222, + 2.631579, 2.550000, 2.476191, 2.409091, 2.347826, 2.291667, 2.240000, + 2.192308, 2.148148, 2.107143, 2.068965, 2.033333, 2.000000, 1.968750, + 1.939394, 1.911765, 1.885714, 1.861111, 1.837838, 1.815789, 1.794872, + 1.775000, 1.756098, 1.738095, 1.720930, 1.704545, 1.688889, 1.673913, + 1.659575, 1.645833, 1.632653, 2.700000, 2.645161, 2.593750, 2.545455, + 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, + 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, + 2.062500, 2.040816, 3.366667, 3.290323, 3.218750, 3.151515, 3.088235, + 3.028571, 2.972222, 2.918919, 2.868421, 2.820513, 2.775000, 2.731707, + 2.690476, 2.651163, 2.613636, 2.577778, 2.543478, 2.510638, 2.479167, + 2.448980}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2, 3}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_3) { - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, + 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, + 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, + 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, + 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, + 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, + 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, + 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, + 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, + 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, + 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, + 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, + 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, + 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, + 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, + 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, + 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, + 6.315790}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_4D_4) { - - // x[4, 12, 128, 128] * y[4, 1, 128, 1] = z[4, 12, 128, 128] - - auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 1, 5, 1 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - - ASSERT_EQ(e, z); + // x[4, 12, 128, 128] * y[4, 1, 128, 1] = z[4, 12, 128, 128] + + auto x = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 1, 5, 1}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, + 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, + 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, + 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, + 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, + 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, + 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, + 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, + 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, + 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, + 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, + 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, + 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, + 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, + 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, + 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, + 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, + 6.315790}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4}, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_1) { - // x[4, 12, 128, 128, 128] * y[4, 1, 128, 128, 128] = z[4, 12, 128, 128, 128] - auto x = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('c', { 2, 1, 5, 4, 3 }, sd::DataType::FLOAT32); - auto z = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto e = NDArray('c', { 2, 3, 5, 4, 3 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 630.000000, 682.000000, 736.000000, 792.000000, 850.000000, 910.000000, 972.000000, 1036.000000, 1102.000000, 1170.000000, 1240.000000, 1312.000000, 1386.000000, 1462.000000, 1540.000000, 1620.000000, 1702.000000, 1786.000000, 1872.000000, 1960.000000, 2050.000000, 2142.000000, 2236.000000, 2332.000000, 2430.000000, 2530.000000, 2632.000000, 2736.000000, 2842.000000, 2950.000000, 3060.000000, 3172.000000, 3286.000000, 3402.000000, 3520.000000, 3640.000000, 3762.000000, 3886.000000, 4012.000000, 4140.000000, 610.000000, 682.000000, 756.000000, 832.000000, 910.000000, 990.000000, 1072.000000, 1156.000000, 1242.000000, 1330.000000, 1420.000000, 1512.000000, 1606.000000, 1702.000000, 1800.000000, 1900.000000, 2002.000000, 2106.000000, 2212.000000, 2320.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 5050.000000, 5202.000000, 5356.000000, 5512.000000, 5670.000000, 5830.000000, 5992.000000, 6156.000000, 6322.000000, 6490.000000, 6660.000000, 6832.000000, 7006.000000, 7182.000000, 7360.000000, 7540.000000, 7722.000000, 7906.000000, 8092.000000, 8280.000000, 1210.000000, 1342.000000, 1476.000000, 1612.000000, 1750.000000, 1890.000000, 2032.000000, 2176.000000, 2322.000000, 2470.000000, 2620.000000, 2772.000000, 2926.000000, 3082.000000, 3240.000000, 3400.000000, 3562.000000, 3726.000000, 3892.000000, 4060.000000, 4230.000000, 4402.000000, 4576.000000, 4752.000000, 4930.000000, 5110.000000, 5292.000000, 5476.000000, 5662.000000, 5850.000000, 6040.000000, 6232.000000, 6426.000000, 6622.000000, 6820.000000, 7020.000000, 7222.000000, 7426.000000, 7632.000000, 7840.000000, 8050.000000, 8262.000000, 8476.000000, 8692.000000, 8910.000000, 9130.000000, 9352.000000, 9576.000000, 9802.000000, 10030.000000, 10260.000000, 10492.000000, 10726.000000, 10962.000000, 11200.000000, 11440.000000, 11682.000000, 11926.000000, 12172.000000, 12420.000000, 12670.000000, 12922.000000, 13176.000000, 13432.000000, 13690.000000, 13950.000000, 14212.000000, 14476.000000, 14742.000000, 15010.000000, 15280.000000, 15552.000000, 15826.000000, 16102.000000, 16380.000000, 16660.000000, 16942.000000, 17226.000000, 17512.000000, 17800.000000, 18090.000000, 18382.000000, 18676.000000, 18972.000000, 19270.000000, 19570.000000, 19872.000000, 20176.000000, 20482.000000, 20790.000000, 21100.000000, 21412.000000, 21726.000000, 22042.000000, 22360.000000, 22680.000000, 23002.000000, 23326.000000, 23652.000000, 23980.000000, 24310.000000, 24642.000000, 24976.000000, 25312.000000, 25650.000000, 25990.000000, 26332.000000, 26676.000000, 27022.000000, 27370.000000, 27720.000000, 28072.000000, 28426.000000, 28782.000000, 29140.000000, 29500.000000, 29862.000000, 30226.000000, 30592.000000, 30960.000000, 16870.000000, 17182.000000, 17496.000000, 17812.000000, 18130.000000, 18450.000000, 18772.000000, 19096.000000, 19422.000000, 19750.000000, 20080.000000, 20412.000000, 20746.000000, 21082.000000, 21420.000000, 21760.000000, 22102.000000, 22446.000000, 22792.000000, 23140.000000, 23490.000000, 23842.000000, 24196.000000, 24552.000000, 24910.000000, 25270.000000, 25632.000000, 25996.000000, 26362.000000, 26730.000000, 27100.000000, 27472.000000, 27846.000000, 28222.000000, 28600.000000, 28980.000000, 29362.000000, 29746.000000, 30132.000000, 30520.000000, 30910.000000, 31302.000000, 31696.000000, 32092.000000, 32490.000000, 32890.000000, 33292.000000, 33696.000000, 34102.000000, 34510.000000, 34920.000000, 35332.000000, 35746.000000, 36162.000000, 36580.000000, 37000.000000, 37422.000000, 37846.000000, 38272.000000, 38700.000000, 21070.000000, 21442.000000, 21816.000000, 22192.000000, 22570.000000, 22950.000000, 23332.000000, 23716.000000, 24102.000000, 24490.000000, 24880.000000, 25272.000000, 25666.000000, 26062.000000, 26460.000000, 26860.000000, 27262.000000, 27666.000000, 28072.000000, 28480.000000, 28890.000000, 29302.000000, 29716.000000, 30132.000000, 30550.000000, 30970.000000, 31392.000000, 31816.000000, 32242.000000, 32670.000000, 33100.000000, 33532.000000, 33966.000000, 34402.000000, 34840.000000, 35280.000000, 35722.000000, 36166.000000, 36612.000000, 37060.000000, 37510.000000, 37962.000000, 38416.000000, 38872.000000, 39330.000000, 39790.000000, 40252.000000, 40716.000000, 41182.000000, 41650.000000, 42120.000000, 42592.000000, 43066.000000, 43542.000000, 44020.000000, 44500.000000, 44982.000000, 45466.000000, 45952.000000, 46440.000000 }, sd::DataType::FLOAT32); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); - // z.printBuffer(); - ASSERT_EQ(e, z); + // x[4, 12, 128, 128, 128] * y[4, 1, 128, 128, 128] = z[4, 12, 128, 128, 128] + auto x = NDArray('c', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('c', {2, 1, 5, 4, 3}, sd::DataType::FLOAT32); + auto z = NDArray('c', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray( + 'c', {2, 3, 5, 4, 3}, + {10.000000, 22.000000, 36.000000, 52.000000, 70.000000, + 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, + 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, + 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, + 630.000000, 682.000000, 736.000000, 792.000000, 850.000000, + 910.000000, 972.000000, 1036.000000, 1102.000000, 1170.000000, + 1240.000000, 1312.000000, 1386.000000, 1462.000000, 1540.000000, + 1620.000000, 1702.000000, 1786.000000, 1872.000000, 1960.000000, + 2050.000000, 2142.000000, 2236.000000, 2332.000000, 2430.000000, + 2530.000000, 2632.000000, 2736.000000, 2842.000000, 2950.000000, + 3060.000000, 3172.000000, 3286.000000, 3402.000000, 3520.000000, + 3640.000000, 3762.000000, 3886.000000, 4012.000000, 4140.000000, + 610.000000, 682.000000, 756.000000, 832.000000, 910.000000, + 990.000000, 1072.000000, 1156.000000, 1242.000000, 1330.000000, + 1420.000000, 1512.000000, 1606.000000, 1702.000000, 1800.000000, + 1900.000000, 2002.000000, 2106.000000, 2212.000000, 2320.000000, + 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, + 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, + 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, + 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, + 5050.000000, 5202.000000, 5356.000000, 5512.000000, 5670.000000, + 5830.000000, 5992.000000, 6156.000000, 6322.000000, 6490.000000, + 6660.000000, 6832.000000, 7006.000000, 7182.000000, 7360.000000, + 7540.000000, 7722.000000, 7906.000000, 8092.000000, 8280.000000, + 1210.000000, 1342.000000, 1476.000000, 1612.000000, 1750.000000, + 1890.000000, 2032.000000, 2176.000000, 2322.000000, 2470.000000, + 2620.000000, 2772.000000, 2926.000000, 3082.000000, 3240.000000, + 3400.000000, 3562.000000, 3726.000000, 3892.000000, 4060.000000, + 4230.000000, 4402.000000, 4576.000000, 4752.000000, 4930.000000, + 5110.000000, 5292.000000, 5476.000000, 5662.000000, 5850.000000, + 6040.000000, 6232.000000, 6426.000000, 6622.000000, 6820.000000, + 7020.000000, 7222.000000, 7426.000000, 7632.000000, 7840.000000, + 8050.000000, 8262.000000, 8476.000000, 8692.000000, 8910.000000, + 9130.000000, 9352.000000, 9576.000000, 9802.000000, 10030.000000, + 10260.000000, 10492.000000, 10726.000000, 10962.000000, 11200.000000, + 11440.000000, 11682.000000, 11926.000000, 12172.000000, 12420.000000, + 12670.000000, 12922.000000, 13176.000000, 13432.000000, 13690.000000, + 13950.000000, 14212.000000, 14476.000000, 14742.000000, 15010.000000, + 15280.000000, 15552.000000, 15826.000000, 16102.000000, 16380.000000, + 16660.000000, 16942.000000, 17226.000000, 17512.000000, 17800.000000, + 18090.000000, 18382.000000, 18676.000000, 18972.000000, 19270.000000, + 19570.000000, 19872.000000, 20176.000000, 20482.000000, 20790.000000, + 21100.000000, 21412.000000, 21726.000000, 22042.000000, 22360.000000, + 22680.000000, 23002.000000, 23326.000000, 23652.000000, 23980.000000, + 24310.000000, 24642.000000, 24976.000000, 25312.000000, 25650.000000, + 25990.000000, 26332.000000, 26676.000000, 27022.000000, 27370.000000, + 27720.000000, 28072.000000, 28426.000000, 28782.000000, 29140.000000, + 29500.000000, 29862.000000, 30226.000000, 30592.000000, 30960.000000, + 16870.000000, 17182.000000, 17496.000000, 17812.000000, 18130.000000, + 18450.000000, 18772.000000, 19096.000000, 19422.000000, 19750.000000, + 20080.000000, 20412.000000, 20746.000000, 21082.000000, 21420.000000, + 21760.000000, 22102.000000, 22446.000000, 22792.000000, 23140.000000, + 23490.000000, 23842.000000, 24196.000000, 24552.000000, 24910.000000, + 25270.000000, 25632.000000, 25996.000000, 26362.000000, 26730.000000, + 27100.000000, 27472.000000, 27846.000000, 28222.000000, 28600.000000, + 28980.000000, 29362.000000, 29746.000000, 30132.000000, 30520.000000, + 30910.000000, 31302.000000, 31696.000000, 32092.000000, 32490.000000, + 32890.000000, 33292.000000, 33696.000000, 34102.000000, 34510.000000, + 34920.000000, 35332.000000, 35746.000000, 36162.000000, 36580.000000, + 37000.000000, 37422.000000, 37846.000000, 38272.000000, 38700.000000, + 21070.000000, 21442.000000, 21816.000000, 22192.000000, 22570.000000, + 22950.000000, 23332.000000, 23716.000000, 24102.000000, 24490.000000, + 24880.000000, 25272.000000, 25666.000000, 26062.000000, 26460.000000, + 26860.000000, 27262.000000, 27666.000000, 28072.000000, 28480.000000, + 28890.000000, 29302.000000, 29716.000000, 30132.000000, 30550.000000, + 30970.000000, 31392.000000, 31816.000000, 32242.000000, 32670.000000, + 33100.000000, 33532.000000, 33966.000000, 34402.000000, 34840.000000, + 35280.000000, 35722.000000, 36166.000000, 36612.000000, 37060.000000, + 37510.000000, 37962.000000, 38416.000000, 38872.000000, 39330.000000, + 39790.000000, 40252.000000, 40716.000000, 41182.000000, 41650.000000, + 42120.000000, 42592.000000, 43066.000000, 43542.000000, 44020.000000, + 44500.000000, 44982.000000, 45466.000000, 45952.000000, 46440.000000}, + sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + // z.printBuffer(); + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_2) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5, 4, 3 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 0.700000, 0.709677, 0.718750, 0.727273, 0.735294, 0.742857, 0.750000, 0.756757, 0.763158, 0.769231, 0.775000, 0.780488, 0.785714, 0.790698, 0.795455, 0.800000, 0.804348, 0.808511, 0.812500, 0.816327, 0.820000, 0.823529, 0.826923, 0.830189, 0.833333, 0.836364, 0.839286, 0.842105, 0.844828, 0.847458, 0.850000, 0.852459, 0.854839, 0.857143, 0.859375, 0.861538, 0.863636, 0.865672, 0.867647, 0.869565, 6.100000, 5.636364, 5.250000, 4.923077, 4.642857, 4.400000, 4.187500, 4.000000, 3.833333, 3.684211, 3.550000, 3.428571, 3.318182, 3.217391, 3.125000, 3.040000, 2.961539, 2.888889, 2.821429, 2.758621, 2.700000, 2.645161, 2.593750, 2.545455, 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, 2.062500, 2.040816, 2.020000, 2.000000, 1.980769, 1.962264, 1.944444, 1.927273, 1.910714, 1.894737, 1.879310, 1.864407, 1.850000, 1.836066, 1.822581, 1.809524, 1.796875, 1.784615, 1.772727, 1.761194, 1.750000, 1.739130, 12.100000, 11.090909, 10.250000, 9.538462, 8.928572, 8.400000, 7.937500, 7.529412, 7.166667, 6.842105, 6.550000, 6.285714, 6.045455, 5.826087, 5.625000, 5.440000, 5.269231, 5.111111, 4.964286, 4.827586, 4.700000, 4.580645, 4.468750, 4.363636, 4.264706, 4.171429, 4.083333, 4.000000, 3.921053, 3.846154, 3.775000, 3.707317, 3.642857, 3.581395, 3.522727, 3.466667, 3.413043, 3.361702, 3.312500, 3.265306, 3.220000, 3.176471, 3.134615, 3.094340, 3.055556, 3.018182, 2.982143, 2.947368, 2.913793, 2.881356, 2.850000, 2.819672, 2.790323, 2.761905, 2.734375, 2.707692, 2.681818, 2.656716, 2.632353, 2.608696, 2.585714, 2.563380, 2.541667, 2.520548, 2.500000, 2.480000, 2.460526, 2.441558, 2.423077, 2.405063, 2.387500, 2.370370, 2.353658, 2.337349, 2.321429, 2.305882, 2.290698, 2.275862, 2.261364, 2.247191, 2.233333, 2.219780, 2.206522, 2.193548, 2.180851, 2.168421, 2.156250, 2.144330, 2.132653, 2.121212, 2.110000, 2.099010, 2.088235, 2.077670, 2.067308, 2.057143, 2.047170, 2.037383, 2.027778, 2.018349, 2.009091, 2.000000, 1.991071, 1.982301, 1.973684, 1.965217, 1.956897, 1.948718, 1.940678, 1.932773, 1.925000, 1.917355, 1.909836, 1.902439, 1.895161, 1.888000, 1.880952, 1.874016, 1.867188, 1.860465, 3.442857, 3.408451, 3.375000, 3.342466, 3.310811, 3.280000, 3.250000, 3.220779, 3.192308, 3.164557, 3.137500, 3.111111, 3.085366, 3.060241, 3.035714, 3.011765, 2.988372, 2.965517, 2.943182, 2.921348, 2.900000, 2.879121, 2.858696, 2.838710, 2.819149, 2.800000, 2.781250, 2.762887, 2.744898, 2.727273, 2.710000, 2.693069, 2.676471, 2.660194, 2.644231, 2.628572, 2.613208, 2.598131, 2.583333, 2.568807, 2.554545, 2.540540, 2.526786, 2.513274, 2.500000, 2.486957, 2.474138, 2.461539, 2.449152, 2.436975, 2.425000, 2.413223, 2.401639, 2.390244, 2.379032, 2.368000, 2.357143, 2.346457, 2.335938, 2.325581, 4.300000, 4.253521, 4.208333, 4.164383, 4.121622, 4.080000, 4.039474, 4.000000, 3.961539, 3.924051, 3.887500, 3.851852, 3.817073, 3.783133, 3.750000, 3.717647, 3.686047, 3.655172, 3.625000, 3.595506, 3.566667, 3.538461, 3.510870, 3.483871, 3.457447, 3.431579, 3.406250, 3.381443, 3.357143, 3.333333, 3.310000, 3.287129, 3.264706, 3.242718, 3.221154, 3.200000, 3.179245, 3.158879, 3.138889, 3.119266, 3.100000, 3.081081, 3.062500, 3.044248, 3.026316, 3.008696, 2.991379, 2.974359, 2.957627, 2.941176, 2.925000, 2.909091, 2.893443, 2.878049, 2.862903, 2.848000, 2.833333, 2.818898, 2.804688, 2.790698 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2,3,4 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5, 4, 3}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, + 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, + 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 0.700000, + 0.709677, 0.718750, 0.727273, 0.735294, 0.742857, 0.750000, 0.756757, + 0.763158, 0.769231, 0.775000, 0.780488, 0.785714, 0.790698, 0.795455, + 0.800000, 0.804348, 0.808511, 0.812500, 0.816327, 0.820000, 0.823529, + 0.826923, 0.830189, 0.833333, 0.836364, 0.839286, 0.842105, 0.844828, + 0.847458, 0.850000, 0.852459, 0.854839, 0.857143, 0.859375, 0.861538, + 0.863636, 0.865672, 0.867647, 0.869565, 6.100000, 5.636364, 5.250000, + 4.923077, 4.642857, 4.400000, 4.187500, 4.000000, 3.833333, 3.684211, + 3.550000, 3.428571, 3.318182, 3.217391, 3.125000, 3.040000, 2.961539, + 2.888889, 2.821429, 2.758621, 2.700000, 2.645161, 2.593750, 2.545455, + 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, + 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, + 2.062500, 2.040816, 2.020000, 2.000000, 1.980769, 1.962264, 1.944444, + 1.927273, 1.910714, 1.894737, 1.879310, 1.864407, 1.850000, 1.836066, + 1.822581, 1.809524, 1.796875, 1.784615, 1.772727, 1.761194, 1.750000, + 1.739130, 12.100000, 11.090909, 10.250000, 9.538462, 8.928572, 8.400000, + 7.937500, 7.529412, 7.166667, 6.842105, 6.550000, 6.285714, 6.045455, + 5.826087, 5.625000, 5.440000, 5.269231, 5.111111, 4.964286, 4.827586, + 4.700000, 4.580645, 4.468750, 4.363636, 4.264706, 4.171429, 4.083333, + 4.000000, 3.921053, 3.846154, 3.775000, 3.707317, 3.642857, 3.581395, + 3.522727, 3.466667, 3.413043, 3.361702, 3.312500, 3.265306, 3.220000, + 3.176471, 3.134615, 3.094340, 3.055556, 3.018182, 2.982143, 2.947368, + 2.913793, 2.881356, 2.850000, 2.819672, 2.790323, 2.761905, 2.734375, + 2.707692, 2.681818, 2.656716, 2.632353, 2.608696, 2.585714, 2.563380, + 2.541667, 2.520548, 2.500000, 2.480000, 2.460526, 2.441558, 2.423077, + 2.405063, 2.387500, 2.370370, 2.353658, 2.337349, 2.321429, 2.305882, + 2.290698, 2.275862, 2.261364, 2.247191, 2.233333, 2.219780, 2.206522, + 2.193548, 2.180851, 2.168421, 2.156250, 2.144330, 2.132653, 2.121212, + 2.110000, 2.099010, 2.088235, 2.077670, 2.067308, 2.057143, 2.047170, + 2.037383, 2.027778, 2.018349, 2.009091, 2.000000, 1.991071, 1.982301, + 1.973684, 1.965217, 1.956897, 1.948718, 1.940678, 1.932773, 1.925000, + 1.917355, 1.909836, 1.902439, 1.895161, 1.888000, 1.880952, 1.874016, + 1.867188, 1.860465, 3.442857, 3.408451, 3.375000, 3.342466, 3.310811, + 3.280000, 3.250000, 3.220779, 3.192308, 3.164557, 3.137500, 3.111111, + 3.085366, 3.060241, 3.035714, 3.011765, 2.988372, 2.965517, 2.943182, + 2.921348, 2.900000, 2.879121, 2.858696, 2.838710, 2.819149, 2.800000, + 2.781250, 2.762887, 2.744898, 2.727273, 2.710000, 2.693069, 2.676471, + 2.660194, 2.644231, 2.628572, 2.613208, 2.598131, 2.583333, 2.568807, + 2.554545, 2.540540, 2.526786, 2.513274, 2.500000, 2.486957, 2.474138, + 2.461539, 2.449152, 2.436975, 2.425000, 2.413223, 2.401639, 2.390244, + 2.379032, 2.368000, 2.357143, 2.346457, 2.335938, 2.325581, 4.300000, + 4.253521, 4.208333, 4.164383, 4.121622, 4.080000, 4.039474, 4.000000, + 3.961539, 3.924051, 3.887500, 3.851852, 3.817073, 3.783133, 3.750000, + 3.717647, 3.686047, 3.655172, 3.625000, 3.595506, 3.566667, 3.538461, + 3.510870, 3.483871, 3.457447, 3.431579, 3.406250, 3.381443, 3.357143, + 3.333333, 3.310000, 3.287129, 3.264706, 3.242718, 3.221154, 3.200000, + 3.179245, 3.158879, 3.138889, 3.119266, 3.100000, 3.081081, 3.062500, + 3.044248, 3.026316, 3.008696, 2.991379, 2.974359, 2.957627, 2.941176, + 2.925000, 2.909091, 2.893443, 2.878049, 2.862903, 2.848000, 2.833333, + 2.818898, 2.804688, 2.790698}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2, 3, 4}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_3) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 5}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, + 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, + 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, + 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, + 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, + 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, + 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, + 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, + 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, + 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, + 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, + 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, + 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, + 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, + 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, + 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, + 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, + 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, + 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, + 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, + 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, + 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, + 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, + 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, + 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, + 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, + 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, + 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, + 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, + 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, + 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, + 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, + 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, + 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, + 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, + 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, + 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, + 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, + 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, + 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, + 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, + 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, + 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, + 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, + 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, + 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, + 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, + 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, + 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, + 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, + 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, + 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, + 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, + 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, + 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, + 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, + 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, + 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, + 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, + 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, {0, 2}, y, z); + + ASSERT_EQ(e, z); } /////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Test_broadcast_5D_4) { - - auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - auto y = NDArray('f', { 2, 1, 5, 1, 1 }, sd::DataType::FLOAT32); - auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - // recieved by main algorithm - auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); - - auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); - e.assign(eC); - - x.linspace(1.f); - y.linspace(10.f); - z.assign(0.f); - - x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); - - ASSERT_EQ(e, z); + auto x = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + auto y = NDArray('f', {2, 1, 5, 1, 1}, sd::DataType::FLOAT32); + auto z = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray( + 'c', {2, 3, 5, 4, 3}, + {0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, + 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, + 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, + 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, + 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, + 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, + 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, + 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, + 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, + 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, + 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, + 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, + 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, + 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, + 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, + 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, + 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, + 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, + 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, + 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, + 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, + 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, + 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, + 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, + 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, + 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, + 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, + 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, + 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, + 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, + 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, + 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, + 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, + 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, + 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, + 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, + 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, + 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, + 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, + 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, + 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, + 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, + 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, + 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, + 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, + 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, + 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, + 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, + 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, + 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, + 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, + 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, + 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, + 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, + 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, + 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, + 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, + 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, + 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, + 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369}, + sd::DataType::FLOAT32); + + auto e = NDArray('f', {2, 3, 5, 4, 3}, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_1) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_2) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 13, 14, 16, 16, 5, 6, 7, 8, + 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}; - Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_3) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_4) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_5) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 12, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99}; - Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_6) { + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float buff2[] = {13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + float expBuff[] = {1, 13, 2, 14, 3, 16, 4, 16, 5, 17, 6, 18, + 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24}; + Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; - float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; - float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24}; - Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99}; - Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99}; - Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray input2(buff2, shape2); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input2}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_7) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_8) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_9) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_10) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); - auto output = results.at(0); - - //expected.printShapeInfo("exp"); - //output->printShapeInfo("out"); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + // expected.printShapeInfo("exp"); + // output->printShapeInfo("out"); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests14, Stack_11) { + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - float buff1[] = {1}; - float expBuff[] = {1, 1, 1}; - Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; - Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); - ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); - - NDArray input1(buff1, shape1); - NDArray expected(expBuff, expShape); - - sd::ops::stack op; - auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_12) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); - sd::ops::stack op; + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::stack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_13) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {1, 1, 3}); - - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 1, 3}); + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - sd::ops::stack op; + auto input = NDArrayFactory::create(inBuff, 'c', {1, 1, 3}); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 1, 3}); - auto z = result.at(0); + sd::ops::stack op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Stack_14) { - float inBuff[] = {1.0f, 2.0f, 3.0f}; - float expBuff[] = {1.0f, 2.0f, 3.0f}; - - auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); - - auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; - sd::ops::stack op; + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); - auto z = result.at(0); + sd::ops::stack op; - //z->printShapeInfo(); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_15) { - auto t = NDArrayFactory::create('c', {2, 3, 5}); - auto u = NDArrayFactory::create('c', {2, 3, 5}); - auto v = NDArrayFactory::create('c', {2, 3, 5}); - auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); + auto t = NDArrayFactory::create('c', {2, 3, 5}); + auto u = NDArrayFactory::create('c', {2, 3, 5}); + auto v = NDArrayFactory::create('c', {2, 3, 5}); + auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v}, {}, {-4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {-4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests14, Stack_16) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create(2.0f); - auto v = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_17) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); - - sd::ops::stack op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); - // z->printShapeInfo("z shape"); + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Stack_18) { - auto x = NDArrayFactory::create('c', {0}); - auto e = NDArrayFactory::create('c', {1, 0}); - - sd::ops::stack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, z); - sd::ops::reduce_min sumOp; - auto res2 = sumOp.evaluate({&e}, {1.}, {1}); - ASSERT_EQ(res2.status(), Status::OK()); - auto out = res2.at(0); + auto x = NDArrayFactory::create('c', {0}); + auto e = NDArrayFactory::create('c', {1, 0}); - ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + ASSERT_EQ(out.e(0), DataTypeUtils::infOrMax()); } TEST_F(DeclarableOpsTests14, Stack_19) { - auto x = NDArrayFactory::empty(); - auto e = NDArrayFactory::create('c', {0}); - - sd::ops::stack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {0}); + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Stack_20) { - auto x = NDArrayFactory::empty(); - auto e = NDArrayFactory::create('c', {2, 0}); - - sd::ops::stack op; - auto result = op.evaluate({&x, &x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {2, 0}); + sd::ops::stack op; + auto result = op.evaluate({&x, &x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Stack_21) { + NDArray x1('c', {3, 2}, sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, sd::DataType::FLOAT32); + x1.linspace(0); + x2.linspace(6); - NDArray x1('c', {3,2}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, sd::DataType::FLOAT32); - x1.linspace(0); - x2.linspace(6); - - sd::ops::stack opStack; - auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, resultStack.status()); + sd::ops::stack opStack; + auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultStack.status()); + sd::ops::concat opConcat; + auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultConcat.status()); - sd::ops::concat opConcat; - auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, resultConcat.status()); + auto outStack = resultStack.at(0); + auto outConcat = resultConcat.at(0); - auto outStack = resultStack.at(0); - auto outConcat = resultConcat.at(0); + outConcat.reshapei({2, 3, 2}); - outConcat.reshapei({2,3,2}); - - ASSERT_TRUE(outStack.isSameShape(outConcat)); - ASSERT_TRUE(outStack.equalsTo(outConcat)); + ASSERT_TRUE(outStack.isSameShape(outConcat)); + ASSERT_TRUE(outStack.equalsTo(outConcat)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Reshape1) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; - - auto x = NDArrayFactory::create('f', xShape); - auto y = NDArrayFactory::create('f', yShape); + const std::vector xShape = {5, 4, 3}; + const std::vector yShape = {3, 5, 4}; + auto x = NDArrayFactory::create('f', xShape); + auto y = NDArrayFactory::create('f', yShape); - VariableSpace variableSpace; - variableSpace.putVariable(-1, x); - variableSpace.putVariable(-2, y); + VariableSpace variableSpace; + variableSpace.putVariable(-1, x); + variableSpace.putVariable(-2, y); - Context block(1, &variableSpace, false); - block.fillInputs({ -1, -2 }); + Context block(1, &variableSpace, false); + block.fillInputs({-1, -2}); - sd::ops::reshapeas reshape; + sd::ops::reshapeas reshape; - reshape.execute(&block); + reshape.execute(&block); - ASSERT_TRUE(variableSpace.hasVariable(1)); - auto z = variableSpace.getVariable(1)->getNDArray().get(); + ASSERT_TRUE(variableSpace.hasVariable(1)); + auto z = variableSpace.getVariable(1)->getNDArray().get(); - ASSERT_TRUE(y.isSameShape(z)); + ASSERT_TRUE(y.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests14, Reshape2) { - const std::vector xShape = { 5,4,3 }; - const std::vector yShape = { 3,5,4 }; + const std::vector xShape = {5, 4, 3}; + const std::vector yShape = {3, 5, 4}; - auto x = NDArrayFactory::create('c', xShape); - auto y = NDArrayFactory::create('c', yShape); + auto x = NDArrayFactory::create('c', xShape); + auto y = NDArrayFactory::create('c', yShape); - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - variableSpace->putVariable(1, std::make_shared()); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, std::make_shared()); - auto block = new Context(1, variableSpace, false); - block->fillInputs({ -1 }); - block->appendI(-y.ordering()); - block->appendI(3); - block->appendI(5); - block->appendI(4); + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI(-y.ordering()); + block->appendI(3); + block->appendI(5); + block->appendI(4); - sd::ops::reshape reshape; + sd::ops::reshape reshape; - Nd4jStatus status = reshape.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + Nd4jStatus status = reshape.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(result->isSameShape(y)); + ASSERT_TRUE(result->isSameShape(y)); - delete block; - delete variableSpace; + delete block; + delete variableSpace; } TEST_F(DeclarableOpsTests14, Reshape3) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3, 4, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(x.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Reshape4) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {3, 4, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(x.isSameShape(z)); } TEST_F(DeclarableOpsTests14, Reshape5) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {5, 4, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); } TEST_F(DeclarableOpsTests14, Reshape6) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 4, 15 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {4, 15}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { 4, -1 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {4, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape7) { - auto x = NDArrayFactory::create('c', { 3, 4, 5 }); - auto exp = NDArrayFactory::create('c', { 60 }); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto exp = NDArrayFactory::create('c', {60}); - sd::ops::reshape op; - auto result = op.evaluate({ &x }, {}, { -1 }); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.isSameShape(exp)); } TEST_F(DeclarableOpsTests14, Reshape8) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); - auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + auto x = NDArrayFactory::create('f', {2, 3}, + {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); + auto e = NDArrayFactory::create('f', {3, 2}, + {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); - auto r = x.reshape('c', {3, 2});; - r.streamline('f'); + auto r = x.reshape('c', {3, 2}); + ; + r.streamline('f'); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {3, 2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {3, 2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); } TEST_F(DeclarableOpsTests14, Reshape9) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - sd::ops::reshape op; - auto result = op.evaluate({&array}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape10) { - auto array = NDArrayFactory::create(119.f); - auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); - auto z = NDArrayFactory::create('c', {1, 1}); + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); - sd::ops::reshape op; - auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + sd::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape11) { - auto x = NDArrayFactory::create('c', {4, 3}); - auto exp = NDArrayFactory::create('c', {4, 3}); + auto x = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('c', {4, 3}); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {-99, 4, 3}); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {-99, 4, 3}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape12) { - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); - auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); + auto exp = + NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x, &shape}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape13) { - auto vector = NDArrayFactory::create('c', {1}, {119.0f}); - auto exp = NDArrayFactory::create(119.f); - auto empty = NDArrayFactory::empty_(); + auto vector = NDArrayFactory::create('c', {1}, {119.0f}); + auto exp = NDArrayFactory::create(119.f); + auto empty = NDArrayFactory::empty_(); - sd::ops::reshape op; - auto result = op.evaluate({&vector, empty}, {}, {}); + sd::ops::reshape op; + auto result = op.evaluate({&vector, empty}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(exp, result.at(0)); + ASSERT_EQ(exp, result.at(0)); - delete empty; + delete empty; } TEST_F(DeclarableOpsTests14, Reshape14) { - auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); - auto y = NDArrayFactory::create('c', {2}, {10, 0}); - auto e = NDArrayFactory::create('c', {10, 0}); - - sd::ops::reshape op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); - auto z = result.at(0); + sd::ops::reshape op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests14, Reshape15) { + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); - auto x0 = NDArrayFactory::create('c', {2, 0}); - auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); - auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + auto e0 = NDArrayFactory::create('c', {2, 0, 1}); + auto e1 = NDArrayFactory::create('c', {0, 1}); - auto e0 = NDArrayFactory::create('c', {2, 0, 1}); - auto e1 = NDArrayFactory::create('c', {0, 1}); + sd::ops::reshape op; + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0.status()); + auto z0 = result0.at(0); + ASSERT_EQ(e0, z0); - sd::ops::reshape op; - auto result0 = op.evaluate({&x0, &shape0}, {}, {}); - ASSERT_EQ(Status::OK(), result0.status()); - auto z0 = result0.at(0); - ASSERT_EQ(e0, z0); - - auto result1 = op.evaluate({&x1, &shape1}, {}, {}); - ASSERT_EQ(Status::OK(), result1.status()); - auto z1 = result1.at(0); - ASSERT_EQ(e1, z1); + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1.status()); + auto z1 = result1.at(0); + ASSERT_EQ(e1, z1); } TEST_F(DeclarableOpsTests14, Reshape16) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); - auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); - sd::ops::reshape op; + sd::ops::reshape op; - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape17) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape18) { - auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests14, Reshape19) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::reshape op; - auto result = op.evaluate({&x}, {}, {-99, 1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests14, Reshape20) { - - NDArray x1('c', {2,0}, sd::DataType::FLOAT32); - NDArray x2('c', {10,0}, sd::DataType::FLOAT32); - NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32); - NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32); - NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); - NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); - NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); - NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32); - - sd::ops::reshape op; - - auto result = op.evaluate({&x1}, {}, {2, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,0})); - - result = op.evaluate({&x2}, {}, {2, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,0,5})); - - result = op.evaluate({&x2}, {}, {5, 2, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({5,2,0})); - - result = op.evaluate({&x2}, {}, {-1, 2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({5,2,0})); - - result = op.evaluate({&x3}, {}, {2, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,0,10})); - - result = op.evaluate({&x4}, {}, {2, -1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,5,0})); - - result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,0,0,0,10})); - - result = op.evaluate({&x6}, {}, {-1, 2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); - - result = op.evaluate({&x7}, {}, {-1, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2, 0})); - - result = op.evaluate({&x7}, {}, {10,0,50,100}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({10,0,50,100})); - - result = op.evaluate({&x7}, {}, {2,0,-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(result.at(0).isSameShape({2,0,1})); + NDArray x1('c', {2, 0}, sd::DataType::FLOAT32); + NDArray x2('c', {10, 0}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 0, 0, 10}, sd::DataType::FLOAT32); + NDArray x4('c', {0, 0, 10}, sd::DataType::FLOAT32); + NDArray x5('c', {0, 2, 10}, sd::DataType::FLOAT32); + NDArray x6('c', {0, 10, 0}, sd::DataType::FLOAT32); + NDArray x7('c', {0, 1, 2}, sd::DataType::FLOAT32); + NDArray x8('c', {1, 2, 0}, sd::DataType::FLOAT32); + + sd::ops::reshape op; + + auto result = op.evaluate({&x1}, {}, {2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0})); + + result = op.evaluate({&x2}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 5})); + + result = op.evaluate({&x2}, {}, {5, 2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x2}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x3}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 10})); + + result = op.evaluate({&x4}, {}, {2, -1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 5, 0})); + + result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 0, 0, 10})); + + result = op.evaluate({&x6}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({5, 2, 0})); + + result = op.evaluate({&x7}, {}, {-1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0})); + + result = op.evaluate({&x7}, {}, {10, 0, 50, 100}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({10, 0, 50, 100})); + + result = op.evaluate({&x7}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0).isSameShape({2, 0, 1})); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 2eff1150d317..d8d0c9c60808 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -14,184 +14,186 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // Created by raver on 8/4/2018. // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests15 : public testing::Test { -public: - - DeclarableOpsTests15() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests15() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) { - auto d = NDArrayFactory::create('c', {10, 10}); - auto w = NDArrayFactory::create(10); - auto x = NDArrayFactory::create('c', {10}); - auto y = NDArrayFactory::create('c', {10}); + auto d = NDArrayFactory::create('c', {10, 10}); + auto w = NDArrayFactory::create(10); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); - auto z0 = NDArrayFactory::create('c', {10}); - auto z1 = NDArrayFactory::create('c', {10}); + auto z0 = NDArrayFactory::create('c', {10}); + auto z1 = NDArrayFactory::create('c', {10}); - sd::ops::normalize_moments op; - auto result = op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::normalize_moments op; + auto result = + op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests15, Test_Add_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); - sd::ops::add op; - auto result = op.execute({&x, &y}, {&x}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, x); + sd::ops::add op; + auto result = op.execute({&x, &y}, {&x}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { - auto x = NDArrayFactory::create('c', {2, 5}); - int y = 1; - x.assign(y); + auto x = NDArrayFactory::create('c', {2, 5}); + int y = 1; + x.assign(y); - ASSERT_EQ(10, x.sumNumber().e(0)); + ASSERT_EQ(10, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests15, Test_standarize_1) { - auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - sd::ops::standardize op; - auto result = op.execute({&x}, {&x}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, x); + sd::ops::standardize op; + auto result = op.execute({&x}, {&x}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); } TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { - auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::standardize_bp op; - auto result = op.evaluate({&x, &eps}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::standardize_bp op; + auto result = op.evaluate({&x, &eps}, {0}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { - auto x = NDArrayFactory::create('c', {4,4,3}); - NDArray factor = NDArrayFactory::create(2.); - auto e = NDArrayFactory::create('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + auto x = NDArrayFactory::create('c', {4, 4, 3}); + NDArray factor = NDArrayFactory::create(2.); + auto e = NDArrayFactory::create( + 'c', {4, 4, 3}, + {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, + -2.5, -1.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, + 16.5, 20.5, 21.5, 22.5, 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, + 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 50.5, 51.5, 52.5, 56.5, + 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x, &factor}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); - x.linspace(1.); - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x, &factor}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); - - ASSERT_TRUE(e.equalsTo(out)); - + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { - auto x = NDArrayFactory::create('c', {1, 4,4,3}); - auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, - 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, - 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, - 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f - }); - x.linspace(1.); - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x}, {2.}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {1, 4, 4, 3}, + {-21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, + -7.5f, -3.5f, -2.5f, -1.5f, 2.5f, 3.5f, 4.5f, 8.5f, + 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 44.5f, 45.5f, 46.5f, 50.5f, 51.5f, 52.5f, 56.5f, + 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f}); + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { - auto x = NDArrayFactory::create('c', {1, 4,4,3}); - auto e = NDArrayFactory::create('c', {1, 4,4,3}, { - -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, - 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, - 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, - 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {1, 4, 4, 3}, + {-21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, + -7.5f, -3.5f, -2.5f, -1.5f, 2.5f, 3.5f, 4.5f, 8.5f, + 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, + 40.5f, 44.5f, 45.5f, 46.5f, 50.5f, 51.5f, 52.5f, 56.5f, + 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { - auto x = NDArrayFactory::create('c', {4, 4, 3}); - auto e = NDArrayFactory::create('c', {4, 4, 3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {4, 4, 3}); + auto e = NDArrayFactory::create( + 'c', {4, 4, 3}, + {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, + -2.5, -1.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, + 16.5, 20.5, 21.5, 22.5, 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, + 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 50.5, 51.5, 52.5, 56.5, + 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { - auto x = NDArrayFactory::create('c', {1, 3, 4}); - auto e = NDArrayFactory::create('c', {1, 3, 4}, { - -3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16. - }); - x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Adjusted Constrast"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto e = NDArrayFactory::create( + 'c', {1, 3, 4}, {-3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16.}); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); } /* * public void testAdjustContrast1() { - INDArray in = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + INDArray in = Nd4j.createFromArray(new + float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, + 0.2744f, 0.1981f, 0.4143f, 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, @@ -209,1782 +211,2148 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { * */ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { - auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, - 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, - 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, - 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, - 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, - 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, - 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, - 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, - 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, - 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, - 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, - 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, - 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, - 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, - .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, - 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375f, 1.0666375f, 0.9130375f, - -0.07396251f, 0.91843754f, -0.17496246f, - 0.47543746f, 1.2492375f, 0.55643755f, - 1.3110375f, -0.36456245f, 1.0518374f, - 0.7824375f, 0.57523745f, -0.21656245f, - 0.0816375f, -0.2261625f, 0.40323752f, - 1.4520376f, 0.6868375f, 0.81723756f, - -0.17576247f, 0.81423753f, -0.08656245f, - - -0.36249164f, 0.45590833f, 1.1925083f, - 0.00650835f, 1.4861084f, 1.2079083f, - 0.05270836f, 0.37350836f, 0.94130826f, - 1.0715083f, 0.6103083f, 0.9825083f, - 0.07370833f, -0.4518917f, -0.39889166f, - -0.3354917f, 1.2213084f, 1.0345083f, - -0.3132917f, 0.78470826f, 0.23390833f, - 0.6943083f, 0.68170834f, -0.09989169f, - - 0.8352709f, 1.3798709f, 0.15507084f, - 0.26607084f, -0.10792917f, 1.2302709f, - 0.6448709f, -0.29992914f, 1.3534708f, - 0.86607087f, 0.37607086f, 0.04027084f, - 0.40087086f, 0.59507084f, 0.9416709f, - 0.53127086f, -0.01712915f, 1.4610709f, - -0.17152917f, -0.13992918f, 0.6242708f, - -0.42192918f, 0.38387084f, -0.15752912f, - - 0.3311833f, 0.00618333f, 0.17538333f, - 0.10418332f, 0.8365834f, 0.27098334f, - 1.2421833f, -0.1114167f, 1.0153834f, - 0.9523833f, 0.8317833f, 0.9633833f, - 0.6501833f, 0.04258335f, 0.9999833f, - -0.40181667f, 0.11418331f, 0.47938335f, - 1.1057833f, -0.29761666f, 1.0779834f, - 0.5243833f, -0.32181668f, 1.1833833f, - - 0.73157084f, 0.4317708f, 0.7283708f, - 1.2297708f, 0.4307708f, 0.85377085f, - 0.05977082f, -0.09282917f, 0.33957082f, - 1.0751709f, 0.2119708f, 0.51897085f, - -0.25302917f, 1.1723708f, -0.12562919f, - 1.1993709f, 0.5257708f, 0.40517086f, - 0.53197086f, 0.8441708f, 0.02617085f, - -0.0208292f, 0.8711709f, 0.04137081f, - - 0.74936247f, 0.6085625f, 0.8997625f, - -0.08743751f, 0.18576252f, -0.17563748f, - 0.5991625f, -0.0038375f, 0.07576251f, - 0.42536253f, -0.22823751f, 0.36296248f, - 0.81456256f, -0.16183749f, 0.5161625f, - -0.21183747f, 0.7429625f, 0.6217625f, - 0.17656249f, 0.02616251f, -0.17923748f, - 1.4659625f, 0.40016252f, 0.28356248f, - - 0.4195791f, 0.8745791f, 0.36637908f, - 0.50597906f, -0.17942089f, 0.16917908f, - 1.0235791f, 1.3699791f, -0.11382091f, - -0.0918209f, 0.7757791f, 0.09017909f, - 1.3807791f, -0.15202093f, 1.3875791f, - -0.1712209f, 1.3989791f, 0.43777913f, - 0.7855791f, 0.1423791f, 1.4711791f, - 0.6455791f, 0.6211791f, -0.48062086f, - - 0.10189578f, 0.5628958f, 0.68909574f, - 0.96649575f, -0.09370419f, 1.3466958f, - 1.4584957f, 1.3544958f, -0.3829042f, - 0.11269578f, -0.47890422f, 1.0436958f, - 0.6128957f, 0.27209583f, 0.2714958f, - 0.21889582f, 0.08789578f, 1.1296958f, - 0.4596958f, 0.39309582f, 0.8344958f, - 0.71149576f, -0.4799042f, 0.4880958f - }); - - sd::ops::adjust_contrast op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printBuffer("Adjusted Constrast6"); -// e.printBuffer("Adjusted Expected 6"); -// ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, + 0.2072f, 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, 0.1785f, + 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, 0.5869f, 0.5747f, 0.0238f, + 0.2943f, 0.5248f, 0.5879f, .7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, + 0.0519f, 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, 0.3528f, + 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, 0.5991f, 0.0034f, 0.4874f}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {1.0218375f, 1.0666375f, 0.9130375f, -0.07396251f, 0.91843754f, + -0.17496246f, 0.47543746f, 1.2492375f, 0.55643755f, 1.3110375f, + -0.36456245f, 1.0518374f, 0.7824375f, 0.57523745f, -0.21656245f, + 0.0816375f, -0.2261625f, 0.40323752f, 1.4520376f, 0.6868375f, + 0.81723756f, -0.17576247f, 0.81423753f, -0.08656245f, + + -0.36249164f, 0.45590833f, 1.1925083f, 0.00650835f, 1.4861084f, + 1.2079083f, 0.05270836f, 0.37350836f, 0.94130826f, 1.0715083f, + 0.6103083f, 0.9825083f, 0.07370833f, -0.4518917f, -0.39889166f, + -0.3354917f, 1.2213084f, 1.0345083f, -0.3132917f, 0.78470826f, + 0.23390833f, 0.6943083f, 0.68170834f, -0.09989169f, + + 0.8352709f, 1.3798709f, 0.15507084f, 0.26607084f, -0.10792917f, + 1.2302709f, 0.6448709f, -0.29992914f, 1.3534708f, 0.86607087f, + 0.37607086f, 0.04027084f, 0.40087086f, 0.59507084f, 0.9416709f, + 0.53127086f, -0.01712915f, 1.4610709f, -0.17152917f, -0.13992918f, + 0.6242708f, -0.42192918f, 0.38387084f, -0.15752912f, + + 0.3311833f, 0.00618333f, 0.17538333f, 0.10418332f, 0.8365834f, + 0.27098334f, 1.2421833f, -0.1114167f, 1.0153834f, 0.9523833f, + 0.8317833f, 0.9633833f, 0.6501833f, 0.04258335f, 0.9999833f, + -0.40181667f, 0.11418331f, 0.47938335f, 1.1057833f, -0.29761666f, + 1.0779834f, 0.5243833f, -0.32181668f, 1.1833833f, + + 0.73157084f, 0.4317708f, 0.7283708f, 1.2297708f, 0.4307708f, + 0.85377085f, 0.05977082f, -0.09282917f, 0.33957082f, 1.0751709f, + 0.2119708f, 0.51897085f, -0.25302917f, 1.1723708f, -0.12562919f, + 1.1993709f, 0.5257708f, 0.40517086f, 0.53197086f, 0.8441708f, + 0.02617085f, -0.0208292f, 0.8711709f, 0.04137081f, + + 0.74936247f, 0.6085625f, 0.8997625f, -0.08743751f, 0.18576252f, + -0.17563748f, 0.5991625f, -0.0038375f, 0.07576251f, 0.42536253f, + -0.22823751f, 0.36296248f, 0.81456256f, -0.16183749f, 0.5161625f, + -0.21183747f, 0.7429625f, 0.6217625f, 0.17656249f, 0.02616251f, + -0.17923748f, 1.4659625f, 0.40016252f, 0.28356248f, + + 0.4195791f, 0.8745791f, 0.36637908f, 0.50597906f, -0.17942089f, + 0.16917908f, 1.0235791f, 1.3699791f, -0.11382091f, -0.0918209f, + 0.7757791f, 0.09017909f, 1.3807791f, -0.15202093f, 1.3875791f, + -0.1712209f, 1.3989791f, 0.43777913f, 0.7855791f, 0.1423791f, + 1.4711791f, 0.6455791f, 0.6211791f, -0.48062086f, + + 0.10189578f, 0.5628958f, 0.68909574f, 0.96649575f, -0.09370419f, + 1.3466958f, 1.4584957f, 1.3544958f, -0.3829042f, 0.11269578f, + -0.47890422f, 1.0436958f, 0.6128957f, 0.27209583f, 0.2714958f, + 0.21889582f, 0.08789578f, 1.1296958f, 0.4596958f, 0.39309582f, + 0.8344958f, 0.71149576f, -0.4799042f, 0.4880958f}); + + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printBuffer("Adjusted Constrast6"); + // e.printBuffer("Adjusted Expected 6"); + // ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { - auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, - 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, - 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, - 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, - 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, - 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, - 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, - 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, - 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, - 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, - 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, - 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, - 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, - 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, - 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, - 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, - .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, - 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { - 1.0218375, 1.0666375 , 0.9130375 , - -0.07396251, 0.91843754, -0.17496246, - 0.47543746, 1.2492375 , 0.55643755, - 1.3110375 , -0.36456245, 1.0518374 , - 0.7824375 , 0.57523745, -0.21656245, - 0.0816375 , -0.2261625 , 0.40323752, - 1.4520376 , 0.6868375 , 0.81723756, - -0.17576247, 0.81423753, -0.08656245, - - -0.36249164, 0.45590833, 1.1925083 , - 0.00650835, 1.4861084 , 1.2079083 , - 0.05270836, 0.37350836, 0.94130826, - 1.0715083 , 0.6103083 , 0.9825083 , - 0.07370833, -0.4518917 , -0.39889166, - -0.3354917 , 1.2213084 , 1.0345083 , - -0.3132917 , 0.78470826, 0.23390833, - 0.6943083 , 0.68170834, -0.09989169, - - 0.8352709 , 1.3798709 , 0.15507084, - 0.26607084, -0.10792917, 1.2302709 , - 0.6448709 , -0.29992914, 1.3534708 , - 0.86607087, 0.37607086, 0.04027084, - 0.40087086, 0.59507084, 0.9416709 , - 0.53127086, -0.01712915, 1.4610709 , - -0.17152917, -0.13992918, 0.6242708 , - -0.42192918, 0.38387084, -0.15752912, - - - 0.3311833 , 0.00618333, 0.17538333, - 0.10418332, 0.8365834 , 0.27098334, - 1.2421833 , -0.1114167 , 1.0153834 , - 0.9523833 , 0.8317833 , 0.9633833 , - 0.6501833 , 0.04258335, 0.9999833 , - -0.40181667, 0.11418331, 0.47938335, - 1.1057833 , -0.29761666, 1.0779834 , - 0.5243833 , -0.32181668, 1.1833833 , - - 0.73157084, 0.4317708 , 0.7283708 , - 1.2297708 , 0.4307708 , 0.85377085, - 0.05977082, -0.09282917, 0.33957082, - 1.0751709 , 0.2119708 , 0.51897085, - -0.25302917, 1.1723708 , -0.12562919, - 1.1993709 , 0.5257708 , 0.40517086, - 0.53197086, 0.8441708 , 0.02617085, - -0.0208292 , 0.8711709 , 0.04137081, - - 0.74936247, 0.6085625 , 0.8997625 , - -0.08743751, 0.18576252, -0.17563748, - 0.5991625 , -0.0038375 , 0.07576251, - 0.42536253, -0.22823751, 0.36296248, - 0.81456256, -0.16183749, 0.5161625 , - -0.21183747, 0.7429625 , 0.6217625 , - 0.17656249, 0.02616251, -0.17923748, - 1.4659625 , 0.40016252, 0.28356248, - - 0.4195791 , 0.8745791 , 0.36637908, - 0.50597906, -0.17942089, 0.16917908, - 1.0235791 , 1.3699791 , -0.11382091, - -0.0918209 , 0.7757791 , 0.09017909, - 1.3807791 , -0.15202093, 1.3875791 , - -0.1712209 , 1.3989791 , 0.43777913, - 0.7855791 , 0.1423791 , 1.4711791 , - 0.6455791 , 0.6211791 , -0.48062086, - - - 0.10189578, 0.5628958 , 0.68909574, - 0.96649575, -0.09370419, 1.3466958 , - 1.4584957 , 1.3544958 , -0.3829042 , - 0.11269578, -0.47890422, 1.0436958 , - 0.6128957 , 0.27209583, 0.2714958 , - 0.21889582, 0.08789578, 1.1296958 , - 0.4596958 , 0.39309582, 0.8344958 , - 0.71149576, -0.4799042, 0.4880958 - }); -// x.linspace(1.); - sd::ops::adjust_contrast_v2 op; - auto result = op.evaluate({&x}, {2.}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printBuffer("Adjusted Constrast7"); -// e.printBuffer("Adjusted expected 7"); - auto diff = e - out; -// diff.printBuffer("Adjusted subtract 7"); - ASSERT_TRUE(e.equalsTo(out)); - + auto x = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, 0.2585f, 0.4189f, + 0.7028f, 0.7679f, 0.5373f, 0.7234f, 0.2690f, 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.7494f, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, 0.3574f, 0.1704f, 0.8395f, 0.5468f, 0.0744f, + 0.9011f, 0.6574f, 0.4124f, 0.2445f, 0.4248f, 0.5219f, 0.6952f, 0.4900f, + 0.2158f, 0.9549f, 0.1386f, 0.1544f, 0.5365f, 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, 0.2974f, 0.6636f, 0.3808f, 0.8664f, 0.1896f, + 0.7530f, 0.7215f, 0.6612f, 0.7270f, 0.5704f, 0.2666f, 0.7453f, 0.0444f, + 0.3024f, 0.4850f, 0.7982f, 0.0965f, 0.7843f, 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, + 0.4143f, 0.7821f, 0.3505f, 0.5040f, 0.1180f, 0.8307f, 0.1817f, 0.8442f, + 0.5074f, 0.4471f, 0.5105f, 0.6666f, 0.2576f, 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, 0.1210f, 0.2576f, 0.0769f, 0.4643f, 0.1628f, + 0.2026f, 0.3774f, 0.0506f, 0.3462f, 0.5720f, 0.0838f, 0.4228f, 0.0588f, + 0.5362f, 0.4756f, 0.2530f, 0.1778f, 0.0751f, 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, 0.5171f, 0.1744f, 0.3487f, 0.7759f, 0.9491f, + 0.2072f, 0.2182f, 0.6520f, 0.3092f, 0.9545f, 0.1881f, 0.9579f, 0.1785f, + 0.9636f, 0.4830f, 0.6569f, 0.3353f, 0.9997f, 0.5869f, 0.5747f, 0.0238f, + 0.2943f, 0.5248f, 0.5879f, .7266f, 0.1965f, 0.9167f, 0.9726f, 0.9206f, + 0.0519f, 0.2997f, 0.0039f, 0.7652f, 0.5498f, 0.3794f, 0.3791f, 0.3528f, + 0.2873f, 0.8082f, 0.4732f, 0.4399f, 0.6606f, 0.5991f, 0.0034f, 0.4874f}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 1}, + {1.0218375, 1.0666375, 0.9130375, -0.07396251, 0.91843754, + -0.17496246, 0.47543746, 1.2492375, 0.55643755, 1.3110375, + -0.36456245, 1.0518374, 0.7824375, 0.57523745, -0.21656245, + 0.0816375, -0.2261625, 0.40323752, 1.4520376, 0.6868375, + 0.81723756, -0.17576247, 0.81423753, -0.08656245, + + -0.36249164, 0.45590833, 1.1925083, 0.00650835, 1.4861084, + 1.2079083, 0.05270836, 0.37350836, 0.94130826, 1.0715083, + 0.6103083, 0.9825083, 0.07370833, -0.4518917, -0.39889166, + -0.3354917, 1.2213084, 1.0345083, -0.3132917, 0.78470826, + 0.23390833, 0.6943083, 0.68170834, -0.09989169, + + 0.8352709, 1.3798709, 0.15507084, 0.26607084, -0.10792917, + 1.2302709, 0.6448709, -0.29992914, 1.3534708, 0.86607087, + 0.37607086, 0.04027084, 0.40087086, 0.59507084, 0.9416709, + 0.53127086, -0.01712915, 1.4610709, -0.17152917, -0.13992918, + 0.6242708, -0.42192918, 0.38387084, -0.15752912, + + 0.3311833, 0.00618333, 0.17538333, 0.10418332, 0.8365834, + 0.27098334, 1.2421833, -0.1114167, 1.0153834, 0.9523833, + 0.8317833, 0.9633833, 0.6501833, 0.04258335, 0.9999833, + -0.40181667, 0.11418331, 0.47938335, 1.1057833, -0.29761666, + 1.0779834, 0.5243833, -0.32181668, 1.1833833, + + 0.73157084, 0.4317708, 0.7283708, 1.2297708, 0.4307708, + 0.85377085, 0.05977082, -0.09282917, 0.33957082, 1.0751709, + 0.2119708, 0.51897085, -0.25302917, 1.1723708, -0.12562919, + 1.1993709, 0.5257708, 0.40517086, 0.53197086, 0.8441708, + 0.02617085, -0.0208292, 0.8711709, 0.04137081, + + 0.74936247, 0.6085625, 0.8997625, -0.08743751, 0.18576252, + -0.17563748, 0.5991625, -0.0038375, 0.07576251, 0.42536253, + -0.22823751, 0.36296248, 0.81456256, -0.16183749, 0.5161625, + -0.21183747, 0.7429625, 0.6217625, 0.17656249, 0.02616251, + -0.17923748, 1.4659625, 0.40016252, 0.28356248, + + 0.4195791, 0.8745791, 0.36637908, 0.50597906, -0.17942089, + 0.16917908, 1.0235791, 1.3699791, -0.11382091, -0.0918209, + 0.7757791, 0.09017909, 1.3807791, -0.15202093, 1.3875791, + -0.1712209, 1.3989791, 0.43777913, 0.7855791, 0.1423791, + 1.4711791, 0.6455791, 0.6211791, -0.48062086, + + 0.10189578, 0.5628958, 0.68909574, 0.96649575, -0.09370419, + 1.3466958, 1.4584957, 1.3544958, -0.3829042, 0.11269578, + -0.47890422, 1.0436958, 0.6128957, 0.27209583, 0.2714958, + 0.21889582, 0.08789578, 1.1296958, 0.4596958, 0.39309582, + 0.8344958, 0.71149576, -0.4799042, 0.4880958}); + // x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printBuffer("Adjusted Constrast7"); + // e.printBuffer("Adjusted expected 7"); + auto diff = e - out; + // diff.printBuffer("Adjusted subtract 7"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_1) { - auto x = NDArrayFactory::create('c', {2, 2, 2}); - auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); - x.linspace(1.); - - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {(int) sd::DataType::DOUBLE}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); -// out->printIndexedBuffer("Casted result"); - ASSERT_TRUE(e.equalsTo(out)); + auto x = NDArrayFactory::create('c', {2, 2, 2}); + auto e = NDArrayFactory::create('c', {2, 2}, + {2., 512., 8192., 131072.032}); + x.linspace(1.); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int)sd::DataType::DOUBLE}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + // out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_2) { - auto x = NDArrayFactory::create('c', {2, 4}); - auto e = NDArrayFactory::create('c', {2, 4, 2}, {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, - 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); - x.linspace(1.); + auto x = NDArrayFactory::create('c', {2, 4}); + auto e = NDArrayFactory::create( + 'c', {2, 4, 2}, + {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, 0.f, 2.312f, 0.f, 2.375f, + 0.f, 2.438f, 0.f, 2.5f}); + x.linspace(1.); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {(int) sd::DataType::HALF}); - ASSERT_EQ(Status::OK(), result.status()); - auto out = result.at(0); - - ASSERT_TRUE(e.equalsTo(out)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int)sd::DataType::HALF}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + ASSERT_TRUE(e.equalsTo(out)); } TEST_F(DeclarableOpsTests15, Test_BitCast_3) { - auto x = NDArrayFactory::create('c', {1, 4}); + auto x = NDArrayFactory::create('c', {1, 4}); - x.linspace(1.); - sd::ops::bitcast op; - try { - auto result = op.evaluate({&x}, {(int) sd::DataType::INT64}); - ASSERT_NE(Status::OK(), result.status()); + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.evaluate({&x}, {(int)sd::DataType::INT64}); + ASSERT_NE(Status::OK(), result.status()); - } catch (std::exception& e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } TEST_F(DeclarableOpsTests15, Test_BitCast_4) { - auto x = NDArrayFactory::create('c', {1, 4}); - auto e = NDArrayFactory::create('c', {1, 2}, {1234567890LL, 2468013579LL}); - x.linspace(1.); - sd::ops::bitcast op; - try { - auto result = op.execute({&x}, {&e}, {}, {sd::DataType::INT64}, {}); - ASSERT_NE(Status::OK(), result); - } catch(std::exception& e) { - nd4j_printf("Error `%s' should be here. It's OK.\n",e.what()); - } - + auto x = NDArrayFactory::create('c', {1, 4}); + auto e = NDArrayFactory::create('c', {1, 2}, + {1234567890LL, 2468013579LL}); + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.execute({&x}, {&e}, {}, {sd::DataType::INT64}, {}); + ASSERT_NE(Status::OK(), result); + } catch (std::exception& e) { + nd4j_printf("Error `%s' should be here. It's OK.\n", e.what()); + } } TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { - auto x = NDArrayFactory::create('c', {1, 2}); - auto e = NDArrayFactory::create('c', {1, 2}, {4607182418800017408LL, 4611686018427387904LL}); // as TF 4607182418800017408, 4611686018427387904 - x.linspace(1.); - sd::ops::bitcast op; - - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 2}); + auto e = NDArrayFactory::create( + 'c', {1, 2}, + {4607182418800017408LL, + 4611686018427387904LL}); // as TF 4607182418800017408, + // 4611686018427387904 + x.linspace(1.); + sd::ops::bitcast op; - // e.printIndexedBuffer("Double to int64"); - auto res = result.at(0); - ASSERT_EQ(res, e); + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // e.printIndexedBuffer("Double to int64"); + auto res = result.at(0); + ASSERT_EQ(res, e); } - TEST_F(DeclarableOpsTests15, Test_BitCast_5) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 0.4922f, 0.2969f, 0.6172f, 0.8906f, - 0.9297f, 0.0859f, 0.2344f, 0.3828f, - 0.5781f, 0.7969f, 0.0391f, 0.1719f, - 0.8359f, 0.9297f, 0.3438f, 0.0938f}); - - auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, - 3314989625590692528LL}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {0.4922f, 0.2969f, 0.6172f, 0.8906f, 0.9297f, 0.0859f, 0.2344f, 0.3828f, + 0.5781f, 0.7969f, 0.0391f, 0.1719f, 0.8359f, 0.9297f, 0.3438f, 0.0938f}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + auto e = NDArrayFactory::create( + 'c', {4}, + {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, + 3314989625590692528LL}); -// res->printIndexedBuffer("BITCAST5"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST5"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, Test_BitCast_6) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 16.f}); - - auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, - 5476460161268730496LL}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, + 14.f, 15.f, 16.f}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); + auto e = NDArrayFactory::create( + 'c', {4}, + {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, + 5476460161268730496LL}); -// res->printIndexedBuffer("BITCAST6"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST6"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, Test_BitCast_7) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.1f, 2.2f, 3.3f, 4.4f, - 5.1f, 6.2f, 7.3f, 8.4f, - 9.1f, 10.2f, 11.3f, 12.4f, - 13.f, 14.2f, 15.3f, 16.4f}); + auto x = NDArrayFactory::create( + 'c', {4, 4}, + {1.1f, 2.2f, 3.3f, 4.4f, 5.1f, 6.2f, 7.3f, 8.4f, 9.1f, 10.2f, 11.3f, + 12.4f, 13.f, 14.2f, 15.3f, 16.4f}); - auto e = NDArrayFactory::create('c', {4}, { - 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); + auto e = NDArrayFactory::create( + 'c', {4}, + {4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, + 5483778673873668736LL}); - sd::ops::bitcast op; - auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto res = result.at(0); - -// res->printIndexedBuffer("BITCAST7"); - ASSERT_TRUE(e.equalsTo(res)); + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + // res->printIndexedBuffer("BITCAST7"); + ASSERT_TRUE(e.equalsTo(res)); } TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { - auto a = NDArrayFactory::create('c', {1, 3}); - auto b = NDArrayFactory::create('c', {1, 4}); - auto gI = NDArrayFactory::create('c', {3, 4}); + auto a = NDArrayFactory::create('c', {1, 3}); + auto b = NDArrayFactory::create('c', {1, 4}); + auto gI = NDArrayFactory::create('c', {3, 4}); - auto gA = NDArrayFactory::create('c', {1, 3}); - auto gB = NDArrayFactory::create('c', {1, 4}); + auto gA = NDArrayFactory::create('c', {1, 3}); + auto gB = NDArrayFactory::create('c', {1, 4}); - sd::ops::matmul_bp op; - auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, {1, 0, 0}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul_bp op; + auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, + {1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests15, test_non_decreasing_1) { - auto x = NDArrayFactory::create(1.0); - auto z = NDArrayFactory::create(false); - auto e = NDArrayFactory::create(true); + auto x = NDArrayFactory::create(1.0); + auto z = NDArrayFactory::create(false); + auto e = NDArrayFactory::create(true); - sd::ops::is_non_decreasing op; - Context ctx(1); - ctx.setInputArray(0, x); - ctx.setOutputArray(0, z); + sd::ops::is_non_decreasing op; + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_check_numeric_1) { - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, 3.f}); - auto y = NDArrayFactory::string("shouldn't ever trigger"); + auto x = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto y = NDArrayFactory::string("shouldn't ever trigger"); - sd::ops::check_numerics op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::check_numerics op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(DeclarableOpsTests15, test_check_numeric_2) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::infinity()}); - auto y = NDArrayFactory::string("should trigger"); - auto z = NDArrayFactory::create('c', {3} ); + auto x = NDArrayFactory::create( + 'c', {3}, {1.f, 2.f, std::numeric_limits::infinity()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::check_numerics op; - try { - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument& e) { + // + } } TEST_F(DeclarableOpsTests15, test_check_numeric_3) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::quiet_NaN()}); - auto y = NDArrayFactory::string("should trigger"); - auto z = NDArrayFactory::create('c', {3} ); + auto x = NDArrayFactory::create( + 'c', {3}, {1.f, 2.f, std::numeric_limits::quiet_NaN()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::check_numerics op; - try { - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument& e) { + // + } } TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - - sd::ops::layer_norm op; - auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + sd::ops::layer_norm op; + auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto eps = NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::layer_norm_bp op; - auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto eps = + NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::layer_norm_bp op; + auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) { + NDArray x('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, sd::DataType::FLOAT32); + NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, sd::DataType::FLOAT32); + NDArray gradO('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray x('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, sd::DataType::FLOAT32); - NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, sd::DataType::FLOAT32); - NDArray gradO('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - - NDArray gradI('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); - NDArray gradG('c', {4}, sd::DataType::FLOAT32); - NDArray gradB('c', {4}, sd::DataType::FLOAT32); + NDArray gradI('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gradG('c', {4}, sd::DataType::FLOAT32); + NDArray gradB('c', {4}, sd::DataType::FLOAT32); - x.linspace(-20, 0.5); - gradO.linspace(-4, 0.05); + x.linspace(-20, 0.5); + gradO.linspace(-4, 0.05); - sd::ops::layer_norm_bp op; - auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true}); - ASSERT_EQ(Status::OK(), status); + sd::ops::layer_norm_bp op; + auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, + {}, {1, 2, 3}, {true}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests15, test_hashCode_1) { - auto x = NDArrayFactory::create('c', {10}); - auto y = NDArrayFactory::create('c', {10}); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); - x.linspace(1.); - y.linspace(2.); + x.linspace(1.); + y.linspace(2.); - sd::ops::hashcode op; - auto resultA0 = op.evaluate({&x}); - auto resultA1 = op.evaluate({&x}); - auto resultB0 = op.evaluate({&y}); -// resultA0->at(0)->printIndexedBuffer("A0"); -// resultA1->at(0)->printIndexedBuffer("A1"); -// resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(resultA0.at(0), resultA1.at(0)); - ASSERT_NE(resultA0.at(0), resultB0.at(0)); + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); + // resultA0->at(0)->printIndexedBuffer("A0"); + // resultA1->at(0)->printIndexedBuffer("A1"); + // resultB0->at(0)->printIndexedBuffer("B0"); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_hashCode_2) { - auto x = NDArrayFactory::create('c', {1027}); - auto y = NDArrayFactory::create('c', {1027}); + auto x = NDArrayFactory::create('c', {1027}); + auto y = NDArrayFactory::create('c', {1027}); - x.linspace(1.); - y.linspace(2.); + x.linspace(1.); + y.linspace(2.); - sd::ops::hashcode op; - auto resultA0 = op.evaluate({&x}); - auto resultA1 = op.evaluate({&x}); - auto resultB0 = op.evaluate({&y}); + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); -// resultA0->at(0)->printIndexedBuffer("A0"); -// resultA1->at(0)->printIndexedBuffer("A1"); -// resultB0->at(0)->printIndexedBuffer("B0"); + // resultA0->at(0)->printIndexedBuffer("A0"); + // resultA1->at(0)->printIndexedBuffer("A1"); + // resultB0->at(0)->printIndexedBuffer("B0"); - ASSERT_EQ(resultA0.at(0), resultA1.at(0)); - ASSERT_NE(resultA0.at(0), resultB0.at(0)); + ASSERT_EQ(resultA0.at(0), resultA1.at(0)); + ASSERT_NE(resultA0.at(0), resultB0.at(0)); } TEST_F(DeclarableOpsTests15, test_rank_1) { - auto array = NDArrayFactory::create('c', {4, 64}); - auto e = NDArrayFactory::create('c', {}, {2}); - auto z = NDArrayFactory::create('c', {}); + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + auto z = NDArrayFactory::create('c', {}); - sd::ops::rank op; - auto result = op.execute({&array}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + sd::ops::rank op; + auto result = op.execute({&array}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_rank_2) { - auto array = NDArrayFactory::create('c', {4, 64}); - auto e = NDArrayFactory::create('c', {}, {2}); + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); - sd::ops::rank op; - auto result = op.evaluate({&array}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_EQ(e, z); + sd::ops::rank op; + auto result = op.evaluate({&array}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { - auto x0 = NDArrayFactory::create(5); - auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); - auto x2 = NDArrayFactory::create('c', {1, 3}, {0.7717289f, 0.9280778f, 0.98455656f}); - auto x3 = NDArrayFactory::create('c', {1, 3}, {0.94414854f, 0.5956861f, 0.8668989f}); - auto x4 = NDArrayFactory::create('c', {7, 12}, {0.460692f, 0.042572856f, 0.08420354f, -0.09538093f, -0.11416581f, -0.53166187f, 0.40133476f, -0.24381405f, 0.30778718f, 0.52713746f, 0.16253126f, -0.034891903f, 0.011679292f, -0.19076681f, 0.14710993f, -0.3704369f, 0.51872355f, 0.13536876f, -0.5568739f, -0.08727971f, 0.07601875f, -0.074174374f, -0.5345982f, -0.3581748f, -0.28263924f, -0.25141674f, 0.43328637f, -0.50227314f, -0.26641843f, -0.38241976f, -0.19636461f, -0.04020852f, -0.27312332f, 0.5207915f, -0.37247592f, -0.4713087f, -0.25670746f, -0.14942765f, -0.015806139f, -0.22531253f, 0.5582536f, 0.3093416f, 0.3221351f, -0.0964683f, 0.14318448f, 0.42279094f, -0.46992f, -0.43399644f, -0.51704615f, -0.11854091f, 0.21697259f, -0.049382925f, 0.14059627f, 0.3912331f, -0.41345632f, 0.5067368f, -0.3420229f, 0.485789f, 0.044918716f, 0.26209074f, 0.12357575f, 0.21778125f, -0.53791714f, 0.18346387f, 0.054183125f, 0.5480431f, 0.03675288f, -0.26656917f, -0.018610716f, 0.19917983f, 0.5566165f, 0.43570566f, -0.35720813f, 0.31097364f, -0.47134516f, -0.289197f, 0.091138184f, 0.13300979f, -0.36592877f, -0.17540845f, 0.21732038f, 0.4393713f, 0.42800313f, 0.5006979f}); - auto x5 = NDArrayFactory::create('c', {1, 3}); - auto x6 = NDArrayFactory::create('c', {1, 3}); - auto x7 = NDArrayFactory::create('c', {1, 3}); - auto x8 = NDArrayFactory::create('c', {12}); - - sd::ops::lstmBlock op; - auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - // z->printIndexedBuffer("Z"); + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create( + 'c', {5, 1, 4}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, + 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, + 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, + 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); + auto x2 = NDArrayFactory::create( + 'c', {1, 3}, {0.7717289f, 0.9280778f, 0.98455656f}); + auto x3 = NDArrayFactory::create( + 'c', {1, 3}, {0.94414854f, 0.5956861f, 0.8668989f}); + auto x4 = NDArrayFactory::create( + 'c', {7, 12}, + {0.460692f, 0.042572856f, 0.08420354f, -0.09538093f, -0.11416581f, + -0.53166187f, 0.40133476f, -0.24381405f, 0.30778718f, 0.52713746f, + 0.16253126f, -0.034891903f, 0.011679292f, -0.19076681f, 0.14710993f, + -0.3704369f, 0.51872355f, 0.13536876f, -0.5568739f, -0.08727971f, + 0.07601875f, -0.074174374f, -0.5345982f, -0.3581748f, -0.28263924f, + -0.25141674f, 0.43328637f, -0.50227314f, -0.26641843f, -0.38241976f, + -0.19636461f, -0.04020852f, -0.27312332f, 0.5207915f, -0.37247592f, + -0.4713087f, -0.25670746f, -0.14942765f, -0.015806139f, -0.22531253f, + 0.5582536f, 0.3093416f, 0.3221351f, -0.0964683f, 0.14318448f, + 0.42279094f, -0.46992f, -0.43399644f, -0.51704615f, -0.11854091f, + 0.21697259f, -0.049382925f, 0.14059627f, 0.3912331f, -0.41345632f, + 0.5067368f, -0.3420229f, 0.485789f, 0.044918716f, 0.26209074f, + 0.12357575f, 0.21778125f, -0.53791714f, 0.18346387f, 0.054183125f, + 0.5480431f, 0.03675288f, -0.26656917f, -0.018610716f, 0.19917983f, + 0.5566165f, 0.43570566f, -0.35720813f, 0.31097364f, -0.47134516f, + -0.289197f, 0.091138184f, 0.13300979f, -0.36592877f, -0.17540845f, + 0.21732038f, 0.4393713f, 0.42800313f, 0.5006979f}); + auto x5 = NDArrayFactory::create('c', {1, 3}); + auto x6 = NDArrayFactory::create('c', {1, 3}); + auto x7 = NDArrayFactory::create('c', {1, 3}); + auto x8 = NDArrayFactory::create('c', {12}); + + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, + {2.0, 0.3}, {0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("Z"); } TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { - int seqLen = 8; - int bS = 16; - int nIn = 8; - - auto x0 = NDArrayFactory::create(5); - auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); - auto x2 = NDArrayFactory::create('f', {bS, nIn}); // nIn == nOut - auto x3 = NDArrayFactory::create('f', {bS, nIn}); - auto x4 = NDArrayFactory::create('f', {2 * nIn, 4 * nIn}); - auto x5 = NDArrayFactory::create('f', {nIn}); - auto x6 = NDArrayFactory::create('f', {nIn}); - auto x7 = NDArrayFactory::create('f', {nIn}); - auto x8 = NDArrayFactory::create('f', {4 * nIn}); + int seqLen = 8; + int bS = 16; + int nIn = 8; - sd::ops::lstmBlock op; - auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); + auto x2 = NDArrayFactory::create('f', {bS, nIn}); // nIn == nOut + auto x3 = NDArrayFactory::create('f', {bS, nIn}); + auto x4 = NDArrayFactory::create('f', {2 * nIn, 4 * nIn}); + auto x5 = NDArrayFactory::create('f', {nIn}); + auto x6 = NDArrayFactory::create('f', {nIn}); + auto x7 = NDArrayFactory::create('f', {nIn}); + auto x8 = NDArrayFactory::create('f', {4 * nIn}); - auto z = result.at(0); + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, + {1.0, 0.0}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); } TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { + int seqLen = 3; + int bS = 2; + int nIn = 4; - int seqLen = 3; - int bS = 2; - int nIn = 4; + NDArray f('f', {bS, nIn, seqLen}, sd::DataType::FLOAT32); + NDArray cLast('f', {bS, nIn}, sd::DataType::FLOAT32); - NDArray f('f', {bS, nIn, seqLen}, sd::DataType::FLOAT32); - NDArray cLast('f', {bS, nIn}, sd::DataType::FLOAT32); + f = 2; + cLast = 3; - f = 2; - cLast = 3; + for (int t = 0; t < seqLen; ++t) { + // section 1 + // auto ft = f({0,0, 0,0, t,t+1}); + // auto temp = ft * cLast; - for (int t = 0; t < seqLen; ++t) { - - //section 1 - //auto ft = f({0,0, 0,0, t,t+1}); - //auto temp = ft * cLast; - - - // section 2 - auto ft = f({0,0, 0,0, t,t+1}); - auto temp1 = ft.reshape('f', {bS, nIn}); - auto temp2 = temp1 * cLast; - } + // section 2 + auto ft = f({0, 0, 0, 0, t, t + 1}); + auto temp1 = ft.reshape('f', {bS, nIn}); + auto temp2 = temp1 * cLast; + } } TEST_F(DeclarableOpsTests15, test_empty_increasing_1) { - auto x = NDArrayFactory::create('c', {1, 0, 3}); - auto z = NDArrayFactory::create(false); + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); - Context ctx(1); - ctx.setInputArray(0, x); - ctx.setOutputArray(0, z); + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - sd::ops::is_strictly_increasing op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::is_strictly_increasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(true, z.e(0)); + ASSERT_EQ(true, z.e(0)); } TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { - auto x = NDArrayFactory::create('c', {1, 0, 3}); - auto z = NDArrayFactory::create(false); + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); - Context ctx(1); - ctx.setInputArray(0, x); - ctx.setOutputArray(0, z); + Context ctx(1); + ctx.setInputArray(0, x); + ctx.setOutputArray(0, z); - sd::ops::is_non_decreasing op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::is_non_decreasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(true, z.e(0)); + ASSERT_EQ(true, z.e(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { - // rank 1 - NDArray rgbs('c', { 3 }, { 10, 50, 200 }, sd::DataType::INT32); - NDArray expected('c', { 1 }, std::vector{ 55 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({&rgbs}, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray rgbs('c', {3}, {10, 50, 200}, sd::DataType::INT32); + NDArray expected('c', {1}, std::vector{55}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { - // rank 1 - auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); - auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 1 + auto rgbs = NDArrayFactory::create('f', {3}, {1, 120, -25}); + auto expected = NDArrayFactory::create('f', {1}, {67}); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { - // rank 2 - NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); - NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 2 + NDArray rgbs('c', {4, 3}, + {-94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102}, + sd::DataType::INT32); + NDArray expected('c', {4, 1}, {41, 105, 101, 101}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { + NDArray rgbs('c', {3, 2}, {14, 99, 207, 10, 114, 201}, sd::DataType::INT32); - NDArray rgbs('c', { 3, 2 }, {14, 99, 207, 10, 114, 201 }, sd::DataType::INT32); + rgbs.permutei({1, 0}); + NDArray expected('c', {2, 1}, {138, 58}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - rgbs.permutei({1,0}); - NDArray expected('c', { 2, 1 }, { 138, 58 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { - // rank 2 - NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); - NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, sd::DataType::INT32); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {0}); - auto output = result.at(0); + // rank 2 + NDArray rgbs('c', {3, 4}, + {-94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102}, + sd::DataType::INT32); + NDArray expected('c', {1, 4}, {50, 100, 105, 94}, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {0}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 5,4,3 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); - - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 1}, + {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f, 2.49686432f, + -43.59792709f, 9.64180183f, 23.04854202f, 40.7946167f, 44.98754883f, + -25.19047546f, 20.64586449f, -4.97033119f, 30.0226841f, 30.30688286f, + 15.61459541f, 43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); - - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create( + 'c', {5, 1, 4}, + {36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f, + 2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, + 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, + 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f}); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { - // rank 3 - auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); - try { - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); - - } catch (std::exception& e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + // rank 3 + auto rgbs = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + try { + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); + + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { - // rank 3 - auto rgbs = NDArrayFactory::create('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f}); - auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); + // rank 3 + auto rgbs = NDArrayFactory::create( + 'f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f}); + auto expected = NDArrayFactory::create( + 'f', {2, 2, 1}, {36.626545f, 38.607746f, -40.614971f, 18.233341f}); - sd::ops::rgb_to_grs op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) { - // rank 1 - NDArray rgbs('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray rgbs('f', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + NDArray expected('f', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { + NDArray rgbs('c', {3, 2}, {14., 99., 207., 10., 114., 201.}, + sd::DataType::FLOAT32); + rgbs.permutei({1, 0}); - NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, sd::DataType::FLOAT32); - rgbs.permutei({ 1,0 }); + NDArray expected( + 'c', {2, 3}, + {138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085}, + sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; - NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, sd::DataType::FLOAT32); - sd::ops::rgb_to_yuv op; - - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { - // rank 2 - NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, { 0 }); - auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 2 + NDArray rgbs( + 'c', {3, 4}, + {-9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {3, 4}, + {-2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, + 0.358612, -6.472839, 4.568039, 5.290639, -0.430992}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {0}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) { - // rank 3 - NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); - NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray rgbs( + 'c', {5, 4, 3}, + {1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, + 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, + 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, + -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, + 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, + 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, + -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, + 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, + 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, + 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, + 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, + -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {5, 4, 3}, + {14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, + 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, + -5.21515376, -9.41983935, -20.5835293, 24.61614501, -44.28390394, + 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, + 45.62757638, -11.550021, 36.44083018, -64.71012983, -10.435098, + -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, + 10.699559, 9.46744852, -18.5778351, -7.6957283, 39.31166179, + 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, + -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, + 47.642019, -18.83162118, -21.50836396, -33.788558, 22.87507047, + 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, + -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { - // rank 3 - NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, { 1 }); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 3 + NDArray rgbs( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + NDArray expected( + 'c', {5, 3, 4}, + { + 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, + -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, + -52.255379, -36.527435, -51.546139, 2.234915, 20.914114, + 8.785358, 32.552223, -3.356598, 9.069552, 1.393482, + 36.029255, 4.824605, -9.972263, 11.058715, 15.947105, + 55.283543, 36.845627, -29.750486, 0.887228, 6.534475, + -21.794132, 34.155693, -89.929497, 39.562351, 27.276817, + 31.359871, 8.149521, 13.673355, 1.104303, 68.774300, + 2.236881, 13.216944, -3.555702, -3.225931, 3.063015, + -36.134724, 58.302204, 8.477802, 38.695396, 27.181587, + -14.157411, 7.157054, 11.714512, 22.148155, 11.580557, + -27.204905, 7.120562, 21.992094, 2.406748, -6.265247, + }, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { - // rank 3 - NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - try { - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); + // rank 3 + NDArray rgbs( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + try { + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); - } - catch (std::exception & e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { - // rank 3 - NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); - NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, sd::DataType::FLOAT32); - - sd::ops::rgb_to_yuv op; - auto result = op.evaluate({ &rgbs }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray rgbs('f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, + 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, + 6.5078e+00f, 3.3562e+01f, -5.8844e+01f, 2.2750e+01f}, + sd::DataType::FLOAT32); + NDArray expected( + 'f', {2, 2, 3}, + {36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, + -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435}, + sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { - // rank 1 - NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 1 + NDArray yuv('c', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + NDArray expected('c', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { - // rank 1 - NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); + // rank 1 + NDArray yuv('f', {3}, {55.14, 71.2872001, -39.6005542}, + sd::DataType::FLOAT32); + NDArray expected('f', {3}, {10, 50, 200}, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) { - // rank 2 - NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); - NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, { 0 }); - auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 2 + NDArray expected( + 'c', {3, 4}, + {-9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22}, + sd::DataType::FLOAT32); + NDArray yuv('c', {3, 4}, + {-2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, + -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {0}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) { - // rank 3 - NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); - NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray expected( + 'c', {5, 4, 3}, + {1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, + 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, + 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, + -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, + 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, + 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, + -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, + 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, + 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, + 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, + 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, + -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01}, + sd::DataType::FLOAT32); + NDArray yuv( + 'c', {5, 4, 3}, + {14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, + 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, + -5.21515376, -9.41983935, -20.5835293, 24.61614501, -44.28390394, + 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, + 45.62757638, -11.550021, 36.44083018, -64.71012983, -10.435098, + -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, + 10.699559, 9.46744852, -18.5778351, -7.6957283, 39.31166179, + 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, + -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, + 47.642019, -18.83162118, -21.50836396, -33.788558, 22.87507047, + 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, + -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { - // rank 3 - NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, { 1 }); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + // rank 3 + NDArray expected( + 'c', {5, 3, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + NDArray yuv('c', {5, 3, 4}, + { + 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, + -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, + -52.255379, -36.527435, -51.546139, 2.234915, 20.914114, + 8.785358, 32.552223, -3.356598, 9.069552, 1.393482, + 36.029255, 4.824605, -9.972263, 11.058715, 15.947105, + 55.283543, 36.845627, -29.750486, 0.887228, 6.534475, + -21.794132, 34.155693, -89.929497, 39.562351, 27.276817, + 31.359871, 8.149521, 13.673355, 1.104303, 68.774300, + 2.236881, 13.216944, -3.555702, -3.225931, 3.063015, + -36.134724, 58.302204, 8.477802, 38.695396, 27.181587, + -14.157411, 7.157054, 11.714512, 22.148155, 11.580557, + -27.204905, 7.120562, 21.992094, 2.406748, -6.265247, + }, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { - // rank 3 - NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); - try { - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - ASSERT_EQ(Status::THROW(), result.status()); + // rank 3 + NDArray yuv( + 'c', {3, 5, 4}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f, 3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f, 2.3125e+01f, 1.8145e+00f, + 1.4602e+01f, -4.5859e+00f, 3.9344e+01f, 1.1617e+01f, -8.6562e+01f, + 1.0038e+02f, 6.7938e+01f, 5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f, 6.8562e+01f, -7.4414e+00f, + 3.9656e+01f, 1.1641e+01f, -2.7516e+01f, 6.7562e+01f, 7.8438e+01f, + 5.4883e+00f, 2.9438e+01f, -3.1344e+01f, 6.5125e+01f, 1.2695e+01f, + 4.0531e+01f, -6.1211e+00f, 6.2219e+01f, 4.6812e+01f, 5.2250e+01f, + -1.1414e+01f, 1.5404e-02f, 2.9938e+01f, 5.6719e+00f, -2.0125e+01f, + 2.1531e+01f, 6.2500e+01f, 7.2188e+01f, 9.3750e+00f, -4.8125e+01f}, + sd::DataType::FLOAT32); + try { + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); - } - catch (std::exception & e) { - nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); - } + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { - // rank 3 - NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); - NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, sd::DataType::FLOAT32); - - sd::ops::yuv_to_rgb op; - auto result = op.evaluate({ &yuv }, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // rank 3 + NDArray expected('f', {2, 2, 3}, + {1.7750e+01f, -7.1062e+01f, -1.0019e+02f, -2.3406e+01f, + 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, + 6.5078e+00f, 3.3562e+01f, -5.8844e+01f, 2.2750e+01f}, + sd::DataType::FLOAT32); + NDArray yuv( + 'f', {2, 2, 3}, + {36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, + -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435}, + sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({&yuv}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { + // same shape + NDArray x('c', {2, 2, 2}, {4, 3, 2, 5, 7, 8, -9, -12}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 2}, {2, 3, -2, 4, -1, -4, 10, 8}, + sd::DataType::FLOAT32); - // same shape - NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, sd::DataType::FLOAT32); - NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, sd::DataType::FLOAT32); + NDArray dLdz('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdxExp( + 'c', {2, 2, 2}, + {8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08}, + sd::DataType::FLOAT32); + NDArray dLdyExp( + 'c', {2, 2, 2}, + {22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0}, + sd::DataType::FLOAT32); + dLdz.assign(1.0); - NDArray dLdz('c', { 2,2,2 }, sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, sd::DataType::FLOAT32); + sd::ops::Pow_bp op; + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); - dLdz.assign(1.0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto dLdx = results.at(0); - auto dLdy = results.at(1); - - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { + NDArray x('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray y('c', {3, 2, 1}, sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - NDArray x('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray y('c', { 3,2,1 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); - - NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {1, 2, 3}, {16.8, 19.2, 21.6, 24., 26.4, 28.8}, + sd::DataType::FLOAT32); + NDArray dLdyExp('c', {3, 2, 1}, + {13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162}, + sd::DataType::FLOAT32); - x.assign(4.0); - y.assign(2.0); - dLdz.linspace(0.1, 0.1); + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::Pow_bp op; + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdx = results.at(0); - auto dLdy = results.at(1); - - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + auto dLdx = results.at(0); + auto dLdy = results.at(1); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { + // y - same shape as dLdz + NDArray xY('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray yY('c', {3, 2, 3}, sd::DataType::FLOAT32); - // y - same shape as dLdz - NDArray xY('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray yY('c', { 3,2,3 }, sd::DataType::FLOAT32); - - NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, sd::DataType::FLOAT32); - NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + NDArray dLdxExpY('c', {1, 2, 3}, {16.8, 19.2, 21.6, 24., 26.4, 28.8}, + sd::DataType::FLOAT32); + NDArray dLdyExpY('c', {3, 2, 3}, + {2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, + 15.5265, 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, + 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528}, + sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - xY.assign(4.0); - yY.assign(2.0); - dLdz.linspace(0.1, 0.1); + xY.assign(4.0); + yY.assign(2.0); + dLdz.linspace(0.1, 0.1); - sd::ops::Pow_bp op; - auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {}); + sd::ops::Pow_bp op; + auto resultsY = op.evaluate({&xY, &yY, &dLdz}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); - auto dLdxY = resultsY.at(0); - auto dLdyY = resultsY.at(1); + auto dLdxY = resultsY.at(0); + auto dLdyY = resultsY.at(1); - ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); - ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); - ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); - ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); + ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { + // x - same shape ad dLdz + NDArray yX('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray xX('c', {3, 2, 3}, sd::DataType::FLOAT32); - // x - same shape ad dLdz - NDArray yX('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray xX('c', { 3,2,3 }, sd::DataType::FLOAT32); - - NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, sd::DataType::FLOAT32); - NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, sd::DataType::FLOAT32); + NDArray dLdxExpX('c', {3, 2, 3}, + {3.2, 6.4, 9.6, 12.8, 16., 19.2, 22.4, 25.6, 28.8, 32., 35.2, + 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6}, + sd::DataType::FLOAT32); + NDArray dLdyExpX('c', {1, 2, 3}, + {23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528}, + sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); - dLdz.linspace(0.1, 0.1); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); + dLdz.linspace(0.1, 0.1); - sd::ops::Pow_bp op; + sd::ops::Pow_bp op; - xX.assign(2.0); - yX.assign(4.0); + xX.assign(2.0); + yX.assign(4.0); - auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {}); + auto resultsX = op.evaluate({&xX, &yX, &dLdz}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); - auto dLdxX = resultsX.at(0); - auto dLdyX = resultsX.at(1); + auto dLdxX = resultsX.at(0); + auto dLdyX = resultsX.at(1); - ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); - ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); - ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); - ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); + ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); + ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); + ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); + ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { + // both single array + NDArray xConst('c', {1}, sd::DataType::FLOAT32); + NDArray yConst('c', {1}, sd::DataType::FLOAT32); + NDArray dLdz('c', {1}, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {1}, sd::DataType::FLOAT32); + NDArray dLdyExp('c', {1}, sd::DataType::FLOAT32); - // both single array - NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray yConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdxExp('c', { 1 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 1 }, sd::DataType::FLOAT32); + xConst.assign(3.0); + yConst.assign(4.0); + dLdz.assign(1.0); - xConst.assign(3.0); - yConst.assign(4.0); - dLdz.assign(1.0); + dLdxExp.assign(4.0 * pow(3, 3)); + dLdyExp.assign(pow(3, 4) * log(3)); - dLdxExp.assign(4.0 * pow(3, 3)); - dLdyExp.assign(pow(3, 4) * log(3)); + sd::ops::Pow_bp op; + auto results = op.evaluate({&xConst, &yConst, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - auto dLdx = results.at(0); - auto dLdy = results.at(1); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { + // x single array + NDArray xConst('c', {1}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdzC('c', {2, 2, 2}, sd::DataType::FLOAT32); - // x single array - NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); - NDArray y('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - - xConst.assign(2.0); - y.assign(4.0); - dLdzC.linspace(0.1, 0.1); - - NDArray dLdxExpXC('c', { 1 }, std::vector{ 115.2 }, sd::DataType::FLOAT32); - NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, sd::DataType::FLOAT32); + xConst.assign(2.0); + y.assign(4.0); + dLdzC.linspace(0.1, 0.1); - sd::ops::Pow_bp op; - auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); + NDArray dLdxExpXC('c', {1}, std::vector{115.2}, + sd::DataType::FLOAT32); + NDArray dLdyExpXC( + 'c', {2, 2, 2}, + {1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228}, + sd::DataType::FLOAT32); - auto dLdxXC = resultsXC.at(0); - auto dLdyXC = resultsXC.at(1); + sd::ops::Pow_bp op; + auto resultsXC = op.evaluate({&xConst, &y, &dLdzC}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); - ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); - ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); - ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); - ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); + auto dLdxXC = resultsXC.at(0); + auto dLdyXC = resultsXC.at(1); + ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); + ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); + ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); + ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { + // Y - scalar + auto Y = NDArrayFactory::create(2.f); + NDArray x('c', {2, 2, 2}, sd::DataType::FLOAT32); + NDArray dLdzC('c', {2, 2, 2}, sd::DataType::FLOAT32); - // Y - scalar - auto Y = NDArrayFactory::create(2.f); - NDArray x('c', { 2, 2, 2 }, sd::DataType::FLOAT32); - NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + dLdzC.linspace(0.1, 0.1); + x = 4.f; - dLdzC.linspace(0.1, 0.1); - x = 4.f; + NDArray dLdxExpYs('c', {2, 2, 2}, {0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4}, + sd::DataType::FLOAT32); - NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, sd::DataType::FLOAT32); + auto dLdyExpYs = NDArrayFactory::create(79.85056f); - auto dLdyExpYs = NDArrayFactory::create(79.85056f); + sd::ops::Pow_bp op; + auto resultsYs = op.evaluate({&x, &Y, &dLdzC}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); - sd::ops::Pow_bp op; - auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); + auto dLdxY = resultsYs.at(0); + auto dLdyY = resultsYs.at(1); - auto dLdxY = resultsYs.at(0); - auto dLdyY = resultsYs.at(1); - - ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); - ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); - ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); - ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); + ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { - // both scalars - - auto X = NDArrayFactory::create(4.f); - auto Y = NDArrayFactory::create(2.f); - NDArray dLdz = NDArrayFactory::create(0.1f); + // both scalars - NDArray dLdxExp = NDArrayFactory::create(2.f*4.f*0.1f); + auto X = NDArrayFactory::create(4.f); + auto Y = NDArrayFactory::create(2.f); + NDArray dLdz = NDArrayFactory::create(0.1f); - NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); + NDArray dLdxExp = NDArrayFactory::create(2.f * 4.f * 0.1f); - sd::ops::Pow_bp op; - auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {}); + NDArray dLdyExp = + NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::Pow_bp op; + auto results = op.evaluate({&X, &Y, &dLdz}, {}, {}); - auto dLdx = results.at(0); - auto dLdy = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + auto dLdx = results.at(0); + auto dLdy = results.at(1); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { + sd::ops::Pow_bp op; + // diff shapes + NDArray x('c', {3, 2, 1}, sd::DataType::FLOAT32); + NDArray y('c', {1, 2, 3}, sd::DataType::FLOAT32); + NDArray dLdz('c', {3, 2, 3}, sd::DataType::FLOAT32); - sd::ops::Pow_bp op; - // diff shapes - NDArray x('c', { 3,2,1 }, sd::DataType::FLOAT32); - NDArray y('c', { 1,2,3 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); - - NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, sd::DataType::FLOAT32); - NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', {3, 2, 1}, {4.8, 12., 19.2, 26.4, 33.6, 40.8}, + sd::DataType::FLOAT32); + NDArray dLdyExp('c', {1, 2, 3}, + {46.57949, 53.2337, 59.88792, 66.54213, 73.19634, 79.85056}, + sd::DataType::FLOAT32); - x.assign(4.0); - y.assign(2.0); - dLdz.linspace(0.1, 0.1); + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); - auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto results = op.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto dLdx = results.at(0); - auto dLdy = results.at(1); + auto dLdx = results.at(0); + auto dLdy = results.at(1); - ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); - ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); - ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); - ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { + // diff shapes broadcastable + NDArray yB('c', {1, 2, 3, 1}, sd::DataType::FLOAT32); + NDArray xB('c', {2, 3, 1}, sd::DataType::FLOAT32); - // diff shapes broadcastable - NDArray yB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); - NDArray xB('c', { 2,3,1 }, sd::DataType::FLOAT32); + NDArray dLdyExpB('c', {1, 2, 3, 1}, + {2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843}, + sd::DataType::FLOAT32); + NDArray dLdxExpB('c', {2, 3, 1}, {0.8, 1.6, 2.4, 3.2, 4., 4.8}, + sd::DataType::FLOAT32); + NDArray dLdzB('c', {1, 2, 3, 1}, sd::DataType::FLOAT32); - NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, sd::DataType::FLOAT32); - NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, sd::DataType::FLOAT32); - NDArray dLdzB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); + dLdzB.linspace(0.1, 0.1); + xB.assign(4.0); + yB.assign(2.0); - dLdzB.linspace(0.1, 0.1); - xB.assign(4.0); - yB.assign(2.0); + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({&xB, &yB, &dLdzB}, {}, {}); - sd::ops::Pow_bp op; - auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); - auto dLdxB = resultsB.at(0); - auto dLdyB = resultsB.at(1); + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); - ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); - ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); - - ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); - ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); } TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { #ifdef FFAST_MATH - if (1 > 0) - return; + if (1 > 0) return; #endif - NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, sd::DataType::FLOAT32); - NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, sd::DataType::FLOAT32); - - NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); - NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); - - NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, sd::DataType::FLOAT32); - - sd::ops::Pow_bp op; - auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); - auto dLdxB = resultsB.at(0); - auto dLdyB = resultsB.at(1); - - ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); - for (int i = 0; i < dLdxB.lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdxB.e(i)) && !sd::math::nd4j_isnan(dLdxExpB.e(i))) - ASSERT_NEAR(dLdxB.e(i), dLdxExpB.e(i), 0.00001); - } - - ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); - for (int i = 0; i < dLdyB.lengthOf(); ++i) { - if (!sd::math::nd4j_isnan(dLdyB.e(i)) && !sd::math::nd4j_isnan(dLdyExpB.e(i))) - ASSERT_NEAR(dLdyB.e(i), dLdyExpB.e(i), 0.00001); - } - - + NDArray xB('c', {3, 2, 1}, {.4, 3, 5, .8, -9, -12}, sd::DataType::FLOAT32); + NDArray yB('c', {1, 2, 3}, {3, -2, .4, -4, 10, .8}, sd::DataType::FLOAT32); + + NDArray dLdxExpB('c', {3, 2, 1}, + {-5.994056, 39366.191406, 7.508829, -2.223537, + -std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}, + sd::DataType::FLOAT32); + NDArray dLdyExpB( + 'c', {1, 2, 3}, + {20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, + 12974.389648, -std::numeric_limits::quiet_NaN()}, + sd::DataType::FLOAT32); + + NDArray dLdzB( + 'c', {3, 2, 3}, + {.1, .2, .3, .1, .2, .3, .1, .4, .1, .2, .1, .1, .3, .1, .5, .1, .7, .1}, + sd::DataType::FLOAT32); + + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({&xB, &yB, &dLdzB}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + auto dLdxB = resultsB.at(0); + auto dLdyB = resultsB.at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + for (int i = 0; i < dLdxB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdxB.e(i)) && + !sd::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB.e(i), dLdxExpB.e(i), 0.00001); + } + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + for (int i = 0; i < dLdyB.lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdyB.e(i)) && + !sd::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB.e(i), dLdyExpB.e(i), 0.00001); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { + NDArray A('c', {1, 2, 3}, {2.1, 2.2, 2.3, 2.4, 2.5, 2.6}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 4}, {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {1, 2, 3}, {3.3, 8.5, 13.36, 3.7, 9.54, 15.}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {1, 2, 4}, {3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82}, + sd::DataType::FLOAT32); - NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 0, 1, 2, 0, 1}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { + NDArray A('c', {1, 2, 3}, {2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 3}, {3, 3, 3, 3, 3, 3}, sd::DataType::FLOAT32); + NDArray dLdC('c', {1}, {1}, sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 1 }, { 1 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(B.isSameShape(dLdAbp)); - ASSERT_TRUE(B.equalsTo(dLdAbp)); - - ASSERT_TRUE(A.isSameShape(dLdBbp)); - ASSERT_TRUE(A.equalsTo(dLdBbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) { + NDArray A('c', {3, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {4, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - - NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); - NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; + NDArray dA( + 'c', {3, 2, 2}, + {3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17.}, + sd::DataType::FLOAT32); + NDArray dB('c', {4, 2, 2}, + {4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, + 6.2, 6.32, 6.56, 6.8, 7.04}, + sd::DataType::FLOAT32); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + sd::ops::tensormmul_bp op_bp; - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { + NDArray A('c', {3, 4, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 2}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {3, 4, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { + NDArray A('c', {3, 4, 1, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 2, 1}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {3, 4, 1, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { + NDArray A('c', {2, 2, 2}, {2, 2, 2, 2, 2, 2, 2, 2}, sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 2}, {3, 3, 3, 3, 3, 3, 3, 3}, sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, sd::DataType::FLOAT32); - - auto dLdC = NDArrayFactory::create(1.f); - - sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + auto dLdC = NDArrayFactory::create(1.f); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + sd::ops::tensormmul_bp op_bp; + auto resultsBP = + op_bp.evaluate({&A, &B, &dLdC}, {}, {3, 0, 1, 2, 3, 0, 1, 2}, {}); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_TRUE(B.isSameShape(dLdAbp)); - ASSERT_TRUE(B.equalsTo(dLdAbp)); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(A.isSameShape(dLdBbp)); - ASSERT_TRUE(A.equalsTo(dLdBbp)); + ASSERT_TRUE(B.isSameShape(dLdAbp)); + ASSERT_TRUE(B.equalsTo(dLdAbp)); + ASSERT_TRUE(A.isSameShape(dLdBbp)); + ASSERT_TRUE(A.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) { + NDArray A('c', {3, 4, 1}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 4, 1}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 2, 1}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); - - NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; + NDArray dLdA('c', {3, 4, 1}, + {7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, + 21.42, 29.55, 29.94}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {2, 4, 1}, + {30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2}, + sd::DataType::FLOAT32); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + sd::ops::tensormmul_bp op_bp; - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); - ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { + NDArray A('c', {1, 1, 4, 3}, {0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 1, 4, 2}, {4, 13, .5, 19, 2.3, 1.2, 18, .9}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 2}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); + NDArray dLdA('c', {1, 1, 4, 3}, + {20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, + 24.66, 28.44}, + sd::DataType::FLOAT32); + NDArray dLdB('c', {1, 1, 4, 2}, + {11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4}, + sd::DataType::FLOAT32); - NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, sd::DataType::FLOAT32); - NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, sd::DataType::FLOAT32); + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul_bp op_bp; + auto resultsBP = + op_bp.evaluate({&A, &B, &dLdC}, {}, {3, 0, 1, 2, 3, 0, 1, 2}, {}); - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); - ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); - ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); + ASSERT_TRUE(dLdA.isSameShape(dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(dLdAbp)); + ASSERT_TRUE(dLdB.isSameShape(dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) { + NDArray A('c', {3, 2, 2, 1}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {4, 2, 2, 1}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 1, 4, 1}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - - NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); - NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); + NDArray dA( + 'c', {3, 2, 2, 1}, + {3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17.}, + sd::DataType::FLOAT32); + NDArray dB('c', {4, 2, 2, 1}, + {4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, + 6.2, 6.32, 6.56, 6.8, 7.04}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) { + NDArray A('c', {1, 2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {1, 2, 2, 4}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {1, 3, 1, 4}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - - - NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); - NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + NDArray dA( + 'c', {1, 2, 2, 3}, + {3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74}, + sd::DataType::FLOAT32); + NDArray dB('c', {1, 2, 2, 4}, + {3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, + 6.8, 4.73, 5.66, 6.59, 7.52}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) { + NDArray A('c', {2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 4}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {3, 4}, {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2}, + sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); - - - NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); - NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + NDArray dA( + 'c', {2, 2, 3}, + {3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74}, + sd::DataType::FLOAT32); + NDArray dB('c', {2, 2, 4}, + {3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, + 6.8, 4.73, 5.66, 6.59, 7.52}, + sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {2, 0, 1, 2, 0, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) { - - NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); - NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::FLOAT32); - NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, - 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::FLOAT32); - - NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, sd::DataType::FLOAT32); - NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, sd::DataType::FLOAT32); - - sd::ops::tensormmul_bp op_bp; - - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); - + NDArray A('c', {2, 2, 3}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::FLOAT32); + NDArray B('c', {2, 2, 3}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2}, + sd::DataType::FLOAT32); + NDArray dLdC('c', {2, 3, 2, 3}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6}, + sd::DataType::FLOAT32); + + NDArray dA('c', {2, 2, 3}, + {7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, + 49.33, 63.01, 76.69}, + sd::DataType::FLOAT32); + NDArray dB('c', {2, 2, 3}, + {25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, + 33.78, 35.46, 37.14}, + sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) { - - NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::DOUBLE); - NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::DOUBLE); - NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, - 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); - - NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, sd::DataType::DOUBLE); - NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, sd::DataType::DOUBLE); - - sd::ops::tensormmul_bp op_bp; - - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); - - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); - - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); - + NDArray A('c', {3, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2}, + sd::DataType::DOUBLE); + NDArray B('c', {3, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2}, + sd::DataType::DOUBLE); + NDArray dLdC('c', {3, 2, 3, 2}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6}, + sd::DataType::DOUBLE); + + NDArray dA('c', {3, 2, 2}, + {7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, + 71.69, 62.21, 75.71}, + sd::DataType::DOUBLE); + NDArray dB('c', {3, 2, 2}, + {26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, + 34.14, 35.01, 36.66}, + sd::DataType::DOUBLE); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) { + NDArray A('c', {2, 2, 2, 2}, + {2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, + 3.4, 3.5, 3.6}, + sd::DataType::DOUBLE); - NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); - - NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::DOUBLE); + NDArray B('c', {2, 2, 2, 2}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, + 4.4, 4.5, 4.6}, + sd::DataType::DOUBLE); - NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, - 1.3, 1.4, 1.5, 1.6 }, sd::DataType::DOUBLE); + NDArray dLdC( + 'c', {2, 2, 2, 2, 2, 2}, + {.1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6}, + sd::DataType::DOUBLE); - NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, sd::DataType::DOUBLE); - NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, sd::DataType::DOUBLE); + NDArray dA('c', {2, 2, 2, 2}, + {13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, + 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24}, + sd::DataType::DOUBLE); + NDArray dB('c', {2, 2, 2, 2}, + {10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, + 23.48, 25.6, 22.12, 24.56, 27., 29.44}, + sd::DataType::DOUBLE); - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul_bp op_bp; - auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + auto resultsBP = op_bp.evaluate({&A, &B, &dLdC}, {}, {1, 1, 1, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); - auto dLdAbp = resultsBP.at(0); - auto dLdBbp = resultsBP.at(1); + auto dLdAbp = resultsBP.at(0); + auto dLdBbp = resultsBP.at(1); - ASSERT_TRUE(dA.isSameShape(dLdAbp)); - ASSERT_TRUE(dA.equalsTo(dLdAbp)); - - ASSERT_TRUE(dB.isSameShape(dLdBbp)); - ASSERT_TRUE(dB.equalsTo(dLdBbp)); + ASSERT_TRUE(dA.isSameShape(dLdAbp)); + ASSERT_TRUE(dA.equalsTo(dLdAbp)); + ASSERT_TRUE(dB.isSameShape(dLdBbp)); + ASSERT_TRUE(dB.equalsTo(dLdBbp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) { + NDArray A('c', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::FLOAT32); + NDArray B('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::FLOAT32); - NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); - NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); - - NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, sd::DataType::FLOAT32); - - NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, sd::DataType::FLOAT32); - NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, sd::DataType::FLOAT32); + NDArray dLdC('f', {2, 2}, {23.0, 24.44, 2.0, 26.}, sd::DataType::FLOAT32); - sd::ops::tensormmul_bp op; - auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 }); + NDArray dA('c', {2, 2, 3}, + {27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, + 177.32, 379.08001, 580.839966}, + sd::DataType::FLOAT32); + NDArray dB('f', {2, 2, 3}, + {194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., + 288.96002, 240., 431.27999, 324.}, + sd::DataType::FLOAT32); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::tensormmul_bp op; + auto results = op.evaluate({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - auto dLdA = results.at(0); - auto dLdB = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(dA.isSameShape(dLdA)); - ASSERT_TRUE(dA.equalsTo(dLdA)); + auto dLdA = results.at(0); + auto dLdB = results.at(1); - ASSERT_TRUE(dB.isSameShape(dLdB)); - ASSERT_TRUE(dB.equalsTo(dLdB)); + ASSERT_TRUE(dA.isSameShape(dLdA)); + ASSERT_TRUE(dA.equalsTo(dLdA)); + ASSERT_TRUE(dB.isSameShape(dLdB)); + ASSERT_TRUE(dB.equalsTo(dLdB)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { + NDArray A('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); + NDArray B('c', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); - NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + const OpArgsHolder argsHolderFF({&A, &B}, {}, {2, 1, 2, 2, 1, 2}); + const OpArgsHolder argsHolderBP({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); - const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; - sd::ops::tensormmul op; - sd::ops::tensormmul_bp op_bp; - - const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); - ASSERT_TRUE(isGradCorrect); + const bool isGradCorrect = + GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1, 0}); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { + NDArray A('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); + NDArray B('f', {2, 2, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}, + sd::DataType::DOUBLE); - NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); - - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); - const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + const OpArgsHolder argsHolderFF({&A, &B}, {}, {2, 1, 2, 2, 1, 2}); + const OpArgsHolder argsHolderBP({&A, &B, &dLdC}, {}, {2, 1, 2, 2, 1, 2}); - sd::ops::tensormmul op; - sd::ops::tensormmul_bp op_bp; + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; - const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 }); - ASSERT_TRUE(isGradCorrect); + const bool isGradCorrect = + GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1, 0}); + ASSERT_TRUE(isGradCorrect); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, gru_1) { + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; - const int sL = 3; - const int bS = 2; - const int nIn = 5; - const int nOut = 4; - - - NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); - NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); - NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + NDArray x('c', {sL, bS, nIn}, + {0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., + 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., + 10.5, 11., 11.5, 12., 12.5, 13., 13.5, 14., 14.5, 15.}, + sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3, -2, -1, 0, 1, 2, 3, 4}, + sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3 * nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3 * nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, - 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + NDArray expH( + 'c', {sL, bS, nOut}, + {-1.681847, -1.062565, -0.443283, 0.175998, 0.837823, 1.488041, + 2.13826, 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, + 0.135444, 0.397604, 0.710558, 1.002922, 1.295287, 1.587651}, + sd::DataType::FLOAT32); - Wx = 0.003; - Wh = 0.006; - b = 0.5; + Wx = 0.003; + Wh = 0.006; + b = 0.5; - NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + NDArray dLdC('c', {2, 2}, sd::DataType::DOUBLE); - sd::ops::gru op; - auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto h = results.at(0); + auto h = results.at(0); - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, gru_bp_1) { - - const int sL = 3; - const int bS = 2; - const int nIn = 5; - const int nOut = 4; - - - NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE); - NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {3*nOut}, sd::DataType::DOUBLE); - - NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); - - Wx.linspace(1,-0.1); - Wh.linspace(0.2,0.2); - b.linspace(1,-0.15); - - const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); - - sd::ops::gru opFF; - sd::ops::gru_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + NDArray x('c', {sL, bS, nIn}, + {0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., + 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., + 10.5, 11., 11.5, 12., 12.5, 13., 13.5, 14., 14.5, 15.}, + sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, {-3, -2, -1, 0, 1, 2, 3, 4}, + sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 3 * nOut}, sd::DataType::DOUBLE); + NDArray Wh('c', {nOut, 3 * nOut}, sd::DataType::DOUBLE); + NDArray b('c', {3 * nOut}, sd::DataType::DOUBLE); + + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + + Wx.linspace(1, -0.1); + Wh.linspace(0.2, 0.2); + b.linspace(1, -0.15); + + const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); + + sd::ops::gru opFF; + sd::ops::gru_bp opBP; + + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 77262e0f747d..de3c6aa756ec 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -14,1063 +14,1036 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +// +// @author raver119@gmail.com +// - // - // @author raver119@gmail.com - // - -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests16 : public testing::Test { -public: - - DeclarableOpsTests16() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); - auto y = NDArrayFactory::create(0); - auto w = NDArrayFactory::create(3.0f); - auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto y = NDArrayFactory::create(0); + auto w = NDArrayFactory::create(3.0f); + auto e = NDArrayFactory::create('c', {3}, {3.f, 1.f, 1.f}); - sd::ops::scatter_upd op; - auto result = op.evaluate({ &x, &y, &w }); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &y, &w}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_2) { + NDArray x('c', {10, 3}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, {2, 5}, sd::DataType::INT32); + NDArray updates('c', {2, 3}, {100, 101, 102, 200, 201, 202}, + sd::DataType::FLOAT32); + NDArray e('c', {10, 3}, + {1, 2, 3, 4, 5, 6, 100, 101, 102, 10, 11, 12, 13, 14, 15, + 200, 201, 202, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}, + sd::DataType::FLOAT32); - NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 2,5 }, sd::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); - NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, sd::DataType::FLOAT32); - - x.linspace(1); + x.linspace(1); - sd::ops::scatter_upd op; - auto result = op.evaluate({ &x, &indices, &updates }); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &indices, &updates}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, scatter_upd_3) { - - NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); - NDArray indices('c', { 2 }, { 20,5 }, sd::DataType::INT32); - NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); - NDArray output('c', { 10, 3 }, sd::DataType::FLOAT32); - - sd::ops::scatter_upd op; - ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); + NDArray x('c', {10, 3}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, {20, 5}, sd::DataType::INT32); + NDArray updates('c', {2, 3}, {100, 101, 102, 200, 201, 202}, + sd::DataType::FLOAT32); + NDArray output('c', {10, 3}, sd::DataType::FLOAT32); + + sd::ops::scatter_upd op; + ASSERT_ANY_THROW( + op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); } TEST_F(DeclarableOpsTests16, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); - sd::ops::size op; - auto status = op.execute({ &x }, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::size op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, test_empty_noop_1) { - auto z = NDArrayFactory::empty(); + auto z = NDArrayFactory::empty(); - sd::ops::noop op; - auto status = op.execute({}, { &z }, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::noop op; + auto status = op.execute({}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_empty_noop_2) { - auto z = NDArrayFactory::empty(); + auto z = NDArrayFactory::empty(); - Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - sd::ops::noop op; - auto status = op.execute(&ctx); + sd::ops::noop op; + auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_svd_1) { - auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); - auto z = NDArrayFactory::create('c', { 3 }); + auto x = NDArrayFactory::create( + 'c', {3, 3}, + {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, + 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f}); + auto z = NDArrayFactory::create('c', {3}); - sd::ops::svd op; - auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); + sd::ops::svd op; + auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { - auto x = NDArrayFactory::create({ 37, 37, 37 }); - auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); - auto e = NDArrayFactory::create(18); + auto x = NDArrayFactory::create({37, 37, 37}); + auto y = NDArrayFactory::create({8723, 8723, 8723}); + auto e = NDArrayFactory::create(18); - sd::ops::bits_hamming_distance op; - auto result = op.evaluate({ &x, &y }); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::bits_hamming_distance op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { - auto input = NDArrayFactory::create('c', { 512 }); - auto low = NDArrayFactory::create('c', { 512 }); - auto high = NDArrayFactory::create('c', { 512 }); + auto input = NDArrayFactory::create('c', {512}); + auto low = NDArrayFactory::create('c', {512}); + auto high = NDArrayFactory::create('c', {512}); - auto output = NDArrayFactory::create(0.0f); + auto output = NDArrayFactory::create(0.0f); - input.linspace(1.0); - low.linspace(1.0); - high.linspace(1.0); + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); - sd::ops::knn_mindistance op; - auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::knn_mindistance op; + auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); } TEST_F(DeclarableOpsTests16, test_empty_cast_1) { - auto x = NDArrayFactory::create('c', { 1, 0, 2 }); - auto e = NDArrayFactory::create('c', { 1, 0, 2 }); + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto e = NDArrayFactory::create('c', {1, 0, 2}); - sd::ops::cast op; - auto result = op.evaluate({&x}, {10}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, result.at(0)); + sd::ops::cast op; + auto result = op.evaluate({&x}, {10}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests16, test_range_1) { - sd::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); + sd::ops::range op; + auto z = NDArrayFactory::create('c', {200}); - Context ctx(1); - ctx.setTArguments({ -1.0, 1.0, 0.01 }); - ctx.setOutputArray(0, z); + Context ctx(1); + ctx.setTArguments({-1.0, 1.0, 0.01}); + ctx.setOutputArray(0, z); - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } TEST_F(DeclarableOpsTests16, test_range_2) { - sd::ops::range op; - auto z = NDArrayFactory::create('c', { 200 }); + sd::ops::range op; + auto z = NDArrayFactory::create('c', {200}); - double tArgs[] = { -1.0, 1.0, 0.01 }; + double tArgs[] = {-1.0, 1.0, 0.01}; - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); - shape::printShapeInfoLinear("Result", shapes->at(0)); - ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + auto shapes = + ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, + tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); + shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); - delete shapes; + delete shapes; } TEST_F(DeclarableOpsTests16, test_reverse_1) { - std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; - std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; - - for (auto r : rows) { - for (auto c : columns) { - //nd4j_printf("Trying [%i, %i]\n", r, c); - auto array = NDArrayFactory::create('c', { r, c }); - auto exp = NDArrayFactory::create('c', { r, c }); - auto reversed = NDArrayFactory::create('c', { r, c }); - - auto rowOriginal = NDArrayFactory::create('c', { c }); - auto rowReversed = NDArrayFactory::create('c', { c }); - - for (int e = 0; e < c; e++) { - rowOriginal.p(e, (float)e); - rowReversed.p(c - e - 1, (float)e); - } - - - auto listI = array.allTensorsAlongDimension({ 1 }); - auto listE = exp.allTensorsAlongDimension({ 1 }); - - for (int e = 0; e < r; e++) { - listI.at(e).assign(rowOriginal); - listE.at(e).assign(rowReversed); - } - - sd::ops::reverse op; - Nd4jLong axis = 1; - auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(exp, reversed); - } + std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; + std::vector columns = {6, 5, 10, 100, 153, 171, 635}; + + for (auto r : rows) { + for (auto c : columns) { + // nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', {r, c}); + auto exp = NDArrayFactory::create('c', {r, c}); + auto reversed = NDArrayFactory::create('c', {r, c}); + + auto rowOriginal = NDArrayFactory::create('c', {c}); + auto rowReversed = NDArrayFactory::create('c', {c}); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); + } + + auto listI = array.allTensorsAlongDimension({1}); + auto listE = exp.allTensorsAlongDimension({1}); + + for (int e = 0; e < r; e++) { + listI.at(e).assign(rowOriginal); + listE.at(e).assign(rowReversed); + } + + sd::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); } + } } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { - /* - test case generated by python colorsys and scaled to suit our needs - from colorsys import * - from random import * - import numpy as np - rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) - hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) - rgbs.ravel() - hsvs.ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, - 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, - 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, - 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, - 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, - 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, - 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, - 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, - 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, - 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, - 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, - 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, - 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, - 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, - 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, - 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, - 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, - 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, - 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); - - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: + np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) rgbs.ravel() hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, + 0.517642438f, 0.890151322f, 0.461456001f, 0.0869259685f, + 0.928968489f, 0.588904262f, 0.54742825f, 0.684074104f, + 0.52110225f, 0.761800349f, 0.486593395f, 0.753103435f, + 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, + 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, + 0.674094439f, 0.524110138f, 0.216262609f, 0.0361763388f, + 0.2204483f, 0.279114306f, 0.3721793f, 0.632020354f, + 0.25007084f, 0.823592246f, 0.637001634f, 0.30433768f, + 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, + 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, + 0.00208298792f, 0.91390729f, 0.505838513f, 0.875348926f, + 0.428009957f, 0.367065936f, 0.911922634f, 0.270003974f, + 0.164243385f, 0.0581932105f, 0.313204288f, 0.644775152f, + 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, + 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, + 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, + 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, + 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, + 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, + 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, + 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, + 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, + 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); #if 0 //visual check rgbs.printBuffer("rgbs "); actual.printBuffer("HSV "); expected.printBuffer("exp"); #endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { - /* - swapped_rgbs=rgbs.swapaxes(1,2).ravel() - swapped_hsvs=hsvs.swapaxes(1,2).ravel() - */ - auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, - 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, - 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, - 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, - 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, - 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, - 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, - 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, - 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, - 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, - 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, - 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, - 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, - 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, - 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, - 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, - 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, - 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, - 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f - }); - - - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 1 }); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, + 0.725874603f, 0.517642438f, 0.0869259685f, 0.54742825f, + 0.413571358f, 0.890151322f, 0.928968489f, 0.684074104f, + 0.52110225f, 0.753103435f, 0.913557053f, 0.46850124f, + 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, + 0.674094439f, 0.0361763388f, 0.3721793f, 0.823592246f, + 0.524110138f, 0.2204483f, 0.632020354f, 0.637001634f, + 0.216262609f, 0.279114306f, 0.25007084f, 0.30433768f, + 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, + 0.366362303f, 0.931746006f, 0.00208298792f, 0.875348926f, + 0.428009957f, 0.270003974f, 0.313204288f, 0.775881767f, + 0.367065936f, 0.164243385f, 0.644775152f, 0.575452209f, + 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f}); + + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { - - auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, - 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, - 0.54742825f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, - 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, - 0.199753001f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); - - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgbs = NDArrayFactory::create( + 'c', {4, 3}, + {0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f}); + + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 0 }); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f}); + + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { - auto rgbs = NDArrayFactory::create('c', { 3 }, { - 0.545678377f, 0.725874603f, 0.413571358f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.262831867f, 0.430244058f, 0.725874603f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3}, {0.545678377f, 0.725874603f, 0.413571358f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.262831867f, 0.430244058f, 0.725874603f}); - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); + auto actual = NDArrayFactory::create('c', {3}); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, - 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, - 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, - 0.928968489f, 0.684074104f - }); - - //get subarray - //get subarray - NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrRgbs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f}); + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f}); + + // get subarray + // get subarray + NDArray subArrRgbs = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = hsvs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrRgbs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrRgbs.printShapeInfo("subArrRgbs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, subArrRgbs); - ctx.setOutputArray(0, actual); - sd::ops::rgb_to_hsv op; - auto status = op.execute(&ctx); + auto actual = NDArrayFactory::create('c', {3}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { - - auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, - 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, - 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, - 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, - 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, - 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, - 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, - 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, - 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, - 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, - 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, - 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, - 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, - 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, - 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, - 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, - 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, - 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, - 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f - }); - - - auto actual = NDArrayFactory::create('c', { 5,4,3 }); - - Context ctx(1); - ctx.setInputArray(0, hsvs); - ctx.setOutputArray(0, actual); - - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, + 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, + 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, + 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, + 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, + 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, + 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, + 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, + 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, + 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, + 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, + 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, + 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, + 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, + 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, + 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, + 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, + 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, + 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { - auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, - 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, - 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, - 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, - 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, - 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, - 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, - 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, - 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, - 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, - 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, - 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, - 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, - 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, - 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, - 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, - 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, - 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, - 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f - }); - auto actual = NDArrayFactory::create('c', { 5,3,4 }); - - Context ctx(1); - ctx.setInputArray(0, hsvs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 1 }); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, + 0.793608069f, 0.920532584f, 0.563831031f, 0.332347751f, + 0.65870738f, 0.887555957f, 0.773604929f, 0.111181192f, + 0.239250854f, 0.0853395388f, 0.851340771f, 0.971996486f, + 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, + 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, + 0.194078103f, 0.847525477f, 0.0299375951f, 0.196718201f, + 0.219649255f, 0.653597355f, 0.184475258f, 0.179381892f, + 0.616706491f, 0.700065672f, 0.274936169f, 0.934476376f, + 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, + 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, + 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.741747797f, 0.828556478f, 0.447288662f, 0.66402483f, + 0.657703638f, 0.615502477f, 0.864299297f, 0.795475543f, + 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, + 0.135951888f, 0.0705317783f, 0.337422464f, 0.111181192f, + 0.65870738f, 0.811602857f, 0.773604929f, 0.074230373f, + 0.675155059f, 0.226065159f, 0.690895379f, 0.464318275f, + 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, + 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, + 0.594427716f, 0.700065672f, 0.274936169f, 0.904251479f, + 0.616706491f, 0.242504601f, 0.233327664f, 0.934476376f, + 0.481247369f, 0.661103036f, 0.224217249f, 0.766848235f, + 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, + 0.54157777f, 0.213413864f, 0.0821588641f, 0.167780413f, + 0.107076412f, 0.46964643f, 0.01761852f, 0.114806315f, + 0.0573956046f, 0.183820669f, 0.129833668f, 0.121884218f, + 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f}); + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { - auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, - 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, - 0.332347751f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, - 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, - 0.111181192f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 4,3 }); - - Context ctx(1); - ctx.setInputArray(0, hsvs); - ctx.setOutputArray(0, actual); - - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {4, 3}, + {0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f}); + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, hsvs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 0 }); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f}); + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + auto hsvs = NDArrayFactory::create( + 'c', {3}, {0.705504596f, 0.793608069f, 0.65870738f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.257768334f, 0.135951888f, 0.65870738f}); - auto hsvs = NDArrayFactory::create('c', { 3 }, { - 0.705504596f, 0.793608069f, 0.65870738f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.257768334f, 0.135951888f, 0.65870738f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, hsvs); - ctx.setOutputArray(0, actual); + auto actual = NDArrayFactory::create('c', {3}); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); + Context ctx(1); + ctx.setInputArray(0, hsvs); + ctx.setOutputArray(0, actual); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { - - auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { - 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, - 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, - 0.773604929f, 0.111181192f - }); - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, - 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, - 0.773604929f, 0.074230373f - }); - - auto actual = NDArrayFactory::create('c', { 3 }); - //get subarray - NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrHsvs.reshapei({ 3 }); - NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - expected.reshapei({ 3 }); + auto hsvs = NDArrayFactory::create( + 'c', {3, 4}, + {0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f}); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f}); + + auto actual = NDArrayFactory::create('c', {3}); + // get subarray + NDArray subArrHsvs = hsvs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrHsvs.reshapei({3}); + NDArray expected = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrHsvs.printShapeInfo("subArrHsvs"); -#endif - - Context ctx(1); - ctx.setInputArray(0, subArrHsvs); - ctx.setOutputArray(0, actual); - sd::ops::hsv_to_rgb op; - auto status = op.execute(&ctx); +#endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrHsvs); + ctx.setOutputArray(0, actual); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - - TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) { - /** - generated using numpy - _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], - [0.587f, -0.27455667f, -0.52273617f], - [0.114f, -0.32134392f, 0.31119955f]]) - nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3]) - out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) - - #alternatively you could use just with apply - out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs) - - */ - auto rgb = NDArrayFactory::create('c', { 5, 4 ,3 }, - { - 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, - 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , - 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, - 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , - 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, - 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, - 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, - 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, - 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, - 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , - 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, - 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f - }); - - auto expected = NDArrayFactory::create('c', { 5, 4 ,3 }, - { - 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, - 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, - -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , - 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, - 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, - 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, - 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , - -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, - 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, - 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, - -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, - -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f - }); - - auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, rgb); - ctx.setOutputArray(0, actual); - - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + /** + generated using numpy + _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], + [0.587f, -0.27455667f, -0.52273617f], + [0.114f, -0.32134392f, 0.31119955f]]) + nnrgbs = np.array([random() for x in + range(0,3*4*5)],np.float32).reshape([5,4,3]) out + =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) + + #alternatively you could use just with apply + out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ + x,len(nnrgbs.shape)-1,nnrgbs) + + */ + auto rgb = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.48055f, 0.80757356f, 0.2564435f, 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f, + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f, 0.41196227f, 0.3095014f, + 0.6620493f, 0.30888894f, 0.3122602f, 0.7993488f, 0.86656475f, + 0.5997049f, 0.9776477f, 0.72481847f, 0.7835693f, 0.14649455f, + 0.3573504f, 0.33301765f, 0.7853056f, 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f, 0.72647524f, 0.6623308f, 0.96197623f, + 0.0720306f, 0.23853847f, 0.1427159f, 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f, 0.01258832f, 0.8160156f, + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f, 0.20662387f, + 0.4153733f, 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f}); + + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f, + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f, -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f, 0.47240727f, -0.1417591f, + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f, 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f, -0.0304029f, 0.35893188f}); + + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { - - auto rgb = NDArrayFactory::create('c', { 5, 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, - 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f , - 0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f, - 0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f , - 0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f , - 0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f, - 0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f , - 0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f, - 0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f , - 0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f - }); - - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, - 0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, - 0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, - 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, - -0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f , - -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f, - -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, - 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, - 0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f, - -0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f - }); - - auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, rgb); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 1 }); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto rgb = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, + 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, + 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f, 0.3122602f, + 0.14837205f, 0.9585542f, 0.6620493f, 0.7993488f, 0.86656475f, + 0.72481847f, 0.3573504f, 0.25830218f, 0.5997049f, 0.7835693f, + 0.33301765f, 0.59289205f, 0.9776477f, 0.14649455f, 0.7853056f, + 0.41357264f, 0.5934154f, 0.96197623f, 0.1427159f, 0.10614152f, + 0.72647524f, 0.0720306f, 0.19581454f, 0.26093867f, 0.6623308f, + 0.23853847f, 0.06766324f, 0.9584985f, 0.01258832f, 0.08418505f, + 0.20662387f, 0.50057423f, 0.8160156f, 0.86440504f, 0.4153733f, + 0.08274968f, 0.56506383f, 0.6807802f, 0.76146203f, 0.9521758f}); + + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, + 0.3667803f, 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, + 0.2397369f, -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, + 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, + -0.13084008f, -0.1417591f, 0.17403452f, -0.21071186f, 0.145886f, + -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f, 0.29417614f, + -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, + 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, + 0.39241133f, 0.3067938f, -0.39812097f, -0.40592682f, -0.23560742f, + -0.0304029f, -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f}); + + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { + auto rgb = NDArrayFactory::create( + 'c', {4, 3}, + {0.48055f, 0.80757356f, 0.2564435f, 0.94277316f, 0.17006584f, 0.33366168f, + 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f, 0.98633456f, + 0.00158441f}); - auto rgb = NDArrayFactory::create('c', { 4, 3 }, - { - 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, - 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , - 0.98633456f, 0.00158441f - }); - - auto expected = NDArrayFactory::create('c', { 4, 3 }, - { - 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, - 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, - -0.07432612f, -0.44518381f - }); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f}); - auto actual = NDArrayFactory::create('c', { 4, 3 }); + auto actual = NDArrayFactory::create('c', {4, 3}); - Context ctx(1); - ctx.setInputArray(0, rgb); - ctx.setOutputArray(0, actual); + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { - - auto rgb = NDArrayFactory::create('c', { 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f - }); - - auto expected = NDArrayFactory::create('c', { 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f - }); - - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, rgb); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 0 }); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + auto rgb = NDArrayFactory::create( + 'c', {3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, 0.17006584f, + 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, 0.48942474f, + 0.00158441f}); + + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f}); + + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, rgb); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } - - TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) { - - auto rgbs = NDArrayFactory::create('c', { 3 }, - { 0.48055f , 0.80757356f, 0.2564435f }); - auto expected = NDArrayFactory::create('c', { 3 }, - { 0.64696468f, -0.01777124f, -0.24070648f, }); - - - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, rgbs); - ctx.setOutputArray(0, actual); - - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + auto rgbs = NDArrayFactory::create( + 'c', {3}, {0.48055f, 0.80757356f, 0.2564435f}); + auto expected = NDArrayFactory::create('c', {3}, + { + 0.64696468f, + -0.01777124f, + -0.24070648f, + }); + + auto actual = NDArrayFactory::create('c', {3}); + + Context ctx(1); + ctx.setInputArray(0, rgbs); + ctx.setOutputArray(0, actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { - - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, - { - 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, - 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, - 0.48942474f, 0.00158441f - }); - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, - { - 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, - 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, - -0.04447775f, -0.44518381f - }); - - //get subarray - NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrRgbs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.48055f, 0.94277316f, 0.41727918f, 0.3305715f, 0.80757356f, 0.17006584f, + 0.54528666f, 0.98633456f, 0.2564435f, 0.33366168f, 0.48942474f, + 0.00158441f}); + + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f}); + + // get subarray + NDArray subArrRgbs = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = yiqs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrRgbs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrRgbs.printShapeInfo("subArrRgbs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, subArrRgbs); - ctx.setOutputArray(0, actual); - sd::ops::rgb_to_yiq op; - auto status = op.execute(&ctx); + auto actual = NDArrayFactory::create('c', {3}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + Context ctx(1); + ctx.setInputArray(0, subArrRgbs); + ctx.setOutputArray(0, actual); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { - - auto yiqs = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, - 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, - -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, - 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, - 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, - 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, - 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, - -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, - -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, - 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, - 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, - 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f - }); - auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, - 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, - 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, - -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, - 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, - 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, - 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, - 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, - 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, - 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, - 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, - -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f - }); - auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, yiqs); - ctx.setOutputArray(0, actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, + -0.212469354f, 0.455438733f, 0.418221354f, 0.349350512f, + 0.145902053f, 0.947576523f, -0.471601307f, 0.263960421f, + 0.700227439f, 0.32434237f, -0.278446227f, 0.130805135f, + -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, + 0.76971364f, 0.131288841f, -0.141177326f, 0.900081575f, + -0.0788725987f, 0.14756602f, 0.387832165f, 0.229834676f, + 0.47921446f, 0.632930398f, 0.0443540029f, -0.268817365f, + 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, + 0.426448852f, 0.496989518f, -0.283434421f, -0.177366048f, + 0.715208411f, -0.496444523f, 0.189553142f, 0.616444945f, + 0.345852494f, 0.447739422f, 0.224696323f, 0.451372236f, + 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f}); + auto expected = NDArrayFactory::create( + 'c', {5, 4, 3}, + {0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, + -0.170521997f, 1.07776645f, 0.842775284f, 0.228765106f, + 0.280231822f, 0.660605291f, 0.905021825f, 1.91936605f, + 0.837427991f, 0.792213732f, -0.133271854f, -0.17216571f, + 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, + 0.807577594f, 0.825371955f, 0.383812296f, 0.916293093f, + 0.82603058f, 1.23885956f, 0.905059196f, 0.015164554f, + 0.950156781f, 0.508443732f, 0.794845279f, 0.12571529f, + -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, + 0.953435072f, 0.115916842f, 0.688879376f, 0.508405162f, + 0.35829352f, 0.727568094f, 1.58768577f, 1.22504294f, + 0.232589777f, 0.996727258f, 0.841224629f, -0.0909671176f, + 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f}); + auto actual = NDArrayFactory::create('c', {5, 4, 3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { - - auto yiqs = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f, - 0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, - -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f, - 0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f, - 0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f, - -0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, - -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f, - -0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f, - 0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f, - -0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f - }); - auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f, - 0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, - -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f, - 0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f, - 0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f, - 0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, - 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f, - 1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f, - 0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f, - 0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f - }); - auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, yiqs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 1 }); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, + -0.288912386f, -0.212469354f, 0.349350512f, -0.471601307f, + -0.132725924f, 0.455438733f, 0.145902053f, 0.263960421f, + 0.700227439f, 0.130805135f, 0.0276055578f, 0.716282248f, + 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, + -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, + 0.76971364f, 0.900081575f, 0.387832165f, 0.632930398f, + 0.131288841f, -0.0788725987f, 0.229834676f, 0.0443540029f, + -0.141177326f, 0.14756602f, 0.47921446f, -0.268817365f, + 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, + -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, + -0.140715122f, -0.106209636f, 0.426448852f, -0.177366048f, + 0.715208411f, 0.616444945f, 0.224696323f, 0.446561724f, + -0.496444523f, 0.345852494f, 0.451372236f, -0.187599331f, + 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f}); + auto expected = NDArrayFactory::create( + 'c', {5, 3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, + 0.939747555f, -0.170521997f, 0.228765106f, 0.905021825f, + 0.868814286f, 1.07776645f, 0.280231822f, 1.91936605f, + 0.837427991f, -0.17216571f, 0.0451873479f, 0.705446224f, + 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, + -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, + 0.807577594f, 0.916293093f, 0.905059196f, 0.508443732f, + 0.825371955f, 0.82603058f, 0.015164554f, 0.794845279f, + 0.383812296f, 1.23885956f, 0.950156781f, 0.12571529f, + -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, + 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, + 0.0147000261f, 1.34712305f, 0.953435072f, 0.508405162f, + 0.35829352f, 1.22504294f, 0.841224629f, -0.0110094378f, + 0.727568094f, 0.232589777f, -0.0909671176f, 0.787642119f, + 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f}); + auto actual = NDArrayFactory::create('c', {5, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({1}); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { - - auto yiqs = NDArrayFactory::create('c', { 4, 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, - 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, - -0.471601307f, 0.263960421f - }); - auto expected = NDArrayFactory::create('c', { 4, 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, - 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, - 0.905021825f, 1.91936605f - }); - auto actual = NDArrayFactory::create('c', { 4, 3 }); - - Context ctx(1); - ctx.setInputArray(0, yiqs); - ctx.setOutputArray(0, actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {4, 3}, + {0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, + {0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f}); + auto actual = NDArrayFactory::create('c', {4, 3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f - }); - auto expected = NDArrayFactory::create('c', { 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f - }); - auto actual = NDArrayFactory::create('c', { 3, 4 }); - - Context ctx(1); - ctx.setInputArray(0, yiqs); - ctx.setOutputArray(0, actual); - ctx.setIArguments({ 0 }); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); - + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f}); + auto actual = NDArrayFactory::create('c', {3, 4}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + ctx.setIArguments({0}); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { - - auto yiqs = NDArrayFactory::create('c', { 3 }, { - 0.775258899f, -0.288912386f, -0.132725924f - }); - auto expected = NDArrayFactory::create('c', { 3 }, { - 0.416663059f, 0.939747555f, 0.868814286f - }); - auto actual = NDArrayFactory::create('c', { 3 }); - - Context ctx(1); - ctx.setInputArray(0, yiqs); - ctx.setOutputArray(0, actual); - - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); + auto yiqs = NDArrayFactory::create( + 'c', {3}, {0.775258899f, -0.288912386f, -0.132725924f}); + auto expected = NDArrayFactory::create( + 'c', {3}, {0.416663059f, 0.939747555f, 0.868814286f}); + auto actual = NDArrayFactory::create('c', {3}); + + Context ctx(1); + ctx.setInputArray(0, yiqs); + ctx.setOutputArray(0, actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); #if 0 actual.printBuffer("actual"); expected.printBuffer("expected"); #endif - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { - - auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { - 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, - -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, - 0.145902053f, 0.263960421f - }); - auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { - 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, - -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, - 0.280231822f, 1.91936605f - }); - - //get subarray - NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); - NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); - subArrYiqs.reshapei({ 3 }); - expected.reshapei({ 3 }); + auto yiqs = NDArrayFactory::create( + 'c', {3, 4}, + {0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f}); + auto rgbs = NDArrayFactory::create( + 'c', {3, 4}, + {0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f}); + + // get subarray + NDArray subArrYiqs = yiqs.subarray({NDIndex::all(), NDIndex::point(0)}); + NDArray expected = rgbs.subarray({NDIndex::all(), NDIndex::point(0)}); + subArrYiqs.reshapei({3}); + expected.reshapei({3}); #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrYiqs.printShapeInfo("subArrYiqs"); #endif - auto actual = NDArrayFactory::create('c', { 3 }); + auto actual = NDArrayFactory::create('c', {3}); - Context ctx(1); - ctx.setInputArray(0, subArrYiqs); - ctx.setOutputArray(0, actual); - sd::ops::yiq_to_rgb op; - auto status = op.execute(&ctx); + Context ctx(1); + ctx.setInputArray(0, subArrYiqs); + ctx.setOutputArray(0, actual); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.equalsTo(actual)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp index 08434b25b11e..e0c0bd451984 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -14,77 +14,73 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests17 : public testing::Test { -public: - - DeclarableOpsTests17() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests17() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { - auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); - auto shape = NDArrayFactory::create({3, 3}); - auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); - auto def = NDArrayFactory::create(0.f); - auto exp = NDArrayFactory::create('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f}); - - - sd::ops::compat_sparse_to_dense op; - auto result = op.evaluate({&ranges, &shape, &values, &def}); - ASSERT_EQ(Status::OK(), result.status()); + auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto def = NDArrayFactory::create(0.f); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 2.f, 0.f, 0.f, 0.f, 3.f}); + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { - auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); - auto shape = NDArrayFactory::create({3, 3}); - auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); - auto def = NDArrayFactory::string("d"); - auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); - - - sd::ops::compat_sparse_to_dense op; - auto result = op.evaluate({&ranges, &shape, &values, &def}); - ASSERT_EQ(Status::OK(), result.status()); - + auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto def = NDArrayFactory::string("d"); + auto exp = NDArrayFactory::string( + {3, 3}, {"alpha", "d", "d", "d", "beta", "d", "d", "d", "gamma"}); + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { - auto x = NDArrayFactory::string( {2}, {"first string", "second"}); - auto delimiter = NDArrayFactory::string(" "); - - auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + auto x = NDArrayFactory::string({2}, {"first string", "second"}); + auto delimiter = NDArrayFactory::string(" "); - sd::ops::compat_string_split op; - auto result = op.evaluate({&x, &delimiter}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(2, result.size()); + auto exp0 = NDArrayFactory::create({0, 0, 0, 1, 1, 0}); + auto exp1 = NDArrayFactory::string({3}, {"first", "string", "second"}); - auto z0 = result.at(0); - auto z1 = result.at(1); + sd::ops::compat_string_split op; + auto result = op.evaluate({&x, &delimiter}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp1.isSameShape(z1)); + auto z0 = result.at(0); + auto z1 = result.at(1); - ASSERT_EQ(exp0, z0); - ASSERT_EQ(exp1, z1); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index 7d7a39cdac63..bcea71efc00b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -14,1596 +14,2613 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +// +// @author raver119@gmail.com +// - // - // @author raver119@gmail.com - // - -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests18 : public testing::Test { -public: - - DeclarableOpsTests18() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests18() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests18, test_bitcast_1) { - auto x = NDArrayFactory::create(0.23028551377579154); - auto z = NDArrayFactory::create(0); - auto e = NDArrayFactory::create(4597464930322771456L); + auto x = NDArrayFactory::create(0.23028551377579154); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(4597464930322771456L); - sd::ops::bitcast op; - auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::bitcast op; + auto status = op.execute({&x}, {&z}, {}, {(Nd4jLong)sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests18, test_tanh_1) { - auto x = NDArrayFactory::create('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f }); - auto z = x.ulike(); - auto e = NDArrayFactory::create('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f }); - - sd::ops::tanh op; - op.execute({ &x }, { &z }); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create( + 'c', {8}, + {0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f}); + auto z = x.ulike(); + auto e = NDArrayFactory::create( + 'c', {8}, + {0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, + -1.f}); + + sd::ops::tanh op; + op.execute({&x}, {&z}); + + ASSERT_EQ(e, z); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_2) { - - NDArray x('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - NDArray z('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - - NDArray e('c', { 2, 2, 3, 3, 4, 4 }, { -0.761594, -0.760331, -0.759063, -0.757788, -0.756508, -0.755222, -0.753930, -0.752633, -0.751329, -0.750020, -0.748704, -0.747383, -0.746056, -0.744723, -0.743383, -0.742038, -0.740687, -0.739330, -0.737967, -0.736598, -0.735222, -0.733841, -0.732453, -0.731060, -0.729660, -0.728254, -0.726842, -0.725424, -0.724000, -0.722569, -0.721132, -0.719689, -0.718240, -0.716784, -0.715323, -0.713854, -0.712380, -0.710899, -0.709412, -0.707919, -0.706419, -0.704913, -0.703401, -0.701882, -0.700357, -0.698825, -0.697287, -0.695742, -0.694191, -0.692634, -0.691069, -0.689499, -0.687922, -0.686338, -0.684748, -0.683152, -0.681548, -0.679939, -0.678322, -0.676699, -0.675070, -0.673434, -0.671791, -0.670142, -0.668486, -0.666823, -0.665153, -0.663477, -0.661795, -0.660105, -0.658409, -0.656706, -0.654997, -0.653280, -0.651557, -0.649827, -0.648091, -0.646348, -0.644597, -0.642841, -0.641077, -0.639306, -0.637529, -0.635745, -0.633954, -0.632157, -0.630352, -0.628541, -0.626722, -0.624897, -0.623065, -0.621227, -0.619381, -0.617528, -0.615669, -0.613803, -0.611929, -0.610049, -0.608162, -0.606269, -0.604368, -0.602460, -0.600546, -0.598624, -0.596696, -0.594760, -0.592818, -0.590869, -0.588913, -0.586950, -0.584980, -0.583003, -0.581019, -0.579029, -0.577031, -0.575026, -0.573015, -0.570996, -0.568971, -0.566939, -0.564900, -0.562853, -0.560800, -0.558740, -0.556674, -0.554600, -0.552519, -0.550431, -0.548337, -0.546235, -0.544127, -0.542012, -0.539890, -0.537761, -0.535625, -0.533482, -0.531332, -0.529176, -0.527013, -0.524842, -0.522665, -0.520482, -0.518291, -0.516093, -0.513889, -0.511678, -0.509460, -0.507235, -0.505004, -0.502765, -0.500520, -0.498268, -0.496010, -0.493745, -0.491472, -0.489194, -0.486908, -0.484616, -0.482318, -0.480012, -0.477700, -0.475381, -0.473056, -0.470724, -0.468385, -0.466040, -0.463689, -0.461330, -0.458966, -0.456594, -0.454216, -0.451832, -0.449441, -0.447044, -0.444640, -0.442230, -0.439814, -0.437391, -0.434962, -0.432526, -0.430084, -0.427636, -0.425181, -0.422721, -0.420254, -0.417780, -0.415301, -0.412815, -0.410323, -0.407825, -0.405321, -0.402811, -0.400295, -0.397773, -0.395244, -0.392710, -0.390170, -0.387623, -0.385071, -0.382513, -0.379949, -0.377379, -0.374803, -0.372222, -0.369635, -0.367042, -0.364443, -0.361839, -0.359229, -0.356613, -0.353992, -0.351365, -0.348732, -0.346095, -0.343451, -0.340802, -0.338148, -0.335488, -0.332823, -0.330153, -0.327477, -0.324796, -0.322110, -0.319419, -0.316723, -0.314021, -0.311314, -0.308602, -0.305886, -0.303164, -0.300437, -0.297705, -0.294969, -0.292227, -0.289481, -0.286730, -0.283975, -0.281214, -0.278449, -0.275679, -0.272905, -0.270126, -0.267343, -0.264555, -0.261763, -0.258966, -0.256165, -0.253360, -0.250550, -0.247737, -0.244919, -0.242097, -0.239270, -0.236440, -0.233606, -0.230768, -0.227925, -0.225079, -0.222229, -0.219376, -0.216518, -0.213657, -0.210792, -0.207923, -0.205051, -0.202176, -0.199297, -0.196414, -0.193528, -0.190639, -0.187746, -0.184850, -0.181951, -0.179049, -0.176144, -0.173235, -0.170324, -0.167409, -0.164492, -0.161572, -0.158649, -0.155723, -0.152794, -0.149863, -0.146929, -0.143992, -0.141053, -0.138112, -0.135168, -0.132221, -0.129273, -0.126322, -0.123368, -0.120413, -0.117455, -0.114496, -0.111534, -0.108570, -0.105605, -0.102637, -0.099668, -0.096697, -0.093724, -0.090750, -0.087774, -0.084796, -0.081817, -0.078836, -0.075854, -0.072871, -0.069886, -0.066900, -0.063913, -0.060924, -0.057935, -0.054945, -0.051953, -0.048961, -0.045968, -0.042974, -0.039979, -0.036983, -0.033987, -0.030990, -0.027993, -0.024995, -0.021996, -0.018998, -0.015999, -0.012999, -0.010000, -0.007000, -0.004000, -0.001000, 0.002000, 0.005000, 0.008000, 0.011000, 0.013999, 0.016998, 0.019997, 0.022996, 0.025994, 0.028992, 0.031989, 0.034986, 0.037982, 0.040977, 0.043972, 0.046965, 0.049958, 0.052950, 0.055942, 0.058932, 0.061921, 0.064909, 0.067895, 0.070881, 0.073865, 0.076848, 0.079830, 0.082810, 0.085789, 0.088766, 0.091741, 0.094715, 0.097687, 0.100658, 0.103627, 0.106594, 0.109558, 0.112521, 0.115482, 0.118441, 0.121398, 0.124353, 0.127305, 0.130256, 0.133204, 0.136149, 0.139092, 0.142033, 0.144971, 0.147907, 0.150840, 0.153771, 0.156698, 0.159623, 0.162545, 0.165465, 0.168381, 0.171294, 0.174205, 0.177112, 0.180017, 0.182918, 0.185816, 0.188711, 0.191602, 0.194490, 0.197375, 0.200257, 0.203135, 0.206009, 0.208880, 0.211747, 0.214611, 0.217471, 0.220327, 0.223180, 0.226028, 0.228873, 0.231714, 0.234551, 0.237384, 0.240213, 0.243038, 0.245858, 0.248675, 0.251487, 0.254296, 0.257099, 0.259899, 0.262694, 0.265485, 0.268271, 0.271053, 0.273830, 0.276603, 0.279371, 0.282135, 0.284894, 0.287648, 0.290397, 0.293142, 0.295882, 0.298617, 0.301347, 0.304072, 0.306792, 0.309507, 0.312217, 0.314922, 0.317622, 0.320317, 0.323006, 0.325691, 0.328370, 0.331044, 0.333712, 0.336376, 0.339033, 0.341686, 0.344333, 0.346974, 0.349611, 0.352241, 0.354866, 0.357485, 0.360099, 0.362707, 0.365310, 0.367907, 0.370498, 0.373083, 0.375663, 0.378236, 0.380804, 0.383366, 0.385922, 0.388473, 0.391017, 0.393555, 0.396088, 0.398614, 0.401134, 0.403649, 0.406157, 0.408659, 0.411155, 0.413644, 0.416128, 0.418605, 0.421077, 0.423542, 0.426000, 0.428453, 0.430899, 0.433339, 0.435772, 0.438199, 0.440620, 0.443034, 0.445442, 0.447844, 0.450239, 0.452628, 0.455010, 0.457385, 0.459755, 0.462117, 0.464473, 0.466823, 0.469166, 0.471502, 0.473832, 0.476155, 0.478471, 0.480781, 0.483085, 0.485381, 0.487671, 0.489954, 0.492231, 0.494500, 0.496763, 0.499020, 0.501269, 0.503512, 0.505748, 0.507977, 0.510200, 0.512416, 0.514624, 0.516827, 0.519022, 0.521210, 0.523392, 0.525567, 0.527735, 0.529896, 0.532050, 0.534197, 0.536338, 0.538471, 0.540598, 0.542718, 0.544831, 0.546937, 0.549036, 0.551128, 0.553213, 0.555292, 0.557363, 0.559428, 0.561486, 0.563536, 0.565580, 0.567617, 0.569647, 0.571670, 0.573686, 0.575695, 0.577697, 0.579693, 0.581681, 0.583663, 0.585637, 0.587605, 0.589566, 0.591519, 0.593466, 0.595406, 0.597339, 0.599265, 0.601184, 0.603097, 0.605002, 0.606901, 0.608792, 0.610677, 0.612555, 0.614425, 0.616289, 0.618147, 0.619997 }, sd::DataType::FLOAT32); - - sd::ops::tanh op; - op.execute({ &x }, { &z }); - ASSERT_EQ(e, z); + NDArray x('c', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray z('c', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + + NDArray e('c', {2, 2, 3, 3, 4, 4}, + {-0.761594, -0.760331, -0.759063, -0.757788, -0.756508, -0.755222, + -0.753930, -0.752633, -0.751329, -0.750020, -0.748704, -0.747383, + -0.746056, -0.744723, -0.743383, -0.742038, -0.740687, -0.739330, + -0.737967, -0.736598, -0.735222, -0.733841, -0.732453, -0.731060, + -0.729660, -0.728254, -0.726842, -0.725424, -0.724000, -0.722569, + -0.721132, -0.719689, -0.718240, -0.716784, -0.715323, -0.713854, + -0.712380, -0.710899, -0.709412, -0.707919, -0.706419, -0.704913, + -0.703401, -0.701882, -0.700357, -0.698825, -0.697287, -0.695742, + -0.694191, -0.692634, -0.691069, -0.689499, -0.687922, -0.686338, + -0.684748, -0.683152, -0.681548, -0.679939, -0.678322, -0.676699, + -0.675070, -0.673434, -0.671791, -0.670142, -0.668486, -0.666823, + -0.665153, -0.663477, -0.661795, -0.660105, -0.658409, -0.656706, + -0.654997, -0.653280, -0.651557, -0.649827, -0.648091, -0.646348, + -0.644597, -0.642841, -0.641077, -0.639306, -0.637529, -0.635745, + -0.633954, -0.632157, -0.630352, -0.628541, -0.626722, -0.624897, + -0.623065, -0.621227, -0.619381, -0.617528, -0.615669, -0.613803, + -0.611929, -0.610049, -0.608162, -0.606269, -0.604368, -0.602460, + -0.600546, -0.598624, -0.596696, -0.594760, -0.592818, -0.590869, + -0.588913, -0.586950, -0.584980, -0.583003, -0.581019, -0.579029, + -0.577031, -0.575026, -0.573015, -0.570996, -0.568971, -0.566939, + -0.564900, -0.562853, -0.560800, -0.558740, -0.556674, -0.554600, + -0.552519, -0.550431, -0.548337, -0.546235, -0.544127, -0.542012, + -0.539890, -0.537761, -0.535625, -0.533482, -0.531332, -0.529176, + -0.527013, -0.524842, -0.522665, -0.520482, -0.518291, -0.516093, + -0.513889, -0.511678, -0.509460, -0.507235, -0.505004, -0.502765, + -0.500520, -0.498268, -0.496010, -0.493745, -0.491472, -0.489194, + -0.486908, -0.484616, -0.482318, -0.480012, -0.477700, -0.475381, + -0.473056, -0.470724, -0.468385, -0.466040, -0.463689, -0.461330, + -0.458966, -0.456594, -0.454216, -0.451832, -0.449441, -0.447044, + -0.444640, -0.442230, -0.439814, -0.437391, -0.434962, -0.432526, + -0.430084, -0.427636, -0.425181, -0.422721, -0.420254, -0.417780, + -0.415301, -0.412815, -0.410323, -0.407825, -0.405321, -0.402811, + -0.400295, -0.397773, -0.395244, -0.392710, -0.390170, -0.387623, + -0.385071, -0.382513, -0.379949, -0.377379, -0.374803, -0.372222, + -0.369635, -0.367042, -0.364443, -0.361839, -0.359229, -0.356613, + -0.353992, -0.351365, -0.348732, -0.346095, -0.343451, -0.340802, + -0.338148, -0.335488, -0.332823, -0.330153, -0.327477, -0.324796, + -0.322110, -0.319419, -0.316723, -0.314021, -0.311314, -0.308602, + -0.305886, -0.303164, -0.300437, -0.297705, -0.294969, -0.292227, + -0.289481, -0.286730, -0.283975, -0.281214, -0.278449, -0.275679, + -0.272905, -0.270126, -0.267343, -0.264555, -0.261763, -0.258966, + -0.256165, -0.253360, -0.250550, -0.247737, -0.244919, -0.242097, + -0.239270, -0.236440, -0.233606, -0.230768, -0.227925, -0.225079, + -0.222229, -0.219376, -0.216518, -0.213657, -0.210792, -0.207923, + -0.205051, -0.202176, -0.199297, -0.196414, -0.193528, -0.190639, + -0.187746, -0.184850, -0.181951, -0.179049, -0.176144, -0.173235, + -0.170324, -0.167409, -0.164492, -0.161572, -0.158649, -0.155723, + -0.152794, -0.149863, -0.146929, -0.143992, -0.141053, -0.138112, + -0.135168, -0.132221, -0.129273, -0.126322, -0.123368, -0.120413, + -0.117455, -0.114496, -0.111534, -0.108570, -0.105605, -0.102637, + -0.099668, -0.096697, -0.093724, -0.090750, -0.087774, -0.084796, + -0.081817, -0.078836, -0.075854, -0.072871, -0.069886, -0.066900, + -0.063913, -0.060924, -0.057935, -0.054945, -0.051953, -0.048961, + -0.045968, -0.042974, -0.039979, -0.036983, -0.033987, -0.030990, + -0.027993, -0.024995, -0.021996, -0.018998, -0.015999, -0.012999, + -0.010000, -0.007000, -0.004000, -0.001000, 0.002000, 0.005000, + 0.008000, 0.011000, 0.013999, 0.016998, 0.019997, 0.022996, + 0.025994, 0.028992, 0.031989, 0.034986, 0.037982, 0.040977, + 0.043972, 0.046965, 0.049958, 0.052950, 0.055942, 0.058932, + 0.061921, 0.064909, 0.067895, 0.070881, 0.073865, 0.076848, + 0.079830, 0.082810, 0.085789, 0.088766, 0.091741, 0.094715, + 0.097687, 0.100658, 0.103627, 0.106594, 0.109558, 0.112521, + 0.115482, 0.118441, 0.121398, 0.124353, 0.127305, 0.130256, + 0.133204, 0.136149, 0.139092, 0.142033, 0.144971, 0.147907, + 0.150840, 0.153771, 0.156698, 0.159623, 0.162545, 0.165465, + 0.168381, 0.171294, 0.174205, 0.177112, 0.180017, 0.182918, + 0.185816, 0.188711, 0.191602, 0.194490, 0.197375, 0.200257, + 0.203135, 0.206009, 0.208880, 0.211747, 0.214611, 0.217471, + 0.220327, 0.223180, 0.226028, 0.228873, 0.231714, 0.234551, + 0.237384, 0.240213, 0.243038, 0.245858, 0.248675, 0.251487, + 0.254296, 0.257099, 0.259899, 0.262694, 0.265485, 0.268271, + 0.271053, 0.273830, 0.276603, 0.279371, 0.282135, 0.284894, + 0.287648, 0.290397, 0.293142, 0.295882, 0.298617, 0.301347, + 0.304072, 0.306792, 0.309507, 0.312217, 0.314922, 0.317622, + 0.320317, 0.323006, 0.325691, 0.328370, 0.331044, 0.333712, + 0.336376, 0.339033, 0.341686, 0.344333, 0.346974, 0.349611, + 0.352241, 0.354866, 0.357485, 0.360099, 0.362707, 0.365310, + 0.367907, 0.370498, 0.373083, 0.375663, 0.378236, 0.380804, + 0.383366, 0.385922, 0.388473, 0.391017, 0.393555, 0.396088, + 0.398614, 0.401134, 0.403649, 0.406157, 0.408659, 0.411155, + 0.413644, 0.416128, 0.418605, 0.421077, 0.423542, 0.426000, + 0.428453, 0.430899, 0.433339, 0.435772, 0.438199, 0.440620, + 0.443034, 0.445442, 0.447844, 0.450239, 0.452628, 0.455010, + 0.457385, 0.459755, 0.462117, 0.464473, 0.466823, 0.469166, + 0.471502, 0.473832, 0.476155, 0.478471, 0.480781, 0.483085, + 0.485381, 0.487671, 0.489954, 0.492231, 0.494500, 0.496763, + 0.499020, 0.501269, 0.503512, 0.505748, 0.507977, 0.510200, + 0.512416, 0.514624, 0.516827, 0.519022, 0.521210, 0.523392, + 0.525567, 0.527735, 0.529896, 0.532050, 0.534197, 0.536338, + 0.538471, 0.540598, 0.542718, 0.544831, 0.546937, 0.549036, + 0.551128, 0.553213, 0.555292, 0.557363, 0.559428, 0.561486, + 0.563536, 0.565580, 0.567617, 0.569647, 0.571670, 0.573686, + 0.575695, 0.577697, 0.579693, 0.581681, 0.583663, 0.585637, + 0.587605, 0.589566, 0.591519, 0.593466, 0.595406, 0.597339, + 0.599265, 0.601184, 0.603097, 0.605002, 0.606901, 0.608792, + 0.610677, 0.612555, 0.614425, 0.616289, 0.618147, 0.619997}, + sd::DataType::FLOAT32); + + sd::ops::tanh op; + op.execute({&x}, {&z}); + ASSERT_EQ(e, z); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp) { - - NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdx('c', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - dLdz.linspace(0.01, 0.01); - - NDArray e('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdz('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdx('c', {2, 3, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray e('c', {2, 3, 4}, + {0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, + 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, + 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, + 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732}, + sd::DataType::FLOAT32); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp2) { - - NDArray x('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - NDArray dLdx('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1., 0.003); - dLdz.linspace(0.01, 0.01); - - NDArray exp('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); - NDArray e('f', { 2, 3, 4 }, sd::DataType::FLOAT32); - e.assign(exp); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('f', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdz('f', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray dLdx('f', {2, 3, 4}, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray exp('c', {2, 3, 4}, + {0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, + 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, + 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, + 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732}, + sd::DataType::FLOAT32); + NDArray e('f', {2, 3, 4}, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, test_tanh_bp3) { - - NDArray x('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - NDArray dLdz('f', { 2,2, 3,3, 4,4 }, sd::DataType::FLOAT32); - NDArray dLdx('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - - x.linspace(-1.5, 0.005); - dLdz.linspace(-1., 0.01); - - NDArray exp('c', { 2, 2, 3, 3, 4, 4 }, { -0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049 }, sd::DataType::FLOAT32); - - NDArray e('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); - e.assign(exp); - - sd::ops::tanh_bp op; - op.execute({ &x, &dLdz }, { &dLdx }); - ASSERT_EQ(e, dLdx); + NDArray x('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray dLdz('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray dLdx('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + + x.linspace(-1.5, 0.005); + dLdz.linspace(-1., 0.01); + + NDArray exp('c', {2, 2, 3, 3, 4, 4}, + {-0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, + -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, + -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, + -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, + -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, + -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, + -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, + -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, + -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, + -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, + -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, + -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, + -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, + -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, + -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, + -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, + -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, + 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, + 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, + 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, + 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, + 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, + 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, + 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, + 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, + 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, + 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, + 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, + 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, + 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, + 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, + 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, + 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, + 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, + 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, + 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, + 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, + 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, + 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, + 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, + 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, + 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, + 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, + 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, + 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, + 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, + 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, + 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, + 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, + 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, + 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, + 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, + 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, + 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, + 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, + 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, + 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, + 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, + 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, + 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, + 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, + 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, + 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, + 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, + 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, + 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, + 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, + 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, + 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, + 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, + 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, + 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, + 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, + 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, + 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, + 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, + 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, + 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, + 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, + 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, + 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, + 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, + 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, + 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, + 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, + 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, + 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, + 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, + 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, + 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, + 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, + 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, + 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, + 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, + 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, + 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049}, + sd::DataType::FLOAT32); + + NDArray e('f', {2, 2, 3, 3, 4, 4}, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({&x, &dLdz}, {&dLdx}); + ASSERT_EQ(e, dLdx); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) { + NDArray input('c', {2, 2}, {1, 2, 3, 4}, DataType::FLOAT32); + NDArray epsilon('c', {2, 2}, {.1, .2, .3, .4}, DataType::FLOAT32); - NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32); - NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32); - - int axis = 1; + int axis = 1; - NDArray output('c', { 2, 2 }, DataType::FLOAT32); + NDArray output('c', {2, 2}, DataType::FLOAT32); - NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32); + NDArray exp('c', {2, 2}, {-0.019661, 0.019661, -0.019661, 0.019661}, + DataType::FLOAT32); - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); + sd::ops::softmax_bp op; + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) { - - NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); - input.linspace(0.1, 0.2); - - int axis = -1; - - NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32); - - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); - + NDArray input('c', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray epsilon( + 'c', {4, 5, 2, 3}, + {-0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855}, + DataType::FLOAT32); + input.linspace(0.1, 0.2); + + int axis = -1; + + NDArray output('c', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray exp('c', {4, 5, 2, 3}, + {-0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, + -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253}, + DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { - - NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); - input.linspace(-5., 0.5); - - int axis = 1; - - NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32); - - NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32); - exp.assign(expC); - - sd::ops::softmax_bp op; - - Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(output.equalsTo(exp)); + NDArray input('f', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray epsilon( + 'f', {4, 5, 2, 3}, + {-0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, + -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855}, + DataType::FLOAT32); + input.linspace(-5., 0.5); + + int axis = 1; + + NDArray output('f', {4, 5, 2, 3}, DataType::FLOAT32); + NDArray expC( + 'c', {4, 5, 2, 3}, + {-0.0, -0.0, 0.0, 0.0, -0.0, -0.0, + 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, + 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, + -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, + 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, + -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, + 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, + -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, + -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, + 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, + 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, + -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, + 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, + 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, + -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, + -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, + 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, + 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, + -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, + 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909}, + DataType::FLOAT32); + + NDArray exp('f', {4, 5, 2, 3}, DataType::FLOAT32); + exp.assign(expC); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({&input, &epsilon}, {&output}, {}, {axis}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) { - - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32); - dLdz.linspace(1); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f }); - auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f }); - auto edLdb = NDArrayFactory::create('c', { 2 }, { 4.f, 6.f }); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + NDArray dLdz('c', {2, 2}, DataType::FLOAT32); + dLdz.linspace(1); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {2, 3}, {17.f, 14.f, 10.f, 45.f, 32.f, 26.f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 2}, {43.f, 58.f, 26.f, 42.f, 21.f, 30.f}); + auto edLdb = NDArrayFactory::create('c', {2}, {4.f, 6.f}); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) { - - auto x = NDArrayFactory::create('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create('c', { 4 }, { 100.f, 200.f, 100.f, 200.f }); - - NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32); - dLdz.linspace(.1, .5); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f }); - auto edLdw = NDArrayFactory::create('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f }); - auto edLdb = NDArrayFactory::create('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create( + 'c', {6, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, + 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create( + 'c', {3, 4}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = + NDArrayFactory::create('c', {4}, {100.f, 200.f, 100.f, 200.f}); + + NDArray dLdz('c', {6, 4}, DataType::FLOAT32); + dLdz.linspace(.1, .5); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {6, 3}, + {15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, + 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, + 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, + 238.700012f, 183.199997f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 4}, + {268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, + 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f}); + auto edLdb = NDArrayFactory::create( + 'c', {4}, {30.6f, 33.599998f, 36.599998f, 39.599998f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) { - - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto w = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); - - auto dLdz = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 3937.f, 3096.f }); - auto edLdw = NDArrayFactory::create('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f }); - auto edLdb = NDArrayFactory::create('c', { 3 }, { 166.f, 269.f, 326.f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto w = NDArrayFactory::create('c', {2, 3}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f, 300.f}); + + auto dLdz = NDArrayFactory::create('c', {1, 3}, {166.f, 269.f, 326.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', {1, 2}, {3937.f, 3096.f}); + auto edLdw = NDArrayFactory::create( + 'c', {2, 3}, {166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f}); + auto edLdb = NDArrayFactory::create('c', {3}, {166.f, 269.f, 326.f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) { - - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto w = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); - auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); - - auto dLdz = NDArrayFactory::create('c', { 1,1 }, { 244.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 2684.f, 732.f }); - auto edLdw = NDArrayFactory::create('c', { 2,1 }, { 244.f, 2684.f }); - auto edLdb = NDArrayFactory::create('c', { 1 }, { 244.f }); - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto w = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); + auto b = NDArrayFactory::create('c', {1}, {200.f}); + + auto dLdz = NDArrayFactory::create('c', {1, 1}, {244.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', {1, 2}, {2684.f, 732.f}); + auto edLdw = NDArrayFactory::create('c', {2, 1}, {244.f, 2684.f}); + auto edLdb = NDArrayFactory::create('c', {1}, {244.f}); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) { - - auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto dLdz = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdxC = NDArrayFactory::create('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f }); - auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f }); - auto edLdbC = NDArrayFactory::create('c', { 2 }, { 427.f, 584.f }); - - auto edLdx = NDArrayFactory::create('f', { 2,3 }); - auto edLdw = NDArrayFactory::create('f', { 3,2 }); - auto edLdb = NDArrayFactory::create('f', { 2 }); - - edLdx.assign(edLdxC); - edLdw.assign(edLdwC); - edLdb.assign(edLdbC); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); - + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('f', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + auto dLdz = + NDArrayFactory::create('f', {2, 2}, {140.f, 287.f, 233.f, 351.f}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdxC = NDArrayFactory::create( + 'c', {2, 3}, {2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f}); + auto edLdwC = NDArrayFactory::create( + 'c', {3, 2}, {3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f}); + auto edLdbC = NDArrayFactory::create('c', {2}, {427.f, 584.f}); + + auto edLdx = NDArrayFactory::create('f', {2, 3}); + auto edLdw = NDArrayFactory::create('f', {3, 2}); + auto edLdb = NDArrayFactory::create('f', {2}); + + edLdx.assign(edLdxC); + edLdw.assign(edLdwC); + edLdb.assign(edLdbC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) { - - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto dLdz = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); - - // mkl-format - w.permutei({ 1,0 }); - - sd::ops::xw_plus_b_bp op; - auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 }); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto dLdx = result.at(0); - auto dLdw = result.at(1); - auto dLdb = result.at(2); - - auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f }); - auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); - auto edLdb = NDArrayFactory::create('c', { 2 }, { 483.f, 543.f }); - auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); - edLdw.permutei({ 1,0 }); - edLdw.assign(edLdwC); - - ASSERT_TRUE(edLdx.isSameShape(dLdx)); - ASSERT_TRUE(edLdw.isSameShape(dLdw)); - ASSERT_TRUE(edLdb.isSameShape(dLdb)); - ASSERT_TRUE(edLdx.equalsTo(dLdx)); - ASSERT_TRUE(edLdw.equalsTo(dLdw)); - ASSERT_TRUE(edLdb.equalsTo(dLdb)); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto w = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); + + auto dLdz = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); + + // mkl-format + w.permutei({1, 0}); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({&x, &w, &b, &dLdz}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create( + 'c', {2, 3}, {2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f}); + auto edLdwC = NDArrayFactory::create( + 'c', {3, 2}, {4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f}); + auto edLdb = NDArrayFactory::create('c', {2}, {483.f, 543.f}); + auto edLdw = NDArrayFactory::create( + 'c', {3, 2}, {4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f}); + edLdw.permutei({1, 0}); + edLdw.assign(edLdwC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { + NDArray gradient('c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, + 0.2822134494781494, 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.001f); - NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); - auto lr = NDArrayFactory::create(0.001f); - - NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - sd::ops::sgd_updater op; + NDArray update('c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); - Nd4jStatus status = op.execute({ &gradient, &lr }, { &gradient }, {}, { }); + sd::ops::sgd_updater op; - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(update.equalsTo(gradient)); + Nd4jStatus status = op.execute({&gradient, &lr}, {&gradient}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd2) { + NDArray gradient('c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, + 0.2822134494781494, 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); - NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + NDArray update('c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); - NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - sd::ops::sgd_updater op; - - Nd4jStatus status = op.execute({ &gradient }, { &gradient }, { 0.001f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(update.equalsTo(gradient)); + sd::ops::sgd_updater op; + Nd4jStatus status = op.execute({&gradient}, {&gradient}, {0.001f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterSgd3) { - - NDArray gradientC('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); - - NDArray updateC('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); - - NDArray gradient('f', { 1, 5 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - gradient.assign(gradientC); - update.assign(updateC); - - sd::ops::sgd_updater op; - - auto results = op.evaluate({ &gradient }, { 0.001f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); + NDArray gradientC( + 'c', {1, 5}, + {0.21138794720172882, 0.38947954773902893, 0.2822134494781494, + 0.4342866837978363, 0.7928546667098999}, + DataType::FLOAT32); + + NDArray updateC( + 'c', {1, 5}, + {0.00021138794720173, 0.00038947954773903, 0.00028221344947815, + 0.00043428668379784, 0.0007928546667099}, + DataType::FLOAT32); + + NDArray gradient('f', {1, 5}, DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + gradient.assign(gradientC); + update.assign(updateC); + + sd::ops::sgd_updater op; + + auto results = op.evaluate({&gradient}, {0.001f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm1) { - - NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.1f); - auto decay = NDArrayFactory::create(0.95f); - auto epsilon = NDArrayFactory::create(1.e-8f); - - sd::ops::rms_prop_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init, &lr, &decay, &epsilon }, { &grad0, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp0('c', { 1, 5 }, { 0.4472121903197142, 0.4472095514452829, 0.4472135169488324, 0.44721352981195367, 0.44721349127249754 }, DataType::FLOAT32); - NDArray stateG0('c', { 1, 5 }, { 0.00164065126484513, 0.00055124687044416, 0.03816546608068996, 0.04711672627124962, 0.02749591463177582 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateG0)); - - - NDArray grad1('c', { 1, 5 }, { 0.0139725673943758, 0.19333727657794952, 0.9288347363471985, 0.9253600239753723, 0.3578299283981323 }, DataType::FLOAT32); - status = op.execute({ &grad1, &init, &lr, &decay, &epsilon }, { &grad1, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.03528177364993147, 0.3952537075263024, 0.32964378302079766, 0.31269398966616074, 0.1984174163852542 }, DataType::FLOAT32); - NDArray stateG1('c', { 1, 5 }, { 0.00156838033358239, 0.00239264965265088, 0.07939389114891399, 0.08757544865627226, 0.03252323178305766 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateG1)); - - NDArray grad2('c', { 1, 5 }, { 0.5442887544631958, 0.5386605262756348, 0.884294331073761, 0.15599730610847473, 0.7259345054626465 }, DataType::FLOAT32); - status = op.execute({ &grad2, &init, &lr, &decay, &epsilon }, { &grad2, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.4262874753567082, 0.41582357367557454, 0.2613066321005825, 0.05369221235564697, 0.3034061716240995 }, DataType::FLOAT32); - NDArray stateG2('c', { 1, 5 }, { 0.01630247372865814, 0.01678077529839554, 0.11452301978992785, 0.0844134341991137, 0.05724611550496966 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateG2)); - + NDArray grad0('c', {1, 5}, + {0.1811431348323822, 0.10499879717826843, 0.8736756443977356, + 0.9707390666007996, 0.7415646314620972}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto decay = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({&grad0, &init, &lr, &decay, &epsilon}, + {&grad0, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0( + 'c', {1, 5}, + {0.4472121903197142, 0.4472095514452829, 0.4472135169488324, + 0.44721352981195367, 0.44721349127249754}, + DataType::FLOAT32); + NDArray stateG0( + 'c', {1, 5}, + {0.00164065126484513, 0.00055124687044416, 0.03816546608068996, + 0.04711672627124962, 0.02749591463177582}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + NDArray grad1('c', {1, 5}, + {0.0139725673943758, 0.19333727657794952, 0.9288347363471985, + 0.9253600239753723, 0.3578299283981323}, + DataType::FLOAT32); + status = op.execute({&grad1, &init, &lr, &decay, &epsilon}, {&grad1, &init}, + {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.03528177364993147, 0.3952537075263024, 0.32964378302079766, + 0.31269398966616074, 0.1984174163852542}, + DataType::FLOAT32); + NDArray stateG1( + 'c', {1, 5}, + {0.00156838033358239, 0.00239264965265088, 0.07939389114891399, + 0.08757544865627226, 0.03252323178305766}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + NDArray grad2('c', {1, 5}, + {0.5442887544631958, 0.5386605262756348, 0.884294331073761, + 0.15599730610847473, 0.7259345054626465}, + DataType::FLOAT32); + status = op.execute({&grad2, &init, &lr, &decay, &epsilon}, {&grad2, &init}, + {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.4262874753567082, 0.41582357367557454, 0.2613066321005825, + 0.05369221235564697, 0.3034061716240995}, + DataType::FLOAT32); + NDArray stateG2( + 'c', {1, 5}, + {0.01630247372865814, 0.01678077529839554, 0.11452301978992785, + 0.0844134341991137, 0.05724611550496966}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm2) { - - NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::rms_prop_updater op; - - Nd4jStatus status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp0('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); - NDArray stateG0('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateG0)); - - status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); - NDArray stateG1('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateG1)); - - status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); - NDArray stateG2('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateG2)); - + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0( + 'c', {1, 5}, + {0.4472135330146769, 0.44721357487863594, 0.44721358411270346, + 0.4472135878446271, 0.447213589800546}, + DataType::FLOAT32); + NDArray stateG0('c', {1, 5}, + {0.05000000950000005, 0.2000000095000002, 0.4500000095000004, + 0.8000000095000007, 1.250000009500001}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.32025628253164734, 0.3202562987764395, 0.32025630254446874, + 0.3202563041196892, 0.3202563049660074}, + DataType::FLOAT32); + NDArray stateG1('c', {1, 5}, + {0.09750000902500008, 0.3900000090250003, 0.8775000090250007, + 1.5600000090250012, 2.437500009025002}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + status = + op.execute({&grad, &init}, {&update, &init}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.2647903457769699, 0.2647903552517623, 0.26479035752571606, + 0.2647903584968847, 0.2647903590265272}, + DataType::FLOAT32); + NDArray stateG2('c', {1, 5}, + {0.1426250085737501, 0.5705000085737504, 1.283625008573751, + 2.2820000085737515, 3.565625008573753}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray init('f', { 1, 5 }, DataType::FLOAT32); - grad.assign(gradC); - init.assign(initC); - - sd::ops::rms_prop_updater op; - auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateG0C('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); - NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateG.assign(stateG0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); - NDArray stateG1C('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); - - update.assign(update1C); - stateG.assign(stateG1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); - NDArray stateG2C('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); - - update.assign(update2C); - stateG.assign(stateG2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initC('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray init('f', {1, 5}, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::rms_prop_updater op; + auto results = op.evaluate({&grad, &init}, {0.1f, 0.95f, 1.e-8}, {}); + + NDArray updateC('c', {1, 5}, + {0.4472135330146769, 0.44721357487863594, 0.44721358411270346, + 0.4472135878446271, 0.447213589800546}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateG0C('c', {1, 5}, + {0.05000000950000005, 0.2000000095000002, 0.4500000095000004, + 0.8000000095000007, 1.250000009500001}, + DataType::FLOAT32); + NDArray stateG('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.32025628253164734, 0.3202562987764395, 0.32025630254446874, + 0.3202563041196892, 0.3202563049660074}, + DataType::FLOAT32); + NDArray stateG1C('c', {1, 5}, + {0.09750000902500008, 0.3900000090250003, 0.8775000090250007, + 1.5600000090250012, 2.437500009025002}, + DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.95f, 1.e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, + {0.2647903457769699, 0.2647903552517623, 0.26479035752571606, + 0.2647903584968847, 0.2647903590265272}, + DataType::FLOAT32); + NDArray stateG2C('c', {1, 5}, + {0.1426250085737501, 0.5705000085737504, 1.283625008573751, + 2.2820000085737515, 3.565625008573753}, + DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaGrad1) { + // need Java test - // need Java test + NDArray grad0('c', {1, 5}, + {0.1811431348323822, 0.10499879717826843, 0.8736756443977356, + 0.9707390666007996, 0.7415646314620972}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); - NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.1f); + auto epsilon = NDArrayFactory::create(1.e-8f); - auto lr = NDArrayFactory::create(0.1f); - auto epsilon = NDArrayFactory::create(1.e-8f); - - sd::ops::ada_grad_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init, &lr, &epsilon }, { &grad0, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::ada_grad_updater op; + Nd4jStatus status = + op.execute({&grad0, &init, &lr, &epsilon}, {&grad0, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs1) { - - NDArray grad0('c', { 1, 5 }, { 0.6877592206001282, 0.7830561399459839, 0.7647699117660522, 0.6183066964149475, 0.3303879499435425 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - sd::ops::nesterovs_updater op; - - Nd4jStatus status = op.execute({ &grad0, &init }, { &grad0, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.13067425191402435, 0.14878066658973696, 0.14530628323554992, 0.11747827231884002, 0.06277371048927306 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { -0.06877592206001282, -0.0783056139945984, -0.07647699117660522, -0.06183066964149475, -0.03303879499435425 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateV0)); - - NDArray grad1('c', { 1, 5 }, { 0.3676236569881439, 0.07645636051893234, 0.45949840545654297, 0.6335387825965881, 0.2953402101993561 }, DataType::FLOAT32); - status = op.execute({ &grad1, &init }, { &grad1, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.12555699169635773, 0.07795425583422186, 0.14925105988979342, 0.17045521110296247, 0.08287606388330458 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { -0.09866069555282593, -0.0781206886470318, -0.11477913260459902, -0.11900148093700408, -0.05926893651485443 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateV1)); - - NDArray grad2('c', { 1, 5 }, { 0.9874004125595093, 0.41817641258239746, 0.16838215291500092, 0.00803728867322206, 0.37015461921691895 }, DataType::FLOAT32); - status = op.execute({ &grad2, &init }, { &grad2, &init }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.26752124178409575, 0.1427312761947513, 0.12496370646357537, 0.09791828440688549, 0.11833721622824667 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { -0.18753466725349427, -0.11212626104056837, -0.12013943463563921, -0.10790506171062587, -0.09035750478506088 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateV2)); - + NDArray grad0('c', {1, 5}, + {0.6877592206001282, 0.7830561399459839, 0.7647699117660522, + 0.6183066964149475, 0.3303879499435425}, + DataType::FLOAT32); + NDArray init('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = + op.execute({&grad0, &init}, {&grad0, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.13067425191402435, 0.14878066658973696, 0.14530628323554992, + 0.11747827231884002, 0.06277371048927306}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {-0.06877592206001282, -0.0783056139945984, -0.07647699117660522, + -0.06183066964149475, -0.03303879499435425}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + NDArray grad1('c', {1, 5}, + {0.3676236569881439, 0.07645636051893234, 0.45949840545654297, + 0.6335387825965881, 0.2953402101993561}, + DataType::FLOAT32); + status = op.execute({&grad1, &init}, {&grad1, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1( + 'c', {1, 5}, + {0.12555699169635773, 0.07795425583422186, 0.14925105988979342, + 0.17045521110296247, 0.08287606388330458}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {-0.09866069555282593, -0.0781206886470318, -0.11477913260459902, + -0.11900148093700408, -0.05926893651485443}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + NDArray grad2('c', {1, 5}, + {0.9874004125595093, 0.41817641258239746, 0.16838215291500092, + 0.00803728867322206, 0.37015461921691895}, + DataType::FLOAT32); + status = op.execute({&grad2, &init}, {&grad2, &init}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2( + 'c', {1, 5}, + {0.26752124178409575, 0.1427312761947513, 0.12496370646357537, + 0.09791828440688549, 0.11833721622824667}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {-0.18753466725349427, -0.11212626104056837, -0.12013943463563921, + -0.10790506171062587, -0.09035750478506088}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs2) { - - NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.1f); - auto momentum = NDArrayFactory::create(0.9f); - - sd::ops::nesterovs_updater op; - - Nd4jStatus status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(init.equalsTo(stateV0)); - - status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp1('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(init.equalsTo(stateV1)); - - status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - NDArray updateExp2('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(init.equalsTo(stateV2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray init('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto momentum = NDArrayFactory::create(0.9f); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = + op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.19, 0.38, 0.5700000000000001, 0.76, 0.95}, + DataType::FLOAT32); + NDArray stateV0('c', {1, 5}, {-0.1, -0.2, -0.30000000000000004, -0.4, -0.5}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + status = op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', {1, 5}, + {0.27099999999999996, 0.5419999999999999, 0.813, + 1.0839999999999999, 1.355}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {-0.19, -0.38, -0.5700000000000001, -0.76, -0.95}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + status = op.execute({&grad, &init, &lr, &momentum}, {&update, &init}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', {1, 5}, {0.3439, 0.6878, 1.0317, 1.3756, 1.7195}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {-0.271, -0.542, -0.8130000000000002, -1.084, -1.355}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray init('f', { 1, 5 }, DataType::FLOAT32); - grad.assign(gradC); - init.assign(initC); - - sd::ops::nesterovs_updater op; - auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.9f }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateG0C('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); - NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateG.assign(stateG0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); - NDArray stateG1C('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); - - update.assign(update1C); - stateG.assign(stateG1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); - - - results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); - NDArray stateG2C('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); - - update.assign(update2C); - stateG.assign(stateG2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateG.isSameShape(results.at(1))); - ASSERT_TRUE(stateG.equalsTo(results.at(1))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initC('c', {1, 5}, + {0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001}, + DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray init('f', {1, 5}, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::nesterovs_updater op; + auto results = op.evaluate({&grad, &init}, {0.1f, 0.9f}, {}); + + NDArray updateC('c', {1, 5}, {0.19, 0.38, 0.5700000000000001, 0.76, 0.95}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateG0C('c', {1, 5}, {-0.1, -0.2, -0.30000000000000004, -0.4, -0.5}, + DataType::FLOAT32); + NDArray stateG('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', {1, 5}, + {0.27099999999999996, 0.5419999999999999, 0.813, + 1.0839999999999999, 1.355}, + DataType::FLOAT32); + NDArray stateG1C('c', {1, 5}, + {-0.19, -0.38, -0.5700000000000001, -0.76, -0.95}, + DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({&grad, &stateG}, {0.1f, 0.9f}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, {0.3439, 0.6878, 1.0317, 1.3756, 1.7195}, + DataType::FLOAT32); + NDArray stateG2C('c', {1, 5}, + {-0.271, -0.542, -0.8130000000000002, -1.084, -1.355}, + DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ada_max_updater op; - - Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray stateU('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = + op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray stateU('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', {1, 5}, {0.0019, 0.0019, 0.0019, 0.0019, 0.0019}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', {1, 5}, {0.00271, 0.00271, 0.00271, 0.00271, 0.00271}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax2) { - - NDArray grad0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::ada_max_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray stateU0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.00538735911250114, 0.09700437784194944, 0.08912011384963987, 0.08891847729682921, 0.01882378011941909 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.6400517821311951, 0.3779360353946686, 0.35128724575042725, 0.6554615497589111, 0.8420050740242004 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00107575360832691, 0.00129089809294599, 0.00129546826560191, 0.00163878765669416, 0.00120120308808246 }, DataType::FLOAT32); - NDArray stateU1('c', { 1, 5 }, { 0.6400517821311951, 0.9690737346410752, 0.8903099373579025, 0.888295588195324, 0.8420050740242004 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.06885380141437052, 0.12509754359722136, 0.11533682703971859, 0.1455727845430374, 0.10114190950989721 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.5984494686126709, 0.05978915095329285, 0.5749519467353821, 0.2804091274738312, 0.0192152876406908 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00190508497658779, 0.00122473022928962, 0.00181352349370876, 0.00179237223044249, 0.00110500865710834 }, DataType::FLOAT32); - NDArray stateU2('c', { 1, 5 }, { 0.6394117303490638, 0.9681046609064341, 0.8894196274205446, 0.8874072926071286, 0.8411630689501762 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.12181336813420054, 0.11856670433282851, 0.16129833900928492, 0.15905641883611676, 0.09294924732297657 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.05387359112501144, 0.9700437784194946, 0.8912011384963989, + 0.8891847729682922, 0.18823780119419098}, + DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray stateU0('c', {1, 5}, + {0.05387359112501144, 0.9700437784194946, 0.8912011384963989, + 0.8891847729682922, 0.18823780119419098}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.00538735911250114, 0.09700437784194944, 0.08912011384963987, + 0.08891847729682921, 0.01882378011941909}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.6400517821311951, 0.3779360353946686, 0.35128724575042725, + 0.6554615497589111, 0.8420050740242004}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00107575360832691, 0.00129089809294599, 0.00129546826560191, + 0.00163878765669416, 0.00120120308808246}, + DataType::FLOAT32); + NDArray stateU1('c', {1, 5}, + {0.6400517821311951, 0.9690737346410752, 0.8903099373579025, + 0.888295588195324, 0.8420050740242004}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.06885380141437052, 0.12509754359722136, 0.11533682703971859, + 0.1455727845430374, 0.10114190950989721}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.5984494686126709, 0.05978915095329285, 0.5749519467353821, + 0.2804091274738312, 0.0192152876406908}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00190508497658779, 0.00122473022928962, 0.00181352349370876, + 0.00179237223044249, 0.00110500865710834}, + DataType::FLOAT32); + NDArray stateU2('c', {1, 5}, + {0.6394117303490638, 0.9681046609064341, 0.8894196274205446, + 0.8874072926071286, 0.8411630689501762}, + DataType::FLOAT32); + NDArray stateM2( + 'c', {1, 5}, + {0.12181336813420054, 0.11856670433282851, 0.16129833900928492, + 0.15905641883611676, 0.09294924732297657}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::ada_max_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateM.assign(stateM1C); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update2C('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateM.assign(stateM2C); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::ada_max_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC('c', {1, 5}, {0.001, 0.001, 0.001, 0.001, 0.001}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', {1, 5}, {0.0019, 0.0019, 0.0019, 0.0019, 0.0019}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateM.assign(stateM1C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', {1, 5}, {0.00271, 0.00271, 0.00271, 0.00271, 0.00271}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateM.assign(stateM2C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::adam_updater op; - - Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateV)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::adam_updater op; + + Nd4jStatus status = + op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray stateV('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initU, &initM}, {&update, &initU, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam2) { - - NDArray grad0('c', { 1, 5 }, { 0.7124611735343933, 0.7283763289451599, 0.8196553587913513, 0.9501070976257324, 0.2654055953025818 }, DataType::FLOAT32); - NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::adam_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999955614757, 0.00099999956584582, 0.00099999961419438, 0.0009999996671663, 0.00099999880851273 }, DataType::FLOAT32); - NDArray stateU0('c', { 1, 5 }, { 0.00050760092379401, 0.00053053207656763, 0.00067183490719538, 0.00090270349695879, 0.00007044013001792 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.07124611735343932, 0.07283763289451597, 0.08196553587913512, 0.09501070976257323, 0.02654055953025817 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initU.equalsTo(stateU0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.4374369978904724, 0.11488933861255646, 0.6765823364257812, 0.7659900188446045, 0.04410457238554955 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00129067017716555, 0.00104532555849556, 0.00133106720937621, 0.00132869584719374, 0.00105226561254395 }, DataType::FLOAT32); - NDArray stateU1('c', { 1, 5 }, { 0.00069844444999364, 0.00054320110461789, 0.00112892673025155, 0.00148854150243139, 0.00007231490319321 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.10786520540714262, 0.07704280346632002, 0.14142721593379973, 0.16210864067077635, 0.02829696081578731 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initU.equalsTo(stateU1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.496029257774353, 0.11621368676424026, 0.9112075567245483, 0.5717480182647705, 0.5975669026374817 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00150986322036664, 0.00108559662275258, 0.00156079502787382, 0.00150778241516558, 0.00130066803775601 }, DataType::FLOAT32); - NDArray stateU2('c', { 1, 5 }, { 0.00094379103011182, 0.00055616352450461, 0.00195809701495322, 0.00181394875731865, 0.00042932879141777 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.14668161064386365, 0.08095989179611204, 0.21840525001287456, 0.20307257843017573, 0.08522395499795674 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initU.equalsTo(stateU2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.7124611735343933, 0.7283763289451599, 0.8196553587913513, + 0.9501070976257324, 0.2654055953025818}, + DataType::FLOAT32); + NDArray initU('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::adam_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999955614757, 0.00099999956584582, 0.00099999961419438, + 0.0009999996671663, 0.00099999880851273}, + DataType::FLOAT32); + NDArray stateU0( + 'c', {1, 5}, + {0.00050760092379401, 0.00053053207656763, 0.00067183490719538, + 0.00090270349695879, 0.00007044013001792}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.07124611735343932, 0.07283763289451597, 0.08196553587913512, + 0.09501070976257323, 0.02654055953025817}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.4374369978904724, 0.11488933861255646, 0.6765823364257812, + 0.7659900188446045, 0.04410457238554955}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00129067017716555, 0.00104532555849556, 0.00133106720937621, + 0.00132869584719374, 0.00105226561254395}, + DataType::FLOAT32); + NDArray stateU1( + 'c', {1, 5}, + {0.00069844444999364, 0.00054320110461789, 0.00112892673025155, + 0.00148854150243139, 0.00007231490319321}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.10786520540714262, 0.07704280346632002, 0.14142721593379973, + 0.16210864067077635, 0.02829696081578731}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.496029257774353, 0.11621368676424026, 0.9112075567245483, + 0.5717480182647705, 0.5975669026374817}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initU, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00150986322036664, 0.00108559662275258, 0.00156079502787382, + 0.00150778241516558, 0.00130066803775601}, + DataType::FLOAT32); + NDArray stateU2( + 'c', {1, 5}, + {0.00094379103011182, 0.00055616352450461, 0.00195809701495322, + 0.00181394875731865, 0.00042932879141777}, + DataType::FLOAT32); + NDArray stateM2( + 'c', {1, 5}, + {0.14668161064386365, 0.08095989179611204, 0.21840525001287456, + 0.20307257843017573, 0.08522395499795674}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::adam_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::adam_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ada_delta_updater op; - - Nd4jStatus status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); - NDArray stateMsg0('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); - NDArray stateMsdx0('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); - - status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); - NDArray stateMsg1('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); - NDArray stateMsdx1('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); - - status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); - NDArray stateMsg2('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); - NDArray stateMsdx2('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initMsg('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMsdx('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = + op.execute({&grad, &initMsg, &initMsdx}, {&update, &initMsg, &initMsdx}, + {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00447209123431084, 0.00447212477470162, 0.00447213098596791, + 0.00447213315991723, 0.00447213416614627}, + DataType::FLOAT32); + NDArray stateMsg0('c', {1, 5}, + {0.05000000000000004, 0.20000000000000018, + 0.4500000000000004, 0.8000000000000007, 1.250000000000001}, + DataType::FLOAT32); + NDArray stateMsdx0('c', {1, 5}, + {0.0000009999800004, 0.00000099999500002, + 0.00000099999777778, 0.00000099999875, 0.0000009999992}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + status = op.execute({&grad, &initMsg, &initMsdx}, + {&update, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.0045290622655332, 0.00452909666868751, 0.00452910303972733, + 0.00452910526959756, 0.00452910630171004}, + DataType::FLOAT32); + NDArray stateMsg1( + 'c', {1, 5}, + {0.09750000000000009, 0.39000000000000035, 0.8775000000000008, + 1.5600000000000014, 2.4375000000000018}, + DataType::FLOAT32); + NDArray stateMsdx1( + 'c', {1, 5}, + {0.00000197560125063, 0.00000197563108174, 0.00000197563660612, + 0.00000197563853966, 0.00000197563943461}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + status = op.execute({&grad, &initMsg, &initMsdx}, + {&update, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00456759948242601, 0.00456763438748812, 0.00456764085147516, + 0.00456764311387702, 0.004567644161047}, + DataType::FLOAT32); + NDArray stateMsg2('c', {1, 5}, + {0.1426250000000001, 0.5705000000000005, 1.2836250000000011, + 2.282000000000002, 3.5656250000000025}, + DataType::FLOAT32); + NDArray stateMsdx2( + 'c', {1, 5}, + {0.0000029199694397, 0.00000292001372254, 0.00000292002192321, + 0.00000292002479346, 0.00000292002612198}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta2) { - - NDArray grad0('c', { 1, 5 }, { 0.22060230374336243, 0.10593396425247192, 0.9027279019355774, 0.831809401512146, 0.2733047902584076 }, DataType::FLOAT32); - NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto rho = NDArrayFactory::create(0.95f); - auto epsilon = NDArrayFactory::create(1.0e-6); - - sd::ops::ada_delta_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initMsg, &initMsdx, &rho, &epsilon }, { &grad0, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.0044712172817412, 0.00446815612502933, 0.00447208107763182, 0.004472071321461, 0.00447153735969189 }, DataType::FLOAT32); - NDArray stateMsg0('c', { 1, 5 }, { 0.00243326882084394, 0.0005611002391122, 0.04074588324665051, 0.03459534402219976, 0.00373477541890961 }, DataType::FLOAT32); - NDArray stateMsdx0('c', { 1, 5 }, { 0.00000099958919903, 0.00000099822095788, 0.00000099997545825, 0.00000099997109521, 0.00000099973231796 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); - - NDArray grad1('c', { 1, 5 }, { 0.6351608633995056, 0.21878601610660553, 0.6470938920974731, 0.3742971122264862, 0.9453978538513184 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initMsg, &initMsdx, &rho, &epsilon }, { &grad1, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00598985959779411, 0.00571609509028959, 0.00374704195122062, 0.00265092283150538, 0.00608704322078556 }, DataType::FLOAT32); - NDArray stateMsg1('c', { 1, 5 }, { 0.02248307149952203, 0.00292641126934659, 0.05964511434381081, 0.03987049323214412, 0.0482368917512981 }, DataType::FLOAT32); - NDArray stateMsdx1('c', { 1, 5 }, { 0.00000274353063914, 0.00000258199706405, 0.00000165199285454, 0.00000130134213338, 0.00000280235046064 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); - - NDArray grad2('c', { 1, 5 }, { 0.8484492301940918, 0.9634076952934265, 0.6676893830299377, 0.4450211524963379, 0.32364124059677124 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initMsg, &initMsdx, &rho, &epsilon }, { &grad2, &initMsg, &initMsdx }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00685468722145889, 0.00822128238053265, 0.00386965914609878, 0.00308849888680941, 0.00279277397245112 }, DataType::FLOAT32); - NDArray stateMsg2('c', { 1, 5 }, { 0.05735222273539331, 0.04918781007340889, 0.07895331423716523, 0.04777915987899536, 0.05106222979448406 }, DataType::FLOAT32); - NDArray stateMsdx2('c', { 1, 5 }, { 0.00000495569095238, 0.00000583237140987, 0.00000231810630717, 0.0000017132162954, 0.00000305221226067 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); - ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); + NDArray grad0('c', {1, 5}, + {0.22060230374336243, 0.10593396425247192, 0.9027279019355774, + 0.831809401512146, 0.2733047902584076}, + DataType::FLOAT32); + NDArray initMsg('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMsdx('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto rho = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.0e-6); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({&grad0, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad0, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.0044712172817412, 0.00446815612502933, 0.00447208107763182, + 0.004472071321461, 0.00447153735969189}, + DataType::FLOAT32); + NDArray stateMsg0( + 'c', {1, 5}, + {0.00243326882084394, 0.0005611002391122, 0.04074588324665051, + 0.03459534402219976, 0.00373477541890961}, + DataType::FLOAT32); + NDArray stateMsdx0( + 'c', {1, 5}, + {0.00000099958919903, 0.00000099822095788, 0.00000099997545825, + 0.00000099997109521, 0.00000099973231796}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + NDArray grad1('c', {1, 5}, + {0.6351608633995056, 0.21878601610660553, 0.6470938920974731, + 0.3742971122264862, 0.9453978538513184}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad1, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00598985959779411, 0.00571609509028959, 0.00374704195122062, + 0.00265092283150538, 0.00608704322078556}, + DataType::FLOAT32); + NDArray stateMsg1( + 'c', {1, 5}, + {0.02248307149952203, 0.00292641126934659, 0.05964511434381081, + 0.03987049323214412, 0.0482368917512981}, + DataType::FLOAT32); + NDArray stateMsdx1( + 'c', {1, 5}, + {0.00000274353063914, 0.00000258199706405, 0.00000165199285454, + 0.00000130134213338, 0.00000280235046064}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + NDArray grad2('c', {1, 5}, + {0.8484492301940918, 0.9634076952934265, 0.6676893830299377, + 0.4450211524963379, 0.32364124059677124}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initMsg, &initMsdx, &rho, &epsilon}, + {&grad2, &initMsg, &initMsdx}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00685468722145889, 0.00822128238053265, 0.00386965914609878, + 0.00308849888680941, 0.00279277397245112}, + DataType::FLOAT32); + NDArray stateMsg2( + 'c', {1, 5}, + {0.05735222273539331, 0.04918781007340889, 0.07895331423716523, + 0.04777915987899536, 0.05106222979448406}, + DataType::FLOAT32); + NDArray stateMsdx2( + 'c', {1, 5}, + {0.00000495569095238, 0.00000583237140987, 0.00000231810630717, + 0.0000017132162954, 0.00000305221226067}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msg - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msdx - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initMsg('f', { 1, 5 }, DataType::FLOAT32); - NDArray initMsdx('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initMsg.assign(initVC); - initMsdx.assign(initMC); - - sd::ops::ada_delta_updater op; - auto results = op.evaluate({ &grad, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); - NDArray stateMsg('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); - NDArray stateMsdx('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateMsg.assign(stateV0C); - stateMsdx.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &results.at(1), &results.at(2) }, { 0.95, 1.0e-6 }, { }); - - NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); - - NDArray stateV1C('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); - - update.assign(update1C); - stateMsg.assign(stateV1C); - stateMsdx.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateMsg, &stateMsdx }, { 0.95f, 1.0e-6 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); - - update.assign(update2C); - stateMsg.assign(stateV2C); - stateMsdx.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); - ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); - ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); - ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, + DataType::FLOAT32); // Msg + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, + DataType::FLOAT32); // Msdx + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initMsg('f', {1, 5}, DataType::FLOAT32); + NDArray initMsdx('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initMsg.assign(initVC); + initMsdx.assign(initMC); + + sd::ops::ada_delta_updater op; + auto results = op.evaluate({&grad, &initMsg, &initMsdx}, {0.95f, 1.0e-6}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00447209123431084, 0.00447212477470162, 0.00447213098596791, + 0.00447213315991723, 0.00447213416614627}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.05000000000000004, 0.20000000000000018, + 0.4500000000000004, 0.8000000000000007, 1.250000000000001}, + DataType::FLOAT32); + NDArray stateMsg('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C('c', {1, 5}, + {0.0000009999800004, 0.00000099999500002, + 0.00000099999777778, 0.00000099999875, 0.0000009999992}, + DataType::FLOAT32); + NDArray stateMsdx('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateMsg.assign(stateV0C); + stateMsdx.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = + op.evaluate({&grad, &results.at(1), &results.at(2)}, {0.95, 1.0e-6}, {}); + + NDArray update1C( + 'c', {1, 5}, + {0.0045290622655332, 0.00452909666868751, 0.00452910303972733, + 0.00452910526959756, 0.00452910630171004}, + DataType::FLOAT32); + + NDArray stateV1C('c', {1, 5}, + {0.09750000000000009, 0.39000000000000035, + 0.8775000000000008, 1.5600000000000014, 2.4375000000000018}, + DataType::FLOAT32); + NDArray stateM1C( + 'c', {1, 5}, + {0.00000197560125063, 0.00000197563108174, 0.00000197563660612, + 0.00000197563853966, 0.00000197563943461}, + DataType::FLOAT32); + + update.assign(update1C); + stateMsg.assign(stateV1C); + stateMsdx.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateMsg, &stateMsdx}, {0.95f, 1.0e-6}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00456759948242601, 0.00456763438748812, 0.00456764085147516, + 0.00456764311387702, 0.004567644161047}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.1426250000000001, 0.5705000000000005, 1.2836250000000011, + 2.282000000000002, 3.5656250000000025}, + DataType::FLOAT32); + NDArray stateM2C( + 'c', {1, 5}, + {0.0000029199694397, 0.00000292001372254, 0.00000292002192321, + 0.00000292002479346, 0.00000292002612198}, + DataType::FLOAT32); + + update.assign(update2C); + stateMsg.assign(stateV2C); + stateMsdx.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::nadam_updater op; - - Nd4jStatus status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); - NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::nadam_updater op; + + Nd4jStatus status = + op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.06008325654320519, 0.06008326604320069, 0.06008326920986652, + 0.06008327079319956, 0.0600832717431994}, + DataType::FLOAT32); + NDArray stateV('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.499999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.06061258367739481, 0.06061259045578174, 0.06061259271524436, + 0.06061259384497576, 0.06061259452281461}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initV, &initM}, {&update, &initV, &initM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.06281865774973168, 0.06281866348713228, 0.06281866539959938, + 0.06281866635583296, 0.06281866692957314}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam2) { - - NDArray grad0('c', { 1, 5 }, { 0.8047558665275574, 0.9653639197349548, 0.31240877509117126, 0.9530212879180908, 0.01295729912817478 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::nadam_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.06008325193356386, 0.0600832558615088, 0.06008321472550684, 0.06008325560661022, 0.0600818092240132 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.00064763200471052, 0.00093192749752604, 0.00009759924275397, 0.00090824957522506, 0.0000001678916007 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.08047558665275573, 0.09653639197349546, 0.03124087750911712, 0.09530212879180906, 0.00129572991281748 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.9839006662368774, 0.8964805603027344, 0.3631269931793213, 0.00931886397302151, 0.6320028901100159 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.06273730114378717, 0.0596708938019245, 0.06226533928512862, 0.02621380498466489, 0.06059567064824535 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.00161504489372718, 0.00173467296502922, 0.00022936285668667, 0.00090742816687558, 0.0003995953768165 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.17081809461116787, 0.17653080880641933, 0.06442948907613753, 0.08670380230993031, 0.06436644593253729 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.7712154984474182, 0.1282273381948471, 0.7019220590591431, 0.8883536458015442, 0.33057701587677 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.06062658222261493, 0.04001212712739213, 0.06906390273197544, 0.05804376499107734, 0.05097529565845974 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.00220820319387896, 0.00174938054232472, 0.00072182807082381, 0.0016956929387176, 0.00050847694486568 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2308578349947929, 0.1717004617452621, 0.12817874607443808, 0.16686878665909166, 0.09098750292696056 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.8047558665275574, 0.9653639197349548, 0.31240877509117126, + 0.9530212879180908, 0.01295729912817478}, + DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::nadam_updater op; + + Nd4jStatus status = + op.execute({&grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.06008325193356386, 0.0600832558615088, 0.06008321472550684, + 0.06008325560661022, 0.0600818092240132}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {0.00064763200471052, 0.00093192749752604, 0.00009759924275397, + 0.00090824957522506, 0.0000001678916007}, + DataType::FLOAT32); + NDArray stateM0( + 'c', {1, 5}, + {0.08047558665275573, 0.09653639197349546, 0.03124087750911712, + 0.09530212879180906, 0.00129572991281748}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.9839006662368774, 0.8964805603027344, 0.3631269931793213, + 0.00931886397302151, 0.6320028901100159}, + DataType::FLOAT32); + + status = op.execute({&grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.06273730114378717, 0.0596708938019245, 0.06226533928512862, + 0.02621380498466489, 0.06059567064824535}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {0.00161504489372718, 0.00173467296502922, 0.00022936285668667, + 0.00090742816687558, 0.0003995953768165}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.17081809461116787, 0.17653080880641933, 0.06442948907613753, + 0.08670380230993031, 0.06436644593253729}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.7712154984474182, 0.1282273381948471, 0.7019220590591431, + 0.8883536458015442, 0.33057701587677}, + DataType::FLOAT32); + + status = op.execute({&grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initV, &initM}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.06062658222261493, 0.04001212712739213, 0.06906390273197544, + 0.05804376499107734, 0.05097529565845974}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {0.00220820319387896, 0.00174938054232472, 0.00072182807082381, + 0.0016956929387176, 0.00050847694486568}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2308578349947929, 0.1717004617452621, 0.12817874607443808, + 0.16686878665909166, 0.09098750292696056}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterNadam3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - - sd::ops::nadam_updater op; - auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - - results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray update2C('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::nadam_updater op; + auto results = + op.evaluate({&grad, &initV, &initM}, {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.06008325654320519, 0.06008326604320069, 0.06008326920986652, + 0.06008327079319956, 0.0600832717431994}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.499999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.06061258367739481, 0.06061259045578174, 0.06061259271524436, + 0.06061259384497576, 0.06061259452281461}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({&grad, &stateV, &stateM}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.06281865774973168, 0.06281866348713228, 0.06281866539959938, + 0.06281866635583296, 0.06281866692957314}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad1) { - - NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray update('c', { 1, 5 }, DataType::FLOAT32); - - sd::ops::ams_grad_updater op; - - Nd4jStatus status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateH0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initH.equalsTo(stateH0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateH1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initH.equalsTo(stateH1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateH2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - - ASSERT_TRUE(update.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initH.equalsTo(stateH2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initH('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray update('c', {1, 5}, DataType::FLOAT32); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray stateV0('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateH0('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, + 0.29999999999999993, 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateH1('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({&grad, &initV, &initM, &initH}, + {&update, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateH2('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad2) { - - NDArray grad0('c', { 1, 5 }, { 0.5730348229408264, 0.04330538213253021, 0.249028742313385, 0.6514443755149841, 0.7017051577568054 }, DataType::FLOAT32); - NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - auto lr = NDArrayFactory::create(0.001f); - auto beta1 = NDArrayFactory::create(0.9f); - auto beta2 = NDArrayFactory::create(0.999f); - auto epsilon = NDArrayFactory::create(1.0e-8); - - sd::ops::ams_grad_updater op; - - Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp0('c', { 1, 5 }, { 0.00099999944815292, 0.00099999269777932, 0.00099999873015716, 0.00099999951457465, 0.00099999954934402 }, DataType::FLOAT32); - NDArray stateV0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.0004923901284225 }, DataType::FLOAT32); - NDArray stateH0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.00049239012842255 }, DataType::FLOAT32); - NDArray stateM0('c', { 1, 5 }, { 0.05730348229408263, 0.00433053821325302, 0.0249028742313385, 0.0651444375514984, 0.07017051577568052 }, DataType::FLOAT32); - - ASSERT_TRUE(grad0.equalsTo(updateExp0)); - ASSERT_TRUE(initV.equalsTo(stateV0)); - ASSERT_TRUE(initH.equalsTo(stateH0)); - ASSERT_TRUE(initM.equalsTo(stateM0)); - - NDArray grad1('c', { 1, 5 }, { 0.6404328346252441, 0.9432603120803833, 0.45608729124069214, 0.9097326993942261, 0.748093843460083 }, DataType::FLOAT32); - - status = op.execute({ &grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp1('c', { 1, 5 }, { 0.00134565543815267, 0.00104022434054697, 0.00130914539820157, 0.00133725290576052, 0.0013453914974122 }, DataType::FLOAT32); - NDArray stateV1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); - NDArray stateH1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); - NDArray stateM1('c', { 1, 5 }, { 0.11561641752719877, 0.09822351559996603, 0.06802131593227385, 0.14960326373577115, 0.13796284854412078 }, DataType::FLOAT32); - - ASSERT_TRUE(grad1.equalsTo(updateExp1)); - ASSERT_TRUE(initV.equalsTo(stateV1)); - ASSERT_TRUE(initH.equalsTo(stateH1)); - ASSERT_TRUE(initM.equalsTo(stateM1)); - - NDArray grad2('c', { 1, 5 }, { 0.46250319480895996, 0.09698919206857681, 0.21754667162895203, 0.46824514865875244, 0.6005083918571472 }, DataType::FLOAT32); - - status = op.execute({ &grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM, &initH }, { }, { }); - ASSERT_EQ(ND4J_STATUS_OK, status); - - NDArray updateExp2('c', { 1, 5 }, { 0.00154098993679222, 0.00103399135000281, 0.00147364850040774, 0.00149693641196572, 0.00155078467854623 }, DataType::FLOAT32); - NDArray stateV2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); - NDArray stateH2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); - NDArray stateM2('c', { 1, 5 }, { 0.1503050952553749, 0.09810008324682712, 0.08297385150194167, 0.1814674522280693, 0.1842174028754234 }, DataType::FLOAT32); - - ASSERT_TRUE(grad2.equalsTo(updateExp2)); - ASSERT_TRUE(initV.equalsTo(stateV2)); - ASSERT_TRUE(initH.equalsTo(stateH2)); - ASSERT_TRUE(initM.equalsTo(stateM2)); + NDArray grad0('c', {1, 5}, + {0.5730348229408264, 0.04330538213253021, 0.249028742313385, + 0.6514443755149841, 0.7017051577568054}, + DataType::FLOAT32); + NDArray initH('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initV('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initM('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute( + {&grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad0, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0( + 'c', {1, 5}, + {0.00099999944815292, 0.00099999269777932, 0.00099999873015716, + 0.00099999951457465, 0.00099999954934402}, + DataType::FLOAT32); + NDArray stateV0( + 'c', {1, 5}, + {0.00032836890830282, 0.00000187535612164, 0.00006201531449819, + 0.00042437977439011, 0.0004923901284225}, + DataType::FLOAT32); + NDArray stateH0( + 'c', {1, 5}, + {0.00032836890830282, 0.00000187535612164, 0.00006201531449819, + 0.00042437977439011, 0.00049239012842255}, + DataType::FLOAT32); + NDArray stateM0('c', {1, 5}, + {0.05730348229408263, 0.00433053821325302, 0.0249028742313385, + 0.0651444375514984, 0.07017051577568052}, + DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', {1, 5}, + {0.6404328346252441, 0.9432603120803833, 0.45608729124069214, + 0.9097326993942261, 0.748093843460083}, + DataType::FLOAT32); + + status = op.execute( + {&grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad1, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1( + 'c', {1, 5}, + {0.00134565543815267, 0.00104022434054697, 0.00130914539820157, + 0.00133725290576052, 0.0013453914974122}, + DataType::FLOAT32); + NDArray stateV1( + 'c', {1, 5}, + {0.00073819475506065, 0.00089161349711151, 0.00026996891641496, + 0.00125156897896282, 0.00105154213691696}, + DataType::FLOAT32); + NDArray stateH1( + 'c', {1, 5}, + {0.00073819475506065, 0.00089161349711151, 0.00026996891641496, + 0.00125156897896282, 0.00105154213691696}, + DataType::FLOAT32); + NDArray stateM1( + 'c', {1, 5}, + {0.11561641752719877, 0.09822351559996603, 0.06802131593227385, + 0.14960326373577115, 0.13796284854412078}, + DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', {1, 5}, + {0.46250319480895996, 0.09698919206857681, 0.21754667162895203, + 0.46824514865875244, 0.6005083918571472}, + DataType::FLOAT32); + + status = op.execute( + {&grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon}, + {&grad2, &initV, &initM, &initH}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2( + 'c', {1, 5}, + {0.00154098993679222, 0.00103399135000281, 0.00147364850040774, + 0.00149693641196572, 0.00155078467854623}, + DataType::FLOAT32); + NDArray stateV2( + 'c', {1, 5}, + {0.00095136576551408, 0.00090012878699251, 0.00031702550183538, + 0.00146957092922632, 0.0014111009234709}, + DataType::FLOAT32); + NDArray stateH2( + 'c', {1, 5}, + {0.00095136576551408, 0.00090012878699251, 0.00031702550183538, + 0.00146957092922632, 0.0014111009234709}, + DataType::FLOAT32); + NDArray stateM2('c', {1, 5}, + {0.1503050952553749, 0.09810008324682712, 0.08297385150194167, + 0.1814674522280693, 0.1842174028754234}, + DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad3) { - - NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); - NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - NDArray initHC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); - - NDArray grad('f', { 1, 5 }, DataType::FLOAT32); - NDArray initV('f', { 1, 5 }, DataType::FLOAT32); - NDArray initM('f', { 1, 5 }, DataType::FLOAT32); - NDArray initH('f', { 1, 5 }, DataType::FLOAT32); - - grad.assign(gradC); - initV.assign(initVC); - initM.assign(initMC); - initH.assign(initHC); - - sd::ops::ams_grad_updater op; - auto results = op.evaluate({ &grad, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); - NDArray update('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); - NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); - - NDArray stateH0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); - NDArray stateH('f', { 1, 5 }, DataType::FLOAT32); - - update.assign(updateC); - stateV.assign(stateV0C); - stateM.assign(stateM0C); - stateH.assign(stateH0C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); - - results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); - NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); - NDArray stateH1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); - - - update.assign(update1C); - stateV.assign(stateV1C); - stateM.assign(stateM1C); - stateH.assign(stateH1C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); - - results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); - - - NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); - NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); - NDArray stateH2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); - - - update.assign(update2C); - stateV.assign(stateV2C); - stateM.assign(stateM2C); - stateH.assign(stateH2C); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(update.isSameShape(results.at(0))); - ASSERT_TRUE(update.equalsTo(results.at(0))); - ASSERT_TRUE(stateV.isSameShape(results.at(1))); - ASSERT_TRUE(stateV.equalsTo(results.at(1))); - ASSERT_TRUE(stateM.isSameShape(results.at(2))); - ASSERT_TRUE(stateM.equalsTo(results.at(2))); - ASSERT_TRUE(stateH.isSameShape(results.at(3))); - ASSERT_TRUE(stateH.equalsTo(results.at(3))); + NDArray gradC('c', {1, 5}, {1, 2, 3, 4, 5}, DataType::FLOAT32); + NDArray initVC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initMC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + NDArray initHC('c', {1, 5}, {0.0, 0.0, 0.0, 0.0, 0.0}, DataType::FLOAT32); + + NDArray grad('f', {1, 5}, DataType::FLOAT32); + NDArray initV('f', {1, 5}, DataType::FLOAT32); + NDArray initM('f', {1, 5}, DataType::FLOAT32); + NDArray initH('f', {1, 5}, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + initH.assign(initHC); + + sd::ops::ams_grad_updater op; + auto results = op.evaluate({&grad, &initV, &initM, &initH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray updateC( + 'c', {1, 5}, + {0.00099999968377233, 0.00099999984188614, 0.00099999989459076, + 0.00099999992094306, 0.00099999993675445}, + DataType::FLOAT32); + NDArray update('f', {1, 5}, DataType::FLOAT32); + + NDArray stateV0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateV('f', {1, 5}, DataType::FLOAT32); + + NDArray stateM0C( + 'c', {1, 5}, + {0.09999999999999998, 0.19999999999999996, 0.29999999999999993, + 0.3999999999999999, 0.4999999999999999}, + DataType::FLOAT32); + NDArray stateM('f', {1, 5}, DataType::FLOAT32); + + NDArray stateH0C('c', {1, 5}, + {0.001, 0.004, 0.00900000000000001, 0.01600000000000001, + 0.02500000000000002}, + DataType::FLOAT32); + NDArray stateH('f', {1, 5}, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + stateH.assign(stateH0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({&grad, &stateV, &stateM, &stateH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C( + 'c', {1, 5}, + {0.00134383858541481, 0.00134383873569809, 0.00134383878579252, + 0.00134383881083974, 0.00134383882586807}, + DataType::FLOAT32); + NDArray stateV1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + NDArray stateM1C('c', {1, 5}, + {0.18999999999999995, 0.3799999999999999, 0.5699999999999998, + 0.7599999999999998, 0.9499999999999997}, + DataType::FLOAT32); + NDArray stateH1C('c', {1, 5}, + {0.001999, 0.00799600000000001, 0.01799100000000001, + 0.03198400000000003, 0.04997500000000005}, + DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + stateH.assign(stateH1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({&grad, &stateV, &stateM, &stateH}, + {0.001f, 0.9f, 0.999f, 1.0e-8}, {}); + + NDArray update2C( + 'c', {1, 5}, + {0.00156540157923389, 0.00156540172220632, 0.0015654017698638, + 0.00156540179369254, 0.00156540180798979}, + DataType::FLOAT32); + NDArray stateV2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + NDArray stateM2C('c', {1, 5}, + {0.2709999999999999, 0.5419999999999998, 0.8129999999999998, + 1.0839999999999996, 1.3549999999999995}, + DataType::FLOAT32); + NDArray stateH2C('c', {1, 5}, + {0.002997001, 0.01198800400000001, 0.02697300900000002, + 0.04795201600000004, 0.07492502500000006}, + DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + stateH.assign(stateH2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index f8a084f43984..a2682a996f90 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -14,345 +14,343 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include -#include #include +#include +#include +#include -using namespace sd; +#include "testlayers.h" +using namespace sd; class DeclarableOpsTests19 : public testing::Test { -public: - - DeclarableOpsTests19() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests19() { + printf("\n"); + fflush(stdout); + } }; TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { - auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); - auto exp_encoded = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); - auto exp_gradients = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + auto exp_encoded = + NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = + NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {0.5}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {0.5}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - //encoded->printIndexedBuffer("ENC"); + // encoded->printIndexedBuffer("ENC"); - ASSERT_EQ(exp_encoded, encoded); - ASSERT_EQ(exp_gradients, x); + ASSERT_EQ(exp_encoded, encoded); + ASSERT_EQ(exp_gradients, x); - // FIXME: we need to add a way to declare individual inplace outputs - //ASSERT_EQ(exp_gradients, *gradients); + // FIXME: we need to add a way to declare individual inplace outputs + // ASSERT_EQ(exp_gradients, *gradients); } TEST_F(DeclarableOpsTests19, test_threshold_encode_2) { - for (int length = 5; length < 35; length++) { - auto x = NDArrayFactory::create('c', {10000}); - auto exp_gradients = NDArrayFactory::create('c', {10000}); + for (int length = 5; length < 35; length++) { + auto x = NDArrayFactory::create('c', {10000}); + auto exp_gradients = NDArrayFactory::create('c', {10000}); - for (int e = 0; e < length; e++) { - x.p(e, 2e-3); - exp_gradients.p(e, 1e-3); - } + for (int e = 0; e < length; e++) { + x.p(e, 2e-3); + exp_gradients.p(e, 1e-3); + } - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1e-3}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1e-3}); - auto encoded = result.at(1); + auto encoded = result.at(1); - ASSERT_EQ(length + 4, encoded.lengthOf()); - ASSERT_EQ(exp_gradients, x); - } + ASSERT_EQ(length + 4, encoded.lengthOf()); + ASSERT_EQ(exp_gradients, x); + } } TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_1) { - auto x = NDArrayFactory::create('c', {6}); - x = 1.0f; + auto x = NDArrayFactory::create('c', {6}); + x = 1.0f; - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1.0}, {3}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {3}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - ASSERT_EQ(7, encoded.lengthOf()); - ASSERT_EQ(3, x.sumNumber().e(0)); + ASSERT_EQ(7, encoded.lengthOf()); + ASSERT_EQ(3, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_2) { - auto x = NDArrayFactory::create('c', {1000}); - x = 1.0f; + auto x = NDArrayFactory::create('c', {1000}); + x = 1.0f; - sd::ops::encode_threshold op; - auto result = op.evaluate({&x}, {1.0}, {100}); + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {100}); - auto gradients = result.at(0); - auto encoded = result.at(1); + auto gradients = result.at(0); + auto encoded = result.at(1); - ASSERT_EQ(104, encoded.lengthOf()); + ASSERT_EQ(104, encoded.lengthOf()); - ASSERT_EQ(900, x.sumNumber().e(0)); + ASSERT_EQ(900, x.sumNumber().e(0)); } TEST_F(DeclarableOpsTests19, test_threshold_decode_1) { - auto x = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); - auto y = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); - auto exp_gradients = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); - - sd::ops::decode_threshold op; - auto status = op.execute({&x, &y}, {&x}); - ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp_gradients, x); + auto x = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + auto y = + NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = + NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + + sd::ops::decode_threshold op; + auto status = op.execute({&x, &y}, {&x}); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(exp_gradients, x); } TEST_F(DeclarableOpsTests19, test_bitmap_encode_1) { - auto initial = NDArrayFactory::create('c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); - auto exp_0 = initial.like(); - auto exp_1 = initial.dup(); - auto exp_c = NDArrayFactory::create(2L); - - sd::ops::encode_bitmap enc; - auto enc_result = enc.evaluate({&initial}, {1e-3f}); - ASSERT_EQ(Status::OK(), enc_result.status()); + auto initial = NDArrayFactory::create( + 'c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); + auto exp_0 = initial.like(); + auto exp_1 = initial.dup(); + auto exp_c = NDArrayFactory::create(2L); - //initial.printIndexedBuffer("initial"); - ASSERT_EQ(exp_0, initial); + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {1e-3f}); + ASSERT_EQ(Status::OK(), enc_result.status()); - auto encoded = enc_result.at(1); - auto counter = enc_result.at(2); + // initial.printIndexedBuffer("initial"); + ASSERT_EQ(exp_0, initial); - //encoded->printIndexedBuffer("encoded"); + auto encoded = enc_result.at(1); + auto counter = enc_result.at(2); - ASSERT_EQ(exp_c, counter); + // encoded->printIndexedBuffer("encoded"); - sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, &encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(exp_c, counter); + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); - //initial.printIndexedBuffer(); + // initial.printIndexedBuffer(); - ASSERT_EQ(exp_1, initial); + ASSERT_EQ(exp_1, initial); } TEST_F(DeclarableOpsTests19, test_bitmap_encode_decode) { - auto initial = NDArrayFactory::create('c', {256000}); - initial = 1.0f; - auto exp = initial.dup(); - auto neg = initial.like(); - neg = 0.5f; - - sd::ops::encode_bitmap enc; - auto enc_result = enc.evaluate({&initial}, {0.5f}); - auto encoded = enc_result.at(1); - - // checking equality of all encoded bits - for (int e = 5; e < encoded.lengthOf() - 1; e++) { - if (encoded.e(e) != encoded.e(e - 1)) - nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded.e(e)); - } - - ASSERT_NE(exp, initial); - ASSERT_EQ(neg, initial); - - sd::ops::decode_bitmap dec; - auto status = dec.execute({&initial, &encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); - - // checking equality of all dedoded bits - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 1.0f) - nd4j_printf("initial[%i] = %f\n", e, f); - } - - - ASSERT_EQ(exp, initial); + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; + + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); + + // checking equality of all encoded bits + for (int e = 5; e < encoded.lengthOf() - 1; e++) { + if (encoded.e(e) != encoded.e(e - 1)) + nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, + encoded.e(e)); + } + + ASSERT_NE(exp, initial); + ASSERT_EQ(neg, initial); + + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) nd4j_printf("initial[%i] = %f\n", e, f); + } + + ASSERT_EQ(exp, initial); } TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { - auto initial = NDArrayFactory::create('c', {256000}); - initial = 1.0f; - auto exp = initial.dup(); - auto neg = initial.like(); - neg = 0.5f; - - sd::ops::encode_threshold enc; - auto enc_result = enc.evaluate({&initial}, {0.5f}); - auto encoded = enc_result.at(1); - - ASSERT_EQ(256000 + 4, encoded.lengthOf()); - ASSERT_NE(exp, initial); - - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 0.5f) { - nd4j_printf("initial[%i] = %f\n", e, f); - throw std::runtime_error(""); - } + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; + + sd::ops::encode_threshold enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); + + ASSERT_EQ(256000 + 4, encoded.lengthOf()); + ASSERT_NE(exp, initial); + + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 0.5f) { + nd4j_printf("initial[%i] = %f\n", e, f); + throw std::runtime_error(""); } - ASSERT_EQ(neg, initial); - - // checking equality of all encoded bits - //for (int e = 5; e < encoded->lengthOf() - 1; e++) { - //if (encoded->e(e) != encoded->e(e - 1) + 1) - //nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); - //} - - sd::ops::decode_threshold dec; - auto status = dec.execute({&initial, &encoded}, {&initial}); - ASSERT_EQ(Status::OK(), status); - - // checking equality of all dedoded bits - for (int e = 0; e < initial.lengthOf(); e++) { - auto f = initial.e(e); - if (f != 1.0f) - nd4j_printf("initial[%i] = %f\n", e, f); - } - - ASSERT_EQ(exp, initial); + } + ASSERT_EQ(neg, initial); + + // checking equality of all encoded bits + // for (int e = 5; e < encoded->lengthOf() - 1; e++) { + // if (encoded->e(e) != encoded->e(e - 1) + 1) + // nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, + // encoded->e(e)); + //} + + sd::ops::decode_threshold dec; + auto status = dec.execute({&initial, &encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) nd4j_printf("initial[%i] = %f\n", e, f); + } + + ASSERT_EQ(exp, initial); } - TEST_F(DeclarableOpsTests19, test_matmul_ccc) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('c', {10, 10}); - auto z = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('c', {10, 10}); + auto z = NDArrayFactory::create('c', {10, 10}); - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_fcf) { - auto x = NDArrayFactory::create('f', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_cff) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('f', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } - TEST_F(DeclarableOpsTests19, test_matmul_ccf) { - auto x = NDArrayFactory::create('c', {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_matmul_fff) { - auto x = NDArrayFactory::create('f', {10, 10}); - auto y = NDArrayFactory::create('f', {10, 10}); - auto e = NDArrayFactory::create('f', {10, 10}); - auto z = NDArrayFactory::create('f', {10, 10}); + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); - z.assign(100.f); - e.assign(110.f); - x.assign(1.0f); - y.assign(1.0f); + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); - sd::ops::matmul op; - auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); - ASSERT_EQ(Status::OK(), status); + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { - /* - DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") - .addInputs( - Nd4j.create(DataType.FLOAT, 2,2,12), - Nd4j.create(DataType.FLOAT, 3,2,3), - Nd4j.create(DataType.FLOAT, 2,3,6) - ) - .addOutputs( - Nd4j.create(DataType.FLOAT, 2,2,12), - Nd4j.create(DataType.FLOAT, 3,2,3)) - .addIntegerArguments(3,2,0,1,2,0) - .build(); - - Nd4j.exec(op); - */ - - auto t = NDArrayFactory::create('c', {2, 2, 12}); - auto u = NDArrayFactory::create('c', {3, 2, 3}); - auto v = NDArrayFactory::create('c', {2, 3, 6}); - - sd::ops::conv1d_bp op; - auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); - ASSERT_EQ(Status::OK(), result.status()); - + /* + DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") + .addInputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3), + Nd4j.create(DataType.FLOAT, 2,3,6) + ) + .addOutputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3)) + .addIntegerArguments(3,2,0,1,2,0) + .build(); + + Nd4j.exec(op); + */ + + auto t = NDArrayFactory::create('c', {2, 2, 12}); + auto u = NDArrayFactory::create('c', {3, 2, 3}); + auto v = NDArrayFactory::create('c', {2, 3, 6}); + + sd::ops::conv1d_bp op; + auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2, 0}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests19, test_squeeze_1) { - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto e = NDArrayFactory::create('c', {3, 4}); - int axis = 2; + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto e = NDArrayFactory::create('c', {3, 4}); + int axis = 2; - sd::ops::squeeze op; - auto status = op.execute({&x}, {&e}, {axis}); - ASSERT_EQ(Status::OK(), status); + sd::ops::squeeze op; + auto status = op.execute({&x}, {&e}, {axis}); + ASSERT_EQ(Status::OK(), status); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 1661eed56fe2..68fddd6f7a71 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -14,4166 +14,4290 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests2 : public testing::Test { -public: - - DeclarableOpsTests2() { - printf("\n"); - } + public: + DeclarableOpsTests2() { printf("\n"); } }; //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_1) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray indices('c', {1, 6}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 1, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray indices('c', {1,6}, {0,1, 2,2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,1,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}, sd::DataType::FLOAT32); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {1}); + auto result = op.evaluate({&input, &indices}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests2, gather_2) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // auto indices ('c', {1,6}, {0,1, 2,2, 1,2}); + NDArray expected( + 'c', {2, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - //auto indices ('c', {1,6}, {0,1, 2,2, 1,2}); - NDArray expected('c', {2,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input}, {}, {1, 0, 1, 2, 2, 1, 2}, {true}); - auto result = op.evaluate({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_3) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {1, 1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {2, 1, 1, 4}, {9, 10, 11, 12, 21, 22, 23, 24}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {1,1}, std::vector{2}, sd::DataType::INT32); - NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {1}); - auto result = op.evaluate({&input, &indices}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests2, gather_4) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // auto indices ('c', {1,1}, {2}); + NDArray expected('c', {2, 4}, {9, 10, 11, 12, 21, 22, 23, 24}); - NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - //auto indices ('c', {1,1}, {2}); - NDArray expected('c', {2,4}, {9,10,11,12,21,22,23,24}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input}, {}, {1, 2}); + auto result = op.evaluate({&input}, {}, {1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_5) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,2,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 21,22,23,24,17,18,19,20,21,22,23,24}); + sd::ops::gather op; - sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); - auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_6) { + NDArray input( + 'c', {3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); + NDArray expected( + 'c', {2, 3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); - NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); - NDArray expected('c', {2,3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36, 25,26,27,28,29,30,31,32,33,34,35,36, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {0}); + auto result = op.evaluate({&input, &indices}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_7) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT64); + NDArray expected( + 'c', {2, 3, 2, 3}, + {1, 2, 3, 3, 2, 3, 5, 6, 7, 7, 6, 7, 9, 10, 11, 11, 10, 11, + 13, 14, 15, 15, 14, 15, 17, 18, 19, 19, 18, 19, 21, 22, 23, 23, 22, 23}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT64); - NDArray expected('c', {2,3,2,3}, {1, 2, 3, 3, 2, 3, 5, 6, 7, 7, 6, 7, 9,10,11,11,10,11, 13,14,15,15,14,15, 17,18,19,19,18,19, 21,22,23,23,22,23}); - - sd::ops::gather op; + sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {2}); + auto result = op.evaluate({&input, &indices}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_8) { + NDArray input('c', {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + sd::DataType::FLOAT32); + NDArray indices('c', {1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {1, 5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::FLOAT32); - NDArray indices('c', {1}, std::vector{2}, sd::DataType::INT32); - NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); - - sd::ops::gather op; - - auto result = op.evaluate({&input, &indices}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printShapeInfo(); - // output->printIndexedBuffer(); + sd::ops::gather op; - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto result = op.evaluate({&input, &indices}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_9) { - NDArray x('c', {2, 4, 3, 2}, sd::DataType::FLOAT32); - NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT32); + NDArray x('c', {2, 4, 3, 2}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT32); - sd::ops::gather op; - auto result = op.evaluate({&x, &indices}, {}, {-2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {-2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_10) { - NDArray x('c', {2, 2}, {1, 2, 3, 4}); - NDArray e('c', {2, 2}, {3, 4, 1, 2}); + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); - sd::ops::gather op; - auto result = op.evaluate({&x}, {}, {0, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_11) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT64); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); - NDArray x('c', {2, 2}, {1, 2, 3, 4}); - NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT64); - NDArray e('c', {2, 2}, {3, 4, 1, 2}); - - sd::ops::gather op; - auto result = op.evaluate({&x, &indices}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_12) { + NDArray input('c', {4}, {2.f, 3.f, 4.f, 5.f}); + NDArray indices('c', {2}, {0, 2}, sd::DataType::INT32); + NDArray exp('c', {2}, {2.f, 4.f}); - NDArray input('c', {4}, {2.f, 3.f, 4.f, 5.f}); - NDArray indices('c', {2}, {0, 2}, sd::DataType::INT32); - NDArray exp('c', {2}, {2.f, 4.f}); - - sd::ops::gather op; - auto result = op.evaluate({&input, &indices}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_13) { - - NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray indices ('c', {2,3,4}, {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}, sd::DataType::INT32); - NDArray expected('c', {2,3, 2,3,4, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, - 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, - 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, - 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119}); - - input.linspace(0); - - sd::ops::gather op; - - auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + NDArray input('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray indices('c', {2, 3, 4}, {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}, + sd::DataType::INT32); + NDArray expected( + 'c', {2, 3, 2, 3, 4, 5}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, + 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, + 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, + 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, + 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, + 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + 116, 117, 118, 119, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 100, 101, 102, 103, + 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, + 118, 119, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, + 114, 115, 116, 117, 118, 119}); + + input.linspace(0); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_14) { + NDArray input('c', {2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + NDArray indices('c', {2, 3}, {0, 10, 2, 20, 1, 2}, sd::DataType::INT32); + NDArray output('c', {2, 2, 3, 4}); - NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, sd::DataType::INT32); - NDArray output('c', {2,2,3,4}); + sd::ops::gather op; - sd::ops::gather op; - - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, gather_15) { + NDArray input('c', {2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray indices('c', {2, 3, 4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, + sd::DataType::INT32); + NDArray output('c', {2, 3, 2, 3, 4, 5}); - NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); - NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, sd::DataType::INT32); - NDArray output('c', {2,3, 2,3,4, 5}); - - sd::ops::gather op; + sd::ops::gather op; - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); } TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { + NDArray input( + 'c', {3, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}, + sd::DataType::INT32); + NDArray indices('c', {2, 3}, {0, 1, 2, 2, 1, 2}, sd::DataType::INT32); - NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, sd::DataType::INT32); - NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); + sd::ops::broadcastgradientargs op; - sd::ops::broadcastgradientargs op; + auto result = op.evaluate({&input, &indices}, {}, {}); - auto result = op.evaluate({&input, &indices}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); - + ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); } TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095); - exp1.assign(0.019875); - exp2.assign(0.02); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::create('c', {2}, {4, 5}); - auto codes = NDArrayFactory::create('c', {2}, {1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {1}, {1}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095); + exp1.assign(0.019875); + exp2.assign(0.02); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); - x.linspace(1); - auto exp = x.reshape('c', {2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); + x.linspace(1); + auto exp = x.reshape('c', {2, 3, 4}); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - auto exp = new NDArray(x.dup()); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + auto exp = new NDArray(x.dup()); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - - delete exp; + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + + delete exp; } TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); - sd::ops::floormod op; + sd::ops::floormod op; - auto result = op.evaluate({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); - sd::ops::floordiv op; + sd::ops::floordiv op; - auto result = op.evaluate({&x, &y}, {}, {}); + auto result = op.evaluate({&x, &y}, {}, {}); - auto z = result.at(0); -// z->printShapeInfo("FloorDiv1 shape"); -// z->printIndexedBuffer("FloorDiv1"); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); + // z->printShapeInfo("FloorDiv1 shape"); + // z->printIndexedBuffer("FloorDiv1"); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { - auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); - auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); - auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); - - auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); - sd::ops::floordiv_bp op; + auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto z1 = result.at(0); - auto z2 = result.at(1); -// z->printShapeInfo("FloorDiv1 shape"); -// z1->printIndexedBuffer("FloorDiv2_1"); -// z2->printIndexedBuffer("FloorDiv2_2"); + sd::ops::floordiv_bp op; - ASSERT_TRUE(exp1.equalsTo(z1)); - ASSERT_TRUE(exp2.equalsTo(z2)); + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto z1 = result.at(0); + auto z2 = result.at(1); + // z->printShapeInfo("FloorDiv1 shape"); + // z1->printIndexedBuffer("FloorDiv2_1"); + // z2->printIndexedBuffer("FloorDiv2_2"); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.equalsTo(z2)); } TEST_F(DeclarableOpsTests2, Test_CRelu_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); - auto exp = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); - - sd::ops::crelu op; + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::crelu op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); + auto x = + NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); + auto eps = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); - sd::ops::crelu_bp op; - auto result = op.evaluate({&x, &eps}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::crelu_bp op; + auto result = op.evaluate({&x, &eps}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}); - auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); - auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create('c', {2, 2}); + auto eps = NDArrayFactory::create( + 'c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); + auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); - sd::ops::concat_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::concat_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto epsX = result.at(0); - auto epsY = result.at(1); + auto epsX = result.at(0); + auto epsY = result.at(1); - ASSERT_TRUE(expEX.isSameShape(epsX)); - ASSERT_TRUE(expEX.equalsTo(epsX)); + ASSERT_TRUE(expEX.isSameShape(epsX)); + ASSERT_TRUE(expEX.equalsTo(epsX)); - ASSERT_TRUE(expEY.isSameShape(epsY)); - ASSERT_TRUE(expEY.equalsTo(epsY)); + ASSERT_TRUE(expEY.isSameShape(epsY)); + ASSERT_TRUE(expEY.equalsTo(epsY)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1,4,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - // result.printIndexedBuffer("ADL test2"); - // expected.printIndexedBuffer("ADL expec"); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); + // result.printIndexedBuffer("ADL test2"); + // expected.printIndexedBuffer("ADL expec"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1,1,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 1, 5}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,1,5}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - expected.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); + expected.assign(0.f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - expected.assign(0.f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 60.f); + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 0.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 4, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,4,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 60.); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 60.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 60.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 0.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5); - weights.p(1, 0.f); - weights.p(2, 0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); + weights.p(1, 0.f); + weights.p(2, 0.f); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 2.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 2.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.f); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 2.01667, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 2.01667, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.f); - labels.p(0, 0.f); - labels.p(1, 0.f); - labels.p(2, 0.f); - labels.p(3, 0.f); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.93333, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 1, 1, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,1,1,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5f); - predictions.p(0, 0.f); - predictions.p(1, 0.f); - predictions.p(2, 0.f); - predictions.p(3, 0.); - labels.p(0, 0.f); - labels.p(1, 0.f); - labels.p(2, 0.f); - labels.p(3, 0.f); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.93333f, 1e-5); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.93333f, 1e-5); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 1.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 1.f); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 0.); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { + auto labels = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto weights = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto labels = NDArrayFactory::create('c', {2,3,4,5}); - auto predictions = NDArrayFactory::create('c', {2,3,4,5}); - auto weights = NDArrayFactory::create('c', {2,3,4,5}); - - labels.linspace(1); - predictions.linspace(3); - weights.assign(0.5); - predictions.p(0, 0.); - predictions.p(1, 0.); - predictions.p(2, 0.); - predictions.p(3, 0.); - labels.p(0, 0.); - labels.p(1, 0.); - labels.p(2, 0.); - labels.p(3, 0.); - weights.p(40+0, 0.); - weights.p(40+1, 0.); - weights.p(40+2, 0.); - weights.p(40+3, 0.); - - sd::ops::absolute_difference_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + predictions.p(0, 0.); + predictions.p(1, 0.); + predictions.p(2, 0.); + predictions.p(3, 0.); + labels.p(0, 0.); + labels.p(1, 0.); + labels.p(2, 0.); + labels.p(3, 0.); + weights.p(40 + 0, 0.); + weights.p(40 + 1, 0.); + weights.p(40 + 2, 0.); + weights.p(40 + 3, 0.); - auto result = results.at(0); + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.965517, 1e-5); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.965517, 1e-5); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {1, 3, 4}, + {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, + -275.5f, -307.5f, -341.5f, -377.5f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,4}); - auto expected = NDArrayFactory::create('c', {1,3,4}, {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, -275.5f, -307.5f, -341.5f, -377.5f}); + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); - labels.linspace(1); - predictions.linspace(2); - weights.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 1, 4}, + {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,1,4}, {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,1}); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 1}, {-2.f, -6.f, -10.f, -14.f, -18.f, -22.f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 1}, {-2.f, -6.f, -10.f, -14.f, -18.f, -22.f}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -71.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - auto result = results.at(0); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -71.f); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -71.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); - labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 0}); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -69.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -69.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -24.f); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.f); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -24.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -24.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - labels.linspace(1); - weights.assign(0.5f); - predictions.assign(0.5f); - weights.p(0, 0.f); - weights.p(1, 0.f); - - sd::ops::cosine_distance_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); + weights.p(0, 0.f); + weights.p(1, 0.f); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == -32.); - - + auto result = results.at(0); + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == -32.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printBuffer(); + auto result = results.at(0); + // result.printBuffer(); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - // result.printBuffer(); + auto result = results.at(0); + // result.printBuffer(); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 0., 0., 2.5, 0., 3.5, 0., 4.5, 0., 5.5, 0., 6.5, + 0., 7.5, 0., 8.5, 0., 9.5, 10., 0., 0., 11.5, 0., 12.5}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); - - logits.linspace(1); - weights.assign(0.5); - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); - auto result = results.at(0); - // result.printBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); + // result.printBuffer(); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + logits.linspace(1); + weights.assign(0.5); - logits.linspace(1); - weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - - logits.linspace(1); - weights.assign(0.5); + logits.linspace(1); + weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 83.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 83.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 6.91667, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.5); - + logits.linspace(1); + weights.assign(0.5); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 6.91667, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - logits.linspace(1); - weights.assign(0.5); - - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 6.91667, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.91667, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - - - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + logits.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 3.45833, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); + logits.linspace(1); + weights.assign(0.5); - logits.linspace(1); - weights.assign(0.5); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + auto result = results.at(0); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 3.45833, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.45833, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test12) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + logits.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 3.975, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 3.975, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, hinge_loss_test13) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(1); - weights.assign(0.); + logits.linspace(1); + weights.assign(0.); + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); - sd::ops::hinge_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_TRUE(result.e(0) == 0.); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_TRUE(result.e(0) == 0.); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.0425, 0.0875, 0.13250001, 0.17749999, 0.22250001, 0.26750001, + 0.31250003, 0.35749999, 0.4025, 0.44749999, 0.49249998, 0.53750002, + 0.58249998, 0.6275, 0.67250001, 0.71749997, 0.76249999, 0.8075, + 0.85250002, 0.89749998, 0.9425, 0.98749995, 1.03250015, 1.0775001}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 13.44, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 13.44, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.44, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.12, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.12, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.12, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.3, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.3, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.56, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.56, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.56, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, huber_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - labels.linspace(0.1, 0.1); - predictions.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); - sd::ops::huber_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.65, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.65, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, + 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, + 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433, + -4.09208679, -5.83734226, -7.8636713, -10.23689461, -13.05822182, + -16.49509811, -20.85659218, -26.82411766, -36.52717209}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); - auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -113.886429, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -113.886429, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -113.886429, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -113.886429, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -9.490536, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -9.490536, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + NDArray weights(sd::DataType::DOUBLE); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - NDArray weights(sd::DataType::DOUBLE); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -9.490536, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -9.490536, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -12.443609, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -12.443609, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -4.745268, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -4.745268, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -4.745268, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, log_loss_test13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.04, 0.04); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::log_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -6.221805, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.221805, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { - auto labels = NDArrayFactory::create('c', {1,3}, {0., 0.5, 1.}); - auto predictions = NDArrayFactory::create('c', {1,3}, {1., 1., 1.}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + auto labels = NDArrayFactory::create('c', {1, 3}, {0., 0.5, 1.}); + auto predictions = NDArrayFactory::create('c', {1, 3}, {1., 1., 1.}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto expected = NDArrayFactory::create('c', {1, 1}, {1.}); - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { - auto labels = NDArrayFactory::create('c', {10,4}, {-0.5533444483384939, -0.4045807428083095, -0.38990808632111873, -1.3367815555936828, 2.2110825342567204, -0.3322538938773163, 0.5683588435736076, 1.401524673423209, -0.2216208609234102, -0.23645194877057543, -1.9319189398422172, 0.6106128799796062, 1.6973842275926025, -2.8306371397325553E-4, -1.1550401544465256, -0.08357706614294765, -0.27784822018757077, 0.8290894318337857, 1.6484476009013025, -0.7752524785358668, -0.9700596207063842, 3.0809371469543207, -0.23684959888998405, 0.22403535560739518, 0.6146150452128438, -1.1250088686147994, -0.5915314787415693, -0.0944090155356556, 0.7995514825959854, -1.2290496239142903, -1.8329592004926936, -0.1694821152623061, -1.7614978090471403, 0.07929168376086736, 0.4086255139492943, 2.045562727396195, -0.48701853719962834, 0.10304152395720723, -0.8993147347502636, -0.49078404206110715}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.5982871220907984, 1.2010665656903237, 0.30243355682445544, -0.2070857400459659, 0.6962389393180044, -0.5878034128580758, 0.8325626284025988, -0.3555823702782838, -0.7099759151434476, 1.7971905051128672, -1.1018498592680859, 0.008705918349147959, -1.713038986676157, 0.5029671900704719, 0.7491261275031563, -0.34800067781360444, -1.3529065441284513, -0.6075230577852321, -0.6153583973120907, 1.6014780660677996, 0.6444219215516616, 0.7925830851904783, -0.5006063079380708, 1.7812300901376552, 0.4736193941708224, 1.411502849640833, 0.9555142545037492, -0.03936687661890644, 1.31661624967917, 0.7344531724786305, 0.8388550872918745, 0.7010030219905558, -0.5442944240155373, 0.4437344837841118, -1.7502823958671712, -1.9271369730241665, 0.9256612923554498, 1.9065401403827893, 0.42450175148842717, -0.11783183865542822}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - auto expected = NDArrayFactory::create('c', {10,1}, {1.9665822560405073, 3.806679563402927, 6.185624212589066, 20.237895345263905, 16.739700814450472, 13.655430201400929, 6.473256392322658, 3.9337379694106325, 22.509455553531062, 1.4741234749089487}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-0.5533444483384939, -0.4045807428083095, -0.38990808632111873, + -1.3367815555936828, 2.2110825342567204, -0.3322538938773163, + 0.5683588435736076, 1.401524673423209, -0.2216208609234102, + -0.23645194877057543, -1.9319189398422172, 0.6106128799796062, + 1.6973842275926025, -2.8306371397325553E-4, -1.1550401544465256, + -0.08357706614294765, -0.27784822018757077, 0.8290894318337857, + 1.6484476009013025, -0.7752524785358668, -0.9700596207063842, + 3.0809371469543207, -0.23684959888998405, 0.22403535560739518, + 0.6146150452128438, -1.1250088686147994, -0.5915314787415693, + -0.0944090155356556, 0.7995514825959854, -1.2290496239142903, + -1.8329592004926936, -0.1694821152623061, -1.7614978090471403, + 0.07929168376086736, 0.4086255139492943, 2.045562727396195, + -0.48701853719962834, 0.10304152395720723, -0.8993147347502636, + -0.49078404206110715}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.5982871220907984, 1.2010665656903237, 0.30243355682445544, + -0.2070857400459659, 0.6962389393180044, -0.5878034128580758, + 0.8325626284025988, -0.3555823702782838, -0.7099759151434476, + 1.7971905051128672, -1.1018498592680859, 0.008705918349147959, + -1.713038986676157, 0.5029671900704719, 0.7491261275031563, + -0.34800067781360444, -1.3529065441284513, -0.6075230577852321, + -0.6153583973120907, 1.6014780660677996, 0.6444219215516616, + 0.7925830851904783, -0.5006063079380708, 1.7812300901376552, + 0.4736193941708224, 1.411502849640833, 0.9555142545037492, + -0.03936687661890644, 1.31661624967917, 0.7344531724786305, + 0.8388550872918745, 0.7010030219905558, -0.5442944240155373, + 0.4437344837841118, -1.7502823958671712, -1.9271369730241665, + 0.9256612923554498, 1.9065401403827893, 0.42450175148842717, + -0.11783183865542822}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, + {1.9665822560405073, 3.806679563402927, 6.185624212589066, + 20.237895345263905, 16.739700814450472, 13.655430201400929, + 6.473256392322658, 3.9337379694106325, 22.509455553531062, + 1.4741234749089487}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - auto expected = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, 5.999534225166869, 22.58050883748054, 6.8600435676788605, 107.5976928688877, 191.56864939172544}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.9165069946629816, 0.166426191704143, 0.13873357227527264, + -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, + -0.4653205175596491, -1.7447031523970766, 1.349525448316014, + 2.433089865629357, -2.54858150221601, -0.6060282162911894, + 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, + -0.35787770401703584, -0.2608532564720665, 0.65688909921908, + -0.1705876431948587, 1.2052884124800949, -0.976783296084278, + 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, + 0.26460250034147065, -0.2299030354616135, -0.418989869909565, + 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, + 0.2984909806904042, 0.1329065864221682, 1.478600294413247, + 0.05421279873635542, -1.0552978360622536, -0.743808639782604, + -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, + 1.5030902829432997}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-3.398114657004427, 0.40587455906092945, 1.587706448479039, + 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, + -0.26929204111727345, -2.710461824817806, 0.9141296064806023, + -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, + -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, + 0.08522757123963924, -2.708523129389936, 0.09738215252575103, + -0.8797837670498875, 0.8714091607391934, -0.628958978867591, + 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, + -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, + -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, + 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, + 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, + 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, + -2.1005486985463966}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + auto expected = NDArrayFactory::create( + 'c', {10, 1}, + {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, + 5.999534225166869, 22.58050883748054, 6.8600435676788605, + 107.5976928688877, 191.56864939172544}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { - auto labels = NDArrayFactory::create('c', {10,4}, {-1.9540657282602247, -0.37099621218123746, 0.24959541842365968, 0.4125896396216978, -0.8661959659606203, 0.3651479206362867, -1.7475031047706964, -1.0962133982440159, 0.8451229874730279, 0.6876932162478913, 1.2598782790596628, 0.9372328828104118, 1.383555504464105, -0.816048166961237, 0.009041816630426176, -0.004376554457540983, -0.2386352931506252, -0.6494407817111416, 1.7888273635934742, -1.2157303560822368, -0.2446697859467434, -0.3040881765177774, -0.25843499040765916, -0.16479617511053568, 1.8063435075905592, 0.36002291874022285, -0.43317974028771883, 1.070086390817373, -1.0788479808458253, -0.3364318348487324, -0.859106579072977, 0.43984270049845064, -0.23662331183489546, -1.263417124724063, -0.3123732566483939, -0.125249623799724, -1.951308433393268, -0.4925779190927575, -1.081735149025745, -1.9910331435034687}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.7053977111021588, 1.7704125629388408, -0.0876171627499475, 0.9428762101237441, 0.9080108618240852, -0.478732892339118, -0.8189639230649537, 1.3359668242925342, -0.07499867017894829, 0.6169780756804321, -1.1891117691972148, -0.319354110980483, -1.4287263424900434, -0.3556443786879834, 0.6389682186473912, 0.3161742985911756, 0.9047447733840537, -1.9974117226910393, 2.1067775658502326, 0.17035521714679938, -1.1393894489992826, 1.4570837278971687, 0.6312249731754015, -0.42793125692777634, -1.0685964336386844, -0.3590636581851568, -0.19147354841437528, -0.10128937266756889, -0.5714869078294972, 0.2682604831358205, 0.6608524575561853, 0.35658907103040305, -0.7053263272861181, -0.6318441042427088, 2.131292677079184, -0.3624048087249232, 1.6008209804575328, 0.1245980660014825, 1.0685424462364297, -0.5672594432046791}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 60.74394998193965, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-1.9540657282602247, -0.37099621218123746, 0.24959541842365968, + 0.4125896396216978, -0.8661959659606203, 0.3651479206362867, + -1.7475031047706964, -1.0962133982440159, 0.8451229874730279, + 0.6876932162478913, 1.2598782790596628, 0.9372328828104118, + 1.383555504464105, -0.816048166961237, 0.009041816630426176, + -0.004376554457540983, -0.2386352931506252, -0.6494407817111416, + 1.7888273635934742, -1.2157303560822368, -0.2446697859467434, + -0.3040881765177774, -0.25843499040765916, -0.16479617511053568, + 1.8063435075905592, 0.36002291874022285, -0.43317974028771883, + 1.070086390817373, -1.0788479808458253, -0.3364318348487324, + -0.859106579072977, 0.43984270049845064, -0.23662331183489546, + -1.263417124724063, -0.3123732566483939, -0.125249623799724, + -1.951308433393268, -0.4925779190927575, -1.081735149025745, + -1.9910331435034687}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.7053977111021588, 1.7704125629388408, -0.0876171627499475, + 0.9428762101237441, 0.9080108618240852, -0.478732892339118, + -0.8189639230649537, 1.3359668242925342, -0.07499867017894829, + 0.6169780756804321, -1.1891117691972148, -0.319354110980483, + -1.4287263424900434, -0.3556443786879834, 0.6389682186473912, + 0.3161742985911756, 0.9047447733840537, -1.9974117226910393, + 2.1067775658502326, 0.17035521714679938, -1.1393894489992826, + 1.4570837278971687, 0.6312249731754015, -0.42793125692777634, + -1.0685964336386844, -0.3590636581851568, -0.19147354841437528, + -0.10128937266756889, -0.5714869078294972, 0.2682604831358205, + 0.6608524575561853, 0.35658907103040305, -0.7053263272861181, + -0.6318441042427088, 2.131292677079184, -0.3624048087249232, + 1.6008209804575328, 0.1245980660014825, 1.0685424462364297, + -0.5672594432046791}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 60.74394998193965, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 15.189082270182983, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.9165069946629816, 0.166426191704143, 0.13873357227527264, + -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, + -0.4653205175596491, -1.7447031523970766, 1.349525448316014, + 2.433089865629357, -2.54858150221601, -0.6060282162911894, + 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, + -0.35787770401703584, -0.2608532564720665, 0.65688909921908, + -0.1705876431948587, 1.2052884124800949, -0.976783296084278, + 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, + 0.26460250034147065, -0.2299030354616135, -0.418989869909565, + 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, + 0.2984909806904042, 0.1329065864221682, 1.478600294413247, + 0.05421279873635542, -1.0552978360622536, -0.743808639782604, + -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, + 1.5030902829432997}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-3.398114657004427, 0.40587455906092945, 1.587706448479039, + 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, + -0.26929204111727345, -2.710461824817806, 0.9141296064806023, + -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, + -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, + 0.08522757123963924, -2.708523129389936, 0.09738215252575103, + -0.8797837670498875, 0.8714091607391934, -0.628958978867591, + 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, + -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, + -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, + 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, + 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, + 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, + -2.1005486985463966}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 15.189082270182983, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.7712557146220891, 0.37344724586647443, -1.465944048516541, 0.3226845250222374, 0.3153238532645865, -0.6453963287132424, -1.7695663855309438, -0.31350813714835285, 0.6209850696184357, -1.0632582557661083, 0.8971205782356552, -0.7361143357044725, 0.4349813432397299, 1.1012674501462072, -1.846028584047857, -0.04711049067212126, 0.3511384383511822, -1.5908669452488973, 0.6271232025632083, -0.5370025878354387, 0.09775855957778733, 0.8465118033582384, -0.5118005514773271, -0.8215749768059044, -0.5154271246850248, -0.6614138367887438, -2.721743038982485, -0.20634785234624944, 1.074134378795222, -0.515671736473577, 0.33574452224656587, -0.4258992514621533, -1.6946210614398756, 2.0853105493575246, -0.23223717047374226, -1.3145231337861756, -0.307739072607248, -0.13713627422120406, -0.05615471338688221, -0.7031780205843188}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.8253096544930751, 0.81324545672996, 1.2530858908292535, 0.6881658781201572, 0.11626814971230247, 0.810096847233213, -0.41726775033902014, -0.07246036077805246, -0.3491325803119671, -0.7381717490678714, -1.258884944199858, 2.6195012275145992, 0.3241066697239042, -1.3306435333372646, -0.3413119919683999, 0.13167356361127197, -0.3992424507051653, 0.14454163796541403, -2.4931643208872316, 1.8740911656038526, -2.3404306490682956, -0.8036392545918644, -1.9726177395274997, -0.20128619801149433, -1.0680828820641624, -0.6228179015361869, 1.0785520122486962, -0.26148573195062036, -0.9154287856620913, 0.6612224269248097, -0.21735407368781667, 0.5584864652543093, 1.0208212201167435, -0.7560947201084579, -0.9092906572495081, 0.47525819203475833, 1.2215678456801444, -0.39319465979983964, 1.9435677135606038, 1.4540100039010526}); - auto weights = NDArrayFactory::create('c', {1,1}, {1}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 13.568564090650312, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.7712557146220891, 0.37344724586647443, -1.465944048516541, + 0.3226845250222374, 0.3153238532645865, -0.6453963287132424, + -1.7695663855309438, -0.31350813714835285, 0.6209850696184357, + -1.0632582557661083, 0.8971205782356552, -0.7361143357044725, + 0.4349813432397299, 1.1012674501462072, -1.846028584047857, + -0.04711049067212126, 0.3511384383511822, -1.5908669452488973, + 0.6271232025632083, -0.5370025878354387, 0.09775855957778733, + 0.8465118033582384, -0.5118005514773271, -0.8215749768059044, + -0.5154271246850248, -0.6614138367887438, -2.721743038982485, + -0.20634785234624944, 1.074134378795222, -0.515671736473577, + 0.33574452224656587, -0.4258992514621533, -1.6946210614398756, + 2.0853105493575246, -0.23223717047374226, -1.3145231337861756, + -0.307739072607248, -0.13713627422120406, -0.05615471338688221, + -0.7031780205843188}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.8253096544930751, 0.81324545672996, 1.2530858908292535, + 0.6881658781201572, 0.11626814971230247, 0.810096847233213, + -0.41726775033902014, -0.07246036077805246, -0.3491325803119671, + -0.7381717490678714, -1.258884944199858, 2.6195012275145992, + 0.3241066697239042, -1.3306435333372646, -0.3413119919683999, + 0.13167356361127197, -0.3992424507051653, 0.14454163796541403, + -2.4931643208872316, 1.8740911656038526, -2.3404306490682956, + -0.8036392545918644, -1.9726177395274997, -0.20128619801149433, + -1.0680828820641624, -0.6228179015361869, 1.0785520122486962, + -0.26148573195062036, -0.9154287856620913, 0.6612224269248097, + -0.21735407368781667, 0.5584864652543093, 1.0208212201167435, + -0.7560947201084579, -0.9092906572495081, 0.47525819203475833, + 1.2215678456801444, -0.39319465979983964, 1.9435677135606038, + 1.4540100039010526}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 13.568564090650312, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { - auto labels = NDArrayFactory::create('c', {10,4}, {-0.06125002348040258, 0.5143643450377119, 2.6790723358660036, -0.8032552006036418, -2.4374371040644163, -0.1562964773317163, -1.3957988654288038, 1.2791626503391635, -1.433421873294552, -1.1819478586737284, 0.05162930965054662, -0.538650473505593, -0.548171720093084, -0.3103900587344872, -2.3955103171953342, 0.7127238680062526, 0.7182079438418053, 1.1842662402382182, 0.09585189676958715, 0.9276146067349225, 0.7856673461867428, 0.41368195133354113, -0.2939280190178078, -2.400566355562181, -1.1841519118039245, -1.066170501847581, -0.9274507409610022, 1.7671863041813334, -1.2849985781031494, -1.275990164491566, -0.8866824403466698, -0.6074077385015517, 0.7647344603897107, -1.048099070426831, 0.9433828938345293, -0.5591415819237762, 1.7962773615541947, -0.42365710367758247, -0.0385518907389571, -1.109959713481321}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-0.7445687252538243, 0.2293875300325241, -1.0231630280206505, -0.18532545069458992, -0.07797403344353356, -0.9132035669873787, 0.9352296415512886, -1.7406458535354787, 0.8578334648119594, -0.6186274065269556, 0.4874824473654153, -0.9285817343788997, 0.1654680500853023, -0.6371334533926012, 1.3115245864160707, -2.072558735678832, 0.660795731844733, -0.34942292767044864, 0.05787182311194333, -0.12939210444705632, -0.6457028552461069, -0.6048992126598505, -0.17179604529778109, 1.292989642826032, -0.28867767615688045, 0.7635565516046265, -1.5464151753137487, -1.273368390129285, -1.074046012825826, -0.3534580692302915, 0.5757285568118223, 1.823271242883469, 0.31618576929075215, 0.5422847605415213, -0.7836698021860683, -0.6292022623165172, 2.1114596721927508, 0.4634986528550097, 0.08922001427846013, 1.5767749644913223}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 198.318201904499, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {-0.06125002348040258, 0.5143643450377119, 2.6790723358660036, + -0.8032552006036418, -2.4374371040644163, -0.1562964773317163, + -1.3957988654288038, 1.2791626503391635, -1.433421873294552, + -1.1819478586737284, 0.05162930965054662, -0.538650473505593, + -0.548171720093084, -0.3103900587344872, -2.3955103171953342, + 0.7127238680062526, 0.7182079438418053, 1.1842662402382182, + 0.09585189676958715, 0.9276146067349225, 0.7856673461867428, + 0.41368195133354113, -0.2939280190178078, -2.400566355562181, + -1.1841519118039245, -1.066170501847581, -0.9274507409610022, + 1.7671863041813334, -1.2849985781031494, -1.275990164491566, + -0.8866824403466698, -0.6074077385015517, 0.7647344603897107, + -1.048099070426831, 0.9433828938345293, -0.5591415819237762, + 1.7962773615541947, -0.42365710367758247, -0.0385518907389571, + -1.109959713481321}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-0.7445687252538243, 0.2293875300325241, -1.0231630280206505, + -0.18532545069458992, -0.07797403344353356, -0.9132035669873787, + 0.9352296415512886, -1.7406458535354787, 0.8578334648119594, + -0.6186274065269556, 0.4874824473654153, -0.9285817343788997, + 0.1654680500853023, -0.6371334533926012, 1.3115245864160707, + -2.072558735678832, 0.660795731844733, -0.34942292767044864, + 0.05787182311194333, -0.12939210444705632, -0.6457028552461069, + -0.6048992126598505, -0.17179604529778109, 1.292989642826032, + -0.28867767615688045, 0.7635565516046265, -1.5464151753137487, + -1.273368390129285, -1.074046012825826, -0.3534580692302915, + 0.5757285568118223, 1.823271242883469, 0.31618576929075215, + 0.5422847605415213, -0.7836698021860683, -0.6292022623165172, + 2.1114596721927508, 0.4634986528550097, 0.08922001427846013, + 1.5767749644913223}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 198.318201904499, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { - auto labels = NDArrayFactory::create('c', {10,4}, {1.2003157672694111, -1.0738078620687983, 1.4513396266923826, 0.5753935722952708, -0.5424028602429585, 0.9816221437385002, -1.0566397385428794, 1.503481308203513, -0.6543147953583112, 1.7453669976827346, -0.1557689124924227, 0.3387794658137257, -1.2306868494328145, -0.3299042398395769, 0.026464968146954395, -1.5077479623528403, -0.27514168845621795, 0.18739335150879793, 1.7319910646645431, 1.5228099405663476, 0.8522684742808536, 0.2362049362675063, 0.2610756525241469, 0.457998065505686, -2.7342179885912623, -0.10968795695808314, 0.581598742956297, -1.9309885922934567, -1.5775788440607954, -0.04254899350225641, -0.3125858556254039, -1.1328154327730207, 0.00566243314780096, 0.8492052576274621, 0.05945202212214481, 1.4976918834497108, 0.8869512918387292, 0.4014181932175132, -0.015512552855187248, -1.3609667909108454}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.1088399463364795, 0.09302972835006071, 0.033839927431215555, -0.39567507675572494, 0.8269497207597863, 1.111162272517752, 0.4930937252630912, -1.4561668998323452, 0.9417715392862969, -1.0553855492735509, 0.05848285303876081, 0.8852337518047972, -0.7472824481835305, 0.404906922583895, -0.2198309547562547, 1.9536515925189717, 0.8165036568007779, -0.19524282774410398, -0.09111693087754393, 1.1604245932512238, -0.6243762858131077, 1.4297003275591034, -0.17220079411538428, -2.3139504326793032, 0.3839796486999712, 2.0287791964679234, 0.1534441713632995, -0.6062103319229825, -0.4965880982906036, -0.373907747810053, -1.6566345746154432, 0.17534987728494222, -1.6713458890334796, 1.254628987947714, 1.914596591838086, -1.0816010467183583, 0.25033738231939673, -1.605752685708275, 1.1029112741353981, 0.3237822320282494}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 10.709003499121707, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {1.2003157672694111, -1.0738078620687983, 1.4513396266923826, + 0.5753935722952708, -0.5424028602429585, 0.9816221437385002, + -1.0566397385428794, 1.503481308203513, -0.6543147953583112, + 1.7453669976827346, -0.1557689124924227, 0.3387794658137257, + -1.2306868494328145, -0.3299042398395769, 0.026464968146954395, + -1.5077479623528403, -0.27514168845621795, 0.18739335150879793, + 1.7319910646645431, 1.5228099405663476, 0.8522684742808536, + 0.2362049362675063, 0.2610756525241469, 0.457998065505686, + -2.7342179885912623, -0.10968795695808314, 0.581598742956297, + -1.9309885922934567, -1.5775788440607954, -0.04254899350225641, + -0.3125858556254039, -1.1328154327730207, 0.00566243314780096, + 0.8492052576274621, 0.05945202212214481, 1.4976918834497108, + 0.8869512918387292, 0.4014181932175132, -0.015512552855187248, + -1.3609667909108454}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.1088399463364795, 0.09302972835006071, 0.033839927431215555, + -0.39567507675572494, 0.8269497207597863, 1.111162272517752, + 0.4930937252630912, -1.4561668998323452, 0.9417715392862969, + -1.0553855492735509, 0.05848285303876081, 0.8852337518047972, + -0.7472824481835305, 0.404906922583895, -0.2198309547562547, + 1.9536515925189717, 0.8165036568007779, -0.19524282774410398, + -0.09111693087754393, 1.1604245932512238, -0.6243762858131077, + 1.4297003275591034, -0.17220079411538428, -2.3139504326793032, + 0.3839796486999712, 2.0287791964679234, 0.1534441713632995, + -0.6062103319229825, -0.4965880982906036, -0.373907747810053, + -1.6566345746154432, 0.17534987728494222, -1.6713458890334796, + 1.254628987947714, 1.914596591838086, -1.0816010467183583, + 0.25033738231939673, -1.605752685708275, 1.1029112741353981, + 0.3237822320282494}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.709003499121707, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { - auto labels = NDArrayFactory::create('c', {10,4}, {0.054445708809271035, 2.107634671009908, -0.7906421810578572, -1.075840781788665, 0.11881403008710377, 0.8444812915085994, -0.305754504070933, 1.6429935026781464, 0.8155105031719394, 0.04900134907242568, 0.6847004530975871, 0.23315535615893132, 0.17011663306483038, -1.1865513655938285, 1.5931597087896407, -1.7937514075547496, -0.036695307704292295, -1.6416280650778925, 1.130578912176608, -1.1267224667674058, -0.8690453889645526, 0.6717944721406133, 0.0850200492927782, 1.1294419289013125, 0.2154793028698133, 0.4557382556428947, -0.7343674069166273, -0.20013117860162175, -0.6096905108192562, 0.42022878041905926, -0.7446306649741321, 0.01724811509597817, 1.843091605690758, 1.008879504632424, 1.198292190689489, -0.4474144618813475, 0.25202981742888664, 0.07036737843407408, 1.2400630276444486, -1.1072825235557615}); - auto predictions = NDArrayFactory::create('c', {10,4}, {-1.6788168943811437, 1.1823653279081687, -0.3580541857004183, -0.4449970504370699, -1.3031645333940127, 0.5755013195969282, -0.7997343141774744, -0.8806735270004084, 0.9705277499376251, -1.6360067944580943, 0.12579369136710156, 1.0525902242414313, -1.625751312422252, -0.03900152587147075, 0.4112500942756277, 0.6589999986358094, 0.6144107111689617, 2.8561269030217264, 1.5299963640392247, -0.314093051147705, 1.6523278218751989, -0.5504653447714114, 0.53395260877978, 0.409795577698306, 0.4466825218051794, 1.2382059301630401, 0.4834869732526594, -0.635409128905636, -1.9343816841697272, -0.4192523056060229, -1.0662979055059818, 0.4270901960618144, -0.7391311480757151, -0.8268168961897452, -1.0855715553457785, -9.410401291588706E-4, -0.7721838774717349, 0.4784019579457375, -0.6979798841469268, -0.319729737118584}); - auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); - - sd::ops::mean_pairwssqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 17.686067864414472, 1e-5); - - + auto labels = NDArrayFactory::create( + 'c', {10, 4}, + {0.054445708809271035, 2.107634671009908, -0.7906421810578572, + -1.075840781788665, 0.11881403008710377, 0.8444812915085994, + -0.305754504070933, 1.6429935026781464, 0.8155105031719394, + 0.04900134907242568, 0.6847004530975871, 0.23315535615893132, + 0.17011663306483038, -1.1865513655938285, 1.5931597087896407, + -1.7937514075547496, -0.036695307704292295, -1.6416280650778925, + 1.130578912176608, -1.1267224667674058, -0.8690453889645526, + 0.6717944721406133, 0.0850200492927782, 1.1294419289013125, + 0.2154793028698133, 0.4557382556428947, -0.7343674069166273, + -0.20013117860162175, -0.6096905108192562, 0.42022878041905926, + -0.7446306649741321, 0.01724811509597817, 1.843091605690758, + 1.008879504632424, 1.198292190689489, -0.4474144618813475, + 0.25202981742888664, 0.07036737843407408, 1.2400630276444486, + -1.1072825235557615}); + auto predictions = NDArrayFactory::create( + 'c', {10, 4}, + {-1.6788168943811437, 1.1823653279081687, -0.3580541857004183, + -0.4449970504370699, -1.3031645333940127, 0.5755013195969282, + -0.7997343141774744, -0.8806735270004084, 0.9705277499376251, + -1.6360067944580943, 0.12579369136710156, 1.0525902242414313, + -1.625751312422252, -0.03900152587147075, 0.4112500942756277, + 0.6589999986358094, 0.6144107111689617, 2.8561269030217264, + 1.5299963640392247, -0.314093051147705, 1.6523278218751989, + -0.5504653447714114, 0.53395260877978, 0.409795577698306, + 0.4466825218051794, 1.2382059301630401, 0.4834869732526594, + -0.635409128905636, -1.9343816841697272, -0.4192523056060229, + -1.0662979055059818, 0.4270901960618144, -0.7391311480757151, + -0.8268168961897452, -1.0855715553457785, -9.410401291588706E-4, + -0.7721838774717349, 0.4784019579457375, -0.6979798841469268, + -0.319729737118584}); + auto weights = NDArrayFactory::create( + 'c', {10, 1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 17.686067864414472, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, {0., 0., 0., 0., 3.125, 4.5, 6.125, 8., + 10.125, 12.5, 15.125, 18., 21.125, 24.5, 28.125, 32., + 36.125, 40.5, 45.125, 50., 55.125, 60.5, 66.125, 72.}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0., 0., 0., 0., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 612.5, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 612.5, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - weights.p(3, 0.); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 608.75, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 608.75, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 51.041668, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 51.041668, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 51.041668, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 51.041668, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 88.541664, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 88.541664, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 25.520834, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,4}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 25.520834, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 25.520834, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 25.520834, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { + auto labels = NDArrayFactory::create('c', {2, 3, 4}); + auto predictions = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4}); - auto predictions = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - predictions.linspace(0.5, 0.5); - labels.linspace(1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); - sd::ops::mean_sqerr_loss op; - auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 44.270832, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 44.270832, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.37219834, 0.29906943, 0.27717763, 0.45650762, 0.23703849, 0.51874399, + 0.20159303, 0.58555031, 0.17057693, 0.65663081, 0.14366767, 0.73164123, + 0.12050423, 0.81020868, 0.10070664, 0.89195037, 0.08389302, 0.97648883, + 1.01969337, 0.06346401, 0.05775976, 1.15254164, 0.04777273, 1.2434181}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.24719833, 0.54906946, 0.65217763, -0.04349237, 0.86203849, + -0.23125602, 1.07659304, -0.41444966, 1.29557693, -0.59336919, + 1.5186677, -0.76835877, 1.74550426, -0.93979132, 1.9757067, + -1.10804963, 2.20889306, -1.27351117, -1.35530663, 2.56346393, + 2.68275976, -1.59745836, 2.92277265, -1.7565819}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {0.24719833, 0.54906946, 0.65217763,-0.04349237,0.86203849,-0.23125602, 1.07659304,-0.41444966,1.29557693,-0.59336919, 1.5186677 ,-0.76835877,1.74550426,-0.93979132, 1.9757067 ,-1.10804963,2.20889306,-1.27351117,-1.35530663, 2.56346393,2.68275976,-1.59745836, 2.92277265,-1.7565819 }); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 11.2187976837, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 10.2187976837, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 10.2187976837, 1e-5); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 6.06840181351, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 6.06840181351, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1,4}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.934899806976, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.851566493511, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.851566493511, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 1.01140034199, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 1.01140034199, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 4}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,4}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.467449903488, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.425783246756, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.425783246756, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); - weights.p(2, 0.); + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); - sd::ops::sigm_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 0.505700170994, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 0.505700170994, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {1.39253557, 1.44253552, 1.44253552, 1.44253552, 1.39253557, 1.44253552}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {1.39253557,1.44253552,1.44253552,1.44253552,1.39253557,1.44253552}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + auto result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = + NDArrayFactory::create('c', {2, 3}, + {-0.92835701, -1.12835705, -1.12835705, + -1.12835705, -0.92835701, -1.12835705}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), 8.55521392822, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), 8.55521392822, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {2, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {2,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto result = results.at(0); + auto result = results.at(0); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -6.37014198303, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -2.12338066101, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.12338066101, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto weights = NDArrayFactory::create('c', {1, 3}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto weights = NDArrayFactory::create('c', {1,3}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = + op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -1.06169033051, 1e-5); - - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -1.06169033051, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = + op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(result.isScalar()); - ASSERT_NEAR(result.e(0), -2.18880319595, 1e-5); + auto result = results.at(0); - + ASSERT_TRUE(result.isScalar()); + ASSERT_NEAR(result.e(0), -2.18880319595, 1e-5); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {1.39253557, 1.44253552}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {1.39253557,1.44253552}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } - - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {2, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {-2.08880329, -2.28880334}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - logits.linspace(0.1, 0.1); - weights.assign(0.5); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto result = results.at(0); - auto result = results.at(0); - - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { + auto labels = + NDArrayFactory::create('c', {2, 4}, {0, 1, 1, 0, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 4}); + auto weights = NDArrayFactory::create('c', {1, 1}); + auto expected = + NDArrayFactory::create('c', {2, 1}, {-2.08880329, -2.28880334}); - auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,4}); - auto weights = NDArrayFactory::create('c', {1,1}); - auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); - - logits.linspace(0.1, 0.1); - weights.assign(0.5); - - sd::ops::softmax_cross_entropy_loss op; - auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + logits.linspace(0.1, 0.1); + weights.assign(0.5); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); - auto result = results.at(0); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + auto result = results.at(0); - + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test1) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99926789, 0.99926789, 0.99926789, 0.99926789, 0.99926789, 0.99926789, + 0.99926789, 0.99926789}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99987108, 3.99987108, 3.99987108, 3.99987108, 3.99987108, 3.99987108, + 3.99987108, 3.99987108}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 1.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test2) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{1.93001527,1.93001527,1.93001527,1.93001527, 1.93001527,1.93001527,1.93001527,1.93001527}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., -10.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.95867589, 0.95867589, 0.95867589, 0.95867589, 0.95867589, 0.95867589, + 0.95867589, 0.95867589}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {1.93001527, 1.93001527, 1.93001527, 1.93001527, 1.93001527, 1.93001527, + 1.93001527, 1.93001527}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., -10.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test3) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, + 0.37992568, 0.37992568}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0., 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test4) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, 0.37992568, + 0.37992568, 0.37992568}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0.3, 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test5) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.3,0.3,0.3,0.3,0.3,0.3}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, + {0.3, 0.3, 0.3, 0.3, 0.3, 0.3}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0.3, 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test6) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99832496,1.99832496,1.99832496,1.99832496,1.99832496,1.99832496}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99832496, 1.99832496, 1.99832496, 1.99832496, 1.99832496, 1.99832496}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99972188, 3.99972188, 3.99972188, 3.99972188, 3.99972188, 3.99972188, + 3.99972188, 3.99972188}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test7) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.75977136,0.75977136,0.75977136,0.75977136,0.75977136,0.75977136}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.75977136, 0.75977136, 0.75977136, 0.75977136, 0.75977136, 0.75977136}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, {0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0.4, 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test8) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99930672,0.99930672,0.99930672,0.99930672, 0.99930672,0.99930672,0.99930672,0.99930672}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct,1e-4)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99930672, 0.99930672, 0.99930672, 0.99930672, 0.99930672, 0.99930672, + 0.99930672, 0.99930672}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, + 3.99996277, 3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht, 1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct, 1e-4)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test9) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 4; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {0.99501777, 0.99501777, 0.99501777, 0.99501777, 0.99501777, 0.99501777, + 0.99501777, 0.99501777}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht, 1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test10) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99861344,1.99861344,1.99861344,1.99861344,1.99861344,1.99861344}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277, 3.99996277, 3.99996277, 3.99996277,3.99996277, 3.99996277, 3.99996277, 3.99996277}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99861344, 1.99861344, 1.99861344, 1.99861344, 1.99861344, 1.99861344}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, 3.99996277, + 3.99996277, 3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {0., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test11) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99003554,1.99003554,1.99003554,1.99003554,1.99003554,1.99003554}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numProj}, + {1.99003554, 1.99003554, 1.99003554, 1.99003554, 1.99003554, 1.99003554}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests2, lstmCell_test12) { - - const int batchSize = 2; - const int inSize = 10; - const int numProj = 3; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); - auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - ct_1.assign(3.); - Wx.assign(0.5); - Wh.assign(0.5); - Wc.assign(0.5); - Wp.assign(0.5); - b.assign(0.7); - - auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.,1.,1.,1.,1.,1.}); - auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); - - sd::ops::lstmCell op; - auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 1.,-5.}, {1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, + {1., 1., 1., 1., 1., 1.}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits}, + {3., 3., 3., 3., 3., 3., 3., 3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, + {3., 1., -5.}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto ht = results.at(0); + auto ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index e4883e7cc832..44058198e30f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -14,2381 +14,2770 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include -#include #include #include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests3 : public testing::Test { -public: - - DeclarableOpsTests3() { -// - } + public: + DeclarableOpsTests3() { + // + } }; - TEST_F(DeclarableOpsTests3, Test_Tile_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto rep_vector= NDArrayFactory::create('c', {1, 2}, {2, 2}); - std::vector reps({2, 2}); - - auto exp = x.tile(reps); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto rep_vector = NDArrayFactory::create('c', {1, 2}, {2, 2}); + std::vector reps({2, 2}); - sd::ops::tile op; - auto result = op.evaluate({&x, &rep_vector}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = x.tile(reps); - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::tile op; + auto result = op.evaluate({&x, &rep_vector}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Tile_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - std::vector reps({2, 2}); - - auto exp = x.tile(reps); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + std::vector reps({2, 2}); - sd::ops::tile op; - auto result = op.evaluate({&x}, {}, {2, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto exp = x.tile(reps); - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::tile op; + auto result = op.evaluate({&x}, {}, {2, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Permute_1) { - auto x= NDArrayFactory::create('c', {2, 3, 4}); - auto permute= NDArrayFactory::create('c', {1, 3}, {0, 2, 1}); - auto exp= NDArrayFactory::create('c', {2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto permute = NDArrayFactory::create('c', {1, 3}, {0, 2, 1}); + auto exp = NDArrayFactory::create('c', {2, 4, 3}); - sd::ops::permute op; - auto result = op.evaluate({&x, &permute}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x, &permute}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests3, Test_Permute_2) { - auto x= NDArrayFactory::create('c', {2, 3, 4}); - auto exp= NDArrayFactory::create('c', {4, 3, 2}); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 3, 2}); - sd::ops::permute op; - auto result = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests3, Test_Unique_1) { - auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); - auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); -// auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); - - sd::ops::unique op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); - - auto v = result.at(0); - auto i = result.at(1); - // v->printIndexedBuffer("Values"); - // i->printIndexedBuffer("Indices"); - // i->printShapeInfo("Indices shape"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI = NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); + // auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); + + sd::ops::unique op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - + auto v = result.at(0); + auto i = result.at(1); + // v->printIndexedBuffer("Values"); + // i->printIndexedBuffer("Indices"); + // i->printShapeInfo("Indices shape"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests3, Test_Unique_2) { - auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); - auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); - auto expC= NDArrayFactory::create('c', {3}, {2, 2, 1}); + auto x = + NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI = NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); + auto expC = NDArrayFactory::create('c', {3}, {2, 2, 1}); - sd::ops::unique_with_counts op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::unique_with_counts op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(3, result.size()); - auto v = result.at(0); - auto i = result.at(1); - auto c = result.at(2); + auto v = result.at(0); + auto i = result.at(1); + auto c = result.at(2); - // v->printShapeInfo(); - // v->printIndexedBuffer("Values"); - // i->printShapeInfo(); - // i->printIndexedBuffer("Indices"); - // c->printShapeInfo(); - // c->printIndexedBuffer("Counts"); + // v->printShapeInfo(); + // v->printIndexedBuffer("Values"); + // i->printShapeInfo(); + // i->printIndexedBuffer("Indices"); + // c->printShapeInfo(); + // c->printIndexedBuffer("Counts"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } TEST_F(DeclarableOpsTests3, Test_Rint_1) { - auto x= NDArrayFactory::create('c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); - auto exp= NDArrayFactory::create('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); + auto x = NDArrayFactory::create( + 'c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); + auto exp = NDArrayFactory::create( + 'c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); - sd::ops::rint op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::rint op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Norm_1) { - auto x = NDArrayFactory::create('c', {100, 100}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); - std::vector empty; - std::vector dims({1}); - sd::ops::norm op; + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; - auto result0 = op.evaluate({&x}, {0.}, {}); + auto result0 = op.evaluate({&x}, {0.}, {}); - auto z0 = result0.at(0); - auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); + auto z0 = result0.at(0); + auto exp0 = + x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); - auto result1 = op.evaluate({&x}, {1.}, {1}); - ASSERT_EQ(result1.status(), ND4J_STATUS_OK); - auto z1 = result1.at(0); - // z1->printIndexedBuffer("Z1"); - auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); - // exp1.printIndexedBuffer("EXP1"); - // z1->printShapeInfo("Z1 shape"); - // exp1.printShapeInfo("EXP1 shape"); - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + auto result1 = op.evaluate({&x}, {1.}, {1}); + ASSERT_EQ(result1.status(), ND4J_STATUS_OK); + auto z1 = result1.at(0); + // z1->printIndexedBuffer("Z1"); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); + // exp1.printIndexedBuffer("EXP1"); + // z1->printShapeInfo("Z1 shape"); + // exp1.printShapeInfo("EXP1 shape"); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); - auto result4 = op.evaluate({&x}, {4.}, {1}); + auto result4 = op.evaluate({&x}, {4.}, {1}); - auto z4 = result4.at(0); - auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); - ASSERT_TRUE(exp4.isSameShape(z4)); - ASSERT_TRUE(exp4.equalsTo(z4)); + auto z4 = result4.at(0); + auto exp4 = x.reduceAlongDimension(reduce::NormMax, dims, false, false); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); } - TEST_F(DeclarableOpsTests3, Test_Norm_2) { - auto x = NDArrayFactory::create('c', {100, 100}); - x.linspace(1); - auto axis= NDArrayFactory::create('c', {1, 1}, {1}); - - std::vector empty; - std::vector dims({1}); - sd::ops::norm op; + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); + auto axis = NDArrayFactory::create('c', {1, 1}, {1}); - auto result0 = op.evaluate({&x}, {0}, {}); + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; - auto z0 = result0.at(0); - auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); + auto result0 = op.evaluate({&x}, {0}, {}); - + auto z0 = result0.at(0); + auto exp0 = + x.reduceAlongDimension(reduce::NormFrobenius, empty, false, false); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); - auto result1 = op.evaluate({&x, &axis}, {1}, {}); + auto result1 = op.evaluate({&x, &axis}, {1}, {}); - auto z1 = result1.at(0); - auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + auto z1 = result1.at(0); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims, false, false); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); - auto result4 = op.evaluate({&x, &axis}, {4}, {}); - - auto z4 = result4.at(0); - auto exp4= x.reduceAlongDimension(reduce::NormMax, dims, false, false); - ASSERT_TRUE(exp4.isSameShape(z4)); - ASSERT_TRUE(exp4.equalsTo(z4)); + auto result4 = op.evaluate({&x, &axis}, {4}, {}); + auto z4 = result4.at(0); + auto exp4 = x.reduceAlongDimension(reduce::NormMax, dims, false, false); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); } - TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp = NDArrayFactory::create('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); + auto x = NDArrayFactory::create('c', {2, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); - sd::ops::clipbyavgnorm op; - auto result = op.evaluate({&x}, {0.8}, {}); + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.8}, {}); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); - sd::ops::clipbyavgnorm op; - auto result = op.evaluate({&x}, {0.9}, {}); + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.9}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); + auto x = NDArrayFactory::create('c', {2, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {4.0}, {}); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.0}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {6.0}, {}); + auto x = NDArrayFactory::create( + 'c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - auto z = result.at(0); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {6.0}, {}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { + auto x = NDArrayFactory::create('c', {3, 5}); + auto unities = NDArrayFactory::create('c', {3, 1}, {1., 1., 1.}); + auto scale = NDArrayFactory::create('c', {3, 1}, {1.1, 1., 0.9}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto unities = NDArrayFactory::create('c', {3, 1}, {1., 1., 1.}); - auto scale = NDArrayFactory::create('c', {3, 1}, {1.1, 1., 0.9}); + x.linspace(100.); - x.linspace(100.); + auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + x /= xNorm1; + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); - auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); - x /= xNorm1; - xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); + ASSERT_TRUE(unities.isSameShape(xNorm1)); + ASSERT_TRUE(unities.equalsTo(xNorm1)); - ASSERT_TRUE(unities.isSameShape(xNorm1)); - ASSERT_TRUE(unities.equalsTo(xNorm1)); + x *= scale; + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); - x *= scale; - xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.0}, {1}); + auto z = result.at(0); - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {1.0}, {1}); - auto z = result.at(0); - - auto zNorm1 = z.reduceAlongDimension(reduce::Norm2, {1}, true); - auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); - - ASSERT_TRUE(exp.isSameShape(&zNorm1)); - ASSERT_TRUE(exp.equalsTo(&zNorm1)); + auto zNorm1 = z.reduceAlongDimension(reduce::Norm2, {1}, true); + auto exp = NDArrayFactory::create('c', {3, 1}, + {1., 1., xNorm1.e(2)}); + ASSERT_TRUE(exp.isSameShape(&zNorm1)); + ASSERT_TRUE(exp.equalsTo(&zNorm1)); } TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { - auto x= NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto y= NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); - - auto exp0= NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); - auto exp1= NDArrayFactory::create('c', {3}, {1, 3, 5}); + auto x = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); - sd::ops::listdiff op; - auto result = op.evaluate({&x, &y}); + auto exp0 = NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); + auto exp1 = NDArrayFactory::create('c', {3}, {1, 3, 5}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}); - auto z0 = result.at(0); - auto z1 = result.at(1); + ASSERT_EQ(Status::OK(), result.status()); - z0.getDataBuffer()->syncToSpecial(true); // force sync - z1.getDataBuffer()->syncToSpecial(true); // force sync + auto z0 = result.at(0); + auto z1 = result.at(1); - ASSERT_TRUE(exp0.isSameShape(z0)); - ASSERT_TRUE(exp0.equalsTo(z0)); + z0.getDataBuffer()->syncToSpecial(true); // force sync + z1.getDataBuffer()->syncToSpecial(true); // force sync - ASSERT_TRUE(exp1.isSameShape(z1)); - ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); } TEST_F(DeclarableOpsTests3, Test_Range_1) { - auto start = NDArrayFactory::create(0.3f); - auto stop = NDArrayFactory::create(-5.f); - auto step = NDArrayFactory::create(-0.33f); - auto exp= NDArrayFactory::create('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); + auto start = NDArrayFactory::create(0.3f); + auto stop = NDArrayFactory::create(-5.f); + auto step = NDArrayFactory::create(-0.33f); + auto exp = NDArrayFactory::create( + 'c', {17}, + {0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, + -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_2) { - auto start= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {-1.f}); - auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); + auto start = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {-1.f}); + auto exp = NDArrayFactory::create('c', {2}, {2.f, 1.f}); - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_3) { - auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); - - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}); + auto start = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_10) { - auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); - auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); - auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); - - sd::ops::range op; - auto result = op.evaluate({&start, &stop, &step}, {sd::DataType::DOUBLE}); + auto start = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop = NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step = NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}, {sd::DataType::DOUBLE}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_4) { - auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); - - sd::ops::range op; - auto result = op.evaluate({}, {-10., 10., 1.666}, {}); + auto exp = NDArrayFactory::create( + 'c', {13}, + {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, + 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {-10., 10., 1.666}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests3, Test_Range_5) { - auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); - - sd::ops::range op; - auto result = op.evaluate({}, {2, 0, -1}, {}); + auto exp = NDArrayFactory::create('c', {2}, {2.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {2, 0, -1}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_6) { - auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2}, {0.f, 1.f}); - sd::ops::range op; - auto result = op.evaluate({}, {0, 2, 1}, {}); + sd::ops::range op; + auto result = op.evaluate({}, {0, 2, 1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_7) { - auto exp= NDArrayFactory::create('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); - - sd::ops::range op; - auto result = op.evaluate({}, {10,-5,-1.666}, {}); + auto exp = + NDArrayFactory::create('c', {10}, + {10.f, 8.334f, 6.668f, 5.002f, 3.336f, + 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {10, -5, -1.666}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - - TEST_F(DeclarableOpsTests3, Test_Range_8) { - auto exp= NDArrayFactory::create('c', {2}, {2, 1}); - - sd::ops::range op; - auto result = op.evaluate({}, {}, {2, 0, -1}); + auto exp = NDArrayFactory::create('c', {2}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::range op; + auto result = op.evaluate({}, {}, {2, 0, -1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Range_9) { - auto exp= NDArrayFactory::create('c', {2}, {0, 1}); + auto exp = NDArrayFactory::create('c', {2}, {0, 1}); - sd::ops::range op; - auto result = op.evaluate({}, {}, {0, 2, 1}); + sd::ops::range op; + auto result = op.evaluate({}, {}, {0, 2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); -// exp->printIndexedBuffer("e"); -// z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); -// exp->printIndexedBuffer("e"); -// z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create( + 'f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y = NDArrayFactory::create( + 'f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 5, 4, 3, 5, 3, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto y= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create( + 'c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 5, 4, 3, 3, 4, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } - TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('f', {2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create( + 'f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {111, 111, 2, 3, 5, 2, 5, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { - auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); - auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); - auto x= NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - auto y= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto a = NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b = NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x = NDArrayFactory::create('c', {2, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y = NDArrayFactory::create( + 'c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto exp = MmulHelper::mmul(&x, &y); + auto exp = MmulHelper::mmul(&x, &y); - // exp->printShapeInfo("exp shape"); + // exp->printShapeInfo("exp shape"); - sd::ops::batched_gemm op; - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - for (int e = 0; e < 3; e++) { - auto z = result.at(e); + for (int e = 0; e < 3; e++) { + auto z = result.at(e); - //exp->printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + // exp->printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - ASSERT_TRUE(exp->isSameShape(z)); - ASSERT_TRUE(exp->equalsTo(z)); - } + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } - delete exp; - + delete exp; } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { - auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); - auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - - sd::ops::batched_gemm op; - try { - auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); - - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create( + 'c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + sd::ops::batched_gemm op; + try { + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}); + + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { - auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); - auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - - auto z = NDArrayFactory::create('c', {2, 3}); - - sd::ops::batched_gemm op; - try { - auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); - ASSERT_TRUE(false); - } catch (std::invalid_argument &e) { - // - } + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create( + 'c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + auto z = NDArrayFactory::create('c', {2, 3}); + + sd::ops::batched_gemm op; + try { + auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, + {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } } TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { - auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); - auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); - auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); - - sd::ops::reversedivide op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto y = NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); + auto exp = NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); - auto z = result.at(0); + sd::ops::reversedivide op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test1) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); - - xt.assign(1.); - ct_1.assign(2.); - w.assign(0.5); - b.assign(0.7); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(0.7); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, + 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, + 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - auto ht = results.at(0); - auto ct = results.at(1); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); + auto ht = results.at(0); + auto ct = results.at(1); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } - //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test2) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - xt.assign(1.); - ct_1.assign(2.); - w.assign(0.5); - b.assign(-1.); + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(-1.); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, + 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, + 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto ht = results.at(0); - auto ct = results.at(1); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); + auto ht = results.at(0); + auto ct = results.at(1); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, sruCell_test3) { + const int batchSize = 2; + const int inSize = 5; - const int batchSize = 2; - const int inSize = 5; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); - auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); - auto b = NDArrayFactory::create('c', {2*inSize}); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3 * inSize}); + auto b = NDArrayFactory::create('c', {2 * inSize}); - xt.assign(10.); - ct_1.assign(1.); - w.assign(0.5); - b.assign(-1.); + xt.assign(10.); + ct_1.assign(1.); + w.assign(0.5); + b.assign(-1.); - auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); - auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, + 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); + auto expCt = NDArrayFactory::create( + 'c', {batchSize, inSize}, + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - sd::ops::sruCell op; - auto results = op.evaluate({&xt, &ct_1, &w, &b}); + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto ht = results.at(0); - auto ct = results.at(1); + auto ht = results.at(0); + auto ct = results.at(1); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - ASSERT_TRUE(expCt.isSameShape(ct)); - ASSERT_TRUE(expCt.equalsTo(ct)); - - + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test1) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); - - xt.assign(1.); - ht_1.assign(2.); - Wru.assign(0.5); - Wc.assign(0.5); - bru.assign(0.7); - bc.assign(0.7); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); - auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); + xt.assign(1.); + ht_1.assign(2.); + Wru.assign(0.5); + Wc.assign(0.5); + bru.assign(0.7); + bc.assign(0.7); - sd::ops::gruCell op; - auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, + 1.99993872f, 1.99993872f, 1.99993872f}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - auto ht = results.at(3); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + auto ht = results.at(3); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test2) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(1.5); + Wc.assign(1.5); + bru.assign(-10); + bc.assign(-10); - xt.assign(1.); - ht_1.assign(0.); - Wru.assign(1.5); - Wc.assign(1.5); - bru.assign(-10); - bc.assign(-10); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, + 0.00669224f, 0.00669224f, 0.00669224f}); - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - sd::ops::gruCell op; - auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto ht = results.at(3); - auto ht = results.at(3); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); - - + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, gruCell_test3) { + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; - const int batchSize = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {batchSize, inSize}); - auto ht_1= NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); - auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); - auto bru = NDArrayFactory::create('c', {2*numUnits}); - auto bc = NDArrayFactory::create('c', {numUnits}); - - xt.assign(1.); - ht_1.assign(0.); - Wru.assign(0.1); - Wc.assign(0.1); - bru.assign(1); - bc.assign(1); + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = + NDArrayFactory::create('c', {(inSize + numUnits), 2 * numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize + numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2 * numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); - auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(0.1); + Wc.assign(0.1); + bru.assign(1); + bc.assign(1); + auto expHt = NDArrayFactory::create( + 'c', {batchSize, numUnits}, + {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, + 0.1149149f, 0.1149149f}); - sd::ops::gruCell op; - auto result = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + sd::ops::gruCell op; + auto result = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto ht = result.at(3); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + auto ht = result.at(3); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test1) { + auto input = + NDArrayFactory::create('c', {1, 8}, {5, 2, 7, 4, 6, 3, 1, 0}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test2) { + auto input = + NDArrayFactory::create('c', {1, 8}, {5, 2, 7, 4, 6, 3, 1, 0}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); - auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); - + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, invertPermutation_test3) { + auto input = + NDArrayFactory::create('c', {1, 8}, {1, 2, 0, 4, 6, 3, 5, 7}); + auto expected = + NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); - auto input= NDArrayFactory::create('c', {1, 8}, {1,2,0,4,6,3,5,7}); - auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); - - sd::ops::invert_permutation op; - auto result = op.evaluate({&input}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test1) { + auto input = NDArrayFactory::create('c', {3, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {3, 2}); - input.linspace(1); + auto expected = NDArrayFactory::create( + 'c', {3, 2, 3, 2}, + {1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 6}); - auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + sd::ops::diag op; + auto result = op.evaluate({&input}); - sd::ops::diag op; - auto result = op.evaluate({&input}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test2) { + auto input = NDArrayFactory::create('c', {2, 3}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2, 3}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); - - sd::ops::diag op; - auto result = op.evaluate({&input}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2, 3}, + {1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 6}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::diag op; + auto result = op.evaluate({&input}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test_vector) { + auto input = NDArrayFactory::linspace(1, 4, 4); + auto expected = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4}); + sd::ops::diag op; + auto result = op.evaluate({input}); - auto input = NDArrayFactory::linspace(1,4,4); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); - - sd::ops::diag op; - auto result = op.evaluate({input}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete input; + delete input; } TEST_F(DeclarableOpsTests3, diag_test_col_vector) { + auto input = NDArrayFactory::linspace(1, 4, 4); + input->reshapei({4, 1}); + auto expected = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 4}); + sd::ops::diag op; + auto result = op.evaluate({input}, {}, {}); - auto input = NDArrayFactory::linspace(1,4,4); - input->reshapei({4,1}); - auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); - - sd::ops::diag op; - auto result = op.evaluate({input}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); - - delete input; + delete input; } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test3) { + auto input = NDArrayFactory::create('c', {1, 3}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {1, 3}); - input.linspace(1); + auto expected = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 2, 0, 0, 0, 3}); - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test4) { + auto input = NDArrayFactory::create('c', {3, 1}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {3, 1}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); - - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + auto expected = + NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 2, 0, 0, 0, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test5) { + auto input = NDArrayFactory::create('c', {1, 1}); + input.linspace(2); - auto input= NDArrayFactory::create('c', {1, 1}); - input.linspace(2); + auto expected = NDArrayFactory::create('c', {1, 1}, {2}); - auto expected= NDArrayFactory::create('c', {1,1}, {2}); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diag_test6) { + auto input = NDArrayFactory::create('c', {2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); - - sd::ops::diag op; - auto result = op.evaluate({&input}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2, 2, 2, 2}, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, + 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, + 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 8}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto diagonal = NDArrayFactory::create('c', {4, 2}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {4,3,2}); - auto diagonal= NDArrayFactory::create('c', {4,2}); - input.assign(0.); - diagonal.assign(1.); + auto expected = NDArrayFactory::create( + 'c', {4, 3, 2}, + {1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0}); - auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { + auto input = NDArrayFactory::create('c', {1, 1, 2}); + auto diagonal = NDArrayFactory::create('c', {1, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {1,1,2}); - auto diagonal= NDArrayFactory::create('c', {1,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {1,1,2}, {1.f, 0.f}); + auto expected = NDArrayFactory::create('c', {1, 1, 2}, {1.f, 0.f}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { + auto input = NDArrayFactory::create('c', {2, 1, 4}); + auto diagonal = NDArrayFactory::create('c', {2, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {2,1,4}); - auto diagonal= NDArrayFactory::create('c', {2,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); + auto expected = + NDArrayFactory::create('c', {2, 1, 4}, {1, 0, 0, 0, 1, 0, 0, 0}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { + auto input = NDArrayFactory::create('c', {2, 1, 4, 1}); + auto diagonal = NDArrayFactory::create('c', {2, 1, 1}); + input.assign(0.); + diagonal.assign(1.); - auto input= NDArrayFactory::create('c', {2,1,4,1}); - auto diagonal= NDArrayFactory::create('c', {2,1,1}); - input.assign(0.); - diagonal.assign(1.); - - auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); + auto expected = NDArrayFactory::create('c', {2, 1, 4, 1}, + {1, 0, 0, 0, 1, 0, 0, 0}); - sd::ops::matrix_set_diag op; - auto result = op.evaluate({&input, &diagonal}, {}, {}); + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test1) { + auto input = NDArrayFactory::create('c', {2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2}, {1,4}); + auto expected = NDArrayFactory::create('c', {2}, {1, 4}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printBuffer(); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + // output->printBuffer(); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test2) { + auto input = NDArrayFactory::create('c', {2, 2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); + auto expected = NDArrayFactory::create('c', {2, 2}, {1, 6, 11, 16}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, diagPart_test3) { + auto input = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); + input.linspace(1); - auto input= NDArrayFactory::create('c', {2,2,2,2,2,2}); - input.linspace(1); - - auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 2}, {1, 10, 19, 28, 37, 46, 55, 64}); - sd::ops::diag_part op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test1) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a = NDArrayFactory::create('c', {3,3}); - auto b = NDArrayFactory::create('c', {3,3}); - auto x = NDArrayFactory::create('c', {3,3}); - - a.linspace((float16)0.1, (float16)0.1); - b.linspace((float16)0.1, (float16)0.1); - x.assign(0.1); + a.linspace((float16)0.1, (float16)0.1); + b.linspace((float16)0.1, (float16)0.1); + x.assign(0.1); - auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-2)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test2) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(0.1, 0.1); - b.linspace(0.1, 0.1); - x.assign(0.1); - - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test3) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); - a.linspace(0.1, 0.1); - b.linspace(0.1, 0.1); - x.assign(0.1); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, + 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test4) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(1); - b.linspace(1); - x.assign(0.1); - - auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); + a.linspace(1); + b.linspace(1); + x.assign(0.1); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, + 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, + 1.14644360e-05f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test5) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + a.linspace(3200.); + b.linspace(3200.); + x.assign(0.1); - a.linspace(3200.); - b.linspace(3200.); - x.assign(0.1); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test6) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(0.1); - - auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.1); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, + 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, + 8.39227685e-10f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test7) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.9); - a.linspace(10.); - b.linspace(10.); - x.assign(0.9); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, + 0.99999998f, 0.99999999f, 1.f, 1.f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test8) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(1.); - - auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); + a.linspace(10.); + b.linspace(10.); + x.assign(1.); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test9) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.); - a.linspace(10.); - b.linspace(10.); - x.assign(0.); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test10) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto a= NDArrayFactory::create('c', {3,3}); - auto b= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - a.linspace(10.); - b.linspace(10.); - x.assign(0.5); - - auto expected= NDArrayFactory::create('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + a.linspace(10.); + b.linspace(10.); + x.assign(0.5); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test11) { + NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, + sd::DataType::FLOAT32); + NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, + sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, + sd::DataType::FLOAT32); - NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, sd::DataType::FLOAT32); - NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, sd::DataType::FLOAT32); - NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); + NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, + sd::DataType::FLOAT32); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, sd::DataType::FLOAT32); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, betainc_test12) { + NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, + sd::DataType::FLOAT32); + NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, + sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, + sd::DataType::FLOAT32); - NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, sd::DataType::FLOAT32); - NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, sd::DataType::FLOAT32); - NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); + NDArray expected('c', {4}, {0.9999995, 0.8594694, 0.999988, 0.49124345}, + sd::DataType::FLOAT32); - NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, sd::DataType::FLOAT32); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); - sd::ops::betainc op; - auto result = op.evaluate({&a, &b, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test1) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(2.); + q.linspace(1.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, + 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test2) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(10.); - x.assign(2.); + q.linspace(10.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, + 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test3) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(100.); - x.assign(2.); + q.linspace(100.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, + 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test4) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(100.); - x.assign(2.); + q.linspace(100.); + x.assign(2.); - auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, + 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test5) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); + q.linspace(1.); + x.assign(1.1); - q.linspace(1.); - x.assign(1.1); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, + 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); - auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test6) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(1.01); + q.linspace(1.); + x.assign(1.01); - auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, + 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test7) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto q = NDArrayFactory::create('c', {3, 3}); - auto x= NDArrayFactory::create('c', {3,3}); - auto q= NDArrayFactory::create('c', {3,3}); - - q.linspace(1.); - x.assign(10.); - - auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); - + q.linspace(1.); + x.assign(10.); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, + 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, + 4.56065812e-10f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test8) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + // q.linspace(1.); + // x.assign(10.); - //q.linspace(1.); - //x.assign(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); - sd::ops::zeta op; - auto result = op.evaluate({&x, &q}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test9) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create( + 'c', {3, 4}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); + // q.linspace(1.); + // x.assign(10.); - //q.linspace(1.); - //x.assign(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); - sd::ops::zeta op; - auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results); - ASSERT_EQ(ND4J_STATUS_OK, results); + // auto output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); - //auto output = result.at(0); - // z.printIndexedBuffer("Zeta output"); - ASSERT_TRUE(expected.isSameShape(z)); - ASSERT_TRUE(expected.equalsTo(z)); - -// + // } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test10) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create('c', {3, 4}); - auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); - auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z= NDArrayFactory::create('c', {3,4}); - - //q.linspace(1.); - //x.assign(10.); + // q.linspace(1.); + // x.assign(10.); - auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - sd::ops::zeta op; - auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results); + ASSERT_EQ(ND4J_STATUS_OK, results); - //auto output = result.at(0); - // z.printIndexedBuffer("Zeta output"); - ASSERT_TRUE(expected.isSameShape(z)); - ASSERT_TRUE(expected.equalsTo(z)); + // auto output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); -// + // } - TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) { - auto x = NDArrayFactory::create('c', {8, 7}); - auto indices = NDArrayFactory::create('c',{2}, {5, 3}); - auto axis = NDArrayFactory::create(-2); + auto x = NDArrayFactory::create('c', {8, 7}); + auto indices = NDArrayFactory::create('c', {2}, {5, 3}); + auto axis = NDArrayFactory::create(-2); - auto z0 = NDArrayFactory::create('c', {5, 7}); - auto z1 = NDArrayFactory::create('c', {3, 7}); + auto z0 = NDArrayFactory::create('c', {5, 7}); + auto z1 = NDArrayFactory::create('c', {3, 7}); - sd::ops::split_v op; - auto status = op.execute({&x, &indices, &axis}, std::vector{&z0, &z1}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::split_v op; + auto status = op.execute({&x, &indices, &axis}, + std::vector{&z0, &z1}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test1) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); + // ASSERT_FALSE(true); + n.linspace(1.); + x.assign(0.5); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); -// ASSERT_FALSE(true); - n.linspace(1.); - x.assign(0.5); - - auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); - - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, + 1290440.250000, -20644900.000000, 3.71595e+08}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - auto output = result.at(0); - // output->printBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + // output->printBuffer(); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test2) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); - - n.linspace(10.); - x.linspace(0.5); - - auto expected= NDArrayFactory::create('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02}); + n.linspace(10.); + x.linspace(0.5); - //ASSERT_FALSE(true); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {-7.43182451e+09, 3.08334759e+05, -3.25669798e+03, 1.55186197e+02, + -1.46220433e+01, 2.00905201e+00, -3.48791235e-01, 7.08016273e-02, + -1.60476052e-02}); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + // ASSERT_FALSE(true); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, polygamma_test3) { + auto n = NDArrayFactory::create('c', {3, 3}); + auto x = NDArrayFactory::create('c', {3, 3}); - auto n= NDArrayFactory::create('c', {3,3}); - auto x= NDArrayFactory::create('c', {3,3}); + n.linspace(1.); + x.linspace(10.); - n.linspace(1.); - x.linspace(10.); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, + {1.05166336e-01, -9.04983497e-03, 1.31009323e-03, -2.44459433e-04, + 5.31593880e-05, -1.28049888e-05, 3.31755364e-06, -9.07408791e-07, + 2.58758130e-07}); - auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests3, polygamma_test4) { + NDArray n('c', {3, 4}, {/*0.7788*/ 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + sd::DataType::DOUBLE); + NDArray x('c', {3, 4}, + {0.7717, 0.9281, 0.9846, 0.4838, 0.6433, 0.6041, 0.6501, 0.7612, + 0.7605, 0.3948, 0.9493, 0.8600}, + sd::DataType::DOUBLE); - NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, sd::DataType::DOUBLE); - - NDArray expected('c', {3,4}, {/*std::numeric_limits::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, - 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {3, 4}, + {/*std::numeric_limits::quiet_NaN()*/ -1.031918, -7.021327e-01, + 1.682743e+00, -1.851378e+01, 3.604167e+01, -3.008293e+02, 1.596005e+03, + -4.876665e+03, 4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, + sd::DataType::DOUBLE); - sd::ops::polygamma op; - auto result = op.evaluate({&n, &x}, {}, {}); + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(DeclarableOpsTests3, digamma_1) { + NDArray x('c', {18}, + {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., + 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, + sd::DataType::DOUBLE); - NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, sd::DataType::DOUBLE); - - NDArray expected('c', {18}, {std::numeric_limits::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, - std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {18}, + {std::numeric_limits::infinity(), -99996.761229, 3.091129, + 7.401432, 1.792911, 11.196838, 10.630354, 0.03649, 2.11331, + std::numeric_limits::infinity(), -5.28904, -0.577216, 0.03649, + 0.544293, 1.549434, 2.917892, 3.020524, 3.077401}, + sd::DataType::DOUBLE); - sd::ops::digamma op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::digamma op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test1) { - - auto x= NDArrayFactory::create('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16}); - auto expS= NDArrayFactory::create('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); - auto expU= NDArrayFactory::create('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042}); - auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {6, 6}, {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, + 1, -2, -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, + 6, -16, 0, -1, -16, 11, 7, -19, 2, -17, 17, -16}); + auto expS = NDArrayFactory::create( + 'c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {0.14692, -0.11132, -0.69568, 0.59282, -0.14881, 0.32935, + -0.38751, 0.60378, -0.04927, -0.01397, -0.69456, -0.01581, + 0.19293, -0.12795, -0.18682, -0.69065, -0.20597, 0.62617, + 0.66806, 0.4314, -0.33849, -0.22166, 0.04099, -0.44967, + 0.11121, -0.64065, -0.02138, -0.07378, -0.60568, -0.45216, + -0.5765, -0.1007, -0.60305, -0.34175, 0.29068, -0.3042}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {-0.24577, -0.24512, 0.00401, -0.04585, -0.62058, 0.70162, + 0.27937, 0.75961, 0.43885, -0.06857, -0.3839, 0.01669, + -0.35944, -0.09629, 0.44593, 0.78602, -0.09103, -0.19125, + 0.53973, 0.07613, -0.10721, 0.49559, 0.35687, 0.56431, + -0.6226, 0.39742, 0.12785, -0.15716, 0.52372, 0.37297, + 0.23113, -0.43578, 0.76204, -0.32414, 0.23996, 0.11543}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test2) { - - auto x = NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {7, 6}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU = NDArrayFactory::create( + 'c', {7, 7}, + {-0.13417, -0.12443, -0.68854, 0.5196, 0.21706, 0.03974, 0.41683, + 0.347, 0.62666, -0.04964, -0.01912, 0.66932, 0.1457, -0.12183, + -0.17329, -0.14666, -0.19639, -0.55355, 0.0614, 0.75729, 0.1619, + -0.64703, 0.37056, -0.37398, -0.32922, -0.0186, -0.35656, -0.26134, + -0.08027, -0.64405, -0.0127, -0.06934, 0.59287, -0.14956, -0.44712, + 0.55906, -0.06235, -0.58017, -0.12911, -0.359, -0.00393, -0.44877, + 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {0.2508, -0.2265, 0.01689, 0.04486, 0.53132, 0.77537, + -0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, + 0.33139, -0.05528, 0.47186, 0.73171, 0.18905, -0.3055, + -0.57263, 0.06276, -0.09542, 0.59396, -0.36152, 0.419, + 0.59193, 0.4361, 0.13557, -0.03632, -0.5755, 0.32944, + -0.21165, -0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test3) { - - auto x= NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); - auto expU= NDArrayFactory::create('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); - auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {7, 6}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU = NDArrayFactory::create( + 'c', {7, 6}, + {-0.13417, -0.12443, -0.68854, 0.5196, 0.21706, 0.03974, 0.347, + 0.62666, -0.04964, -0.01912, 0.66932, 0.1457, -0.17329, -0.14666, + -0.19639, -0.55355, 0.0614, 0.75729, -0.64703, 0.37056, -0.37398, + -0.32922, -0.0186, -0.35656, -0.08027, -0.64405, -0.0127, -0.06934, + 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359, + -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, {0.2508, -0.2265, 0.01689, 0.04486, 0.53132, 0.77537, + -0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, + 0.33139, -0.05528, 0.47186, 0.73171, 0.18905, -0.3055, + -0.57263, 0.06276, -0.09542, 0.59396, -0.36152, 0.419, + 0.59193, 0.4361, 0.13557, -0.03632, -0.5755, 0.32944, + -0.21165, -0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test4) { - - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {6, 7}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102, + -0.49879, 0.12076, 0.37629, -0.7211, -0.24585, 0.12086, + -0.36569, -0.70218, -0.08012, 0.21274, -0.07314, 0.56231, + -0.44508, 0.4329, 0.1356, 0.60909, -0.47398, -0.02164, + 0.61238, -0.05674, 0.59489, 0.06588, -0.3874, 0.33685, + -0.13044, -0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV = NDArrayFactory::create( + 'c', {7, 7}, + {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, + 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, + 0.32179, 0.12812, -0.25812, 0.0691, -0.12891, 0.26979, 0.84807, + -0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, + 0.48118, 0.15876, -0.65132, -0.24602, 0.3963, -0.16651, -0.27155, + -0.31605, -0.46947, -0.50195, 0.0378, -0.34937, -0.53062, 0.15069, + 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151, 0.13065}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test5) { - - auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); - auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); - auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); - auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {6, 7}, + {0., -9., -6, 9, -10, -12, 2, 13, 5, -11, 20, -17, 1, -2, + -11, 3, -8, 3, -14, 19, -20, 20, -17, -5, 6, -16, 0, -1, + -16, 11, 7, -19, 2, -17, 17, -16, 4, -9, 1, -15, 7, -2}); + auto expS = NDArrayFactory::create( + 'c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102, + -0.49879, 0.12076, 0.37629, -0.7211, -0.24585, 0.12086, + -0.36569, -0.70218, -0.08012, 0.21274, -0.07314, 0.56231, + -0.44508, 0.4329, 0.1356, 0.60909, -0.47398, -0.02164, + 0.61238, -0.05674, 0.59489, 0.06588, -0.3874, 0.33685, + -0.13044, -0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV = NDArrayFactory::create( + 'c', {7, 6}, + {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, + -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, + -0.25812, 0.0691, -0.12891, 0.26979, -0.50833, 0.13793, 0.06658, + -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, + 0.3963, -0.16651, -0.31605, -0.46947, -0.50195, 0.0378, -0.34937, + -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test6) { - - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 - ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 - ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 - ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, - 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, - 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, - -0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361, - 0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611, - -0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552, - -0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017, - -0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421, - 0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209}); - auto expV= NDArrayFactory::create('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, - 0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951, - -0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422, - 0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945, - -0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794, - 0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017, - -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5f); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5f); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-7., 17, 4, -10, 5, 1, -5, -19, 13, -8, 9, 13, 19, 13, -2, + -8, 10, -9, 0, -20, -2, 14, 19, 5, -18, 4, -13, 12, -10, 5, + -10, -10, 17, -5, -2, 10, 5, -4, -11, 15, -3, 15, -17, -20, -10, + -4, 12, -9, 16, 13, 10, -19, 2, -9, -10, 8, -2, -4, 3, 7, + 10, -19, -11, -4, -6, 2, -12, 6, -4, -14, 14, 16, 7, 19, -17, + 2, -14, 5, -1, 16, 19, -11, -14, -16, -19, 15, -18, -12, -16, 16, + 1, 5, 7, 8, 2, 13, -3, 6, 2, -5}); + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, + 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, + 2.2399, 44.72126, 32.3164, 16.60139, 6.88783, 0.78122}); + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, -0.32285, -0.58332, + 0.3451, 0.4746, -0.45953, 0.58332, 0.10605, 0.51533, 0.50234, + 0.36136, 0.12588, -0.73123, -0.37812, -0.00215, 0.55361, 0.68915, + -0.2919, 0.04767, -0.4197, -0.51132, 0.44464, -0.25326, -0.42493, + -0.01712, -0.74653, 0.516, -0.16688, 0.1854, -0.77155, 0.27611, + -0.19321, -0.14317, -0.85886, -0.15224, 0.42585, -0.60155, -0.68323, + 0.18819, -0.29053, -0.22696, -0.36993, 0.64862, -0.10956, -0.54483, + -0.36552, -0.57697, -0.32277, 0.11229, 0.55495, 0.4923, -0.02937, + 0.01689, -0.63257, 0.57075, -0.52245, -0.56002, -0.2036, -0.53119, + -0.6022, 0.01017, -0.33605, -0.35257, 0.53215, -0.04936, -0.69075, + 0.48958, -0.85427, -0.14796, -0.03449, 0.08633, 0.15008, 0.60996, + 0.31071, -0.67721, 0.22421, 0.67717, -0.59857, 0.04372, -0.2565, + 0.33979, 0.68116, 0.49852, -0.13441, 0.51374, -0.07421, -0.20066, + 0.04504, 0.42865, 0.44418, 0.75939, 0.12113, -0.13826, 0.83651, + 0.11988, -0.50209}); + auto expV = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, 0.59651, -0.13439, + -0.395, 0.66979, 0.14654, 0.73731, 0.47061, 0.19357, -0.41127, + -0.16817, 0.1047, -0.29727, 0.73711, 0.38235, -0.45951, -0.29873, + 0.80012, -0.02078, 0.4651, -0.23201, -0.05314, -0.0419, -0.52146, + 0.77792, 0.344, -0.66438, 0.05648, 0.03756, -0.31531, 0.67422, + 0.74471, 0.01504, -0.03081, -0.24335, 0.62049, 0.03172, 0.91947, + 0.30828, 0.23713, 0.04796, -0.01311, 0.38652, -0.79415, -0.42423, + -0.19945, -0.13783, -0.54667, -0.58527, 0.49955, 0.3001, 0.85214, + 0.01628, 0.02688, -0.02891, 0.52157, 0.16608, -0.20181, 0.61371, + 0.69894, -0.25794, 0.45726, -0.33952, -0.32659, -0.18938, -0.73015, + 0.13486, 0.73816, -0.41646, 0.47458, -0.1956, 0.5536, -0.137, + 0.64688, 0.50536, 0.03017, -0.51827, -0.31837, -0.16732, 0.71378, + -0.30425, -0.39314, 0.15266, 0.63693, -0.30945, -0.5663, -0.51981, + 0.03325, 0.37603, 0.05147, 0.76462, -0.01282, 0.92491, -0.08042, + 0.36977, -0.03428}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5f); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5f); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-7., 17, 4, -10, 5, 1, -5, -19, 13, -8, 9, 13, 19, 13, -2, + -8, 10, -9, 0, -20, -2, 14, 19, 5, -18, 4, -13, 12, -10, 5, + -10, -10, 17, -5, -2, 10, 5, -4, -11, 15, -3, 15, -17, -20, -10, + -4, 12, -9, 16, 13, 10, -19, 2, -9, -10, 8, -2, -4, 3, 7, + 10, -19, -11, -4, -6, 2, -12, 6, -4, -14, 14, 16, 7, 19, -17, + 2, -14, 5, -1, 16, 19, -11, -14, -16, -19, 15, -18, -12, -16, 16, + 1, 5, 7, 8, 2, 13, -3, 6, 2, -5}); + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, 38.18412, 31.52287, + 23.52755, 11.79484, 1.90195, 39.34498, 32.54861, 17.52492, 7.03003, + 2.2399, 44.72126, 32.3164, 16.60139, 6.88783, 0.78122}); - auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 - ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 - ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); - auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,38.18412, 31.52287, 23.52755, 11.79484, 1.90195, - 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); - - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 0, 16}); + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 0, 16}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto s = result.at(0); + auto s = result.at(0); - ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expS.isSameShape(s)); - + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); } /////////////////////////////////////////////////////////////////// // TEST_F(DeclarableOpsTests3, svd_test8) { -// auto x= NDArrayFactory::create('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 ,19 -// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19 -// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1 -// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6 -// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8 -// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 ,-18 ,15 ,-11 ,17 ,-8 ,-18 -// ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9 -// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1 -// ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17 -// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19 -// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 ,13 ,4 ,-15 ,1 ,-14 -// ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6 -// ,18 ,14 ,-3 ,-10 ,8 ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 ,-14 ,0 -// ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10 -// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 ,18 ,5 ,1 ,-10 ,15 ,10 ,16 -// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 ,-20 ,-4 ,18 ,-11 ,11 ,-10 -// ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 , -// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 ,-7 ,-18 ,17 , -// -13 ,18 ,-20 ,-16 ,-5 ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17}); - -// auto expS= NDArrayFactory::create('c', {2,2,10}, { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456, +// auto x= NDArrayFactory::create('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 +// ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 +// ,19 +// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 +// ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19 +// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 +// ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1 +// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 +// ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6 +// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 +// ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8 +// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 +// ,-18 ,15 ,-11 ,17 ,-8 ,-18 ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 +// ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9 +// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 +// ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1 ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 +// ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17 +// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 +// ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19 +// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 +// ,13 ,4 ,-15 ,1 ,-14 ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 +// ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6 ,18 ,14 ,-3 ,-10 ,8 +// ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 +// ,-14 ,0 ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 +// ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10 +// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 +// ,18 ,5 ,1 ,-10 ,15 ,10 ,16 +// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 +// ,-20 ,-4 ,18 ,-11 ,11 ,-10 ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 +// ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 , +// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 +// ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 +// ,-7 ,-18 ,17 , -13 ,18 ,-20 ,-16 ,-5 +// ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 +// ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17}); + +// auto expS= NDArrayFactory::create('c', {2,2,10}, +// { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456, // 25.53945, 19.31856, 15.30939, 9.31349, -// 67.41342, 59.64963, 58.72687, 39.22496, 32.39772, 29.30833, -// 23.1491 , 16.92442, 6.38613, 3.49563, -// 74.37477, 52.07016, 46.10758, 39.10742, 32.02261, 27.05888, -// 20.54921, 13.17989, 8.4158 , 4.39974, -// 65.47447, 56.31305, 54.13371, 46.26955, 43.47755, 30.25799, +// 67.41342, 59.64963, 58.72687, 39.22496, +// 32.39772, 29.30833, 23.1491 +// , 16.92442, 6.38613, 3.49563, +// 74.37477, 52.07016, 46.10758, 39.10742, +// 32.02261, 27.05888, +// 20.54921, 13.17989, 8.4158 +// , 4.39974, +// 65.47447, 56.31305, 54.13371, 46.26955, +// 43.47755, 30.25799, // 20.71463, 16.89671, 10.39572, 7.81631}); -// auto expU= NDArrayFactory::create('c', {2,2,11,11}, {-0.177870, -0.149461, -0.196911, 0.036990, -0.338237, 0.548901, -// -0.074396, 0.497067, -0.083636, -0.111810, -0.466989, -0.010465, 0.434732, 0.337198, 0.305239, -0.292813, -// 0.041280, -0.517144, 0.121499, 0.464908, 0.003658, 0.135017, -0.446916, -0.098318, 0.073571, -0.200521, -// 0.186776, -0.353022, -0.435582, -0.225959, 0.052972, 0.032390, -0.583801, -0.402790, 0.562809, 0.102744, -// 0.066555, 0.206079, 0.115322, 0.217220, -0.062591, -0.273173, -0.569645, 0.005612, 0.092601, 0.350055, -// -0.608007, -0.367743, 0.064860, 0.112656, 0.091576, -0.144262, 0.554655, -0.042100, -0.092023, 0.026986, -// -0.395811, -0.245209, 0.572522, 0.429430, 0.099621, -0.159236, -0.086263, 0.268160, -0.391298, 0.050417, -// 0.150175, 0.045253, 0.464173, 0.138376, 0.265551, 0.049691, 0.528778, 0.116951, 0.384609, 0.144416, -// -0.453591, -0.519390, -0.150671, 0.072897, 0.102406, -0.154184, 0.450735, 0.174171, -0.519405, 0.147109, -// 0.333670, 0.178053, 0.360763, 0.226976, 0.069976, -0.046765, 0.448897, 0.511309, -0.361050, -0.191690, -// -0.304442, 0.270383, -0.124133, 0.417183, -0.083359, 0.137022, 0.004276, -0.462336, 0.051267, 0.020622, -// -0.566932, -0.051351, -0.417106, -0.292202, -0.021595, -0.315956, 0.396626, -0.604952, 0.155990, 0.258395, -// -0.125080, 0.115404, 0.234517, -0.357460, 0.271271, 0.063771, -0.087400, -0.024710, -0.179892, 0.584339, -// -0.413085, 0.510580, 0.334646, 0.044424, 0.224735, 0.134434, -0.147861, 0.291853, 0.487948, 0.238917, -// 0.433893, 0.435884, 0.056370, -0.051216, -0.450902, 0.062411, 0.080733, -0.365211, 0.031931, 0.493926, -// -0.239428, 0.038247, -0.180721, -0.118035, 0.042175, 0.377296, -0.516399, 0.324744, -0.756196, 0.160856, -// -0.152527, -0.046867, -0.092933, -0.044945, 0.137659, 0.246552, -0.071709, 0.032821, -0.529356, -0.029669, -// 0.200178, 0.188916, 0.428036, -0.496734, -0.164185, 0.629070, -0.131588, 0.073992, 0.066877, 0.208450, -// -0.156170, -0.253670, -0.000365, -0.121172, 0.067774, 0.618226, 0.230460, -0.118865, 0.579424, 0.324523, -// 0.038653, 0.310308, 0.570186, -0.217271, -0.110967, 0.196375, 0.167058, 0.264071, -0.130023, 0.254189, -// -0.459057, -0.301033, 0.069932, -0.033338, -0.070600, 0.685064, 0.130274, 0.074929, -0.206899, 0.574057, -// 0.327277, -0.131588, -0.018497, 0.312445, 0.314594, 0.480422, -0.293858, -0.273277, -0.006598, -0.134574, -// 0.403501, 0.140025, 0.380693, -0.257039, -0.067012, 0.248776, -0.361838, -0.270296, -0.225844, 0.320245, -// 0.055730, 0.454809, -0.212163, -0.063281, 0.563112, -0.200737, 0.537389, -0.210845, 0.109997, 0.166215, -// -0.243725, -0.347349, -0.274348, 0.263950, 0.437134, 0.265820, -0.127520, -0.033325, -0.137156, 0.518557, -// 0.246720, 0.389394, -0.600568, 0.062027, -0.047838, -0.338416, 0.032778, -0.141998, -0.338022, -0.381467, -// 0.210512, -0.314413, 0.256321, 0.001460, 0.238901, 0.139840, 0.633423, -0.182575, -0.461504, 0.290250, -// -0.025930, 0.336998, -0.211280, -0.662387, -0.207946, -0.003860, -0.147842, 0.157217, 0.123704, 0.345686, -// 0.337946, 0.138261, -0.178814, -0.109597, 0.087135, -0.509500, -0.300296, -0.262279, 0.377476, -0.366815, -// 0.091787, 0.247495, -0.193812, -0.179714, 0.238552, -0.162305, -0.029549, 0.785426, -0.157586, -0.084533, -// -0.357024, 0.317878, 0.217656, 0.125319, 0.648832, 0.344045, -0.001109, 0.457190, -0.072439, -0.106278, -// 0.228962, -0.136139, -0.528342, -0.020840, -0.108908, -0.231661, 0.396864, 0.234925, 0.180894, -0.179430, -// -0.587730, 0.178276, -0.008672, -0.386172, 0.033155, 0.319568, 0.101457, -0.272011, 0.126007, 0.175374, -// -0.081668, 0.112987, -0.296422, -0.713743, 0.269413, -0.082098, -0.338649, 0.131035, -0.518616, 0.022478, -// 0.177802, -0.042432, -0.606219, -0.343848, 0.014416, -0.141375, 0.748332, -0.165911, -0.049067, -0.241062, -// 0.436318, 0.173318, 0.058066, 0.193764, -0.000647, 0.265777, -0.027847, -0.096305, 0.711632, 0.066506, -// -0.223124, 0.219165, -0.038165, 0.427444, -0.296887, 0.139982, 0.298976, 0.294876, -0.001315, 0.419802, -// 0.475401, -0.156256, -0.289477, -0.438761, -0.116348, 0.108350, -0.369368, -0.219943, 0.433088, 0.187565, -// -0.217259, 0.147014, -0.538991, -0.065052, 0.310337, 0.491887, 0.254439, 0.075052, 0.071155, -0.084856, -// 0.402098, 0.096270, 0.093662, -0.475769, 0.256832, 0.161394, -0.390050, -0.513551, -0.184665, 0.211506, -// -0.112525, -0.493409, -0.258765, 0.262124, -0.272998, 0.269370, 0.266226, -0.367919, 0.192386, -0.006422, -// -0.466728, -0.481792, 0.090611, -0.156359, 0.178693, -0.371658, -0.214190, -0.469058, -0.006134, 0.081902, -// 0.536950, 0.064836, -0.334010, 0.523530, -0.182061, -0.206686, 0.002985, 0.054858, -0.038727, -0.075390, -// 0.543839, -0.442964, -0.190550, -0.298127, -0.065323, 0.131415, 0.329899, 0.122096, -0.507075, 0.523751, -// -0.167317, 0.198593, -0.069066, 0.402739, 0.328583, 0.314184, -0.268003, -0.148549, 0.118925, -0.508174, -// 0.128716, -0.405597, -0.157224, 0.271021, -0.384444, -0.174935, 0.343919, -0.076726, 0.607931, 0.383931, -// 0.198254, 0.133707, 0.321460, -0.232543, 0.099988, -0.321954, -0.366304, -0.137440, 0.232835, -0.290306, -// -0.260804, -0.347721, 0.182895, 0.382311, -0.332847, -0.192469, -0.438258, -0.017533, -0.192976, -0.702531, -// 0.124463, 0.039719, -0.221319, -0.224785, 0.096356, -0.302131, -0.462598, 0.194320}); - - -// auto expV= NDArrayFactory::create('c', {2,2,10,10}, {-0.050761, 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334, -// -0.025855, 0.348132, 0.036882, 0.034921, 0.307295, 0.629837, 0.014276, 0.265687, 0.188407, -0.035481, 0.082827, -// -0.490175, 0.391118, -0.180180, 0.169108, 0.206663, 0.623321, 0.260009, 0.081943, 0.004485, 0.136199, 0.060353, -// -0.641224, -0.181559, -0.041761, 0.578416, -0.161798, -0.573128, -0.187563, 0.012533, 0.368041, 0.314619, -// -0.079349, -0.527508, 0.216020, 0.004721, 0.188769, -0.242534, -0.442685, -0.121683, -0.565306, -0.202894, -// 0.095280, -0.181900, -0.170627, -0.201655, 0.620259, -0.257996, 0.277656, -0.009623, 0.266775, 0.081952, -// 0.539241, -0.452254, -0.136142, 0.177049, -0.144734, 0.494673, 0.101613, 0.280091, -0.186281, 0.548779, -// 0.235160, 0.054763, -0.571503, 0.298086, 0.035312, -0.195188, 0.474030, -0.175457, -0.497267, -0.101439, -// -0.170678, -0.060605, -0.557305, 0.073433, 0.057195, 0.352091, -0.486102, -0.483569, 0.252091, -0.121245, -// 0.068719, -0.638919, -0.078029, -0.236556, -0.351440, -0.024437, 0.319855, -0.007406, 0.319691, -0.402334, -// -0.197966, 0.058936, -0.360900, 0.233414, -0.251532, 0.105457, 0.048097, 0.029321, 0.002714, -0.845953, -// -0.136344, 0.378037, 0.277491, 0.278420, 0.037491, 0.432117, -0.586745, 0.104573, 0.316569, -0.039848, -// 0.239645, -0.320923, 0.555156, 0.145059, -0.546959, 0.267760, 0.298029, 0.177831, -0.191286, -0.032427, -// 0.197034, 0.081887, -0.113063, 0.711713, 0.020279, -0.362346, -0.145776, 0.173289, -0.500880, 0.181624, -// 0.084391, -0.278967, 0.212143, -0.413382, 0.012879, -0.216886, -0.625774, 0.066795, -0.421937, -0.291320, -// 0.011402, -0.416660, -0.134200, 0.043039, 0.554715, 0.126867, 0.147315, 0.474334, 0.094354, -0.156458, -// 0.450168, 0.447448, 0.261750, -0.161426, -0.064309, -0.592417, 0.210891, 0.104312, 0.176178, -0.237020, -// 0.455579, -0.358056, -0.307454, 0.033700, -0.486831, -0.303963, -0.284916, 0.241549, 0.510701, 0.206104, -// 0.062587, 0.248212, 0.132088, -0.122704, 0.026342, -0.011108, 0.066306, 0.763127, 0.009491, 0.038822, -// -0.562773, -0.320104, 0.477773, 0.354169, 0.293329, -0.304227, -0.001662, -0.213324, 0.365277, -0.198056, -// -0.383499, -0.017789, 0.324542, -0.642856, 0.238689, -0.360461, -0.060599, -0.257192, 0.342400, 0.180845, -// 0.272810, -0.452278, -0.409323, 0.077013, -0.082561, 0.334893, -0.103309, -0.198049, 0.480416, 0.470593, -// 0.029072, -0.300574, 0.532293, 0.250892, -0.355298, 0.079716, -0.319781, 0.259925, 0.277872, -0.251917, -// 0.346821, 0.161642, 0.205861, 0.107125, -0.594779, -0.226272, 0.610183, -0.065926, 0.170332, 0.312553, -// -0.108093, 0.368268, -0.183109, -0.192222, -0.544559, 0.136824, -0.412352, -0.398250, -0.257291, 0.019911, -// 0.288797, 0.013350, 0.349817, -0.108331, 0.180576, 0.652863, 0.319319, 0.020218, -0.324499, 0.290877, -// 0.338518, -0.301776, -0.440871, -0.281683, -0.158759, -0.080281, 0.418260, 0.189926, -0.064112, -0.390914, -// 0.485420, -0.464327, 0.211070, 0.044295, -0.032292, 0.043985, 0.147160, -0.702247, -0.198395, -0.352940, -// -0.237014, -0.438235, 0.073448, -0.418712, -0.280275, -0.091373, -0.194273, 0.347558, -0.421767, 0.283011, -// -0.351869, -0.210088, -0.034628, 0.448410, 0.149194, -0.488551, -0.068805, -0.117007, -0.390999, 0.377100, -// 0.423252, -0.041944, 0.455115, -0.537818, 0.266732, 0.218202, 0.047475, -0.383506, -0.158858, 0.450881, -// 0.072415, 0.355772, 0.002360, 0.138976, 0.541349, -0.295405, 0.463832, 0.400676, -0.168962, 0.259334, -// -0.047960, 0.272197, 0.582658, 0.198052, 0.127300, -0.320468, -0.104858, -0.229698, 0.046672, -0.474224, -// 0.370765, -0.246450, 0.212667, 0.024935, -0.344530, -0.238547, 0.185931, 0.269068, 0.487414, 0.421376, -// 0.442391, -0.284247, 0.304973, -0.365006, -0.159016, -0.129088, -0.126454, 0.600462, -0.461163, -0.243552, -// -0.049814, -0.381340, -0.054504, 0.436237, 0.126120, -0.359677, -0.409734, -0.179422, -0.414820, 0.371149, -// 0.078299, 0.503544, 0.322165, 0.148341, -0.495447, -0.084355, -0.174667, 0.016802, -0.066954, 0.318825, -// -0.480771, -0.060163, 0.144302, -0.041555, 0.459106, 0.029882, -0.565026, 0.282336, 0.528472, 0.044916, -// -0.286167, -0.101052, -0.181529, -0.419406, -0.032204, -0.732282, 0.106833, -0.288881, 0.171516, -0.096242, -// -0.331834, -0.493188, 0.393195, 0.358365, 0.049125, 0.123457, 0.438169, -0.105015, 0.092386, -0.130413, -0.476991}); +// auto expU= NDArrayFactory::create('c', {2,2,11,11}, {-0.177870, +// -0.149461, -0.196911, 0.036990, -0.338237, 0.548901, +// -0.074396, 0.497067, -0.083636, +// -0.111810, -0.466989, -0.010465, +// 0.434732, 0.337198, 0.305239, +// -0.292813, 0.041280, -0.517144, +// 0.121499, 0.464908, 0.003658, +// 0.135017, -0.446916, -0.098318, +// 0.073571, -0.200521, 0.186776, +// -0.353022, -0.435582, -0.225959, +// 0.052972, 0.032390, -0.583801, +// -0.402790, 0.562809, 0.102744, +// 0.066555, 0.206079, 0.115322, +// 0.217220, -0.062591, -0.273173, +// -0.569645, 0.005612, 0.092601, +// 0.350055, -0.608007, -0.367743, +// 0.064860, 0.112656, 0.091576, +// -0.144262, 0.554655, -0.042100, +// -0.092023, 0.026986, -0.395811, +// -0.245209, 0.572522, 0.429430, +// 0.099621, -0.159236, -0.086263, +// 0.268160, -0.391298, 0.050417, +// 0.150175, 0.045253, 0.464173, +// 0.138376, 0.265551, 0.049691, +// 0.528778, 0.116951, 0.384609, +// 0.144416, -0.453591, -0.519390, +// -0.150671, 0.072897, 0.102406, +// -0.154184, 0.450735, 0.174171, +// -0.519405, 0.147109, 0.333670, +// 0.178053, 0.360763, 0.226976, +// 0.069976, -0.046765, 0.448897, +// 0.511309, -0.361050, -0.191690, +// -0.304442, 0.270383, -0.124133, +// 0.417183, -0.083359, 0.137022, +// 0.004276, -0.462336, 0.051267, +// 0.020622, -0.566932, -0.051351, +// -0.417106, -0.292202, -0.021595, +// -0.315956, 0.396626, -0.604952, +// 0.155990, 0.258395, -0.125080, +// 0.115404, 0.234517, -0.357460, +// 0.271271, 0.063771, -0.087400, +// -0.024710, -0.179892, 0.584339, +// -0.413085, 0.510580, 0.334646, +// 0.044424, 0.224735, 0.134434, +// -0.147861, 0.291853, 0.487948, +// 0.238917, 0.433893, 0.435884, +// 0.056370, -0.051216, -0.450902, +// 0.062411, 0.080733, -0.365211, +// 0.031931, 0.493926, -0.239428, +// 0.038247, -0.180721, -0.118035, +// 0.042175, 0.377296, -0.516399, +// 0.324744, -0.756196, 0.160856, +// -0.152527, -0.046867, -0.092933, +// -0.044945, 0.137659, 0.246552, +// -0.071709, 0.032821, -0.529356, +// -0.029669, 0.200178, 0.188916, +// 0.428036, -0.496734, -0.164185, +// 0.629070, -0.131588, 0.073992, +// 0.066877, 0.208450, -0.156170, +// -0.253670, -0.000365, -0.121172, +// 0.067774, 0.618226, 0.230460, +// -0.118865, 0.579424, 0.324523, +// 0.038653, 0.310308, 0.570186, +// -0.217271, -0.110967, 0.196375, +// 0.167058, 0.264071, -0.130023, +// 0.254189, -0.459057, -0.301033, +// 0.069932, -0.033338, -0.070600, +// 0.685064, 0.130274, 0.074929, +// -0.206899, 0.574057, 0.327277, +// -0.131588, -0.018497, 0.312445, +// 0.314594, 0.480422, -0.293858, +// -0.273277, -0.006598, -0.134574, +// 0.403501, 0.140025, 0.380693, +// -0.257039, -0.067012, 0.248776, +// -0.361838, -0.270296, -0.225844, +// 0.320245, 0.055730, 0.454809, +// -0.212163, -0.063281, 0.563112, +// -0.200737, 0.537389, -0.210845, +// 0.109997, 0.166215, -0.243725, +// -0.347349, -0.274348, 0.263950, +// 0.437134, 0.265820, -0.127520, +// -0.033325, -0.137156, 0.518557, +// 0.246720, 0.389394, -0.600568, +// 0.062027, -0.047838, -0.338416, +// 0.032778, -0.141998, -0.338022, +// -0.381467, 0.210512, -0.314413, +// 0.256321, 0.001460, 0.238901, +// 0.139840, 0.633423, -0.182575, +// -0.461504, 0.290250, -0.025930, +// 0.336998, -0.211280, -0.662387, +// -0.207946, -0.003860, -0.147842, +// 0.157217, 0.123704, 0.345686, +// 0.337946, 0.138261, -0.178814, +// -0.109597, 0.087135, -0.509500, +// -0.300296, -0.262279, 0.377476, +// -0.366815, 0.091787, 0.247495, +// -0.193812, -0.179714, 0.238552, +// -0.162305, -0.029549, 0.785426, +// -0.157586, -0.084533, -0.357024, +// 0.317878, 0.217656, 0.125319, +// 0.648832, 0.344045, -0.001109, +// 0.457190, -0.072439, -0.106278, +// 0.228962, -0.136139, -0.528342, +// -0.020840, -0.108908, -0.231661, +// 0.396864, 0.234925, 0.180894, +// -0.179430, -0.587730, 0.178276, +// -0.008672, -0.386172, 0.033155, +// 0.319568, 0.101457, -0.272011, +// 0.126007, 0.175374, -0.081668, +// 0.112987, -0.296422, -0.713743, +// 0.269413, -0.082098, -0.338649, +// 0.131035, -0.518616, 0.022478, +// 0.177802, -0.042432, -0.606219, +// -0.343848, 0.014416, -0.141375, +// 0.748332, -0.165911, -0.049067, +// -0.241062, 0.436318, 0.173318, +// 0.058066, 0.193764, -0.000647, +// 0.265777, -0.027847, -0.096305, +// 0.711632, 0.066506, -0.223124, +// 0.219165, -0.038165, 0.427444, +// -0.296887, 0.139982, 0.298976, +// 0.294876, -0.001315, 0.419802, +// 0.475401, -0.156256, -0.289477, +// -0.438761, -0.116348, 0.108350, +// -0.369368, -0.219943, 0.433088, +// 0.187565, -0.217259, 0.147014, +// -0.538991, -0.065052, 0.310337, +// 0.491887, 0.254439, 0.075052, +// 0.071155, -0.084856, 0.402098, +// 0.096270, 0.093662, -0.475769, +// 0.256832, 0.161394, -0.390050, +// -0.513551, -0.184665, 0.211506, +// -0.112525, -0.493409, -0.258765, +// 0.262124, -0.272998, 0.269370, +// 0.266226, -0.367919, 0.192386, +// -0.006422, -0.466728, -0.481792, +// 0.090611, -0.156359, 0.178693, +// -0.371658, -0.214190, -0.469058, +// -0.006134, 0.081902, 0.536950, +// 0.064836, -0.334010, 0.523530, +// -0.182061, -0.206686, 0.002985, +// 0.054858, -0.038727, -0.075390, +// 0.543839, -0.442964, -0.190550, +// -0.298127, -0.065323, 0.131415, +// 0.329899, 0.122096, -0.507075, +// 0.523751, -0.167317, 0.198593, +// -0.069066, 0.402739, 0.328583, +// 0.314184, -0.268003, -0.148549, +// 0.118925, -0.508174, 0.128716, +// -0.405597, -0.157224, 0.271021, +// -0.384444, -0.174935, 0.343919, +// -0.076726, 0.607931, 0.383931, +// 0.198254, 0.133707, 0.321460, +// -0.232543, 0.099988, -0.321954, +// -0.366304, -0.137440, 0.232835, +// -0.290306, -0.260804, -0.347721, +// 0.182895, 0.382311, -0.332847, +// -0.192469, -0.438258, -0.017533, +// -0.192976, -0.702531, 0.124463, +// 0.039719, -0.221319, -0.224785, +// 0.096356, -0.302131, -0.462598, +// 0.194320}); + +// auto expV= NDArrayFactory::create('c', {2,2,10,10}, {-0.050761, +// 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334, +// -0.025855, 0.348132, 0.036882, +// 0.034921, 0.307295, 0.629837, +// 0.014276, 0.265687, 0.188407, +// -0.035481, 0.082827, -0.490175, +// 0.391118, -0.180180, 0.169108, +// 0.206663, 0.623321, 0.260009, +// 0.081943, 0.004485, 0.136199, +// 0.060353, -0.641224, -0.181559, +// -0.041761, 0.578416, -0.161798, +// -0.573128, -0.187563, 0.012533, +// 0.368041, 0.314619, -0.079349, +// -0.527508, 0.216020, 0.004721, +// 0.188769, -0.242534, -0.442685, +// -0.121683, -0.565306, -0.202894, +// 0.095280, -0.181900, -0.170627, +// -0.201655, 0.620259, -0.257996, +// 0.277656, -0.009623, 0.266775, +// 0.081952, 0.539241, -0.452254, +// -0.136142, 0.177049, -0.144734, +// 0.494673, 0.101613, 0.280091, +// -0.186281, 0.548779, 0.235160, +// 0.054763, -0.571503, 0.298086, +// 0.035312, -0.195188, 0.474030, +// -0.175457, -0.497267, -0.101439, +// -0.170678, -0.060605, -0.557305, +// 0.073433, 0.057195, 0.352091, +// -0.486102, -0.483569, 0.252091, +// -0.121245, 0.068719, -0.638919, +// -0.078029, -0.236556, -0.351440, +// -0.024437, 0.319855, -0.007406, +// 0.319691, -0.402334, -0.197966, +// 0.058936, -0.360900, 0.233414, +// -0.251532, 0.105457, 0.048097, +// 0.029321, 0.002714, -0.845953, +// -0.136344, 0.378037, 0.277491, +// 0.278420, 0.037491, 0.432117, +// -0.586745, 0.104573, 0.316569, +// -0.039848, 0.239645, -0.320923, +// 0.555156, 0.145059, -0.546959, +// 0.267760, 0.298029, 0.177831, +// -0.191286, -0.032427, 0.197034, +// 0.081887, -0.113063, 0.711713, +// 0.020279, -0.362346, -0.145776, +// 0.173289, -0.500880, 0.181624, +// 0.084391, -0.278967, 0.212143, +// -0.413382, 0.012879, -0.216886, +// -0.625774, 0.066795, -0.421937, +// -0.291320, 0.011402, -0.416660, +// -0.134200, 0.043039, 0.554715, +// 0.126867, 0.147315, 0.474334, +// 0.094354, -0.156458, 0.450168, +// 0.447448, 0.261750, -0.161426, +// -0.064309, -0.592417, 0.210891, +// 0.104312, 0.176178, -0.237020, +// 0.455579, -0.358056, -0.307454, +// 0.033700, -0.486831, -0.303963, +// -0.284916, 0.241549, 0.510701, +// 0.206104, 0.062587, 0.248212, +// 0.132088, -0.122704, 0.026342, +// -0.011108, 0.066306, 0.763127, +// 0.009491, 0.038822, -0.562773, +// -0.320104, 0.477773, 0.354169, +// 0.293329, -0.304227, -0.001662, +// -0.213324, 0.365277, -0.198056, +// -0.383499, -0.017789, 0.324542, +// -0.642856, 0.238689, -0.360461, +// -0.060599, -0.257192, 0.342400, +// 0.180845, 0.272810, -0.452278, +// -0.409323, 0.077013, -0.082561, +// 0.334893, -0.103309, -0.198049, +// 0.480416, 0.470593, 0.029072, +// -0.300574, 0.532293, 0.250892, +// -0.355298, 0.079716, -0.319781, +// 0.259925, 0.277872, -0.251917, +// 0.346821, 0.161642, 0.205861, +// 0.107125, -0.594779, -0.226272, +// 0.610183, -0.065926, 0.170332, +// 0.312553, -0.108093, 0.368268, +// -0.183109, -0.192222, -0.544559, +// 0.136824, -0.412352, -0.398250, +// -0.257291, 0.019911, 0.288797, +// 0.013350, 0.349817, -0.108331, +// 0.180576, 0.652863, 0.319319, +// 0.020218, -0.324499, 0.290877, +// 0.338518, -0.301776, -0.440871, +// -0.281683, -0.158759, -0.080281, +// 0.418260, 0.189926, -0.064112, +// -0.390914, 0.485420, -0.464327, +// 0.211070, 0.044295, -0.032292, +// 0.043985, 0.147160, -0.702247, +// -0.198395, -0.352940, -0.237014, +// -0.438235, 0.073448, -0.418712, +// -0.280275, -0.091373, -0.194273, +// 0.347558, -0.421767, 0.283011, +// -0.351869, -0.210088, -0.034628, +// 0.448410, 0.149194, -0.488551, +// -0.068805, -0.117007, -0.390999, +// 0.377100, 0.423252, -0.041944, +// 0.455115, -0.537818, 0.266732, +// 0.218202, 0.047475, -0.383506, +// -0.158858, 0.450881, 0.072415, +// 0.355772, 0.002360, 0.138976, +// 0.541349, -0.295405, 0.463832, +// 0.400676, -0.168962, 0.259334, +// -0.047960, 0.272197, 0.582658, +// 0.198052, 0.127300, -0.320468, +// -0.104858, -0.229698, 0.046672, +// -0.474224, 0.370765, -0.246450, +// 0.212667, 0.024935, -0.344530, +// -0.238547, 0.185931, 0.269068, +// 0.487414, 0.421376, 0.442391, +// -0.284247, 0.304973, -0.365006, +// -0.159016, -0.129088, -0.126454, +// 0.600462, -0.461163, -0.243552, +// -0.049814, -0.381340, -0.054504, +// 0.436237, 0.126120, -0.359677, +// -0.409734, -0.179422, -0.414820, +// 0.371149, 0.078299, 0.503544, +// 0.322165, 0.148341, -0.495447, +// -0.084355, -0.174667, 0.016802, +// -0.066954, 0.318825, -0.480771, +// -0.060163, 0.144302, -0.041555, +// 0.459106, 0.029882, -0.565026, +// 0.282336, 0.528472, 0.044916, +// -0.286167, -0.101052, -0.181529, +// -0.419406, -0.032204, -0.732282, +// 0.106833, -0.288881, 0.171516, +// -0.096242, -0.331834, -0.493188, +// 0.393195, 0.358365, 0.049125, +// 0.123457, 0.438169, -0.105015, +// 0.092386, -0.130413, -0.476991}); // sd::ops::svd op; // auto results = op.execute({&x}, {}, {1, 1, 7}); @@ -2399,452 +2788,555 @@ TEST_F(DeclarableOpsTests3, svd_test7) { // auto u = result.at(1); // auto v = result.at(2); - // ASSERT_TRUE(expS.isSameShape(s)); - // ASSERT_TRUE(expU.isSameShape(u)); - // ASSERT_TRUE(expV.isSameShape(v)); +// ASSERT_TRUE(expS.isSameShape(s)); +// ASSERT_TRUE(expU.isSameShape(u)); +// ASSERT_TRUE(expV.isSameShape(v)); - // ASSERT_TRUE(expS.equalsTo(s)); +// ASSERT_TRUE(expS.equalsTo(s)); - // if(sd::Environment::getInstance()->isCPU()) { - // ASSERT_TRUE(expU.equalsTo(u)); - // ASSERT_TRUE(expV.equalsTo(v)); - // } - // else { - // for(uint i = 0; i < expU.lengthOf(); ++i) - // ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); - // for(uint i = 0; i < expV.lengthOf(); ++i) - // ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); - // } +// if(sd::Environment::getInstance()->isCPU()) { +// ASSERT_TRUE(expU.equalsTo(u)); +// ASSERT_TRUE(expV.equalsTo(v)); +// } +// else { +// for(uint i = 0; i < expU.lengthOf(); ++i) +// ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), +// sd::math::nd4j_abs(u->e(i)), 1e-5); +// for(uint i = 0; i < expV.lengthOf(); ++i) +// ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), +// sd::math::nd4j_abs(v->e(i)), 1e-5); +// } -// +// // } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test9) { - - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , - -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , - 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , - -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , - -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, - 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, - 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , - 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, - -0.61335, 0.10076, 0.01381, 0.40922, -0.66783,-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234, - -0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318,0.62118, -0.37993, 0.30992, 0.34537, -0.50444, - 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - - auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01, 1.37280000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, 9.35890000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01, 2.10040000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01, 1.83200000e-02, - -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,-2.10060000e-01, 2.41550000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01, -4.97000000e-02, - -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, 4.00490000e-01,1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01, 7.68890000e-01, - -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01, 7.12900000e-02,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,1.51570000e-01, 6.02100000e-02, - 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01, -2.92870000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, 3.92350000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01, -4.08100000e-02,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01, -6.83340000e-01, - 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02, 5.78250000e-01,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, -2.15420000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01, 2.92020000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02, 2.55590000e-01, - -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01, -3.65370000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, -1.09390000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04, 1.90420000e-01,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02, -8.60450000e-01, - 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 6}, + {17, -11, 20, -10, 19, 13, -18, 6, -2, -6, -10, 4, -6, -4, 3, + 16, 12, -15, 8, -8, 12, -1, 20, 19, -13, 0, 20, 17, -8, 16, + -19, 7, -16, -14, -5, 7, 7, -5, 12, -15, 7, 8, 1, -8, -17, + 10, -11, 8, -10, 1, -6, 10, 15, 19, -15, 8, 2, 8, 12, 7, + -5, 1, 8, 4, -13, 2, 19, -2, -10, -8, 11, 1, 20, -11, 4, + 1, -17, -15, 0, -9, -4, -1, -6, -9, -13, 10, 7, -2, 15, -10, + -1, 11, -20, -2, -1, -18, 12, 16, 8, -9, -20, -7, -20, 3, -9, + 12, 8, -19, -2, 2, 1, 7, 10, -18, 13, 6, 14, 0, 19, 8}); + + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, + 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, + 2.017, 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079, + 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128, 0.36387, + 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, -0.10577, + 0.93946, -0.0871, -0.31058, 0.04677, 0.52823, 0.31163, -0.78777, + 0.02322, -0.05234, -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, + 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, 0.62118, -0.37993, + 0.30992, 0.34537, -0.50444, 0.45763, -0.42877, 0.08128, -0.3904, + 0.66912, -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, -0.21986, + -0.8214, -0.00392, -0.1659, 0.49944, -0.79443, 0.1633, -0.45374, + -0.31666, -0.18989, -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, + 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, + -0.20758, 0.74339, -0.01213, -0.2024, -0.80239, -0.35502, -0.3982, + -0.17492, 0.68875, 0.1822, -0.08046, -0.39238, -0.57619, 0.34555, + 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847, -0.7532, + 0.22176, -0.33913}); + + auto expV = NDArrayFactory::create( + 'c', {2, 2, 6, 6}, + {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + -1.10690000e-01, 1.37280000e-01, 2.86620000e-01, 5.88200000e-02, + 1.68760000e-01, -2.55000000e-03, -1.00090000e-01, 9.35890000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01, + 6.70320000e-01, 2.10040000e-01, 1.00910000e-01, 4.35740000e-01, + -6.90500000e-01, -3.61090000e-01, -4.38680000e-01, 1.83200000e-02, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01, 2.41550000e-01, -4.42450000e-01, 4.56640000e-01, + 5.48020000e-01, 3.32100000e-02, -5.40210000e-01, -4.97000000e-02, + -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01, + -4.21850000e-01, 4.00490000e-01, 1.83740000e-01, -1.36190000e-01, + -2.29380000e-01, -5.11090000e-01, -2.06580000e-01, 7.68890000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, + 5.88210000e-01, 7.12900000e-02, 2.25200000e-01, 4.30600000e-02, + 9.08510000e-01, -3.08940000e-01, 1.51570000e-01, 6.02100000e-02, + 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02, + -5.79750000e-01, -2.92870000e-01, 4.89620000e-01, -2.24300000e-01, + 5.31200000e-02, 6.92040000e-01, 2.72560000e-01, 3.92350000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, + -1.17970000e-01, -4.08100000e-02, 4.25340000e-01, -1.65500000e-02, + -2.82400000e-02, -5.60180000e-01, 1.93050000e-01, -6.83340000e-01, + 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01, + 2.37500000e-02, 5.78250000e-01, -6.10000000e-04, 3.00110000e-01, + 1.17290000e-01, -6.92400000e-02, -9.19220000e-01, -2.15420000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, + -3.19580000e-01, 2.92020000e-01, 2.25920000e-01, -1.10170000e-01, + 9.17020000e-01, -1.71540000e-01, 3.39100000e-02, 2.55590000e-01, + -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01, + 4.98470000e-01, -3.65370000e-01, 6.39700000e-02, -4.04150000e-01, + -5.28310000e-01, 8.90000000e-02, -7.30460000e-01, -1.09390000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, + 5.20000000e-04, 1.90420000e-01, 2.55960000e-01, 3.17040000e-01, + -3.47800000e-02, -3.01860000e-01, -3.57600000e-02, -8.60450000e-01, + 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, + -4.39400000e-02, 2.17750000e-01, -6.57270000e-01, 2.91000000e-01, + 4.17280000e-01, 2.52880000e-01, -4.63400000e-01, -1.74620000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test10) { - - auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , - -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , - 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , - -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , - -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); - - auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, - 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, - 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , - 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); - - auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,-0.61335, 0.10076, 0.01381, 0.40922, -0.66783, - -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234,-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318, - 0.62118, -0.37993, 0.30992, 0.34537, -0.50444,0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, - -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658,0.49473, -0.39923, -0.20758, 0.74339, -0.01213, - -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); - - auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, - -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01,-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, - -2.10060000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01,-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, - 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01,-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, - 1.51570000e-01,1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, - -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01,8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, - 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02,-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, - -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02,1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, - -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); - } - + auto x = NDArrayFactory::create( + 'c', {2, 2, 5, 6}, + {17, -11, 20, -10, 19, 13, -18, 6, -2, -6, -10, 4, -6, -4, 3, + 16, 12, -15, 8, -8, 12, -1, 20, 19, -13, 0, 20, 17, -8, 16, + -19, 7, -16, -14, -5, 7, 7, -5, 12, -15, 7, 8, 1, -8, -17, + 10, -11, 8, -10, 1, -6, 10, 15, 19, -15, 8, 2, 8, 12, 7, + -5, 1, 8, 4, -13, 2, 19, -2, -10, -8, 11, 1, 20, -11, 4, + 1, -17, -15, 0, -9, -4, -1, -6, -9, -13, 10, 7, -2, 15, -10, + -1, 11, -20, -2, -1, -18, 12, 16, 8, -9, -20, -7, -20, 3, -9, + 12, 8, -19, -2, 2, 1, 7, 10, -18, 13, 6, 14, 0, 19, 8}); + + auto expS = NDArrayFactory::create( + 'c', {2, 2, 5}, + {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, 38.56035, 30.62846, + 26.31646, 19.42605, 3.01162, 38.56369, 29.18881, 19.54565, 10.89746, + 2.017, 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU = NDArrayFactory::create( + 'c', {2, 2, 5, 5}, + {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025, 0.26329, 0.3079, + 0.38582, 0.77696, 0.28872, 0.03076, 0.03015, -0.9128, 0.36387, + 0.18039, -0.61335, 0.10076, 0.01381, 0.40922, -0.66783, -0.10577, + 0.93946, -0.0871, -0.31058, 0.04677, 0.52823, 0.31163, -0.78777, + 0.02322, -0.05234, -0.23942, -0.45801, -0.34248, 0.71286, 0.32778, + 0.26147, 0.60409, 0.39933, 0.46862, 0.43318, 0.62118, -0.37993, + 0.30992, 0.34537, -0.50444, 0.45763, -0.42877, 0.08128, -0.3904, + 0.66912, -0.05428, 0.53632, 0.19774, -0.32198, 0.75276, -0.21986, + -0.8214, -0.00392, -0.1659, 0.49944, -0.79443, 0.1633, -0.45374, + -0.31666, -0.18989, -0.24459, 0.10463, -0.27652, 0.85595, 0.34657, + 0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, + -0.20758, 0.74339, -0.01213, -0.2024, -0.80239, -0.35502, -0.3982, + -0.17492, 0.68875, 0.1822, -0.08046, -0.39238, -0.57619, 0.34555, + 0.12488, -0.50703, -0.29269, 0.72267, -0.34713, 0.3847, -0.7532, + 0.22176, -0.33913}); + + auto expV = NDArrayFactory::create( + 'c', {2, 2, 6, 5}, + {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01, + -1.10690000e-01, 2.86620000e-01, 5.88200000e-02, 1.68760000e-01, + -2.55000000e-03, -1.00090000e-01, -4.88230000e-01, 4.84470000e-01, + -1.09150000e-01, -1.46810000e-01, 6.70320000e-01, 1.00910000e-01, + 4.35740000e-01, -6.90500000e-01, -3.61090000e-01, -4.38680000e-01, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01, -4.42450000e-01, 4.56640000e-01, 5.48020000e-01, + 3.32100000e-02, -5.40210000e-01, -6.36070000e-01, 5.57600000e-02, + 3.28740000e-01, 3.81950000e-01, -4.21850000e-01, 1.83740000e-01, + -1.36190000e-01, -2.29380000e-01, -5.11090000e-01, -2.06580000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01, + 5.88210000e-01, 2.25200000e-01, 4.30600000e-02, 9.08510000e-01, + -3.08940000e-01, 1.51570000e-01, 1.97510000e-01, -7.26560000e-01, + 1.05370000e-01, 1.10600000e-02, -5.79750000e-01, 4.89620000e-01, + -2.24300000e-01, 5.31200000e-02, 6.92040000e-01, 2.72560000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01, + -1.17970000e-01, 4.25340000e-01, -1.65500000e-02, -2.82400000e-02, + -5.60180000e-01, 1.93050000e-01, 8.08800000e-02, 4.38260000e-01, + -2.48340000e-01, -6.36220000e-01, 2.37500000e-02, -6.10000000e-04, + 3.00110000e-01, 1.17290000e-01, -6.92400000e-02, -9.19220000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02, + -3.19580000e-01, 2.25920000e-01, -1.10170000e-01, 9.17020000e-01, + -1.71540000e-01, 3.39100000e-02, -4.86810000e-01, -2.32390000e-01, + -4.31500000e-01, 3.75290000e-01, 4.98470000e-01, 6.39700000e-02, + -4.04150000e-01, -5.28310000e-01, 8.90000000e-02, -7.30460000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01, + 5.20000000e-04, 2.55960000e-01, 3.17040000e-01, -3.47800000e-02, + -3.01860000e-01, -3.57600000e-02, 1.31650000e-01, 7.57150000e-01, + -4.89030000e-01, 3.47710000e-01, -4.39400000e-02, -6.57270000e-01, + 2.91000000e-01, 4.17280000e-01, 2.52880000e-01, -4.63400000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test11) { - - NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, - 0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234, - 0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695}); - NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571}); - NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, - 7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01, - 7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01, - 8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01}); - NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763, - -0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413, - -0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232, - -0.43596, 0.83108, -0.34531}); - - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {0, 1, 16}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - auto u = result.at(1); - auto v = result.at(2); - - ASSERT_TRUE(expS.isSameShape(s)); - ASSERT_TRUE(expU.isSameShape(u)); - ASSERT_TRUE(expV.isSameShape(v)); - - ASSERT_TRUE(expS.equalsTo(s)); - - if(sd::Environment::getInstance()->isCPU()) { - ASSERT_TRUE(expU.equalsTo(u)); - ASSERT_TRUE(expV.equalsTo(v)); - } - else { - for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u.e(i)), 1e-5); - for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v.e(i)), 1e-5); - } - + NDArray x('c', {2, 2, 3, 3}, + {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, + 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, 0.1596, 0.3087, + 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, + 0.5056, 0.8925, -0.5461, 0.9234, 0.0856, -0.7938, 0.6591, 0.5555, + 0.1500, 0.3087, 0.1548, 0.4695}); + NDArray expS('c', {2, 2, 3}, + {1.89671, 0.37095, 0.05525, 1.51296, 0.52741, 0.17622, 1.69095, + 0.90438, 0.24688, 1.33551, 0.87475, 0.21571}); + NDArray expU('c', {2, 2, 3, 3}, + {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, + -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, 7.8967e-01, + 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, + 3.6412e-01, 1.9366e-01, 9.1100e-01, 7.1764e-01, 5.9844e-01, + 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, + -8.0116e-01, 2.6639e-01, 8.7050e-01, -4.2088e-01, -2.5513e-01, + 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, + -7.7481e-01}); + NDArray expV('c', {2, 2, 3, 3}, + {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, + 0.4768, 0.51228, 0.7143, 0.77137, -0.17763, -0.6111, + 0.26324, -0.7852, 0.56051, 0.57939, 0.59322, 0.55892, + 0.55149, 0.06737, 0.83146, 0.81413, -0.26072, -0.51887, + 0.18182, 0.96306, -0.19863, 0.85948, 0.2707, -0.4336, + 0.26688, 0.48582, 0.83232, -0.43596, 0.83108, -0.34531}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if (sd::Environment::getInstance()->isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } else { + for (uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), + sd::math::nd4j_abs(u.e(i)), 1e-5); + for (uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), + sd::math::nd4j_abs(v.e(i)), 1e-5); + } } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test12) { + NDArray x( + 'c', {4, 3}, + {1.7787856, 0.80119777, 0.72437465, 0.23089433, 1.7271413, 0.18039072, + 0.50563407, 0.89252293, 1.5461209, 0.92336726, 0.085571885, 0.79378015}); + NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); - NDArray x('c', {4,3}, {1.7787856,0.80119777,0.72437465,0.23089433,1.7271413,0.18039072,0.50563407,0.89252293,1.5461209,0.92336726,0.085571885,0.79378015}); - NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); - + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 0, 16}); - sd::ops::svd op; - auto result = op.evaluate({&x}, {}, {1, 0, 16}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - - ASSERT_TRUE(expS.equalsTo(s)); - ASSERT_TRUE(expS.isSameShape(s)); + auto s = result.at(0); + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, elu_test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {.1, .2, .3, 0.5 * -0.32968, 0.5 * -0.393469, 0.5 * -0.451188, .7, .8, + .9}); - auto x = NDArrayFactory::create('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); - auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); - - sd::ops::elu op; - auto result = op.evaluate({&x}, {0.5}, {}); + sd::ops::elu op; + auto result = op.evaluate({&x}, {0.5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, elu_bp_test1) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto eps = NDArrayFactory::create('c', {3, 3}); + eps.assign(2.); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2, 2, 2, 0.5 * 1.34064, 0.5 * 1.213061, 0.5 * 1.097623, 2, 2, 2}); - auto x = NDArrayFactory::create('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); - auto eps = NDArrayFactory::create('c', {3,3}); - eps.assign(2.); - auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); - - sd::ops::elu_bp op; - auto result = op.evaluate({ &x, &eps }, {0.5}, {}); + sd::ops::elu_bp op; + auto result = op.evaluate({&x, &eps}, {0.5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, lrelu_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); - - sd::ops::lrelu op; - auto result = op.evaluate({&x}, {0.2}, {}); + sd::ops::lrelu op; + auto result = op.evaluate({&x}, {0.2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, lrelu_bp_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); - auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); - - sd::ops::lrelu_bp op; - auto result = op.evaluate({&x, &eps}, {0.2}, {}); + sd::ops::lrelu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, selu_test1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, + 8.405608, 9.456309}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); - - sd::ops::selu op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::selu op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, selu_test2) { + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, -4, -5, -6, 7, 8, 9}); + // auto expS = NDArrayFactory::create('c', {3}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, + 2.101402, 2.101402}); - auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); -// auto expS = NDArrayFactory::create('c', {3}); - auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); - auto exp = NDArrayFactory::create('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402}); - - sd::ops::selu_bp op; - auto result = op.evaluate({&x, &eps}, {0.2}, {}); + sd::ops::selu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto s = result.at(0); -// auto u = result.at(1); -// auto v = result.at(2); -// s->printIndexedBuffer("SELU_BP"); - ASSERT_TRUE(exp.equalsTo(s)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto s = result.at(0); + // auto u = result.at(1); + // auto v = result.at(2); + // s->printIndexedBuffer("SELU_BP"); + ASSERT_TRUE(exp.equalsTo(s)); } TEST_F(DeclarableOpsTests3, EQScalarTests_1) { - Graph graph; - - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + Graph graph; - sd::ops::eq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, EQScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::eq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, GTScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gt_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, GTScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gt_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, GTEScalarTests_3) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(2.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); - sd::ops::gte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, LTEScalarTests_3) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(2.0f); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); - sd::ops::lte_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, NEQScalarTests_1) { - Graph graph; - - auto x = NDArrayFactory::create(1.0f); - auto scalar = NDArrayFactory::create(1.0f); + Graph graph; - sd::ops::neq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_FALSE(res); + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); } TEST_F(DeclarableOpsTests3, NEQScalarTests_2) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::neq_scalar op; - auto res = op.verify({&x, &scalar}); - ASSERT_TRUE(res); + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); } TEST_F(DeclarableOpsTests3, NOOPTests_1) { - Graph graph; + Graph graph; - auto x = NDArrayFactory::create(2.0f); - auto scalar = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); - sd::ops::noop op; - auto res = op.evaluate({&x, &scalar}, {}, {}); - ASSERT_TRUE(res.status() == sd::Status::OK()); + sd::ops::noop op; + auto res = op.evaluate({&x, &scalar}, {}, {}); + ASSERT_TRUE(res.status() == sd::Status::OK()); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 2af6b480e2b6..59856bcc3b9c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -18,39 +18,37 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests4 : public testing::Test { -public: - - DeclarableOpsTests4() { - printf("\n"); - fflush(stdout); - - sd::ops::adjust_hue op0; - sd::ops::adjust_saturation op1; - } + public: + DeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } }; template class TypedDeclarableOpsTests4 : public testing::Test { -public: - - TypedDeclarableOpsTests4() { - printf("\n"); - fflush(stdout); - - sd::ops::adjust_hue op0; - sd::ops::adjust_saturation op1; - } + public: + TypedDeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } }; typedef ::testing::Types TestingTypes; @@ -58,2359 +56,2557 @@ TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); - - x.linspace(1); - - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, + 54.f, 55.f, 58.f, 59.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + x.linspace(1); - auto z = result.at(0); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, + 54.f, 55.f, 58.f, 59.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { - auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 2}, + { + 7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, + 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, + 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, + 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f, + }); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { - auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, + 77.f, 78.f, 81.f, 82.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, + 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, + 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, + 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, + 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, + 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, + 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { - auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, + 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, + 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, + 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); - - x.linspace(1); + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, + {3.f, 4.f, 6.f, 7.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); - - x.linspace(1); - + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - //z->printShapeInfo("z shape:"); - //z->printBuffer("z buffer:"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z shape:"); + // z->printBuffer("z buffer:"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, - 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, - -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, - -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, - 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, - 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, - -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, - 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, - -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, - 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, - 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - sd::ops::avgpool2d op; - auto result = op.evaluate({&input}, {3,3, 3,3, 0,0, 1,1,1, 0,1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - // z->printIndexedBuffer("z"); - // exp.printIndexedBuffer("e"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, + 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, + 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, + 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, + 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, + -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, + 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, + 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, + -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, + 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, + -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, + -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, + 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, + 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, + -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, + 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, + 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, + -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, + 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, + 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, + 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, + 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, + 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, + -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, + 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, + 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, + -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, + 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, + -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, + -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, + -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, + -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, + -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, + 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, + -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, + 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, + 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, + -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, + 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, + 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, + -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, + -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, + 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, + -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, + -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, + 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, + 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, + 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, + 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, + -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, + -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, + 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, + 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, + -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, + 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, + -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, + -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, + 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, + -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, + 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, + 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, + 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, + -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, + 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, + 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, + 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, + 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, + 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, + 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, + 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, + 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, + -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, + 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, + -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, + -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, + 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, + -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, + -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, + 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, + -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, + -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, + 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, + 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, + -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, + 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, + -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, + 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, + 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, + -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, + -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, + 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, + 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, + 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, + 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, + 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, + -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, + 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, + -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, + -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, + 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, + -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, + 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, + 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, + -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, + -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, + 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, + 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, + 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, + 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, + -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, + -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, + 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, + 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, + 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, + 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, + -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, + -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, + 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, + -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, + -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, + 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, + 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, + -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, + 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, + 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, + -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, + 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, + -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, + -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, + 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, + 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, + -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, + 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, + 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, + -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, + 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, + -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, + 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, + 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, + 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, + -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, + 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, + 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, + -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, + 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, + 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, + -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, + 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, + 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, + 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, + 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, + -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, + -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, + 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, + -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, + -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, + 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, + 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, + 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, + 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, + -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, + 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, + 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, + -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, + -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, + 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, + 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, + -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, + 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, + -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, + 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, + 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, + -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, + -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, + -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, + 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, + -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, + 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, + -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, + -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, + 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, + 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, + 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, + 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, + 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, + -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, + 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, + 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, + -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, + 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, + 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, + -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, + 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, + -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, + 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, + 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, + 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, + -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, + 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, + -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, + -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, + 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, + 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, + 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, + 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, + -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, + -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, + 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, + 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, + -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, + 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, + 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, + 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, + 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, + -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, + 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, + 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, + 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, + -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, + 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, + -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, + -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, + 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, + -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, + 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, + 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, + -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, + -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, + 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, + -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, + 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, + 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, + 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, + 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, + -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, + -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, + -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, + 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, + -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, + -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, + 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, + -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, + 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, + 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, + -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, + 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, + -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, + -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, + 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, + 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, + -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, + 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, + -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, + -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, + 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, + 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, + 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, + 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, + -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, + -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, + 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, + 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, + -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, + 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, + 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, + -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, + 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, + -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, + -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, + 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, + -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, + -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, + 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, + 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, + 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, + 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, + -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, + 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&input}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_11) { - int inOutH = 5;// 35; - int inOutW = 5;// 35; - int inOutC = 10;// 192; + int inOutH = 5; // 35; + int inOutW = 5; // 35; + int inOutC = 10; // 192; - auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); - sd::ops::avgpool2d op; - auto result = op.evaluate({&x}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {3, 3, 1, 1, 0, 0, 1, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; - int padTop = totalPadHeight / 2; - int padBottom = totalPadHeight - totalPadHeight / 2; + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; - int k = 3; + int k = 3; - auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - for (int h = 0; h < inOutH; h++) { - for (int w = 0; w < inOutW; w++) { - int hFrom = h - padTop; - int wFrom = w - padBottom; + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; - int hTo = hFrom + k; - int wTo = wFrom + k; + int hTo = hFrom + k; + int wTo = wFrom + k; - hFrom = sd::math::nd4j_max(0, hFrom); - wFrom = sd::math::nd4j_max(0, wFrom); + hFrom = sd::math::nd4j_max(0, hFrom); + wFrom = sd::math::nd4j_max(0, wFrom); - hTo = sd::math::nd4j_min(inOutH, hTo); - wTo = sd::math::nd4j_min(inOutW, wTo); + hTo = sd::math::nd4j_min(inOutH, hTo); + wTo = sd::math::nd4j_min(inOutW, wTo); - int idxOut[4]; - int idxIn[4]; - for (int ch = 0; ch < inOutC; ch++) { - idxOut[1] = h; - idxOut[2] = w; - idxOut[3] = ch; - idxIn[3] = ch; + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; - for (int kh = hFrom; kh < hTo; kh++) { - for (int kw = wFrom; kw < wTo; kw++) { - idxIn[1] = kh; - idxIn[2] = kw; + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; - auto inVal = x.e(0, kh, kw, ch); - m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); - c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); - } - } - } + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } } + } } - m /= c; - - ASSERT_EQ(m, z); - + } + m /= c; + ASSERT_EQ(m, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_12) { - - int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; - int oH=4, oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, - 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, - 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, - 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, - 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, - 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, - 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, - 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); - input.linspace(1.); - - sd::ops::avgpool2d op; - auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - //output->printIndexedBuffer("output"); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + int bS = 4, iH = 10, iW = 10, iC = 3, kH = 3, kW = 3, sH = 3, sW = 3, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 4, oW = 4; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oH, oW, iC}, + {17.5, 18.5, 19.5, 25., 26., 27., 34., 35., 36., + 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100., 101., 102., + 109., 110., 111., 116.5, 117.5, 118.5, 182.5, 183.5, 184.5, + 190., 191., 192., 199., 200., 201., 206.5, 207.5, 208.5, + 257.5, 258.5, 259.5, 265., 266., 267., 274., 275., 276., + 281.5, 282.5, 283.5, 317.5, 318.5, 319.5, 325., 326., 327., + 334., 335., 336., 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, + 400., 401., 402., 409., 410., 411., 416.5, 417.5, 418.5, + 482.5, 483.5, 484.5, 490., 491., 492., 499., 500., 501., + 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565., 566., 567., + 574., 575., 576., 581.5, 582.5, 583.5, 617.5, 618.5, 619.5, + 625., 626., 627., 634., 635., 636., 641.5, 642.5, 643.5, + 692.5, 693.5, 694.5, 700., 701., 702., 709., 710., 711., + 716.5, 717.5, 718.5, 782.5, 783.5, 784.5, 790., 791., 792., + 799., 800., 801., 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, + 865., 866., 867., 874., 875., 876., 881.5, 882.5, 883.5, + 917.5, 918.5, 919.5, 925., 926., 927., 934., 935., 936., + 941.5, 942.5, 943.5, 992.5, 993.5, 994.5, 1000., 1001., 1002., + 1009., 1010., 1011., 1016.5, 1017.5, 1018.5, 1082.5, 1083.5, 1084.5, + 1090., 1091., 1092., 1099., 1100., 1101., 1106.5, 1107.5, 1108.5, + 1157.5, 1158.5, 1159.5, 1165., 1166., 1167., 1174., 1175., 1176., + 1181.5, 1182.5, 1183.5}); + input.linspace(1.); + + sd::ops::avgpool2d op; + auto results = op.evaluate( + {&input}, {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + // output->printIndexedBuffer("output"); + // expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_13) { - - const int bS = 2; // batch size - const int iD = 1; // input depth (number of picture channels, for example rgb=3) - const int iH = 28; // picture height in pixels - const int iW = 28; // picture width in pixels - const int kH = 5; // kernel height in pixels - const int kW = 5; // kernel width in pixels - const int sH = 1; // stride step in horizontal direction - const int sW = 1; // stride step in vertical direction - const int pH = 0; // padding height - const int pW = 0; // padding width - const int dH = 2; // dilation height - const int dW = 2; // dilation width - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(*result)); - - - delete variableSpace; - delete block; + const int bS = 2; // batch size + const int iD = + 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_14) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = + (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = + (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 0, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - block->appendI({kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - sd::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(*result)); - - delete variableSpace; - delete block; + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int)sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int)sd::math::nd4j_ceil(iW * 1.f / sW); + + auto x = NDArrayFactory::create('c', {bS, iD, iH, iW}); + auto exp = NDArrayFactory::create('c', {bS, iD, oH, oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->appendI( + {kH, kW, sH, sW, pH, pW, dW, dH, 1, 0, + 0}); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad + // Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(*result)); + + delete variableSpace; + delete block; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, avgpool2d_16) { + int bS = 2, iH = 4, iW = 4, iC = 2, kH = 2, kW = 2, sH = 2, sW = 2, pH = 0, + pW = 0, dH = 1, dW = 1; + int oH = 2, oW = 2; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW - int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; - int oH=2,oW=2; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray output('f', {bS, oH, oW, iC}, sd::DataType::FLOAT32); + NDArray expected('c', {bS, oH, oW, iC}, + {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, + 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, + sd::DataType::FLOAT32); - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray output('f', {bS, oH, oW, iC}, sd::DataType::FLOAT32); - NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, sd::DataType::FLOAT32); + input.linspace(1.); - input.linspace(1.); + sd::ops::avgpool2d op; + auto status = op.execute( + {&input}, {&output}, {}, + {kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, 0, dataFormat}, {}); - sd::ops::avgpool2d op; - auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(Status::OK(), status); + // output.printBuffer(); + // expected.printIndexedBuffer("expected"); - // output.printBuffer(); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_1) { - auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); - auto bias = NDArrayFactory::create('c', {2}, {1, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}, {}); + auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3, 2}, + {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, + 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, + 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, biasadd_2) { - auto x = NDArrayFactory::create('c', {2, 2, 3, 3}); - auto bias = NDArrayFactory::create('c', {2}, {1, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3, 3}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto z = result.at(0); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {true}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, biasadd_3) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - - sd::ops::biasadd op; - auto result = op.evaluate({&x, &row}, {}, {}, {true}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto z = result.at(0); + sd::ops::biasadd op; + auto result = op.evaluate({&x, &row}, {}, {}, {true}); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_bp_1) { + NDArray x('c', {2, 2, 2, 3}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray gradO('c', {2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - NDArray x('c', {2,2,2,3}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray gradO('c', {2,2,2,3}, sd::DataType::FLOAT32); - NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); + NDArray expGradB('c', {3}, {9.2, 10., 10.8}, sd::DataType::FLOAT32); - NDArray expGradB('c', {3}, {9.2, 10. , 10.8}, sd::DataType::FLOAT32); + gradO.linspace(0.1, 0.1); - gradO.linspace(0.1, 0.1); + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC - sd::ops::biasadd_bp op; - auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto gradI = result.at(0); - auto gradB = result.at(1); - - ASSERT_TRUE(gradI.isSameShape(gradO)); - ASSERT_TRUE(gradI.equalsTo(gradO)); - - ASSERT_TRUE(gradB.isSameShape(expGradB)); - ASSERT_TRUE(gradB.equalsTo(expGradB)); + auto gradI = result.at(0); + auto gradB = result.at(1); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_bp_2) { + NDArray x('c', {2, 3, 2, 2}, {1., 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray gradO('c', {2, 3, 2, 2}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - NDArray x('c', {2,3,2,2}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray gradO('c', {2,3,2,2}, sd::DataType::FLOAT32); - NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); - - NDArray expGradB('c', {3}, {6.8, 10., 13.2}, sd::DataType::FLOAT32); - - gradO.linspace(0.1, 0.1); + NDArray expGradB('c', {3}, {6.8, 10., 13.2}, sd::DataType::FLOAT32); - sd::ops::biasadd_bp op; - auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW + gradO.linspace(0.1, 0.1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW - auto gradI = result.at(0); - auto gradB = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(gradI.isSameShape(gradO)); - ASSERT_TRUE(gradI.equalsTo(gradO)); - - ASSERT_TRUE(gradB.isSameShape(expGradB)); - ASSERT_TRUE(gradB.equalsTo(expGradB)); + auto gradI = result.at(0); + auto gradB = result.at(1); + ASSERT_TRUE(gradI.isSameShape(gradO)); + ASSERT_TRUE(gradI.equalsTo(gradO)); + ASSERT_TRUE(gradB.isSameShape(expGradB)); + ASSERT_TRUE(gradB.equalsTo(expGradB)); } TEST_F(DeclarableOpsTests4, biasadd_4) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); - sd::ops::biasadd op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), status); + sd::ops::biasadd op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests4, Test_Fill_1) { - auto x = NDArrayFactory::create('c', {1, 3}, {3, 2, 4}); - auto v = NDArrayFactory::create(2.); - auto exp = NDArrayFactory::create('c', {3, 2, 4}); - exp.assign(2.0f); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 3}, {3, 2, 4}); + auto v = NDArrayFactory::create(2.); + auto exp = NDArrayFactory::create('c', {3, 2, 4}); + exp.assign(2.0f); - auto z = result.at(0); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { - auto x = NDArrayFactory::create('c', {1, 81}); - auto exp = NDArrayFactory::create('c', {1, 2}, {0, 1}); - - x.p(51, 1); - x.p(52, 0); - x.p(60, 1); - x.p(61, 0); - sd::ops::firas_sparse op; - auto result = op.evaluate({&x}, {0, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 81}); + auto exp = NDArrayFactory::create('c', {1, 2}, {0, 1}); - auto z = result.at(0); -// z->printIndexedBuffer("FIRAS"); -// z->printShapeInfo("OUTSHAPE"); -// ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + x.p(51, 1); + x.p(52, 0); + x.p(60, 1); + x.p(61, 0); + sd::ops::firas_sparse op; + auto result = op.evaluate({&x}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("FIRAS"); + // z->printShapeInfo("OUTSHAPE"); + // ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { - auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {81}); + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {81}); - x.linspace(1); - exp.linspace(1); - sd::ops::flatten op; - auto result = op.evaluate({&x}, {}, {'c'}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Flatten1"); -// z->printShapeInfo("Flatten1 shape"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten1"); + // z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { - auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); - auto y = NDArrayFactory::create('c', {3, 3}); - auto exp = NDArrayFactory::create('c', {90}); - - x.linspace(1); - y.linspace(82); - exp.linspace(1); - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'c'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create('c', {90}); - auto z = result.at(0); -// z->printIndexedBuffer("Flatten2"); -// z->printShapeInfo("Flatten2 shape"); - - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1); + y.linspace(82); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten2"); + // z->printShapeInfo("Flatten2 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { - NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray y('f', {2,2}, sd::DataType::INT32); - NDArray exp('c', {8}, {1, 2, 3, 4, 1, 2, 3, 4}, sd::DataType::INT32); - - y.assign(x); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2, 2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 2, 3, 4, 1, 2, 3, 4}, sd::DataType::INT32); - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'c'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + y.assign(x); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { - NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray y('f', {2,2}, sd::DataType::INT32); - NDArray exp('c', {8}, {1, 3, 2, 4, 1, 3, 2, 4}, sd::DataType::INT32); - - y.assign(x); - - sd::ops::flatten op; - auto result = op.evaluate({&x, &y}, {}, {'f'}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2, 2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 3, 2, 4, 1, 3, 2, 4}, sd::DataType::INT32); - auto z = result.at(0); + y.assign(x); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'f'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7}); - auto exp = NDArrayFactory::create('c', {3,3}); + auto x = NDArrayFactory::create( + 'c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7}); + auto exp = NDArrayFactory::create('c', {3, 3}); - exp.linspace(1); - sd::ops::Floor op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Flatten1"); -// z->printShapeInfo("Flatten1 shape"); - ASSERT_TRUE(exp.equalsTo(z)); + exp.linspace(1); + sd::ops::Floor op; + auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Flatten1"); + // z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Split_1) { - auto x = NDArrayFactory::create('c', {5, 30}); - auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); - - std::vector list0({0,0, 0,4}); - std::vector list1({0,0, 4,19}); - std::vector list2({0,0, 19,30}); + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + std::vector list0({0, 0, 0, 4}); + std::vector list1({0, 0, 4, 19}); + std::vector list2({0, 0, 19, 30}); - sub0.assign(0.0); - sub1.assign(1.0); - sub2.assign(2.0); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + sub0.assign(0.0); + sub1.assign(1.0); + sub2.assign(2.0); - sd::ops::split_v op; - auto result = op.evaluate({&x, &sizes}, {}, {1}); + sd::ops::split_v op; + auto result = op.evaluate({&x, &sizes}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(3, result.size()); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); - - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); - - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); } // special test for TF mode, when axis goes first TEST_F(DeclarableOpsTests4, Test_Split_2) { - auto x = NDArrayFactory::create('c', {5, 12}); - auto axis = NDArrayFactory::create('c', {1, 1}, {1.f}); - - std::vector list0 = {0,0, 0,3}; - std::vector list1 = {0,0, 3,6}; - std::vector list2 = {0,0, 6,9}; - std::vector list3 = {0,0, 9,12}; - - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); - auto sub3 = x(list3, true); - - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); - sub3.assign(3.0f); + auto x = NDArrayFactory::create('c', {5, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {1.f}); + std::vector list0 = {0, 0, 0, 3}; + std::vector list1 = {0, 0, 3, 6}; + std::vector list2 = {0, 0, 6, 9}; + std::vector list3 = {0, 0, 9, 12}; - sd::ops::split op; - auto result = op.evaluate({&axis, &x}, {}, {4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + auto sub3 = x(list3, true); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); - auto z3 = result.at(3); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); + sub3.assign(3.0f); - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); - ASSERT_TRUE(sub3.isSameShape(z3)); + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); - ASSERT_TRUE(sub3.equalsTo(z3)); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + auto z3 = result.at(3); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub3.isSameShape(z3)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); + ASSERT_TRUE(sub3.equalsTo(z3)); } // special test for TF mode, when axis goes first TEST_F(DeclarableOpsTests4, Test_Split_3) { - auto x = NDArrayFactory::create('c', {6, 12}); - auto axis = NDArrayFactory::create('c', {1, 1}, {0.f}); + auto x = NDArrayFactory::create('c', {6, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {0.f}); - std::vector list0 = {0,2, 0,0}; - std::vector list1 = {2,4, 0,0}; - std::vector list2 = {4,6, 0,0}; + std::vector list0 = {0, 2, 0, 0}; + std::vector list1 = {2, 4, 0, 0}; + std::vector list2 = {4, 6, 0, 0}; - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); - sd::ops::split op; - auto result = op.evaluate({&axis, &x}, {}, {3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z0 = result.at(0); - auto z1 = result.at(1); - auto z2 = result.at(2); + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); - ASSERT_TRUE(sub0.isSameShape(z0)); - ASSERT_TRUE(sub1.isSameShape(z1)); - ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); - ASSERT_TRUE(sub0.equalsTo(z0)); - ASSERT_TRUE(sub1.equalsTo(z1)); - ASSERT_TRUE(sub2.equalsTo(z2)); + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test4) { + auto input = NDArrayFactory::create( + 'c', {10}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto axis = NDArrayFactory::create(-1); + auto exp1 = + NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto exp2 = + NDArrayFactory::create('c', {5}, {6.f, 7.f, 8.f, 9.f, 10.f}); - auto input = NDArrayFactory::create('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}); - auto axis = NDArrayFactory::create(-1); - auto exp1 = NDArrayFactory::create('c', {5}, {1.f,2.f,3.f,4.f,5.f}); - auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); + sd::ops::split op; + auto results = op.evaluate({&input, &axis}, {}, {2}, {}); - sd::ops::split op; - auto results = op.evaluate({&input, &axis}, {}, {2}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto out1 = results.at(0); + auto out2 = results.at(1); - auto out1 = results.at(0); - auto out2 = results.at(1); - - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test5) { + auto input = NDArrayFactory::create( + 'c', {3, 8}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 9.f, 10.f, 11.f, 12.f, 17.f, 18.f, 19.f, 20.f}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 4}, + {5.f, 6.f, 7.f, 8.f, 13.f, 14.f, 15.f, 16.f, 21.f, 22.f, 23.f, 24.f}); - auto input = NDArrayFactory::create('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f}); - auto exp1 = NDArrayFactory::create('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f}); - auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); - - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {2,-1},{}); + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {2, -1}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - auto out1 = results.at(0); - auto out2 = results.at(1); + auto out1 = results.at(0); + auto out2 = results.at(1); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test6) { + NDArray input('c', {0, 4}, sd::DataType::FLOAT32); + std::vector expShape = {0, 1}; - NDArray input('c', {0,4}, sd::DataType::FLOAT32); - std::vector expShape = {0,1}; + const int numSplits = 4; + const int axis = 1; - const int numSplits = 4; - const int axis = 1; + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i).isSameShape(expShape)); + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, split_test7) { + NDArray input('c', {0, 4}, sd::DataType::FLOAT32); + std::vector expShape = {0, 4}; - NDArray input('c', {0,4}, sd::DataType::FLOAT32); - std::vector expShape = {0,4}; - - const int numSplits = 4; - const int axis = 0; + const int numSplits = 4; + const int axis = 0; - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); - for (int i = 0; i < numSplits; ++i) - ASSERT_TRUE(results.at(i).isSameShape(expShape)); + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i).isSameShape(expShape)); } - TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {1, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {2}, {1.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2}, {1.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) { - auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {-2, -3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {-2, -3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { - auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto x = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::space_to_depth op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { - auto x = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - - sd::ops::space_to_depth op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) { - auto x = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) { - auto x = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); - auto exp = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto x = NDArrayFactory::create( + 'c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {2, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { - auto x = NDArrayFactory::create('c', {4, 4, 16, 16}); - auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); - - sd::ops::depth_to_space op; - auto result = op.evaluate({&x}, {}, {4, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {4, 4, 16, 16}); + auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {4, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_1) { - auto a = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto b = NDArrayFactory::create('c', {3}, {6, 7, 8}); - auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); + auto a = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto b = NDArrayFactory::create('c', {3}, {6, 7, 8}); + auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_2) { - auto a = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto b = NDArrayFactory::create('c', {2, 3}, {6, 7, 8, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); - - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto b = NDArrayFactory::create('c', {2, 3}, {6, 7, 8, 6, 7, 8}); + auto exp = + NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests4, Test_Cross_3) { - auto a = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto b = NDArrayFactory::create('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { -1, 2, -1, -11, 22, -11, -11, 40, -27}); - - sd::ops::cross op; - auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto a = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto b = + NDArrayFactory::create('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {-1, 2, -1, -11, 22, -11, -11, 40, -27}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_Add_119) { - auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); - sd::ops::add op; - auto result = op.evaluate({&a, &b}); + sd::ops::add op; + auto result = op.evaluate({&a, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_EQ(2, z.rankOf()); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_EQ(2, z.rankOf()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto exp = NDArrayFactory::create('c', {2, 4, 3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f, - 4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f}); - x.linspace(1.f); - - sd::ops::tile_to_shape op; - auto result = op.evaluate({&x},{}, {2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto exp = NDArrayFactory::create( + 'c', {2, 4, 3}, + {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f}); + x.linspace(1.f); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::tile_to_shape op; + auto result = op.evaluate({&x}, {}, {2, 4, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1,3,4,5}); - exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1}); - - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3, 4, 5}); + exp.linspace(1); - auto z = result.at(0); + sd::ops::strided_slice op; + auto result = op.evaluate( + {&x}, {}, {0, 0, 0, 1, 0, -999, 0, 0, 0, -999, 3, 4, 5, -999, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto begin = NDArrayFactory::create('c', {4}, {-999,0,0,0}); - auto end = NDArrayFactory::create('c', {4}, {-999,3,4,5}); - auto stride = NDArrayFactory::create('c', {4}, {-999,1,1,1}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1,3,4,5}); - exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0,0,0,1,0}); - - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto begin = NDArrayFactory::create('c', {4}, {-999, 0, 0, 0}); + auto end = NDArrayFactory::create('c', {4}, {-999, 3, 4, 5}); + auto stride = NDArrayFactory::create('c', {4}, {-999, 1, 1, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3, 4, 5}); + exp.linspace(1); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0, 0, 0, 1, 0}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { - int axis = 0; - auto x = NDArrayFactory::create('c', {1}, {10}); - auto begin = NDArrayFactory::create('c', {1}, {axis}); - auto end = NDArrayFactory::create('c', {1}, {axis}); - auto stride = NDArrayFactory::create('c', {1}, {1}); - //x.linspace(1); - //auto exp = NDArrayFactory::create('c', {1,3,4,5}); - //exp.linspace(1); + int axis = 0; + auto x = NDArrayFactory::create('c', {1}, {10}); + auto begin = NDArrayFactory::create('c', {1}, {axis}); + auto end = NDArrayFactory::create('c', {1}, {axis}); + auto stride = NDArrayFactory::create('c', {1}, {1}); + // x.linspace(1); + // auto exp = NDArrayFactory::create('c', {1,3,4,5}); + // exp.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,0,0,0}); - - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.isEmpty()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.isEmpty()); } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { - auto x = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); - auto begin = NDArrayFactory::create('c', {2}, {0, 0}); - auto end = NDArrayFactory::create('c', {2}, {0,1}); - auto stride = NDArrayFactory::create('c', {2}, {1,1}); -// x.linspace(1); - auto exp = NDArrayFactory::create('c', {1}, {1}); - //exp.linspace(1); - - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,1,0,2}); + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto begin = NDArrayFactory::create('c', {2}, {0, 0}); + auto end = NDArrayFactory::create('c', {2}, {0, 1}); + auto stride = NDArrayFactory::create('c', {2}, {1, 1}); + // x.linspace(1); + auto exp = NDArrayFactory::create('c', {1}, {1}); + // exp.linspace(1); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1, 0, 1, 0, 2}); - auto z = result.at(0); - ASSERT_TRUE(z.lengthOf() == 1); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(z.lengthOf() == 1); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test1) { + auto x1 = NDArrayFactory::create('c', {2, 2, 2}); + auto x2 = NDArrayFactory::create('c', {2, 2, 2}); + auto x3 = NDArrayFactory::create('c', {2, 2, 2}); + x1.linspace(1); + x2.linspace(9); + x3.linspace(17); - auto x1 = NDArrayFactory::create('c', {2,2,2}); - auto x2 = NDArrayFactory::create('c', {2,2,2}); - auto x3 = NDArrayFactory::create('c', {2,2,2}); - x1.linspace(1); - x2.linspace(9); - x3.linspace(17); - - auto expected = NDArrayFactory::create('c', {3,2,2,2}); - expected.linspace(1); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3, 2, 2, 2}); + expected.linspace(1); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test2) { + auto x1 = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {1, 2}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {1, 2}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {1,2}, {1,2}); - auto x2 = NDArrayFactory::create('c', {1,2}, {3,4}); - auto x3 = NDArrayFactory::create('c', {1,2}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,1,2}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + auto expected = + NDArrayFactory::create('c', {3, 1, 2}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test3) { + auto x1 = NDArrayFactory::create('c', {2, 1}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {2, 1}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {2, 1}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {2,1}, {1,2}); - auto x2 = NDArrayFactory::create('c', {2,1}, {3,4}); - auto x3 = NDArrayFactory::create('c', {2,1}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,2,1}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = + NDArrayFactory::create('c', {3, 2, 1}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } -\ + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test4) { + auto x1 = NDArrayFactory::create('c', {2}, {1, 2}); + auto x2 = NDArrayFactory::create('c', {2}, {3, 4}); + auto x3 = NDArrayFactory::create('c', {2}, {5, 6}); - auto x1 = NDArrayFactory::create('c', {2}, {1,2}); - auto x2 = NDArrayFactory::create('c', {2}, {3,4}); - auto x3 = NDArrayFactory::create('c', {2}, {5,6}); - - auto expected = NDArrayFactory::create('c', {3,2}, {1,2,3,4,5,6}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = + NDArrayFactory::create('c', {3, 2}, {1, 2, 3, 4, 5, 6}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test5) { + auto x1 = NDArrayFactory::create('c', {1}, {1}); + auto x2 = NDArrayFactory::create('c', {1}, {3}); + auto x3 = NDArrayFactory::create('c', {1}, {5}); - auto x1 = NDArrayFactory::create('c', {1}, {1}); - auto x2 = NDArrayFactory::create('c', {1}, {3}); - auto x3 = NDArrayFactory::create('c', {1}, {5}); - - auto expected = NDArrayFactory::create('c', {3,1}, {1,3,5}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3, 1}, {1, 3, 5}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test6) { + auto x1 = NDArrayFactory::create(1.); + auto x2 = NDArrayFactory::create(3.); + auto x3 = NDArrayFactory::create(5.); - auto x1 = NDArrayFactory::create(1.); - auto x2 = NDArrayFactory::create(3.); - auto x3 = NDArrayFactory::create(5.); - - auto expected = NDArrayFactory::create('c', {3}, {1,3,5}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1, &x2, &x3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto expected = NDArrayFactory::create('c', {3}, {1, 3, 5}); + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, parallel_stack_test7) { + auto x1 = NDArrayFactory::create(1.); + auto expected = NDArrayFactory::create('c', {1}, {1.}); - auto x1 = NDArrayFactory::create(1.); - auto expected = NDArrayFactory::create('c', {1}, {1.}); - - sd::ops::parallel_stack op; - auto results = op.evaluate({&x1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test1) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - // out0->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {2, 3, 4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, + 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + // out0->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test2) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 2, 4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, + 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test3) { - - auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {1,3}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {2,2}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {1, 3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {2, 2}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {3, 2, 4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, + 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {3, 2, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test4) { - - auto in0 = NDArrayFactory::create('c', {1,2}, {1, 2}); - auto in1 = NDArrayFactory::create('c', {3,1}, {10, 20, 30}); - auto in2 = NDArrayFactory::create('c', {1,4,1}, {100, 200, 300, 400}); - auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); - auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3, 1}, {10, 20, 30}); + auto in2 = + NDArrayFactory::create('c', {1, 4, 1}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create( + 'c', {2, 3, 4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, + 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create( + 'c', {2, 3, 4}, + {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, + 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test5) { - - auto in0 = NDArrayFactory::create(1); - auto in1 = NDArrayFactory::create(2); - auto in2 = NDArrayFactory::create(3); - auto exp0 = NDArrayFactory::create('c', {1,1,1}, {1}); - auto exp1 = NDArrayFactory::create('c', {1,1,1}, {2}); - auto exp2 = NDArrayFactory::create('c', {1,1,1}, {3}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create(1); + auto in1 = NDArrayFactory::create(2); + auto in2 = NDArrayFactory::create(3); + auto exp0 = NDArrayFactory::create('c', {1, 1, 1}, {1}); + auto exp1 = NDArrayFactory::create('c', {1, 1, 1}, {2}); + auto exp2 = NDArrayFactory::create('c', {1, 1, 1}, {3}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test6) { - - auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); - auto in1 = NDArrayFactory::create(5); - auto in2 = NDArrayFactory::create(6); - auto exp0 = NDArrayFactory::create('c', {4,1,1}, {1,2,3,4}); - auto exp1 = NDArrayFactory::create('c', {4,1,1}, {5,5,5,5}); - auto exp2 = NDArrayFactory::create('c', {4,1,1}, {6,6,6,6}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {4, 1, 1}, {5, 5, 5, 5}); + auto exp2 = NDArrayFactory::create('c', {4, 1, 1}, {6, 6, 6, 6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test7) { - - auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); - auto in1 = NDArrayFactory::create(5); - auto in2 = NDArrayFactory::create(6); - auto exp0 = NDArrayFactory::create('c', {1,4,1}, {1,2,3,4}); - auto exp1 = NDArrayFactory::create('c', {1,4,1}, {5,5,5,5}); - auto exp2 = NDArrayFactory::create('c', {1,4,1}, {6,6,6,6}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); - auto out0 = results.at(0); - auto out1 = results.at(1); - auto out2 = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp2.equalsTo(out2)); - - + auto in0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {1, 4, 1}, {1, 2, 3, 4}); + auto exp1 = NDArrayFactory::create('c', {1, 4, 1}, {5, 5, 5, 5}); + auto exp2 = NDArrayFactory::create('c', {1, 4, 1}, {6, 6, 6, 6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test8) { + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); - auto in0 = NDArrayFactory::create(5); - auto exp0 = NDArrayFactory::create('c', {1}, {5}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0}, {}, {0}); - auto out0 = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {0}); + auto out0 = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, meshgrid_test9) { + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); - auto in0 = NDArrayFactory::create(5); - auto exp0 = NDArrayFactory::create('c', {1}, {5}); - - sd::ops::meshgrid op; - auto results = op.evaluate({&in0}, {}, {1}); - auto out0 = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp0.isSameShape(out0)); - ASSERT_TRUE(exp0.equalsTo(out0)); - + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {1}); + auto out0 = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) { - - - auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); - auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}); - auto weight = NDArrayFactory::create(0.7f); - auto expected = NDArrayFactory::create('c', {2, 3}, {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}); - -//Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}; -//---------- -//Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}; -//---------- -//Weights [0.7] -//Result {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887} - - sd::ops::weighted_cross_entropy_with_logits op; - auto results = op.evaluate({&targets, &input, &weight}); - auto output = results.at(0); - - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + auto input = NDArrayFactory::create( + 'c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create( + 'c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); + auto weight = NDArrayFactory::create(0.7f); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}); + + // Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}; + //---------- + // Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}; + //---------- + // Weights [0.7] + // Result {-159.50006, -191.1, -16.009075, -210., -24.001238, + // -15.03887} + + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weight}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) { + auto input = NDArrayFactory::create( + 'c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create( + 'c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); + auto weights = NDArrayFactory::create({0.5f, 0.7f, 1.0f}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, + {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weights}); + auto output = results.at(0); - auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); - auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); - auto weights = NDArrayFactory::create({0.5f, 0.7f, 1.0f}) ; - auto expected = NDArrayFactory::create('c', {2, 3}, {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); - - sd::ops::weighted_cross_entropy_with_logits op; - auto results = op.evaluate({&targets, &input, &weights}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, lstm_test1) { - - const int time = 5; - const int batchSize = 3; - const int inSize = 3; - const int numProj = 3; - const int numUnits = 3; - - auto x = NDArrayFactory::create('c', {time, batchSize, inSize}); - auto h0 = NDArrayFactory::create('c', {batchSize, numProj}); - auto c0 = NDArrayFactory::create('c', {batchSize, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); - auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); - auto Wc = NDArrayFactory::create('c', {3*numUnits}); - auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); - auto b = NDArrayFactory::create('c', {4*numUnits}); - - x.linspace(0.5, 0.5); - h0 = 1.; - c0 = 2.; - Wx = 0.003; - Wh = 0.006; - Wc = 0.; - Wp = 0.; - b = 0.5; - - auto expH = NDArrayFactory::create('c', {time, batchSize, numProj}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, - 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, - 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, - 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, - 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); - - auto expClast = NDArrayFactory::create('c', {1, batchSize, numProj}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); - - sd::ops::lstm op; - auto results = op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto c = results.at(1); - auto cLast = c({4,5,0,0,0,0},true); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - - ASSERT_TRUE(expClast.isSameShape(&cLast)); - ASSERT_TRUE(expClast.equalsTo(&cLast)); - - + const int time = 5; + const int batchSize = 3; + const int inSize = 3; + const int numProj = 3; + const int numUnits = 3; + + auto x = NDArrayFactory::create('c', {time, batchSize, inSize}); + auto h0 = NDArrayFactory::create('c', {batchSize, numProj}); + auto c0 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4 * numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4 * numUnits}); + auto Wc = NDArrayFactory::create('c', {3 * numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4 * numUnits}); + + x.linspace(0.5, 0.5); + h0 = 1.; + c0 = 2.; + Wx = 0.003; + Wh = 0.006; + Wc = 0.; + Wp = 0.; + b = 0.5; + + auto expH = NDArrayFactory::create( + 'c', {time, batchSize, numProj}, + {0.57574, 0.57574, 0.57574, 0.58006, 0.58006, 0.58006, 0.58434, 0.58434, + 0.58434, 0.55114, 0.55114, 0.55114, 0.55732, 0.55732, 0.55732, 0.56338, + 0.56338, 0.56338, 0.53763, 0.53763, 0.53763, 0.54534, 0.54534, 0.54534, + 0.55287, 0.55287, 0.55287, 0.53626, 0.53626, 0.53626, 0.54487, 0.54487, + 0.54487, 0.55327, 0.55327, 0.55327, 0.54484, 0.54484, 0.54484, 0.55379, + 0.55379, 0.55379, 0.5625, 0.5625, 0.5625}); + + auto expClast = NDArrayFactory::create( + 'c', {1, batchSize, numProj}, + {1.1589154, 1.1589154, 1.1589154, 1.1892855, 1.1892855, 1.1892855, + 1.219861, 1.219861, 1.219861}); + + sd::ops::lstm op; + auto results = + op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto c = results.at(1); + auto cLast = c({4, 5, 0, 0, 0, 0}, true); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(&cLast)); + ASSERT_TRUE(expClast.equalsTo(&cLast)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, relu6_test1) { + auto input = NDArrayFactory::create('c', {2, 4}, + {-13., 10, -5, 0, 2, 7, 6, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, {0., 6., 0., 0., 2., 6., 6., 6.}); - auto input = NDArrayFactory::create('c', {2,4}, {-13.,10,-5,0,2,7,6,12}); - auto expected = NDArrayFactory::create('c', {2,4}, {0., 6., 0., 0.,2., 6., 6., 6.}); + sd::ops::relu6 op; + auto results = op.evaluate({&input}, {0.}, {}); - sd::ops::relu6 op; - auto results = op.evaluate({&input}, {0.}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, relu6_bp_test1) { + auto input = NDArrayFactory::create('c', {2, 4}, + {-13., 10, -5, 0, 2, 7, 6, 5}); + auto gradO = NDArrayFactory::create( + 'c', {2, 4}, {-1., -2., 0., 4., 5., 6., 7., 8.}); - auto input = NDArrayFactory::create('c', {2,4}, {-13.,10, -5, 0, 2, 7, 6, 5}); - auto gradO = NDArrayFactory::create('c', {2,4}, {-1., -2., 0., 4., 5., 6., 7., 8.}); - - auto expected = NDArrayFactory::create('c', {2,4}, {0., 0., 0., 0., 5., 0., 0., 8.}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, {0., 0., 0., 0., 5., 0., 0., 8.}); - sd::ops::relu6_bp op; - auto results = op.evaluate({&input, &gradO}, {0.}); + sd::ops::relu6_bp op; + auto results = op.evaluate({&input, &gradO}, {0.}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + auto output = results.at(0); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } - //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f}); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { - - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.9824562f, 0.f, 0.03822664f, 0.9824562f, - 0.67488194f, 0.f, 0.18924236f, 0.96960944f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, - 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, - 0.905509f, 0.f, 0.2824086f, 0.8361251f, - 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, - 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { + + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0.9824562f, 0.f, 0.03822664f, 0.9824562f, 0.67488194f, + 0.f, 0.18924236f, 0.96960944f, 0.99330735f, 0.f, + 0.f, 0.37139067f, 0.86567914f, 0.18702209f, 0.05610663f, + 0.9520745f, 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, 0.57063663f, + 0.41959068f, 0.629386f, 0.3504383f, 0.9520745f, 0.21039814f, + 0.06311944f, 0.3268602f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { - - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, - 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { + + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0.70082176f, 0.f, 0.03822664f, 0.70082176f, 0.21835658f, + 0.f, 0.18924236f, 0.9462118f, 0.9922489f, 0.f, + 0.f, 0.04615111f, 0.46755522f, 0.18702209f, 0.05610663f, + 0.8415994f, 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, 0.54546785f, + 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, + 0.06311944f, 0.10519907f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + { - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, - 4.5f, 1.f, 0.3f, 0.5f} - ); - - auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 5.5f, 0.f, 0.3f, 5.5f, 1.5f, 0.f, 1.3f, 6.5f, 8.6f, 0.f, 0.f, + 0.4f, 2.5f, 1.f, 0.3f, 4.5f, 1.5f, 1.f, 1.3f, 1.5f, 3.5f, 0.f, + 1.3f, 2.5f, 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, - 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} - ); + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f}); - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); - // exp.printIndexedBuffer("LRN exp"); -// ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test1) { + const int rows = 3; + const int cols = 5; - const int rows = 3; - const int cols = 5; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols}); - auto output = results.at(0); - - // output->printIndexedBuffer(); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols}); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + // output->printIndexedBuffer(); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test2) { + const int rows = 3; + const int cols = 5; + const int diag = 2; - const int rows = 3; - const int cols = 5; - const int diag = 2; + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test3) { + const int rows = 3; + const int cols = 5; + const int diag = -1; - const int rows = 3; - const int cols = 5; - const int diag = -1; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test4) { + const int rows = 3; + const int cols = 5; + const int diag = -2; - const int rows = 3; - const int cols = 5; - const int diag = -2; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test5) { + const int rows = 5; - const int rows = 5; - - auto expected = NDArrayFactory::create('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + auto expected = NDArrayFactory::create( + 'c', {rows, rows}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test6) { + const int rows = 3; + const int cols = 5; + const int diag = -20; - const int rows = 3; - const int cols = 5; - const int diag = -20; - - auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, tri_test7) { + const int rows = 3; + const int cols = 5; + const int diag = 20; - const int rows = 3; - const int cols = 5; - const int diag = 20; + auto expected = + NDArrayFactory::create('c', {rows, cols}, + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - sd::ops::tri op; - auto results = op.evaluate({}, {}, {rows, cols, diag}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test1) { + auto input = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test2) { + auto input = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 4, 5, 6, 0, 8, 9, 0, 0, 12}); - auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3,4, 5, 6,0, 8, 9,0, 0, 12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test3) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 0, 6, 7, 8, 9, 10, 0, 12}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,3, 4,0, 6,7, 8,9,10,0,12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test4) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 0, 4, 0, 0, 7, 8, 0, 10, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,0, 4,0, 0,7, 8,0, 10,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test5) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 2, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 2,0, 0,0, 0,0, 8,0, 0,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {1}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test6) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 0,0, 0,0, 0,0, 0,0, 0,0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {10}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {10}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test7) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-10}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-10}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test8) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6, + 0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6,0, 2, 3, 4, 5, 6,0, 0, 3, 4, 5, 6,0, 0, 0, 4, 5, 6,0, 0, 0, 0, 5, 6,0, 0, 0, 0, 0, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test9) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, + 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test10) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {3}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {3}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_test11) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, + 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); - - sd::ops::triu op; - auto results = op.evaluate({&input}, {}, {-58}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-58}); + auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test1) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto gradO = NDArrayFactory::create('c', {2, 3, 2}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 0.5, 0., 0., 0., 0., 0., 0.5, 0., 0., 0., 0.}); - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test2) { + auto input = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto gradO = NDArrayFactory::create('c', {2, 3, 2}); - gradO = 0.5; + auto expected = NDArrayFactory::create( + 'c', {2, 3, 2}, {0.5, 0.5, 0., 0.5, 0., 0., 0.5, 0.5, 0., 0.5, 0., 0.}); - auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.5,0.5,0. ,0.5,0. ,0. ,0.5,0.5,0. ,0.5,0. ,0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test3) { + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {6, 6}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - auto gradO = NDArrayFactory::create('c', {6,6}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {6,6}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0. , 0.5, 0.5, 0.5, 0.5,0. , 0. , 0. , 0.5, 0.5, 0.5}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {-2}); - auto gradI = results.at(0); + auto expected = NDArrayFactory::create( + 'c', {6, 6}, + {0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0., 0.5, 0.5, 0.5, 0.5, 0.5, + 0., 0., 0.5, 0.5, 0.5, 0.5, 0., 0., 0., 0.5, 0.5, 0.5}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {-2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, triu_bp_test4) { + auto input = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + gradO = 0.5; - auto input = NDArrayFactory::create('c', {2,3}, {1, 2, 3, 4, 5, 6}); - auto gradO = NDArrayFactory::create('c', {2,3}); - gradO = 0.5; - - auto expected = NDArrayFactory::create('c', {2,3}, {0., 0., 0., 0., 0., 0.}); - - sd::ops::triu_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {10}); - auto gradI = results.at(0); + auto expected = + NDArrayFactory::create('c', {2, 3}, {0., 0., 0., 0., 0., 0.}); - ASSERT_EQ(Status::OK(), results.status()); - - ASSERT_TRUE(expected.isSameShape(gradI)); - ASSERT_TRUE(expected.equalsTo(gradI)); + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {10}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index f319dcf7cb11..3eac6cf72675 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -18,1908 +18,1943 @@ // @author raver119@gmail.com // - -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests5 : public testing::Test { -public: - - DeclarableOpsTests5() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests5() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests5, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); - x.linspace(1); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, + 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, + 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, + 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, + 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0, 2, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, + 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, + 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - auto z = result.at(0); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_PermuteEquality_2) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 3, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, + 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); -// x.printShapeInfo("{1, 0, 2} shape"); -// x.printBuffer("{1, 0, 2} data"); + // x.printShapeInfo("{1, 0, 2} shape"); + // x.printBuffer("{1, 0, 2} data"); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {1, 0, 2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{1, 2, 0} shape"); -// x.printBuffer("{1, 2, 0} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 5, 3}, + {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, + 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, + 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, + 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, + 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {1, 2, 0}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{1, 2, 0} shape"); + // x.printBuffer("{1, 2, 0} data"); - auto z = result.at(0); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 2, 0}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, + 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, + 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, + 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, + 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {2, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 4, 3}, {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - -// x.printShapeInfo("{2, 1, 0} shape"); -// x.printBuffer("{2, 1, 0} data"); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, + 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, + 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, + 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, + 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {2, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); + // x.printShapeInfo("{2, 1, 0} shape"); + // x.printBuffer("{2, 1, 0} data"); - auto z = result.at(0); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { - auto x = NDArrayFactory::create('c', {2, 1, 3}); - auto eps = NDArrayFactory::create('c', {2, 4, 3}); + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto eps = NDArrayFactory::create('c', {2, 4, 3}); - auto exp = NDArrayFactory::create('c', {2, 1, 3}, {22.f, 26.f, 30.f, 70.f, 74.f, 78.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 1, 3}, {22.f, 26.f, 30.f, 70.f, 74.f, 78.f}); - eps.linspace(1.f); + eps.linspace(1.f); - sd::ops::tile_to_shape_bp op; - auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); + sd::ops::tile_to_shape_bp op; + auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printShapeInfo("RES shape"); - // x.printShapeInfo("EXP shape"); - // z->printIndexedBuffer("RES output"); - ASSERT_TRUE(x.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // z->printShapeInfo("RES shape"); + // x.printShapeInfo("EXP shape"); + // z->printIndexedBuffer("RES output"); + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_Rdiv_bp_1) { - auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - sd::ops::reversedivide op_ff; - auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result_ff.status()); + sd::ops::reversedivide op_ff; + auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result_ff.status()); - auto z_ff = result_ff.at(0); - ASSERT_TRUE(eps.isSameShape(z_ff)); + auto z_ff = result_ff.at(0); + ASSERT_TRUE(eps.isSameShape(z_ff)); - sd::ops::reversedivide_bp op_bp; - auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result_bp.status()); + sd::ops::reversedivide_bp op_bp; + auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result_bp.status()); - auto z_bp = result_bp.at(0); - ASSERT_TRUE(x.isSameShape(z_bp)); + auto z_bp = result_bp.at(0); + ASSERT_TRUE(x.isSameShape(z_bp)); } - TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto y = NDArrayFactory::create(2.0f); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto y = NDArrayFactory::create(2.0f); - sd::ops::less op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(result.at(0).t(0), true); - + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(result.at(0).t(0), true); } TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {120}); - auto y = NDArrayFactory::create(5); + auto x = NDArrayFactory::create('c', {1, 1}, {120}); + auto y = NDArrayFactory::create(5); + + sd::ops::set_seed op; + auto result = op.evaluate({&x, &y}, {}, {120, 5}); - sd::ops::set_seed op; - auto result = op.evaluate({&x, &y}, {}, {120, 5}); + ASSERT_EQ(Status::OK(), result.status()); + // result->at(0)->printIndexedBuffer("RES SEED"); - ASSERT_EQ(Status::OK(), result.status()); -// result->at(0)->printIndexedBuffer("RES SEED"); - - sd::ops::get_seed getOp; - auto getRes = getOp.evaluate({}); - ASSERT_EQ(Status::OK(), getRes.status()); -// getres.at(0)->printIndexedBuffer("Output RES GET SEED"); -// ASSERT_EQ(result.at(0)->t(0), true); + sd::ops::get_seed getOp; + auto getRes = getOp.evaluate({}); + ASSERT_EQ(Status::OK(), getRes.status()); + // getres.at(0)->printIndexedBuffer("Output RES GET SEED"); + // ASSERT_EQ(result.at(0)->t(0), true); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterMul_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); - - sd::ops::scatter_mul op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); - auto z = result.at(0); + sd::ops::scatter_mul op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterDiv_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); - sd::ops::scatter_div op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_div op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Scatter Div"); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Scatter Div"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, scatterSub_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); - - sd::ops::scatter_sub op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); - auto z = result.at(0); -// z->printIndexedBuffer("Scatter Sub"); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_sub op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("Scatter Sub"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); - - sd::ops::hardsigmoid op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::hardsigmoid op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); + auto matrix = + NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); - sd::ops::hardsigmoid_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::hardsigmoid_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); - - sd::ops::hardtanh op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {-1, -1, -1, -1, 0, 1, 1, 1, 1}); - auto z = result.at(0); -// z->printIndexedBuffer("Hardtanh 2x2"); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::hardtanh op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("Hardtanh 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); - sd::ops::hardtanh_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::hardtanh_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Hardtanh_bp 2x2"); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Hardtanh_bp 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, histogram_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); - - sd::ops::histogram op; - auto result = op.evaluate({&matrix}, {}, {3}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); - auto z = result.at(0); -// z->printIndexedBuffer("Histogram3"); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {3}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("Histogram3"); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, histogram_test2) { - auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); - auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); + auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); + auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); - sd::ops::histogram op; - auto result = op.evaluate({&matrix}, {}, {4}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {4}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); -// auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); - - sd::ops::identity op; - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); + // auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); - auto z = result.at(0); - ASSERT_TRUE(matrix.equalsTo(z)); + sd::ops::identity op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + ASSERT_TRUE(matrix.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Identity_test2) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); -// auto exp = NDArrayFactory::create('c', {3,3}); - sd::ops::identity_bp op; - auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.equalsTo(eps)); - - + auto matrix = NDArrayFactory::create('c', {3, 3}, + {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + // auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::identity_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(eps)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Log1p_test1) { - auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); - // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); -// auto exp = NDArrayFactory::create('c', {3,3}); - sd::ops::Log1p op; - y.applyTransform(sd::transform::Log, y); - auto result = op.evaluate({&matrix}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(z.equalsTo(y)); - - + auto matrix = + NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); + auto y = + NDArrayFactory::create('c', {3, 3}, {5, 4, 3, 2, 1, 2, 3, 4, 5}); + // auto eps = NDArrayFactory::create('c', {3, 3}, + // {1,2,3,4,5,6,7,8,9}); + // auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::Log1p op; + y.applyTransform(sd::transform::Log, y); + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z.equalsTo(y)); } TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { + auto x = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { - auto x = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto x = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); + auto exp = NDArrayFactory::create( + 'c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16}); - auto x = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - auto exp = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16}); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printIndexedBuffer(); + auto z = result.at(0); + // z->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) { - - const int blockSize = 2; - NDArray x('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); - NDArray paddings = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - - NDArray exp('c', {3*blockSize*blockSize, 3, 4, 2}, {0,0, 0,0, 0,0, 0,0, 0,0, 11,12, 13,14, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 35,36, 37,38, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 59,60, 61,62, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 83,84, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 107, 108, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 0,0, 131, 132, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 147, 148, 149, 150, 0,0, 0,0, 155, 156, 157, 158, - 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 171, 172, 173, 174, 0,0, 0,0, 179, 180, 181, 182, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 195, 196, - 197, 198, 0,0, 0,0, 203, 204, 205, 206, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 219, 220, 0,0, 0,0, 0,0, 227, 228, 0,0, 0,0, 0,0, - 0,0, 0,0, 0,0, 0,0, 243, 244, 0,0, 0,0, 0,0, 251, 252, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 267, 268, 0,0, 0,0, 0,0, 275, - 276, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0}, sd::DataType::FLOAT32); - - sd::ops::space_to_batch op; - auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + const int blockSize = 2; + NDArray x( + 'c', {3, 3 * blockSize - 1 - 2, 4 * blockSize - 2 - 3, 2}, + {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, + 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, + 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, + 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, + sd::DataType::FLOAT32); + NDArray paddings = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); + + NDArray exp('c', {3 * blockSize * blockSize, 3, 4, 2}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 12, 13, 14, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 35, 36, 37, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 59, 60, 61, 62, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 83, 84, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 107, 108, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 131, 132, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 147, 148, 149, 150, 0, 0, 0, 0, 155, 156, 157, 158, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 171, 172, 173, 174, 0, 0, + 0, 0, 179, 180, 181, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 195, 196, 197, 198, 0, 0, 0, 0, 203, 204, 205, 206, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 219, 220, 0, 0, 0, 0, + 0, 0, 227, 228, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 243, 244, 0, 0, 0, 0, 0, 0, 251, 252, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 267, 268, 0, 0, 0, 0, + 0, 0, 275, 276, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::FLOAT32); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create( + 'c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto z = result.at(0); - // z->printIndexedBuffer(); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer(); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto x = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) { - auto x = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, - 0, 2, 4, 0, 10, 12, - 0, 5, 7, 0, 13, 15, - 0, 6, 8, 0, 14, 16}); - auto exp = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create( + 'c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); - auto z = result.at(0); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { + const int blockSize = 2; + NDArray x('c', {3 * blockSize * blockSize, 3, 4, 2}, sd::DataType::FLOAT32); + x.linspace(1, 1); + NDArray crops = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); - const int blockSize = 2; - NDArray x('c', {3*blockSize*blockSize, 3, 4, 2}, sd::DataType::FLOAT32); - x.linspace(1, 1); - NDArray crops = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); + NDArray exp( + 'c', {3, 3 * blockSize - 1 - 2, 4 * blockSize - 2 - 3, 2}, + {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, + 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, + 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, + 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, + sd::DataType::FLOAT32); - NDArray exp('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::batch_to_space op; - auto result = op.evaluate({&x, &crops}, {}, {blockSize}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test1) { + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); - auto expected = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3}); + auto output = results.at(0); + // output->printIndexedBuffer(); - sd::ops::eye op; - auto results = op.evaluate({}, {}, {-99, 3}); - auto output = results.at(0); - // output->printIndexedBuffer(); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test2) { + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto expected = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - - sd::ops::eye op; - auto results = op.evaluate({}, {}, {-99, 3, 4}); - auto output = results.at(0); + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3, 4}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test3) { + auto expected = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); - auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); - - sd::ops::eye op; - auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); - auto output = results.at(0); - // output->printIndexedBuffer("Output eye"); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::eye op; + auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); + auto output = results.at(0); + // output->printIndexedBuffer("Output eye"); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test4) { + auto expected = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., + 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., + 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); - auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); - - sd::ops::eye op; - auto results = op.evaluate({}, {6/*double*/}, {-99, 3, 4, 2, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::eye op; + auto results = op.evaluate({}, {6 /*double*/}, {-99, 3, 4, 2, 2}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test5) { + sd::ops::eye op; + auto result = op.evaluate({}, {}, {3, 2}); - sd::ops::eye op; - auto result = op.evaluate({},{},{3, 2}); - - auto z = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - + auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test1) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2, 2, 1}, {3, 2, 3, 2}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,2,1}, {3,2,3,2}); - - auto expected = NDArrayFactory::create('c', {2,2,3,2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); + auto expected = NDArrayFactory::create( + 'c', {2, 2, 3, 2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test2) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = + NDArrayFactory::create('c', {2, 2, 2}, {3, 2, 1, 2, 0, 1, 0, 1}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,2,2}, {3,2,1,2, 0,1,0,1}); - - auto expected = NDArrayFactory::create('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); + auto expected = NDArrayFactory::create('c', {2, 2, 2}, + {23, 24, 11, 12, 3, 4, 3, 4}); - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}, {true}); - auto output = results.at(0); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test3) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {3}, {3, 2, 1}); + auto expected = NDArrayFactory::create(24.); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {3}, {3,2,1}); - auto expected = NDArrayFactory::create(24.); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test4) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 0, 2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {24., 6}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {2,3}, {3,2,1,0,2,1}); - auto expected = NDArrayFactory::create('c',{2}, {24., 6}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test5) { + auto input = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto indices = NDArrayFactory::create('c', {5, 1}, {3, 2, 0, 1, 1}); + auto expected = NDArrayFactory::create('c', {5}, {4., 3, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); - auto indices = NDArrayFactory::create('c', {5,1}, {3,2,0,1,1}); - auto expected = NDArrayFactory::create('c',{5}, {4.,3,1,2,2}); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test6) { + auto input = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + std::vector shape = {1}; + auto indices = NDArrayFactory::create('c', shape, {2}); + auto expected = NDArrayFactory::create(3.); - auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); - std::vector shape = {1}; - auto indices = NDArrayFactory::create('c', shape, {2}); - auto expected = NDArrayFactory::create(3.); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}); - auto output = results.at(0); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test7) { + auto input = NDArrayFactory::create('c', {4, 4}); + input.linspace(1); + auto indices = NDArrayFactory::create( + 'c', {3, 3, 2}, {0, 2, 1, 0, 1, 0, 1, 3, 1, 0, 2, 1, 0, 1, 0, 1, 3, 1}); + auto expected = NDArrayFactory::create('c', {3, 3}, + {3, 5, 5, 8, 5, 10, 2, 2, 14}); - auto input = NDArrayFactory::create('c', {4, 4}); - input.linspace(1); - auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,1,0, 1,3,1, 0,2,1, 0,1,0, 1,3,1}); - auto expected = NDArrayFactory::create('c', {3,3}, {3,5,5,8,5,10,2,2,14}); - - sd::ops::gather_nd op; - auto results = op.evaluate({&input, &indices}, {}, {}, {true}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test8) { - auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); - auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); - auto e = NDArrayFactory::create('c', {2}, {1., 4.}); - - sd::ops::gather_nd op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto e = NDArrayFactory::create('c', {2}, {1., 4.}); - auto z = result.at(0); + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(e, z); + auto z = result.at(0); - + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests5, gatherNd_test9) { - auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); - auto indices = NDArrayFactory::create('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1}); - auto exp = NDArrayFactory::create('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); - x.linspace(1); - - sd::ops::gather_nd op; - auto result = op.evaluate({&x, &indices}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); + auto indices = + NDArrayFactory::create('c', {3, 3}, {0, 2, 1, 0, 1, 0, 1, 3, 1}); + auto exp = NDArrayFactory::create('c', {3, 2}, + {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); + x.linspace(1); - auto z = result.at(0); + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - //z->printIndexedBuffer(); - //z->printShapeInfo("z shape"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printIndexedBuffer(); + // z->printShapeInfo("z shape"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test10) { + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto indices = + NDArrayFactory::create('c', {2, 2, 2}, {30, 20, 1, 2, 0, 10, 0, 1}); - auto input = NDArrayFactory::create('c', {4, 3, 2}); - auto indices = NDArrayFactory::create('c', {2,2,2}, {30,20,1,2, 0,10,0,1}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); - auto output = NDArrayFactory::create('c', {2,2,2}); + sd::ops::gather_nd op; - sd::ops::gather_nd op; - - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, gatherNd_test11) { + auto input = NDArrayFactory::create('c', {4, 4}); + auto indices = NDArrayFactory::create( + 'c', {3, 3, 2}, + {0, 2, 1, 0, 10, 0, 1, 30, 1, 0, 20, 1, 0, 1, 0, 1, 30, 1}); + auto output = NDArrayFactory::create('c', {3, 3}); - auto input = NDArrayFactory::create('c', {4, 4}); - auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1}); - auto output = NDArrayFactory::create('c', {3,3}); - - sd::ops::gather_nd op; + sd::ops::gather_nd op; - ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, + 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, + 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, + 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {4,4,4,4}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); - ASSERT_EQ(Status::OK(), results.status()); - - auto output = results.at(0); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = results.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, + 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {0,1,2,3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); - auto output = results.at(0); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, + 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, + 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, + 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {3}, {2,3,4}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); - auto output = results.at(0); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, + 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, + 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, + 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); - auto output = results.at(0); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, + 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, + 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, + 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, + 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, + 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - std::vector data = {1,0,1,0,1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - std::vector data = {1,0,1,0,1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {1, 0, 1, 0, 1}; - auto seqLengths = NDArrayFactory::create('c', {5}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {3}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); - - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); - auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); - input.linspace(1); - std::vector data = {1}; - auto seqLengths = NDArrayFactory::create('c', {1}, data); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); + auto output = results.at(0); - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { - auto input = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); - auto lengths = NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); - auto e = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto input = NDArrayFactory::create( + 'c', {8, 8, 3, 2}, + {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, + 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, + 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, + 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, + 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, + 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, + 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, + 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, + 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, + 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, + 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, + 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, + 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, + 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, + 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, + 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, + 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, + 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, + 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, + 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, + 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, + 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, + 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, + 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, + 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, + 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, + 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, + 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, + 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, + 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, + 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, + 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, + 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, + 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, + 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, + 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, + 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, + 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, + 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, + 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, + 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, + 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, + 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, + 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, + 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, + 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, + 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, + 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, + 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, + 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, + 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, + 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, + 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, + 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, + 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, + 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, + 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, + 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, + 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, + 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, + 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, + 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, + 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, + 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto lengths = + NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); + auto e = NDArrayFactory::create( + 'c', {8, 8, 3, 2}, + {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, + 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, + 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, + 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, + 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, + 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, + 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, + 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, + 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, + 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, + 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, + 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, + 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, + 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, + 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, + 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, + 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, + 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, + 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, + 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, + 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, + 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, + 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, + 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, + 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, + 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, + 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, + 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, + 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, + 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, + 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, + 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, + 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, + 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, + 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, + 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, + 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, + 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, + 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, + 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, + 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, + 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, + 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, + 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, + 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, + 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, + 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, + 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, + 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, + 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, + 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, + 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, + 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, + 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, + 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, + 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, + 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, + 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, + 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, + 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, + 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, + 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, + 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, + 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); + ASSERT_EQ(Status::OK(), results.status()); + + auto z = results.at(0); + + ASSERT_EQ(e, z); +} - sd::ops::reverse_sequence op; - auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); - ASSERT_EQ(Status::OK(), results.status()); +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_0) { + auto x = NDArrayFactory::create( + 'c', {2, 6}, + {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0, 14.0}); + auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); - auto z = results.at(0); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - ASSERT_EQ(e, z); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - -} + auto v = result.at(0); + auto i = result.at(1); + /* + v->printShapeInfo("topK_0: shape v"); + expV.printShapeInfo("topK_0: shape expV"); -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, Test_TopK_0) { - auto x = NDArrayFactory::create('c', {2, 6}, {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); - auto expV = NDArrayFactory::create('c', {2, 1}, {11.0, 14.0}); - auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); - - auto v = result.at(0); - auto i = result.at(1); -/* - v->printShapeInfo("topK_0: shape v"); - expV.printShapeInfo("topK_0: shape expV"); - - i->printShapeInfo("topK_0: shape I"); - expI.printShapeInfo("topK_0: shape expI"); - - v->printIndexedBuffer("topK_0: v"); - expV.printIndexedBuffer("topK_0: expV"); - i->printIndexedBuffer("topK_0: i"); - expI.printIndexedBuffer("topK_0: expI"); -*/ - - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - // repeat res again - for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting - } - + i->printShapeInfo("topK_0: shape I"); + expI.printShapeInfo("topK_0: shape expI"); + + v->printIndexedBuffer("topK_0: v"); + expV.printIndexedBuffer("topK_0: expV"); + i->printIndexedBuffer("topK_0: i"); + expI.printIndexedBuffer("topK_0: expI"); + */ + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); - auto expV = NDArrayFactory::create('c', {2, 1}, {11.0f, 14.0f}); - auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0f, 14.0f}); + auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto i = result.at(1); + auto v = result.at(0); + auto i = result.at(1); -// v->printShapeInfo("topK_1: shape v"); -// expV.printShapeInfo("topK_1: shape expV"); + // v->printShapeInfo("topK_1: shape v"); + // expV.printShapeInfo("topK_1: shape expV"); -// i->printShapeInfo("topK_1: shape I"); -// expI.printShapeInfo("topK_1: shape expI"); + // i->printShapeInfo("topK_1: shape I"); + // expI.printShapeInfo("topK_1: shape expI"); -// v->printIndexedBuffer("topK_1: v"); -// expV.printIndexedBuffer("topK_1: expV"); -// i->printIndexedBuffer("topK_1: i"); -// expI.printIndexedBuffer("topK_1: expI"); + // v->printIndexedBuffer("topK_1: v"); + // expV.printIndexedBuffer("topK_1: expV"); + // i->printIndexedBuffer("topK_1: i"); + // expI.printIndexedBuffer("topK_1: expI"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); - - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - // repeat res again - for (int cases = 0; cases < 100; ++cases) { - op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting - } - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, {&v, &i}, {}, {1, 0}, {}); // without sorting + } } /////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); -// <<<14.>,<9.>>, <<21.>,<9.>>, <<14.>,<16.>>> - auto expV = NDArrayFactory::create('c', {2, 3, 1}, {14.0f, 9.0f, - 21.0f, - 9.0f, 14.0f, - 16.0f - } - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + // <<<14.>,<9.>>, <<21.>,<9.>>, <<14.>,<16.>>> + auto expV = NDArrayFactory::create( + 'c', {2, 3, 1}, {14.0f, 9.0f, 21.0f, 9.0f, 14.0f, 16.0f}); - auto expI = NDArrayFactory::create('c', {2, 3, 1 }, {2, 1, 0, 1, 2, 0}); + auto expI = + NDArrayFactory::create('c', {2, 3, 1}, {2, 1, 0, 1, 2, 0}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {1, 1}); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto i = result.at(1); + auto v = result.at(0); + auto i = result.at(1); -// v->printShapeInfo("shape v"); -// expV.printShapeInfo("shape expV"); + // v->printShapeInfo("shape v"); + // expV.printShapeInfo("shape expV"); -// i->printShapeInfo("shape I"); -// expI.printShapeInfo("shape expI"); + // i->printShapeInfo("shape I"); + // expI.printShapeInfo("shape expI"); -// v->printIndexedBuffer("v"); -// expV.printIndexedBuffer("expV"); -// i->printIndexedBuffer("i"); -// expI.printIndexedBuffer("expI"); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + // i->printIndexedBuffer("i"); + // expI.printIndexedBuffer("expI"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests5, Test_TopK_3) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); - - auto expV = NDArrayFactory::create('c', {2, 3, 2}, {14.0f, 11.0f, - 9.0f, 7.0f, - 21.0f, 15.0f, - 9.0f, 7.0f, - 14.0f, 13.0f, - 16.0f, 13.5f - } - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); + auto expV = + NDArrayFactory::create('c', {2, 3, 2}, + {14.0f, 11.0f, 9.0f, 7.0f, 21.0f, 15.0f, + 9.0f, 7.0f, 14.0f, 13.0f, 16.0f, 13.5f}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); + auto expI = NDArrayFactory::create( + 'c', {2, 3, 2}, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - auto v = result.at(0); - auto i = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); -// v->printShapeInfo("shape v"); -// expV.printShapeInfo("shape expV"); + auto v = result.at(0); + auto i = result.at(1); -// i->printShapeInfo("shape I"); -// expI.printShapeInfo("shape expI"); + // v->printShapeInfo("shape v"); + // expV.printShapeInfo("shape expV"); -// v->printIndexedBuffer("v"); -// expV.printIndexedBuffer("expV"); -// i->printIndexedBuffer("i"); -// expI.printIndexedBuffer("expI"); + // i->printShapeInfo("shape I"); + // expI.printShapeInfo("shape expI"); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + // i->printIndexedBuffer("i"); + // expI.printIndexedBuffer("expI"); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0 - } - ); - - auto expV = NDArrayFactory::create('c', {2, 3, 2}, {11.0f, 14.0f, - 9.0f, 7.0f, - 21.0f, 15.0f, - 9.0f, 7.0f, - 13.0f, 14.0f, - 16.0f, 13.5f - } - ); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); + auto expV = + NDArrayFactory::create('c', {2, 3, 2}, + {11.0f, 14.0f, 9.0f, 7.0f, 21.0f, 15.0f, + 9.0f, 7.0f, 13.0f, 14.0f, 16.0f, 13.5f}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2}, {false}); + auto expI = NDArrayFactory::create( + 'c', {2, 3, 2}, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2}, {false}); - auto v = result.at(0); - auto i = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto v = result.at(0); + auto i = result.at(1); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_4) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); - auto expV = NDArrayFactory::create('c', {2, 2}, {11.0f, 3.0f, 14.0f, 6.0f}); - auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = + NDArrayFactory::create('c', {2, 2}, {11.0f, 3.0f, 14.0f, 6.0f}); + auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto i = result.at(1); + auto v = result.at(0); + auto i = result.at(1); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); - - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_5) { - auto x = NDArrayFactory::create('f', {2, 3}, {1.1, 5.2, 3.1, 14.2, 11.1, 6.2}); - auto expV = NDArrayFactory::create('f', {2, 2}, {11.1, 14.2, 3.1, 6.2}); - auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); - - sd::ops::top_k op; - auto result = op.evaluate({&x}, {}, {2, 1}); + auto x = NDArrayFactory::create('f', {2, 3}, + {1.1, 5.2, 3.1, 14.2, 11.1, 6.2}); + auto expV = + NDArrayFactory::create('f', {2, 2}, {11.1, 14.2, 3.1, 6.2}); + auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); - auto v = result.at(0); - auto i = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(expV.isSameShape(v)); - ASSERT_TRUE(expV.equalsTo(v)); + auto v = result.at(0); + auto i = result.at(1); - ASSERT_TRUE(expI.isSameShape(i)); - ASSERT_TRUE(expI.equalsTo(i)); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); - + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); } /////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_Moments_1) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); - - auto y = NDArrayFactory::create('c', {3}, {0, 1, 2}); - //auto expV('f', {6}, {1, 0, 0, 0, 0, 0 }); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - float expMean = 9.395833f; - float expDeviation = 22.4579f; -//Mean 9.395833 -//Deviance 22.4579 + auto y = NDArrayFactory::create('c', {3}, {0, 1, 2}); + // auto expV('f', {6}, {1, 0, 0, 0, 0, 0 }); - float inf = 1.e-5f; + float expMean = 9.395833f; + float expDeviation = 22.4579f; + // Mean 9.395833 + // Deviance 22.4579 - sd::ops::moments op; - auto result = op.evaluate({&x, &y}, {}, {}); + float inf = 1.e-5f; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::moments op; + auto result = op.evaluate({&x, &y}, {}, {}); - auto v = result.at(0); - auto d = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); -// v->printIndexedBuffer("Result is "); -// d->printIndexedBuffer("Result is "); + auto v = result.at(0); + auto d = result.at(1); - ASSERT_TRUE(v.isScalar()); - ASSERT_NEAR(expMean, v.e(0), inf); - ASSERT_NEAR(expDeviation, d.e(0), inf); + // v->printIndexedBuffer("Result is "); + // d->printIndexedBuffer("Result is "); - + ASSERT_TRUE(v.isScalar()); + ASSERT_NEAR(expMean, v.e(0), inf); + ASSERT_NEAR(expDeviation, d.e(0), inf); } TEST_F(DeclarableOpsTests5, Test_Moments_2) { - NDArray x('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); + NDArray x('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - NDArray expV('c', {4}, {11.833333, 7.6666665, 10.416667, 7.6666665}); - NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); + NDArray expV('c', {4}, {11.833333, 7.6666665, 10.416667, 7.6666665}); + NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 1}); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - auto v = result.at(0); - auto d = result.at(1); + auto v = result.at(0); + auto d = result.at(1); - ASSERT_TRUE(v.isVector()); - ASSERT_TRUE(d.isVector()); + ASSERT_TRUE(v.isVector()); + ASSERT_TRUE(d.isVector()); - ASSERT_TRUE(v.equalsTo(&expV)); - ASSERT_TRUE(d.equalsTo(&expD)); - - + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } TEST_F(DeclarableOpsTests5, Test_Moments_3) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, - 6.0, 9.0, 3.5, 7.0, - 21.0, 3.0, 14.0, 15.0, - 6.0, 9.0, 3.5, 7.0, - 11.0, 13.0, 14.0, 5.0, - 16.0, 9.0, 13.5, 7.0} - ); - - auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, - 8.5f, 11.f, 8.75f, 6.f, - 18.5f, 6.f, 13.75f, 11.f}); - auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, - 6.25f, 4.f, 27.5625f, 1.f, - 6.25f, 9.f, 0.0625f, 16.f}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0}); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0}); + auto expV = + NDArrayFactory::create('c', {3, 4}, + {8.5f, 6.f, 8.75f, 6.f, 8.5f, 11.f, 8.75f, + 6.f, 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create( + 'c', {3, 4}, + {6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, + 0.0625f, 16.f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); - auto v = result.at(0); - auto d = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_TRUE(v.isMatrix()); - ASSERT_TRUE(d.isMatrix()); + auto v = result.at(0); + auto d = result.at(1); - ASSERT_TRUE(v.equalsTo(&expV)); - ASSERT_TRUE(d.equalsTo(&expD)); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); - + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } TEST_F(DeclarableOpsTests5, Test_Moments_4) { + auto x = NDArrayFactory::create( + 'f', {2, 3, 4}, {11.0f, 6.0f, 6.0f, 11.0f, 21.0f, 16.0f, 3.0f, 9.0f, + 9.0f, 13.0f, 3.0f, 9.0f, 14.0f, 3.5f, 3.5f, 14.0f, + 14.0f, 13.5f, 5.0f, 7.0f, 7.0f, 5.0f, 15.0f, 7.0f}); - auto x = NDArrayFactory::create('f', {2, 3, 4}, {11.0f, 6.0f, 6.0f, 11.0f, 21.0f, 16.0f, 3.0f, 9.0f, 9.0f, 13.0f, 3.0f, 9.0f, - 14.0f, 3.5f, 3.5f, 14.0f, 14.0f, 13.5f, 5.0f, 7.0f, 7.0f, 5.0f, 15.0f, 7.0f}); - + auto expV = + NDArrayFactory::create('c', {3, 4}, + {8.5f, 6.f, 8.75f, 6.f, 8.5f, 11.f, 8.75f, + 6.f, 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create( + 'c', {3, 4}, + {6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, + 0.0625f, 16.f}); - auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, 8.5f, 11.f, 8.75f, 6.f, 18.5f, 6.f, 13.75f, 11.f}); - auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, 0.0625f, 16.f}); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(2, result.size()); + auto v = result.at(0); + auto d = result.at(1); - auto v = result.at(0); - auto d = result.at(1); + ASSERT_TRUE(v.isMatrix()); + ASSERT_TRUE(d.isMatrix()); - ASSERT_TRUE(v.isMatrix()); - ASSERT_TRUE(d.isMatrix()); + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); - // v->printIndexedBuffer("v"); - // expV.printIndexedBuffer("expV"); + // d->printIndexedBuffer("d"); + // expD.printIndexedBuffer("expD"); - // d->printIndexedBuffer("d"); - // expD.printIndexedBuffer("expD"); - - ASSERT_TRUE(v.equalsTo(&expV)); - ASSERT_TRUE(d.equalsTo(&expD)); - - + ASSERT_TRUE(v.equalsTo(&expV)); + ASSERT_TRUE(d.equalsTo(&expD)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test1) { - - auto input = NDArrayFactory::create('c', {3, 4, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); - NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - double traceM = matrix.getTrace(); - // nd4j_printf("Trace for matrix is %f\n", traceM); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // exp.printIndexedBuffer("EXP TRACE"); - // output->printIndexedBuffer("OUT TRACE"); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); + NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + double traceM = matrix.getTrace(); + // nd4j_printf("Trace for matrix is %f\n", traceM); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // exp.printIndexedBuffer("EXP TRACE"); + // output->printIndexedBuffer("OUT TRACE"); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test2) { + auto input = NDArrayFactory::create('c', {4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(40.); - auto input = NDArrayFactory::create('c', {4, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create(40.); - - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test3) { + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); - auto input = NDArrayFactory::create('c', {1, 5}); - input.linspace(1); - auto exp = NDArrayFactory::create(1.); - - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test4) { + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); - auto input = NDArrayFactory::create('c', {5, 1}); - input.linspace(1); - auto exp = NDArrayFactory::create(1.); - - sd::ops::trace op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test5) { + auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); + input.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); - auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); - input.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); - - sd::ops::trace op; - auto results = op.evaluate({&input}); - auto output = results.at(0); + sd::ops::trace op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test1) { + auto input = NDArrayFactory::create('c', {2, 2, 2}); + input.linspace(1); - auto input = NDArrayFactory::create('c', {2, 2, 2}); - input.linspace(1); - - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output.lengthOf(); ++i) - if(output.e(i) == (float)0.) - haveZeros = true; + bool haveZeros = false; + for (int i = 0; i < output.lengthOf(); ++i) + if (output.e(i) == (float)0.) haveZeros = true; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test2) { + auto input = NDArrayFactory::create('c', {1, 3, 2}); + input.linspace(1); - auto input = NDArrayFactory::create('c', {1, 3, 2}); - input.linspace(1); - - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(input.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test3) { + auto input = NDArrayFactory::create('c', {3, 2, 1}); + input.linspace(1); - auto input = NDArrayFactory::create('c', {3, 2, 1}); - input.linspace(1); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output.lengthOf(); ++i) - if(output.e(i) == (float)0.) - haveZeros = true; - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); + bool haveZeros = false; + for (int i = 0; i < output.lengthOf(); ++i) + if (output.e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test04) { - auto input = NDArrayFactory::create('c', {4}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {4}); + input.linspace(1); - sd::ops::random_shuffle op; - //NDArray* output; - auto results = op.evaluate({&input}, {}, {}, {}, {}, true); - ASSERT_EQ(Status::OK(), results.status()); - auto output = &input; //results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); + sd::ops::random_shuffle op; + // NDArray* output; + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); + ASSERT_EQ(Status::OK(), results.status()); + auto output = &input; // results.at(0); + bool haveZeros = false; + for (int i = 0; i < output->lengthOf(); ++i) + if (output->e(i) == (float)0.) haveZeros = true; - + ASSERT_TRUE(input.isSameShape(output)); + // ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test4) { - auto input = NDArrayFactory::create('c', {4}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {4}); + input.linspace(1); - sd::ops::random_shuffle op; - //NDArray* output; - auto results = op.evaluate({&input}); - ASSERT_EQ(Status::OK(), results.status()); - auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output.lengthOf(); ++i) - if(output.e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); + sd::ops::random_shuffle op; + // NDArray* output; + auto results = op.evaluate({&input}); + ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); + bool haveZeros = false; + for (int i = 0; i < output.lengthOf(); ++i) + if (output.e(i) == (float)0.) haveZeros = true; - + ASSERT_TRUE(input.isSameShape(output)); + // ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test5) { + auto input = NDArrayFactory::create('c', {4, 1}); + input.linspace(1); - auto input = NDArrayFactory::create('c', {4,1}); - input.linspace(1); - - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output.lengthOf(); ++i) - if(output.e(i) == (float)0.) - haveZeros = true; + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); + bool haveZeros = false; + for (int i = 0; i < output.lengthOf(); ++i) + if (output.e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test6) { + auto input = NDArrayFactory::create('c', {4, 1, 1}); + input.linspace(1); - auto input = NDArrayFactory::create('c', {4,1,1}); - input.linspace(1); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output.lengthOf(); ++i) - if(output.e(i) == (float)0.) - haveZeros = true; - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); + bool haveZeros = false; + for (int i = 0; i < output.lengthOf(); ++i) + if (output.e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(!input.equalsTo(output)); + ASSERT_TRUE(!haveZeros); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test7) { + auto input = NDArrayFactory::create('c', {1, 4}); + input.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto input = NDArrayFactory::create('c', {1,4}); - input.linspace(1); - auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4}); - - sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.isSameShape(output)); + ASSERT_TRUE(input.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, - 14, 24, 15, 25, 16, 26, 17, 27, - 18, 28, 19, 29, 20, 30, 21, 31}); - - auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); - auto exp = NDArrayFactory::create('c', {9, 4, 2}, {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, - 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, - 10, 20, 11, 21, 12, 22, 13, 23, 10, 20, 11, 21, - 12, 22, 13, 23, 10, 20, 11, 21, 12, 22, 13, 23, - 18, 28, 19, 29, 20, 30, 21, 31, 18, 28, 19, 29, - 20, 30, 21, 31, 18, 28, 19, 29, 20, 30, 21, 31}); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&x, &y}, {}, {0}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - //output->printShapeInfo("Output"); - //exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - //output->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {9, 4, 2}, + {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, 14, 24, + 15, 25, 16, 26, 17, 27, 10, 20, 11, 21, 12, 22, 13, 23, 10, 20, 11, 21, + 12, 22, 13, 23, 10, 20, 11, 21, 12, 22, 13, 23, 18, 28, 19, 29, 20, 30, + 21, 31, 18, 28, 19, 29, 20, 30, 21, 31, 18, 28, 19, 29, 20, 30, 21, 31}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printShapeInfo("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, - 70, 80, 90, 10, 11, 12, - 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24}); - //1, 0, 1, 0, 1, 0 - auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); - auto exp = NDArrayFactory::create('c', {6, 4, 2}, {90, 10, 11, 12, 13, 14, - 15, 16, 10, 20, 30, 40, - 50, 60, 70, 80, 90, 10, - 11, 12, 13, 14, 15, 16, - 10, 20, 30, 40, 50, 60, - 70, 80, 90, 10, 11, 12, - 13, 14, 15, 16, 10, 20, - 30, 40, 50, 60, 70, 80}); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&x, &y}, {}, {0}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - // output->printShapeInfo("Output"); - // exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // output->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, 70, 80, 90, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + // 1, 0, 1, 0, 1, 0 + auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 2}, + {90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80, + 90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80, + 90, 10, 11, 12, 13, 14, 15, 16, 10, 20, 30, 40, 50, 60, 70, 80}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printShapeInfo("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { - - - auto y = NDArrayFactory::create('c', {3,2}, {5, 4, 4, 5, 3, 3}); - auto exp = NDArrayFactory::create('c', {6, 3, 3}, { - 6, 20, 11, 21, 12, 22, 13, 23, 14, - 5, 20, 11, 21, 12, 22, 13, 23, 14, - 5, 20, 11, 21, 12, 22, 13, 23, 14, - 6, 20, 11, 21, 12, 22, 13, 23, 14, - 4, 20, 11, 21, 12, 22, 13, 23, 14, - 4, 20, 11, 21, 12, 22, 13, 23, 14 }); - - // y.printShapeInfo("y shape"); - // y.printIndexedBuffer("y buffer"); - auto p1 = NDArrayFactory::create('c', {3,3}, {1, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p2 = NDArrayFactory::create('c', {3,3}, {2, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p3 = NDArrayFactory::create('c', {3,3}, {3, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p4 = NDArrayFactory::create('c', {3,3}, {4, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p5 = NDArrayFactory::create('c', {3,3}, {5, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p6 = NDArrayFactory::create('c', {3,3}, {6, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p7 = NDArrayFactory::create('c', {3,3}, {7, 20, 11, 21, 12, 22, 13, 23, 14}); - auto p8 = NDArrayFactory::create('c', {3,3}, {8, 20, 11, 21, 12, 22, 13, 23, 14}); - -// res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') - - sd::ops::embedding_lookup op; - auto result = op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); - auto output = result.at(0); - // x.printShapeInfo("Input"); - // output->printIndexedBuffer("Output"); - // exp.printShapeInfo("Expected"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - // output->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(output)); - - + auto y = NDArrayFactory::create('c', {3, 2}, {5, 4, 4, 5, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 3, 3}, + {6, 20, 11, 21, 12, 22, 13, 23, 14, 5, 20, 11, 21, 12, 22, 13, 23, 14, + 5, 20, 11, 21, 12, 22, 13, 23, 14, 6, 20, 11, 21, 12, 22, 13, 23, 14, + 4, 20, 11, 21, 12, 22, 13, 23, 14, 4, 20, 11, 21, 12, 22, 13, 23, 14}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + auto p1 = NDArrayFactory::create('c', {3, 3}, + {1, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p2 = NDArrayFactory::create('c', {3, 3}, + {2, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p3 = NDArrayFactory::create('c', {3, 3}, + {3, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p4 = NDArrayFactory::create('c', {3, 3}, + {4, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p5 = NDArrayFactory::create('c', {3, 3}, + {5, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p6 = NDArrayFactory::create('c', {3, 3}, + {6, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p7 = NDArrayFactory::create('c', {3, 3}, + {7, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p8 = NDArrayFactory::create('c', {3, 3}, + {8, 20, 11, 21, 12, 22, 13, 23, 14}); + + // res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') + + sd::ops::embedding_lookup op; + auto result = + op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printIndexedBuffer("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); } /* @Test public void testDynamicPartition(){ INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") - .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) - .addIntegerArguments(3) //3 partitions - .addInputs(data, partitions).build()); + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), + Nd4j.createUninitialized(DataType.INT, 1), + Nd4j.createUninitialized(DataType.INT, 1)) .addIntegerArguments(3) //3 + partitions .addInputs(data, partitions).build()); INDArray exp0 = Nd4j.createFromArray(2, 0); INDArray exp1 = Nd4j.createFromArray(2); @@ -1930,1167 +1965,1170 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { assertEquals(exp2, out[2]); }*/ TEST_F(DeclarableOpsTests5, DynamicPartition_01) { + auto x = NDArrayFactory::create({2, 1, 2, 0}); - auto x = NDArrayFactory::create({2,1,2,0}); - - auto y = NDArrayFactory::create({0,2,1,0}); + auto y = NDArrayFactory::create({0, 2, 1, 0}); - int numPartition = 3; - std::vector exp( { NDArrayFactory::create('c', {2}, {2, 0}), - NDArrayFactory::create('c', {1}, {2}), - NDArrayFactory::create('c', {1}, {1})}); + int numPartition = 3; + std::vector exp({NDArrayFactory::create('c', {2}, {2, 0}), + NDArrayFactory::create('c', {1}, {2}), + NDArrayFactory::create('c', {1}, {1})}); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - // output->printShapeInfo("Output shape> "); - // output->printIndexedBuffer("Output data> "); - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - - + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } TEST_F(DeclarableOpsTests5, DynamicPartition_1) { - - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, - 13, 23, 14, 24, 15, 25, 16, 26, 17, 27, - 18, 28, 19, 29, 20, 30, 21, 31}); - - auto y = NDArrayFactory::create('c', {3, 4, 2}, {0, 0, 0, 0, 0, 0, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 1, 1, 1, 1, 1, 1, 1, 1 - } - ); -/* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f - } - ); -*/ - int numPartition = 3; - std::vector exp( { NDArrayFactory::create('c', {6}, {10, 20, 11, 21, 12, 22}), - NDArrayFactory::create('c', {8}, {18, 28, 19, 29, 20, 30, 21, 31}), - NDArrayFactory::create('c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); - - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 - - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - // output->printShapeInfo("Output shape> "); - // output->printIndexedBuffer("Output data> "); - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - - + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create( + 'c', {3, 4, 2}, + {0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}); + /* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f + } + ); + */ + int numPartition = 3; + std::vector exp( + {NDArrayFactory::create('c', {6}, {10, 20, 11, 21, 12, 22}), + NDArrayFactory::create('c', {8}, + {18, 28, 19, 29, 20, 30, 21, 31}), + NDArrayFactory::create( + 'c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); + + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicPartition_2) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); - auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); - auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); + std::vector exp( + {NDArrayFactory::create('c', {1}, {-2.2}), + NDArrayFactory::create('c', {3}, {0.1, 5.2, -1.}), + NDArrayFactory::create('c', {3}, {-1., 4.3, 7.4}), + NDArrayFactory::create('c', {1}, {0.0})}); - std::vector exp( {NDArrayFactory::create('c', {1}, {-2.2}), - NDArrayFactory::create('c', {3}, {0.1, 5.2, -1.}), - NDArrayFactory::create('c', {3}, {-1., 4.3, 7.4}), - NDArrayFactory::create('c', {1}, {0.0})}); + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); - sd::ops::dynamic_partition op; - int numPartition = 4; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - - + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } } - TEST_F(DeclarableOpsTests5, DynamicPartition_3) { - - auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); - auto y = NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); - - std::vector exp( {NDArrayFactory::create({0.1f, 5.2f, -1.f, -2.2f}), - NDArrayFactory::create('c', {1}, {-1.f}), - NDArrayFactory::create({4.3f, 7.4f}), - NDArrayFactory::create('c', {1}, {0.0f})}); - - sd::ops::dynamic_partition op; - int numPartition = 4; - auto result = op.evaluate({&x, &y}, {}, {numPartition}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 - - for (int e = 0; e < result.size(); e++) { - auto output = result.at(e); - if (output.shapeInfo()) - { - // output->printShapeInfo("Output shape> "); - // exp[e].printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - - ASSERT_TRUE(exp[e].isSameShape(output)); - ASSERT_TRUE(exp[e].equalsTo(output)); - } - else - { - ASSERT_TRUE(exp[e].lengthOf() == 0); - } + auto x = NDArrayFactory::create( + 'c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = + NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); + + std::vector exp( + {NDArrayFactory::create({0.1f, 5.2f, -1.f, -2.2f}), + NDArrayFactory::create('c', {1}, {-1.f}), + NDArrayFactory::create({4.3f, 7.4f}), + NDArrayFactory::create('c', {1}, {0.0f})}); + + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), + numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + if (output.shapeInfo()) { + // output->printShapeInfo("Output shape> "); + // exp[e].printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } else { + ASSERT_TRUE(exp[e].lengthOf() == 0); } - - + } } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { - auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); - auto i1 = NDArrayFactory::empty(); - auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - - auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); - auto d1 = NDArrayFactory::empty(); - auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::empty(); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto d0 = NDArrayFactory::create( + 'c', {2, 5}, + {0.085571885, 0.7937801, 0.65908563, 0.55552566, 0.15962744, 0.7787856, + 0.80119777, 0.72437465, 0.23089433, 0.72714126}); + auto d1 = NDArrayFactory::empty(); + auto d2 = NDArrayFactory::create( + 'c', {2, 5}, + {0.94414854, 0.5956861, 0.8668989, 0.3502196, 0.5100082, 0.061725974, + 0.6621324, 0.034165382, 0.32576954, 0.51917326}); - + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { - auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); - auto i1 = NDArrayFactory::create('c', {0}); - auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - - auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); - auto d1 = NDArrayFactory::create('c', {0, 5}); - auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::create('c', {0}); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto d0 = NDArrayFactory::create( + 'c', {2, 5}, + {0.085571885, 0.7937801, 0.65908563, 0.55552566, 0.15962744, 0.7787856, + 0.80119777, 0.72437465, 0.23089433, 0.72714126}); + auto d1 = NDArrayFactory::create('c', {0, 5}); + auto d2 = NDArrayFactory::create( + 'c', {2, 5}, + {0.94414854, 0.5956861, 0.8668989, 0.3502196, 0.5100082, 0.061725974, + 0.6621324, 0.034165382, 0.32576954, 0.51917326}); - + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_1) { + auto x1 = NDArrayFactory::create({1, 3, 5, 0}); + auto x2 = NDArrayFactory::create({2, 4}); + auto y2 = NDArrayFactory::create({-1., -1.}); + auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - auto x1 = NDArrayFactory::create({1, 3, 5, 0}); - auto x2 = NDArrayFactory::create({2, 4}); - auto y2 = NDArrayFactory::create({-1., -1.}); - auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - - - auto exp = NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); + auto exp = + NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_2) { + auto x1 = NDArrayFactory::create({1, 3}); + auto x2 = NDArrayFactory::create({5, 0, 2, 4}); + auto y1 = NDArrayFactory::create({-1.f, -1.f}); + auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - auto x1 = NDArrayFactory::create({1, 3}); - auto x2 = NDArrayFactory::create({5, 0, 2, 4}); - auto y1 = NDArrayFactory::create({-1.f, -1.f}); - auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); + auto exp = + NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); - auto exp = NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - // output->printShapeInfo("Output shape> "); - // exp.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // exp.printIndexedBuffer("Expected res>"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + // output->printShapeInfo("Output shape> "); + // exp.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // exp.printIndexedBuffer("Expected res>"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, + 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, + 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, + 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206, 2.0724206, 2.0724206, 2.0724206, 2.21726155, 2.21726155, + 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, + 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, + 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20347691, 1.20347691, 1.20347691, 1.20347691, 1.34829926, 1.34829926, 1.34829926, 1.34829926, 1.49312162, 1.49312162, 1.49312162, 1.49312162, 1.6379441 , 1.6379441 , 1.6379441 , 1.6379441 , 1.78276646, 1.78276646, 1.78276646, 1.78276646, 1.92758882, 1.92758882, 1.92758882, 1.92758882, 2.0724113 , 2.0724113 , 2.0724113 , 2.0724113 , 2.21723366, 2.21723366, 2.21723366, 2.21723366, 2.36205602, 2.36205602, 2.36205602, 2.36205602, 2.50687838, 2.50687838, 2.50687838, 2.50687838, 2.65170074, 2.65170074, 2.65170074, 2.65170074, 2.79652309, 2.79652309, 2.79652309, 2.79652309}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {1.20347691, 1.20347691, 1.20347691, 1.20347691, 1.34829926, 1.34829926, + 1.34829926, 1.34829926, 1.49312162, 1.49312162, 1.49312162, 1.49312162, + 1.6379441, 1.6379441, 1.6379441, 1.6379441, 1.78276646, 1.78276646, + 1.78276646, 1.78276646, 1.92758882, 1.92758882, 1.92758882, 1.92758882, + 2.0724113, 2.0724113, 2.0724113, 2.0724113, 2.21723366, 2.21723366, + 2.21723366, 2.21723366, 2.36205602, 2.36205602, 2.36205602, 2.36205602, + 2.50687838, 2.50687838, 2.50687838, 2.50687838, 2.65170074, 2.65170074, + 2.65170074, 2.65170074, 2.79652309, 2.79652309, 2.79652309, 2.79652309}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { - - auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); - x.linspace(1); - - auto scale = NDArrayFactory::create('c', {4}); - - scale = 0.5; - auto offset = NDArrayFactory::create('c', {4}); - offset = 2.; - auto expY = NDArrayFactory::create('c', {2, 4, 2, 3}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); - auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); - auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {1,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create( + 'c', {2, 4, 2, 3}, + {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, + 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, + 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, + 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, + 2.0724206, 2.0724206, 2.0724206, 2.0724206, 2.21726155, 2.21726155, + 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, + 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, + 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = + NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create( + 'c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {1, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - std::vector shape = {4}; - auto scale = NDArrayFactory::create('c', shape); - auto offset = NDArrayFactory::create('c', shape); - auto mean = NDArrayFactory::create('c', shape); - auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; - offset = 2.; - mean = 25.; - variance = 5.; - - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.36602688, -3.14244223, -2.91885757, -2.6952734 , -2.47168875, -2.24810457, -2.02451992, -1.80093551, -1.57735109, -1.35376668, -1.13018227, -0.90659785, -0.68301344, -0.45942879, -0.23584437, -0.01225996, 0.21132445, 0.43490887, 0.65849328, 0.88207781, 1.10566223, 1.32924664, 1.55283117, 1.77641559, 2. , 2.22358441, 2.44716883, 2.67075348, 2.89433765, 3.11792231, 3.34150672, 3.56509113, 3.78867555, 4.01225996, 4.23584461, 4.45942879, 4.68301344, 4.90659809, 5.13018227, 5.35376644, 5.57735109, 5.80093575, 6.02451992, 6.24810457, 6.47168875, 6.6952734 , 6.91885757, 7.14244223}); - auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {-3.36602688, -3.14244223, -2.91885757, -2.6952734, -2.47168875, + -2.24810457, -2.02451992, -1.80093551, -1.57735109, -1.35376668, + -1.13018227, -0.90659785, -0.68301344, -0.45942879, -0.23584437, + -0.01225996, 0.21132445, 0.43490887, 0.65849328, 0.88207781, + 1.10566223, 1.32924664, 1.55283117, 1.77641559, 2., + 2.22358441, 2.44716883, 2.67075348, 2.89433765, 3.11792231, + 3.34150672, 3.56509113, 3.78867555, 4.01225996, 4.23584461, + 4.45942879, 4.68301344, 4.90659809, 5.13018227, 5.35376644, + 5.57735109, 5.80093575, 6.02451992, 6.24810457, 6.47168875, + 6.6952734, 6.91885757, 7.14244223}); + auto expBatchMean = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { - - auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); - x.linspace(1); - std::vector shape = {4}; - auto scale = NDArrayFactory::create('c', shape); - auto offset = NDArrayFactory::create('c', shape); - auto mean = NDArrayFactory::create('c', shape); - auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; - offset = 2.; - mean = 25.; - variance = 5.; - - auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.33992958e+00, -3.11743259e+00, -2.89493513e+00, -2.67243814e+00, -2.44994116e+00, -2.22744417e+00, -2.00494719e+00, -1.78244996e+00, -1.55995297e+00, -1.33745599e+00, -1.11495876e+00, -8.92461777e-01, -6.69964790e-01, -4.47467566e-01, -2.24970579e-01, -2.47359276e-03, 2.20023513e-01, 4.42520618e-01, 6.65017605e-01, 8.87514710e-01, 1.11001182e+00, 1.33250880e+00, 1.55500591e+00, 1.77750289e+00, 2.00000000e+00, 2.22249699e+00, 2.44499421e+00, 2.66749120e+00, 2.88998818e+00, 3.11248541e+00, 3.33498240e+00, 3.55747938e+00, 3.77997637e+00, 4.00247383e+00, 4.22497082e+00, 4.44746780e+00, 4.66996479e+00, 4.89246178e+00, 5.11495876e+00, 5.33745575e+00, 5.55995274e+00, 5.78244972e+00, 6.00494719e+00, 6.22744417e+00, 6.44994116e+00, 6.67243814e+00, 6.89493513e+00, 7.11743259e+00}); - auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); - - - sd::ops::fused_batch_norm op; - auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); - auto y = results.at(0); - auto batchMean = results.at(1); - auto batchVar = results.at(2); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expY.isSameShape(y)); - ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); - ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create( + 'c', {2, 2, 3, 4}, + {-3.33992958e+00, -3.11743259e+00, -2.89493513e+00, -2.67243814e+00, + -2.44994116e+00, -2.22744417e+00, -2.00494719e+00, -1.78244996e+00, + -1.55995297e+00, -1.33745599e+00, -1.11495876e+00, -8.92461777e-01, + -6.69964790e-01, -4.47467566e-01, -2.24970579e-01, -2.47359276e-03, + 2.20023513e-01, 4.42520618e-01, 6.65017605e-01, 8.87514710e-01, + 1.11001182e+00, 1.33250880e+00, 1.55500591e+00, 1.77750289e+00, + 2.00000000e+00, 2.22249699e+00, 2.44499421e+00, 2.66749120e+00, + 2.88998818e+00, 3.11248541e+00, 3.33498240e+00, 3.55747938e+00, + 3.77997637e+00, 4.00247383e+00, 4.22497082e+00, 4.44746780e+00, + 4.66996479e+00, 4.89246178e+00, 5.11495876e+00, 5.33745575e+00, + 5.55995274e+00, 5.78244972e+00, 6.00494719e+00, 6.22744417e+00, + 6.44994116e+00, 6.67243814e+00, 6.89493513e+00, 7.11743259e+00}); + auto expBatchMean = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = + NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0, 1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test1) { + auto labels = NDArrayFactory::create('c', {1, 3}, {1, 2, 4}); + auto predictions = NDArrayFactory::create('c', {1, 3}, {2, 2, 4}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); - auto labels = NDArrayFactory::create('c', {1, 3}, {1, 2, 4}); - auto predictions = NDArrayFactory::create('c', {1, 3}, {2, 2, 4}); - auto expected = NDArrayFactory::create('c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test2) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto expected = NDArrayFactory::create('c', {3, 3}, + {0, 0, 0, 1, 0, 0, 0, 0, 1}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 1, 0, 0, 0, 0, 1}); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {3}); + ASSERT_EQ(Status::OK(), results.status()); - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions}, {}, {3}); - ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); - auto output = results.at(0); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); - auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); - auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); - - sd::ops::confusion_matrix op; - auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE}); - auto output = results.at(0); + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, + {sd::DataType::DOUBLE}); + auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_1) { + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, 70, 0, 90, 0, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 0, 21, 22, 23, 24}); - auto x = NDArrayFactory::create('c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, - 70, 0, 90, 0, 11, 12, - 13, 14, 15, 16, 17, 18, - 19, 0, 21, 22, 23, 24}); - - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0).isScalar()); - ASSERT_EQ(res.at(0).e(0), 0.25); - - + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.25); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0).isScalar()); - ASSERT_EQ(res.at(0).e(0), 0.375); + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); - + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_3) { + auto x = NDArrayFactory::create( + 'f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); - auto x = NDArrayFactory::create('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); - sd::ops::zero_fraction op; - auto res = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), res.status()); - ASSERT_TRUE(res.at(0).isScalar()); - ASSERT_EQ(res.at(0).e(0), 0.375); - - + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0).isScalar()); + ASSERT_EQ(res.at(0).e(0), 0.375); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_1) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); - - auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + auto exp = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_2) { + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto y = NDArrayFactory::create('c', {2, 3}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f, 300.f}); - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto y = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + auto exp = NDArrayFactory::create('c', {1, 3}, {166.f, 269.f, 326.f}); - auto exp = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}, {}, {}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_3) { + auto x = NDArrayFactory::create('c', {1, 2}, {1.f, 11.f}); + auto y = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); + auto b = NDArrayFactory::create('c', {1}, {200.f}); - auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); - auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); - auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); - - auto exp = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + auto exp = NDArrayFactory::create('c', {1, 1}, {244.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_4) { + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('f', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); + auto b = NDArrayFactory::create({100.f, 200.f}); - auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - auto b = NDArrayFactory::create({ 100.f, 200.f }); + auto exp = + NDArrayFactory::create('f', {2, 2}, {140.f, 287.f, 233.f, 351.f}); - auto exp = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_5) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {3, 2}, + {11.f, 3.f, 4.f, 5.f, 6.f, 2.f}); - auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); - - y = y.transpose(); - - auto b = NDArrayFactory::create({ 100.f, 200.f }); + y = y.transpose(); - auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + auto b = NDArrayFactory::create({100.f, 200.f}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {173.f, 264.f, 310.f, 279.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }, {}, { 1 }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_6) { + auto x = NDArrayFactory::create('c', {3, 2}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create('c', {2, 1}, {11.f, 3.f}); - auto x = NDArrayFactory::create('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', {1}, {100.f}); - auto b = NDArrayFactory::create('c', { 1 }, { 100.f }); + auto exp = NDArrayFactory::create('c', {3, 1}, {144.f, 175.f, 173.f}); - auto exp = NDArrayFactory::create('c', { 3, 1 }, { 144.f, 175.f, 173.f }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_7) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + auto y = NDArrayFactory::create( + 'c', {4, 5}, {11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, + 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f}); - auto x = NDArrayFactory::create('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); - auto y = NDArrayFactory::create('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f }); - - auto b = NDArrayFactory::create('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f }); + auto b = NDArrayFactory::create('c', {5}, + {100.f, 200.f, 300.f, 400.f, 500.f}); - auto exp = NDArrayFactory::create('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f }); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, + 248.f, 396.f, 496.f, 596.f, 696.f}); - sd::ops::xw_plus_b op; - auto result = op.evaluate({ &x, &y, &b }); + sd::ops::xw_plus_b op; + auto result = op.evaluate({&x, &y, &b}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_1) { + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - - sd::ops::stop_gradient op; - auto result = op.evaluate({&x}); + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + auto output = result.at(0); - // output->printShapeInfo("Output shape> "); - // x.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); - - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, StopGradient_2) { + auto x = NDArrayFactory::create('f', {2, 3}, + {1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('f', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); - sd::ops::stop_gradient op; - auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto output = result.at(0); - - // output->printShapeInfo("Output shape> "); - // x.printShapeInfo("Expected shape> "); - // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); + auto output = result.at(0); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test1) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-2.16985e+00, -1.69846e-01, -3.16985e+00, -1.31507e+00, -6.31507e+00, + -3.15072e-01, -8.00046e+00, -4.58767e-04, -9.00046e+00, -1.31327e+00, + -1.23133e+01, -3.13266e-01, -1.40000e+01, -1.13743e-06, -1.50000e+01, + -1.31326e+00, -1.83133e+01, -3.13262e-01, -2.00000e+01, -2.81941e-09, + -2.10000e+01, -1.31326e+00, -2.43133e+01, -3.13262e-01, -2.73133e+01, + -1.31326e+00, -3.13262e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test2) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-3.05095e+00, -3.04946e+00, -5.00705e+00, -5.09458e-02, -7.04946e+00, + -7.04851e-03, -6.05095e+00, -4.94556e-02, -8.00705e+00, -3.04859e+00, + -1.30000e+01, -3.04859e+00, -1.50486e+01, -2.37286e-06, -1.70486e+01, + -4.85876e-02, -1.60000e+01, -4.85874e-02, -2.10000e+01, -3.04859e+00, + -2.51269e+01, -7.96007e-10, -2.50486e+01, -2.12693e+00, -2.40000e+01, + -4.85874e-02, -1.26928e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-3.05095e+00,-3.04946e+00,-5.00705e+00, -5.09458e-02,-7.04946e+00,-7.04851e-03, -6.05095e+00,-4.94556e-02,-8.00705e+00, -3.04859e+00,-1.30000e+01,-3.04859e+00, -1.50486e+01,-2.37286e-06,-1.70486e+01, -4.85876e-02,-1.60000e+01,-4.85874e-02, -2.10000e+01,-3.04859e+00,-2.51269e+01, -7.96007e-10,-2.50486e+01,-2.12693e+00, -2.40000e+01,-4.85874e-02,-1.26928e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {1}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {1}); + auto z = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test3) { + auto input = NDArrayFactory::create( + 'c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3, 3}, + {-2.16985e+00, -1.69846e-01, -3.16985e+00, -1.31507e+00, -6.31507e+00, + -3.15072e-01, -8.00046e+00, -4.58767e-04, -9.00046e+00, -1.31327e+00, + -1.23133e+01, -3.13266e-01, -1.40000e+01, -1.13743e-06, -1.50000e+01, + -1.31326e+00, -1.83133e+01, -3.13262e-01, -2.00000e+01, -2.81941e-09, + -2.10000e+01, -1.31326e+00, -2.43133e+01, -3.13262e-01, -2.73133e+01, + -1.31326e+00, -3.13262e-01}); - auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {2}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {2}); + auto z = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test5) { + auto input = NDArrayFactory::create('c', {3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3}, + {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, + -1.31335, -0.31335}); - auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); - auto expOutput = NDArrayFactory::create('c', {3, 3}, {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, -1.31335, -0.31335}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test6) { + auto input = NDArrayFactory::create('c', {3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create( + 'c', {3, 3}, + {-3.05095, -3.04946, -7.12773, -0.05095, -7.04946, -2.12773, -6.05095, + -0.04946, -0.12773}); - auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); - auto expOutput = NDArrayFactory::create('c', {3, 3}, {-3.05095,-3.04946,-7.12773, -0.05095,-7.04946,-2.12773, -6.05095,-0.04946,-0.12773}); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test7) { + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test8) { + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test9) { + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test10) { + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); - sd::ops::log_softmax op; - auto results = op.evaluate({&input}, {}, {0}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test11) { + auto input = NDArrayFactory::create('c', {5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create( + 'c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - auto input = NDArrayFactory::create('c', {5}, {-1, 1, -2, 2, 3}); - auto expOutput = NDArrayFactory::create('c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); - - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_test12) { + auto input = NDArrayFactory::create( + 'c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); + auto expOutput = NDArrayFactory::create( + 'c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); - auto input = NDArrayFactory::create('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); - auto expOutput = NDArrayFactory::create('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); - - for (int i = 0; i < 10; ++i) - { - sd::ops::log_softmax op; - auto results = op.evaluate({&input}); - auto z = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); + for (int i = 0; i < 10; ++i) { + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); - - } + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { + auto input = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto epsilon = + NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {-0.07311, 0.02689, -0.07311, 0.02689}); - auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689}); - - sd::ops::log_softmax_bp op; - auto results = op.evaluate({&input, &epsilon}); - auto output = results.at(0); + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}); + auto output = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { + auto input = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto epsilon = + NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); - auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); - - sd::ops::log_softmax_bp op; - auto results = op.evaluate({&input, &epsilon}, {}, {0}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, {0}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ELU_1) { + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, {-0.63212055, 2., 1.5, -0.753403, 1., 2., 2., 1.}); + auto res = NDArrayFactory::create('c', {2, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); - auto res = NDArrayFactory::create('c', {2, 2, 2}); - - input.applyScalar(sd::scalar::ELU, 1.f, res); + input.applyScalar(sd::scalar::ELU, 1.f, res); - ASSERT_TRUE(res.equalsTo(&exp)); + ASSERT_TRUE(res.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, L2_Loss_1) { + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + double exp(9.605); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - double exp(9.605); + sd::ops::l2_loss op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); - sd::ops::l2_loss op; - auto results = op.evaluate({&input}, {}, {}); - auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output.isScalar()); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output.isScalar()); - - ASSERT_EQ(output.e(0), exp); - - + ASSERT_EQ(output.e(0), exp); } TEST_F(DeclarableOpsTests5, L2_Loss_2) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); - sd::ops::l2_loss op; - auto results = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::l2_loss op; + auto results = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); - auto z = results.at(0); + auto z = results.at(0); - ASSERT_EQ(e, z); - - + ASSERT_EQ(e, z); } TEST_F(DeclarableOpsTests5, L2_Loss_3) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); - auto z = NDArrayFactory::create(0.0); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); - sd::ops::l2_loss op; - auto status = op.execute({&x}, {&z} , {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::l2_loss op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { - auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto targets = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817}); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); - sd::ops::log_poisson_loss op; - auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); - auto output = results.at(0); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, + 1.7182817}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); + auto output = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); - auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + auto input = NDArrayFactory::create( + 'c', {2, 2, 2}, {-1., 2., 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create( + 'c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); - auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); - auto targets = NDArrayFactory::create('c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2}, + {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, + 4.0408626, 1.3700882}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882}); + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); + auto output = results.at(0); - sd::ops::log_poisson_loss op; - auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); - auto output = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { + auto means = NDArrayFactory::create( + 'c', {2, 3, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); - auto means = NDArrayFactory::create('c', {2, 3, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {2, 3, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(2.0); - - auto expMeans = NDArrayFactory::create('c', {2, 3, 4}, { - 5.5, 1.5, 7., 2.5, - 3., 4.5, 1.75, 3.5, - 10.5, 1.5, 7., 7.5, - 3., 4.5, 1.75, 3.5, - 5.5, 6.5, 7., 2.5, - 8., 4.5, 6.75, 3.5}); - - auto expDeviance = NDArrayFactory::create('c', {2, 3, 4}, { - -19.75, 4.25, -37., 1.25, - -1., -10.75, 3.6875, -3.75, - -94.75, 4.25, -37., -43.75, - -1., -10.75, 3.6875, -3.75, - -19.75, -30.75, -37., 1.25, - -51., -10.75, -33.8125, -3.75}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + auto deviance = NDArrayFactory::create( + 'c', {2, 3, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); + auto counts = NDArrayFactory::create(2.0); - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); + auto expMeans = NDArrayFactory::create( + 'c', {2, 3, 4}, + {5.5, 1.5, 7., 2.5, 3., 4.5, 1.75, 3.5, 10.5, 1.5, 7., 7.5, + 3., 4.5, 1.75, 3.5, 5.5, 6.5, 7., 2.5, 8., 4.5, 6.75, 3.5}); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + auto expDeviance = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-19.75, 4.25, -37., 1.25, -1., -10.75, 3.6875, -3.75, + -94.75, 4.25, -37., -43.75, -1., -10.75, 3.6875, -3.75, + -19.75, -30.75, -37., 1.25, -51., -10.75, -33.8125, -3.75}); - -} + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); - auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(12.0); - - auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 0.9166667, 0.25, 1.1666667, 0.4166667, - 0.5, 0.75, 0.2916667, 0.5833334, - 1.75, 0.25, 1.1666667, 1.25, - 0.5, 0.75, 0.2916667, 0.5833334, - 0.9166667, 1.0833334, 1.1666667, 0.4166667, - 1.3333334, 0.75, 1.125, 0.5833334}); - - auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { - 0.9097222, 1.0208334, 0.6388887, 1.0763888, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - -0.4791665, 1.0208334, 0.6388887, 0.5208335, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - 0.9097222, 0.7430556, 0.6388887, 1.0763888, - 0.38888884, 1.0208334, 0.6927084, 1.076389}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); +} - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { + auto means = NDArrayFactory::create( + 'c', {3, 2, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + auto deviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); - -} + auto counts = NDArrayFactory::create(12.0); -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { + auto expMeans = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9166667, 0.25, 1.1666667, 0.4166667, 0.5, 0.75, + 0.2916667, 0.5833334, 1.75, 0.25, 1.1666667, 1.25, + 0.5, 0.75, 0.2916667, 0.5833334, 0.9166667, 1.0833334, + 1.1666667, 0.4166667, 1.3333334, 0.75, 1.125, 0.5833334}); - auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., - 6., 9., 3.5, 7., - 21., 3., 14., 15., - 6., 9., 3.5, 7., - 11., 13., 14., 5., - 16., 9., 13.5, 7.}); - - auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., - 16., 19., 13.5, 17., - 31., 13., 24., 25., - 16., 19., 13.5, 17., - 21., 23., 24., 15., - 26., 19., 23.5, 17.}); - - auto counts = NDArrayFactory::create(12.0); - double shift = 10.0; - auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 10.9166667, 10.25, 11.1666667, 10.4166667, - 10.5, 10.75, 10.2916667, 10.5833334, - 11.75, 10.25, 11.1666667, 11.25, - 10.5, 10.75, 10.2916667, 10.5833334, - 10.9166667, 11.0833334, 11.1666667, 10.4166667, - 11.3333334, 10.75, 11.125, 10.5833334}); - - auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { - 0.9097222, 1.0208334, 0.6388887, 1.0763888, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - -0.4791665, 1.0208334, 0.6388887, 0.5208335, - 1.0833334, 1.0208334, 1.0399306, 1.076389, - 0.9097222, 0.7430556, 0.6388887, 1.0763888, - 0.38888884, 1.0208334, 0.6927084, 1.076389}); - - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); + auto expDeviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9097222, 1.0208334, 0.6388887, 1.0763888, 1.0833334, 1.0208334, + 1.0399306, 1.076389, -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, 0.9097222, 0.7430556, + 0.6388887, 1.0763888, 0.38888884, 1.0208334, 0.6927084, 1.076389}); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); - + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { + auto means = NDArrayFactory::create( + 'c', {3, 2, 4}, {11., 3., 14., 5., 6., 9., 3.5, 7., 21., 3., 14., 15., + 6., 9., 3.5, 7., 11., 13., 14., 5., 16., 9., 13.5, 7.}); + + auto deviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {21., 13., 24., 15., 16., 19., 13.5, 17., 31., 13., 24., 25., + 16., 19., 13.5, 17., 21., 23., 24., 15., 26., 19., 23.5, 17.}); + + auto counts = NDArrayFactory::create(12.0); + double shift = 10.0; + auto expMeans = NDArrayFactory::create( + 'c', {3, 2, 4}, + {10.9166667, 10.25, 11.1666667, 10.4166667, 10.5, 10.75, + 10.2916667, 10.5833334, 11.75, 10.25, 11.1666667, 11.25, + 10.5, 10.75, 10.2916667, 10.5833334, 10.9166667, 11.0833334, + 11.1666667, 10.4166667, 11.3333334, 10.75, 11.125, 10.5833334}); + + auto expDeviance = NDArrayFactory::create( + 'c', {3, 2, 4}, + {0.9097222, 1.0208334, 0.6388887, 1.0763888, 1.0833334, 1.0208334, + 1.0399306, 1.076389, -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, 0.9097222, 0.7430556, + 0.6388887, 1.0763888, 0.38888884, 1.0208334, 0.6927084, 1.076389}); + + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index a8451e1c4026..49b3860f0f2e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -18,1646 +18,1569 @@ // Created by raver119 on 09.02.18. // - -#include "testlayers.h" -#include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class DeclarableOpsTests6 : public testing::Test { -public: - - DeclarableOpsTests6() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests6() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) { - auto matrix = NDArrayFactory::create('c', {5, 2}); - auto b = NDArrayFactory::create('c', {1}, {0.}); - auto e = NDArrayFactory::create('c', {1}, {1}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {1}); + auto s = NDArrayFactory::create('c', {1}, {1}); - matrix.linspace(1); + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + matrix.linspace(1); - auto z = result.at(0); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { - auto matrix = NDArrayFactory::create('c', {5, 2}); - auto b = NDArrayFactory::create('c', {1}, {0.0f}); - auto e = NDArrayFactory::create('c', {1}, {1.0f}); - auto s = NDArrayFactory::create('c', {1}, {1.0f}); + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.0f}); + auto e = NDArrayFactory::create('c', {1}, {1.0f}); + auto s = NDArrayFactory::create('c', {1}, {1.0f}); - auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - matrix.linspace(1); + matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_EQ(exp, z); + auto z = result.at(0); - + ASSERT_EQ(exp, z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { - auto matrix = NDArrayFactory::create(10); - auto b = NDArrayFactory::create(0); - auto e = NDArrayFactory::create(0); - auto s = NDArrayFactory::create(1.0); + auto matrix = NDArrayFactory::create(10); + auto b = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(0); + auto s = NDArrayFactory::create(1.0); - //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + // auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - //matrix.linspace(1); + // matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //z->printShapeInfo("SS OS shape"); - ASSERT_TRUE(z.isEmpty()); - //ASSERT_EQ(exp, *z); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printShapeInfo("SS OS shape"); + ASSERT_TRUE(z.isEmpty()); + // ASSERT_EQ(exp, *z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { - auto matrix = NDArrayFactory::create('c', {1}, {10}); - auto b = NDArrayFactory::create('c', {1}, {0.}); - auto e = NDArrayFactory::create('c', {1}, {0.}); - auto s = NDArrayFactory::create('c', {1}, {1.0}); + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {0.}); + auto s = NDArrayFactory::create('c', {1}, {1.0}); - auto exp = NDArrayFactory::create(10); + auto exp = NDArrayFactory::create(10); - //matrix.linspace(1); + // matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(z.equalsTo(exp)); - //ASSERT_EQ(exp, *z); + auto z = result.at(0); - + ASSERT_TRUE(z.equalsTo(exp)); + // ASSERT_EQ(exp, *z); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { - int z = 0; - auto matrix = NDArrayFactory::create('c', {1}, {10}); - auto b = NDArrayFactory::create('c', {1}, {1}); - auto e = NDArrayFactory::create('c', {1}, {z}); - auto s = NDArrayFactory::create('c', {1}, {1}); - sd::ops::ones_as opOnes; - //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); - auto onesRes = opOnes.evaluate({&matrix}); - //matrix.linspace(1); - ASSERT_EQ(onesRes.status(), Status::OK()); - - auto ones = onesRes.at(0); - ones *= 10; - auto onesD = ones.dup(); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, onesD); - variableSpace->putVariable(-2, b); - variableSpace->putVariable(-3, e); - variableSpace->putVariable(-4, s); - auto block = new Context(1, variableSpace, false); // not-in-place - block->fillInputs({-1}); - block->fillInputs({-2}); - block->fillInputs({-3}); - block->fillInputs({-4}); - - block->appendI(0); - block->appendI(0); - block->appendI(1); - block->appendI(0); - block->appendI(0); - - auto inputShapes = new ShapeList({ones.shapeInfo(), b.shapeInfo(), e.shapeInfo(), s.shapeInfo()}); - - sd::ops::strided_slice op; - - auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); - ASSERT_EQ(result->size(), 1); - ASSERT_TRUE(shape::isEmpty(result->at(0))); - //ASSERT_EQ(exp, *z); - delete block; - delete result; - delete variableSpace; - delete inputShapes; + int z = 0; + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {1}, {z}); + auto s = NDArrayFactory::create('c', {1}, {1}); + sd::ops::ones_as opOnes; + // auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto onesRes = opOnes.evaluate({&matrix}); + // matrix.linspace(1); + ASSERT_EQ(onesRes.status(), Status::OK()); + + auto ones = onesRes.at(0); + ones *= 10; + auto onesD = ones.dup(); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, onesD); + variableSpace->putVariable(-2, b); + variableSpace->putVariable(-3, e); + variableSpace->putVariable(-4, s); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + block->fillInputs({-2}); + block->fillInputs({-3}); + block->fillInputs({-4}); + + block->appendI(0); + block->appendI(0); + block->appendI(1); + block->appendI(0); + block->appendI(0); + + auto inputShapes = new ShapeList( + {ones.shapeInfo(), b.shapeInfo(), e.shapeInfo(), s.shapeInfo()}); + + sd::ops::strided_slice op; + + auto result = op.calculateOutputShape( + inputShapes, + *block); // execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); + ASSERT_EQ(result->size(), 1); + ASSERT_TRUE(shape::isEmpty(result->at(0))); + // ASSERT_EQ(exp, *z); + delete block; + delete result; + delete variableSpace; + delete inputShapes; } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) { - auto matrix = NDArrayFactory::create('c', {3, 2, 2}); - auto b = NDArrayFactory::create('c', {1}, {2}); - auto e = NDArrayFactory::create('c', {1}, {3}); - auto s = NDArrayFactory::create('c', {1}, {1}); + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); - auto exp = NDArrayFactory::create('c', {2,2}, {0.0f, 0.0f, 0., 0.}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.0f, 0.0f, 0., 0.}); - //matrix.linspace(1); + // matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { - auto matrix = NDArrayFactory::create('c', {3, 2, 2}); - auto b = NDArrayFactory::create('c', {1}, {2}); - auto e = NDArrayFactory::create('c', {1}, {3}); - auto s = NDArrayFactory::create('c', {1}, {1}); - - auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); - //matrix.linspace(1); + auto exp = + NDArrayFactory::create('c', {1, 2, 2}, {0.0f, 0.0f, 0., 0.}); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); - ASSERT_EQ(Status::OK(), result.status()); + // matrix.linspace(1); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {5, 4}); - auto b = NDArrayFactory::create('c', {1}, {zero}); - auto e = NDArrayFactory::create('c', {1}, {zero}); - auto s = NDArrayFactory::create('c', {1}, {1}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); + auto b = NDArrayFactory::create('c', {1}, {zero}); + auto e = NDArrayFactory::create('c', {1}, {zero}); + auto s = NDArrayFactory::create('c', {1}, {1}); - //auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); + // auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., + // 0.}); - //matrix.linspace(1); + // matrix.linspace(1); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {5, 4}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - auto grad = NDArrayFactory::create('c', {5}); + auto grad = NDArrayFactory::create('c', {5}); - matrix.linspace(1); - grad.linspace(1); + matrix.linspace(1); + grad.linspace(1); - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {1, 2}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {1, 2}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - auto grad = NDArrayFactory::create('c', {1}, {1.}); + auto grad = NDArrayFactory::create('c', {1}, {1.}); - matrix.linspace(1); - //grad.linspace(1); + matrix.linspace(1); + // grad.linspace(1); - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { - int zero = 0; - auto matrix = NDArrayFactory::create('c', {4, 8192}); -// auto b = NDArrayFactory::create('c', {1}, {zero}); -// auto e = NDArrayFactory::create('c', {1}, {zero}); -// auto s = NDArrayFactory::create('c', {1}, {1}); + int zero = 0; + auto matrix = NDArrayFactory::create('c', {4, 8192}); + // auto b = NDArrayFactory::create('c', {1}, {zero}); + // auto e = NDArrayFactory::create('c', {1}, {zero}); + // auto s = NDArrayFactory::create('c', {1}, {1}); - auto grad = NDArrayFactory::create('c', {4, 256}); + auto grad = NDArrayFactory::create('c', {4, 256}); - matrix.linspace(1); - grad.linspace(1); + matrix.linspace(1); + grad.linspace(1); - sd::ops::strided_slice_bp op; - auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::strided_slice_bp op; + auto result = + op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); - sd::ops::test_scalar op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::test_scalar op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Order_1) { - auto x = NDArrayFactory::create('f', {2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}); - x.linspace(1); - exp.linspace(1); - - sd::ops::order op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + exp.linspace(1); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_NE(x.ordering(), z.ordering()); + sd::ops::order op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_NE(x.ordering(), z.ordering()); } TEST_F(DeclarableOpsTests6, cumSum_1) { - auto x = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); + auto x = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_2) { - auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); - auto z = result.at(0); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("CumSum1"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + // z->printIndexedBuffer("CumSum1"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_3) { - auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); - auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); - auto z = result.at(0); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_4) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); - auto z = result.at(0); - // z->printBuffer(); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_5) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, + { + 6.f, + 5.f, + 3.f, + 15.f, + 11.f, + 6.f, + 24.f, + 17.f, + 9.f, + }); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_6) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_7) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - auto z = result.at(0); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, cumSum_8) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); - sd::ops::cumsum op; - auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); + sd::ops::cumsum op; + auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_9) { - - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1); - - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.}); - - auto expFT = NDArrayFactory::create('c', {3, 5}, {15, 14, 12, 9, 5,40, 34, 27, 19, 10,65, 54, 42, 29, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {14, 12, 9, 5, 0,34, 27, 19, 10, 0,54, 42, 29, 15, 0}); - - int exclusive, reverse; - - //************************************// - exclusive = 0; reverse = 0; - - sd::ops::cumsum op; - auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_TRUE(expFF.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 0; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTF.equalsTo(z)); - - - //************************************// - exclusive = 0; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expFT.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTT.equalsTo(z)); - - + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.}); + + auto expFT = NDArrayFactory::create( + 'c', {3, 5}, + {15, 14, 12, 9, 5, 40, 34, 27, 19, 10, 65, 54, 42, 29, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, {14, 12, 9, 5, 0, 34, 27, 19, 10, 0, 54, 42, 29, 15, 0}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; + reverse = 0; + + sd::ops::cumsum op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + //************************************// + exclusive = 0; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_10) { - auto x = NDArrayFactory::create('c', {4, 16, 16, 1}); - auto y = NDArrayFactory::create(-3); - - sd::ops::cumsum op; - auto result = op.evaluate({&x, &y}, {}, {1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {4, 16, 16, 1}); + auto y = NDArrayFactory::create(-3); - + sd::ops::cumsum op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_11) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {12., 15., 18., 11., 13., 15., 7., 8., 9., 39., 42., 45., 29., 31., + 33., 16., 17., 18., 66., 69., 72., 47., 49., 51., 25., 26., 27.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {12., 15., 18.,11., 13., 15.,7., 8., 9., 39., 42., 45.,29., 31., 33.,16., 17., 18., 66., 69., 72.,47., 49., 51.,25., 26., 27.}); + x.linspace(1); - x.linspace(1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_12) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {1., 2., 3., 5., 7., 9., 12., 15., 18., 10., 11., 12., 23., 25., + 27., 39., 42., 45., 19., 20., 21., 41., 43., 45., 66., 69., 72.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {1., 2., 3.,5., 7., 9.,12., 15., 18., 10., 11., 12.,23., 25., 27.,39., 42., 45., 19., 20., 21.,41., 43., 45., 66., 69., 72.}); + x.linspace(1); - x.linspace(1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_13) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {11., 13., 15., 7., 8., 9., 0., 0., 0., 29., 31., 33., 16., 17., + 18., 0., 0., 0., 47., 49., 51., 25., 26., 27., 0., 0., 0.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {11., 13., 15.,7., 8., 9.,0., 0., 0., 29., 31., 33.,16., 17., 18.,0., 0., 0., 47., 49., 51.,25., 26., 27.,0., 0., 0.}); + x.linspace(1); - x.linspace(1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_14) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {29., 31., 33., 35., 37., 39., 41., 43., 45., 19., 20., 21., 22., 23., + 24., 25., 26., 27., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {29., 31., 33.,35., 37., 39.,41., 43., 45., 19., 20., 21.,22., 23., 24.,25., 26., 27., 0., 0., 0.,0., 0., 0.,0., 0., 0.}); + x.linspace(1); - x.linspace(1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_15) { + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create( + 'c', {3, 3, 3}, + {6., 5., 3., 15., 11., 6., 24., 17., 9., 33., 23., 12., 42., 29., + 15., 51., 35., 18., 60., 41., 21., 69., 47., 24., 78., 53., 27.}); - NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); - auto exp = NDArrayFactory::create('c', {3,3,3}, {6., 5., 3.,15., 11., 6.,24., 17., 9., 33., 23., 12.,42., 29., 15.,51., 35., 18., 60., 41., 21.,69., 47., 24.,78., 53., 27.}); + x.linspace(1); - x.linspace(1); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_16) { + NDArray x('f', {3, 4}, sd::DataType::FLOAT32); - NDArray x('f', {3, 4}, sd::DataType::FLOAT32); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printShapeInfo(); + // x.printShapeInfo(); - auto z = result.at(0); - // z->printShapeInfo(); - // x.printShapeInfo(); - - ASSERT_TRUE(z.ews() == 1); - ASSERT_TRUE(x.ews() == 1); - - + ASSERT_TRUE(z.ews() == 1); + ASSERT_TRUE(x.ews() == 1); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_17) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(0, 1.); - exp1.p(0, 1.); + exp0.p(0, 1.); + exp1.p(0, 1.); - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); - } + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_18) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(0, 0.); - exp1.p(0, 0.); + exp0.p(0, 0.); + exp1.p(0, 0.); - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev + i); - exp1.p(i, prev + i); - } + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev + i); + exp1.p(i, prev + i); + } - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_19) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(1499, 1500.f); - exp1.p(1499, 1500.f); + exp0.p(1499, 1500.f); + exp1.p(1499, 1500.f); - for (int i = 1498; i >= 0; --i) { - const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 1); - exp1.p(i, prev + i + 1); - } + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {0, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // exp0.printBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // exp0.printBuffer(); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, cumSum_20) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1); - x1.linspace(1); - - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - exp0.p(1499, 0.); - exp1.p(1499, 0.); + exp0.p(1499, 0.); + exp1.p(1499, 0.); - for (int i = 1498; i >= 0; --i) { - const auto prev = exp0.e(i + 1); - exp0.p(i, prev + i + 2); - exp1.p(i, prev + i + 2); - } + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 2); + exp1.p(i, prev + i + 2); + } - sd::ops::cumsum op; - auto result = op.evaluate({&x}, {}, {1, 1, 1}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + sd::ops::mergemaxindex op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); - sd::ops::mergemaxindex op; - - auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); + auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0).printIndexedBuffer("MergeMaxIndex Result is "); -// res.at(0).printShapeInfo("Shape info for MergeMaxIdex"); -// x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0).equalsTo(exp)); - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("MergeMaxIndex Result is "); + // res.at(0).printShapeInfo("Shape info for MergeMaxIdex"); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {1, 2, 1, 2, 1, 2, 1, 2}); + sd::ops::mergemaxindex op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); - sd::ops::mergemaxindex op; + auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); - auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); -// res.at(0).printIndexedBuffer("MergeMaxIndex2 Result is "); -// res.at(0).printShapeInfo("Shape info for MergeMaxIdex2"); -// x.printIndexedBuffer("Input is"); - ASSERT_TRUE(ress.at(0).equalsTo(exp)); - + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // res.at(0).printIndexedBuffer("MergeMaxIndex2 Result is "); + // res.at(0).printShapeInfo("Shape info for MergeMaxIdex2"); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress.at(0).equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestDropout_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({2, 2}); + sd::ops::dropout op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto shape = NDArrayFactory::create({2, 2}); - sd::ops::dropout op; - - auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); + auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - //res.at(0).printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMod_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); + sd::ops::mod op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); - sd::ops::mod op; - - auto res = op.evaluate({&x, &y}); + auto res = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0).printIndexedBuffer("MOD Result is "); -// x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0).equalsTo(exp)); - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("MOD Result is "); + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMod_BP_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + sd::ops::mod_bp op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}); - sd::ops::mod_bp op; + auto res = op.evaluate({&x, &y, &eps}); - auto res = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res.at(0).printIndexedBuffer("MOD_BP Result is "); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0).printIndexedBuffer("MOD_BP Result is "); - - // x.printIndexedBuffer("Input is"); - ASSERT_TRUE(res.at(0).equalsTo(exp)); - + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0).equalsTo(exp)); } /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestRank_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create( + 'c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create(3); + sd::ops::rank op; - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto exp = NDArrayFactory::create(3); - sd::ops::rank op; - - auto res = op.evaluate({&x}); + auto res = op.evaluate({&x}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(res.at(0).equalsTo(exp)); - + ASSERT_TRUE(res.at(0).equalsTo(exp)); } TEST_F(DeclarableOpsTests6, TestDropout_2) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); - sd::ops::dropout op; + sd::ops::dropout op; - auto res = op.evaluate({&x}, {0.4f}, {113}); + auto res = op.evaluate({&x}, {0.4f}, {113}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); } TEST_F(DeclarableOpsTests6, TestDropout_3) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto shape = NDArrayFactory::create({1, 2}); - - sd::ops::dropout op; + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({1, 2}); - auto res = op.evaluate({&x, &shape}, {0.4f}, {113}); + sd::ops::dropout op; - ASSERT_EQ(ND4J_STATUS_OK, res.status()); + auto res = op.evaluate({&x, &shape}, {0.4f}, {113}); - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + auto expI = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5, 3.5, 0., 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); - auto expI = NDArrayFactory::create('c', {2, 2, 2, 4}, {0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15, - 0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15}); + sd::ops::max_pool_with_argmax op; - sd::ops::max_pool_with_argmax op; + auto res = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto res = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(expI.isSameShape(res.at(0))); + ASSERT_TRUE(expI.isSameShape(res.at(1))); + ASSERT_TRUE(x.equalsTo(res.at(0))); + ASSERT_TRUE(expI.equalsTo(res.at(1))); + // x.printIndexedBuffer("Input is"); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(expI.isSameShape(res.at(0))); - ASSERT_TRUE(expI.isSameShape(res.at(1))); - ASSERT_TRUE(x.equalsTo(res.at(0))); - ASSERT_TRUE(expI.equalsTo(res.at(1))); - //x.printIndexedBuffer("Input is"); - - ASSERT_TRUE(expI.equalsTo(res.at(1))); - - + ASSERT_TRUE(expI.equalsTo(res.at(1))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., - 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); -// ------------------------------------ - double count = 8.0; - auto sumExp = NDArrayFactory::create({30.2, 5., 7.8, 22.8}); - auto sqrExp = NDArrayFactory::create({154.22, 7., 14.34, 103.62}); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + // ------------------------------------ + double count = 8.0; + auto sumExp = NDArrayFactory::create({30.2, 5., 7.8, 22.8}); + auto sqrExp = NDArrayFactory::create({154.22, 7., 14.34, 103.62}); - auto axis = NDArrayFactory::create({0, 1, 2}); + auto axis = NDArrayFactory::create({0, 1, 2}); - sd::ops::sufficient_statistics op; + sd::ops::sufficient_statistics op; - auto res = op.evaluate({&x, &axis}); + auto res = op.evaluate({&x, &axis}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0).e(0), count); - ASSERT_TRUE(sumExp.equalsTo(res.at(1))); - ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); - - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0).e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { -// auto x0 = NDArrayFactory::create('c', {10, 10}); -// auto x1 = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, - 1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); -// ------------------------------------ - double count = 4.0; - auto sumExp = NDArrayFactory::create('c', {2, 4}, { - 18.2, 3., 4.6, 8.8, - 12., 2., 3.2, 14.} - ); - - auto sqrExp = NDArrayFactory::create('c', {2, 4}, { - 113.22, 5., 10.78, 34.62, - 41., 2., 3.56, 69.} - ); + // auto x0 = NDArrayFactory::create('c', {10, 10}); + // auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5, 1.5, 0., 1.3, 6.5, 8.6, 0., 0., + 0.4, 2.5, 1., 0.3, 4.5, 1.5, 1., 1.3, 1.5, 3.5, 0., + 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + // ------------------------------------ + double count = 4.0; + auto sumExp = NDArrayFactory::create( + 'c', {2, 4}, {18.2, 3., 4.6, 8.8, 12., 2., 3.2, 14.}); - auto axis = NDArrayFactory::create({0, 1}); + auto sqrExp = NDArrayFactory::create( + 'c', {2, 4}, {113.22, 5., 10.78, 34.62, 41., 2., 3.56, 69.}); - sd::ops::sufficient_statistics op; + auto axis = NDArrayFactory::create({0, 1}); - auto res = op.evaluate({&x, &axis}); + sd::ops::sufficient_statistics op; - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_EQ(res.at(0).e(0), count); - ASSERT_TRUE(sumExp.equalsTo(res.at(1))); - ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); + auto res = op.evaluate({&x, &axis}); - + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0).e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_1) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); + // ------------------------------------ - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); -// ------------------------------------ + NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32); - NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32); + sd::ops::bincount op; - sd::ops::bincount op; - - auto res = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_2) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); + // ------------------------------------ -// ------------------------------------ + auto exp = NDArrayFactory::create({3., 4., 13.}); - auto exp = NDArrayFactory::create({3., 4., 13.}); + sd::ops::bincount op; - sd::ops::bincount op; + auto res = op.evaluate({&x, &weights}); - auto res = op.evaluate({&x, &weights}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_3) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); -// ------------------------------------ + // ------------------------------------ - auto exp = NDArrayFactory::create({3., 4.}); + auto exp = NDArrayFactory::create({3., 4.}); - sd::ops::bincount op; + sd::ops::bincount op; - auto res = op.evaluate({&x, &weights}, {}, {0, 2}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x, &weights}, {}, {0, 2}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_4) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); -// ------------------------------------ + // ------------------------------------ - auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); - sd::ops::bincount op; + sd::ops::bincount op; - auto res = op.evaluate({&x, &weights}, {}, {4, 4}); + auto res = op.evaluate({&x, &weights}, {}, {4, 4}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BinCount_5) { + auto x = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 0, 1, 2, 2, 1, 2}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, { - 1, 2, 0, 1, 2, 2, 1, 2} - ); - - auto weights = NDArrayFactory::create('c', {2, 2, 2}, { - 2, 1, 3, 1, 5, 1, 1, 6} - ); - auto minV = NDArrayFactory::create(4); - auto maxV = NDArrayFactory::create(4); -// ------------------------------------ - - auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + auto weights = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 3, 1, 5, 1, 1, 6}); + auto minV = NDArrayFactory::create(4); + auto maxV = NDArrayFactory::create(4); + // ------------------------------------ - sd::ops::bincount op; + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); - auto res = op.evaluate({&x, &weights, &minV, &maxV}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - // res->at(0)->printBuffer("BC out"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::bincount op; + auto res = op.evaluate({&x, &weights, &minV, &maxV}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res->at(0)->printBuffer("BC out"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) { + auto x = NDArrayFactory::create({2, 2, 2}); - auto x = NDArrayFactory::create( {2, 2, 2} ); + auto y = NDArrayFactory::create({2, 1, 2}); - auto y = NDArrayFactory::create({ 2, 1, 2}); + auto exp = NDArrayFactory::create({2, 2, 2}); - auto exp = NDArrayFactory::create({2, 2, 2}); + sd::ops::broadcast_dynamic_shape op; - sd::ops::broadcast_dynamic_shape op; - - auto res = op.evaluate({&x, &y}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { + auto x = NDArrayFactory::create({2, 2}); - auto x = NDArrayFactory::create( {2, 2} ); - - auto y = NDArrayFactory::create({2, 1, 2}); - - auto exp = NDArrayFactory::create({2, 2, 2}); + auto y = NDArrayFactory::create({2, 1, 2}); - sd::ops::broadcast_dynamic_shape op; + auto exp = NDArrayFactory::create({2, 2, 2}); - auto res = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) { + auto x = NDArrayFactory::create({2, 2, 2}); - auto x = NDArrayFactory::create( {2, 2, 2} ); + auto y = NDArrayFactory::create({2, 1}); - auto y = NDArrayFactory::create({2, 1}); + auto exp = NDArrayFactory::create({2, 2, 2}); - auto exp = NDArrayFactory::create({2, 2, 2}); + sd::ops::broadcast_dynamic_shape op; - sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}, {}, {}, {}); - auto res = op.evaluate({&x, &y}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { + auto x = NDArrayFactory::create({2, 1}); - auto x = NDArrayFactory::create( {2, 1} ); - - auto y = NDArrayFactory::create('c', {1}, {4}); + auto y = NDArrayFactory::create('c', {1}, {4}); - auto exp = NDArrayFactory::create({2, 4}); + auto exp = NDArrayFactory::create({2, 4}); - sd::ops::broadcast_dynamic_shape op; + sd::ops::broadcast_dynamic_shape op; - auto res = op.evaluate({&x, &y}); + auto res = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - //res->at(0)->printBuffer("Shape SGO 4"); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res->at(0)->printBuffer("Shape SGO 4"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { + auto x = NDArrayFactory::create({2, 1, 4}); - auto x = NDArrayFactory::create({2, 1, 4}); - - auto y = NDArrayFactory::create({2, 2, 4}); - - auto exp = NDArrayFactory::create({2, 2, 4}); + auto y = NDArrayFactory::create({2, 2, 4}); - sd::ops::broadcast_dynamic_shape op; - auto res = op.evaluate({&x, &y}); + auto exp = NDArrayFactory::create({2, 2, 4}); - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) { + auto x = NDArrayFactory::create({1, 1, 3}); - auto x = NDArrayFactory::create({1, 1, 3}); + auto y = NDArrayFactory::create({2, 4, 1}); - auto y = NDArrayFactory::create({2, 4, 1}); + auto exp = NDArrayFactory::create({2, 4, 3}); - auto exp = NDArrayFactory::create({2, 4, 3}); + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); - sd::ops::broadcast_dynamic_shape op; - auto res = op.evaluate({&x, &y}); - - ASSERT_EQ(ND4J_STATUS_OK, res.status()); - ASSERT_TRUE(exp.equalsTo(res.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) { + auto x = NDArrayFactory::create('c', {1}, {1}); - auto x = NDArrayFactory::create('c', {1}, {1}); - - auto y = NDArrayFactory::create('c', {1}, {4}); + auto y = NDArrayFactory::create('c', {1}, {4}); - auto z = NDArrayFactory::create('c', {1}); + auto z = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('c', {1}, {4}); + auto exp = NDArrayFactory::create('c', {1}, {4}); - sd::ops::broadcast_dynamic_shape op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto x = NDArrayFactory::create('c', {2}, {2,2}); - - auto y = NDArrayFactory::create('c', {1}, {1}); + auto y = NDArrayFactory::create('c', {1}, {1}); - auto z = NDArrayFactory::create('c', {2}); + auto z = NDArrayFactory::create('c', {2}); - auto exp = NDArrayFactory::create('c', {2}, {2,2}); + auto exp = NDArrayFactory::create('c', {2}, {2, 2}); - sd::ops::broadcast_dynamic_shape op; - auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - // ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); - - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.2771281, 0., 0., - 0.36950415, 0., 0., - -0.2771281, 0., 0., - 0.36950415, 0., 0., - -0.2771281, 0., 0., - 0.36950415, 0., 0.} - ); -// 8.660254 -// auto expNorm(8.660254); - - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x}, {0.8}, {}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.2771281, 0., 0., 0.36950415, 0., 0., -0.2771281, 0., 0., 0.36950415, + 0., 0., -0.2771281, 0., 0., 0.36950415, 0., 0.}); + // 8.660254 + // auto expNorm(8.660254); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x}, {0.8}, {}); - auto z = result.at(0); - auto norm = result.at(1); - //z->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expected"); - //norm->printIndexedBuffer("Norm"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -// ASSERT_TRUE(expNorm.equalsTo(norm)); + auto z = result.at(0); + auto norm = result.at(1); + // z->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expected"); + // norm->printIndexedBuffer("Norm"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + // ASSERT_TRUE(expNorm.equalsTo(norm)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); - - auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, - -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} - ); + auto a = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.44090813, 0., 0., - 0.5878775, 0., 0., - -0.44090813, 0., 0., - 0.5878775, 0., 0., - -0.44090813, 0., 0., - 0.5878775, 0., 0.} -//12.247449 + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.44090813, 0., 0., 0.5878775, 0., 0., -0.44090813, 0., 0., 0.5878775, + 0., 0., -0.44090813, 0., 0., 0.5878775, 0., 0.} // 12.247449 - ); + ); - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x, &a}, {1.8}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {1.8}, {}); - auto z = result.at(0); - auto y = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(exp.equalsTo(y)); + auto z = result.at(0); + auto y = result.at(1); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { - - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - -0.19595918, 0., 0., - 0.2612789, 0., 0., - -0.19595918, 0., 0., - 0.2612789, 0., 0., - -0.19595918, 0., 0., - 0.2612789, 0., 0.} - ); - - sd::ops::clip_by_global_norm op; - auto result = op.evaluate({&x, &a}, {0.8}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - auto y = result.at(1); - //z->printIndexedBuffer("Output 1"); - //y->printIndexedBuffer("Output 2"); - //result.at(2)->printIndexedBuffer("Global norm is"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(result.at(2).isScalar()); - ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_TRUE(exp.equalsTo(y)); - - + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); + auto a = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, + 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-0.19595918, 0., 0., 0.2612789, 0., 0., -0.19595918, 0., 0., 0.2612789, + 0., 0., -0.19595918, 0., 0., 0.2612789, 0., 0.}); + + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {0.8}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto y = result.at(1); + // z->printIndexedBuffer("Output 1"); + // y->printIndexedBuffer("Output 2"); + // result.at(2)->printIndexedBuffer("Global norm is"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(result.at(2).isScalar()); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, + 0.0, 0.0, 0.0, 4.0}); + auto exp = NDArrayFactory::create({36.0, -48.0}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); - auto exp = NDArrayFactory::create({36.0, -48.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}); + auto exp = NDArrayFactory::create({-2.0, -2.0}); - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}); - auto exp = NDArrayFactory::create({-2.0, -2.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); + NDArray exp('c', {1}, std::vector{-54.0}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); - NDArray exp('c', {1}, std::vector{-54.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0}); + auto exp = NDArrayFactory::create('c', {1}, {189.0}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0}); - auto exp = NDArrayFactory::create('c', {1}, {189.0}); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - // z->printIndexedBuffer("Output "); - // exp.printIndexedBuffer("Expected "); - // z->printShapeInfo("Output shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + // z->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { + auto x = NDArrayFactory::create('c', {1, 4, 4}); + NDArray exp('c', {1}, std::vector{-16.0}); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); - auto x = NDArrayFactory::create('c', {1, 4, 4}); - NDArray exp('c', {1}, std::vector{-16.0}); - x.linspace(1); - x.p(5, 4.0); - x.p(12, 12.0); - - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto exp = NDArrayFactory::create(-16.0); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); - auto x = NDArrayFactory::create('c', {4, 4}); - auto exp = NDArrayFactory::create(-16.0); - x.linspace(1); - x.p(5, 4.0); - x.p(12, 12.0); + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //z->printIndexedBuffer("Output "); - //z->printShapeInfo("Shape"); - //exp.printIndexedBuffer("Expected "); - ASSERT_TRUE(z.isScalar()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // z->printShapeInfo("Shape"); + // exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(z.isScalar()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, + 0.0, 0.0, 0.0, 4.0}); + auto exp = + NDArrayFactory::create({3.58351893845611, 3.871201010907891}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); - auto exp = NDArrayFactory::create({3.58351893845611, 3.871201010907891}); - - sd::ops::log_matrix_determinant op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::log_matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4, 12, -16, 12, 37, -43, -16, -43, 98, 4, 1.2, -1.6, 1.2, 3.7, -4.3, + -1.6, -4.3, 9.8}); + auto exp = NDArrayFactory::create({3.5835189, 4.159008}); - auto x = NDArrayFactory::create('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8}); - auto exp = NDArrayFactory::create({ 3.5835189, 4.159008}); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_2) { + auto x = NDArrayFactory::create( + 'c', {1, 3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + auto exp = NDArrayFactory::create('c', {1}, {3.5835189}); - auto x = NDArrayFactory::create('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); - auto exp = NDArrayFactory::create('c', {1}, { 3.5835189}); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, LogDet_3) { + auto x = NDArrayFactory::create( + 'c', {3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + auto exp = NDArrayFactory::create(3.5835189); - auto x = NDArrayFactory::create('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); - auto exp = NDArrayFactory::create( 3.5835189); - - sd::ops::logdet op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_1) { + auto x = NDArrayFactory::create( + 'c', {2, 5, 5}, + {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, + 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f, - auto x = NDArrayFactory::create('c', {2, 5, 5}, { - 2.f, 4.f, 60.f, 8.f, 10.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 2.f, 4.f, 6.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 4.f, - - 1.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 1.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 4.f, 3.f, 2.f, 1.f, 0.f, - 5.f, 4.f, 3.f, 2.f, 1.f - }); + 1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f}); - auto exp = NDArrayFactory::create('c', {2, 5, 5}, { - 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, - 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, - 0.f, 0.f, 0.5f, -2.0f, 0.25f, - 0.f, 0.f, 0.f, 1.0f, -0.5f, - 0.f, 0.f, 0.f, 0.f, 0.25f, - - 1.0f, 0.0f, 0.0f, 0.0f, 0.f, - -2.0f, 1.0f, 0.f, 0.f, 0.f, - -26.0f, -2.0f, 1.f, 0.f, 0.f, - 54.0f, 1.0f, -2.0f, 1.f, 0.f, - -27.0f, 0.0f, 1.0f, -2.0f, 1.f, - }); + auto exp = NDArrayFactory::create( + 'c', {2, 5, 5}, + { + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, + 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, + 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f, - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f, + }); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_010) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + { + 1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, + }); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_01) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, + 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f }); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, + 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, + 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f}); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_02) { + auto x = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 5, 5}, + {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, + 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, + 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); - auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); - auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// @@ -1704,1096 +1627,1269 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } */ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + { + 4.f, 0.f, 0.f, 0.f, 0.f, 4.f, 2.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 8.f, 6.f, 4.f, 2.f, 0.f, 15.f, 12.f, 9.f, 6.f, 3.f, + }); - auto x = NDArrayFactory::create('c', {5, 5}, { - 4.f, 0.f, 0.f, 0.f, 0.f, - 4.f, 2.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 8.f, 6.f, 4.f, 2.f, 0.f, - 15.f, 12.f, 9.f, 6.f, 3.f, - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, - -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, - -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, - 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, - -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {0.25f, 0.0f, 0.0f, 0.0f, 0.0f, -0.50f, 0.5f, 0.0f, 0.0f, + 0.0f, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, 13.50f, 0.5f, -2.0f, + 0.5f, 0.0f, -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_3) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + { + 4.f, 0.f, 0.f, 0.f, 0.f, 4.f, 2.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, + 0.f, 0.f, 8.f, 6.f, 4.f, 2.f, 0.f, 15.f, 12.f, 9.f, 6.f, 3.f, + }); - auto x = NDArrayFactory::create('c', {5, 5}, { - 4.f, 0.f, 0.f, 0.f, 0.f, - 4.f, 2.f, 0.f, 0.f, 0.f, - 30.f, 2.f, 1.f, 0.f, 0.f, - 8.f, 6.f, 4.f, 2.f, 0.f, - 15.f, 12.f, 9.f, 6.f, 3.f, - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, - -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, - -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, - 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, - -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {0.25f, 0.0f, 0.0f, 0.0f, 0.0f, -0.50f, 0.5f, 0.0f, 0.0f, + 0.0f, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, 13.50f, 0.5f, -2.0f, + 0.5f, 0.0f, -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); -// exp.printIndexedBuffer("Expected "); -// z->printIndexedBuffer("Output "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // exp.printIndexedBuffer("Expected "); + // z->printIndexedBuffer("Output "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_4) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 2.f, 30.f, 4.f, 5.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 1.f, + 2.f, 3.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 1.f}); - auto x = NDArrayFactory::create('c', {5, 5}, { - 1.f, 2.f, 30.f, 4.f, 5.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 1.f, 2.f, 3.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, - 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, - 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, - 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 1.0f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {1.0f, -2.0f, -26.0f, 54.0f, -27.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, MatrixInverse_04) { + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 2.f, 30.f, 4.f, 5.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 1.f, + 2.f, 3.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 1.f}); - auto x = NDArrayFactory::create('c', {5, 5}, { - 1.f, 2.f, 30.f, 4.f, 5.f, - 0.f, 1.f, 2.f, 3.f, 4.f, - 0.f, 0.f, 1.f, 2.f, 3.f, - 0.f, 0.f, 0.f, 1.f, 2.f, - 0.f, 0.f, 0.f, 0.f, 1.f - }); - - auto exp = NDArrayFactory::create('c', {5, 5}, { - 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, - 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, - 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, - 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, - 0.0f, 0.0f, 0.0f, 0.0f, 1.0f - }); - - sd::ops::matrix_inverse op; - auto result = op.evaluate({&x}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {1.0f, -2.0f, -26.0f, 54.0f, -27.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, 0.0f, 0.0f, + 1.0f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); - auto z = result.at(0); -// z->printIndexedBuffer("Output "); -// exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, ReluLayer_1) { - auto x = NDArrayFactory::create('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12}); - auto w = NDArrayFactory::create('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25}); - auto b = NDArrayFactory::create({20.0, 30.0, 50.0}); - + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12}); + auto w = NDArrayFactory::create( + 'c', {4, 3}, + {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25}); + auto b = NDArrayFactory::create({20.0, 30.0, 50.0}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {21.4, 30.45, 52.3, 23.8, 31.05, 56.5, 26.2, 31.65, 60.7}); - auto exp = NDArrayFactory::create('c', {3, 3}, { - 21.4, 30.45, 52.3, - 23.8, 31.05, 56.5, - 26.2, 31.65, 60.7}); + sd::ops::relu_layer op; + auto result = op.evaluate({&x, &w, &b}); - sd::ops::relu_layer op; - auto result = op.evaluate({&x, &w, &b}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printShapeInfo("Output shape"); - // z->printIndexedBuffer("Output "); - // exp.printIndexedBuffer("Expected "); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("Output shape"); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto y = NDArrayFactory::create('c', {3, 4, 5}); - + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); - std::vector dims = {0, 1}; - auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); - ASSERT_TRUE(&z != nullptr); + std::vector dims = {0, 1}; + auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); + ASSERT_TRUE(&z != nullptr); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test1) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0. ,0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, + 0.93751527, 0.93751527}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test2) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333, - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548, - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654, - 0.98112648, 0.98112648, 0.98112648, 0.98112648}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0.97338548, 0.97338548, 0.97338548, 0.97338548, + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.97864398, 0.97864398, + 0.97864398, 0.97864398, 0.98000654, 0.98000654, 0.98000654, 0.98000654, + 0.98112648, 0.98112648, 0.98112648, 0.98112648}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.98000654, 0.98000654, 0.98000654, 0.98000654, 0.98112648, 0.98112648, + 0.98112648, 0.98112648}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test3) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, 0}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333, - 0., 0., 0., 0. , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time - 1, 0}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0., 0., 0., 0., + 0.9312333, 0.9312333, 0.9312333, 0.9312333, 0., 0., 0., 0., + 0.97136768, 0.97136768, 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2, 0.2, 0.2, 0.2}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test4) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, - 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0. , - 0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.49676344, 0.49676344, + 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, + 0.96529784, 0.96529784, 0., 0., 0., 0., + 0.97688859, 0.97688859, 0.97688859, 0.97688859, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, + 0.88400882, 0.88400882}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_rnn_test5) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, - 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345, - 0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868, - 0.98110653, 0.98110653, 0.98110653, 0.98110653}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653}); - - sd::ops::static_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.49676344, 0.49676344, + 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, + 0.96529784, 0.96529784, 0.96849345, 0.96849345, 0.96849345, 0.96849345, + 0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.97831069, 0.97831069, + 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868, + 0.98110653, 0.98110653, 0.98110653, 0.98110653}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, + 0.98110653, 0.98110653}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842, - 0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605, - 0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0., - 0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., - 0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = NDArrayFactory::create( + 'c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881, 0.86708881, + 0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842, 0.78347842, + 0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176, 0.55529176, + 0., 0., 0., 0., 0., 0., + 0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605, 0.90935605, + 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945, 0.64692945, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.9555734, 0.9555734, 0.9555734, 0.77843476, 0.77843476, 0.77843476, + 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761, - 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0. , 0. , 0. , 0. , 0. ,0. , - 0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931, - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0. , 0. , 0. , 0. , 0. ,0. , - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0. , 0. , 0. , 0. , 0. ,0. , - 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , - 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto maxTimeStep = NDArrayFactory::create( + 'c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273, 0.86518273, + 0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761, 0.66617761, + 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203, + 0., 0., 0., 0., 0., 0., + 0.60005558, 0.60005558, 0.60005558, 0.9029975, 0.9029975, 0.9029975, + 0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931, 0.43819931, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032, 0.88852032, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775, 0.66737775, + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, + 0.31492203, 0.31492203, 0.31492203, 0., 0., 0.}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, + 0.31492203, 0.31492203, 0.31492203, 0., 0., 0.}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531, - 0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713, - 0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504, - 0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 , - 0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926, - 0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206, - 0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176, - 0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985, - 0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637, - 0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713}); - - sd::ops::static_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFWfinal = results.at(1); - auto hBWfinal = results.at(2); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW + numUnitsBW}, + {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012, 0.86841012, + 0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531, 0.88207531, + 0.31492203, 0.31492203, 0.31492203, 0.8941667, 0.8941667, 0.8941667, + 0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713, 0.90489713, + 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375, 0.91381375, + 0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504, 0.92253504, + 0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876, 0.93027876, + 0.75947891, 0.75947891, 0.75947891, 0.9371767, 0.9371767, 0.9371767, + 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274, 0.94014274, + 0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926, 0.94648926, + 0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779, 0.95204779, + 0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206, 0.95694206, + 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086, 0.93773086, + 0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176, 0.94579176, + 0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886, 0.95267886, + 0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985, 0.95857985, + 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293, 0.76075293, + 0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637, 0.78024637, + 0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344, 0.79833344, + 0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646, 0.81508646}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, + 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, + 0.8941667, 0.8941667, 0.8941667, 0.90489713, 0.90489713, 0.90489713}); + + sd::ops::static_bidirectional_rnn op; + auto results = + op.evaluate({&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , - 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0. , 0. , 0. , 0. , - 0.97732812, 0.97732812, 0.97732812, 0.97732812,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. }); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {time, bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484, 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527, 0.97136768, 0.97136768, + 0.97136768, 0.97136768, 0., 0., 0., 0., + 0.97732812, 0.97732812, 0.97732812, 0.97732812, 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, + 0.93751527, 0.93751527}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, - 0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0. , 0. , 0. , 0. , - 0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, - 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time - 1, time}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.92755601, 0.92755601, + 0.92755601, 0.92755601, 0.96778334, 0.96778334, 0.96778334, 0.96778334, + 0.97309129, 0.97309129, 0.97309129, 0.97309129, 0., 0., + 0., 0., 0.75001965, 0.75001965, 0.75001965, 0.75001965, + 0.95449491, 0.95449491, 0.95449491, 0.95449491, 0.97732828, 0.97732828, + 0.97732828, 0.97732828, 0.98000655, 0.98000655, 0.98000655, 0.98000655, + 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, + 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto h0 = NDArrayFactory::create('c', {bS, numUnits}); - - x.linspace(0.01, 0.01); - h0 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129, - 0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, - 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.92755601, 0.92755601, + 0.92755601, 0.92755601, 0.96778334, 0.96778334, 0.96778334, 0.96778334, + 0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.97491207, 0.97491207, + 0.97491207, 0.97491207, 0.75001965, 0.75001965, 0.75001965, 0.75001965, + 0.95449491, 0.95449491, 0.95449491, 0.95449491, 0.97732828, 0.97732828, + 0.97732828, 0.97732828, 0.98000655, 0.98000655, 0.98000655, 0.98000655, + 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, + 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-4}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, - 0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0. , 0. , 0. , 0. , - 0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0 , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. }); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 4}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.86347567, 0.86347567, + 0.86347567, 0.86347567, 0.96059545, 0.96059545, 0.96059545, 0.96059545, + 0.9724738, 0.9724738, 0.9724738, 0.9724738, 0., 0., + 0., 0., 0.57368608, 0.57368608, 0.57368608, 0.57368608, + 0., 0., 0, 0., 0., 0., + 0, 0., 0., 0., 0., 0., + 0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.9724738, 0.9724738, 0.9724738, 0.9724738, 0.57368608, 0.57368608, + 0.57368608, 0.57368608}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) { - - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); - - x.linspace(0.01, 0.01); - Wx = 0.3; - Wh = 0.4; - b = 0.25; - - auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545, - 0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608, - 0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727, - 0.98119833, 0.98119833, 0.98119833, 0.98119833}); - - auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833}); - - sd::ops::dynamic_rnn op; - auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto h = results.at(0); - auto hFinal = results.at(1); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expHFinal.isSameShape(hFinal)); - ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - - + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create( + 'c', {bS, time, numUnits}, + {0.47615493, 0.47615493, 0.47615493, 0.47615493, 0.86347567, 0.86347567, + 0.86347567, 0.86347567, 0.96059545, 0.96059545, 0.96059545, 0.96059545, + 0.9724738, 0.9724738, 0.9724738, 0.9724738, 0.97486307, 0.97486307, + 0.97486307, 0.97486307, 0.57368608, 0.57368608, 0.57368608, 0.57368608, + 0.92135149, 0.92135149, 0.92135149, 0.92135149, 0.97482354, 0.97482354, + 0.97482354, 0.97482354, 0.97984727, 0.97984727, 0.97984727, 0.97984727, + 0.98119833, 0.98119833, 0.98119833, 0.98119833}); + + auto expHFinal = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.97486307, 0.97486307, 0.97486307, 0.97486307, 0.98119833, 0.98119833, + 0.98119833, 0.98119833}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {time, bS, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0. , 0. , 0. , - 0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0. , 0. , 0. ,0. , 0. , 0. , - 0.9052501 , 0.9052501 , 0.9052501 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.9555734 , 0.9555734 , 0.9555734 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0. , 0. , 0. , - 0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0. , 0. , 0. ,0. , 0. , 0. , - 0.9181592 , 0.9181592 , 0.9181592 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.8026439 , 0.8026439 , 0.8026439 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2 , 0.2 , 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {time, bS, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.47615493, 0.47615493, 0.47615493, + 0.51241561, 0.51241561, 0.51241561, 0., 0., 0., + 0.73880324, 0.73880324, 0.73880324, 0.77843476, 0.77843476, 0.77843476, + 0., 0., 0., 0., 0., 0., + 0.9052501, 0.9052501, 0.9052501, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.9555734, 0.9555734, 0.9555734, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {time, bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0., 0., 0., + 0.90935605, 0.90935605, 0.90935605, 0.64692945, 0.64692945, 0.64692945, + 0., 0., 0., 0., 0., 0., + 0.9181592, 0.9181592, 0.9181592, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.8026439, 0.8026439, 0.8026439, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.9555734, 0.9555734, 0.9555734, 0.77843476, 0.77843476, 0.77843476, + 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, + 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.73978305, 0.73978305, 0.73978305,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0. , 0. , 0. , - 0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.76576202, 0.76576202, 0.76576202,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2 , 0.2 , 0.2}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.66617761, 0.66617761, 0.66617761, + 0.80944357, 0.80944357, 0.80944357, 0.87294706, 0.87294706, 0.87294706, + 0., 0., 0., 0.61067683, 0.61067683, 0.61067683, + 0.84851124, 0.84851124, 0.84851124, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.73978305, 0.73978305, 0.73978305, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.84345207, 0.84345207, 0.84345207, 0.83584708, 0.83584708, 0.83584708, + 0.77435951, 0.77435951, 0.77435951, 0.58760492, 0.58760492, 0.58760492, + 0., 0., 0., 0.85615841, 0.85615841, 0.85615841, + 0.67397984, 0.67397984, 0.67397984, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.76576202, 0.76576202, 0.76576202, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.87294706, 0.87294706, 0.87294706, 0.84851124, 0.84851124, 0.84851124, + 0.73978305, 0.73978305, 0.73978305, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, + 0.76576202, 0.76576202, 0.76576202, 0.25, 0.25, 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW, &maxTimeStep}, + {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0. , 0. , 0. , - 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0. , 0. , 0. , - 0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , - 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto maxTimeStep = + NDArrayFactory::create('c', {bS}, {time - 1, time - 3, time - 4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.22602835, 0.22602835, 0.22602835, 0.49994591, 0.49994591, 0.49994591, + 0.72869307, 0.72869307, 0.72869307, 0.84784327, 0.84784327, 0.84784327, + 0., 0., 0., 0.43819931, 0.43819931, 0.43819931, + 0.7793996, 0.7793996, 0.7793996, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.82273707, 0.82273707, 0.82273707, 0.77935851, 0.77935851, 0.77935851, + 0.6381121, 0.6381121, 0.6381121, 0.35748551, 0.35748551, 0.35748551, + 0., 0., 0., 0.77843476, 0.77843476, 0.77843476, + 0.47615493, 0.47615493, 0.47615493, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.84784327, 0.84784327, 0.84784327, 0.7793996, 0.7793996, 0.7793996, + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0.}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, + 0.61067683, 0.61067683, 0.61067683, 0., 0., 0.}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); - auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); - - x.linspace(0.01, 0.01); - h0FW = 0.2; - h0BW = 0.25; - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899, - 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483, - 0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541, - 0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 }); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662, - 0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764, - 0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363, - 0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 }); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.43819931, 0.43819931, 0.43819931, 0.66617761, 0.66617761, 0.66617761, + 0.80944357, 0.80944357, 0.80944357, 0.87294706, 0.87294706, 0.87294706, + 0.89948899, 0.89948899, 0.89948899, 0.61067683, 0.61067683, 0.61067683, + 0.84851124, 0.84851124, 0.84851124, 0.91925737, 0.91925737, 0.91925737, + 0.93751395, 0.93751395, 0.93751395, 0.94544483, 0.94544483, 0.94544483, + 0.73978305, 0.73978305, 0.73978305, 0.92827068, 0.92827068, 0.92827068, + 0.95791111, 0.95791111, 0.95791111, 0.96427356, 0.96427356, 0.96427356, + 0.96797541, 0.96797541, 0.96797541, 0.83057887, 0.83057887, 0.83057887, + 0.96365083, 0.96365083, 0.96365083, 0.97585698, 0.97585698, 0.97585698, + 0.97866981, 0.97866981, 0.97866981, 0.9807326, 0.9807326, 0.9807326}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.85301722, 0.85301722, 0.85301722, 0.86427295, 0.86427295, 0.86427295, + 0.8599919, 0.8599919, 0.8599919, 0.80609463, 0.80609463, 0.80609463, + 0.61814662, 0.61814662, 0.61814662, 0.91888753, 0.91888753, 0.91888753, + 0.92652672, 0.92652672, 0.92652672, 0.92939674, 0.92939674, 0.92939674, + 0.90661931, 0.90661931, 0.90661931, 0.74516764, 0.74516764, 0.74516764, + 0.95254269, 0.95254269, 0.95254269, 0.95710717, 0.95710717, 0.95710717, + 0.96021584, 0.96021584, 0.96021584, 0.95222547, 0.95222547, 0.95222547, + 0.83426363, 0.83426363, 0.83426363, 0.97154357, 0.97154357, 0.97154357, + 0.97424915, 0.97424915, 0.97424915, 0.97644817, 0.97644817, 0.97644817, + 0.97410547, 0.97410547, 0.97410547, 0.89409962, 0.89409962, 0.89409962}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, + 0.96797541, 0.96797541, 0.96797541, 0.9807326, 0.9807326, 0.9807326}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, + 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate( + {&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW, &h0FW, &h0BW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { - - const int bS = 4; - const int inSize = 4; - const int numUnitsFW = 3; - const int numUnitsBW = 3; - const int time = 5; - - auto x = NDArrayFactory::create('c', {bS, time, inSize}); - auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); - auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); - auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); - - x.linspace(0.01, 0.01); - WxFW = 0.3; - WhFW = 0.4; - bFW = 0.1; - - auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928, - 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339, - 0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929, - 0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144}); - - auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399, - 0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286, - 0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031, - 0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149 , 0.976149 , 0.976149 ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646}); - - auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144}); - auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); - - sd::ops::dynamic_bidirectional_rnn op; - auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto hFW = results.at(0); - auto hBW = results.at(1); - auto hFWfinal = results.at(2); - auto hBWfinal = results.at(3); - - ASSERT_TRUE(expHFW.isSameShape(hFW)); - ASSERT_TRUE(expHFW.equalsTo(hFW)); - ASSERT_TRUE(expHBW.isSameShape(hBW)); - ASSERT_TRUE(expHBW.equalsTo(hBW)); - ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); - ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); - ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); - ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - - + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2 * numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create( + 'c', {bS, time, numUnitsFW}, + {0.22602835, 0.22602835, 0.22602835, 0.49994591, 0.49994591, 0.49994591, + 0.72869307, 0.72869307, 0.72869307, 0.84784327, 0.84784327, 0.84784327, + 0.89357928, 0.89357928, 0.89357928, 0.43819931, 0.43819931, 0.43819931, + 0.7793996, 0.7793996, 0.7793996, 0.9053792, 0.9053792, 0.9053792, + 0.93546593, 0.93546593, 0.93546593, 0.94518339, 0.94518339, 0.94518339, + 0.61067683, 0.61067683, 0.61067683, 0.90347408, 0.90347408, 0.90347408, + 0.95538786, 0.95538786, 0.95538786, 0.96406045, 0.96406045, 0.96406045, + 0.96795929, 0.96795929, 0.96795929, 0.73978305, 0.73978305, 0.73978305, + 0.95499984, 0.95499984, 0.95499984, 0.97535671, 0.97535671, 0.97535671, + 0.97864446, 0.97864446, 0.97864446, 0.98073144, 0.98073144, 0.98073144}); + + auto expHBW = NDArrayFactory::create( + 'c', {bS, time, numUnitsBW}, + {0.84882345, 0.84882345, 0.84882345, 0.85160683, 0.85160683, 0.85160683, + 0.81997657, 0.81997657, 0.81997657, 0.69228829, 0.69228829, 0.69228829, + 0.39861399, 0.39861399, 0.39861399, 0.91865453, 0.91865453, 0.91865453, + 0.92528094, 0.92528094, 0.92528094, 0.92212167, 0.92212167, 0.92212167, + 0.86418213, 0.86418213, 0.86418213, 0.57969286, 0.57969286, 0.57969286, + 0.95252666, 0.95252666, 0.95252666, 0.95696305, 0.95696305, 0.95696305, + 0.95878749, 0.95878749, 0.95878749, 0.93722463, 0.93722463, 0.93722463, + 0.71727031, 0.71727031, 0.71727031, 0.97154234, 0.97154234, 0.97154234, + 0.97423089, 0.97423089, 0.97423089, 0.976149, 0.976149, 0.976149, + 0.96878298, 0.96878298, 0.96878298, 0.81508646, 0.81508646, 0.81508646}); + + auto expHFWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsFW}, + {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, + 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144}); + auto expHBWfinal = NDArrayFactory::create( + 'c', {bS, numUnitsBW}, + {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, + 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = + op.evaluate({&x, &WxFW, &WhFW, &bFW, &WxFW, &WhFW, &bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); } - TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { - auto x = NDArrayFactory::create('c', {3}, {0.15f, 0.25f, 0.35f}); - auto e = NDArrayFactory::create('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); + auto x = NDArrayFactory::create('c', {3}, {0.15f, 0.25f, 0.35f}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, result.at(0)); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { - auto x = NDArrayFactory::create('c', {1}, {0.15f}); - auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); - - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1}, {0.15f}); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); - ASSERT_EQ(e, result.at(0)); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_EQ(e, result.at(0)); } TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { - auto x = NDArrayFactory::create(0.15f); - auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); + auto x = NDArrayFactory::create(0.15f); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); - sd::ops::diag op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_EQ(e, result.at(0)); + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_EQ(e, result.at(0)); } - - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 778c20e22f22..d0640091f180 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -18,6975 +18,7333 @@ // Created by raver119 on 09.02.18. // - -#include "testlayers.h" -#include -#include #include #include #include +#include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; - class DeclarableOpsTests7 : public testing::Test { -public: - - DeclarableOpsTests7() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests7 : public testing::Test { -public: - - TypedDeclarableOpsTests7() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests7, TestingTypes); TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { - double inputData[150] = { - 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 - }; - - auto x = NDArrayFactory::create(inputData,'c',{1,149}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - - ASSERT_EQ(148,z.e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); - - - + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, + 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, + 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, + 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, + 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, + 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, + 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, + 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, + 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, + 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, + 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, + 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, + 2.16, 2.16, 2.16, 2.16, 2.16, 2.17}; + + auto x = NDArrayFactory::create(inputData, 'c', {1, 149}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148, z.e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); - - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - ASSERT_EQ(3, z.e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); - - - + auto z = result.at(1); + ASSERT_EQ(3, z.e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x,&scalar}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(3, z.lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto scalar = NDArrayFactory::create('c', {1, 1}, {0.0}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x, &scalar}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&scalar,&x}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(3,z.lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto scalar = NDArrayFactory::create('c', {1, 1}, {0.0}); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&scalar, &x}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {1.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_EQ(2,z.lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {1.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(2, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { - std::vector data; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - } - - + std::vector data; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } - auto x = NDArrayFactory::create('c',{1,4},data); - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {1.0},{5}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {1, 4}, data); + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {1.0}, {5}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_EQ(3,z.lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); + ASSERT_EQ(3, z.lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - TEST_F(DeclarableOpsTests7, TEST_WHERE) { - std::vector data; - std::vector mask; - std::vector put; - std::vector resultData; - std::vector assertion; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - if(i > 1) { - assertion.push_back(5.0); - mask.push_back(true); - } - else { - assertion.push_back(i); - mask.push_back(false); - } - - put.push_back(5.0); - resultData.push_back(0.0); + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if (i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } else { + assertion.push_back(i); + mask.push_back(false); } - - - - auto x = NDArrayFactory::create('c',{1,4},data); - auto maskArr = NDArrayFactory::create('c',{1,4},mask); - auto putArr = NDArrayFactory::create('c',{1,4},put); - auto resultArr = NDArrayFactory::create('c',{1,4},resultData); - sd::ops::where_np op; - //greater than test - // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); - ASSERT_EQ(Status::OK(), result); - for(int i = 0; i < 4; i++) - ASSERT_EQ(assertion[i],resultArr.e(i)); - // auto z = result.at(0); - //ASSERT_EQ(4,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); - - + put.push_back(5.0); + resultData.push_back(0.0); + } + + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto maskArr = NDArrayFactory::create('c', {1, 4}, mask); + auto putArr = NDArrayFactory::create('c', {1, 4}, put); + auto resultArr = NDArrayFactory::create('c', {1, 4}, resultData); + sd::ops::where_np op; + // greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, + // std::initializer_list*> outputs , + // std::initializer_list tArgs, std::initializer_list + // iArgs, bool isInplace = false); + + auto result = + op.execute({&maskArr, &x, &putArr}, {&resultArr}, {}, {3}, {}, {}, false); + ASSERT_EQ(Status::OK(), result); + for (int i = 0; i < 4; i++) ASSERT_EQ(assertion[i], resultArr.e(i)); + // auto z = result.at(0); + // ASSERT_EQ(4,z->lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_MASK) { - double x[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; - double z[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; - bool mask[300] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - double put[200] = {0.99666107,0.9867112,0.97686064,0.9671082,0.95745337,0.9478948,0.9384318,0.92906314,0.9197881,0.91060543,0.9015147,0.8925147,0.8836044,0.8747831,0.86605,0.85740393,0.8488442,0.84037,0.83198035,0.8236745,0.8154515,0.8073106,0.79925096,0.79127187,0.7833724,0.77555174,0.76780915,0.7601439,0.75255525,0.7450422,0.7376043,0.73024046,0.72295034,0.715733,0.7085876,0.7015135,0.69451016,0.68757665,0.6807124,0.6739167,0.66718876,0.66052806,0.6539338,0.6474054,0.6409421,0.6345435,0.6282087,0.6219371,0.6157281,0.60958105,0.6034956,0.59747064,0.5915059,0.5856007,0.57975453,0.5739667,0.5682366,0.5625637,0.5569475,0.5513874,0.54588276,0.540433,0.53503764,0.5296962,0.52440816,0.51917285,0.5139898,0.5088585,0.50377846,0.4987491,0.4937699,0.48884052,0.48396033,0.47912875,0.47434545,0.4696099,0.46492168,0.46028027,0.45568514,0.4511359,0.44663212,0.4421733,0.43775895,0.43338865,0.42906195,0.42477852,0.4205379,0.41633952,0.41218308,0.40806815,0.40399432,0.3999611,0.3959682,0.39201516,0.38810158,0.384227,0.38039115,0.37659356,0.37283397,0.3691119,0.36542687,0.36177874,0.35816705,0.3545914,0.35105142,0.34754673,0.34407702,0.34064204,0.33724132,0.3338745,0.33054137,0.3272415,0.32397458,0.32074028,0.3175382,0.31436813,0.31122974,0.3081226,0.30504647,0.30200112,0.2989862,0.29600134,0.29304633,0.2901207,0.28722438,0.28435695,0.2815181,0.27870762,0.27592525,0.27317056,0.27044344,0.26774356,0.26507056,0.2624243,0.25980446,0.25721073,0.25464293,0.25210077,0.249584,0.24709237,0.24462552,0.24218333,0.23976555,0.23737194,0.23500215,0.23265606,0.23033342,0.22803394,0.22575743,0.2235036,0.22127232,0.21906327,0.21687631,0.21471114,0.21256764,0.21044552,0.20834461,0.20626466,0.20420544,0.20216681,0.20014854,0.19815037,0.19617215,0.19421372,0.19227484,0.19035533,0.18845497,0.18657354,0.18471093,0.18286693,0.18104129,0.17923392,0.17744459,0.17567308,0.1739193,0.17218304,0.17046405,0.16876228,0.16707748,0.16540948,0.16375816,0.16212334,0.16050482,0.15890247,0.15731607,0.15574552,0.15419069,0.15265137,0.15112738,0.14961864,0.14812498,0.14664622,0.1451822,0.14373279,0.14229788,0.14087726,0.13947085,0.13807845,0.13669999,0.13533528}; - double assertion[300] = {1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,9.966611049434810354e-01,9.867111603284486332e-01,9.768605487739230320e-01,9.671082786103732953e-01,9.574533680683808834e-01,9.478948451798039354e-01,9.384317476799283186e-01,9.290631229105962285e-01,9.197880277243004610e-01,9.106055283892373620e-01,9.015147004953073528e-01,8.925146288610534828e-01,8.836044074415293492e-01,8.747831392370875037e-01,8.660499362030764647e-01,8.574039191604412302e-01,8.488442177072155204e-01,8.403699701308978698e-01,8.319803233217017979e-01,8.236744326866727306e-01,8.154514620646623468e-01,8.073105836421510251e-01,7.992509778699116163e-01,7.912718333805045523e-01,7.833723469065965173e-01,7.755517232000953554e-01,7.678091749520912224e-01,7.601439227135980969e-01,7.525551948170853267e-01,7.450422272987937689e-01,7.376042638218265335e-01,7.302405556000080011e-01,7.229503613225031211e-01,7.157329470791886639e-01,7.085875862867698771e-01,7.015135596156351072e-01,6.945101549174396149e-01,6.875766671534137009e-01,6.807123983233853703e-01,6.739166573955123196e-01,6.671887602367149173e-01,6.605280295438040739e-01,6.539337947752965619e-01,6.474053920839111242e-01,6.409421642497381555e-01,6.345434606140767375e-01,6.282086370139332576e-01,6.219370557171712832e-01,6.157280853583116942e-01,6.095811008749726367e-01,6.034954834449430816e-01,5.974706204238864338e-01,5.915059052836644238e-01,5.856007375512777280e-01,5.797545227484157682e-01,5.739666723316099173e-01,5.682366036329845604e-01,5.625637398015992385e-01,5.569475097453767676e-01,5.513873480736106725e-01,5.458826950400470501e-01,5.404329964865340896e-01,5.350377037872348085e-01,5.296962737933965659e-01,5.244081687786711354e-01,5.191728563849821176e-01,5.139898095689314772e-01,5.088585065487419845e-01,5.037784307517284565e-01,4.987490707622945774e-01,4.937699202704479151e-01,4.888404780208293054e-01,4.839602477622509946e-01,4.791287381977387683e-01,4.743454629350723484e-01,4.696099404378203390e-01,4.649216939768630041e-01,4.602802515824001017e-01,4.556851459964368911e-01,4.511359146257447605e-01,4.466320994952920342e-01,4.421732472021388527e-01,4.377589088697927955e-01,4.333886401030203062e-01,4.290620009431086457e-01,4.247785558235752101e-01,4.205378735263185508e-01,4.163395271382073215e-01,4.121830940081024908e-01,4.080681557043087104e-01,4.039942979724505667e-01,3.999611106937689398e-01,3.959681878438343627e-01,3.920151274516718853e-01,3.881015315592946102e-01,3.842270061816405180e-01,3.803911612669100828e-01,3.765936106572991271e-01,3.728339720501240850e-01,3.691118669593352886e-01,3.654269206774144463e-01,3.617787622376523182e-01,3.581670243768036999e-01,3.545913434981138868e-01,3.510513596347161203e-01,3.475467164133922426e-01,3.440770610186974499e-01,3.406420441574410929e-01,3.372413200235238606e-01,3.338745462631242389e-01,3.305413839402346898e-01,3.272414975025391692e-01,3.239745547476344245e-01,3.207402267895853032e-01,3.175381880258169032e-01,3.143681161043347383e-01,3.112296918912743071e-01,3.081225994387726264e-01,3.050465259531625062e-01,3.020011617634821843e-01,2.989862002903017069e-01,2.960013380148582840e-01,2.930462744485015647e-01,2.901207121024425017e-01,2.872243564578055852e-01,2.843569159359789489e-01,2.815181018692606840e-01,2.787076284717992514e-01,2.759252128108221624e-01,2.731705747781537075e-01,2.704434370620155681e-01,2.677435251191103149e-01,2.650705671469821278e-01,2.624242940566549609e-01,2.598044394455423789e-01,2.572107395706292876e-01,2.546429333219200064e-01,2.521007621961529055e-01,2.495839702707757235e-01,2.470923041781825646e-01,2.446255130802063582e-01,2.421833486428674187e-01,2.397655650113727777e-01,2.373719187853666479e-01,2.350021689944260528e-01,2.326560770738031469e-01,2.303334068404078172e-01,2.280339244690317291e-01,2.257573984688081292e-01,2.235035996599082919e-01,2.212723011504689752e-01,2.190632783137518302e-01,2.168763087655291855e-01,2.147111723416972873e-01,2.125676510761114746e-01,2.104455291786438698e-01,2.083445930134591173e-01,2.062646310775079761e-01,2.042054339792348794e-01,2.021667944174980747e-01,2.001485071607009836e-01,1.981503690261307848e-01,1.961721788595043592e-01,1.942137375147174327e-01,1.922748478337968081e-01,1.903553146270518526e-01,1.884549446534251604e-01,1.865735466010380594e-01,1.847109310679319050e-01,1.828669105430000552e-01,1.810412993871116094e-01,1.792339138144224131e-01,1.774445718738737465e-01,1.756730934308744496e-01,1.739193001491673995e-01,1.721830154728755669e-01,1.704640646087285105e-01,1.687622745084652875e-01,1.670774738514141378e-01,1.654094930272448083e-01,1.637581641188943782e-01,1.621233208856623365e-01,1.605047987464754966e-01,1.589024347633189727e-01,1.573160676248336609e-01,1.557455376300762306e-01,1.541906866724424563e-01,1.526513582237501165e-01,1.511273973184814046e-01,1.496186505381822129e-01,1.481249659960175158e-01,1.466461933214808777e-01,1.451821836452561187e-01,1.437327895842310799e-01,1.422978652266598532e-01,1.408772661174743090e-01,1.394708492437411185e-01,1.380784730202649913e-01,1.366999972753347725e-01,1.353352832366127023e-01}; - Nd4jLong threeHundredShapePointer[8] = {2,1,300,1,1,0,1,99}; - Nd4jLong twoHundredShapePointer[8] = {2,1,200,1,1,0,1,99}; - sd::ops::where_np op; - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); - ArrayOptions::setDataType(twoHundredShapePointer, sd::DataType::DOUBLE); - - NDArray xArr(x,threeHundredShapePointer); - NDArray putArr(put,twoHundredShapePointer); - NDArray resultArr(z,threeHundredShapePointer); - - resultArr.assign(0.0); - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::BOOL); - NDArray maskArr(mask,threeHundredShapePointer); - - ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); - NDArray assertArr(assertion, threeHundredShapePointer); - Nd4jStatus result = op.execute({&maskArr, &xArr, &putArr},{&resultArr},{},{},{}); - ASSERT_EQ(Status::OK(),result); - ASSERT_TRUE(assertArr.isSameShape(resultArr)); - ASSERT_TRUE (assertArr.equalsTo(resultArr)); + double x[300] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + double z[300] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + bool mask[300] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + double put[200] = { + 0.99666107, 0.9867112, 0.97686064, 0.9671082, 0.95745337, 0.9478948, + 0.9384318, 0.92906314, 0.9197881, 0.91060543, 0.9015147, 0.8925147, + 0.8836044, 0.8747831, 0.86605, 0.85740393, 0.8488442, 0.84037, + 0.83198035, 0.8236745, 0.8154515, 0.8073106, 0.79925096, 0.79127187, + 0.7833724, 0.77555174, 0.76780915, 0.7601439, 0.75255525, 0.7450422, + 0.7376043, 0.73024046, 0.72295034, 0.715733, 0.7085876, 0.7015135, + 0.69451016, 0.68757665, 0.6807124, 0.6739167, 0.66718876, 0.66052806, + 0.6539338, 0.6474054, 0.6409421, 0.6345435, 0.6282087, 0.6219371, + 0.6157281, 0.60958105, 0.6034956, 0.59747064, 0.5915059, 0.5856007, + 0.57975453, 0.5739667, 0.5682366, 0.5625637, 0.5569475, 0.5513874, + 0.54588276, 0.540433, 0.53503764, 0.5296962, 0.52440816, 0.51917285, + 0.5139898, 0.5088585, 0.50377846, 0.4987491, 0.4937699, 0.48884052, + 0.48396033, 0.47912875, 0.47434545, 0.4696099, 0.46492168, 0.46028027, + 0.45568514, 0.4511359, 0.44663212, 0.4421733, 0.43775895, 0.43338865, + 0.42906195, 0.42477852, 0.4205379, 0.41633952, 0.41218308, 0.40806815, + 0.40399432, 0.3999611, 0.3959682, 0.39201516, 0.38810158, 0.384227, + 0.38039115, 0.37659356, 0.37283397, 0.3691119, 0.36542687, 0.36177874, + 0.35816705, 0.3545914, 0.35105142, 0.34754673, 0.34407702, 0.34064204, + 0.33724132, 0.3338745, 0.33054137, 0.3272415, 0.32397458, 0.32074028, + 0.3175382, 0.31436813, 0.31122974, 0.3081226, 0.30504647, 0.30200112, + 0.2989862, 0.29600134, 0.29304633, 0.2901207, 0.28722438, 0.28435695, + 0.2815181, 0.27870762, 0.27592525, 0.27317056, 0.27044344, 0.26774356, + 0.26507056, 0.2624243, 0.25980446, 0.25721073, 0.25464293, 0.25210077, + 0.249584, 0.24709237, 0.24462552, 0.24218333, 0.23976555, 0.23737194, + 0.23500215, 0.23265606, 0.23033342, 0.22803394, 0.22575743, 0.2235036, + 0.22127232, 0.21906327, 0.21687631, 0.21471114, 0.21256764, 0.21044552, + 0.20834461, 0.20626466, 0.20420544, 0.20216681, 0.20014854, 0.19815037, + 0.19617215, 0.19421372, 0.19227484, 0.19035533, 0.18845497, 0.18657354, + 0.18471093, 0.18286693, 0.18104129, 0.17923392, 0.17744459, 0.17567308, + 0.1739193, 0.17218304, 0.17046405, 0.16876228, 0.16707748, 0.16540948, + 0.16375816, 0.16212334, 0.16050482, 0.15890247, 0.15731607, 0.15574552, + 0.15419069, 0.15265137, 0.15112738, 0.14961864, 0.14812498, 0.14664622, + 0.1451822, 0.14373279, 0.14229788, 0.14087726, 0.13947085, 0.13807845, + 0.13669999, 0.13533528}; + double assertion[300] = {1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 9.966611049434810354e-01, 9.867111603284486332e-01, + 9.768605487739230320e-01, 9.671082786103732953e-01, + 9.574533680683808834e-01, 9.478948451798039354e-01, + 9.384317476799283186e-01, 9.290631229105962285e-01, + 9.197880277243004610e-01, 9.106055283892373620e-01, + 9.015147004953073528e-01, 8.925146288610534828e-01, + 8.836044074415293492e-01, 8.747831392370875037e-01, + 8.660499362030764647e-01, 8.574039191604412302e-01, + 8.488442177072155204e-01, 8.403699701308978698e-01, + 8.319803233217017979e-01, 8.236744326866727306e-01, + 8.154514620646623468e-01, 8.073105836421510251e-01, + 7.992509778699116163e-01, 7.912718333805045523e-01, + 7.833723469065965173e-01, 7.755517232000953554e-01, + 7.678091749520912224e-01, 7.601439227135980969e-01, + 7.525551948170853267e-01, 7.450422272987937689e-01, + 7.376042638218265335e-01, 7.302405556000080011e-01, + 7.229503613225031211e-01, 7.157329470791886639e-01, + 7.085875862867698771e-01, 7.015135596156351072e-01, + 6.945101549174396149e-01, 6.875766671534137009e-01, + 6.807123983233853703e-01, 6.739166573955123196e-01, + 6.671887602367149173e-01, 6.605280295438040739e-01, + 6.539337947752965619e-01, 6.474053920839111242e-01, + 6.409421642497381555e-01, 6.345434606140767375e-01, + 6.282086370139332576e-01, 6.219370557171712832e-01, + 6.157280853583116942e-01, 6.095811008749726367e-01, + 6.034954834449430816e-01, 5.974706204238864338e-01, + 5.915059052836644238e-01, 5.856007375512777280e-01, + 5.797545227484157682e-01, 5.739666723316099173e-01, + 5.682366036329845604e-01, 5.625637398015992385e-01, + 5.569475097453767676e-01, 5.513873480736106725e-01, + 5.458826950400470501e-01, 5.404329964865340896e-01, + 5.350377037872348085e-01, 5.296962737933965659e-01, + 5.244081687786711354e-01, 5.191728563849821176e-01, + 5.139898095689314772e-01, 5.088585065487419845e-01, + 5.037784307517284565e-01, 4.987490707622945774e-01, + 4.937699202704479151e-01, 4.888404780208293054e-01, + 4.839602477622509946e-01, 4.791287381977387683e-01, + 4.743454629350723484e-01, 4.696099404378203390e-01, + 4.649216939768630041e-01, 4.602802515824001017e-01, + 4.556851459964368911e-01, 4.511359146257447605e-01, + 4.466320994952920342e-01, 4.421732472021388527e-01, + 4.377589088697927955e-01, 4.333886401030203062e-01, + 4.290620009431086457e-01, 4.247785558235752101e-01, + 4.205378735263185508e-01, 4.163395271382073215e-01, + 4.121830940081024908e-01, 4.080681557043087104e-01, + 4.039942979724505667e-01, 3.999611106937689398e-01, + 3.959681878438343627e-01, 3.920151274516718853e-01, + 3.881015315592946102e-01, 3.842270061816405180e-01, + 3.803911612669100828e-01, 3.765936106572991271e-01, + 3.728339720501240850e-01, 3.691118669593352886e-01, + 3.654269206774144463e-01, 3.617787622376523182e-01, + 3.581670243768036999e-01, 3.545913434981138868e-01, + 3.510513596347161203e-01, 3.475467164133922426e-01, + 3.440770610186974499e-01, 3.406420441574410929e-01, + 3.372413200235238606e-01, 3.338745462631242389e-01, + 3.305413839402346898e-01, 3.272414975025391692e-01, + 3.239745547476344245e-01, 3.207402267895853032e-01, + 3.175381880258169032e-01, 3.143681161043347383e-01, + 3.112296918912743071e-01, 3.081225994387726264e-01, + 3.050465259531625062e-01, 3.020011617634821843e-01, + 2.989862002903017069e-01, 2.960013380148582840e-01, + 2.930462744485015647e-01, 2.901207121024425017e-01, + 2.872243564578055852e-01, 2.843569159359789489e-01, + 2.815181018692606840e-01, 2.787076284717992514e-01, + 2.759252128108221624e-01, 2.731705747781537075e-01, + 2.704434370620155681e-01, 2.677435251191103149e-01, + 2.650705671469821278e-01, 2.624242940566549609e-01, + 2.598044394455423789e-01, 2.572107395706292876e-01, + 2.546429333219200064e-01, 2.521007621961529055e-01, + 2.495839702707757235e-01, 2.470923041781825646e-01, + 2.446255130802063582e-01, 2.421833486428674187e-01, + 2.397655650113727777e-01, 2.373719187853666479e-01, + 2.350021689944260528e-01, 2.326560770738031469e-01, + 2.303334068404078172e-01, 2.280339244690317291e-01, + 2.257573984688081292e-01, 2.235035996599082919e-01, + 2.212723011504689752e-01, 2.190632783137518302e-01, + 2.168763087655291855e-01, 2.147111723416972873e-01, + 2.125676510761114746e-01, 2.104455291786438698e-01, + 2.083445930134591173e-01, 2.062646310775079761e-01, + 2.042054339792348794e-01, 2.021667944174980747e-01, + 2.001485071607009836e-01, 1.981503690261307848e-01, + 1.961721788595043592e-01, 1.942137375147174327e-01, + 1.922748478337968081e-01, 1.903553146270518526e-01, + 1.884549446534251604e-01, 1.865735466010380594e-01, + 1.847109310679319050e-01, 1.828669105430000552e-01, + 1.810412993871116094e-01, 1.792339138144224131e-01, + 1.774445718738737465e-01, 1.756730934308744496e-01, + 1.739193001491673995e-01, 1.721830154728755669e-01, + 1.704640646087285105e-01, 1.687622745084652875e-01, + 1.670774738514141378e-01, 1.654094930272448083e-01, + 1.637581641188943782e-01, 1.621233208856623365e-01, + 1.605047987464754966e-01, 1.589024347633189727e-01, + 1.573160676248336609e-01, 1.557455376300762306e-01, + 1.541906866724424563e-01, 1.526513582237501165e-01, + 1.511273973184814046e-01, 1.496186505381822129e-01, + 1.481249659960175158e-01, 1.466461933214808777e-01, + 1.451821836452561187e-01, 1.437327895842310799e-01, + 1.422978652266598532e-01, 1.408772661174743090e-01, + 1.394708492437411185e-01, 1.380784730202649913e-01, + 1.366999972753347725e-01, 1.353352832366127023e-01}; + Nd4jLong threeHundredShapePointer[8] = {2, 1, 300, 1, 1, 0, 1, 99}; + Nd4jLong twoHundredShapePointer[8] = {2, 1, 200, 1, 1, 0, 1, 99}; + sd::ops::where_np op; + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + ArrayOptions::setDataType(twoHundredShapePointer, sd::DataType::DOUBLE); + + NDArray xArr(x, threeHundredShapePointer); + NDArray putArr(put, twoHundredShapePointer); + NDArray resultArr(z, threeHundredShapePointer); + + resultArr.assign(0.0); + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::BOOL); + NDArray maskArr(mask, threeHundredShapePointer); + + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + NDArray assertArr(assertion, threeHundredShapePointer); + Nd4jStatus result = + op.execute({&maskArr, &xArr, &putArr}, {&resultArr}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(assertArr.isSameShape(resultArr)); + ASSERT_TRUE(assertArr.equalsTo(resultArr)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_SCALAR) { - std::vector data; - std::vector mask; - std::vector put; - std::vector resultData; - std::vector assertion; - for(Nd4jLong i = 0; i < 4; i++) { - data.push_back(i); - if(i > 1) { - assertion.push_back(5.0); - mask.push_back(true); - } - else { - assertion.push_back(i); - mask.push_back(false); - } - - resultData.push_back(0.0); + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for (Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if (i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } else { + assertion.push_back(i); + mask.push_back(false); } + resultData.push_back(0.0); + } - put.push_back(5.0); - - - auto x = NDArrayFactory::create('c',{1,4},data); - auto maskArr = NDArrayFactory::create('c',{1,4},mask); - auto putArr = NDArrayFactory::create('c',{1,1},put); - auto resultArr = NDArrayFactory::create('c',{1,4},resultData); - sd::ops::where_np op; - //greater than test - // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); - - auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); - // ASSERT_EQ(Status::OK(), result.status()); - for(int i = 0; i < 4; i++) - ASSERT_EQ(assertion[i],resultArr.e(i)); - // auto z = result.at(0); - //ASSERT_EQ(4,z->lengthOf()); - //ASSERT_TRUE(exp.isSameShape(z)); + put.push_back(5.0); + auto x = NDArrayFactory::create('c', {1, 4}, data); + auto maskArr = NDArrayFactory::create('c', {1, 4}, mask); + auto putArr = NDArrayFactory::create('c', {1, 1}, put); + auto resultArr = NDArrayFactory::create('c', {1, 4}, resultData); + sd::ops::where_np op; + // greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, + // std::initializer_list*> outputs , + // std::initializer_list tArgs, std::initializer_list + // iArgs, bool isInplace = false); + auto result = + op.execute({&maskArr, &x, &putArr}, {&resultArr}, {}, {3}, {}, {}, false); + // ASSERT_EQ(Status::OK(), result.status()); + for (int i = 0; i < 4; i++) ASSERT_EQ(assertion[i], resultArr.e(i)); + // auto z = result.at(0); + // ASSERT_EQ(4,z->lengthOf()); + // ASSERT_TRUE(exp.isSameShape(z)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_1) { - auto x = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - - auto z = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); + auto x = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - sd::ops::matrix_diag_part op; + auto z = NDArrayFactory::create( + 'c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::matrix_diag_part op; - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {}); - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_2) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); - auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); + auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - sd::ops::matrix_diag_part op; + sd::ops::matrix_diag_part op; - auto result = op.evaluate({&x}, {}, {}); + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); - - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_1) { - auto z = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - - auto x = NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto z = NDArrayFactory::create( + 'c', {2, 4, 4}, + {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4., + 5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); - sd::ops::matrix_diag op; + auto x = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::matrix_diag op; - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {}); - + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_2) { - auto z = NDArrayFactory::create('c', {2, 3, 3}, {1., 0., 0., 0., 2., 0., 0., 0., 3.,5., 0., 0., 0., 6., 0.,0., 0., 7.}); - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - - sd::ops::matrix_diag op; + auto z = NDArrayFactory::create( + 'c', {2, 3, 3}, + {1., 0., 0., 0., 2., 0., 0., 0., 3., 5., 0., 0., 0., 6., 0., 0., 0., 7.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(z.equalsTo(result.at(0))); + sd::ops::matrix_diag op; - + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_1) { - auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto shape = NDArrayFactory::create({1, 2, 3}); - sd::ops::random_crop op; - - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// ASSERT_TRUE(z.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create('c', {2, 2, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto shape = NDArrayFactory::create({1, 2, 3}); + sd::ops::random_crop op; - + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_2) { - auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto shape = NDArrayFactory::create({2, 2, 2}); - sd::ops::random_crop op; - - auto result = op.evaluate({&x, &shape}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// ASSERT_TRUE(z.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create('c', {2, 2, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto shape = NDArrayFactory::create({2, 2, 2}); + sd::ops::random_crop op; - + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // ASSERT_TRUE(z.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, - 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); - - auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, - 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, - 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printShapeInfo("Output shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create( + 'c', {2, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create( + 'c', {3, 1, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, - 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, - 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); - - auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, - 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, - 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, - 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, - 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); - - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printShapeInfo("Output shape"); - auto res = result.at(0); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - int numOfCases = 100; - auto timeStart = std::chrono::system_clock::now(); - - for (int i = 0; i < numOfCases; i++) { - op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {&res}, {}, {}, {}); - } - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - //nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld us.\n", numOfCases, outerTime / numOfCases); - - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create( + 'c', {2, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create( + 'c', {2, 3, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create( + 'c', {3, 1, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, + 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, + 71.f, 72.f, 73.f, 74.f, 75.f, 76.f, 77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, + 81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 100.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, + 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printShapeInfo("Output shape"); + auto res = result.at(0); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + int numOfCases = 100; + auto timeStart = std::chrono::system_clock::now(); + + for (int i = 0; i < numOfCases; i++) { + op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, + {&res}, {}, {}, {}); + } + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); + // nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld + // us.\n", numOfCases, outerTime / numOfCases); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - - auto data0 = NDArrayFactory::create('c', {2,5,4}); - auto data1 = NDArrayFactory::create('c', {2,3,5,4}); - auto data2 = NDArrayFactory::create('c', {3,1,5,4}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, { - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - - 181, 182, 183, 184, - 185, 186, 187, 188, - 189, 190, 191, 192, - 193, 194, 195, 196, - 197, 198, 199, 200, - - 121, 122, 123, 124, - 125, 126, 127, 128, - 129, 130, 131, 132, - 133, 134, 135, 136, - 137, 138, 139, 140, - - 161, 162, 163, 164, - 165, 166, 167, 168, - 169, 170, 171, 172, - 173, 174, 175, 176, - 177, 178, 179, 180, - - 81, 82, 83, 84, - 85, 86, 87, 88, - 89, 90, 91, 92, - 93, 94, 95, 96, - 97, 98, 99, 100, - - 141, 142, 143, 144, - 145, 146, 147, 148, - 149, 150, 151, 152, - 153, 154, 155, 156, - 157, 158, 159, 160, - - 41, 42, 43, 44, - 45, 46, 47, 48, - 49, 50, 51, 52, - 53, 54, 55, 56, - 57, 58, 59, 60, - - 101, 102, 103, 104, - 105, 106, 107, 108, - 109, 110, 111, 112, - 113, 114, 115, 116, - 117, 118, 119, 120, - - 61, 62, 63, 64, - 65, 66, 67, 68, - 69, 70, 71, 72, - 73, 74, 75, 76, - 77, 78, 79, 80, - - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - }); - data0.linspace(1); - data1.linspace(21); - data2.linspace(141); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(z.isSameShape(exp)); - ASSERT_TRUE(z.equalsTo(exp)); - - + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + + auto data0 = NDArrayFactory::create('c', {2, 5, 4}); + auto data1 = NDArrayFactory::create('c', {2, 3, 5, 4}); + auto data2 = NDArrayFactory::create('c', {3, 1, 5, 4}); + + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + { + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + + 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, + + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, + + 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, + 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, + + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, + 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, + + 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, + 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, + + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + + 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, + 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(21); + data2.linspace(141); + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { - auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); - auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); - auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); - - auto data0 = NDArrayFactory::create('c', {2,5,4}); - auto data1 = NDArrayFactory::create('c', {2,3,5,4}); - auto data2 = NDArrayFactory::create('c', {3,1,5,4}); - - auto exp = NDArrayFactory::create('c', {11, 5, 4}, { - 41, 42, 43, 44, - 45, 46, 47, 48, - 49, 50, 51, 52, - 53, 54, 55, 56, - 57, 58, 59, 60, - - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16, - 17, 18, 19, 20, - - 201, 202, 203, 204, - 205, 206, 207, 208, - 209, 210, 211, 212, - 213, 214, 215, 216, - 217, 218, 219, 220, - - 141, 142, 143, 144, - 145, 146, 147, 148, - 149, 150, 151, 152, - 153, 154, 155, 156, - 157, 158, 159, 160, - - 181, 182, 183, 184, - 185, 186, 187, 188, - 189, 190, 191, 192, - 193, 194, 195, 196, - 197, 198, 199, 200, - - 101, 102, 103, 104, - 105, 106, 107, 108, - 109, 110, 111, 112, - 113, 114, 115, 116, - 117, 118, 119, 120, - - 161, 162, 163, 164, - 165, 166, 167, 168, - 169, 170, 171, 172, - 173, 174, 175, 176, - 177, 178, 179, 180, - - 61, 62, 63, 64, - 65, 66, 67, 68, - 69, 70, 71, 72, - 73, 74, 75, 76, - 77, 78, 79, 80, - - 121, 122, 123, 124, - 125, 126, 127, 128, - 129, 130, 131, 132, - 133, 134, 135, 136, - 137, 138, 139, 140, - - 81, 82, 83, 84, - 85, 86, 87, 88, - 89, 90, 91, 92, - 93, 94, 95, 96, - 97, 98, 99, 100, - - 21, 22, 23, 24, - 25, 26, 27, 28, - 29, 30, 31, 32, - 33, 34, 35, 36, - 37, 38, 39, 40, - }); - data0.linspace(1); - data1.linspace(41); - data2.linspace(161); - sd::ops::dynamic_stitch op; - auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - ASSERT_TRUE(z.isSameShape(exp)); - ASSERT_TRUE(z.equalsTo(exp)); - - -} + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { - auto x = NDArrayFactory::create('c', {5, 4, 11}); - auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); - auto e = NDArrayFactory::create('c', {5, 11}); - x.assign(1.f); - e.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {4}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(4, result.size()); - auto z = result.at(0); + auto data0 = NDArrayFactory::create('c', {2, 5, 4}); + auto data1 = NDArrayFactory::create('c', {2, 3, 5, 4}); + auto data2 = NDArrayFactory::create('c', {3, 1, 5, 4}); - ASSERT_TRUE(e.isSameShape(z)); + auto exp = NDArrayFactory::create( + 'c', {11, 5, 4}, + { + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, - -} + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { - auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20,11, 21,12, 22,13, 23,14, 24,15, 25,16, 26,17, 27,18, 28,19, 29,20, 30,21, 31}); - - auto y = NDArrayFactory::create('c', {3, 4}, {0,0,0,0, 2,2,2,2, 2,1,1,1}); - auto e = NDArrayFactory::create('c', {4, 2}, {10, 20, 11, 21, 12, 22, 13, 23}); - -// x.assign(1.f); -// e.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {3}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(3, result.size()); - auto z = result.at(0); -// z->printShapeInfo("Output shape info"); -// result.at(1)->printShapeInfo("Shape2"); -// result.at(2)->printShapeInfo("Shape3"); -// result.at(3)->printShapeInfo("Shape4"); -// z->printIndexedBuffer("Output1"); -// result.at(1)->printIndexedBuffer("Output2"); -// result.at(2)->printIndexedBuffer("Output3"); -// result.at(3)->printIndexedBuffer("Output4"); - ASSERT_TRUE(e.isSameShape(z)); - - -} -TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { - auto x = NDArrayFactory::create('c', {5, 4, 11}); - auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); - auto e1 = NDArrayFactory::create('c', {5, 11}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, - 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, - 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, - 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); - auto e2 = NDArrayFactory::create('c', {5, 11}, { 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, - 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); - auto e3 = NDArrayFactory::create('c', {5, 11}, {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, - 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, - 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, - 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, - 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); - auto e4 = NDArrayFactory::create('c', {5, 11}, { 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, - 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, - 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, - 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}) ; - std::vector e({&e1, &e2, &e3, &e4}); - x.linspace(1.f); - //.assign(1.f); - sd::ops::dynamic_partition op; - auto result = op.evaluate({&x, &y}, {}, {4}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(4, result.size()); - for (size_t i = 0; i < result.size(); i++) { - auto z = result.at(i); -// z->printShapeInfo("Output shape info"); -// z->printIndexedBuffer("Output1"); -// result.at(1)->printIndexedBuffer("Output2"); -// result.at(2)->printIndexedBuffer("Output3"); -// result.at(3)->printIndexedBuffer("Output4"); - ASSERT_TRUE(e[i]->isSameShape(z)); - ASSERT_TRUE(e[i]->equalsTo(z)); - } + 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, - -} + 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, + 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, + 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, -TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { - auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); + 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, + 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, + 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, - auto z = result.at(0); -// z->printIndexedBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, - + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, -} + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, + 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, -TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(41); + data2.linspace(161); + sd::ops::dynamic_stitch op; + auto result = op.evaluate( + {&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(z.isSameShape(exp)); + ASSERT_TRUE(z.equalsTo(exp)); +} - +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create( + 'c', {5, 4}, + {0, 1, 2, 3, 1, 0, 2, 3, 2, 3, 1, 0, 2, 1, 0, 3, 0, 1, 2, 3}); + auto e = NDArrayFactory::create('c', {5, 11}); + x.assign(1.f); + e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); } -TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { - auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { + auto x = NDArrayFactory::create( + 'c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, + 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create('c', {3, 4}, + {0, 0, 0, 0, 2, 2, 2, 2, 2, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {4, 2}, + {10, 20, 11, 21, 12, 22, 13, 23}); + + // x.assign(1.f); + // e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + auto z = result.at(0); + // z->printShapeInfo("Output shape info"); + // result.at(1)->printShapeInfo("Shape2"); + // result.at(2)->printShapeInfo("Shape3"); + // result.at(3)->printShapeInfo("Shape4"); + // z->printIndexedBuffer("Output1"); + // result.at(1)->printIndexedBuffer("Output2"); + // result.at(2)->printIndexedBuffer("Output3"); + // result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e.isSameShape(z)); +} +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create( + 'c', {5, 4}, + {0, 1, 2, 3, 1, 0, 2, 3, 2, 3, 1, 0, 2, 1, 0, 3, 0, 1, 2, 3}); + auto e1 = NDArrayFactory::create( + 'c', {5, 11}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 56, 57, 58, + 59, 60, 61, 62, 63, 64, 65, 66, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 155, 156, 157, 158, 159, 160, 161, 162, 163, + 164, 165, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); + auto e2 = NDArrayFactory::create( + 'c', {5, 11}, + {12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 111, 112, 113, 114, 115, 116, + 117, 118, 119, 120, 121, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); + auto e3 = NDArrayFactory::create( + 'c', {5, 11}, + {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 67, 68, 69, + 70, 71, 72, 73, 74, 75, 76, 77, 89, 90, 91, 92, 93, 94, + 95, 96, 97, 98, 99, 133, 134, 135, 136, 137, 138, 139, 140, 141, + 142, 143, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); + auto e4 = NDArrayFactory::create( + 'c', {5, 11}, + {34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 166, 167, 168, 169, 170, 171, 172, 173, 174, + 175, 176, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}); + std::vector e({&e1, &e2, &e3, &e4}); + x.linspace(1.f); + //.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + for (size_t i = 0; i < result.size(); i++) { + auto z = result.at(i); + // z->printShapeInfo("Output shape info"); + // z->printIndexedBuffer("Output1"); + // result.at(1)->printIndexedBuffer("Output2"); + // result.at(2)->printIndexedBuffer("Output3"); + // result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e[i]->isSameShape(z)); + ASSERT_TRUE(e[i]->equalsTo(z)); + } +} - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {sd::DataType::INT32}); - ASSERT_EQ(Status::OK(), result.status()); +TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { + auto input = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 16}, + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); +TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 30}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} - +TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 30}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) { - auto input = NDArrayFactory::create({1, 3, 2}); - auto maxLen = NDArrayFactory::create(5); - auto exp = NDArrayFactory::create('c', {3,5}, { - 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f - }); + auto input = NDArrayFactory::create({1, 3, 2}); + auto maxLen = NDArrayFactory::create(5); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - sd::ops::sequence_mask op; - auto result = op.evaluate({&input, &maxLen}, {sd::DataType::FLOAT32}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::sequence_mask op; + auto result = op.evaluate({&input, &maxLen}, {sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) { - auto input = NDArrayFactory::create({1, 3, 2}); - auto exp = NDArrayFactory::create('c', {3,5}, { - 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f - }); - - sd::ops::sequence_mask op; - auto result = op.evaluate({&input}, {5, (int)sd::DataType::FLOAT32}); - ASSERT_EQ(Status::OK(), result.status()); + auto input = NDArrayFactory::create({1, 3, 2}); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, + 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); - auto z = result.at(0); -// z->printBuffer("Output"); -// z->printShapeInfo("Shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {5, (int)sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printBuffer("Output"); + // z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); - - sd::ops::segment_max op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printBuffer("MaX1"); -// exp.printBuffer("ExP1"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("MaX1"); + // exp.printBuffer("ExP1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_01) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1., 10, 40, 30}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5,5, 5}); - auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., + 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1., + 10, 40, 30}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); - sd::ops::segment_max op; + sd::ops::segment_max op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printBuffer("MaX01"); -// exp.printBuffer("ExP01"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printBuffer("MaX01"); + // exp.printBuffer("ExP01"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("OutputMaxBP"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {0., 1., 0., 2., 0., 0., 3., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("OutputMaxBP"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_2) { - auto x = NDArrayFactory::create('c', {5, 4}, { 0, 1.8, 2.5, 4., - 1, 9., 2.1, 2.4, - 0, 3., 9., 2.1, - 2, 1, 2.1, 0.7, - 3, 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1, 9, 9, 4, - 2, 1, 2.1, 0.7, - 3, 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::segment_max op; + auto x = NDArrayFactory::create( + 'c', {5, 4}, {0, 1.8, 2.5, 4., 1, 9., 2.1, 2.4, 0, 3., + 9., 2.1, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1, 9, 9, 4, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - auto out = result.at(0); -// out->printIndexedBuffer("Output2Max"); -// exp.printIndexedBuffer("Expect2Max"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + auto out = result.at(0); + // out->printIndexedBuffer("Output2Max"); + // exp.printIndexedBuffer("Expect2Max"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); -// NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + // NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, + // 0.1,3., 4.2, 2.2, 1.}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::segment_max_bp op; + sd::ops::segment_max_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); - //exp.printIndexedBuffer("BP Max Expect"); - //result.at(0)->printIndexedBuffer("BP Max Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("BP Max Expect"); + // result.at(0)->printIndexedBuffer("BP Max Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28.,119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117.,51. , 42. , 87. , 44., - 55.1, 56.4, 93. , 28.,119.1, 82.1,112.7,113.1,114. ,114.2,116.2,117.,91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28., 119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117. }); + // ---------------------------------------------------------------- - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 87., 44., + 55.1, 56.4, 93., 28., 119.1, 82.1, 112.7, 113.1, 114., 114.2, + 116.2, 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); - sd::ops::segment_max op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output3Max"); -// result.at(0)->printShapeInfo("Out Shape 3 Max"); -// exp.printIndexedBuffer("Expect3Max"); -// exp.printShapeInfo("Exp Shape 3 Max"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output3Max"); + // result.at(0)->printShapeInfo("Out Shape 3 Max"); + // exp.printIndexedBuffer("Expect3Max"); + // exp.printShapeInfo("Exp Shape 3 Max"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_max op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); - auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); - - sd::ops::unsorted_segment_max op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {0., 1., 0., 2., 0., 0., 3., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_2) { - auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3., 0., 1., 0., 2., 0., 0., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); - auto eps = NDArrayFactory::create('c', {5}); - sd::ops::segment_max_bp op; - eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3., 0., 1., 0., 2., 0., 0., 4., 0., 0., 0., 0., 0., 5., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_2) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); - auto exp = NDArrayFactory::create({2.2, 9., -DataTypeUtils::max(), 9., 4.2}); - - sd::ops::unsorted_segment_max op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create( + {2.2, 9., -DataTypeUtils::max(), 9., 4.2}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("OutputUnsortedMax"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("OutputUnsortedMax"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_3) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {2.1, 2.5, 4, 9, 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_max op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - //exp.printIndexedBuffer("Expect"); - //result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_max op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_4) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 0, 2}); - double principalMax = DataTypeUtils::max(); - auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 11.7, 9, - -principalMax, -principalMax, -principalMax, -principalMax, - 3., 4.2, 2.2, 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, + 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 0, 2}); + double principalMax = DataTypeUtils::max(); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {2.1, 2.5, 11.7, 9, -principalMax, -principalMax, -principalMax, + -principalMax, 3., 4.2, 2.2, 1.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::unsorted_segment_max op; + sd::ops::unsorted_segment_max op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - //exp.printIndexedBuffer("Expect"); - //result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4, 3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - - sd::ops::segment_min op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + sd::ops::segment_min op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { - auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); - - sd::ops::segment_min op; + auto x = + NDArrayFactory::create({1.8, -2.5, 4., -9., 2.1, 2.4, -3., -9., + 2.1, 2.1, 0.7, 0.1, 3., -4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + sd::ops::segment_min op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { - auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, -3.f, -9.f, 2.1f, 2.1f,0.7f, 0.1f, 3.f, -4.2f, 2.2f, 1.f}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); + auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, + -3.f, -9.f, 2.1f, 2.1f, 0.7f, 0.1f, + 3.f, -4.2f, 2.2f, 1.f}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); - sd::ops::segment_min op; + sd::ops::segment_min op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::segment_min_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::segment_min_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::unsorted_segment_min_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output1"); - //exp.printIndexedBuffer("Expecte"); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output1"); + // exp.printIndexedBuffer("Expecte"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_2) { - auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); - auto eps = NDArrayFactory::create('c', {5}); - eps.linspace(1); - sd::ops::unsorted_segment_min_bp op; + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output1"); - //exp.printIndexedBuffer("Expecte"); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output1"); + // exp.printIndexedBuffer("Expecte"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.8, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::segment_min op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_min op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3. , 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - sd::ops::segment_min_bp op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_min_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, - 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , - 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 31., 22., 67., 24., + 15.1, 46.4, 73., 28., 109.1, 12.1, 12.7, 13.1, 14., 14.2, + 16.2, 11., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - - sd::ops::unsorted_segment_min op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_min op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_01) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); - sd::ops::unsorted_segment_min op; + sd::ops::unsorted_segment_min op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.8, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_min op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_min op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, - 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , - 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 31., 22., 67., 24., + 15.1, 46.4, 73., 28., 109.1, 12.1, 12.7, 13.1, 14., 14.2, + 16.2, 11., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - 51., 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., - 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - double principalMax = DataTypeUtils::max(); - - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 51., - 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, - 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, - 91., 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_min op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - // exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + double principalMax = DataTypeUtils::max(); + + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, + 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, + 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, + 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, 31., 22., + 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, 91., 82., 37., + 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., + 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - - sd::ops::segment_mean op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); - - sd::ops::segment_mean op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { - auto x = NDArrayFactory::create('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto x = + NDArrayFactory::create('c', {6, 3}, + {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - sd::ops::segment_mean op; + sd::ops::segment_mean op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { - auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); - - sd::ops::segment_mean op; - x.linspace(1.); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {6, 3}); //, {1, + //2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., + //16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { - auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); - auto z = NDArrayFactory::create('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + auto x = NDArrayFactory::create( + 'c', {6, 3}); //, {1, + //2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., + //16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); + auto z = NDArrayFactory::create( + 'c', + {3, + 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, {2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); - sd::ops::segment_mean op; - x.linspace(1.); - auto result = op.execute({&x, &idx}, {&z}); - ASSERT_EQ(result, Status::OK()); + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {&z}); + ASSERT_EQ(result, Status::OK()); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, { 0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); - eps.linspace(1); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); + eps.linspace(1); - sd::ops::segment_mean_bp op; + sd::ops::segment_mean_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 41., 32., 77., 34., + 35.1, 51.4, 83., 28., 114.1, 47.1, 62.7, 63.1, 64., 64.2, + 66.2, 64., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); - sd::ops::unsorted_segment_mean op; + sd::ops::unsorted_segment_mean op; - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::segment_mean_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 3., 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::segment_mean_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::unsorted_segment_mean_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 3., 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::unsorted_segment_mean_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_2) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({3., 1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 4./3., 4./3., 4./3., - 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); - sd::ops::unsorted_segment_mean_bp op; + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {3., 1. / 2., 1. / 2., 2. / 4., 2. / 4., 2. / 4., 2. / 4, 4. / 3., + 4. / 3., 4. / 3., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6., 5. / 6.}); + sd::ops::unsorted_segment_mean_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); - - sd::ops::unsorted_segment_mean op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_mean op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::unsorted_segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 41., 32., 77., 34., + 35.1, 51.4, 83., 28., 114.1, 47.1, 62.7, 63.1, 64., 64.2, + 66.2, 64., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_mean op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({3.0405593, 8.75, 3., 7.621024, 4.5723805}); - - sd::ops::unsorted_segment_sqrt_n op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create( + {3.0405593, 8.75, 3., 7.621024, 4.5723805}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_sqrt_n op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_BP_1) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); -// NDArray exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); - auto exp = NDArrayFactory::create({3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); - sd::ops::unsorted_segment_sqrt_n_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Hello Out:"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + // NDArray + // exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); + auto exp = NDArrayFactory::create( + {3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, + 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); + sd::ops::unsorted_segment_sqrt_n_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Hello Out:"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 2.7577164, 3.4648232, 4.9497476, 12.727922, - 2.1, 2.1, 0.7, 0.1, - 3. , 4.2, 2.2, 1. - }); - - sd::ops::unsorted_segment_sqrt_n op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {2.7577164, 3.4648232, 4.9497476, 12.727922, 2.1, 2.1, 0.7, 0.1, 3., 4.2, + 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// result.at(0)->printIndexedBuffer("Output"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_sqrt_n op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , - 57.982758, 45.254833, 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, 93.62093, 90.50967, - 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); - - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 57.982758, 45.254833, + 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, + 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, + 93.62093, 90.50967, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { - auto x = NDArrayFactory::create({1.,2.,5.,7.,3.,1.,3.,4.}); - auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); - //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); - auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); - sd::ops::unsorted_segment_sqrt_n op; - - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - // result.at(0)->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = NDArrayFactory::create({1., 2., 5., 7., 3., 1., 3., 4.}); + auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); + // NDArray exp({1.7320508075688772, 1., 1.4142135623730951, + // 1.4142135623730951}); + auto exp = + NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) { - auto x = NDArrayFactory::create({5,1,7,2,3,4,1,3}); - auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); - //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); -// auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); - sd::ops::unsorted_segment_sqrt_n op; - -try { + auto x = NDArrayFactory::create({5, 1, 7, 2, 3, 4, 1, 3}); + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2, 2, 3, 3}); + // NDArray exp({1.7320508075688772, 1., 1.4142135623730951, + // 1.4142135623730951}); + // auto exp = + // NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + + try { auto result = op.evaluate({&x, &idx}, {}, {1}); ASSERT_NE(result.status(), Status::OK()); -} -catch (std::exception& err) { - -} - // result.at(0)->printIndexedBuffer("Output"); - // exp.printIndexedBuffer("Expect"); - //ASSERT_TRUE(exp.equalsTo(result.at(0))); + } catch (std::exception& err) { + } + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + // ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({ 0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - sd::ops::segment_sum op; + sd::ops::segment_sum op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::segment_sum_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::segment_sum_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::unsorted_segment_sum_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create( + {1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_2) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({ 3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); - sd::ops::unsorted_segment_sum_bp op; + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); - - sd::ops::segment_sum op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.9, 4.9, 7., 18., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_sum op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1. , 2., 3., 4., 1. , 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {3, 4}); - eps.linspace(1); - - sd::ops::segment_sum_bp op; + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 2., 3., 4., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {3, 4}); + eps.linspace(1); - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_sum_bp op; - + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , - 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , - 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 82., 64., 154., 68., + 70.2, 102.8, 166., 56., 228.2, 94.2, 125.4, 126.2, 128., 128.4, + 132.4, 128., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - - sd::ops::unsorted_segment_sum op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_sum op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.9, 4.9, 7., 18., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_sum op; + sd::ops::unsorted_segment_sum op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , - 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , - 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 82., 64., 154., 68., + 70.2, 102.8, 166., 56., 228.2, 94.2, 125.4, 126.2, 128., 128.4, + 132.4, 128., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , - 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 3, 7}); - auto exp = NDArrayFactory::create('c', {8, 4, 4}, { - 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , - 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , - 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , - 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_sum op; - - auto result = op.evaluate({&x, &idx}, {}, {8}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create( + 'c', {8, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, + 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, + 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 31., 22., 87., 44., 55.1, 46.4, 73., + 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, + 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_1) { - auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - - sd::ops::segment_prod op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_1) { - auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - sd::ops::segment_prod_bp op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("ProdBP Output"); -// exp.printIndexedBuffer("ProdBP Expect"); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("ProdBP Output"); + // exp.printIndexedBuffer("ProdBP Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_1) { - auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - sd::ops::segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("ProdBP Output"); - //exp.printIndexedBuffer("ProdBP Expect"); + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create( + {2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("ProdBP Output"); + // exp.printIndexedBuffer("ProdBP Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_2) { - auto x = NDArrayFactory::create({ 3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); - auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); - auto n = NDArrayFactory::create(5LL); - sd::ops::unsorted_segment_prod_bp op; + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, + 151.2, 17.64, 75.6, 75.6, 13.86, + 97.02, 3.234, 2.31, 4.41, 9.702}); + auto n = NDArrayFactory::create(5LL); + sd::ops::unsorted_segment_prod_bp op; - auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Unsorted ProdBP Output"); - //exp.printIndexedBuffer("Unsorted ProdBP Expect"); + auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Unsorted ProdBP Output"); + // exp.printIndexedBuffer("Unsorted ProdBP Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_2) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::segment_prod op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_2) { - auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., - 2.1, 2.4, 3., 9., - 2.1, 2.1, 0.7, 0.1, - 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., 6., 7., 8., 9., 10., 11., 12.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - eps.linspace(1); - sd::ops::segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &eps}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 2); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = + NDArrayFactory::create('c', {4, 4}, + {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., + 6., 7., 8., 9., 10., 11., 12.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + eps.linspace(1); + sd::ops::segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , - 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, 1882.4401, 1287, - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - - sd::ops::segment_prod op; - - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 1581, 924, + 5829, 1056, 832.01001, 2616.9602, 6789, 784, + 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, + 1882.4401, 1287, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_04) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + // ---------------------------------------------------------------- - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_05) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + // ---------------------------------------------------------------- - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(0); -// res->printIndexedBuffer("Segment prod 05"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); + // res->printIndexedBuffer("Segment prod 05"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) { - auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 6, 7, 8}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); - sd::ops::segment_prod op; + sd::ops::segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - auto res = result.at(0); -// res->printIndexedBuffer("Segment prod 05_1"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); + // res->printIndexedBuffer("Segment prod 05_1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_06) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8'}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); - sd::ops::segment_prod op; + // ---------------------------------------------------------------- - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_07) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8'}); - auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); - auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); - sd::ops::segment_prod op; + // ---------------------------------------------------------------- - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto idx = NDArrayFactory::create({0, 0, 1, 2, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create({2, 3, 120, 56}); + sd::ops::segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_08) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8', '\x9', '\xA'}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); - auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); - sd::ops::segment_prod op; + auto idx = NDArrayFactory::create({0, 0, 2, 2, 2, 2, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create({2, 1, 360, 5040}); + sd::ops::segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { - auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - - sd::ops::unsorted_segment_prod op; + auto x = + NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_11) { - auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); - auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); - auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - - sd::ops::unsorted_segment_prod op; + auto x = + NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create( + {2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); - auto result = op.evaluate({&x, &idx}, {}, {5}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_2) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); - auto idx = NDArrayFactory::create({0, 0, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); + auto x = + NDArrayFactory::create('c', {4, 4}, + {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, + 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - sd::ops::unsorted_segment_prod op; + sd::ops::unsorted_segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) { - auto x = NDArrayFactory::create('c', {4, 4}, { - 3., 4.2, 2.2, 1., - 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1 }); - auto idx = NDArrayFactory::create({2, 0, 0, 1}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); - - //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + auto x = + NDArrayFactory::create('c', {4, 4}, + {3., 4.2, 2.2, 1., 1.8, 2.5, 4., 9., 2.1, + 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1}); + auto idx = NDArrayFactory::create({2, 0, 0, 1}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {3.78, 6., 12., 81., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); - sd::ops::unsorted_segment_prod op; + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_EQ(result.size(), 1); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) { - auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create( + {'\x1', '\x2', '\x3', '\x4', '\x5', '\x6', '\x7', '\x8', '\x9', '\xA'}); - auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); - auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); - sd::ops::unsorted_segment_prod op; + // ---------------------------------------------------------------- - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto idx = NDArrayFactory::create({0, 0, 2, 2, 2, 2, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create({2, 1, 360, 5040}); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , - 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, 1882.4401, 1287, - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); - - sd::ops::unsorted_segment_prod op; - - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); + + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, + 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117., 1581, 924, + 5829, 1056, 832.01001, 2616.9602, 6789, 784, + 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, + 1882.4401, 1287, 91., 82., 37., 64., + 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117.}); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) { - auto x = NDArrayFactory::create('c', {4, 4, 4}, { - 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., - 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , - 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + auto x = NDArrayFactory::create( + 'c', {4, 4, 4}, + {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, + 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., + 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, + 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., + 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., + 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, + 14., 114.2, 16.2, 117.}); -// ---------------------------------------------------------------- + // ---------------------------------------------------------------- - auto idx = NDArrayFactory::create({1, 1, 1, 2}); - auto exp = NDArrayFactory::create('c', {3, 4, 4}, { - 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., + auto idx = NDArrayFactory::create({1, 1, 1, 2}); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 4}, + {1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., - 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, - 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, + 143871, 75768, 215673, 67584., 45843.75, 121426.96, + 495597, 21952, 1547562.8, 12020.262, 161306.38, 19409.092, + 22344, 185191.27, 30495.531, 150579, - 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); + 91., 82., 37., 64, 55.1, 46.400002, + 73, 28, 119.1, 12.1, 112.7, 13.1, + 14, 114.2, 16.2, 117}); - sd::ops::unsorted_segment_prod op; + sd::ops::unsorted_segment_prod op; - auto result = op.evaluate({&x, &idx}, {}, {3}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) { - auto x = NDArrayFactory::create('c', {8, 15}); - -// ---------------------------------------------------------------- + auto x = NDArrayFactory::create('c', {8, 15}); - auto idx = NDArrayFactory::create({3, 1, 2, 1, 2, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {4, 15}, { - 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., - 1., 1., 1., 1., 1., - 78016., 85493., 93312., 101479., 110000., - 118881., 128128., 137747., 147744., 158125., - 168896., 180063., 191632., 203609., 216000., - 172081., 182528., 193347., 204544., 216125., - 228096., 240463., 253232., 266409., 280000., - 294011., 308448., 323317., 338624., 354375., - 76., 154., 234., 316., 400., - 486., 574., 664., 756., 850., - 946., 1044., 1144., 1246., 1350.}); - x.linspace(1.); + // ---------------------------------------------------------------- - sd::ops::unsorted_segment_prod op; + auto idx = NDArrayFactory::create({3, 1, 2, 1, 2, 3, 2, 1}); + auto exp = NDArrayFactory::create( + 'c', {4, 15}, + {1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 78016., + 85493., 93312., 101479., 110000., 118881., 128128., 137747., 147744., + 158125., 168896., 180063., 191632., 203609., 216000., 172081., 182528., + 193347., 204544., 216125., 228096., 240463., 253232., 266409., 280000., + 294011., 308448., 323317., 338624., 354375., 76., 154., 234., + 316., 400., 486., 574., 664., 756., 850., 946., + 1044., 1144., 1246., 1350.}); + x.linspace(1.); - auto result = op.evaluate({&x, &idx}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); -// result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + sd::ops::unsorted_segment_prod op; - + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { - auto x = NDArrayFactory::create('c', {8}, { - 5,1,7,2,3,4,1,3}); - auto gradO = NDArrayFactory::create('c', {4}, {1,2,3,4}); -// ---------------------------------------------------------------- - - auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); - auto exp = NDArrayFactory::create('c', {8}, { - 7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, 4.000000 - }); -// 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., -// -// 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, -// 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, -// -// 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); - - sd::ops::unsorted_segment_prod_bp op; - - auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); - //exp.printIndexedBuffer("Expect"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create('c', {8}, {5, 1, 7, 2, 3, 4, 1, 3}); + auto gradO = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + // ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2, 2, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {8}, + {7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, + 4.000000}); + // 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., + // + // 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, + // 21952, 1547562.8, 12020.262, 161306.38, 19409.092, 22344, + // 185191.27, 30495.531, 150579, + // + // 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, + // 112.7, 13.1, 14, 114.2, 16.2, 117}); + + sd::ops::unsorted_segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_1) { - auto x = NDArrayFactory::create('c', {2,4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); - -// ---------------------------------------------------------------- - - - auto exp = NDArrayFactory::create('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); - - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., + 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., + 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., + 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., + 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119., 12., 112., 13., 140., 110., 160., 107.}); + + // ---------------------------------------------------------------- + + auto exp = NDArrayFactory::create( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., + 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., + 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., + 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., + 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., + 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., + 119., 12., 112., 13., 140., 110., 160., 107.}); + + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_2) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 1, 12}, { - 11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., - 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., - 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24. - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 3,3, 1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 1, 12}, + {11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., + 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., + 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 2, 3, 3, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_3) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 2, 6}, { - 11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., 9., 8., - 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., 211., 12., 13., 25., - 6., 7., 15., 216., 17., 35., 36., 327. - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,1,3,2,2,2,0}); - ASSERT_EQ(result.status(), Status::OK()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 2, 6}, + {11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., + 9., 8., 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., + 211., 12., 13., 25., 6., 7., 15., 216., 17., 35., 36., 327.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 1, 3, 2, 2, 2, 0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_4) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., - 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_5) { - auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { - 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., - 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., - 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., - 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., - 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., - 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., -211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., - 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., - 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {3, 1, 1, 18}, { - 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., 23., 5., 6., 7., 35., 36., 37., - 9., 8., 7., 49., 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., 135., 136., 137., - 211., 12., 13., 15., 216., 17., 21., 2., 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. - -//Patch shape is (3, 1, 2, 18) - - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {3,2,3,2,1,2,0}); - ASSERT_EQ(result.status(), Status::OK()); -// result.at(0)->printIndexedBuffer("Output"); - //result.at(0)->printShapeInfo("Out Shape"); -// exp.printIndexedBuffer("Expect"); - //exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {3, 3, 4, 3}, + {11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {3, 1, 1, 18}, + { + 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., + 23., 5., 6., 7., 35., 36., 37., 9., 8., 7., 49., + 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., + 135., 136., 137., 211., 12., 13., 15., 216., 17., 21., 2., + 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. + + // Patch shape is (3, 1, 2, 18) + + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {3, 2, 3, 2, 1, 2, 0}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_6) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -//Images shape is (3, 3, 4, 3) -//[1, 1, 1, 1] -//[1, 3, 2, 1] -auto exp = NDArrayFactory::create('c', {2, 1, 4, 4}, { - 11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, - 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42 - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,1, 1,1, 1,1,0}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); + + // Images shape is (3, 3, 4, 3) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {2, 1, 4, 4}, + {11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, + 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, + 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, + 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2, 1, 1, 1, 1, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_7) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { - 1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., - 4., 5., 7., 8., 5., 6., 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0. }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 4}, + {1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., 4., 5., 7., 8., 5., 6., + 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_8) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { - 1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, 0, 0, 11, 12, 0, 0, - 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, - 13, 14, 15, 16, 0, 0, 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0 }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 8}, + {1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, + 0, 0, 11, 12, 0, 0, 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, + 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, 13, 14, 15, 16, 0, 0, + 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9) { - auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 6, 6, 18}, { - 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., - 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., - 0., 0., 0., 0., 0., 0., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., - 0., 0., 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., - 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., - 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., - 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., - 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., - 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., - 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., - 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., - 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., - 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., - 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., - 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., - 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., - 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., - 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., - 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., - 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., - 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., - 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., - 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., - 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., - 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., - 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., - 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., - 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., - 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., - 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., - 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., - 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., 0., 0., - 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 0., 0., 0., 0., 0., 0., - 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., - 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., - 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 6, 6, 18}, + {0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., + 13., 14., 15., 16., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., + 5., 6., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., + 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 0., 0., + 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., + 21., 22., 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., + 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 9., 10., + 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 1., 2., + 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., + 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., + 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., + 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., + 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., + 35., 36., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., + 33., 34., 35., 36., 0., 0., 0., 0., 13., 14., 15., 16., 0., 0., + 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 13., 14., 15., 16., + 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., + 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., + 41., 42., 43., 44., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., + 33., 34., 41., 42., 43., 44., 45., 46., 19., 20., 21., 22., 23., 24., + 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 21., 22., + 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., + 0., 0., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., + 0., 0., 49., 50., 51., 52., 25., 26., 27., 28., 29., 30., 37., 38., + 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., + 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., + 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., + 55., 56., 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., + 47., 48., 55., 56., 57., 58., 59., 60., 33., 34., 35., 36., 0., 0., + 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., + 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., + 63., 64., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., + 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., 43., 44., 51., 52., + 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 41., 42., 43., 44., + 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., + 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., + 69., 70., 71., 72., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., + 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 49., 50., 51., 52., + 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., 49., 50., + 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., + 0., 0., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 0., 0., 0., 0., 0., 0., 53., 54., 55., 56., 57., 58., 65., 66., + 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., 55., 56., 57., 58., + 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., + 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., + 0., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {3, 3, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9_1) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 2}, {1, 116, 2, 116, 3, 116, 4, 116, - 5, 117, 6, 117, 7, 117, 8, 117, - 9, 118, 10, 118, 11, 118, 12, 118, - 13, 119, 14, 119, 15, 119, 16, 119}); - //x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { - 1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, 117, 7, 117, 3, 116, - 4, 116, 7, 117, 8, 117, 4, 116, 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, - 9, 118, 10, 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, 11, 118, -12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, 118, 10, 118, 13, 119, 14, 119, -10, 118, 11, 118, 14, 119, 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, - 0, 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, 14, 119, 15, 119, - 0, 0, 0, 0, 15, 119, 16, 119, 0, 0, 0, 0, 16, 119, 0, 0, 0, 0, - 0, 0 - - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create( + 'c', {1, 4, 4, 2}, + {1, 116, 2, 116, 3, 116, 4, 116, 5, 117, 6, 117, 7, 117, 8, 117, + 9, 118, 10, 118, 11, 118, 12, 118, 13, 119, 14, 119, 15, 119, 16, 119}); + // x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 8}, + {1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, + 117, 7, 117, 3, 116, 4, 116, 7, 117, 8, 117, 4, 116, + 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, 9, 118, 10, + 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, + 11, 118, 12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, + 118, 10, 118, 13, 119, 14, 119, 10, 118, 11, 118, 14, 119, + 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, 0, + 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, + 14, 119, 15, 119, 0, 0, 0, 0, 15, 119, 16, 119, 0, + 0, 0, 0, 16, 119, 0, 0, 0, 0, 0, 0 + + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } // // //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_10) { - auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 18}, { - 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., - 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., - 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., - 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., - 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., - 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., - 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., - 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., - 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., - 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., - 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., - 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., - 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., - 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., - 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., - 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72.}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 18}, + {1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., + 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., + 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., + 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., + 35., 36., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., + 37., 38., 39., 40., 41., 42., 15., 16., 17., 18., 19., 20., 27., 28., + 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 17., 18., 19., 20., + 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., + 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., + 45., 46., 47., 48., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., + 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., 31., 32., + 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 29., 30., + 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., + 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., + 55., 56., 57., 58., 59., 60., 37., 38., 39., 40., 41., 42., 49., 50., + 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., + 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., + 67., 68., 69., 70., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., + 59., 60., 67., 68., 69., 70., 71., 72.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {3, 3, 1, 1, 1, 1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { - 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, - 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 4}, + {1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, + 10, 11, 7, 8, 11, 12, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010_1) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 4}, { - 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, 5, 6, 9, 10, 6, 7, 10, 11, - 7, 8, 11, 12, 8, 0, 12, 0, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, - 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputSame"); -// exp.printBuffer("ExpectSame"); -// exp.printIndexedBuffer("Expect Same Formatted"); -// output->printIndexedBuffer("Output Same Formatted"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 4}, + {1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, + 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, 8, 0, 12, 0, + 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, + 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputSame"); + // exp.printBuffer("ExpectSame"); + // exp.printIndexedBuffer("Expect Same Formatted"); + // output->printIndexedBuffer("Output Same Formatted"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_011) { - auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, { - 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16, - }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - //x.printIndexedBuffer("Images"); - //x.printBuffer("Images linear"); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("OutputValid"); -// exp.printBuffer("ExpectValid"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, + { + 1, + 3, + 9, + 11, + 2, + 4, + 10, + 12, + 5, + 7, + 13, + 15, + 6, + 8, + 14, + 16, + }); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + // x.printIndexedBuffer("Images"); + // x.printBuffer("Images linear"); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 2, 2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("OutputValid"); + // exp.printBuffer("ExpectValid"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_11) { - auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { - 1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, 22, 23, 24, 9, 10, - 11, 12, 25, 26, 27, 28, 13, 14, 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, - 49, 50, 51, 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, 57, 58, - 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, - 69, 70, 71, 72, 85, 86, 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, - 79, 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, 101, 102, 103, 104, - 117, 118, 119, 120, 105, 106, 107, 108, 121, 122, 123, 124, 109, 110, 111, 112, 125, 126, - 127, 128}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); -// output->printBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 4, 4, 8}, + {1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, + 22, 23, 24, 9, 10, 11, 12, 25, 26, 27, 28, 13, 14, + 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, 49, 50, 51, + 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, + 57, 58, 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, + 66, 67, 68, 81, 82, 83, 84, 69, 70, 71, 72, 85, 86, + 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, 79, + 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, + 101, 102, 103, 104, 117, 118, 119, 120, 105, 106, 107, 108, 121, + 122, 123, 124, 109, 110, 111, 112, 125, 126, 127, 128}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 2, 2, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_12) { - auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); - x.linspace(1); - -//Images shape is (1, 3, 3, 4) -//[1, 1, 1, 1] -//[1, 3, 2, 1] - auto exp = NDArrayFactory::create('c', {1, 8, 8, 8}, { - 0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, 21, 22, 0, 0, - 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, 21, 22, 25, 26, 0, 0, 0, 0, - 23, 24, 27, 28, 0, 0, 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, - 31, 32, 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, 35, 36, - 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, 35, 36, 39, 40, 5, 6, - 9, 10, 37, 38, 41, 42, 7, 8, 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, - 41, 42, 45, 46, 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, - 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, 49, 50, 53, 54, - 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, 25, 26, 53, 54, 57, 58, 23, 24, - 27, 28, 55, 56, 59, 60, 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, - 59, 60, 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, 0, 0, - 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, 39, 40, 67, 68, 71, 72, - 37, 38, 41, 42, 69, 70, 73, 74, 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, - 45, 46, 73, 74, 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, - 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, 53, 54, 81, 82, - 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, 53, 54, 57, 58, 85, 86, 89, 90, - 55, 56, 59, 60, 87, 88, 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, - 63, 64, 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, 67, 68, - 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, 67, 68, 71, 72, 99, 100, - 103, 104, 69, 70, 73, 74, 101, 102, 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, - 73, 74, 77, 78, 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, - 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, 81, 82, 85, 86, - 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, 119, 120, 85, 86, 89, 90, 117, 118, - 121, 122, 87, 88, 91, 92, 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, - 91, 92, 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, 0, 0, - 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, 0, 0, 99, 100, 103, 104, - 0, 0, 0, 0, 101, 102, 105, 106, 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, - 0, 0, 105, 106, 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, - 109, 110, 0, 0, 0, 0, 0, 0}); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; - - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); - //output->printShapeInfo("Output shape"); -// output->printIndexedBuffer("Output"); -// exp.printBuffer("Expect"); -// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) -// if (exp.e(e) != output->e(e)) -// printf("%lld ", e); -// printf("\n"); - //result.at(1)->printBuffer("OUtput2"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + + // Images shape is (1, 3, 3, 4) + //[1, 1, 1, 1] + //[1, 3, 2, 1] + auto exp = NDArrayFactory::create( + 'c', {1, 8, 8, 8}, + {0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, + 21, 22, 0, 0, 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, + 21, 22, 25, 26, 0, 0, 0, 0, 23, 24, 27, 28, 0, 0, + 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, 31, 32, + 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, + 35, 36, 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, + 35, 36, 39, 40, 5, 6, 9, 10, 37, 38, 41, 42, 7, 8, + 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, 41, 42, 45, 46, + 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, + 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, + 49, 50, 53, 54, 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, + 25, 26, 53, 54, 57, 58, 23, 24, 27, 28, 55, 56, 59, 60, + 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, 59, 60, + 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, + 0, 0, 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, + 39, 40, 67, 68, 71, 72, 37, 38, 41, 42, 69, 70, 73, 74, + 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, 45, 46, 73, 74, + 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, + 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, + 53, 54, 81, 82, 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, + 53, 54, 57, 58, 85, 86, 89, 90, 55, 56, 59, 60, 87, 88, + 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, 63, 64, + 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, + 67, 68, 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, + 67, 68, 71, 72, 99, 100, 103, 104, 69, 70, 73, 74, 101, 102, + 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, 73, 74, 77, 78, + 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, + 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, + 81, 82, 85, 86, 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, + 119, 120, 85, 86, 89, 90, 117, 118, 121, 122, 87, 88, 91, 92, + 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, 91, 92, + 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, + 0, 0, 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, + 0, 0, 99, 100, 103, 104, 0, 0, 0, 0, 101, 102, 105, 106, + 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, 0, 0, 105, 106, + 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, + 109, 110, 0, 0, 0, 0, 0, 0}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 2, 2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,2,2,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printShapeInfo("Output shape"); + // output->printIndexedBuffer("Output"); + // exp.printBuffer("Expect"); + // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) + // if (exp.e(e) != output->e(e)) + // printf("%lld ", e); + // printf("\n"); + // result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { - auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { - 1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., 12., 5., 6., - 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., 15., 16., 9., 10., 11., 12., - 15., 16., 17., 18., 11., 12., 0., 0., 17., 18., 0., 0., 13., 14., 15., 16., 0., 0., - 0., 0., 15., 16., 17., 18., 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0. }); -// ---------------------------------------------------------------- - sd::ops::extract_image_patches op; + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); - auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 3, 8}, + {1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., + 12., 5., 6., 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., + 15., 16., 9., 10., 11., 12., 15., 16., 17., 18., 11., 12., 0., 0., 17., + 18., 0., 0., 13., 14., 15., 16., 0., 0., 0., 0., 15., 16., 17., 18., + 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0.}); + // ---------------------------------------------------------------- + sd::ops::extract_image_patches op; - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate( + {&x}, {}, + {2, 2, 1, 1, 1, 1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], + // rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_1) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); - auto result = op.evaluate({&x}, {}, {6}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {6}); + ASSERT_EQ(result.status(), Status::OK()); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_2) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, - 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {-8}); - ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + auto result = op.evaluate({&x}, {}, {-8}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_3) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, - 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); - auto result = op.evaluate({&x}, {}, {-40}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {-40}); + ASSERT_EQ(result.status(), Status::OK()); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); - -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); - auto result = op.evaluate({&x}, {}, {38}); - ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output 4"); - //exp.printIndexedBuffer("Expect 4"); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(result.at(0))); + auto result = op.evaluate({&x}, {}, {38}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output 4"); + // exp.printIndexedBuffer("Expect 4"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4_inplace) { - auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, - 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 -}); + auto x = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}); -auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { - 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, - 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, - 21.41, 21.42, 22.11, 22.12 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {38}, {}, {}, true); - ASSERT_EQ(result.status(), Status::OK()); - //x.printIndexedBuffer("Output 4 inplace"); - //exp.printIndexedBuffer("Expect 4 inplace"); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 4, 2}, + {22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, + 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, + 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, + 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {38}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); + // x.printIndexedBuffer("Output 4 inplace"); + // exp.printIndexedBuffer("Expect 4 inplace"); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_5) { - auto x = NDArrayFactory::create('c', {3, 4}, { - 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. -}); + auto x = NDArrayFactory::create( + 'c', {3, 4}, {0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11.}); -auto exp = NDArrayFactory::create('c', {3, 4}, { - 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. -// 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + { + 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. + // 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 + }); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {2, 1}); - ASSERT_EQ(result.status(), Status::OK()); + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(result.status(), Status::OK()); - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(result.at(0))); - - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_6) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. -}); - -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11.}); - auto result = op.evaluate({&x}, {}, {1, 2}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10.}); + // ---------------------------------------------------------------- + sd::ops::roll op; - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto result = op.evaluate({&x}, {}, {1, 2}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_7) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. -}); - -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); - auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2.}); + // ---------------------------------------------------------------- + sd::ops::roll op; - //result.at(0)->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(result.at(0))); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_8) { - auto x = NDArrayFactory::create('c', {2, 3, 2}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. -}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 2}, {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); -auto exp = NDArrayFactory::create('c', {2, 3, 2}, { - 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}, {}, {}, true); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); - //x.printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + // x.printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_9) { - auto x = NDArrayFactory::create('c', {2, 3, 3}, { - 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17. -}); + auto x = + NDArrayFactory::create('c', {2, 3, 3}, + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., + 10., 11., 12., 13., 14., 15., 16., 17.}); -auto exp = NDArrayFactory::create('c', {2, 3, 3}, { - 6., 7., 8., 0., 1., 2., 3., 4., 5., 15., 16., 17., 9., 10., 11., 12., 13., 14. -}); -// ---------------------------------------------------------------- - sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {1, 1}, {}, {}, true); - ASSERT_EQ(result.status(), Status::OK()); + auto exp = + NDArrayFactory::create('c', {2, 3, 3}, + {6., 7., 8., 0., 1., 2., 3., 4., 5., 15., + 16., 17., 9., 10., 11., 12., 13., 14.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {1, 1}, {}, {}, true); + ASSERT_EQ(result.status(), Status::OK()); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_10) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - auto result = op.evaluate({&x}, {}, {3, 1}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + // ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {3, 1}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_11) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,2}); - auto axis = NDArrayFactory::create({0, 1}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); - -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); - - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 2}); + auto axis = NDArrayFactory::create({0, 1}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., + 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_12) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,1,1}); - auto axis = NDArrayFactory::create({0, 1, 2}); - - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 1, 1}); + auto axis = NDArrayFactory::create({0, 1, 2}); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, + 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_13) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create(3); - auto axis = NDArrayFactory::create(2); - - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; - NDArray* y = nullptr; - auto result = op.evaluate({&x}, {}, {3,2}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create(3); + auto axis = NDArrayFactory::create(2); - ASSERT_TRUE(exp.equalsTo(out)); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2, 3, 4, 1, 6, 7, 8, 5, 10, 11, 12, 9, + 14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21}); + // ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x}, {}, {3, 2}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_14) { - auto x = NDArrayFactory::create('c', {2, 3, 4}, { - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. - }); - auto shift = NDArrayFactory::create({1,1,1}); - auto axis = NDArrayFactory::create({0, 1, 2}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., + 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.}); + auto shift = NDArrayFactory::create({1, 1, 1}); + auto axis = NDArrayFactory::create({0, 1, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { - 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 - }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, + 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7}); + // ---------------------------------------------------------------- + sd::ops::roll op; - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); -// out->printIndexedBuffer("Output"); - //exp.printIndexedBuffer("Expect"); + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + // out->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_15) { - auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f }); - auto shift = NDArrayFactory::create(2); - auto axis = NDArrayFactory::create(0); - - auto exp = NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f }); -// ---------------------------------------------------------------- - sd::ops::roll op; + auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f}); + auto shift = NDArrayFactory::create(2); + auto axis = NDArrayFactory::create(0); - auto result = op.evaluate({&x, &shift, &axis}); - ASSERT_EQ(result.status(), Status::OK()); - auto out = result.at(0); -// out->printIndexedBuffer("Output 15"); -// exp.printIndexedBuffer("Expect 15"); + auto exp = + NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f}); + // ---------------------------------------------------------------- + sd::ops::roll op; - ASSERT_TRUE(exp.equalsTo(out)); + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + // out->printIndexedBuffer("Output 15"); + // exp.printIndexedBuffer("Expect 15"); - + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create(50.); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create(50.); + sd::ops::percentile op; - sd::ops::percentile op; + auto result = op.evaluate({&input}, {50.}, {}); + auto output = result.at(0); - auto result = op.evaluate({&input}, {50.}, {}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test2) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {11.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test3) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {10.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {10.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 0, 1}, {}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test4) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1, 1, 1}, {11.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 1, 1}, {}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test5) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {12., 7., 11., 10.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 0, 1}, {0,1}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {0, 1}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test6) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - - auto expected = NDArrayFactory::create('c', {1,1,4}, {16., 14., 15., 13.}); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 1, 1}, {0,1}); - auto output = result.at(0); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {16., 14., 15., 13.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {0, 1}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test7) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = + NDArrayFactory::create('c', {1, 1, 4}, {12., 7., 11., 10.}); - auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0, 1}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0,1}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test8) { + const int dim0 = 5, dim1 = 5, dim2 = 4; - const int dim0=5, dim1=5, dim2=4; - - auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto input = NDArrayFactory::create( + 'c', {dim0, dim1, dim2}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); + auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {0,1}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0, 1}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test9) { + const int dim0 = 100; - const int dim0=100; - - auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - - auto expected = NDArrayFactory::create(11.); + auto input = NDArrayFactory::create( + 'c', {dim0}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {0}); - auto output = result.at(0); + auto expected = NDArrayFactory::create(11.); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test10) { + const int dim0 = 100; - const int dim0=100; - - auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., - 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., - 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., - 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., - 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - - auto expected = NDArrayFactory::create('c', {1}, {11.}); + auto input = NDArrayFactory::create( + 'c', {dim0}, + {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., + 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., + 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., + 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., + 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., + 97., 100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., + 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0}); - auto output = result.at(0); + auto expected = NDArrayFactory::create('c', {1}, {11.}); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test11) { + const int dim0 = 1; - const int dim0=1; + auto input = NDArrayFactory::create('c', {dim0}, {100.}); - auto input = NDArrayFactory::create('c', {dim0}, {100.}); + auto expected = NDArrayFactory::create('c', {1}, {100.}); - auto expected = NDArrayFactory::create('c', {1}, {100.}); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 1}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test12) { + const int dim0 = 1; - const int dim0=1; - - auto input = NDArrayFactory::create('c', {dim0}, {100.}); + auto input = NDArrayFactory::create('c', {dim0}, {100.}); - auto expected = NDArrayFactory::create(100.); + auto expected = NDArrayFactory::create(100.); - sd::ops::percentile op; - //q, interpolation, keepDims - auto result = op.evaluate({&input}, {10, 2, 0}, {}); - auto output = result.at(0); + sd::ops::percentile op; + // q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {}); + auto output = result.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, transpose_test3) { + auto input = + NDArrayFactory::create('c', {5, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto exp = + NDArrayFactory::create('c', {3, 5}, + {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, + 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); - auto input = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); - - sd::ops::transpose op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::transpose op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test1) { + auto input = + NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, + 1.695077, 1.706884, 1.711427, 1.713446}); - auto input = NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); - - sd::ops::rationaltanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rationaltanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test2) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.998222, 1.516093, 1.658054, + 1.695077, 1.706884, 1.711427, 1.713446}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); - - sd::ops::rationaltanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rationaltanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test3) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {1.143933, 1.605747, 0.795557, 0.261710, + 0.095832, 0.041218, 0.020221, 0.010971}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); - - sd::ops::rationaltanh_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - auto output = result.at(0); -// output->printBuffer("Output rationaltanh BP"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rationaltanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); + // output->printBuffer("Output rationaltanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.761594, 0.964028, 0.995055, + 0.999329, 0.999909, 0.999988, 0.999998}); - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); - - sd::ops::rectifiedtanh op; - auto result = op.evaluate({&input}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Output rectifiedtanh"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::rectifiedtanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Output rectifiedtanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { - - auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); - auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); - - sd::ops::rectifiedtanh_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - auto output = result.at(0); -// output->printBuffer("Output rectifiedtanh BP"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto input = + NDArrayFactory::create('c', {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = + NDArrayFactory::create('c', {2, 2, 2}, + {0.000000, 0.839949, 0.211952, 0.039464, + 0.006705, 0.001089, 0.000172, 0.000027}); + + sd::ops::rectifiedtanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); + // output->printBuffer("Output rectifiedtanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, RealDiv_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e = + NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f,2.f}); - NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); - - sd::ops::realdiv op; - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput RealDiv"); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput RealDiv"); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); + NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); + NDArray eps = + NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); - NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); - NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); - NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); - - sd::ops::realdiv_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::realdiv_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); - auto z0 = result.at(0); - auto z1 = result.at(1); -// z0->printShapeInfo("OUtput RealDiv BP0 shape"); -// z1->printShapeInfo("OUtput RealDiv BP1 shape"); -// z0->printIndexedBuffer("OUtput RealDiv BP0"); -// z1->printIndexedBuffer("OUtput RealDiv BP1"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e0.equalsTo(z0)); - ASSERT_TRUE(e1.equalsTo(z1)); + ASSERT_EQ(Status::OK(), result.status()); - + auto z0 = result.at(0); + auto z1 = result.at(1); + // z0->printShapeInfo("OUtput RealDiv BP0 shape"); + // z1->printShapeInfo("OUtput RealDiv BP1 shape"); + // z0->printIndexedBuffer("OUtput RealDiv BP0"); + // z1->printIndexedBuffer("OUtput RealDiv BP1"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + // NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e = NDArrayFactory::create({1, 2, 1}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); -// NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); - NDArray e = NDArrayFactory::create({1, 2, 1}); - - sd::ops::shapes_of op; - auto result = op.evaluate({&x}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::shapes_of op; + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput RealDiv"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput RealDiv"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_2) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create({1, 2, 1}); + NDArray e1 = NDArrayFactory::create({1, 2}); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); - NDArray e0 = NDArrayFactory::create({1, 2, 1}); - NDArray e1 = NDArrayFactory::create({1, 2}); + sd::ops::shapes_of op; + auto result = op.evaluate({&x, &y}, {}, {}); - sd::ops::shapes_of op; - auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z0 = result.at(0); - auto z1 = result.at(1); -// z0->printIndexedBuffer("OUtput shapes2"); -// z1->printIndexedBuffer("OUtput shapes2"); -// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e0.equalsTo(z0)); - ASSERT_TRUE(e1.equalsTo(z1)); - - + auto z0 = result.at(0); + auto z1 = result.at(1); + // z0->printIndexedBuffer("OUtput shapes2"); + // z1->printIndexedBuffer("OUtput shapes2"); + // ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); } TEST_F(DeclarableOpsTests7, Size_1) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray e = NDArrayFactory::create(2); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); - NDArray e = NDArrayFactory::create(2); - - sd::ops::size op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::size op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput SIZE"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput SIZE"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Size_2) { + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create(10); - NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create(10); - - sd::ops::size op; - auto result = op.evaluate({&y}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::size op; + auto result = op.evaluate({&y}, {}, {}); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput SIZE"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result.status()); - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput SIZE"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softplus_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create( + 'c', {5, 2}, + {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, + 10.000046, 10.000046, 11.000016}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + sd::ops::softplus op; + auto result = op.evaluate({&x}, {}, {}); - sd::ops::softplus op; - auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softplus"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softplus"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softplus_BP_1) { - - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); - NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); - sd::ops::softplus ffOP; - sd::ops::softplus_bp bpOp; - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - - bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(gradOK); -// -// auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softplus"); -///// ASSERT_TRUE(e.isSameShape(z)); -// ASSERT_TRUE(e.equalsTo(*z)); -// -// + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + // NDArray e = NDArrayFactory::create('c', {5, 2}, + // {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, + // 10.000046, 10.000046, 11.000016}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + sd::ops::softplus ffOP; + sd::ops::softplus_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); + // + // auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softplus"); + ///// ASSERT_TRUE(e.isSameShape(z)); + // ASSERT_TRUE(e.equalsTo(*z)); + // + // } TEST_F(DeclarableOpsTests7, Softsign_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray e = NDArrayFactory::create( + 'c', {5, 2}, + {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, + 0.9166667}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); - - sd::ops::softsign op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::softsign op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("OUtput Softsign"); -/// ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("OUtput Softsign"); + /// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(DeclarableOpsTests7, Softsign_BP_1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + // NDArray e = NDArrayFactory::create('c', {5, 2}, + // {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, + // 10.000046f, 10.000046f, 11.000016f}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + sd::ops::softsign ffOP; + sd::ops::softsign_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f}); - NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); - sd::ops::softsign ffOP; - sd::ops::softsign_bp bpOp; - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); - - bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); - ASSERT_TRUE(gradOK); + ASSERT_TRUE(gradOK); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test2) { + auto x = NDArrayFactory::create('c', {1, 2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = + NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - auto x = NDArrayFactory::create('c', {1,2}, {2, 2}); - auto v = NDArrayFactory::create(42.); - auto exp = NDArrayFactory::create('c', {2, 2},{42.f, 42.f, 42.f, 42.f}); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}, {}, {}); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test3) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = + NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto v = NDArrayFactory::create(42.); - auto exp = NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); - - sd::ops::fill op; - auto result = op.evaluate({&x, &v}, {}, {}); - auto output = result.at(0); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test1) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); - - sd::ops::toggle_bits op; - auto result = op.evaluate({&x}); - auto output = result.at(0); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::toggle_bits op; + auto result = op.evaluate({&x}); + auto output = result.at(0); - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test2) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); + auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); - auto x = NDArrayFactory::create('c', {2}, {2, 2}); - auto y = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); - auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); + sd::ops::toggle_bits op; + auto result = op.evaluate({&x, &y}); + auto output = result.at(0); + auto z = result.at(1); - sd::ops::toggle_bits op; - auto result = op.evaluate({&x, &y}); - auto output = result.at(0); - auto z = result.at(1); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp0.isSameShape(output)); - ASSERT_TRUE(exp0.equalsTo(output)); - ASSERT_TRUE(exp1.isSameShape(z)); - ASSERT_TRUE(exp1.equalsTo(z)); - - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp0.isSameShape(output)); + ASSERT_TRUE(exp0.equalsTo(output)); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray y = NDArrayFactory::create('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2}); - NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {5, 2}, + {2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + NDArray exp = NDArrayFactory::create( + 'c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - sd::ops::truncatediv op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); - - + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test2) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray y = NDArrayFactory::create('c', {1, 2}, {2,2}); - NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - - sd::ops::truncatediv op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp.isSameShape(output)); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray exp = NDArrayFactory::create( + 'c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); - + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test1) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - - sd::ops::to_int32 op32; - sd::ops::to_int64 op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_I"); - auto out2 = result64.at(0); -// out2->printIndexedBuffer("OUT_L"); - -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(expI.equalsTo(out1)); - ASSERT_TRUE(expL.equalsTo(out2)); - + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expI = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expL = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expF = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray expF16 = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + + sd::ops::to_int32 op32; + sd::ops::to_int64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_I"); + auto out2 = result64.at(0); + // out2->printIndexedBuffer("OUT_L"); + + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expI.equalsTo(out1)); + ASSERT_TRUE(expL.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test2) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + NDArray x = NDArrayFactory::create('c', {5, 2}, + {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray expF = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray expH = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); - sd::ops::to_float32 op32; - sd::ops::to_float16 op16; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result16 = op16.evaluate({&x}, {}, {}); + sd::ops::to_float32 op32; + sd::ops::to_float16 op16; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result16 = op16.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result16.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_F"); - auto out2 = result16.at(0); -// out2->printIndexedBuffer("OUT_H"); - -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(expF.equalsTo(out1)); - ASSERT_TRUE(expH.equalsTo(out2)); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result16.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_F"); + auto out2 = result16.at(0); + // out2->printIndexedBuffer("OUT_H"); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expF.equalsTo(out1)); + ASSERT_TRUE(expH.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test3) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray x = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp32 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp64 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); - sd::ops::to_uint32 op32; - sd::ops::to_uint64 op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); + sd::ops::to_uint32 op32; + sd::ops::to_uint64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); -// out1->printIndexedBuffer("OUT_U32"); - auto out2 = result64.at(0); -// out2->printIndexedBuffer("OUT_U64"); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + // out1->printIndexedBuffer("OUT_U32"); + auto out2 = result64.at(0); + // out2->printIndexedBuffer("OUT_U64"); -// output->printIndexedBuffer("Toggled"); - ASSERT_TRUE(exp32.equalsTo(out1)); - ASSERT_TRUE(exp64.equalsTo(out2)); + // output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test4) { - NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); - NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); - - sd::ops::to_float32 op32; - sd::ops::to_double op64; - auto result32 = op32.evaluate({&x}, {}, {}); - auto result64 = op64.evaluate({&x}, {}, {}); + NDArray x = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); + NDArray exp32 = NDArrayFactory::create( + 'c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray exp64 = NDArrayFactory::create( + 'c', {5, 2}, {1, 2, 3, 4, 5, 7, 9, 10, 10, 11}); - ASSERT_EQ(ND4J_STATUS_OK, result32.status()); - ASSERT_EQ(ND4J_STATUS_OK, result64.status()); - auto out1 = result32.at(0); - auto out2 = result64.at(0); + sd::ops::to_float32 op32; + sd::ops::to_double op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); - ASSERT_TRUE(exp32.equalsTo(out1)); - ASSERT_TRUE(exp64.equalsTo(out2)); + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + auto out2 = result64.at(0); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test1) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - - auto exp = NDArrayFactory::create('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, + 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test2) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); + auto exp = NDArrayFactory::create( + 'c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, + 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test3) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {1, 2}, {2, 2}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {1,2}, {2, 2}); - - auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test4) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); + auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test5) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); - - auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); - + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test6) { + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {1, 2, 1, 1}, {1, 1}); - auto input = NDArrayFactory::create(1.); - auto paddings = NDArrayFactory::create('c', {1,2,1,1}, {1, 1}); + auto exp = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test7) { + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); - auto input = NDArrayFactory::create(1.); - auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); - - auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {3}, {1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test8) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, + 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3,9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - ASSERT_EQ(result.status(), Status::OK()); - - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test9) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); - - auto exp = NDArrayFactory::create('c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, + 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, + 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test10) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test11) { + auto input = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); + auto exp = NDArrayFactory::create('c', {1, 3}, {1., 2., 3.}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test12) { + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto paddings = NDArrayFactory::create('c', {2,1}, {0, 0}); + auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test13) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); - - auto exp = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); + auto exp = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test14) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = + NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); - auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test15) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); - auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); - - auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {1}); - auto output = result.at(0); + auto exp = NDArrayFactory::create( + 'c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test16) { - - auto input = NDArrayFactory::create('c', {4,3,2}); - auto paddings = NDArrayFactory::create('c', {3,2}, {3,3,2,2,1,1}); - - auto exp = NDArrayFactory::create('c', {10,7,4}, {24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., - 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); - input.linspace(1.); - - sd::ops::mirror_pad op; - auto result = op.evaluate({&input, &paddings}, {}, {0}); - ASSERT_EQ(result.status(), Status::OK()); - auto output = result.at(0); - //output->printBuffer("VVV"); - //exp.printBuffer("EXP"); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto paddings = NDArrayFactory::create('c', {3, 2}, {3, 3, 2, 2, 1, 1}); + + auto exp = NDArrayFactory::create( + 'c', {10, 7, 4}, + {24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., 22., 21., + 22., 21., 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., + 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., 22., 21., + 22., 21., 24., 23., 24., 23., 22., 21., 22., 21., 20., 19., 20., 19., + 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., 16., 15., + 16., 15., 18., 17., 18., 17., 16., 15., 16., 15., 14., 13., 14., 13., + 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., 10., 9., + 10., 9., 12., 11., 12., 11., 10., 9., 10., 9., 8., 7., 8., 7., + 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., + 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); + input.linspace(1.); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + // output->printBuffer("VVV"); + // exp.printBuffer("EXP"); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(120.f); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - //************************************// + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(1307674368000.f); - //************************************// - - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); - //************************************// + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_01) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_02) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_3) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_4) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_5) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_6) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { - auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - - sd::ops::pnormpool2d op; - auto result = op.evaluate({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); - ASSERT_EQ(Status::OK(), result.status()); + auto input = NDArrayFactory::create( + 'c', {1, 1, 5, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 5, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); - ASSERT_EQ(exp, result.at(0)); + sd::ops::pnormpool2d op; + auto result = op.evaluate({&input}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 3, 0}); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_EQ(exp, result.at(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // output->printShapeInfo("Output shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); -// output->printShapeInfo("Output shape"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_1) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_2) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_1) { - - auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto eps = NDArrayFactory::create(1307674368000.f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, - 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, - 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, - 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, - 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); - - sd::ops::reduce_prod_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input = + NDArrayFactory::create('c', {3, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1710012166826558903812096.f, 855006083413279451906048.f, + 570004067618451974258688.f, 427503041706639725953024.f, + 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, + 190001355872817324752896.f, 171001227491294996070400.f, + 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, + 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { - - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create(0.5f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}); - - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; - auto res = op_exp.evaluate({&input}); - auto result = op.evaluate({&input, &eps}, {}, {}); - exp.assign(res.at(0).e(0)); - exp /= input; - exp *= eps.e(0); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - //exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create(0.5f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}); + + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + auto res = op_exp.evaluate({&input}); + auto result = op.evaluate({&input, &eps}, {}, {}); + exp.assign(res.at(0).e(0)); + exp /= input; + exp *= eps.e(0); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + sd::ops::reduce_prod_bp op; + // sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - sd::ops::reduce_prod_bp op; - //sd::ops::reduce_prod op_exp; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_03) { - int ax = 0; - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - sd::ops::reduce_prod_bp op; - //sd::ops::reduce_prod op_exp; - auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + int ax = 0; + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + sd::ops::reduce_prod_bp op; + // sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, + 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + // auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {0}); - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; -// auto res = op_exp.execute({&input}, {}, {}); - auto result = op.evaluate({&input, &eps}, {0.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_5) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + //************************************// + auto exp = + NDArrayFactory::create('c', {3, 4}, + {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, + 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); - auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); - auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 4}, {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + // auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {1}); - sd::ops::reduce_prod_bp op; - sd::ops::reduce_prod op_exp; -// auto res = op_exp.execute({&input}, {}, {}); - auto result = op.evaluate({&input, &eps}, {0.f}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // exp.printIndexedBuffer("Expected"); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// exp.printIndexedBuffer("Expected"); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - -// + // } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(0, eps.e(0)); - exp.p(1, eps.e(1)); - exp.p(2, eps.e(2)); - exp.p(3, eps.e(3)); - auto axes = NDArrayFactory::create({0,1}); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2, 2, -1.f); + exp.p(2, 2, 0.5f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - x.p(2,2, -1.f); - exp.p(2,2, 0.5f); - //x.printIndexedBuffer("Input is"); - // exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_4) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2, 2, -1.f); + exp.p(2, 2, 0.5f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - x.p(2,2, -1.f); - exp.p(2,2, 0.5f); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_5) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, -1.f); - x.p(1,1, -2.f); - x.p(2,2, -3.f); - x.p(3,3, -4.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// exp(2,2) = 0.5f; -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, -1.f); + x.p(1, 1, -2.f); + x.p(2, 2, -3.f); + x.p(3, 3, -4.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // exp(2,2) = 0.5f; + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_6) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, -1.f); - x.p(1,1, -2.f); - x.p(2,2, -3.f); - x.p(3,3, -4.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// exp(2,2) = 0.5f; -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_min_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, -1.f); + x.p(1, 1, -2.f); + x.p(2, 2, -3.f); + x.p(3, 3, -4.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // exp(2,2) = 0.5f; + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - x.linspace(1); - // x.printIndexedBuffer("Input is"); - // exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - exp.p(20, eps.e(0)); - exp.p(21, eps.e(1)); - exp.p(22, eps.e(2)); - exp.p(23, eps.e(3)); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_3) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, 21.f); - x.p(1,1, 22.f); - x.p(2,2, 23.f); - x.p(3,3, 24.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, 21.f); + x.p(1, 1, 22.f); + x.p(2, 2, 23.f); + x.p(3, 3, 24.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_4) { - - auto x = NDArrayFactory::create('c', {4, 4}); - auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - x.p(0,0, 21.f); - x.p(1,1, 22.f); - x.p(2,2, 23.f); - x.p(3,3, 24.f); - exp.p(0,0, 1.f); - exp.p(1,1, 2.f); - exp.p(2,2, 3.f); - exp.p(3,3, 4.f); - -// x.printIndexedBuffer("Input is"); -// exp.printIndexedBuffer("Expected "); - sd::ops::reduce_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0, 0, 21.f); + x.p(1, 1, 22.f); + x.p(2, 2, 23.f); + x.p(3, 3, 24.f); + exp.p(0, 0, 1.f); + exp.p(1, 1, 2.f); + exp.p(2, 2, 3.f); + exp.p(3, 3, 4.f); + + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(5.f); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.p(12, -2.f); - x.p(20, -3.f); - exp.assign(5.f); - exp.p(12, -exp.e(12)); - exp.p(20, -exp.e(20)); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(5.f); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.p(12, -2.f); + x.p(20, -3.f); + exp.assign(5.f); + exp.p(12, -exp.e(12)); + exp.p(20, -exp.e(20)); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - // exp.printIndexedBuffer("Expect is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // exp.printIndexedBuffer("Expect is"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - auto axes = NDArrayFactory::create({0,1}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto axes = NDArrayFactory::create({0, 1}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); - sd::ops::reduce_norm1_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(x.isSameShape(output)); - ASSERT_TRUE(x.equalsTo(output)); - - + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, - 10.f, 24.f, 42.f, 64.f, - 18.f, 40.f, 66.f, 96.f, - 26.f, 56.f, 90.f, 128.f, - 34.f, 72.f, 114.f, 160.f, - 42.f, 88.f, 138.f, 192.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, - 10.f, 24.f, 42.f, 64.f, - 18.f, 40.f, 66.f, 96.f, - 26.f, 56.f, 90.f, 128.f, - 34.f, 72.f, 114.f, 160.f, - 42.f, 88.f, 138.f, 192.f}); - auto axes = NDArrayFactory::create({0, 1}); - x.linspace(1); - - sd::ops::reduce_sqnorm_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - auto axes = NDArrayFactory::create({0,1}); - x.linspace(1); - exp.p(20, 1.f); - exp.p(21, 2.f); - exp.p(22, 3.f); - exp.p(23, 4.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - - exp.p(15, 1.f); - exp.p(19, 2.f); - exp.p(23, 3.f); + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(15, 1.f); - exp.p(19, 2.f); - exp.p(23, 3.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - exp.p(23, 1.f); - sd::ops::reduce_norm_max_bp op; - auto result = op.evaluate({&x, &eps}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + NDArray* z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create(1.f); + // auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + y.linspace(2); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); - NDArray* z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create(1.f); -// auto exp = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); - y.linspace(2); - + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + auto output = result.at(0); + auto outputX = result.at(1); + // tput->printIndexedBuffer("Result is"); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {}); - auto output = result.at(0); - auto outputX = result.at(1); - //tput->printIndexedBuffer("Result is"); + // ASSERT_EQ(ND4J_STATUS_OK, result.status()); -// ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(x.equalsTo(outputX)); + ASSERT_TRUE(y.equalsTo(output)); - ASSERT_TRUE(x.equalsTo(outputX)); - ASSERT_TRUE(y.equalsTo(output)); - - -// delete z; + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_2) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); -// auto z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create('c', {2, 4}); - auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, - 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f - }); - auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f - }); - x.assign(1.f); - eps.linspace(1); - y.assign(2.f); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps}, {}, {1}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - ASSERT_EQ(result.size(), 2); - auto outputX = result.at(0); - auto outputY = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - -// delete z; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + // auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create( + 'c', {2, 3, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f}); + auto expY = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f}); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_02) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3, 4}); -// auto z; // = NDArrayFactory::create('c', {4}); - auto eps = NDArrayFactory::create('c', {2, 4}); - auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, - 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f - }); - auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f - }); - auto axis = NDArrayFactory::create('c', {1}, {1}); - x.assign(1.f); - eps.linspace(1); - y.assign(2.f); - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - ASSERT_EQ(result.size(), 2); - auto outputX = result.at(0); - auto outputY = result.at(1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - -// delete z; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + // auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create( + 'c', {2, 3, 4}, + {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f}); + auto expY = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3}); - auto expX = NDArrayFactory::create('c', {3, 4}, {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); - auto expY = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); - x.linspace(1); - eps.linspace(1); - y.assign(2.f); - - sd::ops::reduce_dot_bp op; - auto result = op.evaluate({&x,&y, &eps}, {}, {1}); - auto outputX = result.at(0); - auto outputY = result.at(1); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(expX.equalsTo(outputX)); - ASSERT_TRUE(expY.equalsTo(outputY)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3}); + auto expX = NDArrayFactory::create( + 'c', {3, 4}, + {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); + auto expY = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); + x.linspace(1); + eps.linspace(1); + y.assign(2.f); + + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); + auto outputX = result.at(0); + auto outputY = result.at(1); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + x.linspace(1); + eps.assign(1.f); - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, - 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); - x.linspace(1); - eps.assign(1.f); - - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {0,0}); - auto output = result.at(0); + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 0}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_2) { - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, - 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); - x.linspace(1); - eps.assign(1.f); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); + x.linspace(1); + eps.assign(1.f); - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {1,0}); - auto output = result.at(0); + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1, 0}); + auto output = result.at(0); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_3) { - auto x = NDArrayFactory::create('c', {3, 4}); - auto eps = NDArrayFactory::create('c', {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}); - - x.linspace(1); - exp.linspace(0); - eps.assign(1.f); + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}); - sd::ops::cumsum_bp op; - auto result = op.evaluate({&x, &eps}, {}, {1,1}); - auto output = result.at(0); + x.linspace(1); + exp.linspace(0); + eps.assign(1.f); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(output)); - - + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.equalsTo(output)); } - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 9b5cf41848f1..7999eb75b71c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -18,33 +18,30 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 10.06.2018 // - -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" // #include using namespace sd; - class DeclarableOpsTests8 : public testing::Test { -public: - - DeclarableOpsTests8() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } }; template class TypedDeclarableOpsTests8 : public testing::Test { -public: - - TypedDeclarableOpsTests8() { - printf("\n"); - fflush(stdout); - } + public: + TypedDeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } }; typedef ::testing::Types TestingTypes; @@ -52,3626 +49,3584 @@ TYPED_TEST_CASE(TypedDeclarableOpsTests8, TestingTypes); //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create( + 'c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create('c', {3}, + {900.9375f, 969.8594f, 424.1875f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {3}, {900.9375f, 969.8594f, 424.1875f}); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, + {900.9375f, 969.8594f, 424.1875f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {900.9375f, 969.8594f, 424.1875f}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,2}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.f}); + auto exp = NDArrayFactory::create(788.6927f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); - auto exp = NDArrayFactory::create(788.6927f); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(788.6927f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(788.6927f); - - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {788.6927f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); - sd::ops::reduce_variance op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {788.6927f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); - auto axes = NDArrayFactory::create({0, 1, 2}); - sd::ops::reduce_variance op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {3}, + {30.01562f, 31.14257f, 20.59581f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {3}, {30.01562f, 31.14257f, 20.59581f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, + {30.01562f, 31.14257f, 20.59581f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {30.01562f, 31.14257f, 20.59581f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.}, {0,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(28.08367f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(28.08367f); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create(28.08367f); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create(28.08367f); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {28.08367f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {28.08367f}); - - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {1.f}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.f}, {0, 1, 2}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {0.f, 1.f}, {0, 1}); + auto output = result.at(0); - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x}, {0.f,1.f}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - // output->printBuffer("Reduced STDDEV"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDev_test08) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {27.f, 34.f, 5.f, 4.f, 54.f, 6.f, 65.f, 8.f, 37.f, 45.f, 8.f, 67.f, + 96.f, 10.f, 65.f, 41.f, 33.f, 85.f, 92.f, 24.f, 25.f, 55.f, 49.f, 76.}); + auto exp = NDArrayFactory::create( + 'c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); + auto axes = NDArrayFactory::create({0, 1}); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); + auto output = result.at(0); - auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); - auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); - auto axes = NDArrayFactory::create({0,1}); - sd::ops::reduce_stdev op; - auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - // output->printBuffer("Reduced STDDEV08"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV08"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test1) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto gradO2 = NDArrayFactory::create(0.5f); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.5f, -0.4090909f, -0.3181818f, -0.22727273f, -0.13636364f, -0.045454547f, 0.045454547f, 0.13636364f, 0.22727273f, 0.3181818f, 0.4090909f, 0.5f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.45833334f, -0.375f, -0.29166666f, -0.20833333f, -0.125f, -0.041666668f, 0.041666668f, 0.125f, 0.20833333f, 0.29166666f, 0.375f, 0.45833334f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.5f, -0.4090909f, -0.3181818f, -0.22727273f, -0.13636364f, + -0.045454547f, 0.045454547f, 0.13636364f, 0.22727273f, 0.3181818f, + 0.4090909f, 0.5f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.45833334f, -0.375f, -0.29166666f, -0.20833333f, -0.125f, + -0.041666668f, 0.041666668f, 0.125f, 0.20833333f, 0.29166666f, 0.375f, + 0.45833334f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test2) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.,2.,3.,4.}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); - auto axes = NDArrayFactory::create({(int)0,}); - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + auto axes = NDArrayFactory::create({ + (int)0, + }); + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVarianceBP_test3) { - - auto x = NDArrayFactory::create('c', {3, 4}); - auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto exp12 = NDArrayFactory::create('c', {3, 4}, - {-0.750000f, -0.250000f, 0.250000f, 0.750000f, -1.500000f, -0.500000f, - 0.500000f, 1.500000f, -2.250000f, -0.750000f, 0.750000f, 2.250000f}); - auto exp34 = NDArrayFactory::create('c', {3, 4}, - {-1.000000f, -0.333333f, 0.333333f, 1.000000f, -2.000000f, -0.666667f, - 0.666667f, 2.000000f, -3.000000f, -1.000000f, 1.000000f, 3.000000f}); - - x.linspace(1); - - sd::ops::reduce_variance_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.750000f, -0.250000f, 0.250000f, 0.750000f, -1.500000f, -0.500000f, + 0.500000f, 1.500000f, -2.250000f, -0.750000f, 0.750000f, 2.250000f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-1.000000f, -0.333333f, 0.333333f, 1.000000f, -2.000000f, -0.666667f, + 0.666667f, 2.000000f, -3.000000f, -1.000000f, 1.000000f, 3.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto gradO2 = NDArrayFactory::create(0.5f); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.069337524f, -0.056730703f, -0.04412388f, -0.031517055f, -0.018910235f, -0.0063034114f, 0.0063034114f, 0.018910235f, 0.031517055f, 0.04412388f, 0.056730703f, 0.069337524f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.06638563f, -0.05431551f, -0.0422454f, -0.030175284f, -0.01810517f, -0.006035057f, 0.006035057f, 0.01810517f, 0.030175284f, 0.0422454f, 0.05431551f, 0.06638563f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.069337524f, -0.056730703f, -0.04412388f, -0.031517055f, -0.018910235f, + -0.0063034114f, 0.0063034114f, 0.018910235f, 0.031517055f, 0.04412388f, + 0.056730703f, 0.069337524f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.06638563f, -0.05431551f, -0.0422454f, -0.030175284f, -0.01810517f, + -0.006035057f, 0.006035057f, 0.01810517f, 0.030175284f, 0.0422454f, + 0.05431551f, 0.06638563f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test2) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, + 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = + NDArrayFactory::create('c', {3, 4}, + {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test02) { - - int ax = 0; - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); - auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - + int ax = 0; + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, + 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = + NDArrayFactory::create('c', {3, 4}, + {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test3) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.f,2.f,3.f}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.f,2.f,3.f}); - auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.3354102f, -0.1118034f, 0.1118034f, 0.3354102f, -0.6708204f, -0.2236068f, 0.2236068f, 0.6708204f, -1.0062306f, -0.3354102f, 0.3354102f, 1.0062306f}); - auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.38729835f, -0.12909944f, 0.12909944f, 0.38729835f, -0.7745967f, -0.2581989f, 0.2581989f, 0.7745967f, -1.161895f, -0.38729835f, 0.38729835f, 1.161895f}); - - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp12.isSameShape(output)); - ASSERT_TRUE(exp12.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {0,1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp34.isSameShape(output)); - ASSERT_TRUE(exp34.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp12 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.3354102f, -0.1118034f, 0.1118034f, 0.3354102f, -0.6708204f, + -0.2236068f, 0.2236068f, 0.6708204f, -1.0062306f, -0.3354102f, + 0.3354102f, 1.0062306f}); + auto exp34 = NDArrayFactory::create( + 'c', {3, 4}, + {-0.38729835f, -0.12909944f, 0.12909944f, 0.38729835f, -0.7745967f, + -0.2581989f, 0.2581989f, 0.7745967f, -1.161895f, -0.38729835f, + 0.38729835f, 1.161895f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(120.f); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {}); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - //************************************// + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); - sd::ops::reduce_sum op; - auto result = op.evaluate({&input}, {}, {1}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_03) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - //************************************// - - sd::ops::reduce_sum op; - auto result = op.evaluate({&input, &axis}, {}, {}, {false}); + sd::ops::reduce_sum op; + auto result = op.evaluate({&input, &axis}, {}, {}, {false}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - // z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_1) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create(1307674368000.f); - //************************************// + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - //z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_2) { + auto input = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); - //************************************// - - sd::ops::reduce_prod op; - auto result = op.evaluate({&input}, {}, {1}); + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_01) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_02) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_sum op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_01) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_02) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 2}, {10395.f, 46080.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); - x.linspace(1); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_3) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_4) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_04) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {112.f, 1080.f, 3960.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_5) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_6) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create(479001600.f); - x.linspace(1); - - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_7) { + auto x = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2,3,2}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_prod op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 5.f, 9.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_min op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(1.f); - x.linspace(1); - - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Min_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); - x.linspace(1); - // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_min op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - // output->printShapeInfo("Output shape"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // output->printShapeInfo("Output shape"); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); + sd::ops::reduce_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Max_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_max op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); - // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {68.f, 100.f, 132.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(300.f); - x.linspace(1); - - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm1 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - x.linspace(1); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {1, 3, 1}, {29.597298f, 39.344631f, 49.759422f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); - auto axes = NDArrayFactory::create({0,2}); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(70.f); - x.linspace(1); - - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); - x.linspace(1); -// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::reduce_norm2 op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); - auto axes = NDArrayFactory::create({0,2}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(24.f); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); - x.linspace(1); - - sd::ops::reduce_norm_max op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 4}, + {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,1}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_04) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); - auto axes = NDArrayFactory::create({0, 2}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create(4900.f); - x.linspace(1); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {}, {0, 1, 2}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); - x.linspace(1); - - sd::ops::reduce_sqnorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); -// output->printIndexedBuffer("Result is"); + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_1) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_2) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_3) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_4) { + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + //************************************// - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_04) { + int ax = 0; + auto input = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + //************************************// - int ax = 0; - auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); - auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f, - 1.f, 2.f, 3.f, 4.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - //************************************// - - sd::ops::reduce_sum_bp op; - auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_BP_1) { - - auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); - auto eps = NDArrayFactory::create(1307674368000.f); - //************************************// -// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); - //************************************// - auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, - 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, - 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, - 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, - 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); - - sd::ops::reduce_prod_bp op; - auto result = op.evaluate({&input, &eps}, {}, {}); - - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Result is "); -// z->printShapeInfo(); - ASSERT_TRUE(exp.equalsTo(z)); - + auto input = + NDArrayFactory::create('c', {3, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// + // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, + // 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1710012166826558903812096.f, 855006083413279451906048.f, + 570004067618451974258688.f, 427503041706639725953024.f, + 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, + 190001355872817324752896.f, 171001227491294996070400.f, + 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, + 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + // z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); - x.linspace(1); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,1}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 1, 4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.}, {0,1}); - auto output = result.at(0); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,2}); - auto output = result.at(0); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {1, 3, 1}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.f}, {0,2}); - auto output = result.at(0); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.f}, {0, 2}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test5) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(12.5f); - x.linspace(1); - - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {}); - auto output = result.at(0); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test6) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create(12.5f); - x.linspace(1); - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - auto output = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_EQ(Status::OK(), result.status()); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - x.linspace(1); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + auto output = result.at(0); - sd::ops::reduce_mean op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(Status::OK(), result.status()); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMean_test8) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - auto axes = NDArrayFactory::create({0, 1, 2}); - x.linspace(1); - - sd::ops::reduce_mean op; - auto result = op.evaluate({&x, &axes}, {}, {}, {true}); - auto output = result.at(0); + sd::ops::reduce_mean op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create(0.5f); + auto gradO2 = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, 1. / 24, + 1. / 24, 1. / 24, 1. / 24, 1. / 24}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create(0.5f); - auto gradO2 = NDArrayFactory::create('c', {1,1}, {0.5f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24}); - - x.linspace(1); + x.linspace(1); - sd::ops::reduce_mean_bp op; + sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO1}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); + auto result = op.evaluate({&x, &gradO1}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); - // output->printShapeInfo("o"); + // output->printShapeInfo("o"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - result = op.evaluate({&x, &gradO2}, {1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test2) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, + 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); + x.linspace(1); - x.linspace(1); + sd::ops::reduce_mean_bp op; - sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test02) { - - int ax = 0; - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); - auto axis = NDArrayFactory::create('c', {1}, {ax}); - x.linspace(1); - - sd::ops::reduce_mean_bp op; - - auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + int ax = 0; + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = + NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, + 4.f / 3.f, 1.f / 3.f, 2.f / 3.f, 1.f, 4.f / 3.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto exp = + NDArrayFactory::create('c', {3, 4}, + {0.25f, 0.25f, 0.25f, 0.25f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.75f, 0.75f, 0.75f, 0.75f}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto gradO2 = NDArrayFactory::create('c', {3,1}, {1.f, 2.f, 3.f}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.25f, 0.25f, 0.25f, 0.25f, 0.5f, 0.5f, 0.5f, 0.5f, 0.75f, 0.75f, 0.75f, 0.75f}); - - x.linspace(1); + x.linspace(1); - sd::ops::reduce_mean_bp op; + sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO1}, {0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - result = op.evaluate({&x, &gradO2}, {1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { + auto x = NDArrayFactory::create('c', {3}, {2.f, 3.f, 4.f}); + auto gradO = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3}, {-0.25f, 0.f, 0.25f}); - auto x = NDArrayFactory::create('c', {3}, {2.f, 3.f, 4.f}); - auto gradO = NDArrayFactory::create(0.5f); - auto exp = NDArrayFactory::create('c', {3}, {-0.25f, 0.f, 0.25f}); + sd::ops::reduce_stdev_bp op; - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO}, {0,1}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + auto result = op.evaluate({&x, &gradO}, {0, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 3}, {2.78507, 1.34254, 4.12761, 2.88507, 2.78507, 2.88507}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3}, {2.78507, 1.34254, 4.12761, 2.88507, 2.78507, 2.88507}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {3, 4}, + {0.26328, 1.46328, 1.72656, 0., 0.26328, 0., 1.46328, 0.26328, 1.72656, + 0., 1.72656, 1.46328}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {3,4}, {0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { + auto labels = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + auto logits = NDArrayFactory::create('c', {2, 3, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, + {0.75125, 1.55125, 3.45375, 0.75125, 3.45375, 0., 2.3025, 1.15125}); - auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); - auto logits = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,4}, {0.75125, 1.55125, 3.45375, 0.75125, 3.45375, 0. , 2.3025 , 1.15125}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {1}); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test4) { + auto labels = NDArrayFactory::create('c', {2, 3}, {0, 1, 1, 0, 0, 1}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = NDArrayFactory::create('c', {2}, {2.10389, 1.00194}); - auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {2}, {2.10389, 1.00194}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test5) { + auto labels = NDArrayFactory::create('c', {2, 3}, {0, 1, 1, 0, 0, 1}); + auto logits = NDArrayFactory::create('c', {2, 3}); + auto expected = + NDArrayFactory::create('c', {3}, {0., 0.85436, 1.40871}); - auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); - auto logits = NDArrayFactory::create('c', {2,3}); - auto expected = NDArrayFactory::create('c', {3}, {0., 0.85436, 1.40871}); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test6) { + auto labels = NDArrayFactory::create('c', {2, 1}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {1}, {0.6444}); - auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {1}, {0.6444}); + logits.linspace(0.1, 0.1); - logits.linspace(0.1, 0.1); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test7) { + auto labels = NDArrayFactory::create('c', {2, 1}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); - auto logits = NDArrayFactory::create('c', {2,1}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {1}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { + auto labels = NDArrayFactory::create('c', {2}, {0, 1}); + auto logits = NDArrayFactory::create('c', {2}); + auto expected = NDArrayFactory::create(0.6444); - auto labels = NDArrayFactory::create('c', {2}, {0,1}); - auto logits = NDArrayFactory::create('c', {2}); - auto expected = NDArrayFactory::create(0.6444); - - logits.linspace(0.1, 0.1); - - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + logits.linspace(0.1, 0.1); - ASSERT_EQ(Status::OK(), results.status()); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - auto output = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + auto output = results.at(0); - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test9) { + auto labels = NDArrayFactory::create('c', {1}, {0.}); + auto logits = NDArrayFactory::create('c', {1}, {0.2}); + auto expected = NDArrayFactory::create(0.); - auto labels = NDArrayFactory::create('c', {1}, {0.}); - auto logits = NDArrayFactory::create('c', {1}, {0.2}); - auto expected = NDArrayFactory::create(0.); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(Status::OK(), results.status()); + auto output = results.at(0); - auto output = results.at(0); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test10) { + auto labels = NDArrayFactory::create('c', {1, 2}, {0, 1}); + auto logits = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - auto labels = NDArrayFactory::create('c', {1,2}, {0,1}); - auto logits = NDArrayFactory::create('c', {1,2}); - auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); - - logits.linspace(0.1, 0.1); + logits.linspace(0.1, 0.1); - sd::ops::softmax_cross_entropy_loss_with_logits op; - auto results = op.evaluate({&logits, &labels}, {}, {0}); + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); - ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(Status::OK(), results.status()); - auto output = results.at(0); + auto output = results.at(0); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test4) { + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, + 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, + 0.08528871, 0.529365, 0.5510694}); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, + 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, + 0.317105}); - auto x = NDArrayFactory::create('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test5) { + // auto x = NDArrayFactory::create('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, + // 1,2,3,4,5}); + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, + 11., 12., 12.53507, 12.26833, 12.02676}); + // auto exp = NDArrayFactory::create('c', {3, 5}, + // {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}); - // auto x = NDArrayFactory::create('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676}); - // auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {0}); - auto output = result.at(0); + x.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test6) { + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, + 6.15587, 6.66886, 7.18185, 7.69484}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484}); + x.linspace(1); - x.linspace(1); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {1}); + auto output = result.at(0); - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test7) { + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818, 3.40777, + 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636, 6.38957}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); - - x.linspace(1); + x.linspace(1); - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {0,1}); - auto output = result.at(0); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0, 1}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test8) { + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create( + 'c', {3, 5}, + {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818, 3.40777, + 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636, 6.38957}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.}, {}); - auto output = result.at(0); + x.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.}, {}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test9) { + auto x = NDArrayFactory::create('c', {2}, {3., 4.}); + auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); - auto x = NDArrayFactory::create('c', {2}, {3., 4.}); - auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.}, {}); + auto output = result.at(0); - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {4.}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test10) { + auto x = NDArrayFactory::create(6.); + auto exp = NDArrayFactory::create(5.); - auto x = NDArrayFactory::create(6.); - auto exp = NDArrayFactory::create(5.); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {5.}, {}); - auto output = result.at(0); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {5.}, {}); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, clipbynorm_test11) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1., 2., 3., 4., 4.44787, 5.33745, + 6.22702, 7.1166, 6.33046, 7.03384, 7.73723, 8.44061, + 13., 14., 15., 16., 15.12277, 16.01235, + 16.90192, 17.7915, 14.77107, 15.47446, 16.17784, 16.88123}); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061, - 13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {35.}, {0, 2}); - auto output = result.at(0); + x.linspace(1); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {35.}, {0, 2}); + auto output = result.at(0); - + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - TEST_F(DeclarableOpsTests8, clipbynorm_test_tf_119_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9}); - auto e = NDArrayFactory::create('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {0.54}, {}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto e = NDArrayFactory::create( + 'c', {3, 3}, + {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, + 0.22390789, 0.25589472, 0.28788155}); - ASSERT_EQ(e, result.at(0)); + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {0.54}, {}); - + ASSERT_EQ(e, result.at(0)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test4) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, + 1.333333, 0.333333, 0.666667, 1.000000, 1.333333}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); - auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333}); - - sd::ops::reduce_mean_bp op; + sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceMeanBP_test5) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {0.2500, 0.2500, 0.2500, 0.2500, 0.5000, 0.5000, 0.5000, 0.5000, 0.7500, + 0.7500, 0.7500, 0.7500}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); - auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {0.2500,0.2500,0.2500,0.2500, 0.5000,0.5000,0.5000,0.5000, 0.7500,0.7500,0.7500,0.7500}); - - sd::ops::reduce_mean_bp op; + sd::ops::reduce_mean_bp op; - auto result = op.evaluate({&x, &gradO1}, {0}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - result = op.evaluate({&x, &gradO2}, {1}, {1}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceStDevBP_test5) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.408248, -0.816497, -1.224745, -1.632993, 0.000000, 0.000000, 0.000000, + 0.000000, 0.408248, 0.816497, 1.224745, 1.632993}); - auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); - auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); - auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.408248, -0.816497, -1.224745, -1.632993, 0.000000, 0.000000, 0.000000, 0.000000, 0.408248, 0.816497, 1.224745, 1.632993}); + sd::ops::reduce_stdev_bp op; - sd::ops::reduce_stdev_bp op; + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); - auto result = op.evaluate({&x, &gradO1}, {0}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO2}, {1}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, zeros_as_test1) { + auto x = NDArrayFactory::create(10.f); + auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); - auto x = NDArrayFactory::create(10.f); - auto y = NDArrayFactory::create(100.f); - auto exp = NDArrayFactory::create(0.f); - - sd::ops::zeros_as op; - - Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::zeros_as op; - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); + Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, zeros_as_test2) { + auto x = NDArrayFactory::create(10.f); + // auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); - auto x = NDArrayFactory::create(10.f); - //auto y = NDArrayFactory::create(100.f); - auto exp = NDArrayFactory::create(0.f); + sd::ops::zeros_as op; - sd::ops::zeros_as op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto y = result.at(0); - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - auto y = result.at(0); - - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); - + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test1) { + auto x = NDArrayFactory::create(10.); + auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); - - sd::ops::ones_as op; - - Nd4jStatus status = op.execute({&x}, {&y}); - ASSERT_EQ(Status::OK(), status); + sd::ops::ones_as op; - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); + Nd4jStatus status = op.execute({&x}, {&y}); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test2) { + auto x = NDArrayFactory::create(10.); + // auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - //auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); + sd::ops::ones_as op; - sd::ops::ones_as op; - - auto results = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), results.status()); - auto y = results.at(0); - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); - - + auto results = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, ones_as_test3) { + auto x = NDArrayFactory::create(10.); + // auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); - auto x = NDArrayFactory::create(10.); - //auto y = NDArrayFactory::create(100.); - auto exp = NDArrayFactory::create(1.); - - sd::ops::ones_as op; + sd::ops::ones_as op; - auto results = op.evaluate({&x}, {}, {}, {}, {sd::DataType::INT32}); - ASSERT_EQ(Status::OK(), results.status()); - auto y = results.at(0); + auto results = op.evaluate({&x}, {}, {}, {}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); - ASSERT_TRUE(y.isSameShape(exp)); - ASSERT_TRUE(y.equalsTo(exp)); - - + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { - - auto data = NDArrayFactory::create('c', {10, 10}); - data.linspace(1); - - auto means = data.reduceAlongDimension(reduce::Sum, {0}); - auto deviance = NDArrayFactory::create('c', {10}, {825., 825. , 825., 825., 825., 825., 825., 825., 825., 825. }); // data.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); // = NDArrayFactory::create('c', {10, 10}); - - auto counts = NDArrayFactory::create(10.0); - -// auto expMeans = NDArrayFactory::create('c', {10, 10}); - -// auto expDeviance = NDArrayFactory::create('c', {10, 10}); - auto squared = NDArrayFactory::create('c', {10, 10}); - data.applyTransform(transform::Square, squared); - auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); -// ssSquared->printBuffer("Sum squared"); -// squared.printBuffer("Squared"); - sd::ops::normalize_moments op; - auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); - means /= counts; -// sd::ops::normalize_moments op; -// auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_EQ(results.size(), 2); - - auto outputMeans = results.at(0); - auto outputDeviance = results.at(1); - -// outputMeans->printIndexedBuffer("Means"); -// outputDeviance->printIndexedBuffer("Variance"); -// deviance.printIndexedBuffer("Expected"); -// means->printIndexedBuffer("Expected means"); - ASSERT_TRUE(means.isSameShape(outputMeans)); - ASSERT_TRUE(means.equalsTo(outputMeans)); - ASSERT_TRUE(deviance.isSameShape(outputDeviance)); - ASSERT_TRUE(deviance.equalsTo(outputDeviance)); - //delete deviance; -// ASSERT_TRUE(expMeans.isSameShape(outputMeans)); -// ASSERT_TRUE(expMeans.equalsTo(outputMeans)); -// ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); -// ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - - + auto data = NDArrayFactory::create('c', {10, 10}); + data.linspace(1); + + auto means = data.reduceAlongDimension(reduce::Sum, {0}); + auto deviance = NDArrayFactory::create( + 'c', {10}, + {825., 825., 825., 825., 825., 825., 825., 825., 825., + 825.}); // data.varianceAlongDimension(variance::SummaryStatsVariance, + // false, {0}); // = NDArrayFactory::create('c', {10, + // 10}); + + auto counts = NDArrayFactory::create(10.0); + + // auto expMeans = NDArrayFactory::create('c', {10, 10}); + + // auto expDeviance = NDArrayFactory::create('c', {10, 10}); + auto squared = NDArrayFactory::create('c', {10, 10}); + data.applyTransform(transform::Square, squared); + auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); + // ssSquared->printBuffer("Sum squared"); + // squared.printBuffer("Squared"); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); + means /= counts; + // sd::ops::normalize_moments op; + // auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + // outputMeans->printIndexedBuffer("Means"); + // outputDeviance->printIndexedBuffer("Variance"); + // deviance.printIndexedBuffer("Expected"); + // means->printIndexedBuffer("Expected means"); + ASSERT_TRUE(means.isSameShape(outputMeans)); + ASSERT_TRUE(means.equalsTo(outputMeans)); + ASSERT_TRUE(deviance.isSameShape(outputDeviance)); + ASSERT_TRUE(deviance.equalsTo(outputDeviance)); + // delete deviance; + // ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + // ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + // ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + // ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create( + 'c', {4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); - auto expVariance = NDArrayFactory::create('c', {4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 1}); - - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - - + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {1, 1, 4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create( + 'c', {1, 1, 4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); - auto expVariance = NDArrayFactory::create('c', {1,1,4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); - + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + auto expVariance = + NDArrayFactory::create('c', {3}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); - auto expVariance = NDArrayFactory::create('c', {3}, {37.25f, 37.25f, 37.25f}); - x.linspace(1); - - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0, 2}); - ASSERT_EQ(Status::OK(), result.status()); - - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); - + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_4) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = + NDArrayFactory::create('c', {1, 3, 1}, {8.5f, 12.5f, 16.5f}); + auto expVariance = + NDArrayFactory::create('c', {1, 3, 1}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto expMeans = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); - auto expVariance = NDArrayFactory::create('c', {1,3,1}, {37.25f, 37.25f, 37.25f}); - x.linspace(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0, 2}); - ASSERT_EQ(Status::OK(), result.status()); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - -// ASSERT_TRUE(exp.isSameShape(output)); -// ASSERT_TRUE(exp.equalsTo(output)); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - - + // ASSERT_TRUE(exp.isSameShape(output)); + // ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_6) { - auto expMeans = NDArrayFactory::create(12.5f); - auto expVariance = NDArrayFactory::create(47.916668f); - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1); + auto expMeans = NDArrayFactory::create(12.5f); + auto expVariance = NDArrayFactory::create(47.916668f); - sd::ops::moments op; - auto result = op.evaluate({&x}, {}, {0,1,2}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); - + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, Test_Moments_7) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = NDArrayFactory::create('c', {1, 1, 1}, {12.5f}); + auto expVariance = + NDArrayFactory::create('c', {1, 1, 1}, {47.916668f}); - auto expMeans = NDArrayFactory::create('c', {1,1,1}, {12.5f}); - auto expVariance = NDArrayFactory::create('c', {1,1,1}, {47.916668f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); - x.linspace(1); - // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); - sd::ops::moments op; - auto result = op.evaluate({&x}, {1.}, {0,1,2}); - ASSERT_EQ(Status::OK(), result.status()); + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); - auto outputMeans = result.at(0); - auto outputVariance = result.at(1); - -// outputMeans->printIndexedBuffer("Means"); -// outputVariance->printIndexedBuffer("Variance"); -// outputMeans->printShapeInfo("Result shape"); - ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); - ASSERT_TRUE(expVariance.isSameShape(outputVariance)); - ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - - + // outputMeans->printIndexedBuffer("Means"); + // outputVariance->printIndexedBuffer("Variance"); + // outputMeans->printShapeInfo("Result shape"); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) { + auto x = NDArrayFactory::create( + 'c', {1, 1, 2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto x = NDArrayFactory::create('c', {1, 1, 2, 5}, { 1.f, 2.f, 3.f, 4.f, 5.f, - 6.f, 7.f, 8.f, 9.f, 10.f} - ); - - auto exp = NDArrayFactory::create('c', {1, 1, 2, 5}, {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f}// 0.72760683, 0.4850712, 0.5848977, 0.67488194, -// 0.7581754, 0.58321184, 0.86747235, 0.4048204} - ); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 2, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, + 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f} + // 0.72760683, 0.4850712, 0.5848977, 0.67488194, + // 0.7581754, 0.58321184, 0.86747235, 0.4048204} + ); - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - //ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN out"); - //exp.printBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN out"); + // exp.printBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) { + auto x = NDArrayFactory::create('c', {1, 1, 1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto x = NDArrayFactory::create('c', {1, 1, 1, 6}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - - auto exp = NDArrayFactory::create('c', {1, 1, 1, 6}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.4193139f, 0.5360563f, 0.67936623f} - ); - - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); + auto exp = + NDArrayFactory::create('c', {1, 1, 1, 6}, + {0.2581989f, 0.3592106f, 0.40089184f, + 0.4193139f, 0.5360563f, 0.67936623f}); - ASSERT_EQ(Status::OK(), results.status()); - //ASSERT_TRUE(exp.isSameShape(out)); - //out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); - + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) { + auto x = NDArrayFactory::create( + 'c', {1, 1, 1, 10}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto exp = NDArrayFactory::create( + 'c', {1, 1, 1, 10}, + {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, + 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); - auto x = NDArrayFactory::create('c', {1, 1, 1, 10}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); - auto exp = NDArrayFactory::create('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - // out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) { + auto x = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {5.5f, 0.f, 0.3f, 5.5f, 8.6f, 0.f, 0.f, 0.4f, 1.5f, 1.f, 1.3f, 1.5f, 2.6f, + 2.f, 3.f, 1.4f}); - auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, - 8.6f, 0.f, 0.f, 0.4f, - 1.5f, 1.f, 1.3f, 1.5f, - 2.6f, 2.f, 3.f, 1.4f} - ); - - auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { - 0.98386997f, 0.f, 0.05358852f, 0.9824562f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, - 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} - ); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, + {0.98386997f, 0.f, 0.05358852f, 0.9824562f, 0.99330735f, 0.f, 0.f, + 0.37139067f, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) { - - auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } - ); -// - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) { - - auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); - x.linspace(1); - - auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { - 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } - ); -// - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); - ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4) { + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); - // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); - auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); - x.linspace(1); + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); - sd::ops::lrn op; - auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); -// out->printIndexedBuffer("LRN out"); -// exp.printIndexedBuffer("LRN exp"); -// ASSERT_TRUE(exp.equalsTo(out)); - - + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) { - int iterations = 1000; - // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); - // auto z = NDArrayFactory::create('c', {8, 32, 64, 64}); - auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); - auto z = NDArrayFactory::create('c', {2, 8, 16, 16}); - x.linspace(1); + int iterations = 1000; + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + // auto z = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + auto z = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); - sd::ops::lrn op; + sd::ops::lrn op; - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); - auto timeStart = std::chrono::system_clock::now(); + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < iterations; e++) - op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); - - auto timeEnd = std::chrono::system_clock::now(); - auto spanTime = std::chrono::duration_cast ((timeEnd - timeStart) / iterations).count(); - auto ttlTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); + for (int e = 0; e < iterations; e++) + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + auto timeEnd = std::chrono::system_clock::now(); + auto spanTime = std::chrono::duration_cast( + (timeEnd - timeStart) / iterations) + .count(); + auto ttlTime = std::chrono::duration_cast( + (timeEnd - timeStart)) + .count(); -// ASSERT_TRUE(exp.isSameShape(out)); -// ASSERT_TRUE(exp.equalsTo(out)); + // ASSERT_TRUE(exp.isSameShape(out)); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) { - - auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {1,1,1,10}); - eps.linspace(1); -// -// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { -// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} -// ); -/// - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN BP out"); - //exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1, 1, 1, 10}); + eps.linspace(1); + // + // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + // 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, + // 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, + // 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, + // 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, + // 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, + // 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, + // 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, + // 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, + // 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, + // 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, + // 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, + // 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, + // 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, + // 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, + // 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, + // 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, + // 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, + // 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, + // 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, + // 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, + // 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, + // 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, + // 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, + // 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, + // 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, + // 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, + // 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, + // 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, + // 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, + // 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, + // 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, + // 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, + // 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, + // 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, + // 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, + // 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, + // 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, + // 0.357766, 0.375275, 0.384886} + // ); + /// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_02) { - - auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {1,1,1,10}); - eps.linspace(1); -// -// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { -// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} -// ); -/// - sd::ops::lrn opFF; - sd::ops::lrn_bp opBP; - - const OpArgsHolder argsHolderFF({&x}, {1., 1., 0.5}, {5}); - const OpArgsHolder argsHolderBP({&x, &eps}, {1., 1., 0.5}, {5}); - - bool gradOK = true; //GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - //auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, sd::DataType::DOUBLE); - //auto out = results.at(0); - - //ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradOK); - //out->printBuffer("LRN BP out"); - //exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - // + auto x = NDArrayFactory::create('c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1, 1, 1, 10}); + eps.linspace(1); + // + // auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + // 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, + // 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, + // 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, + // 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, + // 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, + // 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, + // 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, + // 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, + // 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, + // 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, + // 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, + // 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, + // 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, + // 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, + // 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, + // 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, + // 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, + // 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, + // 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, + // 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, + // 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, + // 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, + // 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, + // 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, + // 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, + // 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, + // 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, + // 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, + // 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, + // 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, + // 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, + // 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, + // 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, + // 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, + // 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, + // 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, + // 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, + // 0.357766, 0.375275, 0.384886} + // ); + /// + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; + + const OpArgsHolder argsHolderFF({&x}, {1., 1., 0.5}, {5}); + const OpArgsHolder argsHolderBP({&x, &eps}, {1., 1., 0.5}, {5}); + + bool gradOK = + true; // GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + // auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, + // sd::DataType::DOUBLE); auto out = results.at(0); + + // ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradOK); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); + + // } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_1) { - - auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); - x.linspace(1); - auto eps = NDArrayFactory::create('c', {3,3,5,5}); - eps.linspace(1); -// -auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, 0.357766f, 0.375275f, 0.384886f} - ); -/// - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); -// ASSERT_TRUE(exp.isSameShape(out)); - // out->printBuffer("LRN BP out"); - // exp.printBuffer("LRN BP exp"); - //ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}); + eps.linspace(1); + // + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, + 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, + 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, + 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, + 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, + 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, + 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, + 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, + 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, + 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, + 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, + 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, + 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, + 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, + 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, + 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, + 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, + 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, + 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, + 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, + 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, + 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, + 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, + 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, + 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, + 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, + 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, + 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, + 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, + 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, + 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, + 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, + 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, + 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, + 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, + 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, + 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, + 0.357766f, 0.375275f, 0.384886f}); + /// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { - - auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); - x.linspace(1); - - auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, - 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, - 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, - 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, - 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, - - 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, - 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, - 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, - 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, - 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, - - 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, - 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, - 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, - 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, - 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, - - - 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, - 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, - 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, - 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, - 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, - - 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, - 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, - 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, - 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, - 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, - - 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, - 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, - 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, - 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, - 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, - - - 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, - 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, - 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, - 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, - 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, - - 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, - 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, - 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, - 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, - 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, - - 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, - 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, - 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, - 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, - 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); -// - auto exp = NDArrayFactory::create('c', {3,3,5,5}, { - 0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, 0.000717f, 0.000840f, 0.000992f} - // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, -0.003866f, -0.004100f, -0.003802f, -0.003383f} - ); - - sd::ops::lrn_bp op; - auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(out)); - //out->printBuffer("LRN BP out"); -// exp.printIndexedBuffer("LRN exp"); - // ASSERT_TRUE(exp.equalsTo(out)); - - + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto eps = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f}); + // + auto exp = NDArrayFactory::create( + 'c', {3, 3, 5, 5}, + {0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, + 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, + 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, + 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, + 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, + 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, + 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, + 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, + 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, + 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, + 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, + 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, + 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, + 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, + 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, + 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, + 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, + 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, + 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, + 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, + 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, + 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, + 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, + 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, + 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, + 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, + 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, + 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, + 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, + 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, + 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, + 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, + 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, + 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, + 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, + 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, + 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, + 0.000717f, 0.000840f, 0.000992f} + // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, + // 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, + // 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, + // 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, + // 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, + // 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, + // 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, + // 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, + // 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, + // -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, + // -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, + // -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, + // -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, + // -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, + // -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, + // -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, + // -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, + // -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, + // -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, + // 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, + // 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, + // 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, + // -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, + // -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, + // -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, + // -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, + // -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, + // -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, + // -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, + // -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, + // -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, + // -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, + // -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, + // -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, + // 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, + // 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, + // 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, + // -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, + // -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, + // -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, + // -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, + // -0.003866f, -0.004100f, -0.003802f, -0.003383f} + ); + + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); } - - diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 0b00f1effb2f..a42cf64d856b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -18,83 +18,78 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 22.06.2018 // - -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTests9 : public testing::Test { -public: - - DeclarableOpsTests9() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTests9() { + printf("\n"); + fflush(stdout); + } }; //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, + 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); + x.linspace(1); - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::reduce_stdev_bp op; + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { - - auto x = NDArrayFactory::create('c', {3,4}); - auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); - auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); - auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); - auto axis = NDArrayFactory::create('c', {1}, {1}); - x.linspace(1); - - sd::ops::reduce_stdev_bp op; - - auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - // output->printIndexedBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - - result = op.evaluate({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, + 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } /* @@ -110,13 +105,16 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) { double extraParams[] = {lambda}; Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); - const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + functions::random::RandomFunction::template +execTransform>(rng, x.getBuffer(), +x.shapeInfo(), extraParams); const double actualMean = +x.meanNumber().e(0); const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -141,14 +139,18 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + functions::random::RandomFunction::template +execTransform>(rng, y.getBuffer(), +y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -170,13 +172,16 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) { double extraParams[] = {lambda}; Nd4jLong *buffer = new Nd4jLong[N]; - auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, +(Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); - const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + functions::random::RandomFunction::template +execTransform>(rng, x.getBuffer(), +x.shapeInfo(), extraParams); const double actualMean = +x.meanNumber().e(0); const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -203,16 +208,20 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { Nd4jLong *buffer = new Nd4jLong[N]; // Nd4jPointer extra[2]; #ifndef __CUDABLAS__ - sd::random::RandomBuffer* rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); - if (rng == nullptr) - throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); + sd::random::RandomBuffer* rng = (sd::random::RandomBuffer *) +initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw +std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG +initialization failed !"); - functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + functions::random::RandomFunction::template +execTransform>(rng, y.getBuffer(), +y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); destroyRandom((Nd4jPointer) rng); #endif const double actualMean = x.meanNumber().e(0); - const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + const double actualStd = +x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(std, actualStd, 0.01); @@ -224,2100 +233,2184 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { */ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { - auto x = NDArrayFactory::create('f', {2, 2}, {1.0, 3.0, 2.0, 4.0}); - auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); - auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); + auto x = NDArrayFactory::create('f', {2, 2}, {1.0, 3.0, 2.0, 4.0}); + auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); + auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); - x.applyScalar(scalar::Add, 1.0, z); + x.applyScalar(scalar::Add, 1.0, z); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test1) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('c', {2, 2, 4}); + auto x2 = NDArrayFactory::create('c', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('c', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test2) { + auto x0 = NDArrayFactory::create('c', {1, 3, 1}); + auto x1 = NDArrayFactory::create('c', {1, 2, 1}); + auto x2 = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 6, 1}, + {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - auto x0 = NDArrayFactory::create('c', {1,3,1}); - auto x1 = NDArrayFactory::create('c', {1,2,1}); - auto x2 = NDArrayFactory::create('c', {1,1,1}); - auto exp = NDArrayFactory::create('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test3) { + auto x0 = NDArrayFactory::create('c', {3}); + auto x1 = NDArrayFactory::create('c', {2}); + auto x2 = NDArrayFactory::create('c', {1}); + auto exp = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - auto x0 = NDArrayFactory::create('c', {3}); - auto x1 = NDArrayFactory::create('c', {2}); - auto x2 = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test4) { + auto x0 = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + auto x1 = NDArrayFactory::create('c', {1, 1, 1}, {2.f}); + auto x2 = NDArrayFactory::create('c', {1, 1, 1}, {3.f}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create('c', {1,1,1}, {1.f}); - auto x1 = NDArrayFactory::create('c', {1,1,1}, {2.f}); - auto x2 = NDArrayFactory::create('c', {1,1,1}, {3.f}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test5) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {1}, {2.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {1}, {2.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test6) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test7) { + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create(2.f); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create(2.f); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test8) { + auto x0 = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); - auto x0 = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test9) { + auto x0 = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); - auto x0 = NDArrayFactory::create('c', {1}, {1.f}); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test10) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('c', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test11) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test12) { + auto x0 = NDArrayFactory::create('c', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 6, 4}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, - 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); - - sd::ops::concat op; + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test13) { + auto x0 = NDArrayFactory::create('f', {2, 3, 4}); + auto x1 = NDArrayFactory::create('f', {2, 2, 4}); + auto x2 = NDArrayFactory::create('f', {2, 1, 4}); + auto exp = NDArrayFactory::create( + 'f', {2, 6, 4}, + {1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, + 2.f, 14.f, 6.f, 18.f, 10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, + 3.f, 15.f, 7.f, 19.f, 11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, + 4.f, 16.f, 8.f, 20.f, 12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); - auto x0 = NDArrayFactory::create('f', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, - 3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); - - x0.linspace(1); - x1.linspace(1); - x2.linspace(1); + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } TEST_F(DeclarableOpsTests9, concat_test14) { + NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE); + NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE); - NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE); - NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE); - - x0 = 1.; - x1 = 2.; + x0 = 1.; + x1 = 2.; - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); - - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = z(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } + auto z = result.at(0); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e + 1) * 1., mean, 1e-5); + } } TEST_F(DeclarableOpsTests9, concat_test15) { - auto x = NDArrayFactory::create('c', {2}, {1, 0}); - auto y = NDArrayFactory::create (3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); - - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {2}, {1, 0}); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test16) { + auto x = NDArrayFactory::create('c', {0, 2, 3}); + auto y = NDArrayFactory::create('c', {0, 2, 3}); + auto exp = NDArrayFactory::create('c', {0, 2, 3}); - auto x = NDArrayFactory::create('c', {0,2,3}); - auto y = NDArrayFactory::create('c', {0,2,3}); - auto exp = NDArrayFactory::create('c', {0,2,3}); - - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test17) { + NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE); + NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE); - NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE); - NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE); + x0 = 1.; + x1 = 2.; - x0 = 1.; - x1 = 2.; + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printShapeInfo(); + // z->printIndexedBuffer(); - auto z = result.at(0); - // z->printShapeInfo(); - // z->printIndexedBuffer(); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); - - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = z(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e + 1) * 1., mean, 1e-5); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test18) { - Context context(1); - Nd4jLong axis = 0; + Context context(1); + Nd4jLong axis = 0; - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < 2000; e++) { - auto array = NDArrayFactory::create('c', {1, 300}); - array.assign(e); - context.setInputArray(e, array); - } + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 2000; e++) { + auto array = NDArrayFactory::create('c', {1, 300}); + array.assign(e); + context.setInputArray(e, array); + } - auto z = NDArrayFactory::create('c', {2000, 300}); - context.setOutputArray(0, z); - context.setIArguments(&axis, 1); + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, z); + context.setIArguments(&axis, 1); - sd::ops::concat op; - op.execute(&context); + sd::ops::concat op; + op.execute(&context); - for (int e = 0; e < 2000; e++) { - auto exp = NDArrayFactory::create('c', {300}); - exp.assign(e); - auto row = z(e, {0}); - ASSERT_EQ(exp, row); - } + for (int e = 0; e < 2000; e++) { + auto exp = NDArrayFactory::create('c', {300}); + exp.assign(e); + auto row = z(e, {0}); + ASSERT_EQ(exp, row); + } } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test19) { + Context context(1); + Nd4jLong axis = 0; - Context context(1); - Nd4jLong axis = 0; - - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < 10; e++) { - auto array = NDArrayFactory::create('c', {1, 5, 20}); - array.assign(e); - context.setInputArray(e, array); - } + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 10; e++) { + auto array = NDArrayFactory::create('c', {1, 5, 20}); + array.assign(e); + context.setInputArray(e, array); + } - auto z = NDArrayFactory::create('c', {10, 5, 20}); - context.setOutputArray(0, z); - context.setIArguments(&axis, 1); + auto z = NDArrayFactory::create('c', {10, 5, 20}); + context.setOutputArray(0, z); + context.setIArguments(&axis, 1); - sd::ops::concat op; - op.execute(&context); + sd::ops::concat op; + op.execute(&context); - for (int e = 0; e < 10; e++) - ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e(0), 1e-5f); + for (int e = 0; e < 10; e++) + ASSERT_NEAR((float)e, z(e, {0}).meanNumber().e(0), 1e-5f); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test20) { - auto x0 = NDArrayFactory::create('c', {1, 100, 150}); - auto x1 = NDArrayFactory::create('c', {1, 100, 150}); - auto x2 = NDArrayFactory::create('c', {1, 100, 150}); - auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - - x0.assign(1.0); - x1.assign(2.0); - x2.assign(3.0); - x3.assign(4.0); - - sd::ops::concat op; - auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - auto z = result.at(0); + x0.assign(1.0); + x1.assign(2.0); + x2.assign(3.0); + x3.assign(4.0); - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); - ASSERT_TRUE(4 == numOfTads); + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); - for (int e = 0; e < numOfTads; e++) { - NDArray tad = z(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((double) e+1, mean, 1e-5); - } + auto z = result.at(0); + Nd4jLong numOfTads = ShapeUtils::getNumOfSubArrs(z.shapeInfo(), {0}); + ASSERT_TRUE(4 == numOfTads); + for (int e = 0; e < numOfTads; e++) { + NDArray tad = z(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((double)e + 1, mean, 1e-5); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test21) { + NDArray x0('c', {1, 4, 5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 4, 5}, sd::DataType::FLOAT32); + NDArray z('f', {3, 4, 5}, sd::DataType::FLOAT32); - NDArray x0('c', {1,4,5}, sd::DataType::FLOAT32); - NDArray x1('c', {2,4,5}, sd::DataType::FLOAT32); - NDArray z('f', {3,4,5}, sd::DataType::FLOAT32); + x0 = 0.; + x1 = 1.; - x0 = 0.; - x1 = 1.; - - sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + sd::ops::concat op; + auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test22) { + NDArray x0('c', {1, 6}, {1, 2, 3, 4, 5, 6}); + NDArray x1('c', {1, 6}, {7, 8, 9, 10, 11, 12}); + NDArray output('f', {2, 6}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - NDArray x0('c', {1,6}, {1,2,3,4,5,6}); - NDArray x1('c', {1,6}, {7,8,9,10,11,12}); - NDArray output('f', {2,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); - - sd::ops::concat op; + sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test23) { + NDArray x0('c', {1, 4}, {1, 2, 3, 4}); + NDArray x1('c', {1, 4}, {5, 6, 7, 8}); + NDArray output('c', {2, 4}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray x0('c', {1,4}, {1,2,3,4}); - NDArray x1('c', {1,4}, {5,6,7,8}); - NDArray output('c', {2,4}, sd::DataType::DOUBLE); - NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); - - sd::ops::concat op; + sd::ops::concat op; - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test24) { - auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); - auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto z = NDArrayFactory::create('c', {2, 2}); + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); - sd::ops::concat op; - auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test25) { + auto x0 = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto x1 = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); + auto axis = NDArrayFactory::create('c', {1}, {0.}); + auto exp = + NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); - auto x0 = NDArrayFactory::create('c', {1,4}, {1,2,3,4}); - auto x1 = NDArrayFactory::create('c', {1,4}, {5,6,7,8}); - auto axis = NDArrayFactory::create('c', {1}, {0.}); - auto exp = NDArrayFactory::create('c', {2,4}, {1,2,3,4,5,6,7,8}); - - sd::ops::concat op; - - auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); + sd::ops::concat op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test26) { + NDArray x0('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x1('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x2('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x0('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x1('f', {1, 2, 3}, sd::DataType::INT32); - NDArray x2('f', {1, 2, 3}, sd::DataType::INT32); - - NDArray exp('f', {3, 2, 3}, {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, sd::DataType::INT32); + NDArray exp('f', {3, 2, 3}, + {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, + sd::DataType::INT32); - x0.linspace(0); - x1.linspace(6); - x2.linspace(12); + x0.linspace(0); + x1.linspace(6); + x2.linspace(12); - sd::ops::concat op; + sd::ops::concat op; - auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {}); + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test27) { + auto x1 = NDArrayFactory::create('c', {0, 1}); + auto x2 = NDArrayFactory::create('c', {0, 1}); + auto x3 = NDArrayFactory::create('c', {0, 1}); + auto x4 = NDArrayFactory::create('c', {0, 1}); - auto x1 = NDArrayFactory::create('c', {0,1}); - auto x2 = NDArrayFactory::create('c', {0,1}); - auto x3 = NDArrayFactory::create('c', {0,1}); - auto x4 = NDArrayFactory::create('c', {0,1}); + std::vector expShape = {0, 4}; - std::vector expShape = {0, 4}; + sd::ops::concat op; + auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::concat op; - auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(z.isSameShape(expShape)); + ASSERT_TRUE(z.isSameShape(expShape)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test1) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.78, 0.84, 0.9, 1.32, 1.38, 1.44}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {4, 9}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test2) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 9}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 9}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); - auto gradI = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test3) { + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.01, 0.02, 0.03, 0.04, 0.05, 0.06}); - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 3}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.01, 0.02, 0.03,0.04, 0.05, 0.06}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test4) { + auto input = + NDArrayFactory::create('c', {6}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {12}); + auto gradIExp = NDArrayFactory::create( + 'c', {6}, {0.08, 0.1, 0.12, 0.14, 0.16, 0.18}); - auto input = NDArrayFactory::create('c', {6}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {12}); - auto gradIExp = NDArrayFactory::create('c', {6}, {0.08, 0.1 , 0.12, 0.14, 0.16, 0.18}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {2}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test5) { + auto input = NDArrayFactory::create('c', {1}, {1.}); + auto gradO = NDArrayFactory::create('c', {1}); + auto gradIExp = NDArrayFactory::create('c', {1}, {0.01}); - auto input = NDArrayFactory::create('c', {1}, {1.}); - auto gradO = NDArrayFactory::create('c', {1}); - auto gradIExp = NDArrayFactory::create('c', {1}, {0.01}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test6) { + auto input = + NDArrayFactory::create('c', {2, 1, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {2, 3, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test7) { + auto input = + NDArrayFactory::create('c', {2, 1, 3}, {1., 2., 3., 4., 5., 6.}); + auto reps = NDArrayFactory::create('c', {1, 3}, {1, 3, 2}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); - auto reps = NDArrayFactory::create('c', {1, 3}, {1, 3, 2}); - auto gradO = NDArrayFactory::create('c', {2, 3, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); - - gradO.linspace(0.01, 0.01); - - sd::ops::tile_bp op; - auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); - auto gradI = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); + gradO.linspace(0.01, 0.01); + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_test1) { + auto input = + NDArrayFactory::create('c', {1, 6}, {1., 2., 3., 4., 5., 6.}); + auto reps = NDArrayFactory::create('c', {1, 2}, {2, 1}); + auto expOut = NDArrayFactory::create( + 'c', + { + 2, + 6, + }, + {1., 2., 3., 4., 5., 6., 1., 2., 3., 4., 5., 6.}); - auto input = NDArrayFactory::create('c', {1, 6}, {1.,2.,3.,4.,5.,6.}); - auto reps = NDArrayFactory::create('c', {1, 2}, {2, 1}); - auto expOut = NDArrayFactory::create('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.}); - - sd::ops::tile op; - auto results = op.evaluate({&input, &reps}, {}, {}); - auto out = results.at(0); - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(expOut.isSameShape(out)); - ASSERT_TRUE(expOut.equalsTo(out)); - + sd::ops::tile op; + auto results = op.evaluate({&input, &reps}, {}, {}); + auto out = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOut.isSameShape(out)); + ASSERT_TRUE(expOut.equalsTo(out)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { + NDArray x('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray shape('c', {2}, {2, 2}); + sd::ops::dropout_bp op; - NDArray x('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray shape('c', {2}, {2, 2}); - sd::ops::dropout_bp op; - - auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - //ress.at(0)->printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - ASSERT_FALSE(ress.at(0).equalsTo(errs)); + auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); + ASSERT_FALSE(ress.at(0).equalsTo(errs)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, TestDropout_1) { - - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); -// NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - //NDArray shape({2.f, 2.f}); - sd::ops::dropout op; - x.linspace(1); - auto ress = op.evaluate({&x}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - auto res = ress.at(0); //->printIndexedBuffer("Result is "); - //x.printIndexedBuffer("Input is"); - //res->printIndexedBuffer("Result for Dropout_1"); - auto countZero = res.reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 80, 5); - auto ress2 = op.evaluate({&x}, {0.2f}, {113}); - - ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - auto res2 = ress2.at(0); - - countZero = res.reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 80, 5); - //res2->printIndexedBuffer("Result for Dropout_2"); - ASSERT_TRUE(res.equalsTo(res2)); - //res->printIndexedBuffer("FF dropout"); - //res2->printIndexedBuffer("BP dropout"); - - - + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + // NDArray errs('c', {2, 2, 2}, + // {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + // NDArray shape({2.f, 2.f}); + sd::ops::dropout op; + x.linspace(1); + auto ress = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + auto res = ress.at(0); //->printIndexedBuffer("Result is "); + // x.printIndexedBuffer("Input is"); + // res->printIndexedBuffer("Result for Dropout_1"); + auto countZero = res.reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + auto ress2 = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + auto res2 = ress2.at(0); + + countZero = res.reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + // res2->printIndexedBuffer("Result for Dropout_2"); + ASSERT_TRUE(res.equalsTo(res2)); + // res->printIndexedBuffer("FF dropout"); + // res2->printIndexedBuffer("BP dropout"); } TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { - NDArray x0('c', {10, 10}, sd::DataType::FLOAT32); - NDArray x1('c', {10, 10}, sd::DataType::FLOAT32); - - x0.linspace(1); - x1.linspace(1); -/* - float prob[] = {0.5f}; - Nd4jLong* _bufferA = new Nd4jLong[100000]; - long _seed = 119L; - auto _rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); - - x0. applyTransform(random::DropOutInverted, &x0, prob); -// x1.template applyRandom>(_rngB, nullptr, &x1, prob); -// x0.printIndexedBuffer("01Result1"); - int count = 0; - for (int e = 0; e < x0.lengthOf(); e++) - if (x0.e(e) != 0.f) - count++; -// nd4j_printf("\nX0 count %i\n", count); -// ASSERT_TRUE(x0.equalsTo(&x1)); - - // this check is required to ensure we're calling wrong signature -// ASSERT_FALSE(x0.equalsTo(nexp0)); -// ASSERT_FALSE(x0.equalsTo(nexp1)); -// ASSERT_FALSE(x0.equalsTo(nexp2)); - destroyRandom(_rngA); - delete [] _bufferA; -*/ - sd::ops::dropout op; - - auto ress = op.evaluate({&x1}, {0.5f}, {119}); - - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - //ress.at(0)->printIndexedBuffer("01Dropout result is "); - auto count = ress.at(0).reduceNumber(reduce::CountNonZero); -// nd4j_printf("\n01Dropout count %i\n\n", count); - - sd::ops::dropout_bp op2; - //NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f}); - //02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] - - auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, sd::DataType::FLOAT32); // skipped due given by default - //x0.printIndexedBuffer("X0"); - //x1.printIndexedBuffer("X1"); - ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); - auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); - //ressY->at(0)->printIndexedBuffer("BP"); - //ress.at(0)->printIndexedBuffer("FF"); - bool ret = true; - for (int e = 0; e < ress.at(0).lengthOf(); e++) { - if (ress.at(0).e(e) == 0.f) - if (ressX.at(0).e(e) != ress.at(0).e(e)) { - ret = false; - break; - } - } - ASSERT_TRUE(ret); - // ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0))); - //ressX->at(0)->printIndexedBuffer("02Dropout result is "); -/* float countZero = ressX->at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - countZero = ress.at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - countZero = ressY->at(0)->template reduceNumber>(); - ASSERT_NEAR(countZero, 50.f, 5.f); - */ -// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - - + NDArray x0('c', {10, 10}, sd::DataType::FLOAT32); + NDArray x1('c', {10, 10}, sd::DataType::FLOAT32); + + x0.linspace(1); + x1.linspace(1); + /* + float prob[] = {0.5f}; + Nd4jLong* _bufferA = new Nd4jLong[100000]; + long _seed = 119L; + auto _rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, + 100000, (Nd4jPointer) _bufferA); + + x0. applyTransform(random::DropOutInverted, &x0, prob); + // x1.template applyRandom>(_rngB, + nullptr, &x1, prob); + // x0.printIndexedBuffer("01Result1"); + int count = 0; + for (int e = 0; e < x0.lengthOf(); e++) + if (x0.e(e) != 0.f) + count++; + // nd4j_printf("\nX0 count %i\n", count); + // ASSERT_TRUE(x0.equalsTo(&x1)); + + // this check is required to ensure we're calling wrong signature + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + destroyRandom(_rngA); + delete [] _bufferA; + */ + sd::ops::dropout op; + + auto ress = op.evaluate({&x1}, {0.5f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("01Dropout result is "); + auto count = ress.at(0).reduceNumber(reduce::CountNonZero); + // nd4j_printf("\n01Dropout count %i\n\n", count); + + sd::ops::dropout_bp op2; + // NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, + // 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, + // 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, + // 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, + // 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, + // 0.f, 400.f}); 02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, + // 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, + // 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, + // 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, + // 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, + // 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, + // 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, + // 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] + + auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, + {119}); // , false, sd::DataType::FLOAT32); // + // skipped due given by default + // x0.printIndexedBuffer("X0"); + // x1.printIndexedBuffer("X1"); + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + // ressY->at(0)->printIndexedBuffer("BP"); + // ress.at(0)->printIndexedBuffer("FF"); + bool ret = true; + for (int e = 0; e < ress.at(0).lengthOf(); e++) { + if (ress.at(0).e(e) == 0.f) + if (ressX.at(0).e(e) != ress.at(0).e(e)) { + ret = false; + break; + } + } + ASSERT_TRUE(ret); + // ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0))); + // ressX->at(0)->printIndexedBuffer("02Dropout result is "); + /* float countZero = ressX->at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ress.at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ressY->at(0)->template + reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + */ + // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); } TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - - x.linspace(1); - - sd::ops::dropout op; + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - auto ress = op.evaluate({&x}, {0.5f}, {119}); + x.linspace(1); - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); -// ress.at(0)->printIndexedBuffer("01Dropout result is "); + sd::ops::dropout op; - sd::ops::dropout_bp op2; + auto ress = op.evaluate({&x}, {0.5f}, {119}); - auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + // ress.at(0)->printIndexedBuffer("01Dropout result is "); - ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); - auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); - ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + sd::ops::dropout_bp op2; - //ress.at(0)->printIndexedBuffer("FF Dropout result is "); - //ressY->at(0)->printIndexedBuffer("BP Dropout result is "); + auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); - auto countZero = ress.at(0).reduceNumber(reduce::CountZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressX.at(0).reduceNumber(reduce::CountZero); - //nd4j_printf("X zero count is %f\n", countZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); - countZero = ressY.at(0).reduceNumber(reduce::CountZero); - //nd4j_printf("Y zero count is %f\n", countZero); - ASSERT_NEAR(countZero.e(0), 50.f, 10.f); -// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - ASSERT_TRUE(ressX.at(0).equalsTo(ressY.at(0))); + // ress.at(0)->printIndexedBuffer("FF Dropout result is "); + // ressY->at(0)->printIndexedBuffer("BP Dropout result is "); + auto countZero = ress.at(0).reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressX.at(0).reduceNumber(reduce::CountZero); + // nd4j_printf("X zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressY.at(0).reduceNumber(reduce::CountZero); + // nd4j_printf("Y zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); + ASSERT_TRUE(ressX.at(0).equalsTo(ressY.at(0))); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - NDArray eps('c', {10, 10}, sd::DataType::FLOAT32); - - x.linspace(1); - eps.linspace(1); + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + NDArray eps('c', {10, 10}, sd::DataType::FLOAT32); - sd::ops::alpha_dropout_bp op; + x.linspace(1); + eps.linspace(1); - auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + sd::ops::alpha_dropout_bp op; - ASSERT_EQ(ND4J_STATUS_OK, ress.status()); - auto res = ress.at(0); + auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); - auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + auto res = ress.at(0); - ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); - auto res2 = ress2.at(0); - //res->printIndexedBuffer("Result1AlphaBP1"); - //res2->printIndexedBuffer("Result1AlphaBP2"); - ASSERT_TRUE(res2.equalsTo(res)); + auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + auto res2 = ress2.at(0); + // res->printIndexedBuffer("Result1AlphaBP1"); + // res2->printIndexedBuffer("Result1AlphaBP2"); + ASSERT_TRUE(res2.equalsTo(res)); } TEST_F(DeclarableOpsTests9, test_range_int_1) { - auto x0 = NDArrayFactory::create(0); - auto x1 = NDArrayFactory::create(2); - auto x2 = NDArrayFactory::create(1); + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(2); + auto x2 = NDArrayFactory::create(1); - sd::ops::range op; - auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); } TEST_F(DeclarableOpsTests9, test_range_empty_1) { - auto x0 = NDArrayFactory::create(0); - auto x1 = NDArrayFactory::create(0); - auto x2 = NDArrayFactory::create(1); + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(0); + auto x2 = NDArrayFactory::create(1); - sd::ops::range op; - auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(z.isEmpty()); + ASSERT_TRUE(z.isEmpty()); } - TEST_F(DeclarableOpsTests9, test_broadcast_bool_1) { - auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - auto y = NDArrayFactory::create('c', {1, 2, 4, 4}); - auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto y = NDArrayFactory::create('c', {1, 2, 4, 4}); + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, y, z); + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { - auto orig = NDArrayFactory::create('c', {1, 7, 4, 4}); - std::vector list = {0,0, 0,2, 0,0, 0,0}; - auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - - auto y = orig(list, true); + auto orig = NDArrayFactory::create('c', {1, 7, 4, 4}); + std::vector list = {0, 0, 0, 2, 0, 0, 0, 0}; + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); - auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto y = orig(list, true); - std::vector dims = {0, 2, 3, 4}; - x.applyBroadcast(broadcast::LessThan, dims, y, z); + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); } TEST_F(DeclarableOpsTests9, test_unstack_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(5, result.size()); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { - auto x = NDArrayFactory::create({1, 2, 3, 4, 5}); - x.linspace(1.0); - auto z1 = NDArrayFactory::create(1); - auto z2 = NDArrayFactory::create(2); - auto z3 = NDArrayFactory::create(3); - auto z4 = NDArrayFactory::create(4); - auto z5 = NDArrayFactory::create(5); - std::vector z({&z1, &z2, &z3, &z4, &z5}); - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(5, result.size()); - for (size_t i = 0; i < result.size(); i++) { - ASSERT_TRUE(result.at(i).isSameShape(z[i])); - ASSERT_TRUE(result.at(i).equalsTo(z[i])); - } - + auto x = NDArrayFactory::create({1, 2, 3, 4, 5}); + x.linspace(1.0); + auto z1 = NDArrayFactory::create(1); + auto z2 = NDArrayFactory::create(2); + auto z3 = NDArrayFactory::create(3); + auto z4 = NDArrayFactory::create(4); + auto z5 = NDArrayFactory::create(5); + std::vector z({&z1, &z2, &z3, &z4, &z5}); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); + for (size_t i = 0; i < result.size(); i++) { + ASSERT_TRUE(result.at(i).isSameShape(z[i])); + ASSERT_TRUE(result.at(i).equalsTo(z[i])); + } } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_test12) { + const int bS = 5; + const int nOut = 4; + const int axis = 0; + const double clip = 2.; - const int bS = 5; - const int nOut = 4; - const int axis = 0; - const double clip = 2.; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1] - auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); - auto expect = NDArrayFactory::create('c', {bS, nOut}); + auto x = NDArrayFactory::create( + 'c', {bS, nOut}, + {0.412, 0.184, 0.961, 0.897, 0.173, 0.931, 0.736, + 0.540, 0.953, 0.278, 0.573, 0.787, 0.320, 0.776, + 0.338, 0.311, 0.835, 0.909, 0.890, 0.290}); // uniform random in range + // [0,1] + auto colVect = NDArrayFactory::create('c', {bS, 1}, + {0.9, 0.95, 1.00, 1.05, 1.1}); + auto expect = NDArrayFactory::create('c', {bS, nOut}); - auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] + auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, + true); // norm2 has shape [1, nOut] - auto y = ( (x / norm2) * clip) * colVect ; - auto temp = (x / norm2) * clip; + auto y = ((x / norm2) * clip) * colVect; + auto temp = (x / norm2) * clip; - for (int j = 0; j < nOut; ++j) { - auto yCol = y({0,0, j,j+1}); - const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); - if (norm2Col <= clip) - expect({0,0, j,j+1}).assign(yCol); - else - expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) ); - } - - sd::ops::clipbynorm op; - auto result = op.evaluate({&y}, {clip}, {axis}); - auto outFF = result.at(0); - - ASSERT_TRUE(expect.isSameShape(outFF)); - ASSERT_TRUE(expect.equalsTo(outFF)); + for (int j = 0; j < nOut; ++j) { + auto yCol = y({0, 0, j, j + 1}); + const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); + if (norm2Col <= clip) + expect({0, 0, j, j + 1}).assign(yCol); + else + expect({0, 0, j, j + 1}).assign(yCol * (clip / norm2Col)); + } + sd::ops::clipbynorm op; + auto result = op.evaluate({&y}, {clip}, {axis}); + auto outFF = result.at(0); + ASSERT_TRUE(expect.isSameShape(outFF)); + ASSERT_TRUE(expect.equalsTo(outFF)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) { + const int bS = 2; + const int nOut = 3; + const double clip = 0.7; - const int bS = 2; - const int nOut = 3; - const double clip = 0.7; + auto x = + NDArrayFactory::create('c', {bS, nOut}, + {0.412, 0.184, 0.961, 0.173, 0.736, + 0.540}); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); + const OpArgsHolder argsHolderFF({&x}, {clip}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); - const OpArgsHolder argsHolderFF({&x}, {clip}, {}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) { + const int bS = 2; + const int nOut = 3; + const int axis = 0; + const double clip = 0.7; - const int bS = 2; - const int nOut = 3; - const int axis = 0; - const double clip = 0.7; + auto x = + NDArrayFactory::create('c', {bS, nOut}, + {0.412, 0.184, 0.961, 0.173, 0.736, + 0.540}); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); - const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) { + const int bS = 2; + const int nOut = 3; + const int axis = 1; + const double clip = 1.; - const int bS = 2; - const int nOut = 3; - const int axis = 1; - const double clip = 1.; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); + auto x = + NDArrayFactory::create('c', {bS, nOut}, + {0.412, 0.184, 0.961, 0.173, 0.736, + 0.540}); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); - const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_1) { - - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1); - - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); - - auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); - - int exclusive, reverse; - - //************************************// - exclusive = 0; reverse = 0; - - sd::ops::cumprod op; - auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_TRUE(expFF.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 0; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTF.equalsTo(z)); - - - //************************************// - exclusive = 0; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expFT.equalsTo(z)); - - - //************************************// - exclusive = 1; reverse = 1; - - result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result.status()); - z = result.at(0); - ASSERT_TRUE(expTT.equalsTo(z)); - - + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240., 11., 132., 1716., + 24024., 360360.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {1, 1, 2, 6, 24, 1, 6, 42, 336, 3024, 1, 11, 132, 1716, 24024}); + + auto expFT = + NDArrayFactory::create('c', {3, 5}, + {120, 120, 60, 20, 5, 30240, 5040, 720, 90, + 10, 360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, + {120, 60, 20, 5, 1, 5040, 720, 90, 10, 1, 32760, 2730, 210, 15, 1}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; + reverse = 0; + + sd::ops::cumprod op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + //************************************// + exclusive = 0; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + //************************************// + exclusive = 1; + reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_2) { + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1, 0.1); + x1.linspace(1, 0.1); - NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray x0 = x(0, {0}); - NDArray x1 = x(1, {0}); - x0.linspace(1, 0.1); - x1.linspace(1, 0.1); + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); - NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); - NDArray exp0 = exp(0, {0}); - NDArray exp1 = exp(1, {0}); + exp0.p(0, 1.f); + exp1.p(0, 1.f); - exp0.p(0, 1.f); - exp1.p(0, 1.f); + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i - 1); + exp0.p(i, prev * x0.e(i)); + exp1.p(i, prev * x1.e(i)); + } - for (int i = 1; i < 1500; ++i) { - const auto prev = exp0.e(i-1); - exp0.p(i, prev * x0.e(i)); - exp1.p(i, prev * x1.e(i)); - } - - sd::ops::cumprod op; - auto result = op.evaluate({&x}, {}, {0, 0, 1}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::cumprod op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_1) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {0, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 0}); - const OpArgsHolder argsHolderFF({&x}, {}, {0, 0}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 0}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_2) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_3) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 0}); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 0}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 0}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_bp_check_4) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {0, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 1}); - const OpArgsHolder argsHolderFF({&x}, {}, {0, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 1}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumsum_bp_check_2) { + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 4}); - auto gradO = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); - x.linspace(1); + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); - const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + sd::ops::cumsum opFF; + sd::ops::cumsum_bp opBP; - sd::ops::cumsum opFF; - sd::ops::cumsum_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test1) { + auto inputC = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1.); - auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); - auto axis = NDArrayFactory::create(1.); + auto expFF = NDArrayFactory::create( + 'c', {3, 5}, + {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240., 11., 132., 1716., + 24024., 360360.}); + auto expTF = NDArrayFactory::create( + 'c', {3, 5}, + {1, 1, 2, 6, 24, 1, 6, 42, 336, 3024, 1, 11, 132, 1716, 24024}); - auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); - auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); + auto expFT = + NDArrayFactory::create('c', {3, 5}, + {120, 120, 60, 20, 5, 30240, 5040, 720, 90, + 10, 360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create( + 'c', {3, 5}, + {120, 60, 20, 5, 1, 5040, 720, 90, 10, 1, 32760, 2730, 210, 15, 1}); + auto gradO = NDArrayFactory::create('c', {3, 5}); - auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ - auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); - auto gradO = NDArrayFactory::create('c', {3, 5}); + int exclusive, reverse; - int exclusive, reverse; + //************************************// + exclusive = 0; + reverse = 0; - //************************************// - exclusive = 0; reverse = 0; + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, + {exclusive, reverse}); - const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); - const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test2) { + auto inputC = NDArrayFactory::create('c', {2, 2}); + auto axis = NDArrayFactory::create(1.); - auto inputC = NDArrayFactory::create('c', {2, 2}); - auto axis = NDArrayFactory::create(1.); + auto gradO = NDArrayFactory::create('c', {2, 2}); - auto gradO = NDArrayFactory::create('c', {2, 2}); + int exclusive, reverse; - int exclusive, reverse; + //************************************// + exclusive = 0; + reverse = 0; + inputC.linspace(1); + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, + {exclusive, reverse}); - //************************************// - exclusive = 0; reverse = 0; - inputC.linspace(1); - const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); - const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; - sd::ops::cumprod opFF; - sd::ops::cumprod_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1}, + {1, 1}, GradCheck::MEAN); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1}, {1, 1},GradCheck::MEAN); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = + NDArrayFactory::create('c', {3, 4}, + {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, + 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f, + -0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f,-0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - - auto result = op.evaluate({&x, &alpha}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test3) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3, 1}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {3,1}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test4) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, + -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test5) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = + NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, + 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test6) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1, 1, 1}, {-2.}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create('c', {1,1,1}, {-2.}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test7) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test8) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 0, 1, 0, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test9) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f}); - auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test10) { + auto x = NDArrayFactory::create( + 'c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create( + 'c', {2, 4}, {8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f}); - auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); - auto alpha = NDArrayFactory::create(-2.f); - auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test11) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, - 32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, - 14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, - 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, - 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = + NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, + -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, 32.f, 31.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, + -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, 14.f, 13.f, 12.f, 11.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test12) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {3,5}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {3, 5}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, + 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, 9.3f, + 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, + 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, -2.2f, + -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test13) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, - 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {5, 3}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, + 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, 9.3f, + 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, + 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, -2.2f, + -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, + 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, + 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, + 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, + 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test14) { - - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, - -33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, - -11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, - 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); - - sd::ops::prelu op; - auto result = op.evaluate({&x, &alpha}, {}, {-2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create( + 'c', {2, 10}, + {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, + 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4, 5}, + {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, + -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, -33.f, + -35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, + 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, -11.2f, -12.f, + -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, + 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, + 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, + 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, + 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { + const float theta = 2.f; + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - const float theta = 2.f; - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 3.f,4.f, 5.f, 6.f, 7.f,8.f, 9.f,10.f,11.f}); - - sd::ops::thresholdedrelu op; - - auto result = op.evaluate({&x}, {theta}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::thresholdedrelu op; + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, + -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto threshold = NDArrayFactory::create(2.0); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); - auto threshold = NDArrayFactory::create(2.0); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - - sd::ops::compare_and_bitpack op; - - auto result = op.evaluate({&x, &threshold}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); -// output->printIndexedBuffer("Packed to uint8"); - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Packed to uint8"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { + const float theta = -2.f; + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, -4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, + -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, + 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); - const float theta = -2.f; - auto x = NDArrayFactory::create('c', {2, 3, 4}, {0.f,-4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); - - sd::ops::thresholdedrelu op; - - auto result = op.evaluate({&x}, {theta}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + sd::ops::thresholdedrelu op; + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test1) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., + 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create( + 'c', {3, 4}, + {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); - auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {}); + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test2) { + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., + 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); - auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {1}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {1}); - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {1}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {1}); + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); + x.linspace(-30.); + x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu + auto alpha = + NDArrayFactory::create('c', {5, 3}, + {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 2, 5}); - auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); - x.linspace(-30.); - x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu - auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 2, 5}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-1, 2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-1, 2}); - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-1, 2}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-1, 2}); + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele + auto alpha = NDArrayFactory::create( + 'c', {2, 10}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.25, 0.1, 0.2, + 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4, 5}); - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - x.linspace(-50.); - x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele - auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.25, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4, 5}); - - const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-2}); - const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-2}); + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-2}); - sd::ops::prelu opFF; - sd::ops::prelu_bp opBP; + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) { + const double theta = 0.15; - const double theta = 0.15; - - auto x = NDArrayFactory::create('c', {2, 3, 4}, {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6,-0.5,-0.4,-0.3,-0.2,-0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); - auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + auto x = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); - const OpArgsHolder argsHolderFF({&x}, {theta}, {}); - const OpArgsHolder argsHolderBP({&x, &dLdO}, {theta}, {}); + const OpArgsHolder argsHolderFF({&x}, {theta}, {}); + const OpArgsHolder argsHolderBP({&x, &dLdO}, {theta}, {}); - sd::ops::thresholdedrelu opFF; - sd::ops::thresholdedrelu_bp opBP; + sd::ops::thresholdedrelu opFF; + sd::ops::thresholdedrelu_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, + 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); - x.linspace(1.f); - y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create(0.1); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + x.linspace(1.f); + // y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create(0.1); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); - x.linspace(1.f); - // y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&y, &x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test3) { + auto x = NDArrayFactory::create('c', {2, 1, 4}); + auto y = NDArrayFactory::create('c', {3, 1}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, + 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); - auto x = NDArrayFactory::create('c', {2, 1, 4}); - auto y = NDArrayFactory::create('c', {3,1}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); - x.linspace(1.f); - y.linspace(0.1f, 0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test4) { + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); + x.linspace(1.f); - auto x = NDArrayFactory::create('c', {1, 1}); - auto y = NDArrayFactory::create(0.1f); - auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); - x.linspace(1.f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test5) { + auto x = NDArrayFactory::create(1.f); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create(0.1f); - auto x = NDArrayFactory::create(1.f); - auto y = NDArrayFactory::create(0.1f); - auto exp = NDArrayFactory::create(0.1f); - - sd::ops::multiply op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test1) { + auto x = NDArrayFactory::create('c', {1, 1}, {100.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {1, 1}); - auto x = NDArrayFactory::create('c', {1, 1}, {100.}); - auto y = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {1, 1}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; - auto resFF = opFF.evaluate({&x, &y}, {}, {}); - auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); -// resFF->at(0)->printIndexedBuffer("Multiply 1x1"); -// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); -// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + auto resFF = opFF.evaluate({&x, &y}, {}, {}); + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); + // resFF->at(0)->printIndexedBuffer("Multiply 1x1"); + // resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); + // resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test2) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {2, 2}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test3) { + auto y = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto x = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto x = NDArrayFactory::create(0.1); - auto dLdz = NDArrayFactory::create('c', {2, 2}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test4) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create('c', {2, 2}, {0.1,0.2,0.3,0.4}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test5) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto y = NDArrayFactory::create('c', {2}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test6) { + auto y = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto x = NDArrayFactory::create('c', {2}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); - auto x = NDArrayFactory::create('c', {2}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 2}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test7) { + auto y = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto x = NDArrayFactory::create('c', {2, 1}, {0.1, 0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto x = NDArrayFactory::create('c', {2, 1}, {0.1,0.2}); - auto dLdz = NDArrayFactory::create('c', {2, 3}); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test8) { + auto y = NDArrayFactory::create('c', {2, 1, 4}); + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1., 0.5); + y.linspace(0.1, 0.05); - auto y = NDArrayFactory::create('c', {2, 1, 4}); - auto x = NDArrayFactory::create('c', {1, 3, 4}); - auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); - x.linspace(1., 0.5); - y.linspace(0.1, 0.05); - - const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - sd::ops::multiply opFF; - sd::ops::multiply_bp opBP; + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) { + auto y = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto dLdz = NDArrayFactory::create('c', {10, 10}); + // auto eps = NDArrayFactory::create('c', {10, 10}); + x.linspace(4); // 2., 2.0); + y.linspace(3); + dLdz.linspace(1); + // const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + // const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto x = NDArrayFactory::create('c', {10, 10}); - auto dLdz = NDArrayFactory::create('c', {10, 10}); - //auto eps = NDArrayFactory::create('c', {10, 10}); - x.linspace(4); //2., 2.0); - y.linspace(3); - dLdz.linspace(1); -// const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); -// const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); - -// sd::ops::floormod opFF; -// auto resFF = opFF.execute({&x, &y}, {}, {}); -// resFF->at(0)->printIndexedBuffer("FF floormod"); -// delete resFF; - sd::ops::floormod_bp opBP; - auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); - ASSERT_TRUE(resBP.status() == ND4J_STATUS_OK); - -// resBP->at(0)->printIndexedBuffer("BP floormod /dx"); -// resBP->at(1)->printIndexedBuffer("BP floormod /dy"); - ASSERT_TRUE(dLdz.equalsTo(resBP.at(0))); - ASSERT_TRUE(dLdz.equalsTo(resBP.at(1))); - -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + // sd::ops::floormod opFF; + // auto resFF = opFF.execute({&x, &y}, {}, {}); + // resFF->at(0)->printIndexedBuffer("FF floormod"); + // delete resFF; + sd::ops::floormod_bp opBP; + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_TRUE(resBP.status() == ND4J_STATUS_OK); -// ASSERT_TRUE(isGradCorrect); + // resBP->at(0)->printIndexedBuffer("BP floormod /dx"); + // resBP->at(1)->printIndexedBuffer("BP floormod /dy"); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(0))); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(1))); + + // const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, + // argsHolderFF, argsHolderBP); + + // ASSERT_TRUE(isGradCorrect); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); - auto dLdzX = NDArrayFactory::create('c', {2, 4}); - auto dLdzY = NDArrayFactory::create('c', {2, 4}); - auto dLdzZ = NDArrayFactory::create('c', {2, 4}); - auto exp = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3}); - x.linspace(1); -// dLdzX.linspace(1); -// dLdzY.linspace(2); -// dLdzZ.linspace(3); - dLdzX.assign(1); - dLdzY.assign(2); - dLdzZ.assign(3); - - sd::ops::dynamic_partition op1; - auto res1 = op1.evaluate({&x, &y}, {}, {3}); - - sd::ops::dynamic_partition_bp op2; - auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); - ASSERT_TRUE(res2.status() == ND4J_STATUS_OK); - ASSERT_TRUE(res2.size() == 2); -// printf("How many: %ul\n", res2->size()); -// res2->at(0)->printBuffer("Ouputput0"); -// res2->at(1)->printBuffer("Ouputput1"); - ASSERT_TRUE(res2.at(0).equalsTo(exp)); - + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); + auto dLdzX = NDArrayFactory::create('c', {2, 4}); + auto dLdzY = NDArrayFactory::create('c', {2, 4}); + auto dLdzZ = NDArrayFactory::create('c', {2, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 4}, + {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3}); + x.linspace(1); + // dLdzX.linspace(1); + // dLdzY.linspace(2); + // dLdzZ.linspace(3); + dLdzX.assign(1); + dLdzY.assign(2); + dLdzZ.assign(3); + + sd::ops::dynamic_partition op1; + auto res1 = op1.evaluate({&x, &y}, {}, {3}); + + sd::ops::dynamic_partition_bp op2; + auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); + ASSERT_TRUE(res2.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res2.size() == 2); + // printf("How many: %ul\n", res2->size()); + // res2->at(0)->printBuffer("Ouputput0"); + // res2->at(1)->printBuffer("Ouputput1"); + ASSERT_TRUE(res2.at(0).equalsTo(exp)); } ////////////////////////////////////////////////////////////////////// -//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) { +// TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) { // // auto x = NDArrayFactory::create('c', {2, 3, 4}); // auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); @@ -2330,41 +2423,41 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { // dLdzZ.linspace(1); // // const OpArgsHolder argsHolderFF({&x, &y}, {}, {3}); -// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); +// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, +// {3}); // // sd::ops::dynamic_partition opFF; // sd::ops::dynamic_partition_bp opBP; // -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +// argsHolderBP); // // ASSERT_TRUE(isGradCorrect); //} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { + auto x = NDArrayFactory::create('c', {2, 1, 3}, + {2.0, 6.0, -3.0, 2.0, 6.0, -3.0}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.}); + auto eps = NDArrayFactory::create('c', {2, 1, 3}); + eps.assign(1.f); + sd::ops::floormod_bp op; - auto x = NDArrayFactory::create('c', {2, 1, 3}, {2.0, 6.0, -3.0, 2.0, 6.0, -3.0}); - auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); - auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.}); - auto eps = NDArrayFactory::create('c', {2, 1, 3}); - eps.assign(1.f); - sd::ops::floormod_bp op; + auto result = op.evaluate({&x, &y, &eps}); - auto result = op.evaluate({&x, &y, &eps}); + ASSERT_TRUE(result.size() == 2); + auto gradX = result.at(0); + auto gradY = result.at(1); - ASSERT_TRUE(result.size() == 2); - auto gradX = result.at(0); - auto gradY = result.at(1); - -// gradX->printIndexedBuffer("gradX"); -// gradY->printIndexedBuffer("gradY"); - ASSERT_TRUE(exp.isSameShape(gradY)); - - ASSERT_TRUE(exp.equalsTo(gradY)); + // gradX->printIndexedBuffer("gradX"); + // gradY->printIndexedBuffer("gradY"); + ASSERT_TRUE(exp.isSameShape(gradY)); + ASSERT_TRUE(exp.equalsTo(gradY)); } - /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { @@ -2415,12 +2508,15 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { - const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, +&dLdc, &dLdh}, {}, {}); sd::ops::gruCell opFF; sd::ops::gruCell_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, sd::GradCheck::LossFunc::SUM, true); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, sd::GradCheck::LossFunc::SUM, +true); ASSERT_TRUE(isGradCorrect); } @@ -2428,50 +2524,56 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { + NDArray x = NDArrayFactory::create( + 'c', {3, 3}, {4, 12, -16, 12, 37, -43, -16, -43, 98}); + NDArray exp = NDArrayFactory::create( + 'c', {3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3.}); - NDArray x = NDArrayFactory::create('c', {3, 3}, {4,12,-16, 12 ,37,-43, -16, -43, 98}); - NDArray exp = NDArrayFactory::create('c', {3,3}, {2., 0., 0., 6., 1., 0., -8., 5., 3.}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); -// res->printIndexedBuffer("Output for Cholesky1"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky1"); + ASSERT_TRUE(exp.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4, 12, -16, 12, 37, -43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0, 1., 1., 2.}); - NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); -// res->printIndexedBuffer("Output for Cholesky 2"); - ASSERT_TRUE(exp.equalsTo(res)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky 2"); + ASSERT_TRUE(exp.equalsTo(res)); } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { + NDArray x = NDArrayFactory::create( + 'c', {2, 3, 3}, + {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, + 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); + NDArray exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 2.f}); - NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); - NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 1.f, 1.f, 2.f}); - - sd::ops::cholesky op; - - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(result.status(), ND4J_STATUS_OK); - auto res = result.at(0); - // res->printIndexedBuffer("Output for Cholesky 3"); - ASSERT_TRUE(exp.equalsTo(res, 1e-4)); + sd::ops::cholesky op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky 3"); + ASSERT_TRUE(exp.equalsTo(res, 1e-4)); } //////////////////////////////////////////////////////////////////// @@ -2496,12 +2598,14 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { // b = 0.5; // const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); -// const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); +// const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, +// {}, {}); // sd::ops::gru opFF; // sd::ops::gru_bp opBP; -// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +// argsHolderBP); // ASSERT_TRUE(isGradCorrect); // } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index 4f69da61b233..f0f3fdc05c42 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -14,49 +14,57 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + #include +#include "testlayers.h" using namespace sd; - class DeclarableOpsTestsCuda1 : public testing::Test { -public: - - DeclarableOpsTestsCuda1() { - printf("\n"); - fflush(stdout); - } + public: + DeclarableOpsTestsCuda1() { + printf("\n"); + fflush(stdout); + } }; - TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { - double inputData[150] = { - 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 - }; - - auto precursor = NDArrayFactory::create(inputData,'c',{1,149}); - NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo()); - - sd::ops::choose op; - //greater than test - auto result = op.evaluate({&x}, {0.0},{3}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(1); - - ASSERT_EQ(148,z->e(0)); - //ASSERT_TRUE(exp.isSameShape(z)); + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, + 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, + 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, + 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, + 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, + 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, + 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, + 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, + 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, + 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, + 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, + 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, + 2.16, 2.16, 2.16, 2.16, 2.16, 2.17}; + + auto precursor = NDArrayFactory::create(inputData, 'c', {1, 149}); + NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo()); + + sd::ops::choose op; + // greater than test + auto result = op.evaluate({&x}, {0.0}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148, z->e(0)); + // ASSERT_TRUE(exp.isSameShape(z)); } /* @@ -69,8 +77,8 @@ TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) { auto timeStart = std::chrono::system_clock::now(); auto status = op.execute({&x}, {&z}, {}, {1}, {}); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - nd4j_printf("exec time: %lld us\n", outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); nd4j_printf("exec time: %lld us\n", outerTime); ASSERT_EQ(Status::OK(), status); } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index 0b73bbd96eba..ec6b15698575 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -18,224 +18,224 @@ // Created by raver on 6/18/2018. // -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" // #include using namespace sd; - class EmptyTests : public testing::Test { -public: - - EmptyTests() { - printf("\n"); - fflush(stdout); - } + public: + EmptyTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(EmptyTests, Test_Create_Empty_1) { - auto empty = NDArrayFactory::empty_(); - ASSERT_TRUE(empty->isEmpty()); + auto empty = NDArrayFactory::empty_(); + ASSERT_TRUE(empty->isEmpty()); - ASSERT_EQ(0, empty->lengthOf()); - ASSERT_TRUE(empty->buffer() == nullptr); + ASSERT_EQ(0, empty->lengthOf()); + ASSERT_TRUE(empty->buffer() == nullptr); - ASSERT_TRUE(shape::isEmpty(empty->shapeInfo())); + ASSERT_TRUE(shape::isEmpty(empty->shapeInfo())); - delete empty; + delete empty; } TEST_F(EmptyTests, Test_Create_Empty_2) { - auto empty = NDArrayFactory::empty(); - ASSERT_TRUE(empty.isEmpty()); + auto empty = NDArrayFactory::empty(); + ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.buffer() == nullptr); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.buffer() == nullptr); - ASSERT_TRUE(shape::isEmpty(empty.shapeInfo())); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(shape::isEmpty(empty.shapeInfo())); + ASSERT_TRUE(empty.isEmpty()); } TEST_F(EmptyTests, Test_Concat_1) { -// auto empty = NDArrayFactory::empty_(); - auto empty = NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; - auto vector = NDArrayFactory::create('c', {1}, {1.0f}); + // auto empty = NDArrayFactory::empty_(); + auto empty = NDArray('c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::create_('c', + // {(Nd4jLong)0}}; + auto vector = NDArrayFactory::create('c', {1}, {1.0f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&empty, &vector}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &vector}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); -// z->printShapeInfo("z shape"); -// z->printIndexedBuffer("z buffr"); + // z->printShapeInfo("z shape"); + // z->printIndexedBuffer("z buffr"); - ASSERT_EQ(vector, z); + ASSERT_EQ(vector, z); } - TEST_F(EmptyTests, Test_Concat_2) { - auto empty = NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create('c', {1}, {1.0f}); - auto scalar2 = NDArrayFactory::create('c', {1}, {2.0f}); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = NDArray( + 'c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create('c', {1}, {1.0f}); + auto scalar2 = NDArrayFactory::create('c', {1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_3) { - auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create(1.0f); - auto scalar2 = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = + NDArrayFactory::empty(); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_Concat_4) { - auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create(1.0f); - auto scalar2 = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto empty = + NDArrayFactory::empty(); // NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - ASSERT_TRUE(empty.isEmpty()); + ASSERT_TRUE(empty.isEmpty()); - sd::ops::concat op; - auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::concat op; + auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(EmptyTests, Test_dup_1) { - auto empty = NDArrayFactory::empty(); - auto dup = empty.dup(); + auto empty = NDArrayFactory::empty(); + auto dup = empty.dup(); - ASSERT_TRUE(dup.isEmpty()); - ASSERT_EQ(empty, dup); + ASSERT_TRUE(dup.isEmpty()); + ASSERT_EQ(empty, dup); } TEST_F(EmptyTests, test_empty_scatter_1) { - auto x = NDArrayFactory::create('c', {5}); - auto indices = NDArrayFactory::create('c', {0}); - auto updates = NDArrayFactory::create('c', {0}); - - x.linspace(1.0f); + auto x = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); - sd::ops::scatter_upd op; - auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), result.status()); + x.linspace(1.0f); - auto z = result.at(0); - ASSERT_EQ(x, z); + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(x, z); } TEST_F(EmptyTests, test_empty_scatter_2) { - NDArray x ('c', {5}, sd::DataType::FLOAT32); - NDArray z ('c', {5}, sd::DataType::FLOAT32); - auto indices = NDArrayFactory::create('c', {0}); - auto updates = NDArrayFactory::create('c', {0}); + NDArray x('c', {5}, sd::DataType::FLOAT32); + NDArray z('c', {5}, sd::DataType::FLOAT32); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); - x.linspace(1.0f); + x.linspace(1.0f); - sd::ops::scatter_upd op; - auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true}); + sd::ops::scatter_upd op; + auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(EmptyTests, test_shaped_empty_1) { - auto empty = NDArrayFactory::create('c', {2, 0, 3}); - std::vector shape = {2, 0, 3}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(3, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {2, 0, 3}); + std::vector shape = {2, 0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(3, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_2) { - auto empty = NDArrayFactory::create('c', {0, 3}); - std::vector shape = {0, 3}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(2, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {0, 3}); + std::vector shape = {0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(2, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_3) { - auto empty = NDArrayFactory::create('c', {0}); - std::vector shape = {0}; - - ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); - ASSERT_EQ(0, empty.lengthOf()); - ASSERT_TRUE(empty.isEmpty()); - ASSERT_EQ(shape, empty.getShapeAsVector()); - ASSERT_EQ(1, empty.rankOf()); + auto empty = NDArrayFactory::create('c', {0}); + std::vector shape = {0}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(1, empty.rankOf()); } TEST_F(EmptyTests, test_shaped_empty_4) { - const auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, sd::DataType::FLOAT32); - NDArray array(shape, true, sd::LaunchContext::defaultContext()); - std::vector shapeOf({0}); - - ASSERT_TRUE(array.isEmpty()); - ASSERT_EQ(1, array.rankOf()); - ASSERT_EQ(shapeOf, array.getShapeAsVector()); + const auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo( + 0, sd::DataType::FLOAT32); + NDArray array(shape, true, sd::LaunchContext::defaultContext()); + std::vector shapeOf({0}); + + ASSERT_TRUE(array.isEmpty()); + ASSERT_EQ(1, array.rankOf()); + ASSERT_EQ(shapeOf, array.getShapeAsVector()); } - TEST_F(EmptyTests, test_empty_matmul_1) { - auto x = NDArrayFactory::create('c', {0, 1}); - auto y = NDArrayFactory::create('c', {1, 0}); - auto e = NDArrayFactory::create('c', {0, 0}); - - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto x = NDArrayFactory::create('c', {0, 1}); + auto y = NDArrayFactory::create('c', {1, 0}); + auto e = NDArrayFactory::create('c', {0, 0}); - auto z = result.at(0); - ASSERT_EQ(e, z); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_EQ(e, z); } TEST_F(EmptyTests, test_empty_matmul_2) { - auto x = NDArrayFactory::create('c', {1, 0, 4}); - auto y = NDArrayFactory::create('c', {1, 4, 0}); - auto e = NDArrayFactory::create('c', {1, 0, 0}); + auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto y = NDArrayFactory::create('c', {1, 4, 0}); + auto e = NDArrayFactory::create('c', {1, 0, 0}); - sd::ops::matmul op; - auto result = op.evaluate({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); - ASSERT_EQ(e, z); + auto z = result.at(0); + ASSERT_EQ(e, z); } diff --git a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp index f80178d19d6a..59798c27ec7e 100644 --- a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp @@ -18,46 +18,47 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include +#include #include +#include + #include -#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ExecutionLayerTests : public testing::Test { -public: - ExecutionLayerTests() { - /// - } + public: + ExecutionLayerTests() { + /// + } }; TEST_F(ExecutionLayerTests, test_reassign_1) { - ExecutionLayer layer; - OpSequence sequence1, sequence2; + ExecutionLayer layer; + OpSequence sequence1, sequence2; - ops::add op1; - ops::multiply op2; - ops::divide op3; + ops::add op1; + ops::multiply op2; + ops::divide op3; - Context ctx1(1); - Context ctx2(2); - Context ctx3(3); + Context ctx1(1); + Context ctx2(2); + Context ctx3(3); - sequence1.append(&op1, ctx1); - sequence2.append(&op2, ctx2); - sequence2.append(&op3, ctx3); + sequence1.append(&op1, ctx1); + sequence2.append(&op2, ctx2); + sequence2.append(&op3, ctx3); - layer.append(sequence1); - layer.append(sequence2); + layer.append(sequence1); + layer.append(sequence2); - auto seq = layer[0]; - ASSERT_EQ(1, seq.length()); + auto seq = layer[0]; + ASSERT_EQ(1, seq.length()); - seq = layer[1]; - ASSERT_EQ(2, seq.length()); + seq = layer[1]; + ASSERT_EQ(2, seq.length()); } - diff --git a/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp b/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp index 87ac750b2caa..35cda293bdc7 100644 --- a/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ExtraArgumentsTests.cpp @@ -18,49 +18,47 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include + #include +#include "testlayers.h" + using namespace sd; class ExtraArgumentsTests : public testing::Test { -public: - - ExtraArgumentsTests() { - printf("\n"); - fflush(stdout); - } + public: + ExtraArgumentsTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(ExtraArgumentsTests, Basic_Test_1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - ExtraArguments args({1.0, 2.0, 3.0}); + ExtraArguments args({1.0, 2.0, 3.0}); - float ef[] = {1.f, 2.f, 3.f}; - double ed[] = {1., 2., 3.}; + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; - auto ptrFloat = reinterpret_cast(args.argumentsAsT()); - auto ptrDouble = reinterpret_cast(args.argumentsAsT()); - ASSERT_TRUE(ptrFloat != nullptr); - ASSERT_TRUE(ptrDouble != nullptr); + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f); + } - for (int e = 0; e < 3; e++) { - ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5); - } + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5); + } } - TEST_F(ExtraArgumentsTests, Basic_Test_2) { - ExtraArguments args; + ExtraArguments args; - auto ptrInt = args.argumentsAsT(); - ASSERT_TRUE(ptrInt == nullptr); + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); } - diff --git a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp index c83b11f0ec69..45f1a210db79 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -18,69 +18,69 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include +#include +#include #include +#include #include -#include -#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class FlatBuffersTest : public testing::Test { -public: - int alpha = 0; - - Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; - - FlatBuffersTest() { - Environment::getInstance()->setDebug(false); - Environment::getInstance()->setVerbose(false); - Environment::getInstance()->setProfiling(false); - } - - ~FlatBuffersTest() { - Environment::getInstance()->setDebug(false); - Environment::getInstance()->setVerbose(false); - Environment::getInstance()->setProfiling(false); - - delete[] cShape; - delete[] fShape; - } + public: + int alpha = 0; + + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + + FlatBuffersTest() { + Environment::getInstance()->setDebug(false); + Environment::getInstance()->setVerbose(false); + Environment::getInstance()->setProfiling(false); + } + + ~FlatBuffersTest() { + Environment::getInstance()->setDebug(false); + Environment::getInstance()->setVerbose(false); + Environment::getInstance()->setProfiling(false); + + delete[] cShape; + delete[] fShape; + } }; /** * Simple test that creates Node & reads it */ TEST_F(FlatBuffersTest, BasicTest1) { - flatbuffers::FlatBufferBuilder builder(1024); + flatbuffers::FlatBufferBuilder builder(1024); - auto name = builder.CreateString("wow"); + auto name = builder.CreateString("wow"); - auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, transform::Ones, {0}); + auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, + transform::Ones, {0}); - builder.Finish(node); - - // now we have our buffer with data - uint8_t *buf = builder.GetBufferPointer(); - int size = builder.GetSize(); - ASSERT_TRUE(size > 0); + builder.Finish(node); + // now we have our buffer with data + uint8_t *buf = builder.GetBufferPointer(); + int size = builder.GetSize(); + ASSERT_TRUE(size > 0); + auto restored = GetFlatNode(buf); - auto restored = GetFlatNode(buf); + auto gA = new Node(restored); + auto gB = new Node(restored); - auto gA = new Node(restored); - auto gB = new Node(restored); + ASSERT_TRUE(gA->equals(gB)); - ASSERT_TRUE(gA->equals(gB)); - - delete gA; - delete gB; + delete gA; + delete gB; } /* @@ -93,10 +93,11 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); auto fBuffer = builder.CreateVector(array->asByteVector()); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); - auto fVid = CreateIntPair(builder, -1); + auto fArray = CreateFlatArray(builder, fShape, fBuffer, +sd::graph::DType::DType_FLOAT); auto fVid = CreateIntPair(builder, -1); - auto fVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + auto fVar = CreateFlatVariable(builder, fVid, 0, +sd::graph::DType::DType_FLOAT, 0, fArray); std::vector outputs1, outputs2, inputs1, inputs2; outputs1.push_back(2); @@ -115,8 +116,9 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto name1 = builder.CreateString("wow1"); auto name2 = builder.CreateString("wow2"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, transform::Abs, 0, in1, 0, vec1); - auto node2 = CreateFlatNode(builder, 2, name2, OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, +transform::Abs, 0, in1, 0, vec1); auto node2 = CreateFlatNode(builder, 2, name2, +OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2); std::vector> variables_vector; variables_vector.push_back(fVar); @@ -170,7 +172,8 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { auto vs = graph.variableSpace(); - ASSERT_EQ(OutputMode_IMPLICIT, graph.getExecutorConfiguration()->_outputMode); + ASSERT_EQ(OutputMode_IMPLICIT, +graph.getExecutorConfiguration()->_outputMode); ASSERT_EQ(3, vs->totalEntries()); ASSERT_EQ(1, vs->externalEntries()); @@ -183,14 +186,16 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { sd::graph::GraphExecutioner::execute(&graph); - auto resultWrapper = sd::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf); + auto resultWrapper = +sd::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf); auto flatResults = GetFlatResult(resultWrapper->pointer()); ASSERT_EQ(1, flatResults->variables()->size()); ASSERT_TRUE(flatResults->variables()->Get(0)->name() != nullptr); ASSERT_TRUE(flatResults->variables()->Get(0)->name()->c_str() != nullptr); - //nd4j_printf("VARNAME: %s\n", flatResults->variables()->Get(0)->name()->c_str()); + //nd4j_printf("VARNAME: %s\n", +flatResults->variables()->Get(0)->name()->c_str()); auto var0 = new Variable(flatResults->variables()->Get(0)); //auto var1 = new Variable(flatResults->variables()->Get(1)); @@ -206,23 +211,23 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) { } */ TEST_F(FlatBuffersTest, ExecutionTest1) { - auto gA = new Node(OpType_TRANSFORM_SAME); + auto gA = new Node(OpType_TRANSFORM_SAME); - auto c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + auto c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - auto e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + auto e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); - //gA->execute(array, nullptr, array); + // gA->execute(array, nullptr, array); - //ASSERT_TRUE(exp->equalsTo(array)); + // ASSERT_TRUE(exp->equalsTo(array)); - delete gA; - delete[] c; - delete array; - delete[] e; - delete exp; + delete gA; + delete[] c; + delete array; + delete[] e; + delete exp; } /* @@ -264,7 +269,8 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto name1 = builder.CreateString("wow1"); - auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, sd::graph::DType::FLOAT); + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, +sd::graph::DType::FLOAT); std::vector> variables_vector; variables_vector.push_back(fXVar); @@ -289,7 +295,8 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { auto flatGraph = graphBuilder.Finish(); builder.Finish(flatGraph); - auto restoredGraph = new Graph(GetFlatGraph(builder.GetBufferPointer())); + auto restoredGraph = new +Graph(GetFlatGraph(builder.GetBufferPointer())); GraphExecutioner::execute(restoredGraph); @@ -298,9 +305,12 @@ TEST_F(FlatBuffersTest, ExplicitOutputTest1) { // IMPLICIT is default ASSERT_EQ(1, results->size()); - //ASSERT_NEAR(-2.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); - //ASSERT_NEAR(-1.0, results->at(1)->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(-3.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + //ASSERT_NEAR(-2.0, +results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + //ASSERT_NEAR(-1.0, +results->at(1)->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(-3.0, +results->at(0)->getNDArray()->reduceNumber>(), 1e-5); //ASSERT_EQ(-1, results->at(0)->id()); //ASSERT_EQ(-2, results->at(1)->id()); @@ -323,7 +333,8 @@ TEST_F(FlatBuffersTest, ReadFile1) { ASSERT_EQ(1, restoredGraph->rootNodes()); ASSERT_EQ(2, restoredGraph->totalNodes()); - auto ones = restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray(); + auto ones = +restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray(); ASSERT_EQ(4, ones->lengthOf()); ASSERT_NEAR(4.0f, ones->template reduceNumber>(), 1e-5); @@ -331,9 +342,9 @@ TEST_F(FlatBuffersTest, ReadFile1) { Nd4jStatus status = GraphExecutioner::execute(restoredGraph); ASSERT_EQ(ND4J_STATUS_OK, status); - auto result = restoredGraph->getVariableSpace()->getVariable(2)->getNDArray(); - ASSERT_EQ(1, result->lengthOf()); - ASSERT_EQ(8, result->e(0)); + auto result = +restoredGraph->getVariableSpace()->getVariable(2)->getNDArray(); ASSERT_EQ(1, +result->lengthOf()); ASSERT_EQ(8, result->e(0)); delete[] data; delete restoredGraph; @@ -341,7 +352,8 @@ TEST_F(FlatBuffersTest, ReadFile1) { TEST_F(FlatBuffersTest, ReadFile2) { uint8_t* data = sd::graph::readFlatBuffers("./resources/adam_sum.fb"); - Nd4jPointer result = GraphExecutioner::executeFlatBuffer((Nd4jPointer) data); + Nd4jPointer result = +GraphExecutioner::executeFlatBuffer((Nd4jPointer) data); ResultSet arrays(GetFlatResult(result)); @@ -354,7 +366,8 @@ TEST_F(FlatBuffersTest, ReadFile2) { } TEST_F(FlatBuffersTest, ReadFile3) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/adam_sum.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/adam_sum.fb"); Nd4jStatus status = GraphExecutioner::execute(graph); ASSERT_EQ(ND4J_STATUS_OK, status); @@ -369,7 +382,8 @@ TEST_F(FlatBuffersTest, ReadFile3) { TEST_F(FlatBuffersTest, ReadInception1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/inception.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/inception.fb"); Nd4jStatus status = GraphExecutioner::execute(graph); @@ -395,7 +409,8 @@ TEST_F(FlatBuffersTest, ReadInception1) { TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { // TF graph: // https://gist.github.com/raver119/b86ef727e9a094aab386e2b35e878966 - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/three_args_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/three_args_while.fb"); ASSERT_TRUE(graph != nullptr); @@ -434,7 +449,8 @@ TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) { auto exp('c', {5, 2}, {3., 6., 9., 12., 15., 18., 21., 24., 27., 30.}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_loop.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_loop.fb"); ASSERT_TRUE(graph != nullptr); @@ -467,7 +483,8 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 sd::ops::assign op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/nested_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/nested_while.fb"); ASSERT_TRUE(graph != nullptr); @@ -491,11 +508,14 @@ TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { /* TEST_F(FlatBuffersTest, ReadTensorArray_1) { - // TF graph: https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2 + // TF graph: +https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2 - auto exp('c', {3, 2}, {1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000}); + auto exp('c', {3, 2}, +{1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_array.fb"); ASSERT_TRUE(graph != nullptr); @@ -516,8 +536,9 @@ TEST_F(FlatBuffersTest, ReadTensorArray_1) { */ /* TEST_F(FlatBuffersTest, ReadStridedSlice_1) { - // TF graph: https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); + // TF graph: +https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); ASSERT_TRUE(graph != nullptr); @@ -541,7 +562,8 @@ TEST_F(FlatBuffersTest, ReduceDim_1) { exp.assign(3.0); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); graph->printOut(); @@ -574,7 +596,8 @@ TEST_F(FlatBuffersTest, ReduceDim_2) { exp.assign(3.0); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb"); graph->printOut(); @@ -604,160 +627,445 @@ TEST_F(FlatBuffersTest, ReduceDim_2) { #ifdef GRAPH_FILES_OK TEST_F(FlatBuffersTest, Ae_00) { - sd::ops::rank op1; + sd::ops::rank op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, + 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, + 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, + -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); -// graph->printOut(); + // graph->printOut(); - ASSERT_EQ(OutputMode_VARIABLE_SPACE, graph->getExecutorConfiguration()->_outputMode); + ASSERT_EQ(OutputMode_VARIABLE_SPACE, + graph->getExecutorConfiguration()->_outputMode); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->variableSpace()->hasVariable(18)); + ASSERT_TRUE(graph->variableSpace()->hasVariable(18)); - auto z = graph->variableSpace()->getVariable(18)->getNDArray(); + auto z = graph->variableSpace()->getVariable(18)->getNDArray(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, expand_dims) { - sd::ops::rank op1; + sd::ops::rank op1; - auto exp = NDArrayFactory::create('c', {3, 1, 4}, {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, 0.25903158f, 1.13243439f}); + auto exp = NDArrayFactory::create( + 'c', {3, 1, 4}, + {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, + 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, + 0.25903158f, 1.13243439f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb"); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb"); -// graph->printOut(); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_TRUE(graph->variableSpace()->hasVariable(5)); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(graph->variableSpace()->hasVariable(5)); - auto z = graph->variableSpace()->getVariable(5)->getNDArray(); + auto z = graph->variableSpace()->getVariable(5)->getNDArray(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, transpose) { - sd::ops::rank op1; + sd::ops::rank op1; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb"); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb"); - //graph->printOut(); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_Stitches) { - sd::ops::realdiv op0; + sd::ops::realdiv op0; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/partition_stitch_misc.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers( + "./resources/partition_stitch_misc.fb"); + // graph->printOut(); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); - - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_GruDynamicMnist) { - sd::Environment::getInstance()->setDebug(false); - sd::Environment::getInstance()->setVerbose(false); + sd::Environment::getInstance()->setDebug(false); + sd::Environment::getInstance()->setVerbose(false); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/gru_dynamic_mnist.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers( + "./resources/gru_dynamic_mnist.fb"); + // graph->printOut(); - auto timeStart = std::chrono::system_clock::now(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto timeStart = std::chrono::system_clock::now(); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto timeEnd = std::chrono::system_clock::now(); + auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + auto outerTime = + std::chrono::duration_cast(timeEnd - timeStart) + .count(); - // nd4j_printf("GRU time 1 time %lld us\n", outerTime); + // nd4j_printf("GRU time 1 time %lld us\n", outerTime); - delete graph; + delete graph; } TEST_F(FlatBuffersTest, Test_Non2D_2) { - sd::Environment::getInstance()->setDebug(false); - sd::Environment::getInstance()->setVerbose(false); - sd::ops::realdiv op0; + sd::Environment::getInstance()->setDebug(false); + sd::Environment::getInstance()->setVerbose(false); + sd::ops::realdiv op0; - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb"); - //graph->printOut(); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb"); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(ND4J_STATUS_OK, result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); - delete graph; + delete graph; } - TEST_F(FlatBuffersTest, Test_TensorDotMisc) { - Environment::getInstance()->setVerbose(false); - Environment::getInstance()->setDebug(false); - - auto e = NDArrayFactory::create('c', {1, 3, 16, 20}, {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 9.f, 4.f, 5.f, 4.f, 3.f, 5.f, 5.f, 6.f, 4.f, 4.f, 3.f, 3.f, 6.f, 2.f, 3.f, 2.f, 5.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 4.f, 5.f, 5.f, 6.f, 7.f, 4.f, 2.f, 3.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 4.f, 3.f, 4.f, 5.f, 4.f, 6.f, 3.f, 4.f, 4.f, 5.f, 6.f, 6.f, 4.f, 6.f, 6.f, 6.f, 5.f, 6.f, 6.f, 7.f, 7.f, 4.f, 3.f, 4.f, 4.f, 4.f, 5.f, 2.f, 5.f, 7.f, 5.f, 2.f, 1.f, 5.f, 5.f, 4.f, 1.f, 4.f, 1.f, 3.f, 3.f, 5.f, 4.f, 4.f, 3.f, 7.f, 3.f, 6.f, 3.f, 3.f, 4.f, 1.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 1.f, 3.f, 4.f, 2.f, 4.f, 4.f, 2.f, 6.f, 1.f, 2.f, 2.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f, 4.f, 8.f, 5.f, 5.f, 3.f, 5.f, 3.f, 3.f, 2.f, 4.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 4.f, 3.f, 4.f, 8.f, 6.f, 5.f, 9.f, 6.f}); - - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb"); -// graph->printOut(); - - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(graph->variableSpace()->hasVariable(77)); - - auto z = graph->variableSpace()->getVariable(77,0)->getNDArray(); - - ASSERT_EQ(e, *z); - - delete graph; + Environment::getInstance()->setVerbose(false); + Environment::getInstance()->setDebug(false); + + auto e = NDArrayFactory::create( + 'c', {1, 3, 16, 20}, + {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, + 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, + 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, + 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, + 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, + 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, + 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, + 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, + 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, + 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, + 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, + 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, + 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, + 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, + 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, + 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, + 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, + 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, + 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, + 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, + 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, + 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, + 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, + 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, + 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, + 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, + 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, + 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, + 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, + 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, + 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, + 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, + 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, + 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, + 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, + 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, + 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, + 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, + 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, + 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, + 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, + 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, + 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, + 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, + 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, + 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, + 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, + 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, + 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, + 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, + 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, + 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, + 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, + 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, + 9.f, 4.f, 5.f, 4.f, 3.f, 5.f, 5.f, 6.f, 4.f, 4.f, 3.f, 3.f, 6.f, 2.f, + 3.f, 2.f, 5.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 4.f, 5.f, 5.f, 6.f, 7.f, + 4.f, 2.f, 3.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 4.f, 3.f, 4.f, 5.f, 4.f, + 6.f, 3.f, 4.f, 4.f, 5.f, 6.f, 6.f, 4.f, 6.f, 6.f, 6.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 4.f, 3.f, 4.f, 4.f, 4.f, 5.f, 2.f, 5.f, 7.f, 5.f, 2.f, 1.f, + 5.f, 5.f, 4.f, 1.f, 4.f, 1.f, 3.f, 3.f, 5.f, 4.f, 4.f, 3.f, 7.f, 3.f, + 6.f, 3.f, 3.f, 4.f, 1.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 1.f, 3.f, 4.f, + 2.f, 4.f, 4.f, 2.f, 6.f, 1.f, 2.f, 2.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, + 4.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f, 4.f, 8.f, 5.f, 5.f, 3.f, + 5.f, 3.f, 3.f, 2.f, 4.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, + 4.f, 3.f, 4.f, 8.f, 6.f, 5.f, 9.f, 6.f}); + + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb"); + // graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->variableSpace()->hasVariable(77)); + + auto z = graph->variableSpace()->getVariable(77, 0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; } - TEST_F(FlatBuffersTest, Test_MNIST_00_1) { - auto e = NDArrayFactory::create('c', {100, 10}, {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, 0.00163741f, 0.00582776f, 0.11497385f, 0.02010887f, 0.71663547f, 0.00154932f, 0.00001290f, 0.00023825f, 0.01393047f, 0.00012438f, 0.00033184f, 0.00010033f, 0.98197538f, 0.00022847f, 0.00150876f, 0.00597587f, 0.00819661f, 0.03041674f, 0.43121871f, 0.00986523f, 0.13834484f, 0.29576671f, 0.01305170f, 0.03919542f, 0.02796829f, 0.00139392f, 0.00031466f, 0.00229704f, 0.00647669f, 0.86193180f, 0.01064646f, 0.00494287f, 0.00901443f, 0.00526376f, 0.09771839f, 0.00184158f, 0.00040986f, 0.00008309f, 0.01634205f, 0.01102151f, 0.01133229f, 0.00011603f, 0.30489817f, 0.00813993f, 0.64581543f, 0.00132390f, 0.00009014f, 0.00471620f, 0.00419161f, 0.01024686f, 0.02504917f, 0.94500881f, 0.00010234f, 0.00620976f, 0.00306121f, 0.00971363f, 0.05415262f, 0.05265132f, 0.01217585f, 0.16251956f, 0.00188165f, 0.61800343f, 0.04541704f, 0.01950107f, 0.02398386f, 0.05354780f, 0.00129718f, 0.00762409f, 0.06902183f, 0.01746517f, 0.71758413f, 0.04491642f, 0.00194128f, 0.07204670f, 0.01455537f, 0.00356139f, 0.00223315f, 0.01881612f, 0.01844147f, 0.65686893f, 0.01172961f, 0.01321550f, 0.06555344f, 0.00993031f, 0.19965005f, 0.99641657f, 0.00000005f, 0.00027076f, 0.00000523f, 0.00001288f, 0.00173779f, 0.00140848f, 0.00001787f, 0.00012701f, 0.00000342f, 0.00364264f, 0.00040242f, 0.00199880f, 0.01658181f, 0.00522031f, 0.00494563f, 0.00134627f, 0.87392259f, 0.00277323f, 0.08916643f, 0.00200165f, 0.00006030f, 0.00265544f, 0.00137030f, 0.85328883f, 0.00988892f, 0.00416652f, 0.00394441f, 0.00617034f, 0.11645336f, 0.97291315f, 0.00000182f, 0.00194084f, 0.01498440f, 0.00001028f, 0.00389095f, 0.00023297f, 0.00044887f, 0.00528154f, 0.00029516f, 0.00188889f, 0.79829764f, 0.01104437f, 0.04222726f, 0.00522182f, 0.04550264f, 0.03192228f, 0.01099020f, 0.04107348f, 0.01183154f, 0.00058263f, 0.00048307f, 0.00013920f, 0.96885711f, 0.00005209f, 0.01755359f, 0.00061751f, 0.00787173f, 0.00087605f, 0.00296709f, 0.00342248f, 0.68736714f, 0.01477064f, 0.11038199f, 0.00979373f, 0.03290173f, 0.02064420f, 0.03154078f, 0.03068676f, 0.05849051f, 0.00054699f, 0.00028973f, 0.00066918f, 0.79915440f, 0.00078404f, 0.18881910f, 0.00078736f, 0.00024780f, 0.00598373f, 0.00271761f, 0.37178108f, 0.00029151f, 0.11573081f, 0.00016159f, 0.08614764f, 0.05626433f, 0.33961067f, 0.00184490f, 0.01931754f, 0.00884999f, 0.00103338f, 0.00105793f, 0.01583840f, 0.01417849f, 0.00086645f, 0.00075313f, 0.00009471f, 0.92975640f, 0.00786521f, 0.02855594f, 0.00831110f, 0.00041050f, 0.95547730f, 0.01004958f, 0.00024040f, 0.00674337f, 0.01100292f, 0.00229303f, 0.00543977f, 0.00003204f, 0.00073861f, 0.00003656f, 0.00233217f, 0.00864751f, 0.00044351f, 0.00055325f, 0.00046273f, 0.97456056f, 0.00097461f, 0.01125053f, 0.00035382f, 0.94428235f, 0.00286066f, 0.01286138f, 0.00111129f, 0.00731637f, 0.00518610f, 0.00538214f, 0.01197775f, 0.00866815f, 0.06013579f, 0.03228600f, 0.20441757f, 0.54548728f, 0.00006484f, 0.02362618f, 0.05482962f, 0.00106437f, 0.07713205f, 0.00095635f, 0.00029120f, 0.94839782f, 0.00271641f, 0.02038633f, 0.00010249f, 0.00270848f, 0.00299053f, 0.00069419f, 0.01599395f, 0.00571855f, 0.00580072f, 0.81594771f, 0.03097420f, 0.03646614f, 0.00565077f, 0.01715674f, 0.02362122f, 0.01730293f, 0.02312471f, 0.02395495f, 0.00083797f, 0.00032276f, 0.00475549f, 0.00577861f, 0.00193654f, 0.00201117f, 0.00095864f, 0.89032167f, 0.00238766f, 0.09068950f, 0.00007685f, 0.00309113f, 0.00165920f, 0.00566203f, 0.79406202f, 0.00106585f, 0.00073159f, 0.02779965f, 0.01331810f, 0.15253356f, 0.01362522f, 0.17258310f, 0.57671696f, 0.04606603f, 0.02204953f, 0.00909986f, 0.04971812f, 0.00135137f, 0.09417208f, 0.01461779f, 0.00351132f, 0.01659229f, 0.02209206f, 0.77456558f, 0.00303461f, 0.07932901f, 0.06269170f, 0.01151956f, 0.01363456f, 0.01302921f, 0.04056359f, 0.00052574f, 0.00214679f, 0.41835260f, 0.00373941f, 0.47472891f, 0.00819933f, 0.00047488f, 0.04602791f, 0.00524084f, 0.00085833f, 0.19585223f, 0.03986045f, 0.44138056f, 0.01866945f, 0.11297230f, 0.03688592f, 0.03147812f, 0.04306961f, 0.07897298f, 0.00580970f, 0.00654101f, 0.80165571f, 0.01388136f, 0.04366852f, 0.00407737f, 0.07712067f, 0.01289223f, 0.01437380f, 0.01997955f, 0.00013239f, 0.00000585f, 0.00003676f, 0.00288744f, 0.76327205f, 0.00911173f, 0.00025323f, 0.00345270f, 0.00977252f, 0.21107534f, 0.00238540f, 0.00011487f, 0.01707160f, 0.00274678f, 0.85196322f, 0.00066304f, 0.01279381f, 0.02112481f, 0.00446795f, 0.08666852f, 0.01046857f, 0.00011744f, 0.00377885f, 0.00806424f, 0.00110093f, 0.01087467f, 0.96216726f, 0.00024677f, 0.00213707f, 0.00104427f, 0.00835356f, 0.00037980f, 0.00540865f, 0.91882282f, 0.00084274f, 0.03935680f, 0.00700863f, 0.00609934f, 0.00307425f, 0.01065346f, 0.09310398f, 0.00066428f, 0.00076882f, 0.02210450f, 0.04447530f, 0.77650899f, 0.00945148f, 0.00689890f, 0.00886871f, 0.03715509f, 0.07214937f, 0.00624633f, 0.01399398f, 0.29444799f, 0.03825752f, 0.36904955f, 0.02109544f, 0.01373637f, 0.14653027f, 0.02449317f, 0.01878268f, 0.01089148f, 0.36442387f, 0.01426089f, 0.02649262f, 0.00308395f, 0.51123023f, 0.00987128f, 0.02856500f, 0.01239803f, 0.65732223f, 0.00001665f, 0.00257388f, 0.02261361f, 0.00056261f, 0.08028404f, 0.00753943f, 0.00092872f, 0.22300763f, 0.00515121f, 0.00238470f, 0.00001802f, 0.00303019f, 0.00282769f, 0.93392336f, 0.00829813f, 0.00937593f, 0.00232166f, 0.00606702f, 0.03175319f, 0.00192149f, 0.89188498f, 0.01474108f, 0.03585867f, 0.00123343f, 0.00441551f, 0.00399710f, 0.00857630f, 0.01781271f, 0.01955875f, 0.00221238f, 0.00005268f, 0.00038176f, 0.00141851f, 0.07513693f, 0.00153898f, 0.00254140f, 0.04116146f, 0.00216117f, 0.87339473f, 0.17824675f, 0.04543359f, 0.01501061f, 0.03382575f, 0.09682461f, 0.29989448f, 0.02655865f, 0.16809541f, 0.09566309f, 0.04044705f, 0.00052125f, 0.00006512f, 0.00041621f, 0.03254773f, 0.00120942f, 0.00177929f, 0.00091721f, 0.95285058f, 0.00068729f, 0.00900588f, 0.04185560f, 0.00125587f, 0.33473280f, 0.00119652f, 0.00552071f, 0.03358750f, 0.04974457f, 0.00243473f, 0.41644078f, 0.11323092f, 0.00945223f, 0.00509389f, 0.04602458f, 0.02943204f, 0.23871920f, 0.06141117f, 0.05274383f, 0.03511769f, 0.09954999f, 0.42245534f, 0.00686926f, 0.01075546f, 0.49830484f, 0.37111449f, 0.00928881f, 0.00910977f, 0.00822666f, 0.00448587f, 0.04094843f, 0.04089646f, 0.00190534f, 0.00074783f, 0.02465805f, 0.02045769f, 0.02690129f, 0.00249506f, 0.00202899f, 0.84847659f, 0.01121813f, 0.06111111f, 0.00527403f, 0.00617689f, 0.00719898f, 0.17549324f, 0.25461593f, 0.15036304f, 0.04163047f, 0.01647436f, 0.08906800f, 0.25370511f, 0.10200825f, 0.03916828f, 0.22575049f, 0.08762794f, 0.06703069f, 0.01087492f, 0.27197123f, 0.15926389f, 0.02289790f, 0.01340644f, 0.00233572f, 0.00071111f, 0.01389953f, 0.00187034f, 0.89338356f, 0.00067592f, 0.00535080f, 0.02598928f, 0.01003115f, 0.04575264f, 0.00010197f, 0.00006095f, 0.00021980f, 0.99164659f, 0.00011408f, 0.00474983f, 0.00004892f, 0.00012496f, 0.00257160f, 0.00036128f, 0.91125363f, 0.00012225f, 0.02511939f, 0.00156989f, 0.00002669f, 0.03335980f, 0.01791442f, 0.00531134f, 0.00345027f, 0.00187230f, 0.00210833f, 0.00001888f, 0.00016036f, 0.00394190f, 0.00016232f, 0.00026980f, 0.00012382f, 0.99098623f, 0.00036967f, 0.00185874f, 0.99578768f, 0.00000018f, 0.00162244f, 0.00012927f, 0.00000136f, 0.00158810f, 0.00016544f, 0.00000476f, 0.00069853f, 0.00000226f, 0.19834445f, 0.00044551f, 0.40857196f, 0.34896207f, 0.00023418f, 0.00828141f, 0.02426279f, 0.00148875f, 0.00938030f, 0.00002860f, 0.00201644f, 0.06109568f, 0.01542680f, 0.05984236f, 0.00112191f, 0.00419699f, 0.00110061f, 0.28937989f, 0.13231210f, 0.43350723f, 0.00055382f, 0.92216444f, 0.00396460f, 0.01456171f, 0.00061405f, 0.00972675f, 0.00677260f, 0.00454273f, 0.02471014f, 0.01238921f, 0.00027888f, 0.02572848f, 0.00290584f, 0.00748292f, 0.08441166f, 0.00232722f, 0.00188305f, 0.81133318f, 0.01191756f, 0.05173124f, 0.00315098f, 0.00499059f, 0.00158580f, 0.92859417f, 0.00035086f, 0.04807130f, 0.00101955f, 0.00034313f, 0.01119398f, 0.00069962f, 0.00112821f, 0.00214349f, 0.03968662f, 0.00325992f, 0.00253143f, 0.00199443f, 0.00964058f, 0.90529889f, 0.00384289f, 0.03047365f, 0.00174196f, 0.06674320f, 0.00283191f, 0.09274873f, 0.01944309f, 0.03424436f, 0.00694406f, 0.07912937f, 0.15087396f, 0.54529935f, 0.00007096f, 0.00001000f, 0.00001498f, 0.00007066f, 0.00002792f, 0.00005677f, 0.00000490f, 0.99606401f, 0.00030978f, 0.00337013f, 0.00286575f, 0.00011636f, 0.00064778f, 0.00992065f, 0.04501861f, 0.03149971f, 0.00287679f, 0.37334359f, 0.00214695f, 0.53156382f, 0.00600238f, 0.00003215f, 0.02112119f, 0.00084685f, 0.00497269f, 0.00753993f, 0.95174772f, 0.00150877f, 0.00212018f, 0.00410815f, 0.00006566f, 0.00001179f, 0.99827027f, 0.00028396f, 0.00004237f, 0.00000550f, 0.00091406f, 0.00003423f, 0.00036640f, 0.00000567f, 0.00079063f, 0.00006855f, 0.00051338f, 0.00590454f, 0.00732460f, 0.00195139f, 0.00034534f, 0.90222436f, 0.00163695f, 0.07924022f, 0.00362202f, 0.01493629f, 0.01135249f, 0.00781013f, 0.05138498f, 0.22704794f, 0.00442778f, 0.00350683f, 0.59828150f, 0.07762999f, 0.00016529f, 0.00001219f, 0.00006521f, 0.00446292f, 0.94456083f, 0.00407963f, 0.00102245f, 0.00057420f, 0.00344479f, 0.04161252f, 0.00000981f, 0.00030270f, 0.00017082f, 0.00029943f, 0.00010159f, 0.00003605f, 0.00001875f, 0.99310946f, 0.00063157f, 0.00531995f, 0.01100852f, 0.00021492f, 0.00049603f, 0.59714299f, 0.00454595f, 0.33691072f, 0.03074775f, 0.00427598f, 0.00512297f, 0.00953417f, 0.00064403f, 0.00001687f, 0.00822414f, 0.00012918f, 0.02522905f, 0.00046274f, 0.95950085f, 0.00174588f, 0.00070707f, 0.00334025f, 0.00014754f, 0.96842438f, 0.00752080f, 0.00713038f, 0.00074491f, 0.00107368f, 0.00245372f, 0.00181830f, 0.00883226f, 0.00185409f, 0.00210863f, 0.00017522f, 0.00039881f, 0.98836052f, 0.00003650f, 0.00535216f, 0.00001887f, 0.00069545f, 0.00265663f, 0.00019714f, 0.00028919f, 0.00026057f, 0.00356666f, 0.00034738f, 0.00413719f, 0.00133701f, 0.98608136f, 0.00009625f, 0.00153734f, 0.00234698f, 0.01427079f, 0.04020482f, 0.04733688f, 0.03817881f, 0.16299380f, 0.04943828f, 0.03522370f, 0.05902825f, 0.23904003f, 0.31428465f, 0.00029359f, 0.00005619f, 0.00007707f, 0.98437482f, 0.00000957f, 0.00828004f, 0.00002787f, 0.00510217f, 0.00087425f, 0.00090444f, 0.00011413f, 0.83918202f, 0.01017746f, 0.03100164f, 0.00308035f, 0.01615586f, 0.02608237f, 0.00337026f, 0.05493741f, 0.01589854f, 0.00053240f, 0.00144792f, 0.00108170f, 0.00027300f, 0.86477506f, 0.00072790f, 0.01062538f, 0.00428096f, 0.00233054f, 0.11392505f, 0.00411633f, 0.33660546f, 0.01735369f, 0.18114267f, 0.03090077f, 0.11699959f, 0.03416851f, 0.06780743f, 0.07481573f, 0.13608985f, 0.00073468f, 0.20941530f, 0.01012138f, 0.17237675f, 0.01661461f, 0.02184150f, 0.03694551f, 0.30870155f, 0.04255475f, 0.18069389f, 0.06343270f, 0.00037455f, 0.06623310f, 0.00041474f, 0.00209181f, 0.04566626f, 0.81232506f, 0.00054500f, 0.00807252f, 0.00084416f, 0.00008067f, 0.00003926f, 0.00225794f, 0.00115743f, 0.01925980f, 0.00010427f, 0.00062067f, 0.02234522f, 0.00210706f, 0.95202768f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb"); - //graph->printOut(); - - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); - - ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - - auto z = graph->variableSpace()->getVariable(6,0)->getNDArray(); - - ASSERT_EQ(e, *z); - - delete graph; + auto e = NDArrayFactory::create( + 'c', {100, 10}, + {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, + 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, + 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, + 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, + 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, + 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, + 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, + 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, + 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, + 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, + 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, + 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, + 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, + 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, + 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, + 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, + 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, + 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, + 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, + 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, + 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, + 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, + 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, + 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, + 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, + 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, + 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, + 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, + 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, + 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, + 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, + 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, + 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, + 0.00163741f, 0.00582776f, 0.11497385f, 0.02010887f, 0.71663547f, + 0.00154932f, 0.00001290f, 0.00023825f, 0.01393047f, 0.00012438f, + 0.00033184f, 0.00010033f, 0.98197538f, 0.00022847f, 0.00150876f, + 0.00597587f, 0.00819661f, 0.03041674f, 0.43121871f, 0.00986523f, + 0.13834484f, 0.29576671f, 0.01305170f, 0.03919542f, 0.02796829f, + 0.00139392f, 0.00031466f, 0.00229704f, 0.00647669f, 0.86193180f, + 0.01064646f, 0.00494287f, 0.00901443f, 0.00526376f, 0.09771839f, + 0.00184158f, 0.00040986f, 0.00008309f, 0.01634205f, 0.01102151f, + 0.01133229f, 0.00011603f, 0.30489817f, 0.00813993f, 0.64581543f, + 0.00132390f, 0.00009014f, 0.00471620f, 0.00419161f, 0.01024686f, + 0.02504917f, 0.94500881f, 0.00010234f, 0.00620976f, 0.00306121f, + 0.00971363f, 0.05415262f, 0.05265132f, 0.01217585f, 0.16251956f, + 0.00188165f, 0.61800343f, 0.04541704f, 0.01950107f, 0.02398386f, + 0.05354780f, 0.00129718f, 0.00762409f, 0.06902183f, 0.01746517f, + 0.71758413f, 0.04491642f, 0.00194128f, 0.07204670f, 0.01455537f, + 0.00356139f, 0.00223315f, 0.01881612f, 0.01844147f, 0.65686893f, + 0.01172961f, 0.01321550f, 0.06555344f, 0.00993031f, 0.19965005f, + 0.99641657f, 0.00000005f, 0.00027076f, 0.00000523f, 0.00001288f, + 0.00173779f, 0.00140848f, 0.00001787f, 0.00012701f, 0.00000342f, + 0.00364264f, 0.00040242f, 0.00199880f, 0.01658181f, 0.00522031f, + 0.00494563f, 0.00134627f, 0.87392259f, 0.00277323f, 0.08916643f, + 0.00200165f, 0.00006030f, 0.00265544f, 0.00137030f, 0.85328883f, + 0.00988892f, 0.00416652f, 0.00394441f, 0.00617034f, 0.11645336f, + 0.97291315f, 0.00000182f, 0.00194084f, 0.01498440f, 0.00001028f, + 0.00389095f, 0.00023297f, 0.00044887f, 0.00528154f, 0.00029516f, + 0.00188889f, 0.79829764f, 0.01104437f, 0.04222726f, 0.00522182f, + 0.04550264f, 0.03192228f, 0.01099020f, 0.04107348f, 0.01183154f, + 0.00058263f, 0.00048307f, 0.00013920f, 0.96885711f, 0.00005209f, + 0.01755359f, 0.00061751f, 0.00787173f, 0.00087605f, 0.00296709f, + 0.00342248f, 0.68736714f, 0.01477064f, 0.11038199f, 0.00979373f, + 0.03290173f, 0.02064420f, 0.03154078f, 0.03068676f, 0.05849051f, + 0.00054699f, 0.00028973f, 0.00066918f, 0.79915440f, 0.00078404f, + 0.18881910f, 0.00078736f, 0.00024780f, 0.00598373f, 0.00271761f, + 0.37178108f, 0.00029151f, 0.11573081f, 0.00016159f, 0.08614764f, + 0.05626433f, 0.33961067f, 0.00184490f, 0.01931754f, 0.00884999f, + 0.00103338f, 0.00105793f, 0.01583840f, 0.01417849f, 0.00086645f, + 0.00075313f, 0.00009471f, 0.92975640f, 0.00786521f, 0.02855594f, + 0.00831110f, 0.00041050f, 0.95547730f, 0.01004958f, 0.00024040f, + 0.00674337f, 0.01100292f, 0.00229303f, 0.00543977f, 0.00003204f, + 0.00073861f, 0.00003656f, 0.00233217f, 0.00864751f, 0.00044351f, + 0.00055325f, 0.00046273f, 0.97456056f, 0.00097461f, 0.01125053f, + 0.00035382f, 0.94428235f, 0.00286066f, 0.01286138f, 0.00111129f, + 0.00731637f, 0.00518610f, 0.00538214f, 0.01197775f, 0.00866815f, + 0.06013579f, 0.03228600f, 0.20441757f, 0.54548728f, 0.00006484f, + 0.02362618f, 0.05482962f, 0.00106437f, 0.07713205f, 0.00095635f, + 0.00029120f, 0.94839782f, 0.00271641f, 0.02038633f, 0.00010249f, + 0.00270848f, 0.00299053f, 0.00069419f, 0.01599395f, 0.00571855f, + 0.00580072f, 0.81594771f, 0.03097420f, 0.03646614f, 0.00565077f, + 0.01715674f, 0.02362122f, 0.01730293f, 0.02312471f, 0.02395495f, + 0.00083797f, 0.00032276f, 0.00475549f, 0.00577861f, 0.00193654f, + 0.00201117f, 0.00095864f, 0.89032167f, 0.00238766f, 0.09068950f, + 0.00007685f, 0.00309113f, 0.00165920f, 0.00566203f, 0.79406202f, + 0.00106585f, 0.00073159f, 0.02779965f, 0.01331810f, 0.15253356f, + 0.01362522f, 0.17258310f, 0.57671696f, 0.04606603f, 0.02204953f, + 0.00909986f, 0.04971812f, 0.00135137f, 0.09417208f, 0.01461779f, + 0.00351132f, 0.01659229f, 0.02209206f, 0.77456558f, 0.00303461f, + 0.07932901f, 0.06269170f, 0.01151956f, 0.01363456f, 0.01302921f, + 0.04056359f, 0.00052574f, 0.00214679f, 0.41835260f, 0.00373941f, + 0.47472891f, 0.00819933f, 0.00047488f, 0.04602791f, 0.00524084f, + 0.00085833f, 0.19585223f, 0.03986045f, 0.44138056f, 0.01866945f, + 0.11297230f, 0.03688592f, 0.03147812f, 0.04306961f, 0.07897298f, + 0.00580970f, 0.00654101f, 0.80165571f, 0.01388136f, 0.04366852f, + 0.00407737f, 0.07712067f, 0.01289223f, 0.01437380f, 0.01997955f, + 0.00013239f, 0.00000585f, 0.00003676f, 0.00288744f, 0.76327205f, + 0.00911173f, 0.00025323f, 0.00345270f, 0.00977252f, 0.21107534f, + 0.00238540f, 0.00011487f, 0.01707160f, 0.00274678f, 0.85196322f, + 0.00066304f, 0.01279381f, 0.02112481f, 0.00446795f, 0.08666852f, + 0.01046857f, 0.00011744f, 0.00377885f, 0.00806424f, 0.00110093f, + 0.01087467f, 0.96216726f, 0.00024677f, 0.00213707f, 0.00104427f, + 0.00835356f, 0.00037980f, 0.00540865f, 0.91882282f, 0.00084274f, + 0.03935680f, 0.00700863f, 0.00609934f, 0.00307425f, 0.01065346f, + 0.09310398f, 0.00066428f, 0.00076882f, 0.02210450f, 0.04447530f, + 0.77650899f, 0.00945148f, 0.00689890f, 0.00886871f, 0.03715509f, + 0.07214937f, 0.00624633f, 0.01399398f, 0.29444799f, 0.03825752f, + 0.36904955f, 0.02109544f, 0.01373637f, 0.14653027f, 0.02449317f, + 0.01878268f, 0.01089148f, 0.36442387f, 0.01426089f, 0.02649262f, + 0.00308395f, 0.51123023f, 0.00987128f, 0.02856500f, 0.01239803f, + 0.65732223f, 0.00001665f, 0.00257388f, 0.02261361f, 0.00056261f, + 0.08028404f, 0.00753943f, 0.00092872f, 0.22300763f, 0.00515121f, + 0.00238470f, 0.00001802f, 0.00303019f, 0.00282769f, 0.93392336f, + 0.00829813f, 0.00937593f, 0.00232166f, 0.00606702f, 0.03175319f, + 0.00192149f, 0.89188498f, 0.01474108f, 0.03585867f, 0.00123343f, + 0.00441551f, 0.00399710f, 0.00857630f, 0.01781271f, 0.01955875f, + 0.00221238f, 0.00005268f, 0.00038176f, 0.00141851f, 0.07513693f, + 0.00153898f, 0.00254140f, 0.04116146f, 0.00216117f, 0.87339473f, + 0.17824675f, 0.04543359f, 0.01501061f, 0.03382575f, 0.09682461f, + 0.29989448f, 0.02655865f, 0.16809541f, 0.09566309f, 0.04044705f, + 0.00052125f, 0.00006512f, 0.00041621f, 0.03254773f, 0.00120942f, + 0.00177929f, 0.00091721f, 0.95285058f, 0.00068729f, 0.00900588f, + 0.04185560f, 0.00125587f, 0.33473280f, 0.00119652f, 0.00552071f, + 0.03358750f, 0.04974457f, 0.00243473f, 0.41644078f, 0.11323092f, + 0.00945223f, 0.00509389f, 0.04602458f, 0.02943204f, 0.23871920f, + 0.06141117f, 0.05274383f, 0.03511769f, 0.09954999f, 0.42245534f, + 0.00686926f, 0.01075546f, 0.49830484f, 0.37111449f, 0.00928881f, + 0.00910977f, 0.00822666f, 0.00448587f, 0.04094843f, 0.04089646f, + 0.00190534f, 0.00074783f, 0.02465805f, 0.02045769f, 0.02690129f, + 0.00249506f, 0.00202899f, 0.84847659f, 0.01121813f, 0.06111111f, + 0.00527403f, 0.00617689f, 0.00719898f, 0.17549324f, 0.25461593f, + 0.15036304f, 0.04163047f, 0.01647436f, 0.08906800f, 0.25370511f, + 0.10200825f, 0.03916828f, 0.22575049f, 0.08762794f, 0.06703069f, + 0.01087492f, 0.27197123f, 0.15926389f, 0.02289790f, 0.01340644f, + 0.00233572f, 0.00071111f, 0.01389953f, 0.00187034f, 0.89338356f, + 0.00067592f, 0.00535080f, 0.02598928f, 0.01003115f, 0.04575264f, + 0.00010197f, 0.00006095f, 0.00021980f, 0.99164659f, 0.00011408f, + 0.00474983f, 0.00004892f, 0.00012496f, 0.00257160f, 0.00036128f, + 0.91125363f, 0.00012225f, 0.02511939f, 0.00156989f, 0.00002669f, + 0.03335980f, 0.01791442f, 0.00531134f, 0.00345027f, 0.00187230f, + 0.00210833f, 0.00001888f, 0.00016036f, 0.00394190f, 0.00016232f, + 0.00026980f, 0.00012382f, 0.99098623f, 0.00036967f, 0.00185874f, + 0.99578768f, 0.00000018f, 0.00162244f, 0.00012927f, 0.00000136f, + 0.00158810f, 0.00016544f, 0.00000476f, 0.00069853f, 0.00000226f, + 0.19834445f, 0.00044551f, 0.40857196f, 0.34896207f, 0.00023418f, + 0.00828141f, 0.02426279f, 0.00148875f, 0.00938030f, 0.00002860f, + 0.00201644f, 0.06109568f, 0.01542680f, 0.05984236f, 0.00112191f, + 0.00419699f, 0.00110061f, 0.28937989f, 0.13231210f, 0.43350723f, + 0.00055382f, 0.92216444f, 0.00396460f, 0.01456171f, 0.00061405f, + 0.00972675f, 0.00677260f, 0.00454273f, 0.02471014f, 0.01238921f, + 0.00027888f, 0.02572848f, 0.00290584f, 0.00748292f, 0.08441166f, + 0.00232722f, 0.00188305f, 0.81133318f, 0.01191756f, 0.05173124f, + 0.00315098f, 0.00499059f, 0.00158580f, 0.92859417f, 0.00035086f, + 0.04807130f, 0.00101955f, 0.00034313f, 0.01119398f, 0.00069962f, + 0.00112821f, 0.00214349f, 0.03968662f, 0.00325992f, 0.00253143f, + 0.00199443f, 0.00964058f, 0.90529889f, 0.00384289f, 0.03047365f, + 0.00174196f, 0.06674320f, 0.00283191f, 0.09274873f, 0.01944309f, + 0.03424436f, 0.00694406f, 0.07912937f, 0.15087396f, 0.54529935f, + 0.00007096f, 0.00001000f, 0.00001498f, 0.00007066f, 0.00002792f, + 0.00005677f, 0.00000490f, 0.99606401f, 0.00030978f, 0.00337013f, + 0.00286575f, 0.00011636f, 0.00064778f, 0.00992065f, 0.04501861f, + 0.03149971f, 0.00287679f, 0.37334359f, 0.00214695f, 0.53156382f, + 0.00600238f, 0.00003215f, 0.02112119f, 0.00084685f, 0.00497269f, + 0.00753993f, 0.95174772f, 0.00150877f, 0.00212018f, 0.00410815f, + 0.00006566f, 0.00001179f, 0.99827027f, 0.00028396f, 0.00004237f, + 0.00000550f, 0.00091406f, 0.00003423f, 0.00036640f, 0.00000567f, + 0.00079063f, 0.00006855f, 0.00051338f, 0.00590454f, 0.00732460f, + 0.00195139f, 0.00034534f, 0.90222436f, 0.00163695f, 0.07924022f, + 0.00362202f, 0.01493629f, 0.01135249f, 0.00781013f, 0.05138498f, + 0.22704794f, 0.00442778f, 0.00350683f, 0.59828150f, 0.07762999f, + 0.00016529f, 0.00001219f, 0.00006521f, 0.00446292f, 0.94456083f, + 0.00407963f, 0.00102245f, 0.00057420f, 0.00344479f, 0.04161252f, + 0.00000981f, 0.00030270f, 0.00017082f, 0.00029943f, 0.00010159f, + 0.00003605f, 0.00001875f, 0.99310946f, 0.00063157f, 0.00531995f, + 0.01100852f, 0.00021492f, 0.00049603f, 0.59714299f, 0.00454595f, + 0.33691072f, 0.03074775f, 0.00427598f, 0.00512297f, 0.00953417f, + 0.00064403f, 0.00001687f, 0.00822414f, 0.00012918f, 0.02522905f, + 0.00046274f, 0.95950085f, 0.00174588f, 0.00070707f, 0.00334025f, + 0.00014754f, 0.96842438f, 0.00752080f, 0.00713038f, 0.00074491f, + 0.00107368f, 0.00245372f, 0.00181830f, 0.00883226f, 0.00185409f, + 0.00210863f, 0.00017522f, 0.00039881f, 0.98836052f, 0.00003650f, + 0.00535216f, 0.00001887f, 0.00069545f, 0.00265663f, 0.00019714f, + 0.00028919f, 0.00026057f, 0.00356666f, 0.00034738f, 0.00413719f, + 0.00133701f, 0.98608136f, 0.00009625f, 0.00153734f, 0.00234698f, + 0.01427079f, 0.04020482f, 0.04733688f, 0.03817881f, 0.16299380f, + 0.04943828f, 0.03522370f, 0.05902825f, 0.23904003f, 0.31428465f, + 0.00029359f, 0.00005619f, 0.00007707f, 0.98437482f, 0.00000957f, + 0.00828004f, 0.00002787f, 0.00510217f, 0.00087425f, 0.00090444f, + 0.00011413f, 0.83918202f, 0.01017746f, 0.03100164f, 0.00308035f, + 0.01615586f, 0.02608237f, 0.00337026f, 0.05493741f, 0.01589854f, + 0.00053240f, 0.00144792f, 0.00108170f, 0.00027300f, 0.86477506f, + 0.00072790f, 0.01062538f, 0.00428096f, 0.00233054f, 0.11392505f, + 0.00411633f, 0.33660546f, 0.01735369f, 0.18114267f, 0.03090077f, + 0.11699959f, 0.03416851f, 0.06780743f, 0.07481573f, 0.13608985f, + 0.00073468f, 0.20941530f, 0.01012138f, 0.17237675f, 0.01661461f, + 0.02184150f, 0.03694551f, 0.30870155f, 0.04255475f, 0.18069389f, + 0.06343270f, 0.00037455f, 0.06623310f, 0.00041474f, 0.00209181f, + 0.04566626f, 0.81232506f, 0.00054500f, 0.00807252f, 0.00084416f, + 0.00008067f, 0.00003926f, 0.00225794f, 0.00115743f, 0.01925980f, + 0.00010427f, 0.00062067f, 0.02234522f, 0.00210706f, 0.95202768f}); + auto graph = + GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb"); + // graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); + + auto z = graph->variableSpace()->getVariable(6, 0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; } - - TEST_F(FlatBuffersTest, Test_MNIST_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb"); - //graph->printOut(); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb"); + // graph->printOut(); - auto result = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), result); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); - delete graph; + delete graph; } /* @@ -765,9 +1073,11 @@ TEST_F(FlatBuffersTest, Test_MNIST_1) { TEST_F(FlatBuffersTest, nhwc_conv_0) { sd::ops::rank op1; - auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, -2.292647f, -1.791460f, 13.055838f, 4.278642f}); + auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, +-2.292647f, -1.791460f, 13.055838f, 4.278642f}); - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); graph->printOut(); @@ -794,12 +1104,12 @@ TEST_F(FlatBuffersTest, nhwc_conv_0) { */ - /* TEST_F(FlatBuffersTest, ReadLoops_SimpleWhile_1) { // TF graph: // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simple_while.fb"); + auto graph = +GraphExecutioner::importFromFlatBuffers("./resources/simple_while.fb"); ASSERT_TRUE(graph != nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp index ae4d654e9007..400ca6cb84a5 100644 --- a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -20,73 +20,70 @@ #include #include -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" using namespace sd; class FlatUtilsTests : public testing::Test { -public: - + public: }; TEST_F(FlatUtilsTests, flat_float_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, restored); + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_int_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, restored); + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_bool_serde_1) { - auto array = NDArrayFactory::create('c', {4}, {true, false, true, false}); + auto array = + NDArrayFactory::create('c', {4}, {true, false, true, false}); - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto restored = FlatUtils::fromFlatArray(pfArray); - auto restored = FlatUtils::fromFlatArray(pfArray); - - ASSERT_EQ(array, restored); + ASSERT_EQ(array, restored); } TEST_F(FlatUtilsTests, flat_string_serde_1) { - auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); - - flatbuffers::FlatBufferBuilder builder(1024); - auto flatArray = FlatUtils::toFlatArray(builder, array); - builder.Finish(flatArray); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); - auto pfArray = GetFlatArray(builder.GetBufferPointer()); + auto pfArray = GetFlatArray(builder.GetBufferPointer()); - auto restored = FlatUtils::fromFlatArray(pfArray); + auto restored = FlatUtils::fromFlatArray(pfArray); - ASSERT_EQ(array, restored); + ASSERT_EQ(array, restored); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 71c6a73a56ba..4dd898f94c8f 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -18,1016 +18,1011 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include +#include #include +#include + #include -#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class GraphAnalysisTests : public testing::Test { -public: - GraphAnalysisTests() { - /// - } + public: + GraphAnalysisTests() { + /// + } }; TEST_F(GraphAnalysisTests, basic_toposort_test_1) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a(sd::ops::multiply(), "multiply"); - Node b(sd::ops::add(), "add"); + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"multiply", "C"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); - // we just check that nodes were really added - ASSERT_EQ(2, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(2, graph.size()); - auto optimized = graph.optimizedGraph(); + auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(1, optimized.layers()); + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.layers()); - auto layer = optimized.layer(0); + auto layer = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - auto sequence = layer[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(2, sequence.length()); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); - ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); + ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_2) { - Graph graph; - - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Graph graph; - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a(sd::ops::multiply(), "multiply"); - Node b(sd::ops::add(), "add"); - Node c(sd::ops::subtract(), "subtract"); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + Node c(sd::ops::subtract(), "subtract"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"multiply", "C"}); - graph.addNode(c, {"multiply", "D"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + graph.addNode(c, {"multiply", "D"}); - // we just check that nodes were really added - ASSERT_EQ(3, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(3, graph.size()); - auto optimized = graph.optimizedGraph(); + auto optimized = graph.optimizedGraph(); - // graph size must stay the same - ASSERT_EQ(3, graph.size()); + // graph size must stay the same + ASSERT_EQ(3, graph.size()); - // we expect that OptimizedGraph has exactly 2 layers - ASSERT_EQ(2, optimized.layers()); + // we expect that OptimizedGraph has exactly 2 layers + ASSERT_EQ(2, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); + // checking first layer first + auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - ; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + ; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); + // checking second layer now + auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_3) { - Graph graph; - - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Graph graph; - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); - graph.addNode(c, {"b", "D"}); - graph.addNode(d, {"b", "D"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); - graph.addNode(e, {"c", "d"}); + graph.addNode(c, {"b", "D"}); + graph.addNode(d, {"b", "D"}); - // we just check that nodes were really added - ASSERT_EQ(5, graph.size()); + graph.addNode(e, {"c", "d"}); - auto optimized = graph.optimizedGraph(); + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(3, optimized.layers()); + auto optimized = graph.optimizedGraph(); - // checking first layer first - auto layer0 = optimized.layer(0); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(3, optimized.layers()); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - //auto sequence = layer0[0]; + // checking first layer first + auto layer0 = optimized.layer(0); - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(2, layer0[0].length()); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + // auto sequence = layer0[0]; - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, layer0[0].length()); - // checking second layer now - auto layer1 = optimized.layer(1); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); + // checking second layer now + auto layer1 = optimized.layer(1); - //sequence = layer1[0]; + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + // sequence = layer1[0]; - //sequence = layer1[1]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + // sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer2 = optimized.layer(2); + // checking last layer + auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - //sequence = layer2[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_4) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // E + graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // F - graph.addVariable("F", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // F + graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); - Node a1(sd::ops::multiply(), "a1"); - Node a2(sd::ops::add(), "a2"); + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); - Node b1(sd::ops::subtract(), "b1"); - Node b2(sd::ops::add(), "b2"); - Node b3(sd::ops::multiply(), "b3"); + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::add(), "d2"); - Node d1(sd::ops::multiply(), "d1"); - Node d2(sd::ops::add(), "d2"); + Node e(sd::ops::subtract(), "e"); - Node e(sd::ops::subtract(), "e"); + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); - graph.addNode(a1, { "A", "B" }); - graph.addNode(a2, { "C", "D" }); + graph.addNode(b1, {"a1", "E"}); + graph.addNode(b2, {"a1", "a2"}); + graph.addNode(b3, {"a2", "F"}); - graph.addNode(b1, { "a1", "E" }); - graph.addNode(b2, { "a1", "a2" }); - graph.addNode(b3, { "a2", "F" }); + graph.addNode(d1, {"b1", "b2"}); + graph.addNode(d2, {"b3", "b2"}); - graph.addNode(d1, { "b1", "b2" }); - graph.addNode(d2, { "b3", "b2" }); + graph.addNode(e, {"d1", "d2"}); - graph.addNode(e, { "d1", "d2" }); + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); - // we just check that nodes were really added - ASSERT_EQ(8, graph.size()); + auto optimized = graph.optimizedGraph(); - auto optimized = graph.optimizedGraph(); + // we expect that OptimizedGraph has exactly 4 layer + ASSERT_EQ(4, optimized.layers()); - // we expect that OptimizedGraph has exactly 4 layer - ASSERT_EQ(4, optimized.layers()); + // checking first layer first + auto layer0 = optimized.layer(0); - // checking first layer first - auto layer0 = optimized.layer(0); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer0.width()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + // checking second layer now + auto layer1 = optimized.layer(1); - // checking second layer now - auto layer1 = optimized.layer(1); + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); - // we expect layer has exactly 3 OpSequences - ASSERT_EQ(3, layer1.width()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); + auto layer2 = optimized.layer(2); - auto layer2 = optimized.layer(2); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer2.width()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); + // checking last layer + auto layer3 = optimized.layer(3); - // checking last layer - auto layer3 = optimized.layer(3); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer3.width()); - - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_5) { - Graph graph; + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + Node f(sd::ops::multiply(), "f"); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node g(sd::ops::multiply(), "g"); + Node h(sd::ops::multiply(), "h"); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); - Node f(sd::ops::multiply(), "f"); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"C", "D"}); - - Node g(sd::ops::multiply(), "g"); - Node h(sd::ops::multiply(), "h"); + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"a", "b"}); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"C", "D"}); + graph.addNode(e, {"c", "d"}); + graph.addNode(f, {"c", "d"}); - graph.addNode(c, {"a", "b"}); - graph.addNode(d, {"a", "b"}); + graph.addNode(g, {"c", "e"}); + graph.addNode(h, {"d", "f"}); - graph.addNode(e, {"c", "d"}); - graph.addNode(f, {"c", "d"}); + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); - graph.addNode(g, {"c", "e"}); - graph.addNode(h, {"d", "f"}); + auto optimized = graph.optimizedGraph(); - // we just check that nodes were really added - ASSERT_EQ(8, graph.size()); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(4, optimized.layers()); - auto optimized = graph.optimizedGraph(); + // checking first layer first + auto layer0 = optimized.layer(0); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(4, optimized.layers()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); + // auto sequence = layer0[0]; - // checking first layer first - auto layer0 = optimized.layer(0); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer0.width()); - //auto sequence = layer0[0]; + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); + // sequence = layer0[1]; - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - - //sequence = layer0[1]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); + // checking second layer now + auto layer1 = optimized.layer(1); - // checking second layer now - auto layer1 = optimized.layer(1); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); + // sequence = layer1[0]; - //sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - - //sequence = layer1[1]; + // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); - // checking before last layer - auto layer2 = optimized.layer(2); + // checking before last layer + auto layer2 = optimized.layer(2); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer2.width()); - //sequence = layer2[0]; + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); + // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); - //sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); + // sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer3 = optimized.layer(3); + // checking last layer + auto layer3 = optimized.layer(3); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer3.width()); - //sequence = layer3[0]; + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + // sequence = layer3[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); - - //sequence = layer3[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[1].length()); - ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); + // sequence = layer3[1]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_6) { - Graph graph; + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // E + graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // F + graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node a(sd::ops::multiply(), "a"); + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::subtract(), "b2"); - // F - graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node c1(sd::ops::add(), "c1"); + Node c2(sd::ops::multiply(), "c2"); + Node c3(sd::ops::subtract(), "c3"); - Node a(sd::ops::multiply(), "a"); - Node b1(sd::ops::add(), "b1"); - Node b2(sd::ops::subtract(), "b2"); - - Node c1(sd::ops::add(), "c1"); - Node c2(sd::ops::multiply(), "c2"); - Node c3(sd::ops::subtract(), "c3"); - - Node d1(sd::ops::multiply(), "d1"); - Node d2(sd::ops::multiply(), "d2"); + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::multiply(), "d2"); - Node e(sd::ops::add(), "e"); + Node e(sd::ops::add(), "e"); - graph.addNode(a, {"A", "B"}); + graph.addNode(a, {"A", "B"}); - graph.addNode(b1, {"a", "C"}); - graph.addNode(b2, {"a", "D"}); + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "D"}); - graph.addNode(c1, {"b1", "E"}); - graph.addNode(c2, {"b1", "b2"}); - graph.addNode(c3, {"b2", "F"}); + graph.addNode(c1, {"b1", "E"}); + graph.addNode(c2, {"b1", "b2"}); + graph.addNode(c3, {"b2", "F"}); - graph.addNode(d1, {"c1", "c2"}); - graph.addNode(d2, {"c2", "c3"}); + graph.addNode(d1, {"c1", "c2"}); + graph.addNode(d2, {"c2", "c3"}); - graph.addNode(e, {"d1", "d2"}); + graph.addNode(e, {"d1", "d2"}); - // we just check that nodes were really added - ASSERT_EQ(9, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(9, graph.size()); - auto optimized = graph.optimizedGraph(); + auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(5, optimized.layers()); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); + // checking first layer first + auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); - //auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + // auto sequence = layer0[0]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); + // checking second layer now + auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); - //sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); - - //sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); + // sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); - // checking midle layer - auto layer2 = optimized.layer(2); + // sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(3, layer2.width()); + // checking midle layer + auto layer2 = optimized.layer(2); - //sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(3, layer2.width()); - //sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); + // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - //sequence = layer2[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); + // sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); - // checking before last layer - auto layer3 = optimized.layer(3); + // sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer3.width()); - //sequence = layer3[0]; + // checking before last layer + auto layer3 = optimized.layer(3); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); - - //sequence = layer3[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[1].length()); - ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); - - // checking last layer - auto layer4 = optimized.layer(4); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + // sequence = layer3[0]; - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(1, layer4.width()); - //sequence = layer4[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer4[0].length()); - ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); + // sequence = layer3[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); + // checking last layer + auto layer4 = optimized.layer(4); + + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(1, layer4.width()); + // sequence = layer4[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_7) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"a", "b"}); - graph.addNode(d, {"b", "c"}); + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"b", "c"}); - graph.addNode(e, {"b", "c", "d"}); + graph.addNode(e, {"b", "c", "d"}); - // we just check that nodes were really added - ASSERT_EQ(5, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); - auto optimized = graph.optimizedGraph(); + auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(5, optimized.layers()); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); + // checking first layer first + auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - //auto sequence = layer0[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); + // checking second layer now + auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(1, layer1.width()); - //sequence = layer1[0]; + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(1, layer1.width()); + // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); - // checking layer 2 - auto layer2 = optimized.layer(2); + // checking layer 2 + auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - //sequence = layer2[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); - // checking layer 3 - auto layer3 = optimized.layer(3); + // checking layer 3 + auto layer3 = optimized.layer(3); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer3.width()); - //sequence = layer3[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); + // sequence = layer3[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); - // checking layer 3 - auto layer4 = optimized.layer(4); + // checking layer 3 + auto layer4 = optimized.layer(4); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer4.width()); - //sequence = layer4[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer4.width()); + // sequence = layer4[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer4[0].length()); - ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); } - TEST_F(GraphAnalysisTests, basic_toposort_test_8) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // E + graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // F - graph.addVariable("F", NDArrayFactory::create('c', { 3 }, { 1, 1, 1 })); + // F + graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + Node a3(sd::ops::add(), "a3"); - Node a1(sd::ops::multiply(), "a1"); - Node a2(sd::ops::add(), "a2"); - Node a3(sd::ops::add(), "a3"); + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); - Node b1(sd::ops::subtract(), "b1"); - Node b2(sd::ops::add(), "b2"); - Node b3(sd::ops::multiply(), "b3"); + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); + graph.addNode(a3, {"E", "F"}); - graph.addNode(a1, { "A", "B" }); - graph.addNode(a2, { "C", "D" }); - graph.addNode(a3, { "E", "F" }); + graph.addNode(b1, {"a1", "a2"}); + graph.addNode(b2, {"a1", "a2", "a3"}); + graph.addNode(b3, {"a2", "a3"}); - graph.addNode(b1, { "a1", "a2" }); - graph.addNode(b2, { "a1", "a2", "a3" }); - graph.addNode(b3, { "a2", "a3" }); + // we just check that nodes were really added + ASSERT_EQ(6, graph.size()); - // we just check that nodes were really added - ASSERT_EQ(6, graph.size()); + auto optimized = graph.optimizedGraph(); - auto optimized = graph.optimizedGraph(); + // we expect that OptimizedGraph has exactly 2 layer + ASSERT_EQ(2, optimized.layers()); - // we expect that OptimizedGraph has exactly 2 layer - ASSERT_EQ(2, optimized.layers()); + // checking first layer first + auto layer0 = optimized.layer(0); - // checking first layer first - auto layer0 = optimized.layer(0); + // we expect layer has exactly 3 OpSequence + ASSERT_EQ(3, layer0.width()); + // auto sequence = layer0[0]; - // we expect layer has exactly 3 OpSequence - ASSERT_EQ(3, layer0.width()); - //auto sequence = layer0[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + // sequence = layer0[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - //sequence = layer0[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + // sequence = layer0[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[2].length()); + ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); - //sequence = layer0[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[2].length()); - ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); + // checking second layer now + auto layer1 = optimized.layer(1); - // checking second layer now - auto layer1 = optimized.layer(1); + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); - // we expect layer has exactly 3 OpSequences - ASSERT_EQ(3, layer1.width()); + // sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); - //sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); + // sequence = layer1[1]; - //sequence = layer1[1]; - - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); - - //sequence = layer1[2]; - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); + // sequence = layer1[2]; + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_9) { + // start graph - // start graph + Graph graph; - Graph graph; + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - Node a(sd::ops::multiply(), "a"); - - Node b1(sd::ops::add(), "b1"); - Node b2(sd::ops::multiply(), "b2"); - Node b3(sd::ops::subtract(), "b3"); - Node b4(sd::ops::Pow(), "b4"); - - Node c1(sd::ops::Pow(), "c1"); - Node c2(sd::ops::subtract(), "c2"); - Node c3(sd::ops::multiply(), "c3"); - Node c4(sd::ops::add(), "c4"); - - Node c5(sd::ops::Pow(), "c5"); - Node c6(sd::ops::subtract(), "c6"); - Node c7(sd::ops::multiply(), "c7"); - Node c8(sd::ops::add(), "c8"); - - graph.addNode(a, {"A", "B"}); - - graph.addNode(b1, {"a", "C"}); - graph.addNode(b2, {"a", "C"}); - graph.addNode(b3, {"a", "C"}); - graph.addNode(b4, {"a", "C"}); - - graph.addNode(c1, {"b1", "D"}); - graph.addNode(c2, {"b2", "D"}); - graph.addNode(c3, {"b3", "D"}); - graph.addNode(c4, {"b4", "D"}); - - graph.addNode(c5, {"b1", "D"}); - graph.addNode(c6, {"b2", "D"}); - graph.addNode(c7, {"b3", "D"}); - graph.addNode(c8, {"b4", "D"}); - - // we just check that nodes were really added - ASSERT_EQ(13, graph.size()); - - auto optimized = graph.optimizedGraph(); - - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(3, optimized.layers()); - - auto layer = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - //auto sequence = layer[0]; - - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer[0].length()); - ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); - - auto layer1 = optimized.layer(1); - // we expect layer has exactly 4 OpSequence - ASSERT_EQ(4, layer1.width()); - //sequence = layer1[0]; - - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - - //sequence = layer1[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); - - //sequence = layer1[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); - - //sequence = layer1[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[3].length()); - ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); - - auto layer2 = optimized.layer(2); - // we expect layer has exactly 4 OpSequence - ASSERT_EQ(8, layer2.width()); - //sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - - //sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); - - //sequence = layer2[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); - - //sequence = layer2[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[3].length()); - ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); - - //sequence = layer2[4]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[4].length()); - ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); - - //sequence = layer2[5]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[5].length()); - ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); - - //sequence = layer2[6]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[6].length()); - ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); - - //sequence = layer2[7]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[7].length()); - ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a(sd::ops::multiply(), "a"); + + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::multiply(), "b2"); + Node b3(sd::ops::subtract(), "b3"); + Node b4(sd::ops::Pow(), "b4"); + + Node c1(sd::ops::Pow(), "c1"); + Node c2(sd::ops::subtract(), "c2"); + Node c3(sd::ops::multiply(), "c3"); + Node c4(sd::ops::add(), "c4"); + + Node c5(sd::ops::Pow(), "c5"); + Node c6(sd::ops::subtract(), "c6"); + Node c7(sd::ops::multiply(), "c7"); + Node c8(sd::ops::add(), "c8"); + + graph.addNode(a, {"A", "B"}); + + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "C"}); + graph.addNode(b3, {"a", "C"}); + graph.addNode(b4, {"a", "C"}); + + graph.addNode(c1, {"b1", "D"}); + graph.addNode(c2, {"b2", "D"}); + graph.addNode(c3, {"b3", "D"}); + graph.addNode(c4, {"b4", "D"}); + + graph.addNode(c5, {"b1", "D"}); + graph.addNode(c6, {"b2", "D"}); + graph.addNode(c7, {"b3", "D"}); + graph.addNode(c8, {"b4", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(13, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.layers()); + + auto layer = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + // auto sequence = layer[0]; + + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer[0].length()); + ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); + + auto layer1 = optimized.layer(1); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(4, layer1.width()); + // sequence = layer1[0]; + + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); + + // sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); + + // sequence = layer1[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); + + // sequence = layer1[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[3].length()); + ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); + + auto layer2 = optimized.layer(2); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(8, layer2.width()); + // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + + // sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); + + // sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); + + // sequence = layer2[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[3].length()); + ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); + + // sequence = layer2[4]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[4].length()); + ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); + + // sequence = layer2[5]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[5].length()); + ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); + + // sequence = layer2[6]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[6].length()); + ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); + + // sequence = layer2[7]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[7].length()); + ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_10) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::multiply(), "c"); - Node d(sd::ops::subtract(), "d"); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"a", "D"}); - graph.addNode(d, {"a", "b", "c"}); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + graph.addNode(c, {"a", "D"}); + graph.addNode(d, {"a", "b", "c"}); - // we just check that nodes were really added - ASSERT_EQ(4, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); - auto optimized = graph.optimizedGraph(); + auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(3, optimized.layers()); + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.layers()); - auto layer = optimized.layer(0); + auto layer = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - auto sequence = layer[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - auto layer1 = optimized.layer(1); + auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); - sequence = layer1[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer1.width()); + sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + auto layer2 = optimized.layer(2); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, basic_toposort_test_11) { - Graph graph; - - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::multiply(), "c"); - Node d(sd::ops::subtract(), "d"); - - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"A", "C"}); - graph.addNode(c, {"B", "D"}); - graph.addNode(d, {"C", "D"}); - - // we just check that nodes were really added - ASSERT_EQ(4, graph.size()); - - auto optimized = graph.optimizedGraph(); - - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(1, optimized.layers()); - - auto layer = optimized.layer(0); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(4, layer.width()); - auto sequence = layer[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - sequence = layer[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); - sequence = layer[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - sequence = layer[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); + Graph graph; + + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + // D + graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); + + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"A", "C"}); + graph.addNode(c, {"B", "D"}); + graph.addNode(d, {"C", "D"}); + + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); + + auto optimized = graph.optimizedGraph(); + + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.layers()); + + auto layer = optimized.layer(0); + + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(4, layer.width()); + auto sequence = layer[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + sequence = layer[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); + sequence = layer[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + sequence = layer[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, test_cond_1) { - auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); - - auto optimized = graph.optimizedGraph(); - /* - some infor that would be useful for implementation - currently on optimization graph is passing next data - - Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation class: -1719689536 - Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation type: 21; Operation class: -1719689536 - Node name: cond/Switch; ID: 9; Input: 1, 0; Operation type: 119; Operation class: -1719689536 - Node name: cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: -1719689536 - Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: 119; Operation class: -1719689536 - Node name: cond/Merge; ID: 8; Input: 7, 0; Operation type: 119; Operation class: -1719689536 - Node name: in_0/read; ID: 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 - Node name: cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: -1719689536 - Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; Operation class: -1719689536 - Node name: cond/LinSpace; ID: 7; Input: 4, 0; Operation type: 21; Operation class: -1719689536 - - as it can be seen cond/LinSpace is not connected with any switch node(s) that causes wrong results of optimization. - also maybe to cover all conditional operations will be need "Operation class", but this have to discovered deeper. - - All above is true for test_cond_2 - */ - + auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + + auto optimized = graph.optimizedGraph(); + /* + some infor that would be useful for implementation + currently on optimization graph is passing next data + + Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation + class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation + type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; + Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: + cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: + -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: + 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, + 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: + 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: + cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: + -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; + Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; + Operation type: 21; Operation class: -1719689536 + + as it can be seen cond/LinSpace is not connected with any switch node(s) that + causes wrong results of optimization. also maybe to cover all conditional + operations will be need "Operation class", but this have to discovered deeper. + + All above is true for test_cond_2 + */ } TEST_F(GraphAnalysisTests, test_cond_2) { - auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); + auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index f5d967469354..556b688e759e 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -19,88 +19,87 @@ // Created by raver119 on 29.11.17. // - -#include "testlayers.h" +#include #include -#include -#include -#include #include -#include -#include -#include +#include #include +#include +#include +#include +#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class GraphExecutorTests : public testing::Test { -public: - + public: }; TEST_F(GraphExecutorTests, test_basic_exec_1) { - GraphMemoryManager memoryManager; - Graph graph; + GraphMemoryManager memoryManager; + Graph graph; - OptimizedGraph optimizedGraph; - OpSequence sequence; + OptimizedGraph optimizedGraph; + OpSequence sequence; - optimizedGraph.append(sequence); + optimizedGraph.append(sequence); - VariableProxy proxy(&graph.variableSpace()); - GraphExecutor executor; - executor.execute(optimizedGraph, proxy); + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); } TEST_F(GraphExecutorTests, test_basic_exec_2) { - GraphMemoryManager mgr; - Graph graph(nullptr, mgr); + GraphMemoryManager mgr; + Graph graph(nullptr, mgr); - auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); - auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); + auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); + auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); - graph.addVariable("A", A); - graph.addVariable("B", B); - graph.addVariable("C", C); + graph.addVariable("A", A); + graph.addVariable("B", B); + graph.addVariable("C", C); - Node m(sd::ops::multiply(), "mul"); - Node a(sd::ops::add(), "add"); + Node m(sd::ops::multiply(), "mul"); + Node a(sd::ops::add(), "add"); - graph.addNode(m, {"A", "B"}); - graph.addNode(a, {"mul", "C"}); + graph.addNode(m, {"A", "B"}); + graph.addNode(a, {"mul", "C"}); - OptimizedGraph optimizedGraph; - OpSequence sequence; + OptimizedGraph optimizedGraph; + OpSequence sequence; - ASSERT_EQ(2, m.protoContext().inputs().size()); - ASSERT_EQ(2, a.protoContext().inputs().size()); + ASSERT_EQ(2, m.protoContext().inputs().size()); + ASSERT_EQ(2, a.protoContext().inputs().size()); - sequence.append(m.customOp(), m.protoContext()); - sequence.append(a.customOp(), a.protoContext()); + sequence.append(m.customOp(), m.protoContext()); + sequence.append(a.customOp(), a.protoContext()); - optimizedGraph.append(sequence); + optimizedGraph.append(sequence); - ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(1, optimizedGraph.layers()); + ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(1, optimizedGraph.layers()); - VariableProxy proxy(&graph.variableSpace()); - GraphExecutor executor; - executor.execute(optimizedGraph, proxy); + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); - // checking results by ID - ASSERT_TRUE(proxy.hasVariable(m.id())); - ASSERT_TRUE(proxy.hasVariable(a.id())); + // checking results by ID + ASSERT_TRUE(proxy.hasVariable(m.id())); + ASSERT_TRUE(proxy.hasVariable(a.id())); - // checking results by name - ASSERT_TRUE(proxy.hasVariable("mul")); - ASSERT_TRUE(proxy.hasVariable("add")); + // checking results by name + ASSERT_TRUE(proxy.hasVariable("mul")); + ASSERT_TRUE(proxy.hasVariable("add")); - // checking if result is valid - auto result = proxy.getVariable(a.id())->getNDArray(); - ASSERT_EQ(exp, *result); + // checking if result is valid + auto result = proxy.getVariable(a.id())->getNDArray(); + ASSERT_EQ(exp, *result); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp index 9faf685abf01..1891325c58b0 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -18,26 +18,26 @@ // Created by raver119 on 11.12.17. // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class GraphHolderTests : public testing::Test { -public: - + public: }; TEST_F(GraphHolderTests, SimpleTests_1) { - Graph graph; - Nd4jLong graphId = 119; - GraphHolder::getInstance()->registerGraph(graphId, graph); + Graph graph; + Nd4jLong graphId = 119; + GraphHolder::getInstance()->registerGraph(graphId, graph); - ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); + ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(graphId)); - GraphHolder::getInstance()->forgetGraph(graphId); + GraphHolder::getInstance()->forgetGraph(graphId); - ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); + ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp index 8fe46cd2f791..b93721b1e1d4 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp @@ -14,251 +14,247 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include #include #include +#include + #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class GraphRandomGeneratorTests : public testing::Test { -public: - + public: }; TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_1) { - sd::graph::RandomGenerator g0(119); - sd::graph::RandomGenerator g1(119); + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(119); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_EQ(i0, i1); + ASSERT_EQ(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_2) { - sd::graph::RandomGenerator g0(119); - sd::graph::RandomGenerator g1(117); + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(117); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_3) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 10); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 10); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_4) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(117, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(117, 5); - auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); - ASSERT_NE(i0, i1); + ASSERT_NE(i0, i1); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // but not after one of them was rewinded - ASSERT_NE(r1, r0); + // but not after one of them was rewinded + ASSERT_NE(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_2) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(199); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(199); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // but not after they was rewinded with different number of elements - ASSERT_NE(r1, r0); + // but not after they was rewinded with different number of elements + ASSERT_NE(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_3) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - // values after rewind aren't equal - ASSERT_NE(r0, v0); + // values after rewind aren't equal + ASSERT_NE(r0, v0); - // two generators must give the same output - ASSERT_EQ(v0, v1); + // two generators must give the same output + ASSERT_EQ(v0, v1); - // and here output must be equal as well - ASSERT_EQ(r1, r0); + // and here output must be equal as well + ASSERT_EQ(r1, r0); } TEST_F(GraphRandomGeneratorTests, Sequential_Test_4) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); - - auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(200); - g1.rewindH(200); - auto z0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto z1 = g1.relativeT(15, 0, DataTypeUtils::max()); - g0.rewindH(201); - g1.rewindH(199); - auto y0 = g0.relativeT(15, 0, DataTypeUtils::max()); - auto y1 = g1.relativeT(15, 0, DataTypeUtils::max()); - - // values after rewind aren't equal - ASSERT_NE(r0, v0); - - // two generators must give the same output - ASSERT_EQ(v0, v1); - - // and here output must be equal as well - ASSERT_EQ(r0, r1); - - ASSERT_EQ(z0, z1); - - ASSERT_NE(r0, z0); - ASSERT_NE(r1, z1); - - ASSERT_NE(y0, z0); - ASSERT_NE(y1, z1); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto z0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto z1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(201); + g1.rewindH(199); + auto y0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto y1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // and here output must be equal as well + ASSERT_EQ(r0, r1); + + ASSERT_EQ(z0, z1); + + ASSERT_NE(r0, z0); + ASSERT_NE(r1, z1); + + ASSERT_NE(y0, z0); + ASSERT_NE(y1, z1); } - //#ifndef __clang__ TEST_F(GraphRandomGeneratorTests, Long_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - std::array z0, z1, z2, z3; + std::array z0, z1, z2, z3; - for (int e = 0; e < z0.size(); e++) { - z0[e] = g0.relativeT(e); - z1[e] = g1.relativeT(e); - } + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e); + z1[e] = g1.relativeT(e); + } - g0.rewindH(z0.size()); - g1.rewindH(z0.size()); + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); - for (int e = 0; e < z0.size(); e++) { - z2[e] = g0.relativeT(e); - z3[e] = g1.relativeT(e); - } + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e); + z3[e] = g1.relativeT(e); + } - // these sequences should be equal - ASSERT_EQ(z0, z1); - ASSERT_EQ(z2, z3); + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); - // these sequences should be different due to rewind - ASSERT_NE(z0, z3); + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); - // we'll be counting values > MAX_INT here - int maxes = 0; + // we'll be counting values > MAX_INT here + int maxes = 0; - for (int e = 0; e < z0.size(); e++) { - auto v = z0[e]; + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; - // we don't want any negatives here - ASSERT_TRUE(v > 0); + // we don't want any negatives here + ASSERT_TRUE(v > 0); - if (v > DataTypeUtils::max()) - maxes++; - } + if (v > DataTypeUtils::max()) maxes++; + } - // and now we're ensuring there ARE values above MAX_INT - ASSERT_NE(0, maxes); + // and now we're ensuring there ARE values above MAX_INT + ASSERT_NE(0, maxes); } - TEST_F(GraphRandomGeneratorTests, FloatingPoint_Test_1) { - sd::graph::RandomGenerator g0(119, 5); - sd::graph::RandomGenerator g1(119, 5); - - std::array z0, z1, z2, z3; + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); - for (int e = 0; e < z0.size(); e++) { - z0[e] = g0.relativeT(e, -1.0, 1.0); - z1[e] = g1.relativeT(e, -1.0, 1.0); - } + std::array z0, z1, z2, z3; - g0.rewindH(z0.size()); - g1.rewindH(z0.size()); + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e, -1.0, 1.0); + z1[e] = g1.relativeT(e, -1.0, 1.0); + } - for (int e = 0; e < z0.size(); e++) { - z2[e] = g0.relativeT(e, -1.0, 1.0); - z3[e] = g1.relativeT(e, -1.0, 1.0); - } + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); - // these sequences should be equal - ASSERT_EQ(z0, z1); - ASSERT_EQ(z2, z3); + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e, -1.0, 1.0); + z3[e] = g1.relativeT(e, -1.0, 1.0); + } - // these sequences should be different due to rewind - ASSERT_NE(z0, z3); + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); - // we'll count negatives as well - int negs = 0; + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); - // make sure every value stays within distribution borders - for (int e = 0; e < z0.size(); e++) { - auto v = z0[e]; - if (!(v >= -1.0 && v <= 1.0)) { - nd4j_printf("Failed at idx [%i]: %f\n", e, (float) v); - ASSERT_TRUE(v >= -1.0 && v <= 1.0); - } + // we'll count negatives as well + int negs = 0; - if (v < 0.0) - negs++; + // make sure every value stays within distribution borders + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; + if (!(v >= -1.0 && v <= 1.0)) { + nd4j_printf("Failed at idx [%i]: %f\n", e, (float)v); + ASSERT_TRUE(v >= -1.0 && v <= 1.0); } - // there should be negatives - ASSERT_TRUE(negs > 0); + if (v < 0.0) negs++; + } - // and positives - ASSERT_NE(z0.size(), negs); -} + // there should be negatives + ASSERT_TRUE(negs > 0); + // and positives + ASSERT_NE(z0.size(), negs); +} diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp index 4f77d0501574..9f5b2fe66e55 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests.cpp @@ -18,40 +18,43 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include #include -#include -#include -#include #include #include -#include +#include +#include +#include #include + #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class GraphTests : public testing::Test { -public: - /* - int cShape[] = {2, 2, 2, 2, 1, 0, 1, 99}; - int fShape[] = {2, 2, 2, 1, 2, 0, 1, 102}; - */ - GraphTests() { - //Environment::getInstance()->setDebug(true); - //Environment::getInstance()->setVerbose(true); - } + public: + /* + int cShape[] = {2, 2, 2, 2, 1, 0, 1, 99}; + int fShape[] = {2, 2, 2, 1, 2, 0, 1, 102}; + */ + GraphTests() { + // Environment::getInstance()->setDebug(true); + // Environment::getInstance()->setVerbose(true); + } }; - /* TEST_F(GraphTests, Test_Minifier_1) { // run preprocessor to produce single header // if all ok - return value is 0, if error - non-zero value will be returned - std::string input("../include/ops/ops.h"); //declarable/CustomOperations.h"); + std::string input("../include/ops/ops.h"); +//declarable/CustomOperations.h"); - ASSERT_EQ(0, GraphUtils::runPreprocessor(input.c_str(), "libnd4j_mini.hpp")); + ASSERT_EQ(0, GraphUtils::runPreprocessor(input.c_str(), +"libnd4j_mini.hpp")); // remove file from filesystem #ifdef __linux__ ASSERT_EQ(0, unlink("libnd4j_mini.hpp")); @@ -60,24 +63,23 @@ TEST_F(GraphTests, Test_Minifier_1) { */ TEST_F(GraphTests, Test_Minifier_2) { - - // run preprocessor to produce single header - // if all ok - return value is 0, if error - non-zero value will be returned - ASSERT_EQ(0, GraphUtils::runPreprocessor("../include/ops/specials.h", "libnd4j_mini2.hpp")); - // remove file from filesystem + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned + ASSERT_EQ(0, GraphUtils::runPreprocessor("../include/ops/specials.h", + "libnd4j_mini2.hpp")); + // remove file from filesystem #ifdef __linux__ - ASSERT_EQ(0, unlink("libnd4j_mini2.hpp")); + ASSERT_EQ(0, unlink("libnd4j_mini2.hpp")); #endif } TEST_F(GraphTests, Test_Minifier_3) { - - // run preprocessor to produce single header - // if all ok - return value is 0, if error - non-zero value will be returned + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned #ifdef __linux__ - ASSERT_EQ(0x100, GraphUtils::runPreprocessor("/include/ops/ops.h", "libnd4j_mini3.hpp")); + ASSERT_EQ(0x100, GraphUtils::runPreprocessor("/include/ops/ops.h", + "libnd4j_mini3.hpp")); #endif - // remove file from filesystem - //ASSERT_EQ(0, unlink("libnd4j_mini3.hpp")); - + // remove file from filesystem + // ASSERT_EQ(0, unlink("libnd4j_mini3.hpp")); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp index 117cf7265482..f343058e70c6 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphTests2.cpp @@ -18,158 +18,171 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include +#include +#include #include -#include -#include -#include #include #include -#include -#include -#include -#include +#include #include +#include +#include +#include #include -#include -#include +#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class GraphTests2 : public testing::Test { -public: - - GraphTests2() { - // - } + public: + GraphTests2() { + // + } }; TEST_F(GraphTests2, test_placeholder_1) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); + graph.addPlaceholder("input", DataType::BFLOAT16, {4, 12, 48}); - ASSERT_TRUE(graph.variableSpace().hasVariable("input")); + ASSERT_TRUE(graph.variableSpace().hasVariable("input")); - auto variable = graph.variableSpace().getVariable("input"); + auto variable = graph.variableSpace().getVariable("input"); - ASSERT_NE(nullptr, variable); - ASSERT_TRUE(variable->isPlaceholder()); - ASSERT_EQ(DataType::BFLOAT16, variable->dataType()); - ASSERT_EQ(std::vector({4, 12, 48}), variable->shape()); + ASSERT_NE(nullptr, variable); + ASSERT_TRUE(variable->isPlaceholder()); + ASSERT_EQ(DataType::BFLOAT16, variable->dataType()); + ASSERT_EQ(std::vector({4, 12, 48}), variable->shape()); - auto placeholders = graph.placeholders(); - ASSERT_EQ(1, placeholders.size()); - ASSERT_EQ(placeholders[0], variable); + auto placeholders = graph.placeholders(); + ASSERT_EQ(1, placeholders.size()); + ASSERT_EQ(placeholders[0], variable); } TEST_F(GraphTests2, test_execution_1) { - Graph graph; + Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A + graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // B + graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // C + graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node b(sd::ops::add(), "add_node"); + Node b(sd::ops::add(), "add_node"); - graph.addNode(Node( sd::ops::multiply(), "multiply_node"), {"A", "B"}); - graph.addNode(b, {"multiply_node", "C"}); + graph.addNode(Node(sd::ops::multiply(), "multiply_node"), {"A", "B"}); + graph.addNode(b, {"multiply_node", "C"}); - auto result = graph.execute({}, {"add_node"}); - ASSERT_EQ(1, result.size()); - ASSERT_EQ(1, result.count("add_node")); + auto result = graph.execute({}, {"add_node"}); + ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.count("add_node")); } TEST_F(GraphTests2, test_placeholder_resolution_1) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - Node node(sd::ops::tanh(), "tanh_node"); - graph.addNode(node, {"input"}); + Node node(sd::ops::tanh(), "tanh_node"); + graph.addNode(node, {"input"}); - // this test must throw an exception, because input isn't resolved yet - ASSERT_ANY_THROW(graph.execute()); + // this test must throw an exception, because input isn't resolved yet + ASSERT_ANY_THROW(graph.execute()); } TEST_F(GraphTests2, test_placeholder_resolution_2) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node(sd::ops::rationaltanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::rationaltanh(), "tanh_node"), {"input"}); - auto result = graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); + auto result = + graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}); - // TODO: add result validation here + // TODO: add result validation here } TEST_F(GraphTests2, test_placeholder_resolution_3) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); - ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), sd::datatype_exception); + ASSERT_THROW( + graph.execute({{"input", NDArrayFactory::create(5)}}, {"tanh_node"}), + sd::datatype_exception); } TEST_F(GraphTests2, test_placeholder_resolution_4) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); + graph.addPlaceholder("input", DataType::FLOAT32, {3, 4, 5}); - Node a(sd::ops::tanh(), "tanh_node"); - graph.addNode(a, {"input"}); + Node a(sd::ops::tanh(), "tanh_node"); + graph.addNode(a, {"input"}); - ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), sd::shape_mismatch_exception); + ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, + {"tanh_node"}), + sd::shape_mismatch_exception); } TEST_F(GraphTests2, test_output_resolution_1) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - Node node(sd::ops::tanh(), "tanh_node"); - graph.addNode(node, {"input"}); + Node node(sd::ops::tanh(), "tanh_node"); + graph.addNode(node, {"input"}); - // since we're requesting output of non-existent node - we expect exception - ASSERT_THROW(graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), graph::unresolved_output_exception); + // since we're requesting output of non-existent node - we expect exception + ASSERT_THROW( + graph.execute({{"input", NDArrayFactory::create(0.5f)}}, {"pow_node"}), + graph::unresolved_output_exception); } TEST_F(GraphTests2, test_input_resolution_1) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - Node a(sd::ops::tanh(), "tanh_node"); - graph.addNode(a, {"input"}); + Node a(sd::ops::tanh(), "tanh_node"); + graph.addNode(a, {"input"}); - // since we're trying to resolve non-existent placeholder - we expect exception - ASSERT_THROW(graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), graph::unresolved_input_exception); + // since we're trying to resolve non-existent placeholder - we expect + // exception + ASSERT_THROW( + graph.execute({{"array", NDArrayFactory::create(0.5f)}}, {"tanh_node"}), + graph::unresolved_input_exception); } TEST_F(GraphTests2, test_double_name_1) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); - graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"}); - ASSERT_ANY_THROW(graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"})); + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"}); + ASSERT_ANY_THROW( + graph.addNode(Node(sd::ops::add(), "add_node"), {"tanh_node"})); } TEST_F(GraphTests2, test_self_reference) { - Graph graph; + Graph graph; - graph.addPlaceholder("input", DataType::FLOAT32); + graph.addPlaceholder("input", DataType::FLOAT32); - graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); - ASSERT_ANY_THROW(graph.addNode(Node(sd::ops::add(), "add_node"), {"add_node"})); + graph.addNode(Node(sd::ops::tanh(), "tanh_node"), {"input"}); + ASSERT_ANY_THROW( + graph.addNode(Node(sd::ops::add(), "add_node"), {"add_node"})); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp index da513f7d4640..e632267e34a1 100644 --- a/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/HashUtilsTests.cpp @@ -18,26 +18,22 @@ // Created by raver119 on 02.09.17. // -#include "testlayers.h" #include -class HashUtilsTests : public testing::Test { - -}; +#include "testlayers.h" +class HashUtilsTests : public testing::Test {}; TEST_F(HashUtilsTests, TestEquality1) { - std::string str("Conv2D"); + std::string str("Conv2D"); - Nd4jLong hash1 = sd::ops::HashHelper::getInstance()->getLongHash(str); - ASSERT_EQ(-1637140380760460323L, hash1); + Nd4jLong hash1 = sd::ops::HashHelper::getInstance()->getLongHash(str); + ASSERT_EQ(-1637140380760460323L, hash1); } - - TEST_F(HashUtilsTests, TestEquality2) { - std::string str("switch"); + std::string str("switch"); - Nd4jLong hash1 = sd::ops::HashHelper::getInstance()->getLongHash(str); - ASSERT_EQ(-1988317239813741487L, hash1); + Nd4jLong hash1 = sd::ops::HashHelper::getInstance()->getLongHash(str); + ASSERT_EQ(-1988317239813741487L, hash1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index e25bd0144d24..7c95f43bdcd9 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -15,1310 +15,1894 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include "testlayers.h" -#include +#include +#include #include -#include -#include #include -#include +#include +#include #include -#include +#include +#include #include +#include +#include #include #include -#include -#include -#include -#include +#include + +#include "testlayers.h" using namespace sd; class HelpersTests1 : public testing::Test { -public: - - HelpersTests1() { - - std::cout< array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); - ASSERT_EQ(2, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); + ASSERT_EQ(2, idx); } TEST_F(HelpersTests1, test_binary_search_2) { - std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); - ASSERT_EQ(-1, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); + ASSERT_EQ(-1, idx); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, evalHHmatrix_test1) { + auto x = NDArrayFactory::create('c', {1, 4}, {14, 17, 3, 1}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, + -0.0632377, -0.0210792, -0.13484, -0.0632377, 0.98884, -0.00371987, + -0.0449467, -0.0210792, -0.00371987, 0.99876}); - - auto x = NDArrayFactory::create('c', {1,4}, {14,17,3,1}); - auto exp = NDArrayFactory::create('c', {4,4}, {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, -0.0632377, -0.0210792, -0.13484,-0.0632377, 0.98884,-0.00371987, -0.0449467,-0.0210792,-0.00371987, 0.99876}); - - auto result = ops::helpers::Householder::evalHHmatrix(x); - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); - + auto result = ops::helpers::Householder::evalHHmatrix(x); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, evalHHmatrix_test2) { +#ifdef __CUDABLAS__ + return; +#endif + auto x = NDArrayFactory::create('c', {1, 3}, {14, -4, 3}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {-0.941742, 0.269069, -0.201802, 0.269069, 0.962715, 0.0279639, -0.201802, + 0.0279639, 0.979027}); - #ifdef __CUDABLAS__ - return; - #endif - auto x = NDArrayFactory::create('c', {1,3}, {14,-4,3}); - auto exp = NDArrayFactory::create('c', {3,3}, {-0.941742, 0.269069,-0.201802, 0.269069, 0.962715,0.0279639, -0.201802,0.0279639, 0.979027}); - - auto result = ops::helpers::Householder::evalHHmatrix(x); - - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + auto result = ops::helpers::Householder::evalHHmatrix(x); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } - ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, evalHHmatrixData_test1) { +#ifdef __CUDABLAS__ + return; +#endif + auto x = NDArrayFactory::create('c', {1, 4}, {14, 17, 3, 1}); + auto tail = NDArrayFactory::create('c', {1, 3}); + auto expTail = NDArrayFactory::create( + 'c', {1, 3}, {0.468984, 0.0827618, 0.0275873}); + const double normXExpected = -22.2486; + const double coeffExpected = 1.62925; - #ifdef __CUDABLAS__ - return; - #endif - auto x = NDArrayFactory::create('c', {1,4}, {14,17,3,1}); - auto tail = NDArrayFactory::create('c', {1,3}); - auto expTail = NDArrayFactory::create('c', {1,3}, {0.468984, 0.0827618, 0.0275873}); - const double normXExpected = -22.2486; - const double coeffExpected = 1.62925; - - double normX, coeff; - ops::helpers::Householder::evalHHmatrixData(x, tail, coeff, normX); - - ASSERT_NEAR(normX, normXExpected, 1e-5); - ASSERT_NEAR(coeff, coeffExpected, 1e-5); - ASSERT_TRUE(tail.isSameShapeStrict(expTail)); - ASSERT_TRUE(tail.equalsTo(&expTail)); + double normX, coeff; + ops::helpers::Householder::evalHHmatrixData(x, tail, coeff, normX); + ASSERT_NEAR(normX, normXExpected, 1e-5); + ASSERT_NEAR(coeff, coeffExpected, 1e-5); + ASSERT_TRUE(tail.isSameShapeStrict(expTail)); + ASSERT_TRUE(tail.equalsTo(&expTail)); } - ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test1) { +#ifdef __CUDABLAS__ + return; +#endif + auto x = NDArrayFactory::create( + 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); + auto tail = NDArrayFactory::create('c', {1, 3}, {0.5, 0.5, 0.5}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {9.05, 15.8, 11.4, 0.8, 8.525, 2.4, 15.7, 17.9, 17.525, 16.4, 3.7, 1.9, + 4.525, 2.4, 0.7, 14.9}); - #ifdef __CUDABLAS__ - return; - #endif - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); - - ops::helpers::Householder::mulLeft(x, tail, 0.1); - // expTail.printShapeInfo(); - - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ops::helpers::Householder::mulLeft(x, tail, 0.1); + // expTail.printShapeInfo(); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test2) { +#ifdef __CUDABLAS__ + return; +#endif + auto x = NDArrayFactory::create( + 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); + auto tail = NDArrayFactory::create('c', {3, 1}, {0.5, 0.5, 0.5}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {9.05, 15.8, 11.4, 0.8, 8.525, 2.4, 15.7, 17.9, 17.525, 16.4, 3.7, 1.9, + 4.525, 2.4, 0.7, 14.9}); - #ifdef __CUDABLAS__ - return; - #endif - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {3,1}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); - - ops::helpers::Householder::mulLeft(x, tail, 0.1); - - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ops::helpers::Householder::mulLeft(x, tail, 0.1); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulRight_test1) { +#ifdef __CUDABLAS__ + return; +#endif + auto x = NDArrayFactory::create( + 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); + auto tail = NDArrayFactory::create('c', {1, 3}, {0.5, 0.5, 0.5}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {9, 17.5, 12.5, 1.5, 7, 2.5, 15.5, 17.5, 15.8, 16.4, 3.4, 1.4, 4.3, 3.15, + 1.15, 15.15}); - #ifdef __CUDABLAS__ - return; - #endif - auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); - auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); - auto exp = NDArrayFactory::create('c', {4,4}, {9,17.5,12.5, 1.5, 7, 2.5,15.5, 17.5, 15.8,16.4, 3.4, 1.4, 4.3,3.15,1.15,15.15}); - - ops::helpers::Householder::mulRight(x, tail, 0.1); - - ASSERT_TRUE(x.isSameShapeStrict(exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ops::helpers::Householder::mulRight(x, tail, 0.1); + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } - ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); + auto hhMatrixExp = NDArrayFactory::create( + 'c', {4, 4}, + {1.524000, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, + 0.971124, 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, + 0}); + auto hhBidiagExp = NDArrayFactory::create( + 'c', {4, 4}, + {-17.1756, 24.3869, 0, 0, 0, -8.61985, -3.89823, 0, 0, 0, 4.03047, + 4.13018, 0, 0, 0, 1.21666}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6,13,11,7,6,3,7,4,7,6,6,7,10}); - auto hhMatrixExp = NDArrayFactory::create('c', {4,4}, {1.524000, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367,0, 0.229221,-0.272237,0.938237,0}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.1756, 24.3869, 0, 0, 0,-8.61985,-3.89823, 0, 0, 0, 4.03047,4.13018, 0, 0, 0,1.21666}); - - ops::helpers::BiDiagonalUp object(matrix); - // object._HHmatrix.printBuffer(); + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto hhMatrixExp = NDArrayFactory::create( + 'c', {5, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, + -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, + 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); + auto hhBidiagExp = NDArrayFactory::create( + 'c', {4, 4}, + {-17.2916, 7.03123, 0, 0, 0, 16.145, -22.9275, 0, 0, 0, -9.9264, -11.5516, + 0, 0, 0, -12.8554}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto hhMatrixExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821, 0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0, 16.145,-22.9275, 0, 0, 0, -9.9264,-11.5516, 0, 0, 0,-12.8554}); - - ops::helpers::BiDiagonalUp object(matrix); - // object._HHmatrix.printBuffer(); + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12, 0,-15,10,2}); - auto hhMatrixExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0,16.3413,-20.7828, 0, 0, 0,-18.4892,4.13261, 0, 0, 0,-21.323}); - - ops::helpers::BiDiagonalUp object(matrix); - // object._HHmatrix.printBuffer(); - - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); - ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); - ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); - ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, + -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); + auto hhMatrixExp = NDArrayFactory::create( + 'c', {6, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, + 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, + -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, + 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); + auto hhBidiagExp = NDArrayFactory::create( + 'c', {4, 4}, + {-17.2916, 7.03123, 0, 0, 0, 16.3413, -20.7828, 0, 0, 0, -18.4892, + 4.13261, 0, 0, 0, -21.323}); + + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); + + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test1) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto vectorsUseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto vectorsVseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.66025,1.58392,1.99303}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.66979,0}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); - +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto vectorsUseqExp = NDArrayFactory::create( + 'c', {5, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, + -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, + 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); + auto vectorsVseqExp = NDArrayFactory::create( + 'c', {5, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, + -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, + 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); + auto coeffsUseqExp = NDArrayFactory::create( + 'c', {4, 1}, {1.52048, 1.66025, 1.58392, 1.99303}); + auto coeffsVseqExp = + NDArrayFactory::create('c', {3, 1}, {1.37012, 1.66979, 0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test2) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); - auto vectorsUseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto vectorsVseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.65232,1.35075,1.61136}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.59666,0}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); - +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, + -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); + auto vectorsUseqExp = NDArrayFactory::create( + 'c', {6, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, + 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, + -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, + 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); + auto vectorsVseqExp = NDArrayFactory::create( + 'c', {6, 4}, + {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, + 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, + -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, + 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); + auto coeffsUseqExp = NDArrayFactory::create( + 'c', {4, 1}, {1.52048, 1.65232, 1.35075, 1.61136}); + auto coeffsVseqExp = + NDArrayFactory::create('c', {3, 1}, {1.37012, 1.59666, 0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test3) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto vectorsUseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); - auto vectorsVseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); - auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, { 1.524, 1.5655,1.06367,0}); - auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.75682,1.02929, 0}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - - ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); - ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); - ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); - - ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); - ASSERT_TRUE(vSeq._shift == 1); - ASSERT_TRUE(uSeq._shift == 0); - +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); + auto vectorsUseqExp = NDArrayFactory::create( + 'c', {4, 4}, + {1.524, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, 0.971124, + 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, 0}); + auto vectorsVseqExp = NDArrayFactory::create( + 'c', {4, 4}, + {1.524, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, 0.971124, + 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, 0}); + auto coeffsUseqExp = + NDArrayFactory::create('c', {4, 1}, {1.524, 1.5655, 1.06367, 0}); + auto coeffsVseqExp = + NDArrayFactory::create('c', {3, 1}, {1.75682, 1.02929, 0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test4) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {2.49369, 2.62176, 5.88386, 7.69905, -16.0588, -18.7319, -9.15007, + -12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279, -2.29178, 1.90139, + -0.66187}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto exp = NDArrayFactory::create('c', {4,4}, {2.49369, 2.62176, 5.88386, 7.69905, -16.0588,-18.7319,-9.15007,-12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279,-2.29178, 1.90139,-0.66187}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix); - - ASSERT_TRUE(matrix.equalsTo(&exp)); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test5) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {4.52891, 8.09473, -2.73704, -13.0302, -11.0752, 7.41549, -3.75125, + 0.815252, -7.76818, -15.9102, -9.90869, -11.8677, 1.63942, -17.0312, + -9.05102, -4.49088, -9.63311, 0.540226, -1.52764, 5.79111}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto exp = NDArrayFactory::create('c', {5,4}, {4.52891, 8.09473,-2.73704,-13.0302, -11.0752, 7.41549,-3.75125,0.815252, -7.76818,-15.9102,-9.90869,-11.8677, 1.63942,-17.0312,-9.05102,-4.49088, -9.63311,0.540226,-1.52764, 5.79111}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix); - - ASSERT_TRUE(matrix.equalsTo(&exp)); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test6) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {9, -1, 3, 9, -4.43019, -15.1713, + -3.2854, -7.65743, -9.39162, -7.03599, 8.03827, 9.48453, + -2.97785, -16.424, 5.35265, -20.1171, -0.0436177, -13.118, + -8.37287, -17.3012, -1.14074, 4.18282, -10.0914, -5.69014}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9,-1,3,9, -4.43019,-15.1713, -3.2854,-7.65743, -9.39162,-7.03599, 8.03827, 9.48453, -2.97785, -16.424, 5.35265,-20.1171, -0.0436177, -13.118,-8.37287,-17.3012, -1.14074, 4.18282,-10.0914,-5.69014}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); - - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test7) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {9, 13, 3, 6, -5.90424, -2.30926, -0.447417, 3.05712, -10.504, -9.31339, + -8.85493, -10.8886, -8.29494, -10.6737, -5.94895, -7.55591}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto exp = NDArrayFactory::create('c', {4,4}, {9,13,3,6,-5.90424,-2.30926,-0.447417, 3.05712, -10.504,-9.31339, -8.85493,-10.8886, -8.29494,-10.6737, -5.94895,-7.55591}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test8) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, + -6, -6.90831, -5.01113, 0.381677, 0.440128, -0.80107, 0.961605, + -0.308019, -1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto exp = NDArrayFactory::create('c', {5,4}, {9, -13, 3, 6, 13, 11, 7, -6, -6.90831,-5.01113, 0.381677,0.440128, -0.80107,0.961605,-0.308019,-1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test9) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, + -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, {9, -13, 3, 6, 13, 11, + 7, -6, 3, 7, 4, 7, + 3.77597, 18.6226, -0.674868, 4.61365, 5.02738, -14.1486, + -2.22877, -8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, 3.77597, 18.6226,-0.674868, 4.61365, 5.02738,-14.1486, -2.22877,-8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); - ASSERT_TRUE(matrix.equalsTo(&exp)); + ASSERT_TRUE(matrix.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test10) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {9, -1, 3, 9, 10, 11, + -7, -5, 3, 2, 4, 7, + 2.58863, 11.0295, -4.17483, -0.641012, -1.21892, -16.3151, + 6.12049, -20.0239, -0.901799, -15.0389, -12.4944, -20.2394}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 2.58863, 11.0295,-4.17483,-0.641012, -1.21892,-16.3151, 6.12049, -20.0239, -0.901799,-15.0389,-12.4944, -20.2394}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test11) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 4}, + {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, + -7, -5, 3, 2, 4, 7, + 1.14934, 4.40257, 8.70127, -1.18824, 1.5132, 0.220419, + -11.6285, -11.7549, 2.32148, 24.3838, 0.256531, 25.9116}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 1.14934, 4.40257, 8.70127,-1.18824, 1.5132,0.220419,-11.6285,-11.7549, 2.32148, 24.3838,0.256531, 25.9116}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test12) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, + -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, -2.62252, -22.2914, + 4.76743, -19.6689, -1.05943, -9.00514, -11.8013, -7.94571}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, 19, -2.62252,-22.2914, 4.76743,-19.6689, -1.05943,-9.00514,-11.8013,-7.94571}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test13) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {9, -1, 3, 9, -4.65167, 3.44652, + 7.83593, 22.6899, -9.48514, -21.902, 5.66559, -13.0533, + -0.343184, 15.2895, 7.2888, 14.0489, 0.289638, -1.87752, + 3.944, -1.49707, -2.48845, 3.18285, -10.6685, 0.406502}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto exp = NDArrayFactory::create('c', {6,4}, {9 , -1 , 3 , 9, -4.65167, 3.44652, 7.83593, 22.6899, -9.48514, -21.902, 5.66559,-13.0533, -0.343184, 15.2895, 7.2888, 14.0489, 0.289638,-1.87752, 3.944,-1.49707, -2.48845, 3.18285,-10.6685,0.406502}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test14) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); + auto matrix2 = NDArrayFactory::create( + 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, + 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {1.78958, 8.06962, -6.13687, 4.36267, 1.06472, -14.9578, -8.1522, + 1.30442, -18.3343, -13.2578, 13.5536, 5.50764, 15.7859, 7.60831, + 11.7871, -1.3626, -0.634986, 7.60934, -2.1841, 5.62694, -13.0577, + 15.1554, -7.6511, 3.76365, -5.87368}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto exp = NDArrayFactory::create('c', {5,5}, {1.78958, 8.06962,-6.13687, 4.36267, 1.06472, -14.9578, -8.1522, 1.30442,-18.3343,-13.2578, 13.5536, 5.50764, 15.7859, 7.60831, 11.7871, -1.3626,-0.634986, 7.60934, -2.1841, 5.62694, -13.0577, 15.1554, -7.6511, 3.76365,-5.87368}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test15) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); + auto matrix2 = NDArrayFactory::create( + 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, + 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); + auto exp = NDArrayFactory::create( + 'c', {5, 5}, + {9, -1, 3, 9, 10, 11, -7, + -5, 3, 2, 4, 7, -1, 6, + 7, -9.26566, -16.4298, 1.64125, -17.3243, -7.70257, -16.7077, + 4.80216, -19.1652, -2.42279, -13.0258}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); - auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto exp = NDArrayFactory::create('c', {5,5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, -9.26566,-16.4298, 1.64125,-17.3243,-7.70257, -16.7077, 4.80216,-19.1652,-2.42279,-13.0258}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.mulLeft(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test16) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, + 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); + auto matrix2 = NDArrayFactory::create('c', {10, 10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, + -0.455573, -0.824221, -0.239444, 0.216163, -0.0951492, + -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, + -0.7869, 0.245393, 0.116952, -0.541267, 0.117997, + -0.0828315, 0.303191, -0.888202, 0.133021, 0.3076}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{5,5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, -0.455573,-0.824221,-0.239444, 0.216163,-0.0951492, -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, -0.7869, 0.245393, 0.116952,-0.541267, 0.117997, -0.0828315, 0.303191,-0.888202, 0.133021, 0.3076}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test17) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, + 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); + auto matrix2 = NDArrayFactory::create('c', {10, 10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create( + 'c', {5, 5}, {1, 0, 0, 0, 0, + 0, -0.022902, 0.986163, 0.0411914, 0.158935, + 0, -0.44659, 0.021539, 0.797676, -0.404731, + 0, -0.554556, 0.103511, -0.600701, -0.56649, + 0, -0.701784, -0.127684, -0.0342758, 0.700015}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{5,5}, {1, 0, 0, 0, 0, 0,-0.022902, 0.986163, 0.0411914, 0.158935, 0, -0.44659, 0.021539, 0.797676,-0.404731, 0,-0.554556, 0.103511, -0.600701, -0.56649, 0,-0.701784,-0.127684,-0.0342758, 0.700015}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test18) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto matrix2 = NDArrayFactory::create('c', {10, 10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create( + 'c', {6, 6}, + {-0.637993, 0.190621, -0.524821, -0.312287, 0.407189, 0.133659, + -0.708881, 0.0450803, 0.47462, 0.232701, -0.204602, -0.417348, + -0.212664, -0.0405892, -0.297123, 0.0240276, -0.821557, 0.435099, + 0.0708881, -0.432466, -0.49252, -0.145004, -0.199312, -0.710367, + -0.141776, -0.56468, -0.180549, 0.706094, 0.274317, 0.233707, + -0.141776, -0.673865, 0.368567, -0.572848, 0.0490246, 0.243733}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{6,6}, {-0.637993, 0.190621,-0.524821,-0.312287, 0.407189, 0.133659, -0.708881, 0.0450803, 0.47462, 0.232701,-0.204602,-0.417348, -0.212664,-0.0405892,-0.297123,0.0240276,-0.821557, 0.435099, 0.0708881, -0.432466, -0.49252,-0.145004,-0.199312,-0.710367, -0.141776, -0.56468,-0.180549, 0.706094, 0.274317, 0.233707, -0.141776, -0.673865, 0.368567,-0.572848,0.0490246, 0.243733}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); - uSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test19) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, + -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); + auto matrix2 = NDArrayFactory::create('c', {10, 10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1, 0, 0, 0, 0, -0.859586, 0.28601, -0.42345, 0, 0.19328, -0.585133, + -0.787567, 0, -0.473027, -0.758826, 0.447693}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); - auto matrix2 = NDArrayFactory::create('c', {10,10}); - matrix2 = 100.; - auto exp = NDArrayFactory::create('c',{4,4}, {1, 0, 0, 0, 0,-0.859586, 0.28601, -0.42345, 0, 0.19328,-0.585133,-0.787567, 0,-0.473027,-0.758826, 0.447693}); - - ops::helpers::BiDiagonalUp object(matrix); - ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); - vSeq.applyTo(matrix2); + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); - ASSERT_TRUE(matrix2.equalsTo(&exp)); + ASSERT_TRUE(matrix2.equalsTo(&exp)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test1) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); - auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); - auto expU = NDArrayFactory::create('c', {5,5}, {18,3, 2,7,-11, 7, 7.75131,10,-12.5665, -8, 13, 20.905,-4,-14.7979, -9, -17,-3.87565,-7,-19.2608, -8, -9, 9, 6, 14,-11}); - - ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(1,1,2,2); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, + -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); + auto matrix2 = NDArrayFactory::create( + 'c', {5, 5}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20, -4, + -16, -9, -17, -5, -7, -19, -8, -9, 9, 6, 14, -11}); + auto expM = NDArrayFactory::create( + 'c', {5, 5}, + {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17.0294, + -6, 8, -10, 14, -15, 6, -10, -14, 12, 0, -16, 0}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, + {18, 3, 2, 7, -11, 7, 7.75131, 10, -12.5665, + -8, 13, 20.905, -4, -14.7979, -9, -17, -3.87565, -7, + -19.2608, -8, -9, 9, 6, 14, -11}); + + ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1, 1, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test2) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); - auto expM = NDArrayFactory::create('c', {5,5}, {22.6716,14, 9,-12,-12, 5,-4,-19, -7,-12, 0,16, 0, -6, 8, -10,14,-15, 6,-10, -14,12, -1,-16, 3}); - auto expU = NDArrayFactory::create('c', {5,5}, {-12.1738, 3, -13.4089, 7,-11, 1.36735, 7, -12.1297,-13, -8, -12.3944,20, -5.60173,-16, -9, -17,-5,-7,-19, -8, -9, 9, 6, 14,-11}); - - ops::helpers::SVD svd(matrix, 4, true, true, true); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(0,0,2,2); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, + -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); + auto matrix2 = NDArrayFactory::create( + 'c', {5, 5}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20, -4, + -16, -9, -17, -5, -7, -19, -8, -9, 9, 6, 14, -11}); + auto expM = NDArrayFactory::create( + 'c', {5, 5}, + {22.6716, 14, 9, -12, -12, 5, -4, -19, -7, -12, 0, 16, 0, + -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, + {-12.1738, 3, -13.4089, 7, -11, 1.36735, 7, -12.1297, -13, + -8, -12.3944, 20, -5.60173, -16, -9, -17, -5, -7, + -19, -8, -9, 9, 6, 14, -11}); + + ops::helpers::SVD svd(matrix, 4, true, true, true); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(0, 0, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test3) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, + -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); + auto matrix2 = NDArrayFactory::create( + 'c', {2, 6}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20}); + auto expM = NDArrayFactory::create( + 'c', {5, 5}, + {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17.0294, + -6, 8, -10, 14, -15, 6, -10, -14, 12, 0, -16, 0}); + auto expU = NDArrayFactory::create( + 'c', {2, 6}, + {18, 2.58377, 2, 7.16409, -11, 7, 7, 10.4525, -13, -7.39897, 13, 20}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); - auto matrix2 = NDArrayFactory::create('c', {2,6}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20}); - auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); - auto expU = NDArrayFactory::create('c', {2,6}, {18, 2.58377, 2, 7.16409,-11, 7, 7 ,10.4525 ,-13, -7.39897 ,13 ,20}); - - ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); - svd._m = matrix; - svd._u = matrix2; - svd.deflation1(1,1,2,2); + ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1, 1, 2, 2); - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test4) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8, 18,-17, 18, -14,-15,8.06226, 2, 2, -3,-18, 0,-17, 2, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16, -20, 13, 20,-10, -9, -1,-20.7138,4.46525, -4, 20, -11, 19,-18.4812,2.72876, 12,-19, 18,-18, 17, -10,-19, 14, -2, -7, -17, -14, -4,-16, 18, -6, -18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-18, -13, 14, 2, -2,-11,2.97683,-7.69015,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 2, 2, 1, 1, 2, 1); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expM = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 8.06226, 2, 2, -3, -18, 0, -17, 2, 12, 18, 6, -2, -17}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, + {-10, -16, -20, 13, 20, -10, -9, -1, -20.7138, + 4.46525, -4, 20, -11, 19, -18.4812, 2.72876, 12, -19, + 18, -18, 17, -10, -19, 14, -2, -7, -17, + -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 2.97683, + -7.69015, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 2, 2, 1, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test5) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16,-20,13, 20,-10, -9,-15.8359, -7,-12.2566, -4, 20, -11,-1.30158, -5,-26.1401, 12,-19, 18,-19.3068, 17, 7.15871,-19, 14, -2, -7,-17, -14, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 0, 1, 1, 0, 2, 2); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expM = NDArrayFactory::create( + 'c', {6, 5}, {18.4391, 20, 19, -18, -6, 3, 6, 2, -7, -7, + 0, 8, 18.4391, -17, 18, -14, -15, 1, 2, 2, + -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, + {-10, -16, -20, 13, 20, -10, -9, -15.8359, -7, -12.2566, + -4, 20, -11, -1.30158, -5, -26.1401, 12, -19, 18, -19.3068, + 17, 7.15871, -19, 14, -2, -7, -17, -14, -4, -16, + 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 2, -1.08465, -13, 22.7777, 2, -2, -5.64019, 8, + 9.65341, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test6) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {2,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {2,6}, {-10, -0.542326,-20, 20.6084,20,-10, -9, -15.8359, -7,-12.2566,-4, 20}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation2(1, 0, 1, 1, 0, 2, 2); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto matrix2 = NDArrayFactory::create( + 'c', {2, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expM = NDArrayFactory::create( + 'c', {6, 5}, {18.4391, 20, 19, -18, -6, 3, 6, 2, -7, -7, + 0, 8, 18.4391, -17, 18, -14, -15, 1, 2, 2, + -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto expU = + NDArrayFactory::create('c', {2, 6}, + {-10, -0.542326, -20, 20.6084, 20, -10, -9, + -15.8359, -7, -12.2566, -4, 20}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 2, -1.08465, -13, 22.7777, 2, -2, -5.64019, 8, + 9.65341, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test7) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8,19.6977,-17, 18, -14,-15, 1, 2, 2, -3,-18, 0,-17, 0, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10, -16,-20, 13, 20,-10, -9,-9.03658, -7,-17.8701, -4, 20, -11, 10.0519, -5,-24.1652, 12,-19, 18, -20.51, 17,-1.82762,-19, 14, -2,-12.0826,-17,-9.95039, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13,14, 2, -2,-11, 8, 2,-6, -3, -8, 8,-2, 7, 16, 15, -3, 7, 0}); - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation(1, 3, 1, 1, 2, 1); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + auto expM = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 19.6977, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 0, -17, 0, 12, 18, 6, -2, -17}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, + {-10, -16, -20, 13, 20, -10, -9, -9.03658, -7, + -17.8701, -4, 20, -11, 10.0519, -5, -24.1652, 12, -19, + 18, -20.51, 17, -1.82762, -19, 14, -2, -12.0826, -17, + -9.95039, -4, -16, 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(1, 3, 1, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test8) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expM = NDArrayFactory::create('c', {6,5}, {12, 20,19,-18, -6, 3, 6, 2, -7, -7, 14,-15, 2,-17, 18, -14, 8, 1, 18, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); - auto expU = NDArrayFactory::create('c', {6,6}, {-10,-20,-16, 13, 20,-10, -9, -7, -1,-20, -4, 20, -11, -5, 19,-18, 12,-19, 18, 17,-18,-10,-19, 14, -2, -7,-17,-14, -4,-16, 18, -6,-18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13, 2,14, -2,-11, 8,-6, 2, -3, -8, 8, 7,-2, 16, 15, -3, 7, 0}); - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - svd.deflation(0, 2, 2, 1, 2, 1); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + auto expM = NDArrayFactory::create( + 'c', {6, 5}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, -15, 2, -17, 18, + -14, 8, 1, 18, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-10, -20, -16, 13, 20, -10, -9, -7, -1, -20, -4, 20, + -11, -5, 19, -18, 12, -19, 18, 17, -18, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 2, 14, -2, -11, 8, + -6, 2, -3, -8, 8, 7, -2, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(0, 2, 2, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test9) { - - #ifdef __CUDABLAS__ - return; - #endif - auto col0 = NDArrayFactory::create('c', {10,1}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,14}); - auto diag = NDArrayFactory::create('c', {10,1}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2}); - auto permut = NDArrayFactory::create('c', {1,10}, {8 ,1 ,4 ,0, 5 ,2 ,9 ,3 ,7 ,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expSingVals = NDArrayFactory::create('c', {10,1}, {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); - auto expShifts = NDArrayFactory::create('c', {10,1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); - auto expMus = NDArrayFactory::create('c', {10,1}, {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); - - auto singVals = NDArrayFactory::create('c', {10,1}); - auto shifts = NDArrayFactory::create('c', {10,1}); - auto mus = NDArrayFactory::create('c', {10,1}); - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); - - ASSERT_TRUE(expSingVals.equalsTo(&singVals)); - ASSERT_TRUE(expShifts.equalsTo(&shifts)); - ASSERT_TRUE(expMus.equalsTo(&mus)); +#ifdef __CUDABLAS__ + return; +#endif + auto col0 = NDArrayFactory::create( + 'c', {10, 1}, {12, 20, 19, -18, -6, 3, 6, 2, -7, 14}); + auto diag = NDArrayFactory::create( + 'c', {10, 1}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2}); + auto permut = NDArrayFactory::create('c', {1, 10}, + {8, 1, 4, 0, 5, 2, 9, 3, 7, 6}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + auto expSingVals = NDArrayFactory::create( + 'c', {10, 1}, + {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); + auto expShifts = NDArrayFactory::create( + 'c', {10, 1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); + auto expMus = NDArrayFactory::create( + 'c', {10, 1}, + {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); + + auto singVals = NDArrayFactory::create('c', {10, 1}); + auto shifts = NDArrayFactory::create('c', {10, 1}); + auto mus = NDArrayFactory::create('c', {10, 1}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); + + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expShifts.equalsTo(&shifts)); + ASSERT_TRUE(expMus.equalsTo(&mus)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test10) { +#ifdef __CUDABLAS__ + return; +#endif + auto singVals = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); + auto col0 = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); + auto diag = NDArrayFactory::create('c', {4, 1}, {5, 7, -13, 14}); + auto permut = NDArrayFactory::create('c', {1, 4}, {0, 2, 3, 1}); + auto mus = NDArrayFactory::create('c', {4, 1}, {4, 1, 4, 6}); + auto shifts = NDArrayFactory::create('c', {4, 1}, {4, 2, 5, 6}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - #ifdef __CUDABLAS__ - return; - #endif - auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto col0 = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); - auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); - auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); - auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expZhat = NDArrayFactory::create('c', {4,1}, {0, 0.278208, 72.501953, 0}); + auto expZhat = + NDArrayFactory::create('c', {4, 1}, {0, 0.278208, 72.501953, 0}); - auto zhat = NDArrayFactory::create('c', {4,1}); + auto zhat = NDArrayFactory::create('c', {4, 1}); - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); - ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); - ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); + ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); + ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test11) { - - #ifdef __CUDABLAS__ - return; - #endif - auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); - auto zhat = NDArrayFactory::create('c', {4,1}, {2 ,1 ,2 ,1}); - auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); - auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); - auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); - auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expU = NDArrayFactory::create('c', {5,5}, {-0.662161, 0.980399,-0.791469,-0.748434, 0, -0.744931, 0.183825,-0.593602,-0.392928, 0, 0.0472972, 0.061275,0.0719517, 0.104781, 0, 0.0662161,0.0356509, 0.126635, 0.523904, 0, 0, 0, 0, 0, 1}); - auto expV = NDArrayFactory::create('c', {4,4}, {-0.745259,-0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, -0.156156, -0.0768918,-0.130705,-0.0885868,-0.0773343, 0.115929,0.0818966, 0.167906, 0.416415}); - auto U = NDArrayFactory::create('c', {5,5}); - auto V = NDArrayFactory::create('c', {4,4}); - - - ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.calcSingVecs(zhat, diag,permut, singVals, shifts, mus, U, V); - - ASSERT_TRUE(expU.equalsTo(&U)); - ASSERT_TRUE(expV.equalsTo(&V)); - +#ifdef __CUDABLAS__ + return; +#endif + auto singVals = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); + auto zhat = NDArrayFactory::create('c', {4, 1}, {2, 1, 2, 1}); + auto diag = NDArrayFactory::create('c', {4, 1}, {5, 7, -13, 14}); + auto permut = NDArrayFactory::create('c', {1, 4}, {0, 2, 3, 1}); + auto mus = NDArrayFactory::create('c', {4, 1}, {4, 1, 4, 6}); + auto shifts = NDArrayFactory::create('c', {4, 1}, {4, 2, 5, 6}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {-0.662161, 0.980399, -0.791469, -0.748434, 0, + -0.744931, 0.183825, -0.593602, -0.392928, 0, + 0.0472972, 0.061275, 0.0719517, 0.104781, 0, + 0.0662161, 0.0356509, 0.126635, 0.523904, 0, + 0, 0, 0, 0, 1}); + auto expV = NDArrayFactory::create( + 'c', {4, 4}, + {-0.745259, -0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, + -0.156156, -0.0768918, -0.130705, -0.0885868, -0.0773343, 0.115929, + 0.0818966, 0.167906, 0.416415}); + auto U = NDArrayFactory::create('c', {5, 5}); + auto V = NDArrayFactory::create('c', {4, 4}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); + + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test12) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - - auto expSingVals = NDArrayFactory::create('c', {4,1}, {8.43282, 5, 2.3, 1.10167}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.401972,0, 0.206791, 0.891995,0, 0,1, 0, 0,0, 0.816018,0,-0.522818,-0.246529,0, -0.415371,0,-0.826982, 0.378904,0, 0,0, 0, 0,1}); - auto expV = NDArrayFactory::create('c', {4,4}, {-0.951851,0,-0.133555,-0.275939, 0,1, 0, 0, 0.290301,0,-0.681937,-0.671333, -0.098513,0,-0.719114, 0.687873}); - - ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - NDArray U, singVals, V; - svd.calcBlockSVD(1, 4, U, singVals, V); - - ASSERT_TRUE(expSingVals.equalsTo(&singVals)); - ASSERT_TRUE(expU.equalsTo(&U)); - ASSERT_TRUE(expV.equalsTo(&V)); - - ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); - ASSERT_TRUE(expU.isSameShapeStrict(U)); - ASSERT_TRUE(expV.isSameShapeStrict(V)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, + -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix4 = NDArrayFactory::create( + 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, + -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); + + auto expSingVals = + NDArrayFactory::create('c', {4, 1}, {8.43282, 5, 2.3, 1.10167}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, + {0.401972, 0, 0.206791, 0.891995, 0, 0, 1, + 0, 0, 0, 0.816018, 0, -0.522818, -0.246529, + 0, -0.415371, 0, -0.826982, 0.378904, 0, 0, + 0, 0, 0, 1}); + auto expV = NDArrayFactory::create( + 'c', {4, 4}, + {-0.951851, 0, -0.133555, -0.275939, 0, 1, 0, 0, 0.290301, 0, -0.681937, + -0.671333, -0.098513, 0, -0.719114, 0.687873}); + + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + NDArray U, singVals, V; + svd.calcBlockSVD(1, 4, U, singVals, V); + + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); + + ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); + ASSERT_TRUE(expU.isSameShapeStrict(U)); + ASSERT_TRUE(expV.isSameShapeStrict(V)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test13) { +#ifdef __CUDABLAS__ + return; +#endif + NDArray matrix1('c', {6, 5}, {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, + 14, 8, 18, -17, 18, -14, -15, 1, 2, 2, + -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - #ifdef __CUDABLAS__ - return; - #endif - NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - - auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); - auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); - - ops::helpers::HHcolPivQR qr(matrix1); + auto expQR = NDArrayFactory::create( + 'c', {6, 5}, + {-37.054, 0.323852, 8.04231, -22.9395, -13.089, 0.105164, + 32.6021, 6.42277, -0.262898, -1.58766, 0.140218, -0.485058, + 29.2073, -9.92301, -23.7111, -0.262909, -0.00866538, 0.103467, + 8.55831, -1.86455, -0.315491, 0.539207, 0.40754, -0.0374124, + -7.10401, 0.315491, 0.385363, -0.216459, -0.340008, 0.628595}); + auto expCoeffs = NDArrayFactory::create( + 'c', {1, 5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); + auto expPermut = NDArrayFactory::create( + 'c', {5, 5}, {0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0}); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test14) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {5, 6}, + {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, + -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - - auto expQR = NDArrayFactory::create('c', {5,6}, {-32.665, -4.95944, -8.26574, 7.22487, 16.5927, 11.7251, -0.135488, -29.0586, 10.9776, -14.6886, 4.18841, 20.7116, 0.348399, 0.323675, 25.5376, 1.64324, 9.63959, -9.0238, -0.0580664,0.0798999,-0.0799029, 19.5281,-4.97736, 16.0969, 0.348399,-0.666783, 0.0252425,0.0159188, 10.6978,-4.69198}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); + auto expQR = NDArrayFactory::create( + 'c', {5, 6}, + {-32.665, -4.95944, -8.26574, 7.22487, 16.5927, 11.7251, + -0.135488, -29.0586, 10.9776, -14.6886, 4.18841, 20.7116, + 0.348399, 0.323675, 25.5376, 1.64324, 9.63959, -9.0238, + -0.0580664, 0.0798999, -0.0799029, 19.5281, -4.97736, 16.0969, + 0.348399, -0.666783, 0.0252425, 0.0159188, 10.6978, -4.69198}); + auto expCoeffs = NDArrayFactory::create( + 'c', {1, 5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); + auto expPermut = NDArrayFactory::create( + 'c', {6, 6}, {0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test15) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - - auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); - auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); + auto expQR = NDArrayFactory::create( + 'c', {6, 6}, + {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, + -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, + 0.394431, 0.509952, -30.2179, -6.78341, 12.8421, 28.5491, + -0.290633, 0.111912, 0.450367, 28.1139, 15.5195, 2.60562, + 0.332152, 0.405161, 0.308163, 0.0468127, 22.294, -2.94931, + 0.249114, 0.0627956, 0.657873, 0.76767, -0.752594, -7.46986}); + auto expCoeffs = NDArrayFactory::create( + 'c', {1, 6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); + auto expPermut = NDArrayFactory::create( + 'c', {6, 6}, {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test1) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto left = NDArrayFactory::create('c', {2, 2}); + auto right = NDArrayFactory::create('c', {2, 2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto left = NDArrayFactory::create('c', {2,2}); - auto right = NDArrayFactory::create('c', {2,2}); - - auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); - auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); + auto expLeft = NDArrayFactory::create( + 'c', {2, 2}, {0.972022, 0.23489, -0.23489, 0.972022}); + auto expRight = NDArrayFactory::create( + 'c', {2, 2}, {0.827657, 0.561234, -0.561234, 0.827657}); - ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); + ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); - ASSERT_TRUE(expLeft.equalsTo(&left)); - ASSERT_TRUE(expRight.equalsTo(&right)); + ASSERT_TRUE(expLeft.equalsTo(&left)); + ASSERT_TRUE(expRight.equalsTo(&right)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test2) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); - auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - - auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); - auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); - auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - - ops::helpers::JacobiSVD jac(matrix3, true, true, true); - jac._m = matrix3; - jac._u = matrix4; - jac._v = matrix5; - - double maxElem; - bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); - - // ASSERT_NEAR(maxElem, 19.69772, 1e-5); - ASSERT_TRUE(exp3.equalsTo(&matrix3)); - ASSERT_TRUE(exp4.equalsTo(&jac._u)); - ASSERT_TRUE(exp5.equalsTo(&jac._v)); - - ASSERT_TRUE(result); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix4 = NDArrayFactory::create( + 'c', {5, 5}, {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, + -17, 18, -14, -15, 1, 2, 2, -3, -18, 8, -17, -19}); + auto matrix5 = NDArrayFactory::create( + 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, + -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); + + auto exp3 = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, -0.609208, 19.6977, 8.63044, -11.9811, + -4.67059, -2, -11, 8, 2, -6, 3.55371, 0, -12.5903, + 7.51356, -5.5844, 16, 15, -3, 7, 0}); + auto exp4 = NDArrayFactory::create( + 'c', {5, 5}, + {12, -10.9657, 19, 24.5714, -6, 3, -2.6399, 2, 8.83351, -7, + 14, -0.406138, 18, 18.7839, 18, -14, 12.8949, 1, -7.9197, 2, + -3, 23.353, 8, 8.2243, -19}); + auto exp5 = NDArrayFactory::create( + 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, + -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); + + ops::helpers::JacobiSVD jac(matrix3, true, true, true); + jac._m = matrix3; + jac._u = matrix4; + jac._v = matrix5; + + double maxElem; + bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); + + // ASSERT_NEAR(maxElem, 19.69772, 1e-5); + ASSERT_TRUE(exp3.equalsTo(&matrix3)); + ASSERT_TRUE(exp4.equalsTo(&jac._u)); + ASSERT_TRUE(exp5.equalsTo(&jac._v)); + + ASSERT_TRUE(result); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test3) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, -1.14919, -12.1206, + 3.59677, 4.34919, -4.24758, -1.94919, 11.7427, 11.6698, -10.4444, + -2.74919, -3, -8, 8, -2, 7, 16, + 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test4) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 1.94919, 4.92056, + -8.79677, 1.25081, 5.04758, 1.14919, -16.1427, -8.46976, 11.2444, + 0.349193, -3, -8, 8, -2, 7, 16, + 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test5) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, + 2, 1.14919, 6.32056, -4.59677, -1.14919, 3.44758, -3, -8, 8, + -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test6) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, + {-18, -14.5173, 4.5746, -7, 1, 2, 6.46976, -16.5427, 14, 2, + -2, -8.39677, -6.92056, 2, -6, -3, -7.79677, -4.59677, -2, 7, + 16, 5.32379, 11.019, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test7) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, {-18, 14.9173, 3.0254, -7, 1, 2, -13.6698, 11.3427, 14, 2, + -2, 3.99677, 10.1206, 2, -6, -3, 4.59677, 7.79677, -2, 7, + 16, 0.67621, -12.219, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } ////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test8) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto rotation = NDArrayFactory::create( + 'c', {2, 2}, + {0.2, math::nd4j_sqrt(0.6), + -math::nd4j_sqrt(0.6), 0.2}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); + auto expected = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 18.5173, -7, 1, 2, -18, -12.6698, 14, 2, + -2, -11, 7.79677, 2, -6, -3, -8, 7.79677, -2, 7, + 16, 15, -2.92379, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test9) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {0.744855, 0.0686476, 0.079663, 0.0889877, 0.65285, + -0.386297, -0.760021, 0.00624688, 0.156774, 0.498522, + 0.186491, -0.322427, 0.773083, -0.468826, -0.209299, + 0.246053, -0.215594, 0.240942, 0.821793, -0.399475, + -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, + -0.0849394, 0.917171, -0.155876, -0.0124053, 0.356555, + 0.66983, 0.182569, 0.696897, 0.179807, 0.000864568, + -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, + 0.0160818, -0.0351459, -0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test10) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {0.744855, 0.0686476, 0.079663, 0.0889877, 0.65285, + -0.386297, -0.760021, 0.00624688, 0.156774, 0.498522, + 0.186491, -0.322427, 0.773083, -0.468826, -0.209299, + 0.246053, -0.215594, 0.240942, 0.821793, -0.399475, + -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, + -0.0849394, 0.917171, -0.155876, -0.0124053, 0.356555, + 0.66983, 0.182569, 0.696897, 0.179807, 0.000864568, + -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, + 0.0160818, -0.0351459, -0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test11) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - - ops::helpers::JacobiSVD jac(matrix, true, true, false); - - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 5}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, 2, -6, + -3, -8, 8, -2, 7, 16, 15, -3, 7, 0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create( + 'c', {6, 5}, + {0.720125, -0.149734, 0.227784, -0.0288531, 0.595353, -0.509487, + -0.567298, -0.237169, -0.0469077, 0.38648, 0.120912, -0.32916, + -0.0202265, 0.921633, -0.153994, 0.180033, -0.294831, 0.357867, + -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, + 0.211013, -0.222425, -0.433662, 0.673515, -0.128465, 0.099309}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-0.581609, 0.315327, 0.333158, 0.34476, -0.576582, + 0.117364, 0.889461, 0.175174, -0.166603, 0.369651, + 0.643246, -0.0899117, 0.613288, 0.442462, -0.0790943, + -0.480818, -0.264384, 0.395122, 0.223126, 0.702145, + -0.0548207, -0.177325, 0.571031, -0.779632, -0.1779}); + + ops::helpers::JacobiSVD jac(matrix, true, true, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test12) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - - ops::helpers::JacobiSVD jac(matrix, true, true, true); - - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {6, 5}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, 2, -6, + -3, -8, 8, -2, 7, 16, 15, -3, 7, 0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, + {0.720125, -0.149734, 0.227784, -0.0288531, 0.595353, -0.227676, + -0.509487, -0.567298, -0.237169, -0.0469077, 0.38648, -0.459108, + 0.120912, -0.32916, -0.0202265, 0.921633, -0.153994, 0.0591992, + 0.180033, -0.294831, 0.357867, -0.194106, -0.646595, -0.544823, + -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.393155, + -0.222425, -0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-0.581609, 0.315327, 0.333158, 0.34476, -0.576582, + 0.117364, 0.889461, 0.175174, -0.166603, 0.369651, + 0.643246, -0.0899117, 0.613288, 0.442462, -0.0790943, + -0.480818, -0.264384, 0.395122, 0.223126, 0.702145, + -0.0548207, -0.177325, 0.571031, -0.779632, -0.1779}); + + ops::helpers::JacobiSVD jac(matrix, true, true, true); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test13) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); - - ops::helpers::JacobiSVD jac(matrix, true, true, true); - - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 6}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, 2, -6, + -3, -8, 8, -2, 7, 16, 15, -3, 7, 0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {0.592324, -0.121832, -0.484064, -0.624878, -0.0975619, + 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, + -0.272693, -0.138725, 0.249336, -0.540587, 0.742962, + 0.263619, -0.903996, 0.179714, 0.276206, 0.0686237, + -0.284717, -0.117079, -0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create( + 'c', {6, 6}, + {-0.619634, -0.158345, 0.462262, -0.021009, -0.299779, 0.53571, + -0.183441, -0.504296, -0.150804, -0.251078, -0.563079, -0.556052, + 0.724925, -0.404744, 0.154104, -0.177039, -0.262604, 0.431988, + 0.0335645, -0.501546, 0.221702, 0.797602, 0.186339, -0.165176, + -0.0675636, 0.0663677, -0.728788, 0.414614, -0.390566, 0.368038, + -0.226262, -0.54849, -0.399426, -0.311613, 0.580387, 0.233392}); + + ops::helpers::JacobiSVD jac(matrix, true, true, true); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test14) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); - - ops::helpers::JacobiSVD jac(matrix, true, true, false); - - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 6}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, 2, -6, + -3, -8, 8, -2, 7, 16, 15, -3, 7, 0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {0.592324, -0.121832, -0.484064, -0.624878, -0.0975619, + 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, + -0.272693, -0.138725, 0.249336, -0.540587, 0.742962, + 0.263619, -0.903996, 0.179714, 0.276206, 0.0686237, + -0.284717, -0.117079, -0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create( + 'c', {6, 5}, + {-0.619634, -0.158345, 0.462262, -0.021009, -0.299779, -0.183441, + -0.504296, -0.150804, -0.251078, -0.563079, 0.724925, -0.404744, + 0.154104, -0.177039, -0.262604, 0.0335645, -0.501546, 0.221702, + 0.797602, 0.186339, -0.0675636, 0.0663677, -0.728788, 0.414614, + -0.390566, -0.226262, -0.54849, -0.399426, -0.311613, 0.580387}); + + ops::helpers::JacobiSVD jac(matrix, true, true, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test15) { +#ifdef __CUDABLAS__ + return; +#endif + auto matrix = NDArrayFactory::create( + 'c', {5, 6}, + {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, 2, -6, + -3, -8, 8, -2, 7, 16, 15, -3, 7, 0, 3, -11, 2, 12, 10}); - #ifdef __CUDABLAS__ - return; - #endif - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + auto expS = NDArrayFactory::create( + 'c', {5, 1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create( + 'c', {5, 5}, {0.592324, -0.121832, -0.484064, -0.624878, -0.0975619, + 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, + -0.272693, -0.138725, 0.249336, -0.540587, 0.742962, + 0.263619, -0.903996, 0.179714, 0.276206, 0.0686237, + -0.284717, -0.117079, -0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create( + 'c', {6, 5}, + {-0.619634, -0.158345, 0.462262, -0.021009, -0.299779, -0.183441, + -0.504296, -0.150804, -0.251078, -0.563079, 0.724925, -0.404744, + 0.154104, -0.177039, -0.262604, 0.0335645, -0.501546, 0.221702, + 0.797602, 0.186339, -0.0675636, 0.0663677, -0.728788, 0.414614, + -0.390566, -0.226262, -0.54849, -0.399426, -0.311613, 0.580387}); - ops::helpers::JacobiSVD jac(matrix, false, false, false); + ops::helpers::JacobiSVD jac(matrix, false, false, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test16) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - - auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,7.07022, 0, 0, 0, -4, 0,5.09585, 0, 0, -3, 0, 0,3.32256, 0, -5, 0, 0, 0,1.00244, -2, -3, -4, -5, 0}); - auto expU = NDArrayFactory::create('c', {6,6}, {-5.58884,-2.18397,-11.0944, 3.30292, 0,-10, 8.19094, 5.05917, 16.9641,-4.53112, 0, 20, 6.55878, 3.76734, 15.9255,-3.76399, 0,-19, 1.36021, 23.3551,-8.01165, -1.5816, 0, 14, -15.6318,-2.85386, 8.83051, 2.74286, 1,-16, 18, -6, -18, 1,-15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, 14.5866, 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, 2.37014, -3, 0.56907,-8.93313,-5.31596, 3.10096, 16,-10.6859, 1.70708,-7.24295,-10.6975}); - - ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - - svd.DivideAndConquer(0, 3, 1, 1, 1); - // svd._m.printIndexedBuffer(); - ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, + -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix4 = NDArrayFactory::create( + 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, + -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); + + auto expM = NDArrayFactory::create( + 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, 7.07022, 0, 0, 0, + -4, 0, 5.09585, 0, 0, -3, 0, 0, 3.32256, 0, + -5, 0, 0, 0, 1.00244, -2, -3, -4, -5, 0}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, {-5.58884, -2.18397, -11.0944, 3.30292, 0, -10, + 8.19094, 5.05917, 16.9641, -4.53112, 0, 20, + 6.55878, 3.76734, 15.9255, -3.76399, 0, -19, + 1.36021, 23.3551, -8.01165, -1.5816, 0, 14, + -15.6318, -2.85386, 8.83051, 2.74286, 1, -16, + 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, + {-18, 1, 19, -7, 1, 2, 14.5866, + 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, + 2.37014, -3, 0.56907, -8.93313, -5.31596, 3.10096, 16, + -10.6859, 1.70708, -7.24295, -10.6975}); + + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + + svd.DivideAndConquer(0, 3, 1, 1, 1); + // svd._m.printIndexedBuffer(); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test17) { - - #ifdef __CUDABLAS__ - return; - #endif - auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); - auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - - auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,12.1676, 0, 0, 0, -4, 0,7.49514, 0, 0, -3, 0, 0,5.00951, 0, -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); - auto expU = NDArrayFactory::create('c', {6,6}, {0.295543,-0.238695, 0.262095,-0.231772, -0.85631,-10, 0.519708,0.0571492,-0.368706,-0.727615, 0.247527, 20, 0.313717,-0.561567,-0.602941, 0.469567,-0.0468295,-19, 0.474589,-0.372165, 0.656962, 0.124776, 0.434845, 14, -0.564717,-0.697061,0.0150082, -0.4252, 0.119081,-16, 18, -6, -18, 1, -15,-12}); - auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-0.0366659, 0.977361,-0.0316106,0.205967, -2, -0.670795, -0.151697, -0.503288,0.523185, -3, 0.740124,-0.0841435, -0.486714,0.456339, 16, 0.0300945, -0.121135, 0.71331,0.689645}); - - ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); - svd._m = matrix1; - svd._u = matrix2; - svd._v = matrix3; - - svd.DivideAndConquer(0, 3, 1, 1, 1); - - ASSERT_TRUE(expM.equalsTo(&svd._m)); - ASSERT_TRUE(expU.equalsTo(&svd._u)); - ASSERT_TRUE(expV.equalsTo(&svd._v)); - - ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); - ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); - ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); +#ifdef __CUDABLAS__ + return; +#endif + auto matrix1 = NDArrayFactory::create( + 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, + -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); + auto matrix2 = NDArrayFactory::create( + 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, + -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, + -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); + auto matrix3 = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, + 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix4 = NDArrayFactory::create( + 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, + -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); + + auto expM = NDArrayFactory::create( + 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, 12.1676, 0, 0, 0, + -4, 0, 7.49514, 0, 0, -3, 0, 0, 5.00951, 0, + -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); + auto expU = NDArrayFactory::create( + 'c', {6, 6}, + {0.295543, -0.238695, 0.262095, -0.231772, -0.85631, -10, + 0.519708, 0.0571492, -0.368706, -0.727615, 0.247527, 20, + 0.313717, -0.561567, -0.602941, 0.469567, -0.0468295, -19, + 0.474589, -0.372165, 0.656962, 0.124776, 0.434845, 14, + -0.564717, -0.697061, 0.0150082, -0.4252, 0.119081, -16, + 18, -6, -18, 1, -15, -12}); + auto expV = NDArrayFactory::create( + 'c', {5, 5}, {-18, 1, 19, -7, 1, + 2, -0.0366659, 0.977361, -0.0316106, 0.205967, + -2, -0.670795, -0.151697, -0.503288, 0.523185, + -3, 0.740124, -0.0841435, -0.486714, 0.456339, + 16, 0.0300945, -0.121135, 0.71331, 0.689645}); + + ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + + svd.DivideAndConquer(0, 3, 1, 1, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); + + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); } // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test18) { -// auto matrix('c', {10,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , -// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , -// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , -// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , -// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 -// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4}); - -// auto expS('c', {10, 1}, {65.0394, 56.1583, 48.9987, 39.2841, 35.7296, 22.8439, 17.474, 15.2708, 15.0768, 0.846648}); - -// auto expU('c', {10,10}, {0.413187, 0.159572,0.0238453, 0.601154,-0.0428558, -0.461779, 0.41787, -0.221153, 0.0206268, 0.0532219, -// 0.364377,-0.154281, 0.199857,-0.0943331, 0.415653, -0.139834, -0.258458, 0.10677, 0.72003,-0.0749772, -// -0.315063,-0.418079,-0.377499, 0.37031, 0.0123835, 0.300036, 0.153702, -0.129223, 0.390675, 0.403962, -// 0.102001,-0.216667, -0.74093,-0.166164,-0.0269665, -0.240065, 0.0549761,-0.0178001, 0.0197525, -0.55134, -// -0.107298, 0.386899,-0.377536, 0.033214, 0.486739, -0.245438, -0.43788, -0.208875, -0.170449, 0.365491, -// 0.18026, 0.240482,-0.115801, 0.237399, -0.643413, 0.139274, -0.582963, -0.116222, 0.224524,-0.0525887, -// 0.141172, 0.340505,-0.261653, 0.186411, 0.0625811, 0.19585, 0.128195, 0.832893, 0.0319884, 0.0864513, -// -0.385777,-0.330504, 0.128342, 0.156083, -0.200883, -0.648548, -0.256507, 0.40519,-0.0434365, 0.0909978, -// 0.574478,-0.371028,-0.136672,-0.328417, -0.190226,-0.0476664,-0.0399815, 0.0687528, -0.242039, 0.549918, -// 0.209886,-0.398294,0.0919207, 0.490454, 0.305228, 0.280486, -0.341358, 0.0540678, -0.432618, -0.264332}); - -// auto expV('c', {10,10}, {0.423823,-0.0845148, 0.389647, -0.10717,-0.168732, 0.123783, 0.159237, -0.450407, -0.611513,-0.0629076, -// 0.412121, 0.317493, -0.355665,-0.383203,-0.382616,-0.309073, -0.21869,-0.0746378, 0.0829771, 0.392186, -// -0.0603483, 0.232234, 0.0383737, 0.435441,0.0829318, 0.327822,-0.206101, 0.184083, -0.34018, 0.667018, -// -0.453935, 0.119616, 0.288392, 0.184366,-0.524289, -0.42264, 0.41005,-0.0505891,0.00333608, 0.195602, -// 0.247802, 0.0776165, 0.33026, 0.190986, 0.526809,-0.345006,0.0651023, -0.386472, 0.395169, 0.284091, -// 0.426355, -0.269507, 0.304685, 0.386708,-0.257916,-0.287742,-0.329622, 0.463719, 0.0613767, -0.16261, -// -0.384582, 0.241486, 0.425935,-0.292636,0.0465594,-0.125018,-0.685871, -0.112806,-0.0977978, -0.127356, -// -0.121678, -0.06796, -0.501443, 0.473165,0.0422977,-0.369324,-0.248758, -0.408769, -0.305785, -0.211138, -// 0.186099, 0.809997, 0.0338281, 0.268965, -0.04829, 0.141617, 0.12121, 0.0362537, 0.0831986, -0.436428, -// 0.0174496, 0.161638,-0.0334757,-0.224027, 0.439364,-0.478697, 0.237318, 0.457809, -0.483235,-0.0253522}); +// auto matrix('c', {10,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 +// ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 +// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 +// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 +// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 +// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 +// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 +// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 +// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 +// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 +// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4}); + +// auto expS('c', {10, 1}, +// {65.0394, 56.1583, 48.9987, 39.2841, 35.7296, 22.8439, 17.474, 15.2708, 15.0768, +// 0.846648}); + +// auto expU('c', {10,10}, {0.413187, 0.159572,0.0238453, +// 0.601154,-0.0428558, -0.461779, 0.41787, -0.221153, 0.0206268, +// 0.0532219, +// 0.364377,-0.154281, +// 0.199857,-0.0943331, 0.415653, +// -0.139834, -0.258458, 0.10677, +// 0.72003,-0.0749772, +// -0.315063,-0.418079,-0.377499, +// 0.37031, 0.0123835, 0.300036, +// 0.153702, -0.129223, 0.390675, +// 0.403962, +// 0.102001,-0.216667, +// -0.74093,-0.166164,-0.0269665, +// -0.240065, 0.0549761,-0.0178001, +// 0.0197525, -0.55134, +// -0.107298, 0.386899,-0.377536, +// 0.033214, 0.486739, -0.245438, +// -0.43788, -0.208875, -0.170449, +// 0.365491, +// 0.18026, 0.240482,-0.115801, +// 0.237399, -0.643413, 0.139274, +// -0.582963, -0.116222, +// 0.224524,-0.0525887, +// 0.141172, 0.340505,-0.261653, +// 0.186411, 0.0625811, 0.19585, +// 0.128195, 0.832893, 0.0319884, +// 0.0864513, +// -0.385777,-0.330504, 0.128342, +// 0.156083, -0.200883, -0.648548, +// -0.256507, 0.40519,-0.0434365, +// 0.0909978, +// 0.574478,-0.371028,-0.136672,-0.328417, +// -0.190226,-0.0476664,-0.0399815, +// 0.0687528, -0.242039, 0.549918, +// 0.209886,-0.398294,0.0919207, +// 0.490454, 0.305228, 0.280486, +// -0.341358, 0.0540678, -0.432618, +// -0.264332}); + +// auto expV('c', {10,10}, {0.423823,-0.0845148, 0.389647, +// -0.10717,-0.168732, 0.123783, 0.159237, -0.450407, -0.611513,-0.0629076, +// 0.412121, 0.317493, +// -0.355665,-0.383203,-0.382616,-0.309073, +// -0.21869,-0.0746378, 0.0829771, +// 0.392186, +// -0.0603483, 0.232234, 0.0383737, +// 0.435441,0.0829318, +// 0.327822,-0.206101, 0.184083, +// -0.34018, 0.667018, +// -0.453935, 0.119616, 0.288392, +// 0.184366,-0.524289, -0.42264, +// 0.41005,-0.0505891,0.00333608, +// 0.195602, +// 0.247802, 0.0776165, 0.33026, +// 0.190986, +// 0.526809,-0.345006,0.0651023, +// -0.386472, 0.395169, 0.284091, +// 0.426355, -0.269507, 0.304685, +// 0.386708,-0.257916,-0.287742,-0.329622, +// 0.463719, 0.0613767, -0.16261, +// -0.384582, 0.241486, +// 0.425935,-0.292636,0.0465594,-0.125018,-0.685871, +// -0.112806,-0.0977978, -0.127356, +// -0.121678, -0.06796, -0.501443, +// 0.473165,0.0422977,-0.369324,-0.248758, +// -0.408769, -0.305785, -0.211138, +// 0.186099, 0.809997, 0.0338281, +// 0.268965, -0.04829, 0.141617, +// 0.12121, 0.0362537, 0.0831986, +// -0.436428, +// 0.0174496, +// 0.161638,-0.0334757,-0.224027, +// 0.439364,-0.478697, 0.237318, +// 0.457809, -0.483235,-0.0253522}); // ops::helpers::SVD svd(matrix, 8, true, true, true); // // svd._u.printShapeInfo(); @@ -1333,43 +1917,109 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } - // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test19) { -// auto matrix('c', {11,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , -// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , -// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , -// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , -// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 -// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); - -// auto expS('c', {10, 1}, {65.5187, 56.305, 50.9808, 41.6565, 35.8698, 29.3898, 17.9743, 15.3568, 15.2223, 0.846847}); - -// auto expU('c', {11,11}, {-0.387999,-0.117659, 0.162976, 0.641067,-0.0174306, -0.181469,-0.218643, -0.308042, 0.0670776,-0.0632539, -0.462228, -// -0.37021, 0.14822, -0.195157,-0.0467394, -0.381275, -0.183363, 0.326599, -0.370579, -0.56626, 0.0798798, 0.225133, -// 0.339692, 0.433146, 0.30841, 0.134184, -0.108725, 0.466056,-0.153546, -0.359783, -0.189621, -0.402737, 0.0605675, -// -0.0650167, 0.268868, 0.662416, -0.327524, 0.0339198,-0.0916729,0.0415428, -0.0765093,-0.0288338, 0.546108, -0.247418, -// 0.114029,-0.361828, 0.379255,-0.0935836, -0.488912, -0.125232, 0.480666,-0.00544881, 0.280747, -0.36698,-0.0648559, -// -0.174798, -0.21859, 0.178313, 0.212153, 0.579101, 0.369942, 0.551063, -0.139813,-0.0296135, 0.0572204, 0.212783, -// -0.133981,-0.311817, 0.304673, 0.0865395, -0.104221, 0.196295,-0.191271, 0.571084, -0.603697,-0.0868996,-0.0196788, -// 0.398676, 0.319697, -0.112145, 0.235089, 0.201666, -0.337134, 0.43406, 0.261686, -0.283102,-0.0999458, -0.411893, -// -0.559998, 0.392802, 0.0996997, -0.281135, 0.24017, -0.136769,0.0121463, 0.218664, 0.127577, -0.550001,0.00227476, -// -0.197522, 0.403875,-0.0647804, 0.383315, -0.388502, 0.335719, 0.20912, 0.404926, 0.309087, 0.266437, 0.0942471, -// 0.140425,0.0934688, 0.325994, 0.345081, 0.0825574, -0.521239,-0.129018, 0.0806886, 0.0442647, 0.014397, 0.665103}); - -// auto expV('c', {10,10}, {-0.4428, 0.0661762,-0.361903, 0.0307317, 0.19574,-0.0356551,-0.241991, 0.0866805, 0.74701, 0.062837, -// -0.400091, -0.277277, 0.375095, -0.323052, 0.443668, -0.264809, 0.292881, -0.106586,-0.00623963,-0.392226, -// 0.0536693, -0.232105,0.0106246, 0.332557, -0.167406, 0.400872,0.0835708, 0.414598, 0.141906,-0.666936, -// 0.473793, -0.121962,-0.147941, 0.414665, 0.538964, -0.372149,-0.285458, -0.132952, -0.0166319,-0.195945, -// -0.251722,-0.0813691,-0.233887, 0.280439, -0.512597, -0.328782, 0.074277, -0.581806, -0.0327555,-0.284121, -// -0.406324, 0.284462,-0.168731, 0.518021, 0.226396, -0.109282, 0.381083, 0.305342, -0.359301, 0.162524, -// 0.335857, -0.302206,-0.484806, -0.196382,0.00286755, -0.111789, 0.672115, 0.0705632, 0.191787, 0.127533, -// 0.185896, 0.134279, 0.608397, 0.382412,-0.0997649, -0.117987, 0.326934,-0.0941208, 0.496913, 0.210914, -// -0.201675, -0.795446,0.0916484, 0.267237,0.00604554, 0.167517, -0.13914,-0.0355323, -0.0869256, 0.436465, -// 0.00123325, -0.142684,0.0978458,-0.0945446, -0.349755, -0.674457,-0.196126, 0.587134,-0.00964182,0.0249317}); +// auto matrix('c', {11,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 +// ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 +// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 +// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 +// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 +// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 +// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 +// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 +// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 +// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 +// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -7, +// 1, -2, 15, 0, 4, -9,19, -3, 10 +// }); + +// auto expS('c', {10, 1}, +// {65.5187, 56.305, 50.9808, 41.6565, 35.8698, 29.3898, 17.9743, 15.3568, 15.2223, +// 0.846847}); + +// auto expU('c', {11,11}, {-0.387999,-0.117659, 0.162976, +// 0.641067,-0.0174306, -0.181469,-0.218643, -0.308042, +// 0.0670776,-0.0632539, -0.462228, +// -0.37021, 0.14822, +// -0.195157,-0.0467394, -0.381275, +// -0.183363, 0.326599, -0.370579, +// -0.56626, 0.0798798, 0.225133, +// 0.339692, 0.433146, 0.30841, +// 0.134184, -0.108725, +// 0.466056,-0.153546, -0.359783, +// -0.189621, -0.402737, 0.0605675, +// -0.0650167, 0.268868, 0.662416, +// -0.327524, +// 0.0339198,-0.0916729,0.0415428, +// -0.0765093,-0.0288338, 0.546108, +// -0.247418, +// 0.114029,-0.361828, +// 0.379255,-0.0935836, -0.488912, +// -0.125232, 0.480666,-0.00544881, +// 0.280747, -0.36698,-0.0648559, +// -0.174798, -0.21859, 0.178313, +// 0.212153, 0.579101, 0.369942, +// 0.551063, -0.139813,-0.0296135, +// 0.0572204, 0.212783, +// -0.133981,-0.311817, 0.304673, +// 0.0865395, -0.104221, +// 0.196295,-0.191271, 0.571084, +// -0.603697,-0.0868996,-0.0196788, +// 0.398676, 0.319697, -0.112145, +// 0.235089, 0.201666, -0.337134, +// 0.43406, 0.261686, +// -0.283102,-0.0999458, -0.411893, +// -0.559998, 0.392802, 0.0996997, +// -0.281135, 0.24017, +// -0.136769,0.0121463, 0.218664, +// 0.127577, -0.550001,0.00227476, +// -0.197522, 0.403875,-0.0647804, +// 0.383315, -0.388502, 0.335719, +// 0.20912, 0.404926, 0.309087, +// 0.266437, 0.0942471, +// 0.140425,0.0934688, 0.325994, +// 0.345081, 0.0825574, +// -0.521239,-0.129018, 0.0806886, +// 0.0442647, 0.014397, 0.665103}); + +// auto expV('c', {10,10}, {-0.4428, 0.0661762,-0.361903, 0.0307317, +// 0.19574,-0.0356551,-0.241991, 0.0866805, 0.74701, 0.062837, +// -0.400091, -0.277277, 0.375095, +// -0.323052, 0.443668, -0.264809, +// 0.292881, +// -0.106586,-0.00623963,-0.392226, +// 0.0536693, -0.232105,0.0106246, +// 0.332557, -0.167406, +// 0.400872,0.0835708, 0.414598, +// 0.141906,-0.666936, +// 0.473793, -0.121962,-0.147941, +// 0.414665, 0.538964, +// -0.372149,-0.285458, -0.132952, +// -0.0166319,-0.195945, +// -0.251722,-0.0813691,-0.233887, +// 0.280439, -0.512597, -0.328782, +// 0.074277, -0.581806, +// -0.0327555,-0.284121, -0.406324, +// 0.284462,-0.168731, 0.518021, +// 0.226396, -0.109282, 0.381083, +// 0.305342, -0.359301, 0.162524, +// 0.335857, -0.302206,-0.484806, +// -0.196382,0.00286755, -0.111789, +// 0.672115, 0.0705632, 0.191787, +// 0.127533, 0.185896, 0.134279, +// 0.608397, 0.382412,-0.0997649, +// -0.117987, 0.326934,-0.0941208, +// 0.496913, 0.210914, +// -0.201675, -0.795446,0.0916484, +// 0.267237,0.00604554, 0.167517, +// -0.13914,-0.0355323, -0.0869256, +// 0.436465, +// 0.00123325, +// -0.142684,0.0978458,-0.0945446, +// -0.349755, -0.674457,-0.196126, +// 0.587134,-0.00964182,0.0249317}); // ops::helpers::SVD svd(matrix, 8, true, true, true); @@ -1382,43 +2032,111 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } - // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test20) { -// auto matrix('c', {10,11}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , -// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , -// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , -// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , -// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 -// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); - -// auto expS('c', {10, 1}, {68.9437, 54.8773, 50.7858, 42.4898, 35.1984, 26.6285, 21.376, 12.2334, 5.9112, 0.38292}); - -// auto expU('c', {10,10}, {0.30332,-0.0677785, 0.155514, -0.722623,-0.0843687,-0.0712535, 0.414936, -0.15422, -0.381536,-0.057561, -// 0.473286, 0.0231518, 0.0878106, 0.45493, -0.311654, 0.138957, 0.311305, 0.509971, -0.288207,0.0656506, -// -0.131548, 0.32051, 0.489848,-0.0539042, -0.521328, -0.363728, -0.328685,-0.0329672,-0.0726502, 0.344431, -// 0.072974, 0.522632, -0.477056, 0.0618953,-0.0980883, -0.095653, -0.26596, -0.15453, -0.475107,-0.388594, -// 0.267569, -0.336154,-0.0930604, -0.261336, -0.39945, 0.480346, -0.568317, 0.0593335, 0.102036,-0.106029, -// -0.0919782, -0.460136, 0.106434, 0.327722, 0.0952523, 0.0915698, -0.129052, -0.460878, -0.59722, 0.240608, -// -0.248827, -0.48834, -0.243788, -0.106636,-0.0803772, -0.567457, -0.12005, 0.480504, -0.188409,-0.139802, -// 0.643408, -0.16245, -0.152596, 0.16849,-0.0120438, -0.51616,-0.0694232, -0.36172, 0.322169,0.0440701, -// -0.229467,-0.0227008, -0.588303,-0.0327104, -0.482264, 0.0794715, 0.340158, -0.175969, 0.108784, 0.449731, -// 0.229718, 0.169979, -0.227516, -0.21815, 0.454459, 0.017476, -0.278516, 0.287333, -0.148844, 0.655637}); - -// auto expV('c', {11,11}, {0.190806, -0.193628, 0.383793,-0.0266376, 0.113035, 0.158361, 0.0297803, -0.793229, -0.13761,-0.260666, -0.152503, -// -0.303449, 0.0392386, 0.250627, -0.165231, 0.141567, 0.0479565, 0.72763, 0.14053, -0.339907, 0.224366, -0.280806, -// -0.159724, -0.38984, -0.256355, -0.337861, 0.075089, -0.237427, -0.153718, -0.217747, 0.320899, 0.455058, -0.446697, -// 0.376823, -0.560303, 0.269135, 0.265416,-0.00742902, 0.0263377, -0.192808, 0.435842, -0.275365,0.0511804, -0.30799, -// 0.522537, 0.209791, -0.44191, -0.282323, -0.12139, 0.226382, 0.221075, 0.0844301, 0.0285412,-0.297578, -0.443394, -// 0.0588008, 0.115035, 0.54835, -0.52266, -0.141345, 0.411122, -0.182423, 0.213721, 0.353022, 0.119504, 0.0508673, -// -0.299021,-0.0424794, -0.285618, 0.177961, 0.35831, 0.769783, -0.215983,-0.00423939, -0.110575,0.0928082,-0.0841152, -// -0.0977062, -0.624782, -0.240391, -0.276154, -0.342018, 0.199695, 0.268881, 0.00402219,-0.0536164, -0.17679, 0.450283, -// 0.428931, 0.0748696, -0.120853, -0.360103, 0.37093,-0.0611563, -0.100263, -0.0604207, -0.432926, 0.412875, 0.39142, -// -0.35553, 0.127463,-0.0199906, -0.343149, -0.315968, -0.115698, -0.442585, 0.0126156, -0.584161,-0.219242, -0.20156, -// -0.134753, -0.154272, 0.037343, -0.281348, 0.666324, -0.213813,-0.0427932, 0.238783, 0.132347,-0.557478, 0.0253325}); +// auto matrix('c', {10,11}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 +// ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 +// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 +// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 +// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 +// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 +// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 +// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 +// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 +// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 +// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -7, +// 1, -2, 15, 0, 4, -9,19, -3, 10 +// }); + +// auto expS('c', {10, 1}, +// {68.9437, 54.8773, 50.7858, 42.4898, 35.1984, 26.6285, 21.376, 12.2334, 5.9112, +// 0.38292}); + +// auto expU('c', {10,10}, {0.30332,-0.0677785, 0.155514, +// -0.722623,-0.0843687,-0.0712535, 0.414936, -0.15422, +// -0.381536,-0.057561, +// 0.473286, 0.0231518, 0.0878106, +// 0.45493, -0.311654, 0.138957, +// 0.311305, 0.509971, +// -0.288207,0.0656506, +// -0.131548, 0.32051, +// 0.489848,-0.0539042, -0.521328, +// -0.363728, +// -0.328685,-0.0329672,-0.0726502, +// 0.344431, +// 0.072974, 0.522632, -0.477056, +// 0.0618953,-0.0980883, -0.095653, +// -0.26596, -0.15453, +// -0.475107,-0.388594, 0.267569, +// -0.336154,-0.0930604, -0.261336, +// -0.39945, 0.480346, -0.568317, +// 0.0593335, 0.102036,-0.106029, +// -0.0919782, -0.460136, 0.106434, +// 0.327722, 0.0952523, 0.0915698, +// -0.129052, -0.460878, -0.59722, +// 0.240608, +// -0.248827, -0.48834, -0.243788, +// -0.106636,-0.0803772, -0.567457, +// -0.12005, 0.480504, +// -0.188409,-0.139802, +// 0.643408, -0.16245, -0.152596, +// 0.16849,-0.0120438, +// -0.51616,-0.0694232, -0.36172, +// 0.322169,0.0440701, +// -0.229467,-0.0227008, +// -0.588303,-0.0327104, -0.482264, +// 0.0794715, 0.340158, -0.175969, +// 0.108784, 0.449731, +// 0.229718, 0.169979, -0.227516, +// -0.21815, 0.454459, 0.017476, +// -0.278516, 0.287333, -0.148844, +// 0.655637}); + +// auto expV('c', {11,11}, {0.190806, -0.193628, 0.383793,-0.0266376, +// 0.113035, 0.158361, 0.0297803, -0.793229, -0.13761,-0.260666, +// -0.152503, +// -0.303449, 0.0392386, 0.250627, +// -0.165231, 0.141567, 0.0479565, +// 0.72763, 0.14053, -0.339907, +// 0.224366, -0.280806, -0.159724, +// -0.38984, -0.256355, -0.337861, +// 0.075089, -0.237427, -0.153718, +// -0.217747, 0.320899, 0.455058, +// -0.446697, +// 0.376823, -0.560303, 0.269135, +// 0.265416,-0.00742902, 0.0263377, +// -0.192808, 0.435842, +// -0.275365,0.0511804, -0.30799, +// 0.522537, 0.209791, -0.44191, +// -0.282323, -0.12139, 0.226382, +// 0.221075, 0.0844301, +// 0.0285412,-0.297578, -0.443394, +// 0.0588008, 0.115035, 0.54835, +// -0.52266, -0.141345, 0.411122, +// -0.182423, 0.213721, 0.353022, +// 0.119504, 0.0508673, +// -0.299021,-0.0424794, -0.285618, +// 0.177961, 0.35831, 0.769783, +// -0.215983,-0.00423939, +// -0.110575,0.0928082,-0.0841152, +// -0.0977062, -0.624782, -0.240391, +// -0.276154, -0.342018, 0.199695, +// 0.268881, 0.00402219,-0.0536164, +// -0.17679, 0.450283, +// 0.428931, 0.0748696, -0.120853, +// -0.360103, 0.37093,-0.0611563, +// -0.100263, -0.0604207, -0.432926, +// 0.412875, 0.39142, -0.35553, +// 0.127463,-0.0199906, -0.343149, +// -0.315968, -0.115698, -0.442585, +// 0.0126156, -0.584161,-0.219242, +// -0.20156, +// -0.134753, -0.154272, 0.037343, +// -0.281348, 0.666324, +// -0.213813,-0.0427932, 0.238783, +// 0.132347,-0.557478, 0.0253325}); // ops::helpers::SVD svd(matrix, 8, true, true, true); @@ -1431,16 +2149,18 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } - ///////////////////////////////////////////////////////////////////// -//TEST_F(HelpersTests1, reverseArray_test1) { +// TEST_F(HelpersTests1, reverseArray_test1) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); -// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); -// auto outArr = NDArrayFactory::create('c', {2,5}); +// auto inArr = NDArrayFactory::create('c', {2,5}, +// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', +// {2,5}, {10,9,8,7,6,5,4,3,2,1}); auto outArr = +// NDArrayFactory::create('c', {2,5}); // // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo()); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), +// inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), +// outArr.shapeInfo()); // // ASSERT_TRUE(outArr.equalsTo(&exp)); // ASSERT_TRUE(outArr.isSameShapeStrict(exp)); @@ -1448,13 +2168,16 @@ TEST_F(HelpersTests1, SVD_test17) { // // ///////////////////////////////////////////////////////////////////// -//TEST_F(HelpersTests1, reverseArray_test2) { +// TEST_F(HelpersTests1, reverseArray_test2) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); -// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); +// auto inArr = NDArrayFactory::create('c', {2,5}, +// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', +// {2,5}, {10,9,8,7,6,5,4,3,2,1}); // // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), inArr.getBuffer(), inArr.shapeInfo()); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), +// inArr.getBuffer(), inArr.shapeInfo(), inArr.getBuffer(), +// inArr.shapeInfo()); // // ASSERT_TRUE(inArr.equalsTo(&exp)); // ASSERT_TRUE(inArr.isSameShapeStrict(exp)); @@ -1462,13 +2185,16 @@ TEST_F(HelpersTests1, SVD_test17) { // // ///////////////////////////////////////////////////////////////////// -//TEST_F(HelpersTests1, reverseArray_test3) { +// TEST_F(HelpersTests1, reverseArray_test3) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); -// auto exp = NDArrayFactory::create('c', {2,5}, {5,4,3,2,1,6,7,8,9,10}); -// auto outArr = NDArrayFactory::create('c', {2,5}); +// auto inArr = NDArrayFactory::create('c', {2,5}, +// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', +// {2,5}, {5,4,3,2,1,6,7,8,9,10}); auto outArr = +// NDArrayFactory::create('c', {2,5}); // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo(), 5); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), +// inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), +// outArr.shapeInfo(), 5); // // ASSERT_TRUE(outArr.equalsTo(&exp)); // ASSERT_TRUE(outArr.isSameShapeStrict(exp)); @@ -1476,1023 +2202,1639 @@ TEST_F(HelpersTests1, SVD_test17) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test1) { + const int bS = 2; + const int inSize = 4; + const int numUnits = 4; - const int bS = 2; - const int inSize = 4; - const int numUnits = 4; - - NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); - NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); - NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); - NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); - NDArray b ('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); - NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); + NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); + NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); + NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); + NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); + NDArray b('c', {2 * numUnits}, {0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.4}); + NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - NDArray expHt('c', {bS, numUnits}, {0.492988, 0.56489956, 0.6291452 , 0.6858091,0.492988, 0.56489956, 0.6291452 , 0.6858091}); + NDArray expHt('c', {bS, numUnits}, + {0.492988, 0.56489956, 0.6291452, 0.6858091, 0.492988, + 0.56489956, 0.6291452, 0.6858091}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, + &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } - /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test2) { + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; - const int bS = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create( + 'c', {2 * numUnits}, {0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.4}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.6169093,0.67506987,0.72589741,0.76986654,0.6169093,0.67506987,0.72589741,0.76986654}); + auto expHt = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.6169093, 0.67506987, 0.72589741, 0.76986654, 0.6169093, 0.67506987, + 0.72589741, 0.76986654}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, + &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test3) { + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; - const int bS = 2; - const int inSize = 10; - const int numUnits = 4; - - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}, {0.01,0.02,0.03,0.04, 0.05,0.06,0.07,0.08}); + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create( + 'c', {2 * numUnits}, {0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - xt.assign(0.1); - ht_1.assign(0.2); - Wx.assign(0.3); - Wh.assign(0.4); + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.5915195, 0.6043678, 0.6169093, 0.6291452,0.5915195, 0.6043678, 0.6169093, 0.6291452}); + auto expHt = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.5915195, 0.6043678, 0.6169093, 0.6291452, 0.5915195, 0.6043678, + 0.6169093, 0.6291452}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, + &ht_1, &ht); - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test4) { + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; - const int bS = 2; - const int inSize = 3; - const int numUnits = 4; + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2 * numUnits}); - auto xt = NDArrayFactory::create('c', {bS, inSize}); - auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); - auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); - auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2*numUnits}); + auto ht = NDArrayFactory::create('c', {bS, numUnits}); - auto ht = NDArrayFactory::create('c', {bS, numUnits}); + xt.linspace(0.01, 0.01); + ht_1 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; - xt.linspace(0.01, 0.01); - ht_1 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; + auto expHt = NDArrayFactory::create( + 'c', {bS, numUnits}, + {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, + 0.69882484, 0.69882484}); - auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484}); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, + &ht_1, &ht); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); - - ASSERT_TRUE(expHt.isSameShape(ht)); - ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); } #endif //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_1) { + auto x = NDArrayFactory::create('c', {3, 3}, + {10, 11, 12, 13, 14, 15, 16, 17, 18}); + auto y = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {138., 171., 204., 174., 216., 258., 210., 261., 312.}); - auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); - auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); - - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - delete result; + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_2) { + auto x = NDArrayFactory::create('c', {3, 3}, + {10, 11, 12, 13, 14, 15, 16, 17, 18}); + auto y = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto expected = NDArrayFactory::create( + 'c', {3, 3}, {138., 171., 204., 174., 216., 258., 210., 261., 312.}); + auto result = NDArrayFactory::create('c', {3, 3}); - auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); - auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); - auto result = NDArrayFactory::create('c', {3,3}); - - MmulHelper::mmul(&x, &y, &result, 1., 0.); - - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + MmulHelper::mmul(&x, &y, &result, 1., 0.); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_3) { + auto x = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + auto y = NDArrayFactory::create('c', {4, 5}); + y.linspace(1); + auto expected = NDArrayFactory::create( + 'c', {3, 5}, + {110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., + 466., 508., 550.}); - auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); - auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); - - delete result; + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_4) { + auto x = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + auto y = NDArrayFactory::create('c', {4, 5}); + y.linspace(1); + auto expected = NDArrayFactory::create( + 'c', {3, 5}, + {110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., + 466., 508., 550.}); + auto result = NDArrayFactory::create('c', {3, 5}); - auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); - auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); - auto result = NDArrayFactory::create('c', {3,5}); - - MmulHelper::mmul(&x, &y, &result, 1., 0.); + MmulHelper::mmul(&x, &y, &result, 1., 0.); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_5) { + auto x = NDArrayFactory::create('c', {4, 3}); + x.linspace(1); + auto y = NDArrayFactory::create('c', {3, 5}); + y.linspace(1); + auto expected = NDArrayFactory::create( + 'c', {4, 5}, + {46., 52., 58., 64., 70., 100., 115., 130., 145., 160., + 154., 178., 202., 226., 250., 208., 241., 274., 307., 340.}); - auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); - auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); - - auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); - ASSERT_TRUE(expected.isSameShape(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); - delete result; + delete result; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_6) { + auto x = NDArrayFactory::create('c', {4, 3}); + x.linspace(1); + auto y = NDArrayFactory::create('c', {3, 5}); + y.linspace(1); + auto expected = NDArrayFactory::create( + 'c', {4, 5}, + {46., 52., 58., 64., 70., 100., 115., 130., 145., 160., + 154., 178., 202., 226., 250., 208., 241., 274., 307., 340.}); + auto result = NDArrayFactory::create('c', {4, 5}); - auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); - auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); - auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); - auto result = NDArrayFactory::create('c', {4,5}); - - MmulHelper::mmul(&x, &y, &result, 1., 0.); - - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + MmulHelper::mmul(&x, &y, &result, 1., 0.); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_7) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); + auto result = NDArrayFactory::create('c', {4, 4}); - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - auto result = NDArrayFactory::create('c', {4,4}); - - MmulHelper::mmul(&x, &y, &result, 1., 0.); - - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + MmulHelper::mmul(&x, &y, &result, 1., 0.); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_1) { + auto a = NDArrayFactory::create('c', {2, 3, 4}); + auto b = NDArrayFactory::create('c', {2, 5, 3}); - auto a = NDArrayFactory::create('c', {2, 3, 4}); - auto b = NDArrayFactory::create('c', {2, 5, 3}); - - auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); + auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); - ASSERT_TRUE(c->isSameShape({2,4,2,5})); - delete c; + ASSERT_TRUE(c->isSameShape({2, 4, 2, 5})); + delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_2) { + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - - auto c = MmulHelper::tensorDot(&a, &b, {2,1}, {4,2}); + auto c = MmulHelper::tensorDot(&a, &b, {2, 1}, {4, 2}); - ASSERT_TRUE(c->isSameShape({7,6,2,5,8})); - delete c; + ASSERT_TRUE(c->isSameShape({7, 6, 2, 5, 8})); + delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_3) { + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); + auto c = NDArrayFactory::create('f', {7, 6, 2, 8, 5}); - auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto c = NDArrayFactory::create('f', {7,6,2,8,5}); + MmulHelper::tensorDot(&a, &b, &c, {2, 1}, {4, 2}, {0, 1, 2, 4, 3}); - MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); - - ASSERT_TRUE(c.isSameShape({7,6,2,8,5})); + ASSERT_TRUE(c.isSameShape({7, 6, 2, 8, 5})); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_4) { - - auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); - auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); - auto c = NDArrayFactory::create('f', {7,3,2,2,5}); - auto expected = NDArrayFactory::create('c', {7,3,2,2,5}, { 754.5, 2014.5, 3274.5, 4534.5 , 5794.5, 964.5, 2224.5, 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786. , 2118. , 3450. , 4782. , 6114. , 1008. , 2340. , 3672. , 5004. , 6336. , - 7446. , 8778. , 10110. , 11442. , 12774. , 7668. , 9000. , 10332. , 11664. , 12996. , 817.5, 2221.5, 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, 9475.5, 10879.5, 12283.5, 13687.5, - 1888.5, 5740.5, 9592.5, 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5,21148.5, 25000.5, 28852.5, 32704.5, 36556.5,21790.5, 25642.5, 29494.5, 33346.5, 37198.5, 1920. , 5844. , 9768. , 13692. , 17616. , 2574. , 6498. , 10422. , 14346. , 18270. , - 21540. , 25464. , 29388. , 33312. , 37236. ,22194. , 26118. , 30042. , 33966. , 37890. , 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, 2617.5, 6613.5, 10609.5, 14605.5, 18601.5,21931.5, 25927.5, 29923.5, 33919.5, 37915.5,22597.5, 26593.5, 30589.5, 34585.5, 38581.5, - 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, 10540.5, 16984.5, 23428.5, 29872.5,35242.5, 41686.5, 48130.5, 54574.5, 61018.5,36316.5, 42760.5, 49204.5, 55648.5, 62092.5, 3054. , 9570. , 16086. , 22602. , 29118. , 4140. , 10656. , 17172. , 23688. , 30204. , - 35634. , 42150. , 48666. , 55182. , 61698. ,36720. , 43236. , 49752. , 56268. , 62784. , 3085.5, 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, 23947.5, 30535.5,36025.5, 42613.5, 49201.5, 55789.5, 62377.5,37123.5, 43711.5, 50299.5, 56887.5, 63475.5, - 4156.5, 13192.5, 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, 41806.5,49336.5, 58372.5, 67408.5, 76444.5, 85480.5,50842.5, 59878.5, 68914.5, 77950.5, 86986.5, 4188. , 13296. , 22404. , 31512. , 40620. , 5706. , 14814. , 23922. , 33030. , 42138. , - 49728. , 58836. , 67944. , 77052. , 86160. ,51246. , 60354. , 69462. , 78570. , 87678. , 4219.5, 13399.5, 22579.5, 31759.5, 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5,50119.5, 59299.5, 68479.5, 77659.5, 86839.5,51649.5, 60829.5, 70009.5, 79189.5, 88369.5, - 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, 7228.5, 18856.5, 30484.5, 42112.5, 53740.5,63430.5, 75058.5, 86686.5, 98314.5,109942.5,65368.5, 76996.5, 88624.5,100252.5,111880.5, 5322. , 17022. , 28722. , 40422. , 52122. , 7272. , 18972. , 30672. , 42372. , 54072. , - 63822. , 75522. , 87222. , 98922. ,110622. ,65772. , 77472. , 89172. ,100872. ,112572. , 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, 30859.5, 42631.5, 54403.5,64213.5, 75985.5, 87757.5, 99529.5,111301.5,66175.5, 77947.5, 89719.5,101491.5,113263.5, - 6424.5, 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, 51454.5, 65674.5,77524.5, 91744.5,105964.5,120184.5,134404.5,79894.5, 94114.5,108334.5,122554.5,136774.5, 6456. , 20748. , 35040. , 49332. , 63624. , 8838. , 23130. , 37422. , 51714. , 66006. , - 77916. , 92208. ,106500. ,120792. ,135084. ,80298. , 94590. ,108882. ,123174. ,137466. , 6487.5, 20851.5, 35215.5, 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5,78307.5, 92671.5,107035.5,121399.5,135763.5,80701.5, 95065.5,109429.5,123793.5,138157.5, - 7558.5, 24370.5, 41182.5, 57994.5, 74806.5,10360.5, 27172.5, 43984.5, 60796.5, 77608.5,91618.5,108430.5,125242.5,142054.5,158866.5,94420.5,111232.5,128044.5,144856.5,161668.5, 7590. , 24474. , 41358. , 58242. , 75126. ,10404. , 27288. , 44172. , 61056. , 77940. , - 92010. ,108894. ,125778. ,142662. ,159546. ,94824. ,111708. ,128592. ,145476. ,162360. , 7621.5, 24577.5, 41533.5, 58489.5, 75445.5,10447.5, 27403.5, 44359.5, 61315.5, 78271.5,92401.5,109357.5,126313.5,143269.5,160225.5,95227.5,112183.5,129139.5,146095.5,163051.5}); - - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); - - MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); - - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); + auto c = NDArrayFactory::create('f', {7, 3, 2, 2, 5}); + auto expected = NDArrayFactory::create( + 'c', {7, 3, 2, 2, 5}, + {754.5, 2014.5, 3274.5, 4534.5, 5794.5, 964.5, 2224.5, + 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, + 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786., + 2118., 3450., 4782., 6114., 1008., 2340., 3672., + 5004., 6336., 7446., 8778., 10110., 11442., 12774., + 7668., 9000., 10332., 11664., 12996., 817.5, 2221.5, + 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, + 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, + 9475.5, 10879.5, 12283.5, 13687.5, 1888.5, 5740.5, 9592.5, + 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5, + 21148.5, 25000.5, 28852.5, 32704.5, 36556.5, 21790.5, 25642.5, + 29494.5, 33346.5, 37198.5, 1920., 5844., 9768., 13692., + 17616., 2574., 6498., 10422., 14346., 18270., 21540., + 25464., 29388., 33312., 37236., 22194., 26118., 30042., + 33966., 37890., 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, + 2617.5, 6613.5, 10609.5, 14605.5, 18601.5, 21931.5, 25927.5, + 29923.5, 33919.5, 37915.5, 22597.5, 26593.5, 30589.5, 34585.5, + 38581.5, 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, + 10540.5, 16984.5, 23428.5, 29872.5, 35242.5, 41686.5, 48130.5, + 54574.5, 61018.5, 36316.5, 42760.5, 49204.5, 55648.5, 62092.5, + 3054., 9570., 16086., 22602., 29118., 4140., 10656., + 17172., 23688., 30204., 35634., 42150., 48666., 55182., + 61698., 36720., 43236., 49752., 56268., 62784., 3085.5, + 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, + 23947.5, 30535.5, 36025.5, 42613.5, 49201.5, 55789.5, 62377.5, + 37123.5, 43711.5, 50299.5, 56887.5, 63475.5, 4156.5, 13192.5, + 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, + 41806.5, 49336.5, 58372.5, 67408.5, 76444.5, 85480.5, 50842.5, + 59878.5, 68914.5, 77950.5, 86986.5, 4188., 13296., 22404., + 31512., 40620., 5706., 14814., 23922., 33030., 42138., + 49728., 58836., 67944., 77052., 86160., 51246., 60354., + 69462., 78570., 87678., 4219.5, 13399.5, 22579.5, 31759.5, + 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5, 50119.5, + 59299.5, 68479.5, 77659.5, 86839.5, 51649.5, 60829.5, 70009.5, + 79189.5, 88369.5, 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, + 7228.5, 18856.5, 30484.5, 42112.5, 53740.5, 63430.5, 75058.5, + 86686.5, 98314.5, 109942.5, 65368.5, 76996.5, 88624.5, 100252.5, + 111880.5, 5322., 17022., 28722., 40422., 52122., 7272., + 18972., 30672., 42372., 54072., 63822., 75522., 87222., + 98922., 110622., 65772., 77472., 89172., 100872., 112572., + 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, + 30859.5, 42631.5, 54403.5, 64213.5, 75985.5, 87757.5, 99529.5, + 111301.5, 66175.5, 77947.5, 89719.5, 101491.5, 113263.5, 6424.5, + 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, + 51454.5, 65674.5, 77524.5, 91744.5, 105964.5, 120184.5, 134404.5, + 79894.5, 94114.5, 108334.5, 122554.5, 136774.5, 6456., 20748., + 35040., 49332., 63624., 8838., 23130., 37422., 51714., + 66006., 77916., 92208., 106500., 120792., 135084., 80298., + 94590., 108882., 123174., 137466., 6487.5, 20851.5, 35215.5, + 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5, + 78307.5, 92671.5, 107035.5, 121399.5, 135763.5, 80701.5, 95065.5, + 109429.5, 123793.5, 138157.5, 7558.5, 24370.5, 41182.5, 57994.5, + 74806.5, 10360.5, 27172.5, 43984.5, 60796.5, 77608.5, 91618.5, + 108430.5, 125242.5, 142054.5, 158866.5, 94420.5, 111232.5, 128044.5, + 144856.5, 161668.5, 7590., 24474., 41358., 58242., 75126., + 10404., 27288., 44172., 61056., 77940., 92010., 108894., + 125778., 142662., 159546., 94824., 111708., 128592., 145476., + 162360., 7621.5, 24577.5, 41533.5, 58489.5, 75445.5, 10447.5, + 27403.5, 44359.5, 61315.5, 78271.5, 92401.5, 109357.5, 126313.5, + 143269.5, 160225.5, 95227.5, 112183.5, 129139.5, 146095.5, 163051.5}); + + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); + + MmulHelper::tensorDot(&a, &b, &c, {2, 1}, {4, 2}, {0, 1, 2, 4, 3}); + + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_5) { + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto c = NDArrayFactory::create('f', {2, 4}); + auto expected = NDArrayFactory::create( + 'c', {2, 4}, {9.5, 11., 12.5, 14., 20.75, 24.5, 28.25, 32.}); - auto a = NDArrayFactory::create('c', {2, 3}); - auto b = NDArrayFactory::create('c', {3, 4}); - auto c = NDArrayFactory::create('f', {2, 4}); - auto expected = NDArrayFactory::create('c', {2, 4}, {9.5,11.,12.5 ,14.,20.75 ,24.5,28.25,32.}); + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); + MmulHelper::tensorDot(&a, &b, &c, {1}, {0}); + // c.printIndexedBuffer(); - MmulHelper::tensorDot(&a, &b, &c, {1}, {0}); - // c.printIndexedBuffer(); - - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_6) { + int bS = 2, iH = 3, iW = 2, iC = 2, mC = 2, kH = 2, kW = 2; + int oC = iC * mC; + int oH = 3, oW = 2; - int bS=2, iH=3,iW=2, iC=2,mC=2, kH=2,kW=2; - int oC=iC*mC; - int oH=3,oW=2; - - auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); - auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto c = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}, {100.,110.,336.,370.,107.,118.,345.,380.,114.,126.,354.,390.,121.,134.,363.,400.,128.,142.,372.,410.,135.,150.,381.,420., - 436.,494.,768.,850.,443.,502.,777.,860.,450.,510.,786.,870.,457.,518.,795.,880.,464.,526.,804.,890.,471.,534.,813.,900.}); + auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto c = NDArrayFactory::create('c', {bS, oH, oW, iC * mC}); + auto expected = NDArrayFactory::create( + 'c', {bS, oH, oW, iC * mC}, + {100., 110., 336., 370., 107., 118., 345., 380., 114., 126., 354., 390., + 121., 134., 363., 400., 128., 142., 372., 410., 135., 150., 381., 420., + 436., 494., 768., 850., 443., 502., 777., 860., 450., 510., 786., 870., + 457., 518., 795., 880., 464., 526., 804., 890., 471., 534., 813., 900.}); - a.linspace(0.5, 0.5); - b.linspace(0.5, 0.5); + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); - auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); + auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); - // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); - // c.printBuffer(); + // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + MmulHelper::tensorDot(&a, &b, &cR, + {{1, 0, 4, 5, 2, 3}, {iC, bS * oH * oW, kW * kH}}, + {{2, 0, 1, 3}, {iC, kH * kW, mC}}, + {{3, 0, 1, 2, 4}, {iC, bS * oH * oW, mC}}); + // c.printBuffer(); - ASSERT_TRUE(c.isSameShape(expected)); - ASSERT_TRUE(c.equalsTo(expected)); + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmmulHelperAgain) { - auto x = NDArrayFactory::create('c', {128, 156}); - auto y = NDArrayFactory::create('c', {156, 256}); - auto z = NDArrayFactory::create('c', {128, 256}); - auto e = NDArrayFactory::create('c', {128, 256}); + auto x = NDArrayFactory::create('c', {128, 156}); + auto y = NDArrayFactory::create('c', {156, 256}); + auto z = NDArrayFactory::create('c', {128, 256}); + auto e = NDArrayFactory::create('c', {128, 256}); - x.assign(1.0f); - y.assign(1.0f); - e.assign(156.0f); + x.assign(1.0f); + y.assign(1.0f); + e.assign(156.0f); - MmulHelper::mmul(&x, &y, &z); + MmulHelper::mmul(&x, &y, &z); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test1) { + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); - auto x1 = NDArrayFactory::create('c', {1, 1}); - auto x2 = NDArrayFactory::create('c', {2, 2}); - auto x3 = NDArrayFactory::create('c', {3, 3}); - - OpArgsHolder holder1({&x1}); - OpArgsHolder holder2({&x1,&x2,&x3}, {4.f, 5.f}, {6}); + OpArgsHolder holder1({&x1}); + OpArgsHolder holder2({&x1, &x2, &x3}, {4.f, 5.f}, {6}); - ASSERT_TRUE(holder1.getNumInArrs() == 1); - ASSERT_TRUE(holder1.getNumTArgs() == 0); - ASSERT_TRUE(holder1.getNumIArgs() == 0); + ASSERT_TRUE(holder1.getNumInArrs() == 1); + ASSERT_TRUE(holder1.getNumTArgs() == 0); + ASSERT_TRUE(holder1.getNumIArgs() == 0); - ASSERT_TRUE(holder2.getNumInArrs() == 3); - ASSERT_TRUE(holder2.getNumTArgs() == 2); - ASSERT_TRUE(holder2.getNumIArgs() == 1); + ASSERT_TRUE(holder2.getNumInArrs() == 3); + ASSERT_TRUE(holder2.getNumTArgs() == 2); + ASSERT_TRUE(holder2.getNumIArgs() == 1); - const std::vector& isArrAlloc1 = holder1.getAllocInfo(); - ASSERT_TRUE(isArrAlloc1.size() == 0); + const std::vector& isArrAlloc1 = holder1.getAllocInfo(); + ASSERT_TRUE(isArrAlloc1.size() == 0); - const std::vector& isArrAlloc2 = holder2.getAllocInfo(); - ASSERT_TRUE(isArrAlloc2.size() == 0); + const std::vector& isArrAlloc2 = holder2.getAllocInfo(); + ASSERT_TRUE(isArrAlloc2.size() == 0); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test2) { + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); + auto grad = NDArrayFactory::create('c', {2, 3}); - auto x1 = NDArrayFactory::create('c', {1, 1}); - auto x2 = NDArrayFactory::create('c', {2, 2}); - auto x3 = NDArrayFactory::create('c', {3, 3}); - auto grad = NDArrayFactory::create('c', {2, 3}); + OpArgsHolder holderFF({&x1, &x2, &x3}, {4.f, 5.f}, {6}); + OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); + OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); - OpArgsHolder holderFF({&x1,&x2,&x3}, {4.f, 5.f}, {6}); - OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); - OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); + ASSERT_TRUE(holderBP1.getNumInArrs() == 4); + ASSERT_TRUE(holderBP1.getNumTArgs() == 2); + ASSERT_TRUE(holderBP1.getNumIArgs() == 1); + ASSERT_TRUE(holderBP2.getNumInArrs() == 4); + ASSERT_TRUE(holderBP2.getNumTArgs() == 2); + ASSERT_TRUE(holderBP2.getNumIArgs() == 1); - ASSERT_TRUE(holderBP1.getNumInArrs() == 4); - ASSERT_TRUE(holderBP1.getNumTArgs() == 2); - ASSERT_TRUE(holderBP1.getNumIArgs() == 1); - ASSERT_TRUE(holderBP2.getNumInArrs() == 4); - ASSERT_TRUE(holderBP2.getNumTArgs() == 2); - ASSERT_TRUE(holderBP2.getNumIArgs() == 1); + const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); + ASSERT_TRUE(isArrAllocBP1.size() == 0); - const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); - ASSERT_TRUE(isArrAllocBP1.size() == 0); + const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); + for (int i = 0; i < holderFF.getNumInArrs(); ++i) { + ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); + } - const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); - for(int i = 0; i < holderFF.getNumInArrs(); ++i) { - ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); - } - - ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs()+1]) == false); + ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs() + 1]) == + false); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test3) { - - auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); - auto gradO = NDArrayFactory::create('c', {4, 9}); - auto exp = NDArrayFactory::create('c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6,1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6}); - auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); - - gradO.linspace(0.01, 0.01); - - OpArgsHolder holderFF({&input}, {}, {2, 3}); - sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly - auto results = opFF.execute(holderFF); - auto tiled = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(exp.isSameShape(tiled)); - ASSERT_TRUE(exp.equalsTo(tiled)); - - OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); - sd::ops::tile_bp opBP; - results = opBP.execute(holderBP); - auto gradI = results.at(0); - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(gradIExp.isSameShape(gradI)); - ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + auto input = + NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto exp = NDArrayFactory::create( + 'c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, + 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6}); + auto gradIExp = NDArrayFactory::create( + 'c', {2, 3}, {0.78, 0.84, 0.9, 1.32, 1.38, 1.44}); + + gradO.linspace(0.01, 0.01); + + OpArgsHolder holderFF({&input}, {}, {2, 3}); + sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here + // whether op.execute() works with OpArgsHolder correctly + auto results = opFF.execute(holderFF); + auto tiled = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(tiled)); + ASSERT_TRUE(exp.equalsTo(tiled)); + + OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); + sd::ops::tile_bp opBP; + results = opBP.execute(holderBP); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); } - ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test1) { + auto x = NDArrayFactory::create('c', {2, 3}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}); + auto gradO = NDArrayFactory::create('c', {2, 3}); - auto x = NDArrayFactory::create('c', {2, 3}, {0.1, 0.2, 0.3, 0.4, 0.5 ,0.6}); - auto gradO = NDArrayFactory::create('c', {2, 3}); - - const OpArgsHolder argsHolderFF({&x}, {}, {}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {}); + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {}); - sd::ops::sigmoid opFF; - sd::ops::sigmoid_bp opBP; + sd::ops::sigmoid opFF; + sd::ops::sigmoid_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test2) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + x.linspace(1); + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 1, 0}); - x.linspace(1); - weights.linspace(0.1, 0.1); - weights.permutei({2,3,1,0}); + const OpArgsHolder argsHolderFF({&x, &weights}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderFF({&x, &weights}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test3) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2, 3, 1, 0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test4) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2, 3, 1, 0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test5) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2, 3, 1, 0}); - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + const bool isGradCorrect = GradCheck::checkGrad( + opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); - - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test6) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); - auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); - - x.linspace(1); - weights.linspace(0.1, 0.1); - bias = 0.5; - weights.permutei({2,3,1,0}); + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2, 3, 1, 0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, + {2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::conv2d opFF; - sd::ops::conv2d_bp opBP; + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, {0.5, 1}, GradCheck::MEAN); + const bool isGradCorrect = + GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, + {0.5, 1}, GradCheck::MEAN); - ASSERT_TRUE(isGradCorrect); + ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test1) { + auto input = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 5}); + auto expOutput = NDArrayFactory::create('c', {1, 5}); + expOutput = 1; - auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {1,5}); - auto expOutput = NDArrayFactory::create('c', {1,5}); - expOutput = 1; + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test2) { + auto input = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {5, 1}); + auto expOutput = NDArrayFactory::create( + 'c', {5, 1}, + {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); - auto input = NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5,1}); - auto expOutput = NDArrayFactory::create('c', {5,1}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); - - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test3) { + auto input = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create( + 'c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); - auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create('c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test4) { - - NDArray input('c', {1500}, sd::DataType::DOUBLE); - NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput('c', {1500}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, -0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001,0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, -0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002,0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, -0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003,0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, -0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005,0.000005, 0.000005, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, -0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009,0.000009, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, -0.000012, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016,0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, -0.000022, 0.000022, 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, 0.000028, 0.000028, 0.000028, 0.000028, 0.000029,0.000029, 0.000029, 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, -0.000038, 0.000039, 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, 0.000048, 0.000049, 0.000049, 0.000050, 0.000050,0.000051, 0.000051, 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, -0.000067, 0.000068, 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, 0.000084, 0.000085, 0.000086, 0.000087, 0.000088,0.000089, 0.000090, 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, -0.000117, 0.000119, 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, 0.000148, 0.000149, 0.000151, 0.000152, 0.000154,0.000155, 0.000157, 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, -0.000205, 0.000208, 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, 0.000259, 0.000261, 0.000264, 0.000266, 0.000269,0.000272, 0.000275, 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, -0.000360, 0.000363, 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, 0.000453, 0.000457, 0.000462, 0.000467, 0.000471,0.000476, 0.000481, 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, -0.000630, 0.000636, 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, 0.000793, 0.000801, 0.000809, 0.000817, 0.000825,0.000833, 0.000842, 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, -0.001103, 0.001114, 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, 0.001388, 0.001402, 0.001416, 0.001430, 0.001444,0.001459, 0.001473, 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, -0.001930, 0.001950, 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, 0.002429, 0.002454, 0.002478, 0.002503, 0.002528,0.002554, 0.002579, 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, -0.003379, 0.003413, 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, 0.004253, 0.004296, 0.004339, 0.004382, 0.004426,0.004471, 0.004516, 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, -0.005916, 0.005975, 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, 0.007445, 0.007520, 0.007596, 0.007672, 0.007749,0.007827, 0.007906, 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, 0.009851, 0.009950}, sd::DataType::DOUBLE); - input.linspace(0.01, 0.01); - - ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput( + 'c', {1500}, + {0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, + 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, + 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, + 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, + 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, + 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, + 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, + 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, + 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, + 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, + 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, + 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, + 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, + 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, + 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, + 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000010, + 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, + 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, + 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, + 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000013, + 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, + 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, + 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, + 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000017, 0.000017, + 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, + 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, + 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, + 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, 0.000022, 0.000022, + 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, + 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, + 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, + 0.000028, 0.000028, 0.000028, 0.000028, 0.000029, 0.000029, 0.000029, + 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, + 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, + 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, + 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, 0.000038, 0.000039, + 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, + 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, + 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, + 0.000048, 0.000049, 0.000049, 0.000050, 0.000050, 0.000051, 0.000051, + 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, + 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, + 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, + 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, 0.000067, 0.000068, + 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, + 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, + 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, + 0.000084, 0.000085, 0.000086, 0.000087, 0.000088, 0.000089, 0.000090, + 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, + 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, + 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, + 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, 0.000117, 0.000119, + 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, + 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, + 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, + 0.000148, 0.000149, 0.000151, 0.000152, 0.000154, 0.000155, 0.000157, + 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, + 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, + 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, + 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, 0.000205, 0.000208, + 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, + 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, + 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, + 0.000259, 0.000261, 0.000264, 0.000266, 0.000269, 0.000272, 0.000275, + 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, + 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, + 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, + 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, 0.000360, 0.000363, + 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, + 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, + 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, + 0.000453, 0.000457, 0.000462, 0.000467, 0.000471, 0.000476, 0.000481, + 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, + 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, + 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, + 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, 0.000630, 0.000636, + 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, + 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, + 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, + 0.000793, 0.000801, 0.000809, 0.000817, 0.000825, 0.000833, 0.000842, + 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, + 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, + 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, + 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, 0.001103, 0.001114, + 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, + 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, + 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, + 0.001388, 0.001402, 0.001416, 0.001430, 0.001444, 0.001459, 0.001473, + 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, + 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, + 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, + 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, 0.001930, 0.001950, + 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, + 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, + 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, + 0.002429, 0.002454, 0.002478, 0.002503, 0.002528, 0.002554, 0.002579, + 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, + 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, + 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, + 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, 0.003379, 0.003413, + 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, + 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, + 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, + 0.004253, 0.004296, 0.004339, 0.004382, 0.004426, 0.004471, 0.004516, + 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, + 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, + 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, + 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, 0.005916, 0.005975, + 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, + 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, + 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, + 0.007445, 0.007520, 0.007596, 0.007672, 0.007749, 0.007827, 0.007906, + 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, + 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, + 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, + 0.009851, 0.009950}, + sd::DataType::DOUBLE); + input.linspace(0.01, 0.01); + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test1) { + auto input = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 5}); + auto expOutput = NDArrayFactory::create('c', {1, 5}); + expOutput = 0; - auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {1,5}); - auto expOutput = NDArrayFactory::create('c', {1,5}); - expOutput = 0; + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, + 0); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test2) { + auto input = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {5, 1}); + auto expOutput = NDArrayFactory::create( + 'c', {5, 1}, + {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - auto input= NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5,1}); - auto expOutput = NDArrayFactory::create('c', {5,1}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, + 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test3) { + auto input = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create( + 'c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create('c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, + 0); - ASSERT_TRUE(output.equalsTo(&expOutput)); + ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test4) { - - NDArray input('c', {1500}, sd::DataType::DOUBLE); - NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput('c', {1500}, {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, --8.053773, -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, --7.951773, -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, --7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, -7.746773, --7.745773, -7.744773, -7.743773, -7.742773, -7.741773, -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, -7.644773, -7.643773, --7.642773, -7.641773, -7.640773, -7.639773, -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, -7.542773, -7.541773, -7.540773, --7.539773, -7.538773, -7.537773, -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, -7.440773, -7.439773, -7.438773, -7.437773, --7.436773, -7.435773, -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, --7.333773, -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, --7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, -7.128773, --7.127773, -7.126773, -7.125773, -7.124773, -7.123773, -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, -7.026773, -7.025773, --7.024773, -7.023773, -7.022773, -7.021773, -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, -6.924773, -6.923773, -6.922773, --6.921773, -6.920773, -6.919773, -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, -6.822773, -6.821773, -6.820773, -6.819773, --6.818773, -6.817773, -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, --6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, sd::DataType::DOUBLE); - input.linspace(0.01, 0.001); - - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); - - ASSERT_TRUE(output.equalsTo(&expOutput)); + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput( + 'c', {1500}, + {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, + -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, + -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, + -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, + -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, + -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, + -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, + -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, + -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, + -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, + -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, + -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, + -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, + -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, + -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, + -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, + -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, -8.053773, + -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, + -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, + -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, + -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, + -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, + -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, + -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, + -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, + -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, + -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, + -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, + -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, + -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, + -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, + -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, + -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, + -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, -7.951773, + -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, + -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, + -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, + -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, + -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, + -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, + -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, + -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, + -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, + -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, + -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, + -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, + -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, + -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, + -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, + -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, + -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, + -7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, + -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, + -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, + -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, + -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, + -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, + -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, + -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, + -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, + -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, + -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, + -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, + -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, + -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, + -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, + -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, + -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, + -7.746773, -7.745773, -7.744773, -7.743773, -7.742773, -7.741773, + -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, + -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, + -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, + -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, + -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, + -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, + -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, + -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, + -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, + -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, + -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, + -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, + -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, + -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, + -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, + -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, + -7.644773, -7.643773, -7.642773, -7.641773, -7.640773, -7.639773, + -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, + -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, + -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, + -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, + -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, + -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, + -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, + -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, + -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, + -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, + -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, + -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, + -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, + -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, + -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, + -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, + -7.542773, -7.541773, -7.540773, -7.539773, -7.538773, -7.537773, + -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, + -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, + -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, + -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, + -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, + -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, + -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, + -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, + -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, + -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, + -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, + -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, + -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, + -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, + -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, + -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, + -7.440773, -7.439773, -7.438773, -7.437773, -7.436773, -7.435773, + -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, + -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, + -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, + -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, + -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, + -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, + -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, + -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, + -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, + -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, + -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, + -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, + -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, + -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, + -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, + -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, + -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, -7.333773, + -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, + -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, + -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, + -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, + -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, + -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, + -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, + -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, + -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, + -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, + -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, + -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, + -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, + -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, + -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, + -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, + -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, + -7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, + -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, + -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, + -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, + -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, + -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, + -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, + -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, + -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, + -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, + -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, + -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, + -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, + -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, + -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, + -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, + -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, + -7.128773, -7.127773, -7.126773, -7.125773, -7.124773, -7.123773, + -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, + -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, + -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, + -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, + -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, + -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, + -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, + -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, + -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, + -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, + -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, + -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, + -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, + -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, + -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, + -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, + -7.026773, -7.025773, -7.024773, -7.023773, -7.022773, -7.021773, + -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, + -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, + -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, + -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, + -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, + -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, + -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, + -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, + -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, + -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, + -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, + -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, + -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, + -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, + -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, + -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, + -6.924773, -6.923773, -6.922773, -6.921773, -6.920773, -6.919773, + -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, + -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, + -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, + -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, + -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, + -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, + -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, + -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, + -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, + -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, + -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, + -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, + -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, + -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, + -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, + -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, + -6.822773, -6.821773, -6.820773, -6.819773, -6.818773, -6.817773, + -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, + -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, + -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, + -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, + -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, + -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, + -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, + -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, + -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, + -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, + -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, + -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, + -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, + -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, + -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, + -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, + -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, + -6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, + -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, + -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, + -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, + -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, + -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, + -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, + -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, + -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, + -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, + sd::DataType::DOUBLE); + input.linspace(0.01, 0.001); + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, + 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); } - ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_1) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {M, N}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_2) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {M, N, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(6, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(6, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_3) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(4, {1,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {N, M, 5}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(4, {1, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_4) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('f', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(3, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(3, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_5) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(2, {0,1}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('f', {5, M, N}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(2, {0, 1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_6) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; - - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(13, {0,2}); - NDArray y('f', {M}, sd::DataType::DOUBLE); + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(13, {0, 2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); - NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_7) { + const Nd4jLong M = 3; + const Nd4jLong N = 4; - const Nd4jLong M = 3; - const Nd4jLong N = 4; + NDArray a('c', {N, M}, + {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, + sd::DataType::DOUBLE); + a.permutei({1, 0}); + NDArray temp('c', {5, N, M}, + {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, + 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, + 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, + 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, + sd::DataType::DOUBLE); + NDArray x = temp(10, {0, 2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); - NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); - a.permutei({1,0}); - NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); - NDArray x = temp(10, {0,2}); - NDArray y('c', {M}, sd::DataType::DOUBLE); + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); - - sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); - ASSERT_TRUE(y.equalsTo(&exp)); + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_1) { + NDArray input('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, + sd::DataType::DOUBLE); + NDArray expOutput('c', {3, 3}, + {0.04508, 0.04514, 0.0008, 0.0472, 0.00087, 0.10492, + 0.00235, 0.04592, 0.10553}, + sd::DataType::DOUBLE); + NDArray output('c', {3, 3}, sd::DataType::DOUBLE); - NDArray input('c', {3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, sd::DataType::DOUBLE); - NDArray expOutput('c', {3,3}, {0.04508, 0.04514, 0.0008 , 0.0472 , 0.00087, 0.10492, 0.00235, 0.04592, 0.10553}, sd::DataType::DOUBLE); - NDArray output('c', {3,3}, sd::DataType::DOUBLE); - - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_2) { - - NDArray input('c', {3,3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14.}, sd::DataType::DOUBLE); - NDArray expOutput('c', {3,3,3}, {4.50755e-02, 4.51394e-02, 6.64586e-03,4.72027e-02, 8.67128e-04, 6.97440e-03,2.35008e-03, 4.59243e-02, 3.32995e-04, - 4.51766e-02, 2.26032e-06, 4.51767e-02,2.91394e-07, 2.37285e-06, 3.94360e-08,4.51769e-02, 1.12535e-07, 4.51767e-02, - 7.58256e-10, 4.51767e-02, 1.22325e-11,7.96007e-10, 1.32293e-11, 1.04994e-01,3.77513e-11, 4.51767e-02, 1.04994e-01}, sd::DataType::DOUBLE); - NDArray output('c', {3,3,3}, sd::DataType::DOUBLE); - - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 1); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + NDArray input('c', {3, 3, 3}, + {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, + -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14.}, + sd::DataType::DOUBLE); + NDArray expOutput( + 'c', {3, 3, 3}, + {4.50755e-02, 4.51394e-02, 6.64586e-03, 4.72027e-02, 8.67128e-04, + 6.97440e-03, 2.35008e-03, 4.59243e-02, 3.32995e-04, 4.51766e-02, + 2.26032e-06, 4.51767e-02, 2.91394e-07, 2.37285e-06, 3.94360e-08, + 4.51769e-02, 1.12535e-07, 4.51767e-02, 7.58256e-10, 4.51767e-02, + 1.22325e-11, 7.96007e-10, 1.32293e-11, 1.04994e-01, 3.77513e-11, + 4.51767e-02, 1.04994e-01}, + sd::DataType::DOUBLE); + NDArray output('c', {3, 3, 3}, sd::DataType::DOUBLE); + + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 1); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_3) { + NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); + NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, + sd::DataType::DOUBLE); + NDArray output('c', {5}, sd::DataType::DOUBLE); - NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); - NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, sd::DataType::DOUBLE); - NDArray output('c', {5}, sd::DataType::DOUBLE); + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - // input.applyTransform(sd::transform::SoftMaxDerivative, &output); - - sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); - ASSERT_TRUE(expOutput.isSameShape(output)); - ASSERT_TRUE(expOutput.equalsTo(output)); + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_1) { - - const int bS = 2; - const int nIn = 10; - const int nOut = 4; - - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh - - NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - - NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; - - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = + 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = + 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = + 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = + 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = + 0; // alpha value for output activation, not required for tanh + const float outBeta = + 0; // beta value for output activation, not required for tanh + + NDArray x('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, + {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, + 0.999288, 0.999288}, + sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, + {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, + 3.999778, 3.999778}, + sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, + gateAlpha, gateBeta, cellAct, cellAlpha, + cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, + &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_2) { - - const int bS = 2; - const int nIn = 10; - const int nOut = 4; - - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 3; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh - - NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - - NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, sd::DataType::FLOAT32); - - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; - - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = + 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = + 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = + 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = + 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = + 0; // alpha value for output activation, not required for tanh + const float outBeta = + 0; // beta value for output activation, not required for tanh + + NDArray x('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, + {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, + sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, + sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, + gateAlpha, gateBeta, cellAct, cellAlpha, + cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, + &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_3) { - - const int nIn = 10; - const int nOut = 4; - - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh - - NDArray x ('c', {nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); - NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); - NDArray hI('c', {nOut}, sd::DataType::FLOAT32); - NDArray cI('c', {nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); - - NDArray h('c', {nOut}, sd::DataType::FLOAT32); - NDArray c('c', {nOut}, sd::DataType::FLOAT32); - - NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); - NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - - std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; - - x = 1.; - hI = 2.; - cI = 3.; - Wx = 0.5; - Wr = 0.4; - Wp = 0.3; - b = 0.7; - - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); - - ASSERT_TRUE(expH.isSameShape(h)); - ASSERT_TRUE(expH.equalsTo(h)); - ASSERT_TRUE(expC.isSameShape(c)); - ASSERT_TRUE(expC.equalsTo(c)); + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = + 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = + 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = + 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = + 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = + 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = + 0; // alpha value for output activation, not required for tanh + const float outBeta = + 0; // beta value for output activation, not required for tanh + + NDArray x('c', {nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {nOut}, sd::DataType::FLOAT32); + NDArray c('c', {nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, + sd::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, + sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, + gateAlpha, gateBeta, cellAct, cellAlpha, + cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, + &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); } - diff --git a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp index dc0bcb7443f9..04158a8f6bbf 100644 --- a/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/IndexingTests.cpp @@ -18,424 +18,422 @@ // Created by raver119 on 31.10.2017. // -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class IndexingTests : public testing::Test { -public: - + public: }; TEST_F(IndexingTests, StridedSlice_1) { - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto exp = NDArrayFactory::create('c', {1, 1, 3}); - exp.p(0, 25.f); - exp.p(1, 26.f); - exp.p(2, 27.f); - - x.linspace(1); - auto begin = NDArrayFactory::create({2,2, 0}); - auto end = NDArrayFactory::create({3,3,3}); - auto strides = NDArrayFactory::create({1,1,1}); - - - sd::ops::strided_slice op; - - auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 25.f); + exp.p(1, 26.f); + exp.p(2, 27.f); + + x.linspace(1); + auto begin = NDArrayFactory::create({2, 2, 0}); + auto end = NDArrayFactory::create({3, 3, 3}); + auto strides = NDArrayFactory::create({1, 1, 1}); + + sd::ops::strided_slice op; + + auto result = op.evaluate({&x, &begin, &end, &strides}, {}, + {0, 0, 0, 0, 0}); //, 2,2,0, 3,3,3, 1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, StridedSlice_2) { - auto x = NDArrayFactory::create('c', {5, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 3, 3}, {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f}); + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 3}, + {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, + 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f}); - x.linspace(1); + x.linspace(1); - sd::ops::strided_slice op; - - auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; - auto z = result.at(0); + auto result = + op.evaluate({&x}, {}, {0, 0, 0, 0, 0, 3, 2, 0, 5, 5, 3, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, StridedSlice_3) { - auto x = NDArrayFactory::create('c', {5, 5, 5}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, 113.f, 116.f, 118.f, 121.f, 123.f}); + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = + NDArrayFactory::create('c', {2, 3, 2}, + {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, + 113.f, 116.f, 118.f, 121.f, 123.f}); - x.linspace(1); + x.linspace(1); - sd::ops::strided_slice op; + sd::ops::strided_slice op; - auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = + op.evaluate({&x}, {}, {0, 0, 0, 0, 0, 3, 2, 0, 5, 5, 3, 1, 1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, SimpleSlice_1) { + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - - auto exp = NDArrayFactory::create('c', {1, 1, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); - sd::ops::slice op; + sd::ops::slice op; - auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto result = op.evaluate({&input}, {}, {1, 0, 0, 1, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, SimpleSlice_2) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto exp = NDArrayFactory::create('c', {1, 2, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); - exp.p(3, 4.0f); - exp.p(4, 4.0f); - exp.p(5, 4.0f); + auto exp = NDArrayFactory::create('c', {1, 2, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 4.0f); + exp.p(4, 4.0f); + exp.p(5, 4.0f); - sd::ops::slice op; + sd::ops::slice op; - auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&input}, {}, {1, 0, 0, 1, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, SimpleSlice_3) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto exp = NDArrayFactory::create('c', {2, 1, 3}); - exp.p(0, 3.0f); - exp.p(1, 3.0f); - exp.p(2, 3.0f); - exp.p(3, 5.0f); - exp.p(4, 5.0f); - exp.p(5, 5.0f); + auto exp = NDArrayFactory::create('c', {2, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 5.0f); + exp.p(4, 5.0f); + exp.p(5, 5.0f); - sd::ops::slice op; + sd::ops::slice op; - auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&input}, {}, {1, 0, 0, 2, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); - + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, SimpleSlice_4) { - auto input = NDArrayFactory::create('c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); - auto start = NDArrayFactory::create('c', {3}, {1.0, 0.0, 0.0}); - auto stop = NDArrayFactory::create('c', {3}, {2.0, 1.0, 3.0}); - auto exp = NDArrayFactory::create('c', {2, 1, 3}, {3.0, 3.0, 3.0, 5.0, 5.0, 5.0}); - - sd::ops::slice op; + auto input = NDArrayFactory::create( + 'c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto start = NDArrayFactory::create('c', {3}, {1.0, 0.0, 0.0}); + auto stop = NDArrayFactory::create('c', {3}, {2.0, 1.0, 3.0}); + auto exp = NDArrayFactory::create('c', {2, 1, 3}, + {3.0, 3.0, 3.0, 5.0, 5.0, 5.0}); - auto result = op.evaluate({&input, &start, &stop}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::slice op; - auto z = result.at(0); + auto result = op.evaluate({&input, &start, &stop}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_0) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e).assign((float) (e+1)); - } + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } - auto exp = NDArrayFactory::create('c', {1, 5}); - exp.assign(2.0f); + auto exp = NDArrayFactory::create('c', {1, 5}); + exp.assign(2.0f); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 0, 1, 2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - // z->printShapeInfo("z"); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // z->printShapeInfo("z"); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_00) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e).assign((float) (e+1)); - } - - auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); - + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 0, 1, 1, 2, 3, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_1) { - auto matrix = NDArrayFactory::create('c', {3, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - tads.at(e).assign((float) (e+1)); - } + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e).assign((float)(e + 1)); + } - auto exp = NDArrayFactory::create('c', {5}); - exp.assign(2.0f); + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2.0f); - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0, 0, 0, 0, 1, 1, 2, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - // z->printShapeInfo("z"); + // z->printShapeInfo("z"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, MaskedSlice_2) { - - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3, 3}, {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, 6.000000f, 6.200000f, 6.300000f}); - - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create( + 'c', {3, 3}, + {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, + 6.000000f, 6.200000f, 6.300000f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_3) { + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, MaskedSlice_4) { + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, {4.f, 4.2f, 4.3f}); - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); - - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix}, {}, {0, 0, 0, 0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Live_Slice_1) { - auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); - auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); + auto matrix = NDArrayFactory::create( + 'c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, + 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, + 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, {4.f, 4.2f, 4.3f}); - auto begin = NDArrayFactory::create('c', {3}, {1.0f, 0.0f, 0.0f}); - auto end = NDArrayFactory::create('c', {3}, {3.0f, 3.0f, 3.0f}); - auto stride = NDArrayFactory::create('c', {3}, {1.0f, 1.0f, 1.0f}); + auto begin = NDArrayFactory::create('c', {3}, {1.0f, 0.0f, 0.0f}); + auto end = NDArrayFactory::create('c', {3}, {3.0f, 3.0f, 3.0f}); + auto stride = NDArrayFactory::create('c', {3}, {1.0f, 1.0f, 1.0f}); - // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) - sd::ops::strided_slice op; - auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = + op.evaluate({&matrix, &begin, &end, &stride}, {}, {0, 0, 0, 0, 3}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - // z->printShapeInfo("z shape"); + // z->printShapeInfo("z shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_1) { - auto x = NDArrayFactory::create('c', {1, 2}, {5.f, 2.f}); - auto a = NDArrayFactory::create('c', {1}, {0.f}); - auto b = NDArrayFactory::create('c', {1}, {1.f}); - auto c = NDArrayFactory::create('c', {1}, {1.f}); - auto exp = NDArrayFactory::create({5.0f, 2}); + auto x = NDArrayFactory::create('c', {1, 2}, {5.f, 2.f}); + auto a = NDArrayFactory::create('c', {1}, {0.f}); + auto b = NDArrayFactory::create('c', {1}, {1.f}); + auto c = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create({5.0f, 2}); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Test_StridedSlice_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto a = NDArrayFactory::create('c', {2}, {1, 1}); - auto b = NDArrayFactory::create('c', {2}, {2, 2}); - auto c = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {5.0}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 1}); + auto b = NDArrayFactory::create('c', {2}, {2, 2}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {5.0}); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - // z->printIndexedBuffer("Z"); + // z->printIndexedBuffer("Z"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_3) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto a = NDArrayFactory::create('c', {2}, {1, 2}); - auto b = NDArrayFactory::create('c', {2}, {2, 3}); - auto c = NDArrayFactory::create('c', {2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {6.0}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 2}); + auto b = NDArrayFactory::create('c', {2}, {2, 3}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {6.0}); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(IndexingTests, Test_StridedSlice_4) { - auto x = NDArrayFactory::create('c', {1, 2}, {5, 2}); - auto a = NDArrayFactory::create('c', {1}, {0.}); - auto b = NDArrayFactory::create('c', {1}, {1}); - auto c = NDArrayFactory::create('c', {1}, {1}); - auto exp = NDArrayFactory::create({5.0f, 2}); + auto x = NDArrayFactory::create('c', {1, 2}, {5, 2}); + auto a = NDArrayFactory::create('c', {1}, {0.}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto c = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create({5.0f, 2}); - sd::ops::strided_slice op; - auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); -// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + // auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, + // 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); + auto z = result.at(0); - //z->printIndexedBuffer("Z"); + // z->printIndexedBuffer("Z"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(IndexingTests, Test_Subarray_Strided_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto exp = NDArrayFactory::create('c', {3, 2}, {1, 3, 4, 6, 7, 9}); - auto sub = x({0,0,0, 0,3,2}, true, true); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 2}, {1, 3, 4, 6, 7, 9}); + auto sub = x({0, 0, 0, 0, 3, 2}, true, true); - ASSERT_TRUE(exp.isSameShape(sub)); - ASSERT_TRUE(exp.equalsTo(sub)); + ASSERT_TRUE(exp.isSameShape(sub)); + ASSERT_TRUE(exp.equalsTo(sub)); } - /* TEST_F(IndexingTests, MaskedSlice_5) { - auto matrix('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto matrix('c', {3, 3, 3}, +{1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, +6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); auto exp('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu b/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu index aa2c13eb5eed..160f315576c0 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropCudaTests.cu @@ -18,70 +18,84 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include -#include #include #include +#include +#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class JavaInteropCudaTests : public testing::Test { -public: - + public: }; TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) { - auto x = NDArrayFactory::create('c', {3, 5}); - auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {3, 5}); - x.assign(1.f); - e.assign(2.f); - - sd::ops::add op; - Context context(1); - - context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - - context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - - PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1"); - execCustomOp2(nullptr, op.getOpHash(), &context); - - pm.synchronize(); - - ASSERT_EQ(e, x); + auto x = NDArrayFactory::create('c', {3, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {3, 5}); + x.assign(1.f); + e.assign(2.f); + + sd::ops::add op; + Context context(1); + + context.setCudaContext( + LaunchContext::defaultContext()->getCudaStream(), + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); + + context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + + PointersManager pm(LaunchContext::defaultContext(), + "test_DeclarableOp_execution_1"); + execCustomOp2(nullptr, op.getOpHash(), &context); + + pm.synchronize(); + + ASSERT_EQ(e, x); } TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) { - NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); - NDArray y('c', {2, 2}, sd::DataType::FLOAT32); - NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); - NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); - x.assign(1.f); - y.assign(2.f); - e.assign(false); + x.assign(1.f); + y.assign(2.f); + e.assign(false); - sd::ops::equals op; - Context context(1); + sd::ops::equals op; + Context context(1); - context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); - context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + context.setCudaContext( + LaunchContext::defaultContext()->getCudaStream(), + LaunchContext::defaultContext()->getReductionPointer(), + LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo()); - context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2"); - execCustomOp2(nullptr, op.getOpHash(), &context); + PointersManager pm(LaunchContext::defaultContext(), + "test_DeclarableOp_execution_2"); + execCustomOp2(nullptr, op.getOpHash(), &context); - pm.synchronize(); + pm.synchronize(); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } - diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index a9774aa8ff74..83bf35f98560 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -18,450 +18,524 @@ // @author raver119@gmail.com // -#include #include +#include +#include +#include #include #include -#include -#include -#include "testlayers.h" + #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; class JavaInteropTests : public testing::Test { -public: - + public: }; - TEST_F(JavaInteropTests, TestShapeExposure1) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; - auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + auto shapeList = + calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } - TEST_F(JavaInteropTests, TestShapeExposure2) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1, 2, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 5, 4}); - sd::ops::shape_of op; + sd::ops::shape_of op; - std::vector tArgs({}); - std::vector iArgs({}); + std::vector tArgs({}); + std::vector iArgs({}); + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo()}; - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo()}; + auto shapeList = + calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } TEST_F(JavaInteropTests, TestShapeExposure3) { - auto x = NDArrayFactory::create('c', {5, 30}); - auto sizes = NDArrayFactory::create('c', {3}, {4, 15, 11}); + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {3}, {4, 15, 11}); - std::vector list0 = {0,0, 0,4}; - std::vector list1 = {0,0, 4,19}; - std::vector list2 = {0,0, 19,30}; + std::vector list0 = {0, 0, 0, 4}; + std::vector list1 = {0, 0, 4, 19}; + std::vector list2 = {0, 0, 19, 30}; - auto sub0 = x(list0, true); - auto sub1 = x(list1, true); - auto sub2 = x(list2, true); + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); - sub0.assign(0.0f); - sub1.assign(1.0f); - sub2.assign(2.0f); + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); - Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.specialBuffer(), sizes.specialBuffer()}; - Nd4jPointer inputShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)sizes.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)sizes.specialShapeInfo()}; + Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.specialBuffer(), + sizes.specialBuffer()}; + Nd4jPointer inputShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)sizes.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)sizes.specialShapeInfo()}; - sd::ops::split_v op; + sd::ops::split_v op; - Nd4jLong iArgs[] = {1}; - auto hash = op.getOpHash(); + Nd4jLong iArgs[] = {1}; + auto hash = op.getOpHash(); - auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); + auto shapeList = + calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, + nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); - ASSERT_EQ(3, shapeList->size()); + ASSERT_EQ(3, shapeList->size()); - ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0))); - ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); - ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); + ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0))); + ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); + ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); - deleteShapeList((Nd4jPointer) shapeList); + deleteShapeList((Nd4jPointer)shapeList); } TEST_F(JavaInteropTests, Test_Squeeze_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - - sd::ops::squeeze op; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + sd::ops::squeeze op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_RDiv_1) { - auto x = NDArrayFactory::create('c', {3}, {2, 2, 2}); - auto y = NDArrayFactory::create('c', {3}, {4, 6, 8}); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {2, 3, 4}); - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - - sd::ops::reversedivide op; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), (Nd4jPointer)z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&x, &y}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto y = NDArrayFactory::create('c', {3}, {4, 6, 8}); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {2, 3, 4}); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + sd::ops::reversedivide op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), + (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, TestSconv2d_1) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); - auto bias = NDArrayFactory::create('c', {2}); - auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); - output.assign(0.0); - - input.linspace(1); - weightsD.linspace(1); - weightsP.linspace(1); - bias.linspace(1); - weightsD.permutei({2,3,1,0}); - weightsP.permutei({2,3,1,0}); - - auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); - - sd::ops::sconv2d op; - - NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), (Nd4jPointer) weightsP.buffer(), (Nd4jPointer) bias.buffer(), (Nd4jPointer) input.specialBuffer(), (Nd4jPointer) weightsD.specialBuffer(), (Nd4jPointer) weightsP.specialBuffer(), (Nd4jPointer) bias.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer) weightsP.shapeInfo(), (Nd4jPointer) bias.shapeInfo(), (Nd4jPointer) input.specialShapeInfo(), (Nd4jPointer) weightsD.specialShapeInfo(), (Nd4jPointer) weightsP.specialShapeInfo(), (Nd4jPointer) bias.specialShapeInfo()}; - - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), (Nd4jPointer) output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer) output.specialShapeInfo()}; - - Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; - - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1, - nullptr, 0, exp, 9, nullptr, 0, false); - - //output.printBuffer("output"); - NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); - - ASSERT_NEAR(1423, output.e(0), 1e-5); - //nd4j_printf("Iter %i passed...\n", e); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2, 3, 1, 0}); + weightsP.permutei({2, 3, 1, 0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + + NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)weightsD.buffer(), + (Nd4jPointer)weightsP.buffer(), + (Nd4jPointer)bias.buffer(), + (Nd4jPointer)input.specialBuffer(), + (Nd4jPointer)weightsD.specialBuffer(), + (Nd4jPointer)weightsP.specialBuffer(), + (Nd4jPointer)bias.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weightsD.shapeInfo(), + (Nd4jPointer)weightsP.shapeInfo(), + (Nd4jPointer)bias.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)weightsD.specialShapeInfo(), + (Nd4jPointer)weightsP.specialShapeInfo(), + (Nd4jPointer)bias.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + (Nd4jPointer)output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, + false); + + // output.printBuffer("output"); + NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + ASSERT_NEAR(1423, output.e(0), 1e-5); + // nd4j_printf("Iter %i passed...\n", e); } TEST_F(JavaInteropTests, TestSconv2d_2) { - auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); - auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); - auto output = NDArrayFactory::create('c', {3, 3, 8, 8}); - output.assign(0.0); + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto output = NDArrayFactory::create('c', {3, 3, 8, 8}); + output.assign(0.0); - input.linspace(1); - weightsD.linspace(1); - weightsD.permutei({2,3,1,0}); + input.linspace(1); + weightsD.linspace(1); + weightsD.permutei({2, 3, 1, 0}); - auto expOutput = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto expOutput = NDArrayFactory::create('c', {3, 3, 8, 8}); - sd::ops::sconv2d op; + sd::ops::sconv2d op; - NDArray::prepareSpecialUse({&output}, {&input, &weightsD}); + NDArray::prepareSpecialUse({&output}, {&input, &weightsD}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), input.specialBuffer(), weightsD.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)weightsD.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = { + (Nd4jPointer)input.buffer(), (Nd4jPointer)weightsD.buffer(), + input.specialBuffer(), weightsD.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weightsD.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)weightsD.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; - Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, + false); - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + NDArray::registerSpecialUse({&output}, {&input, &weightsD}); - NDArray::registerSpecialUse({&output}, {&input, &weightsD}); - - ASSERT_NEAR(1, output.e(0), 1e-5); + ASSERT_NEAR(1, output.e(0), 1e-5); } - TEST_F(JavaInteropTests, TestMaxPooling2d_1) { - auto input = NDArrayFactory::create('c', {1, 2, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {1, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); - NDArray::prepareSpecialUse({&output}, {&input}); + NDArray::prepareSpecialUse({&output}, {&input}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - sd::ops::maxpool2d op; + sd::ops::maxpool2d op; - Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input}); - ASSERT_EQ(ND4J_STATUS_OK, status); + Nd4jStatus status = execCustomOp( + nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false); + NDArray::registerSpecialUse({&output}, {&input}); + ASSERT_EQ(ND4J_STATUS_OK, status); } TEST_F(JavaInteropTests, TestCol2Im_1) { - /* - o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99] - o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99] - o.d.n.l.c.ConvolutionLayer - Strides: [1, 1] - o.d.n.l.c.ConvolutionLayer - Padding: [0, 0] - o.d.n.l.c.ConvolutionLayer - Input: [4,5] - o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1] - */ - auto input = NDArrayFactory::create('c', {1, 2, 2, 2, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); - input.linspace(1); + /* + o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, + 1, 40, 8, 0, -1, 99] o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, + 4, 5, 20, 20, 5, 1, 0, 1, 99] o.d.n.l.c.ConvolutionLayer - Strides: [1, 1] + o.d.n.l.c.ConvolutionLayer - Padding: [0, 0] + o.d.n.l.c.ConvolutionLayer - Input: [4,5] + o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1] + */ + auto input = NDArrayFactory::create('c', {1, 2, 2, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); - NDArray::prepareSpecialUse({&output}, {&input}); + NDArray::prepareSpecialUse({&output}, {&input}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; - sd::ops::col2im op; + sd::ops::col2im op; - Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; + Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; - auto hash = op.getOpHash(); + auto hash = op.getOpHash(); - execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); - NDArray::registerSpecialUse({&output}, {&input}); + NDArray::registerSpecialUse({&output}, {&input}); - ASSERT_TRUE(output.meanNumber().e(0) > 0.0f); + ASSERT_TRUE(output.meanNumber().e(0) > 0.0f); } TEST_F(JavaInteropTests, TestPNorm_1) { - /* - o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99] - o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99] - o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2] - o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1] - o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0] - o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1] - o.d.n.l.c.s.SubsamplingLayer - Same: false - o.d.n.l.c.s.SubsamplingLayer - pnorm: 2 - */ - auto input = NDArrayFactory::create('c', {1, 3, 4, 4}); - auto output = NDArrayFactory::create('c', {1, 3, 3, 3}); - input.linspace(1); - - NDArray::prepareSpecialUse({&output}, {&input}); - - sd::ops::pnormpool2d op; - - Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; - - - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input}); - - ASSERT_TRUE(output.meanNumber().e(0) > 0.0); + /* + o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, + 99] o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, + 1, 99] o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2] + o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0] + o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Same: false + o.d.n.l.c.s.SubsamplingLayer - pnorm: 2 + */ + auto input = NDArrayFactory::create('c', {1, 3, 4, 4}); + auto output = NDArrayFactory::create('c', {1, 3, 3, 3}); + input.linspace(1); + + NDArray::prepareSpecialUse({&output}, {&input}); + + sd::ops::pnormpool2d op; + + Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, + 0, false); + + NDArray::registerSpecialUse({&output}, {&input}); + + ASSERT_TRUE(output.meanNumber().e(0) > 0.0); } - TEST_F(JavaInteropTests, TestInplace_1) { - auto input = NDArrayFactory::create('c', {10, 10}); - //auto exp('c', {10, 10}); - input.linspace(1); - - NDArray::prepareSpecialUse({}, {&input}); + auto input = NDArrayFactory::create('c', {10, 10}); + // auto exp('c', {10, 10}); + input.linspace(1); - sd::ops::clipbyvalue op; + NDArray::prepareSpecialUse({}, {&input}); - double extras[] = {-1.0f, 1.0f}; + sd::ops::clipbyvalue op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + double extras[] = {-1.0f, 1.0f}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true); + Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, nullptr, nullptr, 0, extras, + 2, nullptr, 0, nullptr, 0, true); - NDArray::registerSpecialUse({}, {&input}); + NDArray::registerSpecialUse({}, {&input}); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(ND4J_STATUS_OK, result); - ASSERT_NEAR(1.0, input.meanNumber().e(0), 1e-5); + ASSERT_NEAR(1.0, input.meanNumber().e(0), 1e-5); } TEST_F(JavaInteropTests, Test_Synonyms_1) { - auto op = OpRegistrator::getInstance()->getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance()->getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = op->getOpName(); - std::string nameRef = opRef->getOpName(); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_Synonyms_2) { - auto op = OpRegistrator::getInstance()->getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance()->getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = op->getOpName(); - std::string nameRef = opRef->getOpName(); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_Synonyms_3) { - auto op = OpRegistrator::getInstance()->getOperation("RDiv"); - auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); - std::string nameExp("reversedivide"); + auto op = OpRegistrator::getInstance()->getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance()->getOperation("reversedivide"); + std::string nameExp("reversedivide"); - ASSERT_TRUE(op != nullptr); - ASSERT_TRUE(opRef != nullptr); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); - std::string name = op->getOpName(); - std::string nameRef = opRef->getOpName(); + std::string name = op->getOpName(); + std::string nameRef = opRef->getOpName(); - ASSERT_EQ(nameExp, nameRef); - ASSERT_EQ(nameRef, name); + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); } TEST_F(JavaInteropTests, Test_FastPath_Validation_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::softmax op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { - auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::softmax op; - auto status = op.execute(&ctx); - ASSERT_NE(Status::OK(), status); + auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); } TEST_F(JavaInteropTests, Test_FastPath_Validation_3) { - auto x = NDArrayFactory::create('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); - - auto min = NDArrayFactory::create({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); - auto max = NDArrayFactory::create({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); - - auto z = NDArrayFactory::create('c', {3, 5}); - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo()); - ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - sd::ops::fake_quant_with_min_max_vars_per_channel op; - ASSERT_ANY_THROW(op.execute(&ctx)); + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + + auto min = NDArrayFactory::create( + {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + auto max = NDArrayFactory::create( + {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + auto z = NDArrayFactory::create('c', {3, 5}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), + min.specialShapeInfo()); + ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), + max.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + ASSERT_ANY_THROW(op.execute(&ctx)); } TEST_F(JavaInteropTests, Test_empty_cast_1) { - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto z = NDArrayFactory::create('c', {1, 0, 2}); - auto e = NDArrayFactory::create('c', {1, 0, 2}); - - Nd4jLong iArgs[] = {10}; - - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 1); - - sd::ops::cast op; - auto result = op.execute(&ctx); - ASSERT_EQ(Status::OK(), result); - ASSERT_EQ(e, z); + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto z = NDArrayFactory::create('c', {1, 0, 2}); + auto e = NDArrayFactory::create('c', {1, 0, 2}); + + Nd4jLong iArgs[] = {10}; + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); + + sd::ops::cast op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); } /* @@ -483,12 +557,16 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), x.specialShapeInfo()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), +x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), +z.specialBuffer()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), +z.specialShapeInfo()}; - auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); + auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, +ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, +0, false); NDArray::registerSpecialUse({&z}, {&x}); @@ -549,8 +627,8 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { auto _z = z.e(e); auto eq = sd::math::nd4j_eq(_m, _z, 1e-5); if (!eq) { - nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z); - cnt++; + nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, +_z); cnt++; } } @@ -559,7 +637,8 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) { TEST_F(JavaInteropTests, Test_GraphReuse_1) { - uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + uint8_t* data = +sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); registerGraph(nullptr, 119, (Nd4jPointer) data); @@ -581,8 +660,9 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { auto exp1 = NDArrayFactory::create('c', {3}, {6, 6, 6}); auto exp2 = NDArrayFactory::create('c', {3}, {9, 9, 9}); - // we load graph from file, because we're not in java here, and dont have buffer ready - uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + // we load graph from file, because we're not in java here, and dont have +buffer ready uint8_t* data = +sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); // we ensure that there's no such a graph stored earlier ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); @@ -605,10 +685,9 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()}; Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()}; - // now we're executing stored graph and providing replacement for input variable - auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1); - ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); - ASSERT_EQ(1, res_0->size()); + // now we're executing stored graph and providing replacement for input +variable auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, +1); ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); ASSERT_EQ(1, res_0->size()); auto z0 = res_0->at(0)->getNDArray(); ASSERT_TRUE(exp0.isSameShape(z0)); @@ -658,197 +737,247 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) { */ TEST_F(JavaInteropTests, Test_Greater_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); -// auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); - auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); + // auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); - NDArray::prepareSpecialUse({&o}, {&x, &y}); + NDArray::prepareSpecialUse({&o}, {&x, &y}); - sd::ops::greater op; + sd::ops::greater op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x, &y}); - ASSERT_TRUE(exp.equalsTo(&o)); + NDArray::registerSpecialUse({&o}, {&x, &y}); + ASSERT_TRUE(exp.equalsTo(&o)); } - TEST_F(JavaInteropTests, Test_Greater_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); - auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); - auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); + auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); - sd::ops::greater op; + sd::ops::greater op; - NDArray::prepareSpecialUse({&o}, {&x, &y}); + NDArray::prepareSpecialUse({&o}, {&x, &y}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x, &y}); + NDArray::registerSpecialUse({&o}, {&x, &y}); - ASSERT_TRUE(exp.equalsTo(&o)); + ASSERT_TRUE(exp.equalsTo(&o)); } TEST_F(JavaInteropTests, Test_Boolean_Op_1) { + sd::ops::is_non_decreasing op; - sd::ops::is_non_decreasing op; - - auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); - auto o = NDArrayFactory::create(false); - auto exp = NDArrayFactory::create(1); + auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto o = NDArrayFactory::create(false); + auto exp = NDArrayFactory::create(1); - NDArray::prepareSpecialUse({&o}, {&x}); + NDArray::prepareSpecialUse({&o}, {&x}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)o.shapeInfo(), + (Nd4jPointer)o.specialShapeInfo()}; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - NDArray::registerSpecialUse({&o}, {&x}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&o}, {&x}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.equalsTo(&o)); + ASSERT_TRUE(exp.equalsTo(&o)); } - TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {2, 3}); - sd::ops::test_output_reshape op; + sd::ops::test_output_reshape op; - NDArray::prepareSpecialUse({&z}, {&x}); + NDArray::prepareSpecialUse({&z}, {&x}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&x}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&x}); + ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto y = NDArrayFactory::create(2.0f); - auto z = NDArrayFactory::create('f', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - - - sd::ops::add op; - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - - NDArray::prepareSpecialUse({&z}, {&x, &y}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); - ASSERT_FALSE(e.ordering() == z.ordering()); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y = NDArrayFactory::create(2.0f); + auto z = NDArrayFactory::create('f', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}, + {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + + sd::ops::add op; + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), + (Nd4jPointer)y.buffer(), x.specialBuffer(), + y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + ASSERT_FALSE(e.ordering() == z.ordering()); } TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) { - auto input = NDArrayFactory::create('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); - auto indices = NDArrayFactory::create('c', {1, 6}, {0,1, 2,2, 1,2}); - auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); - auto e = NDArrayFactory::create('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); - - sd::ops::gather op; - - NDArray::prepareSpecialUse({&output}, {&input, &indices}); - - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) indices.buffer(), input.specialBuffer(), indices.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) indices.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; - - Nd4jLong iArgs[] = {1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); - - NDArray::registerSpecialUse({&output}, {&input, &indices}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(e.isSameShape(output)); - ASSERT_TRUE(e.equalsTo(output)); - ASSERT_FALSE(e.ordering() == output.ordering()); + auto input = NDArrayFactory::create( + 'c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + auto indices = + NDArrayFactory::create('c', {1, 6}, {0, 1, 2, 2, 1, 2}); + auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); + auto e = NDArrayFactory::create( + 'c', {2, 1, 6, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 21, 22, 23, 24, 17, 18, 19, 20, 21, 22, 23, 24}); + + sd::ops::gather op; + + NDArray::prepareSpecialUse({&output}, {&input, &indices}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)indices.buffer(), + input.specialBuffer(), indices.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)indices.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)output.buffer(), + output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)output.shapeInfo(), + (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong iArgs[] = {1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input, &indices}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(output)); + ASSERT_TRUE(e.equalsTo(output)); + ASSERT_FALSE(e.ordering() == output.ordering()); } TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) { - auto x = NDArrayFactory::create('c', {3, 4, 5}); - auto y = NDArrayFactory::create('c', {3, 4, 5}); - auto z = NDArrayFactory::create('c', {5}); - - auto dims = NDArrayFactory::create('c', {2}, {0, 1}); - dims.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()}; - #endif - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {0,1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {0,1}); - - NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dims.dataBuffer()); - - execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(), - packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - - NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); - - delete []extraPointers; + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dims = NDArrayFactory::create('c', {2}, {0, 1}); + dims.syncToHost(); + + sd::LaunchContext *context = sd::LaunchContext::defaultContext(); + + Nd4jPointer *extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[6]{ + nullptr, context->getCudaStream(), context->getScalarPointer(), + nullptr, context->getCudaSpecialStream(), context->getReductionPointer()}; +#endif + + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {0, 1}); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {0, 1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dims.dataBuffer()); + + execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), &dimBuf, dims.shapeInfo(), + dims.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); + + delete[] extraPointers; } /* @@ -868,487 +997,1128 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) { */ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002}); - auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073}); - - sd::ops::avgpool2d op; - - NDArray::prepareSpecialUse({&z}, {&input}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - - Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, + 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, + 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, + 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, + 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, + -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, + 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, + 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, + -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, + 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, + -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, + -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, + 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, + 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, + -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, + 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, + 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, + -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, + 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, + 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, + 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, + 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, + 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, + -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, + 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, + 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, + -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, + 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, + -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, + -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, + -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, + -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, + -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, + 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, + -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, + 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, + 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, + -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, + 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, + 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, + -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, + -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, + 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, + -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, + -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, + 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, + 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, + 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, + 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, + -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, + -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, + 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, + 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, + -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, + 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, + -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, + -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, + 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, + -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, + 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, + 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, + 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, + -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, + 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, + 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, + 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, + 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, + 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, + 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, + 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, + 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, + -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, + 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, + -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, + -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, + 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, + -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, + -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, + 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, + -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, + -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, + 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, + 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, + -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, + 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, + -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, + 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, + 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, + -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, + -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, + 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, + 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, + 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, + 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, + 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, + -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, + 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, + -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, + -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, + 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, + -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, + 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, + 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, + -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, + -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, + 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, + 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, + 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, + 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, + -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, + -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, + 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, + 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, + 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, + 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, + -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, + -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, + 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, + -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, + -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, + 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, + 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, + -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, + 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, + 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, + -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, + 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, + -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, + -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, + 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, + 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, + -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, + 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, + 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, + -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, + 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, + -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, + 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, + 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, + 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, + -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, + 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, + 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, + -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, + 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, + 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, + -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, + 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, + 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, + 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, + 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, + -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, + -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, + 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, + -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, + -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, + 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, + 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, + 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, + 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, + -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, + 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, + 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, + -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, + -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, + 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, + 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, + -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, + 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, + -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, + 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, + 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, + -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, + -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, + -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, + 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, + -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, + 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, + -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, + -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, + 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, + 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, + 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, + 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, + 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, + -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, + 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, + 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, + -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, + 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, + 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, + -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, + 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, + -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, + 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, + 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, + 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, + -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, + 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, + -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, + -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, + 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, + 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, + 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, + 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, + -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, + -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, + 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, + 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, + -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, + 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, + 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, + 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, + 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, + -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, + 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, + 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, + 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, + -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, + 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, + -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, + -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, + 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, + -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, + 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, + 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, + -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, + -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, + 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, + -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, + 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, + 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, + 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, + 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, + -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, + -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, + -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, + 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, + -2.45065260, 4.56992292, -3.02825928, -3.74182844, 5.11069250, + -0.91531068, -2.31385994, 1.83399653, 3.39370203, -3.60886002}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, + 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, + -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, + 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, + 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, + -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, + 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, + -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, + -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, + 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, + 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, + -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, + 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, + -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, + -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, + 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, + 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, + 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, + 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, + -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, + -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, + 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, + 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, + -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, + 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, + 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, + -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, + 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, + -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, + -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, + 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, + -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, + -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, + 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, + 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, + 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, + 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, + -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, + 0.86383581, -1.91504073}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + + Nd4jLong iArgs[] = {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) { - auto input = NDArrayFactory::create('c', {1, 1, 4, 5}); - auto z = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto input = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto z = NDArrayFactory::create('c', {1, 1, 4, 5}); - input.linspace(1.0); + input.linspace(1.0); - NDArray::prepareSpecialUse({&z}, {&input}); + NDArray::prepareSpecialUse({&z}, {&input}); - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0}; + Nd4jLong iArgs[] = {2, 2, 1, 1, 1, 1, 2, 2, 1, 0, 0}; - sd::ops::maxpool2d op; + sd::ops::maxpool2d op; - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Unstack_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - auto z0 = NDArrayFactory::create('c',{5}); - auto z1 = NDArrayFactory::create('c',{5}); - auto z2 = NDArrayFactory::create('c',{5}); - auto z3 = NDArrayFactory::create('c',{5}); - auto z4 = NDArrayFactory::create('c',{5}); - - NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.specialBuffer(), z1.specialBuffer(), z2.specialBuffer(), z3.specialBuffer(), z4.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z0.shapeInfo(), (Nd4jPointer)z1.shapeInfo(), (Nd4jPointer)z2.shapeInfo(), - (Nd4jPointer)z3.shapeInfo(), (Nd4jPointer)z4.shapeInfo(), (Nd4jPointer)z0.specialShapeInfo(), - (Nd4jPointer)z1.specialShapeInfo(), (Nd4jPointer)z2.specialShapeInfo(), - (Nd4jPointer)z3.specialShapeInfo(), (Nd4jPointer)z4.specialShapeInfo()}; - - Nd4jLong iArgs[] = {0}; - - sd::ops::unstack op; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); - - NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); - ASSERT_EQ(Status::OK(), status); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + auto z0 = NDArrayFactory::create('c', {5}); + auto z1 = NDArrayFactory::create('c', {5}); + auto z2 = NDArrayFactory::create('c', {5}); + auto z3 = NDArrayFactory::create('c', {5}); + auto z4 = NDArrayFactory::create('c', {5}); + + NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), + x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), + z2.buffer(), z3.buffer(), + z4.buffer(), z0.specialBuffer(), + z1.specialBuffer(), z2.specialBuffer(), + z3.specialBuffer(), z4.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = { + (Nd4jPointer)z0.shapeInfo(), (Nd4jPointer)z1.shapeInfo(), + (Nd4jPointer)z2.shapeInfo(), (Nd4jPointer)z3.shapeInfo(), + (Nd4jPointer)z4.shapeInfo(), (Nd4jPointer)z0.specialShapeInfo(), + (Nd4jPointer)z1.specialShapeInfo(), (Nd4jPointer)z2.specialShapeInfo(), + (Nd4jPointer)z3.specialShapeInfo(), (Nd4jPointer)z4.specialShapeInfo()}; + + Nd4jLong iArgs[] = {0}; + + sd::ops::unstack op; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - sd::ops::avgpool2d op; - - NDArray::prepareSpecialUse({&z}, {&input}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; - - Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; - - auto hash = op.getOpHash(); - auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); - - NDArray::registerSpecialUse({&z}, {&input}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto input = NDArrayFactory::create( + 'c', {4, 10, 10, 3}, + {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, + 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, + 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, + 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, + 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, + -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, + 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, + 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, + -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, + 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, + -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, + -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, + 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, + 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, + -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, + 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, + 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, + -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, + 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, + 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, + 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, + 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, + 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, + -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, + 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, + 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, + -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, + 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, + -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, + -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, + -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, + -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, + -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, + 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, + -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, + 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, + 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, + -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, + 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, + 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, + -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, + -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, + 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, + -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, + -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, + 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, + 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, + 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, + 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, + -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, + -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, + 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, + 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, + -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, + 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, + -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, + -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, + 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, + -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, + 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, + 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, + 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, + -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, + 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, + 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, + 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, + 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, + 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, + 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, + 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, + 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, + -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, + 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, + -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, + -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, + 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, + -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, + -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, + 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, + -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, + -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, + 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, + 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, + -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, + 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, + -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, + 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, + 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, + -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, + -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, + 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, + 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, + 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, + 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, + 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, + -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, + 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, + -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, + -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, + 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, + -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, + 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, + 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, + -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, + -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, + 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, + 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, + 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, + 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, + -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, + -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, + 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, + 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, + 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, + 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, + -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, + -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, + 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, + -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, + -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, + 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, + 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, + -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, + 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, + 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, + -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, + 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, + -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, + -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, + 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, + 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, + -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, + 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, + 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, + -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, + 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, + -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, + 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, + 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, + 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, + -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, + 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, + 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, + -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, + 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, + 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, + -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, + 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, + 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, + 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, + 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, + -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, + -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, + 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, + -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, + -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, + 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, + 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, + 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, + 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, + -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, + 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, + 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, + -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, + -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, + 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, + 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, + -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, + 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, + -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, + 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, + 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, + -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, + -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, + -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, + 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, + -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, + 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, + -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, + -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, + 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, + 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, + 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, + 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, + 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, + -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, + 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, + 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, + -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, + 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, + 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, + -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, + 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, + -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, + 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, + 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, + 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, + -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, + 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, + -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, + -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, + 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, + 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, + 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, + 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, + -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, + -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, + 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, + 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, + -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, + 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, + 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, + 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, + 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, + -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, + 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, + 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, + 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, + -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, + 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, + -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, + -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, + 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, + -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, + 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, + 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, + -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, + -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, + 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, + -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, + 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, + 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, + 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, + 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, + -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, + -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, + -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, + 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, + -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, + -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create( + 'c', {4, 4, 4, 3}, + {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, + 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, + -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, + 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, + 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, + -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, + 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, + -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, + -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, + 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, + 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, + -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, + 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, + -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, + -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, + 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, + 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, + 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, + 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, + -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, + -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, + 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, + 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, + -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, + 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, + 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, + -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, + 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, + -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, + -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, + 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, + -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, + -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, + 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, + 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, + 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, + 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, + -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, + 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), + input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), + z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + Nd4jLong iArgs[] = {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}; + + auto hash = op.getOpHash(); + auto status = + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, + ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(JavaInteropTests, Test_Mixed_Add_1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto arrayX = NDArrayFactory::create({1, 2, 3, 4}); - auto arrayY = NDArrayFactory::create({1, 2, 3, 4}); - auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); - auto arrayE = NDArrayFactory::create({2, 4, 6, 8}); + auto arrayX = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayY = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); + auto arrayE = NDArrayFactory::create({2, 4, 6, 8}); - NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); - OpaqueDataBuffer xBuf(arrayX.dataBuffer()); - OpaqueDataBuffer yBuf(arrayY.dataBuffer()); - OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); + OpaqueDataBuffer xBuf(arrayX.dataBuffer()); + OpaqueDataBuffer yBuf(arrayY.dataBuffer()); + OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); - execPairwiseTransform(nullptr, pairwise::Add, - &xBuf, arrayX.shapeInfo(), arrayX.specialShapeInfo(), - &yBuf, arrayY.shapeInfo(), arrayY.specialShapeInfo(), - &zBuf, arrayZ.shapeInfo(), arrayZ.specialShapeInfo(), - nullptr); + execPairwiseTransform(nullptr, pairwise::Add, &xBuf, arrayX.shapeInfo(), + arrayX.specialShapeInfo(), &yBuf, arrayY.shapeInfo(), + arrayY.specialShapeInfo(), &zBuf, arrayZ.shapeInfo(), + arrayZ.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); - ASSERT_EQ(arrayE, arrayZ); + ASSERT_EQ(arrayE, arrayZ); } TEST_F(JavaInteropTests, Test_Add_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); - NDArray::prepareSpecialUse({&x}, {&x, &y}); + NDArray::prepareSpecialUse({&x}, {&x, &y}); - sd::ops::add op; + sd::ops::add op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), y.buffer(), x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), y.buffer(), + x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&x}, {&x, &y}); + NDArray::registerSpecialUse({&x}, {&x, &y}); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(JavaInteropTests, zeta_test10) { + auto x = NDArrayFactory::create( + 'c', {3, 4}, + {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create( + 'c', {3, 4}, + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create('c', {3, 4}); - auto x = NDArrayFactory::create('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); - auto q = NDArrayFactory::create('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); - auto z = NDArrayFactory::create('c', {3, 4}); - - auto e = NDArrayFactory::create('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + auto e = NDArrayFactory::create( + 'c', {3, 4}, + {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, + 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); - sd::ops::zeta op; + sd::ops::zeta op; - NDArray::prepareSpecialUse({&z}, {&x, &q}); + NDArray::prepareSpecialUse({&z}, {&x, &q}); - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), q.buffer(), x.specialBuffer(), q.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)q.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)q.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), q.buffer(), + x.specialBuffer(), q.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = { + (Nd4jPointer)x.shapeInfo(), (Nd4jPointer)q.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)q.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, + ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, + nullptr, 0, false); - NDArray::registerSpecialUse({&z}, {&x, &q}); + NDArray::registerSpecialUse({&z}, {&x, &q}); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_IAMax_1) { - auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); - auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); - auto exp = NDArrayFactory::create(1); + auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); + auto arrayZ = + arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); + auto exp = NDArrayFactory::create(1); - ASSERT_EQ(exp, arrayZ); + ASSERT_EQ(exp, arrayZ); } TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) { - auto arrayX = NDArrayFactory::create('c', {10, 10}); - auto arrayY = NDArrayFactory::create('c', {10, 10}); - - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(arrayX.buffer()), reinterpret_cast(arrayY.buffer()), arrayX.specialBuffer(), arrayY.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)arrayX.shapeInfo(), (Nd4jPointer)arrayY.shapeInfo(), (Nd4jPointer)arrayX.specialShapeInfo(), (Nd4jPointer)arrayY.specialShapeInfo()}; - - NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); - sd::ops::greater_equal op; - auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); - NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); - delete shapeList; + auto arrayX = NDArrayFactory::create('c', {10, 10}); + auto arrayY = NDArrayFactory::create('c', {10, 10}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(arrayX.buffer()), + reinterpret_cast(arrayY.buffer()), + arrayX.specialBuffer(), arrayY.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)arrayX.shapeInfo(), + (Nd4jPointer)arrayY.shapeInfo(), + (Nd4jPointer)arrayX.specialShapeInfo(), + (Nd4jPointer)arrayY.specialShapeInfo()}; + + NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); + sd::ops::greater_equal op; + auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 2, nullptr, 0, nullptr, + 0, nullptr, 0, nullptr, 0); + NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); + delete shapeList; } TEST_F(JavaInteropTests, Test_L2_Loss_3) { - auto x = NDArrayFactory::create(0.7787855863571167); - auto e = NDArrayFactory::create(0.303254); - auto z = NDArrayFactory::create(0.0); + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); - NDArray::prepareSpecialUse({&z}, {&x}); + NDArray::prepareSpecialUse({&z}, {&x}); - Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), + x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast(z.buffer()), (Nd4jPointer)z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast(z.buffer()), + (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; - sd::ops::l2_loss op; - auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); + sd::ops::l2_loss op; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, + nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); - NDArray::registerSpecialUse({&z}, {&x}); + NDArray::registerSpecialUse({&z}, {&x}); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, Test_Fastpath_3) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::add op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&array0, &array1}); + NDArray::registerSpecialUse({&z}, {&array0, &array1}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(JavaInteropTests, Test_Fastpath_4) { + auto exp = NDArrayFactory::create( + 'c', {3, 5}, {1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + auto z = NDArrayFactory::create('c', {3, 5}); + Nd4jLong iArgs[] = {3, 5, 2}; - auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1}); - auto z = NDArrayFactory::create('c', {3, 5}); - Nd4jLong iArgs[] = {3, 5, 2}; - - - NDArray::prepareSpecialUse({&z}, {}); + NDArray::prepareSpecialUse({&z}, {}); - Context ctx(1); + Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 3); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 3); - sd::ops::tri op; - execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::tri op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {}); + NDArray::registerSpecialUse({&z}, {}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(JavaInteropTests, Test_Fastpath_5) { - auto a = NDArrayFactory::create('c', {3, 3}); - auto b = NDArrayFactory::create('c', {3, 3}); - auto c = NDArrayFactory::create('c', {3, 3}); - a.linspace(1.0); - b.linspace(1.0); + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto c = NDArrayFactory::create('c', {3, 3}); + a.linspace(1.0); + b.linspace(1.0); - NDArray::prepareSpecialUse({&c}, {&b, &c}); + NDArray::prepareSpecialUse({&c}, {&b, &c}); - Context ctx(1); + Context ctx(1); - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); - ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); + ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), + c.specialShapeInfo()); - sd::ops::matmul op; - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::matmul op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&c}, {&b, &c}); + NDArray::registerSpecialUse({&c}, {&b, &c}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Fastpath_6) { - auto a = NDArrayFactory::create('c', {2, 3}); - auto b = NDArrayFactory::create('c', {3, 4}); - auto gI = NDArrayFactory::create('c', {2, 4}); + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto gI = NDArrayFactory::create('c', {2, 4}); - auto gA = NDArrayFactory::create('c', {2, 3}); - auto gB = NDArrayFactory::create('c', {3, 4}); - a.linspace(1.0); - b.linspace(1.0); - gI.linspace(1.0); + auto gA = NDArrayFactory::create('c', {2, 3}); + auto gB = NDArrayFactory::create('c', {3, 4}); + a.linspace(1.0); + b.linspace(1.0); + gI.linspace(1.0); - NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI}); + NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI}); - Context ctx(1); - Nd4jLong iArgs[] = {0L, 0L, 0L}; + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); - ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); + ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), + gI.specialShapeInfo()); - ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo()); - ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo()); + ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), + gA.specialShapeInfo()); + ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), + gB.specialShapeInfo()); - ctx.setIArguments(iArgs, 3); + ctx.setIArguments(iArgs, 3); - sd::ops::matmul_bp op; - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::matmul_bp op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI}); + NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, Test_Fastpath_7) { - auto a = NDArrayFactory::create('c', {2}, {1.f, 2.f}); - auto b = NDArrayFactory::create(3.f); - auto z = NDArrayFactory::create('c', {3}); - auto e = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto a = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto b = NDArrayFactory::create(3.f); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - NDArray::prepareSpecialUse({&z}, {&a, &b}); + NDArray::prepareSpecialUse({&z}, {&a, &b}); - Context ctx(1); - Nd4jLong iArgs[] = {0L, 0L, 0L}; + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; - ctx.setIArguments(iArgs, 1); + ctx.setIArguments(iArgs, 1); - sd::ops::concat op; + sd::ops::concat op; - ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); - ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), + a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), + b.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&a, &b}); - ASSERT_EQ(Status::OK(), status); + NDArray::registerSpecialUse({&z}, {&a, &b}); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_bfloat16_rng) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - auto z = NDArrayFactory::create('c', {10}); - RandomGenerator rng(119, 323841120L); - bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; - OpaqueDataBuffer zBuf(z.dataBuffer()); - execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args); + auto z = NDArrayFactory::create('c', {10}); + RandomGenerator rng(119, 323841120L); + bfloat16 args[2] = {(bfloat16)0.0f, (bfloat16)1.0f}; + OpaqueDataBuffer zBuf(z.dataBuffer()); + execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, + z.shapeInfo(), z.specialShapeInfo(), args); - //z.printIndexedBuffer("z"); - ASSERT_TRUE(z.sumNumber().e(0) > 0); + // z.printIndexedBuffer("z"); + ASSERT_TRUE(z.sumNumber().e(0) > 0); } TEST_F(JavaInteropTests, test_ismax_view) { - auto original = NDArrayFactory::create('c', {2, 3, 40}); - auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); - v.assign(1.0); + auto original = NDArrayFactory::create('c', {2, 3, 40}); + auto v = original.subarray( + {NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); + v.assign(1.0); - auto e = v.like(); - auto t = e(0, {2}); - t.assign(1.0); + auto e = v.like(); + auto t = e(0, {2}); + t.assign(1.0); - auto z = v.ulike(); + auto z = v.ulike(); + Nd4jLong iArgs[] = {2L, 0L}; + Context ctx(1); + ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); - Nd4jLong iArgs[] = {2L, 0L}; - Context ctx(1); - ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - ctx.setIArguments(iArgs, 1); + sd::ops::ismax op; + op.execute(&ctx); - sd::ops::ismax op; - op.execute(&ctx); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - sd::ops::size op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::size op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(JavaInteropTests, test_expandable_array_op_1) { - auto x = NDArrayFactory::string( {2}, {"first string", "second"}); - auto d = NDArrayFactory::string(" ", sd::DataType::UTF8); + auto x = NDArrayFactory::string({2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" ", sd::DataType::UTF8); - auto z0 = NDArrayFactory::create('c', {6}); - auto z1 = NDArrayFactory::string( {3}, {"", "", ""}); + auto z0 = NDArrayFactory::create('c', {6}); + auto z1 = NDArrayFactory::string({3}, {"", "", ""}); - auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); - auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + auto exp0 = NDArrayFactory::create({0, 0, 0, 1, 1, 0}); + auto exp1 = NDArrayFactory::string({3}, {"first", "string", "second"}); - InteropDataBuffer iz0(z0.dataBuffer()); - InteropDataBuffer iz1(z1.dataBuffer()); + InteropDataBuffer iz0(z0.dataBuffer()); + InteropDataBuffer iz1(z1.dataBuffer()); - Context ctx(1); - ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); - ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); - ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); - ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo()); + ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), + d.specialShapeInfo()); + ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); + ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); - sd::ops::compat_string_split op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::compat_string_split op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); - ASSERT_EQ(exp0, z0); - ASSERT_EQ(exp1, z1); + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); } TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); - auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); - auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); + auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); - double buffer[2048]; + double buffer[2048]; - InteropDataBuffer ix(0, DataType::DOUBLE, false); - InteropDataBuffer iy(0, DataType::DOUBLE, false); - InteropDataBuffer iz(0, DataType::DOUBLE, false); + InteropDataBuffer ix(0, DataType::DOUBLE, false); + InteropDataBuffer iy(0, DataType::DOUBLE, false); + InteropDataBuffer iz(0, DataType::DOUBLE, false); - // we're imitating workspace-managed array here - ix.setPrimary(buffer + 64, x.lengthOf()); - iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); - iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); + // we're imitating workspace-managed array here + ix.setPrimary(buffer + 64, x.lengthOf()); + iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); + iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); - Context ctx(1); - ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); - ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); - ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); + Context ctx(1); + ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); + ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); + ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); - ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); - sd::ops::maxpool2d_bp op; - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); + sd::ops::maxpool2d_bp op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); } TEST_F(JavaInteropTests, test_linspace_shape_1) { - if (!Environment::getInstance()->isCPU()) - return; - - sd::ops::lin_space op; - double tArgs[2] = {1.0, 10.0}; - Nd4jLong iArgs = 10L; - int dArg = (int) sd::DataType::FLOAT32; - auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); - - ASSERT_EQ(1, result->size()); - delete result; + if (!Environment::getInstance()->isCPU()) return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int)sd::DataType::FLOAT32; + auto result = + ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, + tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; } /* @@ -1403,8 +2173,20 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { } */ // TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) { -// std::array syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f}; -// std::array syn1; +// std::array syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, +// -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, +// -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, +// -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, +// -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, +// 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, +// 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, +// 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, +// -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, +// -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, +// 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, +// -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, +// -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, +// -0.020073576f, 0.010949242f}; std::array syn1; // std::array exp; // for (int e = 0; e < syn1.size(); e++) @@ -1412,8 +2194,8 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // for (int e = 0; e < exp.size(); e++) { // auto f = static_cast(e); -// auto tmp = sd::math::nd4j_exp((f / 100000.0 * 2.0 - 1.0) * 6.0); -// exp[e] = static_cast(tmp / (tmp + 1.0)); +// auto tmp = sd::math::nd4j_exp((f / 100000.0 * 2.0 +// - 1.0) * 6.0); exp[e] = static_cast(tmp / (tmp + 1.0)); // } // auto maxTypes = 5; @@ -1432,9 +2214,9 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // int indexPos = maxTypes * batchLimit; // int intArraysPos = indexPos + (maxIndexArguments * batchLimit); -// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit)); -// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2; -// int shapesPos = argsPos + (maxArgs * batchLimit); +// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * +// batchLimit)); int argsPos = (realPos + ((maxRealArguments * batchLimit))) +// / 2; int shapesPos = argsPos + (maxArgs * batchLimit); // std::vector intArray0({0, 0, 0, 0, 0}); // std::vector intArray1({1, 0, 0, 0, 0}); @@ -1442,7 +2224,8 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // std::vector indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0}); // std::vector indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0}); -// std::vector realArgs0({0.024964055335354007f, 3.0768702268737162E18f}); +// std::vector +// realArgs0({0.024964055335354007f, 3.0768702268737162E18f}); // int argSize = 6; // int shapesSize = 0; @@ -1493,6 +2276,7 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) { // ptrptr[idx+1] = reinterpret_cast(syn1.data()); // ptrptr[idx+2] = reinterpret_cast(exp.data()); - -// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data()); +// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, +// maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, +// maxRealArguments, pointer.data()); // } diff --git a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu index a114f71798ad..f210fa8cf198 100644 --- a/libnd4j/tests_cpu/layers_tests/LambdaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LambdaTests.cu @@ -18,202 +18,192 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; class LambdaTests : public testing::Test { -public: - - LambdaTests() { - printf("\n"); - fflush(stdout); - } + public: + LambdaTests() { + printf("\n"); + fflush(stdout); + } }; template -__global__ void runLambda(double *input, double *output, Nd4jLong length, Lambda lambda) { - auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) { - output[e] = lambda(input[e]); - } +__global__ void runLambda(double *input, double *output, Nd4jLong length, + Lambda lambda) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) { + output[e] = lambda(input[e]); + } } -void launcher(cudaStream_t *stream, double *input, double *output, Nd4jLong length) { - //auto f = [] __host__ __device__ (double x) -> double { - // return x + 1.; - //}; - auto f = LAMBDA_D(x) { - return x+1.; - }; +void launcher(cudaStream_t *stream, double *input, double *output, + Nd4jLong length) { + // auto f = [] __host__ __device__ (double x) -> double { + // return x + 1.; + //}; + auto f = LAMBDA_D(x) { return x + 1.; }; - - runLambda<<<128, 128, 128, *stream>>>(input, output, length, f); + runLambda<<<128, 128, 128, *stream>>>(input, output, length, f); } - TEST_F(LambdaTests, test_basic_1) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + // x.applyLambda(f, nullptr); + launcher(LaunchContext::defaultContext()->getCudaStream(), + (double *)x.specialBuffer(), (double *)x.specialBuffer(), + x.lengthOf()); + auto res = + cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + ASSERT_EQ(0, res); - - //x.applyLambda(f, nullptr); - launcher(LaunchContext::defaultContext()->getCudaStream(), (double *)x.specialBuffer(), (double *)x.specialBuffer(), x.lengthOf()); - auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); - ASSERT_EQ(0, res); - - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } void test(NDArray &x) { - auto f = LAMBDA_D(x) { - return x+1.; - }; + auto f = LAMBDA_D(x) { return x + 1.; }; - x.applyLambda(f, x); + x.applyLambda(f, x); } template void test2(NDArray &x) { - auto f = LAMBDA_T(x) { - return x+1.; - }; + auto f = LAMBDA_T(x) { return x + 1.; }; - x.applyLambda(f, x); + x.applyLambda(f, x); } void testPairwise(NDArray &x, NDArray &y) { - auto f = LAMBDA_DD(x, y) { - return x + y +1.; - }; + auto f = LAMBDA_DD(x, y) { return x + y + 1.; }; - x.applyPairwiseLambda(y, f, x); + x.applyPairwiseLambda(y, f, x); } void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { - auto f = LAMBDA_DDD(i, j, k) { - return i + j + k + 2.; - }; + auto f = LAMBDA_DDD(i, j, k) { return i + j + k + 2.; }; - i.applyTriplewiseLambda(j, k, f, i); + i.applyTriplewiseLambda(j, k, f, i); } void testIndexed(NDArray &x) { - auto f = ILAMBDA_D(x) { - return _idx + 1.; - }; + auto f = ILAMBDA_D(x) { return _idx + 1.; }; - x.applyIndexedLambda(f, x); + x.applyIndexedLambda(f, x); } void testIndexedPairwise(NDArray &x, NDArray &y) { - auto f = ILAMBDA_DD(x, y) { - return _idx + x + y +1.; - }; + auto f = ILAMBDA_DD(x, y) { return _idx + x + y + 1.; }; - x.applyIndexedPairwiseLambda(y, f, x); + x.applyIndexedPairwiseLambda(y, f, x); } TEST_F(LambdaTests, test_basic_2) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - test(x); + test(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_3) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - test(x); + test(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_4) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - test2(x); + test2(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_5) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {4., 4., 4., 4., 4.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 4., 4., 4., 4.}); - testPairwise(x, y); + testPairwise(x, y); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_6) { - auto x = NDArrayFactory::create('c', {5}); - auto e = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); - testIndexed(x); + testIndexed(x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(LambdaTests, test_basic_7) { - auto w = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {5., 5., 5., 5., 5.}); + auto w = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {5., 5., 5., 5., 5.}); - testTriplewise(w, x, y); + testTriplewise(w, x, y); - ASSERT_EQ(e, w); + ASSERT_EQ(e, w); } TEST_F(LambdaTests, test_basic_8) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); - auto e = NDArrayFactory::create('c', {5}, {4., 5., 6., 7., 8.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 5., 6., 7., 8.}); - testIndexedPairwise(x, y); + testIndexedPairwise(x, y); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } - template void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) { + auto f = LAMBDA_TT(x, y) { + return sd::math::nd4j_max(x, (T)0.f) - x * y + + sd::math::nd4j_log( + (T)1.f + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; - auto f = LAMBDA_TT(x, y){ - return sd::math::nd4j_max(x, (T)0.f) - - x * y - + sd::math::nd4j_log((T)1.f - + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); - }; - - x.applyPairwiseLambda(y, f, z); + x.applyPairwiseLambda(y, f, z); } /////////////////////////////////////////////////////////////////// TEST_F(LambdaTests, test_basic_9) { - - NDArray labels('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); - NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray output('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray expected('c', {2,3,4}, {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836}); - - logits.linspace(0.1, 0.1); - - NDArray::prepareSpecialUse({&output}, {&logits, &labels}); - testPairwiseMy(logits, labels, output); - NDArray::registerSpecialUse({&output}, {&logits, &labels}); - - // output.printBuffer(nullptr, -1, true); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray labels('c', {2, 3, 4}, {0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0}); + NDArray logits('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray output('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray expected( + 'c', {2, 3, 4}, + {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, + 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, + 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, + 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836}); + + logits.linspace(0.1, 0.1); + + NDArray::prepareSpecialUse({&output}, {&logits, &labels}); + testPairwiseMy(logits, labels, output); + NDArray::registerSpecialUse({&output}, {&logits, &labels}); + + // output.printBuffer(nullptr, -1, true); + ASSERT_TRUE(expected.equalsTo(output)); } diff --git a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu index e16df80e680c..30003277de28 100644 --- a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu @@ -18,108 +18,111 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + #include -#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class LaunchContextCudaTests : public testing::Test { - // + // }; - void acquireContext(int threadId, int &deviceId) { - deviceId = AffinityManager::currentDeviceId(); + deviceId = AffinityManager::currentDeviceId(); - nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId); + nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, + deviceId); - auto lc = LaunchContext::defaultContext(); - nd4j_printf("LC: [%p]\n", lc); + auto lc = LaunchContext::defaultContext(); + nd4j_printf("LC: [%p]\n", lc); - nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream()); + nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), + lc->getCudaStream()); } TEST_F(LaunchContextCudaTests, basic_test_1) { - int deviceA, deviceB; - std::thread threadA(acquireContext, 0, std::ref(deviceA)); - std::thread threadB(acquireContext, 1, std::ref(deviceB)); + int deviceA, deviceB; + std::thread threadA(acquireContext, 0, std::ref(deviceA)); + std::thread threadB(acquireContext, 1, std::ref(deviceB)); - threadA.join(); - threadB.join(); - nd4j_printf("All threads joined\n",""); + threadA.join(); + threadB.join(); + nd4j_printf("All threads joined\n", ""); - if (AffinityManager::numberOfDevices() > 1) - ASSERT_NE(deviceA, deviceB); + if (AffinityManager::numberOfDevices() > 1) ASSERT_NE(deviceA, deviceB); } -void fillArray(int tid, std::vector &arrays) { - auto array = NDArrayFactory::create_('c', {3, 10}); - nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId()); - array->assign(tid); - arrays[tid] = array; +void fillArray(int tid, std::vector &arrays) { + auto array = NDArrayFactory::create_('c', {3, 10}); + nd4j_printf("Array created on device [%i]\n", + AffinityManager::currentDeviceId()); + array->assign(tid); + arrays[tid] = array; } TEST_F(LaunchContextCudaTests, basic_test_2) { - std::vector arrays(2); + std::vector arrays(2); - std::thread threadA(fillArray, 0, std::ref(arrays)); - std::thread threadB(fillArray, 1, std::ref(arrays)); + std::thread threadA(fillArray, 0, std::ref(arrays)); + std::thread threadB(fillArray, 1, std::ref(arrays)); - threadA.join(); - threadB.join(); + threadA.join(); + threadB.join(); - for (int e = 0; e < 2; e++) { - auto array = arrays[e]; - ASSERT_EQ(e, array->e(0)); + for (int e = 0; e < 2; e++) { + auto array = arrays[e]; + ASSERT_EQ(e, array->e(0)); - delete array; - } + delete array; + } } void initAffinity(int tid, std::vector &aff) { - auto affinity = AffinityManager::currentDeviceId(); - aff[tid] = affinity; - nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); + auto affinity = AffinityManager::currentDeviceId(); + aff[tid] = affinity; + nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); } TEST_F(LaunchContextCudaTests, basic_test_3) { - auto totalThreads = AffinityManager::numberOfDevices() * 4; - nd4j_printf("Total threads: %i\n", totalThreads); - std::vector affinities(totalThreads); + auto totalThreads = AffinityManager::numberOfDevices() * 4; + nd4j_printf("Total threads: %i\n", totalThreads); + std::vector affinities(totalThreads); - for (int e = 0; e < totalThreads; e++) { - std::thread thread(initAffinity, e, std::ref(affinities)); + for (int e = 0; e < totalThreads; e++) { + std::thread thread(initAffinity, e, std::ref(affinities)); - thread.join(); - } + thread.join(); + } - std::vector hits(AffinityManager::numberOfDevices()); - std::fill(hits.begin(), hits.end(), 0); + std::vector hits(AffinityManager::numberOfDevices()); + std::fill(hits.begin(), hits.end(), 0); - // we need to make sure all threads were attached to "valid" devices - for (int e = 0; e < totalThreads; e++) { - auto aff = affinities[e]; - ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); + // we need to make sure all threads were attached to "valid" devices + for (int e = 0; e < totalThreads; e++) { + auto aff = affinities[e]; + ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); - hits[aff]++; - } + hits[aff]++; + } - // now we check if all devices got some threads - for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { - ASSERT_GT(hits[e], 0); - } + // now we check if all devices got some threads + for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { + ASSERT_GT(hits[e], 0); + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu index 53179cd6882d..97fd5a512c7c 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu @@ -18,43 +18,49 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; -class LegacyOpsCudaTests : public testing::Test { - -}; - +class LegacyOpsCudaTests : public testing::Test {}; TEST_F(LegacyOpsCudaTests, test_sortTad_1) { - auto x = NDArrayFactory::create('c', {3, 5}, {1.f, 3.f, 0.f, 2.f, 4.f, - 6.f, 5.f, 9.f, 7.f, 8.f, - 10.f, 11.f, 14.f, 12.f, 13.f}); + auto x = + NDArrayFactory::create('c', {3, 5}, + {1.f, 3.f, 0.f, 2.f, 4.f, 6.f, 5.f, 9.f, + 7.f, 8.f, 10.f, 11.f, 14.f, 12.f, 13.f}); - auto e = NDArrayFactory::create('c', {3, 5}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f}); + auto e = + NDArrayFactory::create('c', {3, 5}, + {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, + 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f}); - int axis = 1; - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), axis); + int axis = 1; + auto packX = + ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), axis); - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; - x.syncToDevice(); - sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false); - x.tickWriteDevice(); + x.syncToDevice(); + sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), + packX.platformOffsets(), false); + x.tickWriteDevice(); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index 9ad9a83587c0..47a42aa5470c 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -18,751 +18,788 @@ // Created by raver119 on 16.10.2017. // -#include "testlayers.h" #include +#include #include +#include #include -#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; -class LegacyOpsTests : public testing::Test { - -}; - +class LegacyOpsTests : public testing::Test {}; TEST_F(LegacyOpsTests, TransformTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - auto z = NDArrayFactory::create('c', {5,5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(-1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); - sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg - auto status = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(status, ND4J_STATUS_OK); - //z.printIndexedBuffer("Output NEG"); - ASSERT_TRUE(z.equalsTo(&exp)); + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(status, ND4J_STATUS_OK); + // z.printIndexedBuffer("Output NEG"); + ASSERT_TRUE(z.equalsTo(&exp)); } TEST_F(LegacyOpsTests, TransformTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(-1.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); - sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.equalsTo(z)); } -TEST_F(LegacyOpsTests, Reciprocal_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0f); - - auto ethalon = NDArrayFactory::create('c', {5, 5}); - ethalon.assign(0.5f); +TEST_F(LegacyOpsTests, Reciprocal_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0f); - sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal - Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {}); + auto ethalon = NDArrayFactory::create('c', {5, 5}); + ethalon.assign(0.5f); - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(ethalon.equalsTo(&x)); + sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal + Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(ethalon.equalsTo(&x)); } -TEST_F(LegacyOpsTests, PWT_Tests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); +TEST_F(LegacyOpsTests, PWT_Tests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(3.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(6.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); - sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply - Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, status); - - ASSERT_TRUE(exp.equalsTo(&x)); + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(&x)); } -TEST_F(LegacyOpsTests, PWT_Tests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); - - auto y = NDArrayFactory::create('c', {5, 5}); - y.assign(3.0); +TEST_F(LegacyOpsTests, PWT_Tests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(6.0); + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); - sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply - auto result = op.evaluate({&x, &y}, {}, {}); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); - auto z = result.at(0); + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + auto result = op.evaluate({&x, &y}, {}, {}); - //z->printBuffer("Z"); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + // z->printBuffer("Z"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(LegacyOpsTests, Scalar_Test_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(7.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); - sd::ops::LegacyScalarOp op(scalar::Add); - op.execute({&x}, {&x}, {5.0}, {}, {}); // + sd::ops::LegacyScalarOp op(scalar::Add); + op.execute({&x}, {&x}, {5.0}, {}, {}); // - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(LegacyOpsTests, Scalar_Test_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(2.0); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); - auto exp = NDArrayFactory::create('c', {5, 5}); - exp.assign(7.0); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); - auto y = NDArrayFactory::create(5.0f); + auto y = NDArrayFactory::create(5.0f); - sd::ops::LegacyScalarOp op(scalar::Add, y); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyScalarOp op(scalar::Add, y); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - int opNum = reduce::Sum; - sd::ops::LegacyReduceSameOp op(opNum); - - auto result = op.evaluate({&x}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Sum; + sd::ops::LegacyReduceSameOp op(opNum); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); - // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z.isScalar()); - ASSERT_NEAR(x.sumNumber().e(0), z.e(0), 1e-5f); + ASSERT_EQ(1, result.size()); - + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.sumNumber().e(0), z.e(0), 1e-5f); } - TEST_F(LegacyOpsTests, ReduceTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto axis = NDArrayFactory::create('c', {1}, {1}); - auto result = op.evaluate({&x, &axis}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); - ASSERT_EQ(1, result.size()); + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto result = op.evaluate({&x, &axis}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - auto exp = x.reduceAlongDimension(reduce::Sum, {1}); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_3) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1,1}, {1}); - - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.evaluate({&x, &indices}, {}, {}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Sum,{1}); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_4) { - auto x = NDArrayFactory::create('c', {2, 3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - - - sd::ops::LegacyReduceSameOp op(reduce::Sum); - auto result = op.evaluate({&x, &indices}, {}, {}, {true}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); - // indices.printShapeInfo("Indices shape"); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("Output reduce 4"); - // exp.printIndexedBuffer("Expected reduce 4"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); + + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); + // indices.printShapeInfo("Indices shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Output reduce 4"); + // exp.printIndexedBuffer("Expected reduce 4"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(LegacyOpsTests, ReduceTests_5) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - int opNum = reduce::Mean; - sd::ops::LegacyReduceFloatOp op(opNum); - - auto result = op.evaluate({&x}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Mean; + sd::ops::LegacyReduceFloatOp op(opNum); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}); - auto z = result.at(0); - // z->printBuffer("ReduceTest1"); - ASSERT_TRUE(z.isScalar()); - ASSERT_NEAR(x.meanNumber().e(0), z.e(0), 1e-5f); + ASSERT_EQ(1, result.size()); - + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z.isScalar()); + ASSERT_NEAR(x.meanNumber().e(0), z.e(0), 1e-5f); } - TEST_F(LegacyOpsTests, ReduceTests_6) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(1.0); - auto axis = NDArrayFactory::create('c', {1}, {1}); - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - - auto result = op.evaluate({&x, &axis}, {}, {}); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x, &axis}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - auto exp = x.reduceAlongDimension(reduce::Mean, {1}); + auto z = result.at(0); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_7) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1,1}, {1}); - - - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.evaluate({&x, &indices}, {}, {}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Mean,{1}); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, ReduceTests_8) { - auto x = NDArrayFactory::create('c', {2, 3, 5}); - x.linspace(1); - auto indices = NDArrayFactory::create('c', {1}, {1}); - - - sd::ops::LegacyReduceFloatOp op(reduce::Mean); - auto result = op.evaluate({&x, &indices}, {}, {}, {true}); - auto z = result.at(0); - auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // z->printIndexedBuffer("Reduce8 output"); - // z->printShapeInfo("Reduce8 shape"); - // exp.printShapeInfo("Reduce8 expected shape"); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); - + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Reduce8 output"); + // z->printShapeInfo("Reduce8 shape"); + // exp.printShapeInfo("Reduce8 expected shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(LegacyOpsTests, IndexReduceTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.linspace(1); - - sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); - auto result = op.evaluate({&x}, {}, {}); + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - ASSERT_EQ(1, result.size()); + auto result = op.evaluate({&x}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(z.isScalar()); - ASSERT_EQ(24, z.e(0)); + auto z = result.at(0); - + ASSERT_TRUE(z.isScalar()); + ASSERT_EQ(24, z.e(0)); } - TEST_F(LegacyOpsTests, IndexReduceTests_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto indices = NDArrayFactory::create('c', {1}, {1}); - x.linspace(1); - auto exp = NDArrayFactory::create({4,4,4,4,4}); - sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + auto x = NDArrayFactory::create('c', {5, 5}); + auto indices = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + auto exp = NDArrayFactory::create({4, 4, 4, 4, 4}); + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); - auto result = op.evaluate({&x, &indices}, {}, {}); + auto result = op.evaluate({&x, &indices}, {}, {}); - ASSERT_EQ(1, result.size()); + ASSERT_EQ(1, result.size()); - auto z = result.at(0); - // z->printIndexedBuffer("Hello indexreduce2"); - ASSERT_TRUE(exp.equalsTo(z)); - //ASSERT_EQ(4, z->e(0)); - //ASSERT_EQ(4, z->e(1)); - //ASSERT_EQ(4, z->e(2)); - //ASSERT_EQ(4, z->e(3)); - //ASSERT_EQ(4, z->e(4)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Hello indexreduce2"); + ASSERT_TRUE(exp.equalsTo(z)); + // ASSERT_EQ(4, z->e(0)); + // ASSERT_EQ(4, z->e(1)); + // ASSERT_EQ(4, z->e(2)); + // ASSERT_EQ(4, z->e(3)); + // ASSERT_EQ(4, z->e(4)); } TEST_F(LegacyOpsTests, BroadcastingTests_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - x.assign(0.0f); + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(0.0f); - auto row = NDArrayFactory::create('c', {1, 5}); - row.linspace(1); - auto axis = NDArrayFactory::create('c', {1}, {1}); - sd::ops::LegacyBroadcastOp op(broadcast::Add); - Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {}); + auto row = NDArrayFactory::create('c', {1, 5}); + row.linspace(1); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyBroadcastOp op(broadcast::Add); + Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(ND4J_STATUS_OK, status); - auto list = x.allTensorsAlongDimension({1}); - // x.printIndexedBuffer("Output broadcast"); - // list->at(0)->printIndexedBuffer("Column 0:"); - for (int e = 0; e < list.size(); e++) - ASSERT_TRUE(row.equalsTo(list.at(e))); + auto list = x.allTensorsAlongDimension({1}); + // x.printIndexedBuffer("Output broadcast"); + // list->at(0)->printIndexedBuffer("Column 0:"); + for (int e = 0; e < list.size(); e++) ASSERT_TRUE(row.equalsTo(list.at(e))); } TEST_F(LegacyOpsTests, BroadcastingTests_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); - auto y = NDArrayFactory::create('c', {10, 5}); - auto e = NDArrayFactory::create('c', {10, 5}); - y.assign(3.0); - e.assign(4.0); + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {10, 5}); + auto e = NDArrayFactory::create('c', {10, 5}); + y.assign(3.0); + e.assign(4.0); - int axis = 1; + int axis = 1; - // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {axis}); + // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {axis}); - NDArray::prepareSpecialUse({&y}, {&x}); + NDArray::prepareSpecialUse({&y}, {&x}); - NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, packY.platformShapeInfo(), packY.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcast( + LaunchContext::defaultContext(), broadcast::Add, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, + packY.platformShapeInfo(), packY.platformOffsets(), + packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({&y}, {&x}); + NDArray::registerSpecialUse({&y}, {&x}); - ASSERT_EQ(e, y); + ASSERT_EQ(e, y); } TEST_F(LegacyOpsTests, PowDerivative_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.assign(3.f); - exp.assign(6.f); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3.f); + exp.assign(6.f); - float p = 2.0f; + float p = 2.0f; - x.applyScalar(scalar::PowDerivative, p, x); + x.applyScalar(scalar::PowDerivative, p, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #ifndef __CUDABLAS__ TEST_F(LegacyOpsTests, reduce3_1) { - - Nd4jLong yShape[2] = {4,4}; - Nd4jLong xShape[1] = {4}; - float y[16] ={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}; - float x[4] = {1,2,3,4}; - int dimension[1] = {1}; - int dimensionLength = 1; - int opNum = 1; - float extraVals[1] = {0}; - float result[4] = {0.0,0.0,0.0,0.0}; - - std::vector dim = {1}; - - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape); - auto xShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape); - - //int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength); - auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo('c', dim, shapeBuffer, false, true, nullptr); - functions::reduce3::Reduce3::exec(opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, dimension, dimensionLength, 0, 4); - - float distancesAssertion[4] = {0.0,8.0,16.0,24.0}; - for(int i = 0; i < 4; i++) - ASSERT_NEAR(distancesAssertion[i],result[i], 1e-5); - - delete[] shapeBuffer; - delete[] xShapeBuffer; + Nd4jLong yShape[2] = {4, 4}; + Nd4jLong xShape[1] = {4}; + float y[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + float x[4] = {1, 2, 3, 4}; + int dimension[1] = {1}; + int dimensionLength = 1; + int opNum = 1; + float extraVals[1] = {0}; + float result[4] = {0.0, 0.0, 0.0, 0.0}; + + std::vector dim = {1}; + + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape); + auto xShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape); + + // int *tadShapeBuffer = + // shape::computeResultShape(shapeBuffer,dimension,dimensionLength); + auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo( + 'c', dim, shapeBuffer, false, true, nullptr); + functions::reduce3::Reduce3::exec( + opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, + dimension, dimensionLength, 0, 4); + + float distancesAssertion[4] = {0.0, 8.0, 16.0, 24.0}; + for (int i = 0; i < 4; i++) + ASSERT_NEAR(distancesAssertion[i], result[i], 1e-5); + + delete[] shapeBuffer; + delete[] xShapeBuffer; } #endif - TEST_F(LegacyOpsTests, Reduce3_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto z = NDArrayFactory::create('c', {5}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {1}); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineSimilarity, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineSimilarity, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - delete []extraPointers; + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_3) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif - - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {1}); - - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = NDArrayFactory::create('c', {3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - ASSERT_EQ(e, z); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - delete []extraPointers; + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); + ASSERT_EQ(e, z); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_4) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {1, 3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {1, 5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = + NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {1}); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - // z.printIndexedBuffer("z"); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - ASSERT_EQ(e, z); - delete []extraPointers; + // z.printIndexedBuffer("z"); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete[] extraPointers; } TEST_F(LegacyOpsTests, Reduce3_5) { - auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, - -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, - 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); - - auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); - auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); - auto z = NDArrayFactory::create('c', {1, 3}); - - auto dim = NDArrayFactory::create('c', {1}, {1}); - dim.syncToHost(); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create( + 'c', {3, 5}, + {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, + -0.77555125951, -0.99536740779, -0.0257304441183, -0.6512106060, + -0.345789492130, -1.25485503673, 0.62955373525, -0.31357592344, + 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create( + 'c', {1, 5}, + {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, + -1.25485503673}); + auto e = + NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {1}); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {1}); + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), {1}); - NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3Tad(extraPointers, reduce3::CosineDistance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + execReduce3Tad(extraPointers, reduce3::CosineDistance, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), + packY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); - ASSERT_EQ(e, z); - delete []extraPointers; + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete[] extraPointers; } TEST_F(LegacyOpsTests, test_Reduce3_All_1) { - auto x = NDArrayFactory::create('c', {1000, 100}); - auto y = NDArrayFactory::create('c', {1, 100}); - auto z = NDArrayFactory::create('c', {1000, 1}); - auto dim = NDArrayFactory::create('c', {1}, {-1}); - - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1); - auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1); - - sd::LaunchContext* context = sd::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif + auto x = NDArrayFactory::create('c', {1000, 100}); + auto y = NDArrayFactory::create('c', {1, 100}); + auto z = NDArrayFactory::create('c', {1000, 1}); + auto dim = NDArrayFactory::create('c', {1}, {-1}); + + auto tadPackX = + sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1); + auto tadPackY = + sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; +#ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7]{nullptr, + context->getCudaStream(), + context->getScalarPointer(), + nullptr, + context->getCudaSpecialStream(), + context->getReductionPointer(), + context->getAllocationPointer()}; +#endif - NDArray::prepareSpecialUse({&z}, {&x, &y}); + NDArray::prepareSpecialUse({&z}, {&x, &y}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer dimBuf(dim.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); - execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), - tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &yBuf, + y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dim.shapeInfo(), + dim.specialShapeInfo(), tadPackX.platformShapeInfo(), + tadPackX.platformOffsets(), tadPackY.platformShapeInfo(), + tadPackY.platformOffsets()); - NDArray::registerSpecialUse({&z}, {&x, &y}); + NDArray::registerSpecialUse({&z}, {&x, &y}); - delete []extraPointers; + delete[] extraPointers; } - TEST_F(LegacyOpsTests, test_inverse_broadcast_1) { - auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto e = NDArrayFactory::create('c', {3, 4}); - e.assign(2.0f); + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(2.0f); - auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); + auto tadPackY = + sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); - y.tickWriteDevice(); + y.tickWriteDevice(); - NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, 0, - tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcast( + LaunchContext::defaultContext(), broadcast::Add, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, 0, + tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); - ASSERT_EQ(e, y); + ASSERT_EQ(e, y); } TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { - auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); - auto y = NDArrayFactory::create('c', {3, 4}); - auto z = NDArrayFactory::create('c', {3, 4}); - auto e = NDArrayFactory::create('c', {3, 4}); - e.assign(false); + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto z = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(false); - auto row = y(1, {0}); - row.assign(2.0f); + auto row = y(1, {0}); + row.assign(2.0f); - auto erow = e(1, {0}); - erow.assign(true); + auto erow = e(1, {0}); + erow.assign(true); - auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); + auto tadPackY = + sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1); - z.tickWriteDevice(); + z.tickWriteDevice(); - NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, - nullptr, 0, - tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), - tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + NativeOpExecutioner::execInverseBroadcastBool( + LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), + y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), + z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr, + 0, tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, - x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); - e.assign(std::numeric_limits::infinity()); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(std::numeric_limits::infinity()); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) { - auto x = NDArrayFactory::create('c', {2, 0, 3}); - auto z = NDArrayFactory::create('c', {2, 3}); - auto e = NDArrayFactory::create('c', {2, 3}); - e.assign(-std::numeric_limits::infinity()); + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(-std::numeric_limits::infinity()); - int dim = 1; + int dim = 1; - NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.platformShapeInfo(), nullptr); + NativeOpExecutioner::execReduceSame( + LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, + 1, x.platformShapeInfo(), nullptr); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) { - if (!Environment::getInstance()->isCPU()) - return; - int a = 0; - - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto d = NDArrayFactory::create('c', {1}, {a}); - auto z = NDArrayFactory::create('c', {0, 2}); - auto e = NDArrayFactory::create('c', {0, 2}); - - InteropDataBuffer xdb(x.dataBuffer()); - InteropDataBuffer ddb(d.dataBuffer()); - InteropDataBuffer zdb(z.dataBuffer()); + if (!Environment::getInstance()->isCPU()) return; + int a = 0; + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto d = NDArrayFactory::create('c', {1}, {a}); + auto z = NDArrayFactory::create('c', {0, 2}); + auto e = NDArrayFactory::create('c', {0, 2}); - ::execReduceSame2(nullptr, reduce::SameOps::Sum, - &xdb, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &zdb, z.shapeInfo(), z.specialShapeInfo(), - &ddb, d.shapeInfo(), d.specialShapeInfo()); + InteropDataBuffer xdb(x.dataBuffer()); + InteropDataBuffer ddb(d.dataBuffer()); + InteropDataBuffer zdb(z.dataBuffer()); + ::execReduceSame2(nullptr, reduce::SameOps::Sum, &xdb, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &zdb, z.shapeInfo(), + z.specialShapeInfo(), &ddb, d.shapeInfo(), + d.specialShapeInfo()); } TEST_F(LegacyOpsTests, test_legacy_transform_float_1) { - auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto x = NDArrayFactory::create('c', {1, 0, 4}); - NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr); + NativeOpExecutioner::execTransformFloat( + LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), + x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, + nullptr); } diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index 5f22d48a554e..541a5bcd2df0 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -18,93 +18,88 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; -class ListOperationsTests : public testing::Test { - -}; +class ListOperationsTests : public testing::Test {}; TEST_F(ListOperationsTests, BasicTest_Write_1) { - NDArrayList list(5); - auto x = NDArrayFactory::create('c', {128}); - x.linspace(1); + NDArrayList list(5); + auto x = NDArrayFactory::create('c', {128}); + x.linspace(1); - sd::ops::write_list op; + sd::ops::write_list op; - auto result = op.execute(list, {&x}, {}, {1}); + auto result = op.execute(list, {&x}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, list.elements()); + ASSERT_EQ(1, list.elements()); - auto result2 = op.execute(list, {&x}, {}, {2}); + auto result2 = op.execute(list, {&x}, {}, {2}); - ASSERT_EQ(2, list.elements()); + ASSERT_EQ(2, list.elements()); } TEST_F(ListOperationsTests, BasicTest_Stack_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {10, 100}); - auto tads = exp.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create('c', {100}); - row.assign((double) e); - list.write(e, row); - tads.at(e).assign(row); - } + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {10, 100}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {100}); + row.assign((double)e); + list.write(e, row); + tads.at(e).assign(row); + } - sd::ops::stack_list op; + sd::ops::stack_list op; - auto result = op.execute(list, {}, {}, {1}); + auto result = op.execute(list, {}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printShapeInfo(); + auto z = result.at(0); + // z->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { - NDArrayList list(0, true); - auto x = NDArrayFactory::create('c', {10, 100}); - auto tads = x.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create('c', {100}); - row.assign((double) e); - //list.write(e, row); - tads.at(e).assign(row); - } - - sd::ops::unstack_list op; - - auto result = op.execute(list, {&x}, {}, {0}); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_EQ(10, list.elements()); - -// auto z = result.at(0); -// z->printShapeInfo("The first of"); -// ASSERT_TRUE(exp.isSameShape(z)); -// ASSERT_TRUE(exp.equalsTo(z)); - for (int e = 0; e < 10; e++) { - auto row = list.read(e); - ASSERT_TRUE(row.equalsTo(tads.at(e))); - //list.write(e, row); - } - - + NDArrayList list(0, true); + auto x = NDArrayFactory::create('c', {10, 100}); + auto tads = x.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {100}); + row.assign((double)e); + // list.write(e, row); + tads.at(e).assign(row); + } + + sd::ops::unstack_list op; + + auto result = op.execute(list, {&x}, {}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(10, list.elements()); + + // auto z = result.at(0); + // z->printShapeInfo("The first of"); + // ASSERT_TRUE(exp.isSameShape(z)); + // ASSERT_TRUE(exp.equalsTo(z)); + for (int e = 0; e < 10; e++) { + auto row = list.read(e); + ASSERT_TRUE(row.equalsTo(tads.at(e))); + // list.write(e, row); + } } -//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { +// TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { //// NDArrayList list(0, true); // auto x = NDArrayFactory::create('c', {10, 100}); // auto tads = x.allTensorsAlongDimension({1}); @@ -133,256 +128,241 @@ TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { // //list.write(e, row); // } // -// +// // delete tads; //} TEST_F(ListOperationsTests, BasicTest_Read_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {1, 100}); - exp.assign(4.0f); - - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {1, 100}); - row->assign((double) e); - list.write(e, row->dup()); + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {1, 100}); + exp.assign(4.0f); - delete row; - } + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 100}); + row->assign((double)e); + list.write(e, row->dup()); - sd::ops::read_list op; + delete row; + } - auto result = op.execute(list, {}, {}, {4}); + sd::ops::read_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {}, {}, {4}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Pick_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create('c', {4, 100}); + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {4, 100}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - list.write(e, row->dup()); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double)e); + list.write(e, row->dup()); - delete row; - } + delete row; + } - auto tads = exp.allTensorsAlongDimension({1}); - tads.at(0).assign(1.0f); - tads.at(1).assign(1.0f); - tads.at(2).assign(3.0f); - tads.at(3).assign(3.0f); + auto tads = exp.allTensorsAlongDimension({1}); + tads.at(0).assign(1.0f); + tads.at(1).assign(1.0f); + tads.at(2).assign(3.0f); + tads.at(3).assign(3.0f); + sd::ops::pick_list op; + auto result = op.execute(list, {}, {}, {1, 1, 3, 3}); - sd::ops::pick_list op; - auto result = op.execute(list, {}, {}, {1, 1, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Size_1) { - NDArrayList list(10); - auto exp = NDArrayFactory::create(10); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {100}); - row->assign((double) e); - list.write(e, row->dup()); - - delete row; - } + NDArrayList list(10); + auto exp = NDArrayFactory::create(10); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double)e); + list.write(e, row->dup()); - sd::ops::size_list op; + delete row; + } - auto result = op.execute(list, {}, {}, {1}); + sd::ops::size_list op; - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.execute(list, {}, {}, {1}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); - + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ListOperationsTests, BasicTest_Create_1) { - auto matrix = NDArrayFactory::create('c', {3, 2}); - matrix.linspace(1); + auto matrix = NDArrayFactory::create('c', {3, 2}); + matrix.linspace(1); - sd::ops::create_list op; + sd::ops::create_list op; - auto result = op.execute(NDArrayList(), {&matrix}, {}, {1, 1}); + auto result = op.execute(NDArrayList(), {&matrix}, {}, {1, 1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - // we return flow as well - ASSERT_EQ(1, result.size()); - - + // we return flow as well + ASSERT_EQ(1, result.size()); } TEST_F(ListOperationsTests, BasicTest_Split_1) { - NDArrayList list(0, true); - - auto exp0 = NDArrayFactory::create('c', {2, 5}); - auto exp1 = NDArrayFactory::create('c', {3, 5}); - auto exp2 = NDArrayFactory::create('c', {5, 5}); + NDArrayList list(0, true); - auto matrix = NDArrayFactory::create('c', {10, 5}); + auto exp0 = NDArrayFactory::create('c', {2, 5}); + auto exp1 = NDArrayFactory::create('c', {3, 5}); + auto exp2 = NDArrayFactory::create('c', {5, 5}); - auto lengths = NDArrayFactory::create('c', {3}); - lengths.p(0, 2); - lengths.p(1, 3); - lengths.p(2, 5); + auto matrix = NDArrayFactory::create('c', {10, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); + auto lengths = NDArrayFactory::create('c', {3}); + lengths.p(0, 2); + lengths.p(1, 3); + lengths.p(2, 5); - auto tads0 = exp0.allTensorsAlongDimension({1}); - auto tads1 = exp1.allTensorsAlongDimension({1}); - auto tads2 = exp2.allTensorsAlongDimension({1}); + auto tads = matrix.allTensorsAlongDimension({1}); - int cnt0 = 0; - int cnt1 = 0; - int cnt2 = 0; - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create('c', {5}); - row.assign((double) e); - tads.at(e).assign(row); + auto tads0 = exp0.allTensorsAlongDimension({1}); + auto tads1 = exp1.allTensorsAlongDimension({1}); + auto tads2 = exp2.allTensorsAlongDimension({1}); - if (e < 2) - tads0.at(cnt0++).assign(row); - else if (e < 5) - tads1.at(cnt1++).assign(row); - else - tads2.at(cnt2++).assign(row); - } + int cnt0 = 0; + int cnt1 = 0; + int cnt2 = 0; + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {5}); + row.assign((double)e); + tads.at(e).assign(row); - sd::ops::split_list op; - auto result = op.execute(list, {&matrix, &lengths}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + if (e < 2) + tads0.at(cnt0++).assign(row); + else if (e < 5) + tads1.at(cnt1++).assign(row); + else + tads2.at(cnt2++).assign(row); + } - ASSERT_EQ(3, list.height()); + sd::ops::split_list op; + auto result = op.execute(list, {&matrix, &lengths}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp0.isSameShape(list.readRaw(0))); - ASSERT_TRUE(exp0.equalsTo(list.readRaw(0))); + ASSERT_EQ(3, list.height()); - ASSERT_TRUE(exp1.isSameShape(list.readRaw(1))); - ASSERT_TRUE(exp1.equalsTo(list.readRaw(1))); + ASSERT_TRUE(exp0.isSameShape(list.readRaw(0))); + ASSERT_TRUE(exp0.equalsTo(list.readRaw(0))); - ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); - ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); + ASSERT_TRUE(exp1.isSameShape(list.readRaw(1))); + ASSERT_TRUE(exp1.equalsTo(list.readRaw(1))); - + ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); + ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); } TEST_F(ListOperationsTests, BasicTest_Scatter_1) { - NDArrayList list(0, true); - auto s = NDArrayFactory::create(0.0); + NDArrayList list(0, true); + auto s = NDArrayFactory::create(0.0); - auto matrix = NDArrayFactory::create('c', {10, 5}); - auto tads = matrix.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create_('c', {1, 5}); - row->assign((double) e); - tads.at(e).assign(row); + auto matrix = NDArrayFactory::create('c', {10, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 5}); + row->assign((double)e); + tads.at(e).assign(row); - delete row; - } - auto indices = NDArrayFactory::create('c', {1, 10}); - for (int e = 0; e < matrix.rows(); e++) - indices.p(e, 9 - e); + delete row; + } + auto indices = NDArrayFactory::create('c', {1, 10}); + for (int e = 0; e < matrix.rows(); e++) indices.p(e, 9 - e); - sd::ops::scatter_list op; - auto result = op.execute(list, {&indices, &matrix, &s}, {}, {}); + sd::ops::scatter_list op; + auto result = op.execute(list, {&indices, &matrix, &s}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 10; e++) { - auto row = tads.at(9 - e); - auto chunk = list.readRaw(e); + for (int e = 0; e < 10; e++) { + auto row = tads.at(9 - e); + auto chunk = list.readRaw(e); - ASSERT_TRUE(chunk.isSameShape(row)); + ASSERT_TRUE(chunk.isSameShape(row)); - ASSERT_TRUE(chunk.equalsTo(row)); - } - + ASSERT_TRUE(chunk.equalsTo(row)); + } } TEST_F(ListOperationsTests, BasicTest_Clone_1) { - NDArrayList list(0, true); + NDArrayList list(0, true); - VariableSpace variableSpace; - auto var = std::make_shared(list, "", -1); - variableSpace.putVariable(-1, var); + VariableSpace variableSpace; + auto var = std::make_shared(list, "", -1); + variableSpace.putVariable(-1, var); - Context block(1, &variableSpace); - block.pickInput(-1); + Context block(1, &variableSpace); + block.pickInput(-1); - sd::ops::clone_list op; + sd::ops::clone_list op; - //ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); + // ASSERT_TRUE(list == block.variable(0)->getNDArrayList().get()); - auto result = op.execute(&block); + auto result = op.execute(&block); - ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_EQ(ND4J_STATUS_OK, result); - auto resVar = variableSpace.getVariable(1); + auto resVar = variableSpace.getVariable(1); - auto resList = resVar->getNDArrayList().get(); + auto resList = resVar->getNDArrayList().get(); - ASSERT_TRUE( resList != nullptr); + ASSERT_TRUE(resList != nullptr); - ASSERT_TRUE(list.equals(*resList)); + ASSERT_TRUE(list.equals(*resList)); } TEST_F(ListOperationsTests, BasicTest_Gather_1) { - NDArrayList list(0, true); - for (int e = 0; e < 10; e++) { - auto row = NDArrayFactory::create('c', {3}); - row.assign((double) e); - list.write(e, row.dup()); - } - - auto exp = NDArrayFactory::create('c', {10, 3}); - auto tads = exp.allTensorsAlongDimension({1}); - for (int e = 0; e < 10; e++) { - auto tad = tads.at(9 - e); - tad.assign(e); - } + NDArrayList list(0, true); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create('c', {3}); + row.assign((double)e); + list.write(e, row.dup()); + } - auto indices = NDArrayFactory::create('c', {1, 10}); - indices.linspace(9, -1); + auto exp = NDArrayFactory::create('c', {10, 3}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto tad = tads.at(9 - e); + tad.assign(e); + } - sd::ops::gather_list op; - auto result = op.execute(list, {&indices}, {}, {}); + auto indices = NDArrayFactory::create('c', {1, 10}); + indices.linspace(9, -1); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_EQ(1, result.size()); + sd::ops::gather_list op; + auto result = op.execute(list, {&indices}, {}, {}); - auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); - ASSERT_TRUE(exp.isSameShape(z)); + auto z = result.at(0); - //exp.printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + // exp.printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); - + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp index 976e89550201..b3cf4833292c 100644 --- a/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp @@ -14,210 +14,202 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author Abdelrauf - // +// +// @author Abdelrauf +// -#include "testlayers.h" #include + #include + +#include "testlayers.h" using namespace sd; class LoopCoordsHelper : public testing::Test { -public: - + public: }; - -template -FORCEINLINE -typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type eq_strides(CoordsState& cbs, const Nd4jLong* strides) { - return STRIDE(cbs, rankIndex) == strides[rankIndex]; + return STRIDE(cbs, rankIndex) == strides[rankIndex]; } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type eq_strides(CoordsState& cbs, const Nd4jLong* strides) { - return STRIDE(cbs, rankIndex) == strides[rankIndex] && eq_strides(cbs, strides); + return STRIDE(cbs, rankIndex) == strides[rankIndex] && + eq_strides(cbs, strides); } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 == rankIndex), bool>::type -eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { - return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]; +template +FORCEINLINE typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, + const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && + ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]; } -template -FORCEINLINE -typename std::enable_if<(Rank - 1 != rankIndex), bool>::type -eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { - return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex] - && eq_zip_strides(cbs, strides1, strides2); +template +FORCEINLINE typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, + const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && + ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex] && + eq_zip_strides(cbs, strides1, strides2); } - - - TEST_F(LoopCoordsHelper, Init_Tests) { + constexpr size_t test_Index = 131; + constexpr size_t Rank = 5; - constexpr size_t test_Index = 131; - constexpr size_t Rank = 5; + Nd4jLong shape[Rank] = {3, 5, 7, 8, 9}; + Nd4jLong multiply_st[] = {2, 3, 3, 5, 6, 7, 9, 3}; + Nd4jLong strides_c[Rank]; + Nd4jLong strides_f[Rank]; - Nd4jLong shape[Rank] = { 3, 5 ,7, 8, 9}; - Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; - Nd4jLong strides_c[Rank] ; - Nd4jLong strides_f[Rank]; + Nd4jLong coords[Rank]; + Nd4jLong coords_f[Rank]; - Nd4jLong coords[Rank]; - Nd4jLong coords_f[Rank]; + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - strides_f[0] = multiply_st[0] * shape[0]; - strides_c[Rank-1] = multiply_st[Rank-1] * shape[Rank-1]; + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } - for (int i = 1; i < Rank; i++) { - strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; - } + for (int i = Rank - 2; i >= 0; i--) { + strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; + } - for (int i = Rank-2; i >=0; i--) { - strides_c[i] = strides_c[i+1] * multiply_st[i] * shape[i]; - } + // init our base coords + index2coords_C(test_Index, Rank, shape, coords); + index2coords_F(test_Index, Rank, shape, coords_f); - //init our base coords - index2coords_C(test_Index, Rank, shape, coords); - index2coords_F(test_Index, Rank, shape, coords_f); + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + CoordsState cts; + CoordsState cts_f; - size_t offset_calc = offset_from_coords(strides_c, coords, Rank); - size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); - - CoordsState cts; - CoordsState cts_f; + ZipCoordsState zcts; + ZipCoordsState zcts_f; - ZipCoordsState zcts; - ZipCoordsState zcts_f; + size_t offset = init_coords(cts, test_Index, shape, strides_c); + size_t offset_f = + init_coords(cts_f, test_Index, shape, strides_f); - size_t offset = init_coords(cts, test_Index, shape, strides_c); - size_t offset_f = init_coords(cts_f, test_Index, shape, strides_f); - - zip_size_t zoffset = init_coords(zcts, test_Index, shape, strides_c, strides_c); - zip_size_t zoffset_f = init_coords(zcts_f, test_Index, shape, strides_f, strides_f); - - ASSERT_TRUE(eq_coords(cts, coords)); - ASSERT_TRUE(eq_coords(cts_f, coords_f)); + zip_size_t zoffset = + init_coords(zcts, test_Index, shape, strides_c, strides_c); + zip_size_t zoffset_f = init_coords(zcts_f, test_Index, shape, + strides_f, strides_f); - ASSERT_TRUE(eq_zip_coords(zcts, coords)); - ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); - ASSERT_TRUE(eq_strides(cts,strides_c)); - ASSERT_TRUE(eq_strides(cts_f,strides_f)); + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); - ASSERT_TRUE(eq_zip_strides(zcts, strides_c, strides_c)); - ASSERT_TRUE(eq_zip_strides(zcts_f, strides_f, strides_f)); + ASSERT_TRUE(eq_strides(cts, strides_c)); + ASSERT_TRUE(eq_strides(cts_f, strides_f)); + ASSERT_TRUE(eq_zip_strides(zcts, strides_c, strides_c)); + ASSERT_TRUE(eq_zip_strides(zcts_f, strides_f, strides_f)); - ASSERT_EQ(offset , offset_calc); - ASSERT_EQ(zoffset.first , offset_calc); - ASSERT_EQ(zoffset.second , offset_calc); - ASSERT_EQ(offset_f , offset_calc_f); - ASSERT_EQ(zoffset_f.first , offset_calc_f); - ASSERT_EQ(zoffset_f.second , offset_calc_f); + ASSERT_EQ(offset, offset_calc); + ASSERT_EQ(zoffset.first, offset_calc); + ASSERT_EQ(zoffset.second, offset_calc); + ASSERT_EQ(offset_f, offset_calc_f); + ASSERT_EQ(zoffset_f.first, offset_calc_f); + ASSERT_EQ(zoffset_f.second, offset_calc_f); } - -TEST_F(LoopCoordsHelper, Increment_Use_Tests) { +TEST_F(LoopCoordsHelper, Increment_Use_Tests) { + constexpr size_t Rank = 4; - constexpr size_t Rank = 4; - - Nd4jLong shape[Rank] = { 3, 5 ,7, 8 }; - Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; - Nd4jLong strides_c[Rank]; - Nd4jLong strides_f[Rank]; - - Nd4jLong coords[Rank] = {}; - Nd4jLong coords_f[Rank] = {}; - Nd4jLong coords2[Rank] = {}; - Nd4jLong coords2_f[Rank] = {}; - Nd4jLong zcoords2[Rank] = {}; - Nd4jLong zcoords2_f[Rank] = {}; - - strides_f[0] = multiply_st[0] * shape[0]; - strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - - for (int i = 1; i < Rank; i++) { - strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; - } - - for (int i = Rank - 2; i >= 0; i--) { - strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; - } - - int total = 1; - for (int i = 0; i < Rank; i++) { - total *= shape[i]; - } - - CoordsState cts; - CoordsState cts_f; - - ZipCoordsState zcts; - ZipCoordsState zcts_f; - - size_t offset = init_coords(cts, 0, shape, strides_c); - size_t offset_f = init_coords(cts_f, 0, shape, strides_f); + Nd4jLong shape[Rank] = {3, 5, 7, 8}; + Nd4jLong multiply_st[] = {2, 3, 3, 5, 6, 7, 9, 3}; + Nd4jLong strides_c[Rank]; + Nd4jLong strides_f[Rank]; - zip_size_t zoffset = init_coords(zcts, 0, shape, strides_c, strides_c); - zip_size_t zoffset_f = init_coords(zcts_f, 0, shape, strides_f, strides_f); + Nd4jLong coords[Rank] = {}; + Nd4jLong coords_f[Rank] = {}; + Nd4jLong coords2[Rank] = {}; + Nd4jLong coords2_f[Rank] = {}; + Nd4jLong zcoords2[Rank] = {}; + Nd4jLong zcoords2_f[Rank] = {}; - size_t offset2 = 0; - size_t offset2_f = 0; - zip_size_t zoffset2 = {}; - zip_size_t zoffset2_f = {}; + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; - for (int j = 0; j < total; j++) { + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } + for (int i = Rank - 2; i >= 0; i--) { + strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; + } - index2coords_C(j, Rank, shape, coords); - index2coords_F(j, Rank, shape, coords_f); + int total = 1; + for (int i = 0; i < Rank; i++) { + total *= shape[i]; + } - size_t offset_calc = offset_from_coords(strides_c, coords, Rank); - size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + CoordsState cts; + CoordsState cts_f; + ZipCoordsState zcts; + ZipCoordsState zcts_f; - ASSERT_TRUE(eq_coords(cts, coords)); - ASSERT_TRUE(eq_coords(cts_f, coords_f)); + size_t offset = init_coords(cts, 0, shape, strides_c); + size_t offset_f = init_coords(cts_f, 0, shape, strides_f); - ASSERT_TRUE(eq_zip_coords(zcts, coords)); - ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + zip_size_t zoffset = init_coords(zcts, 0, shape, strides_c, strides_c); + zip_size_t zoffset_f = + init_coords(zcts_f, 0, shape, strides_f, strides_f); - ASSERT_EQ(offset, offset_calc); - ASSERT_EQ(zoffset.first, offset_calc); - ASSERT_EQ(zoffset.second, offset_calc); - ASSERT_EQ(offset_f, offset_calc_f); - ASSERT_EQ(zoffset_f.first, offset_calc_f); - ASSERT_EQ(zoffset_f.second, offset_calc_f); + size_t offset2 = 0; + size_t offset2_f = 0; + zip_size_t zoffset2 = {}; + zip_size_t zoffset2_f = {}; + for (int j = 0; j < total; j++) { + index2coords_C(j, Rank, shape, coords); + index2coords_F(j, Rank, shape, coords_f); - ASSERT_EQ(offset2, offset_calc); - ASSERT_EQ(zoffset2.first, offset_calc); - ASSERT_EQ(zoffset2.second, offset_calc); - ASSERT_EQ(offset2_f, offset_calc_f); - ASSERT_EQ(zoffset2_f.first, offset_calc_f); - ASSERT_EQ(zoffset2_f.second, offset_calc_f); + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); - offset = inc_coords(cts, offset); - offset_f = inc_coords(cts_f, offset_f); - zoffset = inc_coords(zcts, zoffset); - zoffset_f = inc_coords(zcts_f, zoffset_f); + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); - offset2 = inc_coords(shape,strides_c, coords2, offset2, Rank); - offset2_f = inc_coords(shape, strides_f, coords2_f, offset2_f, Rank); - zoffset2 = inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank); - zoffset2_f = inc_coords(shape, strides_f, strides_f, zcoords2_f, zoffset2_f, Rank); + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); - } - + ASSERT_EQ(offset, offset_calc); + ASSERT_EQ(zoffset.first, offset_calc); + ASSERT_EQ(zoffset.second, offset_calc); + ASSERT_EQ(offset_f, offset_calc_f); + ASSERT_EQ(zoffset_f.first, offset_calc_f); + ASSERT_EQ(zoffset_f.second, offset_calc_f); + + ASSERT_EQ(offset2, offset_calc); + ASSERT_EQ(zoffset2.first, offset_calc); + ASSERT_EQ(zoffset2.second, offset_calc); + ASSERT_EQ(offset2_f, offset_calc_f); + ASSERT_EQ(zoffset2_f.first, offset_calc_f); + ASSERT_EQ(zoffset2_f.second, offset_calc_f); + + offset = inc_coords(cts, offset); + offset_f = inc_coords(cts_f, offset_f); + zoffset = inc_coords(zcts, zoffset); + zoffset_f = inc_coords(zcts_f, zoffset_f); + + offset2 = inc_coords(shape, strides_c, coords2, offset2, Rank); + offset2_f = inc_coords(shape, strides_f, coords2_f, offset2_f, Rank); + zoffset2 = + inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank); + zoffset2_f = inc_coords(shape, strides_f, strides_f, zcoords2_f, + zoffset2_f, Rank); + } } - diff --git a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp index 5202fc3a1bc1..e381702db92b 100644 --- a/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ManagedDataBufferTests.cpp @@ -18,40 +18,44 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include +#include #include #include +#include + #include -#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ManagedDataBufferTests : public testing::Test { -public: - ManagedDataBufferTests() { - /// - } + public: + ManagedDataBufferTests() { + /// + } }; TEST_F(ManagedDataBufferTests, basic_constructor_test_1) { - GraphMemoryManager mgr; - auto mdb = std::make_shared(mgr, 0, DataType::FLOAT32, memory::MemoryZone::HOT); + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 0, DataType::FLOAT32, + memory::MemoryZone::HOT); - NDArray array(mdb, 'c', {0}); + NDArray array(mdb, 'c', {0}); } TEST_F(ManagedDataBufferTests, basic_constructor_test_2) { - auto exp = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto exp = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - GraphMemoryManager mgr; - auto mdb = std::make_shared(mgr, 20, DataType::FLOAT32, memory::MemoryZone::HOT); + GraphMemoryManager mgr; + auto mdb = std::make_shared(mgr, 20, DataType::FLOAT32, + memory::MemoryZone::HOT); - ASSERT_NE(nullptr, mdb->platform()); + ASSERT_NE(nullptr, mdb->platform()); - NDArray array(mdb, 'c', {5}); - array.assign(1.0f); + NDArray array(mdb, 'c', {5}); + array.assign(1.0f); - ASSERT_EQ(exp, array); + ASSERT_EQ(exp, array); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp index 4bfe40405012..bee1e37355cf 100644 --- a/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MemoryUtilsTests.cpp @@ -20,27 +20,24 @@ #include #include + #include "testlayers.h" using namespace sd::memory; class MemoryUtilsTests : public testing::Test { -public: - + public: }; TEST_F(MemoryUtilsTests, BasicRetrieve_1) { - MemoryReport reportA; - MemoryReport reportB; + MemoryReport reportA; + MemoryReport reportB; #ifdef _WIN32 - if (1 > 0) - return; + if (1 > 0) return; #endif + MemoryUtils::retrieveMemoryStatistics(reportA); - MemoryUtils::retrieveMemoryStatistics(reportA); - - - ASSERT_NE(reportA, reportB); + ASSERT_NE(reportA, reportB); } diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index c91c1c5c7811..8c95507df314 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -14,96 +14,99 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // // @author raver119@gmail.com // #ifdef HAVE_MKLDNN -#include "testlayers.h" -#include -#include -#include #include #include +#include +#include +#include + +#include "testlayers.h" using namespace sd; class MklDnnTests : public testing::Test { -public: - + public: }; -static void printer(std::initializer_list helpers) { - - for (auto v:helpers) { - nd4j_printf("Initialized [%s]\n", v->name().c_str()); - } +static void printer( + std::initializer_list helpers) { + for (auto v : helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } } - TEST_F(MklDnnTests, helpers_includer) { - // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; + // we need this block, to make sure all helpers are still available within + // binary, and not optimized out by linker + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; - sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; - sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; - sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; - sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; - sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; - sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; - sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; - sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; - sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; - sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; - sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; + sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; - sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; - sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; + sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; - sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; + sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; - sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; + sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; - sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; - sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; - - sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; - sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; + sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; + sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; - printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, + &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, + &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, + &lrn, &batchnorm, &matmul, &softmax, + &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, + &xw_plus_b_bp}); } TEST_F(MklDnnTests, test_tanh_1) { - auto x = NDArrayFactory::create(1.0f); - auto z = NDArrayFactory::create(0.0f); + auto x = NDArrayFactory::create(1.0f); + auto z = NDArrayFactory::create(0.0f); - sd::ops::tanh op; - auto status = op.execute({&x}, {&z}); + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } TEST_F(MklDnnTests, test_tanh_2) { - auto x = NDArrayFactory::create('c', {1}, {1.0f}); - auto z = NDArrayFactory::create('c', {1}, {0.0f}); + auto x = NDArrayFactory::create('c', {1}, {1.0f}); + auto z = NDArrayFactory::create('c', {1}, {0.0f}); - sd::ops::tanh op; - auto status = op.execute({&x}, {&z}); + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); - ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(Status::OK(), status); } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MmapTests.cpp b/libnd4j/tests_cpu/layers_tests/MmapTests.cpp index c1df42fd173e..4318eedbc0eb 100644 --- a/libnd4j/tests_cpu/layers_tests/MmapTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MmapTests.cpp @@ -18,38 +18,38 @@ // Created by raver on 5/13/2018. // -#include "testlayers.h" -#include #include #include +#include + #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class MmapTests : public testing::Test { -public: - + public: }; TEST_F(MmapTests, Test_Basic_Mmap_1) { - // FIXME: we must adopt this for CUDA as well - if (!Environment::getInstance()->isCPU()) - return; + // FIXME: we must adopt this for CUDA as well + if (!Environment::getInstance()->isCPU()) return; - // just 10GB - Nd4jLong size = 100000L; + // just 10GB + Nd4jLong size = 100000L; - std::ofstream ofs("file", std::ios::binary | std::ios::out); - ofs.seekp(size + 1024L); - ofs.write("", 1); - ofs.close(); + std::ofstream ofs("file", std::ios::binary | std::ios::out); + ofs.seekp(size + 1024L); + ofs.write("", 1); + ofs.close(); - auto result = mmapFile(nullptr, "file", size); + auto result = mmapFile(nullptr, "file", size); - ASSERT_FALSE(result == nullptr); + ASSERT_FALSE(result == nullptr); - munmapFile(nullptr, result, size); + munmapFile(nullptr, result, size); - remove("file"); + remove("file"); } diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 94325c50dfb7..0652d9fd4c7a 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -18,1752 +18,1760 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include #include -#include #include +#include +#include "testlayers.h" using namespace sd; class MultiDataTypeTests : public testing::Test { -public: - + public: }; //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_1) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::FLOAT32); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::FLOAT32); - ASSERT_EQ(sd::FLOAT32, dtype); + ASSERT_EQ(sd::FLOAT32, dtype); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_2) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::DOUBLE); - ASSERT_EQ(sd::DOUBLE, dtype); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::DOUBLE); + ASSERT_EQ(sd::DOUBLE, dtype); - ASSERT_EQ(sd::DOUBLE, DataTypeUtils::pickPairwiseResultType(sd::DOUBLE, sd::INT32)); + ASSERT_EQ(sd::DOUBLE, + DataTypeUtils::pickPairwiseResultType(sd::DOUBLE, sd::INT32)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, DataTypeUtils_Test_3) { - auto dtype = DataTypeUtils::pickPairwiseResultType(sd::FLOAT32, sd::DOUBLE); - ASSERT_EQ(sd::FLOAT32, dtype); + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::FLOAT32, sd::DOUBLE); + ASSERT_EQ(sd::FLOAT32, dtype); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create('c', {2, 3}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create('c', {2, 3}, + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x + y; + auto z = x + y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_3) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create( + 'c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_4) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto y = NDArrayFactory::create(2.0); - auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create( + 'c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_5) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); - auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(MultiDataTypeTests, Basic_Test_7) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {2, 3}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}); - auto e = NDArrayFactory::create('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {2, 3}, + {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}); + auto e = NDArrayFactory::create('c', {2, 3}, + {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); - sd::ops::add op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, Basic_Test_6) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); - auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); - auto z = x * y; + auto z = x * y; - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test1) { - NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::UINT8); - NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::UINT8); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::UINT8); + NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::UINT8); - const double number = 10.8; - x = number; + const double number = 10.8; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test2) { - NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); - NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::INT64); + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::INT64); - const bool number = 1000; - x = number; + const bool number = 1000; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_number_test3) { - NDArray x('c', {2, 3}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::BOOL); + NDArray x('c', {2, 3}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::BOOL); - const int number = 1000; - x = number; + const int number = 1000; + x = number; - ASSERT_EQ(x,exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { - NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray y('c', {2, 4}, sd::DataType::HALF); - NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, sd::DataType::HALF); + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray y('c', {2, 4}, sd::DataType::HALF); + NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, + sd::DataType::HALF); - x.repeat(1, {2}, y); + x.repeat(1, {2}, y); - ASSERT_EQ(y, exp); + ASSERT_EQ(y, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { - NDArray x('f', {2}, {1.5, 3.5}, sd::DataType::FLOAT32); - NDArray y('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray x('f', {2}, {1.5, 3.5}, sd::DataType::FLOAT32); + NDArray y('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); - const int* buffX = x.bufferAsT(); - const int* buffY = y.bufferAsT(); + const int* buffX = x.bufferAsT(); + const int* buffY = y.bufferAsT(); - ASSERT_EQ(*buffX, *buffY); + ASSERT_EQ(*buffX, *buffY); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_assign_test1) { - NDArray x('c', {2,2}, {0, 1, 2, 3}, sd::DataType::UINT8); - NDArray exp('c', {2,2}, {10, 10, 20, 20}, sd::DataType::UINT8); + NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::UINT8); + NDArray exp('c', {2, 2}, {10, 10, 20, 20}, sd::DataType::UINT8); - NDArray scalar1('c', {}, std::vector{10.5}, sd::DataType::FLOAT32); - NDArray scalar2('c', {}, std::vector{20.8}, sd::DataType::DOUBLE); + NDArray scalar1('c', {}, std::vector{10.5}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{20.8}, sd::DataType::DOUBLE); - x(0,{0}).assign(scalar1); - x(1,{0}).assign(scalar2); + x(0, {0}).assign(scalar1); + x(1, {0}).assign(scalar2); - ASSERT_EQ(x, exp); + ASSERT_EQ(x, exp); - x.assign(exp); + x.assign(exp); - ASSERT_EQ(x, exp); + ASSERT_EQ(x, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { - NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {1,1}, std::vector{1}, sd::DataType::INT64); - NDArray exp3('c', {2}, std::vector{1,2}, sd::DataType::INT64); + NDArray x('f', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {1, 1}, std::vector{1}, sd::DataType::INT64); + NDArray exp3('c', {2}, std::vector{1, 2}, sd::DataType::INT64); - auto scalar1 = x.reduceAlongDimension(sd::reduce::CountNonZero, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = + x.reduceAlongDimension(sd::reduce::CountNonZero, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::CountZero, {}/*whole range*/, true); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = + x.reduceAlongDimension(sd::reduce::CountZero, {} /*whole range*/, true); + ASSERT_EQ(scalar2, exp2); - auto scalar3 = x.reduceAlongDimension(sd::reduce::CountNonZero, {1}); - ASSERT_EQ(scalar3, exp3); + auto scalar3 = x.reduceAlongDimension(sd::reduce::CountNonZero, {1}); + ASSERT_EQ(scalar3, exp3); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { - NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp1('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); - NDArray exp2('c', {2}, {0.5,2.5}, sd::DataType::FLOAT32); + NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp1('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {2}, {0.5, 2.5}, sd::DataType::FLOAT32); - auto scalar1 = x.reduceAlongDimension(sd::reduce::Mean, {}/*whole range*/); - // scalar1->printShapeInfo(); - // scalar1->printIndexedBuffer(); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(sd::reduce::Mean, {} /*whole range*/); + // scalar1->printShapeInfo(); + // scalar1->printIndexedBuffer(); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { - NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{8.}, sd::DataType::HALF); - NDArray exp2('c', {2}, {2.,6.}, sd::DataType::HALF); + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{8.}, sd::DataType::HALF); + NDArray exp2('c', {2}, {2., 6.}, sd::DataType::HALF); - auto scalar1 = x.reduceAlongDimension(sd::reduce::Sum, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.reduceAlongDimension(sd::reduce::Sum, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { - NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, sd::DataType::HALF); - NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); - NDArray exp2('c', {2}, std::vector{1, 0}, sd::DataType::BOOL); + NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2}, std::vector{1, 0}, sd::DataType::BOOL); - auto scalar1 = x.reduceAlongDimension(sd::reduce::IsPositive, {}/*whole range*/); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = + x.reduceAlongDimension(sd::reduce::IsPositive, {} /*whole range*/); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { - NDArray x('f', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{1.666666667}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{1.118033989}, sd::DataType::FLOAT32); + NDArray x('f', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1.666666667}, + sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{1.118033989}, + sd::DataType::FLOAT32); - auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); - ASSERT_EQ(scalar1, exp1); + auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); + ASSERT_EQ(scalar1, exp1); - auto scalar2 = x.varianceNumber(variance::SummaryStatsStandardDeviation, false); - ASSERT_EQ(scalar2, exp2); + auto scalar2 = + x.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_EQ(scalar2, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::FLOAT32); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::FLOAT32); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); - NDArray exp('c', {2, 2}, {-1, -1, 1, 1}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {-1, -1, 1, 1}, sd::DataType::FLOAT32); - ASSERT_EQ(x1+x2, exp); - ASSERT_EQ(x1+x3, exp); + ASSERT_EQ(x1 + x2, exp); + ASSERT_EQ(x1 + x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = -2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::HALF); - ASSERT_EQ(x1+val1, exp1); - ASSERT_EQ(val1+x1, exp1); + ASSERT_EQ(x1 + val1, exp1); + ASSERT_EQ(val1 + x1, exp1); - ASSERT_EQ(x2+val2, exp2); - ASSERT_EQ(val2+x2, exp2); + ASSERT_EQ(x2 + val2, exp2); + ASSERT_EQ(val2 + x2, exp2); - ASSERT_EQ(x3+val1, exp3); - ASSERT_EQ(val1+x3, exp3); + ASSERT_EQ(x3 + val1, exp3); + ASSERT_EQ(val1 + x3, exp3); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::HALF); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::HALF); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::HALF); - NDArray exp('c', {2, 2}, {1, 3, 3, 5}, sd::DataType::HALF); + NDArray exp('c', {2, 2}, {1, 3, 3, 5}, sd::DataType::HALF); - ASSERT_EQ(x1-x2, exp); - ASSERT_EQ(x1-x3, exp); + ASSERT_EQ(x1 - x2, exp); + ASSERT_EQ(x1 - x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = 2; - const int val2 = 2; - NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {2, 1, 0, -1}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); - NDArray exp5('c', {2,2}, {2, 1, 0, -1}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {2, 1, 0, -1}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = 2; + NDArray exp1('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::HALF); + NDArray exp5('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {2, 1, 0, -1}, sd::DataType::HALF); - ASSERT_EQ(x1-val1, exp1); - ASSERT_EQ(val1-x1, exp2); + ASSERT_EQ(x1 - val1, exp1); + ASSERT_EQ(val1 - x1, exp2); - ASSERT_EQ(x2-val2, exp3); - ASSERT_EQ(val2-x2, exp5); + ASSERT_EQ(x2 - val2, exp3); + ASSERT_EQ(val2 - x2, exp5); - ASSERT_EQ(x3-val1, exp4); - ASSERT_EQ(val1-x3, exp6); + ASSERT_EQ(x3 - val1, exp4); + ASSERT_EQ(val1 - x3, exp6); } -//////////////////////////////////////////////////////////////////////////////// multiply +//////////////////////////////////////////////////////////////////////////////// +///multiply TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::DOUBLE); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::DOUBLE); - NDArray exp('c', {2, 2}, {0, -2, -2, -6}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 2}, {0, -2, -2, -6}, sd::DataType::DOUBLE); - ASSERT_EQ(x1*x2, exp); - ASSERT_EQ(x1*x3, exp); + ASSERT_EQ(x1 * x2, exp); + ASSERT_EQ(x1 * x3, exp); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); - const double val1 = -2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {0, -2, -4, -6}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, -2, -4, -6}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {0, -2, -4, -6}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {0, -2, -4, -6}, sd::DataType::HALF); - ASSERT_EQ(x1*val1, exp1); - ASSERT_EQ(val1*x1, exp1); + ASSERT_EQ(x1 * val1, exp1); + ASSERT_EQ(val1 * x1, exp1); - ASSERT_EQ(x2*val2, exp2); - ASSERT_EQ(val2*x2, exp2); + ASSERT_EQ(x2 * val2, exp2); + ASSERT_EQ(val2 * x2, exp2); - ASSERT_EQ(x3*val1, exp3); - ASSERT_EQ(val1*x3, exp3); + ASSERT_EQ(x3 * val1, exp3); + ASSERT_EQ(val1 * x3, exp3); } - -//////////////////////////////////////////////////////////////////////////////// multiply +//////////////////////////////////////////////////////////////////////////////// +///multiply TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {4, 1, 2, 3}, sd::DataType::HALF); - NDArray x2('c', {2, 2}, {-1, -2, -1, -9}, sd::DataType::DOUBLE); - NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 2}, {4, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2, 2}, {-1, -2, -1, -9}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); - NDArray exp1('c', {2, 2}, {-4, -0.5, -2, -0.3333333}, sd::DataType::HALF); - NDArray exp2('c', {2, 2}, {-0.25, -2, -0.5, -0.666667}, sd::DataType::HALF); + NDArray exp1('c', {2, 2}, {-4, -0.5, -2, -0.3333333}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {-0.25, -2, -0.5, -0.666667}, sd::DataType::HALF); - ASSERT_EQ(x1/x2, exp1); - ASSERT_EQ(x3/x1, exp2); + ASSERT_EQ(x1 / x2, exp1); + ASSERT_EQ(x3 / x1, exp2); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray x2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); - const double val1 = 2; - const int val2 = -2; - NDArray exp1('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0, -1, -1, -2}, sd::DataType::INT64); - NDArray exp4('c', {2,2}, {-2, -1, 0., 0.}, sd::DataType::INT64); - NDArray exp5('c', {2,2}, {-0.5, -1, -1.5, -2}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {-2, -1, -0.666667, -0.5}, sd::DataType::FLOAT32); - NDArray exp7('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::HALF); - NDArray exp8('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = -2; + NDArray exp1('c', {2, 2}, {0.5, 1, 1.5, 2}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {2, 1, 0.666667, 0.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0, -1, -1, -2}, sd::DataType::INT64); + NDArray exp4('c', {2, 2}, {-2, -1, 0., 0.}, sd::DataType::INT64); + NDArray exp5('c', {2, 2}, {-0.5, -1, -1.5, -2}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {-2, -1, -0.666667, -0.5}, sd::DataType::FLOAT32); + NDArray exp7('c', {2, 2}, {0.5, 1, 1.5, 2}, sd::DataType::HALF); + NDArray exp8('c', {2, 2}, {2, 1, 0.666667, 0.5}, sd::DataType::HALF); - ASSERT_EQ(x1/val1, exp1); - ASSERT_EQ(val1/x1, exp2); + ASSERT_EQ(x1 / val1, exp1); + ASSERT_EQ(val1 / x1, exp2); - ASSERT_EQ(x1/val2, exp3); - ASSERT_EQ(val2/x1, exp4); + ASSERT_EQ(x1 / val2, exp3); + ASSERT_EQ(val2 / x1, exp4); - ASSERT_EQ(x2/val2, exp5); - ASSERT_EQ(val2/x2, exp6); + ASSERT_EQ(x2 / val2, exp5); + ASSERT_EQ(val2 / x2, exp6); - ASSERT_EQ(x3/val1, exp7); - ASSERT_EQ(val1/x3, exp8); + ASSERT_EQ(x3 / val1, exp7); + ASSERT_EQ(val1 / x3, exp8); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{5}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{6.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{5}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{6.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {11, 22, 33, 44, 55, 66}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, 1.5, 2.4, 3.5}, sd::DataType::HALF); - scalar1 += scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 += scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 += scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 += scalar1; + ASSERT_EQ(scalar2, exp2); - x2 += x1; - ASSERT_EQ(x2, exp3); + x2 += x1; + ASSERT_EQ(x2, exp3); - x1 += x2; - ASSERT_EQ(x1, exp4); + x1 += x2; + ASSERT_EQ(x1, exp4); - x4 += x3; - ASSERT_EQ(x4, exp5); + x4 += x3; + ASSERT_EQ(x4, exp5); - x6 += x5; - ASSERT_EQ(x6, exp5); + x6 += x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {2.5, 3.5, 4.5, 5.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {2, 3, 4.5, 5}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {4.7, 5.7, 6.7, 7.7}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {4, 5, 6, 7}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {2.5, 3.5, 4.5, 5.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {2, 3, 4.5, 5}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {4.7, 5.7, 6.7, 7.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {4, 5, 6, 7}, sd::DataType::INT32); - x1 += val1; - ASSERT_EQ(x1, exp1); + x1 += val1; + ASSERT_EQ(x1, exp1); - x2 += val1; - ASSERT_EQ(x2, exp2); + x2 += val1; + ASSERT_EQ(x2, exp2); - x1 += val2; - ASSERT_EQ(x1, exp3); + x1 += val2; + ASSERT_EQ(x1, exp3); - x2 += val2; - ASSERT_EQ(x2, exp4); + x2 += val2; + ASSERT_EQ(x2, exp4); - x1 += val3; - ASSERT_EQ(x1, exp5); + x1 += val3; + ASSERT_EQ(x1, exp5); - x2 += val3; - ASSERT_EQ(x2, exp6); + x2 += val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{2}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{-0.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{2}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{-0.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {8, 17, 26, 35, 44, 53}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, -0.5, -1.6, -2.5}, sd::DataType::HALF); - scalar1 -= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 -= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 -= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 -= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 -= x1; - ASSERT_EQ(x2, exp3); + x2 -= x1; + ASSERT_EQ(x2, exp3); - x1 -= x2; - ASSERT_EQ(x1, exp4); + x1 -= x2; + ASSERT_EQ(x1, exp4); - x4 -= x3; - ASSERT_EQ(x4, exp5); + x4 -= x3; + ASSERT_EQ(x4, exp5); - x6 -= x5; - ASSERT_EQ(x6, exp5); + x6 -= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {-2.5, -1.5, -0.5, 0.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {-2., -1., 0., 0.}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {-4.7, -3.7, -2.7, -1.7}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {-4, -3, -2, -2}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {-1, 0, 1, 2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {-1, 0, 1, 2}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {-2.5, -1.5, -0.5, 0.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {-2., -1., 0., 0.}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {-4.7, -3.7, -2.7, -1.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {-4, -3, -2, -2}, sd::DataType::INT32); - x1 -= val1; - ASSERT_EQ(x1, exp1); + x1 -= val1; + ASSERT_EQ(x1, exp1); - x2 -= val1; - ASSERT_EQ(x2, exp2); + x2 -= val1; + ASSERT_EQ(x2, exp2); - x1 -= val2; - ASSERT_EQ(x1, exp3); + x1 -= val2; + ASSERT_EQ(x1, exp3); - x2 -= val2; - ASSERT_EQ(x2, exp4); + x2 -= val2; + ASSERT_EQ(x2, exp4); - x1 -= val3; - ASSERT_EQ(x1, exp5); + x1 -= val3; + ASSERT_EQ(x1, exp5); - x2 -= val3; - ASSERT_EQ(x2, exp6); + x2 -= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{7}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{17.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{7}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{17.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {1, 5, 10, 18, 27, 39}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 0.5, 0.8, 1.5}, sd::DataType::HALF); - scalar1 *= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 *= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 *= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 *= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 *= x1; - ASSERT_EQ(x2, exp3); + x2 *= x1; + ASSERT_EQ(x2, exp3); - x1 *= x2; - ASSERT_EQ(x1, exp4); + x1 *= x2; + ASSERT_EQ(x1, exp4); - x4 *= x3; - ASSERT_EQ(x4, exp5); + x4 *= x3; + ASSERT_EQ(x4, exp5); - x6 *= x5; - ASSERT_EQ(x6, exp5); + x6 *= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 1.5; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {0, 1.5, 3, 4.5}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0, 1, 3, 4}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {0, 3.3, 6.6, 9.9}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {0, 2, 6, 8}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {0, 1.5, 3, 4.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0, 1, 3, 4}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {0, 3.3, 6.6, 9.9}, sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {0, 2, 6, 8}, sd::DataType::INT32); - x1 *= val1; - ASSERT_EQ(x1, exp1); + x1 *= val1; + ASSERT_EQ(x1, exp1); - x2 *= val1; - ASSERT_EQ(x2, exp2); + x2 *= val1; + ASSERT_EQ(x2, exp2); - x1 *= val2; - ASSERT_EQ(x1, exp3); + x1 *= val2; + ASSERT_EQ(x1, exp3); - x2 *= val2; - ASSERT_EQ(x2, exp4); + x2 *= val2; + ASSERT_EQ(x2, exp4); - x1 *= val3; - ASSERT_EQ(x1, exp5); + x1 *= val3; + ASSERT_EQ(x1, exp5); - x2 *= val3; - ASSERT_EQ(x2, exp6); + x2 *= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); - NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); - NDArray x5('c', {2,2}, {1, 2, 3, 4}, sd::DataType::HALF); - NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); - NDArray exp1('c', {0}, std::vector{1}, sd::DataType::INT32); - NDArray exp2('c', {0}, std::vector{2.5}, sd::DataType::HALF); - NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, sd::DataType::INT64); - NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, sd::DataType::HALF); + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray exp3('c', {3, 2}, {6, 8, 8, 8, 9, 9}, sd::DataType::INT64); + NDArray exp4('c', {2, 3}, + {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0.4, 0.25, 0.1333333, 0.125}, sd::DataType::HALF); - scalar1 /= scalar2; - ASSERT_EQ(scalar1, exp1); + scalar1 /= scalar2; + ASSERT_EQ(scalar1, exp1); - scalar2 /= scalar1; - ASSERT_EQ(scalar2, exp2); + scalar2 /= scalar1; + ASSERT_EQ(scalar2, exp2); - x2 /= x1; - ASSERT_EQ(x2, exp3); + x2 /= x1; + ASSERT_EQ(x2, exp3); - x1 /= x2; - ASSERT_EQ(x1, exp4); + x1 /= x2; + ASSERT_EQ(x1, exp4); - x4 /= x3; - ASSERT_EQ(x4, exp5); + x4 /= x3; + ASSERT_EQ(x4, exp5); - x6 /= x5; - ASSERT_EQ(x6, exp5); + x6 /= x5; + ASSERT_EQ(x6, exp5); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); - NDArray x2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); + NDArray x1('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray x2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT32); - const Nd4jLong val1 = 1; - const float16 val2 = 2.; - const double val3 = 2.2; + const Nd4jLong val1 = 1; + const float16 val2 = 2.; + const double val3 = 2.2; - NDArray exp1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray exp5('c', {2,2}, {0, 0.45454545, 0.909090909, 1.363636364}, sd::DataType::FLOAT32); - NDArray exp6('c', {2,2}, {0, 0, 0, 1}, sd::DataType::INT32); + NDArray exp1('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp5('c', {2, 2}, {0, 0.45454545, 0.909090909, 1.363636364}, + sd::DataType::FLOAT32); + NDArray exp6('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::INT32); - x1 /= val1; - ASSERT_EQ(x1, exp1); + x1 /= val1; + ASSERT_EQ(x1, exp1); - x2 /= val1; - ASSERT_EQ(x2, exp2); + x2 /= val1; + ASSERT_EQ(x2, exp2); - x1 /= val2; - ASSERT_EQ(x1, exp3); + x1 /= val2; + ASSERT_EQ(x1, exp3); - x2 /= val2; - ASSERT_EQ(x2, exp4); + x2 /= val2; + ASSERT_EQ(x2, exp4); - x1 /= val3; - ASSERT_EQ(x1, exp5); + x1 /= val3; + ASSERT_EQ(x1, exp5); - x2 /= val3; - ASSERT_EQ(x2, exp6); + x2 /= val3; + ASSERT_EQ(x2, exp6); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{1.5}, sd::DataType::FLOAT32); - NDArray exp2('c', {0}, std::vector{2}, sd::DataType::HALF); - NDArray exp3('c', {0}, std::vector{2}, sd::DataType::DOUBLE); - NDArray exp4('c', {0}, std::vector{0.25},sd::DataType::FLOAT32); + NDArray exp1('c', {0}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{2}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{0.25}, sd::DataType::FLOAT32); + NDArray scalar = x1.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp1); - NDArray scalar = x1.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x2.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x2.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x3.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x3.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::Mean,scalar); - ASSERT_EQ(scalar, exp3); - - scalar = x4.reduceNumber(reduce::Mean); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::Mean, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray exp1('c', {0}, std::vector{6}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{8}, sd::DataType::HALF); - NDArray exp3('c', {0}, std::vector{8}, sd::DataType::DOUBLE); - NDArray exp4('c', {0}, std::vector{1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{6}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{8}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{8}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{1}, sd::DataType::BOOL); - NDArray scalar = x1.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x2.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x3.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp3); + scalar = x3.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x4.reduceNumber(reduce::Sum); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::Sum, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, -1, 2, -3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, -1.5, 2.5, -3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, -1, 2, -3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, -1.5, 2.5, -3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {-2, -1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{1}, sd::DataType::BOOL); + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::BOOL); - NDArray scalar = x1.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x2.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x2.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x2.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x3.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x3.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x3.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x3.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x4.reduceNumber(reduce::IsFinite); - ASSERT_EQ(scalar, exp1); - x4.reduceNumber(reduce::IsFinite, scalar); - ASSERT_EQ(scalar, exp1); + scalar = x4.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x4.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0.5, -1.5, 0, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{4}, sd::DataType::INT64); - NDArray exp3('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp4('c', {0}, std::vector{2}, sd::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{4}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp4('c', {0}, std::vector{2}, sd::DataType::INT64); - NDArray scalar = x1.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp1); - x1.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp1); - scalar = x2.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp2); - x2.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp2); + scalar = x2.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp2); - scalar = x3.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp3); - x3.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp3); + scalar = x3.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp3); - scalar = x4.reduceNumber(reduce::CountNonZero); - ASSERT_EQ(scalar, exp4); - x4.reduceNumber(reduce::CountNonZero, scalar); - ASSERT_EQ(scalar, exp4); + scalar = x4.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp4); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, -1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray exp2('c', {0}, std::vector{2}, sd::DataType::INT64); - NDArray exp3('c', {0}, std::vector{1}, sd::DataType::INT64); + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{1}, sd::DataType::INT64); - NDArray scalar = x1.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp1); - scalar = x2.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp2); + scalar = x2.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp2); - scalar = x3.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); - ASSERT_EQ(scalar, exp3); + scalar = x3.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformFloat_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 4, 9, 16}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,2}, {0, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::HALF); + NDArray exp1('c', {2, 2}, {0, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::HALF); - NDArray result1('c', {2,2}, sd::DataType::FLOAT32); - NDArray result2('c', {2,2}, sd::DataType::DOUBLE); - NDArray result3('c', {2,2}, sd::DataType::HALF); + NDArray result1('c', {2, 2}, sd::DataType::FLOAT32); + NDArray result2('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result3('c', {2, 2}, sd::DataType::HALF); - x1.applyTransform(sd::transform::Sqrt, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::Sqrt, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::Sqrt, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::Sqrt, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::Sqrt, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::Sqrt, result3); - ASSERT_EQ(result3, exp4); + x4.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp4); - x2.applyTransform(sd::transform::Sqrt, x2); - ASSERT_EQ(x2, exp3); + x2.applyTransform(sd::transform::Sqrt, x2); + ASSERT_EQ(x2, exp3); - x3.applyTransform(sd::transform::Sqrt, x3); - ASSERT_EQ(x3, exp2); + x3.applyTransform(sd::transform::Sqrt, x3); + ASSERT_EQ(x3, exp2); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformSame_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 1, 4, 9}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); - NDArray exp3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp5('c', {3,2}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {0, 1, 4, 9}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray exp3('c', {2, 2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray exp4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp5('c', {3, 2}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, + sd::DataType::DOUBLE); - NDArray result1('c', {2,2}, sd::DataType::INT64); - NDArray result2('c', {2,2}, sd::DataType::HALF); - NDArray result3('c', {2,2}, sd::DataType::DOUBLE); - NDArray result4('c', {2,2}, sd::DataType::BOOL); - NDArray result5('c', {3,2}, sd::DataType::DOUBLE); + NDArray result1('c', {2, 2}, sd::DataType::INT64); + NDArray result2('c', {2, 2}, sd::DataType::HALF); + NDArray result3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result4('c', {2, 2}, sd::DataType::BOOL); + NDArray result5('c', {3, 2}, sd::DataType::DOUBLE); - x1.applyTransform(sd::transform::Square, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::Square, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::Square, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::Square, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::Square, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::Square, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::Square, result4); - ASSERT_EQ(result4, exp4); + x4.applyTransform(sd::transform::Square, result4); + ASSERT_EQ(result4, exp4); - x2.applyTransform(sd::transform::Square, x2); - ASSERT_EQ(x2, exp2); + x2.applyTransform(sd::transform::Square, x2); + ASSERT_EQ(x2, exp2); - x3.applyTransform(sd::transform::Square, x3); - ASSERT_EQ(x3, exp3); + x3.applyTransform(sd::transform::Square, x3); + ASSERT_EQ(x3, exp3); - x5.applyTransform(sd::transform::Square, result5); - ASSERT_EQ(result5, exp5); + x5.applyTransform(sd::transform::Square, result5); + ASSERT_EQ(result5, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformBool_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); - NDArray exp3('c', {3,2}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 2}, {0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {3, 2}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray result1('c', {2,2}, sd::DataType::BOOL); - NDArray result2('c', {3,2}, sd::DataType::BOOL); + NDArray result1('c', {2, 2}, sd::DataType::BOOL); + NDArray result2('c', {3, 2}, sd::DataType::BOOL); - /* - x1.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + /* + x1.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + x2.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x3.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp1); + x3.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); - x4.applyTransform(sd::transform::IsMax, result1); - ASSERT_EQ(result1, exp2); + x4.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp2); - x5.applyTransform(sd::transform::IsMax, result2); - ASSERT_EQ(result2, exp3); - */ + x5.applyTransform(sd::transform::IsMax, result2); + ASSERT_EQ(result2, exp3); + */ } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTransformStrict_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x4('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {0, 3, 12, 27}, sd::DataType::HALF); - NDArray exp2('c', {2,2}, {0, 3, 12, 27}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,2}, {0, 3, 12, 27}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,3}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 2}, {0, 3, 12, 27}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 3}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); - NDArray result1('c', {2,2}, sd::DataType::HALF); - NDArray result2('c', {2,2}, sd::DataType::FLOAT32); - NDArray result3('c', {2,2}, sd::DataType::DOUBLE); - NDArray result4('c', {3,2}, sd::DataType::DOUBLE); + NDArray result1('c', {2, 2}, sd::DataType::HALF); + NDArray result2('c', {2, 2}, sd::DataType::FLOAT32); + NDArray result3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray result4('c', {3, 2}, sd::DataType::DOUBLE); - x1.applyTransform(sd::transform::CubeDerivative, result1); - ASSERT_EQ(result1, exp1); + x1.applyTransform(sd::transform::CubeDerivative, result1); + ASSERT_EQ(result1, exp1); - x2.applyTransform(sd::transform::CubeDerivative, result2); - ASSERT_EQ(result2, exp2); + x2.applyTransform(sd::transform::CubeDerivative, result2); + ASSERT_EQ(result2, exp2); - x3.applyTransform(sd::transform::CubeDerivative, result3); - ASSERT_EQ(result3, exp3); + x3.applyTransform(sd::transform::CubeDerivative, result3); + ASSERT_EQ(result3, exp3); - x4.applyTransform(sd::transform::CubeDerivative, result4); - ASSERT_EQ(result4, exp4); + x4.applyTransform(sd::transform::CubeDerivative, result4); + ASSERT_EQ(result4, exp4); - x1.applyTransform(sd::transform::CubeDerivative, x1); - ASSERT_EQ(x1, exp1); + x1.applyTransform(sd::transform::CubeDerivative, x1); + ASSERT_EQ(x1, exp1); - x2.applyTransform(sd::transform::CubeDerivative, x2); - ASSERT_EQ(x2, exp2); + x2.applyTransform(sd::transform::CubeDerivative, x2); + ASSERT_EQ(x2, exp2); - x3.applyTransform(sd::transform::CubeDerivative, x3); - ASSERT_EQ(x3, exp3); + x3.applyTransform(sd::transform::CubeDerivative, x3); + ASSERT_EQ(x3, exp3); - x4.applyTransform(sd::transform::CubeDerivative, x4); - ASSERT_EQ(x4, exp5); + x4.applyTransform(sd::transform::CubeDerivative, x4); + ASSERT_EQ(x4, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); - NDArray x2('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,3}, {0, 1, 0, 1, 0, 0}, sd::DataType::BOOL); - NDArray x4('c', {3,2}, {0.5, 1.5, 2.5, 3.5, 4.5, 0}, sd::DataType::DOUBLE); - NDArray x5('c', {3,2}, sd::DataType::INT32); - NDArray x6('c', {2,3}, sd::DataType::DOUBLE); + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 3}, {0, 1, 0, 1, 0, 0}, sd::DataType::BOOL); + NDArray x4('c', {3, 2}, {0.5, 1.5, 2.5, 3.5, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3, 2}, sd::DataType::INT32); + NDArray x6('c', {2, 3}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,3}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,3}, {1, 1, 1, 1, 1, 0}, sd::DataType::BOOL); - NDArray exp4('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::DOUBLE); - NDArray exp5('c', {3,2}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, + sd::DataType::FLOAT32); + NDArray exp3('c', {2, 3}, {1, 1, 1, 1, 1, 0}, sd::DataType::BOOL); + NDArray exp4('c', {2, 3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, + sd::DataType::DOUBLE); + NDArray exp5('c', {3, 2}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); - x1.applyPairwiseTransform(sd::pairwise::Add, x4, x5); - ASSERT_EQ(x5, exp5); + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x5); + ASSERT_EQ(x5, exp5); - x1.applyPairwiseTransform(sd::pairwise::Add, x4, x6); - ASSERT_EQ(x6, exp4); + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x6); + ASSERT_EQ(x6, exp4); - x1.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x1, exp1); + x1.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x1, exp1); - x2.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x2, exp2); + x2.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x2, exp2); - x3.applyPairwiseTransform(sd::pairwise::Add, x4); - ASSERT_EQ(x3, exp3); + x3.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x3, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {1, 1, 2, 3, 4, 5}, sd::DataType::INT32); - NDArray x2('c', {3,2}, {1, 0, 2, 0, 4, 0}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {0.5, 1.5, 2.5, 3, 4.5, 0}, sd::DataType::DOUBLE); - NDArray x4('c', {2,3}, {0.5, 1., 2.5, 3, 4., 0}, sd::DataType::DOUBLE); - NDArray x5('c', {3,2}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x6('c', {2,3}, {1, 1, 1, 0, 1, 0}, sd::DataType::BOOL); + NDArray x1('c', {2, 3}, {1, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {3, 2}, {1, 0, 2, 0, 4, 0}, sd::DataType::INT32); + NDArray x3('c', {3, 2}, {0.5, 1.5, 2.5, 3, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 3}, {0.5, 1., 2.5, 3, 4., 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3, 2}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x6('c', {2, 3}, {1, 1, 1, 0, 1, 0}, sd::DataType::BOOL); - NDArray x7('c', {3,2}, sd::DataType::BOOL); - NDArray x8('c', {2,3}, sd::DataType::BOOL); + NDArray x7('c', {3, 2}, sd::DataType::BOOL); + NDArray x8('c', {2, 3}, sd::DataType::BOOL); - NDArray exp1('c', {3,2}, {1, 0, 1, 0, 1, 0}, sd::DataType::BOOL); - NDArray exp2('c', {2,3}, {1, 0, 1, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp3('c', {2,3}, {0, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + NDArray exp1('c', {3, 2}, {1, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2, 3}, {1, 0, 1, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp3('c', {2, 3}, {0, 1, 0, 0, 0, 0}, sd::DataType::BOOL); - x1.applyPairwiseTransform(sd::pairwise::EqualTo, x2, x7); - ASSERT_EQ(x7, exp1); + x1.applyPairwiseTransform(sd::pairwise::EqualTo, x2, x7); + ASSERT_EQ(x7, exp1); - x3.applyPairwiseTransform(sd::pairwise::EqualTo, x4, x8); - ASSERT_EQ(x8, exp2); + x3.applyPairwiseTransform(sd::pairwise::EqualTo, x4, x8); + ASSERT_EQ(x8, exp2); - x5.applyPairwiseTransform(sd::pairwise::EqualTo, x6, x8); - ASSERT_EQ(x8, exp3); + x5.applyPairwiseTransform(sd::pairwise::EqualTo, x6, x8); + ASSERT_EQ(x8, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); - NDArray x2('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x3('c', {2,3}, sd::DataType::INT32); - NDArray x4('c', {2}, {1, 2}, sd::DataType::FLOAT32); - NDArray x5('c', {2,3}, sd::DataType::FLOAT32); - NDArray x6('c', {2}, {1, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x3('c', {2, 3}, sd::DataType::INT32); + NDArray x4('c', {2}, {1, 2}, sd::DataType::FLOAT32); + NDArray x5('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x6('c', {2}, {1, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::FLOAT32); - NDArray exp3('c', {2,3}, {11, 21, 31, 41, 51, 61}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {11, 21, 31, 42, 52, 62}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {11, 21, 31, 42, 52, 62}, sd::DataType::FLOAT32); + NDArray exp3('c', {2, 3}, {11, 21, 31, 41, 51, 61}, sd::DataType::INT32); - x1.applyBroadcast(sd::broadcast::Add, {0}, x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyBroadcast(sd::broadcast::Add, {0}, x2, x3); + ASSERT_EQ(x3, exp1); - x1.applyBroadcast(sd::broadcast::Add, {0}, x4, x5); - ASSERT_EQ(x5, exp2); + x1.applyBroadcast(sd::broadcast::Add, {0}, x4, x5); + ASSERT_EQ(x5, exp2); - x1.applyBroadcast(sd::broadcast::Add, {0}, x6, x3); - ASSERT_EQ(x3, exp3); + x1.applyBroadcast(sd::broadcast::Add, {0}, x6, x3); + ASSERT_EQ(x3, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test2) { + NDArray x1('c', {2, 3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {10, 60}, sd::DataType::INT32); + NDArray x3('c', {2, 3}, sd::DataType::BOOL); - NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); - NDArray x2('c', {2}, {10, 60}, sd::DataType::INT32); - NDArray x3('c', {2,3}, sd::DataType::BOOL); - - NDArray x4('c', {2,3}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray x5('c', {2}, {0, 1}, sd::DataType::BOOL); + NDArray x4('c', {2, 3}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2}, {0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,3}, {1, 0, 0, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 3}, {1, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); - x1.applyBroadcast(sd::broadcast::EqualTo, {0}, x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyBroadcast(sd::broadcast::EqualTo, {0}, x2, x3); + ASSERT_EQ(x3, exp1); - x4.applyBroadcast(sd::broadcast::EqualTo, {0}, x5, x3); - ASSERT_EQ(x3, exp2); + x4.applyBroadcast(sd::broadcast::EqualTo, {0}, x5, x3); + ASSERT_EQ(x3, exp2); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x2('c', {2}, {1, 2}, sd::DataType::HALF); - NDArray x3('c', {2,2}, sd::DataType::HALF); + NDArray x1('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, sd::DataType::HALF); - NDArray x4('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x5('c', {2,2}, sd::DataType::INT32); + NDArray x4('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x5('c', {2, 2}, sd::DataType::INT32); - NDArray x6('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray x7('c', {2}, {1, 2}, sd::DataType::INT64); - NDArray x8('c', {2,2}, sd::DataType::BOOL); + NDArray x6('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x7('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x8('c', {2, 2}, sd::DataType::BOOL); - NDArray x13('c', {0}, std::vector{3}, sd::DataType::INT64); - NDArray x14('c', {0}, std::vector{1.5}, sd::DataType::DOUBLE); - NDArray x15(sd::DataType::DOUBLE); - NDArray x16('c', {2,2}, sd::DataType::DOUBLE); + NDArray x13('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray x14('c', {0}, std::vector{1.5}, sd::DataType::DOUBLE); + NDArray x15(sd::DataType::DOUBLE); + NDArray x16('c', {2, 2}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {11, 22, 31, 42}, sd::DataType::HALF); - NDArray exp2('c', {2,2}, {11, 22, 31, 42}, sd::DataType::INT32); - NDArray exp3('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp4('c', {0}, std::vector{4.5}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 2}, {11, 22, 31, 42}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {11, 22, 31, 42}, sd::DataType::INT32); + NDArray exp3('c', {2, 2}, {1, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp4('c', {0}, std::vector{4.5}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 2}, {11.5, 21.5, 31.5, 41.5}, sd::DataType::DOUBLE); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2, x3); - ASSERT_EQ(x3, exp1); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2, x3); + ASSERT_EQ(x3, exp1); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4, x5); - ASSERT_EQ(x5, exp2); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4, x5); + ASSERT_EQ(x5, exp2); - x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7, x8); - ASSERT_EQ(x8, exp3); + x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7, x8); + ASSERT_EQ(x8, exp3); - auto x9 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); - ASSERT_EQ(x9, exp1); + auto x9 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x9, exp1); - auto x10 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4); - ASSERT_EQ(x10, exp2); + auto x10 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4); + ASSERT_EQ(x10, exp2); - auto x11 = x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7); - ASSERT_EQ(x11, exp3); + auto x11 = x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7); + ASSERT_EQ(x11, exp3); - auto x12 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); - ASSERT_EQ(x12, exp1); + auto x12 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x12, exp1); - x13.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x15); - ASSERT_EQ(x15, exp4); + x13.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x15); + ASSERT_EQ(x15, exp4); - x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x16); - ASSERT_EQ(x16, exp5); - - x14.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x1, x16); - ASSERT_EQ(x16, exp5); + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x16); + ASSERT_EQ(x16, exp5); + x14.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x1, x16); + ASSERT_EQ(x16, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::HALF); - NDArray x2('c', {2}, {10, 40}, sd::DataType::HALF); - NDArray x3('c', {2,2}, sd::DataType::BOOL); - NDArray x4('c', {0}, std::vector{10}, sd::DataType::HALF); - NDArray x5('c', {0}, std::vector{20}, sd::DataType::HALF); - NDArray x6(sd::DataType::BOOL); - - NDArray exp1('c', {2,2}, {1, 0, 0, 1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); - NDArray exp3('c', {0}, std::vector{0}, sd::DataType::BOOL); - - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x2, x3); - ASSERT_EQ(x3, exp1); - - x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x3); - ASSERT_EQ(x3, exp2); - - x4.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x1, x3); - ASSERT_EQ(x3, exp2); - - x5.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x6); - ASSERT_EQ(x6, exp3); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x1('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::HALF); + NDArray x2('c', {2}, {10, 40}, sd::DataType::HALF); + NDArray x3('c', {2, 2}, sd::DataType::BOOL); + NDArray x4('c', {0}, std::vector{10}, sd::DataType::HALF); + NDArray x5('c', {0}, std::vector{20}, sd::DataType::HALF); + NDArray x6(sd::DataType::BOOL); + + NDArray exp1('c', {2, 2}, {1, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {1, 0, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {0}, std::vector{0}, sd::DataType::BOOL); + + x1.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x2, x3); + ASSERT_EQ(x3, exp1); + + x1.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x4, x3); + ASSERT_EQ(x3, exp2); + + x4.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x1, x3); + ASSERT_EQ(x3, exp2); + + x5.applyTrueBroadcast( + BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, + sd::broadcast::EqualTo), + x4, x6); + ASSERT_EQ(x6, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyScalar_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::BOOL); - NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {1.1, 2.1, 1.1, 2.1}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {1.1, 2.1, 1.1, 2.1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 2}, {1, 1, 1, 1}, sd::DataType::BOOL); - x1.applyScalar(sd::scalar::Add, 1, x1); - ASSERT_EQ(x1, exp1); + x1.applyScalar(sd::scalar::Add, 1, x1); + ASSERT_EQ(x1, exp1); - x1.applyScalar(sd::scalar::Add, 0.5, x3); - ASSERT_EQ(x3, exp2); + x1.applyScalar(sd::scalar::Add, 0.5, x3); + ASSERT_EQ(x3, exp2); - x2.applyScalar(sd::scalar::Add, 0.1, x2); - ASSERT_EQ(x2, exp3); + x2.applyScalar(sd::scalar::Add, 0.1, x2); + ASSERT_EQ(x2, exp3); - x4.applyScalar(sd::scalar::Add, 1.1, x3); - ASSERT_EQ(x3, exp4); + x4.applyScalar(sd::scalar::Add, 1.1, x3); + ASSERT_EQ(x3, exp4); - x4.applyScalar(sd::scalar::Add, 1, x4); - ASSERT_EQ(x4, exp5); + x4.applyScalar(sd::scalar::Add, 1, x4); + ASSERT_EQ(x4, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyScalar_test2) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 1, 0}, sd::DataType::BOOL); + NDArray x4('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x3('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); - NDArray x4('c', {2,2}, sd::DataType::BOOL); - - - NDArray exp1('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); - - x1.applyScalar(sd::scalar::EqualTo, 1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {0, 1, 1, 0}, sd::DataType::BOOL); - x2.applyScalar(sd::scalar::EqualTo, 1.5, x4); - ASSERT_EQ(x4, exp1); + x1.applyScalar(sd::scalar::EqualTo, 1, x4); + ASSERT_EQ(x4, exp1); - x3.applyScalar(sd::scalar::EqualTo, true, x4); - ASSERT_EQ(x4, exp2); + x2.applyScalar(sd::scalar::EqualTo, 1.5, x4); + ASSERT_EQ(x4, exp1); + x3.applyScalar(sd::scalar::EqualTo, true, x4); + ASSERT_EQ(x4, exp2); } #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyLambda_test1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](float elem) { return elem + item1; }; + auto func2 = [=](int elem) { return elem + item1; }; + auto func3 = [=](int elem) { return elem + item2; }; + auto func4 = [=](double elem) { return elem + item1; }; + auto func5 = [=](float elem) { return elem - (int)1; }; - const float item1 = 0.1; - const double item2 = 0.1; - auto func1 = [=](float elem) { return elem + item1; }; - auto func2 = [=](int elem) { return elem + item1; }; - auto func3 = [=](int elem) { return elem + item2; }; - auto func4 = [=](double elem) { return elem + item1; }; - auto func5 = [=](float elem) { return elem - (int)1; }; - - NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); - - x1.applyLambda(func1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {1, 0, 0, 0}, sd::DataType::BOOL); - x2.applyLambda(func1, x2); - ASSERT_EQ(x2, exp2); + x1.applyLambda(func1, x4); + ASSERT_EQ(x4, exp1); - x2.applyLambda(func2, x2); - ASSERT_EQ(x2, exp2); + x2.applyLambda(func1, x2); + ASSERT_EQ(x2, exp2); - x3.applyLambda(func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyLambda(func2, x2); + ASSERT_EQ(x2, exp2); - x5.applyLambda(func4, x5); - // x5.printBuffer(); - ASSERT_EQ(x5, exp4); + x3.applyLambda(func3, x3); + ASSERT_EQ(x3, exp3); - x6.applyLambda(func5, x7); - ASSERT_EQ(x7, exp5); + x5.applyLambda(func4, x5); + // x5.printBuffer(); + ASSERT_EQ(x5, exp4); + + x6.applyLambda(func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](Nd4jLong idx, float elem) { return idx + elem + item1; }; + auto func2 = [=](Nd4jLong idx, int elem) { return idx + elem + item1; }; + auto func3 = [=](Nd4jLong idx, int elem) { return idx + elem + item2; }; + auto func4 = [=](Nd4jLong idx, double elem) { return idx + elem + item1; }; + auto func5 = [=](Nd4jLong idx, float elem) { return idx + elem - (int)1; }; - const float item1 = 0.1; - const double item2 = 0.1; - auto func1 = [=](Nd4jLong idx, float elem) { return idx + elem + item1; }; - auto func2 = [=](Nd4jLong idx, int elem) { return idx + elem + item1; }; - auto func3 = [=](Nd4jLong idx, int elem) { return idx + elem + item2; }; - auto func4 = [=](Nd4jLong idx, double elem) { return idx + elem + item1; }; - auto func5 = [=](Nd4jLong idx, float elem) { return idx + elem - (int)1; }; - - NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp6('c', {2,2}, {0, 3, 6, 9}, sd::DataType::INT64); - - x1.applyIndexedLambda(func1, x4); - ASSERT_EQ(x4, exp1); + NDArray exp1('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0, 2, 4, 6}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp6('c', {2, 2}, {0, 3, 6, 9}, sd::DataType::INT64); - x2.applyIndexedLambda(func1, x2); - ASSERT_EQ(x2, exp2); + x1.applyIndexedLambda(func1, x4); + ASSERT_EQ(x4, exp1); - x2.applyIndexedLambda(func2, x2); - ASSERT_EQ(x2, exp6); + x2.applyIndexedLambda(func1, x2); + ASSERT_EQ(x2, exp2); - x3.applyIndexedLambda(func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyIndexedLambda(func2, x2); + ASSERT_EQ(x2, exp6); + + x3.applyIndexedLambda(func3, x3); + ASSERT_EQ(x3, exp3); - x5.applyIndexedLambda(func4, x5); - ASSERT_EQ(x5, exp4); + x5.applyIndexedLambda(func4, x5); + ASSERT_EQ(x5, exp4); - x6.applyIndexedLambda(func5, x7); - ASSERT_EQ(x7, exp5); + x6.applyIndexedLambda(func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); + NDArray other1('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2, 2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); - NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); - NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); - NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - - auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; - auto func2 = [](int elem1, float elem2) { return elem1 + elem2; }; - auto func3 = [](int elem1, double elem2) { return elem1 + elem2; }; - auto func4 = [](double elem1, float elem2) { return elem1 + elem2; }; - auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; - - NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0., 0, 0, 0}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); - - x1.applyPairwiseLambda(other2, func1, x4); - ASSERT_EQ(x4, exp1); + auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; + auto func2 = [](int elem1, float elem2) { return elem1 + elem2; }; + auto func3 = [](int elem1, double elem2) { return elem1 + elem2; }; + auto func4 = [](double elem1, float elem2) { return elem1 + elem2; }; + auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; - x2.applyPairwiseLambda(other3, func1, x2); - ASSERT_EQ(x2, exp2); + NDArray exp1('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 1, 0, 1}, sd::DataType::BOOL); - x2.applyPairwiseLambda(other3, func2, x2); - ASSERT_EQ(x2, other3); + x1.applyPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); - x3.applyPairwiseLambda(other1, func3, x3); - ASSERT_EQ(x3, exp3); + x2.applyPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, other3); - x5.applyPairwiseLambda(other1, func4, x5); - ASSERT_EQ(x5, exp4); + x3.applyPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); - x6.applyPairwiseLambda(other4, func5, x7); - ASSERT_EQ(x7, exp5); + x5.applyPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { - - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); - NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); - NDArray x7('c', {2,2}, sd::DataType::BOOL); - NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); - NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); - NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); - NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); - - auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func2 = [](Nd4jLong idx, int elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func3 = [](Nd4jLong idx, int elem1, double elem2) { return elem1 + elem2 + idx; }; - auto func4 = [](Nd4jLong idx, double elem1, float elem2) { return elem1 + elem2 + idx; }; - auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; }; - - NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); - NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); - NDArray exp5('c', {2,2}, {0., 1, 1, 1}, sd::DataType::BOOL); - - x1.applyIndexedPairwiseLambda(other2, func1, x4); - ASSERT_EQ(x4, exp1); - - x2.applyIndexedPairwiseLambda(other3, func1, x2); - ASSERT_EQ(x2, exp2); - - x2.applyIndexedPairwiseLambda(other3, func2, x2); - ASSERT_EQ(x2, exp2); - - x3.applyIndexedPairwiseLambda(other1, func3, x3); - ASSERT_EQ(x3, exp3); - - x5.applyIndexedPairwiseLambda(other1, func4, x5); - ASSERT_EQ(x5, exp4); - - x6.applyIndexedPairwiseLambda(other4, func5, x7); - ASSERT_EQ(x7, exp5); + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2, 2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2, 2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2, 2}, sd::DataType::BOOL); + NDArray other1('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2, 2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2, 2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); + + auto func1 = [](Nd4jLong idx, float elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func2 = [](Nd4jLong idx, int elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func3 = [](Nd4jLong idx, int elem1, double elem2) { + return elem1 + elem2 + idx; + }; + auto func4 = [](Nd4jLong idx, double elem1, float elem2) { + return elem1 + elem2 + idx; + }; + auto func5 = [](Nd4jLong idx, float elem1, int elem2) { + return elem1 - elem2 + idx; + }; + + NDArray exp1('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2, 2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2, 2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2, 2}, {0., 1, 1, 1}, sd::DataType::BOOL); + + x1.applyIndexedPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyIndexedPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyIndexedPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, exp2); + + x3.applyIndexedPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyIndexedPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyIndexedPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { + NDArray x1('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::DOUBLE); + NDArray x3('c', {2, 2}, {0, -1.5, -2.5, -3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {0., -1, -2, -3}, sd::DataType::DOUBLE); - NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, sd::DataType::DOUBLE); - - NDArray x5('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT32); - NDArray x6('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT32); - NDArray x7('c', {2,2}, {0., 10, 20, 30}, sd::DataType::INT32); + NDArray x5('c', {2, 2}, {0., 1, 2, 3}, sd::DataType::INT32); + NDArray x6('c', {2, 2}, {0., -1, -2, -3}, sd::DataType::INT32); + NDArray x7('c', {2, 2}, {0., 10, 20, 30}, sd::DataType::INT32); - NDArray x8('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); - NDArray x9('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); - NDArray x10('c', {2,2}, {0., 0, 0, 0}, sd::DataType::BOOL); + NDArray x8('c', {2, 2}, {0., 1, 0, 1}, sd::DataType::BOOL); + NDArray x9('c', {2, 2}, {1., 1, 0, 1}, sd::DataType::BOOL); + NDArray x10('c', {2, 2}, {0., 0, 0, 0}, sd::DataType::BOOL); - auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; }; - auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; }; - auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; }; - auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; }; + auto func1 = [](double elem1, float elem2, int elem3) { + return elem1 + elem2 + elem3; + }; + auto func2 = [](float elem1, float elem2, float elem3) { + return elem1 + elem2 + elem3; + }; + auto func3 = [](int elem1, int elem2, int elem3) { + return elem1 + elem2 + elem3; + }; + auto func4 = [](bool elem1, bool elem2, bool elem3) { + return elem1 + elem2 + elem3; + }; - NDArray exp('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); + NDArray exp('c', {2, 2}, {1., 1, 0, 1}, sd::DataType::BOOL); - x1.applyTriplewiseLambda(x2, x3, func1, x4); - ASSERT_EQ(x4, x2); + x1.applyTriplewiseLambda(x2, x3, func1, x4); + ASSERT_EQ(x4, x2); - x1.applyTriplewiseLambda(x2, x3, func2, x1); - ASSERT_EQ(x1, x3); + x1.applyTriplewiseLambda(x2, x3, func2, x1); + ASSERT_EQ(x1, x3); - x5.applyTriplewiseLambda(x6, x7, func3, x5); - ASSERT_EQ(x5, x7); + x5.applyTriplewiseLambda(x6, x7, func3, x5); + ASSERT_EQ(x5, x7); - x8.applyTriplewiseLambda(x9, x10, func4, x8); - ASSERT_EQ(x8, exp); + x8.applyTriplewiseLambda(x9, x10, func4, x8); + ASSERT_EQ(x8, exp); } #endif ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::INT64); - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); - - NDArray scalar = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_EQ(scalar, exp1); + NDArray scalar = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_EQ(scalar, exp1); - NDArray vec1 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_EQ(vec1, exp2); + NDArray vec1 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_EQ(vec1, exp2); - NDArray vec2 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_EQ(vec2, exp3); + NDArray vec2 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { + NDArray x1('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray scalar('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray vec1('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray vec2('c', {3}, {1, 1, 1}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::INT64); - NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray scalar('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray vec1('c', {2}, {2,2}, sd::DataType::INT64); - NDArray vec2('c', {3}, {1,1,1}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); - NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); + x1.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_EQ(scalar, exp1); - x1.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_EQ(scalar, exp1); + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_EQ(vec1, exp2); - x1.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); - ASSERT_EQ(vec1, exp2); - - x1.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); - ASSERT_EQ(vec2, exp3); + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_EQ(vec2, exp3); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyReduce3_test1) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - - auto result = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_EQ(result, exp1); + auto result = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_EQ(result, exp1); - result = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_EQ(result, exp2); + result = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyReduce3_test2) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 2}, {-1, -2, -3, -4}, sd::DataType::INT32); + NDArray x3('c', {2, 2}, {1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x6('c', {2, 3}, {-6, -5, -4, -3, -2, -1}, sd::DataType::INT32); + NDArray x7('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x8('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); - NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x6('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); - NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {-18, -20, -18}, sd::DataType::FLOAT32); + NDArray exp4('c', {2}, {-28, -28}, sd::DataType::FLOAT32); + NDArray exp5('c', {3}, {7.5, 10.5, 13.5}, sd::DataType::DOUBLE); + NDArray exp6('c', {2}, {9, 22.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); - NDArray exp3('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); - NDArray exp4('c', {2}, {-28,-28}, sd::DataType::FLOAT32); - NDArray exp5('c', {3}, {7.5,10.5,13.5}, sd::DataType::DOUBLE); - NDArray exp6('c', {2}, {9,22.5}, sd::DataType::DOUBLE); + auto result = x1.applyReduce3(reduce3::Dot, x2, {0, 1}); + ASSERT_EQ(result, exp1); - auto result = x1.applyReduce3(reduce3::Dot, x2, {0,1}); - ASSERT_EQ(result, exp1); + result = x3.applyReduce3(reduce3::Dot, x4, {0, 1}); + ASSERT_EQ(result, exp2); - result = x3.applyReduce3(reduce3::Dot, x4, {0,1}); - ASSERT_EQ(result, exp2); + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); + ASSERT_EQ(result, exp3); - result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); - ASSERT_EQ(result, exp3); + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); + ASSERT_EQ(result, exp4); - result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); - ASSERT_EQ(result, exp4); + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); + ASSERT_EQ(result, exp5); - result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); - ASSERT_EQ(result, exp5); - - result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); - ASSERT_EQ(result, exp6); + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); + ASSERT_EQ(result, exp6); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, applyAllReduce3_test1) { + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray x2('c', {2, 3}, {-1, 1, -1, 1, -1, 1}, sd::DataType::INT32); + NDArray x3('c', {2, 3}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::DOUBLE); + NDArray exp1('c', {2, 3}, {2, -2, 2, 2, -2, 2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {6, 6, 6, 9, 9, 9}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray x2('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); - NDArray x3('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); - - auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); - ASSERT_EQ(result, exp1); + auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_EQ(result, exp1); - result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); - ASSERT_EQ(result, exp2); + result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); + ASSERT_EQ(result, exp2); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, RowCol_test1) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; + if (!Environment::getInstance()->isExperimentalBuild()) return; - NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); - NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); - NDArray x4('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray x1('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5, 0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray x5('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); - NDArray exp1('c', {2,3}, {2,3,4,5,6,7}, sd::DataType::INT32); - NDArray exp2('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT32); - NDArray exp3('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,3}, {0,1,1,2,3,3}, sd::DataType::INT32); + NDArray exp1('c', {2, 3}, {2, 3, 4, 5, 6, 7}, sd::DataType::INT32); + NDArray exp2('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray exp3('c', {2, 3}, {1.5, 2.5, 3.5, 4.6, 5.6, 6.6}, + sd::DataType::DOUBLE); + NDArray exp4('c', {2, 3}, {0, 1, 1, 2, 3, 3}, sd::DataType::INT32); - x1.addiRowVector(x3); - ASSERT_EQ(x1, exp1); + x1.addiRowVector(x3); + ASSERT_EQ(x1, exp1); - x1.addiColumnVector(x2); - ASSERT_EQ(x1, exp1); + x1.addiColumnVector(x2); + ASSERT_EQ(x1, exp1); - x4.addiColumnVector(x2); - ASSERT_EQ(x4, exp3); + x4.addiColumnVector(x2); + ASSERT_EQ(x4, exp3); - x5.muliColumnVector(x2); - ASSERT_EQ(x5, exp4); + x5.muliColumnVector(x2); + ASSERT_EQ(x5, exp4); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, RowCol_test2) { - if (!Environment::getInstance()->isExperimentalBuild()) - return; - - NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); - NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); - NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); - NDArray x4('c', {2,3}, sd::DataType::FLOAT32); - NDArray x5('c', {3}, {1,2,3}, sd::DataType::INT64); - NDArray x6('c', {2,3}, sd::DataType::INT32); - NDArray x7('c', {3}, {1.5,1.6,1.7}, sd::DataType::DOUBLE); - NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); - NDArray x9('c', {3}, {1,2,3}, sd::DataType::DOUBLE); - NDArray x10('c', {2,3}, sd::DataType::DOUBLE); - - NDArray exp1('c', {2,3}, {2.5,3.6,4.7,5.5,6.6,7.7}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {2, 4, 6, 5, 7, 9}, sd::DataType::INT32); - NDArray exp3('c', {2,3}, {-0.5,0.4,1.3,2.5,3.4,4.3}, sd::DataType::FLOAT32); - NDArray exp4('c', {2,3}, {1,4,9,4,10,18}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,3}, {1,1,1,4,2.5,2}, sd::DataType::DOUBLE); - NDArray exp6('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::FLOAT32); - - x1.addRowVector(x3, x4); - ASSERT_EQ(x4, exp1); + if (!Environment::getInstance()->isExperimentalBuild()) return; + + NDArray x1('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5, 0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x5('c', {3}, {1, 2, 3}, sd::DataType::INT64); + NDArray x6('c', {2, 3}, sd::DataType::INT32); + NDArray x7('c', {3}, {1.5, 1.6, 1.7}, sd::DataType::DOUBLE); + NDArray x8('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::FLOAT32); + NDArray x9('c', {3}, {1, 2, 3}, sd::DataType::DOUBLE); + NDArray x10('c', {2, 3}, sd::DataType::DOUBLE); - x1.addRowVector(x5, x6); - ASSERT_EQ(x6, exp2); + NDArray exp1('c', {2, 3}, {2.5, 3.6, 4.7, 5.5, 6.6, 7.7}, + sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {2, 4, 6, 5, 7, 9}, sd::DataType::INT32); + NDArray exp3('c', {2, 3}, {-0.5, 0.4, 1.3, 2.5, 3.4, 4.3}, + sd::DataType::FLOAT32); + NDArray exp4('c', {2, 3}, {1, 4, 9, 4, 10, 18}, sd::DataType::DOUBLE); + NDArray exp5('c', {2, 3}, {1, 1, 1, 4, 2.5, 2}, sd::DataType::DOUBLE); + NDArray exp6('c', {2, 3}, {1.5, 2.5, 3.5, 4.6, 5.6, 6.6}, + sd::DataType::FLOAT32); - x8.subRowVector(x7, x4); - ASSERT_EQ(x4, exp3); + x1.addRowVector(x3, x4); + ASSERT_EQ(x4, exp1); - x1.mulRowVector(x9, x10); - ASSERT_EQ(x10, exp4); + x1.addRowVector(x5, x6); + ASSERT_EQ(x6, exp2); - x1.divRowVector(x9, x10); - ASSERT_EQ(x10, exp5); + x8.subRowVector(x7, x4); + ASSERT_EQ(x4, exp3); - x1.addColumnVector(x2, x4); - ASSERT_EQ(x4, exp6); + x1.mulRowVector(x9, x10); + ASSERT_EQ(x10, exp4); + + x1.divRowVector(x9, x10); + ASSERT_EQ(x10, exp5); + + x1.addColumnVector(x2, x4); + ASSERT_EQ(x4, exp6); } ////////////////////////////////////////////////////////////////////// @@ -1805,178 +1813,167 @@ TEST_F(MultiDataTypeTests, tile_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, asT_test1) { + NDArray x1('c', {2}, {1.5, 2.5}, sd::DataType::FLOAT32); - NDArray x1('c', {2}, {1.5, 2.5}, sd::DataType::FLOAT32); - - NDArray exp1('c', {2}, {1, 2}, sd::DataType::INT32); - NDArray exp2('c', {2}, {1.5, 2.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {2}, {1, 2}, sd::DataType::INT32); + NDArray exp2('c', {2}, {1.5, 2.5}, sd::DataType::DOUBLE); - auto result = new NDArray(x1.asT()); - ASSERT_EQ(*result, exp1); - delete result; + auto result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp1); + delete result; - result = new NDArray(x1.asT()); - ASSERT_EQ(*result, exp2); - delete result; + result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp2); + delete result; - result = new NDArray(x1.asT(sd::DataType::INT32)); - ASSERT_EQ(*result, exp1); - delete result; + result = new NDArray(x1.asT(sd::DataType::INT32)); + ASSERT_EQ(*result, exp1); + delete result; - result = new NDArray(x1.asT(sd::DataType::DOUBLE)); - ASSERT_EQ(*result, exp2); - delete result; + result = new NDArray(x1.asT(sd::DataType::DOUBLE)); + ASSERT_EQ(*result, exp2); + delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, assign_test2) { + NDArray x1('c', {2, 3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, sd::DataType::INT32); + NDArray x3('c', {3, 2}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, sd::DataType::BOOL); + NDArray x5('c', {2, 3}, {1.5, 2.5, 0, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); - NDArray x1('c', {2,3}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, sd::DataType::INT32); - NDArray x3('c', {3,2}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, sd::DataType::BOOL); - NDArray x5('c', {2,3}, {1.5,2.5,0,4.5,5.5,6.5}, sd::DataType::FLOAT32); + NDArray exp1('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT32); + NDArray exp2('c', {3, 2}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, + sd::DataType::DOUBLE); + NDArray exp3('c', {3, 2}, {1, 1, 0, 1, 1, 1}, sd::DataType::BOOL); - NDArray exp1('c', {3,2}, {1, 2,3,4,5,6}, sd::DataType::INT32); - NDArray exp2('c', {3,2}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::DOUBLE); - NDArray exp3('c', {3,2}, {1,1,0,1,1,1}, sd::DataType::BOOL); + x2.assign(x1); + ASSERT_EQ(x2, exp1); - x2.assign(x1); - ASSERT_EQ(x2, exp1); + x3.assign(x1); + ASSERT_EQ(x3, exp2); - x3.assign(x1); - ASSERT_EQ(x3, exp2); - - x4.assign(x5); - ASSERT_EQ(x4, exp3); + x4.assign(x5); + ASSERT_EQ(x4, exp3); } TEST_F(MultiDataTypeTests, Test_Cast_1) { - auto first = NDArrayFactory::create('c', {10}); - auto asBool = NDArrayFactory::create('c', {10}); - auto _not = NDArrayFactory::create('c', {10}); - auto asFloat = NDArrayFactory::create('c', {10}); - auto exp = NDArrayFactory::create('c', {10}); - exp.assign(0.0f); + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(0.0f); - asBool.assign(first); + asBool.assign(first); - // asBool.printIndexedBuffer("asBool"); - asBool.applyScalar(scalar::Not, false, _not); + // asBool.printIndexedBuffer("asBool"); + asBool.applyScalar(scalar::Not, false, _not); - // _not.printIndexedBuffer("_not"); + // _not.printIndexedBuffer("_not"); - asFloat.assign(_not); + asFloat.assign(_not); - // asFloat.printIndexedBuffer("asFloat"); - ASSERT_EQ(exp, asFloat); + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); } TEST_F(MultiDataTypeTests, Test_Cast_2) { - auto first = NDArrayFactory::create('c', {10}); - auto asBool = NDArrayFactory::create('c', {10}); - auto _not = NDArrayFactory::create('c', {10}); - auto asFloat = NDArrayFactory::create('c', {10}); - auto exp = NDArrayFactory::create('c', {10}); - exp.assign(1.0f); + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(1.0f); - asBool.assign(first); + asBool.assign(first); - // asBool.printIndexedBuffer("asBool"); - asBool.applyTransform(transform::Not, _not); + // asBool.printIndexedBuffer("asBool"); + asBool.applyTransform(transform::Not, _not); - // _not.printIndexedBuffer("_not"); + // _not.printIndexedBuffer("_not"); - asFloat.assign(_not); + asFloat.assign(_not); - // asFloat.printIndexedBuffer("asFloat"); - ASSERT_EQ(exp, asFloat); + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); } ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, divide_bool_test1) { - - NDArray x1('c', {2,3}, {1.5,0,3.5,0,5.5,6.5}, sd::DataType::FLOAT32); - NDArray x2('c', {3,2}, {1,1,0,1,0,1}, sd::DataType::BOOL); - NDArray x3('c', {2,3}, sd::DataType::FLOAT32); - NDArray x4('c', {2}, sd::DataType::BOOL); - - try { - NDArray x3 = x1 / x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1 /= x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - NDArray x3 = 150. / x2; - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.divRowVector(x4, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.applyBroadcast(sd::broadcast::FloorDiv, {1}, x4, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } - - try { - x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); - } - catch (std::exception& message) { - // printf("%s\n", message.what()); - ASSERT_TRUE(1); - } + NDArray x1('c', {2, 3}, {1.5, 0, 3.5, 0, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3, 2}, {1, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x3('c', {2, 3}, sd::DataType::FLOAT32); + NDArray x4('c', {2}, sd::DataType::BOOL); + + try { + NDArray x3 = x1 / x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1 /= x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + NDArray x3 = 150. / x2; + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.divRowVector(x4, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyBroadcast(sd::broadcast::FloorDiv, {1}, x4, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); + } catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } } - ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, aaa) { + NDArray z('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + z.permutei({1, 0}); - NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - z.permutei({1,0}); - - sd::graph::RandomGenerator gen(119,5); - ExtraArguments extras({1.5, 2.5}); - - NativeOpExecutioner::execRandom(LaunchContext::defaultContext(), sd::random::UniformDistribution, - &gen, - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - extras.argumentsAsT()); - // z.printIndexedBuffer(); + sd::graph::RandomGenerator gen(119, 5); + ExtraArguments extras({1.5, 2.5}); + NativeOpExecutioner::execRandom( + LaunchContext::defaultContext(), sd::random::UniformDistribution, &gen, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + extras.argumentsAsT()); + // z.printIndexedBuffer(); } ////////////////////////////////////////////////////////////////////// -TEST_F(MultiDataTypeTests, assign_2) -{ - NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, sd::DataType::FLOAT32); - NDArray y('c', {4}, sd::DataType::INT32); - NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); +TEST_F(MultiDataTypeTests, assign_2) { + NDArray x('c', {4}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1, 2, 3, 4}, sd::DataType::INT32); - y.assign(x); - // y.printBuffer(); + y.assign(x); + // y.printBuffer(); - ASSERT_TRUE(expected.equalsTo(&y)); + ASSERT_TRUE(expected.equalsTo(&y)); } diff --git a/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp index 1c12f2d7258f..306fc1525b56 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDeviceTests.cpp @@ -18,51 +18,52 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include #include -#include +#include #include +#include + #include +#include "testlayers.h" using namespace sd; class MultiDeviceTests : public testing::Test { -public: - + public: }; -void createArrays(int limit, std::vector &arrays) { - auto deviceId = AffinityManager::currentDeviceId(); - auto numDevices = AffinityManager::numberOfDevices(); +void createArrays(int limit, std::vector &arrays) { + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); - for (int e = 0; e < limit; e++) { - auto value = deviceId * limit + e; - arrays[value] = NDArrayFactory::create_('c', {10}); - arrays[value]->assign(value); - //nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e(0)); - } + for (int e = 0; e < limit; e++) { + auto value = deviceId * limit + e; + arrays[value] = NDArrayFactory::create_('c', {10}); + arrays[value]->assign(value); + // nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, + // arrays[value]->meanNumber().e(0)); + } } TEST_F(MultiDeviceTests, test_multi_device_migration_1) { - auto deviceId = AffinityManager::currentDeviceId(); - auto numDevices = AffinityManager::numberOfDevices(); - auto numArrays = 10; - std::vector arrays(numDevices * numArrays); + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); + auto numArrays = 10; + std::vector arrays(numDevices * numArrays); - // filling list of arrays on multiple threads - for (int e = 0; e < numDevices; e++) { - std::thread t1(createArrays, numArrays, std::ref(arrays)); + // filling list of arrays on multiple threads + for (int e = 0; e < numDevices; e++) { + std::thread t1(createArrays, numArrays, std::ref(arrays)); - t1.join(); - } + t1.join(); + } - // at this moment all arrays are build, so we can test migration - for (int e = 0; e < arrays.size(); e++) { - ASSERT_NEAR((float) e, arrays[e]->meanNumber().e(0), 1e-5f); - delete arrays[e]; - } + // at this moment all arrays are build, so we can test migration + for (int e = 0; e < arrays.size(); e++) { + ASSERT_NEAR((float)e, arrays[e]->meanNumber().e(0), 1e-5f); + delete arrays[e]; + } } diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu index 24ac087d1673..6428169853be 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayConstructorsTests.cu @@ -18,189 +18,186 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include +#include +#include #include #include #include #include -#include -#include #include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class NDArrayConstructorsTests : public testing::Test { -public: - + public: }; TEST_F(NDArrayConstructorsTests, test_constructor_1) { - auto x = NDArrayFactory::empty_(); + auto x = NDArrayFactory::empty_(); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_TRUE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_TRUE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_TRUE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_2) { - auto x = NDArrayFactory::vector(5, 1.0f); + auto x = NDArrayFactory::vector(5, 1.0f); + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_3) { - auto x = NDArrayFactory::create('c',{5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); - ASSERT_TRUE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_4) { - auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f); + auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f); - ASSERT_FALSE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_5) { - auto x = NDArrayFactory::create('c',{2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_6) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray y(x); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray y(x); - ASSERT_TRUE(y.buffer() == nullptr); - ASSERT_FALSE(y.specialBuffer() == nullptr); + ASSERT_TRUE(y.buffer() == nullptr); + ASSERT_FALSE(y.specialBuffer() == nullptr); - ASSERT_FALSE(y.shapeInfo() == nullptr); - ASSERT_FALSE(y.specialShapeInfo() == nullptr); + ASSERT_FALSE(y.shapeInfo() == nullptr); + ASSERT_FALSE(y.specialShapeInfo() == nullptr); - ASSERT_TRUE(y.isActualOnDeviceSide()); - ASSERT_FALSE(y.isActualOnHostSide()); + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_7) { - auto x = NDArrayFactory::create(1.0f); + auto x = NDArrayFactory::create(1.0f); - ASSERT_FALSE(x.buffer() == nullptr); - ASSERT_FALSE(x.specialBuffer() == nullptr); + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); - ASSERT_FALSE(x.shapeInfo() == nullptr); - ASSERT_FALSE(x.specialShapeInfo() == nullptr); + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); } TEST_F(NDArrayConstructorsTests, test_constructor_8) { - auto x = NDArrayFactory::create_('c',{2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_9) { - auto x = NDArrayFactory::create_('c',{2, 2}); + auto x = NDArrayFactory::create_('c', {2, 2}); - ASSERT_TRUE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_linspace_1) { - auto x = NDArrayFactory::linspace(1.0f, 10.0f, 20); + auto x = NDArrayFactory::linspace(1.0f, 10.0f, 20); - ASSERT_FALSE(x->buffer() == nullptr); - ASSERT_FALSE(x->specialBuffer() == nullptr); + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); - ASSERT_FALSE(x->shapeInfo() == nullptr); - ASSERT_FALSE(x->specialShapeInfo() == nullptr); + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_TRUE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); - delete x; + delete x; } TEST_F(NDArrayConstructorsTests, test_constructor_10) { - - NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0 - NDArray scalar2('c', {}, std::vector{0}); - - ASSERT_TRUE(scalar1.isActualOnDeviceSide()); - ASSERT_TRUE(!scalar1.isActualOnHostSide()); - ASSERT_TRUE(scalar2.isActualOnDeviceSide()); - ASSERT_TRUE(scalar2.isActualOnHostSide()); - - ASSERT_TRUE(scalar2.equalsTo(scalar1)); - - ASSERT_TRUE(scalar1.isActualOnDeviceSide()); - ASSERT_TRUE(!scalar1.isActualOnHostSide()); - ASSERT_TRUE(scalar2.isActualOnDeviceSide()); - ASSERT_TRUE(scalar2.isActualOnHostSide()); - - ASSERT_TRUE(scalar1.buffer() == nullptr); - ASSERT_TRUE(scalar1.specialBuffer() != nullptr); - ASSERT_TRUE(scalar1.shapeInfo() != nullptr); - ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr); - ASSERT_TRUE(scalar1.lengthOf() == 1); + NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0 + NDArray scalar2('c', {}, std::vector{0}); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar2.equalsTo(scalar1)); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar1.buffer() == nullptr); + ASSERT_TRUE(scalar1.specialBuffer() != nullptr); + ASSERT_TRUE(scalar1.shapeInfo() != nullptr); + ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr); + ASSERT_TRUE(scalar1.lengthOf() == 1); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index f95705f08e93..f6af6438f4e6 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -14,2185 +14,2402 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ - // - // @author raver119@gmail.com - // +// +// @author raver119@gmail.com +// -#include "testlayers.h" #include #include +#include +#include #include #include #include #include -#include -#include #include #include +#include -#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class NDArrayCudaBasicsTests : public testing::Test { -public: - + public: }; ////////////////////////////////////////////////////////////////////////// -static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { - - if(devicePtrs.size() != hostData.size()) - throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); - - cudaError_t cudaResult; - - void* reductionPointer; - cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - int* allocationPointer; - cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; - - lc.setReductionPointer(reductionPointer); - lc.setAllocationPointer(allocationPointer); - cudaStream_t stream = *lc.getCudaStream(); - - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); - } - return cudaResult; +static cudaError_t allocateDeviceMem( + LaunchContext& lc, std::vector& devicePtrs, + const std::vector>& hostData) { + if (devicePtrs.size() != hostData.size()) + throw std::invalid_argument( + "prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void* reductionPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&reductionPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + int* allocationPointer; + cudaResult = + cudaMalloc(reinterpret_cast(&allocationPointer), 1024 * 1024); + if (cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + if (cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, stream); + } + return cudaResult; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_2) { - auto x = NDArrayFactory::create('c', {5}); - auto y = NDArrayFactory::create('c', {5}); + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_3) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); - NDArray::registerSpecialUse({&x}, {&y}); + NDArray::registerSpecialUse({&x}, {&y}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); - ASSERT_TRUE(y.isActualOnDeviceSide()); - ASSERT_FALSE(y.isActualOnHostSide()); + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); } TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - delete x; - delete y; + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_02) { - auto x = NDArrayFactory::create_('c', {5}); - auto y = NDArrayFactory::create_('c', {5}); + auto x = NDArrayFactory::create_('c', {5}); + auto y = NDArrayFactory::create_('c', {5}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); - delete x; - delete y; + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Neg, *y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Neg, *y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //y->syncToHost(); - // y->printBuffer("Negatives"); - delete x; - delete y; + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // y->syncToHost(); + // y->printBuffer("Negatives"); + delete x; + delete y; } TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) { - auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); - ASSERT_TRUE(x->isActualOnDeviceSide()); - ASSERT_FALSE(x->isActualOnHostSide()); + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); - NDArray::registerSpecialUse({y}, {x}); - x->applyTransform(transform::Cosine, *y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Cosine, *y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //y->syncToHost(); - delete x; - delete y; + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // y->syncToHost(); + delete x; + delete y; } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - - LaunchContext lc(stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - z.tickWriteDevice(); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + // cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + // cudaMemcpyHostToDevice, *stream); cudaMemcpyAsync(devShapePtrX, + // x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), + // cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_2) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray y('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray z('c', { 5 }, sd::DataType::DOUBLE); + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}); + NDArray y('c', {5}, {1, 2, 3, 4, 5}); + NDArray z('c', {5}, sd::DataType::DOUBLE); - NDArray exp('c', { 5 }, { 2, 4, 6, 8, 10 }); + NDArray exp('c', {5}, {2, 4, 6, 8, 10}); - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); - LaunchContext lc(stream, *stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_3) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - - Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - auto stream = reinterpret_cast(&nativeStream); - - //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); - //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); - - LaunchContext lc(stream, *stream, nullptr, nullptr); - NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - z.tickWriteDevice(); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - //double* localBuffer = ; - z.syncToHost(); - cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), cudaMemcpyDeviceToHost); - res = cudaStreamSynchronize(*stream); - z.tickWriteHost(); - ASSERT_EQ(0, res); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", + sizeof(cudaStream_t)); + cudaError_t dZ = + cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + // cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), + // cudaMemcpyHostToDevice, *stream); cudaMemcpyAsync(devShapePtrX, + // x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), + // cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform( + &lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + // double* localBuffer = ; + z.syncToHost(); + cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), + cudaMemcpyDeviceToHost); + res = cudaStreamSynchronize(*stream); + z.tickWriteHost(); + ASSERT_EQ(0, res); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Add, y, z); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x.applyPairwiseTransform(pairwise::Add, y, z); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_5) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += y; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - //y.printBuffer("3Y = "); - //z.printBuffer("3Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {2, 4, 6, 8, 10}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x += y; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + // y.printBuffer("3Y = "); + // z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_6) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += y; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {3, 4, 5, 6, 7}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x += y; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestAdd_7) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - //auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x += 2.; - //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); - x.syncToHost(); - - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + // auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, + // 5}); auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {3, 4, 5, 6, 7}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x += 2.; + // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - // x.printBuffer("3X = "); - // y.printBuffer("3Y = "); - // z.printBuffer("3Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + // x.printBuffer("3X = "); + // y.printBuffer("3Y = "); + // z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - NDArray z('c', { 5 }, sd::DataType::DOUBLE); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + NDArray z('c', {5}, sd::DataType::DOUBLE); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - x.applyPairwiseTransform(pairwise::Multiply, y, z); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - // z.printBuffer("23Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < z.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - } + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {5}, {1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {5}); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + // z.printBuffer("23Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestMultiply_4) { - // allocating host-side arrays - NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); - NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x *= y; - //x.tickWriteDevice(); - // x.printBuffer("33Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - for (int e = 0; e < x.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - } + // allocating host-side arrays + NDArray x('c', {5}, {1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', {5}, {1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {5}, {1, 4, 9, 16, 25}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + x *= y; + // x.tickWriteDevice(); + // x.printBuffer("33Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestPrimitiveNeg_01) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create('c', { 5 }, { -1, -2, -3, -4, -5 }); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', {5}, {-1, -2, -3, -4, -5}); - auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + auto stream = x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); - NativeOpExecutioner::execTransformSame(x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, nullptr, nullptr); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - y.tickWriteDevice(); + NativeOpExecutioner::execTransformSame( + x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), + x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo(), nullptr, nullptr, nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.tickWriteDevice(); - // x.printBuffer("X = "); - // y.printBuffer("Y = "); + // x.printBuffer("X = "); + // y.printBuffer("Y = "); - for (int e = 0; e < y.lengthOf(); e++) { - ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); - } + for (int e = 0; e < y.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + } } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Neg, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - // y.printBuffer("Negatives2"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Sqrt, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - ASSERT_TRUE(y.equalsTo(exp)); - //y.printBuffer("SQRT output"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - //auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); - //ASSERT_TRUE(x.isActualOnDeviceSide()); - //ASSERT_TRUE(x.isActualOnHostSide()); - - x.applyTransform(transform::Assign, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - - // printf("Assigned to another array\n"); - // y.printBuffer("OUput"); - ASSERT_TRUE(y.equalsTo(x)); - //y.syncToHost(); - //y.printBuffer("IsMax output"); - //delete x; - //delete y; -} - -TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.dataType() == y.dataType()); - //y.printBuffer("Cosine2"); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Neg, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // y.printBuffer("Negatives2"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + {1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Sqrt, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + ASSERT_TRUE(y.equalsTo(exp)); + // y.printBuffer("SQRT output"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + // auto exp = + // NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + // ASSERT_TRUE(x.isActualOnDeviceSide()); + // ASSERT_TRUE(x.isActualOnHostSide()); + + x.applyTransform(transform::Assign, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + + // printf("Assigned to another array\n"); + // y.printBuffer("OUput"); + ASSERT_TRUE(y.equalsTo(x)); + // y.syncToHost(); + // y.printBuffer("IsMax output"); + // delete x; + // delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + 'c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + // y.printBuffer("Cosine2"); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - //exp.syncToHost(); - //y.printBuffer("PrimitiveCosine2"); - //exp.printBuffer("Primitive Cosine exp"); - ASSERT_TRUE(exp.isSameShape(y)); - ASSERT_TRUE(exp.dataType() == y.dataType()); - //for (int e = 0; e < y.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); - //} - - ASSERT_TRUE(exp.equalsTo(y)); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + 'c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // exp.syncToHost(); + // y.printBuffer("PrimitiveCosine2"); + // exp.printBuffer("Primitive Cosine exp"); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + // for (int e = 0; e < y.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + //} + + ASSERT_TRUE(exp.equalsTo(y)); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {5}); - auto exp = NDArrayFactory::create({0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); - - ASSERT_TRUE(x.isActualOnDeviceSide()); - ASSERT_FALSE(x.isActualOnHostSide()); - x.applyTransform(transform::Cosine, y); - //ASSERT_TRUE(x->isActualOnDeviceSide()); - //ASSERT_FALSE(x->isActualOnHostSide()); - - //ASSERT_TRUE(y->isActualOnDeviceSide()); - //ASSERT_TRUE(y->isActualOnHostSide()); - //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); - //ASSERT_EQ(0, res); - //exp.syncToHost(); -// y.printBuffer("PrimitiveCosine3"); -// exp.printBuffer("Primitive Cosine3 exp"); -// y.printShapeInfo("Y shape"); -// exp.printShapeInfo("Exp Shape"); - ASSERT_TRUE(exp.isSameShape(y)); -// -// for (int e = 0; e < y.lengthOf(); e++) { -// printf("%lf == %lf\n", exp.e(e), y.e(e)); -//// ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); -// } - - ASSERT_TRUE(exp.equalsTo(y)); - //delete x; - //delete y; + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create( + {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + // ASSERT_TRUE(x->isActualOnDeviceSide()); + // ASSERT_FALSE(x->isActualOnHostSide()); + + // ASSERT_TRUE(y->isActualOnDeviceSide()); + // ASSERT_TRUE(y->isActualOnHostSide()); + // auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + // ASSERT_EQ(0, res); + // exp.syncToHost(); + // y.printBuffer("PrimitiveCosine3"); + // exp.printBuffer("Primitive Cosine3 exp"); + // y.printShapeInfo("Y shape"); + // exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(y)); + // + // for (int e = 0; e < y.lengthOf(); e++) { + // printf("%lf == %lf\n", exp.e(e), y.e(e)); + //// ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + // } + + ASSERT_TRUE(exp.equalsTo(y)); + // delete x; + // delete y; } TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_2) { - - //if (!Environment::getInstance()->isExperimentalBuild()) - // return; - - NDArray x = NDArrayFactory::create('c', {2,3,4}); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); -// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); - NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t stream; - cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext lc(&stream); - - // allocate required amount of global device memory and copy host data to it - cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Multiply, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - z.tickWriteDevice(); - - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance()->isExperimentalBuild()) + // return; + + NDArray x = NDArrayFactory::create('c', {2, 3, 4}); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + // NDArray exp('c', {2,3,4}, + // {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., + // 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3, 4}, + {10., 40., 90., 160., 50., 120., 210., 320., + 90., 200., 330., 480., 650., 840., 1050., 1280., + 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, + sd::DataType::DOUBLE); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + &lc, sd::broadcast::Multiply, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); } TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_3) { - - //if (!Environment::getInstance()->isExperimentalBuild()) - // return; - - NDArray x('c', {2,3,4}, sd::DataType::DOUBLE); - NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); -// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); - NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); - x.linspace(1); x.syncToDevice(); - - std::vector dimensions = {0,2}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - //cudaStream_t stream; - //cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); - LaunchContext* pLc = x.getContext();//(&stream); - cudaStream_t* stream = pLc->getCudaStream(); - // allocate required amount of global device memory and copy host data to it -// cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - for(int i = 0; i < devicePtrs.size(); ++i) { - - cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); ASSERT_EQ(0, cudaResult); - cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, *stream); - } - - NDArray::registerSpecialUse({&z}, {&x, &y}); - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Multiply, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - //cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); - //z.syncToHost(); - // verify results - for (int e = 0; e < z.lengthOf(); e++) - ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - ASSERT_TRUE(exp.equalsTo(z)); - // delete cuda stream - //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance()->isExperimentalBuild()) + // return; + + NDArray x('c', {2, 3, 4}, sd::DataType::DOUBLE); + NDArray y('c', {2, 4}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::DOUBLE); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::DOUBLE); + // NDArray exp('c', {2,3,4}, + // {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., + // 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3, 4}, + {10., 40., 90., 160., 50., 120., 210., 320., + 90., 200., 330., 480., 650., 840., 1050., 1280., + 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, + sd::DataType::DOUBLE); + x.linspace(1); + x.syncToDevice(); + + std::vector dimensions = {0, 2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), + dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + // cudaStream_t stream; + // cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext* pLc = x.getContext(); //(&stream); + cudaStream_t* stream = pLc->getCudaStream(); + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); + for (int i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), + hostData[i].second); + ASSERT_EQ(0, cudaResult); + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice, *stream); + } + + NDArray::registerSpecialUse({&z}, {&x, &y}); + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + pLc, sd::broadcast::Multiply, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + // z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + ASSERT_TRUE(exp.equalsTo(z)); + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } - TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_1) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - x *= y; - //x.syncToHost(); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(x)); -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); -// } + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create( + 3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x *= y; + // x.syncToHost(); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(x)); + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; - // z.printBuffer("53Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create( + 3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; + // z.printBuffer("53Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); - auto y = NDArrayFactory::create('c', {2,3}, {3, 3, 3, 3, 3, 3}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); - //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; - - // z.printBuffer("52Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create( + 'c', {2, 3}, + {3, 3, 3, 3, 3, 3}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {3, 6, 9, 12, 15, 18}); + // if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; + + // z.printBuffer("52Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); - auto y = NDArrayFactory::create('c', {2, 3}, {2., 3., 3., 3., 3., 3.}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - auto z = NDArrayFactory::create('c', { 2, 3 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 9, 12, 15, 18 }); - //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x.applyPairwiseTransform(pairwise::Multiply, y, z);// *= y; - - // z.printBuffer("51Result out"); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - ASSERT_TRUE(exp.equalsTo(z)); - -// for (int e = 0; e < x.lengthOf(); e++) { -// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); -// } + // allocating host-side arrays + auto x = NDArrayFactory::create( + 'c', {2, 3}, {1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create( + 'c', {2, 3}, + {2., 3., 3., 3., 3., + 3.}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', {2, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 6, 9, 12, 15, 18}); + // if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + x.applyPairwiseTransform(pairwise::Multiply, y, z); // *= y; + + // z.printBuffer("51Result out"); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + // } } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestBroadcastRaw_1) { - - //if (!Environment::getInstance()->isExperimentalBuild()) - // return; - - NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); - NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); - NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); - //real output [10, 11, 12, 13, 4, 5, 6, 7, 28, 29, 30, 31, 22, 23, 24, 25, 16, 17, 18, 19, 40, 41, 42, 43] - x.linspace(0); x.syncToDevice(); - - std::vector dimensions = {1}; - - // evaluate xTad data - shape::TAD xTad; - xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); - - // prepare input arrays for prepareDataForCuda function - std::vector> hostData; - hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(Nd4jLong)); // 0 -- dimensions - hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo - hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets - std::vector devicePtrs(hostData.size(), nullptr); - - // create cuda stream and LaunchContext - cudaError_t cudaResult; - cudaStream_t* stream = x.getContext()->getCudaStream(); - LaunchContext* pLc = x.getContext(); - - // allocate required amount of global device memory and copy host data to it - //cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); - for(size_t i = 0; i < devicePtrs.size(); ++i) { - cudaResult = cudaMalloc(&devicePtrs[i], hostData[i].second); //if(cudaResult != 0) return cudaResult; - ASSERT_EQ(cudaResult, 0); - cudaMemcpy(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice); - } - - // call cuda kernel which calculates result - NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Add, - nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), - nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - (int*)devicePtrs[0], dimensions.size(), - (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], - nullptr, nullptr); - - cudaResult = cudaStreamSynchronize(*stream); ASSERT_EQ(0, cudaResult); - - // x.printIndexedBuffer(" X"); - // y.printIndexedBuffer("+Y"); - // z.printBuffer("ADD broadcasted output"); - // verify results - // for (int e = 0; e < z.lengthOf(); e++) - // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); - - // free allocated global device memory - for(int i = 0; i < devicePtrs.size(); ++i) - cudaFree(devicePtrs[i]); - - // delete cuda stream - //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); + // if (!Environment::getInstance()->isExperimentalBuild()) + // return; + + NDArray x('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2, 3, 4}, + {100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100}, + sd::DataType::INT32); + NDArray exp('c', {2, 3, 4}, {10, 11, 12, 13, 24, 25, 26, 27, 38, 39, 40, 41, + 22, 23, 24, 25, 36, 37, 38, 39, 50, 51, 52, 53}, + sd::DataType::INT32); + // real output [10, 11, 12, 13, 4, 5, 6, 7, 28, 29, 30, 31, 22, 23, 24, 25, + // 16, 17, 18, 19, 40, 41, 42, 43] + x.linspace(0); + x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back( + dimensions.data(), + dimensions.size() * sizeof(Nd4jLong)); // 0 -- dimensions + hostData.emplace_back( + xTad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, + xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t* stream = x.getContext()->getCudaStream(); + LaunchContext* pLc = x.getContext(); + + // allocate required amount of global device memory and copy host data to it + // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, + // cudaResult); + for (size_t i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc( + &devicePtrs[i], + hostData[i].second); // if(cudaResult != 0) return cudaResult; + ASSERT_EQ(cudaResult, 0); + cudaMemcpy(devicePtrs[i], hostData[i].first, hostData[i].second, + cudaMemcpyHostToDevice); + } + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast( + pLc, sd::broadcast::Add, nullptr, x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), + y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, cudaResult); + + // x.printIndexedBuffer(" X"); + // y.printIndexedBuffer("+Y"); + // z.printBuffer("ADD broadcasted output"); + // verify results + // for (int e = 0; e < z.lengthOf(); e++) + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for (int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + // cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); } TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - x *= y; - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - //for (int e = 0; e < x.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - //} + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {3}, {2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 6, 12, 8, 15, 24}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + x *= y; + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} } - TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { - // allocating host-side arrays - NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); - //auto z = NDArrayFactory::create('c', { 5 }); - - auto exp = NDArrayFactory::create('c', { 2, 3 }, { 11,12, 13,14, 15, 16 }); - auto expZ = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); - - // making raw buffers - //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; - //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); - //ASSERT_EQ(0, res); - //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); - //ASSERT_EQ(0, res); - //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - //x.printBuffer("23X = "); - //y.printBuffer("23Y = "); - //void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) - x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); - - // - // cudaFree(devBufferPtrX); - //cudaFree(devBufferPtrZ); - //cudaFree(devShapePtrX); - - //for (int e = 0; e < x.lengthOf(); e++) { - // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); - //} - ASSERT_TRUE(exp.equalsTo(expZ)); - + // allocating host-side arrays + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', {3}, {2., 3., 4.}, sd::DataType::DOUBLE); + // auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = + NDArrayFactory::create('c', {2, 3}, {11, 12, 13, 14, 15, 16}); + auto expZ = + NDArrayFactory::create('c', {2, 3}, {2, 6, 12, 8, 15, 24}); + + // making raw buffers + // Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + // cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), + // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = + // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * + // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, + // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // x.printBuffer("23X = "); + // y.printBuffer("23Y = "); + // void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* + // other, NDArray* target, const bool checkTargetShape, ExtraArguments + // *extraArgs) + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); + + // + // cudaFree(devBufferPtrX); + // cudaFree(devBufferPtrZ); + // cudaFree(devShapePtrX); + + // for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} + ASSERT_TRUE(exp.equalsTo(expZ)); } - ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestReduceSum_1) { - // allocating host-side arrays - auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create(15); - auto exp = NDArrayFactory::create(15); + // allocating host-side arrays + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(15); + auto exp = NDArrayFactory::create(15); - auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + auto stream = x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); - NativeOpExecutioner::execReduceSameScalar(x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); - auto res = cudaStreamSynchronize(*stream); - ASSERT_EQ(0, res); - y.syncToHost(); + NativeOpExecutioner::execReduceSameScalar( + x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), + x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), + y.specialBuffer(), y.specialShapeInfo()); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.syncToHost(); - ASSERT_NEAR(y.e(0), 15, 1e-5); + ASSERT_NEAR(y.e(0), 15, 1e-5); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestDup1) { + NDArray array('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto arrC = array.dup('c'); + auto arrF = array.dup('f'); + // arrC->printBuffer("arrC"); - NDArray array('c', {2,3}, {1,2,3,4,5,6}); - auto arrC = array.dup('c'); - auto arrF = array.dup('f'); - // arrC->printBuffer("arrC"); + // arrF->printBuffer("arrF"); + // arrC->printShapeInfo("C shape"); + // arrF->printShapeInfo("F shape"); - // arrF->printBuffer("arrF"); - //arrC->printShapeInfo("C shape"); - //arrF->printShapeInfo("F shape"); + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(array.equalsTo(arrF)); - ASSERT_TRUE(array.equalsTo(arrC)); - - ASSERT_TRUE(arrF.equalsTo(arrC)); + ASSERT_TRUE(arrF.equalsTo(arrC)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_1) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - - ASSERT_TRUE(x.equalsTo(y)); + ASSERT_TRUE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_TRUE(x.equalsTo(y)); + ASSERT_TRUE(x.equalsTo(y)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_2) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 10, 10}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1, 2, 5, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,10,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1,2,5,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); } ////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, equalsTo_3) { + NDArray x('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sd::DataType::DOUBLE); + NDArray y('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); - NDArray y('c', {2,5}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}, sd::DataType::FLOAT32); - - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); - x.permutei({1,0}); - y.permutei({1,0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - ASSERT_FALSE(x.equalsTo(y)); + ASSERT_FALSE(x.equalsTo(y)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { + NDArray x('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::INT32); + NDArray x2('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::INT32); + NDArray y('c', {2, 3, 4}, {-2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5, + -2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5}, + sd::DataType::INT32); + NDArray k('c', {2, 3}, {-2, 3, -4, 5, -2, 3}, sd::DataType::INT32); + NDArray k2('c', {3, 2}, {-2, 3, -4, 5, -2, 3}, sd::DataType::INT32); - NDArray x('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); - NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); - NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::INT32); - NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); - NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); + NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 3}, {-10.f, -2.f, 6.f, 14.f, 22.f, 30.f}, + sd::DataType::FLOAT32); + NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, sd::DataType::FLOAT32); + NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, sd::DataType::FLOAT32); - NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, sd::DataType::FLOAT32); - NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, sd::DataType::FLOAT32); - NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, sd::DataType::FLOAT32); + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + z = x.applyReduce3(sd::reduce3::Dot, k, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); + x.permutei({0, 2, 1}); + y.permutei({0, 2, 1}); - z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - x.permutei({0,2,1}); - y.permutei({0,2,1}); + x2.permutei({1, 0, 2}); - z = y.applyReduce3(sd::reduce3::Dot, x, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); - - x2.permutei({1,0,2}); - - z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_2) { + NDArray x('c', {2, 3, 4}, {-10, -9, -8.5, -7, -6, -5, -4, -3, -2, -1, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::DOUBLE); + NDArray x2('c', {2, 3, 4}, {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0.5, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + sd::DataType::DOUBLE); + NDArray y('c', {2, 3, 4}, {-2, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5, + -2.5, 3, -4, 5, -2, 3, -4, 5, -2, 3, -4, 5}, + sd::DataType::DOUBLE); + NDArray k('c', {2, 3}, {-2, 3, -4, 5.5, -2, 3}, sd::DataType::DOUBLE); + NDArray k2('c', {3, 2}, {-2, 3, -4, 5, -2, 3.5}, sd::DataType::DOUBLE); - NDArray x('c', {2,3,4}, {-10,-9,-8.5,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); - NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0.5,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); - NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2.5,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::DOUBLE); - NDArray k('c', {2,3}, {-2,3,-4,5.5,-2,3}, sd::DataType::DOUBLE); - NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3.5}, sd::DataType::DOUBLE); - - NDArray exp1('c', {3}, {5., 20., 36.}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,3}, {-8., -2., 6., 13., 22., 30.}, sd::DataType::DOUBLE); - NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, sd::DataType::DOUBLE); - NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, sd::DataType::DOUBLE); + NDArray exp1('c', {3}, {5., 20., 36.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 3}, {-8., -2., 6., 13., 22., 30.}, + sd::DataType::DOUBLE); + NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, sd::DataType::DOUBLE); - NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = x.applyReduce3(sd::reduce3::Dot, k, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); - x.permutei({0,2,1}); - y.permutei({0,2,1}); + x.permutei({0, 2, 1}); + y.permutei({0, 2, 1}); - z = y.applyReduce3(sd::reduce3::Dot, x, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - x2.permutei({1,0,2}); + x2.permutei({1, 0, 2}); - z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { + NDArray x1('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray x2('c', {2, 2, 2}, {-1, -2, -3, -4, -5, -6, -7, -8}, + sd::DataType::INT32); + NDArray x3('c', {3, 2}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); - NDArray x1('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{-204}, sd::DataType::FLOAT32); - NDArray exp2('c', {}, std::vector{31.5}, sd::DataType::DOUBLE); - + NDArray exp1('c', {}, std::vector{-204}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{31.5}, sd::DataType::DOUBLE); - auto z = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_TRUE(z.equalsTo(&exp1)); + auto z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); - x1.permutei({2,1,0}); - x2.permutei({2,1,0}); - x3.permutei({1,0}); - x4.permutei({1,0}); + x1.permutei({2, 1, 0}); + x2.permutei({2, 1, 0}); + x3.permutei({1, 0}); + x4.permutei({1, 0}); - z = x1.applyReduce3(reduce3::Dot, x2); - ASSERT_TRUE(z.equalsTo(&exp1)); + z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x3.applyReduce3(reduce3::Dot, x4); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); } //////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { - - NDArray x1('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); - NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); - NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); - NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); - - NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, sd::DataType::FLOAT32); - NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, - -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, - sd::DataType::FLOAT32); - NDArray exp3('c', {1,1}, std::vector{31.5}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, sd::DataType::DOUBLE); - - auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); - - z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); - ASSERT_TRUE(z.equalsTo(&exp2)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp3)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); - ASSERT_TRUE(z.equalsTo(&exp4)); - - x1.permutei({2,1,0}); - x2.permutei({2,1,0}); - x3.permutei({1,0}); - x4.permutei({1,0}); - - z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); - ASSERT_TRUE(z.equalsTo(&exp1)); - - z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); - ASSERT_TRUE(z.equalsTo(&exp4)); + NDArray x1('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::INT32); + NDArray x2('c', {2, 2, 2}, {-1, -2, -3, -4, -5, -6, -7, -8}, + sd::DataType::INT32); + NDArray x3('c', {3, 2}, {1.5, 1.5, 1.5, 1.5, 1.5, 1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3, 2}, {1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + + NDArray exp1('c', {3, 2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, + sd::DataType::FLOAT32); + NDArray exp2('c', {6, 4}, + {-36.f, -44.f, -52.f, -60.f, -42.f, -52.f, -62.f, -72.f, + 2.f, 0.f, -2.f, -4.f, 6.f, 4.f, 2.f, 0.f, + 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, + sd::DataType::FLOAT32); + NDArray exp3('c', {1, 1}, std::vector{31.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 3}, {4.5, 10.5, 16.5, 4.5, 10.5, 16.5, 4.5, 10.5, 16.5}, + sd::DataType::DOUBLE); + + auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); + ASSERT_TRUE(z.equalsTo(&exp4)); + + x1.permutei({2, 1, 0}); + x2.permutei({2, 1, 0}); + x3.permutei({1, 0}); + x4.permutei({1, 0}); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0, 2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); + ASSERT_TRUE(z.equalsTo(&exp4)); } ////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { + NDArray x('c', {2, 3}, {0, 10, 1, 2, 2.5, -4}, sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); + NDArray scalar('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray vec1('c', {2}, {100, 100}, sd::DataType::INT64); + NDArray vec2('c', {3}, {100, 100, 100}, sd::DataType::INT64); - NDArray scalar('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray vec1('c', {2}, {100,100}, sd::DataType::INT64); - NDArray vec2('c', {3}, {100,100,100}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_TRUE(scalar.equalsTo(&exp1)); - x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_TRUE(scalar.equalsTo(&exp1)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_TRUE(vec1.equalsTo(&exp2)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); - ASSERT_TRUE(vec1.equalsTo(&exp2)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_TRUE(vec2.equalsTo(&exp3)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); - ASSERT_TRUE(vec2.equalsTo(&exp3)); + x.permutei({1, 0}); - x.permutei({1,0}); + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0, 1}); + ASSERT_TRUE(scalar.equalsTo(&exp4)); - x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); - ASSERT_TRUE(scalar.equalsTo(&exp4)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {0}); + ASSERT_TRUE(vec1.equalsTo(&exp5)); - x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {0}); - ASSERT_TRUE(vec1.equalsTo(&exp5)); - - x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {1}); - ASSERT_TRUE(vec2.equalsTo(&exp6)); + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {1}); + ASSERT_TRUE(vec2.equalsTo(&exp6)); } - ////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { + NDArray x('c', {2, 3}, {0, 10, 1, 2, 2.5, -4}, sd::DataType::DOUBLE); - NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); - NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 0, 0}, sd::DataType::INT64); - NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); - NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1, 0, 0}, sd::DataType::INT64); - auto z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp1)); + auto z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp1)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z.equalsTo(&exp2)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z.equalsTo(&exp3)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp3)); - x.permutei({1,0}); + x.permutei({1, 0}); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); - ASSERT_TRUE(z.equalsTo(&exp4)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0, 1}); + ASSERT_TRUE(z.equalsTo(&exp4)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); - ASSERT_TRUE(z.equalsTo(&exp5)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp5)); - z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); - ASSERT_TRUE(z.equalsTo(&exp6)); + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp6)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { - - NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); - - NDArray z1('c', {}, std::vector{100}, sd::DataType::DOUBLE); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::DOUBLE); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); - - NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, sd::DataType::FLOAT32); - NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::FLOAT32); - NDArray exp5('c', {2}, {3.5f,0.833333f}, sd::DataType::FLOAT32); - - x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - x.reduceAlongDimension(sd::reduce::Mean, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - x.reduceAlongDimension(sd::reduce::Mean, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - x.reduceAlongDimension(sd::reduce::Mean, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::Mean, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::INT32); + + NDArray z1('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::DOUBLE); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100, 100}, sd::DataType::FLOAT32); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {3.f, 4.f, 1.f, 0.666667f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {4.5, 1, 1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {4, 5, 1, 1, 1, 1}, sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {3.5f, 0.833333f}, sd::DataType::FLOAT32); + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::Mean, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::Mean, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { - - NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::DOUBLE); - - NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3,4,1,0.666667}, sd::DataType::DOUBLE); - NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); - NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::DOUBLE); - NDArray exp5('c', {2}, {3.5,0.833333}, sd::DataType::DOUBLE); - - NDArray z1 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - NDArray z2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - NDArray z3 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - NDArray z4 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); - - NDArray z5 = x.reduceAlongDimension(sd::reduce::Mean, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); - - NDArray z6 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + -1, + -2, + -3, + -4, + }, + sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2, 2}, {3, 4, 1, 0.666667}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {4.5, 1, 1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3, 2}, {4, 5, 1, 1, 1, 1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2}, {3.5, 0.833333}, sd::DataType::DOUBLE); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Mean, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Mean, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Mean, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Mean, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, EqualityTest1) { - auto arrayA = NDArrayFactory::create_('f', {3, 5}); - auto arrayB = NDArrayFactory::create_('f', {3, 5}); - auto arrayC = NDArrayFactory::create_('f', {3, 5}); + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); - auto arrayD = NDArrayFactory::create_('f', {2, 4}); - auto arrayE = NDArrayFactory::create_('f', {1, 15}); + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); - for (int i = 0; i < arrayA->rows(); i++) { - for (int k = 0; k < arrayA->columns(); k++) { - arrayA->p(i, k, (float) i); - } + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float)i); } + } - for (int i = 0; i < arrayB->rows(); i++) { - for (int k = 0; k < arrayB->columns(); k++) { - arrayB->p(i, k, (float) i); - } + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float)i); } + } - for (int i = 0; i < arrayC->rows(); i++) { - for (int k = 0; k < arrayC->columns(); k++) { - arrayC->p(i, k, (float) i+1); - } + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float)i + 1); } + } - ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); - ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); - delete arrayA; - delete arrayB; - delete arrayC; - delete arrayD; - delete arrayE; + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { + NDArray x('c', {2, 3, 2}, + {1.5f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.5f, 8.f, -1.f, -2.f, -3.5f, -4.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, sd::DataType::FLOAT32); - - NDArray z1('c', {}, std::vector{100}, sd::DataType::FLOAT32); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::FLOAT32); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); - NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); + NDArray z1('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::FLOAT32); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, + sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100, 100}, sd::DataType::FLOAT32); - NDArray exp1('c', {}, std::vector{26.5f}, sd::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, sd::DataType::FLOAT32); - NDArray exp3('c', {3}, {19.f,4.f,3.5f}, sd::DataType::FLOAT32); - NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, sd::DataType::FLOAT32); - NDArray exp5('c', {2}, {21.5f,5.f}, sd::DataType::FLOAT32); + NDArray exp1('c', {}, std::vector{26.5f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2, 2}, {9.5f, 12.f, 3.f, 2.f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {19.f, 4.f, 3.5f}, sd::DataType::FLOAT32); + NDArray exp4('c', {3, 2}, {9.f, 10.f, 2.f, 2.f, 1.5f, 2.f}, + sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {21.5f, 5.f}, sd::DataType::FLOAT32); - x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::Sum, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::Sum, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::Sum, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::Sum, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.reduceAlongDimension(sd::reduce::Sum, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::Sum, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::Sum, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); + x.reduceAlongDimension(sd::reduce::Sum, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::Sum, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::Sum, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { - - NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, sd::DataType::INT64); - - NDArray exp1('c', {}, std::vector{26}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {9,12,3,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {18,4,4}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, sd::DataType::INT64); - NDArray exp5('c', {2}, {21,5}, sd::DataType::INT64); - - NDArray z1 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); - - NDArray z2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); - - NDArray z3 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); - - x.permutei({1,0,2}); // 3x2x2 - - NDArray z4 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); - - NDArray z5 = x.reduceAlongDimension(sd::reduce::Sum, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); - - NDArray z6 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray x('c', {2, 3, 2}, + { + 1.5, + 2, + 3, + 4, + 5, + 6, + 7.5, + 8, + -1, + -2, + -3.5, + -4, + }, + sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{26}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {9, 12, 3, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {18, 4, 4}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {8, 10, 2, 2, 2, 2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {21, 5}, sd::DataType::INT64); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Sum, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Sum, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1, 0, 2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Sum, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Sum, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -4, 5, 6, -7.5, 8, -1, -0.5, -3.5, 4}, + sd::DataType::DOUBLE); - NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::DOUBLE); + NDArray z1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray z2('c', {2, 2}, {true, true, true, true}, sd::DataType::BOOL); + NDArray z3('c', {3}, {true, true, true}, sd::DataType::BOOL); + NDArray z4('c', {3, 2}, {true, true, true, true, true, true}, + sd::DataType::BOOL); + NDArray z5('c', {2}, {true, true}, sd::DataType::BOOL); - NDArray z1('c', {}, std::vector{true}, sd::DataType::BOOL); - NDArray z2('c', {2,2}, {true,true,true,true}, sd::DataType::BOOL); - NDArray z3('c', {3}, {true,true,true}, sd::DataType::BOOL); - NDArray z4('c', {3,2}, {true,true,true,true,true,true}, sd::DataType::BOOL); - NDArray z5('c', {2}, {true,true}, sd::DataType::BOOL); + NDArray exp1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {true, true, false, true}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {true, true, true}, sd::DataType::BOOL); + NDArray exp4('c', {3, 2}, {true, true, true, false, true, true}, + sd::DataType::BOOL); + NDArray exp5('c', {2}, {true, true}, sd::DataType::BOOL); - NDArray exp1('c', {}, std::vector{true}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {true,true,false,true}, sd::DataType::BOOL); - NDArray exp3('c', {3}, {true,true,true}, sd::DataType::BOOL); - NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, sd::DataType::BOOL); - NDArray exp5('c', {2}, {true,true}, sd::DataType::BOOL); + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::IsPositive, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::IsPositive, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::IsPositive, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.reduceAlongDimension(sd::reduce::IsPositive, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.permutei({1, 0, 2}); // 3x2x2 - x.permutei({1,0,2}); // 3x2x2 + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::IsPositive, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::IsPositive, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::IsPositive, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::IsPositive, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -4, 5, 6, -7.5, 8, -1, -0.5, -3.5, 4}, + sd::DataType::INT32); - NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::INT32); - - NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); - NDArray exp2('c', {2,2}, {1,1,0,1}, sd::DataType::BOOL); - NDArray exp3('c', {3}, {1,1,1}, sd::DataType::BOOL); - NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, sd::DataType::BOOL); - NDArray exp5('c', {2}, {1,1}, sd::DataType::BOOL); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2, 2}, {1, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {1, 1, 1}, sd::DataType::BOOL); + NDArray exp4('c', {3, 2}, {0, 1, 1, 0, 1, 1}, sd::DataType::BOOL); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::BOOL); - NDArray z1 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + NDArray z1 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + NDArray z2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + NDArray z3 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - NDArray z4 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); + NDArray z4 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); + NDArray z5 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray z6 = x.reduceAlongDimension(sd::reduce::IsPositive, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { + NDArray x( + 'c', {2, 3, 2}, + {0.5f, 2.f, 3.f, -0.f, 5.f, 6.f, -7.5f, 0.f, -1.f, -0.5f, -3.5f, 4.f}, + sd::DataType::FLOAT32); - NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, sd::DataType::FLOAT32); + NDArray z1('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray z2('c', {2, 2}, {100, 100, 100, 100}, sd::DataType::INT64); + NDArray z3('c', {3}, {100, 100, 100}, sd::DataType::INT64); + NDArray z4('c', {3, 2}, {100, 100, 100, 100, 100, 100}, sd::DataType::INT64); + NDArray z5('c', {2}, {100, 100}, sd::DataType::INT64); - NDArray z1('c', {}, std::vector{100}, sd::DataType::INT64); - NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::INT64); - NDArray z3('c', {3}, {100,100,100}, sd::DataType::INT64); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT64); - NDArray z5('c', {2}, {100,100}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {0, 1, 0, 1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1, 1, 0}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {0, 1, 0, 1, 0, 0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1, 1}, sd::DataType::INT64); - NDArray exp1('c', {}, std::vector{2}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {0,1,0,1}, sd::DataType::INT64); - NDArray exp3('c', {3}, {1,1,0}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, sd::DataType::INT64); - NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::CountZero, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - x.reduceAlongDimension(sd::reduce::CountZero, z2, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + x.reduceAlongDimension(sd::reduce::CountZero, z3, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.reduceAlongDimension(sd::reduce::CountZero, z3, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + x.permutei({1, 0, 2}); // 3x2x2 - x.permutei({1,0,2}); // 3x2x2 + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + x.reduceAlongDimension(sd::reduce::CountZero, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); - x.reduceAlongDimension(sd::reduce::CountZero, z4, {1}); - ASSERT_TRUE(z4.equalsTo(&exp4)); - - x.reduceAlongDimension(sd::reduce::CountZero, z5, {0,2}); - ASSERT_TRUE(z5.equalsTo(&exp5)); + x.reduceAlongDimension(sd::reduce::CountZero, z5, {0, 2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { + NDArray x('c', {2, 3, 2}, {0.5, 2, 3, -0, 5, 6, -7.5, 0, -1, -0.5, -3.5, 4}, + sd::DataType::INT32); - NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, sd::DataType::INT32); - - NDArray exp1('c', {}, std::vector{4}, sd::DataType::INT64); - NDArray exp2('c', {2,2}, {1,1,0,2}, sd::DataType::INT64); - NDArray exp3('c', {3}, {2,2,0}, sd::DataType::INT64); - NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, sd::DataType::INT64); - NDArray exp5('c', {2}, {2,2}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{4}, sd::DataType::INT64); + NDArray exp2('c', {2, 2}, {1, 1, 0, 2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {2, 2, 0}, sd::DataType::INT64); + NDArray exp4('c', {3, 2}, {1, 1, 0, 2, 0, 0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {2, 2}, sd::DataType::INT64); - NDArray z1 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); - ASSERT_TRUE(z1.equalsTo(&exp1)); + NDArray z1 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 1, 2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); - NDArray z2 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); - ASSERT_TRUE(z2.equalsTo(&exp2)); + NDArray z2 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); - NDArray z3 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); - ASSERT_TRUE(z3.equalsTo(&exp3)); + NDArray z3 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); - x.permutei({1,0,2}); // 3x2x2 + x.permutei({1, 0, 2}); // 3x2x2 - NDArray z4 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); - ASSERT_TRUE(z4.equalsTo(&exp1)); + NDArray z4 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 1, 2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); - NDArray z5 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); - ASSERT_TRUE(z5.equalsTo(&exp4)); + NDArray z5 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); - NDArray z6 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); - ASSERT_TRUE(z6.equalsTo(&exp5)); + NDArray z6 = x.reduceAlongDimension(sd::reduce::CountZero, {0, 2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); } TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', + { + 1, + 5, + }, + {1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + sd::DataType::FLOAT32); - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - - ASSERT_TRUE(row->equalsTo(&expRow)); + ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, *row, z); - x += *row; + x.applyBroadcast(broadcast::Add, {1}, *row, z); + x += *row; - ASSERT_TRUE(x.equalsTo(z)); - //ASSERT_TRUE(z.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(z)); + // ASSERT_TRUE(z.equalsTo(&exp)); - delete row; + delete row; } TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest2) { - - auto x = NDArrayFactory::create('c', {5, 5}); - //auto z = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); - NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); - - ASSERT_TRUE(row->equalsTo(&expRow)); - x.applyBroadcast(broadcast::Add, {1}, *row, x); - ASSERT_TRUE(x.equalsTo(&exp)); + auto x = NDArrayFactory::create('c', {5, 5}); + // auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', + { + 1, + 5, + }, + {1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, + sd::DataType::FLOAT32); + + ASSERT_TRUE(row->equalsTo(&expRow)); + x.applyBroadcast(broadcast::Add, {1}, *row, x); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, TestBroadcast_1) { + NDArray exp('c', {2, 3, 2, 2}, + {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., + 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.}, + sd::DataType::DOUBLE); - NDArray exp('c', {2, 3, 2, 2}, {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.}, sd::DataType::DOUBLE); + auto input = NDArrayFactory::create('c', {2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); - auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); - auto bias = NDArrayFactory::create('c', {1, 3}); - - bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, bias, input); - ASSERT_TRUE(exp.equalsTo(&input)); + bias.linspace(1); + input.applyBroadcast(broadcast::Add, {1}, bias, input); + ASSERT_TRUE(exp.equalsTo(&input)); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_1) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - ASSERT_TRUE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_2) { - auto x = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); - ASSERT_TRUE(x.equalsTo(y)); - //for (int e = 0; e < x.lengthOf(); e++) - // ASSERT_NEAR(x.e(e), y.e(e), 1.e-5f); + auto x = + NDArrayFactory::create('c', {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(y)); + // for (int e = 0; e < x.lengthOf(); e++) + // ASSERT_NEAR(x.e(e), y.e(e), 1.e-5f); } TEST_F(NDArrayCudaBasicsTests, TestFloat16_3) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - ASSERT_TRUE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_4) { - auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); - auto y = NDArrayFactory::create({2,4,5,5,6,7,8,9}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = NDArrayFactory::create({1, 2, 3, 4, 5, 7, 8, 9}); + auto y = NDArrayFactory::create({2, 4, 5, 5, 6, 7, 8, 9}); + ASSERT_FALSE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_5) { - auto x = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('c', {3,3}, {2,4,5,5,6,7,8,9, 10}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {2, 4, 5, 5, 6, 7, 8, 9, 10}); + ASSERT_FALSE(x.equalsTo(&y)); } TEST_F(NDArrayCudaBasicsTests, TestFloat_6) { - auto x = NDArrayFactory::create('f', {3,3}, {1,2,3,4,5,6,7,8,9}); - auto y = NDArrayFactory::create('f', {3,3}, {2,4,5,5,6,7,8,9,10}); - ASSERT_FALSE(x.equalsTo(&y)); + auto x = + NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('f', {3, 3}, {2, 4, 5, 5, 6, 7, 8, 9, 10}); + ASSERT_FALSE(x.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {1, 8, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2 = NDArrayFactory::create(expected.ordering(), expected.getShapeAsVector()); - x = 1.; - y = 2.; - expected = 3.; - res2 = 0.f; +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {1, 8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2 = NDArrayFactory::create(expected.ordering(), + expected.getShapeAsVector()); + x = 1.; + y = 2.; + expected = 3.; + res2 = 0.f; - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2);// *= y; + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2); // *= y; - ASSERT_TRUE(expected.isSameShape(&res2)); - ASSERT_TRUE(expected.equalsTo(&res2)); + ASSERT_TRUE(expected.isSameShape(&res2)); + ASSERT_TRUE(expected.equalsTo(&res2)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_5) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 1, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2(expected); - x = 1.; - y = 2.; - expected = 3.; - //x.printBuffer("X="); - //y.printBuffer("Y="); - //expected.printBuffer("EXPECTED"); - auto result = x + y; - //result.printBuffer("1 + 2 ="); - //res2.assign(x + y); - - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); - //res2.printBuffer("Z="); - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; -// x += y; - //x.printBuffer("OutputX"); - //res2.syncToHost(); - //res2.printBuffer("OUputZ"); - //x.printIndexedBuffer("OUtputX"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_5) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + // x.printBuffer("X="); + // y.printBuffer("Y="); + // expected.printBuffer("EXPECTED"); + auto result = x + y; + // result.printBuffer("1 + 2 ="); + // res2.assign(x + y); + + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + // res2.printBuffer("Z="); + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + // x += y; + // x.printBuffer("OutputX"); + // res2.syncToHost(); + // res2.printBuffer("OUputZ"); + // x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_51) -{ - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); - NDArray res2(expected); - x = 1.; - y = 2.; - expected = 3.; - //x.printBuffer("X="); - //y.printBuffer("Y="); - //expected.printBuffer("EXPECTED"); - auto result = x + y; - //result.printBuffer("1 + 2 ="); - //res2.assign(x + y); - - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); - //res2.printBuffer("Z="); - //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; -// x += y; - //x.printBuffer("OutputX"); - //res2.syncToHost(); - //res2.printBuffer("OUputZ"); - //x.printIndexedBuffer("OUtputX"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_1) -{ - auto x = NDArrayFactory::create('c', {2, 1, 2}); - x = 10.; - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}); - exp = 10.; - - // y.printShapeInfo("Output SHAPE"); - // y.printBuffer("Output TILE"); - // exp.printBuffer("Expect TILE"); - ASSERT_TRUE(exp.equalsTo(y)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_2) -{ - auto x = NDArrayFactory::create('f', {2, 1, 2}); - x = 10.; - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('f', {2, 2, 2}); - exp = 10.; - ASSERT_TRUE(exp.equalsTo(y)); -} - -TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) -{ - auto x = NDArrayFactory::create('f', {2, 1, 2}); - x = 10.; - x.p(1,0,1, 20); - x.syncToDevice(); - auto y = x.tile({1,2,1}); - auto exp = NDArrayFactory::create('f', {2, 2, 2}); - exp = 10.; - exp.p(1,0,1, 20.); - exp.p(1, 1, 1, 20.); - exp.syncToDevice(); - ASSERT_TRUE(exp.equalsTo(y)); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_51) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + // x.printBuffer("X="); + // y.printBuffer("Y="); + // expected.printBuffer("EXPECTED"); + auto result = x + y; + // result.printBuffer("1 + 2 ="); + // res2.assign(x + y); + + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + // res2.printBuffer("Z="); + // x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; + // x += y; + // x.printBuffer("OutputX"); + // res2.syncToHost(); + // res2.printBuffer("OUputZ"); + // x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_1) { + auto x = NDArrayFactory::create('c', {2, 1, 2}); + x = 10.; + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + exp = 10.; + + // y.printShapeInfo("Output SHAPE"); + // y.printBuffer("Output TILE"); + // exp.printBuffer("Expect TILE"); + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_2) { + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) { + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + x.p(1, 0, 1, 20); + x.syncToDevice(); + auto y = x.tile({1, 2, 1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + exp.p(1, 0, 1, 20.); + exp.p(1, 1, 1, 20.); + exp.syncToDevice(); + ASSERT_TRUE(exp.equalsTo(y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, sd::DataType::FLOAT32); - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + NDArray a('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 2, 1, 0, 4, 7}, + sd::DataType::FLOAT32); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); - auto result = x + y; + x.linspace(1); + y.linspace(1); + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, assign_2) -{ - NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, sd::DataType::FLOAT32); - NDArray y('c', {4}, sd::DataType::INT32); - NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); +TEST_F(NDArrayCudaBasicsTests, assign_2) { + NDArray x('c', {4}, {1.5f, 2.5f, 3.5f, 4.5f}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1, 2, 3, 4}, sd::DataType::INT32); - y.assign(x); - // y.printBuffer("ASSIGN VECTOR"); + y.assign(x); + // y.printBuffer("ASSIGN VECTOR"); - ASSERT_TRUE(expected.equalsTo(&y)); + ASSERT_TRUE(expected.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayCudaBasicsTests, subarray_1) -{ - NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - - Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX0[] = {1.f, 13.f}; - Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX1[] = {2.f, 14.f}; - Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; - float buffExpX2[] = {1.f, 13.f}; - Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; - float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; - Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; - float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; - Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; - float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; - - Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY0[] = {1.f, 2.f}; - Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY1[] = {7.f, 8.f}; - Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.f, 2.f}; - Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; - float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; - Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; - float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; - Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; - float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; - - - NDArray x0 = x(0, {1,2}); - NDArray xExp(buffExpX0, shapeExpX0); - - ASSERT_TRUE(xExp.isSameShape(x0)); - ASSERT_TRUE(xExp.equalsTo(x0)); -// for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) -// ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); -// for(int i = 0; i < x0.lengthOf(); ++i) -// ASSERT_TRUE(x0.e(i) == buffExpX0[i]); - - NDArray x1 = x(1, {1,2}); - NDArray x1Exp(buffExpX1, shapeExpX1); - ASSERT_TRUE(x1Exp.isSameShape(x1)); - ASSERT_TRUE(x1Exp.equalsTo(x1)); - -// for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) -// ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX1[i]); -// for(int i = 0; i < x1.lengthOf(); ++i) -// ASSERT_TRUE(x1.e(i) == buffExpX1[i]); - - NDArray x2 = x(0, {1,2}, true); - NDArray x2Exp(buffExpX2, shapeExpX2); - ASSERT_TRUE(x2Exp.isSameShape(x2)); -// x2.printBuffer("X2"); -// x2Exp.printBuffer("X2 EXPECT"); - ASSERT_TRUE(x2Exp.equalsTo(x2)); -// for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) -// ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); -// for(int i = 0; i < x2.lengthOf(); ++i) -// ASSERT_TRUE(x2.e(i) == buffExpX2[i]); - - NDArray x3 = x(2, {1}); - NDArray x3Exp(buffExpX3, shapeExpX3); - ASSERT_TRUE(x3Exp.isSameShape(x3)); - ASSERT_TRUE(x3Exp.equalsTo(x3)); -// for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) -// ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); -// for(int i = 0; i < x3.lengthOf(); ++i) -// ASSERT_TRUE(x3.e(i) == buffExpX3[i]); - - NDArray x4 = x(2, {1}, true); - NDArray x4Exp(buffExpX4, shapeExpX4); - ASSERT_TRUE(x4Exp.isSameShape(x4)); - ASSERT_TRUE(x4Exp.equalsTo(x4)); -// for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) -// ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); -// for(int i = 0; i < x4.lengthOf(); ++i) -// ASSERT_TRUE(x4.e(i) == buffExpX4[i]); - - NDArray x5 = x(3, {2}); - NDArray x5Exp(buffExpX5, shapeExpX5); - ASSERT_TRUE(x5Exp.isSameShape(x5)); - ASSERT_TRUE(x5Exp.equalsTo(x5)); - -// for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) -// ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); -// for(int i = 0; i < x5.lengthOf(); ++i) -// ASSERT_TRUE(x5.e(i) == buffExpX5[i]); - - // ******************* // - NDArray y0 = y(0, {1,2}); - NDArray y0Exp(buffExpY0, shapeExpY0); - ASSERT_TRUE(y0Exp.isSameShape(y0)); - ASSERT_TRUE(y0Exp.equalsTo(y0)); -// for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) -// ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); -// for(int i = 0; i < y0.lengthOf(); ++i) -// ASSERT_TRUE(y0.e(i) == buffExpY0[i]); - - NDArray y1 = y(1, {1,2}); - NDArray y1Exp(buffExpY1, shapeExpY1); - ASSERT_TRUE(y1Exp.isSameShape(y1)); - ASSERT_TRUE(y1Exp.equalsTo(y1)); -// for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) -// ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY1[i]); -// for(int i = 0; i < y1.lengthOf(); ++i) -// ASSERT_TRUE(y1.e(i) == buffExpY1[i]); - - NDArray y2 = y(0, {1,2}, true); - NDArray y2Exp(buffExpY2, shapeExpY2); - ASSERT_TRUE(y2Exp.isSameShape(y2)); - ASSERT_TRUE(y2Exp.equalsTo(y2)); -// for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) -// ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); -// for(int i = 0; i < y2.lengthOf(); ++i) -// ASSERT_TRUE(y2.e(i) == buffExpY2[i]); - - NDArray y3 = y(2, {1}); - NDArray y3Exp(buffExpY3, shapeExpY3); - ASSERT_TRUE(y3Exp.isSameShape(y3)); - ASSERT_TRUE(y3Exp.equalsTo(y3)); -// for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) -// ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); -// for(int i = 0; i < y3.lengthOf(); ++i) -// ASSERT_TRUE(y3.e(i) == buffExpY3[i]); - - NDArray y4 = y(2, {1}, true); - NDArray y4Exp = NDArrayFactory::create('f', {2,1,4}, {5, 6, 11, 12, 17, 18, 23, 24}); - ASSERT_TRUE(y4Exp.isSameShape(y4)); - ASSERT_TRUE(y4Exp.equalsTo(y4)); -// for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) -// ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); -// for(int i = 0; i < y4.lengthOf(); ++i) -// ASSERT_TRUE(y4.e(i) == buffExpY4[i]); - - NDArray y5 = y(3, {2}); - NDArray y5Exp(buffExpY5, shapeExpY5); - ASSERT_TRUE(y5Exp.isSameShape(y5)); - ASSERT_TRUE(y5Exp.equalsTo(y5)); -// for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) -// ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); -// for(int i = 0; i < y5.lengthOf(); ++i) -// ASSERT_TRUE(y5.e(i) == buffExpY5[i]); +TEST_F(NDArrayCudaBasicsTests, subarray_1) { + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX0[] = {1.f, 13.f}; + Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX1[] = {2.f, 14.f}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; + float buffExpX2[] = {1.f, 13.f}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; + float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; + float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; + float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY0[] = {1.f, 2.f}; + Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY1[] = {7.f, 8.f}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.f, 2.f}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; + float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; + float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; + float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; + + NDArray x0 = x(0, {1, 2}); + NDArray xExp(buffExpX0, shapeExpX0); + + ASSERT_TRUE(xExp.isSameShape(x0)); + ASSERT_TRUE(xExp.equalsTo(x0)); + // for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) + // ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); + // for(int i = 0; i < x0.lengthOf(); ++i) + // ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1, 2}); + NDArray x1Exp(buffExpX1, shapeExpX1); + ASSERT_TRUE(x1Exp.isSameShape(x1)); + ASSERT_TRUE(x1Exp.equalsTo(x1)); + + // for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) + // ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX1[i]); + // for(int i = 0; i < x1.lengthOf(); ++i) + // ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1, 2}, true); + NDArray x2Exp(buffExpX2, shapeExpX2); + ASSERT_TRUE(x2Exp.isSameShape(x2)); + // x2.printBuffer("X2"); + // x2Exp.printBuffer("X2 EXPECT"); + ASSERT_TRUE(x2Exp.equalsTo(x2)); + // for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) + // ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); + // for(int i = 0; i < x2.lengthOf(); ++i) + // ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + NDArray x3Exp(buffExpX3, shapeExpX3); + ASSERT_TRUE(x3Exp.isSameShape(x3)); + ASSERT_TRUE(x3Exp.equalsTo(x3)); + // for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) + // ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); + // for(int i = 0; i < x3.lengthOf(); ++i) + // ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + NDArray x4Exp(buffExpX4, shapeExpX4); + ASSERT_TRUE(x4Exp.isSameShape(x4)); + ASSERT_TRUE(x4Exp.equalsTo(x4)); + // for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) + // ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); + // for(int i = 0; i < x4.lengthOf(); ++i) + // ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + NDArray x5Exp(buffExpX5, shapeExpX5); + ASSERT_TRUE(x5Exp.isSameShape(x5)); + ASSERT_TRUE(x5Exp.equalsTo(x5)); + + // for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) + // ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); + // for(int i = 0; i < x5.lengthOf(); ++i) + // ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1, 2}); + NDArray y0Exp(buffExpY0, shapeExpY0); + ASSERT_TRUE(y0Exp.isSameShape(y0)); + ASSERT_TRUE(y0Exp.equalsTo(y0)); + // for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) + // ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); + // for(int i = 0; i < y0.lengthOf(); ++i) + // ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1, 2}); + NDArray y1Exp(buffExpY1, shapeExpY1); + ASSERT_TRUE(y1Exp.isSameShape(y1)); + ASSERT_TRUE(y1Exp.equalsTo(y1)); + // for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) + // ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY1[i]); + // for(int i = 0; i < y1.lengthOf(); ++i) + // ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1, 2}, true); + NDArray y2Exp(buffExpY2, shapeExpY2); + ASSERT_TRUE(y2Exp.isSameShape(y2)); + ASSERT_TRUE(y2Exp.equalsTo(y2)); + // for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) + // ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); + // for(int i = 0; i < y2.lengthOf(); ++i) + // ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + NDArray y3Exp(buffExpY3, shapeExpY3); + ASSERT_TRUE(y3Exp.isSameShape(y3)); + ASSERT_TRUE(y3Exp.equalsTo(y3)); + // for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) + // ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); + // for(int i = 0; i < y3.lengthOf(); ++i) + // ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + NDArray y4Exp = NDArrayFactory::create('f', {2, 1, 4}, + {5, 6, 11, 12, 17, 18, 23, 24}); + ASSERT_TRUE(y4Exp.isSameShape(y4)); + ASSERT_TRUE(y4Exp.equalsTo(y4)); + // for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) + // ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); + // for(int i = 0; i < y4.lengthOf(); ++i) + // ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + NDArray y5Exp(buffExpY5, shapeExpY5); + ASSERT_TRUE(y5Exp.isSameShape(y5)); + ASSERT_TRUE(y5Exp.equalsTo(y5)); + // for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) + // ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); + // for(int i = 0; i < y5.lengthOf(); ++i) + // ASSERT_TRUE(y5.e(i) == buffExpY5[i]); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { - - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); - - auto diag = x.diagonal('c'); - //diag.syncToDevice(); - for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - printf("VAL[%ld] = %f\n", e, diag.e(e)); //, exp.e(e), 1.e-5); - } - - for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { - ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); - } - double eps(1.e-5); - NDArray tmp(sd::DataType::FLOAT32, x.getContext()); // scalar = 0 - - ExtraArguments extras({eps}); - NativeOpExecutioner::execReduce3Scalar(diag.getContext(), reduce3::EqualsWithEps, diag.buffer(), - diag.shapeInfo(), diag.specialBuffer(), diag.specialShapeInfo(), extras.argumentsAsT(sd::DataType::FLOAT32), - exp.buffer(), exp.shapeInfo(), exp.specialBuffer(), exp.specialShapeInfo(), - tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); - cudaStream_t* stream = x.getContext()->getCudaStream(); - auto res = cudaStreamSynchronize(*stream); - // tmp.printBuffer("Compare result is (expected 0)"); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + + auto diag = x.diagonal('c'); + // diag.syncToDevice(); + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + printf("VAL[%ld] = %f\n", e, + diag.e(e)); //, exp.e(e), 1.e-5); + } + + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); + } + double eps(1.e-5); + NDArray tmp(sd::DataType::FLOAT32, x.getContext()); // scalar = 0 + + ExtraArguments extras({eps}); + NativeOpExecutioner::execReduce3Scalar( + diag.getContext(), reduce3::EqualsWithEps, diag.buffer(), + diag.shapeInfo(), diag.specialBuffer(), diag.specialShapeInfo(), + extras.argumentsAsT(sd::DataType::FLOAT32), exp.buffer(), exp.shapeInfo(), + exp.specialBuffer(), exp.specialShapeInfo(), tmp.buffer(), + tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); + cudaStream_t* stream = x.getContext()->getCudaStream(); + auto res = cudaStreamSynchronize(*stream); + // tmp.printBuffer("Compare result is (expected 0)"); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { - auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); - //x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x->reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); + // x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x->reshapei('c', {3, 4, 5}); - x->permutei({0, 1, 2}); - x->streamline(); + x->permutei({0, 1, 2}); + x->streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); - delete x; + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + delete x; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({0, 1, 2}); - x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({0, 1, 2}); - x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) { - //auto x = NDArrayFactory::create('c', {1, 60}); - auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); -// auto x = *xx; - //x.linspace(1); -// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); -// x.reshapei('c', {3, 4, 5}); - -// x.permutei({0, 1, 2}); -// x.streamline(); - -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); - -// ASSERT_TRUE(exp.isSameShape(&x)); -// ASSERT_TRUE(exp.equalsTo(&x)); - delete xx; + // auto x = NDArrayFactory::create('c', {1, 60}); + auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); + // auto x = *xx; + // x.linspace(1); + // auto exp = NDArrayFactory::create('c', {3, 4, 5}, + // {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + // 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + // 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, + // 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + // 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, + // 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); + + // x.permutei({0, 1, 2}); + // x.streamline(); + + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); + + // ASSERT_TRUE(exp.isSameShape(&x)); + // ASSERT_TRUE(exp.equalsTo(&x)); + delete xx; } TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - //x.linspace(1); - for (int l = 0; l < x.lengthOf(); l++) - x.p(l, float(l + 1.f)); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + // x.linspace(1); + for (int l = 0; l < x.lengthOf(); l++) x.p(l, float(l + 1.f)); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 1, 2}); - x.streamline(); + x.permutei({0, 1, 2}); + x.streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_1) { - auto x = NDArrayFactory::empty(); - ASSERT_TRUE(x.isActualOnHostSide()); - ASSERT_TRUE(x.isEmpty()); + auto x = NDArrayFactory::empty(); + ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isEmpty()); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_2) { - auto x = NDArrayFactory::empty_(); + auto x = NDArrayFactory::empty_(); - ASSERT_TRUE(x->isEmpty()); - delete x; + ASSERT_TRUE(x->isEmpty()); + delete x; } TEST_F(NDArrayCudaBasicsTests, Test_Empty_3) { - auto x = NDArrayFactory::empty(sd::DataType::FLOAT32); + auto x = NDArrayFactory::empty(sd::DataType::FLOAT32); - ASSERT_TRUE(x.isEmpty()); + ASSERT_TRUE(x.isEmpty()); } TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) { - auto x = NDArrayFactory::empty_(sd::DataType::FLOAT32); + auto x = NDArrayFactory::empty_(sd::DataType::FLOAT32); - ASSERT_TRUE(x->isEmpty()); - delete x; + ASSERT_TRUE(x->isEmpty()); + delete x; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp index 0400f4b90bf8..7110561e5ef8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -20,52 +20,50 @@ #include #include + #include "testlayers.h" using namespace sd; class NDArrayListTests : public testing::Test { -public: - + public: }; - TEST_F(NDArrayListTests, BasicTests_1) { - NDArrayList list(false); + NDArrayList list(false); - auto x = NDArrayFactory::create('c', {1, 10}); - auto y = NDArrayFactory::create('c', {1, 10}); + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 10}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); - //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); + // ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); } TEST_F(NDArrayListTests, BasicTests_2) { - NDArrayList list(false); + NDArrayList list(false); - auto x = NDArrayFactory::create('c', {1, 10}); - auto y = NDArrayFactory::create('c', {1, 7}); + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 7}); - ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, x.dup())); - ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, y)); + ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, y)); } - TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { - auto input = NDArrayFactory::create('c', {10, 10}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {10, 10}); + input.linspace(1); - NDArrayList list(false); + NDArrayList list(false); - list.unstack(input, 0); + list.unstack(input, 0); - ASSERT_EQ(10, list.elements()); + ASSERT_EQ(10, list.elements()); - auto array = list.stack(); + auto array = list.stack(); - ASSERT_TRUE(input.isSameShape(array)); + ASSERT_TRUE(input.isSameShape(array)); - ASSERT_TRUE(input.equalsTo(array)); + ASSERT_TRUE(input.equalsTo(array)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 669574fa7cc2..c85ade5d18c8 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -18,1410 +18,1484 @@ // Created by raver119 on 04.08.17. // -#include "testlayers.h" -#include #include #include +#include + +#include "testlayers.h" + using namespace sd; ////////////////////////////////////////////////////////////////////// class NDArrayTest : public testing::Test { -public: - int alpha = 0; + public: + int alpha = 0; - Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; - float arr1[6] = {1,2,3,4,5,6}; - Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; - float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; - Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; - const std::vector tileShape1 = {2,2,2}; + float arr1[6] = {1, 2, 3, 4, 5, 6}; + Nd4jLong shape1[8] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float arr2[48] = {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 1, 2, 3, 1, + 2, 3, 4, 5, 6, 4, 5, 6, 1, 2, 3, 1, 2, 3, 4, 5, + 6, 4, 5, 6, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}; + Nd4jLong shape2[10] = {3, 2, 4, 6, 24, 6, 1, 8192, 1, 99}; + const std::vector tileShape1 = {2, 2, 2}; - - ~NDArrayTest() { - delete[] cShape; - delete[] fShape; - } + ~NDArrayTest() { + delete[] cShape; + delete[] fShape; + } }; - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestDup1) { + NDArray array(arr1, shape1); - NDArray array(arr1, shape1); + auto arrC = new NDArray(array.dup('c')); + auto arrF = new NDArray(array.dup('f')); - auto arrC = new NDArray(array.dup('c')); - auto arrF = new NDArray(array.dup('f')); + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); - ASSERT_TRUE(array.equalsTo(arrF)); - ASSERT_TRUE(array.equalsTo(arrC)); + ASSERT_TRUE(arrF->equalsTo(arrC)); - ASSERT_TRUE(arrF->equalsTo(arrC)); - - delete arrC; - delete arrF; + delete arrC; + delete arrF; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, AssignScalar1) { - auto array = NDArrayFactory::create_('c', {1, 10}); + auto array = NDArrayFactory::create_('c', {1, 10}); - array->assign(2.0f); + array->assign(2.0f); - for (int i = 0; i < array->lengthOf(); i++) { - ASSERT_EQ(2.0f, array->e(i)); - } + for (int i = 0; i < array->lengthOf(); i++) { + ASSERT_EQ(2.0f, array->e(i)); + } - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, NDArrayOrder1) { - // original part - auto c = new float[4] {1, 2, 3, 4}; + // original part + auto c = new float[4]{1, 2, 3, 4}; - // expected part - auto f = new float[4] {1, 3, 2, 4}; + // expected part + auto f = new float[4]{1, 3, 2, 4}; - auto arrayC = new NDArray(c, cShape); - auto arrayF = new NDArray(arrayC->dup('f')); - auto arrayC2 = new NDArray(arrayF->dup('c')); + auto arrayC = new NDArray(c, cShape); + auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC2 = new NDArray(arrayF->dup('c')); - ASSERT_EQ('c', arrayC->ordering()); - ASSERT_EQ('f', arrayF->ordering()); - ASSERT_EQ('c', arrayC2->ordering()); + ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); + ASSERT_EQ('c', arrayC2->ordering()); - for (int i = 0; i < 4; i++) { - ASSERT_NEAR(f[i], arrayF->bufferAsT()[i], 1e-5f); - } + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(f[i], arrayF->bufferAsT()[i], 1e-5f); + } - for (int i = 0; i < 8; i++) { - ASSERT_EQ(fShape[i], arrayF->shapeInfo()[i]); - } + for (int i = 0; i < 8; i++) { + ASSERT_EQ(fShape[i], arrayF->shapeInfo()[i]); + } - for (int i = 0; i < 4; i++) { - ASSERT_NEAR(c[i], arrayC2->bufferAsT()[i], 1e-5f); - } - - for (int i = 0; i < 8; i++) { - ASSERT_EQ(cShape[i], arrayC2->shapeInfo()[i]); - } + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(c[i], arrayC2->bufferAsT()[i], 1e-5f); + } + for (int i = 0; i < 8; i++) { + ASSERT_EQ(cShape[i], arrayC2->shapeInfo()[i]); + } - delete[] c; - delete[] f; - delete arrayC; - delete arrayF; - delete arrayC2; + delete[] c; + delete[] f; + delete arrayC; + delete arrayF; + delete arrayC2; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestGetScalar1) { - auto c = new float[4] {1, 2, 3, 4}; - auto cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - - auto arrayC = new NDArray(c, cShape); - - ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); - ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); + auto c = new float[4]{1, 2, 3, 4}; + auto cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC = new NDArray(c, cShape); - ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); - ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); + ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); + auto arrayF = new NDArray(arrayC->dup('f')); - arrayF->p(1, 0, 7.0f); - ASSERT_NEAR(7.0f, arrayF->e(1, 0), 1e-5f); + ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); + arrayF->p(1, 0, 7.0f); + ASSERT_NEAR(7.0f, arrayF->e(1, 0), 1e-5f); - arrayC->p(1, 1, 9.0f); - ASSERT_NEAR(9.0f, arrayC->e(1, 1), 1e-5f); + arrayC->p(1, 1, 9.0f); + ASSERT_NEAR(9.0f, arrayC->e(1, 1), 1e-5f); - delete[] c; - delete[] cShape; + delete[] c; + delete[] cShape; - delete arrayC; - delete arrayF; + delete arrayC; + delete arrayF; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, EqualityTest1) { - auto arrayA = NDArrayFactory::create_('f', {3, 5}); - auto arrayB = NDArrayFactory::create_('f', {3, 5}); - auto arrayC = NDArrayFactory::create_('f', {3, 5}); + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); - auto arrayD = NDArrayFactory::create_('f', {2, 4}); - auto arrayE = NDArrayFactory::create_('f', {1, 15}); + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); - for (int i = 0; i < arrayA->rows(); i++) { - for (int k = 0; k < arrayA->columns(); k++) { - arrayA->p(i, k, (float) i); - } + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float)i); } + } - for (int i = 0; i < arrayB->rows(); i++) { - for (int k = 0; k < arrayB->columns(); k++) { - arrayB->p(i, k, (float) i); - } + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float)i); } + } - for (int i = 0; i < arrayC->rows(); i++) { - for (int k = 0; k < arrayC->columns(); k++) { - arrayC->p(i, k, (float) i+1); - } + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float)i + 1); } + } - //nd4j_printf("A B\n",""); - ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + // nd4j_printf("A B\n",""); + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); - //nd4j_printf("C B\n",""); - ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + // nd4j_printf("C B\n",""); + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); - //nd4j_printf("D B\n",""); - ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + // nd4j_printf("D B\n",""); + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); - //nd4j_printf("E B\n",""); - ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + // nd4j_printf("E B\n",""); + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); - delete arrayA; - delete arrayB; - delete arrayC; - delete arrayD; - delete arrayE; + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; } TEST_F(NDArrayTest, TestTad1) { - auto array = NDArrayFactory::create_('c', {3, 3}); + auto array = NDArrayFactory::create_('c', {3, 3}); - auto row2 = (*array)(1, {0}); + auto row2 = (*array)(1, {0}); - ASSERT_TRUE(row2.isView()); - ASSERT_EQ(3, row2.lengthOf()); + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); - row2.assign(1.0); + row2.assign(1.0); - ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); - delete array; + ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTad2) { - auto array = NDArrayFactory::create_('c', {3, 3}); + auto array = NDArrayFactory::create_('c', {3, 3}); - ASSERT_EQ(3, array->tensorsAlongDimension({1})); + ASSERT_EQ(3, array->tensorsAlongDimension({1})); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTad3) { - auto array = NDArrayFactory::create_('c', {4, 3}); + auto array = NDArrayFactory::create_('c', {4, 3}); - auto row2 = (*array)(1, {0}); + auto row2 = (*array)(1, {0}); - ASSERT_TRUE(row2.isView()); - ASSERT_EQ(3, row2.lengthOf()); - delete array; + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); + delete array; } - TEST_F(NDArrayTest, TestPermuteReshape1) { + NDArray array('c', {2, 2, 5, 5}, sd::DataType::FLOAT32); + int pShape[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + int rShape[] = {3, 2, 25, 2, 25, 1, 50, 8192, 0, 99}; - NDArray array('c', {2, 2, 5, 5}, sd::DataType::FLOAT32); - int pShape[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; - int rShape[] = {3, 2, 25, 2, 25, 1, 50, 8192, 0, 99}; + array.permutei({1, 2, 3, 0}); - array.permutei({1, 2, 3, 0}); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], array.shapeInfo()[e]); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(pShape[e], array.shapeInfo()[e]); + array.reshapei('c', {2, 25, 2}); - array.reshapei('c', {2, 25, 2}); - - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(rShape[e], array.shapeInfo()[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); } - TEST_F(NDArrayTest, TestPermuteReshape2) { - auto array = NDArrayFactory::create('c', {2, 2, 5, 5, 6, 6}); - int pShape[] = {6, 2, 2, 6, 6, 5, 5, 900, 1800, 6, 1, 180, 36, 8192, 0, 99}; - int rShape[] = {3, 2, 72, 25, 1800, 25, 1, 8192, 1, 99}; + auto array = NDArrayFactory::create('c', {2, 2, 5, 5, 6, 6}); + int pShape[] = {6, 2, 2, 6, 6, 5, 5, 900, 1800, 6, 1, 180, 36, 8192, 0, 99}; + int rShape[] = {3, 2, 72, 25, 1800, 25, 1, 8192, 1, 99}; + // array.printShapeInfo("before"); - // array.printShapeInfo("before"); + array.permutei({1, 0, 4, 5, 2, 3}); - array.permutei({1, 0, 4, 5, 2, 3}); + // array.printShapeInfo("after "); - // array.printShapeInfo("after "); + auto aShape = array.shapeInfo(); - auto aShape = array.shapeInfo(); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], aShape[e]); - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(pShape[e], aShape[e]); + array.reshapei('c', {2, 72, 25}); - array.reshapei('c', {2, 72, 25}); - - for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) - ASSERT_EQ(rShape[e], array.shapeInfo()[e]); + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestRepeat1) { + auto eBuffer = new float[8]{1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + NDArray array('c', {2, 2}, sd::DataType::FLOAT32); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array.lengthOf(); e++) array.p(e, e + 1); - auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; - auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; - NDArray array('c', {2, 2}, sd::DataType::FLOAT32); - auto exp = new NDArray(eBuffer, eShape); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, e + 1); - - // array.printBuffer(); + // array.printBuffer(); - auto rep = array.repeat(0, {2}); + auto rep = array.repeat(0, {2}); - ASSERT_EQ(4, rep.sizeAt(0)); - ASSERT_EQ(2, rep.sizeAt(1)); + ASSERT_EQ(4, rep.sizeAt(0)); + ASSERT_EQ(2, rep.sizeAt(1)); - ASSERT_TRUE(exp->equalsTo(rep)); + ASSERT_TRUE(exp->equalsTo(rep)); - delete[] eBuffer; - delete[] eShape; - delete exp; + delete[] eBuffer; + delete[] eShape; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestRepeat2) { - auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; - auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; - auto array = NDArrayFactory::create_('c', {2, 2}); - auto exp = new NDArray(eBuffer, eShape); - for (int e = 0; e < array->lengthOf(); e++) - array->p(e, e + 1); + auto eBuffer = new float[8]{1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + auto array = NDArrayFactory::create_('c', {2, 2}); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array->lengthOf(); e++) array->p(e, e + 1); - //array->printBuffer(); + // array->printBuffer(); - auto rep = new NDArray(exp->dup()); - rep->assign(0.); - array->repeat(0, {2}, *rep); - //rep->printIndexedBuffer("Repeated"); + auto rep = new NDArray(exp->dup()); + rep->assign(0.); + array->repeat(0, {2}, *rep); + // rep->printIndexedBuffer("Repeated"); - ASSERT_EQ(4, rep->sizeAt(0)); - ASSERT_EQ(2, rep->sizeAt(1)); + ASSERT_EQ(4, rep->sizeAt(0)); + ASSERT_EQ(2, rep->sizeAt(1)); - //rep->printBuffer(); + // rep->printBuffer(); - ASSERT_TRUE(exp->equalsTo(rep)); + ASSERT_TRUE(exp->equalsTo(rep)); - delete[] eBuffer; - delete[] eShape; - delete array; - delete exp; - delete rep; + delete[] eBuffer; + delete[] eShape; + delete array; + delete exp; + delete rep; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestIndexedPut1) { - auto array = NDArrayFactory::create_('f', {3, 3}); + auto array = NDArrayFactory::create_('f', {3, 3}); - array->p(4, 1.0f); - ASSERT_EQ(1.0f, array->e(4)); - //array->printBuffer(); + array->p(4, 1.0f); + ASSERT_EQ(1.0f, array->e(4)); + // array->printBuffer(); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSum1) { - // Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; - float *c = new float[4] {1, 2, 3, 4}; + // Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + float *c = new float[4]{1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + auto array = new NDArray(c, cShape); - ASSERT_EQ(10.0f, array->sumNumber().e(0)); - ASSERT_EQ(2.5f, array->meanNumber().e(0)); + ASSERT_EQ(10.0f, array->sumNumber().e(0)); + ASSERT_EQ(2.5f, array->meanNumber().e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestAddiRowVector) { - float *c = new float[4] {1, 2, 3, 4}; - float *e = new float[4] {2, 3, 4, 5}; + float *c = new float[4]{1, 2, 3, 4}; + float *e = new float[4]{2, 3, 4, 5}; - auto array = new NDArray(c, cShape); - auto row = NDArrayFactory::create_('c', {1, 2}); - auto exp = new NDArray(e, cShape); - row->assign(1.0f); + auto array = new NDArray(c, cShape); + auto row = NDArrayFactory::create_('c', {1, 2}); + auto exp = new NDArray(e, cShape); + row->assign(1.0f); - array->addiRowVector(*row); + array->addiRowVector(*row); - ASSERT_TRUE(exp->equalsTo(array)); + ASSERT_TRUE(exp->equalsTo(array)); - delete[] c; - delete[] e; + delete[] c; + delete[] e; - delete array; - delete row; - delete exp; + delete array; + delete row; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestAddiColumnVector) { - float arr1[] = {1, 2, 3, 4}; - float arr2[] = {5, 6}; - float arr3[] = {6, 7, 9, 10}; - Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; - Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; - NDArray matrix(arr1, shape1); - NDArray column(arr2, shape2); - NDArray exp(arr3, shape1); + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {6, 7, 9, 10}; + Nd4jLong shape1[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {2, 2, 1, 1, 1, 8192, 1, 99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); - matrix.addiColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(matrix)); - ASSERT_TRUE(exp.equalsTo(&matrix)); + matrix.addiColumnVector(column); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMuliColumnVector) { - float arr1[] = {1, 2, 3, 4}; - float arr2[] = {5, 6}; - float arr3[] = {5, 10, 18, 24}; - Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; - Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; - NDArray matrix(arr1, shape1); - NDArray column(arr2, shape2); - NDArray exp(arr3, shape1); + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {5, 10, 18, 24}; + Nd4jLong shape1[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {2, 2, 1, 1, 1, 8192, 1, 99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); - matrix.muliColumnVector(column); + matrix.muliColumnVector(column); - ASSERT_TRUE(exp.isSameShapeStrict(matrix)); - ASSERT_TRUE(exp.equalsTo(&matrix)); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test3D_1) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - auto arrayF = NDArrayFactory::create_('f', {2, 5, 10}); + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayF = NDArrayFactory::create_('f', {2, 5, 10}); - ASSERT_EQ(100, arrayC->lengthOf()); - ASSERT_EQ(100, arrayF->lengthOf()); + ASSERT_EQ(100, arrayC->lengthOf()); + ASSERT_EQ(100, arrayF->lengthOf()); - ASSERT_EQ('c', arrayC->ordering()); - ASSERT_EQ('f', arrayF->ordering()); + ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); - delete arrayC; - delete arrayF; + delete arrayC; + delete arrayF; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTranspose1) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; - auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + auto expC = new Nd4jLong[10]{3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10]{3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; - auto arrayT = arrayC->transpose(); + auto arrayT = arrayC->transpose(); - for (int e = 0; e < arrayC->rankOf(); e++) { - ASSERT_EQ(shape::shapeOf(expC)[e], arrayC->sizeAt(e)); - ASSERT_EQ(shape::shapeOf(expT)[e], arrayT.sizeAt(e)); - } + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expC)[e], arrayC->sizeAt(e)); + ASSERT_EQ(shape::shapeOf(expT)[e], arrayT.sizeAt(e)); + } - delete arrayC; - delete[] expC; - delete[] expT; + delete arrayC; + delete[] expC; + delete[] expT; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTranspose2) { - auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); - auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; - auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + auto expC = new Nd4jLong[10]{3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10]{3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; - arrayC->transposei(); + arrayC->transposei(); + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expT)[e], arrayC->sizeAt(e)); + } - for (int e = 0; e < arrayC->rankOf(); e++) { - ASSERT_EQ(shape::shapeOf(expT)[e], arrayC->sizeAt(e)); - } - - delete arrayC; - delete[] expC; - delete[] expT; + delete arrayC; + delete[] expC; + delete[] expT; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSumAlongDimension1) { + NDArray array('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); - NDArray array('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); - - auto res = array.reduceAlongDimension(reduce::Sum, {0}); + auto res = array.reduceAlongDimension(reduce::Sum, {0}); - ASSERT_EQ(2, res.lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(4.0f, res.e(0)); - ASSERT_EQ(6.0f, res.e(1)); + ASSERT_EQ(4.0f, res.e(0)); + ASSERT_EQ(6.0f, res.e(1)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSumAlongDimension2) { - float *c = new float[4] {1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{1, 2, 3, 4}; + auto array = new NDArray(c, cShape); - auto res = array->reduceAlongDimension(reduce::Sum, {1}); + auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res.lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res.e(0)); - ASSERT_EQ(7.0f, res.e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceAlongDimension1) { - float *c = new float[4] {1, 2, 3, 4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{1, 2, 3, 4}; + auto array = new NDArray(c, cShape); - auto res = array->reduceAlongDimension(reduce::Sum, {1}); + auto res = array->reduceAlongDimension(reduce::Sum, {1}); - ASSERT_EQ(2, res.lengthOf()); + ASSERT_EQ(2, res.lengthOf()); - ASSERT_EQ(3.0f, res.e(0)); - ASSERT_EQ(7.0f, res.e(1)); + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTransform1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - float *e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + float *e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, *array); + array->applyTransform(transform::Abs, *array); - ASSERT_TRUE(exp->equalsTo(array)); + ASSERT_TRUE(exp->equalsTo(array)); - delete[] c; - delete array; - delete[] e; - delete exp; + delete[] c; + delete array; + delete[] e; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - ASSERT_EQ(-4, array->reduceNumber(reduce::Min, nullptr).e(0)); + ASSERT_EQ(-4, array->reduceNumber(reduce::Min, nullptr).e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar2) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - ASSERT_EQ(-10, array->reduceNumber(reduce::Sum, nullptr).e(0)); + ASSERT_EQ(-10, array->reduceNumber(reduce::Sum, nullptr).e(0)); - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestReduceScalar3) { - auto array = new NDArray(arr1, shape1); + auto array = new NDArray(arr1, shape1); - ASSERT_EQ(21, array->reduceNumber(reduce::Sum, nullptr).e(0)); + ASSERT_EQ(21, array->reduceNumber(reduce::Sum, nullptr).e(0)); - delete array; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestApplyTransform1) { - float *c = new float[4] {-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - float *e = new float[4] {1, 2, 3, 4}; - auto exp = new NDArray(e, cShape); + float *e = new float[4]{1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); - array->applyTransform(transform::Abs, *array); + array->applyTransform(transform::Abs, *array); + ASSERT_TRUE(exp->equalsTo(array)); - ASSERT_TRUE(exp->equalsTo(array)); + delete[] c; + delete array; - delete[] c; - delete array; - - delete[] e; - delete exp; + delete[] e; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVectors1) { - float *c = new float[4]{-1, -2, -3, -4}; - auto array = new NDArray(c, cShape); - + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); - auto vecShape = array->getShapeInfoAsVector(); - auto vecBuffer = array->getBufferAsVector(); + auto vecShape = array->getShapeInfoAsVector(); + auto vecBuffer = array->getBufferAsVector(); - ASSERT_EQ(8, vecShape.size()); - ASSERT_EQ(4, vecBuffer.size()); + ASSERT_EQ(8, vecShape.size()); + ASSERT_EQ(4, vecBuffer.size()); - for (int e = 0; e < vecBuffer.size(); e++) { - ASSERT_NEAR(c[e], vecBuffer[e], 1e-5); - } + for (int e = 0; e < vecBuffer.size(); e++) { + ASSERT_NEAR(c[e], vecBuffer[e], 1e-5); + } - for (int e = 0; e < vecShape.size(); e++) { - ASSERT_EQ(cShape[e], vecShape[e]); - } + for (int e = 0; e < vecShape.size(); e++) { + ASSERT_EQ(cShape[e], vecShape[e]); + } - delete[] c; - delete array; + delete[] c; + delete array; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks1) { - auto array = NDArrayFactory::create('c', {1, 5}); + auto array = NDArrayFactory::create('c', {1, 5}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_TRUE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_TRUE(array.isRowVector()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_TRUE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks2) { - auto array = NDArrayFactory::create('c', {5, 5}); + auto array = NDArrayFactory::create('c', {5, 5}); - ASSERT_TRUE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); + ASSERT_TRUE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks3) { - auto array = NDArrayFactory::create('c', {5, 1}); + auto array = NDArrayFactory::create('c', {5, 1}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isScalar()); - ASSERT_TRUE(array.isVector()); - ASSERT_TRUE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_TRUE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks4) { - auto array = NDArrayFactory::create('c', {1, 1}); + auto array = NDArrayFactory::create('c', {1, 1}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); - ASSERT_TRUE(array.isScalar()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_TRUE(array.isScalar()); } TEST_F(NDArrayTest, TestReductionAny1) { - auto array = NDArrayFactory::create('c', {2, 2}); - array.p(0, 1.0f); - array.p(1, 1.0f); - array.p(2, 0.0f); - array.p(3, 0.0f); - array.syncToDevice(); - auto result0 = array.reduceAlongDimension(reduce::Any, {0}); + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); + array.syncToDevice(); + auto result0 = array.reduceAlongDimension(reduce::Any, {0}); - ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); - ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); - ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); - auto result1 = array.reduceAlongDimension(reduce::Any, {1}); + auto result1 = array.reduceAlongDimension(reduce::Any, {1}); - ASSERT_EQ(2, result1.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); - ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); + ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); + ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); } TEST_F(NDArrayTest, TestReductionAll1) { - auto array = NDArrayFactory::create('c', {2, 2}); - array.p(0, 1.0f); - array.p(1, 1.0f); - array.p(2, 0.0f); - array.p(3, 0.0f); + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); - auto result0 = array.reduceAlongDimension(reduce::All, {0}); - auto result1 = array.reduceAlongDimension(reduce::All, {1}); + auto result0 = array.reduceAlongDimension(reduce::All, {0}); + auto result1 = array.reduceAlongDimension(reduce::All, {1}); - ASSERT_EQ(2, result0.lengthOf()); - ASSERT_EQ(2, result1.lengthOf()); + ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); - ASSERT_FALSE(result0.e(0)); - ASSERT_FALSE(result0.e(1)); + ASSERT_FALSE(result0.e(0)); + ASSERT_FALSE(result0.e(1)); - ASSERT_TRUE(result1.e(0)); - ASSERT_FALSE(result1.e(1)); + ASSERT_TRUE(result1.e(0)); + ASSERT_FALSE(result1.e(1)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestChecks5) { - auto array = NDArrayFactory::create('c', {5, 5, 5}); + auto array = NDArrayFactory::create('c', {5, 5, 5}); - ASSERT_FALSE(array.isMatrix()); - ASSERT_FALSE(array.isVector()); - ASSERT_FALSE(array.isColumnVector()); - ASSERT_FALSE(array.isRowVector()); - ASSERT_FALSE(array.isScalar()); + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_FALSE(array.isScalar()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile1) { + // float arr1[6] = {1,2,3,4,5,6}; + // Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; + // float arr2[48] = + // {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; + // Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; - // float arr1[6] = {1,2,3,4,5,6}; - // Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; - // float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; - // Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; + NDArray array1(arr1, shape1); // {2,3} + NDArray array2(arr2, shape2); // {2,4,6} + auto expA = new NDArray(array1.dup('c')); - NDArray array1(arr1,shape1); // {2,3} - NDArray array2(arr2,shape2); // {2,4,6} - auto expA = new NDArray(array1.dup('c')); + auto tiled = array1.tile(tileShape1); - auto tiled = array1.tile(tileShape1); + // array2.printShapeInfo("Expct shape"); + // tiled.printShapeInfo("Tiled shape"); + // tiled.printBuffer(); - // array2.printShapeInfo("Expct shape"); - // tiled.printShapeInfo("Tiled shape"); - // tiled.printBuffer(); + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); - ASSERT_TRUE(tiled.isSameShape(&array2)); - ASSERT_TRUE(tiled.equalsTo(&array2)); + ASSERT_TRUE(expA->isSameShape(&array1)); + ASSERT_TRUE(expA->equalsTo(&array1)); - ASSERT_TRUE(expA->isSameShape(&array1)); - ASSERT_TRUE(expA->equalsTo(&array1)); - - // delete tiled; - delete expA; + // delete tiled; + delete expA; } TEST_F(NDArrayTest, TestTile2) { + NDArray array1(arr1, shape1); + NDArray array2(arr2, shape2); - NDArray array1(arr1,shape1); - NDArray array2(arr2,shape2); - - auto tiled = array1.tile(tileShape1); + auto tiled = array1.tile(tileShape1); - ASSERT_TRUE(tiled.isSameShape(&array2)); - ASSERT_TRUE(tiled.equalsTo(&array2)); - // delete tiled; + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); + // delete tiled; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile3) { + NDArray array1(arr1, shape1); + NDArray array2(arr2, shape2); - NDArray array1(arr1,shape1); - NDArray array2(arr2,shape2); + array1.tilei(tileShape1); - array1.tilei(tileShape1); - - ASSERT_TRUE(array1.isSameShapeStrict(array2)); - ASSERT_TRUE(array1.equalsTo(&array2)); + ASSERT_TRUE(array1.isSameShapeStrict(array2)); + ASSERT_TRUE(array1.equalsTo(&array2)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile4) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {1.f, 2.f, 1.f, 2.f, 3.f, 4.f, + 3.f, 4.f, 5.f, 6.f, 5.f, 6.f}; - float xBuff[] = {1,2,3,4,5,6}; - float expBuff[] = {1.f,2.f, 1.f,2.f, 3.f,4.f, 3.f,4.f, 5.f,6.f, 5.f,6.f}; - - auto x = NDArrayFactory::create(xBuff, 'c', {3,1,2}); - auto exp = NDArrayFactory::create(expBuff, 'c', {3,2,2}); + auto x = NDArrayFactory::create(xBuff, 'c', {3, 1, 2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - auto result = x.tile({2,1}); + auto result = x.tile({2, 1}); - ASSERT_TRUE(result.isSameShapeStrict(exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTile5) { + float xBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 9.f, 10.f, 11.f, 12.f}; - float xBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12}; - float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f,12.f, 9.f,10.f, 11.f,12.f}; + auto x = NDArrayFactory::create(xBuff, 'c', {3, 2, 2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - auto x = NDArrayFactory::create(xBuff, 'c', {3,2,2}); - auto exp = NDArrayFactory::create(expBuff, 'c', {3,4,2}); + auto result = x.tile({2, 1}); - auto result = x.tile({2,1}); - - ASSERT_TRUE(result.isSameShapeStrict(exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, TestTile6) -{ - double expBuff[] = {10.,11., 10.,11., 10.,11., 10.,11., 12.,13., 12.,13., 12.,13., 12.,13., 14.,15., 14.,15., 14.,15., 14.,15.}; +TEST_F(NDArrayTest, TestTile6) { + double expBuff[] = {10., 11., 10., 11., 10., 11., 10., 11., + 12., 13., 12., 13., 12., 13., 12., 13., + 14., 15., 14., 15., 14., 15., 14., 15.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); + x.linspace(10); - auto result = x.tile({1,4,1}); + auto result = x.tile({1, 4, 1}); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper1) { - auto xBuffer = new float[3]{1.f, 2.f, 3.f}; - auto xShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[3]{1.f, 2.f, 3.f}; + auto xShape = new Nd4jLong[8]{2, 1, 3, 1, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8]{2, 1, 3, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = MmulHelper::mmul(x, y); + auto z = MmulHelper::mmul(x, y); - ASSERT_EQ(1, z->lengthOf()); - ASSERT_NEAR(28, z->e(0), 1e-5); + ASSERT_EQ(1, z->lengthOf()); + ASSERT_NEAR(28, z->e(0), 1e-5); - delete z; - delete[] xBuffer; - delete[] xShape; - delete[] yBuffer; - delete[] yShape; - delete y; - delete x; + delete z; + delete[] xBuffer; + delete[] xShape; + delete[] yBuffer; + delete[] yShape; + delete y; + delete x; } - TEST_F(NDArrayTest, TestPermuteReshapeMmul1) { - auto x = NDArrayFactory::create('c', {6, 3}); - auto y = NDArrayFactory::create('c', {3, 6}); + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); - Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, + 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - x.permutei({1, 0}); - y.permutei({1, 0}); + x.permutei({1, 0}); + y.permutei({1, 0}); - auto z = MmulHelper::mmul(&x, &y); + auto z = MmulHelper::mmul(&x, &y); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; + delete z; } TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { - auto x = NDArrayFactory::create('c', {6, 3}); - auto y = NDArrayFactory::create('c', {3, 6}); + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); - Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, + 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - auto x_ = new NDArray(x.dup('f')); - auto y_ = new NDArray(y.dup('f')); + auto x_ = new NDArray(x.dup('f')); + auto y_ = new NDArray(y.dup('f')); - x_->permutei({1, 0}); - y_->permutei({1, 0}); + x_->permutei({1, 0}); + y_->permutei({1, 0}); - auto z = MmulHelper::mmul(x_, y_); + auto z = MmulHelper::mmul(x_, y_); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; - delete x_; - delete y_; + delete z; + delete x_; + delete y_; } - TEST_F(NDArrayTest, TestPermuteReshapeMmul3) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); - auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2, 2}); - Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, + 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, + 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - x.permutei({0, 3, 4, 5, 1, 2}); - y.permutei({3, 2, 1, 0}); + x.permutei({0, 3, 4, 5, 1, 2}); + y.permutei({3, 2, 1, 0}); - x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); - y.reshapei('c', {2 * 2 * 3, 2}); + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y.reshapei('c', {2 * 2 * 3, 2}); - auto z = MmulHelper::mmul(&x, &y); + auto z = MmulHelper::mmul(&x, &y); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; + delete z; } TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { - auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); - auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2, 2}); - Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; - NDArray exp(_expB, _expS); + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, + 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, + 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); - for (int e = 0; e < x.lengthOf(); e++) - x.p(e, e+1); + for (int e = 0; e < x.lengthOf(); e++) x.p(e, e + 1); - for (int e = 0; e < y.lengthOf(); e++) - y.p(e, e+1); + for (int e = 0; e < y.lengthOf(); e++) y.p(e, e + 1); - auto y_ = new NDArray(y.dup('f')); + auto y_ = new NDArray(y.dup('f')); - x.permutei({0, 3, 4, 5, 1, 2}); - y_->permutei({3, 2, 1, 0}); + x.permutei({0, 3, 4, 5, 1, 2}); + y_->permutei({3, 2, 1, 0}); - x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); - y_->reshapei('c', {2 * 2 * 3, 2}); + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y_->reshapei('c', {2 * 2 * 3, 2}); - auto z = MmulHelper::mmul(&x, y_); + auto z = MmulHelper::mmul(&x, y_); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); - delete z; - delete y_; + delete z; + delete y_; } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper2) { - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - Nd4jLong xShape[8] = {2, 5, 3, 3, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape, sd::LaunchContext ::defaultContext(), true); - + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + Nd4jLong xShape[8] = {2, 5, 3, 3, 1, 8192, 1, 99}; + auto x = + new NDArray(xBuffer, xShape, sd::LaunchContext ::defaultContext(), true); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - Nd4jLong yShape[8] = {2, 3, 1, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape, sd::LaunchContext ::defaultContext(), true); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + Nd4jLong yShape[8] = {2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = + new NDArray(yBuffer, yShape, sd::LaunchContext ::defaultContext(), true); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo(), sd::LaunchContext ::defaultContext(), true); + auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo(), + sd::LaunchContext ::defaultContext(), true); - //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + // sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), + // y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - MmulHelper::mmul(x, y, z); + MmulHelper::mmul(x, y, z); - //z->printBuffer(); + // z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper3) { - auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + auto xShape = new Nd4jLong[8]{2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[3]{2.f, 4.f, 6.f}; - auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8]{2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {5, 1}); + auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + // sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), + // y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); - MmulHelper::mmul(x, y, z); + MmulHelper::mmul(x, y, z); - //z->printBuffer(); + // z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper4) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 2, 1, 8192, 1, 99}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 2, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, 46.0f, 13.0f, 35.0f, 57.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, + 46.0f, 13.0f, 35.0f, 57.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper5) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, 30.0f, 17.0f, 28.0f, 39.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, + 30.0f, 17.0f, 28.0f, 39.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper6) { - auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; - auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); - - auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; - auto yShape = new Nd4jLong[8] {2, 2, 3, 1, 2, 8192, 1, 102}; - auto y = new NDArray(yBuffer, yShape); + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8]{2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto z = NDArrayFactory::create_('f', {3, 3}); + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8]{2, 2, 3, 1, 2, 8192, 1, 102}; + auto y = new NDArray(yBuffer, yShape); - auto expBuffer = new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto z = NDArrayFactory::create_('f', {3, 3}); - MmulHelper::mmul(x, y, z); - ASSERT_TRUE(z->equalsTo(exp)); + auto expBuffer = + new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMmulHelper7) { - auto xBuffer = new float[15]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; - auto x = new NDArray(xBuffer, xShape); + auto xBuffer = + new float[15]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + auto xShape = new Nd4jLong[8]{2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); - auto yBuffer = new float[5]{2, 4, 6, 8, 10}; - auto yShape = new Nd4jLong[8] {2, 1, 5, 1, 1, 8192, 1, 99}; - auto y = new NDArray(yBuffer, yShape); + auto yBuffer = new float[5]{2, 4, 6, 8, 10}; + auto yShape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); - auto z = NDArrayFactory::create_('f', {1, 3}); + auto z = NDArrayFactory::create_('f', {1, 3}); - auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; - auto exp = new NDArray(expBuffer, z->shapeInfo()); + auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); - MmulHelper::mmul(y, x, z); + MmulHelper::mmul(y, x, z); - //z->printBuffer(); - ASSERT_TRUE(z->equalsTo(exp)); + // z->printBuffer(); + ASSERT_TRUE(z->equalsTo(exp)); - delete[] expBuffer; - delete[] xBuffer; - delete[] yBuffer; - delete[] xShape; - delete[] yShape; + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; - delete x; - delete y; - delete z; - delete exp; + delete x; + delete y; + delete z; + delete exp; } - TEST_F(NDArrayTest, TestMmulHelper_ND_1) { - Nd4jLong _expS[] = {3, 2, 3, 3, 9, 3, 1, 8192, 1, 99}; - float _expB[] = {70.f, 80.f, 90.f, 158.f, 184.f, 210.f, 246.f, 288.f, 330.f, 1030.f, 1088.f, 1146.f, 1310.f, 1384.f, 1458.f, 1590.f, 1680.f, 1770.f}; + Nd4jLong _expS[] = {3, 2, 3, 3, 9, 3, 1, 8192, 1, 99}; + float _expB[] = {70.f, 80.f, 90.f, 158.f, 184.f, 210.f, + 246.f, 288.f, 330.f, 1030.f, 1088.f, 1146.f, + 1310.f, 1384.f, 1458.f, 1590.f, 1680.f, 1770.f}; - auto a = NDArrayFactory::create('c', {2, 3, 4}); - for (int e = 0; e < a.lengthOf(); e++) - a.p(e, e+1); + auto a = NDArrayFactory::create('c', {2, 3, 4}); + for (int e = 0; e < a.lengthOf(); e++) a.p(e, e + 1); - auto b = NDArrayFactory::create('c', {2, 4, 3}); - for (int e = 0; e < b.lengthOf(); e++) - b.p(e, e+1); + auto b = NDArrayFactory::create('c', {2, 4, 3}); + for (int e = 0; e < b.lengthOf(); e++) b.p(e, e + 1); - NDArray exp(_expB, _expS); - auto c = MmulHelper::mmul(&a, &b); + NDArray exp(_expB, _expS); + auto c = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(exp.isSameShape(c)); - ASSERT_TRUE(exp.equalsTo(c)); + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c)); - delete c; + delete c; } - TEST_F(NDArrayTest, TestMmulHelper_ND_2) { - Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; - float _expB[] = { - 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, 4.19750000e+04f, 4.35500000e+04f, - 5.76000000e+04f, 5.98000000e+04f, 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, - 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, 1.35725000e+05f, 1.41050000e+05f, - 1.51350000e+05f, 1.57300000e+05f, 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, - 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, 2.29475000e+05f, 2.38550000e+05f, - 2.45100000e+05f, 2.54800000e+05f, 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, - 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, 3.23225000e+05f, 3.36050000e+05f, - 3.38850000e+05f, 3.52300000e+05f, 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, - 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, 4.16975000e+05f, 4.33550000e+05f, - 4.32600000e+05f, 4.49800000e+05f, 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, - 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, 5.10725000e+05f, 5.31050000e+05f, - 5.26350000e+05f, 5.47300000e+05f, 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, - 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, 6.04475000e+05f, 6.28550000e+05f, - 6.20100000e+05f, 6.44800000e+05f, 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, - 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, 6.98225000e+05f, 7.26050000e+05f, - 7.13850000e+05f, 7.42300000e+05f, 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, - 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, 7.91975000e+05f, 8.23550000e+05f, - 8.07600000e+05f, 8.39800000e+05f, 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, - 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, 8.85725000e+05f, 9.21050000e+05f, - 9.01350000e+05f, 9.37300000e+05f, 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, - 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, 9.79475000e+05f, 1.01855000e+06f, - 9.95100000e+05f, 1.03480000e+06f, 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, - 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, 1.07322500e+06f, 1.11605000e+06f, - 1.08885000e+06f, 1.13230000e+06f, 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, - 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, 1.16697500e+06f, 1.21355000e+06f, - 3.54260000e+06f, 3.58980000e+06f, 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, - 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, 3.77697500e+06f, 3.82730000e+06f, - 3.82385000e+06f, 3.87480000e+06f, 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, - 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, 4.05822500e+06f, 4.11230000e+06f, - 4.10510000e+06f, 4.15980000e+06f, 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, - 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, 4.33947500e+06f, 4.39730000e+06f, - 4.38635000e+06f, 4.44480000e+06f, 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, - 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, 4.62072500e+06f, 4.68230000e+06f, - 4.66760000e+06f, 4.72980000e+06f, 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, - 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, 4.90197500e+06f, 4.96730000e+06f, - 4.94885000e+06f, 5.01480000e+06f, 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, - 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, 5.18322500e+06f, 5.25230000e+06f, - 5.23010000e+06f, 5.29980000e+06f, 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, - 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, 5.46447500e+06f, 5.53730000e+06f, - 5.51135000e+06f, 5.58480000e+06f, 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, - 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, 5.74572500e+06f, 5.82230000e+06f, - 5.79260000e+06f, 5.86980000e+06f, 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, - 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, 6.02697500e+06f, 6.10730000e+06f, - 6.07385000e+06f, 6.15480000e+06f, 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, - 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, 6.30822500e+06f, 6.39230000e+06f, - 6.35510000e+06f, 6.43980000e+06f, 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, - 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, 6.58947500e+06f, 6.67730000e+06f, - 6.63635000e+06f, 6.72480000e+06f, 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, - 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, 6.87072500e+06f, 6.96230000e+06f, - 6.91760000e+06f, 7.00980000e+06f, 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, - 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, 1.19182250e+07f, 1.20135500e+07f, - 1.19963500e+07f, 1.20923000e+07f, 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; - - auto a = NDArrayFactory::create('c', {2, 72, 25}); - for (int e = 0; e < a.lengthOf(); e++) - a.p(e, e+1); - - auto b = NDArrayFactory::create('c', {2, 25, 2}); - for (int e = 0; e < b.lengthOf(); e++) - b.p(e, e+1); - - NDArray exp(_expB, _expS); - - auto c = MmulHelper::mmul(&a, &b); - - ASSERT_TRUE(exp.isSameShape(c)); - ASSERT_TRUE(exp.equalsTo(c, 1e1)); - - delete c; + Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; + float _expB[] = { + 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, + 4.19750000e+04f, 4.35500000e+04f, 5.76000000e+04f, 5.98000000e+04f, + 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, + 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, + 1.35725000e+05f, 1.41050000e+05f, 1.51350000e+05f, 1.57300000e+05f, + 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, + 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, + 2.29475000e+05f, 2.38550000e+05f, 2.45100000e+05f, 2.54800000e+05f, + 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, + 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, + 3.23225000e+05f, 3.36050000e+05f, 3.38850000e+05f, 3.52300000e+05f, + 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, + 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, + 4.16975000e+05f, 4.33550000e+05f, 4.32600000e+05f, 4.49800000e+05f, + 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, + 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, + 5.10725000e+05f, 5.31050000e+05f, 5.26350000e+05f, 5.47300000e+05f, + 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, + 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, + 6.04475000e+05f, 6.28550000e+05f, 6.20100000e+05f, 6.44800000e+05f, + 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, + 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, + 6.98225000e+05f, 7.26050000e+05f, 7.13850000e+05f, 7.42300000e+05f, + 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, + 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, + 7.91975000e+05f, 8.23550000e+05f, 8.07600000e+05f, 8.39800000e+05f, + 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, + 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, + 8.85725000e+05f, 9.21050000e+05f, 9.01350000e+05f, 9.37300000e+05f, + 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, + 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, + 9.79475000e+05f, 1.01855000e+06f, 9.95100000e+05f, 1.03480000e+06f, + 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, + 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, + 1.07322500e+06f, 1.11605000e+06f, 1.08885000e+06f, 1.13230000e+06f, + 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, + 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, + 1.16697500e+06f, 1.21355000e+06f, 3.54260000e+06f, 3.58980000e+06f, + 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, + 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, + 3.77697500e+06f, 3.82730000e+06f, 3.82385000e+06f, 3.87480000e+06f, + 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, + 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, + 4.05822500e+06f, 4.11230000e+06f, 4.10510000e+06f, 4.15980000e+06f, + 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, + 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, + 4.33947500e+06f, 4.39730000e+06f, 4.38635000e+06f, 4.44480000e+06f, + 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, + 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, + 4.62072500e+06f, 4.68230000e+06f, 4.66760000e+06f, 4.72980000e+06f, + 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, + 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, + 4.90197500e+06f, 4.96730000e+06f, 4.94885000e+06f, 5.01480000e+06f, + 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, + 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, + 5.18322500e+06f, 5.25230000e+06f, 5.23010000e+06f, 5.29980000e+06f, + 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, + 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, + 5.46447500e+06f, 5.53730000e+06f, 5.51135000e+06f, 5.58480000e+06f, + 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, + 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, + 5.74572500e+06f, 5.82230000e+06f, 5.79260000e+06f, 5.86980000e+06f, + 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, + 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, + 6.02697500e+06f, 6.10730000e+06f, 6.07385000e+06f, 6.15480000e+06f, + 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, + 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, + 6.30822500e+06f, 6.39230000e+06f, 6.35510000e+06f, 6.43980000e+06f, + 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, + 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, + 6.58947500e+06f, 6.67730000e+06f, 6.63635000e+06f, 6.72480000e+06f, + 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, + 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, + 6.87072500e+06f, 6.96230000e+06f, 6.91760000e+06f, 7.00980000e+06f, + 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, + 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, + 1.19182250e+07f, 1.20135500e+07f, 1.19963500e+07f, 1.20923000e+07f, + 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, + 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, + 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, + 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, + 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, + 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, + 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, + 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, + 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, + 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, + 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, + 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, + 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, + 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, + 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, + 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, + 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, + 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, + 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, + 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, + 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, + 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, + 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, + 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, + 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, + 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, + 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, + 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, + 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, + 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, + 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, + 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, + 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, + 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, + 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, + 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, + 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, + 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, + 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, + 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, + 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, + 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, + 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, + 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, + 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, + 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, + 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, + 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, + 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, + 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, + 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, + 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, + 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, + 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, + 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, + 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, + 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, + 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, + 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, + 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, + 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, + 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, + 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, + 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, + 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, + 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, + 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, + 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, + 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, + 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, + 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, + 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, + 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; + + auto a = NDArrayFactory::create('c', {2, 72, 25}); + for (int e = 0; e < a.lengthOf(); e++) a.p(e, e + 1); + + auto b = NDArrayFactory::create('c', {2, 25, 2}); + for (int e = 0; e < b.lengthOf(); e++) b.p(e, e + 1); + + NDArray exp(_expB, _expS); + + auto c = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c, 1e1)); + + delete c; } - TEST_F(NDArrayTest, TestNegSize1) { - auto array = NDArrayFactory::create('c', {2, 5, 7}); + auto array = NDArrayFactory::create('c', {2, 5, 7}); - ASSERT_EQ(7, array.sizeAt(-1)); - ASSERT_EQ(5, array.sizeAt(-2)); - ASSERT_EQ(2, array.sizeAt(-3)); + ASSERT_EQ(7, array.sizeAt(-1)); + ASSERT_EQ(5, array.sizeAt(-2)); + ASSERT_EQ(2, array.sizeAt(-3)); } ////////////////////////////////////////////////////////////////////// // not-in-place TEST_F(NDArrayTest, Permute1) { + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; - Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; - Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; - const std::initializer_list perm = {2, 0, 1}; - - NDArray arr1(shape1,true); - NDArray arr2(shape2,true); + NDArray arr1(shape1, true); + NDArray arr2(shape2, true); - auto result = arr1.permute(perm); - ASSERT_TRUE(result.isSameShapeStrict(arr2)); + auto result = arr1.permute(perm); + ASSERT_TRUE(result.isSameShapeStrict(arr2)); } ////////////////////////////////////////////////////////////////////// // in-place TEST_F(NDArrayTest, Permute2) { + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; - Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; - Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; - const std::initializer_list perm = {2, 0, 1}; + NDArray arr1(shape1, true); + NDArray arr2(shape2, true); - NDArray arr1(shape1,true); - NDArray arr2(shape2,true); - - ASSERT_TRUE(arr1.permutei(perm)); - ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); + ASSERT_TRUE(arr1.permutei(perm)); + ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); } TEST_F(NDArrayTest, RSubScalarTest1) { - auto array = NDArrayFactory::create('c', {1, 4}); - array.assign(2.0); + auto array = NDArrayFactory::create('c', {1, 4}); + array.assign(2.0); - auto result = NDArrayFactory::create('c', {1, 4}); + auto result = NDArrayFactory::create('c', {1, 4}); - array.applyScalar(scalar::ReverseSubtract, 1.0, result); + array.applyScalar(scalar::ReverseSubtract, 1.0, result); - ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); + ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); } TEST_F(NDArrayTest, BroadcastOpsTest1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + float *brow = new float[5]{1, 2, 3, 4, 5}; + auto bshape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; + float *ebuf = new float[25]{1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, + 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; + auto eshape = new Nd4jLong[8]{2, 5, 5, 5, 1, 8192, 1, 99}; + NDArray expRow(brow, bshape); + NDArray exp(ebuf, eshape); - auto x = NDArrayFactory::create('c', {5, 5}); - auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); - float *brow = new float[5]{1,2,3,4,5}; - auto bshape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; - float *ebuf = new float[25] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; - auto eshape = new Nd4jLong[8] {2, 5, 5, 5, 1, 8192, 1, 99}; - NDArray expRow(brow, bshape); - NDArray exp(ebuf, eshape); - - ASSERT_TRUE(row->equalsTo(&expRow)); + ASSERT_TRUE(row->equalsTo(&expRow)); + x.applyBroadcast(broadcast::Add, {1}, *row, x); - x.applyBroadcast(broadcast::Add, {1}, *row, x); + // x.printBuffer("Result"); - //x.printBuffer("Result"); + ASSERT_TRUE(x.equalsTo(&exp)); - ASSERT_TRUE(x.equalsTo(&exp)); - - delete[] brow; - delete[] bshape; - delete[] ebuf; - delete[] eshape; - delete row; + delete[] brow; + delete[] bshape; + delete[] ebuf; + delete[] eshape; + delete row; } TEST_F(NDArrayTest, TestIndexedPut2) { - auto x = NDArrayFactory::create('f', {2, 2}); - //x.printShapeInfo("x shape"); - x.p(1, 1.0f); + auto x = NDArrayFactory::create('f', {2, 2}); + // x.printShapeInfo("x shape"); + x.p(1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); } TEST_F(NDArrayTest, TestIndexedPut3) { - auto x = NDArrayFactory::create('c', {2, 2}); - x.p(1, 1.0f); + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[1], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[1], 1.0, 1e-5); } TEST_F(NDArrayTest, TestIndexedPut4) { - auto x = NDArrayFactory::create('f', {2, 2}); - x.p(0, 1, 1.0f); + auto x = NDArrayFactory::create('f', {2, 2}); + x.p(0, 1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); } - TEST_F(NDArrayTest, TestIndexedPut5) { - auto x = NDArrayFactory::create('c', {2, 2}); - x.p(0, 1, 1.0f); + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(0, 1, 1.0f); - //x.printBuffer("after"); - ASSERT_NEAR(x.bufferAsT()[1], 1.0, 1e-5); + // x.printBuffer("after"); + ASSERT_NEAR(x.bufferAsT()[1], 1.0, 1e-5); } TEST_F(NDArrayTest, TestAllTensors1) { - auto matrix = NDArrayFactory::create('c', {3, 5}); + auto matrix = NDArrayFactory::create('c', {3, 5}); - ResultSet rows = matrix.allTensorsAlongDimension({1}); + ResultSet rows = matrix.allTensorsAlongDimension({1}); - ASSERT_EQ(3, rows.size()); + ASSERT_EQ(3, rows.size()); } - TEST_F(NDArrayTest, TestIndexing1) { - auto matrix = NDArrayFactory::create('c', {5, 5}); - for (int e = 0; e < matrix.lengthOf(); e++) - matrix.p(e, (float) e); + auto matrix = NDArrayFactory::create('c', {5, 5}); + for (int e = 0; e < matrix.lengthOf(); e++) matrix.p(e, (float)e); - auto sub = matrix({2,4, 0,0}, true); + auto sub = matrix({2, 4, 0, 0}, true); - ASSERT_EQ(2, sub.rows()); - ASSERT_EQ(5, sub.columns()); + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); - ASSERT_NEAR(10, sub.e(0), 1e-5); + ASSERT_NEAR(10, sub.e(0), 1e-5); } - TEST_F(NDArrayTest, TestIndexing2) { - auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); - matrix.linspace(0); + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); - auto sub = matrix({0,0, 2,4, 0,0, 0,0}, true); + auto sub = matrix({0, 0, 2, 4, 0, 0, 0, 0}, true); - ASSERT_EQ(2, sub.sizeAt(0)); - ASSERT_EQ(2, sub.sizeAt(1)); - ASSERT_EQ(4, sub.sizeAt(2)); - ASSERT_EQ(4, sub.sizeAt(3)); + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); - ASSERT_EQ(64, sub.lengthOf()); - ASSERT_NEAR(32, sub.e(0), 1e-5); - ASSERT_NEAR(112, sub.e(32), 1e-5); + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); } TEST_F(NDArrayTest, TestIndexing3) { - auto matrix = NDArrayFactory::create('c', {5, 5}); - matrix.linspace(0); + auto matrix = NDArrayFactory::create('c', {5, 5}); + matrix.linspace(0); - auto sub = matrix({2,4, 0,0}); + auto sub = matrix({2, 4, 0, 0}); - ASSERT_EQ(2, sub.rows()); - ASSERT_EQ(5, sub.columns()); + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); - ASSERT_NEAR(10, sub.e(0), 1e-5); + ASSERT_NEAR(10, sub.e(0), 1e-5); } - TEST_F(NDArrayTest, TestIndexing4) { - auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); - matrix.linspace(0); - - auto sub = matrix({0,0, 2,4, 0,0, 0,0}); + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); - ASSERT_EQ(2, sub.sizeAt(0)); - ASSERT_EQ(2, sub.sizeAt(1)); - ASSERT_EQ(4, sub.sizeAt(2)); - ASSERT_EQ(4, sub.sizeAt(3)); + auto sub = matrix({0, 0, 2, 4, 0, 0, 0, 0}); + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); - ASSERT_EQ(64, sub.lengthOf()); - ASSERT_NEAR(32, sub.e(0), 1e-5); - ASSERT_NEAR(112, sub.e(32), 1e-5); + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); } TEST_F(NDArrayTest, TestReshapeNegative1) { - std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + std::unique_ptr array( + NDArrayFactory::create_('c', {2, 3, 4, 64})); - array->reshapei('c', {-1, 64}); + array->reshapei('c', {-1, 64}); - ASSERT_EQ(24, array->sizeAt(0)); - ASSERT_EQ(64, array->sizeAt(1)); + ASSERT_EQ(24, array->sizeAt(0)); + ASSERT_EQ(64, array->sizeAt(1)); } TEST_F(NDArrayTest, TestReshapeNegative2) { - std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + std::unique_ptr array( + NDArrayFactory::create_('c', {2, 3, 4, 64})); - auto reshaped = array->reshape('c', {-1, 64}); + auto reshaped = array->reshape('c', {-1, 64}); - ASSERT_EQ(24, reshaped.sizeAt(0)); - ASSERT_EQ(64, reshaped.sizeAt(1)); + ASSERT_EQ(24, reshaped.sizeAt(0)); + ASSERT_EQ(64, reshaped.sizeAt(1)); } ////////////////////////////////////////////////////////////////////// // TEST_F(NDArrayTest, SVD1) { // double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; -// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; -// double arrS[2] = {0.626828, 14.269095}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, +// 0.547354, 0.381169, 0.744789}; double arrS[2] = {0.626828, 14.269095}; // double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; // int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; @@ -1445,9 +1519,10 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD2) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; -// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; -// double arrS[3] = {9.508032, 0.77287, 0.000}; -// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] = +// {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, +// -0.816497, 0.408248}; // int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; // int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; @@ -1470,8 +1545,8 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD3) { // double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; -// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; -// double arrS[2] = {0.626828, 14.269095}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, +// 0.547354, 0.381169, 0.744789}; double arrS[2] = {0.626828, 14.269095}; // double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; // int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; @@ -1495,9 +1570,10 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // TEST_F(NDArrayTest, SVD4) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; -// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; -// double arrS[3] = {9.508032, 0.77287, 0.000}; -// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] = +// {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, +// -0.816497, 0.408248}; // int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; // int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; @@ -1518,1176 +1594,1171 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestStdDev1) { - auto array = NDArrayFactory::create('c', {1, 5}); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, e+1); + auto array = NDArrayFactory::create('c', {1, 5}); + for (int e = 0; e < array.lengthOf(); e++) array.p(e, e + 1); - auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - ASSERT_NEAR(std, 1.58109, 1e-4); + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestStdDev2) { - auto array = NDArrayFactory::create('c', {5, 6}); - auto tad = array(0, {1}); + auto array = NDArrayFactory::create('c', {5, 6}); + auto tad = array(0, {1}); - ASSERT_EQ(5, tad.lengthOf()); + ASSERT_EQ(5, tad.lengthOf()); - for (int e = 0; e < tad.lengthOf(); e++) - tad.p(e, e+1); + for (int e = 0; e < tad.lengthOf(); e++) tad.p(e, e + 1); - ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); + ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); - auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - ASSERT_NEAR(std, 1.58109, 1e-4); + auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); } TEST_F(NDArrayTest, TestStdDev3) { - auto array = NDArrayFactory::create('c', {1, 50000}); - for (int e = 0; e < array.lengthOf(); e++) - array.p(e, 1.f + (e%2?0.5f:-0.5f)); + auto array = NDArrayFactory::create('c', {1, 50000}); + for (int e = 0; e < array.lengthOf(); e++) + array.p(e, 1.f + (e % 2 ? 0.5f : -0.5f)); - auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - // nd4j_printf("Variance is %f\n", std); - ASSERT_NEAR(std, 0.5f, 1.0e-5f); + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); + // nd4j_printf("Variance is %f\n", std); + ASSERT_NEAR(std, 0.5f, 1.0e-5f); } TEST_F(NDArrayTest, TestStdDev4) { - auto array = NDArrayFactory::create('c', {1, 20000}); - float const ethalon = 1 / 3.f; - float x = ethalon; - int total = array.lengthOf(); - for (int e = 0; e < total; e++) { - array.p(e, 1.0f + (e % 2?ethalon:-ethalon)); - x *= (e % 2? 2.f: 0.5f); - } - x = 0.f; - for (int e = 0; e < total; ++e) { - x += array.e(e); - } - x /= array.lengthOf(); - float y = 0; - double M2 = 0; - for (int e = 0; e < total; ++e) { + auto array = NDArrayFactory::create('c', {1, 20000}); + float const ethalon = 1 / 3.f; + float x = ethalon; + int total = array.lengthOf(); + for (int e = 0; e < total; e++) { + array.p(e, 1.0f + (e % 2 ? ethalon : -ethalon)); + x *= (e % 2 ? 2.f : 0.5f); + } + x = 0.f; + for (int e = 0; e < total; ++e) { + x += array.e(e); + } + x /= array.lengthOf(); + float y = 0; + double M2 = 0; + for (int e = 0; e < total; ++e) { // y += sd::math::nd4j_abs(array(e) - x); - M2 += (array.e(e) - x) * (array.e(e) - x); - } - //y /= total; - M2 /= total; + M2 += (array.e(e) - x) * (array.e(e) - x); + } + // y /= total; + M2 /= total; - y = M2; - auto a = array.varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto std = a.e(0); -// float bY = array.varianceNumber(); - float bY = 0.3333333f; - // nd4j_printf("Variance is %f, res is %f, internal is %f\n, deviance is %f(%f)\n", std, x, bY, y, sd::math::nd4j_sqrt(M2)); - ASSERT_NEAR(std, 0.3333333f, 1.0e-5f); + y = M2; + auto a = array.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto std = a.e(0); + // float bY = array.varianceNumber(); + float bY = 0.3333333f; + // nd4j_printf("Variance is %f, res is %f, internal is %f\n, deviance is + // %f(%f)\n", std, x, bY, y, sd::math::nd4j_sqrt(M2)); + ASSERT_NEAR(std, 0.3333333f, 1.0e-5f); } TEST_F(NDArrayTest, TestStdDev5) { - auto array = NDArrayFactory::create('c', {1, 10000}); //00000}); - auto arrayD = NDArrayFactory::create('c', {1, 10000}); //00000}); - for (int e = 0; e < array.lengthOf(); e++) { - array.p(e, 1.f + (e%2?1/5.f:-1/5.f)); - arrayD.p(e, 1.0 + (e%2?1/5.:-1/5.)); - } - float stdF = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - double stdD = arrayD.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - // nd4j_printf("Variance is %f(%f)\n", stdF, stdD); - ASSERT_NEAR(stdD, 0.2, 1.0e-8); // 1/5 = 0.2 - ASSERT_NEAR(stdF, 0.2f, 1.0e-5f); // 1/5 = 0.2 + auto array = NDArrayFactory::create('c', {1, 10000}); // 00000}); + auto arrayD = NDArrayFactory::create('c', {1, 10000}); // 00000}); + for (int e = 0; e < array.lengthOf(); e++) { + array.p(e, 1.f + (e % 2 ? 1 / 5.f : -1 / 5.f)); + arrayD.p(e, 1.0 + (e % 2 ? 1 / 5. : -1 / 5.)); + } + float stdF = + array.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + double stdD = + arrayD.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + // nd4j_printf("Variance is %f(%f)\n", stdF, stdD); + ASSERT_NEAR(stdD, 0.2, 1.0e-8); // 1/5 = 0.2 + ASSERT_NEAR(stdF, 0.2f, 1.0e-5f); // 1/5 = 0.2 } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestApplyIndexReduce1) { - float xBuff[] = {1, 5, 2, 12, 9, 3, 10, 7, 4, 11, 6, 8}; - Nd4jLong xShapeInfo[] = {3, 2, 3, 2, 6, 2, 1, 8192, 1, 99}; - std::vector dim = {0,1}; + float xBuff[] = {1, 5, 2, 12, 9, 3, 10, 7, 4, 11, 6, 8}; + Nd4jLong xShapeInfo[] = {3, 2, 3, 2, 6, 2, 1, 8192, 1, 99}; + std::vector dim = {0, 1}; - NDArray x(xBuff, xShapeInfo); - auto exp = NDArrayFactory::create({3, 1}); + NDArray x(xBuff, xShapeInfo); + auto exp = NDArrayFactory::create({3, 1}); - auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3Dot) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); - auto result = x.applyReduce3(reduce3::Dot, y); - ASSERT_TRUE(result.lengthOf() == 1); - ASSERT_NEAR(42, result.e(0), 1e-5); + auto result = x.applyReduce3(reduce3::Dot, y); + ASSERT_TRUE(result.lengthOf() == 1); + ASSERT_NEAR(42, result.e(0), 1e-5); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; - Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; - Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y ,{1}); + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension1) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.816497f, 0.816497f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.816497f, 0.816497f}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - NDArray x(xBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); + auto result = x.varianceAlongDimension( + variance::SummaryStatsStandardDeviation, false, {1}); - auto result = x.varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, {1}); - - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension2) { - float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.666667f, 0.666667f}; - Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; - + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.666667f, 0.666667f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray exp(expBuff, expShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension3) { - - - NDArray x = NDArrayFactory::create('c', {10, 10});//(xBuff, xShapeInfo); - NDArray exp = NDArrayFactory::create('c', {10});//(expBuff, expShapeInfo); - x.linspace(1); // 1, 2, 3, ..., 100 - exp.assign(825.f); - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = + NDArrayFactory::create('c', {10, 10}); //(xBuff, xShapeInfo); + NDArray exp = + NDArrayFactory::create('c', {10}); //(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(825.f); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension4) { - - - NDArray x = NDArrayFactory::create('c', {12, 1, 12});//(xBuff, xShapeInfo); - NDArray exp = NDArrayFactory::create('c', {1,12});//(expBuff, expShapeInfo); - x.linspace(1); // 1, 2, 3, ..., 100 - exp.assign(1716.); - auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); - ASSERT_TRUE(exp.isSameShapeStrict(result)); - ASSERT_TRUE(exp.equalsTo(result)); + NDArray x = + NDArrayFactory::create('c', {12, 1, 12}); //(xBuff, xShapeInfo); + NDArray exp = + NDArrayFactory::create('c', {1, 12}); //(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(1716.); + auto result = + x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestSubRowVector1) { - float xBuff[] = {6, 7, 8, 9}; - float yBuff[] = {1, 2}; - float expBuff[] = {5, 5, 7, 7}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 7, 8, 9}; + float yBuff[] = {1, 2}; + float expBuff[] = {5, 5, 7, 7}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.subRowVector(y, target); + x.subRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestDivRowVector1) { - float xBuff[] = {6, 8, 10, 12}; - float yBuff[] = {2, 4}; - float expBuff[] = {3, 2, 5, 3}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {3, 2, 5, 3}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.divRowVector(y, target); + x.divRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMulRowVector1) { - float xBuff[] = {6, 8, 10, 12}; - float yBuff[] = {2, 4}; - float expBuff[] = {12, 32, 20, 48}; - Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {12, 32, 20, 48}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - NDArray x(xBuff, xShapeInfo); - NDArray y(yBuff, yShapeInfo); - NDArray target(x); - NDArray exp(expBuff, xShapeInfo); + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); - x.mulRowVector(y, target); + x.mulRowVector(y, target); - ASSERT_TRUE(exp.isSameShapeStrict(target)); - ASSERT_TRUE(exp.equalsTo(&target)); + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestTensorDotAgain_1) { - int sY = 1; - int sX = 1; - int pY = 0; - int pX = 0; - int iC = 2; - int oC = 2; - int kY = 3; - int kX = 3; - int iY = 2; - int iX = 2; - int oY = 6; - int oX = 6; - int eD = iC * oC; - int B = 2; - - /* - input = np.linspace(1, B * iC * iY * iX, B * iC * iY * iX).reshape(B, iC, iY, iX) - weights = np.linspace(1, iC * oC * kY * kX, iC * oC * kY * kX).reshape(iC, oC, kY, kX) - */ - double _expB[] = {96.0, 116.0, 136.0, 156.0, 256.0, 276.0, 296.0, 316.0, 102.0, 124.0, 146.0, 168.0, 278.0, 300.0, 322.0, 344.0, 108.0, 132.0, 156.0, 180.0, 300.0, 324.0, 348.0, 372.0, 114.0, 140.0, 166.0, 192.0, 322.0, 348.0, 374.0, 400.0, 120.0, 148.0, 176.0, 204.0, 344.0, 372.0, 400.0, 428.0, 126.0, 156.0, 186.0, 216.0, 366.0, 396.0, 426.0, 456.0, 132.0, 164.0, 196.0, 228.0, 388.0, 420.0, 452.0, 484.0, 138.0, 172.0, 206.0, 240.0, 410.0, 444.0, 478.0, 512.0, 144.0, 180.0, 216.0, 252.0, 432.0, 468.0, 504.0, 540.0, 150.0, 188.0, 226.0, 264.0, 454.0, 492.0, 530.0, 568.0, 156.0, 196.0, 236.0, 276.0, 476.0, 516.0, 556.0, 596.0, 162.0, 204.0, 246.0, 288.0, 498.0, 540.0, 582.0, 624.0, 168.0, 212.0, 256.0, 300.0, 520.0, 564.0, 608.0, 652.0, 174.0, 220.0, 266.0, 312.0, 542.0, 588.0, 634.0, 680.0, 180.0, 228.0, 276.0, 324.0, 564.0, 612.0, 660.0, 708.0, 186.0, 236.0, 286.0, 336.0, 586.0, 636.0, 686.0, 736.0, 192.0, 244.0, 296.0, 348.0, 608.0, 660.0, 712.0, 764.0, 198.0, 252.0, 306.0, 360.0, 630.0, 684.0, 738.0, 792.0}; - - Nd4jLong _expS[] = {6, 2, 3, 3, 2, 2, 2, 72, 24, 8, 4, 2, 1, 16384, 1, 99}; - NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - - auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); - auto weights = NDArrayFactory::create('c', {iC, oC, kY, kX}); - - input.linspace(1); - weights.linspace(1); - - auto result = MmulHelper::tensorDot(&weights, &input, {0}, {1}); - - //result->printShapeInfo("result shape"); - ASSERT_TRUE(exp.isSameShape(result)); - -// exp.printBuffer("Expctd buffer"); -// result->printBuffer("Result buffer"); - ASSERT_TRUE(exp.equalsTo(result)); - - delete result; + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 2; + int kY = 3; + int kX = 3; + int iY = 2; + int iX = 2; + int oY = 6; + int oX = 6; + int eD = iC * oC; + int B = 2; + + /* + input = np.linspace(1, B * iC * iY * iX, B * iC * iY * iX).reshape(B, iC, iY, + iX) weights = np.linspace(1, iC * oC * kY * kX, iC * oC * kY * kX).reshape(iC, + oC, kY, kX) + */ + double _expB[] = { + 96.0, 116.0, 136.0, 156.0, 256.0, 276.0, 296.0, 316.0, 102.0, 124.0, + 146.0, 168.0, 278.0, 300.0, 322.0, 344.0, 108.0, 132.0, 156.0, 180.0, + 300.0, 324.0, 348.0, 372.0, 114.0, 140.0, 166.0, 192.0, 322.0, 348.0, + 374.0, 400.0, 120.0, 148.0, 176.0, 204.0, 344.0, 372.0, 400.0, 428.0, + 126.0, 156.0, 186.0, 216.0, 366.0, 396.0, 426.0, 456.0, 132.0, 164.0, + 196.0, 228.0, 388.0, 420.0, 452.0, 484.0, 138.0, 172.0, 206.0, 240.0, + 410.0, 444.0, 478.0, 512.0, 144.0, 180.0, 216.0, 252.0, 432.0, 468.0, + 504.0, 540.0, 150.0, 188.0, 226.0, 264.0, 454.0, 492.0, 530.0, 568.0, + 156.0, 196.0, 236.0, 276.0, 476.0, 516.0, 556.0, 596.0, 162.0, 204.0, + 246.0, 288.0, 498.0, 540.0, 582.0, 624.0, 168.0, 212.0, 256.0, 300.0, + 520.0, 564.0, 608.0, 652.0, 174.0, 220.0, 266.0, 312.0, 542.0, 588.0, + 634.0, 680.0, 180.0, 228.0, 276.0, 324.0, 564.0, 612.0, 660.0, 708.0, + 186.0, 236.0, 286.0, 336.0, 586.0, 636.0, 686.0, 736.0, 192.0, 244.0, + 296.0, 348.0, 608.0, 660.0, 712.0, 764.0, 198.0, 252.0, 306.0, 360.0, + 630.0, 684.0, 738.0, 792.0}; + + Nd4jLong _expS[] = {6, 2, 3, 3, 2, 2, 2, 72, 24, 8, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + auto weights = NDArrayFactory::create('c', {iC, oC, kY, kX}); + + input.linspace(1); + weights.linspace(1); + + auto result = MmulHelper::tensorDot(&weights, &input, {0}, {1}); + + // result->printShapeInfo("result shape"); + ASSERT_TRUE(exp.isSameShape(result)); + + // exp.printBuffer("Expctd buffer"); + // result->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(result)); + + delete result; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestBroadcast_1) { - double _expB[] = {1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000, 1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000}; - Nd4jLong _expS[] = {4, 2, 3, 2, 2, 12, 4, 2, 1, 16384, 1, 99}; - NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + double _expB[] = {1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, + 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000, + 1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, + 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000}; + Nd4jLong _expS[] = {4, 2, 3, 2, 2, 12, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); - auto bias = NDArrayFactory::create('c', {1, 3}); + auto input = NDArrayFactory::create('c', {2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); - bias.linspace(1); + bias.linspace(1); - input.applyBroadcast(broadcast::Add, {1}, bias, input); + input.applyBroadcast(broadcast::Add, {1}, bias, input); - //input.printBuffer("result"); - ASSERT_TRUE(exp.equalsTo(&input)); + // input.printBuffer("result"); + ASSERT_TRUE(exp.equalsTo(&input)); } TEST_F(NDArrayTest, TestTranspose_11) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - x.transposei(); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.transposei(); - ASSERT_EQ(4, x.sizeAt(0)); - ASSERT_EQ(3, x.sizeAt(1)); - ASSERT_EQ(2, x.sizeAt(2)); + ASSERT_EQ(4, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(2, x.sizeAt(2)); } - TEST_F(NDArrayTest, TestTranspose_12) { - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto y = x.transpose(); + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = x.transpose(); - ASSERT_EQ(4, y.sizeAt(0)); - ASSERT_EQ(3, y.sizeAt(1)); - ASSERT_EQ(2, y.sizeAt(2)); + ASSERT_EQ(4, y.sizeAt(0)); + ASSERT_EQ(3, y.sizeAt(1)); + ASSERT_EQ(2, y.sizeAt(2)); - ASSERT_EQ(2, x.sizeAt(0)); - ASSERT_EQ(3, x.sizeAt(1)); - ASSERT_EQ(4, x.sizeAt(2)); + ASSERT_EQ(2, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(4, x.sizeAt(2)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMMulMultiDim) { - const int bS=2; - const int K=3; - const int N=4; + const int bS = 2; + const int K = 3; + const int N = 4; - auto input = NDArrayFactory::create('c', {bS, K, N}); - auto weights = NDArrayFactory::create('c', {3*K, K}); - auto expected = NDArrayFactory::create('c', {bS, 3*K, N}, { 38, 44, 50, 56, 83, 98, 113, 128, 128, 152, 176, 200, 173, 206, 239, 272, 218, 260, 302, 344, 263, 314, 365, 416, 308, 368, 428, 488, 353, 422, 491, 560, 398, 476, 554, 632, 110, 116, 122, 128, 263, 278, 293, 308, 416, 440, 464, 488, 569, 602, 635, 668, 722, 764, 806, 848, 875, 926, 977, 1028, 1028, 1088, 1148, 1208, 1181, 1250, 1319, 1388, 1334, 1412, 1490, 1568}); + auto input = NDArrayFactory::create('c', {bS, K, N}); + auto weights = NDArrayFactory::create('c', {3 * K, K}); + auto expected = NDArrayFactory::create( + 'c', {bS, 3 * K, N}, + {38, 44, 50, 56, 83, 98, 113, 128, 128, 152, 176, 200, + 173, 206, 239, 272, 218, 260, 302, 344, 263, 314, 365, 416, + 308, 368, 428, 488, 353, 422, 491, 560, 398, 476, 554, 632, + 110, 116, 122, 128, 263, 278, 293, 308, 416, 440, 464, 488, + 569, 602, 635, 668, 722, 764, 806, 848, 875, 926, 977, 1028, + 1028, 1088, 1148, 1208, 1181, 1250, 1319, 1388, 1334, 1412, 1490, 1568}); - input.linspace(1); - weights.linspace(1); + input.linspace(1); + weights.linspace(1); - auto result = MmulHelper::mmul(&weights, &input, nullptr, 1., 0.); - // result must have such shape [bS x 3K x N] + auto result = MmulHelper::mmul(&weights, &input, nullptr, 1., 0.); + // result must have such shape [bS x 3K x N] - ASSERT_TRUE(result->isSameShape(&expected)); + ASSERT_TRUE(result->isSameShape(&expected)); - //result->printShapeInfo("result shape"); - // result->printBuffer("result buffer"); - ASSERT_TRUE(result->equalsTo(&expected)); - delete result; + // result->printShapeInfo("result shape"); + // result->printBuffer("result buffer"); + ASSERT_TRUE(result->equalsTo(&expected)); + delete result; } - TEST_F(NDArrayTest, AdditionOperator1) { + auto input1 = NDArrayFactory::create('c', {2, 2}); + auto input2 = NDArrayFactory::create('c', {2, 2}); + auto expected = NDArrayFactory::create('c', {2, 2}); - auto input1 = NDArrayFactory::create('c', {2,2}); - auto input2 = NDArrayFactory::create('c', {2,2}); - auto expected = NDArrayFactory::create('c', {2,2}); - - input1.assign(1.5); - input2.assign(2.); - expected.assign(3.5); + input1.assign(1.5); + input2.assign(2.); + expected.assign(3.5); - input2 = input1 + input2; - - ASSERT_TRUE(input2.equalsTo(&expected)); + input2 = input1 + input2; + ASSERT_TRUE(input2.equalsTo(&expected)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMatmMul_Again_1) { - auto a = NDArrayFactory::create('c', {3, 4, 1}); - auto b = NDArrayFactory::create('c', {3, 1, 5}); + auto a = NDArrayFactory::create('c', {3, 4, 1}); + auto b = NDArrayFactory::create('c', {3, 1, 5}); - a.linspace(1); - b.linspace(1); + a.linspace(1); + b.linspace(1); - float _expB[] = {1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 6.f, 8.f, 10.f, 3.f, 6.f, 9.f, 12.f, 15.f, 4.f, 8.f, 12.f, 16.f, 20.f, 30.f, 35.f, 40.f, 45.f, 50.f, 36.f, 42.f, 48.f, 54.f, 60.f, 42.f, 49.f, 56.f, 63.f, 70.f, 48.f, 56.f, 64.f, 72.f, 80.f, 99.f, 108.f, 117.f, 126.f, 135.f, 110.f, 120.f, 130.f, 140.f, 150.f, 121.f, 132.f, 143.f, 154.f, 165.f, 132.f, 144.f, 156.f, 168.f, 180.f}; - Nd4jLong _expS[] = {3, 3, 4, 5, 20, 5, 1, 8192, 1, 99}; - NDArray c(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + float _expB[] = { + 1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 6.f, 8.f, 10.f, + 3.f, 6.f, 9.f, 12.f, 15.f, 4.f, 8.f, 12.f, 16.f, 20.f, + 30.f, 35.f, 40.f, 45.f, 50.f, 36.f, 42.f, 48.f, 54.f, 60.f, + 42.f, 49.f, 56.f, 63.f, 70.f, 48.f, 56.f, 64.f, 72.f, 80.f, + 99.f, 108.f, 117.f, 126.f, 135.f, 110.f, 120.f, 130.f, 140.f, 150.f, + 121.f, 132.f, 143.f, 154.f, 165.f, 132.f, 144.f, 156.f, 168.f, 180.f}; + Nd4jLong _expS[] = {3, 3, 4, 5, 20, 5, 1, 8192, 1, 99}; + NDArray c(_expB, _expS, sd::LaunchContext ::defaultContext(), false); - auto c_ = MmulHelper::mmul(&a, &b); + auto c_ = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(c.isSameShape(c_)); - ASSERT_TRUE(c.equalsTo(c_)); + ASSERT_TRUE(c.isSameShape(c_)); + ASSERT_TRUE(c.equalsTo(c_)); - delete c_; + delete c_; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestMatmMul_Again_2) { - auto a = NDArrayFactory::create('c', {2, 5, 4}); - auto b = NDArrayFactory::create('c', {2, 4, 1}); + auto a = NDArrayFactory::create('c', {2, 5, 4}); + auto b = NDArrayFactory::create('c', {2, 4, 1}); - a.linspace(1); - b.linspace(1); + a.linspace(1); + b.linspace(1); - double _expB[] = {30.f, 70.f, 110.f, 150.f, 190.f, 590.f, 694.f, 798.f, 902.f, 1006.f}; - Nd4jLong _expS[] = {3, 2, 5, 1, 5, 1, 1, 16384, 1, 99}; - NDArray c(_expB, _expS); + double _expB[] = {30.f, 70.f, 110.f, 150.f, 190.f, + 590.f, 694.f, 798.f, 902.f, 1006.f}; + Nd4jLong _expS[] = {3, 2, 5, 1, 5, 1, 1, 16384, 1, 99}; + NDArray c(_expB, _expS); - auto c_ = MmulHelper::mmul(&a, &b); + auto c_ = MmulHelper::mmul(&a, &b); - ASSERT_TRUE(c.isSameShape(c_)); + ASSERT_TRUE(c.isSameShape(c_)); - ASSERT_TRUE(c.equalsTo(c_)); + ASSERT_TRUE(c.equalsTo(c_)); - delete c_; + delete c_; } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_1) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_1) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - //x.printShapeInfo("x shape"); - //y.printShapeInfo("y shape"); - //expected.printShapeInfo("e shape"); - //expected.printIndexedBuffer("e"); + // x.printShapeInfo("x shape"); + // y.printShapeInfo("y shape"); + // expected.printShapeInfo("e shape"); + // expected.printIndexedBuffer("e"); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; + auto result = x + y; - //result.printIndexedBuffer("result"); + // result.printIndexedBuffer("result"); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_2) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_2) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; - // result.printIndexedBuffer(); + auto result = x + y; + // result.printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_3) -{ - double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; +TEST_F(NDArrayTest, Operator_Plus_Test_3) { + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(1); - y.linspace(1); + x.linspace(1); + y.linspace(1); - auto result = x + y; - // result.printIndexedBuffer(); - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + auto result = x + y; + // result.printIndexedBuffer(); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Plus_Test_4) -{ - double expBuff[] = {11.,12., 12.,13., 13.,14., 14.,15., 13.,14., 14.,15., 15.,16., 16.,17., 15.,16., 16.,17., 17.,18., 18.,19.}; +TEST_F(NDArrayTest, Operator_Plus_Test_4) { + double expBuff[] = {11., 12., 12., 13., 13., 14., 14., 15., + 13., 14., 14., 15., 15., 16., 16., 17., + 15., 16., 16., 17., 17., 18., 18., 19.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x + y; + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_1) -{ - double expBuff[] = {9. ,10., 10.,11., 11.,12., 12.,13., 17.,18., 18.,19., 19.,20., 20.,21., 25.,26., 26.,27., 27.,28., 28.,29.}; +TEST_F(NDArrayTest, Operator_Minus_Test_1) { + double expBuff[] = {9., 10., 10., 11., 11., 12., 12., 13., + 17., 18., 18., 19., 19., 20., 20., 21., + 25., 26., 26., 27., 27., 28., 28., 29.}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_2) -{ - double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; +TEST_F(NDArrayTest, Operator_Minus_Test_2) { + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11., 10., 9., 8., + 8., 7., 6., 5., 13., 12., 11., 10., 10., 9., 8., 7.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_3) -{ - double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; +TEST_F(NDArrayTest, Operator_Minus_Test_3) { + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11., 10., 9., 8., + 8., 7., 6., 5., 13., 12., 11., 10., 10., 9., 8., 7.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_4) -{ - double expBuff[] = {9.,10., 8., 9., 11.,12.,10.,11., 13.,14.,12.,13.}; +TEST_F(NDArrayTest, Operator_Minus_Test_4) { + double expBuff[] = {9., 10., 8., 9., 11., 12., 10., 11., 13., 14., 12., 13.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_5) -{ - double expBuff[] = {9. ,8 ,10., 9., 11.,10, 12.,11., 13.,12, 14.,13.}; +TEST_F(NDArrayTest, Operator_Minus_Test_5) { + double expBuff[] = {9., 8, 10., 9., 11., 10, 12., 11., 13., 12, 14., 13.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Minus_Test_6) -{ - double expBuff[] = {9., 8, 10., 9, 11.,10, 12.,11., 13.,12, 14.,13, 15.,14, 16.,15., 17.,16, 18.,17, 19.,18, 20.,19.}; +TEST_F(NDArrayTest, Operator_Minus_Test_6) { + double expBuff[] = {9., 8, 10., 9, 11., 10, 12., 11., 13., 12, 14., 13, + 15., 14, 16., 15., 17., 16, 18., 17, 19., 18, 20., 19.}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x - y; + auto result = x - y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_1) -{ - double expBuff[] = {10., 11., 24., 26., 42., 45., 64., 68., 18., 19., 40., 42., 66., 69., 96.,100., 26., 27., 56., 58., 90., 93., 128.,132.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_1) { + double expBuff[] = {10., 11., 24., 26., 42., 45., 64., 68., + 18., 19., 40., 42., 66., 69., 96., 100., + 26., 27., 56., 58., 90., 93., 128., 132.}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_2) -{ - double expBuff[] = {10.,20., 30., 40., 55.,66., 77., 88., 12.,24., 36., 48., 65.,78., 91.,104., 14.,28., 42., 56., 75.,90.,105.,120.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_2) { + double expBuff[] = {10., 20., 30., 40., 55., 66., 77., 88., + 12., 24., 36., 48., 65., 78., 91., 104., + 14., 28., 42., 56., 75., 90., 105., 120.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_3) -{ - double expBuff[] = {10.,20., 30., 40.,55.,66., 77., 88., 12.,24., 36., 48.,65.,78., 91.,104., 14.,28., 42., 56.,75.,90.,105.,120.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_3) { + double expBuff[] = {10., 20., 30., 40., 55., 66., 77., 88., + 12., 24., 36., 48., 65., 78., 91., 104., + 14., 28., 42., 56., 75., 90., 105., 120.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_4) -{ - double expBuff[] = {10.,11.,20.,22., 12.,13.,24.,26., 14.,15.,28.,30.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_4) { + double expBuff[] = {10., 11., 20., 22., 12., 13., + 24., 26., 14., 15., 28., 30.}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_5) -{ - double expBuff[] = {10.,20.,11.,22., 12.,24.,13.,26., 14.,28.,15.,30.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_5) { + double expBuff[] = {10., 20., 11., 22., 12., 24., + 13., 26., 14., 28., 15., 30.}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Multiply_Test_6) -{ - double expBuff[] = {10,11.,12.,13.,28.,30.,32.,34.,54.,57.,60.,63.}; +TEST_F(NDArrayTest, Operator_Multiply_Test_6) { + double expBuff[] = {10, 11., 12., 13., 28., 30., + 32., 34., 54., 57., 60., 63.}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {3, 1, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {3, 1, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x * y; + auto result = x * y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_1) -{ - double expBuff[] = {10. ,11. , 6. , 6.5 , 4.6666, 5. , 4. , 4.25 , 18. ,19. , 10. ,10.5 , 7.3333, 7.6666, 6. , 6.25 , 26. ,27. , 14. ,14.5 , 10. ,10.3333, 8. , 8.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_1) { + double expBuff[] = {10., 11., 6., 6.5, 4.6666, 5., 4., 4.25, + 18., 19., 10., 10.5, 7.3333, 7.6666, 6., 6.25, + 26., 27., 14., 14.5, 10., 10.3333, 8., 8.25}; - auto x = NDArrayFactory::create('c', {3, 4, 2}); - auto y = NDArrayFactory::create('c', {4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result,1e-4)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result, 1e-4)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_2) -{ - double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.83333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.16666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; +TEST_F(NDArrayTest, Operator_Divide_Test_2) { + double expBuff[] = {10., 5., 3.333333, 2.5, 2.2, 1.83333, 1.571428, 1.375, + 12., 6., 4., 3., 2.6, 2.16666, 1.857142, 1.625, + 14., 7., 4.666666, 3.5, 3., 2.5, 2.142857, 1.875}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_3) -{ - double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.833333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.166666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; +TEST_F(NDArrayTest, Operator_Divide_Test_3) { + double expBuff[] = {10., 5., 3.333333, 2.5, 2.2, 1.833333, 1.571428, 1.375, + 12., 6., 4., 3., 2.6, 2.166666, 1.857142, 1.625, + 14., 7., 4.666666, 3.5, 3., 2.5, 2.142857, 1.875}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {2, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_4) -{ - double expBuff[] = {10.,11., 5., 5.5, 12.,13., 6., 6.5, 14.,15., 7., 7.5}; +TEST_F(NDArrayTest, Operator_Divide_Test_4) { + double expBuff[] = {10., 11., 5., 5.5, 12., 13., 6., 6.5, 14., 15., 7., 7.5}; - auto x = NDArrayFactory::create('c', {3, 1, 2}); - auto y = NDArrayFactory::create('c', {2, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_5) -{ - double expBuff[] = {10.,5., 11.,5.5, 12.,6., 13.,6.5, 14.,7., 15.,7.5}; +TEST_F(NDArrayTest, Operator_Divide_Test_5) { + double expBuff[] = {10., 5., 11., 5.5, 12., 6., 13., 6.5, 14., 7., 15., 7.5}; - auto x = NDArrayFactory::create('c', {3, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 2}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_6) -{ - double expBuff[] = {10. , 5.5 , 4. , 3.25 ,14. , 7.5 , 5.333333, 4.25 ,18. , 9.5 , 6.666666, 5.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_6) { + double expBuff[] = {10., 5.5, 4., 3.25, 14., 7.5, + 5.333333, 4.25, 18., 9.5, 6.666666, 5.25}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 4, 1}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } - ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest, Operator_Divide_Test_7) -{ - double expBuff[] = {10., 5. ,3.333333,2.5 ,11., 5.5,3.666666,2.75,12., 6. ,4. ,3. ,13., 6.5,4.333333,3.25, 14., 7. ,4.666666,3.5 ,15., 7.5,5. ,3.75,16., 8. ,5.333333,4. ,17., 8.5,5.666666,4.25, 18., 9. ,6. ,4.5 ,19., 9.5,6.333333,4.75,20.,10. ,6.666666,5. ,21.,10.5,7. ,5.25}; +TEST_F(NDArrayTest, Operator_Divide_Test_7) { + double expBuff[] = { + 10., 5., 3.333333, 2.5, 11., 5.5, 3.666666, 2.75, 12., 6., 4., 3., + 13., 6.5, 4.333333, 3.25, 14., 7., 4.666666, 3.5, 15., 7.5, 5., 3.75, + 16., 8., 5.333333, 4., 17., 8.5, 5.666666, 4.25, 18., 9., 6., 4.5, + 19., 9.5, 6.333333, 4.75, 20., 10., 6.666666, 5., 21., 10.5, 7., 5.25}; - auto x = NDArrayFactory::create('c', {3, 4, 1}); - auto y = NDArrayFactory::create('c', {1, 1, 4}); - auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 4}); + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 4}); - x.linspace(10); - y.linspace(1); + x.linspace(10); + y.linspace(1); - auto result = x / y; + auto result = x / y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } #ifndef __CUDABLAS__ ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {4, 5, 6, 7, 8}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 5, 6, 7, 8}); - auto lambda = LAMBDA_F(_val) { - return _val + 3.0f; - }; + auto lambda = LAMBDA_F(_val) { return _val + 3.0f; }; - x.applyLambda(lambda, x); + x.applyLambda(lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_2) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {3, 5, 3, 5, 3}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 5, 3, 5, 3}); - auto lambda = LAMBDA_FF(_x, _y) { - return _x + _y + 1.0f; - }; + auto lambda = LAMBDA_FF(_x, _y) { return _x + _y + 1.0f; }; - x.applyPairwiseLambda(y, lambda, x); + x.applyPairwiseLambda(y, lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_Lambda_3) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {4, 8, 4, 8, 4}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 8, 4, 8, 4}); - auto lambda = LAMBDA_DD(_x, _y) { - return (_x + _y) * 2; - }; + auto lambda = LAMBDA_DD(_x, _y) { return (_x + _y) * 2; }; - x.applyPairwiseLambda(y, lambda, x); + x.applyPairwiseLambda(y, lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #endif ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_swapUnsafe_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); + auto expX = NDArrayFactory::create('c', {2, 2}, {5, 6, 7, 8}); + auto expY = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); - auto expX = NDArrayFactory::create('c', {2, 2}, {5, 6, 7, 8}); - auto expY = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - - x.swapUnsafe(y); + x.swapUnsafe(y); - ASSERT_TRUE(expX.equalsTo(&x)); - ASSERT_TRUE(expY.equalsTo(&y)); + ASSERT_TRUE(expX.equalsTo(&x)); + ASSERT_TRUE(expY.equalsTo(&y)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_2) { + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 5}); + x.linspace(1); - auto x = NDArrayFactory::create('f', {2, 3}); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 5}); - x.linspace(1); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_3) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 4}); - auto x = NDArrayFactory::create('c', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 4}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_4) { + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 4}); - auto x = NDArrayFactory::create('f', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 2}, {1, 4}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_5) { + auto x = NDArrayFactory::create('c', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 8}); - auto x = NDArrayFactory::create('c', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 8}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_6) { + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 8}); - auto x = NDArrayFactory::create('f', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 2}, {1, 8}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_7) { + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 8}); - auto x = NDArrayFactory::create('f', {2, 2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 8}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_8) { + auto x = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 5}); - auto x = NDArrayFactory::create('c', {2, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 2}, {1, 5}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_9) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 4}); - auto x = NDArrayFactory::create('c', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {2, 1}, {1, 4}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_10) { + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 4}); - auto x = NDArrayFactory::create('f', {2, 2}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {2, 1}, {1, 4}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_11) { + auto x = NDArrayFactory::create('f', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {3, 1}, {1, 5, 9}); - auto x = NDArrayFactory::create('f', {3, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {3, 1}, {1, 5, 9}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_12) { + auto x = NDArrayFactory::create('c', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 5, 9}); - auto x = NDArrayFactory::create('c', {3, 3}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 5, 9}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_13) { + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 1}, {1, 18, 35}); - auto x = NDArrayFactory::create('c', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 1}, {1,18,35}); + auto diag = x.diagonal('c'); - auto diag = x.diagonal('c'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_14) { + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 18, 35}); - auto x = NDArrayFactory::create('c', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 3}, {1,18,35}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_15) { + auto x = NDArrayFactory::create('f', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 3}, {1, 18, 35}); - auto x = NDArrayFactory::create('f', {3, 3, 4}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 3}, {1,18,35}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_16) { + auto x = NDArrayFactory::create('f', {1, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 5}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - - auto diag = x.diagonal('c'); + auto diag = x.diagonal('c'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_17) { + auto x = NDArrayFactory::create('c', {5, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - auto x = NDArrayFactory::create('c', {5, 1}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + auto diag = x.diagonal('r'); - auto diag = x.diagonal('r'); - - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, Test_diagonal_18) { + auto x = NDArrayFactory::create('f', {1, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 1}); - x.linspace(1); - auto exp = NDArrayFactory::create('f', {1, 1}, {1}); - - auto diag = x.diagonal('r'); + auto diag = x.diagonal('r'); - ASSERT_TRUE(exp.isSameShape(diag)); - ASSERT_TRUE(exp.equalsTo(diag)); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, assign_test1) { + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray y('c', {2, 3}, {10, 20, 30, 40, 50, 60}); + y.reshapei('c', {3, 2}); - NDArray x('c', {2, 3}, {1,2,3,4,5,6}); - NDArray y('c', {2, 3}, {10,20,30,40,50,60}); - y.reshapei('c',{3, 2}); - - x.assign(y); - x.reshapei('c',{3, 2}); - ASSERT_TRUE(x.equalsTo(y)); + x.assign(y); + x.reshapei('c', {3, 2}); + ASSERT_TRUE(x.equalsTo(y)); } diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index 4dd4c3abee4d..a5cb2b4ba320 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -18,1290 +18,1343 @@ // Created by raver119 on 21.11.17. // -#include "testlayers.h" -#include #include #include #include +#include + +#include "testlayers.h" + using namespace sd; ////////////////////////////////////////////////////////////////////// class NDArrayTest2 : public testing::Test { -public: - + public: }; - TEST_F(NDArrayTest2, Test_ByteVector_1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); - - auto vec = x.asByteVector(); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto restored = new NDArray((float *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto vec = x.asByteVector(); + auto restored = + new NDArray((float *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_ByteVector_2) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto vec = x.asByteVector(); + auto vec = x.asByteVector(); - auto restored = new NDArray((bfloat16 *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto restored = + new NDArray((bfloat16 *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_ByteVector_3) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.linspace(1); + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); - auto vec = x.asByteVector(); + auto vec = x.asByteVector(); - auto restored = new NDArray((double *)vec.data(), x.shapeInfo(), x.getContext(), false); + auto restored = + new NDArray((double *)vec.data(), x.shapeInfo(), x.getContext(), false); - ASSERT_TRUE(x.equalsTo(restored)); + ASSERT_TRUE(x.equalsTo(restored)); - delete restored; + delete restored; } TEST_F(NDArrayTest2, Test_Reshape_Scalar_1) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); - auto e = NDArrayFactory::create(1.0); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create(1.0); - x.reshapei({}); + x.reshapei({}); - ASSERT_EQ(e, x); - ASSERT_EQ(e.rankOf(), x.rankOf()); + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); } TEST_F(NDArrayTest2, Test_Reshape_Scalar_2) { - auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); - auto e = NDArrayFactory::create('c', {1}, {1.0}); + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create('c', {1}, {1.0}); - x.reshapei({1}); + x.reshapei({1}); - ASSERT_EQ(e, x); - ASSERT_EQ(e.rankOf(), x.rankOf()); + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); } TEST_F(NDArrayTest2, Test_IndexReduce_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - ExtraArguments extras({3.0, 0.0, 10.0}); - int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); + ExtraArguments extras({3.0, 0.0, 10.0}); + int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); - ASSERT_EQ(2, idx); + ASSERT_EQ(2, idx); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_1) { + auto x = NDArrayFactory::create('c', {1, 5}); + auto xExp = NDArrayFactory::create('c', {1, 5}, {1, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {1, 5}); - auto xExp = NDArrayFactory::create('c', {1, 5}, {1, 0, 0, 0, 0}); - - x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + x.setIdentity(); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_2) { + auto x = NDArrayFactory::create('f', {1, 5}); + auto xExp = NDArrayFactory::create('f', {1, 5}, {1, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('f', {1, 5}); - auto xExp = NDArrayFactory::create('f', {1, 5}, {1, 0, 0, 0, 0}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_3) { + auto x = NDArrayFactory::create('f', {1, 1}); + auto xExp = NDArrayFactory::create('f', {1, 1}, {1}); - auto x = NDArrayFactory::create('f', {1, 1}); - auto xExp = NDArrayFactory::create('f', {1, 1}, {1}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_4) { + auto x = NDArrayFactory::create('f', {2, 1}); + auto xExp = NDArrayFactory::create('f', {2, 1}, {1, 0}); - auto x = NDArrayFactory::create('f', {2, 1}); - auto xExp = NDArrayFactory::create('f', {2, 1}, {1,0}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_5) { + auto x = NDArrayFactory::create('f', {2, 2}); + auto xExp = NDArrayFactory::create('f', {2, 2}, {1, 0, 0, 1}); - auto x = NDArrayFactory::create('f', {2, 2}); - auto xExp = NDArrayFactory::create('f', {2, 2}, {1,0,0,1}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_6) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto xExp = NDArrayFactory::create('c', {3, 2}, + {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); - auto x = NDArrayFactory::create('c', {3, 2}); - auto xExp = NDArrayFactory::create('c', {3, 2}, {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); - - x.setIdentity(); + x.setIdentity(); - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_7) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto xExp = NDArrayFactory::create( + 'c', {3, 4}, + {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); - auto x = NDArrayFactory::create('c', {3, 4}); - auto xExp = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); + x.setIdentity(); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } #ifdef ALLOWED_3D_IDENTITY //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, SetIdentity_test_8) { + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto xExp = NDArrayFactory::create( + 'c', {3, 3, 3}, {1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.}); + x.setIdentity(); - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto xExp = NDArrayFactory::create('c', {3, 3, 3}, {1.,0.,0. ,0.,0.,0., 0.,0.,0., 0.,0.,0. ,0.,1.,0., 0.,0.,0., 0.,0.,0. ,0.,0.,0., 0.,0.,1.}); - x.setIdentity(); - - ASSERT_TRUE(x.equalsTo(&xExp)); + ASSERT_TRUE(x.equalsTo(&xExp)); } #endif //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_AllReduce3_2) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); - auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); - auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test1) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); - - auto result = mmul(x, y); - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); - + auto result = mmul(x, y); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test2) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1}, {30}); - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {1, 1}, {30}); - - auto result = mmul(y ,x); - - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + auto result = mmul(y, x); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, mmul_test3) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, + {1., 0.2, 0.3, 0.4, 0.2, 0.04, 0.06, 0.08, 0.3, 0.06, 0.09, 0.12, 0.4, + 0.08, 0.12, 0.16}); + auto w = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), 1}, + x.getContext()); // column-vector + auto wT = NDArrayFactory::create( + x.ordering(), {1, (int)x.lengthOf()}, + x.getContext()); // row-vector (transposed w) - auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); - auto w = NDArrayFactory::create( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector - auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) - - w = x / (float)10.; - w.p(0, 1.); - wT.assign(&w); + w = x / (float)10.; + w.p(0, 1.); + wT.assign(&w); - auto result = mmul(w ,wT); - - ASSERT_TRUE(exp.isSameShape(&result)); - ASSERT_TRUE(exp.equalsTo(&result)); + auto result = mmul(w, wT); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); } - TEST_F(NDArrayTest2, Test_Streamline_1) { - auto x = NDArrayFactory::create('c', {3, 4, 6}); - auto y = NDArrayFactory::create('c', {3, 4, 6}); - x.linspace(1); - y.linspace(1); + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('c', {3, 4, 6}); + x.linspace(1); + y.linspace(1); - x.permutei({1, 0, 2}); - y.permutei({1, 0, 2}); + x.permutei({1, 0, 2}); + y.permutei({1, 0, 2}); - y.streamline(); + y.streamline(); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); - ASSERT_FALSE(x.isSameShapeStrict(y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_FALSE(x.isSameShapeStrict(y)); } - TEST_F(NDArrayTest2, Test_Streamline_2) { - auto x = NDArrayFactory::create('c', {3, 4, 6}); - auto y = NDArrayFactory::create('f', {3, 4, 6}); - x.linspace(1); - y.linspace(1); + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('f', {3, 4, 6}); + x.linspace(1); + y.linspace(1); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); - y.streamline('c'); + y.streamline('c'); - ASSERT_TRUE(x.isSameShape(&y)); - ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); } TEST_F(NDArrayTest2, Test_Enforce_1) { - auto x = NDArrayFactory::create('c', {4, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {4, 4}); + auto x = NDArrayFactory::create('c', {4, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}); - x.linspace(1); - exp.linspace(1); + x.linspace(1); + exp.linspace(1); - x.enforce({4, 4}, 'c'); + x.enforce({4, 4}, 'c'); - ASSERT_TRUE(exp.isSameShapeStrict(x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShapeStrict(x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, TestVector_1) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); - x.addiRowVector(row); + x.addiRowVector(row); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } ////////////////////////////////////////////////////////////////////// -TEST_F(NDArrayTest2, Operator_Plus_Test_5) -{ - - auto x = NDArrayFactory::create('c', {8, 8, 8}); - auto y = NDArrayFactory::create('c', {8, 1, 8}); - auto expected = NDArrayFactory::create('c', {8, 8, 8}); +TEST_F(NDArrayTest2, Operator_Plus_Test_5) { + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); - x = 1.; - y = 2.; - expected = 3.; + x = 1.; + y = 2.; + expected = 3.; - auto result = x + y; + auto result = x + y; - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Operator_Plus_Test_6) { + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 1, 3}); + auto expected = NDArrayFactory::create( + 'c', {3, 3, 3}, + {2., 4., 6., 5., 7., 9., 8., 10., 12., 14., 16., 18., 17., 19., + 21., 20., 22., 24., 26., 28., 30., 29., 31., 33., 32., 34., 36.}); + x.linspace(1); + y.linspace(1); - auto x = NDArrayFactory::create('c', {3, 3, 3}); - auto y = NDArrayFactory::create('c', {3, 1, 3}); - auto expected = NDArrayFactory::create('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); - x.linspace(1); - y.linspace(1); + auto result = x + y; - auto result = x + y; - - ASSERT_TRUE(expected.isSameShape(&result)); - ASSERT_TRUE(expected.equalsTo(&result)); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - - x.tileToShape({2,2,2}, x); + x.tileToShape({2, 2, 2}, x); - ASSERT_TRUE(x.isSameShape(&exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test2) { + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + x.tileToShape({2, 3, 2}, x); - x.tileToShape({2,3,2}, x); - - ASSERT_TRUE(x.isSameShape(&exp)); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test3) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto result = NDArrayFactory::create('c', {2, 2, 2}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); - - x.tileToShape({2,2,2}, result); - // result.printIndexedBuffer(); + x.tileToShape({2, 2, 2}, result); + // result.printIndexedBuffer(); - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, tileToShape_test4) { + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + auto result = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {2, 3, 2}, {1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4}); - auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); - auto result = NDArrayFactory::create('c', {2, 3, 2}); - auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + x.tileToShape({2, 3, 2}, result); - x.tileToShape({2,3,2}, result); - - ASSERT_TRUE(result.isSameShape(&exp)); - ASSERT_TRUE(result.equalsTo(&exp)); + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); } #ifndef __CUDABLAS__ TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = + NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = + NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); - float extra = 1.0f; + float extra = 1.0f; - auto la = LAMBDA_DDD(_t, _u, _v, extra) { - return _t + _u + _v + extra; - }; + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(u, v, la, t); + t.applyTriplewiseLambda(u, v, la, t); - ASSERT_TRUE(t.equalsTo(&exp)); + ASSERT_TRUE(t.equalsTo(&exp)); } - TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { - auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto u = NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); - auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + auto t = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = + NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = + NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); - float extra = 1.0f; + float extra = 1.0f; - auto la = LAMBDA_DDD(_t, _u, _v, extra) { - return _t + _u + _v + extra; - }; + auto la = LAMBDA_DDD(_t, _u, _v, extra) { return _t + _u + _v + extra; }; - t.applyTriplewiseLambda(u, v, la, t); + t.applyTriplewiseLambda(u, v, la, t); - ASSERT_TRUE(t.equalsTo(&exp)); + ASSERT_TRUE(t.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Indexed_Lambda) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); - auto lambda = ILAMBDA_D(_x) { - return (float) _idx; - }; + auto lambda = ILAMBDA_D(_x) { return (float)_idx; }; - x.applyIndexedLambda(lambda, x); + x.applyIndexedLambda(lambda, x); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } #endif TEST_F(NDArrayTest2, Test_PermuteEquality_1) { - auto x = NDArrayFactory::create('c', {1, 60}); - auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); - x.linspace(1); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create( + 'c', {3, 5, 4}, + {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, + 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, + 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, + 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, + 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 2, 1}); - x.streamline(); + x.permutei({0, 2, 1}); + x.streamline(); -// x.printShapeInfo("{0, 2, 1} shape"); -// x.printBuffer("{0, 2, 1} data"); + // x.printShapeInfo("{0, 2, 1} shape"); + // x.printBuffer("{0, 2, 1} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_0) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {3, 4, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, + 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, + 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, + 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({0, 1, 2}); - x.streamline(); + x.permutei({0, 1, 2}); + x.streamline(); -// x.printShapeInfo("{0, 1, 2} shape"); -// x.printBuffer("{0, 1, 2} data"); + // x.printShapeInfo("{0, 1, 2} shape"); + // x.printBuffer("{0, 1, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } - TEST_F(NDArrayTest2, Test_PermuteEquality_2) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 3, 5}, + {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, + 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, + 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({1, 0, 2}); - x.streamline(); + x.permutei({1, 0, 2}); + x.streamline(); -// x.printShapeInfo("{1, 0, 2} shape"); -// x.printBuffer("{1, 0, 2} data"); + // x.printShapeInfo("{1, 0, 2} shape"); + // x.printBuffer("{1, 0, 2} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_3) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {4, 5, 3}, + {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, + 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, + 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, + 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, + 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({1, 2, 0}); - x.streamline(); + x.permutei({1, 2, 0}); + x.streamline(); -// x.printShapeInfo("{1, 2, 0} shape"); -// x.printBuffer("{1, 2, 0} data"); + // x.printShapeInfo("{1, 2, 0} shape"); + // x.printBuffer("{1, 2, 0} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_4) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); - x.reshapei('c', {3, 4, 5}); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 3, 4}, + {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, + 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, + 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, + 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, + 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); - x.permutei({2, 0, 1}); - x.streamline(); + x.permutei({2, 0, 1}); + x.streamline(); -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(NDArrayTest2, Test_PermuteEquality_5) { - auto x = NDArrayFactory::create('c', {1, 60}); - x.linspace(1); - auto exp = NDArrayFactory::create('c', {5, 4, 3}, - {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, - 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, - 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, - 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); - x.reshapei('c', {3, 4, 5}); - - x.permutei({2, 1, 0}); - x.streamline(); + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create( + 'c', {5, 4, 3}, + {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, + 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, + 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, + 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, + 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); -// x.printShapeInfo("{2, 0, 1} shape"); -// x.printBuffer("{2, 0, 1} data"); + x.permutei({2, 1, 0}); + x.streamline(); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + // x.printShapeInfo("{2, 0, 1} shape"); + // x.printBuffer("{2, 0, 1} data"); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test1) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 0, 0, 0, 5, 6, 0, 0, 9, 10, 11, 0, 13, 14, 15, 16}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); - - x.fillAsTriangular(0., 0, 0, x, 'u'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, 0, x, 'u'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test2) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0, 13, 14, 15, 0}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); - - x.fillAsTriangular(0., 0, -1, x, 'u'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, -1, x, 'u'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test3) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12, 0, 0, 0, 16}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); - - x.fillAsTriangular(0., 0, 0, x, 'l'); - - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + x.fillAsTriangular(0., 0, 0, x, 'l'); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, fillAsTriangular_test4) { + auto x = NDArrayFactory::create( + 'c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create( + 'c', {4, 4}, {0, 2, 3, 4, 0, 0, 7, 8, 0, 0, 0, 12, 0, 0, 0, 0}); - auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); - auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); - - x.fillAsTriangular(0., 1, 0, x, 'l'); + x.fillAsTriangular(0., 1, 0, x, 'l'); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_DType_Conversion_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); - auto xd = x.template asT(); + auto xd = x.template asT(); - auto xf = xd.template asT(); + auto xf = xd.template asT(); - ASSERT_TRUE(x.isSameShape(xf)); - ASSERT_TRUE(x.equalsTo(xf)); + ASSERT_TRUE(x.isSameShape(xf)); + ASSERT_TRUE(x.equalsTo(xf)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_ScalarArray_Assign_1) { - auto x = NDArrayFactory::create('c', {2, 2}); - auto y = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {2, 2}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create(2.0f); + auto exp = + NDArrayFactory::create('c', {2, 2}, {2.0f, 2.0f, 2.0f, 2.0f}); - x.assign(y); + x.assign(y); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Reshape_To_Vector_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto x = NDArrayFactory::create('c', {2, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - x.reshapei({-1}); + x.reshapei({-1}); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } - TEST_F(NDArrayTest2, Test_toIndexedString_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1.5f, 2.5f, 3.f, 4.5f}); + auto x = NDArrayFactory::create('c', {2, 2}, {1.5f, 2.5f, 3.f, 4.5f}); - auto str = x.asIndexedString(); - std::string exp = "[1.5, 2.5, 3, 4.5]"; + auto str = x.asIndexedString(); + std::string exp = "[1.5, 2.5, 3, 4.5]"; - ASSERT_EQ(exp, str); + ASSERT_EQ(exp, str); } - ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, permute_test4) { + Nd4jLong arr1ShapeInfo[] = {6, 1, 1, 4, 3, 2, 2, 48, + 48, 12, 4, 2, 1, 8192, 1, 99}; + Nd4jLong arr2ShapeInfo[] = {6, 1, 2, 2, 1, 4, 3, 48, + 2, 1, 48, 12, 4, 8192, 0, 99}; - Nd4jLong arr1ShapeInfo[] = {6, 1, 1, 4, 3, 2, 2, 48, 48, 12, 4, 2, 1, 8192, 1, 99}; - Nd4jLong arr2ShapeInfo[] = {6, 1, 2, 2, 1, 4, 3, 48, 2, 1, 48, 12, 4, 8192, 0, 99}; + auto arr1Buffer = new float[786432]; + auto arr2Buffer = new float[786432]; + NDArray arr1(arr1Buffer, arr1ShapeInfo, sd::LaunchContext ::defaultContext()); + NDArray arr2(arr2Buffer, arr2ShapeInfo, sd::LaunchContext ::defaultContext()); - auto arr1Buffer = new float[786432]; - auto arr2Buffer = new float[786432]; + const std::vector perm = {0, 4, 5, 1, 2, 3}; + auto arr1P = arr1.permute(perm); + // arr1P->printShapeInfo(); - NDArray arr1(arr1Buffer, arr1ShapeInfo, sd::LaunchContext ::defaultContext()); - NDArray arr2(arr2Buffer, arr2ShapeInfo, sd::LaunchContext ::defaultContext()); - - const std::vector perm = {0, 4, 5, 1, 2, 3}; - auto arr1P = arr1.permute(perm); - // arr1P->printShapeInfo(); - - // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); - ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); - delete []arr1Buffer; - delete []arr2Buffer; + // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); + delete[] arr1Buffer; + delete[] arr2Buffer; } //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, TestStdDev3) { + // autoarray('c', {10, 10}); + auto array = NDArrayFactory::create('c', {2, 2}, + {0.2946, 0.2084, 0.0345, 0.7368}); + const int len = array.lengthOf(); - // autoarray('c', {10, 10}); - auto array = NDArrayFactory::create('c', {2, 2}, {0.2946, 0.2084, 0.0345, 0.7368}); - const int len = array.lengthOf(); + double sum = 0.; + for (int i = 0; i < len; ++i) sum += array.e(i); - double sum = 0.; - for(int i=0; i < len; ++i) - sum += array.e(i); + const double mean = sum / len; - const double mean = sum / len; + double diffSquared = 0.; + for (int i = 0; i < len; ++i) + diffSquared += (array.e(i) - mean) * (array.e(i) - mean); - double diffSquared = 0.; - for(int i=0; i < len; ++i) - diffSquared += (array.e(i) - mean) * (array.e(i) - mean); + const double trueVariance = + math::nd4j_sqrt(diffSquared / len); + const double trueVarianceCorr = + math::nd4j_sqrt(diffSquared / (len - 1)); - const double trueVariance = math::nd4j_sqrt(diffSquared / len); - const double trueVarianceCorr = math::nd4j_sqrt(diffSquared / (len - 1)); + const double variance = + array.varianceNumber(variance::SummaryStatsStandardDeviation, false) + .e(0); + const double varianceCorr = + array.varianceNumber(variance::SummaryStatsStandardDeviation, true) + .e(0); - const double variance = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); - const double varianceCorr = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + // printf("%s expected %.10f calculated %.10f\n","variance :", + // trueVariance, variance ); printf("%s expected %.10f calculated + // %.10f\n","variance corrected:", trueVarianceCorr, varianceCorr); - // printf("%s expected %.10f calculated %.10f\n","variance :", trueVariance, variance ); - // printf("%s expected %.10f calculated %.10f\n","variance corrected:", trueVarianceCorr, varianceCorr); - - ASSERT_NEAR(trueVariance, variance, 1e-8); - ASSERT_NEAR(trueVarianceCorr, varianceCorr, 1e-8); + ASSERT_NEAR(trueVariance, variance, 1e-8); + ASSERT_NEAR(trueVarianceCorr, varianceCorr, 1e-8); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_1) { - auto exp = NDArrayFactory::create('c',{1,5}, {1., 2., 3., 4., 5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_2) { - auto exp = NDArrayFactory::create('c',{1,5}, {1., 3., 5., 7., 9.}); - auto x = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1., 3., 5., 7., 9.}); + auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1, 2); + x.linspace(1, 2); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_3) { + auto exp = + NDArrayFactory::create('c', {1, 5}, {1., 4., 7., 10., 13.}); - auto exp = NDArrayFactory::create('c',{1,5}, {1., 4., 7., 10., 13.}); - - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(1,3); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1, 3); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_4) { - auto exp = NDArrayFactory::create('c',{1,5}, {-1., -2., -3., -4., -5.}); + auto exp = + NDArrayFactory::create('c', {1, 5}, {-1., -2., -3., -4., -5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(-1, -1); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(-1, -1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, Test_Linspace_5) { - auto exp = NDArrayFactory::create('c',{1,5}, {9., 8., 7., 6., 5.}); + auto exp = NDArrayFactory::create('c', {1, 5}, {9., 8., 7., 6., 5.}); - auto x = NDArrayFactory::create('c', {1, 5}); - x.linspace(9, -1); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(9, -1); - ASSERT_TRUE(x.equalsTo(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); } - //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - - auto set = x.allTensorsAlongDimension({0}); - // set->at(0)->printShapeInfo(); - // set->at(0)->printIndexedBuffer(); + auto set = x.allTensorsAlongDimension({0}); + // set->at(0)->printShapeInfo(); + // set->at(0)->printIndexedBuffer(); - ASSERT_TRUE(set.size() == 1); - ASSERT_TRUE(exp.isSameShape(set.at(0))); - ASSERT_TRUE(exp.equalsTo(set.at(0))); + ASSERT_TRUE(set.size() == 1); + ASSERT_TRUE(exp.isSameShape(set.at(0))); + ASSERT_TRUE(exp.equalsTo(set.at(0))); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_get_test1) { + auto scalar1 = NDArrayFactory::create(20.f); - auto scalar1 = NDArrayFactory::create(20.f); + NDArray arr('c', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray arr('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray scalar2 = arr.e(2); - NDArray scalar2 = arr.e(2); - - ASSERT_TRUE(scalar1.isSameShape(scalar2)); - ASSERT_TRUE(scalar1.equalsTo(scalar2)); - ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_get_test2) { + auto scalar1 = NDArrayFactory::create(20.f); - auto scalar1 = NDArrayFactory::create(20.f); - - NDArray arr('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray arr('f', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray scalar2 = arr.e(1); + NDArray scalar2 = arr.e(1); - ASSERT_TRUE(scalar1.isSameShape(scalar2)); - ASSERT_TRUE(scalar1.equalsTo(scalar2)); - ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); } //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_set_test1) { + NDArray scalar1 = NDArrayFactory::create(20.f); - NDArray scalar1 = NDArrayFactory::create(20.f); + NDArray arr('c', {2, 2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('c', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - NDArray arr('c', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); - NDArray exp('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + arr.p(2, scalar1); - arr.p(2, scalar1); - - ASSERT_TRUE(exp.equalsTo(arr)); + ASSERT_TRUE(exp.equalsTo(arr)); } - //////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, scalar_set_test2) { + NDArray scalar1 = NDArrayFactory::create(20.f); - NDArray scalar1 = NDArrayFactory::create(20.f); - - NDArray arr('f', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); - NDArray exp('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + NDArray arr('f', {2, 2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('f', {2, 2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); - arr.p(1, scalar1); + arr.p(1, scalar1); - ASSERT_TRUE(exp.equalsTo(arr)); + ASSERT_TRUE(exp.equalsTo(arr)); } TEST_F(NDArrayTest2, big_dup_test) { - // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); - auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); - auto dup = new NDArray(arr->dup('c')); + // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); + auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); + auto dup = new NDArray(arr->dup('c')); - ASSERT_EQ(*arr, *dup); + ASSERT_EQ(*arr, *dup); - delete arr; - delete dup; + delete arr; + delete dup; } TEST_F(NDArrayTest2, debugInfoTest_1) { - NDArray testArray('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., - 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); - NDArray res(sd::DataType::DOUBLE); - DebugInfo info = DebugHelper::debugStatistics(&testArray); - DebugInfo exp; // = {} - sd::ops::reduce_min minOp; - sd::ops::reduce_mean meanOp; - sd::ops::reduce_max maxOp; - sd::ops::reduce_stdev stdevOp; - - minOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._minValue = res.e(0); - meanOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._meanValue = res.e(0); - maxOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._maxValue = res.e(0); - stdevOp.execute({&testArray}, {&res}, {}, {}, {}); - exp._stdDevValue = res.e(0); - exp._zeroCount = 3; - exp._negativeCount = 7; - exp._positiveCount = 118; - exp._infCount = 0; - exp._nanCount = 0; - printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); - printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - ASSERT_EQ(exp, info); + NDArray testArray( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., + 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., + 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., + 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., + 14., 114., 16., 117., 91., -82., 37., 64., -55.1, 0, 73., + 28., -119., 12., 112., 13., 14., 114., 16.2, 117., 91., -82., + 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., + 114., 16., 117., 51., 42., 67., 24., 15., 0., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., + 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., + 16., 117., 91., 82., 37., 64., -3, 0, 73., 28., 119., + 12., 112., 13., 140., 110., 160., 107.}, + sd::DataType::DOUBLE); + NDArray res(sd::DataType::DOUBLE); + DebugInfo info = DebugHelper::debugStatistics(&testArray); + DebugInfo exp; // = {} + sd::ops::reduce_min minOp; + sd::ops::reduce_mean meanOp; + sd::ops::reduce_max maxOp; + sd::ops::reduce_stdev stdevOp; + + minOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._minValue = res.e(0); + meanOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._meanValue = res.e(0); + maxOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._maxValue = res.e(0); + stdevOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._stdDevValue = res.e(0); + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, + info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, + exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, + info._positiveCount, info._infCount, info._nanCount); + ASSERT_EQ(exp, info); } TEST_F(NDArrayTest2, debugInfoTest_2) { - NDArray testArray('c', {2, 4, 4, 4}, { - 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., - 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., - 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., - 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., - 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); - - DebugInfo info; - DebugInfo exp; // = {} - exp._minValue = -119; - exp._maxValue = 160.; - exp._meanValue = 51.328906; - exp._stdDevValue = 52.385694; - exp._zeroCount = 3; - exp._negativeCount = 7; - exp._positiveCount = 118; - exp._infCount = 0; - exp._nanCount = 0; - DebugHelper::retrieveDebugStatistics(&info, &testArray); - printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); - printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - //printf("%lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); - //printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); - ASSERT_EQ(exp, info); + NDArray testArray( + 'c', {2, 4, 4, 4}, + {91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., + 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., + 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., + 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., + 14., 114., 16., 117., 91., -82., 37., 64., -55.1, 0, 73., + 28., -119., 12., 112., 13., 14., 114., 16.2, 117., 91., -82., + 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., + 114., 16., 117., 51., 42., 67., 24., 15., 0., 93., 28., + 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., + 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., + 16., 117., 91., 82., 37., 64., -3, 0, 73., 28., 119., + 12., 112., 13., 140., 110., 160., 107.}, + sd::DataType::DOUBLE); + + DebugInfo info; + DebugInfo exp; // = {} + exp._minValue = -119; + exp._maxValue = 160.; + exp._meanValue = 51.328906; + exp._stdDevValue = 52.385694; + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + DebugHelper::retrieveDebugStatistics(&info, &testArray); + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, + info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, + exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, + info._positiveCount, info._infCount, info._nanCount); + // printf("%lf %lf %lf %lf\n", info._minValue, info._maxValue, + // info._meanValue, info._stdDevValue); printf("%lld %lld %lld %lld %lld\n", + // info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, + // info._nanCount); + ASSERT_EQ(exp, info); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_1) { + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - NDArray x('c', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - - ASSERT_EQ(5, subArr1.ews()); + ASSERT_EQ(5, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_2) { + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - NDArray x('f', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); - - ASSERT_EQ(1, subArr1.ews()); + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_3) { + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - NDArray x('c', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - - ASSERT_EQ(1, subArr1.ews()); + ASSERT_EQ(1, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, test_subarray_ews_4) { + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - NDArray x('f', {10, 5}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); - - ASSERT_EQ(10, subArr1.ews()); + ASSERT_EQ(10, subArr1.ews()); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, subarray_1) { - - NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); - - Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 12, 99}; - float buffExpX0[] = {1.000000, 13.000000}; - float buffExpX1[] = {2.000000, 14.000000}; - Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 12, 99}; - float buffExpX2[] = {1.000000, 13.000000}; - Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 0, 99}; - float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; - Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 0, 99}; - float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; - Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 4, 99}; - float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000}; - - Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 102}; - float buffExpY0[] = {1.000000, 2.000000}; - float buffExpY1[] = {7.000000, 8.000000}; - Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.000000, 2.000000}; - Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 0, 102}; - float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; - Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 0, 102}; - float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; - Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 102}; - float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000}; - - - NDArray x0 = x(0, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) - ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); - for(int i = 0; i < x0.lengthOf(); ++i) - ASSERT_TRUE(x0.e(i) == buffExpX0[i]); - - NDArray x1 = x(1, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) - ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX0[i]); - for(int i = 0; i < x1.lengthOf(); ++i) - ASSERT_TRUE(x1.e(i) == buffExpX1[i]); - - NDArray x2 = x(0, {1,2}, true); - for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) - ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); - for(int i = 0; i < x2.lengthOf(); ++i) - ASSERT_TRUE(x2.e(i) == buffExpX2[i]); - - NDArray x3 = x(2, {1}); - for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) - ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); - for(int i = 0; i < x3.lengthOf(); ++i) - ASSERT_TRUE(x3.e(i) == buffExpX3[i]); - - NDArray x4 = x(2, {1}, true); - for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) - ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); - for(int i = 0; i < x4.lengthOf(); ++i) - ASSERT_TRUE(x4.e(i) == buffExpX4[i]); - - NDArray x5 = x(3, {2}); - for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) - ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); - for(int i = 0; i < x5.lengthOf(); ++i) - ASSERT_TRUE(x5.e(i) == buffExpX5[i]); - - // ******************* // - NDArray y0 = y(0, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) - ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); - for(int i = 0; i < y0.lengthOf(); ++i) - ASSERT_TRUE(y0.e(i) == buffExpY0[i]); - - NDArray y1 = y(1, {1,2}); - for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) - ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY0[i]); - for(int i = 0; i < y1.lengthOf(); ++i) - ASSERT_TRUE(y1.e(i) == buffExpY1[i]); - - NDArray y2 = y(0, {1,2}, true); - for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) - ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); - for(int i = 0; i < y2.lengthOf(); ++i) - ASSERT_TRUE(y2.e(i) == buffExpY2[i]); - - NDArray y3 = y(2, {1}); - for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) - ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); - for(int i = 0; i < y3.lengthOf(); ++i) - ASSERT_TRUE(y3.e(i) == buffExpY3[i]); - - NDArray y4 = y(2, {1}, true); - for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) - ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); - for(int i = 0; i < y4.lengthOf(); ++i) - ASSERT_TRUE(y4.e(i) == buffExpY4[i]); - - NDArray y5 = y(3, {2}); - for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) - ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); - for(int i = 0; i < y5.lengthOf(); ++i) - ASSERT_TRUE(y5.e(i) == buffExpY5[i]); - + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + NDArray y('f', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, + sd::DataType::FLOAT32); + + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 12, 99}; + float buffExpX0[] = {1.000000, 13.000000}; + float buffExpX1[] = {2.000000, 14.000000}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 12, 99}; + float buffExpX2[] = {1.000000, 13.000000}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 0, 99}; + float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, + 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 0, 99}; + float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, + 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 4, 99}; + float buffExpX5[] = {4.000000, 8.000000, 12.000000, + 16.000000, 20.000000, 24.000000}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 102}; + float buffExpY0[] = {1.000000, 2.000000}; + float buffExpY1[] = {7.000000, 8.000000}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.000000, 2.000000}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 0, 102}; + float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, + 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 0, 102}; + float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, + 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 102}; + float buffExpY5[] = {19.000000, 21.000000, 23.000000, + 20.000000, 22.000000, 24.000000}; + + NDArray x0 = x(0, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) + ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); + for (int i = 0; i < x0.lengthOf(); ++i) + ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) + ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX0[i]); + for (int i = 0; i < x1.lengthOf(); ++i) + ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1, 2}, true); + for (int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) + ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); + for (int i = 0; i < x2.lengthOf(); ++i) + ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + for (int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) + ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); + for (int i = 0; i < x3.lengthOf(); ++i) + ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + for (int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) + ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); + for (int i = 0; i < x4.lengthOf(); ++i) + ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + for (int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) + ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); + for (int i = 0; i < x5.lengthOf(); ++i) + ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) + ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); + for (int i = 0; i < y0.lengthOf(); ++i) + ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1, 2}); + for (int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) + ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY0[i]); + for (int i = 0; i < y1.lengthOf(); ++i) + ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1, 2}, true); + for (int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) + ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); + for (int i = 0; i < y2.lengthOf(); ++i) + ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + for (int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) + ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); + for (int i = 0; i < y3.lengthOf(); ++i) + ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + for (int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) + ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); + for (int i = 0; i < y4.lengthOf(); ++i) + ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + for (int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) + ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); + for (int i = 0; i < y5.lengthOf(); ++i) + ASSERT_TRUE(y5.e(i) == buffExpY5[i]); } TEST_F(NDArrayTest2, test_subarray_interval_1) { + NDArray x('f', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0, 9)}); - NDArray x('f', {10, 10}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - - ASSERT_EQ(10, subArr1.sizeAt(0)); - ASSERT_EQ(9, subArr1.sizeAt(1)); + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_interval_2) { + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0, 9)}); - NDArray x('c', {10, 10}, sd::DataType::FLOAT32); - auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); - - ASSERT_EQ(10, subArr1.sizeAt(0)); - ASSERT_EQ(9, subArr1.sizeAt(1)); + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); } TEST_F(NDArrayTest2, test_subarray_3d_cf) { - NDArray f('f', {10, 20, 30}, sd::DataType::FLOAT32); - NDArray c('c', {10, 20, 30}, sd::DataType::FLOAT32); + NDArray f('f', {10, 20, 30}, sd::DataType::FLOAT32); + NDArray c('c', {10, 20, 30}, sd::DataType::FLOAT32); - auto subarrayF = f({0,0, 0,0, 2,3}, true); + auto subarrayF = f({0, 0, 0, 0, 2, 3}, true); - auto subarrayC = c({2,3, 0,0, 0,0}, true); + auto subarrayC = c({2, 3, 0, 0, 0, 0}, true); } TEST_F(NDArrayTest2, test_broadcast_row_1) { - auto x = NDArrayFactory::create('c', {10, 5}); - auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {10, 5}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {10, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {10, 5}); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_1) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_2) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_3) { - auto x = NDArrayFactory::create('c', {5, 10}); - auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('c', {5, 10}); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = + NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_broadcast_column_4) { - auto x = NDArrayFactory::create('f', {10, 5}); - auto y = NDArrayFactory::create('f', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - auto e = NDArrayFactory::create('f', {10, 5}); - e.assign(1.0f); + auto x = NDArrayFactory::create('f', {10, 5}); + auto y = NDArrayFactory::create('f', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('f', {10, 5}); + e.assign(1.0f); - x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_not_tiled_1) { - auto x = NDArrayFactory::create('c', {4, 12, 128, 128}); - auto y = NDArrayFactory::create('c', {4, 1, 128, 128}); - auto e = NDArrayFactory::create('c', {4, 12, 128, 128}); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {4, 12, 128, 128}); + auto y = NDArrayFactory::create('c', {4, 1, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 12, 128, 128}); + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_not_tiled_2) { - auto x = NDArrayFactory::create('c', {4, 128, 768}); - auto y = NDArrayFactory::create('c', {4, 128, 1}); - auto e = NDArrayFactory::create('c', {4, 128, 768}); - y.assign(1.0f); - e.assign(1.0f); + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + y.assign(1.0f); + e.assign(1.0f); - x += y; + x += y; - ASSERT_EQ(e, x); + ASSERT_EQ(e, x); } TEST_F(NDArrayTest2, test_long_sum_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto z = x.reduceAlongDimension(reduce::Sum, {0}); + auto z = x.reduceAlongDimension(reduce::Sum, {0}); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reshapei_1) { + Nd4jLong shapeInfo1[] = {6, 2, 1, 2, 1, 7, 1, 7, + 7, 14, 28, 1, 1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {6, 2,1,2,1,7,1, 7,7,14,28,1,1, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); - auto buffer = new float[shape::length(shapeInfo1)]; - NDArray x(buffer, shapeInfo1); + const bool canReshape = x.reshapei({4, 7}); - const bool canReshape = x.reshapei({4,7}); + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - ASSERT_FALSE(canReshape); - ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - - delete[] buffer; + delete[] buffer; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reshapei_2) { + Nd4jLong shapeInfo1[] = {6, 1, 2, 1, 2, 7, 1, 28, + 7, 7, 14, 1, 1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - Nd4jLong shapeInfo1[] = {6, 1,2,1,2,7,1, 28,7,7,14,1,1, 8192, 0, 99}; - Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; - - auto buffer = new float[shape::length(shapeInfo1)]; - NDArray x(buffer, shapeInfo1); + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); - const bool canReshape = x.reshapei({4,7}); + const bool canReshape = x.reshapei({4, 7}); - ASSERT_FALSE(canReshape); - ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); - delete[] buffer; + delete[] buffer; } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, trueBroadcast_1) { + NDArray x('f', {2, 3}, {1., 2., 3., 4., 5., 6.}); + NDArray y('f', {1, 3}, {5., 4., 3.}); + NDArray z('c', {2, 3}, sd::DataType::DOUBLE); - NDArray x('f', {2, 3}, {1., 2., 3., 4., 5., 6.}); - NDArray y('f', {1, 3}, {5., 4., 3.}); - NDArray z('c', {2, 3}, sd::DataType::DOUBLE); + auto exp = x - y; + x.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), y, z); - auto exp = x - y; - x.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), y, z); + // exp.printIndexedBuffer(); + // z.printIndexedBuffer(); - // exp.printIndexedBuffer(); - // z.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reduce_1) { + NDArray arr6('f', {1, 1, 4, 4, 4, 4}, sd::DataType::DOUBLE); + NDArray exp('f', {1, 1, 4, 4}, sd::DataType::DOUBLE); - NDArray arr6('f', {1, 1, 4, 4, 4, 4}, sd::DataType::DOUBLE); - NDArray exp('f', {1, 1, 4, 4}, sd::DataType::DOUBLE); - - arr6.linspace(1); + arr6.linspace(1); - NDArray arr6s = arr6.reduceAlongDimension(sd::reduce::Sum, {2,3}); + NDArray arr6s = arr6.reduceAlongDimension(sd::reduce::Sum, {2, 3}); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - double sum = 0; - for (int x = 0; x < 4; x++) { - for (int y = 0; y < 4; y++) { - Nd4jLong indices[] = {0, 0, x, y, i, j}; - Nd4jLong offset = shape::getOffset(arr6.shapeInfo(), indices); - sum += ((double*)arr6.buffer())[offset]; - } - } - exp.p(0, 0, i, j, sum); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + double sum = 0; + for (int x = 0; x < 4; x++) { + for (int y = 0; y < 4; y++) { + Nd4jLong indices[] = {0, 0, x, y, i, j}; + Nd4jLong offset = shape::getOffset(arr6.shapeInfo(), indices); + sum += ((double *)arr6.buffer())[offset]; } + } + exp.p(0, 0, i, j, sum); } + } - // arr6s->printShapeInfo(); - // exp.printShapeInfo(); - // exp.printIndexedBuffer(); - // arr6s->printIndexedBuffer(); + // arr6s->printShapeInfo(); + // exp.printShapeInfo(); + // exp.printIndexedBuffer(); + // arr6s->printIndexedBuffer(); - ASSERT_TRUE(exp.equalsTo(arr6s)); + ASSERT_TRUE(exp.equalsTo(arr6s)); } ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest2, reduce3_1) { + NDArray x('c', {1, 4}, {1, 2, 3, 4}); + NDArray y('c', {1, 4}, {2, 3, 4, 5}); + NDArray exp('c', {4}, {1, 1, 1, 1}); - NDArray x('c', {1,4}, {1,2,3,4}); - NDArray y('c', {1,4}, {2,3,4,5}); - NDArray exp('c', {4}, {1,1,1,1}); + NDArray z = x.applyReduce3(sd::reduce3::EuclideanDistance, y, {0}, nullptr); - NDArray z = x.applyReduce3(sd::reduce3::EuclideanDistance, y, {0}, nullptr); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NDArrayTest2, all_tads_1) { - auto x = NDArrayFactory::create('c', {3, 5}); + auto x = NDArrayFactory::create('c', {3, 5}); - auto arrays = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, arrays.size()); + auto arrays = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, arrays.size()); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { - auto x = NDArrayFactory::create('c', {0, 2}); - auto y = NDArrayFactory::create('c', {1, 2}); + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); - auto z = x + y; + auto z = x + y; - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(NDArrayTest2, test_trueBroadcast_empty_2) { - auto x = NDArrayFactory::create('c', {0, 2}); - auto y = NDArrayFactory::create('c', {1, 2}); + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); - auto z = y + x; + auto z = y + x; - ASSERT_EQ(x, z); + ASSERT_EQ(x, z); } TEST_F(NDArrayTest2, test_subarray_followed_by_reshape_1) { + NDArray x('c', {5, 1, 3}, sd::DataType::FLOAT32); + NDArray e('c', {1, 3}, {7.f, 8.f, 9.f}, sd::DataType::FLOAT32); - NDArray x('c', {5, 1, 3}, sd::DataType::FLOAT32); - NDArray e('c', {1, 3}, {7.f, 8.f, 9.f}, sd::DataType::FLOAT32); - - x.linspace(1.); + x.linspace(1.); - auto s = x({2,3, 0,0, 0,0}); + auto s = x({2, 3, 0, 0, 0, 0}); - // s.printIndexedBuffer("s"); + // s.printIndexedBuffer("s"); - auto r = s.reshape(x.ordering(), {1, 3}); - // r.printIndexedBuffer("r"); + auto r = s.reshape(x.ordering(), {1, 3}); + // r.printIndexedBuffer("r"); - ASSERT_EQ(e, r); + ASSERT_EQ(e, r); } TEST_F(NDArrayTest2, test_numpy_import_1) { - std::string fname("./resources/arr_3,4_float32.npy"); - auto exp = NDArrayFactory::create('c', {3, 4}); - exp.linspace(0); + std::string fname("./resources/arr_3,4_float32.npy"); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(0); - auto array = NDArrayFactory::fromNpyFile(fname.c_str()); + auto array = NDArrayFactory::fromNpyFile(fname.c_str()); - ASSERT_EQ(exp, array); + ASSERT_EQ(exp, array); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 3421edf953c2..36df374425f7 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -18,1455 +18,1431 @@ // Created by GS on 22.07.2019. // -#include "testlayers.h" #include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class NativeOpsTests : public testing::Test { -public: - + public: }; - TEST_F(NativeOpsTests, CreateContextTests_1) { -// auto x = NDArrayFactory::create('c', {5, 5}); -// x.assign(1.0); -// auto z = NDArrayFactory::create('c', {5,5}); -// auto exp = NDArrayFactory::create('c', {5, 5}); - auto context = ::createContext(); - ASSERT_TRUE(context == nullptr); - //delete context; + // auto x = NDArrayFactory::create('c', {5, 5}); + // x.assign(1.0); + // auto z = NDArrayFactory::create('c', {5,5}); + // auto exp = NDArrayFactory::create('c', {5, 5}); + auto context = ::createContext(); + ASSERT_TRUE(context == nullptr); + // delete context; } TEST_F(NativeOpsTests, CreateContextTests_2) { -// auto x = NDArrayFactory::create('c', {5, 5}); -// x.assign(1.0); -// auto z = NDArrayFactory::create('c', {5,5}); -// auto exp = NDArrayFactory::create('c', {5, 5}); - auto context1 = ::createContext(); - auto context2 = ::createContext(); - ASSERT_TRUE(context1 == context2); - //delete context1; - //delete context2; + // auto x = NDArrayFactory::create('c', {5, 5}); + // x.assign(1.0); + // auto z = NDArrayFactory::create('c', {5,5}); + // auto exp = NDArrayFactory::create('c', {5, 5}); + auto context1 = ::createContext(); + auto context2 = ::createContext(); + ASSERT_TRUE(context1 == context2); + // delete context1; + // delete context2; } TEST_F(NativeOpsTests, PointerTests_1) { - auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); // x.linspace(1.0); #ifdef __CUDABLAS__ -printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::tryPointer(nullptr, x.buffer(), 4); + ::tryPointer(nullptr, x.buffer(), 4); #endif -// auto exp = NDArrayFactory::create('c', {5, 5}); -// exp.assign(-1.0); -// -// sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg -// auto result = op.execute({&x}, {}, {}); -// -// ASSERT_EQ(1, result->size()); -// -// auto z = result->at(0); -// -// ASSERT_TRUE(exp.equalsTo(z)); -// -// delete result; + // auto exp = NDArrayFactory::create('c', {5, 5}); + // exp.assign(-1.0); + // + // sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + // auto result = op.execute({&x}, {}, {}); + // + // ASSERT_EQ(1, result->size()); + // + // auto z = result->at(0); + // + // ASSERT_TRUE(exp.equalsTo(z)); + // + // delete result; } TEST_F(NativeOpsTests, ThresholdTests_1) { // auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); // x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::setElementThreshold(4); - ASSERT_TRUE(4 == sd::Environment::getInstance()->elementwiseThreshold()); + ::setElementThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance()->elementwiseThreshold()); #endif - } TEST_F(NativeOpsTests, ThresholdTests_2) { // auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); // x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - ::setTADThreshold(4); - ASSERT_TRUE(4 == sd::Environment::getInstance()->tadThreshold()); + ::setTADThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance()->tadThreshold()); #endif - } TEST_F(NativeOpsTests, ExecIndexReduce_1) { - auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); - auto exp = NDArrayFactory::create(120); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execIndexReduceScalar(nullptr, - indexreduce::IndexMax, - &xBuf, x.shapeInfo(), - nullptr, - nullptr, - &expBuf, exp.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 4LL); -#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execIndexReduceScalar(nullptr, indexreduce::IndexMax, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr); + ASSERT_TRUE(exp.e(0) == 4LL); +#endif } TEST_F(NativeOpsTests, ExecIndexReduce_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - NDArray dimension = NDArrayFactory::create({}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); - - ::execIndexReduce(nullptr, - indexreduce::IndexMax, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), - nullptr, - &dimensionBuf, dimension.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 24LL); -#endif + NDArray dimension = NDArrayFactory::create({}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); + ::execIndexReduce(nullptr, indexreduce::IndexMax, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimensionBuf, dimension.shapeInfo(), nullptr); + + ASSERT_TRUE(exp.e(0) == 24LL); +#endif } TEST_F(NativeOpsTests, ExecBroadcast_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 1}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.linspace(2,2); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2, 2); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execBroadcast(nullptr, - broadcast::Add, - &xBuf, x.shapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), - nullptr); - - ASSERT_TRUE(exp.e(0) == 3.); -#endif + auto dimension = NDArrayFactory::create('c', {1}, {1}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + ::execBroadcast(nullptr, broadcast::Add, &xBuf, x.shapeInfo(), nullptr, &yBuf, + y.shapeInfo(), nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), nullptr); + + ASSERT_TRUE(exp.e(0) == 3.); +#endif } TEST_F(NativeOpsTests, ExecBroadcast_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 1}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.linspace(2,2); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2, 2); #ifdef __CUDABLAS__ -printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - int dimd = 0; - auto dimension = NDArrayFactory::create('c', {1}, {dimd}); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execBroadcastBool(nullptr, - broadcast::EqualTo, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, nullptr, - &dimBuf, dimension.shapeInfo(), - nullptr); - ASSERT_TRUE(exp.e(1) && !exp.e(0)); + int dimd = 0; + auto dimension = NDArrayFactory::create('c', {1}, {dimd}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execBroadcastBool(nullptr, broadcast::EqualTo, &xBuf, x.shapeInfo(), + nullptr, &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr, &dimBuf, + dimension.shapeInfo(), nullptr); + ASSERT_TRUE(exp.e(1) && !exp.e(0)); #endif - } TEST_F(NativeOpsTests, ExecPairwise_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.linspace(1.0); - y.assign(2.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.assign(2.); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execPairwiseTransform(nullptr, - pairwise::Add, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, - nullptr); - ASSERT_TRUE(exp.e(5) == 8.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransform(nullptr, pairwise::Add, &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr); + ASSERT_TRUE(exp.e(5) == 8.); #endif - } TEST_F(NativeOpsTests, ExecPairwise_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - x.assign(true); - y.assign(false); - y.t(5) = true; + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(true); + y.assign(false); + y.t(5) = true; #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execPairwiseTransformBool(nullptr, - pairwise::And, - &xBuf, x.shapeInfo(), nullptr, - &yBuf, y.shapeInfo(), nullptr, - &expBuf, exp.shapeInfo(), nullptr, - nullptr); - ASSERT_TRUE(exp.e(5) && !exp.e(4)); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransformBool(nullptr, pairwise::And, &xBuf, x.shapeInfo(), + nullptr, &yBuf, y.shapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), nullptr, nullptr); + ASSERT_TRUE(exp.e(5) && !exp.e(4)); #endif - } TEST_F(NativeOpsTests, ReduceTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceFloat(nullptr, - reduce::Mean, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Mean"); - ASSERT_TRUE(exp.e(0) == 13.); + auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat(nullptr, reduce::Mean, &xBuf, x.shapeInfo(), nullptr, + nullptr, &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.e(0) == 13.); #endif - } TEST_F(NativeOpsTests, ReduceTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceSame(nullptr, - reduce::Sum, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Sum"); - ASSERT_TRUE(exp.e(0) == 325.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceSame(nullptr, reduce::Sum, &xBuf, x.shapeInfo(), nullptr, nullptr, + &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.e(0) == 325.); #endif - } TEST_F(NativeOpsTests, ReduceTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(false); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceBool(nullptr, - reduce::All, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.e(0) == true); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceBool(nullptr, reduce::All, &xBuf, x.shapeInfo(), nullptr, nullptr, + &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(0) == true); #endif - } TEST_F(NativeOpsTests, ReduceTest_4) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120LL); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceLong(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.e(0) == 25LL); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), nullptr, + nullptr, &expBuf, exp.shapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); #endif - } TEST_F(NativeOpsTests, ReduceTest_5) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120LL); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create({0, 1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduceLong2(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.e(0) == 25LL); + auto dimension = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduceLong2(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); #endif - } TEST_F(NativeOpsTests, ReduceTest_6) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create({5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create({1,2,3,4,6}); - x.linspace(1.0); + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create({5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create({1, 2, 3, 4, 6}); + x.linspace(1.0); #ifdef __CUDABLAS__ - printf("Unsupported for cuda now.\n"); + printf("Unsupported for cuda now.\n"); #else - auto dimension = NDArrayFactory::create('c', {1}, {1}); - x.p(5, 0); - x.p(10, 0); x.p(11, 0); - x.p(15, 0); x.p(16, 0); x.p(17, 0); - x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceLong2(nullptr, - reduce::CountNonZero, - &xBuf, x.shapeInfo(), nullptr, - nullptr, - &expBuf, exp.shapeInfo(), nullptr, - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce CountNonZero"); - ASSERT_TRUE(exp.equalsTo(z)); + auto dimension = NDArrayFactory::create('c', {1}, {1}); + x.p(5, 0); + x.p(10, 0); + x.p(11, 0); + x.p(15, 0); + x.p(16, 0); + x.p(17, 0); + x.p(20, 0); + x.p(21, 0); + x.p(22, 0); + x.p(23, 0); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong2(nullptr, reduce::CountNonZero, &xBuf, x.shapeInfo(), + nullptr, nullptr, &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.equalsTo(z)); #endif - } TEST_F(NativeOpsTests, ReduceTest_7) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(13.); - + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(13.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - x.syncToHost(); - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; #endif - x.linspace(1.0); - x.syncToDevice(); - dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execReduceFloat2(extra, - reduce::Mean, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Mean"); - ASSERT_TRUE(exp.equalsTo(z)); - + x.linspace(1.0); + x.syncToDevice(); + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat2(extra, reduce::Mean, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ReduceTest_8) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create(120.); - auto exp = NDArrayFactory::create(325.); - + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create(120.); + auto exp = NDArrayFactory::create(325.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); #endif - x.linspace(1.0); - x.syncToDevice(); + x.linspace(1.0); + x.syncToDevice(); - dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - - ::execReduceSame2(extra, - reduce::Sum, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce Sum"); - ASSERT_TRUE(exp.equalsTo(z)); + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + ::execReduceSame2(extra, reduce::Sum, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ReduceTest_9) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(false); - auto z = NDArrayFactory::create(true); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + auto z = NDArrayFactory::create(true); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); #endif - x.linspace(1.0); - x.syncToDevice(); + x.linspace(1.0); + x.syncToDevice(); - dimension.syncToHost(); + dimension.syncToHost(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduceBool2(extra, - reduce::All, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduceBool2(extra, reduce::All, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduce3(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); - //z.printIndexedBuffer("Z"); - //exp.printIndexedBuffer("Reduce3 Dot"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduce3(extra, reduce3::Dot, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), &expBuf, + exp.shapeInfo(), exp.specialShapeInfo()); + // z.printIndexedBuffer("Z"); + // exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execReduce3Scalar(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce3 Dot"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execReduce3Scalar(extra, reduce3::Dot, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); - dimension.syncToHost(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduce3Tad(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - nullptr, nullptr, nullptr, nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3Tad( + extra, reduce3::Dot, &xBuf, x.shapeInfo(), x.specialShapeInfo(), nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, Reduce3Test_4) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create(120.); - auto z = NDArrayFactory::create(650.); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); - auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); - Nd4jPointer extra[6]; + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.assign(2.); - x.syncToDevice(); - dimension.syncToHost(); - int* dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), dimensions, dimension.lengthOf()); - - auto hTADShapeInfoX = tadPackX.primaryShapeInfo(); - auto hTADOffsetsX = tadPackX.primaryOffsets(); - auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); - auto hTADOffsetsY = tadPackY.primaryOffsets(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execReduce3All(extra, - reduce3::Dot, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + int *dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + y.shapeInfo(), dimensions, dimension.lengthOf()); + + auto hTADShapeInfoX = tadPackX.primaryShapeInfo(); + auto hTADOffsetsX = tadPackX.primaryOffsets(); + auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); + auto hTADOffsetsY = tadPackY.primaryOffsets(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3All(extra, reduce3::Dot, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), nullptr, &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), hTADShapeInfoX, hTADOffsetsX, + hTADShapeInfoY, hTADOffsetsY); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execScalar(extra, - scalar::Multiply, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalar(extra, scalar::Multiply, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.f); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.assign(false); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - - ::execScalarBool(extra, - scalar::GreaterThan, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15) != z.e(15)); + x.linspace(1.0); + z.assign(false); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalarBool(extra, scalar::GreaterThan, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(5) == z.e(5) && + exp.e(15) != z.e(15)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); - auto exp = NDArrayFactory::create(0.9f); - auto z = NDArrayFactory::create(0.21587136f); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, + 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, + 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); + auto exp = NDArrayFactory::create(0.9f); + auto z = NDArrayFactory::create(0.21587136f); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execSummaryStatsScalar(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execSummaryStatsScalar(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + false); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, + -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execSummaryStats(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execSummaryStats(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), exp.specialShapeInfo(), false); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); - auto exp = NDArrayFactory::create(0.9); - auto z = NDArrayFactory::create(0.21587136); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, + -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - auto dimensions = NDArrayFactory::create({0, 1}); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); - - ::execSummaryStatsTad(extra, - variance::SummaryStatsVariance, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - nullptr, - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(), - false, - nullptr, nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Standard Variance"); - ASSERT_TRUE(exp.equalsTo(z)); + auto dimensions = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); + + ::execSummaryStatsTad(extra, variance::SummaryStatsVariance, &xBuf, + x.shapeInfo(), x.specialShapeInfo(), nullptr, &expBuf, + exp.shapeInfo(), exp.specialShapeInfo(), &dimBuf, + dimensions.shapeInfo(), dimensions.specialShapeInfo(), + false, nullptr, nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, + 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - z.linspace(1.); + z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformFloat(extra, - transform::Sqrt, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Sqrt is"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformFloat(extra, transform::Sqrt, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Sqrt is"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}, {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, + 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, + 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - z.linspace(1.); + z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformSame(extra, - transform::Square, - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Square is"); - ASSERT_TRUE(exp.equalsTo(x)); + ::execTransformSame(extra, transform::Square, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Square is"); + ASSERT_TRUE(exp.equalsTo(x)); } TEST_F(NativeOpsTests, TransformTest_3) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.); - z.assign(true); - x.p(24, -25); - z.p(24, false); + x.linspace(1.); + z.assign(true); + x.p(24, -25); + z.p(24, false); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformBool(extra, - transform::IsPositive, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("IsPositive"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformBool(extra, transform::IsPositive, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("IsPositive"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, TransformTest_4) { - auto x = NDArrayFactory::create('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, - 3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, - 0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0, - 0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.}); - - Nd4jPointer extra[6]; + auto x = NDArrayFactory::create( + 'c', {5, 5}, + {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, 3.141592, + 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create( + 'c', {5, 5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, + 0.540302, 1.0, 0.000796, 0.000796, 0.000796, + -1, -1, -1, 1., 1., + 1.0, 1.0, 0.540302, 0.540302, -0.416147, + -0.416147, -0.416147, 0.540302, 1., 1.}); + + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - //z.linspace(1.); - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); + // z.linspace(1.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); - ::execTransformStrict(extra, - transform::Cosine, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - nullptr); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Cosine"); - ASSERT_TRUE(exp.equalsTo(z)); + ::execTransformStrict(extra, transform::Cosine, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), nullptr); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Cosine"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTadTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(10.f); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - auto dimension = NDArrayFactory::create({0, 1}); - auto dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execScalarTad(extra, - scalar::Multiply, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), y.specialShapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), - tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("Reduce All"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarTad(extra, scalar::Multiply, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), tadPackX.primaryShapeInfo(), + tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), + tadPackZ.primaryOffsets()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, ScalarTadTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create(true); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(true); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.assign(false); - x.p(5, true); - x.p(15, true); - //z.linspace(10., 10.); - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - auto dimension = NDArrayFactory::create({0, 1}); - auto dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - z.assign(true); - - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer expBuf(exp.dataBuffer()); - OpaqueDataBuffer dimBuf(dimension.dataBuffer()); - - ::execScalarBoolTad(extra, - scalar::And, - &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &expBuf, exp.shapeInfo(), - exp.specialShapeInfo(), - &yBuf, y.shapeInfo(), - y.specialShapeInfo(), - nullptr, - &dimBuf, dimension.shapeInfo(), - dimension.specialShapeInfo(), - tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); -// x.printIndexedBuffer("Input"); -// exp.printIndexedBuffer("And"); - ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); + x.assign(false); + x.p(5, true); + x.p(15, true); + // z.linspace(10., 10.); + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + z.assign(true); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarBoolTad(extra, scalar::And, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), nullptr, &dimBuf, + dimension.shapeInfo(), dimension.specialShapeInfo(), + tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), + tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); + // x.printIndexedBuffer("Input"); + // exp.printIndexedBuffer("And"); + ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); } TEST_F(NativeOpsTests, ConcatTest_2) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {10,5}); - auto z = NDArrayFactory::create('c', {10,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {10, 5}); + auto z = NDArrayFactory::create('c', {10, 5}); - Nd4jPointer extra[6]; + Nd4jPointer extra[6]; #ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; #endif - x.linspace(1.0); - y.linspace(26); - - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - int d = 0; - auto dimension = NDArrayFactory::create('c', {1}, {d}); - auto dimensions = reinterpret_cast(dimension.buffer()); - //auto tadPackX = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - exp.linspace(1); - Nd4jPointer datas[] = {x.buffer(), y.buffer()}; - Nd4jPointer shapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; - - ::specialConcat(extra, 0, 2, datas, shapes, z.buffer(), z.shapeInfo(), nullptr, nullptr); - -// exp.printIndexedBuffer("Exp"); -// z.printIndexedBuffer("Concat"); - ASSERT_TRUE(exp.equalsTo(z)); + x.linspace(1.0); + y.linspace(26); + + // y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + int d = 0; + auto dimension = NDArrayFactory::create('c', {1}, {d}); + auto dimensions = reinterpret_cast(dimension.buffer()); + // auto tadPackX = + // sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), + // dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dimensions, dimension.lengthOf()); + exp.linspace(1); + Nd4jPointer datas[] = {x.buffer(), y.buffer()}; + Nd4jPointer shapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)y.shapeInfo()}; + + ::specialConcat(extra, 0, 2, datas, shapes, z.buffer(), z.shapeInfo(), + nullptr, nullptr); + + // exp.printIndexedBuffer("Exp"); + // z.printIndexedBuffer("Concat"); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(NativeOpsTests, InitializeTest_1) { -// ::initializeDevicesAndFunctions(); + // ::initializeDevicesAndFunctions(); } TEST_F(NativeOpsTests, MallocTest_1) { - auto a = ::mallocHost(16, 0); - ::freeHost(a); - auto dA = ::mallocDevice(16, 0, 0); - ::freeDevice(dA, 0); + auto a = ::mallocHost(16, 0); + ::freeHost(a); + auto dA = ::mallocDevice(16, 0, 0); + ::freeDevice(dA, 0); } TEST_F(NativeOpsTests, OMPTest_1) { - auto maxThreads = ::ompGetMaxThreads(); - auto numThreads = ::ompGetNumThreads(); - //::setOmpMinThreads(maxThreads); - //::setOmpNumThreads(numThreads); + auto maxThreads = ::ompGetMaxThreads(); + auto numThreads = ::ompGetNumThreads(); + //::setOmpMinThreads(maxThreads); + //::setOmpNumThreads(numThreads); } TEST_F(NativeOpsTests, CreateTest_1) { - auto xx = ::createContext(); - auto yy = ::createStream(); - auto zz = ::createEvent(); - ::destroyEvent(zz); - if (xx) - delete (LaunchContext*)xx; - if (yy) - printf("Stream should be destoyed before."); - + auto xx = ::createContext(); + auto yy = ::createStream(); + auto zz = ::createEvent(); + ::destroyEvent(zz); + if (xx) delete (LaunchContext *)xx; + if (yy) printf("Stream should be destoyed before."); } TEST_F(NativeOpsTests, MemTest_1) { - auto x = NDArrayFactory::create({10, 20, 30, 40, 50}); - auto y = NDArrayFactory::create({20, 20, 20, 20, 20}); + auto x = NDArrayFactory::create({10, 20, 30, 40, 50}); + auto y = NDArrayFactory::create({20, 20, 20, 20, 20}); #ifdef __CUDABLAS__ - return ; + return; #endif - //ASSERT_TRUE(0 == ::memcpy(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); - ASSERT_TRUE(0 == ::memcpyAsync(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); - //ASSERT_TRUE(0 == ::memset(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); - ASSERT_TRUE(0 == ::memsetAsync(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); - + // ASSERT_TRUE(0 == ::memcpy(x.buffer(), y.buffer(), x.lengthOf() * + // sizeof(double), 0, nullptr)); + ASSERT_TRUE(0 == ::memcpyAsync(x.buffer(), y.buffer(), + x.lengthOf() * sizeof(double), 0, nullptr)); + // ASSERT_TRUE(0 == ::memset(x.buffer(), 119, x.lengthOf() * sizeof(double), + // 0, nullptr)); + ASSERT_TRUE(0 == ::memsetAsync(x.buffer(), 119, x.lengthOf() * sizeof(double), + 0, nullptr)); } TEST_F(NativeOpsTests, PullRowsTest_1) { - NDArray x('c', {5, 1}, {0,1,2,3,4}); - NDArray z('c', {4, 1}, sd::DataType::DOUBLE); - NDArray exp('c', {4, 1}, {0,2,3,4}); + NDArray x('c', {5, 1}, {0, 1, 2, 3, 4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0, 2, 3, 4}); - Nd4jLong indexes[] = {0,2,3,4}; - PointersManager pm(LaunchContext::defaultContext(), "NativeOpsTests::pullRows"); - auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + Nd4jLong indexes[] = {0, 2, 3, 4}; + PointersManager pm(LaunchContext::defaultContext(), + "NativeOpsTests::pullRows"); + auto pidx = reinterpret_cast( + pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); - std::vector dims = {1}; + std::vector dims = {1}; - auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dims); - auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dims); + auto xTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + z.shapeInfo(), dims); - Nd4jPointer nativeStart[2]; + Nd4jPointer nativeStart[2]; #ifdef __CUDABLAS__ - nativeStart[1] = (x.getContext()->getCudaStream()); + nativeStart[1] = (x.getContext()->getCudaStream()); #endif - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), - &zBuf, z.shapeInfo(), z.specialShapeInfo(), - 4, pidx, - xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), - zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, + z.shapeInfo(), z.specialShapeInfo(), 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); - ASSERT_TRUE(z.equalsTo(exp)); - pm.synchronize(); + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); } TEST_F(NativeOpsTests, TadPackTest_1) { - int dimension[] = {1}; - int const dimensionLength = 1; - auto x = NDArrayFactory::create('c', {2,3,4}); - sd::TadPack* pack = ::tadOnlyShapeInfo(x.shapeInfo(), - dimension, - dimensionLength); - ASSERT_TRUE(pack != nullptr); - delete pack; + int dimension[] = {1}; + int const dimensionLength = 1; + auto x = NDArrayFactory::create('c', {2, 3, 4}); + sd::TadPack *pack = + ::tadOnlyShapeInfo(x.shapeInfo(), dimension, dimensionLength); + ASSERT_TRUE(pack != nullptr); + delete pack; } TEST_F(NativeOpsTests, AverageTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - exp.linspace(1); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; - ::average(nullptr, - xList, x.shapeInfo(), - dxList, x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - 2, - x.lengthOf(), - true); -// z.printIndexedBuffer("RES"); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + exp.linspace(1); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::average(nullptr, xList, x.shapeInfo(), dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + 2, x.lengthOf(), true); + // z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, AccumulateTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - exp.linspace(2,2); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; - ::accumulate(nullptr, - xList, x.shapeInfo(), - dxList, x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), - z.specialBuffer(), z.specialShapeInfo(), - 2, - x.lengthOf()); -// z.printIndexedBuffer("RES"); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + exp.linspace(2, 2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::accumulate(nullptr, xList, x.shapeInfo(), dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo(), 2, x.lengthOf()); + // z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, P2PTest_1) { - ::enableP2P(true); - ::checkP2P(); - ::isP2PAvailable(); + ::enableP2P(true); + ::checkP2P(); + ::isP2PAvailable(); } TEST_F(NativeOpsTests, ShuffleTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5,5}); - auto z = NDArrayFactory::create('c', {5,5}); + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(1); - y.linspace(34); - exp.linspace(2,2); - Nd4jPointer xList[] = {x.buffer(), x.buffer()}; - Nd4jPointer dxList[] = {x.specialBuffer(), y.specialBuffer()}; - Nd4jPointer xShapeList[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; - Nd4jPointer dxShapeList[] = {(Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; - Nd4jPointer zList[] = {z.buffer(), z.buffer()}; - Nd4jPointer dzList[] = {z.specialBuffer(), z.specialBuffer()}; - Nd4jPointer zShapeList[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.shapeInfo()}; - Nd4jPointer dzShapeList[] = {(Nd4jPointer)z.specialShapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; - int shuffleMap[] = {1, 0, 4, 3, 2}; - auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), {1}); - Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), (Nd4jPointer)zTadPack.platformOffsets()}; - Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), (Nd4jPointer)zTadPack.platformShapeInfo()}; - ::shuffle(nullptr, - xList, xShapeList, - dxList, dxShapeList, - zList, zShapeList, - dzList, dzShapeList, - 2, - shuffleMap, zListTADs, zListOffset); -// z.printIndexedBuffer("RES"); -// x.printIndexedBuffer("INPUT shuffled"); -// y.printIndexedBuffer("INPUT 2 shuffled"); -// ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(1); + y.linspace(34); + exp.linspace(2, 2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer xShapeList[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)y.shapeInfo()}; + Nd4jPointer dxShapeList[] = {(Nd4jPointer)x.specialShapeInfo(), + (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer zList[] = {z.buffer(), z.buffer()}; + Nd4jPointer dzList[] = {z.specialBuffer(), z.specialBuffer()}; + Nd4jPointer zShapeList[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.shapeInfo()}; + Nd4jPointer dzShapeList[] = {(Nd4jPointer)z.specialShapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + int shuffleMap[] = {1, 0, 4, 3, 2}; + auto zTadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + x.shapeInfo(), {1}); + Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), + (Nd4jPointer)zTadPack.platformOffsets()}; + Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), + (Nd4jPointer)zTadPack.platformShapeInfo()}; + ::shuffle(nullptr, xList, xShapeList, dxList, dxShapeList, zList, zShapeList, + dzList, dzShapeList, 2, shuffleMap, zListTADs, zListOffset); + // z.printIndexedBuffer("RES"); + // x.printIndexedBuffer("INPUT shuffled"); + // y.printIndexedBuffer("INPUT 2 shuffled"); + // ASSERT_TRUE(z.equalsTo(exp)); } TEST_F(NativeOpsTests, ConvertTypesTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); + auto x = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {5, 5}); - auto z = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); #ifdef __CUDABLAS__ - return; + return; #endif - x.linspace(2, 2); - exp.linspace(2, 2); - ::convertTypes(nullptr, ND4J_FLOAT32, x.buffer(), x.lengthOf(), ND4J_DOUBLE, z.buffer()); - ASSERT_TRUE(z.equalsTo(exp)); + x.linspace(2, 2); + exp.linspace(2, 2); + ::convertTypes(nullptr, ND4J_FLOAT32, x.buffer(), x.lengthOf(), ND4J_DOUBLE, + z.buffer()); + ASSERT_TRUE(z.equalsTo(exp)); } -//TEST_F(NativeOpsTests, Test_Aggregations_1) { +// TEST_F(NativeOpsTests, Test_Aggregations_1) { // NativeOps ops; // auto x = NDArrayFactory::create('c', {5,5}); // auto y = NDArrayFactory::create('c', {5,5}); // // -// ops.execAggregate(nullptr, 0, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data(), sd::DataType::FLOAT32); -// void **arguments, -// int numArguments, -// Nd4jLong **shapeArguments, -// int numShapeArguments, -// int *indexArguments, -// int numIndexArguments, -// int **intArrays, -// int numIntArrays, -// void *realArguments, +// ops.execAggregate(nullptr, 0, maxArgs, maxShapes, maxIntArrays, +// maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data(), +// sd::DataType::FLOAT32); void **arguments, int numArguments, Nd4jLong +// **shapeArguments, int numShapeArguments, int *indexArguments, int +// numIndexArguments, int **intArrays, int numIntArrays, void *realArguments, // int numRealArguments, // sd::DataType dtype //} TEST_F(NativeOpsTests, RandomTest_1) { - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer zBuf(z.dataBuffer()); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_2) { - auto x = NDArrayFactory::create('c', {100}); - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto x = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - x.linspace(0, 0.01); - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + x.linspace(0, 0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_3) { - auto x = NDArrayFactory::create('c', {100}); - auto y = NDArrayFactory::create('c', {100}); - auto z = NDArrayFactory::create('c', {100}); - Nd4jPointer extra[] = {nullptr, nullptr}; + auto x = NDArrayFactory::create('c', {100}); + auto y = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; #ifdef __CUDABLAS__ - return; - extra[1] = z.getContext()->getCudaStream(); + return; + extra[1] = z.getContext()->getCudaStream(); #endif - x.linspace(0, 0.01); - x.linspace(1, -0.01); - graph::RandomGenerator rng(1023, 119); - double p = 0.5; - OpaqueDataBuffer xBuf(x.dataBuffer()); - OpaqueDataBuffer yBuf(y.dataBuffer()); - OpaqueDataBuffer zBuf(z.dataBuffer()); + x.linspace(0, 0.01); + x.linspace(1, -0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); - ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf, - y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); + ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), + x.specialShapeInfo(), &yBuf, y.shapeInfo(), + y.specialShapeInfo(), &zBuf, z.shapeInfo(), + z.specialShapeInfo(), &p); } TEST_F(NativeOpsTests, RandomTest_4) { #ifdef __CUDABLAS__ - return ; + return; #endif - graph::RandomGenerator* rng = (graph::RandomGenerator*)::initRandom(nullptr, 1023, 0, nullptr); - ::refreshBuffer(nullptr, 1203L, rng); - ::reSeedBuffer(nullptr, 3113L, rng); - ::destroyRandom(rng); + graph::RandomGenerator *rng = + (graph::RandomGenerator *)::initRandom(nullptr, 1023, 0, nullptr); + ::refreshBuffer(nullptr, 1203L, rng); + ::reSeedBuffer(nullptr, 3113L, rng); + ::destroyRandom(rng); } TEST_F(NativeOpsTests, SortTest_1) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto sortedVals = NDArrayFactory::create( - {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); - auto exp = NDArrayFactory::create({1, 3, 4, 5, 5, 10, 29, 34, 50, 56, 71, 73, 78, 91, 111, 120, 138, 331}); + auto sortedVals = NDArrayFactory::create( + {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create( + {1, 3, 4, 5, 5, 10, 29, 34, 50, 56, 71, 73, 78, 91, 111, 120, 138, 331}); - ::sort(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), - sortedVals.specialShapeInfo(), false); - ASSERT_TRUE(sortedVals.equalsTo(exp)); + ::sort(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), false); + ASSERT_TRUE(sortedVals.equalsTo(exp)); } TEST_F(NativeOpsTests, SortTests_2) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif -// OpaqueDataBuffer xBuf(x.dataBuffer()); -// OpaqueDataBuffer yBuf(y.dataBuffer()); -// OpaqueDataBuffer expBuf(exp.dataBuffer()); -// OpaqueDataBuffer dimBuf(exp.dataBuffer()); + // OpaqueDataBuffer xBuf(x.dataBuffer()); + // OpaqueDataBuffer yBuf(y.dataBuffer()); + // OpaqueDataBuffer expBuf(exp.dataBuffer()); + // OpaqueDataBuffer dimBuf(exp.dataBuffer()); - ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTest_3) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); #ifdef __CUDABLAS__ - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; #else - Nd4jPointer extras[2]; + Nd4jPointer extras[2]; #endif - ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTest_4) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto sortedVals = NDArrayFactory::create('c', {3, 6}, - { 10, 1, 5, 120, 34, 5, - 78, 138, 3, 111, 331, 29, - 91, 71, 73, 50, 56, 4}); - auto exp = NDArrayFactory::create('c', {3, 6}, {1, 5, 5, 10, 34, 120, 3, 29, 78, 111, 138, 331, 4, 50, 56, 71, 73, 91}); - - std::vector dims({1}); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(sortedVals.shapeInfo(), {1}); - ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), - sortedVals.specialShapeInfo(), dims.data(), dims.size(), packX.platformShapeInfo(), packX.platformOffsets(), false); -// sortedVals.printBuffer("OUT"); -// exp.printIndexedBuffer("EXP"); - ASSERT_TRUE(sortedVals.equalsTo(exp)); + auto sortedVals = NDArrayFactory::create( + 'c', {3, 6}, + {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 6}, + {1, 5, 5, 10, 34, 120, 3, 29, 78, 111, 138, 331, 4, 50, 56, 71, 73, 91}); + + std::vector dims({1}); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions( + sortedVals.shapeInfo(), {1}); + ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), + sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), + dims.data(), dims.size(), packX.platformShapeInfo(), + packX.platformOffsets(), false); + // sortedVals.printBuffer("OUT"); + // exp.printIndexedBuffer("EXP"); + ASSERT_TRUE(sortedVals.equalsTo(exp)); } TEST_F(NativeOpsTests, SortTests_5) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif - int axis = 1; + int axis = 1; - ::sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); -// k.printIndexedBuffer("k"); -// v.printIndexedBuffer("v"); + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(NativeOpsTests, SortTests_6) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2]; + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; #ifdef __CUDABLAS__ - extras[1] = LaunchContext::defaultContext()->getCudaStream(); + extras[1] = LaunchContext::defaultContext()->getCudaStream(); #endif - int axis = 1; + int axis = 1; - ::sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); + ::sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } -//TEST_F(NativeOpsTests, MapTests_1) { +// TEST_F(NativeOpsTests, MapTests_1) { //#ifdef __CUDABLAS__ // return ; //#endif @@ -1479,131 +1455,153 @@ TEST_F(NativeOpsTests, SortTests_6) { //} TEST_F(NativeOpsTests, MapTests_1) { - //printf("Custom ops: %s\n", ::getAllCustomOps()); - //printf("All ops: %s\n", ::getAllOperations()); + // printf("Custom ops: %s\n", ::getAllCustomOps()); + // printf("All ops: %s\n", ::getAllOperations()); - ::getAllCustomOps(); - ::getAllOperations(); + ::getAllCustomOps(); + ::getAllOperations(); } TEST_F(NativeOpsTests, CustomOpTest_1) { - auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {6}); - auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto x = NDArrayFactory::create('c', {1, 6}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = + NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - sd::ops::squeeze op; + sd::ops::squeeze op; - Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; - Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer)x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), + (Nd4jPointer)x.specialShapeInfo()}; - Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; - Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer)z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), + (Nd4jPointer)z.specialShapeInfo()}; + auto status = ::execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, + ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, + 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); - auto status = ::execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); + ASSERT_EQ(e, z); } TEST_F(NativeOpsTests, CustomOpTests_2) { - auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); - auto z = NDArrayFactory::create('c', {3, 2}); + auto array0 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); - auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); - Context ctx(1); + auto exp = NDArrayFactory::create('c', {3, 2}, + {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); - NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); - ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); - ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), + array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), + array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), + z.specialShapeInfo()); - ASSERT_EQ(2, ctx.width()); + ASSERT_EQ(2, ctx.width()); - sd::ops::add op; - ::execCustomOp2(nullptr, op.getOpHash(), &ctx); + sd::ops::add op; + ::execCustomOp2(nullptr, op.getOpHash(), &ctx); - NDArray::registerSpecialUse({&z}, {&array0, &array1}); + NDArray::registerSpecialUse({&z}, {&array0, &array1}); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(NativeOpsTests, CalculateOutputShapeTests_1) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + Nd4jPointer ptrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; #ifdef __CUDABLAS__ - return; + return; #endif - auto shapeList = ::calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + auto shapeList = + ::calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), + tArgs.size(), iArgs.data(), iArgs.size()); - ASSERT_EQ(1, shapeList->size()); + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - ::deleteShapeList((Nd4jPointer) shapeList); + ::deleteShapeList((Nd4jPointer)shapeList); } TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { - auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); - auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); - sd::ops::conv2d op; + sd::ops::conv2d op; - std::vector tArgs({}); - std::vector bArgsF({}); - std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + std::vector tArgs({}); + std::vector bArgsF({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); - Nd4jPointer shapePtrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; - Nd4jPointer dataPtrs[] = {(Nd4jPointer)input.buffer(), (Nd4jPointer)weights.buffer()}; + Nd4jPointer shapePtrs[] = {(Nd4jPointer)input.shapeInfo(), + (Nd4jPointer)weights.shapeInfo()}; + Nd4jPointer dataPtrs[] = {(Nd4jPointer)input.buffer(), + (Nd4jPointer)weights.buffer()}; #ifdef __CUDABLAS__ - return; + return; #endif - auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast(tArgs.data()), tArgs.size(), - const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0); -// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs - ASSERT_EQ(1, shapeList->size()); + auto shapeList = ::calculateOutputShapes2( + nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, + const_cast(tArgs.data()), tArgs.size(), + const_cast(iArgs.data()), iArgs.size(), nullptr, + bArgsF.size(), nullptr, 0); + // Nd4jPointer* extraPointers, Nd4jLong hash, + // Nd4jPointer* inputBuffers, Nd4jPointer* + // inputShapes, int numInputShapes, double* + // tArgs, int numTArgs, Nd4jLong *iArgs, int + // numIArgs, bool *bArgs, int numBArgs + ASSERT_EQ(1, shapeList->size()); - ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); - ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); - ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); - ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); - ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); - //int *ptr = (int *) shapeList[0]; - //delete[] ptr; - //delete shapeList; + // int *ptr = (int *) shapeList[0]; + // delete[] ptr; + // delete shapeList; - ::deleteShapeList((Nd4jPointer) shapeList); + ::deleteShapeList((Nd4jPointer)shapeList); } - TEST_F(NativeOpsTests, interop_databuffer_tests_1) { - auto idb = ::allocateDataBuffer(100, 10, false); - auto ptr = ::dbPrimaryBuffer(idb); - ::deleteDataBuffer(idb); + auto idb = ::allocateDataBuffer(100, 10, false); + auto ptr = ::dbPrimaryBuffer(idb); + ::deleteDataBuffer(idb); } -//Uncomment when needed only - massive calculations -//TEST_F(NativeOpsTests, BenchmarkTests_1) { +// Uncomment when needed only - massive calculations +// TEST_F(NativeOpsTests, BenchmarkTests_1) { // // printf("%s\n", ::runLightBenchmarkSuit(true)); // printf("%s\n", ::runFullBenchmarkSuit(true)); diff --git a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp index 2325e24455ed..b8afdfcf01eb 100644 --- a/libnd4j/tests_cpu/layers_tests/NlpTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NlpTests.cpp @@ -18,457 +18,476 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include #include +#include +#include +#include "testlayers.h" using namespace sd; - class NlpTests : public testing::Test { -public: - - NlpTests() { - printf("\n"); - fflush(stdout); - } + public: + NlpTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(NlpTests, basic_sg_hs_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01001f); - exp1.assign(0.020005f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {1}, {1}); - auto codes = NDArrayFactory::create('c', {1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_EQ(exp1, row1); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01001f); + exp1.assign(0.020005f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {1}, {1}); + auto codes = NDArrayFactory::create('c', {1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); } TEST_F(NlpTests, basic_sg_hs_test_2) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {2}, {1, 2}); - auto codes = NDArrayFactory::create('c', {2}, {0, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - auto row2 = syn1({2,3, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_EQ(exp1, row1); - ASSERT_EQ(exp2, row2); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2}, {1, 2}); + auto codes = NDArrayFactory::create('c', {2}, {0, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + auto row2 = syn1({2, 3, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); + ASSERT_EQ(exp2, row2); } TEST_F(NlpTests, basic_sg_hs_test_3) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto indices0 = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto indices1 = NDArrayFactory::create('c', {3}, {3, 1, 2}); - auto codes00 = NDArrayFactory::create('c', {3}, {0, 1, 1}); - auto codes01 = NDArrayFactory::create('c', {3}, {1, 0, 1}); - auto syn00 = NDArrayFactory::create('c', {100, 10}); - auto syn01 = NDArrayFactory::create('c', {100, 10}); - auto syn10 = NDArrayFactory::create('c', {100, 10}); - auto syn11 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {10}); - - RandomGenerator rng(119L, 198L); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, 1.0); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, 1.0); - - syn01.assign(syn00); - syn11.assign(syn10); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(1L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result0.status()); - - auto row00 = syn00({0,1, 0,0}, true); - auto row01 = syn01({0,1, 0,0}, true); - auto row1 = syn10({1,2, 0,0}, true); - auto row2 = syn11({1,2, 0,0}, true); - - ASSERT_EQ(row2, row1); - ASSERT_EQ(row00, row01); + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices0 = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto indices1 = NDArrayFactory::create('c', {3}, {3, 1, 2}); + auto codes00 = NDArrayFactory::create('c', {3}, {0, 1, 1}); + auto codes01 = NDArrayFactory::create('c', {3}, {1, 0, 1}); + auto syn00 = NDArrayFactory::create('c', {100, 10}); + auto syn01 = NDArrayFactory::create('c', {100, 10}); + auto syn10 = NDArrayFactory::create('c', {100, 10}); + auto syn11 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + RandomGenerator rng(119L, 198L); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, + 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, + 1.0); + + syn01.assign(syn00); + syn11.assign(syn10); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result0 = op.evaluate( + {&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, + &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + auto result1 = op.evaluate( + {&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, + &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result0.status()); + + auto row00 = syn00({0, 1, 0, 0}, true); + auto row01 = syn01({0, 1, 0, 0}, true); + auto row1 = syn10({1, 2, 0, 0}, true); + auto row2 = syn11({1, 2, 0, 0}, true); + + ASSERT_EQ(row2, row1); + ASSERT_EQ(row00, row01); } TEST_F(NlpTests, basic_sg_hs_ns_test_1) { - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::create(1); - auto indices = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto codes = NDArrayFactory::create('c', {5}, {1, 1, 0, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 150}); - auto syn1 = NDArrayFactory::create('c', {100, 150}); - auto syn1Neg = NDArrayFactory::create('c', {100, 150}); - auto expTable = NDArrayFactory::create('c', {1000}); - auto negTable = NDArrayFactory::create('c', {1000}); - auto neu1e = NDArrayFactory::create('c', {10}); - negTable.linspace(1.0); - - auto alpha = NDArrayFactory::create(1.25); - auto randomValue = NDArrayFactory::create(119L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(1); + auto indices = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto codes = NDArrayFactory::create('c', {5}, {1, 1, 0, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 150}); + auto syn1 = NDArrayFactory::create('c', {100, 150}); + auto syn1Neg = NDArrayFactory::create('c', {100, 150}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + negTable.linspace(1.0); + + auto alpha = NDArrayFactory::create(1.25); + auto randomValue = NDArrayFactory::create(119L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {3}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(NlpTests, basic_sg_ns_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01); - - auto target = NDArrayFactory::create(1); - auto ngStarter = NDArrayFactory::create(3); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {10, 10}); - auto syn1 = NDArrayFactory::empty(); - auto syn1Neg = NDArrayFactory::create('c', {10, 10}); - auto expTable = NDArrayFactory::create('c', {1000}); - auto negTable = NDArrayFactory::create('c', {1000}); - auto neu1e = NDArrayFactory::create('c', {10}); - - auto syn1Neg2 = NDArrayFactory::create('c', {10, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - syn1Neg.assign(0.03); - syn1Neg2.assign(0.03); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.001); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({1,2, 0,0}, true); - - ASSERT_EQ(exp0, row0); - ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01); + + auto target = NDArrayFactory::create(1); + auto ngStarter = NDArrayFactory::create(3); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {10, 10}); + auto syn1 = NDArrayFactory::empty(); + auto syn1Neg = NDArrayFactory::create('c', {10, 10}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + + auto syn1Neg2 = NDArrayFactory::create('c', {10, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + syn1Neg2.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {1, 1}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({1, 2, 0, 0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); } TEST_F(NlpTests, basic_cb_hs_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095f); - exp1.assign(0.019875f); - exp2.assign(0.02f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::create('c', {2}, {4, 5}); - auto codes = NDArrayFactory::create('c', {2}, {1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {1}, {1}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(NlpTests, basic_cb_ns_test_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0096265625); - exp1.assign(0.01); - exp2.assign(0.030125f); - - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::create(6); - auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); - auto locked = NDArrayFactory::create('c', {3}); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::create('c', {100, 10}); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::create('c', {100000}); - auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); - - syn0.assign(0.01); - syn1.assign(0.02); - syn1Neg.assign(0.03); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create(0.025); - auto randomValue = NDArrayFactory::create(2L); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1Neg({6,7, 0,0}, true); - - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - ASSERT_EQ(exp2, row_s1_6); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0096265625); + exp1.assign(0.01); + exp2.assign(0.030125f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(6); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {1, 2, 0}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1Neg({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp2, row_s1_6); } TEST_F(NlpTests, test_sg_hs_batch_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create('c', {2}, {0, 5}); - auto ngStarter = NDArrayFactory::empty(); - auto indices = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto codes = NDArrayFactory::create('c', {2, 2}, {0, 1, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - - auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); - auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); - auto inferenceVector = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {2, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto row0 = syn0({0,1, 0,0}, true); - auto row1 = syn1({1,2, 0,0}, true); - auto row2 = syn1({2,3, 0,0}, true); - - ASSERT_TRUE(exp0.equalsTo(row0, 1e-6)); - ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); - ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto codes = NDArrayFactory::create('c', {2, 2}, {0, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0, 1, 0, 0}, true); + auto row1 = syn1({1, 2, 0, 0}, true); + auto row2 = syn1({2, 3, 0, 0}, true); + + ASSERT_TRUE(exp0.equalsTo(row0, 1e-6)); + ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); + ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); } TEST_F(NlpTests, test_sg_ns_batch_1) { - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.01f); - exp1.assign(0.020005f); - exp2.assign(0.019995f); - - auto target = NDArrayFactory::create('c', {2}, {0, 5}); - auto ngStarter = NDArrayFactory::create('c', {2}, {3, 8}); - auto indices = NDArrayFactory::empty(); - auto codes = NDArrayFactory::empty(); - auto syn0 = NDArrayFactory::create('c', {100, 10}); - auto syn1Neg = NDArrayFactory::create('c', {100, 10}); - auto syn1 = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::create('c', {100000}); - - auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); - auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); - auto inferenceVector = NDArrayFactory::empty(); - auto neu1e = NDArrayFactory::create('c', {2, 10}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - negTable.linspace(0.0); - - sd::ops::skipgram op; - auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::create('c', {2}, {3, 8}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + negTable.linspace(0.0); + + sd::ops::skipgram op; + auto result = op.evaluate( + {&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, + &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, + {}, {4, 5}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); } TEST_F(NlpTests, test_cbow_hs_batch_1) { #ifdef __CUDABLAS__ - return ; + return; #endif - auto target = NDArrayFactory::create(0); - auto ngStarter = NDArrayFactory::empty(); - auto context = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 100, 101, 102}); - auto locked = NDArrayFactory::create('c', {2, 3}); - auto indices = NDArrayFactory::create('c', {2, 2}, {4, 5, 40, 50}); - auto codes = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - auto syn0 = NDArrayFactory::create('c', {244, 10}); - auto syn1 = NDArrayFactory::create('c', {244, 10}); - auto syn1Neg = NDArrayFactory::empty(); - auto expTable = NDArrayFactory::create('c', {10000}); - auto negTable = NDArrayFactory::empty(); - auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); - - syn0.assign(0.01); - syn1.assign(0.02); - expTable.assign(0.5); - - auto alpha = NDArrayFactory::create('c', {2}, {0.025, 0.025}); - auto randomValue = NDArrayFactory::create('c', {2}, {2L, 2L}); - auto inferenceVector = NDArrayFactory::empty(); - - sd::ops::cbow op; - auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); - ASSERT_EQ(Status::OK(), result.status()); - - auto exp0 = NDArrayFactory::create('c', {1, 10}); - auto exp1 = NDArrayFactory::create('c', {1, 10}); - auto exp2 = NDArrayFactory::create('c', {1, 10}); - - exp0.assign(0.0095f); - exp1.assign(0.019875f); - exp2.assign(0.02f); - - auto row_s0_0 = syn0({0,1, 0,0}, true); - auto row_s0_1 = syn0({1,2, 0,0}, true); - auto row_s0_2 = syn0({2,3, 0,0}, true); - - auto row_s1_4 = syn1({4,5, 0,0}, true); - auto row_s1_5 = syn1({5,6, 0,0}, true); - auto row_s1_6 = syn1({6,7, 0,0}, true); - - ASSERT_EQ(exp0, row_s0_0); - ASSERT_EQ(exp0, row_s0_1); - ASSERT_EQ(exp0, row_s0_2); - ASSERT_EQ(exp1, row_s1_4); - ASSERT_EQ(exp1, row_s1_5); - ASSERT_EQ(exp2, row_s1_6); - + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = + NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 100, 101, 102}); + auto locked = NDArrayFactory::create('c', {2, 3}); + auto indices = NDArrayFactory::create('c', {2, 2}, {4, 5, 40, 50}); + auto codes = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {244, 10}); + auto syn1 = NDArrayFactory::create('c', {244, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create('c', {2}, {0.025, 0.025}); + auto randomValue = NDArrayFactory::create('c', {2}, {2L, 2L}); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = + op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, + &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, + &numWords, &locked, &inferenceVector}, + {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto row_s0_0 = syn0({0, 1, 0, 0}, true); + auto row_s0_1 = syn0({1, 2, 0, 0}, true); + auto row_s0_2 = syn0({2, 3, 0, 0}, true); + + auto row_s1_4 = syn1({4, 5, 0, 0}, true); + auto row_s1_5 = syn1({5, 6, 0, 0}, true); + auto row_s1_6 = syn1({6, 7, 0, 0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); } diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 71cf5eb622da..f180bb24c038 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -18,56 +18,55 @@ // Created by raver119 on 21.02.18. // -#include "testlayers.h" #include -#include #include +#include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class NodeTests : public testing::Test { -public: - + public: }; TEST_F(NodeTests, Test_Dtype_Conversion_1) { - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2}); + auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2}); - auto nd = nodeA->asT(); - auto nf = nd->asT(); + auto nd = nodeA->asT(); + auto nf = nd->asT(); - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(nodeA->name(), nf->name()); - ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); + ASSERT_EQ(nodeA->id(), nf->id()); + ASSERT_EQ(nodeA->name(), nf->name()); + ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); + ASSERT_EQ(nodeA->opType(), nf->opType()); + ASSERT_EQ(nodeA->opNum(), nf->opNum()); - delete nodeA; - delete nd; - delete nf; + delete nodeA; + delete nd; + delete nf; } - TEST_F(NodeTests, Test_Dtype_Conversion_2) { - sd::ops::add opA; + sd::ops::add opA; - //auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2}); - auto nodeA = new Node(&opA, 1, {-1}, {2}); - //nodeA->setCustomOp(&op); + // auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2}); + auto nodeA = new Node(&opA, 1, {-1}, {2}); + // nodeA->setCustomOp(&op); - auto nd = nodeA->asT(); - auto nf = nd->asT(); + auto nd = nodeA->asT(); + auto nf = nd->asT(); - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(nodeA->name(), nf->name()); -// ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); - ASSERT_EQ(nodeA->customOp()->getOpHash(), nf->customOp()->getOpHash()); + ASSERT_EQ(nodeA->id(), nf->id()); + ASSERT_EQ(nodeA->name(), nf->name()); + // ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); + ASSERT_EQ(nodeA->opType(), nf->opType()); + ASSERT_EQ(nodeA->opNum(), nf->opNum()); + ASSERT_EQ(nodeA->customOp()->getOpHash(), nf->customOp()->getOpHash()); - delete nodeA; - delete nd; - delete nf; + delete nodeA; + delete nd; + delete nf; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp index a7c7eae24384..b62bf3fdd96f 100644 --- a/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp @@ -18,106 +18,107 @@ // Created by raver119 on 30.06.18. // -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class OmpLaunchHelperTests : public testing::Test { -private: - int ewt = 0; -public: - OmpLaunchHelperTests() { - this->ewt = Environment::getInstance()->elementwiseThreshold(); - Environment::getInstance()->setElementwiseThreshold(1000); - }; - - ~OmpLaunchHelperTests() { - Environment::getInstance()->setElementwiseThreshold(this->ewt); - } + private: + int ewt = 0; + + public: + OmpLaunchHelperTests() { + this->ewt = Environment::getInstance()->elementwiseThreshold(); + Environment::getInstance()->setElementwiseThreshold(1000); + }; + + ~OmpLaunchHelperTests() { + Environment::getInstance()->setElementwiseThreshold(this->ewt); + } }; TEST_F(OmpLaunchHelperTests, Test_BetterSpan_1) { - auto span = OmpLaunchHelper::betterSpan(1000, 4); - ASSERT_EQ(250, span); + auto span = OmpLaunchHelper::betterSpan(1000, 4); + ASSERT_EQ(250, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_2) { - auto span = OmpLaunchHelper::betterSpan(1001, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1001, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_3) { - auto span = OmpLaunchHelper::betterSpan(1002, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1002, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_5) { - auto span = OmpLaunchHelper::betterSpan(1003, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1003, 4); + ASSERT_EQ(251, span); } TEST_F(OmpLaunchHelperTests, Test_BetterSpan_6) { - auto span = OmpLaunchHelper::betterSpan(1004, 4); - ASSERT_EQ(251, span); + auto span = OmpLaunchHelper::betterSpan(1004, 4); + ASSERT_EQ(251, span); } - TEST_F(OmpLaunchHelperTests, Test_BetterThreads_1) { - auto n = OmpLaunchHelper::betterThreads(4000, 6); - ASSERT_EQ(4, n); + auto n = OmpLaunchHelper::betterThreads(4000, 6); + ASSERT_EQ(4, n); } TEST_F(OmpLaunchHelperTests, Test_BetterThreads_2) { - auto n = OmpLaunchHelper::betterThreads(12000, 6); - ASSERT_EQ(6, n); + auto n = OmpLaunchHelper::betterThreads(12000, 6); + ASSERT_EQ(6, n); } TEST_F(OmpLaunchHelperTests, Test_BetterThreads_3) { - auto n = OmpLaunchHelper::betterThreads(899, 6); - ASSERT_EQ(1, n); + auto n = OmpLaunchHelper::betterThreads(899, 6); + ASSERT_EQ(1, n); } TEST_F(OmpLaunchHelperTests, test_tad_threads_1) { - Nd4jLong numTads = 16; - Nd4jLong tadLength = 16; + Nd4jLong numTads = 16; + Nd4jLong tadLength = 16; -// nd4j_printf("TT: [%i]; ET: [%i];\n", Environment::getInstance()->tadThreshold(), Environment::getInstance()->elementwiseThreshold()); - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + // nd4j_printf("TT: [%i]; ET: [%i];\n", + // Environment::getInstance()->tadThreshold(), + // Environment::getInstance()->elementwiseThreshold()); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_2) { - if (omp_get_max_threads() <= 1) - return; + if (omp_get_max_threads() <= 1) return; - Nd4jLong numTads = 2; - Nd4jLong tadLength = Environment::getInstance()->elementwiseThreshold(); + Nd4jLong numTads = 2; + Nd4jLong tadLength = Environment::getInstance()->elementwiseThreshold(); - ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_3) { - Nd4jLong numTads = 2; - Nd4jLong tadLength = 128; + Nd4jLong numTads = 2; + Nd4jLong tadLength = 128; - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_4) { - Nd4jLong numTads = 4; - Nd4jLong tadLength = 64; + Nd4jLong numTads = 4; + Nd4jLong tadLength = 64; - ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); } TEST_F(OmpLaunchHelperTests, test_tad_threads_5) { - auto exp = omp_get_max_threads(); + auto exp = omp_get_max_threads(); - Nd4jLong numTads = exp; - Nd4jLong tadLength = Environment::getInstance()->elementwiseThreshold(); + Nd4jLong numTads = exp; + Nd4jLong tadLength = Environment::getInstance()->elementwiseThreshold(); - ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads)); + ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 8f465d25f249..ea670bafaced 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -18,39 +18,40 @@ // Created by raver119 on 11.10.2017. // -#include "testlayers.h" -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; class OneOffTests : public testing::Test { -public: - + public: }; TEST_F(OneOffTests, test_avg_pool_3d_1) { - auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - graph.execute(); + graph.execute(); } TEST_F(OneOffTests, test_avg_pool_3d_2) { - auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); + auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); - graph.execute(); + graph.execute(); } TEST_F(OneOffTests, test_non2d_0A_1) { - auto graph = Graph::fromFlatBuffers("./resources/non2d_0A.fb"); + auto graph = Graph::fromFlatBuffers("./resources/non2d_0A.fb"); - graph.execute(); + graph.execute(); } /* @@ -70,38 +71,44 @@ TEST_F(OneOffTests, test_assert_scalar_float32_1) { }*/ TEST_F(OneOffTests, test_assert_scalar_float32_2) { - sd::ops::Assert op; - sd::ops::identity op1; - sd::ops::noop op2; - auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); + sd::ops::Assert op; + sd::ops::identity op1; + sd::ops::noop op2; + auto graph = Graph::fromFlatBuffers("./resources/assertsomething.fb"); - graph.printOut(); + graph.printOut(); - //graph.execute(); + // graph.execute(); } - TEST_F(OneOffTests, test_pad_1D_1) { - auto e = NDArrayFactory::create('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f}); - auto graph = Graph::fromFlatBuffers("./resources/pad_1D.fb"); + auto e = NDArrayFactory::create( + 'c', {7}, + {10.f, 0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f, 10.f}); + auto graph = Graph::fromFlatBuffers("./resources/pad_1D.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(4)); + ASSERT_TRUE(graph.variableSpace().hasVariable(4)); - auto z = graph.variableSpace().getVariable(4)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(4)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } /* TEST_F(OneOffTests, test_scatter_nd_update_1) { - auto e = NDArrayFactory::create('c', {10, 7}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, 0.71881700f, 0.18677747f, - 0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, 0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, 0.50060666f, 0.39031255f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {10, 7}, +{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, +0.71881700f, 0.18677747f, 0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, +0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, +1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, +0.50060666f, 0.39031255f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, +1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, +1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); auto graph = Graph::fromFlatBuffers("./resources/scatter_nd_update.fb"); ASSERT_TRUE(graph != nullptr); @@ -125,107 +132,150 @@ TEST_F(OneOffTests, test_scatter_nd_update_1) { */ TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { - auto e = NDArrayFactory::create('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); - - auto graph = Graph::fromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); - - graph.execute(); - - ASSERT_TRUE(graph.variableSpace().hasVariable(9)); - - auto z = graph.variableSpace().getVariable(9)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - ASSERT_EQ(e, *z); + auto e = NDArrayFactory::create( + 'c', {1, 5, 5, 6}, + {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, + 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, + 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, + 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, + 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, + 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, + 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, + 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, + 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, + 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, + 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, + 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, + 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, + 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, + 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, + 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, + 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, + 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, + 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, + 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, + 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, + 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, + 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, + 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, + 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, + 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, + 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, + 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, + 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, + 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); + + auto graph = Graph::fromFlatBuffers( + "./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); + + graph.execute(); + + ASSERT_TRUE(graph.variableSpace().hasVariable(9)); + + auto z = graph.variableSpace().getVariable(9)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_1) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - auto graph = Graph::fromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(5)); + ASSERT_TRUE(graph.variableSpace().hasVariable(5)); - auto z = graph.variableSpace().getVariable(5)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(5)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_2) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - auto graph = Graph::fromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(6)); + ASSERT_TRUE(graph.variableSpace().hasVariable(6)); - auto z = graph.variableSpace().getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_3) { - if (1 > 0) - throw std::runtime_error("This test crashes"); + if (1 > 0) throw std::runtime_error("This test crashes"); - auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); + auto e = NDArrayFactory::create( + 'c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); - auto graph = Graph::fromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); + auto graph = Graph::fromFlatBuffers( + "./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(15)); + ASSERT_TRUE(graph.variableSpace().hasVariable(15)); - auto z = graph.variableSpace().getVariable(15)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(15)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_tensor_array_4) { - auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); + auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); - auto graph = Graph::fromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); + auto graph = Graph::fromFlatBuffers( + "./resources/" + "tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(11)); + ASSERT_TRUE(graph.variableSpace().hasVariable(11)); - auto z = graph.variableSpace().getVariable(11)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(11)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_assert_4) { - auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); - auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); + auto graph = Graph::fromFlatBuffers("./resources/assert_type_rank2_int64.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - auto z = graph.variableSpace().getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } // TEST_F(OneOffTests, test_cond_true_1) { -// auto e = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); +// auto e = NDArrayFactory::create('c', {5}, +// {1.f, 2.f, 3.f, 4.f, 5.f}); // auto graph = Graph::fromFlatBuffers("./resources/cond_true.fb"); // ASSERT_TRUE(graph != nullptr); // graph->printOut(); - // Nd4jStatus status = GraphExecutioner::execute(graph); // ASSERT_EQ(Status::OK(), status); // ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); @@ -266,55 +316,60 @@ TEST_F(OneOffTests, test_cond_false_1) { */ TEST_F(OneOffTests, test_identity_n_2) { - auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + auto e = + NDArrayFactory::create('c', {2, 3}, + {0.77878559f, 0.80119777f, 0.72437465f, + 0.23089433f, 0.72714126f, 0.18039072f}); - sd::ops::identity_n op; + sd::ops::identity_n op; - auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); + auto graph = Graph::fromFlatBuffers("./resources/identity_n_2.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - ASSERT_TRUE(graph.variableSpace().hasVariable(1, 1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1, 1)); - auto z = graph.variableSpace().getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_non2d_1) { - //if (1 > 0) - // throw std::runtime_error("Test not implemented yet"); + // if (1 > 0) + // throw std::runtime_error("Test not implemented yet"); - auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); + auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); - auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); + auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(3)); + ASSERT_TRUE(graph.variableSpace().hasVariable(3)); - auto z = graph.variableSpace().getVariable(3)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(3)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_reduce_all_1) { - auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); + auto e = + NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); - auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + auto graph = + Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); - graph.execute(); + graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(1)); + ASSERT_TRUE(graph.variableSpace().hasVariable(1)); - ASSERT_TRUE(graph.variableSpace().hasVariable(2)); - auto in = graph.variableSpace().getVariable(2)->getNDArray(); + ASSERT_TRUE(graph.variableSpace().hasVariable(2)); + auto in = graph.variableSpace().getVariable(2)->getNDArray(); - auto z = graph.variableSpace().getVariable(1)->getNDArray(); - ASSERT_TRUE(z != nullptr); + auto z = graph.variableSpace().getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); - ASSERT_EQ(e, *z); + ASSERT_EQ(e, *z); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index c4dbc2fcc2be..0c0f795e03b5 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -18,62 +18,62 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include #include +#include +#include #include #include -#include -#include + +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; using namespace sd::graph; class OpSequenceTests : public testing::Test { -public: - - OpSequenceTests() { - } + public: + OpSequenceTests() {} }; TEST_F(OpSequenceTests, test_iterator_1) { - Graph graph; - OpSequence sequence; + Graph graph; + OpSequence sequence; - ASSERT_EQ(0, sequence.length()); + ASSERT_EQ(0, sequence.length()); - ops::add op1; - ops::multiply op2; + ops::add op1; + ops::multiply op2; - Context ctx1(1); - Context ctx2(2); + Context ctx1(1); + Context ctx2(2); - sequence.append(&op1, ctx1); - sequence.append(&op2, ctx2); + sequence.append(&op1, ctx1); + sequence.append(&op2, ctx2); - ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(2, sequence.length()); - int cnt = 1; - for (const auto &v:sequence) { - ASSERT_EQ(cnt++, v.protoContext().nodeId()); - } + int cnt = 1; + for (const auto &v : sequence) { + ASSERT_EQ(cnt++, v.protoContext().nodeId()); + } - ASSERT_EQ(3, cnt); + ASSERT_EQ(3, cnt); - OptimizedGraph optimizedGraph; - ASSERT_EQ(0, optimizedGraph.layers()); + OptimizedGraph optimizedGraph; + ASSERT_EQ(0, optimizedGraph.layers()); - optimizedGraph.append(sequence); - ASSERT_EQ(1, optimizedGraph.layers()); + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); - auto layer = optimizedGraph.layer(0); + auto layer = optimizedGraph.layer(0); - // we expect exactly 1 sequence in this layer - ASSERT_EQ(1, layer.width()); + // we expect exactly 1 sequence in this layer + ASSERT_EQ(1, layer.width()); - auto seq = layer[0]; + auto seq = layer[0]; - ASSERT_EQ(2, seq.length()); + ASSERT_EQ(2, seq.length()); } diff --git a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp index 35e47788027a..f311f3355362 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp @@ -17,51 +17,51 @@ // // Created by raver119 on 15.12.17. // -#include "testlayers.h" #include -#include #include #include #include +#include + +#include "testlayers.h" + using namespace sd; using namespace sd::ops; using namespace sd::graph; class OpTrackerTests : public testing::Test { -public: - int numIterations = 10; - int poolSize = 10; + public: + int numIterations = 10; + int poolSize = 10; - OpTrackerTests() { - } + OpTrackerTests() {} }; TEST_F(OpTrackerTests, Test_Existence_1) { - sd::_loader loader; + sd::_loader loader; - // nd4j_printf("Groups: %i; Operations: %i\n", OpTracker::getInstance()->totalGroups(), OpTracker::getInstance()->totalOperations()); + // nd4j_printf("Groups: %i; Operations: %i\n", + // OpTracker::getInstance()->totalGroups(), + // OpTracker::getInstance()->totalOperations()); - ASSERT_TRUE(OpTracker::getInstance()->totalGroups() > 0); - ASSERT_TRUE(OpTracker::getInstance()->totalOperations() > 0); + ASSERT_TRUE(OpTracker::getInstance()->totalGroups() > 0); + ASSERT_TRUE(OpTracker::getInstance()->totalOperations() > 0); - OpTracker::getInstance()->exportOperations(); + OpTracker::getInstance()->exportOperations(); } TEST_F(OpTrackerTests, Test_Ops_List_1) { - sd::ops::less op; - auto vec = OpRegistrator::getInstance()->getAllHashes(); + sd::ops::less op; + auto vec = OpRegistrator::getInstance()->getAllHashes(); - // nd4j_printf("Total ops: %lld\n", vec.size()); - // nd4j_printf("Less hash: %lld\n", op.getOpHash()); + // nd4j_printf("Total ops: %lld\n", vec.size()); + // nd4j_printf("Less hash: %lld\n", op.getOpHash()); - for (const auto &v: vec) { - if (v == 5484196977525668316L) { - auto op = OpRegistrator::getInstance()->getOperation(v); - // nd4j_printf("OpName: %s\n", op->getOpName()->c_str()); - } + for (const auto &v : vec) { + if (v == 5484196977525668316L) { + auto op = OpRegistrator::getInstance()->getOperation(v); + // nd4j_printf("OpName: %s\n", op->getOpName()->c_str()); } + } } - - - diff --git a/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp index bec75f056dfc..6c947c001444 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTupleTests.cpp @@ -18,42 +18,39 @@ // Created by raver119 on 11.10.2017. // -#include "testlayers.h" #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::ops; class OpTupleTests : public testing::Test { - public: + public: }; TEST_F(OpTupleTests, DirectConstructorTest1) { - auto alpha = NDArrayFactory::create_('c', {1, 2}); - auto beta = NDArrayFactory::create_('c', {1, 2}); - OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1,2, 3}); - - ASSERT_EQ("dummy", tuple._opName); - ASSERT_EQ(2, tuple._inputs.size()); - ASSERT_EQ(0, tuple._outputs.size()); - ASSERT_EQ(1, tuple._tArgs.size()); - ASSERT_EQ(3, tuple._iArgs.size()); + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1, 2, 3}); + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); } TEST_F(OpTupleTests, BuilderTest1) { - auto alpha = NDArrayFactory::create_('c', {1, 2}); - auto beta = NDArrayFactory::create_('c', {1, 2}); - OpTuple tuple("dummy"); - tuple.addInput(alpha) - ->addInput(beta) - ->setTArgs({12.0f}) - ->setIArgs({1, 2, 3}); - - - ASSERT_EQ("dummy", tuple._opName); - ASSERT_EQ(2, tuple._inputs.size()); - ASSERT_EQ(0, tuple._outputs.size()); - ASSERT_EQ(1, tuple._tArgs.size()); - ASSERT_EQ(3, tuple._iArgs.size()); + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy"); + tuple.addInput(alpha)->addInput(beta)->setTArgs({12.0f})->setIArgs({1, 2, 3}); + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp index 0d73b369bae3..9a973bedfcb1 100644 --- a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp @@ -17,34 +17,31 @@ // // Created by agibsonccc on 1/17/17. // -#include "testinclude.h" #include +#include "testinclude.h" + class EqualsTest : public testing::Test { -public: - const Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102}; - float data[2] = {1.0f, 7.0f}; - const Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99}; - float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; - int opNum = 4; - float extraArgs[1] = {1e-6f}; - int dimension[1] = {2147483647}; - int dimensionLength = 1; + public: + const Nd4jLong firstShapeBuffer[8] = {2, 1, 2, 1, 1, 0, 1, 102}; + float data[2] = {1.0f, 7.0f}; + const Nd4jLong secondShapeBuffer[8] = {2, 2, 1, 6, 1, 0, 6, 99}; + float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + int opNum = 4; + float extraArgs[1] = {1e-6f}; + int dimension[1] = {2147483647}; + int dimensionLength = 1; }; #ifndef __CUDABLAS__ -TEST_F(EqualsTest,Eps) { - auto val = sd::NDArrayFactory::create(0.0f); - functions::reduce3::Reduce3::execScalar(opNum, - data, - firstShapeBuffer, - extraArgs, - dataSecond, - secondShapeBuffer, - val.buffer(), - val.shapeInfo()); - ASSERT_TRUE(val.e(0) < 0.5); +TEST_F(EqualsTest, Eps) { + auto val = sd::NDArrayFactory::create(0.0f); + functions::reduce3::Reduce3::execScalar( + opNum, data, firstShapeBuffer, extraArgs, dataSecond, secondShapeBuffer, + val.buffer(), val.shapeInfo()); + ASSERT_TRUE(val.e(0) < 0.5); } #endif diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index f9987a9d246c..f8dc57f087c2 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -18,1681 +18,1825 @@ // Created by raver119 on 12.10.2017. // -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; using namespace sd::ops; class ParityOpsTests : public testing::Test { -public: - + public: }; - TEST_F(ParityOpsTests, TestZeroAs1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0); + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); - auto exp = NDArrayFactory::create('c', {10, 10}); - exp.assign(0.0f); + auto exp = NDArrayFactory::create('c', {10, 10}); + exp.assign(0.0f); - sd::ops::zeros_as op; + sd::ops::zeros_as op; - auto result = op.evaluate({&x}, {}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(z.isSameShape(&x)); - ASSERT_TRUE(z.equalsTo(&exp)); + auto result = op.evaluate({&x}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(z.isSameShape(&x)); + ASSERT_TRUE(z.equalsTo(&exp)); } TEST_F(ParityOpsTests, TestMaximum1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0); - - auto y = NDArrayFactory::create('c', {10, 10}); - y.assign(2.0); + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); - sd::ops::maximum op; + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(2.0); - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::maximum op; - auto z = result.at(0); - - ASSERT_TRUE(y.equalsTo(z)); + auto result = op.evaluate({&x, &y}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(y.equalsTo(z)); } - TEST_F(ParityOpsTests, TestMinimum1) { - auto x = NDArrayFactory::create('c', {10, 10}); - x.assign(1.0f); - - auto y = NDArrayFactory::create('c', {10, 10}); - y.assign(-2.0f); - + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0f); - sd::ops::minimum op; + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(-2.0f); - auto result = op.evaluate({&x, &y}, {}, {}); + sd::ops::minimum op; - auto z = result.at(0); - - ASSERT_TRUE(y.equalsTo(z)); + auto result = op.evaluate({&x, &y}, {}, {}); + auto z = result.at(0); + ASSERT_TRUE(y.equalsTo(z)); } TEST_F(ParityOpsTests, TestTear1) { - auto input = NDArrayFactory::create('c', {10, 5}); - auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e).lengthOf()); - tads.at(e).assign((float) e + 1); - } - - sd::ops::tear op; - - auto result = op.evaluate({&input}, {}, {1}); + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - ASSERT_EQ(10, result.size()); + sd::ops::tear op; - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(10, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } TEST_F(ParityOpsTests, TestUnstack1) { - auto input = NDArrayFactory::create('c', {10, 5}); - auto tads = input.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(5, tads.at(e).lengthOf()); - tads.at(e).assign((float) e + 1); - } + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {0}); - - ASSERT_EQ(10, result.size()); - - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(10, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } - - TEST_F(ParityOpsTests, TestUnstack2) { - auto input = NDArrayFactory::create('c', {5,2,6}); - auto tads = input.allTensorsAlongDimension({0,1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_EQ(10, tads.at(e).lengthOf()); - tads.at(e).assign((float) e + 1); - } + auto input = NDArrayFactory::create('c', {5, 2, 6}); + auto tads = input.allTensorsAlongDimension({0, 1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(10, tads.at(e).lengthOf()); + tads.at(e).assign((float)e + 1); + } - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {2}); - - ASSERT_EQ(6, result.size()); - - for (int e = 0; e < result.size(); e++) - ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(6, result.size()); + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e).equalsTo(result.at(e))); } TEST_F(ParityOpsTests, TestUnstack3) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {3, 2}, {1.f, 4., 7., 10.f, 13.f, 16.f}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 2}, + {1.f, 4., 7., 10.f, 13.f, 16.f}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, TestUnstack4) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {3, 3}, { 1, 2, 3, 7, 8, 9, 13, 14, 15.}); - input.linspace(1); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 7, 8, 9, 13, 14, 15.}); + input.linspace(1); - auto z = result.at(0); + sd::ops::unstack op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack5) { - auto input = NDArrayFactory::create('c', {3,2,3}); - auto exp = NDArrayFactory::create('c', {2, 3}, { 1, 2, 3, 4, 5, 6}); - input.linspace(1); + auto input = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + input.linspace(1); - sd::ops::unstack op; + sd::ops::unstack op; - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack6) { - auto input = NDArrayFactory::create('c', {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack7) { - auto input = NDArrayFactory::create('c', {1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack8) { - auto input = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {1}); - input.linspace(1); - - sd::ops::unstack op; + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, TestUnstack9) { - auto input = NDArrayFactory::create('c', {1, 1}); - auto exp = NDArrayFactory::create('c', {1}, {1}); - input.linspace(1); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); - auto z = result.at(0); + sd::ops::unstack op; - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack10) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {0, 2}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto exp = NDArrayFactory::create('c', {0,2}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.isSameShape(result.at(1))); - ASSERT_TRUE(exp.isSameShape(result.at(2))); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); + ASSERT_TRUE(exp.isSameShape(result.at(2))); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack11) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {3, 0}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto exp = NDArrayFactory::create('c', {3,0}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {2}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(exp.isSameShape(result.at(0))); - ASSERT_TRUE(exp.isSameShape(result.at(1))); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, TestUnstack12) { + auto input = NDArrayFactory::create('c', {3, 0, 2}); - auto input = NDArrayFactory::create('c', {3, 0, 2}); - - sd::ops::unstack op; - - auto result = op.evaluate({&input}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - ASSERT_TRUE(result.size() == 0); + sd::ops::unstack op; + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.size() == 0); } TEST_F(ParityOpsTests, TestUnstack13) { + auto x = NDArrayFactory::create('c', {2, 3}); - auto x = NDArrayFactory::create('c', {2, 3}); - - sd::ops::unstack op; - auto result = op.evaluate({&x}, {}, {1}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(3, result.size()); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - for (int e = 0; e < 3; e++) - ASSERT_EQ(1, result.at(e).rankOf()); + ASSERT_EQ(3, result.size()); + for (int e = 0; e < 3; e++) ASSERT_EQ(1, result.at(e).rankOf()); } - - TEST_F(ParityOpsTests, ExpandDimsTest1) { - auto input = NDArrayFactory::create('c', {5, 5}); - input.linspace(1); - auto reshaped = input.reshape('c', {5, 1, 5}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {1}); + auto input = NDArrayFactory::create('c', {5, 5}); + input.linspace(1); + auto reshaped = input.reshape('c', {5, 1, 5}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {1}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, ExpandDimsTest2) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {1, 3, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {0}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {0}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, ExpandDimsTest3) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {3, 1, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {-2}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {3, 1, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-2}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } TEST_F(ParityOpsTests, ExpandDimsTest4) { - auto input = NDArrayFactory::create('c', {3, 4}); - input.linspace(1); - auto reshaped = input.reshape('c', {1, 3, 4}); - - sd::ops::expand_dims op; - auto result = op.evaluate({&input}, {}, {-3}); + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-3}); - auto z = result.at(0); - - ASSERT_TRUE(reshaped.isSameShape(z)); - ASSERT_TRUE(reshaped.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Shape_1) { - auto x = NDArrayFactory::create('c', {3, 4, 5, 6}); - auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); - sd::ops::shape_of op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::shape_of op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Equals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); - sd::ops::equals op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_NotEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); - - sd::ops::not_equals op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::not_equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Less_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); - - sd::ops::less op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_LessEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); - sd::ops::less_equal op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::less_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_GreaterEquals_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - - sd::ops::greater_equal op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_GreaterEquals_2) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); - sd::ops::greater_equal op; - auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Greater_1) { - auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); - - sd::ops::greater op; - auto result = op.evaluate({&x, &y}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Where_1) { - auto mask = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 0, 0, 0, 1, 1, 1}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); - - sd::ops::Where op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto mask = + NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 0, 0, 0, 1, 1, 1}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); - // z->printIndexedBuffer("result"); + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printIndexedBuffer("result"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Where_2) { - auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - sd::ops::Where op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, Test_Where_3) { - auto mask = NDArrayFactory::create('c', {2, 2, 3}, {0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1}); - auto exp = NDArrayFactory::create('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); + auto mask = NDArrayFactory::create( + 'c', {2, 2, 3}, {0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1}); + auto exp = NDArrayFactory::create( + 'c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); - sd::ops::Where op; - auto result = op.evaluate({&mask}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::Where op; + auto result = op.evaluate({&mask}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - // z->printShapeInfo("z"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_1) { - auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); - auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = + NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = + NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_2) { - auto mask = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4 }); - auto y = NDArrayFactory::create('c', {2, 2}, {9, 8, 7, 6}); - auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); + auto mask = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {9, 8, 7, 6}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Select_3) { - bool value = false; - auto mask = NDArrayFactory::create('c', {1, 1}, {value}); - auto x = NDArrayFactory::create('c', {1, 1}, {1}); - auto y = NDArrayFactory::create('c', {1, 1}, {2}); - auto exp = NDArrayFactory::create('c', {1, 1}, {2}); - - sd::ops::select op; - auto result = op.evaluate({&mask, &x, &y}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + bool value = false; + auto mask = NDArrayFactory::create('c', {1, 1}, {value}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 1}, {2}); + auto exp = NDArrayFactory::create('c', {1, 1}, {2}); - auto z = result.at(0); + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Bias_Add_1) { - auto x = NDArrayFactory::create('c', {10, 5}); - x.assign(0.0); - auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - sd::ops::biasadd op; + auto x = NDArrayFactory::create('c', {10, 5}); + x.assign(0.0); + auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + sd::ops::biasadd op; - auto result = op.evaluate({&x, &bias}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto result = op.evaluate({&x, &bias}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - auto tads = z.allTensorsAlongDimension({1}); - for (int e = 0; e < tads.size(); e++) { - ASSERT_TRUE(bias.equalsTo(tads.at(e))); - } + auto z = result.at(0); + auto tads = z.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_TRUE(bias.equalsTo(tads.at(e))); + } } TEST_F(ParityOpsTests, Test_Scatter_Add_1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_2) { + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0., 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0., 1, 2, 3}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1, 2}, std::vector{0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 1, 1, 1, 1, 1, 1, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1., 1, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto exp = NDArrayFactory::create('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1., 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, + {9., 11., 13., 15., 17., 19., 9., 11., 13., 15., 17., 19.}); - auto z = result.at(0); - // z->printBuffer(); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_6) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {7, 9, 11, 13, 7, 9, 11, 13}); - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, Test_Scatter_Add_7) { - auto matrix = NDArrayFactory::create('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); - NDArray idc('c', {}, std::vector{5}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); - auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); - - sd::ops::scatter_add op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {10, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f}); + NDArray idc('c', {}, std::vector{5}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); + auto exp = NDArrayFactory::create( + 'c', {10, 3}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, + 11.f, 12.f, 13.f, 14.f, 15.f, 26.f, 37.f, 48.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f}); - auto z = result.at(0); + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, Test_Scatter_Add_8) { + NDArray input('c', {8}, {1, 1, 1, 1, 1, 1, 1, 1}, sd::DataType::FLOAT32); + NDArray indices('c', {4}, {1, 1, 1, 1}, sd::DataType::INT32); + NDArray updates('c', {4}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray expected('c', {8}, {1.f, 11.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, + sd::DataType::FLOAT32); - NDArray input('c', {8}, {1,1,1,1,1,1,1,1}, sd::DataType::FLOAT32); - NDArray indices('c', {4}, {1, 1, 1, 1}, sd::DataType::INT32); - NDArray updates('c', {4}, {1,2,3,4}, sd::DataType::FLOAT32); - NDArray expected('c', {8}, {1.f, 11.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, sd::DataType::FLOAT32); + NDArray z('c', {8}, sd::DataType::FLOAT32); - NDArray z('c', {8}, sd::DataType::FLOAT32); + sd::ops::scatter_add op; + Nd4jStatus status = + op.execute({&input, &indices, &updates}, {&z}, {}, {}, {true}); + // z.printBuffer(); - sd::ops::scatter_add op; - Nd4jStatus status = op.execute({&input, &indices, &updates}, {&z}, {}, {}, {true}); - // z.printBuffer(); - - ASSERT_EQ(ND4J_STATUS_OK, status); - ASSERT_TRUE(expected.isSameShapeStrict(z)); - ASSERT_TRUE(expected.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, Test_Scatter_Add_9) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 10, 0, 0}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - auto output = NDArrayFactory::create('c', {2, 2, 3}); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 10, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto output = NDArrayFactory::create('c', {2, 2, 3}); - sd::ops::scatter_add op; + sd::ops::scatter_add op; - ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW( + op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); } //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMax_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector{0.}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector{0.}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test2) { - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); - sd::ops::scatter_max op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_max op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0., 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, + {1, 10, 10, 10, 5, 6, 7, 8}); - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {2,10,1,10, 2,10,1,10, 2,10,1,10, 10,2,10,1, 10,2,10,1, 10,2,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = NDArrayFactory::create( + 'c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 3}, {2, 10, 1, 10, 2, 10, 1, 10, 2, 10, 1, 10, + 10, 2, 10, 1, 10, 2, 10, 1, 10, 2, 10, 1.}); + auto exp = NDArrayFactory::create( + 'c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMax_test6) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); - NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {0,2,0,2, 0,2,0,2, 2,0,2,0., 2,0,2,0}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); - - sd::ops::scatter_max op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 2, 2, 2}, {0, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 0., 2, 0, 2, 0}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); - auto z = result.at(0); + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ParityOpsTests, scatterMin_test1) { - auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); - auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); - - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test2) { - auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); - NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); - sd::ops::scatter_min op; - auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_min op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test3) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); - - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); - auto z = result.at(0); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ParityOpsTests, scatterMin_test4) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0., 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto exp = + NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); - sd::ops::scatter_min op; - auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMin_test5) { - auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); - NDArray idc('c', {1,2}, {10,10}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); - auto output = NDArrayFactory::create('c', {2, 2, 2}); + auto matrix = + NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, {10, 10}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, + {1, 10, 1, 10, 1, 1, 10, 1.}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); - sd::ops::scatter_min op; + sd::ops::scatter_min op; - ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW( + op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test1) { + NDArray indices('c', {2, 1}, {1., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create( + 'c', {2, 4}, {10.f, 20.f, 30.f, 40.f, 50.f, 60.f, 70.f, 80.f}); + auto shape = NDArrayFactory::create('c', {2}, {3, 4}); + auto exp = NDArrayFactory::create( + 'c', {3, 4}, + {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); - NDArray indices('c', {2, 1}, {1., 0.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2, 4}, {10.f, 20.f, 30.f, 40.f, 50.f, 60.f, 70.f, 80.f}); - auto shape = NDArrayFactory::create('c', {2}, {3, 4}); - auto exp = NDArrayFactory::create('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test2) { + NDArray indices('c', {3, 1}, {4., 2., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 4}); + auto shape = NDArrayFactory::create('c', {2}, {5, 4}); + auto exp = NDArrayFactory::create( + 'c', {5, 4}, {9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, + 7.f, 8.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 4.f}); + updates.linspace(1.f); - NDArray indices('c', {3, 1}, {4., 2., 0.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3, 4}); - auto shape = NDArrayFactory::create('c', {2}, {5, 4}); - auto exp = NDArrayFactory::create('c', {5, 4}, {9.f,10.f,11.f,12.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 7.f, 8.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 4.f}); - updates.linspace(1.f); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test3) { - - NDArray indices('c', {2, 3, 1}, {0., 2., 7., 3., 6., 9.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3, 3,4}); - auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); - auto exp = NDArrayFactory::create('c', {10, 3, 4}, {1.f, 2.f, 3.f, 4., 5.f, 6.f, 7.f, 8., 9.f, 10.f, 11.f, 12., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., - 13.f, 14.f, 15.f, 16.,17.f, 18.f, 19.f, 20.,21.f, 22.f, 23.f, 24.,37.f, 38.f, 39.f, 40.,41.f, 42.f, 43.f, 44.,45.f, 46.f, 47.f, 48., - 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., - 49.f, 50.f, 51.f, 52.,53.f, 54.f, 55.f, 56.,57.f, 58.f, 59.f, 60.,25.f, 26.f, 27.f, 28.,29.f, 30.f, 31.f, 32.,33.f, 34.f, 35.f, 36., - 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0.,61.f, 62.f, 63.f, 64.,65.f, 66.f, 67.f, 68.,69.f, 70.f, 71.f, 72.,}); - updates.linspace(1.f); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - + NDArray indices('c', {2, 3, 1}, {0., 2., 7., 3., 6., 9.}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 3, 4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {10, 3, 4}, + { + 1.f, 2.f, 3.f, 4., 5.f, 6.f, 7.f, 8., 9.f, 10.f, 11.f, 12., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 13.f, 14.f, 15.f, 16., 17.f, 18.f, 19.f, 20., 21.f, 22.f, 23.f, 24., + 37.f, 38.f, 39.f, 40., 41.f, 42.f, 43.f, 44., 45.f, 46.f, 47.f, 48., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 49.f, 50.f, 51.f, 52., 53.f, 54.f, 55.f, 56., 57.f, 58.f, 59.f, 60., + 25.f, 26.f, 27.f, 28., 29.f, 30.f, 31.f, 32., 33.f, 34.f, 35.f, 36., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 61.f, 62.f, 63.f, 64., 65.f, 66.f, 67.f, 68., 69.f, 70.f, 71.f, 72., + }); + updates.linspace(1.f); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test4) { + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create( + 'c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); - NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto shape = NDArrayFactory::create('c', {1}, {8}); - auto exp = NDArrayFactory::create('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test5) { + NDArray indices('c', {4, 1}, {1, 1, 1, 1}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create( + 'c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - NDArray indices('c', {4, 1}, {1, 1, 1, 1}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); - auto shape = NDArrayFactory::create('c', {1}, {8}); - auto exp = NDArrayFactory::create('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test6) { + NDArray indices('c', {3, 2}, {0, 1, 1, 0, 3, 2}, sd::DataType::INT32); + NDArray updates('c', {3, 2, 3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5, 4, 2, 3}, sd::DataType::INT32); - NDArray indices('c', {3, 2}, {0,1,1,0,3,2}, sd::DataType::INT32); - NDArray updates('c', {3, 2, 3}, sd::DataType::FLOAT32); - NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); - - NDArray exp('c', {5,4,2,3}, {0., 0., 0.,0., 0., 0.,1., 2., 3.,4., 5., 6.,0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., - 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - updates.linspace(1); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + NDArray exp('c', {5, 4, 2, 3}, + {0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + updates.linspace(1); - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test7) { - - NDArray indices('c', {4,3,2}, {0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0}, sd::DataType::INT32); - NDArray updates('c', {4,3,2,3}, sd::DataType::FLOAT32); - NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); - - NDArray exp('c', {5,4,2,3}, {0., 0., 0., 0., 0., 0., 75., 78., 81., 84., 87., 90., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 222., 228., 234., 240., 246., 252., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 111., 114., 117., 120., 123., 126., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); - updates.linspace(1); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + NDArray indices('c', {4, 3, 2}, {0, 1, 1, 0, 3, 2, 1, 0, 0, 1, 1, 0, + 3, 2, 1, 0, 0, 1, 1, 0, 3, 2, 1, 0}, + sd::DataType::INT32); + NDArray updates('c', {4, 3, 2, 3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5, 4, 2, 3}, sd::DataType::INT32); + + NDArray exp('c', {5, 4, 2, 3}, + {0., 0., 0., 0., 0., 0., 75., 78., 81., 84., 87., 90., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 222., 228., 234., 240., 246., 252., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 111., 114., 117., 120., 123., 126., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, + sd::DataType::FLOAT32); + updates.linspace(1); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test8) { + NDArray indices('c', {3, 2}, {0, 0, 1, 1, 2, 2}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create('c', {2}, {6, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - NDArray indices('c', {3, 2}, {0,0, 1,1, 2,2}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); - auto shape = NDArrayFactory::create('c', {2}, {6,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - - sd::ops::scatter_nd op; - auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test9) { + NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 3, 4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto output = NDArrayFactory::create('c', {10, 3, 4}); - NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3, 3,4}); - auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); - auto output = NDArrayFactory::create('c', {10, 3, 4}); - - sd::ops::scatter_nd op; + sd::ops::scatter_nd op; - ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true})); + ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, + {&output}, {}, {}, {false, true})); } - //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1.f, 0.f, 7.f, 0.f, 0.f, 2.f, 0.f, 8.f, 9.f, 0.f, 3.f, 0.f, + 0.f, 0.f, 0.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {1.f,0.f,7.f,0.f, 0.f,2.f,0.f,8.f, 9.f,0.f,3.f,0.f, 0.f,0.f,0.f,4.f, 5.f,0.f,0.f,0.f, 0.f,6.f,0.f,0.f}); - - input = 0.f; - updates.linspace(1.f); + input = 0.f; + updates.linspace(1.f); - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + input = 0.f; + updates.linspace(1.f); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 0.f,31.f, 32.f, 33.f, 34.f, 35.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 6.f, 7.f, 8.f, 9.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f,36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 0.f, 0.f, 0.f, 0.f, 0.f,11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,16.f, 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, 25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f,26.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 6.f, 7.f, 8.f, 9.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_add op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test6) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto output = NDArrayFactory::create('c', {6, 4}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto output = NDArrayFactory::create('c', {6,4}); + sd::ops::scatter_nd_add op; - sd::ops::scatter_nd_add op; - - ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true})); + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, + {false, true})); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {-1.f, 0.f, -7.f, 0.f, 0.f, -2.f, 0.f, -8.f, -9.f, 0.f, -3.f, 0.f, + 0.f, 0.f, 0.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, -6.f, 0.f, 0.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {-1.f,0.f,-7.f,0.f, 0.f,-2.f,0.f,-8.f, -9.f,0.f,-3.f,0.f, 0.f,0.f,0.f,-4.f, -5.f,0.f,0.f,0.f, 0.f,-6.f,0.f,0.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - //exp.printIndexedBuffer("e"); - //z->printIndexedBuffer("z"); + input = 0.f; + updates.linspace(1.f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // exp.printIndexedBuffer("e"); + // z->printIndexedBuffer("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, {-21.f, -22.f, -23.f, -24., -5.f, -6.f, -7.f, -8., + -9.f, -10.f, -11.f, -12., -13.f, -14.f, -15.f, -16., + -17.f, -18.f, -19.f, -20., -1.f, -2.f, -3.f, -4.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f,4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {-21.f,-22.f,-23.f,-24., -5.f, -6.f, -7.f, -8., -9.f,-10.f,-11.f,-12., -13.f,-14.f,-15.f,-16., -17.f,-18.f,-19.f,-20., -1.f, -2.f, -3.f, -4.f}); - - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + input = 0.f; + updates.linspace(1.f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {-1.f, -2.f, -3.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, 0.f,-31.f, -32.f, -33.f, -34.f, -35.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, -6.f, -7.f, -8.f, -9.f, -10.f, 0.f, 0.f, 0.f, 0.f, 0.f,-36.f, -37.f, -38.f, -39.f, -40.f, - -41.f, -42.f, -43.f, -44.f, -45.f, 0.f, 0.f, 0.f, 0.f, 0.f,-11.f, -12.f, -13.f, -14.f, -15.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-16.f, -17.f, -18.f, -19.f, -20.f, - -21.f, -22.f, -23.f, -24.f, -25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f,-26.f, -27.f, -28.f, -29.f, -30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {-1.f, -2.f, -3.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -31.f, -32.f, -33.f, -34.f, -35.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -6.f, -7.f, -8.f, -9.f, -10.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -36.f, -37.f, -38.f, -39.f, -40.f, + -41.f, -42.f, -43.f, -44.f, -45.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -11.f, -12.f, -13.f, -14.f, -15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -16.f, -17.f, -18.f, -19.f, -20.f, + -21.f, -22.f, -23.f, -24.f, -25.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -26.f, -27.f, -28.f, -29.f, -30.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { -1.f, -2.f, -3.f, -4.f, -5.f, -6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -7.f, -8.f, -9.f, -10.f,-11.f, -12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-13.f, -14.f,-15.f, -16.f,-17.f, -18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-19.f, -20.f,-21.f, -22.f,-23.f,-24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); - input = 0.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_sub op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -7.f, -8.f, -9.f, -10.f, -11.f, -12.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -13.f, -14.f, -15.f, -16.f, -17.f, -18.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + -19.f, -20.f, -21.f, -22.f, -23.f, -24.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test1) { + auto input = NDArrayFactory::create( + 'c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = + NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create( + 'c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); - auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); - auto exp = NDArrayFactory::create('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test2) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + {1.f, -1.f, 7.f, -1.f, -1.f, 2.f, -1.f, 8.f, 9.f, -1.f, 3.f, -1.f, + -1.f, -1.f, -1.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, 6.f, -1.f, -1.f}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {6,4}, {1.f,-1.f,7.f,-1.f, -1.f,2.f,-1.f,8.f, 9.f,-1.f,3.f,-1.f, -1.f,-1.f,-1.f,4.f, 5.f,-1.f,-1.f,-1.f, -1.f,6.f,-1.f,-1.f}); - - input = -1.f; - updates.linspace(1.f); + input = -1.f; + updates.linspace(1.f); - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printIndexedBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test3) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create( + 'c', {6, 4}, + { + 21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f, + }); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,3,4}); - auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f,}); - - input = -1.f; - updates.linspace(1.f); + input = -1.f; + updates.linspace(1.f); - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - // z->printBuffer(); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test4) { - - auto input = NDArrayFactory::create('c', {6, 4, 5}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3,5}); - auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, -1.f,31.f, 32.f, 33.f, 34.f, 35.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f, 6.f, 7.f, 8.f, 9.f, 10.f, -1.f, -1.f, -1.f, -1.f, -1.f,36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, -1.f, -1.f, -1.f, -1.f, -1.f,11.f, 12.f, 13.f, 14.f, 15.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,16.f, 17.f, 18.f, 19.f, 20.f, - 21.f, 22.f, 23.f, 24.f, 25.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, - -1.f, -1.f, -1.f, -1.f, -1.f,26.f, 27.f, 28.f, 29.f, 30.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); - input = -1.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 0.f, 5.f, 1.f, + 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3, 5}); + auto exp = NDArrayFactory::create( + 'c', {6, 4, 5}, + {1.f, 2.f, 3.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, -1.f, 31.f, 32.f, + 33.f, 34.f, 35.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, 6.f, 7.f, 8.f, 9.f, 10.f, -1.f, -1.f, -1.f, -1.f, -1.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, -1.f, -1.f, -1.f, + -1.f, -1.f, 11.f, 12.f, 13.f, 14.f, 15.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + 25.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 26.f, 27.f, 28.f, + 29.f, 30.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test5) { - - auto input = NDArrayFactory::create('c', {6,5,4,3,2}); - NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {2,2,3,2}); - auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, --1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); - input = -1.f; - updates.linspace(1.f); - - sd::ops::scatter_nd_update op; - auto result = op.evaluate({&input, &indices, &updates}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - + auto input = NDArrayFactory::create('c', {6, 5, 4, 3, 2}); + NDArray indices('c', {2, 2, 3}, + {0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 3, 2}); + auto exp = NDArrayFactory::create( + 'c', {6, 5, 4, 3, 2}, + {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_update_test6) { + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, + {0.f, 0.f, 10.f, 1.f, 20.f, 2.f, 30.f, 3.f, 40.f, 0.f, 50.f, + 1.f, 0.f, 2.f, 1.f, 3.f, 2.f, 0.f}, + sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 3}); + auto output = NDArrayFactory::create('c', {6, 4}); - auto input = NDArrayFactory::create('c', {6, 4}); - NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); - auto updates = NDArrayFactory::create('c', {3,3}); - auto output = NDArrayFactory::create('c', {6,4}); + sd::ops::scatter_nd_update op; - sd::ops::scatter_nd_update op; - - ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true})); + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, + {true, true})); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_1) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray updates('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); - - NDArray exp('c', {2,2}, {30,40,10,20}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - // x.printBuffer(); + NDArray exp('c', {2, 2}, {30, 40, 10, 20}, sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 1, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // x.printBuffer(); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_2) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray updates('c', {2, 2}, {10, 20, 30, 40}, sd::DataType::INT32); - NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); - NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); - - NDArray exp('c', {2,2}, {20,10,40,30}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2}, {20, 10, 40, 30}, sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 0, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_3) { + NDArray x('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray updates('c', {2, 2, 2}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::INT32); - NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); + NDArray exp('c', {2, 2, 2}, {50, 60, 70, 80, 10, 20, 30, 40}, + sd::DataType::INT32); - NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 2, 1, 2, 2, 1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_4) { + NDArray x('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, sd::DataType::INT32); + NDArray updates('c', {2, 2, 2}, {10, 20, 30, 40, 50, 60, 70, 80}, + sd::DataType::INT32); - NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); - NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); - - NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, sd::DataType::INT32); - - sd::ops::scatter_update op; - auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray exp('c', {2, 2, 2}, {20, 2, 3, 10, 60, 6, 7, 50}, + sd::DataType::INT32); - ASSERT_TRUE(exp.isSameShape(x)); - ASSERT_TRUE(exp.equalsTo(x)); + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1, 0, 2, 3, 0}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); } diff --git a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp index c6155eb0c235..32d87ba4f705 100644 --- a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp @@ -18,127 +18,129 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include #include -#include #include +#include +#include #include - -#include +#include +#include +#include +#include +#include #include -#include -#include -#include +#include #include #include -#include -#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class PerformanceTests : public testing::Test { -public: - int numIterations = 100; + public: + int numIterations = 100; - PerformanceTests() { - samediff::ThreadPool::getInstance(); - } + PerformanceTests() { samediff::ThreadPool::getInstance(); } }; - #ifdef RELEASE_BUILD TEST_F(PerformanceTests, test_matmul_c_f_1) { - int iterations = 500; - std::vector valuesC, valuesF; - for (int e = 0; e < iterations; e++) { - auto xc = NDArrayFactory::create('c', {512, 2048}); - auto yc = NDArrayFactory::create('c', {2048, 512}); - auto zc = NDArrayFactory::create('c', {512, 512}); + int iterations = 500; + std::vector valuesC, valuesF; + for (int e = 0; e < iterations; e++) { + auto xc = NDArrayFactory::create('c', {512, 2048}); + auto yc = NDArrayFactory::create('c', {2048, 512}); + auto zc = NDArrayFactory::create('c', {512, 512}); - auto xf = NDArrayFactory::create('f', {512, 2048}); - auto yf = NDArrayFactory::create('f', {2048, 512}); - auto zf = NDArrayFactory::create('f', {512, 512}); + auto xf = NDArrayFactory::create('f', {512, 2048}); + auto yf = NDArrayFactory::create('f', {2048, 512}); + auto zf = NDArrayFactory::create('f', {512, 512}); - auto warm = xc.like(); - warm.linspace(1.0); + auto warm = xc.like(); + warm.linspace(1.0); - //zc.linspace(1.0); - //zf.linspace(1.0); + // zc.linspace(1.0); + // zf.linspace(1.0); - sd::ops::matmul op; + sd::ops::matmul op; - auto timeStartF = std::chrono::system_clock::now(); + auto timeStartF = std::chrono::system_clock::now(); - op.execute({&xf, &yf}, {&zf}); + op.execute({&xf, &yf}, {&zf}); - auto timeEndF = std::chrono::system_clock::now(); - auto outerTimeF = std::chrono::duration_cast(timeEndF - timeStartF).count(); + auto timeEndF = std::chrono::system_clock::now(); + auto outerTimeF = std::chrono::duration_cast( + timeEndF - timeStartF) + .count(); + auto timeStartC = std::chrono::system_clock::now(); - auto timeStartC = std::chrono::system_clock::now(); + op.execute({&xc, &yc}, {&zc}); - op.execute({&xc, &yc}, {&zc}); + auto timeEndC = std::chrono::system_clock::now(); + auto outerTimeC = std::chrono::duration_cast( + timeEndC - timeStartC) + .count(); - auto timeEndC = std::chrono::system_clock::now(); - auto outerTimeC = std::chrono::duration_cast(timeEndC - timeStartC).count(); + valuesF.emplace_back(outerTimeF); + valuesC.emplace_back(outerTimeC); + } - valuesF.emplace_back(outerTimeF); - valuesC.emplace_back(outerTimeC); - } + std::sort(valuesC.begin(), valuesC.end()); + std::sort(valuesF.begin(), valuesF.end()); - std::sort(valuesC.begin(), valuesC.end()); - std::sort(valuesF.begin(), valuesF.end()); - - - nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); + nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", + valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); } TEST_F(PerformanceTests, test_maxpooling2d_1) { - std::vector valuesX; - // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); - // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); - auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); - auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); - x.linspace(1.0f); - Nd4jLong k = 5; - - - Nd4jLong iArgs[] {k,k, 1,1, 0,0, 1,1, 1}; - Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setOutputArray(0, &z); - ctx.setIArguments(iArgs, 9); - - sd::ops::maxpool2d op; - - for (int i = 0; i < numIterations; i++) { - auto timeStart = std::chrono::system_clock::now(); - - op.execute(&ctx); - - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesX.emplace_back(outerTime); - - if ((i + 1) % 1000 == 0) - nd4j_printf("Iteration %i finished...\n", i + 1); - } - - std::sort(valuesX.begin(), valuesX.end()); - nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); + std::vector valuesX; + // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); + auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); + x.linspace(1.0f); + Nd4jLong k = 5; + + Nd4jLong iArgs[]{k, k, 1, 1, 0, 0, 1, 1, 1}; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + ctx.setIArguments(iArgs, 9); + + sd::ops::maxpool2d op; + + for (int i = 0; i < numIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast( + timeEnd - timeStart) + .count(); + valuesX.emplace_back(outerTime); + + if ((i + 1) % 1000 == 0) nd4j_printf("Iteration %i finished...\n", i + 1); + } + + std::sort(valuesX.begin(), valuesX.end()); + nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", + valuesX[valuesX.size() / 2], valuesX[0], + valuesX[valuesX.size() - 1]); } #endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index abf8d00141c5..46c191e32e36 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -19,201 +19,216 @@ // Created by raver119 on 20.11.17. // -#include "testlayers.h" #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include #include -#include #include +#include +#include #include +#include +#include +#include +#include #include - -#include +#include +#include #include -#include -#include -#include +#include #include #include -#include -#include -#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class PlaygroundTests : public testing::Test { -public: - int numIterations = 3; - int poolSize = 10; - - PlaygroundTests() { - printf("\n"); - fflush(stdout); - } + public: + int numIterations = 3; + int poolSize = 10; + + PlaygroundTests() { + printf("\n"); + fflush(stdout); + } }; TEST_F(PlaygroundTests, test_avx) { - nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel()); + nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), + ::binaryLevel()); } - TEST_F(PlaygroundTests, test_biasAdd_1) { - auto x = NDArrayFactory::create('c', {512, 3072}); - auto y = NDArrayFactory::create('c', {3072}); + auto x = NDArrayFactory::create('c', {512, 3072}); + auto y = NDArrayFactory::create('c', {3072}); - std::vector values; + std::vector values; - sd::ops::biasadd op; + sd::ops::biasadd op; - for (int e = 0; e < 100; e++) { - auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < 100; e++) { + auto timeStart = std::chrono::system_clock::now(); - op.execute({&x, &y}, {&x}); + op.execute({&x, &y}, {&x}); - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast( + timeEnd - timeStart) + .count(); + values.emplace_back(outerTime); + } - std::sort(values.begin(), values.end()); + std::sort(values.begin(), values.end()); - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); } - TEST_F(PlaygroundTests, test_bert_full_1) { #ifdef _RELEASE + // this test will run ONLY if this model exists + if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) + return; + auto graph = + Graph::fromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); - // this test will run ONLY if this model exists - if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) - return; - - auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); + auto t = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); + auto u = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); + auto v = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/in2_IteratorGetNext_4.npy"); + auto z = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); - auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); - auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); - auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in2_IteratorGetNext_4.npy"); - auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); + // graph->printOut(); - //graph->printOut(); + graph->tagInplaceNodes(); - graph->tagInplaceNodes(); + graph->variableSpace()->putVariable(658, 0, t); + graph->variableSpace()->putVariable(659, 0, u); + graph->variableSpace()->putVariable(660, 0, v); - graph->variableSpace()->putVariable(658,0, t); - graph->variableSpace()->putVariable(659,0, u); - graph->variableSpace()->putVariable(660,0, v); + /* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->variableSpace()->hasVariable(1620)); -/* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->variableSpace()->hasVariable(1620)); + auto array = graph->variableSpace()->getVariable(1620)->getNDArray(); + ASSERT_EQ(z, *array); - auto array = graph->variableSpace()->getVariable(1620)->getNDArray(); - ASSERT_EQ(z, *array); + */ -*/ + sd::Environment::getInstance()->setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); - sd::Environment::getInstance()->setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); + profile->printOut(); - profile->printOut(); + sd::Environment::getInstance()->setProfiling(false); + delete profile; - sd::Environment::getInstance()->setProfiling(false); - delete profile; + /* + std::vector values; -/* - std::vector values; + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + GraphExecutioner::execute(graph); - GraphExecutioner::execute(graph); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + std::sort(values.begin(), values.end()); - std::sort(values.begin(), values.end()); - - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ + delete graph; #endif } - TEST_F(PlaygroundTests, test_bert_1) { #ifdef _RELEASE - // this test will run ONLY if this model exists - if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb")) - return; + // this test will run ONLY if this model exists + if (!FileUtils::fileExists( + "/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb")) + return; - auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); + auto graph = Graph::fromFlatBuffers( + "/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); - auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext.numpy"); - auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_1.numpy"); - auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_4.numpy"); - auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model_output.numpy"); + auto t = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext.numpy"); + auto u = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext_1.numpy"); + auto v = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_input_IteratorGetNext_4.numpy"); + auto z = NDArrayFactory::fromNpyFile( + "/home/raver119/Downloads/Bert_minimal_model/" + "bert_minimal_model_output.numpy"); - //graph->printOut(); + // graph->printOut(); - graph->tagInplaceNodes(); + graph->tagInplaceNodes(); - graph->variableSpace()->putVariable(85,0, t); - graph->variableSpace()->putVariable(86,0, u); - graph->variableSpace()->putVariable(87,0, v); + graph->variableSpace()->putVariable(85, 0, t); + graph->variableSpace()->putVariable(86, 0, u); + graph->variableSpace()->putVariable(87, 0, v); -/* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); + /* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); - auto array = graph->variableSpace()->getVariable(198)->getNDArray(); - ASSERT_EQ(z, *array); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); -*/ - sd::Environment::getInstance()->setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); + */ + sd::Environment::getInstance()->setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); - profile->printOut(); + profile->printOut(); - sd::Environment::getInstance()->setProfiling(false); - delete profile; + sd::Environment::getInstance()->setProfiling(false); + delete profile; -/* - std::vector values; + /* + std::vector values; - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - GraphExecutioner::execute(graph); + GraphExecutioner::execute(graph); - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - std::sort(values.begin(), values.end()); + std::sort(values.begin(), values.end()); - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ + delete graph; #endif } @@ -221,68 +236,68 @@ TEST_F(PlaygroundTests, test_bert_1) { TEST_F(PlaygroundTests, test_bert_2) { #ifdef _RELEASE - // this test will run ONLY if this model exists - if (!FileUtils::fileExists("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb")) - return; + // this test will run ONLY if this model exists + if (!FileUtils::fileExists( + "/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb")) + return; - auto graph = Graph::fromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); + auto graph = Graph::fromFlatBuffers( + "/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); - //graph->printOut(); + // graph->printOut(); - graph->tagInplaceNodes(); + graph->tagInplaceNodes(); + /* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); -/* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->variableSpace()->hasVariable(198)); + auto array = graph->variableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); + */ - auto array = graph->variableSpace()->getVariable(198)->getNDArray(); - ASSERT_EQ(z, *array); -*/ + sd::Environment::getInstance()->setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); - sd::Environment::getInstance()->setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); + profile->printOut(); - profile->printOut(); + sd::Environment::getInstance()->setProfiling(false); + delete profile; - sd::Environment::getInstance()->setProfiling(false); - delete profile; + /* + std::vector values; -/* - std::vector values; + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); - for (int e = 0; e < 1; e++) { - auto timeStart = std::chrono::system_clock::now(); + GraphExecutioner::execute(graph); - GraphExecutioner::execute(graph); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = + std::chrono::duration_cast(timeEnd - + timeStart).count(); values.emplace_back(outerTime); + } - auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); - } - - std::sort(values.begin(), values.end()); + std::sort(values.begin(), values.end()); - nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); -*/ - delete graph; + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + */ + delete graph; #endif } - TEST_F(PlaygroundTests, test_one_off_ops_1) { - auto x = NDArrayFactory::create('c', {4, 128, 768}); - auto y = NDArrayFactory::create('c', {4, 128, 1}); - auto z = x.ulike(); + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto z = x.ulike(); - sd::ops::squaredsubtract op; - op.execute({&x, &y}, {&z}); + sd::ops::squaredsubtract op; + op.execute({&x, &y}, {&z}); } - /* TEST_F(PlaygroundTests, test_broadcast_1) { @@ -316,8 +331,9 @@ TEST_F(PlaygroundTests, test_broadcast_1) { sd::ops::helpers::addBias(ctx, *x, *y, *z, false); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -363,8 +379,9 @@ TEST_F(PlaygroundTests, test_broadcast_1) { x->applyTransform(transform::Tanh, *z, nullptr); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -382,8 +399,8 @@ TEST_F(PlaygroundTests, test_broadcast_1) { /* TEST_F(PlaygroundTests, test_s_0) { - std::vector> shapes = {{32, 224, 224, 3}, {32, 56, 56, 64}, {32, 7, 7, 512}}; - std::vector threads = {1, 2, 4, 8, 16}; + std::vector> shapes = {{32, 224, 224, 3}, {32, 56, 56, +64}, {32, 7, 7, 512}}; std::vector threads = {1, 2, 4, 8, 16}; for (auto shape: shapes) { for (auto t: threads) { @@ -409,20 +426,23 @@ TEST_F(PlaygroundTests, test_s_0) { sd::ops::helpers::addBias(ctx, x, y, z, false); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: +%lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / +2]); } } } TEST_F(PlaygroundTests, test_s_1) { - std::vector> shapes = {{32, 3, 224, 224}, {32, 64, 56, 56}, {32, 512, 7, 7}}; - std::vector threads = {1, 2, 4, 8, 16}; + std::vector> shapes = {{32, 3, 224, 224}, {32, 64, 56, +56}, {32, 512, 7, 7}}; std::vector threads = {1, 2, 4, 8, 16}; for (auto shape: shapes) { for (auto t: threads) { @@ -448,13 +468,16 @@ TEST_F(PlaygroundTests, test_s_1) { sd::ops::helpers::addBias(ctx, x, y, z, true); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: +%lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / +2]); } } } @@ -481,8 +504,8 @@ TEST_F(PlaygroundTests, test_s_0) { op.execute(&ctx); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); @@ -525,8 +548,8 @@ TEST_F(PlaygroundTests, test_s_1) { op.execute(&ctx); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); } @@ -562,8 +585,8 @@ TEST_F(PlaygroundTests, test_s_2) { } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); values.emplace_back(outerTime); }; std::sort(values.begin(), values.end()); @@ -611,8 +634,9 @@ TEST_F(PlaygroundTests, test_s_4) { } } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesX.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); valuesX.emplace_back(outerTime); } @@ -624,7 +648,8 @@ TEST_F(PlaygroundTests, test_s_4) { for (auto k = 0; k < xs2; k++) { for (auto l = 0; l < xs3; l++) { - zbuffer[thread_id] += buffer[i * j + (k * l)] * 2.5f; + zbuffer[thread_id] += buffer[i * j + (k * l)] +* 2.5f; } } } @@ -633,18 +658,21 @@ TEST_F(PlaygroundTests, test_s_4) { samediff::Threads::parallel_for(f2d, 0, xs0, 1, 0, xs1, 1); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - valuesY.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); valuesY.emplace_back(outerTime); } if (valuesX.size() > 0) { std::sort(valuesX.begin(), valuesX.end()); - nd4j_printf("OpenMP time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); + nd4j_printf("OpenMP time: %lld; Min: %lld; Max: %lld;\n", +valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); } if (valuesY.size() > 0) { std::sort(valuesY.begin(), valuesY.end()); - nd4j_printf("Threads time: %lld; Min: %lld; Max: %lld;\n", valuesY[valuesY.size() / 2], valuesY[0], valuesY[valuesY.size() - 1]); + nd4j_printf("Threads time: %lld; Min: %lld; Max: %lld;\n", +valuesY[valuesY.size() / 2], valuesY[0], valuesY[valuesY.size() - 1]); } nd4j_printf("Sum: %f\n", z.sumNumber().e(0)); @@ -677,17 +705,20 @@ TEST_F(PlaygroundTests, test_s_5) { auto timeStart = std::chrono::system_clock::now(); // picking best fit here - auto splitLoop = samediff::ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); - auto span = samediff::Span2::build(splitLoop, 0, numThreads, startX, stopX, incX, startY, stopY, incY); + auto splitLoop = samediff::ThreadsHelper::pickLoop2d(numThreads, itersX, +itersY); auto span = samediff::Span2::build(splitLoop, 0, numThreads, startX, +stopX, incX, startY, stopY, incY); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", +values[values.size() / 2], values[0], values[values.size()-1]); } @@ -707,13 +738,15 @@ TEST_F(PlaygroundTests, test_s_6) { } auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); - values.emplace_back(outerTime); + auto outerTime = +std::chrono::duration_cast(timeEnd - +timeStart).count(); values.emplace_back(outerTime); } std::sort(values.begin(), values.end()); - nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", +values[values.size() / 2], values[0], values[values.size()-1]); } @@ -737,19 +770,21 @@ TEST_F(PlaygroundTests, test_relubp_1) { auto y = x.ulike(); auto z = x.ulike(); RandomGenerator rng(119, 120); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, +-1.0, 1.0); RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, +&y, -1.0, 1.0); int iterations = 10; auto timeStart = std::chrono::system_clock::now(); for (int e = 0; e < iterations; e++) - ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z); - auto timeEnd = std::chrono::system_clock::now(); + ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, +&z); auto timeEnd = std::chrono::system_clock::now(); - auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); - auto time = (Nd4jLong) outerTime / iterations; - auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024; + auto outerTime = std::chrono::duration_cast +(timeEnd - timeStart).count(); auto time = (Nd4jLong) outerTime / iterations; + auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / +1024 / 1024; nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); } @@ -757,10 +792,11 @@ TEST_F(PlaygroundTests, test_relubp_1) { ////////////////////////////////////////////////////////////////////// TEST_F(PlaygroundTests, my) { - int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int oD,oH,oW; + int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, +pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int oD,oH,oW; - sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, +sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); printf("!!%i, %i, %i\n", oD,oH,oW); @@ -774,9 +810,10 @@ TEST_F(PlaygroundTests, my) { auto block = new Context(1, variableSpace, false); // not-in-place auto timeStart = std::chrono::system_clock::now(); - sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, +dD, dH, dW); auto timeEnd = std::chrono::system_clock::now(); auto time = +std::chrono::duration_cast (timeEnd - +timeStart).count(); printf("time: %i \n", time); @@ -786,11 +823,13 @@ TEST_F(PlaygroundTests, my) { TEST_F(PlaygroundTests, my) { - int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; - int oD,oH,oW; + int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, +pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int oD,oH,oW; - // sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); - sd::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0); + // sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, +sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + sd::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, +pW,dH, dW, iH, iW, 0); printf("!!%i, %i, %i\n", oD,oH,oW); @@ -807,10 +846,11 @@ TEST_F(PlaygroundTests, my) { auto block = new Context(1, variableSpace, false); // not-in-place auto timeStart = std::chrono::system_clock::now(); - // sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); - sd::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + // sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, +pW, dD, dH, dW); sd::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, +pH, pW, iH, iW, dH, dW); auto timeEnd = std::chrono::system_clock::now(); auto +time = std::chrono::duration_cast (timeEnd - +timeStart).count(); printf("time: %i \n", time); @@ -821,8 +861,8 @@ TEST_F(PlaygroundTests, my) { TEST_F(PlaygroundTests, my) { int N = 100; - int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=128,oW=128; + int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, +dH=1,dW=1; int oH=128,oW=128; int paddingMode = 1; // 1-SAME, 0-VALID; int dataFormat = 1; // 1-NHWC, 0-NCHW @@ -831,22 +871,25 @@ TEST_F(PlaygroundTests, my) { // NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - // NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); + // NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // +permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] NDArray weights('c', {oC, iC, kH, +kW}, sd::DataType::FLOAT32); NDArray bias('c', {oC}, sd::DataType::FLOAT32); input = 5.; weights = 3.; bias = 1.; sd::ops::conv2d op; - auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, +pH,pW, dH,dW, paddingMode, dataFormat}); auto timeStart = std::chrono::system_clock::now(); for (int i = 0; i < N; ++i) - err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); + err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, +pH,pW, dH,dW, paddingMode, dataFormat}); auto timeEnd = +std::chrono::system_clock::now(); auto time = +std::chrono::duration_cast ((timeEnd - timeStart) / +N).count(); printf("time: %i \n", time); } @@ -861,15 +904,17 @@ TEST_F(PlaygroundTests, lstmLayerCellBp_1) { // const int nOut = 6; const float cellClip = 1.1; // clipping value - const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid - const Nd4jLong cellAct = 0; // tanh activation for cell state - const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = 0; // beta value for cell state activation, not required for tanh - const Nd4jLong outAct = 0; // tanh activation for output - const float outAlpha = 0; // alpha value for output activation, not required for tanh - const float outBeta = 0; // beta value for output activation, not required for tanh + const Nd4jLong gateAct = 2; // sigmoid activation for input (i), +forget (f) and output (o) gates const float gateAlpha = 0; // alpha value +for activation for gates, not required for sigmoid const float gateBeta = 0; // +beta value for activation for gates, not required for sigmoid const Nd4jLong +cellAct = 0; // tanh activation for cell state const float cellAlpha = 0; +// alpha value for cell state activation, not required for tanh const float +cellBeta = 0; // beta value for cell state activation, not required for +tanh const Nd4jLong outAct = 0; // tanh activation for output const +float outAlpha = 0; // alpha value for output activation, not required for +tanh const float outBeta = 0; // beta value for output activation, not +required for tanh NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE); NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); @@ -908,21 +953,23 @@ TEST_F(PlaygroundTests, lstmLayerCellBp_1) { std::vector iArgs = {gateAct, cellAct, outAct}; // std::vector bArgs = {false, false}; - // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs); - // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, +bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, +iArgs, bArgs); std::vector bArgs = {true, true}; - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, +iArgs, bArgs); const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, +&dLdh}, tArgs, iArgs, bArgs); sd::ops::lstmLayerCell opFF; sd::ops::lstmLayerCellBp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, +argsHolderBP, {true, true, true, true, true, true, true}); } */ - - diff --git a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp index 2ab89663af0e..6ded3e6ffeca 100644 --- a/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -18,7 +18,6 @@ // @author raver119@gmail.com // - #include "testlayers.h" /* @@ -32,7 +31,8 @@ class ProtoBufTests : public testing::Test { TEST_F(ProtoBufTests, TestBinaryLoad1) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb"); ASSERT_FALSE(graph == nullptr); } @@ -40,7 +40,8 @@ TEST_F(ProtoBufTests, TestBinaryLoad1) { TEST_F(ProtoBufTests, TestTextLoad1) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt"); ASSERT_FALSE(graph == nullptr); } @@ -49,7 +50,8 @@ TEST_F(ProtoBufTests, TestTextLoad1) { TEST_F(ProtoBufTests, TestTextLoad2) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt"); ASSERT_FALSE(graph == nullptr); @@ -69,9 +71,11 @@ TEST_F(ProtoBufTests, TestTextLoad2) { ASSERT_EQ(12, var0->getNDArray()->lengthOf()); ASSERT_EQ(12, var1->getNDArray()->lengthOf()); - ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(12.0f, var1->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber>(), +1e-5); ASSERT_NEAR(12.0f, +var1->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber>(), +1e-5); // now we're veryfying op graph @@ -79,22 +83,25 @@ TEST_F(ProtoBufTests, TestTextLoad2) { GraphExecutioner::execute(graph); - ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber>(), 1e-5); - ASSERT_NEAR(1.0f, var0->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber>(), +1e-5); ASSERT_NEAR(1.0f, +var0->getNDArray()->reduceNumber>(), 1e-5); } TEST_F(ProtoBufTests, TestTextLoad3) { GOOGLE_PROTOBUF_VERIFY_VERSION; - auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt"); + auto graph = +GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt"); ASSERT_FALSE(graph == nullptr); ASSERT_EQ(2, graph->variableSpace()->externalEntries()); - auto var0 = graph->variableSpace()->getVariable(new std::string("Placeholder")); - auto var1 = graph->variableSpace()->getVariable(new std::string("Placeholder_1")); + auto var0 = graph->variableSpace()->getVariable(new +std::string("Placeholder")); auto var1 = graph->variableSpace()->getVariable(new +std::string("Placeholder_1")); ASSERT_TRUE(var0 != nullptr); ASSERT_TRUE(var1 != nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp index 97f6cd8cd3b6..5e1ef8de14b0 100644 --- a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp @@ -18,53 +18,49 @@ // @author raver119@protonmail.com // - -#include "testlayers.h" #include #include +#include "testlayers.h" using namespace sd; -class QuantizationTests : public testing::Test { - -}; +class QuantizationTests : public testing::Test {}; TEST_F(QuantizationTests, Basic_Test_1) { #ifndef __CUDABLAS__ - auto s = TypeCast::estimateQuantizedSize(10); - ASSERT_EQ(18, s); + auto s = TypeCast::estimateQuantizedSize(10); + ASSERT_EQ(18, s); #endif } TEST_F(QuantizationTests, Basic_Test_2) { #ifndef __CUDABLAS__ - auto s = TypeCast::estimateQuantizedSize(1); - ASSERT_EQ(9, s); + auto s = TypeCast::estimateQuantizedSize(1); + ASSERT_EQ(9, s); #endif } TEST_F(QuantizationTests, Compression_Test_1) { +#ifndef __CUDABLAS__ - #ifndef __CUDABLAS__ - - auto x = NDArrayFactory::create('c', {10}); - auto z = NDArrayFactory::create('c', {10}); - x.linspace(1.0f); + auto x = NDArrayFactory::create('c', {10}); + auto z = NDArrayFactory::create('c', {10}); + x.linspace(1.0f); - auto q = new char[TypeCast::estimateQuantizedSize(x.lengthOf())]; + auto q = new char[TypeCast::estimateQuantizedSize(x.lengthOf())]; - TypeCast::convertToQuantized(nullptr, x.buffer(), x.lengthOf(), q); - TypeCast::convertFromQuantized(nullptr, q, x.lengthOf(), z.buffer()); + TypeCast::convertToQuantized(nullptr, x.buffer(), x.lengthOf(), q); + TypeCast::convertFromQuantized(nullptr, q, x.lengthOf(), z.buffer()); - ASSERT_TRUE(x.equalsTo(z, 0.1)); + ASSERT_TRUE(x.equalsTo(z, 0.1)); - auto fq = reinterpret_cast(q); + auto fq = reinterpret_cast(q); - ASSERT_NEAR(1.0f, fq[0], 1e-5); - ASSERT_NEAR(10.0f, fq[1], 1e-5); + ASSERT_NEAR(1.0f, fq[0], 1e-5); + ASSERT_NEAR(10.0f, fq[1], 1e-5); - delete[] q; + delete[] q; - #endif +#endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 96eb2465756a..76f9d8053ca3 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -19,1286 +19,1344 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include #include -#include #include +#include -using namespace sd; +#include -class RNGTests : public testing::Test { -private: - //Nd4jLong *_bufferA; - //Nd4jLong *_bufferB; - -public: - long _seed = 119L; - sd::graph::RandomGenerator _rngA; - sd::graph::RandomGenerator _rngB; - - NDArray* nexp0 = NDArrayFactory::create_('c', {10, 10}); - NDArray* nexp1 = NDArrayFactory::create_('c', {10, 10}); - NDArray* nexp2 = NDArrayFactory::create_('c', {10, 10}); - - RNGTests() { - //_bufferA = new Nd4jLong[100000]; - //_bufferB = new Nd4jLong[100000]; - //_rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); - //_rngB = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); - _rngA.setStates(_seed, _seed); - _rngB.setStates(_seed, _seed); - nexp0->assign(-1.0f); - nexp1->assign(-2.0f); - nexp2->assign(-3.0f); - } +#include "testlayers.h" - ~RNGTests() { - //destroyRandom(_rngA); - //destroyRandom(_rngB); - //delete[] _bufferA; - //delete[] _bufferB; +using namespace sd; - delete nexp0; - delete nexp1; - delete nexp2; - } +class RNGTests : public testing::Test { + private: + // Nd4jLong *_bufferA; + // Nd4jLong *_bufferB; + + public: + long _seed = 119L; + sd::graph::RandomGenerator _rngA; + sd::graph::RandomGenerator _rngB; + + NDArray* nexp0 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp1 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp2 = NDArrayFactory::create_('c', {10, 10}); + + RNGTests() { + //_bufferA = new Nd4jLong[100000]; + //_bufferB = new Nd4jLong[100000]; + //_rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, + //(Nd4jPointer) _bufferA); _rngB = (sd::random::RandomBuffer *) + //initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); + _rngA.setStates(_seed, _seed); + _rngB.setStates(_seed, _seed); + nexp0->assign(-1.0f); + nexp1->assign(-2.0f); + nexp2->assign(-3.0f); + } + + ~RNGTests() { + // destroyRandom(_rngA); + // destroyRandom(_rngB); + // delete[] _bufferA; + // delete[] _bufferB; + + delete nexp0; + delete nexp1; + delete nexp2; + } }; TEST_F(RNGTests, TestSeeds_1) { - RandomGenerator generator(123L, 456L); + RandomGenerator generator(123L, 456L); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); - Nd4jPointer ptr = malloc(sizeof(RandomGenerator)); - memcpy(ptr, &generator, sizeof(RandomGenerator)); + Nd4jPointer ptr = malloc(sizeof(RandomGenerator)); + memcpy(ptr, &generator, sizeof(RandomGenerator)); - auto cast = reinterpret_cast(ptr); - ASSERT_EQ(123, cast->rootState()); - ASSERT_EQ(456, cast->nodeState()); + auto cast = reinterpret_cast(ptr); + ASSERT_EQ(123, cast->rootState()); + ASSERT_EQ(456, cast->nodeState()); - free(ptr); + free(ptr); } TEST_F(RNGTests, TestSeeds_2) { - RandomGenerator generator(12, 13); + RandomGenerator generator(12, 13); - generator.setStates(123L, 456L); + generator.setStates(123L, 456L); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); } TEST_F(RNGTests, TestGenerator_SGA_1) { - RandomGenerator generator(12, 13); - auto array= NDArrayFactory::create('c',{10000000}); - generator.setStates(123L, 456L); - for (auto idx = 0; idx < array.lengthOf(); idx++) { - float x = generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, - sd::DataTypeUtils::template max() / 10); - array.t(idx) = x; - } - auto minimum = array.reduceNumber(reduce::AMin); - minimum.printBuffer("Randomly float min on 1M array"); - ASSERT_EQ(123, generator.rootState()); - ASSERT_EQ(456, generator.nodeState()); + RandomGenerator generator(12, 13); + auto array = NDArrayFactory::create('c', {10000000}); + generator.setStates(123L, 456L); + for (auto idx = 0; idx < array.lengthOf(); idx++) { + float x = + generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + array.t(idx) = x; + } + auto minimum = array.reduceNumber(reduce::AMin); + minimum.printBuffer("Randomly float min on 1M array"); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); } - TEST_F(RNGTests, Test_Dropout_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - x0.linspace(1); - x1.linspace(1); - - float prob[] = {0.5f}; - - //x0.applyRandom(random::DropOut, _rngA, nullptr, &x0, prob); - //x1.applyRandom(random::DropOut, _rngB, nullptr, &x1, prob); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); - ASSERT_TRUE(x0.equalsTo(&x1)); - //x0.printIndexedBuffer("Dropout"); - // this check is required to ensure we're calling wrong signature - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + // x0.applyRandom(random::DropOut, _rngA, nullptr, &x0, prob); + // x1.applyRandom(random::DropOut, _rngB, nullptr, &x1, prob); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("Dropout"); + // this check is required to ensure we're calling wrong signature + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_DropoutInverted_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - x0.linspace(1); - x1.linspace(1); - - float prob[] = {0.5f}; - - //x0.template applyRandom>(_rngA, nullptr, &x0, prob); - //x1.template applyRandom>(_rngB, nullptr, &x1, prob); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); - ASSERT_TRUE(x0.equalsTo(&x1)); - //x0.printIndexedBuffer("DropoutInverted"); - // this check is required to ensure we're calling wrong signature - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + // x0.template applyRandom>(_rngA, nullptr, + // &x0, prob); x1.template + // applyRandom>(_rngB, nullptr, &x1, prob); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, + &x0, 0.5); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, + &x1, 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("DropoutInverted"); + // this check is required to ensure we're calling wrong signature + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); - RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5f); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_2) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); - RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, + &x0, 0.5f); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, + &x1, 0.5f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Launcher_3) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f, 0.2f, 0.1f, 0.3f); - RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f, 0.2f, 0.1f, 0.3f); + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngA, &x0, + 0.5f, 0.2f, 0.1f, 0.3f); + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f, 0.2f, 0.1f, 0.3f); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Uniform_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, + 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, + 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); - for (int e = 0; e < x0.lengthOf(); e++) { - float v = x0.e(e); - ASSERT_TRUE(v >= 1.0f && v <= 2.0f); - } + for (int e = 0; e < x0.lengthOf(); e++) { + float v = x0.e(e); + ASSERT_TRUE(v >= 1.0f && v <= 2.0f); + } } TEST_F(RNGTests, Test_Uniform_3) { - auto x0 = NDArrayFactory::create('c', {1000000}); + auto x0 = NDArrayFactory::create('c', {1000000}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, + 2.0f); - for (int e = 0; e < x0.lengthOf(); e++) { - auto v = x0.t(e); - ASSERT_TRUE(v >= 1.0 && v <= 2.0); - } + for (int e = 0; e < x0.lengthOf(); e++) { + auto v = x0.t(e); + ASSERT_TRUE(v >= 1.0 && v <= 2.0); + } } TEST_F(RNGTests, Test_Bernoulli_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngA, &x0, 1.0f); - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 1.0f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Gaussian_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - //x0.printIndexedBuffer("x0"); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Gaussian_21) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - -// x0.printIndexedBuffer("x0"); -// x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}); - //x0.printIndexedBuffer("X0 Normal"); - //x1.printIndexedBuffer("X1 Normal"); - ASSERT_TRUE(result.status() == Status::OK()); - auto mean = result.at(0); - auto variance = result.at(1); - - // mean->printIndexedBuffer("Mean"); - // variance->printIndexedBuffer("Variance"); - - ASSERT_NEAR(sd::math::nd4j_abs(mean.e(0)), 0.f, 0.2f); - ASSERT_NEAR(variance.e(0), 1.0f, 0.2f); - - + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, + 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 0.0f, 1.0f); + + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + // x0.printIndexedBuffer("X0 Normal"); + // x1.printIndexedBuffer("X1 Normal"); + ASSERT_TRUE(result.status() == Status::OK()); + auto mean = result.at(0); + auto variance = result.at(1); + + // mean->printIndexedBuffer("Mean"); + // variance->printIndexedBuffer("Variance"); + + ASSERT_NEAR(sd::math::nd4j_abs(mean.e(0)), 0.f, 0.2f); + ASSERT_NEAR(variance.e(0), 1.0f, 0.2f); } #ifdef DEBUG_BUILD TEST_F(RNGTests, Test_Gaussian_22) { - auto x0 = NDArrayFactory::create('c', {1000, 800}); - auto x1 = NDArrayFactory::create('c', {1000, 800}); - - RandomLauncher::fillGaussian(sd::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - - //x0.printIndexedBuffer("x0"); - //x1.printIndexedBuffer("x1"); - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}); - //x0.printIndexedBuffer("X0 Normal"); - //x1.printIndexedBuffer("X1 Normal"); - ASSERT_TRUE(result.status() == Status::OK()); - auto mean0 = result.at(0); - auto variance0 = result.at(1); - - //mean0->printIndexedBuffer("Mean"); - //variance0->printIndexedBuffer("Variance"); - ASSERT_NEAR(sd::math::nd4j_abs(mean0.e(0)), 0.f, 1.0e-3f); - ASSERT_NEAR(variance0.e(0), 1.0f, 1.e-3f); - + auto x0 = NDArrayFactory::create('c', {1000, 800}); + auto x1 = NDArrayFactory::create('c', {1000, 800}); + + RandomLauncher::fillGaussian(sd::LaunchContext::defaultContext(), _rngA, &x0, + 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 0.0f, 1.0f); + + // x0.printIndexedBuffer("x0"); + // x1.printIndexedBuffer("x1"); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + // x0.printIndexedBuffer("X0 Normal"); + // x1.printIndexedBuffer("X1 Normal"); + ASSERT_TRUE(result.status() == Status::OK()); + auto mean0 = result.at(0); + auto variance0 = result.at(1); + + // mean0->printIndexedBuffer("Mean"); + // variance0->printIndexedBuffer("Variance"); + ASSERT_NEAR(sd::math::nd4j_abs(mean0.e(0)), 0.f, 1.0e-3f); + ASSERT_NEAR(variance0.e(0), 1.0f, 1.e-3f); } TEST_F(RNGTests, Test_Gaussian_3) { - auto x0 = NDArrayFactory::create('c', {800000}); - - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); - - auto mean = x0.meanNumber(); //.e(0); - auto stdev = x0.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false);//.e(0); - auto meanExp = NDArrayFactory::create(0.); - auto devExp = NDArrayFactory::create(1.); - ASSERT_TRUE(meanExp.equalsTo(mean, 1.e-3)); - ASSERT_TRUE(devExp.equalsTo(stdev, 1.e-3)); + auto x0 = NDArrayFactory::create('c', {800000}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, + 1.0); + + auto mean = x0.meanNumber(); //.e(0); + auto stdev = x0.varianceNumber(sd::variance::SummaryStatsStandardDeviation, + false); //.e(0); + auto meanExp = NDArrayFactory::create(0.); + auto devExp = NDArrayFactory::create(1.); + ASSERT_TRUE(meanExp.equalsTo(mean, 1.e-3)); + ASSERT_TRUE(devExp.equalsTo(stdev, 1.e-3)); } TEST_F(RNGTests, Test_LogNormal_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, + 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } TEST_F(RNGTests, Test_Truncated_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - // x1.printIndexedBuffer("Distribution TN"); - + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); } TEST_F(RNGTests, Test_Truncated_2) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 1.f, 0.5); - ASSERT_NEAR(deviation.e(0), 2.f, 0.5); + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 1.f, 0.5); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); } TEST_F(RNGTests, Test_Truncated_21) { - auto x0 = NDArrayFactory::create('c', {100, 100}); - auto x1 = NDArrayFactory::create('c', {100, 100}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 2.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 2.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 1.f, 0.002); - ASSERT_NEAR(deviation.e(0), 2.f, 0.5); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); - - // result.at(0)->printBuffer("MEAN"); - // result.at(1)->printBuffer("VARIANCE"); - - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated"); - // maxRes->at(0)->printBuffer("MAX for Truncated"); + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 2.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 2.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 1.f, 0.002); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + // result.at(0)->printBuffer("MEAN"); + // result.at(1)->printBuffer("VARIANCE"); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated"); + // maxRes->at(0)->printBuffer("MAX for Truncated"); } TEST_F(RNGTests, Test_Truncated_22) { - auto x0 = NDArrayFactory::create('c', {100, 100}); - auto x1 = NDArrayFactory::create('c', {100, 100}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 4.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 4.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 2.f, 0.01); - ASSERT_NEAR(deviation.e(0), 4.f, 0.52); - sd::ops::moments op; - auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); - - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated2"); - // maxRes->at(0)->printBuffer("MAX for Truncated2"); - + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 2.0f, 4.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 2.0f, 4.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 4.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 4.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 2.f, 0.01); + ASSERT_NEAR(deviation.e(0), 4.f, 0.52); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated2"); + // maxRes->at(0)->printBuffer("MAX for Truncated2"); } TEST_F(RNGTests, Test_Truncated_23) { - auto x0 = NDArrayFactory::create('c', {1000, 1000}); - auto x1 = NDArrayFactory::create('c', {1000, 1000}); - - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - - ASSERT_TRUE(x0.equalsTo(&x1)); - - auto mean0 = x0.reduceNumber(reduce::Mean); - // mean0.printIndexedBuffer("0Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); - // deviation0.printIndexedBuffer("0Deviation should be 4.0"); - - //ASSERT_FALSE(x0.equalsTo(nexp0)); - //ASSERT_FALSE(x0.equalsTo(nexp1)); - //ASSERT_FALSE(x0.equalsTo(nexp2)); - - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 2.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 4.0"); - //x1.printIndexedBuffer("Distribution TN"); - ASSERT_NEAR(mean.e(0), 0.f, 0.01); - ASSERT_NEAR(deviation.e(0), 1.f, 0.5); - sd::ops::moments op; - auto result = op.evaluate({&x0}); - // result->at(0)->printBuffer("MEAN"); - // result->at(1)->printBuffer("VARIANCE"); - sd::ops::reduce_min minOp; - sd::ops::reduce_max maxOp; - - auto minRes = minOp.evaluate({&x1}, {}, {}, {}); - auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); - // minRes->at(0)->printBuffer("MIN for Truncated3"); - // maxRes->at(0)->printBuffer("MAX for Truncated3"); - + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 0.0f, 1.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 0.0f, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + // mean0.printIndexedBuffer("0Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation0 = + x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation0.printIndexedBuffer("0Deviation should be 4.0"); + + // ASSERT_FALSE(x0.equalsTo(nexp0)); + // ASSERT_FALSE(x0.equalsTo(nexp1)); + // ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 2.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 4.0"); + // x1.printIndexedBuffer("Distribution TN"); + ASSERT_NEAR(mean.e(0), 0.f, 0.01); + ASSERT_NEAR(deviation.e(0), 1.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}); + // result->at(0)->printBuffer("MEAN"); + // result->at(1)->printBuffer("VARIANCE"); + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + // minRes->at(0)->printBuffer("MIN for Truncated3"); + // maxRes->at(0)->printBuffer("MAX for Truncated3"); } TEST_F(RNGTests, Test_Truncated_3) { - auto x0 = NDArrayFactory::create('c', {2000, 2000}); - auto x1 = NDArrayFactory::create('c', {2000, 2000}); + auto x0 = NDArrayFactory::create('c', {2000, 2000}); + auto x1 = NDArrayFactory::create('c', {2000, 2000}); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, + &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - // Check up distribution - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean 1.0"); - //auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + // Check up distribution + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean 1.0"); + // auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); - ASSERT_NEAR(mean.e(0), 1.f, 0.001); - ASSERT_NEAR(deviation.e(0), 2.f, 0.3); + auto deviation = + x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 1.f, 0.001); + ASSERT_NEAR(deviation.e(0), 2.f, 0.3); } #endif TEST_F(RNGTests, Test_Binomial_1) { - auto x0 = NDArrayFactory::create('c', {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngA, &x0, 3, 2.0f); - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 2.0f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngA, &x0, 3, + 2.0f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, + 2.0f); - ASSERT_TRUE(x0.equalsTo(&x1)); + ASSERT_TRUE(x0.equalsTo(&x1)); - //nexp2->printIndexedBuffer("nexp2"); - //x0.printIndexedBuffer("x0"); + // nexp2->printIndexedBuffer("nexp2"); + // x0.printIndexedBuffer("x0"); - ASSERT_FALSE(x0.equalsTo(nexp0)); - ASSERT_FALSE(x0.equalsTo(nexp1)); - ASSERT_FALSE(x0.equalsTo(nexp2)); + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); } - TEST_F(RNGTests, Test_Uniform_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, + 2.0f); - sd::ops::LegacyRandomOp op(0); - auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}, {sd::DataType::FLOAT32}); + sd::ops::LegacyRandomOp op(0); + auto result = + op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}, {sd::DataType::FLOAT32}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_Uniform_SGA_3) { - //auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); - - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max(), sd::DataTypeUtils::template max()); - auto minimumU = x1.reduceNumber(reduce::AMin); - minimumU.printBuffer("\nMinimum"); + // auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, + -sd::DataTypeUtils::template max(), + sd::DataTypeUtils::template max()); + auto minimumU = x1.reduceNumber(reduce::AMin); + minimumU.printBuffer("\nMinimum"); } TEST_F(RNGTests, Test_Gaussian_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - sd::ops::LegacyRandomOp op(random::GaussianDistribution); - auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::GaussianDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_LogNorm_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, + 1.0f, 2.0f); - sd::ops::LegacyRandomOp op(random::LogNormalDistribution); - auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::LogNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_TruncatedNorm_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, + &x1, 1.0f, 2.0f); - sd::ops::LegacyRandomOp op(random::TruncatedNormalDistribution); - auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + sd::ops::LegacyRandomOp op(random::TruncatedNormalDistribution); + auto result = op.execute(_rngA, {&input}, {1.0f, 2.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } - TEST_F(RNGTests, Test_Binomial_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 0.5f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, + 0.5f); - sd::ops::LegacyRandomOp op(random::BinomialDistributionEx); - auto result = op.execute(_rngA, {&input}, {0.5f}, {3}); + sd::ops::LegacyRandomOp op(random::BinomialDistributionEx); + auto result = op.execute(_rngA, {&input}, {0.5f}, {3}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } - TEST_F(RNGTests, Test_Bernoulli_2) { - auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); - auto x1 = NDArrayFactory::create('c', {10, 10}); + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); - RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, + 0.5f); - sd::ops::LegacyRandomOp op(random::BernoulliDistribution); - auto result = op.execute(_rngA, {&input}, {0.5f}, {}); + sd::ops::LegacyRandomOp op(random::BernoulliDistribution); + auto result = op.execute(_rngA, {&input}, {0.5f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_TRUE(x1.isSameShape(z)); - ASSERT_TRUE(x1.equalsTo(z)); + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); } TEST_F(RNGTests, Test_GaussianDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + sd::ops::random_normal op; + auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); + ASSERT_EQ(Status::OK(), result.status()); - sd::ops::random_normal op; - auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - - + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_BernoulliDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); - sd::ops::random_bernoulli op; - auto result = op.evaluate({&x}, {0.5f}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::random_bernoulli op; + auto result = op.evaluate({&x}, {0.5f}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); + auto z = result.at(0); - ASSERT_FALSE(exp0.equalsTo(z)); + ASSERT_FALSE(exp0.equalsTo(z)); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - - + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } - TEST_F(RNGTests, Test_ExponentialDistribution_1) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - - sd::ops::random_exponential op; - auto result = op.evaluate({&x}, {0.25f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // - //z->printBuffer("\nExponential1"); - auto mean = z.reduceNumber(reduce::Mean); - auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); - variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - -// delete result; + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential1"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 0.25 (4 exp) is"); + variance.printBuffer("Variance for exponential with param 0.25 (16 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + // delete result; } TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - - sd::ops::random_exponential op; - auto result = op.evaluate({&x}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // - //z->printBuffer("\nExponential2"); - auto mean = z.reduceNumber(reduce::Mean); - auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - RandomGenerator oc(2716049175077475646L, -6182841917129177862L); - - sd::ops::random_exponential op; - RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); - auto result = op.evaluate({&x}, {1.f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - // -// z->printBuffer("\nExponential2+"); - auto mean = z.reduceNumber(reduce::Mean); - auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); - variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + + sd::ops::random_exponential op; + RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2+"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); } TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) { - auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); - auto exp0 = NDArrayFactory::create('c', {1000, 1000}); - RandomGenerator oc(2716049175077475646L, -6182841917129177862L); - auto expMean = NDArrayFactory::create(0.5f); - auto expVar = NDArrayFactory::create(0.25f); - sd::ops::random_exponential op; - RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); - - auto result = op.evaluate({&x}, {1.}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - //ASSERT_TRUE(exp0.isSameShape(z)); - //ASSERT_FALSE(exp0.equalsTo(z)); - // -// z->printBuffer("\nExponential2+"); - auto mean = z.reduceNumber(reduce::Mean); - auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean"); - variance.printBuffer("Variance"); - ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); - ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); -// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); -// variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); -// ASSERT_FALSE(nexp0->equalsTo(z)); -// ASSERT_FALSE(nexp1->equalsTo(z)); -// ASSERT_FALSE(nexp2->equalsTo(z)); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); - variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); - ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); - ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); - RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); - mean = exp0.reduceNumber(reduce::Mean); - variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); - mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); - variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + auto expMean = NDArrayFactory::create(0.5f); + auto expVar = NDArrayFactory::create(0.25f); + sd::ops::random_exponential op; + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); + + auto result = op.evaluate({&x}, {1.}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // ASSERT_TRUE(exp0.isSameShape(z)); + // ASSERT_FALSE(exp0.equalsTo(z)); + // + // z->printBuffer("\nExponential2+"); + auto mean = z.reduceNumber(reduce::Mean); + auto variance = z.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean"); + variance.printBuffer("Variance"); + ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); + ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); + // mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + // variance.printBuffer("Variance for exponential with param 1. (1 exp) + // is"); ASSERT_FALSE(nexp0->equalsTo(z)); + // ASSERT_FALSE(nexp1->equalsTo(z)); + // ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); + ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); + ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); } TEST_F(RNGTests, Test_ExponentialDistribution_2) { - auto x = NDArrayFactory::create('c', {2}, {10, 10}); - auto y = NDArrayFactory::create('c', {10, 10}); - auto exp0 = NDArrayFactory::create('c', {10, 10}); - - y.assign(1.0); - + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); - sd::ops::random_exponential op; - auto result = op.evaluate({&x, &y}, {0.25f}, {0}); - ASSERT_EQ(Status::OK(), result.status()); + y.assign(1.0); - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); + sd::ops::random_exponential op; + auto result = op.evaluate({&x, &y}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); - ASSERT_FALSE(nexp0->equalsTo(z)); - ASSERT_FALSE(nexp1->equalsTo(z)); - ASSERT_FALSE(nexp2->equalsTo(z)); - - + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); } TEST_F(RNGTests, Test_PoissonDistribution_1) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto la = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - - la.linspace(1.0); + auto x = NDArrayFactory::create('c', {1}, {10}); + auto la = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + la.linspace(1.0); - sd::ops::random_poisson op; - auto result = op.evaluate({&x, &la}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::random_poisson op; + auto result = op.evaluate({&x, &la}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Poisson distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Poisson distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_1) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - - al.linspace(1.0); + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + al.linspace(1.0); - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_2) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {2, 3}); - auto be = NDArrayFactory::create('c', {2, 3}); - auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); - - al.linspace(1.0); - be.assign(1.0); - - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al, &be}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto be = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + be.assign(1.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_GammaDistribution_3) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create('c', {3, 1}); - auto be = NDArrayFactory::create('c', {1, 2}); - auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); - - al.linspace(1.0); - be.assign(2.0); - - sd::ops::random_gamma op; - auto result = op.evaluate({&x, &al, &be}, {}, {}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); -// z->printIndexedBuffer("Gamma distribution"); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {3, 1}); + auto be = NDArrayFactory::create('c', {1, 2}); + auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); + + al.linspace(1.0); + be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } TEST_F(RNGTests, Test_UniformDistribution_04) { - auto x = NDArrayFactory::create('c', {1}, {10}); - auto al = NDArrayFactory::create(1); - auto be = NDArrayFactory::create(20); - auto exp0 = NDArrayFactory::create('c', {10}); - - - sd::ops::randomuniform op; - auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); - ASSERT_EQ(Status::OK(), result.status()); - - auto z = result.at(0); - ASSERT_TRUE(exp0.isSameShape(z)); - ASSERT_FALSE(exp0.equalsTo(z)); - - + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create(1); + auto be = NDArrayFactory::create(20); + auto exp0 = NDArrayFactory::create('c', {10}); + + sd::ops::randomuniform op; + auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); } namespace sd { - namespace tests { - static void fillList(Nd4jLong seed, int numberOfArrays, std::vector &shape, std::vector &list, sd::graph::RandomGenerator *rng) { - rng->setSeed((int) seed); - - for (int i = 0; i < numberOfArrays; i++) { - auto arrayI = NDArrayFactory::create(shape); - auto arrayR = NDArrayFactory::create_('c', shape); - auto min = NDArrayFactory::create(0.0); - auto max = NDArrayFactory::create(1.0); - sd::ops::randomuniform op; - op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false); - - list.emplace_back(arrayR); - } - }; - } -} +namespace tests { +static void fillList(Nd4jLong seed, int numberOfArrays, + std::vector& shape, std::vector& list, + sd::graph::RandomGenerator* rng) { + rng->setSeed((int)seed); + + for (int i = 0; i < numberOfArrays; i++) { + auto arrayI = NDArrayFactory::create(shape); + auto arrayR = NDArrayFactory::create_('c', shape); + auto min = NDArrayFactory::create(0.0); + auto max = NDArrayFactory::create(1.0); + sd::ops::randomuniform op; + op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, + {}, {}, false); + + list.emplace_back(arrayR); + } +}; +} // namespace tests +} // namespace sd TEST_F(RNGTests, Test_Reproducibility_1) { - Nd4jLong seed = 123; + Nd4jLong seed = 123; - std::vector shape = {32, 3, 28, 28}; - sd::graph::RandomGenerator rng; + std::vector shape = {32, 3, 28, 28}; + sd::graph::RandomGenerator rng; - std::vector expList; - sd::tests::fillList(seed, 10, shape, expList, &rng); + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); - for (int e = 0; e < 2; e++) { - std::vector trialList; - sd::tests::fillList(seed, 10, shape, trialList, &rng); + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); - for (int a = 0; a < expList.size(); a++) { - auto arrayE = expList[a]; - auto arrayT = trialList[a]; + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; - bool t = arrayE->equalsTo(arrayT); - if (!t) { - // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); - ASSERT_TRUE(false); - } + bool t = arrayE->equalsTo(arrayT); + if (!t) { + // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); + ASSERT_TRUE(false); + } - delete arrayT; - } + delete arrayT; } + } - for (auto v: expList) - delete v; + for (auto v : expList) delete v; } #ifndef DEBUG_BUILD TEST_F(RNGTests, Test_Reproducibility_2) { - Nd4jLong seed = 123; + Nd4jLong seed = 123; - std::vector shape = {32, 3, 64, 64}; - sd::graph::RandomGenerator rng; + std::vector shape = {32, 3, 64, 64}; + sd::graph::RandomGenerator rng; - std::vector expList; - sd::tests::fillList(seed, 10, shape, expList, &rng); + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); - for (int e = 0; e < 2; e++) { - std::vector trialList; - sd::tests::fillList(seed, 10, shape, trialList, &rng); + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); - for (int a = 0; a < expList.size(); a++) { - auto arrayE = expList[a]; - auto arrayT = trialList[a]; + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; - bool t = arrayE->equalsTo(arrayT); - if (!t) { - // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); + bool t = arrayE->equalsTo(arrayT); + if (!t) { + // nd4j_printf("Failed at iteration [%i] for array [%i]\n", e, a); - for (Nd4jLong f = 0; f < arrayE->lengthOf(); f++) { - double x = arrayE->e(f); - double y = arrayT->e(f); + for (Nd4jLong f = 0; f < arrayE->lengthOf(); f++) { + double x = arrayE->e(f); + double y = arrayT->e(f); - if (sd::math::nd4j_re(x, y) > 0.1) { - // nd4j_printf("E[%lld] %f != T[%lld] %f\n", (long long) f, (float) x, (long long) f, (float) y); - throw std::runtime_error("boom"); - } - } + if (sd::math::nd4j_re(x, y) > 0.1) { + // nd4j_printf("E[%lld] %f != T[%lld] %f\n", (long long) f, (float) + // x, (long long) f, (float) y); + throw std::runtime_error("boom"); + } + } - // just breaker, since test failed - ASSERT_TRUE(false); - } + // just breaker, since test failed + ASSERT_TRUE(false); + } - delete arrayT; - } + delete arrayT; } + } - for (auto v: expList) - delete v; + for (auto v : expList) delete v; } TEST_F(RNGTests, Test_Uniform_4) { - auto x1 = NDArrayFactory::create('c', {1000000}); + auto x1 = NDArrayFactory::create('c', {1000000}); - RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, + 2.0); - /* Check up distribution */ - auto mean = x1.reduceNumber(reduce::Mean); - // mean.printIndexedBuffer("Mean should be 1.5"); - auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean should be 1.5"); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); - auto deviation = x1.varianceNumber(variance::SummaryStatsVariance, false); - //deviation /= (double)x1.lengthOf(); - // deviation.printIndexedBuffer("Deviation should be 1/12 (0.083333)"); + auto deviation = x1.varianceNumber(variance::SummaryStatsVariance, false); + // deviation /= (double)x1.lengthOf(); + // deviation.printIndexedBuffer("Deviation should be 1/12 (0.083333)"); - ASSERT_NEAR(mean.e(0), 1.5, 1e-3); - ASSERT_NEAR(1/12., deviation.e(0), 1e-3); + ASSERT_NEAR(mean.e(0), 1.5, 1e-3); + ASSERT_NEAR(1 / 12., deviation.e(0), 1e-3); } #endif TEST_F(RNGTests, test_choice_1) { - const auto x = NDArrayFactory::linspace(0, 10, 11); - const auto prob = NDArrayFactory::valueOf({11}, 1.0/11, 'c'); - auto z = NDArrayFactory::create('c', {1000}); - - RandomGenerator rng(119, 256); - NativeOpExecutioner::execRandom(sd::LaunchContext ::defaultContext(), random::Choice, &rng, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), prob->buffer(), prob->shapeInfo(), prob->specialBuffer(), prob->specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); - - // z.printIndexedBuffer("z"); - - delete x; - delete prob; + const auto x = NDArrayFactory::linspace(0, 10, 11); + const auto prob = NDArrayFactory::valueOf({11}, 1.0 / 11, 'c'); + auto z = NDArrayFactory::create('c', {1000}); + + RandomGenerator rng(119, 256); + NativeOpExecutioner::execRandom( + sd::LaunchContext ::defaultContext(), random::Choice, &rng, x->buffer(), + x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), prob->buffer(), + prob->shapeInfo(), prob->specialBuffer(), prob->specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr); + + // z.printIndexedBuffer("z"); + + delete x; + delete prob; } TEST_F(RNGTests, test_uniform_119) { - auto x = NDArrayFactory::create('c', {2}, {1, 5}); - auto z = NDArrayFactory::create('c', {1, 5}); + auto x = NDArrayFactory::create('c', {2}, {1, 5}); + auto z = NDArrayFactory::create('c', {1, 5}); - - sd::ops::randomuniform op; - auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); - ASSERT_EQ(Status::OK(), status); + sd::ops::randomuniform op; + auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); + ASSERT_EQ(Status::OK(), status); } TEST_F(RNGTests, test_multinomial_1) { - - NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, sd::DataType::INT64); - NDArray output('f', { 3, 3 }, sd::DataType::INT64); - NDArray samples('f', { 1 }, std::vector({3}), sd::DataType::INT32); - - sd::ops::random_multinomial op; - RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0}, {}, {INT64}, false) ); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64); - - auto result = op.evaluate({ &probsZ, &samples }, { }, { 1 }, {}, {INT64}); - auto outputZ = result.at(0); - - ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(expectedZ.isSameShape(outputZ)); - ASSERT_TRUE(expectedZ.equalsTo(outputZ)); - + NDArray probs('f', {3, 3}, {0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('f', {3, 3}, {0., 1, 2, 2, 0, 0, 1, 2, 1}, + sd::DataType::INT64); + NDArray output('f', {3, 3}, sd::DataType::INT64); + NDArray samples('f', {1}, std::vector({3}), sd::DataType::INT32); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probsZ('c', {1, 3}, {0.3, 0.3, 0.3}, sd::DataType::FLOAT32); + NDArray expectedZ('c', {3, 3}, {0., 0, 0, 0, 0, 0, 0, 0, 0}, + sd::DataType::INT64); + + auto result = op.evaluate({&probsZ, &samples}, {}, {1}, {}, {INT64}); + auto outputZ = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expectedZ.isSameShape(outputZ)); + ASSERT_TRUE(expectedZ.equalsTo(outputZ)); } TEST_F(RNGTests, test_multinomial_2) { - - NDArray samples('c', { 1 }, std::vector{ 20 }, sd::DataType::INT32); - NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); - NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, sd::DataType::INT64); - NDArray output('c', { 3, 20 }, sd::DataType::INT64); - - sd::ops::random_multinomial op; - RandomGenerator rng(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); - NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, sd::DataType::INT64); - NDArray output2('c', { 20, 3 }, sd::DataType::INT64); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1 }, {}, {INT64}, false)); - ASSERT_TRUE(expected2.isSameShape(output2)); - ASSERT_TRUE(expected2.equalsTo(output2)); + NDArray samples('c', {1}, std::vector{20}, sd::DataType::INT32); + NDArray probs('c', {3, 5}, + {0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, + 0.25, 0.25, 0.5}, + sd::DataType::FLOAT32); + NDArray expected('c', {3, 20}, + {0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, + 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, + 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2}, + sd::DataType::INT64); + NDArray output('c', {3, 20}, sd::DataType::INT64); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probs2('c', {5, 3}, + {0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, + 0.25, 0.25, 0.5}, + sd::DataType::FLOAT32); + NDArray expected2('c', {20, 3}, {0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, + 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, + 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, + 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2}, + sd::DataType::INT64); + NDArray output2('c', {20, 3}, sd::DataType::INT64); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs2, &samples}, {&output2}, {}, + {1}, {}, {INT64}, false)); + ASSERT_TRUE(expected2.isSameShape(output2)); + ASSERT_TRUE(expected2.equalsTo(output2)); } TEST_F(RNGTests, test_multinomial_3) { - - NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('c', { 4, 5 }, sd::DataType::INT64); - NDArray output('c', { 4, 5 }, sd::DataType::INT64); - NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); - RandomGenerator rng(1234, 1234); - - sd::ops::random_multinomial op; - - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0 }, {}, {INT64}, false)); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray probs('c', {4, 3}, + {0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('c', {4, 5}, sd::DataType::INT64); + NDArray output('c', {4, 5}, sd::DataType::INT64); + NDArray samples('c', {1}, std::vector{5}, sd::DataType::INT32); + RandomGenerator rng(1234, 1234); + + sd::ops::random_multinomial op; + + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&expected}, {}, + {0}, {}, {INT64}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(RNGTests, test_multinomial_4) { - - NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); - NDArray expected('c', { 5, 4 }, sd::DataType::INT64); - NDArray output('c', { 5, 4 }, sd::DataType::INT64); - NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); - - RandomGenerator rng(1234, 1234); - sd::ops::random_multinomial op; - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1 }, {}, {INT64}, false)); - - rng.setStates(1234, 1234); - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {INT64}, false)); - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); + NDArray probs('c', {3, 4}, + {0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3}, + sd::DataType::FLOAT32); + NDArray expected('c', {5, 4}, sd::DataType::INT64); + NDArray output('c', {5, 4}, sd::DataType::INT64); + NDArray samples('c', {1}, std::vector{5}, sd::DataType::INT32); + + RandomGenerator rng(1234, 1234); + sd::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&expected}, {}, + {1}, {}, {INT64}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {1}, {}, {INT64}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); } TEST_F(RNGTests, test_multinomial_5) { - // multinomial as binomial if 2 classes used - int batchValue = 1; - int ClassValue = 2; - int Samples = 100000; - - NDArray samples('c', { 1 }, std::vector{ 1.*Samples }, sd::DataType::INT32); - - NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, sd::DataType::FLOAT32); - - sd::ops::random_multinomial op; - - NDArray output('c', { Samples, batchValue }, sd::DataType::INT64); - RandomGenerator rng(1234, 1234); - - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); - - auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto mean = output.meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - // theoretical values for binomial - ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); - ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); - - for (int i = 0; i < output.lengthOf(); i++) { - auto value = output.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - } - - auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 }); - auto outputR = resultR.at(0); - ASSERT_EQ(Status::OK(), resultR.status()); - - deviation = outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); - mean = outputR.meanNumber(); - // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); - ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); - - for (int i = 0; i < outputR.lengthOf(); i++) { - auto value = outputR.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - } - + // multinomial as binomial if 2 classes used + int batchValue = 1; + int ClassValue = 2; + int Samples = 100000; + + NDArray samples('c', {1}, std::vector{1. * Samples}, + sd::DataType::INT32); + + NDArray probs('c', {ClassValue, batchValue}, {1.0, 1.0}, + sd::DataType::FLOAT32); + + sd::ops::random_multinomial op; + + NDArray output('c', {Samples, batchValue}, sd::DataType::INT64); + RandomGenerator rng(1234, 1234); + + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {1}, {}, {}, false)); + + auto deviation = + output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + // theoretical values for binomial + ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + auto resultR = op.evaluate({&probs, &samples}, {}, {1}); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); + + deviation = + outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR.meanNumber(); + // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), + // mean.e(0)); + ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); + + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } } - TEST_F(RNGTests, test_multinomial_6) { - - int batchValue = 1; - int ClassValue = 5; - int Samples = 100000; - - NDArray samples('c', { 1 }, std::vector{ 1. * Samples }, sd::DataType::INT32); - - sd::ops::random_multinomial op; - NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, sd::DataType::DOUBLE); - - // without seed - NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); - - auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 }); - auto outputR = resultR.at(0); - ASSERT_EQ(Status::OK(), resultR.status()); - - NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); - - for (int i = 0; i < outputR.lengthOf(); i++) { - auto value = outputR.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - double* z = countsR.bufferAsT(); - z[value] += 1; - } - - for (int i = 0; i < countsR.lengthOf(); i++) { - auto c = countsR.e(i); - auto p = probExpect.e(i); - // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); - } - - auto deviation = outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); - auto mean = outputR.meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); - ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); - - - - RandomGenerator rng(1234, 1234); - NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); - NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); - - ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0 }, {}, {INT64}, false)); - - NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); - - for (int i = 0; i < output.lengthOf(); i++) { - auto value = output.e(i); - ASSERT_TRUE(value >= 0 && value < ClassValue); - double* z = counts.bufferAsT(); - z[value] += 1; - } - - for (int i = 0; i < counts.lengthOf(); i++) { - auto c = counts.e(i); - auto p = probExpect.e(i); - // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); - ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); - } - - deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); - mean = output.meanNumber(); - // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); - ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); - ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); + int batchValue = 1; + int ClassValue = 5; + int Samples = 100000; + + NDArray samples('c', {1}, std::vector{1. * Samples}, + sd::DataType::INT32); + + sd::ops::random_multinomial op; + NDArray probExpect('c', {ClassValue}, {0.058, 0.096, 0.1576, 0.2598, 0.4287}, + sd::DataType::DOUBLE); + + // without seed + NDArray probsR('c', {batchValue, ClassValue}, {1., 1.5, 2., 2.5, 3.}, + sd::DataType::FLOAT32); + + auto resultR = op.evaluate({&probsR, &samples}, {}, {0}); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); + + NDArray countsR('c', {ClassValue}, {0., 0, 0, 0, 0}, sd::DataType::DOUBLE); + + for (int i = 0; i < outputR.lengthOf(); i++) { + auto value = outputR.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = countsR.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < countsR.lengthOf(); i++) { + auto c = countsR.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); + } + + auto deviation = + outputR.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); + + RandomGenerator rng(1234, 1234); + NDArray probs('c', {batchValue, ClassValue}, {1., 1.5, 2., 2.5, 3.}, + sd::DataType::FLOAT32); + NDArray output('c', {batchValue, Samples}, sd::DataType::INT64); + + ASSERT_EQ(Status::OK(), op.execute(rng, {&probs, &samples}, {&output}, {}, + {0}, {}, {INT64}, false)); + + NDArray counts('c', {ClassValue}, {0., 0, 0, 0, 0}, sd::DataType::DOUBLE); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = counts.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < counts.lengthOf(); i++) { + auto c = counts.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); + } + + deviation = + output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); } diff --git a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp index bfee646cb7a8..ef4706aaaed5 100644 --- a/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ResultSetTests.cpp @@ -18,32 +18,31 @@ // Created by raver on 4/18/2019. // -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ResultSetTests : public testing::Test { -public: - + public: }; TEST_F(ResultSetTests, basic_test_1) { - auto x = NDArrayFactory::create('c', {3, 5}); + auto x = NDArrayFactory::create('c', {3, 5}); - auto tensors = x.allTensorsAlongDimension({1}); - ASSERT_EQ(3, tensors.size()); + auto tensors = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, tensors.size()); - ResultSet set = tensors; - ASSERT_EQ(3, tensors.size()); - ASSERT_EQ(3, set.size()); + ResultSet set = tensors; + ASSERT_EQ(3, tensors.size()); + ASSERT_EQ(3, set.size()); - for (int e = 0; e < set.size(); e++) - ASSERT_EQ(5, set.at(e).lengthOf()); + for (int e = 0; e < set.size(); e++) ASSERT_EQ(5, set.at(e).lengthOf()); - for (int e = 0; e < tensors.size(); e++) - ASSERT_EQ(5, tensors.at(e).lengthOf()); + for (int e = 0; e < tensors.size(); e++) + ASSERT_EQ(5, tensors.at(e).lengthOf()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp index 47cbd5f0518e..2c076bcd73e5 100644 --- a/libnd4j/tests_cpu/layers_tests/SanityTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SanityTests.cpp @@ -18,35 +18,37 @@ // Created by raver119 on 13/11/17. // -#include "testlayers.h" #include #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class SanityTests : public testing::Test { -public: - + public: }; TEST_F(SanityTests, VariableSpace_2) { - VariableSpace variableSpace; - variableSpace.putVariable(1, NDArrayFactory::create('c', {3, 3})); - variableSpace.putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); + VariableSpace variableSpace; + variableSpace.putVariable(1, NDArrayFactory::create('c', {3, 3})); + variableSpace.putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); - std::pair pair(1, 2); - variableSpace.putVariable(pair, NDArrayFactory::create('c', {3, 3})); + std::pair pair(1, 2); + variableSpace.putVariable(pair, NDArrayFactory::create('c', {3, 3})); } - TEST_F(SanityTests, Graph_1) { - Graph graph; + Graph graph; - graph.variableSpace().putVariable(1, NDArrayFactory::create('c', {3, 3})); - graph.variableSpace().putVariable({1, 1}, NDArrayFactory::create('c', {3, 3})); + graph.variableSpace().putVariable(1, + NDArrayFactory::create('c', {3, 3})); + graph.variableSpace().putVariable({1, 1}, + NDArrayFactory::create('c', {3, 3})); - std::pair pair(1, 2); - graph.variableSpace().putVariable(pair, NDArrayFactory::create('c', {3, 3})); + std::pair pair(1, 2); + graph.variableSpace().putVariable(pair, + NDArrayFactory::create('c', {3, 3})); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp index 937ca4675608..fde5a69bbd2b 100644 --- a/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ScalarTests.cpp @@ -18,219 +18,202 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ScalarTests : public testing::Test { -public: - + public: }; TEST_F(ScalarTests, Test_Create_1) { - auto x = NDArrayFactory::create(2.0f); - - ASSERT_EQ(0, x.rankOf()); - ASSERT_EQ(1, x.lengthOf()); - ASSERT_TRUE(x.isScalar()); - ASSERT_FALSE(x.isVector()); - ASSERT_FALSE(x.isRowVector()); - ASSERT_FALSE(x.isColumnVector()); - ASSERT_FALSE(x.isMatrix()); + auto x = NDArrayFactory::create(2.0f); + + ASSERT_EQ(0, x.rankOf()); + ASSERT_EQ(1, x.lengthOf()); + ASSERT_TRUE(x.isScalar()); + ASSERT_FALSE(x.isVector()); + ASSERT_FALSE(x.isRowVector()); + ASSERT_FALSE(x.isColumnVector()); + ASSERT_FALSE(x.isMatrix()); } TEST_F(ScalarTests, Test_Add_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create(5.0f); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(5.0f); - x += 3.0f; + x += 3.0f; - ASSERT_NEAR(5.0f, x.e(0), 1e-5f); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_Add_2) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create(5.0f); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(5.0f); - x += y; + x += y; - ASSERT_NEAR(5.0f, x.e(0), 1e-5f); - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_Add_3) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto y = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {4, 5, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {4, 5, 6}); - x += y; + x += y; - ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(ScalarTests, Test_EQ_1) { - auto x = NDArrayFactory::create(2.0f); - auto y = NDArrayFactory::create(3.0f); + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); - ASSERT_TRUE(y.isSameShape(&x)); - ASSERT_FALSE(y.equalsTo(&x)); + ASSERT_TRUE(y.isSameShape(&x)); + ASSERT_FALSE(y.equalsTo(&x)); } TEST_F(ScalarTests, Test_Concat_1) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create(2.0f); - auto v = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_2) { - auto t = NDArrayFactory::create(1.0f); - auto u = NDArrayFactory::create('c', {3}, {2, 3, 4}); - auto v = NDArrayFactory::create(5.0f); - auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto z = result.at(0); - // z->printIndexedBuffer(); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_3) { - auto t = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto u = NDArrayFactory::create(4.0f); - auto v = NDArrayFactory::create(5.0f); - auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v}, {}, {0}); - - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto u = NDArrayFactory::create(4.0f); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); - //z->printShapeInfo("z"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + // z->printShapeInfo("z"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_ExpandDims_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create('c', {1}, {2.0f}); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1}, {2.0f}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {0}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Squeeze_1) { - auto x = NDArrayFactory::create(2.0f); - auto exp = NDArrayFactory::create(2.0f); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(2.0f); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Permute_1) { - auto x = NDArrayFactory::create(3.0f); - auto exp = NDArrayFactory::create(3.0f); - - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(3.0f); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ScalarTests, Test_Concat_Scalar_1) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(ScalarTests, Test_Concat_Scalar_2) { - auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); - auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); - auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); - auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); - auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - - sd::ops::concat op; - auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto z = result.at(0); + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp index 570f2c0c6aaf..77439018b3e0 100644 --- a/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ServerRelatedTests.cpp @@ -18,24 +18,25 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ServerRelatedTests : public testing::Test { -public: - ServerRelatedTests() { - Environment::getInstance()->setDebug(true); - Environment::getInstance()->setVerbose(true); - } - - ~ServerRelatedTests() { - Environment::getInstance()->setDebug(false); - Environment::getInstance()->setVerbose(false); - } + public: + ServerRelatedTests() { + Environment::getInstance()->setDebug(true); + Environment::getInstance()->setVerbose(true); + } + + ~ServerRelatedTests() { + Environment::getInstance()->setDebug(false); + Environment::getInstance()->setVerbose(false); + } }; /* TEST_F(ServerRelatedTests, Basic_Output_Test_1) { @@ -82,106 +83,111 @@ TEST_F(ServerRelatedTests, Basic_Output_Test_1) { */ #if GRAPH_FILES_OK TEST_F(ServerRelatedTests, Basic_Execution_Test_1) { - flatbuffers::FlatBufferBuilder builder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto exp = NDArrayFactory::create('c', {3}, {3.f, 3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3}, {3.f, 3.f, 3.f}); - GraphHolder::getInstance()->registerGraph(11901L, oGraph); + GraphHolder::getInstance()->registerGraph(11901L, oGraph); - auto cGraph = GraphHolder::getInstance()->cloneGraph(11901L); + auto cGraph = GraphHolder::getInstance()->cloneGraph(11901L); - ASSERT_TRUE(oGraph != cGraph); + ASSERT_TRUE(oGraph != cGraph); - auto flatResult = GraphExecutioner::execute(cGraph, builder, nullptr); + auto flatResult = GraphExecutioner::execute(cGraph, builder, nullptr); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - delete cGraph; + delete cGraph; - GraphHolder::getInstance()->dropGraphAny(11901L); + GraphHolder::getInstance()->dropGraphAny(11901L); } TEST_F(ServerRelatedTests, Basic_Execution_Test_2) { - flatbuffers::FlatBufferBuilder builder(4096); - flatbuffers::FlatBufferBuilder otherBuilder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); - auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + auto input0 = NDArrayFactory::create( + 'c', {3, 3}, {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); - GraphHolder::getInstance()->registerGraph(11902L, oGraph); + GraphHolder::getInstance()->registerGraph(11902L, oGraph); - auto cGraph = GraphHolder::getInstance()->cloneGraph(11902L); + auto cGraph = GraphHolder::getInstance()->cloneGraph(11902L); - ASSERT_TRUE(oGraph != cGraph); + ASSERT_TRUE(oGraph != cGraph); - // mastering InferenceRequest - InferenceRequest ir(11902L); - ir.appendVariable(1, 0, &input0); + // mastering InferenceRequest + InferenceRequest ir(11902L); + ir.appendVariable(1, 0, &input0); - auto af = ir.asFlatInferenceRequest(otherBuilder); - otherBuilder.Finish(af); - auto fptr = otherBuilder.GetBufferPointer(); - auto fir = GetFlatInferenceRequest(fptr); + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); - auto flatResult = GraphExecutioner::execute(cGraph, builder, fir); + auto flatResult = GraphExecutioner::execute(cGraph, builder, fir); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - delete cGraph; + delete cGraph; - GraphHolder::getInstance()->dropGraphAny(11902L); + GraphHolder::getInstance()->dropGraphAny(11902L); } TEST_F(ServerRelatedTests, BasicExecutionTests_3) { - flatbuffers::FlatBufferBuilder builder(4096); - flatbuffers::FlatBufferBuilder otherBuilder(4096); - auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); - oGraph->printOut(); + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers( + "./resources/reduce_dim_false.fb"); + oGraph->printOut(); - auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); - auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + auto input0 = NDArrayFactory::create( + 'c', {3, 3}, {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); - GraphHolder::getInstance()->registerGraph(11903L, oGraph); + GraphHolder::getInstance()->registerGraph(11903L, oGraph); - // mastering InferenceRequest - InferenceRequest ir(11903L); - ir.appendVariable(1, 0, &input0); + // mastering InferenceRequest + InferenceRequest ir(11903L); + ir.appendVariable(1, 0, &input0); - auto af = ir.asFlatInferenceRequest(otherBuilder); - otherBuilder.Finish(af); - auto fptr = otherBuilder.GetBufferPointer(); - auto fir = GetFlatInferenceRequest(fptr); + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); + auto flatResult = + GraphHolder::getInstance()->execute(fir->id(), builder, fir); - auto flatResult = GraphHolder::getInstance()->execute(fir->id(), builder, fir); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); - builder.Finish(flatResult); - auto ptr = builder.GetBufferPointer(); - auto received = GetFlatResult(ptr); - - ExecutionResult restored(received); - ASSERT_EQ(1, restored.size()); + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); - ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); - GraphHolder::getInstance()->dropGraphAny(11903L); + GraphHolder::getInstance()->dropGraphAny(11903L); } #endif diff --git a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp index 46edc50d8bd1..b4e314767daf 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeTests.cpp @@ -19,316 +19,305 @@ // #include -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class ShapeTests : public testing::Test { -public: - + public: }; - TEST_F(ShapeTests, Test_Basics_1) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - ASSERT_EQ(2, shape::rank(shape)); - ASSERT_EQ(1, shape::elementWiseStride(shape)); - ASSERT_EQ(5, shape::sizeAt(shape, 0)); - ASSERT_EQ(3, shape::sizeAt(shape, 1)); - ASSERT_EQ('c', shape::order(shape)); + ASSERT_EQ(2, shape::rank(shape)); + ASSERT_EQ(1, shape::elementWiseStride(shape)); + ASSERT_EQ(5, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ('c', shape::order(shape)); } - TEST_F(ShapeTests, Test_Basics_2) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - - ASSERT_EQ(4, shape::rank(shape)); - ASSERT_EQ(-1, shape::elementWiseStride(shape)); - ASSERT_EQ(2, shape::sizeAt(shape, 0)); - ASSERT_EQ(3, shape::sizeAt(shape, 1)); - ASSERT_EQ(4, shape::sizeAt(shape, 2)); - ASSERT_EQ(5, shape::sizeAt(shape, 3)); - ASSERT_EQ('f', shape::order(shape)); + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + + ASSERT_EQ(4, shape::rank(shape)); + ASSERT_EQ(-1, shape::elementWiseStride(shape)); + ASSERT_EQ(2, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ(4, shape::sizeAt(shape, 2)); + ASSERT_EQ(5, shape::sizeAt(shape, 3)); + ASSERT_EQ('f', shape::order(shape)); } - TEST_F(ShapeTests, Test_tadLength_1) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - int axis[] = {2, 3}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + int axis[] = {2, 3}; - ASSERT_EQ(20, shape::tadLength(shape, axis, 2)); + ASSERT_EQ(20, shape::tadLength(shape, axis, 2)); } - TEST_F(ShapeTests, Test_ShapeEquality_1) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99}; - Nd4jLong shape_BAD[] = {4, 3, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99}; + Nd4jLong shape_BAD[] = {4, 3, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - - ASSERT_TRUE(shape::equalsSoft(shape, shape_GOOD)); - ASSERT_FALSE(shape::equalsSoft(shape, shape_BAD)); + ASSERT_TRUE(shape::equalsSoft(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsSoft(shape, shape_BAD)); } TEST_F(ShapeTests, Test_ShapeEquality_2) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; - Nd4jLong shape_BAD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 99}; - + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_BAD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 99}; - ASSERT_TRUE(shape::equalsStrict(shape, shape_GOOD)); - ASSERT_FALSE(shape::equalsStrict(shape, shape_BAD)); + ASSERT_TRUE(shape::equalsStrict(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsStrict(shape, shape_BAD)); } TEST_F(ShapeTests, Test_Ind2SubC_1) { - Nd4jLong shape[] = {3, 5}; - Nd4jLong c0[2]; - shape::index2coords(0, 2, shape, c0); + Nd4jLong shape[] = {3, 5}; + Nd4jLong c0[2]; + shape::index2coords(0, 2, shape, c0); - ASSERT_EQ(0, c0[0]); - ASSERT_EQ(0, c0[1]); + ASSERT_EQ(0, c0[0]); + ASSERT_EQ(0, c0[1]); - Nd4jLong c1[2]; - shape::index2coords(1, 2, shape, c1); + Nd4jLong c1[2]; + shape::index2coords(1, 2, shape, c1); - ASSERT_EQ(0, c1[0]); - ASSERT_EQ(1, c1[1]); + ASSERT_EQ(0, c1[0]); + ASSERT_EQ(1, c1[1]); - Nd4jLong c6[2]; - shape::index2coords(5, 2, shape, c6); + Nd4jLong c6[2]; + shape::index2coords(5, 2, shape, c6); - ASSERT_EQ(1, c6[0]); - ASSERT_EQ(0, c6[1]); + ASSERT_EQ(1, c6[0]); + ASSERT_EQ(0, c6[1]); } - TEST_F(ShapeTests, Test_ShapeDetector_1) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_2) { - Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_3) { - Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - ASSERT_FALSE(shape::isColumnVector(shape)); - ASSERT_TRUE(shape::isVector(shape)); - ASSERT_TRUE(shape::isRowVector(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_FALSE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_TRUE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } - TEST_F(ShapeTests, Test_ShapeDetector_4) { - Nd4jLong shape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 3, 1, 1, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isColumnVector(shape)); - ASSERT_TRUE(shape::isVector(shape)); - ASSERT_FALSE(shape::isRowVector(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_FALSE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_5) { - Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_TRUE(shape::isScalar(shape)); - ASSERT_FALSE(shape::isMatrix(shape)); + ASSERT_TRUE(shape::isScalar(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); - // edge case here. Technicaly it's still a vector with length of 1 - ASSERT_TRUE(shape::isVector(shape)); + // edge case here. Technicaly it's still a vector with length of 1 + ASSERT_TRUE(shape::isVector(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_6) { - Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_EQ(8, shape::shapeInfoLength(shape)); - ASSERT_EQ(64, shape::shapeInfoByteLength(shape)); + ASSERT_EQ(8, shape::shapeInfoLength(shape)); + ASSERT_EQ(64, shape::shapeInfoByteLength(shape)); } TEST_F(ShapeTests, Test_ShapeDetector_7) { - Nd4jLong shape[] = {3, 1, 1, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong shape[] = {3, 1, 1, 1, 1, 1, 1, 0, 1, 99}; - ASSERT_EQ(10, shape::shapeInfoLength(shape)); - ASSERT_EQ(80, shape::shapeInfoByteLength(shape)); + ASSERT_EQ(10, shape::shapeInfoLength(shape)); + ASSERT_EQ(80, shape::shapeInfoByteLength(shape)); } TEST_F(ShapeTests, Test_Transpose_1) { - Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {3, 3, 5, 2, 1, 3, 15, 0, 1, 102}; + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {3, 3, 5, 2, 1, 3, 15, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Transpose_2) { - Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {2, 3, 5, 1, 3, 0, 1, 102}; + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 5, 1, 3, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Transpose_3) { - Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; - Nd4jLong exp[] = {2, 3, 1, 1, 3, 0, 1, 102}; + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 1, 1, 3, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } - TEST_F(ShapeTests, Test_Transpose_4) { - Nd4jLong shape[] = {4, 2, 3, 4, 5, 5, 4, 3, 2, 0, 1, 99}; - Nd4jLong exp[] = {4, 5, 4, 3, 2, 2, 3, 4, 5, 0, 1, 102}; + Nd4jLong shape[] = {4, 2, 3, 4, 5, 5, 4, 3, 2, 0, 1, 99}; + Nd4jLong exp[] = {4, 5, 4, 3, 2, 2, 3, 4, 5, 0, 1, 102}; - shape::transposeInplace(shape); + shape::transposeInplace(shape); - ASSERT_TRUE(shape::equalsStrict(exp, shape)); + ASSERT_TRUE(shape::equalsStrict(exp, shape)); } TEST_F(ShapeTests, Test_Edge_1) { - auto x = NDArrayFactory::create('f', {1, 4, 1, 4}); - x.linspace(1); + auto x = NDArrayFactory::create('f', {1, 4, 1, 4}); + x.linspace(1); - x.reshapei('c', {4, 4}); + x.reshapei('c', {4, 4}); - //x.printShapeInfo("reshape0"); - //x.printIndexedBuffer("x i"); - //x.printBuffer("x r"); + // x.printShapeInfo("reshape0"); + // x.printIndexedBuffer("x i"); + // x.printBuffer("x r"); - x.reshapei({4, 1, 1, 4}); + x.reshapei({4, 1, 1, 4}); - //x.printShapeInfo("reshape1"); + // x.printShapeInfo("reshape1"); } TEST_F(ShapeTests, Test_Edge_2) { - auto x = NDArrayFactory::create('c', {1, 4, 1, 3}); + auto x = NDArrayFactory::create('c', {1, 4, 1, 3}); - x.reshapei('c', {3, 4}); + x.reshapei('c', {3, 4}); - //x.printShapeInfo("reshape0"); + // x.printShapeInfo("reshape0"); - x.reshapei({3, 1, 1, 4}); + x.reshapei({3, 1, 1, 4}); - //x.printShapeInfo("reshape1"); + // x.printShapeInfo("reshape1"); } - TEST_F(ShapeTests, Test_Remove_Index_1) { - int array[] = {1, 2, 3}; - int idx[] = {0}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {0}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(2, result[0]); - ASSERT_EQ(3, result[1]); + ASSERT_EQ(2, result[0]); + ASSERT_EQ(3, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_2) { - int array[] = {1, 2, 3}; - int idx[] = {1}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {1}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(1, result[0]); - ASSERT_EQ(3, result[1]); + ASSERT_EQ(1, result[0]); + ASSERT_EQ(3, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_3) { - int array[] = {1, 2, 3}; - int idx[] = {2}; - int result[2]; - shape::removeIndex(array, idx, 3, 1, result); + int array[] = {1, 2, 3}; + int idx[] = {2}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); - ASSERT_EQ(1, result[0]); - ASSERT_EQ(2, result[1]); + ASSERT_EQ(1, result[0]); + ASSERT_EQ(2, result[1]); } TEST_F(ShapeTests, Test_Remove_Index_4) { - int array[] = {1, 2, 3}; - int idx[] = {0, 2}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {0, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(2, result[0]); + ASSERT_EQ(2, result[0]); } TEST_F(ShapeTests, Test_Remove_Index_5) { - int array[] = {1, 2, 3}; - int idx[] = {1, 0}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {1, 0}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(3, result[0]); + ASSERT_EQ(3, result[0]); } TEST_F(ShapeTests, Test_Remove_Index_6) { - int array[] = {1, 2, 3}; - int idx[] = {1, 2}; - int result[1]; - shape::removeIndex(array, idx, 3, 2, result); + int array[] = {1, 2, 3}; + int idx[] = {1, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); - ASSERT_EQ(1, result[0]); + ASSERT_EQ(1, result[0]); } TEST_F(ShapeTests, Tests_Transpose_119_1) { - auto x = NDArrayFactory::create('c', {3, 2}); - auto y = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); - auto z = NDArrayFactory::create('c', {2, 3}); + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3}); - x.linspace(1.f); + x.linspace(1.f); - auto e = x.permute({1, 0}); - e.streamline('c'); + auto e = x.permute({1, 0}); + e.streamline('c'); - sd::ops::transpose op; - auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + sd::ops::transpose op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(e.isSameShape(z)); - ASSERT_TRUE(e.equalsTo(z)); + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); } TEST_F(ShapeTests, Tests_Transpose_119_2) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1.f); - - auto exp = x.transpose(); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); - sd::ops::transpose op; - auto result = op.evaluate({&x}); - ASSERT_EQ(Status::OK(), result.status()); + auto exp = x.transpose(); - auto z = result.at(0); + sd::ops::transpose op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(ShapeTests, Tests_Transpose_119_3) { - auto x = NDArrayFactory::create('c', {3, 5}); - x.linspace(1.f); + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); - auto z = NDArrayFactory::create('c', {5, 3}); + auto z = NDArrayFactory::create('c', {5, 3}); - auto exp = x.transpose(); + auto exp = x.transpose(); - sd::ops::transpose op; - auto result = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); + sd::ops::transpose op; + auto result = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp b/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp index b84213342126..79940359f808 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeTests2.cpp @@ -17,255 +17,233 @@ // // Created by agibsonccc on 1/6/17. // +#include +#include #include + #include "testinclude.h" -#include -#include class OnesTest : public testing::Test { -public: - Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,0,1,99}; - int dimension[3] = {0,2,3}; - Nd4jLong tadAssertionShape[10] = {3,1,1,4,1,1,3,0,3,99}; - int dimensionLength = 3; + public: + Nd4jLong shapeBuffer[12] = {4, 4, 3, 1, 1, 3, 1, 1, 1, 0, 1, 99}; + int dimension[3] = {0, 2, 3}; + Nd4jLong tadAssertionShape[10] = {3, 1, 1, 4, 1, 1, 3, 0, 3, 99}; + int dimensionLength = 3; }; class LabelTest : public testing::Test { -public: - float labels[450] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0}; - Nd4jLong shapeInfo[8] = {2,150,3,1,150,16384,1,102}; - int dimension[1] = {1}; - int dimensionLength = 1; - Nd4jLong tadShapeInfoAssert[8] = {2,1,3,1,150,16384,150,102}; + public: + float labels[450] = { + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0}; + Nd4jLong shapeInfo[8] = {2, 150, 3, 1, 150, 16384, 1, 102}; + int dimension[1] = {1}; + int dimensionLength = 1; + Nd4jLong tadShapeInfoAssert[8] = {2, 1, 3, 1, 150, 16384, 150, 102}; }; class ThreeDTest : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - Nd4jLong *shapeBuffer; - ThreeDTest() { - shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - } - ~ThreeDTest() { - delete[] shapeBuffer; - } + public: + Nd4jLong shape[3] = {3, 4, 5}; + Nd4jLong *shapeBuffer; + ThreeDTest() { + shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', + 3, shape); + } + ~ThreeDTest() { delete[] shapeBuffer; } }; -class VectorTest : public testing::Test { - -}; +class VectorTest : public testing::Test {}; class NumTadTests : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - int dimension = 0; + public: + Nd4jLong shape[3] = {3, 4, 5}; + int dimension = 0; }; -class ShapeTest : public testing::Test { -public: - Nd4jLong vectorShape[2] = {1,2}; +class ShapeTest : public testing::Test { + public: + Nd4jLong vectorShape[2] = {1, 2}; }; class MatrixTest : public testing::Test { -public: - int rows = 3; - int cols = 4; - int rank = 2; - int dims[2] = {0,1}; - Nd4jLong expectedShapes[2][2] = { - {1,3}, - {1,4} - }; - Nd4jLong expectedStrides[2][2] = { - {1,4}, - {1,1} - }; + public: + int rows = 3; + int cols = 4; + int rank = 2; + int dims[2] = {0, 1}; + Nd4jLong expectedShapes[2][2] = {{1, 3}, {1, 4}}; + Nd4jLong expectedStrides[2][2] = {{1, 4}, {1, 1}}; }; class TADStall : public testing::Test { -public: - Nd4jLong shape[4] = {3,3,4,5}; - int dimensions[3] = {1,2,3}; + public: + Nd4jLong shape[4] = {3, 3, 4, 5}; + int dimensions[3] = {1, 2, 3}; }; class TensorOneDimTest : public testing::Test { -public: - int rows = 3; - int cols = 4; - int dim2 = 5; - int rank = 3; - int dims[3] = {0,1,2}; - Nd4jLong expectedShapes[3][2] = { - {1,3}, - {1,4}, - {1,5} - }; - Nd4jLong expectedStrides[3][2] = { - {1,20}, - {1,5}, - {1,1} - }; + public: + int rows = 3; + int cols = 4; + int dim2 = 5; + int rank = 3; + int dims[3] = {0, 1, 2}; + Nd4jLong expectedShapes[3][2] = {{1, 3}, {1, 4}, {1, 5}}; + Nd4jLong expectedStrides[3][2] = {{1, 20}, {1, 5}, {1, 1}}; }; class TensorTwoDimTest : public testing::Test { -public: - //From a 3d array: - int rows = 3; - int cols = 4; - int dim2 = 5; - int dimensionLength = 2; - int dims[3][2] = { - {0,1},{0,2},{1,2} - }; - - Nd4jLong shape[3] {rows,cols,dim2}; - - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - Nd4jLong expectedShapes[3][2] = { - {rows,cols}, - {rows,dim2}, - {cols,dim2} - }; - - Nd4jLong expectedStrides[3][2] = { - {20,5}, - {20,1}, - {5,1} - }; - + public: + // From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dimensionLength = 2; + int dims[3][2] = {{0, 1}, {0, 2}, {1, 2}}; + + Nd4jLong shape[3]{rows, cols, dim2}; + + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + Nd4jLong expectedShapes[3][2] = {{rows, cols}, {rows, dim2}, {cols, dim2}}; + + Nd4jLong expectedStrides[3][2] = {{20, 5}, {20, 1}, {5, 1}}; }; class TensorTwoFromFourDDimTest : public testing::Test { -public: - //From a 3d array: - int rows = 3; - int cols = 4; - int dim2 = 5; - int dim3 = 6; - Nd4jLong shape[4] = {rows,cols,dim2,dim3}; - int dimensionLength = 2; - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 0,3: expect matrix with shape [rows,dim3] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - //Along dimension 1,3: expect matrix with shape [cols,dim3] - //Along dimension 2,3: expect matrix with shape [dim2,dim3] - - int dims[6][2] = { - {0,1}, - {0,2}, - {0,3}, - {1,2}, - {1,3}, - {2,3} - }; - - Nd4jLong expectedShapes[6][2] = { - {rows,cols}, - {rows,dim2}, - {rows,dim3}, - {cols,dim2}, - {cols,dim3} - ,{dim2,dim3} - }; - - Nd4jLong expectedStrides[6][2] = { - {120,30}, - {120,6}, - {120,1}, - {30,6}, - {30,1}, - {6,1} - }; + public: + // From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dim3 = 6; + Nd4jLong shape[4] = {rows, cols, dim2, dim3}; + int dimensionLength = 2; + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 0,3: expect matrix with shape [rows,dim3] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + // Along dimension 1,3: expect matrix with shape [cols,dim3] + // Along dimension 2,3: expect matrix with shape [dim2,dim3] + + int dims[6][2] = {{0, 1}, {0, 2}, {0, 3}, {1, 2}, {1, 3}, {2, 3}}; + + Nd4jLong expectedShapes[6][2] = {{rows, cols}, {rows, dim2}, {rows, dim3}, + {cols, dim2}, {cols, dim3}, {dim2, dim3}}; + + Nd4jLong expectedStrides[6][2] = {{120, 30}, {120, 6}, {120, 1}, + {30, 6}, {30, 1}, {6, 1}}; }; - class OrderTest : public testing::Test { -public: - Nd4jLong expected[8] = {2,3,4,1,3,0,0,102}; - Nd4jLong test[8] = {2,3,4,1,3,0,0,102}; - + public: + Nd4jLong expected[8] = {2, 3, 4, 1, 3, 0, 0, 102}; + Nd4jLong test[8] = {2, 3, 4, 1, 3, 0, 0, 102}; }; - class LeadingOnes : public testing::Test { -public: - Nd4jLong shapeBufferF[16] = {4,1,1,4,4,1,1,1,4,16384,1,102}; // shapes with data type DOUBLE - Nd4jLong shapeBufferC[16] = {4,1,1,4,4,16,16,4,1,16384,1,99}; - int dimensionLength = 2; - int dimension[2] = {2,3}; - Nd4jLong tadAssertionC[10] = {3,4,4,1,4,1,16,16384,1,99}; - Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,1,16384,1,102}; + public: + Nd4jLong shapeBufferF[16] = { + 4, 1, 1, 4, 4, 1, + 1, 1, 4, 16384, 1, 102}; // shapes with data type DOUBLE + Nd4jLong shapeBufferC[16] = {4, 1, 1, 4, 4, 16, 16, 4, 1, 16384, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {2, 3}; + Nd4jLong tadAssertionC[10] = {3, 4, 4, 1, 4, 1, 16, 16384, 1, 99}; + Nd4jLong tadCAssertionF[10] = {3, 4, 4, 1, 1, 4, 1, 16384, 1, 102}; }; - -TEST_F(LeadingOnes,OnesTest) { - - shape::TAD *cTad = new shape::TAD; - cTad->init(shapeBufferC,dimension,dimensionLength); - cTad->createTadOnlyShapeInfo(); - cTad->createOffsets(); - shape::TAD *fTad = new shape::TAD; - fTad->init(shapeBufferF,dimension,dimensionLength); - fTad->createTadOnlyShapeInfo(); - fTad->createOffsets(); - // shape::printShapeInfoLinear(cTad->tadOnlyShapeInfo); - // shape::printShapeInfoLinear(fTad->tadOnlyShapeInfo); - ASSERT_TRUE(arrsEquals(10, tadCAssertionF, fTad->tadOnlyShapeInfo)); - ASSERT_TRUE(arrsEquals(10, tadAssertionC, cTad->tadOnlyShapeInfo)); - - delete cTad; - delete fTad; +TEST_F(LeadingOnes, OnesTest) { + shape::TAD *cTad = new shape::TAD; + cTad->init(shapeBufferC, dimension, dimensionLength); + cTad->createTadOnlyShapeInfo(); + cTad->createOffsets(); + shape::TAD *fTad = new shape::TAD; + fTad->init(shapeBufferF, dimension, dimensionLength); + fTad->createTadOnlyShapeInfo(); + fTad->createOffsets(); + // shape::printShapeInfoLinear(cTad->tadOnlyShapeInfo); + // shape::printShapeInfoLinear(fTad->tadOnlyShapeInfo); + ASSERT_TRUE(arrsEquals(10, tadCAssertionF, fTad->tadOnlyShapeInfo)); + ASSERT_TRUE(arrsEquals(10, tadAssertionC, cTad->tadOnlyShapeInfo)); + + delete cTad; + delete fTad; } - class NormalThreeFourFive : public testing::Test { -public: - Nd4jLong assertionBuffer[8] = {2, 3, 4, 20, 5, 16384, 5, 99}; - Nd4jLong inputShapeBuffer[10] = {3,3,4,5,20,5,1,16384,1,99}; - int dimensionLength = 2; - int dimension[2] = {0,1}; + public: + Nd4jLong assertionBuffer[8] = {2, 3, 4, 20, 5, 16384, 5, 99}; + Nd4jLong inputShapeBuffer[10] = {3, 3, 4, 5, 20, 5, 1, 16384, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {0, 1}; }; +TEST_F(NormalThreeFourFive, DimensionTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + ASSERT_TRUE(arrsEquals(8, assertionBuffer, tad->tadOnlyShapeInfo)); -TEST_F(NormalThreeFourFive,DimensionTest) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - ASSERT_TRUE(arrsEquals(8,assertionBuffer,tad->tadOnlyShapeInfo)); - - delete tad; + delete tad; } class DimensionWarning : public testing::Test { -public: - int dimensionLength = 2; - int dimensions[2] = {0,1}; - Nd4jLong shape[3] = {1,5,1}; - Nd4jLong *shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - - ~DimensionWarning() { - delete[] shapeBuffer; - } + public: + int dimensionLength = 2; + int dimensions[2] = {0, 1}; + Nd4jLong shape[3] = {1, 5, 1}; + Nd4jLong *shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + + ~DimensionWarning() { delete[] shapeBuffer; } }; - -TEST_F(DimensionWarning,ShapeWarning) { - shape::TAD *tad = new shape::TAD; - tad->init(shapeBuffer,dimensions,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - delete tad; +TEST_F(DimensionWarning, ShapeWarning) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer, dimensions, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + delete tad; } - class TadRank : public testing::Test { - Nd4jLong shapeBuffer[12] = {4,2,1,3,3,9,9,3,1,0,1,99}; - int dimensionLength = 2; - int dimension[2] = {2,3}; - + Nd4jLong shapeBuffer[12] = {4, 2, 1, 3, 3, 9, 9, 3, 1, 0, 1, 99}; + int dimensionLength = 2; + int dimension[2] = {2, 3}; }; class TestRemoveIndex : public testing::Test {}; @@ -281,95 +259,88 @@ class SliceMatrixTest : public testing::Test {}; class SliceTensorTest : public testing::Test {}; class ElementWiseStrideTest : public testing::Test { -public: - Nd4jLong shape[3] = {3,4,5}; - Nd4jLong stride[2] = {20,5}; - int elementWiseStrideAssertion = -1; + public: + Nd4jLong shape[3] = {3, 4, 5}; + Nd4jLong stride[2] = {20, 5}; + int elementWiseStrideAssertion = -1; }; -class PermuteTest : public testing::Test{}; +class PermuteTest : public testing::Test {}; -class LengthPerSliceTest : public testing::Test{}; +class LengthPerSliceTest : public testing::Test {}; class ExpectedValuesTest : public testing::Test { -public: - Nd4jLong mainShape[4] = {9,7,5,3}; - int testDimensions[3] = {0,2,3}; - + public: + Nd4jLong mainShape[4] = {9, 7, 5, 3}; + int testDimensions[3] = {0, 2, 3}; }; class BeginOneTadTest : public testing::Test { -public: - Nd4jLong assertionShapeBuffer[8] = {2,3,5,1,3,16384,1,102}; - Nd4jLong inputShapeBuffer[10] = {3,1,3,5,1,1,3,16384,0,102}; - int dimensionLength = 2; - int dimension[2] = {1,2}; - //error: [2,1,1,1,1,0,1,97] + public: + Nd4jLong assertionShapeBuffer[8] = {2, 3, 5, 1, 3, 16384, 1, 102}; + Nd4jLong inputShapeBuffer[10] = {3, 1, 3, 5, 1, 1, 3, 16384, 0, 102}; + int dimensionLength = 2; + int dimension[2] = {1, 2}; + // error: [2,1,1,1,1,0,1,97] }; class FourDTest : public testing::Test { - /** - * INDArray array3d = Nd4j.ones(1, 10, 10); + /** + * INDArray array3d = Nd4j.ones(1, 10, 10); array3d.sum(1); INDArray array4d = Nd4j.ones(1, 10, 10, 10); INDArray sum40 = array4d.sum(0); - */ -public: - Nd4jLong threeDShape[3] = {1,10,10}; - Nd4jLong fourDShape[4] = {1,10,10,10}; - Nd4jLong *threeDShapeBuffer = nullptr,*fourDShapeBuffer = nullptr; - int dimensionThree = 1; - int dimensionThreeTwo = 0; - int dimensionFour = 0; - int dimensionLength = 1; - FourDTest() { - threeDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 3, threeDShape); - fourDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 4, fourDShape); - } - ~FourDTest() { - if(threeDShapeBuffer != nullptr) - delete[] threeDShapeBuffer; - if(fourDShapeBuffer != nullptr) - delete[] fourDShapeBuffer; - } - - - + */ + public: + Nd4jLong threeDShape[3] = {1, 10, 10}; + Nd4jLong fourDShape[4] = {1, 10, 10, 10}; + Nd4jLong *threeDShapeBuffer = nullptr, *fourDShapeBuffer = nullptr; + int dimensionThree = 1; + int dimensionThreeTwo = 0; + int dimensionFour = 0; + int dimensionLength = 1; + FourDTest() { + threeDShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'f', 3, threeDShape); + fourDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'f', 4, fourDShape); + } + ~FourDTest() { + if (threeDShapeBuffer != nullptr) delete[] threeDShapeBuffer; + if (fourDShapeBuffer != nullptr) delete[] fourDShapeBuffer; + } }; - - -TEST_F(FourDTest,ThreeDFourDTest) { - shape::TAD *threeTadTwo = new shape::TAD; - threeTadTwo->init(threeDShapeBuffer,&dimensionThreeTwo,dimensionLength); - threeTadTwo->createTadOnlyShapeInfo(); - threeTadTwo->createOffsets(); - - shape::TAD *threeTad = new shape::TAD; - threeTad->init(threeDShapeBuffer,&dimensionThree,dimensionLength); - threeTad->createTadOnlyShapeInfo(); - threeTad->createOffsets(); - - shape::TAD *fourTad = new shape::TAD; - fourTad->init(fourDShapeBuffer,&dimensionFour,dimensionLength); - fourTad->createTadOnlyShapeInfo(); - fourTad->createOffsets(); - - delete threeTadTwo; - delete threeTad; - delete fourTad; +TEST_F(FourDTest, ThreeDFourDTest) { + shape::TAD *threeTadTwo = new shape::TAD; + threeTadTwo->init(threeDShapeBuffer, &dimensionThreeTwo, dimensionLength); + threeTadTwo->createTadOnlyShapeInfo(); + threeTadTwo->createOffsets(); + + shape::TAD *threeTad = new shape::TAD; + threeTad->init(threeDShapeBuffer, &dimensionThree, dimensionLength); + threeTad->createTadOnlyShapeInfo(); + threeTad->createOffsets(); + + shape::TAD *fourTad = new shape::TAD; + fourTad->init(fourDShapeBuffer, &dimensionFour, dimensionLength); + fourTad->createTadOnlyShapeInfo(); + fourTad->createOffsets(); + + delete threeTadTwo; + delete threeTad; + delete fourTad; } - - class RowVectorOnesTest : public testing::Test { -public: - Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,8192,1,99}; // float32 type of shape - float data[12] = {1,2,3,4,5,6,7,8,9,10,11,12}; - Nd4jLong assertionBuffer[10] = {3,4,1,1,3,1,1,8192,0,99}; - int dimensionLength = 3; - int dimension[3] = {0,2,3}; + public: + Nd4jLong shapeBuffer[12] = {4, 4, 3, 1, 1, 3, + 1, 1, 1, 8192, 1, 99}; // float32 type of shape + float data[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + Nd4jLong assertionBuffer[10] = {3, 4, 1, 1, 3, 1, 1, 8192, 0, 99}; + int dimensionLength = 3; + int dimension[3] = {0, 2, 3}; }; // TEST_F(RowVectorOnesTest,TadShape) { @@ -380,57 +351,58 @@ class RowVectorOnesTest : public testing::Test { // delete tad; // } - - class SixDTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[16] = {6,1,1,4,4,4,4,1,1,1,4,16,64,16384,1,102}; // shape with double data type - int dimensionLength = 2; - int dimension[2] = {2,3}; - Nd4jLong assertionShapeBuffer[8] = {2,4,4,1,4,16384,1,102}; // also double typed shape + public: + Nd4jLong inputShapeBuffer[16] = { + 6, 1, 1, 4, 4, 4, 4, 1, + 1, 1, 4, 16, 64, 16384, 1, 102}; // shape with double data type + int dimensionLength = 2; + int dimension[2] = {2, 3}; + Nd4jLong assertionShapeBuffer[8] = { + 2, 4, 4, 1, 4, 16384, 1, 102}; // also double typed shape }; TEST_F(SixDTest, SixDWithOnes) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - // shape::printShapeInfoLinear(inputShapeBuffer); - // shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + // shape::printShapeInfoLinear(inputShapeBuffer); + // shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } class TrailingTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[12] = {4,5,5,5,1,1,5,25,125,16384,1,102}; - int dimensionLength = 1; - int dimension[1] = {0}; - Nd4jLong assertionShapeBuffer[8] = {2,1,5,125,1,16384,1,102}; + public: + Nd4jLong inputShapeBuffer[12] = {4, 5, 5, 5, 1, 1, 5, 25, 125, 16384, 1, 102}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 5, 125, 1, 16384, 1, 102}; }; -TEST_F(TrailingTest,TrailingTest2) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; +TEST_F(TrailingTest, TrailingTest2) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } - class ScalarTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[12] = {3,2,3,4,12,4,1,16384,1,99}; - int dimensionLength = 1; - int dimension[1] = {1}; - Nd4jLong assertionShapeBuffer[8] = {2,1,1,1,1,16384,1,99}; + public: + Nd4jLong inputShapeBuffer[12] = {3, 2, 3, 4, 12, 4, 1, 16384, 1, 99}; + int dimensionLength = 1; + int dimension[1] = {1}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 1, 1, 1, 16384, 1, 99}; }; /* TEST_F(ScalarTest,ScalarTest2) { - shape::TAD *tad = new shape::TAD(inputShapeBuffer,dimension,dimensionLength); + shape::TAD *tad = new +shape::TAD(inputShapeBuffer,dimension,dimensionLength); tad->createTadOnlyShapeInfo(); tad ->createOffsets(); //[2,1,1,1,1,0,1,97] @@ -439,37 +411,34 @@ TEST_F(ScalarTest,ScalarTest2) { } */ - - class ThreeTest : public testing::Test { -public: - Nd4jLong inputShapeBuffer[10] = {3,4,3,2,6,2,1,16384,1,99}; - int dimensionLength = 1; - int dimension[1] = {0}; - Nd4jLong assertionShapeBuffer[8] = {2,1,4,1,6,16384,6,99}; + public: + Nd4jLong inputShapeBuffer[10] = {3, 4, 3, 2, 6, 2, 1, 16384, 1, 99}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2, 1, 4, 1, 6, 16384, 6, 99}; }; -TEST_F(ThreeTest,ThreeTest ) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - tad->createOffsets(); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); - delete tad; +TEST_F(ThreeTest, ThreeTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tad->tadOnlyShapeInfo)); + delete tad; } - TEST_F(BeginOneTadTest, TadTest) { - shape::TAD *tad = new shape::TAD; - tad->init(inputShapeBuffer,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - auto tadShapeBuffer = tad->tadOnlyShapeInfo; - // shape::printShapeInfoLinear(tadShapeBuffer); - //[2,1,1,1,1,0,1,97] - ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tadShapeBuffer)); - - delete tad; + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeBuffer = tad->tadOnlyShapeInfo; + // shape::printShapeInfoLinear(tadShapeBuffer); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8, assertionShapeBuffer, tadShapeBuffer)); + + delete tad; } /* @@ -481,338 +450,352 @@ TEST_F(OnesTest,OnesTadTest) { } */ -TEST_F(LabelTest,LabelTad) { - shape::TAD *tad = new shape::TAD; - tad->init(shapeInfo,dimension,dimensionLength); - tad->createTadOnlyShapeInfo(); - auto tadShapeInfo = tad->tadOnlyShapeInfo; - ASSERT_TRUE(arrsEquals(8,tadShapeInfoAssert,tadShapeInfo)); +TEST_F(LabelTest, LabelTad) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeInfo, dimension, dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeInfo = tad->tadOnlyShapeInfo; + ASSERT_TRUE(arrsEquals(8, tadShapeInfoAssert, tadShapeInfo)); - delete tad; + delete tad; } -TEST_F(ExpectedValuesTest,TadTest) { - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, mainShape); - shape::TAD *tad = new shape::TAD; - tad->init(shapeBuffer,testDimensions,3); - tad->createTadOnlyShapeInfo(); - auto shapeInfo = tad->tadOnlyShapeInfo; - - delete tad; - delete[] shapeBuffer; -} +TEST_F(ExpectedValuesTest, TadTest) { + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', 4, mainShape); + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer, testDimensions, 3); + tad->createTadOnlyShapeInfo(); + auto shapeInfo = tad->tadOnlyShapeInfo; -TEST_F(OrderTest,testOrder) { - int rank = shape::rank(expected); - auto expectedShape = shape::shapeOf(expected); - auto expectedStride = shape::stride(expected); - int realOrder = shape::getOrder(rank,expectedShape,expectedStride,1); - int expectedOrder = 102; - ASSERT_EQ(expectedOrder,realOrder); + delete tad; + delete[] shapeBuffer; } - -TEST_F(ThreeDTest,TensorAlongDimensionTest) { - int dimension[2] = {0,2}; - Nd4jLong tadShapeAssertion[2] = {3,5}; - Nd4jLong strideAssertion[2] = {20,1}; - shape::TAD *tad = new shape::TAD; - tad->init(0,this->shapeBuffer,dimension,2); - tad->createTadOnlyShapeInfo(); - auto shapeBufferTest = tad->tadOnlyShapeInfo; - auto shapeTest = shape::shapeOf(shapeBufferTest); - auto strideTest = shape::stride(shapeBufferTest); - ASSERT_TRUE(arrsEquals(2,tadShapeAssertion,shapeTest)); - ASSERT_TRUE(arrsEquals(2,strideAssertion,strideTest)); - delete tad; +TEST_F(OrderTest, testOrder) { + int rank = shape::rank(expected); + auto expectedShape = shape::shapeOf(expected); + auto expectedStride = shape::stride(expected); + int realOrder = shape::getOrder(rank, expectedShape, expectedStride, 1); + int expectedOrder = 102; + ASSERT_EQ(expectedOrder, realOrder); } - -TEST_F(NumTadTests,TadTest) { - auto shape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, this->shape); - shape::TAD *tad = new shape::TAD; - tad->init(shape,&dimension,1); - int numTads = shape::tensorsAlongDimension(shape,&dimension,1); - ASSERT_EQ(20,numTads); - delete[] shape; - delete tad; +TEST_F(ThreeDTest, TensorAlongDimensionTest) { + int dimension[2] = {0, 2}; + Nd4jLong tadShapeAssertion[2] = {3, 5}; + Nd4jLong strideAssertion[2] = {20, 1}; + shape::TAD *tad = new shape::TAD; + tad->init(0, this->shapeBuffer, dimension, 2); + tad->createTadOnlyShapeInfo(); + auto shapeBufferTest = tad->tadOnlyShapeInfo; + auto shapeTest = shape::shapeOf(shapeBufferTest); + auto strideTest = shape::stride(shapeBufferTest); + ASSERT_TRUE(arrsEquals(2, tadShapeAssertion, shapeTest)); + ASSERT_TRUE(arrsEquals(2, strideAssertion, strideTest)); + delete tad; } -TEST_F(TADStall,TestStall) { - auto shapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); - shape::TAD *tad = new shape::TAD; - tad->init(0,shapeInfo,this->dimensions,3); - tad->createTadOnlyShapeInfo(); - Nd4jLong *test = tad->tadOnlyShapeInfo; - - delete[] shapeInfo; - delete tad; +TEST_F(NumTadTests, TadTest) { + auto shape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, + this->shape); + shape::TAD *tad = new shape::TAD; + tad->init(shape, &dimension, 1); + int numTads = shape::tensorsAlongDimension(shape, &dimension, 1); + ASSERT_EQ(20, numTads); + delete[] shape; + delete tad; } +TEST_F(TADStall, TestStall) { + auto shapeInfo = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + shape::TAD *tad = new shape::TAD; + tad->init(0, shapeInfo, this->dimensions, 3); + tad->createTadOnlyShapeInfo(); + Nd4jLong *test = tad->tadOnlyShapeInfo; -TEST_F(LengthPerSliceTest,TestLengthPerSlice) { - Nd4jLong firstShape[2] = {5,3}; - int lengthPerSliceAssertionFirst = 3; - int firstDimension = 0; - int lengthPerSliceTest = shape::lengthPerSlice(2,firstShape,&firstDimension,1); - ASSERT_EQ(lengthPerSliceAssertionFirst,lengthPerSliceTest); + delete[] shapeInfo; + delete tad; } -TEST_F(PermuteTest,PermuteShapeBufferTest) { - int permuteOrder[4] = {3,2,1,0}; - int normalOrder[4] = {0,1,2,3}; - Nd4jLong shapeToPermute[4] = {5,3,2,6}; - Nd4jLong permutedOrder[4] = {6,2,3,5}; - auto shapeBufferOriginal = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); - auto assertionShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); - shape::permuteShapeBufferInPlace(shapeBufferOriginal,normalOrder,shapeBufferOriginal); - EXPECT_TRUE(arrsEquals(4,assertionShapeBuffer,shapeBufferOriginal)); - - auto backwardsAssertion = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, permutedOrder); - auto permuted = shape::permuteShapeBuffer(assertionShapeBuffer, permuteOrder); - EXPECT_TRUE(arrsEquals(4, backwardsAssertion, permuted)); - - - delete[] permuted; - delete[] backwardsAssertion; - delete[] shapeBufferOriginal; - delete[] assertionShapeBuffer; +TEST_F(LengthPerSliceTest, TestLengthPerSlice) { + Nd4jLong firstShape[2] = {5, 3}; + int lengthPerSliceAssertionFirst = 3; + int firstDimension = 0; + int lengthPerSliceTest = + shape::lengthPerSlice(2, firstShape, &firstDimension, 1); + ASSERT_EQ(lengthPerSliceAssertionFirst, lengthPerSliceTest); } -TEST_F(ElementWiseStrideTest,ElementWiseStrideTest) { - +TEST_F(PermuteTest, PermuteShapeBufferTest) { + int permuteOrder[4] = {3, 2, 1, 0}; + int normalOrder[4] = {0, 1, 2, 3}; + Nd4jLong shapeToPermute[4] = {5, 3, 2, 6}; + Nd4jLong permutedOrder[4] = {6, 2, 3, 5}; + auto shapeBufferOriginal = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + auto assertionShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + shape::permuteShapeBufferInPlace(shapeBufferOriginal, normalOrder, + shapeBufferOriginal); + EXPECT_TRUE(arrsEquals(4, assertionShapeBuffer, shapeBufferOriginal)); + + auto backwardsAssertion = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 4, permutedOrder); + auto permuted = shape::permuteShapeBuffer(assertionShapeBuffer, permuteOrder); + EXPECT_TRUE(arrsEquals(4, backwardsAssertion, permuted)); + + delete[] permuted; + delete[] backwardsAssertion; + delete[] shapeBufferOriginal; + delete[] assertionShapeBuffer; } -TEST_F(SliceVectorTest,RowColumnVectorTest) { - Nd4jLong rowVectorShape[2] = {1,5}; - auto rowVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); - Nd4jLong colVectorShape[2] = {5,1}; - auto colVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, colVectorShape); - Nd4jLong *sliceRow = shape::sliceOfShapeBuffer(0,rowVectorShapeInfo); - EXPECT_TRUE(arrsEquals(2,rowVectorShapeInfo,sliceRow)); - Nd4jLong *scalarSliceInfo = shape::createScalarShapeInfo(); - Nd4jLong *scalarColumnAssertion = shape::createScalarShapeInfo(); - scalarColumnAssertion[shape::shapeInfoLength(2) - 3] = 1; - Nd4jLong *scalarColumnTest = shape::sliceOfShapeBuffer(1L,colVectorShapeInfo); - EXPECT_TRUE(arrsEquals(2,scalarColumnAssertion,scalarColumnTest)); - - delete[] scalarColumnTest; - delete[] scalarColumnAssertion; - delete[] scalarSliceInfo; - delete[] sliceRow; - delete[] rowVectorShapeInfo; - delete[] colVectorShapeInfo; +TEST_F(ElementWiseStrideTest, ElementWiseStrideTest) {} + +TEST_F(SliceVectorTest, RowColumnVectorTest) { + Nd4jLong rowVectorShape[2] = {1, 5}; + auto rowVectorShapeInfo = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + Nd4jLong colVectorShape[2] = {5, 1}; + auto colVectorShapeInfo = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, colVectorShape); + Nd4jLong *sliceRow = shape::sliceOfShapeBuffer(0, rowVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2, rowVectorShapeInfo, sliceRow)); + Nd4jLong *scalarSliceInfo = shape::createScalarShapeInfo(); + Nd4jLong *scalarColumnAssertion = shape::createScalarShapeInfo(); + scalarColumnAssertion[shape::shapeInfoLength(2) - 3] = 1; + Nd4jLong *scalarColumnTest = + shape::sliceOfShapeBuffer(1L, colVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2, scalarColumnAssertion, scalarColumnTest)); + + delete[] scalarColumnTest; + delete[] scalarColumnAssertion; + delete[] scalarSliceInfo; + delete[] sliceRow; + delete[] rowVectorShapeInfo; + delete[] colVectorShapeInfo; } -TEST_F(SliceTensorTest,TestSlice) { - Nd4jLong shape[3] = {3,3,2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - Nd4jLong sliceShape[2] = {3,2}; - auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); - Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); - EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); - delete[] testSlice; - delete[] shapeBuffer; - delete[] sliceShapeBuffer; - -} - -TEST_F(SliceMatrixTest,TestSlice) { - Nd4jLong shape[2] = {3,2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - Nd4jLong sliceShape[2] = {1,2}; - auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); - Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); - EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); - delete[] testSlice; - delete[] shapeBuffer; - delete[] sliceShapeBuffer; - +TEST_F(SliceTensorTest, TestSlice) { + Nd4jLong shape[3] = {3, 3, 2}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + Nd4jLong sliceShape[2] = {3, 2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0, shapeBuffer); + EXPECT_TRUE(arrsEquals(2, sliceShapeBuffer, testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; } - -TEST_F(TestConcat,ConcatTest) { - Nd4jLong firstArr[2] = {1,2}; - Nd4jLong secondConcat[2] = {3,4}; - Nd4jLong concatAssertion[4] = {1,2,3,4}; - Nd4jLong *concatTest = shape::concat(firstArr,2,secondConcat,2); - EXPECT_TRUE(arrsEquals(4,concatAssertion,concatTest)); - delete[] concatTest; +TEST_F(SliceMatrixTest, TestSlice) { + Nd4jLong shape[2] = {3, 2}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); + Nd4jLong sliceShape[2] = {1, 2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0, shapeBuffer); + EXPECT_TRUE(arrsEquals(2, sliceShapeBuffer, testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; } -TEST_F(TestReverseCopy,ReverseCopyTest) { - Nd4jLong toCopy[5] = {0,1,2,3,4}; - Nd4jLong reverseAssertion[5] = {4,3,2,1,0}; - Nd4jLong *reverseCopyTest = shape::reverseCopy(toCopy,5); - EXPECT_TRUE(arrsEquals(5,reverseAssertion,reverseCopyTest)); - delete[] reverseCopyTest; +TEST_F(TestConcat, ConcatTest) { + Nd4jLong firstArr[2] = {1, 2}; + Nd4jLong secondConcat[2] = {3, 4}; + Nd4jLong concatAssertion[4] = {1, 2, 3, 4}; + Nd4jLong *concatTest = shape::concat(firstArr, 2, secondConcat, 2); + EXPECT_TRUE(arrsEquals(4, concatAssertion, concatTest)); + delete[] concatTest; } -TEST_F(TestRemoveIndex,Remove) { - Nd4jLong input[5] = {0,1,2,3,4}; - Nd4jLong indexesToRemove[3] = {0,1,2}; - Nd4jLong indexesToRemoveAssertion[2] = {3,4}; - Nd4jLong *indexesToRemoveTest = shape::removeIndex(input,indexesToRemove, (Nd4jLong) 5, (Nd4jLong) 3); - EXPECT_TRUE(arrsEquals(2,indexesToRemoveAssertion,indexesToRemoveTest)); - delete[] indexesToRemoveTest; +TEST_F(TestReverseCopy, ReverseCopyTest) { + Nd4jLong toCopy[5] = {0, 1, 2, 3, 4}; + Nd4jLong reverseAssertion[5] = {4, 3, 2, 1, 0}; + Nd4jLong *reverseCopyTest = shape::reverseCopy(toCopy, 5); + EXPECT_TRUE(arrsEquals(5, reverseAssertion, reverseCopyTest)); + delete[] reverseCopyTest; } -TEST_F(TensorTwoFromFourDDimTest,TadTwoFromFourDimTest) { - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 0,3: expect matrix with shape [rows,dim3] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - //Along dimension 1,3: expect matrix with shape [cols,dim3] - //Along dimension 2,3: expect matrix with shape [dim2,dim3] - auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); - for(int i = 0; i < 3; i++) { - int *dimArr = dims[i]; - Nd4jLong *expectedShape = expectedShapes[i]; - shape::TAD *tad = new shape::TAD; - tad->init(baseShapeBuffer,dimArr,dimensionLength); - auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); - tad->createTadOnlyShapeInfo(); - Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,shape::shapeOf(testShapeBuffer))); - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedStrides[i],shape::stride(testShapeBuffer))); - - delete[] expectedShapeBuffer; - delete tad; - } - - delete[] baseShapeBuffer; +TEST_F(TestRemoveIndex, Remove) { + Nd4jLong input[5] = {0, 1, 2, 3, 4}; + Nd4jLong indexesToRemove[3] = {0, 1, 2}; + Nd4jLong indexesToRemoveAssertion[2] = {3, 4}; + Nd4jLong *indexesToRemoveTest = shape::removeIndex( + input, indexesToRemove, (Nd4jLong)5, (Nd4jLong)3); + EXPECT_TRUE(arrsEquals(2, indexesToRemoveAssertion, indexesToRemoveTest)); + delete[] indexesToRemoveTest; } -TEST_F(TensorTwoDimTest,TadTwoDimTest) { - //Along dimension 0,1: expect matrix with shape [rows,cols] - //Along dimension 0,2: expect matrix with shape [rows,dim2] - //Along dimension 1,2: expect matrix with shape [cols,dim2] - auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); - - for(int i = 0; i < 3; i++) { - int *dimArr = dims[i]; - Nd4jLong *expectedShape = expectedShapes[i]; - shape::TAD *tad = new shape::TAD; - tad->init(baseShapeBuffer,dimArr,dimensionLength); - auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); - tad->createTadOnlyShapeInfo(); - Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; - Nd4jLong *expectedStride = expectedStrides[i]; - Nd4jLong *testShape = shape::shapeOf(testShapeBuffer); - Nd4jLong *testStride = shape::stride(testShapeBuffer); - EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,testShape)); - EXPECT_TRUE(arrsEquals(shape::rank(testShapeBuffer),expectedStride,testStride)); +TEST_F(TensorTwoFromFourDDimTest, TadTwoFromFourDimTest) { + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 0,3: expect matrix with shape [rows,dim3] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + // Along dimension 1,3: expect matrix with shape [cols,dim3] + // Along dimension 2,3: expect matrix with shape [dim2,dim3] + auto baseShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + for (int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer, dimArr, dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer), expectedShape, + shape::shapeOf(testShapeBuffer))); + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer), expectedStrides[i], + shape::stride(testShapeBuffer))); - delete[] expectedShapeBuffer; - delete tad; + delete[] expectedShapeBuffer; + delete tad; + } - } + delete[] baseShapeBuffer; +} - delete[] baseShapeBuffer; +TEST_F(TensorTwoDimTest, TadTwoDimTest) { + // Along dimension 0,1: expect matrix with shape [rows,cols] + // Along dimension 0,2: expect matrix with shape [rows,dim2] + // Along dimension 1,2: expect matrix with shape [cols,dim2] + auto baseShapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + for (int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer, dimArr, dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + Nd4jLong *expectedStride = expectedStrides[i]; + Nd4jLong *testShape = shape::shapeOf(testShapeBuffer); + Nd4jLong *testStride = shape::stride(testShapeBuffer); + EXPECT_TRUE( + arrsEquals(shape::rank(expectedShapeBuffer), expectedShape, testShape)); + EXPECT_TRUE( + arrsEquals(shape::rank(testShapeBuffer), expectedStride, testStride)); + + delete[] expectedShapeBuffer; + delete tad; + } + delete[] baseShapeBuffer; } -TEST_F(TensorOneDimTest,TadDimensionsForTensor) { - Nd4jLong shape[3] = {rows,cols,dim2}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); - - for(int i = 0; i < rank; i++) { - //Along dimension 0: expect row vector with length 'dims[i]' - shape::TAD *zero = new shape::TAD; - zero->init(shapeBuffer,&dims[i],1); - zero->createTadOnlyShapeInfo(); - Nd4jLong *testDimZeroShapeBuffer = zero->tadOnlyShapeInfo; - Nd4jLong *testShape = shape::shapeOf(testDimZeroShapeBuffer); - Nd4jLong *testStride = shape::stride(testDimZeroShapeBuffer); - EXPECT_TRUE(arrsEquals(2,expectedShapes[i],testShape)); - EXPECT_TRUE(arrsEquals(2,expectedStrides[i],testStride)); - - delete zero; - } - - delete[] shapeBuffer; +TEST_F(TensorOneDimTest, TadDimensionsForTensor) { + Nd4jLong shape[3] = {rows, cols, dim2}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', rank, shape); + + for (int i = 0; i < rank; i++) { + // Along dimension 0: expect row vector with length 'dims[i]' + shape::TAD *zero = new shape::TAD; + zero->init(shapeBuffer, &dims[i], 1); + zero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZeroShapeBuffer = zero->tadOnlyShapeInfo; + Nd4jLong *testShape = shape::shapeOf(testDimZeroShapeBuffer); + Nd4jLong *testStride = shape::stride(testDimZeroShapeBuffer); + EXPECT_TRUE(arrsEquals(2, expectedShapes[i], testShape)); + EXPECT_TRUE(arrsEquals(2, expectedStrides[i], testStride)); + + delete zero; + } + + delete[] shapeBuffer; } - -TEST_F(MatrixTest,TadDimensionsForMatrix) { - Nd4jLong shape[2] = {rows,cols}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); - - shape::TAD *dimZero = new shape::TAD; - dimZero->init(shapeBuffer,&dims[0],1); - shape::TAD *dimOne = new shape::TAD; - dimOne->init(shapeBuffer,&dims[1],1); - //Along dimension 0: expect row vector with length 'rows' - Nd4jLong rowVectorShape[2] = {1,rows}; - auto expectedDimZeroShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); - dimZero->createTadOnlyShapeInfo(); - Nd4jLong *testDimZero = dimZero->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(2,expectedShapes[0],shape::shapeOf(testDimZero))); - EXPECT_TRUE(arrsEquals(2,expectedStrides[0],shape::stride(testDimZero))); - - delete[] expectedDimZeroShape; - //Along dimension 1: expect row vector with length 'cols' - Nd4jLong rowVectorColShape[2] {1,cols}; - auto expectedDimOneShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorColShape); - dimOne->createTadOnlyShapeInfo(); - Nd4jLong *testDimOneShape = dimOne->tadOnlyShapeInfo; - EXPECT_TRUE(arrsEquals(2,expectedShapes[1],shape::shapeOf(testDimOneShape))); - EXPECT_TRUE(arrsEquals(2,expectedStrides[1],shape::stride(testDimOneShape))); - - delete[] expectedDimOneShape; - delete dimOne; - delete dimZero; - delete[] shapeBuffer; +TEST_F(MatrixTest, TadDimensionsForMatrix) { + Nd4jLong shape[2] = {rows, cols}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', rank, shape); + + shape::TAD *dimZero = new shape::TAD; + dimZero->init(shapeBuffer, &dims[0], 1); + shape::TAD *dimOne = new shape::TAD; + dimOne->init(shapeBuffer, &dims[1], 1); + // Along dimension 0: expect row vector with length 'rows' + Nd4jLong rowVectorShape[2] = {1, rows}; + auto expectedDimZeroShape = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + dimZero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZero = dimZero->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(2, expectedShapes[0], shape::shapeOf(testDimZero))); + EXPECT_TRUE(arrsEquals(2, expectedStrides[0], shape::stride(testDimZero))); + + delete[] expectedDimZeroShape; + // Along dimension 1: expect row vector with length 'cols' + Nd4jLong rowVectorColShape[2]{1, cols}; + auto expectedDimOneShape = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, rowVectorColShape); + dimOne->createTadOnlyShapeInfo(); + Nd4jLong *testDimOneShape = dimOne->tadOnlyShapeInfo; + EXPECT_TRUE( + arrsEquals(2, expectedShapes[1], shape::shapeOf(testDimOneShape))); + EXPECT_TRUE( + arrsEquals(2, expectedStrides[1], shape::stride(testDimOneShape))); + + delete[] expectedDimOneShape; + delete dimOne; + delete dimZero; + delete[] shapeBuffer; } -TEST_F(VectorTest,VectorTadShape) { - Nd4jLong rowVector[2] = {2,2}; - auto rowBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVector); - int rowDimension = 1; - - Nd4jLong columnVector[2] = {2,2}; - auto colShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, columnVector); - int colDimension = 0; - - - shape::TAD *rowTad = new shape::TAD; - rowTad->init(rowBuffer,&rowDimension,1); - rowTad->createTadOnlyShapeInfo(); - Nd4jLong *rowTadShapeBuffer = rowTad->tadOnlyShapeInfo; - Nd4jLong *rowTadShape = shape::shapeOf(rowTadShapeBuffer); - shape::TAD *colTad = new shape::TAD; - colTad->init(colShapeBuffer,&colDimension,1); - colTad->createTadOnlyShapeInfo(); - Nd4jLong *colTadShapeBuffer = colTad->tadOnlyShapeInfo; - Nd4jLong *colTadShape = shape::shapeOf(colTadShapeBuffer); - Nd4jLong assertionShape[2] = {1,2}; - Nd4jLong assertionStride[2] = {1,1}; - EXPECT_TRUE(arrsEquals(2,assertionShape,rowTadShape)); - EXPECT_TRUE(arrsEquals(2,assertionStride,shape::stride(rowTadShapeBuffer))); - EXPECT_TRUE(arrsEquals(2,assertionShape,colTadShape)); - - delete[] rowBuffer; - delete[] colShapeBuffer; - delete rowTad; - delete colTad; +TEST_F(VectorTest, VectorTadShape) { + Nd4jLong rowVector[2] = {2, 2}; + auto rowBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, + 'c', 2, rowVector); + int rowDimension = 1; + + Nd4jLong columnVector[2] = {2, 2}; + auto colShapeBuffer = sd::ShapeBuilders::createShapeInfo( + sd::DataType::FLOAT32, 'c', 2, columnVector); + int colDimension = 0; + + shape::TAD *rowTad = new shape::TAD; + rowTad->init(rowBuffer, &rowDimension, 1); + rowTad->createTadOnlyShapeInfo(); + Nd4jLong *rowTadShapeBuffer = rowTad->tadOnlyShapeInfo; + Nd4jLong *rowTadShape = shape::shapeOf(rowTadShapeBuffer); + shape::TAD *colTad = new shape::TAD; + colTad->init(colShapeBuffer, &colDimension, 1); + colTad->createTadOnlyShapeInfo(); + Nd4jLong *colTadShapeBuffer = colTad->tadOnlyShapeInfo; + Nd4jLong *colTadShape = shape::shapeOf(colTadShapeBuffer); + Nd4jLong assertionShape[2] = {1, 2}; + Nd4jLong assertionStride[2] = {1, 1}; + EXPECT_TRUE(arrsEquals(2, assertionShape, rowTadShape)); + EXPECT_TRUE(arrsEquals(2, assertionStride, shape::stride(rowTadShapeBuffer))); + EXPECT_TRUE(arrsEquals(2, assertionShape, colTadShape)); + + delete[] rowBuffer; + delete[] colShapeBuffer; + delete rowTad; + delete colTad; } +TEST_F(ShapeTest, IsVector) { ASSERT_TRUE(shape::isVector(vectorShape, 2)); } +TEST_F(VectorTest, LinspaceCombinationTest) { + int rows = 3; + int cols = 4; + int len = rows * cols; + double *linspaced = linspace(1, rows * cols, len); + Nd4jLong shape[2] = {rows, cols}; + auto shapeBuffer = + sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - -TEST_F(ShapeTest,IsVector) { - ASSERT_TRUE(shape::isVector(vectorShape,2)); -} - -TEST_F(VectorTest,LinspaceCombinationTest) { - int rows = 3; - int cols = 4; - int len = rows * cols; - double *linspaced = linspace(1,rows * cols,len); - Nd4jLong shape[2] = {rows,cols}; - auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); - - delete[] shapeBuffer; - delete[] linspaced; + delete[] shapeBuffer; + delete[] linspaced; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp index 25f4f2c18001..8644960932a7 100644 --- a/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ShapeUtilsTests.cpp @@ -18,276 +18,253 @@ // Created by raver119 on 01.11.2017. // -#include "testlayers.h" -#include #include +#include +#include "testlayers.h" using namespace sd; using namespace sd::graph; class ShapeUtilsTests : public testing::Test { -public: - + public: }; ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalDimsToExclude_1) { - std::vector res = ShapeUtils::evalDimsToExclude(3, {0}); + std::vector res = ShapeUtils::evalDimsToExclude(3, {0}); - ASSERT_EQ(2, res.size()); - ASSERT_EQ(1, res.at(0)); - ASSERT_EQ(2, res.at(1)); + ASSERT_EQ(2, res.size()); + ASSERT_EQ(1, res.at(0)); + ASSERT_EQ(2, res.at(1)); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalDimsToExclude_2) { - std::vector res = ShapeUtils::evalDimsToExclude(4, {2, 3}); + std::vector res = ShapeUtils::evalDimsToExclude(4, {2, 3}); - ASSERT_EQ(2, res.size()); - ASSERT_EQ(0, res.at(0)); - ASSERT_EQ(1, res.at(1)); + ASSERT_EQ(2, res.size()); + ASSERT_EQ(0, res.at(0)); + ASSERT_EQ(1, res.at(1)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1) -{ - - Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1) { + Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2) -{ +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2) { + Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99}; + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3) -{ - - Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3) { + Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - const Nd4jLong *newShapeInfo = nullptr; - ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); - ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4) -{ +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4) { + Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99}; - Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99}; - Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99}; - Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99}; + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); - NDArray x(xShapeInfo); - NDArray y(yShapeInfo); + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + // for(int i=0; i<2*newShapeInfo[0]+4; ++i) + // std::cout<('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,4,5}); - std::vector dimensions = {1}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test1) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 4, 5}); + std::vector dimensions = {1}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo()); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo()); - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2) -{ +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {2, 1, 4, 5}); + std::vector dimensions = {1}; - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {2,1,4,5}); - std::vector dimensions = {1}; + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3) -{ - - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {1,1,1,5}); - std::vector dimensions = {0,1,2}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {1, 1, 1, 5}); + std::vector dimensions = {0, 1, 2}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } ////////////////////////////////////////////////////////////////// -TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4) -{ - - auto x = NDArrayFactory::create('c',{2,3,4,5}); - auto expected = NDArrayFactory::create('c', {1,1,1,1}); - std::vector dimensions = {0,1,2,3}; +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + auto expected = NDArrayFactory::create('c', {1, 1, 1, 1}); + std::vector dimensions = {0, 1, 2, 3}; - auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + auto newShapeInfo = + ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); - ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); } TEST_F(ShapeUtilsTests, Test_Strings_1) { - auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); - std::string exp("[2, 3, 4, 5]"); + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + std::string exp("[2, 3, 4, 5]"); - auto s = ShapeUtils::shapeAsString(&x); + auto s = ShapeUtils::shapeAsString(&x); - ASSERT_EQ(exp, s); + ASSERT_EQ(exp, s); } TEST_F(ShapeUtilsTests, Test_Backward_Axis_1) { - auto x = NDArrayFactory::create('c', {2, 4, 3}); - auto y = NDArrayFactory::create('c', {4, 3}); - std::vector exp({0}); + auto x = NDArrayFactory::create('c', {2, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + std::vector exp({0}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } TEST_F(ShapeUtilsTests, Test_Backward_Axis_2) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); - auto y = NDArrayFactory::create('c', {4, 1, 3}); - std::vector exp({0, 2}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 1, 3}); + std::vector exp({0, 2}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } - TEST_F(ShapeUtilsTests, Test_Backward_Axis_3) { - auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); - auto y = NDArrayFactory::create('c', {2, 1, 1, 3}); - std::vector exp({1, 2}); + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {2, 1, 1, 3}); + std::vector exp({1, 2}); - auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); - ASSERT_EQ(exp, z); + ASSERT_EQ(exp, z); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test1) { + int a = 1, b = 2, c = 3, d = 4; + std::vector expected = {2, 3, 0, 1}; - int a=1, b=2, c=3, d=4; - std::vector expected = {2, 3, 0, 1}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,b}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {c, d, a, b}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test2) { + int a = 1, b = 2, c = 3, d = 4; + std::vector expected = {0, 1, 3, 2}; - int a=1, b=2, c=3, d=4; - std::vector expected = {0, 1, 3, 2}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, d, c}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test3) { + int a = 2, b = 2, c = 3, d = 2; + std::vector expected = {0, 1, 3, 2}; - int a=2, b=2, c=3, d=2; - std::vector expected = {0, 1, 3, 2}; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); - - ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, d, c}); + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test4) { + int a = 2, b = 3, c = 4, d = 5; - int a=2, b=3, c=4, d=5; - - std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d}); - - ASSERT_TRUE(result.empty()); + std::vector result = + ShapeUtils::evalPermutFromTo({a, b, c, d}, {a, b, c, d}); + ASSERT_TRUE(result.empty()); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test5) { + int a = 1, b = 2, c = 3, d = 4; - int a=1, b=2, c=3, d=4; - - // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const char*); - ASSERT_TRUE(1); + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const + // char*); + ASSERT_TRUE(1); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, evalPermutFromTo_test6) { + int a = 1, b = 2, c = 3, d = 4; - int a=1, b=2, c=3, d=4; - - // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const char*); - ASSERT_TRUE(1); + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const + // char*); + ASSERT_TRUE(1); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, isPermutNecessary_test1) { - - ASSERT_TRUE(ShapeUtils::isPermutNecessary({1,0,2,3})); + ASSERT_TRUE(ShapeUtils::isPermutNecessary({1, 0, 2, 3})); } ////////////////////////////////////////////////////////////////// TEST_F(ShapeUtilsTests, isPermutNecessary_test2) { - - ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0,1,2,3})); + ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0, 1, 2, 3})); } - - diff --git a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp index b8ac4c8876c5..ff1034a25f58 100644 --- a/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SingleDimTests.cpp @@ -18,168 +18,152 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SingleDimTests : public testing::Test { -public: - + public: }; TEST_F(SingleDimTests, Test_Create_1) { - auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); - ASSERT_EQ(5, x.lengthOf()); - ASSERT_EQ(1, x.rankOf()); - ASSERT_TRUE(x.isVector()); - ASSERT_TRUE(x.isRowVector()); - ASSERT_FALSE(x.isMatrix()); + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + ASSERT_EQ(5, x.lengthOf()); + ASSERT_EQ(1, x.rankOf()); + ASSERT_TRUE(x.isVector()); + ASSERT_TRUE(x.isRowVector()); + ASSERT_FALSE(x.isMatrix()); } TEST_F(SingleDimTests, Test_Add_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 3, 4}); - x += 1.0f; + x += 1.0f; - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } - TEST_F(SingleDimTests, Test_Pairwise_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {2, 4, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 4, 6}); - x += x; + x += x; - ASSERT_TRUE(exp.isSameShape(&x)); - ASSERT_TRUE(exp.equalsTo(&x)); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); } TEST_F(SingleDimTests, Test_Concat_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto y = NDArrayFactory::create('c', {3}, {4, 5, 6}); - auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {3}, {4, 5, 6}); + auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); - sd::ops::concat op; - auto result = op.evaluate({&x, &y}, {}, {0}); + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Reduce_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - float r = x.reduceNumber(reduce::Sum).e(0); + float r = x.reduceNumber(reduce::Sum).e(0); - ASSERT_NEAR(6.0f, r, 1e-5f); + ASSERT_NEAR(6.0f, r, 1e-5f); } TEST_F(SingleDimTests, Test_IndexReduce_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto r = x.indexReduceNumber(indexreduce::IndexMax).e(0); + auto r = x.indexReduceNumber(indexreduce::IndexMax).e(0); - ASSERT_NEAR(2, r, 1e-5f); + ASSERT_NEAR(2, r, 1e-5f); } - TEST_F(SingleDimTests, Test_ExpandDims_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {0}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(SingleDimTests, Test_ExpandDims_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); - sd::ops::expand_dims op; - auto result = op.evaluate({&x}, {}, {1}); + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } - TEST_F(SingleDimTests, Test_Squeeze_1) { - std::vector vecS({1}); - std::vector vecB({3.0f}); - auto x = NDArrayFactory::create('c', vecS, vecB); - auto exp = NDArrayFactory::create(3.0f); + std::vector vecS({1}); + std::vector vecB({3.0f}); + auto x = NDArrayFactory::create('c', vecS, vecB); + auto exp = NDArrayFactory::create(3.0f); - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); - - ASSERT_EQ(exp.rankOf(), z.rankOf()); - ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_EQ(exp.rankOf(), z.rankOf()); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Squeeze_2) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::squeeze op; - auto result = op.evaluate({&x}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } TEST_F(SingleDimTests, Test_Permute_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); - auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - - sd::ops::permute op; - auto result = op.evaluate({&x}, {}, {0}); - ASSERT_EQ(ND4J_STATUS_OK, result.status()); - - auto z = result.at(0); + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp b/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp index 4dcedf03527f..da9b3f2fad25 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/SortCpuTests.cpp @@ -18,87 +18,110 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SortCpuTests : public testing::Test { -public: - + public: }; - TEST_F(SortCpuTests, test_linear_sort_by_key_1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo(), false); - sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_linear_sort_by_val_1) { - if (!Environment::getInstance()->isCPU()) - return; - - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + if (!Environment::getInstance()->isCPU()) return; - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_tad_sort_by_key_1) { - if (!Environment::getInstance()->isCPU()) - return; - - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - - int axis = 1; - sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + if (!Environment::getInstance()->isCPU()) return; + + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + int axis = 1; + sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCpuTests, test_tad_sort_by_val_1) { - if (!Environment::getInstance()->isCPU()) - return; - - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - - int axis = 1; - sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + if (!Environment::getInstance()->isCPU()) return; + + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + int axis = 1; + sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } diff --git a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu index 5a5e75b1b31e..129f5acede02 100644 --- a/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/SortCudaTests.cu @@ -18,107 +18,148 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include #include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class SortCudaTests : public testing::Test { -public: - + public: }; - TEST_F(SortCudaTests, test_linear_sort_by_key_1) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), + v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_linear_sort_by_val_1) { - auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_linear_sort_by_val_2) { - auto k = NDArrayFactory::create('c', {6}, {0, 1, 2, 3, 4, 5}); -// auto v = NDArrayFactory::create('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - NDArray v = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - auto ek = NDArrayFactory::create('c', {6}, {3, 0, 1, 2, 4, 5}); - auto ev = NDArrayFactory::create('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); - k.tickWriteDevice(); - v.tickWriteDevice(); - // k.printIndexedBuffer("KEYS"); - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create('c', {6}, {0, 1, 2, 3, 4, 5}); + // auto v = NDArrayFactory::create('c', {6}, {1.5, 3.5, 5.5, 9.5, + // 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + NDArray v = NDArrayFactory::create('c', {6}, + {0.9f, .75f, .6f, .95f, .5f, .3f}); + auto ek = NDArrayFactory::create('c', {6}, {3, 0, 1, 2, 4, 5}); + auto ev = NDArrayFactory::create('c', {6}, + {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), true); + k.tickWriteDevice(); + v.tickWriteDevice(); + // k.printIndexedBuffer("KEYS"); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_tad_sort_by_key_1) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - int axis = 1; - sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - // k.printIndexedBuffer("k"); - // v.printIndexedBuffer("v"); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } TEST_F(SortCudaTests, test_tad_sort_by_val_1) { - auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); - auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); - - auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); - - Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; - - int axis = 1; - sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); - k.tickWriteDevice(); - v.tickWriteDevice(); - - ASSERT_EQ(ek, k); - ASSERT_EQ(ev, v); + auto k = NDArrayFactory::create( + 'c', {2, 10}, + {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create( + 'c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, + 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create( + 'c', {2, 10}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create( + 'c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, + 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, + LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), + k.specialShapeInfo(), v.buffer(), v.shapeInfo(), + v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); } diff --git a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp index 37f52568f42f..505b5557bffd 100644 --- a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp +++ b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp @@ -18,130 +18,66 @@ // Created by raver119 on 04.08.17. // -#include "testlayers.h" -#include #include + +#include + #include "ops/specials_sparse.h" +#include "testlayers.h" using namespace sd; ////////////////////////////////////////////////////////////////////// class SparseUtilsTest : public testing::Test { -public: - static const Nd4jLong nnz = 40; - static const int rank = 3; + public: + static const Nd4jLong nnz = 40; + static const int rank = 3; }; - ////////////////////////////////////////////////////////////////////// TEST_F(SparseUtilsTest, SortCOOindices_Test) { - - #ifndef __CUDABLAS__ - - Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{ - 0,2,7, - 2,36,35, - 3,30,17, - 5,12,22, - 5,43,45, - 6,32,11, - 8,8,32, - 9,29,11, - 5,11,22, - 15,26,16, - 17,48,49, - 24,28,31, - 26,6,23, - 31,21,31, - 35,46,45, - 37,13,14, - 6,38,18, - 7,28,20, - 8,29,39, - 8,32,30, - 9,42,43, - 11,15,18, - 13,18,45, - 29,26,39, - 30,8,25, - 42,31,24, - 28,33,5, - 31,27,1, - 35,43,26, - 36,8,37, - 39,22,14, - 39,24,42, - 42,48,2, - 43,26,48, - 44,23,49, - 45,18,34, - 46,28,5, - 46,32,17, - 48,34,44, - 49,38,39, - }; - - Nd4jLong * expIndicesArr = new Nd4jLong[nnz * rank]{ - 0, 2, 7, - 2, 36, 35, - 3, 30, 17, - 5, 11, 22, - 5, 12, 22, - 5, 43, 45, - 6, 32, 11, - 6, 38, 18, - 7, 28, 20, - 8, 8, 32, - 8, 29, 39, - 8, 32, 30, - 9, 29, 11, - 9, 42, 43, - 11, 15, 18, - 13, 18, 45, - 15, 26, 16, - 17, 48, 49, - 24, 28, 31, - 26, 6, 23, - 28, 33, 5, - 29, 26, 39, - 30, 8, 25, - 31, 21, 31, - 31, 27, 1, - 35, 43, 26, - 35, 46, 45, - 36, 8, 37, - 37, 13, 14, - 39, 22, 14, - 39, 24, 42, - 42, 31, 24, - 42, 48, 2, - 43, 26, 48, - 44, 23, 49, - 45, 18, 34, - 46, 28, 5, - 46, 32, 17, - 48, 34, 44, - 49, 38, 39, - }; - - auto values = NDArrayFactory::create('c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}); - - auto expValues = NDArrayFactory::create('c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, 21, 22, 9, - 10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, 15, 30, 31, 25, 32, 33, - 34, 35, 36, 37, 38, 39 - }); - - sd::sparse::SparseUtils::sortCooIndicesGeneric(indicesArr, reinterpret_cast(values.buffer()), nnz, rank); - - for ( int i = 0; i < rank * nnz; ++i){ - ASSERT_EQ(expIndicesArr[i], indicesArr[i]); - } - - ASSERT_TRUE(expValues.equalsTo(values)); - - - delete[] indicesArr; - delete[] expIndicesArr; - - #endif +#ifndef __CUDABLAS__ + + Nd4jLong* indicesArr = new Nd4jLong[nnz * rank]{ + 0, 2, 7, 2, 36, 35, 3, 30, 17, 5, 12, 22, 5, 43, 45, 6, 32, 11, + 8, 8, 32, 9, 29, 11, 5, 11, 22, 15, 26, 16, 17, 48, 49, 24, 28, 31, + 26, 6, 23, 31, 21, 31, 35, 46, 45, 37, 13, 14, 6, 38, 18, 7, 28, 20, + 8, 29, 39, 8, 32, 30, 9, 42, 43, 11, 15, 18, 13, 18, 45, 29, 26, 39, + 30, 8, 25, 42, 31, 24, 28, 33, 5, 31, 27, 1, 35, 43, 26, 36, 8, 37, + 39, 22, 14, 39, 24, 42, 42, 48, 2, 43, 26, 48, 44, 23, 49, 45, 18, 34, + 46, 28, 5, 46, 32, 17, 48, 34, 44, 49, 38, 39, + }; + + Nd4jLong* expIndicesArr = new Nd4jLong[nnz * rank]{ + 0, 2, 7, 2, 36, 35, 3, 30, 17, 5, 11, 22, 5, 12, 22, 5, 43, 45, + 6, 32, 11, 6, 38, 18, 7, 28, 20, 8, 8, 32, 8, 29, 39, 8, 32, 30, + 9, 29, 11, 9, 42, 43, 11, 15, 18, 13, 18, 45, 15, 26, 16, 17, 48, 49, + 24, 28, 31, 26, 6, 23, 28, 33, 5, 29, 26, 39, 30, 8, 25, 31, 21, 31, + 31, 27, 1, 35, 43, 26, 35, 46, 45, 36, 8, 37, 37, 13, 14, 39, 22, 14, + 39, 24, 42, 42, 31, 24, 42, 48, 2, 43, 26, 48, 44, 23, 49, 45, 18, 34, + 46, 28, 5, 46, 32, 17, 48, 34, 44, 49, 38, 39, + }; + + auto values = NDArrayFactory::create( + 'c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}); + + auto expValues = NDArrayFactory::create( + 'c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, + 21, 22, 9, 10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, + 15, 30, 31, 25, 32, 33, 34, 35, 36, 37, 38, 39}); + + sd::sparse::SparseUtils::sortCooIndicesGeneric( + indicesArr, reinterpret_cast(values.buffer()), nnz, rank); + + for (int i = 0; i < rank * nnz; ++i) { + ASSERT_EQ(expIndicesArr[i], indicesArr[i]); + } + + ASSERT_TRUE(expValues.equalsTo(values)); + + delete[] indicesArr; + delete[] expIndicesArr; + +#endif } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/StashTests.cpp b/libnd4j/tests_cpu/layers_tests/StashTests.cpp index 2cba6682dd11..5c2b8651016c 100644 --- a/libnd4j/tests_cpu/layers_tests/StashTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StashTests.cpp @@ -22,67 +22,65 @@ #define LIBND4J_STASHTESTS_H #include -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class StashTests : public testing::Test { -public: - + public: }; TEST_F(StashTests, BasicTests_1) { - Stash stash; + Stash stash; - auto alpha = NDArrayFactory::create_('c',{5, 5}); - alpha->assign(1.0); + auto alpha = NDArrayFactory::create_('c', {5, 5}); + alpha->assign(1.0); - auto beta = NDArrayFactory::create_('c',{5, 5}); - beta->assign(2.0); + auto beta = NDArrayFactory::create_('c', {5, 5}); + beta->assign(2.0); - auto cappa = NDArrayFactory::create_('c',{5, 5}); - cappa->assign(3.0); + auto cappa = NDArrayFactory::create_('c', {5, 5}); + cappa->assign(3.0); - stash.storeArray(1, "alpha", alpha); - stash.storeArray(2, "alpha", beta); - stash.storeArray(3, "cappa", cappa); + stash.storeArray(1, "alpha", alpha); + stash.storeArray(2, "alpha", beta); + stash.storeArray(3, "cappa", cappa); - ASSERT_TRUE(stash.checkStash(1, "alpha")); - ASSERT_TRUE(stash.checkStash(2, "alpha")); - ASSERT_TRUE(stash.checkStash(3, "cappa")); + ASSERT_TRUE(stash.checkStash(1, "alpha")); + ASSERT_TRUE(stash.checkStash(2, "alpha")); + ASSERT_TRUE(stash.checkStash(3, "cappa")); - ASSERT_FALSE(stash.checkStash(3, "alpha")); - ASSERT_FALSE(stash.checkStash(2, "beta")); - ASSERT_FALSE(stash.checkStash(1, "cappa")); + ASSERT_FALSE(stash.checkStash(3, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(1, "cappa")); } - TEST_F(StashTests, BasicTests_2) { - Stash stash; - - auto alpha = NDArrayFactory::create_('c',{5, 5}); - alpha->assign(1.0); + Stash stash; - auto beta = NDArrayFactory::create_('c',{5, 5}); - beta->assign(2.0); + auto alpha = NDArrayFactory::create_('c', {5, 5}); + alpha->assign(1.0); - auto cappa = NDArrayFactory::create_('c',{5, 5}); - cappa->assign(3.0); + auto beta = NDArrayFactory::create_('c', {5, 5}); + beta->assign(2.0); - stash.storeArray(1, "alpha", alpha); - stash.storeArray(1, "beta", beta); - stash.storeArray(1, "cappa", cappa); + auto cappa = NDArrayFactory::create_('c', {5, 5}); + cappa->assign(3.0); - ASSERT_FALSE(stash.checkStash(2, "alpha")); - ASSERT_FALSE(stash.checkStash(2, "beta")); - ASSERT_FALSE(stash.checkStash(2, "cappa")); + stash.storeArray(1, "alpha", alpha); + stash.storeArray(1, "beta", beta); + stash.storeArray(1, "cappa", cappa); - ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); - ASSERT_TRUE(beta == stash.extractArray(1, "beta")); - ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); + ASSERT_FALSE(stash.checkStash(2, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(2, "cappa")); + ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); + ASSERT_TRUE(beta == stash.extractArray(1, "beta")); + ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); } -#endif //LIBND4J_STASHTESTS_H +#endif // LIBND4J_STASHTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index 272c410c7c4c..01d4fe00084b 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -20,846 +20,857 @@ // @author Oleg Semeniv // - #include #include -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; class StringTests : public testing::Test { -public: - + public: }; ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_2) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_3) { + auto array = NDArrayFactory::string( + {3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); - auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_4) { + NDArray array({3, 2}, + std::vector{U"alpha", U"beta", U"gamma€한", + U"pÿqwe", U"ß水𝄋", U"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_5) { + NDArray array({3, 2}, + std::vector{u"alpha", u"beta", u"gamma€한", + u"pÿqwe", u"ß水𝄋", u"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_6) { + NDArray array({3, 2}, std::vector{"alpha", "beta", "gamma€한", + "pÿqwe", "ß水𝄋", "omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_7) { + NDArray array({3, 2}, + std::vector{U"alpha", U"beta", U"gamma€한", + U"pÿqwe", U"ß水𝄋", U"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_8) { + NDArray array({3, 2}, + std::vector{u"alpha", u"beta", u"gamma€한", + u"pÿqwe", u"ß水𝄋", u"omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_9) { + NDArray array({3, 2}, std::vector{"alpha", "beta", "gamma€한", + "pÿqwe", "ß水𝄋", "omega"}); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_10) { - - NDArray array(std::u32string(U"gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::u32string(U"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_11) { + NDArray array(U"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array(U"gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_12) { - - NDArray array(std::u16string(u"gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::u16string(u"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_13) { + NDArray array(u"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array(u"gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_14) { - - NDArray array(std::string("gamma€한")); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - array.printIndexedBuffer("String array"); + NDArray array(std::string("gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_15) { + NDArray array("gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - NDArray array("gamma€한"); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_16) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{"alpha", "beta", "gamma", "phi", "theta", + "omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_17) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{"alpha", "beta", "gamma", "phi", "theta", + "omega"}); - auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_18) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{u"alpha", u"beta", u"gamma", u"phi", + u"theta", u"omega"}); - auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_19) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{u"alpha", u"beta", u"gamma", u"phi", + u"theta", u"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_20) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{U"alpha", U"beta", U"gamma", U"phi", + U"theta", U"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); - - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_21) { + auto array = NDArrayFactory::string( + {3, 2}, std::vector{U"alpha", U"òèçùà12345¤z", + U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", + U"theta", U"omega"}); - auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" }); - - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_22) { - std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_23) { - std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); - auto array = NDArrayFactory::string(f.c_str()); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); + auto z = array.e(0); - ASSERT_EQ(f, z); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_1) { - auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); - auto vector = array.asByteVector(); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_1) { - std::string f("alpha"); - auto array = NDArrayFactory::string(f); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); - ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_2) { - auto array = NDArrayFactory::string( {2}, {"alpha", "beta"}); + auto array = NDArrayFactory::string({2}, {"alpha", "beta"}); - ASSERT_EQ(9, StringUtils::byteLength(array)); + ASSERT_EQ(9, StringUtils::byteLength(array)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_split_1) { - auto split = StringUtils::split("alpha beta gamma", " "); + auto split = StringUtils::split("alpha beta gamma", " "); - ASSERT_EQ(3, split.size()); - ASSERT_EQ(std::string("alpha"), split[0]); - ASSERT_EQ(std::string("beta"), split[1]); - ASSERT_EQ(std::string("gamma"), split[2]); + ASSERT_EQ(3, split.size()); + ASSERT_EQ(std::string("alpha"), split[0]); + ASSERT_EQ(std::string("beta"), split[1]); + ASSERT_EQ(std::string("gamma"), split[2]); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf8_utf16) { + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16Exp = + u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::u16string utf16Res; - ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); - ASSERT_EQ(utf16Res.size(), utf16Exp.size()); - for (auto i = 0; i < utf16Exp.size(); i++) { - ASSERT_EQ(utf16Exp[i], utf16Res[i]); - } + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf8_utf32) { + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Exp = + U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); - std::u32string utf32Res; - ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); - - ASSERT_EQ(utf32Res.size(), utf32Exp.size()); - for (auto i = 0; i < utf32Exp.size(); i++) { - ASSERT_EQ(utf32Exp[i], utf32Res[i]); - } + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf16_utf8) { + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::string utf8Res; - ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); + std::string utf8Res; + ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); - ASSERT_EQ(utf8Res.size(), utf8Exp.size()); - for (auto i = 0; i < utf8Exp.size(); i++) { - ASSERT_EQ(utf8Exp[i], utf8Res[i]); - } + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf32_utf8) { + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; - - std::string utf8Res; - ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); + std::string utf8Res; + ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); - ASSERT_EQ(utf8Res.size(), utf8Exp.size()); - for (auto i = 0; i < utf8Exp.size(); i++) { - ASSERT_EQ(utf8Exp[i], utf8Res[i]); - } + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf16_utf32) { + std::u32string utf32Exp = + U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); - std::u32string utf32Res; - ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); - - ASSERT_EQ(utf32Res.size(), utf32Exp.size()); - for (auto i = 0; i < utf32Exp.size(); i++) { - ASSERT_EQ(utf32Exp[i], utf32Res[i]); - } + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, test_unicode_utf32_utf16) { + std::u16string utf16Exp = + u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; - - std::u16string utf16Res; - ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); - ASSERT_EQ(utf16Res.size(), utf16Exp.size()); - for (auto i = 0; i < utf16Exp.size(); i++) { - ASSERT_EQ(utf16Exp[i], utf16Res[i]); - } + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_Default) { - - std::string f("alpha"); - auto array = NDArrayFactory::string(f); + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); - ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16); - std::u16string f16(u"alpha"); - auto array16 = NDArrayFactory::string(f16); - - ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16)); + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); - std::u32string f32(U"alpha"); - auto array32 = NDArrayFactory::string(f32); + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32); - ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); + ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, byte_length_test_UTF16) { - std::string f(u8"alpha"); - auto array = NDArrayFactory::string(f, sd::DataType::UTF16); + std::string f(u8"alpha"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); + ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); - std::u16string f16(u"alpha"); - auto array16 = NDArrayFactory::string(f16, sd::DataType::UTF16); + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); - std::u32string f32(U"alpha"); - auto array32 = NDArrayFactory::string(f32, sd::DataType::UTF16); + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32, sd::DataType::UTF16); - ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); + ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU8) { + std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); - auto z = array.e(0); - - std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(f, z); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU8) { - std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32.c_str(), sd::DataType::UTF8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32.c_str(), sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(f, z); + auto z = array.e(0); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU16) { + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - - ASSERT_EQ(z, f16); + ASSERT_EQ(z, f16); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU16) { + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32, sd::DataType::UTF16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(z, f16); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, f16); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF16toU32) { + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f16, sd::DataType::UTF32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - auto z = array.e(0); - std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); - ASSERT_EQ(z, fres); + auto z = array.e(0); + std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, fres); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF32toU32) { + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - auto z = array.e(0); - ASSERT_EQ(f32, z); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + ASSERT_EQ(f32, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_UTF8toU32) { + std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f, sd::DataType::UTF32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto z = array.e(0); - ASSERT_EQ(f32, z); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto z = array.e(0); + ASSERT_EQ(f32, z); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {"alpha€", "beta", "gamma水", "phi", "theta", "omega水"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {"alpha€", "beta水", "gamma", "phi", "theta", "omega"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U8toUTF16) { - auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U8toUTF32) { - auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) { - auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string( + {3, 2}, {u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega"}, + sd::DataType::UTF8); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF8) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string({3}, {u"alpha", u"beta", u"gamma"}, + sd::DataType::UTF8); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF16) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {u"alpha", u"beta", u"gamma"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U16toUTF32) { - auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {u"alpha水", u"beta", u"gamma水"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) { - auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水"}, + sd::DataType::UTF32); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) { - auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega"}, + sd::DataType::UTF16); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); - printf("Array elements size: \n"); - for (int e = 0; e < array.lengthOf(); e++) { - printf("Element %d size: %d\n", e, static_cast(array.e(e).size())); - } + printf("Array elements size: \n"); + for (int e = 0; e < array.lengthOf(); e++) { + printf("Element %d size: %d\n", e, + static_cast(array.e(e).size())); + } } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) { - auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string( + {3, 2}, {U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega"}, + sd::DataType::UTF8); - ASSERT_EQ(6, array.lengthOf()); - ASSERT_EQ(2, array.rankOf()); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); - array.printIndexedBuffer("String array"); + array.printIndexedBuffer("String array"); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF32) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, sd::DataType::UTF32); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta", U"gamma"}, + sd::DataType::UTF32); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF16) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, sd::DataType::UTF16); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta水", U"gamma水"}, + sd::DataType::UTF16); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Export_Test_U32toUTF8) { - auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, sd::DataType::UTF8); + auto array = NDArrayFactory::string({3}, {U"alpha", U"beta", U"gamma水"}, + sd::DataType::UTF8); - auto vector = array.asByteVector(); + auto vector = array.asByteVector(); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_UTF16) { - std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_dup_UTF32) { - std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(f); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto dup = new NDArray(array.dup()); + auto dup = new NDArray(array.dup()); - auto z0 = array.e(0); - auto z1 = dup->e(0); + auto z0 = array.e(0); + auto z1 = dup->e(0); - ASSERT_EQ(f, z0); - ASSERT_EQ(f, z1); + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); - delete dup; + delete dup; } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF8) { - - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u32, z0); - ASSERT_EQ(u8, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u8, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF16) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); - - auto aCast = array.cast(sd::DataType::UTF16); + auto aCast = array.cast(sd::DataType::UTF16); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u32, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF32toUTF32) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); - auto array = NDArrayFactory::string(u32); - ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF32); - auto aCast = array.cast(sd::DataType::UTF32); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u32, z0); - ASSERT_EQ(u32, z1); + ASSERT_EQ(u32, z0); + ASSERT_EQ(u32, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF16) { + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF16); + auto aCast = array.cast(sd::DataType::UTF16); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u16, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u16, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF32) { + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF32); - auto aCast = array.cast(sd::DataType::UTF32); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u32, z1); - ASSERT_EQ(u16, z0); + ASSERT_EQ(u32, z1); + ASSERT_EQ(u16, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF16toUTF8) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(u16); - ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z1); - ASSERT_EQ(u16, z0); + ASSERT_EQ(u8, z1); + ASSERT_EQ(u16, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF8) { + std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); - - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF8); + auto aCast = array.cast(sd::DataType::UTF8); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z1); - ASSERT_EQ(u8, z0); + ASSERT_EQ(u8, z1); + ASSERT_EQ(u8, z0); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF16) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); - std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + auto aCast = array.cast(sd::DataType::UTF16); - auto aCast = array.cast(sd::DataType::UTF16); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - auto z0 = array.e(0); - auto z1 = aCast.e(0); - - ASSERT_EQ(u8, z0); - ASSERT_EQ(u16, z1); + ASSERT_EQ(u8, z0); + ASSERT_EQ(u16, z1); } ///////////////////////////////////////////////////////////////////////// TEST_F(StringTests, Basic_cast_UTF8toUTF32) { + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); - - std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); - auto array = NDArrayFactory::string(u8); - ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); - ASSERT_EQ(1, array.lengthOf()); - ASSERT_EQ(0, array.rankOf()); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); - auto aCast = array.cast(sd::DataType::UTF32); + auto aCast = array.cast(sd::DataType::UTF32); - auto z0 = array.e(0); - auto z1 = aCast.e(0); + auto z0 = array.e(0); + auto z1 = aCast.e(0); - ASSERT_EQ(u8, z0); - ASSERT_EQ(u32, z1); + ASSERT_EQ(u8, z0); + ASSERT_EQ(u32, z1); } diff --git a/libnd4j/tests_cpu/layers_tests/TadTests.cpp b/libnd4j/tests_cpu/layers_tests/TadTests.cpp index a2cdec003c6b..5a5ecc236f69 100644 --- a/libnd4j/tests_cpu/layers_tests/TadTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TadTests.cpp @@ -21,423 +21,433 @@ #ifndef LIBND4J_TADTESTS_H #define LIBND4J_TADTESTS_H -#include "testlayers.h" #include +#include #include + #include -#include + +#include "testlayers.h" using namespace sd; class TadTests : public testing::Test { -public: - int numLoops = 100000000; + public: + int numLoops = 100000000; - int extLoops = 1000; - int intLoops = 1000; + int extLoops = 1000; + int intLoops = 1000; }; TEST_F(TadTests, Test4DTad1) { + NDArray* arraySource = sd::NDArrayFactory::linspace(1.0f, 10000.0f, 10000); - NDArray* arraySource = sd::NDArrayFactory::linspace(1.0f, 10000.0f, 10000); - - Nd4jLong badShape[] = {4, 2, 1, 4, 4, 80, 16, 4, 1, 8192, -1, 99}; - Nd4jLong goodShape[] = {4, 2, 1, 4, 4, 16, 16, 4, 1, 8192, 1, 99}; + Nd4jLong badShape[] = {4, 2, 1, 4, 4, 80, 16, 4, 1, 8192, -1, 99}; + Nd4jLong goodShape[] = {4, 2, 1, 4, 4, 16, 16, 4, 1, 8192, 1, 99}; - std::vector buff = arraySource->getBufferAsVector(); + std::vector buff = arraySource->getBufferAsVector(); - NDArray* arrayExp = new NDArray(buff.data(), goodShape); - NDArray* arrayBad = new NDArray(buff.data(), badShape); + NDArray* arrayExp = new NDArray(buff.data(), goodShape); + NDArray* arrayBad = new NDArray(buff.data(), badShape); - int dim = 1; - shape::TAD tad; - tad.init(arrayBad->shapeInfo(), &dim, 1); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + int dim = 1; + shape::TAD tad; + tad.init(arrayBad->shapeInfo(), &dim, 1); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - int exp[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 }; - for (int e = 0; e < 32; e++) - ASSERT_EQ((int) tad.tadOffsets[e], exp[e]); + int exp[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95}; + for (int e = 0; e < 32; e++) ASSERT_EQ((int)tad.tadOffsets[e], exp[e]); - delete arrayExp; - delete arrayBad; - delete arraySource; + delete arrayExp; + delete arrayBad; + delete arraySource; } TEST_F(TadTests, TestNumTads1) { - auto x = NDArrayFactory::create('c', {2, 3}); - auto y = NDArrayFactory::create('c', {2, 2}); + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {2, 2}); - std::vector dim({0}); + std::vector dim({0}); - Nd4jLong tadLengthX = shape::tadLength(x.shapeInfo(), dim.data(), dim.size()); - Nd4jLong numTadsX = x.lengthOf() / tadLengthX; + Nd4jLong tadLengthX = shape::tadLength(x.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsX = x.lengthOf() / tadLengthX; - Nd4jLong tadLengthY = shape::tadLength(y.shapeInfo(), dim.data(), dim.size()); - Nd4jLong numTadsY = y.lengthOf() / tadLengthY; + Nd4jLong tadLengthY = shape::tadLength(y.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsY = y.lengthOf() / tadLengthY; - ASSERT_EQ(2, tadLengthX); - ASSERT_EQ(3, numTadsX); + ASSERT_EQ(2, tadLengthX); + ASSERT_EQ(3, numTadsX); - ASSERT_EQ(2, tadLengthY); - ASSERT_EQ(2, numTadsY); + ASSERT_EQ(2, tadLengthY); + ASSERT_EQ(2, numTadsY); } TEST_F(TadTests, TestShapeTad_1) { + float buff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 16, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; - float buff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; - Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; + NDArray input(buff, shapeInfo); - NDArray input(buff, shapeInfo); + std::vector dimensions = {0, 1, 2}; + Nd4jLong tadLength = + shape::tadLength(input.shapeInfo(), dimensions.data(), dimensions.size()); + Nd4jLong numTads = input.lengthOf() / tadLength; - std::vector dimensions = {0,1,2}; - Nd4jLong tadLength = shape::tadLength(input.shapeInfo(), dimensions.data(), dimensions.size()); - Nd4jLong numTads = input.lengthOf() / tadLength; + shape::TAD tad; + tad.init(input.shapeInfo(), dimensions.data(), dimensions.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - shape::TAD tad; - tad.init(input.shapeInfo(), dimensions.data(), dimensions.size()); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + auto tadShapeInfo = + new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo[0])]; + std::memcpy(tadShapeInfo, tad.tadOnlyShapeInfo, + shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); - auto tadShapeInfo = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo[0])]; - std::memcpy(tadShapeInfo, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + float* tadBuff = reinterpret_cast(input.buffer()) + tad.tadOffsets[0]; + NDArray tadArr(tadBuff, tadShapeInfo); - float* tadBuff = reinterpret_cast(input.buffer()) + tad.tadOffsets[0]; - NDArray tadArr(tadBuff, tadShapeInfo); + ASSERT_TRUE(numTads == 1); + ASSERT_TRUE(input.isSameShapeStrict(tadArr)); + ASSERT_TRUE(input.equalsTo(&tadArr)); - ASSERT_TRUE(numTads==1); - ASSERT_TRUE(input.isSameShapeStrict(tadArr)); - ASSERT_TRUE(input.equalsTo(&tadArr)); - - delete[] tadShapeInfo; + delete[] tadShapeInfo; } TEST_F(TadTests, TadNoAxis_1) { - auto array = NDArrayFactory::create('c', {2, 3}); + auto array = NDArrayFactory::create('c', {2, 3}); - shape::TAD tad; - tad.init(array.shapeInfo(), nullptr, 0); - tad.createTadOnlyShapeInfo(); - tad.createOffsets(); + shape::TAD tad; + tad.init(array.shapeInfo(), nullptr, 0); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); - ASSERT_TRUE(tad.wholeThing); + ASSERT_TRUE(tad.wholeThing); - ASSERT_TRUE(shape::equalsStrict(tad.tadOnlyShapeInfo, array.shapeInfo())); + ASSERT_TRUE(shape::equalsStrict(tad.tadOnlyShapeInfo, array.shapeInfo())); } TEST_F(TadTests, TadEdgeCase_1) { - auto array = NDArrayFactory::create('c', {5, 4, 1}); - auto exp = NDArrayFactory::create('c', {5, 4}); - array.linspace(1); + auto array = NDArrayFactory::create('c', {5, 4, 1}); + auto exp = NDArrayFactory::create('c', {5, 4}); + array.linspace(1); - auto tad = array(0, {2}); + auto tad = array(0, {2}); - ASSERT_TRUE(exp.isSameShape(tad)); + ASSERT_TRUE(exp.isSameShape(tad)); } TEST_F(TadTests, TestEdgeCase_2) { + auto array = + NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); - - for (int e = 0 ; e < array.lengthOf(); e++) { - auto tad = array(e, {0,1}); - ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); - } + for (int e = 0; e < array.lengthOf(); e++) { + auto tad = array(e, {0, 1}); + ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); + } } TEST_F(TadTests, TadEdgeCase_2) { - auto array = NDArrayFactory::create('c', {2, 3, 4}); + auto array = NDArrayFactory::create('c', {2, 3, 4}); - auto tad = array(0, {0,2}); + auto tad = array(0, {0, 2}); - ASSERT_EQ(3, tad.lengthOf()); + ASSERT_EQ(3, tad.lengthOf()); } - TEST_F(TadTests, test_Tad_Ews_optimization_1) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {1,2}; - ASSERT_TRUE(xTad.dimensionsDescending(3, array.data(), array.size())); + std::array array = {1, 2}; + ASSERT_TRUE(xTad.dimensionsDescending(3, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_2) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {0,2}; - ASSERT_FALSE(xTad.dimensionsDescending(3, array.data(), array.size())); + std::array array = {0, 2}; + ASSERT_FALSE(xTad.dimensionsDescending(3, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_3) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {1}; - ASSERT_TRUE(xTad.dimensionsDescending(2, array.data(), array.size())); + std::array array = {1}; + ASSERT_TRUE(xTad.dimensionsDescending(2, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_4) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {0}; - ASSERT_TRUE(xTad.dimensionsDescending(1, array.data(), array.size())); + std::array array = {0}; + ASSERT_TRUE(xTad.dimensionsDescending(1, array.data(), array.size())); } TEST_F(TadTests, test_Tad_Ews_optimization_5) { - shape::TAD xTad; + shape::TAD xTad; - std::array array = {2,3}; - ASSERT_TRUE(xTad.dimensionsDescending(4, array.data(), array.size())); + std::array array = {2, 3}; + ASSERT_TRUE(xTad.dimensionsDescending(4, array.data(), array.size())); } TEST_F(TadTests, test_TAD_empty_dims_1) { - Nd4jLong xShape[8] = {2, 150, 1, 3, 1, 16384, 3, 99}; - shape::TAD xTad; - xTad.init(xShape, reinterpret_cast(112L), 0); - xTad.createTadOnlyShapeInfo(); - xTad.createOffsets(); + Nd4jLong xShape[8] = {2, 150, 1, 3, 1, 16384, 3, 99}; + shape::TAD xTad; + xTad.init(xShape, reinterpret_cast(112L), 0); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); } TEST_F(TadTests, test_tad_order_1) { - Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 10, 1, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim = 1; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 10, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 1; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } TEST_F(TadTests, test_tad_order_2) { - Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 150, 1, 10, 8192, 10, 99}; - shape::TAD xTad; - int dim = 0; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 150, 1, 10, 8192, 10, 99}; + shape::TAD xTad; + int dim = 0; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } - TEST_F(TadTests, test_tad_order_3) { - Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 1, 30, 1, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim = 2; - xTad.init(xShape, &dim, 1); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[10] = {3, 10, 20, 30, 600, 30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 30, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 2; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } - TEST_F(TadTests, test_tad_order_4) { - Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; - Nd4jLong tShape[8] = {2, 20, 30, 30, 1, 8192, 1, 99}; - shape::TAD xTad; - int dim[2] = {1, 2}; - xTad.init(xShape, dim, 2); - xTad.createTadOnlyShapeInfo(); - - ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); + Nd4jLong xShape[10] = {3, 10, 20, 30, 600, 30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 20, 30, 30, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim[2] = {1, 2}; + xTad.init(xShape, dim, 2); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); } TEST_F(TadTests, test_column_1) { - auto x = NDArrayFactory::create('c', {5, 2}); - auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), 0); + auto x = NDArrayFactory::create('c', {5, 2}); + auto tadPack = + sd::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), 0); - ASSERT_EQ(1, shape::rank(tadPack.primaryShapeInfo())); - ASSERT_EQ(5, shape::length(tadPack.primaryShapeInfo())); - ASSERT_TRUE(shape::isVector(tadPack.primaryShapeInfo())); + ASSERT_EQ(1, shape::rank(tadPack.primaryShapeInfo())); + ASSERT_EQ(5, shape::length(tadPack.primaryShapeInfo())); + ASSERT_TRUE(shape::isVector(tadPack.primaryShapeInfo())); - auto scalarViewPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(tadPack.primaryShapeInfo(), 0); + auto scalarViewPack = sd::ConstantTadHelper::getInstance()->tadForDimensions( + tadPack.primaryShapeInfo(), 0); - ASSERT_TRUE(shape::equalsStrict(tadPack.primaryShapeInfo(), scalarViewPack.primaryShapeInfo())); + ASSERT_TRUE(shape::equalsStrict(tadPack.primaryShapeInfo(), + scalarViewPack.primaryShapeInfo())); } /////////////////////////////////////////////////////////////////// TEST_F(TadTests, calcOffsets_1) { + Nd4jLong shapeInfoF[10] = {3, 2, 3, 4, 1, 2, 6, 8192, 1, 102}; + Nd4jLong shapeInfoC[10] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; + Nd4jLong shapeInfoFC[10] = {3, 2, 3, 4, 1, 2, 6, 8192, 1, 99}; + ; - Nd4jLong shapeInfoF[10] = {3, 2,3,4, 1,2,6, 8192, 1, 102}; - Nd4jLong shapeInfoC[10] = {3, 2,3,4, 12,4,1, 8192, 1, 99}; - Nd4jLong shapeInfoFC[10] = {3, 2,3,4, 1,2,6, 8192, 1, 99};; - - Nd4jLong expOffsetsF[24] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23}; - Nd4jLong expOffsetsC[24] = {0,12,4,16,8,20,1,13,5,17,9,21,2,14,6,18,10,22,3,15,7,19,11,23}; + Nd4jLong expOffsetsF[24] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + Nd4jLong expOffsetsC[24] = {0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, + 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}; - Nd4jLong offsets[24]; + Nd4jLong offsets[24]; - shape::calcOffsets(shapeInfoF, offsets, 'f'); + shape::calcOffsets(shapeInfoF, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsF[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsF[e]); - shape::calcOffsets(shapeInfoC, offsets, 'f'); + shape::calcOffsets(shapeInfoC, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsC[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsC[e]); - shape::calcOffsets(shapeInfoFC, offsets, 'f'); + shape::calcOffsets(shapeInfoFC, offsets, 'f'); - for (int e = 0; e < 24; e++) - ASSERT_TRUE(offsets[e] == expOffsetsF[e]); + for (int e = 0; e < 24; e++) ASSERT_TRUE(offsets[e] == expOffsetsF[e]); } - ///////////////////////////////////////////////////////////////// TEST_F(TadTests, outerArrayIndexes_1) { - - NDArray x('c', {2,3,4,5}, sd::DataType::FLOAT32); - int maxIdxs[120]; - - NDArray y1('c', {3,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude1 = {0,2}; - const int n1[] = {20,25,30,35, 80,85,90,95}; - int minIdx = 5; - - int N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y1.shapeInfo(), dimsToExclude1.data()); - ASSERT_TRUE(N == x.lengthOf()/y1.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n1[i] == maxIdxs[i]); - - NDArray y2('c', {4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude2 = {0,1}; - const int n2[] = {12,32,52, 72,92,112}; - minIdx = 12; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y2.shapeInfo(), dimsToExclude2.data()); - ASSERT_TRUE(N == x.lengthOf()/y2.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n2[i] == maxIdxs[i]); - - NDArray y3('c', {2,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude3 = {1,2}; - const int n3[] = {64,69,74,79,84,89,94,99,104,109,114,119}; - minIdx = 9; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y3.shapeInfo(), dimsToExclude3.data()); - ASSERT_TRUE(N == x.lengthOf()/y3.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n3[i] == maxIdxs[i]); - - NDArray y4('c', {2,3}, sd::DataType::FLOAT32); - const std::vector dimsToExclude4 = {2,3}; - const int n4[] = {20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y4.shapeInfo(), dimsToExclude4.data()); - ASSERT_TRUE(N == x.lengthOf()/y4.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n4[i] == maxIdxs[i]); - - NDArray y5('c', {2,4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude5 = {1,3}; - const int n5[] = {65,66,67,68,69, 85,86,87,88,89, 105,106,107,108,109}; - minIdx = 5; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y5.shapeInfo(), dimsToExclude5.data()); - ASSERT_TRUE(N == x.lengthOf()/y5.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n5[i] == maxIdxs[i]); - - NDArray y6('c', {2,3,4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude6 = {3}; - const int n6[] = {65,66,67,68,69}; - minIdx = 13; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y6.shapeInfo(), dimsToExclude6.data()); - ASSERT_TRUE(N == x.lengthOf()/y6.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n6[i] == maxIdxs[i]); - - NDArray y7('c', {4}, sd::DataType::FLOAT32); - const std::vector dimsToExclude7 = {0,1,3}; - const int n7[] = {15,16,17,18,19, 35,36,37,38,39, 55,56,57,58,59, 75,76,77,78,79, 95,96,97,98,99, 115,116,117,118,119}; - minIdx = 3; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y7.shapeInfo(), dimsToExclude7.data()); - ASSERT_TRUE(N == x.lengthOf()/y7.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n7[i] == maxIdxs[i]); - - NDArray y8('c', {5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude8 = {0,1,2}; - const int n8[] = {0,5,10,15, 20,25,30,35, 40,45,50,55, 60,65,70,75, 80,85,90,95, 100,105,110,115}; - minIdx = 0; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y8.shapeInfo(), dimsToExclude8.data()); - ASSERT_TRUE(N == x.lengthOf()/y8.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n8[i] == maxIdxs[i]); - - NDArray y9('c', {2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude9 = {1,2,3}; - const int n9[] = {60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y9.shapeInfo(), dimsToExclude9.data()); - ASSERT_TRUE(N == x.lengthOf()/y9.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n9[i] == maxIdxs[i]); - - NDArray y10('c', {3,4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude10 = {0}; - const int n10[] = {11, 71}; - minIdx = 11; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y10.shapeInfo(), dimsToExclude10.data()); - ASSERT_TRUE(N == x.lengthOf()/y10.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n10[i] == maxIdxs[i]); - - NDArray y11('c', {2,4,5}, sd::DataType::FLOAT32); - const std::vector dimsToExclude11 = {1}; - const int n11[] = {66, 86, 106}; - minIdx = 26; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y11.shapeInfo(), dimsToExclude11.data()); - ASSERT_TRUE(N == x.lengthOf()/y11.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n11[i] == maxIdxs[i]); - - NDArray y12('c', {3,2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude12 = {0,2}; - const int n12[] = {0,2,4,5,7,9,10,12,14,15,17,19,60,62,64,65,67,69,70,72,74,75,77,79}; - minIdx = 0; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y12.shapeInfo(), dimsToExclude12.data()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n12[i] == maxIdxs[i]); - - NDArray y13('c', {3,2}, sd::DataType::FLOAT32); - const std::vector dimsToExclude13 = {0,2}; - const int n13[] = {1,3,6,8,11,13,16,18,61,63,66,68,71,73,76,78}; - minIdx = 1; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y13.shapeInfo(), dimsToExclude13.data()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n13[i] == maxIdxs[i]); - - NDArray y14('c', {4,5}, sd::DataType::FLOAT32); - const int n14[] = {12,32,52, 72,92,112}; - minIdx = 12; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y14.shapeInfo(), nullptr); - ASSERT_TRUE(N == x.lengthOf()/y14.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n14[i] == maxIdxs[i]); - - NDArray y15('c', {3,4,5}, sd::DataType::FLOAT32); - const int n15[] = {11, 71}; - minIdx = 11; - - N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y15.shapeInfo(), nullptr); - ASSERT_TRUE(N == x.lengthOf()/y15.lengthOf()); - for(int i = 0; i < N; ++i) - ASSERT_TRUE(n15[i] == maxIdxs[i]); + NDArray x('c', {2, 3, 4, 5}, sd::DataType::FLOAT32); + int maxIdxs[120]; + + NDArray y1('c', {3, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude1 = {0, 2}; + const int n1[] = {20, 25, 30, 35, 80, 85, 90, 95}; + int minIdx = 5; + + int N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), + y1.shapeInfo(), dimsToExclude1.data()); + ASSERT_TRUE(N == x.lengthOf() / y1.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n1[i] == maxIdxs[i]); + + NDArray y2('c', {4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude2 = {0, 1}; + const int n2[] = {12, 32, 52, 72, 92, 112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y2.shapeInfo(), + dimsToExclude2.data()); + ASSERT_TRUE(N == x.lengthOf() / y2.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n2[i] == maxIdxs[i]); + + NDArray y3('c', {2, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude3 = {1, 2}; + const int n3[] = {64, 69, 74, 79, 84, 89, 94, 99, 104, 109, 114, 119}; + minIdx = 9; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y3.shapeInfo(), + dimsToExclude3.data()); + ASSERT_TRUE(N == x.lengthOf() / y3.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n3[i] == maxIdxs[i]); + + NDArray y4('c', {2, 3}, sd::DataType::FLOAT32); + const std::vector dimsToExclude4 = {2, 3}; + const int n4[] = {20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y4.shapeInfo(), + dimsToExclude4.data()); + ASSERT_TRUE(N == x.lengthOf() / y4.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n4[i] == maxIdxs[i]); + + NDArray y5('c', {2, 4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude5 = {1, 3}; + const int n5[] = {65, 66, 67, 68, 69, 85, 86, 87, + 88, 89, 105, 106, 107, 108, 109}; + minIdx = 5; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y5.shapeInfo(), + dimsToExclude5.data()); + ASSERT_TRUE(N == x.lengthOf() / y5.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n5[i] == maxIdxs[i]); + + NDArray y6('c', {2, 3, 4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude6 = {3}; + const int n6[] = {65, 66, 67, 68, 69}; + minIdx = 13; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y6.shapeInfo(), + dimsToExclude6.data()); + ASSERT_TRUE(N == x.lengthOf() / y6.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n6[i] == maxIdxs[i]); + + NDArray y7('c', {4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude7 = {0, 1, 3}; + const int n7[] = {15, 16, 17, 18, 19, 35, 36, 37, 38, 39, + 55, 56, 57, 58, 59, 75, 76, 77, 78, 79, + 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}; + minIdx = 3; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y7.shapeInfo(), + dimsToExclude7.data()); + ASSERT_TRUE(N == x.lengthOf() / y7.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n7[i] == maxIdxs[i]); + + NDArray y8('c', {5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude8 = {0, 1, 2}; + const int n8[] = {0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, + 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y8.shapeInfo(), + dimsToExclude8.data()); + ASSERT_TRUE(N == x.lengthOf() / y8.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n8[i] == maxIdxs[i]); + + NDArray y9('c', {2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude9 = {1, 2, 3}; + const int n9[] = {60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y9.shapeInfo(), + dimsToExclude9.data()); + ASSERT_TRUE(N == x.lengthOf() / y9.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n9[i] == maxIdxs[i]); + + NDArray y10('c', {3, 4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude10 = {0}; + const int n10[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y10.shapeInfo(), + dimsToExclude10.data()); + ASSERT_TRUE(N == x.lengthOf() / y10.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n10[i] == maxIdxs[i]); + + NDArray y11('c', {2, 4, 5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude11 = {1}; + const int n11[] = {66, 86, 106}; + minIdx = 26; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y11.shapeInfo(), + dimsToExclude11.data()); + ASSERT_TRUE(N == x.lengthOf() / y11.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n11[i] == maxIdxs[i]); + + NDArray y12('c', {3, 2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude12 = {0, 2}; + const int n12[] = {0, 2, 4, 5, 7, 9, 10, 12, 14, 15, 17, 19, + 60, 62, 64, 65, 67, 69, 70, 72, 74, 75, 77, 79}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y12.shapeInfo(), + dimsToExclude12.data()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n12[i] == maxIdxs[i]); + + NDArray y13('c', {3, 2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude13 = {0, 2}; + const int n13[] = {1, 3, 6, 8, 11, 13, 16, 18, + 61, 63, 66, 68, 71, 73, 76, 78}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y13.shapeInfo(), + dimsToExclude13.data()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n13[i] == maxIdxs[i]); + + NDArray y14('c', {4, 5}, sd::DataType::FLOAT32); + const int n14[] = {12, 32, 52, 72, 92, 112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y14.shapeInfo(), + nullptr); + ASSERT_TRUE(N == x.lengthOf() / y14.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n14[i] == maxIdxs[i]); + + NDArray y15('c', {3, 4, 5}, sd::DataType::FLOAT32); + const int n15[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y15.shapeInfo(), + nullptr); + ASSERT_TRUE(N == x.lengthOf() / y15.lengthOf()); + for (int i = 0; i < N; ++i) ASSERT_TRUE(n15[i] == maxIdxs[i]); } - - -#endif //LIBND4J_TADTESTS_H +#endif // LIBND4J_TADTESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp index a9450e9d0ad7..dab7ba417e0b 100644 --- a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp @@ -18,12 +18,14 @@ // @author raver119@gmail.com // -#include "testlayers.h" -#include -#include +#include #include +#include +#include + #include -#include + +#include "testlayers.h" using namespace samediff; using namespace sd; @@ -31,181 +33,179 @@ using namespace sd::ops; using namespace sd::graph; class ThreadsTests : public testing::Test { -public: - ThreadsTests() { - nd4j_printf("\n",""); - } + public: + ThreadsTests() { nd4j_printf("\n", ""); } }; TEST_F(ThreadsTests, th_test_1) { - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1023)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1024)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1026)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1023)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1024)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1026)); - ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 2043)); - ASSERT_EQ(2, ThreadsHelper::numberOfThreads(6, 2048)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 2043)); + ASSERT_EQ(2, ThreadsHelper::numberOfThreads(6, 2048)); } - TEST_F(ThreadsTests, th_test_2) { - // in this case we'll get better split over second loop - exactly 32 elements per thread - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(32, 48, 1024)); - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(6, 4, 16384)); + // in this case we'll get better split over second loop - exactly 32 elements + // per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(32, 48, 1024)); + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(6, 4, 16384)); - // in this case we'll get better split over first loop - 2 loops/2048 elements per thread - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(32, 64, 1024)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 6, 16384)); + // in this case we'll get better split over first loop - 2 loops/2048 elements + // per thread + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(32, 64, 1024)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 6, 16384)); - // in this case none of loops are good enough, but second loop is too small for split - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 64, 32)); + // in this case none of loops are good enough, but second loop is too small + // for split + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 64, 32)); - // all loops are good enough, but we go with bigger one, since small - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(2, 64, 32)); + // all loops are good enough, but we go with bigger one, since small + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(2, 64, 32)); - // obviously split goes into second loop, to give 1024 elements per thread - ASSERT_EQ(2, ThreadsHelper::pickLoop2d(2, 1, 2048)); + // obviously split goes into second loop, to give 1024 elements per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(2, 1, 2048)); } TEST_F(ThreadsTests, th_test_3) { - // typical conv cases - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(4, 32, 3, 128)); - ASSERT_EQ(2, ThreadsHelper::pickLoop3d(4, 1, 128, 64)); - ASSERT_EQ(3, ThreadsHelper::pickLoop3d(4, 1, 3, 128)); - - // checking for optimal threads for conv inference - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 1, 3, 128)); - ASSERT_EQ(4, ThreadsHelper::numberOfThreads3d(4, 1, 3, 128)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads3d(8, 1, 3, 128)); - - // checking for optimal threads for conv training - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 16, 3, 128)); - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 128)); - - - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 64)); - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); + // typical conv cases + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(4, 32, 3, 128)); + ASSERT_EQ(2, ThreadsHelper::pickLoop3d(4, 1, 128, 64)); + ASSERT_EQ(3, ThreadsHelper::pickLoop3d(4, 1, 3, 128)); + + // checking for optimal threads for conv inference + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 1, 3, 128)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads3d(4, 1, 3, 128)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads3d(8, 1, 3, 128)); + + // checking for optimal threads for conv training + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 16, 3, 128)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 128)); + + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 64)); + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); } TEST_F(ThreadsTests, th_test_5) { - ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); - ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); - for (auto e = 0; e < 6; e++) { - auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); + for (auto e = 0; e < 6; e++) { + auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } } TEST_F(ThreadsTests, th_test_4) { - // typical conv cases - ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); - ASSERT_EQ(4, ThreadsHelper::numberOfThreads2d(4, 32, 3)); - ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 32, 1)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 16, 64)); + // typical conv cases + ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads2d(4, 32, 3)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 32, 1)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 16, 64)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(4, 32, 1)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(4, 32, 1)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); - // primes edge cases - ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 19, 17)); - ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 19, 17)); + // primes edge cases + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 19, 17)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 19, 17)); - ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); - for (auto e = 0; e < 6; e++) { - auto span = Span2::build(1, e, 6, 0, 19, 1, 0, 17, 1); + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 19, 1, 0, 17, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } - nd4j_printf("-----------------------\n",""); - for (auto e = 0; e < 6; e++) { - auto span = Span2::build(1, e, 6, 0, 32, 1, 0, 3, 1); + nd4j_printf("-----------------------\n", ""); + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 32, 1, 0, 3, 1); - nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); - } + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } } - TEST_F(ThreadsTests, test_span_converage_1) { - for (int b = 1; b <= 128; b++) { - for (int c = 1; c <= 64; c++) { - for (int t = 1; t <= 64; t++) { - - auto threads = ThreadsHelper::numberOfThreads2d(t, b, c); - auto loop = ThreadsHelper::pickLoop2d(threads, b, c); - - if (t > 1 && threads == 1 && (b > 1 && c > 1)) { - nd4j_printf("Got 1 thread for [%i, %i] loop; initial max threads: %i\n", b, c, t) - } - - auto sum = 0; - for (auto a = 0; a < threads; a++) { - auto span = Span2::build(loop, a,threads, 0, b, 1, 0, c, 1); - - if (loop == 1) - sum += span.stopX() - span.startX(); - else if (loop == 2) - sum += span.stopY() - span.startY(); - else - throw std::runtime_error("Bad loop!"); - } - - if (loop == 1) - ASSERT_EQ(b, sum); - else - ASSERT_EQ(c, sum); - } + for (int b = 1; b <= 128; b++) { + for (int c = 1; c <= 64; c++) { + for (int t = 1; t <= 64; t++) { + auto threads = ThreadsHelper::numberOfThreads2d(t, b, c); + auto loop = ThreadsHelper::pickLoop2d(threads, b, c); + + if (t > 1 && threads == 1 && (b > 1 && c > 1)) { + nd4j_printf( + "Got 1 thread for [%i, %i] loop; initial max threads: %i\n", b, c, + t) } + + auto sum = 0; + for (auto a = 0; a < threads; a++) { + auto span = Span2::build(loop, a, threads, 0, b, 1, 0, c, 1); + + if (loop == 1) + sum += span.stopX() - span.startX(); + else if (loop == 2) + sum += span.stopY() - span.startY(); + else + throw std::runtime_error("Bad loop!"); + } + + if (loop == 1) + ASSERT_EQ(b, sum); + else + ASSERT_EQ(c, sum); + } } + } } TEST_F(ThreadsTests, validation_test_2d_1) { - if (1 > 0) - return; - - std::vector threads({1, 2, 4, 6, 8, 12, 16, 20, 32, 48, 64}); + if (1 > 0) return; - for (int e = 1; e < 1024; e++) { - for (int i = 1; i <= 1024; i++ ) { - for (auto t:threads) { - std::atomic sum; - sum.store(0); + std::vector threads({1, 2, 4, 6, 8, 12, 16, 20, 32, 48, 64}); - auto func = PRAGMA_THREADS_FOR_2D { - for (auto x = start_x; x < stop_x; x += inc_x) { - for (auto y = start_y; y < stop_y; y += inc_y) { - sum++; - } - } - }; + for (int e = 1; e < 1024; e++) { + for (int i = 1; i <= 1024; i++) { + for (auto t : threads) { + std::atomic sum; + sum.store(0); - samediff::Threads::parallel_for(func, 0, e, 1, 0, i, 1, t, true); - - ASSERT_EQ(e * i, sum.load()); + auto func = PRAGMA_THREADS_FOR_2D { + for (auto x = start_x; x < stop_x; x += inc_x) { + for (auto y = start_y; y < stop_y; y += inc_y) { + sum++; } - } + } + }; - nd4j_printf("Finished iteration %i\n", e); + samediff::Threads::parallel_for(func, 0, e, 1, 0, i, 1, t, true); + + ASSERT_EQ(e * i, sum.load()); + } } + + nd4j_printf("Finished iteration %i\n", e); + } } TEST_F(ThreadsTests, reduction_test_1) { + auto func = PRAGMA_REDUCE_LONG { + int64_t sum = 0; - auto func = PRAGMA_REDUCE_LONG { - int64_t sum = 0; - - for (auto e = start; e < stop; e++) { - sum++; - }; - - return sum; + for (auto e = start; e < stop; e++) { + sum++; }; - auto sum = samediff::Threads::parallel_long(func, LAMBDA_AL {return _old + _new;}, 0, 8192, 1, 4); - ASSERT_EQ(8192, sum); + return sum; + }; + + auto sum = samediff::Threads::parallel_long( + func, LAMBDA_AL { return _old + _new; }, 0, 8192, 1, 4); + ASSERT_EQ(8192, sum); } /* @@ -230,7 +230,9 @@ TEST_F(ThreadsTests, basic_test_1) { auto timeStartThreads = std::chrono::system_clock::now(); samediff::Threads::parallel_for(func, 0, array.lengthOf()); auto timeEndThreads = std::chrono::system_clock::now(); - auto outerTimeThreads = std::chrono::duration_cast (timeEndThreads - timeStartThreads).count(); + auto outerTimeThreads = +std::chrono::duration_cast (timeEndThreads - +timeStartThreads).count(); auto timeStartOmp = std::chrono::system_clock::now(); PRAGMA_OMP_PARALLEL_FOR_SIMD @@ -238,10 +240,12 @@ TEST_F(ThreadsTests, basic_test_1) { lbuffer[e] += 1.0f; } auto timeEndOmp = std::chrono::system_clock::now(); - auto outerTimeOmp = std::chrono::duration_cast (timeEndOmp - timeStartOmp).count(); + auto outerTimeOmp = std::chrono::duration_cast +(timeEndOmp - timeStartOmp).count(); ASSERT_NEAR((float) array.lengthOf(), array.sumNumber().e(0), 1e-5f); - nd4j_printf("Threads time: %lld us; OMP time: %lld us; %p\n", outerTimeThreads, outerTimeOmp, instance) + nd4j_printf("Threads time: %lld us; OMP time: %lld us; %p\n", +outerTimeThreads, outerTimeOmp, instance) } */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp index 2c27f95f9066..9b7dd6cfde10 100644 --- a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp @@ -18,55 +18,56 @@ // Created by raver119 on 02/07/18. // -#include "testlayers.h" -#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::ops; using namespace sd::graph; class TypeCastTests : public testing::Test { -public: - + public: }; TEST_F(TypeCastTests, Test_Cast_1) { #ifndef __CUDABLAS__ - const int limit = 100; - auto src = new double[limit]; - auto z = new float[limit]; - auto exp = new float[limit]; - - for (int e = 0; e < limit; e++) { - src[e] = static_cast(e); - exp[e] = static_cast(e); - } - - TypeCast::convertGeneric(nullptr, reinterpret_cast(src), limit, reinterpret_cast(z)); - - for (int e = 0; e < limit; e++) { - ASSERT_NEAR(exp[e], z[e], 1e-5f); - } - - delete[] src; - delete[] z; - delete[] exp; + const int limit = 100; + auto src = new double[limit]; + auto z = new float[limit]; + auto exp = new float[limit]; + + for (int e = 0; e < limit; e++) { + src[e] = static_cast(e); + exp[e] = static_cast(e); + } + + TypeCast::convertGeneric(nullptr, + reinterpret_cast(src), limit, + reinterpret_cast(z)); + + for (int e = 0; e < limit; e++) { + ASSERT_NEAR(exp[e], z[e], 1e-5f); + } + + delete[] src; + delete[] z; + delete[] exp; #endif } TEST_F(TypeCastTests, Test_ConvertDtype_1) { +#ifndef __CUDABLAS__ - #ifndef __CUDABLAS__ - - float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; - float16 dst[5]; - float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f}; + float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + float16 dst[5]; + float16 exp[] = {(float16)1.0f, (float16)2.0f, (float16)3.0f, (float16)4.0f, + (float16)5.0f}; - convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst); + convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst); - for (int e = 0; e < 5; e++) - ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f); + for (int e = 0; e < 5; e++) ASSERT_NEAR(exp[e], dst[e], (float16)0.01f); - #endif +#endif } diff --git a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp index 299972c7d0fb..2f32ac16cfdc 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp @@ -18,133 +18,129 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; class VariableProxyTests : public testing::Test { -public: - + public: }; - TEST_F(VariableProxyTests, Test_Simple_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; - ref.putVariable(119, x); + ref.putVariable(119, x); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); } - TEST_F(VariableProxyTests, Test_Simple_2) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; - ASSERT_FALSE(ref.hasVariable(119)); + ASSERT_FALSE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_FALSE(proxy.hasVariable(119)); + ASSERT_FALSE(proxy.hasVariable(119)); - proxy.putVariable(119, x); + proxy.putVariable(119, x); - ASSERT_FALSE(ref.hasVariable(119)); + ASSERT_FALSE(ref.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); } - TEST_F(VariableProxyTests, Test_Simple_3) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; - ref.putVariable(119, x); + ref.putVariable(119, x); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - proxy.putVariable(119, y); + proxy.putVariable(119, y); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - auto z0 = ref.getVariable(119)->getNDArray(); - auto z1 = proxy.getVariable(119)->getNDArray(); + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == *z1); - ASSERT_TRUE(x == *z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } TEST_F(VariableProxyTests, Test_Simple_4) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); - auto z = NDArrayFactory::create('c', {2, 2}, {4, 1, 3, 2}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + auto z = NDArrayFactory::create('c', {2, 2}, {4, 1, 3, 2}); + VariableSpace ref; - ref.putVariable(119, x); - ref.putVariable(118, z); + ref.putVariable(119, x); + ref.putVariable(118, z); - ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(119)); - VariableProxy proxy(&ref); + VariableProxy proxy(&ref); - ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(119)); - proxy.putVariable(119, y); + proxy.putVariable(119, y); - ASSERT_TRUE(ref.hasVariable(119)); - ASSERT_TRUE(ref.hasVariable(118)); + ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(118)); - ASSERT_TRUE(proxy.hasVariable(119)); - ASSERT_TRUE(proxy.hasVariable(118)); + ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(118)); - auto z0 = ref.getVariable(119)->getNDArray(); - auto z1 = proxy.getVariable(119)->getNDArray(); + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == *z1); - ASSERT_TRUE(x == *z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } - TEST_F(VariableProxyTests, Test_Cast_1) { - auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); - auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); - VariableSpace ref; + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; - ref.putVariable(-119, x); + ref.putVariable(-119, x); - ASSERT_TRUE(ref.hasVariable(-119)); + ASSERT_TRUE(ref.hasVariable(-119)); - VariableProxy proxy(&ref); - auto cast = (VariableSpace *) &proxy; + VariableProxy proxy(&ref); + auto cast = (VariableSpace *)&proxy; - ASSERT_TRUE(cast->hasVariable(-119)); + ASSERT_TRUE(cast->hasVariable(-119)); - cast->putVariable(-119, y); + cast->putVariable(-119, y); - ASSERT_TRUE(ref.hasVariable(-119)); + ASSERT_TRUE(ref.hasVariable(-119)); - ASSERT_TRUE(cast->hasVariable(-119)); + ASSERT_TRUE(cast->hasVariable(-119)); - auto z0 = ref.getVariable(-119)->getNDArray(); - auto z1 = cast->getVariable(-119)->getNDArray(); + auto z0 = ref.getVariable(-119)->getNDArray(); + auto z1 = cast->getVariable(-119)->getNDArray(); - ASSERT_FALSE(z0 == z1); - ASSERT_TRUE(y == *z1); - ASSERT_TRUE(x == *z0); + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == *z1); + ASSERT_TRUE(x == *z0); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp index 21407178210b..bd9df549124a 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableSpaceTests.cpp @@ -18,104 +18,102 @@ // @author raver119@gmail.com // -#include "testlayers.h" +#include #include -#include -#include -#include #include -#include +#include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class VariableSpaceTest : public testing::Test { -public: - int *cShape = new int[8]{2, 2, 2, 2, 1, 0, 1, 99}; - int *fShape = new int[8]{2, 2, 2, 1, 2, 0, 1, 102}; - - - ~VariableSpaceTest() { - delete[] cShape; - delete[] fShape; - } + public: + int *cShape = new int[8]{2, 2, 2, 2, 1, 0, 1, 99}; + int *fShape = new int[8]{2, 2, 2, 1, 2, 0, 1, 102}; + + ~VariableSpaceTest() { + delete[] cShape; + delete[] fShape; + } }; - TEST_F(VariableSpaceTest, SettersGettersTest1) { - auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create('c', {5, 5}); - auto arrayB = NDArrayFactory::create('c', {3, 3}); + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); - space1->putVariable(1, arrayA); - space1->putVariable(2, arrayB); + space1->putVariable(1, arrayA); + space1->putVariable(2, arrayB); - auto arrayRA = space1->getVariable(1); - auto arrayRB = space1->getVariable(2); + auto arrayRA = space1->getVariable(1); + auto arrayRB = space1->getVariable(2); - ASSERT_TRUE(arrayA == *arrayRA->getNDArray()); - ASSERT_TRUE(arrayB == *arrayRB->getNDArray()); + ASSERT_TRUE(arrayA == *arrayRA->getNDArray()); + ASSERT_TRUE(arrayB == *arrayRB->getNDArray()); - // we should survive this call - delete space1; + // we should survive this call + delete space1; } - TEST_F(VariableSpaceTest, SettersGettersTest2) { - auto space1 = new VariableSpace(); - auto arrayA = NDArrayFactory::create('c', {5, 5}); - auto arrayB = NDArrayFactory::create('c', {3, 3}); + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create('c', {5, 5}); + auto arrayB = NDArrayFactory::create('c', {3, 3}); - space1->putVariable(-1, arrayA); - space1->putVariable(2, arrayB); + space1->putVariable(-1, arrayA); + space1->putVariable(2, arrayB); - Nd4jLong expExternal = (25 * 4) + (8 * 8); - Nd4jLong expInternal = (9 * 4) + (8 * 8); + Nd4jLong expExternal = (25 * 4) + (8 * 8); + Nd4jLong expInternal = (9 * 4) + (8 * 8); - ASSERT_EQ(expExternal, space1->externalMemory()); - ASSERT_EQ(expInternal, space1->internalMemory()); + ASSERT_EQ(expExternal, space1->externalMemory()); + ASSERT_EQ(expInternal, space1->internalMemory()); - delete space1; + delete space1; } TEST_F(VariableSpaceTest, EqualityTest1) { - VariableSpace space; + VariableSpace space; - std::string name("myvar"); + std::string name("myvar"); - auto arrayA = NDArrayFactory::create('c', {3, 3}); - auto variableA = std::make_shared(arrayA, name, 1); + auto arrayA = NDArrayFactory::create('c', {3, 3}); + auto variableA = std::make_shared(arrayA, name, 1); - space.putVariable(1, variableA); + space.putVariable(1, variableA); - std::pair pair(1,0); + std::pair pair(1, 0); - ASSERT_TRUE(space.hasVariable(1)); - ASSERT_TRUE(space.hasVariable(pair)); - ASSERT_TRUE(space.hasVariable(name)); + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); + ASSERT_TRUE(space.hasVariable(name)); - auto rV1 = space.getVariable(1); - auto rV2 = space.getVariable(pair); - auto rV3 = space.getVariable(name); + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); + auto rV3 = space.getVariable(name); - ASSERT_TRUE(rV1 == rV2); - ASSERT_TRUE(rV2 == rV3); + ASSERT_TRUE(rV1 == rV2); + ASSERT_TRUE(rV2 == rV3); } TEST_F(VariableSpaceTest, EqualityTest2) { - VariableSpace space; + VariableSpace space; - auto arrayA = NDArrayFactory::create('c', {3, 3}); + auto arrayA = NDArrayFactory::create('c', {3, 3}); - space.putVariable(1, arrayA); + space.putVariable(1, arrayA); - std::pair pair(1,0); + std::pair pair(1, 0); - ASSERT_TRUE(space.hasVariable(1)); - ASSERT_TRUE(space.hasVariable(pair)); + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); - auto rV1 = space.getVariable(1); - auto rV2 = space.getVariable(pair); + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); - ASSERT_TRUE(rV1 == rV2); + ASSERT_TRUE(rV1 == rV2); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp index 90deeab8d894..99342f5c9e6b 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableTests.cpp @@ -21,123 +21,128 @@ #ifndef LIBND4J_VARIABLETESTS_H #define LIBND4J_VARIABLETESTS_H -#include "testlayers.h" #include -#include #include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::graph; class VariableTests : public testing::Test { -public: - + public: }; TEST_F(VariableTests, Test_FlatVariableDataType_1) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_FLOAT, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); + auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(*restoredArray)); - ASSERT_TRUE(original.equalsTo(*restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); - delete rv; + delete rv; } TEST_F(VariableTests, Test_FlatVariableDataType_2) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_DOUBLE, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); + auto restoredArray = rv->getNDArray(); - ASSERT_TRUE(original.isSameShape(*restoredArray)); - ASSERT_TRUE(original.equalsTo(*restoredArray)); + ASSERT_TRUE(original.isSameShape(*restoredArray)); + ASSERT_TRUE(original.equalsTo(*restoredArray)); - delete rv; + delete rv; } - TEST_F(VariableTests, Test_FlatVariableDataType_3) { - flatbuffers::FlatBufferBuilder builder(1024); - auto original = NDArrayFactory::create('c', {5, 10}); - auto floating = NDArrayFactory::create('c', {5, 10}); - original.linspace(1); - floating.linspace(1); + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + auto floating = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); + floating.linspace(1); - auto vec = original.asByteVector(); + auto vec = original.asByteVector(); - auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); - auto fBuffer = builder.CreateVector(vec); - auto fVid = CreateIntPair(builder, 1, 12); + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); - auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + auto fArray = + CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + auto flatVar = CreateFlatVariable(builder, fVid, 0, + sd::graph::DType::DType_DOUBLE, 0, fArray); - builder.Finish(flatVar); + builder.Finish(flatVar); - auto ptr = builder.GetBufferPointer(); + auto ptr = builder.GetBufferPointer(); - auto restoredVar = GetFlatVariable(ptr); + auto restoredVar = GetFlatVariable(ptr); - auto rv = new Variable(restoredVar); + auto rv = new Variable(restoredVar); - ASSERT_EQ(1, rv->id()); - ASSERT_EQ(12, rv->index()); + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); - auto restoredArray = rv->getNDArray(); - auto conv = restoredArray->asT(); + auto restoredArray = rv->getNDArray(); + auto conv = restoredArray->asT(); - ASSERT_TRUE(floating.isSameShape(*restoredArray)); - ASSERT_TRUE(floating.equalsTo(conv)); + ASSERT_TRUE(floating.isSameShape(*restoredArray)); + ASSERT_TRUE(floating.equalsTo(conv)); - delete rv; + delete rv; } /* @@ -151,7 +156,8 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); auto fVid = CreateIntPair(builder, 37, 12); - auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + auto flatVar = CreateFlatVariable(builder, fVid, 0, +sd::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); builder.Finish(flatVar); @@ -174,4 +180,4 @@ TEST_F(VariableTests, Test_FlatVariableDataType_4) { delete rv; } */ -#endif //LIBND4J_VARIABLETESTS_H +#endif // LIBND4J_VARIABLETESTS_H diff --git a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp index 74409f16a2cd..6fa22fcefa32 100644 --- a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cpp @@ -21,260 +21,248 @@ #ifndef LIBND4J_WORKSPACETESTS_H #define LIBND4J_WORKSPACETESTS_H -#include "testlayers.h" #include -#include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::memory; -class WorkspaceTests : public testing::Test { - -}; - +class WorkspaceTests : public testing::Test {}; TEST_F(WorkspaceTests, BasicInitialization1) { - Workspace workspace(1024); + Workspace workspace(1024); - ASSERT_EQ(1024, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(1024, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); } TEST_F(WorkspaceTests, BasicInitialization2) { - Workspace workspace(65536); + Workspace workspace(65536); - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + array.p(0, 1.0f); + array.p(5, 1.0f); - auto v = array.reduceNumber(reduce::Sum); - auto f = v.e(0); + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); - ASSERT_NEAR(2.0f, f, 1e-5); + ASSERT_NEAR(2.0f, f, 1e-5); - ASSERT_TRUE(workspace.getCurrentOffset() > 0); + ASSERT_TRUE(workspace.getCurrentOffset() > 0); } - TEST_F(WorkspaceTests, BasicInitialization3) { - Workspace workspace; + Workspace workspace; - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + array.p(0, 1.0f); + array.p(5, 1.0f); - auto v = array.reduceNumber(reduce::Sum); - auto f = v.e(0); + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); - ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e(0), 1e-5); - ASSERT_TRUE(workspace.getCurrentOffset() == 0); + ASSERT_TRUE(workspace.getCurrentOffset() == 0); } - TEST_F(WorkspaceTests, ResetTest1) { - Workspace workspace(65536); - LaunchContext ctx; - ctx.setWorkspace(&workspace); + Workspace workspace(65536); + LaunchContext ctx; + ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - array.p(0, 1.0f); - array.p(5, 1.0f); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + array.p(0, 1.0f); + array.p(5, 1.0f); - workspace.scopeOut(); - for (int e = 0; e < 5; e++) { - workspace.scopeIn(); + workspace.scopeOut(); + for (int e = 0; e < 5; e++) { + workspace.scopeIn(); - auto array2 = NDArrayFactory::create('c', {5, 5}, &ctx); - array2.p(0, 1.0f); - array2.p(5, 1.0f); + auto array2 = NDArrayFactory::create('c', {5, 5}, &ctx); + array2.p(0, 1.0f); + array2.p(5, 1.0f); - ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e(0), 1e-5); + ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e(0), 1e-5); - workspace.scopeOut(); - } + workspace.scopeOut(); + } - ASSERT_EQ(65536, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(65536, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getSpilledSize()); } - TEST_F(WorkspaceTests, StretchTest1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - Workspace workspace(128); - void* ptr = workspace.allocateBytes(8); - workspace.scopeOut(); - ASSERT_EQ(0, workspace.getSpilledSize()); - ASSERT_EQ(0, workspace.getSpilledSecondarySize()); - ASSERT_EQ(0, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - - - workspace.scopeIn(); - for (int e = 0; e < 10; e++) { - - workspace.allocateBytes(128); - - } - ASSERT_EQ(128 * 9, workspace.getSpilledSize()); - workspace.scopeOut(); - workspace.scopeIn(); + Workspace workspace(128); + void* ptr = workspace.allocateBytes(8); + workspace.scopeOut(); + ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(0, workspace.getSpilledSecondarySize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - ASSERT_EQ(0, workspace.getCurrentOffset()); + workspace.scopeIn(); + for (int e = 0; e < 10; e++) { + workspace.allocateBytes(128); + } + ASSERT_EQ(128 * 9, workspace.getSpilledSize()); + workspace.scopeOut(); + workspace.scopeIn(); - // we should have absolutely different pointer here, due to reallocation - void* ptr2 = workspace.allocateBytes(8); + ASSERT_EQ(0, workspace.getCurrentOffset()); - //ASSERT_FALSE(ptr == ptr2); + // we should have absolutely different pointer here, due to reallocation + void* ptr2 = workspace.allocateBytes(8); + // ASSERT_FALSE(ptr == ptr2); - ASSERT_EQ(1280, workspace.getCurrentSize()); - ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(1280, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getSpilledSize()); } TEST_F(WorkspaceTests, NewInWorkspaceTest1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - Workspace ws(65536); + Workspace ws(65536); - ASSERT_EQ(65536, ws.getCurrentSize()); - ASSERT_EQ(0, ws.getCurrentOffset()); + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); - ASSERT_FALSE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); + ASSERT_FALSE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); - MemoryRegistrator::getInstance()->attachWorkspace(&ws); + MemoryRegistrator::getInstance()->attachWorkspace(&ws); - ASSERT_TRUE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); + ASSERT_TRUE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); - auto ast = NDArrayFactory::create_('c', {5, 5}); + auto ast = NDArrayFactory::create_('c', {5, 5}); - ASSERT_TRUE(ws.getCurrentOffset() > 0); + ASSERT_TRUE(ws.getCurrentOffset() > 0); - delete ast; + delete ast; - MemoryRegistrator::getInstance()->forgetWorkspace(); + MemoryRegistrator::getInstance()->forgetWorkspace(); - ASSERT_FALSE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); - ASSERT_TRUE(MemoryRegistrator::getInstance()->getWorkspace() == nullptr); + ASSERT_FALSE(MemoryRegistrator::getInstance()->hasWorkspaceAttached()); + ASSERT_TRUE(MemoryRegistrator::getInstance()->getWorkspace() == nullptr); } - TEST_F(WorkspaceTests, NewInWorkspaceTest2) { - Workspace ws(65536); - LaunchContext ctx; - ctx.setWorkspace(&ws); + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); - ASSERT_EQ(65536, ws.getCurrentSize()); - ASSERT_EQ(0, ws.getCurrentOffset()); + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); - MemoryRegistrator::getInstance()->attachWorkspace(&ws); + MemoryRegistrator::getInstance()->attachWorkspace(&ws); - auto ast = NDArrayFactory::create_('c', {5, 5}, &ctx); + auto ast = NDArrayFactory::create_('c', {5, 5}, &ctx); - ASSERT_TRUE(ws.getCurrentOffset() > 0); + ASSERT_TRUE(ws.getCurrentOffset() > 0); - delete ast; + delete ast; - MemoryRegistrator::getInstance()->forgetWorkspace(); + MemoryRegistrator::getInstance()->forgetWorkspace(); } TEST_F(WorkspaceTests, CloneTest1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - Workspace ws(65536); + Workspace ws(65536); - ws.allocateBytes(65536 * 2); + ws.allocateBytes(65536 * 2); - ASSERT_EQ(65536 * 2, ws.getSpilledSize()); + ASSERT_EQ(65536 * 2, ws.getSpilledSize()); - auto clone = ws.clone(); + auto clone = ws.clone(); - ASSERT_EQ(65536 * 2, clone->getCurrentSize()); - ASSERT_EQ(0, clone->getCurrentOffset()); - ASSERT_EQ(0, clone->getSpilledSize()); + ASSERT_EQ(65536 * 2, clone->getCurrentSize()); + ASSERT_EQ(0, clone->getCurrentOffset()); + ASSERT_EQ(0, clone->getSpilledSize()); - delete clone; + delete clone; } TEST_F(WorkspaceTests, Test_Arrays_1) { - Workspace ws(65536); - LaunchContext ctx; - ctx.setWorkspace(&ws); - - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx); + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); - // x.printIndexedBuffer("x0"); + auto x = NDArrayFactory::create('c', {3, 3}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx); - auto y = NDArrayFactory::create('c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx); + // x.printIndexedBuffer("x0"); - // x.printIndexedBuffer("x2"); + auto y = NDArrayFactory::create( + 'c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx); - auto z = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx); + // x.printIndexedBuffer("x2"); - MmulHelper::mmul(&x, &y, &z); + auto z = NDArrayFactory::create('c', {3, 3}, + {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx); - y.assign(&x); + MmulHelper::mmul(&x, &y, &z); + y.assign(&x); - // x.printIndexedBuffer("x3"); - // y.printIndexedBuffer("y"); - // z.printIndexedBuffer("z"); + // x.printIndexedBuffer("x3"); + // y.printIndexedBuffer("y"); + // z.printIndexedBuffer("z"); } #ifdef GRAPH_FILES_OK TEST_F(WorkspaceTests, Test_Graph_1) { - auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); - auto workspace = graph->variableSpace()->workspace(); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto workspace = graph->variableSpace()->workspace(); - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); - delete graph; + delete graph; } #endif TEST_F(WorkspaceTests, Test_Externalized_1) { - if (!Environment::getInstance()->isCPU()) - return; + if (!Environment::getInstance()->isCPU()) return; - char buffer[10000]; - ExternalWorkspace pojo((Nd4jPointer) buffer, 10000, nullptr, 0); + char buffer[10000]; + ExternalWorkspace pojo((Nd4jPointer)buffer, 10000, nullptr, 0); - ASSERT_EQ(10000, pojo.sizeHost()); - ASSERT_EQ(0, pojo.sizeDevice()); + ASSERT_EQ(10000, pojo.sizeHost()); + ASSERT_EQ(0, pojo.sizeDevice()); - Workspace ws(&pojo); - ASSERT_EQ(10000, ws.getCurrentSize()); - ASSERT_EQ(10000, ws.getAllocatedSize()); - LaunchContext ctx; - ctx.setWorkspace(&ws); + Workspace ws(&pojo); + ASSERT_EQ(10000, ws.getCurrentSize()); + ASSERT_EQ(10000, ws.getAllocatedSize()); + LaunchContext ctx; + ctx.setWorkspace(&ws); - auto x = NDArrayFactory::create('c', {10, 10}, &ctx); + auto x = NDArrayFactory::create('c', {10, 10}, &ctx); - // only buffer size goes into account - ASSERT_EQ(400, ws.getUsedSize()); - ASSERT_EQ(400, ws.getCurrentOffset()); + // only buffer size goes into account + ASSERT_EQ(400, ws.getUsedSize()); + ASSERT_EQ(400, ws.getCurrentOffset()); - x.assign(2.0); + x.assign(2.0); - float m = x.meanNumber().e(0); - ASSERT_NEAR(2.0f, m, 1e-5); + float m = x.meanNumber().e(0); + ASSERT_NEAR(2.0f, m, 1e-5); } // TODO: uncomment this test once long shapes are introduced @@ -285,5 +273,4 @@ TEST_F(WorkspaceTests, Test_Big_Allocation_1) { } */ - -#endif //LIBND4J_WORKSPACETESTS_H \ No newline at end of file +#endif // LIBND4J_WORKSPACETESTS_H \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu index 6fe157ac87ae..2f04fbfe2b6c 100644 --- a/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu +++ b/libnd4j/tests_cpu/layers_tests/WorkspaceTests.cu @@ -18,43 +18,46 @@ // @author raver119@gmail.com // -#include "testlayers.h" #include -#include -#include #include +#include +#include + +#include "testlayers.h" using namespace sd; using namespace sd::memory; -class CudaWorkspaceTests : public testing::Test { - -}; +class CudaWorkspaceTests : public testing::Test {}; TEST_F(CudaWorkspaceTests, Basic_Tests_1) { - Workspace workspace(65536, 65536); + Workspace workspace(65536, 65536); - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); - ASSERT_EQ(108, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); - array.e(0); + array.e(0); - ASSERT_EQ(100, workspace.getCurrentSecondaryOffset()); + ASSERT_EQ(100, workspace.getCurrentSecondaryOffset()); } TEST_F(CudaWorkspaceTests, Basic_Tests_2) { - Workspace workspace(65536, 65536); - - ASSERT_EQ(0, workspace.getCurrentOffset()); - LaunchContext ctx; - ctx.setWorkspace(&workspace); - auto array = NDArrayFactory::create('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx); - - ASSERT_EQ(108, workspace.getCurrentOffset()); - ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + Workspace workspace(65536, 65536); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create( + 'c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + &ctx); + + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/testinclude.h b/libnd4j/tests_cpu/layers_tests/testinclude.h index b266019a9e31..48b23f3548e3 100644 --- a/libnd4j/tests_cpu/layers_tests/testinclude.h +++ b/libnd4j/tests_cpu/layers_tests/testinclude.h @@ -20,33 +20,38 @@ #ifndef LIBND4J_TESTINCLUDE_H #define LIBND4J_TESTINCLUDE_H -#include "testlayers.h" -#include #include -//https://stackoverflow.com/questions/228005/alternative-to-itoa-for-converting-integer-to-string-c -FORCEINLINE std::string int_array_to_string(Nd4jLong int_array[], Nd4jLong size_of_array) { - std::string returnstring = "["; - for (int temp = 0; temp < size_of_array; temp++) { - returnstring += std::to_string(int_array[temp]); - if(temp < size_of_array - 1) - returnstring += ","; - } - returnstring += "]"; - return returnstring; -} - -FORCEINLINE ::testing::AssertionResult arrsEquals(Nd4jLong n, Nd4jLong *assertion,Nd4jLong *other) { - for(int i = 0; i < n; i++) { - if(assertion[i] != other[i]) { - std::string message = std::string("Failure at index ") + std::to_string(i) + std::string(" assertion: ") + int_array_to_string(assertion,n) + std::string(" and test array ") + int_array_to_string(other,n) + std::string(" is not equal"); - return ::testing::AssertionFailure() << message; - } +#include - } - return ::testing::AssertionSuccess(); +#include "testlayers.h" +// https://stackoverflow.com/questions/228005/alternative-to-itoa-for-converting-integer-to-string-c +FORCEINLINE std::string int_array_to_string(Nd4jLong int_array[], + Nd4jLong size_of_array) { + std::string returnstring = "["; + for (int temp = 0; temp < size_of_array; temp++) { + returnstring += std::to_string(int_array[temp]); + if (temp < size_of_array - 1) returnstring += ","; + } + returnstring += "]"; + return returnstring; } +FORCEINLINE ::testing::AssertionResult arrsEquals(Nd4jLong n, + Nd4jLong *assertion, + Nd4jLong *other) { + for (int i = 0; i < n; i++) { + if (assertion[i] != other[i]) { + std::string message = + std::string("Failure at index ") + std::to_string(i) + + std::string(" assertion: ") + int_array_to_string(assertion, n) + + std::string(" and test array ") + int_array_to_string(other, n) + + std::string(" is not equal"); + return ::testing::AssertionFailure() << message; + } + } + return ::testing::AssertionSuccess(); +} -#endif //LIBND4J_TESTINCLUDE_H +#endif // LIBND4J_TESTINCLUDE_H diff --git a/libnd4j/tests_cpu/layers_tests/testlayers.h b/libnd4j/tests_cpu/layers_tests/testlayers.h index 697e61693e2e..6e4f03fdfcc1 100644 --- a/libnd4j/tests_cpu/layers_tests/testlayers.h +++ b/libnd4j/tests_cpu/layers_tests/testlayers.h @@ -21,20 +21,21 @@ #ifndef LIBND4J_TESTLAYERS_H #define LIBND4J_TESTLAYERS_H -#include -#include -#include -#include +#include +#include +#include #include #include #include -#include -#include -#include +#include #include +#include +#include #include -#include -#include +#include +#include +#include + #include -#endif //LIBND4J_TESTLAYERS_H +#endif // LIBND4J_TESTLAYERS_H From f6939aee78e4473a4864b0328539b37118c16784 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 11 May 2020 15:00:50 +0300 Subject: [PATCH 118/233] formatting Signed-off-by: raver119@gmail.com --- .../ops/declarable/helpers/impl/lstmLayer.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 6bfca621d6e6..5ec64180f023 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -685,17 +685,16 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, (*dLdWp)({2 * nOut, 3 * nOut}) += std::move(dLdzo) * (*c); // [nOut] } else if (Wp) { NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); - (std::move(dLdzi) * (*cI)) - .reduceAlongDimension(reduce::Sum, temp, - {0}); // [bS, nOut] -> reduce -> [nOut] + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzi) * (*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); (*dLdWp)({0, nOut}) += temp; - (std::move(dLdzf) * (*cI)) - .reduceAlongDimension(reduce::Sum, temp, - {0}); // [bS, nOut] -> reduce -> [nOut] + + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzf) * (*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); (*dLdWp)({nOut, 2 * nOut}) += temp; - (std::move(dLdzo) * (*c)) - .reduceAlongDimension(reduce::Sum, temp, - {0}); // [bS, nOut] -> reduce -> [nOut] + + // [bS, nOut] -> reduce -> [nOut] + (std::move(dLdzo) * (*c)).reduceAlongDimension(reduce::Sum, temp, {0}); (*dLdWp)({2 * nOut, 3 * nOut}) += temp; } } From 31958e6fb07474beac3ad6aab5fd1e374694d073 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 11 May 2020 15:20:50 +0300 Subject: [PATCH 119/233] formatting Signed-off-by: raver119@gmail.com --- libnd4j/include/array/cpu/NDArray.cpp | 2 +- .../include/execution/cuda/LaunchContext.cu | 4 +- libnd4j/include/helpers/Loops.hpp | 6 +- libnd4j/include/helpers/MKLDNNStream.h | 3 +- libnd4j/include/helpers/TAD.h | 18 ++-- .../helpers/benchmark/BroadcastBenchmark.h | 3 +- libnd4j/include/helpers/cpu/MmulHelper.cpp | 10 +- libnd4j/include/helpers/impl/DebugHelper.cpp | 2 +- libnd4j/include/helpers/impl/MmulHelper.cpp | 10 +- libnd4j/include/helpers/shape.h | 4 +- .../include/loops/cuda/broadcasting_bool.cu | 10 +- .../include/loops/cuda/broadcasting_int.cu | 10 +- .../include/loops/cuda/reduce/reduce_bool.cu | 5 +- .../cuda/specials/fillDimensionalIsMax.cu | 5 +- .../include/loops/cuda/summarystatsreduce.cu | 7 +- .../declarable/generic/nn/fusedBatchNorm.cpp | 4 +- .../include/ops/declarable/headers/common.h | 2 +- .../helpers/cpu/convolutions_pooling2d.cpp | 72 ++++++++----- .../helpers/cpu/convolutions_pooling2dBP.cpp | 72 ++++++++----- .../ops/declarable/helpers/cpu/dropout.cpp | 5 +- .../ops/declarable/helpers/cpu/gather.cpp | 3 +- .../helpers/cpu/gatherTransforms.cpp | 3 +- .../ops/declarable/helpers/cpu/lstm.cpp | 3 +- .../ops/declarable/helpers/cuda/batchnorm.cu | 37 ++++--- .../ops/declarable/helpers/cuda/dropout.cu | 5 +- .../helpers/cuda/extract_patches.cu | 3 +- .../declarable/helpers/cuda/image_resize.cu | 18 ++-- .../ops/declarable/helpers/cuda/lstm.cu | 3 +- .../ops/declarable/helpers/cuda/lup.cu | 5 +- .../ops/declarable/helpers/cuda/sg_cb.cu | 10 +- .../ops/declarable/helpers/cuda/stack.cu | 10 +- .../include/ops/declarable/helpers/hamming.h | 2 +- .../declarable/impl/LegacyBroadcastBoolOp.cpp | 19 ++-- .../ops/declarable/impl/LegacyBroadcastOp.cpp | 19 ++-- .../ops/declarable/impl/LegacyReduce3Op.cpp | 22 ++-- .../declarable/impl/LegacyReduceBoolOp.cpp | 16 +-- .../declarable/impl/LegacyReduceFloatOp.cpp | 16 +-- .../declarable/impl/LegacyReduceLongOp.cpp | 22 ++-- .../declarable/impl/LegacyReduceSameOp.cpp | 22 ++-- .../ops/declarable/impl/LegacyStatsOp.cpp | 11 +- .../declarable/platform/mkldnn/lstmLayer.cpp | 6 +- libnd4j/include/system/play.h | 28 ++--- .../layers_tests/CudaBasicsTests1.cu | 28 ++--- .../layers_tests/DeclarableOpsTests10.cpp | 10 +- .../layers_tests/DeclarableOpsTests13.cpp | 8 +- .../layers_tests/DeclarableOpsTests7.cpp | 8 +- .../layers_tests/DeclarableOpsTests9.cpp | 28 ++--- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 5 +- .../layers_tests/MultiDataTypeTests.cpp | 4 +- .../layers_tests/NDArrayCudaBasicsTests.cu | 102 +++++++++--------- .../tests_cpu/layers_tests/NDArrayTests.cpp | 6 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 2 +- 52 files changed, 419 insertions(+), 319 deletions(-) diff --git a/libnd4j/include/array/cpu/NDArray.cpp b/libnd4j/include/array/cpu/NDArray.cpp index 03d74f2e9902..b9aa5ec6fb1f 100644 --- a/libnd4j/include/array/cpu/NDArray.cpp +++ b/libnd4j/include/array/cpu/NDArray.cpp @@ -351,7 +351,7 @@ void NDArray::tile(const std::vector& reps, NDArray& target) const { const auto targetLen = target.lengthOf(); if (target.ordering() == 'c' && ews == 1) { // ews == 1 always here //#pragma omp parallel for simd if(targetLen > - //Environment::getInstance()->elementwiseThreshold()) schedule(guided) + // Environment::getInstance()->elementwiseThreshold()) schedule(guided) for (Nd4jLong i = 0; i < targetLen; ++i) { auto yOffset = shape::subArrayOffset(i, target.shapeInfo(), shapeInfo()); BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 8106d654de3c..78a78eebe727 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -43,8 +43,8 @@ LaunchContext::LaunchContext(cudaStream_t* cudaStream, //_cudaStream = cudaStream; //_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; //*_cudaSpecialStream = specialCudaStream; _reductionPointer = - //reductionPointer; _scalarPointer = scalarPointer; _allocationPointer = - //allocationPointer; + // reductionPointer; _scalarPointer = scalarPointer; _allocationPointer = + // allocationPointer; _workspace = nullptr; _isAllocated = false; } diff --git a/libnd4j/include/helpers/Loops.hpp b/libnd4j/include/helpers/Loops.hpp index 24d08221805e..ee849b654a44 100644 --- a/libnd4j/include/helpers/Loops.hpp +++ b/libnd4j/include/helpers/Loops.hpp @@ -28,9 +28,9 @@ namespace sd {} // template void Loops::loopReduce(const double* x, const -// Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const Nd4jLong* -// zShapeInfo, double* extraParams, std::function -// startVal, std::function update, +// Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, double* z, const +// Nd4jLong* zShapeInfo, double* extraParams, std::function startVal, std::function update, // std::function op, // std::function postPr); template void // Loops::loopReduce(const float* x, const Nd4jLong* tadShapeInfo, diff --git a/libnd4j/include/helpers/MKLDNNStream.h b/libnd4j/include/helpers/MKLDNNStream.h index 0b5caa28ceda..2db2cdac5077 100644 --- a/libnd4j/include/helpers/MKLDNNStream.h +++ b/libnd4j/include/helpers/MKLDNNStream.h @@ -27,9 +27,10 @@ #if defined(HAVE_MKLDNN) +#include + #include #include -#include namespace sd { class MKLDNNStream { diff --git a/libnd4j/include/helpers/TAD.h b/libnd4j/include/helpers/TAD.h index 40d811701bce..2f763b0c8c4b 100644 --- a/libnd4j/include/helpers/TAD.h +++ b/libnd4j/include/helpers/TAD.h @@ -109,9 +109,9 @@ class TAD { * This method is for GPU mostly, it allows to initialize TAD instance * with precalculated tadOnlyShapeInfo */ - INLINEDEF void - initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, - int *dimension, int dimensionLength); + INLINEDEF void + initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, + int *dimension, int dimensionLength); #ifdef __CUDACC__ __host__ __device__ @@ -329,11 +329,13 @@ INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong ews = shape::elementWiseStride(originalShape); - this->numTads = shape::length(originalShape) / - shape::length(existingTAD); // this->tensorsAlongDimension(this->shapeInfo, - // this->dimension, - // this->dimensionLength);//shape::length(originalShape) - // / shape::length(existingTAD); + this->numTads = + shape::length(originalShape) / + shape::length( + existingTAD); // this->tensorsAlongDimension(this->shapeInfo, + // this->dimension, + // this->dimensionLength);//shape::length(originalShape) + // / shape::length(existingTAD); this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && diff --git a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h index 0d5af63a5ecb..a4f51f09e4ae 100644 --- a/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h +++ b/libnd4j/include/helpers/benchmark/BroadcastBenchmark.h @@ -18,9 +18,10 @@ // @author Alex Black // -#include "../OpBenchmark.h" #include +#include "../OpBenchmark.h" + #ifndef SD_BROADCASTBENCHMARK_H #define SD_BROADCASTBENCHMARK_H diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 043ddf1725f7..90e48e353e56 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -784,12 +784,12 @@ double beta, void* vY, const int incy) { */ // BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool -// transA, const bool transB, const int M, const int N, const int K, const double -// alpha, const void* A, const int lda, const void* B, const int ldb, const -// double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// transA, const bool transB, const int M, const int N, const int K, const +// double alpha, const void* A, const int lda, const void* B, const int ldb, +// const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, // FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char -// aOrder, const int M, const int N, const double alpha, const void* A, const int -// lda, const void* B, const int incx, const double beta, void* C, const int +// aOrder, const int M, const int N, const double alpha, const void* A, const +// int lda, const void* B, const int incx, const double beta, void* C, const int // incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const // double alpha, const void* vX, const Nd4jLong incx, const void* vY, const diff --git a/libnd4j/include/helpers/impl/DebugHelper.cpp b/libnd4j/include/helpers/impl/DebugHelper.cpp index 67d8c4816aad..cd8df4261971 100644 --- a/libnd4j/include/helpers/impl/DebugHelper.cpp +++ b/libnd4j/include/helpers/impl/DebugHelper.cpp @@ -79,7 +79,7 @@ void DebugHelper::retrieveDebugStatistics(DebugInfo* info, _meanValue += current; //_meanValue += delta / n; // this is a perfect formula but not working - //with omp in this notation _stdDevValue += delta2 * e / n; + // with omp in this notation _stdDevValue += delta2 * e / n; _zeroCount += sd::math::nd4j_abs(current) > 0.00001 ? 0 : 1; _positiveCount += current > 0 ? 1 : 0; diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index f56755800b4d..829484ded58b 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -381,12 +381,12 @@ void MmulHelper::matmul(const sd::NDArray* x, const sd::NDArray* y, } // BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool -// transA, const bool transB, const int M, const int N, const int K, const double -// alpha, const void* A, const int lda, const void* B, const int ldb, const -// double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, +// transA, const bool transB, const int M, const int N, const int K, const +// double alpha, const void* A, const int lda, const void* B, const int ldb, +// const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, // FLOAT_TYPES); BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char -// aOrder, const int M, const int N, const double alpha, const void* A, const int -// lda, const void* B, const int incx, const double beta, void* C, const int +// aOrder, const int M, const int N, const double alpha, const void* A, const +// int lda, const void* B, const int incx, const double beta, void* C, const int // incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); // BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const // double alpha, const void* vX, const Nd4jLong incx, const void* vY, const diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index 0013314a787d..d212f70c0f30 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -3859,8 +3859,8 @@ INLINEDEF _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data, Nd4jLong *dimension, int dimensionLength) { Nd4jLong *stride = shape::stride(data); // corner case: return the final item when its greater than the max, since its - // guaranteed to be left over note here that strides are interpreted in reverse - // for tad start from the front rather than the back + // guaranteed to be left over note here that strides are interpreted in + // reverse for tad start from the front rather than the back int rank = shape::rank(data); diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index 69c1352686bb..2f6a08b4b9c4 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -204,8 +204,9 @@ __device__ void BroadcastBool::transformInverseCuda( __shared__ Nd4jLong zEWS; if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(yShapeInfo) / tadLength; xEWS = shape::elementWiseStride(xShapeInfo); @@ -263,8 +264,9 @@ __device__ void BroadcastBool::transformCuda( __shared__ Nd4jLong zEWS; if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(xShapeInfo) / tadLength; yEWS = shape::elementWiseStride(yShapeInfo); diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu index 452cdfc654ba..a7e92732e9c1 100644 --- a/libnd4j/include/loops/cuda/broadcasting_int.cu +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -190,8 +190,9 @@ __device__ void BroadcastInt::transformInverseCuda( __shared__ Nd4jLong zEWS; if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(yShapeInfo) / tadLength; xEWS = shape::elementWiseStride(xShapeInfo); @@ -248,8 +249,9 @@ __device__ void BroadcastInt::transformCuda( __shared__ Nd4jLong zEWS; if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(xShapeInfo, + // dimension, dimensionLength); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(xShapeInfo) / tadLength; yEWS = shape::elementWiseStride(yShapeInfo); diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index 30cf6ce6ad4f..117e116693fc 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -117,8 +117,9 @@ __device__ void ReduceBoolFunction::transformCudaXD( isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1; - tadLength = shape::length(tadOnlyShapeInfo); // tadLength(xShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // tadLength(xShapeInfo, + // dimension, dimensionLength); numTads = shape::length(xShapeInfo) / tadLength; } __syncthreads(); diff --git a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu index 4e2b31db53c2..3b278e67b7f2 100644 --- a/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillDimensionalIsMax.cu @@ -38,8 +38,9 @@ __device__ void fillDimensionalIsMax(const void *vdX, void *vdZ, __shared__ int numTads; if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); // shape::tadLength(zShapeInfo, - // dimension, dimensionLength); + tadLength = + shape::length(tadOnlyShapeInfo); // shape::tadLength(zShapeInfo, + // dimension, dimensionLength); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); numTads = shape::length(zShapeInfo) / tadLength; } diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index d7acde30e8d7..231eef43eda0 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -249,9 +249,10 @@ _CUDA_D void SummaryStatsReduce::transform( __syncthreads(); if (threadIdx.x == 0) { - z[i] = OpType::getValue(postProcessOrNot, - sPartials[threadIdx.x]); // postProcess(sPartials[0],tadLength - // ,extraParams); + z[i] = OpType::getValue( + postProcessOrNot, + sPartials[threadIdx.x]); // postProcess(sPartials[0],tadLength + // ,extraParams); } } } diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index a5ff3038e7d7..99d76e992510 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -84,8 +84,8 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { iD, ShapeUtils::shapeAsString(variance).c_str()); } else { // REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when - // isTraining=true then number of input arrays must be equal to 3, but got %i - // instead !", block.width()); + // isTraining=true then number of input arrays must be equal to 3, but got + // %i instead !", block.width()); std::vector shape = {iD}; mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); diff --git a/libnd4j/include/ops/declarable/headers/common.h b/libnd4j/include/ops/declarable/headers/common.h index 5c0624a8f97b..bde901a61049 100644 --- a/libnd4j/include/ops/declarable/headers/common.h +++ b/libnd4j/include/ops/declarable/headers/common.h @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp index 7653d985624f..a6942a7d30a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2d.cpp @@ -80,20 +80,28 @@ static void pooling2d_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); hstart *= iStride2; hend *= iStride2; @@ -136,20 +144,28 @@ static void pooling2d_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); hstart *= iStride2; hend *= iStride2; @@ -200,20 +216,28 @@ static void pooling2d_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); hstart *= iStride2; hend *= iStride2; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp index 7dd18e530e6f..a6874e8f82a9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions_pooling2dBP.cpp @@ -94,20 +94,28 @@ static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); sum = -DataTypeUtils::max(); valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + @@ -178,20 +186,28 @@ static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); hstart *= gIStride2; hend *= gIStride2; @@ -244,20 +260,28 @@ static void pooling2dBP_(sd::graph::Context& block, const NDArray& input, if (hstart < 0) hstart += - dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) - // / static_cast(dH)); + dH * + ((-hstart + dH - 1) / + dH); // (Nd4jLong)sd::math::nd4j_ceil(static_cast(-hstart) + // / static_cast(dH)); if (wstart < 0) wstart += - dW * ((-wstart + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) - /// static_cast(dW)); + dW * + ((-wstart + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(-wstart) + /// static_cast(dW)); if (hend > iH) hend -= - dH * ((hend - iH + dH - 1) / dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) - /// static_cast(dH)); + dH * + ((hend - iH + dH - 1) / + dH); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(hend-iH) + /// static_cast(dH)); if (wend > iW) wend -= - dW * ((wend - iW + dW - 1) / dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) - /// static_cast(dW)); + dW * + ((wend - iW + dW - 1) / + dW); //(Nd4jLong)sd::math::nd4j_ceil(static_cast(wend-iW) + /// static_cast(dW)); sum = static_cast(0.f); valO = gO[b * oStride0 + c * oStride1 + oh * oStride2 + diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp index 8041770a3749..d8b04295e3d1 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dropout.cpp @@ -176,8 +176,9 @@ int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, probValue, alpha, alpha1, beta); if (res == ND4J_STATUS_OK) { (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, - //output, nullptr); + (*output) *= + (*gradOut); //->applyPairwiseTransform(gradOut, + // output, nullptr); } return res; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp index 923d3eac6e68..5a562c036b69 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gather.cpp @@ -44,7 +44,8 @@ void gather(sd::LaunchContext* context, const NDArray* input, if (indices->isScalar()) { if (input->rankOf() <= 1) { // For scalar indices, rank 0 or 1 input: can't do tensor along - // dimension 0 as this is whole array... instead, we want to get a scalar + // dimension 0 as this is whole array... instead, we want to get a + // scalar auto idx = indices->e(0); auto scalarNDArray = input->e(idx); output->assign(scalarNDArray); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp index b858ee9d7b37..cffa74461bb9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/gatherTransforms.cpp @@ -111,7 +111,8 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, if (indices->isScalar()) { if (input->rankOf() <= 1) { // For scalar indices, rank 0 or 1 input: can't do tensor along - // dimension 0 as this is whole array... instead, we want to get a scalar + // dimension 0 as this is whole array... instead, we want to get a + // scalar auto idx = indices->e(0); auto scalarNDArray = input->e(idx); output->assign(scalarNDArray); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 40bc335aff59..58eab23237d9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -217,7 +217,8 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, m += (*b); // addiRowVector // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] - // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is + // [i,z,f,o]) auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index 130426548d04..a73a9a1ecb66 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -31,12 +31,16 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// // template // __global__ static void batchnormCuda(const void* vx, const Nd4jLong* -// xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* -// vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const -// Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, -// void* vz, const Nd4jLong* -// zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const -// Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const T epsilon) { +// xShapeInfo, const void* vMean, +// const Nd4jLong* meanShapeInfo, +// const void* +// vVariance, const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* gammaShapeInfo, +// const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* +// zShapeInfo, const Nd4jLong* +// xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, +// const T epsilon) { // const auto x = reinterpret_cast(vx); // auto z = reinterpret_cast(vz); @@ -164,16 +168,21 @@ __global__ static void batchnormCuda2( /////////////////////////////////////////////////////////////////// // template // __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int -// threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* -// xShapeInfo, +// threadsPerBlock, const cudaStream_t *stream, +// const void* vx, const Nd4jLong* xShapeInfo, // const void* vMean, const // Nd4jLong* meanShapeInfo, -// const void* vVariance, -// const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* -// gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, -// const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* -// xTadOffsets, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, -// const double epsilon) +// const void* +// vVariance, +// const Nd4jLong* varianceShapeInfo, +// const void* vGamma, const Nd4jLong* +// gammaShapeInfo, const +// void* vBeta, const Nd4jLong* betaShapeInfo, +// void* vz, +// const Nd4jLong* zShapeInfo, +// const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, +// const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const double +// epsilon) // { // batchnormCuda<<>>(vx, diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index e35e50da000d..6c5e1691eb09 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -321,8 +321,9 @@ int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, if (res == ND4J_STATUS_OK) { // FIXME: can we make it single-loop? (*output) *= alpha; - (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, - //output, nullptr); + (*output) *= + (*gradOut); //->applyPairwiseTransform(gradOut, + // output, nullptr); } return res; } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu index 377a734dc218..1fb472885eea 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/extract_patches.cu @@ -61,7 +61,8 @@ static __global__ void globalExtractPatchesKernel( // outColDim/outRowDim for (Nd4jLong batch = start; batch < batchCount; batch += step) { auto patch = input + inputOffsets[batch]; // listOfMatricies->at(batch); - auto outMatrix = output + outputOffsets[batch]; // listOfOutputs->at(batch); + auto outMatrix = + output + outputOffsets[batch]; // listOfOutputs->at(batch); for (Nd4jLong i = 0; i < outRowDim; i++) { for (Nd4jLong j = 0; j < outColDim; j++) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index c85be400238c..5547d94ef007 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -177,10 +177,11 @@ static void resizeImage_(sd::LaunchContext* context, NDArray const* images, Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; auto stream = context->getCudaStream(); - T const* pInput = images->getDataBuffer() - ->specialAsT(); // reinterpret_cast(images->specialBuffer()); // - // this works only with 'c' direction + T const* pInput = + images->getDataBuffer() + ->specialAsT(); // reinterpret_cast(images->specialBuffer()); // + // this works only with 'c' direction F* pOutput = output->dataBuffer() ->specialAsT(); // reinterpret_cast(output->specialBuffer()); @@ -1274,7 +1275,8 @@ static __global__ void resizeAreaKernel( // auto startPtr = sharedPtr + y * scalesDim * sizeof(float); // float* yScales = yScalesShare + y * sizeof(float) * - // scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim + // scalesDim;//reinterpret_cast(startPtr); //shared + y * + // scalesDim // * y + scalesDim * sizeof(T const *) [scalesDim]; T const** yPtrs = // yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim]; yPtrs = // reinterpret_cast(sharedBuf); @@ -1520,13 +1522,13 @@ static __global__ void cropAndResizeKernel( topLeftPos)]); //->e(bIn, topYIndex, left_x_index, d)); const T topRight(images[shape::getOffset( imagesShape, topRightPos)]); //->e(bIn, topYIndex, - //right_x_index, d)); + // right_x_index, d)); const T bottomLeft(images[shape::getOffset( imagesShape, bottomLeftPos)]); //->e(bIn, bottomYIndex, - //left_x_index, d)); + // left_x_index, d)); const T bottomRight(images[shape::getOffset( imagesShape, bottomRightPos)]); //->e(bIn, bottomYIndex, - //right_x_index, d)); + // right_x_index, d)); const T top = topLeft + (topRight - topLeft) * x_lerp; const T bottom = bottomLeft + (bottomRight - bottomLeft) * x_lerp; Nd4jLong zPos[] = {b, y, x, d}; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu index 619dc76a0b79..e58944c679b7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lstm.cu @@ -184,7 +184,8 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, m += (*b); // addiRowVector // Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] - // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) + // to match TF (TF code comments state [i,f,z/ci,o] but behaviour is + // [i,z,f,o]) auto zi = m({0, 0, 0, nOut}); // z for input modulation gate, [bS, nOut] auto zz = m({0, 0, nOut, 2 * nOut}); // z for block input, [bS, nOut] auto zf = m({0, 0, 2 * nOut, 3 * nOut}); // z for forget gate, [bS, nOut] diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 13a3d3329539..3257c08dbb94 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -1063,8 +1063,9 @@ int logdetFunctor_(sd::LaunchContext *context, NDArray *input, cholesky(context, input, &tempOutput, false); auto outputBuf = - output->dataBuffer()->specialAsT(); // reinterpret_cast(output->specialBuffer()); - // // + e * n2; // + e * n2; + output->dataBuffer() + ->specialAsT(); // reinterpret_cast(output->specialBuffer()); + // // + e * n2; // + e * n2; auto inputBuf = tempOutput.dataBuffer() ->specialAsT< diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu index 99e6c2b8d1e1..cbec9bc8d07f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu @@ -731,11 +731,13 @@ void cbowBatchExec_(LaunchContext *lc, NDArray &s0, NDArray &s1, NDArray &s1n, const auto bIndices = indices.dataBuffer()->primaryAsT(); // buffer());//AsT(); const auto bCodes = - codes.dataBuffer()->primaryAsT(); // reinterpret_cast(codes.buffer()); - // //bufferAsT(); + codes.dataBuffer() + ->primaryAsT(); // reinterpret_cast(codes.buffer()); + // //bufferAsT(); const auto bStarters = - negStarters.dataBuffer()->primaryAsT(); // reinterpret_cast(negStarters.buffer()); - // //AsT(); + negStarters.dataBuffer() + ->primaryAsT(); // reinterpret_cast(negStarters.buffer()); + // //AsT(); const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); lr.syncToHost(); nLabels.syncToHost(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu index f899da6db46d..fc27122c6818 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu @@ -265,8 +265,9 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, // /////////////////////////////////////////////////////////////////// // template // __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int -// threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* -// xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) { +// threadsPerBlock, const cudaStream_t *stream, +// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* +// zTadShapeInfo, const int axis) { // unstackCuda<<>>(vx, // xShapeInfo, pVz, zTadShapeInfo, axis); @@ -354,8 +355,9 @@ BUILD_SINGLE_TEMPLATE(template void unstack_, // /////////////////////////////////////////////////////////////////// // template // __host__ static void stackCudaLauncher(const int blocksPerGrid, const int -// threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* -// xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) { +// threadsPerBlock, const cudaStream_t *stream, +// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* +// zShapeInfo, const int axis) { // stackCuda<<>>(pVx, // xTadShapeInfo, vz, zShapeInfo, axis); diff --git a/libnd4j/include/ops/declarable/helpers/hamming.h b/libnd4j/include/ops/declarable/helpers/hamming.h index 78e17c2dda3a..77d0aa6df35b 100644 --- a/libnd4j/include/ops/declarable/helpers/hamming.h +++ b/libnd4j/include/ops/declarable/helpers/hamming.h @@ -21,8 +21,8 @@ #ifndef SD_HAMMING_H #define SD_HAMMING_H -#include #include +#include namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp index 73b768bc777c..be445baa65ac 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastBoolOp.cpp @@ -48,14 +48,14 @@ Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); REQUIRE_TRUE(shape::length(packX.primaryShapeInfo()) == y->lengthOf(), 0, "Length of broadcast TAD should be equal to length of Y " @@ -82,14 +82,15 @@ Nd4jStatus LegacyBroadcastBoolOp::validateAndExecute(Context &block) { ? packZ.primaryShapeInfo() : packZ .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tadZ.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + // manager.replicatePointer(tadZ.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); auto zTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() - : packZ.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tadZ.tadOffsets, - //tadZ.numTads * sizeof(Nd4jLong)); + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOffsets, + // tadZ.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execBroadcast( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp index 516a88155d97..7d9d3a5ed3e6 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyBroadcastOp.cpp @@ -60,14 +60,14 @@ Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); if (x == z) NativeOpExecutioner::execBroadcast( @@ -87,14 +87,15 @@ Nd4jStatus LegacyBroadcastOp::validateAndExecute(Context &block) { ? packZ.primaryShapeInfo() : packZ .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tadZ.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); + // manager.replicatePointer(tadZ.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadZ.tadOnlyShapeInfo)); auto zTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() - : packZ.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tadZ.tadOffsets, - //tadZ.numTads * sizeof(Nd4jLong)); + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadZ.tadOffsets, + // tadZ.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execBroadcast( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp index 52fcd2f46c34..9f7b0b5f71d6 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduce3Op.cpp @@ -68,28 +68,30 @@ Nd4jStatus LegacyReduce3Op::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tadX.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); + // manager.replicatePointer(tadX.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadX.tadOnlyShapeInfo)); auto xTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tadX.tadOffsets, - //tadX.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadX.tadOffsets, + // tadX.numTads * sizeof(Nd4jLong)); auto yTadShape = Environment::getInstance()->isCPU() ? packZ.primaryShapeInfo() : packZ .specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tadY.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); + // manager.replicatePointer(tadY.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tadY.tadOnlyShapeInfo)); auto yTadOffsets = Environment::getInstance()->isCPU() ? packZ.primaryOffsets() - : packZ.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tadY.tadOffsets, - //tadY.numTads * sizeof(Nd4jLong)); + : packZ + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tadY.tadOffsets, + // tadY.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduce3( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp index d215bf8da546..fc92555a2116 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceBoolOp.cpp @@ -91,8 +91,9 @@ Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); // manager.replicatePointer(tad.tadOffsets, - // tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceBool( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), @@ -140,14 +141,15 @@ Nd4jStatus LegacyReduceBoolOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceBool( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp index 552cfcc049a2..d7484520bd39 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp @@ -91,8 +91,9 @@ Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); // manager.replicatePointer(tad.tadOffsets, - // tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceFloat( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), @@ -138,14 +139,15 @@ Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceFloat( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp index 47743b822d5c..0e976984e721 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceLongOp.cpp @@ -88,14 +88,15 @@ Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceLong( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), @@ -140,14 +141,15 @@ Nd4jStatus LegacyReduceLongOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceLong( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp index da1a60e3a606..bc252292a15b 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceSameOp.cpp @@ -84,14 +84,15 @@ Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceSame( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), @@ -136,14 +137,15 @@ Nd4jStatus LegacyReduceSameOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execReduceSame( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp index 9acce2851924..3ad3ed2064c0 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyStatsOp.cpp @@ -70,14 +70,15 @@ Nd4jStatus LegacyStatsOp::validateAndExecute(Context &block) { ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOnlyShapeInfo, - //shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + // manager.replicatePointer(tad.tadOnlyShapeInfo, + // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = Environment::getInstance()->isCPU() ? packX.primaryOffsets() - : packX.specialOffsets(); //(Nd4jLong *) - //manager.replicatePointer(tad.tadOffsets, - //tad.numTads * sizeof(Nd4jLong)); + : packX + .specialOffsets(); //(Nd4jLong *) + // manager.replicatePointer(tad.tadOffsets, + // tad.numTads * sizeof(Nd4jLong)); NativeOpExecutioner::execSummaryStats( block.launchContext(), opNum, x->buffer(), x->shapeInfo(), diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 25d2fa1d0961..eda15da540ee 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -681,9 +681,9 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { && !hasSeqLen // Sequence length array not supported in MKL DNN && dataFormat < 2 // Data format - only 0 and 1 supported in MKL DNN- 0 = // [sL, bS, nIn], 1 = [bS, sL ,nIn] - && directionMode < 4 // Direction mode - only 0-3 supported in MKL DNN - // (no extra dim option) - 0 = fwd, 1 = bwd, 2 = - // bidirectional sum, 3 = bidirectional concat + && directionMode < 4 // Direction mode - only 0-3 supported in MKL DNN + // (no extra dim option) - 0 = fwd, 1 = bwd, 2 = + // bidirectional sum, 3 = bidirectional concat && retLastH == retLastC // Return both lastH and lastC, or return neither // (not just 1 or other) && hasInitH == hasInitC; // Need both or neither initial H and C diff --git a/libnd4j/include/system/play.h b/libnd4j/include/system/play.h index fc525fd0a26a..37a31def95ec 100644 --- a/libnd4j/include/system/play.h +++ b/libnd4j/include/system/play.h @@ -146,7 +146,7 @@ validateAndExecute(Block& block); \ sd::ops::NAME::validateAndExecute(Block& block) */ //#define END_OP(NAME) }; static sd::ops::__registrator> -//register_op##Name; +// register_op##Name; //#DECLARE_OP(Concat, -1, 1) @@ -160,7 +160,7 @@ sd::ops::NAME::validateAndExecute(Block& block) //_EXEC_KERNEL_F(scalarAlongDimension_, scalarAlongDimensionGeneric, float, //(float inputA, float inputB), (paramA, paramB), (10, SCALAR::Add), (11, -//SCALAR::Subtract), (12, SCALAR::Multiply)) +// SCALAR::Subtract), (12, SCALAR::Multiply)) // DISPATCH_KERNEL_SIMPLE(scalarAlongDimension_, scalarAlongDimensionGeneric, // float, INPUT(float inputA, float inputB), PARAMS(paramA, paramB), @@ -178,20 +178,20 @@ opTypeB, N, x, xShape, y, yShape, z, zShape, extrasA, extrasB, scalarA, scalarB), float, OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS));*/ // DISPATCH_KERNEL_META(invertedMetaPairwiseShaped_Pairwise_Scalar_, -// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const -// int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, -// int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, -// float scalarA, float scalarB), PARAMS(opTypeA, opTypeB, N, dx, xShapeInfo, dy, -// yShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB), -// OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) +// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, +// INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int +// *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float +// *extraA, float *extraB, float scalarA, float scalarB), PARAMS(opTypeA, +// opTypeB, N, dx, xShapeInfo, dy, yShapeInfo, dz, zShapeInfo, extraA, extraB, +// scalarA, scalarB), OPS_A(PAIRWISE_TRANSFORM_OPS), OPS_B(SCALAR_OPS)) //_EXPAND_KERNEL_CALL(invertedMetaPairwiseShaped_Pairwise_Scalar_, -//invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, INPUT(const -//int opTypeA, const int opTypeB, long N, float *dx, int *xShapeInfo, float *dy, -//int *yShapeInfo, float *dz, int *zShapeInfo, float *extraA, float *extraB, -//float scalarA, float scalarB), PARAMS(N, dx, dy, xStride, yStride, paramsPtr, -//dz, zStride, nullptr, nullptr, nullptr), 66, simdOps::SomeOpA, 99, -//simdOps::SomeOpB) +// invertedMetaPairwiseShapedGeneric, float, simdOps::InvertedMetaOp, +// INPUT(const int opTypeA, const int opTypeB, long N, float *dx, int +// *xShapeInfo, float *dy, int *yShapeInfo, float *dz, int *zShapeInfo, float +// *extraA, float *extraB, float scalarA, float scalarB), PARAMS(N, dx, dy, +// xStride, yStride, paramsPtr, dz, zStride, nullptr, nullptr, nullptr), 66, +// simdOps::SomeOpA, 99, simdOps::SomeOpB) /* extern "C" __global__ void diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 206345ed966c..9efbca4a04d4 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -3421,14 +3421,18 @@ TEST_F(CudaBasicsTests1, execRandom_1) { // LaunchContext lc(&stream); // // // ::execRandom(extraPointers, random::GaussianDistribution, &gen, - //z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &extra); + // z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + // &extra); // // call cuda kernel which calculates result // NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, // &gen, - // nullptr, z.shapeInfo(), z.specialBuffer(), - //z.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), - //z.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), - //z.specialShapeInfo(), extraArguments.argumentsAsT(z.dataType())); + // nullptr, z.shapeInfo(), + //z.specialBuffer(), + // z.specialShapeInfo(), nullptr, z.shapeInfo(), + // z.specialBuffer(), + // z.specialShapeInfo(), nullptr, z.shapeInfo(), + // z.specialBuffer(), z.specialShapeInfo(), + // extraArguments.argumentsAsT(z.dataType())); // // cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); // ASSERT_EQ(cudaResult, 0); @@ -3461,8 +3465,8 @@ TEST_F(CudaBasicsTests1, execRandom_2) { // // prepare input arrays for prepareDataForCuda function // std::vector> hostData; // hostData.emplace_back(extraArguments.data(), extraArguments.size() * - //sizeof(double)); // 0 -- dimensions std::vector - //devicePtrs(hostData.size(), nullptr); + // sizeof(double)); // 0 -- dimensions std::vector + // devicePtrs(hostData.size(), nullptr); // // create cuda stream and LaunchContext cudaError_t cudaResult; @@ -3472,7 +3476,7 @@ TEST_F(CudaBasicsTests1, execRandom_2) { // allocate required amount of global device memory and copy host data to it // cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); - //ASSERT_EQ(0, cudaResult); + // ASSERT_EQ(0, cudaResult); // call cuda kernel which calculates result NativeOpExecutioner::execRandom( @@ -3562,8 +3566,8 @@ TEST_F(CudaBasicsTests1, execRandom_4) { // // prepare input arrays for prepareDataForCuda function // std::vector> hostData; // hostData.emplace_back(extraArguments.data(), extraArguments.size() * - //sizeof(double)); // 0 -- dimensions std::vector - //devicePtrs(hostData.size(), nullptr); + // sizeof(double)); // 0 -- dimensions std::vector + // devicePtrs(hostData.size(), nullptr); // create cuda stream and LaunchContext // cudaError_t cudaResult; @@ -3572,8 +3576,8 @@ TEST_F(CudaBasicsTests1, execRandom_4) { // LaunchContext lc(&stream); // // // allocate required amount of global device memory and copy host data - //to it cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); - //ASSERT_EQ(0, cudaResult); + // to it cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); auto context = z.getContext(); PointersManager pm(context, "execRandom4"); // call cuda kernel which calculates result diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index d28022f05932..07b2ec91b484 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1098,8 +1098,9 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { NDArray input = NDArrayFactory::create( 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); NDArray n = NDArrayFactory::create(0); - NDArray exp = NDArrayFactory::create(1.f); // NDArrayFactory::create('c', - // {2,2}, {1.f, 4.f, 7.f, 10.f}); + NDArray exp = + NDArrayFactory::create(1.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); // input.linspace(1.f); @@ -1117,8 +1118,9 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { NDArray input = NDArrayFactory::create( 'c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); NDArray n = NDArrayFactory::create(4); - NDArray exp = NDArrayFactory::create(8.f); // NDArrayFactory::create('c', - // {2,2}, {1.f, 4.f, 7.f, 10.f}); + NDArray exp = + NDArrayFactory::create(8.f); // NDArrayFactory::create('c', + // {2,2}, {1.f, 4.f, 7.f, 10.f}); // input.linspace(1.f); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index a9ae77875aef..71b362e5c8f9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -477,11 +477,11 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { 0.0038, 0.0038, 0.0030}); // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, // 0.000000, 0.000000, 65.000000, 60.000000, - // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); // data.linspace(1); auto exp4 = NDArrayFactory::create( 'c', {1, 108}, @@ -528,11 +528,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); // auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, // 0.000000, 0.000000, 65.000000, 60.000000, - // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, + // 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); // data.linspace(1); // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index d0640091f180..7b8e5c4d7c9d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -1828,8 +1828,8 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { auto x = NDArrayFactory::create( 'c', {6, 3}); //, {1, - //2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., - //16., 17., 18.}); + // 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + // 15., 16., 17., 18.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); auto exp = NDArrayFactory::create( 'c', {3, 3}, {2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); @@ -1845,8 +1845,8 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { auto x = NDArrayFactory::create( 'c', {6, 3}); //, {1, - //2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., - //16., 17., 18.}); + // 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + // 15., 16., 17., 18.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 2, 2}); auto z = NDArrayFactory::create( 'c', diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index a42cf64d856b..948854da981b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -1037,21 +1037,21 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { // 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, // 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, - // 0.f, 400.f}); 02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, - // 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, - // 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, - // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, - // 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, - // 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, - // 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, - // 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, + // 0.f, 400.f}); 02Dropout result is [4.000000, 0.000000, 12.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, + // 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, + // 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, + // 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, + // 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, - // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, - // 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, - // 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] + // 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, + // 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, + // 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, + // 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, sd::DataType::FLOAT32); // diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index ec6b15698575..d9a9373a2ef4 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -59,8 +59,9 @@ TEST_F(EmptyTests, Test_Create_Empty_2) { TEST_F(EmptyTests, Test_Concat_1) { // auto empty = NDArrayFactory::empty_(); - auto empty = NDArray('c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::create_('c', - // {(Nd4jLong)0}}; + auto empty = NDArray( + 'c', {0}, sd::DataType::FLOAT32); // NDArrayFactory::create_('c', + // {(Nd4jLong)0}}; auto vector = NDArrayFactory::create('c', {1}, {1.0f}); ASSERT_TRUE(empty.isEmpty()); diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 0652d9fd4c7a..68000833a957 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -387,7 +387,7 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test2) { } //////////////////////////////////////////////////////////////////////////////// -///multiply +/// multiply TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; @@ -425,7 +425,7 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test2) { } //////////////////////////////////////////////////////////////////////////////// -///multiply +/// multiply TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test1) { if (!Environment::getInstance()->isExperimentalBuild()) return; diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index f6af6438f4e6..9cf13b7ca615 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -178,8 +178,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_1) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", @@ -251,8 +251,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_3) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", @@ -308,8 +308,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Add, y, z); // @@ -337,8 +337,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_5) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x += y; // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); x.syncToHost(); @@ -370,8 +370,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_6) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x += y; // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); x.syncToHost(); @@ -400,8 +400,8 @@ TEST_F(NDArrayCudaBasicsTests, TestAdd_7) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x += 2.; // x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); x.syncToHost(); @@ -430,8 +430,8 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, y, z); // x.printBuffer("3X = "); // y.printBuffer("3Y = "); @@ -462,8 +462,8 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, y, z); // @@ -491,8 +491,8 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, y, z); // x.printBuffer("23X = "); // y.printBuffer("23Y = "); @@ -523,10 +523,9 @@ TEST_F(NDArrayCudaBasicsTests, TestMultiply_4) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); x *= y; // x.tickWriteDevice(); // x.printBuffer("33Result out"); @@ -548,8 +547,9 @@ TEST_F(NDArrayCudaBasicsTests, TestPrimitiveNeg_01) { auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); auto exp = NDArrayFactory::create('c', {5}, {-1, -2, -3, -4, -5}); - auto stream = x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); + auto stream = + x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); NativeOpExecutioner::execTransformSame( x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), @@ -893,8 +893,9 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_1) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x *= y; // x.syncToHost(); @@ -923,10 +924,9 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; // z.printBuffer("53Result out"); @@ -958,10 +958,9 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); // *= y; // z.printBuffer("52Result out"); @@ -995,10 +994,9 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); x.applyPairwiseTransform(pairwise::Multiply, y, z); // *= y; // z.printBuffer("51Result out"); @@ -1062,8 +1060,8 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastRaw_1) { LaunchContext* pLc = x.getContext(); // allocate required amount of global device memory and copy host data to it - // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, - // cudaResult); + // cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); + // ASSERT_EQ(0, cudaResult); for (size_t i = 0; i < devicePtrs.size(); ++i) { cudaResult = cudaMalloc( &devicePtrs[i], @@ -1112,10 +1110,9 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); x *= y; // @@ -1145,13 +1142,11 @@ TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { // x.lengthOf() * x.sizeOfT()); ASSERT_EQ(0, res); res = // cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * // x.sizeOfT()); ASSERT_EQ(0, res); res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); ASSERT_EQ(0, - // res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); - // x.printBuffer("23X = "); - // y.printBuffer("23Y = "); - // void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* - // other, NDArray* target, const bool checkTargetShape, ExtraArguments - // *extraArgs) + // **>(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + // ASSERT_EQ(0, res); x.applyPairwiseTransform(pairwise::Multiply, &y, &z, + // nullptr); x.printBuffer("23X = "); y.printBuffer("23Y = "); void + // NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* other, + // NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); // @@ -1172,8 +1167,9 @@ TEST_F(NDArrayCudaBasicsTests, TestReduceSum_1) { auto y = NDArrayFactory::create(15); auto exp = NDArrayFactory::create(15); - auto stream = x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); + auto stream = + x.getContext()->getCudaStream(); // reinterpret_cast(&nativeStream); NativeOpExecutioner::execReduceSameScalar( x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index c85ade5d18c8..e21d90939403 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -1520,7 +1520,8 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; // double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, -// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] = +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] +// = // {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, // -0.816497, 0.408248}; @@ -1571,7 +1572,8 @@ TEST_F(NDArrayTest, TestReshapeNegative2) { // double arrA[6] = {1, 2, 3, 4, 5, 6}; // double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, -// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] = +// 0.000000}; double arrS[3] = {9.508032, 0.77287, 0.000}; double arrVt[9] +// = // {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, // -0.816497, 0.408248}; diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 76f9d8053ca3..1ef3f261d7a3 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -49,7 +49,7 @@ class RNGTests : public testing::Test { //_bufferB = new Nd4jLong[100000]; //_rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, //(Nd4jPointer) _bufferA); _rngB = (sd::random::RandomBuffer *) - //initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); + // initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); _rngA.setStates(_seed, _seed); _rngB.setStates(_seed, _seed); nexp0->assign(-1.0f); From 03d3d7b5888337fac0124ce1f1752eb4f24c0fb1 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 11 May 2020 20:44:59 +0300 Subject: [PATCH 120/233] merge Signed-off-by: raver119@gmail.com --- .../main/java/org/nd4j/nativeblas/Nd4jCpu.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index b67949a543aa..b97274ba1e90 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -16619,6 +16619,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } + @Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyavgnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyavgnorm_bp position(long position) { + return (clipbyavgnorm_bp)super.position(position); + } + + public clipbyavgnorm_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_cumsum) From e615438336dea0c07de9c2cd0ea69dfc3169a16e Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 14 May 2020 08:11:57 +0300 Subject: [PATCH 121/233] merge Signed-off-by: raver119@gmail.com --- libnd4j/include/array/impl/NDArray.cpp | 10 +++++----- .../transforms/clip_by_averaged_norm.cpp | 2 +- .../generic/transforms/merge_max_idx.cpp | 16 ++++++++-------- .../ops/declarable/helpers/cpu/clip.cpp | 10 +++++----- .../ops/declarable/platform/mkldnn/concat.cpp | 4 ++-- .../declarable/platform/mkldnn/deconv2d.cpp | 8 ++++---- .../declarable/platform/mkldnn/mkldnnUtils.cpp | 18 ++++++++++-------- .../declarable/platform/mkldnn/xw_plus_b.cpp | 10 ++++------ .../layers_tests/DeclarableOpsTests16.cpp | 4 ++-- .../layers_tests/DeclarableOpsTests19.cpp | 4 ++-- 10 files changed, 43 insertions(+), 43 deletions(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index fc3422f416d3..c39e20a3ce2a 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -1124,16 +1124,16 @@ std::string NDArray::asString(Nd4jLong limit) { //////////////////////////////////////////////////////////////////////// template -std::vector NDArray::getBufferAsVector() { +std::vector NDArray::getBufferAsVector() const { std::vector vector(lengthOf()); for (Nd4jLong e = 0; e < lengthOf(); e++) vector[e] = this->e(e); return vector; } BUILD_SINGLE_TEMPLATE(template SD_EXPORT std::vector, - NDArray::getBufferAsVector(), LIBND4J_TYPES); + NDArray::getBufferAsVector() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsFlatVector() { +std::vector NDArray::getShapeAsFlatVector() const { std::vector vector(this->rankOf()); for (int e = 0; e < this->rankOf(); e++) vector[e] = static_cast(this->sizeAt(e)); @@ -1158,7 +1158,7 @@ std::vector NDArray::getShapeAsVectorInt() const { } //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsFlatVector() { +std::vector NDArray::getShapeInfoAsFlatVector() const { int magicNumber = shape::shapeInfoLength(this->rankOf()); std::vector vector(magicNumber); @@ -1169,7 +1169,7 @@ std::vector NDArray::getShapeInfoAsFlatVector() { } //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsVector() { +std::vector NDArray::getShapeInfoAsVector() const { int magicNumber = shape::shapeInfoLength(this->rankOf()); std::vector vector(magicNumber); for (int e = 0; e < magicNumber; e++) vector[e] = this->_shapeInfo[e]; diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp index e961f06afc5b..1255566462f1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp @@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) { const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); - helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true); + helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, block.getIArguments(), clipNorm, true); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index 95327b031e68..aadcf88afa67 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -42,22 +42,22 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) { DECLARE_SYN(MergeMaxIndex, mergemaxindex); - DECLARE_TYPES(mergemaxindex) { - getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_INDICES}); - } +DECLARE_TYPES(mergemaxindex) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INDICES}); } -} // namespace ops + DECLARE_SHAPE_FN(mergemaxindex) { auto in = inputShape->at(0); auto dtype = DataType::INT32; if (block.numI() > 0) dtype = (DataType)INT_ARG(0); - auto resShape = - ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); + auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); return SHAPELIST(CONSTANT(resShape)); } + +} // namespace ops } // namespace sd #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp index 06ed5c0adce6..7520bf127012 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp @@ -52,9 +52,9 @@ z = &input; auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { - const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}); + const NDArray actualNorm = useAverage ? listOfSubArrs.at(i).reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i).lengthOf() : listOfSubArrs.at(i).reduceAlongDimension(reduce::Norm2, {}); if(actualNorm.e(0) > clipNorm.e(0)) - *listOfSubArrs.at(i) *= clipNorm / actualNorm; + listOfSubArrs.at(i) *= clipNorm / actualNorm; } }; samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size()); @@ -107,7 +107,7 @@ static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& g auto gradOSubArr = gradOSubArrs.at(i); auto gradISubArr = gradISubArrs.at(i); - const T norm = useAverage ? norm2.e(i) / gradISubArr->lengthOf() : norm2.e(i); + const T norm = useAverage ? norm2.e(i) / gradISubArr.lengthOf() : norm2.e(i); if (norm > clipVal) { @@ -121,10 +121,10 @@ static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& g return factor1 * y * (static_cast(1.f) - factor2 * x * sum); }; - inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); + inputSubArr.applyPairwiseLambda(gradOSubArr, lambda, gradISubArr); } else - gradISubArr->assign(gradOSubArr); + gradISubArr.assign(gradOSubArr); } }; samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp index 9df63556eb89..1f6827d0988d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -94,7 +94,7 @@ PLATFORM_IMPL(concat, ENGINE_CPU) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided"); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); @@ -105,7 +105,7 @@ PLATFORM_IMPL(concat, ENGINE_CPU) { int index = 0; bool allOfSameType = true; auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; - auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); + auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : DataType::FLOAT32; for(int i = 0; i < numOfInArrs; ++i) { auto input = INPUT_VARIABLE(i); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index c254cc5cc213..1f385c721c7c 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -57,9 +57,9 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, std::vector permut; if (0 == wFormat) permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } else if (1 == wFormat) + else if (1 == wFormat) permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - else + else permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] @@ -202,9 +202,9 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, std::vector permut; if (0 == wFormat) permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } else if (1 == wFormat) + else if (1 == wFormat) permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - else + else permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 9dccd7d8cd92..52f52feda2b5 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -72,14 +72,16 @@ dnnl::memory::format_tag getFormat(const NDArray& arr) { ////////////////////////////////////////////////////////////////////// void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut) { if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) { - mklMd.data.format_kind = dnnl_blocked; // overrides formatif(permut.empty()) - for (auto i = 0; i < array.rankOf(); ++i) - mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i); - else { - if(array.rankOf() != permut.size()) - throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !"); - for (auto i = 0; i < array.rankOf(); ++i) - mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]); + mklMd.data.format_kind = dnnl_blocked; // overrides format + if(permut.empty()) + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i); + else { + if(array.rankOf() != permut.size()) + throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !"); + + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]); } } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp index 5f373a40d00a..d7b8c20587e4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -233,15 +233,13 @@ static void xwPlusBiasBp(const NDArray* x, const NDArray* weights, dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx)); mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); - // create engineauto engine = - mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // create engine + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // forward // operation primitive description dnnl::inner_product_forward::desc op_ff_desc( - dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, - dLdz_mkl_md); - dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, - engine); + dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md,dLdz_mkl_md); + dnnl::inner_product_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); // backprob // dLdw diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 5b39ae2cc910..36b4134641a4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -1100,7 +1100,7 @@ TEST_F(DeclarableOpsTests16, clipbynorm_3) { auto result = op.evaluate({&x}, {1.0}, {1}); auto z = result.at(0); - auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); + auto zNorm1 = z.reduceAlongDimension(reduce::Norm2, {1}, true); auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); ASSERT_TRUE(exp.isSameShape(&zNorm1)); @@ -1255,7 +1255,7 @@ TEST_F(DeclarableOpsTests16, clipbynorm_12) { sd::ops::clipbynorm op; auto result = op.evaluate({&x}, {0.54}, {}); - ASSERT_EQ(e, *result.at(0)); + ASSERT_EQ(e, result.at(0)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 8ecb3e52c6fe..227e1514873f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -241,7 +241,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(135079944 + 4, encoded->lengthOf()); + ASSERT_EQ(135079944 + 4, encoded.lengthOf()); ASSERT_NE(exp, initial); /* for (int e = 0; e < initial.lengthOf(); e++) { @@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { //} sd::ops::decode_threshold dec; - auto status = dec.execute({&initial, encoded}, {&initial}); + auto status = dec.execute({&initial, &encoded}, {&initial}); ASSERT_EQ(Status::OK(), status); // checking equality of all dedoded bits From 3114c90fe8fef96a7ae0cb5ac3ae0cb44bcdbe5e Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 14 May 2020 08:15:51 +0300 Subject: [PATCH 122/233] disable one long running test in debug builds Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 227e1514873f..5c4101671b8d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -230,6 +230,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { } TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { +#ifdef _RELEASE // [2,1,135079944,1,1,8192,1,99] auto initial = NDArrayFactory::create('c', {1, 135079944}); initial = 1.0f; @@ -274,6 +275,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { */ ASSERT_EQ(exp, initial); +#endif } From 1410b90f0f24c2c0a7372af8ffc9b14331e83be1 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 15 May 2020 11:05:48 +0300 Subject: [PATCH 123/233] master merged Signed-off-by: raver119@gmail.com --- libnd4j/include/array/NDArray.h | 6 +- libnd4j/include/helpers/cpu/svd.cpp | 4 +- .../include/helpers/impl/EigenValsAndVecs.cpp | 8 +- libnd4j/include/helpers/impl/FullPivLU.cpp | 8 +- .../helpers/impl/HessenbergAndSchur.cpp | 18 +- libnd4j/include/helpers/impl/Sqrtm.cpp | 8 +- libnd4j/include/helpers/impl/biDiagonalUp.cpp | 2 +- libnd4j/include/helpers/impl/hhSequence.cpp | 6 +- libnd4j/include/helpers/impl/householder.cpp | 8 +- libnd4j/include/helpers/impl/jacobiSVD.cpp | 53 +- .../parity_ops/unsorted_segment_min.cpp | 6 +- .../ops/declarable/generic/reduce/argamax.cpp | 6 +- .../ops/declarable/generic/reduce/argamin.cpp | 6 +- .../helpers/cpu/extract_patches.cpp | 2 +- .../ops/declarable/helpers/cpu/lstsq.cpp | 2 +- .../ops/declarable/helpers/cpu/segment.cpp | 24 +- .../ops/declarable/helpers/cpu/solve.cpp | 10 +- .../ops/declarable/helpers/cpu/svd.cpp | 3 +- .../helpers/cpu/triangular_solve.cpp | 4 +- .../ops/declarable/helpers/impl/sqrtm.cpp | 2 +- .../tests_cpu/layers_tests/HelpersTests1.cpp | 2896 +++++------------ .../layers_tests/PlaygroundTests.cpp | 2 + 22 files changed, 941 insertions(+), 2143 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index a91a86eb93d2..b37eefc10e1f 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -2118,7 +2118,8 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1)))); } -////////////////////////////////////////////////////////////////////////template +//////////////////////////////////////////////////////////////////////// +template T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) throw std::invalid_argument( @@ -2134,7 +2135,8 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { return *(reinterpret_cast(bufferWithOffset(i * strideAt(0) + j * strideAt(1) + k * strideAt(2)))); } -////////////////////////////////////////////////////////////////////////template +//////////////////////////////////////////////////////////////////////// +template T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const { if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || diff --git a/libnd4j/include/helpers/cpu/svd.cpp b/libnd4j/include/helpers/cpu/svd.cpp index 26a94055b862..dfe100dc96ca 100644 --- a/libnd4j/include/helpers/cpu/svd.cpp +++ b/libnd4j/include/helpers/cpu/svd.cpp @@ -465,9 +465,9 @@ void SVD::calcSingVals(const NDArray& col0, const NDArray& diag, if (shift == left && (muCur < (T)0. || muCur > right - left)) useBisection = true; - elseif (shift == right && (muCur < -(right - left) || muCur > (T)0.)) + else if (shift == right && (muCur < -(right - left) || muCur > (T)0.)) useBisection = true; - elseif (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && + else if (math::nd4j_abs(fCur) > math::nd4j_abs(fPrev) && math::nd4j_abs(fCur - fPrev) > (T)16. * DataTypeUtils::eps()) useBisection = true; } diff --git a/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp b/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp index 6eeb0c28bff1..8f11363cb7ec 100644 --- a/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp +++ b/libnd4j/include/helpers/impl/EigenValsAndVecs.cpp @@ -283,10 +283,10 @@ void EigenValsAndVecs::calcEigenVecs(const NDArray& schurMatrixU) { } -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; -template class ND4J_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; +template class SD_EXPORT EigenValsAndVecs; } } diff --git a/libnd4j/include/helpers/impl/FullPivLU.cpp b/libnd4j/include/helpers/impl/FullPivLU.cpp index efb7571ed0a4..6e7993b1bc50 100644 --- a/libnd4j/include/helpers/impl/FullPivLU.cpp +++ b/libnd4j/include/helpers/impl/FullPivLU.cpp @@ -160,10 +160,10 @@ void FullPivLU::solve(const NDArray& A, const NDArray& b, NDArray& x) { x({colsPermut[i],colsPermut[i]+1, 0,0}, true).nullify(); } -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; -template class ND4J_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; +template class SD_EXPORT FullPivLU; } } diff --git a/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp b/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp index 31495cab9f10..1422cf261ceb 100644 --- a/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp +++ b/libnd4j/include/helpers/impl/HessenbergAndSchur.cpp @@ -368,15 +368,15 @@ void Schur::calcFromHessenberg() { } } -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; -template class ND4J_EXPORT Hessenberg; - -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; -template class ND4J_EXPORT Schur; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; +template class SD_EXPORT Hessenberg; + +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; +template class SD_EXPORT Schur; } } diff --git a/libnd4j/include/helpers/impl/Sqrtm.cpp b/libnd4j/include/helpers/impl/Sqrtm.cpp index 5fe45656f9b8..087097a59c8e 100644 --- a/libnd4j/include/helpers/impl/Sqrtm.cpp +++ b/libnd4j/include/helpers/impl/Sqrtm.cpp @@ -265,10 +265,10 @@ void Sqrtm::calc(const NDArray& in, NDArray& out) { MmulHelper::mmul(&schur._U, &temp, &out); } -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; -template class ND4J_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; +template class SD_EXPORT Sqrtm; } diff --git a/libnd4j/include/helpers/impl/biDiagonalUp.cpp b/libnd4j/include/helpers/impl/biDiagonalUp.cpp index d5326c21a29d..c66fecf788e1 100644 --- a/libnd4j/include/helpers/impl/biDiagonalUp.cpp +++ b/libnd4j/include/helpers/impl/biDiagonalUp.cpp @@ -125,7 +125,7 @@ HHsequence BiDiagonalUp::makeHHsequence_(const char type) { const int diagSize = type == 'u' ? _HHbidiag.sizeAt(0) : _HHbidiag.sizeAt(0) - 1; - _hhCoeffs = NDArray(_HHmatrix.ordering(), {diagSize}, _HHmatrix.dataType(), _HHmatrix.getContext()); + auto _hhCoeffs = NDArray(_HHmatrix.ordering(), {diagSize}, _HHmatrix.dataType(), _HHmatrix.getContext()); if(type == 'u') for(int i = 0; i < diagSize; ++i) diff --git a/libnd4j/include/helpers/impl/hhSequence.cpp b/libnd4j/include/helpers/impl/hhSequence.cpp index a6fae699d06e..c20277934a04 100644 --- a/libnd4j/include/helpers/impl/hhSequence.cpp +++ b/libnd4j/include/helpers/impl/hhSequence.cpp @@ -89,13 +89,15 @@ void HHsequence::applyTo_(NDArray& dest) { } } -//////////////////////////////////////////////////////////////////////////void HHsequence::applyTo(NDArray& dest) { +////////////////////////////////////////////////////////////////////////// +void HHsequence::applyTo(NDArray& dest) { auto xType = _coeffs.dataType(); BUILD_SINGLE_SELECTOR(xType, applyTo_, (dest), FLOAT_TYPES); } -//////////////////////////////////////////////////////////////////////////void HHsequence::mulLeft(NDArray& matrix) { +///////////////////////////////////////////////////////////////////////// +void HHsequence::mulLeft(NDArray& matrix) { auto xType = _coeffs.dataType(); BUILD_SINGLE_SELECTOR(xType, mulLeft_, (matrix), FLOAT_TYPES); diff --git a/libnd4j/include/helpers/impl/householder.cpp b/libnd4j/include/helpers/impl/householder.cpp index e9572f9f666c..a149021fd965 100644 --- a/libnd4j/include/helpers/impl/householder.cpp +++ b/libnd4j/include/helpers/impl/householder.cpp @@ -202,10 +202,10 @@ void Householder::mulRight(NDArray& matrix, const NDArray& tail, const T coef } -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; -template class ND4J_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; +template class SD_EXPORT Householder; diff --git a/libnd4j/include/helpers/impl/jacobiSVD.cpp b/libnd4j/include/helpers/impl/jacobiSVD.cpp index 0378911f0945..8e8399a790cf 100644 --- a/libnd4j/include/helpers/impl/jacobiSVD.cpp +++ b/libnd4j/include/helpers/impl/jacobiSVD.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace sd { namespace ops { @@ -30,9 +31,9 @@ namespace helpers { template JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, const bool calcV, const bool fullUV) { - if (matrix.rankOf() != 2 || matrix.isScalar()) - throw std::runtime_error( - "ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); + + if(matrix.rankOf() != 2 || matrix.isScalar()) + throw std::runtime_error("ops::helpers::JacobiSVD constructor: input array must be 2D matrix !"); _rows = static_cast(matrix.sizeAt(0)); _cols = static_cast(matrix.sizeAt(1)); @@ -42,37 +43,27 @@ JacobiSVD::JacobiSVD(const NDArray& matrix, const bool calcU, _calcV = calcV; _fullUV = fullUV; - _s = NDArray(matrix.ordering(), {_diagSize, 1}, - matrix.dataType(), matrix.getContext()); + _s = NDArray(matrix.ordering(), {_diagSize, 1}, matrix.dataType(), matrix.getContext()); - if(_calcU) { - if(_fullUV) - _u = NDArray(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext()); - else - _u = NDArray(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext()); - } + if(_calcU) { + if(_fullUV) + _u = NDArray(matrix.ordering(), {_rows, _rows}, matrix.dataType(), matrix.getContext()); else - _u = NDArrayFactory::create(matrix.ordering(), {_rows, _diagSize}, - matrix.dataType(), matrix.getContext()); - } else - _u = NDArray(matrix.ordering(), {_rows, 1}, - matrix.dataType(), matrix.getContext()); - - if(_calcV) { - if(_fullUV) - _v = NDArray(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext()); - else - _v = NDArray(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext()); - } + _u = NDArray(matrix.ordering(), {_rows, _diagSize}, matrix.dataType(), matrix.getContext()); + } + else + _u = NDArray(matrix.ordering(), {_rows, 1}, matrix.dataType(), matrix.getContext()); + + if(_calcV) { + if(_fullUV) + _v = NDArray(matrix.ordering(), {_cols, _cols}, matrix.dataType(), matrix.getContext()); else - _v = NDArrayFactory::create(matrix.ordering(), {_cols, _diagSize}, - matrix.dataType(), matrix.getContext()); - } else - _v = NDArray(matrix.ordering(), {_cols, 1}, - matrix.dataType(), matrix.getContext()); + _v = NDArray(matrix.ordering(), {_cols, _diagSize}, matrix.dataType(), matrix.getContext()); + } + else + _v = NDArray(matrix.ordering(), {_cols, 1}, matrix.dataType(), matrix.getContext()); - _m = NDArray(matrix.ordering(), {_diagSize, _diagSize}, - matrix.dataType(), matrix.getContext()); + _m = NDArray(matrix.ordering(), {_diagSize, _diagSize}, matrix.dataType(), matrix.getContext()); evalData(matrix); } @@ -193,7 +184,7 @@ bool JacobiSVD::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) { template bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation) { - T denom = (T)(2. f)* math::nd4j_abs(y); + T denom = (T)(2.f)* math::nd4j_abs(y); if (denom < DataTypeUtils::min()) { rotation.r(0, 0) = diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index 5c3e49243f4d..54544b266a55 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -27,8 +27,10 @@ CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); auto segmentedOutput = OUTPUT_NULLIFIED(0); - Nd4jLong numOfClasses = - block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); + Nd4jLong numOfClasses = block.width() == 3 + ? INPUT_VARIABLE(2)->e(0) + : INT_ARG(0); + REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a " "vector, but it rank is %i.", diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp index 5fb452227c87..43fbf1f39bae 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp @@ -39,7 +39,7 @@ namespace sd { if (output->isEmpty()) return Status::OK(); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -60,7 +60,7 @@ namespace sd { std::vector dims; if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); @@ -87,7 +87,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.workspace())); } } } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp index 4f590aae8848..f72932239434 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp @@ -39,7 +39,7 @@ namespace sd { if (output->isEmpty()) return Status::OK(); - auto axis = *block.getIArguments(); + auto axis = block.getIArguments(); // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { @@ -60,7 +60,7 @@ namespace sd { std::vector dims; if (block.width() == 1) { - dims = *block.getIArguments(); + dims = block.getIArguments(); } else { auto y = INPUT_VARIABLE(1); dims = y->template asVectorT(); @@ -87,7 +87,7 @@ namespace sd { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.workspace())); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp index ba142260fee4..9de2b03103d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/extract_patches.cpp @@ -75,7 +75,7 @@ static void _extractPatches(NDArray* images, NDArray* output, int sizeRow, bool setUp = (theSame && row >= 0 && col >= 0 && row < rowDim && col < colDim) || (!theSame); if (setUp) { - outMatrix->r(i, j, pos) = patch->e(row, col, pixel); + outMatrix.r(i, j, pos) = patch.e(row, col, pixel); } pos++; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp index c16b1ad3f543..376f72a3512e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstsq.cpp @@ -39,7 +39,7 @@ static void fillRegularizer(NDArray& ioMatrix, double const value) { for (auto x = 0; x < lastDims.size(); x++) { for (auto r = 0; r < rows; r++) { - lastDims[x]->r(r,r) = (T)value; + lastDims[x].r(r,r) = (T)value; } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 1e427cbbb8a3..a52a9d5498a8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -65,14 +65,14 @@ static void segmentMaxFunctor_(NDArray* input, NDArray* indices, for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { if (indices->e(i) == idx) { - for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->r(e) = sd::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); + for (Nd4jLong e = 0; e < maxT.lengthOf(); e++) { + maxT.r(e) = sd::math::nd4j_max(maxT.t(e), listOfTensors.at(i).t(e)); } } else { idx = indices->e(i); maxT = listOfOutTensors.at(idx); - maxT->assign(listOfTensors.at(i)); + maxT.assign(listOfTensors.at(i)); } } @@ -460,8 +460,8 @@ static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, for (size_t idx = 1; idx < fi->second.size(); ++idx) { auto minT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - outputT->r(e) = sd::math::nd4j_min(minT->t(e), outputT->t(e)); + for (Nd4jLong e = 0; e < outputT.lengthOf(); ++e) { + outputT.r(e) = sd::math::nd4j_min(minT.t(e), outputT.t(e)); } } //outputT->assign(maxT); @@ -469,10 +469,10 @@ static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, } } -BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, - (NDArray * input, NDArray* indices, Nd4jLong numOfClasses, - NDArray* output), - NUMERIC_TYPES); +void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), + NUMERIC_TYPES); +} void unsortedSegmentMeanFunctor(sd::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, @@ -969,9 +969,9 @@ static int unsortedSegmentMinFunctorBP_(sd::LaunchContext* context, auto currentOut = listOfOutTensors.at(i); auto currentGradOut = listOfGradOuts.at(classNum); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) - currentOut->r(e) = currentGradOut->t(e); + for (Nd4jLong e = 0; e < current.lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum).t(e) - current.t(e)) < 1.e-6) + currentOut.r(e) = currentGradOut.t(e); } } //}; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index 6aafb5354ea8..730129e41d99 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -46,7 +46,7 @@ static void adjointMatrix_(sd::LaunchContext* context, NDArray const* input, for (auto batch = start; batch < stop; batch++) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c < r; c++) { - math::nd4j_swap(outputPart[batch]->r(r, c) , outputPart[batch]->r(c, r)); + math::nd4j_swap(outputPart[batch].r(r, c) , outputPart[batch].r(c, r)); } } } @@ -72,8 +72,8 @@ static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, auto permutationsPart = permutations.allTensorsAlongDimension({-1}); for (auto batch = 0; batch < permutationsPart.size(); ++batch) { - for (Nd4jLong row = 0; row < PPart[batch]->rows(); ++row) { - PPart[batch]->r(row, permutationsPart[batch]->t(row)) = T(1.f); + for (Nd4jLong row = 0; row < PPart[batch].rows(); ++row) { + PPart[batch].r(row, permutationsPart[batch].t(row)) = T(1.f); } } @@ -83,8 +83,8 @@ static int solveFunctor_(sd::LaunchContext* context, NDArray* leftInput, MmulHelper::matmul(&P, rightInput, &rightPermuted, 0, 0); ResultSet leftLowerPart = leftLower.allTensorsAlongDimension({-2, -1}); for (auto i = 0; i < leftLowerPart.size(); i++) { - for (Nd4jLong r = 0; r < leftLowerPart[i]->rows(); r++) - leftLowerPart[i]->r(r,r) = (T)1.f; + for (Nd4jLong r = 0; r < leftLowerPart[i].rows(); r++) + leftLowerPart[i].r(r,r) = (T)1.f; } // stage 2: triangularSolveFunctor for Lower with given b helpers::triangularSolveFunctor(context, &leftLower, &rightPermuted, true, false, &rightOutput); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp index fdd8e9ec2f4f..c0dfbb9e9f80 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/svd.cpp @@ -67,7 +67,8 @@ static void svd_(const NDArray* x, const std::vector& outArrs, } } -//////////////////////////////////////////////////////////////////////////void svd(sd::LaunchContext* context, const NDArray* x, +////////////////////////////////////////////////////////////////////////// +void svd(sd::LaunchContext* context, const NDArray* x, const std::vector& outArrs, const bool fullUV, const bool calcUV, const int switchNum) { BUILD_SINGLE_SELECTOR(x->dataType(), svd_, diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index 10811e68b06d..e9b598acfce6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -146,13 +146,13 @@ static void adjointTriangularMatrix_(sd::LaunchContext* context, if (!lower) { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = 0; c <= r; c++) { - outputPart[batch]->r(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].r(r, c) = inputPart[batch].t(c, r); } } } else { for (Nd4jLong r = 0; r < rows; r++) { for (Nd4jLong c = r; c < cols; c++) { - outputPart[batch]->r(r, c) = inputPart[batch]->t(c, r); + outputPart[batch].r(r, c) = inputPart[batch].t(c, r); } } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp index b8cc6d8ac29b..f3bca80c03f6 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp @@ -43,7 +43,7 @@ static void sqrtm_(const NDArray* x, NDArray* z) { auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) - ops::helpers::Sqrtm::calc(*listX.at(i), *listZ.at(i)); + ops::helpers::Sqrtm::calc(listX.at(i), listZ.at(i)); }; samediff::Threads::parallel_tad(func, 0, listX.size()); diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 5189df85d88e..8bdefd4f121e 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -15,73 +15,71 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -#include -#include +#include "testlayers.h" +#include #include -#include #include -#include -#include #include -#include -#include -#include +#include +#include +#include #include +#include #include #include +#include +#include +#include +#include -#include - -#include "testlayers.h" using namespace sd; class HelpersTests1 : public testing::Test { public: - HelpersTests1() { std::cout << std::endl << std::flush; } + + HelpersTests1() { + + std::cout<('c', { 4}, {14, 17, 3, 1}); -// auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, - -0.0632377, -0.0210792, -0.13484, -0.0632377, 0.98884, -0.00371987, - -0.0449467, -0.0210792, -0.00371987, 0.99876});// auto result = ops::helpers::Householder::evalHHmatrix(x); +// auto x = NDArrayFactory::create('c', {4}, {14,17,3,1}); +// auto exp = NDArrayFactory::create('c', {4,4}, {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, -0.0632377, -0.0210792, -0.13484,-0.0632377, 0.98884,-0.00371987, -0.0449467,-0.0210792,-0.00371987, 0.99876}); + +// auto result = ops::helpers::Householder::evalHHmatrix(x); // ASSERT_TRUE(result.isSameShape(&exp)); // ASSERT_TRUE(result.equalsTo(&exp)); - // } +// } - // /////////////////////////////////////////////////////////////////// +// /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, evalHHmatrix_test2) { -// -#ifdef __CUDABLAS__ -// return; -//#endif -// auto x = NDArrayFactory::create('c', { 3}, {14, -4, 3}); -// auto exp = NDArrayFactory::create( - 'c', {3, 3}, - {-0.941742, 0.269069, -0.201802, 0.269069, 0.962715, 0.0279639, -0.201802, - 0.0279639, 0.979027});// auto result = ops::helpers::Householder::evalHHmatrix(x); - - // ASSERT_TRUE(result.isSameShape(&exp)); +// #ifdef __CUDABLAS__ +// return; +// #endif +// auto x = NDArrayFactory::create('c', {3}, {14,-4,3}); +// auto exp = NDArrayFactory::create('c', {3,3}, {-0.941742, 0.269069,-0.201802, 0.269069, 0.962715,0.0279639, -0.201802,0.0279639, 0.979027}); + +// auto result = ops::helpers::Householder::evalHHmatrix(x); + +// ASSERT_TRUE(result.isSameShape(&exp)); // ASSERT_TRUE(result.equalsTo(&exp)); - // } +// } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, evalHHmatrixData_test1) { - auto x = NDArrayFactory::create('c', { 4}, {14, 17, 3, 1}); - auto tail = NDArrayFactory::create('c', { 3}); - auto expTail = NDArrayFactory::create( - 'c', { 3}, {0.468984, 0.0827618, 0.0275873}); + auto x = NDArrayFactory::create('c', {4}, {14,17,3,1}); + auto tail = NDArrayFactory::create('c', {3}); + auto expTail = NDArrayFactory::create('c', {3}, {0.468984, 0.0827618, 0.0275873}); const double normXExpected = -22.2486; const double coeffExpected = 1.62925; @@ -97,17 +95,12 @@ TEST_F(HelpersTests1, evalHHmatrixData_test1) { ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test1) { - auto x = NDArrayFactory::create( - 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); - auto tail = NDArrayFactory::create('c', {1, 3}, {0.5, 0.5, 0.5}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {9.05, 15.8, 11.4, 0.8, 8.525, 2.4, 15.7, 17.9, 17.525, 16.4, 3.7, 1.9, - 4.525, 2.4, 0.7, 14.9}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); ops::helpers::Householder::mulLeft(x, tail, 0.1); - ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); } @@ -115,30 +108,23 @@ TEST_F(HelpersTests1, Householder_mulLeft_test1) { ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulLeft_test2) { - auto x = NDArrayFactory::create( - 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); - auto tail = NDArrayFactory::create('c', {3, 1}, {0.5, 0.5, 0.5}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {9.05, 15.8, 11.4, 0.8, 8.525, 2.4, 15.7, 17.9, 17.525, 16.4, 3.7, 1.9, - 4.525, 2.4, 0.7, 14.9}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {3,1}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); ops::helpers::Householder::mulLeft(x, tail, 0.1); ASSERT_TRUE(x.isSameShapeStrict(exp)); ASSERT_TRUE(x.equalsTo(&exp)); + } ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, Householder_mulRight_test1) { - auto x = NDArrayFactory::create( - 'c', {4, 4}, {12, 19, 14, 3, 10, 4, 17, 19, 19, 18, 5, 3, 6, 4, 2, 16}); - auto tail = NDArrayFactory::create('c', {1, 3}, {0.5, 0.5, 0.5}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {9, 17.5, 12.5, 1.5, 7, 2.5, 15.5, 17.5, 15.8, 16.4, 3.4, 1.4, 4.3, 3.15, - 1.15, 15.15}); + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,17.5,12.5, 1.5, 7, 2.5,15.5, 17.5, 15.8,16.4, 3.4, 1.4, 4.3,3.15,1.15,15.15}); ops::helpers::Householder::mulRight(x, tail, 0.1); @@ -149,17 +135,9 @@ TEST_F(HelpersTests1, Householder_mulRight_test1) { ///////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { - auto matrix = NDArrayFactory::create( - 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); - auto hhMatrixExp = NDArrayFactory::create( - 'c', {4, 4}, - {1.524000, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, - 0.971124, 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, - 0}); - auto hhBidiagExp = NDArrayFactory::create( - 'c', {4, 4}, - {-17.1756, 24.3869, 0, 0, 0, -8.61985, -3.89823, 0, 0, 0, 4.03047, - 4.13018, 0, 0, 0, 1.21666}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6,13,11,7,6,3,7,4,7,6,6,7,10}); + auto hhMatrixExp = NDArrayFactory::create('c', {4,4}, {1.524000, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367,0, 0.229221,-0.272237,0.938237,0}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.1756, 24.3869, 0, 0, 0,-8.61985,-3.89823, 0, 0, 0, 4.03047,4.13018, 0, 0, 0,1.21666}); ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); @@ -173,22 +151,12 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto hhMatrixExp = NDArrayFactory::create( - 'c', {5, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, - -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, - 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); - auto hhBidiagExp = NDArrayFactory::create( - 'c', {4, 4}, - {-17.2916, 7.03123, 0, 0, 0, 16.145, -22.9275, 0, 0, 0, -9.9264, -11.5516, - 0, 0, 0, -12.8554}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto hhMatrixExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821, 0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0, 16.145,-22.9275, 0, 0, 0, -9.9264,-11.5516, 0, 0, 0,-12.8554}); ops::helpers::BiDiagonalUp object(matrix); - ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); @@ -198,19 +166,9 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { - auto matrix = NDArrayFactory::create( - 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, - -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); - auto hhMatrixExp = NDArrayFactory::create( - 'c', {6, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, - 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, - -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, - 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); - auto hhBidiagExp = NDArrayFactory::create( - 'c', {4, 4}, - {-17.2916, 7.03123, 0, 0, 0, 16.3413, -20.7828, 0, 0, 0, -18.4892, - 4.13261, 0, 0, 0, -21.323}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12, 0,-15,10,2}); + auto hhMatrixExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0,16.3413,-20.7828, 0, 0, 0,-18.4892,4.13261, 0, 0, 0,-21.323}); ops::helpers::BiDiagonalUp object(matrix); // object._HHmatrix.printBuffer(); @@ -224,23 +182,11 @@ TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test1) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto vectorsUseqExp = NDArrayFactory::create( - 'c', {5, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, - -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, - 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); - auto vectorsVseqExp = NDArrayFactory::create( - 'c', {5, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979, - -0.444696, 0.114105, 0.130601, 1.58392, 0, -0.22821, 0.215638, - 0.0524781, 1.99303, 0.0760699, 0.375605, 0.509835, 0.0591568}); - auto coeffsUseqExp = NDArrayFactory::create( - 'c', {4, 1}, {1.52048, 1.66025, 1.58392, 1.99303}); - auto coeffsVseqExp = - NDArrayFactory::create('c', {3, 1}, {1.37012, 1.66979, 0}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto vectorsUseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto vectorsVseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.66025,1.58392,1.99303}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.66979,0}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -254,30 +200,17 @@ TEST_F(HelpersTests1, HHsequence_test1) { ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); ASSERT_TRUE(vSeq._shift == 1); ASSERT_TRUE(uSeq._shift == 0); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test2) { - auto matrix = NDArrayFactory::create( - 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, - -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); - auto vectorsUseqExp = NDArrayFactory::create( - 'c', {6, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, - 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, - -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, - 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); - auto vectorsVseqExp = NDArrayFactory::create( - 'c', {6, 4}, - {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, - 1.59666, -0.502606, 0.114105, 0.129651, 1.35075, 0, - -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, - 0.389936, 0.2398, 0, 0.0935171, -0.563777, 0.428587}); - auto coeffsUseqExp = NDArrayFactory::create( - 'c', {4, 1}, {1.52048, 1.65232, 1.35075, 1.61136}); - auto coeffsVseqExp = - NDArrayFactory::create('c', {3, 1}, {1.37012, 1.59666, 0}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto vectorsUseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto vectorsVseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.65232,1.35075,1.61136}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.59666,0}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -291,25 +224,17 @@ TEST_F(HelpersTests1, HHsequence_test2) { ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); ASSERT_TRUE(vSeq._shift == 1); ASSERT_TRUE(uSeq._shift == 0); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test3) { - auto matrix = NDArrayFactory::create( - 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); - auto vectorsUseqExp = NDArrayFactory::create( - 'c', {4, 4}, - {1.524, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, 0.971124, - 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, 0}); - auto vectorsVseqExp = NDArrayFactory::create( - 'c', {4, 4}, - {1.524, 1.75682, 0.233741, 0.289458, 0.496646, 1.5655, 1.02929, 0.971124, - 0.114611, -0.451039, 1.06367, 0, 0.229221, -0.272237, 0.938237, 0}); - auto coeffsUseqExp = - NDArrayFactory::create('c', {4, 1}, {1.524, 1.5655, 1.06367, 0}); - auto coeffsVseqExp = - NDArrayFactory::create('c', {3, 1}, {1.75682, 1.02929, 0}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto vectorsUseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto vectorsVseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, { 1.524, 1.5655,1.06367,0}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.75682,1.02929, 0}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -323,77 +248,57 @@ TEST_F(HelpersTests1, HHsequence_test3) { ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); ASSERT_TRUE(vSeq._shift == 1); ASSERT_TRUE(uSeq._shift == 0); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test4) { - auto matrix = NDArrayFactory::create( - 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {2.49369, 2.62176, 5.88386, 7.69905, -16.0588, -18.7319, -9.15007, - -12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279, -2.29178, 1.90139, - -0.66187}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {2.49369, 2.62176, 5.88386, 7.69905, -16.0588,-18.7319,-9.15007,-12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279,-2.29178, 1.90139,-0.66187}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); uSeq.mulLeft(matrix); ASSERT_TRUE(matrix.equalsTo(&exp)); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test5) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto exp = NDArrayFactory::create( - 'c', {5, 4}, - {4.52891, 8.09473, -2.73704, -13.0302, -11.0752, 7.41549, -3.75125, - 0.815252, -7.76818, -15.9102, -9.90869, -11.8677, 1.63942, -17.0312, - -9.05102, -4.49088, -9.63311, 0.540226, -1.52764, 5.79111}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {4.52891, 8.09473,-2.73704,-13.0302, -11.0752, 7.41549,-3.75125,0.815252, -7.76818,-15.9102,-9.90869,-11.8677, 1.63942,-17.0312,-9.05102,-4.49088, -9.63311,0.540226,-1.52764, 5.79111}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); uSeq.mulLeft(matrix); ASSERT_TRUE(matrix.equalsTo(&exp)); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test6) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, - {9, -1, 3, 9, -4.43019, -15.1713, - -3.2854, -7.65743, -9.39162, -7.03599, 8.03827, 9.48453, - -2.97785, -16.424, 5.35265, -20.1171, -0.0436177, -13.118, - -8.37287, -17.3012, -1.14074, 4.18282, -10.0914, -5.69014}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9,-1,3,9, -4.43019,-15.1713, -3.2854,-7.65743, -9.39162,-7.03599, 8.03827, 9.48453, -2.97785, -16.424, 5.35265,-20.1171, -0.0436177, -13.118,-8.37287,-17.3012, -1.14074, 4.18282,-10.0914,-5.69014}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); uSeq.mulLeft(matrix2); ASSERT_TRUE(matrix2.equalsTo(&exp)); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test7) { - auto matrix = NDArrayFactory::create( - 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {9, 13, 3, 6, -5.90424, -2.30926, -0.447417, 3.05712, -10.504, -9.31339, - -8.85493, -10.8886, -8.29494, -10.6737, -5.94895, -7.55591}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,13,3,6,-5.90424,-2.30926,-0.447417, 3.05712, -10.504,-9.31339, -8.85493,-10.8886, -8.29494,-10.6737, -5.94895,-7.55591}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -405,14 +310,8 @@ TEST_F(HelpersTests1, HHsequence_test7) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test8) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto exp = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, - -6, -6.90831, -5.01113, 0.381677, 0.440128, -0.80107, 0.961605, - -0.308019, -1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {9, -13, 3, 6, 13, 11, 7, -6, -6.90831,-5.01113, 0.381677,0.440128, -0.80107,0.961605,-0.308019,-1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -424,14 +323,8 @@ TEST_F(HelpersTests1, HHsequence_test8) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test9) { - auto matrix = NDArrayFactory::create( - 'c', {6, 4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, - -6, 6, 7, 10, 2, 17, 9, 12, 0, -15, 10, 2}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, {9, -13, 3, 6, 13, 11, - 7, -6, 3, 7, 4, 7, - 3.77597, 18.6226, -0.674868, 4.61365, 5.02738, -14.1486, - -2.22877, -8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, 3.77597, 18.6226,-0.674868, 4.61365, 5.02738,-14.1486, -2.22877,-8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -443,17 +336,9 @@ TEST_F(HelpersTests1, HHsequence_test9) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test10) { - auto matrix = NDArrayFactory::create( - 'c', {4, 4}, {9, 13, 3, 6, 13, 11, 7, 6, 3, 7, 4, 7, 6, 6, 7, 10}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, - {9, -1, 3, 9, 10, 11, - -7, -5, 3, 2, 4, 7, - 2.58863, 11.0295, -4.17483, -0.641012, -1.21892, -16.3151, - 6.12049, -20.0239, -0.901799, -15.0389, -12.4944, -20.2394}); + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 2.58863, 11.0295,-4.17483,-0.641012, -1.21892,-16.3151, 6.12049, -20.0239, -0.901799,-15.0389,-12.4944, -20.2394}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -465,17 +350,9 @@ TEST_F(HelpersTests1, HHsequence_test10) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test11) { - auto matrix = NDArrayFactory::create( - 'c', {5, 4}, - {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, -6, 6, 7, 10, 2, 17, 9, 12}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, - -7, -5, 3, 2, 4, 7, - 1.14934, 4.40257, 8.70127, -1.18824, 1.5132, 0.220419, - -11.6285, -11.7549, 2.32148, 24.3838, 0.256531, 25.9116}); + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 1.14934, 4.40257, 8.70127,-1.18824, 1.5132,0.220419,-11.6285,-11.7549, 2.32148, 24.3838,0.256531, 25.9116}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -487,16 +364,9 @@ TEST_F(HelpersTests1, HHsequence_test11) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test12) { - auto matrix = NDArrayFactory::create( - 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, - -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, -2.62252, -22.2914, - 4.76743, -19.6689, -1.05943, -9.00514, -11.8013, -7.94571}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, 19, -2.62252,-22.2914, 4.76743,-19.6689, -1.05943,-9.00514,-11.8013,-7.94571}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -508,17 +378,9 @@ TEST_F(HelpersTests1, HHsequence_test12) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test13) { - auto matrix = NDArrayFactory::create( - 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto exp = NDArrayFactory::create( - 'c', {6, 4}, - {9, -1, 3, 9, -4.65167, 3.44652, - 7.83593, 22.6899, -9.48514, -21.902, 5.66559, -13.0533, - -0.343184, 15.2895, 7.2888, 14.0489, 0.289638, -1.87752, - 3.944, -1.49707, -2.48845, 3.18285, -10.6685, 0.406502}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9 , -1 , 3 , 9, -4.65167, 3.44652, 7.83593, 22.6899, -9.48514, -21.902, 5.66559,-13.0533, -0.343184, 15.2895, 7.2888, 14.0489, 0.289638,-1.87752, 3.944,-1.49707, -2.48845, 3.18285,-10.6685,0.406502}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -530,17 +392,9 @@ TEST_F(HelpersTests1, HHsequence_test13) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test14) { - auto matrix = NDArrayFactory::create( - 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); - auto matrix2 = NDArrayFactory::create( - 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, - 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); - auto exp = NDArrayFactory::create( - 'c', {5, 5}, - {1.78958, 8.06962, -6.13687, 4.36267, 1.06472, -14.9578, -8.1522, - 1.30442, -18.3343, -13.2578, 13.5536, 5.50764, 15.7859, 7.60831, - 11.7871, -1.3626, -0.634986, 7.60934, -2.1841, 5.62694, -13.0577, - 15.1554, -7.6511, 3.76365, -5.87368}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {1.78958, 8.06962,-6.13687, 4.36267, 1.06472, -14.9578, -8.1522, 1.30442,-18.3343,-13.2578, 13.5536, 5.50764, 15.7859, 7.60831, 11.7871, -1.3626,-0.634986, 7.60934, -2.1841, 5.62694, -13.0577, 15.1554, -7.6511, 3.76365,-5.87368}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -549,20 +403,13 @@ TEST_F(HelpersTests1, HHsequence_test14) { ASSERT_TRUE(matrix2.equalsTo(&exp)); } + /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test15) { - auto matrix = NDArrayFactory::create( - 'c', {5, 3}, {9, -13, 3, 13, 11, 7, 3, 7, 4, -6, 6, 7, 2, 17, 9}); - auto matrix2 = NDArrayFactory::create( - 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, - 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); - auto exp = NDArrayFactory::create( - 'c', {5, 5}, - {9, -1, 3, 9, 10, 11, -7, - -5, 3, 2, 4, 7, -1, 6, - 7, -9.26566, -16.4298, 1.64125, -17.3243, -7.70257, -16.7077, - 4.80216, -19.1652, -2.42279, -13.0258}); + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, -9.26566,-16.4298, 1.64125,-17.3243,-7.70257, -16.7077, 4.80216,-19.1652,-2.42279,-13.0258}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -574,17 +421,10 @@ TEST_F(HelpersTests1, HHsequence_test15) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test16) { - auto matrix = NDArrayFactory::create( - 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, - 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); - auto matrix2 = NDArrayFactory::create('c', {10, 10}); + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); matrix2 = 100.; - auto exp = NDArrayFactory::create( - 'c', {5, 5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, - -0.455573, -0.824221, -0.239444, 0.216163, -0.0951492, - -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, - -0.7869, 0.245393, 0.116952, -0.541267, 0.117997, - -0.0828315, 0.303191, -0.888202, 0.133021, 0.3076}); + auto exp = NDArrayFactory::create('c',{5,5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, -0.455573,-0.824221,-0.239444, 0.216163,-0.0951492, -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, -0.7869, 0.245393, 0.116952,-0.541267, 0.117997, -0.0828315, 0.303191,-0.888202, 0.133021, 0.3076}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -596,17 +436,10 @@ TEST_F(HelpersTests1, HHsequence_test16) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test17) { - auto matrix = NDArrayFactory::create( - 'c', {5, 5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, - 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15, 2}); - auto matrix2 = NDArrayFactory::create('c', {10, 10}); + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); matrix2 = 100.; - auto exp = NDArrayFactory::create( - 'c', {5, 5}, {1, 0, 0, 0, 0, - 0, -0.022902, 0.986163, 0.0411914, 0.158935, - 0, -0.44659, 0.021539, 0.797676, -0.404731, - 0, -0.554556, 0.103511, -0.600701, -0.56649, - 0, -0.701784, -0.127684, -0.0342758, 0.700015}); + auto exp = NDArrayFactory::create('c',{5,5}, {1, 0, 0, 0, 0, 0,-0.022902, 0.986163, 0.0411914, 0.158935, 0, -0.44659, 0.021539, 0.797676,-0.404731, 0,-0.554556, 0.103511, -0.600701, -0.56649, 0,-0.701784,-0.127684,-0.0342758, 0.700015}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -618,19 +451,10 @@ TEST_F(HelpersTests1, HHsequence_test17) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test18) { - auto matrix = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto matrix2 = NDArrayFactory::create('c', {10, 10}); + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); matrix2 = 100.; - auto exp = NDArrayFactory::create( - 'c', {6, 6}, - {-0.637993, 0.190621, -0.524821, -0.312287, 0.407189, 0.133659, - -0.708881, 0.0450803, 0.47462, 0.232701, -0.204602, -0.417348, - -0.212664, -0.0405892, -0.297123, 0.0240276, -0.821557, 0.435099, - 0.0708881, -0.432466, -0.49252, -0.145004, -0.199312, -0.710367, - -0.141776, -0.56468, -0.180549, 0.706094, 0.274317, 0.233707, - -0.141776, -0.673865, 0.368567, -0.572848, 0.0490246, 0.243733}); + auto exp = NDArrayFactory::create('c',{6,6}, {-0.637993, 0.190621,-0.524821,-0.312287, 0.407189, 0.133659, -0.708881, 0.0450803, 0.47462, 0.232701,-0.204602,-0.417348, -0.212664,-0.0405892,-0.297123,0.0240276,-0.821557, 0.435099, 0.0708881, -0.432466, -0.49252,-0.145004,-0.199312,-0.710367, -0.141776, -0.56468,-0.180549, 0.706094, 0.274317, 0.233707, -0.141776, -0.673865, 0.368567,-0.572848,0.0490246, 0.243733}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); @@ -642,15 +466,10 @@ TEST_F(HelpersTests1, HHsequence_test18) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHsequence_test19) { - auto matrix = NDArrayFactory::create( - 'c', {6, 4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, - -1, 6, 7, 19, 2, 17, 9, 15, 2, 17, -9, 15}); - auto matrix2 = NDArrayFactory::create('c', {10, 10}); + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); matrix2 = 100.; - auto exp = NDArrayFactory::create( - 'c', {4, 4}, - {1, 0, 0, 0, 0, -0.859586, 0.28601, -0.42345, 0, 0.19328, -0.585133, - -0.787567, 0, -0.473027, -0.758826, 0.447693}); + auto exp = NDArrayFactory::create('c',{4,4}, {1, 0, 0, 0, 0,-0.859586, 0.28601, -0.42345, 0, 0.19328,-0.585133,-0.787567, 0,-0.473027,-0.758826, 0.447693}); ops::helpers::BiDiagonalUp object(matrix); ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); @@ -662,61 +481,61 @@ TEST_F(HelpersTests1, HHsequence_test19) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_1) { - auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto expQR = NDArrayFactory::create('c', {5,6}, {-32.6649659, -4.9594419, -8.2657365, 7.2248659, 16.5927006, 11.7251002, -0.1354883, -29.0586293, 10.9775804, -14.6886248, 4.1884104, 20.7115773, 0.3483986, 0.3236753, 25.5376258, 1.6432380, 9.6395914, -9.0237996, -0.0580664, 0.0798999, -0.0799029, 19.5280665, -4.9773587, 16.0968604, 0.3483986, -0.6667832, 0.0252425, 0.0159188, 10.6978354, -4.6919842}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); + auto expQR = NDArrayFactory::create('c', {5,6}, {-32.6649659, -4.9594419, -8.2657365, 7.2248659, 16.5927006, 11.7251002, -0.1354883, -29.0586293, 10.9775804, -14.6886248, 4.1884104, 20.7115773, 0.3483986, 0.3236753, 25.5376258, 1.6432380, 9.6395914, -9.0237996, -0.0580664, 0.0798999, -0.0799029, 19.5280665, -4.9773587, 16.0968604, 0.3483986, -0.6667832, 0.0252425, 0.0159188, 10.6978354, -4.6919842}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_2) { - auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); - auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); - auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); - auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); + auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); + auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, HHcolPivQR_3) { - NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); - auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); - auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); - auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); + auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); + auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); - ops::helpers::HHcolPivQR qr(matrix1); + ops::helpers::HHcolPivQR qr(matrix1); - ASSERT_TRUE(expQR.equalsTo(&qr._qr)); - ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); - ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); - ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); - ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); - ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); } @@ -724,304 +543,293 @@ TEST_F(HelpersTests1, HHcolPivQR_3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test1) { - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto left = NDArrayFactory::create('c', {2,2}); - auto right = NDArrayFactory::create('c', {2,2}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto left = NDArrayFactory::create('c', {2,2}); + auto right = NDArrayFactory::create('c', {2,2}); - auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); - auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); + auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); + auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); - ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); + ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); - ASSERT_TRUE(expLeft.equalsTo(&left)); - ASSERT_TRUE(expRight.equalsTo(&right)); + ASSERT_TRUE(expLeft.equalsTo(&left)); + ASSERT_TRUE(expRight.equalsTo(&right)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test2) { - auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); - auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); + auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); - auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); - auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); + auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); + auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); - ops::helpers::JacobiSVD jac(matrix3, true, true, true); - jac._m = matrix3; - jac._u = matrix4; - jac._v = matrix5; + ops::helpers::JacobiSVD jac(matrix3, true, true, true); + jac._m = matrix3; + jac._u = matrix4; + jac._v = matrix5; - double maxElem; - bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); + double maxElem; + bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); - // ASSERT_NEAR(maxElem, 19.69772, 1e-5); - ASSERT_TRUE(exp3.equalsTo(&matrix3)); - ASSERT_TRUE(exp4.equalsTo(&jac._u)); - ASSERT_TRUE(exp5.equalsTo(&jac._v)); + // ASSERT_NEAR(maxElem, 19.69772, 1e-5); + ASSERT_TRUE(exp3.equalsTo(&matrix3)); + ASSERT_TRUE(exp4.equalsTo(&jac._u)); + ASSERT_TRUE(exp5.equalsTo(&jac._v)); - ASSERT_TRUE(result); + ASSERT_TRUE(result); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test3) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test4) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test5) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test6) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test7) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } ////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test8) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); - auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); - ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); + ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); - ASSERT_TRUE(expected.equalsTo(&matrix)); + ASSERT_TRUE(expected.equalsTo(&matrix)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test9) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test10) { - auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test11) { - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test12) { - auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); - auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); - auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test13) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); - ops::helpers::JacobiSVD jac(matrix, true, true, true); + ops::helpers::JacobiSVD jac(matrix, true, true, true); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test14) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); - ops::helpers::JacobiSVD jac(matrix, true, true, false); + ops::helpers::JacobiSVD jac(matrix, true, true, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); - ASSERT_TRUE(expU.equalsTo(&jac._u)); - ASSERT_TRUE(expV.equalsTo(&jac._v)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test15) { - auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); - auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); - auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); - auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); - ops::helpers::JacobiSVD jac(matrix, false, false, false); + ops::helpers::JacobiSVD jac(matrix, false, false, false); - ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expS.equalsTo(&jac._s)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, JacobiSVD_test16) { - NDArray rotation('c', {2,2}, sd::DataType::DOUBLE); + NDArray rotation('c', {2,2}, sd::DataType::DOUBLE); - NDArray exp1('c', {2,2}, {1,0,0,1 }, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {0,1,-1,0}, sd::DataType::DOUBLE); - NDArray exp3('c', {2,2}, {-1,0,0,-1}, sd::DataType::DOUBLE); - NDArray exp4('c', {2,2}, {0.983282, 0.182089, -0.182089, 0.983282}, sd::DataType::DOUBLE); - NDArray exp5('c', {2,2}, {0.249041, 0.968493, -0.968493, 0.249041}, sd::DataType::DOUBLE); + NDArray exp1('c', {2,2}, {1,0,0,1 }, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0,1,-1,0}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {-1,0,0,-1}, sd::DataType::DOUBLE); + NDArray exp4('c', {2,2}, {0.983282, 0.182089, -0.182089, 0.983282}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,2}, {0.249041, 0.968493, -0.968493, 0.249041}, sd::DataType::DOUBLE); - ops::helpers::JacobiSVD::createJacobiRotationGivens(0, 0, rotation); - ASSERT_TRUE(rotation.equalsTo(exp1)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp1)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp1)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp1)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(0, -0.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp2)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp2)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp2)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp2)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(-0.5, 0, rotation); - ASSERT_TRUE(rotation.equalsTo(exp3)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp3)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(-0.5, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp3)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp3)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -0.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp4)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp4)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp4)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp4)); - ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -10.5, rotation); - ASSERT_TRUE(rotation.equalsTo(exp5)); - ASSERT_TRUE(rotation.isSameShapeStrict(exp5)); + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -10.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp5)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp5)); } TEST_F(HelpersTests1, test_binary_search_1) { - std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); - ASSERT_EQ(2, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); + ASSERT_EQ(2, idx); } TEST_F(HelpersTests1, test_binary_search_2) { - std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); - ASSERT_EQ(-1, idx); + auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); + ASSERT_EQ(-1, idx); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test1) { - auto matrix = NDArrayFactory::create( - 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, - -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); - auto matrix2 = NDArrayFactory::create( - 'c', {5, 5}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20, -4, - -16, -9, -17, -5, -7, -19, -8, -9, 9, 6, 14, -11}); - auto expM = NDArrayFactory::create( - 'c', {5, 5}, - {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17.0294, - -6, 8, -10, 14, -15, 6, -10, -14, 12, 0, -16, 0}); - auto expU = NDArrayFactory::create( - 'c', {5, 5}, - {18, 3, 2, 7, -11, 7, 7.75131, 10, -12.5665, - -8, 13, 20.905, -4, -14.7979, -9, -17, -3.87565, -7, - -19.2608, -8, -9, 9, 6, 14, -11}); + auto matrix = NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {5,5}, {18,3, 2,7,-11, 7, 7.75131,10,-12.5665, -8, 13, 20.905,-4,-14.7979, -9, -17,-3.87565,-7,-19.2608, -8, -9, 9, 6, 14,-11}); ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); svd._m = matrix; svd._u = matrix2; - svd.deflation1(1, 1, 2, 2); + svd.deflation1(1,1,2,2); ASSERT_TRUE(expM.equalsTo(&svd._m)); ASSERT_TRUE(expU.equalsTo(&svd._u)); @@ -1030,26 +838,15 @@ TEST_F(HelpersTests1, SVD_test1) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test2) { - auto matrix = NDArrayFactory::create( - 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, - -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); - auto matrix2 = NDArrayFactory::create( - 'c', {5, 5}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20, -4, - -16, -9, -17, -5, -7, -19, -8, -9, 9, 6, 14, -11}); - auto expM = NDArrayFactory::create( - 'c', {5, 5}, - {22.6716, 14, 9, -12, -12, 5, -4, -19, -7, -12, 0, 16, 0, - -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); - auto expU = NDArrayFactory::create( - 'c', {5, 5}, - {-12.1738, 3, -13.4089, 7, -11, 1.36735, 7, -12.1297, -13, - -8, -12.3944, 20, -5.60173, -16, -9, -17, -5, -7, - -19, -8, -9, 9, 6, 14, -11}); + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {22.6716,14, 9,-12,-12, 5,-4,-19, -7,-12, 0,16, 0, -6, 8, -10,14,-15, 6,-10, -14,12, -1,-16, 3}); + auto expU = NDArrayFactory::create('c', {5,5}, {-12.1738, 3, -13.4089, 7,-11, 1.36735, 7, -12.1297,-13, -8, -12.3944,20, -5.60173,-16, -9, -17,-5,-7,-19, -8, -9, 9, 6, 14,-11}); ops::helpers::SVD svd(matrix, 4, true, true, true); svd._m = matrix; svd._u = matrix2; - svd.deflation1(0, 0, 2, 2); + svd.deflation1(0,0,2,2); ASSERT_TRUE(expM.equalsTo(&svd._m)); ASSERT_TRUE(expU.equalsTo(&svd._u)); @@ -1058,23 +855,15 @@ TEST_F(HelpersTests1, SVD_test2) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test3) { - auto matrix = NDArrayFactory::create( - 'c', {5, 5}, {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17, - -6, 8, -10, 14, -15, 6, -10, -14, 12, -1, -16, 3}); - auto matrix2 = NDArrayFactory::create( - 'c', {2, 6}, {18, 3, 2, 7, -11, 7, 7, 10, -13, -8, 13, 20}); - auto expM = NDArrayFactory::create( - 'c', {5, 5}, - {-17, 14, 9, -12, -12, 5, -4, -19, -7, -12, 15, 16, 17.0294, - -6, 8, -10, 14, -15, 6, -10, -14, 12, 0, -16, 0}); - auto expU = NDArrayFactory::create( - 'c', {2, 6}, - {18, 2.58377, 2, 7.16409, -11, 7, 7, 10.4525, -13, -7.39897, 13, 20}); + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {2,6}, {18, 2.58377, 2, 7.16409,-11, 7, 7 ,10.4525 ,-13, -7.39897 ,13 ,20}); ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); svd._m = matrix; svd._u = matrix2; - svd.deflation1(1, 1, 2, 2); + svd.deflation1(1,1,2,2); ASSERT_TRUE(expM.equalsTo(&svd._m)); ASSERT_TRUE(expU.equalsTo(&svd._u)); @@ -1083,31 +872,12 @@ TEST_F(HelpersTests1, SVD_test3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test4) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto expM = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 8.06226, 2, 2, -3, -18, 0, -17, 2, 12, 18, 6, -2, -17}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, - {-10, -16, -20, 13, 20, -10, -9, -1, -20.7138, - 4.46525, -4, 20, -11, 19, -18.4812, 2.72876, 12, -19, - 18, -18, 17, -10, -19, 14, -2, -7, -17, - -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, - {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 2.97683, - -7.69015, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8, 18,-17, 18, -14,-15,8.06226, 2, 2, -3,-18, 0,-17, 2, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16, -20, 13, 20,-10, -9, -1,-20.7138,4.46525, -4, 20, -11, 19,-18.4812,2.72876, 12,-19, 18,-18, 17, -10,-19, 14, -2, -7, -17, -14, -4,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-18, -13, 14, 2, -2,-11,2.97683,-7.69015,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); svd._m = matrix1; @@ -1123,31 +893,12 @@ TEST_F(HelpersTests1, SVD_test4) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test5) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto expM = NDArrayFactory::create( - 'c', {6, 5}, {18.4391, 20, 19, -18, -6, 3, 6, 2, -7, -7, - 0, 8, 18.4391, -17, 18, -14, -15, 1, 2, 2, - -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, - {-10, -16, -20, 13, 20, -10, -9, -15.8359, -7, -12.2566, - -4, 20, -11, -1.30158, -5, -26.1401, 12, -19, 18, -19.3068, - 17, 7.15871, -19, 14, -2, -7, -17, -14, -4, -16, - 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, - {-18, 1, 19, -7, 1, 2, -1.08465, -13, 22.7777, 2, -2, -5.64019, 8, - 9.65341, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16,-20,13, 20,-10, -9,-15.8359, -7,-12.2566, -4, 20, -11,-1.30158, -5,-26.1401, 12,-19, 18,-19.3068, 17, 7.15871,-19, 14, -2, -7,-17, -14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); svd._m = matrix1; @@ -1163,27 +914,12 @@ TEST_F(HelpersTests1, SVD_test5) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test6) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto matrix2 = NDArrayFactory::create( - 'c', {2, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto expM = NDArrayFactory::create( - 'c', {6, 5}, {18.4391, 20, 19, -18, -6, 3, 6, 2, -7, -7, - 0, 8, 18.4391, -17, 18, -14, -15, 1, 2, 2, - -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto expU = - NDArrayFactory::create('c', {2, 6}, - {-10, -0.542326, -20, 20.6084, 20, -10, -9, - -15.8359, -7, -12.2566, -4, 20}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, - {-18, 1, 19, -7, 1, 2, -1.08465, -13, 22.7777, 2, -2, -5.64019, 8, - 9.65341, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {2,6}, {-10, -0.542326,-20, 20.6084,20,-10, -9, -15.8359, -7,-12.2566,-4, 20}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); svd._m = matrix1; @@ -1199,31 +935,13 @@ TEST_F(HelpersTests1, SVD_test6) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test7) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - auto expM = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 19.6977, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 0, -17, 0, 12, 18, 6, -2, -17}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, - {-10, -16, -20, 13, 20, -10, -9, -9.03658, -7, - -17.8701, -4, 20, -11, 10.0519, -5, -24.1652, 12, -19, - 18, -20.51, 17, -1.82762, -19, 14, -2, -12.0826, -17, - -9.95039, -4, -16, 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8,19.6977,-17, 18, -14,-15, 1, 2, 2, -3,-18, 0,-17, 0, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10, -16,-20, 13, 20,-10, -9,-9.03658, -7,-17.8701, -4, 20, -11, 10.0519, -5,-24.1652, 12,-19, 18, -20.51, 17,-1.82762,-19, 14, -2,-12.0826,-17,-9.95039, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13,14, 2, -2,-11, 8, 2,-6, -3, -8, 8,-2, 7, 16, 15, -3, 7, 0}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); svd._m = matrix1; @@ -1239,29 +957,13 @@ TEST_F(HelpersTests1, SVD_test7) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test8) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, 8, 18, -17, 18, - -14, -15, 1, 2, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - auto expM = NDArrayFactory::create( - 'c', {6, 5}, - {12, 20, 19, -18, -6, 3, 6, 2, -7, -7, 14, -15, 2, -17, 18, - -14, 8, 1, 18, 2, -3, -18, 8, -17, -19, 12, 18, 6, -2, -17}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, {-10, -20, -16, 13, 20, -10, -9, -7, -1, -20, -4, 20, - -11, -5, 19, -18, 12, -19, 18, 17, -18, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 2, 14, -2, -11, 8, - -6, 2, -3, -8, 8, 7, -2, 16, 15, -3, 7, 0}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20,19,-18, -6, 3, 6, 2, -7, -7, 14,-15, 2,-17, 18, -14, 8, 1, 18, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-20,-16, 13, 20,-10, -9, -7, -1,-20, -4, 20, -11, -5, 19,-18, 12,-19, 18, 17,-18,-10,-19, 14, -2, -7,-17,-14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13, 2,14, -2,-11, 8,-6, 2, -3, -8, 8, 7,-2, 16, 15, -3, 7, 0}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); svd._m = matrix1; @@ -1277,28 +979,18 @@ TEST_F(HelpersTests1, SVD_test8) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test9) { - auto col0 = NDArrayFactory::create( - 'c', {10, 1}, {12, 20, 19, -18, -6, 3, 6, 2, -7, 14}); - auto diag = NDArrayFactory::create( - 'c', {10, 1}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2}); - auto permut = NDArrayFactory::create('c', {1, 10}, - {8, 1, 4, 0, 5, 2, 9, 3, 7, 6}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - auto expSingVals = NDArrayFactory::create( - 'c', {10, 1}, - {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); - auto expShifts = NDArrayFactory::create( - 'c', {10, 1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); - auto expMus = NDArrayFactory::create( - 'c', {10, 1}, - {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); - - auto singVals = NDArrayFactory::create('c', {10, 1}); - auto shifts = NDArrayFactory::create('c', {10, 1}); - auto mus = NDArrayFactory::create('c', {10, 1}); + auto col0 = NDArrayFactory::create('c', {10,1}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,14}); + auto diag = NDArrayFactory::create('c', {10,1}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2}); + auto permut = NDArrayFactory::create('c', {1,10}, {8 ,1 ,4 ,0, 5 ,2 ,9 ,3 ,7 ,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expSingVals = NDArrayFactory::create('c', {10,1}, {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); + auto expShifts = NDArrayFactory::create('c', {10,1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); + auto expMus = NDArrayFactory::create('c', {10,1}, {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); + + auto singVals = NDArrayFactory::create('c', {10,1}); + auto shifts = NDArrayFactory::create('c', {10,1}); + auto mus = NDArrayFactory::create('c', {10,1}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); @@ -1311,91 +1003,62 @@ TEST_F(HelpersTests1, SVD_test9) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test10) { - auto singVals = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); - auto col0 = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); - auto diag = NDArrayFactory::create('c', {4, 1}, {5, 7, -13, 14}); - auto permut = NDArrayFactory::create('c', {1, 4}, {0, 2, 3, 1}); - auto mus = NDArrayFactory::create('c', {4, 1}, {4, 1, 4, 6}); - auto shifts = NDArrayFactory::create('c', {4, 1}, {4, 2, 5, 6}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto col0 = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); - auto expZhat = - NDArrayFactory::create('c', {4, 1}, {0, 0.278208, 72.501953, 0}); + auto expZhat = NDArrayFactory::create('c', {4,1}, {0, 0.278208, 72.501953, 0}); - auto zhat = NDArrayFactory::create('c', {4, 1}); + auto zhat = NDArrayFactory::create('c', {4,1}); ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); + svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); } + /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test11) { - auto singVals = NDArrayFactory::create('c', {4, 1}, {1, 1, 1, 1}); - auto zhat = NDArrayFactory::create('c', {4, 1}, {2, 1, 2, 1}); - auto diag = NDArrayFactory::create('c', {4, 1}, {5, 7, -13, 14}); - auto permut = NDArrayFactory::create('c', {1, 4}, {0, 2, 3, 1}); - auto mus = NDArrayFactory::create('c', {4, 1}, {4, 1, 4, 6}); - auto shifts = NDArrayFactory::create('c', {4, 1}, {4, 2, 5, 6}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - - auto expU = NDArrayFactory::create( - 'c', {5, 5}, {-0.662161, 0.980399, -0.791469, -0.748434, 0, - -0.744931, 0.183825, -0.593602, -0.392928, 0, - 0.0472972, 0.061275, 0.0719517, 0.104781, 0, - 0.0662161, 0.0356509, 0.126635, 0.523904, 0, - 0, 0, 0, 0, 1}); - auto expV = NDArrayFactory::create( - 'c', {4, 4}, - {-0.745259, -0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, - -0.156156, -0.0768918, -0.130705, -0.0885868, -0.0773343, 0.115929, - 0.0818966, 0.167906, 0.416415}); - auto U = NDArrayFactory::create('c', {5, 5}); - auto V = NDArrayFactory::create('c', {4, 4}); + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto zhat = NDArrayFactory::create('c', {4,1}, {2 ,1 ,2 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expU = NDArrayFactory::create('c', {5,5}, {-0.662161, 0.980399,-0.791469,-0.748434, 0, -0.744931, 0.183825,-0.593602,-0.392928, 0, 0.0472972, 0.061275,0.0719517, 0.104781, 0, 0.0662161,0.0356509, 0.126635, 0.523904, 0, 0, 0, 0, 0, 1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.745259,-0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, -0.156156, -0.0768918,-0.130705,-0.0885868,-0.0773343, 0.115929,0.0818966, 0.167906, 0.416415}); + auto U = NDArrayFactory::create('c', {5,5}); + auto V = NDArrayFactory::create('c', {4,4}); + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); - svd.calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V); + svd.calcSingVecs(zhat, diag,permut, singVals, shifts, mus, U, V); ASSERT_TRUE(expU.equalsTo(&U)); ASSERT_TRUE(expV.equalsTo(&V)); + } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test12) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, - -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto matrix4 = NDArrayFactory::create( - 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, - -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); - - auto expSingVals = - NDArrayFactory::create('c', {4, 1}, {8.43282, 5, 2.3, 1.10167}); - auto expU = NDArrayFactory::create( - 'c', {5, 5}, - {0.401972, 0, 0.206791, 0.891995, 0, 0, 1, - 0, 0, 0, 0.816018, 0, -0.522818, -0.246529, - 0, -0.415371, 0, -0.826982, 0.378904, 0, 0, - 0, 0, 0, 1}); - auto expV = NDArrayFactory::create( - 'c', {4, 4}, - {-0.951851, 0, -0.133555, -0.275939, 0, 1, 0, 0, 0.290301, 0, -0.681937, - -0.671333, -0.098513, 0, -0.719114, 0.687873}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expSingVals = NDArrayFactory::create('c', {4,1}, {8.43282, 5, 2.3, 1.10167}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.401972,0, 0.206791, 0.891995,0, 0,1, 0, 0,0, 0.816018,0,-0.522818,-0.246529,0, -0.415371,0,-0.826982, 0.378904,0, 0,0, 0, 0,1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.951851,0,-0.133555,-0.275939, 0,1, 0, 0, 0.290301,0,-0.681937,-0.671333, -0.098513,0,-0.719114, 0.687873}); ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); svd._m = matrix1; @@ -1416,37 +1079,14 @@ TEST_F(HelpersTests1, SVD_test12) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test16) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, - -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto matrix4 = NDArrayFactory::create( - 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, - -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); - - auto expM = NDArrayFactory::create( - 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, 7.07022, 0, 0, 0, - -4, 0, 5.09585, 0, 0, -3, 0, 0, 3.32256, 0, - -5, 0, 0, 0, 1.00244, -2, -3, -4, -5, 0}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, {-5.58884, -2.18397, -11.0944, 3.30292, 0, -10, - 8.19094, 5.05917, 16.9641, -4.53112, 0, 20, - 6.55878, 3.76734, 15.9255, -3.76399, 0, -19, - 1.36021, 23.3551, -8.01165, -1.5816, 0, 14, - -15.6318, -2.85386, 8.83051, 2.74286, 1, -16, - 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, - {-18, 1, 19, -7, 1, 2, 14.5866, - 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, - 2.37014, -3, 0.56907, -8.93313, -5.31596, 3.10096, 16, - -10.6859, 1.70708, -7.24295, -10.6975}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,7.07022, 0, 0, 0, -4, 0,5.09585, 0, 0, -3, 0, 0,3.32256, 0, -5, 0, 0, 0,1.00244, -2, -3, -4, -5, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {-5.58884,-2.18397,-11.0944, 3.30292, 0,-10, 8.19094, 5.05917, 16.9641,-4.53112, 0, 20, 6.55878, 3.76734, 15.9255,-3.76399, 0,-19, 1.36021, 23.3551,-8.01165, -1.5816, 0, 14, -15.6318,-2.85386, 8.83051, 2.74286, 1,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, 14.5866, 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, 2.37014, -3, 0.56907,-8.93313,-5.31596, 3.10096, 16,-10.6859, 1.70708,-7.24295,-10.6975}); ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); svd._m = matrix1; @@ -1467,38 +1107,14 @@ TEST_F(HelpersTests1, SVD_test16) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, SVD_test17) { - auto matrix1 = NDArrayFactory::create( - 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, -4, 5, -2, -3, -4, 0, 5, -1, -5, - -3, -5, 3, 3, 3, -5, 5, -5, 0, 2, -2, -3, -4, -5, -3}); - auto matrix2 = NDArrayFactory::create( - 'c', {6, 6}, {-10, -16, -20, 13, 20, -10, -9, -1, -7, -20, -4, 20, - -11, 19, -5, -18, 12, -19, 18, -18, 17, -10, -19, 14, - -2, -7, -17, -14, -4, -16, 18, -6, -18, 1, -15, -12}); - auto matrix3 = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, -2, -11, 8, - 2, -6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); - auto matrix4 = NDArrayFactory::create( - 'c', {5, 5}, {3, -8, 5, 7, -8, 4, -19, -12, -4, -5, -11, 19, -2, - -7, 1, 16, -5, 10, 19, -19, 0, -20, 0, -8, -13}); - - auto expM = NDArrayFactory::create( - 'c', {6, 5}, {-2, -3, 2, 1, 0, 0, 12.1676, 0, 0, 0, - -4, 0, 7.49514, 0, 0, -3, 0, 0, 5.00951, 0, - -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); - auto expU = NDArrayFactory::create( - 'c', {6, 6}, - {0.295543, -0.238695, 0.262095, -0.231772, -0.85631, -10, - 0.519708, 0.0571492, -0.368706, -0.727615, 0.247527, 20, - 0.313717, -0.561567, -0.602941, 0.469567, -0.0468295, -19, - 0.474589, -0.372165, 0.656962, 0.124776, 0.434845, 14, - -0.564717, -0.697061, 0.0150082, -0.4252, 0.119081, -16, - 18, -6, -18, 1, -15, -12}); - auto expV = NDArrayFactory::create( - 'c', {5, 5}, {-18, 1, 19, -7, 1, - 2, -0.0366659, 0.977361, -0.0316106, 0.205967, - -2, -0.670795, -0.151697, -0.503288, 0.523185, - -3, 0.740124, -0.0841435, -0.486714, 0.456339, - 16, 0.0300945, -0.121135, 0.71331, 0.689645}); + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,12.1676, 0, 0, 0, -4, 0,7.49514, 0, 0, -3, 0, 0,5.00951, 0, -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.295543,-0.238695, 0.262095,-0.231772, -0.85631,-10, 0.519708,0.0571492,-0.368706,-0.727615, 0.247527, 20, 0.313717,-0.561567,-0.602941, 0.469567,-0.0468295,-19, 0.474589,-0.372165, 0.656962, 0.124776, 0.434845, 14, -0.564717,-0.697061,0.0150082, -0.4252, 0.119081,-16, 18, -6, -18, 1, -15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-0.0366659, 0.977361,-0.0316106,0.205967, -2, -0.670795, -0.151697, -0.503288,0.523185, -3, 0.740124,-0.0841435, -0.486714,0.456339, 16, 0.0300945, -0.121135, 0.71331,0.689645}); ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); svd._m = matrix1; @@ -1519,97 +1135,37 @@ TEST_F(HelpersTests1, SVD_test17) { // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test18) { -// auto matrix('c', {10,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 -// ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 -// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 -// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 -// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 -// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 -// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 -// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 -// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 -// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 -// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4}); - -// auto expS('c', {10, 1}, -// {65.0394, 56.1583, 48.9987, 39.2841, 35.7296, 22.8439, 17.474, 15.2708, 15.0768, -// 0.846648}); - -// auto expU('c', {10,10}, {0.413187, 0.159572,0.0238453, -// 0.601154,-0.0428558, -0.461779, 0.41787, -0.221153, 0.0206268, -// 0.0532219, -// 0.364377,-0.154281, -// 0.199857,-0.0943331, 0.415653, -// -0.139834, -0.258458, 0.10677, -// 0.72003,-0.0749772, -// -0.315063,-0.418079,-0.377499, -// 0.37031, 0.0123835, 0.300036, -// 0.153702, -0.129223, 0.390675, -// 0.403962, -// 0.102001,-0.216667, -// -0.74093,-0.166164,-0.0269665, -// -0.240065, 0.0549761,-0.0178001, -// 0.0197525, -0.55134, -// -0.107298, 0.386899,-0.377536, -// 0.033214, 0.486739, -0.245438, -// -0.43788, -0.208875, -0.170449, -// 0.365491, -// 0.18026, 0.240482,-0.115801, -// 0.237399, -0.643413, 0.139274, -// -0.582963, -0.116222, -// 0.224524,-0.0525887, -// 0.141172, 0.340505,-0.261653, -// 0.186411, 0.0625811, 0.19585, -// 0.128195, 0.832893, 0.0319884, -// 0.0864513, -// -0.385777,-0.330504, 0.128342, -// 0.156083, -0.200883, -0.648548, -// -0.256507, 0.40519,-0.0434365, -// 0.0909978, -// 0.574478,-0.371028,-0.136672,-0.328417, -// -0.190226,-0.0476664,-0.0399815, -// 0.0687528, -0.242039, 0.549918, -// 0.209886,-0.398294,0.0919207, -// 0.490454, 0.305228, 0.280486, -// -0.341358, 0.0540678, -0.432618, -// -0.264332}); - -// auto expV('c', {10,10}, {0.423823,-0.0845148, 0.389647, -// -0.10717,-0.168732, 0.123783, 0.159237, -0.450407, -0.611513,-0.0629076, -// 0.412121, 0.317493, -// -0.355665,-0.383203,-0.382616,-0.309073, -// -0.21869,-0.0746378, 0.0829771, -// 0.392186, -// -0.0603483, 0.232234, 0.0383737, -// 0.435441,0.0829318, -// 0.327822,-0.206101, 0.184083, -// -0.34018, 0.667018, -// -0.453935, 0.119616, 0.288392, -// 0.184366,-0.524289, -0.42264, -// 0.41005,-0.0505891,0.00333608, -// 0.195602, -// 0.247802, 0.0776165, 0.33026, -// 0.190986, -// 0.526809,-0.345006,0.0651023, -// -0.386472, 0.395169, 0.284091, -// 0.426355, -0.269507, 0.304685, -// 0.386708,-0.257916,-0.287742,-0.329622, -// 0.463719, 0.0613767, -0.16261, -// -0.384582, 0.241486, -// 0.425935,-0.292636,0.0465594,-0.125018,-0.685871, -// -0.112806,-0.0977978, -0.127356, -// -0.121678, -0.06796, -0.501443, -// 0.473165,0.0422977,-0.369324,-0.248758, -// -0.408769, -0.305785, -0.211138, -// 0.186099, 0.809997, 0.0338281, -// 0.268965, -0.04829, 0.141617, -// 0.12121, 0.0362537, 0.0831986, -// -0.436428, -// 0.0174496, -// 0.161638,-0.0334757,-0.224027, -// 0.439364,-0.478697, 0.237318, -// 0.457809, -0.483235,-0.0253522}); +// auto matrix('c', {10,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4}); + +// auto expS('c', {10, 1}, {65.0394, 56.1583, 48.9987, 39.2841, 35.7296, 22.8439, 17.474, 15.2708, 15.0768, 0.846648}); + +// auto expU('c', {10,10}, {0.413187, 0.159572,0.0238453, 0.601154,-0.0428558, -0.461779, 0.41787, -0.221153, 0.0206268, 0.0532219, +// 0.364377,-0.154281, 0.199857,-0.0943331, 0.415653, -0.139834, -0.258458, 0.10677, 0.72003,-0.0749772, +// -0.315063,-0.418079,-0.377499, 0.37031, 0.0123835, 0.300036, 0.153702, -0.129223, 0.390675, 0.403962, +// 0.102001,-0.216667, -0.74093,-0.166164,-0.0269665, -0.240065, 0.0549761,-0.0178001, 0.0197525, -0.55134, +// -0.107298, 0.386899,-0.377536, 0.033214, 0.486739, -0.245438, -0.43788, -0.208875, -0.170449, 0.365491, +// 0.18026, 0.240482,-0.115801, 0.237399, -0.643413, 0.139274, -0.582963, -0.116222, 0.224524,-0.0525887, +// 0.141172, 0.340505,-0.261653, 0.186411, 0.0625811, 0.19585, 0.128195, 0.832893, 0.0319884, 0.0864513, +// -0.385777,-0.330504, 0.128342, 0.156083, -0.200883, -0.648548, -0.256507, 0.40519,-0.0434365, 0.0909978, +// 0.574478,-0.371028,-0.136672,-0.328417, -0.190226,-0.0476664,-0.0399815, 0.0687528, -0.242039, 0.549918, +// 0.209886,-0.398294,0.0919207, 0.490454, 0.305228, 0.280486, -0.341358, 0.0540678, -0.432618, -0.264332}); + +// auto expV('c', {10,10}, {0.423823,-0.0845148, 0.389647, -0.10717,-0.168732, 0.123783, 0.159237, -0.450407, -0.611513,-0.0629076, +// 0.412121, 0.317493, -0.355665,-0.383203,-0.382616,-0.309073, -0.21869,-0.0746378, 0.0829771, 0.392186, +// -0.0603483, 0.232234, 0.0383737, 0.435441,0.0829318, 0.327822,-0.206101, 0.184083, -0.34018, 0.667018, +// -0.453935, 0.119616, 0.288392, 0.184366,-0.524289, -0.42264, 0.41005,-0.0505891,0.00333608, 0.195602, +// 0.247802, 0.0776165, 0.33026, 0.190986, 0.526809,-0.345006,0.0651023, -0.386472, 0.395169, 0.284091, +// 0.426355, -0.269507, 0.304685, 0.386708,-0.257916,-0.287742,-0.329622, 0.463719, 0.0613767, -0.16261, +// -0.384582, 0.241486, 0.425935,-0.292636,0.0465594,-0.125018,-0.685871, -0.112806,-0.0977978, -0.127356, +// -0.121678, -0.06796, -0.501443, 0.473165,0.0422977,-0.369324,-0.248758, -0.408769, -0.305785, -0.211138, +// 0.186099, 0.809997, 0.0338281, 0.268965, -0.04829, 0.141617, 0.12121, 0.0362537, 0.0831986, -0.436428, +// 0.0174496, 0.161638,-0.0334757,-0.224027, 0.439364,-0.478697, 0.237318, 0.457809, -0.483235,-0.0253522}); // ops::helpers::SVD svd(matrix, 8, true, true, true); // // svd._u.printShapeInfo(); @@ -1624,109 +1180,43 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } + // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test19) { -// auto matrix('c', {11,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 -// ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 -// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 -// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 -// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 -// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 -// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 -// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 -// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 -// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 -// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -7, -// 1, -2, 15, 0, 4, -9,19, -3, 10 -// }); - -// auto expS('c', {10, 1}, -// {65.5187, 56.305, 50.9808, 41.6565, 35.8698, 29.3898, 17.9743, 15.3568, 15.2223, -// 0.846847}); - -// auto expU('c', {11,11}, {-0.387999,-0.117659, 0.162976, -// 0.641067,-0.0174306, -0.181469,-0.218643, -0.308042, -// 0.0670776,-0.0632539, -0.462228, -// -0.37021, 0.14822, -// -0.195157,-0.0467394, -0.381275, -// -0.183363, 0.326599, -0.370579, -// -0.56626, 0.0798798, 0.225133, -// 0.339692, 0.433146, 0.30841, -// 0.134184, -0.108725, -// 0.466056,-0.153546, -0.359783, -// -0.189621, -0.402737, 0.0605675, -// -0.0650167, 0.268868, 0.662416, -// -0.327524, -// 0.0339198,-0.0916729,0.0415428, -// -0.0765093,-0.0288338, 0.546108, -// -0.247418, -// 0.114029,-0.361828, -// 0.379255,-0.0935836, -0.488912, -// -0.125232, 0.480666,-0.00544881, -// 0.280747, -0.36698,-0.0648559, -// -0.174798, -0.21859, 0.178313, -// 0.212153, 0.579101, 0.369942, -// 0.551063, -0.139813,-0.0296135, -// 0.0572204, 0.212783, -// -0.133981,-0.311817, 0.304673, -// 0.0865395, -0.104221, -// 0.196295,-0.191271, 0.571084, -// -0.603697,-0.0868996,-0.0196788, -// 0.398676, 0.319697, -0.112145, -// 0.235089, 0.201666, -0.337134, -// 0.43406, 0.261686, -// -0.283102,-0.0999458, -0.411893, -// -0.559998, 0.392802, 0.0996997, -// -0.281135, 0.24017, -// -0.136769,0.0121463, 0.218664, -// 0.127577, -0.550001,0.00227476, -// -0.197522, 0.403875,-0.0647804, -// 0.383315, -0.388502, 0.335719, -// 0.20912, 0.404926, 0.309087, -// 0.266437, 0.0942471, -// 0.140425,0.0934688, 0.325994, -// 0.345081, 0.0825574, -// -0.521239,-0.129018, 0.0806886, -// 0.0442647, 0.014397, 0.665103}); - -// auto expV('c', {10,10}, {-0.4428, 0.0661762,-0.361903, 0.0307317, -// 0.19574,-0.0356551,-0.241991, 0.0866805, 0.74701, 0.062837, -// -0.400091, -0.277277, 0.375095, -// -0.323052, 0.443668, -0.264809, -// 0.292881, -// -0.106586,-0.00623963,-0.392226, -// 0.0536693, -0.232105,0.0106246, -// 0.332557, -0.167406, -// 0.400872,0.0835708, 0.414598, -// 0.141906,-0.666936, -// 0.473793, -0.121962,-0.147941, -// 0.414665, 0.538964, -// -0.372149,-0.285458, -0.132952, -// -0.0166319,-0.195945, -// -0.251722,-0.0813691,-0.233887, -// 0.280439, -0.512597, -0.328782, -// 0.074277, -0.581806, -// -0.0327555,-0.284121, -0.406324, -// 0.284462,-0.168731, 0.518021, -// 0.226396, -0.109282, 0.381083, -// 0.305342, -0.359301, 0.162524, -// 0.335857, -0.302206,-0.484806, -// -0.196382,0.00286755, -0.111789, -// 0.672115, 0.0705632, 0.191787, -// 0.127533, 0.185896, 0.134279, -// 0.608397, 0.382412,-0.0997649, -// -0.117987, 0.326934,-0.0941208, -// 0.496913, 0.210914, -// -0.201675, -0.795446,0.0916484, -// 0.267237,0.00604554, 0.167517, -// -0.13914,-0.0355323, -0.0869256, -// 0.436465, -// 0.00123325, -// -0.142684,0.0978458,-0.0945446, -// -0.349755, -0.674457,-0.196126, -// 0.587134,-0.00964182,0.0249317}); +// auto matrix('c', {11,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, +// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); + +// auto expS('c', {10, 1}, {65.5187, 56.305, 50.9808, 41.6565, 35.8698, 29.3898, 17.9743, 15.3568, 15.2223, 0.846847}); + +// auto expU('c', {11,11}, {-0.387999,-0.117659, 0.162976, 0.641067,-0.0174306, -0.181469,-0.218643, -0.308042, 0.0670776,-0.0632539, -0.462228, +// -0.37021, 0.14822, -0.195157,-0.0467394, -0.381275, -0.183363, 0.326599, -0.370579, -0.56626, 0.0798798, 0.225133, +// 0.339692, 0.433146, 0.30841, 0.134184, -0.108725, 0.466056,-0.153546, -0.359783, -0.189621, -0.402737, 0.0605675, +// -0.0650167, 0.268868, 0.662416, -0.327524, 0.0339198,-0.0916729,0.0415428, -0.0765093,-0.0288338, 0.546108, -0.247418, +// 0.114029,-0.361828, 0.379255,-0.0935836, -0.488912, -0.125232, 0.480666,-0.00544881, 0.280747, -0.36698,-0.0648559, +// -0.174798, -0.21859, 0.178313, 0.212153, 0.579101, 0.369942, 0.551063, -0.139813,-0.0296135, 0.0572204, 0.212783, +// -0.133981,-0.311817, 0.304673, 0.0865395, -0.104221, 0.196295,-0.191271, 0.571084, -0.603697,-0.0868996,-0.0196788, +// 0.398676, 0.319697, -0.112145, 0.235089, 0.201666, -0.337134, 0.43406, 0.261686, -0.283102,-0.0999458, -0.411893, +// -0.559998, 0.392802, 0.0996997, -0.281135, 0.24017, -0.136769,0.0121463, 0.218664, 0.127577, -0.550001,0.00227476, +// -0.197522, 0.403875,-0.0647804, 0.383315, -0.388502, 0.335719, 0.20912, 0.404926, 0.309087, 0.266437, 0.0942471, +// 0.140425,0.0934688, 0.325994, 0.345081, 0.0825574, -0.521239,-0.129018, 0.0806886, 0.0442647, 0.014397, 0.665103}); + +// auto expV('c', {10,10}, {-0.4428, 0.0661762,-0.361903, 0.0307317, 0.19574,-0.0356551,-0.241991, 0.0866805, 0.74701, 0.062837, +// -0.400091, -0.277277, 0.375095, -0.323052, 0.443668, -0.264809, 0.292881, -0.106586,-0.00623963,-0.392226, +// 0.0536693, -0.232105,0.0106246, 0.332557, -0.167406, 0.400872,0.0835708, 0.414598, 0.141906,-0.666936, +// 0.473793, -0.121962,-0.147941, 0.414665, 0.538964, -0.372149,-0.285458, -0.132952, -0.0166319,-0.195945, +// -0.251722,-0.0813691,-0.233887, 0.280439, -0.512597, -0.328782, 0.074277, -0.581806, -0.0327555,-0.284121, +// -0.406324, 0.284462,-0.168731, 0.518021, 0.226396, -0.109282, 0.381083, 0.305342, -0.359301, 0.162524, +// 0.335857, -0.302206,-0.484806, -0.196382,0.00286755, -0.111789, 0.672115, 0.0705632, 0.191787, 0.127533, +// 0.185896, 0.134279, 0.608397, 0.382412,-0.0997649, -0.117987, 0.326934,-0.0941208, 0.496913, 0.210914, +// -0.201675, -0.795446,0.0916484, 0.267237,0.00604554, 0.167517, -0.13914,-0.0355323, -0.0869256, 0.436465, +// 0.00123325, -0.142684,0.0978458,-0.0945446, -0.349755, -0.674457,-0.196126, 0.587134,-0.00964182,0.0249317}); // ops::helpers::SVD svd(matrix, 8, true, true, true); @@ -1739,111 +1229,43 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } + // /////////////////////////////////////////////////////////////////// // TEST_F(HelpersTests1, SVD_test20) { -// auto matrix('c', {10,11}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 -// ,-18 ,20 ,14 , -// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 -// ,-6 ,-13 ,16 ,-18 ,-13 , -10 ,16 -// ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 -// ,-14 ,7 ,7 ,-9 , 5 ,-16 ,7 ,16 ,13 -// ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 -// ,-16 , -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 -// ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , 9 -// ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 -// ,-2 ,17 ,-18 ,-5 ,-14 ,0 ,9 ,-16 ,9 -// ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, -7, -// 1, -2, 15, 0, 4, -9,19, -3, 10 -// }); - -// auto expS('c', {10, 1}, -// {68.9437, 54.8773, 50.7858, 42.4898, 35.1984, 26.6285, 21.376, 12.2334, 5.9112, -// 0.38292}); - -// auto expU('c', {10,10}, {0.30332,-0.0677785, 0.155514, -// -0.722623,-0.0843687,-0.0712535, 0.414936, -0.15422, -// -0.381536,-0.057561, -// 0.473286, 0.0231518, 0.0878106, -// 0.45493, -0.311654, 0.138957, -// 0.311305, 0.509971, -// -0.288207,0.0656506, -// -0.131548, 0.32051, -// 0.489848,-0.0539042, -0.521328, -// -0.363728, -// -0.328685,-0.0329672,-0.0726502, -// 0.344431, -// 0.072974, 0.522632, -0.477056, -// 0.0618953,-0.0980883, -0.095653, -// -0.26596, -0.15453, -// -0.475107,-0.388594, 0.267569, -// -0.336154,-0.0930604, -0.261336, -// -0.39945, 0.480346, -0.568317, -// 0.0593335, 0.102036,-0.106029, -// -0.0919782, -0.460136, 0.106434, -// 0.327722, 0.0952523, 0.0915698, -// -0.129052, -0.460878, -0.59722, -// 0.240608, -// -0.248827, -0.48834, -0.243788, -// -0.106636,-0.0803772, -0.567457, -// -0.12005, 0.480504, -// -0.188409,-0.139802, -// 0.643408, -0.16245, -0.152596, -// 0.16849,-0.0120438, -// -0.51616,-0.0694232, -0.36172, -// 0.322169,0.0440701, -// -0.229467,-0.0227008, -// -0.588303,-0.0327104, -0.482264, -// 0.0794715, 0.340158, -0.175969, -// 0.108784, 0.449731, -// 0.229718, 0.169979, -0.227516, -// -0.21815, 0.454459, 0.017476, -// -0.278516, 0.287333, -0.148844, -// 0.655637}); - -// auto expV('c', {11,11}, {0.190806, -0.193628, 0.383793,-0.0266376, -// 0.113035, 0.158361, 0.0297803, -0.793229, -0.13761,-0.260666, -// -0.152503, -// -0.303449, 0.0392386, 0.250627, -// -0.165231, 0.141567, 0.0479565, -// 0.72763, 0.14053, -0.339907, -// 0.224366, -0.280806, -0.159724, -// -0.38984, -0.256355, -0.337861, -// 0.075089, -0.237427, -0.153718, -// -0.217747, 0.320899, 0.455058, -// -0.446697, -// 0.376823, -0.560303, 0.269135, -// 0.265416,-0.00742902, 0.0263377, -// -0.192808, 0.435842, -// -0.275365,0.0511804, -0.30799, -// 0.522537, 0.209791, -0.44191, -// -0.282323, -0.12139, 0.226382, -// 0.221075, 0.0844301, -// 0.0285412,-0.297578, -0.443394, -// 0.0588008, 0.115035, 0.54835, -// -0.52266, -0.141345, 0.411122, -// -0.182423, 0.213721, 0.353022, -// 0.119504, 0.0508673, -// -0.299021,-0.0424794, -0.285618, -// 0.177961, 0.35831, 0.769783, -// -0.215983,-0.00423939, -// -0.110575,0.0928082,-0.0841152, -// -0.0977062, -0.624782, -0.240391, -// -0.276154, -0.342018, 0.199695, -// 0.268881, 0.00402219,-0.0536164, -// -0.17679, 0.450283, -// 0.428931, 0.0748696, -0.120853, -// -0.360103, 0.37093,-0.0611563, -// -0.100263, -0.0604207, -0.432926, -// 0.412875, 0.39142, -0.35553, -// 0.127463,-0.0199906, -0.343149, -// -0.315968, -0.115698, -0.442585, -// 0.0126156, -0.584161,-0.219242, -// -0.20156, -// -0.134753, -0.154272, 0.037343, -// -0.281348, 0.666324, -// -0.213813,-0.0427932, 0.238783, -// 0.132347,-0.557478, 0.0253325}); +// auto matrix('c', {10,11}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, +// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); + +// auto expS('c', {10, 1}, {68.9437, 54.8773, 50.7858, 42.4898, 35.1984, 26.6285, 21.376, 12.2334, 5.9112, 0.38292}); + +// auto expU('c', {10,10}, {0.30332,-0.0677785, 0.155514, -0.722623,-0.0843687,-0.0712535, 0.414936, -0.15422, -0.381536,-0.057561, +// 0.473286, 0.0231518, 0.0878106, 0.45493, -0.311654, 0.138957, 0.311305, 0.509971, -0.288207,0.0656506, +// -0.131548, 0.32051, 0.489848,-0.0539042, -0.521328, -0.363728, -0.328685,-0.0329672,-0.0726502, 0.344431, +// 0.072974, 0.522632, -0.477056, 0.0618953,-0.0980883, -0.095653, -0.26596, -0.15453, -0.475107,-0.388594, +// 0.267569, -0.336154,-0.0930604, -0.261336, -0.39945, 0.480346, -0.568317, 0.0593335, 0.102036,-0.106029, +// -0.0919782, -0.460136, 0.106434, 0.327722, 0.0952523, 0.0915698, -0.129052, -0.460878, -0.59722, 0.240608, +// -0.248827, -0.48834, -0.243788, -0.106636,-0.0803772, -0.567457, -0.12005, 0.480504, -0.188409,-0.139802, +// 0.643408, -0.16245, -0.152596, 0.16849,-0.0120438, -0.51616,-0.0694232, -0.36172, 0.322169,0.0440701, +// -0.229467,-0.0227008, -0.588303,-0.0327104, -0.482264, 0.0794715, 0.340158, -0.175969, 0.108784, 0.449731, +// 0.229718, 0.169979, -0.227516, -0.21815, 0.454459, 0.017476, -0.278516, 0.287333, -0.148844, 0.655637}); + +// auto expV('c', {11,11}, {0.190806, -0.193628, 0.383793,-0.0266376, 0.113035, 0.158361, 0.0297803, -0.793229, -0.13761,-0.260666, -0.152503, +// -0.303449, 0.0392386, 0.250627, -0.165231, 0.141567, 0.0479565, 0.72763, 0.14053, -0.339907, 0.224366, -0.280806, +// -0.159724, -0.38984, -0.256355, -0.337861, 0.075089, -0.237427, -0.153718, -0.217747, 0.320899, 0.455058, -0.446697, +// 0.376823, -0.560303, 0.269135, 0.265416,-0.00742902, 0.0263377, -0.192808, 0.435842, -0.275365,0.0511804, -0.30799, +// 0.522537, 0.209791, -0.44191, -0.282323, -0.12139, 0.226382, 0.221075, 0.0844301, 0.0285412,-0.297578, -0.443394, +// 0.0588008, 0.115035, 0.54835, -0.52266, -0.141345, 0.411122, -0.182423, 0.213721, 0.353022, 0.119504, 0.0508673, +// -0.299021,-0.0424794, -0.285618, 0.177961, 0.35831, 0.769783, -0.215983,-0.00423939, -0.110575,0.0928082,-0.0841152, +// -0.0977062, -0.624782, -0.240391, -0.276154, -0.342018, 0.199695, 0.268881, 0.00402219,-0.0536164, -0.17679, 0.450283, +// 0.428931, 0.0748696, -0.120853, -0.360103, 0.37093,-0.0611563, -0.100263, -0.0604207, -0.432926, 0.412875, 0.39142, +// -0.35553, 0.127463,-0.0199906, -0.343149, -0.315968, -0.115698, -0.442585, 0.0126156, -0.584161,-0.219242, -0.20156, +// -0.134753, -0.154272, 0.037343, -0.281348, 0.666324, -0.213813,-0.0427932, 0.238783, 0.132347,-0.557478, 0.0253325}); // ops::helpers::SVD svd(matrix, 8, true, true, true); @@ -1856,18 +1278,16 @@ TEST_F(HelpersTests1, SVD_test17) { // ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); // } + ///////////////////////////////////////////////////////////////////// -// TEST_F(HelpersTests1, reverseArray_test1) { +//TEST_F(HelpersTests1, reverseArray_test1) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, -// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', -// {2,5}, {10,9,8,7,6,5,4,3,2,1}); auto outArr = -// NDArrayFactory::create('c', {2,5}); +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); +// auto outArr = NDArrayFactory::create('c', {2,5}); // // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), -// inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), -// outArr.shapeInfo()); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo()); // // ASSERT_TRUE(outArr.equalsTo(&exp)); // ASSERT_TRUE(outArr.isSameShapeStrict(exp)); @@ -1875,16 +1295,13 @@ TEST_F(HelpersTests1, SVD_test17) { // // ///////////////////////////////////////////////////////////////////// -// TEST_F(HelpersTests1, reverseArray_test2) { +//TEST_F(HelpersTests1, reverseArray_test2) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, -// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', -// {2,5}, {10,9,8,7,6,5,4,3,2,1}); +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); // // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), -// inArr.getBuffer(), inArr.shapeInfo(), inArr.getBuffer(), -// inArr.shapeInfo()); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), inArr.getBuffer(), inArr.shapeInfo()); // // ASSERT_TRUE(inArr.equalsTo(&exp)); // ASSERT_TRUE(inArr.isSameShapeStrict(exp)); @@ -1892,16 +1309,13 @@ TEST_F(HelpersTests1, SVD_test17) { // // ///////////////////////////////////////////////////////////////////// -// TEST_F(HelpersTests1, reverseArray_test3) { +//TEST_F(HelpersTests1, reverseArray_test3) { // -// auto inArr = NDArrayFactory::create('c', {2,5}, -// {1,2,3,4,5,6,7,8,9,10}); auto exp = NDArrayFactory::create('c', -// {2,5}, {5,4,3,2,1,6,7,8,9,10}); auto outArr = -// NDArrayFactory::create('c', {2,5}); +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {5,4,3,2,1,6,7,8,9,10}); +// auto outArr = NDArrayFactory::create('c', {2,5}); // -// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), -// inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), -// outArr.shapeInfo(), 5); +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo(), 5); // // ASSERT_TRUE(outArr.equalsTo(&exp)); // ASSERT_TRUE(outArr.isSameShapeStrict(exp)); @@ -1909,15 +1323,16 @@ TEST_F(HelpersTests1, SVD_test17) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test1) { + const int bS = 2; - const int inSize = 4; + const int inSize = 4; const int numUnits = 4; NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); - NDArray b('c', {2 * numUnits}, {0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.4}); + NDArray b ('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); xt.assign(0.1); @@ -1925,29 +1340,27 @@ TEST_F(HelpersTests1, rnnCell_test1) { Wx.assign(0.3); Wh.assign(0.4); - NDArray expHt('c', {bS, numUnits}, - {0.492988, 0.56489956, 0.6291452, 0.6858091, 0.492988, - 0.56489956, 0.6291452, 0.6858091}); + NDArray expHt('c', {bS, numUnits}, {0.492988, 0.56489956, 0.6291452 , 0.6858091,0.492988, 0.56489956, 0.6291452 , 0.6858091}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, - &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); } + /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test2) { + const int bS = 2; - const int inSize = 10; + const int inSize = 10; const int numUnits = 4; auto xt = NDArrayFactory::create('c', {bS, inSize}); auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create( - 'c', {2 * numUnits}, {0.0, 0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.4}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); auto ht = NDArrayFactory::create('c', {bS, numUnits}); @@ -1956,13 +1369,9 @@ TEST_F(HelpersTests1, rnnCell_test2) { Wx.assign(0.3); Wh.assign(0.4); - auto expHt = NDArrayFactory::create( - 'c', {bS, numUnits}, - {0.6169093, 0.67506987, 0.72589741, 0.76986654, 0.6169093, 0.67506987, - 0.72589741, 0.76986654}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.6169093,0.67506987,0.72589741,0.76986654,0.6169093,0.67506987,0.72589741,0.76986654}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, - &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -1970,16 +1379,16 @@ TEST_F(HelpersTests1, rnnCell_test2) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test3) { + const int bS = 2; - const int inSize = 10; + const int inSize = 10; const int numUnits = 4; auto xt = NDArrayFactory::create('c', {bS, inSize}); auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create( - 'c', {2 * numUnits}, {0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.01,0.02,0.03,0.04, 0.05,0.06,0.07,0.08}); auto ht = NDArrayFactory::create('c', {bS, numUnits}); @@ -1988,13 +1397,9 @@ TEST_F(HelpersTests1, rnnCell_test3) { Wx.assign(0.3); Wh.assign(0.4); - auto expHt = NDArrayFactory::create( - 'c', {bS, numUnits}, - {0.5915195, 0.6043678, 0.6169093, 0.6291452, 0.5915195, 0.6043678, - 0.6169093, 0.6291452}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.5915195, 0.6043678, 0.6169093, 0.6291452,0.5915195, 0.6043678, 0.6169093, 0.6291452}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, - &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -2002,31 +1407,28 @@ TEST_F(HelpersTests1, rnnCell_test3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, rnnCell_test4) { + const int bS = 2; - const int inSize = 3; + const int inSize = 3; const int numUnits = 4; auto xt = NDArrayFactory::create('c', {bS, inSize}); auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); - auto b = NDArrayFactory::create('c', {2 * numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); auto ht = NDArrayFactory::create('c', {bS, numUnits}); xt.linspace(0.01, 0.01); ht_1 = 0.2; - Wx = 0.3; - Wh = 0.4; - b = 0.25; + Wx = 0.3; + Wh = 0.4; + b = 0.25; - auto expHt = NDArrayFactory::create( - 'c', {bS, numUnits}, - {0.68474828, 0.68474828, 0.68474828, 0.68474828, 0.69882484, 0.69882484, - 0.69882484, 0.69882484}); + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484}); - ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, - &ht_1, &ht); + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); @@ -2035,12 +1437,10 @@ TEST_F(HelpersTests1, rnnCell_test4) { #endif //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_1) { - auto x = NDArrayFactory::create('c', {3, 3}, - {10, 11, 12, 13, 14, 15, 16, 17, 18}); - auto y = - NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto expected = NDArrayFactory::create( - 'c', {3, 3}, {138., 171., 204., 174., 216., 258., 210., 261., 312.}); + + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); @@ -2048,34 +1448,30 @@ TEST_F(HelpersTests1, mmulHelper_test_1) { ASSERT_TRUE(expected.equalsTo(result)); delete result; + } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_2) { - auto x = NDArrayFactory::create('c', {3, 3}, - {10, 11, 12, 13, 14, 15, 16, 17, 18}); - auto y = - NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto expected = NDArrayFactory::create( - 'c', {3, 3}, {138., 171., 204., 174., 216., 258., 210., 261., 312.}); - auto result = NDArrayFactory::create('c', {3, 3}); + + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); + auto result = NDArrayFactory::create('c', {3,3}); MmulHelper::mmul(&x, &y, &result, 1., 0.); ASSERT_TRUE(expected.isSameShape(&result)); ASSERT_TRUE(expected.equalsTo(&result)); + } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_3) { - auto x = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - auto y = NDArrayFactory::create('c', {4, 5}); - y.linspace(1); - auto expected = NDArrayFactory::create( - 'c', {3, 5}, - {110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., - 466., 508., 550.}); + + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); @@ -2087,15 +1483,11 @@ TEST_F(HelpersTests1, mmulHelper_test_3) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_4) { - auto x = NDArrayFactory::create('c', {3, 4}); - x.linspace(1); - auto y = NDArrayFactory::create('c', {4, 5}); - y.linspace(1); - auto expected = NDArrayFactory::create( - 'c', {3, 5}, - {110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., - 466., 508., 550.}); - auto result = NDArrayFactory::create('c', {3, 5}); + + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + auto result = NDArrayFactory::create('c', {3,5}); MmulHelper::mmul(&x, &y, &result, 1., 0.); @@ -2103,16 +1495,13 @@ TEST_F(HelpersTests1, mmulHelper_test_4) { ASSERT_TRUE(expected.equalsTo(&result)); } + //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_5) { - auto x = NDArrayFactory::create('c', {4, 3}); - x.linspace(1); - auto y = NDArrayFactory::create('c', {3, 5}); - y.linspace(1); - auto expected = NDArrayFactory::create( - 'c', {4, 5}, - {46., 52., 58., 64., 70., 100., 115., 130., 145., 160., - 154., 178., 202., 226., 250., 208., 241., 274., 307., 340.}); + + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); @@ -2124,141 +1513,95 @@ TEST_F(HelpersTests1, mmulHelper_test_5) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_6) { - auto x = NDArrayFactory::create('c', {4, 3}); - x.linspace(1); - auto y = NDArrayFactory::create('c', {3, 5}); - y.linspace(1); - auto expected = NDArrayFactory::create( - 'c', {4, 5}, - {46., 52., 58., 64., 70., 100., 115., 130., 145., 160., - 154., 178., 202., 226., 250., 208., 241., 274., 307., 340.}); - auto result = NDArrayFactory::create('c', {4, 5}); + + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); + auto result = NDArrayFactory::create('c', {4,5}); MmulHelper::mmul(&x, &y, &result, 1., 0.); ASSERT_TRUE(expected.isSameShape(&result)); ASSERT_TRUE(expected.equalsTo(&result)); + } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulHelper_test_7) { + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); - auto exp = NDArrayFactory::create( - 'c', {4, 4}, {1, 2, 3, 4, 2, 4, 6, 8, 3, 6, 9, 12, 4, 8, 12, 16}); - auto result = NDArrayFactory::create('c', {4, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + auto result = NDArrayFactory::create('c', {4,4}); MmulHelper::mmul(&x, &y, &result, 1., 0.); ASSERT_TRUE(exp.isSameShape(&result)); ASSERT_TRUE(exp.equalsTo(&result)); + } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_1) { + auto a = NDArrayFactory::create('c', {2, 3, 4}); auto b = NDArrayFactory::create('c', {2, 5, 3}); - auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); + auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); - ASSERT_TRUE(c->isSameShape({2, 4, 2, 5})); + ASSERT_TRUE(c->isSameShape({2,4,2,5})); delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_2) { + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto c = MmulHelper::tensorDot(&a, &b, {2, 1}, {4, 2}); + auto c = MmulHelper::tensorDot(&a, &b, {2,1}, {4,2}); - ASSERT_TRUE(c->isSameShape({7, 6, 2, 5, 8})); + ASSERT_TRUE(c->isSameShape({7,6,2,5,8})); delete c; } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_3) { + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); - auto c = NDArrayFactory::create('f', {7, 6, 2, 8, 5}); + auto c = NDArrayFactory::create('f', {7,6,2,8,5}); - MmulHelper::tensorDot(&a, &b, &c, {2, 1}, {4, 2}, {0, 1, 2, 4, 3}); + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); - ASSERT_TRUE(c.isSameShape({7, 6, 2, 8, 5})); + ASSERT_TRUE(c.isSameShape({7,6,2,8,5})); } //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_4) { + auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); - auto c = NDArrayFactory::create('f', {7, 3, 2, 2, 5}); - auto expected = NDArrayFactory::create( - 'c', {7, 3, 2, 2, 5}, - {754.5, 2014.5, 3274.5, 4534.5, 5794.5, 964.5, 2224.5, - 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, - 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786., - 2118., 3450., 4782., 6114., 1008., 2340., 3672., - 5004., 6336., 7446., 8778., 10110., 11442., 12774., - 7668., 9000., 10332., 11664., 12996., 817.5, 2221.5, - 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, - 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, - 9475.5, 10879.5, 12283.5, 13687.5, 1888.5, 5740.5, 9592.5, - 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5, - 21148.5, 25000.5, 28852.5, 32704.5, 36556.5, 21790.5, 25642.5, - 29494.5, 33346.5, 37198.5, 1920., 5844., 9768., 13692., - 17616., 2574., 6498., 10422., 14346., 18270., 21540., - 25464., 29388., 33312., 37236., 22194., 26118., 30042., - 33966., 37890., 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, - 2617.5, 6613.5, 10609.5, 14605.5, 18601.5, 21931.5, 25927.5, - 29923.5, 33919.5, 37915.5, 22597.5, 26593.5, 30589.5, 34585.5, - 38581.5, 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, - 10540.5, 16984.5, 23428.5, 29872.5, 35242.5, 41686.5, 48130.5, - 54574.5, 61018.5, 36316.5, 42760.5, 49204.5, 55648.5, 62092.5, - 3054., 9570., 16086., 22602., 29118., 4140., 10656., - 17172., 23688., 30204., 35634., 42150., 48666., 55182., - 61698., 36720., 43236., 49752., 56268., 62784., 3085.5, - 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, - 23947.5, 30535.5, 36025.5, 42613.5, 49201.5, 55789.5, 62377.5, - 37123.5, 43711.5, 50299.5, 56887.5, 63475.5, 4156.5, 13192.5, - 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, - 41806.5, 49336.5, 58372.5, 67408.5, 76444.5, 85480.5, 50842.5, - 59878.5, 68914.5, 77950.5, 86986.5, 4188., 13296., 22404., - 31512., 40620., 5706., 14814., 23922., 33030., 42138., - 49728., 58836., 67944., 77052., 86160., 51246., 60354., - 69462., 78570., 87678., 4219.5, 13399.5, 22579.5, 31759.5, - 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5, 50119.5, - 59299.5, 68479.5, 77659.5, 86839.5, 51649.5, 60829.5, 70009.5, - 79189.5, 88369.5, 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, - 7228.5, 18856.5, 30484.5, 42112.5, 53740.5, 63430.5, 75058.5, - 86686.5, 98314.5, 109942.5, 65368.5, 76996.5, 88624.5, 100252.5, - 111880.5, 5322., 17022., 28722., 40422., 52122., 7272., - 18972., 30672., 42372., 54072., 63822., 75522., 87222., - 98922., 110622., 65772., 77472., 89172., 100872., 112572., - 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, - 30859.5, 42631.5, 54403.5, 64213.5, 75985.5, 87757.5, 99529.5, - 111301.5, 66175.5, 77947.5, 89719.5, 101491.5, 113263.5, 6424.5, - 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, - 51454.5, 65674.5, 77524.5, 91744.5, 105964.5, 120184.5, 134404.5, - 79894.5, 94114.5, 108334.5, 122554.5, 136774.5, 6456., 20748., - 35040., 49332., 63624., 8838., 23130., 37422., 51714., - 66006., 77916., 92208., 106500., 120792., 135084., 80298., - 94590., 108882., 123174., 137466., 6487.5, 20851.5, 35215.5, - 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5, - 78307.5, 92671.5, 107035.5, 121399.5, 135763.5, 80701.5, 95065.5, - 109429.5, 123793.5, 138157.5, 7558.5, 24370.5, 41182.5, 57994.5, - 74806.5, 10360.5, 27172.5, 43984.5, 60796.5, 77608.5, 91618.5, - 108430.5, 125242.5, 142054.5, 158866.5, 94420.5, 111232.5, 128044.5, - 144856.5, 161668.5, 7590., 24474., 41358., 58242., 75126., - 10404., 27288., 44172., 61056., 77940., 92010., 108894., - 125778., 142662., 159546., 94824., 111708., 128592., 145476., - 162360., 7621.5, 24577.5, 41533.5, 58489.5, 75445.5, 10447.5, - 27403.5, 44359.5, 61315.5, 78271.5, 92401.5, 109357.5, 126313.5, - 143269.5, 160225.5, 95227.5, 112183.5, 129139.5, 146095.5, 163051.5}); + auto c = NDArrayFactory::create('f', {7,3,2,2,5}); + auto expected = NDArrayFactory::create('c', {7,3,2,2,5}, { 754.5, 2014.5, 3274.5, 4534.5 , 5794.5, 964.5, 2224.5, 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786. , 2118. , 3450. , 4782. , 6114. , 1008. , 2340. , 3672. , 5004. , 6336. , + 7446. , 8778. , 10110. , 11442. , 12774. , 7668. , 9000. , 10332. , 11664. , 12996. , 817.5, 2221.5, 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, 9475.5, 10879.5, 12283.5, 13687.5, + 1888.5, 5740.5, 9592.5, 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5,21148.5, 25000.5, 28852.5, 32704.5, 36556.5,21790.5, 25642.5, 29494.5, 33346.5, 37198.5, 1920. , 5844. , 9768. , 13692. , 17616. , 2574. , 6498. , 10422. , 14346. , 18270. , + 21540. , 25464. , 29388. , 33312. , 37236. ,22194. , 26118. , 30042. , 33966. , 37890. , 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, 2617.5, 6613.5, 10609.5, 14605.5, 18601.5,21931.5, 25927.5, 29923.5, 33919.5, 37915.5,22597.5, 26593.5, 30589.5, 34585.5, 38581.5, + 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, 10540.5, 16984.5, 23428.5, 29872.5,35242.5, 41686.5, 48130.5, 54574.5, 61018.5,36316.5, 42760.5, 49204.5, 55648.5, 62092.5, 3054. , 9570. , 16086. , 22602. , 29118. , 4140. , 10656. , 17172. , 23688. , 30204. , + 35634. , 42150. , 48666. , 55182. , 61698. ,36720. , 43236. , 49752. , 56268. , 62784. , 3085.5, 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, 23947.5, 30535.5,36025.5, 42613.5, 49201.5, 55789.5, 62377.5,37123.5, 43711.5, 50299.5, 56887.5, 63475.5, + 4156.5, 13192.5, 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, 41806.5,49336.5, 58372.5, 67408.5, 76444.5, 85480.5,50842.5, 59878.5, 68914.5, 77950.5, 86986.5, 4188. , 13296. , 22404. , 31512. , 40620. , 5706. , 14814. , 23922. , 33030. , 42138. , + 49728. , 58836. , 67944. , 77052. , 86160. ,51246. , 60354. , 69462. , 78570. , 87678. , 4219.5, 13399.5, 22579.5, 31759.5, 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5,50119.5, 59299.5, 68479.5, 77659.5, 86839.5,51649.5, 60829.5, 70009.5, 79189.5, 88369.5, + 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, 7228.5, 18856.5, 30484.5, 42112.5, 53740.5,63430.5, 75058.5, 86686.5, 98314.5,109942.5,65368.5, 76996.5, 88624.5,100252.5,111880.5, 5322. , 17022. , 28722. , 40422. , 52122. , 7272. , 18972. , 30672. , 42372. , 54072. , + 63822. , 75522. , 87222. , 98922. ,110622. ,65772. , 77472. , 89172. ,100872. ,112572. , 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, 30859.5, 42631.5, 54403.5,64213.5, 75985.5, 87757.5, 99529.5,111301.5,66175.5, 77947.5, 89719.5,101491.5,113263.5, + 6424.5, 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, 51454.5, 65674.5,77524.5, 91744.5,105964.5,120184.5,134404.5,79894.5, 94114.5,108334.5,122554.5,136774.5, 6456. , 20748. , 35040. , 49332. , 63624. , 8838. , 23130. , 37422. , 51714. , 66006. , + 77916. , 92208. ,106500. ,120792. ,135084. ,80298. , 94590. ,108882. ,123174. ,137466. , 6487.5, 20851.5, 35215.5, 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5,78307.5, 92671.5,107035.5,121399.5,135763.5,80701.5, 95065.5,109429.5,123793.5,138157.5, + 7558.5, 24370.5, 41182.5, 57994.5, 74806.5,10360.5, 27172.5, 43984.5, 60796.5, 77608.5,91618.5,108430.5,125242.5,142054.5,158866.5,94420.5,111232.5,128044.5,144856.5,161668.5, 7590. , 24474. , 41358. , 58242. , 75126. ,10404. , 27288. , 44172. , 61056. , 77940. , + 92010. ,108894. ,125778. ,142662. ,159546. ,94824. ,111708. ,128592. ,145476. ,162360. , 7621.5, 24577.5, 41533.5, 58489.5, 75445.5,10447.5, 27403.5, 44359.5, 61315.5, 78271.5,92401.5,109357.5,126313.5,143269.5,160225.5,95227.5,112183.5,129139.5,146095.5,163051.5}); a.linspace(0.5, 0.5); b.linspace(0.5, 0.5); - MmulHelper::tensorDot(&a, &b, &c, {2, 1}, {4, 2}, {0, 1, 2, 4, 3}); + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); ASSERT_TRUE(c.isSameShape(expected)); ASSERT_TRUE(c.equalsTo(expected)); @@ -2266,11 +1609,11 @@ TEST_F(HelpersTests1, tensordot_test_4) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_5) { + auto a = NDArrayFactory::create('c', {2, 3}); auto b = NDArrayFactory::create('c', {3, 4}); auto c = NDArrayFactory::create('f', {2, 4}); - auto expected = NDArrayFactory::create( - 'c', {2, 4}, {9.5, 11., 12.5, 14., 20.75, 24.5, 28.25, 32.}); + auto expected = NDArrayFactory::create('c', {2, 4}, {9.5,11.,12.5 ,14.,20.75 ,24.5,28.25,32.}); a.linspace(0.5, 0.5); b.linspace(0.5, 0.5); @@ -2284,19 +1627,16 @@ TEST_F(HelpersTests1, tensordot_test_5) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, tensordot_test_6) { - int bS = 2, iH = 3, iW = 2, iC = 2, mC = 2, kH = 2, kW = 2; - int oC = iC * mC; - int oH = 3, oW = 2; + + int bS=2, iH=3,iW=2, iC=2,mC=2, kH=2,kW=2; + int oC=iC*mC; + int oH=3,oW=2; auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto c = NDArrayFactory::create('c', {bS, oH, oW, iC * mC}); - auto expected = NDArrayFactory::create( - 'c', {bS, oH, oW, iC * mC}, - {100., 110., 336., 370., 107., 118., 345., 380., 114., 126., 354., 390., - 121., 134., 363., 400., 128., 142., 372., 410., 135., 150., 381., 420., - 436., 494., 768., 850., 443., 502., 777., 860., 450., 510., 786., 870., - 457., 518., 795., 880., 464., 526., 804., 890., 471., 534., 813., 900.}); + auto c = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}, {100.,110.,336.,370.,107.,118.,345.,380.,114.,126.,354.,390.,121.,134.,363.,400.,128.,142.,372.,410.,135.,150.,381.,420., + 436.,494.,768.,850.,443.,502.,777.,860.,450.,510.,786.,870.,457.,518.,795.,880.,464.,526.,804.,890.,471.,534.,813.,900.}); a.linspace(0.5, 0.5); b.linspace(0.5, 0.5); @@ -2304,10 +1644,7 @@ TEST_F(HelpersTests1, tensordot_test_6) { auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - MmulHelper::tensorDot(&a, &b, &cR, - {{1, 0, 4, 5, 2, 3}, {iC, bS * oH * oW, kW * kH}}, - {{2, 0, 1, 3}, {iC, kH * kW, mC}}, - {{3, 0, 1, 2, 4}, {iC, bS * oH * oW, mC}}); + MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); // c.printBuffer(); ASSERT_TRUE(c.isSameShape(expected)); @@ -2333,20 +1670,21 @@ TEST_F(HelpersTests1, mmmulHelperAgain) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test1) { + auto x1 = NDArrayFactory::create('c', {1, 1}); auto x2 = NDArrayFactory::create('c', {2, 2}); auto x3 = NDArrayFactory::create('c', {3, 3}); OpArgsHolder holder1({&x1}); - OpArgsHolder holder2({&x1, &x2, &x3}, {4.f, 5.f}, {6}); + OpArgsHolder holder2({&x1,&x2,&x3}, {4.f, 5.f}, {6}); ASSERT_TRUE(holder1.getNumInArrs() == 1); - ASSERT_TRUE(holder1.getNumTArgs() == 0); - ASSERT_TRUE(holder1.getNumIArgs() == 0); + ASSERT_TRUE(holder1.getNumTArgs() == 0); + ASSERT_TRUE(holder1.getNumIArgs() == 0); ASSERT_TRUE(holder2.getNumInArrs() == 3); - ASSERT_TRUE(holder2.getNumTArgs() == 2); - ASSERT_TRUE(holder2.getNumIArgs() == 1); + ASSERT_TRUE(holder2.getNumTArgs() == 2); + ASSERT_TRUE(holder2.getNumIArgs() == 1); const std::vector& isArrAlloc1 = holder1.getAllocInfo(); ASSERT_TRUE(isArrAlloc1.size() == 0); @@ -2357,50 +1695,46 @@ TEST_F(HelpersTests1, OpArgsHolder_test1) { //////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test2) { + auto x1 = NDArrayFactory::create('c', {1, 1}); auto x2 = NDArrayFactory::create('c', {2, 2}); auto x3 = NDArrayFactory::create('c', {3, 3}); auto grad = NDArrayFactory::create('c', {2, 3}); - OpArgsHolder holderFF({&x1, &x2, &x3}, {4.f, 5.f}, {6}); + OpArgsHolder holderFF({&x1,&x2,&x3}, {4.f, 5.f}, {6}); OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); ASSERT_TRUE(holderBP1.getNumInArrs() == 4); - ASSERT_TRUE(holderBP1.getNumTArgs() == 2); - ASSERT_TRUE(holderBP1.getNumIArgs() == 1); + ASSERT_TRUE(holderBP1.getNumTArgs() == 2); + ASSERT_TRUE(holderBP1.getNumIArgs() == 1); ASSERT_TRUE(holderBP2.getNumInArrs() == 4); - ASSERT_TRUE(holderBP2.getNumTArgs() == 2); - ASSERT_TRUE(holderBP2.getNumIArgs() == 1); + ASSERT_TRUE(holderBP2.getNumTArgs() == 2); + ASSERT_TRUE(holderBP2.getNumIArgs() == 1); const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); ASSERT_TRUE(isArrAllocBP1.size() == 0); const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); - for (int i = 0; i < holderFF.getNumInArrs(); ++i) { + for(int i = 0; i < holderFF.getNumInArrs(); ++i) { ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); } - ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs() + 1]) == - false); + ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs()+1]) == false); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, OpArgsHolder_test3) { - auto input = - NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); - auto gradO = NDArrayFactory::create('c', {4, 9}); - auto exp = NDArrayFactory::create( - 'c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, - 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6}); - auto gradIExp = NDArrayFactory::create( - 'c', {2, 3}, {0.78, 0.84, 0.9, 1.32, 1.38, 1.44}); + + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto exp = NDArrayFactory::create('c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6,1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); gradO.linspace(0.01, 0.01); OpArgsHolder holderFF({&input}, {}, {2, 3}); - sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here - // whether op.execute() works with OpArgsHolder correctly + sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly auto results = opFF.execute(holderFF); auto tiled = results.at(0); ASSERT_EQ(Status::OK(), results.status()); @@ -2414,12 +1748,14 @@ TEST_F(HelpersTests1, OpArgsHolder_test3) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); + } + ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test1) { - auto x = NDArrayFactory::create('c', {2, 3}, - {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}); + + auto x = NDArrayFactory::create('c', {2, 3}, {0.1, 0.2, 0.3, 0.4, 0.5 ,0.6}); auto gradO = NDArrayFactory::create('c', {2, 3}); const OpArgsHolder argsHolderFF({&x}, {}, {}); @@ -2428,146 +1764,135 @@ TEST_F(HelpersTests1, checkGrad_test1) { sd::ops::sigmoid opFF; sd::ops::sigmoid_bp opBP; - const bool isGradCorrect = - GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test2) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); x.linspace(1); weights.linspace(0.1, 0.1); - weights.permutei({2, 3, 1, 0}); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); sd::ops::conv2d opFF; sd::ops::conv2d_bp opBP; - const bool isGradCorrect = - GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test3) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); x.linspace(1); weights.linspace(0.1, 0.1); bias = 0.5; - weights.permutei({2, 3, 1, 0}); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); sd::ops::conv2d opFF; sd::ops::conv2d_bp opBP; - const bool isGradCorrect = - GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test4) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); x.linspace(1); weights.linspace(0.1, 0.1); bias = 0.5; - weights.permutei({2, 3, 1, 0}); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); sd::ops::conv2d opFF; sd::ops::conv2d_bp opBP; - const bool isGradCorrect = - GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test5) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); x.linspace(1); weights.linspace(0.1, 0.1); bias = 0.5; - weights.permutei({2, 3, 1, 0}); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); sd::ops::conv2d opFF; sd::ops::conv2d_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad( - opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); ASSERT_TRUE(isGradCorrect); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, checkGrad_test6) { - auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); - auto bias = NDArrayFactory::create('c', {2, 1}); - auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); x.linspace(1); weights.linspace(0.1, 0.1); bias = 0.5; - weights.permutei({2, 3, 1, 0}); + weights.permutei({2,3,1,0}); - const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); - const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, - {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); sd::ops::conv2d opFF; sd::ops::conv2d_bp opBP; - const bool isGradCorrect = - GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, - {0.5, 1}, GradCheck::MEAN); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, {0.5, 1}, GradCheck::MEAN); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test1) { - auto input = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 5}); - auto expOutput = NDArrayFactory::create('c', {1, 5}); + + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); expOutput = 1; ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); @@ -2577,11 +1902,10 @@ TEST_F(HelpersTests1, softMaxForVector_test1) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test2) { - auto input = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - auto output = NDArrayFactory::create('c', {5, 1}); - auto expOutput = NDArrayFactory::create( - 'c', {5, 1}, - {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + + auto input = NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); @@ -2590,10 +1914,10 @@ TEST_F(HelpersTests1, softMaxForVector_test2) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test3) { - auto input = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create( - 'c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + auto expOutput = NDArrayFactory::create('c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); @@ -2602,226 +1926,30 @@ TEST_F(HelpersTests1, softMaxForVector_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softMaxForVector_test4) { + NDArray input('c', {1500}, sd::DataType::DOUBLE); NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput( - 'c', {1500}, - {0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, - 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, - 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, - 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, - 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, - 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, - 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, - 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, - 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, - 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, - 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, - 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, - 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, - 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, - 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, - 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, - 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, - 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, - 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, - 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, - 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, - 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000010, - 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, - 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, - 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, - 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000013, - 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, - 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, - 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, - 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000017, 0.000017, - 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, - 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, - 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, - 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, 0.000022, 0.000022, - 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, - 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, - 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, - 0.000028, 0.000028, 0.000028, 0.000028, 0.000029, 0.000029, 0.000029, - 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, - 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, - 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, - 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, 0.000038, 0.000039, - 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, - 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, - 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, - 0.000048, 0.000049, 0.000049, 0.000050, 0.000050, 0.000051, 0.000051, - 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, - 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, - 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, - 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, 0.000067, 0.000068, - 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, - 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, - 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, - 0.000084, 0.000085, 0.000086, 0.000087, 0.000088, 0.000089, 0.000090, - 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, - 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, - 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, - 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, 0.000117, 0.000119, - 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, - 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, - 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, - 0.000148, 0.000149, 0.000151, 0.000152, 0.000154, 0.000155, 0.000157, - 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, - 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, - 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, - 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, 0.000205, 0.000208, - 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, - 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, - 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, - 0.000259, 0.000261, 0.000264, 0.000266, 0.000269, 0.000272, 0.000275, - 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, - 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, - 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, - 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, 0.000360, 0.000363, - 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, - 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, - 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, - 0.000453, 0.000457, 0.000462, 0.000467, 0.000471, 0.000476, 0.000481, - 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, - 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, - 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, - 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, 0.000630, 0.000636, - 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, - 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, - 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, - 0.000793, 0.000801, 0.000809, 0.000817, 0.000825, 0.000833, 0.000842, - 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, - 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, - 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, - 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, 0.001103, 0.001114, - 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, - 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, - 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, - 0.001388, 0.001402, 0.001416, 0.001430, 0.001444, 0.001459, 0.001473, - 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, - 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, - 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, - 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, 0.001930, 0.001950, - 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, - 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, - 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, - 0.002429, 0.002454, 0.002478, 0.002503, 0.002528, 0.002554, 0.002579, - 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, - 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, - 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, - 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, 0.003379, 0.003413, - 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, - 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, - 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, - 0.004253, 0.004296, 0.004339, 0.004382, 0.004426, 0.004471, 0.004516, - 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, - 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, - 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, - 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, 0.005916, 0.005975, - 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, - 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, - 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, - 0.007445, 0.007520, 0.007596, 0.007672, 0.007749, 0.007827, 0.007906, - 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, - 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, - 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, - 0.009851, 0.009950}, - sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001,0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, + 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002,0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, + 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003,0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, + 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005,0.000005, 0.000005, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, + 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009,0.000009, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, + 0.000012, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016,0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, + 0.000022, 0.000022, 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, 0.000028, 0.000028, 0.000028, 0.000028, 0.000029,0.000029, 0.000029, 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, + 0.000038, 0.000039, 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, 0.000048, 0.000049, 0.000049, 0.000050, 0.000050,0.000051, 0.000051, 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, + 0.000067, 0.000068, 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, 0.000084, 0.000085, 0.000086, 0.000087, 0.000088,0.000089, 0.000090, 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, + 0.000117, 0.000119, 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, 0.000148, 0.000149, 0.000151, 0.000152, 0.000154,0.000155, 0.000157, 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, + 0.000205, 0.000208, 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, 0.000259, 0.000261, 0.000264, 0.000266, 0.000269,0.000272, 0.000275, 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, + 0.000360, 0.000363, 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, 0.000453, 0.000457, 0.000462, 0.000467, 0.000471,0.000476, 0.000481, 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, + 0.000630, 0.000636, 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, 0.000793, 0.000801, 0.000809, 0.000817, 0.000825,0.000833, 0.000842, 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, + 0.001103, 0.001114, 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, 0.001388, 0.001402, 0.001416, 0.001430, 0.001444,0.001459, 0.001473, 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, + 0.001930, 0.001950, 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, 0.002429, 0.002454, 0.002478, 0.002503, 0.002528,0.002554, 0.002579, 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, + 0.003379, 0.003413, 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, 0.004253, 0.004296, 0.004339, 0.004382, 0.004426,0.004471, 0.004516, 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, + 0.005916, 0.005975, 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, 0.007445, 0.007520, 0.007596, 0.007672, 0.007749,0.007827, 0.007906, 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, 0.009851, 0.009950}, sd::DataType::DOUBLE); input.linspace(0.01, 0.01); ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); @@ -2831,324 +1959,78 @@ TEST_F(HelpersTests1, softMaxForVector_test4) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test1) { - auto input = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); - auto output = NDArrayFactory::create('c', {1, 5}); - auto expOutput = NDArrayFactory::create('c', {1, 5}); + + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); expOutput = 0; - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, - 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test2) { - auto input = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); - auto output = NDArrayFactory::create('c', {5, 1}); - auto expOutput = NDArrayFactory::create( - 'c', {5, 1}, - {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, - 0); + auto input= NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); ASSERT_TRUE(output.equalsTo(&expOutput)); } /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test3) { - auto input = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); auto output = NDArrayFactory::create('c', {5}); - auto expOutput = NDArrayFactory::create( - 'c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + auto expOutput = NDArrayFactory::create('c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, - 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); ASSERT_TRUE(output.equalsTo(&expOutput)); } ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, logSoftMaxForVector_test4) { + NDArray input('c', {1500}, sd::DataType::DOUBLE); NDArray output('c', {1500}, sd::DataType::DOUBLE); - NDArray expOutput( - 'c', {1500}, - {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, - -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, - -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, - -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, - -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, - -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, - -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, - -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, - -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, - -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, - -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, - -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, - -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, - -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, - -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, - -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, - -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, -8.053773, - -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, - -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, - -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, - -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, - -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, - -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, - -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, - -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, - -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, - -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, - -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, - -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, - -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, - -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, - -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, - -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, - -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, -7.951773, - -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, - -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, - -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, - -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, - -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, - -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, - -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, - -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, - -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, - -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, - -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, - -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, - -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, - -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, - -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, - -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, - -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, - -7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, - -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, - -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, - -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, - -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, - -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, - -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, - -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, - -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, - -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, - -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, - -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, - -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, - -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, - -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, - -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, - -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, - -7.746773, -7.745773, -7.744773, -7.743773, -7.742773, -7.741773, - -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, - -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, - -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, - -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, - -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, - -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, - -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, - -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, - -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, - -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, - -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, - -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, - -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, - -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, - -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, - -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, - -7.644773, -7.643773, -7.642773, -7.641773, -7.640773, -7.639773, - -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, - -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, - -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, - -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, - -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, - -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, - -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, - -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, - -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, - -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, - -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, - -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, - -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, - -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, - -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, - -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, - -7.542773, -7.541773, -7.540773, -7.539773, -7.538773, -7.537773, - -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, - -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, - -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, - -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, - -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, - -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, - -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, - -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, - -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, - -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, - -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, - -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, - -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, - -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, - -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, - -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, - -7.440773, -7.439773, -7.438773, -7.437773, -7.436773, -7.435773, - -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, - -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, - -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, - -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, - -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, - -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, - -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, - -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, - -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, - -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, - -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, - -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, - -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, - -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, - -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, - -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, - -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, -7.333773, - -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, - -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, - -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, - -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, - -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, - -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, - -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, - -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, - -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, - -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, - -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, - -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, - -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, - -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, - -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, - -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, - -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, - -7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, - -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, - -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, - -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, - -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, - -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, - -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, - -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, - -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, - -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, - -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, - -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, - -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, - -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, - -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, - -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, - -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, - -7.128773, -7.127773, -7.126773, -7.125773, -7.124773, -7.123773, - -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, - -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, - -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, - -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, - -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, - -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, - -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, - -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, - -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, - -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, - -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, - -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, - -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, - -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, - -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, - -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, - -7.026773, -7.025773, -7.024773, -7.023773, -7.022773, -7.021773, - -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, - -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, - -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, - -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, - -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, - -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, - -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, - -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, - -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, - -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, - -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, - -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, - -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, - -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, - -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, - -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, - -6.924773, -6.923773, -6.922773, -6.921773, -6.920773, -6.919773, - -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, - -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, - -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, - -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, - -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, - -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, - -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, - -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, - -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, - -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, - -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, - -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, - -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, - -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, - -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, - -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, - -6.822773, -6.821773, -6.820773, -6.819773, -6.818773, -6.817773, - -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, - -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, - -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, - -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, - -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, - -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, - -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, - -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, - -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, - -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, - -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, - -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, - -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, - -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, - -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, - -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, - -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, - -6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, - -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, - -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, - -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, - -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, - -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, - -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, - -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, - -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, - -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, - sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, + -8.053773, -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, + -7.951773, -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, + -7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, -7.746773, + -7.745773, -7.744773, -7.743773, -7.742773, -7.741773, -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, -7.644773, -7.643773, + -7.642773, -7.641773, -7.640773, -7.639773, -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, -7.542773, -7.541773, -7.540773, + -7.539773, -7.538773, -7.537773, -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, -7.440773, -7.439773, -7.438773, -7.437773, + -7.436773, -7.435773, -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, + -7.333773, -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, + -7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, -7.128773, + -7.127773, -7.126773, -7.125773, -7.124773, -7.123773, -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, -7.026773, -7.025773, + -7.024773, -7.023773, -7.022773, -7.021773, -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, -6.924773, -6.923773, -6.922773, + -6.921773, -6.920773, -6.919773, -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, -6.822773, -6.821773, -6.820773, -6.819773, + -6.818773, -6.817773, -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, + -6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, sd::DataType::DOUBLE); input.linspace(0.01, 0.001); - ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, - 0); + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); ASSERT_TRUE(output.equalsTo(&expOutput)); } + ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_1) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {M, N}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - NDArray temp('f', {M, N, 5}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(6, {0, 2}); + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); @@ -3159,20 +2041,14 @@ TEST_F(HelpersTests1, mmulMxV_1) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_2) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('f', {M, N, 5}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(6, {0, 2}); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); @@ -3183,20 +2059,14 @@ TEST_F(HelpersTests1, mmulMxV_2) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_3) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('f', {N, M, 5}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(4, {1, 2}); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(4, {1,2}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); @@ -3207,20 +2077,14 @@ TEST_F(HelpersTests1, mmulMxV_3) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_4) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('f', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('f', {5, M, N}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(3, {0, 1}); + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); @@ -3231,20 +2095,14 @@ TEST_F(HelpersTests1, mmulMxV_4) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_5) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('f', {5, M, N}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(2, {0, 1}); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); @@ -3255,20 +2113,14 @@ TEST_F(HelpersTests1, mmulMxV_5) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_6) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('c', {5, N, M}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(13, {0, 2}); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(13, {0,2}); NDArray y('f', {M}, sd::DataType::DOUBLE); NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); @@ -3279,20 +2131,14 @@ TEST_F(HelpersTests1, mmulMxV_6) { ////////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, mmulMxV_7) { + const Nd4jLong M = 3; const Nd4jLong N = 4; - NDArray a('c', {N, M}, - {1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3, 0.2, 0.1, 0}, - sd::DataType::DOUBLE); - a.permutei({1, 0}); - NDArray temp('c', {5, N, M}, - {16, 2, -6, 7, 2, -2, 4, -7, 6, 4, 4, 6, -3, 1, 3, - 9, 1, 4, 9, 10, -10, -3, -8, 7, -7, -7, 6, 9, 7, -6, - 8, 7, -3, -3, 4, -2, 5, -3, -3, 4, 6, -5, -1, 7, -5, - 4, -10, -1, 8, 0, -7, 4, -10, -7, -8, -9, 2, 9, 7, 9}, - sd::DataType::DOUBLE); - NDArray x = temp(10, {0, 2}); + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(10, {0,2}); NDArray y('c', {M}, sd::DataType::DOUBLE); NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); @@ -3303,13 +2149,10 @@ TEST_F(HelpersTests1, mmulMxV_7) { ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_1) { - NDArray input('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, - sd::DataType::DOUBLE); - NDArray expOutput('c', {3, 3}, - {0.04508, 0.04514, 0.0008, 0.0472, 0.00087, 0.10492, - 0.00235, 0.04592, 0.10553}, - sd::DataType::DOUBLE); - NDArray output('c', {3, 3}, sd::DataType::DOUBLE); + + NDArray input('c', {3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3}, {0.04508, 0.04514, 0.0008 , 0.0472 , 0.00087, 0.10492, 0.00235, 0.04592, 0.10553}, sd::DataType::DOUBLE); + NDArray output('c', {3,3}, sd::DataType::DOUBLE); // input.applyTransform(sd::transform::SoftMaxDerivative, &output); @@ -3320,20 +2163,12 @@ TEST_F(HelpersTests1, softmaxDerivative_1) { ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_2) { - NDArray input('c', {3, 3, 3}, - {-1, 1, -2, 2, -3, 3, -4, 4, -5, 5, -6, 6, -7, 7, - -8, 8, -9, 9, -10, 10, -11, 11, -12, 12, -13, 13, 14.}, - sd::DataType::DOUBLE); - NDArray expOutput( - 'c', {3, 3, 3}, - {4.50755e-02, 4.51394e-02, 6.64586e-03, 4.72027e-02, 8.67128e-04, - 6.97440e-03, 2.35008e-03, 4.59243e-02, 3.32995e-04, 4.51766e-02, - 2.26032e-06, 4.51767e-02, 2.91394e-07, 2.37285e-06, 3.94360e-08, - 4.51769e-02, 1.12535e-07, 4.51767e-02, 7.58256e-10, 4.51767e-02, - 1.22325e-11, 7.96007e-10, 1.32293e-11, 1.04994e-01, 3.77513e-11, - 4.51767e-02, 1.04994e-01}, - sd::DataType::DOUBLE); - NDArray output('c', {3, 3, 3}, sd::DataType::DOUBLE); + + NDArray input('c', {3,3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3,3}, {4.50755e-02, 4.51394e-02, 6.64586e-03,4.72027e-02, 8.67128e-04, 6.97440e-03,2.35008e-03, 4.59243e-02, 3.32995e-04, + 4.51766e-02, 2.26032e-06, 4.51767e-02,2.91394e-07, 2.37285e-06, 3.94360e-08,4.51769e-02, 1.12535e-07, 4.51767e-02, + 7.58256e-10, 4.51767e-02, 1.22325e-11,7.96007e-10, 1.32293e-11, 1.04994e-01,3.77513e-11, 4.51767e-02, 1.04994e-01}, sd::DataType::DOUBLE); + NDArray output('c', {3,3,3}, sd::DataType::DOUBLE); // input.applyTransform(sd::transform::SoftMaxDerivative, &output); @@ -3344,9 +2179,9 @@ TEST_F(HelpersTests1, softmaxDerivative_2) { ////////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, softmaxDerivative_3) { + NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); - NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, - sd::DataType::DOUBLE); + NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, sd::DataType::DOUBLE); NDArray output('c', {5}, sd::DataType::DOUBLE); // input.applyTransform(sd::transform::SoftMaxDerivative, &output); @@ -3358,52 +2193,38 @@ TEST_F(HelpersTests1, softmaxDerivative_3) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_1) { - const int bS = 2; - const int nIn = 10; + + const int bS = 2; + const int nIn = 10; const int nOut = 4; - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = - 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = - 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = - 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = - 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = - 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = - 0; // alpha value for output activation, not required for tanh - const float outBeta = - 0; // beta value for output activation, not required for tanh - - NDArray x('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {bS, nOut}, - {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, - 0.999288, 0.999288}, - sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, - {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, - 3.999778, 3.999778}, - sd::DataType::FLOAT32); + NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - std::vector params = {dataFormat, 0, cellClip, gateAct, - gateAlpha, gateBeta, cellAct, cellAlpha, - cellBeta, outAct, outAlpha, outBeta}; + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; x = 1.; hI = 2.; @@ -3413,8 +2234,7 @@ TEST_F(HelpersTests1, lstmLayerCell_1) { Wp = 0.3; b = 0.7; - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, - &c); + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); @@ -3424,49 +2244,38 @@ TEST_F(HelpersTests1, lstmLayerCell_1) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_2) { - const int bS = 2; - const int nIn = 10; + + const int bS = 2; + const int nIn = 10; const int nOut = 4; - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 3; // clipping value - const float gateAct = - 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = - 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = - 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = - 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = - 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = - 0; // alpha value for output activation, not required for tanh - const float outBeta = - 0; // beta value for output activation, not required for tanh - - NDArray x('c', {bS, nIn}, sd::DataType::FLOAT32); - NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); - NDArray Wr('c', {nOut, 4 * nOut}, sd::DataType::FLOAT32); - NDArray b('c', {4 * nOut}, sd::DataType::FLOAT32); + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray Wp('c', {3 * nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {bS, nOut}, - {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, - sd::DataType::FLOAT32); - NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, - sd::DataType::FLOAT32); + NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, sd::DataType::FLOAT32); - std::vector params = {dataFormat, 0, cellClip, gateAct, - gateAlpha, gateBeta, cellAct, cellAlpha, - cellBeta, outAct, outAlpha, outBeta}; + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; x = 1.; hI = 2.; @@ -3476,8 +2285,7 @@ TEST_F(HelpersTests1, lstmLayerCell_2) { Wp = 0.3; b = 0.7; - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, - &c); + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); @@ -3487,27 +2295,21 @@ TEST_F(HelpersTests1, lstmLayerCell_2) { /////////////////////////////////////////////////////////////////// TEST_F(HelpersTests1, lstmLayerCell_3) { + const int nIn = 10; const int nOut = 4; - const float dataFormat = 0; // is ignored in cell op - const float cellClip = 5; // clipping value - const float gateAct = - 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const float gateAlpha = - 0; // alpha value for activation for gates, not required for sigmoid - const float gateBeta = - 0; // beta value for activation for gates, not required for sigmoid - const float cellAct = 0; // tanh activation for cell state - const float cellAlpha = - 0; // alpha value for cell state activation, not required for tanh - const float cellBeta = - 0; // beta value for cell state activation, not required for tanh - const float outAct = 0; // tanh activation for output - const float outAlpha = - 0; // alpha value for output activation, not required for tanh - const float outBeta = - 0; // beta value for output activation, not required for tanh + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh NDArray x('c', {nIn}, sd::DataType::FLOAT32); NDArray Wx('c', {nIn, 4 * nOut}, sd::DataType::FLOAT32); @@ -3520,14 +2322,11 @@ TEST_F(HelpersTests1, lstmLayerCell_3) { NDArray h('c', {nOut}, sd::DataType::FLOAT32); NDArray c('c', {nOut}, sd::DataType::FLOAT32); - NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, - sd::DataType::FLOAT32); - NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, - sd::DataType::FLOAT32); + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); - std::vector params = {dataFormat, 0, cellClip, gateAct, - gateAlpha, gateBeta, cellAct, cellAlpha, - cellBeta, outAct, outAlpha, outBeta}; + std::vector params = + {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; x = 1.; hI = 2.; @@ -3537,8 +2336,7 @@ TEST_F(HelpersTests1, lstmLayerCell_3) { Wp = 0.3; b = 0.7; - sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, - &c); + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); ASSERT_TRUE(expH.isSameShape(h)); ASSERT_TRUE(expH.equalsTo(h)); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 2cd828d309c5..5db060400082 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -49,6 +49,8 @@ #include #include +#include "testlayers.h" + using namespace sd; using namespace sd::graph; From c4af106eff9bb12557d0420d3aeafb47e2cc2676 Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 15 May 2020 16:17:27 +0300 Subject: [PATCH 124/233] - got back row in JacobiSVD code which was accidentally deleted Signed-off-by: Yurii --- libnd4j/include/helpers/biDiagonalUp.h | 20 +++++++++------ libnd4j/include/helpers/impl/biDiagonalUp.cpp | 2 +- libnd4j/include/helpers/impl/hhSequence.cpp | 4 +-- libnd4j/include/helpers/impl/jacobiSVD.cpp | 25 ++++++++----------- libnd4j/include/helpers/jacobiSVD.h | 2 +- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/libnd4j/include/helpers/biDiagonalUp.h b/libnd4j/include/helpers/biDiagonalUp.h index f008b5308981..2a4c839b69d5 100644 --- a/libnd4j/include/helpers/biDiagonalUp.h +++ b/libnd4j/include/helpers/biDiagonalUp.h @@ -15,7 +15,7 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 18.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // #ifndef LIBND4J_BIDIAGONALUP_H @@ -30,8 +30,10 @@ namespace helpers { class BiDiagonalUp { public: - NDArray _HHmatrix; // 2D Householder matrix - NDArray _HHbidiag; // vector which contains Householder coefficientsNDArray _hhCoeffs; // vector of Householder coefficients + NDArray _HHmatrix; // 2D Householder matrix + NDArray _HHbidiag; // vector which contains Householder coefficients + NDArray _hhCoeffs; // vector of Householder coefficients + /** * constructor @@ -41,11 +43,13 @@ class BiDiagonalUp { BiDiagonalUp(const NDArray& matrix); /** - * this method evaluates data (coeff, normX, tail) used in Householder - * transformation formula for Householder matrix: P = identity_matrix - coeff - * * w * w^T P * x = [normX, 0, 0 , 0, ...] coeff - scalar w = [1, w1, w2, w3, - * ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, - * ...] tail and coeff are stored in _HHmatrix normX are stored in _HHbidiag + * this method evaluates data (coeff, normX, tail) used in Householder transformation + * formula for Householder matrix: P = identity_matrix - coeff * w * w^T + * P * x = [normX, 0, 0 , 0, ...] + * coeff - scalar + * w = [1, w1, w2, w3, ...], "tail" is w except first unity element, that is "tail" = [w1, w2, w3, ...] + * tail and coeff are stored in _HHmatrix + * normX are stored in _HHbidiag */ template void _evalData(); diff --git a/libnd4j/include/helpers/impl/biDiagonalUp.cpp b/libnd4j/include/helpers/impl/biDiagonalUp.cpp index c66fecf788e1..d5326c21a29d 100644 --- a/libnd4j/include/helpers/impl/biDiagonalUp.cpp +++ b/libnd4j/include/helpers/impl/biDiagonalUp.cpp @@ -125,7 +125,7 @@ HHsequence BiDiagonalUp::makeHHsequence_(const char type) { const int diagSize = type == 'u' ? _HHbidiag.sizeAt(0) : _HHbidiag.sizeAt(0) - 1; - auto _hhCoeffs = NDArray(_HHmatrix.ordering(), {diagSize}, _HHmatrix.dataType(), _HHmatrix.getContext()); + _hhCoeffs = NDArray(_HHmatrix.ordering(), {diagSize}, _HHmatrix.dataType(), _HHmatrix.getContext()); if(type == 'u') for(int i = 0; i < diagSize; ++i) diff --git a/libnd4j/include/helpers/impl/hhSequence.cpp b/libnd4j/include/helpers/impl/hhSequence.cpp index c20277934a04..6072e492324f 100644 --- a/libnd4j/include/helpers/impl/hhSequence.cpp +++ b/libnd4j/include/helpers/impl/hhSequence.cpp @@ -27,8 +27,7 @@ namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, - const char type) +HHsequence::HHsequence(const NDArray& vectors, const NDArray& coeffs, const char type) : _vectors(vectors), _coeffs(coeffs) { _diagSize = sd::math::nd4j_min(_vectors.sizeAt(0), _vectors.sizeAt(1)); _shift = 0; @@ -70,6 +69,7 @@ NDArray HHsequence::getTail(const int idx) const { ////////////////////////////////////////////////////////////////////////// template void HHsequence::applyTo_(NDArray& dest) { + int size = _type == 'u' ? _vectors.sizeAt(0) : _vectors.sizeAt(1); if (dest.rankOf() != 2 || (dest.sizeAt(0) != size && dest.sizeAt(1) != size)) diff --git a/libnd4j/include/helpers/impl/jacobiSVD.cpp b/libnd4j/include/helpers/impl/jacobiSVD.cpp index 8e8399a790cf..53e87d602257 100644 --- a/libnd4j/include/helpers/impl/jacobiSVD.cpp +++ b/libnd4j/include/helpers/impl/jacobiSVD.cpp @@ -79,7 +79,7 @@ void JacobiSVD::mulRotationOnLeft(const int i, const int j, NDArray& block, "of array row range !"); auto temp = block({i,j+1,j-i, 0,0,0}, true, true); - temp.assign(mmul(rotation, temp)); + temp.assign(mmul(rotation, temp)); //auto pTemp = block({i, j + 1, j - i, 0, 0, 0}, true, true); // auto temp = pTemp.dup(); @@ -157,8 +157,7 @@ bool JacobiSVD::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) { } else { T v = block.t(p, p) / n; - rotation.r(0,0) = - rotation.r(1, 1) = v; + rotation.r(0,0) = rotation.r(1, 1) = v; v = block.t(q, p) / n; rotation.r(0,1) = v; @@ -187,10 +186,8 @@ bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, T denom = (T)(2.f)* math::nd4j_abs(y); if (denom < DataTypeUtils::min()) { - rotation.r(0, 0) = - rotation.r(1, 1) = (T)1.f; - rotation.r(0, 1) = - rotation.r(1, 0) = (T)0.f; + rotation.r(0, 0) = rotation.r(1, 1) = (T)1.f; + rotation.r(0, 1) = rotation.r(1, 0) = (T)0.f; return false; } else { T tau = (x - z) / denom; @@ -202,16 +199,14 @@ bool JacobiSVD::createJacobiRotation(const T& x, const T& y, const T& z, else t = (T)1.f / (tau - w); - T sign = t > (T)0. ? ( - T )1.f : ( T)-1.f; + T sign = t > (T)0. ? (T)1.f : (T)-1.f; T cos = (T)1.f / math::nd4j_sqrt(t*t + (T) 1.f); - T sin = - -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * cos; + T sin = -sign * (y / math::nd4j_abs(y)) * math::nd4j_abs(t) * cos; - rotation.r(0,1) = sin; + rotation.r(0,1) = sin; rotation.r(1, 0) = -sin; - rotation.r(0, 0) = rotation.r(1,1) = cos; + rotation.r(0, 0) = rotation.r(1,1) = cos; return true; } @@ -345,7 +340,9 @@ void JacobiSVD::evalData(const NDArray& matrix) { _m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); - if (_calcV) _v.setIdentity(); + if(_calcU) _u.setIdentity(); + + if(_calcV) _v.setIdentity(); } T maxDiagElem = 0.; diff --git a/libnd4j/include/helpers/jacobiSVD.h b/libnd4j/include/helpers/jacobiSVD.h index 285c452823c0..bcfd75225a87 100644 --- a/libnd4j/include/helpers/jacobiSVD.h +++ b/libnd4j/include/helpers/jacobiSVD.h @@ -52,7 +52,7 @@ class JacobiSVD { static bool createJacobiRotation(const T& x, const T& y, const T& z, NDArray& rotation); -static void createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation); + static void createJacobiRotationGivens(const T& p, const T& q, NDArray& rotation); static void svd2x2(const NDArray& block, int p, int q, NDArray& left, NDArray& right); From 4e5e53502b876cd32ccccb473f0538714af7ebe9 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 21 May 2020 18:50:36 +0300 Subject: [PATCH 125/233] - implementation of algorithm for topological sort of graph with no cycles Signed-off-by: Yurii --- libnd4j/include/graph/Graph.h | 2 +- libnd4j/include/graph/OptimizedGraph.h | 355 ++-- .../graph/execution/impl/GraphExecutor.cpp | 26 +- libnd4j/include/graph/impl/Graph.cpp | 16 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 719 ++++---- .../layers_tests/GraphAnalysisTests.cpp | 1500 ++++++++--------- .../layers_tests/GraphExecutorTests.cpp | 92 +- .../layers_tests/OpSequenceTests.cpp | 52 +- 8 files changed, 1441 insertions(+), 1321 deletions(-) diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index bef4ff065e75..e15949abcfa0 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -62,7 +62,7 @@ class SD_EXPORT Graph { // we want to know last node id int _maxId = 1; - const GraphMemoryManager &_memoryMaager; + const GraphMemoryManager &_memoryManager; //////////////////////////////////////// Nd4jStatus validateNode(Node *node); diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index fc20c414e630..507f2203e928 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -16,7 +16,6 @@ // // @author raver119@gmail.com -// @author oleg.semeniv@gmail.com // #ifndef SD_OPTIMIZEDGRAPH_H @@ -26,186 +25,210 @@ #include #include -#include #include -#include namespace sd { namespace graph { -class Graph; -class NodeInfo; -/** - * This class acts as a topologically sorted & optimized Graph representation, - * ready for execution - */ -class SD_EXPORT OptimizedGraph { - protected: - // here we store independent OpSequences - // Graph starts from layer 0, and goes deeper step by step - // on each layer we can have 1+ OpSequences that can be executed independent - std::map _onion; - - GraphMemoryManager* _memoryManager = nullptr; - Graph* _originalGraph = nullptr; - - mutable std::mutex _mutex; - - mutable size_t _size = 0; - - public: - OptimizedGraph(Graph* original); - OptimizedGraph() = default; - ~OptimizedGraph() = default; - - OptimizedGraph(const OptimizedGraph& other) noexcept; - - OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; - - // move constructor - OptimizedGraph(OptimizedGraph&& other) noexcept; - - // move assignment operator - OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; - - /** - * This method returns number of layers within OptimizedGraph - * @return - */ - uint64_t layers() const; - - /** - * This method returns OpSequences stored in a given layer - * @param index - * @return - */ - const ExecutionLayer& layer(uint64_t index) const; - - /** - * This method allows to append layer to this OptimizedGraph instance - */ - // FIXME: this method should be removed or made private - void append(const std::vector& layer); - void append(const ExecutionLayer& layer); - void append(OpSequence& sequence); - - /** - * This method returns GraphMemoryManager instance that manages this Graph - * @return - */ - const GraphMemoryManager& memoryManager() const; - - /** - * This method returns pointer to original Graph - * @return - */ - const Graph& originalGraph() const; - - /** - * This method returns number of nodes in this graph instance - * @return - */ - size_t size() const; - - /** - * This method prints out graph content - */ - void printOut() const; - - protected: - /* - * optimize original graph - */ - void createOptimizedGraph(); - /* - * Topological graph analysis - * @param const start node for search - * @param const reference for nodes infor container - * @param operation gather - * @return stop iterating - */ - bool topolSearch(const int startNode, - std::unordered_map& nodesConnections, - std::vector>& opSeq) const; - /* - * Optimized graph analysis prototyping, gather nodes infor - * @param reference to node information collector - * @param reference to start nodes - * @param reference to input branching nodes (input branching node - atleast 2 - * internal inputs) - * @return stop iterating - */ - bool opGraphProto(std::unordered_map& collector, - std::set& startNodes, - std::set& inBranchingNodes) const; - /* - * Define layers and sequence positions based on nodes infor - * @param reference to node information collector - * @param node ID - * @param layer ID - * @param sequence ID - * @param map of layers and max sequence - * @return stop iterating - */ - bool layersSeqDefine(std::unordered_map& collection, int ID, - int layer, int nStartSeq, - std::unordered_map& layersMaxSeq) const; - /* - * Initialize container with operations and context - * @param const reference to layers and sequence collection - * @param reference to opSequence collector - * @return stop iterating - */ - bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, - std::vector>& vOpSeq) const; -}; - -class NodeInfo { - private: - std::set sConnections; - - bool bInBranching; - bool bOutBranching; - bool bProcessed; - int nLayer; - int nSequence; +class Graph; - sd::graph::OpType opType; +class SD_EXPORT OptimizedGraph { + private: + std::vector> _sortedGraph; + // const Graph& _originalGraph; - public: - NodeInfo() { reset(); } - ~NodeInfo() { reset(); } + public: + OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); + OptimizedGraph() {}; + size_t size() const; - void setInBranching(bool bValue) { bInBranching = bValue; } - void setOutBranching(bool bValue) { bOutBranching = bValue; } - void setProcessed(bool bValue = true) { bProcessed = bValue; } - void reset() { - sConnections.clear(); - bProcessed = bInBranching = bOutBranching = false; - nLayer = 0; - nSequence = -1; - opType = OpType_CUSTOM; - } + struct NodeInfo { + std::vector _connections = std::vector(); + int _id = -1; + NodeInfo(const int id): _id(id), _connections(std::vector()) {} + NodeInfo() = delete; + }; - int layer() const { return nLayer; } - void setLayer(int layer) { nLayer = layer; } +}; - int sequence() const { return nSequence; } - void setSequence(int sequence) { nSequence = sequence; } - void addConnection(int id) { sConnections.emplace(id); } - const std::set& connections() const { return sConnections; } +// class Graph; +// class NodeInfo; +// /** +// * This class acts as a topologically sorted & optimized Graph representation, +// * ready for execution +// */ +// class SD_EXPORT OptimizedGraph { +// protected: +// // here we store independent OpSequences +// // Graph starts from layer 0, and goes deeper step by step +// // on each layer we can have 1+ OpSequences that can be executed independent +// std::map _onion; + +// GraphMemoryManager* _memoryManager = nullptr; +// Graph* _originalGraph = nullptr; + +// mutable std::mutex _mutex; + +// mutable size_t _size = 0; + +// public: +// OptimizedGraph(Graph* original); +// OptimizedGraph() = default; +// ~OptimizedGraph() = default; + +// OptimizedGraph(const OptimizedGraph& other) noexcept; + +// OptimizedGraph& operator=(const OptimizedGraph& other) noexcept; + +// // move constructor +// OptimizedGraph(OptimizedGraph&& other) noexcept; + +// // move assignment operator +// OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + +// /** +// * This method returns number of layers within OptimizedGraph +// * @return +// */ +// uint64_t layers() const; + +// /** +// * This method returns OpSequences stored in a given layer +// * @param index +// * @return +// */ +// const ExecutionLayer& layer(uint64_t index) const; + +// /** +// * This method allows to append layer to this OptimizedGraph instance +// */ +// // FIXME: this method should be removed or made private +// void append(const std::vector& layer); +// void append(const ExecutionLayer& layer); +// void append(OpSequence& sequence); + +// /** +// * This method returns GraphMemoryManager instance that manages this Graph +// * @return +// */ +// const GraphMemoryManager& memoryManager() const; + +// /** +// * This method returns pointer to original Graph +// * @return +// */ +// const Graph& originalGraph() const; + +// /** +// * This method returns number of nodes in this graph instance +// * @return +// */ +// size_t size() const; + +// /** +// * This method prints out graph content +// */ +// void printOut() const; + +// protected: +// /* +// * optimize original graph +// */ +// void createOptimizedGraph(); +// /* +// * Topological graph analysis +// * @param const start node for search +// * @param const reference for nodes infor container +// * @param operation gather +// * @return stop iterating +// */ +// bool topolSearch(const int startNode, +// std::unordered_map& nodesConnections, +// std::vector>& opSeq) const; + +// * Optimized graph analysis prototyping, gather nodes infor +// * @param reference to node information collector +// * @param reference to start nodes +// * @param reference to input branching nodes (input branching node - atleast 2 +// * internal inputs) +// * @return stop iterating + +// bool opGraphProto(std::unordered_map& collector, +// std::set& startNodes, +// std::set& inBranchingNodes) const; +// /* +// * Define layers and sequence positions based on nodes infor +// * @param reference to node information collector +// * @param node ID +// * @param layer ID +// * @param sequence ID +// * @param map of layers and max sequence +// * @return stop iterating +// */ +// bool layersSeqDefine(std::unordered_map& collection, int ID, +// int layer, int nStartSeq, +// std::unordered_map& layersMaxSeq) const; +// /* +// * Initialize container with operations and context +// * @param const reference to layers and sequence collection +// * @param reference to opSequence collector +// * @return stop iterating +// */ +// bool initOpSeqContainer(const std::unordered_map& layersMaxSeq, +// std::vector>& vOpSeq) const; +// }; + +// class NodeInfo { +// private: +// std::set sConnections; + +// bool bInBranching; +// bool bOutBranching; +// bool bProcessed; + +// int nLayer; +// int nSequence; + +// sd::graph::OpType opType; + +// public: +// NodeInfo() { reset(); } +// ~NodeInfo() { reset(); } + +// void setInBranching(bool bValue) { bInBranching = bValue; } +// void setOutBranching(bool bValue) { bOutBranching = bValue; } +// void setProcessed(bool bValue = true) { bProcessed = bValue; } + +// void reset() { +// sConnections.clear(); +// bProcessed = bInBranching = bOutBranching = false; +// nLayer = 0; +// nSequence = -1; +// opType = OpType_CUSTOM; +// } + +// int layer() const { return nLayer; } +// void setLayer(int layer) { nLayer = layer; } + +// int sequence() const { return nSequence; } +// void setSequence(int sequence) { nSequence = sequence; } + +// void addConnection(int id) { sConnections.emplace(id); } +// const std::set& connections() const { return sConnections; } + +// void setType(sd::graph::OpType value) { opType = value; } +// sd::graph::OpType type() const { return opType; } +// bool isLogic() { return opType == OpType_LOGIC; } + +// bool isInBranching() const { return bInBranching; } +// bool isOutBranching() const { return bOutBranching; } +// bool isProcessed() const { return bProcessed; } +// }; - void setType(sd::graph::OpType value) { opType = value; } - sd::graph::OpType type() const { return opType; } - bool isLogic() { return opType == OpType_LOGIC; } - bool isInBranching() const { return bInBranching; } - bool isOutBranching() const { return bOutBranching; } - bool isProcessed() const { return bProcessed; } -}; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 20059bbb75f1..619b87fd43ca 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -54,7 +54,7 @@ Nd4jStatus GraphExecutor::execute( const ContextPrototype &contextPrototype, const OpSequence &sequence, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { - auto ctx = prepareContext(contextPrototype, proxy, graph.memoryManager()); + auto ctx = prepareContext(contextPrototype, proxy, GraphMemoryManager()/*graph.memoryManager()*/); return op->execute(&ctx); // throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); } @@ -85,20 +85,20 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, * execute them one by one sequentially */ Nd4jStatus result = Status::OK(); - for (uint64_t l = 0; l < graph.layers(); l++) { - const auto &layer = graph.layer(l); + // for (uint64_t l = 0; l < graph.layers(); l++) { + // const auto &layer = graph.layer(l); - for (uint64_t o = 0; o < layer.width(); o++) { - execute(layer[o], graph, proxy, -1); - } + // for (uint64_t o = 0; o < layer.width(); o++) { + // execute(layer[o], graph, proxy, -1); + // } - // optionally block until all sequences in this layer processed - if (layer.width() > 0 && numDevices > 1) - for (uint64_t o = 0; o < layer.width(); o++) { - result = layer[o].wait(); - if (result != Status::OK()) return result; - } - } + // // optionally block until all sequences in this layer processed + // if (layer.width() > 0 && numDevices > 1) + // for (uint64_t o = 0; o < layer.width(); o++) { + // result = layer[o].wait(); + // if (result != Status::OK()) return result; + // } + // } return result; } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 65912d6d9e61..c989e8e9fb2d 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -124,9 +124,7 @@ void Graph::addNode(Node &node, throw std::runtime_error("Graph::addNode() - Not implemented yet"); } -Graph::Graph(const FlatGraph *flatGraph, - const GraphMemoryManager &memoryManager) - : _memoryMaager(memoryManager) { +Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager): _memoryManager(memoryManager) { bool trusted = flatGraph != nullptr; // if there was no exec configuration in flatgraph - create default one @@ -279,7 +277,7 @@ void Graph::printOut() { nd4j_printf("\nPrinting out Nodes...\n", ""); // since we need structure - we'll print out nodes of OptimizedGraph - optimizedGraph().printOut(); + // optimizedGraph().printOut(); } } @@ -641,7 +639,7 @@ std::map Graph::execute( return result; } -Graph::Graph(const Graph &other) : _memoryMaager(other._memoryMaager) { +Graph::Graph(const Graph &other) : _memoryManager(other._memoryManager) { _configuration = other._configuration; _variableSpace = other._variableSpace; _stash = other._stash; @@ -665,7 +663,7 @@ Graph &Graph::operator=(const Graph &other) noexcept { return *this; } -Graph::Graph(Graph &&other) : _memoryMaager(other._memoryMaager) { +Graph::Graph(Graph &&other) : _memoryManager(other._memoryManager) { _configuration = other._configuration; _variableSpace = other._variableSpace; _stash = other._stash; @@ -693,14 +691,14 @@ Graph &Graph::operator=(Graph &&other) noexcept { return *this; } -const GraphMemoryManager &Graph::memoryManager() const { return _memoryMaager; } +const GraphMemoryManager &Graph::memoryManager() const { return _memoryManager; } const OptimizedGraph &Graph::optimizedGraph() const { std::lock_guard lock(_optimizedLock); // optionally rebuild optimized graph, if it's out of date - if (_optimized.size() != size()) - _optimized = OptimizedGraph(const_cast(this)); + // if (_optimized.size() != size()) + // _optimized = OptimizedGraph(unmappedNodes()); return _optimized; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 7a8059b3d573..f429b4cef1f4 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -15,353 +15,452 @@ ******************************************************************************/ // -// @author raver119@gmail.com -// @author oleg.semeniv@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include #include +#include -namespace sd { +namespace sd { namespace graph { -OptimizedGraph::OptimizedGraph(Graph* original) { - _originalGraph = original; - _memoryManager = const_cast(&original->memoryManager()); - // create optimized graph - createOptimizedGraph(); -} - -OptimizedGraph::OptimizedGraph(const OptimizedGraph& other) noexcept { - _onion = other._onion; - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; -} -OptimizedGraph& OptimizedGraph::operator=( - const OptimizedGraph& other) noexcept { - if (this == &other) return *this; +OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace) { - _onion = other._onion; - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; + MAP_IMPL> workMap; // key is node id, value is vector containing connections (internal inputs, that is nodes, not input arrays) - return *this; -} + // fill workMap + for (const auto& i0 : inMap) { -OptimizedGraph::OptimizedGraph(OptimizedGraph&& other) noexcept { - _onion = std::move(other._onion); - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; -} + std::forward_list& currList = workMap[i0.first] = {}; -OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph&& other) noexcept { - if (this == &other) return *this; + // loop through inputs of node + for (const auto& i1 : i0.second.input()) + if (!varSpace.hasVariable(i1.first)) + currList.push_front(i1.first); + } - _onion = std::move(other._onion); - _memoryManager = other._memoryManager; - _originalGraph = other._originalGraph; + // 2d vector + std::vector> sortedNodes; // 0d - layers, 1d - nodes belonging to layer, OpSequence are separated by -1 in current layer - return *this; -} + // searches for of exit node + auto findExitNode = [] (const decltype(workMap)& m) { + return std::find_if(m.begin(), m.end(), [] (const std::pair> &p) {return p.second.empty();}); + }; -size_t OptimizedGraph::size() const { - std::lock_guard lock(_mutex); + // searches for next node in current OpSequence + auto findNextNodeInOpSeq = [] (const decltype(workMap)& m, const int &currId) { + return std::find_if(m.begin(), m.end(), [&currId] (const std::pair> &p) {return p.second.front() == currId && std::distance(p.second.begin(), p.second.end()) == 1;}); + }; - std::vector seq; - if (_size == 0) - for (const auto& v : _onion) { - for (int e = 0; e < v.second.width(); e++) { - _size += v.second.at(0).length(); - } - } + decltype(workMap)::const_iterator exitNode(nullptr), nextNodeInSeq(nullptr); - return _size; -} + while (!workMap.empty()) { -uint64_t OptimizedGraph::layers() const { return _onion.size(); } + sortedNodes.emplace_back(std::vector()); // add layer + auto& currLayer = sortedNodes.back(); -const ExecutionLayer& OptimizedGraph::layer(uint64_t index) const { - return _onion.at(index); -} + // loop of searching for exit nodes + while ((exitNode = findExitNode(workMap)) != workMap.end()) { -void OptimizedGraph::append(const std::vector& layer) { - std::lock_guard lock(_mutex); - _onion[_onion.size()] = layer; - _size = 0; -} + int currId = exitNode->first; // id of node under consideration -void OptimizedGraph::append(OpSequence& sequence) { - append(ExecutionLayer({sequence})); -} + currLayer.push_back(currId); + workMap.erase(currId); -void OptimizedGraph::append(const ExecutionLayer& layer) { - std::lock_guard lock(_mutex); - _onion[_onion.size()] = layer; - _size = 0; -} + // loop of searching for next nodes in OpSequence + while ((nextNodeInSeq = findNextNodeInOpSeq(workMap, currId)) != workMap.end()) { -const GraphMemoryManager& OptimizedGraph::memoryManager() const { - return *_memoryManager; -} + currId = nextNodeInSeq->first; -const Graph& OptimizedGraph::originalGraph() const { return *_originalGraph; } - -bool OptimizedGraph::opGraphProto(std::unordered_map& collector, - std::set& startNodes, - std::set& inBranchingNodes) const { - // double check to avoid unstable behavior - if (originalGraph().unmappedNodes().empty()) return false; - - const auto& unmappedNodes = originalGraph().unmappedNodes(); - // iterate via original graph nodes to gather node information - for (const auto& it : unmappedNodes) { - const auto& ID = it.first; - const auto& inputs = it.second.input(); - // if node info is not in collecter add it - if (collector.find(ID) == collector.end()) collector[ID] = NodeInfo(); - - NodeInfo& parentNode = collector[ID]; - // count external and internal inputs to find out the type of the node - // (start, in-branching, out-branching) - int inExCounts = 0, inInternalCounts = 0; - for (const auto& in : inputs) { - // find input id in original graph - if (unmappedNodes.find(in.first) == unmappedNodes.end()) { - // count external inputs, all inputs which id is not in unmapped - // container will be treaded as external - inExCounts++; - } else { - // count iternal inputs, all inputs that are not in external variable - // space will be treated as outputs from other nodes - inInternalCounts++; - // if node info is not in collector add it - if (collector.find(in.first) == collector.end()) - collector[in.first] = NodeInfo(); - // input node connection with discovered - collector[in.first].addConnection(ID); + sortedNodes.back().push_back(currId); + workMap.erase(currId); } + currLayer.push_back(-1); // mark end of OpSequence } - // set operation type - parentNode.setType(it.second.opType()); - - // if move then 1 internal input this is in-branching node - parentNode.setInBranching(inInternalCounts > 1); - // gather start and in-branching nodes for the loop when operations are put - // to OpSequence (topolSearch) - if (inExCounts == inputs.size()) { - startNodes.emplace(ID); - } else { - if (parentNode.isInBranching()) inBranchingNodes.emplace(ID); - } - } - return true; -} -bool OptimizedGraph::topolSearch( - const int startNode, std::unordered_map& collector, - std::vector>& opSeq) const { - // double check to avoid unstable behavior - if (originalGraph().unmappedNodes().empty() || collector.empty()) - return false; - - // skip nodes which are not pre-collected and pre-processed - auto itParent = collector.find(startNode); - if (itParent != collector.end()) { - // iterate via start (in-branching) nodes connections in depth - for (const auto& itNodes : itParent->second.connections()) { - auto itChild = collector.find(itNodes); - // double check - if (itChild != collector.end()) { - // if the child is in-branching node it will be treated as start node or - // it was proceed - if (itChild->second.isInBranching() || itChild->second.isProcessed()) { - continue; - } - // put operation to OpSequence container - const auto it = originalGraph().unmappedNodes().find(itNodes); - auto& child = itChild->second; - // the layer and sequence are pre-defined in layersSeqDefine method - opSeq[child.layer()][child.sequence()].append( - it->second.customOp(), it->second.contextPrototype()); - child.setProcessed(); - // go to the child node connections - topolSearch(itNodes, collector, opSeq); - } - } - } - return true; -} + // remove connections in all nodes pointing at exit nodes + for (uint i = 0; i < currLayer.size(); ++i) { -void OptimizedGraph::createOptimizedGraph() { - // container to store node infor - std::unordered_map collector; - // containers to store start and in-branching nodes - std::set startNodes, inBranching; - // container to store max sequences per layer - std::unordered_map layersMaxSeq; - - // optimizing graph prototyping - // select start nodes - // create connections between nodes - // select in-branching nodes ( more then one iternal input -> outputs from - // other nodes) - if (!opGraphProto(collector, startNodes, inBranching)) - throw std::runtime_error( - "OptimizedGraph::optimizedGraph() - not prototyped!"); - - // next step set the node layer and it sequence in layer - // define max layers and max sequence per layer - int startSeq = 0; - bool bOnlyStartNodes = collector.empty(); - for (const auto& id : startNodes) { - layersMaxSeq[0] = startSeq; - // if only start nodes exists they have to be add to connections - if (bOnlyStartNodes) { - auto node = NodeInfo(); - node.setLayer(0); - node.setProcessed(true); - node.setSequence(startSeq); - collector[id] = node; - } else { - layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); - } - startSeq++; - } + if(currLayer[i] != -1 || i == 0) // if not at the and of OpSequence + continue; - // init container to collect operations per node position (layer:sequence) - std::vector> vOpSeq; - if (!initOpSeqContainer(layersMaxSeq, vOpSeq)) - throw std::runtime_error( - "OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, " - "not all nodes properly prototyped!"); - - // combine start nodes and in-branching nodes - startNodes.insert(inBranching.begin(), inBranching.end()); - // re-init proceed NodeInfo member to avoid append sequence several times - for (auto& it : collector) { - it.second.setProcessed(false); - } + const int idOfExitNode = currLayer[i-1]; - // iterate via start and in-branching nodes - for (const auto& id : startNodes) { - const auto it = originalGraph().unmappedNodes().find(id); - auto& nodeInfo = collector[id]; - // append start/in-branching node operation to sequence - if (!nodeInfo.isProcessed()) { - vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append( - it->second.customOp(), it->second.contextPrototype()); - nodeInfo.setProcessed(); + std::for_each(workMap.begin(), workMap.end(), [&idOfExitNode] (std::pair> &p) {p.second.remove(idOfExitNode);}); } - // search in depth via connections of "start" node - if (!topolSearch(id, collector, vOpSeq)) - throw std::runtime_error( - "OptimizedGraph::topolSearch() - cannot run topological search, " - "inputs incorrect!"); - } - // put results to optimized graph - for (auto& vSeq : vOpSeq) { - this->append(vSeq); - } -} -bool OptimizedGraph::initOpSeqContainer( - const std::unordered_map& layersMaxSeq, - std::vector>& vOpSeq) const { - // double check to avoid unstable behavior - if (layersMaxSeq.empty()) return false; - // pre-init op-sequence size layers/per-layer sequence - vOpSeq.resize(layersMaxSeq.size()); - for (const auto& it : layersMaxSeq) { - vOpSeq[it.first].resize(it.second + 1); } - return true; -} -bool OptimizedGraph::layersSeqDefine( - std::unordered_map& collection, int ID, int layer, - int startSeq, std::unordered_map& layersMaxSeq) const { - // double check to avoid unstable behavior - auto parent = collection.find(ID); - if (parent == collection.end()) return false; - - // if node was proceed and the current layer is less of it own return - if (parent->second.isProcessed() && parent->second.layer() >= layer) - return true; - - // put layer and sequence to container that collects layers and max sequence - // per layer - auto layerFound = layersMaxSeq.find(layer); - if (layerFound == layersMaxSeq.end()) { - // if layer was not treated before, create pair for it - layersMaxSeq[layer] = 0; - // set sequence value to 0, as this is first sequence in layer - startSeq = 0; - } else { - // if node sequence position was not checked use it for max sequence - // selection sequence have to be incremented as max + 1, without any jumps - if (startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; - - layerFound->second = - (layerFound->second < startSeq && parent->second.sequence() < 0) - ? startSeq - : layerFound->second; - } - // double check if the layer is higher and set node layer - if (parent->second.layer() < layer) parent->second.setLayer(layer); - // double check if sequence was init, if not set current sequence - if (parent->second.sequence() < 0) parent->second.setSequence(startSeq); - // set is node out-branching - parent->second.setOutBranching(parent->second.connections().size() > 1); - // set that node was processed, to avoid it double processing (only for some - // cases it can be processed several times) - parent->second.setProcessed(); - - // if current node is out-branching it childs will be put to next layer - if (parent->second.isOutBranching() && !parent->second.isLogic()) layer++; - - // childs sequence position have to start from max defined sequence position - // in layer or if it is first node in layer from 0 - int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) - ? 0 - : layersMaxSeq[layer]; - // if parent is out-branching node sequence have to be increment - // on the next stage the sequence value will be double checked with max per - // layer todo check logic part maybe here have to be check operation class - // (something likke Switch, If, While etc) probably for each of them could be - // other behavior - seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 - : seq; - - // loop via childs (connected nodes) - for (const auto& id : parent->second.connections()) { - // double check to avoid unstable behavior - auto child = collection.find(id); - if (child == collection.end()) return false; - - // in case parent was not out-branching node but child is in branching it - // will be put to next layer todo check logic part - if (!parent->second.isOutBranching() && child->second.isInBranching() && - !child->second.isLogic()) - layer++; - - // move in depth of connections - layersSeqDefine(collection, id, layer, seq, layersMaxSeq); - // increment sequence as childs are on the one layer in case if child was - // not processed earlier todo check logic part - if (!parent->second.isLogic()) seq++; - } - return true; -} -void OptimizedGraph::printOut() const { - for (uint64_t o = 0; o < _onion.size(); o++) { - const auto& layer = _onion.at(o); - printf("Layer [%lu]\n", o); - for (uint64_t l = 0; l < layer.width(); l++) layer.at(l).printOut(); - } + + + // auto fi = std::find_if(ops.begin(), ops.end(), [name](ops::OpDescriptor a) { return a.getOpName()->compare(name) == 0;}); + + + // for (const auto& i0 : idVsConnects) { + + // bool isExitNode = true; + + // for (const auto& i1 : idVsConnects) { + + // if (i0.first == i1.first) + // continue; + + // if (i1.second) { + // isExitNode = false; + // break; + // } + // } + + // } + } + + +// OptimizedGraph::OptimizedGraph(Graph* original) { +// _originalGraph = original; +// _memoryManager = const_cast(&original->memoryManager()); +// // create optimized graph +// createOptimizedGraph(); +// } + +// OptimizedGraph::OptimizedGraph(const OptimizedGraph& other) noexcept { +// _onion = other._onion; +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; +// } + +// OptimizedGraph& OptimizedGraph::operator=( +// const OptimizedGraph& other) noexcept { +// if (this == &other) return *this; + +// _onion = other._onion; +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; + +// return *this; +// } + +// OptimizedGraph::OptimizedGraph(OptimizedGraph&& other) noexcept { +// _onion = std::move(other._onion); +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; +// } + +// OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph&& other) noexcept { +// if (this == &other) return *this; + +// _onion = std::move(other._onion); +// _memoryManager = other._memoryManager; +// _originalGraph = other._originalGraph; + +// return *this; +// } + +// size_t OptimizedGraph::size() const { +// std::lock_guard lock(_mutex); + +// std::vector seq; +// if (_size == 0) +// for (const auto& v : _onion) { +// for (int e = 0; e < v.second.width(); e++) { +// _size += v.second.at(0).length(); +// } +// } + +// return _size; +// } + +// uint64_t OptimizedGraph::layers() const { return _onion.size(); } + +// const ExecutionLayer& OptimizedGraph::layer(uint64_t index) const { +// return _onion.at(index); +// } + +// void OptimizedGraph::append(const std::vector& layer) { +// std::lock_guard lock(_mutex); +// _onion[_onion.size()] = layer; +// _size = 0; +// } + +// void OptimizedGraph::append(OpSequence& sequence) { +// append(ExecutionLayer({sequence})); +// } + +// void OptimizedGraph::append(const ExecutionLayer& layer) { +// std::lock_guard lock(_mutex); +// _onion[_onion.size()] = layer; +// _size = 0; +// } + +// const GraphMemoryManager& OptimizedGraph::memoryManager() const { +// return *_memoryManager; +// } + +// const Graph& OptimizedGraph::originalGraph() const { return *_originalGraph; } + +// bool OptimizedGraph::opGraphProto(std::unordered_map& collector, +// std::set& startNodes, +// std::set& inBranchingNodes) const { +// // double check to avoid unstable behavior +// if (originalGraph().unmappedNodes().empty()) return false; + +// const auto& unmappedNodes = originalGraph().unmappedNodes(); +// // iterate via original graph nodes to gather node information +// for (const auto& it : unmappedNodes) { +// const auto& ID = it.first; +// const auto& inputs = it.second.input(); +// // if node info is not in collecter add it +// if (collector.find(ID) == collector.end()) collector[ID] = NodeInfo(); + +// NodeInfo& parentNode = collector[ID]; +// // count external and internal inputs to find out the type of the node +// // (start, in-branching, out-branching) +// int inExCounts = 0, inInternalCounts = 0; +// for (const auto& in : inputs) { +// // find input id in original graph +// if (unmappedNodes.find(in.first) == unmappedNodes.end()) { +// // count external inputs, all inputs which id is not in unmapped +// // container will be treaded as external +// inExCounts++; +// } else { +// // count iternal inputs, all inputs that are not in external variable +// // space will be treated as outputs from other nodes +// inInternalCounts++; +// // if node info is not in collector add it +// if (collector.find(in.first) == collector.end()) +// collector[in.first] = NodeInfo(); +// // input node connection with discovered +// collector[in.first].addConnection(ID); +// } +// } +// // set operation type +// parentNode.setType(it.second.opType()); + +// // if move then 1 internal input this is in-branching node +// parentNode.setInBranching(inInternalCounts > 1); +// // gather start and in-branching nodes for the loop when operations are put +// // to OpSequence (topolSearch) +// if (inExCounts == inputs.size()) { +// startNodes.emplace(ID); +// } else { +// if (parentNode.isInBranching()) inBranchingNodes.emplace(ID); +// } +// } +// return true; +// } + +// bool OptimizedGraph::topolSearch( +// const int startNode, std::unordered_map& collector, +// std::vector>& opSeq) const { +// // double check to avoid unstable behavior +// if (originalGraph().unmappedNodes().empty() || collector.empty()) +// return false; + +// // skip nodes which are not pre-collected and pre-processed +// auto itParent = collector.find(startNode); +// if (itParent != collector.end()) { +// // iterate via start (in-branching) nodes connections in depth +// for (const auto& itNodes : itParent->second.connections()) { +// auto itChild = collector.find(itNodes); +// // double check +// if (itChild != collector.end()) { +// // if the child is in-branching node it will be treated as start node or +// // it was proceed +// if (itChild->second.isInBranching() || itChild->second.isProcessed()) { +// continue; +// } +// // put operation to OpSequence container +// const auto it = originalGraph().unmappedNodes().find(itNodes); +// auto& child = itChild->second; +// // the layer and sequence are pre-defined in layersSeqDefine method +// opSeq[child.layer()][child.sequence()].append( +// it->second.customOp(), it->second.contextPrototype()); +// child.setProcessed(); +// // go to the child node connections +// topolSearch(itNodes, collector, opSeq); +// } +// } +// } +// return true; +// } + +// void OptimizedGraph::createOptimizedGraph() { +// // container to store node infor +// std::unordered_map collector; +// // containers to store start and in-branching nodes +// std::set startNodes, inBranching; +// // container to store max sequences per layer +// std::unordered_map layersMaxSeq; + +// // optimizing graph prototyping +// // select start nodes +// // create connections between nodes +// // select in-branching nodes ( more then one iternal input -> outputs from +// // other nodes) +// if (!opGraphProto(collector, startNodes, inBranching)) +// throw std::runtime_error( +// "OptimizedGraph::optimizedGraph() - not prototyped!"); + +// // next step set the node layer and it sequence in layer +// // define max layers and max sequence per layer +// int startSeq = 0; +// bool bOnlyStartNodes = collector.empty(); +// for (const auto& id : startNodes) { +// layersMaxSeq[0] = startSeq; +// // if only start nodes exists they have to be add to connections +// if (bOnlyStartNodes) { +// auto node = NodeInfo(); +// node.setLayer(0); +// node.setProcessed(true); +// node.setSequence(startSeq); +// collector[id] = node; +// } else { +// layersSeqDefine(collector, id, 0, startSeq, layersMaxSeq); +// } +// startSeq++; +// } + +// // init container to collect operations per node position (layer:sequence) +// std::vector> vOpSeq; +// if (!initOpSeqContainer(layersMaxSeq, vOpSeq)) +// throw std::runtime_error( +// "OptimizedGraph::initOpSeqContainer() - cannot initialize OpSequence, " +// "not all nodes properly prototyped!"); + +// // combine start nodes and in-branching nodes +// startNodes.insert(inBranching.begin(), inBranching.end()); +// // re-init proceed NodeInfo member to avoid append sequence several times +// for (auto& it : collector) { +// it.second.setProcessed(false); +// } + +// // iterate via start and in-branching nodes +// for (const auto& id : startNodes) { +// const auto it = originalGraph().unmappedNodes().find(id); +// auto& nodeInfo = collector[id]; +// // append start/in-branching node operation to sequence +// if (!nodeInfo.isProcessed()) { +// vOpSeq[nodeInfo.layer()][nodeInfo.sequence()].append( +// it->second.customOp(), it->second.contextPrototype()); +// nodeInfo.setProcessed(); +// } + +// // search in depth via connections of "start" node +// if (!topolSearch(id, collector, vOpSeq)) +// throw std::runtime_error( +// "OptimizedGraph::topolSearch() - cannot run topological search, " +// "inputs incorrect!"); +// } +// // put results to optimized graph +// for (auto& vSeq : vOpSeq) { +// this->append(vSeq); +// } +// } + +// bool OptimizedGraph::initOpSeqContainer( +// const std::unordered_map& layersMaxSeq, +// std::vector>& vOpSeq) const { +// // double check to avoid unstable behavior +// if (layersMaxSeq.empty()) return false; +// // pre-init op-sequence size layers/per-layer sequence +// vOpSeq.resize(layersMaxSeq.size()); +// for (const auto& it : layersMaxSeq) { +// vOpSeq[it.first].resize(it.second + 1); +// } +// return true; +// } + +// bool OptimizedGraph::layersSeqDefine( +// std::unordered_map& collection, int ID, int layer, +// int startSeq, std::unordered_map& layersMaxSeq) const { +// // double check to avoid unstable behavior +// auto parent = collection.find(ID); +// if (parent == collection.end()) return false; + +// // if node was proceed and the current layer is less of it own return +// if (parent->second.isProcessed() && parent->second.layer() >= layer) +// return true; + +// // put layer and sequence to container that collects layers and max sequence +// // per layer +// auto layerFound = layersMaxSeq.find(layer); +// if (layerFound == layersMaxSeq.end()) { +// // if layer was not treated before, create pair for it +// layersMaxSeq[layer] = 0; +// // set sequence value to 0, as this is first sequence in layer +// startSeq = 0; +// } else { +// // if node sequence position was not checked use it for max sequence +// // selection sequence have to be incremented as max + 1, without any jumps +// if (startSeq > (layerFound->second + 1)) startSeq = layerFound->second + 1; + +// layerFound->second = +// (layerFound->second < startSeq && parent->second.sequence() < 0) +// ? startSeq +// : layerFound->second; +// } + +// // double check if the layer is higher and set node layer +// if (parent->second.layer() < layer) parent->second.setLayer(layer); +// // double check if sequence was init, if not set current sequence +// if (parent->second.sequence() < 0) parent->second.setSequence(startSeq); +// // set is node out-branching +// parent->second.setOutBranching(parent->second.connections().size() > 1); +// // set that node was processed, to avoid it double processing (only for some +// // cases it can be processed several times) +// parent->second.setProcessed(); + +// // if current node is out-branching it childs will be put to next layer +// if (parent->second.isOutBranching() && !parent->second.isLogic()) layer++; + +// // childs sequence position have to start from max defined sequence position +// // in layer or if it is first node in layer from 0 +// int seq = (layersMaxSeq.find(layer) == layersMaxSeq.end()) +// ? 0 +// : layersMaxSeq[layer]; +// // if parent is out-branching node sequence have to be increment +// // on the next stage the sequence value will be double checked with max per +// // layer todo check logic part maybe here have to be check operation class +// // (something likke Switch, If, While etc) probably for each of them could be +// // other behavior +// seq = (parent->second.isOutBranching() && !parent->second.isLogic()) ? seq + 1 +// : seq; + +// // loop via childs (connected nodes) +// for (const auto& id : parent->second.connections()) { +// // double check to avoid unstable behavior +// auto child = collection.find(id); +// if (child == collection.end()) return false; + +// // in case parent was not out-branching node but child is in branching it +// // will be put to next layer todo check logic part +// if (!parent->second.isOutBranching() && child->second.isInBranching() && +// !child->second.isLogic()) +// layer++; + +// // move in depth of connections +// layersSeqDefine(collection, id, layer, seq, layersMaxSeq); +// // increment sequence as childs are on the one layer in case if child was +// // not processed earlier todo check logic part +// if (!parent->second.isLogic()) seq++; +// } + +// return true; +// } + +// void OptimizedGraph::printOut() const { +// for (uint64_t o = 0; o < _onion.size(); o++) { +// const auto& layer = _onion.at(o); +// printf("Layer [%lu]\n", o); +// for (uint64_t l = 0; l < layer.width(); l++) layer.at(l).printOut(); +// } +// } + + } // namespace graph } // namespace sd diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 4dd898f94c8f..46c657d3c11b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -37,992 +37,992 @@ class GraphAnalysisTests : public testing::Test { } }; -TEST_F(GraphAnalysisTests, basic_toposort_test_1) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_1) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a(sd::ops::multiply(), "multiply"); - Node b(sd::ops::add(), "add"); +// Node a(sd::ops::multiply(), "multiply"); +// Node b(sd::ops::add(), "add"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"multiply", "C"}); +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"multiply", "C"}); - // we just check that nodes were really added - ASSERT_EQ(2, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(2, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(1, optimized.layers()); +// // we expect that OptimizedGraph has exactly 1 layer +// ASSERT_EQ(1, optimized.layers()); - auto layer = optimized.layer(0); +// auto layer = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - auto sequence = layer[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer.width()); +// auto sequence = layer[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(2, sequence.length()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); - ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); -} +// ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); +// ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_2) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_2) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); - Node a(sd::ops::multiply(), "multiply"); - Node b(sd::ops::add(), "add"); - Node c(sd::ops::subtract(), "subtract"); +// Node a(sd::ops::multiply(), "multiply"); +// Node b(sd::ops::add(), "add"); +// Node c(sd::ops::subtract(), "subtract"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"multiply", "C"}); - graph.addNode(c, {"multiply", "D"}); +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"multiply", "C"}); +// graph.addNode(c, {"multiply", "D"}); - // we just check that nodes were really added - ASSERT_EQ(3, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(3, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // graph size must stay the same - ASSERT_EQ(3, graph.size()); +// // graph size must stay the same +// ASSERT_EQ(3, graph.size()); - // we expect that OptimizedGraph has exactly 2 layers - ASSERT_EQ(2, optimized.layers()); +// // we expect that OptimizedGraph has exactly 2 layers +// ASSERT_EQ(2, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - ; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer0.width()); +// ; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); +// // we expect layer has exactly 2 OpSequences +// ASSERT_EQ(2, layer1.width()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); -} +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_3) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_3) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); +// Node a(sd::ops::multiply(), "a"); +// Node b(sd::ops::add(), "b"); +// Node c(sd::ops::subtract(), "c"); +// Node d(sd::ops::add(), "d"); +// Node e(sd::ops::multiply(), "e"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"b", "D"}); - graph.addNode(d, {"b", "D"}); +// graph.addNode(c, {"b", "D"}); +// graph.addNode(d, {"b", "D"}); - graph.addNode(e, {"c", "d"}); +// graph.addNode(e, {"c", "d"}); - // we just check that nodes were really added - ASSERT_EQ(5, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(5, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(3, optimized.layers()); +// // we expect that OptimizedGraph has exactly 3 layer +// ASSERT_EQ(3, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - // auto sequence = layer0[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer0.width()); +// // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(2, layer0[0].length()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(2, layer0[0].length()); - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); +// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); +// // we expect layer has exactly 2 OpSequences +// ASSERT_EQ(2, layer1.width()); - // sequence = layer1[0]; +// // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; +// // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer2 = optimized.layer(2); +// // checking last layer +// auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - // sequence = layer2[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer2.width()); +// // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); -} +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_4) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_4) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // E +// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // F - graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // F +// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a1(sd::ops::multiply(), "a1"); - Node a2(sd::ops::add(), "a2"); +// Node a1(sd::ops::multiply(), "a1"); +// Node a2(sd::ops::add(), "a2"); - Node b1(sd::ops::subtract(), "b1"); - Node b2(sd::ops::add(), "b2"); - Node b3(sd::ops::multiply(), "b3"); +// Node b1(sd::ops::subtract(), "b1"); +// Node b2(sd::ops::add(), "b2"); +// Node b3(sd::ops::multiply(), "b3"); - Node d1(sd::ops::multiply(), "d1"); - Node d2(sd::ops::add(), "d2"); +// Node d1(sd::ops::multiply(), "d1"); +// Node d2(sd::ops::add(), "d2"); - Node e(sd::ops::subtract(), "e"); +// Node e(sd::ops::subtract(), "e"); - graph.addNode(a1, {"A", "B"}); - graph.addNode(a2, {"C", "D"}); +// graph.addNode(a1, {"A", "B"}); +// graph.addNode(a2, {"C", "D"}); - graph.addNode(b1, {"a1", "E"}); - graph.addNode(b2, {"a1", "a2"}); - graph.addNode(b3, {"a2", "F"}); +// graph.addNode(b1, {"a1", "E"}); +// graph.addNode(b2, {"a1", "a2"}); +// graph.addNode(b3, {"a2", "F"}); - graph.addNode(d1, {"b1", "b2"}); - graph.addNode(d2, {"b3", "b2"}); +// graph.addNode(d1, {"b1", "b2"}); +// graph.addNode(d2, {"b3", "b2"}); - graph.addNode(e, {"d1", "d2"}); +// graph.addNode(e, {"d1", "d2"}); - // we just check that nodes were really added - ASSERT_EQ(8, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(8, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 4 layer - ASSERT_EQ(4, optimized.layers()); +// // we expect that OptimizedGraph has exactly 4 layer +// ASSERT_EQ(4, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer0.width()); +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer0.width()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[0].length()); +// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[1].length()); +// ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 3 OpSequences - ASSERT_EQ(3, layer1.width()); +// // we expect layer has exactly 3 OpSequences +// ASSERT_EQ(3, layer1.width()); - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[2].length()); +// ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); - auto layer2 = optimized.layer(2); +// auto layer2 = optimized.layer(2); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer2.width()); +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer2.width()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[1].length()); +// ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer3 = optimized.layer(3); +// // checking last layer +// auto layer3 = optimized.layer(3); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer3.width()); +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer3.width()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); -} +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[0].length()); +// ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_5) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_5) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); - Node f(sd::ops::multiply(), "f"); +// Node a(sd::ops::multiply(), "a"); +// Node b(sd::ops::add(), "b"); +// Node c(sd::ops::subtract(), "c"); +// Node d(sd::ops::add(), "d"); +// Node e(sd::ops::multiply(), "e"); +// Node f(sd::ops::multiply(), "f"); - Node g(sd::ops::multiply(), "g"); - Node h(sd::ops::multiply(), "h"); +// Node g(sd::ops::multiply(), "g"); +// Node h(sd::ops::multiply(), "h"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"C", "D"}); +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"C", "D"}); - graph.addNode(c, {"a", "b"}); - graph.addNode(d, {"a", "b"}); +// graph.addNode(c, {"a", "b"}); +// graph.addNode(d, {"a", "b"}); - graph.addNode(e, {"c", "d"}); - graph.addNode(f, {"c", "d"}); +// graph.addNode(e, {"c", "d"}); +// graph.addNode(f, {"c", "d"}); - graph.addNode(g, {"c", "e"}); - graph.addNode(h, {"d", "f"}); +// graph.addNode(g, {"c", "e"}); +// graph.addNode(h, {"d", "f"}); - // we just check that nodes were really added - ASSERT_EQ(8, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(8, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(4, optimized.layers()); +// // we expect that OptimizedGraph has exactly 3 layer +// ASSERT_EQ(4, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer0.width()); - // auto sequence = layer0[0]; +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer0.width()); +// // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); - // sequence = layer0[1]; +// // sequence = layer0[1]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(1, layer0[1].length()); +// ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); +// // we expect layer has exactly 2 OpSequences +// ASSERT_EQ(2, layer1.width()); - // sequence = layer1[0]; +// // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; +// // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); - // checking before last layer - auto layer2 = optimized.layer(2); +// // checking before last layer +// auto layer2 = optimized.layer(2); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer2.width()); - // sequence = layer2[0]; +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer2.width()); +// // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); - // sequence = layer2[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); +// // sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[1].length()); +// ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer3 = optimized.layer(3); +// // checking last layer +// auto layer3 = optimized.layer(3); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer3.width()); - // sequence = layer3[0]; +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer3.width()); +// // sequence = layer3[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[0].length()); +// ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); - // sequence = layer3[1]; +// // sequence = layer3[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[1].length()); - ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); -} +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[1].length()); +// ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_6) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_6) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // E +// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // F - graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // F +// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b1(sd::ops::add(), "b1"); - Node b2(sd::ops::subtract(), "b2"); +// Node a(sd::ops::multiply(), "a"); +// Node b1(sd::ops::add(), "b1"); +// Node b2(sd::ops::subtract(), "b2"); - Node c1(sd::ops::add(), "c1"); - Node c2(sd::ops::multiply(), "c2"); - Node c3(sd::ops::subtract(), "c3"); +// Node c1(sd::ops::add(), "c1"); +// Node c2(sd::ops::multiply(), "c2"); +// Node c3(sd::ops::subtract(), "c3"); - Node d1(sd::ops::multiply(), "d1"); - Node d2(sd::ops::multiply(), "d2"); +// Node d1(sd::ops::multiply(), "d1"); +// Node d2(sd::ops::multiply(), "d2"); - Node e(sd::ops::add(), "e"); +// Node e(sd::ops::add(), "e"); - graph.addNode(a, {"A", "B"}); +// graph.addNode(a, {"A", "B"}); - graph.addNode(b1, {"a", "C"}); - graph.addNode(b2, {"a", "D"}); +// graph.addNode(b1, {"a", "C"}); +// graph.addNode(b2, {"a", "D"}); - graph.addNode(c1, {"b1", "E"}); - graph.addNode(c2, {"b1", "b2"}); - graph.addNode(c3, {"b2", "F"}); +// graph.addNode(c1, {"b1", "E"}); +// graph.addNode(c2, {"b1", "b2"}); +// graph.addNode(c3, {"b2", "F"}); - graph.addNode(d1, {"c1", "c2"}); - graph.addNode(d2, {"c2", "c3"}); +// graph.addNode(d1, {"c1", "c2"}); +// graph.addNode(d2, {"c2", "c3"}); - graph.addNode(e, {"d1", "d2"}); +// graph.addNode(e, {"d1", "d2"}); - // we just check that nodes were really added - ASSERT_EQ(9, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(9, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(5, optimized.layers()); +// // we expect that OptimizedGraph has exactly 3 layer +// ASSERT_EQ(5, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer0.width()); - // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); +// // auto sequence = layer0[0]; +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(1, layer0[0].length()); +// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(2, layer1.width()); +// // we expect layer has exactly 2 OpSequences +// ASSERT_EQ(2, layer1.width()); - // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); +// // sequence = layer1[0]; +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); +// // sequence = layer1[1]; +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); - // checking midle layer - auto layer2 = optimized.layer(2); +// // checking midle layer +// auto layer2 = optimized.layer(2); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(3, layer2.width()); +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(3, layer2.width()); - // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); +// // sequence = layer2[0]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - // sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); +// // sequence = layer2[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[1].length()); +// ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); - // sequence = layer2[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); +// // sequence = layer2[2]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[2].length()); +// ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); - // checking before last layer - auto layer3 = optimized.layer(3); +// // checking before last layer +// auto layer3 = optimized.layer(3); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer3.width()); - // sequence = layer3[0]; +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer3.width()); +// // sequence = layer3[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[0].length()); +// ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); - // sequence = layer3[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[1].length()); - ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); +// // sequence = layer3[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[1].length()); +// ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); - // checking last layer - auto layer4 = optimized.layer(4); +// // checking last layer +// auto layer4 = optimized.layer(4); - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(1, layer4.width()); - // sequence = layer4[0]; +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(1, layer4.width()); +// // sequence = layer4[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer4[0].length()); - ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); -} +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer4[0].length()); +// ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_7) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_7) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::subtract(), "c"); - Node d(sd::ops::add(), "d"); - Node e(sd::ops::multiply(), "e"); +// Node a(sd::ops::multiply(), "a"); +// Node b(sd::ops::add(), "b"); +// Node c(sd::ops::subtract(), "c"); +// Node d(sd::ops::add(), "d"); +// Node e(sd::ops::multiply(), "e"); - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"a", "b"}); - graph.addNode(d, {"b", "c"}); +// graph.addNode(c, {"a", "b"}); +// graph.addNode(d, {"b", "c"}); - graph.addNode(e, {"b", "c", "d"}); +// graph.addNode(e, {"b", "c", "d"}); - // we just check that nodes were really added - ASSERT_EQ(5, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(5, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(5, optimized.layers()); +// // we expect that OptimizedGraph has exactly 3 layer +// ASSERT_EQ(5, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - // auto sequence = layer0[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer0.width()); +// // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(1, layer0[0].length()); +// ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(1, layer1.width()); - // sequence = layer1[0]; +// // we expect layer has exactly 2 OpSequences +// ASSERT_EQ(1, layer1.width()); +// // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); - // checking layer 2 - auto layer2 = optimized.layer(2); +// // checking layer 2 +// auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - // sequence = layer2[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer2.width()); +// // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); - // checking layer 3 - auto layer3 = optimized.layer(3); +// // checking layer 3 +// auto layer3 = optimized.layer(3); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer3.width()); - // sequence = layer3[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer3.width()); +// // sequence = layer3[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer3[0].length()); +// ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); - // checking layer 3 - auto layer4 = optimized.layer(4); +// // checking layer 3 +// auto layer4 = optimized.layer(4); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer4.width()); - // sequence = layer4[0]; +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer4.width()); +// // sequence = layer4[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer4[0].length()); - ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); -} +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer4[0].length()); +// ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_8) { - Graph graph; +// TEST_F(GraphAnalysisTests, basic_toposort_test_8) { +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // E - graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // E +// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // F - graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // F +// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); - Node a1(sd::ops::multiply(), "a1"); - Node a2(sd::ops::add(), "a2"); - Node a3(sd::ops::add(), "a3"); +// Node a1(sd::ops::multiply(), "a1"); +// Node a2(sd::ops::add(), "a2"); +// Node a3(sd::ops::add(), "a3"); - Node b1(sd::ops::subtract(), "b1"); - Node b2(sd::ops::add(), "b2"); - Node b3(sd::ops::multiply(), "b3"); +// Node b1(sd::ops::subtract(), "b1"); +// Node b2(sd::ops::add(), "b2"); +// Node b3(sd::ops::multiply(), "b3"); - graph.addNode(a1, {"A", "B"}); - graph.addNode(a2, {"C", "D"}); - graph.addNode(a3, {"E", "F"}); +// graph.addNode(a1, {"A", "B"}); +// graph.addNode(a2, {"C", "D"}); +// graph.addNode(a3, {"E", "F"}); - graph.addNode(b1, {"a1", "a2"}); - graph.addNode(b2, {"a1", "a2", "a3"}); - graph.addNode(b3, {"a2", "a3"}); +// graph.addNode(b1, {"a1", "a2"}); +// graph.addNode(b2, {"a1", "a2", "a3"}); +// graph.addNode(b3, {"a2", "a3"}); - // we just check that nodes were really added - ASSERT_EQ(6, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(6, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 2 layer - ASSERT_EQ(2, optimized.layers()); +// // we expect that OptimizedGraph has exactly 2 layer +// ASSERT_EQ(2, optimized.layers()); - // checking first layer first - auto layer0 = optimized.layer(0); +// // checking first layer first +// auto layer0 = optimized.layer(0); - // we expect layer has exactly 3 OpSequence - ASSERT_EQ(3, layer0.width()); - // auto sequence = layer0[0]; +// // we expect layer has exactly 3 OpSequence +// ASSERT_EQ(3, layer0.width()); +// // auto sequence = layer0[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[0].length()); +// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // sequence = layer0[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); +// // sequence = layer0[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[1].length()); +// ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - // sequence = layer0[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer0[2].length()); - ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); +// // sequence = layer0[2]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer0[2].length()); +// ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); - // checking second layer now - auto layer1 = optimized.layer(1); +// // checking second layer now +// auto layer1 = optimized.layer(1); - // we expect layer has exactly 3 OpSequences - ASSERT_EQ(3, layer1.width()); +// // we expect layer has exactly 3 OpSequences +// ASSERT_EQ(3, layer1.width()); - // sequence = layer1[0]; - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); +// // sequence = layer1[0]; +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; +// // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); - // sequence = layer1[2]; - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); -} +// // sequence = layer1[2]; +// ASSERT_EQ(1, layer1[2].length()); +// ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); +// } -TEST_F(GraphAnalysisTests, basic_toposort_test_9) { - // start graph +// TEST_F(GraphAnalysisTests, basic_toposort_test_9) { +// // start graph - Graph graph; +// Graph graph; - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - Node a(sd::ops::multiply(), "a"); +// Node a(sd::ops::multiply(), "a"); - Node b1(sd::ops::add(), "b1"); - Node b2(sd::ops::multiply(), "b2"); - Node b3(sd::ops::subtract(), "b3"); - Node b4(sd::ops::Pow(), "b4"); +// Node b1(sd::ops::add(), "b1"); +// Node b2(sd::ops::multiply(), "b2"); +// Node b3(sd::ops::subtract(), "b3"); +// Node b4(sd::ops::Pow(), "b4"); - Node c1(sd::ops::Pow(), "c1"); - Node c2(sd::ops::subtract(), "c2"); - Node c3(sd::ops::multiply(), "c3"); - Node c4(sd::ops::add(), "c4"); +// Node c1(sd::ops::Pow(), "c1"); +// Node c2(sd::ops::subtract(), "c2"); +// Node c3(sd::ops::multiply(), "c3"); +// Node c4(sd::ops::add(), "c4"); - Node c5(sd::ops::Pow(), "c5"); - Node c6(sd::ops::subtract(), "c6"); - Node c7(sd::ops::multiply(), "c7"); - Node c8(sd::ops::add(), "c8"); +// Node c5(sd::ops::Pow(), "c5"); +// Node c6(sd::ops::subtract(), "c6"); +// Node c7(sd::ops::multiply(), "c7"); +// Node c8(sd::ops::add(), "c8"); - graph.addNode(a, {"A", "B"}); +// graph.addNode(a, {"A", "B"}); - graph.addNode(b1, {"a", "C"}); - graph.addNode(b2, {"a", "C"}); - graph.addNode(b3, {"a", "C"}); - graph.addNode(b4, {"a", "C"}); +// graph.addNode(b1, {"a", "C"}); +// graph.addNode(b2, {"a", "C"}); +// graph.addNode(b3, {"a", "C"}); +// graph.addNode(b4, {"a", "C"}); - graph.addNode(c1, {"b1", "D"}); - graph.addNode(c2, {"b2", "D"}); - graph.addNode(c3, {"b3", "D"}); - graph.addNode(c4, {"b4", "D"}); +// graph.addNode(c1, {"b1", "D"}); +// graph.addNode(c2, {"b2", "D"}); +// graph.addNode(c3, {"b3", "D"}); +// graph.addNode(c4, {"b4", "D"}); - graph.addNode(c5, {"b1", "D"}); - graph.addNode(c6, {"b2", "D"}); - graph.addNode(c7, {"b3", "D"}); - graph.addNode(c8, {"b4", "D"}); +// graph.addNode(c5, {"b1", "D"}); +// graph.addNode(c6, {"b2", "D"}); +// graph.addNode(c7, {"b3", "D"}); +// graph.addNode(c8, {"b4", "D"}); - // we just check that nodes were really added - ASSERT_EQ(13, graph.size()); +// // we just check that nodes were really added +// ASSERT_EQ(13, graph.size()); - auto optimized = graph.optimizedGraph(); +// auto optimized = graph.optimizedGraph(); - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(3, optimized.layers()); +// // we expect that OptimizedGraph has exactly 1 layer +// ASSERT_EQ(3, optimized.layers()); - auto layer = optimized.layer(0); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - // auto sequence = layer[0]; +// auto layer = optimized.layer(0); +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer.width()); +// // auto sequence = layer[0]; - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer[0].length()); - ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 2 ops +// ASSERT_EQ(1, layer[0].length()); +// ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); - auto layer1 = optimized.layer(1); - // we expect layer has exactly 4 OpSequence - ASSERT_EQ(4, layer1.width()); - // sequence = layer1[0]; +// auto layer1 = optimized.layer(1); +// // we expect layer has exactly 4 OpSequence +// ASSERT_EQ(4, layer1.width()); +// // sequence = layer1[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer1[0].length()); +// ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); +// // sequence = layer1[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer1[1].length()); +// ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); - // sequence = layer1[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); +// // sequence = layer1[2]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer1[2].length()); +// ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); - // sequence = layer1[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer1[3].length()); - ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); +// // sequence = layer1[3]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer1[3].length()); +// ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); - auto layer2 = optimized.layer(2); - // we expect layer has exactly 4 OpSequence - ASSERT_EQ(8, layer2.width()); - // sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - - // sequence = layer2[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); - - // sequence = layer2[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); - - // sequence = layer2[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[3].length()); - ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); - - // sequence = layer2[4]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[4].length()); - ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); - - // sequence = layer2[5]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[5].length()); - ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); - - // sequence = layer2[6]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[6].length()); - ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); - - // sequence = layer2[7]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[7].length()); - ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); -} - -TEST_F(GraphAnalysisTests, basic_toposort_test_10) { - Graph graph; - - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::multiply(), "c"); - Node d(sd::ops::subtract(), "d"); - - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"a", "C"}); - graph.addNode(c, {"a", "D"}); - graph.addNode(d, {"a", "b", "c"}); - - // we just check that nodes were really added - ASSERT_EQ(4, graph.size()); - - auto optimized = graph.optimizedGraph(); - - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(3, optimized.layers()); - - auto layer = optimized.layer(0); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer.width()); - auto sequence = layer[0]; - - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - - auto layer1 = optimized.layer(1); - - // we expect layer has exactly 2 OpSequence - ASSERT_EQ(2, layer1.width()); - sequence = layer1[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); - sequence = layer1[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - - auto layer2 = optimized.layer(2); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - sequence = layer2[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); -} - -TEST_F(GraphAnalysisTests, basic_toposort_test_11) { - Graph graph; - - // A - graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); - - // B - graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - - // C - graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - // D - graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); - - Node a(sd::ops::multiply(), "a"); - Node b(sd::ops::add(), "b"); - Node c(sd::ops::multiply(), "c"); - Node d(sd::ops::subtract(), "d"); - - graph.addNode(a, {"A", "B"}); - graph.addNode(b, {"A", "C"}); - graph.addNode(c, {"B", "D"}); - graph.addNode(d, {"C", "D"}); - - // we just check that nodes were really added - ASSERT_EQ(4, graph.size()); - - auto optimized = graph.optimizedGraph(); - - // we expect that OptimizedGraph has exactly 1 layer - ASSERT_EQ(1, optimized.layers()); - - auto layer = optimized.layer(0); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(4, layer.width()); - auto sequence = layer[0]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - sequence = layer[1]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); - sequence = layer[2]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - sequence = layer[3]; - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); -} - -TEST_F(GraphAnalysisTests, test_cond_1) { - auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); - - auto optimized = graph.optimizedGraph(); - /* - some infor that would be useful for implementation - currently on optimization graph is passing next data - - Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation - class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation - type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; - Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: - cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: - -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: - 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, - 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: - 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: - cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: - -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; - Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; - Operation type: 21; Operation class: -1719689536 - - as it can be seen cond/LinSpace is not connected with any switch node(s) that - causes wrong results of optimization. also maybe to cover all conditional - operations will be need "Operation class", but this have to discovered deeper. - - All above is true for test_cond_2 - */ -} - -TEST_F(GraphAnalysisTests, test_cond_2) { - auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); -} +// auto layer2 = optimized.layer(2); +// // we expect layer has exactly 4 OpSequence +// ASSERT_EQ(8, layer2.width()); +// // sequence = layer2[0]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[0].length()); +// ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + +// // sequence = layer2[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[1].length()); +// ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); + +// // sequence = layer2[2]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[2].length()); +// ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); + +// // sequence = layer2[3]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[3].length()); +// ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); + +// // sequence = layer2[4]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[4].length()); +// ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); + +// // sequence = layer2[5]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[5].length()); +// ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); + +// // sequence = layer2[6]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[6].length()); +// ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); + +// // sequence = layer2[7]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, layer2[7].length()); +// ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); +// } + +// TEST_F(GraphAnalysisTests, basic_toposort_test_10) { +// Graph graph; + +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + +// Node a(sd::ops::multiply(), "a"); +// Node b(sd::ops::add(), "b"); +// Node c(sd::ops::multiply(), "c"); +// Node d(sd::ops::subtract(), "d"); + +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"a", "C"}); +// graph.addNode(c, {"a", "D"}); +// graph.addNode(d, {"a", "b", "c"}); + +// // we just check that nodes were really added +// ASSERT_EQ(4, graph.size()); + +// auto optimized = graph.optimizedGraph(); + +// // we expect that OptimizedGraph has exactly 1 layer +// ASSERT_EQ(3, optimized.layers()); + +// auto layer = optimized.layer(0); + +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer.width()); +// auto sequence = layer[0]; + +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); + +// auto layer1 = optimized.layer(1); + +// // we expect layer has exactly 2 OpSequence +// ASSERT_EQ(2, layer1.width()); +// sequence = layer1[0]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); +// sequence = layer1[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); + +// auto layer2 = optimized.layer(2); +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(1, layer2.width()); +// sequence = layer2[0]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +// } + +// TEST_F(GraphAnalysisTests, basic_toposort_test_11) { +// Graph graph; + +// // A +// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + +// // B +// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + +// // C +// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + +// // D +// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + +// Node a(sd::ops::multiply(), "a"); +// Node b(sd::ops::add(), "b"); +// Node c(sd::ops::multiply(), "c"); +// Node d(sd::ops::subtract(), "d"); + +// graph.addNode(a, {"A", "B"}); +// graph.addNode(b, {"A", "C"}); +// graph.addNode(c, {"B", "D"}); +// graph.addNode(d, {"C", "D"}); + +// // we just check that nodes were really added +// ASSERT_EQ(4, graph.size()); + +// auto optimized = graph.optimizedGraph(); + +// // we expect that OptimizedGraph has exactly 1 layer +// ASSERT_EQ(1, optimized.layers()); + +// auto layer = optimized.layer(0); + +// // we expect layer has exactly 1 OpSequence +// ASSERT_EQ(4, layer.width()); +// auto sequence = layer[0]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); +// sequence = layer[1]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); +// sequence = layer[2]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); +// sequence = layer[3]; +// // we expect that OpSequence has exactly 1 ops +// ASSERT_EQ(1, sequence.length()); +// ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +// } + +// TEST_F(GraphAnalysisTests, test_cond_1) { +// auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + +// auto optimized = graph.optimizedGraph(); +// /* +// some infor that would be useful for implementation +// currently on optimization graph is passing next data + +// Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation +// class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation +// type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; +// Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: +// cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: +// -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: +// 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, +// 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: +// 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: +// cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: +// -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; +// Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; +// Operation type: 21; Operation class: -1719689536 + +// as it can be seen cond/LinSpace is not connected with any switch node(s) that +// causes wrong results of optimization. also maybe to cover all conditional +// operations will be need "Operation class", but this have to discovered deeper. + +// All above is true for test_cond_2 +// */ +// } + +// TEST_F(GraphAnalysisTests, test_cond_2) { +// auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); +// } diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 556b688e759e..641e5de291dc 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -39,67 +39,67 @@ class GraphExecutorTests : public testing::Test { public: }; -TEST_F(GraphExecutorTests, test_basic_exec_1) { - GraphMemoryManager memoryManager; - Graph graph; +// TEST_F(GraphExecutorTests, test_basic_exec_1) { +// GraphMemoryManager memoryManager; +// Graph graph; - OptimizedGraph optimizedGraph; - OpSequence sequence; +// OptimizedGraph optimizedGraph; +// OpSequence sequence; - optimizedGraph.append(sequence); +// optimizedGraph.append(sequence); - VariableProxy proxy(&graph.variableSpace()); - GraphExecutor executor; - executor.execute(optimizedGraph, proxy); -} +// VariableProxy proxy(&graph.variableSpace()); +// GraphExecutor executor; +// executor.execute(optimizedGraph, proxy); +// } -TEST_F(GraphExecutorTests, test_basic_exec_2) { - GraphMemoryManager mgr; - Graph graph(nullptr, mgr); +// TEST_F(GraphExecutorTests, test_basic_exec_2) { +// GraphMemoryManager mgr; +// Graph graph(nullptr, mgr); - auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); - auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); +// auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); +// auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); +// auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); - auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); +// auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); - graph.addVariable("A", A); - graph.addVariable("B", B); - graph.addVariable("C", C); +// graph.addVariable("A", A); +// graph.addVariable("B", B); +// graph.addVariable("C", C); - Node m(sd::ops::multiply(), "mul"); - Node a(sd::ops::add(), "add"); +// Node m(sd::ops::multiply(), "mul"); +// Node a(sd::ops::add(), "add"); - graph.addNode(m, {"A", "B"}); - graph.addNode(a, {"mul", "C"}); +// graph.addNode(m, {"A", "B"}); +// graph.addNode(a, {"mul", "C"}); - OptimizedGraph optimizedGraph; - OpSequence sequence; +// OptimizedGraph optimizedGraph; +// OpSequence sequence; - ASSERT_EQ(2, m.protoContext().inputs().size()); - ASSERT_EQ(2, a.protoContext().inputs().size()); +// ASSERT_EQ(2, m.protoContext().inputs().size()); +// ASSERT_EQ(2, a.protoContext().inputs().size()); - sequence.append(m.customOp(), m.protoContext()); - sequence.append(a.customOp(), a.protoContext()); +// sequence.append(m.customOp(), m.protoContext()); +// sequence.append(a.customOp(), a.protoContext()); - optimizedGraph.append(sequence); +// optimizedGraph.append(sequence); - ASSERT_EQ(2, sequence.length()); - ASSERT_EQ(1, optimizedGraph.layers()); +// ASSERT_EQ(2, sequence.length()); +// ASSERT_EQ(1, optimizedGraph.layers()); - VariableProxy proxy(&graph.variableSpace()); - GraphExecutor executor; - executor.execute(optimizedGraph, proxy); +// VariableProxy proxy(&graph.variableSpace()); +// GraphExecutor executor; +// executor.execute(optimizedGraph, proxy); - // checking results by ID - ASSERT_TRUE(proxy.hasVariable(m.id())); - ASSERT_TRUE(proxy.hasVariable(a.id())); +// // checking results by ID +// ASSERT_TRUE(proxy.hasVariable(m.id())); +// ASSERT_TRUE(proxy.hasVariable(a.id())); - // checking results by name - ASSERT_TRUE(proxy.hasVariable("mul")); - ASSERT_TRUE(proxy.hasVariable("add")); +// // checking results by name +// ASSERT_TRUE(proxy.hasVariable("mul")); +// ASSERT_TRUE(proxy.hasVariable("add")); - // checking if result is valid - auto result = proxy.getVariable(a.id())->getNDArray(); - ASSERT_EQ(exp, *result); -} \ No newline at end of file +// // checking if result is valid +// auto result = proxy.getVariable(a.id())->getNDArray(); +// ASSERT_EQ(exp, *result); +// } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 0c0f795e03b5..2abe7bbfbb83 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -38,42 +38,42 @@ class OpSequenceTests : public testing::Test { OpSequenceTests() {} }; -TEST_F(OpSequenceTests, test_iterator_1) { - Graph graph; - OpSequence sequence; +// TEST_F(OpSequenceTests, test_iterator_1) { +// Graph graph; +// OpSequence sequence; - ASSERT_EQ(0, sequence.length()); +// ASSERT_EQ(0, sequence.length()); - ops::add op1; - ops::multiply op2; +// ops::add op1; +// ops::multiply op2; - Context ctx1(1); - Context ctx2(2); +// Context ctx1(1); +// Context ctx2(2); - sequence.append(&op1, ctx1); - sequence.append(&op2, ctx2); +// sequence.append(&op1, ctx1); +// sequence.append(&op2, ctx2); - ASSERT_EQ(2, sequence.length()); +// ASSERT_EQ(2, sequence.length()); - int cnt = 1; - for (const auto &v : sequence) { - ASSERT_EQ(cnt++, v.protoContext().nodeId()); - } +// int cnt = 1; +// for (const auto &v : sequence) { +// ASSERT_EQ(cnt++, v.protoContext().nodeId()); +// } - ASSERT_EQ(3, cnt); +// ASSERT_EQ(3, cnt); - OptimizedGraph optimizedGraph; - ASSERT_EQ(0, optimizedGraph.layers()); +// OptimizedGraph optimizedGraph; +// ASSERT_EQ(0, optimizedGraph.layers()); - optimizedGraph.append(sequence); - ASSERT_EQ(1, optimizedGraph.layers()); +// optimizedGraph.append(sequence); +// ASSERT_EQ(1, optimizedGraph.layers()); - auto layer = optimizedGraph.layer(0); +// auto layer = optimizedGraph.layer(0); - // we expect exactly 1 sequence in this layer - ASSERT_EQ(1, layer.width()); +// // we expect exactly 1 sequence in this layer +// ASSERT_EQ(1, layer.width()); - auto seq = layer[0]; +// auto seq = layer[0]; - ASSERT_EQ(2, seq.length()); -} +// ASSERT_EQ(2, seq.length()); +// } From 69963ba4080f52d0a872db548a8e395a0f4bd79a Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 22 May 2020 19:15:45 +0300 Subject: [PATCH 126/233] - testing and fixing bugs in algorithm for topological sort of graph Signed-off-by: Yurii --- libnd4j/include/graph/OptimizedGraph.h | 47 ++- .../include/graph/execution/ExecutionLayer.h | 1 + .../graph/execution/impl/ExecutionLayer.cpp | 12 +- .../graph/execution/impl/OpSequence.cpp | 4 +- libnd4j/include/graph/impl/Graph.cpp | 14 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 141 ++++--- .../layers_tests/GraphAnalysisTests.cpp | 390 ++++++++---------- 7 files changed, 323 insertions(+), 286 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 507f2203e928..69b51048d132 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -35,21 +35,40 @@ class Graph; class SD_EXPORT OptimizedGraph { private: - std::vector> _sortedGraph; + std::vector _sortedGraph; // const Graph& _originalGraph; public: OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); - OptimizedGraph() {}; + // move constructor + OptimizedGraph(OptimizedGraph&& other) noexcept; + // default constructor + OptimizedGraph() = default; + + /** + * returns number of nodes in this graph instance + * @return + */ size_t size() const; + /** + * returns OpSequences stored in a given layer + * @param index + * @return + */ + const ExecutionLayer& layer(const uint64_t& index) const { return _sortedGraph.at(index); } + + + /** + * returns number of layers within OptimizedGraph + * @return + */ + uint64_t numOfLayers() const { return _sortedGraph.size(); } + + + // move assignment operator + OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; - struct NodeInfo { - std::vector _connections = std::vector(); - int _id = -1; - NodeInfo(const int id): _id(id), _connections(std::vector()) {} - NodeInfo() = delete; - }; }; @@ -89,18 +108,8 @@ class SD_EXPORT OptimizedGraph { // // move assignment operator // OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; -// /** -// * This method returns number of layers within OptimizedGraph -// * @return -// */ -// uint64_t layers() const; -// /** -// * This method returns OpSequences stored in a given layer -// * @param index -// * @return -// */ -// const ExecutionLayer& layer(uint64_t index) const; + // /** // * This method allows to append layer to this OptimizedGraph instance diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index 04dcb2bbd538..9149c1c7f4fa 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -63,6 +63,7 @@ class SD_EXPORT ExecutionLayer { * This method appends OpSequence to the end of this layer * @param sequence */ + void append(OpSequence&& sequence); void append(const OpSequence& sequence); }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index 127b1436f99b..5cb13496b0ba 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -36,10 +36,15 @@ const OpSequence &ExecutionLayer::operator[](uint64_t index) const { return at(index); } -void ExecutionLayer::append(const OpSequence &sequence) { + +void ExecutionLayer::append(OpSequence&& sequence) { + _sequences.emplace_back(std::move(sequence)); +} +void ExecutionLayer::append(const OpSequence& sequence) { _sequences.emplace_back(sequence); } + ExecutionLayer::ExecutionLayer(const ExecutionLayer &other) noexcept { _sequences = other._sequences; } @@ -53,8 +58,9 @@ ExecutionLayer &ExecutionLayer::operator=( return *this; } -ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept { - _sequences = std::move(other._sequences); +// move constructor +ExecutionLayer::ExecutionLayer(ExecutionLayer &&other) noexcept: _sequences(std::move(other._sequences)) { + } ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0dd4b908e4a9..0aa78e5011d5 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -41,8 +41,8 @@ OpSequence::OpSequence(const OpSequence &other) noexcept { //////////////////////////////////////////////////////////////////////// // move constructor -OpSequence::OpSequence(OpSequence &&other) noexcept { - _ops = std::move(other._ops); +OpSequence::OpSequence(OpSequence &&other) noexcept: _ops(std::move(other._ops)) { + } OpSequence &OpSequence::operator=(OpSequence &&other) noexcept { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index c989e8e9fb2d..3cc0531c871c 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -273,12 +273,12 @@ void Graph::printOut() { fflush(stdout); - if (size() > 0) { - nd4j_printf("\nPrinting out Nodes...\n", ""); + // if (size() > 0) { + // nd4j_printf("\nPrinting out Nodes...\n", ""); - // since we need structure - we'll print out nodes of OptimizedGraph - // optimizedGraph().printOut(); - } + // // since we need structure - we'll print out nodes of OptimizedGraph + // optimizedGraph().printOut(); + // } } Nd4jStatus Graph::validateNode(Node *node) { @@ -697,8 +697,8 @@ const OptimizedGraph &Graph::optimizedGraph() const { std::lock_guard lock(_optimizedLock); // optionally rebuild optimized graph, if it's out of date - // if (_optimized.size() != size()) - // _optimized = OptimizedGraph(unmappedNodes()); + if (_optimized.size() != size()) + _optimized = OptimizedGraph(unmappedNodes(), variableSpace()); return _optimized; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index f429b4cef1f4..9c815bb3cf0f 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -16,6 +16,7 @@ // // @author Yurii Shyrma (iuriish@yahoo.com) +// @author raver119@gmail.com // #include @@ -25,6 +26,24 @@ namespace sd { namespace graph { +/////////////////////////////////////////////////////////////////// +// move constructor +OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept: _sortedGraph(std::move(other._sortedGraph)) { + +} + +/////////////////////////////////////////////////////////////////// +// move assignment operator +OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { + if (this == &other) + return *this; + + _sortedGraph = std::move(other._sortedGraph); + + return *this; +} + +/////////////////////////////////////////////////////////////////// OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace) { MAP_IMPL> workMap; // key is node id, value is vector containing connections (internal inputs, that is nodes, not input arrays) @@ -40,87 +59,129 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS currList.push_front(i1.first); } + + // for (const auto& p : workMap) { + // printf("node %i: ", p.first); + // std::for_each(p.second.begin(), p.second.end(), [] (const int &j) {printf("%i, ", j);}); + // printf("\n-----------------\n"); + // } + + // 2d vector - std::vector> sortedNodes; // 0d - layers, 1d - nodes belonging to layer, OpSequence are separated by -1 in current layer + std::vector> sortedNodes; // 0d - layers, 1d - nodes of layer, OpSequence are separated by -1 in current layer - // searches for of exit node - auto findExitNode = [] (const decltype(workMap)& m) { + //*** lambda: searches for of start node + auto findStartNode = [] (const decltype(workMap)& m) { return std::find_if(m.begin(), m.end(), [] (const std::pair> &p) {return p.second.empty();}); }; - // searches for next node in current OpSequence - auto findNextNodeInOpSeq = [] (const decltype(workMap)& m, const int &currId) { - return std::find_if(m.begin(), m.end(), [&currId] (const std::pair> &p) {return p.second.front() == currId && std::distance(p.second.begin(), p.second.end()) == 1;}); + //*** lambda: searches for next node in current OpSequence + auto findNextNodeInOpSeq = [] (const decltype(workMap)& m, const int &idOfStartNode, int& resultId) { + uint count = 0; + decltype(workMap)::const_iterator it; + for (it = m.cbegin(); it != m.cend(); ++it) + if (std::find(it->second.cbegin(), it->second.cend(), idOfStartNode) != it->second.end()) { ++count; if(count > 1) break; } + + if(count == 1 && !it->second.empty() && std::distance(it->second.begin(), it->second.end()) == 1) + resultId = it->first; + else + count = 0; + return count; }; - decltype(workMap)::const_iterator exitNode(nullptr), nextNodeInSeq(nullptr); + decltype(workMap)::const_iterator startNode; while (!workMap.empty()) { sortedNodes.emplace_back(std::vector()); // add layer auto& currLayer = sortedNodes.back(); - // loop of searching for exit nodes - while ((exitNode = findExitNode(workMap)) != workMap.end()) { + // loop of searching for start nodes + while ((startNode = findStartNode(workMap)) != workMap.end()) { - int currId = exitNode->first; // id of node under consideration + int nextId, currId(startNode->first); // id of node under consideration currLayer.push_back(currId); workMap.erase(currId); // loop of searching for next nodes in OpSequence - while ((nextNodeInSeq = findNextNodeInOpSeq(workMap, currId)) != workMap.end()) { - - currId = nextNodeInSeq->first; + while (findNextNodeInOpSeq(workMap, currId, nextId) == 1) { - sortedNodes.back().push_back(currId); - workMap.erase(currId); + currLayer.push_back(nextId); + workMap.erase(nextId); + currId = nextId; } + currLayer.push_back(-1); // mark end of OpSequence + } - // remove connections in all nodes pointing at exit nodes - for (uint i = 0; i < currLayer.size(); ++i) { + // remove connections in all nodes pointing at start nodes + for (int i = 0; i < currLayer.size(); ++i) { if(currLayer[i] != -1 || i == 0) // if not at the and of OpSequence continue; - const int idOfExitNode = currLayer[i-1]; + const int idOfStartNode = currLayer[i-1]; - std::for_each(workMap.begin(), workMap.end(), [&idOfExitNode] (std::pair> &p) {p.second.remove(idOfExitNode);}); + std::for_each(workMap.begin(), workMap.end(), [&idOfStartNode] (std::pair> &p) {p.second.remove(idOfStartNode);}); } - } + // int i = 0; + // for (const auto& vec : sortedNodes) { + // printf("layer %i: ",i++); + // for (int j = 0; j < vec.size(); ++j) { + // printf("%i, ", vec[j]); + // } + // printf("\n"); + // } + //*** fill _sortedGraph *** // + // loop through layers + for (const auto& vec : sortedNodes) { + ExecutionLayer layer; - // auto fi = std::find_if(ops.begin(), ops.end(), [name](ops::OpDescriptor a) { return a.getOpName()->compare(name) == 0;}); + // loop through OpSequences + uint i = 0; + while(i < vec.size()) { + OpSequence seq; - // for (const auto& i0 : idVsConnects) { + // loop through OpSequence + while(vec[i++] != -1) + seq.append(inMap.at(vec[i-1]).customOp(), inMap.at(vec[i-1]).protoContext()); - // bool isExitNode = true; + layer.append(std::move(seq)); + } + _sortedGraph.emplace_back(std::move(layer)); + } +} - // for (const auto& i1 : idVsConnects) { - // if (i0.first == i1.first) - // continue; +/////////////////////////////////////////////////////////////////// +size_t OptimizedGraph::size() const { + // std::lock_guard lock(_mutex); - // if (i1.second) { - // isExitNode = false; - // break; - // } - // } + size_t size = 0; - // } + std::for_each(_sortedGraph.begin(), _sortedGraph.end(), [&size] (const ExecutionLayer &l) { + for (int e = 0; e < l.width(); e++) { + size += l.at(0).length(); + } + ;}); + return size; } + + + // OptimizedGraph::OptimizedGraph(Graph* original) { // _originalGraph = original; // _memoryManager = const_cast(&original->memoryManager()); @@ -161,25 +222,9 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS // return *this; // } -// size_t OptimizedGraph::size() const { -// std::lock_guard lock(_mutex); -// std::vector seq; -// if (_size == 0) -// for (const auto& v : _onion) { -// for (int e = 0; e < v.second.width(); e++) { -// _size += v.second.at(0).length(); -// } -// } -// return _size; -// } - -// uint64_t OptimizedGraph::layers() const { return _onion.size(); } -// const ExecutionLayer& OptimizedGraph::layer(uint64_t index) const { -// return _onion.at(index); -// } // void OptimizedGraph::append(const std::vector& layer) { // std::lock_guard lock(_mutex); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 46c657d3c11b..51293fdb77e9 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -37,287 +37,263 @@ class GraphAnalysisTests : public testing::Test { } }; -// TEST_F(GraphAnalysisTests, basic_toposort_test_1) { -// Graph graph; +TEST_F(GraphAnalysisTests, optimizedGraph_1) { -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // A*B + C + Graph graph; -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + graph.addVariable("A", NDArray('c', {3}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, sd::DataType::INT32)); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); -// Node a(sd::ops::multiply(), "multiply"); -// Node b(sd::ops::add(), "add"); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"multiply", "C"}); + // we just check that nodes were really added + ASSERT_EQ(2, graph.size()); -// // we just check that nodes were really added -// ASSERT_EQ(2, graph.size()); + const auto& optimized = graph.optimizedGraph(); -// auto optimized = graph.optimizedGraph(); + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.numOfLayers()); -// // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(1, optimized.layers()); + auto layer = optimized.layer(0); -// auto layer = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer.width()); -// auto sequence = layer[0]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, sequence.length()); -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); +} -// ASSERT_EQ(4, sequence.at(0).protoContext().nodeId()); -// ASSERT_EQ(5, sequence.at(1).protoContext().nodeId()); -// } +TEST_F(GraphAnalysisTests, optimizedGraph_2) { -// TEST_F(GraphAnalysisTests, basic_toposort_test_2) { -// Graph graph; + // 0 = A*B, 1_0 = 0+C, 1_1 = 0-D + Graph graph; -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {4, 4, 4}, sd::DataType::INT32)); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); - -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); - -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {4, 4, 4})); - -// Node a(sd::ops::multiply(), "multiply"); -// Node b(sd::ops::add(), "add"); -// Node c(sd::ops::subtract(), "subtract"); - -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"multiply", "C"}); -// graph.addNode(c, {"multiply", "D"}); - -// // we just check that nodes were really added -// ASSERT_EQ(3, graph.size()); - -// auto optimized = graph.optimizedGraph(); - -// // graph size must stay the same -// ASSERT_EQ(3, graph.size()); - -// // we expect that OptimizedGraph has exactly 2 layers -// ASSERT_EQ(2, optimized.layers()); - -// // checking first layer first -// auto layer0 = optimized.layer(0); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer0.width()); -// ; + Node a(sd::ops::multiply(), "multiply"); + Node b(sd::ops::add(), "add"); + Node c(sd::ops::subtract(), "subtract"); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[0].length()); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"multiply", "C"}); + graph.addNode(c, {"multiply", "D"}); -// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + // we just check that nodes were really added + ASSERT_EQ(3, graph.size()); -// // checking second layer now -// auto layer1 = optimized.layer(1); + const auto& optimized = graph.optimizedGraph(); -// // we expect layer has exactly 2 OpSequences -// ASSERT_EQ(2, layer1.width()); + // graph size must stay the same + ASSERT_EQ(3, graph.size()); -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); + // we expect that OptimizedGraph has exactly 2 layers + ASSERT_EQ(2, optimized.numOfLayers()); -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); -// } + // checking first layer first + auto layer0 = optimized.layer(0); -// TEST_F(GraphAnalysisTests, basic_toposort_test_3) { -// Graph graph; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect that OpSequence has exactly 1 node + ASSERT_EQ(1, layer0[0].length()); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // checking second layer now + auto layer1 = optimized.layer(1); -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); -// Node a(sd::ops::multiply(), "a"); -// Node b(sd::ops::add(), "b"); -// Node c(sd::ops::subtract(), "c"); -// Node d(sd::ops::add(), "d"); -// Node e(sd::ops::multiply(), "e"); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"a", "C"}); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); +} -// graph.addNode(c, {"b", "D"}); -// graph.addNode(d, {"b", "D"}); +TEST_F(GraphAnalysisTests, optimizedGraph_3) { -// graph.addNode(e, {"c", "d"}); + // 0 = A*B+C, 1_0 = 0-D, 1_1 = 0+D, 2 = 1_0*1_1 + Graph graph; -// // we just check that nodes were really added -// ASSERT_EQ(5, graph.size()); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// auto optimized = graph.optimizedGraph(); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); -// // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(3, optimized.layers()); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); -// // checking first layer first -// auto layer0 = optimized.layer(0); + graph.addNode(c, {"b", "D"}); + graph.addNode(d, {"b", "D"}); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer0.width()); -// // auto sequence = layer0[0]; + graph.addNode(e, {"c", "d"}); -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(2, layer0[0].length()); + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); -// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); -// ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); + const auto& optimized = graph.optimizedGraph(); -// // checking second layer now -// auto layer1 = optimized.layer(1); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(3, optimized.numOfLayers()); -// // we expect layer has exactly 2 OpSequences -// ASSERT_EQ(2, layer1.width()); + // checking first layer first + auto layer0 = optimized.layer(0); -// // sequence = layer1[0]; + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + // auto sequence = layer0[0]; -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(2, layer0[0].length()); -// // sequence = layer1[1]; + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + ASSERT_EQ(6, layer0[0].at(1).protoContext().nodeId()); -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + // checking second layer now + const auto& layer1 = optimized.layer(1); -// // checking last layer -// auto layer2 = optimized.layer(2); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer2.width()); -// // sequence = layer2[0]; + // sequence = layer1[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); -// } + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); -// TEST_F(GraphAnalysisTests, basic_toposort_test_4) { -// Graph graph; + // sequence = layer1[1]; -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // checking last layer + auto layer2 = optimized.layer(2); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + // sequence = layer2[0]; -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); +} -// // E -// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); +TEST_F(GraphAnalysisTests, optimizedGraph_4) { + Graph graph; -// // F -// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// Node a1(sd::ops::multiply(), "a1"); -// Node a2(sd::ops::add(), "a2"); + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); -// Node b1(sd::ops::subtract(), "b1"); -// Node b2(sd::ops::add(), "b2"); -// Node b3(sd::ops::multiply(), "b3"); + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); -// Node d1(sd::ops::multiply(), "d1"); -// Node d2(sd::ops::add(), "d2"); + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::add(), "d2"); -// Node e(sd::ops::subtract(), "e"); + Node e(sd::ops::subtract(), "e"); -// graph.addNode(a1, {"A", "B"}); -// graph.addNode(a2, {"C", "D"}); + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); -// graph.addNode(b1, {"a1", "E"}); -// graph.addNode(b2, {"a1", "a2"}); -// graph.addNode(b3, {"a2", "F"}); + graph.addNode(b1, {"a1", "E"}); + graph.addNode(b2, {"a1", "a2"}); + graph.addNode(b3, {"a2", "F"}); -// graph.addNode(d1, {"b1", "b2"}); -// graph.addNode(d2, {"b3", "b2"}); + graph.addNode(d1, {"b1", "b2"}); + graph.addNode(d2, {"b3", "b2"}); -// graph.addNode(e, {"d1", "d2"}); + graph.addNode(e, {"d1", "d2"}); -// // we just check that nodes were really added -// ASSERT_EQ(8, graph.size()); + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); -// auto optimized = graph.optimizedGraph(); + const auto& optimized = graph.optimizedGraph(); -// // we expect that OptimizedGraph has exactly 4 layer -// ASSERT_EQ(4, optimized.layers()); + // we expect that OptimizedGraph has exactly 4 layer + ASSERT_EQ(4, optimized.numOfLayers()); -// // checking first layer first -// auto layer0 = optimized.layer(0); + // checking first layer first + auto layer0 = optimized.layer(0); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer0.width()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[0].length()); -// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[1].length()); -// ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); -// // checking second layer now -// auto layer1 = optimized.layer(1); + // checking second layer now + auto layer1 = optimized.layer(1); -// // we expect layer has exactly 3 OpSequences -// ASSERT_EQ(3, layer1.width()); + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); -// ASSERT_EQ(1, layer1[2].length()); -// ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); -// auto layer2 = optimized.layer(2); + auto layer2 = optimized.layer(2); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer2.width()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(12, layer2[0].at(0).protoContext().nodeId()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[1].length()); -// ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(13, layer2[1].at(0).protoContext().nodeId()); -// // checking last layer -// auto layer3 = optimized.layer(3); + // checking last layer + auto layer3 = optimized.layer(3); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer3.width()); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[0].length()); -// ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); -// } + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); +} // TEST_F(GraphAnalysisTests, basic_toposort_test_5) { // Graph graph; @@ -362,7 +338,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(4, optimized.layers()); +// ASSERT_EQ(4, optimized.numOfLayers()); // // checking first layer first // auto layer0 = optimized.layer(0); @@ -486,7 +462,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(5, optimized.layers()); +// ASSERT_EQ(5, optimized.numOfLayers()); // // checking first layer first // auto layer0 = optimized.layer(0); @@ -594,7 +570,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(5, optimized.layers()); +// ASSERT_EQ(5, optimized.numOfLayers()); // // checking first layer first // auto layer0 = optimized.layer(0); @@ -694,7 +670,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 2 layer -// ASSERT_EQ(2, optimized.layers()); +// ASSERT_EQ(2, optimized.numOfLayers()); // // checking first layer first // auto layer0 = optimized.layer(0); @@ -794,7 +770,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(3, optimized.layers()); +// ASSERT_EQ(3, optimized.numOfLayers()); // auto layer = optimized.layer(0); // // we expect layer has exactly 1 OpSequence @@ -904,7 +880,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(3, optimized.layers()); +// ASSERT_EQ(3, optimized.numOfLayers()); // auto layer = optimized.layer(0); @@ -969,7 +945,7 @@ class GraphAnalysisTests : public testing::Test { // auto optimized = graph.optimizedGraph(); // // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(1, optimized.layers()); +// ASSERT_EQ(1, optimized.numOfLayers()); // auto layer = optimized.layer(0); From 03d4f343082bdba23a2be1ff613b6c5c85802761 Mon Sep 17 00:00:00 2001 From: Yurii Date: Mon, 25 May 2020 18:47:54 +0300 Subject: [PATCH 127/233] - algorithm of topological sort of graph with no cycles is completed Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 205 ++-- .../layers_tests/GraphAnalysisTests.cpp | 982 +++++++++--------- 2 files changed, 564 insertions(+), 623 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 9c815bb3cf0f..1723e03ef82d 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -35,134 +35,93 @@ OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept: _sortedGraph(st /////////////////////////////////////////////////////////////////// // move assignment operator OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { - if (this == &other) - return *this; - _sortedGraph = std::move(other._sortedGraph); + if (this == &other) + return *this; + + _sortedGraph = std::move(other._sortedGraph); - return *this; + return *this; } /////////////////////////////////////////////////////////////////// OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace) { - MAP_IMPL> workMap; // key is node id, value is vector containing connections (internal inputs, that is nodes, not input arrays) - - // fill workMap - for (const auto& i0 : inMap) { - - std::forward_list& currList = workMap[i0.first] = {}; - - // loop through inputs of node - for (const auto& i1 : i0.second.input()) - if (!varSpace.hasVariable(i1.first)) - currList.push_front(i1.first); - } - - - // for (const auto& p : workMap) { - // printf("node %i: ", p.first); - // std::for_each(p.second.begin(), p.second.end(), [] (const int &j) {printf("%i, ", j);}); - // printf("\n-----------------\n"); - // } - - - // 2d vector - std::vector> sortedNodes; // 0d - layers, 1d - nodes of layer, OpSequence are separated by -1 in current layer - - //*** lambda: searches for of start node - auto findStartNode = [] (const decltype(workMap)& m) { - return std::find_if(m.begin(), m.end(), [] (const std::pair> &p) {return p.second.empty();}); - }; - - //*** lambda: searches for next node in current OpSequence - auto findNextNodeInOpSeq = [] (const decltype(workMap)& m, const int &idOfStartNode, int& resultId) { - uint count = 0; - decltype(workMap)::const_iterator it; - for (it = m.cbegin(); it != m.cend(); ++it) - if (std::find(it->second.cbegin(), it->second.cend(), idOfStartNode) != it->second.end()) { ++count; if(count > 1) break; } - - if(count == 1 && !it->second.empty() && std::distance(it->second.begin(), it->second.end()) == 1) - resultId = it->first; - else - count = 0; - return count; - }; - - decltype(workMap)::const_iterator startNode; - - while (!workMap.empty()) { - - sortedNodes.emplace_back(std::vector()); // add layer - auto& currLayer = sortedNodes.back(); - - // loop of searching for start nodes - while ((startNode = findStartNode(workMap)) != workMap.end()) { - - int nextId, currId(startNode->first); // id of node under consideration - - currLayer.push_back(currId); - workMap.erase(currId); - - // loop of searching for next nodes in OpSequence - while (findNextNodeInOpSeq(workMap, currId, nextId) == 1) { - - currLayer.push_back(nextId); - workMap.erase(nextId); - currId = nextId; - } - - currLayer.push_back(-1); // mark end of OpSequence - + struct NodeInfo { + uint _layerNum = 0; + std::vector _opSeq = {}; + std::vector _in = {}; + std::vector _out = {}; + }; + + MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) + + // create workMap, fill vectors containing input and output nodes per each node, and find start nodes + std::vector startNodes; + for (const auto& p : inMap) { + + for (const auto& v : p.second.input()) + if (!varSpace.hasVariable(v.first)) { + workMap[p.first]._in.push_back(v.first); + workMap[v.first]._out.push_back(p.first); + } + if(workMap[p.first]._in.empty()) + startNodes.push_back(p.first); } - // remove connections in all nodes pointing at start nodes - for (int i = 0; i < currLayer.size(); ++i) { - - if(currLayer[i] != -1 || i == 0) // if not at the and of OpSequence - continue; - - const int idOfStartNode = currLayer[i-1]; - - std::for_each(workMap.begin(), workMap.end(), [&idOfStartNode] (std::pair> &p) {p.second.remove(idOfStartNode);}); + // collect OpSequences (fill _opSeq) + std::vector nodesToDelete; + for (auto& p : workMap) { + + if(p.second._in.size() != 1) { + + auto& out = p.second._out; + while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { + nodesToDelete.push_back(out[0]); + p.second._opSeq.push_back(out[0]); + out = workMap[out[0]]._out; + } + if(out != p.second._out) + p.second._out = std::move(out); + } } - } - - // int i = 0; - // for (const auto& vec : sortedNodes) { - // printf("layer %i: ",i++); - // for (int j = 0; j < vec.size(); ++j) { - // printf("%i, ", vec[j]); - // } - // printf("\n"); - // } + // delete nodes present in _opSeq, their ids are already stored in nodesToDelete + for (const auto& i : nodesToDelete) + workMap.erase(i); + // lambda for topological sort + std::function visit = [&visit, &workMap] (const int id, const uint layerNum, uint& numOfLayers) { + if(layerNum <= workMap[id]._layerNum) { return; } + workMap[id]._layerNum = layerNum; + if(numOfLayers < layerNum) { numOfLayers = layerNum; } + for (const auto& nextId : workMap[id]._out) + visit(nextId, layerNum+1, numOfLayers); + }; - //*** fill _sortedGraph *** // - // loop through layers - for (const auto& vec : sortedNodes) { + // perform topological sort + uint numOfLayers = 0; + for (const auto& id : startNodes) + for (const auto& nextId : workMap[id]._out) + visit(nextId, 1, numOfLayers); - ExecutionLayer layer; + // fill _sortedGraph + _sortedGraph = std::vector(numOfLayers+1); + for (const auto& p : workMap) { - // loop through OpSequences - uint i = 0; - while(i < vec.size()) { + OpSequence seq; + seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).protoContext()); - OpSequence seq; + for (const auto& id : p.second._opSeq) + seq.append(inMap.at(id).customOp(), inMap.at(id).protoContext()); - // loop through OpSequence - while(vec[i++] != -1) - seq.append(inMap.at(vec[i-1]).customOp(), inMap.at(vec[i-1]).protoContext()); - - layer.append(std::move(seq)); + _sortedGraph[p.second._layerNum].append(std::move(seq)); } - _sortedGraph.emplace_back(std::move(layer)); - } } + /////////////////////////////////////////////////////////////////// size_t OptimizedGraph::size() const { // std::lock_guard lock(_mutex); @@ -180,6 +139,40 @@ size_t OptimizedGraph::size() const { + // _sortedGraph = std::vector(numOfLayers+1); + // std::vector> printGraph = std::vector>(numOfLayers+1); + // for (const auto& p : workMap) { + + // OpSequence seq; + // seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).protoContext()); + // printGraph[p.second._layerNum].push_back(p.first); + + // for (const auto& id : p.second._opSeq) { + // seq.append(inMap.at(id).customOp(), inMap.at(id).protoContext()); + // printGraph[p.second._layerNum].push_back(id); + // } + + // _sortedGraph[p.second._layerNum].append(std::move(seq)); + // } + + // for (int i = 0; i < printGraph.size(); ++i) { + // printf("layer %i: ", i); + // for (int j = 0; j < printGraph[i].size(); ++j) + // printf("%i, ", printGraph[i][j]); + // printf("\n"); + // } + + // for (const auto& p : workMap) { + // printf("node %i: , layerNum %i: , opSeq: ", p.first,p.second._layerNum); + // std::for_each(p.second._opSeq.begin(), p.second._opSeq.end(), [] (const int &j) {printf("%i, ", j);}); + // printf(", ins: "); + // std::for_each(p.second._in.begin(), p.second._in.end(), [] (const int &j) {printf("%i, ", j);}); + // printf(", outs: "); + // std::for_each(p.second._out.begin(), p.second._out.end(), [] (const int &j) {printf("%i, ", j);}); + // printf("\n"); + // } + // printf("\n-----------------\n");; + // OptimizedGraph::OptimizedGraph(Graph* original) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 51293fdb77e9..4755fe62edcb 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -181,12 +181,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_3) { // sequence = layer1[0]; ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(/*7*/8, layer1[0].at(0).protoContext().nodeId()); // sequence = layer1[1]; ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(/*8*/7, layer1[1].at(0).protoContext().nodeId()); // checking last layer auto layer2 = optimized.layer(2); @@ -266,10 +266,10 @@ TEST_F(GraphAnalysisTests, optimizedGraph_4) { ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(/*10*/11, layer1[1].at(0).protoContext().nodeId()); ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); + ASSERT_EQ(/*11*/10, layer1[2].at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); @@ -295,679 +295,627 @@ TEST_F(GraphAnalysisTests, optimizedGraph_4) { ASSERT_EQ(14, layer3[0].at(0).protoContext().nodeId()); } -// TEST_F(GraphAnalysisTests, basic_toposort_test_5) { -// Graph graph; +TEST_F(GraphAnalysisTests, optimizedGraph_5) { -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Graph graph; -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); + Node f(sd::ops::multiply(), "f"); -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node g(sd::ops::multiply(), "g"); + Node h(sd::ops::multiply(), "h"); -// Node a(sd::ops::multiply(), "a"); -// Node b(sd::ops::add(), "b"); -// Node c(sd::ops::subtract(), "c"); -// Node d(sd::ops::add(), "d"); -// Node e(sd::ops::multiply(), "e"); -// Node f(sd::ops::multiply(), "f"); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"C", "D"}); -// Node g(sd::ops::multiply(), "g"); -// Node h(sd::ops::multiply(), "h"); + graph.addNode(c, {"a", "b"}); + graph.addNode(d, {"a", "b"}); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"C", "D"}); + graph.addNode(e, {"c", "d"}); + graph.addNode(f, {"c", "d"}); -// graph.addNode(c, {"a", "b"}); -// graph.addNode(d, {"a", "b"}); + graph.addNode(g, {"c", "e"}); + graph.addNode(h, {"d", "f"}); -// graph.addNode(e, {"c", "d"}); -// graph.addNode(f, {"c", "d"}); + // we just check that nodes were really added + ASSERT_EQ(8, graph.size()); -// graph.addNode(g, {"c", "e"}); -// graph.addNode(h, {"d", "f"}); + const auto& optimized = graph.optimizedGraph(); -// // we just check that nodes were really added -// ASSERT_EQ(8, graph.size()); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(4, optimized.numOfLayers()); -// auto optimized = graph.optimizedGraph(); + // checking first layer first + auto layer0 = optimized.layer(0); -// // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(4, optimized.numOfLayers()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer0.width()); + // auto sequence = layer0[0]; -// // checking first layer first -// auto layer0 = optimized.layer(0); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer0.width()); -// // auto sequence = layer0[0]; + ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(1, layer0[0].length()); + // sequence = layer0[1]; -// ASSERT_EQ(5, layer0[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); -// // sequence = layer0[1]; + // checking second layer now + auto layer1 = optimized.layer(1); -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(1, layer0[1].length()); -// ASSERT_EQ(6, layer0[1].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); -// // checking second layer now -// auto layer1 = optimized.layer(1); + // sequence = layer1[0]; -// // we expect layer has exactly 2 OpSequences -// ASSERT_EQ(2, layer1.width()); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(/*7*/8, layer1[0].at(0).protoContext().nodeId()); -// // sequence = layer1[0]; + // sequence = layer1[1]; -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(/*8*/7, layer1[1].at(0).protoContext().nodeId()); -// // sequence = layer1[1]; + // checking before last layer + auto layer2 = optimized.layer(2); -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer2.width()); + // sequence = layer2[0]; -// // checking before last layer -// auto layer2 = optimized.layer(2); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(/*9*/10, layer2[0].at(0).protoContext().nodeId()); + // sequence = layer2[1]; -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer2.width()); -// // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(/*10*/9, layer2[1].at(0).protoContext().nodeId()); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); -// // sequence = layer2[1]; + // checking last layer + auto layer3 = optimized.layer(3); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[1].length()); -// ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + // sequence = layer3[0]; -// // checking last layer -// auto layer3 = optimized.layer(3); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(/*11*/12, layer3[0].at(0).protoContext().nodeId()); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer3.width()); -// // sequence = layer3[0]; + // sequence = layer3[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[0].length()); -// ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(/*12*/11, layer3[1].at(0).protoContext().nodeId()); +} -// // sequence = layer3[1]; +TEST_F(GraphAnalysisTests, optimizedGraph_6) { + Graph graph; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[1].length()); -// ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); -// } + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// TEST_F(GraphAnalysisTests, basic_toposort_test_6) { -// Graph graph; + Node a(sd::ops::multiply(), "a"); + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::subtract(), "b2"); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node c1(sd::ops::add(), "c1"); + Node c2(sd::ops::multiply(), "c2"); + Node c3(sd::ops::subtract(), "c3"); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node d1(sd::ops::multiply(), "d1"); + Node d2(sd::ops::multiply(), "d2"); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + Node e(sd::ops::add(), "e"); -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addNode(a, {"A", "B"}); -// // E -// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "D"}); -// // F -// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addNode(c1, {"b1", "E"}); + graph.addNode(c2, {"b1", "b2"}); + graph.addNode(c3, {"b2", "F"}); -// Node a(sd::ops::multiply(), "a"); -// Node b1(sd::ops::add(), "b1"); -// Node b2(sd::ops::subtract(), "b2"); + graph.addNode(d1, {"c1", "c2"}); + graph.addNode(d2, {"c2", "c3"}); -// Node c1(sd::ops::add(), "c1"); -// Node c2(sd::ops::multiply(), "c2"); -// Node c3(sd::ops::subtract(), "c3"); + graph.addNode(e, {"d1", "d2"}); -// Node d1(sd::ops::multiply(), "d1"); -// Node d2(sd::ops::multiply(), "d2"); + // we just check that nodes were really added + ASSERT_EQ(9, graph.size()); -// Node e(sd::ops::add(), "e"); + const auto& optimized = graph.optimizedGraph(); -// graph.addNode(a, {"A", "B"}); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.numOfLayers()); -// graph.addNode(b1, {"a", "C"}); -// graph.addNode(b2, {"a", "D"}); + // checking first layer first + auto layer0 = optimized.layer(0); -// graph.addNode(c1, {"b1", "E"}); -// graph.addNode(c2, {"b1", "b2"}); -// graph.addNode(c3, {"b2", "F"}); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); -// graph.addNode(d1, {"c1", "c2"}); -// graph.addNode(d2, {"c2", "c3"}); + // auto sequence = layer0[0]; + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); -// graph.addNode(e, {"d1", "d2"}); + // checking second layer now + auto layer1 = optimized.layer(1); -// // we just check that nodes were really added -// ASSERT_EQ(9, graph.size()); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(2, layer1.width()); -// auto optimized = graph.optimizedGraph(); + // sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(/*8*/9, layer1[0].at(0).protoContext().nodeId()); -// // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(5, optimized.numOfLayers()); + // sequence = layer1[1]; + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(/*9*/8, layer1[1].at(0).protoContext().nodeId()); -// // checking first layer first -// auto layer0 = optimized.layer(0); + // checking midle layer + auto layer2 = optimized.layer(2); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer0.width()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(3, layer2.width()); -// // auto sequence = layer0[0]; -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(1, layer0[0].length()); -// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(/*10*/11, layer2[0].at(0).protoContext().nodeId()); -// // checking second layer now -// auto layer1 = optimized.layer(1); + // sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(/*11*/12, layer2[1].at(0).protoContext().nodeId()); -// // we expect layer has exactly 2 OpSequences -// ASSERT_EQ(2, layer1.width()); + // sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(/*12*/10, layer2[2].at(0).protoContext().nodeId()); -// // sequence = layer1[0]; -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); + // checking before last layer + auto layer3 = optimized.layer(3); -// // sequence = layer1[1]; -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer3.width()); + // sequence = layer3[0]; -// // checking midle layer -// auto layer2 = optimized.layer(2); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(3, layer2.width()); + // sequence = layer3[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[1].length()); + ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); -// // sequence = layer2[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); + // checking last layer + auto layer4 = optimized.layer(4); -// // sequence = layer2[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[1].length()); -// ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(1, layer4.width()); + // sequence = layer4[0]; -// // sequence = layer2[2]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[2].length()); -// ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); +} -// // checking before last layer -// auto layer3 = optimized.layer(3); +TEST_F(GraphAnalysisTests, optimizedGraph_7) { + Graph graph; -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer3.width()); -// // sequence = layer3[0]; + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[0].length()); -// ASSERT_EQ(13, layer3[0].at(0).protoContext().nodeId()); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::subtract(), "c"); + Node d(sd::ops::add(), "d"); + Node e(sd::ops::multiply(), "e"); -// // sequence = layer3[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[1].length()); -// ASSERT_EQ(14, layer3[1].at(0).protoContext().nodeId()); + graph.addNode(a, {"A", "B"}); -// // checking last layer -// auto layer4 = optimized.layer(4); + graph.addNode(b, {"a", "C"}); -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(1, layer4.width()); -// // sequence = layer4[0]; + graph.addNode(c, {"a", "b"}); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer4[0].length()); -// ASSERT_EQ(15, layer4[0].at(0).protoContext().nodeId()); -// } + graph.addNode(d, {"b", "c"}); -// TEST_F(GraphAnalysisTests, basic_toposort_test_7) { -// Graph graph; + graph.addNode(e, {"b", "c", "d"}); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we just check that nodes were really added + ASSERT_EQ(5, graph.size()); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + const auto& optimized = graph.optimizedGraph(); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect that OptimizedGraph has exactly 3 layer + ASSERT_EQ(5, optimized.numOfLayers()); -// Node a(sd::ops::multiply(), "a"); -// Node b(sd::ops::add(), "b"); -// Node c(sd::ops::subtract(), "c"); -// Node d(sd::ops::add(), "d"); -// Node e(sd::ops::multiply(), "e"); + // checking first layer first + auto layer0 = optimized.layer(0); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"a", "C"}); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer0.width()); + // auto sequence = layer0[0]; -// graph.addNode(c, {"a", "b"}); -// graph.addNode(d, {"b", "c"}); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); -// graph.addNode(e, {"b", "c", "d"}); + // checking second layer now + auto layer1 = optimized.layer(1); -// // we just check that nodes were really added -// ASSERT_EQ(5, graph.size()); + // we expect layer has exactly 2 OpSequences + ASSERT_EQ(1, layer1.width()); + // sequence = layer1[0]; -// auto optimized = graph.optimizedGraph(); + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); -// // we expect that OptimizedGraph has exactly 3 layer -// ASSERT_EQ(5, optimized.numOfLayers()); + // checking layer 2 + auto layer2 = optimized.layer(2); -// // checking first layer first -// auto layer0 = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + // sequence = layer2[0]; -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer0.width()); -// // auto sequence = layer0[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(1, layer0[0].length()); -// ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); + // checking layer 3 + auto layer3 = optimized.layer(3); -// // checking second layer now -// auto layer1 = optimized.layer(1); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer3.width()); + // sequence = layer3[0]; -// // we expect layer has exactly 2 OpSequences -// ASSERT_EQ(1, layer1.width()); -// // sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer3[0].length()); + ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); + // checking layer 3 + auto layer4 = optimized.layer(4); -// // checking layer 2 -// auto layer2 = optimized.layer(2); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer4.width()); + // sequence = layer4[0]; -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer2.width()); -// // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer4[0].length()); + ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); +} -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); +TEST_F(GraphAnalysisTests, optimizedGraph_8) { + Graph graph; -// // checking layer 3 -// auto layer3 = optimized.layer(3); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("E", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("F", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer3.width()); -// // sequence = layer3[0]; + Node a1(sd::ops::multiply(), "a1"); + Node a2(sd::ops::add(), "a2"); + Node a3(sd::ops::add(), "a3"); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer3[0].length()); -// ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); + Node b1(sd::ops::subtract(), "b1"); + Node b2(sd::ops::add(), "b2"); + Node b3(sd::ops::multiply(), "b3"); -// // checking layer 3 -// auto layer4 = optimized.layer(4); + graph.addNode(a1, {"A", "B"}); + graph.addNode(a2, {"C", "D"}); + graph.addNode(a3, {"E", "F"}); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer4.width()); -// // sequence = layer4[0]; + graph.addNode(b1, {"a1", "a2"}); + graph.addNode(b2, {"a1", "a2", "a3"}); + graph.addNode(b3, {"a2", "a3"}); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer4[0].length()); -// ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); -// } + // we just check that nodes were really added + ASSERT_EQ(6, graph.size()); -// TEST_F(GraphAnalysisTests, basic_toposort_test_8) { -// Graph graph; + const auto& optimized = graph.optimizedGraph(); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect that OptimizedGraph has exactly 2 layer + ASSERT_EQ(2, optimized.numOfLayers()); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // checking first layer first + auto layer0 = optimized.layer(0); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect layer has exactly 3 OpSequence + ASSERT_EQ(3, layer0.width()); + // auto sequence = layer0[0]; -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[0].length()); + ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); -// // E -// graph.addVariable("E", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // sequence = layer0[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[1].length()); + ASSERT_EQ(/*8*/9, layer0[1].at(0).protoContext().nodeId()); -// // F -// graph.addVariable("F", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // sequence = layer0[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer0[2].length()); + ASSERT_EQ(/*9*/8, layer0[2].at(0).protoContext().nodeId()); -// Node a1(sd::ops::multiply(), "a1"); -// Node a2(sd::ops::add(), "a2"); -// Node a3(sd::ops::add(), "a3"); + // checking second layer now + auto layer1 = optimized.layer(1); -// Node b1(sd::ops::subtract(), "b1"); -// Node b2(sd::ops::add(), "b2"); -// Node b3(sd::ops::multiply(), "b3"); + // we expect layer has exactly 3 OpSequences + ASSERT_EQ(3, layer1.width()); -// graph.addNode(a1, {"A", "B"}); -// graph.addNode(a2, {"C", "D"}); -// graph.addNode(a3, {"E", "F"}); + // sequence = layer1[0]; + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); -// graph.addNode(b1, {"a1", "a2"}); -// graph.addNode(b2, {"a1", "a2", "a3"}); -// graph.addNode(b3, {"a2", "a3"}); + // sequence = layer1[1]; -// // we just check that nodes were really added -// ASSERT_EQ(6, graph.size()); + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); -// auto optimized = graph.optimizedGraph(); + // sequence = layer1[2]; + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); +} -// // we expect that OptimizedGraph has exactly 2 layer -// ASSERT_EQ(2, optimized.numOfLayers()); +TEST_F(GraphAnalysisTests, optimizedGraph_9) { + // start graph -// // checking first layer first -// auto layer0 = optimized.layer(0); + Graph graph; -// // we expect layer has exactly 3 OpSequence -// ASSERT_EQ(3, layer0.width()); -// // auto sequence = layer0[0]; + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {4, 4, 4}, sd::DataType::INT32)); -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[0].length()); -// ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); + Node a(sd::ops::multiply(), "a"); -// // sequence = layer0[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[1].length()); -// ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); + Node b1(sd::ops::add(), "b1"); + Node b2(sd::ops::multiply(), "b2"); + Node b3(sd::ops::subtract(), "b3"); + Node b4(sd::ops::Pow(), "b4"); -// // sequence = layer0[2]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer0[2].length()); -// ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); + Node c1(sd::ops::Pow(), "c1"); + Node c2(sd::ops::subtract(), "c2"); + Node c3(sd::ops::multiply(), "c3"); + Node c4(sd::ops::add(), "c4"); -// // checking second layer now -// auto layer1 = optimized.layer(1); + Node c5(sd::ops::Pow(), "c5"); + Node c6(sd::ops::subtract(), "c6"); + Node c7(sd::ops::multiply(), "c7"); + Node c8(sd::ops::add(), "c8"); -// // we expect layer has exactly 3 OpSequences -// ASSERT_EQ(3, layer1.width()); + graph.addNode(a, {"A", "B"}); -// // sequence = layer1[0]; -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); + graph.addNode(b1, {"a", "C"}); + graph.addNode(b2, {"a", "C"}); + graph.addNode(b3, {"a", "C"}); + graph.addNode(b4, {"a", "C"}); -// // sequence = layer1[1]; + graph.addNode(c1, {"b1", "D"}); + graph.addNode(c2, {"b2", "D"}); + graph.addNode(c3, {"b3", "D"}); + graph.addNode(c4, {"b4", "D"}); -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); + graph.addNode(c5, {"b1", "D"}); + graph.addNode(c6, {"b2", "D"}); + graph.addNode(c7, {"b3", "D"}); + graph.addNode(c8, {"b4", "D"}); -// // sequence = layer1[2]; -// ASSERT_EQ(1, layer1[2].length()); -// ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); -// } + // we just check that nodes were really added + ASSERT_EQ(13, graph.size()); -// TEST_F(GraphAnalysisTests, basic_toposort_test_9) { -// // start graph + const auto& optimized = graph.optimizedGraph(); -// Graph graph; + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.numOfLayers()); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + auto layer = optimized.layer(0); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + // auto sequence = layer[0]; -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + // we expect that OpSequence has exactly 2 ops + ASSERT_EQ(1, layer[0].length()); + ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + auto layer1 = optimized.layer(1); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(4, layer1.width()); + // sequence = layer1[0]; -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[0].length()); + ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); -// Node a(sd::ops::multiply(), "a"); + // sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[1].length()); + ASSERT_EQ(/*7*/9, layer1[1].at(0).protoContext().nodeId()); -// Node b1(sd::ops::add(), "b1"); -// Node b2(sd::ops::multiply(), "b2"); -// Node b3(sd::ops::subtract(), "b3"); -// Node b4(sd::ops::Pow(), "b4"); + // sequence = layer1[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[2].length()); + ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); -// Node c1(sd::ops::Pow(), "c1"); -// Node c2(sd::ops::subtract(), "c2"); -// Node c3(sd::ops::multiply(), "c3"); -// Node c4(sd::ops::add(), "c4"); + // sequence = layer1[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer1[3].length()); + ASSERT_EQ(/*9*/7, layer1[3].at(0).protoContext().nodeId()); -// Node c5(sd::ops::Pow(), "c5"); -// Node c6(sd::ops::subtract(), "c6"); -// Node c7(sd::ops::multiply(), "c7"); -// Node c8(sd::ops::add(), "c8"); + auto layer2 = optimized.layer(2); + // we expect layer has exactly 4 OpSequence + ASSERT_EQ(8, layer2.width()); + // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[0].length()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); -// graph.addNode(a, {"A", "B"}); + // sequence = layer2[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[1].length()); + ASSERT_EQ(/*14*/11, layer2[1].at(0).protoContext().nodeId()); -// graph.addNode(b1, {"a", "C"}); -// graph.addNode(b2, {"a", "C"}); -// graph.addNode(b3, {"a", "C"}); -// graph.addNode(b4, {"a", "C"}); + // sequence = layer2[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[2].length()); + ASSERT_EQ(/*11*/12, layer2[2].at(0).protoContext().nodeId()); -// graph.addNode(c1, {"b1", "D"}); -// graph.addNode(c2, {"b2", "D"}); -// graph.addNode(c3, {"b3", "D"}); -// graph.addNode(c4, {"b4", "D"}); + // sequence = layer2[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[3].length()); + ASSERT_EQ(/*15*/13, layer2[3].at(0).protoContext().nodeId()); -// graph.addNode(c5, {"b1", "D"}); -// graph.addNode(c6, {"b2", "D"}); -// graph.addNode(c7, {"b3", "D"}); -// graph.addNode(c8, {"b4", "D"}); + // sequence = layer2[4]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[4].length()); + ASSERT_EQ(/*12*/14, layer2[4].at(0).protoContext().nodeId()); -// // we just check that nodes were really added -// ASSERT_EQ(13, graph.size()); + // sequence = layer2[5]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[5].length()); + ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); -// auto optimized = graph.optimizedGraph(); + // sequence = layer2[6]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[6].length()); + ASSERT_EQ(/*13*/17, layer2[6].at(0).protoContext().nodeId()); -// // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(3, optimized.numOfLayers()); - -// auto layer = optimized.layer(0); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer.width()); -// // auto sequence = layer[0]; - -// // we expect that OpSequence has exactly 2 ops -// ASSERT_EQ(1, layer[0].length()); -// ASSERT_EQ(5, layer[0].at(0).protoContext().nodeId()); - -// auto layer1 = optimized.layer(1); -// // we expect layer has exactly 4 OpSequence -// ASSERT_EQ(4, layer1.width()); -// // sequence = layer1[0]; - -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer1[0].length()); -// ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - -// // sequence = layer1[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer1[1].length()); -// ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); - -// // sequence = layer1[2]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer1[2].length()); -// ASSERT_EQ(8, layer1[2].at(0).protoContext().nodeId()); - -// // sequence = layer1[3]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer1[3].length()); -// ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); - -// auto layer2 = optimized.layer(2); -// // we expect layer has exactly 4 OpSequence -// ASSERT_EQ(8, layer2.width()); -// // sequence = layer2[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[0].length()); -// ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - -// // sequence = layer2[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[1].length()); -// ASSERT_EQ(14, layer2[1].at(0).protoContext().nodeId()); - -// // sequence = layer2[2]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[2].length()); -// ASSERT_EQ(11, layer2[2].at(0).protoContext().nodeId()); - -// // sequence = layer2[3]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[3].length()); -// ASSERT_EQ(15, layer2[3].at(0).protoContext().nodeId()); - -// // sequence = layer2[4]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[4].length()); -// ASSERT_EQ(12, layer2[4].at(0).protoContext().nodeId()); - -// // sequence = layer2[5]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[5].length()); -// ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); - -// // sequence = layer2[6]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[6].length()); -// ASSERT_EQ(13, layer2[6].at(0).protoContext().nodeId()); - -// // sequence = layer2[7]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, layer2[7].length()); -// ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); -// } + // sequence = layer2[7]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, layer2[7].length()); + ASSERT_EQ(/*17*/15, layer2[7].at(0).protoContext().nodeId()); +} -// TEST_F(GraphAnalysisTests, basic_toposort_test_10) { -// Graph graph; +TEST_F(GraphAnalysisTests, optimizedGraph_10) { + Graph graph; -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"a", "C"}); + graph.addNode(c, {"a", "D"}); + graph.addNode(d, {"a", "b", "c"}); -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); -// Node a(sd::ops::multiply(), "a"); -// Node b(sd::ops::add(), "b"); -// Node c(sd::ops::multiply(), "c"); -// Node d(sd::ops::subtract(), "d"); + const auto& optimized = graph.optimizedGraph(); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"a", "C"}); -// graph.addNode(c, {"a", "D"}); -// graph.addNode(d, {"a", "b", "c"}); + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(3, optimized.numOfLayers()); -// // we just check that nodes were really added -// ASSERT_EQ(4, graph.size()); + auto layer = optimized.layer(0); -// auto optimized = graph.optimizedGraph(); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer.width()); + auto sequence = layer[0]; -// // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(3, optimized.numOfLayers()); - -// auto layer = optimized.layer(0); - -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer.width()); -// auto sequence = layer[0]; - -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); - -// auto layer1 = optimized.layer(1); - -// // we expect layer has exactly 2 OpSequence -// ASSERT_EQ(2, layer1.width()); -// sequence = layer1[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); -// sequence = layer1[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); - -// auto layer2 = optimized.layer(2); -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(1, layer2.width()); -// sequence = layer2[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); -// } + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); -// TEST_F(GraphAnalysisTests, basic_toposort_test_11) { -// Graph graph; + auto layer1 = optimized.layer(1); -// // A -// graph.addVariable("A", NDArrayFactory::create('c', {3}, {1, 1, 1})); + // we expect layer has exactly 2 OpSequence + ASSERT_EQ(2, layer1.width()); + sequence = layer1[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*6*/7, sequence.at(0).protoContext().nodeId()); + sequence = layer1[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*7*/6, sequence.at(0).protoContext().nodeId()); -// // B -// graph.addVariable("B", NDArrayFactory::create('c', {3}, {2, 2, 2})); + auto layer2 = optimized.layer(2); + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(1, layer2.width()); + sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_11) { + Graph graph; + + graph.addVariable("A", NDArray('c', {3}, {1, 1, 1}, sd::DataType::INT32)); + graph.addVariable("B", NDArray('c', {3}, {2, 2, 2}, sd::DataType::INT32)); + graph.addVariable("C", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); + graph.addVariable("D", NDArray('c', {3}, {3, 3, 3}, sd::DataType::INT32)); -// // C -// graph.addVariable("C", NDArrayFactory::create('c', {3}, {3, 3, 3})); + Node a(sd::ops::multiply(), "a"); + Node b(sd::ops::add(), "b"); + Node c(sd::ops::multiply(), "c"); + Node d(sd::ops::subtract(), "d"); -// // D -// graph.addVariable("D", NDArrayFactory::create('c', {3}, {3, 3, 3})); + graph.addNode(a, {"A", "B"}); + graph.addNode(b, {"A", "C"}); + graph.addNode(c, {"B", "D"}); + graph.addNode(d, {"C", "D"}); -// Node a(sd::ops::multiply(), "a"); -// Node b(sd::ops::add(), "b"); -// Node c(sd::ops::multiply(), "c"); -// Node d(sd::ops::subtract(), "d"); + // we just check that nodes were really added + ASSERT_EQ(4, graph.size()); -// graph.addNode(a, {"A", "B"}); -// graph.addNode(b, {"A", "C"}); -// graph.addNode(c, {"B", "D"}); -// graph.addNode(d, {"C", "D"}); + const auto& optimized = graph.optimizedGraph(); -// // we just check that nodes were really added -// ASSERT_EQ(4, graph.size()); + // we expect that OptimizedGraph has exactly 1 layer + ASSERT_EQ(1, optimized.numOfLayers()); -// auto optimized = graph.optimizedGraph(); + auto layer = optimized.layer(0); -// // we expect that OptimizedGraph has exactly 1 layer -// ASSERT_EQ(1, optimized.numOfLayers()); - -// auto layer = optimized.layer(0); - -// // we expect layer has exactly 1 OpSequence -// ASSERT_EQ(4, layer.width()); -// auto sequence = layer[0]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); -// sequence = layer[1]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); -// sequence = layer[2]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); -// sequence = layer[3]; -// // we expect that OpSequence has exactly 1 ops -// ASSERT_EQ(1, sequence.length()); -// ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); -// } + // we expect layer has exactly 1 OpSequence + ASSERT_EQ(4, layer.width()); + auto sequence = layer[0]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*5*/6, sequence.at(0).protoContext().nodeId()); + sequence = layer[1]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*6*/5, sequence.at(0).protoContext().nodeId()); + sequence = layer[2]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*7*/8, sequence.at(0).protoContext().nodeId()); + sequence = layer[3]; + // we expect that OpSequence has exactly 1 ops + ASSERT_EQ(1, sequence.length()); + ASSERT_EQ(/*8*/7, sequence.at(0).protoContext().nodeId()); +} // TEST_F(GraphAnalysisTests, test_cond_1) { // auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); From 768ffb04922de730da0bb6e1780f6d4b184cfb98 Mon Sep 17 00:00:00 2001 From: Yurii Date: Tue, 26 May 2020 14:07:23 +0300 Subject: [PATCH 128/233] - provide sort of OpSequences within ExecutionLayer - uncomment and update code in GraphExecutor::execute method Signed-off-by: Yurii --- .../include/graph/execution/ExecutionLayer.h | 8 ++ .../graph/execution/impl/ExecutionLayer.cpp | 19 ++++ .../graph/execution/impl/GraphExecutor.cpp | 24 ++--- libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 + .../layers_tests/GraphAnalysisTests.cpp | 91 +++++++------------ 5 files changed, 76 insertions(+), 70 deletions(-) diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index 9149c1c7f4fa..83395a65d7a7 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -65,7 +65,15 @@ class SD_EXPORT ExecutionLayer { */ void append(OpSequence&& sequence); void append(const OpSequence& sequence); + + /** + * sort OpSequences in increasing order in respect to id of fist node in sequence + * @param sequence + */ + void sortOpSequences(); + }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index 5cb13496b0ba..bff7cea61602 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -70,5 +70,24 @@ ExecutionLayer &ExecutionLayer::operator=(ExecutionLayer &&other) noexcept { return *this; } + +//////////////////////////////////////////////////////////////////////// +void ExecutionLayer::sortOpSequences() { // bubble sort + + const int numOfOpSequences = this->width(); + + OpSequence temp; + + for (int i = 0; i < numOfOpSequences - 1; ++i) + for (int j = 0; j < numOfOpSequences - 1 - i; ++j) + if (_sequences[j][0].protoContext().nodeId() > _sequences[j + 1][0].protoContext().nodeId()) { + temp = _sequences[j]; + _sequences[j] = _sequences[j + 1]; + _sequences[j + 1] = temp; + + } +} + + } // namespace graph } // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 619b87fd43ca..5473395cbca9 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -85,20 +85,20 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, * execute them one by one sequentially */ Nd4jStatus result = Status::OK(); - // for (uint64_t l = 0; l < graph.layers(); l++) { - // const auto &layer = graph.layer(l); + for (uint64_t l = 0; l < graph.numOfLayers(); l++) { + const auto &layer = graph.layer(l); - // for (uint64_t o = 0; o < layer.width(); o++) { - // execute(layer[o], graph, proxy, -1); - // } + for (uint64_t o = 0; o < layer.width(); o++) { + execute(layer[o], graph, proxy, -1); + } - // // optionally block until all sequences in this layer processed - // if (layer.width() > 0 && numDevices > 1) - // for (uint64_t o = 0; o < layer.width(); o++) { - // result = layer[o].wait(); - // if (result != Status::OK()) return result; - // } - // } + // optionally block until all sequences in this layer processed + if (layer.width() > 0 && numDevices > 1) + for (uint64_t o = 0; o < layer.width(); o++) { + result = layer[o].wait(); + if (result != Status::OK()) return result; + } + } return result; } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 1723e03ef82d..4e894340dd07 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -118,6 +118,10 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS _sortedGraph[p.second._layerNum].append(std::move(seq)); } + + // sort _sortedGraph + for (auto& l : _sortedGraph) + l.sortOpSequences(); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 4755fe62edcb..94aeeb7b5b40 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -181,12 +181,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_3) { // sequence = layer1[0]; ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(/*7*/8, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); // sequence = layer1[1]; ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(/*8*/7, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); // checking last layer auto layer2 = optimized.layer(2); @@ -266,10 +266,10 @@ TEST_F(GraphAnalysisTests, optimizedGraph_4) { ASSERT_EQ(9, layer1[0].at(0).protoContext().nodeId()); ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(/*10*/11, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(10, layer1[1].at(0).protoContext().nodeId()); ASSERT_EQ(1, layer1[2].length()); - ASSERT_EQ(/*11*/10, layer1[2].at(0).protoContext().nodeId()); + ASSERT_EQ(11, layer1[2].at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); @@ -361,12 +361,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_5) { // sequence = layer1[0]; ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(/*7*/8, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(7, layer1[0].at(0).protoContext().nodeId()); // sequence = layer1[1]; ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(/*8*/7, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(8, layer1[1].at(0).protoContext().nodeId()); // checking before last layer auto layer2 = optimized.layer(2); @@ -377,29 +377,26 @@ TEST_F(GraphAnalysisTests, optimizedGraph_5) { // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(/*9*/10, layer2[0].at(0).protoContext().nodeId()); + ASSERT_EQ(9, layer2[0].at(0).protoContext().nodeId()); // sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(/*10*/9, layer2[1].at(0).protoContext().nodeId()); + ASSERT_EQ(10, layer2[1].at(0).protoContext().nodeId()); // checking last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer3.width()); - // sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(/*11*/12, layer3[0].at(0).protoContext().nodeId()); - - // sequence = layer3[1]; + ASSERT_EQ(11, layer3[0].at(0).protoContext().nodeId()); // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer3[1].length()); - ASSERT_EQ(/*12*/11, layer3[1].at(0).protoContext().nodeId()); + ASSERT_EQ(12, layer3[1].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, optimizedGraph_6) { @@ -466,11 +463,11 @@ TEST_F(GraphAnalysisTests, optimizedGraph_6) { // sequence = layer1[0]; ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(/*8*/9, layer1[0].at(0).protoContext().nodeId()); + ASSERT_EQ(8, layer1[0].at(0).protoContext().nodeId()); // sequence = layer1[1]; ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(/*9*/8, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(9, layer1[1].at(0).protoContext().nodeId()); // checking midle layer auto layer2 = optimized.layer(2); @@ -478,27 +475,23 @@ TEST_F(GraphAnalysisTests, optimizedGraph_6) { // we expect layer has exactly 2 OpSequence ASSERT_EQ(3, layer2.width()); - // sequence = layer2[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(/*10*/11, layer2[0].at(0).protoContext().nodeId()); + ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - // sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(/*11*/12, layer2[1].at(0).protoContext().nodeId()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); - // sequence = layer2[2]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(/*12*/10, layer2[2].at(0).protoContext().nodeId()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); // checking before last layer auto layer3 = optimized.layer(3); // we expect layer has exactly 2 OpSequence ASSERT_EQ(2, layer3.width()); - // sequence = layer3[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer3[0].length()); @@ -514,7 +507,6 @@ TEST_F(GraphAnalysisTests, optimizedGraph_6) { // we expect layer has exactly 2 OpSequence ASSERT_EQ(1, layer4.width()); - // sequence = layer4[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer4[0].length()); @@ -646,21 +638,18 @@ TEST_F(GraphAnalysisTests, optimizedGraph_8) { // we expect layer has exactly 3 OpSequence ASSERT_EQ(3, layer0.width()); - // auto sequence = layer0[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer0[0].length()); ASSERT_EQ(7, layer0[0].at(0).protoContext().nodeId()); - // sequence = layer0[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer0[1].length()); - ASSERT_EQ(/*8*/9, layer0[1].at(0).protoContext().nodeId()); + ASSERT_EQ(8, layer0[1].at(0).protoContext().nodeId()); - // sequence = layer0[2]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer0[2].length()); - ASSERT_EQ(/*9*/8, layer0[2].at(0).protoContext().nodeId()); + ASSERT_EQ(9, layer0[2].at(0).protoContext().nodeId()); // checking second layer now auto layer1 = optimized.layer(1); @@ -668,16 +657,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_8) { // we expect layer has exactly 3 OpSequences ASSERT_EQ(3, layer1.width()); - // sequence = layer1[0]; ASSERT_EQ(1, layer1[0].length()); ASSERT_EQ(10, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; - ASSERT_EQ(1, layer1[1].length()); ASSERT_EQ(11, layer1[1].at(0).protoContext().nodeId()); - // sequence = layer1[2]; ASSERT_EQ(1, layer1[2].length()); ASSERT_EQ(12, layer1[2].at(0).protoContext().nodeId()); } @@ -737,7 +722,6 @@ TEST_F(GraphAnalysisTests, optimizedGraph_9) { auto layer = optimized.layer(0); // we expect layer has exactly 1 OpSequence ASSERT_EQ(1, layer.width()); - // auto sequence = layer[0]; // we expect that OpSequence has exactly 2 ops ASSERT_EQ(1, layer[0].length()); @@ -746,16 +730,14 @@ TEST_F(GraphAnalysisTests, optimizedGraph_9) { auto layer1 = optimized.layer(1); // we expect layer has exactly 4 OpSequence ASSERT_EQ(4, layer1.width()); - // sequence = layer1[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer1[0].length()); ASSERT_EQ(6, layer1[0].at(0).protoContext().nodeId()); - // sequence = layer1[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer1[1].length()); - ASSERT_EQ(/*7*/9, layer1[1].at(0).protoContext().nodeId()); + ASSERT_EQ(7, layer1[1].at(0).protoContext().nodeId()); // sequence = layer1[2]; // we expect that OpSequence has exactly 1 ops @@ -765,50 +747,43 @@ TEST_F(GraphAnalysisTests, optimizedGraph_9) { // sequence = layer1[3]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer1[3].length()); - ASSERT_EQ(/*9*/7, layer1[3].at(0).protoContext().nodeId()); + ASSERT_EQ(9, layer1[3].at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); // we expect layer has exactly 4 OpSequence ASSERT_EQ(8, layer2.width()); - // sequence = layer2[0]; + // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[0].length()); ASSERT_EQ(10, layer2[0].at(0).protoContext().nodeId()); - // sequence = layer2[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[1].length()); - ASSERT_EQ(/*14*/11, layer2[1].at(0).protoContext().nodeId()); + ASSERT_EQ(11, layer2[1].at(0).protoContext().nodeId()); - // sequence = layer2[2]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[2].length()); - ASSERT_EQ(/*11*/12, layer2[2].at(0).protoContext().nodeId()); + ASSERT_EQ(12, layer2[2].at(0).protoContext().nodeId()); - // sequence = layer2[3]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[3].length()); - ASSERT_EQ(/*15*/13, layer2[3].at(0).protoContext().nodeId()); + ASSERT_EQ(13, layer2[3].at(0).protoContext().nodeId()); - // sequence = layer2[4]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[4].length()); - ASSERT_EQ(/*12*/14, layer2[4].at(0).protoContext().nodeId()); + ASSERT_EQ(14, layer2[4].at(0).protoContext().nodeId()); - // sequence = layer2[5]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[5].length()); - ASSERT_EQ(16, layer2[5].at(0).protoContext().nodeId()); + ASSERT_EQ(15, layer2[5].at(0).protoContext().nodeId()); - // sequence = layer2[6]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[6].length()); - ASSERT_EQ(/*13*/17, layer2[6].at(0).protoContext().nodeId()); + ASSERT_EQ(16, layer2[6].at(0).protoContext().nodeId()); - // sequence = layer2[7]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, layer2[7].length()); - ASSERT_EQ(/*17*/15, layer2[7].at(0).protoContext().nodeId()); + ASSERT_EQ(17, layer2[7].at(0).protoContext().nodeId()); } TEST_F(GraphAnalysisTests, optimizedGraph_10) { @@ -854,11 +829,11 @@ TEST_F(GraphAnalysisTests, optimizedGraph_10) { sequence = layer1[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*6*/7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); sequence = layer1[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*7*/6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); auto layer2 = optimized.layer(2); // we expect layer has exactly 1 OpSequence @@ -902,19 +877,19 @@ TEST_F(GraphAnalysisTests, optimizedGraph_11) { auto sequence = layer[0]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*5*/6, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(5, sequence.at(0).protoContext().nodeId()); sequence = layer[1]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*6*/5, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(6, sequence.at(0).protoContext().nodeId()); sequence = layer[2]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*7*/8, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(7, sequence.at(0).protoContext().nodeId()); sequence = layer[3]; // we expect that OpSequence has exactly 1 ops ASSERT_EQ(1, sequence.length()); - ASSERT_EQ(/*8*/7, sequence.at(0).protoContext().nodeId()); + ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } // TEST_F(GraphAnalysisTests, test_cond_1) { From 63f623b4e126791cefe8131d4a3316a4a4736694 Mon Sep 17 00:00:00 2001 From: Yurii Date: Tue, 26 May 2020 18:21:54 +0300 Subject: [PATCH 129/233] - correct topo sort code: now take into account that nodes id may be present in variableSpace Signed-off-by: Yurii --- libnd4j/include/graph/OptimizedGraph.h | 9 +-- libnd4j/include/graph/impl/Graph.cpp | 10 ++- libnd4j/include/graph/impl/OptimizedGraph.cpp | 38 ++++++++--- libnd4j/include/system/pointercast.h | 4 +- .../layers_tests/GraphAnalysisTests.cpp | 68 ++++++++++--------- 5 files changed, 74 insertions(+), 55 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 69b51048d132..dff2494ce6fe 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -69,6 +69,11 @@ class SD_EXPORT OptimizedGraph { // move assignment operator OptimizedGraph& operator=(OptimizedGraph&& other) noexcept; + /** + * prints out graph content + */ + void printOut() const; + }; @@ -137,10 +142,6 @@ class SD_EXPORT OptimizedGraph { // */ // size_t size() const; -// /** -// * This method prints out graph content -// */ -// void printOut() const; // protected: // /* diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 3cc0531c871c..28268ee539fd 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -273,12 +273,10 @@ void Graph::printOut() { fflush(stdout); - // if (size() > 0) { - // nd4j_printf("\nPrinting out Nodes...\n", ""); - - // // since we need structure - we'll print out nodes of OptimizedGraph - // optimizedGraph().printOut(); - // } + if (size() > 0) { + nd4j_printf("\nPrinting out Nodes...\n", ""); + optimizedGraph().printOut(); + } } Nd4jStatus Graph::validateNode(Node *node) { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 4e894340dd07..861d616a9eb9 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -61,10 +61,11 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& p : inMap) { for (const auto& v : p.second.input()) - if (!varSpace.hasVariable(v.first)) { + if (v.first >= inMap.begin()->first) { workMap[p.first]._in.push_back(v.first); workMap[v.first]._out.push_back(p.first); } + if(workMap[p.first]._in.empty()) startNodes.push_back(p.first); } @@ -86,6 +87,7 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS } } + // delete nodes present in _opSeq, their ids are already stored in nodesToDelete for (const auto& i : nodesToDelete) workMap.erase(i); @@ -106,6 +108,7 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& nextId : workMap[id]._out) visit(nextId, 1, numOfLayers); + // fill _sortedGraph _sortedGraph = std::vector(numOfLayers+1); for (const auto& p : workMap) { @@ -120,8 +123,8 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS } // sort _sortedGraph - for (auto& l : _sortedGraph) - l.sortOpSequences(); + // for (auto& l : _sortedGraph) + // l.sortOpSequences(); } @@ -141,8 +144,30 @@ size_t OptimizedGraph::size() const { return size; } +void OptimizedGraph::printOut() const { + + for (uint i = 0; i < _sortedGraph.size(); ++i) { + printf("Layer [%u]\n", i); + for (uint j = 0; j < _sortedGraph[i].width(); ++j) + _sortedGraph[i][j].printOut(); + } + + printf("And simple print:\n"); + for (int i = 0; i < _sortedGraph.size(); ++i) { + printf("layer %i: ", i); + for (int j = 0; j < _sortedGraph[i].width(); ++j) { + printf("("); + for (int k = 0; k < _sortedGraph[i][j].length(); ++k) { + printf("%i, ", _sortedGraph[i][j][k].protoContext().nodeId()); + } + printf("), "); + } + printf("\n"); + } +} +// std::for_each(inMap.begin(), inMap.end(), [] (const std::pair &p) {printf("node id %i \n", p.first);}); // _sortedGraph = std::vector(numOfLayers+1); // std::vector> printGraph = std::vector>(numOfLayers+1); // for (const auto& p : workMap) { @@ -495,13 +520,6 @@ size_t OptimizedGraph::size() const { // return true; // } -// void OptimizedGraph::printOut() const { -// for (uint64_t o = 0; o < _onion.size(); o++) { -// const auto& layer = _onion.at(o); -// printf("Layer [%lu]\n", o); -// for (uint64_t l = 0; l < layer.width(); l++) layer.at(l).printOut(); -// } -// } } // namespace graph diff --git a/libnd4j/include/system/pointercast.h b/libnd4j/include/system/pointercast.h index aedd9afbe9e5..ead4f4f70b5a 100644 --- a/libnd4j/include/system/pointercast.h +++ b/libnd4j/include/system/pointercast.h @@ -69,8 +69,8 @@ typedef int Nd4jStatus; #elif __GNUC__ -#include -#define MAP_IMPL std::unordered_map +#include +#define MAP_IMPL std::map #else diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 94aeeb7b5b40..edf60f3e9442 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -892,36 +892,38 @@ TEST_F(GraphAnalysisTests, optimizedGraph_11) { ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } -// TEST_F(GraphAnalysisTests, test_cond_1) { -// auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); - -// auto optimized = graph.optimizedGraph(); -// /* -// some infor that would be useful for implementation -// currently on optimization graph is passing next data - -// Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation -// class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation -// type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; -// Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: -// cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: -// -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: -// 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, -// 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: -// 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: -// cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: -// -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; -// Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; -// Operation type: 21; Operation class: -1719689536 - -// as it can be seen cond/LinSpace is not connected with any switch node(s) that -// causes wrong results of optimization. also maybe to cover all conditional -// operations will be need "Operation class", but this have to discovered deeper. - -// All above is true for test_cond_2 -// */ -// } - -// TEST_F(GraphAnalysisTests, test_cond_2) { -// auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); -// } +TEST_F(GraphAnalysisTests, test_cond_1) { + auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + + const auto& optimized = graph.optimizedGraph(); + // graph.printOut(); + /* + some infor that would be useful for implementation + currently on optimization graph is passing next data + + Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation + class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation + type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; + Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: + cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: + -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: + 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, + 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: + 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: + cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: + -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; + Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; + Operation type: 21; Operation class: -1719689536 + + as it can be seen cond/LinSpace is not connected with any switch node(s) that + causes wrong results of optimization. also maybe to cover all conditional + operations will be need "Operation class", but this have to discovered deeper. + + All above is true for test_cond_2 + */ +} + +TEST_F(GraphAnalysisTests, test_cond_2) { + auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); + // graph.printOut(); +} From 375b64421c4029c7583ed2eb3ac0c0a3d3d5e8d4 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 26 May 2020 19:48:46 +0300 Subject: [PATCH 130/233] minor tests CMakeLists.txt fix Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 5ae202542e7e..175afbda6a22 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -61,7 +61,7 @@ else() endif() if (SD_CPU AND SD_SANITIZE) - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} -fsanitize=address") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() # CUDA? endif() From c4e1c72049e38fd6a7c2cd5a0ddd1e65968b15fc Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 26 May 2020 20:01:22 +0300 Subject: [PATCH 131/233] - id propagation fixed for Logic ops - id propagation fix for copy/move constructor/operator Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/Node.cpp | 41 +++++++++++++++++------------ 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index aaae8d129576..d5dc65af2230 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -555,34 +555,34 @@ Node::Node(const FlatNode *node) { for (auto v : _dimensions) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int)node->extraParams()->size(); e++) { + for (int e = 0; e < (int) node->extraParams()->size(); e++) { block.appendT(static_cast(node->extraParams()->Get(e))); } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int)node->extraBools()->size(); e++) { + for (int e = 0; e < (int) node->extraBools()->size(); e++) { block.appendB(node->extraBools()->Get(e)); } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int)node->extraInteger()->size(); e++) { + for (int e = 0; e < (int) node->extraInteger()->size(); e++) { block.appendI(node->extraInteger()->Get(e)); } if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int)node->extraTypes()->size(); e++) { - block.appendD((sd::DataType)node->extraTypes()->Get(e)); + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block.appendD((sd::DataType) node->extraTypes()->Get(e)); } } this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType( - _opType, (int)node->input()->size(), - (int)block.getIArguments().size(), - (int)block.getTArguments().size(), (int)_opNum)); + _opType, (int) node->input()->size(), + (int) block.getIArguments().size(), + (int) block.getTArguments().size(), (int) _opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } else if (node->inputPaired() != nullptr && - node->inputPaired()->size() > 0) { + node->inputPaired()->size() > 0) { ContextPrototype block(nullptr, this->id(), false); for (int e = 0; e < this->input().size(); e++) { @@ -593,34 +593,37 @@ Node::Node(const FlatNode *node) { for (auto v : _dimensions) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int)node->extraParams()->size(); e++) { + for (int e = 0; e < (int) node->extraParams()->size(); e++) { block.appendT(static_cast(node->extraParams()->Get(e))); } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int)node->extraBools()->size(); e++) { + for (int e = 0; e < (int) node->extraBools()->size(); e++) { block.appendB(node->extraBools()->Get(e)); } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int)node->extraInteger()->size(); e++) { + for (int e = 0; e < (int) node->extraInteger()->size(); e++) { block.appendI(node->extraInteger()->Get(e)); } if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int)node->extraTypes()->size(); e++) { - block.appendD((sd::DataType)node->extraTypes()->Get(e)); + for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + block.appendD((sd::DataType) node->extraTypes()->Get(e)); } } this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType( - _opType, (int)node->inputPaired()->size(), - (int)block.getIArguments().size(), - (int)block.getTArguments().size(), (int)_opNum)); + _opType, (int) node->inputPaired()->size(), + (int) block.getIArguments().size(), + (int) block.getTArguments().size(), (int) _opNum)); block.setOpDescriptor(this->customOp()->getOpDescriptor()); } + } else if (this->_opType == OpType_LOGIC) { + ContextPrototype block(nullptr, this->id()); + this->setContextPrototype(block); } else if (this->_opType == OpType_CUSTOM) { auto op = sd::ops::OpRegistrator::getInstance()->getOperation(this->opNum()); @@ -705,6 +708,7 @@ Node::Node(const Node &other) noexcept { _scope_id = other._scope_id; _scope_name = other._scope_name; _rewindNode = other._rewindNode; + _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -738,6 +742,7 @@ Node &Node::operator=(const Node &other) noexcept { _scope_id = other._scope_id; _scope_name = other._scope_name; _rewindNode = other._rewindNode; + _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -771,6 +776,7 @@ Node::Node(Node &&other) noexcept { _name = std::move(other._name); _scope_name = std::move(other._scope_name); _rewindNode = other._rewindNode; + _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -806,6 +812,7 @@ Node &Node::operator=(Node &&other) noexcept { _name = std::move(other._name); _scope_name = std::move(other._scope_name); _rewindNode = other._rewindNode; + _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; From eaffa43ebe34db6d0a2bca48800771ceb7c05c08 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 26 May 2020 20:25:50 +0300 Subject: [PATCH 132/233] - inputs propagation fixed for Logic ops - Node name is propagated to ContextPrototype now Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/Node.cpp | 11 +++++++++++ libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index d5dc65af2230..4ee866c6bbb6 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -551,6 +551,8 @@ Node::Node(const FlatNode *node) { if (node->input() != nullptr && node->input()->size() > 0) { ContextPrototype block(nullptr, this->id(), false); + if (!this->name().empty()) + block.setName(this->name()); for (auto v : _dimensions) block.appendA(v); @@ -623,6 +625,13 @@ Node::Node(const FlatNode *node) { } } else if (this->_opType == OpType_LOGIC) { ContextPrototype block(nullptr, this->id()); + if (!this->name().empty()) + block.setName(this->name()); + + for (int e = 0; e < this->input().size(); e++) { + block.pickInput(this->input().at(e)); + } + this->setContextPrototype(block); } else if (this->_opType == OpType_CUSTOM) { auto op = @@ -633,6 +642,8 @@ Node::Node(const FlatNode *node) { } ContextPrototype block(nullptr, this->id()); + if (!this->name().empty()) + block.setName(this->name()); for (int e = 0; e < this->input().size(); e++) { block.pickInput(this->input().at(e)); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index edf60f3e9442..5117e1e1d430 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -896,7 +896,7 @@ TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); - // graph.printOut(); + graph.printOut(); /* some infor that would be useful for implementation currently on optimization graph is passing next data From e5191e8abfdfeadccb1b4a4dc40945857b5c733d Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 27 May 2020 17:14:31 +0300 Subject: [PATCH 133/233] dependencies Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 8 +++++++ libnd4j/include/graph/Variable.h | 8 +++++++ libnd4j/include/graph/impl/Graph.cpp | 7 ++++++ libnd4j/include/graph/impl/Node.cpp | 29 +++++++++++++++++++++++++ libnd4j/include/graph/impl/Variable.cpp | 29 +++++++++++++++++++++++++ 5 files changed, 81 insertions(+) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 1dc5ad7a9526..0f2e78488771 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -45,8 +45,12 @@ class SD_EXPORT Node { ContextPrototype _protoContext; Nd4jLong _opNum; int _id = 0; + std::vector> _input; std::vector> _output; + std::vector> _dependencies; + std::vector _stringDependencies; + std::vector _dimensions; std::vector _referencedBy; @@ -152,6 +156,7 @@ class SD_EXPORT Node { int id() const; const std::vector> &input() const; const std::vector> &output() const; + const std::vector> &dependencies() const; Nd4jLong getFrameId(); void setFrameId(Nd4jLong frameId); @@ -233,6 +238,9 @@ class SD_EXPORT Node { template Node *asT(); + // this method converts string deps to int deps + void actualizeDependencies(const MAP_IMPL &lookupTable) const; + FORCEINLINE void pullValues(Node *other) { this->_dataType = other->dataType(); this->_protoContext = other->protoContext(); diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 4a3e241addb3..bd2d8634d799 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -76,6 +76,9 @@ class SD_EXPORT Variable { VariableType _variableType = VariableType::NDARRAY; + std::vector> _dependencies; + std::vector _stringDependencies; + public: explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); @@ -130,6 +133,11 @@ class SD_EXPORT Variable { const std::vector &shape() const; DataType dataType() const; + const std::vector>& dependencies() const; + + // this method converts string deps to int deps + void actualizeDependencies(const MAP_IMPL &lookupTable) const; + #ifndef __JAVACPP_HACK__ /** * This method returns offset to this Variable in FlatBuffer diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 28268ee539fd..d7ba681a9faa 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -191,6 +191,13 @@ Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager _symbolicLookupTable[nnode.name()] = nnode.id(); } } + + // now, once everything is deserializerd, time to roll through Variables/Nodes and update dependencies + for (const auto &v: _unmapped) + v.second.actualizeDependencies(_symbolicLookupTable); + + for (const auto &v:_variableSpace.variables()) + v->actualizeDependencies(_symbolicLookupTable); } /** diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 4ee866c6bbb6..6094016ac02c 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -384,6 +384,19 @@ Node::Node(sd::ops::DeclarableOp *customOp, int id, void Node::setOpType(OpType opType) { this->_opType = opType; } +const std::vector>& Node::dependencies() const { + return _dependencies; +} + +void Node::actualizeDependencies(const MAP_IMPL &lookupTable) const { + for (const auto &v: _stringDependencies) { + if (lookupTable.count(v) == 0) + throw std::runtime_error("Unknown Node dependency found: [" + v + "]"); + + const_cast(this)->_dependencies.emplace_back(std::pair{lookupTable.at(v), 0}); + } +} + Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, @@ -503,6 +516,22 @@ Node::Node(const FlatNode *node) { } } + // reading control deps, and filling _dependencies field + if (node->varControlDeps() != nullptr && node->varControlDeps()->size() > 0) { + for (int e = 0; e < node->varControlDeps()->size(); e++) + _stringDependencies.emplace_back(node->varControlDeps()->Get(e)->str()); + } + + if (node->controlDepFor() != nullptr && node->controlDepFor()->size() > 0) { + for (int e = 0; e < node->controlDepFor()->size(); e++) + _stringDependencies.emplace_back(node->controlDepFor()->Get(e)->str()); + } + + if (node->controlDeps() != nullptr && node->controlDeps()->size() > 0) { + for (int e = 0; e < node->controlDeps()->size(); e++) + _stringDependencies.emplace_back(node->controlDeps()->Get(e)->str()); + } + /* if (node->output() != nullptr) for (int e = 0; e < (int) node->output()->size(); e++) { diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index ad80bce30203..40eeb3df150e 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -147,6 +147,19 @@ void sd::graph::Variable::setNDArrayList( this->_list = list; } +const std::vector>& Variable::dependencies() const { + return _dependencies; +} + +void Variable::actualizeDependencies(const MAP_IMPL &lookupTable) const { + for (const auto &v: _stringDependencies) { + if (lookupTable.count(v) == 0) + throw std::runtime_error("Unknown Variable dependency found: [" + v + "]"); + + const_cast(this)->_dependencies.emplace_back(std::pair{lookupTable.at(v), 0}); + } +} + void sd::graph::Variable::setNDArray(std::shared_ptr array) { this->_variableType = VariableType::NDARRAY; this->_ndarray = array; @@ -167,6 +180,22 @@ sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { int8_t *buffer = nullptr; + // reading control deps, and filling _dependencies field + if (flatVariable->controlDepsForVar() != nullptr && flatVariable->controlDepsForVar()->size() > 0) { + for (int e = 0; e < flatVariable->controlDepsForVar()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDepsForVar()->Get(e)->str()); + } + + if (flatVariable->controlDepForOp() != nullptr && flatVariable->controlDepForOp()->size() > 0) { + for (int e = 0; e < flatVariable->controlDepForOp()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDepForOp()->Get(e)->str()); + } + + if (flatVariable->controlDeps() != nullptr && flatVariable->controlDeps()->size() > 0) { + for (int e = 0; e < flatVariable->controlDeps()->size(); e++) + _stringDependencies.emplace_back(flatVariable->controlDeps()->Get(e)->str()); + } + switch (flatVariable->variabletype()) { case VarType_VARIABLE: { // ????? From 0d96f84a0a0d5acd6da6df5da669054ca48c3b5f Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 28 May 2020 10:28:58 +0300 Subject: [PATCH 134/233] minor CMakeLists fix Signed-off-by: raver119@gmail.com --- libnd4j/blas/CMakeLists.txt | 2 +- libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 8419cdd4cfd6..7ebc90bb5a49 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -297,7 +297,7 @@ elseif(SD_CPU) file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in) foreach(FL_ITEM ${COMPILATION_UNITS}) - string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM}) + string(REGEX MATCH "^(.*)\\.cpp\\.in$" dummy ${FL_ITEM}) set(FL_ITEM_WLE ${CMAKE_MATCH_1}) foreach(FL_TYPE_INDEX RANGE 0 9) #message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp") diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 92084ef74d2c..cafd5c52d863 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -227,7 +227,7 @@ endif() file(GLOB_RECURSE COMPILATION_UNITS false ../../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in) foreach(FL_ITEM ${COMPILATION_UNITS}) - string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM}) + string(REGEX MATCH "^(.*)\\.cpp\\.in$" dummy ${FL_ITEM}) set(FL_ITEM_WLE ${CMAKE_MATCH_1}) foreach(FL_TYPE_INDEX RANGE 0 9) #message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp") From 5d829473f5e35f2c744eabcd75925289c928aee3 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 28 May 2020 14:36:36 +0300 Subject: [PATCH 135/233] - while sorting, take into account dependencies between variables and ops --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 28 +++++++++++-------- .../layers_tests/GraphAnalysisTests.cpp | 2 +- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 861d616a9eb9..35cba736cf9d 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -60,11 +60,20 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS std::vector startNodes; for (const auto& p : inMap) { - for (const auto& v : p.second.input()) - if (v.first >= inMap.begin()->first) { + for (const auto& v : p.second.input()) { + if (v.first >= inMap.begin()->first) { // is op workMap[p.first]._in.push_back(v.first); workMap[v.first]._out.push_back(p.first); } + else { // is variable + for (const auto& i : varSpace.getVariable(v.first).get()->dependencies()) { + if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), i.first) == workMap[p.first]._in.end()) { + workMap[p.first]._in.push_back(i.first); + workMap[i.first]._out.push_back(p.first); + } + } + } + } if(workMap[p.first]._in.empty()) startNodes.push_back(p.first); @@ -74,16 +83,11 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS std::vector nodesToDelete; for (auto& p : workMap) { - if(p.second._in.size() != 1) { - - auto& out = p.second._out; - while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { - nodesToDelete.push_back(out[0]); - p.second._opSeq.push_back(out[0]); - out = workMap[out[0]]._out; - } - if(out != p.second._out) - p.second._out = std::move(out); + auto& out = p.second._out; + while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { + nodesToDelete.push_back(out[0]); + p.second._opSeq.push_back(out[0]); + out = std::move(workMap[out[0]]._out); } } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 5117e1e1d430..edf60f3e9442 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -896,7 +896,7 @@ TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); - graph.printOut(); + // graph.printOut(); /* some infor that would be useful for implementation currently on optimization graph is passing next data From 72b83f4842904981f535962bf80813fe6e7c0aad Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 28 May 2020 14:48:05 +0300 Subject: [PATCH 136/233] while graph Signed-off-by: raver119@gmail.com --- .../layers_tests/GraphAnalysisTests.cpp | 5 +++++ libnd4j/tests_cpu/resources/while_iter1.fb | Bin 0 -> 9512 bytes .../org/nd4j/imports/TensorFlowImportTest.java | 6 ++++++ 3 files changed, 11 insertions(+) create mode 100644 libnd4j/tests_cpu/resources/while_iter1.fb diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index edf60f3e9442..f8559433aa12 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -927,3 +927,8 @@ TEST_F(GraphAnalysisTests, test_cond_2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); // graph.printOut(); } + +TEST_F(GraphAnalysisTests, test_while_iter_1_1) { + auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); + graph.printOut(); +} diff --git a/libnd4j/tests_cpu/resources/while_iter1.fb b/libnd4j/tests_cpu/resources/while_iter1.fb new file mode 100644 index 0000000000000000000000000000000000000000..d81f0b4b37703088f59c2cbea6efb7ef927d2385 GIT binary patch literal 9512 zcmchdUuYc19mhwCY$?vQQJr%Tj%crG9z-gz$+*KpYP6& zcJI6=Dbl7t_}uKy{N^`*e!rPn?W)^2Ich%3Wt-_R3Daqkrpp}Vehvm;3%a2fo&aM$ zjaG+ub=b6XCwZWy=8JdBPbaitce2s88a@2sB{)0bfjuEPeb!U`;bd|d(gd5kpd z=UMKD!HMQ*H$N+OBm6x2lruuQ7oCv0cd51&BTAe3`^PqNH)% zW__Eo700?g{p(sybMrK*&VrxZvg=LrScY}bI+G6;VIE$EG894culWv-ZH)U1;Kcbr zyP~4dk;IGc956P z<|d2z;wxNk`&IhZchCE!^}gM$T(%lTZK|G9#ib;?-pY(IOg=J9LWsolmq(jgTz0e8TN5lS9U$x(C?~5m1@%pH4W-rgILl0ssH!>V)xNcaa{t9 zX9bpE9x6}<`Je#u(GaQnk%AVJ5R&0+5#{%m{>y}*7fpPu&pnc#W>Zo8p0Bhf^_iyv0_QG-1%D$&LKSIi@7<23IHm;OchIP;xbQKmsXV9xq zh9YRc8Yb0zDNZ_J44wmtwgsX6v2gL?)adxd$*iAOA5%k=K4S|hmjv;4f0@+?eq9?% zo1H@wj4e%hA83D7pHq+o#j5(_M0TP?u^E0JvhLfItsanF%|nz;vLi+9x@Wxv>w19u z{n~3hhu&rHRo)AC?Q^`YOQ5r-3QI5#6)1!DgaXK~L!{w3I?w%$FIT_RwfUHewtcJq z7jiS+`y9=i?iD|FUl{UB$UvIq+tT1ycC^oFAGikUuSEAFpf>vd5x+kxzRms4Ki%^B z=KT5MU+sSV9a62YU;NvyFO97X8rwRoz!Jz8&DK|nGR4&xC`O!Uu6QKycqupiZJig+ z4?3f*AHJ@Vx<96NDx=zcBAXh6e9*;Bvvv0NKjZhy6aTU{kZDU=eQQWUT)P~x1>+T%(;`Hy6?8|!pdWuwQ?aF_xe>Goa z*n|y`kFUZa>}4Hi(Q~4?a_g8kBe`krm8Wq9YxP-`G+D>@mr>QTp?#xKTzg9A^n<9a z6xZ61M*RDwd-41G4mJ+XLzF8?cI*v?)dL^ zE-8LWumL(V<;NwMhcfK8-(=BoqP{na$6#OaV$M|wy$7n#mxAnRjv7Vz8n@Dleh}TK zxDLl*5A%~IzYlsLj(w+u&A8%l6I=h!IPCxElZiu}0!Q$GI|~$#!~VO*j~S20lD0;$8Ei`~pA~0`VE?QCM*TrVozv9O^UhA+GM{8UG`2NOJ$-MHU+i9M zzoUIKLib1Ls*mcc)eXM&d_O|}E%fysSaqy^{WpZyBkHs=HvOF${H^Kg@V9V;{wefz z76*UFs87Z`%35?|m5(&93!uJh-h)25_iy&S+}Th)=Z_!KhUx`*Me>bhW;NXPw}0mX{;+-#}l_*TGrj>6vebjyzjFGj-x0}~|yTg>71NlV# zDMhqLk=L$LPf|9CYHyPI+E0V_yk{SBjxr1Nv!cC4bCTtE0yIbJgZ@{^YPNeZE1mt0 SU*}v0y_YB5JnIbM*8C4hNFjaz literal 0 HcmV?d00001 diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java index 45ecc80fc325..9c002ed41b7c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TensorFlowImportTest.java @@ -1119,4 +1119,10 @@ public void testControlDependencies1() throws Exception { assertEquals(variables.get("cond/LinSpace/num"), Collections.singletonList("cond/switch_t")); assertEquals(variables.get("cond/ones"), Collections.singletonList("cond/switch_f")); } + + @Test + public void testWhile() throws Exception { + SameDiff sd = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/while1/iter_1/frozen_model.pb").getInputStream()); + sd.save(new File("../../../libnd4j/tests_cpu/resources/while_iter1.fb"), true); + } } \ No newline at end of file From ac4f755b9bd73142555237e9a9c0870da63e89d8 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 1 Jun 2020 12:32:27 +0300 Subject: [PATCH 137/233] master merged Signed-off-by: raver119@gmail.com --- libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in | 3 +-- libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp | 2 +- libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in | 2 +- .../loops/cpu/compilation_units/broadcast_bool_p.cpp.in | 2 +- .../loops/cpu/compilation_units/broadcast_int_p.cpp.in | 2 +- .../include/loops/cpu/compilation_units/broadcast_p.cpp.in | 2 +- .../loops/cpu/compilation_units/indexreduce_int32.cpp.in | 2 +- .../loops/cpu/compilation_units/indexreduce_int64.cpp.in | 2 +- libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in | 2 +- libnd4j/include/loops/cpu/compilation_units/random.cpp.in | 2 +- .../loops/cpu/compilation_units/reduce3_bfloat16.cpp.in | 2 +- .../include/loops/cpu/compilation_units/reduce3_double.cpp.in | 2 +- .../include/loops/cpu/compilation_units/reduce3_float.cpp.in | 2 +- .../loops/cpu/compilation_units/reduce3_float16.cpp.in | 2 +- .../include/loops/cpu/compilation_units/reduce_float.cpp.in | 2 +- libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in | 2 +- .../include/loops/cuda/compilation_units/broadcasting.cu.in | 2 +- libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in | 2 +- libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in | 2 +- .../include/loops/cuda/compilation_units/reduce_float.cu.in | 2 +- libnd4j/include/loops/cuda/compilation_units/scalar.cu.in | 2 +- .../ops/declarable/generic/linalg/matrix_band_part.cpp | 2 +- libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp | 2 +- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp | 2 +- libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp | 2 +- libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 4 ++-- 26 files changed, 27 insertions(+), 28 deletions(-) diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in index 4f38b4d8f397..c762b192e7fe 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.cpp.in @@ -22,6 +22,5 @@ #cmakedefine FLOAT_TYPE_GEN namespace sd { - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } diff --git a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp index 241dc7e8cd25..d9edf2bdd2e7 100644 --- a/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp +++ b/libnd4j/include/helpers/cpu/loops/Reduction3Loops.hpp @@ -18,9 +18,9 @@ // @author raver119@gmail.com // +#include #include #include -#include using namespace simdOps; diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in index 5c1bb227d8d3..9443dc497230 100644 --- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in +++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.cpp.in @@ -22,7 +22,7 @@ #cmakedefine FLOAT_TYPE_GEN namespace sd { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in index b3c60462beea..7aa9e5824b52 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_bool_p.cpp.in @@ -23,6 +23,6 @@ namespace functions { namespace broadcast { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES_@FL_TYPE_INDEX@, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT BroadcastBool, , LIBND4J_TYPES_@FL_TYPE_INDEX@, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in index a36c1a0b2518..d52ecccd232c 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_int_p.cpp.in @@ -23,6 +23,6 @@ namespace functions { namespace broadcast { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES_@FL_TYPE_INDEX@); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT BroadcastInt, , INTEGER_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in index 1dbb4aac4006..e14963df2009 100644 --- a/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/broadcast_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in index 97402d38eedb..ad4df30fa112 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int32.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT32, int32_t)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT32, int32_t)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in index 30fa30749bd9..7e4ed33957d8 100644 --- a/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/indexreduce_int64.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace indexreduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT64, Nd4jLong)); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT IndexReduce, , LIBND4J_TYPES_@FL_TYPE_INDEX@, (sd::DataType::INT64, Nd4jLong)); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in index bbf809de8761..997350a3aaf4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/pairwise_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/random.cpp.in b/libnd4j/include/loops/cpu/compilation_units/random.cpp.in index 921532ac881b..e6e7dac16b57 100644 --- a/libnd4j/include/loops/cpu/compilation_units/random.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/random.cpp.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace random { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_SINGLE_TEMPLATE(template class SD_EXPORT RandomFunction, , FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in index 68616c3f9b74..cd2398e5b3d4 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_bfloat16.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_3); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_3); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in index 5c722838d1d4..a29a6952a97b 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_double.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_2); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_2); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in index ee127c2d9b84..c0a62ad94425 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_1); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_1); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in index 65c2b563ad83..3699c541d2dd 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce3_float16.cpp.in @@ -23,6 +23,6 @@ #cmakedefine LIBND4J_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_0); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES_@FL_TYPE_INDEX@, FLOAT_TYPES_0); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in b/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in index 3837c7810b4d..25cefff53ee3 100644 --- a/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/reduce_float.cpp.in @@ -23,6 +23,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } diff --git a/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in b/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in index dc024170d8b7..768665099a76 100644 --- a/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in +++ b/libnd4j/include/loops/cpu/compilation_units/scalar_p.cpp.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in b/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in index 6349dcfc98b8..cd93088da9e3 100644 --- a/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/broadcasting.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace broadcast { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT Broadcast, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in b/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in index 312ed7416242..2903b2f2c2a4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/pairwise.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace pairwise_transforms { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT PairWiseTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in b/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in index dd74728369a9..92ce78a83da6 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/reduce3.cu.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce3 { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT Reduce3, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in b/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in index 34c2bf8caadc..30d1c80a01be 100644 --- a/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/reduce_float.cu.in @@ -22,6 +22,6 @@ #cmakedefine FLOAT_TYPE_GEN namespace functions { namespace reduce { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); + BUILD_DOUBLE_TEMPLATE(template class SD_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in b/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in index 15608bdd1e65..4d5fbf05d9f4 100644 --- a/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in +++ b/libnd4j/include/loops/cuda/compilation_units/scalar.cu.in @@ -22,6 +22,6 @@ #cmakedefine PAIRWISE_TYPE_GEN namespace functions { namespace scalar { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); + BUILD_PAIRWISE_TEMPLATE(template class SD_EXPORT ScalarTransform, , PAIRWISE_TYPES_@FL_TYPE_INDEX@); } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp index 67153587f07e..388a23169255 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/matrix_band_part.cpp @@ -63,7 +63,7 @@ CONFIGURABLE_OP_IMPL(matrix_band_part, 1, 1, true, 0, 0) { return ND4J_STATUS_OK; } DECLARE_SYN(band_part, matrix_band_part); -} // namespace ops + DECLARE_TYPES(matrix_band_part) { getOpDescriptor() diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp index 73d35ec38022..ea51ca54850a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -178,7 +178,7 @@ PLATFORM_CHECK(concat, ENGINE_CPU) { const auto zType = z->dataType(); - const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const bool isAxisInLastArr = block.numB() == 0 ? false : B_ARG(0); const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); return z->rankOf() < 7 && numOfInArrs <= 3072 diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index d7a60611b14f..d934fc700051 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1672,7 +1672,7 @@ TEST_F(DeclarableOpsTests10, ResizeImages_Test1) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray *result = results.at(0); + auto result = results.at(0); // result->printBuffer("Resized to 7x9"); // expected.printBuffer("Expect for 7x9"); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 4c6d5a1c948d..b2cb13f5db6f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1572,7 +1572,7 @@ TEST_F(DeclarableOpsTests11, ResizeImages_Test8) { ASSERT_EQ(ND4J_STATUS_OK, results.status()); - NDArray* result = results.at(0); + auto result = results.at(0); // result->printBuffer("Area Resized to 6x6"); // expected.printBuffer("Area Expect for 6x6"); diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 342a993f974d..84f29d954f62 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1074,8 +1074,8 @@ TEST_F(RNGTests, Test_UniformDistribution_05) { ASSERT_FALSE(exp0.equalsTo(z)); sd::ops::reduce_max checkOp; - auto checkResult = checkOp.evaluate({z}); - checkResult[0]->printIndexedBuffer("Max on uniform with 0 to 1 on 100M cases is"); + auto checkResult = checkOp.evaluate({&z}); + checkResult[0].printIndexedBuffer("Max on uniform with 0 to 1 on 100M cases is"); } namespace sd { From 5c83564116cbaa96ad811225bf861adaf8ad8b8f Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 10:36:12 +0300 Subject: [PATCH 138/233] stuff compiles again Signed-off-by: raver119@gmail.com --- libnd4j/include/array/ConstantOffsetsBuffer.h | 2 +- libnd4j/include/array/ConstantShapeBuffer.h | 2 +- .../include/array/CudaPointerDeallocator.h | 2 +- libnd4j/include/array/NDArray.h | 4 +- libnd4j/include/array/PointerDeallocator.h | 2 +- libnd4j/include/array/PointerWrapper.h | 2 +- .../include/array/PrimaryPointerDeallocator.h | 2 +- libnd4j/include/array/impl/NDArray.cpp | 176 ++++++++---------- .../graph/execution/impl/OpSequence.cpp | 2 +- libnd4j/include/graph/impl/Graph.cpp | 2 +- libnd4j/include/graph/impl/Node.cpp | 12 +- .../include/graph/logic/impl/LogicReturn.cpp | 4 +- .../include/graph/logic/impl/LogicWhile.cpp | 2 +- .../helpers/benchmark/ReductionBenchmark.h | 2 +- .../include/helpers/cpu/ConstantTadHelper.cpp | 3 +- .../declarable/impl/LegacyReduceFloatOp.cpp | 4 +- .../layers_tests/DeclarableOpsTests1.cpp | 4 +- .../layers_tests/GraphHolderTests.cpp | 6 +- 18 files changed, 106 insertions(+), 127 deletions(-) diff --git a/libnd4j/include/array/ConstantOffsetsBuffer.h b/libnd4j/include/array/ConstantOffsetsBuffer.h index 61c1e381f3aa..d99d91711585 100644 --- a/libnd4j/include/array/ConstantOffsetsBuffer.h +++ b/libnd4j/include/array/ConstantOffsetsBuffer.h @@ -28,7 +28,7 @@ namespace sd { -class ND4J_EXPORT ConstantOffsetsBuffer { +class SD_EXPORT ConstantOffsetsBuffer { private: std::shared_ptr _primaryOffsets; std::shared_ptr _specialOffsets; diff --git a/libnd4j/include/array/ConstantShapeBuffer.h b/libnd4j/include/array/ConstantShapeBuffer.h index 2996532710de..bc2118dab7ca 100644 --- a/libnd4j/include/array/ConstantShapeBuffer.h +++ b/libnd4j/include/array/ConstantShapeBuffer.h @@ -28,7 +28,7 @@ namespace sd { -class ND4J_EXPORT ConstantShapeBuffer { +class SD_EXPORT ConstantShapeBuffer { private: std::shared_ptr _primaryShapeInfo; std::shared_ptr _specialShapeInfo; diff --git a/libnd4j/include/array/CudaPointerDeallocator.h b/libnd4j/include/array/CudaPointerDeallocator.h index c5c817aebca2..8fbda6976955 100644 --- a/libnd4j/include/array/CudaPointerDeallocator.h +++ b/libnd4j/include/array/CudaPointerDeallocator.h @@ -26,7 +26,7 @@ #include namespace sd { -class ND4J_EXPORT CudaPointerDeallocator : public PointerDeallocator { +class SD_EXPORT CudaPointerDeallocator : public PointerDeallocator { public: CudaPointerDeallocator() = default; ~CudaPointerDeallocator() = default; diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index fde957fafc42..8726920c985b 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -192,8 +192,8 @@ class SD_EXPORT NDArray { * contains shape info: matrix rank, numbers of elements per each dimension, * dimensions strides, element-wise-stride, c-like or fortan-like order */ - constNd4jLong* _shapeInfo = nullptr; - constNd4jLong* _shapeInfoD = nullptr; + const Nd4jLong* _shapeInfo = nullptr; + const Nd4jLong* _shapeInfoD = nullptr; /** * pointer on device launch context (with all data needed there). diff --git a/libnd4j/include/array/PointerDeallocator.h b/libnd4j/include/array/PointerDeallocator.h index 5bf820421713..ec3e095528d0 100644 --- a/libnd4j/include/array/PointerDeallocator.h +++ b/libnd4j/include/array/PointerDeallocator.h @@ -26,7 +26,7 @@ namespace sd { -class ND4J_EXPORT PointerDeallocator { +class SD_EXPORT PointerDeallocator { public: PointerDeallocator() = default; ~PointerDeallocator() = default; diff --git a/libnd4j/include/array/PointerWrapper.h b/libnd4j/include/array/PointerWrapper.h index 9e15aaaa3398..dad090593d19 100644 --- a/libnd4j/include/array/PointerWrapper.h +++ b/libnd4j/include/array/PointerWrapper.h @@ -27,7 +27,7 @@ #include namespace sd { -class ND4J_EXPORT PointerWrapper { +class SD_EXPORT PointerWrapper { private: void* _pointer = nullptr; std::shared_ptr _deallocator; diff --git a/libnd4j/include/array/PrimaryPointerDeallocator.h b/libnd4j/include/array/PrimaryPointerDeallocator.h index b4fe34764560..6c8e0f553574 100644 --- a/libnd4j/include/array/PrimaryPointerDeallocator.h +++ b/libnd4j/include/array/PrimaryPointerDeallocator.h @@ -26,7 +26,7 @@ #include namespace sd { -class ND4J_EXPORT PrimaryPointerDeallocator : public PointerDeallocator { +class SD_EXPORT PrimaryPointerDeallocator : public PointerDeallocator { public: PrimaryPointerDeallocator() = default; ~PrimaryPointerDeallocator() = default; diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index c39e20a3ce2a..a198810f8b17 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -223,7 +223,7 @@ NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, getContext()->getWorkspace()); _buffer->setToZeroBuffers(); } else - setShapeInfo(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); + setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype)); } ////////////////////////////////////////////////////////////////////////// @@ -1227,14 +1227,15 @@ void NDArray::streamline(char o) { syncToDevice(); std::shared_ptr newBuffer = std::make_shared( this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo( + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo( dataType(), order, rankOf(), shapeOf()); NativeOpExecutioner::execTransformSame( getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), - static_cast(shapeBuffer.primary()), newBuffer->special(), - static_cast(shapeBuffer.special()), nullptr, nullptr, nullptr); - setShapeInfo(static_cast(shapeBuffer.primary())); + shapeBuffer.primary(), newBuffer->special(), + shapeBuffer.special(), nullptr, nullptr, nullptr); + + setShapeInfo(shapeBuffer); _buffer = newBuffer; _offset = 0; tickWriteDevice(); @@ -1558,7 +1559,7 @@ NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, auto newShape = ShapeUtils::evalReduceShapeInfo( 'c', copy, *this, - isR() ? dataType() : Environment::getInstance()->defaultFloatDataType(), + isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace()); NDArray result(newShape, true, getContext()); @@ -1665,7 +1666,7 @@ NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, "NDArray::reduceNumber FloatOps: you can't use this method on String " "array!"); - auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo( + auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo( DataTypeUtils::pickFloatingType(dataType())); NDArray result(shape, true, this->getContext()); @@ -1706,7 +1707,7 @@ NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void* extraParams) const { "array!"); auto shape = - ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL); + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL); NDArray result(shape, true, this->getContext()); NDArray::prepareSpecialUse({&result}, {this}); @@ -1727,7 +1728,7 @@ NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void* extraParams) const { "array!"); auto shape = - ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); + ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); NDArray result(shape, true, this->getContext()); NDArray::prepareSpecialUse({&result}, {this}); @@ -2129,8 +2130,7 @@ void NDArray::setAttached(bool reallyAttached) { ////////////////////////////////////////////////////////////////////////// // calculate strides void NDArray::updateStrides(const char order) { - shape::updateStrides(_shapeInfo, order); - syncShape(); + throw std::runtime_error("Very bad method was invoked"); } ////////////////////////////////////////////////////////////////////////// @@ -2755,7 +2755,7 @@ void NDArray::operator+=(const NDArray& other) { if (isS()) throw std::runtime_error( "NDArray::operator+=: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -2805,7 +2805,7 @@ void NDArray::operator-=(const NDArray& other) { throw std::runtime_error( "NDArray::operator-=: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -2854,7 +2854,7 @@ void NDArray::operator*=(const NDArray& other) { if (isS()) throw std::runtime_error( "NDArray::operator*=: you can't use this method on String array!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -2907,7 +2907,7 @@ void NDArray::operator/=(const NDArray& other) { throw std::runtime_error( "NDArray::operator/=: you can't divide by bool array!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType()) { throw sd::datatype_exception::build( "NDArray operator/=: Cannot divide different types", this->dataType(), @@ -3245,16 +3245,13 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = - ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3319,16 +3316,13 @@ void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = - ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3393,16 +3387,14 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = - ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3595,16 +3587,14 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3673,16 +3663,14 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3751,16 +3739,14 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); if (!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); xShapeInfoH = reinterpret_cast(xPack.primary()); xShapeInfoD = reinterpret_cast(xPack.special()); } if (!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance() - ->createShapeInfoWithUnitiesForBroadcast( + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast( target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); yShapeInfoH = reinterpret_cast(yPack.primary()); @@ -3787,9 +3773,9 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, //////////////////////////////////////////////////////////////////////// void* NDArray::operator new(size_t i) { - if (sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) { + if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) { sd::memory::Workspace* ws = - sd::memory::MemoryRegistrator::getInstance()->getWorkspace(); + sd::memory::MemoryRegistrator::getInstance().getWorkspace(); return ws->allocateBytes((Nd4jLong)i); } else { auto p = malloc(i); @@ -3800,7 +3786,7 @@ void* NDArray::operator new(size_t i) { //////////////////////////////////////////////////////////////////////// void NDArray::operator delete(void* p) { - if (!sd::memory::MemoryRegistrator::getInstance()->hasWorkspaceAttached()) + if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) free(p); } @@ -4072,8 +4058,8 @@ void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, else { std::vector copy(dimensions); auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimensions); NativeOpExecutioner::execSummaryStats( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5021,8 +5007,8 @@ void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, std::vector copy = dimensions; shape::checkDimensions(rankOf(), copy); auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( shapeInfo(), copy); NativeOpExecutioner::execIndexReduce( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5134,11 +5120,11 @@ NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); } else { auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( shapeInfo(), copy); - auto packY = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( other.shapeInfo(), copy); if (!shape::equalsSoft(packX.primaryShapeInfo(), @@ -5183,8 +5169,8 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, shape::checkDimensions(other.rankOf(), copy); auto packX = - ConstantTadHelper::getInstance()->tadForDimensions(shapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions( + ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance().tadForDimensions( other.shapeInfo(), copy); // check tads shapes @@ -5194,7 +5180,7 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, "different !"); // set newShape for output array - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo( + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo( DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); @@ -5208,7 +5194,7 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - auto pDims = sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; NDArray::prepareSpecialUse({&result}, {this, &other}); NativeOpExecutioner::execReduce3All( @@ -5259,7 +5245,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); } else { - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( shapeInfo(), copy); NativeOpExecutioner::execReduceFloat( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5309,8 +5295,8 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, target.specialBuffer(), target.specialShapeInfo()); } else { // if (!isEmpty()) { auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), copy); NativeOpExecutioner::execReduceSame( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5360,8 +5346,8 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, target.specialBuffer(), target.specialShapeInfo()); } else { auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), copy); NativeOpExecutioner::execReduceLong( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5411,8 +5397,8 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, target.specialBuffer(), target.specialShapeInfo()); } else { auto pDims = - sd::Environment::getInstance()->isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), copy); NativeOpExecutioner::execReduceBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), @@ -5668,7 +5654,7 @@ void NDArray::addRowVector(const NDArray& row, NDArray& target) const { int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({&target}, {this, &row}); @@ -5699,7 +5685,7 @@ void NDArray::subRowVector(const NDArray& row, NDArray& target) const { int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({&target}, {this, &row}); @@ -5729,7 +5715,7 @@ void NDArray::mulRowVector(const NDArray& row, NDArray& target) const { int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({&target}, {this, &row}); @@ -5762,7 +5748,7 @@ void NDArray::divRowVector(const NDArray& row, NDArray& target) const { int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({&target}, {this, &row}); @@ -5788,7 +5774,7 @@ void NDArray::addiRowVector(const NDArray& row) { int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({this}, {&row}); @@ -5818,7 +5804,7 @@ void NDArray::addColumnVector(const NDArray& column, NDArray& target) const { int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({&target}, {this, &column}); @@ -5845,7 +5831,7 @@ void NDArray::addiColumnVector(const NDArray& column) { int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({this}, {&column}); @@ -5872,7 +5858,7 @@ void NDArray::muliColumnVector(const NDArray& column) { int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions( + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( this->shapeInfo(), dimension); NDArray::prepareSpecialUse({this}, {&column}); @@ -5925,7 +5911,7 @@ ResultSet NDArray::multipleTensorsAlongDimension( if (indices.size() == 0) return result; - auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + auto pack = ConstantTadHelper::getInstance().tadForDimensions( shapeInfo(), const_cast(dimensions.data()), dimensions.size()); auto tadLength = shape::length(pack.primaryShapeInfo()); @@ -6043,7 +6029,7 @@ ResultSet NDArray::allTensorsAlongDimension( "NDArray::allTensorsAlongDimension static function: all input " "dimensions must be smaller than rank of input array !"); - auto pack = ConstantTadHelper::getInstance()->tadForDimensions( + auto pack = ConstantTadHelper::getInstance().tadForDimensions( _shapeInfo, const_cast(dimensions.data()), dimensions.size()); auto numTads = pack.numberOfTads(); @@ -6176,12 +6162,10 @@ void NDArray::setShapeInfo(const Nd4jLong* shapeInfo) { if (shapeInfo != nullptr) { ShapeDescriptor descriptor(shapeInfo); auto shapeBuffer = - ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); -#ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); -#endif + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; @@ -6203,12 +6187,10 @@ void NDArray::setShapeInfo(const Nd4jLong* shapeInfo, shapeInfo, dtype, true, getContext()->getWorkspace()); ShapeDescriptor descriptor(shapeInfoTemp); auto shapeBuffer = - ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor); + ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); -#ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); -#endif + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; @@ -6224,13 +6206,12 @@ void NDArray::setShapeInfo(const Nd4jLong* shapeInfo, ////////////////////////////////////////////////////////////////////////// void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { - auto shapeBuffer = ConstantShapeHelper::getInstance()->bufferForShapeInfo( + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo( const_cast(descriptor)); - _shapeInfo = reinterpret_cast(shapeBuffer.primary()); -#ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast(shapeBuffer.special()); -#endif + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; @@ -6241,13 +6222,10 @@ void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const ConstantDataBuffer& shapeBuffer) { - _shapeInfo = reinterpret_cast( - const_cast(shapeBuffer).primary()); -#ifdef __CUDABLAS__ - _shapeInfoD = reinterpret_cast( - const_cast(shapeBuffer).special()); -#endif +void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) { + _shapeInfo = shapeBuffer.primary(); + _shapeInfoD = shapeBuffer.special(); + if (ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; @@ -6709,7 +6687,7 @@ NDArray operator+(T1&& arr1, T2&& arr2) { throw std::runtime_error( "operator+(T&& arr1, T&& arr2): you can't use this method on String " "arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -6783,7 +6761,7 @@ NDArray operator-(T1&& arr1, T2&& arr2) { throw std::runtime_error( "operator-(T&& arr1, T&& arr2): you can't use this method on String " "arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -6857,7 +6835,7 @@ NDArray operator*(T1&& arr1, T2&& arr2) { throw std::runtime_error( "operator*(T&& arr1, T&& arr2): you can't use this method on String " "arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) throw sd::datatype_exception::build( @@ -6931,7 +6909,7 @@ NDArray operator/(T1&& arr1, T2&& arr2) { throw std::runtime_error( "operator/(T&& arr1, T&& arr2): you can't use this method on String " "arrays!"); - if (!Environment::getInstance()->isExperimentalBuild() && + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) throw sd::datatype_exception::build( diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 0aa78e5011d5..07e1b1086476 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -87,7 +87,7 @@ void OpSequence::append(const std::shared_ptr &op, void OpSequence::append(sd::ops::DeclarableOp *op, const ContextPrototype &ctx) { auto rop = - sd::ops::OpRegistrator::getInstance()->getOperation(op->getOpHash()); + sd::ops::OpRegistrator::getInstance().getOperation(op->getOpHash()); append(rop, ctx); } diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 574f1019d44a..f602fb7ce336 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -499,7 +499,7 @@ Graph Graph::importFromTensorFlow(const char *fileName) { node.name().c_str(), node.op().c_str()); sd::ops::DeclarableOp *op = - sd::ops::OpRegistrator::getInstance()->getOperationFloat(node.op().c_str()); + sd::ops::OpRegistrator::getInstance().getOperationFloat(node.op().c_str()); if (op == nullptr) { nd4j_verbose("Op wasn't found: %s\n", node.op().c_str()); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index a6ab7054001c..39f668ba64f9 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -48,7 +48,7 @@ Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs) { auto customOp = - ops::OpRegistrator::getInstance()->getOperation(opName.getOpHash()); + ops::OpRegistrator::getInstance().getOperation(opName.getOpHash()); this->_name = nodeName; this->_opType = OpType_CUSTOM; @@ -77,7 +77,7 @@ Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, Node::Node(const std::string &opName, const std::string &nodeName, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs, const std::vector &dArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); this->_name = nodeName; this->_opType = OpType_CUSTOM; @@ -287,7 +287,7 @@ Node::Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); this->_opType = OpType_CUSTOM; this->_id = id; @@ -316,7 +316,7 @@ Node::Node(const std::string &opName, const int id, const std::vector> &inputs, const std::vector &tArgs, const std::vector &iArgs) { - auto customOp = ops::OpRegistrator::getInstance()->getOperation(opName); + auto customOp = ops::OpRegistrator::getInstance().getOperation(opName); this->_opType = OpType_CUSTOM; this->_id = id; @@ -354,9 +354,9 @@ Node::Node(sd::ops::DeclarableOp *customOp, int id, // if custom op is a registered one - pull it from cache, otherwise - clone // locally - if (sd::ops::OpRegistrator::getInstance()->hasOperation(_opNum)) + if (sd::ops::OpRegistrator::getInstance().hasOperation(_opNum)) this->_customOp = - sd::ops::OpRegistrator::getInstance()->getOperation(_opNum); + sd::ops::OpRegistrator::getInstance().getOperation(_opNum); else throw std::runtime_error( "Can't create a node with custom operation within"); diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp index 27a8fc76c78d..5b19f380cd22 100644 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -37,7 +37,7 @@ Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { // FIXME!! outputAddr.second = e; - if (Environment::getInstance()->isDebugAndVerbose()) + if (Environment::getInstance().isDebugAndVerbose()) nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second); @@ -50,7 +50,7 @@ Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { // FIXME: this is obviously wrong, we should keep depth track for backprop here varOut->getNDArray()->assign(varIn->getNDArray()); - if (Environment::getInstance()->isDebugAndVerbose()) + if (Environment::getInstance().isDebugAndVerbose()) nd4j_debug("In after: [%f]; Out after: [%f]\n", varIn->getNDArray()->meanNumber().e(0), varOut->getNDArray()->meanNumber().e(0)); diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index c53d4dd86a03..c377d59058d8 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -98,7 +98,7 @@ Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { // now we should take result of the Scope run, and evaluate it auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - if (Environment::getInstance()->isDebugAndVerbose()) + if (Environment::getInstance().isDebugAndVerbose()) result->printBuffer("Result of the last node:"); // if result evaluates to 0.0 - condition returned FALSE diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h index b71c12c8b43b..68d91143bf08 100644 --- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h +++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h @@ -21,7 +21,7 @@ #include #include -#include +#include #ifndef SD_REDUCEBENCHMARK_H #define SD_REDUCEBENCHMARK_H diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index 7222cedf54b7..767aa7d35fcc 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -72,7 +72,8 @@ TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { std::lock_guard lock(_mutex); if (_cache[deviceId].count(descriptor) == 0) { -// if there's no TadPack matching this descriptor - create one const auto shapeInfo = descriptor.originalShape().toShapeInfo(); +// if there's no TadPack matching this descriptor - create one + const auto shapeInfo = descriptor.originalShape().toShapeInfo(); const int rank = shape::rank(shapeInfo); const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); diff --git a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp index 33f8351f4c55..7ddd8b962bba 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyReduceFloatOp.cpp @@ -135,14 +135,14 @@ Nd4jStatus LegacyReduceFloatOp::validateAndExecute(Context &block) { x->shapeInfo(), dims); auto pTadShape = - Environment::getInstance()->isCPU() + Environment::getInstance().isCPU() ? packX.primaryShapeInfo() : packX .specialShapeInfo(); //(Nd4jLong *) // manager.replicatePointer(tad.tadOnlyShapeInfo, // shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); auto pTadOffsets = - Environment::getInstance()->isCPU() + Environment::getInstance().isCPU() ? packX.primaryOffsets() : packX .specialOffsets(); //(Nd4jLong *) diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index b882a3a041c3..6a20db6a6677 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1758,7 +1758,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // std::string opName("add"); // auto hash = -// sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); +// sd::ops::HashHelper::getInstance().getInstance().getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; @@ -1810,7 +1810,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) { // std::string opName("add"); // auto hash = -// sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); +// sd::ops::HashHelper::getInstance().getInstance().getLongHash(opName); // auto inputBuffers = new Nd4jPointer[2]; // auto inputShapes = new Nd4jPointer[2]; diff --git a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp index 983adbc7e7d9..5938f30ac72b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -33,11 +33,11 @@ class GraphHolderTests : public testing::Test { TEST_F(GraphHolderTests, SimpleTests_1) { Graph graph; Nd4jLong graphId = 119; - GraphHolder::getInstance()->registerGraph(graphId, graph); + GraphHolder::getInstance().registerGraph(graphId, graph); ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); - GraphHolder::getInstance()->forgetGraph(graphId); + GraphHolder::getInstance().forgetGraph(graphId); - ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(graphId)); + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); } \ No newline at end of file From a570ed2c6463e751392066905e5125dcfadaafbc Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 15:13:37 +0300 Subject: [PATCH 139/233] minor tweaks Signed-off-by: raver119@gmail.com --- libnd4j/blas/CMakeLists.txt | 4 ++-- libnd4j/include/array/impl/NDArrayFactory.cpp | 3 --- libnd4j/include/graph/VariableProxy.h | 2 +- libnd4j/include/graph/impl/Graph.cpp | 8 ++++---- libnd4j/include/helpers/files.h | 3 +-- libnd4j/include/memory/HotZoneManager.h | 4 ++-- .../generic/images/draw_bounding_boxes.cpp | 15 ++++++--------- .../ops/declarable/generic/random/multinomial.cpp | 2 +- .../ops/declarable/generic/tsne/symmetrized.cpp | 8 ++++---- .../ops/declarable/helpers/cpu/BarnesHutTsne.cpp | 8 +++----- .../ops/declarable/helpers/cpu/image_resize.cpp | 5 +++-- .../include/ops/declarable/helpers/cpu/lrn.cpp | 8 ++++---- .../include/ops/declarable/helpers/cpu/lstm.cpp | 2 +- .../ops/declarable/helpers/cpu/one_hot.cpp | 2 +- .../tests_cpu/layers_tests/GraphAnalysisTests.cpp | 2 +- 15 files changed, 34 insertions(+), 42 deletions(-) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index c23ba635bfcd..e67d29ae1e37 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -108,10 +108,10 @@ ENDIF() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" AND SD_X86_BUILD) # apple clang but not ios-arm - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-defaulted-function-deleted") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") # using Clang - SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-defaulted-function-deleted") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") # using Intel C++ SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -O3 -fp-model fast") diff --git a/libnd4j/include/array/impl/NDArrayFactory.cpp b/libnd4j/include/array/impl/NDArrayFactory.cpp index 23ae2e070571..a84e639ec516 100644 --- a/libnd4j/include/array/impl/NDArrayFactory.cpp +++ b/libnd4j/include/array/impl/NDArrayFactory.cpp @@ -130,9 +130,6 @@ template SD_EXPORT NDArray NDArrayFactory::create( template SD_EXPORT NDArray NDArrayFactory::create( const char order, const std::vector& shape, const std::vector& data, sd::LaunchContext* context); -template SD_EXPORT NDArray NDArrayFactory::create( - const char order, const std::vector& shape, - const std::vector& data, sd::LaunchContext* context); //////////////////////////////////////////////////////////////////////// template diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 729d171629a0..63eabf15e460 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -34,7 +34,7 @@ class SD_EXPORT VariableProxy : public VariableSpace { explicit VariableProxy(const VariableSpace* reference); ~VariableProxy(); - virtual VariableSpace& operator=(const VariableSpace& other); + virtual VariableSpace& operator=(const VariableSpace& other) override; virtual int numberOfPlaceholders() const override; virtual const std::vector>& placeholders() diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index f602fb7ce336..188ad9322b71 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -650,7 +650,7 @@ Graph::Graph(const Graph &other) : _memoryManager(other._memoryManager) { _unmapped = other._unmapped; _symbolicLookupTable = other._symbolicLookupTable; _built = false; - _maxId = _maxId; + _maxId = other._maxId; } Graph &Graph::operator=(const Graph &other) noexcept { @@ -662,7 +662,7 @@ Graph &Graph::operator=(const Graph &other) noexcept { _unmapped = other._unmapped; _symbolicLookupTable = other._symbolicLookupTable; _built = false; - _maxId = _maxId; + _maxId = other._maxId; return *this; } @@ -676,7 +676,7 @@ Graph::Graph(Graph &&other) : _memoryManager(other._memoryManager) { _symbolicLookupTable = std::move(other._symbolicLookupTable); _built = false; - _maxId = _maxId; + _maxId = other._maxId; } Graph &Graph::operator=(Graph &&other) noexcept { @@ -690,7 +690,7 @@ Graph &Graph::operator=(Graph &&other) noexcept { _symbolicLookupTable = std::move(other._symbolicLookupTable); _built = false; - _maxId = _maxId; + _maxId = other._maxId; return *this; } diff --git a/libnd4j/include/helpers/files.h b/libnd4j/include/helpers/files.h index 913f7bfc0a6b..6e4de204e2e3 100644 --- a/libnd4j/include/helpers/files.h +++ b/libnd4j/include/helpers/files.h @@ -89,9 +89,8 @@ unsigned maxpathlen(char *path[], const char *base) { return blen + n + 1; } bool file_exists(const char *name) { - // printf("Trying file: [%s]\n", name); FILE *file; - if (file = fopen(name, "r")) { + if ((file = fopen(name, "r"))) { fclose(file); return true; } diff --git a/libnd4j/include/memory/HotZoneManager.h b/libnd4j/include/memory/HotZoneManager.h index 1c8e197ee9b6..5499e488fc59 100644 --- a/libnd4j/include/memory/HotZoneManager.h +++ b/libnd4j/include/memory/HotZoneManager.h @@ -42,9 +42,9 @@ class SD_EXPORT HotZoneManager : public ZoneManager { uint64_t used() const override; - virtual MemoryDescriptor allocate(uint64_t numBytes) = 0; + virtual MemoryDescriptor allocate(uint64_t numBytes) override = 0; - virtual void release(MemoryDescriptor &descriptor) = 0; + virtual void release(MemoryDescriptor &descriptor) override = 0; }; } // namespace memory } // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp index 44ea0676f29e..45927669440b 100644 --- a/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp +++ b/libnd4j/include/ops/declarable/generic/images/draw_bounding_boxes.cpp @@ -62,19 +62,16 @@ OP_IMPL(draw_bounding_boxes, 3, 1, true) { "draw_bounding_boxes: Batches for images and boxes " "should be the same, but %lld and %lld occured.", images->sizeAt(0), boxes->sizeAt(0)); - helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, - colors, output); - return ND4J_STATUS_OK; + helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); + return Status::OK(); } DECLARE_TYPES(draw_bounding_boxes) { getOpDescriptor() - ->setAllowedInputTypes( - 0, {HALF, FLOAT32}) // TF allows HALF and FLOAT32 only - ->setAllowedInputTypes(1, {FLOAT32}) // as TF - ->setAllowedInputTypes(2, {FLOAT32}) // as TF - ->setAllowedOutputTypes( - {HALF, FLOAT32}); // TF allows HALF and FLOAT32 only + ->setAllowedInputTypes(0, {HALF, FLOAT32}) // TF allows HALF and FLOAT32 only + ->setAllowedInputTypes(1, sd::DataType::FLOAT32) // as TF + ->setAllowedInputTypes(2, sd::DataType::FLOAT32) // as TF + ->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only } } // namespace ops } // namespace sd diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp index 0438ca6d7a34..126ca5a588fe 100644 --- a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -125,7 +125,7 @@ DECLARE_SHAPE_FN(random_multinomial) { DECLARE_TYPES(random_multinomial) { getOpDescriptor() ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {sd::DataType::INT32}) + ->setAllowedInputTypes(1, sd::DataType::INT32) ->setAllowedOutputTypes(0, {ALL_INDICES}); } } // namespace ops diff --git a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp index 2b26c6dbbf90..016ea38553b7 100644 --- a/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp +++ b/libnd4j/include/ops/declarable/generic/tsne/symmetrized.cpp @@ -50,11 +50,11 @@ CUSTOM_OP_IMPL(barnes_symmetrized, 3, 3, false, 0, -1) { DECLARE_TYPES(barnes_symmetrized) { getOpDescriptor() - ->setAllowedInputTypes(0, {DataType::INT32}) - ->setAllowedInputTypes(1, {DataType::INT32}) + ->setAllowedInputTypes(0, DataType::INT32) + ->setAllowedInputTypes(1, DataType::INT32) ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) - ->setAllowedOutputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes(1, {DataType::INT32}) + ->setAllowedOutputTypes(1, DataType::INT32) + ->setAllowedOutputTypes(1, DataType::INT32) ->setAllowedOutputTypes(2, {ALL_INTS, ALL_FLOATS}) ->setSameMode(false); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp index a25d8957fd54..2ffd204a90dc 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/BarnesHutTsne.cpp @@ -220,7 +220,9 @@ static void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, T res = sd::math::nd4j_sign(grad) != sd::math::nd4j_sign(eps) ? x + T(.2) : x * T(.8); + if (res < .01) res = .01; + return res; }; @@ -233,7 +235,7 @@ void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, // gains.add(.2).muli(sign(yGrads)).neq(sign(yIncs)).castTo(Nd4j.defaultFloatingPointType()) // .addi(gains.mul(0.8).muli(sign(yGrads)).neq(sign(yIncs))); BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, - (input, gradX, epsilon, output), NUMERIC_TYPES); + (input, gradX, epsilon, output), FLOAT_TYPES); // auto signGradX = *gradX; // auto signEpsilon = *epsilon; // gradX->applyTransform(transform::Sign, &signGradX, nullptr); @@ -254,10 +256,6 @@ void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, // leftPart.applyPairwiseTransform(pairwise::Add, &rightPart, output, // nullptr); } -BUILD_SINGLE_TEMPLATE(template void barnes_gains_, - (NDArray * input, NDArray* gradX, NDArray* epsilon, - NDArray* output), - NUMERIC_TYPES); bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 3d03b7c90359..f086339c368c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -1441,9 +1441,10 @@ int resizeAreaFunctor(sd::LaunchContext* context, NDArray const* image, return resizeBicubicFunctor(context, image, width, height, alignCorners, false, output); case kResizeArea: return resizeAreaFunctor(context, image, width, height, alignCorners, output); + default: + nd4j_printf("helper::resizeImagesFunctor: Wrong resize method %i\n", (int)method); + return Status::CODE(ND4J_STATUS_BAD_INPUT, "helper::resizeImagesFunctor: Wrong resize method"); } - nd4j_printf("helper::resizeImagesFunctor: Wrong resize method %i\n", (int)method); - return Status::CODE(ND4J_STATUS_BAD_INPUT, "helper::resizeImagesFunctor: Wrong resize method"); } // ------------------------------------------------------------------------------------------------------------------ // int resizeFunctor(sd::LaunchContext * context, NDArray const* image, int const width, int const height, diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp index a2caf7c25d73..5a11a9647ec5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp @@ -37,14 +37,14 @@ static int lrnFunctor_(sd::graph::Context& block, NDArray* input, const int rank = input->rankOf(); TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - input->shapeInfo(), {rank - 1}); + input->shapeInfo(), rank - 1); TadPack outTadPack; if (shape::haveSameShapeAndStrides(input->shapeInfo(), output->shapeInfo())) outTadPack = inTadPack; else outTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - output->shapeInfo(), {rank - 1}); + output->shapeInfo(), rank - 1); const Nd4jLong numOfTads = inTadPack.numberOfTads(); const Nd4jLong tadLen = input->sizeAt(-1); @@ -168,14 +168,14 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const int rank = input.rankOf(); TadPack inTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - input.shapeInfo(), {rank - 1}); + input.shapeInfo(), rank - 1); TadPack gradITadPack; if (shape::haveSameShapeAndStrides(input.shapeInfo(), gradI.shapeInfo())) gradITadPack = inTadPack; else gradITadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - gradI.shapeInfo(), {rank - 1}); + gradI.shapeInfo(), rank - 1); const Nd4jLong numOfTads = inTadPack.numberOfTads(); const Nd4jLong tadLen = input.sizeAt(-1); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 58eab23237d9..05ad99099f51 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -209,7 +209,7 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, xt->dataType(), xt->getContext()); helpers::concat(xt->getContext(), {const_cast(xt), const_cast(yLast)}, - concatOut, {1}); + concatOut, 1); auto m = mmul(concatOut, diff --git a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp index 03a6d440eabd..05f77bb8c6e9 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/one_hot.cpp @@ -36,7 +36,7 @@ static void onehot_(void* voutput, Nd4jLong const* zShapeInfo, auto indices = reinterpret_cast(vindices); auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - zShapeInfo, {axis}); + zShapeInfo, axis); auto iLen = static_cast(shape::length(iShapeInfo)); auto tLen = diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index f8559433aa12..a55eae5b0b91 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -930,5 +930,5 @@ TEST_F(GraphAnalysisTests, test_cond_2) { TEST_F(GraphAnalysisTests, test_while_iter_1_1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); - graph.printOut(); + //graph.printOut(); } From 7cdd353533bce8a8c5f45dc566b5dbd470699b90 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 15:24:26 +0300 Subject: [PATCH 140/233] ASAN should be active for GCC only Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 563bf58f6c0b..ffe508987397 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -60,7 +60,8 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") endif() - if (SD_CPU AND SD_SANITIZE) + # Use ASAN for GCC only + if (SD_CPU AND SD_SANITIZE AND "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" ) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() # CUDA? From ed8f34f8abd25112983f7b152cea6cd0147bf0b0 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 15:55:50 +0300 Subject: [PATCH 141/233] un-MMAP Graphs once they're gone Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Graph.h | 3 ++ libnd4j/include/graph/impl/Graph.cpp | 6 ++- libnd4j/include/legacy/cpu/NativeOps.cpp | 3 +- libnd4j/include/memory/GraphMemoryManager.h | 11 ++++- libnd4j/include/memory/MmapDeallocator.h | 41 +++++++++++++++++++ .../include/memory/cpu/GraphMemoryManager.cpp | 5 +++ .../include/memory/impl/MmapDeallocator.cpp | 33 +++++++++++++++ 7 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 libnd4j/include/memory/MmapDeallocator.h create mode 100644 libnd4j/include/memory/impl/MmapDeallocator.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index e15949abcfa0..5a0d3093892c 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -77,6 +77,9 @@ class SD_EXPORT Graph { mutable std::mutex _optimizedLock; + + std::vector _handles; + public: Graph(const FlatGraph *flatGraph = nullptr, const GraphMemoryManager &memoryManager = GraphMemoryManager()); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 188ad9322b71..a522e800b000 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -35,6 +35,7 @@ #include #include +#include namespace sd { namespace graph { @@ -54,7 +55,9 @@ VariableSpace &Graph::variableSpace() const { return const_cast(_variableSpace); } -Graph::~Graph() {} +Graph::~Graph() { + +} int Graph::idByName(const std::string &nodeName) const { if (_symbolicLookupTable.count(nodeName) == 0) @@ -335,6 +338,7 @@ Graph Graph::fromFlatBuffers(const char *fileName, // mmap this file ref = ::mmapFile(nullptr, fileName, fsize); ptrGraph = reinterpret_cast(ref[0]); + memoryManager.track(std::make_shared(ref, std::make_shared())); } else { // if mmap is not supported - load it directly diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 7ae6ede1e733..6ef6fd484c35 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -1607,7 +1607,7 @@ Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - auto hZ = new Nd4jLong[2]; + auto hZ = new Nd4jLong[3]; errno = 0; try { #if defined(_WIN32) || defined(_WIN64) @@ -1627,6 +1627,7 @@ Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, hZ[1] = fd; #endif + hZ[2] = length; return hZ; } catch (std::exception &e) { diff --git a/libnd4j/include/memory/GraphMemoryManager.h b/libnd4j/include/memory/GraphMemoryManager.h index 4245a91f275e..98a85985996f 100644 --- a/libnd4j/include/memory/GraphMemoryManager.h +++ b/libnd4j/include/memory/GraphMemoryManager.h @@ -24,8 +24,10 @@ #include #include #include - +#include #include +#include +#include using namespace sd::memory; @@ -35,6 +37,7 @@ class GraphMemoryManager { protected: std::map _zones; + mutable std::vector> _attached; public: GraphMemoryManager(); ~GraphMemoryManager(); @@ -53,6 +56,12 @@ class GraphMemoryManager { * @param descriptor */ virtual void release(MemoryDescriptor &descriptor); + + /** + * This method allows to store reference to certain memory regions and keep them alive as long as Graph is alive + * @param ptr + */ + virtual void track(const std::shared_ptr &ptr) const; }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/memory/MmapDeallocator.h b/libnd4j/include/memory/MmapDeallocator.h new file mode 100644 index 000000000000..6e2429c9db68 --- /dev/null +++ b/libnd4j/include/memory/MmapDeallocator.h @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_MEMORY_MMAPDEALLOCATOR_H_ +#define SD_MEMORY_MMAPDEALLOCATOR_H_ + +#include + +namespace sd { +namespace memory { + class MmapDeallocator : public sd::PointerDeallocator { + private: + public: + MmapDeallocator() = default; + ~MmapDeallocator() = default; + + void release(void *ptr) override; + }; + +} +} + + +#endif //SD_MEMORY_MMAPDEALLOCATOR_H_ diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp index afca4b9dfe0d..44294a59f02d 100644 --- a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp +++ b/libnd4j/include/memory/cpu/GraphMemoryManager.cpp @@ -47,5 +47,10 @@ MemoryDescriptor GraphMemoryManager::allocate(size_t numBytes, void GraphMemoryManager::release(MemoryDescriptor &descriptor) { _zones[descriptor.zone()]->release(descriptor); } + +void GraphMemoryManager::track(const std::shared_ptr &ptr) const { + _attached.emplace_back(ptr); +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/memory/impl/MmapDeallocator.cpp b/libnd4j/include/memory/impl/MmapDeallocator.cpp new file mode 100644 index 000000000000..b4df0ada8009 --- /dev/null +++ b/libnd4j/include/memory/impl/MmapDeallocator.cpp @@ -0,0 +1,33 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace sd { +namespace memory { + +void MmapDeallocator::release(void *ptr) { + auto ref = reinterpret_cast(ptr); + ::munmapFile(nullptr, ref, ref[2]); +} + +} // namespace memory +} // namespace sd From c405d3cf847f4301458486e8b4ea7512bc3c23ea Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 17:37:14 +0300 Subject: [PATCH 142/233] re-enabled GraphExecutorTests Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/OptimizedGraph.h | 10 ++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 +- .../layers_tests/GraphExecutorTests.cpp | 94 +++++++++---------- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index dff2494ce6fe..1fbb0cafdc59 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -74,7 +74,17 @@ class SD_EXPORT OptimizedGraph { */ void printOut() const; + /** + * Returns number of layers within OptimizedGraph + * @return + */ + uint64_t layers() const { return _sortedGraph.size(); } + /** + * This method adds given OpSequence to execution queue + * @param sequence + */ + void append(const OpSequence &sequence); }; diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 35cba736cf9d..c034ade246a5 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -131,7 +131,9 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS // l.sortOpSequences(); } - +void OptimizedGraph::append(const OpSequence &sequence) { + _sortedGraph.emplace_back(ExecutionLayer({sequence})); +} /////////////////////////////////////////////////////////////////// size_t OptimizedGraph::size() const { diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 641e5de291dc..2d48c876eb89 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -39,67 +39,67 @@ class GraphExecutorTests : public testing::Test { public: }; -// TEST_F(GraphExecutorTests, test_basic_exec_1) { -// GraphMemoryManager memoryManager; -// Graph graph; +TEST_F(GraphExecutorTests, test_basic_exec_1) { + GraphMemoryManager memoryManager; + Graph graph; + OptimizedGraph optimizedGraph; + OpSequence sequence; -// OptimizedGraph optimizedGraph; -// OpSequence sequence; + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); -// optimizedGraph.append(sequence); + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); +} -// VariableProxy proxy(&graph.variableSpace()); -// GraphExecutor executor; -// executor.execute(optimizedGraph, proxy); -// } +TEST_F(GraphExecutorTests, test_basic_exec_2) { + GraphMemoryManager mgr; + Graph graph(nullptr, mgr); -// TEST_F(GraphExecutorTests, test_basic_exec_2) { -// GraphMemoryManager mgr; -// Graph graph(nullptr, mgr); + auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); -// auto A = NDArrayFactory::create('c', {3}, {1, 1, 1}); -// auto B = NDArrayFactory::create('c', {3}, {2, 2, 2}); -// auto C = NDArrayFactory::create('c', {3}, {3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); -// auto exp = NDArrayFactory::create('c', {3}, {5, 5, 5}); + graph.addVariable("A", A); + graph.addVariable("B", B); + graph.addVariable("C", C); -// graph.addVariable("A", A); -// graph.addVariable("B", B); -// graph.addVariable("C", C); + Node m(sd::ops::multiply(), "mul"); + Node a(sd::ops::add(), "add"); -// Node m(sd::ops::multiply(), "mul"); -// Node a(sd::ops::add(), "add"); + graph.addNode(m, {"A", "B"}); + graph.addNode(a, {"mul", "C"}); -// graph.addNode(m, {"A", "B"}); -// graph.addNode(a, {"mul", "C"}); + OptimizedGraph optimizedGraph; + OpSequence sequence; -// OptimizedGraph optimizedGraph; -// OpSequence sequence; + ASSERT_EQ(2, m.protoContext().inputs().size()); + ASSERT_EQ(2, a.protoContext().inputs().size()); -// ASSERT_EQ(2, m.protoContext().inputs().size()); -// ASSERT_EQ(2, a.protoContext().inputs().size()); + sequence.append(m.customOp(), m.protoContext()); + sequence.append(a.customOp(), a.protoContext()); -// sequence.append(m.customOp(), m.protoContext()); -// sequence.append(a.customOp(), a.protoContext()); + optimizedGraph.append(sequence); -// optimizedGraph.append(sequence); + ASSERT_EQ(2, sequence.length()); + ASSERT_EQ(1, optimizedGraph.layers()); -// ASSERT_EQ(2, sequence.length()); -// ASSERT_EQ(1, optimizedGraph.layers()); + VariableProxy proxy(&graph.variableSpace()); + GraphExecutor executor; + executor.execute(optimizedGraph, proxy); -// VariableProxy proxy(&graph.variableSpace()); -// GraphExecutor executor; -// executor.execute(optimizedGraph, proxy); + // checking results by ID + ASSERT_TRUE(proxy.hasVariable(m.id())); + ASSERT_TRUE(proxy.hasVariable(a.id())); -// // checking results by ID -// ASSERT_TRUE(proxy.hasVariable(m.id())); -// ASSERT_TRUE(proxy.hasVariable(a.id())); + // checking results by name + ASSERT_TRUE(proxy.hasVariable("mul")); + ASSERT_TRUE(proxy.hasVariable("add")); -// // checking results by name -// ASSERT_TRUE(proxy.hasVariable("mul")); -// ASSERT_TRUE(proxy.hasVariable("add")); - -// // checking if result is valid -// auto result = proxy.getVariable(a.id())->getNDArray(); -// ASSERT_EQ(exp, *result); -// } \ No newline at end of file + // checking if result is valid + auto result = proxy.getVariable(a.id())->getNDArray(); + ASSERT_EQ(exp, *result); +} \ No newline at end of file From 6ee056a65d13c940f3761380680ccf287d6e5fbb Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 18:50:00 +0300 Subject: [PATCH 143/233] just few comments Signed-off-by: raver119@gmail.com --- .../include/graph/execution/GraphExecutor.h | 10 +++++--- .../graph/execution/impl/GraphExecutor.cpp | 25 ++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index da83c0b19043..ce54b07c3b19 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -69,13 +69,14 @@ class SD_EXPORT GraphExecutor { /** * This method executes OpSequence - * @param sequence + * @param seq * @param deviceId - this argument allows to override device affinity * specified in OpSequence, keep it < 0 to follow OpSequence * @return */ - virtual Nd4jStatus execute(const OpSequence &sequence, - const OptimizedGraph &graph, VariableProxy &proxy, + virtual Nd4jStatus execute(const OpSequence &seq, + const OptimizedGraph &graph, + VariableProxy &proxy, int deviceId) const; /** @@ -87,7 +88,8 @@ class SD_EXPORT GraphExecutor { virtual Nd4jStatus execute(const std::shared_ptr &op, const ContextPrototype &contextPrototype, const OpSequence &sequence, - const OptimizedGraph &graph, VariableProxy &proxy, + const OptimizedGraph &graph, + VariableProxy &proxy, const int deviceId) const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 5473395cbca9..1a1b5198c1b6 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -59,17 +59,19 @@ Nd4jStatus GraphExecutor::execute( // throw std::runtime_error("GraphExecutor::execute - Not implemented yet"); } -Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, +Nd4jStatus GraphExecutor::execute(const OpSequence &seq, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { /* * this is a basic implementation that works without dispatching etc */ - for (int e = 0; e < sequence.length(); e++) { - auto v = sequence[e]; - auto result = execute(v.op(), v.protoContext(), sequence, graph, proxy, - deviceId >= 0 ? deviceId : sequence.deviceId()); + for (int e = 0; e < seq.length(); e++) { + auto v = seq[e]; + // only Ops can be executed this way :( + auto result = execute(v.op(), v.protoContext(), seq, graph, proxy, deviceId >= 0 ? deviceId : seq.deviceId()); + + // if any one op fails - there will be no sense in executing other ops if (result != Status::OK()) return result; } @@ -78,24 +80,29 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &sequence, Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, VariableProxy &proxy) const { - const auto numDevices = AffinityManager::numberOfDevices(); - /* * this is a basic exection logic: roll through layers and sequences and * execute them one by one sequentially */ + const auto numDevices = AffinityManager::numberOfDevices(); Nd4jStatus result = Status::OK(); - for (uint64_t l = 0; l < graph.numOfLayers(); l++) { + for (uint64_t l = 0; l < graph.layers(); l++) { const auto &layer = graph.layer(l); + //TODO: this loop is executable in parallel, so we should do this eventually for (uint64_t o = 0; o < layer.width(); o++) { - execute(layer[o], graph, proxy, -1); + result = execute(layer[o], graph, proxy, -1); } + // early termination + if (result != Status::OK()) return result; + // optionally block until all sequences in this layer processed if (layer.width() > 0 && numDevices > 1) for (uint64_t o = 0; o < layer.width(); o++) { result = layer[o].wait(); + + // early termination if (result != Status::OK()) return result; } } From 9bfdee7c3b0bb7307426dcfefcc05709ca530562 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 8 Jun 2020 21:44:04 +0300 Subject: [PATCH 144/233] ExecutionTask shouldn't rely on DeclarableOp Signed-off-by: raver119@gmail.com --- .../include/graph/execution/ExecutionTask.h | 8 +++--- libnd4j/include/graph/execution/OpSequence.h | 5 ++-- .../graph/execution/impl/ExecutionTask.cpp | 25 +++++++------------ .../graph/execution/impl/GraphExecutor.cpp | 10 +++++++- .../graph/execution/impl/OpSequence.cpp | 11 ++------ libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 +-- .../layers_tests/ExecutionLayerTests.cpp | 12 ++++----- .../layers_tests/GraphExecutorTests.cpp | 4 +-- 8 files changed, 37 insertions(+), 42 deletions(-) diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h index 9e9d22f2f304..c8532222557e 100644 --- a/libnd4j/include/graph/execution/ExecutionTask.h +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -21,6 +21,7 @@ #ifndef SD_EXECUTIONTASK_H #define SD_EXECUTIONTASK_H +#include #include #include @@ -30,11 +31,12 @@ namespace sd { namespace graph { class SD_EXPORT ExecutionTask { protected: - std::shared_ptr _op; + // FIXME: do we really want references here? smart pointers would work better + const Node& _node; const ContextPrototype& _context; public: - ExecutionTask(const std::shared_ptr& op, + ExecutionTask(const Node& node, const ContextPrototype& ctx); ~ExecutionTask() = default; @@ -51,7 +53,7 @@ class SD_EXPORT ExecutionTask { void printOut() const; - std::shared_ptr op() const; + const Node& node() const; const ContextPrototype& protoContext() const; }; diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index eafd49f699e5..b625af8f3510 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -90,11 +90,10 @@ class SD_EXPORT OpSequence * @param ctx - ContextPrototype for this operation with inputs/outputs/args * defined */ - void append(const std::shared_ptr& op, - const sd::graph::ContextPrototype& ctx); - void append(sd::ops::DeclarableOp* op, + void append(const Node& node, const sd::graph::ContextPrototype& ctx); + /** * Iterator functionality for OpSequence * @return diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index c2067041e0f8..e48a0e26aea2 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -22,40 +22,33 @@ namespace sd { namespace graph { -ExecutionTask::ExecutionTask(const std::shared_ptr &op, +ExecutionTask::ExecutionTask(const Node& node, const ContextPrototype &ctx) - : _op(op), _context(ctx) { - // -} + : _node(node), _context(ctx) { } -std::shared_ptr ExecutionTask::op() const { return _op; } +const Node& ExecutionTask::node() const { return _node; } const ContextPrototype &ExecutionTask::protoContext() const { return _context; } ExecutionTask::ExecutionTask(const ExecutionTask &other) - : _op(other._op), _context(other._context) { - // -} + : _node(other._node), _context(other._context) { } ExecutionTask &ExecutionTask::operator=(const ExecutionTask &other) noexcept { if (this == &other) return *this; - _op = other._op; + const_cast(_node) = other._node; const_cast(_context) = other._context; return *this; } ExecutionTask::ExecutionTask(ExecutionTask &&other) - : _op(other._op), _context(other._context) { - // -} + : _node(other._node), _context(other._context) { } void ExecutionTask::printOut() const { if (_context.name().empty()) { - if (_op != nullptr) - printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), - _op->getOpName().c_str()); + if (_node.hasCustomOp()) + printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _node.customOp()->getOpName().c_str()); else printf(" <%i:0>: ", _context.nodeId()); } else { @@ -85,7 +78,7 @@ void ExecutionTask::printOut() const { ExecutionTask &ExecutionTask::operator=(ExecutionTask &&other) noexcept { if (this == &other) return *this; - _op = std::move(other._op); + const_cast(_node) = other._node; const_cast(_context) = std::move(other._context); return *this; diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 1a1b5198c1b6..86dadc7db91d 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -69,7 +69,15 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, for (int e = 0; e < seq.length(); e++) { auto v = seq[e]; // only Ops can be executed this way :( - auto result = execute(v.op(), v.protoContext(), seq, graph, proxy, deviceId >= 0 ? deviceId : seq.deviceId()); + Nd4jStatus result = Status::OK(); + + if (v.node().hasCustomOp()) + result = execute(v.node().customOp(), v.protoContext(), seq, graph, proxy, deviceId >= 0 ? deviceId : seq.deviceId()); + else { + nd4j_printf("Node <%i:%s> has no customOp set\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str()); + } // if any one op fails - there will be no sense in executing other ops if (result != Status::OK()) return result; diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 07e1b1086476..5fac00733b5a 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -78,19 +78,12 @@ const ExecutionTask &OpSequence::operator[](uint64_t index) const { uint64_t OpSequence::length() const { return _ops.size(); } -void OpSequence::append(const std::shared_ptr &op, +void OpSequence::append(const Node &node, const sd::graph::ContextPrototype &ctx) { - ExecutionTask task(op, ctx); + ExecutionTask task(node, ctx); _ops.emplace_back(task); } -void OpSequence::append(sd::ops::DeclarableOp *op, - const ContextPrototype &ctx) { - auto rop = - sd::ops::OpRegistrator::getInstance().getOperation(op->getOpHash()); - append(rop, ctx); -} - OpSequence::iterator OpSequence::begin() { return OpSequence::iterator(*this, 0); } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index c034ade246a5..1f5e0db405b4 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -118,10 +118,10 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& p : workMap) { OpSequence seq; - seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).protoContext()); + seq.append(inMap.at(p.first), inMap.at(p.first).protoContext()); for (const auto& id : p.second._opSeq) - seq.append(inMap.at(id).customOp(), inMap.at(id).protoContext()); + seq.append(inMap.at(id), inMap.at(id).protoContext()); _sortedGraph[p.second._layerNum].append(std::move(seq)); } diff --git a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp index 59798c27ec7e..5cc96512c354 100644 --- a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp @@ -41,17 +41,17 @@ TEST_F(ExecutionLayerTests, test_reassign_1) { ExecutionLayer layer; OpSequence sequence1, sequence2; - ops::add op1; - ops::multiply op2; - ops::divide op3; + Node a(sd::ops::add(), "add"); + Node m(sd::ops::multiply(), "mul"); + Node d(sd::ops::divide(), "div"); Context ctx1(1); Context ctx2(2); Context ctx3(3); - sequence1.append(&op1, ctx1); - sequence2.append(&op2, ctx2); - sequence2.append(&op3, ctx3); + sequence1.append(a, ctx1); + sequence2.append(m, ctx2); + sequence2.append(d, ctx3); layer.append(sequence1); layer.append(sequence2); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 2d48c876eb89..465141d7288a 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -79,8 +79,8 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { ASSERT_EQ(2, m.protoContext().inputs().size()); ASSERT_EQ(2, a.protoContext().inputs().size()); - sequence.append(m.customOp(), m.protoContext()); - sequence.append(a.customOp(), a.protoContext()); + sequence.append(m, m.protoContext()); + sequence.append(a, a.protoContext()); optimizedGraph.append(sequence); From d1389266f126a4f6b0021cd923530a2aef312292 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 09:50:20 +0300 Subject: [PATCH 145/233] minor Node overhaul Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Graph.h | 11 +- libnd4j/include/graph/Node.h | 151 ++------ libnd4j/include/graph/Scope.h | 108 ------ libnd4j/include/graph/impl/Graph.cpp | 16 +- libnd4j/include/graph/impl/Node.cpp | 333 ++---------------- libnd4j/include/graph/impl/Scope.cpp | 58 --- .../graph/logic/impl/LogicExecutor.cpp | 4 +- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 41 +-- 8 files changed, 84 insertions(+), 638 deletions(-) delete mode 100644 libnd4j/include/graph/Scope.h delete mode 100644 libnd4j/include/graph/impl/Scope.cpp diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 5a0d3093892c..81dfc67c31ce 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -48,6 +47,8 @@ class SD_EXPORT Graph { protected: ExecutorConfiguration _configuration; VariableSpace _variableSpace; + + // TODO: these 2 fields should be deleted memory::Workspace _workspace; Stash _stash; @@ -134,12 +135,10 @@ class SD_EXPORT Graph { * These methods add given node to the graph * @param node */ - void addNode(Node &&node, const std::initializer_list &inputs); + void addNode(Node &&node, const std::vector &inputs); + + void addNode(Node &node, const std::vector &inputs); - void addNode(Node &node, const std::initializer_list &inputs); - void addNode(Node &node, const std::initializer_list &inputs); - void addNode(Node &node, - const std::initializer_list> &inputs); void addVariable(const std::string &name, NDArray &array); void addVariable(const std::string &name, NDArray &&array); diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 0f2e78488771..71b210b948a5 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -38,60 +38,40 @@ class Graph; class SD_EXPORT Node { protected: - // TODO: this field must be removed - sd::DataType _dataType; + // int and string IDs + int _id = 0; + std::string _name; - OpType _opType; + // Node state, basically ContextPrototype _protoContext; + + // these 2 fields are used for Logic ops only + OpType _opType; + OpClass _opClass; Nd4jLong _opNum; - int _id = 0; + // Inputs are stored in format std::vector> _input; + + // Outputs are stored in format std::vector> _output; + + // Control flow dependencies for Node std::vector> _dependencies; std::vector _stringDependencies; - std::vector _dimensions; - - std::vector _referencedBy; - - std::string _name; - - // many ops require extra parameters to run - double *_extraParams = nullptr; - - bool _hasExternalOutputs; - bool _hasExternalInputs; - bool _hasInternalOutputs; - bool _hasInternalInputs; - - // this field is used to check, if op should be used in-place (so it can/will - // modify its inputs) - bool _isInplace = false; - - OpClass _opClass; + // service state fields + bool _hasExternalOutputs = false; + bool _hasExternalInputs = false; + bool _hasInternalOutputs = false; + bool _hasInternalInputs = false; - // these fields are used to store embedded CustomOps and Graph in case of - // Graph-in-Graph scenario - Graph *_graph = nullptr; std::shared_ptr _customOp; // each node can be active or inactive, if used with divergents, like IF // statements bool _active = true; - // meh - mutable bool _removable = true; - - // these fields contain information about Scope these ops are related to - int _scope_id = 0; - std::string _scope_name; - - // TODO: these 3 fields should be removed - int _rewindNode = -1; - std::pair _rewindLayer = {-1, -1}; - Nd4jLong _frameId = -1; - public: explicit Node(const sd::ops::DeclarableOp &op, const std::string &nodeName = {}, @@ -147,120 +127,61 @@ class SD_EXPORT Node { // move assignment operator Node &operator=(Node &&other) noexcept; - bool equals(Node *other) const; + bool equals(const Node *other) const; + bool equals(const Node &other) const; - sd::DataType dataType(); const ContextPrototype &protoContext() const; - OpType opType() const; + + OpType opType() const { return _opType; }; + OpClass opClass() const { return _opClass;}; + Nd4jLong opNum() const; int id() const; + const std::vector> &input() const; const std::vector> &output() const; const std::vector> &dependencies() const; - Nd4jLong getFrameId(); - void setFrameId(Nd4jLong frameId); - - int getRewindNode(); - void setRewindNode(int nodeId); - - std::pair &getRewindLayer(); - void setRewindLayer(int layerId, int stepId = 0); - void setId(int id); - double *extraParams(); - bool isMultiInput(); bool isMultiOutput(); - bool isRemovable() const; - void markRemovable(bool reallyRemovable) const; - bool isDivergencePoint(); - void setActive(bool reallyActive); - bool isActive(); - bool hasExternalOutputs(); - bool hasExternalInputs(); - bool hasInternalOutputs(); - bool hasInternalInputs(); + bool hasExternalOutputs() const; + bool hasExternalInputs() const; + bool hasInternalOutputs() const; + bool hasInternalInputs() const; void pickOutputOnce(int outputId); void pickOutput(int outputId); void pickOutput(int nodeId, int outputId); + void pickExternalOutput(int outputId); + void pickInput(int inputId); void pickInput(int nodeId, int outputId); - void pickInput(std::pair &id); + void pickInput(const std::pair &id); void pickInput(const std::string &id); - void setName(std::string *name); void setName(const std::string &name); - const std::string &getName() const; const std::string &name() const; - int totalReferences(); - void addReference(int nodeId); - void setContextPrototype(const ContextPrototype &block); - const ContextPrototype &contextPrototype() const; - bool hasBlockAttached(); - void setCustomOp(std::shared_ptr customOp); + void setCustomOp(const std::shared_ptr &customOp); std::shared_ptr customOp() const; - bool hasCustomOp() const; - - void setGraph(Graph *graph = nullptr); - Graph *graph() const; - bool hasGraphEmbedded() const; - bool isInplace(); - void markInplace(bool reallyInplace); - - OpClass getOpClass() const; - - // these methods are used for internal profiling - void setOuterTime(Nd4jLong time); - void setInnerTime(Nd4jLong time); - - // methods related to scopes - bool isScoped(); - void setScopeInfo(int id, const char *name = nullptr); - int scopeId(); - std::string *scopeName(); + bool hasCustomOp() const; void setOpType(OpType opType); - // clone Node - Node *clone(); - - template - Node *asT(); - // this method converts string deps to int deps void actualizeDependencies(const MAP_IMPL &lookupTable) const; - FORCEINLINE void pullValues(Node *other) { - this->_dataType = other->dataType(); - this->_protoContext = other->protoContext(); - this->_hasExternalInputs = other->hasExternalInputs(); - this->_hasExternalOutputs = other->hasExternalOutputs(); - this->_hasInternalInputs = other->hasInternalInputs(); - this->_hasInternalOutputs = other->hasInternalOutputs(); - - this->markInplace(other->isInplace()); - this->setActive(other->isActive()); - this->setScopeInfo(other->scopeId(), other->scopeName()->c_str()); - - for (auto &v : other->input()) this->_input.emplace_back(v); - - for (auto &v : other->output()) this->_output.emplace_back(v); - } - - static std::shared_ptr buildOpByType( - OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); - static void deleteOpByType(OpType opType, void *op); + // utility method that generates legacy ops out of OpType and OpNum + static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/Scope.h b/libnd4j/include/graph/Scope.h deleted file mode 100644 index 660578503e45..000000000000 --- a/libnd4j/include/graph/Scope.h +++ /dev/null @@ -1,108 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 14.10.2017. -// - -#ifndef LIBND4J_SCOPE_H -#define LIBND4J_SCOPE_H - -#include - -#include -#include - -namespace sd { -namespace graph { - -/** - * Scope holds sequential list of operations, and made suitable for continuous - * re-execution of multiple operations. - * - * @tparam T - */ -class SD_EXPORT Scope { - protected: - // Graph-unique IDs for Scope instances - int _id; - std::string _name; - - // list of nodes to run, always sequential - // Graph takes care of topo sort - std::vector _nodes; - - public: - // attach GiG here, with shared namespace? - // or just rebuilt graph leaf? - // ¯\_(ツ)_/¯ - - // default consructor - explicit Scope(int id, const char* name = nullptr); - - // default destructor - ~Scope(); - - /** - * this method adds op node to the scope - * - * PLEASE NOTE: We assume that ops are being added ORDERED - */ - void push_back(Node* node); - - /** - * This method returns list of ops stored earlier, ready for execution - * - * PLEASE NOTE: If the scope is conditional - last op in list should be - * BooleanOp - * @return - */ - std::vector* nodes(); - - /** - * This function returns number of nodes in this scope - * - * @return - */ - int size(); - - /** - * Returns ID of this scope - * @return - */ - int id(); - - /** - * Returns name of this scope - * - * @return - */ - std::string* name(); - - /** - * This method returns clone of this Scope - */ - Scope* clone(); - - /** - * This method removes all Nodes from this scope - */ - void forgetNodes(); -}; -} // namespace graph -} // namespace sd - -#endif // LIBND4J_SCOPE_H diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index a522e800b000..b84e2afa75ba 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -78,13 +78,12 @@ void Graph::addVariable(const std::string &name, NDArray &&array) { } void Graph::addNode(Node &&node, - const std::initializer_list &inputs) { + const std::vector &inputs) { auto lvalue = std::move(node); addNode(lvalue, inputs); } -void Graph::addNode(Node &node, - const std::initializer_list &inputs) { +void Graph::addNode(Node &node, const std::vector &inputs) { // temporary check. basically we're okay if Node has id defined if (node.id() != 0) throw std::runtime_error("Graph::addNode - Node has id defined"); @@ -116,17 +115,6 @@ void Graph::addNode(Node &node, _unmapped[node.id()] = node; } -void Graph::addNode(Node &node, const std::initializer_list &inputs) { - throw std::runtime_error("Graph::addNode() - Not implemented yet"); -} - -void Graph::addNode(Node &node, - const std::initializer_list> &inputs) { - node.markRemovable(false); - - throw std::runtime_error("Graph::addNode() - Not implemented yet"); -} - Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager): _memoryManager(memoryManager) { bool trusted = flatGraph != nullptr; diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 39f668ba64f9..359883d360aa 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -41,6 +41,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -53,8 +54,6 @@ Node::Node(const ops::DeclarableOp &opName, const std::string &nodeName, this->_name = nodeName; this->_opType = OpType_CUSTOM; this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default this->_customOp = customOp; _hasExternalInputs = false; @@ -82,8 +81,6 @@ Node::Node(const std::string &opName, const std::string &nodeName, this->_name = nodeName; this->_opType = OpType_CUSTOM; this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default this->_customOp = customOp; _hasExternalInputs = false; @@ -103,37 +100,6 @@ Node::Node(const std::string &opName, const std::string &nodeName, this->setContextPrototype(block); } -void Node::setOuterTime(Nd4jLong time) { - // if (hasBlockAttached()) - // _block->setOuterTime(time); -} - -void Node::setInnerTime(Nd4jLong time) { - // if (hasBlockAttached()) - // _block->setInnerTime(time); -} - -void Node::setGraph(Graph *graph) { _graph = graph; } - -Graph *Node::graph() const { return _graph; } - -void Node::markInplace(bool reallyInplace) { - _isInplace = reallyInplace; - _protoContext.markInplace(reallyInplace); -} - -bool Node::isRemovable() const { return _removable; } - -void Node::markRemovable(bool reallyRemovable) const { - _removable = reallyRemovable; -} - -OpClass Node::getOpClass() const { return _opClass; } - -bool Node::hasBlockAttached() { return true; } - -bool Node::isInplace() { return _isInplace; } - bool Node::isDivergencePoint() { if (hasCustomOp()) { return _customOp->getOpDescriptor()->isDivergent(); @@ -143,16 +109,6 @@ bool Node::isDivergencePoint() { return false; } -void Node::setActive(bool reallyActive) { _active = reallyActive; } - -bool Node::isActive() { return _active; } - -Nd4jLong Node::getFrameId() { return _frameId; } - -void Node::setFrameId(Nd4jLong frameId) { _frameId = frameId; } - -const ContextPrototype &Node::contextPrototype() const { return _protoContext; } - void Node::setContextPrototype(const ContextPrototype &block) { _protoContext = block; } @@ -166,25 +122,18 @@ std::shared_ptr Node::customOp() const { return _customOp; } -void Node::setCustomOp(std::shared_ptr customOp) { +void Node::setCustomOp(const std::shared_ptr& customOp) { _customOp = customOp; - - // divergent ops (Switch etc) are always inplace, they don't allocate anything - if (_customOp.get() != nullptr && _customOp->getOpDescriptor()->isDivergent()) - _isInplace = true; } bool Node::hasCustomOp() const { return _customOp != nullptr; } -const std::string &Node::name() const { return this->getName(); } +const std::string &Node::name() const { return _name; } -const std::string &Node::getName() const { return _name; } void Node::setName(const std::string &name) { _name = name; } -void Node::setName(std::string *name) { _name = *name; } - -void Node::pickInput(std::pair &pair) { +void Node::pickInput(const std::pair &pair) { _input.push_back(pair); _protoContext.pickInput(pair); } @@ -235,26 +184,18 @@ void Node::pickOutput(int outputId) { _hasInternalOutputs = true; } -bool Node::hasExternalOutputs() { return _hasExternalOutputs; } +bool Node::hasExternalOutputs() const { return _hasExternalOutputs; } -bool Node::hasExternalInputs() { return _hasExternalInputs; } +bool Node::hasExternalInputs() const { return _hasExternalInputs; } -bool Node::hasInternalOutputs() { return _hasInternalOutputs; } +bool Node::hasInternalOutputs() const { return _hasInternalOutputs; } -bool Node::hasInternalInputs() { return _hasInternalInputs; } +bool Node::hasInternalInputs() const { return _hasInternalInputs; } bool Node::isMultiInput() { return _input.size() > 1; } bool Node::isMultiOutput() { return _output.size() > 1; } -double *Node::extraParams() { return _extraParams; } - -int Node::totalReferences() { return _referencedBy.size(); } - -void Node::addReference(int nodeId) { _referencedBy.emplace_back(nodeId); } - -OpType Node::opType() const { return _opType; } - int Node::id() const { return _id; } Nd4jLong Node::opNum() const { return _opNum; } @@ -263,26 +204,6 @@ const std::vector> &Node::input() const { return _input; } const std::vector> &Node::output() const { return _output; } -bool Node::isScoped() { return _scope_id != 0; } - -void Node::setScopeInfo(int id, const char *name) { - _scope_id = id; - - if (name != nullptr) _scope_name = name; -} - -int Node::scopeId() { return _scope_id; } - -std::string *Node::scopeName() { return &_scope_name; } - -template -Node *Node::asT() { - auto node = this->clone(); - node->_dataType = DataTypeUtils::fromT(); - return node; -} -BUILD_SINGLE_TEMPLATE(template SD_EXPORT Node *Node::asT, (), LIBND4J_TYPES); - Node::Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs, const std::vector &tArgs, @@ -292,15 +213,8 @@ Node::Node(const std::string &opName, const std::string &nodeName, const int id, this->_opType = OpType_CUSTOM; this->_id = id; this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default this->_customOp = customOp; - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - for (auto i : inputs) pickInput(i); ContextPrototype block(this->customOp()->getOpDescriptor(), this->id(), @@ -321,14 +235,8 @@ Node::Node(const std::string &opName, const int id, this->_opType = OpType_CUSTOM; this->_id = id; this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default this->_customOp = customOp; - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; for (auto i : inputs) pickInput(i); @@ -349,8 +257,6 @@ Node::Node(sd::ops::DeclarableOp *customOp, int id, this->_opType = OpType_CUSTOM; this->_id = id; this->_opNum = customOp->getOpHash(); - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default // if custom op is a registered one - pull it from cache, otherwise - clone // locally @@ -361,11 +267,6 @@ Node::Node(sd::ops::DeclarableOp *customOp, int id, throw std::runtime_error( "Can't create a node with custom operation within"); - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - for (auto i : input) pickInput(i); for (auto o : output) pickOutput(o); @@ -405,8 +306,6 @@ Node::Node(OpType opType, int opNum, int id, std::initializer_list input, this->_opType = opType; this->_id = id; this->_opNum = opNum; - this->_extraParams = nullptr; - this->_dataType = sd::DataType::FLOAT32; // float as default _hasExternalInputs = false; _hasExternalOutputs = false; @@ -421,9 +320,7 @@ Node::Node(OpType opType, int opNum, int id, std::initializer_list input, if (opType == OpType_TRANSFORM_SAME || opType == OpType_TRANSFORM_FLOAT || opType == OpType_TRANSFORM_STRICT || opType == OpType_TRANSFORM_BOOL || opType == OpType_SCALAR || opType == OpType_BROADCAST) { - if (_output.size() <= 1) { - _isInplace = true; - } + _opClass = OpClass_TRANSFORM; } else if (opType == OpType_REDUCE_SAME || opType == OpType_REDUCE_FLOAT || opType == OpType_REDUCE_BOOL || opType == OpType_REDUCE_LONG || @@ -472,17 +369,9 @@ Node::Node(OpType opType, int opNum, int id, std::initializer_list input, }; Node::Node(const FlatNode *node) { - _hasExternalInputs = false; - _hasExternalOutputs = false; - _hasInternalInputs = false; - _hasInternalOutputs = false; - _extraParams = nullptr; - - _dataType = sd::DataType::FLOAT32; // float as default - if (node->scope_id() != 0) this->_scope_id = node->scope_id(); - - if (node->scope_name() != nullptr && node->scope_name()->size() > 0) - this->_scope_name = node->scope_name()->str(); + // temporary holders for _extras and _dimensions, for transferring those into ContextPrototype + std::vector extras; + std::vector axis; if (node->scalar() != nullptr) throw std::runtime_error("FlatNode has scalar defined, it's deprecated"); @@ -532,36 +421,27 @@ Node::Node(const FlatNode *node) { _stringDependencies.emplace_back(node->controlDeps()->Get(e)->str()); } - /* - if (node->output() != nullptr) - for (int e = 0; e < (int) node->output()->size(); e++) { - auto oid = node->output()->Get(e); - if (oid != this->_id && oid != 0) { - nd4j_verbose("Picking output: %i\n", node->output()->Get(e)); - pickOutput(oid); - } - } - */ - + // transferring extraParams. Used for legacy ops only if (node->extraParams() != nullptr && node->extraParams()->size() > 0) { - _extraParams = new double[node->extraParams()->size()]; - for (int e = 0; e < (int)node->extraParams()->size(); e++) { - _extraParams[e] = static_cast(node->extraParams()->Get(e)); - } + extras.resize(node->extraParams()->size()); + + for (int e = 0; e < (int)node->extraParams()->size(); e++) + extras[e] = static_cast(node->extraParams()->Get(e)); } - // if (node->dimensions() != nullptr && node->dimensions()->size() > 0) - // throw std::runtime_error("FlatNode has dimensions defined. Graph is - // outdated"); + // transferring dimensions. Used for legacy ops only + if (node->dimensions() != nullptr && node->dimensions()->size() > 0) { + axis.resize(node->dimensions()->size()); + + for (int e = 0; e < (int) node->dimensions()->size(); e++) + axis[e] = node->dimensions()->Get(e); + } if (this->opType() == OpType_LOGIC && this->opNum() == 100L) { - if (node->extraInteger()->size() < 1) { - nd4j_printf("Node_%i is type of Enter, but has no FrameID defined\n", - this->id()); - throw std::runtime_error("Enter node must have FrameID specified"); - } + if (node->extraInteger()->size() < 1) + throw std::runtime_error("Enter Node [" + StringUtils::valueToString(this->id()) + "] must have FrameID specified"); - this->setFrameId(node->extraInteger()->Get(0)); + //this->setFrameId(node->extraInteger()->Get(0)); } // these ops allow in-place execution by design @@ -574,16 +454,13 @@ Node::Node(const FlatNode *node) { _opType == OpType_TRANSFORM_BOOL || _opType == OpType_RANDOM || _opType == OpType_PAIRWISE || _opType == OpType_PAIRWISE_BOOL || _opType == OpType_SCALAR_BOOL || _opType == OpType_SCALAR) { - if (_output.size() <= 1) { - _isInplace = true; - } if (node->input() != nullptr && node->input()->size() > 0) { ContextPrototype block(nullptr, this->id(), false); if (!this->name().empty()) block.setName(this->name()); - for (auto v : _dimensions) block.appendA(v); + for (auto v : axis) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) for (int e = 0; e < (int) node->extraParams()->size(); e++) { @@ -621,7 +498,7 @@ Node::Node(const FlatNode *node) { } // there's no other IArgs in legacy options, actually - for (auto v : _dimensions) block.appendA(v); + for (auto v : axis) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) for (int e = 0; e < (int) node->extraParams()->size(); e++) { @@ -700,7 +577,7 @@ Node::Node(const FlatNode *node) { } } - for (auto v : _dimensions) block.appendA(v); + for (auto v : axis) block.appendA(v); this->setContextPrototype(block); this->setCustomOp(op); @@ -711,131 +588,88 @@ Node::Node(const FlatNode *node) { } } -sd::DataType Node::dataType() { return _dataType; } const ContextPrototype &Node::protoContext() const { return _protoContext; } -Node::~Node() { - if (_extraParams != nullptr) delete[] _extraParams; -} - -int Node::getRewindNode() { return _rewindNode; } - -void Node::setRewindNode(int nodeId) { _rewindNode = nodeId; } +Node::~Node() { } -std::pair &Node::getRewindLayer() { return _rewindLayer; }; - -void Node::setRewindLayer(int layerId, int stepId) { - _rewindLayer.first = layerId; - _rewindLayer.second = stepId; -} - -bool Node::equals(Node *other) const { - if (_opType == other->_opType && _dataType == other->_dataType && - _opNum == other->_opNum) +bool Node::equals(const Node *other) const { + if (_opType == other->_opType && _opNum == other->_opNum) return true; return false; } +bool Node::equals(const Node &other) const { + return this->equals(&other); +} + Node::Node(const Node &other) noexcept { - _dataType = other._dataType; _opType = other._opType; _opClass = other._opClass; _opNum = other._opNum; _customOp = other._customOp; _name = other._name; - _scope_id = other._scope_id; - _scope_name = other._scope_name; - _rewindNode = other._rewindNode; _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; _active = other._active; - _removable = other._removable; - _graph = other._graph; _customOp = other._customOp; - _extraParams = other._extraParams; _protoContext = other._protoContext; _input = other._input; _output = other._output; - _dimensions = other._dimensions; - _rewindLayer = other._rewindLayer; - _referencedBy = other._referencedBy; } Node &Node::operator=(const Node &other) noexcept { if (this == &other) return *this; - _dataType = other._dataType; _opType = other._opType; _opClass = other._opClass; _opNum = other._opNum; _customOp = other._customOp; _name = other._name; - _scope_id = other._scope_id; - _scope_name = other._scope_name; - _rewindNode = other._rewindNode; _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; _active = other._active; - _removable = other._removable; - _graph = other._graph; _customOp = other._customOp; - _extraParams = other._extraParams; _protoContext = other._protoContext; _input = other._input; _output = other._output; - _dimensions = other._dimensions; - _rewindLayer = other._rewindLayer; - _referencedBy = other._referencedBy; return *this; } Node::Node(Node &&other) noexcept { - _dataType = other._dataType; + _opType = other._opType; _opClass = other._opClass; _opNum = other._opNum; _customOp = other._customOp; - _scope_id = other._scope_id; _name = std::move(other._name); - _scope_name = std::move(other._scope_name); - _rewindNode = other._rewindNode; _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; _active = other._active; - _removable = other._removable; - _graph = other._graph; - _extraParams = other._extraParams; _protoContext = other._protoContext; _customOp = std::move(other._customOp); _input = std::move(other._input); _output = std::move(other._output); - _dimensions = std::move(other._dimensions); - _rewindLayer = std::move(other._rewindLayer); - _referencedBy = std::move(other._referencedBy); other._customOp = nullptr; } @@ -843,103 +677,28 @@ Node::Node(Node &&other) noexcept { Node &Node::operator=(Node &&other) noexcept { if (this == &other) return *this; - _dataType = other._dataType; _opType = other._opType; _opClass = other._opClass; _opNum = other._opNum; _customOp = other._customOp; - _scope_id = other._scope_id; _name = std::move(other._name); - _scope_name = std::move(other._scope_name); - _rewindNode = other._rewindNode; _id = other._id; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _isInplace = other._isInplace; _active = other._active; - _removable = other._removable; - _graph = other._graph; - _extraParams = other._extraParams; _protoContext = other._protoContext; _customOp = std::move(other._customOp); _input = std::move(other._input); _output = std::move(other._output); - _dimensions = std::move(other._dimensions); - _rewindLayer = std::move(other._rewindLayer); - _referencedBy = std::move(other._referencedBy); return *this; } -void Node::deleteOpByType(OpType opType, void *op) { - switch (opType) { - case OpType_PAIRWISE: - delete reinterpret_cast(op); - break; - case OpType_PAIRWISE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_STRICT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_SAME: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_TRANSFORM_BOOL: - delete reinterpret_cast(op); - break; - case OpType_SCALAR: - delete reinterpret_cast(op); - break; - case OpType_SCALAR_BOOL: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_3: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_SAME: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_FLOAT: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_LONG: - delete reinterpret_cast(op); - break; - case OpType_REDUCE_BOOL: - delete reinterpret_cast(op); - break; - case OpType_INDEX_REDUCE: - delete reinterpret_cast(op); - break; - case OpType_SUMMARYSTATS: - delete reinterpret_cast(op); - break; - case OpType_RANDOM: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST: - delete reinterpret_cast(op); - break; - case OpType_BROADCAST_BOOL: - delete reinterpret_cast(op); - break; - case OpType_CUSTOM: - delete reinterpret_cast(op); - break; - default: - throw std::runtime_error("Bad opType passed in"); - } -} - std::shared_ptr Node::buildOpByType( OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum) { switch (opType) { @@ -984,21 +743,5 @@ std::shared_ptr Node::buildOpByType( } } -Node *Node::clone() { - if (this->_customOp && this->_opType == OpType_CUSTOM) { - auto clone = new Node(_customOp.get(), _id); - clone->pullValues(this); - return clone; - } else { - auto clone = new Node(_opType, _opNum, _id); - - clone->pullValues(this); - - // op time - clone->_customOp = _customOp; - - return clone; - } -} } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/impl/Scope.cpp b/libnd4j/include/graph/impl/Scope.cpp deleted file mode 100644 index 9332537d4021..000000000000 --- a/libnd4j/include/graph/impl/Scope.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 14.10.2017. -// - -#include - -namespace sd { -namespace graph { -Scope::Scope(int id, const char* name) { - _id = id; - - if (name != nullptr) - _name = name; - else - name = ""; -} - -Scope::~Scope() { - for (auto v : _nodes) delete v; -} - -void Scope::push_back(Node* node) { _nodes.emplace_back(node); } - -std::vector* Scope::nodes() { return &_nodes; } - -int Scope::size() { return (int)_nodes.size(); } - -int Scope::id() { return _id; } - -std::string* Scope::name() { return &_name; } - -void Scope::forgetNodes() { _nodes.clear(); } - -Scope* Scope::clone() { - auto clone = new Scope(_id, _name.c_str()); - - for (auto v : _nodes) clone->_nodes.emplace_back(v->clone()); - - return clone; -} -} // namespace graph -} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 2e8fbba2b195..ee9ece1e773b 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -59,12 +59,12 @@ Nd4jStatus LogicExecutor::processNode(Graph *graph, Node *node) { return LogicEnter::processNode(graph, node); } - if (node->getName().empty()) { + if (node->name().empty()) { nd4j_printf("Unknown LogicOp used at node [%i]: [%i]\n", node->id(), node->opNum()); } else { nd4j_printf("Unknown LogicOp used at node [%i:<%s>]: [%i]\n", node->id(), - node->getName().c_str(), node->opNum()); + node->name().c_str(), node->opNum()); } return ND4J_STATUS_BAD_INPUT; } diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index f180bb24c038..93d73ad4af8a 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -30,43 +30,4 @@ using namespace sd::graph; class NodeTests : public testing::Test { public: -}; - -TEST_F(NodeTests, Test_Dtype_Conversion_1) { - auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2}); - - auto nd = nodeA->asT(); - auto nf = nd->asT(); - - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(nodeA->name(), nf->name()); - ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); - - delete nodeA; - delete nd; - delete nf; -} - -TEST_F(NodeTests, Test_Dtype_Conversion_2) { - sd::ops::add opA; - - // auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2}); - auto nodeA = new Node(&opA, 1, {-1}, {2}); - // nodeA->setCustomOp(&op); - - auto nd = nodeA->asT(); - auto nf = nd->asT(); - - ASSERT_EQ(nodeA->id(), nf->id()); - ASSERT_EQ(nodeA->name(), nf->name()); - // ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); - ASSERT_EQ(nodeA->opType(), nf->opType()); - ASSERT_EQ(nodeA->opNum(), nf->opNum()); - ASSERT_EQ(nodeA->customOp()->getOpHash(), nf->customOp()->getOpHash()); - - delete nodeA; - delete nd; - delete nf; -} \ No newline at end of file +}; \ No newline at end of file From 715b124c3528ddae20ccce5a8467c07a3a09be3c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 10:18:38 +0300 Subject: [PATCH 146/233] we don't need extras Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/Node.cpp | 64 +++++++++-------------------- 1 file changed, 19 insertions(+), 45 deletions(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 359883d360aa..252e744d0584 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -369,8 +369,7 @@ Node::Node(OpType opType, int opNum, int id, std::initializer_list input, }; Node::Node(const FlatNode *node) { - // temporary holders for _extras and _dimensions, for transferring those into ContextPrototype - std::vector extras; + // temporary holders _dimensions, for transferring axis into ContextPrototype std::vector axis; if (node->scalar() != nullptr) @@ -397,8 +396,7 @@ Node::Node(const FlatNode *node) { } else { if (this->opType() != OpType_LOGIC) { if (this->_name.size() > 0) { - nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, - this->_name.c_str()); + nd4j_debug("Node [%i:<%s>] has no inputs defined\n", this->_id, this->_name.c_str()); } else { nd4j_debug("Node [%i:] has no inputs defined\n", this->_id); } @@ -406,28 +404,18 @@ Node::Node(const FlatNode *node) { } // reading control deps, and filling _dependencies field - if (node->varControlDeps() != nullptr && node->varControlDeps()->size() > 0) { + if (node->varControlDeps() != nullptr && node->varControlDeps()->size() > 0) for (int e = 0; e < node->varControlDeps()->size(); e++) _stringDependencies.emplace_back(node->varControlDeps()->Get(e)->str()); - } - if (node->controlDepFor() != nullptr && node->controlDepFor()->size() > 0) { + if (node->controlDepFor() != nullptr && node->controlDepFor()->size() > 0) for (int e = 0; e < node->controlDepFor()->size(); e++) _stringDependencies.emplace_back(node->controlDepFor()->Get(e)->str()); - } - if (node->controlDeps() != nullptr && node->controlDeps()->size() > 0) { + if (node->controlDeps() != nullptr && node->controlDeps()->size() > 0) for (int e = 0; e < node->controlDeps()->size(); e++) _stringDependencies.emplace_back(node->controlDeps()->Get(e)->str()); - } - // transferring extraParams. Used for legacy ops only - if (node->extraParams() != nullptr && node->extraParams()->size() > 0) { - extras.resize(node->extraParams()->size()); - - for (int e = 0; e < (int)node->extraParams()->size(); e++) - extras[e] = static_cast(node->extraParams()->Get(e)); - } // transferring dimensions. Used for legacy ops only if (node->dimensions() != nullptr && node->dimensions()->size() > 0) { @@ -463,25 +451,20 @@ Node::Node(const FlatNode *node) { for (auto v : axis) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { + for (int e = 0; e < (int) node->extraParams()->size(); e++) block.appendT(static_cast(node->extraParams()->Get(e))); - } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { + for (int e = 0; e < (int) node->extraBools()->size(); e++) block.appendB(node->extraBools()->Get(e)); - } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { + for (int e = 0; e < (int) node->extraInteger()->size(); e++) block.appendI(node->extraInteger()->Get(e)); - } - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) + for (int e = 0; e < (int) node->extraTypes()->size(); e++) block.appendD((sd::DataType) node->extraTypes()->Get(e)); - } - } this->setContextPrototype(block); this->setCustomOp(Node::buildOpByType( @@ -501,25 +484,20 @@ Node::Node(const FlatNode *node) { for (auto v : axis) block.appendA(v); if (node->extraParams() != nullptr && node->extraParams()->size() > 0) - for (int e = 0; e < (int) node->extraParams()->size(); e++) { + for (int e = 0; e < (int) node->extraParams()->size(); e++) block.appendT(static_cast(node->extraParams()->Get(e))); - } if (node->extraBools() != nullptr && node->extraBools()->size() > 0) - for (int e = 0; e < (int) node->extraBools()->size(); e++) { + for (int e = 0; e < (int) node->extraBools()->size(); e++) block.appendB(node->extraBools()->Get(e)); - } if (node->extraInteger() != nullptr && node->extraInteger()->size() > 0) - for (int e = 0; e < (int) node->extraInteger()->size(); e++) { + for (int e = 0; e < (int) node->extraInteger()->size(); e++) block.appendI(node->extraInteger()->Get(e)); - } - if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) { - for (int e = 0; e < (int) node->extraTypes()->size(); e++) { + if (node->extraTypes() != nullptr && node->extraTypes()->size() > 0) + for (int e = 0; e < (int) node->extraTypes()->size(); e++) block.appendD((sd::DataType) node->extraTypes()->Get(e)); - } - } this->setContextPrototype(block); @@ -534,26 +512,22 @@ Node::Node(const FlatNode *node) { if (!this->name().empty()) block.setName(this->name()); - for (int e = 0; e < this->input().size(); e++) { + for (int e = 0; e < this->input().size(); e++) block.pickInput(this->input().at(e)); - } this->setContextPrototype(block); } else if (this->_opType == OpType_CUSTOM) { auto op = sd::ops::OpRegistrator::getInstance().getOperation(this->opNum()); - if (op == nullptr) { - nd4j_verbose("Can't find operation: %lld\n", this->opNum()); - throw std::runtime_error("Can't find requested operation"); - } + if (op == nullptr) + throw std::runtime_error("Can't find requested operation [" + StringUtils::valueToString(this->opNum()) + "]"); ContextPrototype block(nullptr, this->id()); if (!this->name().empty()) block.setName(this->name()); - for (int e = 0; e < this->input().size(); e++) { + for (int e = 0; e < this->input().size(); e++) block.pickInput(this->input().at(e)); - } if (node->extraInteger() != nullptr) for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { From d0517346c2245b759844d19a0885a932b306be02 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 10:52:14 +0300 Subject: [PATCH 147/233] meh Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 14 ++++++-------- libnd4j/include/graph/impl/Graph.cpp | 2 +- libnd4j/include/graph/impl/Node.cpp | 4 +--- libnd4j/include/graph/impl/OptimizedGraph.cpp | 8 ++++---- libnd4j/include/graph/logic/impl/LogicExit.cpp | 2 +- libnd4j/include/graph/logic/impl/LogicLoopCond.cpp | 2 +- .../tests_cpu/layers_tests/GraphExecutorTests.cpp | 8 ++++---- libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp | 2 +- 8 files changed, 19 insertions(+), 23 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 71b210b948a5..073f864eac0c 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -130,13 +130,15 @@ class SD_EXPORT Node { bool equals(const Node *other) const; bool equals(const Node &other) const; - const ContextPrototype &protoContext() const; - OpType opType() const { return _opType; }; OpClass opClass() const { return _opClass;}; - Nd4jLong opNum() const; int id() const; + const std::string &name() const; + + void setName(const std::string &name); + + Nd4jLong opNum() const; const std::vector> &input() const; const std::vector> &output() const; @@ -165,9 +167,7 @@ class SD_EXPORT Node { void pickInput(const std::pair &id); void pickInput(const std::string &id); - void setName(const std::string &name); - const std::string &name() const; - + const ContextPrototype &contextPrototype() const; void setContextPrototype(const ContextPrototype &block); void setCustomOp(const std::shared_ptr &customOp); @@ -175,8 +175,6 @@ class SD_EXPORT Node { bool hasCustomOp() const; - void setOpType(OpType opType); - // this method converts string deps to int deps void actualizeDependencies(const MAP_IMPL &lookupTable) const; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index b84e2afa75ba..c942f87f35a3 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -225,7 +225,7 @@ void Graph::printOutNode(const Node &node) const { } if (node.opType() == OpType_CUSTOM) { - auto ctx = node.protoContext(); + auto ctx = node.contextPrototype(); if (ctx.numI() > 0) { printf("]; iArgs: ["); diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 252e744d0584..5f1dfe243f2e 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -283,8 +283,6 @@ Node::Node(sd::ops::DeclarableOp *customOp, int id, this->setContextPrototype(block); } -void Node::setOpType(OpType opType) { this->_opType = opType; } - const std::vector>& Node::dependencies() const { return _dependencies; } @@ -563,7 +561,7 @@ Node::Node(const FlatNode *node) { } -const ContextPrototype &Node::protoContext() const { return _protoContext; } +const ContextPrototype &Node::contextPrototype() const { return _protoContext; } Node::~Node() { } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 1f5e0db405b4..ae6f9a81eb72 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -118,10 +118,10 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& p : workMap) { OpSequence seq; - seq.append(inMap.at(p.first), inMap.at(p.first).protoContext()); + seq.append(inMap.at(p.first), inMap.at(p.first).contextPrototype()); for (const auto& id : p.second._opSeq) - seq.append(inMap.at(id), inMap.at(id).protoContext()); + seq.append(inMap.at(id), inMap.at(id).contextPrototype()); _sortedGraph[p.second._layerNum].append(std::move(seq)); } @@ -179,11 +179,11 @@ void OptimizedGraph::printOut() const { // for (const auto& p : workMap) { // OpSequence seq; - // seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).protoContext()); + // seq.append(inMap.at(p.first).customOp(), inMap.at(p.first).contextPrototype()); // printGraph[p.second._layerNum].push_back(p.first); // for (const auto& id : p.second._opSeq) { - // seq.append(inMap.at(id).customOp(), inMap.at(id).protoContext()); + // seq.append(inMap.at(id).customOp(), inMap.at(id).contextPrototype()); // printGraph[p.second._layerNum].push_back(id); // } diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 8753ed56c184..1d0be2c7da91 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -30,7 +30,7 @@ Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); - Context ctx(node->protoContext(), __variableSpace); + Context ctx(node->contextPrototype(), __variableSpace); auto input = ctx.variable(0)->getNDArray(); std::pair pair0(node->id(), 0); diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index b9bc803b8ec9..d3e191b0cdc2 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -28,7 +28,7 @@ Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { auto __variableSpace = graph->variableSpace(); auto __flowPath = __variableSpace->flowPath(); - Context ctx(node->protoContext(), __variableSpace); + Context ctx(node->contextPrototype(), __variableSpace); auto input = ctx.variable(0)->getNDArray(); std::pair pair0(node->id(), 0); diff --git a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp index 465141d7288a..038fab85ca40 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphExecutorTests.cpp @@ -76,11 +76,11 @@ TEST_F(GraphExecutorTests, test_basic_exec_2) { OptimizedGraph optimizedGraph; OpSequence sequence; - ASSERT_EQ(2, m.protoContext().inputs().size()); - ASSERT_EQ(2, a.protoContext().inputs().size()); + ASSERT_EQ(2, m.contextPrototype().inputs().size()); + ASSERT_EQ(2, a.contextPrototype().inputs().size()); - sequence.append(m, m.protoContext()); - sequence.append(a, a.protoContext()); + sequence.append(m, m.contextPrototype()); + sequence.append(a, a.contextPrototype()); optimizedGraph.append(sequence); diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 2abe7bbfbb83..183650a4dee0 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -57,7 +57,7 @@ class OpSequenceTests : public testing::Test { // int cnt = 1; // for (const auto &v : sequence) { -// ASSERT_EQ(cnt++, v.protoContext().nodeId()); +// ASSERT_EQ(cnt++, v.contextPrototype().nodeId()); // } // ASSERT_EQ(3, cnt); From dbb670cb7f770bf174fe634e96ed709f60c4fedc Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 12:06:00 +0300 Subject: [PATCH 148/233] meh Signed-off-by: raver119@gmail.com --- .../include/graph/execution/impl/GraphExecutor.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 86dadc7db91d..1adcfb156793 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -63,16 +63,21 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, const OptimizedGraph &graph, VariableProxy &proxy, const int deviceId) const { + // we either follow or override target deviceId specified in OpSequence + auto targetDevice = deviceId >= 0 + ? deviceId + : seq.deviceId(); + /* * this is a basic implementation that works without dispatching etc */ + auto result = Status::OK(); for (int e = 0; e < seq.length(); e++) { - auto v = seq[e]; - // only Ops can be executed this way :( - Nd4jStatus result = Status::OK(); + auto &v = seq[e]; + // only Ops can be executed this way :( if (v.node().hasCustomOp()) - result = execute(v.node().customOp(), v.protoContext(), seq, graph, proxy, deviceId >= 0 ? deviceId : seq.deviceId()); + result = execute(v.node().customOp(), v.protoContext(), seq, graph, proxy, targetDevice); else { nd4j_printf("Node <%i:%s> has no customOp set\n", v.node().id(), From 7af9efcc7e64286c4ef622ba924e035ad0ce3c4e Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 14:19:35 +0300 Subject: [PATCH 149/233] additional graph sort Signed-off-by: raver119@gmail.com --- .../ops/declarable/generic/tensor/range.cpp | 1 + .../layers_tests/GraphAnalysisTests.cpp | 35 +++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/tensor/range.cpp b/libnd4j/include/ops/declarable/generic/tensor/range.cpp index 1a6969031de9..1751ae1c1c16 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/range.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/range.cpp @@ -283,6 +283,7 @@ DECLARE_TYPES(range) { ->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); } + } // namespace ops } // namespace sd diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index a55eae5b0b91..d88aa9f47203 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -892,11 +892,41 @@ TEST_F(GraphAnalysisTests, optimizedGraph_11) { ASSERT_EQ(8, sequence.at(0).protoContext().nodeId()); } +TEST_F(GraphAnalysisTests, optimizedGraph_12) { + Graph graph; + + graph.addVariable("start", NDArrayFactory::create(0)); + graph.addVariable("step", NDArrayFactory::create(1)); + + graph.addVariable("const_1", NDArrayFactory::create(0)); + graph.addVariable("const_2", NDArrayFactory::create(2)); + + // generating "stop" argument for Range op + graph.addNode(Node(sd::ops::add(), "add"), {"const_1", "const_2"}); + + // generating axis, should be equal to {0} + graph.addNode(Node(sd::ops::range(), "range_1"), {"start", "add", "step"}); + + graph.addNode(Node(sd::ops::range(), "range_2"), {"range_1", "add", "step"}); + + auto &optimized = graph.optimizedGraph(); + + graph.printOut(); + + // we expect exactly 1 layer + ASSERT_EQ(1, optimized.layers()); + auto layer = optimized.layer(0); + + // we expect exactly 1 OpSequence wihtin this layer + ASSERT_EQ(1, layer.width()); +} + TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); - // graph.printOut(); + graph.printOut(); + graph.execute(); /* some infor that would be useful for implementation currently on optimization graph is passing next data @@ -925,7 +955,8 @@ TEST_F(GraphAnalysisTests, test_cond_1) { TEST_F(GraphAnalysisTests, test_cond_2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); - // graph.printOut(); + graph.printOut(); + graph.execute(); } TEST_F(GraphAnalysisTests, test_while_iter_1_1) { From 6ec49664527a55ee093a3a4878bd30c49325b4c4 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 14:41:18 +0300 Subject: [PATCH 150/233] Additional assertion Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index d88aa9f47203..11e91c18860b 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -919,6 +919,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { // we expect exactly 1 OpSequence wihtin this layer ASSERT_EQ(1, layer.width()); + auto seq = layer.at(0); + + // this Graph doesn't allow any variance here. Order must be exactly the same as below + ASSERT_EQ(std::string("add"), seq[0].node().name()); + ASSERT_EQ(std::string("range_1"), seq[1].node().name()); + ASSERT_EQ(std::string("range_2"), seq[2].node().name()); } TEST_F(GraphAnalysisTests, test_cond_1) { From 9b4dea677332d31b2b39792b810fa1fa1b25ade7 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 15:30:25 +0300 Subject: [PATCH 151/233] few updates of minifier Signed-off-by: raver119@gmail.com --- libnd4j/minifier/minifier.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index 9f7802d7b6b6..81c2ac0073fe 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -21,8 +21,9 @@ #else #include #endif + +#include #include -#include #include #include @@ -113,10 +114,11 @@ int main(int argc, char *argv[]) { // std::cout << "File " << file << " exists and can be read" << // std::endl; auto graph = Graph::fromFlatBuffers(file.c_str()); - auto ops = graph->getOperations(); + auto ops = graph.unmappedNodes(); for (auto &v : ops) { - descriptors.emplace_back(v); + if (v.second.hasCustomOp()) + descriptors.emplace_back(*v.second.customOp()->getOpDescriptor()); } } else { std::cerr << "File " << file << " exists, but has zero size" From facc2cf15e8b24fea88c77e466a89592ccaf19c7 Mon Sep 17 00:00:00 2001 From: Yurii Date: Tue, 9 Jun 2020 17:08:26 +0300 Subject: [PATCH 152/233] - merge layers, containing only one opSeq with len=1, into one layer Signed-off-by: Yurii --- libnd4j/include/graph/execution/OpSequence.h | 2 ++ .../graph/execution/impl/OpSequence.cpp | 8 +++++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 35 ++++++++++++++++--- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index b625af8f3510..5ce0a5acc398 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -92,6 +92,8 @@ class SD_EXPORT OpSequence */ void append(const Node& node, const sd::graph::ContextPrototype& ctx); + void append(const ExecutionTask& task); + void append(ExecutionTask&& task); /** diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 5fac00733b5a..c114f5f0d7cc 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -84,6 +84,14 @@ void OpSequence::append(const Node &node, _ops.emplace_back(task); } +void OpSequence::append(const ExecutionTask& task) { + _ops.emplace_back(task); +} + +void OpSequence::append(ExecutionTask&& task) { + _ops.emplace_back(std::move(task)); +} + OpSequence::iterator OpSequence::begin() { return OpSequence::iterator(*this, 0); } diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index ae6f9a81eb72..8ed116f13148 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -112,9 +112,8 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& nextId : workMap[id]._out) visit(nextId, 1, numOfLayers); - - // fill _sortedGraph - _sortedGraph = std::vector(numOfLayers+1); + // fill vectors with layers + std::vector sortedGraphTemp(numOfLayers+1); for (const auto& p : workMap) { OpSequence seq; @@ -123,7 +122,35 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& id : p.second._opSeq) seq.append(inMap.at(id), inMap.at(id).contextPrototype()); - _sortedGraph[p.second._layerNum].append(std::move(seq)); + sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } + + // check whether there is layer with one OpSequence containing only one op + bool isLayerWithOneOp = false; + for(auto& layer : sortedGraphTemp) { + if(layer.width() == 1 && layer.at(0).length() == 1) { + isLayerWithOneOp = true; + break; + } + } + + // fill _sortedGraph + if(!isLayerWithOneOp) { + _sortedGraph = std::move(sortedGraphTemp); + } + else { + for (uint i = 0; i < sortedGraphTemp.size();) { + + OpSequence seq; + while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) + seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); + + if(seq.length() != 0) + _sortedGraph.emplace_back(ExecutionLayer({seq})); + else + _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); + } + } // sort _sortedGraph From 0add137572e890e03b48b0506327ce858d50c859 Mon Sep 17 00:00:00 2001 From: Yurii Date: Tue, 9 Jun 2020 17:18:59 +0300 Subject: [PATCH 153/233] - correct GraphAnalysisTests.optimizedGraph_7 test Signed-off-by: Yurii --- .../layers_tests/GraphAnalysisTests.cpp | 70 ++++--------------- 1 file changed, 15 insertions(+), 55 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 11e91c18860b..8b857bffc42c 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -540,63 +540,23 @@ TEST_F(GraphAnalysisTests, optimizedGraph_7) { ASSERT_EQ(5, graph.size()); const auto& optimized = graph.optimizedGraph(); - + // graph.printOut(); // we expect that OptimizedGraph has exactly 3 layer - ASSERT_EQ(5, optimized.numOfLayers()); - - // checking first layer first - auto layer0 = optimized.layer(0); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer0.width()); - // auto sequence = layer0[0]; - - // we expect that OpSequence has exactly 2 ops - ASSERT_EQ(1, layer0[0].length()); - ASSERT_EQ(4, layer0[0].at(0).protoContext().nodeId()); - - // checking second layer now - auto layer1 = optimized.layer(1); - - // we expect layer has exactly 2 OpSequences - ASSERT_EQ(1, layer1.width()); - // sequence = layer1[0]; - - ASSERT_EQ(1, layer1[0].length()); - ASSERT_EQ(5, layer1[0].at(0).protoContext().nodeId()); - - // checking layer 2 - auto layer2 = optimized.layer(2); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer2.width()); - // sequence = layer2[0]; - - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer2[0].length()); - ASSERT_EQ(6, layer2[0].at(0).protoContext().nodeId()); - - // checking layer 3 - auto layer3 = optimized.layer(3); - - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer3.width()); - // sequence = layer3[0]; + ASSERT_EQ(1, optimized.numOfLayers()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer3[0].length()); - ASSERT_EQ(7, layer3[0].at(0).protoContext().nodeId()); + auto layer = optimized.layer(0); - // checking layer 3 - auto layer4 = optimized.layer(4); + ASSERT_EQ(1, layer.width()); - // we expect layer has exactly 1 OpSequence - ASSERT_EQ(1, layer4.width()); - // sequence = layer4[0]; + auto seq = layer.at(0); + ASSERT_EQ(5, seq.length()); - // we expect that OpSequence has exactly 1 ops - ASSERT_EQ(1, layer4[0].length()); - ASSERT_EQ(8, layer4[0].at(0).protoContext().nodeId()); + // this Graph doesn't allow any variance here. Order must be exactly the same as below + ASSERT_EQ(std::string("a"), seq[0].node().name()); + ASSERT_EQ(std::string("b"), seq[1].node().name()); + ASSERT_EQ(std::string("c"), seq[2].node().name()); + ASSERT_EQ(std::string("d"), seq[3].node().name()); + ASSERT_EQ(std::string("e"), seq[4].node().name()); } TEST_F(GraphAnalysisTests, optimizedGraph_8) { @@ -911,7 +871,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { auto &optimized = graph.optimizedGraph(); - graph.printOut(); + // graph.printOut(); // we expect exactly 1 layer ASSERT_EQ(1, optimized.layers()); @@ -931,7 +891,7 @@ TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); - graph.printOut(); + // graph.printOut(); graph.execute(); /* some infor that would be useful for implementation @@ -961,7 +921,7 @@ TEST_F(GraphAnalysisTests, test_cond_1) { TEST_F(GraphAnalysisTests, test_cond_2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); - graph.printOut(); + // graph.printOut(); graph.execute(); } From 8fbeec664cd5e1aab81319e85ddef1a4bc4b2444 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 20:27:38 +0300 Subject: [PATCH 154/233] meh 3 Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/GraphExecutor.h | 3 ++- libnd4j/include/graph/execution/impl/GraphExecutor.cpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index ce54b07c3b19..34b90e7eff04 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -65,7 +65,8 @@ class SD_EXPORT GraphExecutor { * @return */ virtual Nd4jStatus execute(const OptimizedGraph &graph, - VariableProxy &proxy) const; + VariableProxy &proxy, + bool isInference = true) const; /** * This method executes OpSequence diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 1adcfb156793..789ee0a9e5e6 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -92,7 +92,8 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, } Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, - VariableProxy &proxy) const { + VariableProxy &proxy, + bool isInference) const { /* * this is a basic exection logic: roll through layers and sequences and * execute them one by one sequentially From 9b27736cbf601e0cd176f3af5d6ca7539ae57d1c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 21:31:20 +0300 Subject: [PATCH 155/233] one more test Signed-off-by: raver119@gmail.com --- .../layers_tests/GraphAnalysisTests.cpp | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 8b857bffc42c..71f11764e849 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -887,6 +887,75 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { ASSERT_EQ(std::string("range_2"), seq[2].node().name()); } +TEST_F(GraphAnalysisTests, optimizedGraph_13) { + Graph graph; + + graph.addPlaceholder("input", sd::DataType::FLOAT32, {-1, 3, 244, 244}); + graph.addVariable("weights_1", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_2", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_3", NDArrayFactory::create(0.1f)); + graph.addVariable("weights_4", NDArrayFactory::create(0.1f)); + + graph.addVariable("axis", NDArrayFactory::create(1)); + + graph.addNode(Node(sd::ops::tanh(), "conv_1"), {"input", "weights_1"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_1"), {"conv_1"}); + + // branch 1 + graph.addNode(Node(sd::ops::tanh(), "conv_2"), {"pooling_1", "weights_2"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_2"), {"conv_2"}); + + // branch 2 + graph.addNode(Node(sd::ops::tanh(), "conv_3"), {"pooling_1", "weights_3"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_3"), {"conv_3"}); + + // branch 3 + graph.addNode(Node(sd::ops::tanh(), "conv_4"), {"pooling_1", "weights_4"}); + graph.addNode(Node(sd::ops::tanh(), "pooling_4"), {"conv_4"}); + + // merge branch + graph.addNode(Node(sd::ops::concat(), "concat"), {"pooling_2", "pooling_3", "pooling_4", "axis"}); + + auto &optimized = graph.optimizedGraph(); + + // we expect exactly 2 layers + ASSERT_EQ(3, optimized.layers()); + auto layer = optimized.layer(0); + + // layer 0 must have exactly 1 sequence of 2 ops: conv_1 and pooling_1 + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_1"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_1"), seq[1].node().name()); + + layer = optimized.layer(1); + ASSERT_EQ(3, layer.width()); + + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_2"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_2"), seq[1].node().name()); + + seq = layer[1]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_3"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_3"), seq[1].node().name()); + + seq = layer[2]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("conv_4"), seq[0].node().name()); + ASSERT_EQ(std::string("pooling_4"), seq[1].node().name()); + + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("concat"), seq[0].node().name()); +} + TEST_F(GraphAnalysisTests, test_cond_1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); From 5531e88de00da48b8cd67b5a2b217151959bd7d2 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 21:32:58 +0300 Subject: [PATCH 156/233] typo Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 71f11764e849..0f8fecac4fd1 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -918,7 +918,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_13) { auto &optimized = graph.optimizedGraph(); - // we expect exactly 2 layers + // we expect exactly 3 layers ASSERT_EQ(3, optimized.layers()); auto layer = optimized.layer(0); From dad0e2739610b0979a904503021873dc7cd8d054 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 9 Jun 2020 21:58:47 +0300 Subject: [PATCH 157/233] meh Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/logic/LogicConditional.h | 6 +++--- libnd4j/include/graph/logic/LogicEnter.h | 6 +++--- libnd4j/include/graph/logic/LogicExecutor.h | 6 +++--- libnd4j/include/graph/logic/LogicExit.h | 4 ++-- libnd4j/include/graph/logic/LogicExpose.h | 2 +- libnd4j/include/graph/logic/LogicLoopCond.h | 6 +++--- libnd4j/include/graph/logic/LogicMerge.h | 6 +++--- libnd4j/include/graph/logic/LogicNextIteration.h | 6 +++--- libnd4j/include/graph/logic/LogicReturn.h | 6 +++--- libnd4j/include/graph/logic/LogicScope.h | 6 +++--- libnd4j/include/graph/logic/LogicSwitch.h | 6 +++--- libnd4j/include/graph/logic/LogicWhile.h | 6 +++--- 12 files changed, 33 insertions(+), 33 deletions(-) diff --git a/libnd4j/include/graph/logic/LogicConditional.h b/libnd4j/include/graph/logic/LogicConditional.h index d84aa584c76f..ce7ef82b5ad5 100644 --- a/libnd4j/include/graph/logic/LogicConditional.h +++ b/libnd4j/include/graph/logic/LogicConditional.h @@ -18,8 +18,8 @@ // Created by raver119 on 20.10.2017. // -#ifndef LIBND4J_LOGICCONDITIONAL_H -#define LIBND4J_LOGICCONDITIONAL_H +#ifndef SD_LOGICCONDITIONAL_H +#define SD_LOGICCONDITIONAL_H #include #include @@ -46,4 +46,4 @@ class LogicConditional { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICCONDITIONAL_H +#endif // SD_LOGICCONDITIONAL_H diff --git a/libnd4j/include/graph/logic/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h index f0ba6a439767..480217205c4c 100644 --- a/libnd4j/include/graph/logic/LogicEnter.h +++ b/libnd4j/include/graph/logic/LogicEnter.h @@ -18,8 +18,8 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICENTER_H -#define LIBND4J_LOGICENTER_H +#ifndef SD_LOGICENTER_H +#define SD_LOGICENTER_H #include #include @@ -33,4 +33,4 @@ class LogicEnter { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICEXIT_H +#endif // SD_LOGICEXIT_H diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index 3dea3ec2e9b9..cbf0ed0b64e1 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -18,8 +18,8 @@ // Created by raver119 on 20.10.2017. // -#ifndef LIBND4J_LOGICEXECUTOR_H -#define LIBND4J_LOGICEXECUTOR_H +#ifndef SD_LOGICEXECUTOR_H +#define SD_LOGICEXECUTOR_H #include #include @@ -39,4 +39,4 @@ class LogicExecutor { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICEXECUTOR_H +#endif // SD_LOGICEXECUTOR_H diff --git a/libnd4j/include/graph/logic/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h index 216409c38251..7beebee1929d 100644 --- a/libnd4j/include/graph/logic/LogicExit.h +++ b/libnd4j/include/graph/logic/LogicExit.h @@ -18,8 +18,8 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICEXIT_H -#define LIBND4J_LOGICEXIT_H +#ifndef SD_LOGICEXIT_H +#define SD_LOGICEXIT_H #include #include diff --git a/libnd4j/include/graph/logic/LogicExpose.h b/libnd4j/include/graph/logic/LogicExpose.h index 6e4bb5e1b937..fc3614b9cded 100644 --- a/libnd4j/include/graph/logic/LogicExpose.h +++ b/libnd4j/include/graph/logic/LogicExpose.h @@ -18,7 +18,7 @@ // Created by raver119 on 12.11.2017. // -#ifndef LIBND4J_LOGICEXPOSE_H +#ifndef SD_LOGICEXPOSE_H #define LIBND4J_LOGICEXPOSE_H #include diff --git a/libnd4j/include/graph/logic/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h index 670c0d07faee..53ff911f5d8c 100644 --- a/libnd4j/include/graph/logic/LogicLoopCond.h +++ b/libnd4j/include/graph/logic/LogicLoopCond.h @@ -18,8 +18,8 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICLOOPCOND_H -#define LIBND4J_LOGICLOOPCOND_H +#ifndef SD_LOGICLOOPCOND_H +#define SD_LOGICLOOPCOND_H #include #include @@ -33,4 +33,4 @@ class LogicLoopCond { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICLOOPCOND_H +#endif // SD_LOGICLOOPCOND_H diff --git a/libnd4j/include/graph/logic/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h index 8bd8cbe7d5a1..a808e046e83c 100644 --- a/libnd4j/include/graph/logic/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -18,8 +18,8 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICMERGE_H -#define LIBND4J_LOGICMERGE_H +#ifndef SD_LOGICMERGE_H +#define SD_LOGICMERGE_H #include #include @@ -33,4 +33,4 @@ class LogicMerge { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICMERGE_H +#endif // SD_LOGICMERGE_H diff --git a/libnd4j/include/graph/logic/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h index 415b44f6dbcc..7a17645f93fe 100644 --- a/libnd4j/include/graph/logic/LogicNextIteration.h +++ b/libnd4j/include/graph/logic/LogicNextIteration.h @@ -18,8 +18,8 @@ // Created by raver119 on 30.01.18. // -#ifndef LIBND4J_LOGICNEXTITERATION_H -#define LIBND4J_LOGICNEXTITERATION_H +#ifndef SD_LOGICNEXTITERATION_H +#define SD_LOGICNEXTITERATION_H #include #include @@ -33,4 +33,4 @@ class LogicNextIeration { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICNEXTITERATION_H +#endif // SD_LOGICNEXTITERATION_H diff --git a/libnd4j/include/graph/logic/LogicReturn.h b/libnd4j/include/graph/logic/LogicReturn.h index 8c342b091417..b462eb4c6618 100644 --- a/libnd4j/include/graph/logic/LogicReturn.h +++ b/libnd4j/include/graph/logic/LogicReturn.h @@ -18,8 +18,8 @@ // Created by raver119 on 28.10.2017. // -#ifndef LIBND4J_LOGICRETURN_H -#define LIBND4J_LOGICRETURN_H +#ifndef SD_LOGICRETURN_H +#define SD_LOGICRETURN_H #include #include @@ -41,4 +41,4 @@ class LogicReturn { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICRETURN_H +#endif // SD_LOGICRETURN_H diff --git a/libnd4j/include/graph/logic/LogicScope.h b/libnd4j/include/graph/logic/LogicScope.h index 17d13d83ac91..bd197b1a0dc8 100644 --- a/libnd4j/include/graph/logic/LogicScope.h +++ b/libnd4j/include/graph/logic/LogicScope.h @@ -18,8 +18,8 @@ // Created by raver119 on 20.10.2017. // -#ifndef LIBND4J_LOGICSCOPE_H -#define LIBND4J_LOGICSCOPE_H +#ifndef SD_LOGICSCOPE_H +#define SD_LOGICSCOPE_H #include #include @@ -41,4 +41,4 @@ class LogicScope { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICSCOPE_H +#endif // SD_LOGICSCOPE_H diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index d74ce87e4908..9d659cf4a91b 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -18,8 +18,8 @@ // Created by raver119 on 21.10.17. // -#ifndef LIBND4J_LOGICSWITCH_H -#define LIBND4J_LOGICSWITCH_H +#ifndef SD_LOGICSWITCH_H +#define SD_LOGICSWITCH_H #include #include @@ -41,4 +41,4 @@ class LogicSwitch { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICSWITCH_H +#endif // SD_LOGICSWITCH_H diff --git a/libnd4j/include/graph/logic/LogicWhile.h b/libnd4j/include/graph/logic/LogicWhile.h index e80d742cbf41..d5437c5eae8e 100644 --- a/libnd4j/include/graph/logic/LogicWhile.h +++ b/libnd4j/include/graph/logic/LogicWhile.h @@ -18,8 +18,8 @@ // Created by raver119 on 20.10.2017. // -#ifndef LIBND4J_LOGICWHILE_H -#define LIBND4J_LOGICWHILE_H +#ifndef SD_LOGICWHILE_H +#define SD_LOGICWHILE_H #include #include @@ -41,4 +41,4 @@ class LogicWhile { } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICWHILE_H +#endif // SD_LOGICWHILE_H From ca370272a00bc0c550925060280ca25e27b7bc51 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 09:23:27 +0300 Subject: [PATCH 158/233] few tests Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/VariableProxy.h | 12 ++++ libnd4j/include/graph/execution/StackFrame.h | 42 ++++++++++++++ .../graph/execution/impl/StackFrame.cpp | 24 ++++++++ libnd4j/include/graph/impl/VariableProxy.cpp | 8 +++ .../layers_tests/VariableProxyTests.cpp | 56 +++++++++++++++++++ 5 files changed, 142 insertions(+) create mode 100644 libnd4j/include/graph/execution/StackFrame.h create mode 100644 libnd4j/include/graph/execution/impl/StackFrame.cpp diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 63eabf15e460..430e825a78fb 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -92,6 +92,18 @@ class SD_EXPORT VariableProxy : public VariableSpace { virtual int totalEntries() const override; virtual Stash* stash() const override; + + /** + * This method updates this proxy with entries from a given VariableProxy + * @param proxy + */ + void pullFrom(const VariableProxy &proxy); + + /** + * This method will update given VariableProxy with values from this proxy + * @param proxy + */ + void pushTo(VariableProxy &proxy) const; }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h new file mode 100644 index 000000000000..574ea9b95e45 --- /dev/null +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_STACKFRAME_H_ +#define SD_STACKFRAME_H_ + +#include +#include + +namespace sd { +namespace graph { + +class SD_EXPORT StackFrame { + private: + VariableProxy _proxy; + + public: + explicit StackFrame(VariableProxy &proxy); + ~StackFrame() = default; +}; + +} // namespace graph +} // namespace sd + +#endif // SD_STACKFRAME_H_ diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp new file mode 100644 index 000000000000..d96f48c5006f --- /dev/null +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include + +sd::graph::StackFrame::StackFrame(sd::graph::VariableProxy &proxy) : _proxy(proxy) { } diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index d511dc9e4140..03e3f4e54968 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -218,5 +218,13 @@ VariableSpace &VariableProxy::operator=(const VariableSpace &other) { return *this; } + +void VariableProxy::pullFrom(const VariableProxy &proxy) { + +} +void VariableProxy::pushTo(VariableProxy &proxy) const { + +} + } // namespace graph } // namespace sd diff --git a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp index 2f32ac16cfdc..fb4c0711bbeb 100644 --- a/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/VariableProxyTests.cpp @@ -143,4 +143,60 @@ TEST_F(VariableProxyTests, Test_Cast_1) { ASSERT_FALSE(z0 == z1); ASSERT_TRUE(y == *z1); ASSERT_TRUE(x == *z0); +} + +TEST_F(VariableProxyTests, test_update_1) { + VariableSpace ref; + + auto x = NDArrayFactory::create(1); + auto A = NDArrayFactory::create(2); + auto B = NDArrayFactory::create(3); + + // set initial states for all 3 VariableSpaces/Proxies + ref.putVariable(2, x); + + VariableProxy proxyA(&ref); + proxyA.putVariable(2, A); + + VariableProxy proxyB(&proxyA); + proxyB.putVariable(2, B); + + // check out initial state + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); + + // update state now and check result + proxyB.pushTo(proxyA); + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); +} + +TEST_F(VariableProxyTests, test_update_2) { + VariableSpace ref; + + auto x = NDArrayFactory::create(1); + auto A = NDArrayFactory::create(2); + auto B = NDArrayFactory::create(3); + + // set initial states for all 3 VariableSpaces/Proxies + ref.putVariable(2, x); + + VariableProxy proxyA(&ref); + proxyA.putVariable(2, A); + + VariableProxy proxyB(&proxyA); + proxyB.putVariable(2, B); + + // check out initial state + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(B, *proxyB.getVariable(2)->getNDArray().get()); + + // update state now and check result + proxyB.pullFrom(proxyA); + ASSERT_EQ(x, *ref.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyA.getVariable(2)->getNDArray().get()); + ASSERT_EQ(A, *proxyB.getVariable(2)->getNDArray().get()); } \ No newline at end of file From 769c08fd6d6fe5814b8ef43ee5c20ea53da0cac7 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 10:15:54 +0300 Subject: [PATCH 159/233] VariableProxy update methods Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/VariableProxy.h | 3 +- libnd4j/include/graph/VariableSpace.h | 19 ++---- libnd4j/include/graph/impl/VariableProxy.cpp | 71 +++++++++++--------- libnd4j/include/graph/impl/VariableSpace.cpp | 22 +++--- 4 files changed, 58 insertions(+), 57 deletions(-) diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 430e825a78fb..0b1c707a3272 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,7 +29,7 @@ namespace graph { class SD_EXPORT VariableProxy : public VariableSpace { protected: const VariableSpace* _backed; - VariableSpace* _current = nullptr; + VariableSpace _current; public: explicit VariableProxy(const VariableSpace* reference); diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 6a9a9a0b5e90..82b24f52bb26 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -97,19 +97,14 @@ class SD_EXPORT VariableSpace { virtual std::vector> variables() const; - virtual std::shared_ptr putVariable(const std::pair &pair, - const NDArray &array); + virtual std::shared_ptr putVariable(const std::pair &pair, const NDArray &array); virtual std::shared_ptr putVariable(int id, const NDArray &array); - virtual std::shared_ptr putVariable( - int id, int idx, const std::shared_ptr &array); - virtual std::shared_ptr putVariable(int id, int idx, - const NDArray &array); - virtual std::shared_ptr putVariable(const std::string &name, int id, - int idx, const NDArray &array); - virtual void putVariable(const std::string &name, int id, int idx, - const std::shared_ptr &variable); - virtual void putVariable(const std::pair &pair, - const std::shared_ptr &variable); + virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); + virtual std::shared_ptr putVariable(int id, int idx, const NDArray &array); + virtual std::shared_ptr putVariable(const std::string &name, int id, int idx, const NDArray &array); + + virtual void putVariable(const std::string &name, int id, int idx, const std::shared_ptr &variable); + virtual void putVariable(const std::pair &pair, const std::shared_ptr &variable); virtual void putVariable(int id, const std::shared_ptr &variable); virtual void dropVariable(const std::string &pair); diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index 03e3f4e54968..c67d30633035 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,10 +29,9 @@ VariableProxy::VariableProxy(const VariableSpace *ref) { if (ref == nullptr) _backed = new VariableSpace(); _backed = ref; - _current = new VariableSpace(); } -VariableProxy::~VariableProxy() { delete _current; } +VariableProxy::~VariableProxy() { } int VariableProxy::numberOfPlaceholders() const { return _backed->numberOfPlaceholders(); @@ -55,15 +55,15 @@ bool VariableProxy::hasExternalVariable(const std::string &symbol) const { } bool VariableProxy::hasVariable(int id) const { - return _current->hasVariable(id) || _backed->hasVariable(id); + return _current.hasVariable(id) || _backed->hasVariable(id); } bool VariableProxy::hasVariable(int id, int idx) const { - return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); + return _current.hasVariable(id, idx) || _backed->hasVariable(id, idx); } bool VariableProxy::hasVariable(const std::pair &pair) const { - return _current->hasVariable(pair) || _backed->hasVariable(pair); + return _current.hasVariable(pair) || _backed->hasVariable(pair); } void VariableProxy::dropVariable(const std::pair &pair) { @@ -71,16 +71,16 @@ void VariableProxy::dropVariable(const std::pair &pair) { } void VariableProxy::dropVariable(int id, int idx) { - assert(_current->hasVariable(id, idx)); + assert(_current.hasVariable(id, idx)); - _current->dropVariable(id, idx); + _current.dropVariable(id, idx); } std::vector> VariableProxy::variables() const { std::vector> result; auto b = _backed->variables(); - auto c = _current->variables(); + auto c = _current.variables(); for (auto v : b) result.emplace_back(v); @@ -90,11 +90,11 @@ std::vector> VariableProxy::variables() const { } bool VariableProxy::hasVariable(const std::string &symbol) const { - return _current->hasVariable(symbol) || _backed->hasVariable(symbol); + return _current.hasVariable(symbol) || _backed->hasVariable(symbol); } std::shared_ptr VariableProxy::getVariable(int id) const { - if (_current->hasVariable(id)) return _current->getVariable(id); + if (_current.hasVariable(id)) return _current.getVariable(id); if (_backed->hasVariable(id)) return _backed->getVariable(id); @@ -103,7 +103,7 @@ std::shared_ptr VariableProxy::getVariable(int id) const { } std::shared_ptr VariableProxy::getVariable(int id, int idx) const { - if (_current->hasVariable(id, idx)) return _current->getVariable(id, idx); + if (_current.hasVariable(id, idx)) return _current.getVariable(id, idx); if (_backed->hasVariable(id, idx)) return _backed->getVariable(id, idx); @@ -113,7 +113,7 @@ std::shared_ptr VariableProxy::getVariable(int id, int idx) const { std::shared_ptr VariableProxy::getVariable( const std::pair &pair) const { - if (_current->hasVariable(pair)) return _current->getVariable(pair); + if (_current.hasVariable(pair)) return _current.getVariable(pair); if (_backed->hasVariable(pair)) return _backed->getVariable(pair); @@ -124,7 +124,7 @@ std::shared_ptr VariableProxy::getVariable( std::shared_ptr VariableProxy::getVariable( const std::string &symbol) const { - if (_current->hasVariable(symbol)) return _current->getVariable(symbol); + if (_current.hasVariable(symbol)) return _current.getVariable(symbol); if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); @@ -138,77 +138,77 @@ void VariableProxy::replaceVariable(std::shared_ptr variable) { if (_backed->hasVariable(variable->getName())) { auto origVar = _backed->getVariable(variable->getName()); variable->setId(origVar->id(), origVar->index()); - _current->replaceVariable(variable); + _current.replaceVariable(variable); } else - _current->replaceVariable(variable); + _current.replaceVariable(variable); } else // if proxy has variable - that's one story - _current->replaceVariable(variable); + _current.replaceVariable(variable); } std::shared_ptr VariableProxy::putVariable(const std::string &name, int id, int idx, const NDArray &array) { - return _current->putVariable(name, id, idx, array); + return _current.putVariable(name, id, idx, array); } void VariableProxy::putOutputVariable(std::shared_ptr variable) { - _current->putOutputVariable(variable); + _current.putOutputVariable(variable); } std::shared_ptr VariableProxy::putVariable( const std::pair &pair, const NDArray &array) { - return _current->putVariable(pair, array); + return _current.putVariable(pair, array); } void VariableProxy::putVariable(const std::pair &pair, const std::shared_ptr &variable) { - _current->putVariable(pair, variable); + _current.putVariable(pair, variable); } void VariableProxy::putVariable(int id, const std::shared_ptr &variable) { - _current->putVariable(id, variable); + _current.putVariable(id, variable); } std::shared_ptr VariableProxy::putVariable(int id, const NDArray &array) { - return _current->putVariable(id, array); + return _current.putVariable(id, array); } std::shared_ptr VariableProxy::putVariable(int id, int idx, const NDArray &array) { - return _current->putVariable(id, idx, array); + return _current.putVariable(id, idx, array); } void VariableProxy::putVariable(const std::string &name, int id, int idx, const std::shared_ptr &array) { - _current->putVariable(name, id, idx, array); + _current.putVariable(name, id, idx, array); } -Stash *VariableProxy::stash() const { return _current->stash(); } +Stash *VariableProxy::stash() const { return _current.stash(); } Nd4jLong VariableProxy::externalMemory() const { - return _backed->externalMemory() + _current->externalMemory(); + return _backed->externalMemory() + _current.externalMemory(); } Nd4jLong VariableProxy::internalMemory() const { - return _backed->internalMemory() + _current->internalMemory(); + return _backed->internalMemory() + _current.internalMemory(); } Nd4jLong VariableProxy::totalMemory() const { - return _backed->totalMemory() + _current->totalMemory(); + return _backed->totalMemory() + _current.totalMemory(); } int VariableProxy::externalEntries() const { - return _backed->externalEntries() + _current->externalEntries(); + return _backed->externalEntries() + _current.externalEntries(); } int VariableProxy::internalEntries() const { - return _backed->internalEntries() + _current->internalEntries(); + return _backed->internalEntries() + _current.internalEntries(); } int VariableProxy::totalEntries() const { - return _backed->totalEntries() + _current->totalEntries(); + return _backed->totalEntries() + _current.totalEntries(); } VariableSpace &VariableProxy::operator=(const VariableSpace &other) { @@ -220,10 +220,15 @@ VariableSpace &VariableProxy::operator=(const VariableSpace &other) { } void VariableProxy::pullFrom(const VariableProxy &proxy) { - + for (const auto &v:proxy._current.variables()) { + _current.replaceVariable(v); + } } -void VariableProxy::pushTo(VariableProxy &proxy) const { +void VariableProxy::pushTo(VariableProxy &proxy) const { + for (const auto &v:_current.variables()) { + proxy._current.replaceVariable(v); + } } } // namespace graph diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 0e4307e69097..72443ccb19bd 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -339,35 +339,35 @@ VariableSpace &VariableSpace::operator=(const VariableSpace &other) { void VariableSpace::replaceVariable(std::shared_ptr variable) { bool replaced = false; - // trying name first + // trying name lookup first if (!variable->getName().empty()) { - nd4j_printf("Trying to replace variable by name: [%s]\n", - variable->getName().c_str()); if (hasVariable(variable->getName())) { - nd4j_printf("Replacing by name: [%s]\n", variable->getName().c_str()); auto vs = getVariable(variable->getName()); dropVariable(vs->id(), vs->index()); putVariable({vs->id(), vs->index()}, variable); - // delete vs; + + // if we're on zero index, we also must update index-less reference + if (vs->index() == 0) + _variables[vs->id()] = variable; + replaced = true; } } else { - nd4j_printf("Trying to replace variable by id: [%i:%i]\n", variable->id(), - variable->index()); if (hasVariable(variable->id(), variable->index())) { - nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), - variable->index()); auto vs = getVariable(variable->id(), variable->index()); dropVariable(variable->id(), variable->index()); putVariable({vs->id(), vs->index()}, variable); - // delete vs; + + // if we're on zero index, we also must update index-less reference + if (vs->index() == 0) + _variables[vs->id()] = variable; + replaced = true; } } if (!replaced) { - nd4j_printf("wasn't able to replace variable, putting\n", ""); putVariable({variable->id(), variable->index()}, variable); } } From 3a4bec1d1935c0bbcb5409968d19d25cbead3a45 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 11:07:15 +0300 Subject: [PATCH 160/233] VariableSpace::dropVariable() implemented properly Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/VariableSpace.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 72443ccb19bd..382710d9b5e6 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -376,7 +376,19 @@ void VariableSpace::dropVariable(const std::pair &pair) { dropVariable(pair.first, pair.second); } -void VariableSpace::dropVariable(int id, int idx) {} +void VariableSpace::dropVariable(int id, int idx) { + if (hasVariable(id, idx)) { + // we must check if string name is defined for this Variable + auto v = getVariable(id, idx); + if (!v->name().empty()) + _symbolic.erase(v->name()); + + _paired.erase(std::pair{id, idx}); + } + + if (idx == 0 && hasVariable(id)) + _variables.erase(id); +} VariableSpace::VariableSpace() {} } // namespace graph From 3ae5e456bda39f1c21b9baed2f75fd2ba7002977 Mon Sep 17 00:00:00 2001 From: Yurii Date: Wed, 10 Jun 2020 16:19:07 +0300 Subject: [PATCH 161/233] - check tests for graphs containing conditions Signed-off-by: Yurii --- .../layers_tests/GraphAnalysisTests.cpp | 66 ++++++++++++++++++- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 0f8fecac4fd1..502c2e88a7f7 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -957,11 +957,41 @@ TEST_F(GraphAnalysisTests, optimizedGraph_13) { } TEST_F(GraphAnalysisTests, test_cond_1) { - auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); + auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); // graph.printOut(); - graph.execute(); + + // we expect exactly 3 layers + ASSERT_EQ(3, optimized.layers()); + + // layer 0 + auto layer = optimized.layer(0); + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("in_0/read"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/Switch"), seq[1].node().name()); + + // layer 1 + layer = optimized.layer(1); + ASSERT_EQ(2, layer.width()); + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("cond/switch_t"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/LinSpace"), seq[1].node().name()); + seq = layer[1]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/switch_f"), seq[0].node().name()); + + // layer 2 + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); + + // graph.execute(); /* some infor that would be useful for implementation currently on optimization graph is passing next data @@ -989,9 +1019,39 @@ TEST_F(GraphAnalysisTests, test_cond_1) { } TEST_F(GraphAnalysisTests, test_cond_2) { + auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); + const auto& optimized = graph.optimizedGraph(); // graph.printOut(); - graph.execute(); + + ASSERT_EQ(3, optimized.layers()); + + // layer 0 + auto layer = optimized.layer(0); + ASSERT_EQ(1, layer.width()); + auto seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("in_0/read"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/Switch"), seq[1].node().name()); + + // layer 1 + layer = optimized.layer(1); + ASSERT_EQ(2, layer.width()); + seq = layer[0]; + ASSERT_EQ(2, seq.length()); + ASSERT_EQ(std::string("cond/switch_t"), seq[0].node().name()); + ASSERT_EQ(std::string("cond/LinSpace"), seq[1].node().name()); + seq = layer[1]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/switch_f"), seq[0].node().name()); + + // layer 2 + layer = optimized.layer(2); + ASSERT_EQ(1, layer.width()); + seq = layer[0]; + ASSERT_EQ(1, seq.length()); + ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); + } TEST_F(GraphAnalysisTests, test_while_iter_1_1) { From 6e07438efe9d6f77b4a9600bf63de0874ef07ed0 Mon Sep 17 00:00:00 2001 From: Yurii Date: Wed, 10 Jun 2020 17:33:42 +0300 Subject: [PATCH 162/233] - fill outputs of nodes in original map Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 24 ++++++++++++------- .../layers_tests/GraphAnalysisTests.cpp | 8 +++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 8ed116f13148..1ddf874c411b 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -60,16 +60,24 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS std::vector startNodes; for (const auto& p : inMap) { - for (const auto& v : p.second.input()) { - if (v.first >= inMap.begin()->first) { // is op - workMap[p.first]._in.push_back(v.first); - workMap[v.first]._out.push_back(p.first); + const auto& inputs = p.second.input(); + + for (int i = 0; i < inputs.size(); ++i) { + + if (inputs[i].first >= inMap.begin()->first) { // is op + workMap[p.first]._in.push_back(inputs[i].first); + workMap[inputs[i].first]._out.push_back(p.first); + const_cast&>(inMap)[inputs[i].first].pickOutput(p.first, i); } else { // is variable - for (const auto& i : varSpace.getVariable(v.first).get()->dependencies()) { - if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), i.first) == workMap[p.first]._in.end()) { - workMap[p.first]._in.push_back(i.first); - workMap[i.first]._out.push_back(p.first); + + const auto depends = varSpace.getVariable(inputs[i].first).get()->dependencies(); + + for (int j = 0; j < depends.size(); ++j) { + if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { + workMap[p.first]._in.push_back(depends[j].first); + workMap[depends[j].first]._out.push_back(p.first); + const_cast&>(inMap)[depends[j].first].pickOutput(p.first, j); } } } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 502c2e88a7f7..db3269ef7cfc 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -956,7 +956,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_13) { ASSERT_EQ(std::string("concat"), seq[0].node().name()); } -TEST_F(GraphAnalysisTests, test_cond_1) { +TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); @@ -1018,7 +1018,7 @@ TEST_F(GraphAnalysisTests, test_cond_1) { */ } -TEST_F(GraphAnalysisTests, test_cond_2) { +TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); const auto& optimized = graph.optimizedGraph(); @@ -1054,7 +1054,7 @@ TEST_F(GraphAnalysisTests, test_cond_2) { } -TEST_F(GraphAnalysisTests, test_while_iter_1_1) { +TEST_F(GraphAnalysisTests, optimizedGraph_while1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); - //graph.printOut(); + graph.printOut(); } From a0ee500786cbd5e13b6b8d7eaf938eebb2ade28e Mon Sep 17 00:00:00 2001 From: Yurii Date: Wed, 10 Jun 2020 17:52:35 +0300 Subject: [PATCH 163/233] - reproduce crash Signed-off-by: Yurii --- libnd4j/include/graph/OptimizedGraph.h | 2 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 1fbb0cafdc59..dcb2649d0fbc 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -39,7 +39,7 @@ class SD_EXPORT OptimizedGraph { // const Graph& _originalGraph; public: - OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); + OptimizedGraph(MAP_IMPL map, const VariableSpace& varSpace); // move constructor OptimizedGraph(OptimizedGraph&& other) noexcept; // default constructor diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 1ddf874c411b..a28e28d9be1f 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -45,7 +45,7 @@ OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { } /////////////////////////////////////////////////////////////////// -OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace) { +OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& varSpace) { struct NodeInfo { uint _layerNum = 0; @@ -67,7 +67,7 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS if (inputs[i].first >= inMap.begin()->first) { // is op workMap[p.first]._in.push_back(inputs[i].first); workMap[inputs[i].first]._out.push_back(p.first); - const_cast&>(inMap)[inputs[i].first].pickOutput(p.first, i); + inMap[inputs[i].first].pickOutput(p.first, i); } else { // is variable @@ -77,7 +77,7 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { workMap[p.first]._in.push_back(depends[j].first); workMap[depends[j].first]._out.push_back(p.first); - const_cast&>(inMap)[depends[j].first].pickOutput(p.first, j); + inMap[depends[j].first].pickOutput(p.first, j); } } } From 6f9bc96c2c5039c914065211323d209e9c15f801 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 18:28:06 +0300 Subject: [PATCH 164/233] 3 new tests Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/ContextPrototype.h | 2 +- .../include/graph/execution/GraphExecutor.h | 4 +- libnd4j/include/graph/execution/StackFrame.h | 2 + .../graph/execution/impl/GraphExecutor.cpp | 21 +++++-- .../include/graph/impl/ContextPrototype.cpp | 2 +- libnd4j/include/graph/impl/Node.cpp | 6 +- libnd4j/include/graph/impl/VariableSpace.cpp | 4 +- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 55 ++++++++++++++++++- 8 files changed, 81 insertions(+), 15 deletions(-) diff --git a/libnd4j/include/graph/ContextPrototype.h b/libnd4j/include/graph/ContextPrototype.h index 4bce3b2dfd9b..e780c4aa93a0 100644 --- a/libnd4j/include/graph/ContextPrototype.h +++ b/libnd4j/include/graph/ContextPrototype.h @@ -132,7 +132,7 @@ class SD_EXPORT ContextPrototype { bool isUseMKLDNN() const { return _useMKLDNN; } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN = useMKLDNN; } - std::string name() const; + const std::string& name() const; void setName(const std::string& name); /** diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 34b90e7eff04..461267262a70 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include namespace sd { @@ -77,7 +79,7 @@ class SD_EXPORT GraphExecutor { */ virtual Nd4jStatus execute(const OpSequence &seq, const OptimizedGraph &graph, - VariableProxy &proxy, + std::deque &stackFrames, int deviceId) const; /** diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 574ea9b95e45..dc036a636acb 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -34,6 +34,8 @@ class SD_EXPORT StackFrame { public: explicit StackFrame(VariableProxy &proxy); ~StackFrame() = default; + + const VariableProxy& variableProxy() const { return _proxy; } }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 789ee0a9e5e6..1c124f0d98c6 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -21,6 +21,7 @@ #include #include + namespace sd { namespace graph { Context GraphExecutor::prepareContext( @@ -61,7 +62,7 @@ Nd4jStatus GraphExecutor::execute( Nd4jStatus GraphExecutor::execute(const OpSequence &seq, const OptimizedGraph &graph, - VariableProxy &proxy, + std::deque &stackFrames, const int deviceId) const { // we either follow or override target deviceId specified in OpSequence auto targetDevice = deviceId >= 0 @@ -74,10 +75,11 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, auto result = Status::OK(); for (int e = 0; e < seq.length(); e++) { auto &v = seq[e]; + auto &p = stackFrames.back().variableProxy(); // only Ops can be executed this way :( if (v.node().hasCustomOp()) - result = execute(v.node().customOp(), v.protoContext(), seq, graph, proxy, targetDevice); + result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); else { nd4j_printf("Node <%i:%s> has no customOp set\n", v.node().id(), @@ -98,14 +100,21 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, * this is a basic exection logic: roll through layers and sequences and * execute them one by one sequentially */ + std::deque stackFrames; + + StackFrame baseFrame(proxy); + + // now we create one default StackFrame. current one. + stackFrames.push_back(baseFrame); + const auto numDevices = AffinityManager::numberOfDevices(); - Nd4jStatus result = Status::OK(); + Nd4jStatus result = Status::OK(); // for (uint64_t l = 0; l < graph.layers(); l++) { const auto &layer = graph.layer(l); //TODO: this loop is executable in parallel, so we should do this eventually for (uint64_t o = 0; o < layer.width(); o++) { - result = execute(layer[o], graph, proxy, -1); + result = execute(layer[o], graph, stackFrames, -1); } // early termination @@ -121,7 +130,11 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, } } + // that's the rule. it can't be not equal to 1. + assert(stackFrames.size() == 1); + return result; } + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index dc3409ec974e..05fb84b8c9dc 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -259,7 +259,7 @@ ContextPrototype &ContextPrototype::operator=( void ContextPrototype::setNodeId(int id) { _nodeId = id; } -std::string ContextPrototype::name() const { return _name; } +const std::string& ContextPrototype::name() const { return _name; } void ContextPrototype::setName(const std::string &name) { _name = name; } } // namespace graph diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 5f1dfe243f2e..f0c7a3391ff0 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -637,13 +637,11 @@ Node::Node(Node &&other) noexcept { _hasInternalInputs = other._hasInternalInputs; _active = other._active; - _protoContext = other._protoContext; + _protoContext = std::move(other._protoContext); _customOp = std::move(other._customOp); _input = std::move(other._input); _output = std::move(other._output); - - other._customOp = nullptr; } Node &Node::operator=(Node &&other) noexcept { @@ -662,7 +660,7 @@ Node &Node::operator=(Node &&other) noexcept { _hasInternalInputs = other._hasInternalInputs; _active = other._active; - _protoContext = other._protoContext; + _protoContext = std::move(other._protoContext); _customOp = std::move(other._customOp); _input = std::move(other._input); diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 382710d9b5e6..a0c8e64acf13 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -261,9 +261,7 @@ std::shared_ptr VariableSpace::getVariable(int id) const { return _variables.at(id); } -VariableSpace::~VariableSpace() { - // -} +VariableSpace::~VariableSpace() { } VariableSpace::VariableSpace(const VariableSpace &other) { _stash = other._stash; diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 93d73ad4af8a..24e84eecb57d 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -30,4 +30,57 @@ using namespace sd::graph; class NodeTests : public testing::Test { public: -}; \ No newline at end of file +}; + +TEST_F(NodeTests, test_copy_1) { + Node a(sd::ops::add(), "add"); + + Node b(sd::ops::divide(), "div"); + + ASSERT_NE(a.name(), b.name()); + ASSERT_NE(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_NE(a.contextPrototype().name(), b.contextPrototype().name()); + + b = a; + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +static FORCEINLINE Node copy(const Node& node) { + Node a(node); + return a; +} + +TEST_F(NodeTests, test_copy_2) { + Node a(sd::ops::add(), "add"); + + Node b(sd::ops::divide(), "div"); + + ASSERT_NE(a.name(), b.name()); + ASSERT_NE(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_NE(a.contextPrototype().name(), b.contextPrototype().name()); + + b = copy(a); + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +TEST_F(NodeTests, test_copy_3) { + Node a(sd::ops::add(), "add"); + + Node b = copy(a); + + ASSERT_EQ(a.name(), b.name()); + ASSERT_EQ(a.customOp()->getOpName(), b.customOp()->getOpName()); + ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); + + ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} \ No newline at end of file From 28ce87eb1a05d96214d4f1ed6077284e46f4cf4d Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 18:28:38 +0300 Subject: [PATCH 165/233] && Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 24e84eecb57d..28e4327f3edc 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -51,8 +51,7 @@ TEST_F(NodeTests, test_copy_1) { } static FORCEINLINE Node copy(const Node& node) { - Node a(node); - return a; + return Node(node); } TEST_F(NodeTests, test_copy_2) { From cf8f677a3a7da98306b4ed29903fc6310c5bff03 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 20:21:27 +0300 Subject: [PATCH 166/233] one more test Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/NodeTests.cpp | 25 ++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp index 28e4327f3edc..59d7db737528 100644 --- a/libnd4j/tests_cpu/layers_tests/NodeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NodeTests.cpp @@ -82,4 +82,29 @@ TEST_F(NodeTests, test_copy_3) { ASSERT_EQ(a.contextPrototype().name(), b.contextPrototype().name()); ASSERT_NE(&a.contextPrototype(), &b.contextPrototype()); +} + +static MAP_IMPL modifier(MAP_IMPL map) { + if (map.size() < 5) { + map[3] = Node(sd::ops::multiply(), "mul"); + return map; + } else { + map.erase(1); + return map; + } +} + +TEST_F(NodeTests, test_copy_4) { + MAP_IMPL map; + map[1] = Node(sd::ops::add(), "add"); + map[2] = Node(sd::ops::divide(), "div"); + + auto other = modifier(map); + + ASSERT_EQ(3, other.size()); + + ASSERT_EQ(map[1].name(), other[1].name()); + ASSERT_EQ(map[1].contextPrototype().name(), other[1].contextPrototype().name()); + + ASSERT_NE(&map[1].contextPrototype(), &other[1].contextPrototype()); } \ No newline at end of file From 92992462090daed5940fae9e60b4fb5fbd266f2c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 20:40:53 +0300 Subject: [PATCH 167/233] no mo Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/ExecutionTask.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/execution/ExecutionTask.h b/libnd4j/include/graph/execution/ExecutionTask.h index c8532222557e..be15ee185b6f 100644 --- a/libnd4j/include/graph/execution/ExecutionTask.h +++ b/libnd4j/include/graph/execution/ExecutionTask.h @@ -31,9 +31,11 @@ namespace sd { namespace graph { class SD_EXPORT ExecutionTask { protected: - // FIXME: do we really want references here? smart pointers would work better - const Node& _node; - const ContextPrototype& _context; + // TODO: smart pointers here? + const Node _node; + + // FIXME: this field can be removed. Node contains ContextPrototype. + const ContextPrototype _context; public: ExecutionTask(const Node& node, From 264efe05d97a26c3e4e436ed58aa4bf8015ac9b6 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 10 Jun 2020 21:00:39 +0300 Subject: [PATCH 168/233] meh Signed-off-by: raver119@gmail.com --- libnd4j/include/array/NDArray.h | 7 ------ libnd4j/include/array/cpu/NDArray.cpp | 2 -- libnd4j/include/array/cuda/NDArray.cu | 4 ++-- libnd4j/include/array/impl/NDArray.cpp | 6 ++--- .../include/execution/cpu/LaunchContext.cpp | 24 ++++++++----------- 5 files changed, 14 insertions(+), 29 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 8726920c985b..a8bf746727a2 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -847,12 +847,6 @@ class SD_EXPORT NDArray { */ Nd4jLong argMax(std::initializer_list dimensions = {}); - // FIXME: remove this method eventually - void makeBothActual() const { - syncToDevice(); - syncToHost(); - } - void applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments* extraParams = nullptr); void applyTransform(sd::transform::SameOps op, NDArray& target, @@ -2185,7 +2179,6 @@ const Nd4jLong* NDArray::shapeInfo() const { return _shapeInfo; } //////////////////////////////////////////////////////////////////////// const Nd4jLong* NDArray::specialShapeInfo() const { if (_shapeInfoD == nullptr) return _shapeInfo; - // FIXME: this should be fixed once CUDA backend added return _shapeInfoD; } diff --git a/libnd4j/include/array/cpu/NDArray.cpp b/libnd4j/include/array/cpu/NDArray.cpp index 3f50c0e9ea77..d57849458edb 100644 --- a/libnd4j/include/array/cpu/NDArray.cpp +++ b/libnd4j/include/array/cpu/NDArray.cpp @@ -257,14 +257,12 @@ const void* NDArray::specialBufferWithOffset(Nd4jLong offset) const { //////////////////////////////////////////////////////////////////////// void* NDArray::specialBuffer() { if (_buffer->special() == nullptr) return buffer(); - // FIXME: this should be fixed once CUDA backend added return static_cast(_buffer->special()) + (_offset * sizeOfT()); } //////////////////////////////////////////////////////////////////////// void const* NDArray::specialBuffer() const { if (_buffer->special() == nullptr) return buffer(); - // FIXME: this should be fixed once CUDA backend added return static_cast(_buffer->special()) + (_offset * sizeOfT()); } diff --git a/libnd4j/include/array/cuda/NDArray.cu b/libnd4j/include/array/cuda/NDArray.cu index 831abb8a8738..fd95c73637aa 100644 --- a/libnd4j/include/array/cuda/NDArray.cu +++ b/libnd4j/include/array/cuda/NDArray.cu @@ -615,7 +615,7 @@ void* NDArray::specialBuffer() { syncToDevice(); tickReadHost(); } - // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } @@ -625,7 +625,7 @@ void const* NDArray::specialBuffer() const { syncToDevice(); tickReadHost(); } - // FIXME: this should be fixed once CUDA backend added + return static_cast(_buffer->special()) + (_offset * sizeOfT()); } diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index a198810f8b17..854a5c56ae90 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -1050,7 +1050,6 @@ bool NDArray::isR() const { ////////////////////////////////////////////////////////////////////////// bool NDArray::isZ() const { - // TODO: decide if we really want to exclude Bool here return !isC() && !isR() && !isB() && !isS(); } @@ -2744,9 +2743,8 @@ NDArray NDArray::cast(DataType dtype) const { //////////////////////////////////////////////////////////////////////// void NDArray::cast(NDArray& target, DataType dtype) { if (isS()) - throw std::runtime_error( - "NDArray::cast: you can't use this method on String array!"); - // TODO: to be implemented properly + throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); + target.assign(this); } diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 566867800167..8a0e9ad1045b 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -67,14 +67,13 @@ LaunchContext::LaunchContext(Nd4jPointer cudaStream, static std::mutex _lock; - LaunchContext* LaunchContext::defaultContext() { - { - // synchronous block goes here - std::lock_guard lock(_lock); - // TODO: we need it to be device-aware, but only once we add NUMA support for - // cpu - if (LaunchContext::_contexts.empty()) - LaunchContext::_contexts.emplace_back(std::make_shared()); +LaunchContext* LaunchContext::defaultContext() { + { + // synchronous block goes here + std::lock_guard lock(_lock); + // TODO: we need it to be device-aware, but only once we add NUMA support for cpu + if (LaunchContext::_contexts.empty()) + LaunchContext::_contexts.emplace_back(std::make_shared()); } // return context for current device @@ -83,19 +82,16 @@ static std::mutex _lock; std::mutex* LaunchContext::deviceMutex() { return &_mutex; } -void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { - // -} +void LaunchContext::swapContextBuffers(ContextBuffers& buffers) { } bool LaunchContext::isInitialized() { return true; } -void LaunchContext::releaseBuffers() { - // -} +void LaunchContext::releaseBuffers() { } sd::ErrorReference* LaunchContext::errorReference() { return contextBuffers.errorReference(); } void* LaunchContext::engine() { return _engine; } + } // namespace sd \ No newline at end of file From 04087208711ae182f8d23437fcee4b6c06b68754 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 08:33:49 +0300 Subject: [PATCH 169/233] few additional assertion Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index db3269ef7cfc..3b71d06e0016 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -885,6 +885,16 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { ASSERT_EQ(std::string("add"), seq[0].node().name()); ASSERT_EQ(std::string("range_1"), seq[1].node().name()); ASSERT_EQ(std::string("range_2"), seq[2].node().name()); + + std::pair exp; + + exp = {seq[1].node().id(), 0}; + ASSERT_EQ(exp, seq[0].node().output()[0]); + + exp = {seq[2].node().id(), 0}; + ASSERT_EQ(exp, seq[1].node().output()[0]); + + ASSERT_EQ(0, seq[2].node().output().size()); } TEST_F(GraphAnalysisTests, optimizedGraph_13) { From c3795cb20e45b480b3ed6b0a6fc8eb99e1993f62 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 08:42:00 +0300 Subject: [PATCH 170/233] two Node signatures renamed Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 4 ++-- libnd4j/include/graph/impl/Graph.cpp | 6 +++--- libnd4j/include/graph/impl/Node.cpp | 16 ++++++++-------- libnd4j/include/graph/impl/OptimizedGraph.cpp | 2 +- .../layers_tests/GraphAnalysisTests.cpp | 6 +++--- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 073f864eac0c..68e6808f61f2 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -140,8 +140,8 @@ class SD_EXPORT Node { Nd4jLong opNum() const; - const std::vector> &input() const; - const std::vector> &output() const; + const std::vector> &inputs() const; + const std::vector> &outputs() const; const std::vector> &dependencies() const; void setId(int id); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index c942f87f35a3..cd106678494c 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -218,10 +218,10 @@ void Graph::printOutNode(const Node &node) const { nd4j_printf("Inputs: [", ""); // auto block = node->getBlock(); - for (int e = 0; e < node.input().size(); e++) { - auto in = node.input()[e]; + for (int e = 0; e < node.inputs().size(); e++) { + auto in = node.inputs()[e]; printf("{%i:%i}", in.first, in.second); - if (e < node.input().size() - 1) nd4j_printf(", ", ""); + if (e < node.inputs().size() - 1) nd4j_printf(", ", ""); } if (node.opType() == OpType_CUSTOM) { diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index f0c7a3391ff0..0b132f42fb0b 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -200,9 +200,9 @@ int Node::id() const { return _id; } Nd4jLong Node::opNum() const { return _opNum; } -const std::vector> &Node::input() const { return _input; } +const std::vector> &Node::inputs() const { return _input; } -const std::vector> &Node::output() const { return _output; } +const std::vector> &Node::outputs() const { return _output; } Node::Node(const std::string &opName, const std::string &nodeName, const int id, const std::vector &inputs, @@ -474,8 +474,8 @@ Node::Node(const FlatNode *node) { node->inputPaired()->size() > 0) { ContextPrototype block(nullptr, this->id(), false); - for (int e = 0; e < this->input().size(); e++) { - block.pickInput(this->input().at(e)); + for (int e = 0; e < this->inputs().size(); e++) { + block.pickInput(this->inputs().at(e)); } // there's no other IArgs in legacy options, actually @@ -510,8 +510,8 @@ Node::Node(const FlatNode *node) { if (!this->name().empty()) block.setName(this->name()); - for (int e = 0; e < this->input().size(); e++) - block.pickInput(this->input().at(e)); + for (int e = 0; e < this->inputs().size(); e++) + block.pickInput(this->inputs().at(e)); this->setContextPrototype(block); } else if (this->_opType == OpType_CUSTOM) { @@ -524,8 +524,8 @@ Node::Node(const FlatNode *node) { if (!this->name().empty()) block.setName(this->name()); - for (int e = 0; e < this->input().size(); e++) - block.pickInput(this->input().at(e)); + for (int e = 0; e < this->inputs().size(); e++) + block.pickInput(this->inputs().at(e)); if (node->extraInteger() != nullptr) for (uint32_t e = 0; e < node->extraInteger()->size(); e++) { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index a28e28d9be1f..356f57429283 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -60,7 +60,7 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v std::vector startNodes; for (const auto& p : inMap) { - const auto& inputs = p.second.input(); + const auto& inputs = p.second.inputs(); for (int i = 0; i < inputs.size(); ++i) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 3b71d06e0016..757f3695b020 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -889,12 +889,12 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { std::pair exp; exp = {seq[1].node().id(), 0}; - ASSERT_EQ(exp, seq[0].node().output()[0]); + //ASSERT_EQ(exp, seq[0].node().output()[0]); exp = {seq[2].node().id(), 0}; - ASSERT_EQ(exp, seq[1].node().output()[0]); + ASSERT_EQ(exp, seq[1].node().outputs()[0]); - ASSERT_EQ(0, seq[2].node().output().size()); + ASSERT_EQ(0, seq[2].node().outputs().size()); } TEST_F(GraphAnalysisTests, optimizedGraph_13) { From adea25b5238b7c21a0967c525172bdd40112d4e0 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 09:03:21 +0300 Subject: [PATCH 171/233] one more assert Signed-off-by: raver119@gmail.com --- .../tests_cpu/layers_tests/GraphAnalysisTests.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 757f3695b020..16338680a592 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -983,6 +983,9 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { ASSERT_EQ(std::string("in_0/read"), seq[0].node().name()); ASSERT_EQ(std::string("cond/Switch"), seq[1].node().name()); + auto swtch = seq[1].node(); + ASSERT_EQ(2, swtch.outputs().size()); + // layer 1 layer = optimized.layer(1); ASSERT_EQ(2, layer.width()); @@ -994,6 +997,16 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { ASSERT_EQ(1, seq.length()); ASSERT_EQ(std::string("cond/switch_f"), seq[0].node().name()); + std::pair exp; + + // checking True condition first + exp = {layer[0][0].node().id(), 1}; + ASSERT_EQ(exp, swtch.outputs()[0]); + + // checking False condition next + exp = {layer[1][0].node().id(), 0}; + ASSERT_EQ(exp, swtch.outputs()[1]); + // layer 2 layer = optimized.layer(2); ASSERT_EQ(1, layer.width()); From ffac730c651c6dca5fe1ba515da748df2757333f Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 09:44:42 +0300 Subject: [PATCH 172/233] StackFrames Signed-off-by: raver119@gmail.com --- .../include/graph/execution/impl/GraphExecutor.cpp | 11 +++++++---- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 1c124f0d98c6..0bfcddc76786 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -100,12 +100,12 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, * this is a basic exection logic: roll through layers and sequences and * execute them one by one sequentially */ - std::deque stackFrames; - StackFrame baseFrame(proxy); + // root StackFrame is built from VariableProxy copy + StackFrame rootFrame(proxy); - // now we create one default StackFrame. current one. - stackFrames.push_back(baseFrame); + // now we create out dequeue of frames with one root StackFrame. current one. + std::deque stackFrames({rootFrame}); const auto numDevices = AffinityManager::numberOfDevices(); Nd4jStatus result = Status::OK(); // @@ -133,6 +133,9 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, // that's the rule. it can't be not equal to 1. assert(stackFrames.size() == 1); + // update original VariableProxy + proxy.pullFrom(stackFrames.front().variableProxy()); + return result; } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 16338680a592..6b0307c30d91 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -889,7 +889,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_12) { std::pair exp; exp = {seq[1].node().id(), 0}; - //ASSERT_EQ(exp, seq[0].node().output()[0]); + ASSERT_EQ(exp, seq[0].node().outputs()[0]); exp = {seq[2].node().id(), 0}; ASSERT_EQ(exp, seq[1].node().outputs()[0]); @@ -1079,5 +1079,5 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { TEST_F(GraphAnalysisTests, optimizedGraph_while1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); - graph.printOut(); + //graph.printOut(); } From 41b30d29cbe3fd6bbb2c7e73eccc57a4af197b2c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 11:50:37 +0300 Subject: [PATCH 173/233] minor updates Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/OpSequence.h | 13 +++++++++ .../graph/execution/impl/GraphExecutor.cpp | 11 +++++--- .../graph/execution/impl/OpSequence.cpp | 27 ++++++++++++++++++- .../include/graph/logic/LogicConditional.h | 3 ++- libnd4j/include/graph/logic/LogicEnter.h | 3 ++- libnd4j/include/graph/logic/LogicExecutor.h | 3 ++- libnd4j/include/graph/logic/LogicExit.h | 3 ++- libnd4j/include/graph/logic/LogicExpose.h | 7 ++--- libnd4j/include/graph/logic/LogicLoopCond.h | 3 ++- libnd4j/include/graph/logic/LogicMerge.h | 5 +++- .../include/graph/logic/LogicNextIteration.h | 5 +++- libnd4j/include/graph/logic/LogicReturn.h | 4 ++- libnd4j/include/graph/logic/LogicScope.h | 4 ++- libnd4j/include/graph/logic/LogicSwitch.h | 4 ++- libnd4j/include/graph/logic/LogicWhile.h | 4 ++- .../graph/logic/impl/LogicConditional.cpp | 3 ++- .../include/graph/logic/impl/LogicEnter.cpp | 3 ++- .../graph/logic/impl/LogicExecutor.cpp | 25 ++++++++--------- .../include/graph/logic/impl/LogicExit.cpp | 3 ++- .../include/graph/logic/impl/LogicExpose.cpp | 3 ++- .../graph/logic/impl/LogicLoopCond.cpp | 3 ++- .../include/graph/logic/impl/LogicMerge.cpp | 3 ++- .../graph/logic/impl/LogicNextIteration.cpp | 3 ++- .../include/graph/logic/impl/LogicReturn.cpp | 3 ++- .../include/graph/logic/impl/LogicScope.cpp | 3 ++- .../include/graph/logic/impl/LogicSwitch.cpp | 3 ++- .../include/graph/logic/impl/LogicWhile.cpp | 3 ++- 27 files changed, 115 insertions(+), 42 deletions(-) diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index 5ce0a5acc398..e477d4968443 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -42,6 +42,12 @@ class SD_EXPORT OpSequence int _deviceId = 0; + // this map contains Node::id() -> OpSequence index mappings + MAP_IMPL _idToIndex; + + // this map contains OpSequence index -> Node::id() mapping + MAP_IMPL _indexToId; + public: explicit OpSequence(const std::vector& ops, int deviceId = 0); OpSequence(int deviceId = 0); @@ -95,6 +101,13 @@ class SD_EXPORT OpSequence void append(const ExecutionTask& task); void append(ExecutionTask&& task); + /** + * These two methods provide access to index/id dictionalries + * @param index + * @return + */ + int nodeId(int index) const; + int nodeIndex(int id) const; /** * Iterator functionality for OpSequence diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 0bfcddc76786..1b01bf023b83 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -77,10 +77,13 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, auto &v = seq[e]; auto &p = stackFrames.back().variableProxy(); - // only Ops can be executed this way :( - if (v.node().hasCustomOp()) - result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); - else { + + if (v.node().opType() == OpType_LOGIC) { + + } else if (v.node().hasCustomOp()) { + // only Ops can be executed this way :( + result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); + } else { nd4j_printf("Node <%i:%s> has no customOp set\n", v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str()); diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index c114f5f0d7cc..4bde45cf3a7b 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -20,6 +20,7 @@ #include #include +#include namespace sd { namespace graph { @@ -81,15 +82,39 @@ uint64_t OpSequence::length() const { return _ops.size(); } void OpSequence::append(const Node &node, const sd::graph::ContextPrototype &ctx) { ExecutionTask task(node, ctx); - _ops.emplace_back(task); + append(task); } void OpSequence::append(const ExecutionTask& task) { _ops.emplace_back(task); + + // updating dictionaries + auto index = _ops.size() - 1; + _idToIndex[task.node().id()] = index; + _indexToId[index] = task.node().id(); } void OpSequence::append(ExecutionTask&& task) { _ops.emplace_back(std::move(task)); + + // updating dictionaries + auto index = _ops.size() - 1; + _idToIndex[task.node().id()] = index; + _indexToId[index] = task.node().id(); +} + +int OpSequence::nodeId(int index) const { + if (index < 0 || index >= _ops.size() || _indexToId.count(index) < 1) + throw std::runtime_error("Out-of-size index requested: " + StringUtils::valueToString(index)); + + return _indexToId.at(index); +} + +int OpSequence::nodeIndex(int id) const { + if ( _idToIndex.count(id) < 1) + throw std::runtime_error("Unknown Node ID requested: " + StringUtils::valueToString(id)); + + return _idToIndex.at(id); } OpSequence::iterator OpSequence::begin() { diff --git a/libnd4j/include/graph/logic/LogicConditional.h b/libnd4j/include/graph/logic/LogicConditional.h index ce7ef82b5ad5..52d34933d770 100644 --- a/libnd4j/include/graph/logic/LogicConditional.h +++ b/libnd4j/include/graph/logic/LogicConditional.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -41,7 +42,7 @@ namespace graph { */ class LogicConditional { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h index 480217205c4c..021f63ec9b66 100644 --- a/libnd4j/include/graph/logic/LogicEnter.h +++ b/libnd4j/include/graph/logic/LogicEnter.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,7 +29,7 @@ namespace sd { namespace graph { class LogicEnter { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index cbf0ed0b64e1..118e16bd9aa0 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -34,7 +35,7 @@ namespace graph { */ class LogicExecutor { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h index 7beebee1929d..cc1e8dfc99c1 100644 --- a/libnd4j/include/graph/logic/LogicExit.h +++ b/libnd4j/include/graph/logic/LogicExit.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,7 +29,7 @@ namespace sd { namespace graph { class LogicExit { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicExpose.h b/libnd4j/include/graph/logic/LogicExpose.h index fc3614b9cded..2b5eaa677349 100644 --- a/libnd4j/include/graph/logic/LogicExpose.h +++ b/libnd4j/include/graph/logic/LogicExpose.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,7 +20,7 @@ // #ifndef SD_LOGICEXPOSE_H -#define LIBND4J_LOGICEXPOSE_H +#define SD_LOGICEXPOSE_H #include #include @@ -29,9 +30,9 @@ namespace sd { namespace graph { class LogicExpose { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd -#endif // LIBND4J_LOGICEXPOSE_H +#endif // SD_LOGICEXPOSE_H diff --git a/libnd4j/include/graph/logic/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h index 53ff911f5d8c..1642d5955581 100644 --- a/libnd4j/include/graph/logic/LogicLoopCond.h +++ b/libnd4j/include/graph/logic/LogicLoopCond.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,7 +29,7 @@ namespace sd { namespace graph { class LogicLoopCond { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h index a808e046e83c..78d1ec3f60b7 100644 --- a/libnd4j/include/graph/logic/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -26,10 +27,12 @@ namespace sd { namespace graph { + class LogicMerge { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h index 7a17645f93fe..13752130fe60 100644 --- a/libnd4j/include/graph/logic/LogicNextIteration.h +++ b/libnd4j/include/graph/logic/LogicNextIteration.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -26,10 +27,12 @@ namespace sd { namespace graph { + class LogicNextIeration { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicReturn.h b/libnd4j/include/graph/logic/LogicReturn.h index b462eb4c6618..eef1cdc6a9b5 100644 --- a/libnd4j/include/graph/logic/LogicReturn.h +++ b/libnd4j/include/graph/logic/LogicReturn.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,8 +37,9 @@ namespace graph { */ class LogicReturn { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicScope.h b/libnd4j/include/graph/logic/LogicScope.h index bd197b1a0dc8..cba7e9d0041f 100644 --- a/libnd4j/include/graph/logic/LogicScope.h +++ b/libnd4j/include/graph/logic/LogicScope.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,8 +37,9 @@ namespace graph { */ class LogicScope { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index 9d659cf4a91b..407788bdff24 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,8 +37,9 @@ namespace graph { */ class LogicSwitch { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicWhile.h b/libnd4j/include/graph/logic/LogicWhile.h index d5437c5eae8e..cc4d6cdf2758 100644 --- a/libnd4j/include/graph/logic/LogicWhile.h +++ b/libnd4j/include/graph/logic/LogicWhile.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -36,8 +37,9 @@ namespace graph { */ class LogicWhile { public: - static Nd4jStatus processNode(Graph* graph, Node* node); + static Nd4jStatus processNode(const Node* node); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp index ec1437d7bbd5..392a1db6ebbc 100644 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ b/libnd4j/include/graph/logic/impl/LogicConditional.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -24,7 +25,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicConditional::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicConditional::processNode(const Node *node) { throw std::runtime_error( "LogicConditional::processNode - not implemented yet"); /* diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index da4721bee3bd..db52b4b732c0 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -23,7 +24,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicEnter::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicEnter::processNode(const Node *node) { throw std::runtime_error("LogicEnter::processNode - not implemented yet"); /* // this op replicates input variable into the frame. basically happens once diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index ee9ece1e773b..afc926ff8417 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -33,30 +34,30 @@ namespace sd { namespace graph { -Nd4jStatus LogicExecutor::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicExecutor::processNode(const Node *node) { switch (node->opNum()) { case sd::logic::While: - return LogicWhile::processNode(graph, node); + return LogicWhile::processNode(node); case sd::logic::Scope: - return LogicScope::processNode(graph, node); + return LogicScope::processNode(node); case sd::logic::Conditional: - return LogicConditional::processNode(graph, node); + return LogicConditional::processNode(node); case sd::logic::Switch: - return LogicSwitch::processNode(graph, node); + return LogicSwitch::processNode(node); case sd::logic::Return: - return LogicReturn::processNode(graph, node); + return LogicReturn::processNode(node); case sd::logic::Expose: - return LogicExpose::processNode(graph, node); + return LogicExpose::processNode(node); case sd::logic::Merge: - return LogicMerge::processNode(graph, node); + return LogicMerge::processNode(node); case sd::logic::LoopCond: - return LogicLoopCond::processNode(graph, node); + return LogicLoopCond::processNode(node); case sd::logic::NextIteration: - return LogicNextIeration::processNode(graph, node); + return LogicNextIeration::processNode(node); case sd::logic::Exit: - return LogicExit::processNode(graph, node); + return LogicExit::processNode(node); case sd::logic::Enter: - return LogicEnter::processNode(graph, node); + return LogicEnter::processNode(node); } if (node->name().empty()) { diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 1d0be2c7da91..c5329a1caaf9 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExit::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicExit::processNode(const Node *node) { // this op is basically no-op // we just know it exists throw std::runtime_error("LogicExit::processNode - Not implemented yet"); diff --git a/libnd4j/include/graph/logic/impl/LogicExpose.cpp b/libnd4j/include/graph/logic/impl/LogicExpose.cpp index 3717adab45f2..825bb0db13c2 100644 --- a/libnd4j/include/graph/logic/impl/LogicExpose.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExpose.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExpose::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicExpose::processNode(const Node *node) { // do we really want this? return ND4J_STATUS_OK; } diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index d3e191b0cdc2..5fa1c6e1bfa6 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicLoopCond::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicLoopCond::processNode(const Node *node) { throw std::runtime_error("LogicLoopCond::processNode - Not implemented yet"); /* auto __variableSpace = graph->variableSpace(); diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 4c9f1d2baa81..547722b271ba 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -23,7 +24,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicMerge::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicMerge::processNode(const Node *node) { throw std::runtime_error("LogicMerge::processNode - not implemented yet"); /* // at merge node only one of inputs exist if that's just switch and other node diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 1a0eafd94395..fd59336ad85c 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicNextIeration::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicNextIeration::processNode(const Node *node) { throw std::runtime_error( "LogicNextIeration::processNode - not implemented yet"); /* diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp index 5b19f380cd22..a9d98b3140bd 100644 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ b/libnd4j/include/graph/logic/impl/LogicReturn.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -25,7 +26,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicReturn::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicReturn::processNode(const Node *node) { throw std::runtime_error("LogicReturn::processNode - not implemented yet"); /* auto __variableSpace = graph->variableSpace(); diff --git a/libnd4j/include/graph/logic/impl/LogicScope.cpp b/libnd4j/include/graph/logic/impl/LogicScope.cpp index 89738507edf9..c1efb207b652 100644 --- a/libnd4j/include/graph/logic/impl/LogicScope.cpp +++ b/libnd4j/include/graph/logic/impl/LogicScope.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -23,7 +24,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicScope::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicScope::processNode(const Node *node) { // this op is basically no-op // we just know it exists return sd::Status::OK(); diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 8de39b462b10..5f9997713d4f 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -24,7 +25,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicSwitch::processNode(Graph* graph, Node* node) { +Nd4jStatus LogicSwitch::processNode(const Node* node) { throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); /* auto __variableSpace = graph->variableSpace(); diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp index c377d59058d8..2472f346daa1 100644 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ b/libnd4j/include/graph/logic/impl/LogicWhile.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -25,7 +26,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicWhile::processNode(Graph *graph, Node *node) { +Nd4jStatus LogicWhile::processNode(const Node *node) { throw std::runtime_error("LogicWhile::processNode - not implemented yet"); /* auto __variableSpace = graph->variableSpace(); From 7ad0492782e1840a43548a4070451658ff762906 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 11:51:46 +0300 Subject: [PATCH 174/233] swap Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 68e6808f61f2..96c802cde815 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -50,10 +50,10 @@ class SD_EXPORT Node { OpClass _opClass; Nd4jLong _opNum; - // Inputs are stored in format + // Inputs are stored in format std::vector> _input; - // Outputs are stored in format + // Outputs are stored in format std::vector> _output; // Control flow dependencies for Node From 16e986bce7e4fdb2c7883d764364d2f0621916a9 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 11 Jun 2020 12:01:29 +0300 Subject: [PATCH 175/233] - correct number of outputs in node Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 356f57429283..de45c6fbf6bf 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -67,7 +67,7 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v if (inputs[i].first >= inMap.begin()->first) { // is op workMap[p.first]._in.push_back(inputs[i].first); workMap[inputs[i].first]._out.push_back(p.first); - inMap[inputs[i].first].pickOutput(p.first, i); + inMap[inputs[i].first].pickOutput(p.first, inputs[i].second); } else { // is variable @@ -77,7 +77,7 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { workMap[p.first]._in.push_back(depends[j].first); workMap[depends[j].first]._out.push_back(p.first); - inMap[depends[j].first].pickOutput(p.first, j); + inMap[depends[j].first].pickOutput(p.first, depends[j].second); } } } From f4a4f531b310485b552cf7ff5de2007253029b0d Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 11 Jun 2020 17:25:46 +0300 Subject: [PATCH 176/233] - introduce special treatment for NextIteration op in graph loop Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 46 +++++++++++++++++-- .../layers_tests/GraphAnalysisTests.cpp | 3 +- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index de45c6fbf6bf..6cd3d3a003f6 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -47,6 +47,16 @@ OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { /////////////////////////////////////////////////////////////////// OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& varSpace) { + + // for (const auto& p : inMap) { + // printf("%s %i, inputs: ", p.second.name().c_str(), p.first); + // const auto& inputs = p.second.inputs(); + // for (int i = 0; i < inputs.size(); ++i) + // printf("%i, ", inputs[i].first); + // printf("\n"); + // } + // return; + struct NodeInfo { uint _layerNum = 0; std::vector _opSeq = {}; @@ -65,9 +75,11 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v for (int i = 0; i < inputs.size(); ++i) { if (inputs[i].first >= inMap.begin()->first) { // is op - workMap[p.first]._in.push_back(inputs[i].first); - workMap[inputs[i].first]._out.push_back(p.first); inMap[inputs[i].first].pickOutput(p.first, inputs[i].second); + if(inMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { + workMap[inputs[i].first]._out.push_back(p.first); + workMap[p.first]._in.push_back(inputs[i].first); + } } else { // is variable @@ -75,9 +87,11 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v for (int j = 0; j < depends.size(); ++j) { if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { - workMap[p.first]._in.push_back(depends[j].first); - workMap[depends[j].first]._out.push_back(p.first); inMap[depends[j].first].pickOutput(p.first, depends[j].second); + if(inMap[depends[j].first].name().find("NextIteration") == std::string::npos) { + workMap[depends[j].first]._out.push_back(p.first); + workMap[p.first]._in.push_back(depends[j].first); + } } } } @@ -86,6 +100,28 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v if(workMap[p.first]._in.empty()) startNodes.push_back(p.first); } +// printf("\n\n\n\n\n"); + // for (const auto& p : workMap) { + // printf("id %i, inputs: ", p.first); + // for (const auto& i : p.second._in) + // printf("%i, ", i); + // printf("outputs: "); + // for (const auto& i : p.second._out) + // printf("%i, ", i); + // printf("\n"); + // } + + // for (const auto& p : inMap) { + // printf("id %i, inputs ", p.first); + // for(const auto& i : p.second.inputs()) + // printf("%i, ", i.first); + // printf(" outputs: "); + // for(const auto& i : p.second.outputs()) + // printf("%i, ", i.first); + // printf("\n"); + // } + // printf("\n\n\n\n\n"); + // collect OpSequences (fill _opSeq) std::vector nodesToDelete; @@ -133,7 +169,7 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v sortedGraphTemp[p.second._layerNum].append(std::move(seq)); } - // check whether there is layer with one OpSequence containing only one op + // check whether there are layers with one OpSequence which in turn contains only one op bool isLayerWithOneOp = false; for(auto& layer : sortedGraphTemp) { if(layer.width() == 1 && layer.at(0).length() == 1) { diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 6b0307c30d91..7e94798f0238 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1079,5 +1079,6 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { TEST_F(GraphAnalysisTests, optimizedGraph_while1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); - //graph.printOut(); + const auto& optimized = graph.optimizedGraph(); + // graph.printOut(); } From ca037cc44a7dd946144837f553b062a7ae170f31 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 11 Jun 2020 18:42:34 +0300 Subject: [PATCH 177/233] Stack? Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 7 +-- .../include/graph/execution/GraphExecutor.h | 3 +- libnd4j/include/graph/execution/Stack.h | 49 +++++++++++++++++++ libnd4j/include/graph/execution/StackFrame.h | 4 ++ .../graph/execution/impl/GraphExecutor.cpp | 15 +++--- .../include/graph/execution/impl/Stack.cpp | 47 ++++++++++++++++++ .../graph/execution/impl/StackFrame.cpp | 16 +++++- libnd4j/include/graph/impl/Node.cpp | 1 + 8 files changed, 128 insertions(+), 14 deletions(-) create mode 100644 libnd4j/include/graph/execution/Stack.h create mode 100644 libnd4j/include/graph/execution/impl/Stack.cpp diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 96c802cde815..327a4bbef42a 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,8 +19,8 @@ // @author raver119@gmail.com // -#ifndef LIBND4J_GNODE_H -#define LIBND4J_GNODE_H +#ifndef SD_GNODE_H +#define SD_GNODE_H #include #include @@ -184,4 +185,4 @@ class SD_EXPORT Node { } // namespace graph } // namespace sd -#endif // LIBND4J_GNODE_H +#endif // SD_GNODE_H diff --git a/libnd4j/include/graph/execution/GraphExecutor.h b/libnd4j/include/graph/execution/GraphExecutor.h index 461267262a70..281fd419fddb 100644 --- a/libnd4j/include/graph/execution/GraphExecutor.h +++ b/libnd4j/include/graph/execution/GraphExecutor.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ class SD_EXPORT GraphExecutor { */ virtual Nd4jStatus execute(const OpSequence &seq, const OptimizedGraph &graph, - std::deque &stackFrames, + Stack &stack, int deviceId) const; /** diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h new file mode 100644 index 000000000000..fbebc04d0bbe --- /dev/null +++ b/libnd4j/include/graph/execution/Stack.h @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_STACK_H_ +#define SD_STACK_H_ + +#include +#include +#include + +namespace sd { +namespace graph { + +class SD_EXPORT Stack { + private: + std::deque _frames; + + public: + Stack(const VariableProxy &root); + ~Stack() = default; + + StackFrame& back(); + StackFrame& front(); + StackFrame& root(); + + const VariableProxy& rootVariableSpace() const; +}; + +} // namespace graph +} // namespace sd + +#endif //SD_STACK_H_ diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index dc036a636acb..9adb4ff2814d 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -31,11 +31,15 @@ class SD_EXPORT StackFrame { private: VariableProxy _proxy; + MAP_IMPL _disabledNodes; public: explicit StackFrame(VariableProxy &proxy); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } + + void disableNode(int nodeId); + bool isDisabled(int nodeId) const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 1b01bf023b83..63883b23e740 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -62,7 +62,7 @@ Nd4jStatus GraphExecutor::execute( Nd4jStatus GraphExecutor::execute(const OpSequence &seq, const OptimizedGraph &graph, - std::deque &stackFrames, + Stack &stack, const int deviceId) const { // we either follow or override target deviceId specified in OpSequence auto targetDevice = deviceId >= 0 @@ -75,7 +75,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, auto result = Status::OK(); for (int e = 0; e < seq.length(); e++) { auto &v = seq[e]; - auto &p = stackFrames.back().variableProxy(); + auto &p = stack.back().variableProxy(); if (v.node().opType() == OpType_LOGIC) { @@ -104,11 +104,8 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, * execute them one by one sequentially */ - // root StackFrame is built from VariableProxy copy - StackFrame rootFrame(proxy); - // now we create out dequeue of frames with one root StackFrame. current one. - std::deque stackFrames({rootFrame}); + Stack stack(proxy); const auto numDevices = AffinityManager::numberOfDevices(); Nd4jStatus result = Status::OK(); // @@ -117,7 +114,7 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, //TODO: this loop is executable in parallel, so we should do this eventually for (uint64_t o = 0; o < layer.width(); o++) { - result = execute(layer[o], graph, stackFrames, -1); + result = execute(layer[o], graph, stack, -1); } // early termination @@ -134,10 +131,10 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, } // that's the rule. it can't be not equal to 1. - assert(stackFrames.size() == 1); + //assert(stackFrames.size() == 1); // update original VariableProxy - proxy.pullFrom(stackFrames.front().variableProxy()); + proxy.pullFrom(stack.front().variableProxy()); return result; } diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp new file mode 100644 index 000000000000..0abbde8cbc0e --- /dev/null +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +Stack::Stack(const VariableProxy &root) { + _frames.push_back(StackFrame(const_cast(root))); +} + +const VariableProxy &Stack::rootVariableSpace() const { + return _frames.front().variableProxy(); +} + +StackFrame &Stack::back() { + return _frames.back(); +} + +StackFrame &Stack::front() { + return _frames.front(); +} + +StackFrame &Stack::root() { + return _frames.front(); +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index d96f48c5006f..fb84551270b9 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -21,4 +21,18 @@ #include -sd::graph::StackFrame::StackFrame(sd::graph::VariableProxy &proxy) : _proxy(proxy) { } +namespace sd { +namespace graph { + +StackFrame::StackFrame(VariableProxy &proxy) : _proxy(proxy) { } + +void StackFrame::disableNode(int nodeId) { + _disabledNodes[nodeId] = 1; +} + +bool StackFrame::isDisabled(int nodeId) const { + return _disabledNodes.count(nodeId) > 0; +} + +} // namespace graph +} // namespace sd diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 0b132f42fb0b..93ffb856c010 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at From 78e96e88b78d60962a53d71544f957ae6410b899 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 12 Jun 2020 16:36:56 +0300 Subject: [PATCH 178/233] LogicSwitch draft Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/OptimizedGraph.h | 3 - .../graph/execution/impl/GraphExecutor.cpp | 17 +++- libnd4j/include/graph/logic/LogicExecutor.h | 3 +- libnd4j/include/graph/logic/LogicSwitch.h | 3 +- .../graph/logic/impl/LogicExecutor.cpp | 4 +- .../include/graph/logic/impl/LogicSwitch.cpp | 94 ++++--------------- .../layers_tests/GraphAnalysisTests.cpp | 3 +- 7 files changed, 37 insertions(+), 90 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index dcb2649d0fbc..1c6808a06741 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -31,12 +31,9 @@ namespace sd { namespace graph { -class Graph; - class SD_EXPORT OptimizedGraph { private: std::vector _sortedGraph; - // const Graph& _originalGraph; public: OptimizedGraph(MAP_IMPL map, const VariableSpace& varSpace); diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 63883b23e740..68983d0cc199 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -20,6 +20,7 @@ #include #include +#include namespace sd { @@ -75,12 +76,21 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, auto result = Status::OK(); for (int e = 0; e < seq.length(); e++) { auto &v = seq[e]; - auto &p = stack.back().variableProxy(); + auto &f = stack.back(); + auto &p = f.variableProxy(); if (v.node().opType() == OpType_LOGIC) { + nd4j_printf("Node <%i:%s> is a logic op\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str()); + LogicExecutor::processNode(&v.node(), f); } else if (v.node().hasCustomOp()) { + // we skip all disabled nodes + if (f.isDisabled(v.node().id())) + continue; + // only Ops can be executed this way :( result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); } else { @@ -130,10 +140,7 @@ Nd4jStatus GraphExecutor::execute(const OptimizedGraph &graph, } } - // that's the rule. it can't be not equal to 1. - //assert(stackFrames.size() == 1); - - // update original VariableProxy + // update original VariableSpace from the top-level VariableSpace proxy.pullFrom(stack.front().variableProxy()); return result; diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index 118e16bd9aa0..7a63af5705c7 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -35,7 +36,7 @@ namespace graph { */ class LogicExecutor { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, StackFrame &frame); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index 407788bdff24..4b7cae7d92c2 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -37,7 +38,7 @@ namespace graph { */ class LogicSwitch { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, StackFrame &frame); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index afc926ff8417..b59f198f20bd 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -34,7 +34,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExecutor::processNode(const Node *node) { +Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame) { switch (node->opNum()) { case sd::logic::While: return LogicWhile::processNode(node); @@ -43,7 +43,7 @@ Nd4jStatus LogicExecutor::processNode(const Node *node) { case sd::logic::Conditional: return LogicConditional::processNode(node); case sd::logic::Switch: - return LogicSwitch::processNode(node); + return LogicSwitch::processNode(node, frame); case sd::logic::Return: return LogicReturn::processNode(node); case sd::logic::Expose: diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 5f9997713d4f..d4c1cde6e473 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -25,91 +25,31 @@ namespace sd { namespace graph { -Nd4jStatus LogicSwitch::processNode(const Node* node) { - throw std::runtime_error("LogicSwitch::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); +Nd4jStatus LogicSwitch::processNode(const Node* node, StackFrame &frame) { + const auto &inputs = node->inputs(); + const auto &outputs = node->outputs(); - Context ctx(node->getContextPrototype(), __variableSpace); + auto &varSpace = const_cast(frame.variableProxy()); - // this can be either our format, or compatible format. - if (graph->hasScope(node->input()->at(0).first)) { - nd4j_debug("Node_%i: Scoped mode.\n", node->id()); - // first input is Scope, so it's ours - int scopeConditionIndex = node->input()->at(0).first; - auto input = ctx.variable(1); + REQUIRE_TRUE(inputs.size() == 2, 0, "Switch: op must have exactly 2 inputs"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Switch: input Variable doesn't exist"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[1]), 0, "Switch: condition Variable doesn't exist"); - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } + auto input = varSpace.getVariable(inputs[0]); + auto boolean = varSpace.getVariable(inputs[1]); - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node"); + REQUIRE_TRUE(boolean->hasNDArray(), 0, "Switch: boolean Variable must have NDArray defined"); - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, - node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, - node->id(), 1)); - - if (!result->e(0)) { - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input->getNDArray()); - __variableSpace->getVariable(pair1)->markRemovable(false); - } + if (boolean->getNDArray()->e(0)) { + // true branch + varSpace.putVariable(std::pair{node->id(), 1}, *input->getNDArray()); } else { - // first input is NOT a Scope, so it's compatible format - nd4j_debug("Node_%i: Compatible mode.\n", node->id()); - - auto input = ctx.variable(0)->getNDArray(); - auto boolean = ctx.variable(1)->getNDArray(); - - //input->printIndexedBuffer("0"); - //boolean->printIndexedBuffer("1"); - - std::pair pair0(node->id(), 0); - std::pair pair1(node->id(), 1); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, - node->id(), 0)); - - if (!__variableSpace->hasVariable(pair1)) - __variableSpace->putVariable(pair1, new Variable(nullptr, nullptr, - node->id(), 1)); - - if (!boolean->e(0)) { - // false - nd4j_debug("Node_%i: FALSE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 0); - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - } else { - //true - nd4j_debug("Node_%i: TRUE branch active\n", node->id()); - __flowPath->markBranch(node->id(), 1); - __variableSpace->getVariable(pair1)->setNDArray(input); - __variableSpace->getVariable(pair1)->markRemovable(false); - } + // false branch + varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); } - return sd::Status::OK(); - */ + return Status::OK(); }; + } // namespace graph } // namespace sd diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 7e94798f0238..2335228d0c7e 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1045,7 +1045,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { auto graph = Graph::fromFlatBuffers("resources/cond_false.fb"); const auto& optimized = graph.optimizedGraph(); - // graph.printOut(); + graph.printOut(); ASSERT_EQ(3, optimized.layers()); @@ -1075,6 +1075,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { ASSERT_EQ(1, seq.length()); ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); + graph.execute(); } TEST_F(GraphAnalysisTests, optimizedGraph_while1) { From cbf0f68d623641b22210b2e03c2ff95f06ae8aa6 Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 12 Jun 2020 16:48:38 +0300 Subject: [PATCH 179/233] - add new data member (map of nodes) to OptimizedGraph class Signed-off-by: Yurii --- libnd4j/include/graph/OptimizedGraph.h | 9 +++++++- libnd4j/include/graph/impl/OptimizedGraph.cpp | 23 +++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 1c6808a06741..7ab39681f65a 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -34,9 +34,10 @@ namespace graph { class SD_EXPORT OptimizedGraph { private: std::vector _sortedGraph; + MAP_IMPL _nodesMap; public: - OptimizedGraph(MAP_IMPL map, const VariableSpace& varSpace); + OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); // move constructor OptimizedGraph(OptimizedGraph&& other) noexcept; // default constructor @@ -82,6 +83,12 @@ class SD_EXPORT OptimizedGraph { * @param sequence */ void append(const OpSequence &sequence); + + /** + * returns reference on _nodesMap + * @return + */ + const MAP_IMPL& getNodesMap() const { return _nodesMap; } }; diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 6cd3d3a003f6..d650c80fafbc 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -45,10 +45,9 @@ OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { } /////////////////////////////////////////////////////////////////// -OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& varSpace) { +OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace): _nodesMap(inMap) { - - // for (const auto& p : inMap) { + // for (const auto& p : _nodesMap) { // printf("%s %i, inputs: ", p.second.name().c_str(), p.first); // const auto& inputs = p.second.inputs(); // for (int i = 0; i < inputs.size(); ++i) @@ -68,15 +67,15 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v // create workMap, fill vectors containing input and output nodes per each node, and find start nodes std::vector startNodes; - for (const auto& p : inMap) { + for (const auto& p : _nodesMap) { const auto& inputs = p.second.inputs(); for (int i = 0; i < inputs.size(); ++i) { - if (inputs[i].first >= inMap.begin()->first) { // is op - inMap[inputs[i].first].pickOutput(p.first, inputs[i].second); - if(inMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { + if (inputs[i].first >= _nodesMap.begin()->first) { // is op + _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); + if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { workMap[inputs[i].first]._out.push_back(p.first); workMap[p.first]._in.push_back(inputs[i].first); } @@ -87,8 +86,8 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v for (int j = 0; j < depends.size(); ++j) { if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { - inMap[depends[j].first].pickOutput(p.first, depends[j].second); - if(inMap[depends[j].first].name().find("NextIteration") == std::string::npos) { + _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); + if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { workMap[depends[j].first]._out.push_back(p.first); workMap[p.first]._in.push_back(depends[j].first); } @@ -111,7 +110,7 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v // printf("\n"); // } - // for (const auto& p : inMap) { + // for (const auto& p : _nodesMap) { // printf("id %i, inputs ", p.first); // for(const auto& i : p.second.inputs()) // printf("%i, ", i.first); @@ -161,10 +160,10 @@ OptimizedGraph::OptimizedGraph(MAP_IMPL inMap, const VariableSpace& v for (const auto& p : workMap) { OpSequence seq; - seq.append(inMap.at(p.first), inMap.at(p.first).contextPrototype()); + seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); for (const auto& id : p.second._opSeq) - seq.append(inMap.at(id), inMap.at(id).contextPrototype()); + seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); sortedGraphTemp[p.second._layerNum].append(std::move(seq)); } From 4c77db9a0e19af19a15dcdc11a0f493d85e04173 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 12 Jun 2020 20:02:15 +0300 Subject: [PATCH 180/233] LogicSwitch draft Signed-off-by: raver119@gmail.com --- .../graph/execution/impl/GraphExecutor.cpp | 2 +- libnd4j/include/graph/logic/LogicExecutor.h | 3 +- libnd4j/include/graph/logic/LogicSwitch.h | 2 +- .../graph/logic/impl/LogicExecutor.cpp | 4 +- .../include/graph/logic/impl/LogicSwitch.cpp | 40 ++++++++++++++++++- 5 files changed, 45 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 68983d0cc199..66286bd7d7a2 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -85,7 +85,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str()); - LogicExecutor::processNode(&v.node(), f); + LogicExecutor::processNode(&v.node(), f, graph); } else if (v.node().hasCustomOp()) { // we skip all disabled nodes if (f.isDisabled(v.node().id())) diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index 7a63af5705c7..bc696ac05325 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -36,7 +37,7 @@ namespace graph { */ class LogicExecutor { public: - static Nd4jStatus processNode(const Node* node, StackFrame &frame); + static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index 4b7cae7d92c2..93bf514a12b9 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -38,7 +38,7 @@ namespace graph { */ class LogicSwitch { public: - static Nd4jStatus processNode(const Node* node, StackFrame &frame); + static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index b59f198f20bd..a23658b78025 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -34,7 +34,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame) { +Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame, const OptimizedGraph& graph) { switch (node->opNum()) { case sd::logic::While: return LogicWhile::processNode(node); @@ -43,7 +43,7 @@ Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame) { case sd::logic::Conditional: return LogicConditional::processNode(node); case sd::logic::Switch: - return LogicSwitch::processNode(node, frame); + return LogicSwitch::processNode(node, frame, graph); case sd::logic::Return: return LogicReturn::processNode(node); case sd::logic::Expose: diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index d4c1cde6e473..0afc83769158 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -25,7 +25,43 @@ namespace sd { namespace graph { -Nd4jStatus LogicSwitch::processNode(const Node* node, StackFrame &frame) { + +static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node) { + const auto &outputs = node->outputs(); + + // we're going to roll through all consumers + for (const auto &o:outputs) { + // now fetch disabled node + const auto &n = graph.getNodesMap().at(o.first); + + // edge case here: don't disable Merge node + if (n.opType() == OpType_LOGIC && n.opNum() == sd::logic::Merge) + continue; + + // disable each consumer + frame.disableNode(o.first); + + // do recursive magic + disableBranch(frame, varSpace, graph, &n); + } +} + +static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable) { + const auto &outputs = node->outputs(); + int second = branchToDisable ? 1 : 0; + + for (const auto &o:outputs) { + if (o.second == second) { + frame.disableNode(o.first); + + const auto &n = graph.getNodesMap().at(o.first); + + disableBranch(frame, varSpace, graph, &n); + } + } +} + +Nd4jStatus LogicSwitch::processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph) { const auto &inputs = node->inputs(); const auto &outputs = node->outputs(); @@ -43,9 +79,11 @@ Nd4jStatus LogicSwitch::processNode(const Node* node, StackFrame &frame) { if (boolean->getNDArray()->e(0)) { // true branch varSpace.putVariable(std::pair{node->id(), 1}, *input->getNDArray()); + disableBranch(frame, varSpace, graph, node, false); } else { // false branch varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); + disableBranch(frame, varSpace, graph, node, true); } return Status::OK(); From 37ef5e974f631c41d5b7fb8d0dc236f884b5f82c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 12 Jun 2020 22:06:01 +0300 Subject: [PATCH 181/233] LogicSwitch draft Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/VariableProxy.h | 2 + libnd4j/include/graph/VariableSpace.h | 1 + libnd4j/include/graph/impl/Graph.cpp | 10 +- libnd4j/include/graph/impl/OptimizedGraph.cpp | 5 +- libnd4j/include/graph/impl/VariableProxy.cpp | 4 + libnd4j/include/graph/logic/LogicMerge.h | 2 +- .../graph/logic/impl/LogicExecutor.cpp | 2 +- .../include/graph/logic/impl/LogicMerge.cpp | 120 ++---------------- .../include/graph/logic/impl/LogicSwitch.cpp | 17 ++- .../layers_tests/GraphAnalysisTests.cpp | 36 ++---- 10 files changed, 58 insertions(+), 141 deletions(-) diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 0b1c707a3272..6f17b29ac1ae 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -41,6 +41,8 @@ class SD_EXPORT VariableProxy : public VariableSpace { virtual const std::vector>& placeholders() const override; + virtual const MAP_IMPL, std::shared_ptr>& externalPaired() const; + virtual bool hasExternalVariable(int it) const override; virtual bool hasExternalVariable( const std::pair& pair) const override; diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 82b24f52bb26..67e43a6d12f9 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -39,6 +39,7 @@ namespace sd { namespace graph { class SD_EXPORT VariableSpace { + friend class VariableProxy; protected: // stash is NOT cloned Stash _stash; diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index cd106678494c..5e32633cd43e 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -622,11 +622,17 @@ std::map Graph::execute( // fetch outputs from our VariableProxy std::map result; for (const auto &v : outputs) { - if (!proxy.hasVariable(v)) + // resolve string -> int dep + int id = -119; + if (_symbolicLookupTable.count(v) > 0) + id = _symbolicLookupTable.at(v); + + + if (!proxy.hasVariable(id)) throw unresolved_output_exception::build( "Requested output doesn't exist after execution", v); - auto var = proxy.getVariable(v); + auto var = proxy.getVariable(id); // TODO: we want to make sure ManagedDataBuffer doesn't leak here result[v] = *var->getNDArray(); diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index d650c80fafbc..b679e4940129 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -28,9 +28,7 @@ namespace graph { /////////////////////////////////////////////////////////////////// // move constructor -OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept: _sortedGraph(std::move(other._sortedGraph)) { - -} +OptimizedGraph::OptimizedGraph(OptimizedGraph &&other) noexcept: _sortedGraph(std::move(other._sortedGraph)), _nodesMap(std::move(other._nodesMap)) { } /////////////////////////////////////////////////////////////////// // move assignment operator @@ -40,6 +38,7 @@ OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { return *this; _sortedGraph = std::move(other._sortedGraph); + _nodesMap = std::move(other._nodesMap); return *this; } diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index c67d30633035..809a0c1fa1ae 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -160,6 +160,10 @@ std::shared_ptr VariableProxy::putVariable( return _current.putVariable(pair, array); } +const MAP_IMPL, std::shared_ptr> &VariableProxy::externalPaired() const { + return _backed->_paired; +} + void VariableProxy::putVariable(const std::pair &pair, const std::shared_ptr &variable) { _current.putVariable(pair, variable); diff --git a/libnd4j/include/graph/logic/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h index 78d1ec3f60b7..3130f1a97354 100644 --- a/libnd4j/include/graph/logic/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -30,7 +30,7 @@ namespace graph { class LogicMerge { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index a23658b78025..22fc63ea726b 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -49,7 +49,7 @@ Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame, const case sd::logic::Expose: return LogicExpose::processNode(node); case sd::logic::Merge: - return LogicMerge::processNode(node); + return LogicMerge::processNode(node, frame, graph); case sd::logic::LoopCond: return LogicLoopCond::processNode(node); case sd::logic::NextIteration: diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 547722b271ba..b99a83426005 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -16,7 +16,7 @@ ******************************************************************************/ // -// Created by raver119 on 30.01.18. +// @author raver119@gmail.com // #include @@ -24,118 +24,26 @@ namespace sd { namespace graph { -Nd4jStatus LogicMerge::processNode(const Node *node) { - throw std::runtime_error("LogicMerge::processNode - not implemented yet"); - /* - // at merge node only one of inputs exist if that's just switch and other node -isn't LogicNextItration auto __variableSpace = graph->variableSpace(); auto -__flowPath = __variableSpace->flowPath(); - - // merge MUST have 2 inputs - auto inputAddr0 = node->input().at(0); - auto inputAddr1 = node->input().at(1); - - bool isWhile = false; - - // now we want to check if second input is NextIteration - if (graph->hasNode(inputAddr1.first)) { - auto secondNode = graph->nodeById(inputAddr1.first); - - // checking for NextIteration - if (secondNode->opType() == OpType_LOGIC && secondNode->opNum() == 80L) { - isWhile = true; - - // notifying NextIteration node for rewind index - secondNode->setRewindLayer(node->getLayer()); - secondNode->setRewindNode(node->id()); - } +Nd4jStatus LogicMerge::processNode(const Node *node, StackFrame &frame, const OptimizedGraph& graph) { + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); + REQUIRE_TRUE(inputs.size() == 2, 0, "Merge: op expects exactly 2 inputs, but only %i defined", (int) inputs.size()); + if (frame.isDisabled(inputs[0].first) && frame.isDisabled(inputs[1].first)) { + REQUIRE_TRUE(false, 0, "Merge: only 1 input should be disabled, but got both of them down"); } - // FIXME: we don't need this check. Just last input should survive, IF it -exists if (isWhile){ - - if (node->getFrameId() >= 0) - __flowPath->markFrameActive(node->getFrameId(), true); - - bool hasVar = __variableSpace->hasVariable(inputAddr1); - if ( hasVar && __flowPath->wasExecuted(inputAddr1.first)) { - nd4j_debug("Node_%i: propagating second input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr1); - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), -0); - -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); - - auto array = var->getNDArray(); - - //array->printIndexedBuffer("propagated"); - - lvar->setNDArray(array); - lvar->markReadOnly(true); - - __flowPath->markExecuted(inputAddr1.first, false); + // we're getting first non-disable input and propagate it + const auto &p = frame.isDisabled(inputs[0].first) ? inputs[1] : inputs[0]; + REQUIRE_TRUE(frame.variableProxy().hasVariable(p), 0, "Merge: Variable [%i:%i] doesn't exist", p.first, p.second); - } else { - nd4j_debug("Node_%i: propagating first input\n", node->id()); - auto var = __variableSpace->getVariable(inputAddr0); - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), -0); - -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - - - } - } else { - - // basically, first non-null variable is our target - for (int e = 0; e < node->input().size(); e++) { - auto inputAddr = node->input().at(e); - - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (!var->hasNDArray() || -!__flowPath->isNodeActive(inputAddr.first)) continue; - - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), -node->id(), 0); - - if (lvar->hasNDArray()) - delete lvar->getNDArray(); - - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); - //lvar->markExternal(false);h - - break; - } - } - } + std::pair t(node->id(), 0); + auto array = varSpace.getVariable(p)->getNDArray().get(); + varSpace.putVariable(t, *array); return Status::OK(); - */ } + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 0afc83769158..189fa4373f43 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -16,7 +16,7 @@ ******************************************************************************/ // -// Created by raver119 on 21.10.17. +// @author raver119@gmail.com // #include @@ -29,8 +29,20 @@ namespace graph { static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node) { const auto &outputs = node->outputs(); + // we're going to disable certain external variables, if they depend on a current disabled node + // FIXME: it can be done in a better way rather than O(n^2) + for (const auto &var: varSpace.externalPaired()) { + for (const auto &d: var.second->dependencies()) { + if (d.first == node->id()) + frame.disableNode(var.second->id()); + } + } + // we're going to roll through all consumers for (const auto &o:outputs) { + if (graph.getNodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + // now fetch disabled node const auto &n = graph.getNodesMap().at(o.first); @@ -54,6 +66,9 @@ static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const Opti if (o.second == second) { frame.disableNode(o.first); + if (graph.getNodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + const auto &n = graph.getNodesMap().at(o.first); disableBranch(frame, varSpace, graph, &n); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 2335228d0c7e..501c5e7bc84a 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -970,7 +970,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { auto graph = Graph::fromFlatBuffers("resources/cond_true.fb"); const auto& optimized = graph.optimizedGraph(); - // graph.printOut(); + graph.printOut(); // we expect exactly 3 layers ASSERT_EQ(3, optimized.layers()); @@ -1014,31 +1014,11 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond1) { ASSERT_EQ(1, seq.length()); ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); - // graph.execute(); - /* - some infor that would be useful for implementation - currently on optimization graph is passing next data - - Node name: cond/switch_f; ID: 11; Input: 9, 0; Operation type: 21; Operation - class: -1719689536 Node name: cond/switch_t; ID: 10; Input: 9, 1; Operation - type: 21; Operation class: -1719689536 Node name: cond/Switch; ID: 9; - Input: 1, 0; Operation type: 119; Operation class: -1719689536 Node name: - cond/Switch; ID: 9; Input: 6, 0; Operation type: 119; Operation class: - -1719689536 Node name: cond/Merge; ID: 8; Input: 5, 0; Operation type: - 119; Operation class: -1719689536 Node name: cond/Merge; ID: 8; Input: 7, - 0; Operation type: 119; Operation class: -1719689536 Node name: in_0/read; ID: - 6; Input: 1, 0; Operation type: 21; Operation class: -1719689536 Node name: - cond/LinSpace; ID: 7; Input: 2, 0; Operation type: 21; Operation class: - -1719689536 Node name: cond/LinSpace; ID: 7; Input: 3, 0; Operation type: 21; - Operation class: -1719689536 Node name: cond/LinSpace; ID: 7; Input: 4, 0; - Operation type: 21; Operation class: -1719689536 - - as it can be seen cond/LinSpace is not connected with any switch node(s) that - causes wrong results of optimization. also maybe to cover all conditional - operations will be need "Operation class", but this have to discovered deeper. - - All above is true for test_cond_2 - */ + auto res = graph.execute({}, {"cond/Merge"}); + ASSERT_EQ(1, res.size()); + auto arr = res["cond/Merge"]; + + ASSERT_EQ(NDArrayFactory::create({1.f, 2.f, 3.f, 4.f, 5.f}), arr); } TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { @@ -1075,7 +1055,9 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { ASSERT_EQ(1, seq.length()); ASSERT_EQ(std::string("cond/Merge"), seq[0].node().name()); - graph.execute(); + auto res = graph.execute({}, {"cond/Merge"}); + ASSERT_EQ(1, res.size()); + ASSERT_EQ(NDArrayFactory::create({1.f, 1.f, 1.f, 1.f, 1.f}), res["cond/Merge"]); } TEST_F(GraphAnalysisTests, optimizedGraph_while1) { From 4d02fc504dcef1e9e804b88d95f7c7244eb836b6 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 08:09:31 +0300 Subject: [PATCH 182/233] weird printout Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index ea670bafaced..3694587e6c13 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -39,6 +39,7 @@ class OneOffTests : public testing::Test { TEST_F(OneOffTests, test_avg_pool_3d_1) { auto graph = Graph::fromFlatBuffers("./resources/avg_pooling3d.fb"); + graph.printOut(); graph.execute(); } From 0512e9c82e6c47881d953cc49603da657594283c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 08:28:20 +0300 Subject: [PATCH 183/233] Node fields Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 10 +++------- libnd4j/include/graph/impl/Node.cpp | 4 ---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 327a4bbef42a..96eadbde7f82 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -47,9 +47,9 @@ class SD_EXPORT Node { ContextPrototype _protoContext; // these 2 fields are used for Logic ops only - OpType _opType; - OpClass _opClass; - Nd4jLong _opNum; + OpType _opType = OpType_GRAPH; + OpClass _opClass = OpClass_GRAPH; + Nd4jLong _opNum = 0; // Inputs are stored in format std::vector> _input; @@ -69,10 +69,6 @@ class SD_EXPORT Node { std::shared_ptr _customOp; - // each node can be active or inactive, if used with divergents, like IF - // statements - bool _active = true; - public: explicit Node(const sd::ops::DeclarableOp &op, const std::string &nodeName = {}, diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 93ffb856c010..3bd020ab8c70 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -589,7 +589,6 @@ Node::Node(const Node &other) noexcept { _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _active = other._active; _customOp = other._customOp; _protoContext = other._protoContext; @@ -612,7 +611,6 @@ Node &Node::operator=(const Node &other) noexcept { _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _active = other._active; _customOp = other._customOp; _protoContext = other._protoContext; @@ -636,7 +634,6 @@ Node::Node(Node &&other) noexcept { _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _active = other._active; _protoContext = std::move(other._protoContext); @@ -659,7 +656,6 @@ Node &Node::operator=(Node &&other) noexcept { _hasExternalInputs = other._hasExternalInputs; _hasInternalOutputs = other._hasInternalOutputs; _hasInternalInputs = other._hasInternalInputs; - _active = other._active; _protoContext = std::move(other._protoContext); From bb0d89876f4f0bac45fc0ee0aeead73475977971 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 08:28:50 +0300 Subject: [PATCH 184/233] one more comment Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 1 + 1 file changed, 1 insertion(+) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index 96eadbde7f82..b872c283c09c 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -61,6 +61,7 @@ class SD_EXPORT Node { std::vector> _dependencies; std::vector _stringDependencies; + // TODO: these fields should be removed // service state fields bool _hasExternalOutputs = false; bool _hasExternalInputs = false; From e7f3b48f1bbd46979a5460ea8d099a991eb9c805 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 09:02:43 +0300 Subject: [PATCH 185/233] minor Node tweaks Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/Node.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index 3bd020ab8c70..fe256f82fc18 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -104,7 +104,7 @@ Node::Node(const std::string &opName, const std::string &nodeName, bool Node::isDivergencePoint() { if (hasCustomOp()) { return _customOp->getOpDescriptor()->isDivergent(); - } else if (opType() == OpType_LOGIC && opNum() == 30) + } else if (opType() == OpType_LOGIC && opNum() == sd::logic::Switch) return true; else return false; @@ -135,7 +135,7 @@ const std::string &Node::name() const { return _name; } void Node::setName(const std::string &name) { _name = name; } void Node::pickInput(const std::pair &pair) { - _input.push_back(pair); + _input.emplace_back(pair); _protoContext.pickInput(pair); } @@ -159,7 +159,7 @@ void Node::pickInput(int inputId) { void Node::pickExternalOutput(int outputId) { std::pair pair(outputId, 0); - _output.push_back(pair); + _output.emplace_back(pair); _hasExternalOutputs = true; } From db8b49cb6eaf4e88e1cf0687e980b89ce7549465 Mon Sep 17 00:00:00 2001 From: Yurii Date: Mon, 15 Jun 2020 14:38:30 +0300 Subject: [PATCH 186/233] - change condition on variable to be input array in graph Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index b679e4940129..f21790a24d14 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -72,7 +72,7 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (int i = 0; i < inputs.size(); ++i) { - if (inputs[i].first >= _nodesMap.begin()->first) { // is op + if (_nodesMap.count(inputs[i].first) != 0) { // is op _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { workMap[inputs[i].first]._out.push_back(p.first); From 38294ff4ff5ee45c7c16588167270275e3666021 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 15:09:50 +0300 Subject: [PATCH 187/233] NDArray ostream Signed-off-by: raver119@gmail.com --- libnd4j/include/array/NDArray.h | 10 ++++++---- libnd4j/include/array/impl/NDArray.cpp | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index a8bf746727a2..07c31a842271 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -219,7 +219,7 @@ class SD_EXPORT NDArray { int _deviceId = AffinityManager::currentDeviceId(); template - std::string toStringValue(T value); + std::string toStringValue(T value) const; public: NDArray() = default; @@ -688,8 +688,8 @@ class SD_EXPORT NDArray { */ void printIndexedBuffer(const char* msg = nullptr, Nd4jLong limit = -1) const; - std::string asIndexedString(Nd4jLong limit = -1); - std::string asString(Nd4jLong limit = -1); + std::string asIndexedString(Nd4jLong limit = -1) const; + std::string asString(Nd4jLong limit = -1) const; /** * this method assigns values of given array to this one @@ -1772,7 +1772,9 @@ class SD_EXPORT NDArray { FORCEINLINE bool operator==(const NDArray& other) const; FORCEINLINE bool operator!=(const NDArray& other) const; -}; +}; // class NDArray + +std::ostream &operator<<(std::ostream &os, const NDArray &m); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 854a5c56ae90..e073ef0534a1 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -907,6 +907,13 @@ NDArray::NDArray(const std::vector& shape, tickWriteHost(); syncToDevice(); } + + +std::ostream& operator<<(std::ostream &os, const NDArray &m) { + os << m.asIndexedString(); + return os; +} + ///////////////////////////////////////////////////////////////////////// NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, @@ -1060,7 +1067,7 @@ bool NDArray::isB() const { ////////////////////////////////////////////////////////////////////////// template -std::string NDArray::toStringValue(T value) { +std::string NDArray::toStringValue(T value) const { std::ostringstream os; // throw the value into the string stream os << value; @@ -1070,7 +1077,7 @@ std::string NDArray::toStringValue(T value) { ////////////////////////////////////////////////////////////////////////// template <> -std::string NDArray::toStringValue(float16 value) { +std::string NDArray::toStringValue(float16 value) const { std::ostringstream os; // throw the value into the string stream os << (float)value; @@ -1080,7 +1087,7 @@ std::string NDArray::toStringValue(float16 value) { ////////////////////////////////////////////////////////////////////////// template <> -std::string NDArray::toStringValue(bfloat16 value) { +std::string NDArray::toStringValue(bfloat16 value) const { std::ostringstream os; // throw the value into the string stream os << (float)value; @@ -1089,7 +1096,7 @@ std::string NDArray::toStringValue(bfloat16 value) { } ////////////////////////////////////////////////////////////////////////// -std::string NDArray::asIndexedString(Nd4jLong limit) { +std::string NDArray::asIndexedString(Nd4jLong limit) const { std::ostringstream os; os << "["; if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); @@ -1102,7 +1109,7 @@ std::string NDArray::asIndexedString(Nd4jLong limit) { } ////////////////////////////////////////////////////////////////////////// -std::string NDArray::asString(Nd4jLong limit) { +std::string NDArray::asString(Nd4jLong limit) const { std::ostringstream os; os << "["; if (limit < 1 || limit > this->lengthOf()) limit = this->lengthOf(); From 64e66062b72e030b476343bc05ab1e54bee73499 Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 15 Jun 2020 16:55:39 +0300 Subject: [PATCH 188/233] Refactored printIndexedBuffer routine. Signed-off-by: shugeo --- libnd4j/include/array/NDArray.h | 6 + libnd4j/include/array/impl/NDArray.cpp | 135 ++++++++++++++++-- .../layers_tests/DeclarableOpsTests12.cpp | 12 +- 3 files changed, 133 insertions(+), 20 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index a8bf746727a2..d37f5fa8022c 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -667,6 +667,12 @@ class SD_EXPORT NDArray { void printBuffer(const char* msg = nullptr, Nd4jLong limit = -1, const bool sync = true) const; + /** + * make strings for current ndarray - linear or structured by dimensions + * */ + std::string linearString(Nd4jLong limit = -1) const; + std::string indexedBufferString(Nd4jLong limit = -1) const; + /** * print element by element consequently in a way they (elements) are stored * in physical memory diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 854a5c56ae90..b868b1411770 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -35,6 +35,7 @@ #include #include #include +#include namespace sd { @@ -1885,6 +1886,8 @@ void NDArray::printShapeInfo(const char* msg) const { fflush(stdout); } +static std::string formattedString(NDArray const* arr, int depth, int limit, std::stringstream& ss); + ////////////////////////////////////////////////////////////////////////// void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) const { @@ -1967,6 +1970,41 @@ void NDArray::printLinearBuffer() const { printf("]\n"); fflush(stdout); } + + std::string NDArray::linearString(Nd4jLong limit) const { + syncToHost(); + + const auto ews = this->ews() > 0 ? this->ews() : 1; + const auto len = this->lengthOf(); + std::stringstream ss; + ss << "["; + + for (Nd4jLong e = 0; e < len; e++) { + if (e) + ss << ", "; + switch (this->dataType()) { + case sd::DataType::INT32: + ss << this->bufferAsT()[e * ews]; + break; + case sd::DataType::INT64: + ss << this->bufferAsT()[e * ews]; + break; + case sd::DataType::FLOAT32: + ss << std::setprecision(6) << this->bufferAsT()[e * ews]; + break; + case sd::DataType::DOUBLE: + ss << std::setprecision(6) << this->bufferAsT()[e * ews]; + break; + //case sd::DataType::UTF8: ss << this->bufferAsT()[e * ews]; break; + default: + throw std::invalid_argument("NDArray::linearString: not implemented yet for this data type !"); + } + + } + ss << "]"; + return ss.str(); + } + ////////////////////////////////////////////////////////////////////////// static void printFormatted(NDArray const* arr, int depth, int limit) { if (arr->rankOf() == 1) { @@ -2030,38 +2068,107 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { } } + static std::string formattedString(NDArray const* arr, int depth, int limit, std::stringstream& ss) { + + if (arr->rankOf() == 1) { + ss << "[ "; + for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { + if (arr->isR()) + ss << arr->e(i); + else if (arr->isZ()) + ss << arr->e(i); + else if (arr->isB()) + ss << arr->e(i) ? "true" : "false"; + else if (arr->isS()) { + ss << "\"" << arr->e(i).c_str() << "\""; + } + } + ss << "]"; + } else if (arr->rankOf() == 2) { + Nd4jLong rows = arr->rows(); + Nd4jLong cols = arr->columns(); + //memset(padding, ' ', depth); + ss << "["; + for (Nd4jLong row = 0; row < rows; ++row) { + if (row && depth > 0) + ss << std::setfill(' ') << std::setw(depth); + ss << "["; + Nd4jLong colLimit = cols > limit ? cols : limit; + for (Nd4jLong col = 0; col < colLimit; ++col) { + if (col) ss << (", "); + if (arr->isR()) + ss << arr->e(row, col); + else if (arr->isZ()) + ss << arr->e(row, col); + else if (arr->isB()) + ss << arr->e(row, col) ? "true" : "false"; + else if (arr->isS()) { + ss << "\"" << arr->e(row * cols + col).c_str() <<"\""; + } + } + if (row < rows - 1) + ss << "]" << std::endl; + else + ss << "]"; + } + ss << "]"; + } else { + // std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); + size_t restCount = 2; + ss << "["; + restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); + for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { + NDArray subArr = (*arr)(arrIndex, {0}); + formattedString(&subArr, depth + 1, limit, ss); + if (arrIndex < restCount - 1) { + for (Nd4jLong i = 1; i < arr->rankOf(); ++i) printf("\n"); + for (Nd4jLong i = 0; i < depth - 2; ++i) printf(" "); + } + } + ss << "]"; + } + return ss.str(); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { + auto indexedString = indexedBufferString(limit); + if (msg) + printf("%s:\n%s\n", msg, indexedString.c_str()); + else + printf("%s\n", indexedString.c_str()); + fflush(stdout); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { +std::string NDArray::indexedBufferString(Nd4jLong limit) const { syncToHost(); - + std::string output; Nd4jLong rank = this->rankOf(); bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); - if (msg) printf("%s: ", msg); - if (this->isEmpty()) { - printf("Empty\n"); + return std::string("Empty"); } else if (this->rankOf() == 0) { + std::stringstream ss; if (this->isZ()) - printf("%lld\n", this->e(0)); + ss << this->e(0); else if (this->isR()) - printf("%f\n", this->e(0)); + ss << this->e(0); else if (this->isB()) { - printf("%s\n", this->e(0) ? "true" : "false"); + ss << this->e(0) ? "true" : "false"; } else if (this->isS()) { // todo do we need this // printf("\"%lld\"\n", this->getOffset(e)); - printf("\"%s\"\n", this->e(0).c_str()); + ss << "\"" << this->e(0) << "\n"; } + return ss.str(); } else if (rowFlag && ews() == 1) - printBuffer(nullptr, limit); + return linearString(limit); else { - if (msg) printf("\n"); - printFormatted(this, 1, limit); - printf("\n"); + std::stringstream ss; + return formattedString(this, 1, limit, ss); } - fflush(stdout); } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 841623866364..fcab5d25cc85 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -3035,12 +3035,12 @@ TEST_F(DeclarableOpsTests12, QR_Test_1) { ASSERT_EQ(res.status(), ND4J_STATUS_OK); auto q = res.at(0); auto r = res.at(1); - // q->printIndexedBuffer("Orthogonal 5x5"); - // expQ.printBuffer("Orthogonal Exp"); - // r->printIndexedBuffer("Upper triangular 5x3"); - // expR.printBuffer("Upper triangular Exp"); - // q->printShapeInfo("Q shape"); - // r->printShapeInfo("R shape"); + q.printIndexedBuffer("Orthogonal 5x5"); + expQ.printBuffer("Orthogonal Exp"); + r.printIndexedBuffer("Upper triangular 5x3"); + expR.printBuffer("Upper triangular Exp"); + q.printShapeInfo("Q shape"); + r.printShapeInfo("R shape"); sd::ops::matmul opMul; auto res2 = opMul.evaluate({&q, &r}); // MmulHelper::matmul(q, r, &in, false, false); From e5ba9c9d5c16b43d95d39a2860db6f8d32d341f6 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 17:02:34 +0300 Subject: [PATCH 189/233] - Node::_axis should be copied/moved too - one test fixed Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/ContextPrototype.cpp | 4 ++++ libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 17 +++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/graph/impl/ContextPrototype.cpp b/libnd4j/include/graph/impl/ContextPrototype.cpp index 05fb84b8c9dc..8f333b8f32ad 100644 --- a/libnd4j/include/graph/impl/ContextPrototype.cpp +++ b/libnd4j/include/graph/impl/ContextPrototype.cpp @@ -178,6 +178,7 @@ ContextPrototype::ContextPrototype(const ContextPrototype &other) noexcept { _bArgs = other._bArgs; _dArgs = other._dArgs; _name = other._name; + _axis = other._axis; _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -200,6 +201,7 @@ ContextPrototype &ContextPrototype::operator=( _bArgs = other._bArgs; _dArgs = other._dArgs; _name = other._name; + _axis = other._axis; _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -221,6 +223,7 @@ ContextPrototype::ContextPrototype(ContextPrototype &&other) noexcept { _bArgs = std::move(other._bArgs); _dArgs = std::move(other._dArgs); _name = std::move(other._name); + _axis = std::move(other._axis); _nodeId = other._nodeId; _isInplace = other._isInplace; @@ -243,6 +246,7 @@ ContextPrototype &ContextPrototype::operator=( _bArgs = std::move(other._bArgs); _dArgs = std::move(other._dArgs); _name = std::move(other._name); + _axis = std::move(other._axis); _nodeId = other._nodeId; _isInplace = other._isInplace; diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 3694587e6c13..6779c40dedb3 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -338,30 +338,27 @@ TEST_F(OneOffTests, test_identity_n_2) { } TEST_F(OneOffTests, test_non2d_1) { - // if (1 > 0) - // throw std::runtime_error("Test not implemented yet"); - - auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); + auto e = NDArrayFactory::create('c', {1, 2}, {2.07706356f, 2.66380072f}); auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); + graph.printOut(); graph.execute(); - ASSERT_TRUE(graph.variableSpace().hasVariable(3)); + ASSERT_TRUE(graph.variableSpace().hasVariable(6)); - auto z = graph.variableSpace().getVariable(3)->getNDArray(); + auto z = graph.variableSpace().getVariable(6)->getNDArray(); ASSERT_TRUE(z != nullptr); ASSERT_EQ(e, *z); } TEST_F(OneOffTests, test_reduce_all_1) { - auto e = - NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); + auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); - auto graph = - Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + graph.printOut(); graph.execute(); ASSERT_TRUE(graph.variableSpace().hasVariable(1)); From 9dba7fbadf76a8be94f878137824bf26b61041ec Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 15 Jun 2020 17:12:04 +0300 Subject: [PATCH 190/233] Corrected formattedString routine. Signed-off-by: shugeo --- libnd4j/include/array/impl/NDArray.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index ef2eb15b72d6..54d230e9f51d 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -2098,7 +2098,7 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { ss << "["; for (Nd4jLong row = 0; row < rows; ++row) { if (row && depth > 0) - ss << std::setfill(' ') << std::setw(depth); + ss << std::setfill(' ') << std::setw(depth) << ' '; ss << "["; Nd4jLong colLimit = cols > limit ? cols : limit; for (Nd4jLong col = 0; col < colLimit; ++col) { From eceafbb38c85d24bfa6becabecdd9c82bd079dbb Mon Sep 17 00:00:00 2001 From: shugeo Date: Mon, 15 Jun 2020 17:25:13 +0300 Subject: [PATCH 191/233] Refactored formattedString routine. Signed-off-by: shugeo --- libnd4j/include/array/impl/NDArray.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 54d230e9f51d..ee1e709b2a8c 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -2104,7 +2104,7 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { for (Nd4jLong col = 0; col < colLimit; ++col) { if (col) ss << (", "); if (arr->isR()) - ss << arr->e(row, col); + ss << std::setw(12) << std::setprecision(6) << arr->e(row, col); else if (arr->isZ()) ss << arr->e(row, col); else if (arr->isB()) From 61db9e81732f9afe8fedb582f080c748376a6671 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 15 Jun 2020 17:31:40 +0300 Subject: [PATCH 192/233] other method should be used in << operator Signed-off-by: raver119@gmail.com --- libnd4j/include/array/impl/NDArray.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index ef2eb15b72d6..90aacb7619ab 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -911,7 +911,7 @@ NDArray::NDArray(const std::vector& shape, std::ostream& operator<<(std::ostream &os, const NDArray &m) { - os << m.asIndexedString(); + os << m.indexedBufferString(); return os; } From ee730f02573b956023f855f9cc79064e160b5e58 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Wed, 17 Jun 2020 09:56:18 +0300 Subject: [PATCH 193/233] nested while test Signed-off-by: raver119@gmail.com --- .../declarable/helpers/cpu/randomShuffle.cpp | 4 +-- .../layers_tests/DeclarableOpsTests5.cpp | 26 +++++++++---------- .../layers_tests/GraphAnalysisTests.cpp | 6 +++++ libnd4j/tests_cpu/layers_tests/RNGTests.cpp | 16 ++++++------ 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index f1e87eaecc62..15dbc2fa0fb8 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -160,7 +160,7 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen for(int i = firstDim - 1; i > 0; --i) { const int j = rng.relativeInt(i) % (i + 1); if(i != j) - subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + subArrsList.at(i).swapUnsafe(subArrsList.at(j)); } } else { @@ -177,7 +177,7 @@ static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGen auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; ++i) - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + subArrsListOut.at(i).assign(subArrsListIn.at(indices[i])); }; samediff::Threads::parallel_for(func, 0, firstDim); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 7fd376169353..61dc37514c5c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -1726,7 +1726,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->equalsTo(exp1)); + ASSERT_TRUE(output.equalsTo(exp1)); } ////////////////////////////////////////////////////////////////////// @@ -1763,8 +1763,8 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3) - || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6)); + ASSERT_TRUE(output.equalsTo(exp1) || output.equalsTo(exp2) || output.equalsTo(exp3) + || output.equalsTo(exp4) || output.equalsTo(exp5) || output.equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// @@ -1781,11 +1781,11 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { // ASSERT_TRUE(!output->equalsTo(input)); bool hasDublicates = false; - for(int i = 0; i < output->lengthOf() - 1; ++i) - for(int j = i+1; j < output->lengthOf(); ++j) - if(output->t(i) == output->t(j)) { + for(int i = 0; i < output.lengthOf() - 1; ++i) + for(int j = i+1; j < output.lengthOf(); ++j) + if(output.t(i) == output.t(j)) { hasDublicates = true; - i = output->lengthOf(); + i = output.lengthOf(); break;} ASSERT_TRUE(!hasDublicates); } @@ -1803,11 +1803,11 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { // ASSERT_TRUE(!output->equalsTo(input)); bool hasDublicates = false; - for(int i = 0; i < output->lengthOf() - 1; ++i) - for(int j = i+1; j < output->lengthOf(); ++j) - if(output->t(i) == output->t(j)) { + for(int i = 0; i < output.lengthOf() - 1; ++i) + for(int j = i+1; j < output.lengthOf(); ++j) + if(output.t(i) == output.t(j)) { hasDublicates = true; - i = output->lengthOf(); + i = output.lengthOf(); break; } ASSERT_TRUE(!hasDublicates); @@ -1824,10 +1824,10 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test7) { auto output = results.at(0); // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(!output->equalsTo(input)); + ASSERT_TRUE(!output.equalsTo(input)); auto vec1 = input.getBufferAsVector(); - auto vec2 = output->getBufferAsVector(); + auto vec2 = output.getBufferAsVector(); std::sort(vec2.begin(), vec2.end()); ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin())); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 501c5e7bc84a..a3011a7f4404 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1065,3 +1065,9 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while1) { const auto& optimized = graph.optimizedGraph(); // graph.printOut(); } + +TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { + auto graph = Graph::fromFlatBuffers("resources/simplewhile_nested.fb"); + const auto& optimized = graph.optimizedGraph(); + graph.printOut(); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index b857c8917625..224414af35a9 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1062,12 +1062,12 @@ TEST_F(RNGTests, Test_GammaDistribution_4) { ASSERT_FALSE(exp0.equalsTo(z)); sd::ops::reduce_mean testOps1; sd::ops::reduce_variance testOps2; - auto testRes1 = testOps1.evaluate({z}); - auto testRes2 = testOps2.evaluate({z}); + auto testRes1 = testOps1.evaluate({&z}); + auto testRes2 = testOps2.evaluate({&z}); // testRes1[0]->printBuffer("Mean (expected 1.0)"); // testRes2[0]->printBuffer("Variance (expected 0.5)"); - ASSERT_NEAR(testRes1[0]->t(0), 1.0f, 0.01); - ASSERT_NEAR(testRes2[0]->t(0), 0.5f, 0.02); + ASSERT_NEAR(testRes1[0].t(0), 1.0f, 0.01); + ASSERT_NEAR(testRes2[0].t(0), 0.5f, 0.02); } TEST_F(RNGTests, Test_GammaDistribution_5) { @@ -1090,12 +1090,12 @@ TEST_F(RNGTests, Test_GammaDistribution_5) { // z->printIndexedBuffer("Gamma distributed"); sd::ops::reduce_mean testOps1; sd::ops::reduce_variance testOps2; - auto testRes1 = testOps1.evaluate({z}); - auto testRes2 = testOps2.evaluate({z}); + auto testRes1 = testOps1.evaluate({&z}); + auto testRes2 = testOps2.evaluate({&z}); // testRes1[0]->printBuffer("Mean (expected 0.1)"); // testRes2[0]->printBuffer("Variance (expected 0.05)"); - ASSERT_NEAR(testRes1[0]->t(0), 0.1f, 0.02); - ASSERT_NEAR(testRes2[0]->t(0), 0.05f, 0.02); + ASSERT_NEAR(testRes1[0].t(0), 0.1f, 0.02); + ASSERT_NEAR(testRes2[0].t(0), 0.05f, 0.02); } TEST_F(RNGTests, Test_UniformDistribution_04) { From 0d892d3221d2e93c59e9b79746fb8944495aac9c Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 08:42:53 +0300 Subject: [PATCH 194/233] one more test passes Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/OneOffTests.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 6779c40dedb3..0e53696e0201 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -218,8 +218,6 @@ TEST_F(OneOffTests, test_tensor_array_2) { } TEST_F(OneOffTests, test_tensor_array_3) { - if (1 > 0) throw std::runtime_error("This test crashes"); - auto e = NDArrayFactory::create( 'c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); From aa711611280e9727d21d83fedaacee3ecfc56a64 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 09:04:45 +0300 Subject: [PATCH 195/233] one more test fixed Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Context.h | 3 +-- libnd4j/include/graph/Variable.h | 19 ++++++++----------- libnd4j/include/graph/impl/Context.cpp | 15 +++++++++++++++ libnd4j/include/graph/impl/Variable.cpp | 8 +++++--- .../declarable/generic/parity_ops/expose.cpp | 6 +++--- .../layers_tests/DeclarableOpsTests1.cpp | 10 +++------- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 8a36aefac87e..f6cffd6a7b7f 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -149,8 +149,7 @@ class SD_EXPORT Context : public sd::graph::ContextPrototype { void pushNDArrayToVariableSpace(const std::pair &pair, const NDArray &array); - void pushNDArrayListToVariableSpace(int nodeId, int index, - std::shared_ptr list); + void pushNDArrayListToVariableSpace(int nodeId, int index, std::shared_ptr list); void pushNDArrayListToVariableSpace(int nodeId, int index, const NDArrayList &list, bool track = true); diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index bd2d8634d799..9ce96a6cc768 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -61,7 +61,6 @@ class SD_EXPORT Variable { protected: int _id = 0; int _index = 0; - std::shared_ptr _ndarray; std::string _name; std::vector _shape; @@ -72,6 +71,8 @@ class SD_EXPORT Variable { bool _placeholder = false; bool _removable = true; + // actual content + std::shared_ptr _ndarray; std::shared_ptr _list; VariableType _variableType = VariableType::NDARRAY; @@ -80,16 +81,12 @@ class SD_EXPORT Variable { std::vector _stringDependencies; public: - explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, - const std::vector &shape = {}); - explicit Variable(const sd::NDArray &array, const std::string &name, int id, - int idx = 0); - explicit Variable(std::shared_ptr array, const std::string &name, - int id, int idx = 0); - explicit Variable(std::shared_ptr array, - const char *name = nullptr); - explicit Variable(const NDArrayList &arrayList, const std::string &name, - int id, int idx = 0); + explicit Variable(bool placeHolder, DataType dataType = DataType::ANY, const std::vector &shape = {}); + explicit Variable(const sd::NDArray &array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const std::string &name, int id, int idx = 0); + explicit Variable(std::shared_ptr array, const char *name = nullptr); + explicit Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx = 0); explicit Variable(); #ifndef __JAVACPP_HACK__ diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 35ec8d38eb2f..1b92e6f8d3de 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -233,6 +233,21 @@ void Context::pushNDArrayListToVariableSpace(int nodeId, int index, pushNDArrayListToVariableSpace(pair, list, track); } + +void Context::pushNDArrayListToVariableSpace(int nodeId, int index, + std::shared_ptr list) { + std::pair pair(nodeId, index); + if (!_variableSpace->hasVariable(pair)) { + auto var = std::make_shared(); + var->setId(pair.first, pair.second); + var->setNDArrayList(list); + _variableSpace->putVariable(pair, var); + } else { + auto var = _variableSpace->getVariable(pair); + var->setNDArrayList(list); + } +} + void Context::pushNDArrayListToVariableSpace(const std::pair &pair, const NDArrayList &list, bool track) { diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 40eeb3df150e..1b13a6c2cdc3 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -28,14 +28,16 @@ namespace sd { namespace graph { -Variable::Variable(const NDArrayList &arrayList, const std::string &name, - int id, int idx) { - _list = std::make_shared(arrayList); +Variable::Variable(const NDArrayList &arrayList, const std::string &name, int id, int idx) + : Variable(std::make_shared(arrayList), name, id, idx) { } +Variable::Variable(std::shared_ptr list, const std::string &name, int id, int idx) { + _list = list; if (!name.empty()) _name = name; _id = id; _index = idx; + _variableType = VariableType::ARRAY_LIST; } Variable::Variable(const NDArray &array, const std::string &name, int id, diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp index b52ed1d363b6..b3b2689046fd 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/expose.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,7 +16,7 @@ ******************************************************************************/ // -// Created by raver119 on 12/11/17. +// @author raver119@gmail.com // #include @@ -35,8 +36,7 @@ CUSTOM_OP_IMPL(expose, -1, -1, true, 0, 0) { if (!var->hasNDArrayList()) { auto list = inVar->getNDArrayList(); - // block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); - throw std::runtime_error("Expose - not implemented yet"); + block.pushNDArrayListToVariableSpace(block.nodeId(), e, list); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 6a20db6a6677..491b369f5b01 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -3510,12 +3510,8 @@ TEST_F(DeclarableOpsTests1, Test_Expose_1) { } TEST_F(DeclarableOpsTests1, Test_Expose_2) { - if (1 > 0) throw std::runtime_error("Test not implemented yet"); - - auto list = new NDArrayList(0, true); - - auto var = std::make_shared(NDArray(), "arraylist", -1, 0); - // var->setNDArrayList(list); + auto list = std::make_shared(0, true); + auto var = std::make_shared(list, "arraylist", -1, 0); VariableSpace variableSpace; variableSpace.putVariable(-1, var); @@ -3535,7 +3531,7 @@ TEST_F(DeclarableOpsTests1, Test_Expose_2) { auto list1 = var1->getNDArrayList(); - ASSERT_TRUE(list == list1.get()); + ASSERT_TRUE(list.get() == list1.get()); } TEST_F(DeclarableOpsTests1, Test_Release) { From 638160bcc73da91268fb73994426319465f3ec75 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 09:13:09 +0300 Subject: [PATCH 196/233] one outdated test removed Signed-off-by: raver119@gmail.com --- .../tests_cpu/layers_tests/OneOffTests.cpp | 97 ------------------- 1 file changed, 97 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index 0e53696e0201..b587e256a846 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -55,22 +55,6 @@ TEST_F(OneOffTests, test_non2d_0A_1) { graph.execute(); } -/* -TEST_F(OneOffTests, test_assert_scalar_float32_1) { - sd::ops::Assert op; - sd::ops::identity op1; - sd::ops::noop op2; - auto graph = Graph::fromFlatBuffers("./resources/scalar_float32.fb"); - - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - delete graph; -}*/ - TEST_F(OneOffTests, test_assert_scalar_float32_2) { sd::ops::Assert op; sd::ops::identity op1; @@ -97,87 +81,6 @@ TEST_F(OneOffTests, test_pad_1D_1) { ASSERT_EQ(e, *z); } -/* -TEST_F(OneOffTests, test_scatter_nd_update_1) { - - auto e = NDArrayFactory::create('c', {10, 7}, -{1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, -0.71881700f, 0.18677747f, 0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, -0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, -0.50060666f, 0.39031255f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); - - auto graph = Graph::fromFlatBuffers("./resources/scatter_nd_update.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - - ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - - auto z = graph->variableSpace()->getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - z->printIndexedBuffer("z"); - - ASSERT_EQ(e, *z); - - delete graph; -} - */ - -TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { - auto e = NDArrayFactory::create( - 'c', {1, 5, 5, 6}, - {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, - 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, - 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, - 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, - 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, - 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, - 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, - 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, - 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, - 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, - 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, - 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, - 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, - 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, - 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, - 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, - 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, - 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, - 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, - 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, - 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, - 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, - 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, - 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, - 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, - 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, - 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, - 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, - 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, - 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); - - auto graph = Graph::fromFlatBuffers( - "./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); - - graph.execute(); - - ASSERT_TRUE(graph.variableSpace().hasVariable(9)); - - auto z = graph.variableSpace().getVariable(9)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - ASSERT_EQ(e, *z); -} TEST_F(OneOffTests, test_tensor_array_1) { auto e = From c89220d3f81748c9d67b2a09656de8684555dc26 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 09:42:38 +0300 Subject: [PATCH 197/233] few outdated tests removed Signed-off-by: raver119@gmail.com --- .../tests_cpu/layers_tests/OneOffTests.cpp | 50 ------------------- 1 file changed, 50 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp index b587e256a846..4f630ffcc826 100644 --- a/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OneOffTests.cpp @@ -169,54 +169,6 @@ TEST_F(OneOffTests, test_assert_4) { ASSERT_EQ(e, *z); } -// TEST_F(OneOffTests, test_cond_true_1) { -// auto e = NDArrayFactory::create('c', {5}, -// {1.f, 2.f, 3.f, 4.f, 5.f}); - -// auto graph = Graph::fromFlatBuffers("./resources/cond_true.fb"); -// ASSERT_TRUE(graph != nullptr); - -// graph->printOut(); - -// Nd4jStatus status = GraphExecutioner::execute(graph); -// ASSERT_EQ(Status::OK(), status); -// ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - -// auto z = graph->variableSpace()->getVariable(6)->getNDArray(); -// ASSERT_TRUE(z != nullptr); - -// z->printIndexedBuffer("z buffer"); - -// ASSERT_EQ(e, *z); - -// delete graph; -// } - -/* -TEST_F(OneOffTests, test_cond_false_1) { - auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); - - auto graph = Graph::fromFlatBuffers("./resources/cond_false.fb"); - ASSERT_TRUE(graph != nullptr); - - graph->printOut(); - - - Nd4jStatus status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->variableSpace()->hasVariable(6)); - - auto z = graph->variableSpace()->getVariable(6)->getNDArray(); - ASSERT_TRUE(z != nullptr); - - z->printIndexedBuffer("z buffer"); - - ASSERT_EQ(e, *z); - - delete graph; -} -*/ - TEST_F(OneOffTests, test_identity_n_2) { auto e = NDArrayFactory::create('c', {2, 3}, @@ -243,7 +195,6 @@ TEST_F(OneOffTests, test_non2d_1) { auto graph = Graph::fromFlatBuffers("./resources/non2d_1.fb"); - graph.printOut(); graph.execute(); ASSERT_TRUE(graph.variableSpace().hasVariable(6)); @@ -259,7 +210,6 @@ TEST_F(OneOffTests, test_reduce_all_1) { auto graph = Graph::fromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); - graph.printOut(); graph.execute(); ASSERT_TRUE(graph.variableSpace().hasVariable(1)); From 516a79d33bb22a66b34cd068218c9ac00fd4d32a Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 09:54:38 +0300 Subject: [PATCH 198/233] Layer/Sequence delimiters added to Graph::printOut Signed-off-by: raver119@gmail.com --- .../graph/execution/impl/ExecutionTask.cpp | 6 +-- libnd4j/include/graph/impl/OptimizedGraph.cpp | 39 +++++++++++-------- .../layers_tests/GraphAnalysisTests.cpp | 2 +- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index e48a0e26aea2..4445ea981b25 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -48,11 +48,11 @@ ExecutionTask::ExecutionTask(ExecutionTask &&other) void ExecutionTask::printOut() const { if (_context.name().empty()) { if (_node.hasCustomOp()) - printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _node.customOp()->getOpName().c_str()); + printf(" <%i:0>: {Op: %s}; ", _context.nodeId(), _node.customOp()->getOpName().c_str()); else - printf(" <%i:0>: ", _context.nodeId()); + printf(" <%i:0>: ", _context.nodeId()); } else { - printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); + printf(" <%s> <%i>: ", _context.name().c_str(), _context.nodeId()); } auto sz = _context.inputs().size(); diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index f21790a24d14..7faeeb95e926 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -220,25 +220,30 @@ size_t OptimizedGraph::size() const { } void OptimizedGraph::printOut() const { - - for (uint i = 0; i < _sortedGraph.size(); ++i) { - printf("Layer [%u]\n", i); - for (uint j = 0; j < _sortedGraph[i].width(); ++j) - _sortedGraph[i][j].printOut(); + for (uint i = 0; i < _sortedGraph.size(); ++i) { + printf("Layer [%u] {\n", i); + for (uint j = 0; j < _sortedGraph[i].width(); ++j) { + printf(" Sequence [%u] {\n", j); + _sortedGraph[i][j].printOut(); + printf(" }\n"); } - - printf("And simple print:\n"); - for (int i = 0; i < _sortedGraph.size(); ++i) { - printf("layer %i: ", i); - for (int j = 0; j < _sortedGraph[i].width(); ++j) { - printf("("); - for (int k = 0; k < _sortedGraph[i][j].length(); ++k) { - printf("%i, ", _sortedGraph[i][j][k].protoContext().nodeId()); - } - printf("), "); - } - printf("\n"); + printf("}\n"); + } + + /* + printf("And simple print:\n"); + for (int i = 0; i < _sortedGraph.size(); ++i) { + printf("layer %i: ", i); + for (int j = 0; j < _sortedGraph[i].width(); ++j) { + printf("("); + for (int k = 0; k < _sortedGraph[i][j].length(); ++k) { + printf("%i, ", _sortedGraph[i][j][k].protoContext().nodeId()); + } + printf("), "); } + printf("\n"); + } + */ } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index a3011a7f4404..0feba485afe1 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1063,7 +1063,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { TEST_F(GraphAnalysisTests, optimizedGraph_while1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); const auto& optimized = graph.optimizedGraph(); - // graph.printOut(); + graph.printOut(); } TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { From c65f16e3fbefbaa2b37a98c531a8684c3fd87d55 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 18 Jun 2020 11:35:42 +0300 Subject: [PATCH 199/233] CUDA side update Signed-off-by: raver119@gmail.com --- libnd4j/include/array/impl/NDArray.cpp | 6 +- libnd4j/include/legacy/NativeOps.h | 5 - libnd4j/include/legacy/cuda/NativeOps.cu | 185 +++--------------- .../ops/declarable/helpers/cuda/dropout.cu | 6 +- .../declarable/helpers/cuda/image_resize.cu | 4 +- .../include/ops/declarable/helpers/cuda/qr.cu | 3 +- .../ops/declarable/helpers/cuda/random.cu | 10 +- .../declarable/helpers/cuda/randomShuffle.cu | 4 +- .../ops/declarable/helpers/cuda/reverse.cu | 19 +- .../ops/declarable/helpers/cuda/roll.cu | 3 +- .../ops/declarable/helpers/cuda/svd.cu | 3 +- .../helpers/cuda/triangular_solve.cu | 6 +- .../declarable/platform/cudnn/avgpool2d.cu | 4 +- .../declarable/platform/cudnn/avgpool3d.cu | 4 +- .../declarable/platform/cudnn/batchnorm.cu | 8 +- .../ops/declarable/platform/cudnn/conv2d.cu | 10 +- .../ops/declarable/platform/cudnn/conv3d.cu | 10 +- .../platform/cudnn/depthwiseConv2d.cu | 14 +- .../declarable/platform/cudnn/maxpool2d.cu | 4 +- .../declarable/platform/cudnn/maxpool3d.cu | 4 +- libnd4j/include/system/op_boilerplate.h | 2 +- .../layers_tests/DeclarableOpsTestsCuda1.cu | 2 +- 22 files changed, 79 insertions(+), 237 deletions(-) diff --git a/libnd4j/include/array/impl/NDArray.cpp b/libnd4j/include/array/impl/NDArray.cpp index 5597eae6d290..8abb259e478a 100644 --- a/libnd4j/include/array/impl/NDArray.cpp +++ b/libnd4j/include/array/impl/NDArray.cpp @@ -2085,7 +2085,7 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { else if (arr->isZ()) ss << arr->e(i); else if (arr->isB()) - ss << arr->e(i) ? "true" : "false"; + ss << (arr->e(i) ? "true" : "false"); else if (arr->isS()) { ss << "\"" << arr->e(i).c_str() << "\""; } @@ -2108,7 +2108,7 @@ static void printFormatted(NDArray const* arr, int depth, int limit) { else if (arr->isZ()) ss << arr->e(row, col); else if (arr->isB()) - ss << arr->e(row, col) ? "true" : "false"; + ss << (arr->e(row, col) ? "true" : "false"); else if (arr->isS()) { ss << "\"" << arr->e(row * cols + col).c_str() <<"\""; } @@ -2163,7 +2163,7 @@ std::string NDArray::indexedBufferString(Nd4jLong limit) const { else if (this->isR()) ss << this->e(0); else if (this->isB()) { - ss << this->e(0) ? "true" : "false"; + ss << (this->e(0) ? "true" : "false"); } else if (this->isS()) { // todo do we need this // printf("\"%lld\"\n", this->getOffset(e)); diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index f06dda84240b..74f7b759aee4 100644 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -1517,11 +1517,6 @@ SD_EXPORT void deletePointerArray(Nd4jPointer pointer); SD_EXPORT void deleteVariablesSet(OpaqueVariablesSet* pointer); -// GraphState creation -SD_EXPORT Nd4jPointer getGraphState(Nd4jLong id); - -SD_EXPORT void deleteGraphState(Nd4jPointer state); - SD_EXPORT void deleteResultWrapper(Nd4jPointer ptr); SD_EXPORT int estimateThreshold(Nd4jPointer* extraPointers, Nd4jPointer x, diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index 116003830735..32892a67dd79 100644 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -2834,7 +2834,8 @@ void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { try { - return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + //return sd::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + return nullptr; } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( @@ -2855,7 +2856,7 @@ const char *getAllCustomOps() { } sd::ShapeList *_calculateOutputShapes( - Nd4jPointer *extraPointers, sd::ops::DeclarableOp *op, + Nd4jPointer *extraPointers, std::shared_ptr &op, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs, int *dArgs, int numDArgs) { @@ -2863,14 +2864,14 @@ sd::ShapeList *_calculateOutputShapes( Context block(2, &varSpace); sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) block.getIArguments()->push_back(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numTArgs; e++) block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); - for (int e = 0; e < numBArgs; e++) block.getBArguments()->push_back(bArgs[e]); + for (int e = 0; e < numBArgs; e++) block.appendB(bArgs[e]); for (int e = 0; e < numDArgs; e++) - block.getDArguments()->push_back((sd::DataType)dArgs[e]); + block.appendD((sd::DataType) dArgs[e]); for (int e = 0; e < numInputShapes; e++) { auto shape_ = reinterpret_cast(inputShapes[e]); @@ -2883,7 +2884,7 @@ sd::ShapeList *_calculateOutputShapes( ? nullptr : inputBuffers[e + numInputShapes]; - auto array = new sd::NDArray(buffer_, bufferD_, shape_); + sd::NDArray array(buffer_, bufferD_, shape_); // block should contain references to proper variable varSpace.putVariable(1, e, array); @@ -2894,8 +2895,6 @@ sd::ShapeList *_calculateOutputShapes( auto shapeList = op->calculateOutputShape(&inShapes, block); - if (varSpace.launchContext()->getWorkspace() != nullptr) shapeList->detach(); - return shapeList; } @@ -2921,7 +2920,7 @@ sd::ShapeList *calculateOutputShapes2(Nd4jPointer *extraPointers, Nd4jLong hash, } sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, - sd::ops::DeclarableOp *op, + std::shared_ptr &op, Nd4jPointer *inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, @@ -2929,9 +2928,9 @@ sd::ShapeList *_calculateOutputShapes(Nd4jPointer *extraPointers, Context block(1); sd::ShapeList inShapes; - for (int e = 0; e < numIArgs; e++) block.getIArguments()->push_back(iArgs[e]); + for (int e = 0; e < numIArgs; e++) block.appendI(iArgs[e]); - for (int e = 0; e < numTArgs; e++) block.getTArguments()->push_back(tArgs[e]); + for (int e = 0; e < numTArgs; e++) block.appendT(tArgs[e]); for (int e = 0; e < numInputShapes; e++) inShapes.push_back(reinterpret_cast(inputShapes[e])); @@ -2966,12 +2965,12 @@ Nd4jLong const *getShape(sd::ShapeList *list, Nd4jLong i) { return list->at(i); } -static FORCEINLINE Nd4jStatus -realExec(sd::ops::DeclarableOp *op, Nd4jPointer *extraPointers, Nd4jLong hash, - Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, - Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs, - double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, - bool *bArgs, int numBArgs, bool isInplace) { +static FORCEINLINE Nd4jStatus realExec( + std::shared_ptr &op, Nd4jPointer *extraPointers, Nd4jLong hash, + Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, + Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs, + double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, + bool *bArgs, int numBArgs, bool isInplace) { if (op == nullptr) nd4j_printf("Can't find requested operation: [%lld]\n", hash); @@ -3128,64 +3127,12 @@ int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, } } -static VariablesSet *executeStoredGraphT(Nd4jPointer *extraPointers, - Nd4jLong graphId, - Nd4jPointer *inputBuffers, - Nd4jPointer *inputShapes, - int *inputIndices, int numInputs) { - auto graph = sd::graph::GraphHolder::getInstance().pullGraph(graphId); - auto varSpace = graph->variableSpace()->clone(); - - std::vector handles; - - for (int e = 0; e < numInputs; e++) { - auto idx = inputIndices[e]; - - // we'll delete this array later, together with cloned VariableSpace - auto array = new sd::NDArray(inputBuffers[e], - reinterpret_cast(inputShapes[e])); - handles.emplace_back(array); - - if (varSpace->hasVariable(idx)) { - auto var = varSpace->getVariable(idx); - if (var->hasNDArray()) delete var->getNDArray(); - - var->setNDArray(array); - } else - varSpace->putVariable(idx, array); - } - - auto dZ = sd::graph::GraphExecutioner::execute(graph, varSpace); - auto varSet = new sd::graph::VariablesSet(dZ); - - if (dZ == ND4J_STATUS_OK) { - // pull back results, and provide them - auto outputs = graph->fetchOutputs(); - for (int e = 0; e < outputs->size(); e++) { - // we're only getting variable ID/Index from original grap. values will be - // taken from cloned workspace - std::pair varId(outputs->at(e)->id(), outputs->at(e)->index()); - - auto var = varSpace->getVariable(varId); - - varSet->push_back(var->clone()); - } - - delete outputs; - } - - delete varSpace; - - return varSet; -} - VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int *inputIndices, int numInputs) { try { - return executeStoredGraphT(extraPointers, graphId, inputBuffers, - inputShapes, inputIndices, numInputs); + throw std::runtime_error("Not implemented yet"); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( @@ -3213,7 +3160,7 @@ int getVariableIndex(sd::graph::Variable *variable) { } const char *getVariableName(sd::graph::Variable *variable) { - return variable->getName()->c_str(); + return variable->name().c_str(); } Nd4jLong const *getVariableShape(sd::graph::Variable *variable) { @@ -3226,9 +3173,9 @@ void *getVariableBuffer(sd::graph::Variable *variable) { int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { try { - sd::graph::GraphHolder::getInstance().dropGraphAny(graphId); + sd::graph::GraphHolder::getInstance().forgetGraph(graphId); - return ND4J_STATUS_OK; + return Status::OK(); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( @@ -3270,85 +3217,7 @@ const char *getAllOperations() { return sd::OpTracker::getInstance().exportOperations(); } -Nd4jPointer getGraphState(Nd4jLong id) { - return (Nd4jPointer) new sd::graph::GraphState(id); -} -void deleteGraphState(Nd4jPointer state) { - auto stateP = reinterpret_cast(state); - delete stateP; -} - -Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, - sd::graph::GraphState *state, Nd4jLong opHash, - Nd4jLong *scopes, int numScopes, - Nd4jPointer *inputBuffers, - Nd4jPointer *inputShapes, int numInputs, - Nd4jPointer *outputBuffers, - Nd4jPointer *outputShapes, int numOutputs) { - /** - * That's basically exec, with VariableSpace provided in GraphState: - * depending on operation (i.e. while of if), different logic executors could - * be used - */ - - auto graph = state->graph(); - auto varSpace = state->variableSpace(); - - // Node is dynamically created, and has nothing beyond it: only inputs and - // outputs this node has id of 0, and inputs are - Node node(OpType_LOGIC, opHash, 0); - - // mapping inputs - for (int e = 0; e < numInputs; e++) { - auto buffer = inputBuffers[e]; - auto shapeInfo = reinterpret_cast(inputShapes[e]); - - auto array = new sd::NDArray(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace - varSpace->putVariable(0, e, array); - node.pickInput(0, e); - } - - // mapping scopes - for (int e = 0; e < numScopes; e++) { - // we should check scope existence in GraphState/Graph - int scopeId = (int)scopes[e]; - if (!state->hasScope(scopeId)) { - // nd4j_printf("execCustomOpWithScope: referenced scope [%i] doesn't - // exist\n", scopeId); - return Status::THROW(); - } - node.pickInput(scopeId, 0); - } - - auto dZ = LogicExecutor::processNode(graph, &node); - if (dZ != Status::OK()) return dZ; - - // mapping outputs - - for (int e = 0; e < numOutputs; e++) { - auto buffer = outputBuffers[e]; - auto shapeInfo = reinterpret_cast(outputShapes[e]); - - NDArray array(buffer, shapeInfo, varSpace->launchContext()); - - // now we just put array to VarSpace to the same ID - // varSpace->putVariable(0, e, array); - - auto t = varSpace->getVariable(0, e)->getNDArray(); - array.assign(t); - } - - // removing input variables - for (int e = 0; e < numInputs; e++) { - varSpace->dropVariable(0, e); - } - - // after some bla-bla-bla we should have Graph and Node for current op - return Status::OK(); -} Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, @@ -3356,17 +3225,7 @@ Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - try { - return execCustomOpWithScope( - extraPointers, reinterpret_cast(state), opHash, - scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, - outputShapes, numOutputs); - } catch (std::exception &e) { - sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); - sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( - e.what()); - return 1; - } + return 0; } void deleteResultWrapper(Nd4jPointer ptr) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 6c5e1691eb09..ef14d1e46780 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -332,7 +332,7 @@ int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { BUILD_SINGLE_SELECTOR( - context.dataType(), return dropOutFunctorBP_, + output->dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); } @@ -341,7 +341,7 @@ int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); @@ -351,7 +351,7 @@ int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { - BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, + BUILD_SINGLE_SELECTOR(output->dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 8cdc99250b76..cc349d98bc6d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -1429,8 +1429,8 @@ BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, // ------------------------------------------------------------------------------------------------------------------// -int resizeImagesFunctor(sd::LaunchContext* context, NDArray const* image, int constwidth, - int constheight, ImageResizeMethods method, +int resizeImagesFunctor(sd::LaunchContext* context, NDArray const* image, int const width, + int const height, ImageResizeMethods method, bool alignCorners, NDArray* output) { switch (method) { case kResizeBilinear: diff --git a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu index 3f78567859c3..762e3f3b14fb 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/qr.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/qr.cu @@ -186,8 +186,7 @@ void qr_(LaunchContext* context, NDArray const* input, NDArray* outputQ, for (auto batch = start; batch < stop; batch += increment) { // qr here - qrSingle(context, listInput.at(batch), listOutQ.at(batch), - listOutR.at(batch), fullMatricies); + qrSingle(context, &listInput.at(batch), &listOutQ.at(batch), &listOutR.at(batch), fullMatricies); } NDArray::registerSpecialUse({outputQ, outputR}, {input}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index 1550efe407cb..ddb40c23c87b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -133,8 +133,8 @@ namespace helpers { * output - distributed output. * */ template -static __global__ void fillGammaKernel(Tconst* uList, Nd4jLong uLength, T const* alpha, - const Nd4jLong* alphaShape, Tconst* beta, +static __global__ void fillGammaKernel(T const* uList, Nd4jLong uLength, T const* alpha, + const Nd4jLong* alphaShape, T const* beta, const Nd4jLong* betaShape, T* output, const Nd4jLong* outputShape) { // fill up @@ -180,11 +180,7 @@ static void fillRandomGamma_(LaunchContext* context, copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast( BroadcastOpsTuple::Assign(), *alpha)); - copyBeta = new NDArray( - betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); - // if (!copyAlpha->isActualOnDevice()) copyAlpha->syncToDevice(); -// if (! - copyBeta->isActualOnDevice()) copyBeta->syncToDevice(); + copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta)); } auto stream = context->getCudaStream(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu index bb7998e60ef3..29725ac86ab5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu @@ -187,7 +187,7 @@ static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& for(int i = firstDim - 1; i > 0; --i) { const int j = rng.relativeInt(i) % (i + 1); if(i != j) - subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + subArrsList.at(i).swapUnsafe(subArrsList.at(j)); } } else { @@ -204,7 +204,7 @@ static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; ++i) - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + subArrsListOut.at(i).assign(subArrsListIn.at(indices[i])); }; samediff::Threads::parallel_for(func, 0, firstDim); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index 4fb6ac960d29..a16d0dcbf1fc 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -218,15 +218,14 @@ static void reverseSequence_(sd::LaunchContext* context, const NDArray* input, int numOfElemsToReverse = seqLengths->e(i); if (numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { - outSubArrsSet.at(i)->assign(inSubArrsSet.at(i)); + outSubArrsSet.at(i).assign(inSubArrsSet.at(i)); } else { auto inInnerSet = - inSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + inSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); auto outInnerSet = - outSubArrsSet.at(i)->allTensorsAlongDimension({seqDim}); + outSubArrsSet.at(i).allTensorsAlongDimension({seqDim}); for (int j = 0; j < inInnerSet.size(); ++j) - reverseArray(context, inInnerSet.at(j), outInnerSet.at(j), - numOfElemsToReverse); + reverseArray(context, &inInnerSet.at(j), &outInnerSet.at(j), numOfElemsToReverse); } } } @@ -250,16 +249,13 @@ void reverseSequence(sd::LaunchContext* context, const NDArray* input, void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, const std::vector* intArgs) { - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( - input->shapeInfo(), *intArgs); - auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions( - output->shapeInfo(), *intArgs); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), *intArgs); + auto packZ = sd::ConstantTadHelper::getInstance().tadForDimensions(output->shapeInfo(), *intArgs); NDArray::prepareSpecialUse({output}, {input}); if (packX.numberOfTads() == 1) { - BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, - (context, input, output, 0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); } else { BUILD_SINGLE_SELECTOR( input->dataType(), reverseTad, @@ -272,6 +268,7 @@ void reverse(sd::LaunchContext* context, const NDArray* input, NDArray* output, NDArray::registerSpecialUse({output}, {input}); } + } // namespace helpers } // namespace ops } // namespace sd diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index 29a3ee32f82c..01db45ba6e78 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -272,8 +272,7 @@ static void rollFunctorFull_(NDArray *input, NDArray *output, // theShift -= fullLen * (theShift / fullLen - 1); // } for (int k = 0; k < fullLen; k++) { - rollFunctorLinear(output->getContext(), listOfTensors.at(k), - listOfOutTensors.at(k), theShift, true); + rollFunctorLinear(output->getContext(), &listOfTensors.at(k), &listOfOutTensors.at(k), theShift, true); } } else { std::vector dims(input->rankOf() - axe - 1); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index bb0668cb95db..bab93a08685e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -728,8 +728,7 @@ void svd(sd::LaunchContext* context, const NDArray* x, } for (int i = 0; i < tadsX.size(); ++i) - svdJcb(context, tadsX.at(i), tadsS.at(i), calcUV ? tadsU->at(i) : nullptr, - calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); + svdJcb(context, &tadsX.at(i), &tadsS.at(i), calcUV ? &tadsU->at(i) : nullptr, calcUV ? &tadsV->at(i) : nullptr, fullUV, calcUV); if (calcUV) { delete tadsU; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu index a1061231d902..8884ffdce8c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/triangular_solve.cu @@ -200,10 +200,8 @@ static int triangularSolveFunctor_(sd::LaunchContext* context, // output.syncToDevice(); } BUILD_SINGLE_TEMPLATE(template void triangularSolve2D, (sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output), FLOAT_TYPES); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output); -// template void triangularSolve2D(sd::LaunchContext* context, NDArray const& leftInput, NDArray const& rightInput, bool const lower, bool const unitsOnDiag, NDArray& output);int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, + +int triangularSolveFunctor(sd::LaunchContext* context, NDArray* leftInput, NDArray* rightInput, bool lower, bool unitsOnDiag, NDArray* output) { BUILD_SINGLE_SELECTOR( diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu index 737eaa01af41..4a4deb642f94 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -43,7 +43,7 @@ PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) { const auto dW = INT_ARG(7); const auto paddingMode = static_cast(INT_ARG(8)); const auto extraParam0 = INT_ARG(9); - const int isNCHW = block.getIArguments()->size() > 10 + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC @@ -111,7 +111,7 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { const auto dW = INT_ARG(7); // dilations width const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME const auto extraParam0 = INT_ARG(9); - const auto isNCHW = block.getIArguments()->size() > 10 + const auto isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu index aabf3300905c..15e1dd975161 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -47,7 +47,7 @@ PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, @@ -128,7 +128,7 @@ PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging - const int isNCDHW = block.getIArguments()->size() > 14 + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu index f22c00d33573..c011a28af0ec 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -353,7 +353,7 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { if (applyScale) gamma = INPUT_VARIABLE(3); if (applyOffset) beta = INPUT_VARIABLE(3 + (int)applyScale); - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over @@ -463,7 +463,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { NDArray* gamma = applyScale ? INPUT_VARIABLE(3) : nullptr; NDArray* beta = applyOffset ? INPUT_VARIABLE(3 + (int)applyScale) : nullptr; - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int xRank = input->rankOf(); // *********************************** // @@ -545,7 +545,7 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) { gradB = OUTPUT_VARIABLE(3 + (int)applyScale); } - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int inRank = input->rankOf(); // get axes args to normalize input array over @@ -671,7 +671,7 @@ PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) { NDArray* gradG = nullptr; NDArray* gradB = nullptr; - const int numOfIntArgs = block.getIArguments()->size(); + const int numOfIntArgs = block.numI(); const int xRank = input->rankOf(); // *********************************** // diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu index f93cd5d42b28..7bad4e2b0ca8 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv2d.cu @@ -427,10 +427,10 @@ PLATFORM_IMPL(conv2d, ENGINE_CUDA) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 + bool isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - // [oC, kH, kW, iC] @@ -568,10 +568,10 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CUDA) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, oC], 1 - [oC, iC, kH, kW], 2 - // [oC, kH, kW, iC] @@ -697,7 +697,7 @@ PLATFORM_CHECK(conv2d_bp, ENGINE_CUDA) { // oH, oW] (NCHW), epsilon_next const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 + const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC diff --git a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu index 9a2007b7497c..dfd63d957eee 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/conv3d.cu @@ -475,10 +475,10 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CUDA) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = block.getIArguments()->size() > 13 + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], // 2-[oC, kD, kH, kW, iC] @@ -620,10 +620,10 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CUDA) { int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW - int wFormat = block.getIArguments()->size() > 14 + int wFormat = block.numI() > 14 ? INT_ARG(14) : 0; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], // 2-[oC, kD, kH, kW, iC] @@ -748,7 +748,7 @@ PLATFORM_CHECK(conv3dnew_bp, ENGINE_CUDA) { // oH, oW] (NCDHW), epsilon_next int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNCDHW = block.getIArguments()->size() > 13 + int isNCDHW = block.numI() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW diff --git a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu index 512b51efad5e..8a09c6f33d57 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/depthwiseConv2d.cu @@ -478,10 +478,10 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CUDA) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - int wFormat = block.getIArguments()->size() > 10 + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - // [mC, kH, kW, iC] @@ -557,7 +557,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CUDA) { auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int wFormat = block.getIArguments()->size() > 10 + const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 // - [mC, kH, kW, iC] @@ -625,10 +625,10 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CUDA) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 + int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW - int wFormat = block.getIArguments()->size() > 10 + int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 - // [mC, kH, kW, iC] @@ -738,10 +738,10 @@ PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CUDA) { // oH, oW] (NCDHW), epsilon_next const int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME, 2-CAUSAL - const int isNCHW = block.getIArguments()->size() > 9 + const int isNCHW = block.numI() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - const int wFormat = block.getIArguments()->size() > 10 + const int wFormat = block.numI() > 10 ? INT_ARG(10) : 0; // 0 - [kH, kW, iC, mC], 1 - [mC, iC, kH, kW], 2 // - [mC, kH, kW, iC] diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu index fb1682b9b2b3..50504fd29bc1 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -42,7 +42,7 @@ PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) { const auto dH = INT_ARG(6); const auto dW = INT_ARG(7); const auto paddingMode = static_cast(INT_ARG(8)); - const int isNCHW = block.getIArguments()->size() > 10 + const int isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC @@ -105,7 +105,7 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { const auto dH = INT_ARG(6); // dilations height const auto dW = INT_ARG(7); // dilations width const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - const auto isNCHW = block.getIArguments()->size() > 10 + const auto isNCHW = block.numI() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu index b9fdd58d1f58..ab5ff55f28d6 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -47,7 +47,7 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { int dW = INT_ARG(11); // dilations width int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) + int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, @@ -124,7 +124,7 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID // const int extraParam0 = INT_ARG(13); // define what divisor to use while // averaging - const int isNCDHW = block.getIArguments()->size() > 14 + const int isNCDHW = block.numI() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC diff --git a/libnd4j/include/system/op_boilerplate.h b/libnd4j/include/system/op_boilerplate.h index 465a31c13e3c..0860c0dd8063 100644 --- a/libnd4j/include/system/op_boilerplate.h +++ b/libnd4j/include/system/op_boilerplate.h @@ -1255,7 +1255,7 @@ #define REGISTER_H(NAME) template \ struct __registrator_##NAME {\ __registrator_##NAME() {\ - OpName *ptr = new OpName(); \ + auto ptr = std::make_shared(); \ OpRegistrator::getInstance().registerOperation(ptr); \ }\ };\ diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index f0f3fdc05c42..60d12a1745fa 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -63,7 +63,7 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { auto z = result.at(1); - ASSERT_EQ(148, z->e(0)); + ASSERT_EQ(148, z.e(0)); // ASSERT_TRUE(exp.isSameShape(z)); } From 947fb480c684c96336b0c05ea69b5d642a2a7cc4 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 18 Jun 2020 14:38:51 +0300 Subject: [PATCH 200/233] - restore deleted row of code in lstmLayer mkl helper Signed-off-by: Yurii --- .../declarable/platform/mkldnn/lstmLayer.cpp | 365 +++++------------- .../platform/mkldnn/mkldnnUtils.cpp | 4 +- .../layers_tests/DeclarableOpsTests13.cpp | 14 +- 3 files changed, 116 insertions(+), 267 deletions(-) diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 57c4222475b8..8ec643ef1d87 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -33,6 +33,7 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* hI, const NDArray* cI, const std::vector& params, NDArray* h, NDArray* hL, NDArray* cL) { + // equations (no peephole connections) // it = σ(Wxi * xt + Wri * ht-1 + bi) // ft = σ(Wxf * xt + Wrf * ht-1 + bf) @@ -108,21 +109,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const int dataFormat = params[0]; const int directionMode = params[1]; - const int sL = - x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); - const int bS = - x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); + const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); const int nIn = x->sizeAt(-1); const int nOut = Wx->sizeAt(-1); - const int dirDim = - directionMode < 2 - ? 1 - : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional - const int hDirDim = directionMode <= 2 - ? 1 - : 2; // for h array, take into account - // bidirectional_sum mode (directionMode == 2) + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional + const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) // evaluate direction rnn_direction direction; @@ -159,7 +152,8 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, // weights type dnnl::memory::data_type wType = xType; - if (xType == dnnl::memory::data_type::u8) wType = dnnl::memory::data_type::s8; + if (xType == dnnl::memory::data_type::u8) + wType = dnnl::memory::data_type::s8; // bias type dnnl::memory::data_type bType = xType; @@ -177,98 +171,64 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, // memory descriptors for arrays // x - x_lstm_md = - dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); - // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, - // dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, - // dnnl::memory::format_tag::ntc); - x_user_md = - dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); - mkldnnUtils::setBlockStrides(*x, - x_user_md); + x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); + x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); + mkldnnUtils::setBlockStrides(*x, x_user_md); // wx - wx_lstm_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, - dnnl::memory::format_tag::any); - wx_user_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, - dnnl::memory::format_tag::ldigo); - mkldnnUtils::setBlockStrides(*Wx, - wx_user_md); + wx_lstm_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, dnnl::memory::format_tag::any); + wx_user_md = dnnl::memory::desc({1, dirDim, nIn, 4, nOut}, wType, dnnl::memory::format_tag::ldigo); + mkldnnUtils::setBlockStrides(*Wx, wx_user_md); // wr - wr_lstm_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, - dnnl::memory::format_tag::any); - wr_user_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, - dnnl::memory::format_tag::ldigo); - mkldnnUtils::setBlockStrides(*Wr, - wr_user_md); + wr_lstm_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, dnnl::memory::format_tag::any); + wr_user_md = dnnl::memory::desc({1, dirDim, nOut, 4, nOut}, wType, dnnl::memory::format_tag::ldigo); + mkldnnUtils::setBlockStrides(*Wr, wr_user_md); // h - h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, - dnnl::memory::format_tag::any); - // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, - // type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, - // hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); - h_user_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, - dnnl::memory::format_tag::tnc); - mkldnnUtils::setBlockStrides(*h, - h_user_md); + h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, dnnl::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); + h_user_md = dnnl::memory::desc({sL, bS, hDirDim * nOut}, hType, dnnl::memory::format_tag::tnc); + mkldnnUtils::setBlockStrides(*h, h_user_md); // b if (b) { - b_lstm_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, - dnnl::memory::format_tag::any); - b_user_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, - dnnl::memory::format_tag::ldgo); - mkldnnUtils::setBlockStrides(*b, - b_user_md); + b_lstm_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, dnnl::memory::format_tag::any); + b_user_md = dnnl::memory::desc({1, dirDim, 4, nOut}, bType, dnnl::memory::format_tag::ldgo); + mkldnnUtils::setBlockStrides(*b, b_user_md); } // hI if (hI) { - hI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, - dnnl::memory::format_tag::any); - hI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, - dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*hI, - hI_user_md); + hI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::any); + hI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*hI, hI_user_md); } // cI if (cI) { - cI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, - dnnl::memory::format_tag::any); - cI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, - dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*cI, - cI_user_md); + cI_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::any); + cI_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, xType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*cI, cI_user_md); } // hL if (hL) { - hL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, - dnnl::memory::format_tag::any); - hL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, - dnnl::memory::format_tag::ldnc); + hL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::any); + hL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); hL_user_md.data.format_kind = dnnl_blocked; // overrides format - mkldnnUtils::setBlockStrides(*hL, - hL_user_md); + mkldnnUtils::setBlockStrides(*hL, hL_user_md); } if (cL) { - cL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, - dnnl::memory::format_tag::ldnc); - cL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, - dnnl::memory::format_tag::ldnc); - mkldnnUtils::setBlockStrides(*cL, - cL_user_md); + cL_lstm_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); + cL_user_md = dnnl::memory::desc({1, dirDim, bS, nOut}, hType, dnnl::memory::format_tag::ldnc); + mkldnnUtils::setBlockStrides(*cL, cL_user_md); } // lstm memory description - lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, - x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, - wr_lstm_md, b_lstm_md, h_lstm_md, hL_lstm_md, - cL_lstm_md); + lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md); dnnl::stream stream(engine); @@ -280,46 +240,28 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, - lstm_prim_desc.src_layer_desc(), - args[DNNL_ARG_SRC_LAYER]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); // wx - mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, - lstm_prim_desc.weights_layer_desc(), - args[DNNL_ARG_WEIGHTS_LAYER]); + mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); // wr - mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, - lstm_prim_desc.weights_iter_desc(), - args[DNNL_ARG_WEIGHTS_ITER]); + mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); // h - auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, - lstm_prim_desc.dst_layer_desc() , - args[DNNL_ARG_DST_LAYER] ); + auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER] ); // b if (b) - mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, - lstm_prim_desc.bias_desc(), - args[DNNL_ARG_BIAS]); - + mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); // hI if (hI) - mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, - lstm_prim_desc.src_iter_desc(), - args[DNNL_ARG_SRC_ITER]); - + mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); // cI if (cI) - mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, - lstm_prim_desc.src_iter_c_desc(), - args[DNNL_ARG_SRC_ITER_C]); - - + mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; @@ -349,93 +291,50 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { - const auto dataFormat = INT_ARG( - 0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, - // nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) - const auto directionMode = - INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = - // bidirectional concat, 4 = bidirectional extra output dim - // (in conjunction with format dataFormat = 3) - const auto hasBiases = - B_ARG(0); // indicates whether biases array is provided - const auto hasSeqLen = - B_ARG(1); // indicates whether seqLen array is provided - const auto hasInitH = - B_ARG(2); // indicates whether initial output is provided - const auto hasInitC = - B_ARG(3); // indicates whether initial cell state is provided - const auto hasPH = - B_ARG(4); // indicates whether peephole connections are present - const auto retFullSeq = B_ARG(5); // indicates whether to return whole time - // sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = - B_ARG(6); // indicates whether to return output at last time step only, - // in this case shape would be [bS, nOut] (exact shape depends - // on dataFormat argument) - const auto retLastC = - B_ARG(7); // indicates whether to return cells state at last time step - // only, in this case shape would be [bS, nOut] (exact shape - // depends on dataFormat argument) + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 =bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) - const auto cellClip = - T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping const auto x = INPUT_VARIABLE(0); // input const auto Wx = INPUT_VARIABLE(1); // input weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights int count = 3; - const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases - const auto seqLen = - hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector - const auto hI = - hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output - const auto cI = - hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state - const auto Wp = - hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights - - REQUIRE_TRUE(cellClip == 0, 0, - "LSTM_LAYER_MKLDNN operation: cell clipping is not supported " - "currently !"); - REQUIRE_TRUE( - retFullSeq, 0, - "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence " - "output h should be always true in case of mkl dnn library !"); - REQUIRE_TRUE(hasPH == false, 0, - "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support " - "peephole connections !"); - REQUIRE_TRUE(hasSeqLen == false, 0, - "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support " - "array specifying max time step per each example in batch !"); - REQUIRE_TRUE( - dataFormat < 2, 0, - "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are " - "allowed for input/output tensors in mkl dnn library: TNC and NTC!"); - REQUIRE_TRUE(directionMode < 4, 0, - "LSTM_LAYER_MKLDNN operation: option for bidirectional extra " - "output dimension is not valid in mkl dnn library !"); - REQUIRE_TRUE(retLastH == retLastC, 0, - "LSTM_LAYER_MKLDNN operation: only two options are present: 1) " - "calculate both output at last time and cell state at last " - "time; 2) do not calculate both !"); - REQUIRE_TRUE(hasInitH == hasInitC, 0, - "LSTM_LAYER_MKLDNN operation: either both of or neither of " - "initial C and initial H must be provided"); + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip == 0, 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported " "currently !"); + REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); + REQUIRE_TRUE(hasPH == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); + REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); + REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); + REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); + REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); count = 0; - auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output - auto hL = - retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step - auto cL = - retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step // evaluate dimensions - const Nd4jLong sL = x->sizeAt(dataFormat); - const Nd4jLong bS = - dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); - const Nd4jLong nIn = x->sizeAt(2); + const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; // inputs validations @@ -443,116 +342,66 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { // Wx validation if (Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, " - "expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), - ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); // Wr validation if (Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4 * nOut) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent " - "weights, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), - ShapeUtils::shapeAsString(Wr).c_str()); + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); // biases validation if (b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4 * nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of biases, " - "expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({4 * nOut}).c_str(), - ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); // initial output validation - if (hI != nullptr && - (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of initial " - "output, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({bS, nOut}).c_str(), - ShapeUtils::shapeAsString(hI).c_str()); + if (hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); // initial cell validation - if (cI != nullptr && - (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell " - "state, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({bS, nOut}).c_str(), - ShapeUtils::shapeAsString(cI).c_str()); - } else { // bidirectional + if (cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + } + else { // bidirectional // Wx validation if (Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, " - "expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), - ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); // Wr validation - if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || - Wr->sizeAt(2) != 4 * nOut) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent " - "weights, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), - ShapeUtils::shapeAsString(Wr).c_str()); + if (Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4 * nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); // biases validation - if (b != nullptr && - (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of biases, " - "expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), - ShapeUtils::shapeAsString(b).c_str()); + if (b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4 * nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4 * nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); // initial output validation - if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || - hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of initial " - "output, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), - ShapeUtils::shapeAsString(hI).c_str()); + if (hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); // initial cell validation - if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || - cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) - REQUIRE_TRUE(false, 0, - "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell " - "state, expected is %s, but got %s instead !", - ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), - ShapeUtils::shapeAsString(cI).c_str()); + if (cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); } - std::vector params = {static_cast(dataFormat), - static_cast(directionMode), - static_cast(cellClip)}; + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; - const int dirDim = - directionMode < 2 - ? 1 - : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional // permut x and h to tnc format if they have ntc format NDArray *xP(const_cast(x)), *hP(h); if (dataFormat == 1) { xP = new NDArray(x->permute({1, 0, 2})); // [bS, sL, nIn] -> [sL, bS, nIn] - hP = new NDArray( - h->permute({1, 0, 2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] + hP = new NDArray(h->permute({1, 0, 2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] } // reshape arrays in accordance to mkl allowed formats - NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), - *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); + NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); WxR = new NDArray(Wx->reshape(Wx->ordering(), {1, dirDim, nIn, 4, nOut})); WrR = new NDArray(Wr->reshape(Wr->ordering(), {1, dirDim, nOut, 4, nOut})); - if (b) bR = new NDArray(b->reshape(b->ordering(), {1, dirDim, 4, nOut})); + if (b) + bR = new NDArray(b->reshape(b->ordering(), {1, dirDim, 4, nOut})); else - bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullifiedif (hI) hIR = new NDArray(hI->reshape(hI->ordering(), {1, dirDim, bS, nOut})); - if (cI) cIR = new NDArray(cI->reshape(cI->ordering(), {1, dirDim, bS, nOut})); + bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullifiedif (hI) hIR = new NDArray(hI->reshape(hI->ordering(), {1, dirDim, bS, nOut})); + if(hI) + hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); + if (cI) + cIR = new NDArray(cI->reshape(cI->ordering(), {1, dirDim, bS, nOut})); if (hL) - hLR = - new NDArray(hL->reshape(hL->ordering(), {1, dirDim, bS, nOut}, false)); + hLR = new NDArray(hL->reshape(hL->ordering(), {1, dirDim, bS, nOut}, false)); if (cL) - cLR = - new NDArray(cL->reshape(cL->ordering(), {1, dirDim, bS, nOut}, false)); + cLR = new NDArray(cL->reshape(cL->ordering(), {1, dirDim, bS, nOut}, false)); lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 52f52feda2b5..2763e8fe2ffb 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -85,6 +85,7 @@ void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std: } } } + //////////////////////////////////////////////////////////////////////////////////////////////// dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, @@ -97,7 +98,8 @@ dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engin auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; if (bReorder) dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); - arg = mkl_mem;return user_mem; + arg = mkl_mem; + return user_mem; } ////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 680dbcea4230..0c62808c4ca6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1279,20 +1279,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) { const int dataFormat = 0; // [sL,bS,nIn] const int directionMode = 0; // forward - const int gateAct = - 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output const bool hasBiases = true; // biases array is provided const bool hasSeqLen = false; // seqLen array is not provided const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = false; // peephole connections are absent - const auto retFullSeq = - true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step - const auto retLastC = true; // return cells state at last time step + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step const double cellClip = 0; // do not apply clipping From 5828e6402f7ed84d0af0ae0a31349efcac8f2dcf Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 18 Jun 2020 18:34:30 +0300 Subject: [PATCH 201/233] - first attempt to combine all ops belonging to loop into one opSequence Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 7faeeb95e926..a185c2141464 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -154,19 +154,51 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& nextId : workMap[id]._out) visit(nextId, 1, numOfLayers); + // gather all nodes belonging to loop and store them in one OpSequence + const char delimiter = '/'; + for (auto p0 = workMap.begin(); p0 != std::prev(workMap.end()); ++p0) { // std::prev(workMap.end()) == workMap.end() - 1 + if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) + continue; + const auto& name = _nodesMap[p0->first].name(); + if(name.find("Enter") == std::string::npos) + continue; + std::string loopName = name.substr(0, name.find(delimiter)); // evaluate name of loop + for (auto p1 = std::next(p0); p1 != workMap.end(); ++p1) { // std::next(p0) = p0 + 1 + if(_nodesMap[p1->first].name().find(loopName) == std::string::npos) + continue; + p0->second._opSeq.push_back(p1->first); + p0->second._opSeq.insert(p0->second._opSeq.end(), p1->second._opSeq.begin(), p1->second._opSeq.end()); + p1->second._opSeq.clear(); + p1->second._opSeq.push_back(-1); // mark node to be neglected + // p1->second._layerNum = p0->second._layerNum; + } + } + // fill vectors with layers - std::vector sortedGraphTemp(numOfLayers+1); + std::vector sortedGraphTemp;//(numOfLayers+1); for (const auto& p : workMap) { + if(!p.second._opSeq.empty() && p.second._opSeq[0] == -1) + continue; + OpSequence seq; seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); for (const auto& id : p.second._opSeq) seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); + while(sortedGraphTemp.size() <= p.second._layerNum) + sortedGraphTemp.emplace_back(ExecutionLayer()); + sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } + // delete empty layers + for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) + if(it->width() == 0) + sortedGraphTemp.erase(it--); + // check whether there are layers with one OpSequence which in turn contains only one op bool isLayerWithOneOp = false; for(auto& layer : sortedGraphTemp) { @@ -230,7 +262,7 @@ void OptimizedGraph::printOut() const { printf("}\n"); } - /* + printf("And simple print:\n"); for (int i = 0; i < _sortedGraph.size(); ++i) { printf("layer %i: ", i); @@ -243,7 +275,7 @@ void OptimizedGraph::printOut() const { } printf("\n"); } - */ + } From d4b79ac46fa1dab43f483d1077aee73100434ad7 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 18 Jun 2020 19:56:43 +0300 Subject: [PATCH 202/233] - correct alg for cycle ops Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index a185c2141464..7db7b14c8987 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -159,11 +159,33 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (auto p0 = workMap.begin(); p0 != std::prev(workMap.end()); ++p0) { // std::prev(workMap.end()) == workMap.end() - 1 if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) continue; - const auto& name = _nodesMap[p0->first].name(); - if(name.find("Enter") == std::string::npos) + // const auto& name = _nodesMap[p0->first].name(); + // if(name.find("Enter") == std::string::npos) + // continue; + bool isInLoop = false; + auto* name = &_nodesMap[p0->first].name(); + + if(name->find("Enter") == std::string::npos) { + for (const auto& id : p0->second._opSeq) { + if(id == -1) + break; + name = &_nodesMap[id].name(); + if(name->find("Enter") != std::string::npos) { + isInLoop = true; + break; + } + } + } + else + isInLoop = true; + + if(!isInLoop) continue; - std::string loopName = name.substr(0, name.find(delimiter)); // evaluate name of loop + + std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop for (auto p1 = std::next(p0); p1 != workMap.end(); ++p1) { // std::next(p0) = p0 + 1 + if(!p1->second._opSeq.empty() && p1->second._opSeq[0] == -1) + continue; if(_nodesMap[p1->first].name().find(loopName) == std::string::npos) continue; p0->second._opSeq.push_back(p1->first); From aeb96b9f48092cd8276b06f652cbabd023d04713 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 07:41:36 +0300 Subject: [PATCH 203/233] few more assertions Signed-off-by: raver119@gmail.com --- libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 0feba485afe1..95fe175c7681 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1063,6 +1063,11 @@ TEST_F(GraphAnalysisTests, optimizedGraph_cond2) { TEST_F(GraphAnalysisTests, optimizedGraph_while1) { auto graph = Graph::fromFlatBuffers("resources/while_iter1.fb"); const auto& optimized = graph.optimizedGraph(); + + // this Graph must have exactly 1 Layer and 1 OpSequence, since all it has is While loop + ASSERT_EQ(1, optimized.layers()); + ASSERT_EQ(1, optimized.layer(0).width()); + graph.printOut(); } From 2971d0f2f583e5f1ad190588e00a74682dd6f4a0 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 07:55:15 +0300 Subject: [PATCH 204/233] Outputs added to Graph::printOut Signed-off-by: raver119@gmail.com --- .../graph/execution/impl/ExecutionTask.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp index 4445ea981b25..5677852bc28f 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionTask.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionTask.cpp @@ -59,7 +59,7 @@ void ExecutionTask::printOut() const { if (sz) { printf(" Inputs: ["); int cnt = 0; - for (const auto &v : _context.inputs()) { + for (const auto &v : _node.inputs()) { printf("<%i:%i>", v.first, v.second); if (cnt < sz - 1) printf(", "); @@ -71,6 +71,22 @@ void ExecutionTask::printOut() const { printf(" No inputs; "); } + sz = _node.outputs().size(); + if (sz) { + printf(" Outputs: ["); + int cnt = 0; + for (const auto &v : _node.outputs()) { + printf("<%i:%i>", v.first, v.second); + + if (cnt < sz - 1) printf(", "); + cnt++; + } + + printf("]; "); + } else { + printf(" No outputs; "); + } + printf("\n"); fflush(stdout); } From 9b8b41d30ac283098567464b1a3dd182879fc53b Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 10:38:39 +0300 Subject: [PATCH 205/233] next step Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/Stack.h | 3 + libnd4j/include/graph/execution/StackFrame.h | 7 +- .../graph/execution/impl/GraphExecutor.cpp | 2 +- .../include/graph/execution/impl/Stack.cpp | 10 ++- .../graph/execution/impl/StackFrame.cpp | 6 +- libnd4j/include/graph/logic/LogicEnter.h | 2 +- libnd4j/include/graph/logic/LogicExecutor.h | 4 +- libnd4j/include/graph/logic/LogicExit.h | 2 +- libnd4j/include/graph/logic/LogicMerge.h | 2 +- .../include/graph/logic/LogicNextIteration.h | 2 +- libnd4j/include/graph/logic/LogicSwitch.h | 2 +- .../include/graph/logic/impl/LogicEnter.cpp | 65 ++++++------------- .../graph/logic/impl/LogicExecutor.cpp | 12 ++-- .../include/graph/logic/impl/LogicExit.cpp | 2 +- .../include/graph/logic/impl/LogicMerge.cpp | 5 +- .../graph/logic/impl/LogicNextIteration.cpp | 2 +- .../include/graph/logic/impl/LogicSwitch.cpp | 5 +- 17 files changed, 67 insertions(+), 66 deletions(-) diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h index fbebc04d0bbe..9fc04a874056 100644 --- a/libnd4j/include/graph/execution/Stack.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -41,6 +41,9 @@ class SD_EXPORT Stack { StackFrame& root(); const VariableProxy& rootVariableSpace() const; + + void openFrame(const std::string &frameName); + void closeFrame(); }; } // namespace graph diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 9adb4ff2814d..e3216c3bbe8f 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -23,6 +23,7 @@ #include #include +#include namespace sd { namespace graph { @@ -32,14 +33,18 @@ class SD_EXPORT StackFrame { VariableProxy _proxy; MAP_IMPL _disabledNodes; + + std::string _frameName; public: - explicit StackFrame(VariableProxy &proxy); + explicit StackFrame(const VariableProxy &proxy, const std::string &frameName); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } void disableNode(int nodeId); bool isDisabled(int nodeId) const; + + const std::string& frameName() const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 66286bd7d7a2..28e6ac092f68 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -85,7 +85,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str()); - LogicExecutor::processNode(&v.node(), f, graph); + LogicExecutor::processNode(&v.node(), stack, graph); } else if (v.node().hasCustomOp()) { // we skip all disabled nodes if (f.isDisabled(v.node().id())) diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 0abbde8cbc0e..2b58adf411cc 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { Stack::Stack(const VariableProxy &root) { - _frames.push_back(StackFrame(const_cast(root))); + _frames.push_back(StackFrame(const_cast(root), "defaultFrame")); } const VariableProxy &Stack::rootVariableSpace() const { @@ -43,5 +43,13 @@ StackFrame &Stack::root() { return _frames.front(); } +void Stack::openFrame(const std::string &frameName) { + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameName)); +} + +void Stack::closeFrame() { + +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index fb84551270b9..329330d23379 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { -StackFrame::StackFrame(VariableProxy &proxy) : _proxy(proxy) { } +StackFrame::StackFrame(const VariableProxy &proxy, const std::string &frameName) : _proxy(proxy), _frameName(frameName) { } void StackFrame::disableNode(int nodeId) { _disabledNodes[nodeId] = 1; @@ -34,5 +34,9 @@ bool StackFrame::isDisabled(int nodeId) const { return _disabledNodes.count(nodeId) > 0; } +const std::string& StackFrame::frameName() const { + return _frameName; +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicEnter.h b/libnd4j/include/graph/logic/LogicEnter.h index 021f63ec9b66..f341595ce34f 100644 --- a/libnd4j/include/graph/logic/LogicEnter.h +++ b/libnd4j/include/graph/logic/LogicEnter.h @@ -29,7 +29,7 @@ namespace sd { namespace graph { class LogicEnter { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicExecutor.h b/libnd4j/include/graph/logic/LogicExecutor.h index bc696ac05325..83a4e2a3e6e1 100644 --- a/libnd4j/include/graph/logic/LogicExecutor.h +++ b/libnd4j/include/graph/logic/LogicExecutor.h @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace sd { namespace graph { @@ -37,7 +37,7 @@ namespace graph { */ class LogicExecutor { public: - static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicExit.h b/libnd4j/include/graph/logic/LogicExit.h index cc1e8dfc99c1..82e4e134f086 100644 --- a/libnd4j/include/graph/logic/LogicExit.h +++ b/libnd4j/include/graph/logic/LogicExit.h @@ -29,7 +29,7 @@ namespace sd { namespace graph { class LogicExit { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/LogicMerge.h b/libnd4j/include/graph/logic/LogicMerge.h index 3130f1a97354..541b7396bff2 100644 --- a/libnd4j/include/graph/logic/LogicMerge.h +++ b/libnd4j/include/graph/logic/LogicMerge.h @@ -30,7 +30,7 @@ namespace graph { class LogicMerge { public: - static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/LogicNextIteration.h b/libnd4j/include/graph/logic/LogicNextIteration.h index 13752130fe60..aeb2b16bd7e3 100644 --- a/libnd4j/include/graph/logic/LogicNextIteration.h +++ b/libnd4j/include/graph/logic/LogicNextIteration.h @@ -30,7 +30,7 @@ namespace graph { class LogicNextIeration { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index 93bf514a12b9..e34d475a501d 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -38,7 +38,7 @@ namespace graph { */ class LogicSwitch { public: - static Nd4jStatus processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index db52b4b732c0..89e9b0a235b3 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -24,58 +24,33 @@ namespace sd { namespace graph { -Nd4jStatus LogicEnter::processNode(const Node *node) { - throw std::runtime_error("LogicEnter::processNode - not implemented yet"); - /* - // this op replicates input variable into the frame. basically happens once - for single loop. - // sure, if there's inner loop within outer loop, it'll be called once for - outer loop and multiple times for inner loop - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); +/** + * This function does 2 things: + * - Propagates input variable + * - Opens new StackFrame + */ +Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + // if current frameName isn't equal to node frame name - we'll open new StackFrame then + // FIXME: instead of node->name() we should use proper + if (node->name() != stack.back().frameName()) + stack.openFrame(node->name()); - // basically, first non-null variable is our target - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); + const auto &frame = stack.back(); - if (__variableSpace->hasVariable(inputAddr)) { - auto var = __variableSpace->getVariable(inputAddr); - if (var->hasNDArray()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), - node->id(), 0); + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); + // validate Node state + REQUIRE_TRUE(inputs.size() == 1, 0, "Enter: op must have exactly 1 inputs"); + REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Enter: input Variable doesn't exist"); - break; - } else if (var->hasNDArrayList()) { - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), - node->id(), 0); - - auto list = var->getNDArrayList(); - lvar->setNDArrayList(list); - lvar->markReadOnly(true); - - break; - } else { - // FIXME: can we really have third case here? - continue; - } - } - } + // now we propagate input as ouwn output + auto input = varSpace.getVariable(inputs[0]); + varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); return sd::Status::OK(); - */ } + } // namespace graph } // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 22fc63ea726b..38a7a309f6ae 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -34,7 +34,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame, const OptimizedGraph& graph) { +Nd4jStatus LogicExecutor::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { switch (node->opNum()) { case sd::logic::While: return LogicWhile::processNode(node); @@ -43,21 +43,21 @@ Nd4jStatus LogicExecutor::processNode(const Node *node, StackFrame &frame, const case sd::logic::Conditional: return LogicConditional::processNode(node); case sd::logic::Switch: - return LogicSwitch::processNode(node, frame, graph); + return LogicSwitch::processNode(node, stack, graph); case sd::logic::Return: return LogicReturn::processNode(node); case sd::logic::Expose: return LogicExpose::processNode(node); case sd::logic::Merge: - return LogicMerge::processNode(node, frame, graph); + return LogicMerge::processNode(node, stack, graph); case sd::logic::LoopCond: return LogicLoopCond::processNode(node); case sd::logic::NextIteration: - return LogicNextIeration::processNode(node); + return LogicNextIeration::processNode(node, stack, graph); case sd::logic::Exit: - return LogicExit::processNode(node); + return LogicExit::processNode(node, stack, graph); case sd::logic::Enter: - return LogicEnter::processNode(node); + return LogicEnter::processNode(node, stack, graph); } if (node->name().empty()) { diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index c5329a1caaf9..421c15cb084a 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -23,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicExit::processNode(const Node *node) { +Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // this op is basically no-op // we just know it exists throw std::runtime_error("LogicExit::processNode - Not implemented yet"); diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index b99a83426005..3858f75f13fb 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -24,7 +24,10 @@ namespace sd { namespace graph { -Nd4jStatus LogicMerge::processNode(const Node *node, StackFrame &frame, const OptimizedGraph& graph) { +Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + // getting current frame first + auto &frame = stack.back(); + const auto &inputs = node->inputs(); auto &varSpace = const_cast(frame.variableProxy()); diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index fd59336ad85c..219c98d742d6 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -23,7 +23,7 @@ namespace sd { namespace graph { -Nd4jStatus LogicNextIeration::processNode(const Node *node) { +Nd4jStatus LogicNextIeration::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { throw std::runtime_error( "LogicNextIeration::processNode - not implemented yet"); /* diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 189fa4373f43..c67a8181d474 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -76,7 +76,10 @@ static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const Opti } } -Nd4jStatus LogicSwitch::processNode(const Node* node, StackFrame &frame, const OptimizedGraph& graph) { +Nd4jStatus LogicSwitch::processNode(const Node* node, Stack &stack, const OptimizedGraph& graph) { + // getting current frame first + auto &frame = stack.back(); + const auto &inputs = node->inputs(); const auto &outputs = node->outputs(); From 1ab388c9c77eb0480959e16a12f114fb7fae2bc6 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 11:10:08 +0300 Subject: [PATCH 206/233] Numeric Node frameId Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 6 ++++++ libnd4j/include/graph/execution/Stack.h | 2 +- libnd4j/include/graph/execution/StackFrame.h | 6 +++--- libnd4j/include/graph/execution/impl/Stack.cpp | 6 +++--- .../include/graph/execution/impl/StackFrame.cpp | 6 +++--- libnd4j/include/graph/impl/Node.cpp | 14 +++++++++++++- libnd4j/include/graph/logic/impl/LogicEnter.cpp | 5 ++--- .../tests_cpu/layers_tests/GraphAnalysisTests.cpp | 2 ++ 8 files changed, 33 insertions(+), 14 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index b872c283c09c..e5451c8337e1 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -70,6 +70,9 @@ class SD_EXPORT Node { std::shared_ptr _customOp; + // this field is for Enter nodes only + int _frameId = -1; + public: explicit Node(const sd::ops::DeclarableOp &op, const std::string &nodeName = {}, @@ -176,6 +179,9 @@ class SD_EXPORT Node { // this method converts string deps to int deps void actualizeDependencies(const MAP_IMPL &lookupTable) const; + int frameId() const; + void setFrameId(int frameId); + // utility method that generates legacy ops out of OpType and OpNum static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); }; diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h index 9fc04a874056..30da79fecf0f 100644 --- a/libnd4j/include/graph/execution/Stack.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -42,7 +42,7 @@ class SD_EXPORT Stack { const VariableProxy& rootVariableSpace() const; - void openFrame(const std::string &frameName); + void openFrame(int frameId); void closeFrame(); }; diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index e3216c3bbe8f..daf59a61bef5 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -34,9 +34,9 @@ class SD_EXPORT StackFrame { MAP_IMPL _disabledNodes; - std::string _frameName; + int _frameId; public: - explicit StackFrame(const VariableProxy &proxy, const std::string &frameName); + explicit StackFrame(const VariableProxy &proxy, int frameId); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } @@ -44,7 +44,7 @@ class SD_EXPORT StackFrame { void disableNode(int nodeId); bool isDisabled(int nodeId) const; - const std::string& frameName() const; + int frameId() const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 2b58adf411cc..58f1a165b452 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { Stack::Stack(const VariableProxy &root) { - _frames.push_back(StackFrame(const_cast(root), "defaultFrame")); + _frames.push_back(StackFrame(const_cast(root), -1)); } const VariableProxy &Stack::rootVariableSpace() const { @@ -43,8 +43,8 @@ StackFrame &Stack::root() { return _frames.front(); } -void Stack::openFrame(const std::string &frameName) { - _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameName)); +void Stack::openFrame(int frameId) { + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId)); } void Stack::closeFrame() { diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 329330d23379..20798ac3b8f2 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { -StackFrame::StackFrame(const VariableProxy &proxy, const std::string &frameName) : _proxy(proxy), _frameName(frameName) { } +StackFrame::StackFrame(const VariableProxy &proxy, int frameId) : _proxy(proxy), _frameId(frameId) { } void StackFrame::disableNode(int nodeId) { _disabledNodes[nodeId] = 1; @@ -34,8 +34,8 @@ bool StackFrame::isDisabled(int nodeId) const { return _disabledNodes.count(nodeId) > 0; } -const std::string& StackFrame::frameName() const { - return _frameName; +int StackFrame::frameId() const { + return _frameId; } } // namespace graph diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index fe256f82fc18..f5bd3ad2ea55 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -367,6 +367,14 @@ Node::Node(OpType opType, int opNum, int id, std::initializer_list input, } }; +int Node::frameId() const { + return _frameId; +} + +void Node::setFrameId(int frameId) { + _frameId = frameId; +} + Node::Node(const FlatNode *node) { // temporary holders _dimensions, for transferring axis into ContextPrototype std::vector axis; @@ -428,7 +436,7 @@ Node::Node(const FlatNode *node) { if (node->extraInteger()->size() < 1) throw std::runtime_error("Enter Node [" + StringUtils::valueToString(this->id()) + "] must have FrameID specified"); - //this->setFrameId(node->extraInteger()->Get(0)); + this->setFrameId(node->extraInteger()->Get(0)); } // these ops allow in-place execution by design @@ -584,6 +592,7 @@ Node::Node(const Node &other) noexcept { _customOp = other._customOp; _name = other._name; _id = other._id; + _frameId = other._frameId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -606,6 +615,7 @@ Node &Node::operator=(const Node &other) noexcept { _customOp = other._customOp; _name = other._name; _id = other._id; + _frameId = other._frameId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -629,6 +639,7 @@ Node::Node(Node &&other) noexcept { _customOp = other._customOp; _name = std::move(other._name); _id = other._id; + _frameId = other._frameId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -651,6 +662,7 @@ Node &Node::operator=(Node &&other) noexcept { _customOp = other._customOp; _name = std::move(other._name); _id = other._id; + _frameId = other._frameId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 89e9b0a235b3..12536730fc16 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -32,9 +32,8 @@ namespace graph { */ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // if current frameName isn't equal to node frame name - we'll open new StackFrame then - // FIXME: instead of node->name() we should use proper - if (node->name() != stack.back().frameName()) - stack.openFrame(node->name()); + if (node->frameId() != stack.back().frameId()) + stack.openFrame(node->frameId()); const auto &frame = stack.back(); diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 95fe175c7681..35eb069708b5 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1069,6 +1069,8 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while1) { ASSERT_EQ(1, optimized.layer(0).width()); graph.printOut(); + + //graph.execute(); } TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { From b13edaedb798d097a18f9cc8573c5880d77a1d0e Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 12:37:40 +0300 Subject: [PATCH 207/233] Numeric Node frameId Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 6 ++++- libnd4j/include/graph/execution/Stack.h | 2 +- libnd4j/include/graph/execution/StackFrame.h | 9 +++++-- .../include/graph/execution/impl/Stack.cpp | 6 ++--- .../graph/execution/impl/StackFrame.cpp | 10 ++++++- libnd4j/include/graph/impl/Node.cpp | 12 +++++++++ .../include/graph/logic/impl/LogicEnter.cpp | 15 +++++++---- .../include/graph/logic/impl/LogicExit.cpp | 27 +++++-------------- 8 files changed, 54 insertions(+), 33 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index e5451c8337e1..de04f6194d62 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -71,7 +71,8 @@ class SD_EXPORT Node { std::shared_ptr _customOp; // this field is for Enter nodes only - int _frameId = -1; + mutable int _frameId = -1; + mutable int _exitId = -1; public: explicit Node(const sd::ops::DeclarableOp &op, @@ -182,6 +183,9 @@ class SD_EXPORT Node { int frameId() const; void setFrameId(int frameId); + int exitId() const; + void setExitId(int exitId) const; + // utility method that generates legacy ops out of OpType and OpNum static std::shared_ptr buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum); }; diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h index 30da79fecf0f..97e0fe50e6d8 100644 --- a/libnd4j/include/graph/execution/Stack.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -42,7 +42,7 @@ class SD_EXPORT Stack { const VariableProxy& rootVariableSpace() const; - void openFrame(int frameId); + void openFrame(int frameId, int enterId); void closeFrame(); }; diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index daf59a61bef5..874e5d579c15 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -34,9 +34,12 @@ class SD_EXPORT StackFrame { MAP_IMPL _disabledNodes; - int _frameId; + // these fields are used + int _frameId = -119; + int _enterId = -119; + mutable int _exitId = -119; public: - explicit StackFrame(const VariableProxy &proxy, int frameId); + explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } @@ -45,6 +48,8 @@ class SD_EXPORT StackFrame { bool isDisabled(int nodeId) const; int frameId() const; + int enterId() const; + int exitId() const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 58f1a165b452..4e208649211e 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { Stack::Stack(const VariableProxy &root) { - _frames.push_back(StackFrame(const_cast(root), -1)); + _frames.push_back(StackFrame(const_cast(root), -1, 0)); } const VariableProxy &Stack::rootVariableSpace() const { @@ -43,8 +43,8 @@ StackFrame &Stack::root() { return _frames.front(); } -void Stack::openFrame(int frameId) { - _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId)); +void Stack::openFrame(int frameId, int enterId) { + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId)); } void Stack::closeFrame() { diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 20798ac3b8f2..a4e6c98079dc 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { -StackFrame::StackFrame(const VariableProxy &proxy, int frameId) : _proxy(proxy), _frameId(frameId) { } +StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId) : _proxy(proxy), _frameId(frameId), _enterId(enterId) { } void StackFrame::disableNode(int nodeId) { _disabledNodes[nodeId] = 1; @@ -38,5 +38,13 @@ int StackFrame::frameId() const { return _frameId; } +int StackFrame::enterId() const { + return _enterId; +} + +int StackFrame::exitId() const { + return _exitId; +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/impl/Node.cpp b/libnd4j/include/graph/impl/Node.cpp index f5bd3ad2ea55..4a4081e7d779 100644 --- a/libnd4j/include/graph/impl/Node.cpp +++ b/libnd4j/include/graph/impl/Node.cpp @@ -375,6 +375,14 @@ void Node::setFrameId(int frameId) { _frameId = frameId; } +int Node::exitId() const { + return _exitId; +} + +void Node::setExitId(int exitId) const { + _exitId = exitId; +} + Node::Node(const FlatNode *node) { // temporary holders _dimensions, for transferring axis into ContextPrototype std::vector axis; @@ -593,6 +601,7 @@ Node::Node(const Node &other) noexcept { _name = other._name; _id = other._id; _frameId = other._frameId; + _exitId = other._exitId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -616,6 +625,7 @@ Node &Node::operator=(const Node &other) noexcept { _name = other._name; _id = other._id; _frameId = other._frameId; + _exitId = other._exitId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -640,6 +650,7 @@ Node::Node(Node &&other) noexcept { _name = std::move(other._name); _id = other._id; _frameId = other._frameId; + _exitId = other._exitId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; @@ -663,6 +674,7 @@ Node &Node::operator=(Node &&other) noexcept { _name = std::move(other._name); _id = other._id; _frameId = other._frameId; + _exitId = other._exitId; _hasExternalOutputs = other._hasExternalOutputs; _hasExternalInputs = other._hasExternalInputs; diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 12536730fc16..5088dcabab33 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -27,13 +27,17 @@ namespace graph { /** * This function does 2 things: - * - Propagates input variable - * - Opens new StackFrame + * - Propagates input Variable + * - Opens new StackFrame (only if that's the first Enter in this Loop) */ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // if current frameName isn't equal to node frame name - we'll open new StackFrame then - if (node->frameId() != stack.back().frameId()) - stack.openFrame(node->frameId()); + if (node->frameId() != stack.back().frameId()) { + stack.openFrame(node->frameId(), node->id()); + + // since this is the loop entrance, we'll rewind to this Node once iteration ends + // Enter -> Merge -> NextIteration + } const auto &frame = stack.back(); @@ -44,7 +48,8 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz REQUIRE_TRUE(inputs.size() == 1, 0, "Enter: op must have exactly 1 inputs"); REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Enter: input Variable doesn't exist"); - // now we propagate input as ouwn output + // now we propagate input as own output + // ssince we've opened new StackFrame, this Variable will end up in new VariableProxy auto input = varSpace.getVariable(inputs[0]); varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 421c15cb084a..6fd07c875ff6 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -23,28 +23,15 @@ namespace sd { namespace graph { + +/** + * This funciton does 2 things: + * - Propagates input Variable to outer StackFrame + * - closes current StackFrame (only if this is the last Exit node in this loop) + */ Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { - // this op is basically no-op - // we just know it exists throw std::runtime_error("LogicExit::processNode - Not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->contextPrototype(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, - nullptr, node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - return ND4J_STATUS_OK; - */ } + } // namespace graph } // namespace sd \ No newline at end of file From d45132d0996b2a1da98e010d160b3d71d4e8be44 Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 19 Jun 2020 13:50:09 +0300 Subject: [PATCH 208/233] - determine correspondence between enter and exit nodes Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 7db7b14c8987..23168e3d2505 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -66,10 +66,16 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS // create workMap, fill vectors containing input and output nodes per each node, and find start nodes std::vector startNodes; - for (const auto& p : _nodesMap) { + for (auto& p : _nodesMap) { const auto& inputs = p.second.inputs(); + if(p.second.name().find("Exit") != std::string::npos) { + const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; + p.second.setFrameId(_nodesMap[idOfEnter].frameId()); + _nodesMap[idOfEnter].setExitId(p.first); + } + for (int i = 0; i < inputs.size(); ++i) { if (_nodesMap.count(inputs[i].first) != 0) { // is op @@ -167,14 +173,12 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS if(name->find("Enter") == std::string::npos) { for (const auto& id : p0->second._opSeq) { - if(id == -1) - break; name = &_nodesMap[id].name(); if(name->find("Enter") != std::string::npos) { isInLoop = true; break; } - } + } } else isInLoop = true; From 7a8b3792c54a801021bbc6511a41aab0e4524d5e Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 19 Jun 2020 14:26:37 +0300 Subject: [PATCH 209/233] - correct procedure of opSeq creation for ops within given loop Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 23168e3d2505..c1ab69d08c4d 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -187,11 +187,30 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS continue; std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop + for (auto p1 = std::next(p0); p1 != workMap.end(); ++p1) { // std::next(p0) = p0 + 1 + if(!p1->second._opSeq.empty() && p1->second._opSeq[0] == -1) continue; - if(_nodesMap[p1->first].name().find(loopName) == std::string::npos) + + isInLoop = false; + name = &_nodesMap[p1->first].name(); + + if(name->find(loopName) == std::string::npos) { + for (const auto& id : p1->second._opSeq) { + name = &_nodesMap[id].name(); + if(name->find(loopName) != std::string::npos) { + isInLoop = true; + break; + } + } + } + else + isInLoop = true; + + if(!isInLoop) continue; + p0->second._opSeq.push_back(p1->first); p0->second._opSeq.insert(p0->second._opSeq.end(), p1->second._opSeq.begin(), p1->second._opSeq.end()); p1->second._opSeq.clear(); From 5bfed377dad063f8c3cbe8373f1ae335d273cda1 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 14:57:13 +0300 Subject: [PATCH 210/233] next step Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/OptimizedGraph.h | 6 ++- .../include/graph/execution/ExecutionLayer.h | 6 +++ libnd4j/include/graph/execution/OpSequence.h | 1 + libnd4j/include/graph/execution/StackFrame.h | 1 + .../graph/execution/impl/ExecutionLayer.cpp | 7 ++++ .../graph/execution/impl/OpSequence.cpp | 15 ++++++- .../graph/execution/impl/StackFrame.cpp | 4 ++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 39 +++++++++++++++++++ .../include/graph/logic/impl/LogicEnter.cpp | 14 +++++-- .../include/graph/logic/impl/LogicSwitch.cpp | 8 ++-- .../layers_tests/GraphAnalysisTests.cpp | 2 +- 11 files changed, 93 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 7ab39681f65a..2b0736b97be9 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -88,7 +88,11 @@ class SD_EXPORT OptimizedGraph { * returns reference on _nodesMap * @return */ - const MAP_IMPL& getNodesMap() const { return _nodesMap; } + const MAP_IMPL& nodesMap() const { return _nodesMap; } + + int nodeLayer(int nodeId) const; + int nodeSequence(int nodeId) const; + int nodeIndex(int nodeId) const; }; diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index 83395a65d7a7..24b3c1bc2910 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -72,6 +72,12 @@ class SD_EXPORT ExecutionLayer { */ void sortOpSequences(); + /** + * This method checks if specified Node resides within this ExecutionLayer + * @param nodeId + * @return + */ + bool hasNode(int nodeId) const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index e477d4968443..f3ad8e157a82 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -108,6 +108,7 @@ class SD_EXPORT OpSequence */ int nodeId(int index) const; int nodeIndex(int id) const; + bool hasNode(int id) const; /** * Iterator functionality for OpSequence diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 874e5d579c15..60d7a418a956 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -50,6 +50,7 @@ class SD_EXPORT StackFrame { int frameId() const; int enterId() const; int exitId() const; + void setExitId(int id) const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index bff7cea61602..24ba6d5117ba 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -36,6 +36,13 @@ const OpSequence &ExecutionLayer::operator[](uint64_t index) const { return at(index); } +bool ExecutionLayer::hasNode(int nodeId) const { + for (const auto &v:_sequences) + if (v.hasNode(nodeId)) + return true; + + return false; +} void ExecutionLayer::append(OpSequence&& sequence) { _sequences.emplace_back(std::move(sequence)); diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 4bde45cf3a7b..6d8e54be277d 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -38,18 +38,24 @@ OpSequence::OpSequence(const OpSequence &other) noexcept { _ops.clear(); for (const auto &v : other._ops) _ops.emplace_back(v); + + _idToIndex = other._idToIndex; + _indexToId = other._indexToId; } //////////////////////////////////////////////////////////////////////// // move constructor OpSequence::OpSequence(OpSequence &&other) noexcept: _ops(std::move(other._ops)) { - + _idToIndex = std::move(other._idToIndex); + _indexToId = std::move(other._indexToId); } OpSequence &OpSequence::operator=(OpSequence &&other) noexcept { if (this == &other) return *this; _ops = std::move(other._ops); + _idToIndex = std::move(other._idToIndex); + _indexToId = std::move(other._indexToId); return *this; } @@ -60,6 +66,9 @@ OpSequence &OpSequence::operator=(const OpSequence &other) noexcept { _ops.clear(); for (const auto &v : other._ops) _ops.emplace_back(v); + _idToIndex = other._idToIndex; + _indexToId = other._indexToId; + return *this; } @@ -117,6 +126,10 @@ int OpSequence::nodeIndex(int id) const { return _idToIndex.at(id); } +bool OpSequence::hasNode(int id) const { + return _idToIndex.count(id) > 0; +} + OpSequence::iterator OpSequence::begin() { return OpSequence::iterator(*this, 0); } diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index a4e6c98079dc..690a87073d6b 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -46,5 +46,9 @@ int StackFrame::exitId() const { return _exitId; } +void StackFrame::setExitId(int id) const { + _exitId = id; +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 23168e3d2505..41e21bd8e2ff 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -277,6 +278,44 @@ size_t OptimizedGraph::size() const { return size; } +int OptimizedGraph::nodeLayer(int nodeId) const { + int cnt = 0; + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) + return cnt; + + cnt++; + } + + throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + +int OptimizedGraph::nodeIndex(int nodeId) const { + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) { + for (int e = 0; e < v.width(); e++) { + if (v[e].hasNode(nodeId)) + return v[e].nodeIndex(nodeId); + } + } + } + + throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + +int OptimizedGraph::nodeSequence(int nodeId) const { + for (const auto &v:_sortedGraph) { + if (v.hasNode(nodeId)) { + for (int e = 0; e < v.width(); e++) { + if (v[e].hasNode(nodeId)) + return e; + } + } + } + + throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); +} + void OptimizedGraph::printOut() const { for (uint i = 0; i < _sortedGraph.size(); ++i) { printf("Layer [%u] {\n", i); diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 5088dcabab33..da0fe9ee3012 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -33,14 +33,22 @@ namespace graph { Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // if current frameName isn't equal to node frame name - we'll open new StackFrame then if (node->frameId() != stack.back().frameId()) { - stack.openFrame(node->frameId(), node->id()); - // since this is the loop entrance, we'll rewind to this Node once iteration ends - // Enter -> Merge -> NextIteration + stack.openFrame(node->frameId(), node->id()); } + // getting current frame (it might be the new one!) const auto &frame = stack.back(); + // we need to find rewind point - it has to be NextIteration node with max index within OpSequence + // and we need to find exit point - it has to be Exit node, with max index within OpSequence + auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; + auto thisExitIndex = graph.nodeIndex(node->exitId()); + + // we want to exit after the last Exit node + if (thisExitIndex > currentExitIndex) + frame.setExitId(node->exitId()); + const auto &inputs = node->inputs(); auto &varSpace = const_cast(frame.variableProxy()); diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index c67a8181d474..c2a93f227908 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -40,11 +40,11 @@ static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const Opti // we're going to roll through all consumers for (const auto &o:outputs) { - if (graph.getNodesMap().count(o.first) == 0) + if (graph.nodesMap().count(o.first) == 0) throw std::runtime_error("pew-pew"); // now fetch disabled node - const auto &n = graph.getNodesMap().at(o.first); + const auto &n = graph.nodesMap().at(o.first); // edge case here: don't disable Merge node if (n.opType() == OpType_LOGIC && n.opNum() == sd::logic::Merge) @@ -66,10 +66,10 @@ static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const Opti if (o.second == second) { frame.disableNode(o.first); - if (graph.getNodesMap().count(o.first) == 0) + if (graph.nodesMap().count(o.first) == 0) throw std::runtime_error("pew-pew"); - const auto &n = graph.getNodesMap().at(o.first); + const auto &n = graph.nodesMap().at(o.first); disableBranch(frame, varSpace, graph, &n); } diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 35eb069708b5..b29b81da4d66 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1070,7 +1070,7 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while1) { graph.printOut(); - //graph.execute(); + graph.execute(); } TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { From d8892cd6ec2438c65871620f92b6ccda59d20415 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 15:19:30 +0300 Subject: [PATCH 211/233] next step Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/StackFrame.h | 4 ++++ .../graph/execution/impl/StackFrame.cpp | 8 ++++++++ .../include/graph/logic/impl/LogicEnter.cpp | 18 ++++++++++++++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 60d7a418a956..6de6d5fc91d5 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -37,6 +37,7 @@ class SD_EXPORT StackFrame { // these fields are used int _frameId = -119; int _enterId = -119; + mutable int _rewindId = -119; mutable int _exitId = -119; public: explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId); @@ -50,6 +51,9 @@ class SD_EXPORT StackFrame { int frameId() const; int enterId() const; int exitId() const; + int rewindId() const; + + void setRewindId(int id) const; void setExitId(int id) const; }; diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 690a87073d6b..4fefb87f289d 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -50,5 +50,13 @@ void StackFrame::setExitId(int id) const { _exitId = id; } +int StackFrame::rewindId() const { + return _rewindId; +} + +void StackFrame::setRewindId(int id) const { + _rewindId = id; +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index da0fe9ee3012..cb6b9840af25 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -37,10 +37,12 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz stack.openFrame(node->frameId(), node->id()); } + const auto &inputs = node->inputs(); + const auto &outputs = node->outputs(); + // getting current frame (it might be the new one!) const auto &frame = stack.back(); - // we need to find rewind point - it has to be NextIteration node with max index within OpSequence // and we need to find exit point - it has to be Exit node, with max index within OpSequence auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; auto thisExitIndex = graph.nodeIndex(node->exitId()); @@ -49,7 +51,19 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz if (thisExitIndex > currentExitIndex) frame.setExitId(node->exitId()); - const auto &inputs = node->inputs(); + + // we need to find rewind point - it has to be NextIteration node with max index within OpSequence + const auto &merge = graph.nodesMap().at(outputs[0].first); + const auto &iter = graph.nodesMap().at(merge.inputs()[1].first); + + // we must compare index of this NextIteration Node within OpSequence to the current one, if it's set + auto currentRewindIndex = frame.rewindId() >= 0 ? graph.nodeIndex(frame.rewindId()) : -1; + auto thisRewindIndex = graph.nodeIndex(iter.id()); + + // we want to rewind after the last NextIteration node + if (thisRewindIndex > currentRewindIndex) + frame.setRewindId(iter.id()); + auto &varSpace = const_cast(frame.variableProxy()); // validate Node state From 8be942c107f436e21f131a8c372460b53cb951ee Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 15:25:07 +0300 Subject: [PATCH 212/233] next step Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/logic/impl/LogicEnter.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index cb6b9840af25..23f27448453c 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -36,12 +36,11 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz // since this is the loop entrance, we'll rewind to this Node once iteration ends stack.openFrame(node->frameId(), node->id()); } - - const auto &inputs = node->inputs(); - const auto &outputs = node->outputs(); - // getting current frame (it might be the new one!) const auto &frame = stack.back(); + const auto &inputs = node->inputs(); + const auto &outputs = node->outputs(); + auto &varSpace = const_cast(frame.variableProxy()); // and we need to find exit point - it has to be Exit node, with max index within OpSequence auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; @@ -51,7 +50,6 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz if (thisExitIndex > currentExitIndex) frame.setExitId(node->exitId()); - // we need to find rewind point - it has to be NextIteration node with max index within OpSequence const auto &merge = graph.nodesMap().at(outputs[0].first); const auto &iter = graph.nodesMap().at(merge.inputs()[1].first); @@ -64,8 +62,6 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz if (thisRewindIndex > currentRewindIndex) frame.setRewindId(iter.id()); - auto &varSpace = const_cast(frame.variableProxy()); - // validate Node state REQUIRE_TRUE(inputs.size() == 1, 0, "Enter: op must have exactly 1 inputs"); REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Enter: input Variable doesn't exist"); From 3e00fe09cfb5f60e5a67a176ace07e20d22d49cd Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 18:03:45 +0300 Subject: [PATCH 213/233] Exit draft Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/StackFrame.h | 10 +++++++++ .../include/graph/execution/impl/Stack.cpp | 4 ++-- .../graph/execution/impl/StackFrame.cpp | 14 ++++++++++++- .../include/graph/logic/impl/LogicEnter.cpp | 1 + .../include/graph/logic/impl/LogicExit.cpp | 21 ++++++++++++++++++- 5 files changed, 46 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 6de6d5fc91d5..5d797830a77a 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -31,6 +31,7 @@ namespace graph { class SD_EXPORT StackFrame { private: VariableProxy _proxy; + StackFrame *_parent = nullptr; MAP_IMPL _disabledNodes; @@ -41,10 +42,13 @@ class SD_EXPORT StackFrame { mutable int _exitId = -119; public: explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId); + explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId, StackFrame &parent); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } + + void disableNode(int nodeId); bool isDisabled(int nodeId) const; @@ -55,6 +59,12 @@ class SD_EXPORT StackFrame { void setRewindId(int id) const; void setExitId(int id) const; + + /** + * This method returns parent frame + * @return + */ + StackFrame& parent() const; }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 4e208649211e..0fbbbcbefa05 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -44,11 +44,11 @@ StackFrame &Stack::root() { } void Stack::openFrame(int frameId, int enterId) { - _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId)); + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId, _frames.back())); } void Stack::closeFrame() { - + _frames.pop_back(); } } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 4fefb87f289d..896e9183798e 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -24,7 +24,15 @@ namespace sd { namespace graph { -StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId) : _proxy(proxy), _frameId(frameId), _enterId(enterId) { } +StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId) + : _proxy(proxy), _frameId(frameId), _enterId(enterId) { + +} + +StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId, StackFrame &parent) + : StackFrame(proxy, frameId, enterId) { + _parent = &parent; +} void StackFrame::disableNode(int nodeId) { _disabledNodes[nodeId] = 1; @@ -58,5 +66,9 @@ void StackFrame::setRewindId(int id) const { _rewindId = id; } +StackFrame &StackFrame::parent() const { + return *_parent; +} + } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 23f27448453c..ce5d1a957230 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -36,6 +36,7 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz // since this is the loop entrance, we'll rewind to this Node once iteration ends stack.openFrame(node->frameId(), node->id()); } + // getting current frame (it might be the new one!) const auto &frame = stack.back(); const auto &inputs = node->inputs(); diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 6fd07c875ff6..cc6fb482a00e 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -30,7 +30,26 @@ namespace graph { * - closes current StackFrame (only if this is the last Exit node in this loop) */ Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { - throw std::runtime_error("LogicExit::processNode - Not implemented yet"); + // getting current frame (it must be the StackFrame created for a While loop) + const auto &frame = stack.back(); + + // we must propagate variable from this frame to parent one + const auto &parent = frame.parent(); + + const auto &inputs = node->inputs(); + + REQUIRE_TRUE(inputs.size() == 1, 0, "Exit: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "Exit: input Variable doesn't exist"); + + // get Variable from current VariableProxy and put to the ParentOne + auto var = frame.variableProxy().getVariable(inputs[0]); + const_cast(parent.variableProxy()).putVariable(inputs[0], var); + + // if this is the last Exit node - we close current StackFrame + if (frame.exitId() == node->id()) + stack.closeFrame(); + + return Status::OK(); } } // namespace graph From e706c9e9dfac1336fc1f554a1703d0dec0757013 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 19:37:44 +0300 Subject: [PATCH 214/233] first test passed Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/Stack.h | 1 + .../graph/execution/impl/GraphExecutor.cpp | 14 ++++-- .../include/graph/execution/impl/Stack.cpp | 10 ++++ libnd4j/include/graph/logic/LogicLoopCond.h | 2 +- .../include/graph/logic/impl/LogicEnter.cpp | 13 ++++-- .../graph/logic/impl/LogicExecutor.cpp | 2 +- .../include/graph/logic/impl/LogicExit.cpp | 2 +- .../graph/logic/impl/LogicLoopCond.cpp | 46 +++++++------------ .../include/graph/logic/impl/LogicMerge.cpp | 26 ++++++++--- .../graph/logic/impl/LogicNextIteration.cpp | 32 +++++-------- .../layers_tests/GraphAnalysisTests.cpp | 6 ++- 11 files changed, 85 insertions(+), 69 deletions(-) diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h index 97e0fe50e6d8..8be487a0f5a7 100644 --- a/libnd4j/include/graph/execution/Stack.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -43,6 +43,7 @@ class SD_EXPORT Stack { const VariableProxy& rootVariableSpace() const; void openFrame(int frameId, int enterId); + void iterateFrame(int frameId, int enterId); void closeFrame(); }; diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 28e6ac092f68..867834fc90c9 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -79,6 +79,10 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, auto &f = stack.back(); auto &p = f.variableProxy(); + // we skip all disabled nodes + if (f.isDisabled(v.node().id())) + continue; + if (v.node().opType() == OpType_LOGIC) { nd4j_printf("Node <%i:%s> is a logic op\n", @@ -86,11 +90,13 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, v.node().name().empty() ? "" : v.node().name().c_str()); LogicExecutor::processNode(&v.node(), stack, graph); - } else if (v.node().hasCustomOp()) { - // we skip all disabled nodes - if (f.isDisabled(v.node().id())) - continue; + // next iteration is special case: we might rewind + if (v.node().opNum() == sd::logic::NextIteration) { + if (f.rewindId() == v.node().id()) + e = seq.nodeIndex(f.enterId()) - 1; + } + } else if (v.node().hasCustomOp()) { // only Ops can be executed this way :( result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); } else { diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 0fbbbcbefa05..46efcfffdea0 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -47,7 +47,17 @@ void Stack::openFrame(int frameId, int enterId) { _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId, _frames.back())); } +void Stack::iterateFrame(int frameId, int enterId) { + auto ¤t = this->back(); + auto &parent = current.parent(); + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId, parent)); +} + void Stack::closeFrame() { + // we should remove all frames untl we hit parent frame + auto ¤t = this->back(); + auto &parent = current.parent(); + _frames.pop_back(); } diff --git a/libnd4j/include/graph/logic/LogicLoopCond.h b/libnd4j/include/graph/logic/LogicLoopCond.h index 1642d5955581..5564ac6613f6 100644 --- a/libnd4j/include/graph/logic/LogicLoopCond.h +++ b/libnd4j/include/graph/logic/LogicLoopCond.h @@ -29,7 +29,7 @@ namespace sd { namespace graph { class LogicLoopCond { public: - static Nd4jStatus processNode(const Node* node); + static Nd4jStatus processNode(const Node* node, Stack &stack, const OptimizedGraph& graph); }; } // namespace graph } // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index ce5d1a957230..2a10fc184dcb 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -31,10 +31,15 @@ namespace graph { * - Opens new StackFrame (only if that's the first Enter in this Loop) */ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { - // if current frameName isn't equal to node frame name - we'll open new StackFrame then - if (node->frameId() != stack.back().frameId()) { - // since this is the loop entrance, we'll rewind to this Node once iteration ends - stack.openFrame(node->frameId(), node->id()); + // the only possible case for this equality is NextIteration rewind + if (node->id() == stack.back().enterId()) { + stack.iterateFrame(node->frameId(), node->id()); + } else { + // if current frameName isn't equal to node frame name - we'll open new StackFrame then + if (node->frameId() != stack.back().frameId()) { + // since this is the loop entrance, we'll rewind to this Node once iteration ends + stack.openFrame(node->frameId(), node->id()); + } } // getting current frame (it might be the new one!) diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 38a7a309f6ae..829b8593719c 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -51,7 +51,7 @@ Nd4jStatus LogicExecutor::processNode(const Node *node, Stack &stack, const Opti case sd::logic::Merge: return LogicMerge::processNode(node, stack, graph); case sd::logic::LoopCond: - return LogicLoopCond::processNode(node); + return LogicLoopCond::processNode(node, stack, graph); case sd::logic::NextIteration: return LogicNextIeration::processNode(node, stack, graph); case sd::logic::Exit: diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index cc6fb482a00e..03c19b71cd99 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -43,7 +43,7 @@ Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const Optimize // get Variable from current VariableProxy and put to the ParentOne auto var = frame.variableProxy().getVariable(inputs[0]); - const_cast(parent.variableProxy()).putVariable(inputs[0], var); + const_cast(parent.variableProxy()).putVariable({node->id(), 0}, *var->getNDArray()); // if this is the last Exit node - we close current StackFrame if (frame.exitId() == node->id()) diff --git a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp index 5fa1c6e1bfa6..cdb595b678b2 100644 --- a/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp +++ b/libnd4j/include/graph/logic/impl/LogicLoopCond.cpp @@ -23,36 +23,22 @@ namespace sd { namespace graph { -Nd4jStatus LogicLoopCond::processNode(const Node *node) { - throw std::runtime_error("LogicLoopCond::processNode - Not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - Context ctx(node->contextPrototype(), __variableSpace); - auto input = ctx.variable(0)->getNDArray(); - - std::pair pair0(node->id(), 0); - - if (!__variableSpace->hasVariable(pair0)) - __variableSpace->putVariable(pair0, new Variable(nullptr, nullptr, -node->id(), 0)); - - __variableSpace->getVariable(pair0)->setNDArray(input); - __variableSpace->getVariable(pair0)->markRemovable(false); - - // pass further - if (input->e(0) > 0) { - // if condition is TRUE body will be invoked some time soon -// __flowPath->markFrameActive(node->getFrameId(), true); - //__flowPath->i - } else { - // body won't be activated -// __flowPath->markFrameActive(node->getFrameId(), false); - } - - return ND4J_STATUS_OK; - */ + +Nd4jStatus LogicLoopCond::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + auto &frame = stack.back(); + + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); + + REQUIRE_TRUE(inputs.size() == 1, 0, "LoopCond: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "LoopCond: input Variable doesn't exist"); + + // Propagate Variable + auto var = varSpace.getVariable(inputs[0]); + varSpace.putVariable({node->id(), 0}, *var->getNDArray()); + + return Status::OK(); } + } // namespace graph } // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 3858f75f13fb..fb3db88b9b82 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -36,14 +36,28 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz REQUIRE_TRUE(false, 0, "Merge: only 1 input should be disabled, but got both of them down"); } - // we're getting first non-disable input and propagate it - const auto &p = frame.isDisabled(inputs[0].first) ? inputs[1] : inputs[0]; + const auto &firstNode = graph.nodesMap().at(inputs[0].first); + const auto &secondNode = graph.nodesMap().at(inputs[1].first); - REQUIRE_TRUE(frame.variableProxy().hasVariable(p), 0, "Merge: Variable [%i:%i] doesn't exist", p.first, p.second); + if ((firstNode.opType() == OpType_LOGIC && firstNode.opNum() == sd::logic::NextIteration)|| (secondNode.opType() == OpType_LOGIC && secondNode.opNum() == sd::logic::NextIteration)) { + // if we're on NextIteration merge, we'll propagate its output regardless of first arg existence + if (firstNode.opType() == OpType_LOGIC && firstNode.opNum() == sd::logic::NextIteration) { + auto id = varSpace.hasVariable(inputs[0]) && varSpace.getVariable(inputs[0])->hasNDArray() ? inputs[0] : inputs[1]; + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); + } else { + auto id = varSpace.hasVariable(inputs[1]) && varSpace.getVariable(inputs[1])->hasNDArray() ? inputs[1] : inputs[0]; + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); + } + } else { + // we're getting first non-disabled input and propagate it + const auto &p = frame.isDisabled(inputs[0].first) ? inputs[1] : inputs[0]; - std::pair t(node->id(), 0); - auto array = varSpace.getVariable(p)->getNDArray().get(); - varSpace.putVariable(t, *array); + REQUIRE_TRUE(frame.variableProxy().hasVariable(p), 0, "Merge: Variable [%i:%i] doesn't exist", p.first, p.second); + + std::pair t(node->id(), 0); + auto array = varSpace.getVariable(p)->getNDArray().get(); + varSpace.putVariable(t, *array); + } return Status::OK(); } diff --git a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp index 219c98d742d6..86ecab4e54c1 100644 --- a/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp +++ b/libnd4j/include/graph/logic/impl/LogicNextIteration.cpp @@ -23,32 +23,22 @@ namespace sd { namespace graph { -Nd4jStatus LogicNextIeration::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { - throw std::runtime_error( - "LogicNextIeration::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - auto __flowPath = __variableSpace->flowPath(); - - auto inputAddr = node->input()->at(0); - auto var = __variableSpace->getVariable(inputAddr); +Nd4jStatus LogicNextIeration::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { + auto &frame = stack.back(); - Variable *lvar = nullptr; - if (__variableSpace->hasVariable(node->id(), 0)) - lvar = __variableSpace->getVariable(node->id(), 0); - else - lvar = new Variable(nullptr, node->getName().c_str(), node->id(), 0); + const auto &inputs = node->inputs(); + auto &varSpace = const_cast(frame.variableProxy()); -// if (lvar->hasNDArray()) -// delete lvar->getNDArray(); + REQUIRE_TRUE(inputs.size() == 1, 0, "LoopCond: op must have exactly 1 input1"); + REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "LoopCond: input Variable doesn't exist"); - auto array = var->getNDArray(); - lvar->setNDArray(array); - lvar->markReadOnly(true); + // Propagate Variable + auto var = varSpace.getVariable(inputs[0]); + varSpace.putVariable({node->id(), 0}, *var->getNDArray()); - return ND4J_STATUS_OK; - */ + return Status::OK(); } + } // namespace graph } // namespace sd \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index b29b81da4d66..7e920694cc13 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1070,7 +1070,11 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while1) { graph.printOut(); - graph.execute(); + auto results = graph.execute({}, {"while/Exit", "while/Exit_1"}); + ASSERT_EQ(2, results.size()); + + ASSERT_EQ(NDArrayFactory::create(1.f), results["while/Exit"]); + ASSERT_EQ(NDArrayFactory::create(1.f), results["while/Exit_1"]); } TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { From 4f8fde26e9831847a5c4a6a26548a128c3618f97 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Fri, 19 Jun 2020 21:02:14 +0300 Subject: [PATCH 215/233] assertions for a new test Signed-off-by: raver119@gmail.com --- .../layers_tests/GraphAnalysisTests.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 7e920694cc13..47defbe38b44 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1077,6 +1077,23 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while1) { ASSERT_EQ(NDArrayFactory::create(1.f), results["while/Exit_1"]); } +TEST_F(GraphAnalysisTests, optimizedGraph_while2) { + auto graph = Graph::fromFlatBuffers("resources/while_iter3.fb"); + const auto& optimized = graph.optimizedGraph(); + + // this Graph must have exactly 1 Layer and 1 OpSequence, since all it has is While loop + ASSERT_EQ(1, optimized.layers()); + ASSERT_EQ(1, optimized.layer(0).width()); + + graph.printOut(); + + auto results = graph.execute({}, {"while/Exit", "while/Exit_1"}); + ASSERT_EQ(2, results.size()); + + ASSERT_EQ(NDArrayFactory::create(3.f), results["while/Exit"]); + ASSERT_EQ(NDArrayFactory::create(3.f), results["while/Exit_1"]); +} + TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { auto graph = Graph::fromFlatBuffers("resources/simplewhile_nested.fb"); const auto& optimized = graph.optimizedGraph(); From 0b00097254c9b90ca20176290cc5f78dc9568438 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Sat, 20 Jun 2020 12:05:48 +0300 Subject: [PATCH 216/233] full stack cleanup Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/execution/Stack.h | 1 + libnd4j/include/graph/execution/StackFrame.h | 7 ++++-- .../include/graph/execution/impl/Stack.cpp | 20 +++++++++++----- .../graph/execution/impl/StackFrame.cpp | 8 +++---- .../include/graph/logic/impl/LogicMerge.cpp | 24 ++++++++++++------- 5 files changed, 40 insertions(+), 20 deletions(-) diff --git a/libnd4j/include/graph/execution/Stack.h b/libnd4j/include/graph/execution/Stack.h index 8be487a0f5a7..20ffbaea0f93 100644 --- a/libnd4j/include/graph/execution/Stack.h +++ b/libnd4j/include/graph/execution/Stack.h @@ -32,6 +32,7 @@ class SD_EXPORT Stack { private: std::deque _frames; + int _counter = 0; public: Stack(const VariableProxy &root); ~Stack() = default; diff --git a/libnd4j/include/graph/execution/StackFrame.h b/libnd4j/include/graph/execution/StackFrame.h index 5d797830a77a..a94cda1f2df2 100644 --- a/libnd4j/include/graph/execution/StackFrame.h +++ b/libnd4j/include/graph/execution/StackFrame.h @@ -30,6 +30,7 @@ namespace graph { class SD_EXPORT StackFrame { private: + int _id; VariableProxy _proxy; StackFrame *_parent = nullptr; @@ -41,8 +42,8 @@ class SD_EXPORT StackFrame { mutable int _rewindId = -119; mutable int _exitId = -119; public: - explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId); - explicit StackFrame(const VariableProxy &proxy, int frameId, int enterId, StackFrame &parent); + explicit StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId); + explicit StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId, StackFrame &parent); ~StackFrame() = default; const VariableProxy& variableProxy() const { return _proxy; } @@ -65,6 +66,8 @@ class SD_EXPORT StackFrame { * @return */ StackFrame& parent() const; + + int id() const { return _id; } }; } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index 46efcfffdea0..eeafaf8e340c 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -24,7 +24,7 @@ namespace sd { namespace graph { Stack::Stack(const VariableProxy &root) { - _frames.push_back(StackFrame(const_cast(root), -1, 0)); + _frames.push_back(StackFrame(const_cast(root), _counter++, -1, 0)); } const VariableProxy &Stack::rootVariableSpace() const { @@ -44,21 +44,29 @@ StackFrame &Stack::root() { } void Stack::openFrame(int frameId, int enterId) { - _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId, _frames.back())); + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, _frames.back())); } void Stack::iterateFrame(int frameId, int enterId) { auto ¤t = this->back(); auto &parent = current.parent(); - _frames.emplace_back(StackFrame(_frames.back().variableProxy(), frameId, enterId, parent)); + _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, parent)); } void Stack::closeFrame() { // we should remove all frames untl we hit parent frame - auto ¤t = this->back(); - auto &parent = current.parent(); + auto &parent = this->back().parent(); + + while (!_frames.empty()) { + auto ¤t = this->back(); + + // if ID's match - we'll stop + if (current.id() == parent.id()) + break; + + _frames.pop_back(); + } - _frames.pop_back(); } } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 896e9183798e..127697f79028 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -24,13 +24,13 @@ namespace sd { namespace graph { -StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId) - : _proxy(proxy), _frameId(frameId), _enterId(enterId) { +StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId) + : _proxy(proxy), _frameId(frameId), _enterId(enterId), _id(id) { } -StackFrame::StackFrame(const VariableProxy &proxy, int frameId, int enterId, StackFrame &parent) - : StackFrame(proxy, frameId, enterId) { +StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId, StackFrame &parent) + : StackFrame(proxy, id, frameId, enterId) { _parent = &parent; } diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index fb3db88b9b82..2665d838f675 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -24,6 +24,14 @@ namespace sd { namespace graph { + +static bool isNextIterationCase(const OptimizedGraph& graph, int firstId, int secondId) { + const auto firstNode = graph.nodesMap().count(firstId) > 0 ? &graph.nodesMap().at(firstId) : nullptr; + const auto secondNode = graph.nodesMap().count(secondId) > 0 ? &graph.nodesMap().at(secondId) : nullptr; + + return (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) || (secondNode != nullptr && secondNode->opType() == OpType_LOGIC && secondNode->opNum() == sd::logic::NextIteration); +} + Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // getting current frame first auto &frame = stack.back(); @@ -36,15 +44,17 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz REQUIRE_TRUE(false, 0, "Merge: only 1 input should be disabled, but got both of them down"); } - const auto &firstNode = graph.nodesMap().at(inputs[0].first); - const auto &secondNode = graph.nodesMap().at(inputs[1].first); + // if we're on NextIteration merge, we'll propagate its output regardless of first arg existence + if (isNextIterationCase(graph, inputs[0].first, inputs[1].first)) { + const auto firstNode = graph.nodesMap().count(inputs[0].first) > 0 ? &graph.nodesMap().at(inputs[0].first) : nullptr; + const auto secondNode = graph.nodesMap().count(inputs[1].first) > 0 ? &graph.nodesMap().at(inputs[1].first) : nullptr; - if ((firstNode.opType() == OpType_LOGIC && firstNode.opNum() == sd::logic::NextIteration)|| (secondNode.opType() == OpType_LOGIC && secondNode.opNum() == sd::logic::NextIteration)) { - // if we're on NextIteration merge, we'll propagate its output regardless of first arg existence - if (firstNode.opType() == OpType_LOGIC && firstNode.opNum() == sd::logic::NextIteration) { + if (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) { + // we must check, if NextIteration Node already was executed. Or, pick initial value first auto id = varSpace.hasVariable(inputs[0]) && varSpace.getVariable(inputs[0])->hasNDArray() ? inputs[0] : inputs[1]; varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); } else { + // we must check, if NextIteration Node already was executed. Or, pick initial value first auto id = varSpace.hasVariable(inputs[1]) && varSpace.getVariable(inputs[1])->hasNDArray() ? inputs[1] : inputs[0]; varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); } @@ -54,9 +64,7 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz REQUIRE_TRUE(frame.variableProxy().hasVariable(p), 0, "Merge: Variable [%i:%i] doesn't exist", p.first, p.second); - std::pair t(node->id(), 0); - auto array = varSpace.getVariable(p)->getNDArray().get(); - varSpace.putVariable(t, *array); + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(p)->getNDArray()); } return Status::OK(); From b5b9319c36b542e17e4deb9d0357aedddd6ef824 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Sat, 20 Jun 2020 17:26:02 +0300 Subject: [PATCH 217/233] next step Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/impl/Graph.cpp | 13 +++++++++++-- libnd4j/include/graph/impl/OptimizedGraph.cpp | 6 +++--- libnd4j/include/graph/impl/Variable.cpp | 14 ++++++-------- .../tests_cpu/layers_tests/GraphAnalysisTests.cpp | 10 ++++++++++ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 5e32633cd43e..0123bc92cb01 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -143,6 +143,9 @@ Graph::Graph(const FlatGraph *flatGraph, const GraphMemoryManager &memoryManager } _variableSpace.putVariable(pair, var); + + if (var->isPlaceholder()) + _placeholders.emplace_back(var->name()); } } @@ -604,8 +607,14 @@ std::map Graph::execute( } // TODO: it would be nice if we'll print out unresolved placeholders - if (placeholdersCount != _placeholders.size()) - throw std::runtime_error("Some placeholders were not resolved"); + if (placeholdersCount != _placeholders.size()) { + std::string missing; + for (const auto &v:_placeholders) { + if (dictionary.count(v) == 0) + missing += "<" + v + ">, "; + } + throw std::runtime_error("Placeholders were not resolved: [" + missing + "]"); + } // we also must check existence of requested outputs for (const auto &v : outputs) { diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 94f3eaafe494..031dcfc70241 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -306,7 +306,7 @@ int OptimizedGraph::nodeLayer(int nodeId) const { cnt++; } - throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); + throw std::runtime_error("Layer of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); } int OptimizedGraph::nodeIndex(int nodeId) const { @@ -319,7 +319,7 @@ int OptimizedGraph::nodeIndex(int nodeId) const { } } - throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); + throw std::runtime_error("Index of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); } int OptimizedGraph::nodeSequence(int nodeId) const { @@ -332,7 +332,7 @@ int OptimizedGraph::nodeSequence(int nodeId) const { } } - throw std::runtime_error("Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); + throw std::runtime_error("Sequence of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); } void OptimizedGraph::printOut() const { diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index 1b13a6c2cdc3..dc6f74e2541f 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -66,7 +66,7 @@ void sd::graph::Variable::setVariableType(VariableType variableType) { bool sd::graph::Variable::hasNDArrayList() const { return _list != nullptr; } -bool sd::graph::Variable::isPlaceholder() const { return _placeholder; } +bool sd::graph::Variable::isPlaceholder() const { return _placeholder || _variableType == sd::graph::VariableType::PLACEHOLDER; } const std::string &sd::graph::Variable::name() const { return _name; } @@ -236,16 +236,12 @@ sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { _variableType = VariableType::NDARRAY; } break; case VarType_PLACEHOLDER: { - if (flatVariable->shape() == nullptr && - flatVariable->ndarray() == nullptr) - throw std::runtime_error( - "PLACEHOLDER variable must have shape defined"); + if (flatVariable->shape() == nullptr && flatVariable->ndarray() == nullptr) + throw std::runtime_error("PLACEHOLDER variable must have shape defined"); if (flatVariable->ndarray() != nullptr) { auto ar = flatVariable->ndarray(); - _ndarray = std::make_shared( - sd::graph::FlatUtils::fromFlatArray(ar)); - // _ndarray->triggerAllocationFlag(true); + _ndarray = std::make_shared(FlatUtils::fromFlatArray(ar)); _variableType = VariableType::NDARRAY; } @@ -255,6 +251,8 @@ sd::graph::Variable::Variable(const sd::graph::FlatVariable *flatVariable) { for (int i = 0; i < flatVariable->shape()->size(); i++) _shape.emplace_back(flatVariable->shape()->Get(i)); + _dtype = (sd::DataType) flatVariable->dtype(); + if (_ndarray == nullptr) _variableType = VariableType::PLACEHOLDER; } } break; diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index 47defbe38b44..d3522207bb30 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1095,7 +1095,17 @@ TEST_F(GraphAnalysisTests, optimizedGraph_while2) { } TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create('c', {3, 3}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + auto graph = Graph::fromFlatBuffers("resources/simplewhile_nested.fb"); const auto& optimized = graph.optimizedGraph(); + + ASSERT_EQ(2, graph.placeholders().size()); + graph.printOut(); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {13.f, 14.f, 15.f, 16.f}), results["output"]); } \ No newline at end of file From 5d74719bae08710bd288fbf2de44595d67f6662b Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 22 Jun 2020 11:08:01 +0300 Subject: [PATCH 218/233] bunch of small tweaks Signed-off-by: raver119@gmail.com --- .../include/graph/execution/impl/Stack.cpp | 5 +- .../graph/execution/impl/StackFrame.cpp | 4 +- .../include/graph/logic/impl/LogicEnter.cpp | 55 ++--- .../include/graph/logic/impl/LogicExit.cpp | 2 + .../include/graph/logic/impl/LogicMerge.cpp | 24 +++ .../include/graph/logic/impl/LogicSwitch.cpp | 2 + libnd4j/include/helpers/StringUtils.h | 201 ++++++++++-------- 7 files changed, 170 insertions(+), 123 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/Stack.cpp b/libnd4j/include/graph/execution/impl/Stack.cpp index eeafaf8e340c..127a45d3cbb5 100644 --- a/libnd4j/include/graph/execution/impl/Stack.cpp +++ b/libnd4j/include/graph/execution/impl/Stack.cpp @@ -45,18 +45,22 @@ StackFrame &Stack::root() { void Stack::openFrame(int frameId, int enterId) { _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, _frames.back())); + nd4j_printf("Opening frame [%i], parent: [%i]\n", _frames.back().id(), _frames.back().parent().id()); } void Stack::iterateFrame(int frameId, int enterId) { auto ¤t = this->back(); auto &parent = current.parent(); _frames.emplace_back(StackFrame(_frames.back().variableProxy(), _counter++, frameId, enterId, parent)); + nd4j_printf("Iterating frame, parent: [%i]\n", parent.id()); } void Stack::closeFrame() { // we should remove all frames untl we hit parent frame auto &parent = this->back().parent(); + nd4j_printf("Collapsed frame [%i], parent: [%i]\n", this->back().id(), parent.id()); + while (!_frames.empty()) { auto ¤t = this->back(); @@ -66,7 +70,6 @@ void Stack::closeFrame() { _frames.pop_back(); } - } } // namespace graph diff --git a/libnd4j/include/graph/execution/impl/StackFrame.cpp b/libnd4j/include/graph/execution/impl/StackFrame.cpp index 127697f79028..115aa03c5b3b 100644 --- a/libnd4j/include/graph/execution/impl/StackFrame.cpp +++ b/libnd4j/include/graph/execution/impl/StackFrame.cpp @@ -25,9 +25,7 @@ namespace sd { namespace graph { StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId) - : _proxy(proxy), _frameId(frameId), _enterId(enterId), _id(id) { - -} + : _proxy(proxy), _frameId(frameId), _enterId(enterId), _id(id) { } StackFrame::StackFrame(const VariableProxy &proxy, int id, int frameId, int enterId, StackFrame &parent) : StackFrame(proxy, id, frameId, enterId) { diff --git a/libnd4j/include/graph/logic/impl/LogicEnter.cpp b/libnd4j/include/graph/logic/impl/LogicEnter.cpp index 2a10fc184dcb..76e29c18dfd7 100644 --- a/libnd4j/include/graph/logic/impl/LogicEnter.cpp +++ b/libnd4j/include/graph/logic/impl/LogicEnter.cpp @@ -31,15 +31,19 @@ namespace graph { * - Opens new StackFrame (only if that's the first Enter in this Loop) */ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { - // the only possible case for this equality is NextIteration rewind - if (node->id() == stack.back().enterId()) { - stack.iterateFrame(node->frameId(), node->id()); - } else { - // if current frameName isn't equal to node frame name - we'll open new StackFrame then - if (node->frameId() != stack.back().frameId()) { - // since this is the loop entrance, we'll rewind to this Node once iteration ends - stack.openFrame(node->frameId(), node->id()); + if (node->exitId() >= 0) { + // the only possible case for this equality is NextIteration rewind + if (node->id() == stack.back().enterId()) { + stack.iterateFrame(node->frameId(), node->id()); + } else { + // if current frameName isn't equal to node frame name - we'll open new StackFrame then + if (node->frameId() != stack.back().frameId()) { + // since this is the loop entrance, we'll rewind to this Node once iteration ends + stack.openFrame(node->frameId(), node->id()); + } } + } else { + //nd4j_printf("Neg\n", ""); } // getting current frame (it might be the new one!) @@ -49,33 +53,34 @@ Nd4jStatus LogicEnter::processNode(const Node *node, Stack &stack, const Optimiz auto &varSpace = const_cast(frame.variableProxy()); // and we need to find exit point - it has to be Exit node, with max index within OpSequence - auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; - auto thisExitIndex = graph.nodeIndex(node->exitId()); + if (node->exitId() >= 0) { + auto currentExitIndex = frame.exitId() >= 0 ? graph.nodeIndex(frame.exitId()) : -1; + auto thisExitIndex = graph.nodeIndex(node->exitId()); - // we want to exit after the last Exit node - if (thisExitIndex > currentExitIndex) - frame.setExitId(node->exitId()); + // we want to exit after the last Exit node + if (thisExitIndex > currentExitIndex) + frame.setExitId(node->exitId()); - // we need to find rewind point - it has to be NextIteration node with max index within OpSequence - const auto &merge = graph.nodesMap().at(outputs[0].first); - const auto &iter = graph.nodesMap().at(merge.inputs()[1].first); + // we need to find rewind point - it has to be NextIteration node with max index within OpSequence + const auto &merge = graph.nodesMap().at(outputs[0].first); + const auto &iter = graph.nodesMap().at(merge.inputs()[1].first); - // we must compare index of this NextIteration Node within OpSequence to the current one, if it's set - auto currentRewindIndex = frame.rewindId() >= 0 ? graph.nodeIndex(frame.rewindId()) : -1; - auto thisRewindIndex = graph.nodeIndex(iter.id()); + // we must compare index of this NextIteration Node within OpSequence to the current one, if it's set + auto currentRewindIndex = frame.rewindId() >= 0 ? graph.nodeIndex(frame.rewindId()) : -1; + auto thisRewindIndex = graph.nodeIndex(iter.id()); - // we want to rewind after the last NextIteration node - if (thisRewindIndex > currentRewindIndex) - frame.setRewindId(iter.id()); + // we want to rewind after the last NextIteration node + if (thisRewindIndex > currentRewindIndex) + frame.setRewindId(iter.id()); + } // validate Node state REQUIRE_TRUE(inputs.size() == 1, 0, "Enter: op must have exactly 1 inputs"); REQUIRE_TRUE(varSpace.hasVariable(inputs[0]), 0, "Enter: input Variable doesn't exist"); // now we propagate input as own output - // ssince we've opened new StackFrame, this Variable will end up in new VariableProxy - auto input = varSpace.getVariable(inputs[0]); - varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); + // ssince we've opened new StackFrame, this Variable will end up in current VariableProxy + varSpace.putVariable(std::pair{node->id(), 0}, *varSpace.getVariable(inputs[0])->getNDArray()); return sd::Status::OK(); } diff --git a/libnd4j/include/graph/logic/impl/LogicExit.cpp b/libnd4j/include/graph/logic/impl/LogicExit.cpp index 03c19b71cd99..647152c544af 100644 --- a/libnd4j/include/graph/logic/impl/LogicExit.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExit.cpp @@ -41,6 +41,8 @@ Nd4jStatus LogicExit::processNode(const Node *node, Stack &stack, const Optimize REQUIRE_TRUE(inputs.size() == 1, 0, "Exit: op must have exactly 1 input1"); REQUIRE_TRUE(frame.variableProxy().hasVariable(inputs[0]), 0, "Exit: input Variable doesn't exist"); + nd4j_printf("Propagating Variable to the Frame [%i]\n", parent.id()); + // get Variable from current VariableProxy and put to the ParentOne auto var = frame.variableProxy().getVariable(inputs[0]); const_cast(parent.variableProxy()).putVariable({node->id(), 0}, *var->getNDArray()); diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index 2665d838f675..aea0c8b3bb61 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -21,6 +21,7 @@ #include #include +#include namespace sd { namespace graph { @@ -32,6 +33,13 @@ static bool isNextIterationCase(const OptimizedGraph& graph, int firstId, int se return (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) || (secondNode != nullptr && secondNode->opType() == OpType_LOGIC && secondNode->opNum() == sd::logic::NextIteration); } +static bool checkViability(Stack &stack, const std::pair &first, const std::pair &second) { + auto &frame = stack.back(); + auto &varSpace = const_cast(frame.variableProxy()); + + return (!frame.isDisabled(first.first) && varSpace.hasVariable(first) && varSpace.getVariable(first)->hasNDArray()) || (!frame.isDisabled(second.first) && varSpace.hasVariable(second) && varSpace.getVariable(second)->hasNDArray()); +} + Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { // getting current frame first auto &frame = stack.back(); @@ -40,10 +48,20 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz auto &varSpace = const_cast(frame.variableProxy()); REQUIRE_TRUE(inputs.size() == 2, 0, "Merge: op expects exactly 2 inputs, but only %i defined", (int) inputs.size()); + + // if both inputs are unavailable - this node is disabled and must be disabled + if (!checkViability(stack, inputs[0], inputs[1])) { + nd4j_printf("Both inputs absent, skipping\n", ""); + // TODO: disable this branch + return Status::OK(); + } + if (frame.isDisabled(inputs[0].first) && frame.isDisabled(inputs[1].first)) { REQUIRE_TRUE(false, 0, "Merge: only 1 input should be disabled, but got both of them down"); } + + // if we're on NextIteration merge, we'll propagate its output regardless of first arg existence if (isNextIterationCase(graph, inputs[0].first, inputs[1].first)) { const auto firstNode = graph.nodesMap().count(inputs[0].first) > 0 ? &graph.nodesMap().at(inputs[0].first) : nullptr; @@ -52,10 +70,16 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz if (firstNode != nullptr && firstNode->opType() == OpType_LOGIC && firstNode->opNum() == sd::logic::NextIteration) { // we must check, if NextIteration Node already was executed. Or, pick initial value first auto id = varSpace.hasVariable(inputs[0]) && varSpace.getVariable(inputs[0])->hasNDArray() ? inputs[0] : inputs[1]; + if (!varSpace.hasVariable(id) || !varSpace.getVariable(id)->hasNDArray()) + throw std::runtime_error("Non-existent NDArray requested: [" + StringUtils::valueToString(id) +"]"); + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); } else { // we must check, if NextIteration Node already was executed. Or, pick initial value first auto id = varSpace.hasVariable(inputs[1]) && varSpace.getVariable(inputs[1])->hasNDArray() ? inputs[1] : inputs[0]; + if (!varSpace.hasVariable(id) || !varSpace.getVariable(id)->hasNDArray()) + throw std::runtime_error("Non-existent NDArray requested: [" + StringUtils::valueToString(id) +"]"); + varSpace.putVariable({node->id(), 0}, *varSpace.getVariable(id)->getNDArray()); } } else { diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index c2a93f227908..461036f0a098 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -94,6 +94,8 @@ Nd4jStatus LogicSwitch::processNode(const Node* node, Stack &stack, const Optimi REQUIRE_TRUE(boolean->hasNDArray(), 0, "Switch: boolean Variable must have NDArray defined"); + nd4j_printf("Switch [%i] evaluated as [%s]\n", node->id(), boolean->getNDArray()->e(0) ? "true" : "false"); + if (boolean->getNDArray()->e(0)) { // true branch varSpace.putVariable(std::pair{node->id(), 1}, *input->getNDArray()); diff --git a/libnd4j/include/helpers/StringUtils.h b/libnd4j/include/helpers/StringUtils.h index 73958cb2d3f3..2eaba7e7e916 100644 --- a/libnd4j/include/helpers/StringUtils.h +++ b/libnd4j/include/helpers/StringUtils.h @@ -16,7 +16,7 @@ ******************************************************************************/ // -// Created by raver119 on 20/04/18. +// @author raver110@gmail.com // @author Oleg Semeniv // @@ -36,31 +36,23 @@ namespace sd { class SD_EXPORT StringUtils { public: template - static FORCEINLINE std::string valueToString(T value) { - std::ostringstream os; + static FORCEINLINE std::string valueToString(T value); - os << value; - - // convert the string stream into a string and return - return os.str(); - } + /** + * These methods convert integer values to string with 0s and 1s + * @param value + * @return + */ + template + static std::string bitsToString(T value); /** - * These methods convert integer values to string with 0s and 1s - * @param value - * @return - */ - template - static std::string bitsToString(T value); - - /** - * This method just concatenates error message with a given graphId - * @param message - * @param graphId - * @return - */ - static FORCEINLINE std::string buildGraphErrorMessage(const char* message, - Nd4jLong graphId) { + * This method just concatenates error message with a given graphId + * @param message + * @param graphId + * @return + */ + static FORCEINLINE std::string buildGraphErrorMessage(const char* message, Nd4jLong graphId) { std::string result(message); result += " ["; result += valueToString(graphId); @@ -70,89 +62,110 @@ class SD_EXPORT StringUtils { } /** - * This method returns number of needle matches within haystack - * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 - * - * @param haystack - * @param haystackLength - * @param needle - * @param needleLength - * @return - */ + * This method returns number of needle matches within haystack + * PLEASE NOTE: this method operates on 8-bit arrays interpreted as uint8 + * + * @param haystack + * @param haystackLength + * @param needle + * @param needleLength + * @return + */ static uint64_t countSubarrays(const void* haystack, uint64_t haystackLength, const void* needle, uint64_t needleLength); - /** - * This method returns number of bytes used for string NDArrays content - * PLEASE NOTE: this doesn't include header - * - * @param array - * @return - */ + /** + * This method returns number of bytes used for string NDArrays content + * PLEASE NOTE: this doesn't include header + * + * @param array + * @return + */ static uint64_t byteLength(const NDArray& array); - /** - * This method splits a string into substring by delimiter - * - * @param haystack - * @param delimiter - * @return - */ - static std::vector split(const std::string& haystack, - const std::string& delimiter); - - /** - * This method convert u8 string to u16 - * @param const reference to input string - * @param reference to output u16string - * @return boolean status - */ + /** + * This method splits a string into substring by delimiter + * + * @param haystack + * @param delimiter + * @return + */ + static std::vector split(const std::string& haystack, const std::string& delimiter); + + /** + * This method convert u8 string to u16 + * @param const reference to input string + * @param reference to output u16string + * @return boolean status + */ static bool u8StringToU16String(const std::string& u8, std::u16string& u16); - /** - * This method convert u8 string to u32 - * @param const reference to input string - * @param reference to output u32string - * @return boolean status - */ + /** + * This method convert u8 string to u32 + * @param const reference to input string + * @param reference to output u32string + * @return boolean status + */ static bool u8StringToU32String(const std::string& u8, std::u32string& u32); - /** - * This method convert u16 string to u32 - * @param const reference to input u16string - * @param reference to output u32string - * @return boolean status - */ - static bool u16StringToU32String(const std::u16string& u16, - std::u32string& u32); - - /** - * This method convert u16 string to u8 string - * @param const reference to input u16string - * @param reference to output string - * @return boolean status - */ + /** + * This method convert u16 string to u32 + * @param const reference to input u16string + * @param reference to output u32string + * @return boolean status + */ + static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32); + + /** + * This method convert u16 string to u8 string + * @param const reference to input u16string + * @param reference to output string + * @return boolean status + */ static bool u16StringToU8String(const std::u16string& u16, std::string& u8); - /** - * This method convert u32 string to u16 string - * @param const reference to input u32string - * @param reference to output u16string - * @return boolean status - */ - static bool u32StringToU16String(const std::u32string& u32, - std::u16string& u16); - - /** - * This method convert u32 string to u8 string - * @param const reference to input u32string - * @param reference to output string - * @return boolean status - */ + /** + * This method convert u32 string to u16 string + * @param const reference to input u32string + * @param reference to output u16string + * @return boolean status + */ + static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16); + + /** + * This method convert u32 string to u8 string + * @param const reference to input u32string + * @param reference to output string + * @return boolean status + */ static bool u32StringToU8String(const std::u32string& u32, std::string& u8); - template - static std::string vectorToString(const std::vector &vec); - };} // namespace sd + template + static std::string vectorToString(const std::vector &vec); +}; + +template <> +FORCEINLINE std::string StringUtils::valueToString(std::pair value) { + std::ostringstream os; + + os << value.first; + os << ":"; + os << value.second; + + // convert the string stream into a string and return + return os.str(); +} + +template +FORCEINLINE std::string StringUtils::valueToString(T value) { + std::ostringstream os; + + os << value; + + // convert the string stream into a string and return + return os.str(); +} + +} // namespace sd #endif // LIBND4J_STRINGUTILS_H From ba73c1097f8bb758003ce2b1af6735353fc1da6c Mon Sep 17 00:00:00 2001 From: Yurii Date: Mon, 22 Jun 2020 22:14:22 +0300 Subject: [PATCH 219/233] - apply another toposort for switch node Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 135 +++++++++--------- 1 file changed, 64 insertions(+), 71 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 031dcfc70241..f2eff634cf83 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -146,13 +146,29 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS workMap.erase(i); // lambda for topological sort - std::function visit = [&visit, &workMap] (const int id, const uint layerNum, uint& numOfLayers) { + std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { - if(layerNum <= workMap[id]._layerNum) { return; } + if(layerNum <= workMap[id]._layerNum) + return; workMap[id]._layerNum = layerNum; - if(numOfLayers < layerNum) { numOfLayers = layerNum; } - for (const auto& nextId : workMap[id]._out) - visit(nextId, layerNum+1, numOfLayers); + if(numOfLayers < layerNum) + numOfLayers = layerNum; + + const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; + if(!isSwitch) { + for (const auto& nextId : workMap[id]._out) + visit(nextId, layerNum+1, numOfLayers); + } + else { + if(this->_nodesMap[id].outputs()[0].second == 1) { + visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch + visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + } + else { + visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch + visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + } + } }; // perform topological sort @@ -161,89 +177,66 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (const auto& nextId : workMap[id]._out) visit(nextId, 1, numOfLayers); - // gather all nodes belonging to loop and store them in one OpSequence - const char delimiter = '/'; - for (auto p0 = workMap.begin(); p0 != std::prev(workMap.end()); ++p0) { // std::prev(workMap.end()) == workMap.end() - 1 - if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) - continue; - // const auto& name = _nodesMap[p0->first].name(); - // if(name.find("Enter") == std::string::npos) - // continue; - bool isInLoop = false; - auto* name = &_nodesMap[p0->first].name(); - - if(name->find("Enter") == std::string::npos) { - for (const auto& id : p0->second._opSeq) { - name = &_nodesMap[id].name(); - if(name->find("Enter") != std::string::npos) { - isInLoop = true; - break; - } - } - } - else - isInLoop = true; - - if(!isInLoop) - continue; - - std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop - - for (auto p1 = std::next(p0); p1 != workMap.end(); ++p1) { // std::next(p0) = p0 + 1 - - if(!p1->second._opSeq.empty() && p1->second._opSeq[0] == -1) - continue; - - isInLoop = false; - name = &_nodesMap[p1->first].name(); - - if(name->find(loopName) == std::string::npos) { - for (const auto& id : p1->second._opSeq) { - name = &_nodesMap[id].name(); - if(name->find(loopName) != std::string::npos) { - isInLoop = true; - break; - } - } - } - else - isInLoop = true; - - if(!isInLoop) - continue; - - p0->second._opSeq.push_back(p1->first); - p0->second._opSeq.insert(p0->second._opSeq.end(), p1->second._opSeq.begin(), p1->second._opSeq.end()); - p1->second._opSeq.clear(); - p1->second._opSeq.push_back(-1); // mark node to be neglected - // p1->second._layerNum = p0->second._layerNum; - } - } // fill vectors with layers - std::vector sortedGraphTemp;//(numOfLayers+1); + std::vector sortedGraphTemp(numOfLayers+1); for (const auto& p : workMap) { - if(!p.second._opSeq.empty() && p.second._opSeq[0] == -1) - continue; - OpSequence seq; seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); for (const auto& id : p.second._opSeq) seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); - while(sortedGraphTemp.size() <= p.second._layerNum) - sortedGraphTemp.emplace_back(ExecutionLayer()); + // while(sortedGraphTemp.size() <= p.second._layerNum) + // sortedGraphTemp.emplace_back(ExecutionLayer()); sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } + const char delimiter = '/'; + for (int i0 = 0; i0 < sortedGraphTemp.size(); ++i0) { + for (int i1 = 0; i1 < sortedGraphTemp[i0].width(); ++i1) { + for (int i2 = 0; i2 < sortedGraphTemp[i0][i1].length(); ++i2) { + // if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) + // continue; + + auto id = sortedGraphTemp[i0][i1][i2].node().id(); + auto* name = &_nodesMap[id].name(); + if (name->find("Enter") == std::string::npos) + continue; + std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop + for (int j0 = i0; j0 < sortedGraphTemp.size(); ++j0) { + for (int j1 = j0 == i0 ? i1 + 1 : 0; j1 < sortedGraphTemp[j0].width(); ++j1) { + for (int j2 = 0; j2 < sortedGraphTemp[j0][j1].length(); ++j2) { + id = sortedGraphTemp[j0][j1][j2].node().id(); + name = &_nodesMap[id].name(); + if (name->find(loopName) == std::string::npos) + continue; + for (int k = 0; k < sortedGraphTemp[j0][j1].length(); ++k) + const_cast(sortedGraphTemp[i0][i1]).append(sortedGraphTemp[j0][j1][k]); + const_cast(sortedGraphTemp[j0][j1]) = OpSequence(); + break; + } + } + } + break; + } + } } // delete empty layers - for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) - if(it->width() == 0) + for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) { + bool isEmpty = true; + for (uint i = 0; i < it->width(); ++i) { + if(it->at(i).length() != 0) + isEmpty = false; + break; + } + + if(isEmpty) sortedGraphTemp.erase(it--); + } // check whether there are layers with one OpSequence which in turn contains only one op bool isLayerWithOneOp = false; From e27d87d7764960f7791a646501ae6eb888a402d3 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 23 Jun 2020 09:05:55 +0300 Subject: [PATCH 220/233] - purge methods for ExecutionLayer and OptimizedGraph - remove legacy Logic ops Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/OptimizedGraph.h | 4 + .../include/graph/execution/ExecutionLayer.h | 5 + libnd4j/include/graph/execution/OpSequence.h | 8 +- .../graph/execution/impl/ExecutionLayer.cpp | 3 + .../graph/execution/impl/GraphExecutor.cpp | 15 +- .../graph/execution/impl/OpSequence.cpp | 14 ++ libnd4j/include/graph/impl/OptimizedGraph.cpp | 12 ++ .../include/graph/logic/LogicConditional.h | 50 ------ libnd4j/include/graph/logic/LogicExpose.h | 38 ----- libnd4j/include/graph/logic/LogicReturn.h | 46 ------ libnd4j/include/graph/logic/LogicScope.h | 46 ------ libnd4j/include/graph/logic/LogicWhile.h | 46 ------ .../graph/logic/impl/LogicConditional.cpp | 140 ----------------- .../graph/logic/impl/LogicExecutor.cpp | 17 +- .../include/graph/logic/impl/LogicExpose.cpp | 31 ---- .../include/graph/logic/impl/LogicReturn.cpp | 64 -------- .../include/graph/logic/impl/LogicScope.cpp | 33 ---- .../include/graph/logic/impl/LogicWhile.cpp | 147 ------------------ .../layers_tests/ExecutionLayerTests.cpp | 23 +++ 19 files changed, 79 insertions(+), 663 deletions(-) delete mode 100644 libnd4j/include/graph/logic/LogicConditional.h delete mode 100644 libnd4j/include/graph/logic/LogicExpose.h delete mode 100644 libnd4j/include/graph/logic/LogicReturn.h delete mode 100644 libnd4j/include/graph/logic/LogicScope.h delete mode 100644 libnd4j/include/graph/logic/LogicWhile.h delete mode 100644 libnd4j/include/graph/logic/impl/LogicConditional.cpp delete mode 100644 libnd4j/include/graph/logic/impl/LogicExpose.cpp delete mode 100644 libnd4j/include/graph/logic/impl/LogicReturn.cpp delete mode 100644 libnd4j/include/graph/logic/impl/LogicScope.cpp delete mode 100644 libnd4j/include/graph/logic/impl/LogicWhile.cpp diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 2b0736b97be9..69032a102dd1 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -36,6 +36,10 @@ class SD_EXPORT OptimizedGraph { std::vector _sortedGraph; MAP_IMPL _nodesMap; + /** + * This method removes all empty ExecutionLayers from this graph + */ + void purgeEmptyLayers(); public: OptimizedGraph(const MAP_IMPL& map, const VariableSpace& varSpace); // move constructor diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index 24b3c1bc2910..d53aba220dcf 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -78,6 +78,11 @@ class SD_EXPORT ExecutionLayer { * @return */ bool hasNode(int nodeId) const; + + /** + * This method removes all empty OpSequences from this layer + */ + void purgeEmptySequences(); }; } // namespace graph diff --git a/libnd4j/include/graph/execution/OpSequence.h b/libnd4j/include/graph/execution/OpSequence.h index f3ad8e157a82..9ad7b10f8d21 100644 --- a/libnd4j/include/graph/execution/OpSequence.h +++ b/libnd4j/include/graph/execution/OpSequence.h @@ -88,7 +88,11 @@ class SD_EXPORT OpSequence * @return */ const ExecutionTask& at(uint64_t index) const; + ExecutionTask& at(uint64_t index); + const ExecutionTask& operator[](uint64_t index) const; + ExecutionTask& operator[](uint64_t index); + /** * This method allows to add DeclarableOp to the end of execution queue @@ -96,10 +100,10 @@ class SD_EXPORT OpSequence * @param ctx - ContextPrototype for this operation with inputs/outputs/args * defined */ - void append(const Node& node, - const sd::graph::ContextPrototype& ctx); + void append(const Node& node, const sd::graph::ContextPrototype& ctx); void append(const ExecutionTask& task); void append(ExecutionTask&& task); + void append(const OpSequence &sequence); /** * These two methods provide access to index/id dictionalries diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index 24ba6d5117ba..d01fc2dccc9e 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -95,6 +95,9 @@ void ExecutionLayer::sortOpSequences() { // bubble sort } } +void ExecutionLayer::purgeEmptySequences() { + _sequences.erase(std::remove_if(_sequences.begin(), _sequences.end(), [](OpSequence &seq) -> bool { return seq.length() == 0; }), _sequences.end()); +} } // namespace graph } // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 867834fc90c9..99465d80c65e 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -85,9 +85,10 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, if (v.node().opType() == OpType_LOGIC) { - nd4j_printf("Node <%i:%s> is a logic op\n", + nd4j_printf("Node <%i:%s> is a logic op; Current frame: [%i]\n", v.node().id(), - v.node().name().empty() ? "" : v.node().name().c_str()); + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); LogicExecutor::processNode(&v.node(), stack, graph); @@ -97,12 +98,18 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, e = seq.nodeIndex(f.enterId()) - 1; } } else if (v.node().hasCustomOp()) { + nd4j_printf("Node <%i:%s> is a custom op; Current frame: [%i]\n", + v.node().id(), + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); + // only Ops can be executed this way :( result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); } else { - nd4j_printf("Node <%i:%s> has no customOp set\n", + nd4j_printf("Node <%i:%s> has no customOp set; Current frame: [%i]\n", v.node().id(), - v.node().name().empty() ? "" : v.node().name().c_str()); + v.node().name().empty() ? "" : v.node().name().c_str(), + f.id()); } // if any one op fails - there will be no sense in executing other ops diff --git a/libnd4j/include/graph/execution/impl/OpSequence.cpp b/libnd4j/include/graph/execution/impl/OpSequence.cpp index 6d8e54be277d..cc073d3dd26a 100644 --- a/libnd4j/include/graph/execution/impl/OpSequence.cpp +++ b/libnd4j/include/graph/execution/impl/OpSequence.cpp @@ -82,10 +82,18 @@ const ExecutionTask &OpSequence::at(uint64_t index) const { return _ops[index]; } +ExecutionTask &OpSequence::at(uint64_t index) { + return _ops[index]; +} + const ExecutionTask &OpSequence::operator[](uint64_t index) const { return at(index); } +ExecutionTask &OpSequence::operator[](uint64_t index) { + return at(index); +} + uint64_t OpSequence::length() const { return _ops.size(); } void OpSequence::append(const Node &node, @@ -112,6 +120,12 @@ void OpSequence::append(ExecutionTask&& task) { _indexToId[index] = task.node().id(); } +void OpSequence::append(const OpSequence &sequence) { + for (const auto &v:sequence._ops) { + this->append(v); + } +} + int OpSequence::nodeId(int index) const { if (index < 0 || index >= _ops.size() || _indexToId.count(index) < 1) throw std::runtime_error("Out-of-size index requested: " + StringUtils::valueToString(index)); diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index f2eff634cf83..d448658d99a3 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -269,6 +269,9 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS // sort _sortedGraph // for (auto& l : _sortedGraph) // l.sortOpSequences(); + + // clean up before exiting + purgeEmptyLayers(); } void OptimizedGraph::append(const OpSequence &sequence) { @@ -328,6 +331,15 @@ int OptimizedGraph::nodeSequence(int nodeId) const { throw std::runtime_error("Sequence of Node [" + StringUtils::valueToString(nodeId) + "] wasn't found in OptimizedGraph"); } +void OptimizedGraph::purgeEmptyLayers() { + // purge empty sequences first, if any + for (auto &v:_sortedGraph) + v.purgeEmptySequences(); + + // now purge all layers without sequences + _sortedGraph.erase(std::remove_if(_sortedGraph.begin(), _sortedGraph.end(), [](ExecutionLayer &layer) -> bool { return layer.width() == 0; }), _sortedGraph.end()); +} + void OptimizedGraph::printOut() const { for (uint i = 0; i < _sortedGraph.size(); ++i) { printf("Layer [%u] {\n", i); diff --git a/libnd4j/include/graph/logic/LogicConditional.h b/libnd4j/include/graph/logic/LogicConditional.h deleted file mode 100644 index 52d34933d770..000000000000 --- a/libnd4j/include/graph/logic/LogicConditional.h +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef SD_LOGICCONDITIONAL_H -#define SD_LOGICCONDITIONAL_H - -#include -#include -#include - -namespace sd { -namespace graph { -/** - * This class is responsible for execution logic of Conditional logical - * abstraction - * - * TL/DR: Class takes 2 ops/scopes with the same number of inputs/outputs and - * condtion. Condition is evaluated, and based on its result - one of ops/scopes - * is executed. Results of this execution will be copied to Conditional node, - * and every other op in the graph will be sure that it's Conditional own - * result, both alternative nodes will stay in disguise. - * - * @tparam T - */ -class LogicConditional { - public: - static Nd4jStatus processNode(const Node* node); -}; -} // namespace graph -} // namespace sd - -#endif // SD_LOGICCONDITIONAL_H diff --git a/libnd4j/include/graph/logic/LogicExpose.h b/libnd4j/include/graph/logic/LogicExpose.h deleted file mode 100644 index 2b5eaa677349..000000000000 --- a/libnd4j/include/graph/logic/LogicExpose.h +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 12.11.2017. -// - -#ifndef SD_LOGICEXPOSE_H -#define SD_LOGICEXPOSE_H - -#include -#include -#include - -namespace sd { -namespace graph { -class LogicExpose { - public: - static Nd4jStatus processNode(const Node* node); -}; -} // namespace graph -} // namespace sd - -#endif // SD_LOGICEXPOSE_H diff --git a/libnd4j/include/graph/logic/LogicReturn.h b/libnd4j/include/graph/logic/LogicReturn.h deleted file mode 100644 index eef1cdc6a9b5..000000000000 --- a/libnd4j/include/graph/logic/LogicReturn.h +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 28.10.2017. -// - -#ifndef SD_LOGICRETURN_H -#define SD_LOGICRETURN_H - -#include -#include -#include - -namespace sd { -namespace graph { -/** - * This class is responsible for execution logic of Return logical abstraction - * - * Basically we're just transferring input variable(s) to output variable(s), - * nothing beyond that - * @tparam T - */ -class LogicReturn { - public: - static Nd4jStatus processNode(const Node* node); -}; - -} // namespace graph -} // namespace sd - -#endif // SD_LOGICRETURN_H diff --git a/libnd4j/include/graph/logic/LogicScope.h b/libnd4j/include/graph/logic/LogicScope.h deleted file mode 100644 index cba7e9d0041f..000000000000 --- a/libnd4j/include/graph/logic/LogicScope.h +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef SD_LOGICSCOPE_H -#define SD_LOGICSCOPE_H - -#include -#include -#include - -namespace sd { -namespace graph { -/** - * This class is responsible for execution logic of Scope logical abstraction - * - * It's ultra-simple. It does nothing, and can't be executed directly. - * - * @tparam T - */ -class LogicScope { - public: - static Nd4jStatus processNode(const Node* node); -}; - -} // namespace graph -} // namespace sd - -#endif // SD_LOGICSCOPE_H diff --git a/libnd4j/include/graph/logic/LogicWhile.h b/libnd4j/include/graph/logic/LogicWhile.h deleted file mode 100644 index cc4d6cdf2758..000000000000 --- a/libnd4j/include/graph/logic/LogicWhile.h +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#ifndef SD_LOGICWHILE_H -#define SD_LOGICWHILE_H - -#include -#include -#include - -namespace sd { -namespace graph { -/** - * This class is responsible for execution logic of While logical abstraction - * - * Basic idea is simple: we take 2 scopes, one for condition and other one for - * body. and we re-execute body as long, as condition scope evaluates to TRUE - * @tparam T - */ -class LogicWhile { - public: - static Nd4jStatus processNode(const Node* node); -}; - -} // namespace graph -} // namespace sd - -#endif // SD_LOGICWHILE_H diff --git a/libnd4j/include/graph/logic/impl/LogicConditional.cpp b/libnd4j/include/graph/logic/impl/LogicConditional.cpp deleted file mode 100644 index 392a1db6ebbc..000000000000 --- a/libnd4j/include/graph/logic/impl/LogicConditional.cpp +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include -#include - -namespace sd { -namespace graph { -Nd4jStatus LogicConditional::processNode(const Node *node) { - throw std::runtime_error( - "LogicConditional::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - - auto size = node->input()->size(); - - // propagating inputs (optional) - for (int e = 0; e < size - 3; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, - node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - - int scopeConditionIndex = node->input()->at(size - 3).first; - int scopeFalseIndex = node->input()->at(size - 2).first; - int scopeTrueIndex = node->input()->at(size - 1).first; - - auto scopeCondition = graph->scopeById(scopeConditionIndex); - int lastNode = 0; - for (auto v: *scopeCondition->nodes()) { - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // now we should take result of the Scope run, and evaluate it - //nd4j_debug("", ""); - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - //result->printBuffer("Result of the last node:"); - - bool isReturn = false; - - // now we're executing one of the scopes, depending on condition evaluation - if (result->e(0) == 0) { - auto scopeFalse = graph->scopeById(scopeFalseIndex); - lastNode = 0; - int nodes = scopeFalse->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeFalse->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto *node = scopeFalse->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } else { - auto scopeTrue = graph->scopeById(scopeTrueIndex); - lastNode = 0; - int nodes = scopeTrue->nodes()->size(); - for (int e = 0; e < nodes - 1; e++) { - auto v = scopeTrue->nodes()->at(e); - GraphExecutioner::executeFlatNode(graph, v, __variableSpace); - lastNode = v->id(); - } - - // last node is either return or just last op - auto node = scopeTrue->nodes()->at(nodes -1); - if (node->opType() == OpType_LOGIC && node->opNum() == 40) { - isReturn = true; - LogicReturn::processNode(graph, node); - } else { - GraphExecutioner::executeFlatNode(graph, node, __variableSpace); - lastNode = node->id(); - } - } - - // now fetch and transfer variables to Conditional node - // but only if return wasn't called at the end of scope - if (!isReturn) { - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(lastNode, e); - std::pair pairNew(node->id(), e); - if (__variableSpace->hasVariable(pair)) { - auto array = __variableSpace->getVariable(pair)->getNDArray(); - auto newVar = new Variable(array); - newVar->setId(lastNode, e); - newVar->markRemovable(false); - - __variableSpace->putVariable(pairNew, newVar); - } else - break; - } - } - - return sd::Status::OK(); - */ -} -} // namespace graph -} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp index 829b8593719c..a128f6b9ed39 100644 --- a/libnd4j/include/graph/logic/impl/LogicExecutor.cpp +++ b/libnd4j/include/graph/logic/impl/LogicExecutor.cpp @@ -16,38 +16,23 @@ ******************************************************************************/ // -// Created by raver119 on 20.10.2017. +// @author raver119@gmail.com // -#include #include #include #include -#include #include #include #include -#include -#include #include -#include namespace sd { namespace graph { Nd4jStatus LogicExecutor::processNode(const Node *node, Stack &stack, const OptimizedGraph& graph) { switch (node->opNum()) { - case sd::logic::While: - return LogicWhile::processNode(node); - case sd::logic::Scope: - return LogicScope::processNode(node); - case sd::logic::Conditional: - return LogicConditional::processNode(node); case sd::logic::Switch: return LogicSwitch::processNode(node, stack, graph); - case sd::logic::Return: - return LogicReturn::processNode(node); - case sd::logic::Expose: - return LogicExpose::processNode(node); case sd::logic::Merge: return LogicMerge::processNode(node, stack, graph); case sd::logic::LoopCond: diff --git a/libnd4j/include/graph/logic/impl/LogicExpose.cpp b/libnd4j/include/graph/logic/impl/LogicExpose.cpp deleted file mode 100644 index 825bb0db13c2..000000000000 --- a/libnd4j/include/graph/logic/impl/LogicExpose.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 12.11.2017. -// - -#include - -namespace sd { -namespace graph { -Nd4jStatus LogicExpose::processNode(const Node *node) { - // do we really want this? - return ND4J_STATUS_OK; -} -} // namespace graph -} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicReturn.cpp b/libnd4j/include/graph/logic/impl/LogicReturn.cpp deleted file mode 100644 index a9d98b3140bd..000000000000 --- a/libnd4j/include/graph/logic/impl/LogicReturn.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 28.10.2017. -// - -#include "graph/logic/LogicReturn.h" - -#include -#include - -namespace sd { -namespace graph { -Nd4jStatus LogicReturn::processNode(const Node *node) { - throw std::runtime_error("LogicReturn::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - - for (int e = 0; e < node->input()->size(); e++) { - auto inputAddr = node->input()->at(e); - auto outputAddr = node->output()->at(e); - - // FIXME!! - outputAddr.second = e; - - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_debug("Return input: <%i, %i>; Return output: <%i, %i>\n", - inputAddr.first, inputAddr.second, outputAddr.first, outputAddr.second); - - auto varIn = __variableSpace->getVariable(inputAddr); - auto varOut = __variableSpace->getVariable(outputAddr); - - nd4j_debug("Returning varType: [%s]\n", - EnumUtils::_VariableTypeToString(varIn->variableType())); - - // FIXME: this is obviously wrong, we should keep depth track for backprop - here varOut->getNDArray()->assign(varIn->getNDArray()); - - if (Environment::getInstance().isDebugAndVerbose()) - nd4j_debug("In after: [%f]; Out after: [%f]\n", - varIn->getNDArray()->meanNumber().e(0), - varOut->getNDArray()->meanNumber().e(0)); - } - - return sd::Status::OK(); - */ -} -} // namespace graph -} // namespace sd diff --git a/libnd4j/include/graph/logic/impl/LogicScope.cpp b/libnd4j/include/graph/logic/impl/LogicScope.cpp deleted file mode 100644 index c1efb207b652..000000000000 --- a/libnd4j/include/graph/logic/impl/LogicScope.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include - -namespace sd { -namespace graph { -Nd4jStatus LogicScope::processNode(const Node *node) { - // this op is basically no-op - // we just know it exists - return sd::Status::OK(); -} -} // namespace graph -} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/graph/logic/impl/LogicWhile.cpp b/libnd4j/include/graph/logic/impl/LogicWhile.cpp deleted file mode 100644 index 2472f346daa1..000000000000 --- a/libnd4j/include/graph/logic/impl/LogicWhile.cpp +++ /dev/null @@ -1,147 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.10.2017. -// - -#include -#include -#include -#include - -namespace sd { -namespace graph { -Nd4jStatus LogicWhile::processNode(const Node *node) { - throw std::runtime_error("LogicWhile::processNode - not implemented yet"); - /* - auto __variableSpace = graph->variableSpace(); - - nd4j_debug("Starting on WHILE loop: [%i]\n", node->id()); - - // total number of inputs. 2 last inputs are scopes - int inputs = node->input()->size(); - - if (inputs < 3) { - nd4j_printf("While [%i]: loop should have at least 1 external variable - announced\n", node->id()); return ND4J_STATUS_BAD_INPUT; - } - - for (int e = 0; e < inputs - 2; e++) { - std::pair pair(node->id(), e); - if (!__variableSpace->hasVariable(pair)) { - __variableSpace->putVariable(pair, new Variable(nullptr, nullptr, - node->id(), e)); - } - - auto va = node->input()->at(e); - - auto inputVar = __variableSpace->getVariable(va); - - auto innerVar = __variableSpace->getVariable(pair); - if (innerVar->hasNDArray()) { - // TODO: ??? - } else { - // FIXME: in some cases it's possible to have no NDArray - if (inputVar->hasNDArray()) - innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup())); - } - } - - int scopeConditionIndex = node->input()->at(inputs - 2).first; - int scopeBodyIndex = node->input()->at(inputs - 1).first; - - nd4j_debug("While [%i]: got [%i] inputs\n", node->id(), - node->input()->size()); - - // we're running condition nodes now - auto scope = graph->scopeById(scopeConditionIndex); - int breaker = 0; - while (true && breaker < 10000000) { - int lastNode = 0; - // we're running condition scope first - nd4j_debug("While [%i]: got [%i] ops in condition scope [%i]\n", - node->id(), scope->nodes()->size(), scopeConditionIndex); - - for (Node* v: *scope->nodes()) { - //v->getBlock()->updateVariables(); - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName().c_str()); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, v, - __variableSpace); if (status != ND4J_STATUS_OK) return status; - } - - lastNode = v->id(); - } - - if (!__variableSpace->hasVariable(lastNode)) { - nd4j_printf("While [%i]: got no results out of conditional loop\n", - node->id()); return ND4J_STATUS_KERNEL_FAILURE; - } - - // now we should take result of the Scope run, and evaluate it - auto result = __variableSpace->getVariable(lastNode)->getNDArray(); - - if (Environment::getInstance().isDebugAndVerbose()) - result->printBuffer("Result of the last node:"); - - // if result evaluates to 0.0 - condition returned FALSE - if (result->e(0) == 0) - break; - else { - auto scopeBody = graph->scopeById(scopeBodyIndex); - int lastNode = 0; - int e = 0; - nd4j_debug("While [%i] got [%i] ops in body scope [%i]\n", node->id(), - scopeBody->nodes()->size(), scopeBodyIndex); for (; e < - scopeBody->nodes()->size() - 1; e++) { Node* v = scopeBody->nodes()->at(e); - - if (v->opType() == OpType_LOGIC) { - nd4j_debug("Falling back to logic\n",""); - LogicExecutor::processNode(graph, v); - } else { - nd4j_debug("Op [<%s>]\n", v->getName().c_str()); - //v->getBlock()->updateVariables(); - Nd4jStatus status = GraphExecutioner::executeFlatNode(graph, - v, __variableSpace); if (status != ND4J_STATUS_OK) return status; - } - - lastNode = v->id(); - } - - // now execute return statement - Node* ret = scopeBody->nodes()->at(e); - LogicReturn::processNode(graph, ret); - } - - breaker++; - } - - // if we've hit breaker limit - we should notify about that - if (breaker >= 10000000) { - nd4j_printf("While condition seems to be never ending, aborting...\n", - breaker); return ND4J_STATUS_KERNEL_FAILURE; - } - - return sd::Status::OK(); - */ -} -} // namespace graph -} // namespace sd diff --git a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp index 5cc96512c354..ac2026f00fe7 100644 --- a/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ExecutionLayerTests.cpp @@ -62,3 +62,26 @@ TEST_F(ExecutionLayerTests, test_reassign_1) { seq = layer[1]; ASSERT_EQ(2, seq.length()); } + +TEST_F(ExecutionLayerTests, test_purge_1) { + ExecutionLayer layer; + OpSequence sequence1, sequence2; + + Node a(sd::ops::add(), "add"); + Node m(sd::ops::multiply(), "mul"); + + Context ctx1(1); + Context ctx2(2); + + sequence1.append(a, ctx1); + sequence1.append(m, ctx2); + + layer.append(sequence1); + layer.append(sequence2); + + ASSERT_EQ(2, layer.width()); + + layer.purgeEmptySequences(); + + ASSERT_EQ(1, layer.width()); +} \ No newline at end of file From e71e9d18d4bfce4b40c2d10a3e3b570624bbd533 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 23 Jun 2020 11:24:14 +0300 Subject: [PATCH 221/233] OpSequence::append test Signed-off-by: raver119@gmail.com --- .../layers_tests/OpSequenceTests.cpp | 68 ++++++++++++------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp index 183650a4dee0..910e69c89b28 100644 --- a/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpSequenceTests.cpp @@ -38,42 +38,58 @@ class OpSequenceTests : public testing::Test { OpSequenceTests() {} }; -// TEST_F(OpSequenceTests, test_iterator_1) { -// Graph graph; -// OpSequence sequence; +TEST_F(OpSequenceTests, test_append_1) { + OpSequence sequenceA; + OpSequence sequenceB; -// ASSERT_EQ(0, sequence.length()); + ASSERT_EQ(0, sequenceA.length()); -// ops::add op1; -// ops::multiply op2; + Context ctx1(1); + Context ctx2(2); -// Context ctx1(1); -// Context ctx2(2); + sequenceA.append(Node(sd::ops::add(), "add"), ctx1); + sequenceB.append(Node(sd::ops::multiply(), "mul"), ctx2); -// sequence.append(&op1, ctx1); -// sequence.append(&op2, ctx2); + ASSERT_EQ(1, sequenceA.length()); -// ASSERT_EQ(2, sequence.length()); + sequenceA.append(sequenceB); -// int cnt = 1; -// for (const auto &v : sequence) { -// ASSERT_EQ(cnt++, v.contextPrototype().nodeId()); -// } + ASSERT_EQ(2, sequenceA.length()); +} -// ASSERT_EQ(3, cnt); +TEST_F(OpSequenceTests, test_iterator_1) { + Graph graph; + OpSequence sequence; -// OptimizedGraph optimizedGraph; -// ASSERT_EQ(0, optimizedGraph.layers()); + ASSERT_EQ(0, sequence.length()); -// optimizedGraph.append(sequence); -// ASSERT_EQ(1, optimizedGraph.layers()); + Context ctx1(1); + Context ctx2(2); -// auto layer = optimizedGraph.layer(0); + sequence.append(Node(ops::add(), "add"), ctx1); + sequence.append(Node(ops::divide(), "div"), ctx2); -// // we expect exactly 1 sequence in this layer -// ASSERT_EQ(1, layer.width()); + ASSERT_EQ(2, sequence.length()); -// auto seq = layer[0]; + int cnt = 1; + for (const auto &v : sequence) { + ASSERT_EQ(cnt++, v.protoContext().nodeId()); + } -// ASSERT_EQ(2, seq.length()); -// } + ASSERT_EQ(3, cnt); + + OptimizedGraph optimizedGraph; + ASSERT_EQ(0, optimizedGraph.layers()); + + optimizedGraph.append(sequence); + ASSERT_EQ(1, optimizedGraph.layers()); + + auto layer = optimizedGraph.layer(0); + + // we expect exactly 1 sequence in this layer + ASSERT_EQ(1, layer.width()); + + auto seq = layer[0]; + + ASSERT_EQ(2, seq.length()); +} From 44d1efa2d39127789e165345c988b3e76d1b0a0a Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Tue, 23 Jun 2020 11:44:21 +0300 Subject: [PATCH 222/233] few tweaks to make CLang happy Signed-off-by: raver119@gmail.com --- libnd4j/include/array/ConstantHolder.h | 4 ++-- libnd4j/include/array/impl/ConstantHolder.cpp | 18 ++++++++++++++++++ .../layers_tests/ConvolutionTests1.cpp | 2 +- .../layers_tests/ConvolutionTests2.cpp | 2 +- .../layers_tests/DeclarableOpsTests10.cpp | 2 +- .../layers_tests/DeclarableOpsTests13.cpp | 2 +- .../layers_tests/DeclarableOpsTests4.cpp | 2 +- .../layers_tests/DeclarableOpsTests6.cpp | 2 +- .../layers_tests/DeclarableOpsTests7.cpp | 2 +- .../layers_tests/DeclarableOpsTests8.cpp | 6 ++---- .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 18 +++++++++--------- .../tests_cpu/layers_tests/NativeOpsTests.cpp | 4 ++-- .../tests_cpu/layers_tests/PlaygroundTests.cpp | 2 -- 13 files changed, 40 insertions(+), 26 deletions(-) diff --git a/libnd4j/include/array/ConstantHolder.h b/libnd4j/include/array/ConstantHolder.h index 0e606b7a421b..1006692b6e57 100644 --- a/libnd4j/include/array/ConstantHolder.h +++ b/libnd4j/include/array/ConstantHolder.h @@ -41,8 +41,8 @@ class ConstantHolder { ConstantHolder() = default; ~ConstantHolder() = default; - ConstantHolder& operator=(const ConstantHolder& other) = default; - ConstantHolder& operator=(ConstantHolder&& other) = default; + ConstantHolder& operator=(const ConstantHolder& other); + ConstantHolder& operator=(ConstantHolder&& other); bool hasBuffer(sd::DataType dataType); diff --git a/libnd4j/include/array/impl/ConstantHolder.cpp b/libnd4j/include/array/impl/ConstantHolder.cpp index b8146a1aeee2..35c45335baf5 100644 --- a/libnd4j/include/array/impl/ConstantHolder.cpp +++ b/libnd4j/include/array/impl/ConstantHolder.cpp @@ -28,6 +28,24 @@ ConstantHolder::ConstantHolder(const ConstantHolder& other) { _deviceId = other._deviceId; } +ConstantHolder &ConstantHolder::operator=(const ConstantHolder &other) { + if (this == &other) return *this; + + _buffers = other._buffers; + _deviceId = other._deviceId; + + return *this; +} + +ConstantHolder &ConstantHolder::operator=(ConstantHolder &&other) { + if (this == &other) return *this; + + _buffers = std::move(other._buffers); + _deviceId = other._deviceId; + + return *this; +} + bool ConstantHolder::hasBuffer(sd::DataType dataType) { return _buffers.count(dataType) > 0; } diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index c10e5e8eb95e..087c12bfebc3 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -51,7 +51,7 @@ class TypedConvolutionTests1 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes); +TYPED_TEST_SUITE(TypedConvolutionTests1, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, conv2d_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 87fe7e16d003..f9f39eb086b2 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -103,7 +103,7 @@ class TypedConvolutionTests2 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes); +TYPED_TEST_SUITE(TypedConvolutionTests2, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index d934fc700051..785b48417078 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -46,7 +46,7 @@ class TypedDeclarableOpsTests10 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests10, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests10, TestingTypes); TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { auto x = NDArrayFactory::create('c', {3, 3}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 0c62808c4ca6..04d9cb615996 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -46,7 +46,7 @@ class TypedDeclarableOpsTests13 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests13, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests13, TestingTypes); TEST_F(DeclarableOpsTests13, test_pow_1) { auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index e83be07a2a92..148a249e4612 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -52,7 +52,7 @@ class TypedDeclarableOpsTests4 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests4, TestingTypes); ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index e76f33d4a0c3..cc482adfa202 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -1780,7 +1780,7 @@ TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { std::vector dims = {0, 1}; auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); - ASSERT_TRUE(&z != nullptr); + ASSERT_TRUE(z.defined()); } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 7b8e5c4d7c9d..ffa56549245f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -47,7 +47,7 @@ class TypedDeclarableOpsTests7 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests7, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests7, TestingTypes); TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { double inputData[150] = { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 4357a994c004..b51ff96745f7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -45,7 +45,7 @@ class TypedDeclarableOpsTests8 : public testing::Test { }; typedef ::testing::Types TestingTypes; -TYPED_TEST_CASE(TypedDeclarableOpsTests8, TestingTypes); +TYPED_TEST_SUITE(TypedDeclarableOpsTests8, TestingTypes); //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, reduceVariance_test1) { @@ -466,9 +466,7 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { 'c', {3, 4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); - auto axes = NDArrayFactory::create({ - (int)0, - }); + auto axes = NDArrayFactory::create(0); x.linspace(1); sd::ops::reduce_variance_bp op; diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index 28ab790eef7f..899a5ec48142 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -350,7 +350,7 @@ TEST_F(LegacyOpsTests, BroadcastingTests_2) { // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( - y.shapeInfo(), {axis}); + y.shapeInfo(), axis); NDArray::prepareSpecialUse({&y}, {&x}); @@ -439,9 +439,9 @@ TEST_F(LegacyOpsTests, Reduce3_2) { #endif auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( - x.shapeInfo(), {1}); + x.shapeInfo(), 1); auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( - y.shapeInfo(), {1}); + y.shapeInfo(), 1); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); OpaqueDataBuffer xBuf(x.dataBuffer()); @@ -494,9 +494,9 @@ TEST_F(LegacyOpsTests, Reduce3_3) { #endif auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( - x.shapeInfo(), {1}); + x.shapeInfo(), 1); auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( - y.shapeInfo(), {1}); + y.shapeInfo(), 1); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); OpaqueDataBuffer xBuf(x.dataBuffer()); @@ -549,9 +549,9 @@ TEST_F(LegacyOpsTests, Reduce3_4) { #endif auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( - x.shapeInfo(), {1}); + x.shapeInfo(), 1); auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( - y.shapeInfo(), {1}); + y.shapeInfo(), 1); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); OpaqueDataBuffer xBuf(x.dataBuffer()); @@ -606,9 +606,9 @@ TEST_F(LegacyOpsTests, Reduce3_5) { #endif auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions( - x.shapeInfo(), {1}); + x.shapeInfo(), 1); auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions( - y.shapeInfo(), {1}); + y.shapeInfo(), 1); NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 773e0728da98..9295db2af829 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -1169,7 +1169,7 @@ TEST_F(NativeOpsTests, ShuffleTest_1) { (Nd4jPointer)z.specialShapeInfo()}; int shuffleMap[] = {1, 0, 4, 3, 2}; auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions( - x.shapeInfo(), {1}); + x.shapeInfo(), 1); Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), (Nd4jPointer)zTadPack.platformOffsets()}; Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), @@ -1365,7 +1365,7 @@ TEST_F(NativeOpsTests, SortTest_4) { std::vector dims({1}); auto packX = ConstantTadHelper::getInstance().tadForDimensions( - sortedVals.shapeInfo(), {1}); + sortedVals.shapeInfo(), 1); ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), dims.data(), dims.size(), packX.platformShapeInfo(), diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 8ef3c4b5dfe5..9244dbdc3266 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -600,8 +600,6 @@ timeStart).count(); values.emplace_back(outerTime); } } - -/* TEST_F(PlaygroundTests, test_broadcast_1) { int pool = 500; std::vector aX(pool); From c555c446d4619f36bc69f712752b740d30274609 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 25 Jun 2020 16:11:14 +0300 Subject: [PATCH 223/233] - provide different sorting algorithm for graph with loops Signed-off-by: Yurii --- .../include/graph/execution/ExecutionLayer.h | 3 + .../graph/execution/impl/ExecutionLayer.cpp | 8 + libnd4j/include/graph/impl/OptimizedGraph.cpp | 303 +++++++++++------- 3 files changed, 202 insertions(+), 112 deletions(-) diff --git a/libnd4j/include/graph/execution/ExecutionLayer.h b/libnd4j/include/graph/execution/ExecutionLayer.h index d53aba220dcf..16ce1bf4096f 100644 --- a/libnd4j/include/graph/execution/ExecutionLayer.h +++ b/libnd4j/include/graph/execution/ExecutionLayer.h @@ -59,6 +59,9 @@ class SD_EXPORT ExecutionLayer { const OpSequence& at(uint64_t index) const; const OpSequence& operator[](uint64_t index) const; + OpSequence& at(uint64_t index); + OpSequence& operator[](uint64_t index); + /** * This method appends OpSequence to the end of this layer * @param sequence diff --git a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp index d01fc2dccc9e..50f8a600d6df 100644 --- a/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp +++ b/libnd4j/include/graph/execution/impl/ExecutionLayer.cpp @@ -32,10 +32,18 @@ const OpSequence &ExecutionLayer::at(uint64_t index) const { return _sequences[index]; } +OpSequence &ExecutionLayer::at(uint64_t index) { + return _sequences[index]; +} + const OpSequence &ExecutionLayer::operator[](uint64_t index) const { return at(index); } +OpSequence &ExecutionLayer::operator[](uint64_t index) { + return at(index); +} + bool ExecutionLayer::hasNode(int nodeId) const { for (const auto &v:_sequences) if (v.hasNode(nodeId)) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index d448658d99a3..24a1057aabab 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -61,30 +61,40 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS std::vector _opSeq = {}; std::vector _in = {}; std::vector _out = {}; + bool _isActive = false; }; MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) + bool containsLoop = false; + // create workMap, fill vectors containing input and output nodes per each node, and find start nodes std::vector startNodes; + std::unordered_map> endsOfFrame; // key - id of frame, value is vector containing ids of all exits from this frame + std::hash hasher; for (auto& p : _nodesMap) { const auto& inputs = p.second.inputs(); + const auto& nameOfNode = p.second.name(); - if(p.second.name().find("Exit") != std::string::npos) { + if(nameOfNode.find("Exit") != std::string::npos) { + containsLoop = true; const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; p.second.setFrameId(_nodesMap[idOfEnter].frameId()); _nodesMap[idOfEnter].setExitId(p.first); + } else if (nameOfNode.find("NextIteration") != std::string::npos) { + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); + endsOfFrame[hasher(frameName)].push_back(p.first); } for (int i = 0; i < inputs.size(); ++i) { if (_nodesMap.count(inputs[i].first) != 0) { // is op _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); - if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { + // if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { workMap[inputs[i].first]._out.push_back(p.first); workMap[p.first]._in.push_back(inputs[i].first); - } + // } } else { // is variable @@ -93,10 +103,10 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS for (int j = 0; j < depends.size(); ++j) { if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); - if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { + // if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { workMap[depends[j].first]._out.push_back(p.first); workMap[p.first]._in.push_back(depends[j].first); - } + // } } } } @@ -127,151 +137,220 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS // } // printf("\n\n\n\n\n"); + if (containsLoop) { + + std::vector seq; + uint numOfActive = 0; + auto it = workMap.begin(); + + while (numOfActive < workMap.size()) { + + bool makeActive = true; + + if(!it->second._isActive) { + + const auto& nameOfNode = _nodesMap[it->first].name(); + + for (const auto& inId : it->second._in) { + + if (_nodesMap[inId].name().find("NextIteration") != std::string::npos) + continue; + + if (!workMap[inId]._isActive) { + makeActive = false; + } + else if (nameOfNode.find("Exit") != std::string::npos) { + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); + + for (const auto& j : endsOfFrame[hasher(frameName)]) { + if(!workMap[j]._isActive){ + makeActive = false; + break; + } + } + } + if(!makeActive) + break; + } + } - // collect OpSequences (fill _opSeq) - std::vector nodesToDelete; - for (auto& p : workMap) { + if(makeActive && !it->second._isActive) { + ++numOfActive; + it->second._isActive = true; + seq.push_back(it->first); + } - auto& out = p.second._out; - while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { - nodesToDelete.push_back(out[0]); - p.second._opSeq.push_back(out[0]); - out = std::move(workMap[out[0]]._out); + if(++it == workMap.end()) + it = workMap.begin(); } - } + OpSequence sequence; + for (auto i : seq) + sequence.append(_nodesMap.at(i), _nodesMap.at(i).contextPrototype()); + _sortedGraph.emplace_back(ExecutionLayer({sequence})); - // delete nodes present in _opSeq, their ids are already stored in nodesToDelete - for (const auto& i : nodesToDelete) - workMap.erase(i); - // lambda for topological sort - std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { - if(layerNum <= workMap[id]._layerNum) - return; - workMap[id]._layerNum = layerNum; - if(numOfLayers < layerNum) - numOfLayers = layerNum; + // for (const auto& i : seq) + // printf("%i ", i); + // printf("\n\n"); + // for (const auto& i : workMap) + // if(i.second._isActive) + // printf("%i ",i.first); + // printf("\n"); + // for (const auto& i : workMap) + // if(!i.second._isActive) + // printf("%i ",i.first); - const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; - if(!isSwitch) { - for (const auto& nextId : workMap[id]._out) - visit(nextId, layerNum+1, numOfLayers); + } + else { + // collect OpSequences (fill _opSeq) + std::vector nodesToDelete; + for (auto& p : workMap) { + + auto& out = p.second._out; + while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { + nodesToDelete.push_back(out[0]); + p.second._opSeq.push_back(out[0]); + out = std::move(workMap[out[0]]._out); + } } - else { - if(this->_nodesMap[id].outputs()[0].second == 1) { - visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch - visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + + + // delete nodes present in _opSeq, their ids are already stored in nodesToDelete + for (const auto& i : nodesToDelete) + workMap.erase(i); + + // lambda for topological sort + std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { + + if(layerNum <= workMap[id]._layerNum) + return; + workMap[id]._layerNum = layerNum; + if(numOfLayers < layerNum) + numOfLayers = layerNum; + + const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; + if(!isSwitch) { + for (const auto& nextId : workMap[id]._out) + visit(nextId, layerNum+1, numOfLayers); } else { - visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch - visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + if(this->_nodesMap[id].outputs()[0].second == 1) { + visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch + visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + } + else { + visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch + visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + } } - } - }; + }; - // perform topological sort - uint numOfLayers = 0; - for (const auto& id : startNodes) - for (const auto& nextId : workMap[id]._out) - visit(nextId, 1, numOfLayers); + // perform topological sort + uint numOfLayers = 0; + for (const auto& id : startNodes) + for (const auto& nextId : workMap[id]._out) + visit(nextId, 1, numOfLayers); - // fill vectors with layers - std::vector sortedGraphTemp(numOfLayers+1); - for (const auto& p : workMap) { + // fill vectors with layers + std::vector sortedGraphTemp(numOfLayers+1); + for (const auto& p : workMap) { - OpSequence seq; - seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); + OpSequence seq; + seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); - for (const auto& id : p.second._opSeq) - seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); + for (const auto& id : p.second._opSeq) + seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); - // while(sortedGraphTemp.size() <= p.second._layerNum) - // sortedGraphTemp.emplace_back(ExecutionLayer()); + // while(sortedGraphTemp.size() <= p.second._layerNum) + // sortedGraphTemp.emplace_back(ExecutionLayer()); - sortedGraphTemp[p.second._layerNum].append(std::move(seq)); - } + sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } - const char delimiter = '/'; - for (int i0 = 0; i0 < sortedGraphTemp.size(); ++i0) { - for (int i1 = 0; i1 < sortedGraphTemp[i0].width(); ++i1) { - for (int i2 = 0; i2 < sortedGraphTemp[i0][i1].length(); ++i2) { - // if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) - // continue; - - auto id = sortedGraphTemp[i0][i1][i2].node().id(); - auto* name = &_nodesMap[id].name(); - if (name->find("Enter") == std::string::npos) - continue; - std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop - for (int j0 = i0; j0 < sortedGraphTemp.size(); ++j0) { - for (int j1 = j0 == i0 ? i1 + 1 : 0; j1 < sortedGraphTemp[j0].width(); ++j1) { - for (int j2 = 0; j2 < sortedGraphTemp[j0][j1].length(); ++j2) { - id = sortedGraphTemp[j0][j1][j2].node().id(); - name = &_nodesMap[id].name(); - if (name->find(loopName) == std::string::npos) - continue; - for (int k = 0; k < sortedGraphTemp[j0][j1].length(); ++k) - const_cast(sortedGraphTemp[i0][i1]).append(sortedGraphTemp[j0][j1][k]); - const_cast(sortedGraphTemp[j0][j1]) = OpSequence(); - break; + const char delimiter = '/'; + for (int i0 = 0; i0 < sortedGraphTemp.size(); ++i0) { + for (int i1 = 0; i1 < sortedGraphTemp[i0].width(); ++i1) { + for (int i2 = 0; i2 < sortedGraphTemp[i0][i1].length(); ++i2) { + // if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) + // continue; + + auto id = sortedGraphTemp[i0][i1][i2].node().id(); + auto* name = &_nodesMap[id].name(); + if (name->find("Enter") == std::string::npos) + continue; + std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop + for (int j0 = i0; j0 < sortedGraphTemp.size(); ++j0) { + for (int j1 = j0 == i0 ? i1 + 1 : 0; j1 < sortedGraphTemp[j0].width(); ++j1) { + for (int j2 = 0; j2 < sortedGraphTemp[j0][j1].length(); ++j2) { + id = sortedGraphTemp[j0][j1][j2].node().id(); + name = &_nodesMap[id].name(); + if (name->find(loopName) == std::string::npos) + continue; + for (int k = 0; k < sortedGraphTemp[j0][j1].length(); ++k) + const_cast(sortedGraphTemp[i0][i1]).append(sortedGraphTemp[j0][j1][k]); + const_cast(sortedGraphTemp[j0][j1]) = OpSequence(); + break; + } } } + break; } - break; } } - } - // delete empty layers - for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) { - bool isEmpty = true; - for (uint i = 0; i < it->width(); ++i) { - if(it->at(i).length() != 0) - isEmpty = false; - break; + // delete empty layers + for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) { + bool isEmpty = true; + for (uint i = 0; i < it->width(); ++i) { + if(it->at(i).length() != 0) + isEmpty = false; + break; + } + + if(isEmpty) + sortedGraphTemp.erase(it--); } - if(isEmpty) - sortedGraphTemp.erase(it--); - } + // check whether there are layers with one OpSequence which in turn contains only one op + bool isLayerWithOneOp = false; + for(auto& layer : sortedGraphTemp) { + if(layer.width() == 1 && layer.at(0).length() == 1) { + isLayerWithOneOp = true; + break; + } + } - // check whether there are layers with one OpSequence which in turn contains only one op - bool isLayerWithOneOp = false; - for(auto& layer : sortedGraphTemp) { - if(layer.width() == 1 && layer.at(0).length() == 1) { - isLayerWithOneOp = true; - break; + // fill _sortedGraph + if(!isLayerWithOneOp) { + _sortedGraph = std::move(sortedGraphTemp); } - } + else { + for (uint i = 0; i < sortedGraphTemp.size();) { - // fill _sortedGraph - if(!isLayerWithOneOp) { - _sortedGraph = std::move(sortedGraphTemp); - } - else { - for (uint i = 0; i < sortedGraphTemp.size();) { + OpSequence seq; + while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) + seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); - OpSequence seq; - while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) - seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); + if(seq.length() != 0) + _sortedGraph.emplace_back(ExecutionLayer({seq})); + else + _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); + } - if(seq.length() != 0) - _sortedGraph.emplace_back(ExecutionLayer({seq})); - else - _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); } + // sort _sortedGraph + // for (auto& l : _sortedGraph) + // l.sortOpSequences(); + + // clean up before exiting + purgeEmptyLayers(); } - // sort _sortedGraph - // for (auto& l : _sortedGraph) - // l.sortOpSequences(); - // clean up before exiting - purgeEmptyLayers(); } void OptimizedGraph::append(const OpSequence &sequence) { From 035416cfce508e87744b6a8af8f7ffb494170f46 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 25 Jun 2020 16:38:36 +0300 Subject: [PATCH 224/233] disable disabled Merge Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/logic/LogicSwitch.h | 2 +- libnd4j/include/graph/logic/LogicUtilities.h | 42 ++++++++++ .../include/graph/logic/impl/LogicMerge.cpp | 3 +- .../include/graph/logic/impl/LogicSwitch.cpp | 54 +------------ .../graph/logic/impl/LogicUtilities.cpp | 77 +++++++++++++++++++ 5 files changed, 125 insertions(+), 53 deletions(-) create mode 100644 libnd4j/include/graph/logic/LogicUtilities.h create mode 100644 libnd4j/include/graph/logic/impl/LogicUtilities.cpp diff --git a/libnd4j/include/graph/logic/LogicSwitch.h b/libnd4j/include/graph/logic/LogicSwitch.h index e34d475a501d..a7d21fd053fe 100644 --- a/libnd4j/include/graph/logic/LogicSwitch.h +++ b/libnd4j/include/graph/logic/LogicSwitch.h @@ -16,7 +16,7 @@ ******************************************************************************/ // -// Created by raver119 on 21.10.17. +// @author raver119@gmail.com // #ifndef SD_LOGICSWITCH_H diff --git a/libnd4j/include/graph/logic/LogicUtilities.h b/libnd4j/include/graph/logic/LogicUtilities.h new file mode 100644 index 000000000000..93f9d4677085 --- /dev/null +++ b/libnd4j/include/graph/logic/LogicUtilities.h @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SD_LOGICUTILITIES_H +#define SD_LOGICUTILITIES_H + +#include +#include +#include +#include + +namespace sd { +namespace graph { + +class SD_EXPORT LogicUtilities { + public: + static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable); + static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node); +}; + + +} +} + +#endif //SD_LOGICUTILITIES_H diff --git a/libnd4j/include/graph/logic/impl/LogicMerge.cpp b/libnd4j/include/graph/logic/impl/LogicMerge.cpp index aea0c8b3bb61..815b322bf658 100644 --- a/libnd4j/include/graph/logic/impl/LogicMerge.cpp +++ b/libnd4j/include/graph/logic/impl/LogicMerge.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace sd { namespace graph { @@ -52,7 +53,7 @@ Nd4jStatus LogicMerge::processNode(const Node *node, Stack &stack, const Optimiz // if both inputs are unavailable - this node is disabled and must be disabled if (!checkViability(stack, inputs[0], inputs[1])) { nd4j_printf("Both inputs absent, skipping\n", ""); - // TODO: disable this branch + LogicUtilities::disableBranch(frame, varSpace, graph, node); return Status::OK(); } diff --git a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp index 461036f0a098..53310a79a997 100644 --- a/libnd4j/include/graph/logic/impl/LogicSwitch.cpp +++ b/libnd4j/include/graph/logic/impl/LogicSwitch.cpp @@ -22,59 +22,11 @@ #include #include #include +#include namespace sd { namespace graph { -static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node) { - const auto &outputs = node->outputs(); - - // we're going to disable certain external variables, if they depend on a current disabled node - // FIXME: it can be done in a better way rather than O(n^2) - for (const auto &var: varSpace.externalPaired()) { - for (const auto &d: var.second->dependencies()) { - if (d.first == node->id()) - frame.disableNode(var.second->id()); - } - } - - // we're going to roll through all consumers - for (const auto &o:outputs) { - if (graph.nodesMap().count(o.first) == 0) - throw std::runtime_error("pew-pew"); - - // now fetch disabled node - const auto &n = graph.nodesMap().at(o.first); - - // edge case here: don't disable Merge node - if (n.opType() == OpType_LOGIC && n.opNum() == sd::logic::Merge) - continue; - - // disable each consumer - frame.disableNode(o.first); - - // do recursive magic - disableBranch(frame, varSpace, graph, &n); - } -} - -static void disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable) { - const auto &outputs = node->outputs(); - int second = branchToDisable ? 1 : 0; - - for (const auto &o:outputs) { - if (o.second == second) { - frame.disableNode(o.first); - - if (graph.nodesMap().count(o.first) == 0) - throw std::runtime_error("pew-pew"); - - const auto &n = graph.nodesMap().at(o.first); - - disableBranch(frame, varSpace, graph, &n); - } - } -} Nd4jStatus LogicSwitch::processNode(const Node* node, Stack &stack, const OptimizedGraph& graph) { // getting current frame first @@ -99,11 +51,11 @@ Nd4jStatus LogicSwitch::processNode(const Node* node, Stack &stack, const Optimi if (boolean->getNDArray()->e(0)) { // true branch varSpace.putVariable(std::pair{node->id(), 1}, *input->getNDArray()); - disableBranch(frame, varSpace, graph, node, false); + LogicUtilities::disableBranch(frame, varSpace, graph, node, false); } else { // false branch varSpace.putVariable(std::pair{node->id(), 0}, *input->getNDArray()); - disableBranch(frame, varSpace, graph, node, true); + LogicUtilities::disableBranch(frame, varSpace, graph, node, true); } return Status::OK(); diff --git a/libnd4j/include/graph/logic/impl/LogicUtilities.cpp b/libnd4j/include/graph/logic/impl/LogicUtilities.cpp new file mode 100644 index 000000000000..1adca748a95a --- /dev/null +++ b/libnd4j/include/graph/logic/impl/LogicUtilities.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +namespace graph { + +void LogicUtilities::disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node) { + const auto &outputs = node->outputs(); + + // we're going to disable certain external variables, if they depend on a current disabled node + // FIXME: it can be done in a better way rather than O(n^2) + for (const auto &var: varSpace.externalPaired()) { + for (const auto &d: var.second->dependencies()) { + if (d.first == node->id()) + frame.disableNode(var.second->id()); + } + } + + // we're going to roll through all consumers + for (const auto &o:outputs) { + if (graph.nodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + + // now fetch disabled node + const auto &n = graph.nodesMap().at(o.first); + + // edge case here: don't disable Merge node + if (n.opType() == OpType_LOGIC && n.opNum() == sd::logic::Merge) + continue; + + // disable each consumer + frame.disableNode(o.first); + + // do recursive magic + disableBranch(frame, varSpace, graph, &n); + } +} + +void LogicUtilities::disableBranch(StackFrame &frame, VariableProxy &varSpace, const OptimizedGraph &graph, const Node* node, bool branchToDisable) { + const auto &outputs = node->outputs(); + int second = branchToDisable ? 1 : 0; + + for (const auto &o:outputs) { + if (o.second == second) { + frame.disableNode(o.first); + + if (graph.nodesMap().count(o.first) == 0) + throw std::runtime_error("pew-pew"); + + const auto &n = graph.nodesMap().at(o.first); + + disableBranch(frame, varSpace, graph, &n); + } + } +} + +} +} From 39963b3295da8a98fe81ef838f1167f44b097e43 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 25 Jun 2020 17:32:05 +0300 Subject: [PATCH 225/233] - split graph sorting into two separate algorithms: first one for usual graphs without enter nodes, second for graphs containing enter nodes Signed-off-by: Yurii --- libnd4j/include/graph/OptimizedGraph.h | 3 + libnd4j/include/graph/impl/OptimizedGraph.cpp | 406 ++++++++---------- 2 files changed, 189 insertions(+), 220 deletions(-) diff --git a/libnd4j/include/graph/OptimizedGraph.h b/libnd4j/include/graph/OptimizedGraph.h index 69032a102dd1..0c553f967988 100644 --- a/libnd4j/include/graph/OptimizedGraph.h +++ b/libnd4j/include/graph/OptimizedGraph.h @@ -36,6 +36,9 @@ class SD_EXPORT OptimizedGraph { std::vector _sortedGraph; MAP_IMPL _nodesMap; + void sortUsualGraph(const VariableSpace& varSpace); + void sortGraphWithFrames(const VariableSpace& varSpace); + /** * This method removes all empty ExecutionLayers from this graph */ diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 24a1057aabab..eec971c7e072 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -45,32 +45,17 @@ OptimizedGraph& OptimizedGraph::operator=(OptimizedGraph &&other) noexcept { } /////////////////////////////////////////////////////////////////// -OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace): _nodesMap(inMap) { - - // for (const auto& p : _nodesMap) { - // printf("%s %i, inputs: ", p.second.name().c_str(), p.first); - // const auto& inputs = p.second.inputs(); - // for (int i = 0; i < inputs.size(); ++i) - // printf("%i, ", inputs[i].first); - // printf("\n"); - // } - // return; +void OptimizedGraph::sortGraphWithFrames(const VariableSpace& varSpace) { struct NodeInfo { - uint _layerNum = 0; - std::vector _opSeq = {}; std::vector _in = {}; - std::vector _out = {}; bool _isActive = false; }; - MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) - - bool containsLoop = false; + MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes) // create workMap, fill vectors containing input and output nodes per each node, and find start nodes - std::vector startNodes; - std::unordered_map> endsOfFrame; // key - id of frame, value is vector containing ids of all exits from this frame + std::unordered_map> nextItersOfFrame; // key - id of frame, value is vector containing ids of all NextIterations of this frame std::hash hasher; for (auto& p : _nodesMap) { @@ -78,281 +63,262 @@ OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableS const auto& nameOfNode = p.second.name(); if(nameOfNode.find("Exit") != std::string::npos) { - containsLoop = true; + const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; p.second.setFrameId(_nodesMap[idOfEnter].frameId()); _nodesMap[idOfEnter].setExitId(p.first); + } else if (nameOfNode.find("NextIteration") != std::string::npos) { + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); - endsOfFrame[hasher(frameName)].push_back(p.first); + nextItersOfFrame[hasher(frameName)].push_back(p.first); } for (int i = 0; i < inputs.size(); ++i) { if (_nodesMap.count(inputs[i].first) != 0) { // is op + _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); - // if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { - workMap[inputs[i].first]._out.push_back(p.first); - workMap[p.first]._in.push_back(inputs[i].first); - // } - } - else { // is variable + workMap[p.first]._in.push_back(inputs[i].first); + + } else { // is variable const auto depends = varSpace.getVariable(inputs[i].first).get()->dependencies(); for (int j = 0; j < depends.size(); ++j) { + if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { + _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); - // if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { - workMap[depends[j].first]._out.push_back(p.first); - workMap[p.first]._in.push_back(depends[j].first); - // } + workMap[p.first]._in.push_back(depends[j].first); } } } } - - if(workMap[p.first]._in.empty()) - startNodes.push_back(p.first); } -// printf("\n\n\n\n\n"); - // for (const auto& p : workMap) { - // printf("id %i, inputs: ", p.first); - // for (const auto& i : p.second._in) - // printf("%i, ", i); - // printf("outputs: "); - // for (const auto& i : p.second._out) - // printf("%i, ", i); - // printf("\n"); - // } - - // for (const auto& p : _nodesMap) { - // printf("id %i, inputs ", p.first); - // for(const auto& i : p.second.inputs()) - // printf("%i, ", i.first); - // printf(" outputs: "); - // for(const auto& i : p.second.outputs()) - // printf("%i, ", i.first); - // printf("\n"); - // } - // printf("\n\n\n\n\n"); - if (containsLoop) { + std::vector seq; + uint numOfActive = 0; + auto it = workMap.begin(); - std::vector seq; - uint numOfActive = 0; - auto it = workMap.begin(); + // perform linear sort + while (numOfActive < workMap.size()) { - while (numOfActive < workMap.size()) { + bool makeActive = true; - bool makeActive = true; + if(!it->second._isActive) { - if(!it->second._isActive) { + const auto& nameOfNode = _nodesMap[it->first].name(); - const auto& nameOfNode = _nodesMap[it->first].name(); + for (const auto& inId : it->second._in) { - for (const auto& inId : it->second._in) { + if (_nodesMap[inId].name().find("NextIteration") != std::string::npos) + continue; - if (_nodesMap[inId].name().find("NextIteration") != std::string::npos) - continue; + if (!workMap[inId]._isActive) { + makeActive = false; + } + else if (nameOfNode.find("Exit") != std::string::npos) { + const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); - if (!workMap[inId]._isActive) { - makeActive = false; - } - else if (nameOfNode.find("Exit") != std::string::npos) { - const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); - - for (const auto& j : endsOfFrame[hasher(frameName)]) { - if(!workMap[j]._isActive){ - makeActive = false; - break; - } + for (const auto& j : nextItersOfFrame[hasher(frameName)]) { + if(!workMap[j]._isActive){ + makeActive = false; + break; } } - if(!makeActive) - break; } + if(!makeActive) + break; } + } - if(makeActive && !it->second._isActive) { - ++numOfActive; - it->second._isActive = true; - seq.push_back(it->first); - } - - if(++it == workMap.end()) - it = workMap.begin(); + if(makeActive && !it->second._isActive) { + ++numOfActive; + it->second._isActive = true; + seq.push_back(it->first); } - OpSequence sequence; - for (auto i : seq) - sequence.append(_nodesMap.at(i), _nodesMap.at(i).contextPrototype()); - _sortedGraph.emplace_back(ExecutionLayer({sequence})); + if(++it == workMap.end()) + it = workMap.begin(); + } + // fill _sortedGraph + OpSequence sequence; + for (auto i : seq) + sequence.append(_nodesMap.at(i), _nodesMap.at(i).contextPrototype()); + _sortedGraph.emplace_back(ExecutionLayer({sequence})); +} +/////////////////////////////////////////////////////////////////// +void OptimizedGraph::sortUsualGraph(const VariableSpace& varSpace) { - // for (const auto& i : seq) - // printf("%i ", i); - // printf("\n\n"); - // for (const auto& i : workMap) - // if(i.second._isActive) - // printf("%i ",i.first); - // printf("\n"); - // for (const auto& i : workMap) - // if(!i.second._isActive) - // printf("%i ",i.first); + struct NodeInfo { + uint _layerNum = 0; + std::vector _opSeq = {}; + std::vector _in = {}; + std::vector _out = {}; + }; - } - else { - // collect OpSequences (fill _opSeq) - std::vector nodesToDelete; - for (auto& p : workMap) { - - auto& out = p.second._out; - while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { - nodesToDelete.push_back(out[0]); - p.second._opSeq.push_back(out[0]); - out = std::move(workMap[out[0]]._out); - } - } + MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) + // create workMap, fill vectors containing input and output nodes per each node, and find start nodes + std::vector startNodes; + for (auto& p : _nodesMap) { - // delete nodes present in _opSeq, their ids are already stored in nodesToDelete - for (const auto& i : nodesToDelete) - workMap.erase(i); + const auto& inputs = p.second.inputs(); - // lambda for topological sort - std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { + if(p.second.name().find("Exit") != std::string::npos) { + const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; + p.second.setFrameId(_nodesMap[idOfEnter].frameId()); + _nodesMap[idOfEnter].setExitId(p.first); + } - if(layerNum <= workMap[id]._layerNum) - return; - workMap[id]._layerNum = layerNum; - if(numOfLayers < layerNum) - numOfLayers = layerNum; + for (int i = 0; i < inputs.size(); ++i) { - const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; - if(!isSwitch) { - for (const auto& nextId : workMap[id]._out) - visit(nextId, layerNum+1, numOfLayers); - } - else { - if(this->_nodesMap[id].outputs()[0].second == 1) { - visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch - visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + if (_nodesMap.count(inputs[i].first) != 0) { // is op + _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); + if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { + workMap[inputs[i].first]._out.push_back(p.first); + workMap[p.first]._in.push_back(inputs[i].first); } - else { - visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch - visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + } + else { // is variable + + const auto depends = varSpace.getVariable(inputs[i].first).get()->dependencies(); + + for (int j = 0; j < depends.size(); ++j) { + if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { + _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); + if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { + workMap[depends[j].first]._out.push_back(p.first); + workMap[p.first]._in.push_back(depends[j].first); + } + } } } - }; + } - // perform topological sort - uint numOfLayers = 0; - for (const auto& id : startNodes) - for (const auto& nextId : workMap[id]._out) - visit(nextId, 1, numOfLayers); + if(workMap[p.first]._in.empty()) + startNodes.push_back(p.first); + } + // collect OpSequences (fill _opSeq) + std::vector nodesToDelete; + for (auto& p : workMap) { - // fill vectors with layers - std::vector sortedGraphTemp(numOfLayers+1); - for (const auto& p : workMap) { + auto& out = p.second._out; + while(out.size() == 1 && workMap[out[0]]._in.size() == 1) { + nodesToDelete.push_back(out[0]); + p.second._opSeq.push_back(out[0]); + out = std::move(workMap[out[0]]._out); + } + } - OpSequence seq; - seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); + // delete nodes present in _opSeq, their ids are already stored in nodesToDelete + for (const auto& i : nodesToDelete) + workMap.erase(i); - for (const auto& id : p.second._opSeq) - seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); + // lambda for topological sort + std::function visit = [&visit, &workMap, this] (const int id, const uint layerNum, uint& numOfLayers) { - // while(sortedGraphTemp.size() <= p.second._layerNum) - // sortedGraphTemp.emplace_back(ExecutionLayer()); + if(layerNum <= workMap[id]._layerNum) + return; + workMap[id]._layerNum = layerNum; + if(numOfLayers < layerNum) + numOfLayers = layerNum; - sortedGraphTemp[p.second._layerNum].append(std::move(seq)); - } + // const bool isSwitch = this->_nodesMap[id].name().find("Switch") != std::string::npos; - const char delimiter = '/'; - for (int i0 = 0; i0 < sortedGraphTemp.size(); ++i0) { - for (int i1 = 0; i1 < sortedGraphTemp[i0].width(); ++i1) { - for (int i2 = 0; i2 < sortedGraphTemp[i0][i1].length(); ++i2) { - // if(!p0->second._opSeq.empty() && p0->second._opSeq[0] == -1) - // continue; - - auto id = sortedGraphTemp[i0][i1][i2].node().id(); - auto* name = &_nodesMap[id].name(); - if (name->find("Enter") == std::string::npos) - continue; - std::string loopName = name->substr(0, name->find(delimiter)); // evaluate name of loop - for (int j0 = i0; j0 < sortedGraphTemp.size(); ++j0) { - for (int j1 = j0 == i0 ? i1 + 1 : 0; j1 < sortedGraphTemp[j0].width(); ++j1) { - for (int j2 = 0; j2 < sortedGraphTemp[j0][j1].length(); ++j2) { - id = sortedGraphTemp[j0][j1][j2].node().id(); - name = &_nodesMap[id].name(); - if (name->find(loopName) == std::string::npos) - continue; - for (int k = 0; k < sortedGraphTemp[j0][j1].length(); ++k) - const_cast(sortedGraphTemp[i0][i1]).append(sortedGraphTemp[j0][j1][k]); - const_cast(sortedGraphTemp[j0][j1]) = OpSequence(); - break; - } - } - } - break; - } - } - } + // if(!isSwitch) { + for (const auto& nextId : workMap[id]._out) + visit(nextId, layerNum+1, numOfLayers); + // } + // else { + // if(this->_nodesMap[id].outputs()[0].second == 1) { + // visit(_nodesMap[id].outputs()[0].first, layerNum+1, numOfLayers); // true branch + // visit(_nodesMap[id].outputs()[1].first, numOfLayers+1, numOfLayers); // false branch + // } + // else { + // visit(_nodesMap[id].outputs()[1].first, layerNum+1, numOfLayers); // true branch + // visit(_nodesMap[id].outputs()[0].first, numOfLayers+1, numOfLayers); // false branch + // } + // } + }; - // delete empty layers - for (auto it = sortedGraphTemp.begin(); it != sortedGraphTemp.end(); ++it) { - bool isEmpty = true; - for (uint i = 0; i < it->width(); ++i) { - if(it->at(i).length() != 0) - isEmpty = false; - break; - } + // perform topological sort + uint numOfLayers = 0; + for (const auto& id : startNodes) + for (const auto& nextId : workMap[id]._out) + visit(nextId, 1, numOfLayers); - if(isEmpty) - sortedGraphTemp.erase(it--); - } - // check whether there are layers with one OpSequence which in turn contains only one op - bool isLayerWithOneOp = false; - for(auto& layer : sortedGraphTemp) { - if(layer.width() == 1 && layer.at(0).length() == 1) { - isLayerWithOneOp = true; - break; - } - } + // fill vectors with layers + std::vector sortedGraphTemp(numOfLayers+1); + for (const auto& p : workMap) { - // fill _sortedGraph - if(!isLayerWithOneOp) { - _sortedGraph = std::move(sortedGraphTemp); - } - else { - for (uint i = 0; i < sortedGraphTemp.size();) { + OpSequence seq; + seq.append(_nodesMap.at(p.first), _nodesMap.at(p.first).contextPrototype()); - OpSequence seq; - while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) - seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); + for (const auto& id : p.second._opSeq) + seq.append(_nodesMap.at(id), _nodesMap.at(id).contextPrototype()); - if(seq.length() != 0) - _sortedGraph.emplace_back(ExecutionLayer({seq})); - else - _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); - } + // while(sortedGraphTemp.size() <= p.second._layerNum) + // sortedGraphTemp.emplace_back(ExecutionLayer()); + sortedGraphTemp[p.second._layerNum].append(std::move(seq)); + } + + + // check whether there are layers with one OpSequence which in turn contains only one op + bool isLayerWithOneOp = false; + for(auto& layer : sortedGraphTemp) { + if(layer.width() == 1 && layer.at(0).length() == 1) { + isLayerWithOneOp = true; + break; } + } - // sort _sortedGraph - // for (auto& l : _sortedGraph) - // l.sortOpSequences(); + // fill _sortedGraph + if(!isLayerWithOneOp) { + _sortedGraph = std::move(sortedGraphTemp); + } + else { + for (uint i = 0; i < sortedGraphTemp.size();) { + + OpSequence seq; + while(i < sortedGraphTemp.size() && sortedGraphTemp[i].width() == 1 && sortedGraphTemp[i].at(0).length() == 1) + seq.append(std::move(sortedGraphTemp[i++].at(0).at(0))); + + if(seq.length() != 0) + _sortedGraph.emplace_back(ExecutionLayer({seq})); + else + _sortedGraph.emplace_back(std::move(sortedGraphTemp[i++])); + } - // clean up before exiting - purgeEmptyLayers(); } + purgeEmptyLayers(); +} + +/////////////////////////////////////////////////////////////////// +OptimizedGraph::OptimizedGraph(const MAP_IMPL& inMap, const VariableSpace& varSpace): _nodesMap(inMap) { + bool hasEnter = false; + for (const auto& p : _nodesMap) { + if(p.second.name().find("Enter") != std::string::npos) { + hasEnter = true; + break; + } + } + + if(hasEnter) + sortGraphWithFrames(varSpace); + else + sortUsualGraph(varSpace); } +/////////////////////////////////////////////////////////////////// void OptimizedGraph::append(const OpSequence &sequence) { _sortedGraph.emplace_back(ExecutionLayer({sequence})); } From 8b6aa54be837eb207afaa35bb5bb535116ced428 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 25 Jun 2020 18:26:42 +0300 Subject: [PATCH 226/233] 2 more tests for Switch and While Signed-off-by: raver119@gmail.com --- .../layers_tests/GraphAnalysisTests.cpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp index d3522207bb30..94b4d0dbc570 100644 --- a/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/GraphAnalysisTests.cpp @@ -1108,4 +1108,26 @@ TEST_F(GraphAnalysisTests, optimizedGraph_nested_while_1) { auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {13.f, 14.f, 15.f, 16.f}), results["output"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_simpleif_0_alt) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create(11.f); + + auto graph = Graph::fromFlatBuffers("resources/simpleif_0_alt.fb"); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {3.f, 4.f, 5.f, 6.f}), results["output"]); +} + +TEST_F(GraphAnalysisTests, optimizedGraph_simplewhile_1) { + auto input_0 = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto input_1 = NDArrayFactory::create(21.f); + + auto graph = Graph::fromFlatBuffers("resources/simplewhile_1.fb"); + + auto results = graph.execute({{"input_0", input_0}, {"input_1", input_1}}, {"output"}); + + ASSERT_EQ(NDArrayFactory::create('c', {2, 2}, {24.f, 25.f, 26.f, 27.f}), results["output"]); } \ No newline at end of file From 0ff755f101ceafb8131dd672adefa0c8f62edabb Mon Sep 17 00:00:00 2001 From: Yurii Date: Fri, 26 Jun 2020 14:40:51 +0300 Subject: [PATCH 227/233] - provide connections between Exit and NextIeration nodes within same frame in order to make toposort alg work correctly Signed-off-by: Yurii --- libnd4j/include/graph/impl/OptimizedGraph.cpp | 78 ++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index eec971c7e072..6fff1dd55248 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -55,10 +55,10 @@ void OptimizedGraph::sortGraphWithFrames(const VariableSpace& varSpace) { MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes) // create workMap, fill vectors containing input and output nodes per each node, and find start nodes - std::unordered_map> nextItersOfFrame; // key - id of frame, value is vector containing ids of all NextIterations of this frame - std::hash hasher; + std::unordered_map> nextItersOfFrame; // key - id of frame, value is vector containing ids of all NextIterations of this frame for (auto& p : _nodesMap) { + const auto& id = p.first; const auto& inputs = p.second.inputs(); const auto& nameOfNode = p.second.name(); @@ -66,31 +66,32 @@ void OptimizedGraph::sortGraphWithFrames(const VariableSpace& varSpace) { const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; p.second.setFrameId(_nodesMap[idOfEnter].frameId()); - _nodesMap[idOfEnter].setExitId(p.first); + _nodesMap[idOfEnter].setExitId(id); } else if (nameOfNode.find("NextIteration") != std::string::npos) { const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); - nextItersOfFrame[hasher(frameName)].push_back(p.first); + nextItersOfFrame[frameName].push_back(id); } for (int i = 0; i < inputs.size(); ++i) { - if (_nodesMap.count(inputs[i].first) != 0) { // is op + const auto& inId = inputs[i].first; + if (_nodesMap.count(inId) != 0) { // is op - _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); - workMap[p.first]._in.push_back(inputs[i].first); + _nodesMap[inId].pickOutput(id, inputs[i].second); + workMap[id]._in.push_back(inId); } else { // is variable - const auto depends = varSpace.getVariable(inputs[i].first).get()->dependencies(); + const auto& depends = varSpace.getVariable(inId).get()->dependencies(); for (int j = 0; j < depends.size(); ++j) { - if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { + if(std::find(workMap[id]._in.begin(), workMap[id]._in.end(), depends[j].first) == workMap[id]._in.end()) { - _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); - workMap[p.first]._in.push_back(depends[j].first); + _nodesMap[depends[j].first].pickOutput(id, depends[j].second); + workMap[id]._in.push_back(depends[j].first); } } } @@ -121,7 +122,7 @@ void OptimizedGraph::sortGraphWithFrames(const VariableSpace& varSpace) { else if (nameOfNode.find("Exit") != std::string::npos) { const std::string frameName = nameOfNode.substr(0, nameOfNode.find_last_of("/")); - for (const auto& j : nextItersOfFrame[hasher(frameName)]) { + for (const auto& j : nextItersOfFrame[frameName]) { if(!workMap[j]._isActive){ makeActive = false; break; @@ -163,44 +164,69 @@ void OptimizedGraph::sortUsualGraph(const VariableSpace& varSpace) { MAP_IMPL workMap; // key is node id, value is class NodeInfo containing auxiliary information (layer number this node belongs to, input/output nodes, OpSequence that starts from this node) // create workMap, fill vectors containing input and output nodes per each node, and find start nodes - std::vector startNodes; + std::vector startNodes, idsOfExits, idsOfNextIters; for (auto& p : _nodesMap) { + const auto& id = p.first; const auto& inputs = p.second.inputs(); if(p.second.name().find("Exit") != std::string::npos) { + idsOfExits.push_back(id); const int idOfEnter = _nodesMap[_nodesMap[inputs[0].first].inputs()[0].first].inputs()[0].first; p.second.setFrameId(_nodesMap[idOfEnter].frameId()); - _nodesMap[idOfEnter].setExitId(p.first); + _nodesMap[idOfEnter].setExitId(id); + } + else if(p.second.name().find("NextIteration") != std::string::npos) { + idsOfNextIters.push_back(id); } for (int i = 0; i < inputs.size(); ++i) { - if (_nodesMap.count(inputs[i].first) != 0) { // is op - _nodesMap[inputs[i].first].pickOutput(p.first, inputs[i].second); - if(_nodesMap[inputs[i].first].name().find("NextIteration") == std::string::npos) { - workMap[inputs[i].first]._out.push_back(p.first); - workMap[p.first]._in.push_back(inputs[i].first); + const auto& inId = inputs[i].first; + if (_nodesMap.count(inId) != 0) { // is op + + _nodesMap[inId].pickOutput(id, inputs[i].second); + + if(_nodesMap[inId].name().find("NextIteration") == std::string::npos) { + workMap[inId]._out.push_back(id); + workMap[id]._in.push_back(inId); } } else { // is variable - const auto depends = varSpace.getVariable(inputs[i].first).get()->dependencies(); + const auto& depends = varSpace.getVariable(inId).get()->dependencies(); for (int j = 0; j < depends.size(); ++j) { - if(std::find(workMap[p.first]._in.begin(), workMap[p.first]._in.end(), depends[j].first) == workMap[p.first]._in.end()) { - _nodesMap[depends[j].first].pickOutput(p.first, depends[j].second); + if(std::find(workMap[id]._in.begin(), workMap[id]._in.end(), depends[j].first) == workMap[id]._in.end()) { + _nodesMap[depends[j].first].pickOutput(id, depends[j].second); if(_nodesMap[depends[j].first].name().find("NextIteration") == std::string::npos) { - workMap[depends[j].first]._out.push_back(p.first); - workMap[p.first]._in.push_back(depends[j].first); + workMap[depends[j].first]._out.push_back(id); + workMap[id]._in.push_back(depends[j].first); } } } } } - if(workMap[p.first]._in.empty()) - startNodes.push_back(p.first); + if(workMap[id]._in.empty()) + startNodes.push_back(id); + } + + // make all NextIteration within current frame to be inputs for all Exits within the same frame + for (const auto& i0 : idsOfExits) { + + const std::string frameName0 = _nodesMap[i0].name().substr(0, _nodesMap[i0].name().find_last_of("/")); + + for (const auto& i1 : idsOfNextIters) { + + const std::string frameName1 = _nodesMap[i1].name().substr(0, _nodesMap[i1].name().find_last_of("/")); + + if(frameName0 != frameName1) + continue; + + workMap[i0]._in.push_back(i1); + workMap[i1]._out.push_back(i0); + } } // collect OpSequences (fill _opSeq) From f9ce92ccacbb8628ae1a308111b8e061fbee1130 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 29 Jun 2020 09:26:25 +0300 Subject: [PATCH 228/233] stubs for cuda Signed-off-by: raver119@gmail.com --- .../include/array/cuda/ManagedDataBuffer.cu | 27 +++++++++++++++++++ .../memory/{cpu => impl}/ColdZoneManager.cpp | 0 .../{cpu => impl}/GraphMemoryManager.cpp | 0 .../memory/{cpu => impl}/HotZoneManager.cpp | 0 4 files changed, 27 insertions(+) create mode 100644 libnd4j/include/array/cuda/ManagedDataBuffer.cu rename libnd4j/include/memory/{cpu => impl}/ColdZoneManager.cpp (100%) rename libnd4j/include/memory/{cpu => impl}/GraphMemoryManager.cpp (100%) rename libnd4j/include/memory/{cpu => impl}/HotZoneManager.cpp (100%) diff --git a/libnd4j/include/array/cuda/ManagedDataBuffer.cu b/libnd4j/include/array/cuda/ManagedDataBuffer.cu new file mode 100644 index 000000000000..610b2faf9a3b --- /dev/null +++ b/libnd4j/include/array/cuda/ManagedDataBuffer.cu @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace sd { +void *ManagedDataBuffer::primary() { return _descriptor.address(); } + +void *ManagedDataBuffer::special() { return nullptr; } +} // namespace sd \ No newline at end of file diff --git a/libnd4j/include/memory/cpu/ColdZoneManager.cpp b/libnd4j/include/memory/impl/ColdZoneManager.cpp similarity index 100% rename from libnd4j/include/memory/cpu/ColdZoneManager.cpp rename to libnd4j/include/memory/impl/ColdZoneManager.cpp diff --git a/libnd4j/include/memory/cpu/GraphMemoryManager.cpp b/libnd4j/include/memory/impl/GraphMemoryManager.cpp similarity index 100% rename from libnd4j/include/memory/cpu/GraphMemoryManager.cpp rename to libnd4j/include/memory/impl/GraphMemoryManager.cpp diff --git a/libnd4j/include/memory/cpu/HotZoneManager.cpp b/libnd4j/include/memory/impl/HotZoneManager.cpp similarity index 100% rename from libnd4j/include/memory/cpu/HotZoneManager.cpp rename to libnd4j/include/memory/impl/HotZoneManager.cpp From 50aa79f5a5b80531c78fce2539b5ab5c1d4c7f24 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 29 Jun 2020 11:28:21 +0300 Subject: [PATCH 229/233] BERT test updates Signed-off-by: raver119@gmail.com --- .../graph/execution/impl/GraphExecutor.cpp | 6 +- libnd4j/include/graph/impl/Graph.cpp | 13 +++- libnd4j/include/graph/impl/OptimizedGraph.cpp | 4 +- .../layers_tests/PlaygroundTests.cpp | 65 +++++-------------- 4 files changed, 33 insertions(+), 55 deletions(-) diff --git a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp index 99465d80c65e..6ffb29976e95 100644 --- a/libnd4j/include/graph/execution/impl/GraphExecutor.cpp +++ b/libnd4j/include/graph/execution/impl/GraphExecutor.cpp @@ -85,7 +85,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, if (v.node().opType() == OpType_LOGIC) { - nd4j_printf("Node <%i:%s> is a logic op; Current frame: [%i]\n", + nd4j_debug("Node <%i:%s> is a logic op; Current frame: [%i]\n", v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str(), f.id()); @@ -98,7 +98,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, e = seq.nodeIndex(f.enterId()) - 1; } } else if (v.node().hasCustomOp()) { - nd4j_printf("Node <%i:%s> is a custom op; Current frame: [%i]\n", + nd4j_debug("Node <%i:%s> is a custom op; Current frame: [%i]\n", v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str(), f.id()); @@ -106,7 +106,7 @@ Nd4jStatus GraphExecutor::execute(const OpSequence &seq, // only Ops can be executed this way :( result = execute(v.node().customOp(), v.protoContext(), seq, graph, const_cast(p), targetDevice); } else { - nd4j_printf("Node <%i:%s> has no customOp set; Current frame: [%i]\n", + nd4j_debug("Node <%i:%s> has no customOp set; Current frame: [%i]\n", v.node().id(), v.node().name().empty() ? "" : v.node().name().c_str(), f.id()); diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 0123bc92cb01..c67979bea9b7 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -246,7 +246,18 @@ void Graph::printOutNode(const Node &node) const { } void Graph::printOut() { - // print variables first + // print placeholders + if (_placeholders.size() > 0) { + nd4j_printf("\nPrinting out Placeholders...\n", ""); + for (auto &v:_placeholders) { + auto var = _variableSpace.getVariable(v); + auto shape = ShapeUtils::shapeAsString(var->shape()); + auto dtype = DataTypeUtils::asString(var->dataType()); + nd4j_printf("<%s> <%i> dtype: %s; shape: %s; \n", v.c_str(), var->id(), dtype.c_str(), shape.c_str()); + } + } + + // print variables if (_variableSpace.totalEntries() > 0) { nd4j_printf("\nPrinting out Variables...\n", ""); auto vars = _variableSpace.variables(); diff --git a/libnd4j/include/graph/impl/OptimizedGraph.cpp b/libnd4j/include/graph/impl/OptimizedGraph.cpp index 6fff1dd55248..f01a6d0e8c2c 100644 --- a/libnd4j/include/graph/impl/OptimizedGraph.cpp +++ b/libnd4j/include/graph/impl/OptimizedGraph.cpp @@ -422,7 +422,7 @@ void OptimizedGraph::printOut() const { printf("}\n"); } - +/* printf("And simple print:\n"); for (int i = 0; i < _sortedGraph.size(); ++i) { printf("layer %i: ", i); @@ -435,7 +435,7 @@ void OptimizedGraph::printOut() const { } printf("\n"); } - +*/ } diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 9244dbdc3266..9deb691507a2 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -95,7 +95,7 @@ TEST_F(PlaygroundTests, test_biasAdd_1) { } TEST_F(PlaygroundTests, test_bert_full_1) { -#ifdef _RELEASE +//#ifdef _RELEASE // this test will run ONLY if this model exists if (!FileUtils::fileExists("/home/raver119/Downloads/BertFull/model.fb")) @@ -113,32 +113,22 @@ TEST_F(PlaygroundTests, test_bert_full_1) { auto z = NDArrayFactory::fromNpyFile( "/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); - // graph->printOut(); - - graph->tagInplaceNodes(); + graph.printOut(); - graph->variableSpace()->putVariable(658, 0, t); - graph->variableSpace()->putVariable(659, 0, u); - graph->variableSpace()->putVariable(660, 0, v); - - /* - // validating graph now - auto status = GraphExecutioner::execute(graph); - ASSERT_EQ(Status::OK(), status); - ASSERT_TRUE(graph->variableSpace()->hasVariable(1620)); - - auto array = graph->variableSpace()->getVariable(1620)->getNDArray(); - ASSERT_EQ(z, *array); +/* + // validating graph now + auto results = graph.execute({{"IteratorGetNext", t}, {"IteratorGetNext:1", u}, {"IteratorGetNext:4", v}}, {"loss/Softmax"}); + ASSERT_EQ(z, results["loss/Softmax"]); +*/ - */ - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); +// sd::Environment::getInstance().setProfiling(true); + //auto profile = GraphProfilingHelper::profile(graph, 1); - profile->printOut(); + //profile->printOut(); - sd::Environment::getInstance().setProfiling(false); - delete profile; + //sd::Environment::getInstance().setProfiling(false); + //delete profile; /* std::vector values; @@ -158,9 +148,8 @@ TEST_F(PlaygroundTests, test_bert_full_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ - delete graph; -#endif +//#endif } TEST_F(PlaygroundTests, test_bert_1) { @@ -188,11 +177,10 @@ TEST_F(PlaygroundTests, test_bert_1) { // graph->printOut(); - graph->tagInplaceNodes(); - graph->variableSpace()->putVariable(85, 0, t); - graph->variableSpace()->putVariable(86, 0, u); - graph->variableSpace()->putVariable(87, 0, v); + graph.variableSpace().putVariable(85, 0, t); + graph.variableSpace().putVariable(86, 0, u); + graph.variableSpace().putVariable(87, 0, v); /* // validating graph now @@ -204,13 +192,6 @@ TEST_F(PlaygroundTests, test_bert_1) { ASSERT_EQ(z, *array); */ - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); - - profile->printOut(); - - sd::Environment::getInstance().setProfiling(false); - delete profile; /* std::vector values; @@ -230,7 +211,6 @@ TEST_F(PlaygroundTests, test_bert_1) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ - delete graph; #endif } @@ -246,10 +226,6 @@ TEST_F(PlaygroundTests, test_bert_2) { auto graph = Graph::fromFlatBuffers( "/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); - // graph->printOut(); - - graph->tagInplaceNodes(); - /* // validating graph now auto status = GraphExecutioner::execute(graph); @@ -260,14 +236,6 @@ TEST_F(PlaygroundTests, test_bert_2) { ASSERT_EQ(z, *array); */ - sd::Environment::getInstance().setProfiling(true); - auto profile = GraphProfilingHelper::profile(graph, 1); - - profile->printOut(); - - sd::Environment::getInstance().setProfiling(false); - delete profile; - /* std::vector values; @@ -286,7 +254,6 @@ TEST_F(PlaygroundTests, test_bert_2) { nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); */ - delete graph; #endif } From fc3220d388329b701b6538e0f71d9b5e66cb8cfa Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 29 Jun 2020 11:42:54 +0300 Subject: [PATCH 230/233] mmap for CUDA Signed-off-by: raver119@gmail.com --- libnd4j/include/legacy/cuda/NativeOps.cu | 47 +++++++++++++++- .../layers_tests/DeclarableOpsTests10.cpp | 56 +++++++------------ 2 files changed, 66 insertions(+), 37 deletions(-) diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index 32892a67dd79..da7da37d721a 100644 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -40,6 +40,16 @@ #include #include +// this section is for MMAP +#ifndef _WIN32 +#include +#include +#include +#else +#include +#include +#endif + using namespace sd; #include @@ -2824,11 +2834,46 @@ void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong *mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { - return nullptr; + auto hZ = new Nd4jLong[3]; + errno = 0; + try { +#if defined(_WIN32) || defined(_WIN64) + _mmap(hZ, static_cast(length), fileName); +#else + int fd = open(fileName, O_RDWR, 0); // checking for failed fopen + if (fd < 0) { + nd4j_printf("Errno: %i\n", errno); + throw std::runtime_error("Failed to open file for MMAP"); + } + void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + + // check for failed allocation + if (ptr == MAP_FAILED) return nullptr; + + hZ[0] = (Nd4jLong)ptr; + hZ[1] = fd; + +#endif + hZ[2] = length; + + return hZ; + } catch (std::exception &e) { + sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage( + e.what()); + return nullptr; + } } void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { + munmap((Nd4jPointer)ptrMap[0], length); +#if defined(_WIN32) || defined(_WIN64) + CloseHandle(reinterpret_cast(ptrMap[1])); +#else + close((int)ptrMap[1]); +#endif + delete[] ptrMap; } sd::graph::ResultWrapper *executeFlatGraph(Nd4jPointer *extraPointers, diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 785b48417078..a3211796b5bb 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2489,16 +2489,11 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { - NDArray boxes = NDArrayFactory::create( - 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); - NDArray scores = - NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); // 3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', - { - 1, - }, - {3}); + auto boxes = NDArrayFactory::create('c', {4, 4}, + {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + auto scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); // 3 + auto max_num = NDArrayFactory::create(3); + auto expected = NDArrayFactory::create('c',{1}, {3}); sd::ops::non_max_suppression_overlaps op; auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); @@ -2513,16 +2508,11 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { - NDArray boxes = NDArrayFactory::create( + auto boxes = NDArrayFactory::create( 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); - NDArray scores = - NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); // 3 - NDArray max_num = NDArrayFactory::create(3); - NDArray expected = NDArrayFactory::create('c', - { - 3, - }, - {1, 1, 1}); + auto scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); // 3 + auto max_num = NDArrayFactory::create(3); + auto expected = NDArrayFactory::create('c', {3}, {1, 1, 1}); sd::ops::non_max_suppression_overlaps op; auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); @@ -2537,16 +2527,11 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { - NDArray boxes = NDArrayFactory::create( - 'c', {4, 4}, {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); - NDArray scores = - NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); // 3 - NDArray max_num = NDArrayFactory::create(5); - NDArray expected = NDArrayFactory::create('c', - { - 5, - }, - {1, 1, 1, 1, 1}); + auto boxes = NDArrayFactory::create('c', {4, 4}, + {0, 0, 1, 1, 0, 0.1, 1, 1.1, 0, -0.1, 1, 0.9, 0, 10, 1, 11}); + auto scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); // 3 + auto max_num = NDArrayFactory::create(5); + auto expected = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); sd::ops::non_max_suppression_overlaps op; auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); @@ -2562,11 +2547,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { int axis = 0; - NDArray images = - NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); - NDArray boxes = NDArrayFactory::create('c', {1, 4}, {0, 0, 1, 1}); - NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); - NDArray cropSize = NDArrayFactory::create({1, 1}); + auto images = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto boxes = NDArrayFactory::create('c', {1, 4}, {0, 0, 1, 1}); + auto boxI = NDArrayFactory::create('c', {1}, {axis}); + auto cropSize = NDArrayFactory::create({1, 1}); // NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); NDArray expected = NDArrayFactory::create('c', {1, 1, 1, 1}, {2.5f}); @@ -2679,8 +2663,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { NDArray images = NDArrayFactory::create('c', {2, 4, 5, 3}); NDArray boxes = NDArrayFactory::create( 'c', {2, 2, 4}, - {0.f, 0.f, 1.f, 1.f, 0.1f, 0.2f, 0.9f, 0.8f, 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, - 0.4f, 0.6f, 0.6f}); + {0.f, 0.f, 1.f, 1.f, 0.1f, 0.2f, 0.9f, 0.8f, 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f}); NDArray colors = NDArrayFactory::create( 'c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); @@ -2701,6 +2684,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}); + images.linspace(1.); sd::ops::draw_bounding_boxes op; auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); From d4c0abd99c45ee6cb3db0967bea9cf2c93296f32 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 29 Jun 2020 12:05:47 +0300 Subject: [PATCH 231/233] non_max_suppression_overlaps CUDA fix Signed-off-by: raver119@gmail.com --- .../helpers/cuda/image_suppression.cu | 19 ++++++------------- .../layers_tests/DeclarableOpsTests10.cpp | 9 +++------ 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 87f6d6fc7982..96825de5968a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -422,17 +422,11 @@ static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, if (!scores->isActualOnDeviceSide()) scores->syncToDevice(); } - NDArray indices = NDArrayFactory::create( - 'c', {scores->lengthOf()}, - context); // - 1, scales->lengthOf()); //, scales->getContext()); - NDArray startPositions = - NDArrayFactory::create('c', {scores->lengthOf()}, context); - NDArray selectedScores(*scores); + auto indices = NDArrayFactory::create('c', {scores->lengthOf()},context); + auto startPositions = NDArrayFactory::create('c', {scores->lengthOf()}, context); + auto selectedScores = scores->dup(); Nd4jPointer extras[2] = {nullptr, stream}; - auto indexBuf = - indices.dataBuffer() - ->specialAsT< - I>(); /// reinterpret_cast(indices->specialBuffer()); + auto indexBuf = indices.dataBuffer()->template specialAsT(); suppressScores<<<128, 128, 128, *stream>>>( selectedScores.dataBuffer()->specialAsT(), indexBuf, @@ -446,10 +440,9 @@ static Nd4jLong nonMaxSuppressionGeneric_(sd::LaunchContext* context, indices.tickWriteDevice(); selectedScores.tickWriteDevice(); - auto scoresData = selectedScores.dataBuffer() - ->specialAsT(); //, numBoxes, scoresData.begin()); + auto scoresData = selectedScores.dataBuffer()->template specialAsT(); - auto startIndices = startPositions.dataBuffer()->specialAsT(); + auto startIndices = startPositions.dataBuffer()->template specialAsT(); I selectedSize = 0; Nd4jLong res = 0; if (output) { // this part used when output shape already calculated to fill diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index a3211796b5bb..f2e7196a8c22 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2502,8 +2502,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap1 Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// @@ -2521,8 +2520,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// @@ -2540,8 +2538,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { auto result = results.at(0); // result.printBuffer("NonMaxSuppressionOverlap Output"); - ASSERT_TRUE(expected.isSameShapeStrict(result)); - ASSERT_TRUE(expected.equalsTo(result)); + ASSERT_EQ(expected, result); } //////////////////////////////////////////////////////////////////// From bd9f9c2d2168d4044fb4e82a7c3394a642682f8a Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Mon, 29 Jun 2020 12:24:51 +0300 Subject: [PATCH 232/233] sqrtm :/ Signed-off-by: raver119@gmail.com --- libnd4j/include/helpers/EigenValsAndVecs.h | 67 ++++++++++--------- .../ops/declarable/helpers/impl/sqrtm.cpp | 12 ++-- .../layers_tests/DeclarableOpsTests15.cpp | 46 ++++++------- 3 files changed, 65 insertions(+), 60 deletions(-) diff --git a/libnd4j/include/helpers/EigenValsAndVecs.h b/libnd4j/include/helpers/EigenValsAndVecs.h index 222b9c36ed36..f4d52f97f868 100644 --- a/libnd4j/include/helpers/EigenValsAndVecs.h +++ b/libnd4j/include/helpers/EigenValsAndVecs.h @@ -30,51 +30,52 @@ namespace helpers { // this class calculates eigenvalues and eigenvectors of given input matrix template class EigenValsAndVecs { + public: + // suppose we got input square NxN matrix - public: - // suppose we got input square NxN matrix + // {N,2} matrix of eigenvalues, 2 means real and imaginary part + NDArray _Vals; + // {N,N,2} matrix, whose columns are the eigenvectors (complex), 2 means real and imaginary part + NDArray _Vecs; - NDArray _Vals; // {N,2} matrix of eigenvalues, 2 means real and imaginary part - NDArray _Vecs; // {N,N,2} matrix, whose columns are the eigenvectors (complex), 2 means real and imaginary part + explicit EigenValsAndVecs(const NDArray& matrix); - explicit EigenValsAndVecs(const NDArray& matrix); + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void divideComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + T norm2 = a2*a2 + b2*b2; + a3 = (a1*a2 + b1*b2) / norm2; + b3 = (a2*b1 - a1*b2) / norm2; + } - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void divideComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void multiplyComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + a3 = (a1*a2 - b1*b2); + b3 = (a1*b2 + b1*a2); + } - T norm2 = a2*a2 + b2*b2; + ////////////////////////////////////////////////////////////////////////// + FORCEINLINE static void sqrtComplexNum(T& a, T& b) { + T norm = math::nd4j_sqrt(a*a + b*b); - a3 = (a1*a2 + b1*b2) / norm2; - b3 = (a2*b1 - a1*b2) / norm2; - } + if(b < (T)0) + b = -math::nd4j_sqrt((T)0.5 * (norm - a)); + else + b = math::nd4j_sqrt((T)0.5 * (norm - a)); - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void multiplyComplexNums(const T& a1, const T& b1, const T& a2, const T& b2, T& a3, T& b3) { + a = math::nd4j_sqrt((T)0.5 * (norm + a)); + } - a3 = (a1*a2 - b1*b2); - b3 = (a1*b2 + b1*a2); - } - ////////////////////////////////////////////////////////////////////////// - FORCEINLINE static void sqrtComplexNum(T& a, T& b) { + private: + // calculates _Vals + void calcEigenVals(const NDArray& schurMatrixT); - T norm = math::nd4j_sqrt(a*a + b*b); - - if(b < (T)0) - b = -math::nd4j_sqrt((T)0.5 * (norm - a)); - else - b = math::nd4j_sqrt((T)0.5 * (norm - a)); - a = math::nd4j_sqrt((T)0.5 * (norm + a)); - } - - - private: - - void calcEigenVals(const NDArray& schurMatrixT); // calculates _Vals - void calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU); // makes changes both in schurMatrixT(NxN) and schurMatrixU(NxN), also calculates and stores pseudo-eigenvectors (real) in schurMatrixU columns - void calcEigenVecs(const NDArray& schurMatrixU); // calculates _Vecs + // makes changes both in schurMatrixT(NxN) and schurMatrixU(NxN), also calculates and stores pseudo-eigenvectors (real) in schurMatrixU columns + void calcPseudoEigenVecs(NDArray& schurMatrixT, NDArray& schurMatrixU); + // calculates _Vecs + void calcEigenVecs(const NDArray& schurMatrixU); }; diff --git a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp index f3bca80c03f6..55d3f97d3bb7 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp @@ -40,6 +40,7 @@ static void sqrtm_(const NDArray* x, NDArray* z) { auto listX = x->allTensorsAlongDimension({-2, -1}); auto listZ = z->allTensorsAlongDimension({-2, -1}); +#ifndef __CUDABLAS__ auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) @@ -47,16 +48,19 @@ static void sqrtm_(const NDArray* x, NDArray* z) { }; samediff::Threads::parallel_tad(func, 0, listX.size()); +#else + for (auto i = 0; i < listX.size(); i++) + ops::helpers::Sqrtm::calc(listX.at(i), listZ.at(i)); +#endif } } ////////////////////////////////////////////////////////////////////////// void sqrtm(sd::LaunchContext* context, const NDArray* x, NDArray* z) { - - x->syncToHost(); - BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES); - z->syncToDevice(); + x->syncToHost(); + BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES); + z->syncToDevice(); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 6ae1284a7766..2f446570309c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -2326,44 +2326,44 @@ TEST_F(DeclarableOpsTests15, gru_1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, sqrtm_1) { NDArray x1('c', {1,1}, {4.}, sd::DataType::DOUBLE); - NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE); NDArray x3('c', {3,3}, {0.5 ,-0.4 ,1.2 ,-2.8 ,-0.2 ,-2.1 ,-2.4 ,-2.0 ,1.1}, sd::DataType::DOUBLE); NDArray x4('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); NDArray x5('c', {5,5}, {2.4 ,0.3 ,0.0 ,1.1 ,1.8 ,0.1 ,1.7 ,2.7 ,1.5 ,2.6 ,0.6 ,2.1 ,2.2 ,1.0 ,0.2 ,1.2 ,2.8 ,1.9 ,0.8 ,2.0 ,0.5 ,1.6 ,0.9 ,1.4 ,2.5}, sd::DataType::DOUBLE); - NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE); - NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE); - NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE); - NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 , + NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE); + NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE); + NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE); + NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 , 0.8607272, 3.1792474,-0.9499947, 0.8541668,-1.4243879, 0.0081136,-0.0622248, 0.4534325, 0.4641865, 1.8132138}, sd::DataType::DOUBLE); - sd::ops::sqrtm op; + sd::ops::sqrtm op; - auto results = op.evaluate({&x1}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp1.isSameShape(results.at(0))); - ASSERT_TRUE(exp1.equalsTo(results.at(0))); + auto results = op.evaluate({&x1}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp1.isSameShape(results.at(0))); + ASSERT_EQ(exp1, results.at(0)); - results = op.evaluate({&x2}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp2.isSameShape(results.at(0))); - ASSERT_TRUE(exp2.equalsTo(results.at( 0))); + results = op.evaluate({&x2}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp2.isSameShape(results.at(0))); + ASSERT_EQ(exp2, results.at( 0)); - results = op.evaluate({&x3}, {}, {}); + results = op.evaluate({&x3}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp3.isSameShape(results.at(0))); - ASSERT_TRUE(exp3.equalsTo(results.at(0))); + ASSERT_TRUE(exp3.isSameShape(results.at(0))); + ASSERT_EQ(exp3, results.at(0)); results = op.evaluate({&x4}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp4.isSameShape(results.at(0))); - ASSERT_TRUE(exp4.equalsTo(results.at(0))); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp4.isSameShape(results.at(0))); + ASSERT_EQ(exp4, results.at(0)); results = op.evaluate({&x5}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, results.status()); - ASSERT_TRUE(exp5.isSameShape(results.at( 0))); - ASSERT_TRUE(exp5.equalsTo(results.at(0))); + ASSERT_TRUE(exp5.isSameShape(results.at( 0))); + ASSERT_EQ(exp5, results.at(0)); } ////////////////////////////////////////////////////////////////////// From b4fa5bf5c11eaf3eecf6fad15c14c9d5a15a8168 Mon Sep 17 00:00:00 2001 From: "raver119@gmail.com" Date: Thu, 2 Jul 2020 10:15:36 +0300 Subject: [PATCH 233/233] few updates for Java side Signed-off-by: raver119@gmail.com --- libnd4j/include/graph/Node.h | 2 + libnd4j/include/graph/Variable.h | 2 + libnd4j/include/graph/VariableProxy.h | 8 +- libnd4j/include/graph/VariableSpace.h | 7 +- .../java/org/nd4j/nativeblas/NativeOps.java | 5 - .../java/org/nd4j/nativeblas/Nd4jCpu.java | 32988 +++++++++------- .../org/nd4j/nativeblas/Nd4jCpuPresets.java | 26 +- 7 files changed, 17811 insertions(+), 15227 deletions(-) diff --git a/libnd4j/include/graph/Node.h b/libnd4j/include/graph/Node.h index de04f6194d62..cc0803ab93bc 100644 --- a/libnd4j/include/graph/Node.h +++ b/libnd4j/include/graph/Node.h @@ -177,8 +177,10 @@ class SD_EXPORT Node { bool hasCustomOp() const; +#ifndef __JAVACPP_HACK__ // this method converts string deps to int deps void actualizeDependencies(const MAP_IMPL &lookupTable) const; +#endif int frameId() const; void setFrameId(int frameId); diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 9ce96a6cc768..b948dc993c2c 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -132,8 +132,10 @@ class SD_EXPORT Variable { const std::vector>& dependencies() const; +#ifndef __JAVACPP_HACK__ // this method converts string deps to int deps void actualizeDependencies(const MAP_IMPL &lookupTable) const; +#endif #ifndef __JAVACPP_HACK__ /** diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index 6f17b29ac1ae..554c5296aa2f 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -38,8 +38,10 @@ class SD_EXPORT VariableProxy : public VariableSpace { virtual VariableSpace& operator=(const VariableSpace& other) override; virtual int numberOfPlaceholders() const override; - virtual const std::vector>& placeholders() - const override; +#ifndef __JAVACPP_HACK__ + virtual const std::vector>& placeholders() const override; + virtual std::vector> variables() const override; +#endif virtual const MAP_IMPL, std::shared_ptr>& externalPaired() const; @@ -60,8 +62,6 @@ class SD_EXPORT VariableProxy : public VariableSpace { virtual std::shared_ptr getVariable( const std::string& symbol) const override; - virtual std::vector> variables() const override; - virtual std::shared_ptr putVariable(const std::pair& pair, const NDArray& array) override; virtual std::shared_ptr putVariable(int id, diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 67e43a6d12f9..3a78337d7a5a 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -78,7 +78,10 @@ class SD_EXPORT VariableSpace { virtual int numberOfPlaceholders() const; - virtual const std::vector> &placeholders() const; +#ifndef __JAVACPP_HACK__ + virtual const std::vector>& placeholders() const; + virtual std::vector> variables() const; +#endif virtual bool hasExternalVariable(int it) const; virtual bool hasExternalVariable(const std::pair &pair) const; @@ -96,8 +99,6 @@ class SD_EXPORT VariableSpace { virtual std::shared_ptr getVariable( const std::string &symbol) const; - virtual std::vector> variables() const; - virtual std::shared_ptr putVariable(const std::pair &pair, const NDArray &array); virtual std::shared_ptr putVariable(int id, const NDArray &array); virtual std::shared_ptr putVariable(int id, int idx, const std::shared_ptr &array); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index b4bd62096c25..2619f9a933aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1078,11 +1078,6 @@ void sortTad(PointerPointer extraPointers, void deleteVariablesSet(OpaqueVariablesSet pointer); - // GraphState creation - Pointer getGraphState(long id); - - void deleteGraphState(Pointer state); - int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold); // this method executes op that requires scope to be present: if/while/cond/whatever diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 2926c06b97a3..2799508c9607 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -273,16 +273,179 @@ public IntIntPair put(int firstValue, int secondValue) { // Created by raver119 on 07.05.19. // -// #ifndef DEV_TESTS_MEMORYTYPE_H -// #define DEV_TESTS_MEMORYTYPE_H - /** enum sd::memory::MemoryType */ - public static final int - HOST = 0, - DEVICE = 10; - +// #ifndef SD_MEMORYTYPE_H +// #define SD_MEMORYTYPE_H +/** enum sd::memory::MemoryType */ +public static final int + HOST = 0, + DEVICE = 10; + + // namespace sd + +// #endif // SD_MEMORYTYPE_H + + +// Parsed from memory/MemoryZone.h + +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_MEMORYZONE_H +// #define SD_MEMORYZONE_H +/** enum sd::memory::MemoryZone */ +public static final int + COLD = 0, + WARM = 10, + HOT = 20; + + // namespace sd + +// #endif // SD_MEMORYZONE_H + + +// Parsed from memory/MemoryDescriptor.h + +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_MEMORYDESCRIPTOR_H +// #define SD_MEMORYDESCRIPTOR_H + +// #include +// #include + +// #include +@Namespace("sd::memory") @NoOffset public static class MemoryDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public MemoryDescriptor(Pointer p) { super(p); } + + public MemoryDescriptor(Pointer ptr, @Cast("sd::memory::MemoryZone") int zone, @Cast("uint64_t") long bytes) { super((Pointer)null); allocate(ptr, zone, bytes); } + private native void allocate(Pointer ptr, @Cast("sd::memory::MemoryZone") int zone, @Cast("uint64_t") long bytes); + + public MemoryDescriptor(@Const @ByRef MemoryDescriptor other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef MemoryDescriptor other); + + public native @ByRef @Name("operator =") @NoException MemoryDescriptor put(@Const @ByRef MemoryDescriptor other); + + // move constructor + + // move assignment operator + + public native @Name("address") Pointer _address(); + public native @Cast("sd::memory::MemoryZone") int zone(); + public native @Cast("uint64_t") long bytes(); +} + // namespace memory + // namespace sd + +// #endif // SD_MEMORYDESCRIPTOR_H + + +// Parsed from memory/GraphMemoryManager.h +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef SD_GRAPHMEMORYMANAGER_H +// #define SD_GRAPHMEMORYMANAGER_H + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class GraphMemoryManager extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public GraphMemoryManager(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public GraphMemoryManager(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public GraphMemoryManager position(long position) { + return (GraphMemoryManager)super.position(position); + } + + public GraphMemoryManager() { super((Pointer)null); allocate(); } + private native void allocate(); + + /** + * This method does allocation (probably) and returns structure that describes + * it + * @param numBytes - number of bytes to be allocated + * @param zone - memory zone for allocation + * @return + */ + public native @ByVal @Name("allocate") MemoryDescriptor _allocate(@Cast("size_t") long numBytes, @Cast("sd::memory::MemoryZone") int zone); + + /** + * This method releases (probably) memory chunk described by given descriptor + * @param descriptor + */ + public native void release(@ByRef MemoryDescriptor descriptor); + + /** + * This method allows to store reference to certain memory regions and keep them alive as long as Graph is alive + * @param ptr + */ +} + // namespace graph + // namespace sd -// #endif //DEV_TESTS_MEMORYTYPE_H +// #endif // SD_GRAPHMEMORYMANAGER_H // Parsed from array/DataType.h @@ -309,31 +472,31 @@ public IntIntPair put(int firstValue, int secondValue) { // #ifndef ND4J_DATATYPE_H // #define ND4J_DATATYPE_H - /** enum sd::DataType */ - public static final int - INHERIT = 0, - BOOL = 1, - FLOAT8 = 2, - HALF = 3, - HALF2 = 4, - FLOAT32 = 5, - DOUBLE = 6, - INT8 = 7, - INT16 = 8, - INT32 = 9, - INT64 = 10, - UINT8 = 11, - UINT16 = 12, - UINT32 = 13, - UINT64 = 14, - QINT8 = 15, - QINT16 = 16, - BFLOAT16 = 17, - UTF8 = 50, - UTF16 = 51, - UTF32 = 52, - ANY = 100, - AUTO = 200; +/** enum sd::DataType */ +public static final int + INHERIT = 0, + BOOL = 1, + FLOAT8 = 2, + HALF = 3, + HALF2 = 4, + FLOAT32 = 5, + DOUBLE = 6, + INT8 = 7, + INT16 = 8, + INT32 = 9, + INT64 = 10, + UINT8 = 11, + UINT16 = 12, + UINT32 = 13, + UINT64 = 14, + QINT8 = 15, + QINT16 = 16, + BFLOAT16 = 17, + UTF8 = 50, + UTF16 = 51, + UTF32 = 52, + ANY = 100, + AUTO = 200; // #endif @@ -361,16 +524,17 @@ public IntIntPair put(int firstValue, int secondValue) { // @author Yurii Shyrma (iuriish@yahoo.com) // -// #ifndef DEV_TESTS_DATABUFFER_H -// #define DEV_TESTS_DATABUFFER_H +// #ifndef SD_DATABUFFER_H +// #define SD_DATABUFFER_H -// #include -// #include -// #include -// #include // #include -// #include // #include +// #include +// #include +// #include +// #include + +// #include @Namespace("sd") @NoOffset public static class DataBuffer extends Pointer { static { Loader.load(); } @@ -383,112 +547,117 @@ public IntIntPair put(int firstValue, int secondValue) { return (DataBuffer)super.position(position); } + public DataBuffer(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType, @Cast("const bool") boolean isOwnerPrimary/*=false*/, + @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } + private native void allocate(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType, @Cast("const bool") boolean isOwnerPrimary/*=false*/, + @Cast("const bool") boolean isOwnerSpecial/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } + private native void allocate(Pointer primary, Pointer special, @Cast("const size_t") long lenInBytes, + @Cast("const sd::DataType") int dataType); + + public DataBuffer(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } + private native void allocate(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + @Cast("const bool") boolean isOwnerPrimary/*=false*/, + Workspace workspace/*=nullptr*/); + public DataBuffer(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } + private native void allocate(Pointer primary, @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); + + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, + Workspace workspace/*=nullptr*/); + public DataBuffer(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } + private native void allocate(@Const Pointer hostBuffer, + @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes); + + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + Workspace workspace/*=nullptr*/, + @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, + Workspace workspace/*=nullptr*/, + @Cast("const bool") boolean allocBoth/*=false*/); + public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } + private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); + + public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef DataBuffer other); + public DataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); - public DataBuffer(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType, isOwnerPrimary, isOwnerSpecial, workspace); } - private native void allocate(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, @Cast("const bool") boolean isOwnerSpecial/*=false*/, - Workspace workspace/*=nullptr*/); - public DataBuffer(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, special, lenInBytes, dataType); } - private native void allocate(Pointer primary, Pointer special, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(primary, lenInBytes, dataType, isOwnerPrimary, workspace); } - private native void allocate(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, - @Cast("const bool") boolean isOwnerPrimary/*=false*/, - Workspace workspace/*=nullptr*/); - public DataBuffer(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(primary, lenInBytes, dataType); } - private native void allocate(Pointer primary, - @Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, - Workspace workspace/*=nullptr*/) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes, workspace); } - private native void allocate(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes, - Workspace workspace/*=nullptr*/); - public DataBuffer(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes) { super((Pointer)null); allocate(hostBuffer, dataType, lenInBytes); } - private native void allocate(@Const Pointer hostBuffer, - @Cast("const sd::DataType") int dataType, @Cast("const size_t") long lenInBytes); - - public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/) { super((Pointer)null); allocate(lenInBytes, dataType, workspace, allocBoth); } - private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType, Workspace workspace/*=nullptr*/, @Cast("const bool") boolean allocBoth/*=false*/); - public DataBuffer(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType) { super((Pointer)null); allocate(lenInBytes, dataType); } - private native void allocate(@Cast("const size_t") long lenInBytes, @Cast("const sd::DataType") int dataType); - - public DataBuffer(@Const @ByRef DataBuffer other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef DataBuffer other); - public DataBuffer() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); - - public native @Cast("sd::DataType") int getDataType(); - public native void setDataType(@Cast("sd::DataType") int dataType); - public native @Cast("size_t") long getLenInBytes(); - - public native Pointer primary(); - public native Pointer special(); - - public native void allocatePrimary(); - public native void allocateSpecial(); - - public native void writePrimary(); - public native void writeSpecial(); - public native void readPrimary(); - public native void readSpecial(); - public native @Cast("bool") boolean isPrimaryActual(); - public native @Cast("bool") boolean isSpecialActual(); - - public native void expand(@Cast("const uint64_t") long size); - - public native int deviceId(); - public native void setDeviceId(int deviceId); - public native void migrate(); - - public native void syncToPrimary(@Const LaunchContext context, @Cast("const bool") boolean forceSync/*=false*/); - public native void syncToPrimary(@Const LaunchContext context); - public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); - public native void syncToSpecial(); - - public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); - public native void setToZeroBuffers(); - - public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, @Cast("const Nd4jLong") long offsetThis/*=0*/, @Cast("const Nd4jLong") long offsetOther/*=0*/); - public native void copyBufferFrom(@Const @ByRef DataBuffer other); - - public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); - - public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); - public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + public native @ByRef @Name("operator =") DataBuffer put(@Const @ByRef DataBuffer other); - /** - * This method deletes buffers, if we're owners - */ - public native @Name("close") void _close(); + public native @Cast("sd::DataType") int getDataType(); + public native void setDataType(@Cast("sd::DataType") int dataType); + public native @Cast("size_t") long getLenInBytes(); + + public native Pointer primary(); + public native Pointer special(); + public native Pointer platform(); + + public native void allocatePrimary(); + public native void allocateSpecial(); + + public native void writePrimary(); + public native void writeSpecial(); + public native void readPrimary(); + public native void readSpecial(); + public native @Cast("bool") boolean isPrimaryActual(); + public native @Cast("bool") boolean isSpecialActual(); + + public native void expand(@Cast("const uint64_t") long size); + + public native int deviceId(); + public native void setDeviceId(int deviceId); + + public native void migrate(); + + public native void syncToPrimary(@Const LaunchContext context, + @Cast("const bool") boolean forceSync/*=false*/); + public native void syncToPrimary(@Const LaunchContext context); + public native void syncToSpecial(@Cast("const bool") boolean forceSync/*=false*/); + public native void syncToSpecial(); + + public native void setToZeroBuffers(@Cast("const bool") boolean both/*=false*/); + public native void setToZeroBuffers(); + + public native void copyBufferFrom(@Const @ByRef DataBuffer other, @Cast("size_t") long sizeToCopyinBytes/*=0*/, + @Cast("const Nd4jLong") long offsetThis/*=0*/, + @Cast("const Nd4jLong") long offsetOther/*=0*/); + public native void copyBufferFrom(@Const @ByRef DataBuffer other); + + public static native void memcpy(@Const @ByRef DataBuffer dst, @Const @ByRef DataBuffer src); + + public native void setPrimaryBuffer(Pointer buffer, @Cast("size_t") long length); + public native void setSpecialBuffer(Pointer buffer, @Cast("size_t") long length); + + /** + * This method deletes buffers, if we're owners + */ + public native @Name("close") void _close(); } ///// IMLEMENTATION OF INLINE METHODS ///// //////////////////////////////////////////////////////////////////////// - + //////////////////////////////////////////////////////////////////////// - +//////////////////////////////////////////////////////////////////////// + // namespace sd -// #endif //DEV_TESTS_DATABUFFER_H +// #endif // SD_DATABUFFER_H // Parsed from array/PointerDeallocator.h @@ -587,33 +756,33 @@ private native void allocate(@Const Pointer hostBuffer, // #include // #include // #include - @Namespace("sd") @NoOffset public static class ConstantDataBuffer extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ConstantDataBuffer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ConstantDataBuffer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ConstantDataBuffer position(long position) { - return (ConstantDataBuffer)super.position(position); - } - - public ConstantDataBuffer(@Const @ByRef ConstantDataBuffer other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ConstantDataBuffer other); - public ConstantDataBuffer() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd") @NoOffset public static class ConstantDataBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstantDataBuffer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ConstantDataBuffer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ConstantDataBuffer position(long position) { + return (ConstantDataBuffer)super.position(position); + } - public native @Cast("uint8_t") byte sizeOf(); - public native @Cast("uint64_t") long length(); + public ConstantDataBuffer(@Const @ByRef ConstantDataBuffer other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ConstantDataBuffer other); + public ConstantDataBuffer() { super((Pointer)null); allocate(); } + private native void allocate(); - public native Pointer primary(); - public native Pointer special(); + public native @Cast("uint8_t") byte sizeOf(); + public native @Cast("uint64_t") long length(); - public native @ByRef @Name("operator =") ConstantDataBuffer put(@Const @ByRef ConstantDataBuffer other); - } + public native Pointer primary(); + public native Pointer special(); + public native @ByRef @Name("operator =") ConstantDataBuffer put(@Const @ByRef ConstantDataBuffer other); +} + // namespace sd -// #endif //DEV_TESTS_CONSTANTDATABUFFER_H +// #endif // SD_CONSTANTDATABUFFER_H // Parsed from array/ConstantShapeBuffer.h @@ -746,68 +915,68 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_CONSTANTDESCRIPTOR_H -// #define DEV_TESTS_CONSTANTDESCRIPTOR_H +// #ifndef SD_CONSTANTDESCRIPTOR_H +// #define SD_CONSTANTDESCRIPTOR_H +// #include // #include -// #include -// #include -// #include // #include -// #include - @Namespace("sd") @NoOffset public static class ConstantDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ConstantDescriptor(Pointer p) { super(p); } - - public ConstantDescriptor(DoublePointer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(DoublePointer values, int length); - public ConstantDescriptor(DoubleBuffer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(DoubleBuffer values, int length); - public ConstantDescriptor(double[] values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(double[] values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") LongPointer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") LongBuffer values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer values, int length); - public ConstantDescriptor(@Cast("const Nd4jLong*") long[] values, int length) { super((Pointer)null); allocate(values, length); } - private native void allocate(@Cast("const Nd4jLong*") long[] values, int length); - - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongPointer values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector LongPointer values); - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongBuffer values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector LongBuffer values); - public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector long[] values) { super((Pointer)null); allocate(values); } - private native void allocate(@Cast("Nd4jLong*") @StdVector long[] values); - public ConstantDescriptor(@StdVector DoublePointer values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector DoublePointer values); - public ConstantDescriptor(@StdVector DoubleBuffer values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector DoubleBuffer values); - public ConstantDescriptor(@StdVector double[] values) { super((Pointer)null); allocate(values); } - private native void allocate(@StdVector double[] values); - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ConstantDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ConstantDescriptor other); - - public native @Cast("bool") boolean isInteger(); - public native @Cast("bool") boolean isFloat(); - - public native @Cast("Nd4jLong") long length(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer integerValues(); - public native @StdVector DoublePointer floatValues(); - } +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class ConstantDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ConstantDescriptor(Pointer p) { super(p); } + + public ConstantDescriptor(DoublePointer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(DoublePointer values, int length); + public ConstantDescriptor(DoubleBuffer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(DoubleBuffer values, int length); + public ConstantDescriptor(double[] values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(double[] values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") LongPointer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") LongBuffer values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer values, int length); + public ConstantDescriptor(@Cast("const Nd4jLong*") long[] values, int length) { super((Pointer)null); allocate(values, length); } + private native void allocate(@Cast("const Nd4jLong*") long[] values, int length); + + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongPointer values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector LongPointer values); + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector LongBuffer values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector LongBuffer values); + public ConstantDescriptor(@Cast("Nd4jLong*") @StdVector long[] values) { super((Pointer)null); allocate(values); } + private native void allocate(@Cast("Nd4jLong*") @StdVector long[] values); + public ConstantDescriptor(@StdVector DoublePointer values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector DoublePointer values); + public ConstantDescriptor(@StdVector DoubleBuffer values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector DoubleBuffer values); + public ConstantDescriptor(@StdVector double[] values) { super((Pointer)null); allocate(values); } + private native void allocate(@StdVector double[] values); + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ConstantDescriptor other); + + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ConstantDescriptor other); + + public native @Cast("bool") boolean isInteger(); + public native @Cast("bool") boolean isFloat(); + + public native @Cast("Nd4jLong") long length(); + + public native @Cast("Nd4jLong*") @StdVector LongPointer integerValues(); + public native @StdVector DoublePointer floatValues(); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_CONSTANTDESCRIPTOR_H +// #endif // SD_CONSTANTDESCRIPTOR_H // Parsed from array/TadPack.h @@ -832,47 +1001,49 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_TADPACK_H -// #define DEV_TESTS_TADPACK_H +// #ifndef SD_TADPACK_H +// #define SD_TADPACK_H // #include // #include - @Namespace("sd") @NoOffset public static class TadPack extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public TadPack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public TadPack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public TadPack position(long position) { - return (TadPack)super.position(position); - } - - public TadPack(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, @Cast("Nd4jLong") long numTads) { super((Pointer)null); allocate(shapes, offets, numTads); } - private native void allocate(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, @Cast("Nd4jLong") long numTads); - public TadPack() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("const Nd4jLong*") LongPointer primaryShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer primaryOffsets(); +@Namespace("sd") @NoOffset public static class TadPack extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TadPack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public TadPack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public TadPack position(long position) { + return (TadPack)super.position(position); + } - public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer specialOffsets(); + public TadPack(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, + @Cast("Nd4jLong") long numTads) { super((Pointer)null); allocate(shapes, offets, numTads); } + private native void allocate(@Const @ByRef ConstantShapeBuffer shapes, @Const @ByRef ConstantOffsetsBuffer offets, + @Cast("Nd4jLong") long numTads); + public TadPack() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("Nd4jLong") long numberOfTads(); - public native int shapeInfoLength(); + public native @Cast("const Nd4jLong*") LongPointer primaryShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer primaryOffsets(); - /** - * These methods return either primary or special pointers depending on platform binaries were compiled for - * @return - */ - public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); - public native @Cast("const Nd4jLong*") LongPointer platformOffsets(); - } + public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer specialOffsets(); + public native @Cast("Nd4jLong") long numberOfTads(); + public native int shapeInfoLength(); + /** + * These methods return either primary or special pointers depending on + * platform binaries were compiled for + * @return + */ + public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); + public native @Cast("const Nd4jLong*") LongPointer platformOffsets(); +} + // namespace sd -// #endif //DEV_TESTS_TADPACK_H +// #endif // SD_TADPACK_H // Parsed from execution/ErrorReference.h @@ -897,36 +1068,36 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_ERRORREFERENCE_H -// #define DEV_TESTS_ERRORREFERENCE_H +// #ifndef SD_ERRORREFERENCE_H +// #define SD_ERRORREFERENCE_H -// #include // #include - @Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ErrorReference(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ErrorReference position(long position) { - return (ErrorReference)super.position(position); - } - - public ErrorReference() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int errorCode(); - public native @Cast("char*") String errorMessage(); - public native void setErrorCode(int errorCode); - public native void setErrorMessage(@StdString BytePointer message); - public native void setErrorMessage(@StdString String message); +// #include +@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ErrorReference(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ErrorReference position(long position) { + return (ErrorReference)super.position(position); } + public ErrorReference() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int errorCode(); + public native @Cast("char*") String errorMessage(); + public native void setErrorCode(int errorCode); + public native void setErrorMessage(@StdString BytePointer message); + public native void setErrorMessage(@StdString String message); +} + // namespace sd -// #endif //DEV_TESTS_ERRORREFERENCE_H +// #endif // SD_ERRORREFERENCE_H // Parsed from execution/Engine.h @@ -953,13 +1124,13 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef SD_ENGINE_H // #define SD_ENGINE_H - /** enum samediff::Engine */ - public static final int - ENGINE_CPU = 0, - ENGINE_CUDA = 1; +/** enum samediff::Engine */ +public static final int + ENGINE_CPU = 0, + ENGINE_CUDA = 1; -// #endif //SD_ENGINE_H +// #endif // SD_ENGINE_H // Parsed from execution/ExecutionMode.h @@ -986,14 +1157,14 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef SD_EXECUTIONMODE_H // #define SD_EXECUTIONMODE_H - /** enum samediff::ExecutionMode */ - public static final int - MODE_UNDEFINED = 0, - MODE_TRAINING = 1, - MODE_INFERENCE = 2; +/** enum samediff::ExecutionMode */ +public static final int + MODE_UNDEFINED = 0, + MODE_TRAINING = 1, + MODE_INFERENCE = 2; -// #endif //SD_EXECUTIONMODE_H +// #endif // SD_EXECUTIONMODE_H // Parsed from system/Environment.h @@ -1021,100 +1192,101 @@ private native void allocate(@Const Pointer hostBuffer, // #ifndef LIBND4J_ENVIRONMENT_H // #define LIBND4J_ENVIRONMENT_H -// #include -// #include -// #include -// #include // #include -// #include +// #include // #include - @Namespace("sd") @NoOffset public static class Environment extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Environment(Pointer p) { super(p); } - - /** - * These 3 fields are mostly for CUDA/cuBLAS version tracking - */ - public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); - public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); - public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); - - public static native @ByRef Environment getInstance(); - - public native @Cast("bool") boolean isVerbose(); - public native void setVerbose(@Cast("bool") boolean reallyVerbose); - public native @Cast("bool") boolean isDebug(); - public native @Cast("bool") boolean isProfiling(); - public native @Cast("bool") boolean isDetectingLeaks(); - public native @Cast("bool") boolean isDebugAndVerbose(); - public native void setDebug(@Cast("bool") boolean reallyDebug); - public native void setProfiling(@Cast("bool") boolean reallyProfile); - public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); - public native @Cast("bool") boolean helpersAllowed(); - public native void allowHelpers(@Cast("bool") boolean reallyAllow); - - public native @Cast("bool") boolean blasFallback(); - - public native int tadThreshold(); - public native void setTadThreshold(int threshold); +// #include - public native int elementwiseThreshold(); - public native void setElementwiseThreshold(int threshold); +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class Environment extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Environment(Pointer p) { super(p); } - public native int maxThreads(); - public native void setMaxThreads(int max); - public native int maxMasterThreads(); - public native void setMaxMasterThreads(int max); + /** + * These 3 fields are mostly for CUDA/cuBLAS version tracking + */ + public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter); + public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter); + public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter); - /* - * Legacy memory limits API, still used in new API as simplified version - */ - public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); - public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); - public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public static native @ByRef Environment getInstance(); - public native @Cast("uint64_t") long maxPrimaryMemory(); - public native @Cast("uint64_t") long maxSpecialMemory(); - //////////////////////// + public native @Cast("bool") boolean isVerbose(); + public native void setVerbose(@Cast("bool") boolean reallyVerbose); + public native @Cast("bool") boolean isDebug(); + public native @Cast("bool") boolean isProfiling(); + public native @Cast("bool") boolean isDetectingLeaks(); + public native @Cast("bool") boolean isDebugAndVerbose(); + public native void setDebug(@Cast("bool") boolean reallyDebug); + public native void setProfiling(@Cast("bool") boolean reallyProfile); + public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); + public native @Cast("bool") boolean helpersAllowed(); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); - /* - * Methods for memory limits/counters - */ - public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); - public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); + public native @Cast("bool") boolean blasFallback(); - public native @Cast("Nd4jLong") long getGroupLimit(int group); - public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); + public native int tadThreshold(); + public native void setTadThreshold(int threshold); - public native @Cast("Nd4jLong") long getGroupCounter(int group); - public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); - //////////////////////// + public native int elementwiseThreshold(); + public native void setElementwiseThreshold(int threshold); - public native @Cast("bool") boolean isUseMKLDNN(); - public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + public native int maxThreads(); + public native void setMaxThreads(int max); - public native @Cast("sd::DataType") int defaultFloatDataType(); - public native void setDefaultFloatDataType(@Cast("sd::DataType") int dtype); + public native int maxMasterThreads(); + public native void setMaxMasterThreads(int max); - public native @Cast("bool") boolean precisionBoostAllowed(); - public native void allowPrecisionBoost(@Cast("bool") boolean reallyAllow); + /* + * Legacy memory limits API, still used in new API as simplified version + */ + public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + + public native @Cast("uint64_t") long maxPrimaryMemory(); + public native @Cast("uint64_t") long maxSpecialMemory(); + //////////////////////// - public native @Cast("bool") boolean isExperimentalBuild(); + /* + * Methods for memory limits/counters + */ + public native void setGroupLimit(int group, @Cast("Nd4jLong") long numBytes); + public native void setDeviceLimit(int deviceId, @Cast("Nd4jLong") long numBytes); - public native @Cast("bool") boolean isCPU(); + public native @Cast("Nd4jLong") long getGroupLimit(int group); + public native @Cast("Nd4jLong") long getDeviceLimit(int deviceId); - public native int blasMajorVersion(); - public native int blasMinorVersion(); - public native int blasPatchVersion(); + public native @Cast("Nd4jLong") long getGroupCounter(int group); + public native @Cast("Nd4jLong") long getDeviceCounter(int deviceId); + //////////////////////// - public native @StdVector Pair capabilities(); - } + public native @Cast("bool") boolean isUseMKLDNN(); + public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + + public native @Cast("sd::DataType") int defaultFloatDataType(); + public native void setDefaultFloatDataType(@Cast("sd::DataType") int dtype); + + public native @Cast("bool") boolean precisionBoostAllowed(); + public native void allowPrecisionBoost(@Cast("bool") boolean reallyAllow); + + public native @Cast("bool") boolean isExperimentalBuild(); + + public native @Cast("bool") boolean isCPU(); + public native int blasMajorVersion(); + public native int blasMinorVersion(); + public native int blasPatchVersion(); + public native @StdVector Pair capabilities(); +} + // namespace sd -// #endif //LIBND4J_ENVIRONMENT_H +// #endif // LIBND4J_ENVIRONMENT_H // Parsed from types/utf8string.h @@ -1139,44 +1311,44 @@ private native void allocate(@Const Pointer hostBuffer, // @author raver119@gmail.com // -// #ifndef DEV_TESTS_UTF8STRING_H -// #define DEV_TESTS_UTF8STRING_H +// #ifndef SD_UTF8STRING_H +// #define SD_UTF8STRING_H -// #include // #include - @Namespace("sd") @NoOffset public static class utf8string extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public utf8string(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public utf8string(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public utf8string position(long position) { - return (utf8string)super.position(position); - } - - public native @Cast("char*") BytePointer _buffer(); public native utf8string _buffer(BytePointer setter); - public native @Cast("unsigned int") int _length(); public native utf8string _length(int setter); - public utf8string() { super((Pointer)null); allocate(); } - private native void allocate(); - - public utf8string(@Cast("char*") String string, int length) { super((Pointer)null); allocate(string, length); } - private native void allocate(@Cast("char*") String string, int length); - public utf8string(@Cast("char*") BytePointer string, int length) { super((Pointer)null); allocate(string, length); } - private native void allocate(@Cast("char*") BytePointer string, int length); - public utf8string(@StdString BytePointer string) { super((Pointer)null); allocate(string); } - private native void allocate(@StdString BytePointer string); - public utf8string(@StdString String string) { super((Pointer)null); allocate(string); } - private native void allocate(@StdString String string); - public utf8string(@Const @ByRef utf8string other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef utf8string other); - public native @ByRef @Name("operator =") utf8string put(@Const @ByRef utf8string other); +// #include +@Namespace("sd") @NoOffset public static class utf8string extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public utf8string(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public utf8string(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public utf8string position(long position) { + return (utf8string)super.position(position); } + public native @Cast("char*") BytePointer _buffer(); public native utf8string _buffer(BytePointer setter); + public native @Cast("unsigned int") int _length(); public native utf8string _length(int setter); + + public utf8string() { super((Pointer)null); allocate(); } + private native void allocate(); + public utf8string(@Cast("char*") String string, int length) { super((Pointer)null); allocate(string, length); } + private native void allocate(@Cast("char*") String string, int length); + public utf8string(@Cast("char*") BytePointer string, int length) { super((Pointer)null); allocate(string, length); } + private native void allocate(@Cast("char*") BytePointer string, int length); + public utf8string(@StdString BytePointer string) { super((Pointer)null); allocate(string); } + private native void allocate(@StdString BytePointer string); + public utf8string(@StdString String string) { super((Pointer)null); allocate(string); } + private native void allocate(@StdString String string); + public utf8string(@Const @ByRef utf8string other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef utf8string other); + public native @ByRef @Name("operator =") utf8string put(@Const @ByRef utf8string other); +} + // namespace sd -// #endif //DEV_TESTS_UTF8STRING_H +// #endif // SD_UTF8STRING_H // Parsed from legacy/NativeOps.h @@ -1214,7 +1386,7 @@ private native void allocate(@Const Pointer hostBuffer, defined __DMC__ || \ defined __BORLANDC__ ) # define thread_local __declspec(thread) -// note that ICC (linux) and Clang are covered by __GNUC__ +// note that ICC (linux) and Clang are covered by __GNUC__ # elif defined __GNUC__ || \ defined __SUNPRO_C || \ defined __xlC__ @@ -1225,17 +1397,17 @@ private native void allocate(@Const Pointer hostBuffer, #endif */ +// #include // #include // #include -// #include -//DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION -//IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN -//RE ADDS THE DEFINITION VIA dll.h -// #ifdef _WIN32 -// #define ND4J_EXPORT __declspec(dllexport) +// DO NOT REMOVE: THIS IS AN EDITOR SEMANTICS THING FOR CLION +// IT DEFINES THE EXPORT MACRO FOR THE EDITOR AND THEN +// RE ADDS THE DEFINITION VIA dll.h +// #ifdef _WIN32 +// #define SD_EXPORT __declspec(dllexport) // #else -// #define ND4J_EXPORT +// #define SD_EXPORT // #endif // #include @@ -1247,17 +1419,16 @@ private native void allocate(@Const Pointer hostBuffer, bool verbose = false; */ -// #include -// #include -// #include // #include +// #include // #include -// #include +// #include // #include -// #include -// #include -// #include // #include +// #include +// #include +// #include +// #include // #include // #include @@ -1293,27 +1464,33 @@ private native void allocate(@Const Pointer hostBuffer, public native void setTADThreshold(int num); /** - * - * @param opNum - * @param x - * @param xShapeInfo - * @param extraParams - */ -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); + * + * @param opNum + * @param x + * @param xShapeInfo + * @param extraParams + */ +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1326,24 +1503,24 @@ public native void execIndexReduceScalar(@Cast("Nd4jPointer*") PointerPointer ex * @param dimension * @param dimensionLength */ -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execIndexReduce( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1357,53 +1534,61 @@ public native void execIndexReduce(@Cast("Nd4jPointer*") PointerPointer extraPoi * @param dimension * @param dimensionLength */ -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execBroadcast( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execBroadcast(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1418,48 +1603,48 @@ public native void execBroadcastBool( * @param n */ public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransform( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams); public native void execPairwiseTransformBool( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams); /** * @@ -1470,70 +1655,93 @@ public native void execPairwiseTransformBool( * @param result * @param resultShapeInfo */ -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); - - -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); + +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1544,84 +1752,81 @@ public native void execReduceLong(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param result * @param resultShapeInfo */ -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceFloat2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceSame2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceBool2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); - - -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); -public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceFloat2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceSame2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceBool2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); + +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape); +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape); +public native void execReduceLong2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape); /** * @@ -1634,24 +1839,27 @@ public native void execReduceLong2(@Cast("Nd4jPointer*") PointerPointer extraPoi * @param result * @param resultShapeInfo */ -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParamsVals, + OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * @@ -1662,24 +1870,24 @@ public native void execReduce3(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param y * @param yShapeInfo */ -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); -public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo); +public native void execReduce3Scalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo); /** * * @param opNum @@ -1693,62 +1901,67 @@ public native void execReduce3Scalar(@Cast("Nd4jPointer*") PointerPointer extraP * @param dimension * @param dimensionLength */ -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadOnlyShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer yTadOnlyShapeInfo, @Cast("const Nd4jLong*") LongPointer yTadOffsets); -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadOnlyShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer yTadOnlyShapeInfo, @Cast("const Nd4jLong*") LongBuffer yTadOffsets); -public native void execReduce3Tad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadOnlyShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] yTadOnlyShapeInfo, @Cast("const Nd4jLong*") long[] yTadOffsets); - - -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer xTadShapeInfo, @Cast("const Nd4jLong*") LongPointer xOffsets, - @Cast("const Nd4jLong*") LongPointer yTadShapeInfo, @Cast("const Nd4jLong*") LongPointer yOffsets); -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer xTadShapeInfo, @Cast("const Nd4jLong*") LongBuffer xOffsets, - @Cast("const Nd4jLong*") LongBuffer yTadShapeInfo, @Cast("const Nd4jLong*") LongBuffer yOffsets); -public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParamsVals, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] dYShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] xTadShapeInfo, @Cast("const Nd4jLong*") long[] xOffsets, - @Cast("const Nd4jLong*") long[] yTadShapeInfo, @Cast("const Nd4jLong*") long[] yOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape, @Cast("const Nd4jLong*") LongPointer tadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, @Cast("const Nd4jLong*") LongPointer yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongPointer yTadOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape, @Cast("const Nd4jLong*") LongBuffer tadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, @Cast("const Nd4jLong*") LongBuffer yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") LongBuffer yTadOffsets); +public native void execReduce3Tad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape, @Cast("const Nd4jLong*") long[] tadOnlyShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, @Cast("const Nd4jLong*") long[] yTadOnlyShapeInfo, + @Cast("const Nd4jLong*") long[] yTadOffsets); + +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, + @Cast("const Nd4jLong*") LongPointer dDimensionShape, @Cast("const Nd4jLong*") LongPointer xTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer xOffsets, @Cast("const Nd4jLong*") LongPointer yTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer yOffsets); +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, + @Cast("const Nd4jLong*") LongBuffer dDimensionShape, @Cast("const Nd4jLong*") LongBuffer xTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer xOffsets, @Cast("const Nd4jLong*") LongBuffer yTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer yOffsets); +public native void execReduce3All( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + Pointer extraParamsVals, OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeInfo, + @Cast("const Nd4jLong*") long[] dYShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, + @Cast("const Nd4jLong*") long[] dDimensionShape, @Cast("const Nd4jLong*") long[] xTadShapeInfo, + @Cast("const Nd4jLong*") long[] xOffsets, @Cast("const Nd4jLong*") long[] yTadShapeInfo, + @Cast("const Nd4jLong*") long[] yOffsets); /** * @@ -1761,43 +1974,52 @@ public native void execReduce3All(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param extraParams * @param n */ -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, - Pointer extraParams); - -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, - Pointer extraParams); -public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalar, @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, - Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, + @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, Pointer extraParams); +public native void execScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, + OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, + @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, Pointer extraParams); + +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongPointer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dSscalarShapeInfo, + Pointer extraParams); +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") LongBuffer hSscalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dSscalarShapeInfo, + Pointer extraParams); +public native void execScalarBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalar, + @Cast("const Nd4jLong*") long[] hSscalarShapeInfo, @Cast("const Nd4jLong*") long[] dSscalarShapeInfo, + Pointer extraParams); /** * @@ -1806,24 +2028,21 @@ public native void execScalarBool(@Cast("Nd4jPointer*") PointerPointer extraPoin * @param xShapeInfo * @param extraParams */ -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStatsScalar( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * * @param opNum @@ -1833,24 +2052,21 @@ public native void execSummaryStatsScalar(@Cast("Nd4jPointer*") PointerPointer e * @param result * @param resultShapeInfo */ -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - @Cast("bool") boolean biasCorrected); -public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, @Cast("bool") boolean biasCorrected); +public native void execSummaryStats( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, @Cast("bool") boolean biasCorrected); /** * * @param opNum @@ -1862,30 +2078,30 @@ public native void execSummaryStats(@Cast("Nd4jPointer*") PointerPointer extraPo * @param dimension * @param dimensionLength */ -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets); -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets); -public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("bool") boolean biasCorrected, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets); +public native void execSummaryStatsTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, Pointer extraParams, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("bool") boolean biasCorrected, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets); /** * @@ -1897,85 +2113,91 @@ public native void execSummaryStatsTad(@Cast("Nd4jPointer*") PointerPointer extr * @param extraParams * @param n */ -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformFloat(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformSame(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformBool(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); - -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - Pointer extraParams); -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - Pointer extraParams); -public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformFloat( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformSame( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformBool( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformAny(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, + @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); + +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, Pointer extraParams); +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, Pointer extraParams); +public native void execTransformStrict( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, Pointer extraParams); /** * @@ -1990,92 +2212,86 @@ public native void execTransformStrict(@Cast("Nd4jPointer*") PointerPointer extr * @param dimension * @param dimensionLength */ -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); -public native void execScalarTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); - -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, @Cast("const Nd4jLong*") LongPointer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); -public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, @Cast("const Nd4jLong*") long[] dZShapeInfo, - OpaqueDataBuffer dbScalars, @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, - Pointer extraParams, - OpaqueDataBuffer dbDimension, @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, - @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); - -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") LongPointer resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") LongBuffer resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void specialConcat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, - @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - Pointer result, - @Cast("const Nd4jLong*") long[] resultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, - @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); +public native void execScalarTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); + +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeInfo, + @Cast("const Nd4jLong*") LongPointer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongPointer hScalarShapeInfo, @Cast("const Nd4jLong*") LongPointer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongPointer hDimensionShape, @Cast("const Nd4jLong*") LongPointer dDimensionShape, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer tadShapeInfoZ, @Cast("const Nd4jLong*") LongPointer tadOffsetsZ); +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") LongBuffer hScalarShapeInfo, @Cast("const Nd4jLong*") LongBuffer dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") LongBuffer hDimensionShape, @Cast("const Nd4jLong*") LongBuffer dDimensionShape, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfoZ, @Cast("const Nd4jLong*") LongBuffer tadOffsetsZ); +public native void execScalarBoolTad( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] dXShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeInfo, + @Cast("const Nd4jLong*") long[] dZShapeInfo, OpaqueDataBuffer dbScalars, + @Cast("const Nd4jLong*") long[] hScalarShapeInfo, @Cast("const Nd4jLong*") long[] dScalarShapeInfo, + Pointer extraParams, OpaqueDataBuffer dbDimension, + @Cast("const Nd4jLong*") long[] hDimensionShape, @Cast("const Nd4jLong*") long[] dDimensionShape, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] tadShapeInfoZ, @Cast("const Nd4jLong*") long[] tadOffsetsZ); + +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") LongPointer resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") LongBuffer resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); +public native void specialConcat(@Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, + int numArrays, @Cast("Nd4jPointer*") PointerPointer data, + @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, Pointer result, + @Cast("const Nd4jLong*") long[] resultShapeInfo, + @Cast("Nd4jPointer*") PointerPointer tadPointers, + @Cast("Nd4jPointer*") PointerPointer offsetPointers); /** * This method implementation exists only for cuda. @@ -2099,10 +2315,12 @@ public native void specialConcat( * * @param pointer pointer that'll be used for allocation * @param memorySize memory size, in bytes - * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for OpenCL that's pointer to device_id, etc + * @param ptrToDeviceId pointer to deviceId. For cuda that's just and int, for + * OpenCL that's pointer to device_id, etc * @param flags optional parameter */ -public native @Cast("Nd4jPointer") Pointer mallocDevice(@Cast("Nd4jLong") long memorySize, int deviceId, int flags); +public native @Cast("Nd4jPointer") Pointer mallocDevice(@Cast("Nd4jLong") long memorySize, int deviceId, + int flags); /** * This method releases previously allocated host memory space @@ -2143,7 +2361,6 @@ public native void specialConcat( */ public native void setOmpMinThreads(int threads); - public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build); /** @@ -2263,11 +2480,8 @@ public native void specialConcat( * @param reserved * @return */ -public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2278,11 +2492,8 @@ public native int memcpySync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2293,11 +2504,8 @@ public native int memcpyAsync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, - int value, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, int value, @Cast("Nd4jLong") long size, int flags, + @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2308,11 +2516,8 @@ public native int memsetSync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, - int value, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, int value, @Cast("Nd4jLong") long size, int flags, + @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2323,11 +2528,8 @@ public native int memsetAsync(@Cast("Nd4jPointer") Pointer dst, * @param reserved * @return */ -public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, - @Cast("Nd4jPointer") Pointer src, - @Cast("Nd4jLong") long size, - int flags, - @Cast("Nd4jPointer") Pointer reserved); +public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, @Cast("Nd4jPointer") Pointer src, @Cast("Nd4jLong") long size, + int flags, @Cast("Nd4jPointer") Pointer reserved); /** * @@ -2368,14 +2570,11 @@ public native int memcpyConstantAsync(@Cast("Nd4jLong") long dst, * @param offsetsBuffer */ public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") LongPointer xShapeInfo, - IntPointer dimension, - int dimensionLength); + IntPointer dimension, int dimensionLength); public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") LongBuffer xShapeInfo, - IntBuffer dimension, - int dimensionLength); + IntBuffer dimension, int dimensionLength); public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") long[] xShapeInfo, - int[] dimension, - int dimensionLength); + int[] dimension, int dimensionLength); public native @Cast("const Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack); public native @Cast("const Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack); @@ -2404,33 +2603,30 @@ public native OpaqueTadPack tadOnlyShapeInfo(@Cast("const Nd4jLong*") long[] xSh * @param zTadShapeInfo * @param zTadOffsets */ -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer zShapeInfo, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") LongPointer indexes, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("const Nd4jLong*") LongPointer zTadShapeInfo, - @Cast("const Nd4jLong*") LongPointer zTadOffsets); -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") LongBuffer indexes, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("const Nd4jLong*") LongBuffer zTadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer zTadOffsets); -public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] zShapeInfo, @Cast("const Nd4jLong*") long[] dzShapeInfo, - @Cast("Nd4jLong") long n, - @Cast("Nd4jLong*") long[] indexes, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("const Nd4jLong*") long[] zTadShapeInfo, - @Cast("const Nd4jLong*") long[] zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer zShapeInfo, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") LongPointer indexes, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, + @Cast("const Nd4jLong*") LongPointer zTadShapeInfo, + @Cast("const Nd4jLong*") LongPointer zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") LongBuffer indexes, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, + @Cast("const Nd4jLong*") LongBuffer zTadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer zTadOffsets); +public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, + OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] zShapeInfo, + @Cast("const Nd4jLong*") long[] dzShapeInfo, @Cast("Nd4jLong") long n, + @Cast("Nd4jLong*") long[] indexes, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, + @Cast("const Nd4jLong*") long[] zTadShapeInfo, + @Cast("const Nd4jLong*") long[] zTadOffsets); /** * @@ -2441,54 +2637,40 @@ public native void pullRows(@Cast("Nd4jPointer*") PointerPointer extraPointers, * @param length * @param propagate */ -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); -public native void average(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") long[] dzShapeInfo, - int n, - @Cast("Nd4jLong") long length, - @Cast("bool") boolean propagate); - - -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongPointer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); -public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer z, @Cast("const Nd4jLong*") long[] zShapeInfo, - Pointer dz, @Cast("const Nd4jLong*") long[] dzShapeInfo, - int n, - @Cast("Nd4jLong") long length); - +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongPointer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongBuffer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); +public native void average(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") long[] zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") long[] dzShapeInfo, int n, @Cast("Nd4jLong") long length, + @Cast("bool") boolean propagate); + +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongPointer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongPointer dzShapeInfo, int n, @Cast("Nd4jLong") long length); +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") LongBuffer zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") LongBuffer dzShapeInfo, int n, @Cast("Nd4jLong") long length); +public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer z, + @Cast("const Nd4jLong*") long[] zShapeInfo, Pointer dz, + @Cast("const Nd4jLong*") long[] dzShapeInfo, int n, @Cast("Nd4jLong") long length); /** * P2P enabler @@ -2526,34 +2708,24 @@ public native void accumulate(@Cast("Nd4jPointer*") PointerPointer extras, * @param tadShapeInfo * @param tadOffsets */ -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - IntPointer shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - IntBuffer shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); -public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, - @Cast("Nd4jPointer*") PointerPointer x, @Cast("Nd4jPointer*") PointerPointer xShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dx, @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer z, @Cast("Nd4jPointer*") PointerPointer zShapeInfo, - @Cast("Nd4jPointer*") PointerPointer dz, @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, - int N, - int[] shuffleMap, - @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadOffsets); - +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, IntPointer shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, IntBuffer shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); +public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, @Cast("Nd4jPointer*") PointerPointer x, + @Cast("Nd4jPointer*") PointerPointer xShapeInfo, @Cast("Nd4jPointer*") PointerPointer dx, + @Cast("Nd4jPointer*") PointerPointer dxShapeInfo, @Cast("Nd4jPointer*") PointerPointer z, + @Cast("Nd4jPointer*") PointerPointer zShapeInfo, @Cast("Nd4jPointer*") PointerPointer dz, + @Cast("Nd4jPointer*") PointerPointer dzShapeInfo, int N, int[] shuffleMap, + @Cast("Nd4jPointer*") PointerPointer tadShapeInfo, @Cast("Nd4jPointer*") PointerPointer tadOffsets); /** * Type Conversions @@ -2568,8 +2740,8 @@ public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, * @param dstType * @param z */ -public native void convertTypes(@Cast("Nd4jPointer*") PointerPointer extras, int srcType, @Cast("Nd4jPointer") Pointer x, @Cast("Nd4jLong") long N, int dstType, @Cast("Nd4jPointer") Pointer z); - +public native void convertTypes(@Cast("Nd4jPointer*") PointerPointer extras, int srcType, @Cast("Nd4jPointer") Pointer x, + @Cast("Nd4jLong") long N, int dstType, @Cast("Nd4jPointer") Pointer z); /** * @@ -2596,83 +2768,46 @@ public native void shuffle(@Cast("Nd4jPointer*") PointerPointer extras, * @param realArguments * @param numRealArguments */ -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") PointerPointer arguments, - int numArguments, - @Cast("Nd4jLong**") PointerPointer shapeArguments, - int numShapeArguments, - IntPointer indexArguments, - int numIndexArguments, - @Cast("int**") PointerPointer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr LongPointer shapeArguments, - int numShapeArguments, - IntPointer indexArguments, - int numIndexArguments, - @ByPtrPtr IntPointer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr LongBuffer shapeArguments, - int numShapeArguments, - IntBuffer indexArguments, - int numIndexArguments, - @ByPtrPtr IntBuffer intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); -public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("void**") @ByPtrPtr Pointer arguments, - int numArguments, - @Cast("Nd4jLong**") @ByPtrPtr long[] shapeArguments, - int numShapeArguments, - int[] indexArguments, - int numIndexArguments, - @ByPtrPtr int[] intArrays, - int numIntArrays, - Pointer realArguments, - int numRealArguments, - @Cast("sd::DataType") int dtype); - - -public native void batchExecutor(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - Pointer ptrToArguments, - @Cast("sd::DataType") int dtype); - -public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int numAggregates, - int opNum, - int maxArgs, - int maxShapes, - int maxIntArrays, - int maxIntArraySize, - int maxIdx, - int maxReals, - Pointer ptrToArguments, - @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") PointerPointer arguments, int numArguments, + @Cast("Nd4jLong**") PointerPointer shapeArguments, int numShapeArguments, + IntPointer indexArguments, int numIndexArguments, + @Cast("int**") PointerPointer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr LongPointer shapeArguments, int numShapeArguments, + IntPointer indexArguments, int numIndexArguments, + @ByPtrPtr IntPointer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr LongBuffer shapeArguments, int numShapeArguments, + IntBuffer indexArguments, int numIndexArguments, + @ByPtrPtr IntBuffer intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); +public native void execAggregate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("void**") @ByPtrPtr Pointer arguments, int numArguments, + @Cast("Nd4jLong**") @ByPtrPtr long[] shapeArguments, int numShapeArguments, + int[] indexArguments, int numIndexArguments, + @ByPtrPtr int[] intArrays, int numIntArrays, + Pointer realArguments, int numRealArguments, + @Cast("sd::DataType") int dtype); + +public native void batchExecutor(@Cast("Nd4jPointer*") PointerPointer extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, int maxIdx, + int maxReals, Pointer ptrToArguments, + @Cast("sd::DataType") int dtype); + +public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extraPointers, int numAggregates, + int opNum, int maxArgs, int maxShapes, + int maxIntArrays, int maxIntArraySize, + int maxIdx, int maxReals, + Pointer ptrToArguments, @Cast("sd::DataType") int dtype); /** * Random operations @@ -2687,21 +2822,18 @@ public native void execAggregateBatch(@Cast("Nd4jPointer*") PointerPointer extra * @param zShapeBuffer * @param extraArguments */ -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2716,27 +2848,30 @@ public native void execRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers * @param zShapeBuffer * @param extraArguments */ -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongPointer hYShapeBuffer, @Cast("const Nd4jLong*") LongPointer dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") LongBuffer hYShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeBuffer, @Cast("const Nd4jLong*") long[] dXShapeBuffer, - OpaqueDataBuffer dbY, @Cast("const Nd4jLong*") long[] hYShapeBuffer, @Cast("const Nd4jLong*") long[] dYShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongPointer hYShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") LongBuffer hYShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeBuffer, + @Cast("const Nd4jLong*") long[] dXShapeBuffer, OpaqueDataBuffer dbY, + @Cast("const Nd4jLong*") long[] hYShapeBuffer, + @Cast("const Nd4jLong*") long[] dYShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2749,25 +2884,24 @@ public native void execRandom3(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param zShapeBuffer * @param extraArguments */ -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, - Pointer extraArguments); -public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, - int opNum, - @Cast("Nd4jPointer") Pointer state, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] hXShapeBuffer, @Cast("const Nd4jLong*") long[] dXShapeBuffer, - OpaqueDataBuffer dbZ, @Cast("const Nd4jLong*") long[] hZShapeBuffer, @Cast("const Nd4jLong*") long[] dZShapeBuffer, - Pointer extraArguments); - +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer hXShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongPointer hZShapeBuffer, + @Cast("const Nd4jLong*") LongPointer dZShapeBuffer, Pointer extraArguments); +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer hXShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") LongBuffer hZShapeBuffer, + @Cast("const Nd4jLong*") LongBuffer dZShapeBuffer, Pointer extraArguments); +public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opNum, + @Cast("Nd4jPointer") Pointer state, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] hXShapeBuffer, + @Cast("const Nd4jLong*") long[] dXShapeBuffer, OpaqueDataBuffer dbZ, + @Cast("const Nd4jLong*") long[] hZShapeBuffer, + @Cast("const Nd4jLong*") long[] dZShapeBuffer, Pointer extraArguments); /** * @@ -2777,10 +2911,8 @@ public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param ptrToBuffer * @return */ -public native @Cast("Nd4jPointer") Pointer initRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - long bufferSize, - @Cast("Nd4jPointer") Pointer ptrToBuffer); +public native @Cast("Nd4jPointer") Pointer initRandom(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + long bufferSize, @Cast("Nd4jPointer") Pointer ptrToBuffer); /** * @@ -2788,9 +2920,8 @@ public native void execRandom2(@Cast("Nd4jPointer*") PointerPointer extraPointer * @param seed * @param ptrRandom */ -public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - @Cast("Nd4jPointer") Pointer ptrRandom); +public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + @Cast("Nd4jPointer") Pointer ptrRandom); /** * @@ -2798,9 +2929,8 @@ public native void refreshBuffer(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param seed * @param ptrRandom */ -public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, - long seed, - @Cast("Nd4jPointer") Pointer ptrRandom); +public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, long seed, + @Cast("Nd4jPointer") Pointer ptrRandom); /** * @@ -2809,95 +2939,93 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe public native void destroyRandom(@Cast("Nd4jPointer") Pointer ptrRandom); /** -* -* @param data -* @param shapeBuffer -* @param wordSize -* @param headerSize -* @return -*/ + * + * @param data + * @param shapeBuffer + * @param wordSize + * @param headerSize + * @return + */ -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") LongPointer headerSize); -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") LongBuffer headerSize); -public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize,@Cast("Nd4jLong*") long[] headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") LongPointer headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") LongBuffer headerSize); +public native @Cast("Nd4jPointer") Pointer numpyHeaderForNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize, @Cast("Nd4jLong*") long[] headerSize); /** -* Load numpy from a header -* based on the cnpy parse from header method. -* @param data the header data to parse -* @return a pointer to a numpy cnpy:NpyArray struct -*/ + * Load numpy from a header + * based on the cnpy parse from header method. + * @param data the header data to parse + * @return a pointer to a numpy cnpy:NpyArray struct + */ public native @Cast("Nd4jPointer") Pointer loadNpyFromHeader(@Cast("Nd4jPointer") Pointer data); /** -* Create a numpy array from an nd4j -* array -* @param data a pointer to the data -* @param shapeBuffer the shapebuffer for the nd4j array -* @param wordSize the word size (4 for float, 8 for doubles) -* @return a pointer to a numpy array -*/ - -public native @Cast("Nd4jPointer") Pointer numpyFromNd4j(@Cast("Nd4jPointer") Pointer data,@Cast("Nd4jPointer") Pointer shapeBuffer,@Cast("Nd4jLong") long wordSize); + * Create a numpy array from an nd4j + * array + * @param data a pointer to the data + * @param shapeBuffer the shapebuffer for the nd4j array + * @param wordSize the word size (4 for float, 8 for doubles) + * @return a pointer to a numpy array + */ +public native @Cast("Nd4jPointer") Pointer numpyFromNd4j(@Cast("Nd4jPointer") Pointer data, @Cast("Nd4jPointer") Pointer shapeBuffer, + @Cast("Nd4jLong") long wordSize); /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer shapeBufferForNumpy(@Cast("Nd4jPointer") Pointer npyArray); - /** -* Get the shape buffer from a -* numpy array. -* **Warning** this allocates memory -* @param npyArray -* @return -*/ + * Get the shape buffer from a + * numpy array. + * **Warning** this allocates memory + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer shapeBufferForNumpyHeader(@Cast("Nd4jPointer") Pointer npyArray); - - /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpyHeader(@Cast("Nd4jPointer") Pointer npyArray); /** -* -* @param npyArray -* @return -*/ + * + * @param npyArray + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpyStruct(@Cast("Nd4jPointer") Pointer npyArrayStruct); /** -* -* @param npyArray -* @param fromFile -* @return -*/ + * + * @param npyArray + * @param fromFile + * @return + */ public native @Cast("Nd4jPointer") Pointer dataPointForNumpy(@Cast("Nd4jPointer") Pointer npyArray); /** -* Load a numpy array from a file -* and return it as an Nd4jPointer -* @param path -* @return -*/ + * Load a numpy array from a file + * and return it as an Nd4jPointer + * @param path + * @return + */ public native @Cast("Nd4jPointer") Pointer numpyFromFile(@StdString BytePointer path); public native @Cast("Nd4jPointer") Pointer numpyFromFile(@StdString String path); - ////// NPZ ////// public native Pointer mapFromNpzFile(@StdString BytePointer path); public native Pointer mapFromNpzFile(@StdString String path); - public native int getNumNpyArraysInMap(Pointer map); public native @Cast("char*") String getNpyArrayNameFromMap(Pointer map, int index); @@ -2922,26 +3050,23 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe ////// /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ public native int elementSizeForNpyArray(@Cast("Nd4jPointer") Pointer npyArray); - /** -* Get the element size for a numpy array -* @param npyArray the numpy array's address -* to get the length for -* @return -*/ + * Get the element size for a numpy array + * @param npyArray the numpy array's address + * to get the length for + * @return + */ public native int elementSizeForNpyArrayHeader(@Cast("Nd4jPointer") Pointer npyArray); - public native void releaseNumpy(@Cast("Nd4jPointer") Pointer npyArray); - /** * Return the length of a shape buffer * based on the pointer @@ -2950,18 +3075,18 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe */ public native int lengthForShapeBufferPointer(@Cast("Nd4jPointer") Pointer buffer); - - /** -* The pointer to get the address for -* -* @param address the address to get the pointer -* @return the pointer for the given address -*/ +/** + * The pointer to get the address for + * + * @param address the address to get the pointer + * @return the pointer for the given address + */ public native @Cast("Nd4jPointer") Pointer pointerForAddress(@Cast("Nd4jLong") long _address); /** - * This method takes single N-dimensional tensor, and copies its TADs to target arrays + * This method takes single N-dimensional tensor, and copies its TADs to target + * arrays * * @param x * @param xShapeInfo @@ -2969,164 +3094,138 @@ public native void reSeedBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointe * @param zShapeInfo * @return */ -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongPointer zShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets); -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets); -public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, - OpaqueDataBuffer dbX, @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, - @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") long[] zShapeInfo, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets); - -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - @Cast("bool") boolean descending); -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - @Cast("bool") boolean descending); -public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - @Cast("bool") boolean descending); - -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - @Cast("bool") boolean descending); - -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - @Cast("bool") boolean descending); -public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - @Cast("bool") boolean descending); - -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("const Nd4jLong*") LongPointer tadShapeInfo, - @Cast("const Nd4jLong*") LongPointer tadOffsets, - @Cast("bool") boolean descending); -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, - @Cast("const Nd4jLong*") LongBuffer tadOffsets, - @Cast("bool") boolean descending); -public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("const Nd4jLong*") long[] tadShapeInfo, - @Cast("const Nd4jLong*") long[] tadOffsets, - @Cast("bool") boolean descending); - -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("bool") boolean descending); - -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongPointer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongPointer dyShapeInfo, - IntPointer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") LongBuffer yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, - IntBuffer dimension, - int dimensionLength, - @Cast("bool") boolean descending); -public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, - Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, - Pointer dx, @Cast("const Nd4jLong*") long[] dxShapeInfo, - Pointer y, @Cast("const Nd4jLong*") long[] yShapeInfo, - Pointer dy, @Cast("const Nd4jLong*") long[] dyShapeInfo, - int[] dimension, - int dimensionLength, - @Cast("bool") boolean descending); - +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, @Cast("const Nd4jLong*") LongPointer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongPointer zShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadShapeInfo, @Cast("const Nd4jLong*") LongPointer tadOffsets); +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") LongBuffer zShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, @Cast("const Nd4jLong*") LongBuffer tadOffsets); +public native void tear(@Cast("Nd4jPointer*") PointerPointer extraPointers, OpaqueDataBuffer dbX, + @Cast("const Nd4jLong*") long[] xShapeInfo, @Cast("const Nd4jLong*") long[] dxShapeInfo, + @Cast("Nd4jPointer*") PointerPointer targets, @Cast("const Nd4jLong*") long[] zShapeInfo, + @Cast("const Nd4jLong*") long[] tadShapeInfo, @Cast("const Nd4jLong*") long[] tadOffsets); + +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, @Cast("bool") boolean descending); +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, @Cast("bool") boolean descending); +public native void sort(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, @Cast("bool") boolean descending); + +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, @Cast("bool") boolean descending); + +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, @Cast("bool") boolean descending); +public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, @Cast("bool") boolean descending); + +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("const Nd4jLong*") LongPointer tadShapeInfo, + @Cast("const Nd4jLong*") LongPointer tadOffsets, @Cast("bool") boolean descending); +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("const Nd4jLong*") LongBuffer tadShapeInfo, + @Cast("const Nd4jLong*") LongBuffer tadOffsets, @Cast("bool") boolean descending); +public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, int[] dimension, + int dimensionLength, @Cast("const Nd4jLong*") long[] tadShapeInfo, + @Cast("const Nd4jLong*") long[] tadOffsets, @Cast("bool") boolean descending); + +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, int[] dimension, + int dimensionLength, @Cast("bool") boolean descending); + +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongPointer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongPointer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongPointer dyShapeInfo, IntPointer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") LongBuffer dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") LongBuffer yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") LongBuffer dyShapeInfo, IntBuffer dimension, + int dimensionLength, @Cast("bool") boolean descending); +public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, Pointer dx, + @Cast("const Nd4jLong*") long[] dxShapeInfo, Pointer y, + @Cast("const Nd4jLong*") long[] yShapeInfo, Pointer dy, + @Cast("const Nd4jLong*") long[] dyShapeInfo, int[] dimension, + int dimensionLength, @Cast("bool") boolean descending); // special sort impl for sorting out COO indices and values -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank); - - -public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length); -public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length); - -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer ptrMap, @Cast("Nd4jLong") long length); -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); -public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, + Pointer values, @Cast("Nd4jLong") long length, int rank); + +public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, + @Cast("Nd4jLong") long length); +public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, + @Cast("Nd4jLong") long length); + +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer ptrMap, + @Cast("Nd4jLong") long length); +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, + @Cast("Nd4jLong") long length); +public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, + @Cast("Nd4jLong") long length); // flatbuffers execution -public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer flatBufferPointer); public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr); public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr); @@ -3136,34 +3235,117 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin public native @Cast("char*") String getAllOperations(); // customOp executioner -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); -public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); - -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, int numDArgs); -public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, int numDArgs); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, + int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, + double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, + int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, + @Cast("bool") boolean isInplace); +public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer") Pointer opContext); + +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong") long hash, + @Cast("Nd4jPointer*") PointerPointer inputShapes, + int numInputShapes, + double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntPointer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntBuffer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, int[] dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, IntPointer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, + @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs, IntBuffer dArgs, + int numDArgs); +public native OpaqueShapeList calculateOutputShapes2( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, + @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, int[] dArgs, + int numDArgs); public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); public native @Cast("const Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); -public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, + @Cast("Nd4jPointer") Pointer flatBufferPointer); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); -public native OpaqueVariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); +public native OpaqueVariablesSet executeStoredGraph( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native @Cast("Nd4jLong") long getVariablesSetSize(OpaqueVariablesSet set); public native @Cast("Nd4jStatus") int getVariablesSetStatus(OpaqueVariablesSet set); @@ -3183,68 +3365,123 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin public native void deleteVariablesSet(OpaqueVariablesSet pointer); -// GraphState creation -public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); - -public native void deleteGraphState(@Cast("Nd4jPointer") Pointer state); - public native void deleteResultWrapper(@Cast("Nd4jPointer") Pointer ptr); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") LongPointer xShapeInfo, int N, float threshold); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") LongBuffer xShapeInfo, int N, float threshold); -public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, @Cast("const Nd4jLong*") long[] xShapeInfo, int N, float threshold); - -// this method executes op that requires scope to be present: if/while/cond/whatever -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") LongPointer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") LongBuffer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); -public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, @Cast("Nd4jLong*") long[] scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); - -//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); -public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); -public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); -public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); -public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") LongPointer xShapeInfo, int N, + float threshold); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo, int N, + float threshold); +public native int estimateThreshold(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer x, + @Cast("const Nd4jLong*") long[] xShapeInfo, int N, + float threshold); + +// this method executes op that requires scope to be present: +// if/while/cond/whatever +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") LongPointer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") LongBuffer scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); +public native @Cast("Nd4jStatus") int execCustomOpWithScope( + @Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer state, @Cast("Nd4jLong") long opHash, + @Cast("Nd4jLong*") long[] scopes, int numScopes, @Cast("Nd4jPointer*") PointerPointer inputBuffers, + @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, + @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs); + +// void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int +// numStrings, Nd4jPointer buffer); +public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("char*") String string, int length); +public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("char*") BytePointer string, int length); +public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer ptr); +public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") LongPointer dXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") LongPointer dYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") LongPointer dIndicesShapeInfo); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") LongBuffer dIndicesShapeInfo); -public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, - Pointer hX, @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] hXOffsets, - Pointer dX, @Cast("const Nd4jLong*") long[] dXShapeInfo, @Cast("const Nd4jLong*") long[] dXOffsets, - Pointer hY, @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] hYOffsets, - Pointer dY, @Cast("const Nd4jLong*") long[] dYShapeInfo, @Cast("const Nd4jLong*") long[] dYOffsets, - Pointer hIindexes, @Cast("const Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, @Cast("const Nd4jLong*") long[] dIndicesShapeInfo); - -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); -public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); - -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("sd::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); - -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer data, int length); -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer data, int length); -public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, DoublePointer data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, DoubleBuffer data, int length); -public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, double[] data, int length); -public native OpaqueConstantDataBuffer constantBuffer(@Cast("sd::DataType") int dtype, ConstantDescriptor descriptor); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") LongPointer hXShapeInfo, @Cast("const Nd4jLong*") LongPointer hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") LongPointer dXShapeInfo, @Cast("const Nd4jLong*") LongPointer dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") LongPointer hYShapeInfo, @Cast("const Nd4jLong*") LongPointer hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") LongPointer dYShapeInfo, @Cast("const Nd4jLong*") LongPointer dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") LongPointer hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") LongPointer dIndicesShapeInfo); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") LongBuffer hXShapeInfo, @Cast("const Nd4jLong*") LongBuffer hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") LongBuffer dXShapeInfo, @Cast("const Nd4jLong*") LongBuffer dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") LongBuffer hYShapeInfo, @Cast("const Nd4jLong*") LongBuffer hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") LongBuffer dYShapeInfo, @Cast("const Nd4jLong*") LongBuffer dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") LongBuffer hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") LongBuffer dIndicesShapeInfo); +public native void scatterUpdate( + @Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, Pointer hX, + @Cast("const Nd4jLong*") long[] hXShapeInfo, @Cast("const Nd4jLong*") long[] hXOffsets, Pointer dX, + @Cast("const Nd4jLong*") long[] dXShapeInfo, @Cast("const Nd4jLong*") long[] dXOffsets, Pointer hY, + @Cast("const Nd4jLong*") long[] hYShapeInfo, @Cast("const Nd4jLong*") long[] hYOffsets, Pointer dY, + @Cast("const Nd4jLong*") long[] dYShapeInfo, @Cast("const Nd4jLong*") long[] dYOffsets, Pointer hIindexes, + @Cast("const Nd4jLong*") long[] hIndicesShapeInfo, Pointer dIindexes, + @Cast("const Nd4jLong*") long[] dIndicesShapeInfo); + +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") LongPointer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, + @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, + @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); + +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantShapeBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] strides, + @Cast("sd::DataType") int dtype, char order, + @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); + +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer data, + int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer data, + int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + DoublePointer data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + DoubleBuffer data, + int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("sd::DataType") int dtype, + double[] data, + int length); +public native OpaqueConstantDataBuffer constantBuffer( + @Cast("sd::DataType") int dtype, ConstantDescriptor descriptor); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf); public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf); +public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native @Cast("Nd4jPointer") Pointer getConstantShapeBufferPrimary(OpaqueConstantShapeBuffer dbf); public native @Cast("Nd4jPointer") Pointer getConstantShapeBufferSpecial(OpaqueConstantShapeBuffer dbf); @@ -3252,28 +3489,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native void deleteConstantDataBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); -public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native OpaqueRandomGenerator getGraphContextRandomGenerator( + OpaqueContext ptr); public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); -public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); +public native void ctxShapeFunctionOverride(OpaqueContext ptr, + @Cast("bool") boolean reallyOverride); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); public native void ctxPurge(OpaqueContext ptr); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); -public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); -public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); -public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); -public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); -public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, OpaqueDataBuffer buffer, Pointer shapeInfo, Pointer specialShapeInfo); -public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, int numberOfArguments); -public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, int numberOfArguments); -public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); -public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); -public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments); -public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments); -public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments); +public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, + Pointer reductionPointer, + Pointer allocationPointer); +public native void setGraphContextInputArray(OpaqueContext ptr, int index, + Pointer buffer, Pointer shapeInfo, + Pointer specialBuffer, + Pointer specialShapeInfo); +public native void setGraphContextOutputArray(OpaqueContext ptr, int index, + Pointer buffer, Pointer shapeInfo, + Pointer specialBuffer, + Pointer specialShapeInfo); +public native void setGraphContextInputBuffer(OpaqueContext ptr, int index, + OpaqueDataBuffer buffer, + Pointer shapeInfo, + Pointer specialShapeInfo); +public native void setGraphContextOutputBuffer(OpaqueContext ptr, int index, + OpaqueDataBuffer buffer, + Pointer shapeInfo, + Pointer specialShapeInfo); +public native void setGraphContextDArguments(OpaqueContext ptr, IntPointer arguments, + int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, IntBuffer arguments, + int numberOfArguments); +public native void setGraphContextDArguments(OpaqueContext ptr, int[] arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, + int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") LongPointer arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") LongBuffer arguments, + int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, + @Cast("Nd4jLong*") long[] arguments, + int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, + int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, + int numberOfArguments); public native void deleteGraphContext(OpaqueContext ptr); public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); @@ -3300,17 +3567,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); -public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); -public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); -public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special); -public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); +public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, + @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, + int dataType, + @Cast("bool") boolean allocateBoth); +public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, + int dataType, + @Cast("Nd4jPointer") Pointer primary, + @Cast("Nd4jPointer") Pointer special); +public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); public native void dbExpandBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); public native void dbAllocatePrimaryBuffer(OpaqueDataBuffer dataBuffer); public native void dbAllocateSpecialBuffer(OpaqueDataBuffer dataBuffer); -public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); -public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetPrimaryBuffer(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jPointer") Pointer primaryBuffer, @Cast("Nd4jLong") long numBytes); +public native void dbSetSpecialBuffer(OpaqueDataBuffer dataBuffer, + @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong") long numBytes); public native void dbSyncToSpecial(OpaqueDataBuffer dataBuffer); public native void dbSyncToPrimary(OpaqueDataBuffer dataBuffer); public native int dbLocality(OpaqueDataBuffer dataBuffer); @@ -3324,14 +3600,13 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint public native void deleteDataBuffer(OpaqueDataBuffer dataBuffer); public native void dbExpand(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long elements); - public native int binaryLevel(); public native int optimalLevel(); public native @Cast("bool") boolean isMinimalRequirementsMet(); public native @Cast("bool") boolean isOptimalRequirementsMet(); -// #endif //NATIVEOPERATIONS_NATIVEOPS_H +// #endif // NATIVEOPERATIONS_NATIVEOPS_H // Parsed from memory/ExternalWorkspace.h @@ -3359,33 +3634,35 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_EXTERNALWORKSPACE_H // #define LIBND4J_EXTERNALWORKSPACE_H -// #include // #include - @Namespace("sd::memory") @NoOffset public static class ExternalWorkspace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ExternalWorkspace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ExternalWorkspace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ExternalWorkspace position(long position) { - return (ExternalWorkspace)super.position(position); - } - - public ExternalWorkspace() { super((Pointer)null); allocate(); } - private native void allocate(); +// #include +@Namespace("sd::memory") @NoOffset public static class ExternalWorkspace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ExternalWorkspace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ExternalWorkspace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ExternalWorkspace position(long position) { + return (ExternalWorkspace)super.position(position); + } - public ExternalWorkspace(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, @Cast("Nd4jLong") long sizeD) { super((Pointer)null); allocate(ptrH, sizeH, ptrD, sizeD); } - private native void allocate(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, @Cast("Nd4jLong") long sizeD); - - public native Pointer pointerHost(); - public native Pointer pointerDevice(); + public ExternalWorkspace() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("Nd4jLong") long sizeHost(); - public native @Cast("Nd4jLong") long sizeDevice(); - } - + public ExternalWorkspace(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, + @Cast("Nd4jLong") long sizeD) { super((Pointer)null); allocate(ptrH, sizeH, ptrD, sizeD); } + private native void allocate(@Cast("Nd4jPointer") Pointer ptrH, @Cast("Nd4jLong") long sizeH, @Cast("Nd4jPointer") Pointer ptrD, + @Cast("Nd4jLong") long sizeD); + public native Pointer pointerHost(); + public native Pointer pointerDevice(); + + public native @Cast("Nd4jLong") long sizeHost(); + public native @Cast("Nd4jLong") long sizeDevice(); +} + // namespace memory + // namespace sd // #endif @@ -3417,67 +3694,69 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_WORKSPACE_H // #define LIBND4J_WORKSPACE_H -// #include -// #include -// #include +// #include +// #include // #include // #include // #include -// #include -// #include - @Namespace("sd::memory") @NoOffset public static class Workspace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Workspace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Workspace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Workspace position(long position) { - return (Workspace)super.position(position); - } - - public Workspace(ExternalWorkspace external) { super((Pointer)null); allocate(external); } - private native void allocate(ExternalWorkspace external); - public Workspace(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/) { super((Pointer)null); allocate(initialSize, secondaryBytes); } - private native void allocate(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public Workspace() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("Nd4jLong") long getAllocatedSize(); - public native @Cast("Nd4jLong") long getCurrentSize(); - public native @Cast("Nd4jLong") long getCurrentOffset(); - public native @Cast("Nd4jLong") long getSpilledSize(); - public native @Cast("Nd4jLong") long getUsedSize(); - - public native @Cast("Nd4jLong") long getAllocatedSecondarySize(); - public native @Cast("Nd4jLong") long getCurrentSecondarySize(); - public native @Cast("Nd4jLong") long getCurrentSecondaryOffset(); - public native @Cast("Nd4jLong") long getSpilledSecondarySize(); - public native @Cast("Nd4jLong") long getUsedSecondarySize(); - - public native void expandBy(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public native void expandBy(@Cast("Nd4jLong") long primaryBytes); - public native void expandTo(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); - public native void expandTo(@Cast("Nd4jLong") long primaryBytes); - -// bool resizeSupported(); - - public native Pointer allocateBytes(@Cast("Nd4jLong") long numBytes); - public native Pointer allocateBytes(@Cast("sd::memory::MemoryType") int type, @Cast("Nd4jLong") long numBytes); - - public native void scopeIn(); - public native void scopeOut(); - - /* - * This method creates NEW workspace of the same memory size and returns pointer to it - */ - public native Workspace clone(); - } - +// #include +// #include +// #include +@Namespace("sd::memory") @NoOffset public static class Workspace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Workspace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Workspace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Workspace position(long position) { + return (Workspace)super.position(position); + } + + public Workspace(ExternalWorkspace external) { super((Pointer)null); allocate(external); } + private native void allocate(ExternalWorkspace external); + public Workspace(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/) { super((Pointer)null); allocate(initialSize, secondaryBytes); } + private native void allocate(@Cast("Nd4jLong") long initialSize/*=0L*/, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public Workspace() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @Cast("Nd4jLong") long getAllocatedSize(); + public native @Cast("Nd4jLong") long getCurrentSize(); + public native @Cast("Nd4jLong") long getCurrentOffset(); + public native @Cast("Nd4jLong") long getSpilledSize(); + public native @Cast("Nd4jLong") long getUsedSize(); + + public native @Cast("Nd4jLong") long getAllocatedSecondarySize(); + public native @Cast("Nd4jLong") long getCurrentSecondarySize(); + public native @Cast("Nd4jLong") long getCurrentSecondaryOffset(); + public native @Cast("Nd4jLong") long getSpilledSecondarySize(); + public native @Cast("Nd4jLong") long getUsedSecondarySize(); + + public native void expandBy(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public native void expandBy(@Cast("Nd4jLong") long primaryBytes); + public native void expandTo(@Cast("Nd4jLong") long primaryBytes, @Cast("Nd4jLong") long secondaryBytes/*=0L*/); + public native void expandTo(@Cast("Nd4jLong") long primaryBytes); + + // bool resizeSupported(); + + public native Pointer allocateBytes(@Cast("Nd4jLong") long numBytes); + public native Pointer allocateBytes(@Cast("sd::memory::MemoryType") int type, @Cast("Nd4jLong") long numBytes); + + public native void scopeIn(); + public native void scopeOut(); + + /* + * This method creates NEW workspace of the same memory size and returns + * pointer to it + */ + public native Workspace clone(); +} + // namespace memory + // namespace sd -// #endif //LIBND4J_WORKSPACE_H +// #endif // LIBND4J_WORKSPACE_H // Parsed from indexing/NDIndex.h @@ -3505,79 +3784,77 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_NDINDEX_H // #define LIBND4J_NDINDEX_H +// #include // #include + // #include -// #include - @Namespace("sd") @NoOffset public static class NDIndex extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndex(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NDIndex(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NDIndex position(long position) { - return (NDIndex)super.position(position); - } - - public NDIndex() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("bool") boolean isAll(); - public native @Cast("bool") boolean isPoint(); - public native @Cast("bool") boolean isInterval(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer getIndices(); - public native @Cast("Nd4jLong") long stride(); - - public static native NDIndex all(); - public static native NDIndex point(@Cast("Nd4jLong") long pt); - public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); - public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); - } - - @Namespace("sd") public static class NDIndexAll extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexAll(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NDIndexAll(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NDIndexAll position(long position) { - return (NDIndexAll)super.position(position); - } - - public NDIndexAll() { super((Pointer)null); allocate(); } - private native void allocate(); - public native @Cast("bool") boolean isInterval(); +@Namespace("sd") @NoOffset public static class NDIndex extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndex(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDIndex(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDIndex position(long position) { + return (NDIndex)super.position(position); } + public NDIndex() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native @Cast("bool") boolean isAll(); + public native @Cast("bool") boolean isPoint(); + public native @Cast("bool") boolean isInterval(); - @Namespace("sd") public static class NDIndexPoint extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexPoint(Pointer p) { super(p); } - - public NDIndexPoint(@Cast("Nd4jLong") long point) { super((Pointer)null); allocate(point); } - private native void allocate(@Cast("Nd4jLong") long point); - public native @Cast("bool") boolean isInterval(); - } + public native @Cast("Nd4jLong*") @StdVector LongPointer getIndices(); + public native @Cast("Nd4jLong") long stride(); - @Namespace("sd") public static class NDIndexInterval extends NDIndex { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDIndexInterval(Pointer p) { super(p); } - - public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/) { super((Pointer)null); allocate(start, end, stride); } - private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); - public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end) { super((Pointer)null); allocate(start, end); } - private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); - public native @Cast("bool") boolean isInterval(); + public static native NDIndex all(); + public static native NDIndex point(@Cast("Nd4jLong") long pt); + public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); + public static native NDIndex interval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); +} + +@Namespace("sd") public static class NDIndexAll extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexAll(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDIndexAll(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDIndexAll position(long position) { + return (NDIndexAll)super.position(position); } + public NDIndexAll() { super((Pointer)null); allocate(); } + private native void allocate(); + public native @Cast("bool") boolean isInterval(); +} + +@Namespace("sd") public static class NDIndexPoint extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexPoint(Pointer p) { super(p); } + + public NDIndexPoint(@Cast("Nd4jLong") long point) { super((Pointer)null); allocate(point); } + private native void allocate(@Cast("Nd4jLong") long point); + public native @Cast("bool") boolean isInterval(); +} +@Namespace("sd") public static class NDIndexInterval extends NDIndex { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDIndexInterval(Pointer p) { super(p); } + public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/) { super((Pointer)null); allocate(start, end, stride); } + private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end, @Cast("Nd4jLong") long stride/*=1*/); + public NDIndexInterval(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end) { super((Pointer)null); allocate(start, end); } + private native void allocate(@Cast("Nd4jLong") long start, @Cast("Nd4jLong") long end); + public native @Cast("bool") boolean isInterval(); +} + // namespace sd -// #endif //LIBND4J_NDINDEX_H +// #endif // LIBND4J_NDINDEX_H // Parsed from indexing/IndicesList.h @@ -3606,20 +3883,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_INDICESLIST_H // #include + // #include "NDIndex.h" - @Namespace("sd") @NoOffset public static class IndicesList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public IndicesList(Pointer p) { super(p); } - +@Namespace("sd") @NoOffset public static class IndicesList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public IndicesList(Pointer p) { super(p); } - public native int size(); - public native NDIndex at(int idx); - public native void push_back(NDIndex idx); - public native @Cast("bool") boolean isScalar(); - } -// #endif //LIBND4J_INDICESLIST_H + public native int size(); + public native NDIndex at(int idx); + public native void push_back(NDIndex idx); + public native @Cast("bool") boolean isScalar(); +} + // namespace sd +// #endif // LIBND4J_INDICESLIST_H // Parsed from graph/VariableType.h @@ -3646,15 +3924,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_VARIABLE_TYPE_H // #define ND4J_VARIABLE_TYPE_H - /** enum sd::graph::VariableType */ - public static final int - NDARRAY = 0, - ARRAY_LIST = 1, - FLOW = 2, - CONSTANT = 3, - PLACEHOLDER = 4; - +/** enum sd::graph::VariableType */ +public static final int + NDARRAY = 0, + ARRAY_LIST = 1, + FLOW = 2, + CONSTANT = 3, + PLACEHOLDER = 4; + // namespace sd // #endif @@ -3683,44 +3961,45 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_INPUTLIST_H // #define LIBND4J_INPUTLIST_H +// #include // #include // #include -// #include -// #include // #include - @Namespace("sd::graph") @NoOffset public static class ArgumentsList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ArgumentsList(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ArgumentsList(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ArgumentsList position(long position) { - return (ArgumentsList)super.position(position); - } - - public ArgumentsList() { super((Pointer)null); allocate(); } - private native void allocate(); - - /** - * This method returns number of argument pairs available - * - * @return - */ - public native int size(); - /** - * This method returns Pair at specified index - * - * @param index - * @return - */ - public native @ByRef Pair at(int index); +// #include +@Namespace("sd::graph") @NoOffset public static class ArgumentsList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ArgumentsList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ArgumentsList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ArgumentsList position(long position) { + return (ArgumentsList)super.position(position); } + public ArgumentsList() { super((Pointer)null); allocate(); } + private native void allocate(); + + /** + * This method returns number of argument pairs available + * + * @return + */ + public native int size(); + /** + * This method returns Pair at specified index + * + * @param index + * @return + */ + public native @ByRef Pair at(int index); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_INPUTLIST_H +// #endif // LIBND4J_INPUTLIST_H // Parsed from types/pair.h @@ -3749,29 +4028,28 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_PAIR_H // #include - @Namespace("sd") @NoOffset public static class Pair extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pair(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pair(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pair position(long position) { - return (Pair)super.position(position); - } - - public Pair(int first/*=0*/, int second/*=0*/) { super((Pointer)null); allocate(first, second); } - private native void allocate(int first/*=0*/, int second/*=0*/); - public Pair() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int first(); - public native int second(); +@Namespace("sd") @NoOffset public static class Pair extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pair(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pair(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pair position(long position) { + return (Pair)super.position(position); } + public Pair(int first/*=0*/, int second/*=0*/) { super((Pointer)null); allocate(first, second); } + private native void allocate(int first/*=0*/, int second/*=0*/); + public Pair() { super((Pointer)null); allocate(); } + private native void allocate(); + public native int first(); + public native int second(); +} + // namespace sd -// #endif //LIBND4J_PAIR_H +// #endif // LIBND4J_PAIR_H // Parsed from array/NDArray.h @@ -3795,1110 +4073,1338 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef NDARRAY_H // #define NDARRAY_H -// #include -// #include -// #include -// #include -// #include "legacy/NativeOpExecutioner.h" -// #include -// #include -// #include -// #include -// #include -// #include // #include // #include +// #include +// #include +// #include +// #include +// #include // #include +// #include +// #include +// #include +// #include +// #include // #include -// #include -// #include +// #include +// #include +// #include +// #include // #include // #include -// #include -// #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include +// #include + +// #include +// #include // #include // #include // #include // #include +// #include "legacy/NativeOpExecutioner.h" +@Namespace("sd") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); +@Namespace("sd") @NoOffset public static class NDArray extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDArray(Pointer p) { super(p); } - @Namespace("sd") public static native @ByVal NDArray mmul(@Const @ByRef NDArray arg0, @Const @ByRef NDArray arg1); - - @Namespace("sd") @NoOffset public static class NDArray extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDArray(Pointer p) { super(p); } - - public NDArray() { super((Pointer)null); allocate(); } - private native void allocate(); + public NDArray() { super((Pointer)null); allocate(); } + private native void allocate(); - /** - * do not allocate memory, memory for array is passed from outside - */ + /** + * do not allocate memory, memory for array is passed from outside + */ // #ifndef __JAVACPP_HACK__ // #endif - /** - * do not allocate memory, memory for array is passed from outside - */ - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } - private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo); + /** + * do not allocate memory, memory for array is passed from outside + */ + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongPointer shapeInfo); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, shapeInfo, context, isBuffAlloc); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, shapeInfo); } + private native void allocate(Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo); - /** - * do not allocate memory, memory for array is passed from outside - * we suppose the content of both (device and host) buffers is identical - */ - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); - public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } - private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo); + /** + * do not allocate memory, memory for array is passed from outside + * we suppose the content of both (device and host) buffers is identical + */ + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongPointer shapeInfo); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") LongBuffer shapeInfo); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo, context, isBuffAlloc, isBuffDAlloc); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isBuffAlloc/*=false*/, @Cast("bool") boolean isBuffDAlloc/*=false*/); + public NDArray(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(buffer, bufferD, shapeInfo); } + private native void allocate(Pointer buffer, Pointer bufferD, @Cast("const Nd4jLong*") long[] shapeInfo); - /** - * copy constructor - */ - public NDArray(@Const @ByRef NDArray other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef NDArray other); + /** + * copy constructor + */ + public NDArray(@Const @ByRef NDArray other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef NDArray other); - /** - * move constructor - */ + /** + * move constructor + */ - /** - * constructor, create array stored at given workspace - */ - public NDArray(LaunchContext context) { super((Pointer)null); allocate(context); } - private native void allocate(LaunchContext context); + /** + * constructor, create array stored at given workspace + */ + public NDArray(LaunchContext context) { super((Pointer)null); allocate(context); } + private native void allocate(LaunchContext context); + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently + */ + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - */ - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - - /** - * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently - * set dtype as array type - */ - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean nullify/*=true*/); - public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype); - - /** - * this constructor creates new array using shape information contained in vector argument - */ - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(order, shape); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape); - - /** - * This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype - */ - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @StdVector DoublePointer data); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @StdVector DoubleBuffer data); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data) { super((Pointer)null); allocate(order, shape, data); } - private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data); - - /** - * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape - */ - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("sd::DataType") int dtype); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("sd::DataType") int dtype); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/); - public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } - private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("sd::DataType") int dtype); - - /** - * This method returns new array with the same shape & data type - * @return - */ - public native @ByVal NDArray like(); + /** + * constructor creates new NDArray using shape information from "shapeInfo", + * set all elements in new array to be zeros, if copyStrides is true then use + * stride values from "shapeInfo", else calculate strides independently set + * dtype as array type + */ + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("sd::DataType") int dtype); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("sd::DataType") int dtype); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype, + @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean nullify/*=true*/); + public NDArray(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("sd::DataType") int dtype); - /** - * This method returns new uninitialized array with the same shape & data type - * @return - */ - public native @ByVal NDArray ulike(); + /** + * this constructor creates new array using shape information contained in + * vector argument + */ + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(order, shape); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape); + /** + * This constructor creates new array with elements copied from data and using + * shape information stored in shape, elements from data will be casted to + * dtype + */ + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @StdVector DoublePointer data); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @StdVector DoubleBuffer data); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(order, shape, data, dtype, context); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data, @Cast("sd::DataType") int dtype/*=sd::DOUBLE*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + public NDArray(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data) { super((Pointer)null); allocate(order, shape, data); } + private native void allocate(char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @StdVector double[] data); - /** - * this constructor creates new NDArray with shape matching "other" array, - * doesn't copy "other" elements into new array !!! - */ - public NDArray(@Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(other, copyStrides, context); } - private native void allocate(@Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); + /** + * this constructor creates new array using given buffer (without memory + * allocation) and shape information stored in shape + */ + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("sd::DataType") int dtype); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("sd::DataType") int dtype); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("const bool") boolean isBuffAlloc/*=false*/); + public NDArray(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } + private native void allocate(Pointer buffer, char order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("sd::DataType") int dtype); - /** - * this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar - */ - public NDArray(@Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isScalar/*=true*/) { super((Pointer)null); allocate(dtype, context, isScalar); } - private native void allocate(@Cast("sd::DataType") int dtype, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("bool") boolean isScalar/*=true*/); - public NDArray(@Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(dtype); } - private native void allocate(@Cast("sd::DataType") int dtype); + /** + * This method returns new array with the same shape & data type + * @return + */ + public native @ByVal NDArray like(); - /** - * This method blocks until asynchronous operation finishes - */ - public native void synchronize(@Cast("char*") String msg); - public native void synchronize(@Cast("char*") BytePointer msg); + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + public native @ByVal NDArray ulike(); - /** - * This method allows to set _isAttached flag - * @param reallyAttached - */ - public native void setAttached(@Cast("bool") boolean reallyAttached); + /** + * this constructor creates new NDArray with shape matching "other" array, + * doesn't copy "other" elements into new array !!! + */ + public NDArray( + @Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(other, copyStrides, context); } + private native void allocate( + @Const NDArray other, @Cast("bool") boolean copyStrides/*=false*/, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/); - public native void tickWriteHost(); - public native void tickWriteDevice(); - public native void tickReadHost(); - public native void tickReadDevice(); - public native void tickBothActual(); - public native @Cast("bool") boolean isActualOnHostSide(); - public native @Cast("bool") boolean isActualOnDeviceSide(); - public native void makeBothBuffersActual(); + /** + * this constructor creates scalar(and set its value = 0) or empty array + * depending on bool argument isScalar + */ + public NDArray(@Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isScalar/*=true*/) { super((Pointer)null); allocate(dtype, context, isScalar); } + private native void allocate(@Cast("sd::DataType") int dtype, + LaunchContext context/*=sd::LaunchContext::defaultContext()*/, + @Cast("bool") boolean isScalar/*=true*/); + public NDArray(@Cast("sd::DataType") int dtype) { super((Pointer)null); allocate(dtype); } + private native void allocate(@Cast("sd::DataType") int dtype); - public native void syncToHost(); - public native void syncToDevice(); - public native void syncShape(); + /** + * This method blocks until asynchronous operation finishes + */ + public native void synchronize(@Cast("char*") String msg); + public native void synchronize(@Cast("char*") BytePointer msg); - /** - * This method can be used on architectures that use special buffers - * @param writeList - * @param readList - */ - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + /** + * This method allows to set _isAttached flag + * @param reallyAttached + */ + public native void setAttached(@Cast("bool") boolean reallyAttached); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public native void tickWriteHost(); + public native void tickWriteDevice(); + public native void tickReadHost(); + public native void tickReadDevice(); + public native void tickBothActual(); + public native @Cast("bool") boolean isActualOnHostSide(); + public native @Cast("bool") boolean isActualOnDeviceSide(); + public native void makeBothBuffersActual(); - /** - * This method returns buffer pointer offset by given number of elements, wrt own data type - * @param offset - * @return - */ - public native Pointer bufferWithOffset(@Cast("Nd4jLong") long offset); - public native Pointer specialBufferWithOffset(@Cast("Nd4jLong") long offset); - /** - * copy assignment operator - * in particular, when _dataType != other._dataType and both shapes are the same, there will be allocation of new _buffer and _dataType acquires other._dataType - */ - public native @ByRef @Name("operator =") NDArray put(@Const @ByRef NDArray other); + public native void syncToHost(); + public native void syncToDevice(); + public native void syncShape(); - /** - * move assignment operator - */ + /** + * This method can be used on architectures that use special buffers + * @param writeList + * @param readList + */ + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, + @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, + @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, + @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); - /** - * assignment operator, assigns the same scalar to all array elements - */ + /** + * This method returns buffer pointer offset by given number of elements, wrt + * own data type + * @param offset + * @return + */ + public native Pointer bufferWithOffset(@Cast("Nd4jLong") long offset); + public native Pointer specialBufferWithOffset(@Cast("Nd4jLong") long offset); + /** + * copy assignment operator + * in particular, when _dataType != other._dataType and both shapes are the + * same, there will be allocation of new _buffer and _dataType acquires + * other._dataType + */ + public native @ByRef @Name("operator =") NDArray put(@Const @ByRef NDArray other); + /** + * move assignment operator + */ - /** - * operators for memory allocation and deletion - */ - public native @Name("operator new") Pointer _new(@Cast("size_t") long i); - public native @Name("operator delete") void _delete(Pointer p); + /** + * assignment operator, assigns the same scalar to all array elements + */ + /** + * operators for memory allocation and deletion + */ + public native @Name("operator new") Pointer _new(@Cast("size_t") long i); + public native @Name("operator delete") void _delete(Pointer p); - public native void setContext(LaunchContext context); + public native void setContext(LaunchContext context); - /** - * create a new array by replicating current array by repeats times along given dimension - * axis - axis along which to repeat elements - * repeats - number of repetitions - */ - public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); - public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); - public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); + /** + * create a new array by replicating current array by repeats times along + * given dimension axis - axis along which to repeat elements repeats - number + * of repetitions + */ + public native @ByVal NDArray repeat(int axis, @StdVector IntPointer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector IntBuffer repeats); + public native @ByVal NDArray repeat(int axis, @StdVector int[] repeats); - /** - * This method fills this array with zeros - */ - public native void nullify(); + /** + * This method fills this array with zeros + */ + public native void nullify(); - /** - * This method returns quantized copy of given array - * - * @param array - * @return - */ - public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); + /** + * This method returns quantized copy of given array + * + * @param array + * @return + */ + public static native @ByVal NDArray quantize(@Const @ByRef NDArray array); - /** - * fill target array by repeating current array - * axis - axis along which to repeat elements - * repeats - vector containing numbers of repetition for elements at given axis - */ - public native void repeat(int axis, @StdVector IntPointer repeats, @ByRef NDArray target); - public native void repeat(int axis, @StdVector IntBuffer repeats, @ByRef NDArray target); - public native void repeat(int axis, @StdVector int[] repeats, @ByRef NDArray target); + /** + * fill target array by repeating current array + * axis - axis along which to repeat elements + * repeats - vector containing numbers of repetition for elements at given + * axis + */ + public native void repeat(int axis, @StdVector IntPointer repeats, + @ByRef NDArray target); + public native void repeat(int axis, @StdVector IntBuffer repeats, + @ByRef NDArray target); + public native void repeat(int axis, @StdVector int[] repeats, + @ByRef NDArray target); - /** - * creates array which points on certain sub-range of this array, sub-range is defined by given indices - */ - - - + /** + * creates array which points on certain sub-range of this array, sub-range + * is defined by given indices + */ + + + - /** - * cast array elements to given dtype - */ - public native @ByVal NDArray cast(@Cast("sd::DataType") int dtype); + /** + * cast array elements to given dtype + */ + public native @ByVal NDArray cast(@Cast("sd::DataType") int dtype); - public native void cast(@ByRef NDArray target, @Cast("sd::DataType") int dtype); + public native void cast(@ByRef NDArray target, @Cast("sd::DataType") int dtype); - /** - * returns _context - */ - public native LaunchContext getContext(); + /** + * returns _context + */ + public native LaunchContext getContext(); // #ifndef __JAVACPP_HACK__ // #endif - /** - * returns host buffer - */ - public native Pointer buffer(); - - - /** - * returns buffer offset (offset is the same for host and device buffers) - */ - public native @Cast("Nd4jLong") long bufferOffset(); - - /** - * if _bufferD==nullptr return _buffer, else return _bufferD - */ - public native Pointer specialBuffer(); + /** + * returns host buffer + */ + public native Pointer buffer(); - /** - * returns device buffer if compilation is for cuda case, otherwise returns host buffer - */ - public native Pointer platformBuffer(); + /** + * returns buffer offset (offset is the same for host and device buffers) + */ + public native @Cast("Nd4jLong") long bufferOffset(); - /** - * returns _shapeInfo - */ - public native @Cast("const Nd4jLong*") LongPointer shapeInfo(); + /** + * if _bufferD==nullptr return _buffer, else return _bufferD + */ + public native Pointer specialBuffer(); + /** + * returns device buffer if compilation is for cuda case, otherwise returns + * host buffer + */ + public native Pointer platformBuffer(); - /** - * Returns True if it's legally empty NDArray, or false otherwise - * @return - */ - public native @Cast("bool") boolean isEmpty(); + /** + * returns _shapeInfo + */ + public native @Cast("const Nd4jLong*") LongPointer shapeInfo(); - /** - * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD - */ - public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); + /** + * Returns True if it's legally empty NDArray, or false otherwise + * @return + */ + public native @Cast("bool") boolean isEmpty(); - public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); + /** + * if _shapeInfoD==nullptr return _shapeInfo, else return _shapeInfoD + */ + public native @Cast("const Nd4jLong*") LongPointer specialShapeInfo(); - /** - * permutes (in-place) the dimensions in array according to "dimensions" array - */ - public native @Cast("bool") boolean permutei(@StdVector IntPointer dimensions); - public native @Cast("bool") boolean permutei(@StdVector IntBuffer dimensions); - public native @Cast("bool") boolean permutei(@StdVector int[] dimensions); - public native @Cast("bool") boolean permutei(@Const IntPointer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Const IntBuffer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Const int[] dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); - public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector long[] dimensions); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); - public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") long[] dimensions, int rank); - - public native @Cast("bool") boolean isFinite(); - public native @Cast("bool") boolean hasNaNs(); - public native @Cast("bool") boolean hasInfs(); - - public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other, @Cast("size_t") long sizeToCopyInBytes/*=0*/, @Cast("Nd4jLong") long offsetThis/*=0*/, @Cast("Nd4jLong") long offsetOther/*=0*/); - public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other); + public native @Cast("const Nd4jLong*") LongPointer platformShapeInfo(); - /** - * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array - */ - public native @ByVal NDArray permute(@StdVector IntPointer dimensions); - public native @ByVal NDArray permute(@StdVector IntBuffer dimensions); - public native @ByVal NDArray permute(@StdVector int[] dimensions); - public native @ByVal NDArray permute(@Const IntPointer dimensions, int rank); - public native @ByVal NDArray permute(@Const IntBuffer dimensions, int rank); - public native @ByVal NDArray permute(@Const int[] dimensions, int rank); - - - + /** + * permutes (in-place) the dimensions in array according to "dimensions" + * array + */ + public native @Cast("bool") boolean permutei(@StdVector IntPointer dimensions); + public native @Cast("bool") boolean permutei(@StdVector IntBuffer dimensions); + public native @Cast("bool") boolean permutei(@StdVector int[] dimensions); + public native @Cast("bool") boolean permutei(@Const IntPointer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Const IntBuffer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Const int[] dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); + public native @Cast("bool") boolean permutei(@Cast("Nd4jLong*") @StdVector long[] dimensions); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); + public native @Cast("bool") boolean permutei(@Cast("const Nd4jLong*") long[] dimensions, int rank); + + public native @Cast("bool") boolean isFinite(); + public native @Cast("bool") boolean hasNaNs(); + public native @Cast("bool") boolean hasInfs(); + + public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other, + @Cast("size_t") long sizeToCopyInBytes/*=0*/, + @Cast("Nd4jLong") long offsetThis/*=0*/, + @Cast("Nd4jLong") long offsetOther/*=0*/); + public native void copyBuffersContinuouslyFrom(@Const @ByRef NDArray other); - public native void permute(@Const IntPointer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Const IntBuffer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Const int[] dimensions, int rank, @ByRef NDArray target); - public native void permute(@StdVector IntPointer dimensions, @ByRef NDArray target); - public native void permute(@StdVector IntBuffer dimensions, @ByRef NDArray target); - public native void permute(@StdVector int[] dimensions, @ByRef NDArray target); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); - public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); - public native @ByVal NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); - - - + /** + * permutes the dimensions in array according to "dimensions" array, new + * array points on _buffer of this array + */ + public native @ByVal NDArray permute(@StdVector IntPointer dimensions); + public native @ByVal NDArray permute(@StdVector IntBuffer dimensions); + public native @ByVal NDArray permute(@StdVector int[] dimensions); + public native @ByVal NDArray permute(@Const IntPointer dimensions, int rank); + public native @ByVal NDArray permute(@Const IntBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Const int[] dimensions, int rank); + + + + + public native void permute(@Const IntPointer dimensions, int rank, @ByRef NDArray target); + public native void permute(@Const IntBuffer dimensions, int rank, @ByRef NDArray target); + public native void permute(@Const int[] dimensions, int rank, @ByRef NDArray target); + public native void permute(@StdVector IntPointer dimensions, @ByRef NDArray target); + public native void permute(@StdVector IntBuffer dimensions, @ByRef NDArray target); + public native void permute(@StdVector int[] dimensions, @ByRef NDArray target); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions); + public native @ByVal NDArray permute(@Cast("Nd4jLong*") @StdVector long[] dimensions); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank); + public native @ByVal NDArray permute(@Cast("const Nd4jLong*") long[] dimensions, int rank); + + + + + public native void permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("const Nd4jLong*") long[] dimensions, int rank, + @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions, @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions, @ByRef NDArray target); + public native void permute(@Cast("Nd4jLong*") @StdVector long[] dimensions, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") LongPointer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") LongBuffer dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("const Nd4jLong*") long[] dimensions, int rank, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector LongPointer dimensions, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector LongBuffer dimensions, @ByRef NDArray target); - public native void permute(@Cast("Nd4jLong*") @StdVector long[] dimensions, @ByRef NDArray target); + /** + * This method streamlines given view or permuted array, and reallocates + * buffer + */ + public native void streamline(char order/*='a'*/); + public native void streamline(); - /** - * This method streamlines given view or permuted array, and reallocates buffer - */ - public native void streamline(char order/*='a'*/); - public native void streamline(); + /** + * prints information about array shape + * msg - message to print out + */ + public native void printShapeInfo(@Cast("char*") String msg/*=nullptr*/); + public native void printShapeInfo(); + public native void printShapeInfo(@Cast("char*") BytePointer msg/*=nullptr*/); - /** - * prints information about array shape - * msg - message to print out - */ - public native void printShapeInfo(@Cast("char*") String msg/*=nullptr*/); - public native void printShapeInfo(); - public native void printShapeInfo(@Cast("char*") BytePointer msg/*=nullptr*/); + /** + * prints buffer elements + * msg - message to print out + * limit - number of array elements to print out + * sync - if true check whether host buffer is actual, if it is not then make + * it so + */ + public native void printBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, + @Cast("const bool") boolean sync/*=true*/); + public native void printBuffer(); + public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, + @Cast("const bool") boolean sync/*=true*/); - /** - * prints buffer elements - * msg - message to print out - * limit - number of array elements to print out - * sync - if true check whether host buffer is actual, if it is not then make it so - */ - public native void printBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); - public native void printBuffer(); - public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); + /** + * make strings for current ndarray - linear or structured by dimensions + * */ + public native @StdString BytePointer linearString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer linearString(); + public native @StdString BytePointer indexedBufferString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer indexedBufferString(); - /** - * print element by element consequently in a way they (elements) are stored in physical memory - */ - public native void printLinearBuffer(); + /** + * print element by element consequently in a way they (elements) are stored + * in physical memory + */ + public native void printLinearBuffer(); - /** - * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status - */ + /** + * prints _buffer (if host = true) or _bufferD (if host = false) as it is, + * that is in current state without checking buffer status + */ - /** - * prints buffer elements, takes into account offset between elements (element-wise-stride) - * msg - message to print out - * limit - number of array elements to print out - */ - public native void printIndexedBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); - public native void printIndexedBuffer(); - public native void printIndexedBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); + /** + * prints buffer elements, takes into account offset between elements + * (element-wise-stride) msg - message to print out limit - number of array + * elements to print out + */ + public native void printIndexedBuffer(@Cast("char*") String msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); + public native void printIndexedBuffer(); + public native void printIndexedBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asIndexedString(@Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asIndexedString(); - public native @StdString BytePointer asString(@Cast("Nd4jLong") long _limit/*=-1*/); - public native @StdString BytePointer asString(); + public native @StdString BytePointer asIndexedString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer asIndexedString(); + public native @StdString BytePointer asString(@Cast("Nd4jLong") long _limit/*=-1*/); + public native @StdString BytePointer asString(); - /** - * this method assigns values of given array to this one - */ - public native void assign(@Const NDArray other, @Cast("bool") boolean allowParallelism/*=true*/); - public native void assign(@Const NDArray other); + /** + * this method assigns values of given array to this one + */ + public native void assign(@Const NDArray other, @Cast("bool") boolean allowParallelism/*=true*/); + public native void assign(@Const NDArray other); - /** - * this method assigns values of given array to this one - */ + /** + * this method assigns values of given array to this one + */ - /** - * this method assigns given value to all elements in array - */ + /** + * this method assigns given value to all elements in array + */ - /** - * returns new copy of this array, optionally in different order - */ - public native @ByVal NDArray dup(byte newOrder/*='a'*/); - public native @ByVal NDArray dup(); + /** + * returns new copy of this array, optionally in different order + */ + public native @ByVal NDArray dup(byte newOrder/*='a'*/); + public native @ByVal NDArray dup(); - /** - * returns sum of all elements of array - */ - public native @ByVal NDArray sumNumber(); + /** + * returns sum of all elements of array + */ + public native @ByVal NDArray sumNumber(); - /** - * returns mean number of array - */ - public native @ByVal NDArray meanNumber(); + /** + * returns mean number of array + */ + public native @ByVal NDArray meanNumber(); // #ifndef __JAVACPP_HACK__ // #endif - /** - * apply transpose operation to the copy of this array, that is this array remains unaffected - */ - public native @ByVal NDArray transpose(); - - - /** - * perform transpose operation and store result in target, this array remains unaffected - * target - where to store result - */ - public native void transpose(@ByRef NDArray target); + /** + * apply transpose operation to the copy of this array, that is this array + * remains unaffected + */ + public native @ByVal NDArray transpose(); + - /** - * apply in-place transpose operation to this array, so this array becomes transposed - */ - public native void transposei(); + /** + * perform transpose operation and store result in target, this array remains + * unaffected target - where to store result + */ + public native void transpose(@ByRef NDArray target); - /** - * returns the number of arrays pointing on specified dimension(s) - * dimensions - array of dimensions to point on - */ - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntPointer dimensions); - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntBuffer dimensions); - public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector int[] dimensions); + /** + * apply in-place transpose operation to this array, so this array becomes + * transposed + */ + public native void transposei(); - /** - * returns true if elements of two arrays are equal to within given epsilon value - * other - input array to compare - * eps - epsilon, this value defines the precision of elements comparison - */ - public native @Cast("bool") boolean equalsTo(@Const NDArray other, double eps/*=1e-5*/); - public native @Cast("bool") boolean equalsTo(@Const NDArray other); + /** + * returns the number of arrays pointing on specified dimension(s) + * dimensions - array of dimensions to point on + */ + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntPointer dimensions); + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @Cast("Nd4jLong") long tensorsAlongDimension(@StdVector int[] dimensions); - /** - * add given row vector to all rows of this array - * row - row vector to add - */ - public native void addiRowVector(@Const @ByRef NDArray row); + /** + * returns true if elements of two arrays are equal to within given epsilon + * value other - input array to compare eps - epsilon, this value defines the + * precision of elements comparison + */ + public native @Cast("bool") boolean equalsTo(@Const NDArray other, double eps/*=1e-5*/); + public native @Cast("bool") boolean equalsTo(@Const NDArray other); - /** - * add given row vector to all rows of this array, store result in target - * row - row vector to add - * target - where to store result - */ - public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * add given row vector to all rows of this array + * row - row vector to add + */ + public native void addiRowVector(@Const @ByRef NDArray row); - /** - * subtract given row vector from all rows of this array, store result in target - * row - row vector to subtract - * target - where to store result - */ - public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * add given row vector to all rows of this array, store result in target + * row - row vector to add + * target - where to store result + */ + public native void addRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * multiply all rows of this array on given row vector, store result in target - * row - row vector to multiply on - * target - where to store result - */ - public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * subtract given row vector from all rows of this array, store result in + * target row - row vector to subtract target - where to store result + */ + public native void subRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * divide all rows of this array on given row vector, store result in target - * row - row vector to divide on - * target - where to store result - */ - public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); + /** + * multiply all rows of this array on given row vector, store result in + * target row - row vector to multiply on target - where to store result + */ + public native void mulRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * add given column vector to all columns of this array, store result in target - * column - column vector to add - * target - where to store result - */ - public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); + /** + * divide all rows of this array on given row vector, store result in target + * row - row vector to divide on + * target - where to store result + */ + public native void divRowVector(@Const @ByRef NDArray row, @ByRef NDArray target); - /** - * add given column vector to all columns of this array, this array becomes affected (in-place operation) - * column - column vector to add - */ - public native void addiColumnVector(@Const @ByRef NDArray column); + /** + * add given column vector to all columns of this array, store result in + * target column - column vector to add target - where to store result + */ + public native void addColumnVector(@Const @ByRef NDArray column, @ByRef NDArray target); - /** - * multiply all columns of this array on given column vector, this array becomes affected (in-place operation) - * column - column vector to multiply on - */ - public native void muliColumnVector(@Const @ByRef NDArray column); + /** + * add given column vector to all columns of this array, this array becomes + * affected (in-place operation) column - column vector to add + */ + public native void addiColumnVector(@Const @ByRef NDArray column); - /** - * returns number of bytes used by _buffer & _shapeInfo - */ - public native @Cast("Nd4jLong") long memoryFootprint(); + /** + * multiply all columns of this array on given column vector, this array + * becomes affected (in-place operation) column - column vector to multiply on + */ + public native void muliColumnVector(@Const @ByRef NDArray column); - /** - * these methods suited for FlatBuffers use - */ - public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); - public native @StdVector IntPointer getShapeAsVectorInt(); - public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); - public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); - public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); + /** + * returns number of bytes used by _buffer & _shapeInfo + */ + public native @Cast("Nd4jLong") long memoryFootprint(); - /** - * set new order and shape in case of suitable array length (in-place operation) - * order - order to set - * shape - shape to set - * copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping - * if there was permute applied before or there are weird strides, then new buffer is allocated for array - */ - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape); + /** + * these methods suited for FlatBuffers use + */ + public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeAsVector(); + public native @StdVector IntPointer getShapeAsVectorInt(); + public native @Cast("Nd4jLong*") @StdVector LongPointer getShapeInfoAsVector(); + public native @Cast("int64_t*") @StdVector LongPointer getShapeInfoAsFlatVector(); + public native @Cast("int64_t*") @StdVector LongPointer getShapeAsFlatVector(); - /** - * creates new array with corresponding order and shape, new array will point on _buffer of this array - * order - order to set - * shape - shape to set - * - * if permute have been applied before or there are weird strides, then new buffer is allocated for new array - */ - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("const bool") boolean copyToNewBuff/*=true*/); - public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - + /** + * set new order and shape in case of suitable array length (in-place + * operation) order - order to set shape - shape to set copyToNewBuff - if + * true then old buffer will be copied to new buffer if last one will be + * allocated after reshaping if there was permute applied before or there are + * weird strides, then new buffer is allocated for array + */ + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @Cast("bool") boolean reshapei(@Cast("Nd4jLong*") @StdVector long[] shape); - /** - * calculate strides and set given order - * order - order to set - */ - public native void updateStrides(byte order); + /** + * creates new array with corresponding order and shape, new array will point + * on _buffer of this array order - order to set shape - shape to set + * + * if permute have been applied before or there are weird strides, then new + * buffer is allocated for new array + */ + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("const bool") boolean copyToNewBuff/*=true*/); + public native @ByVal NDArray reshape(byte order, @Cast("Nd4jLong*") @StdVector long[] shape); + - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - */ - public native void tilei(@Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native void tilei(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native void tilei(@Cast("Nd4jLong*") @StdVector long[] repeats); + /** + * calculate strides and set given order + * order - order to set + */ + public native void updateStrides(byte order); - /** - * returns new array which is created by repeating of this array the number of times given by reps - * repeats - contains numbers of repetitions - */ - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats); - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); - public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector long[] repeats); + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions + */ + public native void tilei(@Cast("Nd4jLong*") @StdVector LongPointer repeats); + public native void tilei(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); + public native void tilei(@Cast("Nd4jLong*") @StdVector long[] repeats); - /** - * change an array by repeating it the number of times given by reps (in-place operation) - * repeats - contains numbers of repetitions - * target - where to store result - */ - public native void tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats, @ByRef NDArray target); - public native void tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats, @ByRef NDArray target); - public native void tile(@Cast("Nd4jLong*") @StdVector long[] repeats, @ByRef NDArray target); + /** + * returns new array which is created by repeating of this array the number + * of times given by reps repeats - contains numbers of repetitions + */ + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats); + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats); + public native @ByVal NDArray tile(@Cast("Nd4jLong*") @StdVector long[] repeats); - /** - * change an array by repeating it the number of times to acquire the new shape which is the same as target shape - * target - where to store result - */ - public native void tile(@ByRef NDArray target); + /** + * change an array by repeating it the number of times given by reps + * (in-place operation) repeats - contains numbers of repetitions target - + * where to store result + */ + public native void tile(@Cast("Nd4jLong*") @StdVector LongPointer repeats, @ByRef NDArray target); + public native void tile(@Cast("Nd4jLong*") @StdVector LongBuffer repeats, @ByRef NDArray target); + public native void tile(@Cast("Nd4jLong*") @StdVector long[] repeats, @ByRef NDArray target); - /** - * check whether array is identity matrix - */ - public native @Cast("bool") boolean isIdentityMatrix(); + /** + * change an array by repeating it the number of times to acquire the new + * shape which is the same as target shape target - where to store result + */ + public native void tile(@ByRef NDArray target); - /** - * check whether array is unitary matrix - */ - public native @Cast("bool") boolean isUnitary(); + /** + * check whether array is identity matrix + */ + public native @Cast("bool") boolean isIdentityMatrix(); - /** - * operator returns subarray with buffer pointing at this->_buffer with offset defined by given intervals - * idx - intervals of indexes which define the subarrays to point on, idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * this->rankOf()) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - * isStrided - if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * so structure of idx is like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} - */ - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx); + /** + * check whether array is unitary matrix + */ + public native @Cast("bool") boolean isUnitary(); - /** - * evaluates subarray with buffer pointing at this->_buffer and offset defined by given sequential index subArrIdx and dimensions in dimsToExclude - * subArrIdx - index of current sub-array - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] - * if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned. - * keepUnitiesInShape - if false then eliminate unities from resulting array shape, for example {1,a,1,b} -> {a,b} - */ - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntPointer dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntPointer dimsToExclude); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntBuffer dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector IntBuffer dimsToExclude); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector int[] dimsToExclude, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, @StdVector int[] dimsToExclude); + /** + * operator returns subarray with buffer pointing at this->_buffer with + * offset defined by given intervals idx - intervals of indexes which define + * the subarrays to point on, idx has form {dim0Start,dim0End, + * dim1Start,dim1End, ....} and length (2 * this->rankOf()) when (dimStart == + * dimEnd) then whole range will be used for current dimension + * keepUnitiesInShape - if false then eliminate unities from resulting array + * shape, for example {1,a,1,b} -> {a,b} isStrided - if true then idx has + * length (3 * this->rankOf()) and contains additional stride numbers which + * correspond to stride between dimStart and dimEnd, so structure of idx is + * like {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} + */ + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongPointer idx); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector LongBuffer idx); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("Nd4jLong*") @StdVector long[] idx); - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * if dimsToExclude.size() = array rank it means sub-array is whole array and copy of original_shapeInfo will be returned and one zero offset - * subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets); - public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets); - public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); + /** + * evaluates subarray with buffer pointing at this->_buffer and offset + * defined by given sequential index subArrIdx and dimensions in dimsToExclude + * subArrIdx - index of current sub-array + * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, + * i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 + * sub-arrays with shape [3,5], and subArrIdx must be in range [0,7] if + * dimsToExclude is empty then idxRanges containing all zeros (means whole + * array) will be returned. keepUnitiesInShape - if false then eliminate + * unities from resulting array shape, for example {1,a,1,b} -> {a,b} + */ + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntPointer dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntPointer dimsToExclude); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntBuffer dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector IntBuffer dimsToExclude); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector int[] dimsToExclude, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native @ByVal @Name("operator ()") NDArray apply(@Cast("const Nd4jLong") long subArrIdx, + @StdVector int[] dimsToExclude); - /** - * addition unary operator array += other - * other - input array to add - */ - public native @Name("operator +=") void addPut(@Const @ByRef NDArray other); + /** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) dimsToExclude - MUST BE SORTED, dimensions to + * evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] if + * dimsToExclude.size() = array rank it means sub-array is whole array and + * copy of original_shapeInfo will be returned and one zero offset + * subArrShapeInfo - output argument, contains shapeInfo common for all + * sub-arrays subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ + public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector IntPointer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer subArrOffsets); + public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector IntBuffer dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer subArrOffsets); + public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets, + @Cast("bool") boolean keepUnitiesInShape/*=false*/); + public native void getSubArrShapeAndOffsets(@StdVector int[] dimsToExclude, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] subArrOffsets); - /** - * subtraction unary operator array -= other - * other - input array to add - */ - public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); + /** + * addition unary operator array += other + * other - input array to add + */ + public native @Name("operator +=") void addPut(@Const @ByRef NDArray other); - /** - * negative operator, it changes sign of all array elements on opposite - */ - public native @ByVal @Name("operator -") NDArray subtract(); - + /** + * subtraction unary operator array -= other + * other - input array to add + */ + public native @Name("operator -=") void subtractPut(@Const @ByRef NDArray other); - /** - * pairwise multiplication unary operator array *= other - * other - input array to multiply on - */ - public native @Name("operator *=") void multiplyPut(@Const @ByRef NDArray other); + /** + * negative operator, it changes sign of all array elements on opposite + */ + public native @ByVal @Name("operator -") NDArray subtract(); + - /** - * multiplication unary operator array *= scalar - * scalar - input scalar to multiply on - */ + /** + * pairwise multiplication unary operator array *= other + * other - input array to multiply on + */ + public native @Name("operator *=") void multiplyPut(@Const @ByRef NDArray other); - /** - * pairwise division unary operator: array /= other - * other - input array to divide on - */ - public native @Name("operator /=") void dividePut(@Const @ByRef NDArray other); + /** + * multiplication unary operator array *= scalar + * scalar - input scalar to multiply on + */ - /** - * division unary operator: array /= scalar - * scalar - input scalar to divide on - */ + /** + * pairwise division unary operator: array /= other + * other - input array to divide on + */ + public native @Name("operator /=") void dividePut(@Const @ByRef NDArray other); - /** - * friend function which implements mathematical multiplication of two arrays - * left - input array - * right - input array - */ - + /** + * division unary operator: array /= scalar + * scalar - input scalar to divide on + */ - /** - * return vector containing _buffer as flat binary array - */ - public native @StdVector BytePointer asByteVector(); + /** + * friend function which implements mathematical multiplication of two arrays + * left - input array + * right - input array + */ + - /** - * makes array to be identity matrix (not necessarily square), that is set all diagonal elements = 1, rest = 0 - */ - public native void setIdentity(); + /** + * return vector containing _buffer as flat binary array + */ + public native @StdVector BytePointer asByteVector(); - /** - * swaps the contents of tow arrays, - * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes may be different except one condition: arrays lengths must be the same - */ - public native void swapUnsafe(@ByRef NDArray other); + /** + * makes array to be identity matrix (not necessarily square), that is set + * all diagonal elements = 1, rest = 0 + */ + public native void setIdentity(); - /** - * return vector with buffer which points on corresponding diagonal elements of array - * type - means of vector to be returned: column ('c') or row ('r') - */ - public native @ByVal NDArray diagonal(byte type ); + /** + * swaps the contents of tow arrays, + * PLEASE NOTE: method doesn't take into account the shapes of arrays, shapes + * may be different except one condition: arrays lengths must be the same + */ + public native void swapUnsafe(@ByRef NDArray other); - /** - * fill target matrix with given value in one or two directions from main diagonal: - * - down from main diagonal starting at subdiagonal number "lower" if direction = 'l' (down) or 'b' (both) - * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) - * direction - in what direction to fill matrix. There are 3 possible directions: - * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected - * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected - * 'b' - fill in both directions, both "lower" and "upper" are taken into account - * rest of target elements are equal to this array elements - * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) - */ + /** + * return vector with buffer which points on corresponding diagonal elements + * of array type - means of vector to be returned: column ('c') or row ('r') + */ + public native @ByVal NDArray diagonal(byte type); - /** - * change an array by repeating it the number of times in order to acquire new shape equal to the input shape - * - * shape - contains new shape to broadcast array to - * target - optional argument, if target != nullptr the resulting array will be placed in target, in opposite case tile operation is done in place - */ - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); - public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); + /** + * fill target matrix with given value in one or two directions from main + * diagonal: + * - down from main diagonal starting at subdiagonal number "lower" if + * direction = 'l' (down) or 'b' (both) + * - up from main diagonal starting at superdiagonal number "upper"if + * direction = 'u' (up) or 'b' (both) direction - in what direction to fill + * matrix. There are 3 possible directions: 'u' - fill up, mathematically this + * corresponds to lower triangular matrix, subdiagonal "lower" unaffected 'l' + * - fill down, mathematically this corresponds to upper triangular matrix, + * superdiagonal "upper" remains unaffected 'b' - fill in both directions, + * both "lower" and "upper" are taken into account rest of target elements are + * equal to this array elements target and this array should have same shapes, + * except when this_rank = 1 (in that case should be target_rank = 2) + */ + + /** + * change an array by repeating it the number of times in order to acquire + * new shape equal to the input shape + * + * shape - contains new shape to broadcast array to + * target - optional argument, if target != nullptr the resulting array will + * be placed in target, in opposite case tile operation is done in place + */ + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native @ByVal NDArray tileToShape(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongPointer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape, @ByRef NDArray target); + public native void tileToShape(@Cast("Nd4jLong*") @StdVector long[] shape, @ByRef NDArray target); // #ifndef __JAVACPP_HACK__ // #endif - public native @ByVal NDArray asT(@Cast("sd::DataType") int dtype); + public native @ByVal NDArray asT(@Cast("sd::DataType") int dtype); + public native void linspace(double start); - public native void linspace(double start); + public native void linspace(double start, double step); - public native void linspace(double start, double step); + /** + * calculates the trace of an array, that is sum of elements on main diagonal + * = sum array[i, i, i, ...] + */ + public native double getTrace(); - /** - * calculates the trace of an array, that is sum of elements on main diagonal = sum array[i, i, i, ...] - */ - public native double getTrace(); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector IntPointer indices, + @StdVector IntPointer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector IntBuffer indices, + @StdVector IntBuffer dimensions); + public native @ByVal ResultSet multipleTensorsAlongDimension( + @StdVector int[] indices, + @StdVector int[] dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntPointer indices, @StdVector IntPointer dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector IntBuffer indices, @StdVector IntBuffer dimensions); - public native @ByVal ResultSet multipleTensorsAlongDimension(@StdVector int[] indices, @StdVector int[] dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); + public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntPointer dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector IntBuffer dimensions); - public native @ByVal ResultSet allTensorsAlongDimension(@StdVector int[] dimensions); + public native @ByVal ResultSet allExamples(); - public native @ByVal ResultSet allExamples(); + /** + * set _shapeInfo + */ + public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); + public native void setShapeInfo(@Const @ByRef ShapeDescriptor descriptor); + public native void setShapeInfo(@Const @ByRef ConstantShapeBuffer shapeBuffer); - /** - * set _shapeInfo - */ - public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); - public native void setShapeInfo(@Const @ByRef ShapeDescriptor descriptor); - public native void setShapeInfo(@Const @ByRef ConstantShapeBuffer shapeBuffer); + /** + * returns absolute offset which corresponds to given sequential index + */ + public native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong") long i); - /** - * returns absolute offset which corresponds to given sequential index - */ - public native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong") long i); + /** + * returns reference on array element with given index + */ - /** - * returns reference on array element with given index - */ + /** + * returns array element with given index + * i - element index in array + */ + /** + * default destructor + */ - /** - * returns array element with given index - * i - element index in array - */ + /** + * set _shapeInfo + */ + /** + * returns the value of "dim" dimension + */ + public native @Cast("Nd4jLong") long sizeAt(int dim); - /** - * default destructor - */ + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); - /** - * set _shapeInfo - */ + /** + * returns order of array + */ + public native char ordering(); - /** - * returns the value of "dim" dimension - */ - public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * return _isView + */ + public native @Cast("bool") boolean isView(); - /** - * returns stride of "dim" dimension - */ - public native @Cast("Nd4jLong") long strideAt(int dim); + /** + * returns shape portion of shapeInfo + */ + public native @Cast("Nd4jLong*") LongPointer shapeOf(); - /** - * returns order of array - */ - public native char ordering(); + /** + * returns strides portion of shapeInfo + */ + public native @Cast("Nd4jLong*") LongPointer stridesOf(); - /** - * return _isView - */ - public native @Cast("bool") boolean isView(); + /** + * returns rank of array + */ + public native int rankOf(); - /** - * returns shape portion of shapeInfo - */ - public native @Cast("Nd4jLong*") LongPointer shapeOf(); + /** + * returns length of array + */ + public native @Cast("Nd4jLong") long lengthOf(); - /** - * returns strides portion of shapeInfo - */ - public native @Cast("Nd4jLong*") LongPointer stridesOf(); + /** + * returns number of rows in array + */ + public native @Cast("Nd4jLong") long rows(); - /** - * returns rank of array - */ - public native int rankOf(); + /** + * returns number of columns in array + */ + public native @Cast("Nd4jLong") long columns(); - /** - * returns length of array - */ - public native @Cast("Nd4jLong") long lengthOf(); + /** + * returns size of array elements type + */ + public native @Cast("size_t") long sizeOfT(); - /** - * returns number of rows in array - */ - public native @Cast("Nd4jLong") long rows(); + /** + * returns element-wise-stride + */ + public native @Cast("Nd4jLong") long ews(); - /** - * returns number of columns in array - */ - public native @Cast("Nd4jLong") long columns(); + // returns true if arrays have same shape + public native @Cast("bool") boolean isSameShape(@Const NDArray other); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector long[] shape); + public native @Cast("bool") boolean areSameShapeAndType(@Const @ByRef NDArray other); - /** - * returns size of array elements type - */ - public native @Cast("size_t") long sizeOfT(); + /** + * returns true if these two NDArrays have same rank, dimensions, strides, + * ews and order + */ + public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); - /** - * returns element-wise-stride - */ - public native @Cast("Nd4jLong") long ews(); + /** + * returns true if buffer && shapeInfo were defined (non nullptr) + */ + public native @Cast("bool") boolean nonNull(); - // returns true if arrays have same shape - public native @Cast("bool") boolean isSameShape(@Const NDArray other); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); - public native @Cast("bool") boolean isSameShape(@Cast("Nd4jLong*") @StdVector long[] shape); - public native @Cast("bool") boolean areSameShapeAndType(@Const @ByRef NDArray other); + /** + * returns array element with given index from linear buffer + * i - element index in array + */ - /** - * returns true if these two NDArrays have same rank, dimensions, strides, ews and order - */ - public native @Cast("bool") boolean isSameShapeStrict(@Const @ByRef NDArray other); + /** + * returns element with given indexes from 2D array + * i - number of row + * j - number of column + */ - /** - * returns true if buffer && shapeInfo were defined (non nullptr) - */ - public native @Cast("bool") boolean nonNull(); + /** + * returns element with given indexes from 3D array + * i - height + * j - width + * k - depth + */ - /** - * returns array element with given index from linear buffer - * i - element index in array - */ + /** + * returns element with given indexes from DD array + */ - /** - * returns element with given indexes from 2D array - * i - number of row - * j - number of column - */ + /** + * returns array-scalar containing element of this array with given index + * i - element index in array + */ + public native @ByVal NDArray e(@Cast("const Nd4jLong") long i); - /** - * returns element with given indexes from 3D array - * i - height - * j - width - * k - depth - */ + /** + * assigns given scalar to array element by given index, regards array buffer + * as linear i - element index in array value - scalar value to assign + */ - /** - * returns element with given indexes from DD array - */ + public native void p(@Cast("const Nd4jLong") long i, @Const @ByRef NDArray value); - /** - * returns array-scalar containing element of this array with given index - * i - element index in array - */ - public native @ByVal NDArray e(@Cast("const Nd4jLong") long i); + /** + * assigns given scalar to 2D array element by given indexes + * i - number of row + * j - number of row + * value - scalar value to assign + */ - /** - * assigns given scalar to array element by given index, regards array buffer as linear - * i - element index in array - * value - scalar value to assign - */ + /** + * assigns given scalar to 3D array element by given indexes + * i - height + * j - width + * k - depth + * value - scalar value to assign + */ + public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, + @Const @ByRef NDArray value); - public native void p(@Cast("const Nd4jLong") long i, @Const @ByRef NDArray value); + /** + * returns true if array is 2D + */ + public native @Cast("bool") boolean isMatrix(); - /** - * assigns given scalar to 2D array element by given indexes - * i - number of row - * j - number of row - * value - scalar value to assign - */ + /** + * returns true if array is vector + */ + public native @Cast("bool") boolean isVector(); - /** - * assigns given scalar to 3D array element by given indexes - * i - height - * j - width - * k - depth - * value - scalar value to assign - */ - public native void p(@Cast("const Nd4jLong") long i, @Cast("const Nd4jLong") long j, @Cast("const Nd4jLong") long k, @Cast("const Nd4jLong") long l, @Const @ByRef NDArray value); + /** + * returns true if array is column vector + */ + public native @Cast("bool") boolean isColumnVector(); - /** - * returns true if array is 2D - */ - public native @Cast("bool") boolean isMatrix(); + /** + * returns true if array is row vector + */ + public native @Cast("bool") boolean isRowVector(); - /** - * returns true if array is vector - */ - public native @Cast("bool") boolean isVector(); + /** + * returns true if all dimensions of array except one are unities, for + * example: [1,1,n,1], [n,1,1], [n], ... posOfNonUnityDim - one dimension with + * value > 1 + */ + public native @Cast("bool") boolean isCommonVector(@ByRef IntPointer posOfNonUnityDim); + public native @Cast("bool") boolean isCommonVector(@ByRef IntBuffer posOfNonUnityDim); + public native @Cast("bool") boolean isCommonVector(@ByRef int[] posOfNonUnityDim); - /** - * returns true if array is column vector - */ - public native @Cast("bool") boolean isColumnVector(); + /** + * returns true if array is scalar + */ + public native @Cast("bool") boolean isScalar(); - /** - * returns true if array is row vector - */ - public native @Cast("bool") boolean isRowVector(); + /** + * Returns data type of this array + * @return + */ + public native @Cast("sd::DataType") int dataType(); - /** - * returns true if all dimensions of array except one are unities, for example: [1,1,n,1], [n,1,1], [n], ... - * posOfNonUnityDim - one dimension with value > 1 - */ - public native @Cast("bool") boolean isCommonVector(@ByRef IntPointer posOfNonUnityDim); - public native @Cast("bool") boolean isCommonVector(@ByRef IntBuffer posOfNonUnityDim); - public native @Cast("bool") boolean isCommonVector(@ByRef int[] posOfNonUnityDim); - - - /** - * returns true if array is scalar - */ - public native @Cast("bool") boolean isScalar(); - - /** - * Returns data type of this array - * @return - */ - public native @Cast("sd::DataType") int dataType(); - - /** - * This method returns true if value is from Integer space - * @return - */ - public native @Cast("bool") boolean isZ(); - - /** - * This method returns true if array is from Real space - * @return - */ - public native @Cast("bool") boolean isR(); + /** + * This method returns true if value is from Integer space + * @return + */ + public native @Cast("bool") boolean isZ(); - /** - * This method returns true if array is from Boolean space - * @return - */ - public native @Cast("bool") boolean isB(); + /** + * This method returns true if array is from Real space + * @return + */ + public native @Cast("bool") boolean isR(); - /** - * This method returns true if array contains Complex numbers - * @return - */ - public native @Cast("bool") boolean isC(); + /** + * This method returns true if array is from Boolean space + * @return + */ + public native @Cast("bool") boolean isB(); - /** - * This method returns true if array contains String - * @return - */ - public native @Cast("bool") boolean isS(); + /** + * This method returns true if array contains Complex numbers + * @return + */ + public native @Cast("bool") boolean isC(); - public native @Cast("bool") boolean isAttached(); + /** + * This method returns true if array contains String + * @return + */ + public native @Cast("bool") boolean isS(); - public native NDArray detach(); + public native @Cast("bool") boolean isAttached(); - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef NDArray other); + public native @ByVal NDArray detach(); - public native @Cast("bool") @Name("operator !=") boolean notEquals(@Const @ByRef NDArray other); - } + /** + * This method returns true if array is valid array with some shape etc + * @return + */ + public native @Cast("bool") boolean defined(); + public native @Cast("bool") boolean undefined(); + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef NDArray other); + public native @Cast("bool") @Name("operator !=") boolean notEquals(@Const @ByRef NDArray other); +} // class NDArray +@Namespace("sd") public static native @Cast("std::ostream*") @ByRef @Name("operator <<") Pointer shiftLeft(@Cast("std::ostream*") @ByRef Pointer os, @Const @ByRef NDArray m); ////////////////////////////////////////////////////////////////////////// ///// IMLEMENTATION OF INLINE METHODS ///// @@ -4965,7 +5471,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////// @@ -5043,13 +5548,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //////////////////////////////////////////////////////////////////////// - -// #if defined(__CUDACC__) //&& defined(BUILD_TESTS) +// #if defined(__CUDACC__) //&& defined(BUILD_TESTS) // for CUDA we need stil stuff inline // #include // #endif - + // namespace sd // #endif @@ -5081,51 +5585,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef NDARRAY_LIST_H // #define NDARRAY_LIST_H -// #include -// #include -// #include // #include // #include // #include - @Namespace("sd") @NoOffset public static class NDArrayList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NDArrayList(Pointer p) { super(p); } - - public NDArrayList(int height, @Cast("bool") boolean expandable/*=false*/) { super((Pointer)null); allocate(height, expandable); } - private native void allocate(int height, @Cast("bool") boolean expandable/*=false*/); - public NDArrayList(int height) { super((Pointer)null); allocate(height); } - private native void allocate(int height); - public native @Cast("sd::DataType") int dataType(); +// #include +// #include +// #include +@Namespace("sd") @NoOffset public static class NDArrayList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NDArrayList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NDArrayList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NDArrayList position(long position) { + return (NDArrayList)super.position(position); + } + + public NDArrayList(int height/*=0*/, @Cast("bool") boolean expandable/*=false*/) { super((Pointer)null); allocate(height, expandable); } + private native void allocate(int height/*=0*/, @Cast("bool") boolean expandable/*=false*/); + public NDArrayList() { super((Pointer)null); allocate(); } + private native void allocate(); + + public NDArrayList(@Const @ByRef NDArrayList other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef NDArrayList other); - public native NDArray read(int idx); - public native NDArray readRaw(int idx); - public native @Cast("Nd4jStatus") int write(int idx, NDArray array); - public native NDArray pick(@StdVector IntPointer indices); - public native NDArray pick(@StdVector IntBuffer indices); - public native NDArray pick(@StdVector int[] indices); - public native @Cast("bool") boolean isWritten(int index); + public native @ByRef @Name("operator =") @NoException NDArrayList put(@Const @ByRef NDArrayList other); - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + // move assignment operator - public native NDArray stack(); - public native void unstack(NDArray array, int axis); + public native @Cast("sd::DataType") int dataType(); - public native @ByRef IntIntPair id(); - public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); - //sd::memory::Workspace* workspace(); - public native LaunchContext context(); - public native NDArrayList clone(); + public native @ByVal NDArray read(int idx); + public native @ByVal NDArray readRaw(int idx); + public native @Cast("Nd4jStatus") int write(int idx, @Const @ByRef NDArray array); - public native @Cast("bool") boolean equals(@ByRef NDArrayList other); + public native @ByVal NDArray pick(@StdVector IntPointer indices); + public native @ByVal NDArray pick(@StdVector IntBuffer indices); + public native @ByVal NDArray pick(@StdVector int[] indices); + public native @Cast("bool") boolean isWritten(int index); - public native int elements(); - public native int height(); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native void setShape(@Cast("Nd4jLong*") @StdVector LongPointer shape); + public native void setShape(@Cast("Nd4jLong*") @StdVector LongBuffer shape); + public native void setShape(@Cast("Nd4jLong*") @StdVector long[] shape); - public native int counter(); - } + public native @ByVal NDArray stack(); + public native void unstack(@Const @ByRef NDArray array, int axis); + + public native @Const @ByRef IntIntPair id(); + public native @StdString BytePointer name(); + + public native @ByVal NDArrayList clone(); + + public native @Cast("bool") boolean equals(@ByRef NDArrayList other); + + public native int elements(); + public native int height(); + public native int counter(); +} + // namespace sd // #endif @@ -5158,56 +5679,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_RESULTSET_H // #define LIBND4J_RESULTSET_H -// #include // #include +// #include // #include -// #include // forward declaration of template class NDArray - - @Namespace("sd") @NoOffset public static class ResultSet extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ResultSet(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ResultSet(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ResultSet position(long position) { - return (ResultSet)super.position(position); - } - - public ResultSet() { super((Pointer)null); allocate(); } - private native void allocate(); + +// #include // forward declaration of template class NDArray + +@Namespace("sd") @NoOffset public static class ResultSet extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ResultSet(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ResultSet(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ResultSet position(long position) { + return (ResultSet)super.position(position); + } + + public ResultSet() { super((Pointer)null); allocate(); } + private native void allocate(); // #ifndef __JAVACPP_HACK__ // #endif - public ResultSet(@Const @ByRef ResultSet other) { super((Pointer)null); allocate(other); } - @NoException private native void allocate(@Const @ByRef ResultSet other); + public ResultSet(@Const @ByRef ResultSet other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef ResultSet other); - public native @ByRef @Name("operator =") @NoException ResultSet put(@Const @ByRef ResultSet other); + public native @ByRef @Name("operator =") @NoException ResultSet put(@Const @ByRef ResultSet other); - // move constructor + // move constructor - // move assignment operator + // move assignment operator - public native int size(); - public native NDArray at(@Cast("const unsigned long") long idx); - public native @Name("operator []") NDArray get(@Cast("const unsigned long") long idx); - public native void push_back(NDArray array); - - public native @Cast("Nd4jStatus") int status(); - public native void setStatus(@Cast("Nd4jStatus") int status); - public native void purge(); - public native void setNonRemovable(); - } + public native int size(); + public native @ByRef NDArray at(@Cast("const unsigned long") long idx); + public native @ByRef @Name("operator []") NDArray get(@Cast("const unsigned long") long idx); + public native void push_back(@Const @ByRef NDArray array); + public native @Cast("Nd4jStatus") int status(); + public native void setStatus(@Cast("Nd4jStatus") int status); + public native void purge(); + public native void setNonRemovable(); +} + // namespace sd -// #endif //LIBND4J_RESULTSET_H +// #endif // LIBND4J_RESULTSET_H // Parsed from graph/RandomGenerator.h /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -5237,6 +5760,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include // #include // #include +// #include +// #include +// #include +// #include + +// #include // #include // #include @@ -5244,110 +5773,113 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #endif // #ifdef __CUDACC__ // #else - @Namespace("sd::graph") @NoOffset public static class RandomGenerator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public RandomGenerator(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public RandomGenerator(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public RandomGenerator position(long position) { - return (RandomGenerator)super.position(position); - } - +@Namespace("sd::graph") @NoOffset public static class RandomGenerator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public RandomGenerator(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public RandomGenerator(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public RandomGenerator position(long position) { + return (RandomGenerator)super.position(position); + } + public native @Cast("uint32_t") int xoroshiro32(@Cast("uint64_t") long index); - public native @Cast("uint64_t") long xoroshiro64(@Cast("uint64_t") long index); - public RandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/) { super((Pointer)null); allocate(rootSeed, nodeSeed); } - private native void allocate(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); - public RandomGenerator() { super((Pointer)null); allocate(); } - private native void allocate(); + public native @Cast("uint64_t") long xoroshiro64(@Cast("uint64_t") long index); + public RandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/) { super((Pointer)null); allocate(rootSeed, nodeSeed); } + private native void allocate(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); + public RandomGenerator() { super((Pointer)null); allocate(); } + private native void allocate(); + + public RandomGenerator(@Const @ByRef RandomGenerator other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef RandomGenerator other); + + public native @ByRef @Name("operator =") @NoException RandomGenerator put(@Const @ByRef RandomGenerator other); + + // move constructor + + // move assignment operator - /** - * This method allows to change graph-level state in runtime. - * PLEASE NOTE: this method will change state of node as well. - */ - public native void setStates(@Cast("Nd4jLong") long rootSeed, @Cast("Nd4jLong") long nodeState/*=0*/); - public native void setStates(@Cast("Nd4jLong") long rootSeed); + /** + * This method allows to change graph-level state in runtime. + * PLEASE NOTE: this method will change state of node as well. + */ + public native void setStates(@Cast("Nd4jLong") long rootSeed, @Cast("Nd4jLong") long nodeState/*=0*/); + public native void setStates(@Cast("Nd4jLong") long rootSeed); - + /** + * This method returns T value between from and to + */ - /** - * This method returns T value between from and to - */ + /** + * This method returns T value between 0 and MAX_T + */ - /** - * This method returns T value between 0 and MAX_T - */ + /** + * These two methods are made for JVM + * @param index + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index); + public native @Cast("Nd4jLong") long relativeLong(@Cast("Nd4jLong") long index); - /** - * These two methods are made for JVM - * @param index - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index); - public native @Cast("Nd4jLong") long relativeLong(@Cast("Nd4jLong") long index); + public native void rewindH(@Cast("uint64_t") long steps); - public native void rewindH(@Cast("uint64_t") long steps); + /** + * These methods set up only node states, with non-changed root ones + */ + public native void setSeed(int seed); - /** - * These methods set up only node states, with non-changed root ones - */ - public native void setSeed(int seed); + public native void setSeed(@Cast("uint64_t") long seed); - public native void setSeed(@Cast("uint64_t") long seed); + public native @Cast("Nd4jLong") long rootState(); - public native @Cast("Nd4jLong") long rootState(); + public native @Cast("Nd4jLong") long nodeState(); +} - public native @Cast("Nd4jLong") long nodeState(); - } - - - - - - - - - - - - - - - ////// - @Namespace("sd::graph") public static native @Cast("uint32_t") int rotl(@Cast("const uint32_t") int x, int k); - @Namespace("sd::graph") public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, int k); - @Namespace("sd::graph") public static native @Cast("uint32_t") int next(@Cast("uint32_t") int s0, @Cast("uint32_t") int s1, @Cast("uint32_t") int s2, @Cast("uint32_t") int s3); - - + + + + + +////// +@Namespace("sd::graph") public static native @Cast("uint32_t") int rotl(@Cast("const uint32_t") int x, int k); + +@Namespace("sd::graph") public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, int k); + +@Namespace("sd::graph") public static native @Cast("uint32_t") int next(@Cast("uint32_t") int s0, @Cast("uint32_t") int s1, @Cast("uint32_t") int s2, @Cast("uint32_t") int s3); - + + + // namespace graph + // namespace sd + // #endif @@ -5376,98 +5908,113 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLE_H // #define LIBND4J_VARIABLE_H -// #include // #include // #include // #include // #include -// #include // #include +// #include + +// #include // #ifndef __JAVACPP_HACK__ // #endif - @Namespace("sd::graph") @NoOffset public static class Variable extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Variable(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Variable(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Variable position(long position) { - return (Variable)super.position(position); - } - - public Variable(@Cast("bool") boolean placeHolder) { super((Pointer)null); allocate(placeHolder); } - private native void allocate(@Cast("bool") boolean placeHolder); - public Variable(NDArray arrayw, @Cast("char*") String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(arrayw, name, id, idx); } - private native void allocate(NDArray arrayw, @Cast("char*") String name, int id, int idx/*=0*/); - public Variable(NDArray arrayw, @Cast("char*") String name, int id) { super((Pointer)null); allocate(arrayw, name, id); } - private native void allocate(NDArray arrayw, @Cast("char*") String name, int id); - public Variable(NDArray arrayw, @Cast("char*") BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(arrayw, name, id, idx); } - private native void allocate(NDArray arrayw, @Cast("char*") BytePointer name, int id, int idx/*=0*/); - public Variable(NDArray arrayw, @Cast("char*") BytePointer name, int id) { super((Pointer)null); allocate(arrayw, name, id); } - private native void allocate(NDArray arrayw, @Cast("char*") BytePointer name, int id); - public Variable(NDArray array/*=nullptr*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } - private native void allocate(NDArray array/*=nullptr*/, @Cast("char*") String name/*=nullptr*/); - public Variable() { super((Pointer)null); allocate(); } - private native void allocate(); - public Variable(NDArray array/*=nullptr*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } - private native void allocate(NDArray array/*=nullptr*/, @Cast("char*") BytePointer name/*=nullptr*/); +@Namespace("sd::graph") @NoOffset public static class Variable extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Variable(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Variable(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Variable position(long position) { + return (Variable)super.position(position); + } + + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongPointer shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongPointer shape/*={}*/); + public Variable(@Cast("bool") boolean placeHolder) { super((Pointer)null); allocate(placeHolder); } + private native void allocate(@Cast("bool") boolean placeHolder); + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongBuffer shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector LongBuffer shape/*={}*/); + public Variable(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector long[] shape/*={}*/) { super((Pointer)null); allocate(placeHolder, dataType, shape); } + private native void allocate(@Cast("bool") boolean placeHolder, @Cast("sd::DataType") int dataType/*=sd::DataType::ANY*/, @Cast("Nd4jLong*") @StdVector long[] shape/*={}*/); + public Variable(@Const @ByRef NDArray array, @StdString BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@Const @ByRef NDArray array, @StdString BytePointer name, int id, int idx/*=0*/); + public Variable(@Const @ByRef NDArray array, @StdString BytePointer name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@Const @ByRef NDArray array, @StdString BytePointer name, int id); + public Variable(@Const @ByRef NDArray array, @StdString String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@Const @ByRef NDArray array, @StdString String name, int id, int idx/*=0*/); + public Variable(@Const @ByRef NDArray array, @StdString String name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@Const @ByRef NDArray array, @StdString String name, int id); + public Variable(@SharedPtr NDArrayList array, @StdString BytePointer name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@SharedPtr NDArrayList array, @StdString BytePointer name, int id, int idx/*=0*/); + public Variable(@SharedPtr NDArrayList array, @StdString BytePointer name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@SharedPtr NDArrayList array, @StdString BytePointer name, int id); + public Variable(@SharedPtr NDArrayList array, @StdString String name, int id, int idx/*=0*/) { super((Pointer)null); allocate(array, name, id, idx); } + private native void allocate(@SharedPtr NDArrayList array, @StdString String name, int id, int idx/*=0*/); + public Variable(@SharedPtr NDArrayList array, @StdString String name, int id) { super((Pointer)null); allocate(array, name, id); } + private native void allocate(@SharedPtr NDArrayList array, @StdString String name, int id); + public Variable(@SharedPtr NDArray array, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } + private native void allocate(@SharedPtr NDArray array, @Cast("char*") String name/*=nullptr*/); + public Variable(@SharedPtr NDArray array) { super((Pointer)null); allocate(array); } + private native void allocate(@SharedPtr NDArray array); + public Variable(@SharedPtr NDArray array, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(array, name); } + private native void allocate(@SharedPtr NDArray array, @Cast("char*") BytePointer name/*=nullptr*/); + public Variable() { super((Pointer)null); allocate(); } + private native void allocate(); // #ifndef __JAVACPP_HACK__ // #endif - public native Variable clone(); + public native @Cast("bool") boolean hasNDArray(); + public native @SharedPtr NDArray getNDArray(); + public native void setNDArray(@SharedPtr NDArray array); - public native @Cast("bool") boolean hasNDArray(); - public native NDArray getNDArray(); - public native void setNDArray(NDArray array); + public native @Cast("bool") boolean hasNDArrayList(); + public native @SharedPtr NDArrayList getNDArrayList(); + public native void setNDArrayList(@SharedPtr NDArrayList list); - public native @Cast("bool") boolean hasNDArrayList(); - public native NDArrayList getNDArrayList(); - public native void setNDArrayList(NDArrayList list); + public native @Cast("bool") boolean isExternal(); + public native @Cast("bool") boolean isReadOnly(); + public native @Cast("bool") boolean isEmpty(); + public native @Cast("bool") boolean isRemovable(); - public native @Cast("bool") boolean isExternal(); - public native @Cast("bool") boolean isReadOnly(); - public native @Cast("bool") boolean isEmpty(); - public native @Cast("bool") boolean isRemovable(); + public native @Cast("bool") boolean isPlaceholder(); - public native @Cast("bool") boolean isPlaceholder(); + public native @Cast("sd::graph::VariableType") int variableType(); + public native void setVariableType(@Cast("sd::graph::VariableType") int variableType); - public native @Cast("sd::graph::VariableType") int variableType(); - public native void setVariableType(@Cast("sd::graph::VariableType") int variableType); + public native void markExternal(@Cast("bool") boolean reallyExternal); + public native void markReadOnly(@Cast("bool") boolean reallyReadOnly); + public native void markRemovable(@Cast("bool") boolean reallyRemovable); - /** - * This method returns InputType of this variable - */ - //InputType variableType() { - // return _variableType; - //} + public native int id(); + public native int index(); + public native void setIndex(int index); + public native void setId(int id); + public native void setId(int id, int idx); - public native void markExternal(@Cast("bool") boolean reallyExternal); - public native void markReadOnly(@Cast("bool") boolean reallyReadOnly); - public native void markRemovable(@Cast("bool") boolean reallyRemovable); + public native @StdString BytePointer name(); + public native @StdString BytePointer getName(); + public native void setName(@StdString BytePointer name); + public native void setName(@StdString String name); - public native int id(); - public native int index(); - public native void setIndex(int index); - public native void setId(int id); - public native void setId(int id, int idx); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @Cast("sd::DataType") int dataType(); - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getName(); - public native void setName(@StdString @Cast({"char*", "std::string*"}) BytePointer name); - - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @StdVector IntIntPair dependencies(); // #ifndef __JAVACPP_HACK__ // #endif - } - - +// #ifndef __JAVACPP_HACK__ +// #endif +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLE_H +// #endif // LIBND4J_VARIABLE_H // Parsed from graph/VariablesSet.h @@ -5495,36 +6042,34 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLESSET_H // #define LIBND4J_VARIABLESSET_H -// #include -// #include -// #include -// #include // #include - @Namespace("sd::graph") @NoOffset public static class VariablesSet extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public VariablesSet(Pointer p) { super(p); } - - public VariablesSet(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/) { super((Pointer)null); allocate(status); } - private native void allocate(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/); - public VariablesSet() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @Cast("Nd4jStatus") int status(); - - public native int size(); +// #include +// #include - public native void push_back(Variable variable); +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class VariablesSet extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public VariablesSet(Pointer p) { super(p); } - public native Variable at(int index); + public VariablesSet(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/) { super((Pointer)null); allocate(status); } + private native void allocate(@Cast("Nd4jStatus") int status/*=ND4J_STATUS_OK*/); + public VariablesSet() { super((Pointer)null); allocate(); } + private native void allocate(); - } - + public native @Cast("Nd4jStatus") int status(); + public native int size(); + public native void push_back(Variable variable); + public native Variable at(int index); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLESSET_H +// #endif // LIBND4J_VARIABLESSET_H // Parsed from graph/FlowPath.h @@ -5552,68 +6097,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_FLOWPATH_H // #define LIBND4J_FLOWPATH_H -// #include -// #include -// #include -// #include -// #include // #include +// #include // #include // #include - @Namespace("sd::graph") @NoOffset public static class FlowPath extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public FlowPath(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public FlowPath(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public FlowPath position(long position) { - return (FlowPath)super.position(position); - } - - public FlowPath() { super((Pointer)null); allocate(); } - private native void allocate(); +// #include +// #include - public native void setInnerTime(int nodeId, @Cast("Nd4jLong") long time); - public native void setOuterTime(int nodeId, @Cast("Nd4jLong") long time); +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class FlowPath extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public FlowPath(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public FlowPath(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public FlowPath position(long position) { + return (FlowPath)super.position(position); + } - public native @Cast("Nd4jLong") long innerTime(int nodeId); - public native @Cast("Nd4jLong") long outerTime(int nodeId); + public FlowPath() { super((Pointer)null); allocate(); } + private native void allocate(); - public native @Cast("bool") boolean isNodeActive(int nodeId); - public native void markNodeActive(int nodeId, @Cast("bool") boolean isActive); + public native void setInnerTime(int nodeId, @Cast("Nd4jLong") long time); + public native void setOuterTime(int nodeId, @Cast("Nd4jLong") long time); - public native @Cast("bool") boolean wasExecuted(int nodeId); - public native void markExecuted(int nodeId, @Cast("bool") boolean wasExecuted); + public native @Cast("Nd4jLong") long innerTime(int nodeId); + public native @Cast("Nd4jLong") long outerTime(int nodeId); - public native int branch(int nodeId); - public native void markBranch(int nodeId, int index); + public native @Cast("bool") boolean isNodeActive(int nodeId); + public native void markNodeActive(int nodeId, @Cast("bool") boolean isActive); - // Frame-related methods + public native @Cast("bool") boolean wasExecuted(int nodeId); + public native void markExecuted(int nodeId, @Cast("bool") boolean wasExecuted); - public native void registerFrame(@Cast("Nd4jLong") long frameId); - public native void forgetFrame(@Cast("Nd4jLong") long frameId); + public native int branch(int nodeId); + public native void markBranch(int nodeId, int index); - public native @Cast("bool") boolean isFrameActive(@Cast("Nd4jLong") long frameId); - public native void markFrameActive(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean isActive); + // Frame-related methods - public native @Cast("bool") boolean isRewindPlanned(@Cast("Nd4jLong") long frameId); - public native void planRewind(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean reallyRewind); + public native void registerFrame(@Cast("Nd4jLong") long frameId); + public native void forgetFrame(@Cast("Nd4jLong") long frameId); - public native int getRewindPosition(@Cast("Nd4jLong") long frameId); - public native void setRewindPosition(@Cast("Nd4jLong") long frameId, int _position); - public native void setRewindPositionOnce(@Cast("Nd4jLong") long frameId, int _position); + public native @Cast("bool") boolean isFrameActive(@Cast("Nd4jLong") long frameId); + public native void markFrameActive(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean isActive); - public native void incrementNumberOfCycles(@Cast("Nd4jLong") long frameId); - public native @Cast("Nd4jLong") long getNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native @Cast("bool") boolean isRewindPlanned(@Cast("Nd4jLong") long frameId); + public native void planRewind(@Cast("Nd4jLong") long frameId, @Cast("bool") boolean reallyRewind); - public native GraphProfile profile(); - } - + public native int getRewindPosition(@Cast("Nd4jLong") long frameId); + public native void setRewindPosition(@Cast("Nd4jLong") long frameId, int _position); + public native void setRewindPositionOnce(@Cast("Nd4jLong") long frameId, int _position); + public native void incrementNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native @Cast("Nd4jLong") long getNumberOfCycles(@Cast("Nd4jLong") long frameId); + public native GraphProfile profile(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_FLOWPATH_H +// #endif // LIBND4J_FLOWPATH_H // Parsed from graph/Intervals.h @@ -5641,43 +6186,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_INTERVALS_H // #define LIBND4J_INTERVALS_H -// #include -// #include -// #include // #include +// #include - @Namespace("sd") @NoOffset public static class Intervals extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Intervals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Intervals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Intervals position(long position) { - return (Intervals)super.position(position); - } - +// #include +// #include - // default constructor - public Intervals() { super((Pointer)null); allocate(); } - private native void allocate(); - - // constructor - public Intervals(@Const @ByRef LongVectorVector content ) { super((Pointer)null); allocate(content); } - private native void allocate(@Const @ByRef LongVectorVector content ); - - // accessing operator - public native @Cast("Nd4jLong*") @StdVector @Name("operator []") LongPointer get(@Cast("const Nd4jLong") long i); +@Namespace("sd") @NoOffset public static class Intervals extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Intervals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Intervals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Intervals position(long position) { + return (Intervals)super.position(position); + } - // returns size of _content - public native int size(); + // default constructor + public Intervals() { super((Pointer)null); allocate(); } + private native void allocate(); - } + // constructor + public Intervals(@Const @ByRef LongVectorVector content) { super((Pointer)null); allocate(content); } + private native void allocate(@Const @ByRef LongVectorVector content); + // accessing operator + public native @Cast("Nd4jLong*") @StdVector @Name("operator []") LongPointer get(@Cast("const Nd4jLong") long i); + // returns size of _content + public native int size(); +} + // namespace sd -// #endif //LIBND4J_INTERVALS_H +// #endif // LIBND4J_INTERVALS_H // Parsed from graph/Stash.h @@ -5707,195 +6250,78 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //#include // #include -// #include -// #include -// #include +// #include + // #include // #include -// #include - @Namespace("sd::graph") @NoOffset public static class KeyPair extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public KeyPair(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public KeyPair(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public KeyPair position(long position) { - return (KeyPair)super.position(position); - } - - public KeyPair(int node/*=0*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } - private native void allocate(int node/*=0*/, @Cast("char*") String name/*=nullptr*/); - public KeyPair() { super((Pointer)null); allocate(); } - private native void allocate(); - public KeyPair(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } - private native void allocate(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/); +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class KeyPair extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public KeyPair(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public KeyPair(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public KeyPair position(long position) { + return (KeyPair)super.position(position); + } - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef KeyPair other); + public KeyPair(int node/*=0*/, @Cast("char*") String name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } + private native void allocate(int node/*=0*/, @Cast("char*") String name/*=nullptr*/); + public KeyPair() { super((Pointer)null); allocate(); } + private native void allocate(); + public KeyPair(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/) { super((Pointer)null); allocate(node, name); } + private native void allocate(int node/*=0*/, @Cast("char*") BytePointer name/*=nullptr*/); - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef KeyPair other); + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef KeyPair other); - public native int key(); - public native @StdString BytePointer name(); - } - + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef KeyPair other); + public native int key(); + public native @StdString BytePointer name(); +} + // namespace graph + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - @Namespace("sd::graph") @NoOffset public static class Stash extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Stash(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Stash(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Stash position(long position) { - return (Stash)super.position(position); - } - - public Stash() { super((Pointer)null); allocate(); } - private native void allocate(); - - //void storeArray(sd::graph::Block& block, const char *name, sd::NDArray *array); - public native void storeArray(int nodeId, @Cast("char*") String name, NDArray array); - public native void storeArray(int nodeId, @Cast("char*") BytePointer name, NDArray array); - - //bool checkStash(sd::graph::Block& block, const char *name); - public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") String name); - public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") BytePointer name); - - //sd::NDArray* extractArray(sd::graph::Block& block, const char *name); - public native NDArray extractArray(int nodeId, @Cast("char*") String name); - public native NDArray extractArray(int nodeId, @Cast("char*") BytePointer name); - - public native void clear(); - } - - - - - - - -// #endif //LIBND4J_STASH_H - - -// Parsed from graph/GraphState.h - -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 23.01.18. -// - -// #ifndef LIBND4J_GRAPHSTATE_H -// #define LIBND4J_GRAPHSTATE_H - -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include - - @Namespace("sd::graph") @NoOffset public static class GraphState extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public GraphState(Pointer p) { super(p); } - - public GraphState(@Cast("Nd4jLong") long id) { super((Pointer)null); allocate(id); } - private native void allocate(@Cast("Nd4jLong") long id); - - /** - * - * @return - */ - public native @Cast("Nd4jLong") long id(); - - /** - * This method adds scope to this state tracker - * - * @param scopeId - * @return - */ - public native @Cast("Nd4jStatus") int registerScope(int scopeId); - - /** - * This method cheks if scope with given ID exists - * - * @param scopeId - ID of the scope - * @return - TRUE if scope exists, FALSE otherwise - */ - public native @Cast("bool") boolean hasScope(int scopeId); - - /** - * This method removes specified scope from this state tracker - * - * @param scopeId - * @return - */ - public native @Cast("Nd4jStatus") int forgetScope(int scopeId); - -// #ifndef __JAVACPP_HACK__ -// #endif - /** - * This method adds given op to the end of specified scope - * - * @param scopeId - * @param opNum - * @param type - * @return - */ - public native @Cast("Nd4jStatus") int attachOpToScope(int scopeId, @Cast("Nd4jLong") long opNum, int type, @ByVal ArgumentsList inputs); +@Namespace("sd::graph") @NoOffset public static class Stash extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Stash(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Stash(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Stash position(long position) { + return (Stash)super.position(position); + } - /** - * This method adds return statement to specified scope - * - * PLEASE NOTE: should be used only in body scopes - * - * @param scopeId - * @param nodeId - * @param args - * @return - */ - public native @Cast("Nd4jStatus") int defineReturn(int scopeId, int nodeId, @ByVal ArgumentsList args); + public Stash() { super((Pointer)null); allocate(); } + private native void allocate(); - /** - * This method returns current variable space of this state holder - * - * @return - */ - public native VariableSpace variableSpace(); - } + // void storeArray(sd::graph::Block& block, const char *name, + // sd::NDArray *array); + public native void storeArray(int nodeId, @Cast("char*") String name, NDArray array); + public native void storeArray(int nodeId, @Cast("char*") BytePointer name, NDArray array); + // bool checkStash(sd::graph::Block& block, const char *name); + public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") String name); + public native @Cast("bool") boolean checkStash(int nodeId, @Cast("char*") BytePointer name); + // sd::NDArray* extractArray(sd::graph::Block& block, const char *name); + public native NDArray extractArray(int nodeId, @Cast("char*") String name); + public native NDArray extractArray(int nodeId, @Cast("char*") BytePointer name); + public native void clear(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_GRAPHSTATE_H +// #endif // LIBND4J_STASH_H // Parsed from graph/VariableSpace.h @@ -5923,102 +6349,102 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_VARIABLESPACE_H // #define LIBND4J_VARIABLESPACE_H -// #include -// #include -// #include -// #include -// #include -// #include -// #include // #include // #include +// #include +// #include // #include +// #include +// #include // #include -// #include -// #include - @Namespace("sd::graph") @NoOffset public static class VariableSpace extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public VariableSpace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public VariableSpace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public VariableSpace position(long position) { - return (VariableSpace)super.position(position); - } - - public VariableSpace() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native @ByRef @Name("operator =") VariableSpace put(@Const @ByRef VariableSpace other); - - public native int numberOfPlaceholders(); - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getPlaceholders(); - public native void setWorkspace(Workspace workspace); - - public native LaunchContext launchContext(); - - public native @Cast("bool") boolean hasExternalVariable(int it); - public native @Cast("bool") boolean hasExternalVariable(@ByRef IntIntPair pair); - public native @Cast("bool") boolean hasExternalVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native @Cast("bool") boolean hasVariable(int id); - public native @Cast("bool") boolean hasVariable(int id, int idx); - public native @Cast("bool") boolean hasVariable(@ByRef IntIntPair pair); - public native @Cast("bool") boolean hasVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native Variable getVariable(int id); - public native Variable getVariable(int id, int idx); - public native Variable getVariable(@ByRef IntIntPair pair); - public native Variable getVariable(@StdString @Cast({"char*", "std::string*"}) BytePointer symbol); - - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getVariables(); - - public native Variable putVariable(@ByRef IntIntPair pair, NDArray array); - public native void putVariable(@ByRef IntIntPair pair, Variable variable); - public native void putVariable(int id, Variable variable); - public native void putVariable(int id, NDArray array); - public native Variable putVariable(int id, int idx, NDArray array); - public native void putVariable(int id, int idx, Variable array); - - public native void dropVariable(@ByRef IntIntPair pair); - public native void dropVariable(int id, int idx); - - public native void trackList(NDArrayList list); - - public native void putOutputVariable(Variable variable); - - public native void replaceVariable(Variable variable); - - // memory-related statistics - public native @Cast("Nd4jLong") long externalMemory(); - public native @Cast("Nd4jLong") long internalMemory(); - public native @Cast("Nd4jLong") long totalMemory(); - - public native int externalEntries(); - public native int internalEntries(); - public native int totalEntries(); - - public native VariableSpace clone(); - - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer handles(); +// #include +// #include +// #include +// #include +// #include +@Namespace("sd::graph") @NoOffset public static class VariableSpace extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public VariableSpace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public VariableSpace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public VariableSpace position(long position) { + return (VariableSpace)super.position(position); + } - public native VariableSpace asT(); - public native void injectVariable(@ByRef IntIntPair pair, Variable variable); + public VariableSpace() { super((Pointer)null); allocate(); } + private native void allocate(); - public native Stash getStash(); + public VariableSpace(@Const @ByRef VariableSpace variableSpace) { super((Pointer)null); allocate(variableSpace); } + private native void allocate(@Const @ByRef VariableSpace variableSpace); - public native @Cast("sd::graph::Variable**") @StdVector PointerPointer getExternalVariables(); + public native @ByRef @Name("operator =") VariableSpace put(@Const @ByRef VariableSpace other); - public native void setFlowPath(FlowPath timers); - public native FlowPath flowPath(); - } - + public native int numberOfPlaceholders(); +// #ifndef __JAVACPP_HACK__ +// #endif + public native @Cast("bool") boolean hasExternalVariable(int it); + public native @Cast("bool") boolean hasExternalVariable(@Const @ByRef IntIntPair pair); + public native @Cast("bool") boolean hasExternalVariable(@StdString BytePointer symbol); + public native @Cast("bool") boolean hasExternalVariable(@StdString String symbol); + + public native @Cast("bool") boolean hasVariable(int id); + public native @Cast("bool") boolean hasVariable(int id, int idx); + public native @Cast("bool") boolean hasVariable(@Const @ByRef IntIntPair pair); + public native @Cast("bool") boolean hasVariable(@StdString BytePointer symbol); + public native @Cast("bool") boolean hasVariable(@StdString String symbol); + + public native @SharedPtr Variable getVariable(int id); + public native @SharedPtr Variable getVariable(int id, int idx); + public native @SharedPtr Variable getVariable( + @Const @ByRef IntIntPair pair); + public native @SharedPtr Variable getVariable( + @StdString BytePointer symbol); + public native @SharedPtr Variable getVariable( + @StdString String symbol); + + public native @SharedPtr Variable putVariable(@Const @ByRef IntIntPair pair, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(int id, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(int id, int idx, @SharedPtr NDArray array); + public native @SharedPtr Variable putVariable(@StdString BytePointer name, int id, int idx, @Const @ByRef NDArray array); + public native @SharedPtr Variable putVariable(@StdString String name, int id, int idx, @Const @ByRef NDArray array); + + public native void putVariable(@StdString BytePointer name, int id, int idx, @SharedPtr Variable variable); + public native void putVariable(@StdString String name, int id, int idx, @SharedPtr Variable variable); + public native void putVariable(@Const @ByRef IntIntPair pair, @SharedPtr Variable variable); + public native void putVariable(int id, @SharedPtr Variable variable); + + public native void dropVariable(@StdString BytePointer pair); + public native void dropVariable(@StdString String pair); + public native void dropVariable(@Const @ByRef IntIntPair pair); + public native void dropVariable(int id, int idx); + + public native void putOutputVariable(@SharedPtr Variable variable); + + public native void replaceVariable(@SharedPtr Variable variable); + + // memory-related statistics + public native @Cast("Nd4jLong") long externalMemory(); + public native @Cast("Nd4jLong") long internalMemory(); + public native @Cast("Nd4jLong") long totalMemory(); + + public native int externalEntries(); + public native int internalEntries(); + public native int totalEntries(); + + public native void injectVariable(@Const @ByRef IntIntPair pair, + @SharedPtr Variable variable); + + public native Stash stash(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_VARIABLESPACE_H +// #endif // LIBND4J_VARIABLESPACE_H // Parsed from helpers/helper_generator.h @@ -6046,10 +6472,10 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HELPER_GENERATOR_H // #define LIBND4J_HELPER_GENERATOR_H -// #include -// #include // #include // #include +// #include +// #include // #ifdef _MSC_VER // include for uint64_t on MSVC @@ -6059,14 +6485,13 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef UINT64_C // #if defined(__LP64__) -// #define UINT64_C(c) c ## UL +// #define UINT64_C(c) c##UL // #else -// #define UINT64_C(c) c ## ULL -// #endif //LP64 -// #endif // UINT64 - -// #endif // MSVC/ANDROID +// #define UINT64_C(c) c##ULL +// #endif // LP64 +// #endif // UINT64 +// #endif // MSVC/ANDROID // #ifdef __GNUC__ // #include @@ -6074,200 +6499,194 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifdef __CUDACC__ // #else - @Namespace("sd::random") @NoOffset public static class RandomBuffer extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public RandomBuffer(Pointer p) { super(p); } - - /** - * This method allocates buffer of size * sizeof(Nd4jLong) - * - * @param size - * @return - */ +@Namespace("sd::random") @NoOffset public static class RandomBuffer extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public RandomBuffer(Pointer p) { super(p); } + + /** + * This method allocates buffer of size * sizeof(Nd4jLong) + * + * @param size + * @return + */ // #ifdef __CUDACC__ // #endif - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer); - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer); - public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer) { super((Pointer)null); allocate(seed, size, buffer); } - private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongPointer buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") LongBuffer buffer); + public RandomBuffer(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer) { super((Pointer)null); allocate(seed, size, buffer); } + private native void allocate(@Cast("Nd4jLong") long seed, @Cast("Nd4jLong") long size, @Cast("uint64_t*") long[] buffer); - public native @Cast("uint64_t*") LongPointer getBuffer(); + public native @Cast("uint64_t*") LongPointer getBuffer(); - public native @Cast("uint64_t*") LongPointer getDeviceBuffer(); + public native @Cast("uint64_t*") LongPointer getDeviceBuffer(); // #ifdef __CUDACC__ // #endif - public native @Cast("Nd4jLong") long getSize(); - - public native @Cast("Nd4jLong") long getSeed(); + public native @Cast("Nd4jLong") long getSize(); - public native void setSeed(@Cast("Nd4jLong") long seed); + public native @Cast("Nd4jLong") long getSeed(); - public native @Cast("Nd4jLong") long getAllocatedSize(); + public native void setSeed(@Cast("Nd4jLong") long seed); - public native @Cast("Nd4jLong") long getOffset(); + public native @Cast("Nd4jLong") long getAllocatedSize(); - public native void setOffset(@Cast("Nd4jLong") long offset); + public native @Cast("Nd4jLong") long getOffset(); - public native void reSeed(@Cast("Nd4jLong") long amplifier); + public native void setOffset(@Cast("Nd4jLong") long offset); - public native @Cast("uint64_t") long getElement(@Cast("Nd4jLong") long _position); + public native void reSeed(@Cast("Nd4jLong") long amplifier); - public native @Cast("uint64_t") long next64(@Cast("uint64_t") long shiftedSeed); + public native @Cast("uint64_t") long getElement(@Cast("Nd4jLong") long _position); - public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, @Cast("uint64_t") long k); + public native @Cast("uint64_t") long next64(@Cast("uint64_t") long shiftedSeed); - public static native @Cast("uint64_t") long safeShift(@Cast("uint64_t") long x, @Cast("uint64_t") long y); + public static native @Cast("uint64_t") long rotl(@Cast("const uint64_t") long x, @Cast("uint64_t") long k); - public native @Cast("uint64_t") long seedConv(@Cast("Nd4jLong") long seed); + public static native @Cast("uint64_t") long safeShift(@Cast("uint64_t") long x, @Cast("uint64_t") long y); - public native void incrementGeneration(); + public native @Cast("uint64_t") long seedConv(@Cast("Nd4jLong") long seed); - public native @Cast("Nd4jLong") long getNextIndex(); + public native void incrementGeneration(); - public native @Cast("uint64_t") long getNextElement(); + public native @Cast("Nd4jLong") long getNextIndex(); + public native @Cast("uint64_t") long getNextElement(); - /** - * This method skips X elements from buffer - * - * @param numberOfElements number of elements to skip - */ + /** + * This method skips X elements from buffer + * + * @param numberOfElements number of elements to skip + */ // #ifdef __CUDACC__ // #endif - public native void rewindH(@Cast("Nd4jLong") long numberOfElements); - - /** - * This method returns random int in range [0..MAX_INT] - * @return - */ - public native int nextInt(); - - public native @Cast("uint64_t") long nextUInt64(); - - /** - * This method returns random int in range [0..to] - * @param to - * @return - */ - public native int nextInt(int to); - - /** - * This method returns random int in range [from..to] - * @param from - * @param to - * @return - */ - public native int nextInt(int from, int to); - - - /** - * This method returns random T in range of [0..1] - * @return - */ - - /** - * This method returns random T in range of [0..to] - * @param to - * @return - */ - - /** - * This method returns random T in range [from..to] - * @param from - * @param to - * @return - */ - - public native @Cast("uint64_t") long relativeUInt64(@Cast("Nd4jLong") long index); - - /** - * relative methods are made as workaround for lock-free concurrent execution - */ - public native int relativeInt(@Cast("Nd4jLong") long index); - - /** - * This method returns random int within [0..to] - * - * @param index - * @param to - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index, int to); - - /** - * This method returns random int within [from..to] - * - * @param index - * @param to - * @param from - * @return - */ - public native int relativeInt(@Cast("Nd4jLong") long index, int from, int to); - - /** - * This method returns random T within [0..1] - * - * @param index - * @return - */ - -/** - * This method returns random T within [0..to] - * - * @param index - * @param to - * @return - */ + public native void rewindH(@Cast("Nd4jLong") long numberOfElements); -/** - * This method returns random T within [from..to] - * - * @param index - * @param from - * @param to - * @return - */ + /** + * This method returns random int in range [0..MAX_INT] + * @return + */ + public native int nextInt(); - } + public native @Cast("uint64_t") long nextUInt64(); - @Namespace("sd::random") @NoOffset public static class IGenerator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public IGenerator(Pointer p) { super(p); } - + /** + * This method returns random int in range [0..to] + * @param to + * @return + */ + public native int nextInt(int to); + /** + * This method returns random int in range [from..to] + * @param from + * @param to + * @return + */ + public native int nextInt(int from, int to); - public native RandomBuffer getBuffer(); + /** + * This method returns random T in range of [0..1] + * @return + */ - public native void setOffset(@Cast("Nd4jLong") long offset); + /** + * This method returns random T in range of [0..to] + * @param to + * @return + */ - public native @Cast("Nd4jLong") long getElementAbsolute(@Cast("Nd4jLong") long _position); + /** + * This method returns random T in range [from..to] + * @param from + * @param to + * @return + */ - public native @Cast("Nd4jLong") long getElementRelative(@Cast("Nd4jLong") long _position); + public native @Cast("uint64_t") long relativeUInt64(@Cast("Nd4jLong") long index); - public native void refreshBuffer(); - } + /** + * relative methods are made as workaround for lock-free concurrent execution + */ + public native int relativeInt(@Cast("Nd4jLong") long index); + /** + * This method returns random int within [0..to] + * + * @param index + * @param to + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index, int to); + /** + * This method returns random int within [from..to] + * + * @param index + * @param to + * @param from + * @return + */ + public native int relativeInt(@Cast("Nd4jLong") long index, int from, int to); - @Namespace("sd::random") @NoOffset public static class Xoroshiro128 extends IGenerator { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Xoroshiro128(Pointer p) { super(p); } - - public Xoroshiro128(RandomBuffer buffer) { super((Pointer)null); allocate(buffer); } - private native void allocate(RandomBuffer buffer); + /** + * This method returns random T within [0..1] + * + * @param index + * @return + */ - public native void refreshBuffer(); - } - + /** + * This method returns random T within [0..to] + * + * @param index + * @param to + * @return + */ + + /** + * This method returns random T within [from..to] + * + * @param index + * @param from + * @param to + * @return + */ +} + +@Namespace("sd::random") @NoOffset public static class IGenerator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public IGenerator(Pointer p) { super(p); } + + + public native RandomBuffer getBuffer(); + + public native void setOffset(@Cast("Nd4jLong") long offset); + + public native @Cast("Nd4jLong") long getElementAbsolute(@Cast("Nd4jLong") long _position); + + public native @Cast("Nd4jLong") long getElementRelative(@Cast("Nd4jLong") long _position); + + public native void refreshBuffer(); +} + +@Namespace("sd::random") @NoOffset public static class Xoroshiro128 extends IGenerator { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Xoroshiro128(Pointer p) { super(p); } -// #endif //LIBND4J_HELPER_GENERATOR_H + public Xoroshiro128(RandomBuffer buffer) { super((Pointer)null); allocate(buffer); } + private native void allocate(RandomBuffer buffer); + + public native void refreshBuffer(); +} + // namespace random + // namespace sd +// #endif // LIBND4J_HELPER_GENERATOR_H // Parsed from graph/profiling/GraphProfile.h @@ -6295,102 +6714,105 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_GRAPH_PROFILE_H // #define ND4J_GRAPH_PROFILE_H -// #include "NodeProfile.h" -// #include // #include -// #include -// #include -// #include +// #include + // #include - @Namespace("sd::graph") @NoOffset public static class GraphProfile extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public GraphProfile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public GraphProfile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public GraphProfile position(long position) { - return (GraphProfile)super.position(position); - } - - public GraphProfile() { super((Pointer)null); allocate(); } - private native void allocate(); - - /** - * These methods just adding amount of bytes to various counters - */ - public native void addToTotal(@Cast("Nd4jLong") long bytes); - public native void addToActivations(@Cast("Nd4jLong") long bytes); - public native void addToTemporary(@Cast("Nd4jLong") long bytes); - public native void addToObjects(@Cast("Nd4jLong") long bytes); - - /** - * This method allows to set graph construction (i.e. deserialization) time in nanoseconds - */ - public native void setBuildTime(@Cast("Nd4jLong") long nanos); - - /** - * This method sets graph execution time in nanoseconds. - */ - public native void setExecutionTime(@Cast("Nd4jLong") long nanos); - - public native void startEvent(@Cast("char*") String name); - public native void startEvent(@Cast("char*") BytePointer name); - public native void recordEvent(@Cast("char*") String name); - public native void recordEvent(@Cast("char*") BytePointer name); - public native void deleteEvent(@Cast("char*") String name); - public native void deleteEvent(@Cast("char*") BytePointer name); - - /** - * This method saves time as delta from last saved time - */ - public native void spotEvent(@Cast("char*") String name); - public native void spotEvent(@Cast("char*") BytePointer name); - - /** - * This method returns pointer to NodeProfile by ID - * PLEASE NOTE: this method will create new NodeProfile if there's none - */ - public native NodeProfile nodeById(int id, @Cast("char*") String name/*=nullptr*/); - public native NodeProfile nodeById(int id); - public native NodeProfile nodeById(int id, @Cast("char*") BytePointer name/*=nullptr*/); - public native @Cast("bool") boolean nodeExists(int id); - - /** - * This method merges values from other profile report - * @param other - */ - public native void merge(GraphProfile other); - public native void assign(GraphProfile other); - - /** - * These methods are just utility methods for time - */ - public static native @Cast("Nd4jLong") long currentTime(); - public static native @Cast("Nd4jLong") long relativeTime(@Cast("Nd4jLong") long time); - - public native void printOut(); - } - +// #include +// #include +// #include +// #include "NodeProfile.h" +@Namespace("sd::graph") @NoOffset public static class GraphProfile extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public GraphProfile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public GraphProfile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public GraphProfile position(long position) { + return (GraphProfile)super.position(position); + } -// #endif + public GraphProfile() { super((Pointer)null); allocate(); } + private native void allocate(); -// Parsed from graph/profiling/NodeProfile.h + /** + * These methods just adding amount of bytes to various counters + */ + public native void addToTotal(@Cast("Nd4jLong") long bytes); + public native void addToActivations(@Cast("Nd4jLong") long bytes); + public native void addToTemporary(@Cast("Nd4jLong") long bytes); + public native void addToObjects(@Cast("Nd4jLong") long bytes); -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * + /** + * This method allows to set graph construction (i.e. deserialization) time in + * nanoseconds + */ + public native void setBuildTime(@Cast("Nd4jLong") long nanos); + + /** + * This method sets graph execution time in nanoseconds. + */ + public native void setExecutionTime(@Cast("Nd4jLong") long nanos); + + public native void startEvent(@Cast("char*") String name); + public native void startEvent(@Cast("char*") BytePointer name); + public native void recordEvent(@Cast("char*") String name); + public native void recordEvent(@Cast("char*") BytePointer name); + public native void deleteEvent(@Cast("char*") String name); + public native void deleteEvent(@Cast("char*") BytePointer name); + + /** + * This method saves time as delta from last saved time + */ + public native void spotEvent(@Cast("char*") String name); + public native void spotEvent(@Cast("char*") BytePointer name); + + /** + * This method returns pointer to NodeProfile by ID + * PLEASE NOTE: this method will create new NodeProfile if there's none + */ + public native NodeProfile nodeById(int id, @Cast("char*") String name/*=nullptr*/); + public native NodeProfile nodeById(int id); + public native NodeProfile nodeById(int id, @Cast("char*") BytePointer name/*=nullptr*/); + public native @Cast("bool") boolean nodeExists(int id); + + /** + * This method merges values from other profile report + * @param other + */ + public native void merge(GraphProfile other); + public native void assign(GraphProfile other); + + /** + * These methods are just utility methods for time + */ + public static native @Cast("Nd4jLong") long currentTime(); + public static native @Cast("Nd4jLong") long relativeTime(@Cast("Nd4jLong") long time); + + public native void printOut(); +} + // namespace graph + // namespace sd + +// #endif + +// Parsed from graph/profiling/NodeProfile.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ @@ -6401,65 +6823,66 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_NODE_PROFILE_H // #define LIBND4J_NODE_PROFILE_H -// #include // #include +// #include + // #include // #include - @Namespace("sd::graph") @NoOffset public static class NodeProfile extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public NodeProfile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public NodeProfile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public NodeProfile position(long position) { - return (NodeProfile)super.position(position); - } - - public NodeProfile() { super((Pointer)null); allocate(); } - private native void allocate(); - - public NodeProfile(int id, @Cast("char*") String name) { super((Pointer)null); allocate(id, name); } - private native void allocate(int id, @Cast("char*") String name); - public NodeProfile(int id, @Cast("char*") BytePointer name) { super((Pointer)null); allocate(id, name); } - private native void allocate(int id, @Cast("char*") BytePointer name); - - public native void setBuildTime(@Cast("Nd4jLong") long time); - public native void setPreparationTime(@Cast("Nd4jLong") long time); - public native void setExecutionTime(@Cast("Nd4jLong") long time); - public native void setTotalTime(@Cast("Nd4jLong") long time); - public native void setShapeFunctionTime(@Cast("Nd4jLong") long time); - public native void setArrayTime(@Cast("Nd4jLong") long time); - public native void setInputTime(@Cast("Nd4jLong") long time); - - public native void setActivationsSize(@Cast("Nd4jLong") long bytes); - public native void setTemporarySize(@Cast("Nd4jLong") long bytes); - public native void setObjectsSize(@Cast("Nd4jLong") long bytes); - public native void setTotalSize(@Cast("Nd4jLong") long bytes); - - public native void addInputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void addInputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void addInputShape(@Cast("const Nd4jLong*") long[] shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public native void addOutputShape(@Cast("const Nd4jLong*") long[] shapeInfo); - - public native @Cast("Nd4jLong") long getActivationsSize(); - public native @Cast("Nd4jLong") long getTemporarySize(); - public native @Cast("Nd4jLong") long getObjectsSize(); - public native @Cast("Nd4jLong") long getTotalSize(); - - public native @Cast("Nd4jLong") long getExecutionTime(); - - public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); - - public native void merge(NodeProfile other); - public native void assign(NodeProfile other); - - public native void printOut(); - } - +@Namespace("sd::graph") @NoOffset public static class NodeProfile extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public NodeProfile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public NodeProfile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public NodeProfile position(long position) { + return (NodeProfile)super.position(position); + } + + public NodeProfile() { super((Pointer)null); allocate(); } + private native void allocate(); + + public NodeProfile(int id, @Cast("char*") String name) { super((Pointer)null); allocate(id, name); } + private native void allocate(int id, @Cast("char*") String name); + public NodeProfile(int id, @Cast("char*") BytePointer name) { super((Pointer)null); allocate(id, name); } + private native void allocate(int id, @Cast("char*") BytePointer name); + + public native void setBuildTime(@Cast("Nd4jLong") long time); + public native void setPreparationTime(@Cast("Nd4jLong") long time); + public native void setExecutionTime(@Cast("Nd4jLong") long time); + public native void setTotalTime(@Cast("Nd4jLong") long time); + public native void setShapeFunctionTime(@Cast("Nd4jLong") long time); + public native void setArrayTime(@Cast("Nd4jLong") long time); + public native void setInputTime(@Cast("Nd4jLong") long time); + + public native void setActivationsSize(@Cast("Nd4jLong") long bytes); + public native void setTemporarySize(@Cast("Nd4jLong") long bytes); + public native void setObjectsSize(@Cast("Nd4jLong") long bytes); + public native void setTotalSize(@Cast("Nd4jLong") long bytes); + + public native void addInputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void addInputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void addInputShape(@Cast("const Nd4jLong*") long[] shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public native void addOutputShape(@Cast("const Nd4jLong*") long[] shapeInfo); + + public native @Cast("Nd4jLong") long getActivationsSize(); + public native @Cast("Nd4jLong") long getTemporarySize(); + public native @Cast("Nd4jLong") long getObjectsSize(); + public native @Cast("Nd4jLong") long getTotalSize(); + + public native @Cast("Nd4jLong") long getExecutionTime(); + + public native @StdString @ByRef @Cast({"char*", "std::string*"}) BytePointer name(); + + public native void merge(NodeProfile other); + public native void assign(NodeProfile other); + public native void printOut(); +} + // namespace graph + // namespace sd // #endif @@ -6489,214 +6912,209 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CONTEXT_H // #define LIBND4J_CONTEXT_H -// #include // #include +// #include +// #include // #include // #include -// #include +// #include // #include -// #include + +// #include // CUDA-specific includes // #ifdef __CUDACC__ // #endif - /** - * This class defines input desired for any given node/operation within graph - */ - @Namespace("sd::graph") @NoOffset public static class Context extends ContextPrototype { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Context(Pointer p) { super(p); } - - public Context(ContextPrototype prototype, VariableSpace variableSpace) { super((Pointer)null); allocate(prototype, variableSpace); } - private native void allocate(ContextPrototype prototype, VariableSpace variableSpace); - - public Context(int nodeId, VariableSpace variableSpace/*=nullptr*/) { super((Pointer)null); allocate(nodeId, variableSpace); } - private native void allocate(int nodeId, VariableSpace variableSpace/*=nullptr*/); - public Context(int nodeId) { super((Pointer)null); allocate(nodeId); } - private native void allocate(int nodeId); - public Context(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace) { super((Pointer)null); allocate(nodeId, variableSpace, isInplace); } - private native void allocate(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace); - - // default destructor - - // these methods are for execution timing - public native void setOuterTime(@Cast("Nd4jLong") long time); - public native void setInnerTime(@Cast("Nd4jLong") long time); - public native @Cast("Nd4jLong") long getOuterTime(); - public native @Cast("Nd4jLong") long getInnerTime(); - - public native @Cast("sd::DataType") int dataType(); - - public native @Cast("sd::DataType") int dataType(int index); - public native void setDataType(int index, @Cast("sd::DataType") int type); - // these methods are related to Workspace abstraction - public native @Cast("bool") boolean hasWorkspaceProvided(); - public native void attachWorkspace(Workspace workspace); - public native void forgetWorkspace(); - - // these methods return full-time workspace - public native Workspace getWorkspace(); - public native Workspace workspace(); - public native Workspace fWorkspace(); - - // this method returns workspace for temporary allocations - public native Workspace tWorkspace(); - - // this method returns workspace for object allocations - public native Workspace oWorkspace(); - - public native void setVariableSpace(VariableSpace variableSpace); - - public native RandomBuffer getRNG(); - public native void setRNG(RandomBuffer rng); +/** + * This class defines input desired for any given node/operation within graph + */ +@Namespace("sd::graph") @NoOffset public static class Context extends ContextPrototype { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Context(Pointer p) { super(p); } - public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public Context(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace, + GraphMemoryManager memoryManager/*=nullptr*/) { super((Pointer)null); allocate(prototype, variableSpace, memoryManager); } + private native void allocate(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace, + GraphMemoryManager memoryManager/*=nullptr*/); + public Context(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace) { super((Pointer)null); allocate(prototype, variableSpace); } + private native void allocate(@Const @ByRef ContextPrototype prototype, VariableSpace variableSpace); - public native VariableSpace getVariableSpace(); + public Context(int nodeId, VariableSpace variableSpace/*=nullptr*/) { super((Pointer)null); allocate(nodeId, variableSpace); } + private native void allocate(int nodeId, VariableSpace variableSpace/*=nullptr*/); + public Context(int nodeId) { super((Pointer)null); allocate(nodeId); } + private native void allocate(int nodeId); + public Context(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace) { super((Pointer)null); allocate(nodeId, variableSpace, isInplace); } + private native void allocate(int nodeId, VariableSpace variableSpace, @Cast("bool") boolean isInplace); - public native LaunchContext launchContext(); + // default destructor - // these fields define, if we can execute specific node in-place, without generating new array + // these methods are for execution timing + public native void setOuterTime(@Cast("Nd4jLong") long time); + public native void setInnerTime(@Cast("Nd4jLong") long time); + public native @Cast("Nd4jLong") long outerTime(); + public native @Cast("Nd4jLong") long innerTime(); + // these methods are related to Workspace abstraction + public native @Cast("bool") boolean hasWorkspaceProvided(); + public native void attachWorkspace(Workspace workspace); - // these variables are only for Divergent Nodes - public native int getBranch(); - public native void setBranch(int branch); + public native Workspace workspace(); - /** - * - * @return - */ - public native Stash getStash(); + public native void setVariableSpace(VariableSpace variableSpace); - /** - * - */ - public native void trackList(NDArrayList list); + public native void setTargetEngine(@Cast("samediff::Engine") int engine); + public native VariableSpace getVariableSpace(); - /** - * This method returns variable for a given input index for this block - * @param idx - * @return - */ - public native Variable getVariable(int idx); - public native Variable variable(int idx); + public native LaunchContext launchContext(); - /** - * This method is shortcut to getVariable(int idx); - * - * + it check fastpath for array availability (preferred) - * @return - */ - public native NDArray getNDArray(int idx); - public native NDArray array(int idx); + /** + * + * @return + */ + public native Stash stash(); + /** + * This method returns variable for a given input index for this block + * @param idx + * @return + */ + public native @SharedPtr Variable getVariable(int idx); + public native @SharedPtr Variable variable(int idx); - /** - * This method fetches variable from VariableSpace DIRECTLY - * @param p - * @return - */ - public native Variable variable(int node, int index); - public native Variable variable(@ByRef IntIntPair p); + /** + * This method is shortcut to getVariable(int idx); + * + * + it check fastpath for array availability (preferred) + * @return + */ + public native @SharedPtr NDArray getNDArray(int idx); + public native @SharedPtr NDArray array(int idx); + /** + * This is special method, used only within Graph + * @param idx + * @return + */ + public native NDArray arrayForOp(int idx); - public native void pushNDArrayToVariableSpace(int nodeId, int index, NDArray array, @Cast("bool") boolean removable/*=true*/); - public native void pushNDArrayToVariableSpace(int nodeId, int index, NDArray array); - public native void pushNDArrayToVariableSpace(@ByRef IntIntPair pair, NDArray array, @Cast("bool") boolean removable/*=true*/); - public native void pushNDArrayToVariableSpace(@ByRef IntIntPair pair, NDArray array); + /** + * This method fetches variable from VariableSpace DIRECTLY + * @param p + * @return + */ + public native @SharedPtr Variable variable(int node, int index); + public native @SharedPtr Variable variable(@Const @ByRef IntIntPair p); + + public native void pushNDArrayToVariableSpace(int nodeId, int index, @Const @ByRef NDArray array); + public native void pushNDArrayToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArray array); + + public native void pushNDArrayListToVariableSpace(int nodeId, int index, @SharedPtr NDArrayList list); + public native void pushNDArrayListToVariableSpace(int nodeId, int index, + @Const @ByRef NDArrayList list, + @Cast("bool") boolean track/*=true*/); + public native void pushNDArrayListToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArrayList list, + @Cast("bool") boolean track/*=true*/); + public native void pushNDArrayListToVariableSpace(@Const @ByRef IntIntPair pair, + @Const @ByRef NDArrayList list); + + public native @Cast("bool") boolean isValueAvailable(@StdString BytePointer name, int id, int idx/*=0*/); + public native @Cast("bool") boolean isValueAvailable(@StdString BytePointer name, int id); + public native @Cast("bool") boolean isValueAvailable(@StdString String name, int id, int idx/*=0*/); + public native @Cast("bool") boolean isValueAvailable(@StdString String name, int id); + + public native @SharedPtr Variable ensureVariable(@StdString BytePointer name, int id, + int idx/*=0*/); + public native @SharedPtr Variable ensureVariable(@StdString BytePointer name, int id); + public native @SharedPtr Variable ensureVariable(@StdString String name, int id, + int idx/*=0*/); + public native @SharedPtr Variable ensureVariable(@StdString String name, int id); + + public native @Cast("unsigned long") long width(); + + // methods used in java interop + /** + * This method checks if Context uses fastpath variable access + * @return + */ + public native @Cast("bool") boolean isFastPath(); - public native void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList list, @Cast("bool") boolean track/*=true*/); - public native void pushNDArrayListToVariableSpace(int nodeId, int index, NDArrayList list); - public native void pushNDArrayListToVariableSpace(@ByRef IntIntPair pair, NDArrayList list, @Cast("bool") boolean track/*=true*/); - public native void pushNDArrayListToVariableSpace(@ByRef IntIntPair pair, NDArrayList list); + /** + * Method allows to forbid FastPath execution + * @param reallyForbid + */ + public native void forbidFastPath(@Cast("bool") boolean reallyForbid); - public native @Cast("bool") boolean isValueAvailable(int idx/*=0*/); - public native @Cast("bool") boolean isValueAvailable(); +// #ifndef __JAVACPP_HACK__ +// #endif - public native Variable ensureVariable(int idx/*=0*/); - public native Variable ensureVariable(); + public native void setInputArray(int index, @SharedPtr NDArray array); + public native void setInputArray(int index, Pointer buffer, @Const Pointer shapeInfo, + Pointer specialBuffer, @Const Pointer specialShapeInfo); + public native void setInputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, + @Const Pointer specialShapeInfo); + + public native void setOutputArray(int index, @SharedPtr NDArray array); + public native void setOutputArray(int index, Pointer buffer, @Const Pointer shapeInfo, + Pointer specialBuffer, @Const Pointer specialShapeInfo); + public native void setOutputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, + @Const Pointer specialShapeInfo); + + public native void setTArguments(DoublePointer arguments, int numberOfArguments); + public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); + public native void setTArguments(double[] arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); + public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); + public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); + public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") IntPointer arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") IntBuffer arguments, int numberOfArguments); + public native void setDArguments(@Cast("sd::DataType*") int[] arguments, int numberOfArguments); + + public native void setTArguments(@StdVector DoublePointer tArgs); + public native void setTArguments(@StdVector DoubleBuffer tArgs); + public native void setTArguments(@StdVector double[] tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); + public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); + public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); + public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector IntPointer dArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector IntBuffer dArgs); + public native void setDArguments(@Cast("sd::DataType*") @StdVector int[] dArgs); - public native @Cast("unsigned long") long width(); + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + public native void clearFastPath(); - // methods used in java interop - /** - * This method checks if Context uses fastpath variable access - * @return - */ - public native @Cast("bool") boolean isFastPath(); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, + @Cast("Nd4jPointer") Pointer allocationPointer); - /** - * Method allows to forbid FastPath execution - * @param reallyForbid - */ - public native void forbidFastPath(@Cast("bool") boolean reallyForbid); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); + public native @Cast("bool") boolean helpersAllowed(); -// #ifndef __JAVACPP_HACK__ -// #endif + public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); + public native @Cast("bool") boolean shapeFunctionOverride(); - public native void setInputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); - public native void setInputArray(int index, NDArray array); - public native void setInputArray(int index, Pointer buffer, @Const Pointer shapeInfo, Pointer specialBuffer, @Const Pointer specialShapeInfo); - public native void setInputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, @Const Pointer specialShapeInfo); - - public native void setOutputArray(int index, NDArray array, @Cast("bool") boolean removable/*=false*/); - public native void setOutputArray(int index, NDArray array); - public native void setOutputArray(int index, Pointer buffer, @Const Pointer shapeInfo, Pointer specialBuffer, @Const Pointer specialShapeInfo); - public native void setOutputArray(int index, Pointer databuffer, @Const Pointer shapeInfo, @Const Pointer specialShapeInfo); - - public native void setTArguments(DoublePointer arguments, int numberOfArguments); - public native void setTArguments(DoubleBuffer arguments, int numberOfArguments); - public native void setTArguments(double[] arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); - public native void setIArguments(@Cast("Nd4jLong*") long[] arguments, int numberOfArguments); - public native void setBArguments(@Cast("bool*") BooleanPointer arguments, int numberOfArguments); - public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") IntPointer arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") IntBuffer arguments, int numberOfArguments); - public native void setDArguments(@Cast("sd::DataType*") int[] arguments, int numberOfArguments); - - public native void setTArguments(@StdVector DoublePointer tArgs); - public native void setTArguments(@StdVector DoubleBuffer tArgs); - public native void setTArguments(@StdVector double[] tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongPointer tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector LongBuffer tArgs); - public native void setIArguments(@Cast("Nd4jLong*") @StdVector long[] tArgs); - public native void setBArguments(@Cast("bool*") @StdVector BooleanPointer tArgs); - public native void setBArguments(@Cast("bool*") @StdVector boolean[] tArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector IntPointer dArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector IntBuffer dArgs); - public native void setDArguments(@Cast("sd::DataType*") @StdVector int[] dArgs); - - /** - * This method purges fastpath in/out contents and releases all the handles. - * - * PLEASE NOTE: I/T/B/D args will stay intact - */ - public native void clearFastPath(); - - public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); - - public native void allowHelpers(@Cast("bool") boolean reallyAllow); - public native @Cast("bool") boolean helpersAllowed(); - - public native void setShapeFunctionOverride(@Cast("bool") boolean reallyOverride); - public native @Cast("bool") boolean shapeFunctionOverride(); - - public native @Cast("samediff::ExecutionMode") int executionMode(); - public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); - - public native @Cast("bool") boolean isTraining(); - public native @Cast("bool") boolean isInference(); - } - + public native @Cast("samediff::ExecutionMode") int executionMode(); + public native void setExecutionMode(@Cast("samediff::ExecutionMode") int executionMode); + public native @Cast("bool") boolean isTraining(); + public native @Cast("bool") boolean isInference(); + public native @Const @ByRef GraphMemoryManager memoryManager(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_BLOCK_H +// #endif // LIBND4J_BLOCK_H // Parsed from graph/ContextPrototype.h @@ -6725,99 +7143,130 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_CONTEXT_PROTOTYPE_H // #define ND4J_CONTEXT_PROTOTYPE_H -// #include -// #include // #include -// #include -// #include -// #include // #include // #include +// #include +// #include +// #include +// #include + +// #include // #ifndef __STANDALONE_BUILD__ // #include // #endif - @Namespace("sd::graph") @NoOffset public static class ContextPrototype extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ContextPrototype(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ContextPrototype(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ContextPrototype position(long position) { - return (ContextPrototype)super.position(position); - } - - public ContextPrototype(OpDescriptor opDescriptor/*=nullptr*/, int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/) { super((Pointer)null); allocate(opDescriptor, nodeId, inPlace); } - private native void allocate(OpDescriptor opDescriptor/*=nullptr*/, int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/); - public ContextPrototype() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int getNodeId(); - public native int nodeId(); - - // this method returns true, if inputs are defined - public native @Cast("bool") boolean hasVariablesFilled(); - - public native void setOpDescriptor(OpDescriptor opDescriptor); - - public native @Cast("sd::DataType") int dataType(); - public native @Cast("sd::DataType") int dataType(int index); - public native void setDataType(int index, @Cast("sd::DataType") int type); - - public native @Cast("bool") boolean isInplace(); - public native void markInplace(@Cast("bool") boolean reallyInplace); - - public native void pickInput(int input); - public native void pickInput(int input, int index); - public native void pickInput(@ByRef IntIntPair p); - public native void fillInputs(@StdVector IntPointer inputs); - public native void fillInputs(@StdVector IntBuffer inputs); - public native void fillInputs(@StdVector int[] inputs); - public native @StdVector IntIntPair inputs(); - - public native @StdVector DoublePointer getTArguments(); - public native @StdVector IntPointer getIArguments(); - public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); - public native @Cast("sd::DataType*") @StdVector IntPointer getDArguments(); - public native @StdVector IntPointer getAxis(); - - public native @Cast("samediff::Engine") int engine(); - - public native @Cast("size_t") long numT(); - public native @Cast("size_t") long numI(); - public native @Cast("size_t") long numB(); - public native @Cast("size_t") long numD(); - - public native IntIntPair input(int idx); - - public native int opNum(); - public native void setOpNum(int opNum); - - public native @Cast("bool") boolean isUseMKLDNN(); - public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); - - /** - * This method returns number of inputs available in this block - * @return - */ - public native @Cast("unsigned long") long width(); - - // just a clone - public native ContextPrototype clone(); - - public native @ByRef RandomGenerator randomGenerator(); - public native @Const @ByRef RandomGenerator getRng(); - public native void setRng(@Const @ByRef RandomGenerator anotherRng); - public native void setRandomGenerator(@Const @ByRef RandomGenerator anotherRng); - public native @Cast("uint64_t") long randomSeed(); - public native void setRandomSeed(@Cast("uint64_t") long seed); - } - +@Namespace("sd::graph") @NoOffset public static class ContextPrototype extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ContextPrototype(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ContextPrototype(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ContextPrototype position(long position) { + return (ContextPrototype)super.position(position); + } + + public ContextPrototype(OpDescriptor opDescriptor/*=nullptr*/, + int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/) { super((Pointer)null); allocate(opDescriptor, nodeId, inPlace); } + private native void allocate(OpDescriptor opDescriptor/*=nullptr*/, + int nodeId/*=1*/, @Cast("bool") boolean inPlace/*=false*/); + public ContextPrototype() { super((Pointer)null); allocate(); } + private native void allocate(); + + public ContextPrototype(@Const @ByRef ContextPrototype other) { super((Pointer)null); allocate(other); } + @NoException private native void allocate(@Const @ByRef ContextPrototype other); + + public native @ByRef @Name("operator =") @NoException ContextPrototype put(@Const @ByRef ContextPrototype other); + + // move constructor + + // move assignment operator + + public native int getNodeId(); + public native int nodeId(); + public native void setNodeId(int id); + + // this method returns true, if inputs are defined + public native @Cast("bool") boolean hasVariablesFilled(); + + public native void setOpDescriptor(OpDescriptor opDescriptor); + + public native @Cast("bool") boolean isInplace(); + public native void markInplace(@Cast("bool") boolean reallyInplace); + + public native void pickInput(int input); + public native void pickInput(int input, int index); + public native void pickInput(@Const @ByRef IntIntPair p); + public native void fillInputs(@StdVector IntPointer inputs); + public native void fillInputs(@StdVector IntBuffer inputs); + public native void fillInputs(@StdVector int[] inputs); + public native @StdVector IntIntPair inputs(); + + public native @StdVector DoublePointer getTArguments(); + public native @StdVector IntPointer getIArguments(); + public native @Cast("bool*") @StdVector BooleanPointer getBArguments(); + public native @Cast("sd::DataType*") @StdVector IntPointer getDArguments(); + public native @StdVector IntPointer getAxis(); + + public native void appendI(@Cast("Nd4jLong*") @StdVector LongPointer value); + public native void appendI(@Cast("Nd4jLong*") @StdVector LongBuffer value); + public native void appendI(@Cast("Nd4jLong*") @StdVector long[] value); + public native void appendT(@StdVector DoublePointer value); + public native void appendT(@StdVector DoubleBuffer value); + public native void appendT(@StdVector double[] value); + public native void appendB(@Cast("bool*") @StdVector BooleanPointer value); + public native void appendB(@Cast("bool*") @StdVector boolean[] value); + public native void appendD(@Cast("sd::DataType*") @StdVector IntPointer value); + public native void appendD(@Cast("sd::DataType*") @StdVector IntBuffer value); + public native void appendD(@Cast("sd::DataType*") @StdVector int[] value); + + public native void appendA(@Cast("Nd4jLong") long value); + public native void appendI(@Cast("Nd4jLong") long value); + public native void appendT(double value); + public native void appendB(@Cast("bool") boolean value); + public native void appendD(@Cast("sd::DataType") int value); + + public native @Cast("samediff::Engine") int engine(); + + public native @Cast("size_t") long numT(); + public native @Cast("size_t") long numI(); + public native @Cast("size_t") long numB(); + public native @Cast("size_t") long numD(); + + public native @Const @ByRef IntIntPair input(int idx); + + public native int opNum(); + public native void setOpNum(int opNum); + + public native @Cast("bool") boolean isUseMKLDNN(); + public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); + public native @StdString BytePointer name(); + public native void setName(@StdString BytePointer name); + public native void setName(@StdString String name); -// #endif //ND4J_CONTEXT_PROTOTYPE_H + /** + * This method returns number of inputs available in this block + * @return + */ + public native @Cast("unsigned long") long width(); + + // just a clone + public native ContextPrototype clone(); + + public native @ByRef RandomGenerator randomGenerator(); + public native @Const @ByRef RandomGenerator getRng(); + public native void setRng(@Const @ByRef RandomGenerator anotherRng); + public native void setRandomGenerator(@Const @ByRef RandomGenerator anotherRng); + public native @Cast("uint64_t") long randomSeed(); + public native void setRandomSeed(@Cast("uint64_t") long seed); +} + // namespace graph + // namespace sd + +// #endif // ND4J_CONTEXT_PROTOTYPE_H // Parsed from graph/ResultWrapper.h @@ -6845,26 +7294,25 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_RESULTWRAPPER_H // #define LIBND4J_RESULTWRAPPER_H +// #include // #include // #include -// #include - @Namespace("sd::graph") @NoOffset public static class ResultWrapper extends org.nd4j.nativeblas.ResultWrapperAbstraction { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ResultWrapper(Pointer p) { super(p); } - - public ResultWrapper(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr) { super((Pointer)null); allocate(size, ptr); } - private native void allocate(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr); - - public native @Cast("Nd4jLong") long size(); +@Namespace("sd::graph") @NoOffset public static class ResultWrapper extends org.nd4j.nativeblas.ResultWrapperAbstraction { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ResultWrapper(Pointer p) { super(p); } - public native @Cast("Nd4jPointer") Pointer pointer(); - } - + public ResultWrapper(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr) { super((Pointer)null); allocate(size, ptr); } + private native void allocate(@Cast("Nd4jLong") long size, @Cast("Nd4jPointer") Pointer ptr); + public native @Cast("Nd4jLong") long size(); + public native @Cast("Nd4jPointer") Pointer pointer(); +} + // namespace graph + // namespace sd -// #endif //LIBND4J_RESULTWRAPPER_H +// #endif // LIBND4J_RESULTWRAPPER_H // Parsed from helpers/shape.h @@ -6895,205 +7343,309 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef SHAPE_H_ // #define SHAPE_H_ -// #include +// #include + // #include +// #include + +// #include "../cnpy/cnpy.h" +// #include "../helpers/logger.h" +// #include "math/templatemath.h" // #include "system/dll.h" // #include "system/nd4jmalloc.h" -// #include "math/templatemath.h" -// #include "../helpers/logger.h" // #include "system/pointercast.h" -// #include "../cnpy/cnpy.h" -// #include public static final int MAX_DIMENSION = 0x7fffffff; -public static final int MAX_NUM_THREADS = 1024; +public static final int MAX_NUM_THREADS = 1024; public static final int MAX_RANK = 32; -public static final int MAX_SHAPEINFOLENGTH = 2*MAX_RANK+4; +public static final int MAX_SHAPEINFOLENGTH = 2 * MAX_RANK + 4; public static final int MAX_COORD = 3; public static final int PREALLOC_SIZE = 33554432; // #ifdef __CUDACC__ // #endif - // #ifdef __CUDACC__ // #else // #define INLINEDEF inline // #endif -// #include "system/pairwise_util.h" -// #include // #include +// #include + +// #include "system/pairwise_util.h" /** * Shape information approximating * the information on an ndarray */ - @Namespace("shape") @NoOffset public static class ShapeInformation extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ShapeInformation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ShapeInformation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ShapeInformation position(long position) { - return (ShapeInformation)super.position(position); - } - - public ShapeInformation(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - public ShapeInformation() { super((Pointer)null); allocate(); } - private native void allocate(); - public ShapeInformation(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - public ShapeInformation(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } - private native void allocate(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, int rank_/*=0*/, int offset_/*=0*/, int elementWiseStride_/*=0*/); - - public native @Cast("Nd4jLong*") LongPointer shape(); public native ShapeInformation shape(LongPointer setter); - public native @Cast("Nd4jLong*") LongPointer stride(); public native ShapeInformation stride(LongPointer setter); - public native char order(); public native ShapeInformation order(char setter); - public native int rank(); public native ShapeInformation rank(int setter); - public native int offset(); public native ShapeInformation offset(int setter); - public native int elementWiseStride(); public native ShapeInformation elementWiseStride(int setter); +@Namespace("shape") @NoOffset public static class ShapeInformation extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ShapeInformation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ShapeInformation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ShapeInformation position(long position) { + return (ShapeInformation)super.position(position); } + public ShapeInformation(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") LongPointer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongPointer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + public ShapeInformation() { super((Pointer)null); allocate(); } + private native void allocate(); + public ShapeInformation(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") LongBuffer shape_/*=nullptr*/, + @Cast("Nd4jLong*") LongBuffer stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + public ShapeInformation(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, + @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/) { super((Pointer)null); allocate(shape_, stride_, order_, rank_, offset_, elementWiseStride_); } + private native void allocate(@Cast("Nd4jLong*") long[] shape_/*=nullptr*/, + @Cast("Nd4jLong*") long[] stride_/*=nullptr*/, char order_/*=0*/, + int rank_/*=0*/, int offset_/*=0*/, + int elementWiseStride_/*=0*/); + + public native @Cast("Nd4jLong*") LongPointer shape(); public native ShapeInformation shape(LongPointer setter); + public native @Cast("Nd4jLong*") LongPointer stride(); public native ShapeInformation stride(LongPointer setter); + public native char order(); public native ShapeInformation order(char setter); + public native int rank(); public native ShapeInformation rank(int setter); + public native int offset(); public native ShapeInformation offset(int setter); + public native int elementWiseStride(); public native ShapeInformation elementWiseStride(int setter); +} + /** * Indexing information * for bounds checking */ - @Namespace("shape") public static class CurrentIndexing extends Pointer { - static { Loader.load(); } - /** Default native constructor. */ - public CurrentIndexing() { super((Pointer)null); allocate(); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public CurrentIndexing(long size) { super((Pointer)null); allocateArray(size); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public CurrentIndexing(Pointer p) { super(p); } - private native void allocate(); - private native void allocateArray(long size); - @Override public CurrentIndexing position(long position) { - return (CurrentIndexing)super.position(position); - } - - public native int numElementsPerThread(); public native CurrentIndexing numElementsPerThread(int setter); - public native int blockStartingIndex(); public native CurrentIndexing blockStartingIndex(int setter); - public native int startingThreadIndex(); public native CurrentIndexing startingThreadIndex(int setter); - public native int endingThreadIndex(); public native CurrentIndexing endingThreadIndex(int setter); - +@Namespace("shape") public static class CurrentIndexing extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public CurrentIndexing() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public CurrentIndexing(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public CurrentIndexing(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public CurrentIndexing position(long position) { + return (CurrentIndexing)super.position(position); } + public native int numElementsPerThread(); public native CurrentIndexing numElementsPerThread(int setter); + public native int blockStartingIndex(); public native CurrentIndexing blockStartingIndex(int setter); + public native int startingThreadIndex(); public native CurrentIndexing startingThreadIndex(int setter); + public native int endingThreadIndex(); public native CurrentIndexing endingThreadIndex(int setter); +} +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongPointer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongPointer shape2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongBuffer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongBuffer shape2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, + @Cast("const Nd4jLong*") long[] shape1, + int shape2Rank, + @Cast("const Nd4jLong*") long[] shape2); + +@Namespace("shape") public static native @Cast("const Nd4jLong*") LongPointer detachShape(@Cast("const Nd4jLong*") LongPointer originalShape); +@Namespace("shape") public static native @Cast("const Nd4jLong*") LongBuffer detachShape(@Cast("const Nd4jLong*") LongBuffer originalShape); +@Namespace("shape") public static native @Cast("const Nd4jLong*") long[] detachShape(@Cast("const Nd4jLong*") long[] originalShape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer copyShape(@Cast("const Nd4jLong*") LongPointer originalShape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer copyShape(@Cast("const Nd4jLong*") LongBuffer originalShape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] copyShape(@Cast("const Nd4jLong*") long[] originalShape); + +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); + +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2, + @Cast("const Nd4jLong*") LongPointer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2, + @Cast("const Nd4jLong*") LongBuffer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2, + @Cast("const Nd4jLong*") long[] shapeInfo3); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongPointer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongPointer shape2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") LongBuffer shape1, + int shape2Rank, + @Cast("const Nd4jLong*") LongBuffer shape2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank, + @Cast("const Nd4jLong*") long[] shape1, + int shape2Rank, + @Cast("const Nd4jLong*") long[] shape2); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); + +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer stride1, int rank1, + @Cast("const Nd4jLong*") LongPointer stride2, int rank2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer stride1, int rank1, + @Cast("const Nd4jLong*") LongBuffer stride2, int rank2); +@Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] stride1, int rank1, + @Cast("const Nd4jLong*") long[] stride2, int rank2); + +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongPointer shapeA, + @Cast("const Nd4jLong*") LongPointer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongBuffer shapeA, + @Cast("const Nd4jLong*") LongBuffer shapeB); +@Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") long[] shapeA, + @Cast("const Nd4jLong*") long[] shapeB); + +// returns true if ranks, shapes and strides are the same +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, + @Cast("const Nd4jLong*") LongPointer shapeInfo2, + @Cast("const Nd4jLong*") LongPointer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, + @Cast("const Nd4jLong*") LongBuffer shapeInfo2, + @Cast("const Nd4jLong*") LongBuffer shapeInfo3); +@Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, + @Cast("const Nd4jLong*") long[] shapeInfo2, + @Cast("const Nd4jLong*") long[] shapeInfo3); + +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); +@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); +@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); + +@Namespace("shape") public static native void traceNew(int id); + +@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); + +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") long[] shapeInfo, int[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, + int newRank, @Cast("Nd4jLong*") LongPointer newShape, + @Cast("bool") boolean isFOrder); +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, + int newRank, @Cast("Nd4jLong*") LongBuffer newShape, + @Cast("bool") boolean isFOrder); +@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") long[] oldShape, + int newRank, @Cast("Nd4jLong*") long[] newShape, + @Cast("bool") boolean isFOrder); + +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") LongPointer newShape, + @Cast("Nd4jLong*") LongPointer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") LongBuffer newShape, + @Cast("Nd4jLong*") LongBuffer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, + byte newOrder, int newRank, + @Cast("const Nd4jLong*") long[] newShape, + @Cast("Nd4jLong*") long[] newShapeInfo); +/** + * newShapeInfo contains rank, shape and order only, no strides/ews/type + */ +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, + @Cast("Nd4jLong*") LongPointer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, + @Cast("Nd4jLong*") LongBuffer newShapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, + @Cast("Nd4jLong*") long[] newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") LongPointer shape1, int shape2Rank, @Cast("const Nd4jLong*") LongPointer shape2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") LongBuffer shape1, int shape2Rank, @Cast("const Nd4jLong*") LongBuffer shape2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(int shape1Rank, @Cast("const Nd4jLong*") long[] shape1, int shape2Rank, @Cast("const Nd4jLong*") long[] shape2); - - @Namespace("shape") public static native @Cast("const Nd4jLong*") LongPointer detachShape(@Cast("const Nd4jLong*") LongPointer originalShape); - @Namespace("shape") public static native @Cast("const Nd4jLong*") LongBuffer detachShape(@Cast("const Nd4jLong*") LongBuffer originalShape); - @Namespace("shape") public static native @Cast("const Nd4jLong*") long[] detachShape(@Cast("const Nd4jLong*") long[] originalShape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer copyShape(@Cast("const Nd4jLong*") LongPointer originalShape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer copyShape(@Cast("const Nd4jLong*") LongBuffer originalShape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] copyShape(@Cast("const Nd4jLong*") long[] originalShape); - - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2, @Cast("const Nd4jLong*") LongPointer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean shapeEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") LongPointer shape1,int shape2Rank, @Cast("const Nd4jLong*") LongPointer shape2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") LongBuffer shape1,int shape2Rank, @Cast("const Nd4jLong*") LongBuffer shape2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(int shape1Rank,@Cast("const Nd4jLong*") long[] shape1,int shape2Rank, @Cast("const Nd4jLong*") long[] shape2); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongPointer stride1,int rank1, @Cast("const Nd4jLong*") LongPointer stride2, int rank2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") LongBuffer stride1,int rank1, @Cast("const Nd4jLong*") LongBuffer stride2, int rank2); - @Namespace("shape") public static native @Cast("bool") boolean strideEquals(@Cast("const Nd4jLong*") long[] stride1,int rank1, @Cast("const Nd4jLong*") long[] stride2, int rank2); - - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsSoft(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsTypesAndShapesSoft(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongPointer shapeA, @Cast("const Nd4jLong*") LongPointer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") LongBuffer shapeA, @Cast("const Nd4jLong*") LongBuffer shapeB); - @Namespace("shape") public static native @Cast("bool") boolean equalsStrict(@Cast("const Nd4jLong*") long[] shapeA, @Cast("const Nd4jLong*") long[] shapeB); - - // returns true if ranks, shapes and strides are the same - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongPointer shapeInfo1, @Cast("const Nd4jLong*") LongPointer shapeInfo2, @Cast("const Nd4jLong*") LongPointer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") LongBuffer shapeInfo1, @Cast("const Nd4jLong*") LongBuffer shapeInfo2, @Cast("const Nd4jLong*") LongBuffer shapeInfo3); - @Namespace("shape") public static native @Cast("bool") boolean haveSameShapeAndStrides(@Cast("const Nd4jLong*") long[] shapeInfo1, @Cast("const Nd4jLong*") long[] shapeInfo2, @Cast("const Nd4jLong*") long[] shapeInfo3); - - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); - @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); - - @Namespace("shape") public static native void traceNew(int id); - - - @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); - - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("const Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); - - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); - @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") long[] oldShape, int newRank, @Cast("Nd4jLong*") long[] newShape, @Cast("bool") boolean isFOrder); - - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") LongPointer newShape, @Cast("Nd4jLong*") LongPointer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") LongBuffer newShape, @Cast("Nd4jLong*") LongBuffer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, byte newOrder, int newRank, @Cast("const Nd4jLong*") long[] newShape, @Cast("Nd4jLong*") long[] newShapeInfo); - /** - * newShapeInfo contains rank, shape and order only, no strides/ews/type - */ - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongPointer oldShapeInfo, @Cast("Nd4jLong*") LongPointer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") LongBuffer oldShapeInfo, @Cast("Nd4jLong*") LongBuffer newShapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean reshapeC(@Cast("const Nd4jLong*") long[] oldShapeInfo, @Cast("Nd4jLong*") long[] newShapeInfo); - - /** - * Get the shape info buffer - * for the given rank and shape. - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] buffer); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBuffer(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] buffer); - /** - * Get the shape info buffer - * for the given rank and shape. - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer output); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer output); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] output); +/** + * Get the shape info buffer + * for the given rank and shape. + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("Nd4jLong*") LongPointer output); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("Nd4jLong*") LongBuffer output); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferFortran(int rank, @Cast("sd::DataType") int dtype, + @Cast("const Nd4jLong*") long[] shape, + @Cast("Nd4jLong*") long[] output); // #ifdef __CUDACC__ // #endif - - /** * Computes the standard packed array strides for a given shape. * @@ -7101,13 +7653,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, + int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + @Cast("Nd4jLong*") long[] ret); /** * Computes the standard packed array strides for a given shape. @@ -7117,23 +7675,29 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, @Cast("Nd4jLong*") long[] ret); - - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongPointer shape, byte order); - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongBuffer shape, byte order); - @Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") long[] shape, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongPointer shapeOnly, @Cast("Nd4jLong*") LongPointer stridesOnly, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongBuffer shapeOnly, @Cast("Nd4jLong*") LongBuffer stridesOnly, byte order); - @Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") long[] shapeOnly, @Cast("Nd4jLong*") long[] stridesOnly, byte order); - - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + @Cast("Nd4jLong*") long[] ret); + +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongPointer shape, byte order); +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") LongBuffer shape, byte order); +@Namespace("shape") public static native void updateStrides(@Cast("Nd4jLong*") long[] shape, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongPointer shapeOnly, + @Cast("Nd4jLong*") LongPointer stridesOnly, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") LongBuffer shapeOnly, + @Cast("Nd4jLong*") LongBuffer stridesOnly, byte order); +@Namespace("shape") public static native void updateStrides(int rank, @Cast("const Nd4jLong*") long[] shapeOnly, + @Cast("Nd4jLong*") long[] stridesOnly, byte order); + +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 /** * Computes the standard packed array strides for a given shape. @@ -7142,13 +7706,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStridesFortran(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStridesFortran(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStridesFortran(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum, @Cast("Nd4jLong*") long[] ret); /** * Computes the standard packed array strides for a given shape. @@ -7157,38 +7727,44 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param startNum the start number for the strides * @return the strides for a matrix of n dimensions */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, int startNum, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, int startNum, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, int startNum, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer calcStrides(@Cast("const Nd4jLong*") LongPointer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer calcStrides(@Cast("const Nd4jLong*") LongBuffer shape, int rank, + int startNum, @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] calcStrides(@Cast("const Nd4jLong*") long[] shape, int rank, + int startNum, @Cast("Nd4jLong*") long[] ret); /** * @param toCopy the shape to copy * @return a copy of the original struct */ - @Namespace("shape") public static native ShapeInformation shapeCopy( ShapeInformation toCopy); - - - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF(@Cast("const Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native ShapeInformation shapeCopy(ShapeInformation toCopy); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean strideDescendingCAscendingF( + @Cast("const Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isContiguous(@Cast("const Nd4jLong*") long[] shapeInfo); /** * copy-past from java hasDefaultStridesForShape function * check whether array is not permuted and has contiguous elements in memory */ - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") long[] shapeInfo); - +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean areStridesDefault(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Compute the element wise stride @@ -7201,9 +7777,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, int isFOrder); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, int isFOrder); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer stride, + int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer stride, + int isFOrder); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] stride, + int isFOrder); /** * Compute the element wise stride @@ -7216,17 +7798,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return 0 if there is no element wise stride the * element wise stride of reshape(1,length) otherwise */ - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, int isFOrder, @Cast("const Nd4jLong*") LongPointer dimension, int dimensionLength); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, int isFOrder, @Cast("const Nd4jLong*") LongBuffer dimension, int dimensionLength); - @Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, int isFOrder, @Cast("const Nd4jLong*") long[] dimension, int dimensionLength); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength,@Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer stride, + int isFOrder, + @Cast("const Nd4jLong*") LongPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer stride, + int isFOrder, + @Cast("const Nd4jLong*") LongBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int computeElementWiseStride(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] stride, + int isFOrder, + @Cast("const Nd4jLong*") long[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeInfoOnlyShapeAndStride( + @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] dimension, int dimensionLength, + @Cast("bool") boolean reverseCopyStride, @Cast("Nd4jLong*") long[] buffer); /** * * @param length @@ -7234,11 +7840,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param rearrange * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer doPermuteSwap(int length, @Cast("Nd4jLong*") LongPointer shape, IntPointer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer doPermuteSwap(int length, @Cast("Nd4jLong*") LongBuffer shape, IntBuffer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] doPermuteSwap(int length, @Cast("Nd4jLong*") long[] shape, int[] rearrange); - - +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer doPermuteSwap(int length, @Cast("Nd4jLong*") LongPointer shape, + IntPointer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer doPermuteSwap(int length, @Cast("Nd4jLong*") LongBuffer shape, + IntBuffer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] doPermuteSwap(int length, @Cast("Nd4jLong*") long[] shape, + int[] rearrange); /** * In place permute swap @@ -7246,55 +7853,82 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param shape * @param rearrange */ - @Namespace("shape") public static native void doPermuteSwap(int length, @Cast("Nd4jLong**") PointerPointer shape, IntPointer rearrange); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer permuteShapeBuffer(@Cast("const Nd4jLong*") LongPointer shapeBuffer, IntPointer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer permuteShapeBuffer(@Cast("const Nd4jLong*") LongBuffer shapeBuffer, IntBuffer rearrange); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] permuteShapeBuffer(@Cast("const Nd4jLong*") long[] shapeBuffer, int[] rearrange); - - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongPointer shapeBuffer, IntPointer rearrange, @Cast("Nd4jLong*") LongPointer out); - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongBuffer shapeBuffer, IntBuffer rearrange, @Cast("Nd4jLong*") LongBuffer out); - @Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") long[] shapeBuffer, int[] rearrange, @Cast("Nd4jLong*") long[] out); - - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, @Const IntPointer rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, @Const IntPointer rearrange); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, @Const IntBuffer rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, @Const IntBuffer rearrange); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, @Const int[] rearrange, @Cast("Nd4jLong") long len/*=-1*/); - @Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, @Const int[] rearrange); - - /** - * Rearrange the permute indexes - * according to which dimensions are specified. - * - * For example, dimension is implicitly: - * 0,1,2 - * - * If you want to do a reduce along dimensions 0 and 1, - * you need to permute the indexes to be: - * 2,0,1 - * - * which will give us the ability to ierate along an element - * wise stride. - */ - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createPermuteIndexes(int originalRank, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createPermuteIndexes(int originalRank, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createPermuteIndexes(int originalRank, int[] dimension,int dimensionLength); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeResultShape(@Cast("const Nd4jLong*") LongPointer originalShapeBuffer, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeResultShape(@Cast("const Nd4jLong*") LongBuffer originalShapeBuffer, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeResultShape(@Cast("const Nd4jLong*") long[] originalShapeBuffer, int[] dimension,int dimensionLength); - - /** - * This method does inplace transpose of given shapeBuffer - * - * @param shapeBuffer - */ - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native void doPermuteSwap(int length, @Cast("Nd4jLong**") PointerPointer shape, + IntPointer rearrange); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer permuteShapeBuffer(@Cast("const Nd4jLong*") LongPointer shapeBuffer, + IntPointer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer permuteShapeBuffer(@Cast("const Nd4jLong*") LongBuffer shapeBuffer, + IntBuffer rearrange); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] permuteShapeBuffer(@Cast("const Nd4jLong*") long[] shapeBuffer, + int[] rearrange); + +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongPointer shapeBuffer, + IntPointer rearrange, + @Cast("Nd4jLong*") LongPointer out); +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + IntBuffer rearrange, + @Cast("Nd4jLong*") LongBuffer out); +@Namespace("shape") public static native void permuteShapeBufferInPlace(@Cast("Nd4jLong*") long[] shapeBuffer, + int[] rearrange, + @Cast("Nd4jLong*") long[] out); + +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, + @Const IntPointer rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongPointer shapeBuffer, + @Const IntPointer rearrange); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + @Const IntBuffer rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeBuffer, + @Const IntBuffer rearrange); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, + @Const int[] rearrange, + @Cast("Nd4jLong") long len/*=-1*/); +@Namespace("shape") public static native void doPermuteShapeInfo(@Cast("Nd4jLong*") long[] shapeBuffer, + @Const int[] rearrange); + +/** + * Rearrange the permute indexes + * according to which dimensions are specified. + * + * For example, dimension is implicitly: + * 0,1,2 + * + * If you want to do a reduce along dimensions 0 and 1, + * you need to permute the indexes to be: + * 2,0,1 + * + * which will give us the ability to ierate along an element + * wise stride. + */ + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createPermuteIndexes(int originalRank, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createPermuteIndexes(int originalRank, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createPermuteIndexes(int originalRank, + int[] dimension, + int dimensionLength); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer computeResultShape( + @Cast("const Nd4jLong*") LongPointer originalShapeBuffer, IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer computeResultShape( + @Cast("const Nd4jLong*") LongBuffer originalShapeBuffer, IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] computeResultShape( + @Cast("const Nd4jLong*") long[] originalShapeBuffer, int[] dimension, int dimensionLength); +/** + * This method does inplace transpose of given shapeBuffer + * + * @param shapeBuffer + */ +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native void transposeInplace(@Cast("Nd4jLong*") long[] shapeBuffer); /** * Get the ordering for the device @@ -7304,9 +7938,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param elementStride * @return */ - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int elementStride); - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int elementStride); - @Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int elementStride); +@Namespace("shape") public static native char getOrder(int length, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int elementStride); /** * Ensure that every value in the re arrange @@ -7324,10 +7961,14 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param rearrange the order to re arrange * @param rank the rank of the rearrange array */ - @Namespace("shape") public static native void permute(@Cast("shape::ShapeInformation**") PointerPointer info, IntPointer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntPointer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntBuffer rearrange, int rank); - @Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, int[] rearrange, int rank); +@Namespace("shape") public static native void permute(@Cast("shape::ShapeInformation**") PointerPointer info, IntPointer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntPointer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, IntBuffer rearrange, + int rank); +@Namespace("shape") public static native void permute(@ByPtrPtr ShapeInformation info, int[] rearrange, + int rank); /** * Returns whether the @@ -7335,72 +7976,80 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param shape the shape of the array * @param rank the rank of cthe shape */ - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shape, int rank); - - - /** - * When 1 dimension is the whole length of the - * array - */ - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shape, int rank); - - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, @ByRef IntPointer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @ByRef IntBuffer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") long[] shapeInfo, @ByRef int[] posOfNonUnityDim); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shape, int rank); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, @ByRef IntPointer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @ByRef IntBuffer posOfNonUnityDim); - @Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") long[] shapeInfo, @ByRef int[] posOfNonUnityDim); - - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") long[] shapeInfo); +/** + * When 1 dimension is the whole length of the + * array + */ +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shape, int rank); + +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int oneDimEqualToLength(@Cast("Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int isVector(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @ByRef IntPointer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @ByRef IntBuffer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isLikeVector(@Cast("const Nd4jLong*") long[] shapeInfo, + @ByRef int[] posOfNonUnityDim); + +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @ByRef IntPointer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @ByRef IntBuffer posOfNonUnityDim); +@Namespace("shape") public static native @Cast("bool") boolean isCommonVector(@Cast("const Nd4jLong*") long[] shapeInfo, + @ByRef int[] posOfNonUnityDim); + +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isRowVector(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isColumnVector(@Cast("const Nd4jLong*") long[] shapeInfo); - /** - * shape - input inShape is shape only, not shapeInfo - * returns number of non-unity dimensions in inShape - */ - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") LongPointer inShape); - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") LongBuffer inShape); - @Namespace("shape") public static native int numOfNonUnitDims(int rank, @Cast("const Nd4jLong*") long[] inShape); +/** + * shape - input inShape is shape only, not shapeInfo + * returns number of non-unity dimensions in inShape + */ +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") LongPointer inShape); +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") LongBuffer inShape); +@Namespace("shape") public static native int numOfNonUnitDims(int rank, + @Cast("const Nd4jLong*") long[] inShape); - /** +/** * Returns whether the * given shape is a vector or not * @param shape the shape of the array * @param rank the rank of the shape */ - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shape, int rank); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shape, int rank); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int isMatrix(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the shape portion of an information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeOf(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeOf(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeOf(@Cast("Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeOf(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeOf(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeOf(@Cast("Nd4jLong*") long[] shapeInfo); /** * Return a copy of a buffer. @@ -7408,19 +8057,22 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * that must be freed elsewhere. */ - /** +/** * Return a copy of a buffer. * This buffer allocates memory * that must be freed elsewhere. */ - /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongPointer from, @Cast("Nd4jLong*") LongPointer to, @Cast("Nd4jLong*") LongPointer indexes); - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongBuffer from, @Cast("Nd4jLong*") LongBuffer to, @Cast("Nd4jLong*") LongBuffer indexes); - @Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") long[] from, @Cast("Nd4jLong*") long[] to, @Cast("Nd4jLong*") long[] indexes); +/** + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongPointer from, @Cast("Nd4jLong*") LongPointer to, + @Cast("Nd4jLong*") LongPointer indexes); +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") LongBuffer from, @Cast("Nd4jLong*") LongBuffer to, + @Cast("Nd4jLong*") LongBuffer indexes); +@Namespace("shape") public static native void copyTo(int length, @Cast("const Nd4jLong*") long[] from, @Cast("Nd4jLong*") long[] to, + @Cast("Nd4jLong*") long[] indexes); /** * Permute the given strides @@ -7431,24 +8083,28 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * and all must be filled in) * @return the rearranged array */ - //ND4J_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, Nd4jLong *rearrange); +// SD_EXPORT _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int +// shapeRank, Nd4jLong *rearrange); /** * Return the slice (shape + 1 in pointer arithmetic) * @param shape the shape to take the slice of * @return the shape array - the first entry */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer slice(@Cast("Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer slice(@Cast("Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] slice(@Cast("Nd4jLong*") long[] shape); - - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native int slices(@Cast("Nd4jLong*") long[] shapeBuffer); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") LongPointer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") LongBuffer shapeBuffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, @Cast("Nd4jLong*") long[] shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer slice(@Cast("Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer slice(@Cast("Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] slice(@Cast("Nd4jLong*") long[] shape); + +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native int slices(@Cast("Nd4jLong*") long[] shapeBuffer); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") LongPointer shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") LongBuffer shapeBuffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] sliceOfShapeBuffer(@Cast("Nd4jLong") long sliceIdx, + @Cast("Nd4jLong*") long[] shapeBuffer); /** * Returns the length of the * shape information buffer: @@ -7457,35 +8113,35 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * info length for * @return rank * 2 + 4 */ - @Namespace("shape") public static native int shapeInfoLength(int rank); +@Namespace("shape") public static native int shapeInfoLength(int rank); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int shapeInfoLength(@Cast("Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(int rank); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(int rank); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("size_t") long shapeInfoByteLength(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the rank portion of * an information buffer */ - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native int rank(@Const IntPointer shapeInfo); - @Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); - @Namespace("shape") public static native int rank(@Const int[] shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native int rank(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native int rank(@Const IntPointer shapeInfo); +@Namespace("shape") public static native int rank(@Const IntBuffer shapeInfo); +@Namespace("shape") public static native int rank(@Const int[] shapeInfo); - /** - * returns pointer on elementWiseStride - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo); +/** + * returns pointer on elementWiseStride + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ews(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ews(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ews(@Cast("Nd4jLong*") long[] shapeInfo); /** * Converts a raw int buffer of the layout: @@ -7497,81 +8153,83 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * * where shape and stride are both straight int pointers */ - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native ShapeInformation infoFromBuffer(@Cast("Nd4jLong*") long[] buffer); /** * Returns the stride portion of an information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer stride(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer stride(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] stride(@Cast("Nd4jLong*") long[] buffer); /** * Compute the length of the given shape */ - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("bool") boolean isEmpty(@Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long length(@Cast("const Nd4jLong*") long[] shapeInfo); /*** * Returns the offset portion of an information buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong") long offset(@Cast("Nd4jLong*") long[] buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongPointer extra(@Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongBuffer extra(@Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef long[] extra(@Cast("Nd4jLong*") long[] buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongPointer extra(@Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef LongBuffer extra(@Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") @ByRef long[] extra(@Cast("Nd4jLong*") long[] buffer); /** * Returns the ordering * for this shape information buffer */ - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native char order(@Cast("const Nd4jLong*") long[] buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native char order(@Cast("const Nd4jLong*") long[] buffer); /** * Returns the type */ - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long type(@Cast("const Nd4jLong*") long[] shapeInfo); /** * Returns the element wise stride for this information * buffer */ - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long elementWiseStride(@Cast("const Nd4jLong*") long[] shapeInfo); - - /** +/** * Returns the element wise stride for this information * buffer - * relative to a dimension and ordering for a reduction index + * relative to a dimension and ordering for a reduction index */ - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") LongPointer buffer, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") LongBuffer buffer, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride(@Cast("Nd4jLong*") long[] buffer, int[] dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") LongPointer buffer, IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") LongBuffer buffer, IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long reductionIndexElementWiseStride( + @Cast("Nd4jLong*") long[] buffer, int[] dimension, int dimensionLength); /** * Returns whether * the given shape info buffer * represents a scalar shape */ - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); - @Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongPointer info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") LongBuffer info); +@Namespace("shape") public static native int isScalar(@Cast("const Nd4jLong*") long[] info); /** * Returns whether @@ -7579,7 +8237,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * represents a scalar * shape or not */ - @Namespace("shape") public static native int isScalar(ShapeInformation info); +@Namespace("shape") public static native int isScalar(ShapeInformation info); /** * Return a copy of this array with the @@ -7594,7 +8252,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** +/** * Return a copy of this array with the * given index omitted * @@ -7607,20 +8265,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** - * Iterate over a given set of indexes - * the begin and end indexes are 0 based. - * 1 padding is automatically assumed for the ending. - * - * For example if you want to iterate over 0 to 4 - * it will go to 4 rather than 3. - * - * indexes should be the indexes to exclude - * indexes length should be the length of indexes - */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer everyIndexBut(@Cast("const Nd4jLong*") LongPointer indexes,int indexesLength,int begin,int end); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer everyIndexBut(@Cast("const Nd4jLong*") LongBuffer indexes,int indexesLength,int begin,int end); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] everyIndexBut(@Cast("const Nd4jLong*") long[] indexes,int indexesLength,int begin,int end); +/** + * Iterate over a given set of indexes + * the begin and end indexes are 0 based. + * 1 padding is automatically assumed for the ending. + * + * For example if you want to iterate over 0 to 4 + * it will go to 4 rather than 3. + * + * indexes should be the indexes to exclude + * indexes length should be the length of indexes + */ +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer everyIndexBut(@Cast("const Nd4jLong*") LongPointer indexes, + int indexesLength, int begin, + int end); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer everyIndexBut(@Cast("const Nd4jLong*") LongBuffer indexes, + int indexesLength, int begin, + int end); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] everyIndexBut(@Cast("const Nd4jLong*") long[] indexes, + int indexesLength, int begin, + int end); /** * Computes the offset for accessing @@ -7630,7 +8294,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint //#ifdef __CUDACC__ // __device__ //#endif -// ND4J_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); +// SD_EXPORT int tadOffset(shape::ShapeInformation *xInfo, int offset); /** * Returns a shape @@ -7640,15 +8304,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the shape to be returned as * @return the new shape */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(@Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createScalarShapeInfo(@Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createScalarShapeInfo(@Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createScalarShapeInfo(@Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createScalarShapeInfo(@Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createScalarShapeInfo(@Cast("Nd4jLong*") long[] ret); /** * Generate an int buffer @@ -7666,9 +8330,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * Keep the given indexes * in the data */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer keep(@Cast("Nd4jLong*") LongPointer data, @Const IntPointer index, int indexLength, int dataLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer keep(@Cast("Nd4jLong*") LongBuffer data, @Const IntBuffer index, int indexLength, int dataLength); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] keep(@Cast("Nd4jLong*") long[] data, @Const int[] index, int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer keep(@Cast("Nd4jLong*") LongPointer data, @Const IntPointer index, + int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer keep(@Cast("Nd4jLong*") LongBuffer data, @Const IntBuffer index, + int indexLength, int dataLength); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] keep(@Cast("Nd4jLong*") long[] data, @Const int[] index, + int indexLength, int dataLength); /** * Generate reverse copy of the data @@ -7706,9 +8373,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the length per slice of the given shape * along the given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Const IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Const IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long lengthPerSlice(int rank, @Cast("const Nd4jLong*") long[] shape, + @Const int[] dimension, + int dimensionLength); /** * calculates the offset for a tensor @@ -7717,27 +8390,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tensorShape * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") LongPointer shape, - @Cast("const Nd4jLong*") LongPointer tensorShape, - int tensorShapeLength, - @Const IntPointer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") LongBuffer shape, - @Cast("const Nd4jLong*") LongBuffer tensorShape, - int tensorShapeLength, - @Const IntBuffer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int rank, - int index, - @Cast("const Nd4jLong*") long[] shape, - @Cast("const Nd4jLong*") long[] tensorShape, - int tensorShapeLength, - @Const int[] dimension, - int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer tensorShape, + int tensorShapeLength, @Const IntPointer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer tensorShape, + int tensorShapeLength, @Const IntBuffer dimension, int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor( + int rank, int index, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] tensorShape, + int tensorShapeLength, @Const int[] dimension, int dimensionLength); /** * calculates the offset for a tensor @@ -7746,53 +8407,55 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tensorShape * @return */ - @Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int index,int tensorLength,int lengthPerSlice2); +@Namespace("shape") public static native @Cast("Nd4jLong") long sliceOffsetForTensor(int index, int tensorLength, + int lengthPerSlice2); /** * Computes the tensor along dimension * offset * @param index the index to get the offset for the tad for * @param rank the rank of the shapes and strides * @param info the shape information to use for tad - * @param dimension the dimensions to use for computing the tensor along dimensions + * @param dimension the dimensions to use for computing the tensor along + * dimensions */ -// ND4J_EXPORT _CUDA_HD int offset(int index, +// SD_EXPORT _CUDA_HD int offset(int index, // int rank, // shape::ShapeInformation *info, // Nd4jLong *dimension, // int dimensionLength); - /** * Computes the number * of tensors along * a given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") LongPointer shape, - IntPointer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") LongBuffer shape, - IntBuffer dimension, - int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, - int length, - @Cast("Nd4jLong*") long[] shape, - int[] dimension, - int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") LongPointer shape, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") LongBuffer shape, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(int rank, int length, + @Cast("Nd4jLong*") long[] shape, + int[] dimension, + int dimensionLength); /** * Computes the number * of tensors along * a given dimension */ - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); - @Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); - - +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongPointer shapeInfo, + IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") LongBuffer shapeInfo, + IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native @Cast("Nd4jLong") long tensorsAlongDimension(@Cast("Nd4jLong*") long[] shapeInfo, + int[] dimension, + int dimensionLength); /** * Returns the tensor along dimension @@ -7802,26 +8465,30 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param i * @return */ - @Namespace("shape") public static native int tadForBlockIndex(int blockSize, int blockIdx, int i); +@Namespace("shape") public static native int tadForBlockIndex(int blockSize, int blockIdx, int i); /** * Computes the number of tads per block * */ - @Namespace("shape") public static native int tadsPerBlock(int blockSize, int tads); +@Namespace("shape") public static native int tadsPerBlock(int blockSize, int tads); -// ND4J_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, Nd4jLong *dimension, +// SD_EXPORT _CUDA_HD Nd4jLong *tadShapeInfo(int index, Nd4jLong *xShapeInfo, +// Nd4jLong *dimension, // int dimensionLength); /** * Returns a shape buffer * for the shape information metadata. */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer( ShapeInformation info); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer(ShapeInformation info); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") LongPointer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") LongBuffer ret); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] toShapeBuffer( ShapeInformation info, @Cast("Nd4jLong*") long[] ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") LongPointer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") LongBuffer ret); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] toShapeBuffer(ShapeInformation info, + @Cast("Nd4jLong*") long[] ret); /** * Returns the number of elements per thread @@ -7872,25 +8539,29 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param numElementsPerTad the number of elements * per tad */ - @Namespace("shape") public static native int tadIndex(int i, int elementWiseStride, int numElementsPerTad); +@Namespace("shape") public static native int tadIndex(int i, int elementWiseStride, + int numElementsPerTad); /** * Map a tad to a * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ - @Namespace("shape") public static native int reductionIndexForTad(int tadIndexForOriginal, int tadsForReduced, - int tadsForOriginal); +@Namespace("shape") public static native int reductionIndexForTad(int tadIndexForOriginal, + int tadsForReduced, + int tadsForOriginal); /** * Computes the number of tads * per reduce index for the * reduction tad. */ - @Namespace("shape") public static native int tadsPerReduceIndex(int tadsForReduce, int tadsForOriginal); +@Namespace("shape") public static native int tadsPerReduceIndex(int tadsForReduce, + int tadsForOriginal); /** * Maps a linear index to a reduction index @@ -7900,351 +8571,653 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param tadNum the number of tads for the shrunken problem * @param originalTadNum the tad number for the reduced version of the problem */ - @Namespace("shape") public static native int reductionIndexForLinear(int i, int elementWiseStride, int numElementsPerTad, - int tadNum, int originalTadNum); +@Namespace("shape") public static native int reductionIndexForLinear(int i, int elementWiseStride, + int numElementsPerTad, + int tadNum, int originalTadNum); /** * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); - - /** - * Returns the rear most left over item not present in - * the dimension array. This assumes that the dimension array is sorted. - * - * For example, given a dimension array of: - * 0,2 - * - * and - * - * 12,4,2,1 in data - * - * You end up with 1 (data[3]) - * since the first item won't match - * the last item of the dimension array - */ - -// ND4J_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int length,Nd4jLong *dimension,int dimensionLength); - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); +@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, @Cast("Nd4jLong") long baseOffset/*=0*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int rank); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank, @Cast("Nd4jLong*") LongPointer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank, @Cast("Nd4jLong*") LongBuffer buffer); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, int rank, @Cast("Nd4jLong*") long[] buffer); - - /** - * Convert a linear index to the corresponding coordinates - * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1] - */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); - - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); - @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); - - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims); - - /** - * Convert coordinates to the corresponding linear index (sequence number in other words) - * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned - */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords); - /** - * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! - */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims); - - /** - * increment n-dimensional array by one iteration by changing coord appropriately - * for example we have array with shape {2, 3}: - * - if input coord = {0,1}, then output coord = {0,2} - * - if input coord = {0,2}, then output coord = {1,0} - * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, {1,0}, {1,1}, {1,2} - */ +/** + * Returns the rear most left over item not present in + * the dimension array. This assumes that the dimension array is sorted. + * + * For example, given a dimension array of: + * 0,2 + * + * and + * + * 12,4,2,1 in data + * + * You end up with 1 (data[3]) + * since the first item won't match + * the last item of the dimension array + */ - /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} - */ - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); - @Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer lShapeInfo, @Cast("const uint*") IntPointer uShapeInfo, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer lShapeInfo, @Cast("const uint*") IntBuffer uShapeInfo, @Cast("const bool") boolean useUnsigned); - @Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] lShapeInfo, @Cast("const uint*") int[] uShapeInfo, @Cast("const bool") boolean useUnsigned); - - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") long[] shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, @Cast("const Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, @Cast("const Nd4jLong*") long[] shapeInfo); - - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides); - @Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides); - - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongPointer arr, int length); - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongBuffer arr, int length); - @Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") long[] arr, int length); - @Namespace("shape") public static native void printIntArray(@Const IntPointer arr, int length); - @Namespace("shape") public static native void printIntArray(@Const IntBuffer arr, int length); - @Namespace("shape") public static native void printIntArray(@Const int[] arr, int length); - - @Namespace("shape") public static native void printArray(FloatPointer arr,int length); - @Namespace("shape") public static native void printArray(FloatBuffer arr,int length); - @Namespace("shape") public static native void printArray(float[] arr,int length); - - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntPointer shape,@Cast("bool") boolean fortranOrder); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntBuffer shape,@Cast("bool") boolean fortranOrder); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferOfNpy(int rank, @Cast("unsigned int*") int[] shape,@Cast("bool") boolean fortranOrder); - -// ND4J_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); - - - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also sort input array of dimensions, this operation is also necessary for creating TAD object - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector IntPointer dimensions); - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector IntBuffer dimensions); - @Namespace("shape") public static native void checkDimensions(int rank, @StdVector int[] dimensions); - - // function calculates linear index of array min, min is sub-array of max, index to be returned is min-array's index and corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // function calculates absolute offset of min array, min is sub-array of max, offset to be returned corresponds to maxIdx of max array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // max array is outer for min array, min array is sub-array of max array - // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) - // dimsToExclude - should be sorted in increasing order - // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array - // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); - - // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array - // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand - // dimsToExclude - should be sorted in increasing order - // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff); - - // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array - // rank is equal to size of shape - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets, byte order/*='c'*/); - @Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); - // ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c'); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") LongPointer buffer, byte order); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") LongBuffer buffer, byte order); - @Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, @Cast("Nd4jLong*const") long[] buffer, byte order); - - // deduce order and element-wise stride - // if array is scalar or unit length vector then ews = 1 and order is preserved - // if array is common vector then ews = stride of non-unity dimension and order is preserved - // if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") LongPointer shapeNoUnities, @Cast("const Nd4jLong*") LongPointer stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") LongBuffer shapeNoUnities, @Cast("const Nd4jLong*") LongBuffer stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo, byte proposedOrder, int numOfNonUnitDims, @Cast("const Nd4jLong*") long[] shapeNoUnities, @Cast("const Nd4jLong*") long[] stridesNoUnities); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo); - @Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo); - - /** - * processes whole set of sub-arrays - * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) and their buffer offsets (each sub-array has its own unique offset from original this-buffer) - * arguments: - * wholeShapeInfo - original shapeInfo of whole array - * numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs - * dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned - * dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] - * subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays) - * subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer - * keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - */ - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, @Cast("Nd4jLong*") LongPointer subArrOffsets); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, @Cast("Nd4jLong*") LongBuffer subArrOffsets); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); - @Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets(@Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, @Cast("Nd4jLong*") long[] subArrOffsets); - - /** - * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array - * arguments: - * idx - input argument, intervals of indexes which define the sub-array to point on, - * when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank) - * when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank) - * when (dimStart == dimEnd) then whole range will be used for current dimension - * maxShapeInfo - input argument, shapeInfo of original array - * minShapeInfo - output argument, shapeInfo of sub-array to be deduced - * minOffset - output argument, offset of sub-array buffer offsets from original buffer - * keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} - * isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd, - * numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1 - */ - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongPointer minOffset); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") @ByRef LongBuffer minOffset); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); - @Namespace("shape") public static native void calcSubArrShapeInfoAndOffset(@Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") @ByRef long[] minOffset); - - /** - * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} - * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order - * stridesNoUnities will point on strides in shapeNoUnities that is on {4,1} - * returns number of non-unity dimensions in inShapeInfo - * if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo - */ - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongPointer shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef LongPointer stridesNoUnities); - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef LongBuffer stridesNoUnities); - @Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, @Cast("Nd4jLong*&") @ByPtrRef long[] shapeNoUnities, @Cast("Nd4jLong*&") @ByPtrRef long[] stridesNoUnities); - - /** - * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2 - * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} - */ - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer outShapeInfo); - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer outShapeInfo); - @Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] outShapeInfo); - - /** - * get stride over contiguous axis (contiguous axis must have stride = 1) - * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) - */ - // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); - - - - - - -//END HEADERS - - - //BEGIN IMPLEMENTATIONS +// SD_EXPORT _CUDA_HD int rearMostLeftOverItem(Nd4jLong *data,int +// length,Nd4jLong *dimension,int dimensionLength); +/** + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords, + @Cast("Nd4jLong") long baseOffset/*=0*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int rank); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int rank); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, + int rank, @Cast("Nd4jLong*") LongPointer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, + int rank, @Cast("Nd4jLong*") LongBuffer buffer); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] createShapeInfo(@Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] stride, + int rank, @Cast("Nd4jLong*") long[] buffer); -// #ifdef __CUDACC__ -// #endif +/** + * Convert a linear index to the corresponding coordinates + * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, + * 1] + */ +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + IntPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + IntBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + int[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, + @Cast("const Nd4jLong*") long[] shape, int[] coords); + +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("Nd4jLong*") long[] coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); +@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, + @Cast("const Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); /** -* Length of a tad given -* the shape information -*/ + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, + IntPointer coords, int dimsSize, + @Const IntPointer tadDims); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, + IntBuffer coords, int dimsSize, + @Const IntBuffer tadDims); +@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, + int[] coords, int dimsSize, + @Const int[] tadDims); + +/** + * Convert coordinates to the corresponding linear index (sequence number in + * other words) for example if shape is {2, 4} and coordinates [1, 1] then index + * 5 is returned + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Const IntPointer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Const IntBuffer coords); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, + @Const int[] coords); +/** + * take into account only dimensions stored in tadDims, tadDims must be sorted + * in increasing order! + */ +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Const IntPointer coords, int dimsSize, + @Const IntPointer tadDims); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Const IntBuffer coords, int dimsSize, + @Const IntBuffer tadDims); +@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, + @Const int[] coords, int dimsSize, + @Const int[] tadDims); + +/** + * increment n-dimensional array by one iteration by changing coord + * appropriately for example we have array with shape {2, 3}: + * - if input coord = {0,1}, then output coord = {0,2} + * - if input coord = {0,2}, then output coord = {1,0} + * so the aim is to produce following subsequence of coord: {0,0}, {0,1}, {0,2}, + * {1,0}, {1,1}, {1,2} + */ + +/* calculates an array buffer offset for given "index" using following formula: + * offset = coord_0*stride_0 + coord_1*stride_1 + ... + + * coord_{rank-1}*stride_{rank-1} + */ +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntPointer shapeInfo); +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") IntBuffer shapeInfo); +@Namespace("shape") public static native @Cast("uint") int getIndexOffset(@Cast("uint") int index, @Cast("const uint*") int[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long getIndexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongPointer lShapeInfo, + @Cast("const uint*") IntPointer uShapeInfo, + @Cast("const bool") boolean useUnsigned); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") LongBuffer lShapeInfo, + @Cast("const uint*") IntBuffer uShapeInfo, + @Cast("const bool") boolean useUnsigned); +@Namespace("shape") public static native @Cast("Nd4jLong") long indexOffset(@Cast("Nd4jLong") long index, + @Cast("const Nd4jLong*") long[] lShapeInfo, + @Cast("const uint*") int[] uShapeInfo, + @Cast("const bool") boolean useUnsigned); + +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfo(@Cast("Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") long[] shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, + @Cast("const Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, + @Cast("const Nd4jLong*") long[] shapeInfo); + +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") String msg, int rank, + @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides); +@Namespace("shape") public static native void printShapeInfoLinear(@Cast("char*") BytePointer msg, int rank, + @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides); + +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongPointer arr, int length); +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") LongBuffer arr, int length); +@Namespace("shape") public static native void printIntArray(@Cast("const Nd4jLong*") long[] arr, int length); +@Namespace("shape") public static native void printIntArray(@Const IntPointer arr, int length); +@Namespace("shape") public static native void printIntArray(@Const IntBuffer arr, int length); +@Namespace("shape") public static native void printIntArray(@Const int[] arr, int length); + +@Namespace("shape") public static native void printArray(FloatPointer arr, int length); +@Namespace("shape") public static native void printArray(FloatBuffer arr, int length); +@Namespace("shape") public static native void printArray(float[] arr, int length); + +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntPointer shape, + @Cast("bool") boolean fortranOrder); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer shapeBufferOfNpy(int rank, @Cast("unsigned int*") IntBuffer shape, + @Cast("bool") boolean fortranOrder); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] shapeBufferOfNpy(int rank, @Cast("unsigned int*") int[] shape, + @Cast("bool") boolean fortranOrder); + +// SD_EXPORT _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer); + +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also sort +// input array of dimensions, this operation is also necessary for creating TAD +// object +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector IntPointer dimensions); +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector IntBuffer dimensions); +@Namespace("shape") public static native void checkDimensions(int rank, + @StdVector int[] dimensions); + +// function calculates linear index of array min, min is sub-array of max, index +// to be returned is min-array's index and corresponds to maxIdx of max array +// dimsToExclude - should be sorted in increasing order +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayIndex(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// function calculates absolute offset of min array, min is sub-array of max, +// offset to be returned corresponds to maxIdx of max array dimsToExclude - +// should be sorted in increasing order +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native @Cast("Nd4jLong") long subArrayOffset(@Cast("const Nd4jLong") long maxIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// max array is outer for min array, min array is sub-array of max array +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) +// dimsToExclude - should be sorted in increasing order +// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated +// as maxRank - minRank +@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/, + int dimsLen/*=-1*/); +@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// calculate indexes of max-array, these output indexes correspond to one minIdx +// index of min-array which is sub-array of max-array dimsToExclude - should be +// sorted in increasing order +@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + @Const IntPointer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo); +@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + @Const IntBuffer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo); +@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + @Const int[] dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo); + +// calculate offsets of max-array, these offsets correspond to one minIdx index +// of min-array which is sub-array of max-array maxOffsets - will contain +// calculated offsets of max-array, buffer for maxOffsets should be allocated +// beforehand dimsToExclude - should be sorted in increasing order memBuff - +// auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments +// storing, should be allocated beforehand +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + IntPointer memBuff, + @Const IntPointer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongPointer maxShapeInfo, + @Cast("const Nd4jLong*") LongPointer minShapeInfo, + IntPointer memBuff); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + IntBuffer memBuff, + @Const IntBuffer dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, + @Cast("const Nd4jLong*") LongBuffer minShapeInfo, + IntBuffer memBuff); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + int[] memBuff, + @Const int[] dimsToExclude/*=nullptr*/); +@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, + @Cast("const Nd4jLong") long minIdx, + @Cast("const Nd4jLong*") long[] maxShapeInfo, + @Cast("const Nd4jLong*") long[] minShapeInfo, + int[] memBuff); + +// calculates offsets for entities (elements or sub-arrays), shape in context of +// sub-array means dimensions excluded from outer array rank is equal to size of +// shape +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongPointer shape, + @Cast("const Nd4jLong*") LongPointer strides, @Cast("Nd4jLong*") LongPointer offsets); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") LongBuffer shape, + @Cast("const Nd4jLong*") LongBuffer strides, @Cast("Nd4jLong*") LongBuffer offsets); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(int rank, @Cast("const Nd4jLong*") long[] shape, + @Cast("const Nd4jLong*") long[] strides, @Cast("Nd4jLong*") long[] offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer offsets); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets, + byte order/*='c'*/); +@Namespace("shape") public static native void calcOffsets(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] offsets); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c'); +// SD_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, +// const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, +// Nd4jLong*& zOffsets, const char order = 'c'); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") LongPointer buffer, + byte order); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") LongBuffer buffer, + byte order); +@Namespace("shape") public static native void shapeOldScalar(@Cast("sd::DataType") int dtype, + @Cast("Nd4jLong*const") long[] buffer, + byte order); + +// deduce order and element-wise stride +// if array is scalar or unit length vector then ews = 1 and order is preserved +// if array is common vector then ews = stride of non-unity dimension and order +// is preserved if strides are normal/contiguous then ews = 1 and corresponding +// order is set, otherwise ews = 0 and order is preserved +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") LongPointer shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") LongPointer shapeNoUnities, @Cast("const Nd4jLong*") LongPointer stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") LongBuffer shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") LongBuffer shapeNoUnities, @Cast("const Nd4jLong*") LongBuffer stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder( + @Cast("Nd4jLong*") long[] shapeInfo, byte proposedOrder, int numOfNonUnitDims, + @Cast("const Nd4jLong*") long[] shapeNoUnities, @Cast("const Nd4jLong*") long[] stridesNoUnities); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongPointer shapeInfo); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") LongBuffer shapeInfo); +@Namespace("shape") public static native void checkStridesEwsAndOrder(@Cast("Nd4jLong*") long[] shapeInfo); + +/** + * processes whole set of sub-arrays + * evaluates shapeInfo of sub-arrays (all sub-arrays have the same shapeInfo) + * and their buffer offsets (each sub-array has its own unique offset from + * original this-buffer) arguments: wholeShapeInfo - original shapeInfo of whole + * array numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to + * numOfSubArrs dimsSize - size of dimsToExclude, if dimsSize = array rank or + * dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and + * one zero offset will be returned dimsToExclude - MUST BE SORTED, dimensions + * to evaluate sub-array along, i.e. when shape is [2,3,4,5] and + * dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5] + * subArrShapeInfo - output argument, contains shapeInfo (same for all + * sub-arrays) subArrOffsets - output argument, contains successive + * sub-arrays offsets from original this-buffer keepUnitiesInShape - if false + * then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> + * {a,b} + */ +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, + @Cast("Nd4jLong*") LongPointer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongPointer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntPointer dimsToExclude, @Cast("Nd4jLong*") LongPointer subArrShapeInfo, + @Cast("Nd4jLong*") LongPointer subArrOffsets); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*") LongBuffer subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") LongBuffer wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const IntBuffer dimsToExclude, @Cast("Nd4jLong*") LongBuffer subArrShapeInfo, + @Cast("Nd4jLong*") LongBuffer subArrOffsets); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, + @Cast("Nd4jLong*") long[] subArrOffsets, @Cast("bool") boolean keepUnitiesInShape/*=false*/); +@Namespace("shape") public static native void calcSubArrsShapeInfoAndOffsets( + @Cast("const Nd4jLong*") long[] wholeShapeInfo, @Cast("const Nd4jLong") long numOfSubArrs, + int dimsSize, @Const int[] dimsToExclude, @Cast("Nd4jLong*") long[] subArrShapeInfo, + @Cast("Nd4jLong*") long[] subArrOffsets); + +/** + * processes only one sub-array, evaluates shapeInfo of sub-array and its buffer + * offset from original array arguments: idx - input argument, intervals of + * indexes which define the sub-array to point on, when isStrided = false then + * idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * + * maxRank) when isStrided = true then idx has form + * {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and + * length (3 * maxRank) when (dimStart == dimEnd) then whole range will be used + * for current dimension maxShapeInfo - input argument, shapeInfo of original + * array minShapeInfo - output argument, shapeInfo of sub-array to be deduced + * minOffset - output argument, offset of sub-array buffer offsets from original + * buffer keepUnitiesInShape - input argument, if false then eliminate unities + * from sub-array shapeInfo, for example {1,a,1,b} -> {a,b} isStrided - input + * argument, if true then idx has length (3 * this->rankOf()) and contains + * additional stride numbers which correspond to stride between dimStart and + * dimEnd, numOfUntiesInMinShape - input argument, number of occurrences in idx + * when (dimEnd - dimStart) = 1 + */ +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongPointer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongPointer idx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("Nd4jLong*") LongPointer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongPointer minOffset); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongBuffer minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") LongBuffer idx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("Nd4jLong*") LongBuffer minShapeInfo, + @Cast("Nd4jLong*") @ByRef LongBuffer minOffset); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, + @Cast("Nd4jLong*") @ByRef long[] minOffset, @Cast("const bool") boolean keepUnitiesInShape/*=false*/, + @Cast("const bool") boolean isStrided/*=false*/, int numOfUntiesInMinShape/*=0*/); +@Namespace("shape") public static native void calcSubArrShapeInfoAndOffset( + @Cast("const Nd4jLong*") long[] idx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("Nd4jLong*") long[] minShapeInfo, + @Cast("Nd4jLong*") @ByRef long[] minOffset); + +/** + * for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99} + * then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and + * strides, no rank/type/ews/order stridesNoUnities will point on strides in + * shapeNoUnities that is on {4,1} returns number of non-unity dimensions in + * inShapeInfo if there is no unities in inShapeInfo, then no copy procedure + * will be performed and shapeNoUnities/stridesNoUnities will point on + * corresponding places in inShapeInfo + */ +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef LongPointer stridesNoUnities); +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef LongBuffer stridesNoUnities); +@Namespace("shape") public static native int excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, + @Cast("Nd4jLong*&") @ByPtrRef long[] shapeNoUnities, + @Cast("Nd4jLong*&") @ByPtrRef long[] stridesNoUnities); + +/** + * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, + * dimsToExclude = {1,3}, dimsSize = 2 then outShapeInfo will contain {3, 2,3,4, + * 12,4,1, 16384,1,99} + */ +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongPointer inShapeInfo, + int dimsSize, + @Const IntPointer dimsToExclude, + @Cast("Nd4jLong*") LongPointer outShapeInfo); +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") LongBuffer inShapeInfo, + int dimsSize, + @Const IntBuffer dimsToExclude, + @Cast("Nd4jLong*") LongBuffer outShapeInfo); +@Namespace("shape") public static native void excludeUnitiesFromShapeInfo(@Cast("const Nd4jLong*") long[] inShapeInfo, + int dimsSize, + @Const int[] dimsToExclude, + @Cast("Nd4jLong*") long[] outShapeInfo); + +/** + * get stride over contiguous axis (contiguous axis must have stride = 1) + * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then + * output is 5 (that is smallest stride in inShapeInfo except those equal to 1) + */ +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo); +// END HEADERS +// BEGIN IMPLEMENTATIONS + +// #ifdef __CUDACC__ +// #endif + +/** + * Length of a tad given + * the shape information + */ /** * Tad element wise stride: @@ -8272,9 +9245,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * Again: this may not preserve ordering of the tad * but maybe used for reductions. */ - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension,int dimensionLength); - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension,int dimensionLength); - @Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension,int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int tadElementWiseStride(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, + int dimensionLength); /** * Computes the standard packed array strides for a given shape. @@ -8312,9 +9288,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - -// check whether input dimensions are permuted, not permuted dimensions order have to be 0,....,rank-1 - +// check whether input dimensions are permuted, not permuted dimensions order +// have to be 0,....,rank-1 /** * @param toCopy the shape to copy @@ -8326,16 +9301,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the given rank and shape. */ - /** - * This is special method, it returns ONLY 2D shapebuffer. - * - * This method is used only for SoftMax - */ +/** + * This is special method, it returns ONLY 2D shapebuffer. + * + * This method is used only for SoftMax + */ /** -* Get the shape info buffer -* for the given rank and shape. -*/ + * Get the shape info buffer + * for the given rank and shape. + */ ////////////////////////////////////////////////////////////////////// @@ -8345,9 +9320,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - // ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { +// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong +// *shapeInfo, Nd4jLong arrLen) { // const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; @@ -8369,7 +9344,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return offset; // } -// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { +// INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, +// uint arrLen) { // const uint rank = shapeInfo[0]; // const uint ews = shapeInfo[rank + rank + 2]; @@ -8396,7 +9372,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// /** @@ -8424,10 +9399,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - - - - /** * Ensure that every value in the re arrange * array is unique @@ -8455,11 +9426,11 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// /** -* Returns whether the -* given shape is a vector or not -* @param shape the shape of the array -* @param rank the rank of the shape -*/ + * Returns whether the + * given shape is a vector or not + * @param shape the shape of the array + * @param rank the rank of the shape + */ /** * Returns the shape portion of an information @@ -8473,16 +9444,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint */ /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ /** -* Return a copy of a buffer. -* This buffer allocates memory -* that must be freed elsewhere. -*/ + * Return a copy of a buffer. + * This buffer allocates memory + * that must be freed elsewhere. + */ /** * Permute the given strides @@ -8493,15 +9464,14 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * and all must be filled in) * @return the rearranged array */ - /* - INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int shapeRank, int *rearrange) { - Nd4jLong *strideCopy = copyOf(shapeRank, toPermute); - checkArrangeArray(rearrange, shapeRank, shapeRank); - Nd4jLong *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); - delete[] strideCopy; - return newStride; - } - */ +/* + INLINEDEF _CUDA_HD Nd4jLong *permutedStrides(Nd4jLong *toPermute, int + shapeRank, int *rearrange) { Nd4jLong *strideCopy = copyOf(shapeRank, + toPermute); checkArrangeArray(rearrange, shapeRank, shapeRank); Nd4jLong + *newStride = doPermuteSwap(shapeRank, strideCopy, rearrange); delete[] + strideCopy; return newStride; + } + */ /** * Return the slice (shape + 1 in pointer arithmetic) @@ -8539,7 +9509,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * buffer */ - /** * Compute the length of the given shape */ @@ -8549,7 +9518,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * portion of an information buffer */ - /** * Returns the ordering * for this shape information buffer @@ -8565,9 +9533,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint */ /** -* Returns the element wise stride for this information -* buffer relative to a dimension and reduction index -*/ + * Returns the element wise stride for this information + * buffer relative to a dimension and reduction index + */ /** * Returns whether @@ -8595,7 +9563,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * item */ - /** +/** * Return a copy of this array with the * given index omitted * @@ -8624,9 +9592,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * for the shape to be returned as * @return the new shape */ - @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape, int dimension); - @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape, int dimension); - @Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer ensureVectorShape(@Cast("Nd4jLong*") LongPointer shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer ensureVectorShape(@Cast("Nd4jLong*") LongBuffer shape, int dimension); +@Namespace("shape") public static native @Cast("Nd4jLong*") long[] ensureVectorShape(@Cast("Nd4jLong*") long[] shape, int dimension); /** * Returns a shape @@ -8637,23 +9605,24 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return the new shape */ - /** - * This method does STRICT comparison for two shape buffers - * - * @param shape - * @return - */ +/** + * This method does STRICT comparison for two shape buffers + * + * @param shape + * @return + */ ////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// - /** - * This method does SOFT comparison for two shape buffers, we compare only rank & shapes - * - * @param shape - * @return - */ +/** + * This method does SOFT comparison for two shape buffers, we compare only rank + * & shapes + * + * @param shape + * @return + */ /** * Generate an int buffer @@ -8724,7 +9693,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - /** +/** * calculates the offset for a tensor * @param index * @param arr @@ -8732,14 +9701,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @return */ - // #ifdef __CUDACC__ // #endif - - - - /** * Computes the number * of tensors along @@ -8752,20 +9716,17 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * a given dimension */ - - - /** -* Get an offset for retrieval -* from a data buffer -* based on the given -* shape stride and given indices -* @param baseOffset the offset to start from -* @param shape the shape of the array -* @param stride the stride of the array -* @param indices the indices to iterate over -* @return the double at the specified index -*/ + * Get an offset for retrieval + * from a data buffer + * based on the given + * shape stride and given indices + * @param baseOffset the offset to start from + * @param shape the shape of the array + * @param stride the stride of the array + * @param indices the indices to iterate over + * @return the double at the specified index + */ ////////////////////////////////////////////////////////////////////////// @@ -8773,7 +9734,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// - /** * Returns the tensor along dimension * for the given block index @@ -8807,7 +9767,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * reduction index. * @param tadIndexForOriginal the original tad index for the * split up problem (eg: split is dimension 3 mapping to a 2,3 problem) - * @param tadsForReduced the number of tads for the shrunk down problem (eg: 2,3) + * @param tadsForReduced the number of tads for the shrunk down problem (eg: + * 2,3) * @param tadsForOriginal the number of tads for the smaller problem (eg: 3) */ @@ -8833,24 +9794,21 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * @param originalTadNum the tad number for the reduced version of the problem */ - /** * Returns the prod of the data * up to the given length */ - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongPointer data, @Cast("Nd4jLong*") LongPointer dimension,int dimensionLength); - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongBuffer data, @Cast("Nd4jLong*") LongBuffer dimension,int dimensionLength); - @Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") long[] data, @Cast("Nd4jLong*") long[] dimension,int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongPointer data, @Cast("Nd4jLong*") LongPointer dimension, + int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") LongBuffer data, @Cast("Nd4jLong*") LongBuffer dimension, + int dimensionLength); +@Namespace("shape") public static native int rearMostLeftOverItem(@Cast("Nd4jLong*") long[] data, @Cast("Nd4jLong*") long[] dimension, + int dimensionLength); // #ifdef __CUDACC__ // #endif - - - - - // INLINEDEF _CUDA_HD Nd4jLong *shapeBufferOfNpyBuffer(char *buffer) { // unsigned Nd4jLong *shape; // unsigned int ndims, wordSize; @@ -8864,17 +9822,17 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////////// // copy-past from java hasDefaultStridesForShape function -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, Nd4jLong* oldShape, const +// int newRank, Nd4jLong* newShapeOf, bool isFOrder, Nd4jLong* target) { // int oldnd; // Nd4jLong* olddims = shape::copyOf(oldRank, shape::shapeOf(oldShape)); -// Nd4jLong* oldstrides = shape::copyOf(oldRank, shape::stride(oldShape)); -// int np, op, last_stride; -// int oi, oj, ok, ni, nj, nk; -// Nd4jLong* newStrides = new Nd4jLong[newRank]; -// oldnd = 0; +// Nd4jLong* oldstrides = shape::copyOf(oldRank, +// shape::stride(oldShape)); int np, op, last_stride; int oi, oj, ok, +// ni, nj, nk; Nd4jLong* newStrides = new Nd4jLong[newRank]; oldnd = 0; // /* -// * Remove axes with dimension 1 from the old array. They have no effect +// * Remove axes with dimension 1 from the old array. They have no +// effect // * but would need special cases since their strides do not matter. // */ // for (oi = 0; oi < oldRank; oi++) { @@ -8911,11 +9869,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return false; // } -// /* oi to oj and ni to nj give the axis ranges currently worked with */ -// oi = 0; -// oj = 1; -// ni = 0; -// nj = 1; +// /* oi to oj and ni to nj give the axis ranges currently worked with +// */ oi = 0; oj = 1; ni = 0; nj = 1; // while (ni < newRank && oi < oldnd) { // np = newShapeOf[ni]; @@ -8943,7 +9898,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // } else { // /* C order */ -// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1]) { +// if (oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + +// 1]) { // /* not contiguous enough */ // delete[] olddims; // delete[] oldstrides; @@ -8994,7 +9950,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // target[shape::shapeInfoLength(newRank) - 3] = 0; // target[shape::shapeInfoLength(newRank) - 2] = 0; // target[shape::shapeInfoLength(newRank) - 1] = isFOrder ? 102 : 99; -// sd::ArrayOptions::setDataType(target, sd::ArrayOptions::dataType(oldShape)); +// sd::ArrayOptions::setDataType(target, +// sd::ArrayOptions::dataType(oldShape)); // delete[] olddims; // delete[] oldstrides; @@ -9004,18 +9961,26 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) { +// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* +// oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* +// newShapeInfo) { -// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements -// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo +// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order +// (except insertion/elimination of unities) will definitely cause +// allocation of new buffer for array elements +// // also this function takes into account identical shapes +// automatically, namely in that case oldShapeInfo is completely copied +// to newShapeInfo // newShapeInfo[0] = newRank; // memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong)); // Nd4jLong* newStrides = shape::stride(newShapeInfo); -// const Nd4jLong* oldShape = shape::shapeOf(const_cast(oldShapeInfo)); -// const Nd4jLong* oldStrides = shape::stride(const_cast(oldShapeInfo)); -// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim; +// const Nd4jLong* oldShape = +// shape::shapeOf(const_cast(oldShapeInfo)); const Nd4jLong* +// oldStrides = shape::stride(const_cast(oldShapeInfo)); +// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, +// oldDim; // while (newStart < newRank && oldStart < oldRank) { @@ -9026,13 +9991,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // if (newDim < oldDim) newDim *= newShape[newStop++]; // else oldDim *= oldShape[oldStop++]; -// // ------ Check whether the original axes can be combined ------ // -// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { -// if(oldShape[i] == 1) // skip unity-dimension and its stride +// // ------ Check whether the original axes can be combined ------ +// // for (int step = 1, i = oldStart; i < oldStop - 1; ++i) { +// if(oldShape[i] == 1) // skip unity-dimension +// and its stride // continue; // while((i + step) < oldRank && oldShape[i + step] == 1) -// ++step; // skip following unity-dimensions and its strides if such are present -// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step]) +// ++step; // skip following +// unity-dimensions and its strides if such are present +// if((i + step) < oldRank && oldStrides[i] != oldShape[i + +// step] * oldStrides[i + step]) // return false; // not contiguous enough // } @@ -9044,13 +10012,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // oldStart = oldStop++; // } -// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) -// for (int i = newStart; i < newRank; ++i) +// // rest of strides should be unities (if there is remainder in +// strides space, that is newStart < newRank) for (int i = newStart; i < +// newRank; ++i) // newStrides[i] = 1; -// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order -// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews -// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type +// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order +// newShapeInfo[2 * newRank + 2] = +// shape::elementWiseStride(oldShapeInfo); // ews newShapeInfo[2 * +// newRank + 1] = shape::type(oldShapeInfo); // type // return true; // } @@ -9059,20 +10029,22 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - // this function checks the consistence of dimensions with array rank (negative dimensions, too large dimensions, too big number of dimensions) - // also it sorts input array of dimensions, this operation is also necessary for creating TAD object - +// this function checks the consistence of dimensions with array rank (negative +// dimensions, too large dimensions, too big number of dimensions) also it sorts +// input array of dimensions, this operation is also necessary for creating TAD +// object // max array is outer for min array, min array is sub-array of max array -// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) +// function calculates the coordinates of min array (and saves them into +// minIdxs) given coordinates of max array (already stored in maxIdxs) - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// @@ -9103,7 +10075,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* +// zShapeInfo, Nd4jLong*& zOffsets, const char order) { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -9116,22 +10090,27 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // const char yOrder = shape::order(yShapeInfo); // const char zOrder = shape::order(zShapeInfo); -// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo); +// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, +// zShapeInfo); -// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == +// zOrder && (xOrder == 'c' || shapesSame)) { // xOffsets = yOffsets = zOffsets = nullptr; // } -// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) { +// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, yShapeInfo))) { // xOffsets = yOffsets = nullptr; // zOffsets = new Nd4jLong[len]; // shape::calcOffsets(zShapeInfo, zOffsets, xOrder); // } -// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) { +// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || +// shape::shapeEquals(xShapeInfo, zShapeInfo))) { // xOffsets = zOffsets = nullptr; // yOffsets = new Nd4jLong[len]; // shape::calcOffsets(yShapeInfo, yOffsets, xOrder); // } -// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) { +// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || +// shape::shapeEquals(yShapeInfo, zShapeInfo))) { // yOffsets = zOffsets = nullptr; // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets, yOrder); @@ -9184,7 +10163,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // } // } -// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) { +// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, +// zShapeInfo)) { // xOffsets = new Nd4jLong[len]; // shape::calcOffsets(xShapeInfo, xOffsets); // yOffsets = zOffsets = xOffsets; @@ -9244,7 +10224,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) { +// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& +// xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) +// { // // we assume all array have same length // const Nd4jLong len = shape::length(xShapeInfo); @@ -9257,7 +10239,8 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo); -// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) { +// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || +// shapesSame)) { // xOffsets = yOffsets = nullptr; // } // else if(xEws == 1) { @@ -9296,9 +10279,9 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint ////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////// -// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const +// Nd4jLong* inShapeInfo) { // Nd4jLong result = 9223372036854775807LL; @@ -9316,9 +10299,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // return result == 9223372036854775807LL ? 1 : result; // } - - - + // namespace shape // #endif /* SHAPE_H_ */ @@ -9348,7 +10329,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPARGSHOLDER_H // #define LIBND4J_OPARGSHOLDER_H - // #include // #include @@ -9363,68 +10343,99 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint return (OpArgsHolder)super.position(position); } + // default constructor + public OpArgsHolder() { super((Pointer)null); allocate(); } + private native void allocate(); - // default constructor - public OpArgsHolder() { super((Pointer)null); allocate(); } - private native void allocate(); - - // copy constructor - public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef OpArgsHolder other); - - // constructor - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoublePointer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector DoubleBuffer tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); - public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } - private native void allocate(@Const @ByRef NDArrayVector inArrs, @StdVector double[] tArgs/*=std::vector()*/, @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); - - // move constructor - - // assignment operator - public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); - - // move assignment operator - - public native @Const @ByRef NDArrayVector getInArrs(); - - public native @StdVector DoublePointer getTArgs(); - - public native @Cast("Nd4jLong*") @StdVector LongPointer getIArgs(); - - public native @Cast("bool*") @StdVector BooleanPointer getBArgs(); - - public native @Cast("bool*") @StdVector BooleanPointer getAllocInfo(); - - public native int getNumInArrs(); - - public native int getNumTArgs(); - - public native int getNumIArgs(); - - public native int getNumBArgs(); - - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, @Cast("const bool") boolean isInPlace/*=false*/); - public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); - + // copy constructor + public OpArgsHolder(@Const @ByRef OpArgsHolder other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef OpArgsHolder other); + + // constructor + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs) { super((Pointer)null); allocate(inArrs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoublePointer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector DoubleBuffer tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/); + public OpArgsHolder(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/) { super((Pointer)null); allocate(inArrs, tArgs, iArgs, bArgs); } + private native void allocate(@Const @ByRef NDArrayVector inArrs, + @StdVector double[] tArgs/*=std::vector()*/, + @Cast("Nd4jLong*") @StdVector long[] iArgs/*=std::vector()*/, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/); + + // move constructor + + // assignment operator + public native @ByRef @Name("operator =") OpArgsHolder put(@Const @ByRef OpArgsHolder other); + + // move assignment operator + + public native @Const @ByRef NDArrayVector getInArrs(); + + public native @StdVector DoublePointer getTArgs(); + + public native @Cast("Nd4jLong*") @StdVector LongPointer getIArgs(); + + public native @Cast("bool*") @StdVector BooleanPointer getBArgs(); + + public native @Cast("bool*") @StdVector BooleanPointer getAllocInfo(); + + public native int getNumInArrs(); + + public native int getNumTArgs(); + + public native int getNumIArgs(); + + public native int getNumBArgs(); + + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs, + @Cast("const bool") boolean isInPlace/*=false*/); + public native @ByVal OpArgsHolder createArgsHolderForBP(@Const @ByRef NDArrayVector inGradArrs); } + // namespace sd - - - - - -// #endif //LIBND4J_OPARGSHOLDER_H +// #endif // LIBND4J_OPARGSHOLDER_H // Parsed from array/ShapeList.h @@ -9452,57 +10463,58 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_SHAPELIST_H // #define LIBND4J_SHAPELIST_H -// #include // #include // #include - @Namespace("sd") @NoOffset public static class ShapeList extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ShapeList(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ShapeList(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ShapeList position(long position) { - return (ShapeList)super.position(position); - } - - public ShapeList(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/); - public ShapeList() { super((Pointer)null); allocate(); } - private native void allocate(); - public ShapeList(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/); - public ShapeList(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } - private native void allocate(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/); - public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace); - public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes) { super((Pointer)null); allocate(shapes); } - private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes); - //ShapeList(bool autoRemovable); - - public native @Cast("const Nd4jLong**") @StdVector PointerPointer asVector(); - public native void destroy(); - public native int size(); - public native @Cast("const Nd4jLong*") LongPointer at(int idx); - public native void push_back(@Cast("const Nd4jLong*") LongPointer shape); - public native void push_back(@Cast("const Nd4jLong*") LongBuffer shape); - public native void push_back(@Cast("const Nd4jLong*") long[] shape); - /** - * PLEASE NOTE: This method should be called ONLY if shapes were generated at workspaces. Otherwise you'll get memory leak - */ - public native void detach(); +// #include +@Namespace("sd") @NoOffset public static class ShapeList extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ShapeList(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ShapeList(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ShapeList position(long position) { + return (ShapeList)super.position(position); } + public ShapeList(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shape/*=nullptr*/); + public ShapeList() { super((Pointer)null); allocate(); } + private native void allocate(); + public ShapeList(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shape/*=nullptr*/); + public ShapeList(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/) { super((Pointer)null); allocate(shape); } + private native void allocate(@Cast("const Nd4jLong*") long[] shape/*=nullptr*/); + public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongPointer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr LongBuffer shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace) { super((Pointer)null); allocate(shapes, isWorkspace); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector @ByPtrPtr long[] shapes, @Cast("bool") boolean isWorkspace); + public ShapeList(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes) { super((Pointer)null); allocate(shapes); } + private native void allocate(@Cast("const Nd4jLong**") @StdVector PointerPointer shapes); + // ShapeList(bool autoRemovable); + + public native @Cast("const Nd4jLong**") @StdVector PointerPointer asVector(); + public native void destroy(); + public native int size(); + public native @Cast("const Nd4jLong*") LongPointer at(int idx); + public native void push_back(@Cast("const Nd4jLong*") LongPointer shape); + public native void push_back(@Cast("const Nd4jLong*") LongBuffer shape); + public native void push_back(@Cast("const Nd4jLong*") long[] shape); + /** + * PLEASE NOTE: This method should be called ONLY if shapes were generated at + * workspaces. Otherwise you'll get memory leak + */ + public native void detach(); +} + // namespace sd -// #endif //LIBND4J_SHAPELIST_H +// #endif // LIBND4J_SHAPELIST_H // Parsed from system/type_boilerplate.h @@ -9535,539 +10547,1513 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define EXPAND3(...) __VA_ARGS__ // #define EXTRACT(...) EXTRACT __VA_ARGS__ // #define NOTHING_EXTRACT -// #define PASTE(x, ...) x ## __VA_ARGS__ -// #define PASTE2(x, ...) x ## __VA_ARGS__ -// #define PASTE3(x, ...) x ## __VA_ARGS__ +// #define PASTE(x, ...) x##__VA_ARGS__ +// #define PASTE2(x, ...) x##__VA_ARGS__ +// #define PASTE3(x, ...) x##__VA_ARGS__ // #define EVALUATING_PASTE(x, ...) PASTE(x, __VA_ARGS__) // #define EVALUATING_PASTE2(x, ...) PASTE2(x, __VA_ARGS__) // #define EVALUATING_PASTE3(x, ...) PASTE3(x, __VA_ARGS__) // #define UNPAREN(x) EVALUATING_PASTE(NOTHING_, EXTRACT x) // #define UNPAREN2(x) EVALUATING_PASTE2(NOTHING_, EXTRACT x) // #define UNPAREN3(x) EVALUATING_PASTE3(NOTHING_, EXTRACT x) -// #define EVAL( x ) x -// #define EVALX( x ) x -// #define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) +// #define EVAL(x) x +// #define EVALX(x) x +// #define EVAL0(...) EVAL1(EVAL1(EVAL1(__VA_ARGS__))) // #define EVAL1(...) EVAL2(EVAL2(EVAL2(__VA_ARGS__))) // #define EVAL2(...) EVAL3(EVAL3(EVAL3(__VA_ARGS__))) // #define EVAL3(...) EVAL4(EVAL4(EVAL4(__VA_ARGS__))) // #define EVAL4(...) EVAL5(EVAL5(EVAL5(__VA_ARGS__))) // #define EVAL5(...) __VA_ARGS__ - // #define SEL_T_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - -// #define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - - - - -// #define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - +// #define SEL_T_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define SEL_T_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) +// EVAL(SEL_T_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_1(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_2(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_3(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_4(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_5(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_6(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_7(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_8(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_9(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_10(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_11(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_12(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_13(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_14(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_15(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_16(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_17(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_18(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT1_20(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(YTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT1_19(WHAT, YTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P1_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P1_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) + +// #define SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_3(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_4(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_5(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_6(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_7(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_8(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_9(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_10(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_11(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_12(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_13(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_14(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_15(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_16(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_17(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_18(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define SEL_P2_20(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// ...) +// WHAT(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_P2_19(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) + +// #define SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define SEL_TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(SEL_TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) // #define DS_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - +// #define DS_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DS_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DS_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) // #define DP_1(WHAT, NAME, SIGNATURE, TYPE_A) WHAT(NAME, SIGNATURE, TYPE_A) -// #define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) WHAT(NAME, SIGNATURE, TYPE_A)EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) -// #define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B)EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) - -// #define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -// #define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - - -// #define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - -// #define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) -// #define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) -// #define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C)EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) - - -// #define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_DP(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME,...) NAME -// #define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - - -// #define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -// #define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME -// #define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, NAME,...) NAME - -// #define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_TT1(__VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, SEL_TT1_3, SEL_TT1_2, SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_SEL_TT2(__VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, SEL_TT2_3, SEL_TT2_2, SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, DS_6, DS_5, DS_4, DS_3, DS_2, DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, DT_6, DT_5, DT_4, DT_3, DT_2, DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) EXPAND(GET_MACRO_DP(__VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - - -// #define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, __VA_ARGS__)) -// #define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, __VA_ARGS__)) -// #define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) - -// #define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, TTT1_4, TTT1_3, TTT1_2, TTT1_1)(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -// #define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, TTT2_4, TTT2_3, TTT2_2, TTT2_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, TTT3_4, TTT3_3, TTT3_2, TTT3_1)(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -// #define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) -// #define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) -// #define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) -// #define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) - -// #define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, ...) EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - -// #define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) -// #define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) -// #define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), __VA_ARGS__)) - -// #define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) -// #define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -// #define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) -// #define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, __VA_ARGS__)) - -// #define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) -// #define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) - +// #define DP_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_2(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_4(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_3(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_5(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_4(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_6(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_5(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_7(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_6(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_8(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_7(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_9(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_8(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_10(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_9(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_11(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_10(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_12(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_11(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_13(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_12(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_14(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_13(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_15(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_14(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_16(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_15(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_17(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_16(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_18(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_17(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_19(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_18(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_20(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_19(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_21(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_20(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_22(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_21(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_23(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_22(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_24(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_23(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_25(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_24(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_26(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_25(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_27(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_26(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_28(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_27(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_29(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_28(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_30(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_29(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_31(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_30(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_32(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_31(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_33(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_32(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_34(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_33(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_35(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_34(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_36(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_35(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_37(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_36(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_38(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_37(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_39(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_38(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_40(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_39(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_41(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_40(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_42(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_41(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_43(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_42(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_44(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_43(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_45(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_44(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_46(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_45(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_47(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_46(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_48(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_47(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_49(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_48(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_50(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_49(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_51(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_50(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_52(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_51(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_53(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_52(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_54(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_53(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_55(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_54(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_56(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_55(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_57(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_56(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_58(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_57(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_59(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_58(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_60(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_59(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_61(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_60(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_62(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_61(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_63(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_62(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_64(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_63(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_65(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_64(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_66(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_65(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_67(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_66(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_68(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_67(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_69(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_68(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_70(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_69(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_71(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_70(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_72(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_71(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_73(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_72(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_74(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_73(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_75(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_74(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_76(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_75(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_77(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_76(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_78(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_77(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_79(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_78(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_80(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_79(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_81(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_80(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_82(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_81(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_83(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_82(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_84(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_83(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_85(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_84(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_86(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_85(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_87(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_86(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_88(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_87(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_89(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_88(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_90(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_89(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_91(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_90(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_92(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_91(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_93(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_92(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_94(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_93(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_95(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_94(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_96(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_95(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_97(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_96(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_98(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_97(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_99(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_98(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_100(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_99(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_101(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_100(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_102(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_101(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_103(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_102(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_104(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_103(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_105(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_104(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_106(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_105(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_107(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_106(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_108(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_107(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_109(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_108(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_110(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_109(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_111(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_110(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_112(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_111(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_113(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_112(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_114(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_113(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_115(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_114(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_116(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_115(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_117(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_116(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_118(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_117(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_119(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_118(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_120(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_119(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_121(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_120(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_122(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_121(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_123(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_122(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_124(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_123(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define DP_125(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// WHAT(NAME, SIGNATURE, TYPE_A) EVAL(DP_124(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define DT_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define DT_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// #define DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_1(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_4(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_5(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_6(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_7(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_8(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_9(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_10(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_11(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_12(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_13(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_14(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_15(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_16(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_17(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_18(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define DT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVAL(DT2_19(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) + +// #define TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_2(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_3(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_4(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_5(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_6(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_7(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_8(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_9(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_10(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_11(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_12(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_13(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_14(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_15(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_16(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_17(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_18(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) +// #define TTT1_20(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, +// ...) +// WHAT(YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT1_19(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, +// __VA_ARGS__)) + +// #define TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT2_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT2_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_1(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_4(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_5(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_6(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_7(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_8(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_9(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_10(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_11(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_12(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_13(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_14(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_15(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_16(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_17(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_18(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TTT3_20(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TTT3_19(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT1_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT1_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT2_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT2_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// #define TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_1(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_2(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_3(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_4(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_5(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_6(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_7(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_8(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_9(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_10(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_11(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_12(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_13(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_14(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_15(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_16(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_17(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_18(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) +// #define TT3_20(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C, ...) +// WHAT(NAME, SIGNATURE, TYPE_A, TYPE_B, TYPE_C) +// EVAL(TT3_19(WHAT, NAME, SIGNATURE, TYPE_A, TYPE_B, __VA_ARGS__)) + +// #define GET_MACRO_SEL_T(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_P1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_P2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_SEL_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, +// _13, _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DS(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_DP( +// _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, +// _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, +// _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, +// _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, +// _62, _63, _64, _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, +// _77, _78, _79, _80, _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, +// _92, _93, _94, _95, _96, _97, _98, _99, _100, _101, _102, _103, _104, +// _105, _106, _107, _108, _109, _110, _111, _112, _113, _114, _115, _116, +// _117, _118, _119, _120, _121, _122, _123, _124, _125, NAME, ...) +// NAME +// #define GET_MACRO_DT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define GET_MACRO_TT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define GET_MACRO_TTT1(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TTT2(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME +// #define GET_MACRO_TTT3(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, +// _14, _15, _16, _17, _18, _19, _20, NAME, ...) +// NAME + +// #define FOR_EACH_S1(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_SEL_T(__VA_ARGS__, SEL_T_20, SEL_T_19, SEL_T_18, SEL_T_17, +// SEL_T_16, SEL_T_15, SEL_T_14, SEL_T_13, SEL_T_12, +// SEL_T_11, SEL_T_10, SEL_T_9, SEL_T_8, SEL_T_7, +// SEL_T_6, SEL_T_5, SEL_T_4, SEL_T_3, SEL_T_2, +// SEL_T_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_SEL_TT1( +// __VA_ARGS__, SEL_TT1_20, SEL_TT1_19, SEL_TT1_18, SEL_TT1_17, SEL_TT1_16, +// SEL_TT1_15, SEL_TT1_14, SEL_TT1_13, SEL_TT1_12, SEL_TT1_11, SEL_TT1_10, +// SEL_TT1_9, SEL_TT1_8, SEL_TT1_7, SEL_TT1_6, SEL_TT1_5, SEL_TT1_4, +// SEL_TT1_3, SEL_TT1_2, +// SEL_TT1_1)(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_SEL_P1(__VA_ARGS__, SEL_P1_20, SEL_P1_19, SEL_P1_18, +// SEL_P1_17, SEL_P1_16, SEL_P1_15, SEL_P1_14, +// SEL_P1_13, SEL_P1_12, SEL_P1_11, SEL_P1_10, +// SEL_P1_9, SEL_P1_8, SEL_P1_7, SEL_P1_6, SEL_P1_5, +// SEL_P1_4, SEL_P1_3, SEL_P1_2, SEL_P1_1)( +// WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND2(GET_MACRO_SEL_P2(__VA_ARGS__, SEL_P2_20, SEL_P2_19, SEL_P2_18, +// SEL_P2_17, SEL_P2_16, SEL_P2_15, SEL_P2_14, +// SEL_P2_13, SEL_P2_12, SEL_P2_11, SEL_P2_10, +// SEL_P2_9, SEL_P2_8, SEL_P2_7, SEL_P2_6, SEL_P2_5, +// SEL_P2_4, SEL_P2_3, SEL_P2_2, SEL_P2_1)( +// WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EXPAND(GET_MACRO_SEL_TT2( +// __VA_ARGS__, SEL_TT2_20, SEL_TT2_19, SEL_TT2_18, SEL_TT2_17, SEL_TT2_16, +// SEL_TT2_15, SEL_TT2_14, SEL_TT2_13, SEL_TT2_12, SEL_TT2_11, SEL_TT2_10, +// SEL_TT2_9, SEL_TT2_8, SEL_TT2_7, SEL_TT2_6, SEL_TT2_5, SEL_TT2_4, +// SEL_TT2_3, SEL_TT2_2, +// SEL_TT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define FOR_EACH_DS(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_DS(__VA_ARGS__, DS_20, DS_19, DS_18, DS_17, DS_16, DS_15, +// DS_14, DS_13, DS_12, DS_11, DS_10, DS_9, DS_8, DS_7, +// DS_6, DS_5, DS_4, DS_3, DS_2, +// DS_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define FOR_EACH_DT(WHAT, NAME, SIGNATURE, TYPES_A, ...) +// EXPAND(GET_MACRO_DT(__VA_ARGS__, DT_20, DT_19, DT_18, DT_17, DT_16, DT_15, +// DT_14, DT_13, DT_12, DT_11, DT_10, DT_9, DT_8, DT_7, +// DT_6, DT_5, DT_4, DT_3, DT_2, +// DT_1)(WHAT, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EXPAND(GET_MACRO_DT2(__VA_ARGS__, DT2_20, DT2_19, DT2_18, DT2_17, DT2_16, +// DT2_15, DT2_14, DT2_13, DT2_12, DT2_11, DT2_10, DT2_9, +// DT2_8, DT2_7, DT2_6, DT2_5, DT2_4, DT2_3, DT2_2, +// DT2_1)(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define FOR_EACH_DP(WHAT, NAME, SIGNATURE, ...) +// EXPAND(GET_MACRO_DP( +// __VA_ARGS__, DP_125, DP_124, DP_123, DP_122, DP_121, DP_120, DP_119, +// DP_118, DP_117, DP_116, DP_115, DP_114, DP_113, DP_112, DP_111, DP_110, +// DP_109, DP_108, DP_107, DP_106, DP_105, DP_104, DP_103, DP_102, DP_101, +// DP_100, DP_99, DP_98, DP_97, DP_96, DP_95, DP_94, DP_93, DP_92, DP_91, +// DP_90, DP_89, DP_88, DP_87, DP_86, DP_85, DP_84, DP_83, DP_82, DP_81, +// DP_80, DP_79, DP_78, DP_77, DP_76, DP_75, DP_74, DP_73, DP_72, DP_71, +// DP_70, DP_69, DP_68, DP_67, DP_66, DP_65, DP_64, DP_63, DP_62, DP_61, +// DP_60, DP_59, DP_58, DP_57, DP_56, DP_55, DP_54, DP_53, DP_52, DP_51, +// DP_50, DP_49, DP_48, DP_47, DP_46, DP_45, DP_44, DP_43, DP_42, DP_41, +// DP_40, DP_39, DP_38, DP_37, DP_36, DP_35, DP_34, DP_33, DP_32, DP_31, +// DP_30, DP_29, DP_28, DP_27, DP_26, DP_25, DP_24, DP_23, DP_22, DP_21, +// DP_20, DP_19, DP_18, DP_17, DP_16, DP_15, DP_14, DP_13, DP_12, DP_11, +// DP_10, DP_9, DP_8, DP_7, DP_6, DP_5, DP_4, DP_3, DP_2, +// DP_1)(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define FOR_EACH_TT1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) +// EXPAND(GET_MACRO_TT1(__VA_ARGS__, TT1_20, TT1_19, TT1_18, TT1_17, TT1_16, +// TT1_15, TT1_14, TT1_13, TT1_12, TT1_11, TT1_10, TT1_9, +// TT1_8, TT1_7, TT1_6, TT1_5, TT1_4, TT1_3, TT1_2, +// TT1_1)(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, +// __VA_ARGS__)) +// #define FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) +// EXPAND(GET_MACRO_TT2(__VA_ARGS__, TT2_20, TT2_19, TT2_18, TT2_17, TT2_16, +// TT2_15, TT2_14, TT2_13, TT2_12, TT2_11, TT2_10, TT2_9, +// TT2_8, TT2_7, TT2_6, TT2_5, TT2_4, TT2_3, TT2_2, +// TT2_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, +// __VA_ARGS__)) +// #define FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) +// EXPAND(GET_MACRO_TT3(__VA_ARGS__, TT3_20, TT3_19, TT3_18, TT3_17, TT3_16, +// TT3_15, TT3_14, TT3_13, TT3_12, TT3_11, TT3_10, TT3_9, +// TT3_8, TT3_7, TT3_6, TT3_5, TT3_4, TT3_3, TT3_2, +// TT3_1)(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, +// __VA_ARGS__)) + +// #define FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// ...) +// EXPAND(GET_MACRO_TTT1(__VA_ARGS__, TTT1_20, TTT1_19, TTT1_18, TTT1_17, +// TTT1_16, TTT1_15, TTT1_14, TTT1_13, TTT1_12, TTT1_11, +// TTT1_10, TTT1_9, TTT1_8, TTT1_7, TTT1_6, TTT1_5, +// TTT1_4, TTT1_3, TTT1_2, TTT1_1)( +// WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, __VA_ARGS__)) +// #define FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) +// EXPAND(GET_MACRO_TTT2(__VA_ARGS__, TTT2_20, TTT2_19, TTT2_18, TTT2_17, +// TTT2_16, TTT2_15, TTT2_14, TTT2_13, TTT2_12, TTT2_11, +// TTT2_10, TTT2_9, TTT2_8, TTT2_7, TTT2_6, TTT2_5, +// TTT2_4, TTT2_3, TTT2_2, TTT2_1)( +// WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__)) +// #define FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) +// EXPAND(GET_MACRO_TTT3(__VA_ARGS__, TTT3_20, TTT3_19, TTT3_18, TTT3_17, +// TTT3_16, TTT3_15, TTT3_14, TTT3_13, TTT3_12, TTT3_11, +// TTT3_10, TTT3_9, TTT3_8, TTT3_7, TTT3_6, TTT3_5, +// TTT3_4, TTT3_3, TTT3_2, TTT3_1)( +// WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +// #define _EXEC_SELECTOR_T(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_S1(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define _EXEC_SELECTOR_P_1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, +// TYPES_A, ...) +// EVAL(FOR_EACH_P1(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_P_2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// ...) +// EVAL(FOR_EACH_P2(WHAT, XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TT_1(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, ...) +// EVAL(FOR_EACH_S2(WHAT, YTYPE, NAME, SIGNATURE, TYPES_A, __VA_ARGS__)) +// #define _EXEC_SELECTOR_TT_2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EVAL(FOR_EACH_S3(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define _EXEC_SINGLE_T(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_DS(WHAT, NAME, SIGNATURE, __VA_ARGS__)) +// #define _EXEC_DOUBLE_T(WHAT, NAME, SIGNATURE, TYPES_A, ...) +// EVAL(FOR_EACH_DT(WHAT, NAME, SIGNATURE, LIST(TYPES_A), __VA_ARGS__)) +// #define _EXEC_DOUBLE_T2(WHAT, NAME, SIGNATURE, TYPE_A, ...) +// EVAL(FOR_EACH_DT2(WHAT, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)) +// #define _EXEC_DOUBLE_P(WHAT, NAME, SIGNATURE, ...) +// EVAL(FOR_EACH_DP(WHAT, NAME, SIGNATURE, __VA_ARGS__)) + +// #define _EXEC_SELECTOR_TTT_1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, +// TYPES_Y, ...) +// EVAL(FOR_EACH_TTT1(WHAT, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TTT_2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// ...) +// EVAL(FOR_EACH_TTT2(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// __VA_ARGS__)) +// #define _EXEC_SELECTOR_TTT_3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, +// ...) +// EVAL(FOR_EACH_TTT3(WHAT, ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, __VA_ARGS__)) + +// #define _EXEC_TRIPLE_T3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, ...) +// EVAL(FOR_EACH_TT3(WHAT, NAME, SIGNATURE, TYPE_Z, TYPE_Y, __VA_ARGS__)) +// #define _EXEC_TRIPLE_T2(WHAT, NAME, SIGNATURE, TYPE_Z, TYPES_X, ...) +// EVAL(FOR_EACH_TT2(WHAT, NAME, SIGNATURE, TYPE_Z, LIST(TYPES_X), __VA_ARGS__)) +// #define _EXEC_TRIPLE_T1(WHAT, NAME, SIGNATURE, TYPES_X, TYPES_Y, ...) +// EVAL(FOR_EACH_TT1(WHAT, NAME, SIGNATURE, LIST(TYPES_X), LIST(TYPES_Y), +// __VA_ARGS__)) + +// #define DISPATCH_PAIRWISE(NAME, SIGNATURE, TYPE, TYPES_B) +// EVAL(_EXEC_DOUBLE_T2(RANDOMPAIRWISE2, NAME, SIGNATURE, TYPE, TYPES_B)) +// #define DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE, ...) +// EVAL(_EXEC_SELECTOR_P_2(SELECTOR_PAIRWISE_2, XTYPE, YTYPE, ZTYPE, NAME, +// SIGNATURE, TYPE, __VA_ARGS__)) + +// #define DISPATCH_DTYPES(NAME, SIGNATURE, TYPE, TYPES_B) +// EVAL(_EXEC_DOUBLE_T2(RANDOMDOUBLE2, NAME, SIGNATURE, TYPE, TYPES_B)) +// #define DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE, ...) +// EVAL(_EXEC_SELECTOR_TT_2(SELECTOR_DOUBLE_2, NAME, SIGNATURE, TYPE, +// __VA_ARGS__)) + +// #define DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, ...) +// EVAL(_EXEC_SELECTOR_TTT_2(SELECTOR_TRIPLE_2, ZTYPE, NAME, SIGNATURE, TYPE_X, +// TYPES_Z, __VA_ARGS__)) +// #define DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, ...) +// EVAL(_EXEC_SELECTOR_TTT_3(SELECTOR_TRIPLE_3, ZTYPE, NAME, SIGNATURE, TYPE_X, +// TYPE_Y, __VA_ARGS__)) // #ifndef __CLION_IDE__ -// #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) -// #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) -// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) -// #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) -// #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} - - -// #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} -// #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} -// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } -// #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), TYPES_Z)) -// #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) -// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} +// #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) +// EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) +// #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) +// EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) +// #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } + +// #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), +// (TYPES_B), TYPES_A)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, +// TYPES_Y, TYPES_Z) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, +// (TYPES_Z), (TYPES_Y), TYPES_X)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } +// #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +// EVAL(_EXEC_TRIPLE_T1(RANDOMTRIPLE, NAME, (SIGNATURE), (TYPES_X), (TYPES_Y), +// TYPES_Z)) +// #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) +// EVAL(_EXEC_DOUBLE_P(RANDOMPAIRWISE, NAME, (SIGNATURE), TYPES_A)) +// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// TYPES_B) +// switch (XTYPE) { +// EVAL(_EXEC_SELECTOR_P_1(SELECTOR_PAIRWISE, XTYPE, YTYPE, ZTYPE, NAME, +// (SIGNATURE), (TYPES_B), TYPES_A)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); +// fflush(stdout); +// throw std::runtime_error("bad data type"); +// } +// } // #else // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) @@ -10078,77 +12064,208 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) -// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) +// #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, +// TYPES_Y, TYPES_Z) // #define BUILD_TRIPLE_TEMPLATE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) // #define BUILD_PAIRWISE_TEMPLATE(NAME, SIGNATURE, TYPES_A) -// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) +// #define BUILD_PAIRWISE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_A, +// TYPES_B) // #endif // #define LIST(...) __VA_ARGS__ -// #define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { NAME SIGNATURE; break; }; -// #define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) - -// #define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -// #define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -// #define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) case ENUM: { if (ZTYPE == YTYPE) {NAME SIGNATURE;} else if (XTYPE == ZTYPE ){NAME SIGNATURE;} else {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("Unknown Z operand");}; break; }; -// #define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, TYPE_B) EVALUATING_PASTE2(_SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), TYPE_A, UNPAREN3(TYPE_B))) -// #define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) case ENUM: { switch(YTYPE) { EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); fflush(stdout);}}; break; }; -// #define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), UNPAREN(TYPES_B))) - -// #define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) case ENUM_Z: { NAMESIGNATURE;}; break; -// #define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) EVALUATING_PASTE3(_SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) -// #define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, TYPES_Z) case ENUM_Y: { switch (ZTYPE) { EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPES_Z))); default: {printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -// #define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, UNPAREN2(TYPE_Y), TYPES_Z)) -// #define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, TYPES_Z, ...) case ENUM_X: { switch (YTYPE) { EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, __VA_ARGS__ )); default: {printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, __LINE__); ; fflush(stdout);} } break; }; -// #define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, TYPE_X) EVALUATING_PASTE(_SELECTOR, _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), TYPES_Z, UNPAREN(TYPES_Y))) - -// #define _SELECTOR_SINGLE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) - -// #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) - -// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; -// #define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) - -// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; -// #define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) - -// #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; -// #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) - -// #define _RANDOMSINGLE(A, B, C, D) AB; +// #define _SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, ENUM, TYPE_B) +// case ENUM: { +// NAME SIGNATURE; +// break; +// }; +// #define SELECTOR_DOUBLE_2(NAME, SIGNATURE, TYPE_A, TYPE_B) +// EVALUATING_PASTE2(_SELECT, OR_DOUBLE_2(NAME, UNPAREN3(SIGNATURE), TYPE_A, +// UNPAREN3(TYPE_B))) + +// #define _SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, ENUM, TYPE_A, ...) +// case ENUM: { +// switch (YTYPE) { +// EXPAND(DISPATCH_DTYPES2(NAME, SIGNATURE, TYPE_A, __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// } +// }; +// break; +// }; +// #define SELECTOR_DOUBLE(YTYPE, NAME, SIGNATURE, TYPES_B, TYPE_A) +// EVALUATING_PASTE(_SELECTOR, _DOUBLE(YTYPE, NAME, SIGNATURE, UNPAREN(TYPE_A), +// UNPAREN(TYPES_B))) + +// #define _SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// ENUM, TYPE_B) +// case ENUM: { +// if (ZTYPE == YTYPE) { +// NAME SIGNATURE; +// } else if (XTYPE == ZTYPE) { +// NAME SIGNATURE; +// } else { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// throw std::runtime_error("Unknown Z operand"); +// }; +// break; +// }; +// #define SELECTOR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// TYPE_B) +// EVALUATING_PASTE2( +// _SELECT, OR_PAIRWISE_2(XTYPE, YTYPE, ZTYPE, NAME, UNPAREN3(SIGNATURE), +// TYPE_A, UNPAREN3(TYPE_B))) +// #define _SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, ENUM, TYPE_A, +// ...) +// case ENUM: { +// switch (YTYPE) { +// EXPAND(DISPATCH_PAIRWISE2(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPE_A, +// __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeX=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// fflush(stdout); +// } +// }; +// break; +// }; +// #define SELECTOR_PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_B, +// TYPE_A) +// EVALUATING_PASTE(_SELECTOR, _PAIRWISE(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, +// UNPAREN(TYPE_A), UNPAREN(TYPES_B))) + +// #define _SELECTOR_TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, ENUM_Z, TYPE_Z) +// case ENUM_Z: { +// NAME SIGNATURE; +// }; break; +// #define SELECTOR_TRIPLE_3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, TYPE_Z) +// EVALUATING_PASTE3( +// _SELECTOR, _TRIPLE_3(NAME, SIGNATURE, TYPE_X, TYPE_Y, UNPAREN3(TYPE_Z))) +// #define _SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, ENUM_Y, TYPE_Y, +// TYPES_Z) +// case ENUM_Y: { +// switch (ZTYPE) { +// EXPAND2(DISPATCH_TTYPES3(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPE_Y, +// UNPAREN3(TYPES_Z))); +// default: { +// printf("[ERROR] Unknown dtypeZ=%d on %s:%d\n", ZTYPE, __FILE__, +// __LINE__); +// ; +// fflush(stdout); +// } +// } +// break; +// }; +// #define SELECTOR_TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, TYPE_Y) +// EVALUATING_PASTE2(_SELECTOR, _TRIPLE_2(ZTYPE, NAME, SIGNATURE, TYPE_X, +// UNPAREN2(TYPE_Y), TYPES_Z)) +// #define _SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, ENUM_X, TYPE_X, +// TYPES_Z, ...) +// case ENUM_X: { +// switch (YTYPE) { +// EXPAND(DISPATCH_TTYPES2(ZTYPE, NAME, SIGNATURE, TYPE_X, TYPES_Z, +// __VA_ARGS__)); +// default: { +// printf("[ERROR] Unknown dtypeY=%d on %s:%d\n", YTYPE, __FILE__, +// __LINE__); +// ; +// fflush(stdout); +// } +// } +// break; +// }; +// #define SELECTOR_TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_Z, TYPES_Y, +// TYPE_X) +// EVALUATING_PASTE(_SELECTOR, +// _TRIPLE(YTYPE, ZTYPE, NAME, SIGNATURE, UNPAREN(TYPE_X), +// TYPES_Z, UNPAREN(TYPES_Y))) + +// #define _SELECTOR_SINGLE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE(A, B, UNPAREN(C))) + +// #define _SELECTOR_SINGLE_THRICE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE_THRICE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) + +// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) +// case C: { +// A B; +// break; +// }; +// #define SELECTOR_SINGLE_TWICE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) A B; +// #define TEMPLATE_SINGLE_TWICE(A, B, C) +// EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) +// case C: { +// A D, UNPAREN2(B); +// break; +// }; +// #define SELECTOR_PARTIAL_SINGLE(A, B, C) +// EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) + +// #define _RANDOMSINGLE(A, B, C, D) A B; // #define _RANDOMSINGLEU(A, B, C, D) A D B; -// #define RANDOMSINGLE(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) -// #define RANDOMSINGLEU(A, B, C) EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) -// #define RANDOMDOUBLE(A, B, C, D) EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) - -// #define _RANDOMDOUBLE2(A, B, C, D, E, F) AB; -// #define RANDOMDOUBLE2(A, B, C, D) EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) - -// #define _RANDOMPAIRWISE2(A, B, C, D, E) AE; -// #define RANDOMPAIRWISE(A, B, C) EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) - -// #define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) AB; -// #define RANDOMTRIPLE3(A, B, Z, Y, X) EVALUATING_PASTE(_RANDOM, TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) - -// #define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, UNPAREN(TYPES_X))) -// #define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) -// #define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, UNPAREN(TYPES_Y))) -// #define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) - - -// #define BROADCAST(NAME) sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) -// #define BROADCAST_BOOL(NAME) sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, sd::broadcast::NAME) +// #define RANDOMSINGLE(A, B, C) +// EVALUATING_PASTE(_RAND, OMSINGLE(A, UNPAREN(B), UNPAREN(C))) +// #define RANDOMSINGLEU(A, B, C) +// EVALUATING_PASTE(_RAND, OMSINGLEU(A, UNPAREN(B), UNPAREN(C))) +// #define RANDOMDOUBLE(A, B, C, D) +// EXPAND(DISPATCH_DTYPES(A, UNPAREN(B), D, UNPAREN(C))) + +// #define _RANDOMDOUBLE2(A, B, C, D, E, F) A B; +// #define RANDOMDOUBLE2(A, B, C, D) +// EVALUATING_PASTE(_RAND, OMDOUBLE2(A, B, UNPAREN(C), UNPAREN(D))) + +// #define _RANDOMPAIRWISE2(A, B, C, D, E) A E; +// #define RANDOMPAIRWISE(A, B, C) +// EVALUATING_PASTE(_RANDOM, PAIRWISE2(A, UNPAREN(C), UNPAREN(B))) + +// #define _RANDOMTRIPLE3(A, B, ZN, ZT, YN, YT, XN, XT) A B; +// #define RANDOMTRIPLE3(A, B, Z, Y, X) +// EVALUATING_PASTE(_RANDOM, +// TRIPLE3(A, UNPAREN(B), UNPAREN(Z), UNPAREN(Y), UNPAREN(X))) + +// #define _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +// EVALX(_EXEC_TRIPLE_T3(RANDOMTRIPLE3, NAME, SIGNATURE, TYPE_Z, TYPE_Y, +// UNPAREN(TYPES_X))) +// #define RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPE_Y) +// _RANDOMTRIPLE2(NAME, SIGNATURE, TYPE_Z, TYPE_Y, TYPES_X) +// #define _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) +// EVAL(_EXEC_TRIPLE_T2(RANDOMTRIPLE2, NAME, SIGNATURE, TYPE_Z, TYPES_X, +// UNPAREN(TYPES_Y))) +// #define RANDOMTRIPLE(NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPE_Z) +// _RANDOMTRIPLE(NAME, SIGNATURE, TYPE_Z, TYPES_X, TYPES_Y) + +// #define BROADCAST(NAME) +// sd::BroadcastOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, +// sd::broadcast::NAME) +// #define BROADCAST_BOOL(NAME) +// sd::BroadcastBoolOpsTuple::custom(sd::scalar::NAME, sd::pairwise::NAME, +// sd::broadcast::NAME) public static final int ALL_STRINGS =UTF32; public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; -// #endif //TESTS_CPU_TYPE_BOILERPLATE_H +// #endif // TESTS_CPU_TYPE_BOILERPLATE_H // Parsed from system/op_boilerplate.h @@ -11366,12 +13483,12 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REQUIRE_OK(A) if (sd::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; // #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (sd::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; -// #define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat>; -// template struct ND4J_EXPORT __registratorHalf>; -// template struct ND4J_EXPORT __registratorDouble>; -// template struct ND4J_EXPORT __registratorSynonymHalf>; -// template struct ND4J_EXPORT __registratorSynonymDouble>; -// template struct ND4J_EXPORT __registratorSynonymFloat>; +// #define DECLARE_ENTRY(NAME, ...) template struct SD_EXPORT __registratorFloat>; +// template struct SD_EXPORT __registratorHalf>; +// template struct SD_EXPORT __registratorDouble>; +// template struct SD_EXPORT __registratorSynonymHalf>; +// template struct SD_EXPORT __registratorSynonymDouble>; +// template struct SD_EXPORT __registratorSynonymFloat>; // #if defined(_MSC_VER) || defined(_WIN64) || defined(_WIN32) || defined(__CLION_IDE__) || defined(__VSCODE__) @@ -11389,7 +13506,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REGISTER_H(NAME) template // struct __registrator_##NAME { // __registrator_##NAME() { -// OpName *ptr = new OpName(); +// auto ptr = std::make_shared(); // OpRegistrator::getInstance().registerOperation(ptr); // } // }; @@ -11403,7 +13520,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define REGISTER_C(NAME) // #endif -// #define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11413,7 +13530,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // REGISTER_H(NAME) -// #define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class ND4J_EXPORT NAME: public sd::ops::BooleanOp { +// #define DECLARE_BOOLEAN_OP(NAME, NIN, SCALAR) class SD_EXPORT NAME: public sd::ops::BooleanOp { // public: // NAME(); // protected: @@ -11426,7 +13543,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // REGISTER_C(NAME) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableListOp { +// #define DECLARE_LIST_OP(NAME, NIN, NOUT, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableListOp { // public: // NAME(); // protected: @@ -11438,7 +13555,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // REGISTER_C(NAME) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_LOGIC_OP(NAME) class ND4J_EXPORT NAME: public sd::ops::LogicOp { +// #define DECLARE_LOGIC_OP(NAME) class SD_EXPORT NAME: public sd::ops::LogicOp { // public: // NAME(); // protected: @@ -11469,7 +13586,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define DECLARE_SYN(NAME, ORIGINAL) template // struct __registratorSynonym_##NAME { // __registratorSynonym_##NAME(const char *name, const char *oname) { -// auto ptr = reinterpret_cast(OpRegistrator::getInstance().getOperation(oname)); +// auto ptr = OpRegistrator::getInstance().getOperation(oname); // if (ptr == nullptr) { // std::string newName(name); // std::string oldName(oname); @@ -11481,7 +13598,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // static sd::ops::__registratorSynonym_##NAME zzz_register_opd_##NAME(#NAME, #ORIGINAL) -// #define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_DIVERGENT_OP(NAME, NIN, NOUT, INPLACEABLE) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11504,7 +13621,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableOp { +// #define DECLARE_CONFIGURABLE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableOp { // public: // NAME(); // sd::ShapeList* calculateOutputShape(sd::ShapeList* inputShape, sd::graph::Context& block); @@ -11527,7 +13644,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // } // Nd4jStatus sd::ops::NAME::validateAndExecute(Context& block) -// #define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableReductionOp { +// #define DECLARE_REDUCTION_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableReductionOp { // public: // NAME(); // protected: @@ -11541,7 +13658,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) -// #define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::DeclarableCustomOp { +// #define DECLARE_CUSTOM_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::DeclarableCustomOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11563,7 +13680,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define DECLARE_TYPES(NAME) void sd::ops::NAME::registerTypes() -// #define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableOp { +// #define DECLARE_BROADCASTABLE_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11572,7 +13689,7 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // }; // REGISTER_H(NAME) -// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { +// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class SD_EXPORT NAME: public sd::ops::BroadcastableBoolOp { // protected: // void registerTypes(); // Nd4jStatus validateAndExecute(Context& block); @@ -11637,20 +13754,20 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define CHECK_STASH(NAME) block.getStash()->checkStash(block.getNodeId(), NAME); // #define UNSTASH(NAME) block.getStash()->extractArray(block.getNodeId(), NAME); -// #define INPUT_VARIABLE(INDEX) block.array(INDEX) +// #define INPUT_VARIABLE(INDEX) block.arrayForOp(INDEX) // #define OUTPUT_VARIABLE(INDEX) reinterpret_cast(this->getZ(block, INDEX)) // #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast(this->getNullifiedZ(block, INDEX)) -// #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList()) +// #define INPUT_LIST(INDEX) reinterpret_cast(block.getVariable(INDEX)->getNDArrayList().get()) -// #define D_ARG(INDEX) block.getDArguments()->at(INDEX) -// #define INT_ARG(INDEX) block.getIArguments()->at(INDEX) +// #define D_ARG(INDEX) block.getDArguments().at(INDEX) +// #define INT_ARG(INDEX) block.getIArguments().at(INDEX) // #define I_ARG(INDEX) INT_ARG(INDEX) -// #define T_ARG(INDEX) block.getTArguments()->at(INDEX) -// #define B_ARG(INDEX) block.getBArguments()->at(INDEX) +// #define T_ARG(INDEX) block.getTArguments().at(INDEX) +// #define B_ARG(INDEX) block.getBArguments().at(INDEX) -// #define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.getWorkspace()) +// #define COPY_SHAPE(SRC, TGT) TGT = ShapeBuilders::copyShapeInfo(SRC, true, block.workspace()) // #define COPY_SHAPE_EX(SRC, TGT, WORKSPACE) TGT = ShapeBuilders::copyShapeInfo(SRC, true, WORKSPACE) @@ -11754,15 +13871,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef ND4J_INPUTTYPE_H // #define ND4J_INPUTTYPE_H - /** enum sd::ops::InputType */ - public static final int - InputType_BOOLEAN = 0, - InputType_NUMERIC = 1, - InputType_STRINGULAR = 2, - InputType_NUMERIC_SET = 3, - InputType_STRINGULAR_SET = 4; - +/** enum sd::ops::InputType */ +public static final int + InputType_BOOLEAN = 0, + InputType_NUMERIC = 1, + InputType_STRINGULAR = 2, + InputType_NUMERIC_SET = 3, + InputType_STRINGULAR_SET = 4; + // namespace sd // #endif @@ -11791,134 +13908,155 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPDESCRIPTOR_H // #define LIBND4J_OPDESCRIPTOR_H -// #include -// #include -// #include +// #include +// #include // #include // #include -// #include -// #include - - /** - * This class is very basic info holder for ops. bean/pojo pretty much. - * - */ - @Namespace("sd::ops") @NoOffset public static class OpDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public OpDescriptor(Pointer p) { super(p); } - - // default constructor - public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } - private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace); - public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } - private native void allocate(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace); - - // constructor for boolean ops - public OpDescriptor(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } - private native void allocate(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar); - public OpDescriptor(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } - private native void allocate(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar); - - // default constructor - // constructor for configurable op - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); - - // constructor for non-configurable divergent op - public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } - private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); - public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } - private native void allocate(int numInputs, int numOutputs, @StdString String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); - - // constructor for non-configurable divergent op - - // constructor for configurable divergent op - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); - public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } - private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); - - // constructor for logical ops (while, scope, etc) - public OpDescriptor(@Cast("char*") String opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } - private native void allocate(@Cast("char*") String opName, @Cast("bool") boolean isLogic); - public OpDescriptor(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } - private native void allocate(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic); - - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef OpDescriptor other); - - // default destructor - - // this method returns minimal expected number of T arguments - public native int getNumberOfTArgs(); - - // this method returns minimal expected number of Integer arguments - public native int getNumberOfIArgs(); - - // this method returns minimal expected number of inputs - public native int getNumberOfInputs(); - - // this method returns hash code for this operation - public native @Cast("Nd4jLong") long getHash(); - - // this method returns minimal expected number of outputs - public native int getNumberOfOutputs(); - - // this method returns opName (can be empty) - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); - - // returns TRUE if this op is divergent. FALSE otherwise - public native @Cast("bool") boolean isDivergent(); - - // returns TRUE if this op allows in-place execution - public native @Cast("bool") boolean allowsInplace(); - - // this method allows you to enable/disable inplace call for a given op - public native void allowInplace(@Cast("bool") boolean reallyAllow); - - // this method returns opNum (applicable for legacy XYZ ops only) - public native int getOpNum(); - - // this method allows to set specifc opNum - public native void setOpNum(int opNum); - - public native void setHash(@Cast("Nd4jLong") long hash); - - public native @Cast("sd::ops::InputType") int inputType(); - - - - public native OpDescriptor setInputType(@Cast("sd::ops::InputType") int type); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector IntPointer dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector IntBuffer dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType*") @StdVector int[] dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector IntPointer dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector IntBuffer dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType*") @StdVector int[] dtype); - public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedInputTypes(@Cast("sd::DataType") int dtype); - public native OpDescriptor setAllowedOutputTypes(@Cast("sd::DataType") int dtype); - public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); - public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); - public native OpDescriptor setInputType(int idx, @Cast("sd::DataType") int dtype); - public native OpDescriptor setOutputType(int idx, @Cast("sd::DataType") int dtype); - - public native @Cast("sd::DataType*") @StdVector IntPointer getOutputTypesForOutput(int index); - - public native @Cast("bool") boolean checkInputMatch(int index, @Cast("sd::DataType") int dataType); - public native @Cast("bool") boolean checkOutputMatch(int index, @Cast("sd::DataType") int dataType); - public native @Cast("bool") boolean isSameMode(); - - public native @Cast("bool") boolean isInherit(int index); - } - +// #include +// #include +// #include +/** + * This class is very basic info holder for ops. bean/pojo pretty much. + * + */ +@Namespace("sd::ops") @NoOffset public static class OpDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public OpDescriptor(Pointer p) { super(p); } + + // default constructor + public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } + private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace); + public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace); } + private native void allocate(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace); + + // constructor for boolean ops + public OpDescriptor(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } + private native void allocate(int numInputs, @StdString BytePointer opName, @Cast("bool") boolean isScalar); + public OpDescriptor(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar) { super((Pointer)null); allocate(numInputs, opName, isScalar); } + private native void allocate(int numInputs, @StdString String opName, @Cast("bool") boolean isScalar); + + // default constructor + + // constructor for configurable op + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, int tArgs, int iArgs); + + // constructor for non-configurable divergent op + public OpDescriptor(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } + private native void allocate(int numInputs, int numOutputs, @StdString BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); + public OpDescriptor(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent); } + private native void allocate(int numInputs, int numOutputs, @StdString String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent); + + // constructor for non-configurable divergent op + + // constructor for configurable divergent op + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") String opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); + public OpDescriptor(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs) { super((Pointer)null); allocate(numInputs, numOutputs, opName, allowsInplace, divergent, tArgs, iArgs); } + private native void allocate(int numInputs, int numOutputs, @Cast("char*") BytePointer opName, + @Cast("bool") boolean allowsInplace, @Cast("bool") boolean divergent, int tArgs, int iArgs); + + // constructor for logical ops (while, scope, etc) + public OpDescriptor(@Cast("char*") String opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } + private native void allocate(@Cast("char*") String opName, @Cast("bool") boolean isLogic); + public OpDescriptor(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic) { super((Pointer)null); allocate(opName, isLogic); } + private native void allocate(@Cast("char*") BytePointer opName, @Cast("bool") boolean isLogic); + + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef OpDescriptor other); + + // default destructor + + // this method returns minimal expected number of T arguments + public native int getNumberOfTArgs(); + + // this method returns minimal expected number of Integer arguments + public native int getNumberOfIArgs(); + + // this method returns minimal expected number of inputs + public native int getNumberOfInputs(); + + // this method returns hash code for this operation + public native @Cast("Nd4jLong") long getHash(); + + // this method returns minimal expected number of outputs + public native int getNumberOfOutputs(); + + // this method returns opName (can be empty) + public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); + + // returns TRUE if this op is divergent. FALSE otherwise + public native @Cast("bool") boolean isDivergent(); + + // returns TRUE if this op allows in-place execution + public native @Cast("bool") boolean allowsInplace(); + + // this method allows you to enable/disable inplace call for a given op + public native void allowInplace(@Cast("bool") boolean reallyAllow); + + // this method returns opNum (applicable for legacy XYZ ops only) + public native int getOpNum(); + + // this method allows to set specifc opNum + public native void setOpNum(int opNum); + + public native void setHash(@Cast("Nd4jLong") long hash); + + public native @Cast("sd::ops::InputType") int inputType(); + + public native OpDescriptor setInputType(@Cast("sd::ops::InputType") int type); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector IntPointer dtype); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector IntBuffer dtype); + public native OpDescriptor setAllowedInputTypes(int index, + @Cast("sd::DataType*") @StdVector int[] dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector IntPointer dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector IntBuffer dtype); + public native OpDescriptor setAllowedOutputTypes(int index, + @Cast("sd::DataType*") @StdVector int[] dtype); + public native OpDescriptor setAllowedInputTypes(int index, @Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedOutputTypes(int index, @Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedInputTypes(@Cast("sd::DataType") int dtype); + public native OpDescriptor setAllowedOutputTypes(@Cast("sd::DataType") int dtype); + public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); + public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); + public native OpDescriptor setInputType(int idx, @Cast("sd::DataType") int dtype); + public native OpDescriptor setOutputType(int idx, @Cast("sd::DataType") int dtype); + + public native @Cast("sd::DataType*") @StdVector IntPointer getOutputTypesForOutput(int index); + + public native @Cast("bool") boolean checkInputMatch(int index, @Cast("sd::DataType") int dataType); + public native @Cast("bool") boolean checkOutputMatch(int index, @Cast("sd::DataType") int dataType); + public native @Cast("bool") boolean isSameMode(); + + public native @Cast("bool") boolean isInherit(int index); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_OPDESCRIPTOR_H +// #endif // LIBND4J_OPDESCRIPTOR_H // Parsed from ops/declarable/PlatformHelper.h @@ -11946,65 +14084,68 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef SD_PLATFORMHELPER_H // #define SD_PLATFORMHELPER_H -// #include // #include // #include -// #include -// #include +// #include // #include - /** - * This abstract class defines methods used by platform-specific helpers implementations - */ - @Namespace("sd::ops::platforms") @NoOffset public static class PlatformHelper extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public PlatformHelper(Pointer p) { super(p); } - - - public native @StdString BytePointer name(); - - public native @Cast("samediff::Engine") int engine(); - - public native @Cast("Nd4jLong") long hash(); - - /** - * This method checks, if given helper can be used with given input/output/configuration options - * - * @param context - * @return - */ - public native @Cast("bool") boolean isUsable(@ByRef Context context); - - /** - * This method invokes helper. Typically this method replaces actual op execution - * - * @param context - * @return - */ - public native @Cast("Nd4jStatus") int invokeHelper(@ByRef Context context); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - public native NDArray getZ(@ByRef Context ctx, int inputId); - - /** - * Helper method, needed for compatibility with DeclarableOp macros - * @param ctx - * @param inputId - * @return - */ - public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); - } - - +// #include + +// #include +/** + * This abstract class defines methods used by platform-specific helpers + * implementations + */ +@Namespace("sd::ops::platforms") @NoOffset public static class PlatformHelper extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public PlatformHelper(Pointer p) { super(p); } + + + public native @StdString BytePointer name(); + + public native @Cast("samediff::Engine") int engine(); + + public native @Cast("Nd4jLong") long hash(); + + /** + * This method checks, if given helper can be used with given + * input/output/configuration options + * + * @param context + * @return + */ + public native @Cast("bool") boolean isUsable(@ByRef Context context); + + /** + * This method invokes helper. Typically this method replaces actual op + * execution + * + * @param context + * @return + */ + public native @Cast("Nd4jStatus") int invokeHelper(@ByRef Context context); + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getZ(@ByRef Context ctx, int inputId); + /** + * Helper method, needed for compatibility with DeclarableOp macros + * @param ctx + * @param inputId + * @return + */ + public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId); +} + // namespace platforms + // namespace ops + // namespace sd -// #endif //SD_PLATFORMHELPER_H +// #endif // SD_PLATFORMHELPER_H // Parsed from ops/declarable/BroadcastableOp.h @@ -12033,22 +14174,23 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_BROADCASTABLEOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" -// #include "DeclarableCustomOp.h" - @Namespace("sd::ops") public static class BroadcastableOp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BroadcastableOp(Pointer p) { super(p); } - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +// #include "DeclarableCustomOp.h" +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") public static class BroadcastableOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableOp(Pointer p) { super(p); } + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_BROADCASTABLEOP_H +// #endif // LIBND4J_BROADCASTABLEOP_H // Parsed from ops/declarable/BroadcastableBoolOp.h @@ -12077,22 +14219,23 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define SD_BROADCASTABLEBOOLOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" -// #include "DeclarableCustomOp.h" - @Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BroadcastableBoolOp(Pointer p) { super(p); } - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +// #include "DeclarableCustomOp.h" +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BroadcastableBoolOp(Pointer p) { super(p); } + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //SD_BROADCASTABLEBOOLOP_H +// #endif // SD_BROADCASTABLEBOOLOP_H // Parsed from ops/declarable/DeclarableOp.h @@ -12120,156 +14263,290 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_DECLARABLE_OPS_H // #define LIBND4J_DECLARABLE_OPS_H -// #include -// #include -// #include // #include -// #include -// #include "OpDescriptor.h" -// #include -// #include // #include +// #include +// #include // #include -// #include +// #include // #include +// #include +// #include +// #include + +// #include + +// #include "OpDescriptor.h" //#include // #include // #include // #include - @Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") String file, int line, int condition, int argNumber, @Cast("char*") String format); - @Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") BytePointer file, int line, int condition, int argNumber, @Cast("char*") BytePointer format); +@Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") String file, int line, int condition, + int argNumber, @Cast("char*") String format); +@Namespace("sd::ops") public static native @Cast("Nd4jStatus") int conditionHelper(@Cast("char*") BytePointer file, int line, int condition, + int argNumber, @Cast("char*") BytePointer format); - /** - * This class is the basic building block of Graph Operations. Any CustomOp out there is built on top of this "abstract" class. - * - */ - @Namespace("sd::ops") @NoOffset public static class DeclarableOp extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableOp(Pointer p) { super(p); } - - // for special cases, like BooleanOps - - // regular constructors - - // for LogicalOps - - // default testructor - - // this method returns OpDescriptor, describing this Op instance - public native OpDescriptor getOpDescriptor(); - - public native @Cast("Nd4jStatus") int validateDataTypes(@ByRef Context block); - - /** - * This method should be available in each implemented Op, and should return Op output shape(s), for a given input shape(s) - */ - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - - /** - * Returns opName - * - * @return - */ - public native @StdString @Cast({"char*", "std::string*"}) BytePointer getOpName(); - - /** - * Returns opHash - */ - public native @Cast("Nd4jLong") long getOpHash(); - - /** - * This method sets arguments for op - */ -// void setArguments(); - - /** - * This method returns pointer to results - */ -// void getResults(); - - /** - * This method executes given Op - * - * @param block - * @return 0 if OK, error code otherwise - */ - public native @Cast("Nd4jStatus") int execute(Context block); - - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs); - - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); - public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - - public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); - public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); - - - // There methods provide various validation options - public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); - - // this method checks if all input arrays have equal lengths - public native @Cast("Nd4jStatus") int validateInputLengthMatch(@ByRef Context block); - - // this method checks if all input arrays have the same shapes (orders/strides are NOT checked) - public native @Cast("Nd4jStatus") int validateInputDimensionsMatch(@ByRef Context block); - - // this method check if all input arrays have the same orders - public native @Cast("Nd4jStatus") int validateOrdersMatch(@ByRef Context block); - - // this method checks if all input arrays are 2D - public native @Cast("Nd4jStatus") int validateInput2D(@ByRef Context block); - - // this method checks if all input arrays are 3D - public native @Cast("Nd4jStatus") int validateInput3D(@ByRef Context block); - - // this method checks if all input arrays are 4D - public native @Cast("Nd4jStatus") int validateInput4D(@ByRef Context block); - - // this method checks if all input arrays are ND - public native @Cast("Nd4jStatus") int validateInputDimensions(@ByRef Context block, int rank); - - // this method checks if number of available arguments matches op expectations - public native @Cast("Nd4jStatus") int validateArguments(@ByRef Context block); - } - +/** + * This class is the basic building block of Graph Operations. Any CustomOp out + * there is built on top of this "abstract" class. + * + */ +@Namespace("sd::ops") @NoOffset public static class DeclarableOp extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableOp(Pointer p) { super(p); } + + // for special cases, like BooleanOps + + // regular constructors + // for LogicalOps -// #endif //LIBND4J_DECLARABLE_OPS_H + // default testructor + + // this method returns OpDescriptor, describing this Op instance + public native OpDescriptor getOpDescriptor(); + + public native @Cast("Nd4jStatus") int validateDataTypes(@ByRef Context block); + + /** + * This method should be available in each implemented Op, and should return + * Op output shape(s), for a given input shape(s) + */ + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); + + /** + * Returns opName + * + * @return + */ + public native @StdString BytePointer getOpName(); + + /** + * Returns opHash + */ + public native @Cast("Nd4jLong") long getOpHash(); + + /** + * This method sets arguments for op + */ + // void setArguments(); + + /** + * This method returns pointer to results + */ + // void getResults(); + + /** + * This method executes given Op + * + * @param block + * @return 0 if OK, error code otherwise + */ + public native @Cast("Nd4jStatus") int execute(Context block); + + public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs); + + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @Cast("Nd4jStatus") int execute( + @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, + @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); + + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, + @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate( + @Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, + @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/); + + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, + @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, + @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, + @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, + @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, + @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); + public native @Cast("Nd4jStatus") int execute( + @ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, + @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, + @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); + + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); + + // There methods provide various validation options + public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); + + // this method checks if all input arrays have equal lengths + public native @Cast("Nd4jStatus") int validateInputLengthMatch(@ByRef Context block); + + // this method checks if all input arrays have the same shapes (orders/strides + // are NOT checked) + public native @Cast("Nd4jStatus") int validateInputDimensionsMatch(@ByRef Context block); + + // this method check if all input arrays have the same orders + public native @Cast("Nd4jStatus") int validateOrdersMatch(@ByRef Context block); + + // this method checks if all input arrays are 2D + public native @Cast("Nd4jStatus") int validateInput2D(@ByRef Context block); + + // this method checks if all input arrays are 3D + public native @Cast("Nd4jStatus") int validateInput3D(@ByRef Context block); + + // this method checks if all input arrays are 4D + public native @Cast("Nd4jStatus") int validateInput4D(@ByRef Context block); + + // this method checks if all input arrays are ND + public native @Cast("Nd4jStatus") int validateInputDimensions(@ByRef Context block, int rank); + + // this method checks if number of available arguments matches op expectations + public native @Cast("Nd4jStatus") int validateArguments(@ByRef Context block); + + /** + * This method pre-allocates NDArrays for Op output, in case they are not + * available at op execution time + */ + public native int prepareOutputs(@ByRef Context block); +} + // namespace ops + // namespace sd + +// #endif // LIBND4J_DECLARABLE_OPS_H // Parsed from ops/declarable/DeclarableListOp.h @@ -12299,24 +14576,36 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include // #include -// #include // #include - @Namespace("sd::ops") public static class DeclarableListOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableListOp(Pointer p) { super(p); } - - - - public native @Cast("Nd4jStatus") int execute(Context block); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); - public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - - +// #include +@Namespace("sd::ops") public static class DeclarableListOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableListOp(Pointer p) { super(p); } + + + public native @Cast("Nd4jStatus") int execute(Context block); + + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector DoublePointer tArgs/*={}*/, + @StdVector IntPointer iArgs/*={}*/); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector DoubleBuffer tArgs/*={}*/, + @StdVector IntBuffer iArgs/*={}*/); + public native @ByVal ResultSet execute(@Const @ByRef NDArrayList list, + @Const @ByRef NDArrayVector inputs, + @StdVector double[] tArgs/*={}*/, + @StdVector int[] iArgs/*={}*/); + + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd // #endif @@ -12346,18 +14635,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_DECLARABLE_REDUCTION_OP_H // #include - @Namespace("sd::ops") public static class DeclarableReductionOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableReductionOp(Pointer p) { super(p); } - +@Namespace("sd::ops") public static class DeclarableReductionOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableReductionOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_DECLARABLE_REDUCTION_OP_H +// #endif // LIBND4J_DECLARABLE_REDUCTION_OP_H // Parsed from ops/declarable/DeclarableCustomOp.h @@ -12386,18 +14676,19 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_DECLARABLECUSTOMOP_H // #include - @Namespace("sd::ops") public static class DeclarableCustomOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DeclarableCustomOp(Pointer p) { super(p); } - +@Namespace("sd::ops") public static class DeclarableCustomOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DeclarableCustomOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShapes, @ByRef Context block); - } - + public native ShapeList calculateOutputShape(ShapeList inputShapes, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_DECLARABLECUSTOMOP_H +// #endif // LIBND4J_DECLARABLECUSTOMOP_H // Parsed from ops/declarable/BooleanOp.h @@ -12426,27 +14717,27 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_BOOLEANOP_H // #include -// #include "OpDescriptor.h" -// #include "DeclarableOp.h" - @Namespace("sd::ops") @NoOffset public static class BooleanOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public BooleanOp(Pointer p) { super(p); } - - - public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); - public native @Cast("bool") boolean verify(@ByRef Context block); - public native @Cast("Nd4jStatus") int execute(Context block); +// #include "DeclarableOp.h" +// #include "OpDescriptor.h" +@Namespace("sd::ops") @NoOffset public static class BooleanOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BooleanOp(Pointer p) { super(p); } - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - + public native @Cast("bool") boolean verify(@Const @ByRef NDArrayVector args); + public native @Cast("bool") boolean verify(@ByRef Context block); + public native @Cast("Nd4jStatus") int execute(Context block); + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_BOOLEANOP_H +// #endif // LIBND4J_BOOLEANOP_H // Parsed from ops/declarable/LogicOp.h @@ -12475,29 +14766,31 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include "DeclarableOp.h" - /** - * Logic ops are unique snowflakes in any Graph. They dramatically change Graph Execution process, by introducing loops, conditions, etc. - * - * Their code is the part of GraphExecutioner logic. But we still want them to be expressed via Graph - * \tparam T - */ - @Namespace("sd::ops") public static class LogicOp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public LogicOp(Pointer p) { super(p); } - - public LogicOp(@Cast("char*") String name) { super((Pointer)null); allocate(name); } - private native void allocate(@Cast("char*") String name); - public LogicOp(@Cast("char*") BytePointer name) { super((Pointer)null); allocate(name); } - private native void allocate(@Cast("char*") BytePointer name); - - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - +/** + * Logic ops are unique snowflakes in any Graph. They dramatically change Graph + * Execution process, by introducing loops, conditions, etc. + * + * Their code is the part of GraphExecutioner logic. But we still want them to + * be expressed via Graph + * \tparam T + */ +@Namespace("sd::ops") public static class LogicOp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public LogicOp(Pointer p) { super(p); } + public LogicOp(@Cast("char*") String name) { super((Pointer)null); allocate(name); } + private native void allocate(@Cast("char*") String name); + public LogicOp(@Cast("char*") BytePointer name) { super((Pointer)null); allocate(name); } + private native void allocate(@Cast("char*") BytePointer name); + public native ShapeList calculateOutputShape(ShapeList inputShape, + @ByRef Context block); +} + // namespace ops + // namespace sd -// #endif //LIBND4J_LOGICOP_H +// #endif // LIBND4J_LOGICOP_H // Parsed from ops/declarable/OpRegistrator.h @@ -12525,76 +14818,84 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_OPREGISTRATOR_H // #define LIBND4J_OPREGISTRATOR_H -// #include -// #include -// #include -// #include +// #include // #include // #include -// #include +// #include + +// #include +// #include +// #include // handlers part -// #include // #include +// #include // #ifndef __JAVACPP_HACK__ // #endif - /** - * This class provides runtime ops lookup, based on opName or opHash. - * To build lookup directory we use *_OP_IMPL macro, which puts static structs at compile time in .cpp files, - * so once binary is executed, static objects are initialized automatically, and we get list of all ops - * available at runtime via this singleton. - * - */ - @Namespace("sd::ops") @NoOffset public static class OpRegistrator extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public OpRegistrator(Pointer p) { super(p); } - +/** + * This class provides runtime ops lookup, based on opName or opHash. + * To build lookup directory we use *_OP_IMPL macro, which puts static structs + * at compile time in .cpp files, so once binary is executed, static objects are + * initialized automatically, and we get list of all ops available at runtime + * via this singleton. + * + */ +@Namespace("sd::ops") @NoOffset public static class OpRegistrator extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public OpRegistrator(Pointer p) { super(p); } - public static native @ByRef OpRegistrator getInstance(); - public static native void exitHandler(); - public static native void sigIntHandler(int sig); - public static native void sigSegVHandler(int sig); + public static native @ByRef OpRegistrator getInstance(); - - public native @Cast("char*") String getAllCustomOperations(); + public static native void exitHandler(); + public static native void sigIntHandler(int sig); + public static native void sigSegVHandler(int sig); - /** - * This method registers operation in our registry, so we can use them later - * - * @param op - */ - public native @Cast("bool") boolean registerOperation(@Cast("char*") String name, DeclarableOp op); - public native @Cast("bool") boolean registerOperation(@Cast("char*") BytePointer name, DeclarableOp op); - public native @Cast("bool") boolean registerOperation(DeclarableOp op); + + public native @Cast("char*") String getAllCustomOperations(); - public native void registerHelper(PlatformHelper op); + /** + * This method registers operation in our registry, so we can use them later + * + * @param op + */ + public native @Cast("bool") boolean registerOperation(@StdString BytePointer opName, + @SharedPtr DeclarableOp op); + public native @Cast("bool") boolean registerOperation(@StdString String opName, + @SharedPtr DeclarableOp op); + public native @Cast("bool") boolean registerOperation(@SharedPtr DeclarableOp op); - public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native void registerHelper(PlatformHelper op); - public native DeclarableOp getOperation(@Cast("char*") String name); - public native DeclarableOp getOperation(@Cast("char*") BytePointer name); - public native DeclarableOp getOperation(@Cast("Nd4jLong") long hash); + public native @Cast("bool") boolean hasHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); - public native PlatformHelper getPlatformHelper(@Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native @SharedPtr DeclarableOp getOperation(@Cast("Nd4jLong") long hash); + public native @SharedPtr DeclarableOp getOperation(@StdString BytePointer name); + public native @SharedPtr DeclarableOp getOperation(@StdString String name); - public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); + public native @Cast("bool") boolean hasOperation(@StdString BytePointer opName); + public native @Cast("bool") boolean hasOperation(@StdString String opName); + public native @Cast("bool") boolean hasOperation(@Cast("const Nd4jLong") long opName); - public native int numberOfOperations(); - } + public native PlatformHelper getPlatformHelper( + @Cast("Nd4jLong") long hash, @Cast("samediff::Engine") int engine); + public native @Cast("Nd4jLong*") @StdVector LongPointer getAllHashes(); - /* - * These structs are used to "register" our ops in OpRegistrator. - */ + public native int numberOfOperations(); +} - +/* + * These structs are used to "register" our ops in OpRegistrator. + */ + // namespace ops + // namespace sd -// #endif //LIBND4J_OPREGISTRATOR_H +// #endif // LIBND4J_OPREGISTRATOR_H // Parsed from ops/declarable/CustomOperations.h @@ -12622,156 +14923,156 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CUSTOMOPERATIONS_H // #define LIBND4J_CUSTOMOPERATIONS_H +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include // #include +// #include +// #include // #include // #include -// #include +// #include // #include +// #include +// #include +// #include +// #include // #include -// #include -// #include -// #include -// #include -// #include +// #include // #include // #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include +// #include // #include -// #include -// #include -// #include -// #include +// #include +// #include +// #include // #include +// #include // #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include -// #include - @Namespace("sd") public static class _loader extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public _loader(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public _loader(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public _loader position(long position) { - return (_loader)super.position(position); - } - - public _loader() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd") public static class _loader extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public _loader(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public _loader(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public _loader position(long position) { + return (_loader)super.position(position); + } + + public _loader() { super((Pointer)null); allocate(); } + private native void allocate(); +} + +// logic ops +@Namespace("sd::ops") public static class Switch extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Switch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Switch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Switch position(long position) { + return (Switch)super.position(position); } - // logic ops - @Namespace("sd::ops") public static class Switch extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Switch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Switch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Switch position(long position) { - return (Switch)super.position(position); - } - public Switch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class While extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public While(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public While(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public While position(long position) { - return (While)super.position(position); - } - +@Namespace("sd::ops") public static class While extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public While(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public While(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public While position(long position) { + return (While)super.position(position); + } + public While() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Scope extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Scope(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Scope(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Scope position(long position) { - return (Scope)super.position(position); - } - +@Namespace("sd::ops") public static class Scope extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Scope(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Scope(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Scope position(long position) { + return (Scope)super.position(position); + } + public Scope() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Conditional extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Conditional(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Conditional(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Conditional position(long position) { - return (Conditional)super.position(position); - } - +@Namespace("sd::ops") public static class Conditional extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Conditional(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Conditional(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Conditional position(long position) { + return (Conditional)super.position(position); + } + public Conditional() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Return extends LogicOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Return(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Return(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Return position(long position) { - return (Return)super.position(position); - } - +@Namespace("sd::ops") public static class Return extends LogicOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Return(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Return(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Return position(long position) { + return (Return)super.position(position); + } + public Return() { super((Pointer)null); allocate(); } private native void allocate(); } +/** + * This operations exposes given arguments as it's own outputs, but does it only + * once. Subsequent calls will be served directly by this op. + * + * PLEASE NOTE: This operation is internal graph operation, and shouldn't be + * used directly usually. + */ +@Namespace("sd::ops") public static class expose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public expose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public expose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public expose position(long position) { + return (expose)super.position(position); + } - /** - * This operations exposes given arguments as it's own outputs, but does it only once. - * Subsequent calls will be served directly by this op. - * - * PLEASE NOTE: This operation is internal graph operation, and shouldn't be used directly usually. - */ - @Namespace("sd::ops") public static class expose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public expose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public expose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public expose position(long position) { - return (expose)super.position(position); - } - public expose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - + // namespace ops + // namespace sd - -// #endif //LIBND4J_CUSTOMOPERATIONS_H +// #endif // LIBND4J_CUSTOMOPERATIONS_H // Parsed from ops/declarable/headers/activations.h @@ -12799,696 +15100,693 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_ACTIVATIONS_H // #define LIBND4J_HEADERS_ACTIVATIONS_H - // #include - /** - * This is Sigmoid activation function implementation - * Math is: 1 / 1 + exp(-x) - */ -// #if NOT_EXCLUDED(OP_sigmoid) - @Namespace("sd::ops") public static class sigmoid extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigmoid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigmoid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigmoid position(long position) { - return (sigmoid)super.position(position); - } - +/** + * This is Sigmoid activation function implementation + * Math is: 1 / 1 + exp(-x) + */ +// #if NOT_EXCLUDED(OP_sigmoid) +@Namespace("sd::ops") public static class sigmoid extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigmoid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigmoid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigmoid position(long position) { + return (sigmoid)super.position(position); + } + public sigmoid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sigmoid_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigmoid_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigmoid_bp position(long position) { - return (sigmoid_bp)super.position(position); - } - +@Namespace("sd::ops") public static class sigmoid_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigmoid_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigmoid_bp position(long position) { + return (sigmoid_bp)super.position(position); + } + public sigmoid_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Softsign activation function implementation + * Math is: x / 1 + abs(x) + */ +// #if NOT_EXCLUDED(OP_softsign) +@Namespace("sd::ops") public static class softsign extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softsign(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softsign(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softsign position(long position) { + return (softsign)super.position(position); + } - /** - * This is Softsign activation function implementation - * Math is: x / 1 + abs(x) - */ -// #if NOT_EXCLUDED(OP_softsign) - @Namespace("sd::ops") public static class softsign extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softsign(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softsign(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softsign position(long position) { - return (softsign)super.position(position); - } - public softsign() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softsign_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softsign_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softsign_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softsign_bp position(long position) { - return (softsign_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softsign_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softsign_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softsign_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softsign_bp position(long position) { + return (softsign_bp)super.position(position); + } + public softsign_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Tanh activation function implementation + */ +// #if NOT_EXCLUDED(OP_tanh) +@Namespace("sd::ops") public static class tanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tanh position(long position) { + return (tanh)super.position(position); + } - /** - * This is Tanh activation function implementation - */ -// #if NOT_EXCLUDED(OP_tanh) - @Namespace("sd::ops") public static class tanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tanh position(long position) { - return (tanh)super.position(position); - } - public tanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tanh_bp position(long position) { - return (tanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tanh_bp position(long position) { + return (tanh_bp)super.position(position); + } + public tanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Softplus activation function implementation + * Math is: log(1 + exp(x)) + */ +// #if NOT_EXCLUDED(OP_softplus) +@Namespace("sd::ops") public static class softplus extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softplus(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softplus(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softplus position(long position) { + return (softplus)super.position(position); + } - /** - * This is Softplus activation function implementation - * Math is: log(1 + exp(x)) - */ -// #if NOT_EXCLUDED(OP_softplus) - @Namespace("sd::ops") public static class softplus extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softplus(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softplus(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softplus position(long position) { - return (softplus)super.position(position); - } - public softplus() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softplus_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softplus_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softplus_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softplus_bp position(long position) { - return (softplus_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softplus_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softplus_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softplus_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softplus_bp position(long position) { + return (softplus_bp)super.position(position); + } + public softplus_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RELU activation function implementation + */ +// #if NOT_EXCLUDED(OP_relu) +@Namespace("sd::ops") public static class relu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu position(long position) { + return (relu)super.position(position); + } - /** - * This is RELU activation function implementation - */ -// #if NOT_EXCLUDED(OP_relu) - @Namespace("sd::ops") public static class relu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu position(long position) { - return (relu)super.position(position); - } - public relu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class relu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu_bp position(long position) { - return (relu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class relu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu_bp position(long position) { + return (relu_bp)super.position(position); + } + public relu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is SELU activation function implementation + */ +// #if NOT_EXCLUDED(OP_selu) +@Namespace("sd::ops") public static class selu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public selu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public selu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public selu position(long position) { + return (selu)super.position(position); + } - /** - * This is SELU activation function implementation - */ -// #if NOT_EXCLUDED(OP_selu) - @Namespace("sd::ops") public static class selu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public selu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public selu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public selu position(long position) { - return (selu)super.position(position); - } - public selu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class selu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public selu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public selu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public selu_bp position(long position) { - return (selu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class selu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public selu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public selu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public selu_bp position(long position) { + return (selu_bp)super.position(position); + } + public selu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Leaky RELU activation function. + * Math is: x < 0 ? alpha * x : x; + */ +// #if NOT_EXCLUDED(OP_lrelu) +@Namespace("sd::ops") public static class lrelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrelu position(long position) { + return (lrelu)super.position(position); + } - /** - * This is Leaky RELU activation function. - * Math is: x < 0 ? alpha * x : x; - */ -// #if NOT_EXCLUDED(OP_lrelu) - @Namespace("sd::ops") public static class lrelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrelu position(long position) { - return (lrelu)super.position(position); - } - public lrelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class lrelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrelu_bp position(long position) { - return (lrelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class lrelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrelu_bp position(long position) { + return (lrelu_bp)super.position(position); + } + public lrelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op is ELU activation function. + * Math is: x >= 0 ? x : exp(x) - 1; + */ +// #if NOT_EXCLUDED(OP_elu) +@Namespace("sd::ops") public static class elu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public elu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public elu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public elu position(long position) { + return (elu)super.position(position); + } - /** - * This op is ELU activation function. - * Math is: x >= 0 ? x : exp(x) - 1; - */ -// #if NOT_EXCLUDED(OP_elu) - @Namespace("sd::ops") public static class elu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public elu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public elu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public elu position(long position) { - return (elu)super.position(position); - } - public elu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class elu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public elu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public elu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public elu_bp position(long position) { - return (elu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class elu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public elu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public elu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public elu_bp position(long position) { + return (elu_bp)super.position(position); + } + public elu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Cube activation function. + * Math is: x^3 + */ +// #if NOT_EXCLUDED(OP_cube) +@Namespace("sd::ops") public static class cube extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cube(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cube(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cube position(long position) { + return (cube)super.position(position); + } - /** - * This is Cube activation function. - * Math is: x^3 - */ -// #if NOT_EXCLUDED(OP_cube) - @Namespace("sd::ops") public static class cube extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cube(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cube(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cube position(long position) { - return (cube)super.position(position); - } - public cube() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class cube_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cube_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cube_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cube_bp position(long position) { - return (cube_bp)super.position(position); - } - +@Namespace("sd::ops") public static class cube_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cube_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cube_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cube_bp position(long position) { + return (cube_bp)super.position(position); + } + public cube_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RectifiedTanh activation function. + * Math is: max(0, tanh(x)) + */ +// #if NOT_EXCLUDED(OP_rectifiedtanh) +@Namespace("sd::ops") public static class rectifiedtanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rectifiedtanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rectifiedtanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rectifiedtanh position(long position) { + return (rectifiedtanh)super.position(position); + } - /** - * This is RectifiedTanh activation function. - * Math is: max(0, tanh(x)) - */ -// #if NOT_EXCLUDED(OP_rectifiedtanh) - @Namespace("sd::ops") public static class rectifiedtanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rectifiedtanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rectifiedtanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rectifiedtanh position(long position) { - return (rectifiedtanh)super.position(position); - } - public rectifiedtanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class rectifiedtanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rectifiedtanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rectifiedtanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rectifiedtanh_bp position(long position) { - return (rectifiedtanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class rectifiedtanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rectifiedtanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rectifiedtanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rectifiedtanh_bp position(long position) { + return (rectifiedtanh_bp)super.position(position); + } + public rectifiedtanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RationalTanh activation function. + */ +// #if NOT_EXCLUDED(OP_rationaltanh) +@Namespace("sd::ops") public static class rationaltanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rationaltanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rationaltanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rationaltanh position(long position) { + return (rationaltanh)super.position(position); + } - /** - * This is RationalTanh activation function. - */ -// #if NOT_EXCLUDED(OP_rationaltanh) - @Namespace("sd::ops") public static class rationaltanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rationaltanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rationaltanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rationaltanh position(long position) { - return (rationaltanh)super.position(position); - } - public rationaltanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class rationaltanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rationaltanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rationaltanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rationaltanh_bp position(long position) { - return (rationaltanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class rationaltanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rationaltanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rationaltanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rationaltanh_bp position(long position) { + return (rationaltanh_bp)super.position(position); + } + public rationaltanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is HardTanh activation function. + * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; + */ +// #if NOT_EXCLUDED(OP_hardtanh) +@Namespace("sd::ops") public static class hardtanh extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardtanh(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardtanh(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardtanh position(long position) { + return (hardtanh)super.position(position); + } - /** - * This is HardTanh activation function. - * Math is: x < -1.0 ? -1.0 : x > 1.0 ? 1.0 : x; - */ -// #if NOT_EXCLUDED(OP_hardtanh) - @Namespace("sd::ops") public static class hardtanh extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardtanh(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardtanh(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardtanh position(long position) { - return (hardtanh)super.position(position); - } - public hardtanh() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hardtanh_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardtanh_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardtanh_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardtanh_bp position(long position) { - return (hardtanh_bp)super.position(position); - } - +@Namespace("sd::ops") public static class hardtanh_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardtanh_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardtanh_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardtanh_bp position(long position) { + return (hardtanh_bp)super.position(position); + } + public hardtanh_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is HardSigmoid activation function. + * Math is: min(1, max(0, 0.2 * x + 0.5)) + */ +// #if NOT_EXCLUDED(OP_hardsigmoid) +@Namespace("sd::ops") public static class hardsigmoid extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardsigmoid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardsigmoid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardsigmoid position(long position) { + return (hardsigmoid)super.position(position); + } - /** - * This is HardSigmoid activation function. - * Math is: min(1, max(0, 0.2 * x + 0.5)) - */ -// #if NOT_EXCLUDED(OP_hardsigmoid) - @Namespace("sd::ops") public static class hardsigmoid extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardsigmoid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardsigmoid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardsigmoid position(long position) { - return (hardsigmoid)super.position(position); - } - public hardsigmoid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hardsigmoid_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hardsigmoid_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hardsigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hardsigmoid_bp position(long position) { - return (hardsigmoid_bp)super.position(position); - } - +@Namespace("sd::ops") public static class hardsigmoid_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hardsigmoid_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hardsigmoid_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hardsigmoid_bp position(long position) { + return (hardsigmoid_bp)super.position(position); + } + public hardsigmoid_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +// #if NOT_EXCLUDED(OP_identity) +@Namespace("sd::ops") public static class identity extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity position(long position) { + return (identity)super.position(position); + } - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ -// #if NOT_EXCLUDED(OP_identity) - @Namespace("sd::ops") public static class identity extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity position(long position) { - return (identity)super.position(position); - } - public identity() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class identity_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity_bp position(long position) { - return (identity_bp)super.position(position); - } - +@Namespace("sd::ops") public static class identity_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity_bp position(long position) { + return (identity_bp)super.position(position); + } + public identity_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Indentity operation. It passes signal umodified in both directions. + */ +// #if NOT_EXCLUDED(OP_identity_n) +@Namespace("sd::ops") public static class identity_n extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public identity_n(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public identity_n(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public identity_n position(long position) { + return (identity_n)super.position(position); + } - /** - * This is Indentity operation. It passes signal umodified in both directions. - */ -// #if NOT_EXCLUDED(OP_identity_n) - @Namespace("sd::ops") public static class identity_n extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public identity_n(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public identity_n(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public identity_n position(long position) { - return (identity_n)super.position(position); - } - public identity_n() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is Concatenated RELU implementation. + * What happens inside: RELU(Concat((x, -x, {-1}))) + * + * PLEASE NOTE: Concatenation will double amount of features available in input + */ +// #if NOT_EXCLUDED(OP_crelu) +@Namespace("sd::ops") public static class crelu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public crelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public crelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public crelu position(long position) { + return (crelu)super.position(position); + } - /** - * This is Concatenated RELU implementation. - * What happens inside: RELU(Concat((x, -x, {-1}))) - * - * PLEASE NOTE: Concatenation will double amount of features available in input - */ -// #if NOT_EXCLUDED(OP_crelu) - @Namespace("sd::ops") public static class crelu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public crelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public crelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public crelu position(long position) { - return (crelu)super.position(position); - } - public crelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class crelu_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public crelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public crelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public crelu_bp position(long position) { - return (crelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class crelu_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public crelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public crelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public crelu_bp position(long position) { + return (crelu_bp)super.position(position); + } + public crelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is RELU6 activation function implementation + */ +// #if NOT_EXCLUDED(OP_relu6) +@Namespace("sd::ops") public static class relu6 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu6(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu6(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu6 position(long position) { + return (relu6)super.position(position); + } - /** - * This is RELU6 activation function implementation - */ -// #if NOT_EXCLUDED(OP_relu6) - @Namespace("sd::ops") public static class relu6 extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu6(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu6(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu6 position(long position) { - return (relu6)super.position(position); - } - public relu6() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class relu6_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu6_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu6_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu6_bp position(long position) { - return (relu6_bp)super.position(position); - } - +@Namespace("sd::ops") public static class relu6_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu6_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu6_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu6_bp position(long position) { + return (relu6_bp)super.position(position); + } + public relu6_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * Parametric Rectified Linear Unit + * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 + */ +// #if NOT_EXCLUDED(OP_prelu) +@Namespace("sd::ops") public static class prelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public prelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public prelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public prelu position(long position) { + return (prelu)super.position(position); + } - /** - * Parametric Rectified Linear Unit - * f(x) = alpha * x for x < 0, f(x) = x for x >= 0 - */ -// #if NOT_EXCLUDED(OP_prelu) - @Namespace("sd::ops") public static class prelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public prelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public prelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public prelu position(long position) { - return (prelu)super.position(position); - } - public prelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class prelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public prelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public prelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public prelu_bp position(long position) { - return (prelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class prelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public prelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public prelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public prelu_bp position(long position) { + return (prelu_bp)super.position(position); + } + public prelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Thresholded Rectified Linear Unit + * f(x) = x for x > theta, f(x) = 0 otherwise + * theta must be >= 0 + */ +// #if NOT_EXCLUDED(OP_thresholdedrelu) +@Namespace("sd::ops") public static class thresholdedrelu extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public thresholdedrelu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public thresholdedrelu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public thresholdedrelu position(long position) { + return (thresholdedrelu)super.position(position); + } - /** - * Thresholded Rectified Linear Unit - * f(x) = x for x > theta, f(x) = 0 otherwise - * theta must be >= 0 - */ -// #if NOT_EXCLUDED(OP_thresholdedrelu) - @Namespace("sd::ops") public static class thresholdedrelu extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public thresholdedrelu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public thresholdedrelu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public thresholdedrelu position(long position) { - return (thresholdedrelu)super.position(position); - } - public thresholdedrelu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class thresholdedrelu_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public thresholdedrelu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public thresholdedrelu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public thresholdedrelu_bp position(long position) { - return (thresholdedrelu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class thresholdedrelu_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public thresholdedrelu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public thresholdedrelu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public thresholdedrelu_bp position(long position) { + return (thresholdedrelu_bp)super.position(position); + } + public thresholdedrelu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - - +// #endif + // namespace ops + // namespace sd // #endif @@ -13519,319 +15817,323 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x < y - */ -// #if NOT_EXCLUDED(OP_lt_scalar) - @Namespace("sd::ops") public static class lt_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lt_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lt_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lt_scalar position(long position) { - return (lt_scalar)super.position(position); - } - +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x < y + */ +// #if NOT_EXCLUDED(OP_lt_scalar) +@Namespace("sd::ops") public static class lt_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lt_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lt_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lt_scalar position(long position) { + return (lt_scalar)super.position(position); + } + public lt_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x > y + */ +// #if NOT_EXCLUDED(OP_gt_scalar) +@Namespace("sd::ops") public static class gt_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gt_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gt_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gt_scalar position(long position) { + return (gt_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x > y - */ -// #if NOT_EXCLUDED(OP_gt_scalar) - @Namespace("sd::ops") public static class gt_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gt_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gt_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gt_scalar position(long position) { - return (gt_scalar)super.position(position); - } - public gt_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x <= y + */ +// #if NOT_EXCLUDED(OP_lte_scalar) +@Namespace("sd::ops") public static class lte_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lte_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lte_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lte_scalar position(long position) { + return (lte_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x <= y - */ -// #if NOT_EXCLUDED(OP_lte_scalar) - @Namespace("sd::ops") public static class lte_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lte_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lte_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lte_scalar position(long position) { - return (lte_scalar)super.position(position); - } - public lte_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x >= y + */ +// #if NOT_EXCLUDED(OP_gte_scalar) +@Namespace("sd::ops") public static class gte_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gte_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gte_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gte_scalar position(long position) { + return (gte_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x >= y - */ -// #if NOT_EXCLUDED(OP_gte_scalar) - @Namespace("sd::ops") public static class gte_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gte_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gte_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gte_scalar position(long position) { - return (gte_scalar)super.position(position); - } - public gte_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if both operands are equal. + */ +// #if NOT_EXCLUDED(OP_eq_scalar) +@Namespace("sd::ops") public static class eq_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public eq_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public eq_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public eq_scalar position(long position) { + return (eq_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if both operands are equal. - */ -// #if NOT_EXCLUDED(OP_eq_scalar) - @Namespace("sd::ops") public static class eq_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public eq_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public eq_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public eq_scalar position(long position) { - return (eq_scalar)super.position(position); - } - public eq_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This is scalar boolean op. + * Both operands should be scalars. + * + * Returns true if x != y + */ +// #if NOT_EXCLUDED(OP_neq_scalar) +@Namespace("sd::ops") public static class neq_scalar extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public neq_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public neq_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public neq_scalar position(long position) { + return (neq_scalar)super.position(position); + } - /** - * This is scalar boolean op. - * Both operands should be scalars. - * - * Returns true if x != y - */ -// #if NOT_EXCLUDED(OP_neq_scalar) - @Namespace("sd::ops") public static class neq_scalar extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public neq_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public neq_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public neq_scalar position(long position) { - return (neq_scalar)super.position(position); - } - public neq_scalar() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +// #if NOT_EXCLUDED(OP_where) +@Namespace("sd::ops") public static class Where extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Where(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Where(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Where position(long position) { + return (Where)super.position(position); + } - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ -// #if NOT_EXCLUDED(OP_where) - @Namespace("sd::ops") public static class Where extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Where(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Where(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Where position(long position) { - return (Where)super.position(position); - } - public Where() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_where_np) +@Namespace("sd::ops") public static class where_np extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public where_np(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public where_np(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public where_np position(long position) { + return (where_np)super.position(position); + } -// #if NOT_EXCLUDED(OP_where_np) - @Namespace("sd::ops") public static class where_np extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public where_np(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public where_np(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public where_np position(long position) { - return (where_np)super.position(position); - } - public where_np() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 2 n-dimensional arrays as input, and return + * array of the same shape, with elements, either from x or y, depending on the + * condition. + */ +// #if NOT_EXCLUDED(OP_select) +@Namespace("sd::ops") public static class select extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public select(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public select(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public select position(long position) { + return (select)super.position(position); + } - /** - * This op takes 2 n-dimensional arrays as input, and return - * array of the same shape, with elements, either from x or y, depending on the condition. - */ -// #if NOT_EXCLUDED(OP_select) - @Namespace("sd::ops") public static class select extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public select(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public select(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public select position(long position) { - return (select)super.position(position); - } - public select() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes either 1 argument and 1 scalar + * or 1 argument and another comparison array + * and runs a pre defined conditional op. + * + * The output of the op is dynamic in size and returns a flat vector of + * elements that return true on the given condition. In numpy parlance, most + * people might understand: a[a > 2] where a is a numpy array and the condition + * is true when an element is > 2. Libnd4j already implements a number of pre + * defined conditions. + * \tparam T + */ +// #if NOT_EXCLUDED(OP_choose) +@Namespace("sd::ops") public static class choose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public choose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public choose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public choose position(long position) { + return (choose)super.position(position); + } - /** - * This op takes either 1 argument and 1 scalar - * or 1 argument and another comparison array - * and runs a pre defined conditional op. - * - * The output of the op is dynamic in size and returns a flat vector of elements - * that return true on the given condition. - * In numpy parlance, most people might understand: - * a[a > 2] - * where a is a numpy array and the condition is true when an element is - * > 2. Libnd4j already implements a number of pre defined conditions. - * \tparam T - */ -// #if NOT_EXCLUDED(OP_choose) - @Namespace("sd::ops") public static class choose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public choose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public choose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public choose position(long position) { - return (choose)super.position(position); - } - public choose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] <= x[i+1]. + */ +// #if NOT_EXCLUDED(OP_is_non_decreasing) +@Namespace("sd::ops") public static class is_non_decreasing extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_non_decreasing(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_non_decreasing(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_non_decreasing position(long position) { + return (is_non_decreasing)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] <= x[i+1]. - */ -// #if NOT_EXCLUDED(OP_is_non_decreasing) - @Namespace("sd::ops") public static class is_non_decreasing extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_non_decreasing(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_non_decreasing(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_non_decreasing position(long position) { - return (is_non_decreasing)super.position(position); - } - public is_non_decreasing() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if for every + * adjacent pair we have x[i] < x[i+1]. + */ +// #if NOT_EXCLUDED(OP_is_strictly_increasing) +@Namespace("sd::ops") public static class is_strictly_increasing extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_strictly_increasing(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_strictly_increasing(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_strictly_increasing position(long position) { + return (is_strictly_increasing)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if for every adjacent pair we have x[i] < x[i+1]. - */ -// #if NOT_EXCLUDED(OP_is_strictly_increasing) - @Namespace("sd::ops") public static class is_strictly_increasing extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_strictly_increasing(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_strictly_increasing(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_strictly_increasing position(long position) { - return (is_strictly_increasing)super.position(position); - } - public is_strictly_increasing() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 1 n-dimensional array as input, and returns true if input is a + * numeric array. + */ +// #if NOT_EXCLUDED(OP_is_numeric_tensor) +@Namespace("sd::ops") public static class is_numeric_tensor extends BooleanOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public is_numeric_tensor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public is_numeric_tensor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public is_numeric_tensor position(long position) { + return (is_numeric_tensor)super.position(position); + } - /** - * This op takes 1 n-dimensional array as input, and returns true if input is a numeric array. - */ -// #if NOT_EXCLUDED(OP_is_numeric_tensor) - @Namespace("sd::ops") public static class is_numeric_tensor extends BooleanOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public is_numeric_tensor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public is_numeric_tensor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public is_numeric_tensor position(long position) { - return (is_numeric_tensor)super.position(position); - } - public is_numeric_tensor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_not) +@Namespace("sd::ops") public static class boolean_not extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_not(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_not(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_not position(long position) { + return (boolean_not)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_not) - @Namespace("sd::ops") public static class boolean_not extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_not(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_not(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_not position(long position) { - return (boolean_not)super.position(position); - } - public boolean_not() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -13860,280 +16162,293 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_BROADCASTABLE_H // #define LIBND4J_HEADERS_BROADCASTABLE_H -// #include // #include -// #include +// #include // #include - // TODO: make broadcastables separate class +// #include +// TODO: make broadcastables separate class + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Max(X, Y) + */ +// #if NOT_EXCLUDED(OP_maximum) +@Namespace("sd::ops") public static class maximum extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maximum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maximum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maximum position(long position) { + return (maximum)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Max(X, Y) - */ -// #if NOT_EXCLUDED(OP_maximum) - @Namespace("sd::ops") public static class maximum extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maximum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maximum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maximum position(long position) { - return (maximum)super.position(position); - } - public maximum() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class maximum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maximum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maximum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maximum_bp position(long position) { - return (maximum_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maximum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maximum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maximum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maximum_bp position(long position) { + return (maximum_bp)super.position(position); + } + public maximum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Min(X, Y) + */ +// #if NOT_EXCLUDED(OP_minimum) +@Namespace("sd::ops") public static class minimum extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public minimum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public minimum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public minimum position(long position) { + return (minimum)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Min(X, Y) - */ -// #if NOT_EXCLUDED(OP_minimum) - @Namespace("sd::ops") public static class minimum extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public minimum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public minimum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public minimum position(long position) { - return (minimum)super.position(position); - } - public minimum() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class minimum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public minimum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public minimum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public minimum_bp position(long position) { - return (minimum_bp)super.position(position); - } - +@Namespace("sd::ops") public static class minimum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public minimum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public minimum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public minimum_bp position(long position) { + return (minimum_bp)super.position(position); + } + public minimum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Add(X, Y) + */ +// #if NOT_EXCLUDED(OP_add) +@Namespace("sd::ops") public static class add extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public add position(long position) { + return (add)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Add(X, Y) - */ -// #if NOT_EXCLUDED(OP_add) - @Namespace("sd::ops") public static class add extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public add position(long position) { - return (add)super.position(position); - } - public add() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class add_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public add_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public add_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public add_bp position(long position) { - return (add_bp)super.position(position); - } - +@Namespace("sd::ops") public static class add_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public add_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public add_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public add_bp position(long position) { + return (add_bp)super.position(position); + } + public add_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) + */ +// #if NOT_EXCLUDED(OP_subtract) +@Namespace("sd::ops") public static class subtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public subtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public subtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public subtract position(long position) { + return (subtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) - */ -// #if NOT_EXCLUDED(OP_subtract) - @Namespace("sd::ops") public static class subtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public subtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public subtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public subtract position(long position) { - return (subtract)super.position(position); - } - public subtract() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class subtract_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public subtract_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public subtract_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public subtract_bp position(long position) { - return (subtract_bp)super.position(position); - } - +@Namespace("sd::ops") public static class subtract_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public subtract_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public subtract_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public subtract_bp position(long position) { + return (subtract_bp)super.position(position); + } + public subtract_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(Y, X) + */ +// #if NOT_EXCLUDED(OP_reversesubtract) +@Namespace("sd::ops") public static class reversesubtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversesubtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversesubtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversesubtract position(long position) { + return (reversesubtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(Y, X) - */ -// #if NOT_EXCLUDED(OP_reversesubtract) - @Namespace("sd::ops") public static class reversesubtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversesubtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversesubtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversesubtract position(long position) { - return (reversesubtract)super.position(position); - } - public reversesubtract() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversesubtract_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversesubtract_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversesubtract_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversesubtract_bp position(long position) { - return (reversesubtract_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversesubtract_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversesubtract_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversesubtract_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversesubtract_bp position(long position) { + return (reversesubtract_bp)super.position(position); + } + public reversesubtract_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) + */ +// #if NOT_EXCLUDED(OP_reversemod) +@Namespace("sd::ops") public static class reversemod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversemod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversemod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversemod position(long position) { + return (reversemod)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = ReverseMod(X, Y) == Mod(Y, X) - */ -// #if NOT_EXCLUDED(OP_reversemod) - @Namespace("sd::ops") public static class reversemod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversemod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversemod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversemod position(long position) { - return (reversemod)super.position(position); - } - public reversemod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversemod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversemod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversemod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversemod_bp position(long position) { - return (reversemod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversemod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversemod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversemod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversemod_bp position(long position) { + return (reversemod_bp)super.position(position); + } + public reversemod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) + */ +// #if NOT_EXCLUDED(OP_squaredsubtract) +@Namespace("sd::ops") public static class squaredsubtract extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public squaredsubtract(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public squaredsubtract(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public squaredsubtract position(long position) { + return (squaredsubtract)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Subtract(X, Y) * Subtract(X, Y) - */ -// #if NOT_EXCLUDED(OP_squaredsubtract) - @Namespace("sd::ops") public static class squaredsubtract extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public squaredsubtract(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public squaredsubtract(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public squaredsubtract position(long position) { - return (squaredsubtract)super.position(position); - } - public squaredsubtract() { super((Pointer)null); allocate(); } private native void allocate(); } @@ -14152,250 +16467,262 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Multiply(X, Y) + */ +// #if NOT_EXCLUDED(OP_multiply) +@Namespace("sd::ops") public static class multiply extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multiply(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multiply(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multiply position(long position) { + return (multiply)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Multiply(X, Y) - */ -// #if NOT_EXCLUDED(OP_multiply) - @Namespace("sd::ops") public static class multiply extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multiply(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multiply(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multiply position(long position) { - return (multiply)super.position(position); - } - public multiply() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class multiply_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multiply_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multiply_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multiply_bp position(long position) { - return (multiply_bp)super.position(position); - } - +@Namespace("sd::ops") public static class multiply_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multiply_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multiply_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multiply_bp position(long position) { + return (multiply_bp)super.position(position); + } + public multiply_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +// #if NOT_EXCLUDED(OP_divide) +@Namespace("sd::ops") public static class divide extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide position(long position) { + return (divide)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ -// #if NOT_EXCLUDED(OP_divide) - @Namespace("sd::ops") public static class divide extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide position(long position) { - return (divide)super.position(position); - } - public divide() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class divide_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide_bp position(long position) { - return (divide_bp)super.position(position); - } - +@Namespace("sd::ops") public static class divide_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_bp position(long position) { + return (divide_bp)super.position(position); + } + public divide_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 + */ +// #if NOT_EXCLUDED(OP_divide_no_nan) +@Namespace("sd::ops") public static class divide_no_nan extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public divide_no_nan(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public divide_no_nan position(long position) { + return (divide_no_nan)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) with exception, 0 if Y = 0 - */ -// #if NOT_EXCLUDED(OP_divide_no_nan) - @Namespace("sd::ops") public static class divide_no_nan extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public divide_no_nan(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public divide_no_nan(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public divide_no_nan position(long position) { - return (divide_no_nan)super.position(position); - } - public divide_no_nan() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(Y, x) - */ -// #if NOT_EXCLUDED(OP_reversedivide) - @Namespace("sd::ops") public static class reversedivide extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversedivide(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversedivide(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversedivide position(long position) { - return (reversedivide)super.position(position); - } - +// #endif +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(Y, x) + */ +// #if NOT_EXCLUDED(OP_reversedivide) +@Namespace("sd::ops") public static class reversedivide extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversedivide(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversedivide(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversedivide position(long position) { + return (reversedivide)super.position(position); + } + public reversedivide() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class reversedivide_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reversedivide_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reversedivide_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reversedivide_bp position(long position) { - return (reversedivide_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reversedivide_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reversedivide_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reversedivide_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reversedivide_bp position(long position) { + return (reversedivide_bp)super.position(position); + } + public reversedivide_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorMod(X, Y) + */ +// #if NOT_EXCLUDED(OP_floormod) +@Namespace("sd::ops") public static class floormod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floormod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floormod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floormod position(long position) { + return (floormod)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorMod(X, Y) - */ -// #if NOT_EXCLUDED(OP_floormod) - @Namespace("sd::ops") public static class floormod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floormod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floormod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floormod position(long position) { - return (floormod)super.position(position); - } - public floormod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class floormod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floormod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floormod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floormod_bp position(long position) { - return (floormod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class floormod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floormod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floormod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floormod_bp position(long position) { + return (floormod_bp)super.position(position); + } + public floormod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mod) +@Namespace("sd::ops") public static class mod extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mod position(long position) { + return (mod)super.position(position); + } -// #if NOT_EXCLUDED(OP_mod) - @Namespace("sd::ops") public static class mod extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mod position(long position) { - return (mod)super.position(position); - } - public mod() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class mod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mod_bp position(long position) { - return (mod_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mod_bp position(long position) { + return (mod_bp)super.position(position); + } + public mod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = FloorDiv(X, Y) + */ +// #if NOT_EXCLUDED(OP_floordiv) +@Namespace("sd::ops") public static class floordiv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public floordiv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public floordiv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public floordiv position(long position) { + return (floordiv)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = FloorDiv(X, Y) - */ -// #if NOT_EXCLUDED(OP_floordiv) - @Namespace("sd::ops") public static class floordiv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public floordiv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public floordiv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public floordiv position(long position) { - return (floordiv)super.position(position); - } - public floordiv() { super((Pointer)null); allocate(); } private native void allocate(); } @@ -14414,453 +16741,458 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - // #endif + // #endif + +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Divide(X, Y) + */ +// #if NOT_EXCLUDED(OP_realdiv) +@Namespace("sd::ops") public static class realdiv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public realdiv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public realdiv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public realdiv position(long position) { + return (realdiv)super.position(position); + } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Divide(X, Y) - */ -// #if NOT_EXCLUDED(OP_realdiv) - @Namespace("sd::ops") public static class realdiv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public realdiv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public realdiv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public realdiv position(long position) { - return (realdiv)super.position(position); - } - public realdiv() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class realdiv_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public realdiv_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public realdiv_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public realdiv_bp position(long position) { - return (realdiv_bp)super.position(position); - } - +@Namespace("sd::ops") public static class realdiv_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public realdiv_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public realdiv_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public realdiv_bp position(long position) { + return (realdiv_bp)super.position(position); + } + public realdiv_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * + * + * \tparam T + */ +@Namespace("sd::ops") public static class truncatediv extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public truncatediv(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public truncatediv(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public truncatediv position(long position) { + return (truncatediv)super.position(position); + } - /** - * - * - * \tparam T - */ - @Namespace("sd::ops") public static class truncatediv extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public truncatediv(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public truncatediv(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public truncatediv position(long position) { - return (truncatediv)super.position(position); - } - public truncatediv() { super((Pointer)null); allocate(); } private native void allocate(); } - /** - * This is one of auto-broadcastable operations. It accepts 2 operands, and operation is applied based on their shapes: - * 1) if shapes are equal that's pairwise operation, result will have the same shape. - * 2) if shape X is scalar and shape Y is array - result will have shape equal to Y. - * 3) if shape X is array and shape Y is scalar - result will have shape equal to X. - * 4) if shape X and Y are both arrays, but shapes aren't equal - result shape will be broadcast result. - * - * This operation returns Z = Assign(X, Y) - */ -// #if NOT_EXCLUDED(OP_assign) - @Namespace("sd::ops") public static class assign extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public assign(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public assign(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public assign position(long position) { - return (assign)super.position(position); - } - +/** + * This is one of auto-broadcastable operations. It accepts 2 operands, and + * operation is applied based on their shapes: 1) if shapes are equal that's + * pairwise operation, result will have the same shape. 2) if shape X is scalar + * and shape Y is array - result will have shape equal to Y. 3) if shape X is + * array and shape Y is scalar - result will have shape equal to X. 4) if shape + * X and Y are both arrays, but shapes aren't equal - result shape will be + * broadcast result. + * + * This operation returns Z = Assign(X, Y) + */ +// #if NOT_EXCLUDED(OP_assign) +@Namespace("sd::ops") public static class assign extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public assign(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public assign(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public assign position(long position) { + return (assign)super.position(position); + } + public assign() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class assign_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public assign_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public assign_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public assign_bp position(long position) { - return (assign_bp)super.position(position); - } - +@Namespace("sd::ops") public static class assign_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public assign_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public assign_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public assign_bp position(long position) { + return (assign_bp)super.position(position); + } + public assign_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_meshgrid) +@Namespace("sd::ops") public static class meshgrid extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public meshgrid(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public meshgrid(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public meshgrid position(long position) { + return (meshgrid)super.position(position); + } -// #if NOT_EXCLUDED(OP_meshgrid) - @Namespace("sd::ops") public static class meshgrid extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public meshgrid(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public meshgrid(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public meshgrid position(long position) { - return (meshgrid)super.position(position); - } - public meshgrid() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x == _y ? (T) 1.0f : (T) 0.0f; + * + */ +// #if NOT_EXCLUDED(OP_equals) +@Namespace("sd::ops") public static class equals extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public equals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public equals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public equals position(long position) { + return (equals)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x == _y ? (T) 1.0f : (T) 0.0f; - * - */ -// #if NOT_EXCLUDED(OP_equals) - @Namespace("sd::ops") public static class equals extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public equals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public equals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public equals position(long position) { - return (equals)super.position(position); - } - public equals() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x != _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_not_equals) +@Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public not_equals(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public not_equals(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public not_equals position(long position) { + return (not_equals)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x != _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_not_equals) - @Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public not_equals(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public not_equals(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public not_equals position(long position) { - return (not_equals)super.position(position); - } - public not_equals() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_less_equal) +@Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public less_equal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public less_equal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public less_equal position(long position) { + return (less_equal)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_less_equal) - @Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public less_equal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public less_equal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public less_equal position(long position) { - return (less_equal)super.position(position); - } - public less_equal() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_greater_equal) +@Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public greater_equal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public greater_equal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public greater_equal position(long position) { + return (greater_equal)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_greater_equal) - @Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public greater_equal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public greater_equal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public greater_equal position(long position) { - return (greater_equal)super.position(position); - } - public greater_equal() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x < _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_less) +@Namespace("sd::ops") public static class less extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public less(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public less(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public less position(long position) { + return (less)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x < _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_less) - @Namespace("sd::ops") public static class less extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public less(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public less(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public less position(long position) { - return (less)super.position(position); - } - public less() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This op takes 2 equally shaped arrays as input, and provides binary matrix as + * output. Math is: _x > _y ? (T) 1.0f : (T) 0.0f; + */ +// #if NOT_EXCLUDED(OP_greater) +@Namespace("sd::ops") public static class greater extends BroadcastableBoolOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public greater(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public greater(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public greater position(long position) { + return (greater)super.position(position); + } - /** - * This op takes 2 equally shaped arrays as input, and provides binary matrix as output. - * Math is: _x > _y ? (T) 1.0f : (T) 0.0f; - */ -// #if NOT_EXCLUDED(OP_greater) - @Namespace("sd::ops") public static class greater extends BroadcastableBoolOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public greater(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public greater(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public greater position(long position) { - return (greater)super.position(position); - } - public greater() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_and) +@Namespace("sd::ops") public static class boolean_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_and position(long position) { + return (boolean_and)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_and) - @Namespace("sd::ops") public static class boolean_and extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_and(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_and(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_and position(long position) { - return (boolean_and)super.position(position); - } - public boolean_and() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_or) +@Namespace("sd::ops") public static class boolean_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_or position(long position) { + return (boolean_or)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_or) - @Namespace("sd::ops") public static class boolean_or extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_or(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_or(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_or position(long position) { - return (boolean_or)super.position(position); - } - public boolean_or() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * + */ +// #if NOT_EXCLUDED(OP_boolean_xor) +@Namespace("sd::ops") public static class boolean_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public boolean_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public boolean_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public boolean_xor position(long position) { + return (boolean_xor)super.position(position); + } - /** - * - */ -// #if NOT_EXCLUDED(OP_boolean_xor) - @Namespace("sd::ops") public static class boolean_xor extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public boolean_xor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public boolean_xor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public boolean_xor position(long position) { - return (boolean_xor)super.position(position); - } - public boolean_xor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation performs calculation of percentile of input array along given + * axises + * + * Input - tensor with rank N > 0 + * Output - tensor with rank (N - length(axis)) or scalar if number of Integer + * arguments is zero Float arguments: 0: percentile (scalar) in range [0,100] + * (inclusively) 1: interpolation (optional), possible values are 0-"lower", + * 1-"higher", 2-"nearest"(default) 2: keepDims (optional), if it is non zero, + * then unities are kept in reduced resulting shape of output array, default is + * 0 Integer arguments - axis - the sequence of axises to calculate percentile + * along, if sequence is empty then calculate percentile for whole input tensor + * and return result as scalar + * + */ +// #if NOT_EXCLUDED(OP_percentile) +@Namespace("sd::ops") public static class percentile extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public percentile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public percentile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public percentile position(long position) { + return (percentile)super.position(position); + } - /** - * This operation performs calculation of percentile of input array along given axises - * - * Input - tensor with rank N > 0 - * Output - tensor with rank (N - length(axis)) or scalar if number of Integer arguments is zero - * Float arguments: - * 0: percentile (scalar) in range [0,100] (inclusively) - * 1: interpolation (optional), possible values are 0-"lower", 1-"higher", 2-"nearest"(default) - * 2: keepDims (optional), if it is non zero, then unities are kept in reduced resulting shape of output array, default is 0 - * Integer arguments - axis - the sequence of axises to calculate percentile along, if sequence is empty then calculate percentile for whole input tensor and return result as scalar - * - */ -// #if NOT_EXCLUDED(OP_percentile) - @Namespace("sd::ops") public static class percentile extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public percentile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public percentile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public percentile position(long position) { - return (percentile)super.position(position); - } - public percentile() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * Special atan2 op impl for TF's args order + * \tparam T + */ +// #if NOT_EXCLUDED(OP_tf_atan2) +@Namespace("sd::ops") public static class tf_atan2 extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tf_atan2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tf_atan2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tf_atan2 position(long position) { + return (tf_atan2)super.position(position); + } - /** - * Special atan2 op impl for TF's args order - * \tparam T - */ -// #if NOT_EXCLUDED(OP_tf_atan2) - @Namespace("sd::ops") public static class tf_atan2 extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tf_atan2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tf_atan2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tf_atan2 position(long position) { - return (tf_atan2)super.position(position); - } - public tf_atan2() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * Broadcastable pow implementation + * \tparam T + */ +// #if NOT_EXCLUDED(OP_Pow) +@Namespace("sd::ops") public static class Pow extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pow(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pow(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pow position(long position) { + return (Pow)super.position(position); + } - /** - * Broadcastable pow implementation - * \tparam T - */ -// #if NOT_EXCLUDED(OP_Pow) - @Namespace("sd::ops") public static class Pow extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pow(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pow(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pow position(long position) { - return (Pow)super.position(position); - } - public Pow() { super((Pointer)null); allocate(); } private native void allocate(); } - @Namespace("sd::ops") public static class Pow_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Pow_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Pow_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Pow_bp position(long position) { - return (Pow_bp)super.position(position); - } - +@Namespace("sd::ops") public static class Pow_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Pow_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Pow_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Pow_bp position(long position) { + return (Pow_bp)super.position(position); + } + public Pow_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Broadcastable igamma implementation + * + * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igamma) +@Namespace("sd::ops") public static class igamma extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igamma position(long position) { + return (igamma)super.position(position); + } - /** - * Broadcastable igamma implementation - * - * igamma(a, x) = gamma(а, x) / Gamma(a) - Gamma distribution function P(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * gamma(a, x) = int from 0 to x { t ^ {a - 1} e^{-t}dt } - * \tparam T - */ -// #if NOT_EXCLUDED(OP_igamma) - @Namespace("sd::ops") public static class igamma extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public igamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public igamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public igamma position(long position) { - return (igamma)super.position(position); - } - public igamma() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - /** - * Broadcastable igammac implementation - * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) - * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } - * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } - * \tparam T - */ -// #if NOT_EXCLUDED(OP_igammac) - @Namespace("sd::ops") public static class igammac extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public igammac(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public igammac(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public igammac position(long position) { - return (igammac)super.position(position); - } - +// #endif +/** + * Broadcastable igammac implementation + * igammac(a, x) = Gamma(a,x)/Gamma(а) - Gamma distribution function Q(a,x) + * Gamma(a) = int from 0 to infinity { t ^ {a - 1} e^{-t}dt } + * Gamma(a, x) = int from x to infinity { t ^ {a - 1} e^{-t}dt } + * \tparam T + */ +// #if NOT_EXCLUDED(OP_igammac) +@Namespace("sd::ops") public static class igammac extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public igammac(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public igammac(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public igammac position(long position) { + return (igammac)super.position(position); + } + public igammac() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -14891,824 +17223,823 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * 1D temporal convolution implementation - * Expected input: - * x: 3D array - * weight: 3D Array - * bias: optional vector - * - * Int args: - * 0: kernel - * 1: stride - * 2: padding - */ -// #if NOT_EXCLUDED(OP_conv1d) - @Namespace("sd::ops") public static class conv1d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv1d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv1d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv1d position(long position) { - return (conv1d)super.position(position); - } - +/** + * 1D temporal convolution implementation + * Expected input: + * x: 3D array + * weight: 3D Array + * bias: optional vector + * + * Int args: + * 0: kernel + * 1: stride + * 2: padding + */ +// #if NOT_EXCLUDED(OP_conv1d) +@Namespace("sd::ops") public static class conv1d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv1d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv1d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv1d position(long position) { + return (conv1d)super.position(position); + } + public conv1d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv1d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv1d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv1d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv1d_bp position(long position) { - return (conv1d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv1d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv1d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv1d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv1d_bp position(long position) { + return (conv1d_bp)super.position(position); + } + public conv1d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * 2D convolution implementation + * Expected input: + * x: 4D array + * weight: 4D Array + * bias: optional vector, length of outputChannels + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 1 true, 0 false + * 9: data format: 1 NHWC, 0 NCHW + */ +// #if NOT_EXCLUDED(OP_conv2d) +@Namespace("sd::ops") public static class conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d position(long position) { + return (conv2d)super.position(position); + } - /** - * 2D convolution implementation - * Expected input: - * x: 4D array - * weight: 4D Array - * bias: optional vector, length of outputChannels - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 1 true, 0 false - * 9: data format: 1 NHWC, 0 NCHW - */ -// #if NOT_EXCLUDED(OP_conv2d) - @Namespace("sd::ops") public static class conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d position(long position) { - return (conv2d)super.position(position); - } - public conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d_bp position(long position) { - return (conv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d_bp position(long position) { + return (conv2d_bp)super.position(position); + } + public conv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv2d_input_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv2d_input_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv2d_input_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv2d_input_bp position(long position) { - return (conv2d_input_bp)super.position(position); - } - - public conv2d_input_bp() { super((Pointer)null); allocate(); } - private native void allocate(); +@Namespace("sd::ops") public static class conv2d_input_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv2d_input_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv2d_input_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv2d_input_bp position(long position) { + return (conv2d_input_bp)super.position(position); + } + + public conv2d_input_bp() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Depthwise convolution2d op: + * Expected inputs: + * x: 4D array, NCHW format + * weightsDepth: 4D array, + * weightsPointwise: optional, 4D array + * bias: optional, vector + */ +// #if NOT_EXCLUDED(OP_sconv2d) +@Namespace("sd::ops") public static class sconv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sconv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sconv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sconv2d position(long position) { + return (sconv2d)super.position(position); + } - /** - * Depthwise convolution2d op: - * Expected inputs: - * x: 4D array, NCHW format - * weightsDepth: 4D array, - * weightsPointwise: optional, 4D array - * bias: optional, vector - */ -// #if NOT_EXCLUDED(OP_sconv2d) - @Namespace("sd::ops") public static class sconv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sconv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sconv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sconv2d position(long position) { - return (sconv2d)super.position(position); - } - public sconv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sconv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sconv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sconv2d_bp position(long position) { - return (sconv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class sconv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sconv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sconv2d_bp position(long position) { + return (sconv2d_bp)super.position(position); + } + public sconv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * 2D deconvolution implementation + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_deconv2d) +@Namespace("sd::ops") public static class deconv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d position(long position) { + return (deconv2d)super.position(position); + } - /** - * 2D deconvolution implementation - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_deconv2d) - @Namespace("sd::ops") public static class deconv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d position(long position) { - return (deconv2d)super.position(position); - } - public deconv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d_bp position(long position) { - return (deconv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class deconv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d_bp position(long position) { + return (deconv2d_bp)super.position(position); + } + public deconv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * 3D deconvolution implementation - * - * IntArgs: - * 0: filter(kernel) depth - * 1: filter(kernel) height - * 2: filter(kernel) width - * 3: strides depth - * 4: strides height - * 5: strides width - * 6: paddings depth - * 7: paddings height - * 8: paddings width - * 9: dilations depth - * 10: dilations height - * 11: dilations width - * 12: same mode: 0 false, 1 true - * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 - */ +/** + * 3D deconvolution implementation + * + * IntArgs: + * 0: filter(kernel) depth + * 1: filter(kernel) height + * 2: filter(kernel) width + * 3: strides depth + * 4: strides height + * 5: strides width + * 6: paddings depth + * 7: paddings height + * 8: paddings width + * 9: dilations depth + * 10: dilations height + * 11: dilations width + * 12: same mode: 0 false, 1 true + * 13: data format (optional): 0-NDHWC, 1-NCDHW, default is 1 + */ + +// #if NOT_EXCLUDED(OP_deconv3d) +@Namespace("sd::ops") public static class deconv3d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv3d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv3d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv3d position(long position) { + return (deconv3d)super.position(position); + } -// #if NOT_EXCLUDED(OP_deconv3d) - @Namespace("sd::ops") public static class deconv3d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv3d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv3d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv3d position(long position) { - return (deconv3d)super.position(position); - } - public deconv3d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv3d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv3d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv3d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv3d_bp position(long position) { - return (deconv3d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class deconv3d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv3d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv3d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv3d_bp position(long position) { + return (deconv3d_bp)super.position(position); + } + public deconv3d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op implements max pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_maxpool2d) +@Namespace("sd::ops") public static class maxpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool2d position(long position) { + return (maxpool2d)super.position(position); + } - /** - * This op implements max pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_maxpool2d) - @Namespace("sd::ops") public static class maxpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool2d position(long position) { - return (maxpool2d)super.position(position); - } - public maxpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class maxpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool2d_bp position(long position) { - return (maxpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maxpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool2d_bp position(long position) { + return (maxpool2d_bp)super.position(position); + } + public maxpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements average pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + */ +// #if NOT_EXCLUDED(OP_avgpool2d) +@Namespace("sd::ops") public static class avgpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool2d position(long position) { + return (avgpool2d)super.position(position); + } - /** - * This op implements average pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - */ -// #if NOT_EXCLUDED(OP_avgpool2d) - @Namespace("sd::ops") public static class avgpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool2d position(long position) { - return (avgpool2d)super.position(position); - } - public avgpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class avgpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool2d_bp position(long position) { - return (avgpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class avgpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool2d_bp position(long position) { + return (avgpool2d_bp)super.position(position); + } + public avgpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements pnorm pooling for convolution networks. + * Expected Input: 4D array, NCHW format. + * + * IntArgs: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: same mode: 0 false, 1 true + * 9: p for p-norm + */ +// #if NOT_EXCLUDED(OP_pnormpool2d) +@Namespace("sd::ops") public static class pnormpool2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pnormpool2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pnormpool2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pnormpool2d position(long position) { + return (pnormpool2d)super.position(position); + } - /** - * This op implements pnorm pooling for convolution networks. - * Expected Input: 4D array, NCHW format. - * - * IntArgs: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: same mode: 0 false, 1 true - * 9: p for p-norm - */ -// #if NOT_EXCLUDED(OP_pnormpool2d) - @Namespace("sd::ops") public static class pnormpool2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pnormpool2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pnormpool2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pnormpool2d position(long position) { - return (pnormpool2d)super.position(position); - } - public pnormpool2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class pnormpool2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pnormpool2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pnormpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pnormpool2d_bp position(long position) { - return (pnormpool2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class pnormpool2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pnormpool2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pnormpool2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pnormpool2d_bp position(long position) { + return (pnormpool2d_bp)super.position(position); + } + public pnormpool2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements im2col algorithm, widely used in convolution neural + * networks Input: 4D input expected + * + * Int args: + * 0: kernel height + * 1: kernel width + * 2: stride height + * 3: stride width + * 4: padding height + * 5: padding width + * 6: dilation height + * 7: dilation width + * 8: isSameMode + */ +// #if NOT_EXCLUDED(OP_im2col) +@Namespace("sd::ops") public static class im2col extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public im2col(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public im2col(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public im2col position(long position) { + return (im2col)super.position(position); + } - /** - * This op implements im2col algorithm, widely used in convolution neural networks - * Input: 4D input expected - * - * Int args: - * 0: kernel height - * 1: kernel width - * 2: stride height - * 3: stride width - * 4: padding height - * 5: padding width - * 6: dilation height - * 7: dilation width - * 8: isSameMode - */ -// #if NOT_EXCLUDED(OP_im2col) - @Namespace("sd::ops") public static class im2col extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public im2col(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public im2col(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public im2col position(long position) { - return (im2col)super.position(position); - } - public im2col() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class im2col_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public im2col_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public im2col_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public im2col_bp position(long position) { - return (im2col_bp)super.position(position); - } - +@Namespace("sd::ops") public static class im2col_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public im2col_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public im2col_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public im2col_bp position(long position) { + return (im2col_bp)super.position(position); + } + public im2col_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op implements col2im algorithm, widely used in convolution neural + * networks Input: 6D input expected (like output of im2col op) + * + * Int args: + * 0: stride height + * 1: stride width + * 2: padding height + * 3: padding width + * 4: image height + * 5: image width + * 6: dilation height + * 7: dilation width + */ +// #if NOT_EXCLUDED(OP_col2im) +@Namespace("sd::ops") public static class col2im extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public col2im(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public col2im(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public col2im position(long position) { + return (col2im)super.position(position); + } - /** - * This op implements col2im algorithm, widely used in convolution neural networks - * Input: 6D input expected (like output of im2col op) - * - * Int args: - * 0: stride height - * 1: stride width - * 2: padding height - * 3: padding width - * 4: image height - * 5: image width - * 6: dilation height - * 7: dilation width - */ -// #if NOT_EXCLUDED(OP_col2im) - @Namespace("sd::ops") public static class col2im extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public col2im(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public col2im(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public col2im position(long position) { - return (col2im)super.position(position); - } - public col2im() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for rows (height) + * 1: scale factor for columns (width) + * 2: data format: 0 NHWC (default), 1 NCHW + */ +// #if NOT_EXCLUDED(OP_upsampling2d) +@Namespace("sd::ops") public static class upsampling2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling2d position(long position) { + return (upsampling2d)super.position(position); + } + + public upsampling2d() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +@Namespace("sd::ops") public static class upsampling2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling2d_bp position(long position) { + return (upsampling2d_bp)super.position(position); + } - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for rows (height) - * 1: scale factor for columns (width) - * 2: data format: 0 NHWC (default), 1 NCHW - */ -// #if NOT_EXCLUDED(OP_upsampling2d) - @Namespace("sd::ops") public static class upsampling2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling2d position(long position) { - return (upsampling2d)super.position(position); - } - - public upsampling2d() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class upsampling2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling2d_bp position(long position) { - return (upsampling2d_bp)super.position(position); - } - public upsampling2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Expected input: 4D array + * + * IntArgs: + * 0: scale factor for depth + * 1: scale factor for rows (height) + * 2: scale factor for columns (width) + * 3: data format: 0 NDHWC (default), 1 NCDHW + */ +// #if NOT_EXCLUDED(OP_upsampling3d) +@Namespace("sd::ops") public static class upsampling3d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling3d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling3d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling3d position(long position) { + return (upsampling3d)super.position(position); + } - /** - * Expected input: 4D array - * - * IntArgs: - * 0: scale factor for depth - * 1: scale factor for rows (height) - * 2: scale factor for columns (width) - * 3: data format: 0 NDHWC (default), 1 NCDHW - */ -// #if NOT_EXCLUDED(OP_upsampling3d) - @Namespace("sd::ops") public static class upsampling3d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling3d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling3d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling3d position(long position) { - return (upsampling3d)super.position(position); - } - public upsampling3d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class upsampling3d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public upsampling3d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public upsampling3d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public upsampling3d_bp position(long position) { - return (upsampling3d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class upsampling3d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public upsampling3d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public upsampling3d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public upsampling3d_bp position(long position) { + return (upsampling3d_bp)super.position(position); + } + public upsampling3d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op produces binary matrix wrt to target dimension. + * Maximum value within each TAD is replaced with 1, other values are set to + * true. + * + * Int args: + * 0: axis + */ +// #if NOT_EXCLUDED(OP_ismax) +@Namespace("sd::ops") public static class ismax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ismax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ismax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ismax position(long position) { + return (ismax)super.position(position); + } - /** - * This op produces binary matrix wrt to target dimension. - * Maximum value within each TAD is replaced with 1, other values are set to true. - * - * Int args: - * 0: axis - */ -// #if NOT_EXCLUDED(OP_ismax) - @Namespace("sd::ops") public static class ismax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ismax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ismax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ismax position(long position) { - return (ismax)super.position(position); - } - public ismax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Dilation2D op + * + * Int args: + * 0: isSameMode + */ +// #if NOT_EXCLUDED(OP_dilation2d) +@Namespace("sd::ops") public static class dilation2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dilation2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dilation2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dilation2d position(long position) { + return (dilation2d)super.position(position); + } - /** - * Dilation2D op - * - * Int args: - * 0: isSameMode - */ -// #if NOT_EXCLUDED(OP_dilation2d) - @Namespace("sd::ops") public static class dilation2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dilation2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dilation2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dilation2d position(long position) { - return (dilation2d)super.position(position); - } - public dilation2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_conv3dnew) +@Namespace("sd::ops") public static class conv3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv3dnew position(long position) { + return (conv3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_conv3dnew) - @Namespace("sd::ops") public static class conv3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv3dnew position(long position) { - return (conv3dnew)super.position(position); - } - public conv3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class conv3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public conv3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public conv3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public conv3dnew_bp position(long position) { - return (conv3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class conv3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public conv3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public conv3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public conv3dnew_bp position(long position) { + return (conv3dnew_bp)super.position(position); + } + public conv3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_avgpool3dnew) +@Namespace("sd::ops") public static class avgpool3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool3dnew position(long position) { + return (avgpool3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_avgpool3dnew) - @Namespace("sd::ops") public static class avgpool3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool3dnew position(long position) { - return (avgpool3dnew)super.position(position); - } - public avgpool3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class avgpool3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public avgpool3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public avgpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public avgpool3dnew_bp position(long position) { - return (avgpool3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class avgpool3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public avgpool3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public avgpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public avgpool3dnew_bp position(long position) { + return (avgpool3dnew_bp)super.position(position); + } + public avgpool3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_maxpool3dnew) +@Namespace("sd::ops") public static class maxpool3dnew extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool3dnew(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool3dnew(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool3dnew position(long position) { + return (maxpool3dnew)super.position(position); + } -// #if NOT_EXCLUDED(OP_maxpool3dnew) - @Namespace("sd::ops") public static class maxpool3dnew extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool3dnew(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool3dnew(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool3dnew position(long position) { - return (maxpool3dnew)super.position(position); - } - public maxpool3dnew() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class maxpool3dnew_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public maxpool3dnew_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public maxpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public maxpool3dnew_bp position(long position) { - return (maxpool3dnew_bp)super.position(position); - } - +@Namespace("sd::ops") public static class maxpool3dnew_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public maxpool3dnew_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public maxpool3dnew_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public maxpool3dnew_bp position(long position) { + return (maxpool3dnew_bp)super.position(position); + } + public maxpool3dnew_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op same as maxpool2d with a variant to return a matrix of indexes for + * max values + * + * Input - 4D tensor + * Output: + * 0 - 4D tensor as input + * 1 - 4D tensor with max value indexes + * + * Int params: + * 9 int with 2x4 vectors and 1 bool value + */ +// #if NOT_EXCLUDED(OP_max_pool_woth_argmax) +@Namespace("sd::ops") public static class max_pool_with_argmax extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public max_pool_with_argmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public max_pool_with_argmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public max_pool_with_argmax position(long position) { + return (max_pool_with_argmax)super.position(position); + } - /** - * This op same as maxpool2d with a variant to return a matrix of indexes for max values - * - * Input - 4D tensor - * Output: - * 0 - 4D tensor as input - * 1 - 4D tensor with max value indexes - * - * Int params: - * 9 int with 2x4 vectors and 1 bool value - */ -// #if NOT_EXCLUDED(OP_max_pool_woth_argmax) - @Namespace("sd::ops") public static class max_pool_with_argmax extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public max_pool_with_argmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public max_pool_with_argmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public max_pool_with_argmax position(long position) { - return (max_pool_with_argmax)super.position(position); - } - public max_pool_with_argmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_depthwise_conv2d) +@Namespace("sd::ops") public static class depthwise_conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depthwise_conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depthwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depthwise_conv2d position(long position) { + return (depthwise_conv2d)super.position(position); + } -// #if NOT_EXCLUDED(OP_depthwise_conv2d) - @Namespace("sd::ops") public static class depthwise_conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depthwise_conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depthwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depthwise_conv2d position(long position) { - return (depthwise_conv2d)super.position(position); - } - public depthwise_conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class depthwise_conv2d_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depthwise_conv2d_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depthwise_conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depthwise_conv2d_bp position(long position) { - return (depthwise_conv2d_bp)super.position(position); - } - +@Namespace("sd::ops") public static class depthwise_conv2d_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depthwise_conv2d_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depthwise_conv2d_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depthwise_conv2d_bp position(long position) { + return (depthwise_conv2d_bp)super.position(position); + } + public depthwise_conv2d_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * point-wise 2D convolution + * Expected input: + * x: 4D array + * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) + * bias: optional vector, length of oC + * + * IntArgs: + * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) + */ +@Namespace("sd::ops") public static class pointwise_conv2d extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pointwise_conv2d(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pointwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pointwise_conv2d position(long position) { + return (pointwise_conv2d)super.position(position); + } - /** - * point-wise 2D convolution - * Expected input: - * x: 4D array - * weight: 4D Array [1, 1, iC, oC] (NHWC) or [oC, iC, 1, 1] (NCHW) - * bias: optional vector, length of oC - * - * IntArgs: - * 0: data format: 1 NHWC, 0 NCHW (optional, by default = NHWC) - */ - @Namespace("sd::ops") public static class pointwise_conv2d extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pointwise_conv2d(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pointwise_conv2d(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pointwise_conv2d position(long position) { - return (pointwise_conv2d)super.position(position); - } - public pointwise_conv2d() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class deconv2d_tf extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public deconv2d_tf(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public deconv2d_tf(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public deconv2d_tf position(long position) { - return (deconv2d_tf)super.position(position); - } - +@Namespace("sd::ops") public static class deconv2d_tf extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public deconv2d_tf(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public deconv2d_tf(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public deconv2d_tf position(long position) { + return (deconv2d_tf)super.position(position); + } + public deconv2d_tf() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - - - + // namespace ops + // namespace sd // #endif @@ -15738,251 +18069,248 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_LIST_H // #include - // list operations, basically all around NDArrayList +// list operations, basically all around NDArrayList + +/** + * This operations puts given NDArray into (optionally) given NDArrayList. + * If no NDArrayList was provided - new one will be created + */ +// #if NOT_EXCLUDED(OP_write_list) +@Namespace("sd::ops") public static class write_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public write_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public write_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public write_list position(long position) { + return (write_list)super.position(position); + } - /** - * This operations puts given NDArray into (optionally) given NDArrayList. - * If no NDArrayList was provided - new one will be created - */ -// #if NOT_EXCLUDED(OP_write_list) - @Namespace("sd::ops") public static class write_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public write_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public write_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public write_list position(long position) { - return (write_list)super.position(position); - } - public write_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation concatenates given NDArrayList, and returns NDArray as result + */ +// #if NOT_EXCLUDED(OP_stack_list) +@Namespace("sd::ops") public static class stack_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stack_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stack_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stack_list position(long position) { + return (stack_list)super.position(position); + } - /** - * This operation concatenates given NDArrayList, and returns NDArray as result - */ -// #if NOT_EXCLUDED(OP_stack_list) - @Namespace("sd::ops") public static class stack_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stack_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stack_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stack_list position(long position) { - return (stack_list)super.position(position); - } - public stack_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations selects specified index fron NDArrayList and returns it as + * NDArray Expected arguments: x: non-empty list indices: optional, scalar with + * index + * + * Int args: + * optional, index + */ +// #if NOT_EXCLUDED(OP_read_list) +@Namespace("sd::ops") public static class read_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public read_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public read_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public read_list position(long position) { + return (read_list)super.position(position); + } - /** - * This operations selects specified index fron NDArrayList and returns it as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, scalar with index - * - * Int args: - * optional, index - */ -// #if NOT_EXCLUDED(OP_read_list) - @Namespace("sd::ops") public static class read_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public read_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public read_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public read_list position(long position) { - return (read_list)super.position(position); - } - public read_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations selects specified indices fron NDArrayList and returns them + * as NDArray Expected arguments: x: non-empty list indices: optional, vector + * with indices + * + * Int args: + * optional, indices + */ +// #if NOT_EXCLUDED(OP_pick_list) +@Namespace("sd::ops") public static class pick_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pick_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pick_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pick_list position(long position) { + return (pick_list)super.position(position); + } - /** - * This operations selects specified indices fron NDArrayList and returns them as NDArray - * Expected arguments: - * x: non-empty list - * indices: optional, vector with indices - * - * Int args: - * optional, indices - */ -// #if NOT_EXCLUDED(OP_pick_list) - @Namespace("sd::ops") public static class pick_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pick_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pick_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pick_list position(long position) { - return (pick_list)super.position(position); - } - public pick_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operations returns scalar, with number of existing arrays within given + * NDArrayList Expected arguments: x: list + */ +// #if NOT_EXCLUDED(OP_size_list) +@Namespace("sd::ops") public static class size_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size_list position(long position) { + return (size_list)super.position(position); + } - /** - * This operations returns scalar, with number of existing arrays within given NDArrayList - * Expected arguments: - * x: list - */ -// #if NOT_EXCLUDED(OP_size_list) - @Namespace("sd::ops") public static class size_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size_list position(long position) { - return (size_list)super.position(position); - } - public size_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation creates new empty NDArrayList + */ +// #if NOT_EXCLUDED(OP_create_list) +@Namespace("sd::ops") public static class create_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create_list position(long position) { + return (create_list)super.position(position); + } - /** - * This operation creates new empty NDArrayList - */ -// #if NOT_EXCLUDED(OP_create_list) - @Namespace("sd::ops") public static class create_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public create_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public create_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public create_list position(long position) { - return (create_list)super.position(position); - } - public create_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation unpacks given NDArray into specified NDArrayList wrt specified + * indices + */ +// #if NOT_EXCLUDED(OP_scatter_list) +@Namespace("sd::ops") public static class scatter_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_list position(long position) { + return (scatter_list)super.position(position); + } - /** - * This operation unpacks given NDArray into specified NDArrayList wrt specified indices - */ -// #if NOT_EXCLUDED(OP_scatter_list) - @Namespace("sd::ops") public static class scatter_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_list position(long position) { - return (scatter_list)super.position(position); - } - public scatter_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks, and stores them into given + * NDArrayList wert sizes Expected arguments: list: optional, NDArrayList. if + * not available - new NDArrayList will be created array: array to be split + * sizes: vector with sizes for each chunk + */ +// #if NOT_EXCLUDED(OP_split_list) +@Namespace("sd::ops") public static class split_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split_list position(long position) { + return (split_list)super.position(position); + } - /** - * This operation splits given NDArray into chunks, and stores them into given NDArrayList wert sizes - * Expected arguments: - * list: optional, NDArrayList. if not available - new NDArrayList will be created - * array: array to be split - * sizes: vector with sizes for each chunk - */ -// #if NOT_EXCLUDED(OP_split_list) - @Namespace("sd::ops") public static class split_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split_list position(long position) { - return (split_list)super.position(position); - } - public split_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation builds NDArray from NDArrayList using indices + * Expected arguments: + * x: non-empty list + * indices: vector with indices for gather operation + */ +// #if NOT_EXCLUDED(OP_gather_list) +@Namespace("sd::ops") public static class gather_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather_list position(long position) { + return (gather_list)super.position(position); + } - /** - * This operation builds NDArray from NDArrayList using indices - * Expected arguments: - * x: non-empty list - * indices: vector with indices for gather operation - */ -// #if NOT_EXCLUDED(OP_gather_list) - @Namespace("sd::ops") public static class gather_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather_list position(long position) { - return (gather_list)super.position(position); - } - public gather_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation clones given NDArrayList + */ +// #if NOT_EXCLUDED(OP_clone_list) +@Namespace("sd::ops") public static class clone_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clone_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clone_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clone_list position(long position) { + return (clone_list)super.position(position); + } - /** - * This operation clones given NDArrayList - */ -// #if NOT_EXCLUDED(OP_clone_list) - @Namespace("sd::ops") public static class clone_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clone_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clone_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clone_list position(long position) { - return (clone_list)super.position(position); - } - public clone_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation unstacks given NDArray into NDArrayList by the first dimension + */ +// #if NOT_EXCLUDED(OP_unstack_list) +@Namespace("sd::ops") public static class unstack_list extends DeclarableListOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unstack_list(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unstack_list(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unstack_list position(long position) { + return (unstack_list)super.position(position); + } - /** - * This operation unstacks given NDArray into NDArrayList by the first dimension - */ -// #if NOT_EXCLUDED(OP_unstack_list) - @Namespace("sd::ops") public static class unstack_list extends DeclarableListOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unstack_list(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unstack_list(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unstack_list position(long position) { - return (unstack_list)super.position(position); - } - public unstack_list() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -16013,714 +18341,742 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of cell output [bS x K x N] - * 1: 3d tensor of cell state [bS x K x N] - */ -// #if NOT_EXCLUDED(OP_sru) - @Namespace("sd::ops") public static class sru extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru position(long position) { - return (sru)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast + * as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of cell output [bS x K x N] + * 1: 3d tensor of cell state [bS x K x N] + */ +// #if NOT_EXCLUDED(OP_sru) +@Namespace("sd::ops") public static class sru extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru position(long position) { + return (sru)super.position(position); + } + public sru() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for Simple Recurrent Unit (bidirectional case): + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of cell output [N x bS x 2K] + * 1: 3d tensor of cell state [N x bS x 2K] + */ +// #if NOT_EXCLUDED(OP_sru_bi) +@Namespace("sd::ops") public static class sru_bi extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bi(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bi(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bi position(long position) { + return (sru_bi)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of cell output [N x bS x 2K] - * 1: 3d tensor of cell state [N x bS x 2K] - */ -// #if NOT_EXCLUDED(OP_sru_bi) - @Namespace("sd::ops") public static class sru_bi extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bi(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bi(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bi position(long position) { - return (sru_bi)super.position(position); - } - public sru_bi() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit: + * "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - + * batch size, K - number of features 1: 2d tensor of weights [3K x K] 2: row of + * biases with twice length [1 x 2K] 3: 2d tensor of previous cell state [bS x + * K] 4: 3d tensor of cell state [bS x K x N] 5: 2d tensor of cell state + * gradients [bS x K] 6: 3d tensor of state output gradients [bS x K x N] 7: + * optional, 2d tensor of dropout mask [bS x K] + * + * Output arrays: + * 0: 3d tensor of input gradients [bS x K x N] + * 1: 3d tensor of weights gradients [bS x 3K x K] + * 2: 2d, row of biases gradients [1 x 2K] + * 3: 2d, tensor of state gradients [bS x K] + */ +// #if NOT_EXCLUDED(OP_sru) +@Namespace("sd::ops") public static class sru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bp position(long position) { + return (sru_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [3K x K] - * 2: row of biases with twice length [1 x 2K] - * 3: 2d tensor of previous cell state [bS x K] - * 4: 3d tensor of cell state [bS x K x N] - * 5: 2d tensor of cell state gradients [bS x K] - * 6: 3d tensor of state output gradients [bS x K x N] - * 7: optional, 2d tensor of dropout mask [bS x K] - * - * Output arrays: - * 0: 3d tensor of input gradients [bS x K x N] - * 1: 3d tensor of weights gradients [bS x 3K x K] - * 2: 2d, row of biases gradients [1 x 2K] - * 3: 2d, tensor of state gradients [bS x K] - */ -// #if NOT_EXCLUDED(OP_sru) - @Namespace("sd::ops") public static class sru_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bp position(long position) { - return (sru_bp)super.position(position); - } - public sru_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for back propagation in Simple Recurrent Unit + * (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav + * Artzi + * + * Input arrays: + * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS + * - batch size, K - number of features 1: 2d tensor of weights [2K x 6K] 2: row + * of biases with twice length [1 x 4K] 3: 2d tensor of previous cell state [bS + * x 2K] 4: 3d tensor of cell state [N x bS x 2K] 5: 2d tensor of cell state + * gradients [bS x 2K] 6: 3d tensor of state output gradients [N x bS x 2K] 7: + * optional, 2d tensor of dropout mask [bS x 2K] + * + * Output arrays: + * 0: 3d tensor of input gradients [N x bS x 2K] + * 1: 3d tensor of weights gradients [N x 2K x 6K] + * 2: 2d, row of biases gradients [1 x 4K] + * 3: 2d, tensor of state gradients [bS x 2K] + */ +// #if NOT_EXCLUDED(OP_sru_bi) +@Namespace("sd::ops") public static class sru_bi_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sru_bi_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sru_bi_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sru_bi_bp position(long position) { + return (sru_bi_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for back propagation in Simple Recurrent Unit (bidirectional case): "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input 3d tensor with shape [N x bS x 2K], N - number of time steps, bS - batch size, K - number of features - * 1: 2d tensor of weights [2K x 6K] - * 2: row of biases with twice length [1 x 4K] - * 3: 2d tensor of previous cell state [bS x 2K] - * 4: 3d tensor of cell state [N x bS x 2K] - * 5: 2d tensor of cell state gradients [bS x 2K] - * 6: 3d tensor of state output gradients [N x bS x 2K] - * 7: optional, 2d tensor of dropout mask [bS x 2K] - * - * Output arrays: - * 0: 3d tensor of input gradients [N x bS x 2K] - * 1: 3d tensor of weights gradients [N x 2K x 6K] - * 2: 2d, row of biases gradients [1 x 4K] - * 3: 2d, tensor of state gradients [bS x 2K] - */ -// #if NOT_EXCLUDED(OP_sru_bi) - @Namespace("sd::ops") public static class sru_bi_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sru_bi_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sru_bi_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sru_bi_bp position(long position) { - return (sru_bi_bp)super.position(position); - } - public sru_bi_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with peep hole connections: + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation and https://research.google.com/pubs/archive/43905.pdf Hasim Sak, + * Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent + * neural network architectures for large scale acoustic modeling." INTERSPEECH, + * 2014. + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numProj], that is at + * previous time step t-1, in case of projection=false -> numProj=numUnits!!! 2: + * previous cell state [batchSize x numUnits], that is at previous time step + * t-1 3: input-to-hidden weights, [inSize x 4*numUnits] 4: hidden-to-hidden + * weights, [numProj x 4*numUnits] 5: diagonal weights for peephole connections + * [3*numUnits] 6: projection weights [numUnits x numProj] 7: biases, + * [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: current cell output [batchSize x numProj], that is at current time step + * t 1: current cell state [batchSize x numUnits], that is at current time step + * t + */ +// #if NOT_EXCLUDED(OP_lstmCell) +@Namespace("sd::ops") public static class lstmCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmCell position(long position) { + return (lstmCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!! - * 2: previous cell state [batchSize x numUnits], that is at previous time step t-1 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: current cell output [batchSize x numProj], that is at current time step t - * 1: current cell state [batchSize x numUnits], that is at current time step t - */ -// #if NOT_EXCLUDED(OP_lstmCell) - @Namespace("sd::ops") public static class lstmCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmCell position(long position) { - return (lstmCell)super.position(position); - } - public lstmCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_lstmLayerCell) +@Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCell position(long position) { + return (lstmLayerCell)super.position(position); + } -// #if NOT_EXCLUDED(OP_lstmLayerCell) - @Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayerCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayerCell position(long position) { - return (lstmLayerCell)super.position(position); - } - public lstmLayerCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_lstmLayerCell) - @Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayerCellBp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayerCellBp position(long position) { - return (lstmLayerCellBp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) +@Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCellBp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCellBp position(long position) { + return (lstmLayerCellBp)super.position(position); + } + public lstmLayerCellBp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM cell with optional peep hole + * connections: S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". + * Neural Computation and https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. See also: https://arxiv.org/pdf/1503.04069.pdf + * + * Input arrays: + * 0: input [bS, inSize] at time t + * 1: previous cell state [bS, numUnits], time t-1 + * 2: previous output [bS, numUnits], time t-1 + * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 4: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 5: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 6: weights - cell peephole (t) + * connections to output gate, [numUnits] 7: biases, shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations [bS, numUnits] + * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) + * 2: f - Output - forget gate activations [bs, numUnits] + * 3: o - Output - output gate activations [bs, numUnits] + * 4: z (ci) - Output - block input [bs, numUnits] + * 5: h (co) - Cell state, post tanh [bs, numUnits] + * 6: y (h) - Current cell output [bS, numUnits], time t + */ +// #if NOT_EXCLUDED(OP_lstmBlockCell) +@Namespace("sd::ops") public static class lstmBlockCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmBlockCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmBlockCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmBlockCell position(long position) { + return (lstmBlockCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM cell with optional peep hole connections: - * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation - * and - * https://research.google.com/pubs/archive/43905.pdf - * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. - * See also: https://arxiv.org/pdf/1503.04069.pdf - * - * Input arrays: - * 0: input [bS, inSize] at time t - * 1: previous cell state [bS, numUnits], time t-1 - * 2: previous output [bS, numUnits], time t-1 - * 3: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 4: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 5: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 6: weights - cell peephole (t) connections to output gate, [numUnits] - * 7: biases, shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations [bS, numUnits] - * 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs) - * 2: f - Output - forget gate activations [bs, numUnits] - * 3: o - Output - output gate activations [bs, numUnits] - * 4: z (ci) - Output - block input [bs, numUnits] - * 5: h (co) - Cell state, post tanh [bs, numUnits] - * 6: y (h) - Current cell output [bS, numUnits], time t - */ -// #if NOT_EXCLUDED(OP_lstmBlockCell) - @Namespace("sd::ops") public static class lstmBlockCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmBlockCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmBlockCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmBlockCell position(long position) { - return (lstmBlockCell)super.position(position); - } - public lstmBlockCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation for LSTM layer with optional peep hole + * connections. See lstmBlockCell for details. lstmBlockCell is used internally + * for computation. This method expects as input (and returns as output) + * sequences in one of 3 formats, depending on the data format arg: dataFormat = + * 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to + * as "time major" dataFormat = 1 -> NST: shape [numExamples, inOutSize, + * timeLength] dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] + * - TF "time_major=false" layout + * + * + * Input arrays: + * 0: max sequence length; long/int64 scalar + * 1: input [seqLength, bS, inSize] at time t + * 2: previous/initial cell state [bS, numUnits] + * 3: previous/initial output [bS, numUnits] + * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) + * weights, [(inSize+numUnits), 4*numUnits] 5: weights - cell peephole (t-1) + * connections to input modulation gate, [numUnits] 6: weights - cell peephole + * (t-1) connections to forget gate, [numUnits] 7: weights - cell peephole (t) + * connections to output gate, [numUnits] 8: biases, Shape [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; + * 2=NTS=[mb,seqLen,size] + * + * Input float arguments: + * 0: the bias added to forget gates in order to reduce the scale of + * forgetting in the beginning of the training 1: clipping value for cell state, + * if it is not equal to zero, then cell state is clipped + * + * Output arrays: + * 0: i - Input modulation gate activations, rank 3, shape as per + * dataFormat 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat + * 2: f - Output - forget gate activations, rank 3, shape as per + * dataFormat 3: o - Output - output gate activations, rank 3, shape as per + * dataFormat 4: z (ci) - Output - block input, rank 3, shape as per dataFormat + * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat + * 6: y (h) - Current cell output, rank 3, shape as per dataFormat + */ +// #if NOT_EXCLUDED(OP_lstmBlock) +@Namespace("sd::ops") public static class lstmBlock extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmBlock(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmBlock(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmBlock position(long position) { + return (lstmBlock)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation for LSTM layer with optional peep hole connections. - * See lstmBlockCell for details. lstmBlockCell is used internally for computation. - * This method expects as input (and returns as output) sequences in one of 3 formats, depending on the data format arg: - * dataFormat = 0 -> TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major" - * dataFormat = 1 -> NST: shape [numExamples, inOutSize, timeLength] - * dataFormat = 2 -> NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout - * - * - * Input arrays: - * 0: max sequence length; long/int64 scalar - * 1: input [seqLength, bS, inSize] at time t - * 2: previous/initial cell state [bS, numUnits] - * 3: previous/initial output [bS, numUnits] - * 4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits] - * 5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits] - * 6: weights - cell peephole (t-1) connections to forget gate, [numUnits] - * 7: weights - cell peephole (t) connections to output gate, [numUnits] - * 8: biases, Shape [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size] - * - * Input float arguments: - * 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * - * Output arrays: - * 0: i - Input modulation gate activations, rank 3, shape as per dataFormat - * 1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat - * 2: f - Output - forget gate activations, rank 3, shape as per dataFormat - * 3: o - Output - output gate activations, rank 3, shape as per dataFormat - * 4: z (ci) - Output - block input, rank 3, shape as per dataFormat - * 5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat - * 6: y (h) - Current cell output, rank 3, shape as per dataFormat - */ -// #if NOT_EXCLUDED(OP_lstmBlock) - @Namespace("sd::ops") public static class lstmBlock extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmBlock(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmBlock(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmBlock position(long position) { - return (lstmBlock)super.position(position); - } - public lstmBlock() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) +@Namespace("sd::ops") public static class lstmLayer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer position(long position) { + return (lstmLayer)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// -// #if NOT_EXCLUDED(OP_lstmLayer) - @Namespace("sd::ops") public static class lstmLayer extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayer position(long position) { - return (lstmLayer)super.position(position); - } - public lstmLayer() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) +@Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer_bp position(long position) { + return (lstmLayer_bp)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// -// #if NOT_EXCLUDED(OP_lstmLayer) - @Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstmLayer_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstmLayer_bp position(long position) { - return (lstmLayer_bp)super.position(position); - } - public lstmLayer_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs + * as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell state [batchSize x inSize], that is at + * previous time step t-1 2: weights [inSize x 3*inSize] 3: biases [1 x + * 2*inSize] + * + * Output arrays: + * 0: current cell output [batchSize x inSize], that is at current time step + * t 1: current cell state [batchSize x inSize], that is at current time step t + */ +// #if NOT_EXCLUDED(OP_sruCell) +@Namespace("sd::ops") public static class sruCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sruCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sruCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sruCell position(long position) { + return (sruCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell state [batchSize x inSize], that is at previous time step t-1 - * 2: weights [inSize x 3*inSize] - * 3: biases [1 x 2*inSize] - * - * Output arrays: - * 0: current cell output [batchSize x inSize], that is at current time step t - * 1: current cell state [batchSize x inSize], that is at current time step t - */ -// #if NOT_EXCLUDED(OP_sruCell) - @Namespace("sd::ops") public static class sruCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sruCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sruCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sruCell position(long position) { - return (sruCell)super.position(position); - } - public sruCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit cell: + * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, + * Fethi Bougares, Holger Schwenk, Yoshua Bengio "Learning Phrase + * Representations using RNN Encoder-Decoder for Statistical Machine + * Translation" + * + * Input arrays: + * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - + * number of features 1: previous cell output [batchSize x numUnits], that is + * at previous time step t-1 2: RU weights - [(inSize+numUnits), 2*numUnits] - + * reset and update gates (input/recurrent weights) 3: C weights - + * [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) 4: reset + * and update biases, [2*numUnits] - reset and update gates 5: cell biases, + * [numUnits] + * + * Output arrays: + * 0: Reset gate output [bS, numUnits] + * 1: Update gate output [bS, numUnits] + * 2: Cell gate output [bS, numUnits] + * 3: Current cell output [bS, numUnits] + */ +// #if NOT_EXCLUDED(OP_gruCell) +@Namespace("sd::ops") public static class gruCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gruCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gruCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gruCell position(long position) { + return (gruCell)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit cell: - * Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio - * "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - * - * Input arrays: - * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features - * 1: previous cell output [batchSize x numUnits], that is at previous time step t-1 - * 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights) - * 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) - * 4: reset and update biases, [2*numUnits] - reset and update gates - * 5: cell biases, [numUnits] - * - * Output arrays: - * 0: Reset gate output [bS, numUnits] - * 1: Update gate output [bS, numUnits] - * 2: Cell gate output [bS, numUnits] - * 3: Current cell output [bS, numUnits] - */ -// #if NOT_EXCLUDED(OP_gruCell) - @Namespace("sd::ops") public static class gruCell extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gruCell(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gruCell(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gruCell position(long position) { - return (gruCell)super.position(position); - } - public gruCell() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gruCell) +@Namespace("sd::ops") public static class gruCell_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gruCell_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gruCell_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gruCell_bp position(long position) { + return (gruCell_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_gruCell) - @Namespace("sd::ops") public static class gruCell_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gruCell_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gruCell_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gruCell_bp position(long position) { - return (gruCell_bp)super.position(position); - } - public gruCell_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "LSTM time sequences" with peep hole connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numProj], that is at time step = 0, in case of + * projection=false -> numProj=numUnits!!! 2: initial cell state [batchSize x + * numUnits], that is at time step = 0 3: input-to-hidden weights, [inSize x + * 4*numUnits] 4: hidden-to-hidden weights, [numProj x 4*numUnits] 5: diagonal + * weights for peephole connections [3*numUnits] 6: projection weights [numUnits + * x numProj] 7: biases, [4*numUnits] + * + * Input integer arguments: + * 0: if not zero, provide peephole connections + * 1: if not zero, then projection is performed, if zero then + * numProj==numUnits is mandatory! + * + * Input float arguments: + * 0: clipping value for cell state, if it is not equal to zero, then cell + * state is clipped 1: clipping value for projected cell output, if it is not + * equal to zero, then projected cell output is clipped 2: the bias added to + * forget gates in order to reduce the scale of forgetting in the beginning of + * the training + * + * Output arrays: + * 0: cell outputs [time x batchSize x numProj], that is per each time step + * 1: cell states [time x batchSize x numUnits], that is per each time step + */ +// #if NOT_EXCLUDED(OP_lstm) +@Namespace("sd::ops") public static class lstm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstm position(long position) { + return (lstm)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "LSTM time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numProj], that is at time step = 0, in case of projection=false -> numProj=numUnits!!! - * 2: initial cell state [batchSize x numUnits], that is at time step = 0 - * 3: input-to-hidden weights, [inSize x 4*numUnits] - * 4: hidden-to-hidden weights, [numProj x 4*numUnits] - * 5: diagonal weights for peephole connections [3*numUnits] - * 6: projection weights [numUnits x numProj] - * 7: biases, [4*numUnits] - * - * Input integer arguments: - * 0: if not zero, provide peephole connections - * 1: if not zero, then projection is performed, if zero then numProj==numUnits is mandatory! - * - * Input float arguments: - * 0: clipping value for cell state, if it is not equal to zero, then cell state is clipped - * 1: clipping value for projected cell output, if it is not equal to zero, then projected cell output is clipped - * 2: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training - * - * Output arrays: - * 0: cell outputs [time x batchSize x numProj], that is per each time step - * 1: cell states [time x batchSize x numUnits], that is per each time step - */ -// #if NOT_EXCLUDED(OP_lstm) - @Namespace("sd::ops") public static class lstm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstm position(long position) { - return (lstm)super.position(position); - } - public lstm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of gated Recurrent Unit: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: initial cell + * output [batchSize x numUnits], that is at time step = 0 2: input-to-hidden + * weights, [inSize x 3*numUnits] 3: hidden-to-hidden weights, [numUnits x + * 3*numUnits] 4: biases, [3*numUnits] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits], that is per each time step + */ +// #if NOT_EXCLUDED(OP_gru) +@Namespace("sd::ops") public static class gru extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru position(long position) { + return (gru)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of gated Recurrent Unit: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: initial cell output [batchSize x numUnits], that is at time step = 0 - * 2: input-to-hidden weights, [inSize x 3*numUnits] - * 3: hidden-to-hidden weights, [numUnits x 3*numUnits] - * 4: biases, [3*numUnits] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits], that is per each time step - */ -// #if NOT_EXCLUDED(OP_gru) - @Namespace("sd::ops") public static class gru extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gru(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gru(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gru position(long position) { - return (gru)super.position(position); - } - public gru() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gru) +@Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru_bp position(long position) { + return (gru_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_gru) - @Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gru_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gru_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gru_bp position(long position) { - return (gru_bp)super.position(position); - } - public gru_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights, [inSize x numUnits] 2: hidden-to-hidden weights, [numUnits x + * numUnits] 3: biases, [2*numUnits] 4: (optional) initial cell output + * [batchSize x numUnits], that is at time step = 0 5: (optional) vector with + * shape [batchSize] containing integer values within [0,time), each element of + * this vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] + * 1: cell final non-zero output [batchSize x numUnits] + */ +@Namespace("sd::ops") public static class static_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public static_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public static_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public static_rnn position(long position) { + return (static_rnn)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - @Namespace("sd::ops") public static class static_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public static_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public static_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public static_rnn position(long position) { - return (static_rnn)super.position(position); - } - public static_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x numUnits], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights, [inSize x numUnits] - * 2: hidden-to-hidden weights, [numUnits x numUnits] - * 3: biases, [2*numUnits] - * 4: (optional) initial cell output [batchSize x numUnits], that is at time step = 0 - * 5: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x numUnits] - * 1: cell final non-zero output [batchSize x numUnits] - */ - @Namespace("sd::ops") public static class dynamic_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_rnn position(long position) { - return (dynamic_rnn)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * numUnits], time - number of time steps, batchSize - batch size, inSize - + * number of features 1: input-to-hidden weights, [inSize x numUnits] 2: + * hidden-to-hidden weights, [numUnits x numUnits] 3: biases, [2*numUnits] 4: + * (optional) initial cell output [batchSize x numUnits], that is at time step = + * 0 5: (optional) vector with shape [batchSize] containing integer values + * within [0,time), each element of this vector set max time step per each input + * in batch, this provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs [time x batchSize x numUnits] or [batchSize x time x + * numUnits] 1: cell final non-zero output [batchSize x numUnits] + */ +@Namespace("sd::ops") public static class dynamic_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_rnn position(long position) { + return (dynamic_rnn)super.position(position); + } + public dynamic_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Output arrays: - * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] - * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - @Namespace("sd::ops") public static class static_bidirectional_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public static_bidirectional_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public static_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public static_bidirectional_rnn position(long position) { - return (static_bidirectional_rnn)super.position(position); - } - - public static_bidirectional_rnn() { super((Pointer)null); allocate(); } - private native void allocate(); +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize], time - number of time + * steps, batchSize - batch size, inSize - number of features 1: input-to-hidden + * weights for forward RNN, [inSize x numUnitsFW] 2: hidden-to-hidden weights + * for forward RNN, [numUnitsFW x numUnitsFW] 3: biases for forward RNN, + * [2*numUnitsFW] 4: input-to-hidden weights for backward RNN, [inSize x + * numUnitsBW] 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x + * numUnitsBW] 6: biases for backward RNN, [2*numUnitsBW] 7: (optional) initial + * cell output for forward RNN [batchSize x numUnitsFW], that is at time step = + * 0 8: (optional) initial cell output for backward RNN [batchSize x + * numUnitsBW], that is at time step = 0 9: (optional) vector with shape + * [batchSize] containing integer values within [0,time), each element of this + * vector set max time step per each input in batch, this provides no + * calculations for time >= maxTimeStep + * + * Output arrays: + * 0: cell outputs [time x batchSize x (numUnitsFW + numUnitsBW)] + * 1: cell final non-zero output for forward RNN [batchSize x numUnitsFW] + * 2: cell final non-zero output for backward RNN [batchSize x numUnitsBW] + */ +@Namespace("sd::ops") public static class static_bidirectional_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public static_bidirectional_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public static_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public static_bidirectional_rnn position(long position) { + return (static_bidirectional_rnn)super.position(position); + } + + public static_bidirectional_rnn() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of operation "static RNN time sequences" with peep hole connections: - * - * Input arrays: - * 0: input with shape [time x batchSize x inSize] or [batchSize x time x inSize], time - number of time steps, batchSize - batch size, inSize - number of features - * 1: input-to-hidden weights for forward RNN, [inSize x numUnitsFW] - * 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW] - * 3: biases for forward RNN, [2*numUnitsFW] - * 4: input-to-hidden weights for backward RNN, [inSize x numUnitsBW] - * 5: hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW] - * 6: biases for backward RNN, [2*numUnitsBW] - * 7: (optional) initial cell output for forward RNN [batchSize x numUnitsFW], that is at time step = 0 - * 8: (optional) initial cell output for backward RNN [batchSize x numUnitsBW], that is at time step = 0 - * 9: (optional) vector with shape [batchSize] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this provides no calculations for time >= maxTimeStep - * - * Input integer arguments: - * 0: (optional) timeMajor - if non zero then input shape is [time, batchSize, ...], else [batchSize, time, ...] - * - * Output arrays: - * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or [batchSize x time x numUnitsFW] - * 1: cell outputs for backward RNN [time x batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] - * 2: cell final non-zero output for forward RNN [batchSize x numUnitsFW] - * 3: cell final non-zero output for backward RNN [batchSize x numUnitsBW] - */ - @Namespace("sd::ops") public static class dynamic_bidirectional_rnn extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_bidirectional_rnn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_bidirectional_rnn position(long position) { - return (dynamic_bidirectional_rnn)super.position(position); - } - +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of operation "static RNN time sequences" with peep hole + * connections: + * + * Input arrays: + * 0: input with shape [time x batchSize x inSize] or [batchSize x time x + * inSize], time - number of time steps, batchSize - batch size, inSize - number + * of features 1: input-to-hidden weights for forward RNN, [inSize x + * numUnitsFW] 2: hidden-to-hidden weights for forward RNN, [numUnitsFW x + * numUnitsFW] 3: biases for forward RNN, [2*numUnitsFW] 4: input-to-hidden + * weights for backward RNN, [inSize x numUnitsBW] 5: hidden-to-hidden weights + * for backward RNN, [numUnitsBW x numUnitsBW] 6: biases for backward RNN, + * [2*numUnitsBW] 7: (optional) initial cell output for forward RNN [batchSize x + * numUnitsFW], that is at time step = 0 8: (optional) initial cell output for + * backward RNN [batchSize x numUnitsBW], that is at time step = 0 9: (optional) + * vector with shape [batchSize] containing integer values within [0,time), each + * element of this vector set max time step per each input in batch, this + * provides no calculations for time >= maxTimeStep + * + * Input integer arguments: + * 0: (optional) timeMajor - if non zero then input shape is [time, + * batchSize, ...], else [batchSize, time, ...] + * + * Output arrays: + * 0: cell outputs for forward RNN [time x batchSize x numUnitsFW] or + * [batchSize x time x numUnitsFW] 1: cell outputs for backward RNN [time x + * batchSize x numUnitsBW] or [batchSize x time x numUnitsBW] 2: cell final + * non-zero output for forward RNN [batchSize x numUnitsFW] 3: cell final + * non-zero output for backward RNN [batchSize x numUnitsBW] + */ +@Namespace("sd::ops") public static class dynamic_bidirectional_rnn extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_bidirectional_rnn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_bidirectional_rnn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_bidirectional_rnn position(long position) { + return (dynamic_bidirectional_rnn)super.position(position); + } + public dynamic_bidirectional_rnn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - - + // namespace ops + // namespace sd // #endif // Parsed from ops/declarable/headers/transforms.h @@ -16749,839 +19105,840 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_TRANSFORMS_H // #include -// #if NOT_EXCLUDED(OP_clipbyvalue) - @Namespace("sd::ops") public static class clipbyvalue extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyvalue(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyvalue(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyvalue position(long position) { - return (clipbyvalue)super.position(position); - } - +// #if NOT_EXCLUDED(OP_clipbyvalue) +@Namespace("sd::ops") public static class clipbyvalue extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyvalue(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyvalue(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyvalue position(long position) { + return (clipbyvalue)super.position(position); + } + public clipbyvalue() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_clipbynorm) +@Namespace("sd::ops") public static class clipbynorm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbynorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbynorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbynorm position(long position) { + return (clipbynorm)super.position(position); + } -// #if NOT_EXCLUDED(OP_clipbynorm) - @Namespace("sd::ops") public static class clipbynorm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbynorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbynorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbynorm position(long position) { - return (clipbynorm)super.position(position); - } - public clipbynorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class clipbynorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbynorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbynorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbynorm_bp position(long position) { - return (clipbynorm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class clipbynorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbynorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbynorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbynorm_bp position(long position) { + return (clipbynorm_bp)super.position(position); + } + public clipbynorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_clipbyavgnorm) +@Namespace("sd::ops") public static class clipbyavgnorm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyavgnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyavgnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyavgnorm position(long position) { + return (clipbyavgnorm)super.position(position); + } -// #if NOT_EXCLUDED(OP_clipbyavgnorm) - @Namespace("sd::ops") public static class clipbyavgnorm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyavgnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyavgnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyavgnorm position(long position) { - return (clipbyavgnorm)super.position(position); - } - public clipbyavgnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clipbyavgnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clipbyavgnorm_bp position(long position) { - return (clipbyavgnorm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clipbyavgnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clipbyavgnorm_bp position(long position) { + return (clipbyavgnorm_bp)super.position(position); + } + public clipbyavgnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumsum) +@Namespace("sd::ops") public static class cumsum extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumsum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumsum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumsum position(long position) { + return (cumsum)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumsum) - @Namespace("sd::ops") public static class cumsum extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumsum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumsum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumsum position(long position) { - return (cumsum)super.position(position); - } - public cumsum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumprod) +@Namespace("sd::ops") public static class cumprod extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumprod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumprod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumprod position(long position) { + return (cumprod)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumprod) - @Namespace("sd::ops") public static class cumprod extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumprod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumprod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumprod position(long position) { - return (cumprod)super.position(position); - } - public cumprod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_tile) +@Namespace("sd::ops") public static class tile extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile position(long position) { + return (tile)super.position(position); + } -// #if NOT_EXCLUDED(OP_tile) - @Namespace("sd::ops") public static class tile extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile position(long position) { - return (tile)super.position(position); - } - public tile() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tile_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_bp position(long position) { - return (tile_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tile_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_bp position(long position) { + return (tile_bp)super.position(position); + } + public tile_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_repeat) +@Namespace("sd::ops") public static class repeat extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public repeat(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public repeat(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public repeat position(long position) { + return (repeat)super.position(position); + } -// #if NOT_EXCLUDED(OP_repeat) - @Namespace("sd::ops") public static class repeat extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public repeat(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public repeat(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public repeat position(long position) { - return (repeat)super.position(position); - } - public repeat() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_invert_permutation) +@Namespace("sd::ops") public static class invert_permutation extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public invert_permutation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public invert_permutation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public invert_permutation position(long position) { + return (invert_permutation)super.position(position); + } -// #if NOT_EXCLUDED(OP_invert_permutation) - @Namespace("sd::ops") public static class invert_permutation extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public invert_permutation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public invert_permutation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public invert_permutation position(long position) { - return (invert_permutation)super.position(position); - } - public invert_permutation() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +@Namespace("sd::ops") public static class concat extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public concat(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public concat(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public concat position(long position) { + return (concat)super.position(position); + } - @Namespace("sd::ops") public static class concat extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public concat(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public concat(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public concat position(long position) { - return (concat)super.position(position); - } - public concat() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class concat_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public concat_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public concat_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public concat_bp position(long position) { - return (concat_bp)super.position(position); - } - +@Namespace("sd::ops") public static class concat_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public concat_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public concat_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public concat_bp position(long position) { + return (concat_bp)super.position(position); + } + public concat_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #if NOT_EXCLUDED(OP_mergemax) - @Namespace("sd::ops") public static class mergemax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemax position(long position) { - return (mergemax)super.position(position); - } - +// #if NOT_EXCLUDED(OP_mergemax) +@Namespace("sd::ops") public static class mergemax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax position(long position) { + return (mergemax)super.position(position); + } + public mergemax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemax_bp position(long position) { - return (mergemax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemax_bp position(long position) { + return (mergemax_bp)super.position(position); + } + public mergemax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * Complete tensor with max indices merged from all input tensors list - * - * INPUT: tensors with the same shape - * OUTPUT: integer tensor with the same shape - * INT_ARG: result type (one of int), INT32 by default - */ -// #if NOT_EXCLUDED(OP_mergemaxindex) - @Namespace("sd::ops") public static class mergemaxindex extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergemaxindex(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergemaxindex(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergemaxindex position(long position) { - return (mergemaxindex)super.position(position); - } - +// #endif +/* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ +// #if NOT_EXCLUDED(OP_mergemaxindex) +@Namespace("sd::ops") public static class mergemaxindex extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergemaxindex(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergemaxindex(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergemaxindex position(long position) { + return (mergemaxindex)super.position(position); + } + public mergemaxindex() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mergeadd) +@Namespace("sd::ops") public static class mergeadd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd position(long position) { + return (mergeadd)super.position(position); + } + + public mergeadd() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +@Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeadd_bp position(long position) { + return (mergeadd_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_mergeadd) - @Namespace("sd::ops") public static class mergeadd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeadd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeadd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeadd position(long position) { - return (mergeadd)super.position(position); - } - - public mergeadd() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeadd_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeadd_bp position(long position) { - return (mergeadd_bp)super.position(position); - } - public mergeadd_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_mergeavg) +@Namespace("sd::ops") public static class mergeavg extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg position(long position) { + return (mergeavg)super.position(position); + } -// #if NOT_EXCLUDED(OP_mergeavg) - @Namespace("sd::ops") public static class mergeavg extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeavg(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeavg(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeavg position(long position) { - return (mergeavg)super.position(position); - } - public mergeavg() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mergeavg_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mergeavg_bp position(long position) { - return (mergeavg_bp)super.position(position); - } - +@Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mergeavg_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mergeavg_bp position(long position) { + return (mergeavg_bp)super.position(position); + } + public mergeavg_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_scatter_update) +@Namespace("sd::ops") public static class scatter_update extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_update(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_update(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_update position(long position) { + return (scatter_update)super.position(position); + } -// #if NOT_EXCLUDED(OP_scatter_update) - @Namespace("sd::ops") public static class scatter_update extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_update(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_update(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_update position(long position) { - return (scatter_update)super.position(position); - } - public scatter_update() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_Floor) +@Namespace("sd::ops") public static class Floor extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Floor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Floor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Floor position(long position) { + return (Floor)super.position(position); + } -// #if NOT_EXCLUDED(OP_Floor) - @Namespace("sd::ops") public static class Floor extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Floor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Floor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Floor position(long position) { - return (Floor)super.position(position); - } - public Floor() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_Log1p) +@Namespace("sd::ops") public static class Log1p extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Log1p(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Log1p(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Log1p position(long position) { + return (Log1p)super.position(position); + } -// #if NOT_EXCLUDED(OP_Log1p) - @Namespace("sd::ops") public static class Log1p extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Log1p(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Log1p(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Log1p position(long position) { - return (Log1p)super.position(position); - } - public Log1p() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reverse) +@Namespace("sd::ops") public static class reverse extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse position(long position) { + return (reverse)super.position(position); + } -// #if NOT_EXCLUDED(OP_reverse) - @Namespace("sd::ops") public static class reverse extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse position(long position) { - return (reverse)super.position(position); - } - public reverse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reverse_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse_bp position(long position) { - return (reverse_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reverse_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse_bp position(long position) { + return (reverse_bp)super.position(position); + } + public reverse_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gather) +@Namespace("sd::ops") public static class gather extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather position(long position) { + return (gather)super.position(position); + } -// #if NOT_EXCLUDED(OP_gather) - @Namespace("sd::ops") public static class gather extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather position(long position) { - return (gather)super.position(position); - } - public gather() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_pad) +@Namespace("sd::ops") public static class pad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public pad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public pad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public pad position(long position) { + return (pad)super.position(position); + } -// #if NOT_EXCLUDED(OP_pad) - @Namespace("sd::ops") public static class pad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public pad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public pad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public pad position(long position) { - return (pad)super.position(position); - } - public pad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * creates identity 2D matrix or batch of identical 2D identity matrices + * + * Input array: + * provide some array - in any case operation simply neglects it + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * + * Input integer arguments: + * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> + * 'f'-order IArgs[1] - the number of rows in output inner-most 2D + * identity matrix IArgs[2] - optional, the number of columns in output + * inner-most 2D identity matrix, if this argument is not provided then it is + * taken to be equal to number of rows IArgs[3,4,...] - optional, shape of + * batch, output matrix will have leading batch dimensions of this shape + */ +// #if NOT_EXCLUDED(OP_eye) +@Namespace("sd::ops") public static class eye extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public eye(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public eye(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public eye position(long position) { + return (eye)super.position(position); + } - /** - * creates identity 2D matrix or batch of identical 2D identity matrices - * - * Input array: - * provide some array - in any case operation simply neglects it - * - * Input float argument (if passed): - * TArgs[0] - type of elements of output array, default value is 5 (float) - * - * Input integer arguments: - * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order - * IArgs[1] - the number of rows in output inner-most 2D identity matrix - * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape - */ -// #if NOT_EXCLUDED(OP_eye) - @Namespace("sd::ops") public static class eye extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public eye(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public eye(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public eye position(long position) { - return (eye)super.position(position); - } - public eye() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_gather_nd) +@Namespace("sd::ops") public static class gather_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gather_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gather_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gather_nd position(long position) { + return (gather_nd)super.position(position); + } -// #if NOT_EXCLUDED(OP_gather_nd) - @Namespace("sd::ops") public static class gather_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public gather_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public gather_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public gather_nd position(long position) { - return (gather_nd)super.position(position); - } - public gather_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reverse_sequence) +@Namespace("sd::ops") public static class reverse_sequence extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reverse_sequence(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reverse_sequence(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reverse_sequence position(long position) { + return (reverse_sequence)super.position(position); + } -// #if NOT_EXCLUDED(OP_reverse_sequence) - @Namespace("sd::ops") public static class reverse_sequence extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reverse_sequence(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reverse_sequence(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reverse_sequence position(long position) { - return (reverse_sequence)super.position(position); - } - public reverse_sequence() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_trace) +@Namespace("sd::ops") public static class trace extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public trace(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public trace(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public trace position(long position) { + return (trace)super.position(position); + } -// #if NOT_EXCLUDED(OP_trace) - @Namespace("sd::ops") public static class trace extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public trace(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public trace(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public trace position(long position) { - return (trace)super.position(position); - } - public trace() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_shuffle) +@Namespace("sd::ops") public static class random_shuffle extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_shuffle(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_shuffle(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_shuffle position(long position) { + return (random_shuffle)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_shuffle) - @Namespace("sd::ops") public static class random_shuffle extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_shuffle(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_shuffle(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_shuffle position(long position) { - return (random_shuffle)super.position(position); - } - public random_shuffle() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * clip a list of given tensors with given average norm when needed + * + * Input: + * a list of tensors (at least one) + * + * Input floating point argument: + * clip_norm - a value that used as threshold value and norm to be used + * + * return a list of clipped tensors + * and global_norm as scalar tensor at the end + */ +// #if NOT_EXCLUDED(OP_clip_by_global_norm) +@Namespace("sd::ops") public static class clip_by_global_norm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public clip_by_global_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public clip_by_global_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public clip_by_global_norm position(long position) { + return (clip_by_global_norm)super.position(position); + } - /** - * clip a list of given tensors with given average norm when needed - * - * Input: - * a list of tensors (at least one) - * - * Input floating point argument: - * clip_norm - a value that used as threshold value and norm to be used - * - * return a list of clipped tensors - * and global_norm as scalar tensor at the end - */ -// #if NOT_EXCLUDED(OP_clip_by_global_norm) - @Namespace("sd::ops") public static class clip_by_global_norm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public clip_by_global_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public clip_by_global_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public clip_by_global_norm position(long position) { - return (clip_by_global_norm)super.position(position); - } - public clip_by_global_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +@Namespace("sd::ops") public static class tri extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tri(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tri(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tri position(long position) { + return (tri)super.position(position); + } - @Namespace("sd::ops") public static class tri extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tri(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tri(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tri position(long position) { - return (tri)super.position(position); - } - public tri() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class triu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triu position(long position) { - return (triu)super.position(position); - } - +@Namespace("sd::ops") public static class triu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triu position(long position) { + return (triu)super.position(position); + } + public triu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class triu_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triu_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triu_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triu_bp position(long position) { - return (triu_bp)super.position(position); - } - +@Namespace("sd::ops") public static class triu_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triu_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triu_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triu_bp position(long position) { + return (triu_bp)super.position(position); + } + public triu_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #if NOT_EXCLUDED(OP_mirror_pad) - @Namespace("sd::ops") public static class mirror_pad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mirror_pad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mirror_pad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mirror_pad position(long position) { - return (mirror_pad)super.position(position); - } - +// #if NOT_EXCLUDED(OP_mirror_pad) +@Namespace("sd::ops") public static class mirror_pad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mirror_pad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mirror_pad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mirror_pad position(long position) { + return (mirror_pad)super.position(position); + } + public mirror_pad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumsum) +@Namespace("sd::ops") public static class cumsum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumsum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumsum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumsum_bp position(long position) { + return (cumsum_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumsum) - @Namespace("sd::ops") public static class cumsum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumsum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumsum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumsum_bp position(long position) { - return (cumsum_bp)super.position(position); - } - public cumsum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_cumprod) +@Namespace("sd::ops") public static class cumprod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cumprod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cumprod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cumprod_bp position(long position) { + return (cumprod_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_cumprod) - @Namespace("sd::ops") public static class cumprod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cumprod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cumprod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cumprod_bp position(long position) { - return (cumprod_bp)super.position(position); - } - public cumprod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_flatten) +@Namespace("sd::ops") public static class flatten extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public flatten(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public flatten(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public flatten position(long position) { + return (flatten)super.position(position); + } -// #if NOT_EXCLUDED(OP_flatten) - @Namespace("sd::ops") public static class flatten extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public flatten(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public flatten(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public flatten position(long position) { - return (flatten)super.position(position); - } - public flatten() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * returns histogram (as 1D array) with fixed bins width + * + * Input arrays: + * - input array with elements to be binned into output histogram + * - range array with first element being bottom limit and second element being + top limit of histogram, please note that input_value <= range[0] will be mapped + to histogram[0], input_value >= range[1] will be mapped to histogram[-1] + * + * Input integer arguments: + * nbins (optional) - number of histogram bins, default value is 100 + */ +// #if NOT_EXCLUDED(OP_histogram_fixed_width) +@Namespace("sd::ops") public static class histogram_fixed_width extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public histogram_fixed_width(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public histogram_fixed_width(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public histogram_fixed_width position(long position) { + return (histogram_fixed_width)super.position(position); + } - /** - * returns histogram (as 1D array) with fixed bins width - * - * Input arrays: - * - input array with elements to be binned into output histogram - * - range array with first element being bottom limit and second element being top limit of histogram, - please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * - * Input integer arguments: - * nbins (optional) - number of histogram bins, default value is 100 - */ -// #if NOT_EXCLUDED(OP_histogram_fixed_width) - @Namespace("sd::ops") public static class histogram_fixed_width extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public histogram_fixed_width(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public histogram_fixed_width(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public histogram_fixed_width position(long position) { - return (histogram_fixed_width)super.position(position); - } - public histogram_fixed_width() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * standardizes input array to be zero mean unit variance along the given axis + * + * + */ +// #if NOT_EXCLUDED(OP_standardize) +@Namespace("sd::ops") public static class standardize extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public standardize(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public standardize(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public standardize position(long position) { + return (standardize)super.position(position); + } - /** - * standardizes input array to be zero mean unit variance along the given axis - * - * - */ -// #if NOT_EXCLUDED(OP_standardize) - @Namespace("sd::ops") public static class standardize extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public standardize(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public standardize(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public standardize position(long position) { - return (standardize)super.position(position); - } - public standardize() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class standardize_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public standardize_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public standardize_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public standardize_bp position(long position) { - return (standardize_bp)super.position(position); - } - +@Namespace("sd::ops") public static class standardize_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public standardize_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public standardize_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public standardize_bp position(long position) { + return (standardize_bp)super.position(position); + } + public standardize_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation calculates hash code, optionally along dimension + */ +// #if NOT_EXCLUDED(OP_hashcode) +@Namespace("sd::ops") public static class hashcode extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hashcode(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hashcode(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hashcode position(long position) { + return (hashcode)super.position(position); + } - /** - * This operation calculates hash code, optionally along dimension - */ -// #if NOT_EXCLUDED(OP_hashcode) - @Namespace("sd::ops") public static class hashcode extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hashcode(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hashcode(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hashcode position(long position) { - return (hashcode)super.position(position); - } - public hashcode() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation calculates number of entries per bin + */ +// #if NOT_EXCLUDED(OP_histogram) +@Namespace("sd::ops") public static class histogram extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public histogram(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public histogram(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public histogram position(long position) { + return (histogram)super.position(position); + } - /** - * This operation calculates number of entries per bin - */ -// #if NOT_EXCLUDED(OP_histogram) - @Namespace("sd::ops") public static class histogram extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public histogram(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public histogram(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public histogram position(long position) { - return (histogram)super.position(position); - } - public histogram() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -17612,61 +19969,59 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_PARITY_H // #include - /** - * This operation returns index of max element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ -// #if NOT_EXCLUDED(OP_argmax) - @Namespace("sd::ops") public static class argmax extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public argmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public argmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public argmax position(long position) { - return (argmax)super.position(position); - } - +/** + * This operation returns index of max element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argmax) +@Namespace("sd::ops") public static class argmax extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argmax position(long position) { + return (argmax)super.position(position); + } + public argmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns index of min element in a given NDArray (optionally: + * along given dimension(s)) Expected input: 0: N-dimensional array 1: optional + * axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argmin) +@Namespace("sd::ops") public static class argmin extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argmin(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argmin(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argmin position(long position) { + return (argmin)super.position(position); + } - /** - * This operation returns index of min element in a given NDArray (optionally: along given dimension(s)) - * Expected input: - * 0: N-dimensional array - * 1: optional axis vector - * - * Int args: - * 0: optional axis - */ -// #if NOT_EXCLUDED(OP_argmin) - @Namespace("sd::ops") public static class argmin extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public argmin(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public argmin(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public argmin position(long position) { - return (argmin)super.position(position); - } - public argmin() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** +/** * This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s)) * Expected input: * 0: N-dimensional array @@ -17721,3654 +20076,3720 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #endif /** - * This operation provides various normalization modes: - * 0: frobenius - * 1: euclidean (norm2) - * 2: norm1 - * 3: norm2 - * 4: inf-norm - * 5: p-norm - * - * Expected arguments: - * input: N-dimensional array - * - * - * Int args: - * 0...: axis - * - * T args: - * 0: norm mode - * 1: p for p-norm - */ -// #if NOT_EXCLUDED(OP_norm) - @Namespace("sd::ops") public static class norm extends DeclarableReductionOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public norm position(long position) { - return (norm)super.position(position); - } - + * This operation provides various normalization modes: + * 0: frobenius + * 1: euclidean (norm2) + * 2: norm1 + * 3: norm2 + * 4: inf-norm + * 5: p-norm + * + * Expected arguments: + * input: N-dimensional array + * + * + * Int args: + * 0...: axis + * + * T args: + * 0: norm mode + * 1: p for p-norm + */ +// #if NOT_EXCLUDED(OP_norm) +@Namespace("sd::ops") public static class norm extends DeclarableReductionOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public norm position(long position) { + return (norm)super.position(position); + } + public norm() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of input array + * + * Input arrays: + * 0: input array, considered as batch of matrices + * 1: diagonal array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank + * - 1, the shapes of diagonal and input arrays must be equal except last + * dimension of input array, for example if input_shape = [A,B,C,D] then + * diagonal_shape = [A,B,C], also last dimension of diagonal array should be + * equal to smaller of last and last but one input dimensions that is: + * diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * 0: has the same shape as input, corresponding diagonal elements are + * substituted + */ +// #if NOT_EXCLUDED(OP_matrix_set_diag) +@Namespace("sd::ops") public static class matrix_set_diag extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_set_diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_set_diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_set_diag position(long position) { + return (matrix_set_diag)super.position(position); + } - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array - * - * Input arrays: - * 0: input array, considered as batch of matrices - * 1: diagonal array containing elements to be inserted into input array, - * following rank condition should be satisfied: diagonal_rank = input_rank - 1, - * the shapes of diagonal and input arrays must be equal except last dimension of input array, - * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], - * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions - * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) - * - * Output array: - * 0: has the same shape as input, corresponding diagonal elements are substituted - */ -// #if NOT_EXCLUDED(OP_matrix_set_diag) - @Namespace("sd::ops") public static class matrix_set_diag extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_set_diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_set_diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_set_diag position(long position) { - return (matrix_set_diag)super.position(position); - } - public matrix_set_diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Inserts elements provided by diagonal array into the main diagonal of + * innermost matrices of output array, rest output elements are set to zeros + * + * Input array: + * diagonal: array containing elements to be inserted into output array, + * following rank condition is present: diagonal_rank = ouput_rank + * - 1 + * + * Output array: + * 0: is considered as batch of matrices, if for example diagonal array has + * shape [A,B,C] then output array has shape [A,B,C,C] + */ +@Namespace("sd::ops") public static class matrix_diag extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_diag position(long position) { + return (matrix_diag)super.position(position); + } - /** - * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array, - * rest output elements are set to zeros - * - * Input array: - * diagonal: array containing elements to be inserted into output array, - * following rank condition is present: diagonal_rank = ouput_rank - 1 - * - * Output array: - * 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C] - */ - @Namespace("sd::ops") public static class matrix_diag extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_diag position(long position) { - return (matrix_diag)super.position(position); - } - public matrix_diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - /** - * This op calculates regularized incomplete beta integral Ix(a, b). - * Implementation is based on two algorithms depending on input values of a and b: - * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied - * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm for continued fractions is applied - * - * Input arrays: - * a: defines power t^{a-1}, must be > 0, type float. - * b: defines power (1-t)^{b-1}, must be > 0, type float. - * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float. - * - * Output array: - * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float - * - * Three input and one output arrays must have the same shape - */ -// #if NOT_EXCLUDED(OP_betainc) - @Namespace("sd::ops") public static class betainc extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public betainc(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public betainc(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public betainc position(long position) { - return (betainc)super.position(position); - } - +/** + * This op calculates regularized incomplete beta integral Ix(a, b). + * Implementation is based on two algorithms depending on input values of a and + * b: + * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature + * method is applied + * - when a and b are both <= maxValue (3000.), then modified Lentz’s algorithm + * for continued fractions is applied + * + * Input arrays: + * a: defines power t^{a-1}, must be > 0, type float. + * b: defines power (1-t)^{b-1}, must be > 0, type float. + * x: defines upper limit of integration, must be within (0 <= x <= 1) range, + * type float. + * + * Output array: + * 0: values of regularized incomplete beta integral that corresponds to + * variable upper limit x, type float + * + * Three input and one output arrays must have the same shape + */ +// #if NOT_EXCLUDED(OP_betainc) +@Namespace("sd::ops") public static class betainc extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public betainc(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public betainc(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public betainc position(long position) { + return (betainc)super.position(position); + } + public betainc() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation is added for compatibility purposes mostly. + * PLEASE NOTE: Please consider using Add instead + * Expected arguments: + * 0: N-dimensional input + * 1: bias vector + */ +// #if NOT_EXCLUDED(OP_biasadd) +@Namespace("sd::ops") public static class biasadd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public biasadd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public biasadd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public biasadd position(long position) { + return (biasadd)super.position(position); + } - /** - * This operation is added for compatibility purposes mostly. - * PLEASE NOTE: Please consider using Add instead - * Expected arguments: - * 0: N-dimensional input - * 1: bias vector - */ -// #if NOT_EXCLUDED(OP_biasadd) - @Namespace("sd::ops") public static class biasadd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public biasadd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public biasadd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public biasadd position(long position) { - return (biasadd)super.position(position); - } - public biasadd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class biasadd_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public biasadd_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public biasadd_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public biasadd_bp position(long position) { - return (biasadd_bp)super.position(position); - } - +@Namespace("sd::ops") public static class biasadd_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public biasadd_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public biasadd_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public biasadd_bp position(long position) { + return (biasadd_bp)super.position(position); + } + public biasadd_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +// #if NOT_EXCLUDED(OP_diag) +@Namespace("sd::ops") public static class diag extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public diag(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public diag(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public diag position(long position) { + return (diag)super.position(position); + } - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ -// #if NOT_EXCLUDED(OP_diag) - @Namespace("sd::ops") public static class diag extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public diag(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public diag(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public diag position(long position) { - return (diag)super.position(position); - } - public diag() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal tensor with a given diagonal values. Given a diagonal, + * this operation returns a tensor with the diagonal and everything else padded + * with zeros. + */ +// #if NOT_EXCLUDED(OP_diag_part) +@Namespace("sd::ops") public static class diag_part extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public diag_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public diag_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public diag_part position(long position) { + return (diag_part)super.position(position); + } - /** - * Returns a diagonal tensor with a given diagonal values. Given a diagonal, this operation returns a tensor with the diagonal and everything else padded with zeros. - */ -// #if NOT_EXCLUDED(OP_diag_part) - @Namespace("sd::ops") public static class diag_part extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public diag_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public diag_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public diag_part position(long position) { - return (diag_part)super.position(position); - } - public diag_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Returns a diagonal vector for any submatricies with in a given tensor. + * It is an op inverse to matrix_set_giag. + * Using input tensor as batched 2D diagonals flat them to vector (1D) with + * diagonal values. + * + * Input : batched tensor with rank >=2 + * Output: tensor with rank lesser by 1 from input + */ +// #if NOT_EXCLUDED(OP_matrix_diag_part) +@Namespace("sd::ops") public static class matrix_diag_part extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_diag_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_diag_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_diag_part position(long position) { + return (matrix_diag_part)super.position(position); + } - /** - * Returns a diagonal vector for any submatricies with in a given tensor. - * It is an op inverse to matrix_set_giag. - * Using input tensor as batched 2D diagonals flat them to vector (1D) with diagonal values. - * - * Input : batched tensor with rank >=2 - * Output: tensor with rank lesser by 1 from input - */ -// #if NOT_EXCLUDED(OP_matrix_diag_part) - @Namespace("sd::ops") public static class matrix_diag_part extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_diag_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_diag_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_diag_part position(long position) { - return (matrix_diag_part)super.position(position); - } - public matrix_diag_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper + * triangular. For A (MxN) Q is M x M and R is (NxN). + * + * Input : + * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of + * float matricies + * + * Output: + * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies + * {Qs} 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular + * matricies {Rs} + */ +// #if NOT_EXCLUDED(OP_qr) +@Namespace("sd::ops") public static class qr extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public qr(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public qr(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public qr position(long position) { + return (qr)super.position(position); + } - /** - * QR decomposition: A = QR, where Q is ortogonal (Q * QT = I) and R is upper triangular. - * For A (MxN) Q is M x M and R is (NxN). - * - * Input : - * 0 - float (or complex float) tensor with shape {.,..,...,M,N} - batch of float matricies - * - * Output: - * 0 - float tensor with shape {.,..,...,MxN} - batch of ortogonal matricies {Qs} - * 1 - float tensor with shape {.,..,...,NxN} - batch of upper triangular matricies {Rs} - */ -// #if NOT_EXCLUDED(OP_qr) - @Namespace("sd::ops") public static class qr extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public qr(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public qr(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public qr position(long position) { - return (qr)super.position(position); - } - public qr() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes 2 arrays: original values, and values to be excluded. + * And returns 2 arrays: values left after exclusion, and indices in original + * array for surivals. Expected arguments: 0: vector with original values 1: + * vector with values to exclude + */ +// #if NOT_EXCLUDED(OP_listdiff) +@Namespace("sd::ops") public static class listdiff extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public listdiff(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public listdiff(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public listdiff position(long position) { + return (listdiff)super.position(position); + } - /** - * This operation takes 2 arrays: original values, and values to be excluded. And returns 2 arrays: values left after exclusion, and indices in original array for surivals. - * Expected arguments: - * 0: vector with original values - * 1: vector with values to exclude - */ -// #if NOT_EXCLUDED(OP_listdiff) - @Namespace("sd::ops") public static class listdiff extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public listdiff(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public listdiff(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public listdiff position(long position) { - return (listdiff)super.position(position); - } - public listdiff() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Add operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_add) +@Namespace("sd::ops") public static class scatter_add extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_add position(long position) { + return (scatter_add)super.position(position); + } - /** - * This operation applies Add operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_add) - @Namespace("sd::ops") public static class scatter_add extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_add position(long position) { - return (scatter_add)super.position(position); - } - public scatter_add() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Subtract operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_sub) +@Namespace("sd::ops") public static class scatter_sub extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_sub(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_sub(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_sub position(long position) { + return (scatter_sub)super.position(position); + } - /** - * This operation applies Subtract operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_sub) - @Namespace("sd::ops") public static class scatter_sub extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_sub(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_sub(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_sub position(long position) { - return (scatter_sub)super.position(position); - } - public scatter_sub() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Multiply operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_mul) +@Namespace("sd::ops") public static class scatter_mul extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_mul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_mul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_mul position(long position) { + return (scatter_mul)super.position(position); + } - /** - * This operation applies Multiply operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_mul) - @Namespace("sd::ops") public static class scatter_mul extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_mul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_mul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_mul position(long position) { - return (scatter_mul)super.position(position); - } - public scatter_mul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Divide operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_div) +@Namespace("sd::ops") public static class scatter_div extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_div(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_div(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_div position(long position) { + return (scatter_div)super.position(position); + } - /** - * This operation applies Divide operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_div) - @Namespace("sd::ops") public static class scatter_div extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_div(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_div(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_div position(long position) { - return (scatter_div)super.position(position); - } - public scatter_div() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Assign operation to specific inputs wrt indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_upd) +@Namespace("sd::ops") public static class scatter_upd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_upd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_upd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_upd position(long position) { + return (scatter_upd)super.position(position); + } - /** - * This operation applies Assign operation to specific inputs wrt indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_upd) - @Namespace("sd::ops") public static class scatter_upd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_upd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_upd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_upd position(long position) { - return (scatter_upd)super.position(position); - } - public scatter_upd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Max operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_max) +@Namespace("sd::ops") public static class scatter_max extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_max position(long position) { + return (scatter_max)super.position(position); + } - /** - * This operation applies Max operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_max) - @Namespace("sd::ops") public static class scatter_max extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_max position(long position) { - return (scatter_max)super.position(position); - } - public scatter_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies Min operation to specific inputs through given indices + * Expected arguments: + * input: array to be updated + * indices: array containing indexes for first dimension of input + * updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_min) +@Namespace("sd::ops") public static class scatter_min extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_min position(long position) { + return (scatter_min)super.position(position); + } - /** - * This operation applies Min operation to specific inputs through given indices - * Expected arguments: - * input: array to be updated - * indices: array containing indexes for first dimension of input - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_min) - @Namespace("sd::ops") public static class scatter_min extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_min position(long position) { - return (scatter_min)super.position(position); - } - public scatter_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation scatter "updates" elements into new output array according to + * given "indices" Expected arguments: indices: array containing elements/slices + * indexes of output array to put "updates" elements into, the rest output + * elements will be zeros updates: array containing elements to be inserted into + * output array shape: contains shape of output array + */ +// #if NOT_EXCLUDED(OP_scatter_nd) +@Namespace("sd::ops") public static class scatter_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd position(long position) { + return (scatter_nd)super.position(position); + } - /** - * This operation scatter "updates" elements into new output array according to given "indices" - * Expected arguments: - * indices: array containing elements/slices indexes of output array to put "updates" elements into, the rest output elements will be zeros - * updates: array containing elements to be inserted into output array - * shape: contains shape of output array - */ -// #if NOT_EXCLUDED(OP_scatter_nd) - @Namespace("sd::ops") public static class scatter_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd position(long position) { - return (scatter_nd)super.position(position); - } - public scatter_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation scatter "updates" elements into input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to put "updates" elements + * into updates: array containing elements to be inserted into input array + */ +// #if NOT_EXCLUDED(OP_scatter_nd_update) +@Namespace("sd::ops") public static class scatter_nd_update extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_update(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_update(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_update position(long position) { + return (scatter_nd_update)super.position(position); + } - /** - * This operation scatter "updates" elements into input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to put "updates" elements into - * updates: array containing elements to be inserted into input array - */ -// #if NOT_EXCLUDED(OP_scatter_nd_update) - @Namespace("sd::ops") public static class scatter_nd_update extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_update(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_update(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_update position(long position) { - return (scatter_nd_update)super.position(position); - } - public scatter_nd_update() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adds "updates" elements to input array along given "indices" + * Expected arguments: + * input: array to be updated + * indices: array containing elements/slices indexes of input array to add + * "updates" elements to updates: array containing elements to be interfered + * with input + */ +// #if NOT_EXCLUDED(OP_scatter_add) +@Namespace("sd::ops") public static class scatter_nd_add extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_add(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_add(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_add position(long position) { + return (scatter_nd_add)super.position(position); + } - /** - * This operation adds "updates" elements to input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to add "updates" elements to - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_add) - @Namespace("sd::ops") public static class scatter_nd_add extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_add(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_add(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_add position(long position) { - return (scatter_nd_add)super.position(position); - } - public scatter_nd_add() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation subtract "updates" elements from input array along given + * "indices" Expected arguments: input: array to be updated indices: array + * containing elements/slices indexes of input array to subtract "updates" + * elements from updates: array containing elements to be interfered with input + */ +// #if NOT_EXCLUDED(OP_scatter_sub) +@Namespace("sd::ops") public static class scatter_nd_sub extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public scatter_nd_sub(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public scatter_nd_sub(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public scatter_nd_sub position(long position) { + return (scatter_nd_sub)super.position(position); + } - /** - * This operation subtract "updates" elements from input array along given "indices" - * Expected arguments: - * input: array to be updated - * indices: array containing elements/slices indexes of input array to subtract "updates" elements from - * updates: array containing elements to be interfered with input - */ -// #if NOT_EXCLUDED(OP_scatter_sub) - @Namespace("sd::ops") public static class scatter_nd_sub extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public scatter_nd_sub(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public scatter_nd_sub(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public scatter_nd_sub position(long position) { - return (scatter_nd_sub)super.position(position); - } - public scatter_nd_sub() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with + * specified value Expected arguments: input: N-dimensional array + * + * T args: + * 0: scalar value, used to fill NDArray + */ +// #if NOT_EXCLUDED(OP_fill_as) +@Namespace("sd::ops") public static class fill_as extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fill_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fill_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fill_as position(long position) { + return (fill_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with specified value - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: scalar value, used to fill NDArray - */ -// #if NOT_EXCLUDED(OP_fill_as) - @Namespace("sd::ops") public static class fill_as extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fill_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fill_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fill_as position(long position) { - return (fill_as)super.position(position); - } - public fill_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies element-wise rint (round to integral value) operation + */ +// #if NOT_EXCLUDED(OP_rint) +@Namespace("sd::ops") public static class rint extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rint(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rint(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rint position(long position) { + return (rint)super.position(position); + } - /** - * This operation applies element-wise rint (round to integral value) operation - */ -// #if NOT_EXCLUDED(OP_rint) - @Namespace("sd::ops") public static class rint extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rint(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rint(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rint position(long position) { - return (rint)super.position(position); - } - public rint() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns unique elements from input array as vector, and their + * original indices in input array Expected input: input: N-dimensional array + */ +// #if NOT_EXCLUDED(OP_unique) +@Namespace("sd::ops") public static class unique extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unique(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unique(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unique position(long position) { + return (unique)super.position(position); + } - /** - * This operation returns unique elements from input array as vector, and their original indices in input array - * Expected input: - * input: N-dimensional array - */ -// #if NOT_EXCLUDED(OP_unique) - @Namespace("sd::ops") public static class unique extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unique(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unique(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unique position(long position) { - return (unique)super.position(position); - } - public unique() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns 3 1D arrays for given 1D array with unique element + * count and indexes input: 0 - 1D array + * + * output: + * 0 - 1D array with unique values + * 1 - 1D array with ids for values in array above + * 2 - 1D array with counts for values in array above + */ +// #if NOT_EXCLUDED(OP_unique_with_counts) +@Namespace("sd::ops") public static class unique_with_counts extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unique_with_counts(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unique_with_counts(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unique_with_counts position(long position) { + return (unique_with_counts)super.position(position); + } - /** - * This operation returns 3 1D arrays for given 1D array with unique element count and indexes - * input: - * 0 - 1D array - * - * output: - * 0 - 1D array with unique values - * 1 - 1D array with ids for values in array above - * 2 - 1D array with counts for values in array above - */ -// #if NOT_EXCLUDED(OP_unique_with_counts) - @Namespace("sd::ops") public static class unique_with_counts extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unique_with_counts(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unique_with_counts(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unique_with_counts position(long position) { - return (unique_with_counts)super.position(position); - } - public unique_with_counts() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits input NDArray into multiple TADs along given dimensions + * Expected arguments: + * input: N-dimensional array + * + * Int args: + * 0..: TAD axis + */ +// #if NOT_EXCLUDED(OP_tear) +@Namespace("sd::ops") public static class tear extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tear(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tear(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tear position(long position) { + return (tear)super.position(position); + } - /** - * This operation splits input NDArray into multiple TADs along given dimensions - * Expected arguments: - * input: N-dimensional array - * - * Int args: - * 0..: TAD axis - */ -// #if NOT_EXCLUDED(OP_tear) - @Namespace("sd::ops") public static class tear extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tear(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tear(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tear position(long position) { - return (tear)super.position(position); - } - public tear() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op does the same as tear, just uses different input format: + * \tparam T + */ +// #if NOT_EXCLUDED(OP_unstack) +@Namespace("sd::ops") public static class unstack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unstack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unstack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unstack position(long position) { + return (unstack)super.position(position); + } - /** - * This op does the same as tear, just uses different input format: - * \tparam T - */ -// #if NOT_EXCLUDED(OP_unstack) - @Namespace("sd::ops") public static class unstack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unstack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unstack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unstack position(long position) { - return (unstack)super.position(position); - } - public unstack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation extracts a strided (optionally) slice from a tensor, + */ +// #if NOT_EXCLUDED(OP_strided_slice) +@Namespace("sd::ops") public static class strided_slice extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public strided_slice(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public strided_slice(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public strided_slice position(long position) { + return (strided_slice)super.position(position); + } - /** - * This operation extracts a strided (optionally) slice from a tensor, - */ -// #if NOT_EXCLUDED(OP_strided_slice) - @Namespace("sd::ops") public static class strided_slice extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public strided_slice(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public strided_slice(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public strided_slice position(long position) { - return (strided_slice)super.position(position); - } - public strided_slice() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // TODO: new op type needed. that returns VIEW - @Namespace("sd::ops") public static class strided_slice_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public strided_slice_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public strided_slice_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public strided_slice_bp position(long position) { - return (strided_slice_bp)super.position(position); - } - + } // TODO: new op type needed. that returns VIEW +@Namespace("sd::ops") public static class strided_slice_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public strided_slice_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public strided_slice_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public strided_slice_bp position(long position) { + return (strided_slice_bp)super.position(position); + } + public strided_slice_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation extracts a slice from a tensor. + * + */ +// #if NOT_EXCLUDED(OP_slice) +@Namespace("sd::ops") public static class slice extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public slice(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public slice(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public slice position(long position) { + return (slice)super.position(position); + } - /** - * This operation extracts a slice from a tensor. - * - */ -// #if NOT_EXCLUDED(OP_slice) - @Namespace("sd::ops") public static class slice extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public slice(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public slice(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public slice position(long position) { - return (slice)super.position(position); - } - public slice() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class slice_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public slice_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public slice_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public slice_bp position(long position) { - return (slice_bp)super.position(position); - } - +@Namespace("sd::ops") public static class slice_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public slice_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public slice_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public slice_bp position(long position) { + return (slice_bp)super.position(position); + } + public slice_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation generate sequences. Basically from......to, with step used as + * increment. Expected arguments: start: optional scalar with starting value + * stop: optional scalar with end value + * step: optional scalar witn step value + * + * Int args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + * + * T args: (optional) + * 0: optional scalar with starting value + * 1: optional scalar with end value + * 1: optional scalar witn step value + */ +// #if NOT_EXCLUDED(OP_range) +@Namespace("sd::ops") public static class range extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public range(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public range(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public range position(long position) { + return (range)super.position(position); + } - /** - * This operation generate sequences. Basically from......to, with step used as increment. - * Expected arguments: - * start: optional scalar with starting value - * stop: optional scalar with end value - * step: optional scalar witn step value - * - * Int args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - * - * T args: (optional) - * 0: optional scalar with starting value - * 1: optional scalar with end value - * 1: optional scalar witn step value - */ -// #if NOT_EXCLUDED(OP_range) - @Namespace("sd::ops") public static class range extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public range(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public range(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public range position(long position) { - return (range)super.position(position); - } - public range() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation return one-hot encoded n-dimensional array + * Expected arguments: + * input: N-dimensional array + * + * T args: + * 0: 'on' value + * 1: 'off' value + * + * Int args: + * 0: depth + * 1: axis + */ +// #if NOT_EXCLUDED(OP_onehot) +@Namespace("sd::ops") public static class onehot extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public onehot(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public onehot(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public onehot position(long position) { + return (onehot)super.position(position); + } - /** - * This operation return one-hot encoded n-dimensional array - * Expected arguments: - * input: N-dimensional array - * - * T args: - * 0: 'on' value - * 1: 'off' value - * - * Int args: - * 0: depth - * 1: axis - */ -// #if NOT_EXCLUDED(OP_onehot) - @Namespace("sd::ops") public static class onehot extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public onehot(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public onehot(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public onehot position(long position) { - return (onehot)super.position(position); - } - public onehot() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation calculate the confusion matrix for a + * pair of prediction and label 1-D arrays. + * Expected arguments: + * Input arrays: + * 0 - predictions: 1-D array + * 1 - labels: 1-D array + * 2 - weights : optional + * Int args: + * 0 - num_classes: optional + * + */ +// #if NOT_EXCLUDED(OP_confusion_matrix) +@Namespace("sd::ops") public static class confusion_matrix extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public confusion_matrix(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public confusion_matrix(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public confusion_matrix position(long position) { + return (confusion_matrix)super.position(position); + } - /** - * This operation calculate the confusion matrix for a - * pair of prediction and label 1-D arrays. - * Expected arguments: - * Input arrays: - * 0 - predictions: 1-D array - * 1 - labels: 1-D array - * 2 - weights : optional - * Int args: - * 0 - num_classes: optional - * - */ -// #if NOT_EXCLUDED(OP_confusion_matrix) - @Namespace("sd::ops") public static class confusion_matrix extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public confusion_matrix(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public confusion_matrix(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public confusion_matrix position(long position) { - return (confusion_matrix)super.position(position); - } - public confusion_matrix() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation stacks a list of rank tensors into one rank-(R+1) tensor. + * Expected arguments: + * 0...: N-Dimensional arrays to stack + * + */ +// #if NOT_EXCLUDED(OP_stack) +@Namespace("sd::ops") public static class stack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stack position(long position) { + return (stack)super.position(position); + } - /** - * This operation stacks a list of rank tensors into one rank-(R+1) tensor. - * Expected arguments: - * 0...: N-Dimensional arrays to stack - * - */ -// #if NOT_EXCLUDED(OP_stack) - @Namespace("sd::ops") public static class stack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stack position(long position) { - return (stack)super.position(position); - } - public stack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation returns length of input array + * Expected arguments: + * input: N-dimensional array + * + * TODO: make this operation reduction, to allow TAD -> size + */ +// #if NOT_EXCLUDED(OP_size) +@Namespace("sd::ops") public static class size extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size position(long position) { + return (size)super.position(position); + } - /** - * This operation returns length of input array - * Expected arguments: - * input: N-dimensional array - * - * TODO: make this operation reduction, to allow TAD -> size - */ -// #if NOT_EXCLUDED(OP_size) - @Namespace("sd::ops") public static class size extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size position(long position) { - return (size)super.position(position); - } - public size() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // add DeclarableScalarOp? -// #endif + } // add DeclarableScalarOp? +// #endif +/** + * This operation returns rank of input array as scalar value. + */ +// #if NOT_EXCLUDED(OP_rank) +@Namespace("sd::ops") public static class rank extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rank(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rank(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rank position(long position) { + return (rank)super.position(position); + } - /** - * This operation returns rank of input array as scalar value. - */ -// #if NOT_EXCLUDED(OP_rank) - @Namespace("sd::ops") public static class rank extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rank(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rank(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rank position(long position) { - return (rank)super.position(position); - } - public rank() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } // ^ -// #endif + } // ^ +// #endif +// #if NOT_EXCLUDED(OP_broadcastgradientargs) +@Namespace("sd::ops") public static class broadcastgradientargs extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcastgradientargs(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcastgradientargs(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcastgradientargs position(long position) { + return (broadcastgradientargs)super.position(position); + } -// #if NOT_EXCLUDED(OP_broadcastgradientargs) - @Namespace("sd::ops") public static class broadcastgradientargs extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcastgradientargs(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcastgradientargs(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcastgradientargs position(long position) { - return (broadcastgradientargs)super.position(position); - } - public broadcastgradientargs() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with zeros + * Expected arguments: + * input: N-dimensional array + * + */ +// #if NOT_EXCLUDED(OP_zeros_as) +@Namespace("sd::ops") public static class zeros_as extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zeros_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zeros_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zeros_as position(long position) { + return (zeros_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with zeros - * Expected arguments: - * input: N-dimensional array - * - */ -// #if NOT_EXCLUDED(OP_zeros_as) - @Namespace("sd::ops") public static class zeros_as extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zeros_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zeros_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zeros_as position(long position) { - return (zeros_as)super.position(position); - } - public zeros_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes input's shape, and returns new NDArray filled with ones + * Expected arguments: + * input: N-dimensional array + * + */ +// #if NOT_EXCLUDED(OP_ones_as) +@Namespace("sd::ops") public static class ones_as extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ones_as(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ones_as(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ones_as position(long position) { + return (ones_as)super.position(position); + } - /** - * This operation takes input's shape, and returns new NDArray filled with ones - * Expected arguments: - * input: N-dimensional array - * - */ -// #if NOT_EXCLUDED(OP_ones_as) - @Namespace("sd::ops") public static class ones_as extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ones_as(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ones_as(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ones_as position(long position) { - return (ones_as)super.position(position); - } - public ones_as() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation applies element-wise pow(x, 2) to the given input + * Expected arguments: + * input: N-Dimensional array + */ +// #if NOT_EXCLUDED(OP_square) +@Namespace("sd::ops") public static class square extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public square(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public square(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public square position(long position) { + return (square)super.position(position); + } - /** - * This operation applies element-wise pow(x, 2) to the given input - * Expected arguments: - * input: N-Dimensional array - */ -// #if NOT_EXCLUDED(OP_square) - @Namespace("sd::ops") public static class square extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public square(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public square(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public square position(long position) { - return (square)super.position(position); - } - public square() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + + * n)^{-x} Implementation is based on Euler-Maclaurin summation formula + * + * Input arrays: + * x: define power {-x}, must be > 1, type float. + * q: define summand in denominator, must be > 0, type float. + * + * Output array: + * 0: corresponding values of Hurwitz zeta function + * + * Two input and one output arrays must have the same shape + */ +// #if NOT_EXCLUDED(OP_zeta) +@Namespace("sd::ops") public static class zeta extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zeta(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zeta(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zeta position(long position) { + return (zeta)super.position(position); + } - /** - * This op calculates Hurwitz zeta function zeta(x, q) = sum_{n=0}^{inf} (q + n)^{-x} - * Implementation is based on Euler-Maclaurin summation formula - * - * Input arrays: - * x: define power {-x}, must be > 1, type float. - * q: define summand in denominator, must be > 0, type float. - * - * Output array: - * 0: corresponding values of Hurwitz zeta function - * - * Two input and one output arrays must have the same shape - */ -// #if NOT_EXCLUDED(OP_zeta) - @Namespace("sd::ops") public static class zeta extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zeta(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zeta(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zeta position(long position) { - return (zeta)super.position(position); - } - public zeta() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates polygamma function psi^(n)(x). Implementation is based on + * serial representation written in terms of the Hurwitz zeta function: + * polygamma = (-1)^{n+1} * n! * zeta(n+1, x). + * + * Input arrays: + * 0: n - define derivative order (n+1), type integer (however currently is + * implemented as float casted to integer) 1: x - abscissa points where to + * evaluate the polygamma function, type float + * + * Output array: + * 0: values of polygamma function at corresponding x, type float + * + * Two input and one output arrays have the same shape + */ +// #if NOT_EXCLUDED(OP_polygamma) +@Namespace("sd::ops") public static class polygamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public polygamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public polygamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public polygamma position(long position) { + return (polygamma)super.position(position); + } - /** - * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in - * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * - * Input arrays: - * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) - * 1: x - abscissa points where to evaluate the polygamma function, type float - * - * Output array: - * 0: values of polygamma function at corresponding x, type float - * - * Two input and one output arrays have the same shape - */ -// #if NOT_EXCLUDED(OP_polygamma) - @Namespace("sd::ops") public static class polygamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public polygamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public polygamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public polygamma position(long position) { - return (polygamma)super.position(position); - } - public polygamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ +// #if NOT_EXCLUDED(OP_lgamma) +@Namespace("sd::ops") public static class lgamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lgamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lgamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lgamma position(long position) { + return (lgamma)super.position(position); + } - /** - * This op calculates lgamma function lgamma(x) = log(Gamma(x)) - * - * Input arrays: - * 0: x - input matrix - * - * Output array: - * 0: log of Gamma(x) - * - */ -// #if NOT_EXCLUDED(OP_lgamma) - @Namespace("sd::ops") public static class lgamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lgamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lgamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lgamma position(long position) { - return (lgamma)super.position(position); - } - public lgamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +// #if NOT_EXCLUDED(OP_digamma) +@Namespace("sd::ops") public static class digamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public digamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public digamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public digamma position(long position) { + return (digamma)super.position(position); + } - /** - * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) - * - * Input arrays: - * 0: x - abscissa points where to evaluate the digamma function, type float - * - * Output array: - * 0: values of digamma function at corresponding x, type float - * - */ -// #if NOT_EXCLUDED(OP_digamma) - @Namespace("sd::ops") public static class digamma extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public digamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public digamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public digamma position(long position) { - return (digamma)super.position(position); - } - public digamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation takes shape as first argument, and returns new NDArray filled + * with specific scalar value. Input arrays: 0 - shape vector 1 - optional + * scalar NDArray + * + * T arguments: + * 0 - optional scalar value + * + */ +// #if NOT_EXCLUDED(OP_fill) +@Namespace("sd::ops") public static class fill extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fill(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fill(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fill position(long position) { + return (fill)super.position(position); + } - /** - * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. - * Input arrays: - * 0 - shape vector - * 1 - optional scalar NDArray - * - * T arguments: - * 0 - optional scalar value - * - */ -// #if NOT_EXCLUDED(OP_fill) - @Namespace("sd::ops") public static class fill extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fill(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fill(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fill position(long position) { - return (fill)super.position(position); - } - public fill() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension Input arrays: 0 - input array 1 - array of sizes 2 - optional axis + * + * Integer arguments: + * 0 - optional axis + * + */ +// #if NOT_EXCLUDED(OP_split_v) +@Namespace("sd::ops") public static class split_v extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split_v(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split_v(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split_v position(long position) { + return (split_v)super.position(position); + } - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * Input arrays: - * 0 - input array - * 1 - array of sizes - * 2 - optional axis - * - * Integer arguments: - * 0 - optional axis - * - */ -// #if NOT_EXCLUDED(OP_split_v) - @Namespace("sd::ops") public static class split_v extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split_v(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split_v(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split_v position(long position) { - return (split_v)super.position(position); - } - public split_v() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation splits given NDArray into chunks of specific size, along given + * dimension 0 - input array 1 - optional axis + * + * Integer arguments: + * 0 - number of splits + * 1 - optional axis + */ +// #if NOT_EXCLUDED(OP_split) +@Namespace("sd::ops") public static class split extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public split(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public split(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public split position(long position) { + return (split)super.position(position); + } - /** - * This operation splits given NDArray into chunks of specific size, along given dimension - * 0 - input array - * 1 - optional axis - * - * Integer arguments: - * 0 - number of splits - * 1 - optional axis - */ -// #if NOT_EXCLUDED(OP_split) - @Namespace("sd::ops") public static class split extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public split(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public split(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public split position(long position) { - return (split)super.position(position); - } - public split() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation adjusts image hue by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing delta + * + * T arguments: + * 0 - optional argument, delta value + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +// #if NOT_EXCLUDED(OP_adjust_hue) +@Namespace("sd::ops") public static class adjust_hue extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_hue(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_hue(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_hue position(long position) { + return (adjust_hue)super.position(position); + } - /** - * This operation adjusts image hue by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing delta - * - * T arguments: - * 0 - optional argument, delta value - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -// #if NOT_EXCLUDED(OP_adjust_hue) - @Namespace("sd::ops") public static class adjust_hue extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_hue(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_hue(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_hue position(long position) { - return (adjust_hue)super.position(position); - } - public adjust_hue() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adjusts image saturation by delta + * Input arrays: + * 0 - input array with rank >= 3, must have at least one dimension equal 3, + * that is dimension containing channels. 1 - optional argument, input + * scalar-array containing saturation factor + * + * T arguments: + * 0 - optional argument, saturation factor + * + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +// #if NOT_EXCLUDED(OP_adjust_saturation) +@Namespace("sd::ops") public static class adjust_saturation extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_saturation(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_saturation(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_saturation position(long position) { + return (adjust_saturation)super.position(position); + } - /** - * This operation adjusts image saturation by delta - * Input arrays: - * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation factor - * - * T arguments: - * 0 - optional argument, saturation factor - * - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -// #if NOT_EXCLUDED(OP_adjust_saturation) - @Namespace("sd::ops") public static class adjust_saturation extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_saturation(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_saturation(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_saturation position(long position) { - return (adjust_saturation)super.position(position); - } - public adjust_saturation() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation adjusts image contrast by given factor ( z = (x - mean) * + * factor + mean ) Input arrays: 0 - input array with rank >= 3, must have last + * one dimension equal 3, that is dimension containing channels. 1 - optional + * argument, input scalar-array containing saturation contrast factor + * + * T arguments: + * 0 - optional argument, contrast factor + * + */ +// #if NOT_EXCLUDED(OP_adjust_contrast) +@Namespace("sd::ops") public static class adjust_contrast extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast position(long position) { + return (adjust_contrast)super.position(position); + } - /** - * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) - * Input arrays: - * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. - * 1 - optional argument, input scalar-array containing saturation contrast factor - * - * T arguments: - * 0 - optional argument, contrast factor - * - */ -// #if NOT_EXCLUDED(OP_adjust_contrast) - @Namespace("sd::ops") public static class adjust_contrast extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_contrast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_contrast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_contrast position(long position) { - return (adjust_contrast)super.position(position); - } - public adjust_contrast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class adjust_contrast_v2 extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public adjust_contrast_v2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public adjust_contrast_v2 position(long position) { - return (adjust_contrast_v2)super.position(position); - } - +@Namespace("sd::ops") public static class adjust_contrast_v2 extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public adjust_contrast_v2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public adjust_contrast_v2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public adjust_contrast_v2 position(long position) { + return (adjust_contrast_v2)super.position(position); + } + public adjust_contrast_v2() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif +/** + * This operation rearranges data from depth into blocks of spatial data. This + * is the reverse transformation of space_to_depth op. This op output is a copy + * of the input tensor where values from the depth dimension are moved in + * spatial blocks to the height and width dimensions. Int attr 0 indicates the + * input block size and how the data is moved. Input: 0 - 4D tensor on given + * type Output: 0 - 4D tensor of given type and proper shape + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + */ +// #if NOT_EXCLUDED(OP_depth_to_space) +@Namespace("sd::ops") public static class depth_to_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public depth_to_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public depth_to_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public depth_to_space position(long position) { + return (depth_to_space)super.position(position); + } - /** - * This operation rearranges data from depth into blocks of spatial data. This is the reverse transformation - * of space_to_depth op. This op output is a copy of the input tensor where values from the depth dimension - * are moved in spatial blocks to the height and width dimensions. Int attr 0 indicates the input - * block size and how the data is moved. - * Input: - * 0 - 4D tensor on given type - * Output: - * 0 - 4D tensor of given type and proper shape - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - */ -// #if NOT_EXCLUDED(OP_depth_to_space) - @Namespace("sd::ops") public static class depth_to_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public depth_to_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public depth_to_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public depth_to_space position(long position) { - return (depth_to_space)super.position(position); - } - public depth_to_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation rearranges blocks of spatial data, into depth.This op output + * is a copy of the input tensor where values from the height and width + * dimensions are moved to the depth dimension. Int attr 0 indicates the input + * block size. + * + * Input: + * - 4D tensor of given type + * Output: + * - 4D tensor + * + * Int arguments: + * 0 - block size + * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels + * } 1 ("NCHW"): shape{ batch, channels, height, width } 2 ("NCHW_VECT_C"): int8 + * shape{ batch, channels / 4, height, width, 4 } optional (default 0) + * + */ +// #if NOT_EXCLUDED(OP_space_to_depth) +@Namespace("sd::ops") public static class space_to_depth extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_depth(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_depth(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_depth position(long position) { + return (space_to_depth)super.position(position); + } - /** - * This operation rearranges blocks of spatial data, into depth.This op output is a copy of the input tensor - * where values from the height and width dimensions are moved to the depth dimension. Int attr 0 indicates - * the input block size. - * - * Input: - * - 4D tensor of given type - * Output: - * - 4D tensor - * - * Int arguments: - * 0 - block size - * 1 - output data format: 0 ("NHWC"): shape{ batch, height, width, channels } - * 1 ("NCHW"): shape{ batch, channels, height, width } - * 2 ("NCHW_VECT_C"): int8 shape{ batch, channels / 4, height, width, 4 } - * optional (default 0) - * - */ -// #if NOT_EXCLUDED(OP_space_to_depth) - @Namespace("sd::ops") public static class space_to_depth extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_depth(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_depth(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_depth position(long position) { - return (space_to_depth)super.position(position); - } - public space_to_depth() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * This op calculates cross-product between input arguments - * Input arguments - * 0 - vector or tensor A - * 1 - vector or tensor B - */ -// #if NOT_EXCLUDED(OP_cross) - @Namespace("sd::ops") public static class cross extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cross(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cross(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cross position(long position) { - return (cross)super.position(position); - } - - public cross() { super((Pointer)null); allocate(); } - private native void allocate(); +/** + * This op calculates cross-product between input arguments + * Input arguments + * 0 - vector or tensor A + * 1 - vector or tensor B + */ +// #if NOT_EXCLUDED(OP_cross) +@Namespace("sd::ops") public static class cross extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cross(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cross(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cross position(long position) { + return (cross)super.position(position); + } + + public cross() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the batch dimension. After + * the zero-padding, both height and width of the input must be divisible by the + * block size. + * + * Inputs: + * 0 - input tensor + * 1 - 2D paddings tensor (shape {M, 2}) + * + * Output: + * - result tensor + * + * Int args: + * 0 - block size (M) + * + */ +// #if NOT_EXCLUDED(OP_space_to_batch) +@Namespace("sd::ops") public static class space_to_batch extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_batch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_batch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_batch position(long position) { + return (space_to_batch)super.position(position); + } - /** - * Zero-pads and then rearranges (permutes) blocks of spatial data into batch. More specifically, this op - * outputs a copy of the input tensor where values from the height and width dimensions are moved to the - * batch dimension. After the zero-padding, both height and width of the input must be divisible by the block - * size. - * - * Inputs: - * 0 - input tensor - * 1 - 2D paddings tensor (shape {M, 2}) - * - * Output: - * - result tensor - * - * Int args: - * 0 - block size (M) - * - */ -// #if NOT_EXCLUDED(OP_space_to_batch) - @Namespace("sd::ops") public static class space_to_batch extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_batch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_batch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_batch position(long position) { - return (space_to_batch)super.position(position); - } - public space_to_batch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * This operation divides "spatial" dimensions [1, ..., M] of the input into a + * grid of blocks of shape block_shape, and interleaves these blocks with the + * "batch" dimension (0) such that in the output, the spatial dimensions [1, + * ..., M] correspond to the position within the grid, and the batch dimension + * combines both the position within a spatial block and the original batch + * position. Prior to division into blocks, the spatial dimensions of the input + * are optionally zero padded according to paddings. + * + * Inputs: + * 0 - input (N-D tensor) + * 1 - block_shape - int 1D tensor with M length + * 2 - paddings - int 2D tensor with shape {M, 2} + * + * Output: + * - N-D tensor with the same type as input 0. + * + * */ +// #if NOT_EXCLUDED(OP_space_to_batch_nd) +@Namespace("sd::ops") public static class space_to_batch_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public space_to_batch_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public space_to_batch_nd position(long position) { + return (space_to_batch_nd)super.position(position); + } - /* - * This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape - * block_shape, and interleaves these blocks with the "batch" dimension (0) such that in the output, - * the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension - * combines both the position within a spatial block and the original batch position. Prior to division into - * blocks, the spatial dimensions of the input are optionally zero padded according to paddings. - * - * Inputs: - * 0 - input (N-D tensor) - * 1 - block_shape - int 1D tensor with M length - * 2 - paddings - int 2D tensor with shape {M, 2} - * - * Output: - * - N-D tensor with the same type as input 0. - * - * */ -// #if NOT_EXCLUDED(OP_space_to_batch_nd) - @Namespace("sd::ops") public static class space_to_batch_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_batch_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_batch_nd position(long position) { - return (space_to_batch_nd)super.position(position); - } - public space_to_batch_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * + * + */ +// #if NOT_EXCLUDED(OP_batch_to_space) +@Namespace("sd::ops") public static class batch_to_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batch_to_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batch_to_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batch_to_space position(long position) { + return (batch_to_space)super.position(position); + } - /** - * - * - */ -// #if NOT_EXCLUDED(OP_batch_to_space) - @Namespace("sd::ops") public static class batch_to_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batch_to_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batch_to_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batch_to_space position(long position) { - return (batch_to_space)super.position(position); - } - public batch_to_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_batch_to_space_nd) - @Namespace("sd::ops") public static class batch_to_space_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batch_to_space_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batch_to_space_nd position(long position) { - return (batch_to_space_nd)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_batch_to_space_nd) +@Namespace("sd::ops") public static class batch_to_space_nd extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batch_to_space_nd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batch_to_space_nd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batch_to_space_nd position(long position) { + return (batch_to_space_nd)super.position(position); + } + public batch_to_space_nd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * top_k operation returns a vector of k top values for + * given NDArray as tensor with default boolean (true) + * as sort for result index array + * will be sorted by the values in descending order. + * The first parameter is a NDArray for working. + * The second is k (default 1) - optional + * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) + * optional + */ +// #if NOT_EXCLUDED(OP_top_k) +@Namespace("sd::ops") public static class top_k extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public top_k(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public top_k(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public top_k position(long position) { + return (top_k)super.position(position); + } - /** - * top_k operation returns a vector of k top values for - * given NDArray as tensor with default boolean (true) - * as sort for result index array - * will be sorted by the values in descending order. - * The first parameter is a NDArray for working. - * The second is k (default 1) - optional - * The third is boolean value(default is true) (0 - as is, 1 - sorted by value) optional - */ -// #if NOT_EXCLUDED(OP_top_k) - @Namespace("sd::ops") public static class top_k extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public top_k(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public top_k(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public top_k position(long position) { - return (top_k)super.position(position); - } - public top_k() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * in_top_k operation returns a vector of k boolean values for + * given NDArray as 2D matrix of predicted in the NDArray k top values + * The first parameter is a NDArray of predicted values (2d array). + * The second is NDArray as vector of indeces k top values will be search. + * The third is k + */ +// #if NOT_EXCLUDED(OP_in_top_k) +@Namespace("sd::ops") public static class in_top_k extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public in_top_k(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public in_top_k(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public in_top_k position(long position) { + return (in_top_k)super.position(position); + } - /** - * in_top_k operation returns a vector of k boolean values for - * given NDArray as 2D matrix of predicted in the NDArray k top values - * The first parameter is a NDArray of predicted values (2d array). - * The second is NDArray as vector of indeces k top values will be search. - * The third is k - */ -// #if NOT_EXCLUDED(OP_in_top_k) - @Namespace("sd::ops") public static class in_top_k extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public in_top_k(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public in_top_k(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public in_top_k position(long position) { - return (in_top_k)super.position(position); - } - public in_top_k() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * moments operation calculate a mean and variation for given NDArray + * with reduce a result according to axis array given. + * For full axis the result is both mean and variance of all members in array. + * Otherwise there are two NDArrays with means and variances for + * Axes can be put as the second NDArray or as int vector. + * + * the optional flag "keep_dims" can be set as T param + */ +// #if NOT_EXCLUDED(OP_moments) +@Namespace("sd::ops") public static class moments extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public moments(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public moments(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public moments position(long position) { + return (moments)super.position(position); + } - /** - * moments operation calculate a mean and variation for given NDArray - * with reduce a result according to axis array given. - * For full axis the result is both mean and variance of all members in array. - * Otherwise there are two NDArrays with means and variances for - * Axes can be put as the second NDArray or as int vector. - * - * the optional flag "keep_dims" can be set as T param - */ -// #if NOT_EXCLUDED(OP_moments) - @Namespace("sd::ops") public static class moments extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public moments(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public moments(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public moments position(long position) { - return (moments)super.position(position); - } - public moments() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * embedding_lookup - search for submatrices in given matrix and retunts them + * accordingly to index array given. + */ +// #if NOT_EXCLUDED(OP_embedding_lookup) +@Namespace("sd::ops") public static class embedding_lookup extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public embedding_lookup(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public embedding_lookup(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public embedding_lookup position(long position) { + return (embedding_lookup)super.position(position); + } - /** - * embedding_lookup - search for submatrices in given matrix and retunts them - * accordingly to index array given. - */ -// #if NOT_EXCLUDED(OP_embedding_lookup) - @Namespace("sd::ops") public static class embedding_lookup extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public embedding_lookup(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public embedding_lookup(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public embedding_lookup position(long position) { - return (embedding_lookup)super.position(position); - } - public embedding_lookup() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * dynamic_partition - partition a input tensor onto num_partitions + * accordingly to index array given. + * + * the first param - NDArray to be partitioned. + * the second param - index array + * the third param (integer param) - num or partitions. + * + * returns a num of NDArrays as output + */ +// #if NOT_EXCLUDED(OP_dynamic_partition) +@Namespace("sd::ops") public static class dynamic_partition extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_partition(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_partition(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_partition position(long position) { + return (dynamic_partition)super.position(position); + } - /** - * dynamic_partition - partition a input tensor onto num_partitions - * accordingly to index array given. - * - * the first param - NDArray to be partitioned. - * the second param - index array - * the third param (integer param) - num or partitions. - * - * returns a num of NDArrays as output - */ -// #if NOT_EXCLUDED(OP_dynamic_partition) - @Namespace("sd::ops") public static class dynamic_partition extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_partition(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_partition(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_partition position(long position) { - return (dynamic_partition)super.position(position); - } - public dynamic_partition() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_dynamic_partition_bp) +@Namespace("sd::ops") public static class dynamic_partition_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_partition_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_partition_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_partition_bp position(long position) { + return (dynamic_partition_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_dynamic_partition_bp) - @Namespace("sd::ops") public static class dynamic_partition_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_partition_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_partition_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_partition_bp position(long position) { - return (dynamic_partition_bp)super.position(position); - } - public dynamic_partition_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * dynamic_stitch - merge partitions from the second param a input tensor + * into a single tensor accordingly to index array given. + * + * the first param - index array + * the second params - tensors to be merged + * + * returns a num of NDArrays as output + * + * the operation is inversion od dynamic_partition + */ +// #if NOT_EXCLUDED(OP_dynamic_stitch) +@Namespace("sd::ops") public static class dynamic_stitch extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dynamic_stitch(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dynamic_stitch(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dynamic_stitch position(long position) { + return (dynamic_stitch)super.position(position); + } - /** - * dynamic_stitch - merge partitions from the second param a input tensor - * into a single tensor accordingly to index array given. - * - * the first param - index array - * the second params - tensors to be merged - * - * returns a num of NDArrays as output - * - * the operation is inversion od dynamic_partition - */ -// #if NOT_EXCLUDED(OP_dynamic_stitch) - @Namespace("sd::ops") public static class dynamic_stitch extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dynamic_stitch(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dynamic_stitch(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dynamic_stitch position(long position) { - return (dynamic_stitch)super.position(position); - } - public dynamic_stitch() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * zero_fraction op. + * compute a fraction of zeros in given array + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +// #if NOT_EXCLUDED(OP_zero_fraction) +@Namespace("sd::ops") public static class zero_fraction extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public zero_fraction(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public zero_fraction(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public zero_fraction position(long position) { + return (zero_fraction)super.position(position); + } - /** - * zero_fraction op. - * compute a fraction of zeros in given array - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ -// #if NOT_EXCLUDED(OP_zero_fraction) - @Namespace("sd::ops") public static class zero_fraction extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public zero_fraction(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public zero_fraction(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public zero_fraction position(long position) { - return (zero_fraction)super.position(position); - } - public zero_fraction() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * xw_plus_b op. - * multiply two first matrices and add third vector to each row of result - * - * input params: - * - 2D matrix NxM - * - 2D matrix MxN - * - 1D vector with N elements - * output value - 2D matrix NxN as multiply of matrixes and add vector - * Int args: - * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul - */ -// #if NOT_EXCLUDED(OP_xw_plus_b) - @Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public xw_plus_b(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public xw_plus_b position(long position) { - return (xw_plus_b)super.position(position); - } - - public xw_plus_b() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public xw_plus_b_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public xw_plus_b_bp position(long position) { - return (xw_plus_b_bp)super.position(position); - } - - public xw_plus_b_bp() { super((Pointer)null); allocate(); } - private native void allocate(); +/** + * xw_plus_b op. + * multiply two first matrices and add third vector to each row of result + * + * input params: + * - 2D matrix NxM + * - 2D matrix MxN + * - 1D vector with N elements + * output value - 2D matrix NxN as multiply of matrixes and add vector + * Int args: + * 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else + * mmul + */ +// #if NOT_EXCLUDED(OP_xw_plus_b) +@Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b position(long position) { + return (xw_plus_b)super.position(position); + } + + public xw_plus_b() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +@Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public xw_plus_b_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public xw_plus_b_bp position(long position) { + return (xw_plus_b_bp)super.position(position); + } + + public xw_plus_b_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + +/** + * This operation is missed due it simplicy. + * Input and output params are the same after operation. + * Input - NDArray, output - NDArray with the same shape. + */ +// #if NOT_EXCLUDED(OP_stop_gradient) +@Namespace("sd::ops") public static class stop_gradient extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public stop_gradient(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public stop_gradient(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public stop_gradient position(long position) { + return (stop_gradient)super.position(position); + } - /** - * This operation is missed due it simplicy. - * Input and output params are the same after operation. - * Input - NDArray, output - NDArray with the same shape. - */ -// #if NOT_EXCLUDED(OP_stop_gradient) - @Namespace("sd::ops") public static class stop_gradient extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public stop_gradient(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public stop_gradient(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public stop_gradient position(long position) { - return (stop_gradient)super.position(position); - } - public stop_gradient() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_parallel_stack) +@Namespace("sd::ops") public static class parallel_stack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public parallel_stack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public parallel_stack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public parallel_stack position(long position) { + return (parallel_stack)super.position(position); + } -// #if NOT_EXCLUDED(OP_parallel_stack) - @Namespace("sd::ops") public static class parallel_stack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public parallel_stack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public parallel_stack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public parallel_stack position(long position) { - return (parallel_stack)super.position(position); - } - public parallel_stack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * normalize_moments operation normalize already calculated mean and variation + * accordingly to shift and count. + * input params: + * - count of data + * - tensor with mean + * - tensor with variance (the same shape as before) + * + * - optional floating point param shift. + * + * returns a normalized pair mean and variance with the same shapes as input + */ +// #if NOT_EXCLUDED(OP_normalize_moments) +@Namespace("sd::ops") public static class normalize_moments extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public normalize_moments(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public normalize_moments(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public normalize_moments position(long position) { + return (normalize_moments)super.position(position); + } - /** - * normalize_moments operation normalize already calculated mean and variation - * accordingly to shift and count. - * input params: - * - count of data - * - tensor with mean - * - tensor with variance (the same shape as before) - * - * - optional floating point param shift. - * - * returns a normalized pair mean and variance with the same shapes as input - */ -// #if NOT_EXCLUDED(OP_normalize_moments) - @Namespace("sd::ops") public static class normalize_moments extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public normalize_moments(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public normalize_moments(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public normalize_moments position(long position) { - return (normalize_moments)super.position(position); - } - public normalize_moments() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * sufficient_statistics operation return calculated mean and variation with + * data count. this operation is invert for moments accordingly to shift and + * count. input params: + * - input tensor + * - axes vector + * + * + * - optional floating point param shift. + * - optional int (as bool) keep_dimension + * + * returns four tensors: + * - scalar tensor (data count) + * - sum elements of input (accross axises) + * - sum of squares of input (accross axises) + * - shift (if was given by input floating param) + */ +// #if NOT_EXCLUDED(OP_sufficient_statistics) +@Namespace("sd::ops") public static class sufficient_statistics extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sufficient_statistics(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sufficient_statistics(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sufficient_statistics position(long position) { + return (sufficient_statistics)super.position(position); + } - /** - * sufficient_statistics operation return calculated mean and variation with data count. - * this operation is invert for moments - * accordingly to shift and count. - * input params: - * - input tensor - * - axes vector - * - * - * - optional floating point param shift. - * - optional int (as bool) keep_dimension - * - * returns four tensors: - * - scalar tensor (data count) - * - sum elements of input (accross axises) - * - sum of squares of input (accross axises) - * - shift (if was given by input floating param) - */ -// #if NOT_EXCLUDED(OP_sufficient_statistics) - @Namespace("sd::ops") public static class sufficient_statistics extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sufficient_statistics(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sufficient_statistics(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sufficient_statistics position(long position) { - return (sufficient_statistics)super.position(position); - } - public sufficient_statistics() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates weighted logarithmic loss of input + * Input arguments + * 0 - target + * 1 - input + * 2 - weights (scalar or vector with same as last dimension) + * + * return value - a tensor with the same shape as target or input + */ +// #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) +@Namespace("sd::ops") public static class weighted_cross_entropy_with_logits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public weighted_cross_entropy_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public weighted_cross_entropy_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public weighted_cross_entropy_with_logits position(long position) { + return (weighted_cross_entropy_with_logits)super.position(position); + } - /** - * This op calculates weighted logarithmic loss of input - * Input arguments - * 0 - target - * 1 - input - * 2 - weights (scalar or vector with same as last dimension) - * - * return value - a tensor with the same shape as target or input - */ -// #if NOT_EXCLUDED(OP_weighted_cross_entropy_with_logits) - @Namespace("sd::ops") public static class weighted_cross_entropy_with_logits extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public weighted_cross_entropy_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public weighted_cross_entropy_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public weighted_cross_entropy_with_logits position(long position) { - return (weighted_cross_entropy_with_logits)super.position(position); - } - public weighted_cross_entropy_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates dropout of input + * Input arguments + * 0 - input tensor + * 1 - noise_shape - (vector with shape to reduce) - optional + * + * int parameter - seed for random numbers + * T parameter - probability (should be between 0 and 1) + * return value - a tensor with the same shape as target or input + */ +// #if NOT_EXCLUDED(OP_dropout) +@Namespace("sd::ops") public static class dropout extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dropout(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dropout(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dropout position(long position) { + return (dropout)super.position(position); + } - /** - * This op calculates dropout of input - * Input arguments - * 0 - input tensor - * 1 - noise_shape - (vector with shape to reduce) - optional - * - * int parameter - seed for random numbers - * T parameter - probability (should be between 0 and 1) - * return value - a tensor with the same shape as target or input - */ -// #if NOT_EXCLUDED(OP_dropout) - @Namespace("sd::ops") public static class dropout extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dropout(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dropout(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dropout position(long position) { - return (dropout)super.position(position); - } - public dropout() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_dropout_bp) - @Namespace("sd::ops") public static class dropout_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dropout_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dropout_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dropout_bp position(long position) { - return (dropout_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_dropout_bp) +@Namespace("sd::ops") public static class dropout_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dropout_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dropout_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dropout_bp position(long position) { + return (dropout_bp)super.position(position); + } + public dropout_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* Calculates alpha weighted dropout + T params: + 0 - drop probability + 1 - alpha value + 2 - alpha' value + 3 - beta value + */ +// #if NOT_EXCLUDED(OP_alpha_dropout_bp) +@Namespace("sd::ops") public static class alpha_dropout_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public alpha_dropout_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public alpha_dropout_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public alpha_dropout_bp position(long position) { + return (alpha_dropout_bp)super.position(position); + } - /* Calculates alpha weighted dropout - T params: - 0 - drop probability - 1 - alpha value - 2 - alpha' value - 3 - beta value - */ -// #if NOT_EXCLUDED(OP_alpha_dropout_bp) - @Namespace("sd::ops") public static class alpha_dropout_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public alpha_dropout_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public alpha_dropout_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public alpha_dropout_bp position(long position) { - return (alpha_dropout_bp)super.position(position); - } - public alpha_dropout_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * bincount operation return a vector with element counted. + * + * input params: + * - input tensor - only int part are accepted + * - weights - the same shape tensor with integer weights for element + * (optional) default weight - 1,1,1..,1 for all values in the tensor + * + * optional ints: + * - min_length - zero or greater + * - max_length - between min_length and max(input) + 1 + * + * returns four tensors: + * - vector tensor with length to min(max_len, max(input) + 1) with count + * of values in indexed place + * + */ +// #if NOT_EXCLUDED(OP_bincount) +@Namespace("sd::ops") public static class bincount extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bincount(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bincount(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bincount position(long position) { + return (bincount)super.position(position); + } - /** - * bincount operation return a vector with element counted. - * - * input params: - * - input tensor - only int part are accepted - * - weights - the same shape tensor with integer weights for element (optional) - * default weight - 1,1,1..,1 for all values in the tensor - * - * optional ints: - * - min_length - zero or greater - * - max_length - between min_length and max(input) + 1 - * - * returns four tensors: - * - vector tensor with length to min(max_len, max(input) + 1) with count - * of values in indexed place - * - */ -// #if NOT_EXCLUDED(OP_bincount) - @Namespace("sd::ops") public static class bincount extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bincount(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bincount(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bincount position(long position) { - return (bincount)super.position(position); - } - public bincount() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * broadcast_dynamic_shape op. + * + * input params: + * 0 - the first shape (vector with shape) + * 1 - the second shape (vector with shape) + * + * return value: + * vector with broadcasted shape + */ +// #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) +@Namespace("sd::ops") public static class broadcast_dynamic_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcast_dynamic_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcast_dynamic_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcast_dynamic_shape position(long position) { + return (broadcast_dynamic_shape)super.position(position); + } - /** - * broadcast_dynamic_shape op. - * - * input params: - * 0 - the first shape (vector with shape) - * 1 - the second shape (vector with shape) - * - * return value: - * vector with broadcasted shape - */ -// #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) - @Namespace("sd::ops") public static class broadcast_dynamic_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcast_dynamic_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcast_dynamic_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcast_dynamic_shape position(long position) { - return (broadcast_dynamic_shape)super.position(position); - } - public broadcast_dynamic_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with determinant for all + * M x M matricies + */ +// #if NOT_EXCLUDED(OP_matrix_determinant) +@Namespace("sd::ops") public static class matrix_determinant extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_determinant(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_determinant position(long position) { + return (matrix_determinant)super.position(position); + } - /** - * matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with determinant for all - * M x M matricies - */ -// #if NOT_EXCLUDED(OP_matrix_determinant) - @Namespace("sd::ops") public static class matrix_determinant extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_determinant(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_determinant position(long position) { - return (matrix_determinant)super.position(position); - } - public matrix_determinant() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * log_matrix_determinant op. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ +/** + * log_matrix_determinant op. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ -// #if NOT_EXCLUDED(OP_log_matrix_determinant) - @Namespace("sd::ops") public static class log_matrix_determinant extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_matrix_determinant(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_matrix_determinant position(long position) { - return (log_matrix_determinant)super.position(position); - } - - public log_matrix_determinant() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif +// #if NOT_EXCLUDED(OP_log_matrix_determinant) +@Namespace("sd::ops") public static class log_matrix_determinant extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_matrix_determinant(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_matrix_determinant(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_matrix_determinant position(long position) { + return (log_matrix_determinant)super.position(position); + } - /** - * logdet op. Logarithm of the determinant of hermitian positive matricies. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: *) with log determinant for all - * M x M matricies - */ + public log_matrix_determinant() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + +/** + * logdet op. Logarithm of the determinant of hermitian positive matricies. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: *) with log determinant for all + * M x M matricies + */ + +// #if NOT_EXCLUDED(OP_logdet) +@Namespace("sd::ops") public static class logdet extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public logdet(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public logdet(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public logdet position(long position) { + return (logdet)super.position(position); + } -// #if NOT_EXCLUDED(OP_logdet) - @Namespace("sd::ops") public static class logdet extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public logdet(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public logdet(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public logdet position(long position) { - return (logdet)super.position(position); - } - public logdet() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_solve_ls op (lstsq) - solves one or more linear least-squares + * problems. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_lstsq) +@Namespace("sd::ops") public static class lstsq extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstsq(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstsq(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstsq position(long position) { + return (lstsq)super.position(position); + } - /** - * matrix_solve_ls op (lstsq) - solves one or more linear least-squares problems. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_lstsq) - @Namespace("sd::ops") public static class lstsq extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lstsq(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lstsq(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lstsq position(long position) { - return (lstsq)super.position(position); - } - public lstsq() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* solve_ls - analog of lstsq op with another solution approach + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * float args: + * 0 - l2_regularizer (default 0. and only for 0 implemented) + * + * boolean args: + * 0 - fast - default is true (optional) - use Cholesky decomposition instead + * QR decomposition of matricies. + * + * return value: + * tensor with dimension (x * y * z * ::: * N * K) with solutions + * + * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq + * method due QR decomposition + * */ +// #if NOT_EXCLUDED(OP_solve_ls) +@Namespace("sd::ops") public static class solve_ls extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public solve_ls(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public solve_ls(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public solve_ls position(long position) { + return (solve_ls)super.position(position); + } - /* solve_ls - analog of lstsq op with another solution approach - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * N) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * float args: - * 0 - l2_regularizer (default 0. and only for 0 implemented) - * - * boolean args: - * 0 - fast - default is true (optional) - use Cholesky decomposition instead QR decomposition of matricies. - * - * return value: - * tensor with dimension (x * y * z * ::: * N * K) with solutions - * - * Note: if fast is false - then l2_regularizer arg is ignored and used lstsq method due QR decomposition - * */ -// #if NOT_EXCLUDED(OP_solve_ls) - @Namespace("sd::ops") public static class solve_ls extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public solve_ls(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public solve_ls(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public solve_ls position(long position) { - return (solve_ls)super.position(position); - } - public solve_ls() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * matrix_inverse op. - make inverse for all 2D square matricies found in the + * input tensor + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M + * matricies in it + */ +// #if NOT_EXCLUDED(OP_matrix_inverse) +@Namespace("sd::ops") public static class matrix_inverse extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_inverse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_inverse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_inverse position(long position) { + return (matrix_inverse)super.position(position); + } - /** - * matrix_inverse op. - make inverse for all 2D square matricies found in the input tensor - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * tensor with dimension (x * y * z * ::: * M * M) with inverse M x M matricies in it - */ -// #if NOT_EXCLUDED(OP_matrix_inverse) - @Namespace("sd::ops") public static class matrix_inverse extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_inverse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_inverse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_inverse position(long position) { - return (matrix_inverse)super.position(position); - } - public matrix_inverse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * triangular_solve op. - reverse Gaussian method for solve systems of linear + * equations. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - lower - default is true (optional) - left part is lower triangular + * matrix 1 - adjoint - default is false (optional) - indicate input matrix or + * its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_triangular_solve) +@Namespace("sd::ops") public static class triangular_solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public triangular_solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public triangular_solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public triangular_solve position(long position) { + return (triangular_solve)super.position(position); + } - /** - * triangular_solve op. - reverse Gaussian method for solve systems of linear equations. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - lower - default is true (optional) - left part is lower triangular matrix - * 1 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_triangular_solve) - @Namespace("sd::ops") public static class triangular_solve extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public triangular_solve(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public triangular_solve(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public triangular_solve position(long position) { - return (triangular_solve)super.position(position); - } - public triangular_solve() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of + * equations 1 - the tensor with dimension (x * y * z * ::: * M * K) - right + * parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its + * adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_solve) +@Namespace("sd::ops") public static class solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public solve position(long position) { + return (solve)super.position(position); + } - /** - * solve op. - solve systems of linear equations - general method. - * - * input params: - * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations - * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations - * - * boolean args: - * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used - * - * return value: - * tensor with dimension (x * y * z * ::: * M * K) with solutions - * - */ -// #if NOT_EXCLUDED(OP_solve) - @Namespace("sd::ops") public static class solve extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public solve(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public solve(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public solve position(long position) { - return (solve)super.position(position); - } - public solve() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif - /** - * lu op. - make LUP decomposition of given batch of 2D square matricies - * - * input params: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) - * - * return value: - * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M matricies in it - * 1 - int (32 or 64) batched vector of permutations with length M - shape (x * y * z * ::: * M) - * - * int argument: - * 0 - data type of output permutaion vector (int32 or int64), optional, default INT32 - */ +/** + * lu op. - make LUP decomposition of given batch of 2D square matricies + * + * input params: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) + * + * return value: + * 0 - float tensor with dimension (x * y * z * ::: * M * M) with LU M x M + * matricies in it 1 - int (32 or 64) batched vector of permutations with length + * M - shape (x * y * z * ::: * M) + * + * int argument: + * 0 - data type of output permutaion vector (int32 or int64), optional, + * default INT32 + */ + +// #if NOT_EXCLUDED(OP_matrix_inverse) +@Namespace("sd::ops") public static class lu extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lu(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lu(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lu position(long position) { + return (lu)super.position(position); + } -// #if NOT_EXCLUDED(OP_matrix_inverse) - @Namespace("sd::ops") public static class lu extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lu(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lu(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lu position(long position) { - return (lu)super.position(position); - } - public lu() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, + * i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] + * + * input params: + * 0 - the ND-tensor filled by integer-like values + * + * optional int param - maxlength (maxlength >= max(x)). By default maxlength = + * max(x). return value: (N+1)D tensor filled by 0 and 1 accordingly the mask + */ +// #if NOT_EXCLUDED(OP_sequence_mask) +@Namespace("sd::ops") public static class sequence_mask extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sequence_mask(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sequence_mask(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sequence_mask position(long position) { + return (sequence_mask)super.position(position); + } - /** - * sequence_mask op. - make mask for given tensor filled by (j > x[i_1, i_2,...,i_n]) -> z[i_1, i_2,...,i_n,j] - * - * input params: - * 0 - the ND-tensor filled by integer-like values - * - * optional int param - maxlength (maxlength >= max(x)). By default maxlength = max(x). - * return value: - * (N+1)D tensor filled by 0 and 1 accordingly the mask - */ -// #if NOT_EXCLUDED(OP_sequence_mask) - @Namespace("sd::ops") public static class sequence_mask extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sequence_mask(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sequence_mask(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sequence_mask position(long position) { - return (sequence_mask)super.position(position); - } - public sequence_mask() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ +// #endif +/** + * segment_max op. - make a tensor filled by max values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ + +// #if NOT_EXCLUDED(OP_segment_max) +@Namespace("sd::ops") public static class segment_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_max position(long position) { + return (segment_max)super.position(position); + } -// #if NOT_EXCLUDED(OP_segment_max) - @Namespace("sd::ops") public static class segment_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_max position(long position) { - return (segment_max)super.position(position); - } - public segment_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_max_bp) - @Namespace("sd::ops") public static class segment_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_max_bp position(long position) { - return (segment_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_max_bp) +@Namespace("sd::ops") public static class segment_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_max_bp position(long position) { + return (segment_max_bp)super.position(position); + } + public segment_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_min op. - make a tensor filled by min values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with min values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_min) +@Namespace("sd::ops") public static class segment_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_min position(long position) { + return (segment_min)super.position(position); + } - /** - * segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with min values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_min) - @Namespace("sd::ops") public static class segment_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_min position(long position) { - return (segment_min)super.position(position); - } - public segment_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_min_bp) - @Namespace("sd::ops") public static class segment_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_min_bp position(long position) { - return (segment_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_min_bp) +@Namespace("sd::ops") public static class segment_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_min_bp position(long position) { + return (segment_min_bp)super.position(position); + } + public segment_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_sum op. - make a tensor filled by sum of values according to index + * tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with sum of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_sum) +@Namespace("sd::ops") public static class segment_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_sum position(long position) { + return (segment_sum)super.position(position); + } - /** - * segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with sum of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_sum) - @Namespace("sd::ops") public static class segment_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_sum position(long position) { - return (segment_sum)super.position(position); - } - public segment_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_sum_bp) - @Namespace("sd::ops") public static class segment_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_sum_bp position(long position) { - return (segment_sum_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_sum_bp) +@Namespace("sd::ops") public static class segment_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_sum_bp position(long position) { + return (segment_sum_bp)super.position(position); + } + public segment_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * segment_prod op. - make a tensor filled by product of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with product of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_prod) +@Namespace("sd::ops") public static class segment_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_prod position(long position) { + return (segment_prod)super.position(position); + } - /** - * segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with product of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_prod) - @Namespace("sd::ops") public static class segment_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_prod position(long position) { - return (segment_prod)super.position(position); - } - public segment_prod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_prod_bp) - @Namespace("sd::ops") public static class segment_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_prod_bp position(long position) { - return (segment_prod_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_prod_bp) +@Namespace("sd::ops") public static class segment_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_prod_bp position(long position) { + return (segment_prod_bp)super.position(position); + } + public segment_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_segment_mean) - @Namespace("sd::ops") public static class segment_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_mean position(long position) { - return (segment_mean)super.position(position); - } - +// #endif +/** + * segment_mean op. - make a tensor filled by average of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_segment_mean) +@Namespace("sd::ops") public static class segment_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_mean position(long position) { + return (segment_mean)super.position(position); + } + public segment_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_segment_mean_bp) - @Namespace("sd::ops") public static class segment_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public segment_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public segment_mean_bp position(long position) { - return (segment_mean_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_segment_mean_bp) +@Namespace("sd::ops") public static class segment_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public segment_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public segment_mean_bp position(long position) { + return (segment_mean_bp)super.position(position); + } + public segment_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_max op. - make a tensor filled by max values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * return value: + * tensor with max values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_max) +@Namespace("sd::ops") public static class unsorted_segment_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_max position(long position) { + return (unsorted_segment_max)super.position(position); + } - /** - * unsorted_segment_max op. - make a tensor filled by max values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * return value: - * tensor with max values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_max) - @Namespace("sd::ops") public static class unsorted_segment_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_max position(long position) { - return (unsorted_segment_max)super.position(position); - } - public unsorted_segment_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) - @Namespace("sd::ops") public static class unsorted_segment_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_max_bp position(long position) { - return (unsorted_segment_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_max_bp) +@Namespace("sd::ops") public static class unsorted_segment_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_max_bp position(long position) { + return (unsorted_segment_max_bp)super.position(position); + } + public unsorted_segment_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_min op. - make a tensor filled by min values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with min values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +@Namespace("sd::ops") public static class unsorted_segment_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_min position(long position) { + return (unsorted_segment_min)super.position(position); + } - /** - * unsorted_segment_min op. - make a tensor filled by min values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with min values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - @Namespace("sd::ops") public static class unsorted_segment_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_min position(long position) { - return (unsorted_segment_min)super.position(position); - } - public unsorted_segment_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) - @Namespace("sd::ops") public static class unsorted_segment_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_min_bp position(long position) { - return (unsorted_segment_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_min_bp) +@Namespace("sd::ops") public static class unsorted_segment_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_min_bp position(long position) { + return (unsorted_segment_min_bp)super.position(position); + } + public unsorted_segment_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_sum op. - make a tensor filled by sum of values according to + * index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with sum of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_sum) +@Namespace("sd::ops") public static class unsorted_segment_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sum position(long position) { + return (unsorted_segment_sum)super.position(position); + } - /** - * unsorted_segment_sum op. - make a tensor filled by sum of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with sum of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_sum) - @Namespace("sd::ops") public static class unsorted_segment_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sum position(long position) { - return (unsorted_segment_sum)super.position(position); - } - public unsorted_segment_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) - @Namespace("sd::ops") public static class unsorted_segment_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sum_bp position(long position) { - return (unsorted_segment_sum_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_sum_bp) +@Namespace("sd::ops") public static class unsorted_segment_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sum_bp position(long position) { + return (unsorted_segment_sum_bp)super.position(position); + } + public unsorted_segment_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_prod op. - make a tensor filled by product of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with product of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_prod) +@Namespace("sd::ops") public static class unsorted_segment_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_prod position(long position) { + return (unsorted_segment_prod)super.position(position); + } + + public unsorted_segment_prod() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) +@Namespace("sd::ops") public static class unsorted_segment_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_prod_bp position(long position) { + return (unsorted_segment_prod_bp)super.position(position); + } - /** - * unsorted_segment_prod op. - make a tensor filled by product of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with product of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_prod) - @Namespace("sd::ops") public static class unsorted_segment_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_prod position(long position) { - return (unsorted_segment_prod)super.position(position); - } - - public unsorted_segment_prod() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_prod_bp) - @Namespace("sd::ops") public static class unsorted_segment_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_prod_bp position(long position) { - return (unsorted_segment_prod_bp)super.position(position); - } - public unsorted_segment_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_mean op. - make a tensor filled by average of values + * according to index tensor given. + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_mean) +@Namespace("sd::ops") public static class unsorted_segment_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_mean position(long position) { + return (unsorted_segment_mean)super.position(position); + } - /** - * unsorted_segment_mean op. - make a tensor filled by average of values according to index tensor given. - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_mean) - @Namespace("sd::ops") public static class unsorted_segment_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_mean position(long position) { - return (unsorted_segment_mean)super.position(position); - } - public unsorted_segment_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) - @Namespace("sd::ops") public static class unsorted_segment_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_mean_bp position(long position) { - return (unsorted_segment_mean_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_mean_bp) +@Namespace("sd::ops") public static class unsorted_segment_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_mean_bp position(long position) { + return (unsorted_segment_mean_bp)super.position(position); + } + public unsorted_segment_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor + * divided by the sqrt(N). + * + * input params: + * 0 - the tensor with data; + * 1 - the tensor with indices. + * + * integer param: + * 0 - num of segments + * + * return value: + * tensor with average of values according to indices sets. + */ +// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) +@Namespace("sd::ops") public static class unsorted_segment_sqrt_n extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sqrt_n(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sqrt_n(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sqrt_n position(long position) { + return (unsorted_segment_sqrt_n)super.position(position); + } - /** - * unsorted_segment_sqrt_n op. - computes the sum along segments of a tensor divided by the sqrt(N). - * - * input params: - * 0 - the tensor with data; - * 1 - the tensor with indices. - * - * integer param: - * 0 - num of segments - * - * return value: - * tensor with average of values according to indices sets. - */ -// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt) - @Namespace("sd::ops") public static class unsorted_segment_sqrt_n extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sqrt_n(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sqrt_n(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sqrt_n position(long position) { - return (unsorted_segment_sqrt_n)super.position(position); - } - public unsorted_segment_sqrt_n() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) - @Namespace("sd::ops") public static class unsorted_segment_sqrt_n_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public unsorted_segment_sqrt_n_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public unsorted_segment_sqrt_n_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public unsorted_segment_sqrt_n_bp position(long position) { - return (unsorted_segment_sqrt_n_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_unsorted_segment_sqrt_n_bp) +@Namespace("sd::ops") public static class unsorted_segment_sqrt_n_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public unsorted_segment_sqrt_n_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public unsorted_segment_sqrt_n_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public unsorted_segment_sqrt_n_bp position(long position) { + return (unsorted_segment_sqrt_n_bp)super.position(position); + } + public unsorted_segment_sqrt_n_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * extract_image_patches op - Extract patches from images and put them in the + * "depth" output dimension. + * + * input params: + * 0 - images tensor (4D) + * + * int params: + * 0 - ksize_rows + * 1 - ksize_cols + * 2 - strides_rows + * 3 - strides_cols + * 4 - rates_rows + * 5 - rates_cols + * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' + */ +// #if NOT_EXCLUDED(OP_extract_image_patches) +@Namespace("sd::ops") public static class extract_image_patches extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public extract_image_patches(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public extract_image_patches(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public extract_image_patches position(long position) { + return (extract_image_patches)super.position(position); + } - /** - * extract_image_patches op - Extract patches from images and put them in the "depth" output dimension. - * - * input params: - * 0 - images tensor (4D) - * - * int params: - * 0 - ksize_rows - * 1 - ksize_cols - * 2 - strides_rows - * 3 - strides_cols - * 4 - rates_rows - * 5 - rates_cols - * 6 - padding_type - 0 - equiv 'VALID', 1 - 'SAME' - */ -// #if NOT_EXCLUDED(OP_extract_image_patches) - @Namespace("sd::ops") public static class extract_image_patches extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public extract_image_patches(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public extract_image_patches(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public extract_image_patches position(long position) { - return (extract_image_patches)super.position(position); - } - public extract_image_patches() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * draw_bounding_boxes op - modified input image with given colors exept given + * boxes. + * + * input params: + * 0 - images tensor (4D) with shape {batch, width, height, channels}, where + * channes is 1 (BW image), 3 (RGB) or 4 (RGBA) 1 - boxes tensor (3D) with shape + * {batch, number_of_boxes, 4} where last dimension encoded as (y_min, x_min, + * y_max, x_max), all values in between 0. and 1. 2 - colours tensor (2D) with + * shape {number_of_boxes, channels} -- bordering color set (palette) + * + * output: + * 0 - 4D tensor with same shape as images (input 0) + */ +// #if NOT_EXCLUDED(OP_draw_bounding_boxes) +@Namespace("sd::ops") public static class draw_bounding_boxes extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public draw_bounding_boxes(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public draw_bounding_boxes position(long position) { + return (draw_bounding_boxes)super.position(position); + } - /** - * draw_bounding_boxes op - modified input image with given colors exept given boxes. - * - * input params: - * 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image), - * 3 (RGB) or 4 (RGBA) - * 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as - * (y_min, x_min, y_max, x_max), all values in between 0. and 1. - * 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette) - * - * output: - * 0 - 4D tensor with same shape as images (input 0) - */ -// #if NOT_EXCLUDED(OP_draw_bounding_boxes) - @Namespace("sd::ops") public static class draw_bounding_boxes extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public draw_bounding_boxes(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public draw_bounding_boxes(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public draw_bounding_boxes position(long position) { - return (draw_bounding_boxes)super.position(position); - } - public draw_bounding_boxes() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * roll - op porting from numpy + * (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) + * + * input params: + * 0 - NDArray + * + * int params: + * 0 - shift + * 1 - axe 1 + * 2 - axe 2 + * ... + * N - axe N + * + * All axes are optional and should be between 0 and input->rankOf(). Of + * course, all axes can be repeated. + * + * output: + * 0 - NDArray with the same shape as input. + */ +// #if NOT_EXCLUDED(OP_roll) +@Namespace("sd::ops") public static class roll extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public roll(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public roll(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public roll position(long position) { + return (roll)super.position(position); + } - /** - * roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html) - * - * input params: - * 0 - NDArray - * - * int params: - * 0 - shift - * 1 - axe 1 - * 2 - axe 2 - * ... - * N - axe N - * - * All axes are optional and should be between 0 and input->rankOf(). Of course, all axes can be repeated. - * - * output: - * 0 - NDArray with the same shape as input. - */ -// #if NOT_EXCLUDED(OP_roll) - @Namespace("sd::ops") public static class roll extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public roll(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public roll(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public roll position(long position) { - return (roll)super.position(position); - } - public roll() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * lin_space - op porting from TF + * (https://www.tensorflow.org/api_docs/python/tf/lin_space) + * + * optional input params: + * 0 - startVal - NDArray scalar (float point) + * 1 - finishVal - NDArray scalar (float point) + * 2 - numOfElements - NDArray scalar (integer) + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements + * output: + * 0 - 1D NDArray with the same type as input and length as given with + * numOfElements param. + */ +// #if NOT_EXCLUDED(OP_lin_space) +@Namespace("sd::ops") public static class lin_space extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lin_space(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lin_space(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lin_space position(long position) { + return (lin_space)super.position(position); + } - /** - * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) - * - * optional input params: - * 0 - startVal - NDArray scalar (float point) - * 1 - finishVal - NDArray scalar (float point) - * 2 - numOfElements - NDArray scalar (integer) - * Optional: - * T args - * 0 - startVal - * 1 - finishVal] - * 2 - numOfElements - * output: - * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. - */ -// #if NOT_EXCLUDED(OP_lin_space) - @Namespace("sd::ops") public static class lin_space extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lin_space(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lin_space(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lin_space position(long position) { - return (lin_space)super.position(position); - } - public lin_space() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * reduction_sum - tf.reduction_sum operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_sum) +@Namespace("sd::ops") public static class reduce_sum extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sum(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sum(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sum position(long position) { + return (reduce_sum)super.position(position); + } - /** - * reduction_sum - tf.reduction_sum operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_sum) - @Namespace("sd::ops") public static class reduce_sum extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sum(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sum(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sum position(long position) { - return (reduce_sum)super.position(position); - } - public reduce_sum() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_sum_bp) +@Namespace("sd::ops") public static class reduce_sum_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sum_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sum_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sum_bp position(long position) { + return (reduce_sum_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_sum_bp) - @Namespace("sd::ops") public static class reduce_sum_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sum_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sum_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sum_bp position(long position) { - return (reduce_sum_bp)super.position(position); - } - public reduce_sum_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * reduction_prod - tf.reduction_prod operation + * + * input params: + * 0 - NDArray + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * All axes are optional and should be between 0 and input->rankOf() - 1 + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_prod) +@Namespace("sd::ops") public static class reduce_prod extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_prod(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_prod(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_prod position(long position) { + return (reduce_prod)super.position(position); + } - /** - * reduction_prod - tf.reduction_prod operation - * - * input params: - * 0 - NDArray - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * All axes are optional and should be between 0 and input->rankOf() - 1 - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_prod) - @Namespace("sd::ops") public static class reduce_prod extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_prod(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_prod(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_prod position(long position) { - return (reduce_prod)super.position(position); - } - public reduce_prod() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_prod_bp) +@Namespace("sd::ops") public static class reduce_prod_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_prod_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_prod_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_prod_bp position(long position) { + return (reduce_prod_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_prod_bp) - @Namespace("sd::ops") public static class reduce_prod_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_prod_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_prod_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_prod_bp position(long position) { - return (reduce_prod_bp)super.position(position); - } - public reduce_prod_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates min of elements along given dimensions + * + * input array: + * x: tensor to calculate mins for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate min along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated mins + */ +// #if NOT_EXCLUDED(OP_reduce_min) +@Namespace("sd::ops") public static class reduce_min extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_min(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_min(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_min position(long position) { + return (reduce_min)super.position(position); + } - /** - * This op calculates min of elements along given dimensions - * - * input array: - * x: tensor to calculate mins for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate min along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated mins - */ -// #if NOT_EXCLUDED(OP_reduce_min) - @Namespace("sd::ops") public static class reduce_min extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_min(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_min(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_min position(long position) { - return (reduce_min)super.position(position); - } - public reduce_min() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_min_bp) - @Namespace("sd::ops") public static class reduce_min_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_min_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_min_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_min_bp position(long position) { - return (reduce_min_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_min_bp) +@Namespace("sd::ops") public static class reduce_min_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_min_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_min_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_min_bp position(long position) { + return (reduce_min_bp)super.position(position); + } + public reduce_min_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates max of elements along given dimensions + * + * input array: + * x: tensor to calculate maxes for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate max along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated maxes + */ +// #if NOT_EXCLUDED(OP_reduce_max) +@Namespace("sd::ops") public static class reduce_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_max position(long position) { + return (reduce_max)super.position(position); + } - /** - * This op calculates max of elements along given dimensions - * - * input array: - * x: tensor to calculate maxes for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated maxes - */ -// #if NOT_EXCLUDED(OP_reduce_max) - @Namespace("sd::ops") public static class reduce_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_max position(long position) { - return (reduce_max)super.position(position); - } - public reduce_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_max_bp) - @Namespace("sd::ops") public static class reduce_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_max_bp position(long position) { - return (reduce_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_max_bp) +@Namespace("sd::ops") public static class reduce_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_max_bp position(long position) { + return (reduce_max_bp)super.position(position); + } + public reduce_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm1 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm1 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm1 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm1 + */ +// #if NOT_EXCLUDED(OP_reduce_norm1) +@Namespace("sd::ops") public static class reduce_norm1 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm1(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm1(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm1 position(long position) { + return (reduce_norm1)super.position(position); + } - /** - * This op calculates norm1 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm1 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm1 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm1 - */ -// #if NOT_EXCLUDED(OP_reduce_norm1) - @Namespace("sd::ops") public static class reduce_norm1 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm1(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm1(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm1 position(long position) { - return (reduce_norm1)super.position(position); - } - public reduce_norm1() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm1_bp) - @Namespace("sd::ops") public static class reduce_norm1_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm1_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm1_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm1_bp position(long position) { - return (reduce_norm1_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm1_bp) +@Namespace("sd::ops") public static class reduce_norm1_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm1_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm1_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm1_bp position(long position) { + return (reduce_norm1_bp)super.position(position); + } + public reduce_norm1_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm2 of elements along given dimensions + * + * input array: + * x: tensor to calculate norm2 for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm2 along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm2 + */ +// #if NOT_EXCLUDED(OP_reduce_norm2) +@Namespace("sd::ops") public static class reduce_norm2 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm2(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm2(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm2 position(long position) { + return (reduce_norm2)super.position(position); + } - /** - * This op calculates norm2 of elements along given dimensions - * - * input array: - * x: tensor to calculate norm2 for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm2 along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm2 - */ -// #if NOT_EXCLUDED(OP_reduce_norm2) - @Namespace("sd::ops") public static class reduce_norm2 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm2(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm2(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm2 position(long position) { - return (reduce_norm2)super.position(position); - } - public reduce_norm2() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm2_bp) - @Namespace("sd::ops") public static class reduce_norm2_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm2_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm2_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm2_bp position(long position) { - return (reduce_norm2_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm2_bp) +@Namespace("sd::ops") public static class reduce_norm2_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm2_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm2_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm2_bp position(long position) { + return (reduce_norm2_bp)super.position(position); + } + public reduce_norm2_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op calculates squared norm of elements along given dimensions + * + * input array: + * x: tensor to calculate squared norm for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate squared norm along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +// #if NOT_EXCLUDED(OP_reduce_sqnorm) +@Namespace("sd::ops") public static class reduce_sqnorm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sqnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sqnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sqnorm position(long position) { + return (reduce_sqnorm)super.position(position); + } - /** - * This op calculates squared norm of elements along given dimensions - * - * input array: - * x: tensor to calculate squared norm for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate squared norm along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ -// #if NOT_EXCLUDED(OP_reduce_sqnorm) - @Namespace("sd::ops") public static class reduce_sqnorm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sqnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sqnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sqnorm position(long position) { - return (reduce_sqnorm)super.position(position); - } - public reduce_sqnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) - @Namespace("sd::ops") public static class reduce_sqnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_sqnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_sqnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_sqnorm_bp position(long position) { - return (reduce_sqnorm_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_sqnorm_bp) +@Namespace("sd::ops") public static class reduce_sqnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_sqnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_sqnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_sqnorm_bp position(long position) { + return (reduce_sqnorm_bp)super.position(position); + } + public reduce_sqnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates norm max of elements along given dimensions + * + * input array: + * x: tensor to calculate norm max for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate norm max along, default + * corresponds to empty list in which case calculation is performed for all + * dimensions and scalar is returned + * + * output array: + * reduced tensor with calculated norm + */ +// #if NOT_EXCLUDED(OP_reduce_norm_max) +@Namespace("sd::ops") public static class reduce_norm_max extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm_max(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm_max(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm_max position(long position) { + return (reduce_norm_max)super.position(position); + } - /** - * This op calculates norm max of elements along given dimensions - * - * input array: - * x: tensor to calculate norm max for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate norm max along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated norm - */ -// #if NOT_EXCLUDED(OP_reduce_norm_max) - @Namespace("sd::ops") public static class reduce_norm_max extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm_max(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm_max(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm_max position(long position) { - return (reduce_norm_max)super.position(position); - } - public reduce_norm_max() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_reduce_norm_max_bp) - @Namespace("sd::ops") public static class reduce_norm_max_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_norm_max_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_norm_max_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_norm_max_bp position(long position) { - return (reduce_norm_max_bp)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_reduce_norm_max_bp) +@Namespace("sd::ops") public static class reduce_norm_max_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_norm_max_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_norm_max_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_norm_max_bp position(long position) { + return (reduce_norm_max_bp)super.position(position); + } + public reduce_norm_max_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op calculates mean of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +// #if NOT_EXCLUDED(OP_reduce_mean) +@Namespace("sd::ops") public static class reduce_mean extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_mean(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_mean(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_mean position(long position) { + return (reduce_mean)super.position(position); + } - /** - * This op calculates mean of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ -// #if NOT_EXCLUDED(OP_reduce_mean) - @Namespace("sd::ops") public static class reduce_mean extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_mean(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_mean(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_mean position(long position) { - return (reduce_mean)super.position(position); - } - public reduce_mean() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reduce_mean_bp) +@Namespace("sd::ops") public static class reduce_mean_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_mean_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_mean_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_mean_bp position(long position) { + return (reduce_mean_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_mean_bp) - @Namespace("sd::ops") public static class reduce_mean_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_mean_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_mean_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_mean_bp position(long position) { - return (reduce_mean_bp)super.position(position); - } - public reduce_mean_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - // #endif - /** - * This op calculates sample variance of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - @Namespace("sd::ops") public static class reduce_variance extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_variance(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_variance(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_variance position(long position) { - return (reduce_variance)super.position(position); - } - + // #endif +/** + * This op calculates sample variance of elements along given dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +@Namespace("sd::ops") public static class reduce_variance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_variance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_variance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_variance position(long position) { + return (reduce_variance)super.position(position); + } + public reduce_variance() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reduce_variance_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_variance_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_variance_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_variance_bp position(long position) { - return (reduce_variance_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reduce_variance_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_variance_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_variance_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_variance_bp position(long position) { + return (reduce_variance_bp)super.position(position); + } + public reduce_variance_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } /** - * This op calculates sample standard deviation of elements along given dimensions - * - * input array: - * x: tensor to calculate mean for - * - * float arguments: - * keepDims: if non zero, then keep reduced dimensions with length = 1, default value is zero - * biasCorrected - if non zero, then bias correction will be applied, default value is zero - * - * int arguments: - * list of integers - dimensions to calculate mean along, default corresponds to empty list in which case calculation is performed for all dimensions and scalar is returned - * - * output array: - * reduced tensor with calculated means - */ - @Namespace("sd::ops") public static class reduce_stdev extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_stdev(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_stdev(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_stdev position(long position) { - return (reduce_stdev)super.position(position); - } - + * This op calculates sample standard deviation of elements along given + * dimensions + * + * input array: + * x: tensor to calculate mean for + * + * float arguments: + * keepDims: if non zero, then keep reduced dimensions with length = 1, + * default value is zero biasCorrected - if non zero, then bias correction will + * be applied, default value is zero + * + * int arguments: + * list of integers - dimensions to calculate mean along, default corresponds + * to empty list in which case calculation is performed for all dimensions and + * scalar is returned + * + * output array: + * reduced tensor with calculated means + */ +@Namespace("sd::ops") public static class reduce_stdev extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_stdev(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_stdev(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_stdev position(long position) { + return (reduce_stdev)super.position(position); + } + public reduce_stdev() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class reduce_stdev_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_stdev_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_stdev_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_stdev_bp position(long position) { - return (reduce_stdev_bp)super.position(position); - } - +@Namespace("sd::ops") public static class reduce_stdev_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_stdev_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_stdev_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_stdev_bp position(long position) { + return (reduce_stdev_bp)super.position(position); + } + public reduce_stdev_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } /** - * This op calculates backprop dot for two tensors along given dimensions - * - * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. - * - * output array: - * the tensor with calculated backproped dots - * - */ + * This op calculates backprop dot for two tensors along given dimensions + * + * input array: + * x: tensor to calculate dot for + * y: tensor to calculate dot for + * z: tensor with gradient output of the FF dot for x and y + * + * int arguments: + * list of integers - dimensions to calculate dot along, + * default corresponds to empty list in which case calculation + * is performed for all dimensions and scalar is returned. + * + * output array: + * the tensor with calculated backproped dots + * + */ + +// #if NOT_EXCLUDED(OP_reduce_dot_bp) +@Namespace("sd::ops") public static class reduce_dot_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_dot_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_dot_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_dot_bp position(long position) { + return (reduce_dot_bp)super.position(position); + } -// #if NOT_EXCLUDED(OP_reduce_dot_bp) - @Namespace("sd::ops") public static class reduce_dot_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_dot_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_dot_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_dot_bp position(long position) { - return (reduce_dot_bp)super.position(position); - } - public reduce_dot_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * reduce_logsumexp - tf.reduce_logsumexe operation - * - * input params: - * 0 - NDArray (input) - * 1 - 1D NDArray (axis) (optional) - integer array - * - * T_ARG param (optional): - * 0 - keep_dims != 0. - * - * int params (optional): - * 0 - axe 1 - * 1 - axe 2 - * ... - * N-1 axe N - * - * CAUTION: All axes are optional and should be between 0 and input->rankOf() - 1 - * and put either with second param or as integers but not both - * - * output: - * 0 - NDArray with reduces shape accordingly to axes (the scalar in default case). - */ -// #if NOT_EXCLUDED(OP_reduce_logsumexp) - @Namespace("sd::ops") public static class reduce_logsumexp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reduce_logsumexp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reduce_logsumexp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reduce_logsumexp position(long position) { - return (reduce_logsumexp)super.position(position); - } - +// #endif +/** + * reduce_logsumexp - tf.reduce_logsumexe operation + * + * input params: + * 0 - NDArray (input) + * 1 - 1D NDArray (axis) (optional) - integer array + * + * T_ARG param (optional): + * 0 - keep_dims != 0. + * + * int params (optional): + * 0 - axe 1 + * 1 - axe 2 + * ... + * N-1 axe N + * + * CAUTION: All axes are optional and should be between 0 and input->rankOf() - + * 1 and put either with second param or as integers but not both + * + * output: + * 0 - NDArray with reduces shape accordingly to axes (the scalar in default + * case). + */ +// #if NOT_EXCLUDED(OP_reduce_logsumexp) +@Namespace("sd::ops") public static class reduce_logsumexp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reduce_logsumexp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reduce_logsumexp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reduce_logsumexp position(long position) { + return (reduce_logsumexp)super.position(position); + } + public reduce_logsumexp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * Copy a tensor setting everything outside a central band in each innermost matrix @@ -21385,294 +23806,297 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint * */ -// #if NOT_EXCLUDED(OP_matrix_band_part) - @Namespace("sd::ops") public static class matrix_band_part extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matrix_band_part(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matrix_band_part(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matrix_band_part position(long position) { - return (matrix_band_part)super.position(position); - } - +// #if NOT_EXCLUDED(OP_matrix_band_part) +@Namespace("sd::ops") public static class matrix_band_part extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matrix_band_part(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matrix_band_part(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matrix_band_part position(long position) { + return (matrix_band_part)super.position(position); + } + public matrix_band_part() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +// #if NOT_EXCLUDED(OP_Assert) +@Namespace("sd::ops") public static class Assert extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public Assert(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public Assert(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public Assert position(long position) { + return (Assert)super.position(position); + } -// #if NOT_EXCLUDED(OP_Assert) - @Namespace("sd::ops") public static class Assert extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public Assert(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public Assert(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public Assert position(long position) { - return (Assert)super.position(position); - } - public Assert() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * image.non_max_suppression ops. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * - vector with size M, where M <= output_size by int type + * + * */ +// #if NOT_EXCLUDED(OP_image_non_max_suppression) +@Namespace("sd::ops") public static class non_max_suppression extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression position(long position) { + return (non_max_suppression)super.position(position); + } - /** - * image.non_max_suppression ops. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * - vector with size M, where M <= output_size by int type - * - * */ -// #if NOT_EXCLUDED(OP_image_non_max_suppression) - @Namespace("sd::ops") public static class non_max_suppression extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression position(long position) { - return (non_max_suppression)super.position(position); - } - public non_max_suppression() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif -// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) - @Namespace("sd::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression_v3(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression_v3 position(long position) { - return (non_max_suppression_v3)super.position(position); - } - +// #endif +// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3) +@Namespace("sd::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression_v3(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression_v3 position(long position) { + return (non_max_suppression_v3)super.position(position); + } + public non_max_suppression_v3() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by + * default 0.5) 1 - score_threshold - the threshold for deciding when to remove + * boxes based on score (optional, by default -inf) int args: 0 - output_size - + * as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices + * from the overlaps tensor, where M <= max_output_size + * */ +// #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) +@Namespace("sd::ops") public static class non_max_suppression_overlaps extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public non_max_suppression_overlaps(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public non_max_suppression_overlaps(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public non_max_suppression_overlaps position(long position) { + return (non_max_suppression_overlaps)super.position(position); + } - /* - * image.non_max_suppression_overlaps op. - * input: - * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type - * 1 - scales - 1D-tensor with shape (num_boxes) by float type - * 2 - output_size - 0D-tensor by int type (optional) - * float args: - * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) - * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) - * int args: - * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. - * - * output: - * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size - * */ -// #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) - @Namespace("sd::ops") public static class non_max_suppression_overlaps extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public non_max_suppression_overlaps(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public non_max_suppression_overlaps(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public non_max_suppression_overlaps position(long position) { - return (non_max_suppression_overlaps)super.position(position); - } - public non_max_suppression_overlaps() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * cholesky op - decomposite positive square symetric matrix (or matricies when + * rank > 2). input: 0 - matricies - tensor with shape (..., N, N) by float type + * + * output - lower triangular matrix (matricies when rank > 2) with the same + * shape as input. + * */ +// #if NOT_EXCLUDED(OP_cholesky) +@Namespace("sd::ops") public static class cholesky extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cholesky(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cholesky(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cholesky position(long position) { + return (cholesky)super.position(position); + } - /* - * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). - * input: - * 0 - matricies - tensor with shape (..., N, N) by float type - * - * output - lower triangular matrix (matricies when rank > 2) with the same shape as input. - * */ -// #if NOT_EXCLUDED(OP_cholesky) - @Namespace("sd::ops") public static class cholesky extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cholesky(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cholesky(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cholesky position(long position) { - return (cholesky)super.position(position); - } - public cholesky() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * nth_element - apply nth_element for last dimension of input tensor - * input array: - * 0 - input array - * 1 - scalar tensor with n for operation. n should be less than last dimension - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_nth_element) - @Namespace("sd::ops") public static class nth_element extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public nth_element(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public nth_element(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public nth_element position(long position) { - return (nth_element)super.position(position); - } - +// #endif +/* + * nth_element - apply nth_element for last dimension of input tensor + * input array: + * 0 - input array + * 1 - scalar tensor with n for operation. n should be less than last + * dimension + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_nth_element) +@Namespace("sd::ops") public static class nth_element extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public nth_element(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public nth_element(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public nth_element position(long position) { + return (nth_element)super.position(position); + } + public nth_element() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op checks for Inf/NaN values within input array, and throws exception if + * there's at least one + */ +// #if NOT_EXCLUDED(OP_check_numerics) +@Namespace("sd::ops") public static class check_numerics extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public check_numerics(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public check_numerics(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public check_numerics position(long position) { + return (check_numerics)super.position(position); + } - /** - * This op checks for Inf/NaN values within input array, and throws exception if there's at least one - */ -// #if NOT_EXCLUDED(OP_check_numerics) - @Namespace("sd::ops") public static class check_numerics extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public check_numerics(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public check_numerics(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public check_numerics position(long position) { - return (check_numerics)super.position(position); - } - public check_numerics() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** - * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - min value - * 2 - 0D Tensor - max value - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) - @Namespace("sd::ops") public static class fake_quant_with_min_max_vars extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fake_quant_with_min_max_vars(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fake_quant_with_min_max_vars(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fake_quant_with_min_max_vars position(long position) { - return (fake_quant_with_min_max_vars)super.position(position); - } - + * fake_quant_with_min_max_vals - tf.quantization.fake_quant_with_min_max_vars + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - min value + * 2 - 0D Tensor - max value + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars) +@Namespace("sd::ops") public static class fake_quant_with_min_max_vars extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars position(long position) { + return (fake_quant_with_min_max_vars)super.position(position); + } + public fake_quant_with_min_max_vars() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** - * fake_quant_with_min_max_vals_per_channel - tf.quantization.fake_quant_with_min_max_vars_per_channel - * - * input params: - * 0 - NDArray (input) - at least 2D. - * 1 - 1D Tensor - min values (min length equals to last dim of input) - * 2 - 1D Tensor - max value (length equals to min) - * - * int params (optional): - * 0 - num_bits (allowed interval [2, 16], default 8) - * 1 - narrow_range (default False) - * - * output: - * 0 - NDArray with the same shape as input - */ -// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) - @Namespace("sd::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fake_quant_with_min_max_vars_per_channel position(long position) { - return (fake_quant_with_min_max_vars_per_channel)super.position(position); - } - + * fake_quant_with_min_max_vals_per_channel - + * tf.quantization.fake_quant_with_min_max_vars_per_channel + * + * input params: + * 0 - NDArray (input) - at least 2D. + * 1 - 1D Tensor - min values (min length equals to last dim of input) + * 2 - 1D Tensor - max value (length equals to min) + * + * int params (optional): + * 0 - num_bits (allowed interval [2, 16], default 8) + * 1 - narrow_range (default False) + * + * output: + * 0 - NDArray with the same shape as input + */ +// #if NOT_EXCLUDED(OP_fake_quant_with_min_max_vars_per_channel) +@Namespace("sd::ops") public static class fake_quant_with_min_max_vars_per_channel extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fake_quant_with_min_max_vars_per_channel(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fake_quant_with_min_max_vars_per_channel(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fake_quant_with_min_max_vars_per_channel position(long position) { + return (fake_quant_with_min_max_vars_per_channel)super.position(position); + } + public fake_quant_with_min_max_vars_per_channel() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * compare_and_bitpack - compare with greater and pack result with uint8 + * + * input params: + * 0 - NDArray (input) + * 1 - 0D Tensor - threshold + * + * + * output: + * 0 - NDArray with the same shape as input and type uint8 + */ +// #if NOT_EXCLUDED(OP_compare_and_bitpack) +@Namespace("sd::ops") public static class compare_and_bitpack extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public compare_and_bitpack(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public compare_and_bitpack position(long position) { + return (compare_and_bitpack)super.position(position); + } - /** - * compare_and_bitpack - compare with greater and pack result with uint8 - * - * input params: - * 0 - NDArray (input) - * 1 - 0D Tensor - threshold - * - * - * output: - * 0 - NDArray with the same shape as input and type uint8 - */ -// #if NOT_EXCLUDED(OP_compare_and_bitpack) - @Namespace("sd::ops") public static class compare_and_bitpack extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public compare_and_bitpack(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public compare_and_bitpack(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public compare_and_bitpack position(long position) { - return (compare_and_bitpack)super.position(position); - } - public compare_and_bitpack() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -21703,308 +24127,307 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_SHAPE_H // #include -// #if NOT_EXCLUDED(OP_permute) - @Namespace("sd::ops") public static class permute extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public permute(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public permute(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public permute position(long position) { - return (permute)super.position(position); - } - +// #if NOT_EXCLUDED(OP_permute) +@Namespace("sd::ops") public static class permute extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public permute(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public permute(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public permute position(long position) { + return (permute)super.position(position); + } + public permute() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reshapeas) +@Namespace("sd::ops") public static class reshapeas extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reshapeas(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reshapeas(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reshapeas position(long position) { + return (reshapeas)super.position(position); + } -// #if NOT_EXCLUDED(OP_reshapeas) - @Namespace("sd::ops") public static class reshapeas extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reshapeas(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reshapeas(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reshapeas position(long position) { - return (reshapeas)super.position(position); - } - public reshapeas() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_transpose) +@Namespace("sd::ops") public static class transpose extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public transpose(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public transpose(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public transpose position(long position) { + return (transpose)super.position(position); + } -// #if NOT_EXCLUDED(OP_transpose) - @Namespace("sd::ops") public static class transpose extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public transpose(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public transpose(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public transpose position(long position) { - return (transpose)super.position(position); - } - public transpose() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_shape_of) +@Namespace("sd::ops") public static class shape_of extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shape_of(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shape_of(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shape_of position(long position) { + return (shape_of)super.position(position); + } -// #if NOT_EXCLUDED(OP_shape_of) - @Namespace("sd::ops") public static class shape_of extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shape_of(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shape_of(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shape_of position(long position) { - return (shape_of)super.position(position); - } - public shape_of() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_shapes_of) +@Namespace("sd::ops") public static class shapes_of extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shapes_of(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shapes_of(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shapes_of position(long position) { + return (shapes_of)super.position(position); + } -// #if NOT_EXCLUDED(OP_shapes_of) - @Namespace("sd::ops") public static class shapes_of extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shapes_of(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shapes_of(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shapes_of position(long position) { - return (shapes_of)super.position(position); - } - public shapes_of() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_squeeze) +@Namespace("sd::ops") public static class squeeze extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public squeeze(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public squeeze(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public squeeze position(long position) { + return (squeeze)super.position(position); + } -// #if NOT_EXCLUDED(OP_squeeze) - @Namespace("sd::ops") public static class squeeze extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public squeeze(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public squeeze(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public squeeze position(long position) { - return (squeeze)super.position(position); - } - public squeeze() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_expand_dims) +@Namespace("sd::ops") public static class expand_dims extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public expand_dims(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public expand_dims(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public expand_dims position(long position) { + return (expand_dims)super.position(position); + } -// #if NOT_EXCLUDED(OP_expand_dims) - @Namespace("sd::ops") public static class expand_dims extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public expand_dims(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public expand_dims(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public expand_dims position(long position) { - return (expand_dims)super.position(position); - } - public expand_dims() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_reshape) +@Namespace("sd::ops") public static class reshape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public reshape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public reshape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public reshape position(long position) { + return (reshape)super.position(position); + } -// #if NOT_EXCLUDED(OP_reshape) - @Namespace("sd::ops") public static class reshape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public reshape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public reshape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public reshape position(long position) { - return (reshape)super.position(position); - } - public reshape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_size_at) +@Namespace("sd::ops") public static class size_at extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public size_at(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public size_at(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public size_at position(long position) { + return (size_at)super.position(position); + } -// #if NOT_EXCLUDED(OP_size_at) - @Namespace("sd::ops") public static class size_at extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public size_at(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public size_at(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public size_at position(long position) { - return (size_at)super.position(position); - } - public size_at() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op changes order of given array to specified order. + * In other words: C/F order switch + * + * Int args: + * 0 - isForder. set to 1 for F order output, or 0 for C order output + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_order) +@Namespace("sd::ops") public static class order extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public order(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public order(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public order position(long position) { + return (order)super.position(position); + } - /** - * This op changes order of given array to specified order. - * In other words: C/F order switch - * - * Int args: - * 0 - isForder. set to 1 for F order output, or 0 for C order output - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_order) - @Namespace("sd::ops") public static class order extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public order(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public order(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public order position(long position) { - return (order)super.position(position); - } - public order() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op boosts specified input up to specified shape + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_tile_to_shape) +@Namespace("sd::ops") public static class tile_to_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_to_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_to_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_to_shape position(long position) { + return (tile_to_shape)super.position(position); + } - /** - * This op boosts specified input up to specified shape - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_tile_to_shape) - @Namespace("sd::ops") public static class tile_to_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_to_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_to_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_to_shape position(long position) { - return (tile_to_shape)super.position(position); - } - public tile_to_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tile_to_shape_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tile_to_shape_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tile_to_shape_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tile_to_shape_bp position(long position) { - return (tile_to_shape_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tile_to_shape_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tile_to_shape_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tile_to_shape_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tile_to_shape_bp position(long position) { + return (tile_to_shape_bp)super.position(position); + } + public tile_to_shape_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op broadcast given input up to given shape + * + * inputs: + * input array - array to be broadcasted to given shape + * shape array - array containing shape be broadcasted to + */ +// #if NOT_EXCLUDED(OP_broadcast_to) +@Namespace("sd::ops") public static class broadcast_to extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public broadcast_to(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public broadcast_to(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public broadcast_to position(long position) { + return (broadcast_to)super.position(position); + } - /** - * This op broadcast given input up to given shape - * - * inputs: - * input array - array to be broadcasted to given shape - * shape array - array containing shape be broadcasted to - */ -// #if NOT_EXCLUDED(OP_broadcast_to) - @Namespace("sd::ops") public static class broadcast_to extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public broadcast_to(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public broadcast_to(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public broadcast_to position(long position) { - return (broadcast_to)super.position(position); - } - public broadcast_to() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif + } +// #endif +// #if NOT_EXCLUDED(OP_evaluate_reduction_shape) +@Namespace("sd::ops") public static class evaluate_reduction_shape extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public evaluate_reduction_shape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public evaluate_reduction_shape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public evaluate_reduction_shape position(long position) { + return (evaluate_reduction_shape)super.position(position); + } -// #if NOT_EXCLUDED(OP_evaluate_reduction_shape) - @Namespace("sd::ops") public static class evaluate_reduction_shape extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public evaluate_reduction_shape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public evaluate_reduction_shape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public evaluate_reduction_shape position(long position) { - return (evaluate_reduction_shape)super.position(position); - } - public evaluate_reduction_shape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +// #if NOT_EXCLUDED(OP_create) +@Namespace("sd::ops") public static class create extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create position(long position) { + return (create)super.position(position); + } - /** - * This operation creates new array - * Input: - * array with shape values - * - * IArgs: - * order value - * data type value - * - * BArgs: - * initialization option - */ -// #if NOT_EXCLUDED(OP_create) - @Namespace("sd::ops") public static class create extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public create(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public create(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public create position(long position) { - return (create)super.position(position); - } - public create() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22035,218 +24458,219 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_RANDOM_H // #include -// #if NOT_EXCLUDED(OP_set_seed) - @Namespace("sd::ops") public static class set_seed extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public set_seed(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public set_seed(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public set_seed position(long position) { - return (set_seed)super.position(position); - } - +// #if NOT_EXCLUDED(OP_set_seed) +@Namespace("sd::ops") public static class set_seed extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public set_seed(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public set_seed(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public set_seed position(long position) { + return (set_seed)super.position(position); + } + public set_seed() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_get_seed) +@Namespace("sd::ops") public static class get_seed extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public get_seed(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public get_seed(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public get_seed position(long position) { + return (get_seed)super.position(position); + } -// #if NOT_EXCLUDED(OP_get_seed) - @Namespace("sd::ops") public static class get_seed extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public get_seed(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public get_seed(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public get_seed position(long position) { - return (get_seed)super.position(position); - } - public get_seed() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/* + * random_uniform distribution for types int32,int64, float16, float and double + * by default dtype is float32 + * + * input: + * 0 - shape of output (1D int tensor) + * 1 - min val (0D of output type) - optional (0 as default) + * 2 - max val (0D of output type) - optional (inf as default) + * + * output: + * 0 - uniformly distributed values of given type (between min and max) + */ +// #if NOT_EXCLUDED(OP_randomuniform) +@Namespace("sd::ops") public static class randomuniform extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public randomuniform(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public randomuniform(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public randomuniform position(long position) { + return (randomuniform)super.position(position); + } - /* - * random_uniform distribution for types int32,int64, float16, float and double - * by default dtype is float32 - * - * input: - * 0 - shape of output (1D int tensor) - * 1 - min val (0D of output type) - optional (0 as default) - * 2 - max val (0D of output type) - optional (inf as default) - * - * output: - * 0 - uniformly distributed values of given type (between min and max) - */ -// #if NOT_EXCLUDED(OP_randomuniform) - @Namespace("sd::ops") public static class randomuniform extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public randomuniform(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public randomuniform(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public randomuniform position(long position) { - return (randomuniform)super.position(position); - } - public randomuniform() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /* - * multinomial (categorical) random generator draws samples from a multinomial distribution - * - * Input array: - * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] - * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. - * Int arguments: - * 0 - optional argument, corresponds to dimension with batch_size - * 1 - optional argument, integer type to use for the output. Default int64. - * - * Output array: - * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] - */ -// #if NOT_EXCLUDED(OP_random_multinomial) - @Namespace("sd::ops") public static class random_multinomial extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_multinomial(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_multinomial(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_multinomial position(long position) { - return (random_multinomial)super.position(position); - } - +// #endif +/* + * multinomial (categorical) random generator draws samples from a multinomial + * distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size + * (N), num_classes (K)] 1 - array with one int value of samples number, number + * of independent samples to draw for each experiment 1,N. Int arguments: 0 - + * optional argument, corresponds to dimension with batch_size 1 - optional + * argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ +// #if NOT_EXCLUDED(OP_random_multinomial) +@Namespace("sd::ops") public static class random_multinomial extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_multinomial(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_multinomial(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_multinomial position(long position) { + return (random_multinomial)super.position(position); + } + public random_multinomial() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_normal) +@Namespace("sd::ops") public static class random_normal extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_normal(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_normal(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_normal position(long position) { + return (random_normal)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_normal) - @Namespace("sd::ops") public static class random_normal extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_normal(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_normal(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_normal position(long position) { - return (random_normal)super.position(position); - } - public random_normal() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_bernoulli) +@Namespace("sd::ops") public static class random_bernoulli extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_bernoulli(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_bernoulli(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_bernoulli position(long position) { + return (random_bernoulli)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_bernoulli) - @Namespace("sd::ops") public static class random_bernoulli extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_bernoulli(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_bernoulli(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_bernoulli position(long position) { - return (random_bernoulli)super.position(position); - } - public random_bernoulli() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_exponential) +@Namespace("sd::ops") public static class random_exponential extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_exponential(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_exponential(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_exponential position(long position) { + return (random_exponential)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_exponential) - @Namespace("sd::ops") public static class random_exponential extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_exponential(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_exponential(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_exponential position(long position) { - return (random_exponential)super.position(position); - } - public random_exponential() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_random_crop) +@Namespace("sd::ops") public static class random_crop extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_crop(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_crop(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_crop position(long position) { + return (random_crop)super.position(position); + } -// #if NOT_EXCLUDED(OP_random_crop) - @Namespace("sd::ops") public static class random_crop extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_crop(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_crop(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_crop position(long position) { - return (random_crop)super.position(position); - } - public random_crop() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * random_gamma op. + */ +// #if NOT_EXCLUDED(OP_random_gamma) +@Namespace("sd::ops") public static class random_gamma extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_gamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_gamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_gamma position(long position) { + return (random_gamma)super.position(position); + } - /** - * random_gamma op. - */ -// #if NOT_EXCLUDED(OP_random_gamma) - @Namespace("sd::ops") public static class random_gamma extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_gamma(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_gamma(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_gamma position(long position) { - return (random_gamma)super.position(position); - } - public random_gamma() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * random_poisson op. + */ +// #if NOT_EXCLUDED(OP_random_poisson) +@Namespace("sd::ops") public static class random_poisson extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public random_poisson(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public random_poisson(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public random_poisson position(long position) { + return (random_poisson)super.position(position); + } - /** - * random_poisson op. - */ -// #if NOT_EXCLUDED(OP_random_poisson) - @Namespace("sd::ops") public static class random_poisson extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public random_poisson(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public random_poisson(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public random_poisson position(long position) { - return (random_poisson)super.position(position); - } - public random_poisson() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22277,473 +24701,481 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include -// #if NOT_EXCLUDED(OP_softmax) - @Namespace("sd::ops") public static class softmax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax position(long position) { - return (softmax)super.position(position); - } - +// #if NOT_EXCLUDED(OP_softmax) +@Namespace("sd::ops") public static class softmax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax position(long position) { + return (softmax)super.position(position); + } + public softmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softmax_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_bp position(long position) { - return (softmax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class softmax_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_bp position(long position) { + return (softmax_bp)super.position(position); + } + public softmax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Local response normalization implementation as TF. + * input: 4D array + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - 4D array + */ +// #if NOT_EXCLUDED(OP_lrn) +@Namespace("sd::ops") public static class lrn extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrn(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrn(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrn position(long position) { + return (lrn)super.position(position); + } - /** - * Local response normalization implementation as TF. - * input: 4D array - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - 4D array - */ -// #if NOT_EXCLUDED(OP_lrn) - @Namespace("sd::ops") public static class lrn extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrn(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrn(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrn position(long position) { - return (lrn)super.position(position); - } - public lrn() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Local response normalization - backprop variant. + * input: + * 0 - 4D array of data + * 1 - epsilon - 4D array of approximation + * + * T args: + * + * 0: bias + * 1: alpha + * 2: beta + * + * Int arg: depth - optional local radius + * + * output - next approximation as 4D array + */ +// #if NOT_EXCLUDED(OP_lrn) +@Namespace("sd::ops") public static class lrn_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lrn_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lrn_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lrn_bp position(long position) { + return (lrn_bp)super.position(position); + } - /** - * Local response normalization - backprop variant. - * input: - * 0 - 4D array of data - * 1 - epsilon - 4D array of approximation - * - * T args: - * - * 0: bias - * 1: alpha - * 2: beta - * - * Int arg: depth - optional local radius - * - * output - next approximation as 4D array - */ -// #if NOT_EXCLUDED(OP_lrn) - @Namespace("sd::ops") public static class lrn_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public lrn_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public lrn_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public lrn_bp position(long position) { - return (lrn_bp)super.position(position); - } - public lrn_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * Batch normalization implementation. + * Reference: https://arxiv.org/abs/1502.03167v3 + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: + * beta: + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * + * T args: + * 0: epsilon + */ +// #if NOT_EXCLUDED(OP_batchnorm) +@Namespace("sd::ops") public static class batchnorm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batchnorm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batchnorm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batchnorm position(long position) { + return (batchnorm)super.position(position); + } - /** - * Batch normalization implementation. - * Reference: https://arxiv.org/abs/1502.03167v3 - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: - * beta: - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * - * T args: - * 0: epsilon - */ -// #if NOT_EXCLUDED(OP_batchnorm) - @Namespace("sd::ops") public static class batchnorm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm position(long position) { - return (batchnorm)super.position(position); - } - public batchnorm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * back prop in batch normalization + * + * Expected arguments: + * input: input array (any number of dimensions) + * mean: + * variance: + * gamma: optional + * beta: optional + * dLdOut: next epsilon + * + * Int args: + * 0: apply scale + * 1: apply offset + * + * T args: + * 0: epsilon + * + * output arrays: + * dL/dInput + * dL/dMean + * dL/dVariance + * dL/dGamma, optional + * dL/dBeta, optional + */ +// #if NOT_EXCLUDED(OP_batchnorm) +@Namespace("sd::ops") public static class batchnorm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batchnorm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batchnorm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batchnorm_bp position(long position) { + return (batchnorm_bp)super.position(position); + } - /** - * back prop in batch normalization - * - * Expected arguments: - * input: input array (any number of dimensions) - * mean: - * variance: - * gamma: optional - * beta: optional - * dLdOut: next epsilon - * - * Int args: - * 0: apply scale - * 1: apply offset - * - * T args: - * 0: epsilon - * - * output arrays: - * dL/dInput - * dL/dMean - * dL/dVariance - * dL/dGamma, optional - * dL/dBeta, optional - */ -// #if NOT_EXCLUDED(OP_batchnorm) - @Namespace("sd::ops") public static class batchnorm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batchnorm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batchnorm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batchnorm_bp position(long position) { - return (batchnorm_bp)super.position(position); - } - public batchnorm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation updates parameters with provided gradients, wrt learning rate + * Expected arguments: + * x: parameters, any shape + * y: gradients. same shape as x + * lr: optional, learning rate + * + * T args: + * 0: optional, learning rate + */ +// #if NOT_EXCLUDED(OP_apply_sgd) +@Namespace("sd::ops") public static class apply_sgd extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public apply_sgd(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public apply_sgd(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public apply_sgd position(long position) { + return (apply_sgd)super.position(position); + } - /** - * This operation updates parameters with provided gradients, wrt learning rate - * Expected arguments: - * x: parameters, any shape - * y: gradients. same shape as x - * lr: optional, learning rate - * - * T args: - * 0: optional, learning rate - */ -// #if NOT_EXCLUDED(OP_apply_sgd) - @Namespace("sd::ops") public static class apply_sgd extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public apply_sgd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public apply_sgd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public apply_sgd position(long position) { - return (apply_sgd)super.position(position); - } - public apply_sgd() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation performs batch normalization of layer, it is based on + * following article https://arxiv.org/abs/1502.03167. Expected arguments: x: + * input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] + * (data format = NCHW), where bS - batch size iH - input height iW - input + * width iD - input depth (or number of channels) scale: 1D input array of + * scale factors, shape [iD] offset: 1D input array of offsets (shifts), shape + * [iD] mean: 1D input array of population mean used for inference, shape [iD], + * this array is required only if isTraining = false variance: 1D input array of + * population mean used for inference, shape [iD], this array is required only + * if isTraining = false + * + * T input arguments: + * 0: epsilon, it is optional argument, default value is 0.001, this is small + * number to be added to the variance of x + * + * integer input arguments: + * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW + * 1: isTraining, may have two values: zero -> inference, unity -> training + */ +// #if NOT_EXCLUDED(OP_fused_batch_norm) +@Namespace("sd::ops") public static class fused_batch_norm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public fused_batch_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public fused_batch_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public fused_batch_norm position(long position) { + return (fused_batch_norm)super.position(position); + } - /** - * This operation performs batch normalization of layer, it is based on following article https://arxiv.org/abs/1502.03167. - * Expected arguments: - * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width - * iD - input depth (or number of channels) - * scale: 1D input array of scale factors, shape [iD] - * offset: 1D input array of offsets (shifts), shape [iD] - * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * - * T input arguments: - * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * - * integer input arguments: - * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW - * 1: isTraining, may have two values: zero -> inference, unity -> training - */ -// #if NOT_EXCLUDED(OP_fused_batch_norm) - @Namespace("sd::ops") public static class fused_batch_norm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public fused_batch_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public fused_batch_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public fused_batch_norm position(long position) { - return (fused_batch_norm)super.position(position); - } - public fused_batch_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_log_softmax) +@Namespace("sd::ops") public static class log_softmax extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_softmax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_softmax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_softmax position(long position) { + return (log_softmax)super.position(position); + } -// #if NOT_EXCLUDED(OP_log_softmax) - @Namespace("sd::ops") public static class log_softmax extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_softmax(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_softmax(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_softmax position(long position) { - return (log_softmax)super.position(position); - } - public log_softmax() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_softmax_bp extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_softmax_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_softmax_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_softmax_bp position(long position) { - return (log_softmax_bp)super.position(position); - } - +@Namespace("sd::ops") public static class log_softmax_bp extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_softmax_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_softmax_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_softmax_bp position(long position) { + return (log_softmax_bp)super.position(position); + } + public log_softmax_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * relu_layer = relu(x*w + b) + */ +@Namespace("sd::ops") public static class relu_layer extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public relu_layer(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public relu_layer(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public relu_layer position(long position) { + return (relu_layer)super.position(position); + } - /** - * relu_layer = relu(x*w + b) - */ - @Namespace("sd::ops") public static class relu_layer extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public relu_layer(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public relu_layer(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public relu_layer position(long position) { - return (relu_layer)super.position(position); - } - public relu_layer() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - /** - * applies layer normalization to input - * y = g * standardize(x) + b - * - * see sd::ops::standardize - * - */ -// #if NOT_EXCLUDED(OP_layer_norm) - @Namespace("sd::ops") public static class layer_norm extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public layer_norm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public layer_norm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public layer_norm position(long position) { - return (layer_norm)super.position(position); - } - +/** + * applies layer normalization to input + * y = g * standardize(x) + b + * + * see sd::ops::standardize + * + */ +// #if NOT_EXCLUDED(OP_layer_norm) +@Namespace("sd::ops") public static class layer_norm extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public layer_norm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public layer_norm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public layer_norm position(long position) { + return (layer_norm)super.position(position); + } + public layer_norm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class layer_norm_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public layer_norm_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public layer_norm_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public layer_norm_bp position(long position) { - return (layer_norm_bp)super.position(position); - } - +@Namespace("sd::ops") public static class layer_norm_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public layer_norm_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public layer_norm_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public layer_norm_bp position(long position) { + return (layer_norm_bp)super.position(position); + } + public layer_norm_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation performs dot product attention on the given timeseries input + * with the given queries out = sum(similarity(k_i, q) * v_i) + * + * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q + * + * Optionally with normalization step: + * similarity(k, q) = softmax(k * q / sqrt(size(q)) + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, + * eq. 1) + * + * Note: This supports multiple queries at once, if only one query is available + * the queries vector still has to be 3D but can have queryCount = 1 + * + * Note: keys and values usually is the same array. If you want to use it as the + * same array, simply pass it for both. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or + * 4D array of shape [batchSize, numHeads, featureKeys, queryCount] k: input 3D + * array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of + * shape [batchSize, numHeads, featureKeys, timesteps] v: input 3D array + * "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape + * [batchSize, numHeads, featureValues, timesteps] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount] 1: OPTIONAL; Attention + * weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, + * timesteps, queryCount] + */ +// #if NOT_EXCLUDED(OP_dot_product_attention) +@Namespace("sd::ops") public static class dot_product_attention extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dot_product_attention(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dot_product_attention position(long position) { + return (dot_product_attention)super.position(position); + } - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] - */ -// #if NOT_EXCLUDED(OP_dot_product_attention) - @Namespace("sd::ops") public static class dot_product_attention extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dot_product_attention(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dot_product_attention position(long position) { - return (dot_product_attention)super.position(position); - } - public dot_product_attention() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class dot_product_attention_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public dot_product_attention_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public dot_product_attention_bp position(long position) { - return (dot_product_attention_bp)super.position(position); - } - +@Namespace("sd::ops") public static class dot_product_attention_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public dot_product_attention_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public dot_product_attention_bp position(long position) { + return (dot_product_attention_bp)super.position(position); + } + public dot_product_attention_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This performs multi-headed dot product attention on the given timeseries + * input out = concat(head_1, head_2, ..., head_n) * Wo head_i = + * dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) + * + * Optionally with normalization when calculating the attention for each head. + * + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. + * 4,5, "3.2.2 Multi-Head Attention") + * + * This makes use of dot_product_attention OP support for rank 4 inputs. + * + * Expected arguments: + * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] + * Wq: input query projection weights of shape [numHeads, projectedKeys, + * featureKeys] Wk: input key projection weights of shape [numHeads, + * projectedKeys, featureKeys] Wv: input value projection weights of shape + * [numHeads, projectedValues, featureValues] Wo: output projection weights of + * shape [numHeads * projectedValues, outSize] mask: OPTIONAL; array that + * defines which values should be skipped of shape [batchSize, timesteps] + * + * integer input arguments: + * 0: normalization, may have two values: zero -> do not apply normalization, + * one -> apply normalization 1: withWeights, may have two values: zero -> do + * not return weights, one -> return weights + * + * Output Arrays: + * 0: Attention result arrays of shape [batchSize, outSize, queryCount] + * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, + * queryCount] + */ +// #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) +@Namespace("sd::ops") public static class multi_head_dot_product_attention extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multi_head_dot_product_attention(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multi_head_dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multi_head_dot_product_attention position(long position) { + return (multi_head_dot_product_attention)super.position(position); + } - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * - * Expected arguments: - * q: input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * k: input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * v: input 3D array "values" of shape [batchSize, featureValues, timesteps] - * Wq: input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wk: input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * mask: OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * - * integer input arguments: - * 0: normalization, may have two values: zero -> do not apply normalization, one -> apply normalization - * 1: withWeights, may have two values: zero -> do not return weights, one -> return weights - * - * Output Arrays: - * 0: Attention result arrays of shape [batchSize, outSize, queryCount] - * 1: OPTIONAL; Attention weights of shape [batchSize, numHeads, timesteps, queryCount] - */ -// #if NOT_EXCLUDED(OP_multi_head_dot_product_attention) - @Namespace("sd::ops") public static class multi_head_dot_product_attention extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multi_head_dot_product_attention(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multi_head_dot_product_attention(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multi_head_dot_product_attention position(long position) { - return (multi_head_dot_product_attention)super.position(position); - } - public multi_head_dot_product_attention() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class multi_head_dot_product_attention_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public multi_head_dot_product_attention_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public multi_head_dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public multi_head_dot_product_attention_bp position(long position) { - return (multi_head_dot_product_attention_bp)super.position(position); - } - +@Namespace("sd::ops") public static class multi_head_dot_product_attention_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public multi_head_dot_product_attention_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public multi_head_dot_product_attention_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public multi_head_dot_product_attention_bp position(long position) { + return (multi_head_dot_product_attention_bp)super.position(position); + } + public multi_head_dot_product_attention_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -22773,150 +25205,149 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include - /** - * This op is general matmum implementation. Depending on inputs dimensionality output result might be different. - * matrix x matrix = BLAS gemm - * vector x matrix = BLAS gemm - * vector x vector = BLAS dot - * vector x scalar = element-wise mul - * scalar x vector = element-wise mul - * - * Optional T arguments: - * 0: alpha (where applicable) - * 1: beta (where applicable) - * - * Optional Integer arguments: - * 0: transA (where applicable) - * 1: transB (where applicable) - */ -// #if NOT_EXCLUDED(OP_matmul) - @Namespace("sd::ops") public static class matmul extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matmul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matmul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matmul position(long position) { - return (matmul)super.position(position); - } - +/** + * This op is general matmum implementation. Depending on inputs dimensionality + * output result might be different. matrix x matrix = BLAS gemm vector x matrix + * = BLAS gemm vector x vector = BLAS dot vector x scalar = element-wise mul + * scalar x vector = element-wise mul + * + * Optional T arguments: + * 0: alpha (where applicable) + * 1: beta (where applicable) + * + * Optional Integer arguments: + * 0: transA (where applicable) + * 1: transB (where applicable) + */ +// #if NOT_EXCLUDED(OP_matmul) +@Namespace("sd::ops") public static class matmul extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matmul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matmul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matmul position(long position) { + return (matmul)super.position(position); + } + public matmul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class matmul_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public matmul_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public matmul_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public matmul_bp position(long position) { - return (matmul_bp)super.position(position); - } - +@Namespace("sd::ops") public static class matmul_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public matmul_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public matmul_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public matmul_bp position(long position) { + return (matmul_bp)super.position(position); + } + public matmul_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * tensorMmul/tensorDot operation + * takes 2 ndarrays, and 2 sets of axes + * + * Integer argumens map: + * IArgs[0] - number of axes along for first array + * IArgs[1]... axes values for first array + * IArgs[] - number of axes along for second array + * IArgs[1]... axes values for second array + */ +// #if NOT_EXCLUDED(OP_tensormmul) +@Namespace("sd::ops") public static class tensormmul extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tensormmul(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tensormmul(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tensormmul position(long position) { + return (tensormmul)super.position(position); + } - /** - * tensorMmul/tensorDot operation - * takes 2 ndarrays, and 2 sets of axes - * - * Integer argumens map: - * IArgs[0] - number of axes along for first array - * IArgs[1]... axes values for first array - * IArgs[] - number of axes along for second array - * IArgs[1]... axes values for second array - */ -// #if NOT_EXCLUDED(OP_tensormmul) - @Namespace("sd::ops") public static class tensormmul extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tensormmul(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tensormmul(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tensormmul position(long position) { - return (tensormmul)super.position(position); - } - public tensormmul() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class tensormmul_bp extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public tensormmul_bp(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public tensormmul_bp(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public tensormmul_bp position(long position) { - return (tensormmul_bp)super.position(position); - } - +@Namespace("sd::ops") public static class tensormmul_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public tensormmul_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public tensormmul_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public tensormmul_bp position(long position) { + return (tensormmul_bp)super.position(position); + } + public tensormmul_bp() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This op is simple implementation of BLAS AXPY method. + * Math is: y += a * x; + */ +// #if NOT_EXCLUDED(OP_axpy) +@Namespace("sd::ops") public static class axpy extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public axpy(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public axpy(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public axpy position(long position) { + return (axpy)super.position(position); + } - /** - * This op is simple implementation of BLAS AXPY method. - * Math is: y += a * x; - */ -// #if NOT_EXCLUDED(OP_axpy) - @Namespace("sd::ops") public static class axpy extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public axpy(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public axpy(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public axpy position(long position) { - return (axpy)super.position(position); - } - public axpy() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation implements batched matrix multiplication + * Expected arguments: + * alpha: vector of T + * beta: vector of T + * ...: A, B matrices sequentially. i.e: AAAAABBBBB + * + * Integer arguments: + * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments + * batchCount - number of operations in this batch + * + * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within + * batch. + */ +// #if NOT_EXCLUDED(OP_batched_gemm) +@Namespace("sd::ops") public static class batched_gemm extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public batched_gemm(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public batched_gemm(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public batched_gemm position(long position) { + return (batched_gemm)super.position(position); + } - /** - * This operation implements batched matrix multiplication - * Expected arguments: - * alpha: vector of T - * beta: vector of T - * ...: A, B matrices sequentially. i.e: AAAAABBBBB - * - * Integer arguments: - * transA, transB, M, N, K, ldA, ldB, ldC - usual BLAS gemm arguments - * batchCount - number of operations in this batch - * - * PLEASE NOTE: M, N, K, ldA, ldB, ldC should be equal for all matrices within batch. - */ -// #if NOT_EXCLUDED(OP_batched_gemm) - @Namespace("sd::ops") public static class batched_gemm extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public batched_gemm(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public batched_gemm(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public batched_gemm position(long position) { - return (batched_gemm)super.position(position); - } - public batched_gemm() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif /** * performs singular value decomposition (SVD) of one or more matrices, evaluates the SVD of each inner-most 2D matrix in input array: @@ -23009,114 +25440,114 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // // #include -// #if NOT_EXCLUDED(OP_test_output_reshape) - @Namespace("sd::ops") public static class test_output_reshape extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public test_output_reshape(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public test_output_reshape(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public test_output_reshape position(long position) { - return (test_output_reshape)super.position(position); - } - +// #if NOT_EXCLUDED(OP_test_output_reshape) +@Namespace("sd::ops") public static class test_output_reshape extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public test_output_reshape(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public test_output_reshape(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public test_output_reshape position(long position) { + return (test_output_reshape)super.position(position); + } + public test_output_reshape() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_test_scalar) +@Namespace("sd::ops") public static class test_scalar extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public test_scalar(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public test_scalar(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public test_scalar position(long position) { + return (test_scalar)super.position(position); + } -// #if NOT_EXCLUDED(OP_test_scalar) - @Namespace("sd::ops") public static class test_scalar extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public test_scalar(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public test_scalar(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public test_scalar position(long position) { - return (test_scalar)super.position(position); - } - public test_scalar() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testreduction) +@Namespace("sd::ops") public static class testreduction extends DeclarableReductionOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testreduction(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testreduction(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testreduction position(long position) { + return (testreduction)super.position(position); + } -// #if NOT_EXCLUDED(OP_testreduction) - @Namespace("sd::ops") public static class testreduction extends DeclarableReductionOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testreduction(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testreduction(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testreduction position(long position) { - return (testreduction)super.position(position); - } - public testreduction() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_noop) +@Namespace("sd::ops") public static class noop extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public noop(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public noop(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public noop position(long position) { + return (noop)super.position(position); + } -// #if NOT_EXCLUDED(OP_noop) - @Namespace("sd::ops") public static class noop extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public noop(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public noop(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public noop position(long position) { - return (noop)super.position(position); - } - public noop() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testop2i2o) +@Namespace("sd::ops") public static class testop2i2o extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testop2i2o(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testop2i2o(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testop2i2o position(long position) { + return (testop2i2o)super.position(position); + } -// #if NOT_EXCLUDED(OP_testop2i2o) - @Namespace("sd::ops") public static class testop2i2o extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testop2i2o(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testop2i2o(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testop2i2o position(long position) { - return (testop2i2o)super.position(position); - } - public testop2i2o() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +// #if NOT_EXCLUDED(OP_testcustom) +@Namespace("sd::ops") public static class testcustom extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public testcustom(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public testcustom(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public testcustom position(long position) { + return (testcustom)super.position(position); + } -// #if NOT_EXCLUDED(OP_testcustom) - @Namespace("sd::ops") public static class testcustom extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public testcustom(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public testcustom(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public testcustom position(long position) { - return (testcustom)super.position(position); - } - public testcustom() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // Parsed from ops/declarable/headers/bitwise.h @@ -23144,226 +25575,228 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_BITWISE_H // #include - /** - * This operation toggles individual bits of each element in array - * - * PLEASE NOTE: This operation is possible only on integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_toggle_bits) - @Namespace("sd::ops") public static class toggle_bits extends DeclarableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public toggle_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public toggle_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public toggle_bits position(long position) { - return (toggle_bits)super.position(position); - } - +/** + * This operation toggles individual bits of each element in array + * + * PLEASE NOTE: This operation is possible only on integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_toggle_bits) +@Namespace("sd::ops") public static class toggle_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public toggle_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public toggle_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public toggle_bits position(long position) { + return (toggle_bits)super.position(position); + } + public toggle_bits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_shift_bits) +@Namespace("sd::ops") public static class shift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shift_bits position(long position) { + return (shift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array to the left: << - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_shift_bits) - @Namespace("sd::ops") public static class shift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public shift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public shift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public shift_bits position(long position) { - return (shift_bits)super.position(position); - } - public shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array to the right: + * >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_rshift_bits) +@Namespace("sd::ops") public static class rshift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rshift_bits position(long position) { + return (rshift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array to the right: >> - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_rshift_bits) - @Namespace("sd::ops") public static class rshift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public rshift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public rshift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public rshift_bits position(long position) { - return (rshift_bits)super.position(position); - } - public rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array, shifting to + * the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_shift_bits) +@Namespace("sd::ops") public static class cyclic_shift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_shift_bits position(long position) { + return (cyclic_shift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array, shifting to the left - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_cyclic_shift_bits) - @Namespace("sd::ops") public static class cyclic_shift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cyclic_shift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cyclic_shift_bits position(long position) { - return (cyclic_shift_bits)super.position(position); - } - public cyclic_shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation shift individual bits of each element in array, shifting to + * the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_rshift_bits) +@Namespace("sd::ops") public static class cyclic_rshift_bits extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_rshift_bits position(long position) { + return (cyclic_rshift_bits)super.position(position); + } - /** - * This operation shift individual bits of each element in array, shifting to the right - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - @Namespace("sd::ops") public static class cyclic_rshift_bits extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cyclic_rshift_bits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cyclic_rshift_bits position(long position) { - return (cyclic_rshift_bits)super.position(position); - } - public cyclic_rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise AND + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_and) +@Namespace("sd::ops") public static class bitwise_and extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_and(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_and position(long position) { + return (bitwise_and)super.position(position); + } - /** - * This operation applies bitwise AND - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_and) - @Namespace("sd::ops") public static class bitwise_and extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_and(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_and(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_and position(long position) { - return (bitwise_and)super.position(position); - } - public bitwise_and() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise OR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_or) +@Namespace("sd::ops") public static class bitwise_or extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_or(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_or position(long position) { + return (bitwise_or)super.position(position); + } - /** - * This operation applies bitwise OR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_or) - @Namespace("sd::ops") public static class bitwise_or extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_or(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_or(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_or position(long position) { - return (bitwise_or)super.position(position); - } - public bitwise_or() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation applies bitwise XOR + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bitwise_xor) +@Namespace("sd::ops") public static class bitwise_xor extends BroadcastableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitwise_xor(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitwise_xor position(long position) { + return (bitwise_xor)super.position(position); + } - /** - * This operation applies bitwise XOR - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bitwise_xor) - @Namespace("sd::ops") public static class bitwise_xor extends BroadcastableOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitwise_xor(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitwise_xor position(long position) { - return (bitwise_xor)super.position(position); - } - public bitwise_xor() { super((Pointer)null); allocate(); } private native void allocate(); } -// #endif +// #endif + +/** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bits_hamming_distance) +@Namespace("sd::ops") public static class bits_hamming_distance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bits_hamming_distance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bits_hamming_distance position(long position) { + return (bits_hamming_distance)super.position(position); + } - /** - * This operation returns hamming distance based on bits - * - * PLEASE NOTE: This operation is applicable only to integer data types - * - * \tparam T - */ -// #if NOT_EXCLUDED(OP_bits_hamming_distance) - @Namespace("sd::ops") public static class bits_hamming_distance extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bits_hamming_distance(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bits_hamming_distance position(long position) { - return (bits_hamming_distance)super.position(position); - } - public bits_hamming_distance() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -23393,692 +25826,731 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_LOSS_H // #include - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of hinge loss function max(0, 1 - labels*logits) - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_hinge_loss) - @Namespace("sd::ops") public static class hinge_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hinge_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hinge_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hinge_loss position(long position) { - return (hinge_loss)super.position(position); - } - + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of hinge loss function max(0, 1 - labels*logits) + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_hinge_loss) +@Namespace("sd::ops") public static class hinge_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hinge_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hinge_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hinge_loss position(long position) { + return (hinge_loss)super.position(position); + } + public hinge_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class hinge_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public hinge_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public hinge_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public hinge_loss_grad position(long position) { - return (hinge_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class hinge_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public hinge_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public hinge_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public hinge_loss_grad position(long position) { + return (hinge_loss_grad)super.position(position); + } + public hinge_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Huber loss function: + * 0.5 * (labels-predictions)^2 if + * |labels-predictions| <= delta 0.5 * delta^2 + delta * (|labels-predictions| - + * delta) if |labels-predictions| > delta + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: point where the huber loss function changes from a quadratic to linear. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_huber_loss) +@Namespace("sd::ops") public static class huber_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public huber_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public huber_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public huber_loss position(long position) { + return (huber_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Huber loss function: - * 0.5 * (labels-predictions)^2 if |labels-predictions| <= delta - * 0.5 * delta^2 + delta * (|labels-predictions| - delta) if |labels-predictions| > delta - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: point where the huber loss function changes from a quadratic to linear. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_huber_loss) - @Namespace("sd::ops") public static class huber_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public huber_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public huber_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public huber_loss position(long position) { - return (huber_loss)super.position(position); - } - public huber_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class huber_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public huber_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public huber_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public huber_loss_grad position(long position) { - return (huber_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class huber_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public huber_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public huber_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public huber_loss_grad position(long position) { + return (huber_loss_grad)super.position(position); + } + public huber_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * + * log(1 - p_i) ) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Input float arguments: + * 0: a small increment to add to avoid taking a log of zero. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_log_loss) +@Namespace("sd::ops") public static class log_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_loss position(long position) { + return (log_loss)super.position(position); + } - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of logarithmic loss function ( y_i * log(p_i) + (1 - y_i) * log(1 - p_i) ) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: a small increment to add to avoid taking a log of zero. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_log_loss) - @Namespace("sd::ops") public static class log_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_loss position(long position) { - return (log_loss)super.position(position); - } - public log_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_loss_grad position(long position) { - return (log_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class log_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_loss_grad position(long position) { + return (log_loss_grad)super.position(position); + } + public log_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * l2_loss op. + * compute a l2 norm for given array. + * + * input param - an array (tensor) + * output value - a real number with given type (e.g. float or double) + */ +// #if NOT_EXCLUDED(OP_l2_loss) +@Namespace("sd::ops") public static class l2_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public l2_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public l2_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public l2_loss position(long position) { + return (l2_loss)super.position(position); + } - /** - * l2_loss op. - * compute a l2 norm for given array. - * - * input param - an array (tensor) - * output value - a real number with given type (e.g. float or double) - */ -// #if NOT_EXCLUDED(OP_l2_loss) - @Namespace("sd::ops") public static class l2_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public l2_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public l2_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public l2_loss position(long position) { - return (l2_loss)super.position(position); - } - public l2_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +/** + * This op calculates logarithmic loss of poisson distributed input. + * Input arrays: + * 0: log_predictions - must be already pre-transformed to log(x) + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights 1: optional - boolean value compute_full_loss: 0 + * (default) or 1 (compute) + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as log_predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_log_poisson_loss) +@Namespace("sd::ops") public static class log_poisson_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_poisson_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_poisson_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_poisson_loss position(long position) { + return (log_poisson_loss)super.position(position); + } - /** - * This op calculates logarithmic loss of poisson distributed input. - * Input arrays: - * 0: log_predictions - must be already pre-transformed to log(x) - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: optional - boolean value compute_full_loss: 0 (default) or 1 (compute) - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as log_predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_log_poisson_loss) - @Namespace("sd::ops") public static class log_poisson_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_poisson_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_poisson_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_poisson_loss position(long position) { - return (log_poisson_loss)super.position(position); - } - public log_poisson_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class log_poisson_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public log_poisson_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public log_poisson_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public log_poisson_loss_grad position(long position) { - return (log_poisson_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class log_poisson_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public log_poisson_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public log_poisson_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public log_poisson_loss_grad position(long position) { + return (log_poisson_loss_grad)super.position(position); + } + public log_poisson_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of pairwise-errors-squared loss function + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Output array: + * 0: loss value, it is just single scalar, type float. + */ +// #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) +@Namespace("sd::ops") public static class mean_pairwssqerr_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_pairwssqerr_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_pairwssqerr_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_pairwssqerr_loss position(long position) { + return (mean_pairwssqerr_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of pairwise-errors-squared loss function - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Output array: - * 0: loss value, it is just single scalar, type float. - */ -// #if NOT_EXCLUDED(OP_mean_pairwssqerr_loss) - @Namespace("sd::ops") public static class mean_pairwssqerr_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_pairwssqerr_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_pairwssqerr_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_pairwssqerr_loss position(long position) { - return (mean_pairwssqerr_loss)super.position(position); - } - public mean_pairwssqerr_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mean_pairwssqerr_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_pairwssqerr_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_pairwssqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_pairwssqerr_loss_grad position(long position) { - return (mean_pairwssqerr_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class mean_pairwssqerr_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_pairwssqerr_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_pairwssqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_pairwssqerr_loss_grad position(long position) { + return (mean_pairwssqerr_loss_grad)super.position(position); + } + public mean_pairwssqerr_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Sum-of-Squares loss function 1/N * + * sum_{i}^{N}(predictions_i - labels_i)^2 + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_mean_sqerr_loss) +@Namespace("sd::ops") public static class mean_sqerr_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_sqerr_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_sqerr_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_sqerr_loss position(long position) { + return (mean_sqerr_loss)super.position(position); + } - - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Sum-of-Squares loss function 1/N * sum_{i}^{N}(predictions_i - labels_i)^2 - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_mean_sqerr_loss) - @Namespace("sd::ops") public static class mean_sqerr_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_sqerr_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_sqerr_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_sqerr_loss position(long position) { - return (mean_sqerr_loss)super.position(position); - } - public mean_sqerr_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class mean_sqerr_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public mean_sqerr_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public mean_sqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public mean_sqerr_loss_grad position(long position) { - return (mean_sqerr_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class mean_sqerr_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public mean_sqerr_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public mean_sqerr_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public mean_sqerr_loss_grad position(long position) { + return (mean_sqerr_loss_grad)super.position(position); + } + public mean_sqerr_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/2): new_labels = labels * (1 - + * labelsSmoothing)+ 0.5 * labelsSmoothing + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as logits or just single scalar, + * depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) +@Namespace("sd::ops") public static class sigm_cross_entropy_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigm_cross_entropy_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigm_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigm_cross_entropy_loss position(long position) { + return (sigm_cross_entropy_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sigmoid cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/2): new_labels = labels * (1 - labelsSmoothing)+ 0.5 * labelsSmoothing - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as logits or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_sigm_cross_entropy_loss) - @Namespace("sd::ops") public static class sigm_cross_entropy_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigm_cross_entropy_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigm_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigm_cross_entropy_loss position(long position) { - return (sigm_cross_entropy_loss)super.position(position); - } - public sigm_cross_entropy_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sigm_cross_entropy_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sigm_cross_entropy_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sigm_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sigm_cross_entropy_loss_grad position(long position) { - return (sigm_cross_entropy_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class sigm_cross_entropy_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sigm_cross_entropy_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sigm_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sigm_cross_entropy_loss_grad position(long position) { + return (sigm_cross_entropy_loss_grad)super.position(position); + } + public sigm_cross_entropy_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function max(logits, 0.) - + * logits * labels + log(1. + exp(-abs(logits))); + * + * Input arrays: + * 0: logits - logits, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels, and must be + * broadcastable to labels. 2: labels - ground truth vales, expected to be 0. + * or 1., type float. Must have the same shape as logits. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as logits. + * 1 - "weighted_sum", output is scalar and equal to sum of all elements + * of weightedLosses array 2 - "weighted_mean", output is scalar and equal to + * sum of all elements of weightedLosses array divided by sum of all elements of + * weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output is scalar + * and equal to scalar sum of all elements of weightedLosses array divided by + * number of non-zero weights + * + * Input float arguments: + * 0: smoothing value, if it is greater than 0 then apply smoothing to the + * labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - + * labelsSmoothing) + labelsSmoothing / numClasses + * + * Output array: + * 0: loss values, type float. + * Can be an array with shape as in logits except last dimension is equal + * to unity or just single scalar, depending on reduction mode (see input + * integer argument) + */ +// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) +@Namespace("sd::ops") public static class softmax_cross_entropy_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss position(long position) { + return (softmax_cross_entropy_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function max(logits, 0.) - logits * labels + log(1. + exp(-abs(logits))); - * - * Input arrays: - * 0: logits - logits, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels, and must be broadcastable to labels. - * 2: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as logits. - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Input float arguments: - * 0: smoothing value, if it is greater than 0 then apply smoothing to the labels (smooth the labels towards 1/numClasses): new_labels = labels * (1 - labelsSmoothing) + labelsSmoothing / numClasses - * - * Output array: - * 0: loss values, type float. - * Can be an array with shape as in logits except last dimension is equal to unity or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss) - @Namespace("sd::ops") public static class softmax_cross_entropy_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss position(long position) { - return (softmax_cross_entropy_loss)super.position(position); - } - public softmax_cross_entropy_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_grad position(long position) { - return (softmax_cross_entropy_loss_grad)super.position(position); - } - + } +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_grad position(long position) { + return (softmax_cross_entropy_loss_grad)super.position(position); + } + public softmax_cross_entropy_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif + } +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of Absolute Difference loss function |predictions - labels| + * + * Input arrays: + * 0: predictions - the predicted values, type float. + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_absolute_difference_loss) +@Namespace("sd::ops") public static class absolute_difference_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public absolute_difference_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public absolute_difference_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public absolute_difference_loss position(long position) { + return (absolute_difference_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of Absolute Difference loss function |predictions - labels| - * - * Input arrays: - * 0: predictions - the predicted values, type float. - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_absolute_difference_loss) - @Namespace("sd::ops") public static class absolute_difference_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public absolute_difference_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public absolute_difference_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public absolute_difference_loss position(long position) { - return (absolute_difference_loss)super.position(position); - } - public absolute_difference_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class absolute_difference_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public absolute_difference_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public absolute_difference_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public absolute_difference_loss_grad position(long position) { - return (absolute_difference_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class absolute_difference_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public absolute_difference_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public absolute_difference_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public absolute_difference_loss_grad position(long position) { + return (absolute_difference_loss_grad)super.position(position); + } + public absolute_difference_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of cosine-distance loss function 1. - (predictions * + * labels).reduce_sum_along(dimension) + * + * Input arrays: + * 0: predictions - the predicted values, type float + * 1: weights - is used for weighting (multiplying) of loss values, type + * float. Can be single scalar or has the same rank as labels and must be + * broadcastable to labels. 2: labels - ground truth vales, type float. Must + * have the same shape as predictions. + * + * Input integer arguments: + * 0: type of reduction to apply to loss + * 0 - "none", unreduced weighted losses with the same shape as + * predictions 1 - "weighted_sum", output is scalar and equal to sum of all + * elements of weightedLosses array 2 - "weighted_mean", output is scalar and + * equal to sum of all elements of weightedLosses array divided by sum of all + * elements of weightsBroad array 3 - "weighted_sum_by_nonzero_weights", output + * is scalar and equal to scalar sum of all elements of weightedLosses array + * divided by number of non-zero weights 1: dimension along which the cosine + * distance is computed. + * + * Output array: + * 0: loss values, type float. + * Can be an array with the same shape as predictions or just single + * scalar, depending on reduction mode (see input integer argument) + */ +// #if NOT_EXCLUDED(OP_cosine_distance_loss) +@Namespace("sd::ops") public static class cosine_distance_loss extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cosine_distance_loss(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cosine_distance_loss(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cosine_distance_loss position(long position) { + return (cosine_distance_loss)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of cosine-distance loss function 1. - (predictions * labels).reduce_sum_along(dimension) - * - * Input arrays: - * 0: predictions - the predicted values, type float - * 1: weights - is used for weighting (multiplying) of loss values, type float. - * Can be single scalar or has the same rank as labels and must be broadcastable to labels. - * 2: labels - ground truth vales, type float. - * Must have the same shape as predictions. - * - * Input integer arguments: - * 0: type of reduction to apply to loss - * 0 - "none", unreduced weighted losses with the same shape as predictions - * 1 - "weighted_sum", output is scalar and equal to sum of all elements of weightedLosses array - * 2 - "weighted_mean", output is scalar and equal to sum of all elements of weightedLosses array divided by sum of all elements of weightsBroad array - * 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of weightedLosses array divided by number of non-zero weights - * 1: dimension along which the cosine distance is computed. - * - * Output array: - * 0: loss values, type float. - * Can be an array with the same shape as predictions or just single scalar, depending on reduction mode (see input integer argument) - */ -// #if NOT_EXCLUDED(OP_cosine_distance_loss) - @Namespace("sd::ops") public static class cosine_distance_loss extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cosine_distance_loss(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cosine_distance_loss(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cosine_distance_loss position(long position) { - return (cosine_distance_loss)super.position(position); - } - public cosine_distance_loss() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class cosine_distance_loss_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cosine_distance_loss_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cosine_distance_loss_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cosine_distance_loss_grad position(long position) { - return (cosine_distance_loss_grad)super.position(position); - } - +@Namespace("sd::ops") public static class cosine_distance_loss_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cosine_distance_loss_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cosine_distance_loss_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cosine_distance_loss_grad position(long position) { + return (cosine_distance_loss_grad)super.position(position); + } + public cosine_distance_loss_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of softmax cross-entropy loss function + * + * Input arrays: + * 0: logits - logits, type float + * 1: labels - ground truth vales, expected to be 0. or 1., type float. + * Must have the same shape as logits. + * + * Input integer arguments: + * 0: optional (default is last dimension) dimension with classes + * + * Output array: + * 0: loss values, type float. An array with shape resulting from reducing of + * logits shape along dimension with classes + */ +// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_with_logits position(long position) { + return (softmax_cross_entropy_loss_with_logits)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of softmax cross-entropy loss function - * - * Input arrays: - * 0: logits - logits, type float - * 1: labels - ground truth vales, expected to be 0. or 1., type float. - * Must have the same shape as logits. - * - * Input integer arguments: - * 0: optional (default is last dimension) dimension with classes - * - * Output array: - * 0: loss values, type float. An array with shape resulting from reducing of logits shape along dimension with classes - */ -// #if NOT_EXCLUDED(OP_softmax_cross_entropy_loss_with_logits) - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_with_logits position(long position) { - return (softmax_cross_entropy_loss_with_logits)super.position(position); - } - public softmax_cross_entropy_loss_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public softmax_cross_entropy_loss_with_logits_grad position(long position) { - return (softmax_cross_entropy_loss_with_logits_grad)super.position(position); - } - +@Namespace("sd::ops") public static class softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public softmax_cross_entropy_loss_with_logits_grad position(long position) { + return (softmax_cross_entropy_loss_with_logits_grad)super.position(position); + } + public softmax_cross_entropy_loss_with_logits_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +////////////////////////////////////////////////////////////////////////// +/** + * Implementation of sparse softmax cross-entropy loss function + * + * Input arrays: + * 0: labels - ground truth vales, expected to be within range [0, + * num_classes), type float. Must have rank equal logits rank minus 1. 1: logits + * - logits, type float + * + * Output array: + * 0: loss values, type float. Has the same shape as labels + */ +// #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) +@Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sparse_softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sparse_softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sparse_softmax_cross_entropy_loss_with_logits position(long position) { + return (sparse_softmax_cross_entropy_loss_with_logits)super.position(position); + } - ////////////////////////////////////////////////////////////////////////// - /** - * Implementation of sparse softmax cross-entropy loss function - * - * Input arrays: - * 0: labels - ground truth vales, expected to be within range [0, num_classes), type float. - * Must have rank equal logits rank minus 1. - * 1: logits - logits, type float - * - * Output array: - * 0: loss values, type float. Has the same shape as labels - */ -// #if NOT_EXCLUDED(OP_sparse_softmax_cross_entropy_loss_with_logits) - @Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sparse_softmax_cross_entropy_loss_with_logits(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sparse_softmax_cross_entropy_loss_with_logits(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sparse_softmax_cross_entropy_loss_with_logits position(long position) { - return (sparse_softmax_cross_entropy_loss_with_logits)super.position(position); - } - public sparse_softmax_cross_entropy_loss_with_logits() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } - @Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public sparse_softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public sparse_softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public sparse_softmax_cross_entropy_loss_with_logits_grad position(long position) { - return (sparse_softmax_cross_entropy_loss_with_logits_grad)super.position(position); - } - +@Namespace("sd::ops") public static class sparse_softmax_cross_entropy_loss_with_logits_grad extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public sparse_softmax_cross_entropy_loss_with_logits_grad(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public sparse_softmax_cross_entropy_loss_with_logits_grad(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public sparse_softmax_cross_entropy_loss_with_logits_grad position(long position) { + return (sparse_softmax_cross_entropy_loss_with_logits_grad)super.position(position); + } + public sparse_softmax_cross_entropy_loss_with_logits_grad() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - - +// #endif + // namespace ops + // namespace sd // #endif @@ -24106,218 +26578,221 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_HEADERS_DTYPE_H // #define LIBND4J_HEADERS_DTYPE_H -// #include - /** - * This operation casts elements of input array to double data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_double) - @Namespace("sd::ops") public static class to_double extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_double(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_double(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_double position(long position) { - return (to_double)super.position(position); - } - +// #include +/** + * This operation casts elements of input array to double data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_double) +@Namespace("sd::ops") public static class to_double extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_double(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_double(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_double position(long position) { + return (to_double)super.position(position); + } + public to_double() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to float16 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_float16) +@Namespace("sd::ops") public static class to_float16 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_float16(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_float16(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_float16 position(long position) { + return (to_float16)super.position(position); + } - /** - * This operation casts elements of input array to float16 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_float16) - @Namespace("sd::ops") public static class to_float16 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_float16(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_float16(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_float16 position(long position) { - return (to_float16)super.position(position); - } - public to_float16() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to float data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_float32) +@Namespace("sd::ops") public static class to_float32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_float32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_float32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_float32 position(long position) { + return (to_float32)super.position(position); + } - /** - * This operation casts elements of input array to float data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_float32) - @Namespace("sd::ops") public static class to_float32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_float32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_float32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_float32 position(long position) { - return (to_float32)super.position(position); - } - public to_float32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_int32) +@Namespace("sd::ops") public static class to_int32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_int32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_int32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_int32 position(long position) { + return (to_int32)super.position(position); + } - /** - * This operation casts elements of input array to int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_int32) - @Namespace("sd::ops") public static class to_int32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_int32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_int32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_int32 position(long position) { - return (to_int32)super.position(position); - } - public to_int32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to int64 (aka long long) data + * type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_int64) +@Namespace("sd::ops") public static class to_int64 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_int64(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_int64(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_int64 position(long position) { + return (to_int64)super.position(position); + } - /** - * This operation casts elements of input array to int64 (aka long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_int64) - @Namespace("sd::ops") public static class to_int64 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_int64(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_int64(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_int64 position(long position) { - return (to_int64)super.position(position); - } - public to_int64() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to unsinged int32 data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_uint32) +@Namespace("sd::ops") public static class to_uint32 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_uint32(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_uint32(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_uint32 position(long position) { + return (to_uint32)super.position(position); + } - /** - * This operation casts elements of input array to unsinged int32 data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_uint32) - @Namespace("sd::ops") public static class to_uint32 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_uint32(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_uint32(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_uint32 position(long position) { - return (to_uint32)super.position(position); - } - public to_uint32() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to unsigned int64 (aka unsigned + * long long) data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + */ +// #if NOT_EXCLUDED(OP_to_uint64) +@Namespace("sd::ops") public static class to_uint64 extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public to_uint64(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public to_uint64(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public to_uint64 position(long position) { + return (to_uint64)super.position(position); + } - /** - * This operation casts elements of input array to unsigned int64 (aka unsigned long long) data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - */ -// #if NOT_EXCLUDED(OP_to_uint64) - @Namespace("sd::ops") public static class to_uint64 extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public to_uint64(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public to_uint64(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public to_uint64 position(long position) { - return (to_uint64)super.position(position); - } - public to_uint64() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif +// #endif + +/** + * This operation casts elements of input array to specified data type + * + * PLEASE NOTE: This op is disabled atm, and reserved for future releases. + * + * + * Int args: + * 0: target DataType + */ +// #if NOT_EXCLUDED(OP_cast) +@Namespace("sd::ops") public static class cast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cast position(long position) { + return (cast)super.position(position); + } - /** - * This operation casts elements of input array to specified data type - * - * PLEASE NOTE: This op is disabled atm, and reserved for future releases. - * - * - * Int args: - * 0: target DataType - */ -// #if NOT_EXCLUDED(OP_cast) - @Namespace("sd::ops") public static class cast extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public cast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public cast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public cast position(long position) { - return (cast)super.position(position); - } - public cast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - /** - * This operation change type of input and modified shape of output to conform with given data type - * - * all as above op - * */ -// #if NOT_EXCLUDED(OP_bitcast) - @Namespace("sd::ops") public static class bitcast extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public bitcast(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public bitcast(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public bitcast position(long position) { - return (bitcast)super.position(position); - } - +// #endif +/** + * This operation change type of input and modified shape of output to conform + * with given data type + * + * all as above op + * */ +// #if NOT_EXCLUDED(OP_bitcast) +@Namespace("sd::ops") public static class bitcast extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bitcast(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bitcast(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bitcast position(long position) { + return (bitcast)super.position(position); + } + public bitcast() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif @@ -24346,56 +26821,57 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CONTEXTBUFFERS_H // #define LIBND4J_CONTEXTBUFFERS_H +// #include // #include // #include -// #include - @Namespace("sd") @NoOffset public static class ContextBuffers extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public ContextBuffers(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public ContextBuffers position(long position) { - return (ContextBuffers)super.position(position); - } - - public ContextBuffers() { super((Pointer)null); allocate(); } - private native void allocate(); - public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ContextBuffers other); - public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } - private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/); - public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } - private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); - - public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); +@Namespace("sd") @NoOffset public static class ContextBuffers extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ContextBuffers(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ContextBuffers position(long position) { + return (ContextBuffers)super.position(position); + } - public native void release(); + public ContextBuffers() { super((Pointer)null); allocate(); } + private native void allocate(); + public ContextBuffers(@Const @ByRef ContextBuffers other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ContextBuffers other); + public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, + @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); } + private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, + @Cast("bool") boolean isOwner/*=false*/); + public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); } + private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer); - public native Pointer reductionBuffer(); - public native Pointer scalarBuffer(); - public native Pointer allocationBuffer(); + public native @ByRef @Name("operator =") ContextBuffers put(@Const @ByRef ContextBuffers other); - public native Pointer execStream(); - public native Pointer specialStream(); + public native void release(); - public native void setReductionBuffer(Pointer pointer); - public native void setScalarBuffer(Pointer pointer); - public native void setAllocationBuffer(Pointer pointer); + public native Pointer reductionBuffer(); + public native Pointer scalarBuffer(); + public native Pointer allocationBuffer(); - public native ErrorReference errorReference(); + public native Pointer execStream(); + public native Pointer specialStream(); - public native void triggerOwnership(@Cast("bool") boolean isOwner); + public native void setReductionBuffer(Pointer pointer); + public native void setScalarBuffer(Pointer pointer); + public native void setAllocationBuffer(Pointer pointer); - public native int deviceId(); + public native ErrorReference errorReference(); - public native @Cast("bool") boolean isInitialized(); - } + public native void triggerOwnership(@Cast("bool") boolean isOwner); + public native int deviceId(); + public native @Cast("bool") boolean isInitialized(); +} + // namespace sd -// #endif //DEV_TESTS_CONTEXTBUFFERS_H +// #endif // SD_CONTEXTBUFFERS_H // Parsed from execution/LaunchContext.h @@ -24423,7 +26899,6 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J_CUDACONTEXT_H // #define LIBND4J_CUDACONTEXT_H - // #ifdef __CUDABLAS__ // #endif @@ -24432,14 +26907,15 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #include "config.h" // #endif +// #include +// #include +// #include // #include -// #include // #include -// #include -// #include + +// #include // #include -// #include -// #include +// #include @Namespace("sd") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } @@ -24452,41 +26928,41 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifdef __CUDABLAS__ -// #endif // CUDA - public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/) { super((Pointer)null); allocate(cudaStream, reductionPointer, scalarPointer, allocationPointer); } - private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/); - public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream) { super((Pointer)null); allocate(cudaStream); } - private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream); - public LaunchContext() { super((Pointer)null); allocate(); } - private native void allocate(); - public native Workspace getWorkspace(); - public native void setWorkspace(Workspace theWorkspace); +// #endif // CUDA + public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/) { super((Pointer)null); allocate(cudaStream, reductionPointer, scalarPointer, allocationPointer); } + private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer scalarPointer/*=nullptr*/, + @Cast("Nd4jPointer") Pointer allocationPointer/*=nullptr*/); + public LaunchContext(@Cast("Nd4jPointer") Pointer cudaStream) { super((Pointer)null); allocate(cudaStream); } + private native void allocate(@Cast("Nd4jPointer") Pointer cudaStream); + public LaunchContext() { super((Pointer)null); allocate(); } + private native void allocate(); + public native Workspace getWorkspace(); + public native void setWorkspace(Workspace theWorkspace); - public native Pointer engine(); + public native Pointer engine(); - public native int getDeviceID(); - public native void setDeviceID(int deviceID); - public native ErrorReference errorReference(); + public native int getDeviceID(); + public native void setDeviceID(int deviceID); + public native ErrorReference errorReference(); // #ifndef __JAVACPP_HACK__ // #endif - public static native @Cast("bool") boolean isInitialized(); - public static native void releaseBuffers(); - - - public static native LaunchContext defaultContext(); + public static native @Cast("bool") boolean isInitialized(); + public static native void releaseBuffers(); + public static native LaunchContext defaultContext(); - public static native void swapContextBuffers(@ByRef ContextBuffers buffers); - + public static native void swapContextBuffers(@ByRef ContextBuffers buffers); } + // namespace sd - - -// #endif //LIBND4J_CUDACONTEXT_H +// #endif // LIBND4J_CUDACONTEXT_H // Parsed from array/ShapeDescriptor.h @@ -24511,15 +26987,16 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // -// #ifndef DEV_TESTS_SHAPEDESCRIPTOR_H -// #define DEV_TESTS_SHAPEDESCRIPTOR_H +// #ifndef SD_SHAPEDESCRIPTOR_H +// #define SD_SHAPEDESCRIPTOR_H -// #include -// #include +// #include // #include // #include -// #include + // #include +// #include +// #include @Namespace("sd") @NoOffset public static class ShapeDescriptor extends Pointer { static { Loader.load(); } @@ -24532,108 +27009,185 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint return (ShapeDescriptor)super.position(position); } - public ShapeDescriptor(@Const @ByRef ShapeDescriptor other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef ShapeDescriptor other); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride, @Cast("const Nd4jLong*") LongPointer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer dtypeOverride, @Cast("const Nd4jLong*") LongPointer orderOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride, @Cast("const Nd4jLong*") LongBuffer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer dtypeOverride, @Cast("const Nd4jLong*") LongBuffer orderOverride); - public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride, @Cast("const Nd4jLong*") long[] orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } - private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] dtypeOverride, @Cast("const Nd4jLong*") long[] orderOverride); - public ShapeDescriptor(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length) { super((Pointer)null); allocate(type, length); } - private native void allocate(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, int rank); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(type, order, shape); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides) { super((Pointer)null); allocate(type, order, shape, strides); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("Nd4jLong*") @StdVector LongPointer strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("Nd4jLong*") @StdVector LongBuffer strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides, @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } - private native void allocate(@Cast("const sd::DataType") int type, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("Nd4jLong*") @StdVector long[] strides, @Cast("const Nd4jLong") long ews); - public ShapeDescriptor() { super((Pointer)null); allocate(); } - private native void allocate(); - - public native int rank(); - public native @Cast("Nd4jLong") long ews(); - public native @Cast("Nd4jLong") long arrLength(); - public native char order(); - public native @Cast("sd::DataType") int dataType(); - public native @Cast("bool") boolean isEmpty(); - public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); - public native @Cast("Nd4jLong*") @StdVector LongPointer strides(); - - // we use default copy assignment operator - public native @ByRef @Name("operator =") ShapeDescriptor put(@Const @ByRef ShapeDescriptor other); - - // we use default move assignment operator - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ShapeDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ShapeDescriptor other); - - public native @Cast("Nd4jLong*") LongPointer toShapeInfo(); - - - public static native @ByVal ShapeDescriptor emptyDescriptor(@Cast("const sd::DataType") int type); - public static native @ByVal ShapeDescriptor scalarDescriptor(@Cast("const sd::DataType") int type); - public static native @ByVal ShapeDescriptor vectorDescriptor(@Cast("const Nd4jLong") long length, @Cast("const sd::DataType") int type); - } + public ShapeDescriptor(@Const @ByRef ShapeDescriptor other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef ShapeDescriptor other); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/) { super((Pointer)null); allocate(shapeInfo, inheritDtype); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("bool") boolean inheritDtype/*=true*/); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const sd::DataType") int dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const sd::DataType") int dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride, + @Cast("const Nd4jLong*") LongPointer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer shapeInfo, + @Cast("const Nd4jLong*") LongPointer dtypeOverride, + @Cast("const Nd4jLong*") LongPointer orderOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride, + @Cast("const Nd4jLong*") LongBuffer orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer shapeInfo, + @Cast("const Nd4jLong*") LongBuffer dtypeOverride, + @Cast("const Nd4jLong*") LongBuffer orderOverride); + public ShapeDescriptor(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride, + @Cast("const Nd4jLong*") long[] orderOverride) { super((Pointer)null); allocate(shapeInfo, dtypeOverride, orderOverride); } + private native void allocate(@Cast("const Nd4jLong*") long[] shapeInfo, + @Cast("const Nd4jLong*") long[] dtypeOverride, + @Cast("const Nd4jLong*") long[] orderOverride); + public ShapeDescriptor(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length) { super((Pointer)null); allocate(type, length); } + private native void allocate(@Cast("const sd::DataType") int type, @Cast("const Nd4jLong") long length); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, int rank) { super((Pointer)null); allocate(type, order, shape, rank); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, int rank); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty) { super((Pointer)null); allocate(type, order, shape, strides, rank, ews, empty); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] strides, + int rank, @Cast("Nd4jLong") long ews, @Cast("const bool") boolean empty); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape) { super((Pointer)null); allocate(type, order, shape); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides) { super((Pointer)null); allocate(type, order, shape, strides); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongPointer shape, + @Cast("Nd4jLong*") @StdVector LongPointer strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector LongBuffer shape, + @Cast("Nd4jLong*") @StdVector LongBuffer strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides, + @Cast("const Nd4jLong") long ews) { super((Pointer)null); allocate(type, order, shape, strides, ews); } + private native void allocate(@Cast("const sd::DataType") int type, byte order, + @Cast("Nd4jLong*") @StdVector long[] shape, + @Cast("Nd4jLong*") @StdVector long[] strides, + @Cast("const Nd4jLong") long ews); + public ShapeDescriptor() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int rank(); + public native @Cast("Nd4jLong") long ews(); + public native @Cast("Nd4jLong") long arrLength(); + public native char order(); + public native @Cast("sd::DataType") int dataType(); + public native @Cast("bool") boolean isEmpty(); + public native @Cast("Nd4jLong*") @StdVector LongPointer shape(); + public native @Cast("Nd4jLong*") @StdVector LongPointer strides(); + + // we use default copy assignment operator + public native @ByRef @Name("operator =") ShapeDescriptor put(@Const @ByRef ShapeDescriptor other); + + // we use default move assignment operator + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef ShapeDescriptor other); + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef ShapeDescriptor other); + + public native @Cast("Nd4jLong*") LongPointer toShapeInfo(); + + public static native @ByVal ShapeDescriptor emptyDescriptor(@Cast("const sd::DataType") int type); + public static native @ByVal ShapeDescriptor scalarDescriptor(@Cast("const sd::DataType") int type); + public static native @ByVal ShapeDescriptor vectorDescriptor(@Cast("const Nd4jLong") long length, + @Cast("const sd::DataType") int type); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_SHAPEDESCRIPTOR_H +// #endif // SD_SHAPEDESCRIPTOR_H // Parsed from array/TadDescriptor.h @@ -24658,67 +27212,103 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // @author raver119@gmail.com // -// #ifndef DEV_TESTS_TADDESCRIPTOR_H -// #define DEV_TESTS_TADDESCRIPTOR_H +// #ifndef SD_TADDESCRIPTOR_H +// #define SD_TADDESCRIPTOR_H -// #include "ShapeDescriptor.h" // #include - @Namespace("sd") @NoOffset public static class TadDescriptor extends Pointer { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public TadDescriptor(Pointer p) { super(p); } - - public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, int length); - public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, int length); - public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } - private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } - private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, int length); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntPointer dimensions); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector IntBuffer dimensions); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions, @Cast("const bool") boolean keepUnitiesInShape/*=false*/); - public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } - private native void allocate(@Const @ByRef ShapeDescriptor descriptor, @StdVector int[] dimensions); - public TadDescriptor(@Const @ByRef TadDescriptor other) { super((Pointer)null); allocate(other); } - private native void allocate(@Const @ByRef TadDescriptor other); - - // we use default copy assignment operator - public native @ByRef @Name("operator =") TadDescriptor put(@Const @ByRef TadDescriptor other); - - // we use default move assignment operator - - // equal to operator - public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef TadDescriptor other); - - // less than operator - public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef TadDescriptor other); - - public native @StdVector IntPointer axis(); - public native @ByRef ShapeDescriptor originalShape(); - public native @Const @ByRef ShapeDescriptor originalShapeConst(); - public native @Cast("bool") boolean areUnitiesinShape(); - } +// #include "ShapeDescriptor.h" +@Namespace("sd") @NoOffset public static class TadDescriptor extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TadDescriptor(Pointer p) { super(p); } + + public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") LongPointer originalShape, @Const IntPointer dimensions, + int length); + public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") LongBuffer originalShape, @Const IntBuffer dimensions, + int length); + public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(originalShape, dimensions, length, keepUnitiesInShape); } + private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length) { super((Pointer)null); allocate(originalShape, dimensions, length); } + private native void allocate(@Cast("const Nd4jLong*") long[] originalShape, @Const int[] dimensions, + int length); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntPointer dimensions); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector IntBuffer dimensions); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/) { super((Pointer)null); allocate(descriptor, dimensions, keepUnitiesInShape); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions, + @Cast("const bool") boolean keepUnitiesInShape/*=false*/); + public TadDescriptor(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions) { super((Pointer)null); allocate(descriptor, dimensions); } + private native void allocate(@Const @ByRef ShapeDescriptor descriptor, + @StdVector int[] dimensions); + public TadDescriptor(@Const @ByRef TadDescriptor other) { super((Pointer)null); allocate(other); } + private native void allocate(@Const @ByRef TadDescriptor other); + + // we use default copy assignment operator + public native @ByRef @Name("operator =") TadDescriptor put(@Const @ByRef TadDescriptor other); + + // we use default move assignment operator + + // equal to operator + public native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef TadDescriptor other); + + // less than operator + public native @Cast("bool") @Name("operator <") boolean lessThan(@Const @ByRef TadDescriptor other); + + public native @StdVector IntPointer axis(); + public native @ByRef ShapeDescriptor originalShape(); + public native @Const @ByRef ShapeDescriptor originalShapeConst(); + public native @Cast("bool") boolean areUnitiesinShape(); +} + // namespace sd // #ifndef __JAVACPP_HACK__ // #endif - -// #endif //DEV_TESTS_TADDESCRIPTOR_H +// #endif // SD_TADDESCRIPTOR_H // Parsed from helpers/DebugInfo.h @@ -24746,48 +27336,48 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #ifndef LIBND4J__DEBUG_INFO_HELPER__H // #define LIBND4J__DEBUG_INFO_HELPER__H -// #include -// #include -// #include // #include -// #include -// #include // #include +// #include +// #include +// #include +// #include + +// #include // #ifdef __CUDACC__ // #endif - @Namespace("sd") public static class DebugInfo extends Pointer { - static { Loader.load(); } - /** Default native constructor. */ - public DebugInfo() { super((Pointer)null); allocate(); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public DebugInfo(long size) { super((Pointer)null); allocateArray(size); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public DebugInfo(Pointer p) { super(p); } - private native void allocate(); - private native void allocateArray(long size); - @Override public DebugInfo position(long position) { - return (DebugInfo)super.position(position); - } - - public native double _minValue(); public native DebugInfo _minValue(double setter); - public native double _maxValue(); public native DebugInfo _maxValue(double setter); - public native double _meanValue(); public native DebugInfo _meanValue(double setter); - public native double _stdDevValue(); public native DebugInfo _stdDevValue(double setter); - public native @Cast("Nd4jLong") long _zeroCount(); public native DebugInfo _zeroCount(long setter); - public native @Cast("Nd4jLong") long _positiveCount(); public native DebugInfo _positiveCount(long setter); - public native @Cast("Nd4jLong") long _negativeCount(); public native DebugInfo _negativeCount(long setter); - public native @Cast("Nd4jLong") long _infCount(); public native DebugInfo _infCount(long setter); - public native @Cast("Nd4jLong") long _nanCount(); public native DebugInfo _nanCount(long setter); +@Namespace("sd") public static class DebugInfo extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public DebugInfo() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public DebugInfo(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DebugInfo(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public DebugInfo position(long position) { + return (DebugInfo)super.position(position); } - @Namespace("sd") public static native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef DebugInfo first, @Const @ByRef DebugInfo second); - + public native double _minValue(); public native DebugInfo _minValue(double setter); + public native double _maxValue(); public native DebugInfo _maxValue(double setter); + public native double _meanValue(); public native DebugInfo _meanValue(double setter); + public native double _stdDevValue(); public native DebugInfo _stdDevValue(double setter); + public native @Cast("Nd4jLong") long _zeroCount(); public native DebugInfo _zeroCount(long setter); + public native @Cast("Nd4jLong") long _positiveCount(); public native DebugInfo _positiveCount(long setter); + public native @Cast("Nd4jLong") long _negativeCount(); public native DebugInfo _negativeCount(long setter); + public native @Cast("Nd4jLong") long _infCount(); public native DebugInfo _infCount(long setter); + public native @Cast("Nd4jLong") long _nanCount(); public native DebugInfo _nanCount(long setter); +} +@Namespace("sd") public static native @Cast("bool") @Name("operator ==") boolean equals(@Const @ByRef DebugInfo first, @Const @ByRef DebugInfo second); + // namespace sd -// #endif //LIBND4J_DEBUGHELPER_H +// #endif // LIBND4J_DEBUGHELPER_H // Parsed from ops/declarable/headers/third_party.h @@ -24816,25 +27406,25 @@ public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPoint // #define LIBND4J_HEADERS_TPARTY_H // #include -// #if NOT_EXCLUDED(OP_firas_sparse) - @Namespace("sd::ops") public static class firas_sparse extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public firas_sparse(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public firas_sparse(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public firas_sparse position(long position) { - return (firas_sparse)super.position(position); - } - +// #if NOT_EXCLUDED(OP_firas_sparse) +@Namespace("sd::ops") public static class firas_sparse extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public firas_sparse(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public firas_sparse(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public firas_sparse position(long position) { + return (firas_sparse)super.position(position); + } + public firas_sparse() { super((Pointer)null); allocate(); } private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } -// #endif - - +// #endif + // namespace ops + // namespace sd // #endif diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index f10410314368..a2514d522ab9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -35,6 +35,9 @@ @Properties(inherit = openblas.class, target = "org.nd4j.nativeblas.Nd4jCpu", helper = "org.nd4j.nativeblas.Nd4jCpuHelper", value = {@Platform(define = "LIBND4J_ALL_OPS", include = { "memory/MemoryType.h", + "memory/MemoryZone.h", + "memory/MemoryDescriptor.h", + "memory/GraphMemoryManager.h", "array/DataType.h", "array/DataBuffer.h", "array/PointerDeallocator.h", @@ -67,7 +70,6 @@ "graph/FlowPath.h", "graph/Intervals.h", "graph/Stash.h", - "graph/GraphState.h", "graph/VariableSpace.h", "helpers/helper_generator.h", "graph/profiling/GraphProfile.h", @@ -80,8 +82,6 @@ "array/ShapeList.h", "system/type_boilerplate.h", "system/op_boilerplate.h", - //"enum_boilerplate.h", - //"op_enums.h", "ops/InputType.h", "ops/declarable/OpDescriptor.h", "ops/declarable/PlatformHelper.h", @@ -159,7 +159,7 @@ public void init(Logger logger, java.util.Properties properties, String encoding @Override public void map(InfoMap infoMap) { - infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", + infoMap.put(new Info("thread_local", "ND4J_EXPORT", "SD_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) .put(new Info("NativeOps.h").objectify()) .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) @@ -174,19 +174,13 @@ public void map(InfoMap infoMap) { .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) - .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", - "@Cast(\"char*\") BytePointer")) - .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", - "@Cast(\"char*\") String")) + .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) + .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", "@Cast(\"char*\") String")) .put(new Info("Nd4jPointer").cast().valueTypes("Pointer").pointerTypes("PointerPointer")) - .put(new Info("Nd4jLong").cast().valueTypes("long").pointerTypes("LongPointer", "LongBuffer", - "long[]")) - .put(new Info("Nd4jStatus").cast().valueTypes("int").pointerTypes("IntPointer", "IntBuffer", - "int[]")) - .put(new Info("float16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", - "short[]")) - .put(new Info("bfloat16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", - "short[]")); + .put(new Info("Nd4jLong").cast().valueTypes("long").pointerTypes("LongPointer", "LongBuffer", "long[]")) + .put(new Info("Nd4jStatus").cast().valueTypes("int").pointerTypes("IntPointer", "IntBuffer", "int[]")) + .put(new Info("float16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", "short[]")) + .put(new Info("bfloat16").cast().valueTypes("short").pointerTypes("ShortPointer", "ShortBuffer", "short[]")); infoMap.put(new Info("__CUDACC__", "MAX_UINT", "HAVE_MKLDNN", "__CUDABLAS__").define(false)) .put(new Info("__JAVACPP_HACK__", "LIBND4J_ALL_OPS").define(true))